Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| from .nn import mean_flat | |
| # input image range [-1,1] | |
| class VGG(nn.Module): | |
| def __init__(self, conv_index='22', rgb_range=1): | |
| super(VGG, self).__init__() | |
| vgg_features = models.vgg19(pretrained=True).features | |
| modules = [m for m in vgg_features] | |
| if conv_index.find('22') >= 0: | |
| self.vgg = nn.Sequential(*modules[:8]) | |
| elif conv_index.find('54') >= 0: | |
| self.vgg = nn.Sequential(*modules[:35]) | |
| vgg_mean = (0.485, 0.456, 0.406) | |
| vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) | |
| self.sub_mean = MeanShift(rgb_range, vgg_mean, vgg_std) | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| def forward(self, sr, hr): | |
| def _forward(x): | |
| x = self.sub_mean(x) | |
| x = self.vgg(x) | |
| return x | |
| sr = (sr + 1.)/2. | |
| hr = (hr + 1.)/2. | |
| vgg_sr = _forward(sr) | |
| with torch.no_grad(): | |
| vgg_hr = _forward(hr.detach()) | |
| loss = mean_flat((vgg_sr - vgg_hr) ** 2) | |
| return loss | |
| class MeanShift(nn.Conv2d): | |
| def __init__( | |
| self, rgb_range, | |
| rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): | |
| super(MeanShift, self).__init__(3, 3, kernel_size=1) | |
| std = torch.Tensor(rgb_std) | |
| self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) | |
| self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std | |
| for p in self.parameters(): | |
| p.requires_grad = False |