File size: 18,291 Bytes
54c16ed
 
 
 
 
 
29da2fa
54c16ed
29da2fa
 
 
 
 
54c16ed
 
 
 
 
29da2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54c16ed
 
 
 
 
 
 
 
 
 
 
 
 
956cffa
54c16ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29da2fa
54c16ed
29da2fa
 
 
 
 
 
 
54c16ed
29da2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54c16ed
29da2fa
 
 
 
 
 
 
 
 
 
 
 
caa9a88
 
 
 
29da2fa
 
 
caa9a88
29da2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54c16ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29da2fa
 
 
54c16ed
29da2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fc4ee0
29da2fa
 
 
 
 
 
 
 
 
 
 
 
 
54c16ed
29da2fa
54c16ed
29da2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54c16ed
29da2fa
 
 
 
 
 
 
 
 
 
 
54c16ed
29da2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fc4ee0
 
 
 
 
29da2fa
 
 
 
 
 
1bd5658
 
 
 
 
29da2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54c16ed
 
 
 
 
 
 
 
 
 
 
4fc4ee0
54c16ed
 
 
956cffa
54c16ed
 
 
 
 
 
 
4fc4ee0
 
 
 
 
 
 
 
 
 
 
 
 
 
54c16ed
 
 
 
 
 
 
 
 
 
 
 
4fc4ee0
 
54c16ed
 
4fc4ee0
54c16ed
 
 
 
 
 
 
 
 
4fc4ee0
 
54c16ed
 
 
4fc4ee0
54c16ed
4fc4ee0
 
29da2fa
 
 
54c16ed
 
 
 
 
4fc4ee0
54c16ed
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from monai.utils import set_determinism
from generative.networks.nets import DiffusionModelUNet, AutoencoderKL, ControlNet
from generative.networks.schedulers import DDPMScheduler
from huggingface_hub import hf_hub_download
from diffusers import UNet2DModel, DDPMScheduler as DiffusersScheduler # Rename to avoid conflict
import torch.nn as nn
import torch.nn.functional as F
from diffusion import VQVAE, Unet, LinearNoiseScheduler

# --- CONFIGURATION ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MASK_MODEL_PATH = "models/mask_diffusion.pth"

# ==========================================
# Helper Functions
# ==========================================
def get_jet_reference_colors(num_classes=4):
    """Recreates the exact RGB colors for classes 0-3 from jet colormap."""
    cmap = plt.get_cmap('jet')
    colors = []
    for i in range(num_classes):
        norm_val = i / (num_classes - 1)
        rgba = cmap(norm_val)
        rgb = [int(c * 255) for c in rgba[:3]]
        colors.append(rgb)
    return np.array(colors)

def rgb_mask_to_onehot(mask_np):
    """
    Converts an RGB numpy mask (H,W,3) to a One-Hot Tensor (1, 4, H, W).
    """
    # 1. Resize if needed (Gradio usually handles this, but good to be safe)
    if mask_np.shape[:2] != (128, 128):
        # Convert to PIL for easy resizing
        img = Image.fromarray(mask_np.astype(np.uint8))
        # Use NEAREST to preserve exact colors (no interpolation)
        img = img.resize((128, 128), resample=Image.NEAREST)
        mask_np = np.array(img)

    # 2. Euclidean distance to find closest class color
    ref_colors = get_jet_reference_colors(4)
    # Calculate distance: (H, W, 1, 3) - (1, 1, 4, 3)
    dist = np.linalg.norm(mask_np[:, :, None, :] - ref_colors[None, None, :, :], axis=3)
    
    # 3. Argmin to get indices (0, 1, 2, 3)
    label_map = np.argmin(dist, axis=2) # Shape: (128, 128)
    
    # 4. One-Hot Encoding
    mask_tensor = torch.tensor(label_map, dtype=torch.long)
    mask_onehot = F.one_hot(mask_tensor, num_classes=4).permute(2, 0, 1).float()
    
    # 5. Add Batch Dimension -> (1, 4, 128, 128)
    return mask_onehot.unsqueeze(0).to(DEVICE)

