phanerozoic's picture
8 segmentation head candidates with shared losses/utils and registry
0e8110e verified
"""Hypercolumn Linear: concatenate features from intermediate blocks, single linear layer."""
import torch
import torch.nn as nn
import torch.nn.functional as F
N_PREFIX = 5
class HypercolumnLinear(nn.Module):
name = "hypercolumn_linear"
needs_intermediates = True
def __init__(self, feat_dim=768, num_classes=150, n_blocks=4):
super().__init__()
self.n_blocks = n_blocks
self.conv = nn.Conv2d(feat_dim * n_blocks, num_classes, 1)
def forward(self, spatial, inter=None):
B, C, H, W = spatial.shape
if inter is None:
raise ValueError("hypercolumn_linear requires intermediate block features")
spatials = []
for feat in inter:
patches = feat[:, N_PREFIX:, :]
s = patches.permute(0, 2, 1).reshape(B, C, H, W)
spatials.append(s)
stacked = torch.cat(spatials, dim=1)
return self.conv(stacked)