RDFNet / nets /Common.py
PolarisFTL's picture
Add nets modules
c79402e verified
Raw
History Blame Contribute Delete
9.76 kB
import torch
import torch.nn as nn
from thop import profile
class SiLU(nn.Module):
@staticmethod
def forward(x):
return x * torch.sigmoid(x)
def autopad(k, p=None):
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
return p
class Conv(nn.Module):
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)): # ch_in, ch_out, kernel, stride, padding, groups
super(Conv, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def fuseforward(self, x):
return self.act(self.conv(x))
class BasicConv(nn.Module):
def __init__(
self,
in_planes,
out_planes,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
relu=True,
bn=True,
bias=False,
):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
self.bn = (
nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
if bn
else None
)
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat(
(torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
)
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(
2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False
)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid_(x_out)
return x * scale
def autopad(k, p=None):
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
return p
class Conv(nn.Module):
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)): # ch_in, ch_out, kernel, stride, padding, groups
super(Conv, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
def forward(self, x):
return self.act(self.bn(self.conv(x)))
#lighting dehaze network
class LMDNet(nn.Module):
def __init__(self):
super(LMDNet, self).__init__()
# mainNet Architecture
self.AAM = nn.Sequential(
nn.Conv2d(64, 3, 1, 1),
nn.LeakyReLU(inplace=True),
nn.Dropout(0.5)
)
self.AAM_1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(128, 32, 1, 1),
nn.LeakyReLU(inplace=True),
nn.Dropout(0.5)
)
self.AAM_2 = nn.Sequential(
nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True),
nn.Conv2d(256, 32, 1, 1),
nn.LeakyReLU(inplace=True),
nn.Dropout(0.5)
)
self.TA = TripletAttention(64)
self.conv = Conv(64, 3, 3, 1)
self.up4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
self.relu = nn.LeakyReLU(0.1, inplace=True)
def forward(self, f1, f2, f3):
t = self.AAM(f1)
f2 = self.AAM_1(f2)
f3 = self.AAM_2(f3)
x1 = f1
x2 = torch.cat([f2, f3], dim=1)
x = x1 + x2
x = self.TA(x)
x = self.conv(x)
dehaze = ((x * t) - x + 1)
out = self.up4(dehaze)
out = self.relu(out)
return out
class TripletAttention(nn.Module):
def __init__(
self,
in_channels,
reduction_ratio=16,
pool_types=["avg", "max"],
no_spatial=False,
):
super(TripletAttention, self).__init__()
self.ChannelGateH = SpatialGate()
self.ChannelGateW = SpatialGate()
self.no_spatial = no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_perm1 = x.permute(0, 2, 1, 3).contiguous()
x_out1 = self.ChannelGateH(x_perm1)
x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
x_perm2 = x.permute(0, 3, 2, 1).contiguous()
x_out2 = self.ChannelGateW(x_perm2)
x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
if not self.no_spatial:
x_out = self.SpatialGate(x)
x_out = (1 / 3) * (x_out + x_out11 + x_out21)
else:
x_out = (1 / 2) * (x_out11 + x_out21)
return x_out
class SiLU(nn.Module):
@staticmethod
def forward(x):
return x * torch.sigmoid(x)
def autopad(k, p=None):
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
return p
class Conv(nn.Module):
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)):
super(Conv, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def fuseforward(self, x):
return self.act(self.conv(x))
class GIE(torch.nn.Module):
def __init__(self, epsilon=1e-8):
super(GIE, self).__init__()
self.epsilon = epsilon
def forward(self, x):
# Step 1: Pixel Mean Squared
x_mean = torch.mean(x, dim=(2, 3), keepdim=True)
epsilon = (x - x_mean) ** 2
# Step 2: Global Extraction
epsilon_mean = torch.mean(epsilon, dim=(2, 3), keepdim=False)
epsilon_mean += self.epsilon
# Step 3: Gamma Calculation (Global Extraction)
gamma_t_c = epsilon / epsilon_mean.unsqueeze(-1).unsqueeze(-1)
sigmoid_gamma = torch.sigmoid(gamma_t_c)
output = x * sigmoid_gamma
return output
# Multi-branch Pooling Information Fusion Module
class MPIF(nn.Module):
def __init__(self, c1, c2, c3, s=2, n=4, e=1, ids=[0]):
super(MPIF, self).__init__()
c_ = int(c2 * e)
self.ids = ids
if s == 1:
self.m1 = nn.MaxPool2d(kernel_size=3, stride=s, padding=1)
self.m2 = nn.AvgPool2d(kernel_size=3, stride=s, padding=1)
else:
self.m1 = nn.MaxPool2d(kernel_size=2, stride=s)
self.m2 = nn.AvgPool2d(kernel_size=2, stride=s)
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = nn.ModuleList(
[Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)]
)
self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1)
self.GIE = GIE(c1)
def forward(self, x):
x1 = self.m1(x)
x2 = self.m2(x)
x = x1 + x2
x_1 = self.cv1(x)
x_1 = self.GIE(x_1)
x_2 = self.cv2(x)
x_all = [x_1, x_2]
for i in range(len(self.cv3)):
x_2 = self.cv3[i](x_2)
x_all.append(x_2)
out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1))
return out
class DilatedConvNet(nn.Module):
def __init__(self, in_channels, out_channels, dilation, padding, kernel_size):
super(DilatedConvNet, self).__init__()
self.dilated_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
x = self.dilated_conv(x)
x = self.relu(x)
return x
class SPPELAN(nn.Module):
def __init__(self, c1, c2, c3=16):
super().__init__()
self.c = c3
self.cv1 = Conv(c1, c3, 1, 1)
self.cv2 = DilatedConvNet(c3, c3, kernel_size=3, padding=1, dilation=1)
self.cv3 = DilatedConvNet(c3, c3, kernel_size=3, padding=2, dilation=2)
self.cv4 = DilatedConvNet(c3, c3, kernel_size=3, padding=3, dilation=3)
self.cv5 = Conv(4*c3, c2, 1, 1)
def forward(self, x):
y = [self.cv1(x)]
y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
return self.cv5(torch.cat(y, 1))
def print_model_flops_and_params(model, inputs):
flops, params = profile(model, inputs=inputs)
print(f"FLOPs: {flops / 1e9:.2f} GFLOPs")
print(f"Parameters: {params / 1e6:.2f} M")
if __name__ == "__main__":
feat1 = torch.randn(1, 64, 160, 160)
feat2 = torch.randn(1, 128, 80, 80)
feat3 = torch.randn(1, 256, 40, 40)
encoder = LMDNet()
print_model_flops_and_params(encoder, (feat1, feat2, feat3))