chenchangliu commited on
Commit
a38d768
·
verified ·
1 Parent(s): 3ece5b1

Upload code_EffNet_train_best.py

Browse files
Files changed (1) hide show
  1. code_EffNet_train_best.py +423 -0
code_EffNet_train_best.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ # Two-head EfficientNet classifier (multi-task) with:
5
+ # - Dataset: reads image crops from DATA_ROOT/{train,val,test} and labels from sidecar .txt files
6
+ # (each .txt contains: "species_id state_id")
7
+ # - Augmentation (train only): resize to IMG_SIZE, random horizontal flip, random 0/90/180/270 rotation,
8
+ # mild ColorJitter (lighting/camera variation), and small translate/scale jitter via RandomAffine
9
+ # - Transfer learning: EfficientNet-B0 pretrained backbone shared by two classification heads
10
+ # (species head: NUM_SPECIES classes, state head: NUM_STATES classes)
11
+ # - Optimization: AdamW with separate learning rates for backbone vs heads (LR_BACKBONE, LR_HEADS)
12
+ # - Warm-up: freeze backbone for the first FREEZE_EPOCHS epochs, then unfreeze and fine-tune end-to-end
13
+ # - LR schedule: CosineAnnealingLR applied only after unfreezing (T_max = EPOCHS - FREEZE_EPOCHS)
14
+ # - Logging: W&B (self-hosted) logs per-head losses/accuracies, combined accuracy, and current LR values
15
+ # - Checkpointing: saves best.pt (by combined val accuracy = mean of two head accuracies) and last.pt
16
+
17
+ ### To prevent overfitting after 15/20 epochs:
18
+ # - Added: label smoothing to prevent overfitting: ce = nn.CrossEntropyLoss(label_smoothing=0.05)
19
+ # - RandomErasing applied AFTER normalization, because it expects a tensor
20
+ # - Increased dropout
21
+ # - Increased FREEZE_EPOCHS
22
+ # - Reduced color augmentation, use very small numbers
23
+ # - Reduced LR_HEADS
24
+ # - Try freezing batch norm
25
+
26
+ from pathlib import Path
27
+ import os
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ from torch.utils.data import Dataset, DataLoader
32
+ from torchvision.io import read_image
33
+ from torchvision import transforms
34
+ from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
35
+
36
+ import wandb
37
+
38
+
39
+ # ------------------ CONFIG ------------------
40
+ DATA_ROOT = Path("LTN_crop_twohead")
41
+ NUM_SPECIES = 12
42
+ NUM_STATES = 4
43
+
44
+ EPOCHS = 150
45
+ BATCH = 32
46
+
47
+ LR_BACKBONE = 3e-5
48
+ LR_HEADS = 2e-4 # slightly reduced to reduce overfitting
49
+ WEIGHT_DECAY = 1e-2
50
+
51
+ FREEZE_EPOCHS = 15 # freeze backbone initially
52
+
53
+ IMG_SIZE = 224
54
+ WORKERS = 1
55
+
56
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
57
+ # --------------------------------------------
58
+
59
+ ### Freeze batchNorm
60
+ def set_bn_eval(m):
61
+ if isinstance(m, nn.BatchNorm2d):
62
+ m.eval()
63
+ ###-----
64
+
65
+ class RandomRotate90:
66
+ """Random rotation by k * 90 degrees, k in {0,1,2,3}."""
67
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
68
+ k = torch.randint(0, 4, (1,)).item()
69
+ return torch.rot90(x, k, dims=[1, 2]) # rotate H,W for CHW tensor
70
+
71
+
72
+ class TwoHeadCrops(Dataset):
73
+ """
74
+ Expects layout:
75
+ LTN_crop_twohead/{train,val,test}/.../*.jpg
76
+ LTN_crop_twohead/{train,val,test}/.../*.txt (contains: "species_id state_id")
77
+
78
+ The folder name is ignored. Labels come from the .txt next to each image.
79
+ """
80
+ def __init__(self, root: Path, split: str):
81
+ img_exts = {".jpg", ".jpeg", ".png"}
82
+ paths = []
83
+
84
+ for p in (root / split).rglob("*"):
85
+ if p.is_dir():
86
+ continue
87
+ if p.suffix.lower() not in img_exts:
88
+ continue
89
+ if any(part.startswith(".") for part in p.parts):
90
+ continue
91
+ if not p.with_suffix(".txt").exists():
92
+ continue
93
+ paths.append(p)
94
+
95
+ self.img_paths = sorted(paths)
96
+ if not self.img_paths:
97
+ raise RuntimeError(f"No images found under: {root / split}")
98
+
99
+ # Augmentation only for training split
100
+ if split == "train":
101
+ self.tfm = transforms.Compose([
102
+ transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True),
103
+ transforms.RandomHorizontalFlip(p=0.5),
104
+ RandomRotate90(), # 90, 180, 270
105
+ transforms.RandomApply([
106
+ transforms.RandomAffine(
107
+ degrees=0, # no arbitrary angle
108
+ translate=(0.02, 0.02), # small shift
109
+ scale=(0.95, 1.05), # small zoom
110
+ shear=None,
111
+ )
112
+ ], p=0.5),
113
+
114
+ transforms.ColorJitter( ### augmentation in both position and color
115
+ brightness=0.05,
116
+ contrast=0.02,
117
+ saturation=0.02,
118
+ hue=0.01,
119
+ ),
120
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
121
+ std=[0.229, 0.224, 0.225]),
122
+ transforms.RandomErasing( ### Add random erasing 25% of images affected, earsed area: 2–10%
123
+ p=0.20,
124
+ scale=(0.01, 0.05),
125
+ ratio=(0.5, 2.0),
126
+ value=0
127
+ ),
128
+ ])
129
+ else:
130
+ self.tfm = transforms.Compose([
131
+ transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True),
132
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
133
+ std=[0.229, 0.224, 0.225]),
134
+ ])
135
+
136
+ def __len__(self) -> int:
137
+ return len(self.img_paths)
138
+
139
+ def __getitem__(self, i: int):
140
+ p = self.img_paths[i]
141
+ x = read_image(str(p)).float() / 255.0 # CHW in [0..1]
142
+ x = self.tfm(x)
143
+
144
+ lab = p.with_suffix(".txt").read_text(encoding="utf-8").strip().split()
145
+ species_id = int(lab[0])
146
+ state_id = int(lab[1])
147
+
148
+ return x, torch.tensor(species_id, dtype=torch.long), torch.tensor(state_id, dtype=torch.long)
149
+
150
+
151
+ class EffNetTwoHead(nn.Module):
152
+ """EfficientNet backbone + two classification heads."""
153
+ def __init__(self, num_species: int, num_states: int, pretrained: bool = True):
154
+ super().__init__()
155
+ base = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT if pretrained else None)
156
+
157
+ self.features = base.features
158
+ self.pool = base.avgpool
159
+
160
+ c = base.classifier[1].in_features
161
+ self.drop = nn.Dropout(0.3)
162
+
163
+ self.head_species = nn.Linear(c, num_species)
164
+ self.head_state = nn.Linear(c, num_states)
165
+
166
+ def forward(self, x: torch.Tensor):
167
+ x = self.features(x)
168
+ x = self.pool(x)
169
+ x = torch.flatten(x, 1)
170
+ x = self.drop(x)
171
+ return self.head_species(x), self.head_state(x)
172
+
173
+
174
+ @torch.no_grad()
175
+ def eval_one_epoch(model: nn.Module, loader: DataLoader, ce):
176
+ """
177
+ Evaluate:
178
+ - total loss and per-head losses
179
+ - per-head accuracies
180
+ """
181
+ model.eval()
182
+
183
+ loss_sum_total = 0.0
184
+ loss_sum_sp = 0.0
185
+ loss_sum_st = 0.0
186
+ n = 0
187
+
188
+ correct_sp = 0
189
+ correct_st = 0
190
+
191
+ for x, ysp, yst in loader:
192
+ x = x.to(DEVICE, non_blocking=True)
193
+ ysp = ysp.to(DEVICE, non_blocking=True)
194
+ yst = yst.to(DEVICE, non_blocking=True)
195
+
196
+ lsp, lst = model(x)
197
+ loss_sp = ce(lsp, ysp)
198
+ loss_st = ce(lst, yst)
199
+ loss = loss_sp + loss_st
200
+
201
+ bs = x.size(0)
202
+ loss_sum_total += float(loss.item()) * bs
203
+ loss_sum_sp += float(loss_sp.item()) * bs
204
+ loss_sum_st += float(loss_st.item()) * bs
205
+ n += bs
206
+
207
+ correct_sp += int((lsp.argmax(1) == ysp).sum().item())
208
+ correct_st += int((lst.argmax(1) == yst).sum().item())
209
+
210
+ val_loss = loss_sum_total / max(1, n)
211
+ val_loss_sp = loss_sum_sp / max(1, n)
212
+ val_loss_st = loss_sum_st / max(1, n)
213
+ val_acc_sp = correct_sp / max(1, n)
214
+ val_acc_st = correct_st / max(1, n)
215
+
216
+ return val_loss, val_loss_sp, val_loss_st, val_acc_sp, val_acc_st
217
+
218
+
219
+ def main():
220
+ # W&B setup (self-hosted server)
221
+ os.environ.setdefault("WANDB_BASE_URL", "http://k8s.tu-ilmenau.de:31020")
222
+
223
+ run = wandb.init(
224
+ project="EffNetCls",
225
+ entity="mase-students",
226
+ config={
227
+ "epochs": EPOCHS,
228
+ "batch": BATCH,
229
+ "lr_backbone": LR_BACKBONE,
230
+ "lr_heads": LR_HEADS,
231
+ "weight_decay": WEIGHT_DECAY,
232
+ "freeze_epochs": FREEZE_EPOCHS,
233
+ "img_size": IMG_SIZE,
234
+ },
235
+ )
236
+
237
+ # Data
238
+ train_ds = TwoHeadCrops(DATA_ROOT, "train")
239
+ val_ds = TwoHeadCrops(DATA_ROOT, "val")
240
+
241
+ train_loader = DataLoader(
242
+ train_ds,
243
+ batch_size=BATCH,
244
+ shuffle=True,
245
+ num_workers=WORKERS,
246
+ pin_memory=True,
247
+ )
248
+ val_loader = DataLoader(
249
+ val_ds,
250
+ batch_size=BATCH,
251
+ shuffle=False,
252
+ num_workers=WORKERS,
253
+ pin_memory=True,
254
+ )
255
+
256
+ # Model
257
+ model = EffNetTwoHead(NUM_SPECIES, NUM_STATES, pretrained=True).to(DEVICE)
258
+
259
+ # Freeze backbone initially (heads learn first, then fine-tune backbone)
260
+ for p in model.features.parameters():
261
+ p.requires_grad = False
262
+
263
+ # Optimizer with separate LR groups
264
+ opt = torch.optim.AdamW(
265
+ [
266
+ {"params": model.features.parameters(), "lr": LR_BACKBONE},
267
+ {"params": model.pool.parameters(), "lr": LR_BACKBONE},
268
+ {"params": model.drop.parameters(), "lr": LR_BACKBONE},
269
+ {"params": model.head_species.parameters(), "lr": LR_HEADS},
270
+ {"params": model.head_state.parameters(), "lr": LR_HEADS},
271
+ ],
272
+ weight_decay=WEIGHT_DECAY,
273
+ )
274
+
275
+ # LR starts high, smooth cosine decay across the whole training
276
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
277
+ opt,
278
+ T_max=EPOCHS - FREEZE_EPOCHS,
279
+ eta_min=1e-6,
280
+ )
281
+
282
+ ce_train = nn.CrossEntropyLoss(label_smoothing=0.05) # apply label smoothing only for training
283
+ ce_val = nn.CrossEntropyLoss()
284
+
285
+ # Save best by combined accuracy; always save last at end
286
+ best_acc = -1.0
287
+
288
+ for epoch in range(1, EPOCHS + 1):
289
+
290
+ # Unfreeze backbone after warm-up
291
+ if epoch == FREEZE_EPOCHS + 1:
292
+ for p in model.features.parameters():
293
+ p.requires_grad = True
294
+ print(f"[epoch {epoch:03d}] Backbone unfrozen")
295
+
296
+ # ---- Train one epoch ----
297
+ model.train()
298
+ # Freeze BatchNorm running statistics after unfreezing the backbone
299
+ if epoch > FREEZE_EPOCHS:
300
+ model.apply(set_bn_eval)
301
+ ########
302
+
303
+ loss_sum_total = 0.0
304
+ loss_sum_sp = 0.0
305
+ loss_sum_st = 0.0
306
+ n = 0
307
+
308
+ correct_sp = 0
309
+ correct_st = 0
310
+
311
+ for x, ysp, yst in train_loader:
312
+ x = x.to(DEVICE, non_blocking=True)
313
+ ysp = ysp.to(DEVICE, non_blocking=True)
314
+ yst = yst.to(DEVICE, non_blocking=True)
315
+
316
+ opt.zero_grad(set_to_none=True)
317
+ lsp, lst = model(x)
318
+
319
+ # per-head losses
320
+ loss_sp = ce_train(lsp, ysp)
321
+ loss_st = ce_train(lst, yst)
322
+
323
+ loss = loss_sp + loss_st
324
+
325
+ loss.backward()
326
+ opt.step()
327
+
328
+ bs = x.size(0)
329
+ loss_sum_total += float(loss.item()) * bs
330
+ loss_sum_sp += float(loss_sp.item()) * bs
331
+ loss_sum_st += float(loss_st.item()) * bs
332
+ n += bs
333
+
334
+ correct_sp += int((lsp.argmax(1) == ysp).sum().item())
335
+ correct_st += int((lst.argmax(1) == yst).sum().item())
336
+
337
+ train_loss = loss_sum_total / max(1, n)
338
+ train_loss_sp = loss_sum_sp / max(1, n)
339
+ train_loss_st = loss_sum_st / max(1, n)
340
+ train_acc_sp = correct_sp / max(1, n)
341
+ train_acc_st = correct_st / max(1, n)
342
+
343
+ # ---- Validate ----
344
+ val_loss, val_loss_sp, val_loss_st, val_acc_sp, val_acc_st = eval_one_epoch(model, val_loader, ce_val)
345
+
346
+ # after unfreeze, avoids “wasting” cosine decay while backbone is frozen.
347
+ if epoch > FREEZE_EPOCHS:
348
+ scheduler.step()
349
+
350
+ # Read current LRs after scheduler.step()
351
+ lr_backbone = opt.param_groups[0]["lr"]
352
+ lr_heads = opt.param_groups[-1]["lr"]
353
+
354
+ combined_acc = 0.5 * (val_acc_sp + val_acc_st)
355
+
356
+ # ---- Print per-epoch summary ----
357
+ print(
358
+ f"epoch {epoch:03d} | "
359
+ f"train_loss={train_loss:.4f} (sp={train_loss_sp:.4f}, st={train_loss_st:.4f}) | "
360
+ f"train_acc_sp={train_acc_sp:.3f} | train_acc_st={train_acc_st:.3f} | "
361
+ f"val_loss={val_loss:.4f} (sp={val_loss_sp:.4f}, st={val_loss_st:.4f}) | "
362
+ f"val_acc_sp={val_acc_sp:.3f} | val_acc_st={val_acc_st:.3f} | "
363
+ f"val_acc_combined={combined_acc:.3f} | "
364
+ f"lr_backbone={lr_backbone:.6f} | lr_heads={lr_heads:.6f}"
365
+ )
366
+
367
+ # ---- W&B logging ----
368
+ wandb.log({
369
+ "epoch": epoch,
370
+
371
+ "train/loss_total": train_loss,
372
+ "train/loss_species": train_loss_sp,
373
+ "train/loss_state": train_loss_st,
374
+ "train/acc_species": train_acc_sp,
375
+ "train/acc_state": train_acc_st,
376
+
377
+ "val/loss_total": val_loss,
378
+ "val/loss_species": val_loss_sp,
379
+ "val/loss_state": val_loss_st,
380
+ "val/acc_species": val_acc_sp,
381
+ "val/acc_state": val_acc_st,
382
+ "val/acc_combined": combined_acc,
383
+
384
+ "lr/backbone": lr_backbone,
385
+ "lr/heads": lr_heads,
386
+ })
387
+
388
+ # ---- Save best checkpoint by combined accuracy ----
389
+ if combined_acc > best_acc:
390
+ best_acc = combined_acc
391
+ torch.save(
392
+ {
393
+ "model": model.state_dict(),
394
+ "epoch": epoch,
395
+ "best_acc": best_acc,
396
+ "val_acc_species": val_acc_sp,
397
+ "val_acc_state": val_acc_st,
398
+ "val_acc_combined": combined_acc,
399
+ "num_species": NUM_SPECIES,
400
+ "num_states": NUM_STATES,
401
+ "img_size": IMG_SIZE,
402
+ },
403
+ "best.pt",
404
+ )
405
+
406
+ # Always save last checkpoint
407
+ torch.save(
408
+ {
409
+ "model": model.state_dict(),
410
+ "epoch": EPOCHS,
411
+ "num_species": NUM_SPECIES,
412
+ "num_states": NUM_STATES,
413
+ "img_size": IMG_SIZE,
414
+ },
415
+ "last.pt",
416
+ )
417
+
418
+ run.finish()
419
+ print("Done. Saved best.pt and last.pt")
420
+
421
+
422
+ if __name__ == "__main__":
423
+ main()