"""Interactive GUI for editing per-frame IVUS contour annotations.""" from __future__ import annotations import json import os from dataclasses import dataclass from datetime import datetime from glob import glob from typing import Any import matplotlib.pyplot as plt import numpy as np from matplotlib.lines import Line2D from matplotlib.widgets import Button, Slider from ..io.dicom import read_dicom @dataclass class FrameState: """In-memory annotation state for one frame.""" frame: int lumen_x: list[float] lumen_y: list[float] plaque_x: list[float] plaque_y: list[float] lumen_confidence: float | None plaque_confidence: float | None bifurcation: bool = False class AnnotationEditor: """Matplotlib-based editor for contour annotations.""" def __init__( self, dicom_path: str, annotations_path: str, edits_path: str, ) -> None: self.dicom_path = dicom_path self.annotations_path = annotations_path self.edits_path = edits_path _, images = read_dicom(dicom_path) self.images = images self.frame_states = self._load_base_annotations(annotations_path) if len(self.frame_states) != images.shape[0]: raise ValueError( "Frame count mismatch between DICOM and annotations: " f"{images.shape[0]} vs {len(self.frame_states)}" ) self._load_existing_edits(edits_path) self._ensure_edits_file_header() self.current_frame = 0 self.drag_target: tuple[str, int] | None = None self.drag_threshold_px = 10.0 self.frame_dirty = False self.fig = None self.ax_image = None self.image_artist = None self.lumen_line: Line2D | None = None self.plaque_line: Line2D | None = None self.lumen_points = None self.plaque_points = None self.slider: Slider | None = None self.status_text = None self.is_updating_slider = False @staticmethod def _safe_float(value: Any) -> float | None: if value is None: return None try: out = float(value) except (TypeError, ValueError): return None if np.isnan(out): return None return out def _load_base_annotations(self, path: str) -> list[FrameState]: states: list[FrameState] = [] with open(path, "r", encoding="utf-8") as fp: for raw in fp: line = raw.strip() if not line: continue rec = json.loads(line) if rec.get("record_type") != "frame": continue states.append( FrameState( frame=int(rec["frame"]), lumen_x=[float(v) for v in rec.get("lumen", {}).get("x", [])], lumen_y=[float(v) for v in rec.get("lumen", {}).get("y", [])], plaque_x=[float(v) for v in rec.get("plaque", {}).get("x", [])], plaque_y=[float(v) for v in rec.get("plaque", {}).get("y", [])], lumen_confidence=self._safe_float(rec.get("lumen_confidence")), plaque_confidence=self._safe_float(rec.get("plaque_confidence")), bifurcation=bool(rec.get("bifurcation", False)), ) ) states.sort(key=lambda s: s.frame) expected = list(range(len(states))) actual = [s.frame for s in states] if actual != expected: raise ValueError( "Base annotations must contain contiguous frame records in order " f"(got first frames: {actual[:10]})" ) return states def _load_existing_edits(self, path: str) -> None: if not os.path.exists(path): return latest: dict[int, dict[str, Any]] = {} with open(path, "r", encoding="utf-8") as fp: for raw in fp: line = raw.strip() if not line: continue rec = json.loads(line) if rec.get("record_type") != "frame_edit": continue frame_idx = int(rec["frame"]) latest[frame_idx] = rec for frame_idx, rec in latest.items(): if frame_idx < 0 or frame_idx >= len(self.frame_states): continue state = self.frame_states[frame_idx] state.lumen_x = [float(v) for v in rec.get("lumen", {}).get("x", state.lumen_x)] state.lumen_y = [float(v) for v in rec.get("lumen", {}).get("y", state.lumen_y)] state.plaque_x = [float(v) for v in rec.get("plaque", {}).get("x", state.plaque_x)] state.plaque_y = [float(v) for v in rec.get("plaque", {}).get("y", state.plaque_y)] state.bifurcation = bool(rec.get("bifurcation", state.bifurcation)) def _ensure_edits_file_header(self) -> None: os.makedirs(os.path.dirname(os.path.abspath(self.edits_path)), exist_ok=True) if os.path.exists(self.edits_path): return meta = { "record_type": "meta", "created_at": datetime.now().isoformat(timespec="seconds"), "dicom_path": os.path.abspath(self.dicom_path), "base_annotations_path": os.path.abspath(self.annotations_path), "format": "append_only_frame_edits", } with open(self.edits_path, "w", encoding="utf-8") as fp: fp.write(json.dumps(meta) + "\n") def _append_frame_edit(self, frame_idx: int, reason: str) -> None: state = self.frame_states[frame_idx] rec = { "record_type": "frame_edit", "saved_at": datetime.now().isoformat(timespec="seconds"), "reason": reason, "frame": frame_idx, "bifurcation": state.bifurcation, "lumen": {"x": state.lumen_x, "y": state.lumen_y}, "plaque": {"x": state.plaque_x, "y": state.plaque_y}, "lumen_confidence": state.lumen_confidence, "plaque_confidence": state.plaque_confidence, } with open(self.edits_path, "a", encoding="utf-8") as fp: fp.write(json.dumps(rec) + "\n") fp.flush() os.fsync(fp.fileno()) self.frame_dirty = False def _build_ui(self) -> None: self.fig = plt.figure(figsize=(13, 9)) self.fig.canvas.manager.set_window_title("DeepIVUS Annotation Editor") self.ax_image = self.fig.add_axes([0.05, 0.22, 0.9, 0.74]) self.ax_image.set_title("DeepIVUS Annotation Editor", fontsize=14, weight="bold") self.ax_image.set_axis_off() slider_ax = self.fig.add_axes([0.12, 0.13, 0.76, 0.035]) self.slider = Slider( ax=slider_ax, label="Frame", valmin=0, valmax=len(self.frame_states) - 1, valinit=0, valstep=1, color="#1f77b4", ) prev_ax = self.fig.add_axes([0.12, 0.05, 0.1, 0.055]) next_ax = self.fig.add_axes([0.24, 0.05, 0.1, 0.055]) save_ax = self.fig.add_axes([0.42, 0.05, 0.14, 0.055]) bif_ax = self.fig.add_axes([0.72, 0.05, 0.16, 0.055]) self.prev_button = Button(prev_ax, "Prev Frame", color="#E0E0E0", hovercolor="#D0D0D0") self.next_button = Button(next_ax, "Next Frame", color="#E0E0E0", hovercolor="#D0D0D0") self.save_button = Button(save_ax, "Save Frame", color="#D6F5D6", hovercolor="#BFF0BF") self.bif_button = Button(bif_ax, "Bifurcation: No", color="#F5D6D6", hovercolor="#F0BFBF") self.status_text = self.fig.text(0.05, 0.18, "", fontsize=10) self.slider.on_changed(self._on_slider) self.prev_button.on_clicked(self._on_prev) self.next_button.on_clicked(self._on_next) self.save_button.on_clicked(self._on_save) self.bif_button.on_clicked(self._on_toggle_bifurcation) self.fig.canvas.mpl_connect("button_press_event", self._on_press) self.fig.canvas.mpl_connect("motion_notify_event", self._on_motion) self.fig.canvas.mpl_connect("button_release_event", self._on_release) self.fig.canvas.mpl_connect("key_press_event", self._on_key) self.fig.canvas.mpl_connect("close_event", self._on_close) def _state(self) -> FrameState: return self.frame_states[self.current_frame] def _render_frame(self) -> None: state = self._state() image = self.images[self.current_frame] if self.image_artist is None: self.image_artist = self.ax_image.imshow(image, cmap="gray") self.lumen_line = self.ax_image.plot([], [], color="#1db954", lw=2)[0] self.plaque_line = self.ax_image.plot([], [], color="#ff5a5a", lw=2)[0] self.lumen_points = self.ax_image.scatter([], [], c="#1db954", s=28, edgecolors="black", linewidths=0.4) self.plaque_points = self.ax_image.scatter([], [], c="#ff5a5a", s=28, edgecolors="black", linewidths=0.4) else: self.image_artist.set_data(image) lumen_x, lumen_y = state.lumen_x, state.lumen_y plaque_x, plaque_y = state.plaque_x, state.plaque_y self.lumen_line.set_data(lumen_x + lumen_x[:1], lumen_y + lumen_y[:1]) self.plaque_line.set_data(plaque_x + plaque_x[:1], plaque_y + plaque_y[:1]) lumen_offsets = np.c_[lumen_x, lumen_y] if lumen_x and lumen_y else np.empty((0, 2)) plaque_offsets = np.c_[plaque_x, plaque_y] if plaque_x and plaque_y else np.empty((0, 2)) self.lumen_points.set_offsets(lumen_offsets) self.plaque_points.set_offsets(plaque_offsets) bif_text = "Yes" if state.bifurcation else "No" bif_color = "#D6F5D6" if state.bifurcation else "#F5D6D6" self.bif_button.label.set_text(f"Bifurcation: {bif_text}") self.bif_button.ax.set_facecolor(bif_color) self.status_text.set_text( f"Frame {self.current_frame + 1}/{len(self.frame_states)} " f"Lumen pts: {len(lumen_x)} Plaque pts: {len(plaque_x)} " f"Autosave file: {os.path.basename(self.edits_path)}" ) self.fig.canvas.draw_idle() def _set_frame(self, frame_idx: int) -> None: frame_idx = int(np.clip(frame_idx, 0, len(self.frame_states) - 1)) if frame_idx == self.current_frame: return if self.frame_dirty: self._append_frame_edit(self.current_frame, reason="frame_change") self.current_frame = frame_idx self.is_updating_slider = True self.slider.set_val(frame_idx) self.is_updating_slider = False self._render_frame() def _nearest_point(self, x: float, y: float) -> tuple[str, int] | None: state = self._state() def best_idx(xs: list[float], ys: list[float]) -> tuple[int, float] | None: if not xs: return None pts = np.column_stack((np.asarray(xs), np.asarray(ys))) dist = np.linalg.norm(pts - np.asarray([x, y]), axis=1) idx = int(np.argmin(dist)) return idx, float(dist[idx]) lumen = best_idx(state.lumen_x, state.lumen_y) plaque = best_idx(state.plaque_x, state.plaque_y) choice: tuple[str, int, float] | None = None if lumen is not None: choice = ("lumen", lumen[0], lumen[1]) if plaque is not None and (choice is None or plaque[1] < choice[2]): choice = ("plaque", plaque[0], plaque[1]) if choice is None or choice[2] > self.drag_threshold_px: return None return choice[0], choice[1] def _on_slider(self, val: float) -> None: if self.is_updating_slider: return self._set_frame(int(val)) def _on_prev(self, _event: Any) -> None: self._set_frame(self.current_frame - 1) def _on_next(self, _event: Any) -> None: self._set_frame(self.current_frame + 1) def _on_save(self, _event: Any) -> None: self._append_frame_edit(self.current_frame, reason="manual_save") self._render_frame() def _on_toggle_bifurcation(self, _event: Any) -> None: state = self._state() state.bifurcation = not state.bifurcation self.frame_dirty = True self._append_frame_edit(self.current_frame, reason="bifurcation_toggle") self._render_frame() def _on_key(self, event: Any) -> None: if event.key in {"left", "a"}: self._set_frame(self.current_frame - 1) elif event.key in {"right", "d"}: self._set_frame(self.current_frame + 1) elif event.key == "s": self._on_save(event) elif event.key == "b": self._on_toggle_bifurcation(event) def _on_press(self, event: Any) -> None: if event.inaxes != self.ax_image or event.button != 1: return if event.xdata is None or event.ydata is None: return hit = self._nearest_point(float(event.xdata), float(event.ydata)) self.drag_target = hit def _on_motion(self, event: Any) -> None: if self.drag_target is None: return if event.inaxes != self.ax_image or event.xdata is None or event.ydata is None: return state = self._state() contour_name, idx = self.drag_target h, w = self.images[self.current_frame].shape x = float(np.clip(event.xdata, 0, w - 1)) y = float(np.clip(event.ydata, 0, h - 1)) if contour_name == "lumen": state.lumen_x[idx] = x state.lumen_y[idx] = y else: state.plaque_x[idx] = x state.plaque_y[idx] = y self.frame_dirty = True self._render_frame() def _on_release(self, _event: Any) -> None: if self.drag_target is None: return self.drag_target = None if self.frame_dirty: self._append_frame_edit(self.current_frame, reason="point_drag") self._render_frame() def _on_close(self, _event: Any) -> None: if self.frame_dirty: self._append_frame_edit(self.current_frame, reason="window_close") def show(self) -> None: self._build_ui() self._render_frame() plt.show() def _default_annotations_path(dicom_path: str, output_root: str) -> str: stem = os.path.splitext(os.path.basename(dicom_path))[0] pattern = os.path.join(output_root, "*", f"{stem}_contours.jsonl") matches = [p for p in glob(pattern) if os.path.isfile(p)] if matches: matches.sort(key=lambda p: os.path.getmtime(p), reverse=True) return matches[0] raise FileNotFoundError( "No pipeline contour JSONL found for " f"'{stem}'. Expected files like: {pattern}" ) def _default_edits_path(annotations_path: str) -> str: if annotations_path.endswith("_contours.jsonl"): return annotations_path.replace("_contours.jsonl", "_edited_annotations.jsonl") root, _ = os.path.splitext(annotations_path) return root + "_edited_annotations.jsonl" def launch_annotation_editor( dicom_path: str, annotations_path: str | None = None, edits_path: str | None = None, output_root: str = "output", ) -> None: """Launch interactive annotation GUI. If ``annotations_path`` is not provided, the latest contours JSONL is selected from ``output//``. """ if annotations_path is None: annotations_path = _default_annotations_path(dicom_path, output_root) if edits_path is None: edits_path = _default_edits_path(annotations_path) editor = AnnotationEditor( dicom_path=dicom_path, annotations_path=annotations_path, edits_path=edits_path, ) editor.show()