Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| import numpy as np | |
| import torch | |
| from utils import utils_image as util | |
| import re | |
| import glob | |
| import os | |
| ''' | |
| # -------------------------------------------- | |
| # Model | |
| # -------------------------------------------- | |
| # Kai Zhang (github: https://github.com/cszn) | |
| # 03/Mar/2019 | |
| # -------------------------------------------- | |
| ''' | |
| def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None): | |
| """ | |
| # --------------------------------------- | |
| # Kai Zhang (github: https://github.com/cszn) | |
| # 03/Mar/2019 | |
| # --------------------------------------- | |
| Args: | |
| save_dir: model folder | |
| net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD' | |
| pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path | |
| Return: | |
| init_iter: iteration number | |
| init_path: model path | |
| # --------------------------------------- | |
| """ | |
| file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) | |
| if file_list: | |
| iter_exist = [] | |
| for file_ in file_list: | |
| iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) | |
| iter_exist.append(int(iter_current[0])) | |
| init_iter = max(iter_exist) | |
| init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type)) | |
| else: | |
| init_iter = 0 | |
| init_path = pretrained_path | |
| return init_iter, init_path | |
| def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1): | |
| ''' | |
| # --------------------------------------- | |
| # Kai Zhang (github: https://github.com/cszn) | |
| # 03/Mar/2019 | |
| # --------------------------------------- | |
| Args: | |
| model: trained model | |
| L: input Low-quality image | |
| mode: | |
| (0) normal: test(model, L) | |
| (1) pad: test_pad(model, L, modulo=16) | |
| (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1) | |
| (3) x8: test_x8(model, L, modulo=1) ^_^ | |
| (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1) | |
| refield: effective receptive filed of the network, 32 is enough | |
| useful when split, i.e., mode=2, 4 | |
| min_size: min_sizeXmin_size image, e.g., 256X256 image | |
| useful when split, i.e., mode=2, 4 | |
| sf: scale factor for super-resolution, otherwise 1 | |
| modulo: 1 if split | |
| useful when pad, i.e., mode=1 | |
| Returns: | |
| E: estimated image | |
| # --------------------------------------- | |
| ''' | |
| if mode == 0: | |
| E = test(model, L) | |
| elif mode == 1: | |
| E = test_pad(model, L, modulo, sf) | |
| elif mode == 2: | |
| E = test_split(model, L, refield, min_size, sf, modulo) | |
| elif mode == 3: | |
| E = test_x8(model, L, modulo, sf) | |
| elif mode == 4: | |
| E = test_split_x8(model, L, refield, min_size, sf, modulo) | |
| return E | |
| ''' | |
| # -------------------------------------------- | |
| # normal (0) | |
| # -------------------------------------------- | |
| ''' | |
| def test(model, L): | |
| E = model(L) | |
| return E | |
| ''' | |
| # -------------------------------------------- | |
| # pad (1) | |
| # -------------------------------------------- | |
| ''' | |
| def test_pad(model, L, modulo=16, sf=1): | |
| h, w = L.size()[-2:] | |
| paddingBottom = int(np.ceil(h/modulo)*modulo-h) | |
| paddingRight = int(np.ceil(w/modulo)*modulo-w) | |
| L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L) | |
| E = model(L) | |
| E = E[..., :h*sf, :w*sf] | |
| return E | |
| ''' | |
| # -------------------------------------------- | |
| # split (function) | |
| # -------------------------------------------- | |
| ''' | |
| def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1): | |
| """ | |
| Args: | |
| model: trained model | |
| L: input Low-quality image | |
| refield: effective receptive filed of the network, 32 is enough | |
| min_size: min_sizeXmin_size image, e.g., 256X256 image | |
| sf: scale factor for super-resolution, otherwise 1 | |
| modulo: 1 if split | |
| Returns: | |
| E: estimated result | |
| """ | |
| h, w = L.size()[-2:] | |
| if h*w <= min_size**2: | |
| L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L) | |
| E = model(L) | |
| E = E[..., :h*sf, :w*sf] | |
| else: | |
| top = slice(0, (h//2//refield+1)*refield) | |
| bottom = slice(h - (h//2//refield+1)*refield, h) | |
| left = slice(0, (w//2//refield+1)*refield) | |
| right = slice(w - (w//2//refield+1)*refield, w) | |
| Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]] | |
| if h * w <= 4*(min_size**2): | |
| Es = [model(Ls[i]) for i in range(4)] | |
| else: | |
| Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)] | |
| b, c = Es[0].size()[:2] | |
| E = torch.zeros(b, c, sf * h, sf * w).type_as(L) | |
| E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf] | |
| E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:] | |
| E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf] | |
| E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:] | |
| return E | |
| ''' | |
| # -------------------------------------------- | |
| # split (2) | |
| # -------------------------------------------- | |
| ''' | |
| def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1): | |
| E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo) | |
| return E | |
| ''' | |
| # -------------------------------------------- | |
| # x8 (3) | |
| # -------------------------------------------- | |
| ''' | |
| def test_x8(model, L, modulo=1, sf=1): | |
| E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)] | |
| for i in range(len(E_list)): | |
| if i == 3 or i == 5: | |
| E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i) | |
| else: | |
| E_list[i] = util.augment_img_tensor4(E_list[i], mode=i) | |
| output_cat = torch.stack(E_list, dim=0) | |
| E = output_cat.mean(dim=0, keepdim=False) | |
| return E | |
| ''' | |
| # -------------------------------------------- | |
| # split and x8 (4) | |
| # -------------------------------------------- | |
| ''' | |
| def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1): | |
| E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)] | |
| for k, i in enumerate(range(len(E_list))): | |
| if i==3 or i==5: | |
| E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i) | |
| else: | |
| E_list[k] = util.augment_img_tensor4(E_list[k], mode=i) | |
| output_cat = torch.stack(E_list, dim=0) | |
| E = output_cat.mean(dim=0, keepdim=False) | |
| return E | |
| ''' | |
| # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^- | |
| # _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^ | |
| # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^- | |
| ''' | |
| ''' | |
| # -------------------------------------------- | |
| # -------------------------------------------- | |
| ''' | |
| # -------------------------------------------- | |
| # print model | |
| # -------------------------------------------- | |
| def print_model(model): | |
| msg = describe_model(model) | |
| print(msg) | |
| # -------------------------------------------- | |
| # print params | |
| # -------------------------------------------- | |
| def print_params(model): | |
| msg = describe_params(model) | |
| print(msg) | |
| ''' | |
| # -------------------------------------------- | |
| # information | |
| # -------------------------------------------- | |
| ''' | |
| # -------------------------------------------- | |
| # model inforation | |
| # -------------------------------------------- | |
| def info_model(model): | |
| msg = describe_model(model) | |
| return msg | |
| # -------------------------------------------- | |
| # params inforation | |
| # -------------------------------------------- | |
| def info_params(model): | |
| msg = describe_params(model) | |
| return msg | |
| ''' | |
| # -------------------------------------------- | |
| # description | |
| # -------------------------------------------- | |
| ''' | |
| # -------------------------------------------- | |
| # model name and total number of parameters | |
| # -------------------------------------------- | |
| def describe_model(model): | |
| if isinstance(model, torch.nn.DataParallel): | |
| model = model.module | |
| msg = '\n' | |
| msg += 'models name: {}'.format(model.__class__.__name__) + '\n' | |
| msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n' | |
| msg += 'Net structure:\n{}'.format(str(model)) + '\n' | |
| return msg | |
| # -------------------------------------------- | |
| # parameters description | |
| # -------------------------------------------- | |
| def describe_params(model): | |
| if isinstance(model, torch.nn.DataParallel): | |
| model = model.module | |
| msg = '\n' | |
| msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n' | |
| for name, param in model.state_dict().items(): | |
| if not 'num_batches_tracked' in name: | |
| v = param.data.clone().float() | |
| msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n' | |
| return msg | |
| if __name__ == '__main__': | |
| class Net(torch.nn.Module): | |
| def __init__(self, in_channels=3, out_channels=3): | |
| super(Net, self).__init__() | |
| self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| return x | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| model = Net() | |
| model = model.eval() | |
| print_model(model) | |
| print_params(model) | |
| x = torch.randn((2,3,401,401)) | |
| torch.cuda.empty_cache() | |
| with torch.no_grad(): | |
| for mode in range(5): | |
| y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1) | |
| print(y.shape) | |
| # run utils/utils_model.py | |