File size: 3,321 Bytes
dbbceb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Deep Evolved Detection Head.

10-layer MLP with interleaved depthwise 3x3 convolutions operating on
92 evolutionarily-selected feature dimensions. The dimension selection
was performed via GPU-batched evolutionary search (200 gen/s). The MLP
was trained on the selected dimensions with the backbone frozen.

182K params, 10.6 mAP, mAP@0.75 = 10.8.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

NUM_CLASSES = 80


def cofiber_decompose(f, n_scales):
    cofibers = []; residual = f
    for _ in range(n_scales - 1):
        omega = F.avg_pool2d(residual, 2)
        sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
        cofibers.append(residual - sigma_omega); residual = omega
    cofibers.append(residual); return cofibers


class SpatialDWConv(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, padding=1, groups=channels)

    def forward(self, x, B, H, W):
        if x.dim() == 4:
            x = x.permute(0, 3, 1, 2)
            x = self.conv(x)
            x = x.permute(0, 2, 3, 1)
        return x


class EvolvedDeepHead(nn.Module):
    name = "evolved_deep"
    needs_intermediates = False

    def __init__(self, evolved_dims, hidden=128, n_layers=10, n_scales=3):
        super().__init__()
        self.evolved_dims = evolved_dims
        self.n_scales = n_scales
        K = len(evolved_dims)
        self.dim_idx = nn.Parameter(torch.tensor(evolved_dims, dtype=torch.long), requires_grad=False)
        self.scale_norms = nn.ModuleList([nn.LayerNorm(768) for _ in range(n_scales)])

        layers = []
        in_dim = K
        for i in range(n_layers):
            layers.append(nn.Linear(in_dim, hidden))
            layers.append(nn.GELU())
            if i % 2 == 1:
                layers.append(SpatialDWConv(hidden))
            in_dim = hidden
        self.backbone = nn.Sequential(*layers)

        self.cls_head = nn.Linear(hidden, NUM_CLASSES)
        self.reg_head = nn.Linear(hidden, 4)
        self.ctr_head = nn.Linear(hidden, 1)
        self.scale_params = nn.Parameter(torch.ones(n_scales))

    def forward(self, spatial, inter=None):
        cofibers = cofiber_decompose(spatial, self.n_scales)
        cls_l, reg_l, ctr_l = [], [], []
        for i, cof in enumerate(cofibers):
            B, C, H, W = cof.shape
            f = self.scale_norms[i](cof.permute(0, 2, 3, 1).reshape(-1, C))
            f_sel = f[:, self.dim_idx].reshape(B, H, W, -1)
            h = self._forward_with_spatial(f_sel, B, H, W)
            cls = self.cls_head(h.reshape(-1, h.shape[-1])).reshape(B, H, W, -1).permute(0, 3, 1, 2)
            reg_raw = (self.reg_head(h.reshape(-1, h.shape[-1])) * self.scale_params[i]).clamp(-10, 10)
            reg = reg_raw.exp().reshape(B, H, W, 4).permute(0, 3, 1, 2)
            ctr = self.ctr_head(h.reshape(-1, h.shape[-1])).reshape(B, H, W, 1).permute(0, 3, 1, 2)
            cls_l.append(cls); reg_l.append(reg); ctr_l.append(ctr)
        return cls_l, reg_l, ctr_l

    def _forward_with_spatial(self, x, B, H, W):
        for layer in self.backbone:
            if isinstance(layer, SpatialDWConv):
                x = layer(x, B, H, W)
            else:
                x = layer(x)
        return x