RDFNet / nets /backbone.py
PolarisFTL's picture
Add nets modules
c79402e verified
Raw
History Blame Contribute Delete
3.67 kB
import torch
import torch.nn as nn
from nets.Common import GIE, LMDNet
def autopad(k, p=None):
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
return p
class Conv(nn.Module):
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)): # ch_in, ch_out, kernel, stride, padding, groups
super(Conv, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def fuseforward(self, x):
return self.act(self.conv(x))
# Multi-branch Pooling Information Fusion Module (Multi_Concat_Block + MP)#
# ------------------------------------------------------------------------- #
class Multi_Concat_Block(nn.Module):
def __init__(self, c1, c2, c3, n=4, e=1, ids=[0]):
super(Multi_Concat_Block, self).__init__()
c_ = int(c2 * e)
self.ids = ids
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = nn.ModuleList(
[Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)]
)
self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1)
self.GIE = GIE(c1)
def forward(self, x):
x_1 = self.cv1(x)
x_1 = self.GIE(x_1)
x_2 = self.cv2(x)
x_all = [x_1, x_2]
for i in range(len(self.cv3)):
x_2 = self.cv3[i](x_2)
x_all.append(x_2)
out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1))
return out
class MP(nn.Module):
def __init__(self, k=2):
super(MP, self).__init__()
self.m1 = nn.MaxPool2d(kernel_size=k, stride=k)
self.m2 = nn.AvgPool2d(kernel_size=k, stride=k)
def forward(self, x):
x1 = self.m1(x)
x2 = self.m2(x)
return x1 + x2
# ------------------------------------------------------------------------- #
class Backbone(nn.Module):
def __init__(self, transition_channels, block_channels, n):
super().__init__()
ids = [-1, -2, -3, -4]
self.stem = Conv(3, transition_channels * 2, 3, 2)
self.dehze = LMDNet()
self.dark2 = nn.Sequential(
Conv(transition_channels * 2, transition_channels * 4, 3, 2),
Multi_Concat_Block(transition_channels * 4, block_channels * 2, transition_channels * 4, n=n, ids=ids),
)
self.dark3 = nn.Sequential(
MP(),
Multi_Concat_Block(transition_channels * 4, block_channels * 4, transition_channels * 8, n=n, ids=ids),
)
self.dark4 = nn.Sequential(
MP(),
Multi_Concat_Block(transition_channels * 8, block_channels * 8, transition_channels * 16, n=n, ids=ids),
)
self.dark5 = nn.Sequential(
MP(),
Multi_Concat_Block(transition_channels * 16, block_channels * 16, transition_channels * 32, n=n, ids=ids),
)
def forward(self, x):
if self.training:
x, clear_x = x.split((8, 8), dim=0)
x = self.stem(x)
x = self.dark2(x)
f1 = x
x = self.dark3(x)
feat1 = x
f2 = x
x = self.dark4(x)
feat2 = x
f3 = x
x = self.dark5(x)
feat3 = x
dehazing = self.dehze(f1, f2, f3)
if self.training:
return feat1, feat2, feat3, dehazing
return feat1, feat2, feat3