"""Hypercolumn Linear: concatenate features from intermediate blocks, single linear layer.""" import torch import torch.nn as nn import torch.nn.functional as F N_PREFIX = 5 class HypercolumnLinear(nn.Module): name = "hypercolumn_linear" needs_intermediates = True def __init__(self, feat_dim=768, num_classes=150, n_blocks=4): super().__init__() self.n_blocks = n_blocks self.conv = nn.Conv2d(feat_dim * n_blocks, num_classes, 1) def forward(self, spatial, inter=None): B, C, H, W = spatial.shape if inter is None: raise ValueError("hypercolumn_linear requires intermediate block features") spatials = [] for feat in inter: patches = feat[:, N_PREFIX:, :] s = patches.permute(0, 2, 1).reshape(B, C, H, W) spatials.append(s) stacked = torch.cat(spatials, dim=1) return self.conv(stacked)