File size: 1,389 Bytes
cb0ad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from torch import nn
from torch.nn import functional as F


class FPN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        assert len(in_channels) == 4
        self.in_channels = in_channels
        
        self.lat_layers = nn.ModuleList()
        self.out_layers = nn.ModuleList()
        for in_channels_pl in in_channels:
            self.lat_layers.append(
                nn.Conv2d(in_channels_pl, out_channels, kernel_size=1, stride=1, padding=0)
            )
            self.out_layers.append(
                nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect')
            )

    def forward(self, feats):
        c2, c3, c4, c5 = feats
        p5 = self.lat_layers[3](c5)
        p4 = F.interpolate(p5, size=c4.shape[2:], align_corners=False, mode='bilinear') + self.lat_layers[2](c4)
        p3 = F.interpolate(p4, size=c3.shape[2:], align_corners=False, mode='bilinear') + self.lat_layers[1](c3)
        p2 = F.interpolate(p3, size=c2.shape[2:], align_corners=False, mode='bilinear') + self.lat_layers[0](c2)

        p2 = self.out_layers[0](p2)
        p3 = self.out_layers[1](p3)
        p4 = self.out_layers[2](p4)
        p5 = self.out_layers[3](p5)
        return p2, p3, p4, p5


def build_fpn(in_channels, out_channels):
    return FPN(in_channels, out_channels)