Spaces:
Runtime error
Runtime error
File size: 3,799 Bytes
b77fd1a ca337fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import os
import torch
import torch.nn.functional as F
import torchvision.transforms.v2 as transforms
from models.base_model import UNET
def find_padding(img, depth=2**4):
B, C, H, W = img.shape
h_pad = (depth - H % depth) % depth
w_pad = (depth - W % depth) % depth
return h_pad, w_pad
def get_pretrained_path(model_name):
# 'SRUNET_x2', 'SRUNET_x3', 'SRUNET_x4', 'SRUNET_x234', 'SRUNET_interpolation', 'SRUNET_x234_interpolation'
current_path = os.path.dirname(os.path.abspath(__file__)).replace("\\", "/")
if model_name == 'SRUNET_x2':
return current_path + '/pretrained/SRUNET_scale_x2.pt'
elif model_name == 'SRUNET_x3':
return current_path + '/pretrained/SRUNET_scale_x3.pt'
elif model_name == 'SRUNET_x4':
return current_path + '/pretrained/SRUNET_scale_x4.pt'
elif model_name == 'SRUNET_x234':
return current_path + '/pretrained/SRUNET_scale_x234.pt'
# elif model_name == 'SRUNET_interpolation':
# return current_path + '/pretrained/SRUNET_x3.pt'
# elif model_name == 'SRUNET_x234_interpolation':
# return current_path + '/pretrained/SRUNET_x3.pt'
else:
raise Exception('Model not found')
def upscale_image(img, model_name, scale_factor):
# get img width height
width, height = img.size
img_mode = img.mode
if img.mode != "RGB":
img = img.convert("RGB")
transform = transforms.Compose([
transforms.Resize((height * scale_factor, width * scale_factor),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
])
#Load Model
checkpoint = torch.load(get_pretrained_path(
model_name), map_location=torch.device('cpu'))
model = UNET()
model.load_state_dict(checkpoint['best_model_state_dict'])
model.eval()
data = transform(img).clamp(0, 1).unsqueeze(0)
# print(data.shape, img.mode)
# return img
h_pad, w_pad = find_padding(data)
data = F.pad(data, (0, w_pad, 0, h_pad), mode='reflect')
with torch.no_grad():
img_scale_pred = model(data).clamp(0, 1)
if h_pad > 0 and w_pad > 0:
img_scale_pred = img_scale_pred[..., :-h_pad, :-w_pad]
elif h_pad > 0:
img_scale_pred = img_scale_pred[..., :-h_pad, :]
elif w_pad > 0:
img_scale_pred = img_scale_pred[..., :, :-w_pad]
else:
img_scale_pred = img_scale_pred
img_scale_pred = img_scale_pred.squeeze(0)
return transforms.ToPILImage()(img_scale_pred).convert(img_mode)
def enhanced_image(img, model_name):
img_mode = img.mode
if img.mode != "RGB":
img = img.convert("RGB")
transform = transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
])
#Load Model
checkpoint = torch.load(get_pretrained_path(
model_name), map_location=torch.device('cpu'))
model = UNET()
model.load_state_dict(checkpoint['best_model_state_dict'])
model.eval()
data = transform(img).clamp(0, 1).unsqueeze(0)
h_pad, w_pad = find_padding(data)
data = F.pad(data, (0, w_pad, 0, h_pad), mode='reflect')
with torch.no_grad():
img_scale_pred = model(data).clamp(0, 1)
if h_pad > 0 and w_pad > 0:
img_scale_pred = img_scale_pred[..., :-h_pad, :-w_pad]
elif h_pad > 0:
img_scale_pred = img_scale_pred[..., :-h_pad, :]
elif w_pad > 0:
img_scale_pred = img_scale_pred[..., :, :-w_pad]
else:
img_scale_pred = img_scale_pred
img_scale_pred = img_scale_pred.squeeze(0)
return transforms.ToPILImage()(img_scale_pred).convert(img_mode) |