"""Cofiber Threshold V2: same as V1 but with 2-layer box regression (768->32->4). Targets the mAP@0.75 collapse in V1 (0.8) caused by single-layer box regression. Classification and cofiber decomposition are identical to V1. ~92K total params (under NanoDet-m-0.5x head at 94K). """ import math import torch import torch.nn as nn import torch.nn.functional as F from losses.fcos import fcos_loss from utils.decode import make_locations, decode_fcos NUM_CLASSES = 80 def cofiber_decompose(f, n_scales): cofibers = [] residual = f for _ in range(n_scales - 1): omega = F.avg_pool2d(residual, 2) sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) cofibers.append(residual - sigma_omega) residual = omega cofibers.append(residual) return cofibers class CofiberThresholdV2(nn.Module): """Cofiber decomposition + LayerNorm + prototype cls + 2-layer box regression. ~92K params.""" name = "cofiber_threshold_v2" needs_intermediates = False def __init__(self, feat_dim=768, num_classes=NUM_CLASSES, n_scales=3, reg_hidden=32): super().__init__() self.n_scales = n_scales self.scale_norms = nn.ModuleList([nn.LayerNorm(feat_dim) for _ in range(n_scales)]) # Classification: same as V1 self.prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01) self.proto_bias = nn.Parameter(torch.zeros(num_classes)) # Box regression: 2-layer with hidden dim self.reg_hidden = nn.Linear(feat_dim, reg_hidden) self.reg_act = nn.GELU() self.reg_out = nn.Linear(reg_hidden, 4) # Centerness: same as V1 self.ctr_weight = nn.Parameter(torch.randn(1, feat_dim) * 0.01) self.ctr_bias = nn.Parameter(torch.zeros(1)) self.scale_params = nn.Parameter(torch.ones(n_scales)) def forward(self, spatial, inter=None): cofibers = cofiber_decompose(spatial, self.n_scales) cls_l, reg_l, ctr_l = [], [], [] for i, cof in enumerate(cofibers): B, C, H, W = cof.shape f = self.scale_norms[i](cof.permute(0, 2, 3, 1).reshape(-1, C)) cls = (f @ self.prototypes.T + self.proto_bias).reshape(B, H, W, -1).permute(0, 3, 1, 2) reg_raw = (self.reg_out(self.reg_act(self.reg_hidden(f))) * self.scale_params[i]).clamp(-10, 10) reg = torch.exp(reg_raw).reshape(B, H, W, 4).permute(0, 3, 1, 2) ctr = (f @ self.ctr_weight.T + self.ctr_bias).reshape(B, H, W, 1).permute(0, 3, 1, 2) cls_l.append(cls) reg_l.append(reg) ctr_l.append(ctr) return cls_l, reg_l, ctr_l def loss(self, preds, locs, boxes_b, labels_b): return fcos_loss(*preds, locs, boxes_b, labels_b) def decode(self, preds, locs, **kw): return decode_fcos(*preds, locs, **kw) def get_locs(self, spatial): dummy = cofiber_decompose(spatial[:1], self.n_scales) sizes = [(c.shape[2], c.shape[3]) for c in dummy] strides = [16 * (2 ** i) for i in range(self.n_scales)] return make_locations(sizes, strides, spatial.device)