JensLundsgaard commited on
Commit
d806fdd
·
verified ·
1 Parent(s): e161809

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +1297 -0
train.py ADDED
@@ -0,0 +1,1297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
6
+ import math
7
+ from PIL import Image
8
+ from PIL import ImageFile
9
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
10
+ import os
11
+ from model import Model
12
+ from raffael_model import ConvLSTMAutoencoder
13
+ from raffael_losses import reconstruction_loss as convlstm_reconstruction_loss, temporal_smoothness_loss
14
+ import sys
15
+
16
+ from torch.utils.data import DataLoader
17
+ from dataset_ivf import IVFSequenceDataset
18
+ from tqdm import tqdm
19
+ from datetime import datetime
20
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
21
+ torch.backends.cuda.enable_flash_sdp(False)
22
+ torch.backends.cuda.enable_math_sdp(True)
23
+ batch_size = 50
24
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
25
+ from huggingface_hub import HfApi
26
+ import wandb
27
+ import gc
28
+ gc.collect()
29
+ import torch.distributed as dist
30
+ from torch.nn.parallel import DistributedDataParallel as DDP
31
+ from torch.utils.data.distributed import DistributedSampler
32
+ import os
33
+ from huggingface_hub import login
34
+ import shutil
35
+ import hashlib
36
+ import json
37
+ class RunningStats:
38
+ def __init__(self):
39
+ self.n = 0
40
+ self.mean = 0.0
41
+ self.m2 = 0.0
42
+
43
+ def push(self, x):
44
+ """Add a new value and update statistics."""
45
+ self.n += 1
46
+ delta = x - self.mean
47
+ self.mean += delta / self.n
48
+ delta2 = x - self.mean
49
+ self.m2 += delta * delta2
50
+
51
+ @property
52
+ def variance(self):
53
+ """Returns sample variance (unbiased). Use self.m2 / self.n for population."""
54
+ return self.m2 / (self.n - 1) if self.n > 1 else 0.0
55
+
56
+ @property
57
+ def std_dev(self):
58
+ """Returns sample standard deviation."""
59
+ return math.sqrt(self.variance)
60
+
61
+ VAL_EMBRYOS = ["CZ594-5","CJ261-10","RL747-8","TM272-9","LFA766-1","GT353-3","LGA881-2-5","LBE649-3","TH481-5","LTA908-2","BS648-7","GS955-7","HA1040-4","CM892-5","FC048-6","GC702-6","DI358-3","MM912-4","RK787-3","GSS052-2","OJ319-5","DML373-2","PS292-4","TM294-2","KT573-4","DJC641-4","FE14-020","LD400-1","MV930-2","MDCH869-4","AS662-2","LH1169-8","GA664-1","PMDPI029-1-3","DV116-3","FV709-11","GM456-3","RA361-4","LM844-1","DL020-3","VM570-4","MC833-6","LV613-2","ZS435-5","RM126-7","BK428-2","LS93-8","GS490-7","GF976-4","PMDPI029-1-11","DRL1048-1","BS294-7","CA658-12","RO793-2","GJ191-1","CC007-2","SL313-11","RC545-2-8","OJ319-9","PA289-8","TK319-10","SM686-7","KJ1077-3","BE645-10","BC167-4","VC581-1","FM162-6","PC758-2","HC459-6","DE069-10","GC340-3","BS596-5","PE256-2","LBE857-1","PH783-3","LS1045-4","CC455-3","DL617-6","BS1086-1","CK601-4","DA309-5","LTE064-1","KF460-4","LP181-1","GS349-4","LC47-8","GS205-6","EH309-8","BS1033-2","LL854-1","DHDPI042-6","BN356-6","PA145-2","GC340-1","MM334-5","AG274-2","BA518-7","BC973-4","BA1195-9","AM33-2","AB91-1","AB028-6","BC167-4","AL884-2","AM685-3"]
62
+ def setup_distributed():
63
+ """Initialize distributed training"""
64
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
65
+ rank = int(os.environ["RANK"])
66
+ world_size = int(os.environ['WORLD_SIZE'])
67
+ local_rank = int(os.environ['LOCAL_RANK'])
68
+ else:
69
+ # Single GPU fallback
70
+ rank = 0
71
+ world_size = 1
72
+ local_rank = 0
73
+
74
+ if world_size > 1:
75
+ dist.init_process_group(backend="nccl")
76
+ torch.cuda.set_device(local_rank)
77
+
78
+ return rank, world_size, local_rank
79
+ def cleanup_distributed():
80
+ if dist.is_initialized():
81
+ dist.destroy_process_group()
82
+
83
+ def generate_repo_name(mode, config_dict, file_paths, date_str):
84
+ """
85
+ Generate a unique, deterministic repository name based on configuration and code.
86
+
87
+ Args:
88
+ mode: Training mode (e.g., "convlstm", "convlstm_latent_split")
89
+ config_dict: Dictionary of all configuration parameters
90
+ file_paths: List of file paths to hash
91
+ date_str: Date string (YYYY-MM-DD)
92
+
93
+ Returns:
94
+ str: Repository name (max 96 chars)
95
+ """
96
+ # Create hash input from config
97
+ config_str = json.dumps(config_dict, sort_keys=True)
98
+
99
+ # Hash all file contents
100
+ file_hasher = hashlib.sha256()
101
+ for file_path in file_paths:
102
+ if os.path.exists(file_path):
103
+ with open(file_path, 'rb') as f:
104
+ file_hasher.update(f.read())
105
+ else:
106
+ # If file doesn't exist, add its name to the hash anyway
107
+ file_hasher.update(file_path.encode())
108
+
109
+ # Combine everything into final hash
110
+ combined_hasher = hashlib.sha256()
111
+ combined_hasher.update(config_str.encode())
112
+ combined_hasher.update(file_hasher.digest())
113
+ combined_hasher.update(date_str.encode())
114
+
115
+ # Get short hash (first 8 characters is enough for uniqueness)
116
+ short_hash = combined_hasher.hexdigest()[:8]
117
+
118
+ # Build repo name: embryo-{mode}-{hash}-{date}
119
+ # Example: embryo-convlstm-a3f2b1c9-2025-12-21
120
+ repo_name = f"embryo-{mode}-{short_hash}-{date_str}"
121
+
122
+ # Ensure it's under 96 characters
123
+ if len(repo_name) > 96:
124
+ # Truncate mode if needed
125
+ max_mode_len = 96 - len(f"embryo--{short_hash}-{date_str}")
126
+ truncated_mode = mode[:max_mode_len]
127
+ repo_name = f"embryo-{truncated_mode}-{short_hash}-{date_str}"
128
+
129
+ return repo_name
130
+
131
+ def save_and_push_model(model, repo_name, required_files, model_config=None):
132
+ """
133
+ Save model and push it along with required training files to HuggingFace Hub
134
+
135
+ Args:
136
+ model: The model to save
137
+ repo_name: Repository name on HuggingFace Hub
138
+ required_files: List of file paths to include in the repo
139
+ model_config: Optional dictionary with model configuration to save as config.json
140
+ """
141
+ # Create temporary directory for the repo
142
+ os.makedirs(repo_name, exist_ok=True)
143
+
144
+ # Save the model weights
145
+ try:
146
+ model.save_pretrained(repo_name)
147
+ print(f"Saved model using save_pretrained")
148
+ except Exception as e:
149
+ # If save_pretrained fails, just save the state dict
150
+ print(f"save_pretrained failed ({e}), saving state dict only")
151
+ torch.save(model.state_dict(), os.path.join(repo_name, "pytorch_model.bin"))
152
+
153
+ # Save custom config.json with all ablation parameters
154
+ if model_config is not None:
155
+ config_path = os.path.join(repo_name, "config.json")
156
+ with open(config_path, 'w') as f:
157
+ json.dump(model_config, f, indent=2)
158
+ print(f"Saved config.json with ablation parameters")
159
+
160
+ # Copy all required files
161
+ for file_path in required_files:
162
+ if os.path.exists(file_path):
163
+ shutil.copy2(file_path, repo_name)
164
+ print(f"Added {file_path} to repository")
165
+ else:
166
+ print(f"Warning: {file_path} not found, skipping")
167
+
168
+ # Push model to hub (this uploads model weights and config)
169
+ try:
170
+ model.push_to_hub(repo_name)
171
+ print(f"Pushed model weights to {repo_name}")
172
+ except Exception as e:
173
+ print(f"Warning: push_to_hub failed ({e}), will upload manually")
174
+
175
+ # Upload all files using HfApi (including config.json)
176
+ api = HfApi()
177
+
178
+ # Upload config.json first if it exists
179
+ config_file = os.path.join(repo_name, "config.json")
180
+ if os.path.exists(config_file):
181
+ try:
182
+ api.upload_file(
183
+ path_or_fileobj=config_file,
184
+ path_in_repo="config.json",
185
+ repo_id=f"JensLundsgaard/{repo_name}",
186
+ repo_type="model"
187
+ )
188
+ print(f"Uploaded config.json to HuggingFace Hub")
189
+ except Exception as e:
190
+ print(f"Warning: Failed to upload config.json: {e}")
191
+
192
+
193
+ # Upload additional required files
194
+ for file_path in required_files:
195
+ local_file = os.path.join(repo_name, os.path.basename(file_path))
196
+ if os.path.exists(local_file):
197
+ try:
198
+ api.upload_file(
199
+ path_or_fileobj=local_file,
200
+ path_in_repo=os.path.basename(file_path),
201
+ repo_id=f"JensLundsgaard/{repo_name}",
202
+ repo_type="model"
203
+ )
204
+ print(f"Uploaded {file_path} to HuggingFace Hub")
205
+ except Exception as e:
206
+ print(f"Warning: Failed to upload {file_path}: {e}")
207
+ else:
208
+ print(f"Warning: {local_file} not found, skipping upload")
209
+
210
+ print(f"Successfully pushed all files to {repo_name}")
211
+ def gaussian_kernel(size=11, sigma=1.5):
212
+ """Generate Gaussian kernel for SSIM"""
213
+ coords = torch.arange(size, dtype=torch.float32)
214
+ coords -= size // 2
215
+ g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
216
+ g /= g.sum()
217
+ return g.unsqueeze(0) * g.unsqueeze(1)
218
+
219
+
220
+ def ssim(img1, img2, kernel_size=11, sigma=1.5, C1=0.01**2, C2=0.03**2):
221
+ """
222
+ Single-scale SSIM
223
+ Args:
224
+ img1, img2: (B, C, H, W)
225
+ """
226
+ kernel = gaussian_kernel(kernel_size, sigma).to(img1.device)
227
+ kernel = kernel.unsqueeze(0).unsqueeze(0) # (1, 1, k, k)
228
+
229
+ mu1 = F.conv2d(img1, kernel, padding=kernel_size//2)
230
+ mu2 = F.conv2d(img2, kernel, padding=kernel_size//2)
231
+
232
+ mu1_sq = mu1 ** 2
233
+ mu2_sq = mu2 ** 2
234
+ mu1_mu2 = mu1 * mu2
235
+
236
+ sigma1_sq = F.conv2d(img1 * img1, kernel, padding=kernel_size//2) - mu1_sq
237
+ sigma2_sq = F.conv2d(img2 * img2, kernel, padding=kernel_size//2) - mu2_sq
238
+ sigma12 = F.conv2d(img1 * img2, kernel, padding=kernel_size//2) - mu1_mu2
239
+
240
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
241
+ ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
242
+
243
+ return ssim_map.mean()
244
+
245
+
246
+ def ms_ssim(img1, img2, kernel_size=11, sigma=1.5, weights=None, levels=5):
247
+ if weights is None:
248
+ weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333],
249
+ device=img1.device)[:levels]
250
+
251
+ kernel = gaussian_kernel(kernel_size, sigma).to(img1.device)
252
+ kernel = kernel.unsqueeze(0).unsqueeze(0).repeat(img1.shape[1], 1, 1, 1)
253
+
254
+ mcs_list = []
255
+
256
+ for i in range(levels):
257
+ if i == levels - 1:
258
+ ssim_val = ssim(img1, img2, kernel_size, sigma)
259
+ else:
260
+ # Compute CS (contrast-structure) only
261
+ mu1 = F.conv2d(img1, kernel, padding=kernel_size//2, groups=img1.shape[1])
262
+ mu2 = F.conv2d(img2, kernel, padding=kernel_size//2, groups=img1.shape[1])
263
+
264
+ sigma1_sq = F.conv2d(img1**2, kernel, padding=kernel_size//2, groups=img1.shape[1]) - mu1**2
265
+ sigma2_sq = F.conv2d(img2**2, kernel, padding=kernel_size//2, groups=img1.shape[1]) - mu2**2
266
+ sigma12 = F.conv2d(img1*img2, kernel, padding=kernel_size//2, groups=img1.shape[1]) - mu1*mu2
267
+
268
+ C2 = 0.03**2
269
+ cs = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
270
+ mcs_list.append(cs.mean())
271
+
272
+ img1 = F.avg_pool2d(img1, 2)
273
+ img2 = F.avg_pool2d(img2, 2)
274
+
275
+ # Correct combination
276
+ ms_ssim_val = torch.prod(torch.stack([mcs ** w for mcs, w in zip(mcs_list, weights[:-1])]))
277
+ ms_ssim_val *= ssim_val ** weights[-1]
278
+
279
+ return ms_ssim_val
280
+
281
+ def reconstruction_loss(x_rec, x_true, l1_weight=0.5, ms_ssim_weight=0.5):
282
+ """
283
+ Combined reconstruction loss: L1 + MS-SSIM
284
+ Args:
285
+ x_rec: (B, T, 1, H, W) - reconstructed video
286
+ x_true: (B, T, 1, H, W) - original video
287
+ l1_weight: L1 loss weight
288
+ ms_ssim_weight: MS-SSIM loss weight
289
+ """
290
+ B, T, C, H, W = x_rec.shape
291
+
292
+ # Flatten temporal dimension for MS-SSIM computation
293
+ x_rec_flat = x_rec.view(B * T, C, H, W) # (B*T, 1, 128, 128)
294
+ x_true_flat = x_true.view(B * T, C, H, W) # (B*T, 1, 128, 128)
295
+
296
+ # L1 Loss
297
+ l1_loss = F.l1_loss(x_rec, x_true)
298
+
299
+ # MS-SSIM Loss
300
+ ms_ssim_val = ms_ssim(x_rec_flat, x_true_flat)
301
+ ms_ssim_loss = 1 - ms_ssim_val
302
+
303
+ # Combined loss
304
+ total_loss = l1_weight * l1_loss + ms_ssim_weight * ms_ssim_loss
305
+
306
+ return total_loss, {
307
+ "l1_loss": l1_loss.item(),
308
+ "ms_ssim_loss": ms_ssim_loss.item(),
309
+ "ms_ssim_value": ms_ssim_val.item()
310
+ }
311
+
312
+
313
+ def train_convlstm(
314
+ loss_type="l1",
315
+ ms_ssim_weight=0.5,
316
+ rec_weight=0.5,
317
+ temporal_weight=0.1,
318
+ dropout_rate=0.1,
319
+ use_convlstm=True,
320
+ use_residual=True,
321
+ use_batchnorm=True,
322
+ model_name="",
323
+ latent_size = 4096
324
+ ):
325
+ gc.collect()
326
+ """Training ConvLSTM Autoencoder with configurable loss (single GPU)
327
+
328
+ Args:
329
+ loss_type: "l1" or "mse" - type of reconstruction loss to use with MS-SSIM
330
+ ms_ssim_weight: float - weight for MS-SSIM loss (0 to disable)
331
+ rec_weight: float - weight for reconstruction loss L1/MSE (0 to disable)
332
+ temporal_weight: float - weight for temporal smoothness loss (0 to disable)
333
+ dropout_rate: float - dropout rate (0 to disable)
334
+ use_convlstm: bool - whether to use ConvLSTM (False = no temporal modeling)
335
+ use_residual: bool - whether to use residual connections
336
+ use_batchnorm: bool - whether to use batch normalization
337
+ """
338
+ print(torch.cuda.memory_summary(device=None, abbreviated=False))
339
+ torch.cuda.empty_cache()
340
+ torch.autograd.detect_anomaly(True)
341
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
342
+
343
+ # Build loss description for logging
344
+ loss_components = []
345
+ if ms_ssim_weight > 0:
346
+ loss_components.append(f"MS-SSIM({ms_ssim_weight})")
347
+ if rec_weight > 0:
348
+ loss_components.append(f"{loss_type.upper()}({rec_weight})")
349
+ if temporal_weight > 0:
350
+ loss_components.append(f"Temporal({temporal_weight})")
351
+ loss_description = " + ".join(loss_components) if loss_components else "None"
352
+
353
+ # Build model description for logging
354
+ model_features = []
355
+ if use_convlstm:
356
+ model_features.append("ConvLSTM")
357
+ if use_residual:
358
+ model_features.append("Residual")
359
+ if use_batchnorm:
360
+ model_features.append("BatchNorm")
361
+ if dropout_rate > 0:
362
+ model_features.append(f"Dropout({dropout_rate})")
363
+ model_description = "+".join(model_features) if model_features else "Baseline"
364
+ date_label = datetime.now().strftime("%Y-%m-%d")
365
+
366
+ wandb.login(key=os.getenv("WANDB_KEY"))
367
+ run = wandb.init(
368
+ entity="jenslundsgaard7-uw-madison",
369
+ project="IVF-Training",
370
+ name=model_name +"-" + date_label,
371
+ config={
372
+ "learning_rate": 0.02,
373
+ "architecture": "ConvLSTM Autoencoder",
374
+ "model_features": model_description,
375
+ "dataset": "https://zenodo.org/records/7912264",
376
+ "epochs": 10,
377
+ "train_split": 0.85,
378
+ "val_split": 0.15,
379
+ "loss": loss_description,
380
+ "loss_type": loss_type,
381
+ "ms_ssim_weight": ms_ssim_weight,
382
+ "rec_weight": rec_weight,
383
+ "temporal_weight": temporal_weight,
384
+ "dropout_rate": dropout_rate,
385
+ "use_convlstm": use_convlstm,
386
+ "use_residual": use_residual,
387
+ "use_batchnorm": use_batchnorm,
388
+ "latent_size": latent_size,
389
+ "seq_len": 50,
390
+ "image_size": 128,
391
+ "distributed": False,
392
+ },
393
+ )
394
+
395
+ login(os.getenv("HF_KEY"))
396
+ print(torch.cuda.memory_summary(device=None, abbreviated=False))
397
+ print(DEVICE)
398
+ print(f"\n{'='*60}")
399
+ print(f"ABLATION STUDY - Training Configuration")
400
+ print(f"{'='*60}")
401
+ print(f"\nLoss Configuration:")
402
+ print(f" Base Loss Type: {loss_type.upper()}")
403
+ print(f" MS-SSIM Weight: {ms_ssim_weight} {'(DISABLED)' if ms_ssim_weight == 0 else ''}")
404
+ print(f" Reconstruction Weight: {rec_weight} {'(DISABLED)' if rec_weight == 0 else ''}")
405
+ print(f" Temporal Smoothness Weight: {temporal_weight} {'(DISABLED)' if temporal_weight == 0 else ''}")
406
+ print(f" Combined Loss: {loss_description}")
407
+ print(f"\nModel Architecture Configuration:")
408
+ print(f" ConvLSTM: {'ENABLED' if use_convlstm else 'DISABLED'}")
409
+ print(f" Residual Connections: {'ENABLED' if use_residual else 'DISABLED'}")
410
+ print(f" Batch Normalization: {'ENABLED' if use_batchnorm else 'DISABLED'}")
411
+ print(f" Dropout Rate: {dropout_rate} {'(DISABLED)' if dropout_rate == 0 else ''}")
412
+ print(f" Model Features: {model_description}")
413
+ print(f"{'='*60}\n")
414
+
415
+ # Save detailed training configuration
416
+ config_content = f"""ConvLSTM Autoencoder Training Configuration (ABLATION)
417
+ ================================================================================
418
+ Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
419
+
420
+ ABLATION STUDY CONFIGURATION
421
+ ================================================================================
422
+
423
+ """
424
+
425
+ with open("training_config_detailed.txt", "w") as f:
426
+ f.write(config_content)
427
+
428
+ print("Configuration saved to training_config_detailed.txt")
429
+
430
+ model = ConvLSTMAutoencoder(
431
+ None,
432
+ seq_len=50,
433
+ input_channels=1,
434
+ encoder_hidden_dim=256,
435
+ encoder_layers=2,
436
+ decoder_hidden_dim=128,
437
+ decoder_layers=2,
438
+ latent_size=latent_size,
439
+ use_classifier=False,
440
+ num_classes=2,
441
+ use_latent_split=False,
442
+ # Ablation parameters
443
+ dropout_rate=dropout_rate,
444
+ use_convlstm=use_convlstm,
445
+ use_residual=use_residual,
446
+ use_batchnorm=use_batchnorm
447
+ )
448
+
449
+ model = model.to(DEVICE)
450
+
451
+ learning_rate = 2e-4
452
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
453
+
454
+ df = pd.read_csv(os.path.abspath("index.csv"))
455
+ mask = df["cell_id"].str.contains("|".join(VAL_EMBRYOS), regex=True)
456
+ val_df = df[mask]
457
+ train_df = df[~mask]
458
+ train_dataset = IVFSequenceDataset(train_df, resize=128, norm="minmax01")
459
+ val_dataset = IVFSequenceDataset(val_df, resize=128, norm="minmax01")
460
+ print("val size: ", str(len(val_df) / len(df)))
461
+
462
+ #generator = torch.Generator().manual_seed(42)
463
+ #train_dataset, val_dataset = torch.utils.data.random_split(ds, [train_size, val_size], generator=generator)
464
+
465
+ # Create DataLoaders
466
+ loader = DataLoader(
467
+ train_dataset,
468
+ batch_size=10,
469
+ shuffle=True,
470
+ num_workers=4,
471
+ pin_memory=True,
472
+ drop_last=True
473
+ )
474
+ val_loader = DataLoader(
475
+ val_dataset,
476
+ batch_size=1,
477
+ shuffle=False, # No shuffle for validation
478
+ num_workers=4,
479
+ pin_memory=True,
480
+ drop_last=False # Don't drop last for validation
481
+ )
482
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(loader) * 10)
483
+
484
+ for epoch in range(10):
485
+ model.train()
486
+ pbar = tqdm(loader, desc=f"epoch {epoch}")
487
+ total = 0.0
488
+ count = 0
489
+
490
+ for index, (embryo_vol, _, _) in enumerate(pbar):
491
+ optimizer.zero_grad()
492
+
493
+ embryo_vol = embryo_vol.to(DEVICE) # (1, T, 1, 500, 500)
494
+
495
+ # Forward pass - returns (reconstruction, latent_seq)
496
+ embryo_recon, embryo_lat = model(embryo_vol)
497
+ if(index % 47 == 0):
498
+ vol_img = embryo_vol[0, -1, 0].cpu().detach().numpy()
499
+ recon_img = embryo_recon[0, -1, 0].cpu().detach().numpy()
500
+
501
+ vol_img = (vol_img * 255).astype(np.uint8)
502
+ recon_img = (recon_img * 255).astype(np.uint8)
503
+ comparison = np.concatenate((vol_img, recon_img), axis=1)
504
+
505
+ images = wandb.Image(comparison, caption="Embryo vs Recon comparison")
506
+ run.log({"reconstruction": images})
507
+
508
+ # Reconstruction loss using MS-SSIM + L1 or MSE (with configurable weights)
509
+ if loss_type == "l1":
510
+ rec_loss, rec_metrics = convlstm_reconstruction_loss(
511
+ embryo_recon, embryo_vol, l1_weight=rec_weight, ms_ssim_weight=ms_ssim_weight
512
+ )
513
+ elif loss_type == "mse":
514
+ # MS-SSIM + MSE loss
515
+ B, T, C, H, W = embryo_recon.shape
516
+ x_rec_flat = embryo_recon.view(B * T, C, H, W)
517
+ x_true_flat = embryo_vol.view(B * T, C, H, W)
518
+
519
+ mse_loss = F.mse_loss(embryo_recon, embryo_vol)
520
+ ms_ssim_val = ms_ssim(x_rec_flat, x_true_flat)
521
+ ms_ssim_loss = 1 - ms_ssim_val
522
+
523
+ rec_loss = rec_weight * mse_loss + ms_ssim_weight * ms_ssim_loss
524
+ rec_metrics = {
525
+ "mse_loss": mse_loss.item(),
526
+ "ms_ssim_loss": ms_ssim_loss.item(),
527
+ "ms_ssim_value": ms_ssim_val.item()
528
+ }
529
+ else:
530
+ raise ValueError(f"Invalid loss_type: {loss_type}. Must be 'l1' or 'mse'")
531
+
532
+ # Temporal smoothness loss (with configurable weight)
533
+ # embryo_lat is (1, T, 4096) - encourages smooth transitions between frames
534
+ if temporal_weight > 0:
535
+ smooth_loss = temporal_smoothness_loss(embryo_lat, weight=temporal_weight)
536
+ loss = rec_loss + smooth_loss
537
+ else:
538
+ smooth_loss = torch.tensor(0.0, device=DEVICE)
539
+ loss = rec_loss
540
+
541
+ if torch.isnan(loss) or torch.isinf(loss):
542
+ print(f"NaN/Inf detected, skipping batch")
543
+ continue
544
+
545
+ loss.backward()
546
+ total_norm = 0
547
+ for p in model.parameters():
548
+ if p.grad is not None:
549
+ param_norm = p.grad.data.norm(2)
550
+ total_norm += param_norm.item() ** 2
551
+ total_norm = total_norm ** 0.5
552
+
553
+ if total_norm > 100:
554
+ print(f"Warning: Large gradient norm: {total_norm:.2f}")
555
+
556
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
557
+ scheduler.step()
558
+ optimizer.step()
559
+ total += loss.item()
560
+ count += 1
561
+
562
+ if (index % 50 == 0) and run is not None:
563
+ log_dict = {
564
+ "step": epoch * len(loader) + index,
565
+ "loss": loss.item(),
566
+ "rec_loss": rec_loss.item(),
567
+ "smooth_loss": smooth_loss.item(),
568
+ "ms_ssim": rec_metrics["ms_ssim_value"],
569
+ "lr": scheduler.get_last_lr()[0]
570
+ }
571
+
572
+ # Add loss-specific metrics
573
+ if loss_type == "l1":
574
+ log_dict["l1_loss"] = rec_metrics["l1_loss"]
575
+ elif loss_type == "mse":
576
+ log_dict["mse_loss"] = rec_metrics["mse_loss"]
577
+
578
+ run.log(log_dict)
579
+
580
+ pbar.set_postfix(
581
+ loss=f"{loss.item():.4f}",
582
+ rec=f"{rec_loss.item():.4f}",
583
+ smooth=f"{smooth_loss.item():.4f}"
584
+ )
585
+
586
+
587
+
588
+ avg_loss = total/max(1, count)
589
+ run.log({"avg_loss": avg_loss})
590
+ print(f"epoch {epoch} avg loss={avg_loss:.4f}")
591
+
592
+ # Save the state dict
593
+ torch.save(model.state_dict(), "convlstm_model_weights.pth")
594
+
595
+ # Generate unique repo name based on config and code
596
+ date_label = datetime.now().strftime("%Y-%m-%d")
597
+
598
+ # Collect all config for hashing
599
+ config_for_hash = {
600
+ "mode": "convlstm",
601
+ "loss_type": loss_type,
602
+ "ms_ssim_weight": ms_ssim_weight,
603
+ "rec_weight": rec_weight,
604
+ "temporal_weight": temporal_weight,
605
+ "dropout_rate": dropout_rate,
606
+ "use_convlstm": use_convlstm,
607
+ "use_residual": use_residual,
608
+ "use_batchnorm": use_batchnorm,
609
+ "learning_rate": 2e-4,
610
+ "encoder_hidden_dim": 256,
611
+ "encoder_layers": 2,
612
+ "decoder_hidden_dim": 128,
613
+ "decoder_layers": 2,
614
+ "latent_size": latent_size,
615
+ "seq_len": 50,
616
+ "image_size": 128,
617
+ }
618
+
619
+ # Required files for ConvLSTM model
620
+ required_files = [
621
+ "train.py",
622
+ "raffael_model.py",
623
+ "raffael_losses.py",
624
+ "raffael_conv_lstm.py",
625
+ "dataset_ivf.py",
626
+ "train_model.sh",
627
+ "training_config.txt",
628
+ "training_config_detailed.txt",
629
+ ]
630
+
631
+ # Generate unique repo name
632
+ repo_name = generate_repo_name("convlstm", config_for_hash, required_files, date_label)
633
+
634
+ # Create comprehensive config for HuggingFace
635
+ hf_config = {
636
+ "model_type": "ConvLSTMAutoencoder",
637
+ "architecture": "ConvLSTM Autoencoder",
638
+ # Model architecture parameters
639
+ "seq_len": 50,
640
+ "input_channels": 1,
641
+ "encoder_hidden_dim": 256,
642
+ "encoder_layers": 2,
643
+ "decoder_hidden_dim": 128,
644
+ "decoder_layers": 2,
645
+ "latent_size": latent_size,
646
+ "use_classifier": False,
647
+ "num_classes": 2,
648
+ "use_latent_split": False,
649
+ "image_size": 128,
650
+ # Ablation parameters
651
+ "dropout_rate": dropout_rate,
652
+ "use_convlstm": use_convlstm,
653
+ "use_residual": use_residual,
654
+ "use_batchnorm": use_batchnorm,
655
+ # Loss configuration
656
+ "loss_type": loss_type,
657
+ "ms_ssim_weight": ms_ssim_weight,
658
+ "rec_weight": rec_weight,
659
+ "temporal_weight": temporal_weight,
660
+ "loss_description": loss_description,
661
+ # Training configuration
662
+ "learning_rate": 2e-4,
663
+ "weight_decay": 1e-5,
664
+ "optimizer": "Adam",
665
+ "scheduler": "CosineAnnealingLR",
666
+ "batch_size": 1,
667
+ "epochs": 10,
668
+ "gradient_clip": 5.0,
669
+ # Dataset
670
+ "dataset": "https://zenodo.org/records/7912264",
671
+ "resize": 128,
672
+ "normalization": "minmax01",
673
+ # Reproducibility
674
+ "repo_name": repo_name,
675
+ "date": date_label,
676
+ "hash": repo_name.split("-")[-2] if "-" in repo_name else "",
677
+ }
678
+
679
+ save_and_push_model(model, model_name +"-"+ date_label, required_files, model_config=hf_config)
680
+ val_metrics = {
681
+ 'mse': RunningStats(),
682
+ 'l1': RunningStats(),
683
+ 'ssim': RunningStats(),
684
+ 'temp': RunningStats()
685
+ }
686
+ model.eval() # Set model to evaluation mode
687
+ with torch.no_grad():
688
+ for embryo_vol, _, _ in val_loader:
689
+ embryo_vol = embryo_vol.to(DEVICE) # (1, T, 1, H, W)
690
+ val_recon, val_lat = model(embryo_vol)
691
+ B, T, C, H, W = embryo_vol.shape
692
+
693
+ # MSE
694
+ val_metrics['mse'].push(F.mse_loss(val_recon, embryo_vol).item())
695
+
696
+ # L1
697
+ val_metrics['l1'].push(F.l1_loss(val_recon, embryo_vol).item())
698
+
699
+ # MS-SSIM
700
+ val_recon_flat = val_recon.view(B * T, C, H, W)
701
+ embryo_vol_flat = embryo_vol.view(B * T, C, H, W)
702
+ ms_ssim_val = ms_ssim(val_recon_flat, embryo_vol_flat)
703
+ val_metrics['ssim'].push((1 - ms_ssim_val).item())
704
+
705
+ # Temporal smoothness of latents
706
+ # val_lat is (B, T, latent_size)
707
+ if T > 1:
708
+ lat_diff = torch.diff(val_lat, dim=1) # (B, T-1, latent_size)
709
+ temporal_smooth = lat_diff.norm(dim=-1).mean() # Average L2 norm of differences
710
+ # Log to wandb with val_ prefix
711
+ val_log_dict = {
712
+ f"val_{key}": value.mean for key, value in val_metrics.items()
713
+ }
714
+ val_log_std_dict = {
715
+ f"val_{key}_std": value.std_dev for key, value in val_metrics.items()
716
+ }
717
+
718
+ run.log(val_log_dict)
719
+ run.log(val_log_std_dict)
720
+
721
+
722
+ run.finish()
723
+ gc.collect()
724
+ torch.cuda.empty_cache()
725
+
726
+
727
+ def train_convlstm_latent_split(
728
+ loss_type="l1",
729
+ ms_ssim_weight=0.5,
730
+ rec_weight=0.5,
731
+ temporal_weight=0.1,
732
+ dropout_rate=0.1,
733
+ use_convlstm=True,
734
+ use_residual=True,
735
+ use_batchnorm=True,
736
+ model_name ="",
737
+ latent_size = 4096
738
+ ):
739
+ gc.collect()
740
+ """Training ConvLSTM Autoencoder with LATENT SPLIT enabled (single GPU)
741
+
742
+ Args:
743
+ loss_type: "l1" or "mse" - type of reconstruction loss to use with MS-SSIM
744
+ ms_ssim_weight: float - weight for MS-SSIM loss (0 to disable)
745
+ rec_weight: float - weight for reconstruction loss L1/MSE (0 to disable)
746
+ temporal_weight: float - weight for temporal smoothness loss (0 to disable)
747
+ dropout_rate: float - dropout rate (0 to disable)
748
+ use_convlstm: bool - whether to use ConvLSTM (False = no temporal modeling)
749
+ use_residual: bool - whether to use residual connections
750
+ use_batchnorm: bool - whether to use batch normalization
751
+ """
752
+ print(torch.cuda.memory_summary(device=None, abbreviated=False))
753
+ torch.cuda.empty_cache()
754
+ torch.autograd.detect_anomaly(True)
755
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
756
+
757
+ # Build loss description for logging
758
+ loss_components = []
759
+ if ms_ssim_weight > 0:
760
+ loss_components.append(f"MS-SSIM({ms_ssim_weight})")
761
+ if rec_weight > 0:
762
+ loss_components.append(f"{loss_type.upper()}({rec_weight})")
763
+ if temporal_weight > 0:
764
+ loss_components.append(f"Temporal({temporal_weight})")
765
+ loss_description = " + ".join(loss_components) if loss_components else "None"
766
+
767
+ # Build model description for logging
768
+ model_features = []
769
+ if use_convlstm:
770
+ model_features.append("ConvLSTM")
771
+ if use_residual:
772
+ model_features.append("Residual")
773
+ if use_batchnorm:
774
+ model_features.append("BatchNorm")
775
+ if dropout_rate > 0:
776
+ model_features.append(f"Dropout({dropout_rate})")
777
+ model_description = "+".join(model_features) if model_features else "Baseline"
778
+ date_label = datetime.now().strftime("%Y-%m-%d")
779
+
780
+ wandb.login(key=os.getenv("WANDB_KEY"))
781
+ run = wandb.init(
782
+ entity="jenslundsgaard7-uw-madison",
783
+ project="IVF-Training",
784
+ name=model_name + "-" + date_label,
785
+ config={
786
+ "learning_rate": 0.02,
787
+ "architecture": "ConvLSTM Autoencoder with Latent Split",
788
+ "model_features": model_description,
789
+ "dataset": "https://zenodo.org/records/7912264",
790
+ "epochs": 10,
791
+ "train_split": 0.85,
792
+ "val_split": 0.15,
793
+ "loss": loss_description,
794
+ "loss_type": loss_type,
795
+ "ms_ssim_weight": ms_ssim_weight,
796
+ "rec_weight": rec_weight,
797
+ "temporal_weight": temporal_weight,
798
+ "dropout_rate": dropout_rate,
799
+ "use_convlstm": use_convlstm,
800
+ "use_residual": use_residual,
801
+ "use_batchnorm": use_batchnorm,
802
+ "latent_size": latent_size,
803
+ "latent_split": True,
804
+ "embryo_latent_size": latent_size//2,
805
+ "empty_latent_size": latent_size//2,
806
+ "seq_len": 50,
807
+ "image_size": 128,
808
+ "distributed": False,
809
+ },
810
+ )
811
+
812
+ login(os.getenv("HF_KEY"))
813
+ print(torch.cuda.memory_summary(device=None, abbreviated=False))
814
+ print(DEVICE)
815
+ print(f"\n{'='*60}")
816
+ print(f"ABLATION STUDY - Training Configuration")
817
+ print(f"{'='*60}")
818
+ print(f"\nLoss Configuration:")
819
+ print(f" Base Loss Type: {loss_type.upper()}")
820
+ print(f" MS-SSIM Weight: {ms_ssim_weight} {'(DISABLED)' if ms_ssim_weight == 0 else ''}")
821
+ print(f" Reconstruction Weight: {rec_weight} {'(DISABLED)' if rec_weight == 0 else ''}")
822
+ print(f" Temporal Smoothness Weight: {temporal_weight} {'(DISABLED)' if temporal_weight == 0 else ''}")
823
+ print(f" Combined Loss: {loss_description}")
824
+ print(f"\nModel Architecture Configuration:")
825
+ print(f" ConvLSTM: {'ENABLED' if use_convlstm else 'DISABLED'}")
826
+ print(f" Residual Connections: {'ENABLED' if use_residual else 'DISABLED'}")
827
+ print(f" Batch Normalization: {'ENABLED' if use_batchnorm else 'DISABLED'}")
828
+ print(f" Dropout Rate: {dropout_rate} {'(DISABLED)' if dropout_rate == 0 else ''}")
829
+ print(f" Model Features: {model_description}")
830
+ print(f"\nLatent Configuration:")
831
+ print(f" Latent Split: ENABLED (2048 for empty, 2048 for embryo)")
832
+ print(f"{'='*60}\n")
833
+
834
+ # Save detailed training configuration
835
+ config_content = f"""ConvLSTM Autoencoder Training Configuration (LATENT SPLIT + ABLATION)
836
+ ================================================================================
837
+ Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
838
+
839
+ ABLATION STUDY CONFIGURATION
840
+ ================================================================================
841
+
842
+ """
843
+
844
+ with open("training_config_latent_split.txt", "w") as f:
845
+ f.write(config_content)
846
+
847
+ print("Configuration saved to training_config_latent_split.txt")
848
+
849
+ # Create model with LATENT SPLIT and ABLATION parameters
850
+ """model = ConvLSTMAutoencoder(
851
+ None,
852
+ seq_len=50,
853
+ input_channels=1,
854
+ encoder_hidden_dim=256,
855
+ encoder_layers=2,
856
+ decoder_hidden_dim=128,
857
+ decoder_layers=2,
858
+ latent_size=latent_size,
859
+ use_classifier=False,
860
+ num_classes=2,
861
+ use_latent_split=True, # ENABLE LATENT SPLIT
862
+ # Ablation parameters
863
+ dropout_rate=dropout_rate,
864
+ use_convlstm=use_convlstm,
865
+ use_residual=use_residual,
866
+ use_batchnorm=use_batchnorm
867
+ """
868
+ model = ConvLSTMAutoencoder(
869
+ seq_len=50,
870
+ input_channels=1,
871
+ encoder_hidden_dim=256,
872
+ encoder_layers=2,
873
+ decoder_hidden_dim=128,
874
+ decoder_layers=2,
875
+ latent_size=4096,
876
+ use_classifier=True,
877
+ num_classes=2
878
+ )
879
+ model = model.to(DEVICE)
880
+
881
+ learning_rate = 2e-4
882
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
883
+
884
+ df = pd.read_csv(os.path.abspath("index.csv"))
885
+ mask = df["cell_id"].str.contains("|".join(VAL_EMBRYOS), regex=True)
886
+ val_df = df[mask]
887
+ train_df = df[~mask]
888
+ train_dataset = IVFSequenceDataset(train_df, resize=128, norm="minmax01")
889
+ val_dataset = IVFSequenceDataset(val_df, resize=128, norm="minmax01")
890
+ print("val size: ", str(len(val_df) / len(df)))
891
+
892
+ #generator = torch.Generator().manual_seed(42)
893
+ #train_dataset, val_dataset = torch.utils.data.random_split(ds, [train_size, val_size], generator=generator)
894
+
895
+ # Create DataLoaders
896
+ loader = DataLoader(
897
+ train_dataset,
898
+ batch_size=1,
899
+ shuffle=True,
900
+ num_workers=4,
901
+ pin_memory=True,
902
+ drop_last=True
903
+ )
904
+ val_loader = DataLoader(
905
+ val_dataset,
906
+ batch_size=1,
907
+ shuffle=False, # No shuffle for validation
908
+ num_workers=4,
909
+ pin_memory=True,
910
+ drop_last=False # Don't drop last for validation
911
+ )
912
+
913
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(loader) * 10)
914
+
915
+ for epoch in range(10):
916
+ model.train()
917
+ pbar = tqdm(loader, desc=f"epoch {epoch}")
918
+ total = 0.0
919
+ count = 0
920
+
921
+ for index, (embryo_vol, empty_well_vol, _) in enumerate(pbar):
922
+ optimizer.zero_grad()
923
+
924
+ # embryo_vol and empty_well_vol are (1, T, 1, H, W)
925
+ embryo_vol = embryo_vol.to(DEVICE)
926
+ empty_well_vol = empty_well_vol.to(DEVICE)
927
+
928
+ # Forward pass for embryo (uses second half of latent: 2048:4096)
929
+ embryo_recon, embryo_lat = model(embryo_vol, empty_well=False)
930
+
931
+ # Forward pass for empty well (uses first half of latent: 0:2048)
932
+ empty_recon, empty_lat = model(empty_well_vol, empty_well=True)
933
+ if(index % 47 == 0):
934
+ vol_img = embryo_vol[0, -1, 0].cpu().detach().numpy()
935
+ recon_img = embryo_recon[0, -1, 0].cpu().detach().numpy()
936
+
937
+ vol_img = (vol_img * 255).astype(np.uint8)
938
+ recon_img = (recon_img * 255).astype(np.uint8)
939
+ comparison = np.concatenate((vol_img, recon_img), axis=1)
940
+
941
+ images = wandb.Image(comparison, caption="Embryo vs Recon comparison")
942
+ run.log({"reconstruction": images})
943
+ # Reconstruction loss for embryo (with configurable weights)
944
+ if loss_type == "l1":
945
+ rec_loss_embryo, rec_metrics_embryo = convlstm_reconstruction_loss(
946
+ embryo_recon, embryo_vol, l1_weight=rec_weight, ms_ssim_weight=ms_ssim_weight
947
+ )
948
+ elif loss_type == "mse":
949
+ B, T, C, H, W = embryo_recon.shape
950
+ x_rec_flat = embryo_recon.view(B * T, C, H, W)
951
+ x_true_flat = embryo_vol.view(B * T, C, H, W)
952
+
953
+ mse_loss = F.mse_loss(embryo_recon, embryo_vol)
954
+ ms_ssim_val = ms_ssim(x_rec_flat, x_true_flat)
955
+ ms_ssim_loss = 1 - ms_ssim_val
956
+
957
+ rec_loss_embryo = rec_weight * mse_loss + ms_ssim_weight * ms_ssim_loss
958
+ rec_metrics_embryo = {
959
+ "mse_loss": mse_loss.item(),
960
+ "ms_ssim_loss": ms_ssim_loss.item(),
961
+ "ms_ssim_value": ms_ssim_val.item()
962
+ }
963
+ else:
964
+ raise ValueError(f"Invalid loss_type: {loss_type}. Must be 'l1' or 'mse'")
965
+
966
+ # Reconstruction loss for empty well (with configurable weights)
967
+ if loss_type == "l1":
968
+ rec_loss_empty, rec_metrics_empty = convlstm_reconstruction_loss(
969
+ empty_recon, empty_well_vol, l1_weight=rec_weight, ms_ssim_weight=ms_ssim_weight
970
+ )
971
+ elif loss_type == "mse":
972
+ B, T, C, H, W = empty_recon.shape
973
+ x_rec_flat = empty_recon.view(B * T, C, H, W)
974
+ x_true_flat = empty_well_vol.view(B * T, C, H, W)
975
+
976
+ mse_loss = F.mse_loss(empty_recon, empty_well_vol)
977
+ ms_ssim_val = ms_ssim(x_rec_flat, x_true_flat)
978
+ ms_ssim_loss = 1 - ms_ssim_val
979
+
980
+ rec_loss_empty = rec_weight * mse_loss + ms_ssim_weight * ms_ssim_loss
981
+ rec_metrics_empty = {
982
+ "mse_loss": mse_loss.item(),
983
+ "ms_ssim_loss": ms_ssim_loss.item(),
984
+ "ms_ssim_value": ms_ssim_val.item()
985
+ }
986
+
987
+ # Total reconstruction loss
988
+ rec_loss = rec_loss_embryo + rec_loss_empty
989
+
990
+ # Temporal smoothness loss (with configurable weight)
991
+ if temporal_weight > 0:
992
+ smooth_loss_embryo = temporal_smoothness_loss(embryo_lat, weight=temporal_weight)
993
+ smooth_loss_empty = temporal_smoothness_loss(empty_lat, weight=temporal_weight)
994
+ smooth_loss = smooth_loss_embryo + smooth_loss_empty
995
+ loss = rec_loss + smooth_loss
996
+ else:
997
+ smooth_loss = torch.tensor(0.0, device=DEVICE)
998
+ loss = rec_loss
999
+
1000
+ if torch.isnan(loss) or torch.isinf(loss):
1001
+ print(f"NaN/Inf detected, skipping batch")
1002
+ continue
1003
+
1004
+ loss.backward()
1005
+ total_norm = 0
1006
+ for p in model.parameters():
1007
+ if p.grad is not None:
1008
+ param_norm = p.grad.data.norm(2)
1009
+ total_norm += param_norm.item() ** 2
1010
+ total_norm = total_norm ** 0.5
1011
+
1012
+ if total_norm > 100:
1013
+ print(f"Warning: Large gradient norm: {total_norm:.2f}")
1014
+
1015
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
1016
+ scheduler.step()
1017
+ optimizer.step()
1018
+ total += loss.item()
1019
+ count += 1
1020
+
1021
+ if (index % 50 == 0) and run is not None:
1022
+ log_dict = {
1023
+ "step": epoch * len(loader) + index,
1024
+ "loss": loss.item(),
1025
+ "rec_loss": rec_loss.item(),
1026
+ "rec_loss_embryo": rec_loss_embryo.item(),
1027
+ "rec_loss_empty": rec_loss_empty.item(),
1028
+ "smooth_loss": smooth_loss.item(),
1029
+ "ms_ssim_embryo": rec_metrics_embryo["ms_ssim_value"],
1030
+ "ms_ssim_empty": rec_metrics_empty["ms_ssim_value"],
1031
+ "lr": scheduler.get_last_lr()[0]
1032
+ }
1033
+
1034
+ # Add loss-specific metrics
1035
+ if loss_type == "l1":
1036
+ log_dict["l1_loss_embryo"] = rec_metrics_embryo["l1_loss"]
1037
+ log_dict["l1_loss_empty"] = rec_metrics_empty["l1_loss"]
1038
+ elif loss_type == "mse":
1039
+ log_dict["mse_loss_embryo"] = rec_metrics_embryo["mse_loss"]
1040
+ log_dict["mse_loss_empty"] = rec_metrics_empty["mse_loss"]
1041
+
1042
+ run.log(log_dict)
1043
+
1044
+ pbar.set_postfix(
1045
+ loss=f"{loss.item():.4f}",
1046
+ rec_e=f"{rec_loss_embryo.item():.4f}",
1047
+ rec_empty=f"{rec_loss_empty.item():.4f}",
1048
+ smooth=f"{smooth_loss.item():.4f}"
1049
+ )
1050
+
1051
+ avg_loss = total/max(1, count)
1052
+ run.log({"avg_loss": avg_loss})
1053
+ print(f"epoch {epoch} avg loss={avg_loss:.4f}")
1054
+
1055
+ # Save the state dict
1056
+ torch.save(model.state_dict(), "convlstm_latent_split_weights.pth")
1057
+
1058
+ # Generate unique repo name based on config and code
1059
+ date_label = datetime.now().strftime("%Y-%m-%d")
1060
+
1061
+ # Collect all config for hashing
1062
+ config_for_hash = {
1063
+ "mode": "convlstm_latent_split",
1064
+ "loss_type": loss_type,
1065
+ "ms_ssim_weight": ms_ssim_weight,
1066
+ "rec_weight": rec_weight,
1067
+ "temporal_weight": temporal_weight,
1068
+ "dropout_rate": dropout_rate,
1069
+ "use_convlstm": use_convlstm,
1070
+ "use_residual": use_residual,
1071
+ "use_batchnorm": use_batchnorm,
1072
+ "use_latent_split": True,
1073
+ "learning_rate": 2e-4,
1074
+ "encoder_hidden_dim": 256,
1075
+ "encoder_layers": 2,
1076
+ "decoder_hidden_dim": 128,
1077
+ "decoder_layers": 2,
1078
+ "latent_size": latent_size,
1079
+ "embryo_latent_size": 2048,
1080
+ "empty_latent_size": 2048,
1081
+ "seq_len": 50,
1082
+ "image_size": 128,
1083
+ }
1084
+
1085
+ # Required files for ConvLSTM model with latent split
1086
+ required_files = [
1087
+ "train.py",
1088
+ "raffael_model.py",
1089
+ "raffael_losses.py",
1090
+ "raffael_conv_lstm.py",
1091
+ "dataset_ivf.py",
1092
+ "train_model.sh",
1093
+ "training_config.txt",
1094
+ "training_config_latent_split.txt",
1095
+ ]
1096
+
1097
+ # Generate unique repo name
1098
+ repo_name = generate_repo_name("convlstm-ls", config_for_hash, required_files, date_label)
1099
+
1100
+ # Create comprehensive config for HuggingFace
1101
+ hf_config = {
1102
+ "model_type": "ConvLSTMAutoencoder",
1103
+ "architecture": "ConvLSTM Autoencoder with Latent Split",
1104
+ # Model architecture parameters
1105
+ "seq_len": 50,
1106
+ "input_channels": 1,
1107
+ "encoder_hidden_dim": 256,
1108
+ "encoder_layers": 2,
1109
+ "decoder_hidden_dim": 128,
1110
+ "decoder_layers": 2,
1111
+ "latent_size": latent_size,
1112
+ "use_classifier": False,
1113
+ "num_classes": 2,
1114
+ "use_latent_split": True,
1115
+ "embryo_latent_size": 2048,
1116
+ "empty_latent_size": 2048,
1117
+ "image_size": 128,
1118
+ # Ablation parameters
1119
+ "dropout_rate": dropout_rate,
1120
+ "use_convlstm": use_convlstm,
1121
+ "use_residual": use_residual,
1122
+ "use_batchnorm": use_batchnorm,
1123
+ # Loss configuration
1124
+ "loss_type": loss_type,
1125
+ "ms_ssim_weight": ms_ssim_weight,
1126
+ "rec_weight": rec_weight,
1127
+ "temporal_weight": temporal_weight,
1128
+ "loss_description": loss_description,
1129
+ # Training configuration
1130
+ "learning_rate": 2e-4,
1131
+ "weight_decay": 1e-5,
1132
+ "optimizer": "Adam",
1133
+ "scheduler": "CosineAnnealingLR",
1134
+ "batch_size": 1,
1135
+ "epochs": 10,
1136
+ "gradient_clip": 5.0,
1137
+ # Dataset
1138
+ "dataset": "https://zenodo.org/records/7912264",
1139
+ "resize": 128,
1140
+ "normalization": "minmax01",
1141
+ # Reproducibility
1142
+ "repo_name": repo_name,
1143
+ "date": date_label,
1144
+ "hash": repo_name.split("-")[-2] if "-" in repo_name else "",
1145
+ }
1146
+
1147
+ save_and_push_model(model, model_name + "-" + date_label, required_files, model_config=hf_config)
1148
+ val_metrics = {
1149
+ 'mse': RunningStats(),
1150
+ 'l1': RunningStats(),
1151
+ 'ssim': RunningStats(),
1152
+ 'temp': RunningStats()
1153
+ }
1154
+ model.eval() # Set model to evaluation mode
1155
+ with torch.no_grad():
1156
+ for embryo_vol, _, _ in val_loader:
1157
+ embryo_vol = embryo_vol.to(DEVICE) # (1, T, 1, H, W)
1158
+ val_recon, val_lat = model(embryo_vol, empty_well=False)
1159
+ _, empty_val_lat = model(embryo_vol, empty_well=True)
1160
+ val_lat = torch.cat([val_lat, empty_val_lat], dim= 2)
1161
+ B, T, C, H, W = embryo_vol.shape
1162
+
1163
+ # MSE
1164
+ val_metrics['mse'].push(F.mse_loss(val_recon, embryo_vol).item())
1165
+
1166
+ # L1
1167
+ val_metrics['l1'].push(F.l1_loss(val_recon, embryo_vol).item())
1168
+
1169
+ # MS-SSIM
1170
+ val_recon_flat = val_recon.view(B * T, C, H, W)
1171
+ embryo_vol_flat = embryo_vol.view(B * T, C, H, W)
1172
+ ms_ssim_val = ms_ssim(val_recon_flat, embryo_vol_flat)
1173
+ val_metrics['ssim'].push((1 - ms_ssim_val).item())
1174
+
1175
+ # Temporal smoothness of latents
1176
+ # val_lat is (B, T, latent_size)
1177
+ if T > 1:
1178
+ lat_diff = torch.diff(val_lat, dim=1) # (B, T-1, latent_size)
1179
+ temporal_smooth = lat_diff.norm(dim=-1).mean() # Average L2 norm of differences
1180
+ # Log to wandb with val_ prefix
1181
+ val_log_dict = {
1182
+ f"val_{key}": value.mean for key, value in val_metrics.items()
1183
+ }
1184
+ val_log_std_dict = {
1185
+ f"val_{key}_std": value.std_dev for key, value in val_metrics.items()
1186
+ }
1187
+
1188
+ run.log(val_log_dict)
1189
+ run.log(val_log_std_dict)
1190
+
1191
+ run.finish()
1192
+ gc.collect()
1193
+ torch.cuda.empty_cache()
1194
+
1195
+ def train_mse_distributed():
1196
+ print("hi")
1197
+ def train_mse_single():
1198
+ print("hi")
1199
+ def train():
1200
+ print("hi")
1201
+
1202
+ if __name__ == "__main__":
1203
+ import sys
1204
+ import argparse
1205
+
1206
+ # Check if using old command line interface
1207
+ if len(sys.argv) > 1 and sys.argv[1] in ["mse_distributed", "mse_single", "convlstm", "convlstm_latent_split"]:
1208
+ mode = sys.argv[1]
1209
+ if mode == "mse_distributed":
1210
+ train_mse_distributed()
1211
+ elif mode == "mse_single":
1212
+ train_mse_single()
1213
+ elif mode == "convlstm":
1214
+ # Parse additional convlstm arguments with ablation support
1215
+ parser = argparse.ArgumentParser(description="Train ConvLSTM Autoencoder with Ablation Studies")
1216
+ parser.add_argument("mode", type=str, help="Training mode")
1217
+
1218
+ # Loss ablation arguments
1219
+ parser.add_argument("--loss-type", type=str, default="l1", choices=["l1", "mse"],
1220
+ help="Reconstruction loss type: l1 or mse (default: l1)")
1221
+ parser.add_argument("--ms-ssim-weight", type=float, default=0.5,
1222
+ help="Weight for MS-SSIM loss (default: 0.5, set to 0 to disable)")
1223
+ parser.add_argument("--rec-weight", type=float, default=0.5,
1224
+ help="Weight for reconstruction loss (default: 0.5, set to 0 to disable)")
1225
+ parser.add_argument("--temporal-weight", type=float, default=0.1,
1226
+ help="Weight for temporal smoothness loss (default: 0.1, set to 0 to disable)")
1227
+
1228
+ # Model ablation arguments
1229
+ parser.add_argument("--dropout-rate", type=float, default=0.1,
1230
+ help="Dropout rate (default: 0.1, set to 0 to disable)")
1231
+ parser.add_argument("--no-convlstm", action="store_true",
1232
+ help="Disable ConvLSTM (no temporal modeling)")
1233
+ parser.add_argument("--no-residual", action="store_true",
1234
+ help="Disable residual connections")
1235
+ parser.add_argument("--no-batchnorm", action="store_true",
1236
+ help="Disable batch normalization")
1237
+ parser.add_argument("--name", type=str, default="", help="model name duhh")
1238
+ parser.add_argument("--size", type=int, default=4096, help="lat size bruh")
1239
+ args = parser.parse_args()
1240
+
1241
+ train_convlstm(
1242
+ loss_type=args.loss_type,
1243
+ ms_ssim_weight=args.ms_ssim_weight,
1244
+ rec_weight=args.rec_weight,
1245
+ temporal_weight=args.temporal_weight,
1246
+ dropout_rate=args.dropout_rate,
1247
+ use_convlstm=not args.no_convlstm,
1248
+ use_residual=not args.no_residual,
1249
+ use_batchnorm=not args.no_batchnorm,
1250
+ model_name = args.name,
1251
+ latent_size = args.size
1252
+
1253
+ )
1254
+ elif mode == "convlstm_latent_split":
1255
+ # Parse additional convlstm_latent_split arguments with ablation support
1256
+ parser = argparse.ArgumentParser(description="Train ConvLSTM Autoencoder with Latent Split and Ablation Studies")
1257
+ parser.add_argument("mode", type=str, help="Training mode")
1258
+
1259
+ # Loss ablation arguments
1260
+ parser.add_argument("--loss-type", type=str, default="l1", choices=["l1", "mse"],
1261
+ help="Reconstruction loss type: l1 or mse (default: l1)")
1262
+ parser.add_argument("--ms-ssim-weight", type=float, default=0.5,
1263
+ help="Weight for MS-SSIM loss (default: 0.5, set to 0 to disable)")
1264
+ parser.add_argument("--rec-weight", type=float, default=0.5,
1265
+ help="Weight for reconstruction loss (default: 0.5, set to 0 to disable)")
1266
+ parser.add_argument("--temporal-weight", type=float, default=0.1,
1267
+ help="Weight for temporal smoothness loss (default: 0.1, set to 0 to disable)")
1268
+
1269
+ # Model ablation arguments
1270
+ parser.add_argument("--dropout-rate", type=float, default=0.1,
1271
+ help="Dropout rate (default: 0.1, set to 0 to disable)")
1272
+ parser.add_argument("--no-convlstm", action="store_true",
1273
+ help="Disable ConvLSTM (no temporal modeling)")
1274
+ parser.add_argument("--no-residual", action="store_true",
1275
+ help="Disable residual connections")
1276
+ parser.add_argument("--no-batchnorm", action="store_true",
1277
+ help="Disable batch normalization")
1278
+
1279
+ parser.add_argument("--name", type=str, default="", help="model name duhh")
1280
+
1281
+ parser.add_argument("--size", type=int, default=4096, help="lat size bruh")
1282
+ args = parser.parse_args()
1283
+
1284
+ train_convlstm_latent_split(
1285
+ loss_type=args.loss_type,
1286
+ ms_ssim_weight=args.ms_ssim_weight,
1287
+ rec_weight=args.rec_weight,
1288
+ temporal_weight=args.temporal_weight,
1289
+ dropout_rate=args.dropout_rate,
1290
+ use_convlstm=not args.no_convlstm,
1291
+ use_residual=not args.no_residual,
1292
+ use_batchnorm=not args.no_batchnorm,
1293
+ model_name = args.name,
1294
+ latent_size = args.size
1295
+ )
1296
+ else:
1297
+ train()