SketchMRI / app.py
Banuka's picture
Upload app.py with huggingface_hub
d3bcbda verified
import os
import subprocess
import sys
# --- dependency check ---
def install_requirements():
# just in case dependencies are missing in a fresh colab env
required = [
("gradio", "gradio"),
("plotly", "plotly"),
("PIL", "pillow"),
("monai", "monai"),
("torch", "torch"),
("numpy", "numpy"),
]
missing = []
for module_name, package_name in required:
try:
__import__(module_name)
except ImportError:
missing.append(package_name)
if missing:
print(f"installing {', '.join(missing)}...")
# using subprocess to ensure they get installed in the current environment
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"--no-cache-dir",
]
+ missing
)
install_requirements()
# ------------------------
from functools import lru_cache
import numpy as np
import torch
import gradio as gr
import plotly.graph_objects as go
from PIL import Image, ImageFilter
# set up paths
APP_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_ROOT = os.path.join(APP_DIR, "preprocessed_data", "train")
# top 5 highly infected cases for demo purposes
TOP_IDS = [
'BraTS20_Training_020',
'BraTS20_Training_015',
'BraTS20_Training_024',
'BraTS20_Training_022',
'BraTS20_Training_014'
]
def _safe_torch_load(path):
# loading with weights_only=False because we have custom classes sometimes
return torch.load(path, map_location="cpu", weights_only=False)
def list_patient_ids():
if not os.path.isdir(DATA_ROOT):
# fallback just in case
return []
# only return the top 5 we selected, assuming they exist
available = []
for pid in TOP_IDS:
p_path = os.path.join(DATA_ROOT, pid)
if os.path.isdir(p_path) and os.path.exists(os.path.join(p_path, "image.pt")):
available.append(pid)
return available
def build_case_choices():
ids = list_patient_ids()
choices = []
for pid in ids:
# these are the options in the dropdown
choices.append(f"{pid} | infected")
return choices
def parse_case(choice):
if not choice:
return None, "infected"
if "|" not in choice:
return choice.strip(), "infected"
pid, mode = choice.split("|", 1)
return pid.strip(), mode.strip()
@lru_cache(maxsize=16)
def load_patient(pid):
p = os.path.join(DATA_ROOT, pid)
img = _safe_torch_load(os.path.join(p, "image.pt")).float().cpu().numpy()
label = _safe_torch_load(os.path.join(p, "label.pt")).float().cpu().numpy()
# fix dimensions if they are missing the channel dim
if img.ndim == 3:
img = img[None, ...]
if label.ndim == 3:
label = label[None, ...]
return img, label
def find_best_slice(label):
# find which slice has the most tumor so we show that one first
mask = label[0] > 0
# sum over H,W for each slice
areas = mask.reshape(mask.shape[0], mask.shape[1], -1).sum(axis=(0, 1))
if areas.size == 0:
return 0
return int(np.argmax(areas))
def to_uint8(img2d):
# normalize to 0-255 for PIL
x = img2d.astype(np.float32)
x = np.clip(x, 0.0, 1.0)
return (x * 255.0).astype(np.uint8)
def to_pil_gray(img2d):
return Image.fromarray(to_uint8(img2d), mode="L").convert("RGB")
def mask_to_pil(mask2d):
x = (mask2d > 0).astype(np.uint8) * 255
return Image.fromarray(x, mode="L")
def get_base_slice(img_vol, label_vol, slice_idx, mode):
base = img_vol[0, :, :, slice_idx]
if mode == "non-infected":
# if we want to show 'healthy', we try to remove the tumor visualization
mask = label_vol[0, :, :, slice_idx] > 0
if mask.any():
healthy_vals = base[~mask]
fill_val = float(np.median(healthy_vals)) if healthy_vals.size else float(np.median(base))
base = base.copy()
base[mask] = fill_val
return np.clip(base, 0.0, 1.0)
def extract_sketch_mask(sketch, base_pil):
if sketch is None:
return None
# gradio 4.x returns a dictionary for ImageEditor
sketch_img = None
sketch_mask = None
if isinstance(sketch, dict):
if "layers" in sketch and sketch["layers"]:
# drawing is in the layers
layer = sketch["layers"][0]
if isinstance(layer, Image.Image):
return np.array(layer.split()[-1]).astype(np.float32) / 255.0
# fallback
sketch_img = sketch.get("composite") or sketch.get("image") or sketch.get("background")
sketch_mask = sketch.get("mask")
elif isinstance(sketch, (list, tuple)) and len(sketch) == 2:
sketch_img, sketch_mask = sketch
else:
sketch_img = sketch
if sketch_mask is not None:
if isinstance(sketch_mask, Image.Image):
mask = np.array(sketch_mask.convert("L"))
else:
mask = np.array(sketch_mask)
return (mask / 255.0).astype(np.float32)
if sketch_img is None:
return None
if not isinstance(sketch_img, Image.Image):
try:
sketch_img = Image.fromarray(sketch_img)
except Exception:
return None
# compare sketch to base to find where user drew
base_arr = np.array(base_pil.convert("RGB")).astype(np.float32) / 255.0
sketch_arr = np.array(sketch_img.convert("RGB")).astype(np.float32) / 255.0
diff = np.abs(sketch_arr - base_arr).mean(axis=2)
mask = (diff > 0.05).astype(np.float32)
return mask
def smooth_mask(mask, radius=3):
if mask is None:
return None
pil = Image.fromarray(np.clip(mask * 255.0, 0, 255).astype(np.uint8), mode="L")
pil = pil.filter(ImageFilter.GaussianBlur(radius=radius))
return np.array(pil).astype(np.float32) / 255.0
def generate_image(base, mask):
# this function simulates the diffusion model output for the demo
if mask is None or mask.max() <= 0:
return base, np.zeros_like(base), np.zeros_like(base)
m = smooth_mask(mask, radius=2)
# add some intensity and noise to simulate the growth
boost = 0.35
noise = (np.random.randn(*base.shape) * 0.03).astype(np.float32)
gen = base + m * (boost + noise)
gen = np.clip(gen, 0.0, 1.0)
diff = np.abs(gen - base)
return gen, diff, m
def build_base_volume(img_vol, label_vol, mode):
base_vol = img_vol[0].copy()
if mode == "non-infected":
mask = label_vol[0] > 0
if mask.any():
healthy_vals = base_vol[~mask]
fill_val = float(np.median(healthy_vals)) if healthy_vals.size else float(np.median(base_vol))
base_vol[mask] = fill_val
return np.clip(base_vol, 0.0, 1.0)
def make_3d_mask(mask2d, depth, center_idx, base_vol, sigma=2.5, max_radius=8):
if mask2d is None or np.max(mask2d) <= 0:
return None
# propagate the 2D mask into 3D using gaussian weights
z = np.arange(depth, dtype=np.float32)
weights = np.exp(-0.5 * ((z - center_idx) / sigma) ** 2)
if max_radius is not None:
weights[np.abs(z - center_idx) > max_radius] = 0.0
weights = weights / (weights.max() + 1e-6)
mask3d = mask2d[:, :, None] * weights[None, None, :]
# constraint: mask shouldn't be outside the brain area
# brain area is roughly where base_vol > 0.05
brain_mask = base_vol > 0.05
mask3d = mask3d * brain_mask
return mask3d
def generate_volume(base_vol, mask3d):
if mask3d is None or np.max(mask3d) <= 0:
return base_vol, np.zeros_like(base_vol), None
boost = 0.28
noise = (np.random.randn(*base_vol.shape) * 0.02).astype(np.float32)
gen_vol = np.clip(base_vol + mask3d * (boost + noise), 0.0, 1.0)
diff_vol = np.abs(gen_vol - base_vol)
return gen_vol, diff_vol, mask3d
def downsample_3d(vol, factor=2):
if factor <= 1:
return vol
if vol is None:
return None
return vol[::factor, ::factor, ::factor]
def build_volume_plot(base_vol, tumor_mask=None, sketch_mask=None, downsample=2):
# downsample=2 to prevent browser crash (3D mesh too large)
v = downsample_3d(base_vol, factor=downsample)
v = v.astype(np.float32)
v_min, v_max = float(v.min()), float(v.max())
v_norm = (v - v_min) / (v_max - v_min + 1e-6)
x, y, z = np.mgrid[
0 : v.shape[0],
0 : v.shape[1],
0 : v.shape[2],
]
fig = go.Figure()
# 1. Brain Volume (Gray, transparent shell)
# Caps=False hides the flat walls if the brain touches the image edge
fig.add_trace(
go.Isosurface(
x=x.flatten(),
y=y.flatten(),
z=z.flatten(),
value=v_norm.flatten(),
isomin=0.01, # Catch even faint brain tissue
isomax=0.8,
opacity=0.1,
surface_count=3, # Reduced for performance
colorscale="Greys",
caps=dict(x_show=False, y_show=False, z_show=False),
showscale=False,
name="Brain"
)
)
# 2. Existing Tumor (Red - Internal Core)
if tumor_mask is not None and np.max(tumor_mask) > 0:
tm = downsample_3d(tumor_mask, factor=downsample)
fig.add_trace(
go.Isosurface(
x=x.flatten(),
y=y.flatten(),
z=z.flatten(),
value=tm.flatten(),
isomin=0.1,
isomax=1.0,
opacity=0.4,
surface_count=5, # Smoother
colorscale=[[0, 'rgb(255, 200, 200)'], [1, 'rgb(255, 0, 0)']], # Fade to red
caps=dict(x_show=False, y_show=False, z_show=False),
showscale=False,
name="Existing Tumor"
)
)
# 3. Sketch/Progression (Glowing Red - Active Growth)
if sketch_mask is not None and np.max(sketch_mask) > 0:
sm = downsample_3d(sketch_mask, factor=downsample)
fig.add_trace(
go.Isosurface(
x=x.flatten(),
y=y.flatten(),
z=z.flatten(),
value=sm.flatten(),
isomin=0.1,
isomax=1.0,
opacity=0.5,
surface_count=5,
colorscale=[[0, 'rgb(255, 100, 50)'], [1, 'rgb(180, 0, 0)']], # Orange-ish to Dark Red mix
caps=dict(x_show=False, y_show=False, z_show=False),
showscale=False,
name="New Growth"
)
)
fig.update_layout(
scene=dict(
xaxis_visible=False,
yaxis_visible=False,
zaxis_visible=False,
aspectmode='data' # Ensures the brain isn't squashed
),
margin=dict(l=0, r=0, t=10, b=0),
height=420,
legend=dict(yanchor="top", y=0.9, xanchor="left", x=0.1)
)
return fig
def load_case(choice):
pid, mode = parse_case(choice)
if not pid:
return (
None, None, None, None, None, None, None,
gr.update(value=0),
"No patient found.",
None,
)
img_vol, label_vol = load_patient(pid)
best_slice = find_best_slice(label_vol)
base = get_base_slice(img_vol, label_vol, best_slice, mode)
input_pil = to_pil_gray(base)
# for image editor, the image becomes the background
sketch_value = input_pil
tumor_mask_pil = mask_to_pil(label_vol[0, :, :, best_slice])
# Build initial volume plot of just the existing tumor
base_vol = build_base_volume(img_vol, label_vol, mode)
# We pass label_vol as tumor_mask if in infected mode
current_tumor_mask = label_vol[0] if mode == "infected" else None
vol_plot = build_volume_plot(base_vol, tumor_mask=current_tumor_mask, sketch_mask=None)
status = f"Loaded {pid} ({mode}). Best slice is {best_slice}."
state = {"pid": pid, "image": img_vol, "label": label_vol, "mode": mode}
return (
input_pil,
sketch_value,
tumor_mask_pil,
None,
None,
None,
vol_plot,
gr.update(value=best_slice, maximum=img_vol.shape[-1] - 1),
status,
state,
)
def update_slice(slice_idx, state):
if state is None:
return None, None, None, "Please load a patient first."
img_vol = state["image"]
label_vol = state["label"]
mode = state["mode"]
slice_idx = int(slice_idx)
slice_idx = max(0, min(slice_idx, img_vol.shape[-1] - 1))
base = get_base_slice(img_vol, label_vol, slice_idx, mode)
input_pil = to_pil_gray(base)
sketch_value = input_pil
tumor_mask_pil = mask_to_pil(label_vol[0, :, :, slice_idx])
status = f"Slice {slice_idx} ({mode})"
return input_pil, sketch_value, tumor_mask_pil, status
def generate_from_sketch(sketch, slice_idx, state):
if state is None:
return None, None, None, None
img_vol = state["image"]
label_vol = state["label"]
mode = state["mode"]
slice_idx = int(slice_idx)
base = get_base_slice(img_vol, label_vol, slice_idx, mode)
base_pil = to_pil_gray(base)
mask = extract_sketch_mask(sketch, base_pil)
gen, diff, mask_sm = generate_image(base, mask)
gen_pil = to_pil_gray(gen)
diff_pil = to_pil_gray(diff / (diff.max() + 1e-6))
mask_pil = mask_to_pil(mask_sm) if mask_sm is not None else None
# 3D Visualization
base_vol = build_base_volume(img_vol, label_vol, mode)
# Prepare User Sketch Mask (3D)
mask2d = mask_sm if mask_sm is not None else mask
sketch_mask_3d = make_3d_mask(mask2d, depth=base_vol.shape[2], center_idx=slice_idx, base_vol=base_vol)
# Prepare Existing Tumor Mask (3D)
existing_tumor_mask_3d = label_vol[0] if mode == "infected" else None
# Generate the Volume (Visual effect on brain tissue)
gen_vol, _, _ = generate_volume(base_vol, sketch_mask_3d)
# Build the combined plot
vol_plot = build_volume_plot(gen_vol, tumor_mask=existing_tumor_mask_3d, sketch_mask=sketch_mask_3d)
return gen_pil, diff_pil, mask_pil, vol_plot
def build_demo():
case_choices = build_case_choices()
default_choice = case_choices[0] if case_choices else None
# simpler theme
with gr.Blocks(title="FYP Brain Simulation") as demo:
gr.Markdown(
"""
# Brain Tumor Simulation
This is a prototype for my Final Year Project.
Instructions:
1. Select a patient ID from the dropdown. (Showing Top 5 Infected Cases)
2. The system will load the MRI and find the slice with the largest tumor.
3. Use the **Sketch** tool to draw projected growth.
4. Click **Generate** to simulate the growth prediction.
"""
)
with gr.Row():
case_dd = gr.Dropdown(
choices=case_choices,
value=default_choice,
label="Select Patient",
)
load_btn = gr.Button("Load Data")
slice_slider = gr.Slider(0, 95, value=0, step=1, label="Slice Number")
# Move 3D Plot to TOP as requested
vol_plot = gr.Plot(label="3D Volume (Full Brain View - Rotate to Explore)")
status = gr.Markdown("Load a case to begin.")
state = gr.State()
with gr.Row():
input_img = gr.Image(label="Input MRI (selected slice)", interactive=False)
# Use ImageEditor to allow drawing/sketching on top of the image
sketch_img = gr.ImageEditor(label="Sketch tumor growth", type="pil", interactive=True)
with gr.Row():
gen_img = gr.Image(label="Generated MRI")
diff_img = gr.Image(label="Difference (changes)")
sketch_mask = gr.Image(label="Sketch Mask")
tumor_mask = gr.Image(label="Tumor Mask (label)")
generate_btn = gr.Button("Generate")
load_btn.click(
fn=load_case,
inputs=[case_dd],
outputs=[input_img, sketch_img, tumor_mask, gen_img, diff_img, sketch_mask, vol_plot, slice_slider, status, state],
)
slice_slider.change(
fn=update_slice,
inputs=[slice_slider, state],
outputs=[input_img, sketch_img, tumor_mask, status],
)
generate_btn.click(
fn=generate_from_sketch,
inputs=[sketch_img, slice_slider, state],
outputs=[gen_img, diff_img, sketch_mask, vol_plot],
)
return demo
demo = build_demo()
# Removed the automated launch here so it doesn't double-trigger in notebooks.
# To run standalone: python app.py -> add demo.launch() below manually or use the notebook.
if __name__ == "__main__":
demo.launch()