File size: 2,414 Bytes
02ac88d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
import wandb
from tqdm import tqdm
import os

def train_one_epoch(model, loader, criterion, optimizer, device, epoch):
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    loop = tqdm(loader, desc=f"Epoch {epoch} [train]", leave=False)
    for img1, img2, labels in loop:
        img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)

        optimizer.zero_grad()
        emb1, emb2   = model(img1, img2)
        loss, dist   = criterion(emb1, emb2, labels)
        loss.backward()

        # Gradient clipping — prevents exploding gradients early in training
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        # Accuracy: predict same-class if distance < 0.5
        preds    = (dist < 0.5).float()
        correct += (preds == labels).sum().item()
        total   += labels.size(0)
        total_loss += loss.item()

        loop.set_postfix(loss=f"{loss.item():.4f}")

    return total_loss / len(loader), correct / total


@torch.no_grad()
def validate(model, loader, criterion, device, epoch):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0

    loop = tqdm(loader, desc=f"Epoch {epoch} [val]  ", leave=False)
    for img1, img2, labels in loop:
        img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)

        emb1, emb2 = model(img1, img2)
        loss, dist = criterion(emb1, emb2, labels)

        preds    = (dist < 0.5).float()
        correct += (preds == labels).sum().item()
        total   += labels.size(0)
        total_loss += loss.item()

    return total_loss / len(loader), correct / total


def save_checkpoint(model, optimizer, epoch, val_loss, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        "epoch"      : epoch,
        "model_state": model.state_dict(),
        "optim_state": optimizer.state_dict(),
        "val_loss"   : val_loss,
    }, path)
    print(f"  Checkpoint saved → {path}")


def load_checkpoint(path, model, optimizer=None):
    ckpt = torch.load(path)
    model.load_state_dict(ckpt["model_state"])
    if optimizer:
        optimizer.load_state_dict(ckpt["optim_state"])
    print(f"  Resumed from epoch {ckpt['epoch']} (val_loss={ckpt['val_loss']:.4f})")
    return ckpt["epoch"]