| | import math |
| | import os |
| | from typing import Any, cast, Dict, List, Union |
| |
|
| | import torch |
| | from torch import nn, Tensor |
| | from torch.nn import functional as F_torch |
| | from torchvision import models, transforms |
| | from torchvision.models.feature_extraction import create_feature_extractor |
| |
|
| | __all__ = [ |
| | "DiscriminatorForVGG", "SRResNet", |
| | "discriminator_for_vgg", "srresnet_x2", "srresnet_x4", "srresnet_x8", |
| | ] |
| |
|
| | feature_extractor_net_cfgs: Dict[str, List[Union[str, int]]] = { |
| | "vgg11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], |
| | "vgg13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], |
| | "vgg16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], |
| | "vgg19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], |
| | } |
| |
|
| |
|
| | def _make_layers(net_cfg_name: str, batch_norm: bool = False) -> nn.Sequential: |
| | net_cfg = feature_extractor_net_cfgs[net_cfg_name] |
| | layers: nn.Sequential[nn.Module] = nn.Sequential() |
| | in_channels = 3 |
| | for v in net_cfg: |
| | if v == "M": |
| | layers.append(nn.MaxPool2d((2, 2), (2, 2))) |
| | else: |
| | v = cast(int, v) |
| | conv2d = nn.Conv2d(in_channels, v, (3, 3), (1, 1), (1, 1)) |
| | if batch_norm: |
| | layers.append(conv2d) |
| | layers.append(nn.BatchNorm2d(v)) |
| | layers.append(nn.ReLU(True)) |
| | else: |
| | layers.append(conv2d) |
| | layers.append(nn.ReLU(True)) |
| | in_channels = v |
| |
|
| | return layers |
| |
|
| |
|
| | class _FeatureExtractor(nn.Module): |
| | def __init__( |
| | self, |
| | net_cfg_name: str = "vgg19", |
| | batch_norm: bool = False, |
| | num_classes: int = 1000) -> None: |
| | super(_FeatureExtractor, self).__init__() |
| | self.features = _make_layers(net_cfg_name, batch_norm) |
| |
|
| | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) |
| |
|
| | self.classifier = nn.Sequential( |
| | nn.Linear(512 * 7 * 7, 4096), |
| | nn.ReLU(True), |
| | nn.Dropout(0.5), |
| | nn.Linear(4096, 4096), |
| | nn.ReLU(True), |
| | nn.Dropout(0.5), |
| | nn.Linear(4096, num_classes), |
| | ) |
| |
|
| | |
| | for module in self.modules(): |
| | if isinstance(module, nn.Conv2d): |
| | nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") |
| | if module.bias is not None: |
| | nn.init.constant_(module.bias, 0) |
| | elif isinstance(module, nn.BatchNorm2d): |
| | nn.init.constant_(module.weight, 1) |
| | nn.init.constant_(module.bias, 0) |
| | elif isinstance(module, nn.Linear): |
| | nn.init.normal_(module.weight, 0, 0.01) |
| | nn.init.constant_(module.bias, 0) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | return self._forward_impl(x) |
| |
|
| | |
| | def _forward_impl(self, x: Tensor) -> Tensor: |
| | x = self.features(x) |
| | x = self.avgpool(x) |
| | x = torch.flatten(x, 1) |
| | x = self.classifier(x) |
| |
|
| | return x |
| |
|
| |
|
| | class SRResNet(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int = 3, |
| | out_channels: int = 3, |
| | channels: int = 64, |
| | num_rcb: int = 16, |
| | upscale: int = 4, |
| | ) -> None: |
| | super(SRResNet, self).__init__() |
| | |
| | self.conv1 = nn.Sequential( |
| | nn.Conv2d(in_channels, channels, (9, 9), (1, 1), (4, 4)), |
| | nn.PReLU(), |
| | ) |
| |
|
| | |
| | trunk = [] |
| | for _ in range(num_rcb): |
| | trunk.append(_ResidualConvBlock(channels)) |
| | self.trunk = nn.Sequential(*trunk) |
| |
|
| | |
| | self.conv2 = nn.Sequential( |
| | nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False), |
| | nn.BatchNorm2d(channels), |
| | ) |
| |
|
| | |
| | upsampling = [] |
| | if upscale == 2 or upscale == 4 or upscale == 8: |
| | for _ in range(int(math.log(upscale, 2))): |
| | upsampling.append(_UpsampleBlock(channels, 2)) |
| | |
| | |
| | self.upsampling = nn.Sequential(*upsampling) |
| |
|
| | |
| | self.conv3 = nn.Conv2d(channels, out_channels, (9, 9), (1, 1), (4, 4)) |
| |
|
| | |
| | for module in self.modules(): |
| | if isinstance(module, nn.Conv2d): |
| | nn.init.kaiming_normal_(module.weight) |
| | if module.bias is not None: |
| | nn.init.constant_(module.bias, 0) |
| | elif isinstance(module, nn.BatchNorm2d): |
| | nn.init.constant_(module.weight, 1) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | return self._forward_impl(x) |
| |
|
| | |
| | def _forward_impl(self, x: Tensor) -> Tensor: |
| | conv1 = self.conv1(x) |
| | x = self.trunk(conv1) |
| | x = self.conv2(x) |
| | x = torch.add(x, conv1) |
| | x = self.upsampling(x) |
| | x = self.conv3(x) |
| |
|
| | x = torch.clamp_(x, 0.0, 1.0) |
| |
|
| | return x |
| |
|
| |
|
| | class DiscriminatorForVGG(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int = 3, |
| | out_channels: int = 1, |
| | channels: int = 64, |
| | ) -> None: |
| | super(DiscriminatorForVGG, self).__init__() |
| | self.features = nn.Sequential( |
| | |
| | nn.Conv2d(in_channels, channels, (3, 3), (1, 1), (1, 1), bias=True), |
| | nn.LeakyReLU(0.2, True), |
| | |
| | nn.Conv2d(channels, channels, (3, 3), (2, 2), (1, 1), bias=False), |
| | nn.BatchNorm2d(channels), |
| | nn.LeakyReLU(0.2, True), |
| | nn.Conv2d(channels, int(2 * channels), (3, 3), (1, 1), (1, 1), bias=False), |
| | nn.BatchNorm2d(int(2 * channels)), |
| | nn.LeakyReLU(0.2, True), |
| | |
| | nn.Conv2d(int(2 * channels), int(2 * channels), (3, 3), (2, 2), (1, 1), bias=False), |
| | nn.BatchNorm2d(int(2 * channels)), |
| | nn.LeakyReLU(0.2, True), |
| | nn.Conv2d(int(2 * channels), int(4 * channels), (3, 3), (1, 1), (1, 1), bias=False), |
| | nn.BatchNorm2d(int(4 * channels)), |
| | nn.LeakyReLU(0.2, True), |
| | |
| | nn.Conv2d(int(4 * channels), int(4 * channels), (3, 3), (2, 2), (1, 1), bias=False), |
| | nn.BatchNorm2d(int(4 * channels)), |
| | nn.LeakyReLU(0.2, True), |
| | nn.Conv2d(int(4 * channels), int(8 * channels), (3, 3), (1, 1), (1, 1), bias=False), |
| | nn.BatchNorm2d(int(8 * channels)), |
| | nn.LeakyReLU(0.2, True), |
| | |
| | nn.Conv2d(int(8 * channels), int(8 * channels), (3, 3), (2, 2), (1, 1), bias=False), |
| | nn.BatchNorm2d(int(8 * channels)), |
| | nn.LeakyReLU(0.2, True), |
| | ) |
| |
|
| | self.classifier = nn.Sequential( |
| | nn.Linear(int(8 * channels) * 6 * 6, 1024), |
| | nn.LeakyReLU(0.2, True), |
| | nn.Linear(1024, out_channels), |
| | ) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | |
| | assert x.size(2) == 96 and x.size(3) == 96, "Input image size must be is 96x96" |
| |
|
| | x = self.features(x) |
| | x = torch.flatten(x, 1) |
| | x = self.classifier(x) |
| |
|
| | return x |
| |
|
| |
|
| | class _ResidualConvBlock(nn.Module): |
| | def __init__(self, channels: int) -> None: |
| | super(_ResidualConvBlock, self).__init__() |
| | self.rcb = nn.Sequential( |
| | nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False), |
| | nn.BatchNorm2d(channels), |
| | nn.PReLU(), |
| | nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False), |
| | nn.BatchNorm2d(channels), |
| | ) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | identity = x |
| |
|
| | x = self.rcb(x) |
| |
|
| | x = torch.add(x, identity) |
| |
|
| | return x |
| |
|
| |
|
| | class _UpsampleBlock(nn.Module): |
| | def __init__(self, channels: int, upscale_factor: int) -> None: |
| | super(_UpsampleBlock, self).__init__() |
| | self.upsample_block = nn.Sequential( |
| | nn.Conv2d(channels, channels * upscale_factor * upscale_factor, (3, 3), (1, 1), (1, 1)), |
| | nn.PixelShuffle(upscale_factor), |
| | nn.PReLU(), |
| | ) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | x = self.upsample_block(x) |
| |
|
| | return x |
| |
|
| |
|
| | class ContentLoss(nn.Module): |
| | """Constructs a content loss function based on the VGG19 network. |
| | Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image. |
| | |
| | Paper reference list: |
| | -`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper. |
| | -`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper. |
| | -`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper. |
| | |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | net_cfg_name: str, |
| | batch_norm: bool, |
| | num_classes: int, |
| | model_weights_path: str, |
| | feature_nodes: list, |
| | feature_normalize_mean: list, |
| | feature_normalize_std: list, |
| | ) -> None: |
| | super(ContentLoss, self).__init__() |
| | |
| | model = _FeatureExtractor(net_cfg_name, batch_norm, num_classes) |
| | |
| | if model_weights_path == "": |
| | model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1) |
| | elif model_weights_path is not None and os.path.exists(model_weights_path): |
| | checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage) |
| | if "state_dict" in checkpoint.keys(): |
| | model.load_state_dict(checkpoint["state_dict"]) |
| | else: |
| | model.load_state_dict(checkpoint) |
| | else: |
| | raise FileNotFoundError("Model weight file not found") |
| | |
| | self.feature_extractor = create_feature_extractor(model, feature_nodes) |
| | |
| | self.feature_extractor_nodes = feature_nodes |
| | |
| | self.normalize = transforms.Normalize(feature_normalize_mean, feature_normalize_std) |
| | |
| | for model_parameters in self.feature_extractor.parameters(): |
| | model_parameters.requires_grad = False |
| | self.feature_extractor.eval() |
| |
|
| | def forward(self, sr_tensor: Tensor, gt_tensor: Tensor) -> [Tensor]: |
| | assert sr_tensor.size() == gt_tensor.size(), "Two tensor must have the same size" |
| | device = sr_tensor.device |
| |
|
| | losses = [] |
| | |
| | sr_tensor = self.normalize(sr_tensor) |
| | gt_tensor = self.normalize(gt_tensor) |
| |
|
| | |
| | sr_feature = self.feature_extractor(sr_tensor) |
| | gt_feature = self.feature_extractor(gt_tensor) |
| |
|
| | |
| | for i in range(len(self.feature_extractor_nodes)): |
| | losses.append(F_torch.mse_loss(sr_feature[self.feature_extractor_nodes[i]], |
| | gt_feature[self.feature_extractor_nodes[i]])) |
| |
|
| | losses = torch.Tensor([losses]).to(device) |
| |
|
| | return losses |
| |
|
| |
|
| | def srresnet_x2(**kwargs: Any) -> SRResNet: |
| | model = SRResNet(upscale=2, **kwargs) |
| |
|
| | return model |
| |
|
| |
|
| | def srresnet_x4(**kwargs: Any) -> SRResNet: |
| | model = SRResNet(upscale=4, **kwargs) |
| |
|
| | return model |
| |
|
| |
|
| | def srresnet_x8(**kwargs: Any) -> SRResNet: |
| | model = SRResNet(upscale=8, **kwargs) |
| |
|
| | return model |
| |
|
| |
|
| | def discriminator_for_vgg(**kwargs) -> DiscriminatorForVGG: |
| | model = DiscriminatorForVGG(**kwargs) |
| |
|
| | return model |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from typing import Any |
| |
|
| | |
| | def test_srresnet(upscale_factor: int = 4): |
| | |
| | batch_size = 1 |
| | channels = 1 |
| | height, width = 24, 24 |
| |
|
| | input_tensor = torch.rand((batch_size, channels, height, width)) |
| |
|
| | |
| | model = SRResNet(in_channels=channels, out_channels=channels, upscale=upscale_factor) |
| |
|
| | |
| | output_tensor = model(input_tensor) |
| |
|
| | print(f"Test SRResnet Input shape: {input_tensor.shape}") |
| | print(f"Test SRResnet Output shape: {output_tensor.shape}") |
| |
|
| | |
| | test_srresnet(upscale_factor=1) |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | class ResidualBlock(nn.Module): |
| | def __init__(self, in_channels, out_channels, k=3, p=1): |
| | super(ResidualBlock, self).__init__() |
| | self.net = nn.Sequential( |
| | nn.Conv2d(in_channels, out_channels, kernel_size=k, padding=p), |
| | nn.BatchNorm2d(out_channels), |
| | nn.PReLU(), |
| | |
| | nn.Conv2d(out_channels, out_channels, kernel_size=k, padding=p), |
| | nn.BatchNorm2d(out_channels) |
| | ) |
| |
|
| | def forward(self, x): |
| | return x + self.net(x) |
| | |
| | class UpsampleBLock(nn.Module): |
| | def __init__(self, in_channels, scaleFactor, k=3, p=1): |
| | super(UpsampleBLock, self).__init__() |
| | self.net = nn.Sequential( |
| | nn.Conv2d(in_channels, in_channels * (scaleFactor ** 2), kernel_size=k, padding=p), |
| | nn.PixelShuffle(scaleFactor), |
| | nn.PReLU() |
| | ) |
| | |
| | def forward(self, x): |
| | return self.net(x) |
| | |
| | class Generator(nn.Module): |
| | def __init__(self, n_residual=8): |
| | super(Generator, self).__init__() |
| | self.n_residual = n_residual |
| | self.conv1 = nn.Sequential( |
| | nn.Conv2d(3, 64, kernel_size=9, padding=4), |
| | nn.PReLU() |
| | ) |
| | |
| | for i in range(n_residual): |
| | self.add_module('residual' + str(i+1), ResidualBlock(64, 64)) |
| | |
| | self.conv2 = nn.Sequential( |
| | nn.Conv2d(64, 64, kernel_size=3, padding=1), |
| | nn.PReLU() |
| | ) |
| | |
| | self.upsample = nn.Sequential( |
| | UpsampleBLock(64, 2), |
| | UpsampleBLock(64, 2), |
| | nn.Conv2d(64, 3, kernel_size=9, padding=4) |
| | ) |
| |
|
| | def forward(self, x): |
| | |
| | y = self.conv1(x) |
| | cache = y.clone() |
| | |
| | for i in range(self.n_residual): |
| | y = self.__getattr__('residual' + str(i+1))(y) |
| | |
| | y = self.conv2(y) |
| | y = self.upsample(y + cache) |
| | |
| | return (torch.tanh(y) + 1.0) / 2.0 |
| |
|
| | class Discriminator(nn.Module): |
| | def __init__(self, in_channels=3, l=0.2): |
| | super(Discriminator, self).__init__() |
| | self.net = nn.Sequential( |
| | nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), |
| | nn.BatchNorm2d(64), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(64, 128, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(128), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), |
| | nn.BatchNorm2d(128), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(128, 256, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(256), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), |
| | nn.BatchNorm2d(256), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(256, 512, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(512), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), |
| | nn.BatchNorm2d(512), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.AdaptiveAvgPool2d(1), |
| | nn.Conv2d(512, 1024, kernel_size=1), |
| | nn.LeakyReLU(l), |
| | nn.Conv2d(1024, 1, kernel_size=1) |
| | ) |
| |
|
| | def forward(self, x): |
| | y = self.net(x) |
| | return torch.sigmoid(y).view(y.size(0)) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | class Discriminator_WGAN(nn.Module): |
| | def __init__(self, l=0.2): |
| | super(Discriminator_WGAN, self).__init__() |
| | self.net = nn.Sequential( |
| | nn.Conv2d(3, 64, kernel_size=3, padding=1), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(64, 128, kernel_size=3, padding=1), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(128, 256, kernel_size=3, padding=1), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(256, 512, kernel_size=3, padding=1), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), |
| | nn.LeakyReLU(l), |
| |
|
| | nn.AdaptiveAvgPool2d(1), |
| | nn.Conv2d(512, 1024, kernel_size=1), |
| | nn.LeakyReLU(l), |
| | nn.Conv2d(1024, 1, kernel_size=1) |
| | ) |
| |
|
| | def forward(self, x): |
| | |
| | y = self.net(x) |
| | |
| | return y.view(y.size()[0]) |
| |
|
| | def compute_gradient_penalty(D, real_samples, fake_samples): |
| | alpha = torch.randn(real_samples.size(0), 1, 1, 1) |
| | if torch.cuda.is_available(): |
| | alpha = alpha.cuda() |
| | |
| | interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) |
| | d_interpolates = D(interpolates) |
| | fake = torch.ones(d_interpolates.size()) |
| | if torch.cuda.is_available(): |
| | fake = fake.cuda() |
| | |
| | gradients = torch.autograd.grad( |
| | outputs=d_interpolates, |
| | inputs=interpolates, |
| | grad_outputs=fake, |
| | create_graph=True, |
| | retain_graph=True, |
| | only_inputs=True, |
| | )[0] |
| | gradients = gradients.view(gradients.size(0), -1) |
| | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() |
| | return gradient_penalty |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | import os |
| | from os import listdir |
| | from os.path import join |
| |
|
| | from PIL import Image |
| |
|
| | import torch.utils.data |
| | from torch.utils.data import DataLoader |
| | from torch.utils.data.dataset import Dataset |
| |
|
| | import torchvision.utils as utils |
| | from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize, Normalize |
| |
|
| | def is_image_file(filename): |
| | return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']) |
| |
|
| | def calculate_valid_crop_size(crop_size, upscale_factor): |
| | return crop_size - (crop_size % upscale_factor) |
| |
|
| | def to_image(): |
| | return Compose([ |
| | ToPILImage(), |
| | ToTensor() |
| | ]) |
| | |
| | class TrainDataset(Dataset): |
| | def __init__(self, dataset_dir, crop_size, upscale_factor): |
| | super(TrainDataset, self).__init__() |
| | self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)] |
| | crop_size = calculate_valid_crop_size(crop_size, upscale_factor) |
| | self.hr_preprocess = Compose([CenterCrop(384), RandomCrop(crop_size), ToTensor()]) |
| | self.lr_preprocess = Compose([ToPILImage(), Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC), ToTensor()]) |
| |
|
| | def __getitem__(self, index): |
| | hr_image = self.hr_preprocess(Image.open(self.image_filenames[index])) |
| | lr_image = self.lr_preprocess(hr_image) |
| | return lr_image, hr_image |
| |
|
| | def __len__(self): |
| | return len(self.image_filenames) |
| | |
| | class DevDataset(Dataset): |
| | def __init__(self, dataset_dir, upscale_factor): |
| | super(DevDataset, self).__init__() |
| | self.upscale_factor = upscale_factor |
| | self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)] |
| |
|
| | def __getitem__(self, index): |
| | hr_image = Image.open(self.image_filenames[index]) |
| | crop_size = calculate_valid_crop_size(128, self.upscale_factor) |
| | lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC) |
| | hr_scale = Resize(crop_size, interpolation=Image.BICUBIC) |
| | hr_image = CenterCrop(crop_size)(hr_image) |
| | lr_image = lr_scale(hr_image) |
| | hr_restore_img = hr_scale(lr_image) |
| | norm = ToTensor() |
| | return norm(lr_image), norm(hr_restore_img), norm(hr_image) |
| |
|
| | def __len__(self): |
| | return len(self.image_filenames) |
| |
|
| | def print_first_parameter(net): |
| | for name, param in net.named_parameters(): |
| | if param.requires_grad: |
| | print (str(name) + ':' + str(param.data[0])) |
| | return |
| |
|
| | def check_grads(model, model_name): |
| | grads = [] |
| | for p in model.parameters(): |
| | if not p.grad is None: |
| | grads.append(float(p.grad.mean())) |
| |
|
| | grads = np.array(grads) |
| | if grads.any() and grads.mean() > 100: |
| | print('WARNING!' + model_name + ' gradients mean is over 100.') |
| | return False |
| | if grads.any() and grads.max() > 100: |
| | print('WARNING!' + model_name + ' gradients max is over 100.') |
| | return False |
| | |
| | return True |
| |
|
| | def get_grads_D(net): |
| | top = 0 |
| | bottom = 0 |
| | for name, param in net.named_parameters(): |
| | if param.requires_grad: |
| | |
| | if name == 'net.0.weight': |
| | top = param.grad.abs().mean() |
| | |
| | |
| | if name == 'net.26.weight': |
| | bottom = param.grad.abs().mean() |
| | |
| | return top, bottom |
| | |
| | def get_grads_D_WAN(net): |
| | top = 0 |
| | bottom = 0 |
| | for name, param in net.named_parameters(): |
| | if param.requires_grad: |
| | |
| | if name == 'net.0.weight': |
| | top = param.grad.abs().mean() |
| | |
| | |
| | if name == 'net.19.weight': |
| | bottom = param.grad.abs().mean() |
| | |
| | return top, bottom |
| |
|
| | def get_grads_G(net): |
| | top = 0 |
| | bottom = 0 |
| | |
| | |
| | for name, param in net.named_parameters(): |
| | if param.requires_grad: |
| | |
| | if name == 'conv1.0.weight': |
| | top = param.grad.abs().mean() |
| | |
| | |
| | if name == 'upsample.2.weight': |
| | bottom = param.grad.abs().mean() |
| | |
| | return top, bottom |
| |
|
| | import torch |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |