File size: 2,414 Bytes
02ac88d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
import wandb
from tqdm import tqdm
import os
def train_one_epoch(model, loader, criterion, optimizer, device, epoch):
model.train()
total_loss, correct, total = 0.0, 0, 0
loop = tqdm(loader, desc=f"Epoch {epoch} [train]", leave=False)
for img1, img2, labels in loop:
img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
optimizer.zero_grad()
emb1, emb2 = model(img1, img2)
loss, dist = criterion(emb1, emb2, labels)
loss.backward()
# Gradient clipping — prevents exploding gradients early in training
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# Accuracy: predict same-class if distance < 0.5
preds = (dist < 0.5).float()
correct += (preds == labels).sum().item()
total += labels.size(0)
total_loss += loss.item()
loop.set_postfix(loss=f"{loss.item():.4f}")
return total_loss / len(loader), correct / total
@torch.no_grad()
def validate(model, loader, criterion, device, epoch):
model.eval()
total_loss, correct, total = 0.0, 0, 0
loop = tqdm(loader, desc=f"Epoch {epoch} [val] ", leave=False)
for img1, img2, labels in loop:
img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
emb1, emb2 = model(img1, img2)
loss, dist = criterion(emb1, emb2, labels)
preds = (dist < 0.5).float()
correct += (preds == labels).sum().item()
total += labels.size(0)
total_loss += loss.item()
return total_loss / len(loader), correct / total
def save_checkpoint(model, optimizer, epoch, val_loss, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
"epoch" : epoch,
"model_state": model.state_dict(),
"optim_state": optimizer.state_dict(),
"val_loss" : val_loss,
}, path)
print(f" Checkpoint saved → {path}")
def load_checkpoint(path, model, optimizer=None):
ckpt = torch.load(path)
model.load_state_dict(ckpt["model_state"])
if optimizer:
optimizer.load_state_dict(ckpt["optim_state"])
print(f" Resumed from epoch {ckpt['epoch']} (val_loss={ckpt['val_loss']:.4f})")
return ckpt["epoch"] |