ivus-segmentation / scripts /finetune /bifurcation /annotate_bifurcation_samples.py
Aditya2162's picture
Upload folder using huggingface_hub
1d197a4 verified
#!/usr/bin/env python3
"""GUI to label bifurcation yes/no for sampled frames across videos.
Designed for files like:
- evals/frame_bank_merged/new_bifurcation_samples_300.jsonl
"""
from __future__ import annotations
import argparse
import json
import os
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Button, Slider
from deepivus.io.dicom import read_dicom
@dataclass
class Sample:
sample_id: str
group: str
dicom_stem: str
dicom_path: Path
frame: int
bifurcation: bool | None
def _read_jsonl(path: Path) -> list[dict[str, Any]]:
rows = []
with path.open("r", encoding="utf-8") as fp:
for raw in fp:
line = raw.strip()
if not line:
continue
rows.append(json.loads(line))
return rows
def _write_jsonl_atomic(path: Path, rows: list[dict[str, Any]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile("w", encoding="utf-8", dir=path.parent, delete=False) as tmp:
tmp_path = Path(tmp.name)
for rec in rows:
tmp.write(json.dumps(rec) + "\n")
tmp.flush()
os.fsync(tmp.fileno())
os.replace(tmp_path, path)
def _load_samples(path: Path) -> tuple[dict[str, Any], list[Sample]]:
rows = _read_jsonl(path)
if not rows:
raise ValueError(f"Empty file: {path}")
meta = rows[0]
samples: list[Sample] = []
for rec in rows[1:]:
if rec.get("record_type") != "frame":
continue
samples.append(
Sample(
sample_id=str(rec["sample_id"]),
group=str(rec["group"]),
dicom_stem=str(rec["dicom_stem"]),
dicom_path=Path(rec["dicom_path"]),
frame=int(rec["frame"]),
bifurcation=rec.get("bifurcation", None),
)
)
if not samples:
raise ValueError(f"No frame records in {path}")
return meta, samples
def _save_samples(path: Path, meta: dict[str, Any], samples: list[Sample]) -> None:
rows: list[dict[str, Any]] = [meta]
for s in samples:
rows.append(
{
"record_type": "frame",
"sample_id": s.sample_id,
"group": s.group,
"dicom_stem": s.dicom_stem,
"dicom_path": str(s.dicom_path),
"frame": int(s.frame),
"bifurcation": s.bifurcation,
}
)
_write_jsonl_atomic(path, rows)
class BifurcationSampleEditor:
def __init__(self, samples_path: Path) -> None:
self.samples_path = samples_path
self.meta, self.samples = _load_samples(samples_path)
self.current = 0
self.is_updating_slider = False
self.image_cache: dict[Path, np.ndarray] = {}
self.fig = None
self.ax_image = None
self.image_artist = None
self.slider = None
self.status_text = None
self.toggle_button = None
self.save_button = None
def _sample(self) -> Sample:
return self.samples[self.current]
def _load_frame(self, sample: Sample) -> np.ndarray:
if sample.dicom_path not in self.image_cache:
_, stack = read_dicom(str(sample.dicom_path))
self.image_cache[sample.dicom_path] = stack
stack = self.image_cache[sample.dicom_path]
if sample.frame < 0 or sample.frame >= stack.shape[0]:
raise IndexError(
f"Frame {sample.frame} out of range for {sample.dicom_path} (num_frames={stack.shape[0]})"
)
return stack[sample.frame]
def _autosave(self) -> None:
_save_samples(self.samples_path, self.meta, self.samples)
def _set_current(self, idx: int) -> None:
self.current = int(np.clip(idx, 0, len(self.samples) - 1))
self.is_updating_slider = True
self.slider.set_val(self.current)
self.is_updating_slider = False
self._render()
def _render(self) -> None:
s = self._sample()
frame = self._load_frame(s)
if self.image_artist is None:
self.image_artist = self.ax_image.imshow(frame, cmap="gray")
else:
self.image_artist.set_data(frame)
label = s.bifurcation
if label is True:
txt = "Bifurcation: Yes"
color = "#D6F5D6"
elif label is False:
txt = "Bifurcation: No"
color = "#F5D6D6"
else:
txt = "Bifurcation: Unset"
color = "#F5F1D6"
self.toggle_button.label.set_text(txt)
self.toggle_button.ax.set_facecolor(color)
done = sum(1 for x in self.samples if x.bifurcation is not None)
self.status_text.set_text(
f"Sample {self.current + 1}/{len(self.samples)} ({done} labeled) "
f"{s.sample_id} frame={s.frame} "
"Keys: left/right navigate, y=yes, n=no, u=unset, s=save"
)
self.fig.canvas.draw_idle()
def _set_label(self, value: bool | None) -> None:
self._sample().bifurcation = value
self._autosave()
self._render()
def _on_slider(self, val: float) -> None:
if self.is_updating_slider:
return
self._set_current(int(val))
def _on_prev(self, _event: Any) -> None:
self._set_current(self.current - 1)
def _on_next(self, _event: Any) -> None:
self._set_current(self.current + 1)
def _on_toggle(self, _event: Any) -> None:
cur = self._sample().bifurcation
if cur is None:
self._set_label(True)
else:
self._set_label(not cur)
def _on_save(self, _event: Any) -> None:
self._autosave()
self._render()
def _on_key(self, event: Any) -> None:
if event.key in {"left", "a"}:
self._on_prev(event)
elif event.key in {"right", "d"}:
self._on_next(event)
elif event.key == "y":
self._set_label(True)
elif event.key == "n":
self._set_label(False)
elif event.key == "u":
self._set_label(None)
elif event.key == "s":
self._on_save(event)
def show(self) -> None:
self.fig = plt.figure(figsize=(12, 9))
self.fig.canvas.manager.set_window_title("Bifurcation Sample Labeler")
self.ax_image = self.fig.add_axes([0.05, 0.22, 0.9, 0.74])
self.ax_image.set_axis_off()
self.ax_image.set_title("Bifurcation Frame Labeling", fontsize=14, weight="bold")
slider_ax = self.fig.add_axes([0.12, 0.13, 0.76, 0.035])
self.slider = Slider(
ax=slider_ax,
label="Sample",
valmin=0,
valmax=len(self.samples) - 1,
valinit=0,
valstep=1,
color="#1f77b4",
)
prev_ax = self.fig.add_axes([0.12, 0.05, 0.1, 0.055])
next_ax = self.fig.add_axes([0.24, 0.05, 0.1, 0.055])
toggle_ax = self.fig.add_axes([0.38, 0.05, 0.2, 0.055])
save_ax = self.fig.add_axes([0.62, 0.05, 0.12, 0.055])
prev_button = Button(prev_ax, "Prev", color="#E0E0E0", hovercolor="#D0D0D0")
next_button = Button(next_ax, "Next", color="#E0E0E0", hovercolor="#D0D0D0")
self.toggle_button = Button(toggle_ax, "Bifurcation: Unset", color="#F5F1D6", hovercolor="#EDE6BF")
self.save_button = Button(save_ax, "Save", color="#D6F5D6", hovercolor="#BFF0BF")
self.status_text = self.fig.text(0.05, 0.18, "", fontsize=10)
self.slider.on_changed(self._on_slider)
prev_button.on_clicked(self._on_prev)
next_button.on_clicked(self._on_next)
self.toggle_button.on_clicked(self._on_toggle)
self.save_button.on_clicked(self._on_save)
self.fig.canvas.mpl_connect("key_press_event", self._on_key)
self._render()
plt.show()
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--samples-jsonl",
type=Path,
default=Path("evals/frame_bank_merged/new_bifurcation_samples_300.jsonl"),
help="Path to sampled frames JSONL.",
)
args = parser.parse_args()
editor = BifurcationSampleEditor(args.samples_jsonl)
editor.show()
if __name__ == "__main__":
main()