|
|
| import os |
| import subprocess |
| import sys |
|
|
| |
| def install_requirements(): |
| |
| required = [ |
| ("gradio", "gradio"), |
| ("plotly", "plotly"), |
| ("PIL", "pillow"), |
| ("monai", "monai"), |
| ("torch", "torch"), |
| ("numpy", "numpy"), |
| ] |
| missing = [] |
| for module_name, package_name in required: |
| try: |
| __import__(module_name) |
| except ImportError: |
| missing.append(package_name) |
| if missing: |
| print(f"installing {', '.join(missing)}...") |
| |
| subprocess.check_call( |
| [ |
| sys.executable, |
| "-m", |
| "pip", |
| "install", |
| "--no-cache-dir", |
| ] |
| + missing |
| ) |
|
|
| install_requirements() |
| |
|
|
| from functools import lru_cache |
| import numpy as np |
| import torch |
| import gradio as gr |
| import plotly.graph_objects as go |
| from PIL import Image, ImageFilter |
|
|
| |
| APP_DIR = os.path.dirname(os.path.abspath(__file__)) |
| DATA_ROOT = os.path.join(APP_DIR, "preprocessed_data", "train") |
|
|
| |
| TOP_IDS = [ |
| 'BraTS20_Training_020', |
| 'BraTS20_Training_015', |
| 'BraTS20_Training_024', |
| 'BraTS20_Training_022', |
| 'BraTS20_Training_014' |
| ] |
|
|
|
|
| def _safe_torch_load(path): |
| |
| return torch.load(path, map_location="cpu", weights_only=False) |
|
|
|
|
| def list_patient_ids(): |
| if not os.path.isdir(DATA_ROOT): |
| |
| return [] |
| |
| |
| available = [] |
| for pid in TOP_IDS: |
| p_path = os.path.join(DATA_ROOT, pid) |
| if os.path.isdir(p_path) and os.path.exists(os.path.join(p_path, "image.pt")): |
| available.append(pid) |
| |
| return available |
|
|
|
|
| def build_case_choices(): |
| ids = list_patient_ids() |
| choices = [] |
| for pid in ids: |
| |
| choices.append(f"{pid} | infected") |
| return choices |
|
|
|
|
| def parse_case(choice): |
| if not choice: |
| return None, "infected" |
| if "|" not in choice: |
| return choice.strip(), "infected" |
| pid, mode = choice.split("|", 1) |
| return pid.strip(), mode.strip() |
|
|
|
|
| @lru_cache(maxsize=16) |
| def load_patient(pid): |
| p = os.path.join(DATA_ROOT, pid) |
| img = _safe_torch_load(os.path.join(p, "image.pt")).float().cpu().numpy() |
| label = _safe_torch_load(os.path.join(p, "label.pt")).float().cpu().numpy() |
| |
| |
| if img.ndim == 3: |
| img = img[None, ...] |
| if label.ndim == 3: |
| label = label[None, ...] |
| return img, label |
|
|
|
|
| def find_best_slice(label): |
| |
| mask = label[0] > 0 |
| |
| areas = mask.reshape(mask.shape[0], mask.shape[1], -1).sum(axis=(0, 1)) |
| if areas.size == 0: |
| return 0 |
| return int(np.argmax(areas)) |
|
|
|
|
| def to_uint8(img2d): |
| |
| x = img2d.astype(np.float32) |
| x = np.clip(x, 0.0, 1.0) |
| return (x * 255.0).astype(np.uint8) |
|
|
|
|
| def to_pil_gray(img2d): |
| return Image.fromarray(to_uint8(img2d), mode="L").convert("RGB") |
|
|
|
|
| def mask_to_pil(mask2d): |
| x = (mask2d > 0).astype(np.uint8) * 255 |
| return Image.fromarray(x, mode="L") |
|
|
|
|
| def get_base_slice(img_vol, label_vol, slice_idx, mode): |
| base = img_vol[0, :, :, slice_idx] |
| if mode == "non-infected": |
| |
| mask = label_vol[0, :, :, slice_idx] > 0 |
| if mask.any(): |
| healthy_vals = base[~mask] |
| fill_val = float(np.median(healthy_vals)) if healthy_vals.size else float(np.median(base)) |
| base = base.copy() |
| base[mask] = fill_val |
| return np.clip(base, 0.0, 1.0) |
|
|
|
|
| def extract_sketch_mask(sketch, base_pil): |
| if sketch is None: |
| return None |
|
|
| |
| sketch_img = None |
| sketch_mask = None |
|
|
| if isinstance(sketch, dict): |
| if "layers" in sketch and sketch["layers"]: |
| |
| layer = sketch["layers"][0] |
| if isinstance(layer, Image.Image): |
| return np.array(layer.split()[-1]).astype(np.float32) / 255.0 |
| |
| |
| sketch_img = sketch.get("composite") or sketch.get("image") or sketch.get("background") |
| sketch_mask = sketch.get("mask") |
| elif isinstance(sketch, (list, tuple)) and len(sketch) == 2: |
| sketch_img, sketch_mask = sketch |
| else: |
| sketch_img = sketch |
|
|
| if sketch_mask is not None: |
| if isinstance(sketch_mask, Image.Image): |
| mask = np.array(sketch_mask.convert("L")) |
| else: |
| mask = np.array(sketch_mask) |
| return (mask / 255.0).astype(np.float32) |
|
|
| if sketch_img is None: |
| return None |
|
|
| if not isinstance(sketch_img, Image.Image): |
| try: |
| sketch_img = Image.fromarray(sketch_img) |
| except Exception: |
| return None |
|
|
| |
| base_arr = np.array(base_pil.convert("RGB")).astype(np.float32) / 255.0 |
| sketch_arr = np.array(sketch_img.convert("RGB")).astype(np.float32) / 255.0 |
| diff = np.abs(sketch_arr - base_arr).mean(axis=2) |
| mask = (diff > 0.05).astype(np.float32) |
| return mask |
|
|
|
|
| def smooth_mask(mask, radius=3): |
| if mask is None: |
| return None |
| pil = Image.fromarray(np.clip(mask * 255.0, 0, 255).astype(np.uint8), mode="L") |
| pil = pil.filter(ImageFilter.GaussianBlur(radius=radius)) |
| return np.array(pil).astype(np.float32) / 255.0 |
|
|
|
|
| def generate_image(base, mask): |
| |
| if mask is None or mask.max() <= 0: |
| return base, np.zeros_like(base), np.zeros_like(base) |
|
|
| m = smooth_mask(mask, radius=2) |
| |
| boost = 0.35 |
| noise = (np.random.randn(*base.shape) * 0.03).astype(np.float32) |
| gen = base + m * (boost + noise) |
| gen = np.clip(gen, 0.0, 1.0) |
| diff = np.abs(gen - base) |
| return gen, diff, m |
|
|
|
|
| def build_base_volume(img_vol, label_vol, mode): |
| base_vol = img_vol[0].copy() |
| if mode == "non-infected": |
| mask = label_vol[0] > 0 |
| if mask.any(): |
| healthy_vals = base_vol[~mask] |
| fill_val = float(np.median(healthy_vals)) if healthy_vals.size else float(np.median(base_vol)) |
| base_vol[mask] = fill_val |
| return np.clip(base_vol, 0.0, 1.0) |
|
|
|
|
| def make_3d_mask(mask2d, depth, center_idx, base_vol, sigma=2.5, max_radius=8): |
| if mask2d is None or np.max(mask2d) <= 0: |
| return None |
| |
| z = np.arange(depth, dtype=np.float32) |
| weights = np.exp(-0.5 * ((z - center_idx) / sigma) ** 2) |
| if max_radius is not None: |
| weights[np.abs(z - center_idx) > max_radius] = 0.0 |
| weights = weights / (weights.max() + 1e-6) |
| |
| mask3d = mask2d[:, :, None] * weights[None, None, :] |
| |
| |
| |
| brain_mask = base_vol > 0.05 |
| mask3d = mask3d * brain_mask |
| |
| return mask3d |
|
|
|
|
| def generate_volume(base_vol, mask3d): |
| if mask3d is None or np.max(mask3d) <= 0: |
| return base_vol, np.zeros_like(base_vol), None |
| boost = 0.28 |
| noise = (np.random.randn(*base_vol.shape) * 0.02).astype(np.float32) |
| gen_vol = np.clip(base_vol + mask3d * (boost + noise), 0.0, 1.0) |
| diff_vol = np.abs(gen_vol - base_vol) |
| return gen_vol, diff_vol, mask3d |
|
|
|
|
| def downsample_3d(vol, factor=2): |
| if factor <= 1: |
| return vol |
| if vol is None: |
| return None |
| return vol[::factor, ::factor, ::factor] |
|
|
|
|
| def build_volume_plot(base_vol, tumor_mask=None, sketch_mask=None, downsample=2): |
| |
| v = downsample_3d(base_vol, factor=downsample) |
| v = v.astype(np.float32) |
| v_min, v_max = float(v.min()), float(v.max()) |
| v_norm = (v - v_min) / (v_max - v_min + 1e-6) |
|
|
| x, y, z = np.mgrid[ |
| 0 : v.shape[0], |
| 0 : v.shape[1], |
| 0 : v.shape[2], |
| ] |
|
|
| fig = go.Figure() |
| |
| |
| |
| fig.add_trace( |
| go.Isosurface( |
| x=x.flatten(), |
| y=y.flatten(), |
| z=z.flatten(), |
| value=v_norm.flatten(), |
| isomin=0.01, |
| isomax=0.8, |
| opacity=0.1, |
| surface_count=3, |
| colorscale="Greys", |
| caps=dict(x_show=False, y_show=False, z_show=False), |
| showscale=False, |
| name="Brain" |
| ) |
| ) |
|
|
| |
| if tumor_mask is not None and np.max(tumor_mask) > 0: |
| tm = downsample_3d(tumor_mask, factor=downsample) |
| fig.add_trace( |
| go.Isosurface( |
| x=x.flatten(), |
| y=y.flatten(), |
| z=z.flatten(), |
| value=tm.flatten(), |
| isomin=0.1, |
| isomax=1.0, |
| opacity=0.4, |
| surface_count=5, |
| colorscale=[[0, 'rgb(255, 200, 200)'], [1, 'rgb(255, 0, 0)']], |
| caps=dict(x_show=False, y_show=False, z_show=False), |
| showscale=False, |
| name="Existing Tumor" |
| ) |
| ) |
|
|
| |
| if sketch_mask is not None and np.max(sketch_mask) > 0: |
| sm = downsample_3d(sketch_mask, factor=downsample) |
| fig.add_trace( |
| go.Isosurface( |
| x=x.flatten(), |
| y=y.flatten(), |
| z=z.flatten(), |
| value=sm.flatten(), |
| isomin=0.1, |
| isomax=1.0, |
| opacity=0.5, |
| surface_count=5, |
| colorscale=[[0, 'rgb(255, 100, 50)'], [1, 'rgb(180, 0, 0)']], |
| caps=dict(x_show=False, y_show=False, z_show=False), |
| showscale=False, |
| name="New Growth" |
| ) |
| ) |
|
|
| fig.update_layout( |
| scene=dict( |
| xaxis_visible=False, |
| yaxis_visible=False, |
| zaxis_visible=False, |
| aspectmode='data' |
| ), |
| margin=dict(l=0, r=0, t=10, b=0), |
| height=420, |
| legend=dict(yanchor="top", y=0.9, xanchor="left", x=0.1) |
| ) |
| return fig |
|
|
|
|
| def load_case(choice): |
| pid, mode = parse_case(choice) |
| if not pid: |
| return ( |
| None, None, None, None, None, None, None, |
| gr.update(value=0), |
| "No patient found.", |
| None, |
| ) |
|
|
| img_vol, label_vol = load_patient(pid) |
| best_slice = find_best_slice(label_vol) |
| base = get_base_slice(img_vol, label_vol, best_slice, mode) |
|
|
| input_pil = to_pil_gray(base) |
| |
| sketch_value = input_pil |
| tumor_mask_pil = mask_to_pil(label_vol[0, :, :, best_slice]) |
|
|
| |
| base_vol = build_base_volume(img_vol, label_vol, mode) |
| |
| current_tumor_mask = label_vol[0] if mode == "infected" else None |
| vol_plot = build_volume_plot(base_vol, tumor_mask=current_tumor_mask, sketch_mask=None) |
|
|
| status = f"Loaded {pid} ({mode}). Best slice is {best_slice}." |
| state = {"pid": pid, "image": img_vol, "label": label_vol, "mode": mode} |
|
|
| return ( |
| input_pil, |
| sketch_value, |
| tumor_mask_pil, |
| None, |
| None, |
| None, |
| vol_plot, |
| gr.update(value=best_slice, maximum=img_vol.shape[-1] - 1), |
| status, |
| state, |
| ) |
|
|
|
|
| def update_slice(slice_idx, state): |
| if state is None: |
| return None, None, None, "Please load a patient first." |
|
|
| img_vol = state["image"] |
| label_vol = state["label"] |
| mode = state["mode"] |
|
|
| slice_idx = int(slice_idx) |
| slice_idx = max(0, min(slice_idx, img_vol.shape[-1] - 1)) |
|
|
| base = get_base_slice(img_vol, label_vol, slice_idx, mode) |
| input_pil = to_pil_gray(base) |
| sketch_value = input_pil |
| tumor_mask_pil = mask_to_pil(label_vol[0, :, :, slice_idx]) |
|
|
| status = f"Slice {slice_idx} ({mode})" |
| return input_pil, sketch_value, tumor_mask_pil, status |
|
|
|
|
| def generate_from_sketch(sketch, slice_idx, state): |
| if state is None: |
| return None, None, None, None |
|
|
| img_vol = state["image"] |
| label_vol = state["label"] |
| mode = state["mode"] |
|
|
| slice_idx = int(slice_idx) |
| base = get_base_slice(img_vol, label_vol, slice_idx, mode) |
| base_pil = to_pil_gray(base) |
|
|
| mask = extract_sketch_mask(sketch, base_pil) |
| gen, diff, mask_sm = generate_image(base, mask) |
|
|
| gen_pil = to_pil_gray(gen) |
| diff_pil = to_pil_gray(diff / (diff.max() + 1e-6)) |
| mask_pil = mask_to_pil(mask_sm) if mask_sm is not None else None |
| |
| |
| base_vol = build_base_volume(img_vol, label_vol, mode) |
| |
| |
| mask2d = mask_sm if mask_sm is not None else mask |
| sketch_mask_3d = make_3d_mask(mask2d, depth=base_vol.shape[2], center_idx=slice_idx, base_vol=base_vol) |
| |
| |
| existing_tumor_mask_3d = label_vol[0] if mode == "infected" else None |
| |
| |
| gen_vol, _, _ = generate_volume(base_vol, sketch_mask_3d) |
| |
| |
| vol_plot = build_volume_plot(gen_vol, tumor_mask=existing_tumor_mask_3d, sketch_mask=sketch_mask_3d) |
|
|
| return gen_pil, diff_pil, mask_pil, vol_plot |
|
|
|
|
| def build_demo(): |
| case_choices = build_case_choices() |
| default_choice = case_choices[0] if case_choices else None |
|
|
| |
| with gr.Blocks(title="FYP Brain Simulation") as demo: |
| gr.Markdown( |
| """ |
| # Brain Tumor Simulation |
| |
| This is a prototype for my Final Year Project. |
| Instructions: |
| 1. Select a patient ID from the dropdown. (Showing Top 5 Infected Cases) |
| 2. The system will load the MRI and find the slice with the largest tumor. |
| 3. Use the **Sketch** tool to draw projected growth. |
| 4. Click **Generate** to simulate the growth prediction. |
| """ |
| ) |
|
|
| with gr.Row(): |
| case_dd = gr.Dropdown( |
| choices=case_choices, |
| value=default_choice, |
| label="Select Patient", |
| ) |
| load_btn = gr.Button("Load Data") |
| slice_slider = gr.Slider(0, 95, value=0, step=1, label="Slice Number") |
|
|
| |
| vol_plot = gr.Plot(label="3D Volume (Full Brain View - Rotate to Explore)") |
|
|
| status = gr.Markdown("Load a case to begin.") |
| state = gr.State() |
|
|
| with gr.Row(): |
| input_img = gr.Image(label="Input MRI (selected slice)", interactive=False) |
| |
| sketch_img = gr.ImageEditor(label="Sketch tumor growth", type="pil", interactive=True) |
|
|
| with gr.Row(): |
| gen_img = gr.Image(label="Generated MRI") |
| diff_img = gr.Image(label="Difference (changes)") |
| sketch_mask = gr.Image(label="Sketch Mask") |
| tumor_mask = gr.Image(label="Tumor Mask (label)") |
|
|
| generate_btn = gr.Button("Generate") |
|
|
| load_btn.click( |
| fn=load_case, |
| inputs=[case_dd], |
| outputs=[input_img, sketch_img, tumor_mask, gen_img, diff_img, sketch_mask, vol_plot, slice_slider, status, state], |
| ) |
|
|
| slice_slider.change( |
| fn=update_slice, |
| inputs=[slice_slider, state], |
| outputs=[input_img, sketch_img, tumor_mask, status], |
| ) |
|
|
| generate_btn.click( |
| fn=generate_from_sketch, |
| inputs=[sketch_img, slice_slider, state], |
| outputs=[gen_img, diff_img, sketch_mask, vol_plot], |
| ) |
|
|
| return demo |
|
|
|
|
| demo = build_demo() |
|
|
| |
| |
| if __name__ == "__main__": |
| demo.launch() |
|
|