| """Hypercolumn Linear: concatenate features from intermediate blocks, single linear layer.""" |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| N_PREFIX = 5 |
|
|
|
|
| class HypercolumnLinear(nn.Module): |
| name = "hypercolumn_linear" |
| needs_intermediates = True |
|
|
| def __init__(self, feat_dim=768, num_classes=150, n_blocks=4): |
| super().__init__() |
| self.n_blocks = n_blocks |
| self.conv = nn.Conv2d(feat_dim * n_blocks, num_classes, 1) |
|
|
| def forward(self, spatial, inter=None): |
| B, C, H, W = spatial.shape |
| if inter is None: |
| raise ValueError("hypercolumn_linear requires intermediate block features") |
| spatials = [] |
| for feat in inter: |
| patches = feat[:, N_PREFIX:, :] |
| s = patches.permute(0, 2, 1).reshape(B, C, H, W) |
| spatials.append(s) |
| stacked = torch.cat(spatials, dim=1) |
| return self.conv(stacked) |
|
|