Ishan Kumarasinghe commited on
Commit
29da2fa
Β·
1 Parent(s): 956cffa

Update app file and requirements

Browse files
app.py CHANGED
@@ -4,13 +4,118 @@ import numpy as np
4
  import matplotlib.pyplot as plt
5
  from PIL import Image
6
  from monai.utils import set_determinism
7
- from generative.networks.nets import DiffusionModelUNet, AutoencoderKL
8
  from generative.networks.schedulers import DDPMScheduler
 
 
 
 
 
9
 
10
  # --- CONFIGURATION ---
11
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  MASK_MODEL_PATH = "models/mask_diffusion.pth"
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # ==========================================
15
  # 1. MODEL LOADING (Cached)
16
  # ==========================================
@@ -41,16 +146,64 @@ def load_mask_model():
41
 
42
  # Placeholder loaders for your other models
43
  def load_conditional_model(model_type):
44
- # TODO: Update architecture definitions to match your trained conditional models
45
  if model_type == "DDPM" and models["ddpm"] is None:
46
- # Example: models["ddpm"] = DiffusionModelUNet(...).to(DEVICE)
47
- # models["ddpm"].load_state_dict(torch.load("models/ddpm_conditional.pth"))
48
- pass
 
 
 
 
49
  elif model_type == "LDM" and models["ldm"] is None:
50
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  elif model_type == "FM" and models["fm"] is None:
52
- pass
53
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  return models.get(model_type.lower())
55
 
56
  # ==========================================
@@ -91,47 +244,132 @@ def colorize_mask(mask_2d):
91
  def synthesize_image(mask_input, source_type, model_choice):
92
  """
93
  Main Logic:
94
- 1. Get the mask (either generated or uploaded).
95
- 2. Pass it to the selected conditional model.
 
96
  """
97
- # A. Handle Input Source
98
- if source_type == "Upload Mask":
99
- if mask_input is None:
100
- return None, "Please upload a mask first."
101
- # Expecting RGB upload, need to convert to integer map?
102
- # Or if your conditional models take RGB, pass raw.
103
- # For safety, let's assume we convert upload to numpy.
104
- mask_np = np.array(mask_input)
105
- # Simple heuristic to get class IDs if uploaded is RGB
106
- if mask_np.ndim == 3:
107
- # Basic conversion (you might need your robust logic here)
108
- mask_idx = mask_np[:,:,0] // 85 # roughly maps 0-255 to 0-3
109
- else:
110
- mask_idx = mask_np
 
 
 
 
 
111
  else:
112
- # Input comes from the "Generate Mask" step (State variable)
113
- if mask_input is None:
114
- return None, "Please generate a mask first."
115
- mask_idx = mask_input
116
-
117
- # B. Run Conditional Inference
118
- # THIS IS WHERE YOU ADD YOUR CONDITIONAL GENERATION CODE
119
- generated_img = np.zeros((128, 128, 3), dtype=np.uint8) # Placeholder Black Image
 
 
 
 
 
120
 
 
121
  if model_choice == "DDPM":
122
- # model = load_conditional_model("DDPM")
123
- # noise = ...
124
- # cond = torch.tensor(mask_idx)...
125
- # generated_img = ...
126
- pass
 
 
 
 
 
 
 
 
 
 
 
 
127
  elif model_choice == "LDM":
128
- pass
129
- elif model_choice == "FM":
130
- pass
 
 
 
 
 
 
 
 
131
 
132
- # Return both the mask (for verification) and the result
133
- display_mask = colorize_mask(mask_idx) if source_type == "Generate Mask" else mask_input
134
- return display_mask, generated_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  # ==========================================
137
  # 3. GRADIO UI
@@ -189,6 +427,9 @@ with gr.Blocks(title="Cardiac MRI Synthesis") as demo:
189
  final_mask, final_img = synthesize_image(gen_state, choice, model_name)
190
  else:
191
  final_mask, final_img = synthesize_image(upload_img, choice, model_name)
 
 
 
192
 
193
  return final_img
194
 
 
4
  import matplotlib.pyplot as plt
5
  from PIL import Image
6
  from monai.utils import set_determinism
7
+ from generative.networks.nets import DiffusionModelUNet, AutoencoderKL, ControlNet
8
  from generative.networks.schedulers import DDPMScheduler
9
+ from huggingface_hub import hf_hub_download
10
+ from diffusers import UNet2DModel, DDPMScheduler as DiffusersScheduler # Rename to avoid conflict
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from diffusion import VQVAE, Unet, LinearNoiseScheduler
14
 
15
  # --- CONFIGURATION ---
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  MASK_MODEL_PATH = "models/mask_diffusion.pth"
18
 
19
+ # ==========================================
20
+ # Helper Functions
21
+ # ==========================================
22
+ def get_jet_reference_colors(num_classes=4):
23
+ """Recreates the exact RGB colors for classes 0-3 from jet colormap."""
24
+ cmap = plt.get_cmap('jet')
25
+ colors = []
26
+ for i in range(num_classes):
27
+ norm_val = i / (num_classes - 1)
28
+ rgba = cmap(norm_val)
29
+ rgb = [int(c * 255) for c in rgba[:3]]
30
+ colors.append(rgb)
31
+ return np.array(colors)
32
+
33
+ def rgb_mask_to_onehot(mask_np):
34
+ """
35
+ Converts an RGB numpy mask (H,W,3) to a One-Hot Tensor (1, 4, H, W).
36
+ """
37
+ # 1. Resize if needed (Gradio usually handles this, but good to be safe)
38
+ if mask_np.shape[:2] != (128, 128):
39
+ # Convert to PIL for easy resizing
40
+ img = Image.fromarray(mask_np.astype(np.uint8))
41
+ # Use NEAREST to preserve exact colors (no interpolation)
42
+ img = img.resize((128, 128), resample=Image.NEAREST)
43
+ mask_np = np.array(img)
44
+
45
+ # 2. Euclidean distance to find closest class color
46
+ ref_colors = get_jet_reference_colors(4)
47
+ # Calculate distance: (H, W, 1, 3) - (1, 1, 4, 3)
48
+ dist = np.linalg.norm(mask_np[:, :, None, :] - ref_colors[None, None, :, :], axis=3)
49
+
50
+ # 3. Argmin to get indices (0, 1, 2, 3)
51
+ label_map = np.argmin(dist, axis=2) # Shape: (128, 128)
52
+
53
+ # 4. One-Hot Encoding
54
+ mask_tensor = torch.tensor(label_map, dtype=torch.long)
55
+ mask_onehot = F.one_hot(mask_tensor, num_classes=4).permute(2, 0, 1).float()
56
+
57
+ # 5. Add Batch Dimension -> (1, 4, 128, 128)
58
+ return mask_onehot.unsqueeze(0).to(DEVICE)
59
+
60
+ class LDMConfig:
61
+ def __init__(self):
62
+ self.im_size = 128
63
+ self.ldm_params = {
64
+ 'time_emb_dim': 256,
65
+ 'down_channels': [128, 256, 512],
66
+ 'mid_channels': [512, 256],
67
+ 'down_sample': [True, True],
68
+ 'attn_down': [False, True],
69
+ 'norm_channels': 32,
70
+ 'num_heads': 8,
71
+ 'conv_out_channels': 128,
72
+ 'num_down_layers': 2,
73
+ 'num_mid_layers': 2,
74
+ 'num_up_layers': 2,
75
+ 'condition_config': {
76
+ 'condition_types': ['image'],
77
+ 'image_condition_config': {
78
+ 'image_condition_input_channels': 4,
79
+ 'image_condition_output_channels': 1,
80
+ }
81
+ }
82
+ }
83
+ self.autoencoder_params = {
84
+ 'z_channels': 4,
85
+ 'codebook_size': 8192,
86
+ 'down_channels': [64, 128, 256],
87
+ 'mid_channels': [256, 256],
88
+ 'down_sample': [True, True],
89
+ 'attn_down': [False, False],
90
+ 'norm_channels': 32,
91
+ 'num_heads': 4,
92
+ 'num_down_layers': 2,
93
+ 'num_mid_layers': 2,
94
+ 'num_up_layers': 2
95
+ }
96
+
97
+ # DEFINITIONS FOR FLOW MATCHING
98
+ class MergedModel(nn.Module):
99
+ def __init__(self, unet, controlnet=None, max_timestep=1000):
100
+ super().__init__()
101
+ self.unet = unet
102
+ self.controlnet = controlnet
103
+ self.max_timestep = max_timestep
104
+ self.has_controlnet = controlnet is not None
105
+
106
+ def forward(self, x, t, cond=None, masks=None):
107
+ # Scale t from [0,1] to [0, 999]
108
+ t = t * (self.max_timestep - 1)
109
+ t = t.floor().long()
110
+ if t.dim() == 0: t = t.expand(x.shape[0])
111
+
112
+ if self.has_controlnet:
113
+ down_res, mid_res = self.controlnet(x=x, timesteps=t, controlnet_cond=masks, context=cond)
114
+ return self.unet(x=x, timesteps=t, context=cond,
115
+ down_block_additional_residuals=down_res,
116
+ mid_block_additional_residual=mid_res)
117
+ return self.unet(x=x, timesteps=t, context=cond)
118
+
119
  # ==========================================
