sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
5.12 kB
import torch
from torch import nn
from torch.nn import functional as F
from .models import register_neck
from .blocks import MaskedConv1D, LayerNorm
@register_neck("fpn")
class FPN1D(nn.Module):
"""
Feature pyramid network
"""
def __init__(
self,
in_channels, # input feature channels, len(in_channels) = # levels
out_channel, # output feature channel
scale_factor=2.0, # downsampling rate between two fpn levels
start_level=0, # start fpn level
end_level=-1, # end fpn level
with_ln=True # if to apply layer norm at the end
):
super().__init__()
assert isinstance(in_channels, list) or isinstance(in_channels, tuple)
self.in_channels = in_channels
self.out_channel = out_channel
self.scale_factor = scale_factor
self.start_level = start_level
if end_level == -1:
self.end_level = len(in_channels)
else:
self.end_level = end_level
assert self.end_level <= len(in_channels)
assert (self.start_level >= 0) and (self.start_level < self.end_level)
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
self.fpn_norms = nn.ModuleList()
for i in range(self.start_level, self.end_level):
# disable bias if using layer norm
l_conv = MaskedConv1D(
in_channels[i], out_channel, 1, bias=(not with_ln))
# use depthwise conv here for efficiency
fpn_conv = MaskedConv1D(
out_channel, out_channel, 3,
padding=1, bias=(not with_ln), groups=out_channel
)
# layer norm for order (B C T)
if with_ln:
fpn_norm = LayerNorm(out_channel)
else:
fpn_norm = nn.Identity()
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_norms.append(fpn_norm)
def forward(self, inputs, fpn_masks):
# inputs must be a list / tuple
assert len(inputs) == len(self.in_channels)
assert len(fpn_masks) == len(self.in_channels)
# build laterals, fpn_masks will remain the same with 1x1 convs
laterals = []
for i in range(len(self.lateral_convs)):
x, _ = self.lateral_convs[i](
inputs[i + self.start_level], fpn_masks[i + self.start_level]
)
laterals.append(x)
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
laterals[i-1] += F.interpolate(
laterals[i],
scale_factor=self.scale_factor,
mode='nearest'
)
# fpn conv / norm -> outputs
# mask will remain the same
fpn_feats = tuple()
new_fpn_masks = tuple()
for i in range(used_backbone_levels):
x, new_mask = self.fpn_convs[i](
laterals[i], fpn_masks[i + self.start_level])
x = self.fpn_norms[i](x)
fpn_feats += (x, )
new_fpn_masks += (new_mask, )
return fpn_feats, new_fpn_masks
@register_neck('identity')
class FPNIdentity(nn.Module):
def __init__(
self,
in_channels, # input feature channels, len(in_channels) = # levels
out_channel, # output feature channel
scale_factor=2.0, # downsampling rate between two fpn levels
start_level=0, # start fpn level
end_level=-1, # end fpn level
with_ln=True # if to apply layer norm at the end
):
super().__init__()
self.in_channels = in_channels
self.out_channel = out_channel
self.scale_factor = scale_factor
self.start_level = start_level
if end_level == -1:
self.end_level = len(in_channels)
else:
self.end_level = end_level
assert self.end_level <= len(in_channels)
assert (self.start_level >= 0) and (self.start_level < self.end_level)
self.fpn_norms = nn.ModuleList()
for i in range(self.start_level, self.end_level):
# check feat dims
assert self.in_channels[i] == self.out_channel
# layer norm for order (B C T)
if with_ln:
fpn_norm = LayerNorm(out_channel)
else:
fpn_norm = nn.Identity()
self.fpn_norms.append(fpn_norm)
def forward(self, inputs, fpn_masks):
# inputs must be a list / tuple
assert len(inputs) == len(self.in_channels)
assert len(fpn_masks) == len(self.in_channels)
# apply norms, fpn_masks will remain the same with 1x1 convs
fpn_feats = tuple()
new_fpn_masks = tuple()
for i in range(len(self.fpn_norms)):
x = self.fpn_norms[i](inputs[i + self.start_level])
fpn_feats += (x, )
new_fpn_masks += (fpn_masks[i + self.start_level], )
return fpn_feats, new_fpn_masks