Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Sun Jun 20 16:14:37 2021 | |
| @author: Administrator | |
| """ | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| from torchvision import transforms | |
| import torch, math | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| import numbers | |
| from thop import profile | |
| import numpy as np | |
| import time | |
| from torchvision import transforms | |
| class OneRestore(nn.Module): | |
| def __init__(self, channel = 32): | |
| super(OneRestore,self).__init__() | |
| self.norm = lambda x: (x-0.5)/0.5 | |
| self.denorm = lambda x: (x+1)/2 | |
| self.in_conv = nn.Conv2d(3,channel,kernel_size=1,stride=1,padding=0,bias=False) | |
| self.encoder = encoder(channel) | |
| self.middle = backbone(channel) | |
| self.decoder = decoder(channel) | |
| self.out_conv = nn.Conv2d(channel,3,kernel_size=1,stride=1,padding=0,bias=False) | |
| def forward(self,x,embedding): | |
| x_in = self.in_conv(self.norm(x)) | |
| x_l, x_m, x_s, x_ss = self.encoder(x_in, embedding) | |
| x_mid = self.middle(x_ss, embedding) | |
| x_out = self.decoder(x_mid, x_ss, x_s, x_m, x_l, embedding) | |
| out = self.out_conv(x_out) + x | |
| return self.denorm(out) | |
| class encoder(nn.Module): | |
| def __init__(self,channel): | |
| super(encoder,self).__init__() | |
| self.el = ResidualBlock(channel)#16 | |
| self.em = ResidualBlock(channel*2)#32 | |
| self.es = ResidualBlock(channel*4)#64 | |
| self.ess = ResidualBlock(channel*8)#128 | |
| self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) | |
| self.conv_eltem = nn.Conv2d(channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#16 32 | |
| self.conv_emtes = nn.Conv2d(2*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#32 64 | |
| self.conv_estess = nn.Conv2d(4*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 128 | |
| self.conv_esstesss = nn.Conv2d(8*channel,16*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 256 | |
| def forward(self,x,embedding): | |
| elout = self.el(x, embedding)#16 | |
| x_emin = self.conv_eltem(self.maxpool(elout))#32 | |
| emout = self.em(x_emin, embedding) | |
| x_esin = self.conv_emtes(self.maxpool(emout)) | |
| esout = self.es(x_esin, embedding) | |
| x_esin = self.conv_estess(self.maxpool(esout)) | |
| essout = self.ess(x_esin, embedding)#128 | |
| return elout, emout, esout, essout#,esssout | |
| class backbone(nn.Module): | |
| def __init__(self,channel): | |
| super(backbone,self).__init__() | |
| self.s1 = ResidualBlock(channel*8)#128 | |
| self.s2 = ResidualBlock(channel*8)#128 | |
| def forward(self,x,embedding): | |
| share1 = self.s1(x, embedding) | |
| share2 = self.s2(share1, embedding) | |
| return share2 | |
| class decoder(nn.Module): | |
| def __init__(self,channel): | |
| super(decoder,self).__init__() | |
| self.dss = ResidualBlock(channel*8)#128 | |
| self.ds = ResidualBlock(channel*4)#64 | |
| self.dm = ResidualBlock(channel*2)#32 | |
| self.dl = ResidualBlock(channel)#16 | |
| #self.conv_dssstdss = nn.Conv2d(16*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#256 128 | |
| self.conv_dsstds = nn.Conv2d(8*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 64 | |
| self.conv_dstdm = nn.Conv2d(4*channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 32 | |
| self.conv_dmtdl = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False)#32 16 | |
| def _upsample(self,x,y): | |
| _,_,H0,W0 = y.size() | |
| return F.interpolate(x,size=(H0,W0),mode='bilinear') | |
| def forward(self, x, x_ss, x_s, x_m, x_l, embedding): | |
| dssout = self.dss(x + x_ss, embedding) | |
| x_dsin = self.conv_dsstds(self._upsample(dssout, x_s)) | |
| dsout = self.ds(x_dsin + x_s, embedding) | |
| x_dmin = self.conv_dstdm(self._upsample(dsout, x_m)) | |
| dmout = self.dm(x_dmin + x_m, embedding) | |
| x_dlin = self.conv_dmtdl(self._upsample(dmout, x_l)) | |
| dlout = self.dl(x_dlin + x_l, embedding) | |
| return dlout | |
| class ResidualBlock(nn.Module): # Edge-oriented Residual Convolution Block 面向边缘的残差网络块 解决梯度消失的问题 | |
| def __init__(self, channel, norm=False): | |
| super(ResidualBlock, self).__init__() | |
| self.el = TransformerBlock(channel, num_heads=8, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias') | |
| def forward(self, x,embedding): | |
| return self.el(x,embedding) | |
| def to_3d(x): | |
| return rearrange(x, 'b c h w -> b (h w) c') | |
| def to_4d(x, h, w): | |
| return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) | |
| class BiasFree_LayerNorm(nn.Module): | |
| def __init__(self, normalized_shape): | |
| super(BiasFree_LayerNorm, self).__init__() | |
| if isinstance(normalized_shape, numbers.Integral): | |
| normalized_shape = (normalized_shape,) | |
| normalized_shape = torch.Size(normalized_shape) | |
| assert len(normalized_shape) == 1 | |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
| self.normalized_shape = normalized_shape | |
| def forward(self, x): | |
| sigma = x.var(-1, keepdim=True, unbiased=False) | |
| return x / torch.sqrt(sigma + 1e-5) * self.weight | |
| class WithBias_LayerNorm(nn.Module): | |
| def __init__(self, normalized_shape): | |
| super(WithBias_LayerNorm, self).__init__() | |
| if isinstance(normalized_shape, numbers.Integral): | |
| normalized_shape = (normalized_shape,) | |
| normalized_shape = torch.Size(normalized_shape) | |
| assert len(normalized_shape) == 1 | |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
| self.normalized_shape = normalized_shape | |
| def forward(self, x): | |
| mu = x.mean(-1, keepdim=True) | |
| sigma = x.var(-1, keepdim=True, unbiased=False) | |
| return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias | |
| class LayerNorm(nn.Module): | |
| def __init__(self, dim, LayerNorm_type): | |
| super(LayerNorm, self).__init__() | |
| if LayerNorm_type == 'BiasFree': | |
| self.body = BiasFree_LayerNorm(dim) | |
| else: | |
| self.body = WithBias_LayerNorm(dim) | |
| def forward(self, x): | |
| h, w = x.shape[-2:] | |
| return to_4d(self.body(to_3d(x)), h, w) | |
| class Cross_Attention(nn.Module): | |
| def __init__(self, | |
| dim, | |
| num_heads, | |
| bias, | |
| q_dim = 324): | |
| super(Cross_Attention, self).__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| sqrt_q_dim = int(math.sqrt(q_dim)) | |
| self.resize = transforms.Resize([sqrt_q_dim, sqrt_q_dim]) | |
| self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) | |
| self.q = nn.Linear(q_dim, q_dim, bias=bias) | |
| self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias) | |
| self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias) | |
| self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) | |
| def forward(self, x, query): | |
| b,c,h,w = x.shape | |
| q = self.q(query) | |
| k, v = self.kv_dwconv(self.kv(x)).chunk(2, dim=1) | |
| k = self.resize(k) | |
| q = repeat(q, 'b l -> b head c l', head=self.num_heads, c=self.dim//self.num_heads) | |
| k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) | |
| v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) | |
| q = torch.nn.functional.normalize(q, dim=-1) | |
| k = torch.nn.functional.normalize(k, dim=-1) | |
| attn = (q @ k.transpose(-2, -1)) * self.temperature | |
| attn = attn.softmax(dim=-1) | |
| out = (attn @ v) | |
| out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) | |
| out = self.project_out(out) | |
| return out | |
| class Self_Attention(nn.Module): | |
| def __init__(self, | |
| dim, | |
| num_heads, | |
| bias): | |
| super(Self_Attention, self).__init__() | |
| self.num_heads = num_heads | |
| self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) | |
| self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) | |
| self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) | |
| self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) | |
| def forward(self, x): | |
| b,c,h,w = x.shape | |
| qkv = self.qkv_dwconv(self.qkv(x)) | |
| q,k,v = qkv.chunk(3, dim=1) | |
| q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) | |
| k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) | |
| v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) | |
| q = torch.nn.functional.normalize(q, dim=-1) | |
| k = torch.nn.functional.normalize(k, dim=-1) | |
| attn = (q @ k.transpose(-2, -1)) * self.temperature | |
| attn = attn.softmax(dim=-1) | |
| out = (attn @ v) | |
| out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) | |
| out = self.project_out(out) | |
| return out | |
| class FeedForward(nn.Module): | |
| def __init__(self, | |
| dim, | |
| ffn_expansion_factor, | |
| bias): | |
| super(FeedForward, self).__init__() | |
| hidden_features = int(dim * ffn_expansion_factor) | |
| self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) | |
| self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, | |
| groups=hidden_features * 2, bias=bias) | |
| self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) | |
| def forward(self, x): | |
| x = self.project_in(x) | |
| x1, x2 = self.dwconv(x).chunk(2, dim=1) | |
| x = F.gelu(x1) * x2 | |
| x = self.project_out(x) | |
| return x | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, | |
| dim, | |
| num_heads=8, | |
| ffn_expansion_factor=2.66, | |
| bias=False, | |
| LayerNorm_type='WithBias'): | |
| super(TransformerBlock, self).__init__() | |
| self.norm1 = LayerNorm(dim, LayerNorm_type) | |
| self.cross_attn = Cross_Attention(dim, num_heads, bias) | |
| self.norm2 = LayerNorm(dim, LayerNorm_type) | |
| self.self_attn = Self_Attention(dim, num_heads, bias) | |
| self.norm3 = LayerNorm(dim, LayerNorm_type) | |
| self.ffn = FeedForward(dim, ffn_expansion_factor, bias) | |
| def forward(self, x, query): | |
| x = x + self.cross_attn(self.norm1(x),query) | |
| x = x + self.self_attn(self.norm2(x)) | |
| x = x + self.ffn(self.norm3(x)) | |
| return x | |
| if __name__ == '__main__': | |
| net = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu") | |
| # x = torch.Tensor(np.random.random((2,3,256,256))).to("cuda" if torch.cuda.is_available() else "cpu") | |
| # query = torch.Tensor(np.random.random((2, 324))).to("cuda" if torch.cuda.is_available() else "cpu") | |
| # out = net(x, query) | |
| # print(out.shape) | |
| input = torch.randn(1, 3, 512, 512).to("cuda" if torch.cuda.is_available() else "cpu") | |
| query = torch.Tensor(np.random.random((1, 324))).to("cuda" if torch.cuda.is_available() else "cpu") | |
| macs, _ = profile(net, inputs=(input, query)) | |
| total = sum([param.nelement() for param in net.parameters()]) | |
| print('Macs = ' + str(macs/1000**3) + 'G') | |
| print('Params = ' + str(total/1e6) + 'M') | |
| from fvcore.nn import FlopCountAnalysis, parameter_count_table | |
| flops = FlopCountAnalysis(net, (input, query)) | |
| print("FLOPs", flops.total()/1000**3) | |