phanerozoic's picture
update repository
dbbceb8
"""Deep Evolved Detection Head.
10-layer MLP with interleaved depthwise 3x3 convolutions operating on
92 evolutionarily-selected feature dimensions. The dimension selection
was performed via GPU-batched evolutionary search (200 gen/s). The MLP
was trained on the selected dimensions with the backbone frozen.
182K params, 10.6 mAP, mAP@0.75 = 10.8.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
NUM_CLASSES = 80
def cofiber_decompose(f, n_scales):
cofibers = []; residual = f
for _ in range(n_scales - 1):
omega = F.avg_pool2d(residual, 2)
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
cofibers.append(residual - sigma_omega); residual = omega
cofibers.append(residual); return cofibers
class SpatialDWConv(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, padding=1, groups=channels)
def forward(self, x, B, H, W):
if x.dim() == 4:
x = x.permute(0, 3, 1, 2)
x = self.conv(x)
x = x.permute(0, 2, 3, 1)
return x
class EvolvedDeepHead(nn.Module):
name = "evolved_deep"
needs_intermediates = False
def __init__(self, evolved_dims, hidden=128, n_layers=10, n_scales=3):
super().__init__()
self.evolved_dims = evolved_dims
self.n_scales = n_scales
K = len(evolved_dims)
self.dim_idx = nn.Parameter(torch.tensor(evolved_dims, dtype=torch.long), requires_grad=False)
self.scale_norms = nn.ModuleList([nn.LayerNorm(768) for _ in range(n_scales)])
layers = []
in_dim = K
for i in range(n_layers):
layers.append(nn.Linear(in_dim, hidden))
layers.append(nn.GELU())
if i % 2 == 1:
layers.append(SpatialDWConv(hidden))
in_dim = hidden
self.backbone = nn.Sequential(*layers)
self.cls_head = nn.Linear(hidden, NUM_CLASSES)
self.reg_head = nn.Linear(hidden, 4)
self.ctr_head = nn.Linear(hidden, 1)
self.scale_params = nn.Parameter(torch.ones(n_scales))
def forward(self, spatial, inter=None):
cofibers = cofiber_decompose(spatial, self.n_scales)
cls_l, reg_l, ctr_l = [], [], []
for i, cof in enumerate(cofibers):
B, C, H, W = cof.shape
f = self.scale_norms[i](cof.permute(0, 2, 3, 1).reshape(-1, C))
f_sel = f[:, self.dim_idx].reshape(B, H, W, -1)
h = self._forward_with_spatial(f_sel, B, H, W)
cls = self.cls_head(h.reshape(-1, h.shape[-1])).reshape(B, H, W, -1).permute(0, 3, 1, 2)
reg_raw = (self.reg_head(h.reshape(-1, h.shape[-1])) * self.scale_params[i]).clamp(-10, 10)
reg = reg_raw.exp().reshape(B, H, W, 4).permute(0, 3, 1, 2)
ctr = self.ctr_head(h.reshape(-1, h.shape[-1])).reshape(B, H, W, 1).permute(0, 3, 1, 2)
cls_l.append(cls); reg_l.append(reg); ctr_l.append(ctr)
return cls_l, reg_l, ctr_l
def _forward_with_spatial(self, x, B, H, W):
for layer in self.backbone:
if isinstance(layer, SpatialDWConv):
x = layer(x, B, H, W)
else:
x = layer(x)
return x