File size: 2,686 Bytes
ca63835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Auto-extracted from detection_arena.py."""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import List

from losses.fcos import fcos_loss, focal_loss, NUM_CLASSES
from losses.centernet import centernet_targets, centernet_loss
from utils.decode import make_locations, decode_fcos, decode_centernet, FPN_STRIDES

N_PREFIX = 5

class ThresholdPrototype(nn.Module):
    """Detection via threshold gates. Integer weights, Heaviside activation. ~65K params."""
    name = "R_threshold"
    needs_intermediates = False

    def __init__(self, feat_dim=768, num_classes=80):
        super().__init__()
        self.prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01)
        self.proto_bias = nn.Parameter(torch.zeros(num_classes))
        self.reg_weight = nn.Parameter(torch.randn(4, feat_dim) * 0.01)
        self.reg_bias = nn.Parameter(torch.zeros(4))
        self.ctr_weight = nn.Parameter(torch.randn(1, feat_dim) * 0.01)
        self.ctr_bias = nn.Parameter(torch.zeros(1))
        self.scale = nn.Parameter(torch.ones(1))

    def forward(self, spatial, inter=None):
        B, C, H, W = spatial.shape
        f = spatial.permute(0, 2, 3, 1).reshape(-1, C)
        cls_logits = f @ self.prototypes.T + self.proto_bias
        reg_raw = (f @ self.reg_weight.T + self.reg_bias) * self.scale
        reg_raw = reg_raw.clamp(-10, 10)
        reg = torch.exp(reg_raw)
        ctr = f @ self.ctr_weight.T + self.ctr_bias
        cls_logits = cls_logits.reshape(B, H, W, -1).permute(0, 3, 1, 2)
        reg = reg.reshape(B, H, W, 4).permute(0, 3, 1, 2)
        ctr = ctr.reshape(B, H, W, 1).permute(0, 3, 1, 2)
        return [cls_logits], [reg], [ctr]

    def loss(self, preds, locs, boxes_b, labels_b):
        return fcos_loss(*preds, locs, boxes_b, labels_b)

    def decode(self, preds, locs, **kw):
        return decode_fcos(*preds, locs, **kw)

    def get_locs(self, spatial):
        h, w = spatial.shape[2], spatial.shape[3]
        return make_locations([(h, w)], [16], spatial.device)

    def quantize_to_circuit(self):
        """Quantize weights to integers for threshold circuit serialization."""
        with torch.no_grad():
            scale = self.prototypes.abs().max()
            self.prototypes.data = (self.prototypes / scale * 127).round()
            self.proto_bias.data = (self.proto_bias / scale * 127).round()
            self.reg_weight.data = (self.reg_weight / self.reg_weight.abs().max() * 127).round()
            self.reg_bias.data = (self.reg_bias / max(self.reg_bias.abs().max(), 1e-6) * 127).round()