phanerozoic commited on
Commit
0e8110e
·
verified ·
1 Parent(s): 714b88b

8 segmentation head candidates with shared losses/utils and registry

Browse files
heads/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Segmentation head registry."""
2
+
3
+ from .linear_probe.head import LinearProbe
4
+ from .cofiber_linear.head import CofiberLinear
5
+ from .cofiber_threshold.head import CofiberThreshold
6
+ from .prototype_bank.head import PrototypeBank
7
+ from .wavelet.head import Wavelet
8
+ from .patch_attention.head import PatchAttention
9
+ from .graph_crf.head import GraphCRF
10
+ from .hypercolumn_linear.head import HypercolumnLinear
11
+
12
+ REGISTRY = {
13
+ "linear_probe": LinearProbe,
14
+ "cofiber_linear": CofiberLinear,
15
+ "cofiber_threshold": CofiberThreshold,
16
+ "prototype_bank": PrototypeBank,
17
+ "wavelet": Wavelet,
18
+ "patch_attention": PatchAttention,
19
+ "graph_crf": GraphCRF,
20
+ "hypercolumn_linear": HypercolumnLinear,
21
+ }
22
+
23
+ ALL_NAMES = list(REGISTRY.keys())
24
+
25
+ def get_head(name):
26
+ if name not in REGISTRY:
27
+ raise ValueError(f"Unknown head: {name}. Available: {ALL_NAMES}")
28
+ return REGISTRY[name]()
heads/cofiber_linear/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/cofiber_linear/head.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cofiber Linear: analytic multi-scale decomposition + shared 1x1 conv per scale."""
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def cofiber_decompose(f, n_scales):
8
+ cofibers = []
9
+ residual = f
10
+ for _ in range(n_scales - 1):
11
+ omega = F.avg_pool2d(residual, 2)
12
+ sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
13
+ cofibers.append(residual - sigma_omega)
14
+ residual = omega
15
+ cofibers.append(residual)
16
+ return cofibers
17
+
18
+
19
+ class CofiberLinear(nn.Module):
20
+ name = "cofiber_linear"
21
+ needs_intermediates = False
22
+
23
+ def __init__(self, feat_dim=768, num_classes=150, n_scales=3):
24
+ super().__init__()
25
+ self.n_scales = n_scales
26
+ self.conv = nn.Conv2d(feat_dim, num_classes, 1)
27
+
28
+ def forward(self, spatial, inter=None):
29
+ cofibers = cofiber_decompose(spatial, self.n_scales)
30
+ target_size = spatial.shape[2:]
31
+ logits = None
32
+ for cof in cofibers:
33
+ out = self.conv(cof)
34
+ out = F.interpolate(out, size=target_size, mode="bilinear", align_corners=False)
35
+ logits = out if logits is None else logits + out
36
+ return logits
heads/cofiber_threshold/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/cofiber_threshold/head.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cofiber Threshold: analytic decomposition + per-scale LayerNorm + prototype classification."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def cofiber_decompose(f, n_scales):
9
+ cofibers = []
10
+ residual = f
11
+ for _ in range(n_scales - 1):
12
+ omega = F.avg_pool2d(residual, 2)
13
+ sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
14
+ cofibers.append(residual - sigma_omega)
15
+ residual = omega
16
+ cofibers.append(residual)
17
+ return cofibers
18
+
19
+
20
+ class CofiberThreshold(nn.Module):
21
+ name = "cofiber_threshold"
22
+ needs_intermediates = False
23
+
24
+ def __init__(self, feat_dim=768, num_classes=150, n_scales=3):
25
+ super().__init__()
26
+ self.n_scales = n_scales
27
+ self.scale_norms = nn.ModuleList([nn.LayerNorm(feat_dim) for _ in range(n_scales)])
28
+ self.prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01)
29
+ self.proto_bias = nn.Parameter(torch.zeros(num_classes))
30
+
31
+ def forward(self, spatial, inter=None):
32
+ cofibers = cofiber_decompose(spatial, self.n_scales)
33
+ target_size = spatial.shape[2:]
34
+ logits = None
35
+ for i, cof in enumerate(cofibers):
36
+ B, C, H, W = cof.shape
37
+ f = self.scale_norms[i](cof.permute(0, 2, 3, 1).reshape(-1, C))
38
+ out = (f @ self.prototypes.T + self.proto_bias).reshape(B, H, W, -1).permute(0, 3, 1, 2)
39
+ out = F.interpolate(out, size=target_size, mode="bilinear", align_corners=False)
40
+ logits = out if logits is None else logits + out
41
+ return logits
heads/graph_crf/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/graph_crf/head.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Graph CRF: k-NN graph in feature space + message passing + per-node classification."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class GraphCRF(nn.Module):
9
+ name = "graph_crf"
10
+ needs_intermediates = False
11
+
12
+ def __init__(self, feat_dim=768, num_classes=150, dim=256, k=8, rounds=2):
13
+ super().__init__()
14
+ self.k = k
15
+ self.rounds = rounds
16
+ self.proj = nn.Linear(feat_dim, dim)
17
+ self.msg_layers = nn.ModuleList()
18
+ for _ in range(rounds):
19
+ self.msg_layers.append(nn.ModuleDict({
20
+ "msg": nn.Linear(dim, dim),
21
+ "gate": nn.Sequential(nn.Linear(dim, dim), nn.Sigmoid()),
22
+ "upd": nn.Linear(dim, dim),
23
+ "norm": nn.LayerNorm(dim),
24
+ }))
25
+ self.cls_head = nn.Linear(dim, num_classes)
26
+
27
+ def forward(self, spatial, inter=None):
28
+ B, C, H, W = spatial.shape
29
+ tokens = self.proj(spatial.flatten(2).permute(0, 2, 1))
30
+ N = tokens.shape[1]
31
+ with torch.no_grad():
32
+ sim = torch.bmm(F.normalize(tokens, dim=-1),
33
+ F.normalize(tokens, dim=-1).transpose(1, 2))
34
+ _, knn_idx = sim.topk(self.k, dim=-1)
35
+ for layer in self.msg_layers:
36
+ neighbors = tokens.gather(1, knn_idx.reshape(B, -1, 1).expand(-1, -1, tokens.shape[-1])).reshape(B, N, self.k, -1)
37
+ msg = layer["msg"](neighbors.mean(dim=2))
38
+ gate = layer["gate"](tokens)
39
+ tokens = layer["norm"](tokens + layer["upd"](gate * msg))
40
+ logits = self.cls_head(tokens).reshape(B, H, W, -1).permute(0, 3, 1, 2)
41
+ return logits
heads/hypercolumn_linear/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/hypercolumn_linear/head.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hypercolumn Linear: concatenate features from intermediate blocks, single linear layer."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ N_PREFIX = 5
8
+
9
+
10
+ class HypercolumnLinear(nn.Module):
11
+ name = "hypercolumn_linear"
12
+ needs_intermediates = True
13
+
14
+ def __init__(self, feat_dim=768, num_classes=150, n_blocks=4):
15
+ super().__init__()
16
+ self.n_blocks = n_blocks
17
+ self.conv = nn.Conv2d(feat_dim * n_blocks, num_classes, 1)
18
+
19
+ def forward(self, spatial, inter=None):
20
+ B, C, H, W = spatial.shape
21
+ if inter is None:
22
+ raise ValueError("hypercolumn_linear requires intermediate block features")
23
+ spatials = []
24
+ for feat in inter:
25
+ patches = feat[:, N_PREFIX:, :]
26
+ s = patches.permute(0, 2, 1).reshape(B, C, H, W)
27
+ spatials.append(s)
28
+ stacked = torch.cat(spatials, dim=1)
29
+ return self.conv(stacked)
heads/linear_probe/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/linear_probe/head.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Linear probe: BatchNorm + 1x1 conv. The EUPE paper baseline."""
2
+
3
+ import torch.nn as nn
4
+
5
+ class LinearProbe(nn.Module):
6
+ name = "linear_probe"
7
+ needs_intermediates = False
8
+
9
+ def __init__(self, feat_dim=768, num_classes=150):
10
+ super().__init__()
11
+ self.bn = nn.BatchNorm2d(feat_dim)
12
+ self.conv = nn.Conv2d(feat_dim, num_classes, 1)
13
+
14
+ def forward(self, spatial, inter=None):
15
+ return self.conv(self.bn(spatial))
heads/patch_attention/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/patch_attention/head.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Patch Attention: each patch attends to k nearest neighbors before classifying."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class PatchAttention(nn.Module):
9
+ name = "patch_attention"
10
+ needs_intermediates = False
11
+
12
+ def __init__(self, feat_dim=768, num_classes=150, dim=256, k=16):
13
+ super().__init__()
14
+ self.k = k
15
+ self.proj = nn.Linear(feat_dim, dim)
16
+ self.attn = nn.MultiheadAttention(dim, 4, batch_first=True)
17
+ self.norm = nn.LayerNorm(dim)
18
+ self.cls_head = nn.Linear(dim, num_classes)
19
+
20
+ def forward(self, spatial, inter=None):
21
+ B, C, H, W = spatial.shape
22
+ tokens = self.proj(spatial.flatten(2).permute(0, 2, 1))
23
+ with torch.no_grad():
24
+ sim = torch.bmm(F.normalize(tokens, dim=-1),
25
+ F.normalize(tokens, dim=-1).transpose(1, 2))
26
+ _, knn_idx = sim.topk(self.k, dim=-1)
27
+ # Gather k-NN for each token as KV, self-attend
28
+ N = tokens.shape[1]
29
+ kv = tokens.gather(1, knn_idx.reshape(B, -1, 1).expand(-1, -1, tokens.shape[-1])).reshape(B, N, self.k, -1)
30
+ kv_flat = kv.reshape(B * N, self.k, -1)
31
+ q = tokens.reshape(B * N, 1, -1)
32
+ out, _ = self.attn(q, kv_flat, kv_flat)
33
+ out = self.norm(tokens + out.reshape(B, N, -1))
34
+ logits = self.cls_head(out).reshape(B, H, W, -1).permute(0, 3, 1, 2)
35
+ return logits
heads/prototype_bank/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/prototype_bank/head.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prototype Bank: learned class prototypes, per-pixel cosine similarity, argmax. No conv layers."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class PrototypeBank(nn.Module):
9
+ name = "prototype_bank"
10
+ needs_intermediates = False
11
+
12
+ def __init__(self, feat_dim=768, num_classes=150):
13
+ super().__init__()
14
+ self.prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01)
15
+ self.scale = nn.Parameter(torch.ones(1) * 10.0)
16
+
17
+ def forward(self, spatial, inter=None):
18
+ B, C, H, W = spatial.shape
19
+ f = F.normalize(spatial.permute(0, 2, 3, 1).reshape(-1, C), dim=-1)
20
+ p = F.normalize(self.prototypes, dim=-1)
21
+ logits = (f @ p.T * self.scale).reshape(B, H, W, -1).permute(0, 3, 1, 2)
22
+ return logits
heads/wavelet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/wavelet/head.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wavelet: Haar decomposition + per-subband classification."""
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class Wavelet(nn.Module):
8
+ name = "wavelet"
9
+ needs_intermediates = False
10
+
11
+ def __init__(self, feat_dim=768, num_classes=150, n_scales=3):
12
+ super().__init__()
13
+ self.n_scales = n_scales
14
+ self.heads = nn.ModuleList([nn.Conv2d(feat_dim, num_classes, 1) for _ in range(n_scales)])
15
+
16
+ @staticmethod
17
+ def haar_down(x):
18
+ return (x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2] +
19
+ x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]) / 4
20
+
21
+ def forward(self, spatial, inter=None):
22
+ target_size = spatial.shape[2:]
23
+ f = spatial
24
+ logits = None
25
+ for i in range(self.n_scales):
26
+ out = self.heads[i](f)
27
+ out = F.interpolate(out, size=target_size, mode="bilinear", align_corners=False)
28
+ logits = out if logits is None else logits + out
29
+ if i < self.n_scales - 1:
30
+ f = self.haar_down(f)
31
+ return logits
losses/__init__.py ADDED
File without changes
losses/segmentation.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Segmentation losses."""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def cross_entropy_loss(logits, targets, ignore_index=255):
8
+ """Standard per-pixel cross entropy. logits: [B, C, H, W], targets: [B, H, W]."""
9
+ if logits.shape[2:] != targets.shape[1:]:
10
+ logits = F.interpolate(logits, size=targets.shape[1:], mode="bilinear", align_corners=False)
11
+ return F.cross_entropy(logits, targets, ignore_index=ignore_index)
utils/__init__.py ADDED
File without changes
utils/decode.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared utilities for segmentation heads."""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def upsample_and_argmax(logits, target_size):
8
+ """Upsample logits to target spatial size and return class indices."""
9
+ if logits.shape[2:] != target_size:
10
+ logits = F.interpolate(logits, size=target_size, mode="bilinear", align_corners=False)
11
+ return logits.argmax(dim=1)