|
|
|
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from loss import SIGReg, gram_anchor_spatial |
|
|
from backbone import LeJEPA |
|
|
from data import DinoDataset |
|
|
from utils import * |
|
|
|
|
|
|
|
|
lambd = 0.01 |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
dataset = DinoDataset(imgsz=1024,batch_size=16,queue_size=400) |
|
|
model = LeJEPA(out_dims=256).to(device) |
|
|
sigreg = SIGReg(device=device).to(device) |
|
|
opt = torch.optim.AdamW(params=model.parameters(), lr=1e-5) |
|
|
|
|
|
|
|
|
|
|
|
num_epochs = 10000 |
|
|
for epoch in range(num_epochs): |
|
|
model.train() |
|
|
pbar = tqdm.tqdm(dataset.store, desc=f"Epoch {epoch+1}/{num_epochs}") |
|
|
|
|
|
|
|
|
sigreg_epoch = 0 |
|
|
inv_epoch = 0 |
|
|
lejepa_epoch = 0 |
|
|
steps = 0 |
|
|
|
|
|
for batch in pbar: |
|
|
batch = batch['views'] |
|
|
batch = batch.to("cuda", non_blocking=True) |
|
|
emb, proj = model(batch) |
|
|
|
|
|
sigreg_loss = sigreg(proj) |
|
|
inv_loss = (proj.transpose(0,1).mean(0) - proj.transpose(0,1)).square().mean() |
|
|
lejepa_loss = inv_loss*(1-lambd) + sigreg_loss*lambd |
|
|
loss = lejepa_loss |
|
|
|
|
|
opt.zero_grad() |
|
|
loss.backward() |
|
|
opt.step() |
|
|
|
|
|
|
|
|
sigreg_epoch += sigreg_loss.item() |
|
|
inv_epoch += inv_loss.item() |
|
|
lejepa_epoch += lejepa_loss.item() |
|
|
steps += 1 |
|
|
|
|
|
|
|
|
pbar.set_postfix({ |
|
|
"sigreg": float(sigreg_loss.item()), |
|
|
"inv": float(inv_loss.item()), |
|
|
"lejepa": float(lejepa_loss.item()) |
|
|
}) |
|
|
|
|
|
|
|
|
sigreg_avg = sigreg_epoch / steps |
|
|
inv_avg = inv_epoch / steps |
|
|
lejepa_avg = lejepa_epoch / steps |
|
|
|
|
|
print(f"Epoch {epoch} | sigreg: {sigreg_avg:.5f} | inv: {inv_avg:.5f} | lejepa: {lejepa_avg:.5f}") |
|
|
torch.save(model.state_dict(), "lejepa-l.pt") |