SynCMRIApp / app.py
Ishan Kumarasinghe
Update app file for select mask option and Upload sample masks
4fc4ee0
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from monai.utils import set_determinism
from generative.networks.nets import DiffusionModelUNet, AutoencoderKL, ControlNet
from generative.networks.schedulers import DDPMScheduler
from huggingface_hub import hf_hub_download
from diffusers import UNet2DModel, DDPMScheduler as DiffusersScheduler # Rename to avoid conflict
import torch.nn as nn
import torch.nn.functional as F
from diffusion import VQVAE, Unet, LinearNoiseScheduler
# --- CONFIGURATION ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MASK_MODEL_PATH = "models/mask_diffusion.pth"
# ==========================================
# Helper Functions
# ==========================================
def get_jet_reference_colors(num_classes=4):
"""Recreates the exact RGB colors for classes 0-3 from jet colormap."""
cmap = plt.get_cmap('jet')
colors = []
for i in range(num_classes):
norm_val = i / (num_classes - 1)
rgba = cmap(norm_val)
rgb = [int(c * 255) for c in rgba[:3]]
colors.append(rgb)
return np.array(colors)
def rgb_mask_to_onehot(mask_np):
"""
Converts an RGB numpy mask (H,W,3) to a One-Hot Tensor (1, 4, H, W).
"""
# 1. Resize if needed (Gradio usually handles this, but good to be safe)
if mask_np.shape[:2] != (128, 128):
# Convert to PIL for easy resizing
img = Image.fromarray(mask_np.astype(np.uint8))
# Use NEAREST to preserve exact colors (no interpolation)
img = img.resize((128, 128), resample=Image.NEAREST)
mask_np = np.array(img)
# 2. Euclidean distance to find closest class color
ref_colors = get_jet_reference_colors(4)
# Calculate distance: (H, W, 1, 3) - (1, 1, 4, 3)
dist = np.linalg.norm(mask_np[:, :, None, :] - ref_colors[None, None, :, :], axis=3)
# 3. Argmin to get indices (0, 1, 2, 3)
label_map = np.argmin(dist, axis=2) # Shape: (128, 128)
# 4. One-Hot Encoding
mask_tensor = torch.tensor(label_map, dtype=torch.long)
mask_onehot = F.one_hot(mask_tensor, num_classes=4).permute(2, 0, 1).float()
# 5. Add Batch Dimension -> (1, 4, 128, 128)
return mask_onehot.unsqueeze(0).to(DEVICE)
class LDMConfig:
def __init__(self):
self.im_size = 128
self.ldm_params = {
'time_emb_dim': 256,
'down_channels': [128, 256, 512],
'mid_channels': [512, 256],
'down_sample': [True, True],
'attn_down': [False, True],
'norm_channels': 32,
'num_heads': 8,
'conv_out_channels': 128,
'num_down_layers': 2,
'num_mid_layers': 2,
'num_up_layers': 2,
'condition_config': {
'condition_types': ['image'],
'image_condition_config': {
'image_condition_input_channels': 4,
'image_condition_output_channels': 1,
}
}
}
self.autoencoder_params = {
'z_channels': 4,
'codebook_size': 8192,
'down_channels': [64, 128, 256],
'mid_channels': [256, 256],
'down_sample': [True, True],
'attn_down': [False, False],
'norm_channels': 32,
'num_heads': 4,
'num_down_layers': 2,
'num_mid_layers': 2,
'num_up_layers': 2
}
# DEFINITIONS FOR FLOW MATCHING
class MergedModel(nn.Module):
def __init__(self, unet, controlnet=None, max_timestep=1000):
super().__init__()
self.unet = unet
self.controlnet = controlnet
self.max_timestep = max_timestep
self.has_controlnet = controlnet is not None
def forward(self, x, t, cond=None, masks=None):
# Scale t from [0,1] to [0, 999]
t = t * (self.max_timestep - 1)
t = t.floor().long()
if t.dim() == 0: t = t.expand(x.shape[0])
if self.has_controlnet:
down_res, mid_res = self.controlnet(x=x, timesteps=t, controlnet_cond=masks, context=cond)
return self.unet(x=x, timesteps=t, context=cond,
down_block_additional_residuals=down_res,
mid_block_additional_residual=mid_res)
return self.unet(x=x, timesteps=t, context=cond)
# ==========================================
# 1. MODEL LOADING (Cached)
# ==========================================
# We use global variables to load models only once
models = {
"mask": None,
"ddpm": None,
"ldm": None,
"fm": None
}
def load_mask_model():
if models["mask"] is None:
print("Loading Mask Model...")
model = DiffusionModelUNet(
spatial_dims=2,
in_channels=4,
out_channels=4,
num_channels=(64, 128, 256, 512),
attention_levels=(False, False, True, True),
num_res_blocks=2,
num_head_channels=32,
).to(DEVICE)
model.load_state_dict(torch.load(MASK_MODEL_PATH, map_location=DEVICE))
model.eval()
models["mask"] = model
return models["mask"]
# Placeholder loaders for your other models
def load_conditional_model(model_type):
# --- 1. DDPM LOADING ---
if model_type == "DDPM" and models["ddpm"] is None:
print("Loading DDPM (Diffusers)...")
# Assuming you uploaded the 'ddpm-150-finetuned' folder content to 'models/ddpm'
unet = UNet2DModel.from_pretrained("models/ddpm/unet").to(DEVICE)
scheduler = DiffusersScheduler.from_pretrained("models/ddpm/scheduler")
models["ddpm"] = (unet, scheduler)
# --- 2. LDM LOADING ---
elif model_type == "LDM" and models["ldm"] is None:
print("Loading LDM (Custom)...")
config = LDMConfig()
# Load VQVAE
vqvae = VQVAE(im_channels=1, model_config=config.autoencoder_params).to(DEVICE)
vqvae.load_state_dict(torch.load("models/vqvae.pth", map_location=DEVICE)) # Ensure filename matches
vqvae.eval()
# Load LDM UNet
ldm_unet = Unet(im_channels=4, model_config=config.ldm_params).to(DEVICE)
ldm_unet.load_state_dict(torch.load("models/ldm.pth", map_location=DEVICE)) # Ensure filename matches
ldm_unet.eval()
models["ldm"] = (vqvae, ldm_unet, config)
# --- 3. FLOW MATCHING LOADING ---
elif model_type == "FM" and models["fm"] is None:
print("Loading Flow Matching (MONAI)...")
# Define Config (From your notebook)
fm_config = {
"spatial_dims": 2, "in_channels": 1, "out_channels": 1,
"num_res_blocks": [2, 2, 2, 2], "num_channels": [32, 64, 128, 256],
"attention_levels": [False, False, False, True], "norm_num_groups": 32,
"resblock_updown": True, "num_head_channels": [32, 64, 128, 256],
"transformer_num_layers": 6, "with_conditioning": True, "cross_attention_dim": 256,
}
# Build Base UNet
unet = DiffusionModelUNet(**fm_config)
# Create a copy of config for ControlNet and remove 'out_channels'
cn_config = fm_config.copy()
cn_config.pop("out_channels", None)
# Build ControlNet
controlnet = ControlNet(
**cn_config,
conditioning_embedding_num_channels=(16,)
)
# Merge
model = MergedModel(unet, controlnet).to(DEVICE)
# Download & Load Weights from Hugging Face Repo
# Replace 'REPO_ID' and 'FILENAME' with your actual ones
path = hf_hub_download(repo_id="ishanthathsara/syn_mri_flow_match", filename="flow_match_model.pt")
checkpoint = torch.load(path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
models["fm"] = model
return models.get(model_type.lower())
# ==========================================
# 2. GENERATION FUNCTIONS
# ==========================================
def generate_new_mask():
"""Generates a fresh mask using the Unconditional Diffusion Model."""
model = load_mask_model()
scheduler = DDPMScheduler(num_train_timesteps=1000)
# 1. Noise
noise = torch.randn((1, 4, 128, 128)).to(DEVICE)
current_img = noise
# 2. Denoising Loop (Simplified for speed, maybe reduce steps for demo?)
# For a demo, 1000 steps might be slow. You can use DDPMScheduler(num_train_timesteps=1000)
# but run fewer inference steps if you switch to DDIMScheduler.
# For now, we keep it standard.
for t in scheduler.timesteps:
with torch.no_grad():
output = model(x=current_img, timesteps=torch.Tensor((t,)).to(DEVICE), context=None)
current_img, _ = scheduler.step(output, t, current_img)
# 3. Post Process
current_img = (current_img + 1) / 2
mask_idx = torch.argmax(current_img, dim=1).cpu().numpy()[0] # (128, 128)
return colorize_mask(mask_idx), mask_idx
def colorize_mask(mask_2d):
"""Converts (128,128) integer mask to RGB image for display."""
cmap = plt.get_cmap('jet')
norm_mask = mask_2d / 3.0
colored = cmap(norm_mask)[:, :, :3] # Drop Alpha
return (colored * 255).astype(np.uint8)
def synthesize_image(mask_input, source_type, model_choice):
"""
Main Logic:
1. Prepares the mask (One-Hot Tensor for models, RGB for display).
2. Runs the selected conditional model.
3. Processes the output for display.
"""
# ==========================================
# A. HANDLE INPUT & PREPARE MASKS
# ==========================================
mask_onehot = None
display_mask = None
# CASE 1: Generated Mask (Input is Integer Array [128, 128] with values 0-3)
if source_type == "Generate Mask":
if mask_input is None: return None, "Please generate a mask first."
# 1. Create One-Hot Tensor for Model: [1, 4, 128, 128]
mask_tensor = torch.tensor(mask_input, dtype=torch.long).to(DEVICE)
mask_onehot = torch.nn.functional.one_hot(mask_tensor, num_classes=4).permute(2, 0, 1).float()
mask_onehot = mask_onehot.unsqueeze(0)
# 2. Create Display Mask
display_mask = colorize_mask(mask_input)
# CASE 2: Uploaded Mask (Input is RGB Image [128, 128, 3])
elif source_type in ["Upload Mask", "Select Mask"]:
if mask_input is None: return None, "Please upload a mask first."
# 1. Create One-Hot Tensor using your helper function
# (Ensure rgb_mask_to_onehot is defined at the top of your script!)
mask_onehot = rgb_mask_to_onehot(np.array(mask_input))
# 2. Display Mask is just the input
display_mask = mask_input
# ==========================================
# B. RUN CONDITIONAL INFERENCE
# ==========================================
generated_img = None
# --- OPTION 1: DDPM ---
if model_choice == "DDPM":
unet, scheduler = load_conditional_model("DDPM")
# Start with Noise
img = torch.randn((1, 1, 128, 128)).to(DEVICE)
for t in scheduler.timesteps:
# Concatenate [Noise (1ch) + Mask (4ch)] -> Input (5ch)
model_input = torch.cat([img, mask_onehot], dim=1)
with torch.no_grad():
noise_pred = unet(model_input, t).sample
img = scheduler.step(noise_pred, t, img).prev_sample
generated_img = img
# --- OPTION 2: LDM ---
elif model_choice == "LDM":
vqvae, ldm_unet, config = load_conditional_model("LDM")
# 1. Latent Noise (32x32)
latent_dim = 128 // 4 # 32
z = torch.randn((1, 4, latent_dim, latent_dim)).to(DEVICE)
# 2. Scheduler (Must match training params!)
scheduler = LinearNoiseScheduler(num_timesteps=1000, beta_start=0.00085, beta_end=0.012)
# 3. Conditioning
cond_input = {'image': mask_onehot}
# 4. Reverse Diffusion in Latent Space
for t in reversed(range(1000)):
t_tensor = torch.tensor([t], device=DEVICE)
with torch.no_grad():
noise_pred = ldm_unet(z, t_tensor, cond_input=cond_input)
# [0] is because sample_prev_timestep returns (mean, x0)
z = scheduler.sample_prev_timestep(z, noise_pred, t_tensor)[0]
# 5. Decode Latents to Pixels
with torch.no_grad():
generated_img = vqvae.decode(z)
# --- OPTION 3: FLOW MATCHING ---
elif model_choice == "Flow Matching":
model = load_conditional_model("FM")
# 1. Initial Noise
x = torch.randn((1, 1, 128, 128)).to(DEVICE)
# 2. Euler Solver (Simple Loop)
steps = 50
dt = 1.0 / steps
# FIX: Convert One-Hot [1, 4, 128, 128] back to class indices [1, 1, 128, 128]
mask_float = mask_onehot.float()
if mask_float.shape[1] == 4:
mask_float = torch.argmax(mask_float, dim=1, keepdim=True).float()
for i in range(steps):
t = torch.tensor([i * dt], device=DEVICE)
with torch.no_grad():
# Predict Velocity
# v = model(x=x, t=t, masks=mask_float)
if mask_float.shape[1] == 4:
mask_float = mask_float[:, 0:1, :, :] # Keep only the first channel
# Now pass it to the model
v = model(x=x, t=t, masks=mask_float)
# Step: x_next = x + v * dt
x = x + v * dt
generated_img = x
# ==========================================
# C. POST-PROCESSING (Tensor -> Numpy)
# ==========================================
if generated_img is not None:
# 1. Move to CPU and remove batch dim: (128, 128)
img_np = generated_img.squeeze().cpu().numpy()
# 2. Normalize [-1, 1] -> [0, 1]
# (DDPM/LDM outputs are usually -1 to 1. If FM is 0-1, this might need adjustment)
img_np = (img_np + 1) / 2
# 3. Clamp to valid range
img_np = np.clip(img_np, 0, 1)
# 4. Convert to uint8 [0, 255]
final_image = (img_np * 255).astype(np.uint8)
return display_mask, final_image
return display_mask, np.zeros((128, 128, 3), dtype=np.uint8)
# ==========================================
# 3. GRADIO UI
# ==========================================
with gr.Blocks(title="Cardiac MRI Synthesis") as demo:
gr.Markdown("# 🫀 Cardiac MRI Synthesis: Mask-to-Image")
gr.Markdown("Generate a synthetic cardiac mask or upload one, then turn it into a realistic MRI.")
with gr.Row():
with gr.Column():
gr.Markdown("### 1. Mask Input")
tab_choice = gr.Radio(["Generate Mask", "Upload Mask", "Select Mask"], label="Source", value="Generate Mask")
# Tab 1: Generate
with gr.Group(visible=True) as group_gen:
btn_gen_mask = gr.Button("Generate Random Mask", variant="primary")
out_gen_mask = gr.Image(label="Generated Mask", type="numpy", interactive=False)
state_mask = gr.State() # Stores the raw integer mask (0-3) hidden from view
# Tab 2: Upload
with gr.Group(visible=False) as group_up:
in_upload_mask = gr.Image(label="Upload Mask (PNG)", type="numpy")
with gr.Group(visible=False) as group_sel:
in_select_mask = gr.Image(label="Selected Mask", type="numpy", interactive=False)
gr.Examples(
examples=[
"sample_masks/img_1.png", # Replace with your actual filenames!
"sample_masks/img_2.png",
"sample_masks/img_3.png",
"sample_masks/img_4.png",
"sample_masks/img_5.png"
],
inputs=in_select_mask,
label="Click a mask to select it"
)
with gr.Column():
gr.Markdown("### 2. Image Synthesis")
model_dropdown = gr.Dropdown(["DDPM", "LDM", "Flow Matching"], label="Select Conditional Model", value="DDPM")
btn_synthesize = gr.Button("✨ Synthesize MRI", variant="primary")
out_final_img = gr.Image(label="Synthetic MRI")
# --- INTERACTIONS ---
# Toggle Tabs
def toggle_input(choice):
return {
group_gen: gr.update(visible=(choice == "Generate Mask")),
group_up: gr.update(visible=(choice == "Upload Mask")),
group_sel: gr.update(visible=(choice == "Select Mask"))
}
tab_choice.change(toggle_input, tab_choice, [group_gen, group_up, group_sel])
# Generate Mask Action
def on_gen_mask():
rgb, raw = generate_new_mask()
return rgb, raw # Update Image and State
btn_gen_mask.click(on_gen_mask, outputs=[out_gen_mask, state_mask])
# Synthesize Action
def on_synthesize(choice, gen_state, upload_img, select_img, model_name):
# We pass the State (raw mask) AND the Upload image or Selected image
# The logic inside determines which to use based on 'choice'
if choice == "Generate Mask":
final_mask, final_img = synthesize_image(gen_state, choice, model_name)
elif choice == "Upload Mask":
final_mask, final_img = synthesize_image(upload_img, choice, model_name)
elif choice == "Select Mask":
final_mask, final_img = synthesize_image(select_img, choice, model_name)
if isinstance(final_img, str): # If final_img is an error message
raise gr.Error(final_img)
return final_img
btn_synthesize.click(
on_synthesize,
inputs=[tab_choice, state_mask, in_upload_mask, in_select_mask, model_dropdown],
outputs=[out_final_img]
)
if __name__ == "__main__":
demo.launch()