File size: 4,370 Bytes
02ac88d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import json, os
from torchvision.datasets import Omniglot
from torchvision import transforms
from torch.utils.data import DataLoader
import wandb

from model   import SiameseNet
from loss    import ContrastiveLoss
from dataset import SiamesePairDataset
from train   import train_one_epoch, validate, save_checkpoint

# ── Config ────────────────────────────────────────────────────
CFG = {
    "epochs"        : 30,
    "batch_size"    : 32,
    "lr"            : 1e-3,
    "embedding_dim" : 128,
    "margin"        : 1.0,
    "num_workers"   : 4,
    "num_pairs_train": 10000,
    "num_pairs_val"  : 2000,
    "data_root"     : "../data",
    "ckpt_dir"      : "../checkpoints",
}

# ── WandB ─────────────────────────────────────────────────────
wandb.init(project="siamese-few-shot", name="run-01", config=CFG)

# ── Data ──────────────────────────────────────────────────────
MEAN, STD = [0.9220], [0.2256]

train_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((105, 105)),
    transforms.RandomCrop(105, padding=8),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

eval_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((105, 105)),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

bg = Omniglot(root=CFG["data_root"], background=True, download=True, transform=None)

with open(os.path.join(CFG["data_root"], "class_split.json")) as f:
    split = json.load(f)

train_ds = SiamesePairDataset(bg, split["train"], transform=train_transform,
                               num_pairs=CFG["num_pairs_train"])
val_ds   = SiamesePairDataset(bg, split["val"],   transform=eval_transform,
                               num_pairs=CFG["num_pairs_val"])

train_loader = DataLoader(train_ds, batch_size=CFG["batch_size"], shuffle=True,
                          num_workers=CFG["num_workers"], pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=CFG["batch_size"], shuffle=False,
                          num_workers=CFG["num_workers"], pin_memory=True)

# ── Model / Loss / Optimiser ──────────────────────────────────
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model     = SiameseNet(embedding_dim=CFG["embedding_dim"]).to(device)
criterion = ContrastiveLoss(margin=CFG["margin"])
optimizer = torch.optim.Adam(model.parameters(), lr=CFG["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG["epochs"])

print(f"Training on : {device}")
print(f"Train pairs : {len(train_ds)} | Val pairs: {len(val_ds)}")

# ── Training loop ─────────────────────────────────────────────
best_val_loss = float("inf")

for epoch in range(1, CFG["epochs"] + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion,
                                            optimizer, device, epoch)
    val_loss,   val_acc   = validate(model, val_loader, criterion, device, epoch)
    scheduler.step()

    print(f"Epoch {epoch:02d} | "
          f"train loss {train_loss:.4f}  acc {train_acc*100:.1f}%  | "
          f"val loss {val_loss:.4f}  acc {val_acc*100:.1f}%")

    wandb.log({
        "epoch"      : epoch,
        "train/loss" : train_loss,
        "train/acc"  : train_acc,
        "val/loss"   : val_loss,
        "val/acc"    : val_acc,
        "lr"         : scheduler.get_last_lr()[0],
    })

    # Save best checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_checkpoint(model, optimizer, epoch, val_loss,
                        f"{CFG['ckpt_dir']}/best.pt")

# Save final checkpoint regardless
save_checkpoint(model, optimizer, CFG["epochs"], val_loss,
                f"{CFG['ckpt_dir']}/final.pt")

wandb.finish()
print("Training complete.")