120
  # 1. MODEL LOADING (Cached)
121
  # ==========================================
 
146
 
147
  # Placeholder loaders for your other models
148
  def load_conditional_model(model_type):
149
+ # --- 1. DDPM LOADING ---
150
  if model_type == "DDPM" and models["ddpm"] is None:
151
+ print("Loading DDPM (Diffusers)...")
152
+ # Assuming you uploaded the 'ddpm-150-finetuned' folder content to 'models/ddpm'
153
+ unet = UNet2DModel.from_pretrained("models/ddpm/unet").to(DEVICE)
154
+ scheduler = DiffusersScheduler.from_pretrained("models/ddpm/scheduler")
155
+ models["ddpm"] = (unet, scheduler)
156
+
157
+ # --- 2. LDM LOADING ---
158
  elif model_type == "LDM" and models["ldm"] is None:
159
+ print("Loading LDM (Custom)...")
160
+ config = LDMConfig()
161
+
162
+ # Load VQVAE
163
+ vqvae = VQVAE(im_channels=1, model_config=config.autoencoder_params).to(DEVICE)
164
+ vqvae.load_state_dict(torch.load("models/vqvae.pth", map_location=DEVICE)) # Ensure filename matches
165
+ vqvae.eval()
166
+
167
+ # Load LDM UNet
168
+ ldm_unet = Unet(im_channels=4, model_config=config.ldm_params).to(DEVICE)
169
+ ldm_unet.load_state_dict(torch.load("models/ldm.pth", map_location=DEVICE)) # Ensure filename matches
170
+ ldm_unet.eval()
171
+
172
+ models["ldm"] = (vqvae, ldm_unet, config)
173
+
174
+ # --- 3. FLOW MATCHING LOADING ---
175
  elif model_type == "FM" and models["fm"] is None:
176
+ print("Loading Flow Matching (MONAI)...")
177
+ # Define Config (From your notebook)
178
+ fm_config = {
179
+ "spatial_dims": 2, "in_channels": 1, "out_channels": 1,
180
+ "num_res_blocks": [2, 2, 2, 2], "num_channels": [32, 64, 128, 256],
181
+ "attention_levels": [False, False, False, True], "norm_num_groups": 32,
182
+ "resblock_updown": True, "num_head_channels": [32, 64, 128, 256],
183
+ "transformer_num_layers": 6, "with_conditioning": True, "cross_attention_dim": 256,
184
+ }
185
+
186
+ # Build Base UNet
187
+ unet = DiffusionModelUNet(**fm_config)
188
+
189
+ # Build ControlNet
190
+ controlnet = ControlNet(
191
+ **fm_config,
192
+ conditioning_embedding_num_channels=(16,)
193
+ )
194
+
195
+ # Merge
196
+ model = MergedModel(unet, controlnet).to(DEVICE)
197
+
198
+ # Download & Load Weights from Hugging Face Repo
199
+ # Replace 'REPO_ID' and 'FILENAME' with your actual ones
200
+ path = hf_hub_download(repo_id="ishanthathsara/syn_mri_flow_match", filename="flow_match_model.pt")
201
+ checkpoint = torch.load(path, map_location=DEVICE)
202
+ model.load_state_dict(checkpoint['model_state_dict'])
203
+ model.eval()
204
+
205
+ models["fm"] = model
206
+
207
  return models.get(model_type.lower())
208
 
209
  # ==========================================
 
244
  def synthesize_image(mask_input, source_type, model_choice):
245
  """
246
  Main Logic:
247
+ 1. Prepares the mask (One-Hot Tensor for models, RGB for display).
248
+ 2. Runs the selected conditional model.
249
+ 3. Processes the output for display.
250
  """
251
+ # ==========================================
252
+ # A. HANDLE INPUT & PREPARE MASKS
253
+ # ==========================================
254
+ mask_onehot = None
255
+ display_mask = None
256
+
257
+ # CASE 1: Generated Mask (Input is Integer Array [128, 128] with values 0-3)
258
+ if source_type == "Generate Mask":
259
+ if mask_input is None: return None, "Please generate a mask first."
260
+
261
+ # 1. Create One-Hot Tensor for Model: [1, 4, 128, 128]
262
+ mask_tensor = torch.tensor(mask_input, dtype=torch.long).to(DEVICE)
263
+ mask_onehot = torch.nn.functional.one_hot(mask_tensor, num_classes=4).permute(2, 0, 1).float()
264
+ mask_onehot = mask_onehot.unsqueeze(0)
265
+
266
+ # 2. Create Display Mask
267
+ display_mask = colorize_mask(mask_input)
268
+
269
+ # CASE 2: Uploaded Mask (Input is RGB Image [128, 128, 3])
270
  else:
271
+ if mask_input is None: return None, "Please upload a mask first."
272
+
273
+ # 1. Create One-Hot Tensor using your helper function
274
+ # (Ensure rgb_mask_to_onehot is defined at the top of your script!)
275
+ mask_onehot = rgb_mask_to_onehot(np.array(mask_input))
276
+
277
+ # 2. Display Mask is just the input
278
+ display_mask = mask_input
279
+
280
+ # ==========================================
281
+ # B. RUN CONDITIONAL INFERENCE
282
+ # ==========================================
283
+ generated_img = None
284
 
285
+ # --- OPTION 1: DDPM ---
286
  if model_choice == "DDPM":
287
+ unet, scheduler = load_conditional_model("DDPM")
288
+
289
+ # Start with Noise
290
+ img = torch.randn((1, 1, 128, 128)).to(DEVICE)
291
+
292
+ for t in scheduler.timesteps:
293
+ # Concatenate [Noise (1ch) + Mask (4ch)] -> Input (5ch)
294
+ model_input = torch.cat([img, mask_onehot], dim=1)
295
+
296
+ with torch.no_grad():
297
+ noise_pred = unet(model_input, t).sample
298
+
299
+ img = scheduler.step(noise_pred, t, img).prev_sample
300
+
301
+ generated_img = img
302
+
303
+ # --- OPTION 2: LDM ---
304
  elif model_choice == "LDM":
305
+ vqvae, ldm_unet, config = load_conditional_model("LDM")
306
+
307
+ # 1. Latent Noise (32x32)
308
+ latent_dim = 128 // 4 # 32
309
+ z = torch.randn((1, 4, latent_dim, latent_dim)).to(DEVICE)
310
+
311
+ # 2. Scheduler (Must match training params!)
312
+ scheduler = LinearNoiseScheduler(num_timesteps=1000, beta_start=0.00085, beta_end=0.012)
313
+
314
+ # 3. Conditioning
315
+ cond_input = {'image': mask_onehot}
316
 
