leharris3's picture
Minimal HF Space deployment with gradio 5.x fix
0917e8d
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias)
class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = False
class BasicBlock(nn.Sequential):
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, bias=False,
bn=True, act=nn.ReLU(True)):
m = [nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), stride=stride, bias=bias)
]
if bn: m.append(nn.BatchNorm2d(out_channels))
if act is not None: m.append(act)
super(BasicBlock, self).__init__(*m)
class ResBlock(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
if bn: m.append(nn.BatchNorm2d(n_feat))
if i == 0: m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x
return res
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feat, 4 * n_feat, 3, bias))
m.append(nn.PixelShuffle(2))
if bn: m.append(nn.BatchNorm2d(n_feat))
if act: m.append(act())
elif scale == 3:
m.append(conv(n_feat, 9 * n_feat, 3, bias))
m.append(nn.PixelShuffle(3))
if bn: m.append(nn.BatchNorm2d(n_feat))
if act: m.append(act())
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
# add NonLocalBlock2D
# reference: https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_simple_version.py
class NonLocalBlock2D(nn.Module):
def __init__(self, in_channels, inter_channels):
super(NonLocalBlock2D, self).__init__()
self.in_channels = in_channels
self.inter_channels = inter_channels
self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0)
nn.init.constant(self.W.weight, 0)
nn.init.constant(self.W.bias, 0)
self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0,2,1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0,2,1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0,2,1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
## define trunk branch
class TrunkBranch(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(TrunkBranch, self).__init__()
modules_body = []
for i in range(2):
modules_body.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
tx = self.body(x)
return tx
## define mask branch
class MaskBranchDownUp(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(MaskBranchDownUp, self).__init__()
MB_RB1 = []
MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
MB_Down = []
MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1))
MB_RB2 = []
for i in range(2):
MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
MB_Up = []
MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2))
MB_RB3 = []
MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
MB_1x1conv = []
MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True))
MB_sigmoid = []
MB_sigmoid.append(nn.Sigmoid())
self.MB_RB1 = nn.Sequential(*MB_RB1)
self.MB_Down = nn.Sequential(*MB_Down)
self.MB_RB2 = nn.Sequential(*MB_RB2)
self.MB_Up = nn.Sequential(*MB_Up)
self.MB_RB3 = nn.Sequential(*MB_RB3)
self.MB_1x1conv = nn.Sequential(*MB_1x1conv)
self.MB_sigmoid = nn.Sequential(*MB_sigmoid)
def forward(self, x):
x_RB1 = self.MB_RB1(x)
x_Down = self.MB_Down(x_RB1)
x_RB2 = self.MB_RB2(x_Down)
x_Up = self.MB_Up(x_RB2)
x_preRB3 = x_RB1 + x_Up
x_RB3 = self.MB_RB3(x_preRB3)
x_1x1 = self.MB_1x1conv(x_RB3)
mx = self.MB_sigmoid(x_1x1)
return mx
## define nonlocal mask branch
class NLMaskBranchDownUp(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(NLMaskBranchDownUp, self).__init__()
MB_RB1 = []
MB_RB1.append(NonLocalBlock2D(n_feat, n_feat // 2))
MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
MB_Down = []
MB_Down.append(nn.Conv2d(n_feat,n_feat, 3, stride=2, padding=1))
MB_RB2 = []
for i in range(2):
MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
MB_Up = []
MB_Up.append(nn.ConvTranspose2d(n_feat,n_feat, 6, stride=2, padding=2))
MB_RB3 = []
MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
MB_1x1conv = []
MB_1x1conv.append(nn.Conv2d(n_feat,n_feat, 1, padding=0, bias=True))
MB_sigmoid = []
MB_sigmoid.append(nn.Sigmoid())
self.MB_RB1 = nn.Sequential(*MB_RB1)
self.MB_Down = nn.Sequential(*MB_Down)
self.MB_RB2 = nn.Sequential(*MB_RB2)
self.MB_Up = nn.Sequential(*MB_Up)
self.MB_RB3 = nn.Sequential(*MB_RB3)
self.MB_1x1conv = nn.Sequential(*MB_1x1conv)
self.MB_sigmoid = nn.Sequential(*MB_sigmoid)
def forward(self, x):
x_RB1 = self.MB_RB1(x)
x_Down = self.MB_Down(x_RB1)
x_RB2 = self.MB_RB2(x_Down)
x_Up = self.MB_Up(x_RB2)
x_preRB3 = x_RB1 + x_Up
x_RB3 = self.MB_RB3(x_preRB3)
x_1x1 = self.MB_1x1conv(x_RB3)
mx = self.MB_sigmoid(x_1x1)
return mx
## define residual attention module
class ResAttModuleDownUpPlus(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(ResAttModuleDownUpPlus, self).__init__()
RA_RB1 = []
RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
RA_TB = []
RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
RA_MB = []
RA_MB.append(MaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
RA_tail = []
for i in range(2):
RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
self.RA_RB1 = nn.Sequential(*RA_RB1)
self.RA_TB = nn.Sequential(*RA_TB)
self.RA_MB = nn.Sequential(*RA_MB)
self.RA_tail = nn.Sequential(*RA_tail)
def forward(self, input):
RA_RB1_x = self.RA_RB1(input)
tx = self.RA_TB(RA_RB1_x)
mx = self.RA_MB(RA_RB1_x)
txmx = tx * mx
hx = txmx + RA_RB1_x
hx = self.RA_tail(hx)
return hx
## define nonlocal residual attention module
class NLResAttModuleDownUpPlus(nn.Module):
def __init__(
self, conv, n_feat, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(NLResAttModuleDownUpPlus, self).__init__()
RA_RB1 = []
RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
RA_TB = []
RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
RA_MB = []
RA_MB.append(NLMaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
RA_tail = []
for i in range(2):
RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
self.RA_RB1 = nn.Sequential(*RA_RB1)
self.RA_TB = nn.Sequential(*RA_TB)
self.RA_MB = nn.Sequential(*RA_MB)
self.RA_tail = nn.Sequential(*RA_tail)
def forward(self, input):
RA_RB1_x = self.RA_RB1(input)
tx = self.RA_TB(RA_RB1_x)
mx = self.RA_MB(RA_RB1_x)
txmx = tx * mx
hx = txmx + RA_RB1_x
hx = self.RA_tail(hx)
return hx
def make_model(args, parent=False):
return RNAN(args)
### RNAN
### residual attention + downscale upscale + denoising
class _ResGroup(nn.Module):
def __init__(self, conv, n_feats, kernel_size, act, res_scale):
super(_ResGroup, self).__init__()
modules_body = []
modules_body.append(ResAttModuleDownUpPlus(conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=res_scale))
modules_body.append(conv(n_feats, n_feats, kernel_size))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
return res
### nonlocal residual attention + downscale upscale + denoising
class _NLResGroup(nn.Module):
def __init__(self, conv, n_feats, kernel_size, act, res_scale):
super(_NLResGroup, self).__init__()
modules_body = []
# changed this to accept scale args
modules_body.append(NLResAttModuleDownUpPlus(
conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True),
res_scale=res_scale))
# if we don't use group residual, donot remove the following conv
modules_body.append(conv(n_feats, n_feats, kernel_size))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
#res += x
return res
class RNAN(nn.Module):
def __init__(self, scale_factor: Optional[int] = 8, args: Optional[dict] = None, conv=default_conv):
"""
Default parameters provided from the original paper.
https://arxiv.org/pdf/1903.10082
Parameters
---
:param n_colors: presumable this is the input channel dim (e.g., C=3 for RGB, etc )
"""
super(RNAN, self).__init__()
if args != None:
n_resgroup = args.n_resgroups
n_resblock = args.n_resblocks
n_feats = args.n_feats
reduction = args.reduction
scale = args.scale[0]
n_colors = args.n_colors
else:
# input channel dim
n_colors = 1
n_resgroup = 10
# set to 2; unused
n_resblock = 2
n_feats = 64
reduction = ...
# assuming this is a standard SR factor
scale = scale_factor
assert scale in [2, 4, 8]
kernel_size = 3
act = nn.ReLU(True)
# define head module
modules_head = [conv(n_colors, n_feats, kernel_size)]
# define body module
# it looks like we hard-coded two NL-blocks
modules_body_nl_low = [
_NLResGroup(
conv, n_feats, kernel_size, act=act, res_scale=scale)]
# the authors use 8 local res blocks in the paper
# this loop creates N-2 blocks, so we set n_resgroup=10 to create
# 10-2=8 blocks
modules_body = [
_ResGroup(
conv, n_feats, kernel_size, act=act, res_scale=scale) \
for _ in range(n_resgroup - 2)]
modules_body_nl_high = [
_NLResGroup(
conv, n_feats, kernel_size, act=act, res_scale=scale)]
modules_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
modules_tail = [
Upsampler(conv, scale, n_feats, act=False),
conv(n_feats, n_colors, kernel_size)]
self.head = nn.Sequential(*modules_head)
self.body_nl_low = nn.Sequential(*modules_body_nl_low)
self.body = nn.Sequential(*modules_body)
self.body_nl_high = nn.Sequential(*modules_body_nl_high)
self.tail = nn.Sequential(*modules_tail)
def forward(self, x: torch.Tensor):
# [B, H, W] -> [B, 1, H, W]
if len(x.shape) == 3:
x = x.unsqueeze(1)
feats_shallow = self.head(x)
res = self.body_nl_low(feats_shallow)
res = self.body(res)
res = self.body_nl_high(res)
res += feats_shallow
res_main = self.tail(res)
return res_main
def load_state_dict(self, state_dict, strict=False):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') >= 0:
print('Replace pre-trained upsampler to new one...')
else:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if strict:
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
if __name__ == "__main__":
model = RNAN()
x = torch.rand((1, 1, 64, 64))
breakpoint()
model(x)