import torch from torch import nn import torch.nn.functional as F class SIGReg(torch.nn.Module): def __init__(self, knots=17,device='cpu'): super().__init__() self.device = device t = torch.linspace(0, 3, knots, dtype=torch.float32) dt = 3 / (knots - 1) weights = torch.full((knots,), 2 * dt, dtype=torch.float32) weights[[0, -1]] = dt window = torch.exp(-t.square() / 2.0) self.register_buffer("t", t) self.register_buffer("phi", window) self.register_buffer("weights", weights * window) def forward(self, proj): N, V = proj.shape[:2] proj = proj.reshape(N, V, -1).transpose(0, 1) A = torch.randn(proj.size(-1), 256, device=self.device) A = A.div_(A.norm(p=2, dim=0)) x_t = (proj @ A).unsqueeze(-1) * self.t err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square() statistic = (err @ self.weights) * proj.size(-2) * (256/proj.size(-2)) return statistic.mean() def gram_anchor_spatial(xs, xt, eps=1e-6): """ xs, xt: (N, V, D, H, W) Gram anchoring over spatial tokens H*W. """ N, V, Ds, H, W = xs.shape N, V, Dt, H, W = xt.shape T = H * W # 1. Normalize over feature dimension D xs = F.normalize(xs, p=2, dim=2) # (N, V, D, H, W) xt = F.normalize(xt, p=2, dim=2) # 2. Flatten spatial tokens -> (N, V, D, T) xs = xs.reshape(N, V, Ds, T) xt = xt.reshape(N, V, Dt, T) # 3. Move token dim second: (N, V, T, D) xs = xs.permute(0, 1, 3, 2) xt = xt.permute(0, 1, 3, 2) # 4. Merge N and V -> (N*V, T, D) xs = xs.reshape(N * V, T, Ds) xt = xt.reshape(N * V, T, Dt) # 5. Compute Gram matrices -> (N*V, T, T) Gs = torch.matmul(xs, xs.transpose(-1, -2)) Gt = torch.matmul(xt, xt.transpose(-1, -2)) # 6. Frobenius loss loss = (Gs - Gt).pow(2).mean() return loss