|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class SIGReg(torch.nn.Module): |
|
|
def __init__(self, knots=17,device='cpu'): |
|
|
super().__init__() |
|
|
self.device = device |
|
|
t = torch.linspace(0, 3, knots, dtype=torch.float32) |
|
|
dt = 3 / (knots - 1) |
|
|
weights = torch.full((knots,), 2 * dt, dtype=torch.float32) |
|
|
weights[[0, -1]] = dt |
|
|
window = torch.exp(-t.square() / 2.0) |
|
|
self.register_buffer("t", t) |
|
|
self.register_buffer("phi", window) |
|
|
self.register_buffer("weights", weights * window) |
|
|
|
|
|
def forward(self, proj): |
|
|
N, V = proj.shape[:2] |
|
|
proj = proj.reshape(N, V, -1).transpose(0, 1) |
|
|
A = torch.randn(proj.size(-1), 256, device=self.device) |
|
|
A = A.div_(A.norm(p=2, dim=0)) |
|
|
x_t = (proj @ A).unsqueeze(-1) * self.t |
|
|
err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square() |
|
|
statistic = (err @ self.weights) * proj.size(-2) * (256/proj.size(-2)) |
|
|
return statistic.mean() |
|
|
|
|
|
|
|
|
def gram_anchor_spatial(xs, xt, eps=1e-6): |
|
|
""" |
|
|
xs, xt: (N, V, D, H, W) |
|
|
Gram anchoring over spatial tokens H*W. |
|
|
""" |
|
|
|
|
|
N, V, Ds, H, W = xs.shape |
|
|
N, V, Dt, H, W = xt.shape |
|
|
T = H * W |
|
|
|
|
|
|
|
|
xs = F.normalize(xs, p=2, dim=2) |
|
|
xt = F.normalize(xt, p=2, dim=2) |
|
|
|
|
|
|
|
|
xs = xs.reshape(N, V, Ds, T) |
|
|
xt = xt.reshape(N, V, Dt, T) |
|
|
|
|
|
|
|
|
xs = xs.permute(0, 1, 3, 2) |
|
|
xt = xt.permute(0, 1, 3, 2) |
|
|
|
|
|
|
|
|
xs = xs.reshape(N * V, T, Ds) |
|
|
xt = xt.reshape(N * V, T, Dt) |
|
|
|
|
|
|
|
|
Gs = torch.matmul(xs, xs.transpose(-1, -2)) |
|
|
Gt = torch.matmul(xt, xt.transpose(-1, -2)) |
|
|
|
|
|
|
|
|
loss = (Gs - Gt).pow(2).mean() |
|
|
|
|
|
return loss |