ivus-segmentation / deepivus /gui /annotation_editor.py
Aditya2162's picture
Upload folder using huggingface_hub
1d197a4 verified
"""Interactive GUI for editing per-frame IVUS contour annotations."""
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from datetime import datetime
from glob import glob
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
from matplotlib.widgets import Button, Slider
from ..io.dicom import read_dicom
@dataclass
class FrameState:
"""In-memory annotation state for one frame."""
frame: int
lumen_x: list[float]
lumen_y: list[float]
plaque_x: list[float]
plaque_y: list[float]
lumen_confidence: float | None
plaque_confidence: float | None
bifurcation: bool = False
class AnnotationEditor:
"""Matplotlib-based editor for contour annotations."""
def __init__(
self,
dicom_path: str,
annotations_path: str,
edits_path: str,
) -> None:
self.dicom_path = dicom_path
self.annotations_path = annotations_path
self.edits_path = edits_path
_, images = read_dicom(dicom_path)
self.images = images
self.frame_states = self._load_base_annotations(annotations_path)
if len(self.frame_states) != images.shape[0]:
raise ValueError(
"Frame count mismatch between DICOM and annotations: "
f"{images.shape[0]} vs {len(self.frame_states)}"
)
self._load_existing_edits(edits_path)
self._ensure_edits_file_header()
self.current_frame = 0
self.drag_target: tuple[str, int] | None = None
self.drag_threshold_px = 10.0
self.frame_dirty = False
self.fig = None
self.ax_image = None
self.image_artist = None
self.lumen_line: Line2D | None = None
self.plaque_line: Line2D | None = None
self.lumen_points = None
self.plaque_points = None
self.slider: Slider | None = None
self.status_text = None
self.is_updating_slider = False
@staticmethod
def _safe_float(value: Any) -> float | None:
if value is None:
return None
try:
out = float(value)
except (TypeError, ValueError):
return None
if np.isnan(out):
return None
return out
def _load_base_annotations(self, path: str) -> list[FrameState]:
states: list[FrameState] = []
with open(path, "r", encoding="utf-8") as fp:
for raw in fp:
line = raw.strip()
if not line:
continue
rec = json.loads(line)
if rec.get("record_type") != "frame":
continue
states.append(
FrameState(
frame=int(rec["frame"]),
lumen_x=[float(v) for v in rec.get("lumen", {}).get("x", [])],
lumen_y=[float(v) for v in rec.get("lumen", {}).get("y", [])],
plaque_x=[float(v) for v in rec.get("plaque", {}).get("x", [])],
plaque_y=[float(v) for v in rec.get("plaque", {}).get("y", [])],
lumen_confidence=self._safe_float(rec.get("lumen_confidence")),
plaque_confidence=self._safe_float(rec.get("plaque_confidence")),
bifurcation=bool(rec.get("bifurcation", False)),
)
)
states.sort(key=lambda s: s.frame)
expected = list(range(len(states)))
actual = [s.frame for s in states]
if actual != expected:
raise ValueError(
"Base annotations must contain contiguous frame records in order "
f"(got first frames: {actual[:10]})"
)
return states
def _load_existing_edits(self, path: str) -> None:
if not os.path.exists(path):
return
latest: dict[int, dict[str, Any]] = {}
with open(path, "r", encoding="utf-8") as fp:
for raw in fp:
line = raw.strip()
if not line:
continue
rec = json.loads(line)
if rec.get("record_type") != "frame_edit":
continue
frame_idx = int(rec["frame"])
latest[frame_idx] = rec
for frame_idx, rec in latest.items():
if frame_idx < 0 or frame_idx >= len(self.frame_states):
continue
state = self.frame_states[frame_idx]
state.lumen_x = [float(v) for v in rec.get("lumen", {}).get("x", state.lumen_x)]
state.lumen_y = [float(v) for v in rec.get("lumen", {}).get("y", state.lumen_y)]
state.plaque_x = [float(v) for v in rec.get("plaque", {}).get("x", state.plaque_x)]
state.plaque_y = [float(v) for v in rec.get("plaque", {}).get("y", state.plaque_y)]
state.bifurcation = bool(rec.get("bifurcation", state.bifurcation))
def _ensure_edits_file_header(self) -> None:
os.makedirs(os.path.dirname(os.path.abspath(self.edits_path)), exist_ok=True)
if os.path.exists(self.edits_path):
return
meta = {
"record_type": "meta",
"created_at": datetime.now().isoformat(timespec="seconds"),
"dicom_path": os.path.abspath(self.dicom_path),
"base_annotations_path": os.path.abspath(self.annotations_path),
"format": "append_only_frame_edits",
}
with open(self.edits_path, "w", encoding="utf-8") as fp:
fp.write(json.dumps(meta) + "\n")
def _append_frame_edit(self, frame_idx: int, reason: str) -> None:
state = self.frame_states[frame_idx]
rec = {
"record_type": "frame_edit",
"saved_at": datetime.now().isoformat(timespec="seconds"),
"reason": reason,
"frame": frame_idx,
"bifurcation": state.bifurcation,
"lumen": {"x": state.lumen_x, "y": state.lumen_y},
"plaque": {"x": state.plaque_x, "y": state.plaque_y},
"lumen_confidence": state.lumen_confidence,
"plaque_confidence": state.plaque_confidence,
}
with open(self.edits_path, "a", encoding="utf-8") as fp:
fp.write(json.dumps(rec) + "\n")
fp.flush()
os.fsync(fp.fileno())
self.frame_dirty = False
def _build_ui(self) -> None:
self.fig = plt.figure(figsize=(13, 9))
self.fig.canvas.manager.set_window_title("DeepIVUS Annotation Editor")
self.ax_image = self.fig.add_axes([0.05, 0.22, 0.9, 0.74])
self.ax_image.set_title("DeepIVUS Annotation Editor", fontsize=14, weight="bold")
self.ax_image.set_axis_off()
slider_ax = self.fig.add_axes([0.12, 0.13, 0.76, 0.035])
self.slider = Slider(
ax=slider_ax,
label="Frame",
valmin=0,
valmax=len(self.frame_states) - 1,
valinit=0,
valstep=1,
color="#1f77b4",
)
prev_ax = self.fig.add_axes([0.12, 0.05, 0.1, 0.055])
next_ax = self.fig.add_axes([0.24, 0.05, 0.1, 0.055])
save_ax = self.fig.add_axes([0.42, 0.05, 0.14, 0.055])
bif_ax = self.fig.add_axes([0.72, 0.05, 0.16, 0.055])
self.prev_button = Button(prev_ax, "Prev Frame", color="#E0E0E0", hovercolor="#D0D0D0")
self.next_button = Button(next_ax, "Next Frame", color="#E0E0E0", hovercolor="#D0D0D0")
self.save_button = Button(save_ax, "Save Frame", color="#D6F5D6", hovercolor="#BFF0BF")
self.bif_button = Button(bif_ax, "Bifurcation: No", color="#F5D6D6", hovercolor="#F0BFBF")
self.status_text = self.fig.text(0.05, 0.18, "", fontsize=10)
self.slider.on_changed(self._on_slider)
self.prev_button.on_clicked(self._on_prev)
self.next_button.on_clicked(self._on_next)
self.save_button.on_clicked(self._on_save)
self.bif_button.on_clicked(self._on_toggle_bifurcation)
self.fig.canvas.mpl_connect("button_press_event", self._on_press)
self.fig.canvas.mpl_connect("motion_notify_event", self._on_motion)
self.fig.canvas.mpl_connect("button_release_event", self._on_release)
self.fig.canvas.mpl_connect("key_press_event", self._on_key)
self.fig.canvas.mpl_connect("close_event", self._on_close)
def _state(self) -> FrameState:
return self.frame_states[self.current_frame]
def _render_frame(self) -> None:
state = self._state()
image = self.images[self.current_frame]
if self.image_artist is None:
self.image_artist = self.ax_image.imshow(image, cmap="gray")
self.lumen_line = self.ax_image.plot([], [], color="#1db954", lw=2)[0]
self.plaque_line = self.ax_image.plot([], [], color="#ff5a5a", lw=2)[0]
self.lumen_points = self.ax_image.scatter([], [], c="#1db954", s=28, edgecolors="black", linewidths=0.4)
self.plaque_points = self.ax_image.scatter([], [], c="#ff5a5a", s=28, edgecolors="black", linewidths=0.4)
else:
self.image_artist.set_data(image)
lumen_x, lumen_y = state.lumen_x, state.lumen_y
plaque_x, plaque_y = state.plaque_x, state.plaque_y
self.lumen_line.set_data(lumen_x + lumen_x[:1], lumen_y + lumen_y[:1])
self.plaque_line.set_data(plaque_x + plaque_x[:1], plaque_y + plaque_y[:1])
lumen_offsets = np.c_[lumen_x, lumen_y] if lumen_x and lumen_y else np.empty((0, 2))
plaque_offsets = np.c_[plaque_x, plaque_y] if plaque_x and plaque_y else np.empty((0, 2))
self.lumen_points.set_offsets(lumen_offsets)
self.plaque_points.set_offsets(plaque_offsets)
bif_text = "Yes" if state.bifurcation else "No"
bif_color = "#D6F5D6" if state.bifurcation else "#F5D6D6"
self.bif_button.label.set_text(f"Bifurcation: {bif_text}")
self.bif_button.ax.set_facecolor(bif_color)
self.status_text.set_text(
f"Frame {self.current_frame + 1}/{len(self.frame_states)} "
f"Lumen pts: {len(lumen_x)} Plaque pts: {len(plaque_x)} "
f"Autosave file: {os.path.basename(self.edits_path)}"
)
self.fig.canvas.draw_idle()
def _set_frame(self, frame_idx: int) -> None:
frame_idx = int(np.clip(frame_idx, 0, len(self.frame_states) - 1))
if frame_idx == self.current_frame:
return
if self.frame_dirty:
self._append_frame_edit(self.current_frame, reason="frame_change")
self.current_frame = frame_idx
self.is_updating_slider = True
self.slider.set_val(frame_idx)
self.is_updating_slider = False
self._render_frame()
def _nearest_point(self, x: float, y: float) -> tuple[str, int] | None:
state = self._state()
def best_idx(xs: list[float], ys: list[float]) -> tuple[int, float] | None:
if not xs:
return None
pts = np.column_stack((np.asarray(xs), np.asarray(ys)))
dist = np.linalg.norm(pts - np.asarray([x, y]), axis=1)
idx = int(np.argmin(dist))
return idx, float(dist[idx])
lumen = best_idx(state.lumen_x, state.lumen_y)
plaque = best_idx(state.plaque_x, state.plaque_y)
choice: tuple[str, int, float] | None = None
if lumen is not None:
choice = ("lumen", lumen[0], lumen[1])
if plaque is not None and (choice is None or plaque[1] < choice[2]):
choice = ("plaque", plaque[0], plaque[1])
if choice is None or choice[2] > self.drag_threshold_px:
return None
return choice[0], choice[1]
def _on_slider(self, val: float) -> None:
if self.is_updating_slider:
return
self._set_frame(int(val))
def _on_prev(self, _event: Any) -> None:
self._set_frame(self.current_frame - 1)
def _on_next(self, _event: Any) -> None:
self._set_frame(self.current_frame + 1)
def _on_save(self, _event: Any) -> None:
self._append_frame_edit(self.current_frame, reason="manual_save")
self._render_frame()
def _on_toggle_bifurcation(self, _event: Any) -> None:
state = self._state()
state.bifurcation = not state.bifurcation
self.frame_dirty = True
self._append_frame_edit(self.current_frame, reason="bifurcation_toggle")
self._render_frame()
def _on_key(self, event: Any) -> None:
if event.key in {"left", "a"}:
self._set_frame(self.current_frame - 1)
elif event.key in {"right", "d"}:
self._set_frame(self.current_frame + 1)
elif event.key == "s":
self._on_save(event)
elif event.key == "b":
self._on_toggle_bifurcation(event)
def _on_press(self, event: Any) -> None:
if event.inaxes != self.ax_image or event.button != 1:
return
if event.xdata is None or event.ydata is None:
return
hit = self._nearest_point(float(event.xdata), float(event.ydata))
self.drag_target = hit
def _on_motion(self, event: Any) -> None:
if self.drag_target is None:
return
if event.inaxes != self.ax_image or event.xdata is None or event.ydata is None:
return
state = self._state()
contour_name, idx = self.drag_target
h, w = self.images[self.current_frame].shape
x = float(np.clip(event.xdata, 0, w - 1))
y = float(np.clip(event.ydata, 0, h - 1))
if contour_name == "lumen":
state.lumen_x[idx] = x
state.lumen_y[idx] = y
else:
state.plaque_x[idx] = x
state.plaque_y[idx] = y
self.frame_dirty = True
self._render_frame()
def _on_release(self, _event: Any) -> None:
if self.drag_target is None:
return
self.drag_target = None
if self.frame_dirty:
self._append_frame_edit(self.current_frame, reason="point_drag")
self._render_frame()
def _on_close(self, _event: Any) -> None:
if self.frame_dirty:
self._append_frame_edit(self.current_frame, reason="window_close")
def show(self) -> None:
self._build_ui()
self._render_frame()
plt.show()
def _default_annotations_path(dicom_path: str, output_root: str) -> str:
stem = os.path.splitext(os.path.basename(dicom_path))[0]
pattern = os.path.join(output_root, "*", f"{stem}_contours.jsonl")
matches = [p for p in glob(pattern) if os.path.isfile(p)]
if matches:
matches.sort(key=lambda p: os.path.getmtime(p), reverse=True)
return matches[0]
raise FileNotFoundError(
"No pipeline contour JSONL found for "
f"'{stem}'. Expected files like: {pattern}"
)
def _default_edits_path(annotations_path: str) -> str:
if annotations_path.endswith("_contours.jsonl"):
return annotations_path.replace("_contours.jsonl", "_edited_annotations.jsonl")
root, _ = os.path.splitext(annotations_path)
return root + "_edited_annotations.jsonl"
def launch_annotation_editor(
dicom_path: str,
annotations_path: str | None = None,
edits_path: str | None = None,
output_root: str = "output",
) -> None:
"""Launch interactive annotation GUI.
If ``annotations_path`` is not provided, the latest contours JSONL is
selected from ``output/<timestamp>/``.
"""
if annotations_path is None:
annotations_path = _default_annotations_path(dicom_path, output_root)
if edits_path is None:
edits_path = _default_edits_path(annotations_path)
editor = AnnotationEditor(
dicom_path=dicom_path,
annotations_path=annotations_path,
edits_path=edits_path,
)
editor.show()