hanjang commited on
Commit
8776d83
·
verified ·
1 Parent(s): 57fde91

Upload /app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +217 -113
app.py CHANGED
@@ -5,31 +5,85 @@ from pathlib import Path
5
  import numpy as np
6
  from PIL import Image, ImageDraw, ImageOps
7
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # ---------- optional deps ----------
9
  try:
10
  import nibabel as nib
11
  HAVE_NIB = True
12
- except Exception:
 
13
  HAVE_NIB = False
 
 
14
  try:
15
  from scipy import ndimage as ndi
16
  HAVE_SCIPY = True
17
- except Exception:
 
18
  HAVE_SCIPY = False
 
 
19
  # ---------- model predictor ----------
20
  PREDICTOR = None
21
  DEVICE = "cuda:0"
 
22
  from huggingface_hub import snapshot_download
23
- CKPT = snapshot_download("hanjang/Interactive-MEN-RT",
24
- allow_patterns=["nnUNetInteractionTrainer__nnUNetPlans__3d_fullres_scratch/**"])
25
- DATA_ROOT = Path("./samples")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  for env in ("nnUNet_raw", "nnUNet_preprocessed", "nnUNet_results"):
27
  os.environ.setdefault(env, tempfile.mkdtemp(prefix=f"{env}_"))
 
28
  def _init_predictor_once():
29
- """Load model once + (best-effort) CUDA warm-up."""
30
  global PREDICTOR
31
  if PREDICTOR is not None:
32
  return True
 
 
 
33
  try:
34
  import torch
35
  from Interactive_MEN_RT_predictor import InteractiveMENRTPredictor
@@ -41,7 +95,6 @@ def _init_predictor_once():
41
  model_training_output_dir=CKPT, use_fold=0, checkpoint_name="checkpoint_best.pth"
42
  )
43
  PREDICTOR = pred
44
- # GPU warm-up (best effort)
45
  try:
46
  if torch.cuda.is_available():
47
  x = np.zeros((1, 8, 8, 8), np.float32)
@@ -57,26 +110,23 @@ def _init_predictor_once():
57
  except Exception as e:
58
  print(f"[MODEL] init failed: {e}")
59
  return False
 
60
  def preload_model_in_background():
61
  threading.Thread(target=_init_predictor_once, daemon=True).start()
 
62
  # ---------- config ----------
63
- EXAMPLES = [
64
- "BraTS-MEN-RT-0071-1",
65
- "BraTS-MEN-RT-0223-1",
66
- "BraTS-MEN-RT-0280-1",
67
- "BraTS-MEN-RT-0422-1",
68
- "BraTS-MEN-RT-0436-1",
69
- ]
70
  RENDER_PX_DEFAULT = 384
71
- MAX_RENDER_PX = 1024
72
- ROT_CCW = True # 90 degree CCW
73
  # colors
74
  ACCENT_HEX = "#1e90ff"
75
  CROSS_RGB = (30, 144, 255)
76
- GT_RGBA_FILL = (255, 215, 0, 128) # alpha=0.5
77
- PR_RGBA_FILL = (255, 60, 60, 128) # alpha=0.5
78
  SEED_RGB = (89, 224, 154)
79
- BBOX_RGB = (255, 140, 0) # Orange for bbox
 
80
  # ---------- state ----------
81
  class State:
82
  def __init__(self):
@@ -85,33 +135,36 @@ class State:
85
  self.case_id=None; self.loaded=False
86
  self.cross={"x":0,"y":0,"z":0}
87
  self.slice={"axial":0,"sagittal":0,"coronal":0}
88
- self.seeds=[] # [(x,y,z), ...]
89
- self.seed_views=[] # ["axial"/"sagittal"/"coronal", ...] (seeds and index synchronized)
90
  self.render_px=RENDER_PX_DEFAULT
91
  self.disp_wh={"axial":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT),
92
  "sagittal":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT),
93
  "coronal":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT)}
94
- self.active_view="axial" # recent clicked view
95
- # BBox state
96
- self.bbox_mode = False # point mode vs bbox mode
97
- self.bbox_points = [] # temporary: 2 points for bbox
98
- self.bboxes = [] # [(x1,y1,z1,x2,y2,z2), ...]
99
- # NIfTI metadata
100
  self.ref_affine = None
101
  self.ref_header = None
 
102
  S = State()
 
103
  # ---------- utils ----------
104
  def _norm01(a):
105
  a=a.astype(np.float32)
106
  p2,p98=np.percentile(a,2),np.percentile(a,98)
107
  if p98<=p2: p2,p98=float(a.min()),float(a.max()) or 1.0
108
  return np.clip((a-p2)/max(p98-p2,1e-6),0,1)
 
109
  def _resize_slice_nearest(arr2d,w,h):
110
  im=Image.fromarray(arr2d); im=im.resize((w,h),Image.NEAREST); return np.array(im)
 
111
  def _rot90_if_needed(img_or_np):
112
  if not ROT_CCW: return img_or_np
113
  if isinstance(img_or_np, Image.Image): return img_or_np.rotate(90, expand=True)
114
  return np.rot90(img_or_np, k=1)
 
115
  # ---------- IO ----------
116
  def _load_png_stack(case_dir):
117
  pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.png")))
@@ -124,6 +177,7 @@ def _load_png_stack(case_dir):
124
  vol=_norm01(vol)
125
  print(f"[PIL] {len(pngs)} slices -> {vol.shape} in {time.time()-t0:.2f}s")
126
  return vol, None, None
 
127
  def _load_nifti(case_dir,case_id,ds=1):
128
  if not HAVE_NIB: return None, None, None
129
  p=case_dir/f"{case_id}_t1c.nii.gz"
@@ -134,6 +188,7 @@ def _load_nifti(case_dir,case_id,ds=1):
134
  arr=_norm01(arr)
135
  print(f"[NIfTI] {case_id} -> {arr.shape} in {time.time()-t0:.2f}s")
136
  return arr, nii.affine, nii.header
 
137
  def _resample_mask_to_vol_shape(mask_xyz, vol_shape_xyz):
138
  mx,my,mz=mask_xyz.shape; vx,vy,vz=vol_shape_xyz
139
  out=np.zeros((vx,vy,vz),dtype=np.uint8)
@@ -143,6 +198,7 @@ def _resample_mask_to_vol_shape(mask_xyz, vol_shape_xyz):
143
  im=Image.fromarray(sl).resize((vy,vx),Image.NEAREST)
144
  out[:,:,k]=(np.array(im)>0).astype(np.uint8)
145
  return out
 
146
  def _load_gt(case_dir,case_id,vol_shape):
