import os, warnings, time, glob, tempfile, threading warnings.filterwarnings("ignore") os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" from pathlib import Path import numpy as np from PIL import Image, ImageDraw, ImageOps import gradio as gr print("\n" + "="*60) print("πŸ” INTERACTIVE-MEN-RT DEMO DEBUG INFO") print("="*60) print(f"πŸ“¦ Gradio version: {gr.__version__}") print(f"πŸ“ Current directory: {os.getcwd()}") print(f"πŸ“ Directory contents: {os.listdir('.')}") DATA_ROOT = Path("./samples") print(f"\nπŸ“‚ DATA_ROOT: {DATA_ROOT}") print(f"πŸ“‚ DATA_ROOT exists: {DATA_ROOT.exists()}") if DATA_ROOT.exists(): print(f"πŸ“‚ DATA_ROOT contents: {list(DATA_ROOT.iterdir())}") EXAMPLES_CHECK = ["BraTS-MEN-RT-0071-1"] for case in EXAMPLES_CHECK: case_dir = DATA_ROOT / case print(f"\n πŸ“¦ Case: {case}") print(f" Exists: {case_dir.exists()}") if case_dir.exists(): files = list(case_dir.iterdir()) print(f" Files: {[f.name for f in files]}") has_t1c = any(f.name.endswith('_t1c.nii.gz') for f in files) print(f" βœ“ Has T1c: {has_t1c}") else: print("❌ samples folder NOT FOUND!") print("\n" + "="*60 + "\n") # ---------- optional deps ---------- try: import nibabel as nib HAVE_NIB = True print("βœ“ nibabel available") except Exception as e: HAVE_NIB = False print(f"βœ— nibabel not available: {e}") try: from scipy import ndimage as ndi HAVE_SCIPY = True print("βœ“ scipy available") except Exception as e: HAVE_SCIPY = False print(f"βœ— scipy not available: {e}") # ---------- model predictor ---------- PREDICTOR = None DEVICE = "cuda:0" from huggingface_hub import snapshot_download try: _repo_root = snapshot_download("hanjang/Interactive-MEN-RT", allow_patterns=["nnUNetInteractionTrainer__nnUNetPlans__3d_fullres_scratch/**"]) CKPT = os.path.join(_repo_root, "nnUNetInteractionTrainer__nnUNetPlans__3d_fullres_scratch") print(f"[INFO] Checkpoint path: {CKPT}") if os.path.exists(CKPT): contents = os.listdir(CKPT) print(f"[INFO] Checkpoint contents: {contents}") fold_0 = os.path.join(CKPT, "fold_0") if os.path.exists(fold_0): print(f"[INFO] fold_0 contents: {os.listdir(fold_0)}") else: print(f"[ERROR] Checkpoint path does not exist!") CKPT = None except Exception as e: print(f"[ERROR] Failed to download checkpoint: {e}") CKPT = None for env in ("nnUNet_raw", "nnUNet_preprocessed", "nnUNet_results"): os.environ.setdefault(env, tempfile.mkdtemp(prefix=f"{env}_")) def _init_predictor_once(): global PREDICTOR if PREDICTOR is not None: return True if CKPT is None: print("[WARN] No checkpoint available, will use fallback only") return False try: import torch from Interactive_MEN_RT_predictor import InteractiveMENRTPredictor dev = torch.device(DEVICE if torch.cuda.is_available() else "cpu") pred = InteractiveMENRTPredictor( device=dev, use_torch_compile=False, do_autozoom=False, verbose=False ) pred.initialize_from_trained_model_folder( model_training_output_dir=CKPT, use_fold=0, checkpoint_name="checkpoint_best.pth" ) PREDICTOR = pred try: if torch.cuda.is_available(): x = np.zeros((1, 8, 8, 8), np.float32) pred.reset_interactions() pred.set_image(x) pred.set_target_buffer(np.zeros_like(x[0], np.float32)) pred._finish_preprocessing_and_initialize_interactions() torch.cuda.synchronize() except Exception: pass print("[MODEL] ready") return True except Exception as e: print(f"[MODEL] init failed: {e}") return False def preload_model_in_background(): threading.Thread(target=_init_predictor_once, daemon=True).start() # ---------- config ---------- EXAMPLES = ["BraTS-MEN-RT-0071-1"] RENDER_PX_DEFAULT = 384 ROT_CCW = True # colors ACCENT_HEX = "#1e90ff" CROSS_RGB = (30, 144, 255) GT_RGBA_FILL = (255, 215, 0, 128) PR_RGBA_FILL = (255, 60, 60, 128) SEED_RGB = (89, 224, 154) BBOX_RGB = (255, 140, 0) # ---------- state ---------- class State: def __init__(self): self.vol=None; self.shape=None self.gt=None; self.pred=None self.case_id=None; self.loaded=False self.cross={"x":0,"y":0,"z":0} self.slice={"axial":0,"sagittal":0,"coronal":0} self.seeds=[] self.seed_views=[] self.render_px=RENDER_PX_DEFAULT self.disp_wh={"axial":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT), "sagittal":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT), "coronal":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT)} self.active_view="axial" self.bbox_mode = False self.bbox_points = [] self.bboxes = [] self.ref_affine = None self.ref_header = None S = State() # ---------- utils ---------- def _norm01(a): a=a.astype(np.float32) p2,p98=np.percentile(a,2),np.percentile(a,98) if p98<=p2: p2,p98=float(a.min()),float(a.max()) or 1.0 return np.clip((a-p2)/max(p98-p2,1e-6),0,1) def _resize_slice_nearest(arr2d,w,h): im=Image.fromarray(arr2d); im=im.resize((w,h),Image.NEAREST); return np.array(im) def _rot90_if_needed(img_or_np): if not ROT_CCW: return img_or_np if isinstance(img_or_np, Image.Image): return img_or_np.rotate(90, expand=True) return np.rot90(img_or_np, k=1) # ---------- IO ---------- def _load_png_stack(case_dir): pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.png"))) if not pngs: pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.jpg"))) if not pngs: return None, None, None t0=time.time() arr=[np.array(Image.open(p).convert("L")) for p in pngs] vol=np.stack(arr,axis=2).astype(np.float32) vol=_norm01(vol) print(f"[PIL] {len(pngs)} slices -> {vol.shape} in {time.time()-t0:.2f}s") return vol, None, None def _load_nifti(case_dir,case_id,ds=1): if not HAVE_NIB: return None, None, None p=case_dir/f"{case_id}_t1c.nii.gz" if not p.exists(): return None, None, None t0=time.time() nii=nib.load(str(p)) arr=np.asanyarray(nii.dataobj[::ds,::ds,::ds],dtype=np.float32) arr=_norm01(arr) print(f"[NIfTI] {case_id} -> {arr.shape} in {time.time()-t0:.2f}s") return arr, nii.affine, nii.header def _resample_mask_to_vol_shape(mask_xyz, vol_shape_xyz): mx,my,mz=mask_xyz.shape; vx,vy,vz=vol_shape_xyz out=np.zeros((vx,vy,vz),dtype=np.uint8) for k in range(vz): src_k=int(round(k*(mz-1)/max(vz-1,1))) sl=(mask_xyz[:,:,src_k]>0).astype(np.uint8)*255 im=Image.fromarray(sl).resize((vy,vx),Image.NEAREST) out[:,:,k]=(np.array(im)>0).astype(np.uint8) return out def _load_gt(case_dir,case_id,vol_shape): candidates = [ f"{case_id}_gtv.nii.gz", f"{case_id}_seg.nii.gz", f"{case_id}_gt.nii.gz", "gtv.nii.gz", "seg.nii.gz", "gt.nii.gz", f"{case_id}_gtv.nii", f"{case_id}_seg.nii", f"{case_id}_gt.nii", ] if HAVE_NIB: for name in candidates: p = case_dir/name if p.exists(): try: m=np.asanyarray(nib.load(str(p)).dataobj,dtype=np.uint8) print(f"[GT] found {p.name} raw={m.shape}") m=_resample_mask_to_vol_shape(m,vol_shape) print(f"[GT] resized -> {m.shape}") return (m>0).astype(np.uint8) except Exception as e: print(f"[GT] load error {p.name}: {e}") print("[GT] not found.") return None def load_case(case_id): case_dir=DATA_ROOT/case_id vol, affine, header = _load_png_stack(case_dir) if vol is None: vol, affine, header = _load_nifti(case_dir,case_id,ds=1) if vol is None: Z=96 x=np.linspace(-1,1,RENDER_PX_DEFAULT)[:,None,] y=np.linspace(-1,1,RENDER_PX_DEFAULT)[None,:,] z=np.linspace(-1,1,Z)[None,None,:] vol=np.exp(-(x**2+y**2+z**2)*6).astype(np.float32) affine = np.eye(4) header = None print("[VOL] dummy") S.vol=vol; S.shape=vol.shape; S.pred=None S.seeds=[]; S.seed_views=[] S.bbox_mode=False; S.bbox_points=[]; S.bboxes=[] S.case_id=case_id S.ref_affine = affine S.ref_header = header X,Y,Z=S.shape S.cross={"x":X//2,"y":Y//2,"z":Z//2} S.slice={"sagittal":S.cross["x"],"coronal":S.cross["y"],"axial":S.cross["z"]} S.gt=_load_gt(case_dir,case_id,S.shape) S.render_px=RENDER_PX_DEFAULT S.active_view="axial" S.loaded=True print(f"[LOAD] {case_id} | shape={S.shape}") # ---------- 2D rendering ---------- def _slice2d(view): if view=="axial": sl=S.vol[:,:,S.slice["axial"]] elif view=="sagittal": sl=S.vol[S.slice["sagittal"],:,:].T else: sl=S.vol[:,S.slice["coronal"],:].T return _rot90_if_needed(sl) def _cross_pix_on_rot(view,w,h,x=None,y=None,z=None): X,Y,Z=S.shape if x is None: x=S.cross["x"] if y is None: y=S.cross["y"] if z is None: z=S.cross["z"] if view=="axial": u=int(round(x*(w-1)/max(X-1,1))) v=int(round((Y-1-y)*(h-1)/max(Y-1,1))) elif view=="sagittal": u=int(round(z*(w-1)/max(Z-1,1))) v=int(round((Y-1-y)*(h-1)/max(Y-1,1))) else: u=int(round(z*(w-1)/max(Z-1,1))) v=int(round((X-1-x)*(h-1)/max(X-1,1))) return u,v def _draw_cross(img_draw, view, w, h): u,v=_cross_pix_on_rot(view,w,h) img_draw.line([(u,0),(u,h)],fill=CROSS_RGB,width=1) img_draw.line([(0,v),(w,v)],fill=CROSS_RGB,width=1) img_draw.ellipse((u-5,v-5,u+5,v+5),fill=(255,255,255),outline=ACCENT_HEX,width=2) def render_top(view): if not S.loaded: return None sl=_slice2d(view) im=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGB") w=h=S.render_px dr=ImageDraw.Draw(im) _draw_cross(dr, view, w, h) S.disp_wh[view]=im.size return im def _axial_mask2d_rot(mask3d): if mask3d is None: return None m = mask3d[:,:,S.slice["axial"]].astype(np.uint8) m = _rot90_if_needed(m) return m def _axial_overlay_fill(mask3d, rgba): sl = _rot90_if_needed(S.vol[:,:,S.slice["axial"]]) base=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGBA") m2d = _axial_mask2d_rot(mask3d) if m2d is None: return base.convert("RGB") m2d = _resize_slice_nearest((m2d>0).astype(np.uint8), S.render_px, S.render_px) over=np.zeros((S.render_px,S.render_px,4),dtype=np.uint8); over[m2d>0]=rgba return Image.alpha_composite(base,Image.fromarray(over,"RGBA")).convert("RGB") def _interaction_2d(): view = S.active_view if not S.loaded: return None sl=_slice2d(view) im=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGB") w=h=S.render_px dr=ImageDraw.Draw(im) _draw_cross(dr, view, w, h) tol=0 for i, (x,y,z) in enumerate(S.seeds): on_plane = ( (view=="axial" and abs(z - S.slice["axial"]) <= tol) or (view=="sagittal" and abs(x - S.slice["sagittal"])<= tol) or (view=="coronal" and abs(y - S.slice["coronal"]) <= tol) ) if not on_plane: continue u,v=_cross_pix_on_rot(view,w,h,x,y,z) r=4 dr.ellipse((u-r,v-r,u+r,v+r), fill=SEED_RGB, outline=(40,140,100), width=1) dr.text((u+6, v-8), f"{i+1}", fill=(30,30,30)) if S.bbox_mode: for i, (x,y,z) in enumerate(S.bbox_points): on_plane = ( (view=="axial" and abs(z - S.slice["axial"]) <= tol) or (view=="sagittal" and abs(x - S.slice["sagittal"])<= tol) or (view=="coronal" and abs(y - S.slice["coronal"]) <= tol) ) if not on_plane: continue u,v=_cross_pix_on_rot(view,w,h,x,y,z) r=6 dr.rectangle((u-r,v-r,u+r,v+r), outline=BBOX_RGB, width=3) text = "P1" if i == 0 else "P2" dr.text((u+10, v-10), text, fill=BBOX_RGB) for (x1,y1,z1,x2,y2,z2) in S.bboxes: if view=="axial": curr_z = S.slice["axial"] if min(z1,z2)-1 <= curr_z <= max(z1,z2)+1: u1,v1 = _cross_pix_on_rot(view,w,h,x1,y1,curr_z) u2,v2 = _cross_pix_on_rot(view,w,h,x2,y2,curr_z) dr.rectangle([u1,v1,u2,v2], outline=(0,255,0), width=2) return im # ===================== segmentation ======================================== def _segment_with_model(): if PREDICTOR is None and not _init_predictor_once(): return None, "model-init-failed" try: img = S.vol[None].astype(np.float32) PREDICTOR.reset_interactions() PREDICTOR.set_image(img) PREDICTOR.set_target_buffer(np.zeros_like(img[0], np.float32)) PREDICTOR._finish_preprocessing_and_initialize_interactions() for (x,y,z) in S.seeds: PREDICTOR.add_point_interaction(x, y, z, foreground=True) for (x1,y1,z1,x2,y2,z2) in S.bboxes: PREDICTOR.add_bbox_interaction( min(x1,x2), min(y1,y2), min(z1,z2), max(x1,x2), max(y1,y2), max(z1,z2) ) PREDICTOR._predict_without_interaction() pred = (PREDICTOR.target_buffer.astype(np.float32) > 0.5).astype(np.uint8) if pred.shape != S.shape: print(f"[MODEL] resize pred {pred.shape}->{S.shape}") pred = _resample_mask_to_vol_shape(pred, S.shape) return pred, "ok" except Exception as e: print(f"[MODEL] inference failed: {e}") return None, "model-error" def _segment_fallback(): if (not S.seeds and not S.bboxes) or not HAVE_SCIPY: return None, "no-interactions-or-scipy" X,Y,Z=S.shape field=np.zeros((X,Y,Z),dtype=np.float32) for (x,y,z) in S.seeds: field[x,y,z]=1.0 for (x1,y1,z1,x2,y2,z2) in S.bboxes: field[min(x1,x2):max(x1,x2)+1, min(y1,y2):max(y1,y2)+1, min(z1,z2):max(z1,z2)+1] = 0.5 t0=time.time() prob=ndi.gaussian_filter(field,sigma=6.0) if prob.max()>0: prob/=prob.max() nz=prob[prob>0]; thr=np.percentile(nz,70) if nz.size else 0.5 mask=prob>=max(thr,1e-3) lab,nlab=ndi.label(mask.astype(np.uint8)) if nlab>1: keep=np.zeros(nlab+1,np.uint8) for (x,y,z) in S.seeds: keep[lab[x,y,z]]=1 for (x1,y1,z1,x2,y2,z2) in S.bboxes: xm,ym,zm=(x1+x2)//2,(y1+y2)//2,(z1+z2)//2 keep[lab[xm,ym,zm]]=1 mask=keep[lab]>0 mask=ndi.binary_closing(mask,iterations=1); mask=ndi.binary_opening(mask,iterations=1) print(f"[FB] seg {time.time()-t0:.3f}s | vox={int(mask.sum())}") return (mask>0).astype(np.uint8), "ok" def do_segment(): pred, tag = _segment_with_model() if pred is None: pred, tag2 = _segment_fallback(); tag = f"{tag}->{tag2}" S.pred = pred if pred is not None else None print(f"[SEG] done: {tag}") return "OK" if S.pred is not None else "Failed" def save_prediction(): if S.pred is None: return "No prediction to save", None if not HAVE_NIB: return "nibabel not installed", None try: tmp_dir = Path(tempfile.mkdtemp(prefix="menrt_output_")) out_path = tmp_dir / f"{S.case_id}_pred.nii.gz" affine = S.ref_affine if S.ref_affine is not None else np.eye(4) header = S.ref_header.copy() if S.ref_header is not None else None nii_img = nib.Nifti1Image(S.pred.astype(np.uint8), affine, header=header) nib.save(nii_img, str(out_path)) print(f"[SAVE] {out_path}") return "Saved successfully!", str(out_path) except Exception as e: print(f"[SAVE] error: {e}") return f"Save failed: {e}", None # ---------- helpers ---------- def _seed_rows(): rows=[] for i,(x,y,z) in enumerate(S.seeds): v = S.seed_views[i] if i < len(S.seed_views) else "" rows.append([i+1, "point", v, x, y, z]) for i,(x1,y1,z1,x2,y2,z2) in enumerate(S.bboxes): rows.append([len(S.seeds)+i+1, "bbox", "3D", f"{x1}-{x2}", f"{y1}-{y2}", f"{z1}-{z2}"]) return rows def _seed_dropdown_options(): opts=[] for i,(x,y,z) in enumerate(S.seeds): v = S.seed_views[i] if i < len(S.seed_views) else "axial" opts.append(f"{v} β†’ point {i+1} β†’ ({x},{y},{z})") for i,(x1,y1,z1,x2,y2,z2) in enumerate(S.bboxes): opts.append(f"3D β†’ bbox {i+1} β†’ ({x1},{y1},{z1})-({x2},{y2},{z2})") return opts def _debug_widgets(current_idx=None): rows = _seed_rows() df_upd = gr.update(value=rows) opts = _seed_dropdown_options() if current_idx is None: val = (opts[-1] if opts else None) else: val = (opts[current_idx] if (0 <= current_idx < len(opts)) else (opts[-1] if opts else None)) dd_upd = gr.update(choices=opts, value=val) return df_upd, dd_upd def _figs_and_imgs(): top_ax=render_top("axial") top_sg=render_top("sagittal") top_co=render_top("coronal") ax_gt = _axial_overlay_fill(S.gt, GT_RGBA_FILL) ax_pr = _axial_overlay_fill(S.pred, PR_RGBA_FILL) inter2d = _interaction_2d() return top_ax, top_sg, top_co, ax_gt, ax_pr, inter2d def _bar_ranges_and_values(): X,Y,Z=S.shape return (gr.update(minimum=0,maximum=Z-1,value=S.slice["axial"],visible=True), gr.update(minimum=0,maximum=X-1,value=S.slice["sagittal"],visible=True), gr.update(minimum=0,maximum=Y-1,value=S.slice["coronal"],visible=True)) def _parse_evt_xy(evt): """κ°•ν™”λœ 이벀트 νŒŒμ‹±""" print(f"[DEBUG_EVT] Event received: {evt}") print(f"[DEBUG_EVT] Event type: {type(evt)}") print(f"[DEBUG_EVT] Event dir: {dir(evt)}") if evt is None: print("[DEBUG_EVT] Event is None!") return None try: # Method 1: evt.index if hasattr(evt, "index") and evt.index is not None: ix = evt.index print(f"[DEBUG_EVT] Found index: {ix}") if isinstance(ix, (list, tuple)) and len(ix) >= 2: result = int(ix[0]), int(ix[1]) print(f"[DEBUG_EVT] Parsed from index: {result}") return result # Method 2: evt.x, evt.y if hasattr(evt, "x") and hasattr(evt, "y"): x_val = getattr(evt, "x") y_val = getattr(evt, "y") print(f"[DEBUG_EVT] Found x={x_val}, y={y_val}") if x_val is not None and y_val is not None: result = int(x_val), int(y_val) print(f"[DEBUG_EVT] Parsed from x,y: {result}") return result except Exception as e: print(f"[DEBUG_EVT] Parse error: {e}") print("[DEBUG_EVT] Failed to parse coordinates!") return None def _disp_to_vol(view,u,v): X,Y,Z=S.shape; w,h=S.disp_wh[view] if w<=0 or h<=0: w=h=S.render_px if view=="axial": x = int(round(u * (X-1) / max(w-1,1))) y = int(round((Y-1) - v * (Y-1) / max(h-1,1))) z = S.slice["axial"] elif view=="sagittal": z = int(round(u * (Z-1) / max(w-1,1))) y = int(round((Y-1) - v * (Y-1) / max(h-1,1))) x = S.slice["sagittal"] else: z = int(round(u * (Z-1) / max(w-1,1))) x = int(round((X-1) - v * (X-1) / max(h-1,1))) y = S.slice["coronal"] x=max(0,min(X-1,x)); y=max(0,min(Y-1,y)); z=max(0,min(Z-1,z)) return x,y,z def _thumb_from_case(case_id,px=96): case_dir=DATA_ROOT/case_id pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.png"))) if not pngs: pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.jpg"))) if pngs: im = Image.open(pngs[len(pngs)//2]).convert("L") arr = _norm01(np.array(im).astype(np.float32)) im = Image.fromarray((arr*255).astype(np.uint8)).resize((px,px),Image.BILINEAR) else: vol, _, _ = _load_png_stack(case_dir) if vol is None: vol, _, _ = _load_nifti(case_dir,case_id,ds=2) if vol is None: im = Image.new("L",(px,px),30) else: mid = vol[:,:,vol.shape[2]//2] im = Image.fromarray((mid*255).astype(np.uint8)).resize((px,px),Image.BILINEAR) im = _rot90_if_needed(im) return ImageOps.expand(im,border=1,fill=200) # ---------- callbacks ---------- def on_load(case_id): load_case(case_id) preload_model_in_background() imgs=_figs_and_imgs(); bars=_bar_ranges_and_values(); dbg=_debug_widgets() return (*imgs,*bars,*dbg, gr.update(value="Point"), None, "Loaded") def on_gallery_select(evt: gr.SelectData): idx=0 if hasattr(evt,"index"): ix=evt.index; idx=int(ix[0] if isinstance(ix,(list,tuple)) else ix) idx=max(0,min(len(EXAMPLES)-1,idx)) cid=EXAMPLES[idx] out=on_load(cid) return (*out, gr.update(value=cid)) def _click_common(view, evt: gr.SelectData): print(f"[CLICK_HANDLER] Called for view={view}") print(f"[CLICK_HANDLER] S.loaded={S.loaded}") print(f"[CLICK_HANDLER] S.bbox_mode={S.bbox_mode}") if not S.loaded: print("[CLICK_HANDLER] Not loaded yet!") return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "Not loaded") xy=_parse_evt_xy(evt) if xy is None: print("[CLICK_HANDLER] Failed to parse event!") return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "Parse failed") u,v=xy x,y,z=_disp_to_vol(view,u,v) print(f"[CLICK_HANDLER] Clicked at display ({u},{v}) -> volume ({x},{y},{z})") if S.bbox_mode: S.bbox_points.append((x,y,z)) if len(S.bbox_points) == 2: p1, p2 = S.bbox_points S.bboxes.append((*p1, *p2)) S.bbox_points = [] print(f"[UI] bbox created: {S.bboxes[-1]}") status = "BBox created!" else: print(f"[UI] bbox point {len(S.bbox_points)}/2") status = f"BBox point {len(S.bbox_points)}/2" else: S.seeds.append((x,y,z)); S.seed_views.append(view) print(f"[UI] seed+ {(x,y,z)} total={len(S.seeds)}") status = f"Added point {len(S.seeds)}" S.cross={"x":x,"y":y,"z":z} S.slice={"sagittal":x,"coronal":y,"axial":z} S.active_view=view imgs=_figs_and_imgs(); bars=_bar_ranges_and_values() idx = len(S.seeds) + len(S.bboxes) - 1 dbg=_debug_widgets(current_idx=idx) return (*imgs,*bars,*dbg, gr.update(value="Point" if not S.bbox_mode else "BBox"), None, status) def on_axial_select(evt: gr.SelectData): print("[EVENT] on_axial_select triggered!") return _click_common("axial", evt) def on_sagittal_select(evt: gr.SelectData): print("[EVENT] on_sagittal_select triggered!") return _click_common("sagittal", evt) def on_coronal_select(evt: gr.SelectData): print("[EVENT] on_coronal_select triggered!") return _click_common("coronal", evt) def on_seg_button(): msg=do_segment() return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Segment: {msg}") def on_clear(): S.seeds=[]; S.seed_views=[]; S.pred=None S.bboxes=[]; S.bbox_points=[] print("[UI] cleared all") return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point"), None, "Cleared") def on_undo(): if S.bbox_mode and S.bbox_points: S.bbox_points.pop() elif S.bboxes: S.bboxes.pop() elif S.seeds: S.seeds.pop(); S.seed_views.pop() if S.seed_views else None print("[UI] undo") return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "Undo") def on_mode_toggle(mode): S.bbox_mode = (mode == "BBox") S.bbox_points = [] print(f"[UI] mode -> {mode}") return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value=mode), None, f"Mode: {mode}") def on_save(): msg, path = save_prediction() if path: return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), gr.update(visible=True, value=path), msg) return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), gr.update(visible=False, value=None), msg) def on_z_release(z_idx): S.slice["axial"]=int(z_idx); S.cross["z"]=S.slice["axial"]; S.active_view="axial" return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Z={z_idx}") def on_x_release(x_idx): S.slice["sagittal"]=int(x_idx); S.cross["x"]=S.slice["sagittal"]; S.active_view="sagittal" return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"X={x_idx}") def on_y_release(y_idx): S.slice["coronal"]=int(y_idx); S.cross["y"]=S.slice["coronal"]; S.active_view="coronal" return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Y={y_idx}") def _jump_to_idx(idx:int): total = len(S.seeds) + len(S.bboxes) if not (0 <= idx < total): return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "") if idx < len(S.seeds): x,y,z = S.seeds[idx] view = S.seed_views[idx] if idx < len(S.seed_views) else "axial" S.cross={"x":x,"y":y,"z":z} S.slice={"sagittal":x,"coronal":y,"axial":z} S.active_view=view return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(current_idx=idx), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Jump to {idx+1}") def on_seed_df_select(evt): try: row = int(evt.index[0]) if hasattr(evt, "index") else 0 except Exception: row = 0 return _jump_to_idx(row) def on_seed_dd_change(val): if not val: return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "") try: if "point" in val: p = val.split("point")[1].strip() num = int(p.split("β†’")[0].strip()) idx = num - 1 else: idx = len(S.seeds) + len(S.bboxes) - 1 except Exception: idx = len(S.seeds) + len(S.bboxes) - 1 return _jump_to_idx(idx) # ---------- UI ---------- css = """ :root { --bg:#f7fbff; --card:#ffffff; --accent:#1e90ff; --shadow:0 8px 26px rgba(45,156,219,.12); } .gradio-container{font-family:Inter,ui-sans-serif,system-ui;background:var(--bg)} .round{background:var(--card);border-radius:16px;padding:10px;box-shadow:var(--shadow);border:1px solid #e8f0fb} .section{font-weight:700;color:#114b8b;margin:2px 0 8px} .tiny .gr-slider input[type="range"]{height:6px} .tiny .gr-form{gap:6px} .smallnote{color:#3e6285;font-size:12px;margin-top:6px} """ THUMBS=[_thumb_from_case(cid,px=96) for cid in EXAMPLES] with gr.Blocks(css=css, title="Interactive-MEN-RT", theme=gr.themes.Soft(), analytics_enabled=False) as demo: gr.Markdown("""

Interactive-MEN-RT Segmentation

Domain-Specialized Interactive Segmentation for Meningioma Radiotherapy Planning
Research only β€” Not for clinical use
""") with gr.Row(): with gr.Column(scale=1, min_width=280, elem_classes=["round"]): gr.Markdown("
Demo Case
") gallery = gr.Gallery(value=THUMBS, columns=1, height=110, allow_preview=False, preview=False, show_label=False) case_dd = gr.Dropdown(choices=EXAMPLES, value=EXAMPLES[0], label="Case") mode_radio = gr.Radio(["Point", "BBox"], value="Point", label="Interaction Mode") with gr.Row(): seg_btn = gr.Button("Segment", variant="primary") save_btn = gr.Button("Save NIfTI", variant="secondary") with gr.Row(): undo_btn = gr.Button("Undo") clr_btn = gr.Button("Clear") gr.Markdown("
Interactions
") seeds_df = gr.Dataframe( headers=["#", "type", "view", "x", "y", "z"], value=[], datatype=["number","str","str","str","str","str"], interactive=False, wrap=True, row_count=(0, "dynamic") ) seed_dd = gr.Dropdown(choices=[], value=None, label="Go to point") pred_file = gr.File(label="Download Prediction", visible=False) status_text = gr.Textbox(label="Status", value="", interactive=False, lines=1) with gr.Column(scale=5): with gr.Row(): with gr.Column(elem_classes=["round"]): axial = gr.Image(type="pil", interactive=True, height=RENDER_PX_DEFAULT+8, label="Axial (Z)") z_bar = gr.Slider(0,1,value=0,step=1,label="Z", elem_classes=["tiny"]) with gr.Column(elem_classes=["round"]): sagittal = gr.Image(type="pil", interactive=True, height=RENDER_PX_DEFAULT+8, label="Sagittal (X)") x_bar = gr.Slider(0,1,value=0,step=1,label="X", elem_classes=["tiny"]) with gr.Column(elem_classes=["round"]): coronal = gr.Image(type="pil", interactive=True, height=RENDER_PX_DEFAULT+8, label="Coronal (Y)") y_bar = gr.Slider(0,1,value=0,step=1,label="Y", elem_classes=["tiny"]) with gr.Row(): out_ax_gt = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Ground Truth") out_ax_pr = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Prediction") inter2d = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Interactions") outputs = [axial, sagittal, coronal, out_ax_gt, out_ax_pr, inter2d, z_bar, x_bar, y_bar, seeds_df, seed_dd, mode_radio, pred_file, status_text] demo.load(lambda: on_load(EXAMPLES[0]), [], outputs) case_dd.change(lambda cid: on_load(cid), [case_dd], outputs) gallery.select(on_gallery_select, [], outputs + [case_dd]) # 클릭 이벀트 - κ°•ν™”λœ ν•Έλ“€λŸ¬ axial.select(on_axial_select, [], outputs) sagittal.select(on_sagittal_select, [], outputs) coronal.select(on_coronal_select, [], outputs) z_bar.release(on_z_release, [z_bar], outputs) x_bar.release(on_x_release, [x_bar], outputs) y_bar.release(on_y_release, [y_bar], outputs) mode_radio.change(on_mode_toggle, [mode_radio], outputs) seg_btn.click(on_seg_button, [], outputs) save_btn.click(on_save, [], outputs) clr_btn.click(on_clear, [], outputs) undo_btn.click(on_undo, [], outputs) seeds_df.select(on_seed_df_select, [], outputs) seed_dd.change(on_seed_dd_change, [seed_dd], outputs) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True, share=True, allowed_paths=[str(DATA_ROOT)])