#!/usr/bin/env python3 """GUI to label bifurcation yes/no for sampled frames across videos. Designed for files like: - evals/frame_bank_merged/new_bifurcation_samples_300.jsonl """ from __future__ import annotations import argparse import json import os import tempfile from dataclasses import dataclass from pathlib import Path from typing import Any import matplotlib.pyplot as plt import numpy as np from matplotlib.widgets import Button, Slider from deepivus.io.dicom import read_dicom @dataclass class Sample: sample_id: str group: str dicom_stem: str dicom_path: Path frame: int bifurcation: bool | None def _read_jsonl(path: Path) -> list[dict[str, Any]]: rows = [] with path.open("r", encoding="utf-8") as fp: for raw in fp: line = raw.strip() if not line: continue rows.append(json.loads(line)) return rows def _write_jsonl_atomic(path: Path, rows: list[dict[str, Any]]) -> None: path.parent.mkdir(parents=True, exist_ok=True) with tempfile.NamedTemporaryFile("w", encoding="utf-8", dir=path.parent, delete=False) as tmp: tmp_path = Path(tmp.name) for rec in rows: tmp.write(json.dumps(rec) + "\n") tmp.flush() os.fsync(tmp.fileno()) os.replace(tmp_path, path) def _load_samples(path: Path) -> tuple[dict[str, Any], list[Sample]]: rows = _read_jsonl(path) if not rows: raise ValueError(f"Empty file: {path}") meta = rows[0] samples: list[Sample] = [] for rec in rows[1:]: if rec.get("record_type") != "frame": continue samples.append( Sample( sample_id=str(rec["sample_id"]), group=str(rec["group"]), dicom_stem=str(rec["dicom_stem"]), dicom_path=Path(rec["dicom_path"]), frame=int(rec["frame"]), bifurcation=rec.get("bifurcation", None), ) ) if not samples: raise ValueError(f"No frame records in {path}") return meta, samples def _save_samples(path: Path, meta: dict[str, Any], samples: list[Sample]) -> None: rows: list[dict[str, Any]] = [meta] for s in samples: rows.append( { "record_type": "frame", "sample_id": s.sample_id, "group": s.group, "dicom_stem": s.dicom_stem, "dicom_path": str(s.dicom_path), "frame": int(s.frame), "bifurcation": s.bifurcation, } ) _write_jsonl_atomic(path, rows) class BifurcationSampleEditor: def __init__(self, samples_path: Path) -> None: self.samples_path = samples_path self.meta, self.samples = _load_samples(samples_path) self.current = 0 self.is_updating_slider = False self.image_cache: dict[Path, np.ndarray] = {} self.fig = None self.ax_image = None self.image_artist = None self.slider = None self.status_text = None self.toggle_button = None self.save_button = None def _sample(self) -> Sample: return self.samples[self.current] def _load_frame(self, sample: Sample) -> np.ndarray: if sample.dicom_path not in self.image_cache: _, stack = read_dicom(str(sample.dicom_path)) self.image_cache[sample.dicom_path] = stack stack = self.image_cache[sample.dicom_path] if sample.frame < 0 or sample.frame >= stack.shape[0]: raise IndexError( f"Frame {sample.frame} out of range for {sample.dicom_path} (num_frames={stack.shape[0]})" ) return stack[sample.frame] def _autosave(self) -> None: _save_samples(self.samples_path, self.meta, self.samples) def _set_current(self, idx: int) -> None: self.current = int(np.clip(idx, 0, len(self.samples) - 1)) self.is_updating_slider = True self.slider.set_val(self.current) self.is_updating_slider = False self._render() def _render(self) -> None: s = self._sample() frame = self._load_frame(s) if self.image_artist is None: self.image_artist = self.ax_image.imshow(frame, cmap="gray") else: self.image_artist.set_data(frame) label = s.bifurcation if label is True: txt = "Bifurcation: Yes" color = "#D6F5D6" elif label is False: txt = "Bifurcation: No" color = "#F5D6D6" else: txt = "Bifurcation: Unset" color = "#F5F1D6" self.toggle_button.label.set_text(txt) self.toggle_button.ax.set_facecolor(color) done = sum(1 for x in self.samples if x.bifurcation is not None) self.status_text.set_text( f"Sample {self.current + 1}/{len(self.samples)} ({done} labeled) " f"{s.sample_id} frame={s.frame} " "Keys: left/right navigate, y=yes, n=no, u=unset, s=save" ) self.fig.canvas.draw_idle() def _set_label(self, value: bool | None) -> None: self._sample().bifurcation = value self._autosave() self._render() def _on_slider(self, val: float) -> None: if self.is_updating_slider: return self._set_current(int(val)) def _on_prev(self, _event: Any) -> None: self._set_current(self.current - 1) def _on_next(self, _event: Any) -> None: self._set_current(self.current + 1) def _on_toggle(self, _event: Any) -> None: cur = self._sample().bifurcation if cur is None: self._set_label(True) else: self._set_label(not cur) def _on_save(self, _event: Any) -> None: self._autosave() self._render() def _on_key(self, event: Any) -> None: if event.key in {"left", "a"}: self._on_prev(event) elif event.key in {"right", "d"}: self._on_next(event) elif event.key == "y": self._set_label(True) elif event.key == "n": self._set_label(False) elif event.key == "u": self._set_label(None) elif event.key == "s": self._on_save(event) def show(self) -> None: self.fig = plt.figure(figsize=(12, 9)) self.fig.canvas.manager.set_window_title("Bifurcation Sample Labeler") self.ax_image = self.fig.add_axes([0.05, 0.22, 0.9, 0.74]) self.ax_image.set_axis_off() self.ax_image.set_title("Bifurcation Frame Labeling", fontsize=14, weight="bold") slider_ax = self.fig.add_axes([0.12, 0.13, 0.76, 0.035]) self.slider = Slider( ax=slider_ax, label="Sample", valmin=0, valmax=len(self.samples) - 1, valinit=0, valstep=1, color="#1f77b4", ) prev_ax = self.fig.add_axes([0.12, 0.05, 0.1, 0.055]) next_ax = self.fig.add_axes([0.24, 0.05, 0.1, 0.055]) toggle_ax = self.fig.add_axes([0.38, 0.05, 0.2, 0.055]) save_ax = self.fig.add_axes([0.62, 0.05, 0.12, 0.055]) prev_button = Button(prev_ax, "Prev", color="#E0E0E0", hovercolor="#D0D0D0") next_button = Button(next_ax, "Next", color="#E0E0E0", hovercolor="#D0D0D0") self.toggle_button = Button(toggle_ax, "Bifurcation: Unset", color="#F5F1D6", hovercolor="#EDE6BF") self.save_button = Button(save_ax, "Save", color="#D6F5D6", hovercolor="#BFF0BF") self.status_text = self.fig.text(0.05, 0.18, "", fontsize=10) self.slider.on_changed(self._on_slider) prev_button.on_clicked(self._on_prev) next_button.on_clicked(self._on_next) self.toggle_button.on_clicked(self._on_toggle) self.save_button.on_clicked(self._on_save) self.fig.canvas.mpl_connect("key_press_event", self._on_key) self._render() plt.show() def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--samples-jsonl", type=Path, default=Path("evals/frame_bank_merged/new_bifurcation_samples_300.jsonl"), help="Path to sampled frames JSONL.", ) args = parser.parse_args() editor = BifurcationSampleEditor(args.samples_jsonl) editor.show() if __name__ == "__main__": main()