AnikS22 commited on
Commit
db4158f
·
verified ·
1 Parent(s): 7f58aa4

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +387 -0
train.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main training script for immunogold CenterNet.
3
+
4
+ Usage:
5
+ python train.py --fold S1 --seed 42 --config config/config.yaml
6
+ python train.py --fold S1 --seed 42 --config config/config.yaml --dry-run
7
+ python train.py --fold S1 --seed 42 --config config/config.yaml --device cuda:0
8
+ """
9
+
10
+ import argparse
11
+ import os
12
+ import random
13
+ import sys
14
+ from pathlib import Path
15
+
16
+ import numpy as np
17
+ import torch
18
+ import yaml
19
+ from torch.utils.data import DataLoader
20
+ from torch.utils.tensorboard import SummaryWriter
21
+
22
+ from src.dataset import ImmunogoldDataset
23
+ from src.evaluate import match_detections_to_gt
24
+ from src.heatmap import extract_peaks
25
+ from src.loss import total_loss
26
+ from src.model import ImmunogoldCenterNet
27
+ from src.preprocessing import discover_synapse_data, load_synapse
28
+ from src.ensemble import sliding_window_inference
29
+ from src.postprocess import cross_class_nms
30
+
31
+
32
+ def set_seed(seed: int):
33
+ """Set all random seeds for reproducibility."""
34
+ random.seed(seed)
35
+ np.random.seed(seed)
36
+ torch.manual_seed(seed)
37
+ if torch.cuda.is_available():
38
+ torch.cuda.manual_seed_all(seed)
39
+ torch.backends.cudnn.deterministic = True
40
+ torch.backends.cudnn.benchmark = False
41
+
42
+
43
+ def parse_args():
44
+ parser = argparse.ArgumentParser(description="Train immunogold CenterNet")
45
+ parser.add_argument("--fold", type=str, required=True,
46
+ help="Synapse ID to hold out (e.g., S1)")
47
+ parser.add_argument("--seed", type=int, default=42)
48
+ parser.add_argument("--config", type=str, default="config/config.yaml")
49
+ parser.add_argument("--device", type=str, default="auto",
50
+ help="Device: auto, cpu, cuda, cuda:0, etc.")
51
+ parser.add_argument("--dry-run", action="store_true",
52
+ help="Load data, build model, run 1 batch, exit")
53
+ parser.add_argument("--resume", type=str, default=None,
54
+ help="Path to checkpoint to resume from")
55
+ return parser.parse_args()
56
+
57
+
58
+ def get_device(device_str: str) -> torch.device:
59
+ if device_str == "auto":
60
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ return torch.device(device_str)
62
+
63
+
64
+ def validate_epoch(
65
+ model, val_data, device, cfg, conf_threshold=0.3,
66
+ ):
67
+ """
68
+ Run validation: sliding window inference on held-out image.
69
+
70
+ Returns dict with val_loss, val_f1_6nm, val_f1_12nm, val_f1_mean.
71
+ """
72
+ model.eval()
73
+ has_6nm = val_data["synapse_id"] not in cfg["data"].get("incomplete_6nm", [])
74
+
75
+ with torch.no_grad():
76
+ heatmap_np, offset_np = sliding_window_inference(
77
+ model, val_data["image"],
78
+ patch_size=cfg["data"]["patch_size"],
79
+ device=device,
80
+ )
81
+
82
+ # Extract detections
83
+ heatmap_t = torch.from_numpy(heatmap_np)
84
+ offset_t = torch.from_numpy(offset_np)
85
+
86
+ detections = extract_peaks(
87
+ heatmap_t, offset_t,
88
+ stride=cfg["data"]["stride"],
89
+ conf_threshold=conf_threshold,
90
+ nms_kernel_sizes=cfg["postprocessing"]["nms_kernel_size"],
91
+ )
92
+ detections = cross_class_nms(
93
+ detections,
94
+ cfg["postprocessing"]["cross_class_nms_distance_px"],
95
+ )
96
+
97
+ # Evaluate
98
+ gt = val_data["annotations"]
99
+ results = match_detections_to_gt(
100
+ detections,
101
+ gt.get("6nm", np.empty((0, 2))),
102
+ gt.get("12nm", np.empty((0, 2))),
103
+ match_radii={k: float(v) for k, v in cfg["evaluation"]["match_radii_px"].items()},
104
+ )
105
+
106
+ return {
107
+ "val_f1_6nm": results["6nm"]["f1"] if has_6nm else float("nan"),
108
+ "val_f1_12nm": results["12nm"]["f1"],
109
+ "val_f1_mean": results["mean_f1"],
110
+ "detections": detections,
111
+ "results": results,
112
+ }
113
+
114
+
115
+ def train_phase(
116
+ model, train_loader, optimizer, scheduler, device, cfg,
117
+ phase_num, n_epochs, writer, global_epoch, val_data,
118
+ best_f1, checkpoint_dir, snapshot_epochs,
119
+ ):
120
+ """Train one phase, return updated global_epoch and best_f1."""
121
+ model.train()
122
+ focal_alpha = cfg["training"]["loss"]["focal_alpha"]
123
+ focal_beta = cfg["training"]["loss"]["focal_beta"]
124
+ lambda_offset = cfg["training"]["loss"]["lambda_offset"]
125
+ patience = cfg["training"]["early_stopping"]["patience"]
126
+ no_improve = 0
127
+
128
+ for epoch in range(n_epochs):
129
+ global_epoch += 1
130
+ epoch_loss = 0.0
131
+ epoch_hm_loss = 0.0
132
+ epoch_off_loss = 0.0
133
+ n_batches = 0
134
+
135
+ model.train()
136
+ for batch in train_loader:
137
+ images = batch["image"].to(device)
138
+ hm_gt = batch["heatmap"].to(device)
139
+ off_gt = batch["offsets"].to(device)
140
+ off_mask = batch["offset_mask"].to(device)
141
+ conf_map = batch["conf_map"].to(device)
142
+
143
+ optimizer.zero_grad()
144
+ hm_pred, off_pred = model(images)
145
+
146
+ loss, hm_loss, off_loss = total_loss(
147
+ hm_pred, hm_gt, off_pred, off_gt, off_mask,
148
+ lambda_offset=lambda_offset,
149
+ focal_alpha=focal_alpha,
150
+ focal_beta=focal_beta,
151
+ conf_weights=conf_map,
152
+ )
153
+
154
+ loss.backward()
155
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
156
+ optimizer.step()
157
+
158
+ epoch_loss += loss.item()
159
+ epoch_hm_loss += hm_loss
160
+ epoch_off_loss += off_loss
161
+ n_batches += 1
162
+
163
+ if scheduler is not None:
164
+ scheduler.step()
165
+
166
+ avg_loss = epoch_loss / max(n_batches, 1)
167
+ avg_hm = epoch_hm_loss / max(n_batches, 1)
168
+ avg_off = epoch_off_loss / max(n_batches, 1)
169
+
170
+ # Log
171
+ writer.add_scalar(f"Phase{phase_num}/train_loss", avg_loss, global_epoch)
172
+ writer.add_scalar(f"Phase{phase_num}/hm_loss", avg_hm, global_epoch)
173
+ writer.add_scalar(f"Phase{phase_num}/off_loss", avg_off, global_epoch)
174
+
175
+ # Validate every 5 epochs
176
+ val_metrics = None
177
+ if global_epoch % 5 == 0 or epoch == n_epochs - 1:
178
+ val_metrics = validate_epoch(model, val_data, device, cfg)
179
+ writer.add_scalar(f"Phase{phase_num}/val_f1_mean", val_metrics["val_f1_mean"], global_epoch)
180
+
181
+ if not np.isnan(val_metrics["val_f1_6nm"]):
182
+ writer.add_scalar(f"Phase{phase_num}/val_f1_6nm", val_metrics["val_f1_6nm"], global_epoch)
183
+ writer.add_scalar(f"Phase{phase_num}/val_f1_12nm", val_metrics["val_f1_12nm"], global_epoch)
184
+
185
+ # Early stopping check
186
+ if val_metrics["val_f1_mean"] > best_f1:
187
+ best_f1 = val_metrics["val_f1_mean"]
188
+ no_improve = 0
189
+ # Save best checkpoint
190
+ torch.save({
191
+ "epoch": global_epoch,
192
+ "model_state_dict": model.state_dict(),
193
+ "optimizer_state_dict": optimizer.state_dict(),
194
+ "val_f1_mean": best_f1,
195
+ "phase": phase_num,
196
+ }, checkpoint_dir / f"phase{phase_num}_best.pth")
197
+ else:
198
+ no_improve += 5 # validated every 5 epochs
199
+
200
+ # Snapshot checkpoints
201
+ if global_epoch in snapshot_epochs:
202
+ torch.save({
203
+ "epoch": global_epoch,
204
+ "model_state_dict": model.state_dict(),
205
+ "val_f1_mean": best_f1,
206
+ "phase": phase_num,
207
+ }, checkpoint_dir / f"phase{phase_num}_{global_epoch}.pth")
208
+
209
+ # Status
210
+ f1_str = f", val_f1={val_metrics['val_f1_mean']:.4f}" if val_metrics else ""
211
+ print(
212
+ f" Phase {phase_num} | Epoch {global_epoch:3d} | "
213
+ f"Loss {avg_loss:.4f} (hm={avg_hm:.4f}, off={avg_off:.4f})"
214
+ f"{f1_str}"
215
+ )
216
+
217
+ if no_improve >= patience:
218
+ print(f" Early stopping at epoch {global_epoch} (patience={patience})")
219
+ break
220
+
221
+ return global_epoch, best_f1
222
+
223
+
224
+ def main():
225
+ args = parse_args()
226
+ with open(args.config) as f:
227
+ cfg = yaml.safe_load(f)
228
+
229
+ set_seed(args.seed)
230
+ device = get_device(args.device)
231
+ print(f"Device: {device}, Fold: {args.fold}, Seed: {args.seed}")
232
+
233
+ # Discover data
234
+ records = discover_synapse_data(
235
+ cfg["data"]["root"], cfg["data"]["synapse_ids"]
236
+ )
237
+
238
+ # Load validation image
239
+ val_record = [r for r in records if r.synapse_id == args.fold]
240
+ if not val_record:
241
+ raise ValueError(f"Fold {args.fold} not found in synapse IDs")
242
+ val_data = load_synapse(val_record[0])
243
+
244
+ # Create dataset
245
+ train_dataset = ImmunogoldDataset(
246
+ records=records,
247
+ fold_id=args.fold,
248
+ mode="train",
249
+ patch_size=cfg["data"]["patch_size"],
250
+ stride=cfg["data"]["stride"],
251
+ hard_mining_fraction=cfg["training"]["hard_mining_fraction"],
252
+ copy_paste_per_class=cfg["training"]["copy_paste_per_class"],
253
+ sigmas=cfg["heatmap"]["sigmas"],
254
+ samples_per_epoch=500,
255
+ seed=args.seed,
256
+ )
257
+
258
+ train_loader = DataLoader(
259
+ train_dataset,
260
+ batch_size=cfg["training"]["batch_size"],
261
+ shuffle=True,
262
+ num_workers=4,
263
+ pin_memory=True,
264
+ drop_last=True,
265
+ worker_init_fn=ImmunogoldDataset.worker_init_fn,
266
+ )
267
+
268
+ # Build model
269
+ pretrained = cfg["model"]["pretrained_weights"]
270
+ if pretrained and not Path(pretrained).exists():
271
+ print(f"Warning: CEM500K weights not found at {pretrained}, using ImageNet")
272
+ pretrained = None
273
+
274
+ model = ImmunogoldCenterNet(
275
+ pretrained_path=pretrained,
276
+ bifpn_channels=cfg["model"]["bifpn_channels"],
277
+ bifpn_rounds=cfg["model"]["bifpn_rounds"],
278
+ num_classes=cfg["model"]["num_classes"],
279
+ ).to(device)
280
+
281
+ param_count = sum(p.numel() for p in model.parameters())
282
+ print(f"Model parameters: {param_count:,}")
283
+
284
+ # Checkpoint directory
285
+ checkpoint_dir = Path("checkpoints") / f"fold_{args.fold}_seed{args.seed}"
286
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
287
+
288
+ # TensorBoard
289
+ writer = SummaryWriter(log_dir=f"logs/fold_{args.fold}_seed{args.seed}")
290
+
291
+ # Snapshot epochs for ensemble
292
+ snapshot_epochs = set(cfg["training"]["n_snapshot_epochs"])
293
+
294
+ # --- Dry run ---
295
+ if args.dry_run:
296
+ print("=== DRY RUN ===")
297
+ batch = next(iter(train_loader))
298
+ images = batch["image"].to(device)
299
+ print(f"Input shape: {images.shape}")
300
+ hm, off = model(images)
301
+ print(f"Heatmap shape: {hm.shape}, Offset shape: {off.shape}")
302
+
303
+ loss_val, hm_loss, off_loss = total_loss(
304
+ hm, batch["heatmap"].to(device),
305
+ off, batch["offsets"].to(device),
306
+ batch["offset_mask"].to(device),
307
+ )
308
+ print(f"Loss: {loss_val.item():.4f} (hm={hm_loss:.4f}, off={off_loss:.4f})")
309
+ print("=== DRY RUN PASSED ===")
310
+ writer.close()
311
+ return
312
+
313
+ # --- Phase 1: Frozen encoder ---
314
+ print("\n=== Phase 1: Frozen encoder ===")
315
+ phase1_cfg = cfg["training"]["phases"]["phase1"]
316
+ model.freeze_encoder()
317
+
318
+ param_groups = model.get_param_groups(1, phase1_cfg)
319
+ optimizer = torch.optim.AdamW(
320
+ param_groups, weight_decay=phase1_cfg["weight_decay"]
321
+ )
322
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
323
+ optimizer, T_0=20, T_mult=2
324
+ )
325
+
326
+ global_epoch = 0
327
+ best_f1 = 0.0
328
+
329
+ global_epoch, best_f1 = train_phase(
330
+ model, train_loader, optimizer, scheduler, device, cfg,
331
+ phase_num=1, n_epochs=phase1_cfg["epochs"],
332
+ writer=writer, global_epoch=global_epoch,
333
+ val_data=val_data, best_f1=best_f1,
334
+ checkpoint_dir=checkpoint_dir,
335
+ snapshot_epochs=snapshot_epochs,
336
+ )
337
+
338
+ # --- Phase 2: Unfreeze deep layers ---
339
+ print("\n=== Phase 2: Unfreeze layer3+layer4 ===")
340
+ phase2_cfg = cfg["training"]["phases"]["phase2"]
341
+ model.unfreeze_deep_layers()
342
+
343
+ param_groups = model.get_param_groups(2, phase2_cfg)
344
+ optimizer = torch.optim.AdamW(
345
+ param_groups, weight_decay=phase2_cfg["weight_decay"]
346
+ )
347
+ scheduler = None # No scheduler for phase 2
348
+
349
+ global_epoch, best_f1 = train_phase(
350
+ model, train_loader, optimizer, scheduler, device, cfg,
351
+ phase_num=2, n_epochs=phase2_cfg["epochs"],
352
+ writer=writer, global_epoch=global_epoch,
353
+ val_data=val_data, best_f1=best_f1,
354
+ checkpoint_dir=checkpoint_dir,
355
+ snapshot_epochs=snapshot_epochs,
356
+ )
357
+
358
+ # --- Phase 3: Full fine-tuning ---
359
+ print("\n=== Phase 3: Full fine-tuning ===")
360
+ phase3_cfg = cfg["training"]["phases"]["phase3"]
361
+ model.unfreeze_all()
362
+
363
+ param_groups = model.get_param_groups(3, phase3_cfg)
364
+ optimizer = torch.optim.AdamW(
365
+ param_groups, weight_decay=phase3_cfg["weight_decay"]
366
+ )
367
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
368
+ optimizer, T_max=phase3_cfg["epochs"],
369
+ eta_min=phase3_cfg["eta_min"],
370
+ )
371
+
372
+ global_epoch, best_f1 = train_phase(
373
+ model, train_loader, optimizer, scheduler, device, cfg,
374
+ phase_num=3, n_epochs=phase3_cfg["epochs"],
375
+ writer=writer, global_epoch=global_epoch,
376
+ val_data=val_data, best_f1=best_f1,
377
+ checkpoint_dir=checkpoint_dir,
378
+ snapshot_epochs=snapshot_epochs,
379
+ )
380
+
381
+ print(f"\nTraining complete. Best val F1: {best_f1:.4f}")
382
+ print(f"Checkpoints saved to: {checkpoint_dir}")
383
+ writer.close()
384
+
385
+
386
+ if __name__ == "__main__":
387
+ main()