Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from nets.Common import GIE, LMDNet | |
| def autopad(k, p=None): | |
| if p is None: | |
| p = k // 2 if isinstance(k, int) else [x // 2 for x in k] | |
| return p | |
| class Conv(nn.Module): | |
| def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)): # ch_in, ch_out, kernel, stride, padding, groups | |
| super(Conv, self).__init__() | |
| self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) | |
| self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03) | |
| self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) | |
| def forward(self, x): | |
| return self.act(self.bn(self.conv(x))) | |
| def fuseforward(self, x): | |
| return self.act(self.conv(x)) | |
| # Multi-branch Pooling Information Fusion Module (Multi_Concat_Block + MP)# | |
| # ------------------------------------------------------------------------- # | |
| class Multi_Concat_Block(nn.Module): | |
| def __init__(self, c1, c2, c3, n=4, e=1, ids=[0]): | |
| super(Multi_Concat_Block, self).__init__() | |
| c_ = int(c2 * e) | |
| self.ids = ids | |
| self.cv1 = Conv(c1, c_, 1, 1) | |
| self.cv2 = Conv(c1, c_, 1, 1) | |
| self.cv3 = nn.ModuleList( | |
| [Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)] | |
| ) | |
| self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1) | |
| self.GIE = GIE(c1) | |
| def forward(self, x): | |
| x_1 = self.cv1(x) | |
| x_1 = self.GIE(x_1) | |
| x_2 = self.cv2(x) | |
| x_all = [x_1, x_2] | |
| for i in range(len(self.cv3)): | |
| x_2 = self.cv3[i](x_2) | |
| x_all.append(x_2) | |
| out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1)) | |
| return out | |
| class MP(nn.Module): | |
| def __init__(self, k=2): | |
| super(MP, self).__init__() | |
| self.m1 = nn.MaxPool2d(kernel_size=k, stride=k) | |
| self.m2 = nn.AvgPool2d(kernel_size=k, stride=k) | |
| def forward(self, x): | |
| x1 = self.m1(x) | |
| x2 = self.m2(x) | |
| return x1 + x2 | |
| # ------------------------------------------------------------------------- # | |
| class Backbone(nn.Module): | |
| def __init__(self, transition_channels, block_channels, n): | |
| super().__init__() | |
| ids = [-1, -2, -3, -4] | |
| self.stem = Conv(3, transition_channels * 2, 3, 2) | |
| self.dehze = LMDNet() | |
| self.dark2 = nn.Sequential( | |
| Conv(transition_channels * 2, transition_channels * 4, 3, 2), | |
| Multi_Concat_Block(transition_channels * 4, block_channels * 2, transition_channels * 4, n=n, ids=ids), | |
| ) | |
| self.dark3 = nn.Sequential( | |
| MP(), | |
| Multi_Concat_Block(transition_channels * 4, block_channels * 4, transition_channels * 8, n=n, ids=ids), | |
| ) | |
| self.dark4 = nn.Sequential( | |
| MP(), | |
| Multi_Concat_Block(transition_channels * 8, block_channels * 8, transition_channels * 16, n=n, ids=ids), | |
| ) | |
| self.dark5 = nn.Sequential( | |
| MP(), | |
| Multi_Concat_Block(transition_channels * 16, block_channels * 16, transition_channels * 32, n=n, ids=ids), | |
| ) | |
| def forward(self, x): | |
| if self.training: | |
| x, clear_x = x.split((8, 8), dim=0) | |
| x = self.stem(x) | |
| x = self.dark2(x) | |
| f1 = x | |
| x = self.dark3(x) | |
| feat1 = x | |
| f2 = x | |
| x = self.dark4(x) | |
| feat2 = x | |
| f3 = x | |
| x = self.dark5(x) | |
| feat3 = x | |
| dehazing = self.dehze(f1, f2, f3) | |
| if self.training: | |
| return feat1, feat2, feat3, dehazing | |
| return feat1, feat2, feat3 |