File size: 5,117 Bytes
33569f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | import torch
from torch import nn
from torch.nn import functional as F
from .models import register_neck
from .blocks import MaskedConv1D, LayerNorm
@register_neck("fpn")
class FPN1D(nn.Module):
"""
Feature pyramid network
"""
def __init__(
self,
in_channels, # input feature channels, len(in_channels) = # levels
out_channel, # output feature channel
scale_factor=2.0, # downsampling rate between two fpn levels
start_level=0, # start fpn level
end_level=-1, # end fpn level
with_ln=True # if to apply layer norm at the end
):
super().__init__()
assert isinstance(in_channels, list) or isinstance(in_channels, tuple)
self.in_channels = in_channels
self.out_channel = out_channel
self.scale_factor = scale_factor
self.start_level = start_level
if end_level == -1:
self.end_level = len(in_channels)
else:
self.end_level = end_level
assert self.end_level <= len(in_channels)
assert (self.start_level >= 0) and (self.start_level < self.end_level)
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
self.fpn_norms = nn.ModuleList()
for i in range(self.start_level, self.end_level):
# disable bias if using layer norm
l_conv = MaskedConv1D(
in_channels[i], out_channel, 1, bias=(not with_ln))
# use depthwise conv here for efficiency
fpn_conv = MaskedConv1D(
out_channel, out_channel, 3,
padding=1, bias=(not with_ln), groups=out_channel
)
# layer norm for order (B C T)
if with_ln:
fpn_norm = LayerNorm(out_channel)
else:
fpn_norm = nn.Identity()
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_norms.append(fpn_norm)
def forward(self, inputs, fpn_masks):
# inputs must be a list / tuple
assert len(inputs) == len(self.in_channels)
assert len(fpn_masks) == len(self.in_channels)
# build laterals, fpn_masks will remain the same with 1x1 convs
laterals = []
for i in range(len(self.lateral_convs)):
x, _ = self.lateral_convs[i](
inputs[i + self.start_level], fpn_masks[i + self.start_level]
)
laterals.append(x)
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
laterals[i-1] += F.interpolate(
laterals[i],
scale_factor=self.scale_factor,
mode='nearest'
)
# fpn conv / norm -> outputs
# mask will remain the same
fpn_feats = tuple()
new_fpn_masks = tuple()
for i in range(used_backbone_levels):
x, new_mask = self.fpn_convs[i](
laterals[i], fpn_masks[i + self.start_level])
x = self.fpn_norms[i](x)
fpn_feats += (x, )
new_fpn_masks += (new_mask, )
return fpn_feats, new_fpn_masks
@register_neck('identity')
class FPNIdentity(nn.Module):
def __init__(
self,
in_channels, # input feature channels, len(in_channels) = # levels
out_channel, # output feature channel
scale_factor=2.0, # downsampling rate between two fpn levels
start_level=0, # start fpn level
end_level=-1, # end fpn level
with_ln=True # if to apply layer norm at the end
):
super().__init__()
self.in_channels = in_channels
self.out_channel = out_channel
self.scale_factor = scale_factor
self.start_level = start_level
if end_level == -1:
self.end_level = len(in_channels)
else:
self.end_level = end_level
assert self.end_level <= len(in_channels)
assert (self.start_level >= 0) and (self.start_level < self.end_level)
self.fpn_norms = nn.ModuleList()
for i in range(self.start_level, self.end_level):
# check feat dims
assert self.in_channels[i] == self.out_channel
# layer norm for order (B C T)
if with_ln:
fpn_norm = LayerNorm(out_channel)
else:
fpn_norm = nn.Identity()
self.fpn_norms.append(fpn_norm)
def forward(self, inputs, fpn_masks):
# inputs must be a list / tuple
assert len(inputs) == len(self.in_channels)
assert len(fpn_masks) == len(self.in_channels)
# apply norms, fpn_masks will remain the same with 1x1 convs
fpn_feats = tuple()
new_fpn_masks = tuple()
for i in range(len(self.fpn_norms)):
x = self.fpn_norms[i](inputs[i + self.start_level])
fpn_feats += (x, )
new_fpn_masks += (fpn_masks[i + self.start_level], )
return fpn_feats, new_fpn_masks
|