import torch import torch.nn as nn from torch.optim import Adam from torch.optim.lr_scheduler import CosineAnnealingLR import wandb from tqdm import tqdm import os def train_one_epoch(model, loader, criterion, optimizer, device, epoch): model.train() total_loss, correct, total = 0.0, 0, 0 loop = tqdm(loader, desc=f"Epoch {epoch} [train]", leave=False) for img1, img2, labels in loop: img1, img2, labels = img1.to(device), img2.to(device), labels.to(device) optimizer.zero_grad() emb1, emb2 = model(img1, img2) loss, dist = criterion(emb1, emb2, labels) loss.backward() # Gradient clipping — prevents exploding gradients early in training torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() # Accuracy: predict same-class if distance < 0.5 preds = (dist < 0.5).float() correct += (preds == labels).sum().item() total += labels.size(0) total_loss += loss.item() loop.set_postfix(loss=f"{loss.item():.4f}") return total_loss / len(loader), correct / total @torch.no_grad() def validate(model, loader, criterion, device, epoch): model.eval() total_loss, correct, total = 0.0, 0, 0 loop = tqdm(loader, desc=f"Epoch {epoch} [val] ", leave=False) for img1, img2, labels in loop: img1, img2, labels = img1.to(device), img2.to(device), labels.to(device) emb1, emb2 = model(img1, img2) loss, dist = criterion(emb1, emb2, labels) preds = (dist < 0.5).float() correct += (preds == labels).sum().item() total += labels.size(0) total_loss += loss.item() return total_loss / len(loader), correct / total def save_checkpoint(model, optimizer, epoch, val_loss, path): os.makedirs(os.path.dirname(path), exist_ok=True) torch.save({ "epoch" : epoch, "model_state": model.state_dict(), "optim_state": optimizer.state_dict(), "val_loss" : val_loss, }, path) print(f" Checkpoint saved → {path}") def load_checkpoint(path, model, optimizer=None): ckpt = torch.load(path) model.load_state_dict(ckpt["model_state"]) if optimizer: optimizer.load_state_dict(ckpt["optim_state"]) print(f" Resumed from epoch {ckpt['epoch']} (val_loss={ckpt['val_loss']:.4f})") return ckpt["epoch"]