import gradio as gr import torch import numpy as np import matplotlib.pyplot as plt from PIL import Image from monai.utils import set_determinism from generative.networks.nets import DiffusionModelUNet, AutoencoderKL, ControlNet from generative.networks.schedulers import DDPMScheduler from huggingface_hub import hf_hub_download from diffusers import UNet2DModel, DDPMScheduler as DiffusersScheduler # Rename to avoid conflict import torch.nn as nn import torch.nn.functional as F from diffusion import VQVAE, Unet, LinearNoiseScheduler # --- CONFIGURATION --- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MASK_MODEL_PATH = "models/mask_diffusion.pth" # ========================================== # Helper Functions # ========================================== def get_jet_reference_colors(num_classes=4): """Recreates the exact RGB colors for classes 0-3 from jet colormap.""" cmap = plt.get_cmap('jet') colors = [] for i in range(num_classes): norm_val = i / (num_classes - 1) rgba = cmap(norm_val) rgb = [int(c * 255) for c in rgba[:3]] colors.append(rgb) return np.array(colors) def rgb_mask_to_onehot(mask_np): """ Converts an RGB numpy mask (H,W,3) to a One-Hot Tensor (1, 4, H, W). """ # 1. Resize if needed (Gradio usually handles this, but good to be safe) if mask_np.shape[:2] != (128, 128): # Convert to PIL for easy resizing img = Image.fromarray(mask_np.astype(np.uint8)) # Use NEAREST to preserve exact colors (no interpolation) img = img.resize((128, 128), resample=Image.NEAREST) mask_np = np.array(img) # 2. Euclidean distance to find closest class color ref_colors = get_jet_reference_colors(4) # Calculate distance: (H, W, 1, 3) - (1, 1, 4, 3) dist = np.linalg.norm(mask_np[:, :, None, :] - ref_colors[None, None, :, :], axis=3) # 3. Argmin to get indices (0, 1, 2, 3) label_map = np.argmin(dist, axis=2) # Shape: (128, 128) # 4. One-Hot Encoding mask_tensor = torch.tensor(label_map, dtype=torch.long) mask_onehot = F.one_hot(mask_tensor, num_classes=4).permute(2, 0, 1).float() # 5. Add Batch Dimension -> (1, 4, 128, 128) return mask_onehot.unsqueeze(0).to(DEVICE) class LDMConfig: def __init__(self): self.im_size = 128 self.ldm_params = { 'time_emb_dim': 256, 'down_channels': [128, 256, 512], 'mid_channels': [512, 256], 'down_sample': [True, True], 'attn_down': [False, True], 'norm_channels': 32, 'num_heads': 8, 'conv_out_channels': 128, 'num_down_layers': 2, 'num_mid_layers': 2, 'num_up_layers': 2, 'condition_config': { 'condition_types': ['image'], 'image_condition_config': { 'image_condition_input_channels': 4, 'image_condition_output_channels': 1, } } } self.autoencoder_params = { 'z_channels': 4, 'codebook_size': 8192, 'down_channels': [64, 128, 256], 'mid_channels': [256, 256], 'down_sample': [True, True], 'attn_down': [False, False], 'norm_channels': 32, 'num_heads': 4, 'num_down_layers': 2, 'num_mid_layers': 2, 'num_up_layers': 2 } # DEFINITIONS FOR FLOW MATCHING class MergedModel(nn.Module): def __init__(self, unet, controlnet=None, max_timestep=1000): super().__init__() self.unet = unet self.controlnet = controlnet self.max_timestep = max_timestep self.has_controlnet = controlnet is not None def forward(self, x, t, cond=None, masks=None): # Scale t from [0,1] to [0, 999] t = t * (self.max_timestep - 1) t = t.floor().long() if t.dim() == 0: t = t.expand(x.shape[0]) if self.has_controlnet: down_res, mid_res = self.controlnet(x=x, timesteps=t, controlnet_cond=masks, context=cond) return self.unet(x=x, timesteps=t, context=cond, down_block_additional_residuals=down_res, mid_block_additional_residual=mid_res) return self.unet(x=x, timesteps=t, context=cond) # ========================================== # 1. MODEL LOADING (Cached) # ========================================== # We use global variables to load models only once models = { "mask": None, "ddpm": None, "ldm": None, "fm": None } def load_mask_model(): if models["mask"] is None: print("Loading Mask Model...") model = DiffusionModelUNet( spatial_dims=2, in_channels=4, out_channels=4, num_channels=(64, 128, 256, 512), attention_levels=(False, False, True, True), num_res_blocks=2, num_head_channels=32, ).to(DEVICE) model.load_state_dict(torch.load(MASK_MODEL_PATH, map_location=DEVICE)) model.eval() models["mask"] = model return models["mask"] # Placeholder loaders for your other models def load_conditional_model(model_type): # --- 1. DDPM LOADING --- if model_type == "DDPM" and models["ddpm"] is None: print("Loading DDPM (Diffusers)...") # Assuming you uploaded the 'ddpm-150-finetuned' folder content to 'models/ddpm' unet = UNet2DModel.from_pretrained("models/ddpm/unet").to(DEVICE) scheduler = DiffusersScheduler.from_pretrained("models/ddpm/scheduler") models["ddpm"] = (unet, scheduler) # --- 2. LDM LOADING --- elif model_type == "LDM" and models["ldm"] is None: print("Loading LDM (Custom)...") config = LDMConfig() # Load VQVAE vqvae = VQVAE(im_channels=1, model_config=config.autoencoder_params).to(DEVICE) vqvae.load_state_dict(torch.load("models/vqvae.pth", map_location=DEVICE)) # Ensure filename matches vqvae.eval() # Load LDM UNet ldm_unet = Unet(im_channels=4, model_config=config.ldm_params).to(DEVICE) ldm_unet.load_state_dict(torch.load("models/ldm.pth", map_location=DEVICE)) # Ensure filename matches ldm_unet.eval() models["ldm"] = (vqvae, ldm_unet, config) # --- 3. FLOW MATCHING LOADING --- elif model_type == "FM" and models["fm"] is None: print("Loading Flow Matching (MONAI)...") # Define Config (From your notebook) fm_config = { "spatial_dims": 2, "in_channels": 1, "out_channels": 1, "num_res_blocks": [2, 2, 2, 2], "num_channels": [32, 64, 128, 256], "attention_levels": [False, False, False, True], "norm_num_groups": 32, "resblock_updown": True, "num_head_channels": [32, 64, 128, 256], "transformer_num_layers": 6, "with_conditioning": True, "cross_attention_dim": 256, } # Build Base UNet unet = DiffusionModelUNet(**fm_config) # Create a copy of config for ControlNet and remove 'out_channels' cn_config = fm_config.copy() cn_config.pop("out_channels", None) # Build ControlNet controlnet = ControlNet( **cn_config, conditioning_embedding_num_channels=(16,) ) # Merge model = MergedModel(unet, controlnet).to(DEVICE) # Download & Load Weights from Hugging Face Repo # Replace 'REPO_ID' and 'FILENAME' with your actual ones path = hf_hub_download(repo_id="ishanthathsara/syn_mri_flow_match", filename="flow_match_model.pt") checkpoint = torch.load(path, map_location=DEVICE) model.load_state_dict(checkpoint['model_state_dict']) model.eval() models["fm"] = model return models.get(model_type.lower()) # ========================================== # 2. GENERATION FUNCTIONS # ========================================== def generate_new_mask(): """Generates a fresh mask using the Unconditional Diffusion Model.""" model = load_mask_model() scheduler = DDPMScheduler(num_train_timesteps=1000) # 1. Noise noise = torch.randn((1, 4, 128, 128)).to(DEVICE) current_img = noise # 2. Denoising Loop (Simplified for speed, maybe reduce steps for demo?) # For a demo, 1000 steps might be slow. You can use DDPMScheduler(num_train_timesteps=1000) # but run fewer inference steps if you switch to DDIMScheduler. # For now, we keep it standard. for t in scheduler.timesteps: with torch.no_grad(): output = model(x=current_img, timesteps=torch.Tensor((t,)).to(DEVICE), context=None) current_img, _ = scheduler.step(output, t, current_img) # 3. Post Process current_img = (current_img + 1) / 2 mask_idx = torch.argmax(current_img, dim=1).cpu().numpy()[0] # (128, 128) return colorize_mask(mask_idx), mask_idx def colorize_mask(mask_2d): """Converts (128,128) integer mask to RGB image for display.""" cmap = plt.get_cmap('jet') norm_mask = mask_2d / 3.0 colored = cmap(norm_mask)[:, :, :3] # Drop Alpha return (colored * 255).astype(np.uint8) def synthesize_image(mask_input, source_type, model_choice): """ Main Logic: 1. Prepares the mask (One-Hot Tensor for models, RGB for display). 2. Runs the selected conditional model. 3. Processes the output for display. """ # ========================================== # A. HANDLE INPUT & PREPARE MASKS # ========================================== mask_onehot = None display_mask = None # CASE 1: Generated Mask (Input is Integer Array [128, 128] with values 0-3) if source_type == "Generate Mask": if mask_input is None: return None, "Please generate a mask first." # 1. Create One-Hot Tensor for Model: [1, 4, 128, 128] mask_tensor = torch.tensor(mask_input, dtype=torch.long).to(DEVICE) mask_onehot = torch.nn.functional.one_hot(mask_tensor, num_classes=4).permute(2, 0, 1).float() mask_onehot = mask_onehot.unsqueeze(0) # 2. Create Display Mask display_mask = colorize_mask(mask_input) # CASE 2: Uploaded Mask (Input is RGB Image [128, 128, 3]) elif source_type in ["Upload Mask", "Select Mask"]: if mask_input is None: return None, "Please upload a mask first." # 1. Create One-Hot Tensor using your helper function # (Ensure rgb_mask_to_onehot is defined at the top of your script!) mask_onehot = rgb_mask_to_onehot(np.array(mask_input)) # 2. Display Mask is just the input display_mask = mask_input # ========================================== # B. RUN CONDITIONAL INFERENCE # ========================================== generated_img = None # --- OPTION 1: DDPM --- if model_choice == "DDPM": unet, scheduler = load_conditional_model("DDPM") # Start with Noise img = torch.randn((1, 1, 128, 128)).to(DEVICE) for t in scheduler.timesteps: # Concatenate [Noise (1ch) + Mask (4ch)] -> Input (5ch) model_input = torch.cat([img, mask_onehot], dim=1) with torch.no_grad(): noise_pred = unet(model_input, t).sample img = scheduler.step(noise_pred, t, img).prev_sample generated_img = img # --- OPTION 2: LDM --- elif model_choice == "LDM": vqvae, ldm_unet, config = load_conditional_model("LDM") # 1. Latent Noise (32x32) latent_dim = 128 // 4 # 32 z = torch.randn((1, 4, latent_dim, latent_dim)).to(DEVICE) # 2. Scheduler (Must match training params!) scheduler = LinearNoiseScheduler(num_timesteps=1000, beta_start=0.00085, beta_end=0.012) # 3. Conditioning cond_input = {'image': mask_onehot} # 4. Reverse Diffusion in Latent Space for t in reversed(range(1000)): t_tensor = torch.tensor([t], device=DEVICE) with torch.no_grad(): noise_pred = ldm_unet(z, t_tensor, cond_input=cond_input) # [0] is because sample_prev_timestep returns (mean, x0) z = scheduler.sample_prev_timestep(z, noise_pred, t_tensor)[0] # 5. Decode Latents to Pixels with torch.no_grad(): generated_img = vqvae.decode(z) # --- OPTION 3: FLOW MATCHING --- elif model_choice == "Flow Matching": model = load_conditional_model("FM") # 1. Initial Noise x = torch.randn((1, 1, 128, 128)).to(DEVICE) # 2. Euler Solver (Simple Loop) steps = 50 dt = 1.0 / steps # FIX: Convert One-Hot [1, 4, 128, 128] back to class indices [1, 1, 128, 128] mask_float = mask_onehot.float() if mask_float.shape[1] == 4: mask_float = torch.argmax(mask_float, dim=1, keepdim=True).float() for i in range(steps): t = torch.tensor([i * dt], device=DEVICE) with torch.no_grad(): # Predict Velocity # v = model(x=x, t=t, masks=mask_float) if mask_float.shape[1] == 4: mask_float = mask_float[:, 0:1, :, :] # Keep only the first channel # Now pass it to the model v = model(x=x, t=t, masks=mask_float) # Step: x_next = x + v * dt x = x + v * dt generated_img = x # ========================================== # C. POST-PROCESSING (Tensor -> Numpy) # ========================================== if generated_img is not None: # 1. Move to CPU and remove batch dim: (128, 128) img_np = generated_img.squeeze().cpu().numpy() # 2. Normalize [-1, 1] -> [0, 1] # (DDPM/LDM outputs are usually -1 to 1. If FM is 0-1, this might need adjustment) img_np = (img_np + 1) / 2 # 3. Clamp to valid range img_np = np.clip(img_np, 0, 1) # 4. Convert to uint8 [0, 255] final_image = (img_np * 255).astype(np.uint8) return display_mask, final_image return display_mask, np.zeros((128, 128, 3), dtype=np.uint8) # ========================================== # 3. GRADIO UI # ========================================== with gr.Blocks(title="Cardiac MRI Synthesis") as demo: gr.Markdown("# 🫀 Cardiac MRI Synthesis: Mask-to-Image") gr.Markdown("Generate a synthetic cardiac mask or upload one, then turn it into a realistic MRI.") with gr.Row(): with gr.Column(): gr.Markdown("### 1. Mask Input") tab_choice = gr.Radio(["Generate Mask", "Upload Mask", "Select Mask"], label="Source", value="Generate Mask") # Tab 1: Generate with gr.Group(visible=True) as group_gen: btn_gen_mask = gr.Button("Generate Random Mask", variant="primary") out_gen_mask = gr.Image(label="Generated Mask", type="numpy", interactive=False) state_mask = gr.State() # Stores the raw integer mask (0-3) hidden from view # Tab 2: Upload with gr.Group(visible=False) as group_up: in_upload_mask = gr.Image(label="Upload Mask (PNG)", type="numpy") with gr.Group(visible=False) as group_sel: in_select_mask = gr.Image(label="Selected Mask", type="numpy", interactive=False) gr.Examples( examples=[ "sample_masks/img_1.png", # Replace with your actual filenames! "sample_masks/img_2.png", "sample_masks/img_3.png", "sample_masks/img_4.png", "sample_masks/img_5.png" ], inputs=in_select_mask, label="Click a mask to select it" ) with gr.Column(): gr.Markdown("### 2. Image Synthesis") model_dropdown = gr.Dropdown(["DDPM", "LDM", "Flow Matching"], label="Select Conditional Model", value="DDPM") btn_synthesize = gr.Button("✨ Synthesize MRI", variant="primary") out_final_img = gr.Image(label="Synthetic MRI") # --- INTERACTIONS --- # Toggle Tabs def toggle_input(choice): return { group_gen: gr.update(visible=(choice == "Generate Mask")), group_up: gr.update(visible=(choice == "Upload Mask")), group_sel: gr.update(visible=(choice == "Select Mask")) } tab_choice.change(toggle_input, tab_choice, [group_gen, group_up, group_sel]) # Generate Mask Action def on_gen_mask(): rgb, raw = generate_new_mask() return rgb, raw # Update Image and State btn_gen_mask.click(on_gen_mask, outputs=[out_gen_mask, state_mask]) # Synthesize Action def on_synthesize(choice, gen_state, upload_img, select_img, model_name): # We pass the State (raw mask) AND the Upload image or Selected image # The logic inside determines which to use based on 'choice' if choice == "Generate Mask": final_mask, final_img = synthesize_image(gen_state, choice, model_name) elif choice == "Upload Mask": final_mask, final_img = synthesize_image(upload_img, choice, model_name) elif choice == "Select Mask": final_mask, final_img = synthesize_image(select_img, choice, model_name) if isinstance(final_img, str): # If final_img is an error message raise gr.Error(final_img) return final_img btn_synthesize.click( on_synthesize, inputs=[tab_choice, state_mask, in_upload_mask, in_select_mask, model_dropdown], outputs=[out_final_img] ) if __name__ == "__main__": demo.launch()