import math import torch import torch.nn as nn from torch.nn import functional as F import warnings class Conv2d(torch.nn.Conv2d): """ A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. """ def __init__(self, *args, **kwargs): """ Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: Args: norm (nn.Module, optional): a normalization layer activation (callable(Tensor) -> Tensor): a callable activation function It assumes that norm layer is used before activation. """ norm = kwargs.pop("norm", None) activation = kwargs.pop("activation", None) super().__init__(*args, **kwargs) self.norm = norm self.activation = activation def forward(self, x): # torchscript does not support SyncBatchNorm yet # https://github.com/pytorch/pytorch/issues/40507 # and we skip these codes in torchscript since: # 1. currently we only support torchscript in evaluation mode # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. if not torch.jit.is_scripting(): # Dynamo doesn't support context managers yet is_dynamo_compiling = True if not is_dynamo_compiling: with warnings.catch_warnings(record=True): if x.numel() == 0 and self.training: # https://github.com/pytorch/pytorch/issues/12013 assert not isinstance( self.norm, torch.nn.SyncBatchNorm ), "SyncBatchNorm does not support empty inputs!" x = F.conv2d( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) if self.norm is not None: x = self.norm(x) if self.activation is not None: x = self.activation(x) return x class LayerNorm(nn.Module): """ A LayerNorm variant, popularized by Transformers, that performs point-wise mean and variance normalization over the channel dimension for inputs that have shape (batch_size, channels, height, width). https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 """ def __init__(self, normalized_shape, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.normalized_shape = (normalized_shape,) def forward(self, x): u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x def get_norm(norm, out_channels): """ Args: norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; or a callable that takes a channel number and returns the normalization layer as a nn.Module. Returns: nn.Module or None: the normalization layer """ if norm is None: return None if isinstance(norm, str): if len(norm) == 0: return None norm = { "LN": lambda channels: LayerNorm(channels), }[norm] return norm(out_channels) class SimpleFP(nn.Module): """ This module implements SimpleFPN in :paper:`vitdet`. It creates pyramid features built on top of the input feature map. """ def __init__( self, out_channels, scale_factors=[4.0, 2.0, 1.0, 0.5], top_block=None, norm="LN", square_pad=0, dim=1024, stride=14, ): """ Args: out_channels (int): number of channels in the output feature maps. scale_factors (list[float]): list of scaling factors to upsample or downsample the input features for creating pyramid features. top_block (nn.Module or None): if provided, an extra operation will be performed on the output of the last (smallest resolution) pyramid output, and the result will extend the result list. The top_block further downsamples the feature map. It must have an attribute "num_levels", meaning the number of extra pyramid levels added by this block, and "in_feature", which is a string representing its input feature (e.g., p5). norm (str): the normalization to use. square_pad (int): If > 0, require input images to be padded to specific square size. """ super(SimpleFP, self).__init__() self.scale_factors = scale_factors strides = [int(stride / scale) for scale in scale_factors] self.stages = [] use_bias = norm == "" for idx, scale in enumerate(scale_factors): out_dim = dim if scale == 4.0: layers = [ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), get_norm(norm, dim // 2), nn.GELU(), nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), ] out_dim = dim // 4 elif scale == 2.0: layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] out_dim = dim // 2 elif scale == 1.0: layers = [] elif scale == 0.5: layers = [nn.MaxPool2d(kernel_size=2, stride=2)] else: raise NotImplementedError(f"scale_factor={scale} is not supported yet.") layers.extend( [ Conv2d( out_dim, out_channels, kernel_size=1, bias=use_bias, norm=get_norm(norm, out_channels), ), Conv2d( out_channels, out_channels, kernel_size=3, padding=1, bias=use_bias, norm=get_norm(norm, out_channels), ), ] ) layers = nn.Sequential(*layers) stage = int(math.log2(strides[idx])) self.add_module(f"simfp_{stage}", layers) self.stages.append(layers) self.top_block = top_block # Return feature names are "p", like ["p2", "p3", ..., "p6"] self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} # top block output feature maps. if self.top_block is not None: for s in range(stage, stage + self.top_block.num_levels): self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) self._out_features = list(self._out_feature_strides.keys()) self._out_feature_channels = {k: out_channels for k in self._out_features} self._size_divisibility = strides[-1] self._square_pad = square_pad def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: mapping from feature map name to pyramid feature map tensor in high to low resolution order. Returned feature names follow the FPN convention: "p", where stage has stride = 2 ** stage e.g., ["p2", "p3", ..., "p6"]. """ features = x results = [] for stage in self.stages: results.append(stage(features)) assert len(self._out_features) == len(results) return results if __name__ == "__main__": """ Test the functionality of SimpleFPN (Feature Pyramid Network). The test uses an input tensor of shape (1, 1024, 28, 28). """ import torch # Generate a dummy input tensor of shape (batch_size=1, channels=1024, height=28, width=28) test_input = torch.randn(1, 1024, 28, 28) # Instantiate the SimpleFP (assumed to be the Feature Pyramid module) # Note: The arguments below should be checked and adapted according to SimpleFP's actual constructor. fpn = SimpleFP( out_channels=256, # Number of output channels for FPN layers norm="LN", # Normalization type, here using LayerNorm ("LN") square_pad=0, # Square padding size if needed (here 0 means no padding) dim=1024, # Number of input channels/features from the backbone stride=14 # Stride setting, typically related to feature scaling ) # ~~~~~ Model Forward Pass ~~~~~ # Compute FPN outputs with torch.no_grad() to avoid tracking gradients (for eval/testing) with torch.no_grad(): output = fpn(test_input) # Expected: result is a list of feature tensors at different FPN stages # ~~~~~ Print Input/Output Information ~~~~~ print("SimpleFPN Test Results:") print(f"Input Shape: {test_input.shape}") print("Output feature maps from each FPN stage:") # NOTE: If the output is a list, the features are not named (unlike dict). Adapt print info accordingly. for idx, feature_map in enumerate(output): print(f" Output stage {idx}: shape = {feature_map.shape}") # If output were a dict with feature names, iterate as below instead: # for feature_name, feature_map in output.items(): # print(f" {feature_name}: {feature_map.shape}")