317
+ # 4. Reverse Diffusion in Latent Space
318
+ for t in reversed(range(1000)):
319
+ t_tensor = torch.tensor([t], device=DEVICE)
320
+ with torch.no_grad():
321
+ noise_pred = ldm_unet(z, t_tensor, cond_input=cond_input)
322
+ # [0] is because sample_prev_timestep returns (mean, x0)
323
+ z = scheduler.sample_prev_timestep(z, noise_pred, t_tensor)[0]
324
+
325
+ # 5. Decode Latents to Pixels
326
+ with torch.no_grad():
327
+ generated_img = vqvae.decode(z)
328
+
329
+ # --- OPTION 3: FLOW MATCHING ---
330
+ elif model_choice == "Flow Matching":
331
+ model = load_conditional_model("FM")
332
+
333
+ # 1. Initial Noise
334
+ x = torch.randn((1, 1, 128, 128)).to(DEVICE)
335
+
336
+ # 2. Euler Solver (Simple Loop)
337
+ steps = 50
338
+ dt = 1.0 / steps
339
+ mask_float = mask_onehot.float()
340
+
341
+ for i in range(steps):
342
+ t = torch.tensor([i * dt], device=DEVICE)
343
+
344
+ with torch.no_grad():
345
+ # Predict Velocity
346
+ v = model(x=x, t=t, masks=mask_float)
347
+
348
+ # Step: x_next = x + v * dt
349
+ x = x + v * dt
350
+
351
+ generated_img = x
352
+
353
+ # ==========================================
354
+ # C. POST-PROCESSING (Tensor -> Numpy)
355
+ # ==========================================
356
+ if generated_img is not None:
357
+ # 1. Move to CPU and remove batch dim: (128, 128)
358
+ img_np = generated_img.squeeze().cpu().numpy()
359
+
360
+ # 2. Normalize [-1, 1] -> [0, 1]
361
+ # (DDPM/LDM outputs are usually -1 to 1. If FM is 0-1, this might need adjustment)
362
+ img_np = (img_np + 1) / 2
363
+
364
+ # 3. Clamp to valid range
365
+ img_np = np.clip(img_np, 0, 1)
366
+
367
+ # 4. Convert to uint8 [0, 255]
368
+ final_image = (img_np * 255).astype(np.uint8)
369
+
370
+ return display_mask, final_image
371
+
372
+ return display_mask, np.zeros((128, 128, 3), dtype=np.uint8)
373
 
374
  # ==========================================
375
  # 3. GRADIO UI
 
427
  final_mask, final_img = synthesize_image(gen_state, choice, model_name)
428
  else:
429
  final_mask, final_img = synthesize_image(upload_img, choice, model_name)
430
+
431
+ if isinstance(final_img, str): # If final_img is an error message
432
+ raise gr.Error(final_img)
433
 
434
  return final_img
435
 