147
  candidates = [
148
  f"{case_id}_gtv.nii.gz", f"{case_id}_seg.nii.gz", f"{case_id}_gt.nii.gz",
@@ -163,6 +219,7 @@ def _load_gt(case_dir,case_id,vol_shape):
163
  print(f"[GT] load error {p.name}: {e}")
164
  print("[GT] not found.")
165
  return None
 
166
  def load_case(case_id):
167
  case_dir=DATA_ROOT/case_id
168
  vol, affine, header = _load_png_stack(case_dir)
@@ -191,12 +248,14 @@ def load_case(case_id):
191
  S.active_view="axial"
192
  S.loaded=True
193
  print(f"[LOAD] {case_id} | shape={S.shape}")
194
- # ---------- 2D (90° CCW) ----------
 
195
  def _slice2d(view):
196
  if view=="axial": sl=S.vol[:,:,S.slice["axial"]]
197
  elif view=="sagittal": sl=S.vol[S.slice["sagittal"],:,:].T
198
  else: sl=S.vol[:,S.slice["coronal"],:].T
199
  return _rot90_if_needed(sl)
 
200
  def _cross_pix_on_rot(view,w,h,x=None,y=None,z=None):
201
  X,Y,Z=S.shape
202
  if x is None: x=S.cross["x"]
@@ -212,11 +271,13 @@ def _cross_pix_on_rot(view,w,h,x=None,y=None,z=None):
212
  u=int(round(z*(w-1)/max(Z-1,1)))
213
  v=int(round((X-1-x)*(h-1)/max(X-1,1)))
214
  return u,v
 
215
  def _draw_cross(img_draw, view, w, h):
216
  u,v=_cross_pix_on_rot(view,w,h)
217
  img_draw.line([(u,0),(u,h)],fill=CROSS_RGB,width=1)
218
  img_draw.line([(0,v),(w,v)],fill=CROSS_RGB,width=1)
219
  img_draw.ellipse((u-5,v-5,u+5,v+5),fill=(255,255,255),outline=ACCENT_HEX,width=2)
 
220
  def render_top(view):
221
  if not S.loaded: return None
222
  sl=_slice2d(view)
@@ -226,12 +287,13 @@ def render_top(view):
226
  _draw_cross(dr, view, w, h)
227
  S.disp_wh[view]=im.size
228
  return im
229
- # ---- Axial overlays (mask also 90° CCW) ----
230
  def _axial_mask2d_rot(mask3d):
231
  if mask3d is None: return None
232
  m = mask3d[:,:,S.slice["axial"]].astype(np.uint8)
233
  m = _rot90_if_needed(m)
234
  return m
 
235
  def _axial_overlay_fill(mask3d, rgba):
236
  sl = _rot90_if_needed(S.vol[:,:,S.slice["axial"]])
237
  base=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGBA")
@@ -240,7 +302,7 @@ def _axial_overlay_fill(mask3d, rgba):
240
  m2d = _resize_slice_nearest((m2d>0).astype(np.uint8), S.render_px, S.render_px)
241
  over=np.zeros((S.render_px,S.render_px,4),dtype=np.uint8); over[m2d>0]=rgba
242
  return Image.alpha_composite(base,Image.fromarray(over,"RGBA")).convert("RGB")
243
- # ---------- Interaction 2D ----------
244
  def _interaction_2d():
245
  view = S.active_view
246
  if not S.loaded: return None
@@ -248,9 +310,8 @@ def _interaction_2d():
248
  im=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGB")
249
  w=h=S.render_px
250
  dr=ImageDraw.Draw(im)
251
- # cross
252
  _draw_cross(dr, view, w, h)
253
- # seeds (only show on the plane)
254
  tol=0
255
  for i, (x,y,z) in enumerate(S.seeds):
256
  on_plane = (
@@ -264,7 +325,6 @@ def _interaction_2d():
264
  dr.ellipse((u-r,v-r,u+r,v+r), fill=SEED_RGB, outline=(40,140,100), width=1)
265
  dr.text((u+6, v-8), f"{i+1}", fill=(30,30,30))
266
 
267
- # bbox temp points (in bbox mode)
268
  if S.bbox_mode:
269
  for i, (x,y,z) in enumerate(S.bbox_points):
270
  on_plane = (
@@ -279,7 +339,6 @@ def _interaction_2d():
279
  text = "P1" if i == 0 else "P2"
280
  dr.text((u+10, v-10), text, fill=BBOX_RGB)
281
 
282
- # draw completed bboxes
283
  for (x1,y1,z1,x2,y2,z2) in S.bboxes:
284
  if view=="axial":
285
  curr_z = S.slice["axial"]
@@ -289,6 +348,7 @@ def _interaction_2d():
289
  dr.rectangle([u1,v1,u2,v2], outline=(0,255,0), width=2)
290
 
291
  return im
 
292
  # ===================== segmentation ========================================
293
  def _segment_with_model():
294
  if PREDICTOR is None and not _init_predictor_once():
@@ -299,18 +359,13 @@ def _segment_with_model():
299
  PREDICTOR.set_image(img)
300
  PREDICTOR.set_target_buffer(np.zeros_like(img[0], np.float32))
301
  PREDICTOR._finish_preprocessing_and_initialize_interactions()
302
-
303
- # Add point interactions
304
  for (x,y,z) in S.seeds:
305
  PREDICTOR.add_point_interaction(x, y, z, foreground=True)
306
-
307
- # Add bbox interactions
308
  for (x1,y1,z1,x2,y2,z2) in S.bboxes:
309
  PREDICTOR.add_bbox_interaction(
310
  min(x1,x2), min(y1,y2), min(z1,z2),
311
  max(x1,x2), max(y1,y2), max(z1,z2)
312
  )
313
-
314
  PREDICTOR._predict_without_interaction()
315
  pred = (PREDICTOR.target_buffer.astype(np.float32) > 0.5).astype(np.uint8)
316
  if pred.shape != S.shape:
@@ -320,6 +375,7 @@ def _segment_with_model():
320
  except Exception as e:
321
  print(f"[MODEL] inference failed: {e}")
322
  return None, "model-error"
 
323
  def _segment_fallback():
324
  if (not S.seeds and not S.bboxes) or not HAVE_SCIPY:
325
  return None, "no-interactions-or-scipy"
@@ -344,6 +400,7 @@ def _segment_fallback():
344
  mask=ndi.binary_closing(mask,iterations=1); mask=ndi.binary_opening(mask,iterations=1)
345
  print(f"[FB] seg {time.time()-t0:.3f}s | vox={int(mask.sum())}")
346
  return (mask>0).astype(np.uint8), "ok"
 
347
  def do_segment():
348
  pred, tag = _segment_with_model()
349
  if pred is None:
@@ -351,7 +408,7 @@ def do_segment():
351
  S.pred = pred if pred is not None else None
352
  print(f"[SEG] done: {tag}")
353
  return "OK" if S.pred is not None else "Failed"
354
- # ---------- Save NIfTI ----------
355
  def save_prediction():
356
  if S.pred is None:
357
  return "No prediction to save", None
@@ -360,21 +417,18 @@ def save_prediction():
360
  try:
361
  tmp_dir = Path(tempfile.mkdtemp(prefix="menrt_output_"))
362
  out_path = tmp_dir / f"{S.case_id}_pred.nii.gz"
363
-
364
  affine = S.ref_affine if S.ref_affine is not None else np.eye(4)
365
  header = S.ref_header.copy() if S.ref_header is not None else None
366
-
367
  nii_img = nib.Nifti1Image(S.pred.astype(np.uint8), affine, header=header)
368
  nib.save(nii_img, str(out_path))
369
-
370
  print(f"[SAVE] {out_path}")
371
  return "Saved successfully!", str(out_path)
372
  except Exception as e:
373
  print(f"[SAVE] error: {e}")
374
  return f"Save failed: {e}", None
375
- # ---------- debug widgets helpers ----------
 
376
  def _seed_rows():
377
- """DataFrame rows: [[#, type, view, x, y, z], ...]"""
378
  rows=[]
379
  for i,(x,y,z) in enumerate(S.seeds):
380
  v = S.seed_views[i] if i < len(S.seed_views) else ""
@@ -382,8 +436,8 @@ def _seed_rows():
382
  for i,(x1,y1,z1,x2,y2,z2) in enumerate(S.bboxes):
383
  rows.append([len(S.seeds)+i+1, "bbox", "3D", f"{x1}-{x2}", f"{y1}-{y2}", f"{z1}-{z2}"])
384
  return rows
 
385
  def _seed_dropdown_options():
386
- """Dropdown options & default value"""
387
  opts=[]
388
  for i,(x,y,z) in enumerate(S.seeds):
389
  v = S.seed_views[i] if i < len(S.seed_views) else "axial"
@@ -391,6 +445,7 @@ def _seed_dropdown_options():
391
  for i,(x1,y1,z1,x2,y2,z2) in enumerate(S.bboxes):
392
  opts.append(f"3D → bbox {i+1} → ({x1},{y1},{z1})-({x2},{y2},{z2})")
393
  return opts
 
394
  def _debug_widgets(current_idx=None):
395
  rows = _seed_rows()
396
  df_upd = gr.update(value=rows)
@@ -401,7 +456,7 @@ def _debug_widgets(current_idx=None):
401
  val = (opts[current_idx] if (0 <= current_idx < len(opts)) else (opts[-1] if opts else None))
402
  dd_upd = gr.update(choices=opts, value=val)
403
  return df_upd, dd_upd
404
- # ---------- pack outputs ----------
405
  def _figs_and_imgs():
406
  top_ax=render_top("axial")
407
  top_sg=render_top("sagittal")
@@ -410,37 +465,48 @@ def _figs_and_imgs():
410
  ax_pr = _axial_overlay_fill(S.pred, PR_RGBA_FILL)
411
  inter2d = _interaction_2d()
412
  return top_ax, top_sg, top_co, ax_gt, ax_pr, inter2d
 
413
  def _bar_ranges_and_values():
414
  X,Y,Z=S.shape
415
  return (gr.update(minimum=0,maximum=Z-1,value=S.slice["axial"],visible=True),
416
  gr.update(minimum=0,maximum=X-1,value=S.slice["sagittal"],visible=True),
417
  gr.update(minimum=0,maximum=Y-1,value=S.slice["coronal"],visible=True))
418
- # ---------- robust event parsing ----------
419
  def _parse_evt_xy(evt):
420
- """
421
- Gradio 3/4 호환: evt.index, evt.x/y, dict(evt['index'], evt['x']/['y']) 모두 대응
422
- """
423
- if evt is None: return None
 
 
 
 
 
424
  try:
 
425
  if hasattr(evt, "index") and evt.index is not None:
426
  ix = evt.index
 
427
  if isinstance(ix, (list, tuple)) and len(ix) >= 2:
428
- return int(ix[0]), int(ix[1])
 
 
 
 
429
  if hasattr(evt, "x") and hasattr(evt, "y"):
430
- return int(getattr(evt, "x")), int(getattr(evt, "y"))
431
- if isinstance(evt, dict):
432
- if "index" in evt and evt["index"] is not None:
433
- ix = evt["index"]
434
- if isinstance(ix, (list, tuple)) and len(ix) >= 2:
435
- return int(ix[0]), int(ix[1])
436
- if "x" in evt and "y" in evt:
437
- return int(evt["x"]), int(evt["y"])
438
- d = evt.get("evt", {}).get("data", {})
439
- if "x" in d and "y" in d:
440
- return int(d["x"]), int(d["y"])
441
- except Exception:
442
- pass
443
  return None
 
444
  def _disp_to_vol(view,u,v):
445
  X,Y,Z=S.shape; w,h=S.disp_wh[view]
446
  if w<=0 or h<=0: w=h=S.render_px
@@ -458,7 +524,7 @@ def _disp_to_vol(view,u,v):
458
  y = S.slice["coronal"]
459
  x=max(0,min(X-1,x)); y=max(0,min(Y-1,y)); z=max(0,min(Z-1,z))
460
  return x,y,z
461
- # ---------- thumbnails ----------
462
  def _thumb_from_case(case_id,px=96):
463
  case_dir=DATA_ROOT/case_id
464
  pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.png")))
@@ -479,13 +545,15 @@ def _thumb_from_case(case_id,px=96):
479
  im = Image.fromarray((mid*255).astype(np.uint8)).resize((px,px),Image.BILINEAR)
480
  im = _rot90_if_needed(im)
481
  return ImageOps.expand(im,border=1,fill=200)
 
482
  # ---------- callbacks ----------
483
  def on_load(case_id):
484
  load_case(case_id)
485
  preload_model_in_background()
486
  imgs=_figs_and_imgs(); bars=_bar_ranges_and_values(); dbg=_debug_widgets()
487
- return (*imgs,*bars,*dbg, gr.update(value="Point"), None, "")
488
- def on_gallery_select(evt: gr.events.SelectData):
 
489
  idx=0
490
  if hasattr(evt,"index"):
491
  ix=evt.index; idx=int(ix[0] if isinstance(ix,(list,tuple)) else ix)
@@ -493,28 +561,42 @@ def on_gallery_select(evt: gr.events.SelectData):
493
  cid=EXAMPLES[idx]
494
  out=on_load(cid)
495
  return (*out, gr.update(value=cid))
 
496
  def _click_common(view, evt: gr.SelectData):
497
- if not S.loaded: return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
 
 
 
 
 
 
 
 
498
  xy=_parse_evt_xy(evt)
499
  if xy is None:
500
- return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
 
 
 
501
  u,v=xy
502
  x,y,z=_disp_to_vol(view,u,v)
 
503
 
504
  if S.bbox_mode:
505
- # BBox mode
506
  S.bbox_points.append((x,y,z))
507
  if len(S.bbox_points) == 2:
508
  p1, p2 = S.bbox_points
509
  S.bboxes.append((*p1, *p2))
510
  S.bbox_points = []
511
  print(f"[UI] bbox created: {S.bboxes[-1]}")
 
512
  else:
513
  print(f"[UI] bbox point {len(S.bbox_points)}/2")
 
514
  else:
515
- # Point mode
516
  S.seeds.append((x,y,z)); S.seed_views.append(view)
517
  print(f"[UI] seed+ {(x,y,z)} total={len(S.seeds)}")
 
518
 
519
  S.cross={"x":x,"y":y,"z":z}
520
  S.slice={"sagittal":x,"coronal":y,"axial":z}
@@ -522,18 +604,32 @@ def _click_common(view, evt: gr.SelectData):
522
  imgs=_figs_and_imgs(); bars=_bar_ranges_and_values()
523
  idx = len(S.seeds) + len(S.bboxes) - 1
524
  dbg=_debug_widgets(current_idx=idx)
525
- return (*imgs,*bars,*dbg, gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
526
- def on_axial_select(evt: gr.SelectData): return _click_common("axial", evt)
527
- def on_sagittal_select(evt: gr.SelectData): return _click_common("sagittal", evt)
528
- def on_coronal_select(evt: gr.SelectData): return _click_common("coronal", evt)
 
 
 
 
 
 
 
 
 
 
529
  def on_seg_button():
530
- msg=do_segment(); print(f"[UI] Segment -> {msg}")
531
- return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
 
 
532
  def on_clear():
533
  S.seeds=[]; S.seed_views=[]; S.pred=None
534
  S.bboxes=[]; S.bbox_points=[]
535
  print("[UI] cleared all")
536
- return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point"), None, "")
 
 
537
  def on_undo():
538
  if S.bbox_mode and S.bbox_points:
539
  S.bbox_points.pop()
@@ -542,49 +638,66 @@ def on_undo():
542
  elif S.seeds:
543
  S.seeds.pop(); S.seed_views.pop() if S.seed_views else None
544
  print("[UI] undo")
545
- return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
 
 
546
  def on_mode_toggle(mode):
547
  S.bbox_mode = (mode == "BBox")
548
  S.bbox_points = []
549
  print(f"[UI] mode -> {mode}")
550
- return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value=mode), None, "")
 
 
551
  def on_save():
552
  msg, path = save_prediction()
553
  if path:
554
- return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), gr.update(visible=True, value=path), msg)
555
- return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), gr.update(visible=False, value=None), msg)
 
 
 
 
 
