phanerozoic's picture
Add tropical, compression, curvature, renormalization heads
cd7a8ba verified
"""Compression Segmentation: modulate features by local prediction residual.
Patches that can't be predicted from neighbors get amplified before classification."""
import torch
import torch.nn as nn
import torch.nn.functional as F
class CompressionSegmentation(nn.Module):
name = "compression"
needs_intermediates = False
def __init__(self, feat_dim=768, num_classes=150):
super().__init__()
self.cls = nn.Conv2d(feat_dim, num_classes, 1)
def forward(self, spatial, inter=None):
B, C, H, W = spatial.shape
kernel = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]],
dtype=spatial.dtype, device=spatial.device) / 8
kernel = kernel.reshape(1, 1, 3, 3).expand(C, 1, 3, 3)
neighbor_mean = F.conv2d(spatial, kernel, padding=1, groups=C)
surprise = (spatial - neighbor_mean).pow(2).sum(dim=1, keepdim=True)
surprise_norm = surprise / surprise.amax(dim=(2, 3), keepdim=True).clamp(min=1e-6)
modulated = spatial * (1 + surprise_norm * 3)
return self.cls(modulated)