File size: 18,236 Bytes
02db62d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa55778
02db62d
 
 
fa55778
 
 
 
 
 
 
 
 
 
02db62d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
"""
2-phase training pipeline for dog breed classification.

v3 recipe:
- Phase 1: Frozen backbone, train head only, OneCycleLR warmup
- Phase 2: Unfreeze backbone, differential LR (backbone 0.01Γ—), CosineAnnealingLR
- ArcFace angular margin loss (optional, default on)
- Progressive resizing 224β†’336 mid-training
- Label smoothing, MixUp/CutMix at batch level
- MPS-optimized for M4 Max
"""

import os
import json
import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

from .registry import get_backbone, list_backbones
from .heads.mlp_head import MLPHead
from .losses import ArcFaceHead
from .augmentations import mixup_data, cutmix_data, mixup_criterion


NUM_CLASSES = 120


def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


class BreedClassifier(nn.Module):
    """Backbone + head β€” supports both MLP (CE) and ArcFace modes."""

    def __init__(
        self,
        backbone_name: str,
        num_classes: int = NUM_CLASSES,
        pretrained: bool = True,
        use_arcface: bool = False,
        arcface_scale: float = 30.0,
        arcface_margin: float = 0.3,
    ):
        super().__init__()
        self.backbone = get_backbone(backbone_name, pretrained=pretrained)
        self.backbone_name = backbone_name
        self.use_arcface = use_arcface

        if use_arcface:
            self.head = ArcFaceHead(
                embed_dim=self.backbone.embed_dim,
                num_classes=num_classes,
                scale=arcface_scale,
                margin=arcface_margin,
            )
        else:
            self.head = MLPHead(self.backbone.embed_dim, num_classes)

    def forward(self, x, labels=None):
        features = self.backbone(x)
        if self.use_arcface:
            return self.head(features, labels)
        return self.head(features)

    def freeze_backbone(self):
        self.backbone.freeze()

    def unfreeze_backbone(self, **kwargs):
        self.backbone.unfreeze(**kwargs)

    def get_param_groups(self, lr: float, backbone_lr_mult: float = 0.1) -> list[dict]:
        groups = self.backbone.get_param_groups(lr, backbone_lr_mult)
        groups.append({"params": list(self.head.parameters()), "lr": lr})
        return groups

    def get_preprocess_config(self) -> dict:
        return self.backbone.get_preprocess_config()


