Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def conv(in_channels, out_channels, kernel_size, bias=False, stride=1): | |
| return nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| padding=(kernel_size // 2), | |
| bias=bias, | |
| stride=stride, | |
| ) | |
| class CALayer(nn.Module): | |
| def __init__(self, channel, reduction=16, bias=False): | |
| super().__init__() | |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.conv_du = nn.Sequential( | |
| nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| y = self.avg_pool(x) | |
| y = self.conv_du(y) | |
| return x * y | |
| class CAB(nn.Module): | |
| def __init__(self, n_feat, kernel_size, reduction, bias, act): | |
| super().__init__() | |
| self.body = nn.Sequential( | |
| conv(n_feat, n_feat, kernel_size, bias=bias), | |
| act, | |
| conv(n_feat, n_feat, kernel_size, bias=bias), | |
| ) | |
| self.CA = CALayer(n_feat, reduction, bias=bias) | |
| def forward(self, x): | |
| res = self.body(x) | |
| res = self.CA(res) | |
| res += x | |
| return res | |
| class CG(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.sobel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]).unsqueeze(0).unsqueeze(0) | |
| self.sobel_y = torch.tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]).unsqueeze(0).unsqueeze(0) | |
| def forward(self, image): | |
| device = image.device | |
| sobel_x = self.sobel_x.to(device) | |
| sobel_y = self.sobel_y.to(device) | |
| gradients = [] | |
| for c in range(image.shape[1]): | |
| grad_x = F.conv2d(image[:, c : c + 1, :, :], sobel_x, padding=1) | |
| grad_y = F.conv2d(image[:, c : c + 1, :, :], sobel_y, padding=1) | |
| gradient = torch.sqrt(torch.clamp(grad_x**2 + grad_y**2, min=1e-6)) | |
| gradients.append(gradient) | |
| return torch.cat(gradients, dim=1) | |
| class CGAM(nn.Module): | |
| def __init__(self, n_feat, kernel_size, bias): | |
| super().__init__() | |
| self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) | |
| self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) | |
| self.conv3 = conv(3, n_feat, kernel_size, bias=bias) | |
| self.gradfilter = CG() | |
| def forward(self, x, x_img): | |
| x1 = self.conv1(x) | |
| img = self.conv2(x) + x_img | |
| rain_grad = self.gradfilter(x_img) | |
| clean_grad = self.gradfilter(img) | |
| x2 = self.conv3(torch.abs(rain_grad - clean_grad)) | |
| grad_att = torch.sigmoid(x1 * x2) | |
| x1 = x1 + x1 * grad_att | |
| return x1, img | |
| class DownSample(nn.Module): | |
| def __init__(self, in_channels, s_factor): | |
| super().__init__() | |
| self.down = nn.Sequential( | |
| nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=False), | |
| nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False), | |
| ) | |
| def forward(self, x): | |
| return self.down(x) | |
| class SkipUpSample(nn.Module): | |
| def __init__(self, in_channels, s_factor): | |
| super().__init__() | |
| self.up = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), | |
| nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False), | |
| ) | |
| def forward(self, x, y): | |
| x = self.up(x) | |
| return x + y | |
| class Encoder(nn.Module): | |
| def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff): | |
| super().__init__() | |
| self.encoder_level1 = nn.Sequential(*[CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]) | |
| self.encoder_level2 = nn.Sequential(*[CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]) | |
| self.encoder_level3 = nn.Sequential(*[CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)]) | |
| self.down12 = DownSample(n_feat, scale_unetfeats) | |
| self.down23 = DownSample(n_feat + scale_unetfeats, scale_unetfeats) | |
| if csff: | |
| self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) | |
| self.csff_enc2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias) | |
| self.csff_enc3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1, bias=bias) | |
| self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) | |
| self.csff_dec2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias) | |
| self.csff_dec3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1, bias=bias) | |
| def forward(self, x, encoder_outs=None, decoder_outs=None): | |
| enc1 = self.encoder_level1(x) | |
| if (encoder_outs is not None) and (decoder_outs is not None): | |
| enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0]) | |
| x = self.down12(enc1) | |
| enc2 = self.encoder_level2(x) | |
| if (encoder_outs is not None) and (decoder_outs is not None): | |
| enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1]) | |
| x = self.down23(enc2) | |
| enc3 = self.encoder_level3(x) | |
| if (encoder_outs is not None) and (decoder_outs is not None): | |
| enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2]) | |
| return [enc1, enc2, enc3] | |
| class Decoder(nn.Module): | |
| def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats): | |
| super().__init__() | |
| self.decoder_level1 = nn.Sequential(*[CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]) | |
| self.decoder_level2 = nn.Sequential(*[CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)]) | |
| self.decoder_level3 = nn.Sequential(*[CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)]) | |
| self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act) | |
| self.skip_attn2 = CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) | |
| self.up21 = SkipUpSample(n_feat, scale_unetfeats) | |
| self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats) | |
| def forward(self, outs): | |
| enc1, enc2, enc3 = outs | |
| dec3 = self.decoder_level3(enc3) | |
| x = self.up32(dec3, self.skip_attn2(enc2)) | |
| dec2 = self.decoder_level2(x) | |
| x = self.up21(dec2, self.skip_attn1(enc1)) | |
| dec1 = self.decoder_level1(x) | |
| return [dec1, dec2, dec3] | |