phanerozoic's picture
8 segmentation head candidates with shared losses/utils and registry
0e8110e verified
"""Patch Attention: each patch attends to k nearest neighbors before classifying."""
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchAttention(nn.Module):
name = "patch_attention"
needs_intermediates = False
def __init__(self, feat_dim=768, num_classes=150, dim=256, k=16):
super().__init__()
self.k = k
self.proj = nn.Linear(feat_dim, dim)
self.attn = nn.MultiheadAttention(dim, 4, batch_first=True)
self.norm = nn.LayerNorm(dim)
self.cls_head = nn.Linear(dim, num_classes)
def forward(self, spatial, inter=None):
B, C, H, W = spatial.shape
tokens = self.proj(spatial.flatten(2).permute(0, 2, 1))
with torch.no_grad():
sim = torch.bmm(F.normalize(tokens, dim=-1),
F.normalize(tokens, dim=-1).transpose(1, 2))
_, knn_idx = sim.topk(self.k, dim=-1)
# Gather k-NN for each token as KV, self-attend
N = tokens.shape[1]
kv = tokens.gather(1, knn_idx.reshape(B, -1, 1).expand(-1, -1, tokens.shape[-1])).reshape(B, N, self.k, -1)
kv_flat = kv.reshape(B * N, self.k, -1)
q = tokens.reshape(B * N, 1, -1)
out, _ = self.attn(q, kv_flat, kv_flat)
out = self.norm(tokens + out.reshape(B, N, -1))
logits = self.cls_head(out).reshape(B, H, W, -1).permute(0, 3, 1, 2)
return logits