556
  def on_z_release(z_idx):
557
  S.slice["axial"]=int(z_idx); S.cross["z"]=S.slice["axial"]; S.active_view="axial"
558
- return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
 
 
559
  def on_x_release(x_idx):
560
  S.slice["sagittal"]=int(x_idx); S.cross["x"]=S.slice["sagittal"]; S.active_view="sagittal"
561
- return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
 
 
562
  def on_y_release(y_idx):
563
  S.slice["coronal"]=int(y_idx); S.cross["y"]=S.slice["coronal"]; S.active_view="coronal"
564
- return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
 
 
565
  def _jump_to_idx(idx:int):
566
- """Common: jump to selected index from log."""
567
  total = len(S.seeds) + len(S.bboxes)
568
  if not (0 <= idx < total):
569
- return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
 
570
  if idx < len(S.seeds):
571
  x,y,z = S.seeds[idx]
572
  view = S.seed_views[idx] if idx < len(S.seed_views) else "axial"
573
  S.cross={"x":x,"y":y,"z":z}
574
  S.slice={"sagittal":x,"coronal":y,"axial":z}
575
  S.active_view=view
576
- return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(current_idx=idx), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
 
 
577
  def on_seed_df_select(evt):
578
- """DataFrame select: evt.index -> (row, col)"""
579
  try:
