| """Cofiber Threshold: adjoint cofiber decomposition + per-scale LayerNorm + threshold prediction. |
| |
| Zero-parameter multi-scale decomposition derived from the adjoint pair |
| (bilinear upsample, average pool). The cofiber of the round-trip map |
| isolates per-scale content. Classification via prototype dot products. |
| 65,000 parameters. |
| |
| Cross-domain performance (21 domains, 15K steps): avg precision 0.617, avg recall 0.368. |
| """ |
|
|
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
| from losses.fcos import fcos_loss |
| from utils.decode import make_locations, decode_fcos |
|
|
| NUM_CLASSES = 80 |
|
|
|
|
| def cofiber_decompose(f, n_scales): |
| """Compute cofibers via the adjoint pair (Sigma, Omega). |
| Sigma = bilinear upsample 2x, Omega = avg pool 2x. |
| cofiber_k = f_k - Sigma(Omega(f_k)): information at scale k absent from scale k+1. |
| """ |
| cofibers = [] |
| residual = f |
| for _ in range(n_scales - 1): |
| omega = F.avg_pool2d(residual, 2) |
| sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) |
| cofibers.append(residual - sigma_omega) |
| residual = omega |
| cofibers.append(residual) |
| return cofibers |
|
|
|
|
| class CofiberThreshold(nn.Module): |
| """Adjoint decomposition (0 params) + per-scale LayerNorm + threshold prediction (~65K params).""" |
| name = "cofiber_threshold" |
| needs_intermediates = False |
|
|
| def __init__(self, feat_dim=768, num_classes=NUM_CLASSES, n_scales=3): |
| super().__init__() |
| self.n_scales = n_scales |
| self.scale_norms = nn.ModuleList([nn.LayerNorm(feat_dim) for _ in range(n_scales)]) |
| self.prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01) |
| self.proto_bias = nn.Parameter(torch.zeros(num_classes)) |
| self.reg_weight = nn.Parameter(torch.randn(4, feat_dim) * 0.01) |
| self.reg_bias = nn.Parameter(torch.zeros(4)) |
| self.ctr_weight = nn.Parameter(torch.randn(1, feat_dim) * 0.01) |
| self.ctr_bias = nn.Parameter(torch.zeros(1)) |
| self.scale_params = nn.Parameter(torch.ones(n_scales)) |
|
|
| def forward(self, spatial, inter=None): |
| cofibers = cofiber_decompose(spatial, self.n_scales) |
| cls_l, reg_l, ctr_l = [], [], [] |
| for i, cof in enumerate(cofibers): |
| B, C, H, W = cof.shape |
| f = self.scale_norms[i](cof.permute(0, 2, 3, 1).reshape(-1, C)) |
| cls = (f @ self.prototypes.T + self.proto_bias).reshape(B, H, W, -1).permute(0, 3, 1, 2) |
| raw = ((f @ self.reg_weight.T + self.reg_bias) * self.scale_params[i]).clamp(-10, 10) |
| reg = torch.exp(raw).reshape(B, H, W, 4).permute(0, 3, 1, 2) |
| ctr = (f @ self.ctr_weight.T + self.ctr_bias).reshape(B, H, W, 1).permute(0, 3, 1, 2) |
| cls_l.append(cls) |
| reg_l.append(reg) |
| ctr_l.append(ctr) |
| return cls_l, reg_l, ctr_l |
|
|
| def loss(self, preds, locs, boxes_b, labels_b): |
| return fcos_loss(*preds, locs, boxes_b, labels_b) |
|
|
| def decode(self, preds, locs, **kw): |
| return decode_fcos(*preds, locs, **kw) |
|
|
| def get_locs(self, spatial): |
| dummy = cofiber_decompose(spatial[:1], self.n_scales) |
| sizes = [(c.shape[2], c.shape[3]) for c in dummy] |
| strides = [16 * (2 ** i) for i in range(self.n_scales)] |
| return make_locations(sizes, strides, spatial.device) |
|
|