Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from thop import profile | |
| class SiLU(nn.Module): | |
| def forward(x): | |
| return x * torch.sigmoid(x) | |
| def autopad(k, p=None): | |
| if p is None: | |
| p = k // 2 if isinstance(k, int) else [x // 2 for x in k] | |
| return p | |
| class Conv(nn.Module): | |
| def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)): # ch_in, ch_out, kernel, stride, padding, groups | |
| super(Conv, self).__init__() | |
| self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) | |
| self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03) | |
| self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) | |
| def forward(self, x): | |
| return self.act(self.bn(self.conv(x))) | |
| def fuseforward(self, x): | |
| return self.act(self.conv(x)) | |
| class BasicConv(nn.Module): | |
| def __init__( | |
| self, | |
| in_planes, | |
| out_planes, | |
| kernel_size, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| relu=True, | |
| bn=True, | |
| bias=False, | |
| ): | |
| super(BasicConv, self).__init__() | |
| self.out_channels = out_planes | |
| self.conv = nn.Conv2d( | |
| in_planes, | |
| out_planes, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias, | |
| ) | |
| self.bn = ( | |
| nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) | |
| if bn | |
| else None | |
| ) | |
| self.relu = nn.ReLU() if relu else None | |
| def forward(self, x): | |
| x = self.conv(x) | |
| if self.bn is not None: | |
| x = self.bn(x) | |
| if self.relu is not None: | |
| x = self.relu(x) | |
| return x | |
| class ChannelPool(nn.Module): | |
| def forward(self, x): | |
| return torch.cat( | |
| (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 | |
| ) | |
| class SpatialGate(nn.Module): | |
| def __init__(self): | |
| super(SpatialGate, self).__init__() | |
| kernel_size = 7 | |
| self.compress = ChannelPool() | |
| self.spatial = BasicConv( | |
| 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False | |
| ) | |
| def forward(self, x): | |
| x_compress = self.compress(x) | |
| x_out = self.spatial(x_compress) | |
| scale = torch.sigmoid_(x_out) | |
| return x * scale | |
| def autopad(k, p=None): | |
| if p is None: | |
| p = k // 2 if isinstance(k, int) else [x // 2 for x in k] | |
| return p | |
| class Conv(nn.Module): | |
| def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)): # ch_in, ch_out, kernel, stride, padding, groups | |
| super(Conv, self).__init__() | |
| self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) | |
| self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03) | |
| self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) | |
| def forward(self, x): | |
| return self.act(self.bn(self.conv(x))) | |
| #lighting dehaze network | |
| class LMDNet(nn.Module): | |
| def __init__(self): | |
| super(LMDNet, self).__init__() | |
| # mainNet Architecture | |
| self.AAM = nn.Sequential( | |
| nn.Conv2d(64, 3, 1, 1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.Dropout(0.5) | |
| ) | |
| self.AAM_1 = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
| nn.Conv2d(128, 32, 1, 1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.Dropout(0.5) | |
| ) | |
| self.AAM_2 = nn.Sequential( | |
| nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True), | |
| nn.Conv2d(256, 32, 1, 1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.Dropout(0.5) | |
| ) | |
| self.TA = TripletAttention(64) | |
| self.conv = Conv(64, 3, 3, 1) | |
| self.up4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) | |
| self.relu = nn.LeakyReLU(0.1, inplace=True) | |
| def forward(self, f1, f2, f3): | |
| t = self.AAM(f1) | |
| f2 = self.AAM_1(f2) | |
| f3 = self.AAM_2(f3) | |
| x1 = f1 | |
| x2 = torch.cat([f2, f3], dim=1) | |
| x = x1 + x2 | |
| x = self.TA(x) | |
| x = self.conv(x) | |
| dehaze = ((x * t) - x + 1) | |
| out = self.up4(dehaze) | |
| out = self.relu(out) | |
| return out | |
| class TripletAttention(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| reduction_ratio=16, | |
| pool_types=["avg", "max"], | |
| no_spatial=False, | |
| ): | |
| super(TripletAttention, self).__init__() | |
| self.ChannelGateH = SpatialGate() | |
| self.ChannelGateW = SpatialGate() | |
| self.no_spatial = no_spatial | |
| if not no_spatial: | |
| self.SpatialGate = SpatialGate() | |
| def forward(self, x): | |
| x_perm1 = x.permute(0, 2, 1, 3).contiguous() | |
| x_out1 = self.ChannelGateH(x_perm1) | |
| x_out11 = x_out1.permute(0, 2, 1, 3).contiguous() | |
| x_perm2 = x.permute(0, 3, 2, 1).contiguous() | |
| x_out2 = self.ChannelGateW(x_perm2) | |
| x_out21 = x_out2.permute(0, 3, 2, 1).contiguous() | |
| if not self.no_spatial: | |
| x_out = self.SpatialGate(x) | |
| x_out = (1 / 3) * (x_out + x_out11 + x_out21) | |
| else: | |
| x_out = (1 / 2) * (x_out11 + x_out21) | |
| return x_out | |
| class SiLU(nn.Module): | |
| def forward(x): | |
| return x * torch.sigmoid(x) | |
| def autopad(k, p=None): | |
| if p is None: | |
| p = k // 2 if isinstance(k, int) else [x // 2 for x in k] | |
| return p | |
| class Conv(nn.Module): | |
| def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)): | |
| super(Conv, self).__init__() | |
| self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) | |
| self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03) | |
| self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) | |
| def forward(self, x): | |
| return self.act(self.bn(self.conv(x))) | |
| def fuseforward(self, x): | |
| return self.act(self.conv(x)) | |
| class GIE(torch.nn.Module): | |
| def __init__(self, epsilon=1e-8): | |
| super(GIE, self).__init__() | |
| self.epsilon = epsilon | |
| def forward(self, x): | |
| # Step 1: Pixel Mean Squared | |
| x_mean = torch.mean(x, dim=(2, 3), keepdim=True) | |
| epsilon = (x - x_mean) ** 2 | |
| # Step 2: Global Extraction | |
| epsilon_mean = torch.mean(epsilon, dim=(2, 3), keepdim=False) | |
| epsilon_mean += self.epsilon | |
| # Step 3: Gamma Calculation (Global Extraction) | |
| gamma_t_c = epsilon / epsilon_mean.unsqueeze(-1).unsqueeze(-1) | |
| sigmoid_gamma = torch.sigmoid(gamma_t_c) | |
| output = x * sigmoid_gamma | |
| return output | |
| # Multi-branch Pooling Information Fusion Module | |
| class MPIF(nn.Module): | |
| def __init__(self, c1, c2, c3, s=2, n=4, e=1, ids=[0]): | |
| super(MPIF, self).__init__() | |
| c_ = int(c2 * e) | |
| self.ids = ids | |
| if s == 1: | |
| self.m1 = nn.MaxPool2d(kernel_size=3, stride=s, padding=1) | |
| self.m2 = nn.AvgPool2d(kernel_size=3, stride=s, padding=1) | |
| else: | |
| self.m1 = nn.MaxPool2d(kernel_size=2, stride=s) | |
| self.m2 = nn.AvgPool2d(kernel_size=2, stride=s) | |
| self.cv1 = Conv(c1, c_, 1, 1) | |
| self.cv2 = Conv(c1, c_, 1, 1) | |
| self.cv3 = nn.ModuleList( | |
| [Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)] | |
| ) | |
| self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1) | |
| self.GIE = GIE(c1) | |
| def forward(self, x): | |
| x1 = self.m1(x) | |
| x2 = self.m2(x) | |
| x = x1 + x2 | |
| x_1 = self.cv1(x) | |
| x_1 = self.GIE(x_1) | |
| x_2 = self.cv2(x) | |
| x_all = [x_1, x_2] | |
| for i in range(len(self.cv3)): | |
| x_2 = self.cv3[i](x_2) | |
| x_all.append(x_2) | |
| out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1)) | |
| return out | |
| class DilatedConvNet(nn.Module): | |
| def __init__(self, in_channels, out_channels, dilation, padding, kernel_size): | |
| super(DilatedConvNet, self).__init__() | |
| self.dilated_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation) | |
| self.relu = nn.ReLU(inplace=False) | |
| def forward(self, x): | |
| x = self.dilated_conv(x) | |
| x = self.relu(x) | |
| return x | |
| class SPPELAN(nn.Module): | |
| def __init__(self, c1, c2, c3=16): | |
| super().__init__() | |
| self.c = c3 | |
| self.cv1 = Conv(c1, c3, 1, 1) | |
| self.cv2 = DilatedConvNet(c3, c3, kernel_size=3, padding=1, dilation=1) | |
| self.cv3 = DilatedConvNet(c3, c3, kernel_size=3, padding=2, dilation=2) | |
| self.cv4 = DilatedConvNet(c3, c3, kernel_size=3, padding=3, dilation=3) | |
| self.cv5 = Conv(4*c3, c2, 1, 1) | |
| def forward(self, x): | |
| y = [self.cv1(x)] | |
| y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4]) | |
| return self.cv5(torch.cat(y, 1)) | |
| def print_model_flops_and_params(model, inputs): | |
| flops, params = profile(model, inputs=inputs) | |
| print(f"FLOPs: {flops / 1e9:.2f} GFLOPs") | |
| print(f"Parameters: {params / 1e6:.2f} M") | |
| if __name__ == "__main__": | |
| feat1 = torch.randn(1, 64, 160, 160) | |
| feat2 = torch.randn(1, 128, 80, 80) | |
| feat3 = torch.randn(1, 256, 40, 40) | |
| encoder = LMDNet() | |
| print_model_flops_and_params(encoder, (feat1, feat2, feat3)) | |