|
|
import argparse |
|
|
import cv2 |
|
|
import glob |
|
|
import os |
|
|
import math |
|
|
import numpy as np |
|
|
import axengine as axe |
|
|
|
|
|
def pre_process(img, tile_size=108, tile_pad=10): |
|
|
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible |
|
|
""" |
|
|
|
|
|
pad_h, pad_w = 0, 0 |
|
|
h, w = img.shape[0:2] |
|
|
|
|
|
if h % tile_size != 0: |
|
|
pad_h = (tile_size - h % tile_size) |
|
|
if w % tile_size != 0: |
|
|
pad_w = (tile_size - w % tile_size) |
|
|
img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') |
|
|
|
|
|
|
|
|
img = np.pad(img, ((tile_pad, tile_pad), (tile_pad, tile_pad), (0, 0)), 'constant') |
|
|
|
|
|
|
|
|
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0) |
|
|
|
|
|
return img |
|
|
|
|
|
def tile_process(img, origin_shape, model, scale=2, tile_size=108, tile_pad=10, imgname=None): |
|
|
"""It will first crop input images to tiles, and then process each tile. |
|
|
Finally, all the processed tiles are merged into one images. |
|
|
""" |
|
|
|
|
|
|
|
|
if not os.path.exists(model): |
|
|
raise ValueError(f'Model {model} does not exist.') |
|
|
|
|
|
session = axe.InferenceSession(model) |
|
|
input_name = session.get_inputs()[0].name |
|
|
output_names = [x.name for x in session.get_outputs()] |
|
|
|
|
|
|
|
|
batch, channel, height, width = img.shape |
|
|
output_height = int(round(height * scale)) |
|
|
output_width = int(round(width * scale)) |
|
|
output_shape = (batch, channel, output_height, output_width) |
|
|
origin_h, origin_w = origin_shape[0:2] |
|
|
|
|
|
|
|
|
output = np.zeros(output_shape) |
|
|
tiles_x = math.floor(width / tile_size) |
|
|
tiles_y = math.floor(height / tile_size) |
|
|
print(f'Tile {tiles_x} x {tiles_y} for image {imgname}') |
|
|
|
|
|
|
|
|
for y in range(tiles_y): |
|
|
for x in range(tiles_x): |
|
|
|
|
|
ofs_x = x * tile_size |
|
|
ofs_y = y * tile_size |
|
|
|
|
|
input_start_x = ofs_x |
|
|
input_end_x = min(ofs_x + tile_size, width) |
|
|
input_start_y = ofs_y |
|
|
input_end_y = min(ofs_y + tile_size, height) |
|
|
|
|
|
|
|
|
input_tile = img[:, :, input_start_y:(input_end_y+2*tile_pad), |
|
|
input_start_x:(input_end_x+2*tile_pad)] |
|
|
|
|
|
|
|
|
try: |
|
|
output_tile = session.run(output_names, {input_name: input_tile}) |
|
|
except RuntimeError as error: |
|
|
print('Error', error) |
|
|
|
|
|
|
|
|
|
|
|
output_start_x = int(round(input_start_x * scale)) |
|
|
output_end_x = int(round(input_end_x * scale)) |
|
|
output_start_y = int(round(input_start_y * scale)) |
|
|
output_end_y = int(round(input_end_y * scale)) |
|
|
|
|
|
start_tile = int(round(tile_pad * scale)) |
|
|
end_tile = int(round(tile_size * scale)) + start_tile |
|
|
|
|
|
output[:, :, output_start_y:output_end_y, |
|
|
output_start_x:output_end_x] = output_tile[0][:, :, start_tile:end_tile, start_tile:end_tile] |
|
|
|
|
|
|
|
|
output = output[:, :, :int(round(origin_h * scale)), :int(round(origin_w * scale))].squeeze(0) |
|
|
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)).astype(np.float32) |
|
|
|
|
|
return output |
|
|
|
|
|
def main(): |
|
|
"""Inference demo for Real-ESRGAN. |
|
|
""" |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder') |
|
|
parser.add_argument('-o', '--output', type=str, default='results', help='Output folder') |
|
|
parser.add_argument('-s', '--outscale', type=float, default=2, help='The final upsampling scale of the image, [Option:2, 4]') |
|
|
parser.add_argument( |
|
|
'--model_path', type=str, default=None, help='Model path. you need to specify it [Options: ]') |
|
|
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image') |
|
|
parser.add_argument('-t', '--tile', type=int, default=108, help='Tile size, 0 for no tile during testing') |
|
|
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding, (tile + tile_pad must == 128.)') |
|
|
parser.add_argument( |
|
|
'--ext', |
|
|
type=str, |
|
|
default='auto', |
|
|
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
assert (args.tile + 2*args.tile_pad) == 128, 'the model input size: 128.' |
|
|
|
|
|
|
|
|
if os.path.isfile(args.input): |
|
|
paths = [args.input] |
|
|
else: |
|
|
paths = sorted(glob.glob(os.path.join(args.input, '*'))) |
|
|
|
|
|
|
|
|
os.makedirs(args.output, exist_ok=True) |
|
|
|
|
|
for idx, path in enumerate(paths): |
|
|
imgname, extension = os.path.splitext(os.path.basename(path)) |
|
|
print('Testing', idx, imgname) |
|
|
if extension not in ['.jpg', '.jpeg', '.png', '.tif', '.tiff', '.bmp', '.webp']: |
|
|
continue |
|
|
|
|
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) |
|
|
if img is None: |
|
|
print('Error loading image') |
|
|
continue |
|
|
img = img.astype(np.float32) |
|
|
if np.max(img) > 256: |
|
|
max_range = 65535 |
|
|
print('\tInput is a 16-bit image') |
|
|
else: |
|
|
max_range = 255 |
|
|
img = img / max_range |
|
|
if len(img.shape) == 2: |
|
|
img_mode = 'L' |
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
|
|
elif img.shape[2] == 4: |
|
|
img_mode = 'RGBA' |
|
|
alpha = img[:, :, 3] |
|
|
img = img[:, :, 0:3] |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
else: |
|
|
img_mode = 'RGB' |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
origin_shape = img.shape |
|
|
img = pre_process(img, args.tile) |
|
|
|
|
|
|
|
|
try: |
|
|
output_img = tile_process(img, origin_shape, args.model_path, args.outscale, args.tile, args.tile_pad, imgname) |
|
|
except RuntimeError as error: |
|
|
print('Error', error) |
|
|
print('If you encounter out of memory, try to set --tile with a smaller number.') |
|
|
|
|
|
if img_mode == 'L': |
|
|
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) |
|
|
if img_mode == 'RGBA': |
|
|
h, w = alpha.shape[0:2] |
|
|
output_alpha = cv2.resize( |
|
|
alpha, |
|
|
(int(round(w * args.outscale)), |
|
|
int(round(h * args.outscale))), |
|
|
interpolation=cv2.INTER_LINEAR |
|
|
) |
|
|
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) |
|
|
output_img[:, :, 3] = output_alpha |
|
|
|
|
|
if max_range == 65535: |
|
|
output = np.clip((output_img * 65535.0), 0, 65535).astype(np.uint16) |
|
|
else: |
|
|
output = np.clip((output_img * 255.0), 0, 255).astype(np.uint8) |
|
|
|
|
|
if args.ext == 'auto': |
|
|
extension = extension[1:] |
|
|
else: |
|
|
extension = args.ext |
|
|
|
|
|
if args.suffix == '': |
|
|
save_path = os.path.join(args.output, f'{imgname}.{extension}') |
|
|
else: |
|
|
save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}') |
|
|
cv2.imwrite(save_path, output) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|