File size: 4,370 Bytes
02ac88d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | import torch
import json, os
from torchvision.datasets import Omniglot
from torchvision import transforms
from torch.utils.data import DataLoader
import wandb
from model import SiameseNet
from loss import ContrastiveLoss
from dataset import SiamesePairDataset
from train import train_one_epoch, validate, save_checkpoint
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββ
CFG = {
"epochs" : 30,
"batch_size" : 32,
"lr" : 1e-3,
"embedding_dim" : 128,
"margin" : 1.0,
"num_workers" : 4,
"num_pairs_train": 10000,
"num_pairs_val" : 2000,
"data_root" : "../data",
"ckpt_dir" : "../checkpoints",
}
# ββ WandB βββββββββββββββββββββββββββββββββββββββββββββββββββββ
wandb.init(project="siamese-few-shot", name="run-01", config=CFG)
# ββ Data ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MEAN, STD = [0.9220], [0.2256]
train_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((105, 105)),
transforms.RandomCrop(105, padding=8),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
eval_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((105, 105)),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
bg = Omniglot(root=CFG["data_root"], background=True, download=True, transform=None)
with open(os.path.join(CFG["data_root"], "class_split.json")) as f:
split = json.load(f)
train_ds = SiamesePairDataset(bg, split["train"], transform=train_transform,
num_pairs=CFG["num_pairs_train"])
val_ds = SiamesePairDataset(bg, split["val"], transform=eval_transform,
num_pairs=CFG["num_pairs_val"])
train_loader = DataLoader(train_ds, batch_size=CFG["batch_size"], shuffle=True,
num_workers=CFG["num_workers"], pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=CFG["batch_size"], shuffle=False,
num_workers=CFG["num_workers"], pin_memory=True)
# ββ Model / Loss / Optimiser ββββββββββββββββββββββββββββββββββ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SiameseNet(embedding_dim=CFG["embedding_dim"]).to(device)
criterion = ContrastiveLoss(margin=CFG["margin"])
optimizer = torch.optim.Adam(model.parameters(), lr=CFG["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG["epochs"])
print(f"Training on : {device}")
print(f"Train pairs : {len(train_ds)} | Val pairs: {len(val_ds)}")
# ββ Training loop βββββββββββββββββββββββββββββββββββββββββββββ
best_val_loss = float("inf")
for epoch in range(1, CFG["epochs"] + 1):
train_loss, train_acc = train_one_epoch(model, train_loader, criterion,
optimizer, device, epoch)
val_loss, val_acc = validate(model, val_loader, criterion, device, epoch)
scheduler.step()
print(f"Epoch {epoch:02d} | "
f"train loss {train_loss:.4f} acc {train_acc*100:.1f}% | "
f"val loss {val_loss:.4f} acc {val_acc*100:.1f}%")
wandb.log({
"epoch" : epoch,
"train/loss" : train_loss,
"train/acc" : train_acc,
"val/loss" : val_loss,
"val/acc" : val_acc,
"lr" : scheduler.get_last_lr()[0],
})
# Save best checkpoint
if val_loss < best_val_loss:
best_val_loss = val_loss
save_checkpoint(model, optimizer, epoch, val_loss,
f"{CFG['ckpt_dir']}/best.pt")
# Save final checkpoint regardless
save_checkpoint(model, optimizer, CFG["epochs"], val_loss,
f"{CFG['ckpt_dir']}/final.pt")
wandb.finish()
print("Training complete.") |