# 📦 Imports import torch from torch.utils.data import DataLoader from loss import SIGReg, gram_anchor_spatial from backbone import LeJEPA from data import DinoDataset from utils import * # hyper-parameters & setup lambd = 0.01 # weighting for SIGReg loss device = "cuda" if torch.cuda.is_available() else "cpu" # dataset, model, loss fn, optimizer dataset = DinoDataset(imgsz=1024,batch_size=16,queue_size=400) model = LeJEPA(out_dims=256).to(device) sigreg = SIGReg(device=device).to(device) opt = torch.optim.AdamW(params=model.parameters(), lr=1e-5) # training loop num_epochs = 10000 for epoch in range(num_epochs): model.train() pbar = tqdm.tqdm(dataset.store, desc=f"Epoch {epoch+1}/{num_epochs}") # epoch accumulators (for averaging) sigreg_epoch = 0 inv_epoch = 0 lejepa_epoch = 0 steps = 0 for batch in pbar: batch = batch['views'] batch = batch.to("cuda", non_blocking=True) emb, proj = model(batch) # losses sigreg_loss = sigreg(proj) inv_loss = (proj.transpose(0,1).mean(0) - proj.transpose(0,1)).square().mean() lejepa_loss = inv_loss*(1-lambd) + sigreg_loss*lambd loss = lejepa_loss opt.zero_grad() loss.backward() opt.step() # accumulate sigreg_epoch += sigreg_loss.item() inv_epoch += inv_loss.item() lejepa_epoch += lejepa_loss.item() steps += 1 # update tqdm bar pbar.set_postfix({ "sigreg": float(sigreg_loss.item()), "inv": float(inv_loss.item()), "lejepa": float(lejepa_loss.item()) }) # epoch averages sigreg_avg = sigreg_epoch / steps inv_avg = inv_epoch / steps lejepa_avg = lejepa_epoch / steps print(f"Epoch {epoch} | sigreg: {sigreg_avg:.5f} | inv: {inv_avg:.5f} | lejepa: {lejepa_avg:.5f}") torch.save(model.state_dict(), "lejepa-l.pt")