def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device,
    scheduler=None,
    mixup_alpha: float = 0.2,
    cutmix_alpha: float = 1.0,
    mix_prob: float = 0.5,
) -> tuple[float, float]:
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    is_arcface = getattr(model, "use_arcface", False)

    pbar = tqdm(loader, desc="Train", leave=False)
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        if is_arcface:
            # ArcFace: no MixUp/CutMix (margin loss needs clean labels)
            # Head returns loss directly during training
            loss = model(images, labels)
            # For accuracy tracking, do inference pass (no grad)
            with torch.no_grad():
                logits = model(images)  # labels=None β†’ returns logits
        else:
            # Standard CE path with MixUp/CutMix
            r = np.random.random()
            if r < mix_prob / 2:
                images, targets_a, targets_b, lam = mixup_data(images, labels, mixup_alpha)
                use_mix = True
            elif r < mix_prob:
                images, targets_a, targets_b, lam = cutmix_data(images, labels, cutmix_alpha)
                use_mix = True
            else:
                use_mix = False

            logits = model(images)

            if use_mix:
                loss = mixup_criterion(criterion, logits, targets_a, targets_b, lam)
            else:
                loss = criterion(logits, labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        if scheduler is not None:
            scheduler.step()

        total_loss += loss.item() * images.size(0)
        _, predicted = logits.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

        pbar.set_postfix(
            loss=f"{loss.item():.3f}",
            acc=f"{100.*correct/total:.1f}%",
            lr=f"{optimizer.param_groups[-1]['lr']:.1e}",
        )

    return total_loss / total, 100.0 * correct / total


@torch.no_grad()
def evaluate(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> dict:
    model.eval()
    total_loss = 0.0
    correct = 0
    correct_top5 = 0
    total = 0
    is_arcface = getattr(model, "use_arcface", False)

    for images, labels in tqdm(loader, desc="Eval", leave=False):
        images = images.to(device)
        labels = labels.to(device)

        # ArcFace in eval: labels=None returns cosine similarity logits
        outputs = model(images)
        loss = criterion(outputs, labels)

        total_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()

        _, top5_pred = outputs.topk(5, 1, True, True)
        correct_top5 += top5_pred.eq(labels.view(-1, 1).expand_as(top5_pred)).sum().item()
        total += labels.size(0)

    return {
        "loss": total_loss / total,
        "top1_acc": 100.0 * correct / total,
        "top5_acc": 100.0 * correct_top5 / total,
    }


def _build_loaders(data_dir, preprocess_config, batch_size, img_size_override=None):
    """Build train/val/test dataloaders, optionally overriding image size.

    Used by progressive resizing to rebuild loaders at a new resolution.
    """
    from .dataset import get_transforms
    from torchvision import datasets

    if img_size_override is not None:
        preprocess_config = {**preprocess_config, "input_size": img_size_override}

    train_transform = get_transforms(preprocess_config, is_train=True)
    val_transform = get_transforms(preprocess_config, is_train=False)

    train_dir = os.path.join(data_dir, "train")
    val_dir = os.path.join(data_dir, "val")
    test_dir = os.path.join(data_dir, "test")

    train_ds = datasets.ImageFolder(train_dir, transform=train_transform)
    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        num_workers=4, pin_memory=True, drop_last=True, persistent_workers=True,
    )

    val_loader = None
    if os.path.isdir(val_dir):
        val_ds = datasets.ImageFolder(val_dir, transform=val_transform)
        val_loader = DataLoader(
            val_ds, batch_size=batch_size, shuffle=False,
            num_workers=4, pin_memory=True, persistent_workers=True,
        )

    test_loader = None
    if os.path.isdir(test_dir):
        test_ds = datasets.ImageFolder(test_dir, transform=val_transform)
        test_loader = DataLoader(
            test_ds, batch_size=batch_size, shuffle=False,
            num_workers=4, pin_memory=True, persistent_workers=True,
        )

    return train_loader, val_loader, test_loader


def train_model(
    backbone_name: str,
    train_loader: DataLoader,
    val_loader: DataLoader,
    epochs: int = 50,
    warmup_epochs: int = 2,
    lr: float = 1e-3,
    backbone_lr_mult: float = 0.01,  # 1/100th β€” Codex+Gemini recommendation
    label_smoothing: float = 0.1,
    mixup_alpha: float = 0.8,
    cutmix_alpha: float = 1.0,
    mix_prob: float = 0.5,
    no_aug_final_epochs: int = 5,  # Turn off MixUp/CutMix for last N epochs
    unfreeze_warmup_epochs: int = 3,  # Linear warmup after unfreeze
    early_stop_patience: int = 10,
    output_dir: str = "models",
    time_limit_minutes: float = 180.0,
    # ArcFace settings
    use_arcface: bool = True,
    arcface_scale: float = 30.0,
    arcface_margin: float = 0.3,
    # Progressive resizing: switch to this resolution at resize_at_epoch
    prog_resize_to: int = None,  # e.g. 336
    prog_resize_at_epoch: int = None,  # e.g. 15 (absolute epoch number)
    prog_resize_batch_size: int = None,  # reduced batch for higher res
    data_dir: str = None,  # needed for progressive resizing to rebuild loaders
    # Keep test_loader for backward compat but don't use for selection
    test_loader: DataLoader = None,
) -> dict:
    """Train with v3 recipe (ArcFace + progressive resizing).

    Key improvements over v2:
    - ArcFace angular margin loss for fine-grained discrimination
    - Progressive resizing: start at 224, bump to 336 mid-training
    - More epochs (50) and patience (10) for thorough convergence
    """
    if val_loader is None and test_loader is not None:
        print("  WARNING: Using test_loader for validation (no val_loader provided)")
        val_loader = test_loader

    device = get_device()
    loss_type = "ArcFace" if use_arcface else "CE"
    print(f"\n{'='*60}")
    print(f"Training: {backbone_name} (v3 recipe β€” {loss_type})")
    print(f"Device: {device}")
    print(f"Backbone LR mult: {backbone_lr_mult} (1/{int(1/backbone_lr_mult)}th of head)")
    if prog_resize_to:
        print(f"Progressive resize: 224 β†’ {prog_resize_to} at epoch {prog_resize_at_epoch}")
    print(f"{'='*60}")

    model = BreedClassifier(
        backbone_name,
        use_arcface=use_arcface,
        arcface_scale=arcface_scale,
        arcface_margin=arcface_margin,
    )
    model = model.to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")

    # CE criterion used for eval (always) and for training if not ArcFace
    criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    os.makedirs(output_dir, exist_ok=True)

    best_val_acc = 0.0
    patience_counter = 0
    history = []
    start_time = time.time()
    time_limit_sec = time_limit_minutes * 60
    current_img_size = 224

    # ─── Phase 1: Frozen backbone, train head only (no augmentation) ───
    print(f"\n[Phase 1] Frozen backbone β€” training head ({warmup_epochs} epochs)")
    model.freeze_backbone()
    head_params = [p for p in model.parameters() if p.requires_grad]
    warmup_optimizer = optim.AdamW(head_params, lr=lr, weight_decay=0.01)

    for epoch in range(warmup_epochs):
        warmup_scheduler = optim.lr_scheduler.OneCycleLR(
            warmup_optimizer,
            max_lr=lr,
            steps_per_epoch=len(train_loader),
            epochs=1,
            pct_start=0.3,
        )

        epoch_start = time.time()
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, warmup_optimizer, device,
            scheduler=warmup_scheduler, mixup_alpha=0, cutmix_alpha=0, mix_prob=0,
        )
        val_metrics = evaluate(model, val_loader, criterion, device)
        epoch_time = time.time() - epoch_start

        record = {
            "epoch": epoch + 1, "phase": 1, "img_size": current_img_size,
            "train_loss": train_loss, "train_acc": train_acc,
            **val_metrics, "epoch_time": epoch_time,
        }
        history.append(record)
        print(
            f"  Epoch {epoch+1}: Train {train_acc:.1f}% | "
            f"Val T1={val_metrics['top1_acc']:.1f}% T5={val_metrics['top5_acc']:.1f}% | "
            f"{epoch_time:.0f}s"
        )

        if val_metrics["top1_acc"] > best_val_acc:
            best_val_acc = val_metrics["top1_acc"]
            _save_checkpoint(model, backbone_name, epoch + 1, val_metrics, output_dir)

    # ─── Phase 2: Unfreeze backbone with careful LR recipe ───
    remaining = epochs - warmup_epochs
    print(f"\n[Phase 2] Unfrozen backbone β€” {remaining} epochs")
    print(f"  Unfreeze warmup: {unfreeze_warmup_epochs} epochs (no MixUp/CutMix)")
    print(f"  Final no-aug stage: last {no_aug_final_epochs} epochs")

    model.unfreeze_backbone()
    param_groups = model.get_param_groups(lr, backbone_lr_mult)
    for pg in param_groups:
        pg['initial_lr'] = pg['lr']
    optimizer = optim.AdamW(param_groups, weight_decay=0.05)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=remaining, eta_min=1e-6)

    for epoch in range(remaining):
        elapsed = time.time() - start_time
        if elapsed > time_limit_sec:
            print(f"\nTime limit reached ({time_limit_minutes}min). Stopping.")
            break

        epoch_num = warmup_epochs + epoch + 1

        # ─── Progressive resizing: switch resolution mid-training ───
        if (prog_resize_to and prog_resize_at_epoch
                and epoch_num == prog_resize_at_epoch
                and data_dir is not None):
            print(f"\n  >>> PROGRESSIVE RESIZE: {current_img_size} β†’ {prog_resize_to}px")
            current_img_size = prog_resize_to
            new_batch = prog_resize_batch_size or max(16, train_loader.batch_size // 2)
            preprocess_cfg = model.get_preprocess_config()
            train_loader, val_loader_new, test_loader_new = _build_loaders(
                data_dir, preprocess_cfg, new_batch, img_size_override=prog_resize_to,
            )
            if val_loader_new is not None:
                val_loader = val_loader_new
            if test_loader_new is not None:
                test_loader = test_loader_new
            print(f"  >>> New batch size: {new_batch}, loader rebuilt\n")

            # Reset patience β€” resolution change means model needs time to adapt
            patience_counter = 0

        # MixUp/CutMix schedule (disabled for ArcFace regardless)
        if use_arcface or epoch < unfreeze_warmup_epochs:
            ep_mix_prob = 0
            ep_mixup = 0
            ep_cutmix = 0
            phase_label = "2a-warmup" if epoch < unfreeze_warmup_epochs else "2b-arcface"
        elif epoch >= remaining - no_aug_final_epochs:
            ep_mix_prob = 0
            ep_mixup = 0
            ep_cutmix = 0
            phase_label = "2c-refine"
        else:
            ep_mix_prob = mix_prob
            ep_mixup = mixup_alpha
            ep_cutmix = cutmix_alpha
            phase_label = "2b-train"

        # Linear LR warmup during unfreeze warmup phase
        if epoch < unfreeze_warmup_epochs:
            warmup_factor = (epoch + 1) / unfreeze_warmup_epochs
            for pg in optimizer.param_groups:
                pg['lr'] = pg['initial_lr'] * warmup_factor if 'initial_lr' in pg else pg['lr']

        epoch_start = time.time()

        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device,
            scheduler=None,
            mixup_alpha=ep_mixup, cutmix_alpha=ep_cutmix, mix_prob=ep_mix_prob,
        )
        scheduler.step()
        val_metrics = evaluate(model, val_loader, criterion, device)
        epoch_time = time.time() - epoch_start

        record = {
            "epoch": epoch_num, "phase": phase_label, "img_size": current_img_size,
            "train_loss": train_loss, "train_acc": train_acc,
            **val_metrics, "epoch_time": epoch_time,
        }
        history.append(record)

        improved = ""
        if val_metrics["top1_acc"] > best_val_acc:
            best_val_acc = val_metrics["top1_acc"]
            _save_checkpoint(model, backbone_name, epoch_num, val_metrics, output_dir)
            improved = f" *NEW BEST*"
            patience_counter = 0
        else:
            patience_counter += 1

        bb_lr = optimizer.param_groups[0]['lr']
        head_lr = optimizer.param_groups[-1]['lr']
        print(
            f"  E{epoch_num} [{phase_label}] {current_img_size}px: Train {train_acc:.1f}% | "
            f"Val T1={val_metrics['top1_acc']:.1f}% T5={val_metrics['top5_acc']:.1f}% | "
            f"LR bb={bb_lr:.1e} head={head_lr:.1e} | {epoch_time:.0f}s{improved}"
        )

        # Don't early stop before progressive resize kicks in
        resize_pending = (prog_resize_to and prog_resize_at_epoch
                          and epoch_num < prog_resize_at_epoch)
        if patience_counter >= early_stop_patience and not resize_pending:
            print(f"\n  Early stopping: no improvement for {early_stop_patience} epochs")
            break
        elif patience_counter >= early_stop_patience and resize_pending:
            print(f"  (patience exhausted but holding for resize at epoch {prog_resize_at_epoch})")

    # Save history
    hist_path = os.path.join(output_dir, f"{backbone_name}_history.json")
    with open(hist_path, "w") as f:
        json.dump(history, f, indent=2)

    total_time = time.time() - start_time
    print(f"\n{backbone_name} complete β€” {total_time/60:.1f}min, best val top1: {best_val_acc:.1f}%")

    return {"backbone": backbone_name, "best_top1": best_val_acc, "history": history}


def _save_checkpoint(model, backbone_name, epoch, val_metrics, output_dir):
    save_path = os.path.join(output_dir, f"{backbone_name}_best.pt")
    torch.save({
        "model_state_dict": model.state_dict(),
        "backbone_name": backbone_name,
        "epoch": epoch,
        "val_top1": val_metrics["top1_acc"],
        "val_top5": val_metrics["top5_acc"],
        "num_classes": NUM_CLASSES,
    }, save_path)


def load_model(backbone_name: str, checkpoint_path: str, device: torch.device = None) -> BreedClassifier:
    """Load a trained model from checkpoint. Auto-detects ArcFace vs MLP head."""
    if device is None:
        device = get_device()
    ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
    # Detect ArcFace by checking for arcface-specific keys
    state_dict = ckpt["model_state_dict"]
    use_arcface = any("arcface" in k for k in state_dict.keys())
    model = BreedClassifier(
        backbone_name,
        num_classes=ckpt.get("num_classes", NUM_CLASSES),
        pretrained=False,
        use_arcface=use_arcface,
    )
    model.load_state_dict(state_dict)
    model.to(device).eval()
    return model