Spaces:
Running
Running
| import functools | |
| from math import exp | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import init | |
| from torch.autograd import Variable | |
| import torch.nn.functional as F | |
| import torch.nn.utils.spectral_norm as SpectralNorm | |
| from torchvision import models | |
| import torch.utils.model_zoo as model_zoo | |
| ################################## IO ################################## | |
| def save(net,path,gpu_id): | |
| if isinstance(net, nn.DataParallel): | |
| torch.save(net.module.cpu().state_dict(),path) | |
| else: | |
| torch.save(net.cpu().state_dict(),path) | |
| if gpu_id != '-1': | |
| net.cuda() | |
| def todevice(net,gpu_id): | |
| if gpu_id != '-1' and len(gpu_id) == 1: | |
| net.cuda() | |
| elif gpu_id != '-1' and len(gpu_id) > 1: | |
| net = nn.DataParallel(net) | |
| net.cuda() | |
| return net | |
| # patch InstanceNorm checkpoints prior to 0.4 | |
| def patch_instance_norm_state_dict(state_dict, module, keys, i=0): | |
| """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" | |
| key = keys[i] | |
| if i + 1 == len(keys): # at the end, pointing to a parameter/buffer | |
| if module.__class__.__name__.startswith('InstanceNorm') and \ | |
| (key == 'running_mean' or key == 'running_var'): | |
| if getattr(module, key) is None: | |
| state_dict.pop('.'.join(keys)) | |
| if module.__class__.__name__.startswith('InstanceNorm') and \ | |
| (key == 'num_batches_tracked'): | |
| state_dict.pop('.'.join(keys)) | |
| else: | |
| patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) | |
| ################################## initialization ################################## | |
| def get_norm_layer(norm_type='instance',mod = '2d'): | |
| if norm_type == 'batch': | |
| if mod == '2d': | |
| norm_layer = functools.partial(nn.BatchNorm2d, affine=True) | |
| elif mod == '3d': | |
| norm_layer = functools.partial(nn.BatchNorm3d, affine=True) | |
| elif norm_type == 'instance': | |
| if mod == '2d': | |
| norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) | |
| elif mod =='3d': | |
| norm_layer = functools.partial(nn.InstanceNorm3d, affine=False, track_running_stats=True) | |
| elif norm_type == 'none': | |
| norm_layer = None | |
| else: | |
| raise NotImplementedError('normalization layer [%s] is not found' % norm_type) | |
| return norm_layer | |
| def init_weights(net, init_type='normal', gain=0.02): | |
| def init_func(m): | |
| classname = m.__class__.__name__ | |
| if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): | |
| if init_type == 'normal': | |
| init.normal_(m.weight.data, 0.0, gain) | |
| elif init_type == 'xavier': | |
| init.xavier_normal_(m.weight.data, gain=gain) | |
| elif init_type == 'kaiming': | |
| init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
| elif init_type == 'orthogonal': | |
| init.orthogonal_(m.weight.data, gain=gain) | |
| else: | |
| raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | |
| if hasattr(m, 'bias') and m.bias is not None: | |
| init.constant_(m.bias.data, 0.0) | |
| elif classname.find('BatchNorm2d') != -1: | |
| init.normal_(m.weight.data, 1.0, gain) | |
| init.constant_(m.bias.data, 0.0) | |
| # print('initialize network with %s' % init_type) | |
| net.apply(init_func) | |
| ################################## Network structure ################################## | |
| ################################## ResnetBlock ################################## | |
| class ResnetBlockSpectralNorm(nn.Module): | |
| def __init__(self, dim, padding_type, activation=nn.LeakyReLU(0.2), use_dropout=False): | |
| super(ResnetBlockSpectralNorm, self).__init__() | |
| self.conv_block = self.build_conv_block(dim, padding_type, activation, use_dropout) | |
| def build_conv_block(self, dim, padding_type, activation, use_dropout): | |
| conv_block = [] | |
| p = 0 | |
| if padding_type == 'reflect': | |
| conv_block += [nn.ReflectionPad2d(1)] | |
| elif padding_type == 'replicate': | |
| conv_block += [nn.ReplicationPad2d(1)] | |
| elif padding_type == 'zero': | |
| p = 1 | |
| else: | |
| raise NotImplementedError('padding [%s] is not implemented' % padding_type) | |
| conv_block += [SpectralNorm(nn.Conv2d(dim, dim, kernel_size=3, padding=p)), | |
| activation] | |
| if use_dropout: | |
| conv_block += [nn.Dropout(0.5)] | |
| p = 0 | |
| if padding_type == 'reflect': | |
| conv_block += [nn.ReflectionPad2d(1)] | |
| elif padding_type == 'replicate': | |
| conv_block += [nn.ReplicationPad2d(1)] | |
| elif padding_type == 'zero': | |
| p = 1 | |
| else: | |
| raise NotImplementedError('padding [%s] is not implemented' % padding_type) | |
| conv_block += [SpectralNorm(nn.Conv2d(dim, dim, kernel_size=3, padding=p))] | |
| return nn.Sequential(*conv_block) | |
| def forward(self, x): | |
| out = x + self.conv_block(x) | |
| return out | |
| ################################## Resnet ################################## | |
| model_urls = { | |
| 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', | |
| 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', | |
| 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', | |
| 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', | |
| 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', | |
| } | |
| def conv3x3(in_planes, out_planes, stride=1): | |
| """3x3 convolution with padding""" | |
| return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |
| padding=1, bias=False) | |
| def conv1x1(in_planes, out_planes, stride=1): | |
| """1x1 convolution""" | |
| return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |
| class BasicBlock(nn.Module): | |
| expansion = 1 | |
| def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None): | |
| super(BasicBlock, self).__init__() | |
| if norm_layer is None: | |
| norm_layer = nn.BatchNorm2d | |
| # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |
| self.conv1 = conv3x3(inplanes, planes, stride) | |
| self.bn1 = norm_layer(planes) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.conv2 = conv3x3(planes, planes) | |
| self.bn2 = norm_layer(planes) | |
| self.downsample = downsample | |
| self.stride = stride | |
| def forward(self, x): | |
| identity = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| if self.downsample is not None: | |
| identity = self.downsample(x) | |
| out += identity | |
| out = self.relu(out) | |
| return out | |
| class Bottleneck(nn.Module): | |
| expansion = 4 | |
| def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None): | |
| super(Bottleneck, self).__init__() | |
| if norm_layer is None: | |
| norm_layer = nn.BatchNorm2d | |
| # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |
| self.conv1 = conv1x1(inplanes, planes) | |
| self.bn1 = norm_layer(planes) | |
| self.conv2 = conv3x3(planes, planes, stride) | |
| self.bn2 = norm_layer(planes) | |
| self.conv3 = conv1x1(planes, planes * self.expansion) | |
| self.bn3 = norm_layer(planes * self.expansion) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.downsample = downsample | |
| self.stride = stride | |
| def forward(self, x): | |
| identity = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out = self.relu(out) | |
| out = self.conv3(out) | |
| out = self.bn3(out) | |
| if self.downsample is not None: | |
| identity = self.downsample(x) | |
| out += identity | |
| out = self.relu(out) | |
| return out | |
| class ResNet(nn.Module): | |
| def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, norm_layer=None): | |
| super(ResNet, self).__init__() | |
| if norm_layer is None: | |
| norm_layer = nn.BatchNorm2d | |
| self.inplanes = 64 | |
| self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, | |
| bias=False) | |
| self.bn1 = norm_layer(64) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
| self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) | |
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) | |
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) | |
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.fc = nn.Linear(512 * block.expansion, num_classes) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
| elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| # Zero-initialize the last BN in each residual branch, | |
| # so that the residual branch starts with zeros, and each residual block behaves like an identity. | |
| # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | |
| if zero_init_residual: | |
| for m in self.modules(): | |
| if isinstance(m, Bottleneck): | |
| nn.init.constant_(m.bn3.weight, 0) | |
| elif isinstance(m, BasicBlock): | |
| nn.init.constant_(m.bn2.weight, 0) | |
| def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): | |
| if norm_layer is None: | |
| norm_layer = nn.BatchNorm2d | |
| downsample = None | |
| if stride != 1 or self.inplanes != planes * block.expansion: | |
| downsample = nn.Sequential( | |
| conv1x1(self.inplanes, planes * block.expansion, stride), | |
| norm_layer(planes * block.expansion), | |
| ) | |
| layers = [] | |
| layers.append(block(self.inplanes, planes, stride, downsample, norm_layer)) | |
| self.inplanes = planes * block.expansion | |
| for _ in range(1, blocks): | |
| layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x = self.avgpool(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.fc(x) | |
| return x | |
| def resnet18(pretrained=False, **kwargs): | |
| """Constructs a ResNet-18 model. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| """ | |
| model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) | |
| if pretrained: | |
| model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) | |
| return model | |
| def resnet101(pretrained=False, **kwargs): | |
| """Constructs a ResNet-101 model. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| """ | |
| model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) | |
| if pretrained: | |
| model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) | |
| return model | |
| ################################## Loss function ################################## | |
| class HingeLossD(nn.Module): | |
| def __init__(self): | |
| super(HingeLossD, self).__init__() | |
| def forward(self, dis_fake, dis_real): | |
| loss_real = torch.mean(F.relu(1. - dis_real)) | |
| loss_fake = torch.mean(F.relu(1. + dis_fake)) | |
| return loss_real + loss_fake | |
| class HingeLossG(nn.Module): | |
| def __init__(self): | |
| super(HingeLossG, self).__init__() | |
| def forward(self, dis_fake): | |
| loss_fake = -torch.mean(dis_fake) | |
| return loss_fake | |
| class VGGLoss(nn.Module): | |
| def __init__(self, gpu_id): | |
| super(VGGLoss, self).__init__() | |
| self.vgg = Vgg19() | |
| if gpu_id != '-1' and len(gpu_id) == 1: | |
| self.vgg.cuda() | |
| elif gpu_id != '-1' and len(gpu_id) > 1: | |
| self.vgg = nn.DataParallel(self.vgg) | |
| self.vgg.cuda() | |
| self.criterion = nn.MSELoss() | |
| self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] | |
| def forward(self, x, y): | |
| x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
| loss = 0 | |
| for i in range(len(x_vgg)): | |
| loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) | |
| return loss | |
| class Vgg19(torch.nn.Module): | |
| def __init__(self, requires_grad=False): | |
| super(Vgg19, self).__init__() | |
| vgg_pretrained_features = models.vgg19(pretrained=True).features | |
| self.slice1 = torch.nn.Sequential() | |
| self.slice2 = torch.nn.Sequential() | |
| self.slice3 = torch.nn.Sequential() | |
| self.slice4 = torch.nn.Sequential() | |
| self.slice5 = torch.nn.Sequential() | |
| for x in range(2): | |
| self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(2, 7): | |
| self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(7, 12): | |
| self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(12, 21): | |
| self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(21, 30): | |
| self.slice5.add_module(str(x), vgg_pretrained_features[x]) | |
| if not requires_grad: | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, X): | |
| h_relu1 = self.slice1(X) | |
| h_relu2 = self.slice2(h_relu1) | |
| h_relu3 = self.slice3(h_relu2) | |
| h_relu4 = self.slice4(h_relu3) | |
| h_relu5 = self.slice5(h_relu4) | |
| out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] | |
| return out | |
| ################################## Evaluation ################################## | |
| '''https://github.com/Po-Hsun-Su/pytorch-ssim | |
| img1 = Variable(torch.rand(1, 1, 256, 256)) | |
| img2 = Variable(torch.rand(1, 1, 256, 256)) | |
| if torch.cuda.is_available(): | |
| img1 = img1.cuda() | |
| img2 = img2.cuda() | |
| print(pytorch_ssim.ssim(img1, img2)) | |
| ssim_loss = pytorch_ssim.SSIM(window_size = 11) | |
| print(ssim_loss(img1, img2)) | |
| ''' | |
| def gaussian(window_size, sigma): | |
| gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) | |
| return gauss/gauss.sum() | |
| def create_window(window_size, channel): | |
| _1D_window = gaussian(window_size, 1.5).unsqueeze(1) | |
| _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |
| window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) | |
| return window | |
| def _ssim(img1, img2, window, window_size, channel, size_average = True): | |
| mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) | |
| mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) | |
| mu1_sq = mu1.pow(2) | |
| mu2_sq = mu2.pow(2) | |
| mu1_mu2 = mu1*mu2 | |
| sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq | |
| sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq | |
| sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 | |
| C1 = 0.01**2 | |
| C2 = 0.03**2 | |
| ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) | |
| if size_average: | |
| return ssim_map.mean() | |
| else: | |
| return ssim_map.mean(1).mean(1).mean(1) | |
| class SSIM(torch.nn.Module): | |
| def __init__(self, window_size = 11, size_average = True): | |
| super(SSIM, self).__init__() | |
| self.window_size = window_size | |
| self.size_average = size_average | |
| self.channel = 1 | |
| self.window = create_window(window_size, self.channel) | |
| def forward(self, img1, img2): | |
| (_, channel, _, _) = img1.size() | |
| if channel == self.channel and self.window.data.type() == img1.data.type(): | |
| window = self.window | |
| else: | |
| window = create_window(self.window_size, channel) | |
| if img1.is_cuda: | |
| window = window.cuda(img1.get_device()) | |
| window = window.type_as(img1) | |
| self.window = window | |
| self.channel = channel | |
| return _ssim(img1, img2, window, self.window_size, channel, self.size_average) | |
| def ssim(img1, img2, window_size = 11, size_average = True): | |
| (_, channel, _, _) = img1.size() | |
| window = create_window(window_size, channel) | |
| if img1.is_cuda: | |
| window = window.cuda(img1.get_device()) | |
| window = window.type_as(img1) | |
| return _ssim(img1, img2, window, window_size, channel, size_average) | |