Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms.v2 as transforms | |
| from models.base_model import UNET | |
| def find_padding(img, depth=2**4): | |
| B, C, H, W = img.shape | |
| h_pad = (depth - H % depth) % depth | |
| w_pad = (depth - W % depth) % depth | |
| return h_pad, w_pad | |
| def get_pretrained_path(model_name): | |
| # 'SRUNET_x2', 'SRUNET_x3', 'SRUNET_x4', 'SRUNET_x234', 'SRUNET_interpolation', 'SRUNET_x234_interpolation' | |
| current_path = os.path.dirname(os.path.abspath(__file__)).replace("\\", "/") | |
| if model_name == 'SRUNET_x2': | |
| return current_path + '/pretrained/SRUNET_scale_x2.pt' | |
| elif model_name == 'SRUNET_x3': | |
| return current_path + '/pretrained/SRUNET_scale_x3.pt' | |
| elif model_name == 'SRUNET_x4': | |
| return current_path + '/pretrained/SRUNET_scale_x4.pt' | |
| elif model_name == 'SRUNET_x234': | |
| return current_path + '/pretrained/SRUNET_scale_x234.pt' | |
| # elif model_name == 'SRUNET_interpolation': | |
| # return current_path + '/pretrained/SRUNET_x3.pt' | |
| # elif model_name == 'SRUNET_x234_interpolation': | |
| # return current_path + '/pretrained/SRUNET_x3.pt' | |
| else: | |
| raise Exception('Model not found') | |
| def upscale_image(img, model_name, scale_factor): | |
| # get img width height | |
| width, height = img.size | |
| img_mode = img.mode | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.Resize((height * scale_factor, width * scale_factor), | |
| interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.ToImage(), | |
| transforms.ToDtype(torch.float32, scale=True), | |
| ]) | |
| #Load Model | |
| checkpoint = torch.load(get_pretrained_path( | |
| model_name), map_location=torch.device('cpu')) | |
| model = UNET() | |
| model.load_state_dict(checkpoint['best_model_state_dict']) | |
| model.eval() | |
| data = transform(img).clamp(0, 1).unsqueeze(0) | |
| # print(data.shape, img.mode) | |
| # return img | |
| h_pad, w_pad = find_padding(data) | |
| data = F.pad(data, (0, w_pad, 0, h_pad), mode='reflect') | |
| with torch.no_grad(): | |
| img_scale_pred = model(data).clamp(0, 1) | |
| if h_pad > 0 and w_pad > 0: | |
| img_scale_pred = img_scale_pred[..., :-h_pad, :-w_pad] | |
| elif h_pad > 0: | |
| img_scale_pred = img_scale_pred[..., :-h_pad, :] | |
| elif w_pad > 0: | |
| img_scale_pred = img_scale_pred[..., :, :-w_pad] | |
| else: | |
| img_scale_pred = img_scale_pred | |
| img_scale_pred = img_scale_pred.squeeze(0) | |
| return transforms.ToPILImage()(img_scale_pred).convert(img_mode) | |
| def enhanced_image(img, model_name): | |
| img_mode = img.mode | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.ToImage(), | |
| transforms.ToDtype(torch.float32, scale=True), | |
| ]) | |
| #Load Model | |
| checkpoint = torch.load(get_pretrained_path( | |
| model_name), map_location=torch.device('cpu')) | |
| model = UNET() | |
| model.load_state_dict(checkpoint['best_model_state_dict']) | |
| model.eval() | |
| data = transform(img).clamp(0, 1).unsqueeze(0) | |
| h_pad, w_pad = find_padding(data) | |
| data = F.pad(data, (0, w_pad, 0, h_pad), mode='reflect') | |
| with torch.no_grad(): | |
| img_scale_pred = model(data).clamp(0, 1) | |
| if h_pad > 0 and w_pad > 0: | |
| img_scale_pred = img_scale_pred[..., :-h_pad, :-w_pad] | |
| elif h_pad > 0: | |
| img_scale_pred = img_scale_pred[..., :-h_pad, :] | |
| elif w_pad > 0: | |
| img_scale_pred = img_scale_pred[..., :, :-w_pad] | |
| else: | |
| img_scale_pred = img_scale_pred | |
| img_scale_pred = img_scale_pred.squeeze(0) | |
| return transforms.ToPILImage()(img_scale_pred).convert(img_mode) |