EDM / src /edm /neck /neck.py
lixi042
update
7e31006
import torch
import torch.nn as nn
import torch.nn.functional as F
from .loftr_module.transformer import LocalFeatureTransformer
class Conv2d_BN_Act(nn.Sequential):
def __init__(
self,
a,
b,
ks=1,
stride=1,
pad=0,
dilation=1,
groups=1,
bn_weight_init=1,
act=None,
drop=None,
):
super().__init__()
self.inp_channel = a
self.out_channel = b
self.ks = ks
self.pad = pad
self.stride = stride
self.dilation = dilation
self.groups = groups
self.add_module(
"c", nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)
)
bn = nn.BatchNorm2d(b)
nn.init.constant_(bn.weight, bn_weight_init)
nn.init.constant_(bn.bias, 0)
self.add_module("bn", bn)
if act != None:
self.add_module("a", act)
if drop != None:
self.add_module("d", nn.Dropout(drop))
class CIM(nn.Module):
"""Feature Aggregation, Correlation Injection Module"""
def __init__(self, config):
super(CIM, self).__init__()
self.block_dims = config["backbone"]["block_dims"]
self.drop = config["fine"]["droprate"]
self.fc32 = Conv2d_BN_Act(
self.block_dims[-1], self.block_dims[-1], 1, drop=self.drop
)
self.fc16 = Conv2d_BN_Act(
self.block_dims[-2], self.block_dims[-1], 1, drop=self.drop
)
self.fc8 = Conv2d_BN_Act(
self.block_dims[-3], self.block_dims[-1], 1, drop=self.drop
)
self.att32 = Conv2d_BN_Act(
self.block_dims[-1],
self.block_dims[-1],
1,
act=nn.Sigmoid(),
drop=self.drop,
)
self.att16 = Conv2d_BN_Act(
self.block_dims[-1],
self.block_dims[-1],
1,
act=nn.Sigmoid(),
drop=self.drop,
)
self.dwconv16 = nn.Sequential(
Conv2d_BN_Act(
self.block_dims[-1],
self.block_dims[-1],
ks=3,
pad=1,
groups=self.block_dims[-1],
act=nn.GELU(),
),
Conv2d_BN_Act(self.block_dims[-1], self.block_dims[-1], 1),
)
self.dwconv8 = nn.Sequential(
Conv2d_BN_Act(
self.block_dims[-1],
self.block_dims[-1],
ks=3,
pad=1,
groups=self.block_dims[-1],
act=nn.GELU(),
),
Conv2d_BN_Act(self.block_dims[-1], self.block_dims[-1], 1),
)
self.loftr_32 = LocalFeatureTransformer(config["neck"])
def forward(self, ms_feats, mask_c0=None, mask_c1=None):
if len(ms_feats) == 3: # same image shape
f8, f16, f32 = ms_feats
f32 = self.fc32(f32)
f32_0, f32_1 = f32.chunk(2, dim=0)
f32_0, f32_1 = self.loftr_32(f32_0, f32_1, mask_c0, mask_c1)
f32 = torch.cat([f32_0, f32_1], dim=0)
f32_up = F.interpolate(f32, scale_factor=2.0, mode="bilinear")
att32_up = F.interpolate(self.att32(
f32), scale_factor=2.0, mode="bilinear")
f16 = self.fc16(f16)
f16 = self.dwconv16(f16 * att32_up + f32_up)
f16_up = F.interpolate(f16, scale_factor=2.0, mode="bilinear")
att16_up = F.interpolate(self.att16(
f16), scale_factor=2.0, mode="bilinear")
f8 = self.fc8(f8)
f8 = self.dwconv8(f8 * att16_up + f16_up)
feat_c0, feat_c1 = f8.chunk(2)
elif len(ms_feats) == 6: # diffirent image shape
f8_0, f16_0, f32_0, f8_1, f16_1, f32_1 = ms_feats
f32_0 = self.fc32(f32_0)
f32_1 = self.fc32(f32_1)
f32_0, f32_1 = self.loftr_32(f32_0, f32_1, mask_c0, mask_c1)
f8, f16, f32 = f8_0, f16_0, f32_0
f32_up = F.interpolate(f32, scale_factor=2.0, mode="bilinear")
att32_up = F.interpolate(self.att32(
f32), scale_factor=2.0, mode="bilinear")
f16 = self.fc16(f16)
f16 = self.dwconv16(f16 * att32_up + f32_up)
f16_up = F.interpolate(f16, scale_factor=2.0, mode="bilinear")
att16_up = F.interpolate(self.att16(
f16), scale_factor=2.0, mode="bilinear")
f8 = self.fc8(f8)
f8 = self.dwconv8(f8 * att16_up + f16_up)
feat_c0 = f8
f8, f16, f32 = f8_1, f16_1, f32_1
f32_up = F.interpolate(f32, scale_factor=2.0, mode="bilinear")
att32_up = F.interpolate(self.att32(
f32), scale_factor=2.0, mode="bilinear")
f16 = self.fc16(f16)
f16 = self.dwconv16(f16 * att32_up + f32_up)
f16_up = F.interpolate(f16, scale_factor=2.0, mode="bilinear")
att16_up = F.interpolate(self.att16(
f16), scale_factor=2.0, mode="bilinear")
f8 = self.fc8(f8)
f8 = self.dwconv8(f8 * att16_up + f16_up)
feat_c1 = f8
return feat_c0, feat_c1