# Copyright (c) Facebook, Inc. and its affiliates. import fvcore.nn.weight_init as weight_init import torch.nn.functional as F from torch import nn def _assert_strides_are_log2_contiguous(strides): """ Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2". """ for i, stride in enumerate(strides[1:], 1): assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format( stride, strides[i - 1] ) class LastLevelMaxPool(nn.Module): """ This module is used in the original FPN to generate a downsampled P6 feature from P5. """ def __init__(self): super().__init__() self.num_levels = 1 self.in_feature = "p5" def forward(self, x): return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)] class LastLevelP6P7(nn.Module): """ This module is used in RetinaNet to generate extra layers, P6 and P7 from C5 feature. """ def __init__(self, in_channels, out_channels, in_feature="res5"): super().__init__() self.num_levels = 2 self.in_feature = in_feature self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) for module in [self.p6, self.p7]: weight_init.c2_xavier_fill(module) def forward(self, c5): p6 = self.p6(c5) p7 = self.p7(F.relu(p6)) return [p6, p7]