phanerozoic's picture
Restructure: one folder per head, shared losses/utils, registry runner
ca63835 verified
"""Auto-extracted from detection_arena.py."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import List
from losses.fcos import fcos_loss, focal_loss, NUM_CLASSES
from losses.centernet import centernet_targets, centernet_loss
from utils.decode import make_locations, decode_fcos, decode_centernet, FPN_STRIDES
N_PREFIX = 5
class ThresholdPrototype(nn.Module):
"""Detection via threshold gates. Integer weights, Heaviside activation. ~65K params."""
name = "R_threshold"
needs_intermediates = False
def __init__(self, feat_dim=768, num_classes=80):
super().__init__()
self.prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01)
self.proto_bias = nn.Parameter(torch.zeros(num_classes))
self.reg_weight = nn.Parameter(torch.randn(4, feat_dim) * 0.01)
self.reg_bias = nn.Parameter(torch.zeros(4))
self.ctr_weight = nn.Parameter(torch.randn(1, feat_dim) * 0.01)
self.ctr_bias = nn.Parameter(torch.zeros(1))
self.scale = nn.Parameter(torch.ones(1))
def forward(self, spatial, inter=None):
B, C, H, W = spatial.shape
f = spatial.permute(0, 2, 3, 1).reshape(-1, C)
cls_logits = f @ self.prototypes.T + self.proto_bias
reg_raw = (f @ self.reg_weight.T + self.reg_bias) * self.scale
reg_raw = reg_raw.clamp(-10, 10)
reg = torch.exp(reg_raw)
ctr = f @ self.ctr_weight.T + self.ctr_bias
cls_logits = cls_logits.reshape(B, H, W, -1).permute(0, 3, 1, 2)
reg = reg.reshape(B, H, W, 4).permute(0, 3, 1, 2)
ctr = ctr.reshape(B, H, W, 1).permute(0, 3, 1, 2)
return [cls_logits], [reg], [ctr]
def loss(self, preds, locs, boxes_b, labels_b):
return fcos_loss(*preds, locs, boxes_b, labels_b)
def decode(self, preds, locs, **kw):
return decode_fcos(*preds, locs, **kw)
def get_locs(self, spatial):
h, w = spatial.shape[2], spatial.shape[3]
return make_locations([(h, w)], [16], spatial.device)
def quantize_to_circuit(self):
"""Quantize weights to integers for threshold circuit serialization."""
with torch.no_grad():
scale = self.prototypes.abs().max()
self.prototypes.data = (self.prototypes / scale * 127).round()
self.proto_bias.data = (self.proto_bias / scale * 127).round()
self.reg_weight.data = (self.reg_weight / self.reg_weight.abs().max() * 127).round()
self.reg_bias.data = (self.reg_bias / max(self.reg_bias.abs().max(), 1e-6) * 127).round()