| | import torch |
| | import yaml |
| |
|
| | from model import Swin2MoSE |
| |
|
| |
|
| | def to_shape(t1, t2): |
| | t1 = t1[None].repeat(t2.shape[0], 1) |
| | t1 = t1.view((t2.shape[:2] + (1, 1))) |
| | return t1 |
| |
|
| |
|
| | def norm(tensor, mean, std): |
| | |
| | mean = torch.tensor(mean).to(tensor.device) |
| | std = torch.tensor(std).to(tensor.device) |
| | |
| | return (tensor - to_shape(mean, tensor)) / to_shape(std, tensor) |
| |
|
| |
|
| | def denorm(tensor, mean, std): |
| | |
| | mean = torch.tensor(mean).to(tensor.device) |
| | std = torch.tensor(std).to(tensor.device) |
| | |
| | return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor) |
| |
|
| |
|
| | def load_config(path): |
| | |
| | with open(path, 'r') as f: |
| | cfg = yaml.safe_load(f) |
| | return cfg |
| |
|
| |
|
| | def load_swin2_mose(model_weights, cfg): |
| | |
| | checkpoint = torch.load(model_weights) |
| |
|
| | |
| | sr_model = Swin2MoSE(**cfg['super_res']['model']) |
| | sr_model.load_state_dict( |
| | checkpoint['model_state_dict']) |
| |
|
| | sr_model.cfg = cfg |
| |
|
| | return sr_model |
| |
|
| |
|
| | def run_swin2_mose(model, lr, hr): |
| | cfg = model.cfg |
| |
|
| | |
| | hr_stats = cfg['dataset']['stats']['tensor_05m_b2b3b4b8'] |
| | lr_stats = cfg['dataset']['stats']['tensor_10m_b2b3b4b8'] |
| |
|
| | |
| | lr_orig = torch.tensor(lr)[None].float()[:, [3, 2, 1, 7]] |
| | hr_orig = torch.tensor(hr)[None].float() |
| |
|
| | |
| | lr = norm(lr_orig, mean=lr_stats['mean'], std=lr_stats['std']) |
| | hr = norm(hr_orig, mean=hr_stats['mean'], std=hr_stats['std']) |
| |
|
| | |
| | sr = model(lr) |
| | if not torch.is_tensor(sr): |
| | sr, _ = sr |
| |
|
| | |
| | sr = denorm(sr, mean=hr_stats['mean'], std=hr_stats['std']) |
| |
|
| | return { |
| | "lr": lr_orig[0], |
| | "sr": sr[0], |
| | "hr": hr_orig[0], |
| | } |
| |
|