| """Compression Segmentation: modulate features by local prediction residual. |
| Patches that can't be predicted from neighbors get amplified before classification.""" |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class CompressionSegmentation(nn.Module): |
| name = "compression" |
| needs_intermediates = False |
|
|
| def __init__(self, feat_dim=768, num_classes=150): |
| super().__init__() |
| self.cls = nn.Conv2d(feat_dim, num_classes, 1) |
|
|
| def forward(self, spatial, inter=None): |
| B, C, H, W = spatial.shape |
| kernel = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]], |
| dtype=spatial.dtype, device=spatial.device) / 8 |
| kernel = kernel.reshape(1, 1, 3, 3).expand(C, 1, 3, 3) |
| neighbor_mean = F.conv2d(spatial, kernel, padding=1, groups=C) |
| surprise = (spatial - neighbor_mean).pow(2).sum(dim=1, keepdim=True) |
| surprise_norm = surprise / surprise.amax(dim=(2, 3), keepdim=True).clamp(min=1e-6) |
| modulated = spatial * (1 + surprise_norm * 3) |
| return self.cls(modulated) |
|
|