File size: 5,999 Bytes
61d360d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
'''
NAFSSR: Stereo Image Super-Resolution Using NAFNet
@InProceedings{Chu2022NAFSSR,
author = {Xiaojie Chu and Liangyu Chen and Wenqing Yu},
title = {NAFSSR: Stereo Image Super-Resolution Using NAFNet},
booktitle = {CVPRW},
year = {2022},
}
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.models.archs.NAFNet_arch import LayerNorm2d, NAFBlock
from basicsr.models.archs.arch_util import MySequential
from basicsr.models.archs.local_arch import Local_Base
class SCAM(nn.Module):
'''
Stereo Cross Attention Module (SCAM)
'''
def __init__(self, c):
super().__init__()
self.scale = c ** -0.5
self.norm_l = LayerNorm2d(c)
self.norm_r = LayerNorm2d(c)
self.l_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
self.r_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.l_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
self.r_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
def forward(self, x_l, x_r):
Q_l = self.l_proj1(self.norm_l(x_l)).permute(0, 2, 3, 1) # B, H, W, c
Q_r_T = self.r_proj1(self.norm_r(x_r)).permute(0, 2, 1, 3) # B, H, c, W (transposed)
V_l = self.l_proj2(x_l).permute(0, 2, 3, 1) # B, H, W, c
V_r = self.r_proj2(x_r).permute(0, 2, 3, 1) # B, H, W, c
# (B, H, W, c) x (B, H, c, W) -> (B, H, W, W)
attention = torch.matmul(Q_l, Q_r_T) * self.scale
F_r2l = torch.matmul(torch.softmax(attention, dim=-1), V_r) #B, H, W, c
F_l2r = torch.matmul(torch.softmax(attention.permute(0, 1, 3, 2), dim=-1), V_l) #B, H, W, c
# scale
F_r2l = F_r2l.permute(0, 3, 1, 2) * self.beta
F_l2r = F_l2r.permute(0, 3, 1, 2) * self.gamma
return x_l + F_r2l, x_r + F_l2r
class DropPath(nn.Module):
def __init__(self, drop_rate, module):
super().__init__()
self.drop_rate = drop_rate
self.module = module
def forward(self, *feats):
if self.training and np.random.rand() < self.drop_rate:
return feats
new_feats = self.module(*feats)
factor = 1. / (1 - self.drop_rate) if self.training else 1.
if self.training and factor != 1.:
new_feats = tuple([x+factor*(new_x-x) for x, new_x in zip(feats, new_feats)])
return new_feats
class NAFBlockSR(nn.Module):
'''
NAFBlock for Super-Resolution
'''
def __init__(self, c, fusion=False, drop_out_rate=0.):
super().__init__()
self.blk = NAFBlock(c, drop_out_rate=drop_out_rate)
self.fusion = SCAM(c) if fusion else None
def forward(self, *feats):
feats = tuple([self.blk(x) for x in feats])
if self.fusion:
feats = self.fusion(*feats)
return feats
class NAFNetSR(nn.Module):
'''
NAFNet for Super-Resolution
'''
def __init__(self, up_scale=4, width=48, num_blks=16, img_channel=3, drop_path_rate=0., drop_out_rate=0., fusion_from=-1, fusion_to=-1, dual=False):
super().__init__()
self.dual = dual # dual input for stereo SR (left view, right view)
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
bias=True)
self.body = MySequential(
*[DropPath(
drop_path_rate,
NAFBlockSR(
width,
fusion=(fusion_from <= i and i <= fusion_to),
drop_out_rate=drop_out_rate
)) for i in range(num_blks)]
)
self.up = nn.Sequential(
nn.Conv2d(in_channels=width, out_channels=img_channel * up_scale**2, kernel_size=3, padding=1, stride=1, groups=1, bias=True),
nn.PixelShuffle(up_scale)
)
self.up_scale = up_scale
def forward(self, inp):
inp_hr = F.interpolate(inp, scale_factor=self.up_scale, mode='bilinear')
if self.dual:
inp = inp.chunk(2, dim=1)
else:
inp = (inp, )
feats = [self.intro(x) for x in inp]
feats = self.body(*feats)
out = torch.cat([self.up(x) for x in feats], dim=1)
out = out + inp_hr
return out
class NAFSSR(Local_Base, NAFNetSR):
def __init__(self, *args, train_size=(1, 6, 30, 90), fast_imp=False, fusion_from=-1, fusion_to=1000, **kwargs):
Local_Base.__init__(self)
NAFNetSR.__init__(self, *args, img_channel=3, fusion_from=fusion_from, fusion_to=fusion_to, dual=True, **kwargs)
N, C, H, W = train_size
base_size = (int(H * 1.5), int(W * 1.5))
self.eval()
with torch.no_grad():
self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
if __name__ == '__main__':
num_blks = 128
width = 128
droppath=0.1
train_size = (1, 6, 30, 90)
net = NAFSSR(up_scale=2,train_size=train_size, fast_imp=True, width=width, num_blks=num_blks, drop_path_rate=droppath)
inp_shape = (6, 64, 64)
from ptflops import get_model_complexity_info
FLOPS = 0
macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
# params = float(params[:-4])
print(params)
macs = float(macs[:-4]) + FLOPS / 10 ** 9
print('mac', macs, params)
# from basicsr.models.archs.arch_util import measure_inference_speed
# net = net.cuda()
# data = torch.randn((1, 6, 128, 128)).cuda()
# measure_inference_speed(net, (data,))
|