| import torch |
| import json, os |
| from torchvision.datasets import Omniglot |
| from torchvision import transforms |
| from torch.utils.data import DataLoader |
| import wandb |
|
|
| from model import SiameseNet |
| from loss import ContrastiveLoss |
| from dataset import SiamesePairDataset |
| from train import train_one_epoch, validate, save_checkpoint |
|
|
| |
| CFG = { |
| "epochs" : 30, |
| "batch_size" : 32, |
| "lr" : 1e-3, |
| "embedding_dim" : 128, |
| "margin" : 1.0, |
| "num_workers" : 4, |
| "num_pairs_train": 10000, |
| "num_pairs_val" : 2000, |
| "data_root" : "../data", |
| "ckpt_dir" : "../checkpoints", |
| } |
|
|
| |
| wandb.init(project="siamese-few-shot", name="run-01", config=CFG) |
|
|
| |
| MEAN, STD = [0.9220], [0.2256] |
|
|
| train_transform = transforms.Compose([ |
| transforms.Grayscale(), |
| transforms.Resize((105, 105)), |
| transforms.RandomCrop(105, padding=8), |
| transforms.RandomHorizontalFlip(), |
| transforms.ColorJitter(brightness=0.2, contrast=0.2), |
| transforms.ToTensor(), |
| transforms.Normalize(MEAN, STD), |
| ]) |
|
|
| eval_transform = transforms.Compose([ |
| transforms.Grayscale(), |
| transforms.Resize((105, 105)), |
| transforms.ToTensor(), |
| transforms.Normalize(MEAN, STD), |
| ]) |
|
|
| bg = Omniglot(root=CFG["data_root"], background=True, download=True, transform=None) |
|
|
| with open(os.path.join(CFG["data_root"], "class_split.json")) as f: |
| split = json.load(f) |
|
|
| train_ds = SiamesePairDataset(bg, split["train"], transform=train_transform, |
| num_pairs=CFG["num_pairs_train"]) |
| val_ds = SiamesePairDataset(bg, split["val"], transform=eval_transform, |
| num_pairs=CFG["num_pairs_val"]) |
|
|
| train_loader = DataLoader(train_ds, batch_size=CFG["batch_size"], shuffle=True, |
| num_workers=CFG["num_workers"], pin_memory=True) |
| val_loader = DataLoader(val_ds, batch_size=CFG["batch_size"], shuffle=False, |
| num_workers=CFG["num_workers"], pin_memory=True) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = SiameseNet(embedding_dim=CFG["embedding_dim"]).to(device) |
| criterion = ContrastiveLoss(margin=CFG["margin"]) |
| optimizer = torch.optim.Adam(model.parameters(), lr=CFG["lr"]) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG["epochs"]) |
|
|
| print(f"Training on : {device}") |
| print(f"Train pairs : {len(train_ds)} | Val pairs: {len(val_ds)}") |
|
|
| |
| best_val_loss = float("inf") |
|
|
| for epoch in range(1, CFG["epochs"] + 1): |
| train_loss, train_acc = train_one_epoch(model, train_loader, criterion, |
| optimizer, device, epoch) |
| val_loss, val_acc = validate(model, val_loader, criterion, device, epoch) |
| scheduler.step() |
|
|
| print(f"Epoch {epoch:02d} | " |
| f"train loss {train_loss:.4f} acc {train_acc*100:.1f}% | " |
| f"val loss {val_loss:.4f} acc {val_acc*100:.1f}%") |
|
|
| wandb.log({ |
| "epoch" : epoch, |
| "train/loss" : train_loss, |
| "train/acc" : train_acc, |
| "val/loss" : val_loss, |
| "val/acc" : val_acc, |
| "lr" : scheduler.get_last_lr()[0], |
| }) |
|
|
| |
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| save_checkpoint(model, optimizer, epoch, val_loss, |
| f"{CFG['ckpt_dir']}/best.pt") |
|
|
| |
| save_checkpoint(model, optimizer, CFG["epochs"], val_loss, |
| f"{CFG['ckpt_dir']}/final.pt") |
|
|
| wandb.finish() |
| print("Training complete.") |