Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from monai.utils import set_determinism | |
| from generative.networks.nets import DiffusionModelUNet, AutoencoderKL, ControlNet | |
| from generative.networks.schedulers import DDPMScheduler | |
| from huggingface_hub import hf_hub_download | |
| from diffusers import UNet2DModel, DDPMScheduler as DiffusersScheduler # Rename to avoid conflict | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusion import VQVAE, Unet, LinearNoiseScheduler | |
| # --- CONFIGURATION --- | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MASK_MODEL_PATH = "models/mask_diffusion.pth" | |
| # ========================================== | |
| # Helper Functions | |
| # ========================================== | |
| def get_jet_reference_colors(num_classes=4): | |
| """Recreates the exact RGB colors for classes 0-3 from jet colormap.""" | |
| cmap = plt.get_cmap('jet') | |
| colors = [] | |
| for i in range(num_classes): | |
| norm_val = i / (num_classes - 1) | |
| rgba = cmap(norm_val) | |
| rgb = [int(c * 255) for c in rgba[:3]] | |
| colors.append(rgb) | |
| return np.array(colors) | |
| def rgb_mask_to_onehot(mask_np): | |
| """ | |
| Converts an RGB numpy mask (H,W,3) to a One-Hot Tensor (1, 4, H, W). | |
| """ | |
| # 1. Resize if needed (Gradio usually handles this, but good to be safe) | |
| if mask_np.shape[:2] != (128, 128): | |
| # Convert to PIL for easy resizing | |
| img = Image.fromarray(mask_np.astype(np.uint8)) | |
| # Use NEAREST to preserve exact colors (no interpolation) | |
| img = img.resize((128, 128), resample=Image.NEAREST) | |
| mask_np = np.array(img) | |
| # 2. Euclidean distance to find closest class color | |
| ref_colors = get_jet_reference_colors(4) | |
| # Calculate distance: (H, W, 1, 3) - (1, 1, 4, 3) | |
| dist = np.linalg.norm(mask_np[:, :, None, :] - ref_colors[None, None, :, :], axis=3) | |
| # 3. Argmin to get indices (0, 1, 2, 3) | |
| label_map = np.argmin(dist, axis=2) # Shape: (128, 128) | |
| # 4. One-Hot Encoding | |
| mask_tensor = torch.tensor(label_map, dtype=torch.long) | |
| mask_onehot = F.one_hot(mask_tensor, num_classes=4).permute(2, 0, 1).float() | |
| # 5. Add Batch Dimension -> (1, 4, 128, 128) | |
| return mask_onehot.unsqueeze(0).to(DEVICE) | |
| class LDMConfig: | |
| def __init__(self): | |
| self.im_size = 128 | |
| self.ldm_params = { | |
| 'time_emb_dim': 256, | |
| 'down_channels': [128, 256, 512], | |
| 'mid_channels': [512, 256], | |
| 'down_sample': [True, True], | |
| 'attn_down': [False, True], | |
| 'norm_channels': 32, | |
| 'num_heads': 8, | |
| 'conv_out_channels': 128, | |
| 'num_down_layers': 2, | |
| 'num_mid_layers': 2, | |
| 'num_up_layers': 2, | |
| 'condition_config': { | |
| 'condition_types': ['image'], | |
| 'image_condition_config': { | |
| 'image_condition_input_channels': 4, | |
| 'image_condition_output_channels': 1, | |
| } | |
| } | |
| } | |
| self.autoencoder_params = { | |
| 'z_channels': 4, | |
| 'codebook_size': 8192, | |
| 'down_channels': [64, 128, 256], | |
| 'mid_channels': [256, 256], | |
| 'down_sample': [True, True], | |
| 'attn_down': [False, False], | |
| 'norm_channels': 32, | |
| 'num_heads': 4, | |
| 'num_down_layers': 2, | |
| 'num_mid_layers': 2, | |
| 'num_up_layers': 2 | |
| } | |
| # DEFINITIONS FOR FLOW MATCHING | |
| class MergedModel(nn.Module): | |
| def __init__(self, unet, controlnet=None, max_timestep=1000): | |
| super().__init__() | |
| self.unet = unet | |
| self.controlnet = controlnet | |
| self.max_timestep = max_timestep | |
| self.has_controlnet = controlnet is not None | |
| def forward(self, x, t, cond=None, masks=None): | |
| # Scale t from [0,1] to [0, 999] | |
| t = t * (self.max_timestep - 1) | |
| t = t.floor().long() | |
| if t.dim() == 0: t = t.expand(x.shape[0]) | |
| if self.has_controlnet: | |
| down_res, mid_res = self.controlnet(x=x, timesteps=t, controlnet_cond=masks, context=cond) | |
| return self.unet(x=x, timesteps=t, context=cond, | |
| down_block_additional_residuals=down_res, | |
| mid_block_additional_residual=mid_res) | |
| return self.unet(x=x, timesteps=t, context=cond) | |
| # ========================================== | |
| # 1. MODEL LOADING (Cached) | |
| # ========================================== | |
| # We use global variables to load models only once | |
| models = { | |
| "mask": None, | |
| "ddpm": None, | |
| "ldm": None, | |
| "fm": None | |
| } | |
| def load_mask_model(): | |
| if models["mask"] is None: | |
| print("Loading Mask Model...") | |
| model = DiffusionModelUNet( | |
| spatial_dims=2, | |
| in_channels=4, | |
| out_channels=4, | |
| num_channels=(64, 128, 256, 512), | |
| attention_levels=(False, False, True, True), | |
| num_res_blocks=2, | |
| num_head_channels=32, | |
| ).to(DEVICE) | |
| model.load_state_dict(torch.load(MASK_MODEL_PATH, map_location=DEVICE)) | |
| model.eval() | |
| models["mask"] = model | |
| return models["mask"] | |
| # Placeholder loaders for your other models | |
| def load_conditional_model(model_type): | |
| # --- 1. DDPM LOADING --- | |
| if model_type == "DDPM" and models["ddpm"] is None: | |
| print("Loading DDPM (Diffusers)...") | |
| # Assuming you uploaded the 'ddpm-150-finetuned' folder content to 'models/ddpm' | |
| unet = UNet2DModel.from_pretrained("models/ddpm/unet").to(DEVICE) | |
| scheduler = DiffusersScheduler.from_pretrained("models/ddpm/scheduler") | |
| models["ddpm"] = (unet, scheduler) | |
| # --- 2. LDM LOADING --- | |
| elif model_type == "LDM" and models["ldm"] is None: | |
| print("Loading LDM (Custom)...") | |
| config = LDMConfig() | |
| # Load VQVAE | |
| vqvae = VQVAE(im_channels=1, model_config=config.autoencoder_params).to(DEVICE) | |
| vqvae.load_state_dict(torch.load("models/vqvae.pth", map_location=DEVICE)) # Ensure filename matches | |
| vqvae.eval() | |
| # Load LDM UNet | |
| ldm_unet = Unet(im_channels=4, model_config=config.ldm_params).to(DEVICE) | |
| ldm_unet.load_state_dict(torch.load("models/ldm.pth", map_location=DEVICE)) # Ensure filename matches | |
| ldm_unet.eval() | |
| models["ldm"] = (vqvae, ldm_unet, config) | |
| # --- 3. FLOW MATCHING LOADING --- | |
| elif model_type == "FM" and models["fm"] is None: | |
| print("Loading Flow Matching (MONAI)...") | |
| # Define Config (From your notebook) | |
| fm_config = { | |
| "spatial_dims": 2, "in_channels": 1, "out_channels": 1, | |
| "num_res_blocks": [2, 2, 2, 2], "num_channels": [32, 64, 128, 256], | |
| "attention_levels": [False, False, False, True], "norm_num_groups": 32, | |
| "resblock_updown": True, "num_head_channels": [32, 64, 128, 256], | |
| "transformer_num_layers": 6, "with_conditioning": True, "cross_attention_dim": 256, | |
| } | |
| # Build Base UNet | |
| unet = DiffusionModelUNet(**fm_config) | |
| # Create a copy of config for ControlNet and remove 'out_channels' | |
| cn_config = fm_config.copy() | |
| cn_config.pop("out_channels", None) | |
| # Build ControlNet | |
| controlnet = ControlNet( | |
| **cn_config, | |
| conditioning_embedding_num_channels=(16,) | |
| ) | |
| # Merge | |
| model = MergedModel(unet, controlnet).to(DEVICE) | |
| # Download & Load Weights from Hugging Face Repo | |
| # Replace 'REPO_ID' and 'FILENAME' with your actual ones | |
| path = hf_hub_download(repo_id="ishanthathsara/syn_mri_flow_match", filename="flow_match_model.pt") | |
| checkpoint = torch.load(path, map_location=DEVICE) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| models["fm"] = model | |
| return models.get(model_type.lower()) | |
| # ========================================== | |
| # 2. GENERATION FUNCTIONS | |
| # ========================================== | |
| def generate_new_mask(): | |
| """Generates a fresh mask using the Unconditional Diffusion Model.""" | |
| model = load_mask_model() | |
| scheduler = DDPMScheduler(num_train_timesteps=1000) | |
| # 1. Noise | |
| noise = torch.randn((1, 4, 128, 128)).to(DEVICE) | |
| current_img = noise | |
| # 2. Denoising Loop (Simplified for speed, maybe reduce steps for demo?) | |
| # For a demo, 1000 steps might be slow. You can use DDPMScheduler(num_train_timesteps=1000) | |
| # but run fewer inference steps if you switch to DDIMScheduler. | |
| # For now, we keep it standard. | |
| for t in scheduler.timesteps: | |
| with torch.no_grad(): | |
| output = model(x=current_img, timesteps=torch.Tensor((t,)).to(DEVICE), context=None) | |
| current_img, _ = scheduler.step(output, t, current_img) | |
| # 3. Post Process | |
| current_img = (current_img + 1) / 2 | |
| mask_idx = torch.argmax(current_img, dim=1).cpu().numpy()[0] # (128, 128) | |
| return colorize_mask(mask_idx), mask_idx | |
| def colorize_mask(mask_2d): | |
| """Converts (128,128) integer mask to RGB image for display.""" | |
| cmap = plt.get_cmap('jet') | |
| norm_mask = mask_2d / 3.0 | |
| colored = cmap(norm_mask)[:, :, :3] # Drop Alpha | |
| return (colored * 255).astype(np.uint8) | |
| def synthesize_image(mask_input, source_type, model_choice): | |
| """ | |
| Main Logic: | |
| 1. Prepares the mask (One-Hot Tensor for models, RGB for display). | |
| 2. Runs the selected conditional model. | |
| 3. Processes the output for display. | |
| """ | |
| # ========================================== | |
| # A. HANDLE INPUT & PREPARE MASKS | |
| # ========================================== | |
| mask_onehot = None | |
| display_mask = None | |
| # CASE 1: Generated Mask (Input is Integer Array [128, 128] with values 0-3) | |
| if source_type == "Generate Mask": | |
| if mask_input is None: return None, "Please generate a mask first." | |
| # 1. Create One-Hot Tensor for Model: [1, 4, 128, 128] | |
| mask_tensor = torch.tensor(mask_input, dtype=torch.long).to(DEVICE) | |
| mask_onehot = torch.nn.functional.one_hot(mask_tensor, num_classes=4).permute(2, 0, 1).float() | |
| mask_onehot = mask_onehot.unsqueeze(0) | |
| # 2. Create Display Mask | |
| display_mask = colorize_mask(mask_input) | |
| # CASE 2: Uploaded Mask (Input is RGB Image [128, 128, 3]) | |
| elif source_type in ["Upload Mask", "Select Mask"]: | |
| if mask_input is None: return None, "Please upload a mask first." | |
| # 1. Create One-Hot Tensor using your helper function | |
| # (Ensure rgb_mask_to_onehot is defined at the top of your script!) | |
| mask_onehot = rgb_mask_to_onehot(np.array(mask_input)) | |
| # 2. Display Mask is just the input | |
| display_mask = mask_input | |
| # ========================================== | |
| # B. RUN CONDITIONAL INFERENCE | |
| # ========================================== | |
| generated_img = None | |
| # --- OPTION 1: DDPM --- | |
| if model_choice == "DDPM": | |
| unet, scheduler = load_conditional_model("DDPM") | |
| # Start with Noise | |
| img = torch.randn((1, 1, 128, 128)).to(DEVICE) | |
| for t in scheduler.timesteps: | |
| # Concatenate [Noise (1ch) + Mask (4ch)] -> Input (5ch) | |
| model_input = torch.cat([img, mask_onehot], dim=1) | |
| with torch.no_grad(): | |
| noise_pred = unet(model_input, t).sample | |
| img = scheduler.step(noise_pred, t, img).prev_sample | |
| generated_img = img | |
| # --- OPTION 2: LDM --- | |
| elif model_choice == "LDM": | |
| vqvae, ldm_unet, config = load_conditional_model("LDM") | |
| # 1. Latent Noise (32x32) | |
| latent_dim = 128 // 4 # 32 | |
| z = torch.randn((1, 4, latent_dim, latent_dim)).to(DEVICE) | |
| # 2. Scheduler (Must match training params!) | |
| scheduler = LinearNoiseScheduler(num_timesteps=1000, beta_start=0.00085, beta_end=0.012) | |
| # 3. Conditioning | |
| cond_input = {'image': mask_onehot} | |
| # 4. Reverse Diffusion in Latent Space | |
| for t in reversed(range(1000)): | |
| t_tensor = torch.tensor([t], device=DEVICE) | |
| with torch.no_grad(): | |
| noise_pred = ldm_unet(z, t_tensor, cond_input=cond_input) | |
| # [0] is because sample_prev_timestep returns (mean, x0) | |
| z = scheduler.sample_prev_timestep(z, noise_pred, t_tensor)[0] | |
| # 5. Decode Latents to Pixels | |
| with torch.no_grad(): | |
| generated_img = vqvae.decode(z) | |
| # --- OPTION 3: FLOW MATCHING --- | |
| elif model_choice == "Flow Matching": | |
| model = load_conditional_model("FM") | |
| # 1. Initial Noise | |
| x = torch.randn((1, 1, 128, 128)).to(DEVICE) | |
| # 2. Euler Solver (Simple Loop) | |
| steps = 50 | |
| dt = 1.0 / steps | |
| # FIX: Convert One-Hot [1, 4, 128, 128] back to class indices [1, 1, 128, 128] | |
| mask_float = mask_onehot.float() | |
| if mask_float.shape[1] == 4: | |
| mask_float = torch.argmax(mask_float, dim=1, keepdim=True).float() | |
| for i in range(steps): | |
| t = torch.tensor([i * dt], device=DEVICE) | |
| with torch.no_grad(): | |
| # Predict Velocity | |
| # v = model(x=x, t=t, masks=mask_float) | |
| if mask_float.shape[1] == 4: | |
| mask_float = mask_float[:, 0:1, :, :] # Keep only the first channel | |
| # Now pass it to the model | |
| v = model(x=x, t=t, masks=mask_float) | |
| # Step: x_next = x + v * dt | |
| x = x + v * dt | |
| generated_img = x | |
| # ========================================== | |
| # C. POST-PROCESSING (Tensor -> Numpy) | |
| # ========================================== | |
| if generated_img is not None: | |
| # 1. Move to CPU and remove batch dim: (128, 128) | |
| img_np = generated_img.squeeze().cpu().numpy() | |
| # 2. Normalize [-1, 1] -> [0, 1] | |
| # (DDPM/LDM outputs are usually -1 to 1. If FM is 0-1, this might need adjustment) | |
| img_np = (img_np + 1) / 2 | |
| # 3. Clamp to valid range | |
| img_np = np.clip(img_np, 0, 1) | |
| # 4. Convert to uint8 [0, 255] | |
| final_image = (img_np * 255).astype(np.uint8) | |
| return display_mask, final_image | |
| return display_mask, np.zeros((128, 128, 3), dtype=np.uint8) | |
| # ========================================== | |
| # 3. GRADIO UI | |
| # ========================================== | |
| with gr.Blocks(title="Cardiac MRI Synthesis") as demo: | |
| gr.Markdown("# 🫀 Cardiac MRI Synthesis: Mask-to-Image") | |
| gr.Markdown("Generate a synthetic cardiac mask or upload one, then turn it into a realistic MRI.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### 1. Mask Input") | |
| tab_choice = gr.Radio(["Generate Mask", "Upload Mask", "Select Mask"], label="Source", value="Generate Mask") | |
| # Tab 1: Generate | |
| with gr.Group(visible=True) as group_gen: | |
| btn_gen_mask = gr.Button("Generate Random Mask", variant="primary") | |
| out_gen_mask = gr.Image(label="Generated Mask", type="numpy", interactive=False) | |
| state_mask = gr.State() # Stores the raw integer mask (0-3) hidden from view | |
| # Tab 2: Upload | |
| with gr.Group(visible=False) as group_up: | |
| in_upload_mask = gr.Image(label="Upload Mask (PNG)", type="numpy") | |
| with gr.Group(visible=False) as group_sel: | |
| in_select_mask = gr.Image(label="Selected Mask", type="numpy", interactive=False) | |
| gr.Examples( | |
| examples=[ | |
| "sample_masks/img_1.png", # Replace with your actual filenames! | |
| "sample_masks/img_2.png", | |
| "sample_masks/img_3.png", | |
| "sample_masks/img_4.png", | |
| "sample_masks/img_5.png" | |
| ], | |
| inputs=in_select_mask, | |
| label="Click a mask to select it" | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### 2. Image Synthesis") | |
| model_dropdown = gr.Dropdown(["DDPM", "LDM", "Flow Matching"], label="Select Conditional Model", value="DDPM") | |
| btn_synthesize = gr.Button("✨ Synthesize MRI", variant="primary") | |
| out_final_img = gr.Image(label="Synthetic MRI") | |
| # --- INTERACTIONS --- | |
| # Toggle Tabs | |
| def toggle_input(choice): | |
| return { | |
| group_gen: gr.update(visible=(choice == "Generate Mask")), | |
| group_up: gr.update(visible=(choice == "Upload Mask")), | |
| group_sel: gr.update(visible=(choice == "Select Mask")) | |
| } | |
| tab_choice.change(toggle_input, tab_choice, [group_gen, group_up, group_sel]) | |
| # Generate Mask Action | |
| def on_gen_mask(): | |
| rgb, raw = generate_new_mask() | |
| return rgb, raw # Update Image and State | |
| btn_gen_mask.click(on_gen_mask, outputs=[out_gen_mask, state_mask]) | |
| # Synthesize Action | |
| def on_synthesize(choice, gen_state, upload_img, select_img, model_name): | |
| # We pass the State (raw mask) AND the Upload image or Selected image | |
| # The logic inside determines which to use based on 'choice' | |
| if choice == "Generate Mask": | |
| final_mask, final_img = synthesize_image(gen_state, choice, model_name) | |
| elif choice == "Upload Mask": | |
| final_mask, final_img = synthesize_image(upload_img, choice, model_name) | |
| elif choice == "Select Mask": | |
| final_mask, final_img = synthesize_image(select_img, choice, model_name) | |
| if isinstance(final_img, str): # If final_img is an error message | |
| raise gr.Error(final_img) | |
| return final_img | |
| btn_synthesize.click( | |
| on_synthesize, | |
| inputs=[tab_choice, state_mask, in_upload_mask, in_select_mask, model_dropdown], | |
| outputs=[out_final_img] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |