core-jepa / src /train.py
Gajesh Ladhar
initial src and benchmark added
c71037b
# 📦 Imports
import torch
from torch.utils.data import DataLoader
from loss import SIGReg, gram_anchor_spatial
from backbone import LeJEPA
from data import DinoDataset
from utils import *
# hyper-parameters & setup
lambd = 0.01 # weighting for SIGReg loss
device = "cuda" if torch.cuda.is_available() else "cpu"
# dataset, model, loss fn, optimizer
dataset = DinoDataset(imgsz=1024,batch_size=16,queue_size=400)
model = LeJEPA(out_dims=256).to(device)
sigreg = SIGReg(device=device).to(device)
opt = torch.optim.AdamW(params=model.parameters(), lr=1e-5)
# training loop
num_epochs = 10000
for epoch in range(num_epochs):
model.train()
pbar = tqdm.tqdm(dataset.store, desc=f"Epoch {epoch+1}/{num_epochs}")
# epoch accumulators (for averaging)
sigreg_epoch = 0
inv_epoch = 0
lejepa_epoch = 0
steps = 0
for batch in pbar:
batch = batch['views']
batch = batch.to("cuda", non_blocking=True)
emb, proj = model(batch)
# losses
sigreg_loss = sigreg(proj)
inv_loss = (proj.transpose(0,1).mean(0) - proj.transpose(0,1)).square().mean()
lejepa_loss = inv_loss*(1-lambd) + sigreg_loss*lambd
loss = lejepa_loss
opt.zero_grad()
loss.backward()
opt.step()
# accumulate
sigreg_epoch += sigreg_loss.item()
inv_epoch += inv_loss.item()
lejepa_epoch += lejepa_loss.item()
steps += 1
# update tqdm bar
pbar.set_postfix({
"sigreg": float(sigreg_loss.item()),
"inv": float(inv_loss.item()),
"lejepa": float(lejepa_loss.item())
})
# epoch averages
sigreg_avg = sigreg_epoch / steps
inv_avg = inv_epoch / steps
lejepa_avg = lejepa_epoch / steps
print(f"Epoch {epoch} | sigreg: {sigreg_avg:.5f} | inv: {inv_avg:.5f} | lejepa: {lejepa_avg:.5f}")
torch.save(model.state_dict(), "lejepa-l.pt")