Real-ESRGAN.axera / run_axmodel.py
jounery-d's picture
Update run_axmodel.py
fdaf566 verified
import argparse
import cv2
import glob
import os
import math
import numpy as np
import axengine as axe
def pre_process(img, tile_size=108, tile_pad=10):
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
"""
# mod pad for divisible borders
pad_h, pad_w = 0, 0
h, w = img.shape[0:2]
if h % tile_size != 0:
pad_h = (tile_size - h % tile_size)
if w % tile_size != 0:
pad_w = (tile_size - w % tile_size)
img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') #mode='reflect')
# boundary pad
img = np.pad(img, ((tile_pad, tile_pad), (tile_pad, tile_pad), (0, 0)), 'constant')
# to CHW-Batch format
img = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
return img
def tile_process(img, origin_shape, model, scale=2, tile_size=108, tile_pad=10, imgname=None):
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
"""
# determine model paths
if not os.path.exists(model):
raise ValueError(f'Model {model} does not exist.')
session = axe.InferenceSession(model)
input_name = session.get_inputs()[0].name
output_names = [x.name for x in session.get_outputs()]
# tile
batch, channel, height, width = img.shape
output_height = int(round(height * scale))
output_width = int(round(width * scale))
output_shape = (batch, channel, output_height, output_width)
origin_h, origin_w = origin_shape[0:2]
# start with black image
output = np.zeros(output_shape)
tiles_x = math.floor(width / tile_size)
tiles_y = math.floor(height / tile_size)
print(f'Tile {tiles_x} x {tiles_y} for image {imgname}')
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * tile_size
ofs_y = y * tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + tile_size, height)
# input tile dimensions
input_tile = img[:, :, input_start_y:(input_end_y+2*tile_pad),
input_start_x:(input_end_x+2*tile_pad)]
# upscale tile
try:
output_tile = session.run(output_names, {input_name: input_tile})
except RuntimeError as error:
print('Error', error)
#print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
# output tile area on total image
output_start_x = int(round(input_start_x * scale))
output_end_x = int(round(input_end_x * scale))
output_start_y = int(round(input_start_y * scale))
output_end_y = int(round(input_end_y * scale))
start_tile = int(round(tile_pad * scale))
end_tile = int(round(tile_size * scale)) + start_tile
output[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = output_tile[0][:, :, start_tile:end_tile, start_tile:end_tile]
# remove extra padding parts
output = output[:, :, :int(round(origin_h * scale)), :int(round(origin_w * scale))].squeeze(0)
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)).astype(np.float32)
return output
def main():
"""Inference demo for Real-ESRGAN.
"""
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
parser.add_argument('-s', '--outscale', type=float, default=2, help='The final upsampling scale of the image, [Option:2, 4]')
parser.add_argument(
'--model_path', type=str, default=None, help='Model path. you need to specify it [Options: ]')
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
parser.add_argument('-t', '--tile', type=int, default=108, help='Tile size, 0 for no tile during testing')
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding, (tile + tile_pad must == 128.)')
parser.add_argument(
'--ext',
type=str,
default='auto',
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
args = parser.parse_args()
# shape check
assert (args.tile + 2*args.tile_pad) == 128, 'the model input size: 128.'
# input
if os.path.isfile(args.input):
paths = [args.input]
else:
paths = sorted(glob.glob(os.path.join(args.input, '*')))
# output
os.makedirs(args.output, exist_ok=True)
for idx, path in enumerate(paths):
imgname, extension = os.path.splitext(os.path.basename(path))
print('Testing', idx, imgname)
if extension not in ['.jpg', '.jpeg', '.png', '.tif', '.tiff', '.bmp', '.webp']:
continue
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img is None:
print('Error loading image')
continue
img = img.astype(np.float32)
if np.max(img) > 256: # 16-bit image
max_range = 65535
print('\tInput is a 16-bit image')
else:
max_range = 255
img = img / max_range
if len(img.shape) == 2: # gray image
img_mode = 'L'
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4: # RGBA image with alpha channel
img_mode = 'RGBA'
alpha = img[:, :, 3]
img = img[:, :, 0:3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
else:
img_mode = 'RGB'
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# pre-process
origin_shape = img.shape
img = pre_process(img, args.tile)
# tile process
try:
output_img = tile_process(img, origin_shape, args.model_path, args.outscale, args.tile, args.tile_pad, imgname)
except RuntimeError as error:
print('Error', error)
print('If you encounter out of memory, try to set --tile with a smaller number.')
if img_mode == 'L':
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
if img_mode == 'RGBA':
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(
alpha,
(int(round(w * args.outscale)),
int(round(h * args.outscale))),
interpolation=cv2.INTER_LINEAR
)
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha
if max_range == 65535: # 16-bit image
output = np.clip((output_img * 65535.0), 0, 65535).astype(np.uint16)
else:
output = np.clip((output_img * 255.0), 0, 255).astype(np.uint8)
if args.ext == 'auto':
extension = extension[1:]
else:
extension = args.ext
if args.suffix == '':
save_path = os.path.join(args.output, f'{imgname}.{extension}')
else:
save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
cv2.imwrite(save_path, output)
if __name__ == '__main__':
main()