class LDMConfig:
    def __init__(self):
        self.im_size = 128
        self.ldm_params = {
            'time_emb_dim': 256,
            'down_channels': [128, 256, 512],
            'mid_channels': [512, 256],
            'down_sample': [True, True],
            'attn_down': [False, True],
            'norm_channels': 32,
            'num_heads': 8,
            'conv_out_channels': 128,
            'num_down_layers': 2,
            'num_mid_layers': 2,
            'num_up_layers': 2,
            'condition_config': {
                'condition_types': ['image'], 
                'image_condition_config': {
                    'image_condition_input_channels': 4, 
                    'image_condition_output_channels': 1, 
                }
            }
        }
        self.autoencoder_params = {
            'z_channels': 4,
            'codebook_size': 8192,
            'down_channels': [64, 128, 256], 
            'mid_channels': [256, 256],
            'down_sample': [True, True],
            'attn_down': [False, False],
            'norm_channels': 32,
            'num_heads': 4,
            'num_down_layers': 2,
            'num_mid_layers': 2,
            'num_up_layers': 2
        }

# DEFINITIONS FOR FLOW MATCHING
class MergedModel(nn.Module):
    def __init__(self, unet, controlnet=None, max_timestep=1000):
        super().__init__()
        self.unet = unet
        self.controlnet = controlnet
        self.max_timestep = max_timestep
        self.has_controlnet = controlnet is not None

    def forward(self, x, t, cond=None, masks=None):
        # Scale t from [0,1] to [0, 999]
        t = t * (self.max_timestep - 1)
        t = t.floor().long()
        if t.dim() == 0: t = t.expand(x.shape[0])
        
        if self.has_controlnet:
            down_res, mid_res = self.controlnet(x=x, timesteps=t, controlnet_cond=masks, context=cond)
            return self.unet(x=x, timesteps=t, context=cond, 
                           down_block_additional_residuals=down_res, 
                           mid_block_additional_residual=mid_res)
        return self.unet(x=x, timesteps=t, context=cond)

# ==========================================
# 1. MODEL LOADING (Cached)
# ==========================================
# We use global variables to load models only once
models = {
    "mask": None,
    "ddpm": None,
    "ldm": None,
    "fm": None
}

def load_mask_model():
    if models["mask"] is None:
        print("Loading Mask Model...")
        model = DiffusionModelUNet(
            spatial_dims=2,
            in_channels=4,
            out_channels=4, 
            num_channels=(64, 128, 256, 512),
            attention_levels=(False, False, True, True),
            num_res_blocks=2,
            num_head_channels=32,
        ).to(DEVICE)
        model.load_state_dict(torch.load(MASK_MODEL_PATH, map_location=DEVICE))
        model.eval()
        models["mask"] = model
    return models["mask"]

