| | import os, sys |
| | import torch |
| |
|
| | |
| | root_path = os.path.abspath('.') |
| | sys.path.append(root_path) |
| | from opt import opt |
| | from architecture.rrdb import RRDBNet |
| | from architecture.grl import GRL |
| | from architecture.dat import DAT |
| | from architecture.swinir import SwinIR |
| | from architecture.cunet import UNet_Full |
| |
|
| |
|
| | def load_rrdb(generator_weight_PATH, scale, print_options=False): |
| | ''' A simpler API to load RRDB model from Real-ESRGAN |
| | Args: |
| | generator_weight_PATH (str): The path to the weight |
| | scale (int): the scaling factor |
| | print_options (bool): whether to print options to show what kinds of setting is used |
| | Returns: |
| | generator (torch): the generator instance of the model |
| | ''' |
| |
|
| | |
| | checkpoint_g = torch.load(generator_weight_PATH) |
| |
|
| | |
| | if 'params_ema' in checkpoint_g: |
| | |
| | weight = checkpoint_g['params_ema'] |
| | generator = RRDBNet(3, 3, scale=scale) |
| |
|
| | elif 'params' in checkpoint_g: |
| | |
| | weight = checkpoint_g['params'] |
| | generator = RRDBNet(3, 3, scale=scale) |
| |
|
| | elif 'model_state_dict' in checkpoint_g: |
| | |
| | weight = checkpoint_g['model_state_dict'] |
| | generator = RRDBNet(3, 3, scale=scale) |
| |
|
| | else: |
| | print("This weight is not supported") |
| | os._exit(0) |
| |
|
| |
|
| | |
| | old_keys = [key for key in weight] |
| | for old_key in old_keys: |
| | if old_key[:10] == "_orig_mod.": |
| | new_key = old_key[10:] |
| | weight[new_key] = weight[old_key] |
| | del weight[old_key] |
| |
|
| | generator.load_state_dict(weight) |
| | generator = generator.eval().cuda() |
| |
|
| |
|
| | |
| | if print_options: |
| | if 'opt' in checkpoint_g: |
| | for key in checkpoint_g['opt']: |
| | value = checkpoint_g['opt'][key] |
| | print(f'{key} : {value}') |
| |
|
| | return generator |
| |
|
| |
|
| | def load_cunet(generator_weight_PATH, scale, print_options=False): |
| | ''' A simpler API to load CUNET model from Real-CUGAN |
| | Args: |
| | generator_weight_PATH (str): The path to the weight |
| | scale (int): the scaling factor |
| | print_options (bool): whether to print options to show what kinds of setting is used |
| | Returns: |
| | generator (torch): the generator instance of the model |
| | ''' |
| | |
| | |
| | if scale != 2: |
| | raise NotImplementedError("We only support 2x in CUNET") |
| |
|
| | |
| | checkpoint_g = torch.load(generator_weight_PATH) |
| |
|
| | |
| | if 'model_state_dict' in checkpoint_g: |
| | |
| | weight = checkpoint_g['model_state_dict'] |
| | loss = checkpoint_g["lowest_generator_weight"] |
| | if "iteration" in checkpoint_g: |
| | iteration = checkpoint_g["iteration"] |
| | else: |
| | iteration = "NAN" |
| | generator = UNet_Full() |
| | |
| | print(f"the generator weight is {loss} at iteration {iteration}") |
| |
|
| | else: |
| | print("This weight is not supported") |
| | os._exit(0) |
| |
|
| |
|
| | |
| | old_keys = [key for key in weight] |
| | for old_key in old_keys: |
| | if old_key[:10] == "_orig_mod.": |
| | new_key = old_key[10:] |
| | weight[new_key] = weight[old_key] |
| | del weight[old_key] |
| |
|
| | generator.load_state_dict(weight) |
| | generator = generator.eval().cuda() |
| |
|
| |
|
| | |
| | if print_options: |
| | if 'opt' in checkpoint_g: |
| | for key in checkpoint_g['opt']: |
| | value = checkpoint_g['opt'][key] |
| | print(f'{key} : {value}') |
| |
|
| | return generator |
| |
|
| | def load_grl(generator_weight_PATH, scale=4): |
| | ''' A simpler API to load GRL model |
| | Args: |
| | generator_weight_PATH (str): The path to the weight |
| | scale (int): Scale Factor (Usually Set as 4) |
| | Returns: |
| | generator (torch): the generator instance of the model |
| | ''' |
| |
|
| | |
| | checkpoint_g = torch.load(generator_weight_PATH) |
| |
|
| | |
| | if 'model_state_dict' in checkpoint_g: |
| | weight = checkpoint_g['model_state_dict'] |
| |
|
| | |
| | generator = GRL( |
| | upscale = scale, |
| | img_size = 64, |
| | window_size = 8, |
| | depths = [4, 4, 4, 4], |
| | embed_dim = 64, |
| | num_heads_window = [2, 2, 2, 2], |
| | num_heads_stripe = [2, 2, 2, 2], |
| | mlp_ratio = 2, |
| | qkv_proj_type = "linear", |
| | anchor_proj_type = "avgpool", |
| | anchor_window_down_factor = 2, |
| | out_proj_type = "linear", |
| | conv_type = "1conv", |
| | upsampler = "nearest+conv", |
| | ).cuda() |
| |
|
| | else: |
| | print("This weight is not supported") |
| | os._exit(0) |
| |
|
| |
|
| | generator.load_state_dict(weight) |
| | generator = generator.eval().cuda() |
| |
|
| |
|
| | num_params = 0 |
| | for p in generator.parameters(): |
| | if p.requires_grad: |
| | num_params += p.numel() |
| | print(f"Number of parameters {num_params / 10 ** 6: 0.2f}") |
| |
|
| |
|
| | return generator |
| |
|
| |
|
| |
|
| | def load_dat(generator_weight_PATH, scale=4): |
| |
|
| | |
| | checkpoint_g = torch.load(generator_weight_PATH) |
| |
|
| | |
| | if 'model_state_dict' in checkpoint_g: |
| | weight = checkpoint_g['model_state_dict'] |
| |
|
| | |
| | generator = DAT(upscale = 4, |
| | in_chans = 3, |
| | img_size = 64, |
| | img_range = 1., |
| | depth = [6, 6, 6, 6, 6, 6], |
| | embed_dim = 180, |
| | num_heads = [6, 6, 6, 6, 6, 6], |
| | expansion_factor = 2, |
| | resi_connection = '1conv', |
| | split_size = [8, 16], |
| | upsampler = 'pixelshuffledirect', |
| | ).cuda() |
| |
|
| | else: |
| | print("This weight is not supported") |
| | os._exit(0) |
| |
|
| |
|
| | generator.load_state_dict(weight) |
| | generator = generator.eval().cuda() |
| |
|
| |
|
| | num_params = 0 |
| | for p in generator.parameters(): |
| | if p.requires_grad: |
| | num_params += p.numel() |
| | print(f"Number of parameters {num_params / 10 ** 6: 0.2f}") |
| |
|
| |
|
| | return generator |
| |
|