| """Auto-extracted from detection_arena.py.""" | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from typing import List | |
| from losses.fcos import fcos_loss, focal_loss, NUM_CLASSES | |
| from losses.centernet import centernet_targets, centernet_loss | |
| from utils.decode import make_locations, decode_fcos, decode_centernet, FPN_STRIDES | |
| N_PREFIX = 5 | |
| class ThresholdPrototype(nn.Module): | |
| """Detection via threshold gates. Integer weights, Heaviside activation. ~65K params.""" | |
| name = "R_threshold" | |
| needs_intermediates = False | |
| def __init__(self, feat_dim=768, num_classes=80): | |
| super().__init__() | |
| self.prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01) | |
| self.proto_bias = nn.Parameter(torch.zeros(num_classes)) | |
| self.reg_weight = nn.Parameter(torch.randn(4, feat_dim) * 0.01) | |
| self.reg_bias = nn.Parameter(torch.zeros(4)) | |
| self.ctr_weight = nn.Parameter(torch.randn(1, feat_dim) * 0.01) | |
| self.ctr_bias = nn.Parameter(torch.zeros(1)) | |
| self.scale = nn.Parameter(torch.ones(1)) | |
| def forward(self, spatial, inter=None): | |
| B, C, H, W = spatial.shape | |
| f = spatial.permute(0, 2, 3, 1).reshape(-1, C) | |
| cls_logits = f @ self.prototypes.T + self.proto_bias | |
| reg_raw = (f @ self.reg_weight.T + self.reg_bias) * self.scale | |
| reg_raw = reg_raw.clamp(-10, 10) | |
| reg = torch.exp(reg_raw) | |
| ctr = f @ self.ctr_weight.T + self.ctr_bias | |
| cls_logits = cls_logits.reshape(B, H, W, -1).permute(0, 3, 1, 2) | |
| reg = reg.reshape(B, H, W, 4).permute(0, 3, 1, 2) | |
| ctr = ctr.reshape(B, H, W, 1).permute(0, 3, 1, 2) | |
| return [cls_logits], [reg], [ctr] | |
| def loss(self, preds, locs, boxes_b, labels_b): | |
| return fcos_loss(*preds, locs, boxes_b, labels_b) | |
| def decode(self, preds, locs, **kw): | |
| return decode_fcos(*preds, locs, **kw) | |
| def get_locs(self, spatial): | |
| h, w = spatial.shape[2], spatial.shape[3] | |
| return make_locations([(h, w)], [16], spatial.device) | |
| def quantize_to_circuit(self): | |
| """Quantize weights to integers for threshold circuit serialization.""" | |
| with torch.no_grad(): | |
| scale = self.prototypes.abs().max() | |
| self.prototypes.data = (self.prototypes / scale * 127).round() | |
| self.proto_bias.data = (self.proto_bias / scale * 127).round() | |
| self.reg_weight.data = (self.reg_weight / self.reg_weight.abs().max() * 127).round() | |
| self.reg_bias.data = (self.reg_bias / max(self.reg_bias.abs().max(), 1e-6) * 127).round() | |