# Placeholder loaders for your other models
def load_conditional_model(model_type):
    # --- 1. DDPM LOADING ---
    if model_type == "DDPM" and models["ddpm"] is None:
        print("Loading DDPM (Diffusers)...")
        # Assuming you uploaded the 'ddpm-150-finetuned' folder content to 'models/ddpm'
        unet = UNet2DModel.from_pretrained("models/ddpm/unet").to(DEVICE)
        scheduler = DiffusersScheduler.from_pretrained("models/ddpm/scheduler")
        models["ddpm"] = (unet, scheduler)

    # --- 2. LDM LOADING ---
    elif model_type == "LDM" and models["ldm"] is None:
        print("Loading LDM (Custom)...")
        config = LDMConfig()
        
        # Load VQVAE
        vqvae = VQVAE(im_channels=1, model_config=config.autoencoder_params).to(DEVICE)
        vqvae.load_state_dict(torch.load("models/vqvae.pth", map_location=DEVICE)) # Ensure filename matches
        vqvae.eval()
        
        # Load LDM UNet
        ldm_unet = Unet(im_channels=4, model_config=config.ldm_params).to(DEVICE)
        ldm_unet.load_state_dict(torch.load("models/ldm.pth", map_location=DEVICE)) # Ensure filename matches
        ldm_unet.eval()
        
        models["ldm"] = (vqvae, ldm_unet, config)

    # --- 3. FLOW MATCHING LOADING ---
    elif model_type == "FM" and models["fm"] is None:
        print("Loading Flow Matching (MONAI)...")
        # Define Config (From your notebook)
        fm_config = {
            "spatial_dims": 2, "in_channels": 1, "out_channels": 1,
            "num_res_blocks": [2, 2, 2, 2], "num_channels": [32, 64, 128, 256],
            "attention_levels": [False, False, False, True], "norm_num_groups": 32,
            "resblock_updown": True, "num_head_channels": [32, 64, 128, 256],
            "transformer_num_layers": 6, "with_conditioning": True, "cross_attention_dim": 256,
        }
        
        # Build Base UNet
        unet = DiffusionModelUNet(**fm_config)

        # Create a copy of config for ControlNet and remove 'out_channels'
        cn_config = fm_config.copy()
        cn_config.pop("out_channels", None)
        
        # Build ControlNet
        controlnet = ControlNet(
            **cn_config, 
            conditioning_embedding_num_channels=(16,)
        )
        
        # Merge
        model = MergedModel(unet, controlnet).to(DEVICE)
        
        # Download & Load Weights from Hugging Face Repo
        # Replace 'REPO_ID' and 'FILENAME' with your actual ones
        path = hf_hub_download(repo_id="ishanthathsara/syn_mri_flow_match", filename="flow_match_model.pt")
        checkpoint = torch.load(path, map_location=DEVICE)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        
        models["fm"] = model

    return models.get(model_type.lower())

# ==========================================
# 2. GENERATION FUNCTIONS
# ==========================================
def generate_new_mask():
    """Generates a fresh mask using the Unconditional Diffusion Model."""
    model = load_mask_model()
    scheduler = DDPMScheduler(num_train_timesteps=1000)
    
    # 1. Noise
    noise = torch.randn((1, 4, 128, 128)).to(DEVICE)
    current_img = noise

    # 2. Denoising Loop (Simplified for speed, maybe reduce steps for demo?)
    # For a demo, 1000 steps might be slow. You can use DDPMScheduler(num_train_timesteps=1000)
    # but run fewer inference steps if you switch to DDIMScheduler. 
    # For now, we keep it standard.
    
    for t in scheduler.timesteps:
        with torch.no_grad():
            output = model(x=current_img, timesteps=torch.Tensor((t,)).to(DEVICE), context=None)
            current_img, _ = scheduler.step(output, t, current_img)

    # 3. Post Process
    current_img = (current_img + 1) / 2
    mask_idx = torch.argmax(current_img, dim=1).cpu().numpy()[0] # (128, 128)
    
    return colorize_mask(mask_idx), mask_idx

def colorize_mask(mask_2d):
    """Converts (128,128) integer mask to RGB image for display."""
    cmap = plt.get_cmap('jet')
    norm_mask = mask_2d / 3.0
    colored = cmap(norm_mask)[:, :, :3] # Drop Alpha
    return (colored * 255).astype(np.uint8)

