PengLiu
push inference code
56ef371
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
import warnings
class Conv2d(torch.nn.Conv2d):
"""
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
"""
def __init__(self, *args, **kwargs):
"""
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
Args:
norm (nn.Module, optional): a normalization layer
activation (callable(Tensor) -> Tensor): a callable activation function
It assumes that norm layer is used before activation.
"""
norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None)
super().__init__(*args, **kwargs)
self.norm = norm
self.activation = activation
def forward(self, x):
# torchscript does not support SyncBatchNorm yet
# https://github.com/pytorch/pytorch/issues/40507
# and we skip these codes in torchscript since:
# 1. currently we only support torchscript in evaluation mode
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
if not torch.jit.is_scripting():
# Dynamo doesn't support context managers yet
is_dynamo_compiling = True
if not is_dynamo_compiling:
with warnings.catch_warnings(record=True):
if x.numel() == 0 and self.training:
# https://github.com/pytorch/pytorch/issues/12013
assert not isinstance(
self.norm, torch.nn.SyncBatchNorm
), "SyncBatchNorm does not support empty inputs!"
x = F.conv2d(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
class LayerNorm(nn.Module):
"""
A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
variance normalization over the channel dimension for inputs that have shape
(batch_size, channels, height, width).
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.normalized_shape = (normalized_shape,)
def forward(self, x):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
def get_norm(norm, out_channels):
"""
Args:
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
or a callable that takes a channel number and returns
the normalization layer as a nn.Module.
Returns:
nn.Module or None: the normalization layer
"""
if norm is None:
return None
if isinstance(norm, str):
if len(norm) == 0:
return None
norm = {
"LN": lambda channels: LayerNorm(channels),
}[norm]
return norm(out_channels)
class SimpleFP(nn.Module):
"""
This module implements SimpleFPN in :paper:`vitdet`.
It creates pyramid features built on top of the input feature map.
"""
def __init__(
self,
out_channels,
scale_factors=[4.0, 2.0, 1.0, 0.5],
top_block=None,
norm="LN",
square_pad=0,
dim=1024,
stride=14,
):
"""
Args:
out_channels (int): number of channels in the output feature maps.
scale_factors (list[float]): list of scaling factors to upsample or downsample
the input features for creating pyramid features.
top_block (nn.Module or None): if provided, an extra operation will
be performed on the output of the last (smallest resolution)
pyramid output, and the result will extend the result list. The top_block
further downsamples the feature map. It must have an attribute
"num_levels", meaning the number of extra pyramid levels added by
this block, and "in_feature", which is a string representing
its input feature (e.g., p5).
norm (str): the normalization to use.
square_pad (int): If > 0, require input images to be padded to specific square size.
"""
super(SimpleFP, self).__init__()
self.scale_factors = scale_factors
strides = [int(stride / scale) for scale in scale_factors]
self.stages = []
use_bias = norm == ""
for idx, scale in enumerate(scale_factors):
out_dim = dim
if scale == 4.0:
layers = [
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
get_norm(norm, dim // 2),
nn.GELU(),
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
]
out_dim = dim // 4
elif scale == 2.0:
layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
out_dim = dim // 2
elif scale == 1.0:
layers = []
elif scale == 0.5:
layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
layers.extend(
[
Conv2d(
out_dim,
out_channels,
kernel_size=1,
bias=use_bias,
norm=get_norm(norm, out_channels),
),
Conv2d(
out_channels,
out_channels,
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, out_channels),
),
]
)
layers = nn.Sequential(*layers)
stage = int(math.log2(strides[idx]))
self.add_module(f"simfp_{stage}", layers)
self.stages.append(layers)
self.top_block = top_block
# Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
# top block output feature maps.
if self.top_block is not None:
for s in range(stage, stage + self.top_block.num_levels):
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
self._out_features = list(self._out_feature_strides.keys())
self._out_feature_channels = {k: out_channels for k in self._out_features}
self._size_divisibility = strides[-1]
self._square_pad = square_pad
def forward(self, x):
"""
Args:
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
Returns:
dict[str->Tensor]:
mapping from feature map name to pyramid feature map tensor
in high to low resolution order. Returned feature names follow the FPN
convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
["p2", "p3", ..., "p6"].
"""
features = x
results = []
for stage in self.stages:
results.append(stage(features))
assert len(self._out_features) == len(results)
return results
if __name__ == "__main__":
"""
Test the functionality of SimpleFPN (Feature Pyramid Network).
The test uses an input tensor of shape (1, 1024, 28, 28).
"""
import torch
# Generate a dummy input tensor of shape (batch_size=1, channels=1024, height=28, width=28)
test_input = torch.randn(1, 1024, 28, 28)
# Instantiate the SimpleFP (assumed to be the Feature Pyramid module)
# Note: The arguments below should be checked and adapted according to SimpleFP's actual constructor.
fpn = SimpleFP(
out_channels=256, # Number of output channels for FPN layers
norm="LN", # Normalization type, here using LayerNorm ("LN")
square_pad=0, # Square padding size if needed (here 0 means no padding)
dim=1024, # Number of input channels/features from the backbone
stride=14 # Stride setting, typically related to feature scaling
)
# ~~~~~ Model Forward Pass ~~~~~
# Compute FPN outputs with torch.no_grad() to avoid tracking gradients (for eval/testing)
with torch.no_grad():
output = fpn(test_input) # Expected: result is a list of feature tensors at different FPN stages
# ~~~~~ Print Input/Output Information ~~~~~
print("SimpleFPN Test Results:")
print(f"Input Shape: {test_input.shape}")
print("Output feature maps from each FPN stage:")
# NOTE: If the output is a list, the features are not named (unlike dict). Adapt print info accordingly.
for idx, feature_map in enumerate(output):
print(f" Output stage {idx}: shape = {feature_map.shape}")
# If output were a dict with feature names, iterate as below instead:
# for feature_name, feature_map in output.items():
# print(f" {feature_name}: {feature_map.shape}")