580
  row = int(evt.index[0]) if hasattr(evt, "index") else 0
581
  except Exception:
582
  row = 0
583
  return _jump_to_idx(row)
 
584
  def on_seed_dd_change(val):
585
- """Dropdown change: id parsing"""
586
  if not val:
587
- return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
 
588
  try:
589
  if "point" in val:
590
  p = val.split("point")[1].strip()
@@ -595,6 +708,7 @@ def on_seed_dd_change(val):
595
  except Exception:
596
  idx = len(S.seeds) + len(S.bboxes) - 1
597
  return _jump_to_idx(idx)
 
598
  # ---------- UI ----------
599
  css = """
600
  :root {
@@ -606,11 +720,10 @@ css = """
606
  .tiny .gr-slider input[type="range"]{height:6px}
607
  .tiny .gr-form{gap:6px}
608
  .smallnote{color:#3e6285;font-size:12px;margin-top:6px}
609
- .dfshort { max-height: 190px; overflow: auto; }
610
- .dfshort table { font-size: 12px; }
611
  """
612
- # thumbnails are created before UI (definition/dependency order guaranteed)
613
  THUMBS=[_thumb_from_case(cid,px=96) for cid in EXAMPLES]
 
614
  with gr.Blocks(css=css, title="Interactive-MEN-RT", theme=gr.themes.Soft(), analytics_enabled=False) as demo:
615
  gr.Markdown("""
616
  <h3 style='color:#114b8b'>Interactive-MEN-RT Segmentation</h3>
@@ -619,10 +732,11 @@ with gr.Blocks(css=css, title="Interactive-MEN-RT", theme=gr.themes.Soft(), anal
619
  <b>⚠️ Research only — Not for clinical use</b>
620
  </div>
621
  """)
 
622
  with gr.Row():
623
  with gr.Column(scale=1, min_width=280, elem_classes=["round"]):
624
- gr.Markdown("<div class='section'>Samples &amp; Tools</div>")
625
- gallery = gr.Gallery(value=THUMBS, columns=len(EXAMPLES), height=110, allow_preview=False, preview=False, show_label=False)
626
  case_dd = gr.Dropdown(choices=EXAMPLES, value=EXAMPLES[0], label="Case")
627
 
628
  mode_radio = gr.Radio(["Point", "BBox"], value="Point", label="Interaction Mode")
@@ -643,7 +757,7 @@ with gr.Blocks(css=css, title="Interactive-MEN-RT", theme=gr.themes.Soft(), anal
643
  wrap=True,
644
  row_count=(0, "dynamic")
645
  )
646
- seed_dd = gr.Dropdown(choices=[], value=None, label="Go to (select a point)")
647
 
648
  pred_file = gr.File(label="Download Prediction", visible=False)
649
  status_text = gr.Textbox(label="Status", value="", interactive=False, lines=1)
@@ -660,43 +774,33 @@ with gr.Blocks(css=css, title="Interactive-MEN-RT", theme=gr.themes.Soft(), anal
660
  coronal = gr.Image(type="pil", interactive=True, height=RENDER_PX_DEFAULT+8, label="Coronal (Y)")
661
  y_bar = gr.Slider(0,1,value=0,step=1,label="Y", elem_classes=["tiny"])
662
  with gr.Row():
663
- out_ax_gt = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Axial GT (fill)")
664
- out_ax_pr = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Axial Pred (fill)")
665
- inter2d = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Interaction (2D)")
666
 
667
- # outputs
668
  outputs = [axial, sagittal, coronal, out_ax_gt, out_ax_pr, inter2d,
669
  z_bar, x_bar, y_bar, seeds_df, seed_dd, mode_radio, pred_file, status_text]
670
 
671
- # initial
672
  demo.load(lambda: on_load(EXAMPLES[0]), [], outputs)
673
-
674
- # case change
675
  case_dd.change(lambda cid: on_load(cid), [case_dd], outputs)
676
-
677
- # gallery select
678
  gallery.select(on_gallery_select, [], outputs + [case_dd])
679
 
680
- # clicks
681
  axial.select(on_axial_select, [], outputs)
682
  sagittal.select(on_sagittal_select, [], outputs)
683
  coronal.select(on_coronal_select, [], outputs)
684
 
685
- # bars
686
  z_bar.release(on_z_release, [z_bar], outputs)
687
  x_bar.release(on_x_release, [x_bar], outputs)
688
  y_bar.release(on_y_release, [y_bar], outputs)
689
 
690
- # mode toggle
691
  mode_radio.change(on_mode_toggle, [mode_radio], outputs)
692
 
693
- # buttons
694
  seg_btn.click(on_seg_button, [], outputs)
695
  save_btn.click(on_save, [], outputs)
696
  clr_btn.click(on_clear, [], outputs)
697
  undo_btn.click(on_undo, [], outputs)
698
 
699
- # debug list selection
700
  seeds_df.select(on_seed_df_select, [], outputs)
701
  seed_dd.change(on_seed_dd_change, [seed_dd], outputs)
702
 
 
5
  import numpy as np
6
  from PIL import Image, ImageDraw, ImageOps
7
  import gradio as gr
8
+
9
+ print("\n" + "="*60)
10
+ print("🔍 INTERACTIVE-MEN-RT DEMO DEBUG INFO")
11
+ print("="*60)
12
+ print(f"📦 Gradio version: {gr.__version__}")
13
+ print(f"📁 Current directory: {os.getcwd()}")
14
+ print(f"📁 Directory contents: {os.listdir('.')}")
15
+
16
+ DATA_ROOT = Path("./samples")
17
+ print(f"\n📂 DATA_ROOT: {DATA_ROOT}")
18
+ print(f"📂 DATA_ROOT exists: {DATA_ROOT.exists()}")
19
+
20
+ if DATA_ROOT.exists():
21
+ print(f"📂 DATA_ROOT contents: {list(DATA_ROOT.iterdir())}")
22
+ EXAMPLES_CHECK = ["BraTS-MEN-RT-0071-1"]
23
+ for case in EXAMPLES_CHECK:
24
+ case_dir = DATA_ROOT / case
25
+ print(f"\n 📦 Case: {case}")
26
+ print(f" Exists: {case_dir.exists()}")
27
+ if case_dir.exists():
28
+ files = list(case_dir.iterdir())
29
+ print(f" Files: {[f.name for f in files]}")
30
+ has_t1c = any(f.name.endswith('_t1c.nii.gz') for f in files)
31
+ print(f" ✓ Has T1c: {has_t1c}")
32
+ else:
33
+ print("❌ samples folder NOT FOUND!")
34
+ print("\n" + "="*60 + "\n")
35
+
36
  # ---------- optional deps ----------
37
  try:
38
  import nibabel as nib
39
  HAVE_NIB = True
40
+ print("✓ nibabel available")
41
+ except Exception as e:
42
  HAVE_NIB = False
43
+ print(f"✗ nibabel not available: {e}")
44
+
45
  try:
46
  from scipy import ndimage as ndi
47
  HAVE_SCIPY = True
48
+ print("✓ scipy available")
49
+ except Exception as e:
50
  HAVE_SCIPY = False
51
+ print(f"✗ scipy not available: {e}")
52
+
53
  # ---------- model predictor ----------
54
  PREDICTOR = None
55
  DEVICE = "cuda:0"
56
+
57
  from huggingface_hub import snapshot_download
58
+
59
+ try:
60
+ _repo_root = snapshot_download("hanjang/Interactive-MEN-RT",
61
+ allow_patterns=["nnUNetInteractionTrainer__nnUNetPlans__3d_fullres_scratch/**"])
62
+ CKPT = os.path.join(_repo_root, "nnUNetInteractionTrainer__nnUNetPlans__3d_fullres_scratch")
63
+ print(f"[INFO] Checkpoint path: {CKPT}")
64
+ if os.path.exists(CKPT):
65
+ contents = os.listdir(CKPT)
66
+ print(f"[INFO] Checkpoint contents: {contents}")
67
+ fold_0 = os.path.join(CKPT, "fold_0")
68
+ if os.path.exists(fold_0):
69
+ print(f"[INFO] fold_0 contents: {os.listdir(fold_0)}")
70
+ else:
71
+ print(f"[ERROR] Checkpoint path does not exist!")
72
+ CKPT = None
73
+ except Exception as e:
74
+ print(f"[ERROR] Failed to download checkpoint: {e}")
75
+ CKPT = None
76
+
77
  for env in ("nnUNet_raw", "nnUNet_preprocessed", "nnUNet_results"):
78
  os.environ.setdefault(env, tempfile.mkdtemp(prefix=f"{env}_"))
79
+
80
  def _init_predictor_once():
 
81
  global PREDICTOR
82
  if PREDICTOR is not None:
83
  return True
84
+ if CKPT is None:
85
+ print("[WARN] No checkpoint available, will use fallback only")
86
+ return False
87
  try:
88
  import torch
89
  from Interactive_MEN_RT_predictor import InteractiveMENRTPredictor
 
95
  model_training_output_dir=CKPT, use_fold=0, checkpoint_name="checkpoint_best.pth"
96
  )
97
  PREDICTOR = pred
 
98
  try:
99
  if torch.cuda.is_available():
100
  x = np.zeros((1, 8, 8, 8), np.float32)
 
110
  except Exception as e:
111
  print(f"[MODEL] init failed: {e}")
112
  return False
113
+
114
  def preload_model_in_background():
115
  threading.Thread(target=_init_predictor_once, daemon=True).start()
116
+
117
  # ---------- config ----------
118
+ EXAMPLES = ["BraTS-MEN-RT-0071-1"]
 
 
 
 
 
 
119
  RENDER_PX_DEFAULT = 384
120
+ ROT_CCW = True
121
+
122
  # colors
123
  ACCENT_HEX = "#1e90ff"
124
  CROSS_RGB = (30, 144, 255)
125
+ GT_RGBA_FILL = (255, 215, 0, 128)
126
+ PR_RGBA_FILL = (255, 60, 60, 128)
127
  SEED_RGB = (89, 224, 154)
128
+ BBOX_RGB = (255, 140, 0)
129
+
130
  # ---------- state ----------
131
  class State:
132
  def __init__(self):
 
135
  self.case_id=None; self.loaded=False
136
  self.cross={"x":0,"y":0,"z":0}
137
  self.slice={"axial":0,"sagittal":0,"coronal":0}
138
+ self.seeds=[]
139
+ self.seed_views=[]
140
  self.render_px=RENDER_PX_DEFAULT
141
  self.disp_wh={"axial":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT),
142
  "sagittal":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT),
143
  "coronal":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT)}
144
+ self.active_view="axial"
145
+ self.bbox_mode = False
146
+ self.bbox_points = []
147
+ self.bboxes = []
 
 
148
  self.ref_affine = None
149
  self.ref_header = None
150
+
151
  S = State()
152
+
153
  # ---------- utils ----------
154
  def _norm01(a):
155
  a=a.astype(np.float32)
156
  p2,p98=np.percentile(a,2),np.percentile(a,98)
157
  if p98<=p2: p2,p98=float(a.min()),float(a.max()) or 1.0
158
  return np.clip((a-p2)/max(p98-p2,1e-6),0,1)
159
+
160
  def _resize_slice_nearest(arr2d,w,h):
161
  im=Image.fromarray(arr2d); im=im.resize((w,h),Image.NEAREST); return np.array(im)
162
+
163
  def _rot90_if_needed(img_or_np):
164
  if not ROT_CCW: return img_or_np
165
  if isinstance(img_or_np, Image.Image): return img_or_np.rotate(90, expand=True)
166
  return np.rot90(img_or_np, k=1)
167
+
168
  # ---------- IO ----------
169
  def _load_png_stack(case_dir):
170
  pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.png")))
 
177
  vol=_norm01(vol)
178
  print(f"[PIL] {len(pngs)} slices -> {vol.shape} in {time.time()-t0:.2f}s")
179
  return vol, None, None
180
+
181
  def _load_nifti(case_dir,case_id,ds=1):
182
  if not HAVE_NIB: return None, None, None
183
  p=case_dir/f"{case_id}_t1c.nii.gz"
 
188
  arr=_norm01(arr)
189
  print(f"[NIfTI] {case_id} -> {arr.shape} in {time.time()-t0:.2f}s")
190
  return arr, nii.affine, nii.header
191
+
192
  def _resample_mask_to_vol_shape(mask_xyz, vol_shape_xyz):
193
  mx,my,mz=mask_xyz.shape; vx,vy,vz=vol_shape_xyz
194
  out=np.zeros((vx,vy,vz),dtype=np.uint8)
 
198
  im=Image.fromarray(sl).resize((vy,vx),Image.NEAREST)
199
  out[:,:,k]=(np.array(im)>0).astype(np.uint8)
200
  return out
201
+
202
  def _load_gt(case_dir,case_id,vol_shape):
203
  candidates = [
204
  f"{case_id}_gtv.nii.gz", f"{case_id}_seg.nii.gz", f"{case_id}_gt.nii.gz",
 
219
  print(f"[GT] load error {p.name}: {e}")
220
  print("[GT] not found.")
221
  return None
222
+
223
  def load_case(case_id):
224
  case_dir=DATA_ROOT/case_id
225
  vol, affine, header = _load_png_stack(case_dir)
 
248
  S.active_view="axial"
249
  S.loaded=True
250
  print(f"[LOAD] {case_id} | shape={S.shape}")
251
+
252
+ # ---------- 2D rendering ----------
253
  def _slice2d(view):
254
  if view=="axial": sl=S.vol[:,:,S.slice["axial"]]
255
  elif view=="sagittal": sl=S.vol[S.slice["sagittal"],:,:].T
256
  else: sl=S.vol[:,S.slice["coronal"],:].T
257
  return _rot90_if_needed(sl)
258
+
259
  def _cross_pix_on_rot(view,w,h,x=None,y=None,z=None):
260
  X,Y,Z=S.shape
261
  if x is None: x=S.cross["x"]
 
271
  u=int(round(z*(w-1)/max(Z-1,1)))
272
  v=int(round((X-1-x)*(h-1)/max(X-1,1)))
273
  return u,v
274
+
275
  def _draw_cross(img_draw, view, w, h):
276
  u,v=_cross_pix_on_rot(view,w,h)
277
  img_draw.line([(u,0),(u,h)],fill=CROSS_RGB,width=1)
278
  img_draw.line([(0,v),(w,v)],fill=CROSS_RGB,width=1)
279
  img_draw.ellipse((u-5,v-5,u+5,v+5),fill=(255,255,255),outline=ACCENT_HEX,width=2)
280
+
281
  def render_top(view):
282
  if not S.loaded: return None
283
  sl=_slice2d(view)
 
287
  _draw_cross(dr, view, w, h)
288
  S.disp_wh[view]=im.size
289
  return im
290
+
291
  def _axial_mask2d_rot(mask3d):
292
  if mask3d is None: return None
293
  m = mask3d[:,:,S.slice["axial"]].astype(np.uint8)
294
  m = _rot90_if_needed(m)
295
  return m
296
+
297
  def _axial_overlay_fill(mask3d, rgba):
298
  sl = _rot90_if_needed(S.vol[:,:,S.slice["axial"]])
299
  base=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGBA")
 
302
  m2d = _resize_slice_nearest((m2d>0).astype(np.uint8), S.render_px, S.render_px)
303
  over=np.zeros((S.render_px,S.render_px,4),dtype=np.uint8); over[m2d>0]=rgba
304
  return Image.alpha_composite(base,Image.fromarray(over,"RGBA")).convert("RGB")
305
+
306
  def _interaction_2d():
307
  view = S.active_view
308
  if not S.loaded: return None
 
310
  im=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGB")
311
  w=h=S.render_px
312
  dr=ImageDraw.Draw(im)
 
313
  _draw_cross(dr, view, w, h)
314
+
315
  tol=0
316
  for i, (x,y,z) in enumerate(S.seeds):
317
  on_plane = (
 
325
  dr.ellipse((u-r,v-r,u+r,v+r), fill=SEED_RGB, outline=(40,140,100), width=1)
326
  dr.text((u+6, v-8), f"{i+1}", fill=(30,30,30))
327
 
 
328
  if S.bbox_mode:
329
  for i, (x,y,z) in enumerate(S.bbox_points):
330
  on_plane = (
 
339
  text = "P1" if i == 0 else "P2"
340
  dr.text((u+10, v-10), text, fill=BBOX_RGB)
341
 
 
342
  for (x1,y1,z1,x2,y2,z2) in S.bboxes:
343
  if view=="axial":
344
  curr_z = S.slice["axial"]
 
348
  dr.rectangle([u1,v1,u2,v2], outline=(0,255,0), width=2)
349
 
350
  return im
351
+
352
  # ===================== segmentation ========================================
353
  def _segment_with_model():
354
  if PREDICTOR is None and not _init_predictor_once():
 
359
  PREDICTOR.set_image(img)
360
  PREDICTOR.set_target_buffer(np.zeros_like(img[0], np.float32))
361
  PREDICTOR._finish_preprocessing_and_initialize_interactions()
 
 
362
  for (x,y,z) in S.seeds:
363
  PREDICTOR.add_point_interaction(x, y, z, foreground=True)
 
 
364
  for (x1,y1,z1,x2,y2,z2) in S.bboxes:
365
  PREDICTOR.add_bbox_interaction(
366
  min(x1,x2), min(y1,y2), min(z1,z2),
367
  max(x1,x2), max(y1,y2), max(z1,z2)
368
  )
 
369
  PREDICTOR._predict_without_interaction()
370
  pred = (PREDICTOR.target_buffer.astype(np.float32) > 0.5).astype(np.uint8)
371
  if pred.shape != S.shape:
 
375
  except Exception as e:
376
  print(f"[MODEL] inference failed: {e}")
377
  return None, "model-error"
378
+
379
  def _segment_fallback():
380
  if (not S.seeds and not S.bboxes) or not HAVE_SCIPY:
381
  return None, "no-interactions-or-scipy"
 
400
  mask=ndi.binary_closing(mask,iterations=1); mask=ndi.binary_opening(mask,iterations=1)
401
  print(f"[FB] seg {time.time()-t0:.3f}s | vox={int(mask.sum())}")
402
  return (mask>0).astype(np.uint8), "ok"
403
+
404
  def do_segment():
405
  pred, tag = _segment_with_model()
406
  if pred is None:
 
408
  S.pred = pred if pred is not None else None
409
  print(f"[SEG] done: {tag}")
410
  return "OK" if S.pred is not None else "Failed"
411
+
412
  def save_prediction():
413
  if S.pred is None:
414
  return "No prediction to save", None
 
417
  try:
418
  tmp_dir = Path(tempfile.mkdtemp(prefix="menrt_output_"))
419
  out_path = tmp_dir / f"{S.case_id}_pred.nii.gz"
 
420
  affine = S.ref_affine if S.ref_affine is not None else np.eye(4)
421
  header = S.ref_header.copy() if S.ref_header is not None else None
 
422
  nii_img = nib.Nifti1Image(S.pred.astype(np.uint8), affine, header=header)
423
  nib.save(nii_img, str(out_path))
 
424
  print(f"[SAVE] {out_path}")
425
  return "Saved successfully!", str(out_path)
426
  except Exception as e:
427
  print(f"[SAVE] error: {e}")
428
  return f"Save failed: {e}", None
429
+
430
+ # ---------- helpers ----------
431
  def _seed_rows():
 
432
  rows=[]
433
  for i,(x,y,z) in enumerate(S.seeds):
434
  v = S.seed_views[i] if i < len(S.seed_views) else ""
 
436
  for i,(x1,y1,z1,x2,y2,z2) in enumerate(S.bboxes):
437
  rows.append([len(S.seeds)+i+1, "bbox", "3D", f"{x1}-{x2}", f"{y1}-{y2}", f"{z1}-{z2}"])
438
  return rows
439
+
440
  def _seed_dropdown_options():
 
441
  opts=[]
442
  for i,(x,y,z) in enumerate(S.seeds):
443
  v = S.seed_views[i] if i < len(S.seed_views) else "axial"
 
445
  for i,(x1,y1,z1,x2,y2,z2) in enumerate(S.bboxes):
446
  opts.append(f"3D → bbox {i+1} → ({x1},{y1},{z1})-({x2},{y2},{z2})")
447
  return opts
448
+
449
  def _debug_widgets(current_idx=None):
450
  rows = _seed_rows()
451
  df_upd = gr.update(value=rows)
 
456
  val = (opts[current_idx] if (0 <= current_idx < len(opts)) else (opts[-1] if opts else None))
457
  dd_upd = gr.update(choices=opts, value=val)
458
  return df_upd, dd_upd
459
+
460
  def _figs_and_imgs():
461
  top_ax=render_top("axial")
462
  top_sg=render_top("sagittal")
 
465
  ax_pr = _axial_overlay_fill(S.pred, PR_RGBA_FILL)
466
  inter2d = _interaction_2d()
467
  return top_ax, top_sg, top_co, ax_gt, ax_pr, inter2d
468
+
469
  def _bar_ranges_and_values():
470
  X,Y,Z=S.shape
471
  return (gr.update(minimum=0,maximum=Z-1,value=S.slice["axial"],visible=True),
472
  gr.update(minimum=0,maximum=X-1,value=S.slice["sagittal"],visible=True),
473
  gr.update(minimum=0,maximum=Y-1,value=S.slice["coronal"],visible=True))
474
+
475
  def _parse_evt_xy(evt):
476
+ """강화된 이벤트 파싱"""
477
+ print(f"[DEBUG_EVT] Event received: {evt}")
478
+ print(f"[DEBUG_EVT] Event type: {type(evt)}")
479
+ print(f"[DEBUG_EVT] Event dir: {dir(evt)}")
480
+
481
+ if evt is None:
482
+ print("[DEBUG_EVT] Event is None!")
483
+ return None
484
+
485
  try:
486
+ # Method 1: evt.index
487
  if hasattr(evt, "index") and evt.index is not None:
488
  ix = evt.index
489
+ print(f"[DEBUG_EVT] Found index: {ix}")
490
  if isinstance(ix, (list, tuple)) and len(ix) >= 2:
491
+ result = int(ix[0]), int(ix[1])
492
+ print(f"[DEBUG_EVT] Parsed from index: {result}")
493
+ return result
494
+
495
+ # Method 2: evt.x, evt.y
496
  if hasattr(evt, "x") and hasattr(evt, "y"):
497
+ x_val = getattr(evt, "x")
498
+ y_val = getattr(evt, "y")
499
+ print(f"[DEBUG_EVT] Found x={x_val}, y={y_val}")
500
+ if x_val is not None and y_val is not None:
501
+ result = int(x_val), int(y_val)
502
+ print(f"[DEBUG_EVT] Parsed from x,y: {result}")
503
+ return result
504
+ except Exception as e:
505
+ print(f"[DEBUG_EVT] Parse error: {e}")
506
+
507
+ print("[DEBUG_EVT] Failed to parse coordinates!")
 
 
508
  return None
509
+
510
  def _disp_to_vol(view,u,v):
511
  X,Y,Z=S.shape; w,h=S.disp_wh[view]
512
  if w<=0 or h<=0: w=h=S.render_px
 
524
  y = S.slice["coronal"]
525
  x=max(0,min(X-1,x)); y=max(0,min(Y-1,y)); z=max(0,min(Z-1,z))
526
  return x,y,z
527
+
528
  def _thumb_from_case(case_id,px=96):
529
  case_dir=DATA_ROOT/case_id
530
  pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.png")))
 
545
  im = Image.fromarray((mid*255).astype(np.uint8)).resize((px,px),Image.BILINEAR)
546
  im = _rot90_if_needed(im)
547
  return ImageOps.expand(im,border=1,fill=200)
548
+
549
  # ---------- callbacks ----------
550
  def on_load(case_id):
551
  load_case(case_id)
552
  preload_model_in_background()
553
  imgs=_figs_and_imgs(); bars=_bar_ranges_and_values(); dbg=_debug_widgets()
554
+ return (*imgs,*bars,*dbg, gr.update(value="Point"), None, "Loaded")
555
+
556
+ def on_gallery_select(evt: gr.SelectData):
557
  idx=0
558
  if hasattr(evt,"index"):
559
  ix=evt.index; idx=int(ix[0] if isinstance(ix,(list,tuple)) else ix)
 
561
  cid=EXAMPLES[idx]
562
  out=on_load(cid)
563
  return (*out, gr.update(value=cid))
564
+
565
  def _click_common(view, evt: gr.SelectData):
566
+ print(f"[CLICK_HANDLER] Called for view={view}")
567
+ print(f"[CLICK_HANDLER] S.loaded={S.loaded}")
568
+ print(f"[CLICK_HANDLER] S.bbox_mode={S.bbox_mode}")
569
+
570
+ if not S.loaded:
571
+ print("[CLICK_HANDLER] Not loaded yet!")
572
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
573
+ gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "Not loaded")
574
+
575
  xy=_parse_evt_xy(evt)
576
  if xy is None:
577
+ print("[CLICK_HANDLER] Failed to parse event!")
578
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
579
+ gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "Parse failed")
580
+
581
  u,v=xy
582
  x,y,z=_disp_to_vol(view,u,v)
583
+ print(f"[CLICK_HANDLER] Clicked at display ({u},{v}) -> volume ({x},{y},{z})")
584
 
585
  if S.bbox_mode:
 
586
  S.bbox_points.append((x,y,z))
587
  if len(S.bbox_points) == 2:
588
  p1, p2 = S.bbox_points
589
  S.bboxes.append((*p1, *p2))
590
  S.bbox_points = []
591
  print(f"[UI] bbox created: {S.bboxes[-1]}")
592
+ status = "BBox created!"
593
  else:
594
  print(f"[UI] bbox point {len(S.bbox_points)}/2")
595
+ status = f"BBox point {len(S.bbox_points)}/2"
596
  else:
 
597
  S.seeds.append((x,y,z)); S.seed_views.append(view)
598
  print(f"[UI] seed+ {(x,y,z)} total={len(S.seeds)}")
599
+ status = f"Added point {len(S.seeds)}"
600
 
601
  S.cross={"x":x,"y":y,"z":z}
602
  S.slice={"sagittal":x,"coronal":y,"axial":z}
 
604
  imgs=_figs_and_imgs(); bars=_bar_ranges_and_values()
605
  idx = len(S.seeds) + len(S.bboxes) - 1
606
  dbg=_debug_widgets(current_idx=idx)
607
+ return (*imgs,*bars,*dbg, gr.update(value="Point" if not S.bbox_mode else "BBox"), None, status)
608
+
609
+ def on_axial_select(evt: gr.SelectData):
610
+ print("[EVENT] on_axial_select triggered!")
611
+ return _click_common("axial", evt)
612
+
613
+ def on_sagittal_select(evt: gr.SelectData):
614
+ print("[EVENT] on_sagittal_select triggered!")
615
+ return _click_common("sagittal", evt)
616
+
617
+ def on_coronal_select(evt: gr.SelectData):
618
+ print("[EVENT] on_coronal_select triggered!")
619
+ return _click_common("coronal", evt)
620
+
621
  def on_seg_button():
622
+ msg=do_segment()
623
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
624
+ gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Segment: {msg}")
625
+
626
  def on_clear():
627
  S.seeds=[]; S.seed_views=[]; S.pred=None
628
  S.bboxes=[]; S.bbox_points=[]
629
  print("[UI] cleared all")
630
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
631
+ gr.update(value="Point"), None, "Cleared")
632
+
633
  def on_undo():
634
  if S.bbox_mode and S.bbox_points:
635
  S.bbox_points.pop()
 
638
  elif S.seeds:
639
  S.seeds.pop(); S.seed_views.pop() if S.seed_views else None
640
  print("[UI] undo")
641
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
642
+ gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "Undo")
643
+
644
  def on_mode_toggle(mode):
645
  S.bbox_mode = (mode == "BBox")
646
  S.bbox_points = []
647
  print(f"[UI] mode -> {mode}")
648
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
649
+ gr.update(value=mode), None, f"Mode: {mode}")
650
+
651
  def on_save():
652
  msg, path = save_prediction()
653
  if path:
654
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
655
+ gr.update(value="Point" if not S.bbox_mode else "BBox"),
656
+ gr.update(visible=True, value=path), msg)
657
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
658
+ gr.update(value="Point" if not S.bbox_mode else "BBox"),
659
+ gr.update(visible=False, value=None), msg)
660
+
661
  def on_z_release(z_idx):
662
  S.slice["axial"]=int(z_idx); S.cross["z"]=S.slice["axial"]; S.active_view="axial"
663
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
664
+ gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Z={z_idx}")
665
+
666
  def on_x_release(x_idx):
667
  S.slice["sagittal"]=int(x_idx); S.cross["x"]=S.slice["sagittal"]; S.active_view="sagittal"
668
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
669
+ gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"X={x_idx}")
670
+
671
  def on_y_release(y_idx):
672
  S.slice["coronal"]=int(y_idx); S.cross["y"]=S.slice["coronal"]; S.active_view="coronal"
673
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
674
+ gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Y={y_idx}")
675
+
676
  def _jump_to_idx(idx:int):
 
677
  total = len(S.seeds) + len(S.bboxes)
678
  if not (0 <= idx < total):
679
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
680
+ gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
681
  if idx < len(S.seeds):
682
  x,y,z = S.seeds[idx]
683
  view = S.seed_views[idx] if idx < len(S.seed_views) else "axial"
684
  S.cross={"x":x,"y":y,"z":z}
685
  S.slice={"sagittal":x,"coronal":y,"axial":z}
686
  S.active_view=view
687
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(current_idx=idx),
688
+ gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Jump to {idx+1}")
689
+
690
  def on_seed_df_select(evt):
 
691
  try:
692
  row = int(evt.index[0]) if hasattr(evt, "index") else 0
693
  except Exception:
694
  row = 0
695
  return _jump_to_idx(row)
696
+
697
  def on_seed_dd_change(val):
 
698
  if not val:
699
+ return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
700
+ gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
701
  try:
702
  if "point" in val:
703
  p = val.split("point")[1].strip()
 
708
  except Exception:
709
  idx = len(S.seeds) + len(S.bboxes) - 1
710
  return _jump_to_idx(idx)
711
+
712
  # ---------- UI ----------
713
  css = """
714
  :root {
 
720
  .tiny .gr-slider input[type="range"]{height:6px}
721
  .tiny .gr-form{gap:6px}
722
  .smallnote{color:#3e6285;font-size:12px;margin-top:6px}
 
 
723
  """
724
+
725
  THUMBS=[_thumb_from_case(cid,px=96) for cid in EXAMPLES]
726
+
727
  with gr.Blocks(css=css, title="Interactive-MEN-RT", theme=gr.themes.Soft(), analytics_enabled=False) as demo:
728
  gr.Markdown("""
729
  <h3 style='color:#114b8b'>Interactive-MEN-RT Segmentation</h3>
 
732
  <b>⚠️ Research only — Not for clinical use</b>
733
  </div>
734
  """)
735
+
736
  with gr.Row():
737
  with gr.Column(scale=1, min_width=280, elem_classes=["round"]):
738
+ gr.Markdown("<div class='section'>Demo Case</div>")
739
+ gallery = gr.Gallery(value=THUMBS, columns=1, height=110, allow_preview=False, preview=False, show_label=False)
740
  case_dd = gr.Dropdown(choices=EXAMPLES, value=EXAMPLES[0], label="Case")
741
 
742
  mode_radio = gr.Radio(["Point", "BBox"], value="Point", label="Interaction Mode")
 
757
  wrap=True,
758
  row_count=(0, "dynamic")
759
  )
760
+ seed_dd = gr.Dropdown(choices=[], value=None, label="Go to point")
761
 
762
  pred_file = gr.File(label="Download Prediction", visible=False)
763
  status_text = gr.Textbox(label="Status", value="", interactive=False, lines=1)
 
774
  coronal = gr.Image(type="pil", interactive=True, height=RENDER_PX_DEFAULT+8, label="Coronal (Y)")
775
  y_bar = gr.Slider(0,1,value=0,step=1,label="Y", elem_classes=["tiny"])
776
  with gr.Row():
777
+ out_ax_gt = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Ground Truth")
778
+ out_ax_pr = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Prediction")
779
+ inter2d = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Interactions")
780
 
 
781
  outputs = [axial, sagittal, coronal, out_ax_gt, out_ax_pr, inter2d,
782
  z_bar, x_bar, y_bar, seeds_df, seed_dd, mode_radio, pred_file, status_text]
783
 
 
784
  demo.load(lambda: on_load(EXAMPLES[0]), [], outputs)
 
 
785
  case_dd.change(lambda cid: on_load(cid), [case_dd], outputs)
 
 
786
  gallery.select(on_gallery_select, [], outputs + [case_dd])
787
 
788
+ # 클릭 이벤트 - 강화된 핸들러
789
  axial.select(on_axial_select, [], outputs)
790
  sagittal.select(on_sagittal_select, [], outputs)
791
  coronal.select(on_coronal_select, [], outputs)
792
 
 
793
  z_bar.release(on_z_release, [z_bar], outputs)
794
  x_bar.release(on_x_release, [x_bar], outputs)
795
  y_bar.release(on_y_release, [y_bar], outputs)
796
 
 
797
  mode_radio.change(on_mode_toggle, [mode_radio], outputs)
798
 
 
799
  seg_btn.click(on_seg_button, [], outputs)
800
  save_btn.click(on_save, [], outputs)
801
  clr_btn.click(on_clear, [], outputs)
802
  undo_btn.click(on_undo, [], outputs)
803
 
 
804
  seeds_df.select(on_seed_df_select, [], outputs)
805
  seed_dd.change(on_seed_dd_change, [seed_dd], outputs)
806