Spaces:
Sleeping
Sleeping
| import os, warnings, time, glob, tempfile, threading | |
| warnings.filterwarnings("ignore") | |
| os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageOps | |
| import gradio as gr | |
| print("\n" + "="*60) | |
| print("π INTERACTIVE-MEN-RT DEMO DEBUG INFO") | |
| print("="*60) | |
| print(f"π¦ Gradio version: {gr.__version__}") | |
| print(f"π Current directory: {os.getcwd()}") | |
| print(f"π Directory contents: {os.listdir('.')}") | |
| DATA_ROOT = Path("./samples") | |
| print(f"\nπ DATA_ROOT: {DATA_ROOT}") | |
| print(f"π DATA_ROOT exists: {DATA_ROOT.exists()}") | |
| if DATA_ROOT.exists(): | |
| print(f"π DATA_ROOT contents: {list(DATA_ROOT.iterdir())}") | |
| EXAMPLES_CHECK = ["BraTS-MEN-RT-0071-1"] | |
| for case in EXAMPLES_CHECK: | |
| case_dir = DATA_ROOT / case | |
| print(f"\n π¦ Case: {case}") | |
| print(f" Exists: {case_dir.exists()}") | |
| if case_dir.exists(): | |
| files = list(case_dir.iterdir()) | |
| print(f" Files: {[f.name for f in files]}") | |
| has_t1c = any(f.name.endswith('_t1c.nii.gz') for f in files) | |
| print(f" β Has T1c: {has_t1c}") | |
| else: | |
| print("β samples folder NOT FOUND!") | |
| print("\n" + "="*60 + "\n") | |
| # ---------- optional deps ---------- | |
| try: | |
| import nibabel as nib | |
| HAVE_NIB = True | |
| print("β nibabel available") | |
| except Exception as e: | |
| HAVE_NIB = False | |
| print(f"β nibabel not available: {e}") | |
| try: | |
| from scipy import ndimage as ndi | |
| HAVE_SCIPY = True | |
| print("β scipy available") | |
| except Exception as e: | |
| HAVE_SCIPY = False | |
| print(f"β scipy not available: {e}") | |
| # ---------- model predictor ---------- | |
| PREDICTOR = None | |
| DEVICE = "cuda:0" | |
| from huggingface_hub import snapshot_download | |
| try: | |
| _repo_root = snapshot_download("hanjang/Interactive-MEN-RT", | |
| allow_patterns=["nnUNetInteractionTrainer__nnUNetPlans__3d_fullres_scratch/**"]) | |
| CKPT = os.path.join(_repo_root, "nnUNetInteractionTrainer__nnUNetPlans__3d_fullres_scratch") | |
| print(f"[INFO] Checkpoint path: {CKPT}") | |
| if os.path.exists(CKPT): | |
| contents = os.listdir(CKPT) | |
| print(f"[INFO] Checkpoint contents: {contents}") | |
| fold_0 = os.path.join(CKPT, "fold_0") | |
| if os.path.exists(fold_0): | |
| print(f"[INFO] fold_0 contents: {os.listdir(fold_0)}") | |
| else: | |
| print(f"[ERROR] Checkpoint path does not exist!") | |
| CKPT = None | |
| except Exception as e: | |
| print(f"[ERROR] Failed to download checkpoint: {e}") | |
| CKPT = None | |
| for env in ("nnUNet_raw", "nnUNet_preprocessed", "nnUNet_results"): | |
| os.environ.setdefault(env, tempfile.mkdtemp(prefix=f"{env}_")) | |
| def _init_predictor_once(): | |
| global PREDICTOR | |
| if PREDICTOR is not None: | |
| return True | |
| if CKPT is None: | |
| print("[WARN] No checkpoint available, will use fallback only") | |
| return False | |
| try: | |
| import torch | |
| from Interactive_MEN_RT_predictor import InteractiveMENRTPredictor | |
| dev = torch.device(DEVICE if torch.cuda.is_available() else "cpu") | |
| pred = InteractiveMENRTPredictor( | |
| device=dev, use_torch_compile=False, do_autozoom=False, verbose=False | |
| ) | |
| pred.initialize_from_trained_model_folder( | |
| model_training_output_dir=CKPT, use_fold=0, checkpoint_name="checkpoint_best.pth" | |
| ) | |
| PREDICTOR = pred | |
| try: | |
| if torch.cuda.is_available(): | |
| x = np.zeros((1, 8, 8, 8), np.float32) | |
| pred.reset_interactions() | |
| pred.set_image(x) | |
| pred.set_target_buffer(np.zeros_like(x[0], np.float32)) | |
| pred._finish_preprocessing_and_initialize_interactions() | |
| torch.cuda.synchronize() | |
| except Exception: | |
| pass | |
| print("[MODEL] ready") | |
| return True | |
| except Exception as e: | |
| print(f"[MODEL] init failed: {e}") | |
| return False | |
| def preload_model_in_background(): | |
| threading.Thread(target=_init_predictor_once, daemon=True).start() | |
| # ---------- config ---------- | |
| EXAMPLES = ["BraTS-MEN-RT-0071-1"] | |
| RENDER_PX_DEFAULT = 384 | |
| ROT_CCW = True | |
| # colors | |
| ACCENT_HEX = "#1e90ff" | |
| CROSS_RGB = (30, 144, 255) | |
| GT_RGBA_FILL = (255, 215, 0, 128) | |
| PR_RGBA_FILL = (255, 60, 60, 128) | |
| SEED_RGB = (89, 224, 154) | |
| BBOX_RGB = (255, 140, 0) | |
| # ---------- state ---------- | |
| class State: | |
| def __init__(self): | |
| self.vol=None; self.shape=None | |
| self.gt=None; self.pred=None | |
| self.case_id=None; self.loaded=False | |
| self.cross={"x":0,"y":0,"z":0} | |
| self.slice={"axial":0,"sagittal":0,"coronal":0} | |
| self.seeds=[] | |
| self.seed_views=[] | |
| self.render_px=RENDER_PX_DEFAULT | |
| self.disp_wh={"axial":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT), | |
| "sagittal":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT), | |
| "coronal":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT)} | |
| self.active_view="axial" | |
| self.bbox_mode = False | |
| self.bbox_points = [] | |
| self.bboxes = [] | |
| self.ref_affine = None | |
| self.ref_header = None | |
| S = State() | |
| # ---------- utils ---------- | |
| def _norm01(a): | |
| a=a.astype(np.float32) | |
| p2,p98=np.percentile(a,2),np.percentile(a,98) | |
| if p98<=p2: p2,p98=float(a.min()),float(a.max()) or 1.0 | |
| return np.clip((a-p2)/max(p98-p2,1e-6),0,1) | |
| def _resize_slice_nearest(arr2d,w,h): | |
| im=Image.fromarray(arr2d); im=im.resize((w,h),Image.NEAREST); return np.array(im) | |
| def _rot90_if_needed(img_or_np): | |
| if not ROT_CCW: return img_or_np | |
| if isinstance(img_or_np, Image.Image): return img_or_np.rotate(90, expand=True) | |
| return np.rot90(img_or_np, k=1) | |
| # ---------- IO ---------- | |
| def _load_png_stack(case_dir): | |
| pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.png"))) | |
| if not pngs: | |
| pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.jpg"))) | |
| if not pngs: return None, None, None | |
| t0=time.time() | |
| arr=[np.array(Image.open(p).convert("L")) for p in pngs] | |
| vol=np.stack(arr,axis=2).astype(np.float32) | |
| vol=_norm01(vol) | |
| print(f"[PIL] {len(pngs)} slices -> {vol.shape} in {time.time()-t0:.2f}s") | |
| return vol, None, None | |
| def _load_nifti(case_dir,case_id,ds=1): | |
| if not HAVE_NIB: return None, None, None | |
| p=case_dir/f"{case_id}_t1c.nii.gz" | |
| if not p.exists(): return None, None, None | |
| t0=time.time() | |
| nii=nib.load(str(p)) | |
| arr=np.asanyarray(nii.dataobj[::ds,::ds,::ds],dtype=np.float32) | |
| arr=_norm01(arr) | |
| print(f"[NIfTI] {case_id} -> {arr.shape} in {time.time()-t0:.2f}s") | |
| return arr, nii.affine, nii.header | |
| def _resample_mask_to_vol_shape(mask_xyz, vol_shape_xyz): | |
| mx,my,mz=mask_xyz.shape; vx,vy,vz=vol_shape_xyz | |
| out=np.zeros((vx,vy,vz),dtype=np.uint8) | |
| for k in range(vz): | |
| src_k=int(round(k*(mz-1)/max(vz-1,1))) | |
| sl=(mask_xyz[:,:,src_k]>0).astype(np.uint8)*255 | |
| im=Image.fromarray(sl).resize((vy,vx),Image.NEAREST) | |
| out[:,:,k]=(np.array(im)>0).astype(np.uint8) | |
| return out | |
| def _load_gt(case_dir,case_id,vol_shape): | |
| candidates = [ | |
| f"{case_id}_gtv.nii.gz", f"{case_id}_seg.nii.gz", f"{case_id}_gt.nii.gz", | |
| "gtv.nii.gz", "seg.nii.gz", "gt.nii.gz", | |
| f"{case_id}_gtv.nii", f"{case_id}_seg.nii", f"{case_id}_gt.nii", | |
| ] | |
| if HAVE_NIB: | |
| for name in candidates: | |
| p = case_dir/name | |
| if p.exists(): | |
| try: | |
| m=np.asanyarray(nib.load(str(p)).dataobj,dtype=np.uint8) | |
| print(f"[GT] found {p.name} raw={m.shape}") | |
| m=_resample_mask_to_vol_shape(m,vol_shape) | |
| print(f"[GT] resized -> {m.shape}") | |
| return (m>0).astype(np.uint8) | |
| except Exception as e: | |
| print(f"[GT] load error {p.name}: {e}") | |
| print("[GT] not found.") | |
| return None | |
| def load_case(case_id): | |
| case_dir=DATA_ROOT/case_id | |
| vol, affine, header = _load_png_stack(case_dir) | |
| if vol is None: | |
| vol, affine, header = _load_nifti(case_dir,case_id,ds=1) | |
| if vol is None: | |
| Z=96 | |
| x=np.linspace(-1,1,RENDER_PX_DEFAULT)[:,None,] | |
| y=np.linspace(-1,1,RENDER_PX_DEFAULT)[None,:,] | |
| z=np.linspace(-1,1,Z)[None,None,:] | |
| vol=np.exp(-(x**2+y**2+z**2)*6).astype(np.float32) | |
| affine = np.eye(4) | |
| header = None | |
| print("[VOL] dummy") | |
| S.vol=vol; S.shape=vol.shape; S.pred=None | |
| S.seeds=[]; S.seed_views=[] | |
| S.bbox_mode=False; S.bbox_points=[]; S.bboxes=[] | |
| S.case_id=case_id | |
| S.ref_affine = affine | |
| S.ref_header = header | |
| X,Y,Z=S.shape | |
| S.cross={"x":X//2,"y":Y//2,"z":Z//2} | |
| S.slice={"sagittal":S.cross["x"],"coronal":S.cross["y"],"axial":S.cross["z"]} | |
| S.gt=_load_gt(case_dir,case_id,S.shape) | |
| S.render_px=RENDER_PX_DEFAULT | |
| S.active_view="axial" | |
| S.loaded=True | |
| print(f"[LOAD] {case_id} | shape={S.shape}") | |
| # ---------- 2D rendering ---------- | |
| def _slice2d(view): | |
| if view=="axial": sl=S.vol[:,:,S.slice["axial"]] | |
| elif view=="sagittal": sl=S.vol[S.slice["sagittal"],:,:].T | |
| else: sl=S.vol[:,S.slice["coronal"],:].T | |
| return _rot90_if_needed(sl) | |
| def _cross_pix_on_rot(view,w,h,x=None,y=None,z=None): | |
| X,Y,Z=S.shape | |
| if x is None: x=S.cross["x"] | |
| if y is None: y=S.cross["y"] | |
| if z is None: z=S.cross["z"] | |
| if view=="axial": | |
| u=int(round(x*(w-1)/max(X-1,1))) | |
| v=int(round((Y-1-y)*(h-1)/max(Y-1,1))) | |
| elif view=="sagittal": | |
| u=int(round(z*(w-1)/max(Z-1,1))) | |
| v=int(round((Y-1-y)*(h-1)/max(Y-1,1))) | |
| else: | |
| u=int(round(z*(w-1)/max(Z-1,1))) | |
| v=int(round((X-1-x)*(h-1)/max(X-1,1))) | |
| return u,v | |
| def _draw_cross(img_draw, view, w, h): | |
| u,v=_cross_pix_on_rot(view,w,h) | |
| img_draw.line([(u,0),(u,h)],fill=CROSS_RGB,width=1) | |
| img_draw.line([(0,v),(w,v)],fill=CROSS_RGB,width=1) | |
| img_draw.ellipse((u-5,v-5,u+5,v+5),fill=(255,255,255),outline=ACCENT_HEX,width=2) | |
| def render_top(view): | |
| if not S.loaded: return None | |
| sl=_slice2d(view) | |
| im=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGB") | |
| w=h=S.render_px | |
| dr=ImageDraw.Draw(im) | |
| _draw_cross(dr, view, w, h) | |
| S.disp_wh[view]=im.size | |
| return im | |
| def _axial_mask2d_rot(mask3d): | |
| if mask3d is None: return None | |
| m = mask3d[:,:,S.slice["axial"]].astype(np.uint8) | |
| m = _rot90_if_needed(m) | |
| return m | |
| def _axial_overlay_fill(mask3d, rgba): | |
| sl = _rot90_if_needed(S.vol[:,:,S.slice["axial"]]) | |
| base=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGBA") | |
| m2d = _axial_mask2d_rot(mask3d) | |
| if m2d is None: return base.convert("RGB") | |
| m2d = _resize_slice_nearest((m2d>0).astype(np.uint8), S.render_px, S.render_px) | |
| over=np.zeros((S.render_px,S.render_px,4),dtype=np.uint8); over[m2d>0]=rgba | |
| return Image.alpha_composite(base,Image.fromarray(over,"RGBA")).convert("RGB") | |
| def _interaction_2d(): | |
| view = S.active_view | |
| if not S.loaded: return None | |
| sl=_slice2d(view) | |
| im=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGB") | |
| w=h=S.render_px | |
| dr=ImageDraw.Draw(im) | |
| _draw_cross(dr, view, w, h) | |
| tol=0 | |
| for i, (x,y,z) in enumerate(S.seeds): | |
| on_plane = ( | |
| (view=="axial" and abs(z - S.slice["axial"]) <= tol) or | |
| (view=="sagittal" and abs(x - S.slice["sagittal"])<= tol) or | |
| (view=="coronal" and abs(y - S.slice["coronal"]) <= tol) | |
| ) | |
| if not on_plane: continue | |
| u,v=_cross_pix_on_rot(view,w,h,x,y,z) | |
| r=4 | |
| dr.ellipse((u-r,v-r,u+r,v+r), fill=SEED_RGB, outline=(40,140,100), width=1) | |
| dr.text((u+6, v-8), f"{i+1}", fill=(30,30,30)) | |
| if S.bbox_mode: | |
| for i, (x,y,z) in enumerate(S.bbox_points): | |
| on_plane = ( | |
| (view=="axial" and abs(z - S.slice["axial"]) <= tol) or | |
| (view=="sagittal" and abs(x - S.slice["sagittal"])<= tol) or | |
| (view=="coronal" and abs(y - S.slice["coronal"]) <= tol) | |
| ) | |
| if not on_plane: continue | |
| u,v=_cross_pix_on_rot(view,w,h,x,y,z) | |
| r=6 | |
| dr.rectangle((u-r,v-r,u+r,v+r), outline=BBOX_RGB, width=3) | |
| text = "P1" if i == 0 else "P2" | |
| dr.text((u+10, v-10), text, fill=BBOX_RGB) | |
| for (x1,y1,z1,x2,y2,z2) in S.bboxes: | |
| if view=="axial": | |
| curr_z = S.slice["axial"] | |
| if min(z1,z2)-1 <= curr_z <= max(z1,z2)+1: | |
| u1,v1 = _cross_pix_on_rot(view,w,h,x1,y1,curr_z) | |
| u2,v2 = _cross_pix_on_rot(view,w,h,x2,y2,curr_z) | |
| dr.rectangle([u1,v1,u2,v2], outline=(0,255,0), width=2) | |
| return im | |
| # ===================== segmentation ======================================== | |
| def _segment_with_model(): | |
| if PREDICTOR is None and not _init_predictor_once(): | |
| return None, "model-init-failed" | |
| try: | |
| img = S.vol[None].astype(np.float32) | |
| PREDICTOR.reset_interactions() | |
| PREDICTOR.set_image(img) | |
| PREDICTOR.set_target_buffer(np.zeros_like(img[0], np.float32)) | |
| PREDICTOR._finish_preprocessing_and_initialize_interactions() | |
| for (x,y,z) in S.seeds: | |
| PREDICTOR.add_point_interaction(x, y, z, foreground=True) | |
| for (x1,y1,z1,x2,y2,z2) in S.bboxes: | |
| PREDICTOR.add_bbox_interaction( | |
| min(x1,x2), min(y1,y2), min(z1,z2), | |
| max(x1,x2), max(y1,y2), max(z1,z2) | |
| ) | |
| PREDICTOR._predict_without_interaction() | |
| pred = (PREDICTOR.target_buffer.astype(np.float32) > 0.5).astype(np.uint8) | |
| if pred.shape != S.shape: | |
| print(f"[MODEL] resize pred {pred.shape}->{S.shape}") | |
| pred = _resample_mask_to_vol_shape(pred, S.shape) | |
| return pred, "ok" | |
| except Exception as e: | |
| print(f"[MODEL] inference failed: {e}") | |
| return None, "model-error" | |
| def _segment_fallback(): | |
| if (not S.seeds and not S.bboxes) or not HAVE_SCIPY: | |
| return None, "no-interactions-or-scipy" | |
| X,Y,Z=S.shape | |
| field=np.zeros((X,Y,Z),dtype=np.float32) | |
| for (x,y,z) in S.seeds: field[x,y,z]=1.0 | |
| for (x1,y1,z1,x2,y2,z2) in S.bboxes: | |
| field[min(x1,x2):max(x1,x2)+1, min(y1,y2):max(y1,y2)+1, min(z1,z2):max(z1,z2)+1] = 0.5 | |
| t0=time.time() | |
| prob=ndi.gaussian_filter(field,sigma=6.0) | |
| if prob.max()>0: prob/=prob.max() | |
| nz=prob[prob>0]; thr=np.percentile(nz,70) if nz.size else 0.5 | |
| mask=prob>=max(thr,1e-3) | |
| lab,nlab=ndi.label(mask.astype(np.uint8)) | |
| if nlab>1: | |
| keep=np.zeros(nlab+1,np.uint8) | |
| for (x,y,z) in S.seeds: keep[lab[x,y,z]]=1 | |
| for (x1,y1,z1,x2,y2,z2) in S.bboxes: | |
| xm,ym,zm=(x1+x2)//2,(y1+y2)//2,(z1+z2)//2 | |
| keep[lab[xm,ym,zm]]=1 | |
| mask=keep[lab]>0 | |
| mask=ndi.binary_closing(mask,iterations=1); mask=ndi.binary_opening(mask,iterations=1) | |
| print(f"[FB] seg {time.time()-t0:.3f}s | vox={int(mask.sum())}") | |
| return (mask>0).astype(np.uint8), "ok" | |
| def do_segment(): | |
| pred, tag = _segment_with_model() | |
| if pred is None: | |
| pred, tag2 = _segment_fallback(); tag = f"{tag}->{tag2}" | |
| S.pred = pred if pred is not None else None | |
| print(f"[SEG] done: {tag}") | |
| return "OK" if S.pred is not None else "Failed" | |
| def save_prediction(): | |
| if S.pred is None: | |
| return "No prediction to save", None | |
| if not HAVE_NIB: | |
| return "nibabel not installed", None | |
| try: | |
| tmp_dir = Path(tempfile.mkdtemp(prefix="menrt_output_")) | |
| out_path = tmp_dir / f"{S.case_id}_pred.nii.gz" | |
| affine = S.ref_affine if S.ref_affine is not None else np.eye(4) | |
| header = S.ref_header.copy() if S.ref_header is not None else None | |
| nii_img = nib.Nifti1Image(S.pred.astype(np.uint8), affine, header=header) | |
| nib.save(nii_img, str(out_path)) | |
| print(f"[SAVE] {out_path}") | |
| return "Saved successfully!", str(out_path) | |
| except Exception as e: | |
| print(f"[SAVE] error: {e}") | |
| return f"Save failed: {e}", None | |
| # ---------- helpers ---------- | |
| def _seed_rows(): | |
| rows=[] | |
| for i,(x,y,z) in enumerate(S.seeds): | |
| v = S.seed_views[i] if i < len(S.seed_views) else "" | |
| rows.append([i+1, "point", v, x, y, z]) | |
| for i,(x1,y1,z1,x2,y2,z2) in enumerate(S.bboxes): | |
| rows.append([len(S.seeds)+i+1, "bbox", "3D", f"{x1}-{x2}", f"{y1}-{y2}", f"{z1}-{z2}"]) | |
| return rows | |
| def _seed_dropdown_options(): | |
| opts=[] | |
| for i,(x,y,z) in enumerate(S.seeds): | |
| v = S.seed_views[i] if i < len(S.seed_views) else "axial" | |
| opts.append(f"{v} β point {i+1} β ({x},{y},{z})") | |
| for i,(x1,y1,z1,x2,y2,z2) in enumerate(S.bboxes): | |
| opts.append(f"3D β bbox {i+1} β ({x1},{y1},{z1})-({x2},{y2},{z2})") | |
| return opts | |
| def _debug_widgets(current_idx=None): | |
| rows = _seed_rows() | |
| df_upd = gr.update(value=rows) | |
| opts = _seed_dropdown_options() | |
| if current_idx is None: | |
| val = (opts[-1] if opts else None) | |
| else: | |
| val = (opts[current_idx] if (0 <= current_idx < len(opts)) else (opts[-1] if opts else None)) | |
| dd_upd = gr.update(choices=opts, value=val) | |
| return df_upd, dd_upd | |
| def _figs_and_imgs(): | |
| top_ax=render_top("axial") | |
| top_sg=render_top("sagittal") | |
| top_co=render_top("coronal") | |
| ax_gt = _axial_overlay_fill(S.gt, GT_RGBA_FILL) | |
| ax_pr = _axial_overlay_fill(S.pred, PR_RGBA_FILL) | |
| inter2d = _interaction_2d() | |
| return top_ax, top_sg, top_co, ax_gt, ax_pr, inter2d | |
| def _bar_ranges_and_values(): | |
| X,Y,Z=S.shape | |
| return (gr.update(minimum=0,maximum=Z-1,value=S.slice["axial"],visible=True), | |
| gr.update(minimum=0,maximum=X-1,value=S.slice["sagittal"],visible=True), | |
| gr.update(minimum=0,maximum=Y-1,value=S.slice["coronal"],visible=True)) | |
| def _parse_evt_xy(evt): | |
| """κ°νλ μ΄λ²€νΈ νμ±""" | |
| print(f"[DEBUG_EVT] Event received: {evt}") | |
| print(f"[DEBUG_EVT] Event type: {type(evt)}") | |
| print(f"[DEBUG_EVT] Event dir: {dir(evt)}") | |
| if evt is None: | |
| print("[DEBUG_EVT] Event is None!") | |
| return None | |
| try: | |
| # Method 1: evt.index | |
| if hasattr(evt, "index") and evt.index is not None: | |
| ix = evt.index | |
| print(f"[DEBUG_EVT] Found index: {ix}") | |
| if isinstance(ix, (list, tuple)) and len(ix) >= 2: | |
| result = int(ix[0]), int(ix[1]) | |
| print(f"[DEBUG_EVT] Parsed from index: {result}") | |
| return result | |
| # Method 2: evt.x, evt.y | |
| if hasattr(evt, "x") and hasattr(evt, "y"): | |
| x_val = getattr(evt, "x") | |
| y_val = getattr(evt, "y") | |
| print(f"[DEBUG_EVT] Found x={x_val}, y={y_val}") | |
| if x_val is not None and y_val is not None: | |
| result = int(x_val), int(y_val) | |
| print(f"[DEBUG_EVT] Parsed from x,y: {result}") | |
| return result | |
| except Exception as e: | |
| print(f"[DEBUG_EVT] Parse error: {e}") | |
| print("[DEBUG_EVT] Failed to parse coordinates!") | |
| return None | |
| def _disp_to_vol(view,u,v): | |
| X,Y,Z=S.shape; w,h=S.disp_wh[view] | |
| if w<=0 or h<=0: w=h=S.render_px | |
| if view=="axial": | |
| x = int(round(u * (X-1) / max(w-1,1))) | |
| y = int(round((Y-1) - v * (Y-1) / max(h-1,1))) | |
| z = S.slice["axial"] | |
| elif view=="sagittal": | |
| z = int(round(u * (Z-1) / max(w-1,1))) | |
| y = int(round((Y-1) - v * (Y-1) / max(h-1,1))) | |
| x = S.slice["sagittal"] | |
| else: | |
| z = int(round(u * (Z-1) / max(w-1,1))) | |
| x = int(round((X-1) - v * (X-1) / max(h-1,1))) | |
| y = S.slice["coronal"] | |
| x=max(0,min(X-1,x)); y=max(0,min(Y-1,y)); z=max(0,min(Z-1,z)) | |
| return x,y,z | |
| def _thumb_from_case(case_id,px=96): | |
| case_dir=DATA_ROOT/case_id | |
| pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.png"))) | |
| if not pngs: | |
| pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.jpg"))) | |
| if pngs: | |
| im = Image.open(pngs[len(pngs)//2]).convert("L") | |
| arr = _norm01(np.array(im).astype(np.float32)) | |
| im = Image.fromarray((arr*255).astype(np.uint8)).resize((px,px),Image.BILINEAR) | |
| else: | |
| vol, _, _ = _load_png_stack(case_dir) | |
| if vol is None: | |
| vol, _, _ = _load_nifti(case_dir,case_id,ds=2) | |
| if vol is None: | |
| im = Image.new("L",(px,px),30) | |
| else: | |
| mid = vol[:,:,vol.shape[2]//2] | |
| im = Image.fromarray((mid*255).astype(np.uint8)).resize((px,px),Image.BILINEAR) | |
| im = _rot90_if_needed(im) | |
| return ImageOps.expand(im,border=1,fill=200) | |
| # ---------- callbacks ---------- | |
| def on_load(case_id): | |
| load_case(case_id) | |
| preload_model_in_background() | |
| imgs=_figs_and_imgs(); bars=_bar_ranges_and_values(); dbg=_debug_widgets() | |
| return (*imgs,*bars,*dbg, gr.update(value="Point"), None, "Loaded") | |
| def on_gallery_select(evt: gr.SelectData): | |
| idx=0 | |
| if hasattr(evt,"index"): | |
| ix=evt.index; idx=int(ix[0] if isinstance(ix,(list,tuple)) else ix) | |
| idx=max(0,min(len(EXAMPLES)-1,idx)) | |
| cid=EXAMPLES[idx] | |
| out=on_load(cid) | |
| return (*out, gr.update(value=cid)) | |
| def _click_common(view, evt: gr.SelectData): | |
| print(f"[CLICK_HANDLER] Called for view={view}") | |
| print(f"[CLICK_HANDLER] S.loaded={S.loaded}") | |
| print(f"[CLICK_HANDLER] S.bbox_mode={S.bbox_mode}") | |
| if not S.loaded: | |
| print("[CLICK_HANDLER] Not loaded yet!") | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), | |
| gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "Not loaded") | |
| xy=_parse_evt_xy(evt) | |
| if xy is None: | |
| print("[CLICK_HANDLER] Failed to parse event!") | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), | |
| gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "Parse failed") | |
| u,v=xy | |
| x,y,z=_disp_to_vol(view,u,v) | |
| print(f"[CLICK_HANDLER] Clicked at display ({u},{v}) -> volume ({x},{y},{z})") | |
| if S.bbox_mode: | |
| S.bbox_points.append((x,y,z)) | |
| if len(S.bbox_points) == 2: | |
| p1, p2 = S.bbox_points | |
| S.bboxes.append((*p1, *p2)) | |
| S.bbox_points = [] | |
| print(f"[UI] bbox created: {S.bboxes[-1]}") | |
| status = "BBox created!" | |
| else: | |
| print(f"[UI] bbox point {len(S.bbox_points)}/2") | |
| status = f"BBox point {len(S.bbox_points)}/2" | |
| else: | |
| S.seeds.append((x,y,z)); S.seed_views.append(view) | |
| print(f"[UI] seed+ {(x,y,z)} total={len(S.seeds)}") | |
| status = f"Added point {len(S.seeds)}" | |
| S.cross={"x":x,"y":y,"z":z} | |
| S.slice={"sagittal":x,"coronal":y,"axial":z} | |
| S.active_view=view | |
| imgs=_figs_and_imgs(); bars=_bar_ranges_and_values() | |
| idx = len(S.seeds) + len(S.bboxes) - 1 | |
| dbg=_debug_widgets(current_idx=idx) | |
| return (*imgs,*bars,*dbg, gr.update(value="Point" if not S.bbox_mode else "BBox"), None, status) | |
| def on_axial_select(evt: gr.SelectData): | |
| print("[EVENT] on_axial_select triggered!") | |
| return _click_common("axial", evt) | |
| def on_sagittal_select(evt: gr.SelectData): | |
| print("[EVENT] on_sagittal_select triggered!") | |
| return _click_common("sagittal", evt) | |
| def on_coronal_select(evt: gr.SelectData): | |
| print("[EVENT] on_coronal_select triggered!") | |
| return _click_common("coronal", evt) | |
| def on_seg_button(): | |
| msg=do_segment() | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), | |
| gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Segment: {msg}") | |
| def on_clear(): | |
| S.seeds=[]; S.seed_views=[]; S.pred=None | |
| S.bboxes=[]; S.bbox_points=[] | |
| print("[UI] cleared all") | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), | |
| gr.update(value="Point"), None, "Cleared") | |
| def on_undo(): | |
| if S.bbox_mode and S.bbox_points: | |
| S.bbox_points.pop() | |
| elif S.bboxes: | |
| S.bboxes.pop() | |
| elif S.seeds: | |
| S.seeds.pop(); S.seed_views.pop() if S.seed_views else None | |
| print("[UI] undo") | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), | |
| gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "Undo") | |
| def on_mode_toggle(mode): | |
| S.bbox_mode = (mode == "BBox") | |
| S.bbox_points = [] | |
| print(f"[UI] mode -> {mode}") | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), | |
| gr.update(value=mode), None, f"Mode: {mode}") | |
| def on_save(): | |
| msg, path = save_prediction() | |
| if path: | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), | |
| gr.update(value="Point" if not S.bbox_mode else "BBox"), | |
| gr.update(visible=True, value=path), msg) | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), | |
| gr.update(value="Point" if not S.bbox_mode else "BBox"), | |
| gr.update(visible=False, value=None), msg) | |
| def on_z_release(z_idx): | |
| S.slice["axial"]=int(z_idx); S.cross["z"]=S.slice["axial"]; S.active_view="axial" | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), | |
| gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Z={z_idx}") | |
| def on_x_release(x_idx): | |
| S.slice["sagittal"]=int(x_idx); S.cross["x"]=S.slice["sagittal"]; S.active_view="sagittal" | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), | |
| gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"X={x_idx}") | |
| def on_y_release(y_idx): | |
| S.slice["coronal"]=int(y_idx); S.cross["y"]=S.slice["coronal"]; S.active_view="coronal" | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), | |
| gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Y={y_idx}") | |
| def _jump_to_idx(idx:int): | |
| total = len(S.seeds) + len(S.bboxes) | |
| if not (0 <= idx < total): | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), | |
| gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "") | |
| if idx < len(S.seeds): | |
| x,y,z = S.seeds[idx] | |
| view = S.seed_views[idx] if idx < len(S.seed_views) else "axial" | |
| S.cross={"x":x,"y":y,"z":z} | |
| S.slice={"sagittal":x,"coronal":y,"axial":z} | |
| S.active_view=view | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(current_idx=idx), | |
| gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Jump to {idx+1}") | |
| def on_seed_df_select(evt): | |
| try: | |
| row = int(evt.index[0]) if hasattr(evt, "index") else 0 | |
| except Exception: | |
| row = 0 | |
| return _jump_to_idx(row) | |
| def on_seed_dd_change(val): | |
| if not val: | |
| return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(), | |
| gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "") | |
| try: | |
| if "point" in val: | |
| p = val.split("point")[1].strip() | |
| num = int(p.split("β")[0].strip()) | |
| idx = num - 1 | |
| else: | |
| idx = len(S.seeds) + len(S.bboxes) - 1 | |
| except Exception: | |
| idx = len(S.seeds) + len(S.bboxes) - 1 | |
| return _jump_to_idx(idx) | |
| # ---------- UI ---------- | |
| css = """ | |
| :root { | |
| --bg:#f7fbff; --card:#ffffff; --accent:#1e90ff; --shadow:0 8px 26px rgba(45,156,219,.12); | |
| } | |
| .gradio-container{font-family:Inter,ui-sans-serif,system-ui;background:var(--bg)} | |
| .round{background:var(--card);border-radius:16px;padding:10px;box-shadow:var(--shadow);border:1px solid #e8f0fb} | |
| .section{font-weight:700;color:#114b8b;margin:2px 0 8px} | |
| .tiny .gr-slider input[type="range"]{height:6px} | |
| .tiny .gr-form{gap:6px} | |
| .smallnote{color:#3e6285;font-size:12px;margin-top:6px} | |
| """ | |
| THUMBS=[_thumb_from_case(cid,px=96) for cid in EXAMPLES] | |
| with gr.Blocks(css=css, title="Interactive-MEN-RT", theme=gr.themes.Soft(), analytics_enabled=False) as demo: | |
| gr.Markdown(""" | |
| <h3 style='color:#114b8b'>Interactive-MEN-RT Segmentation</h3> | |
| <div class='smallnote'> | |
| Domain-Specialized Interactive Segmentation for Meningioma Radiotherapy Planning<br> | |
| <b>Research only β Not for clinical use</b> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=280, elem_classes=["round"]): | |
| gr.Markdown("<div class='section'>Demo Case</div>") | |
| gallery = gr.Gallery(value=THUMBS, columns=1, height=110, allow_preview=False, preview=False, show_label=False) | |
| case_dd = gr.Dropdown(choices=EXAMPLES, value=EXAMPLES[0], label="Case") | |
| mode_radio = gr.Radio(["Point", "BBox"], value="Point", label="Interaction Mode") | |
| with gr.Row(): | |
| seg_btn = gr.Button("Segment", variant="primary") | |
| save_btn = gr.Button("Save NIfTI", variant="secondary") | |
| with gr.Row(): | |
| undo_btn = gr.Button("Undo") | |
| clr_btn = gr.Button("Clear") | |
| gr.Markdown("<div class='section' style='margin-top:8px'>Interactions</div>") | |
| seeds_df = gr.Dataframe( | |
| headers=["#", "type", "view", "x", "y", "z"], | |
| value=[], | |
| datatype=["number","str","str","str","str","str"], | |
| interactive=False, | |
| wrap=True, | |
| row_count=(0, "dynamic") | |
| ) | |
| seed_dd = gr.Dropdown(choices=[], value=None, label="Go to point") | |
| pred_file = gr.File(label="Download Prediction", visible=False) | |
| status_text = gr.Textbox(label="Status", value="", interactive=False, lines=1) | |
| with gr.Column(scale=5): | |
| with gr.Row(): | |
| with gr.Column(elem_classes=["round"]): | |
| axial = gr.Image(type="pil", interactive=True, height=RENDER_PX_DEFAULT+8, label="Axial (Z)") | |
| z_bar = gr.Slider(0,1,value=0,step=1,label="Z", elem_classes=["tiny"]) | |
| with gr.Column(elem_classes=["round"]): | |
| sagittal = gr.Image(type="pil", interactive=True, height=RENDER_PX_DEFAULT+8, label="Sagittal (X)") | |
| x_bar = gr.Slider(0,1,value=0,step=1,label="X", elem_classes=["tiny"]) | |
| with gr.Column(elem_classes=["round"]): | |
| coronal = gr.Image(type="pil", interactive=True, height=RENDER_PX_DEFAULT+8, label="Coronal (Y)") | |
| y_bar = gr.Slider(0,1,value=0,step=1,label="Y", elem_classes=["tiny"]) | |
| with gr.Row(): | |
| out_ax_gt = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Ground Truth") | |
| out_ax_pr = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Prediction") | |
| inter2d = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Interactions") | |
| outputs = [axial, sagittal, coronal, out_ax_gt, out_ax_pr, inter2d, | |
| z_bar, x_bar, y_bar, seeds_df, seed_dd, mode_radio, pred_file, status_text] | |
| demo.load(lambda: on_load(EXAMPLES[0]), [], outputs) | |
| case_dd.change(lambda cid: on_load(cid), [case_dd], outputs) | |
| gallery.select(on_gallery_select, [], outputs + [case_dd]) | |
| # ν΄λ¦ μ΄λ²€νΈ - κ°νλ νΈλ€λ¬ | |
| axial.select(on_axial_select, [], outputs) | |
| sagittal.select(on_sagittal_select, [], outputs) | |
| coronal.select(on_coronal_select, [], outputs) | |
| z_bar.release(on_z_release, [z_bar], outputs) | |
| x_bar.release(on_x_release, [x_bar], outputs) | |
| y_bar.release(on_y_release, [y_bar], outputs) | |
| mode_radio.change(on_mode_toggle, [mode_radio], outputs) | |
| seg_btn.click(on_seg_button, [], outputs) | |
| save_btn.click(on_save, [], outputs) | |
| clr_btn.click(on_clear, [], outputs) | |
| undo_btn.click(on_undo, [], outputs) | |
| seeds_df.select(on_seed_df_select, [], outputs) | |
| seed_dd.change(on_seed_dd_change, [seed_dd], outputs) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True, share=True, | |
| allowed_paths=[str(DATA_ROOT)]) |