| import cv2 |
| import numpy as np |
| import torch |
| import os |
| import tempfile |
| from app.models.ESRGAN import RRDBNet_arch as arch |
|
|
| _MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', 'RRDB_ESRGAN_x4.pth') |
| _device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| _model = None |
|
|
|
|
| def _get_model(): |
| global _model |
| if _model is None: |
| _model = arch.RRDBNet(3, 3, 64, 23, gc=32) |
| _model.load_state_dict(torch.load(_MODEL_PATH, map_location=_device), strict=True) |
| _model.eval() |
| _model = _model.to(_device) |
| return _model |
|
|
|
|
| def upscale_image(image_filepath): |
| model = _get_model() |
|
|
| print('Up-scaling with device: {}...'.format(_device)) |
|
|
| image = cv2.imread(image_filepath, cv2.IMREAD_COLOR) |
| image = image * 1.0 / 255 |
| image = torch.from_numpy(np.transpose(image[:, :, [2, 1, 0]], (2, 0, 1))).float() |
| image_low_res = image.unsqueeze(0).to(_device) |
|
|
| with torch.no_grad(): |
| image_high_res = model(image_low_res).data.squeeze().float().cpu().clamp_(0, 1).numpy() |
| image_high_res = np.transpose(image_high_res[[2, 1, 0], :, :], (1, 2, 0)) |
| image_high_res = (image_high_res * 255.0).round() |
|
|
| _, output_filepath = tempfile.mkstemp(suffix='.png') |
| cv2.imwrite(output_filepath, image_high_res) |
| print("image saved as:", output_filepath) |
|
|
| return output_filepath |
|
|
|
|
|
|
|
|