core-jepa / src /loss.py
Gajesh Ladhar
initial src and benchmark added
c71037b
import torch
from torch import nn
import torch.nn.functional as F
class SIGReg(torch.nn.Module):
def __init__(self, knots=17,device='cpu'):
super().__init__()
self.device = device
t = torch.linspace(0, 3, knots, dtype=torch.float32)
dt = 3 / (knots - 1)
weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
weights[[0, -1]] = dt
window = torch.exp(-t.square() / 2.0)
self.register_buffer("t", t)
self.register_buffer("phi", window)
self.register_buffer("weights", weights * window)
def forward(self, proj):
N, V = proj.shape[:2]
proj = proj.reshape(N, V, -1).transpose(0, 1)
A = torch.randn(proj.size(-1), 256, device=self.device)
A = A.div_(A.norm(p=2, dim=0))
x_t = (proj @ A).unsqueeze(-1) * self.t
err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
statistic = (err @ self.weights) * proj.size(-2) * (256/proj.size(-2))
return statistic.mean()
def gram_anchor_spatial(xs, xt, eps=1e-6):
"""
xs, xt: (N, V, D, H, W)
Gram anchoring over spatial tokens H*W.
"""
N, V, Ds, H, W = xs.shape
N, V, Dt, H, W = xt.shape
T = H * W
# 1. Normalize over feature dimension D
xs = F.normalize(xs, p=2, dim=2) # (N, V, D, H, W)
xt = F.normalize(xt, p=2, dim=2)
# 2. Flatten spatial tokens -> (N, V, D, T)
xs = xs.reshape(N, V, Ds, T)
xt = xt.reshape(N, V, Dt, T)
# 3. Move token dim second: (N, V, T, D)
xs = xs.permute(0, 1, 3, 2)
xt = xt.permute(0, 1, 3, 2)
# 4. Merge N and V -> (N*V, T, D)
xs = xs.reshape(N * V, T, Ds)
xt = xt.reshape(N * V, T, Dt)
# 5. Compute Gram matrices -> (N*V, T, T)
Gs = torch.matmul(xs, xs.transpose(-1, -2))
Gt = torch.matmul(xt, xt.transpose(-1, -2))
# 6. Frobenius loss
loss = (Gs - Gt).pow(2).mean()
return loss