LETTER / src /run_training.py
Sharath33's picture
Upload folder using huggingface_hub
02ac88d verified
import torch
import json, os
from torchvision.datasets import Omniglot
from torchvision import transforms
from torch.utils.data import DataLoader
import wandb
from model import SiameseNet
from loss import ContrastiveLoss
from dataset import SiamesePairDataset
from train import train_one_epoch, validate, save_checkpoint
# ── Config ────────────────────────────────────────────────────
CFG = {
"epochs" : 30,
"batch_size" : 32,
"lr" : 1e-3,
"embedding_dim" : 128,
"margin" : 1.0,
"num_workers" : 4,
"num_pairs_train": 10000,
"num_pairs_val" : 2000,
"data_root" : "../data",
"ckpt_dir" : "../checkpoints",
}
# ── WandB ─────────────────────────────────────────────────────
wandb.init(project="siamese-few-shot", name="run-01", config=CFG)
# ── Data ──────────────────────────────────────────────────────
MEAN, STD = [0.9220], [0.2256]
train_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((105, 105)),
transforms.RandomCrop(105, padding=8),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
eval_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((105, 105)),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
bg = Omniglot(root=CFG["data_root"], background=True, download=True, transform=None)
with open(os.path.join(CFG["data_root"], "class_split.json")) as f:
split = json.load(f)
train_ds = SiamesePairDataset(bg, split["train"], transform=train_transform,
num_pairs=CFG["num_pairs_train"])
val_ds = SiamesePairDataset(bg, split["val"], transform=eval_transform,
num_pairs=CFG["num_pairs_val"])
train_loader = DataLoader(train_ds, batch_size=CFG["batch_size"], shuffle=True,
num_workers=CFG["num_workers"], pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=CFG["batch_size"], shuffle=False,
num_workers=CFG["num_workers"], pin_memory=True)
# ── Model / Loss / Optimiser ──────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SiameseNet(embedding_dim=CFG["embedding_dim"]).to(device)
criterion = ContrastiveLoss(margin=CFG["margin"])
optimizer = torch.optim.Adam(model.parameters(), lr=CFG["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG["epochs"])
print(f"Training on : {device}")
print(f"Train pairs : {len(train_ds)} | Val pairs: {len(val_ds)}")
# ── Training loop ─────────────────────────────────────────────
best_val_loss = float("inf")
for epoch in range(1, CFG["epochs"] + 1):
train_loss, train_acc = train_one_epoch(model, train_loader, criterion,
optimizer, device, epoch)
val_loss, val_acc = validate(model, val_loader, criterion, device, epoch)
scheduler.step()
print(f"Epoch {epoch:02d} | "
f"train loss {train_loss:.4f} acc {train_acc*100:.1f}% | "
f"val loss {val_loss:.4f} acc {val_acc*100:.1f}%")
wandb.log({
"epoch" : epoch,
"train/loss" : train_loss,
"train/acc" : train_acc,
"val/loss" : val_loss,
"val/acc" : val_acc,
"lr" : scheduler.get_last_lr()[0],
})
# Save best checkpoint
if val_loss < best_val_loss:
best_val_loss = val_loss
save_checkpoint(model, optimizer, epoch, val_loss,
f"{CFG['ckpt_dir']}/best.pt")
# Save final checkpoint regardless
save_checkpoint(model, optimizer, CFG["epochs"], val_loss,
f"{CFG['ckpt_dir']}/final.pt")
wandb.finish()
print("Training complete.")