lsnu's picture
Add files using upload-large-folder tool
5ce8761 verified
import torch
import torch.nn.functional as F
from torchvision.ops import FeaturePyramidNetwork
from collections import OrderedDict
class EfficientFeaturePyramidNetwork(FeaturePyramidNetwork):
def __init__(
self,
in_channels_list,
out_channels,
extra_blocks=None,
norm_layer=None,
output_level="res3"
):
super().__init__(
in_channels_list,
out_channels,
extra_blocks,
norm_layer,
)
self.output_level = output_level
# Ensure inner_blocks Conv2d weights are fully contiguous
for idx, block in enumerate(self.inner_blocks):
if isinstance(block, torch.nn.Conv2d):
# Recreate Conv2d with fresh contiguous weights
new_block = torch.nn.Conv2d(
block.in_channels,
block.out_channels,
block.kernel_size,
stride=block.stride,
padding=block.padding,
dilation=block.dilation,
bias=(block.bias is not None),
padding_mode=block.padding_mode,
).to(memory_format=torch.contiguous_format).requires_grad_(True)
# Copy weights/bias from original block
new_block.weight.data.copy_(block.weight.data)
if block.bias is not None:
new_block.bias.data.copy_(block.bias.data)
self.inner_blocks[idx] = new_block
# Register backward hooks to ensure .grad is contiguous
for block in self.inner_blocks:
if isinstance(block, torch.nn.Conv2d):
block.weight.register_hook(lambda grad: grad.contiguous())
def forward(self, x):
"""
Computes the FPN for a set of feature maps.
Args:
x (OrderedDict[Tensor]): feature maps for each feature level.
Returns:
results (OrderedDict[Tensor]): feature maps after FPN layers.
They are ordered from the highest resolution first.
"""
names = list(x.keys())
x = [v.contiguous(memory_format=torch.contiguous_format) for v in x.values()]
last_inner = self.get_result_from_inner_blocks(x[-1], -1)
results = []
results.append(self.get_result_from_layer_blocks(last_inner, -1))
if names[-1] != self.output_level:
for idx in range(len(x) - 2, -1, -1):
inner_lateral = self.get_result_from_inner_blocks(x[idx], idx)
feat_shape = inner_lateral.shape[-2:]
inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
last_inner = (inner_lateral + inner_top_down).contiguous()
results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
# Stop early if we've reached desired output level
if names[idx] == self.output_level:
names = names[idx:]
break
else:
names = names[-1:]
if self.extra_blocks is not None:
results, names = self.extra_blocks(results, x, names)
# Reformat to OrderedDict
out = OrderedDict({
k: v.contiguous(memory_format=torch.contiguous_format)
for k, v in zip(names, results)
})
return out