import torch from torch import nn from torch.nn import functional as F class FPN(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() assert len(in_channels) == 4 self.in_channels = in_channels self.lat_layers = nn.ModuleList() self.out_layers = nn.ModuleList() for in_channels_pl in in_channels: self.lat_layers.append( nn.Conv2d(in_channels_pl, out_channels, kernel_size=1, stride=1, padding=0) ) self.out_layers.append( nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect') ) def forward(self, feats): c2, c3, c4, c5 = feats p5 = self.lat_layers[3](c5) p4 = F.interpolate(p5, size=c4.shape[2:], align_corners=False, mode='bilinear') + self.lat_layers[2](c4) p3 = F.interpolate(p4, size=c3.shape[2:], align_corners=False, mode='bilinear') + self.lat_layers[1](c3) p2 = F.interpolate(p3, size=c2.shape[2:], align_corners=False, mode='bilinear') + self.lat_layers[0](c2) p2 = self.out_layers[0](p2) p3 = self.out_layers[1](p3) p4 = self.out_layers[2](p4) p5 = self.out_layers[3](p5) return p2, p3, p4, p5 def build_fpn(in_channels, out_channels): return FPN(in_channels, out_channels)