phanerozoic's picture
update repository
dbbceb8
"""Cofiber Threshold: adjoint cofiber decomposition + per-scale LayerNorm + threshold prediction.
Zero-parameter multi-scale decomposition derived from the adjoint pair
(bilinear upsample, average pool). The cofiber of the round-trip map
isolates per-scale content. Classification via prototype dot products.
65,000 parameters.
Cross-domain performance (21 domains, 15K steps): avg precision 0.617, avg recall 0.368.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from losses.fcos import fcos_loss
from utils.decode import make_locations, decode_fcos
NUM_CLASSES = 80
def cofiber_decompose(f, n_scales):
"""Compute cofibers via the adjoint pair (Sigma, Omega).
Sigma = bilinear upsample 2x, Omega = avg pool 2x.
cofiber_k = f_k - Sigma(Omega(f_k)): information at scale k absent from scale k+1.
"""
cofibers = []
residual = f
for _ in range(n_scales - 1):
omega = F.avg_pool2d(residual, 2)
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
cofibers.append(residual - sigma_omega)
residual = omega
cofibers.append(residual)
return cofibers
class CofiberThreshold(nn.Module):
"""Adjoint decomposition (0 params) + per-scale LayerNorm + threshold prediction (~65K params)."""
name = "cofiber_threshold"
needs_intermediates = False
def __init__(self, feat_dim=768, num_classes=NUM_CLASSES, n_scales=3):
super().__init__()
self.n_scales = n_scales
self.scale_norms = nn.ModuleList([nn.LayerNorm(feat_dim) for _ in range(n_scales)])
self.prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01)
self.proto_bias = nn.Parameter(torch.zeros(num_classes))
self.reg_weight = nn.Parameter(torch.randn(4, feat_dim) * 0.01)
self.reg_bias = nn.Parameter(torch.zeros(4))
self.ctr_weight = nn.Parameter(torch.randn(1, feat_dim) * 0.01)
self.ctr_bias = nn.Parameter(torch.zeros(1))
self.scale_params = nn.Parameter(torch.ones(n_scales))
def forward(self, spatial, inter=None):
cofibers = cofiber_decompose(spatial, self.n_scales)
cls_l, reg_l, ctr_l = [], [], []
for i, cof in enumerate(cofibers):
B, C, H, W = cof.shape
f = self.scale_norms[i](cof.permute(0, 2, 3, 1).reshape(-1, C))
cls = (f @ self.prototypes.T + self.proto_bias).reshape(B, H, W, -1).permute(0, 3, 1, 2)
raw = ((f @ self.reg_weight.T + self.reg_bias) * self.scale_params[i]).clamp(-10, 10)
reg = torch.exp(raw).reshape(B, H, W, 4).permute(0, 3, 1, 2)
ctr = (f @ self.ctr_weight.T + self.ctr_bias).reshape(B, H, W, 1).permute(0, 3, 1, 2)
cls_l.append(cls)
reg_l.append(reg)
ctr_l.append(ctr)
return cls_l, reg_l, ctr_l
def loss(self, preds, locs, boxes_b, labels_b):
return fcos_loss(*preds, locs, boxes_b, labels_b)
def decode(self, preds, locs, **kw):
return decode_fcos(*preds, locs, **kw)
def get_locs(self, spatial):
dummy = cofiber_decompose(spatial[:1], self.n_scales)
sizes = [(c.shape[2], c.shape[3]) for c in dummy]
strides = [16 * (2 ** i) for i in range(self.n_scales)]
return make_locations(sizes, strides, spatial.device)