def synthesize_image(mask_input, source_type, model_choice):
    """
    Main Logic:
    1. Prepares the mask (One-Hot Tensor for models, RGB for display).
    2. Runs the selected conditional model.
    3. Processes the output for display.
    """
    # ==========================================
    # A. HANDLE INPUT & PREPARE MASKS
    # ==========================================
    mask_onehot = None
    display_mask = None
    
    # CASE 1: Generated Mask (Input is Integer Array [128, 128] with values 0-3)
    if source_type == "Generate Mask":
        if mask_input is None: return None, "Please generate a mask first."
        
        # 1. Create One-Hot Tensor for Model: [1, 4, 128, 128]
        mask_tensor = torch.tensor(mask_input, dtype=torch.long).to(DEVICE)
        mask_onehot = torch.nn.functional.one_hot(mask_tensor, num_classes=4).permute(2, 0, 1).float()
        mask_onehot = mask_onehot.unsqueeze(0) 
        
        # 2. Create Display Mask
        display_mask = colorize_mask(mask_input)

    # CASE 2: Uploaded Mask (Input is RGB Image [128, 128, 3])
    elif source_type in ["Upload Mask", "Select Mask"]:
        if mask_input is None: return None, "Please upload a mask first."
        
        # 1. Create One-Hot Tensor using your helper function
        # (Ensure rgb_mask_to_onehot is defined at the top of your script!)
        mask_onehot = rgb_mask_to_onehot(np.array(mask_input))
        
        # 2. Display Mask is just the input
        display_mask = mask_input

    # ==========================================
    # B. RUN CONDITIONAL INFERENCE
    # ==========================================
    generated_img = None
    
    # --- OPTION 1: DDPM ---
    if model_choice == "DDPM":
        unet, scheduler = load_conditional_model("DDPM")
        
        # Start with Noise
        img = torch.randn((1, 1, 128, 128)).to(DEVICE)
        
        for t in scheduler.timesteps:
            # Concatenate [Noise (1ch) + Mask (4ch)] -> Input (5ch)
            model_input = torch.cat([img, mask_onehot], dim=1)
            
            with torch.no_grad():
                noise_pred = unet(model_input, t).sample
            
            img = scheduler.step(noise_pred, t, img).prev_sample
            
        generated_img = img

    # --- OPTION 2: LDM ---
    elif model_choice == "LDM":
        vqvae, ldm_unet, config = load_conditional_model("LDM")
        
        # 1. Latent Noise (32x32)
        latent_dim = 128 // 4  # 32
        z = torch.randn((1, 4, latent_dim, latent_dim)).to(DEVICE)
        
        # 2. Scheduler (Must match training params!)
        scheduler = LinearNoiseScheduler(num_timesteps=1000, beta_start=0.00085, beta_end=0.012)
        
        # 3. Conditioning
        cond_input = {'image': mask_onehot} 
        
        # 4. Reverse Diffusion in Latent Space
        for t in reversed(range(1000)):
            t_tensor = torch.tensor([t], device=DEVICE)
            with torch.no_grad():
                noise_pred = ldm_unet(z, t_tensor, cond_input=cond_input)
                # [0] is because sample_prev_timestep returns (mean, x0)
                z = scheduler.sample_prev_timestep(z, noise_pred, t_tensor)[0]
        
        # 5. Decode Latents to Pixels
        with torch.no_grad():
            generated_img = vqvae.decode(z)

    # --- OPTION 3: FLOW MATCHING ---
    elif model_choice == "Flow Matching":
        model = load_conditional_model("FM")
        
        # 1. Initial Noise
        x = torch.randn((1, 1, 128, 128)).to(DEVICE)
        
        # 2. Euler Solver (Simple Loop)
        steps = 50 
        dt = 1.0 / steps

        # FIX: Convert One-Hot [1, 4, 128, 128] back to class indices [1, 1, 128, 128]
        mask_float = mask_onehot.float()
        if mask_float.shape[1] == 4:
            mask_float = torch.argmax(mask_float, dim=1, keepdim=True).float()
        
        for i in range(steps):
            t = torch.tensor([i * dt], device=DEVICE)
            
            with torch.no_grad():
                # Predict Velocity
                # v = model(x=x, t=t, masks=mask_float)
                if mask_float.shape[1] == 4:
                    mask_float = mask_float[:, 0:1, :, :]  # Keep only the first channel

                # Now pass it to the model
                v = model(x=x, t=t, masks=mask_float)
            
            # Step: x_next = x + v * dt
            x = x + v * dt
            
        generated_img = x

    # ==========================================
    # C. POST-PROCESSING (Tensor -> Numpy)
    # ==========================================
    if generated_img is not None:
        # 1. Move to CPU and remove batch dim: (128, 128)
        img_np = generated_img.squeeze().cpu().numpy()
        
        # 2. Normalize [-1, 1] -> [0, 1]
        # (DDPM/LDM outputs are usually -1 to 1. If FM is 0-1, this might need adjustment)
        img_np = (img_np + 1) / 2
        
        # 3. Clamp to valid range
        img_np = np.clip(img_np, 0, 1)
        
        # 4. Convert to uint8 [0, 255]
        final_image = (img_np * 255).astype(np.uint8)
        
        return display_mask, final_image
    
    return display_mask, np.zeros((128, 128, 3), dtype=np.uint8)

