Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import math | |
| import models | |
| from models import register | |
| def default_conv(in_channels, out_channels, kernel_size, bias=True): | |
| return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) | |
| ## Channel Attention (CA) Layer | |
| class CALayer(nn.Module): | |
| def __init__(self, channel, reduction=16): | |
| super(CALayer, self).__init__() | |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.conv_du = nn.Sequential( | |
| nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| y = self.avg_pool(x) | |
| y = self.conv_du(y) | |
| return x * y | |
| ## Residual Channel Attention Block (RCAB) | |
| class RCAB(nn.Module): | |
| def __init__( | |
| self, conv, n_feat, kernel_size, reduction, | |
| bias=True, bn=False, act=nn.ReLU(True), res_scale=1): | |
| super(RCAB, self).__init__() | |
| modules_body = [] | |
| for i in range(2): | |
| modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) | |
| if bn: modules_body.append(nn.BatchNorm2d(n_feat)) | |
| if i == 0: modules_body.append(act) | |
| modules_body.append(CALayer(n_feat, reduction)) | |
| self.body = nn.Sequential(*modules_body) | |
| self.res_scale = res_scale | |
| def forward(self, x): | |
| res = self.body(x) | |
| res += x | |
| return res | |
| ## Residual Group (RG) | |
| class ResidualGroup(nn.Module): | |
| def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): | |
| super(ResidualGroup, self).__init__() | |
| modules_body = [ | |
| RCAB( | |
| conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ | |
| for _ in range(n_resblocks)] | |
| modules_body.append(conv(n_feat, n_feat, kernel_size)) | |
| self.body = nn.Sequential(*modules_body) | |
| def forward(self, x): | |
| res = self.body(x) | |
| res += x | |
| return res | |
| class SA_upsample(nn.Module): | |
| def __init__(self, channels, num_experts=4, bias=False): | |
| super(SA_upsample, self).__init__() | |
| self.bias = bias | |
| self.num_experts = num_experts | |
| self.channels = channels | |
| # experts | |
| weight_compress = [] | |
| for i in range(num_experts): | |
| weight_compress.append(nn.Parameter(torch.Tensor(channels//8, channels, 1, 1))) | |
| nn.init.kaiming_uniform_(weight_compress[i], a=math.sqrt(5)) | |
| self.weight_compress = nn.Parameter(torch.stack(weight_compress, 0)) | |
| weight_expand = [] | |
| for i in range(num_experts): | |
| weight_expand.append(nn.Parameter(torch.Tensor(channels, channels//8, 1, 1))) | |
| nn.init.kaiming_uniform_(weight_expand[i], a=math.sqrt(5)) | |
| self.weight_expand = nn.Parameter(torch.stack(weight_expand, 0)) | |
| # two FC layers | |
| self.body = nn.Sequential( | |
| nn.Conv2d(4, 64, 1, 1, 0, bias=True), | |
| nn.ReLU(True), | |
| nn.Conv2d(64, 64, 1, 1, 0, bias=True), | |
| nn.ReLU(True), | |
| ) | |
| # routing head | |
| self.routing = nn.Sequential( | |
| nn.Conv2d(64, num_experts, 1, 1, 0, bias=True), | |
| nn.Sigmoid() | |
| ) | |
| # offset head | |
| self.offset = nn.Conv2d(64, 2, 1, 1, 0, bias=True) | |
| def forward(self, x, scale, scale2): | |
| b, c, h, w = x.size() | |
| # (1) coordinates in LR space | |
| ## coordinates in HR space | |
| coor_hr = [torch.arange(0, round(h * scale), 1).unsqueeze(0).float().to(x.device), | |
| torch.arange(0, round(w * scale2), 1).unsqueeze(0).float().to(x.device)] | |
| ## coordinates in LR space | |
| coor_h = ((coor_hr[0] + 0.5) / scale) - (torch.floor((coor_hr[0] + 0.5) / scale + 1e-3)) - 0.5 | |
| coor_h = coor_h.permute(1, 0) | |
| coor_w = ((coor_hr[1] + 0.5) / scale2) - (torch.floor((coor_hr[1] + 0.5) / scale2 + 1e-3)) - 0.5 | |
| input = torch.cat(( | |
| torch.ones_like(coor_h).expand([-1, round(scale2 * w)]).unsqueeze(0) / scale2, | |
| torch.ones_like(coor_h).expand([-1, round(scale2 * w)]).unsqueeze(0) / scale, | |
| coor_h.expand([-1, round(scale2 * w)]).unsqueeze(0), | |
| coor_w.expand([round(scale * h), -1]).unsqueeze(0) | |
| ), 0).unsqueeze(0) | |
| # (2) predict filters and offsets | |
| embedding = self.body(input) | |
| ## offsets | |
| offset = self.offset(embedding) | |
| ## filters | |
| routing_weights = self.routing(embedding) | |
| routing_weights = routing_weights.view(self.num_experts, round(scale*h) * round(scale2*w)).transpose(0, 1) # (h*w) * n | |
| weight_compress = self.weight_compress.view(self.num_experts, -1) | |
| weight_compress = torch.matmul(routing_weights, weight_compress) | |
| weight_compress = weight_compress.view(1, round(scale*h), round(scale2*w), self.channels//8, self.channels) | |
| weight_expand = self.weight_expand.view(self.num_experts, -1) | |
| weight_expand = torch.matmul(routing_weights, weight_expand) | |
| weight_expand = weight_expand.view(1, round(scale*h), round(scale2*w), self.channels, self.channels//8) | |
| # (3) grid sample & spatially varying filtering | |
| ## grid sample | |
| fea0 = grid_sample(x, offset, scale, scale2) ## b * h * w * c * 1 | |
| fea = fea0.unsqueeze(-1).permute(0, 2, 3, 1, 4) ## b * h * w * c * 1 | |
| ## spatially varying filtering | |
| out = torch.matmul(weight_compress.expand([b, -1, -1, -1, -1]), fea) | |
| out = torch.matmul(weight_expand.expand([b, -1, -1, -1, -1]), out).squeeze(-1) | |
| return out.permute(0, 3, 1, 2) + fea0 | |
| class SA_adapt(nn.Module): | |
| def __init__(self, channels): | |
| super(SA_adapt, self).__init__() | |
| self.mask = nn.Sequential( | |
| nn.Conv2d(channels, 16, 3, 1, 1), | |
| nn.BatchNorm2d(16), | |
| nn.ReLU(True), | |
| nn.AvgPool2d(2), | |
| nn.Conv2d(16, 16, 3, 1, 1), | |
| nn.BatchNorm2d(16), | |
| nn.ReLU(True), | |
| nn.Conv2d(16, 16, 3, 1, 1), | |
| nn.BatchNorm2d(16), | |
| nn.ReLU(True), | |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | |
| nn.Conv2d(16, 1, 3, 1, 1), | |
| nn.BatchNorm2d(1), | |
| nn.Sigmoid() | |
| ) | |
| self.adapt = SA_conv(channels, channels, 3, 1, 1) | |
| def forward(self, x, scale, scale2): | |
| mask = self.mask(x) | |
| adapted = self.adapt(x, scale, scale2) | |
| return x + adapted * mask | |
| class SA_conv(nn.Module): | |
| def __init__(self, channels_in, channels_out, kernel_size=3, stride=1, padding=1, bias=False, num_experts=4): | |
| super(SA_conv, self).__init__() | |
| self.channels_out = channels_out | |
| self.channels_in = channels_in | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.padding = padding | |
| self.num_experts = num_experts | |
| self.bias = bias | |
| # FC layers to generate routing weights | |
| self.routing = nn.Sequential( | |
| nn.Linear(2, 64), | |
| nn.ReLU(True), | |
| nn.Linear(64, num_experts), | |
| nn.Softmax(1) | |
| ) | |
| # initialize experts | |
| weight_pool = [] | |
| for i in range(num_experts): | |
| weight_pool.append(nn.Parameter(torch.Tensor(channels_out, channels_in, kernel_size, kernel_size))) | |
| nn.init.kaiming_uniform_(weight_pool[i], a=math.sqrt(5)) | |
| self.weight_pool = nn.Parameter(torch.stack(weight_pool, 0)) | |
| if bias: | |
| self.bias_pool = nn.Parameter(torch.Tensor(num_experts, channels_out)) | |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_pool) | |
| bound = 1 / math.sqrt(fan_in) | |
| nn.init.uniform_(self.bias_pool, -bound, bound) | |
| def forward(self, x, scale, scale2): | |
| # generate routing weights | |
| scale = torch.ones(1, 1).to(x.device) / scale | |
| scale2 = torch.ones(1, 1).to(x.device) / scale2 | |
| routing_weights = self.routing(torch.cat((scale, scale2), 1)).view(self.num_experts, 1, 1) | |
| # fuse experts | |
| fused_weight = (self.weight_pool.view(self.num_experts, -1, 1) * routing_weights).sum(0) | |
| fused_weight = fused_weight.view(-1, self.channels_in, self.kernel_size, self.kernel_size) | |
| if self.bias: | |
| fused_bias = torch.mm(routing_weights, self.bias_pool).view(-1) | |
| else: | |
| fused_bias = None | |
| # convolution | |
| out = F.conv2d(x, fused_weight, fused_bias, stride=self.stride, padding=self.padding) | |
| return out | |
| def grid_sample(x, offset, scale, scale2): | |
| # generate grids | |
| b, _, h, w = x.size() | |
| grid = np.meshgrid(range(round(scale2*w)), range(round(scale*h))) | |
| grid = np.stack(grid, axis=-1).astype(np.float64) | |
| grid = torch.Tensor(grid).to(x.device) | |
| # project into LR space | |
| grid[:, :, 0] = (grid[:, :, 0] + 0.5) / scale2 - 0.5 | |
| grid[:, :, 1] = (grid[:, :, 1] + 0.5) / scale - 0.5 | |
| # normalize to [-1, 1] | |
| grid[:, :, 0] = grid[:, :, 0] * 2 / (w - 1) -1 | |
| grid[:, :, 1] = grid[:, :, 1] * 2 / (h - 1) -1 | |
| grid = grid.permute(2, 0, 1).unsqueeze(0) | |
| grid = grid.expand([b, -1, -1, -1]) | |
| # add offsets | |
| offset_0 = torch.unsqueeze(offset[:, 0, :, :] * 2 / (w - 1), dim=1) | |
| offset_1 = torch.unsqueeze(offset[:, 1, :, :] * 2 / (h - 1), dim=1) | |
| grid = grid + torch.cat((offset_0, offset_1),1) | |
| grid = grid.permute(0, 2, 3, 1) | |
| # sampling | |
| output = F.grid_sample(x, grid, padding_mode='zeros') | |
| return output | |
| class ArbRCAN(nn.Module): | |
| def __init__(self, encoder_spec=None, conv=default_conv): | |
| super(ArbRCAN, self).__init__() | |
| n_resgroups = 10 | |
| n_resblocks = 20 | |
| n_feats = 64 | |
| kernel_size = 3 | |
| reduction = 16 | |
| act = nn.ReLU(True) | |
| n_colors = 3 | |
| res_scale = 1 | |
| self.n_resgroups = n_resgroups | |
| # head module | |
| modules_head = [conv(n_colors, n_feats, kernel_size)] | |
| self.head = nn.Sequential(*modules_head) | |
| # body module | |
| modules_body = [ | |
| ResidualGroup(conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, | |
| n_resblocks=n_resblocks) \ | |
| for _ in range(n_resgroups)] | |
| modules_body.append(conv(n_feats, n_feats, kernel_size)) | |
| self.body = nn.Sequential(*modules_body) | |
| # tail module | |
| modules_tail = [ | |
| None, # placeholder to match pre-trained RCAN model | |
| conv(n_feats, n_colors, kernel_size)] | |
| self.tail = nn.Sequential(*modules_tail) | |
| ########## our plug-in module ########## | |
| # scale-aware feature adaption block | |
| # For RCAN, feature adaption is performed after each backbone block, i.e., K=1 | |
| self.K = 1 | |
| sa_adapt = [] | |
| for i in range(self.n_resgroups // self.K): | |
| sa_adapt.append(SA_adapt(64)) | |
| self.sa_adapt = nn.Sequential(*sa_adapt) | |
| # scale-aware upsampling layer | |
| self.sa_upsample = SA_upsample(64) | |
| def set_scale(self, scale, scale2): | |
| self.scale = scale | |
| self.scale2 = scale2 | |
| def forward(self, x, size): | |
| B, C, H, W = x.shape | |
| H_up, W_up = size | |
| scale = H_up / H | |
| scale2 = W_up / W | |
| # head | |
| x = self.head(x) | |
| # body | |
| res = x | |
| for i in range(self.n_resgroups): | |
| res = self.body[i](res) | |
| # scale-aware feature adaption | |
| if (i+1) % self.K == 0: | |
| res = self.sa_adapt[i](res, scale, scale2) | |
| res = self.body[-1](res) | |
| res += x | |
| # scale-aware upsampling | |
| res = self.sa_upsample(res, scale, scale2) | |
| # tail | |
| x = self.tail[1](res) | |
| return x | |