AnikS22 commited on
Commit
357520b
·
verified ·
1 Parent(s): db4158f

Upload train_final.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_final.py +205 -0
train_final.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train the final deployable model on ALL 10 images (no holdout).
3
+
4
+ LOOCV proved F1=0.94. This trains the production model using every
5
+ labeled particle for maximum generalization to new unseen images.
6
+
7
+ Usage:
8
+ python train_final.py --config config/config.yaml --device cuda:0
9
+ python train_final.py --config config/config.yaml --device mps
10
+ """
11
+
12
+ import argparse
13
+ import random
14
+ import time
15
+ from pathlib import Path
16
+
17
+ import numpy as np
18
+ import torch
19
+ import yaml
20
+ from torch.utils.data import DataLoader
21
+
22
+ from src.dataset import ImmunogoldDataset
23
+ from src.model import ImmunogoldCenterNet
24
+ from src.loss import total_loss
25
+ from src.preprocessing import discover_synapse_data, load_synapse
26
+
27
+
28
+ def set_seed(seed: int):
29
+ random.seed(seed)
30
+ np.random.seed(seed)
31
+ torch.manual_seed(seed)
32
+ if torch.cuda.is_available():
33
+ torch.cuda.manual_seed_all(seed)
34
+
35
+
36
+ def train_epoch(model, loader, optimizer, device):
37
+ model.train()
38
+ loss_sum = 0
39
+ n = 0
40
+ for batch in loader:
41
+ imgs = batch["image"].to(device)
42
+ optimizer.zero_grad()
43
+ hm_pred, off_pred = model(imgs)
44
+ loss, hm_l, off_l = total_loss(
45
+ hm_pred, batch["heatmap"].to(device),
46
+ off_pred, batch["offsets"].to(device),
47
+ batch["offset_mask"].to(device),
48
+ conf_weights=batch["conf_map"].to(device),
49
+ )
50
+ loss.backward()
51
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
52
+ optimizer.step()
53
+ loss_sum += loss.item()
54
+ n += 1
55
+ return loss_sum / n
56
+
57
+
58
+ def main():
59
+ parser = argparse.ArgumentParser(description="Train final deployable model")
60
+ parser.add_argument("--config", default="config/config.yaml")
61
+ parser.add_argument("--device", default="auto")
62
+ parser.add_argument("--seed", type=int, default=42)
63
+ args = parser.parse_args()
64
+
65
+ with open(args.config) as f:
66
+ cfg = yaml.safe_load(f)
67
+
68
+ set_seed(args.seed)
69
+
70
+ if args.device == "auto":
71
+ device = torch.device(
72
+ "cuda" if torch.cuda.is_available()
73
+ else "mps" if torch.backends.mps.is_available()
74
+ else "cpu"
75
+ )
76
+ else:
77
+ device = torch.device(args.device)
78
+ print(f"Device: {device}")
79
+
80
+ # Load ALL data — no holdout
81
+ records = discover_synapse_data(cfg["data"]["root"], cfg["data"]["synapse_ids"])
82
+
83
+ # Dataset uses ALL images for training (fold_id=None means no exclusion)
84
+ dataset = ImmunogoldDataset(
85
+ records=records,
86
+ fold_id="__NONE__", # no image excluded
87
+ mode="train",
88
+ patch_size=cfg["data"]["patch_size"],
89
+ stride=cfg["data"]["stride"],
90
+ hard_mining_fraction=cfg["training"]["hard_mining_fraction"],
91
+ copy_paste_per_class=cfg["training"]["copy_paste_per_class"],
92
+ sigmas=cfg["heatmap"]["sigmas"],
93
+ samples_per_epoch=500,
94
+ seed=args.seed,
95
+ )
96
+
97
+ loader = DataLoader(
98
+ dataset, batch_size=cfg["training"]["batch_size"],
99
+ shuffle=True, num_workers=4, drop_last=True,
100
+ worker_init_fn=ImmunogoldDataset.worker_init_fn,
101
+ )
102
+
103
+ print(f"Training on ALL {len(dataset.images)} images, "
104
+ f"{sum(len(a['6nm'])+len(a['12nm']) for a in dataset.annotations.values())} particles")
105
+
106
+ # Model
107
+ pretrained = cfg["model"]["pretrained_weights"]
108
+ if not Path(pretrained).exists():
109
+ pretrained = None
110
+ print("Warning: CEM500K weights not found, using ImageNet")
111
+
112
+ model = ImmunogoldCenterNet(
113
+ pretrained_path=pretrained,
114
+ bifpn_channels=cfg["model"]["bifpn_channels"],
115
+ bifpn_rounds=cfg["model"]["bifpn_rounds"],
116
+ ).to(device)
117
+ print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
118
+
119
+ out_dir = Path("checkpoints/final")
120
+ out_dir.mkdir(parents=True, exist_ok=True)
121
+ start = time.time()
122
+
123
+ # Phase 1: Frozen encoder (40 epochs — slightly shorter since more data)
124
+ print("\n=== Phase 1: Frozen encoder (40 epochs) ===")
125
+ model.freeze_encoder()
126
+ opt = torch.optim.AdamW(
127
+ [p for p in model.parameters() if p.requires_grad],
128
+ lr=1e-3, weight_decay=1e-4,
129
+ )
130
+ sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=15, T_mult=2)
131
+
132
+ for ep in range(1, 41):
133
+ loss = train_epoch(model, loader, opt, device)
134
+ sched.step()
135
+ if ep % 10 == 0:
136
+ elapsed = time.time() - start
137
+ print(f" Epoch {ep:3d} | loss={loss:.4f} | {elapsed:.0f}s")
138
+
139
+ torch.save({"model_state_dict": model.state_dict(), "epoch": 40},
140
+ out_dir / "phase1.pth")
141
+
142
+ # Phase 2: Unfreeze deep layers (40 epochs)
143
+ print("\n=== Phase 2: Unfreeze layer3+4 (40 epochs) ===")
144
+ model.unfreeze_deep_layers()
145
+ opt = torch.optim.AdamW([
146
+ {"params": model.layer3.parameters(), "lr": 1e-5},
147
+ {"params": model.layer4.parameters(), "lr": 5e-5},
148
+ {"params": model.bifpn.parameters(), "lr": 5e-4},
149
+ {"params": model.upsample.parameters(), "lr": 5e-4},
150
+ {"params": model.heatmap_head.parameters(), "lr": 5e-4},
151
+ {"params": model.offset_head.parameters(), "lr": 5e-4},
152
+ ], weight_decay=1e-4)
153
+
154
+ for ep in range(41, 81):
155
+ loss = train_epoch(model, loader, opt, device)
156
+ if ep % 10 == 0:
157
+ elapsed = time.time() - start
158
+ print(f" Epoch {ep:3d} | loss={loss:.4f} | {elapsed:.0f}s")
159
+
160
+ torch.save({"model_state_dict": model.state_dict(), "epoch": 80},
161
+ out_dir / "phase2.pth")
162
+
163
+ # Phase 3: Full fine-tune (60 epochs)
164
+ print("\n=== Phase 3: Full fine-tune (60 epochs) ===")
165
+ model.unfreeze_all()
166
+ opt = torch.optim.AdamW([
167
+ {"params": model.stem.parameters(), "lr": 1e-6},
168
+ {"params": model.layer1.parameters(), "lr": 5e-6},
169
+ {"params": model.layer2.parameters(), "lr": 1e-5},
170
+ {"params": model.layer3.parameters(), "lr": 5e-5},
171
+ {"params": model.layer4.parameters(), "lr": 1e-4},
172
+ {"params": model.bifpn.parameters(), "lr": 2e-4},
173
+ {"params": model.upsample.parameters(), "lr": 2e-4},
174
+ {"params": model.heatmap_head.parameters(), "lr": 2e-4},
175
+ {"params": model.offset_head.parameters(), "lr": 2e-4},
176
+ ], weight_decay=1e-4)
177
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=60, eta_min=1e-7)
178
+
179
+ for ep in range(81, 141):
180
+ loss = train_epoch(model, loader, opt, device)
181
+ sched.step()
182
+ if ep % 10 == 0:
183
+ elapsed = time.time() - start
184
+ print(f" Epoch {ep:3d} | loss={loss:.4f} | {elapsed:.0f}s")
185
+ torch.save({
186
+ "model_state_dict": model.state_dict(),
187
+ "epoch": ep,
188
+ }, out_dir / f"phase3_{ep}.pth")
189
+
190
+ # Save final model
191
+ torch.save({
192
+ "model_state_dict": model.state_dict(),
193
+ "epoch": 140,
194
+ "config": cfg,
195
+ }, out_dir / "final_model.pth")
196
+
197
+ elapsed = time.time() - start
198
+ print(f"\n=== Done: 140 epochs in {elapsed:.0f}s ({elapsed/60:.1f}min) ===")
199
+ print(f"Final model: {out_dir / 'final_model.pth'}")
200
+ print(f"\nTo detect particles in a new image:")
201
+ print(f" python predict.py --image path/to/new_image.tif --checkpoint {out_dir / 'final_model.pth'}")
202
+
203
+
204
+ if __name__ == "__main__":
205
+ main()