Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import warnings | |
| class Conv2d(torch.nn.Conv2d): | |
| """ | |
| A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| """ | |
| Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: | |
| Args: | |
| norm (nn.Module, optional): a normalization layer | |
| activation (callable(Tensor) -> Tensor): a callable activation function | |
| It assumes that norm layer is used before activation. | |
| """ | |
| norm = kwargs.pop("norm", None) | |
| activation = kwargs.pop("activation", None) | |
| super().__init__(*args, **kwargs) | |
| self.norm = norm | |
| self.activation = activation | |
| def forward(self, x): | |
| # torchscript does not support SyncBatchNorm yet | |
| # https://github.com/pytorch/pytorch/issues/40507 | |
| # and we skip these codes in torchscript since: | |
| # 1. currently we only support torchscript in evaluation mode | |
| # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or | |
| # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. | |
| if not torch.jit.is_scripting(): | |
| # Dynamo doesn't support context managers yet | |
| is_dynamo_compiling = True | |
| if not is_dynamo_compiling: | |
| with warnings.catch_warnings(record=True): | |
| if x.numel() == 0 and self.training: | |
| # https://github.com/pytorch/pytorch/issues/12013 | |
| assert not isinstance( | |
| self.norm, torch.nn.SyncBatchNorm | |
| ), "SyncBatchNorm does not support empty inputs!" | |
| x = F.conv2d( | |
| x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups | |
| ) | |
| if self.norm is not None: | |
| x = self.norm(x) | |
| if self.activation is not None: | |
| x = self.activation(x) | |
| return x | |
| class LayerNorm(nn.Module): | |
| """ | |
| A LayerNorm variant, popularized by Transformers, that performs point-wise mean and | |
| variance normalization over the channel dimension for inputs that have shape | |
| (batch_size, channels, height, width). | |
| https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 | |
| """ | |
| def __init__(self, normalized_shape, eps=1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
| self.eps = eps | |
| self.normalized_shape = (normalized_shape,) | |
| def forward(self, x): | |
| u = x.mean(1, keepdim=True) | |
| s = (x - u).pow(2).mean(1, keepdim=True) | |
| x = (x - u) / torch.sqrt(s + self.eps) | |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
| return x | |
| def get_norm(norm, out_channels): | |
| """ | |
| Args: | |
| norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; | |
| or a callable that takes a channel number and returns | |
| the normalization layer as a nn.Module. | |
| Returns: | |
| nn.Module or None: the normalization layer | |
| """ | |
| if norm is None: | |
| return None | |
| if isinstance(norm, str): | |
| if len(norm) == 0: | |
| return None | |
| norm = { | |
| "LN": lambda channels: LayerNorm(channels), | |
| }[norm] | |
| return norm(out_channels) | |
| class SimpleFP(nn.Module): | |
| """ | |
| This module implements SimpleFPN in :paper:`vitdet`. | |
| It creates pyramid features built on top of the input feature map. | |
| """ | |
| def __init__( | |
| self, | |
| out_channels, | |
| scale_factors=[4.0, 2.0, 1.0, 0.5], | |
| top_block=None, | |
| norm="LN", | |
| square_pad=0, | |
| dim=1024, | |
| stride=14, | |
| ): | |
| """ | |
| Args: | |
| out_channels (int): number of channels in the output feature maps. | |
| scale_factors (list[float]): list of scaling factors to upsample or downsample | |
| the input features for creating pyramid features. | |
| top_block (nn.Module or None): if provided, an extra operation will | |
| be performed on the output of the last (smallest resolution) | |
| pyramid output, and the result will extend the result list. The top_block | |
| further downsamples the feature map. It must have an attribute | |
| "num_levels", meaning the number of extra pyramid levels added by | |
| this block, and "in_feature", which is a string representing | |
| its input feature (e.g., p5). | |
| norm (str): the normalization to use. | |
| square_pad (int): If > 0, require input images to be padded to specific square size. | |
| """ | |
| super(SimpleFP, self).__init__() | |
| self.scale_factors = scale_factors | |
| strides = [int(stride / scale) for scale in scale_factors] | |
| self.stages = [] | |
| use_bias = norm == "" | |
| for idx, scale in enumerate(scale_factors): | |
| out_dim = dim | |
| if scale == 4.0: | |
| layers = [ | |
| nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), | |
| get_norm(norm, dim // 2), | |
| nn.GELU(), | |
| nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), | |
| ] | |
| out_dim = dim // 4 | |
| elif scale == 2.0: | |
| layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] | |
| out_dim = dim // 2 | |
| elif scale == 1.0: | |
| layers = [] | |
| elif scale == 0.5: | |
| layers = [nn.MaxPool2d(kernel_size=2, stride=2)] | |
| else: | |
| raise NotImplementedError(f"scale_factor={scale} is not supported yet.") | |
| layers.extend( | |
| [ | |
| Conv2d( | |
| out_dim, | |
| out_channels, | |
| kernel_size=1, | |
| bias=use_bias, | |
| norm=get_norm(norm, out_channels), | |
| ), | |
| Conv2d( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| bias=use_bias, | |
| norm=get_norm(norm, out_channels), | |
| ), | |
| ] | |
| ) | |
| layers = nn.Sequential(*layers) | |
| stage = int(math.log2(strides[idx])) | |
| self.add_module(f"simfp_{stage}", layers) | |
| self.stages.append(layers) | |
| self.top_block = top_block | |
| # Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"] | |
| self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} | |
| # top block output feature maps. | |
| if self.top_block is not None: | |
| for s in range(stage, stage + self.top_block.num_levels): | |
| self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) | |
| self._out_features = list(self._out_feature_strides.keys()) | |
| self._out_feature_channels = {k: out_channels for k in self._out_features} | |
| self._size_divisibility = strides[-1] | |
| self._square_pad = square_pad | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. | |
| Returns: | |
| dict[str->Tensor]: | |
| mapping from feature map name to pyramid feature map tensor | |
| in high to low resolution order. Returned feature names follow the FPN | |
| convention: "p<stage>", where stage has stride = 2 ** stage e.g., | |
| ["p2", "p3", ..., "p6"]. | |
| """ | |
| features = x | |
| results = [] | |
| for stage in self.stages: | |
| results.append(stage(features)) | |
| assert len(self._out_features) == len(results) | |
| return results | |
| if __name__ == "__main__": | |
| """ | |
| Test the functionality of SimpleFPN (Feature Pyramid Network). | |
| The test uses an input tensor of shape (1, 1024, 28, 28). | |
| """ | |
| import torch | |
| # Generate a dummy input tensor of shape (batch_size=1, channels=1024, height=28, width=28) | |
| test_input = torch.randn(1, 1024, 28, 28) | |
| # Instantiate the SimpleFP (assumed to be the Feature Pyramid module) | |
| # Note: The arguments below should be checked and adapted according to SimpleFP's actual constructor. | |
| fpn = SimpleFP( | |
| out_channels=256, # Number of output channels for FPN layers | |
| norm="LN", # Normalization type, here using LayerNorm ("LN") | |
| square_pad=0, # Square padding size if needed (here 0 means no padding) | |
| dim=1024, # Number of input channels/features from the backbone | |
| stride=14 # Stride setting, typically related to feature scaling | |
| ) | |
| # ~~~~~ Model Forward Pass ~~~~~ | |
| # Compute FPN outputs with torch.no_grad() to avoid tracking gradients (for eval/testing) | |
| with torch.no_grad(): | |
| output = fpn(test_input) # Expected: result is a list of feature tensors at different FPN stages | |
| # ~~~~~ Print Input/Output Information ~~~~~ | |
| print("SimpleFPN Test Results:") | |
| print(f"Input Shape: {test_input.shape}") | |
| print("Output feature maps from each FPN stage:") | |
| # NOTE: If the output is a list, the features are not named (unlike dict). Adapt print info accordingly. | |
| for idx, feature_map in enumerate(output): | |
| print(f" Output stage {idx}: shape = {feature_map.shape}") | |
| # If output were a dict with feature names, iterate as below instead: | |
| # for feature_name, feature_map in output.items(): | |
| # print(f" {feature_name}: {feature_map.shape}") | |