| """Patch Attention: each patch attends to k nearest neighbors before classifying.""" |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class PatchAttention(nn.Module): |
| name = "patch_attention" |
| needs_intermediates = False |
|
|
| def __init__(self, feat_dim=768, num_classes=150, dim=256, k=16): |
| super().__init__() |
| self.k = k |
| self.proj = nn.Linear(feat_dim, dim) |
| self.attn = nn.MultiheadAttention(dim, 4, batch_first=True) |
| self.norm = nn.LayerNorm(dim) |
| self.cls_head = nn.Linear(dim, num_classes) |
|
|
| def forward(self, spatial, inter=None): |
| B, C, H, W = spatial.shape |
| tokens = self.proj(spatial.flatten(2).permute(0, 2, 1)) |
| with torch.no_grad(): |
| sim = torch.bmm(F.normalize(tokens, dim=-1), |
| F.normalize(tokens, dim=-1).transpose(1, 2)) |
| _, knn_idx = sim.topk(self.k, dim=-1) |
| |
| N = tokens.shape[1] |
| kv = tokens.gather(1, knn_idx.reshape(B, -1, 1).expand(-1, -1, tokens.shape[-1])).reshape(B, N, self.k, -1) |
| kv_flat = kv.reshape(B * N, self.k, -1) |
| q = tokens.reshape(B * N, 1, -1) |
| out, _ = self.attn(q, kv_flat, kv_flat) |
| out = self.norm(tokens + out.reshape(B, N, -1)) |
| logits = self.cls_head(out).reshape(B, H, W, -1).permute(0, 3, 1, 2) |
| return logits |
|
|