Spaces:
Runtime error
Runtime error
Upload /app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -5,31 +5,85 @@ from pathlib import Path
|
|
| 5 |
import numpy as np
|
| 6 |
from PIL import Image, ImageDraw, ImageOps
|
| 7 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
# ---------- optional deps ----------
|
| 9 |
try:
|
| 10 |
import nibabel as nib
|
| 11 |
HAVE_NIB = True
|
| 12 |
-
|
|
|
|
| 13 |
HAVE_NIB = False
|
|
|
|
|
|
|
| 14 |
try:
|
| 15 |
from scipy import ndimage as ndi
|
| 16 |
HAVE_SCIPY = True
|
| 17 |
-
|
|
|
|
| 18 |
HAVE_SCIPY = False
|
|
|
|
|
|
|
| 19 |
# ---------- model predictor ----------
|
| 20 |
PREDICTOR = None
|
| 21 |
DEVICE = "cuda:0"
|
|
|
|
| 22 |
from huggingface_hub import snapshot_download
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
for env in ("nnUNet_raw", "nnUNet_preprocessed", "nnUNet_results"):
|
| 27 |
os.environ.setdefault(env, tempfile.mkdtemp(prefix=f"{env}_"))
|
|
|
|
| 28 |
def _init_predictor_once():
|
| 29 |
-
"""Load model once + (best-effort) CUDA warm-up."""
|
| 30 |
global PREDICTOR
|
| 31 |
if PREDICTOR is not None:
|
| 32 |
return True
|
|
|
|
|
|
|
|
|
|
| 33 |
try:
|
| 34 |
import torch
|
| 35 |
from Interactive_MEN_RT_predictor import InteractiveMENRTPredictor
|
|
@@ -41,7 +95,6 @@ def _init_predictor_once():
|
|
| 41 |
model_training_output_dir=CKPT, use_fold=0, checkpoint_name="checkpoint_best.pth"
|
| 42 |
)
|
| 43 |
PREDICTOR = pred
|
| 44 |
-
# GPU warm-up (best effort)
|
| 45 |
try:
|
| 46 |
if torch.cuda.is_available():
|
| 47 |
x = np.zeros((1, 8, 8, 8), np.float32)
|
|
@@ -57,26 +110,23 @@ def _init_predictor_once():
|
|
| 57 |
except Exception as e:
|
| 58 |
print(f"[MODEL] init failed: {e}")
|
| 59 |
return False
|
|
|
|
| 60 |
def preload_model_in_background():
|
| 61 |
threading.Thread(target=_init_predictor_once, daemon=True).start()
|
|
|
|
| 62 |
# ---------- config ----------
|
| 63 |
-
EXAMPLES = [
|
| 64 |
-
"BraTS-MEN-RT-0071-1",
|
| 65 |
-
"BraTS-MEN-RT-0223-1",
|
| 66 |
-
"BraTS-MEN-RT-0280-1",
|
| 67 |
-
"BraTS-MEN-RT-0422-1",
|
| 68 |
-
"BraTS-MEN-RT-0436-1",
|
| 69 |
-
]
|
| 70 |
RENDER_PX_DEFAULT = 384
|
| 71 |
-
|
| 72 |
-
|
| 73 |
# colors
|
| 74 |
ACCENT_HEX = "#1e90ff"
|
| 75 |
CROSS_RGB = (30, 144, 255)
|
| 76 |
-
GT_RGBA_FILL = (255, 215, 0, 128)
|
| 77 |
-
PR_RGBA_FILL = (255, 60, 60, 128)
|
| 78 |
SEED_RGB = (89, 224, 154)
|
| 79 |
-
BBOX_RGB = (255, 140, 0)
|
|
|
|
| 80 |
# ---------- state ----------
|
| 81 |
class State:
|
| 82 |
def __init__(self):
|
|
@@ -85,33 +135,36 @@ class State:
|
|
| 85 |
self.case_id=None; self.loaded=False
|
| 86 |
self.cross={"x":0,"y":0,"z":0}
|
| 87 |
self.slice={"axial":0,"sagittal":0,"coronal":0}
|
| 88 |
-
self.seeds=[]
|
| 89 |
-
self.seed_views=[]
|
| 90 |
self.render_px=RENDER_PX_DEFAULT
|
| 91 |
self.disp_wh={"axial":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT),
|
| 92 |
"sagittal":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT),
|
| 93 |
"coronal":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT)}
|
| 94 |
-
self.active_view="axial"
|
| 95 |
-
|
| 96 |
-
self.
|
| 97 |
-
self.
|
| 98 |
-
self.bboxes = [] # [(x1,y1,z1,x2,y2,z2), ...]
|
| 99 |
-
# NIfTI metadata
|
| 100 |
self.ref_affine = None
|
| 101 |
self.ref_header = None
|
|
|
|
| 102 |
S = State()
|
|
|
|
| 103 |
# ---------- utils ----------
|
| 104 |
def _norm01(a):
|
| 105 |
a=a.astype(np.float32)
|
| 106 |
p2,p98=np.percentile(a,2),np.percentile(a,98)
|
| 107 |
if p98<=p2: p2,p98=float(a.min()),float(a.max()) or 1.0
|
| 108 |
return np.clip((a-p2)/max(p98-p2,1e-6),0,1)
|
|
|
|
| 109 |
def _resize_slice_nearest(arr2d,w,h):
|
| 110 |
im=Image.fromarray(arr2d); im=im.resize((w,h),Image.NEAREST); return np.array(im)
|
|
|
|
| 111 |
def _rot90_if_needed(img_or_np):
|
| 112 |
if not ROT_CCW: return img_or_np
|
| 113 |
if isinstance(img_or_np, Image.Image): return img_or_np.rotate(90, expand=True)
|
| 114 |
return np.rot90(img_or_np, k=1)
|
|
|
|
| 115 |
# ---------- IO ----------
|
| 116 |
def _load_png_stack(case_dir):
|
| 117 |
pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.png")))
|
|
@@ -124,6 +177,7 @@ def _load_png_stack(case_dir):
|
|
| 124 |
vol=_norm01(vol)
|
| 125 |
print(f"[PIL] {len(pngs)} slices -> {vol.shape} in {time.time()-t0:.2f}s")
|
| 126 |
return vol, None, None
|
|
|
|
| 127 |
def _load_nifti(case_dir,case_id,ds=1):
|
| 128 |
if not HAVE_NIB: return None, None, None
|
| 129 |
p=case_dir/f"{case_id}_t1c.nii.gz"
|
|
@@ -134,6 +188,7 @@ def _load_nifti(case_dir,case_id,ds=1):
|
|
| 134 |
arr=_norm01(arr)
|
| 135 |
print(f"[NIfTI] {case_id} -> {arr.shape} in {time.time()-t0:.2f}s")
|
| 136 |
return arr, nii.affine, nii.header
|
|
|
|
| 137 |
def _resample_mask_to_vol_shape(mask_xyz, vol_shape_xyz):
|
| 138 |
mx,my,mz=mask_xyz.shape; vx,vy,vz=vol_shape_xyz
|
| 139 |
out=np.zeros((vx,vy,vz),dtype=np.uint8)
|
|
@@ -143,6 +198,7 @@ def _resample_mask_to_vol_shape(mask_xyz, vol_shape_xyz):
|
|
| 143 |
im=Image.fromarray(sl).resize((vy,vx),Image.NEAREST)
|
| 144 |
out[:,:,k]=(np.array(im)>0).astype(np.uint8)
|
| 145 |
return out
|
|
|
|
| 146 |
def _load_gt(case_dir,case_id,vol_shape):
|
| 147 |
candidates = [
|
| 148 |
f"{case_id}_gtv.nii.gz", f"{case_id}_seg.nii.gz", f"{case_id}_gt.nii.gz",
|
|
@@ -163,6 +219,7 @@ def _load_gt(case_dir,case_id,vol_shape):
|
|
| 163 |
print(f"[GT] load error {p.name}: {e}")
|
| 164 |
print("[GT] not found.")
|
| 165 |
return None
|
|
|
|
| 166 |
def load_case(case_id):
|
| 167 |
case_dir=DATA_ROOT/case_id
|
| 168 |
vol, affine, header = _load_png_stack(case_dir)
|
|
@@ -191,12 +248,14 @@ def load_case(case_id):
|
|
| 191 |
S.active_view="axial"
|
| 192 |
S.loaded=True
|
| 193 |
print(f"[LOAD] {case_id} | shape={S.shape}")
|
| 194 |
-
|
|
|
|
| 195 |
def _slice2d(view):
|
| 196 |
if view=="axial": sl=S.vol[:,:,S.slice["axial"]]
|
| 197 |
elif view=="sagittal": sl=S.vol[S.slice["sagittal"],:,:].T
|
| 198 |
else: sl=S.vol[:,S.slice["coronal"],:].T
|
| 199 |
return _rot90_if_needed(sl)
|
|
|
|
| 200 |
def _cross_pix_on_rot(view,w,h,x=None,y=None,z=None):
|
| 201 |
X,Y,Z=S.shape
|
| 202 |
if x is None: x=S.cross["x"]
|
|
@@ -212,11 +271,13 @@ def _cross_pix_on_rot(view,w,h,x=None,y=None,z=None):
|
|
| 212 |
u=int(round(z*(w-1)/max(Z-1,1)))
|
| 213 |
v=int(round((X-1-x)*(h-1)/max(X-1,1)))
|
| 214 |
return u,v
|
|
|
|
| 215 |
def _draw_cross(img_draw, view, w, h):
|
| 216 |
u,v=_cross_pix_on_rot(view,w,h)
|
| 217 |
img_draw.line([(u,0),(u,h)],fill=CROSS_RGB,width=1)
|
| 218 |
img_draw.line([(0,v),(w,v)],fill=CROSS_RGB,width=1)
|
| 219 |
img_draw.ellipse((u-5,v-5,u+5,v+5),fill=(255,255,255),outline=ACCENT_HEX,width=2)
|
|
|
|
| 220 |
def render_top(view):
|
| 221 |
if not S.loaded: return None
|
| 222 |
sl=_slice2d(view)
|
|
@@ -226,12 +287,13 @@ def render_top(view):
|
|
| 226 |
_draw_cross(dr, view, w, h)
|
| 227 |
S.disp_wh[view]=im.size
|
| 228 |
return im
|
| 229 |
-
|
| 230 |
def _axial_mask2d_rot(mask3d):
|
| 231 |
if mask3d is None: return None
|
| 232 |
m = mask3d[:,:,S.slice["axial"]].astype(np.uint8)
|
| 233 |
m = _rot90_if_needed(m)
|
| 234 |
return m
|
|
|
|
| 235 |
def _axial_overlay_fill(mask3d, rgba):
|
| 236 |
sl = _rot90_if_needed(S.vol[:,:,S.slice["axial"]])
|
| 237 |
base=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGBA")
|
|
@@ -240,7 +302,7 @@ def _axial_overlay_fill(mask3d, rgba):
|
|
| 240 |
m2d = _resize_slice_nearest((m2d>0).astype(np.uint8), S.render_px, S.render_px)
|
| 241 |
over=np.zeros((S.render_px,S.render_px,4),dtype=np.uint8); over[m2d>0]=rgba
|
| 242 |
return Image.alpha_composite(base,Image.fromarray(over,"RGBA")).convert("RGB")
|
| 243 |
-
|
| 244 |
def _interaction_2d():
|
| 245 |
view = S.active_view
|
| 246 |
if not S.loaded: return None
|
|
@@ -248,9 +310,8 @@ def _interaction_2d():
|
|
| 248 |
im=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGB")
|
| 249 |
w=h=S.render_px
|
| 250 |
dr=ImageDraw.Draw(im)
|
| 251 |
-
# cross
|
| 252 |
_draw_cross(dr, view, w, h)
|
| 253 |
-
|
| 254 |
tol=0
|
| 255 |
for i, (x,y,z) in enumerate(S.seeds):
|
| 256 |
on_plane = (
|
|
@@ -264,7 +325,6 @@ def _interaction_2d():
|
|
| 264 |
dr.ellipse((u-r,v-r,u+r,v+r), fill=SEED_RGB, outline=(40,140,100), width=1)
|
| 265 |
dr.text((u+6, v-8), f"{i+1}", fill=(30,30,30))
|
| 266 |
|
| 267 |
-
# bbox temp points (in bbox mode)
|
| 268 |
if S.bbox_mode:
|
| 269 |
for i, (x,y,z) in enumerate(S.bbox_points):
|
| 270 |
on_plane = (
|
|
@@ -279,7 +339,6 @@ def _interaction_2d():
|
|
| 279 |
text = "P1" if i == 0 else "P2"
|
| 280 |
dr.text((u+10, v-10), text, fill=BBOX_RGB)
|
| 281 |
|
| 282 |
-
# draw completed bboxes
|
| 283 |
for (x1,y1,z1,x2,y2,z2) in S.bboxes:
|
| 284 |
if view=="axial":
|
| 285 |
curr_z = S.slice["axial"]
|
|
@@ -289,6 +348,7 @@ def _interaction_2d():
|
|
| 289 |
dr.rectangle([u1,v1,u2,v2], outline=(0,255,0), width=2)
|
| 290 |
|
| 291 |
return im
|
|
|
|
| 292 |
# ===================== segmentation ========================================
|
| 293 |
def _segment_with_model():
|
| 294 |
if PREDICTOR is None and not _init_predictor_once():
|
|
@@ -299,18 +359,13 @@ def _segment_with_model():
|
|
| 299 |
PREDICTOR.set_image(img)
|
| 300 |
PREDICTOR.set_target_buffer(np.zeros_like(img[0], np.float32))
|
| 301 |
PREDICTOR._finish_preprocessing_and_initialize_interactions()
|
| 302 |
-
|
| 303 |
-
# Add point interactions
|
| 304 |
for (x,y,z) in S.seeds:
|
| 305 |
PREDICTOR.add_point_interaction(x, y, z, foreground=True)
|
| 306 |
-
|
| 307 |
-
# Add bbox interactions
|
| 308 |
for (x1,y1,z1,x2,y2,z2) in S.bboxes:
|
| 309 |
PREDICTOR.add_bbox_interaction(
|
| 310 |
min(x1,x2), min(y1,y2), min(z1,z2),
|
| 311 |
max(x1,x2), max(y1,y2), max(z1,z2)
|
| 312 |
)
|
| 313 |
-
|
| 314 |
PREDICTOR._predict_without_interaction()
|
| 315 |
pred = (PREDICTOR.target_buffer.astype(np.float32) > 0.5).astype(np.uint8)
|
| 316 |
if pred.shape != S.shape:
|
|
@@ -320,6 +375,7 @@ def _segment_with_model():
|
|
| 320 |
except Exception as e:
|
| 321 |
print(f"[MODEL] inference failed: {e}")
|
| 322 |
return None, "model-error"
|
|
|
|
| 323 |
def _segment_fallback():
|
| 324 |
if (not S.seeds and not S.bboxes) or not HAVE_SCIPY:
|
| 325 |
return None, "no-interactions-or-scipy"
|
|
@@ -344,6 +400,7 @@ def _segment_fallback():
|
|
| 344 |
mask=ndi.binary_closing(mask,iterations=1); mask=ndi.binary_opening(mask,iterations=1)
|
| 345 |
print(f"[FB] seg {time.time()-t0:.3f}s | vox={int(mask.sum())}")
|
| 346 |
return (mask>0).astype(np.uint8), "ok"
|
|
|
|
| 347 |
def do_segment():
|
| 348 |
pred, tag = _segment_with_model()
|
| 349 |
if pred is None:
|
|
@@ -351,7 +408,7 @@ def do_segment():
|
|
| 351 |
S.pred = pred if pred is not None else None
|
| 352 |
print(f"[SEG] done: {tag}")
|
| 353 |
return "OK" if S.pred is not None else "Failed"
|
| 354 |
-
|
| 355 |
def save_prediction():
|
| 356 |
if S.pred is None:
|
| 357 |
return "No prediction to save", None
|
|
@@ -360,21 +417,18 @@ def save_prediction():
|
|
| 360 |
try:
|
| 361 |
tmp_dir = Path(tempfile.mkdtemp(prefix="menrt_output_"))
|
| 362 |
out_path = tmp_dir / f"{S.case_id}_pred.nii.gz"
|
| 363 |
-
|
| 364 |
affine = S.ref_affine if S.ref_affine is not None else np.eye(4)
|
| 365 |
header = S.ref_header.copy() if S.ref_header is not None else None
|
| 366 |
-
|
| 367 |
nii_img = nib.Nifti1Image(S.pred.astype(np.uint8), affine, header=header)
|
| 368 |
nib.save(nii_img, str(out_path))
|
| 369 |
-
|
| 370 |
print(f"[SAVE] {out_path}")
|
| 371 |
return "Saved successfully!", str(out_path)
|
| 372 |
except Exception as e:
|
| 373 |
print(f"[SAVE] error: {e}")
|
| 374 |
return f"Save failed: {e}", None
|
| 375 |
-
|
|
|
|
| 376 |
def _seed_rows():
|
| 377 |
-
"""DataFrame rows: [[#, type, view, x, y, z], ...]"""
|
| 378 |
rows=[]
|
| 379 |
for i,(x,y,z) in enumerate(S.seeds):
|
| 380 |
v = S.seed_views[i] if i < len(S.seed_views) else ""
|
|
@@ -382,8 +436,8 @@ def _seed_rows():
|
|
| 382 |
for i,(x1,y1,z1,x2,y2,z2) in enumerate(S.bboxes):
|
| 383 |
rows.append([len(S.seeds)+i+1, "bbox", "3D", f"{x1}-{x2}", f"{y1}-{y2}", f"{z1}-{z2}"])
|
| 384 |
return rows
|
|
|
|
| 385 |
def _seed_dropdown_options():
|
| 386 |
-
"""Dropdown options & default value"""
|
| 387 |
opts=[]
|
| 388 |
for i,(x,y,z) in enumerate(S.seeds):
|
| 389 |
v = S.seed_views[i] if i < len(S.seed_views) else "axial"
|
|
@@ -391,6 +445,7 @@ def _seed_dropdown_options():
|
|
| 391 |
for i,(x1,y1,z1,x2,y2,z2) in enumerate(S.bboxes):
|
| 392 |
opts.append(f"3D → bbox {i+1} → ({x1},{y1},{z1})-({x2},{y2},{z2})")
|
| 393 |
return opts
|
|
|
|
| 394 |
def _debug_widgets(current_idx=None):
|
| 395 |
rows = _seed_rows()
|
| 396 |
df_upd = gr.update(value=rows)
|
|
@@ -401,7 +456,7 @@ def _debug_widgets(current_idx=None):
|
|
| 401 |
val = (opts[current_idx] if (0 <= current_idx < len(opts)) else (opts[-1] if opts else None))
|
| 402 |
dd_upd = gr.update(choices=opts, value=val)
|
| 403 |
return df_upd, dd_upd
|
| 404 |
-
|
| 405 |
def _figs_and_imgs():
|
| 406 |
top_ax=render_top("axial")
|
| 407 |
top_sg=render_top("sagittal")
|
|
@@ -410,37 +465,48 @@ def _figs_and_imgs():
|
|
| 410 |
ax_pr = _axial_overlay_fill(S.pred, PR_RGBA_FILL)
|
| 411 |
inter2d = _interaction_2d()
|
| 412 |
return top_ax, top_sg, top_co, ax_gt, ax_pr, inter2d
|
|
|
|
| 413 |
def _bar_ranges_and_values():
|
| 414 |
X,Y,Z=S.shape
|
| 415 |
return (gr.update(minimum=0,maximum=Z-1,value=S.slice["axial"],visible=True),
|
| 416 |
gr.update(minimum=0,maximum=X-1,value=S.slice["sagittal"],visible=True),
|
| 417 |
gr.update(minimum=0,maximum=Y-1,value=S.slice["coronal"],visible=True))
|
| 418 |
-
|
| 419 |
def _parse_evt_xy(evt):
|
| 420 |
-
"""
|
| 421 |
-
|
| 422 |
-
""
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
try:
|
|
|
|
| 425 |
if hasattr(evt, "index") and evt.index is not None:
|
| 426 |
ix = evt.index
|
|
|
|
| 427 |
if isinstance(ix, (list, tuple)) and len(ix) >= 2:
|
| 428 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
if hasattr(evt, "x") and hasattr(evt, "y"):
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
except Exception:
|
| 442 |
-
pass
|
| 443 |
return None
|
|
|
|
| 444 |
def _disp_to_vol(view,u,v):
|
| 445 |
X,Y,Z=S.shape; w,h=S.disp_wh[view]
|
| 446 |
if w<=0 or h<=0: w=h=S.render_px
|
|
@@ -458,7 +524,7 @@ def _disp_to_vol(view,u,v):
|
|
| 458 |
y = S.slice["coronal"]
|
| 459 |
x=max(0,min(X-1,x)); y=max(0,min(Y-1,y)); z=max(0,min(Z-1,z))
|
| 460 |
return x,y,z
|
| 461 |
-
|
| 462 |
def _thumb_from_case(case_id,px=96):
|
| 463 |
case_dir=DATA_ROOT/case_id
|
| 464 |
pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.png")))
|
|
@@ -479,13 +545,15 @@ def _thumb_from_case(case_id,px=96):
|
|
| 479 |
im = Image.fromarray((mid*255).astype(np.uint8)).resize((px,px),Image.BILINEAR)
|
| 480 |
im = _rot90_if_needed(im)
|
| 481 |
return ImageOps.expand(im,border=1,fill=200)
|
|
|
|
| 482 |
# ---------- callbacks ----------
|
| 483 |
def on_load(case_id):
|
| 484 |
load_case(case_id)
|
| 485 |
preload_model_in_background()
|
| 486 |
imgs=_figs_and_imgs(); bars=_bar_ranges_and_values(); dbg=_debug_widgets()
|
| 487 |
-
return (*imgs,*bars,*dbg, gr.update(value="Point"), None, "")
|
| 488 |
-
|
|
|
|
| 489 |
idx=0
|
| 490 |
if hasattr(evt,"index"):
|
| 491 |
ix=evt.index; idx=int(ix[0] if isinstance(ix,(list,tuple)) else ix)
|
|
@@ -493,28 +561,42 @@ def on_gallery_select(evt: gr.events.SelectData):
|
|
| 493 |
cid=EXAMPLES[idx]
|
| 494 |
out=on_load(cid)
|
| 495 |
return (*out, gr.update(value=cid))
|
|
|
|
| 496 |
def _click_common(view, evt: gr.SelectData):
|
| 497 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
xy=_parse_evt_xy(evt)
|
| 499 |
if xy is None:
|
| 500 |
-
|
|
|
|
|
|
|
|
|
|
| 501 |
u,v=xy
|
| 502 |
x,y,z=_disp_to_vol(view,u,v)
|
|
|
|
| 503 |
|
| 504 |
if S.bbox_mode:
|
| 505 |
-
# BBox mode
|
| 506 |
S.bbox_points.append((x,y,z))
|
| 507 |
if len(S.bbox_points) == 2:
|
| 508 |
p1, p2 = S.bbox_points
|
| 509 |
S.bboxes.append((*p1, *p2))
|
| 510 |
S.bbox_points = []
|
| 511 |
print(f"[UI] bbox created: {S.bboxes[-1]}")
|
|
|
|
| 512 |
else:
|
| 513 |
print(f"[UI] bbox point {len(S.bbox_points)}/2")
|
|
|
|
| 514 |
else:
|
| 515 |
-
# Point mode
|
| 516 |
S.seeds.append((x,y,z)); S.seed_views.append(view)
|
| 517 |
print(f"[UI] seed+ {(x,y,z)} total={len(S.seeds)}")
|
|
|
|
| 518 |
|
| 519 |
S.cross={"x":x,"y":y,"z":z}
|
| 520 |
S.slice={"sagittal":x,"coronal":y,"axial":z}
|
|
@@ -522,18 +604,32 @@ def _click_common(view, evt: gr.SelectData):
|
|
| 522 |
imgs=_figs_and_imgs(); bars=_bar_ranges_and_values()
|
| 523 |
idx = len(S.seeds) + len(S.bboxes) - 1
|
| 524 |
dbg=_debug_widgets(current_idx=idx)
|
| 525 |
-
return (*imgs,*bars,*dbg, gr.update(value="Point" if not S.bbox_mode else "BBox"), None,
|
| 526 |
-
|
| 527 |
-
def
|
| 528 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
def on_seg_button():
|
| 530 |
-
msg=do_segment()
|
| 531 |
-
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
|
|
|
|
|
|
| 532 |
def on_clear():
|
| 533 |
S.seeds=[]; S.seed_views=[]; S.pred=None
|
| 534 |
S.bboxes=[]; S.bbox_points=[]
|
| 535 |
print("[UI] cleared all")
|
| 536 |
-
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
|
|
|
|
|
|
| 537 |
def on_undo():
|
| 538 |
if S.bbox_mode and S.bbox_points:
|
| 539 |
S.bbox_points.pop()
|
|
@@ -542,49 +638,66 @@ def on_undo():
|
|
| 542 |
elif S.seeds:
|
| 543 |
S.seeds.pop(); S.seed_views.pop() if S.seed_views else None
|
| 544 |
print("[UI] undo")
|
| 545 |
-
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
|
|
|
|
|
|
| 546 |
def on_mode_toggle(mode):
|
| 547 |
S.bbox_mode = (mode == "BBox")
|
| 548 |
S.bbox_points = []
|
| 549 |
print(f"[UI] mode -> {mode}")
|
| 550 |
-
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
|
|
|
|
|
|
| 551 |
def on_save():
|
| 552 |
msg, path = save_prediction()
|
| 553 |
if path:
|
| 554 |
-
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 555 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
def on_z_release(z_idx):
|
| 557 |
S.slice["axial"]=int(z_idx); S.cross["z"]=S.slice["axial"]; S.active_view="axial"
|
| 558 |
-
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
|
|
|
|
|
|
| 559 |
def on_x_release(x_idx):
|
| 560 |
S.slice["sagittal"]=int(x_idx); S.cross["x"]=S.slice["sagittal"]; S.active_view="sagittal"
|
| 561 |
-
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
|
|
|
|
|
|
| 562 |
def on_y_release(y_idx):
|
| 563 |
S.slice["coronal"]=int(y_idx); S.cross["y"]=S.slice["coronal"]; S.active_view="coronal"
|
| 564 |
-
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
|
|
|
|
|
|
| 565 |
def _jump_to_idx(idx:int):
|
| 566 |
-
"""Common: jump to selected index from log."""
|
| 567 |
total = len(S.seeds) + len(S.bboxes)
|
| 568 |
if not (0 <= idx < total):
|
| 569 |
-
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
|
|
|
| 570 |
if idx < len(S.seeds):
|
| 571 |
x,y,z = S.seeds[idx]
|
| 572 |
view = S.seed_views[idx] if idx < len(S.seed_views) else "axial"
|
| 573 |
S.cross={"x":x,"y":y,"z":z}
|
| 574 |
S.slice={"sagittal":x,"coronal":y,"axial":z}
|
| 575 |
S.active_view=view
|
| 576 |
-
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(current_idx=idx),
|
|
|
|
|
|
|
| 577 |
def on_seed_df_select(evt):
|
| 578 |
-
"""DataFrame select: evt.index -> (row, col)"""
|
| 579 |
try:
|
| 580 |
row = int(evt.index[0]) if hasattr(evt, "index") else 0
|
| 581 |
except Exception:
|
| 582 |
row = 0
|
| 583 |
return _jump_to_idx(row)
|
|
|
|
| 584 |
def on_seed_dd_change(val):
|
| 585 |
-
"""Dropdown change: id parsing"""
|
| 586 |
if not val:
|
| 587 |
-
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
|
|
|
| 588 |
try:
|
| 589 |
if "point" in val:
|
| 590 |
p = val.split("point")[1].strip()
|
|
@@ -595,6 +708,7 @@ def on_seed_dd_change(val):
|
|
| 595 |
except Exception:
|
| 596 |
idx = len(S.seeds) + len(S.bboxes) - 1
|
| 597 |
return _jump_to_idx(idx)
|
|
|
|
| 598 |
# ---------- UI ----------
|
| 599 |
css = """
|
| 600 |
:root {
|
|
@@ -606,11 +720,10 @@ css = """
|
|
| 606 |
.tiny .gr-slider input[type="range"]{height:6px}
|
| 607 |
.tiny .gr-form{gap:6px}
|
| 608 |
.smallnote{color:#3e6285;font-size:12px;margin-top:6px}
|
| 609 |
-
.dfshort { max-height: 190px; overflow: auto; }
|
| 610 |
-
.dfshort table { font-size: 12px; }
|
| 611 |
"""
|
| 612 |
-
|
| 613 |
THUMBS=[_thumb_from_case(cid,px=96) for cid in EXAMPLES]
|
|
|
|
| 614 |
with gr.Blocks(css=css, title="Interactive-MEN-RT", theme=gr.themes.Soft(), analytics_enabled=False) as demo:
|
| 615 |
gr.Markdown("""
|
| 616 |
<h3 style='color:#114b8b'>Interactive-MEN-RT Segmentation</h3>
|
|
@@ -619,10 +732,11 @@ with gr.Blocks(css=css, title="Interactive-MEN-RT", theme=gr.themes.Soft(), anal
|
|
| 619 |
<b>⚠️ Research only — Not for clinical use</b>
|
| 620 |
</div>
|
| 621 |
""")
|
|
|
|
| 622 |
with gr.Row():
|
| 623 |
with gr.Column(scale=1, min_width=280, elem_classes=["round"]):
|
| 624 |
-
gr.Markdown("<div class='section'>
|
| 625 |
-
gallery = gr.Gallery(value=THUMBS, columns=
|
| 626 |
case_dd = gr.Dropdown(choices=EXAMPLES, value=EXAMPLES[0], label="Case")
|
| 627 |
|
| 628 |
mode_radio = gr.Radio(["Point", "BBox"], value="Point", label="Interaction Mode")
|
|
@@ -643,7 +757,7 @@ with gr.Blocks(css=css, title="Interactive-MEN-RT", theme=gr.themes.Soft(), anal
|
|
| 643 |
wrap=True,
|
| 644 |
row_count=(0, "dynamic")
|
| 645 |
)
|
| 646 |
-
seed_dd = gr.Dropdown(choices=[], value=None, label="Go to
|
| 647 |
|
| 648 |
pred_file = gr.File(label="Download Prediction", visible=False)
|
| 649 |
status_text = gr.Textbox(label="Status", value="", interactive=False, lines=1)
|
|
@@ -660,43 +774,33 @@ with gr.Blocks(css=css, title="Interactive-MEN-RT", theme=gr.themes.Soft(), anal
|
|
| 660 |
coronal = gr.Image(type="pil", interactive=True, height=RENDER_PX_DEFAULT+8, label="Coronal (Y)")
|
| 661 |
y_bar = gr.Slider(0,1,value=0,step=1,label="Y", elem_classes=["tiny"])
|
| 662 |
with gr.Row():
|
| 663 |
-
out_ax_gt = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="
|
| 664 |
-
out_ax_pr = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="
|
| 665 |
-
inter2d = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="
|
| 666 |
|
| 667 |
-
# outputs
|
| 668 |
outputs = [axial, sagittal, coronal, out_ax_gt, out_ax_pr, inter2d,
|
| 669 |
z_bar, x_bar, y_bar, seeds_df, seed_dd, mode_radio, pred_file, status_text]
|
| 670 |
|
| 671 |
-
# initial
|
| 672 |
demo.load(lambda: on_load(EXAMPLES[0]), [], outputs)
|
| 673 |
-
|
| 674 |
-
# case change
|
| 675 |
case_dd.change(lambda cid: on_load(cid), [case_dd], outputs)
|
| 676 |
-
|
| 677 |
-
# gallery select
|
| 678 |
gallery.select(on_gallery_select, [], outputs + [case_dd])
|
| 679 |
|
| 680 |
-
#
|
| 681 |
axial.select(on_axial_select, [], outputs)
|
| 682 |
sagittal.select(on_sagittal_select, [], outputs)
|
| 683 |
coronal.select(on_coronal_select, [], outputs)
|
| 684 |
|
| 685 |
-
# bars
|
| 686 |
z_bar.release(on_z_release, [z_bar], outputs)
|
| 687 |
x_bar.release(on_x_release, [x_bar], outputs)
|
| 688 |
y_bar.release(on_y_release, [y_bar], outputs)
|
| 689 |
|
| 690 |
-
# mode toggle
|
| 691 |
mode_radio.change(on_mode_toggle, [mode_radio], outputs)
|
| 692 |
|
| 693 |
-
# buttons
|
| 694 |
seg_btn.click(on_seg_button, [], outputs)
|
| 695 |
save_btn.click(on_save, [], outputs)
|
| 696 |
clr_btn.click(on_clear, [], outputs)
|
| 697 |
undo_btn.click(on_undo, [], outputs)
|
| 698 |
|
| 699 |
-
# debug list selection
|
| 700 |
seeds_df.select(on_seed_df_select, [], outputs)
|
| 701 |
seed_dd.change(on_seed_dd_change, [seed_dd], outputs)
|
| 702 |
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
from PIL import Image, ImageDraw, ImageOps
|
| 7 |
import gradio as gr
|
| 8 |
+
|
| 9 |
+
print("\n" + "="*60)
|
| 10 |
+
print("🔍 INTERACTIVE-MEN-RT DEMO DEBUG INFO")
|
| 11 |
+
print("="*60)
|
| 12 |
+
print(f"📦 Gradio version: {gr.__version__}")
|
| 13 |
+
print(f"📁 Current directory: {os.getcwd()}")
|
| 14 |
+
print(f"📁 Directory contents: {os.listdir('.')}")
|
| 15 |
+
|
| 16 |
+
DATA_ROOT = Path("./samples")
|
| 17 |
+
print(f"\n📂 DATA_ROOT: {DATA_ROOT}")
|
| 18 |
+
print(f"📂 DATA_ROOT exists: {DATA_ROOT.exists()}")
|
| 19 |
+
|
| 20 |
+
if DATA_ROOT.exists():
|
| 21 |
+
print(f"📂 DATA_ROOT contents: {list(DATA_ROOT.iterdir())}")
|
| 22 |
+
EXAMPLES_CHECK = ["BraTS-MEN-RT-0071-1"]
|
| 23 |
+
for case in EXAMPLES_CHECK:
|
| 24 |
+
case_dir = DATA_ROOT / case
|
| 25 |
+
print(f"\n 📦 Case: {case}")
|
| 26 |
+
print(f" Exists: {case_dir.exists()}")
|
| 27 |
+
if case_dir.exists():
|
| 28 |
+
files = list(case_dir.iterdir())
|
| 29 |
+
print(f" Files: {[f.name for f in files]}")
|
| 30 |
+
has_t1c = any(f.name.endswith('_t1c.nii.gz') for f in files)
|
| 31 |
+
print(f" ✓ Has T1c: {has_t1c}")
|
| 32 |
+
else:
|
| 33 |
+
print("❌ samples folder NOT FOUND!")
|
| 34 |
+
print("\n" + "="*60 + "\n")
|
| 35 |
+
|
| 36 |
# ---------- optional deps ----------
|
| 37 |
try:
|
| 38 |
import nibabel as nib
|
| 39 |
HAVE_NIB = True
|
| 40 |
+
print("✓ nibabel available")
|
| 41 |
+
except Exception as e:
|
| 42 |
HAVE_NIB = False
|
| 43 |
+
print(f"✗ nibabel not available: {e}")
|
| 44 |
+
|
| 45 |
try:
|
| 46 |
from scipy import ndimage as ndi
|
| 47 |
HAVE_SCIPY = True
|
| 48 |
+
print("✓ scipy available")
|
| 49 |
+
except Exception as e:
|
| 50 |
HAVE_SCIPY = False
|
| 51 |
+
print(f"✗ scipy not available: {e}")
|
| 52 |
+
|
| 53 |
# ---------- model predictor ----------
|
| 54 |
PREDICTOR = None
|
| 55 |
DEVICE = "cuda:0"
|
| 56 |
+
|
| 57 |
from huggingface_hub import snapshot_download
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
_repo_root = snapshot_download("hanjang/Interactive-MEN-RT",
|
| 61 |
+
allow_patterns=["nnUNetInteractionTrainer__nnUNetPlans__3d_fullres_scratch/**"])
|
| 62 |
+
CKPT = os.path.join(_repo_root, "nnUNetInteractionTrainer__nnUNetPlans__3d_fullres_scratch")
|
| 63 |
+
print(f"[INFO] Checkpoint path: {CKPT}")
|
| 64 |
+
if os.path.exists(CKPT):
|
| 65 |
+
contents = os.listdir(CKPT)
|
| 66 |
+
print(f"[INFO] Checkpoint contents: {contents}")
|
| 67 |
+
fold_0 = os.path.join(CKPT, "fold_0")
|
| 68 |
+
if os.path.exists(fold_0):
|
| 69 |
+
print(f"[INFO] fold_0 contents: {os.listdir(fold_0)}")
|
| 70 |
+
else:
|
| 71 |
+
print(f"[ERROR] Checkpoint path does not exist!")
|
| 72 |
+
CKPT = None
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"[ERROR] Failed to download checkpoint: {e}")
|
| 75 |
+
CKPT = None
|
| 76 |
+
|
| 77 |
for env in ("nnUNet_raw", "nnUNet_preprocessed", "nnUNet_results"):
|
| 78 |
os.environ.setdefault(env, tempfile.mkdtemp(prefix=f"{env}_"))
|
| 79 |
+
|
| 80 |
def _init_predictor_once():
|
|
|
|
| 81 |
global PREDICTOR
|
| 82 |
if PREDICTOR is not None:
|
| 83 |
return True
|
| 84 |
+
if CKPT is None:
|
| 85 |
+
print("[WARN] No checkpoint available, will use fallback only")
|
| 86 |
+
return False
|
| 87 |
try:
|
| 88 |
import torch
|
| 89 |
from Interactive_MEN_RT_predictor import InteractiveMENRTPredictor
|
|
|
|
| 95 |
model_training_output_dir=CKPT, use_fold=0, checkpoint_name="checkpoint_best.pth"
|
| 96 |
)
|
| 97 |
PREDICTOR = pred
|
|
|
|
| 98 |
try:
|
| 99 |
if torch.cuda.is_available():
|
| 100 |
x = np.zeros((1, 8, 8, 8), np.float32)
|
|
|
|
| 110 |
except Exception as e:
|
| 111 |
print(f"[MODEL] init failed: {e}")
|
| 112 |
return False
|
| 113 |
+
|
| 114 |
def preload_model_in_background():
|
| 115 |
threading.Thread(target=_init_predictor_once, daemon=True).start()
|
| 116 |
+
|
| 117 |
# ---------- config ----------
|
| 118 |
+
EXAMPLES = ["BraTS-MEN-RT-0071-1"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
RENDER_PX_DEFAULT = 384
|
| 120 |
+
ROT_CCW = True
|
| 121 |
+
|
| 122 |
# colors
|
| 123 |
ACCENT_HEX = "#1e90ff"
|
| 124 |
CROSS_RGB = (30, 144, 255)
|
| 125 |
+
GT_RGBA_FILL = (255, 215, 0, 128)
|
| 126 |
+
PR_RGBA_FILL = (255, 60, 60, 128)
|
| 127 |
SEED_RGB = (89, 224, 154)
|
| 128 |
+
BBOX_RGB = (255, 140, 0)
|
| 129 |
+
|
| 130 |
# ---------- state ----------
|
| 131 |
class State:
|
| 132 |
def __init__(self):
|
|
|
|
| 135 |
self.case_id=None; self.loaded=False
|
| 136 |
self.cross={"x":0,"y":0,"z":0}
|
| 137 |
self.slice={"axial":0,"sagittal":0,"coronal":0}
|
| 138 |
+
self.seeds=[]
|
| 139 |
+
self.seed_views=[]
|
| 140 |
self.render_px=RENDER_PX_DEFAULT
|
| 141 |
self.disp_wh={"axial":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT),
|
| 142 |
"sagittal":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT),
|
| 143 |
"coronal":(RENDER_PX_DEFAULT,RENDER_PX_DEFAULT)}
|
| 144 |
+
self.active_view="axial"
|
| 145 |
+
self.bbox_mode = False
|
| 146 |
+
self.bbox_points = []
|
| 147 |
+
self.bboxes = []
|
|
|
|
|
|
|
| 148 |
self.ref_affine = None
|
| 149 |
self.ref_header = None
|
| 150 |
+
|
| 151 |
S = State()
|
| 152 |
+
|
| 153 |
# ---------- utils ----------
|
| 154 |
def _norm01(a):
|
| 155 |
a=a.astype(np.float32)
|
| 156 |
p2,p98=np.percentile(a,2),np.percentile(a,98)
|
| 157 |
if p98<=p2: p2,p98=float(a.min()),float(a.max()) or 1.0
|
| 158 |
return np.clip((a-p2)/max(p98-p2,1e-6),0,1)
|
| 159 |
+
|
| 160 |
def _resize_slice_nearest(arr2d,w,h):
|
| 161 |
im=Image.fromarray(arr2d); im=im.resize((w,h),Image.NEAREST); return np.array(im)
|
| 162 |
+
|
| 163 |
def _rot90_if_needed(img_or_np):
|
| 164 |
if not ROT_CCW: return img_or_np
|
| 165 |
if isinstance(img_or_np, Image.Image): return img_or_np.rotate(90, expand=True)
|
| 166 |
return np.rot90(img_or_np, k=1)
|
| 167 |
+
|
| 168 |
# ---------- IO ----------
|
| 169 |
def _load_png_stack(case_dir):
|
| 170 |
pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.png")))
|
|
|
|
| 177 |
vol=_norm01(vol)
|
| 178 |
print(f"[PIL] {len(pngs)} slices -> {vol.shape} in {time.time()-t0:.2f}s")
|
| 179 |
return vol, None, None
|
| 180 |
+
|
| 181 |
def _load_nifti(case_dir,case_id,ds=1):
|
| 182 |
if not HAVE_NIB: return None, None, None
|
| 183 |
p=case_dir/f"{case_id}_t1c.nii.gz"
|
|
|
|
| 188 |
arr=_norm01(arr)
|
| 189 |
print(f"[NIfTI] {case_id} -> {arr.shape} in {time.time()-t0:.2f}s")
|
| 190 |
return arr, nii.affine, nii.header
|
| 191 |
+
|
| 192 |
def _resample_mask_to_vol_shape(mask_xyz, vol_shape_xyz):
|
| 193 |
mx,my,mz=mask_xyz.shape; vx,vy,vz=vol_shape_xyz
|
| 194 |
out=np.zeros((vx,vy,vz),dtype=np.uint8)
|
|
|
|
| 198 |
im=Image.fromarray(sl).resize((vy,vx),Image.NEAREST)
|
| 199 |
out[:,:,k]=(np.array(im)>0).astype(np.uint8)
|
| 200 |
return out
|
| 201 |
+
|
| 202 |
def _load_gt(case_dir,case_id,vol_shape):
|
| 203 |
candidates = [
|
| 204 |
f"{case_id}_gtv.nii.gz", f"{case_id}_seg.nii.gz", f"{case_id}_gt.nii.gz",
|
|
|
|
| 219 |
print(f"[GT] load error {p.name}: {e}")
|
| 220 |
print("[GT] not found.")
|
| 221 |
return None
|
| 222 |
+
|
| 223 |
def load_case(case_id):
|
| 224 |
case_dir=DATA_ROOT/case_id
|
| 225 |
vol, affine, header = _load_png_stack(case_dir)
|
|
|
|
| 248 |
S.active_view="axial"
|
| 249 |
S.loaded=True
|
| 250 |
print(f"[LOAD] {case_id} | shape={S.shape}")
|
| 251 |
+
|
| 252 |
+
# ---------- 2D rendering ----------
|
| 253 |
def _slice2d(view):
|
| 254 |
if view=="axial": sl=S.vol[:,:,S.slice["axial"]]
|
| 255 |
elif view=="sagittal": sl=S.vol[S.slice["sagittal"],:,:].T
|
| 256 |
else: sl=S.vol[:,S.slice["coronal"],:].T
|
| 257 |
return _rot90_if_needed(sl)
|
| 258 |
+
|
| 259 |
def _cross_pix_on_rot(view,w,h,x=None,y=None,z=None):
|
| 260 |
X,Y,Z=S.shape
|
| 261 |
if x is None: x=S.cross["x"]
|
|
|
|
| 271 |
u=int(round(z*(w-1)/max(Z-1,1)))
|
| 272 |
v=int(round((X-1-x)*(h-1)/max(X-1,1)))
|
| 273 |
return u,v
|
| 274 |
+
|
| 275 |
def _draw_cross(img_draw, view, w, h):
|
| 276 |
u,v=_cross_pix_on_rot(view,w,h)
|
| 277 |
img_draw.line([(u,0),(u,h)],fill=CROSS_RGB,width=1)
|
| 278 |
img_draw.line([(0,v),(w,v)],fill=CROSS_RGB,width=1)
|
| 279 |
img_draw.ellipse((u-5,v-5,u+5,v+5),fill=(255,255,255),outline=ACCENT_HEX,width=2)
|
| 280 |
+
|
| 281 |
def render_top(view):
|
| 282 |
if not S.loaded: return None
|
| 283 |
sl=_slice2d(view)
|
|
|
|
| 287 |
_draw_cross(dr, view, w, h)
|
| 288 |
S.disp_wh[view]=im.size
|
| 289 |
return im
|
| 290 |
+
|
| 291 |
def _axial_mask2d_rot(mask3d):
|
| 292 |
if mask3d is None: return None
|
| 293 |
m = mask3d[:,:,S.slice["axial"]].astype(np.uint8)
|
| 294 |
m = _rot90_if_needed(m)
|
| 295 |
return m
|
| 296 |
+
|
| 297 |
def _axial_overlay_fill(mask3d, rgba):
|
| 298 |
sl = _rot90_if_needed(S.vol[:,:,S.slice["axial"]])
|
| 299 |
base=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGBA")
|
|
|
|
| 302 |
m2d = _resize_slice_nearest((m2d>0).astype(np.uint8), S.render_px, S.render_px)
|
| 303 |
over=np.zeros((S.render_px,S.render_px,4),dtype=np.uint8); over[m2d>0]=rgba
|
| 304 |
return Image.alpha_composite(base,Image.fromarray(over,"RGBA")).convert("RGB")
|
| 305 |
+
|
| 306 |
def _interaction_2d():
|
| 307 |
view = S.active_view
|
| 308 |
if not S.loaded: return None
|
|
|
|
| 310 |
im=Image.fromarray((sl*255).astype(np.uint8)).resize((S.render_px,S.render_px),Image.BILINEAR).convert("RGB")
|
| 311 |
w=h=S.render_px
|
| 312 |
dr=ImageDraw.Draw(im)
|
|
|
|
| 313 |
_draw_cross(dr, view, w, h)
|
| 314 |
+
|
| 315 |
tol=0
|
| 316 |
for i, (x,y,z) in enumerate(S.seeds):
|
| 317 |
on_plane = (
|
|
|
|
| 325 |
dr.ellipse((u-r,v-r,u+r,v+r), fill=SEED_RGB, outline=(40,140,100), width=1)
|
| 326 |
dr.text((u+6, v-8), f"{i+1}", fill=(30,30,30))
|
| 327 |
|
|
|
|
| 328 |
if S.bbox_mode:
|
| 329 |
for i, (x,y,z) in enumerate(S.bbox_points):
|
| 330 |
on_plane = (
|
|
|
|
| 339 |
text = "P1" if i == 0 else "P2"
|
| 340 |
dr.text((u+10, v-10), text, fill=BBOX_RGB)
|
| 341 |
|
|
|
|
| 342 |
for (x1,y1,z1,x2,y2,z2) in S.bboxes:
|
| 343 |
if view=="axial":
|
| 344 |
curr_z = S.slice["axial"]
|
|
|
|
| 348 |
dr.rectangle([u1,v1,u2,v2], outline=(0,255,0), width=2)
|
| 349 |
|
| 350 |
return im
|
| 351 |
+
|
| 352 |
# ===================== segmentation ========================================
|
| 353 |
def _segment_with_model():
|
| 354 |
if PREDICTOR is None and not _init_predictor_once():
|
|
|
|
| 359 |
PREDICTOR.set_image(img)
|
| 360 |
PREDICTOR.set_target_buffer(np.zeros_like(img[0], np.float32))
|
| 361 |
PREDICTOR._finish_preprocessing_and_initialize_interactions()
|
|
|
|
|
|
|
| 362 |
for (x,y,z) in S.seeds:
|
| 363 |
PREDICTOR.add_point_interaction(x, y, z, foreground=True)
|
|
|
|
|
|
|
| 364 |
for (x1,y1,z1,x2,y2,z2) in S.bboxes:
|
| 365 |
PREDICTOR.add_bbox_interaction(
|
| 366 |
min(x1,x2), min(y1,y2), min(z1,z2),
|
| 367 |
max(x1,x2), max(y1,y2), max(z1,z2)
|
| 368 |
)
|
|
|
|
| 369 |
PREDICTOR._predict_without_interaction()
|
| 370 |
pred = (PREDICTOR.target_buffer.astype(np.float32) > 0.5).astype(np.uint8)
|
| 371 |
if pred.shape != S.shape:
|
|
|
|
| 375 |
except Exception as e:
|
| 376 |
print(f"[MODEL] inference failed: {e}")
|
| 377 |
return None, "model-error"
|
| 378 |
+
|
| 379 |
def _segment_fallback():
|
| 380 |
if (not S.seeds and not S.bboxes) or not HAVE_SCIPY:
|
| 381 |
return None, "no-interactions-or-scipy"
|
|
|
|
| 400 |
mask=ndi.binary_closing(mask,iterations=1); mask=ndi.binary_opening(mask,iterations=1)
|
| 401 |
print(f"[FB] seg {time.time()-t0:.3f}s | vox={int(mask.sum())}")
|
| 402 |
return (mask>0).astype(np.uint8), "ok"
|
| 403 |
+
|
| 404 |
def do_segment():
|
| 405 |
pred, tag = _segment_with_model()
|
| 406 |
if pred is None:
|
|
|
|
| 408 |
S.pred = pred if pred is not None else None
|
| 409 |
print(f"[SEG] done: {tag}")
|
| 410 |
return "OK" if S.pred is not None else "Failed"
|
| 411 |
+
|
| 412 |
def save_prediction():
|
| 413 |
if S.pred is None:
|
| 414 |
return "No prediction to save", None
|
|
|
|
| 417 |
try:
|
| 418 |
tmp_dir = Path(tempfile.mkdtemp(prefix="menrt_output_"))
|
| 419 |
out_path = tmp_dir / f"{S.case_id}_pred.nii.gz"
|
|
|
|
| 420 |
affine = S.ref_affine if S.ref_affine is not None else np.eye(4)
|
| 421 |
header = S.ref_header.copy() if S.ref_header is not None else None
|
|
|
|
| 422 |
nii_img = nib.Nifti1Image(S.pred.astype(np.uint8), affine, header=header)
|
| 423 |
nib.save(nii_img, str(out_path))
|
|
|
|
| 424 |
print(f"[SAVE] {out_path}")
|
| 425 |
return "Saved successfully!", str(out_path)
|
| 426 |
except Exception as e:
|
| 427 |
print(f"[SAVE] error: {e}")
|
| 428 |
return f"Save failed: {e}", None
|
| 429 |
+
|
| 430 |
+
# ---------- helpers ----------
|
| 431 |
def _seed_rows():
|
|
|
|
| 432 |
rows=[]
|
| 433 |
for i,(x,y,z) in enumerate(S.seeds):
|
| 434 |
v = S.seed_views[i] if i < len(S.seed_views) else ""
|
|
|
|
| 436 |
for i,(x1,y1,z1,x2,y2,z2) in enumerate(S.bboxes):
|
| 437 |
rows.append([len(S.seeds)+i+1, "bbox", "3D", f"{x1}-{x2}", f"{y1}-{y2}", f"{z1}-{z2}"])
|
| 438 |
return rows
|
| 439 |
+
|
| 440 |
def _seed_dropdown_options():
|
|
|
|
| 441 |
opts=[]
|
| 442 |
for i,(x,y,z) in enumerate(S.seeds):
|
| 443 |
v = S.seed_views[i] if i < len(S.seed_views) else "axial"
|
|
|
|
| 445 |
for i,(x1,y1,z1,x2,y2,z2) in enumerate(S.bboxes):
|
| 446 |
opts.append(f"3D → bbox {i+1} → ({x1},{y1},{z1})-({x2},{y2},{z2})")
|
| 447 |
return opts
|
| 448 |
+
|
| 449 |
def _debug_widgets(current_idx=None):
|
| 450 |
rows = _seed_rows()
|
| 451 |
df_upd = gr.update(value=rows)
|
|
|
|
| 456 |
val = (opts[current_idx] if (0 <= current_idx < len(opts)) else (opts[-1] if opts else None))
|
| 457 |
dd_upd = gr.update(choices=opts, value=val)
|
| 458 |
return df_upd, dd_upd
|
| 459 |
+
|
| 460 |
def _figs_and_imgs():
|
| 461 |
top_ax=render_top("axial")
|
| 462 |
top_sg=render_top("sagittal")
|
|
|
|
| 465 |
ax_pr = _axial_overlay_fill(S.pred, PR_RGBA_FILL)
|
| 466 |
inter2d = _interaction_2d()
|
| 467 |
return top_ax, top_sg, top_co, ax_gt, ax_pr, inter2d
|
| 468 |
+
|
| 469 |
def _bar_ranges_and_values():
|
| 470 |
X,Y,Z=S.shape
|
| 471 |
return (gr.update(minimum=0,maximum=Z-1,value=S.slice["axial"],visible=True),
|
| 472 |
gr.update(minimum=0,maximum=X-1,value=S.slice["sagittal"],visible=True),
|
| 473 |
gr.update(minimum=0,maximum=Y-1,value=S.slice["coronal"],visible=True))
|
| 474 |
+
|
| 475 |
def _parse_evt_xy(evt):
|
| 476 |
+
"""강화된 이벤트 파싱"""
|
| 477 |
+
print(f"[DEBUG_EVT] Event received: {evt}")
|
| 478 |
+
print(f"[DEBUG_EVT] Event type: {type(evt)}")
|
| 479 |
+
print(f"[DEBUG_EVT] Event dir: {dir(evt)}")
|
| 480 |
+
|
| 481 |
+
if evt is None:
|
| 482 |
+
print("[DEBUG_EVT] Event is None!")
|
| 483 |
+
return None
|
| 484 |
+
|
| 485 |
try:
|
| 486 |
+
# Method 1: evt.index
|
| 487 |
if hasattr(evt, "index") and evt.index is not None:
|
| 488 |
ix = evt.index
|
| 489 |
+
print(f"[DEBUG_EVT] Found index: {ix}")
|
| 490 |
if isinstance(ix, (list, tuple)) and len(ix) >= 2:
|
| 491 |
+
result = int(ix[0]), int(ix[1])
|
| 492 |
+
print(f"[DEBUG_EVT] Parsed from index: {result}")
|
| 493 |
+
return result
|
| 494 |
+
|
| 495 |
+
# Method 2: evt.x, evt.y
|
| 496 |
if hasattr(evt, "x") and hasattr(evt, "y"):
|
| 497 |
+
x_val = getattr(evt, "x")
|
| 498 |
+
y_val = getattr(evt, "y")
|
| 499 |
+
print(f"[DEBUG_EVT] Found x={x_val}, y={y_val}")
|
| 500 |
+
if x_val is not None and y_val is not None:
|
| 501 |
+
result = int(x_val), int(y_val)
|
| 502 |
+
print(f"[DEBUG_EVT] Parsed from x,y: {result}")
|
| 503 |
+
return result
|
| 504 |
+
except Exception as e:
|
| 505 |
+
print(f"[DEBUG_EVT] Parse error: {e}")
|
| 506 |
+
|
| 507 |
+
print("[DEBUG_EVT] Failed to parse coordinates!")
|
|
|
|
|
|
|
| 508 |
return None
|
| 509 |
+
|
| 510 |
def _disp_to_vol(view,u,v):
|
| 511 |
X,Y,Z=S.shape; w,h=S.disp_wh[view]
|
| 512 |
if w<=0 or h<=0: w=h=S.render_px
|
|
|
|
| 524 |
y = S.slice["coronal"]
|
| 525 |
x=max(0,min(X-1,x)); y=max(0,min(Y-1,y)); z=max(0,min(Z-1,z))
|
| 526 |
return x,y,z
|
| 527 |
+
|
| 528 |
def _thumb_from_case(case_id,px=96):
|
| 529 |
case_dir=DATA_ROOT/case_id
|
| 530 |
pngs = sorted(glob.glob(str(case_dir / "png_axial" / "*.png")))
|
|
|
|
| 545 |
im = Image.fromarray((mid*255).astype(np.uint8)).resize((px,px),Image.BILINEAR)
|
| 546 |
im = _rot90_if_needed(im)
|
| 547 |
return ImageOps.expand(im,border=1,fill=200)
|
| 548 |
+
|
| 549 |
# ---------- callbacks ----------
|
| 550 |
def on_load(case_id):
|
| 551 |
load_case(case_id)
|
| 552 |
preload_model_in_background()
|
| 553 |
imgs=_figs_and_imgs(); bars=_bar_ranges_and_values(); dbg=_debug_widgets()
|
| 554 |
+
return (*imgs,*bars,*dbg, gr.update(value="Point"), None, "Loaded")
|
| 555 |
+
|
| 556 |
+
def on_gallery_select(evt: gr.SelectData):
|
| 557 |
idx=0
|
| 558 |
if hasattr(evt,"index"):
|
| 559 |
ix=evt.index; idx=int(ix[0] if isinstance(ix,(list,tuple)) else ix)
|
|
|
|
| 561 |
cid=EXAMPLES[idx]
|
| 562 |
out=on_load(cid)
|
| 563 |
return (*out, gr.update(value=cid))
|
| 564 |
+
|
| 565 |
def _click_common(view, evt: gr.SelectData):
|
| 566 |
+
print(f"[CLICK_HANDLER] Called for view={view}")
|
| 567 |
+
print(f"[CLICK_HANDLER] S.loaded={S.loaded}")
|
| 568 |
+
print(f"[CLICK_HANDLER] S.bbox_mode={S.bbox_mode}")
|
| 569 |
+
|
| 570 |
+
if not S.loaded:
|
| 571 |
+
print("[CLICK_HANDLER] Not loaded yet!")
|
| 572 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 573 |
+
gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "Not loaded")
|
| 574 |
+
|
| 575 |
xy=_parse_evt_xy(evt)
|
| 576 |
if xy is None:
|
| 577 |
+
print("[CLICK_HANDLER] Failed to parse event!")
|
| 578 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 579 |
+
gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "Parse failed")
|
| 580 |
+
|
| 581 |
u,v=xy
|
| 582 |
x,y,z=_disp_to_vol(view,u,v)
|
| 583 |
+
print(f"[CLICK_HANDLER] Clicked at display ({u},{v}) -> volume ({x},{y},{z})")
|
| 584 |
|
| 585 |
if S.bbox_mode:
|
|
|
|
| 586 |
S.bbox_points.append((x,y,z))
|
| 587 |
if len(S.bbox_points) == 2:
|
| 588 |
p1, p2 = S.bbox_points
|
| 589 |
S.bboxes.append((*p1, *p2))
|
| 590 |
S.bbox_points = []
|
| 591 |
print(f"[UI] bbox created: {S.bboxes[-1]}")
|
| 592 |
+
status = "BBox created!"
|
| 593 |
else:
|
| 594 |
print(f"[UI] bbox point {len(S.bbox_points)}/2")
|
| 595 |
+
status = f"BBox point {len(S.bbox_points)}/2"
|
| 596 |
else:
|
|
|
|
| 597 |
S.seeds.append((x,y,z)); S.seed_views.append(view)
|
| 598 |
print(f"[UI] seed+ {(x,y,z)} total={len(S.seeds)}")
|
| 599 |
+
status = f"Added point {len(S.seeds)}"
|
| 600 |
|
| 601 |
S.cross={"x":x,"y":y,"z":z}
|
| 602 |
S.slice={"sagittal":x,"coronal":y,"axial":z}
|
|
|
|
| 604 |
imgs=_figs_and_imgs(); bars=_bar_ranges_and_values()
|
| 605 |
idx = len(S.seeds) + len(S.bboxes) - 1
|
| 606 |
dbg=_debug_widgets(current_idx=idx)
|
| 607 |
+
return (*imgs,*bars,*dbg, gr.update(value="Point" if not S.bbox_mode else "BBox"), None, status)
|
| 608 |
+
|
| 609 |
+
def on_axial_select(evt: gr.SelectData):
|
| 610 |
+
print("[EVENT] on_axial_select triggered!")
|
| 611 |
+
return _click_common("axial", evt)
|
| 612 |
+
|
| 613 |
+
def on_sagittal_select(evt: gr.SelectData):
|
| 614 |
+
print("[EVENT] on_sagittal_select triggered!")
|
| 615 |
+
return _click_common("sagittal", evt)
|
| 616 |
+
|
| 617 |
+
def on_coronal_select(evt: gr.SelectData):
|
| 618 |
+
print("[EVENT] on_coronal_select triggered!")
|
| 619 |
+
return _click_common("coronal", evt)
|
| 620 |
+
|
| 621 |
def on_seg_button():
|
| 622 |
+
msg=do_segment()
|
| 623 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 624 |
+
gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Segment: {msg}")
|
| 625 |
+
|
| 626 |
def on_clear():
|
| 627 |
S.seeds=[]; S.seed_views=[]; S.pred=None
|
| 628 |
S.bboxes=[]; S.bbox_points=[]
|
| 629 |
print("[UI] cleared all")
|
| 630 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 631 |
+
gr.update(value="Point"), None, "Cleared")
|
| 632 |
+
|
| 633 |
def on_undo():
|
| 634 |
if S.bbox_mode and S.bbox_points:
|
| 635 |
S.bbox_points.pop()
|
|
|
|
| 638 |
elif S.seeds:
|
| 639 |
S.seeds.pop(); S.seed_views.pop() if S.seed_views else None
|
| 640 |
print("[UI] undo")
|
| 641 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 642 |
+
gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "Undo")
|
| 643 |
+
|
| 644 |
def on_mode_toggle(mode):
|
| 645 |
S.bbox_mode = (mode == "BBox")
|
| 646 |
S.bbox_points = []
|
| 647 |
print(f"[UI] mode -> {mode}")
|
| 648 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 649 |
+
gr.update(value=mode), None, f"Mode: {mode}")
|
| 650 |
+
|
| 651 |
def on_save():
|
| 652 |
msg, path = save_prediction()
|
| 653 |
if path:
|
| 654 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 655 |
+
gr.update(value="Point" if not S.bbox_mode else "BBox"),
|
| 656 |
+
gr.update(visible=True, value=path), msg)
|
| 657 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 658 |
+
gr.update(value="Point" if not S.bbox_mode else "BBox"),
|
| 659 |
+
gr.update(visible=False, value=None), msg)
|
| 660 |
+
|
| 661 |
def on_z_release(z_idx):
|
| 662 |
S.slice["axial"]=int(z_idx); S.cross["z"]=S.slice["axial"]; S.active_view="axial"
|
| 663 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 664 |
+
gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Z={z_idx}")
|
| 665 |
+
|
| 666 |
def on_x_release(x_idx):
|
| 667 |
S.slice["sagittal"]=int(x_idx); S.cross["x"]=S.slice["sagittal"]; S.active_view="sagittal"
|
| 668 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 669 |
+
gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"X={x_idx}")
|
| 670 |
+
|
| 671 |
def on_y_release(y_idx):
|
| 672 |
S.slice["coronal"]=int(y_idx); S.cross["y"]=S.slice["coronal"]; S.active_view="coronal"
|
| 673 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 674 |
+
gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Y={y_idx}")
|
| 675 |
+
|
| 676 |
def _jump_to_idx(idx:int):
|
|
|
|
| 677 |
total = len(S.seeds) + len(S.bboxes)
|
| 678 |
if not (0 <= idx < total):
|
| 679 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 680 |
+
gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
|
| 681 |
if idx < len(S.seeds):
|
| 682 |
x,y,z = S.seeds[idx]
|
| 683 |
view = S.seed_views[idx] if idx < len(S.seed_views) else "axial"
|
| 684 |
S.cross={"x":x,"y":y,"z":z}
|
| 685 |
S.slice={"sagittal":x,"coronal":y,"axial":z}
|
| 686 |
S.active_view=view
|
| 687 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(current_idx=idx),
|
| 688 |
+
gr.update(value="Point" if not S.bbox_mode else "BBox"), None, f"Jump to {idx+1}")
|
| 689 |
+
|
| 690 |
def on_seed_df_select(evt):
|
|
|
|
| 691 |
try:
|
| 692 |
row = int(evt.index[0]) if hasattr(evt, "index") else 0
|
| 693 |
except Exception:
|
| 694 |
row = 0
|
| 695 |
return _jump_to_idx(row)
|
| 696 |
+
|
| 697 |
def on_seed_dd_change(val):
|
|
|
|
| 698 |
if not val:
|
| 699 |
+
return (*_figs_and_imgs(),*_bar_ranges_and_values(),*_debug_widgets(),
|
| 700 |
+
gr.update(value="Point" if not S.bbox_mode else "BBox"), None, "")
|
| 701 |
try:
|
| 702 |
if "point" in val:
|
| 703 |
p = val.split("point")[1].strip()
|
|
|
|
| 708 |
except Exception:
|
| 709 |
idx = len(S.seeds) + len(S.bboxes) - 1
|
| 710 |
return _jump_to_idx(idx)
|
| 711 |
+
|
| 712 |
# ---------- UI ----------
|
| 713 |
css = """
|
| 714 |
:root {
|
|
|
|
| 720 |
.tiny .gr-slider input[type="range"]{height:6px}
|
| 721 |
.tiny .gr-form{gap:6px}
|
| 722 |
.smallnote{color:#3e6285;font-size:12px;margin-top:6px}
|
|
|
|
|
|
|
| 723 |
"""
|
| 724 |
+
|
| 725 |
THUMBS=[_thumb_from_case(cid,px=96) for cid in EXAMPLES]
|
| 726 |
+
|
| 727 |
with gr.Blocks(css=css, title="Interactive-MEN-RT", theme=gr.themes.Soft(), analytics_enabled=False) as demo:
|
| 728 |
gr.Markdown("""
|
| 729 |
<h3 style='color:#114b8b'>Interactive-MEN-RT Segmentation</h3>
|
|
|
|
| 732 |
<b>⚠️ Research only — Not for clinical use</b>
|
| 733 |
</div>
|
| 734 |
""")
|
| 735 |
+
|
| 736 |
with gr.Row():
|
| 737 |
with gr.Column(scale=1, min_width=280, elem_classes=["round"]):
|
| 738 |
+
gr.Markdown("<div class='section'>Demo Case</div>")
|
| 739 |
+
gallery = gr.Gallery(value=THUMBS, columns=1, height=110, allow_preview=False, preview=False, show_label=False)
|
| 740 |
case_dd = gr.Dropdown(choices=EXAMPLES, value=EXAMPLES[0], label="Case")
|
| 741 |
|
| 742 |
mode_radio = gr.Radio(["Point", "BBox"], value="Point", label="Interaction Mode")
|
|
|
|
| 757 |
wrap=True,
|
| 758 |
row_count=(0, "dynamic")
|
| 759 |
)
|
| 760 |
+
seed_dd = gr.Dropdown(choices=[], value=None, label="Go to point")
|
| 761 |
|
| 762 |
pred_file = gr.File(label="Download Prediction", visible=False)
|
| 763 |
status_text = gr.Textbox(label="Status", value="", interactive=False, lines=1)
|
|
|
|
| 774 |
coronal = gr.Image(type="pil", interactive=True, height=RENDER_PX_DEFAULT+8, label="Coronal (Y)")
|
| 775 |
y_bar = gr.Slider(0,1,value=0,step=1,label="Y", elem_classes=["tiny"])
|
| 776 |
with gr.Row():
|
| 777 |
+
out_ax_gt = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Ground Truth")
|
| 778 |
+
out_ax_pr = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Prediction")
|
| 779 |
+
inter2d = gr.Image(type="pil", interactive=False, height=RENDER_PX_DEFAULT+8, label="Interactions")
|
| 780 |
|
|
|
|
| 781 |
outputs = [axial, sagittal, coronal, out_ax_gt, out_ax_pr, inter2d,
|
| 782 |
z_bar, x_bar, y_bar, seeds_df, seed_dd, mode_radio, pred_file, status_text]
|
| 783 |
|
|
|
|
| 784 |
demo.load(lambda: on_load(EXAMPLES[0]), [], outputs)
|
|
|
|
|
|
|
| 785 |
case_dd.change(lambda cid: on_load(cid), [case_dd], outputs)
|
|
|
|
|
|
|
| 786 |
gallery.select(on_gallery_select, [], outputs + [case_dd])
|
| 787 |
|
| 788 |
+
# 클릭 이벤트 - 강화된 핸들러
|
| 789 |
axial.select(on_axial_select, [], outputs)
|
| 790 |
sagittal.select(on_sagittal_select, [], outputs)
|
| 791 |
coronal.select(on_coronal_select, [], outputs)
|
| 792 |
|
|
|
|
| 793 |
z_bar.release(on_z_release, [z_bar], outputs)
|
| 794 |
x_bar.release(on_x_release, [x_bar], outputs)
|
| 795 |
y_bar.release(on_y_release, [y_bar], outputs)
|
| 796 |
|
|
|
|
| 797 |
mode_radio.change(on_mode_toggle, [mode_radio], outputs)
|
| 798 |
|
|
|
|
| 799 |
seg_btn.click(on_seg_button, [], outputs)
|
| 800 |
save_btn.click(on_save, [], outputs)
|
| 801 |
clr_btn.click(on_clear, [], outputs)
|
| 802 |
undo_btn.click(on_undo, [], outputs)
|
| 803 |
|
|
|
|
| 804 |
seeds_df.select(on_seed_df_select, [], outputs)
|
| 805 |
seed_dd.change(on_seed_dd_change, [seed_dd], outputs)
|
| 806 |
|