"""Auto-extracted from detection_arena.py.""" import math import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from typing import List from losses.fcos import fcos_loss, focal_loss, NUM_CLASSES from losses.centernet import centernet_targets, centernet_loss from utils.decode import make_locations, decode_fcos, decode_centernet, FPN_STRIDES N_PREFIX = 5 class ThresholdPrototype(nn.Module): """Detection via threshold gates. Integer weights, Heaviside activation. ~65K params.""" name = "R_threshold" needs_intermediates = False def __init__(self, feat_dim=768, num_classes=80): super().__init__() self.prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01) self.proto_bias = nn.Parameter(torch.zeros(num_classes)) self.reg_weight = nn.Parameter(torch.randn(4, feat_dim) * 0.01) self.reg_bias = nn.Parameter(torch.zeros(4)) self.ctr_weight = nn.Parameter(torch.randn(1, feat_dim) * 0.01) self.ctr_bias = nn.Parameter(torch.zeros(1)) self.scale = nn.Parameter(torch.ones(1)) def forward(self, spatial, inter=None): B, C, H, W = spatial.shape f = spatial.permute(0, 2, 3, 1).reshape(-1, C) cls_logits = f @ self.prototypes.T + self.proto_bias reg_raw = (f @ self.reg_weight.T + self.reg_bias) * self.scale reg_raw = reg_raw.clamp(-10, 10) reg = torch.exp(reg_raw) ctr = f @ self.ctr_weight.T + self.ctr_bias cls_logits = cls_logits.reshape(B, H, W, -1).permute(0, 3, 1, 2) reg = reg.reshape(B, H, W, 4).permute(0, 3, 1, 2) ctr = ctr.reshape(B, H, W, 1).permute(0, 3, 1, 2) return [cls_logits], [reg], [ctr] def loss(self, preds, locs, boxes_b, labels_b): return fcos_loss(*preds, locs, boxes_b, labels_b) def decode(self, preds, locs, **kw): return decode_fcos(*preds, locs, **kw) def get_locs(self, spatial): h, w = spatial.shape[2], spatial.shape[3] return make_locations([(h, w)], [16], spatial.device) def quantize_to_circuit(self): """Quantize weights to integers for threshold circuit serialization.""" with torch.no_grad(): scale = self.prototypes.abs().max() self.prototypes.data = (self.prototypes / scale * 127).round() self.proto_bias.data = (self.proto_bias / scale * 127).round() self.reg_weight.data = (self.reg_weight / self.reg_weight.abs().max() * 127).round() self.reg_bias.data = (self.reg_bias / max(self.reg_bias.abs().max(), 1e-6) * 127).round()