| import torch |
| import torch.nn.functional as F |
| from torchvision.ops import FeaturePyramidNetwork |
| from collections import OrderedDict |
|
|
| class EfficientFeaturePyramidNetwork(FeaturePyramidNetwork): |
|
|
| def __init__( |
| self, |
| in_channels_list, |
| out_channels, |
| extra_blocks=None, |
| norm_layer=None, |
| output_level="res3" |
| ): |
| super().__init__( |
| in_channels_list, |
| out_channels, |
| extra_blocks, |
| norm_layer, |
| ) |
| self.output_level = output_level |
|
|
| |
| for idx, block in enumerate(self.inner_blocks): |
| if isinstance(block, torch.nn.Conv2d): |
| |
| new_block = torch.nn.Conv2d( |
| block.in_channels, |
| block.out_channels, |
| block.kernel_size, |
| stride=block.stride, |
| padding=block.padding, |
| dilation=block.dilation, |
| bias=(block.bias is not None), |
| padding_mode=block.padding_mode, |
| ).to(memory_format=torch.contiguous_format).requires_grad_(True) |
| |
| |
| new_block.weight.data.copy_(block.weight.data) |
| if block.bias is not None: |
| new_block.bias.data.copy_(block.bias.data) |
| |
| self.inner_blocks[idx] = new_block |
|
|
| |
| for block in self.inner_blocks: |
| if isinstance(block, torch.nn.Conv2d): |
| block.weight.register_hook(lambda grad: grad.contiguous()) |
|
|
| def forward(self, x): |
| """ |
| Computes the FPN for a set of feature maps. |
| |
| Args: |
| x (OrderedDict[Tensor]): feature maps for each feature level. |
| |
| Returns: |
| results (OrderedDict[Tensor]): feature maps after FPN layers. |
| They are ordered from the highest resolution first. |
| """ |
| names = list(x.keys()) |
| x = [v.contiguous(memory_format=torch.contiguous_format) for v in x.values()] |
|
|
| last_inner = self.get_result_from_inner_blocks(x[-1], -1) |
| results = [] |
| results.append(self.get_result_from_layer_blocks(last_inner, -1)) |
|
|
| if names[-1] != self.output_level: |
| for idx in range(len(x) - 2, -1, -1): |
| inner_lateral = self.get_result_from_inner_blocks(x[idx], idx) |
| feat_shape = inner_lateral.shape[-2:] |
| inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest") |
| last_inner = (inner_lateral + inner_top_down).contiguous() |
| results.insert(0, self.get_result_from_layer_blocks(last_inner, idx)) |
|
|
| |
| if names[idx] == self.output_level: |
| names = names[idx:] |
| break |
| else: |
| names = names[-1:] |
|
|
| if self.extra_blocks is not None: |
| results, names = self.extra_blocks(results, x, names) |
|
|
| |
| out = OrderedDict({ |
| k: v.contiguous(memory_format=torch.contiguous_format) |
| for k, v in zip(names, results) |
| }) |
|
|
| return out |
|
|