| |
| |
| |
| |
| |
| |
| import math |
| import torch |
| from torch import nn as nn |
| from torch.nn import functional as F |
| from torch.nn import init as init |
| from torch.nn.modules.batchnorm import _BatchNorm |
|
|
| @torch.no_grad() |
| def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): |
| """Initialize network weights.""" |
| if not isinstance(module_list, list): |
| module_list = [module_list] |
| for module in module_list: |
| for m in module.modules(): |
| if isinstance(m, nn.Conv2d): |
| init.kaiming_normal_(m.weight, **kwargs) |
| m.weight.data *= scale |
| if m.bias is not None: |
| m.bias.data.fill_(bias_fill) |
| elif isinstance(m, nn.Linear): |
| init.kaiming_normal_(m.weight, **kwargs) |
| m.weight.data *= scale |
| if m.bias is not None: |
| m.bias.data.fill_(bias_fill) |
| elif isinstance(m, _BatchNorm): |
| init.constant_(m.weight, 1) |
| if m.bias is not None: |
| m.bias.data.fill_(bias_fill) |
|
|
|
|
| def make_layer(basic_block, num_basic_block, **kwarg): |
| """Make layers by stacking the same blocks.""" |
| layers = [] |
| for _ in range(num_basic_block): |
| layers.append(basic_block(**kwarg)) |
| return nn.Sequential(*layers) |
|
|
|
|
| class ResidualBlockNoBN(nn.Module): |
| """Residual block without BN.""" |
| def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): |
| super(ResidualBlockNoBN, self).__init__() |
| self.res_scale = res_scale |
| self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) |
| self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) |
| self.relu = nn.ReLU(inplace=True) |
| if not pytorch_init: |
| default_init_weights([self.conv1, self.conv2], 0.1) |
|
|
| def forward(self, x): |
| identity = x |
| out = self.conv2(self.relu(self.conv1(x))) |
| return identity + out * self.res_scale |
|
|
|
|
| class Upsample(nn.Sequential): |
| """Upsample module.""" |
| def __init__(self, scale, num_feat): |
| m = [] |
| if (scale & (scale - 1)) == 0: |
| for _ in range(int(math.log(scale, 2))): |
| m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) |
| m.append(nn.PixelShuffle(2)) |
| elif scale == 3: |
| m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) |
| m.append(nn.PixelShuffle(3)) |
| else: |
| raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') |
| super(Upsample, self).__init__(*m) |
|
|
|
|
| class LayerNormFunction(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, weight, bias, eps): |
| ctx.eps = eps |
| N, C, H, W = x.size() |
| mu = x.mean(1, keepdim=True) |
| var = (x - mu).pow(2).mean(1, keepdim=True) |
| y = (x - mu) / (var + eps).sqrt() |
| ctx.save_for_backward(y, var, weight) |
| y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) |
| return y |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| eps = ctx.eps |
| N, C, H, W = grad_output.size() |
| y, var, weight = ctx.saved_variables |
| g = grad_output * weight.view(1, C, 1, 1) |
| mean_g = g.mean(dim=1, keepdim=True) |
| mean_gy = (g * y).mean(dim=1, keepdim=True) |
| gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) |
| return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None |
|
|
|
|
| class LayerNorm2d(nn.Module): |
| def __init__(self, channels, eps=1e-6): |
| super(LayerNorm2d, self).__init__() |
| self.register_parameter('weight', nn.Parameter(torch.ones(channels))) |
| self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) |
| self.eps = eps |
|
|
| def forward(self, x): |
| return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) |
|
|