# ==========================================
# 3. GRADIO UI
# ==========================================
with gr.Blocks(title="Cardiac MRI Synthesis") as demo:
    gr.Markdown("# 🫀 Cardiac MRI Synthesis: Mask-to-Image")
    gr.Markdown("Generate a synthetic cardiac mask or upload one, then turn it into a realistic MRI.")

    with gr.Row():
        with gr.Column():
            gr.Markdown("### 1. Mask Input")
            tab_choice = gr.Radio(["Generate Mask", "Upload Mask", "Select Mask"], label="Source", value="Generate Mask")
            
            # Tab 1: Generate
            with gr.Group(visible=True) as group_gen:
                btn_gen_mask = gr.Button("Generate Random Mask", variant="primary")
                out_gen_mask = gr.Image(label="Generated Mask", type="numpy", interactive=False)
                state_mask = gr.State() # Stores the raw integer mask (0-3) hidden from view

            # Tab 2: Upload
            with gr.Group(visible=False) as group_up:
                in_upload_mask = gr.Image(label="Upload Mask (PNG)", type="numpy")

            with gr.Group(visible=False) as group_sel:
                in_select_mask = gr.Image(label="Selected Mask", type="numpy", interactive=False)
                gr.Examples(
                    examples=[
                        "sample_masks/img_1.png", # Replace with your actual filenames!
                        "sample_masks/img_2.png",
                        "sample_masks/img_3.png",
                        "sample_masks/img_4.png",
                        "sample_masks/img_5.png"
                    ],
                    inputs=in_select_mask,
                    label="Click a mask to select it"
                )

        with gr.Column():
            gr.Markdown("### 2. Image Synthesis")
            model_dropdown = gr.Dropdown(["DDPM", "LDM", "Flow Matching"], label="Select Conditional Model", value="DDPM")
            btn_synthesize = gr.Button("✨ Synthesize MRI", variant="primary")
            out_final_img = gr.Image(label="Synthetic MRI")

    # --- INTERACTIONS ---
    
    # Toggle Tabs
    def toggle_input(choice):
        return {
            group_gen: gr.update(visible=(choice == "Generate Mask")),
            group_up: gr.update(visible=(choice == "Upload Mask")),
            group_sel: gr.update(visible=(choice == "Select Mask"))
        }

    tab_choice.change(toggle_input, tab_choice, [group_gen, group_up, group_sel])

    # Generate Mask Action
    def on_gen_mask():
        rgb, raw = generate_new_mask()
        return rgb, raw # Update Image and State
        
    btn_gen_mask.click(on_gen_mask, outputs=[out_gen_mask, state_mask])

    # Synthesize Action
    def on_synthesize(choice, gen_state, upload_img, select_img, model_name):
        # We pass the State (raw mask) AND the Upload image or Selected image
        # The logic inside determines which to use based on 'choice'
        if choice == "Generate Mask":
            final_mask, final_img = synthesize_image(gen_state, choice, model_name)
        elif choice == "Upload Mask":
            final_mask, final_img = synthesize_image(upload_img, choice, model_name)
        elif choice == "Select Mask":
            final_mask, final_img = synthesize_image(select_img, choice, model_name)
        
        if isinstance(final_img, str): # If final_img is an error message
            raise gr.Error(final_img)
            
        return final_img

    btn_synthesize.click(
        on_synthesize, 
        inputs=[tab_choice, state_mask, in_upload_mask, in_select_mask, model_dropdown], 
        outputs=[out_final_img]
    )

if __name__ == "__main__":
    demo.launch()