chenchangliu commited on
Commit
0f3a64e
·
verified ·
1 Parent(s): a38d768

Delete code_EffNet_train_best.py

Browse files
Files changed (1) hide show
  1. code_EffNet_train_best.py +0 -423
code_EffNet_train_best.py DELETED
@@ -1,423 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- # Two-head EfficientNet classifier (multi-task) with:
5
- # - Dataset: reads image crops from DATA_ROOT/{train,val,test} and labels from sidecar .txt files
6
- # (each .txt contains: "species_id state_id")
7
- # - Augmentation (train only): resize to IMG_SIZE, random horizontal flip, random 0/90/180/270 rotation,
8
- # mild ColorJitter (lighting/camera variation), and small translate/scale jitter via RandomAffine
9
- # - Transfer learning: EfficientNet-B0 pretrained backbone shared by two classification heads
10
- # (species head: NUM_SPECIES classes, state head: NUM_STATES classes)
11
- # - Optimization: AdamW with separate learning rates for backbone vs heads (LR_BACKBONE, LR_HEADS)
12
- # - Warm-up: freeze backbone for the first FREEZE_EPOCHS epochs, then unfreeze and fine-tune end-to-end
13
- # - LR schedule: CosineAnnealingLR applied only after unfreezing (T_max = EPOCHS - FREEZE_EPOCHS)
14
- # - Logging: W&B (self-hosted) logs per-head losses/accuracies, combined accuracy, and current LR values
15
- # - Checkpointing: saves best.pt (by combined val accuracy = mean of two head accuracies) and last.pt
16
-
17
- ### To prevent overfitting after 15/20 epochs:
18
- # - Added: label smoothing to prevent overfitting: ce = nn.CrossEntropyLoss(label_smoothing=0.05)
19
- # - RandomErasing applied AFTER normalization, because it expects a tensor
20
- # - Increased dropout
21
- # - Increased FREEZE_EPOCHS
22
- # - Reduced color augmentation, use very small numbers
23
- # - Reduced LR_HEADS
24
- # - Try freezing batch norm
25
-
26
- from pathlib import Path
27
- import os
28
-
29
- import torch
30
- import torch.nn as nn
31
- from torch.utils.data import Dataset, DataLoader
32
- from torchvision.io import read_image
33
- from torchvision import transforms
34
- from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
35
-
36
- import wandb
37
-
38
-
39
- # ------------------ CONFIG ------------------
40
- DATA_ROOT = Path("LTN_crop_twohead")
41
- NUM_SPECIES = 12
42
- NUM_STATES = 4
43
-
44
- EPOCHS = 150
45
- BATCH = 32
46
-
47
- LR_BACKBONE = 3e-5
48
- LR_HEADS = 2e-4 # slightly reduced to reduce overfitting
49
- WEIGHT_DECAY = 1e-2
50
-
51
- FREEZE_EPOCHS = 15 # freeze backbone initially
52
-
53
- IMG_SIZE = 224
54
- WORKERS = 1
55
-
56
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
57
- # --------------------------------------------
58
-
59
- ### Freeze batchNorm
60
- def set_bn_eval(m):
61
- if isinstance(m, nn.BatchNorm2d):
62
- m.eval()
63
- ###-----
64
-
65
- class RandomRotate90:
66
- """Random rotation by k * 90 degrees, k in {0,1,2,3}."""
67
- def __call__(self, x: torch.Tensor) -> torch.Tensor:
68
- k = torch.randint(0, 4, (1,)).item()
69
- return torch.rot90(x, k, dims=[1, 2]) # rotate H,W for CHW tensor
70
-
71
-
72
- class TwoHeadCrops(Dataset):
73
- """
74
- Expects layout:
75
- LTN_crop_twohead/{train,val,test}/.../*.jpg
76
- LTN_crop_twohead/{train,val,test}/.../*.txt (contains: "species_id state_id")
77
-
78
- The folder name is ignored. Labels come from the .txt next to each image.
79
- """
80
- def __init__(self, root: Path, split: str):
81
- img_exts = {".jpg", ".jpeg", ".png"}
82
- paths = []
83
-
84
- for p in (root / split).rglob("*"):
85
- if p.is_dir():
86
- continue
87
- if p.suffix.lower() not in img_exts:
88
- continue
89
- if any(part.startswith(".") for part in p.parts):
90
- continue
91
- if not p.with_suffix(".txt").exists():
92
- continue
93
- paths.append(p)
94
-
95
- self.img_paths = sorted(paths)
96
- if not self.img_paths:
97
- raise RuntimeError(f"No images found under: {root / split}")
98
-
99
- # Augmentation only for training split
100
- if split == "train":
101
- self.tfm = transforms.Compose([
102
- transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True),
103
- transforms.RandomHorizontalFlip(p=0.5),
104
- RandomRotate90(), # 90, 180, 270
105
- transforms.RandomApply([
106
- transforms.RandomAffine(
107
- degrees=0, # no arbitrary angle
108
- translate=(0.02, 0.02), # small shift
109
- scale=(0.95, 1.05), # small zoom
110
- shear=None,
111
- )
112
- ], p=0.5),
113
-
114
- transforms.ColorJitter( ### augmentation in both position and color
115
- brightness=0.05,
116
- contrast=0.02,
117
- saturation=0.02,
118
- hue=0.01,
119
- ),
120
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
121
- std=[0.229, 0.224, 0.225]),
122
- transforms.RandomErasing( ### Add random erasing 25% of images affected, earsed area: 2–10%
123
- p=0.20,
124
- scale=(0.01, 0.05),
125
- ratio=(0.5, 2.0),
126
- value=0
127
- ),
128
- ])
129
- else:
130
- self.tfm = transforms.Compose([
131
- transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True),
132
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
133
- std=[0.229, 0.224, 0.225]),
134
- ])
135
-
136
- def __len__(self) -> int:
137
- return len(self.img_paths)
138
-
139
- def __getitem__(self, i: int):
140
- p = self.img_paths[i]
141
- x = read_image(str(p)).float() / 255.0 # CHW in [0..1]
142
- x = self.tfm(x)
143
-
144
- lab = p.with_suffix(".txt").read_text(encoding="utf-8").strip().split()
145
- species_id = int(lab[0])
146
- state_id = int(lab[1])
147
-
148
- return x, torch.tensor(species_id, dtype=torch.long), torch.tensor(state_id, dtype=torch.long)
149
-
150
-
151
- class EffNetTwoHead(nn.Module):
152
- """EfficientNet backbone + two classification heads."""
153
- def __init__(self, num_species: int, num_states: int, pretrained: bool = True):
154
- super().__init__()
155
- base = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT if pretrained else None)
156
-
157
- self.features = base.features
158
- self.pool = base.avgpool
159
-
160
- c = base.classifier[1].in_features
161
- self.drop = nn.Dropout(0.3)
162
-
163
- self.head_species = nn.Linear(c, num_species)
164
- self.head_state = nn.Linear(c, num_states)
165
-
166
- def forward(self, x: torch.Tensor):
167
- x = self.features(x)
168
- x = self.pool(x)
169
- x = torch.flatten(x, 1)
170
- x = self.drop(x)
171
- return self.head_species(x), self.head_state(x)
172
-
173
-
174
- @torch.no_grad()
175
- def eval_one_epoch(model: nn.Module, loader: DataLoader, ce):
176
- """
177
- Evaluate:
178
- - total loss and per-head losses
179
- - per-head accuracies
180
- """
181
- model.eval()
182
-
183
- loss_sum_total = 0.0
184
- loss_sum_sp = 0.0
185
- loss_sum_st = 0.0
186
- n = 0
187
-
188
- correct_sp = 0
189
- correct_st = 0
190
-
191
- for x, ysp, yst in loader:
192
- x = x.to(DEVICE, non_blocking=True)
193
- ysp = ysp.to(DEVICE, non_blocking=True)
194
- yst = yst.to(DEVICE, non_blocking=True)
195
-
196
- lsp, lst = model(x)
197
- loss_sp = ce(lsp, ysp)
198
- loss_st = ce(lst, yst)
199
- loss = loss_sp + loss_st
200
-
201
- bs = x.size(0)
202
- loss_sum_total += float(loss.item()) * bs
203
- loss_sum_sp += float(loss_sp.item()) * bs
204
- loss_sum_st += float(loss_st.item()) * bs
205
- n += bs
206
-
207
- correct_sp += int((lsp.argmax(1) == ysp).sum().item())
208
- correct_st += int((lst.argmax(1) == yst).sum().item())
209
-
210
- val_loss = loss_sum_total / max(1, n)
211
- val_loss_sp = loss_sum_sp / max(1, n)
212
- val_loss_st = loss_sum_st / max(1, n)
213
- val_acc_sp = correct_sp / max(1, n)
214
- val_acc_st = correct_st / max(1, n)
215
-
216
- return val_loss, val_loss_sp, val_loss_st, val_acc_sp, val_acc_st
217
-
218
-
219
- def main():
220
- # W&B setup (self-hosted server)
221
- os.environ.setdefault("WANDB_BASE_URL", "http://k8s.tu-ilmenau.de:31020")
222
-
223
- run = wandb.init(
224
- project="EffNetCls",
225
- entity="mase-students",
226
- config={
227
- "epochs": EPOCHS,
228
- "batch": BATCH,
229
- "lr_backbone": LR_BACKBONE,
230
- "lr_heads": LR_HEADS,
231
- "weight_decay": WEIGHT_DECAY,
232
- "freeze_epochs": FREEZE_EPOCHS,
233
- "img_size": IMG_SIZE,
234
- },
235
- )
236
-
237
- # Data
238
- train_ds = TwoHeadCrops(DATA_ROOT, "train")
239
- val_ds = TwoHeadCrops(DATA_ROOT, "val")
240
-
241
- train_loader = DataLoader(
242
- train_ds,
243
- batch_size=BATCH,
244
- shuffle=True,
245
- num_workers=WORKERS,
246
- pin_memory=True,
247
- )
248
- val_loader = DataLoader(
249
- val_ds,
250
- batch_size=BATCH,
251
- shuffle=False,
252
- num_workers=WORKERS,
253
- pin_memory=True,
254
- )
255
-
256
- # Model
257
- model = EffNetTwoHead(NUM_SPECIES, NUM_STATES, pretrained=True).to(DEVICE)
258
-
259
- # Freeze backbone initially (heads learn first, then fine-tune backbone)
260
- for p in model.features.parameters():
261
- p.requires_grad = False
262
-
263
- # Optimizer with separate LR groups
264
- opt = torch.optim.AdamW(
265
- [
266
- {"params": model.features.parameters(), "lr": LR_BACKBONE},
267
- {"params": model.pool.parameters(), "lr": LR_BACKBONE},
268
- {"params": model.drop.parameters(), "lr": LR_BACKBONE},
269
- {"params": model.head_species.parameters(), "lr": LR_HEADS},
270
- {"params": model.head_state.parameters(), "lr": LR_HEADS},
271
- ],
272
- weight_decay=WEIGHT_DECAY,
273
- )
274
-
275
- # LR starts high, smooth cosine decay across the whole training
276
- scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
277
- opt,
278
- T_max=EPOCHS - FREEZE_EPOCHS,
279
- eta_min=1e-6,
280
- )
281
-
282
- ce_train = nn.CrossEntropyLoss(label_smoothing=0.05) # apply label smoothing only for training
283
- ce_val = nn.CrossEntropyLoss()
284
-
285
- # Save best by combined accuracy; always save last at end
286
- best_acc = -1.0
287
-
288
- for epoch in range(1, EPOCHS + 1):
289
-
290
- # Unfreeze backbone after warm-up
291
- if epoch == FREEZE_EPOCHS + 1:
292
- for p in model.features.parameters():
293
- p.requires_grad = True
294
- print(f"[epoch {epoch:03d}] Backbone unfrozen")
295
-
296
- # ---- Train one epoch ----
297
- model.train()
298
- # Freeze BatchNorm running statistics after unfreezing the backbone
299
- if epoch > FREEZE_EPOCHS:
300
- model.apply(set_bn_eval)
301
- ########
302
-
303
- loss_sum_total = 0.0
304
- loss_sum_sp = 0.0
305
- loss_sum_st = 0.0
306
- n = 0
307
-
308
- correct_sp = 0
309
- correct_st = 0
310
-
311
- for x, ysp, yst in train_loader:
312
- x = x.to(DEVICE, non_blocking=True)
313
- ysp = ysp.to(DEVICE, non_blocking=True)
314
- yst = yst.to(DEVICE, non_blocking=True)
315
-
316
- opt.zero_grad(set_to_none=True)
317
- lsp, lst = model(x)
318
-
319
- # per-head losses
320
- loss_sp = ce_train(lsp, ysp)
321
- loss_st = ce_train(lst, yst)
322
-
323
- loss = loss_sp + loss_st
324
-
325
- loss.backward()
326
- opt.step()
327
-
328
- bs = x.size(0)
329
- loss_sum_total += float(loss.item()) * bs
330
- loss_sum_sp += float(loss_sp.item()) * bs
331
- loss_sum_st += float(loss_st.item()) * bs
332
- n += bs
333
-
334
- correct_sp += int((lsp.argmax(1) == ysp).sum().item())
335
- correct_st += int((lst.argmax(1) == yst).sum().item())
336
-
337
- train_loss = loss_sum_total / max(1, n)
338
- train_loss_sp = loss_sum_sp / max(1, n)
339
- train_loss_st = loss_sum_st / max(1, n)
340
- train_acc_sp = correct_sp / max(1, n)
341
- train_acc_st = correct_st / max(1, n)
342
-
343
- # ---- Validate ----
344
- val_loss, val_loss_sp, val_loss_st, val_acc_sp, val_acc_st = eval_one_epoch(model, val_loader, ce_val)
345
-
346
- # after unfreeze, avoids “wasting” cosine decay while backbone is frozen.
347
- if epoch > FREEZE_EPOCHS:
348
- scheduler.step()
349
-
350
- # Read current LRs after scheduler.step()
351
- lr_backbone = opt.param_groups[0]["lr"]
352
- lr_heads = opt.param_groups[-1]["lr"]
353
-
354
- combined_acc = 0.5 * (val_acc_sp + val_acc_st)
355
-
356
- # ---- Print per-epoch summary ----
357
- print(
358
- f"epoch {epoch:03d} | "
359
- f"train_loss={train_loss:.4f} (sp={train_loss_sp:.4f}, st={train_loss_st:.4f}) | "
360
- f"train_acc_sp={train_acc_sp:.3f} | train_acc_st={train_acc_st:.3f} | "
361
- f"val_loss={val_loss:.4f} (sp={val_loss_sp:.4f}, st={val_loss_st:.4f}) | "
362
- f"val_acc_sp={val_acc_sp:.3f} | val_acc_st={val_acc_st:.3f} | "
363
- f"val_acc_combined={combined_acc:.3f} | "
364
- f"lr_backbone={lr_backbone:.6f} | lr_heads={lr_heads:.6f}"
365
- )
366
-
367
- # ---- W&B logging ----
368
- wandb.log({
369
- "epoch": epoch,
370
-
371
- "train/loss_total": train_loss,
372
- "train/loss_species": train_loss_sp,
373
- "train/loss_state": train_loss_st,
374
- "train/acc_species": train_acc_sp,
375
- "train/acc_state": train_acc_st,
376
-
377
- "val/loss_total": val_loss,
378
- "val/loss_species": val_loss_sp,
379
- "val/loss_state": val_loss_st,
380
- "val/acc_species": val_acc_sp,
381
- "val/acc_state": val_acc_st,
382
- "val/acc_combined": combined_acc,
383
-
384
- "lr/backbone": lr_backbone,
385
- "lr/heads": lr_heads,
386
- })
387
-
388
- # ---- Save best checkpoint by combined accuracy ----
389
- if combined_acc > best_acc:
390
- best_acc = combined_acc
391
- torch.save(
392
- {
393
- "model": model.state_dict(),
394
- "epoch": epoch,
395
- "best_acc": best_acc,
396
- "val_acc_species": val_acc_sp,
397
- "val_acc_state": val_acc_st,
398
- "val_acc_combined": combined_acc,
399
- "num_species": NUM_SPECIES,
400
- "num_states": NUM_STATES,
401
- "img_size": IMG_SIZE,
402
- },
403
- "best.pt",
404
- )
405
-
406
- # Always save last checkpoint
407
- torch.save(
408
- {
409
- "model": model.state_dict(),
410
- "epoch": EPOCHS,
411
- "num_species": NUM_SPECIES,
412
- "num_states": NUM_STATES,
413
- "img_size": IMG_SIZE,
414
- },
415
- "last.pt",
416
- )
417
-
418
- run.finish()
419
- print("Done. Saved best.pt and last.pt")
420
-
421
-
422
- if __name__ == "__main__":
423
- main()