diffusion.py ADDED
@@ -0,0 +1,1199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import einsum
5
+ import numpy as np
6
+ import pickle
7
+ import glob
8
+ import os
9
+
10
+ # ==========================================
11
+ # BLOCKS for VQVAE (Down, Mid, Up)
12
+ # ==========================================
13
+ def get_time_embedding(time_steps, temb_dim):
14
+ r"""
15
+ Convert time steps tensor into an embedding using the
16
+ sinusoidal time embedding formula
17
+ :param time_steps: 1D tensor of length batch size
18
+ :param temb_dim: Dimension of the embedding
19
+ :return: BxD embedding representation of B time steps
20
+ """
21
+ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
22
+
23
+ # factor = 10000^(2i/d_model)
24
+ factor = 10000 ** ((torch.arange(
25
+ start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
26
+ )
27
+
28
+ # pos / factor
29
+ # timesteps B -> B, 1 -> B, temb_dim
30
+ t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
31
+ t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
32
+ return t_emb
33
+
34
+
35
+ class DownBlock(nn.Module):
36
+ r"""
37
+ Down conv block with attention.
38
+ Sequence of following block
39
+ 1. Resnet block with time embedding
40
+ 2. Attention block
41
+ 3. Downsample
42
+ """
43
+
44
+ def __init__(self, in_channels, out_channels, t_emb_dim,
45
+ down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None):
46
+ super().__init__()
47
+ self.num_layers = num_layers
48
+ self.down_sample = down_sample
49
+ self.attn = attn
50
+ self.context_dim = context_dim
51
+ self.cross_attn = cross_attn
52
+ self.t_emb_dim = t_emb_dim
53
+ self.resnet_conv_first = nn.ModuleList(
54
+ [
55
+ nn.Sequential(
56
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
57
+ nn.SiLU(),
58
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
59
+ kernel_size=3, stride=1, padding=1),
60
+ )
61
+ for i in range(num_layers)
62
+ ]
63
+ )
64
+ if self.t_emb_dim is not None:
65
+ self.t_emb_layers = nn.ModuleList([
66
+ nn.Sequential(
67
+ nn.SiLU(),
68
+ nn.Linear(self.t_emb_dim, out_channels)
69
+ )
70
+ for _ in range(num_layers)
71
+ ])
72
+ self.resnet_conv_second = nn.ModuleList(
73
+ [
74
+ nn.Sequential(
75
+ nn.GroupNorm(norm_channels, out_channels),
76
+ nn.SiLU(),
77
+ nn.Conv2d(out_channels, out_channels,
78
+ kernel_size=3, stride=1, padding=1),
79
+ )
80
+ for _ in range(num_layers)
81
+ ]
82
+ )
83
+
84
+ if self.attn:
85
+ self.attention_norms = nn.ModuleList(
86
+ [nn.GroupNorm(norm_channels, out_channels)
87
+ for _ in range(num_layers)]
88
+ )
89
+
90
+ self.attentions = nn.ModuleList(
91
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
92
+ for _ in range(num_layers)]
93
+ )
94
+
95
+ if self.cross_attn:
96
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
97
+ self.cross_attention_norms = nn.ModuleList(
98
+ [nn.GroupNorm(norm_channels, out_channels)
99
+ for _ in range(num_layers)]
100
+ )
101
+ self.cross_attentions = nn.ModuleList(
102
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
103
+ for _ in range(num_layers)]
104
+ )
105
+ self.context_proj = nn.ModuleList(
106
+ [nn.Linear(context_dim, out_channels)
107
+ for _ in range(num_layers)]
108
+ )
109
+
110
+ self.residual_input_conv = nn.ModuleList(
111
+ [
112
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
113
+ for i in range(num_layers)
114
+ ]
115
+ )
116
+ self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
117
+ 4, 2, 1) if self.down_sample else nn.Identity()
118
+
119
+ def forward(self, x, t_emb=None, context=None):
120
+ out = x
121
+ for i in range(self.num_layers):
122
+ # Resnet block of Unet
123
+ resnet_input = out
124
+ out = self.resnet_conv_first[i](out)
125
+ if self.t_emb_dim is not None:
126
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
127
+ out = self.resnet_conv_second[i](out)
128
+ out = out + self.residual_input_conv[i](resnet_input)
129
+
130
+ if self.attn:
131
+ # Attention block of Unet
132
+ batch_size, channels, h, w = out.shape
133
+ in_attn = out.reshape(batch_size, channels, h * w)
134
+ in_attn = self.attention_norms[i](in_attn)
135
+ in_attn = in_attn.transpose(1, 2)
136
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
137
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
138
+ out = out + out_attn
139
+
140
+ if self.cross_attn:
141
+ assert context is not None, "context cannot be None if cross attention layers are used"
142
+ batch_size, channels, h, w = out.shape
143
+ in_attn = out.reshape(batch_size, channels, h * w)
144
+ in_attn = self.cross_attention_norms[i](in_attn)
145
+ in_attn = in_attn.transpose(1, 2)
146
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
147
+ context_proj = self.context_proj[i](context)
148
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
149
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
150
+ out = out + out_attn
151
+
152
+ # Downsample
153
+ out = self.down_sample_conv(out)
154
+ return out
155
+
156
+
157
+ class MidBlock(nn.Module):
158
+ r"""
159
+ Mid conv block with attention.
160
+ Sequence of following blocks
161
+ 1. Resnet block with time embedding
162
+ 2. Attention block
163
+ 3. Resnet block with time embedding
164
+ """
165
+
166
+ def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None):
167
+ super().__init__()
168
+ self.num_layers = num_layers
169
+ self.t_emb_dim = t_emb_dim
170
+ self.context_dim = context_dim
171
+ self.cross_attn = cross_attn
172
+ self.resnet_conv_first = nn.ModuleList(
173
+ [
174
+ nn.Sequential(
175
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
176
+ nn.SiLU(),
177
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
178
+ padding=1),
179
+ )
180
+ for i in range(num_layers + 1)
181
+ ]
182
+ )
183
+
184
+ if self.t_emb_dim is not None:
185
+ self.t_emb_layers = nn.ModuleList([
186
+ nn.Sequential(
187
+ nn.SiLU(),
188
+ nn.Linear(t_emb_dim, out_channels)
189
+ )
190
+ for _ in range(num_layers + 1)
191
+ ])
192
+ self.resnet_conv_second = nn.ModuleList(
193
+ [
194
+ nn.Sequential(
195
+ nn.GroupNorm(norm_channels, out_channels),
196
+ nn.SiLU(),
197
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
198
+ )
199
+ for _ in range(num_layers + 1)
200
+ ]
201
+ )
202
+
203
+ self.attention_norms = nn.ModuleList(
204
+ [nn.GroupNorm(norm_channels, out_channels)
205
+ for _ in range(num_layers)]
206
+ )
207
+
208
+ self.attentions = nn.ModuleList(
209
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
210
+ for _ in range(num_layers)]
211
+ )
212
+ if self.cross_attn:
213
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
214
+ self.cross_attention_norms = nn.ModuleList(
215
+ [nn.GroupNorm(norm_channels, out_channels)
216
+ for _ in range(num_layers)]
217
+ )
218
+ self.cross_attentions = nn.ModuleList(
219
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
220
+ for _ in range(num_layers)]
221
+ )
222
+ self.context_proj = nn.ModuleList(
223
+ [nn.Linear(context_dim, out_channels)
224
+ for _ in range(num_layers)]
225
+ )
226
+ self.residual_input_conv = nn.ModuleList(
227
+ [
228
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
229
+ for i in range(num_layers + 1)
230
+ ]
231
+ )
232
+
233
+ def forward(self, x, t_emb=None, context=None):
234
+ out = x
235
+
236
+ # First resnet block
237
+ resnet_input = out
238
+ out = self.resnet_conv_first[0](out)
239
+ if self.t_emb_dim is not None:
240
+ out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
241
+ out = self.resnet_conv_second[0](out)
242
+ out = out + self.residual_input_conv[0](resnet_input)
243
+
244
+ for i in range(self.num_layers):
245
+ # Attention Block
246
+ batch_size, channels, h, w = out.shape
247
+ in_attn = out.reshape(batch_size, channels, h * w)
248
+ in_attn = self.attention_norms[i](in_attn)
249
+ in_attn = in_attn.transpose(1, 2)
250
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
251
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
252
+ out = out + out_attn
253
+
254
+ if self.cross_attn:
255
+ assert context is not None, "context cannot be None if cross attention layers are used"
256
+ batch_size, channels, h, w = out.shape
257
+ in_attn = out.reshape(batch_size, channels, h * w)
258
+ in_attn = self.cross_attention_norms[i](in_attn)
259
+ in_attn = in_attn.transpose(1, 2)
260
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
261
+ context_proj = self.context_proj[i](context)
262
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
263
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
264
+ out = out + out_attn
265
+
266
+
267
+ # Resnet Block
268
+ resnet_input = out
269
+ out = self.resnet_conv_first[i + 1](out)
270
+ if self.t_emb_dim is not None:
271
+ out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
272
+ out = self.resnet_conv_second[i + 1](out)
273
+ out = out + self.residual_input_conv[i + 1](resnet_input)
274
+
275
+ return out
276
+
277
+
278
+ class UpBlock(nn.Module):
279
+ r"""
280
+ Up conv block with attention.
281
+ Sequence of following blocks
282
+ 1. Upsample
283
+ 1. Concatenate Down block output
284
+ 2. Resnet block with time embedding
285
+ 3. Attention Block
286
+ """
287
+
288
+ def __init__(self, in_channels, out_channels, t_emb_dim,
289
+ up_sample, num_heads, num_layers, attn, norm_channels):
290
+ super().__init__()
291
+ self.num_layers = num_layers
292
+ self.up_sample = up_sample
293
+ self.t_emb_dim = t_emb_dim
294
+ self.attn = attn
295
+ self.resnet_conv_first = nn.ModuleList(
296
+ [
297
+ nn.Sequential(
298
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
299
+ nn.SiLU(),
300
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
301
+ padding=1),
302
+ )
303
+ for i in range(num_layers)
304
+ ]
305
+ )
306
+
307
+ if self.t_emb_dim is not None:
308
+ self.t_emb_layers = nn.ModuleList([
309
+ nn.Sequential(
310
+ nn.SiLU(),
311
+ nn.Linear(t_emb_dim, out_channels)
312
+ )
313
+ for _ in range(num_layers)
314
+ ])
315
+
316
+ self.resnet_conv_second = nn.ModuleList(
317
+ [
318
+ nn.Sequential(
319
+ nn.GroupNorm(norm_channels, out_channels),
320
+ nn.SiLU(),
321
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
322
+ )
323
+ for _ in range(num_layers)
324
+ ]
325
+ )
326
+ if self.attn:
327
+ self.attention_norms = nn.ModuleList(
328
+ [
329
+ nn.GroupNorm(norm_channels, out_channels)
330
+ for _ in range(num_layers)
331
+ ]
332
+ )
333
+
334
+ self.attentions = nn.ModuleList(
335
+ [
336
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
337
+ for _ in range(num_layers)
338
+ ]
339
+ )
340
+
341
+ self.residual_input_conv = nn.ModuleList(
342
+ [
343
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
344
+ for i in range(num_layers)
345
+ ]
346
+ )
347
+ self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
348
+ 4, 2, 1) \
349
+ if self.up_sample else nn.Identity()
350
+
351
+ def forward(self, x, out_down=None, t_emb=None):
352
+ # Upsample
353
+ x = self.up_sample_conv(x)
354
+
355
+ # Concat with Downblock output
356
+ if out_down is not None:
357
+ x = torch.cat([x, out_down], dim=1)
358
+
359
+ out = x
360
+ for i in range(self.num_layers):
361
+ # Resnet Block
362
+ resnet_input = out
363
+ out = self.resnet_conv_first[i](out)
364
+ if self.t_emb_dim is not None:
365
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
366
+ out = self.resnet_conv_second[i](out)
367
+ out = out + self.residual_input_conv[i](resnet_input)
368
+
369
+ # Self Attention
370
+ if self.attn:
371
+ batch_size, channels, h, w = out.shape
372
+ in_attn = out.reshape(batch_size, channels, h * w)
373
+ in_attn = self.attention_norms[i](in_attn)
374
+ in_attn = in_attn.transpose(1, 2)
375
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
376
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
377
+ out = out + out_attn
378
+ return out
379
+
380
+
381
+ class UpBlockUnet(nn.Module):
382
+ r"""
383
+ Up conv block with attention.
384
+ Sequence of following blocks
385
+ 1. Upsample
386
+ 1. Concatenate Down block output
387
+ 2. Resnet block with time embedding
388
+ 3. Attention Block
389
+ """
390
+
391
+ def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
392
+ num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
393
+ super().__init__()
394
+ self.num_layers = num_layers
395
+ self.up_sample = up_sample
396
+ self.t_emb_dim = t_emb_dim
397
+ self.cross_attn = cross_attn
398
+ self.context_dim = context_dim
399
+ self.resnet_conv_first = nn.ModuleList(
400
+ [
401
+ nn.Sequential(
402
+ nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
403
+ nn.SiLU(),
404
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
405
+ padding=1),
406
+ )
407
+ for i in range(num_layers)
408
+ ]
409
+ )
410
+
411
+ if self.t_emb_dim is not None:
412
+ self.t_emb_layers = nn.ModuleList([
413
+ nn.Sequential(
414
+ nn.SiLU(),
415
+ nn.Linear(t_emb_dim, out_channels)
416
+ )
417
+ for _ in range(num_layers)
418
+ ])
419
+
420
+ self.resnet_conv_second = nn.ModuleList(
421
+ [
422
+ nn.Sequential(
423
+ nn.GroupNorm(norm_channels, out_channels),
424
+ nn.SiLU(),
425
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
426
+ )
427
+ for _ in range(num_layers)
428
+ ]
429
+ )
430
+
431
+ self.attention_norms = nn.ModuleList(
432
+ [
433
+ nn.GroupNorm(norm_channels, out_channels)
434
+ for _ in range(num_layers)
435
+ ]
436
+ )
437
+
438
+ self.attentions = nn.ModuleList(
439
+ [
440
+ nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
441
+ for _ in range(num_layers)
442
+ ]
443
+ )
444
+
445
+ if self.cross_attn:
446
+ assert context_dim is not None, "Context Dimension must be passed for cross attention"
447
+ self.cross_attention_norms = nn.ModuleList(
448
+ [nn.GroupNorm(norm_channels, out_channels)
449
+ for _ in range(num_layers)]
450
+ )
451
+ self.cross_attentions = nn.ModuleList(
452
+ [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
453
+ for _ in range(num_layers)]
454
+ )
455
+ self.context_proj = nn.ModuleList(
456
+ [nn.Linear(context_dim, out_channels)
457
+ for _ in range(num_layers)]
458
+ )
459
+ self.residual_input_conv = nn.ModuleList(
460
+ [
461
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
462
+ for i in range(num_layers)
463
+ ]
464
+ )
465
+ self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
466
+ 4, 2, 1) \
467
+ if self.up_sample else nn.Identity()
468
+
469
+ def forward(self, x, out_down=None, t_emb=None, context=None):
470
+ x = self.up_sample_conv(x)
471
+ if out_down is not None:
472
+ x = torch.cat([x, out_down], dim=1)
473
+
474
+ out = x
475
+ for i in range(self.num_layers):
476
+ # Resnet
477
+ resnet_input = out
478
+ out = self.resnet_conv_first[i](out)
479
+ if self.t_emb_dim is not None:
480
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
481
+ out = self.resnet_conv_second[i](out)
482
+ out = out + self.residual_input_conv[i](resnet_input)
483
+ # Self Attention
484
+ batch_size, channels, h, w = out.shape
485
+ in_attn = out.reshape(batch_size, channels, h * w)
486
+ in_attn = self.attention_norms[i](in_attn)
487
+ in_attn = in_attn.transpose(1, 2)
488
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
489
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
490
+ out = out + out_attn
491
+ # Cross Attention
492
+ if self.cross_attn:
493
+ assert context is not None, "context cannot be None if cross attention layers are used"
494
+ batch_size, channels, h, w = out.shape
495
+ in_attn = out.reshape(batch_size, channels, h * w)
496
+ in_attn = self.cross_attention_norms[i](in_attn)
497
+ in_attn = in_attn.transpose(1, 2)
498
+ assert len(context.shape) == 3, \
499
+ "Context shape does not match B,_,CONTEXT_DIM"
500
+ assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\
501
+ "Context shape does not match B,_,CONTEXT_DIM"
502
+ context_proj = self.context_proj[i](context)
503
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
504
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
505
+ out = out + out_attn
506
+ return out
507
+
508
+ # ==========================================
509
+ # VQVAE Definition
510
+ # ==========================================
511
+ class VQVAE(nn.Module):
512
+ def __init__(self, im_channels, model_config):
513
+ super().__init__()
514
+ self.down_channels = model_config['down_channels']
515
+ self.mid_channels = model_config['mid_channels']
516
+ self.down_sample = model_config['down_sample']
517
+ self.num_down_layers = model_config['num_down_layers']
518
+ self.num_mid_layers = model_config['num_mid_layers']
519
+ self.num_up_layers = model_config['num_up_layers']
520
+
521
+ # To disable attention in Downblock of Encoder and Upblock of Decoder
522
+ self.attns = model_config['attn_down']
523
+
524
+ #Latent Dimension
525
+ self.z_channels = model_config['z_channels']
526
+ self.codebook_size = model_config['codebook_size']
527
+ self.norm_channels = model_config['norm_channels']
528
+ self.num_heads = model_config['num_heads']
529
+
530
+ #Assertion to validate the channel information
531
+ assert self.mid_channels[0] == self.down_channels[-1]
532
+ assert self.mid_channels[-1] == self.down_channels[-1]
533
+ assert len(self.down_sample) == len(self.down_channels) - 1
534
+ assert len(self.attns) == len(self.down_channels) - 1
535
+
536
+ # Wherever we use downsampling in encoder correspondingly use
537
+ # upsampling in decoder
538
+ self.up_sample = list(reversed(self.down_sample))
539
+
540
+ ## Encoder ##
541
+ self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
542
+
543
+ # Downblock + Midblock
544
+ self.encoder_layers = nn.ModuleList([])
545
+ for i in range(len(self.down_channels) - 1):
546
+ self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
547
+ t_emb_dim=None, down_sample=self.down_sample[i],
548
+ num_heads=self.num_heads,
549
+ num_layers=self.num_down_layers,
550
+ attn=self.attns[i],
551
+ norm_channels=self.norm_channels))
552
+
553
+ self.encoder_mids = nn.ModuleList([])
554
+ for i in range(len(self.mid_channels) - 1):
555
+ self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
556
+ t_emb_dim=None,
557
+ num_heads=self.num_heads,
558
+ num_layers=self.num_mid_layers,
559
+ norm_channels=self.norm_channels))
560
+
561
+ self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
562
+ self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1)
563
+
564
+ # Pre Quantization Convolution
565
+ self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
566
+
567
+ # Codebook
568
+ self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
569
+
570
+ ## Decoder ##
571
+ # Post Quantization Convolution
572
+ self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
573
+ self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
574
+
575
+ # Midblock + Upblock
576
+ self.decoder_mids = nn.ModuleList([])
577
+ for i in reversed(range(1, len(self.mid_channels))):
578
+ self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
579
+ t_emb_dim=None,
580
+ num_heads=self.num_heads,
581
+ num_layers=self.num_mid_layers,
582
+ norm_channels=self.norm_channels))
583
+
584
+ self.decoder_layers = nn.ModuleList([])
585
+ for i in reversed(range(1, len(self.down_channels))):
586
+ self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
587
+ t_emb_dim=None, up_sample=self.down_sample[i - 1],
588
+ num_heads=self.num_heads,
589
+ num_layers=self.num_up_layers,
590
+ attn=self.attns[i-1],
591
+ norm_channels=self.norm_channels))
592
+
593
+ self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
594
+ self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
595
+
596
+ def quantize(self, x):
597
+ B, C, H, W = x.shape
598
+
599
+ # B, C, H, W -> B, H, W, C
600
+ x = x.permute(0, 2, 3, 1)
601
+
602
+ # B, H, W, C -> B, H*W, C
603
+ x = x.reshape(x.size(0), -1, x.size(-1))
604
+
605
+ # Find nearest embedding/codebook vector
606
+ # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K)
607
+ dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
608
+ # (B, H*W)
609
+ min_encoding_indices = torch.argmin(dist, dim=-1)
610
+
611
+ # Replace encoder output with nearest codebook
612
+ # quant_out -> B*H*W, C
613
+ quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
614
+
615
+ # x -> B*H*W, C
616
+ x = x.reshape((-1, x.size(-1)))
617
+ commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)
618
+ codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
619
+ quantize_losses = {
620
+ 'codebook_loss': codebook_loss,
621
+ 'commitment_loss': commmitment_loss
622
+ }
623
+ # Straight through estimation
624
+ quant_out = x + (quant_out - x).detach()
625
+
626
+ # quant_out -> B, C, H, W
627
+ quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
628
+ min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1)))
629
+ return quant_out, quantize_losses, min_encoding_indices
630
+
631
+ def encode(self, x):
632
+ out = self.encoder_conv_in(x)
633
+ for idx, down in enumerate(self.encoder_layers):
634
+ out = down(out)
635
+ for mid in self.encoder_mids:
636
+ out = mid(out)
637
+ out = self.encoder_norm_out(out)
638
+ out = nn.SiLU()(out)
639
+ out = self.encoder_conv_out(out)
640
+ out = self.pre_quant_conv(out)
641
+ out, quant_losses, _ = self.quantize(out)
642
+ return out, quant_losses
643
+
644
+ def decode(self, z):
645
+ out = z
646
+ out = self.post_quant_conv(out)
647
+ out = self.decoder_conv_in(out)
648
+ for mid in self.decoder_mids:
649
+ out = mid(out)
650
+ for idx, up in enumerate(self.decoder_layers):
651
+ out = up(out)
652
+
653
+ out = self.decoder_norm_out(out)
654
+ out = nn.SiLU()(out)
655
+ out = self.decoder_conv_out(out)
656
+ return out
657
+
658
+ def forward(self, x):
659
+ z, quant_losses = self.encode(x)
660
+ out = self.decode(z)
661
+ return out, z, quant_losses
662
+
663
+
664
+ # ==========================================
665
+ # SPADE Definitions
666
+ # ==========================================
667
+
668
+ class SPADE(nn.Module):
669
+ def __init__(self, norm_nc, label_nc):
670
+ super().__init__()
671
+ self.param_free_norm = nn.GroupNorm(32, norm_nc)
672
+ nhidden = 128
673
+
674
+ # Convolutions to generate modulation parameters from the mask
675
+ self.mlp_shared = nn.Sequential(
676
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
677
+ nn.ReLU()
678
+ )
679
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
680
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
681
+
682
+ def forward(self, x, segmap):
683
+ # 1. Normalize
684
+ normalized = self.param_free_norm(x)
685
+
686
+ # 2. Resize mask to match x's resolution
687
+ if segmap.size()[2:] != x.size()[2:]:
688
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
689
+
690
+ # 3. Generate params
691
+ actv = self.mlp_shared(segmap)
692
+ gamma = self.mlp_gamma(actv)
693
+ beta = self.mlp_beta(actv)
694
+
695
+ # 4. Modulate
696
+ out = normalized * (1 + gamma) + beta
697
+ return out
698
+
699
+ class SPADEResnetBlock(nn.Module):
700
+ """
701
+ Simplified SPADE Block: Norm -> Act -> Conv
702
+ (We removed the internal shortcut because DownBlock/MidBlock handles the residual connection)
703
+ """
704
+ def __init__(self, in_channels, out_channels, label_nc):
705
+ super().__init__()
706
+ # 1. SPADE Normalization (Uses Mask)
707
+ self.norm1 = SPADE(in_channels, label_nc)
708
+ # 2. Activation
709
+ self.act1 = nn.SiLU()
710
+ # 3. Convolution
711
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
712
+
713
+ def forward(self, x, segmap):
714
+ # Apply SPADE Norm -> Act -> Conv
715
+ h = self.norm1(x, segmap)
716
+ h = self.act1(h)
717
+ h = self.conv1(h)
718
+ return h
719
+
720
+ # ==========================================
721
+ # BLOCKS (Down, Mid, Up)
722
+ # ==========================================
723
+
724
+ def get_time_embedding(time_steps, temb_dim):
725
+ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
726
+ factor = 10000 ** ((torch.arange(
727
+ start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
728
+ )
729
+ t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
730
+ t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
731
+ return t_emb
732
+
733
+
734
+ class SpadeDownBlock(nn.Module):
735
+ def __init__(self, in_channels, out_channels, t_emb_dim, down_sample, num_heads,
736
+ num_layers, attn, norm_channels, cross_attn=False, context_dim=None, label_nc=4):
737
+ super().__init__()
738
+ self.num_layers = num_layers
739
+ self.down_sample = down_sample
740
+ self.attn = attn
741
+ self.context_dim = context_dim
742
+ self.cross_attn = cross_attn
743
+ self.t_emb_dim = t_emb_dim
744
+
745
+ # REPLACED nn.Sequential with SPADEResnetBlock
746
+ self.resnet_conv_first = nn.ModuleList([
747
+ SPADEResnetBlock(in_channels if i == 0 else out_channels, out_channels, label_nc)
748
+ for i in range(num_layers)
749
+ ])
750
+
751
+ if self.t_emb_dim is not None:
752
+ self.t_emb_layers = nn.ModuleList([
753
+ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels))
754
+ for _ in range(num_layers)
755
+ ])
756
+
757
+ # REPLACED nn.Sequential with SPADEResnetBlock
758
+ self.resnet_conv_second = nn.ModuleList([
759
+ SPADEResnetBlock(out_channels, out_channels, label_nc)
760
+ for _ in range(num_layers)
761
+ ])
762
+
763
+ if self.attn:
764
+ self.attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)])
765
+ self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)])
766
+
767
+ if self.cross_attn:
768
+ self.cross_attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)])
769
+ self.cross_attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)])
770
+ self.context_proj = nn.ModuleList([nn.Linear(context_dim, out_channels) for _ in range(num_layers)])
771
+
772
+ self.residual_input_conv = nn.ModuleList([
773
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
774
+ for i in range(num_layers)
775
+ ])
776
+ self.down_sample_conv = nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity()
777
+
778
+ def forward(self, x, t_emb=None, context=None, segmap=None):
779
+ out = x
780
+ for i in range(self.num_layers):
781
+ resnet_input = out
782
+
783
+ # SPADE Block 1 (Pass segmap)
784
+ out = self.resnet_conv_first[i](out, segmap)
785
+
786
+ if self.t_emb_dim is not None:
787
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
788
+
789
+ # SPADE Block 2 (Pass segmap)
790
+ out = self.resnet_conv_second[i](out, segmap)
791
+
792
+ # No residual add here because SPADEResnetBlock handles its own residual/shortcut
793
+ # But your original code added another residual from the very start of the loop
794
+ out = out + self.residual_input_conv[i](resnet_input)
795
+
796
+ if self.attn:
797
+ batch_size, channels, h, w = out.shape
798
+ in_attn = out.reshape(batch_size, channels, h * w)
799
+ in_attn = self.attention_norms[i](in_attn)
800
+ in_attn = in_attn.transpose(1, 2)
801
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
802
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
803
+ out = out + out_attn
804
+
805
+ if self.cross_attn:
806
+ batch_size, channels, h, w = out.shape
807
+ in_attn = out.reshape(batch_size, channels, h * w)
808
+ in_attn = self.cross_attention_norms[i](in_attn)
809
+ in_attn = in_attn.transpose(1, 2)
810
+ context_proj = self.context_proj[i](context)
811
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
812
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
813
+ out = out + out_attn
814
+
815
+ out = self.down_sample_conv(out)
816
+ return out
817
+
818
+
819
+ class SpadeMidBlock(nn.Module):
820
+ def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None, label_nc=4):
821
+ super().__init__()
822
+ self.num_layers = num_layers
823
+ self.t_emb_dim = t_emb_dim
824
+ self.context_dim = context_dim
825
+ self.cross_attn = cross_attn
826
+
827
+ # REPLACED with SPADE
828
+ self.resnet_conv_first = nn.ModuleList([
829
+ SPADEResnetBlock(in_channels if i == 0 else out_channels, out_channels, label_nc)
830
+ for i in range(num_layers + 1)
831
+ ])
832
+
833
+ if self.t_emb_dim is not None:
834
+ self.t_emb_layers = nn.ModuleList([
835
+ nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))
836
+ for _ in range(num_layers + 1)
837
+ ])
838
+
839
+ # REPLACED with SPADE
840
+ self.resnet_conv_second = nn.ModuleList([
841
+ SPADEResnetBlock(out_channels, out_channels, label_nc)
842
+ for _ in range(num_layers + 1)
843
+ ])
844
+
845
+ self.attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)])
846
+ self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)])
847
+
848
+ if self.cross_attn:
849
+ self.cross_attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)])
850
+ self.cross_attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)])
851
+ self.context_proj = nn.ModuleList([nn.Linear(context_dim, out_channels) for _ in range(num_layers)])
852
+
853
+ self.residual_input_conv = nn.ModuleList([
854
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
855
+ for i in range(num_layers + 1)
856
+ ])
857
+
858
+ def forward(self, x, t_emb=None, context=None, segmap=None):
859
+ out = x
860
+
861
+ # First Block (No Attention)
862
+ resnet_input = out
863
+ out = self.resnet_conv_first[0](out, segmap) # Pass segmap
864
+ if self.t_emb_dim is not None:
865
+ out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
866
+ out = self.resnet_conv_second[0](out, segmap) # Pass segmap
867
+ out = out + self.residual_input_conv[0](resnet_input)
868
+
869
+ for i in range(self.num_layers):
870
+ # Attention
871
+ batch_size, channels, h, w = out.shape
872
+ in_attn = out.reshape(batch_size, channels, h * w)
873
+ in_attn = self.attention_norms[i](in_attn)
874
+ in_attn = in_attn.transpose(1, 2)
875
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
876
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
877
+ out = out + out_attn
878
+
879
+ if self.cross_attn:
880
+ batch_size, channels, h, w = out.shape
881
+ in_attn = out.reshape(batch_size, channels, h * w)
882
+ in_attn = self.cross_attention_norms[i](in_attn)
883
+ in_attn = in_attn.transpose(1, 2)
884
+ context_proj = self.context_proj[i](context)
885
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
886
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
887
+ out = out + out_attn
888
+
889
+ # Next Resnet Block
890
+ resnet_input = out
891
+ out = self.resnet_conv_first[i + 1](out, segmap) # Pass segmap
892
+ if self.t_emb_dim is not None:
893
+ out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
894
+ out = self.resnet_conv_second[i + 1](out, segmap) # Pass segmap
895
+ out = out + self.residual_input_conv[i + 1](resnet_input)
896
+
897
+ return out
898
+
899
+
900
+ class SpadeUpBlock(nn.Module):
901
+ def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, num_heads,
902
+ num_layers, norm_channels, cross_attn=False, context_dim=None, label_nc=4):
903
+ super().__init__()
904
+ self.num_layers = num_layers
905
+ self.up_sample = up_sample
906
+ self.t_emb_dim = t_emb_dim
907
+ self.cross_attn = cross_attn
908
+ self.context_dim = context_dim
909
+
910
+ # REPLACED with SPADE
911
+ self.resnet_conv_first = nn.ModuleList([
912
+ SPADEResnetBlock(in_channels if i == 0 else out_channels, out_channels, label_nc)
913
+ for i in range(num_layers)
914
+ ])
915
+
916
+ if self.t_emb_dim is not None:
917
+ self.t_emb_layers = nn.ModuleList([
918
+ nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))
919
+ for _ in range(num_layers)
920
+ ])
921
+
922
+ # REPLACED with SPADE
923
+ self.resnet_conv_second = nn.ModuleList([
924
+ SPADEResnetBlock(out_channels, out_channels, label_nc)
925
+ for _ in range(num_layers)
926
+ ])
927
+
928
+ self.attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)])
929
+ self.attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)])
930
+
931
+ if self.cross_attn:
932
+ self.cross_attention_norms = nn.ModuleList([nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)])
933
+ self.cross_attentions = nn.ModuleList([nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)])
934
+ self.context_proj = nn.ModuleList([nn.Linear(context_dim, out_channels) for _ in range(num_layers)])
935
+
936
+ self.residual_input_conv = nn.ModuleList([
937
+ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
938
+ for i in range(num_layers)
939
+ ])
940
+ self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 4, 2, 1) if self.up_sample else nn.Identity()
941
+
942
+ def forward(self, x, out_down=None, t_emb=None, context=None, segmap=None):
943
+ x = self.up_sample_conv(x)
944
+ if out_down is not None:
945
+ x = torch.cat([x, out_down], dim=1)
946
+
947
+ out = x
948
+ for i in range(self.num_layers):
949
+ resnet_input = out
950
+ out = self.resnet_conv_first[i](out, segmap) # Pass segmap
951
+
952
+ if self.t_emb_dim is not None:
953
+ out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
954
+
955
+ out = self.resnet_conv_second[i](out, segmap) # Pass segmap
956
+ out = out + self.residual_input_conv[i](resnet_input)
957
+
958
+ batch_size, channels, h, w = out.shape
959
+ in_attn = out.reshape(batch_size, channels, h * w)
960
+ in_attn = self.attention_norms[i](in_attn)
961
+ in_attn = in_attn.transpose(1, 2)
962
+ out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
963
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
964
+ out = out + out_attn
965
+
966
+ if self.cross_attn:
967
+ batch_size, channels, h, w = out.shape
968
+ in_attn = out.reshape(batch_size, channels, h * w)
969
+ in_attn = self.cross_attention_norms[i](in_attn)
970
+ in_attn = in_attn.transpose(1, 2)
971
+ context_proj = self.context_proj[i](context)
972
+ out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
973
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
974
+ out = out + out_attn
975
+
976
+ return out
977
+
978
+ # ==========================================
979
+ # Helper Fuctions
980
+ # ==========================================
981
+
982
+ def validate_image_config(condition_config):
983
+ assert 'image_condition_config' in condition_config, "Image conditioning desired but config missing"
984
+ assert 'image_condition_input_channels' in condition_config['image_condition_config'], "Input channels missing"
985
+ assert 'image_condition_output_channels' in condition_config['image_condition_config'], "Output channels missing"
986
+
987
+ def validate_image_conditional_input(cond_input, x):
988
+ assert 'image' in cond_input, "Model initialized with image conditioning but input missing"
989
+ assert cond_input['image'].shape[0] == x.shape[0], "Batch size mismatch"
990
+
991
+ def get_config_value(config, key, default_value):
992
+ return config[key] if key in config else default_value
993
+
994
+ def get_time_embedding(time_steps, temb_dim):
995
+ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
996
+ factor = 10000 ** ((torch.arange(
997
+ start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
998
+ )
999
+ t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
1000
+ t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
1001
+ return t_emb
1002
+
1003
+ def drop_image_condition(image_condition, im, im_drop_prob):
1004
+ if im_drop_prob > 0:
1005
+ im_drop_mask = torch.zeros((im.shape[0], 1, 1, 1), device=im.device).float().uniform_(0, 1) > im_drop_prob
1006
+ return image_condition * im_drop_mask
1007
+ else:
1008
+ return image_condition
1009
+
1010
+ # ==========================================
1011
+ # UNET Definition
1012
+ # ==========================================
1013
+ class Unet(nn.Module):
1014
+ #Unet model with SPADE integration for anatomical consistency.
1015
+
1016
+ def __init__(self, im_channels, model_config):
1017
+ super().__init__()
1018
+ self.down_channels = model_config['down_channels']
1019
+ self.mid_channels = model_config['mid_channels']
1020
+ self.t_emb_dim = model_config['time_emb_dim']
1021
+ self.down_sample = model_config['down_sample']
1022
+ self.num_down_layers = model_config['num_down_layers']
1023
+ self.num_mid_layers = model_config['num_mid_layers']
1024
+ self.num_up_layers = model_config['num_up_layers']
1025
+ self.attns = model_config['attn_down']
1026
+ self.norm_channels = model_config['norm_channels']
1027
+ self.num_heads = model_config['num_heads']
1028
+ self.conv_out_channels = model_config['conv_out_channels']
1029
+
1030
+ # Validate Config
1031
+ assert self.mid_channels[0] == self.down_channels[-1]
1032
+ assert self.mid_channels[-1] == self.down_channels[-2]
1033
+ assert len(self.down_sample) == len(self.down_channels) - 1
1034
+ assert len(self.attns) == len(self.down_channels) - 1
1035
+
1036
+ # Conditioning Setup
1037
+ self.image_cond = False
1038
+ self.condition_config = get_config_value(model_config, 'condition_config', None)
1039
+
1040
+ # Default mask channels (usually 4: BG, LV, Myo, RV)
1041
+ self.im_cond_input_ch = 4
1042
+
1043
+ if self.condition_config is not None:
1044
+ if 'image' in self.condition_config.get('condition_types', []):
1045
+ self.image_cond = True
1046
+ self.im_cond_input_ch = self.condition_config['image_condition_config']['image_condition_input_channels']
1047
+ self.im_cond_output_ch = self.condition_config['image_condition_config']['image_condition_output_channels']
1048
+
1049
+ # Standard Input Conv
1050
+ # SPADE injects the mask later, so we just take the latent input here.
1051
+ self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1)
1052
+
1053
+ # Time Embedding
1054
+ self.t_proj = nn.Sequential(
1055
+ nn.Linear(self.t_emb_dim, self.t_emb_dim), nn.SiLU(), nn.Linear(self.t_emb_dim, self.t_emb_dim)
1056
+ )
1057
+
1058
+ self.up_sample = list(reversed(self.down_sample))
1059
+ self.downs = nn.ModuleList([])
1060
+
1061
+ # Pass label_nc to Blocks
1062
+ for i in range(len(self.down_channels) - 1):
1063
+ self.downs.append(SpadeDownBlock(
1064
+ self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim,
1065
+ down_sample=self.down_sample[i], num_heads=self.num_heads,
1066
+ num_layers=self.num_down_layers, attn=self.attns[i],
1067
+ norm_channels=self.norm_channels,
1068
+ label_nc=self.im_cond_input_ch # SPADE needs this
1069
+ ))
1070
+
1071
+ self.mids = nn.ModuleList([])
1072
+ for i in range(len(self.mid_channels) - 1):
1073
+ self.mids.append(SpadeMidBlock(
1074
+ self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim,
1075
+ num_heads=self.num_heads, num_layers=self.num_mid_layers,
1076
+ norm_channels=self.norm_channels,
1077
+ label_nc=self.im_cond_input_ch # SPADE needs this
1078
+ ))
1079
+
1080
+ self.ups = nn.ModuleList([])
1081
+ for i in reversed(range(len(self.down_channels) - 1)):
1082
+ self.ups.append(SpadeUpBlock(
1083
+ self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels,
1084
+ self.t_emb_dim, up_sample=self.down_sample[i], num_heads=self.num_heads,
1085
+ num_layers=self.num_up_layers, norm_channels=self.norm_channels,
1086
+ label_nc=self.im_cond_input_ch # SPADE needs this
1087
+ ))
1088
+
1089
+ self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels)
1090
+ self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1)
1091
+
1092
+ def forward(self, x, t, cond_input=None):
1093
+ # 1. Validation
1094
+ if self.image_cond:
1095
+ validate_image_conditional_input(cond_input, x)
1096
+ # Get the mask, but don't concatenate yet
1097
+ im_cond = cond_input['image']
1098
+ else:
1099
+ im_cond = None
1100
+
1101
+ # 2. Initial Conv (Standard)
1102
+ out = self.conv_in(x)
1103
+
1104
+ # 3. Time Embedding
1105
+ t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
1106
+ t_emb = self.t_proj(t_emb)
1107
+
1108
+ # 4. Down Blocks (Pass segmap)
1109
+ down_outs = []
1110
+ for down in self.downs:
1111
+ down_outs.append(out)
1112
+ # Inject Mask into Block
1113
+ out = down(out, t_emb, segmap=im_cond)
1114
+
1115
+ # 5. Mid Blocks (Pass segmap)
1116
+ for mid in self.mids:
1117
+ # Inject Mask into Block
1118
+ out = mid(out, t_emb, segmap=im_cond)
1119
+
1120
+ # 6. Up Blocks (Pass segmap)
1121
+ for up in self.ups:
1122
+ down_out = down_outs.pop()
1123
+ # Inject Mask into Block
1124
+ out = up(out, down_out, t_emb, segmap=im_cond)
1125
+
1126
+ out = self.norm_out(out)
1127
+ out = nn.SiLU()(out)
1128
+ out = self.conv_out(out)
1129
+ return out
1130
+
1131
+ # ==========================================
1132
+ # Noise Schedular Definition
1133
+ # ==========================================
1134
+ class LinearNoiseScheduler:
1135
+ def __init__(self, num_timesteps, beta_start, beta_end):
1136
+ self.num_timesteps = num_timesteps
1137
+ self.beta_start = beta_start
1138
+ self.beta_end = beta_end
1139
+ self.betas = (torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2)
1140
+ self.alphas = 1. - self.betas
1141
+ self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
1142
+ self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
1143
+ self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
1144
+
1145
+ def add_noise(self, original, noise, t):
1146
+ original_shape = original.shape
1147
+ batch_size = original_shape[0]
1148
+ sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
1149
+ sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
1150
+
1151
+ for _ in range(len(original_shape) - 1):
1152
+ sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
1153
+ sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
1154
+
1155
+ return (sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise)
1156
+
1157
+ def sample_prev_timestep(self, xt, noise_pred, t):
1158
+ """
1159
+ Reverse diffusion process: Remove noise to get x_{t-1}
1160
+ """
1161
+ sqrt_one_minus_alpha_bar = self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t].view(-1, 1, 1, 1)
1162
+ sqrt_alpha_bar = self.sqrt_alpha_cum_prod.to(xt.device)[t].view(-1, 1, 1, 1)
1163
+ beta_t = self.betas.to(xt.device)[t].view(-1, 1, 1, 1)
1164
+ alpha_t = self.alphas.to(xt.device)[t].view(-1, 1, 1, 1)
1165
+
1166
+ # 1. Estimate x0 (Original image)
1167
+ x0 = (xt - (sqrt_one_minus_alpha_bar * noise_pred)) / sqrt_alpha_bar
1168
+ x0 = torch.clamp(x0, -1., 1.)
1169
+
1170
+ # 2. Calculate Mean of x_{t-1}
1171
+ mean = (xt - (beta_t * noise_pred) / sqrt_one_minus_alpha_bar) / torch.sqrt(alpha_t)
1172
+
1173
+ # 3. Add Noise (if not last step)
1174
+ if t[0] == 0:
1175
+ return mean, x0
1176
+ else:
1177
+ # Reshape variance to [Batch, 1, 1, 1] too
1178
+ variance = ((1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])) * self.betas.to(xt.device)[t]
1179
+ sigma = (variance ** 0.5).view(-1, 1, 1, 1)
1180
+ z = torch.randn(xt.shape).to(xt.device)
1181
+ return mean + sigma * z, x0
1182
+ # 1. Estimate x0 (Original image)
1183
+ # x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
1184
+ # torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
1185
+ # x0 = torch.clamp(x0, -1., 1.)
1186
+
1187
+ # # 2. Calculate Mean of x_{t-1}
1188
+ # mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
1189
+ # mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
1190
+
1191
+ # # 3. Add Noise (if not last step)
1192
+ # if t == 0:
1193
+ # return mean, x0
1194
+ # else:
1195
+ # variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
1196
+ # variance = variance * self.betas.to(xt.device)[t]
1197
+ # sigma = variance ** 0.5
1198
+ # z = torch.randn(xt.shape).to(xt.device)
1199
+ # return mean + sigma * z, x0
models/{ddpm-150-finetuned β†’ ddpm}/model_index.json RENAMED
File without changes
models/{ddpm-150-finetuned β†’ ddpm}/scheduler/scheduler_config.json RENAMED
File without changes
models/{ddpm-150-finetuned β†’ ddpm}/unet/config.json RENAMED
File without changes
models/{ddpm-150-finetuned β†’ ddpm}/unet/diffusion_pytorch_model.safetensors RENAMED
File without changes
models/{ldm_cardiac_cond128_150_10.pth β†’ ldm.pth} RENAMED
File without changes
models/{vqvae_cardiac_autoencoder128_150_10.pth β†’ vqvae.pth} RENAMED
File without changes
requirements.txt CHANGED
@@ -9,4 +9,7 @@ tqdm
9
  gradio
10
  scipy
11
  safetensors
12
- huggingface_hub
 
 
 
 
9
  gradio
10
  scipy
11
  safetensors
12
+ huggingface_hub
13
+ monai
14
+ monai-generative
15
+ diffusers