RARR-Deraining / models /modules.py
SeunghoEum's picture
Add Gradio deraining demo
69f6042
import torch
import torch.nn as nn
import torch.nn.functional as F
def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size,
padding=(kernel_size // 2),
bias=bias,
stride=stride,
)
class CALayer(nn.Module):
def __init__(self, channel, reduction=16, bias=False):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_du = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
nn.Sigmoid(),
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
class CAB(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, bias, act):
super().__init__()
self.body = nn.Sequential(
conv(n_feat, n_feat, kernel_size, bias=bias),
act,
conv(n_feat, n_feat, kernel_size, bias=bias),
)
self.CA = CALayer(n_feat, reduction, bias=bias)
def forward(self, x):
res = self.body(x)
res = self.CA(res)
res += x
return res
class CG(nn.Module):
def __init__(self):
super().__init__()
self.sobel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]).unsqueeze(0).unsqueeze(0)
self.sobel_y = torch.tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]).unsqueeze(0).unsqueeze(0)
def forward(self, image):
device = image.device
sobel_x = self.sobel_x.to(device)
sobel_y = self.sobel_y.to(device)
gradients = []
for c in range(image.shape[1]):
grad_x = F.conv2d(image[:, c : c + 1, :, :], sobel_x, padding=1)
grad_y = F.conv2d(image[:, c : c + 1, :, :], sobel_y, padding=1)
gradient = torch.sqrt(torch.clamp(grad_x**2 + grad_y**2, min=1e-6))
gradients.append(gradient)
return torch.cat(gradients, dim=1)
class CGAM(nn.Module):
def __init__(self, n_feat, kernel_size, bias):
super().__init__()
self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
self.conv2 = conv(n_feat, 3, kernel_size, bias=bias)
self.conv3 = conv(3, n_feat, kernel_size, bias=bias)
self.gradfilter = CG()
def forward(self, x, x_img):
x1 = self.conv1(x)
img = self.conv2(x) + x_img
rain_grad = self.gradfilter(x_img)
clean_grad = self.gradfilter(img)
x2 = self.conv3(torch.abs(rain_grad - clean_grad))
grad_att = torch.sigmoid(x1 * x2)
x1 = x1 + x1 * grad_att
return x1, img
class DownSample(nn.Module):
def __init__(self, in_channels, s_factor):
super().__init__()
self.down = nn.Sequential(
nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=False),
nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False),
)
def forward(self, x):
return self.down(x)
class SkipUpSample(nn.Module):
def __init__(self, in_channels, s_factor):
super().__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False),
)
def forward(self, x, y):
x = self.up(x)
return x + y
class Encoder(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff):
super().__init__()
self.encoder_level1 = nn.Sequential(*[CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
self.encoder_level2 = nn.Sequential(*[CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
self.encoder_level3 = nn.Sequential(*[CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
self.down12 = DownSample(n_feat, scale_unetfeats)
self.down23 = DownSample(n_feat + scale_unetfeats, scale_unetfeats)
if csff:
self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
self.csff_enc2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias)
self.csff_enc3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1, bias=bias)
self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
self.csff_dec2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias)
self.csff_dec3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1, bias=bias)
def forward(self, x, encoder_outs=None, decoder_outs=None):
enc1 = self.encoder_level1(x)
if (encoder_outs is not None) and (decoder_outs is not None):
enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0])
x = self.down12(enc1)
enc2 = self.encoder_level2(x)
if (encoder_outs is not None) and (decoder_outs is not None):
enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1])
x = self.down23(enc2)
enc3 = self.encoder_level3(x)
if (encoder_outs is not None) and (decoder_outs is not None):
enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2])
return [enc1, enc2, enc3]
class Decoder(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats):
super().__init__()
self.decoder_level1 = nn.Sequential(*[CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
self.decoder_level2 = nn.Sequential(*[CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
self.decoder_level3 = nn.Sequential(*[CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act)
self.skip_attn2 = CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
self.up21 = SkipUpSample(n_feat, scale_unetfeats)
self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats)
def forward(self, outs):
enc1, enc2, enc3 = outs
dec3 = self.decoder_level3(enc3)
x = self.up32(dec3, self.skip_attn2(enc2))
dec2 = self.decoder_level2(x)
x = self.up21(dec2, self.skip_attn1(enc1))
dec1 = self.decoder_level1(x)
return [dec1, dec2, dec3]