8 segmentation head candidates with shared losses/utils and registry
Browse files- heads/__init__.py +28 -0
- heads/cofiber_linear/__init__.py +1 -0
- heads/cofiber_linear/head.py +36 -0
- heads/cofiber_threshold/__init__.py +1 -0
- heads/cofiber_threshold/head.py +41 -0
- heads/graph_crf/__init__.py +1 -0
- heads/graph_crf/head.py +41 -0
- heads/hypercolumn_linear/__init__.py +1 -0
- heads/hypercolumn_linear/head.py +29 -0
- heads/linear_probe/__init__.py +1 -0
- heads/linear_probe/head.py +15 -0
- heads/patch_attention/__init__.py +1 -0
- heads/patch_attention/head.py +35 -0
- heads/prototype_bank/__init__.py +1 -0
- heads/prototype_bank/head.py +22 -0
- heads/wavelet/__init__.py +1 -0
- heads/wavelet/head.py +31 -0
- losses/__init__.py +0 -0
- losses/segmentation.py +11 -0
- utils/__init__.py +0 -0
- utils/decode.py +11 -0
heads/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Segmentation head registry."""
|
| 2 |
+
|
| 3 |
+
from .linear_probe.head import LinearProbe
|
| 4 |
+
from .cofiber_linear.head import CofiberLinear
|
| 5 |
+
from .cofiber_threshold.head import CofiberThreshold
|
| 6 |
+
from .prototype_bank.head import PrototypeBank
|
| 7 |
+
from .wavelet.head import Wavelet
|
| 8 |
+
from .patch_attention.head import PatchAttention
|
| 9 |
+
from .graph_crf.head import GraphCRF
|
| 10 |
+
from .hypercolumn_linear.head import HypercolumnLinear
|
| 11 |
+
|
| 12 |
+
REGISTRY = {
|
| 13 |
+
"linear_probe": LinearProbe,
|
| 14 |
+
"cofiber_linear": CofiberLinear,
|
| 15 |
+
"cofiber_threshold": CofiberThreshold,
|
| 16 |
+
"prototype_bank": PrototypeBank,
|
| 17 |
+
"wavelet": Wavelet,
|
| 18 |
+
"patch_attention": PatchAttention,
|
| 19 |
+
"graph_crf": GraphCRF,
|
| 20 |
+
"hypercolumn_linear": HypercolumnLinear,
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
ALL_NAMES = list(REGISTRY.keys())
|
| 24 |
+
|
| 25 |
+
def get_head(name):
|
| 26 |
+
if name not in REGISTRY:
|
| 27 |
+
raise ValueError(f"Unknown head: {name}. Available: {ALL_NAMES}")
|
| 28 |
+
return REGISTRY[name]()
|
heads/cofiber_linear/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/cofiber_linear/head.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cofiber Linear: analytic multi-scale decomposition + shared 1x1 conv per scale."""
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def cofiber_decompose(f, n_scales):
|
| 8 |
+
cofibers = []
|
| 9 |
+
residual = f
|
| 10 |
+
for _ in range(n_scales - 1):
|
| 11 |
+
omega = F.avg_pool2d(residual, 2)
|
| 12 |
+
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
|
| 13 |
+
cofibers.append(residual - sigma_omega)
|
| 14 |
+
residual = omega
|
| 15 |
+
cofibers.append(residual)
|
| 16 |
+
return cofibers
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CofiberLinear(nn.Module):
|
| 20 |
+
name = "cofiber_linear"
|
| 21 |
+
needs_intermediates = False
|
| 22 |
+
|
| 23 |
+
def __init__(self, feat_dim=768, num_classes=150, n_scales=3):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.n_scales = n_scales
|
| 26 |
+
self.conv = nn.Conv2d(feat_dim, num_classes, 1)
|
| 27 |
+
|
| 28 |
+
def forward(self, spatial, inter=None):
|
| 29 |
+
cofibers = cofiber_decompose(spatial, self.n_scales)
|
| 30 |
+
target_size = spatial.shape[2:]
|
| 31 |
+
logits = None
|
| 32 |
+
for cof in cofibers:
|
| 33 |
+
out = self.conv(cof)
|
| 34 |
+
out = F.interpolate(out, size=target_size, mode="bilinear", align_corners=False)
|
| 35 |
+
logits = out if logits is None else logits + out
|
| 36 |
+
return logits
|
heads/cofiber_threshold/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/cofiber_threshold/head.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cofiber Threshold: analytic decomposition + per-scale LayerNorm + prototype classification."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def cofiber_decompose(f, n_scales):
|
| 9 |
+
cofibers = []
|
| 10 |
+
residual = f
|
| 11 |
+
for _ in range(n_scales - 1):
|
| 12 |
+
omega = F.avg_pool2d(residual, 2)
|
| 13 |
+
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
|
| 14 |
+
cofibers.append(residual - sigma_omega)
|
| 15 |
+
residual = omega
|
| 16 |
+
cofibers.append(residual)
|
| 17 |
+
return cofibers
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CofiberThreshold(nn.Module):
|
| 21 |
+
name = "cofiber_threshold"
|
| 22 |
+
needs_intermediates = False
|
| 23 |
+
|
| 24 |
+
def __init__(self, feat_dim=768, num_classes=150, n_scales=3):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.n_scales = n_scales
|
| 27 |
+
self.scale_norms = nn.ModuleList([nn.LayerNorm(feat_dim) for _ in range(n_scales)])
|
| 28 |
+
self.prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01)
|
| 29 |
+
self.proto_bias = nn.Parameter(torch.zeros(num_classes))
|
| 30 |
+
|
| 31 |
+
def forward(self, spatial, inter=None):
|
| 32 |
+
cofibers = cofiber_decompose(spatial, self.n_scales)
|
| 33 |
+
target_size = spatial.shape[2:]
|
| 34 |
+
logits = None
|
| 35 |
+
for i, cof in enumerate(cofibers):
|
| 36 |
+
B, C, H, W = cof.shape
|
| 37 |
+
f = self.scale_norms[i](cof.permute(0, 2, 3, 1).reshape(-1, C))
|
| 38 |
+
out = (f @ self.prototypes.T + self.proto_bias).reshape(B, H, W, -1).permute(0, 3, 1, 2)
|
| 39 |
+
out = F.interpolate(out, size=target_size, mode="bilinear", align_corners=False)
|
| 40 |
+
logits = out if logits is None else logits + out
|
| 41 |
+
return logits
|
heads/graph_crf/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/graph_crf/head.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Graph CRF: k-NN graph in feature space + message passing + per-node classification."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GraphCRF(nn.Module):
|
| 9 |
+
name = "graph_crf"
|
| 10 |
+
needs_intermediates = False
|
| 11 |
+
|
| 12 |
+
def __init__(self, feat_dim=768, num_classes=150, dim=256, k=8, rounds=2):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.k = k
|
| 15 |
+
self.rounds = rounds
|
| 16 |
+
self.proj = nn.Linear(feat_dim, dim)
|
| 17 |
+
self.msg_layers = nn.ModuleList()
|
| 18 |
+
for _ in range(rounds):
|
| 19 |
+
self.msg_layers.append(nn.ModuleDict({
|
| 20 |
+
"msg": nn.Linear(dim, dim),
|
| 21 |
+
"gate": nn.Sequential(nn.Linear(dim, dim), nn.Sigmoid()),
|
| 22 |
+
"upd": nn.Linear(dim, dim),
|
| 23 |
+
"norm": nn.LayerNorm(dim),
|
| 24 |
+
}))
|
| 25 |
+
self.cls_head = nn.Linear(dim, num_classes)
|
| 26 |
+
|
| 27 |
+
def forward(self, spatial, inter=None):
|
| 28 |
+
B, C, H, W = spatial.shape
|
| 29 |
+
tokens = self.proj(spatial.flatten(2).permute(0, 2, 1))
|
| 30 |
+
N = tokens.shape[1]
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
sim = torch.bmm(F.normalize(tokens, dim=-1),
|
| 33 |
+
F.normalize(tokens, dim=-1).transpose(1, 2))
|
| 34 |
+
_, knn_idx = sim.topk(self.k, dim=-1)
|
| 35 |
+
for layer in self.msg_layers:
|
| 36 |
+
neighbors = tokens.gather(1, knn_idx.reshape(B, -1, 1).expand(-1, -1, tokens.shape[-1])).reshape(B, N, self.k, -1)
|
| 37 |
+
msg = layer["msg"](neighbors.mean(dim=2))
|
| 38 |
+
gate = layer["gate"](tokens)
|
| 39 |
+
tokens = layer["norm"](tokens + layer["upd"](gate * msg))
|
| 40 |
+
logits = self.cls_head(tokens).reshape(B, H, W, -1).permute(0, 3, 1, 2)
|
| 41 |
+
return logits
|
heads/hypercolumn_linear/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/hypercolumn_linear/head.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hypercolumn Linear: concatenate features from intermediate blocks, single linear layer."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
N_PREFIX = 5
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HypercolumnLinear(nn.Module):
|
| 11 |
+
name = "hypercolumn_linear"
|
| 12 |
+
needs_intermediates = True
|
| 13 |
+
|
| 14 |
+
def __init__(self, feat_dim=768, num_classes=150, n_blocks=4):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.n_blocks = n_blocks
|
| 17 |
+
self.conv = nn.Conv2d(feat_dim * n_blocks, num_classes, 1)
|
| 18 |
+
|
| 19 |
+
def forward(self, spatial, inter=None):
|
| 20 |
+
B, C, H, W = spatial.shape
|
| 21 |
+
if inter is None:
|
| 22 |
+
raise ValueError("hypercolumn_linear requires intermediate block features")
|
| 23 |
+
spatials = []
|
| 24 |
+
for feat in inter:
|
| 25 |
+
patches = feat[:, N_PREFIX:, :]
|
| 26 |
+
s = patches.permute(0, 2, 1).reshape(B, C, H, W)
|
| 27 |
+
spatials.append(s)
|
| 28 |
+
stacked = torch.cat(spatials, dim=1)
|
| 29 |
+
return self.conv(stacked)
|
heads/linear_probe/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/linear_probe/head.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Linear probe: BatchNorm + 1x1 conv. The EUPE paper baseline."""
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class LinearProbe(nn.Module):
|
| 6 |
+
name = "linear_probe"
|
| 7 |
+
needs_intermediates = False
|
| 8 |
+
|
| 9 |
+
def __init__(self, feat_dim=768, num_classes=150):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.bn = nn.BatchNorm2d(feat_dim)
|
| 12 |
+
self.conv = nn.Conv2d(feat_dim, num_classes, 1)
|
| 13 |
+
|
| 14 |
+
def forward(self, spatial, inter=None):
|
| 15 |
+
return self.conv(self.bn(spatial))
|
heads/patch_attention/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/patch_attention/head.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Patch Attention: each patch attends to k nearest neighbors before classifying."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PatchAttention(nn.Module):
|
| 9 |
+
name = "patch_attention"
|
| 10 |
+
needs_intermediates = False
|
| 11 |
+
|
| 12 |
+
def __init__(self, feat_dim=768, num_classes=150, dim=256, k=16):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.k = k
|
| 15 |
+
self.proj = nn.Linear(feat_dim, dim)
|
| 16 |
+
self.attn = nn.MultiheadAttention(dim, 4, batch_first=True)
|
| 17 |
+
self.norm = nn.LayerNorm(dim)
|
| 18 |
+
self.cls_head = nn.Linear(dim, num_classes)
|
| 19 |
+
|
| 20 |
+
def forward(self, spatial, inter=None):
|
| 21 |
+
B, C, H, W = spatial.shape
|
| 22 |
+
tokens = self.proj(spatial.flatten(2).permute(0, 2, 1))
|
| 23 |
+
with torch.no_grad():
|
| 24 |
+
sim = torch.bmm(F.normalize(tokens, dim=-1),
|
| 25 |
+
F.normalize(tokens, dim=-1).transpose(1, 2))
|
| 26 |
+
_, knn_idx = sim.topk(self.k, dim=-1)
|
| 27 |
+
# Gather k-NN for each token as KV, self-attend
|
| 28 |
+
N = tokens.shape[1]
|
| 29 |
+
kv = tokens.gather(1, knn_idx.reshape(B, -1, 1).expand(-1, -1, tokens.shape[-1])).reshape(B, N, self.k, -1)
|
| 30 |
+
kv_flat = kv.reshape(B * N, self.k, -1)
|
| 31 |
+
q = tokens.reshape(B * N, 1, -1)
|
| 32 |
+
out, _ = self.attn(q, kv_flat, kv_flat)
|
| 33 |
+
out = self.norm(tokens + out.reshape(B, N, -1))
|
| 34 |
+
logits = self.cls_head(out).reshape(B, H, W, -1).permute(0, 3, 1, 2)
|
| 35 |
+
return logits
|
heads/prototype_bank/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/prototype_bank/head.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prototype Bank: learned class prototypes, per-pixel cosine similarity, argmax. No conv layers."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PrototypeBank(nn.Module):
|
| 9 |
+
name = "prototype_bank"
|
| 10 |
+
needs_intermediates = False
|
| 11 |
+
|
| 12 |
+
def __init__(self, feat_dim=768, num_classes=150):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01)
|
| 15 |
+
self.scale = nn.Parameter(torch.ones(1) * 10.0)
|
| 16 |
+
|
| 17 |
+
def forward(self, spatial, inter=None):
|
| 18 |
+
B, C, H, W = spatial.shape
|
| 19 |
+
f = F.normalize(spatial.permute(0, 2, 3, 1).reshape(-1, C), dim=-1)
|
| 20 |
+
p = F.normalize(self.prototypes, dim=-1)
|
| 21 |
+
logits = (f @ p.T * self.scale).reshape(B, H, W, -1).permute(0, 3, 1, 2)
|
| 22 |
+
return logits
|
heads/wavelet/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/wavelet/head.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wavelet: Haar decomposition + per-subband classification."""
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Wavelet(nn.Module):
|
| 8 |
+
name = "wavelet"
|
| 9 |
+
needs_intermediates = False
|
| 10 |
+
|
| 11 |
+
def __init__(self, feat_dim=768, num_classes=150, n_scales=3):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.n_scales = n_scales
|
| 14 |
+
self.heads = nn.ModuleList([nn.Conv2d(feat_dim, num_classes, 1) for _ in range(n_scales)])
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
def haar_down(x):
|
| 18 |
+
return (x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2] +
|
| 19 |
+
x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]) / 4
|
| 20 |
+
|
| 21 |
+
def forward(self, spatial, inter=None):
|
| 22 |
+
target_size = spatial.shape[2:]
|
| 23 |
+
f = spatial
|
| 24 |
+
logits = None
|
| 25 |
+
for i in range(self.n_scales):
|
| 26 |
+
out = self.heads[i](f)
|
| 27 |
+
out = F.interpolate(out, size=target_size, mode="bilinear", align_corners=False)
|
| 28 |
+
logits = out if logits is None else logits + out
|
| 29 |
+
if i < self.n_scales - 1:
|
| 30 |
+
f = self.haar_down(f)
|
| 31 |
+
return logits
|
losses/__init__.py
ADDED
|
File without changes
|
losses/segmentation.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Segmentation losses."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def cross_entropy_loss(logits, targets, ignore_index=255):
|
| 8 |
+
"""Standard per-pixel cross entropy. logits: [B, C, H, W], targets: [B, H, W]."""
|
| 9 |
+
if logits.shape[2:] != targets.shape[1:]:
|
| 10 |
+
logits = F.interpolate(logits, size=targets.shape[1:], mode="bilinear", align_corners=False)
|
| 11 |
+
return F.cross_entropy(logits, targets, ignore_index=ignore_index)
|
utils/__init__.py
ADDED
|
File without changes
|
utils/decode.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared utilities for segmentation heads."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def upsample_and_argmax(logits, target_size):
|
| 8 |
+
"""Upsample logits to target spatial size and return class indices."""
|
| 9 |
+
if logits.shape[2:] != target_size:
|
| 10 |
+
logits = F.interpolate(logits, size=target_size, mode="bilinear", align_corners=False)
|
| 11 |
+
return logits.argmax(dim=1)
|