| import torch |
| import torch.nn as nn |
| from torch.optim import Adam |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| import wandb |
| from tqdm import tqdm |
| import os |
|
|
| def train_one_epoch(model, loader, criterion, optimizer, device, epoch): |
| model.train() |
| total_loss, correct, total = 0.0, 0, 0 |
|
|
| loop = tqdm(loader, desc=f"Epoch {epoch} [train]", leave=False) |
| for img1, img2, labels in loop: |
| img1, img2, labels = img1.to(device), img2.to(device), labels.to(device) |
|
|
| optimizer.zero_grad() |
| emb1, emb2 = model(img1, img2) |
| loss, dist = criterion(emb1, emb2, labels) |
| loss.backward() |
|
|
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
| optimizer.step() |
|
|
| |
| preds = (dist < 0.5).float() |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
| total_loss += loss.item() |
|
|
| loop.set_postfix(loss=f"{loss.item():.4f}") |
|
|
| return total_loss / len(loader), correct / total |
|
|
|
|
| @torch.no_grad() |
| def validate(model, loader, criterion, device, epoch): |
| model.eval() |
| total_loss, correct, total = 0.0, 0, 0 |
|
|
| loop = tqdm(loader, desc=f"Epoch {epoch} [val] ", leave=False) |
| for img1, img2, labels in loop: |
| img1, img2, labels = img1.to(device), img2.to(device), labels.to(device) |
|
|
| emb1, emb2 = model(img1, img2) |
| loss, dist = criterion(emb1, emb2, labels) |
|
|
| preds = (dist < 0.5).float() |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
| total_loss += loss.item() |
|
|
| return total_loss / len(loader), correct / total |
|
|
|
|
| def save_checkpoint(model, optimizer, epoch, val_loss, path): |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
| torch.save({ |
| "epoch" : epoch, |
| "model_state": model.state_dict(), |
| "optim_state": optimizer.state_dict(), |
| "val_loss" : val_loss, |
| }, path) |
| print(f" Checkpoint saved → {path}") |
|
|
|
|
| def load_checkpoint(path, model, optimizer=None): |
| ckpt = torch.load(path) |
| model.load_state_dict(ckpt["model_state"]) |
| if optimizer: |
| optimizer.load_state_dict(ckpt["optim_state"]) |
| print(f" Resumed from epoch {ckpt['epoch']} (val_loss={ckpt['val_loss']:.4f})") |
| return ckpt["epoch"] |