import torch import torch.nn as nn import torch.nn.functional as F # Ref: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py#L363 class ViTFeaturePyramid(nn.Module): """ This module implements SimpleFeaturePyramid in :paper:`vitdet`. It creates pyramid features built on top of the input feature map. """ def __init__( self, in_channels, scale_factors, ): """ Args: scale_factors (list[float]): list of scaling factors to upsample or downsample the input features for creating pyramid features. """ super(ViTFeaturePyramid, self).__init__() self.scale_factors = scale_factors out_dim = dim = in_channels self.stages = nn.ModuleList() for idx, scale in enumerate(scale_factors): if scale == 4.0: layers = [ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), nn.GELU(), nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), ] out_dim = dim // 4 elif scale == 2.0: layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] out_dim = dim // 2 elif scale == 1.0: layers = [] elif scale == 0.5: layers = [nn.MaxPool2d(kernel_size=2, stride=2)] else: raise NotImplementedError(f"scale_factor={scale} is not supported yet.") if scale != 1.0: layers.extend( [ nn.GELU(), nn.Conv2d(out_dim, out_dim, 3, 1, 1), ] ) layers = nn.Sequential(*layers) self.stages.append(layers) def forward(self, x): results = [] for stage in self.stages: results.append(stage(x)) return results def _test(): model = ViTFeaturePyramid( 384, scale_factors=[1, 2, 4], ).cuda() print(model) x = torch.randn(2, 384, 64, 96).cuda() out = model(x) for x in out: print(x.shape) if __name__ == "__main__": _test()