JensLundsgaard commited on
Commit
675ee79
·
verified ·
1 Parent(s): 717483d

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +1274 -0
train.py ADDED
@@ -0,0 +1,1274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
6
+ import math
7
+ from PIL import Image
8
+ from PIL import ImageFile
9
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
10
+ import os
11
+ from model import Model
12
+ from raffael_model import ConvLSTMAutoencoder
13
+ from raffael_losses import reconstruction_loss as convlstm_reconstruction_loss, temporal_smoothness_loss
14
+ import sys
15
+
16
+ from torch.utils.data import DataLoader
17
+ from dataset_ivf import IVFSequenceDataset
18
+ from tqdm import tqdm
19
+ from datetime import datetime
20
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
21
+ torch.backends.cuda.enable_flash_sdp(False)
22
+ torch.backends.cuda.enable_math_sdp(True)
23
+ batch_size = 50
24
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
25
+ from huggingface_hub import HfApi
26
+ import wandb
27
+ import gc
28
+ gc.collect()
29
+ import torch.distributed as dist
30
+ from torch.nn.parallel import DistributedDataParallel as DDP
31
+ from torch.utils.data.distributed import DistributedSampler
32
+ import os
33
+ from huggingface_hub import login
34
+ import shutil
35
+ import hashlib
36
+ import json
37
+ VAL_EMBRYOS = ["CZ594-5","CJ261-10","RL747-8","TM272-9","LFA766-1","GT353-3","LGA881-2-5","LBE649-3","TH481-5","LTA908-2","BS648-7","GS955-7","HA1040-4","CM892-5","FC048-6","GC702-6","DI358-3","MM912-4","RK787-3","GSS052-2","OJ319-5","DML373-2","PS292-4","TM294-2","KT573-4","DJC641-4","FE14-020","LD400-1","MV930-2","MDCH869-4","AS662-2","LH1169-8","GA664-1","PMDPI029-1-3","DV116-3","FV709-11","GM456-3","RA361-4","LM844-1","DL020-3","VM570-4","MC833-6","LV613-2","ZS435-5","RM126-7","BK428-2","LS93-8","GS490-7","GF976-4","PMDPI029-1-11","DRL1048-1","BS294-7","CA658-12","RO793-2","GJ191-1","CC007-2","SL313-11","RC545-2-8","OJ319-9","PA289-8","TK319-10","SM686-7","KJ1077-3","BE645-10","BC167-4","VC581-1","FM162-6","PC758-2","HC459-6","DE069-10","GC340-3","BS596-5","PE256-2","LBE857-1","PH783-3","LS1045-4","CC455-3","DL617-6","BS1086-1","CK601-4","DA309-5","LTE064-1","KF460-4","LP181-1","GS349-4","LC47-8","GS205-6","EH309-8","BS1033-2","LL854-1","DHDPI042-6","BN356-6","PA145-2","GC340-1","MM334-5","AG274-2","BA518-7","BC973-4","BA1195-9","AM33-2","AB91-1","AB028-6","BC167-4","AL884-2","AM685-3"]
38
+ def setup_distributed():
39
+ """Initialize distributed training"""
40
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
41
+ rank = int(os.environ["RANK"])
42
+ world_size = int(os.environ['WORLD_SIZE'])
43
+ local_rank = int(os.environ['LOCAL_RANK'])
44
+ else:
45
+ # Single GPU fallback
46
+ rank = 0
47
+ world_size = 1
48
+ local_rank = 0
49
+
50
+ if world_size > 1:
51
+ dist.init_process_group(backend="nccl")
52
+ torch.cuda.set_device(local_rank)
53
+
54
+ return rank, world_size, local_rank
55
+ def cleanup_distributed():
56
+ if dist.is_initialized():
57
+ dist.destroy_process_group()
58
+
59
+ def generate_repo_name(mode, config_dict, file_paths, date_str):
60
+ """
61
+ Generate a unique, deterministic repository name based on configuration and code.
62
+
63
+ Args:
64
+ mode: Training mode (e.g., "convlstm", "convlstm_latent_split")
65
+ config_dict: Dictionary of all configuration parameters
66
+ file_paths: List of file paths to hash
67
+ date_str: Date string (YYYY-MM-DD)
68
+
69
+ Returns:
70
+ str: Repository name (max 96 chars)
71
+ """
72
+ # Create hash input from config
73
+ config_str = json.dumps(config_dict, sort_keys=True)
74
+
75
+ # Hash all file contents
76
+ file_hasher = hashlib.sha256()
77
+ for file_path in file_paths:
78
+ if os.path.exists(file_path):
79
+ with open(file_path, 'rb') as f:
80
+ file_hasher.update(f.read())
81
+ else:
82
+ # If file doesn't exist, add its name to the hash anyway
83
+ file_hasher.update(file_path.encode())
84
+
85
+ # Combine everything into final hash
86
+ combined_hasher = hashlib.sha256()
87
+ combined_hasher.update(config_str.encode())
88
+ combined_hasher.update(file_hasher.digest())
89
+ combined_hasher.update(date_str.encode())
90
+
91
+ # Get short hash (first 8 characters is enough for uniqueness)
92
+ short_hash = combined_hasher.hexdigest()[:8]
93
+
94
+ # Build repo name: embryo-{mode}-{hash}-{date}
95
+ # Example: embryo-convlstm-a3f2b1c9-2025-12-21
96
+ repo_name = f"embryo-{mode}-{short_hash}-{date_str}"
97
+
98
+ # Ensure it's under 96 characters
99
+ if len(repo_name) > 96:
100
+ # Truncate mode if needed
101
+ max_mode_len = 96 - len(f"embryo--{short_hash}-{date_str}")
102
+ truncated_mode = mode[:max_mode_len]
103
+ repo_name = f"embryo-{truncated_mode}-{short_hash}-{date_str}"
104
+
105
+ return repo_name
106
+
107
+ def save_and_push_model(model, repo_name, required_files, model_config=None):
108
+ """
109
+ Save model and push it along with required training files to HuggingFace Hub
110
+
111
+ Args:
112
+ model: The model to save
113
+ repo_name: Repository name on HuggingFace Hub
114
+ required_files: List of file paths to include in the repo
115
+ model_config: Optional dictionary with model configuration to save as config.json
116
+ """
117
+ # Create temporary directory for the repo
118
+ os.makedirs(repo_name, exist_ok=True)
119
+
120
+ # Save the model weights
121
+ try:
122
+ model.save_pretrained(repo_name)
123
+ print(f"Saved model using save_pretrained")
124
+ except Exception as e:
125
+ # If save_pretrained fails, just save the state dict
126
+ print(f"save_pretrained failed ({e}), saving state dict only")
127
+ torch.save(model.state_dict(), os.path.join(repo_name, "pytorch_model.bin"))
128
+
129
+ # Save custom config.json with all ablation parameters
130
+ if model_config is not None:
131
+ config_path = os.path.join(repo_name, "config.json")
132
+ with open(config_path, 'w') as f:
133
+ json.dump(model_config, f, indent=2)
134
+ print(f"Saved config.json with ablation parameters")
135
+
136
+ # Copy all required files
137
+ for file_path in required_files:
138
+ if os.path.exists(file_path):
139
+ shutil.copy2(file_path, repo_name)
140
+ print(f"Added {file_path} to repository")
141
+ else:
142
+ print(f"Warning: {file_path} not found, skipping")
143
+
144
+ # Push model to hub (this uploads model weights and config)
145
+ try:
146
+ model.push_to_hub(repo_name)
147
+ print(f"Pushed model weights to {repo_name}")
148
+ except Exception as e:
149
+ print(f"Warning: push_to_hub failed ({e}), will upload manually")
150
+
151
+ # Upload all files using HfApi (including config.json)
152
+ api = HfApi()
153
+
154
+ # Upload config.json first if it exists
155
+ config_file = os.path.join(repo_name, "config.json")
156
+ if os.path.exists(config_file):
157
+ try:
158
+ api.upload_file(
159
+ path_or_fileobj=config_file,
160
+ path_in_repo="config.json",
161
+ repo_id=f"JensLundsgaard/{repo_name}",
162
+ repo_type="model"
163
+ )
164
+ print(f"Uploaded config.json to HuggingFace Hub")
165
+ except Exception as e:
166
+ print(f"Warning: Failed to upload config.json: {e}")
167
+
168
+
169
+ # Upload additional required files
170
+ for file_path in required_files:
171
+ local_file = os.path.join(repo_name, os.path.basename(file_path))
172
+ if os.path.exists(local_file):
173
+ try:
174
+ api.upload_file(
175
+ path_or_fileobj=local_file,
176
+ path_in_repo=os.path.basename(file_path),
177
+ repo_id=f"JensLundsgaard/{repo_name}",
178
+ repo_type="model"
179
+ )
180
+ print(f"Uploaded {file_path} to HuggingFace Hub")
181
+ except Exception as e:
182
+ print(f"Warning: Failed to upload {file_path}: {e}")
183
+ else:
184
+ print(f"Warning: {local_file} not found, skipping upload")
185
+
186
+ print(f"Successfully pushed all files to {repo_name}")
187
+ def gaussian_kernel(size=11, sigma=1.5):
188
+ """Generate Gaussian kernel for SSIM"""
189
+ coords = torch.arange(size, dtype=torch.float32)
190
+ coords -= size // 2
191
+ g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
192
+ g /= g.sum()
193
+ return g.unsqueeze(0) * g.unsqueeze(1)
194
+
195
+
196
+ def ssim(img1, img2, kernel_size=11, sigma=1.5, C1=0.01**2, C2=0.03**2):
197
+ """
198
+ Single-scale SSIM
199
+ Args:
200
+ img1, img2: (B, C, H, W)
201
+ """
202
+ kernel = gaussian_kernel(kernel_size, sigma).to(img1.device)
203
+ kernel = kernel.unsqueeze(0).unsqueeze(0) # (1, 1, k, k)
204
+
205
+ mu1 = F.conv2d(img1, kernel, padding=kernel_size//2)
206
+ mu2 = F.conv2d(img2, kernel, padding=kernel_size//2)
207
+
208
+ mu1_sq = mu1 ** 2
209
+ mu2_sq = mu2 ** 2
210
+ mu1_mu2 = mu1 * mu2
211
+
212
+ sigma1_sq = F.conv2d(img1 * img1, kernel, padding=kernel_size//2) - mu1_sq
213
+ sigma2_sq = F.conv2d(img2 * img2, kernel, padding=kernel_size//2) - mu2_sq
214
+ sigma12 = F.conv2d(img1 * img2, kernel, padding=kernel_size//2) - mu1_mu2
215
+
216
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
217
+ ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
218
+
219
+ return ssim_map.mean()
220
+
221
+
222
+ def ms_ssim(img1, img2, kernel_size=11, sigma=1.5, weights=None, levels=5):
223
+ """
224
+ Multi-Scale SSIM (MS-SSIM)
225
+ Args:
226
+ img1, img2: (B, C, H, W)
227
+ weights: weights for each scale, default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
228
+ levels: number of scales
229
+ """
230
+ if weights is None:
231
+ weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333],
232
+ device=img1.device)
233
+
234
+ # Ensure weight count matches
235
+ weights = weights[:levels]
236
+ weights = weights / weights.sum()
237
+
238
+ mcs_list = []
239
+ ssim_val = None
240
+
241
+ for i in range(levels):
242
+ if i == levels - 1:
243
+ # Last layer computes SSIM
244
+ ssim_val = ssim(img1, img2, kernel_size, sigma)
245
+ else:
246
+ # Other layers compute contrast
247
+ kernel = gaussian_kernel(kernel_size, sigma).to(img1.device)
248
+ kernel = kernel.unsqueeze(0).unsqueeze(0)
249
+
250
+ mu1 = F.conv2d(img1, kernel, padding=kernel_size//2)
251
+ mu2 = F.conv2d(img2, kernel, padding=kernel_size//2)
252
+
253
+ sigma1_sq = F.conv2d(img1 * img1, kernel, padding=kernel_size//2) - mu1 ** 2
254
+ sigma2_sq = F.conv2d(img2 * img2, kernel, padding=kernel_size//2) - mu2 ** 2
255
+ sigma12 = F.conv2d(img1 * img2, kernel, padding=kernel_size//2) - mu1 * mu2
256
+
257
+ C2 = 0.03 ** 2
258
+ mcs = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
259
+ mcs_list.append(mcs.mean())
260
+
261
+ # Downsample to next level
262
+ if i < levels - 1:
263
+ img1 = F.avg_pool2d(img1, 2)
264
+ img2 = F.avg_pool2d(img2, 2)
265
+
266
+ # Combine all scales
267
+ ms_ssim_val = ssim_val
268
+ for i, mcs in enumerate(mcs_list):
269
+ ms_ssim_val = ms_ssim_val ** weights[i] * mcs ** weights[i]
270
+
271
+ return ms_ssim_val
272
+
273
+
274
+ def reconstruction_loss(x_rec, x_true, l1_weight=0.5, ms_ssim_weight=0.5):
275
+ """
276
+ Combined reconstruction loss: L1 + MS-SSIM
277
+ Args:
278
+ x_rec: (B, T, 1, H, W) - reconstructed video
279
+ x_true: (B, T, 1, H, W) - original video
280
+ l1_weight: L1 loss weight
281
+ ms_ssim_weight: MS-SSIM loss weight
282
+ """
283
+ B, T, C, H, W = x_rec.shape
284
+
285
+ # Flatten temporal dimension for MS-SSIM computation
286
+ x_rec_flat = x_rec.view(B * T, C, H, W) # (B*T, 1, 128, 128)
287
+ x_true_flat = x_true.view(B * T, C, H, W) # (B*T, 1, 128, 128)
288
+
289
+ # L1 Loss
290
+ l1_loss = F.l1_loss(x_rec, x_true)
291
+
292
+ # MS-SSIM Loss
293
+ ms_ssim_val = ms_ssim(x_rec_flat, x_true_flat)
294
+ ms_ssim_loss = 1 - ms_ssim_val
295
+
296
+ # Combined loss
297
+ total_loss = l1_weight * l1_loss + ms_ssim_weight * ms_ssim_loss
298
+
299
+ return total_loss, {
300
+ "l1_loss": l1_loss.item(),
301
+ "ms_ssim_loss": ms_ssim_loss.item(),
302
+ "ms_ssim_value": ms_ssim_val.item()
303
+ }
304
+
305
+
306
+ def train_convlstm(
307
+ loss_type="l1",
308
+ ms_ssim_weight=0.5,
309
+ rec_weight=0.5,
310
+ temporal_weight=0.1,
311
+ dropout_rate=0.1,
312
+ use_convlstm=True,
313
+ use_residual=True,
314
+ use_batchnorm=True,
315
+ model_name=""
316
+ ):
317
+ gc.collect()
318
+ """Training ConvLSTM Autoencoder with configurable loss (single GPU)
319
+
320
+ Args:
321
+ loss_type: "l1" or "mse" - type of reconstruction loss to use with MS-SSIM
322
+ ms_ssim_weight: float - weight for MS-SSIM loss (0 to disable)
323
+ rec_weight: float - weight for reconstruction loss L1/MSE (0 to disable)
324
+ temporal_weight: float - weight for temporal smoothness loss (0 to disable)
325
+ dropout_rate: float - dropout rate (0 to disable)
326
+ use_convlstm: bool - whether to use ConvLSTM (False = no temporal modeling)
327
+ use_residual: bool - whether to use residual connections
328
+ use_batchnorm: bool - whether to use batch normalization
329
+ """
330
+ print(torch.cuda.memory_summary(device=None, abbreviated=False))
331
+ torch.cuda.empty_cache()
332
+ torch.autograd.detect_anomaly(True)
333
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
334
+
335
+ # Build loss description for logging
336
+ loss_components = []
337
+ if ms_ssim_weight > 0:
338
+ loss_components.append(f"MS-SSIM({ms_ssim_weight})")
339
+ if rec_weight > 0:
340
+ loss_components.append(f"{loss_type.upper()}({rec_weight})")
341
+ if temporal_weight > 0:
342
+ loss_components.append(f"Temporal({temporal_weight})")
343
+ loss_description = " + ".join(loss_components) if loss_components else "None"
344
+
345
+ # Build model description for logging
346
+ model_features = []
347
+ if use_convlstm:
348
+ model_features.append("ConvLSTM")
349
+ if use_residual:
350
+ model_features.append("Residual")
351
+ if use_batchnorm:
352
+ model_features.append("BatchNorm")
353
+ if dropout_rate > 0:
354
+ model_features.append(f"Dropout({dropout_rate})")
355
+ model_description = "+".join(model_features) if model_features else "Baseline"
356
+ date_label = datetime.now().strftime("%Y-%m-%d")
357
+
358
+ wandb.login(key=os.getenv("WANDB_KEY"))
359
+ run = wandb.init(
360
+ entity="jenslundsgaard7-uw-madison",
361
+ project="IVF-Training",
362
+ name=model_name +"-" + date_label,
363
+ config={
364
+ "learning_rate": 0.02,
365
+ "architecture": "ConvLSTM Autoencoder",
366
+ "model_features": model_description,
367
+ "dataset": "https://zenodo.org/records/7912264",
368
+ "epochs": 10,
369
+ "train_split": 0.85,
370
+ "val_split": 0.15,
371
+ "loss": loss_description,
372
+ "loss_type": loss_type,
373
+ "ms_ssim_weight": ms_ssim_weight,
374
+ "rec_weight": rec_weight,
375
+ "temporal_weight": temporal_weight,
376
+ "dropout_rate": dropout_rate,
377
+ "use_convlstm": use_convlstm,
378
+ "use_residual": use_residual,
379
+ "use_batchnorm": use_batchnorm,
380
+ "latent_size": 4096,
381
+ "seq_len": 50,
382
+ "image_size": 128,
383
+ "distributed": False,
384
+ },
385
+ )
386
+
387
+ login(os.getenv("HF_KEY"))
388
+ print(torch.cuda.memory_summary(device=None, abbreviated=False))
389
+ print(DEVICE)
390
+ print(f"\n{'='*60}")
391
+ print(f"ABLATION STUDY - Training Configuration")
392
+ print(f"{'='*60}")
393
+ print(f"\nLoss Configuration:")
394
+ print(f" Base Loss Type: {loss_type.upper()}")
395
+ print(f" MS-SSIM Weight: {ms_ssim_weight} {'(DISABLED)' if ms_ssim_weight == 0 else ''}")
396
+ print(f" Reconstruction Weight: {rec_weight} {'(DISABLED)' if rec_weight == 0 else ''}")
397
+ print(f" Temporal Smoothness Weight: {temporal_weight} {'(DISABLED)' if temporal_weight == 0 else ''}")
398
+ print(f" Combined Loss: {loss_description}")
399
+ print(f"\nModel Architecture Configuration:")
400
+ print(f" ConvLSTM: {'ENABLED' if use_convlstm else 'DISABLED'}")
401
+ print(f" Residual Connections: {'ENABLED' if use_residual else 'DISABLED'}")
402
+ print(f" Batch Normalization: {'ENABLED' if use_batchnorm else 'DISABLED'}")
403
+ print(f" Dropout Rate: {dropout_rate} {'(DISABLED)' if dropout_rate == 0 else ''}")
404
+ print(f" Model Features: {model_description}")
405
+ print(f"{'='*60}\n")
406
+
407
+ # Save detailed training configuration
408
+ config_content = f"""ConvLSTM Autoencoder Training Configuration (ABLATION)
409
+ ================================================================================
410
+ Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
411
+
412
+ ABLATION STUDY CONFIGURATION
413
+ ================================================================================
414
+
415
+ """
416
+
417
+ with open("training_config_detailed.txt", "w") as f:
418
+ f.write(config_content)
419
+
420
+ print("Configuration saved to training_config_detailed.txt")
421
+
422
+ model = ConvLSTMAutoencoder(
423
+ seq_len=50,
424
+ input_channels=1,
425
+ encoder_hidden_dim=256,
426
+ encoder_layers=2,
427
+ decoder_hidden_dim=128,
428
+ decoder_layers=2,
429
+ latent_size=4096,
430
+ use_classifier=False,
431
+ num_classes=2,
432
+ use_latent_split=False,
433
+ # Ablation parameters
434
+ dropout_rate=dropout_rate,
435
+ use_convlstm=use_convlstm,
436
+ use_residual=use_residual,
437
+ use_batchnorm=use_batchnorm
438
+ )
439
+
440
+ model = model.to(DEVICE)
441
+
442
+ learning_rate = 2e-4
443
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
444
+
445
+ df = pd.read_csv(os.path.abspath("index.csv"))
446
+ mask = df["cell_id"].str.contains("|".join(VAL_EMBRYOS), regex=True)
447
+ val_df = df[mask]
448
+ train_df = df[~mask]
449
+ train_dataset = IVFSequenceDataset(train_df, resize=128, norm="minmax01")
450
+ val_dataset = IVFSequenceDataset(val_df, resize=128, norm="minmax01")
451
+ print("val size: ", str(len(val_df) / len(df)))
452
+
453
+ #generator = torch.Generator().manual_seed(42)
454
+ #train_dataset, val_dataset = torch.utils.data.random_split(ds, [train_size, val_size], generator=generator)
455
+
456
+ # Create DataLoaders
457
+ loader = DataLoader(
458
+ train_dataset,
459
+ batch_size=1,
460
+ shuffle=True,
461
+ num_workers=4,
462
+ pin_memory=True,
463
+ drop_last=True
464
+ )
465
+ val_loader = DataLoader(
466
+ val_dataset,
467
+ batch_size=1,
468
+ shuffle=False, # No shuffle for validation
469
+ num_workers=4,
470
+ pin_memory=True,
471
+ drop_last=False # Don't drop last for validation
472
+ )
473
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(loader) * 10)
474
+
475
+ for epoch in range(10):
476
+ model.train()
477
+ pbar = tqdm(loader, desc=f"epoch {epoch}")
478
+ total = 0.0
479
+ count = 0
480
+
481
+ for index, (embryo_vol, _, _) in enumerate(pbar):
482
+ optimizer.zero_grad()
483
+
484
+ embryo_vol = embryo_vol.to(DEVICE) # (1, T, 1, 500, 500)
485
+
486
+ # Forward pass - returns (reconstruction, latent_seq)
487
+ embryo_recon, embryo_lat = model(embryo_vol)
488
+
489
+ # Reconstruction loss using MS-SSIM + L1 or MSE (with configurable weights)
490
+ if loss_type == "l1":
491
+ rec_loss, rec_metrics = convlstm_reconstruction_loss(
492
+ embryo_recon, embryo_vol, l1_weight=rec_weight, ms_ssim_weight=ms_ssim_weight
493
+ )
494
+ elif loss_type == "mse":
495
+ # MS-SSIM + MSE loss
496
+ B, T, C, H, W = embryo_recon.shape
497
+ x_rec_flat = embryo_recon.view(B * T, C, H, W)
498
+ x_true_flat = embryo_vol.view(B * T, C, H, W)
499
+
500
+ mse_loss = F.mse_loss(embryo_recon, embryo_vol)
501
+ ms_ssim_val = ms_ssim(x_rec_flat, x_true_flat)
502
+ ms_ssim_loss = 1 - ms_ssim_val
503
+
504
+ rec_loss = rec_weight * mse_loss + ms_ssim_weight * ms_ssim_loss
505
+ rec_metrics = {
506
+ "mse_loss": mse_loss.item(),
507
+ "ms_ssim_loss": ms_ssim_loss.item(),
508
+ "ms_ssim_value": ms_ssim_val.item()
509
+ }
510
+ else:
511
+ raise ValueError(f"Invalid loss_type: {loss_type}. Must be 'l1' or 'mse'")
512
+
513
+ # Temporal smoothness loss (with configurable weight)
514
+ # embryo_lat is (1, T, 4096) - encourages smooth transitions between frames
515
+ if temporal_weight > 0:
516
+ smooth_loss = temporal_smoothness_loss(embryo_lat, weight=temporal_weight)
517
+ loss = rec_loss + smooth_loss
518
+ else:
519
+ smooth_loss = torch.tensor(0.0, device=DEVICE)
520
+ loss = rec_loss
521
+
522
+ if torch.isnan(loss) or torch.isinf(loss):
523
+ print(f"NaN/Inf detected, skipping batch")
524
+ continue
525
+
526
+ loss.backward()
527
+ total_norm = 0
528
+ for p in model.parameters():
529
+ if p.grad is not None:
530
+ param_norm = p.grad.data.norm(2)
531
+ total_norm += param_norm.item() ** 2
532
+ total_norm = total_norm ** 0.5
533
+
534
+ if total_norm > 100:
535
+ print(f"Warning: Large gradient norm: {total_norm:.2f}")
536
+
537
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
538
+ scheduler.step()
539
+ optimizer.step()
540
+ total += loss.item()
541
+ count += 1
542
+
543
+ if (index % 50 == 0) and run is not None:
544
+ log_dict = {
545
+ "step": epoch * len(loader) + index,
546
+ "loss": loss.item(),
547
+ "rec_loss": rec_loss.item(),
548
+ "smooth_loss": smooth_loss.item(),
549
+ "ms_ssim": rec_metrics["ms_ssim_value"],
550
+ "lr": scheduler.get_last_lr()[0]
551
+ }
552
+
553
+ # Add loss-specific metrics
554
+ if loss_type == "l1":
555
+ log_dict["l1_loss"] = rec_metrics["l1_loss"]
556
+ elif loss_type == "mse":
557
+ log_dict["mse_loss"] = rec_metrics["mse_loss"]
558
+
559
+ run.log(log_dict)
560
+
561
+ pbar.set_postfix(
562
+ loss=f"{loss.item():.4f}",
563
+ rec=f"{rec_loss.item():.4f}",
564
+ smooth=f"{smooth_loss.item():.4f}"
565
+ )
566
+
567
+
568
+
569
+ avg_loss = total/max(1, count)
570
+ run.log({"avg_loss": avg_loss})
571
+ print(f"epoch {epoch} avg loss={avg_loss:.4f}")
572
+
573
+ # Save the state dict
574
+ torch.save(model.state_dict(), "convlstm_model_weights.pth")
575
+
576
+ # Generate unique repo name based on config and code
577
+ date_label = datetime.now().strftime("%Y-%m-%d")
578
+
579
+ # Collect all config for hashing
580
+ config_for_hash = {
581
+ "mode": "convlstm",
582
+ "loss_type": loss_type,
583
+ "ms_ssim_weight": ms_ssim_weight,
584
+ "rec_weight": rec_weight,
585
+ "temporal_weight": temporal_weight,
586
+ "dropout_rate": dropout_rate,
587
+ "use_convlstm": use_convlstm,
588
+ "use_residual": use_residual,
589
+ "use_batchnorm": use_batchnorm,
590
+ "learning_rate": 2e-4,
591
+ "encoder_hidden_dim": 256,
592
+ "encoder_layers": 2,
593
+ "decoder_hidden_dim": 128,
594
+ "decoder_layers": 2,
595
+ "latent_size": 4096,
596
+ "seq_len": 50,
597
+ "image_size": 128,
598
+ }
599
+
600
+ # Required files for ConvLSTM model
601
+ required_files = [
602
+ "train.py",
603
+ "raffael_model.py",
604
+ "raffael_losses.py",
605
+ "raffael_conv_lstm.py",
606
+ "dataset_ivf.py",
607
+ "train_model.sh",
608
+ "training_config.txt",
609
+ "training_config_detailed.txt",
610
+ ]
611
+
612
+ # Generate unique repo name
613
+ repo_name = generate_repo_name("convlstm", config_for_hash, required_files, date_label)
614
+
615
+ # Create comprehensive config for HuggingFace
616
+ hf_config = {
617
+ "model_type": "ConvLSTMAutoencoder",
618
+ "architecture": "ConvLSTM Autoencoder",
619
+ # Model architecture parameters
620
+ "seq_len": 50,
621
+ "input_channels": 1,
622
+ "encoder_hidden_dim": 256,
623
+ "encoder_layers": 2,
624
+ "decoder_hidden_dim": 128,
625
+ "decoder_layers": 2,
626
+ "latent_size": 4096,
627
+ "use_classifier": False,
628
+ "num_classes": 2,
629
+ "use_latent_split": False,
630
+ "image_size": 128,
631
+ # Ablation parameters
632
+ "dropout_rate": dropout_rate,
633
+ "use_convlstm": use_convlstm,
634
+ "use_residual": use_residual,
635
+ "use_batchnorm": use_batchnorm,
636
+ # Loss configuration
637
+ "loss_type": loss_type,
638
+ "ms_ssim_weight": ms_ssim_weight,
639
+ "rec_weight": rec_weight,
640
+ "temporal_weight": temporal_weight,
641
+ "loss_description": loss_description,
642
+ # Training configuration
643
+ "learning_rate": 2e-4,
644
+ "weight_decay": 1e-5,
645
+ "optimizer": "Adam",
646
+ "scheduler": "CosineAnnealingLR",
647
+ "batch_size": 1,
648
+ "epochs": 10,
649
+ "gradient_clip": 5.0,
650
+ # Dataset
651
+ "dataset": "https://zenodo.org/records/7912264",
652
+ "resize": 128,
653
+ "normalization": "minmax01",
654
+ # Reproducibility
655
+ "repo_name": repo_name,
656
+ "date": date_label,
657
+ "hash": repo_name.split("-")[-2] if "-" in repo_name else "",
658
+ }
659
+
660
+ save_and_push_model(model, model_name +"-"+ date_label, required_files, model_config=hf_config)
661
+
662
+ # Comprehensive validation with multiple metrics
663
+ val_metrics = {
664
+ 'mse': 0.0,
665
+ 'l1': 0.0,
666
+ 'ms_ssim_value': 0.0,
667
+ 'ms_ssim_loss': 0.0,
668
+ 'temporal_smoothness': 0.0
669
+ }
670
+ val_count = 0
671
+
672
+ model.eval() # Set model to evaluation mode
673
+ with torch.no_grad():
674
+ for embryo_vol, _, _ in val_loader:
675
+ embryo_vol = embryo_vol.to(DEVICE) # (1, T, 1, H, W)
676
+ val_recon, val_lat = model(embryo_vol)
677
+
678
+ B, T, C, H, W = embryo_vol.shape
679
+
680
+ # MSE
681
+ val_metrics['mse'] += F.mse_loss(val_recon, embryo_vol).item()
682
+
683
+ # L1
684
+ val_metrics['l1'] += F.l1_loss(val_recon, embryo_vol).item()
685
+
686
+ # MS-SSIM
687
+ val_recon_flat = val_recon.view(B * T, C, H, W)
688
+ embryo_vol_flat = embryo_vol.view(B * T, C, H, W)
689
+ ms_ssim_val = ms_ssim(val_recon_flat, embryo_vol_flat)
690
+ val_metrics['ms_ssim_value'] += ms_ssim_val.item()
691
+ val_metrics['ms_ssim_loss'] += (1 - ms_ssim_val).item()
692
+
693
+ # Temporal smoothness of latents
694
+ # val_lat is (B, T, latent_size)
695
+ if T > 1:
696
+ lat_diff = torch.diff(val_lat, dim=1) # (B, T-1, latent_size)
697
+ temporal_smooth = lat_diff.norm(dim=-1).mean() # Average L2 norm of differences
698
+ val_metrics['temporal_smoothness'] += temporal_smooth.item()
699
+
700
+ val_count += 1
701
+
702
+ # Average all metrics
703
+ for key in val_metrics:
704
+ val_metrics[key] /= max(1, val_count)
705
+
706
+ # Log to wandb with val_ prefix
707
+ val_log_dict = {
708
+ f"val_{key}": value for key, value in val_metrics.items()
709
+ }
710
+ val_log_dict['val_epoch'] = epoch
711
+ run.log(val_log_dict)
712
+
713
+ print(f"Validation - MSE: {val_metrics['mse']:.4f}, L1: {val_metrics['l1']:.4f}, "
714
+ f"MS-SSIM: {val_metrics['ms_ssim_value']:.4f}, Temporal Smoothness: {val_metrics['temporal_smoothness']:.4f}")
715
+
716
+ run.finish()
717
+ gc.collect()
718
+ torch.cuda.empty_cache()
719
+
720
+
721
+ def train_convlstm_latent_split(
722
+ loss_type="l1",
723
+ ms_ssim_weight=0.5,
724
+ rec_weight=0.5,
725
+ temporal_weight=0.1,
726
+ dropout_rate=0.1,
727
+ use_convlstm=True,
728
+ use_residual=True,
729
+ use_batchnorm=True,
730
+ model_name =""
731
+ ):
732
+ gc.collect()
733
+ """Training ConvLSTM Autoencoder with LATENT SPLIT enabled (single GPU)
734
+
735
+ Args:
736
+ loss_type: "l1" or "mse" - type of reconstruction loss to use with MS-SSIM
737
+ ms_ssim_weight: float - weight for MS-SSIM loss (0 to disable)
738
+ rec_weight: float - weight for reconstruction loss L1/MSE (0 to disable)
739
+ temporal_weight: float - weight for temporal smoothness loss (0 to disable)
740
+ dropout_rate: float - dropout rate (0 to disable)
741
+ use_convlstm: bool - whether to use ConvLSTM (False = no temporal modeling)
742
+ use_residual: bool - whether to use residual connections
743
+ use_batchnorm: bool - whether to use batch normalization
744
+ """
745
+ print(torch.cuda.memory_summary(device=None, abbreviated=False))
746
+ torch.cuda.empty_cache()
747
+ torch.autograd.detect_anomaly(True)
748
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
749
+
750
+ # Build loss description for logging
751
+ loss_components = []
752
+ if ms_ssim_weight > 0:
753
+ loss_components.append(f"MS-SSIM({ms_ssim_weight})")
754
+ if rec_weight > 0:
755
+ loss_components.append(f"{loss_type.upper()}({rec_weight})")
756
+ if temporal_weight > 0:
757
+ loss_components.append(f"Temporal({temporal_weight})")
758
+ loss_description = " + ".join(loss_components) if loss_components else "None"
759
+
760
+ # Build model description for logging
761
+ model_features = []
762
+ if use_convlstm:
763
+ model_features.append("ConvLSTM")
764
+ if use_residual:
765
+ model_features.append("Residual")
766
+ if use_batchnorm:
767
+ model_features.append("BatchNorm")
768
+ if dropout_rate > 0:
769
+ model_features.append(f"Dropout({dropout_rate})")
770
+ model_description = "+".join(model_features) if model_features else "Baseline"
771
+ date_label = datetime.now().strftime("%Y-%m-%d")
772
+
773
+ wandb.login(key=os.getenv("WANDB_KEY"))
774
+ run = wandb.init(
775
+ entity="jenslundsgaard7-uw-madison",
776
+ project="IVF-Training",
777
+ name=model_name + "-" + date_label,
778
+ config={
779
+ "learning_rate": 0.02,
780
+ "architecture": "ConvLSTM Autoencoder with Latent Split",
781
+ "model_features": model_description,
782
+ "dataset": "https://zenodo.org/records/7912264",
783
+ "epochs": 10,
784
+ "train_split": 0.85,
785
+ "val_split": 0.15,
786
+ "loss": loss_description,
787
+ "loss_type": loss_type,
788
+ "ms_ssim_weight": ms_ssim_weight,
789
+ "rec_weight": rec_weight,
790
+ "temporal_weight": temporal_weight,
791
+ "dropout_rate": dropout_rate,
792
+ "use_convlstm": use_convlstm,
793
+ "use_residual": use_residual,
794
+ "use_batchnorm": use_batchnorm,
795
+ "latent_size": 4096,
796
+ "latent_split": True,
797
+ "embryo_latent_size": 2048,
798
+ "empty_latent_size": 2048,
799
+ "seq_len": 50,
800
+ "image_size": 128,
801
+ "distributed": False,
802
+ },
803
+ )
804
+
805
+ login(os.getenv("HF_KEY"))
806
+ print(torch.cuda.memory_summary(device=None, abbreviated=False))
807
+ print(DEVICE)
808
+ print(f"\n{'='*60}")
809
+ print(f"ABLATION STUDY - Training Configuration")
810
+ print(f"{'='*60}")
811
+ print(f"\nLoss Configuration:")
812
+ print(f" Base Loss Type: {loss_type.upper()}")
813
+ print(f" MS-SSIM Weight: {ms_ssim_weight} {'(DISABLED)' if ms_ssim_weight == 0 else ''}")
814
+ print(f" Reconstruction Weight: {rec_weight} {'(DISABLED)' if rec_weight == 0 else ''}")
815
+ print(f" Temporal Smoothness Weight: {temporal_weight} {'(DISABLED)' if temporal_weight == 0 else ''}")
816
+ print(f" Combined Loss: {loss_description}")
817
+ print(f"\nModel Architecture Configuration:")
818
+ print(f" ConvLSTM: {'ENABLED' if use_convlstm else 'DISABLED'}")
819
+ print(f" Residual Connections: {'ENABLED' if use_residual else 'DISABLED'}")
820
+ print(f" Batch Normalization: {'ENABLED' if use_batchnorm else 'DISABLED'}")
821
+ print(f" Dropout Rate: {dropout_rate} {'(DISABLED)' if dropout_rate == 0 else ''}")
822
+ print(f" Model Features: {model_description}")
823
+ print(f"\nLatent Configuration:")
824
+ print(f" Latent Split: ENABLED (2048 for empty, 2048 for embryo)")
825
+ print(f"{'='*60}\n")
826
+
827
+ # Save detailed training configuration
828
+ config_content = f"""ConvLSTM Autoencoder Training Configuration (LATENT SPLIT + ABLATION)
829
+ ================================================================================
830
+ Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
831
+
832
+ ABLATION STUDY CONFIGURATION
833
+ ================================================================================
834
+
835
+ """
836
+
837
+ with open("training_config_latent_split.txt", "w") as f:
838
+ f.write(config_content)
839
+
840
+ print("Configuration saved to training_config_latent_split.txt")
841
+
842
+ # Create model with LATENT SPLIT and ABLATION parameters
843
+ model = ConvLSTMAutoencoder(
844
+ seq_len=50,
845
+ input_channels=1,
846
+ encoder_hidden_dim=256,
847
+ encoder_layers=2,
848
+ decoder_hidden_dim=128,
849
+ decoder_layers=2,
850
+ latent_size=4096,
851
+ use_classifier=False,
852
+ num_classes=2,
853
+ use_latent_split=True, # ENABLE LATENT SPLIT
854
+ # Ablation parameters
855
+ dropout_rate=dropout_rate,
856
+ use_convlstm=use_convlstm,
857
+ use_residual=use_residual,
858
+ use_batchnorm=use_batchnorm
859
+ )
860
+
861
+ model = model.to(DEVICE)
862
+
863
+ learning_rate = 2e-4
864
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
865
+
866
+ df = pd.read_csv(os.path.abspath("index.csv"))
867
+ mask = df["cell_id"].str.contains("|".join(VAL_EMBRYOS), regex=True)
868
+ val_df = df[mask]
869
+ train_df = df[~mask]
870
+ train_dataset = IVFSequenceDataset(train_df, resize=128, norm="minmax01")
871
+ val_dataset = IVFSequenceDataset(val_df, resize=128, norm="minmax01")
872
+ print("val size: ", str(len(val_df) / len(df)))
873
+
874
+ #generator = torch.Generator().manual_seed(42)
875
+ #train_dataset, val_dataset = torch.utils.data.random_split(ds, [train_size, val_size], generator=generator)
876
+
877
+ # Create DataLoaders
878
+ loader = DataLoader(
879
+ train_dataset,
880
+ batch_size=1,
881
+ shuffle=True,
882
+ num_workers=4,
883
+ pin_memory=True,
884
+ drop_last=True
885
+ )
886
+ val_loader = DataLoader(
887
+ val_dataset,
888
+ batch_size=1,
889
+ shuffle=False, # No shuffle for validation
890
+ num_workers=4,
891
+ pin_memory=True,
892
+ drop_last=False # Don't drop last for validation
893
+ )
894
+
895
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(loader) * 10)
896
+
897
+ for epoch in range(10):
898
+ model.train()
899
+ pbar = tqdm(loader, desc=f"epoch {epoch}")
900
+ total = 0.0
901
+ count = 0
902
+
903
+ for index, (embryo_vol, empty_well_vol, _) in enumerate(pbar):
904
+ optimizer.zero_grad()
905
+
906
+ # embryo_vol and empty_well_vol are (1, T, 1, H, W)
907
+ embryo_vol = embryo_vol.to(DEVICE)
908
+ empty_well_vol = empty_well_vol.to(DEVICE)
909
+
910
+ # Forward pass for embryo (uses second half of latent: 2048:4096)
911
+ embryo_recon, embryo_lat = model(embryo_vol, empty_well=False)
912
+
913
+ # Forward pass for empty well (uses first half of latent: 0:2048)
914
+ empty_recon, empty_lat = model(empty_well_vol, empty_well=True)
915
+
916
+ # Reconstruction loss for embryo (with configurable weights)
917
+ if loss_type == "l1":
918
+ rec_loss_embryo, rec_metrics_embryo = convlstm_reconstruction_loss(
919
+ embryo_recon, embryo_vol, l1_weight=rec_weight, ms_ssim_weight=ms_ssim_weight
920
+ )
921
+ elif loss_type == "mse":
922
+ B, T, C, H, W = embryo_recon.shape
923
+ x_rec_flat = embryo_recon.view(B * T, C, H, W)
924
+ x_true_flat = embryo_vol.view(B * T, C, H, W)
925
+
926
+ mse_loss = F.mse_loss(embryo_recon, embryo_vol)
927
+ ms_ssim_val = ms_ssim(x_rec_flat, x_true_flat)
928
+ ms_ssim_loss = 1 - ms_ssim_val
929
+
930
+ rec_loss_embryo = rec_weight * mse_loss + ms_ssim_weight * ms_ssim_loss
931
+ rec_metrics_embryo = {
932
+ "mse_loss": mse_loss.item(),
933
+ "ms_ssim_loss": ms_ssim_loss.item(),
934
+ "ms_ssim_value": ms_ssim_val.item()
935
+ }
936
+ else:
937
+ raise ValueError(f"Invalid loss_type: {loss_type}. Must be 'l1' or 'mse'")
938
+
939
+ # Reconstruction loss for empty well (with configurable weights)
940
+ if loss_type == "l1":
941
+ rec_loss_empty, rec_metrics_empty = convlstm_reconstruction_loss(
942
+ empty_recon, empty_well_vol, l1_weight=rec_weight, ms_ssim_weight=ms_ssim_weight
943
+ )
944
+ elif loss_type == "mse":
945
+ B, T, C, H, W = empty_recon.shape
946
+ x_rec_flat = empty_recon.view(B * T, C, H, W)
947
+ x_true_flat = empty_well_vol.view(B * T, C, H, W)
948
+
949
+ mse_loss = F.mse_loss(empty_recon, empty_well_vol)
950
+ ms_ssim_val = ms_ssim(x_rec_flat, x_true_flat)
951
+ ms_ssim_loss = 1 - ms_ssim_val
952
+
953
+ rec_loss_empty = rec_weight * mse_loss + ms_ssim_weight * ms_ssim_loss
954
+ rec_metrics_empty = {
955
+ "mse_loss": mse_loss.item(),
956
+ "ms_ssim_loss": ms_ssim_loss.item(),
957
+ "ms_ssim_value": ms_ssim_val.item()
958
+ }
959
+
960
+ # Total reconstruction loss
961
+ rec_loss = rec_loss_embryo + rec_loss_empty
962
+
963
+ # Temporal smoothness loss (with configurable weight)
964
+ if temporal_weight > 0:
965
+ smooth_loss_embryo = temporal_smoothness_loss(embryo_lat, weight=temporal_weight)
966
+ smooth_loss_empty = temporal_smoothness_loss(empty_lat, weight=temporal_weight)
967
+ smooth_loss = smooth_loss_embryo + smooth_loss_empty
968
+ loss = rec_loss + smooth_loss
969
+ else:
970
+ smooth_loss = torch.tensor(0.0, device=DEVICE)
971
+ loss = rec_loss
972
+
973
+ if torch.isnan(loss) or torch.isinf(loss):
974
+ print(f"NaN/Inf detected, skipping batch")
975
+ continue
976
+
977
+ loss.backward()
978
+ total_norm = 0
979
+ for p in model.parameters():
980
+ if p.grad is not None:
981
+ param_norm = p.grad.data.norm(2)
982
+ total_norm += param_norm.item() ** 2
983
+ total_norm = total_norm ** 0.5
984
+
985
+ if total_norm > 100:
986
+ print(f"Warning: Large gradient norm: {total_norm:.2f}")
987
+
988
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
989
+ scheduler.step()
990
+ optimizer.step()
991
+ total += loss.item()
992
+ count += 1
993
+
994
+ if (index % 50 == 0) and run is not None:
995
+ log_dict = {
996
+ "step": epoch * len(loader) + index,
997
+ "loss": loss.item(),
998
+ "rec_loss": rec_loss.item(),
999
+ "rec_loss_embryo": rec_loss_embryo.item(),
1000
+ "rec_loss_empty": rec_loss_empty.item(),
1001
+ "smooth_loss": smooth_loss.item(),
1002
+ "ms_ssim_embryo": rec_metrics_embryo["ms_ssim_value"],
1003
+ "ms_ssim_empty": rec_metrics_empty["ms_ssim_value"],
1004
+ "lr": scheduler.get_last_lr()[0]
1005
+ }
1006
+
1007
+ # Add loss-specific metrics
1008
+ if loss_type == "l1":
1009
+ log_dict["l1_loss_embryo"] = rec_metrics_embryo["l1_loss"]
1010
+ log_dict["l1_loss_empty"] = rec_metrics_empty["l1_loss"]
1011
+ elif loss_type == "mse":
1012
+ log_dict["mse_loss_embryo"] = rec_metrics_embryo["mse_loss"]
1013
+ log_dict["mse_loss_empty"] = rec_metrics_empty["mse_loss"]
1014
+
1015
+ run.log(log_dict)
1016
+
1017
+ pbar.set_postfix(
1018
+ loss=f"{loss.item():.4f}",
1019
+ rec_e=f"{rec_loss_embryo.item():.4f}",
1020
+ rec_empty=f"{rec_loss_empty.item():.4f}",
1021
+ smooth=f"{smooth_loss.item():.4f}"
1022
+ )
1023
+
1024
+ avg_loss = total/max(1, count)
1025
+ run.log({"avg_loss": avg_loss})
1026
+ print(f"epoch {epoch} avg loss={avg_loss:.4f}")
1027
+
1028
+ # Save the state dict
1029
+ torch.save(model.state_dict(), "convlstm_latent_split_weights.pth")
1030
+
1031
+ # Generate unique repo name based on config and code
1032
+ date_label = datetime.now().strftime("%Y-%m-%d")
1033
+
1034
+ # Collect all config for hashing
1035
+ config_for_hash = {
1036
+ "mode": "convlstm_latent_split",
1037
+ "loss_type": loss_type,
1038
+ "ms_ssim_weight": ms_ssim_weight,
1039
+ "rec_weight": rec_weight,
1040
+ "temporal_weight": temporal_weight,
1041
+ "dropout_rate": dropout_rate,
1042
+ "use_convlstm": use_convlstm,
1043
+ "use_residual": use_residual,
1044
+ "use_batchnorm": use_batchnorm,
1045
+ "use_latent_split": True,
1046
+ "learning_rate": 2e-4,
1047
+ "encoder_hidden_dim": 256,
1048
+ "encoder_layers": 2,
1049
+ "decoder_hidden_dim": 128,
1050
+ "decoder_layers": 2,
1051
+ "latent_size": 4096,
1052
+ "embryo_latent_size": 2048,
1053
+ "empty_latent_size": 2048,
1054
+ "seq_len": 50,
1055
+ "image_size": 128,
1056
+ }
1057
+
1058
+ # Required files for ConvLSTM model with latent split
1059
+ required_files = [
1060
+ "train.py",
1061
+ "raffael_model.py",
1062
+ "raffael_losses.py",
1063
+ "raffael_conv_lstm.py",
1064
+ "dataset_ivf.py",
1065
+ "train_model.sh",
1066
+ "training_config.txt",
1067
+ "training_config_latent_split.txt",
1068
+ ]
1069
+
1070
+ # Generate unique repo name
1071
+ repo_name = generate_repo_name("convlstm-ls", config_for_hash, required_files, date_label)
1072
+
1073
+ # Create comprehensive config for HuggingFace
1074
+ hf_config = {
1075
+ "model_type": "ConvLSTMAutoencoder",
1076
+ "architecture": "ConvLSTM Autoencoder with Latent Split",
1077
+ # Model architecture parameters
1078
+ "seq_len": 50,
1079
+ "input_channels": 1,
1080
+ "encoder_hidden_dim": 256,
1081
+ "encoder_layers": 2,
1082
+ "decoder_hidden_dim": 128,
1083
+ "decoder_layers": 2,
1084
+ "latent_size": 4096,
1085
+ "use_classifier": False,
1086
+ "num_classes": 2,
1087
+ "use_latent_split": True,
1088
+ "embryo_latent_size": 2048,
1089
+ "empty_latent_size": 2048,
1090
+ "image_size": 128,
1091
+ # Ablation parameters
1092
+ "dropout_rate": dropout_rate,
1093
+ "use_convlstm": use_convlstm,
1094
+ "use_residual": use_residual,
1095
+ "use_batchnorm": use_batchnorm,
1096
+ # Loss configuration
1097
+ "loss_type": loss_type,
1098
+ "ms_ssim_weight": ms_ssim_weight,
1099
+ "rec_weight": rec_weight,
1100
+ "temporal_weight": temporal_weight,
1101
+ "loss_description": loss_description,
1102
+ # Training configuration
1103
+ "learning_rate": 2e-4,
1104
+ "weight_decay": 1e-5,
1105
+ "optimizer": "Adam",
1106
+ "scheduler": "CosineAnnealingLR",
1107
+ "batch_size": 1,
1108
+ "epochs": 10,
1109
+ "gradient_clip": 5.0,
1110
+ # Dataset
1111
+ "dataset": "https://zenodo.org/records/7912264",
1112
+ "resize": 128,
1113
+ "normalization": "minmax01",
1114
+ # Reproducibility
1115
+ "repo_name": repo_name,
1116
+ "date": date_label,
1117
+ "hash": repo_name.split("-")[-2] if "-" in repo_name else "",
1118
+ }
1119
+
1120
+ save_and_push_model(model, model_name + "-" + date_label, required_files, model_config=hf_config)
1121
+ val_metrics = {
1122
+ 'mse': 0.0,
1123
+ 'l1': 0.0,
1124
+ 'ms_ssim_value': 0.0,
1125
+ 'ms_ssim_loss': 0.0,
1126
+ 'temporal_smoothness': 0.0
1127
+ }
1128
+ val_count = 0
1129
+
1130
+ model.eval() # Set model to evaluation mode
1131
+ with torch.no_grad():
1132
+ for embryo_vol, _, _ in val_loader:
1133
+ embryo_vol = embryo_vol.to(DEVICE) # (1, T, 1, H, W)
1134
+ val_recon, val_lat = model(embryo_vol, empty_well=False)
1135
+ _, empty_val_lat = model(embryo_vol, empty_well=True)
1136
+ val_lat = torch.cat([val_lat, empty_val_lat], dim= 2)
1137
+ B, T, C, H, W = embryo_vol.shape
1138
+
1139
+ # MSE
1140
+ val_metrics['mse'] += F.mse_loss(val_recon, embryo_vol).item()
1141
+
1142
+ # L1
1143
+ val_metrics['l1'] += F.l1_loss(val_recon, embryo_vol).item()
1144
+
1145
+ # MS-SSIM
1146
+ val_recon_flat = val_recon.view(B * T, C, H, W)
1147
+ embryo_vol_flat = embryo_vol.view(B * T, C, H, W)
1148
+ ms_ssim_val = ms_ssim(val_recon_flat, embryo_vol_flat)
1149
+ val_metrics['ms_ssim_value'] += ms_ssim_val.item()
1150
+ val_metrics['ms_ssim_loss'] += (1 - ms_ssim_val).item()
1151
+
1152
+ # Temporal smoothness of latents
1153
+ # val_lat is (B, T, latent_size)
1154
+ if T > 1:
1155
+ lat_diff = torch.diff(val_lat, dim=1) # (B, T-1, latent_size)
1156
+ temporal_smooth = lat_diff.norm(dim=-1).mean() # Average L2 norm of differences
1157
+ val_metrics['temporal_smoothness'] += temporal_smooth.item()
1158
+
1159
+ val_count += 1
1160
+
1161
+ # Average all metrics
1162
+ for key in val_metrics:
1163
+ val_metrics[key] /= max(1, val_count)
1164
+
1165
+ # Log to wandb with val_ prefix
1166
+ val_log_dict = {
1167
+ f"val_{key}": value for key, value in val_metrics.items()
1168
+ }
1169
+ val_log_dict['val_epoch'] = epoch
1170
+ run.log(val_log_dict)
1171
+
1172
+
1173
+ run.finish()
1174
+ gc.collect()
1175
+ torch.cuda.empty_cache()
1176
+
1177
+ def train_mse_distributed():
1178
+ print("hi")
1179
+ def train_mse_single():
1180
+ print("hi")
1181
+ def train():
1182
+ print("hi")
1183
+
1184
+ if __name__ == "__main__":
1185
+ import sys
1186
+ import argparse
1187
+
1188
+ # Check if using old command line interface
1189
+ if len(sys.argv) > 1 and sys.argv[1] in ["mse_distributed", "mse_single", "convlstm", "convlstm_latent_split"]:
1190
+ mode = sys.argv[1]
1191
+ if mode == "mse_distributed":
1192
+ train_mse_distributed()
1193
+ elif mode == "mse_single":
1194
+ train_mse_single()
1195
+ elif mode == "convlstm":
1196
+ # Parse additional convlstm arguments with ablation support
1197
+ parser = argparse.ArgumentParser(description="Train ConvLSTM Autoencoder with Ablation Studies")
1198
+ parser.add_argument("mode", type=str, help="Training mode")
1199
+
1200
+ # Loss ablation arguments
1201
+ parser.add_argument("--loss-type", type=str, default="l1", choices=["l1", "mse"],
1202
+ help="Reconstruction loss type: l1 or mse (default: l1)")
1203
+ parser.add_argument("--ms-ssim-weight", type=float, default=0.5,
1204
+ help="Weight for MS-SSIM loss (default: 0.5, set to 0 to disable)")
1205
+ parser.add_argument("--rec-weight", type=float, default=0.5,
1206
+ help="Weight for reconstruction loss (default: 0.5, set to 0 to disable)")
1207
+ parser.add_argument("--temporal-weight", type=float, default=0.1,
1208
+ help="Weight for temporal smoothness loss (default: 0.1, set to 0 to disable)")
1209
+
1210
+ # Model ablation arguments
1211
+ parser.add_argument("--dropout-rate", type=float, default=0.1,
1212
+ help="Dropout rate (default: 0.1, set to 0 to disable)")
1213
+ parser.add_argument("--no-convlstm", action="store_true",
1214
+ help="Disable ConvLSTM (no temporal modeling)")
1215
+ parser.add_argument("--no-residual", action="store_true",
1216
+ help="Disable residual connections")
1217
+ parser.add_argument("--no-batchnorm", action="store_true",
1218
+ help="Disable batch normalization")
1219
+ parser.add_argument("--name", type=str, default="", help="model name duhh")
1220
+ args = parser.parse_args()
1221
+
1222
+ train_convlstm(
1223
+ loss_type=args.loss_type,
1224
+ ms_ssim_weight=args.ms_ssim_weight,
1225
+ rec_weight=args.rec_weight,
1226
+ temporal_weight=args.temporal_weight,
1227
+ dropout_rate=args.dropout_rate,
1228
+ use_convlstm=not args.no_convlstm,
1229
+ use_residual=not args.no_residual,
1230
+ use_batchnorm=not args.no_batchnorm,
1231
+ model_name = args.name
1232
+
1233
+ )
1234
+ elif mode == "convlstm_latent_split":
1235
+ # Parse additional convlstm_latent_split arguments with ablation support
1236
+ parser = argparse.ArgumentParser(description="Train ConvLSTM Autoencoder with Latent Split and Ablation Studies")
1237
+ parser.add_argument("mode", type=str, help="Training mode")
1238
+
1239
+ # Loss ablation arguments
1240
+ parser.add_argument("--loss-type", type=str, default="l1", choices=["l1", "mse"],
1241
+ help="Reconstruction loss type: l1 or mse (default: l1)")
1242
+ parser.add_argument("--ms-ssim-weight", type=float, default=0.5,
1243
+ help="Weight for MS-SSIM loss (default: 0.5, set to 0 to disable)")
1244
+ parser.add_argument("--rec-weight", type=float, default=0.5,
1245
+ help="Weight for reconstruction loss (default: 0.5, set to 0 to disable)")
1246
+ parser.add_argument("--temporal-weight", type=float, default=0.1,
1247
+ help="Weight for temporal smoothness loss (default: 0.1, set to 0 to disable)")
1248
+
1249
+ # Model ablation arguments
1250
+ parser.add_argument("--dropout-rate", type=float, default=0.1,
1251
+ help="Dropout rate (default: 0.1, set to 0 to disable)")
1252
+ parser.add_argument("--no-convlstm", action="store_true",
1253
+ help="Disable ConvLSTM (no temporal modeling)")
1254
+ parser.add_argument("--no-residual", action="store_true",
1255
+ help="Disable residual connections")
1256
+ parser.add_argument("--no-batchnorm", action="store_true",
1257
+ help="Disable batch normalization")
1258
+
1259
+ parser.add_argument("--name", type=str, default="", help="model name duhh")
1260
+ args = parser.parse_args()
1261
+
1262
+ train_convlstm_latent_split(
1263
+ loss_type=args.loss_type,
1264
+ ms_ssim_weight=args.ms_ssim_weight,
1265
+ rec_weight=args.rec_weight,
1266
+ temporal_weight=args.temporal_weight,
1267
+ dropout_rate=args.dropout_rate,
1268
+ use_convlstm=not args.no_convlstm,
1269
+ use_residual=not args.no_residual,
1270
+ use_batchnorm=not args.no_batchnorm,
1271
+ model_name = args.name
1272
+ )
1273
+ else:
1274
+ train()