import torch import json, os from torchvision.datasets import Omniglot from torchvision import transforms from torch.utils.data import DataLoader import wandb from model import SiameseNet from loss import ContrastiveLoss from dataset import SiamesePairDataset from train import train_one_epoch, validate, save_checkpoint # ── Config ──────────────────────────────────────────────────── CFG = { "epochs" : 30, "batch_size" : 32, "lr" : 1e-3, "embedding_dim" : 128, "margin" : 1.0, "num_workers" : 4, "num_pairs_train": 10000, "num_pairs_val" : 2000, "data_root" : "../data", "ckpt_dir" : "../checkpoints", } # ── WandB ───────────────────────────────────────────────────── wandb.init(project="siamese-few-shot", name="run-01", config=CFG) # ── Data ────────────────────────────────────────────────────── MEAN, STD = [0.9220], [0.2256] train_transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((105, 105)), transforms.RandomCrop(105, padding=8), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), transforms.Normalize(MEAN, STD), ]) eval_transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((105, 105)), transforms.ToTensor(), transforms.Normalize(MEAN, STD), ]) bg = Omniglot(root=CFG["data_root"], background=True, download=True, transform=None) with open(os.path.join(CFG["data_root"], "class_split.json")) as f: split = json.load(f) train_ds = SiamesePairDataset(bg, split["train"], transform=train_transform, num_pairs=CFG["num_pairs_train"]) val_ds = SiamesePairDataset(bg, split["val"], transform=eval_transform, num_pairs=CFG["num_pairs_val"]) train_loader = DataLoader(train_ds, batch_size=CFG["batch_size"], shuffle=True, num_workers=CFG["num_workers"], pin_memory=True) val_loader = DataLoader(val_ds, batch_size=CFG["batch_size"], shuffle=False, num_workers=CFG["num_workers"], pin_memory=True) # ── Model / Loss / Optimiser ────────────────────────────────── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SiameseNet(embedding_dim=CFG["embedding_dim"]).to(device) criterion = ContrastiveLoss(margin=CFG["margin"]) optimizer = torch.optim.Adam(model.parameters(), lr=CFG["lr"]) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG["epochs"]) print(f"Training on : {device}") print(f"Train pairs : {len(train_ds)} | Val pairs: {len(val_ds)}") # ── Training loop ───────────────────────────────────────────── best_val_loss = float("inf") for epoch in range(1, CFG["epochs"] + 1): train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch) val_loss, val_acc = validate(model, val_loader, criterion, device, epoch) scheduler.step() print(f"Epoch {epoch:02d} | " f"train loss {train_loss:.4f} acc {train_acc*100:.1f}% | " f"val loss {val_loss:.4f} acc {val_acc*100:.1f}%") wandb.log({ "epoch" : epoch, "train/loss" : train_loss, "train/acc" : train_acc, "val/loss" : val_loss, "val/acc" : val_acc, "lr" : scheduler.get_last_lr()[0], }) # Save best checkpoint if val_loss < best_val_loss: best_val_loss = val_loss save_checkpoint(model, optimizer, epoch, val_loss, f"{CFG['ckpt_dir']}/best.pt") # Save final checkpoint regardless save_checkpoint(model, optimizer, CFG["epochs"], val_loss, f"{CFG['ckpt_dir']}/final.pt") wandb.finish() print("Training complete.")