| import math |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| def infer_params(state_dict): |
| |
| scale2x = 0 |
| scalemin = 6 |
| n_uplayer = 0 |
| plus = False |
|
|
| for block in list(state_dict): |
| parts = block.split(".") |
| n_parts = len(parts) |
| if n_parts == 5 and parts[2] == "sub": |
| nb = int(parts[3]) |
| elif n_parts == 3: |
| part_num = int(parts[1]) |
| if part_num > scalemin and parts[0] == "model" and parts[2] == "weight": |
| scale2x += 1 |
| if part_num > n_uplayer: |
| n_uplayer = part_num |
| out_nc = state_dict[block].shape[0] |
| if not plus and "conv1x1" in block: |
| plus = True |
|
|
| nf = state_dict["model.0.weight"].shape[0] |
| in_nc = state_dict["model.0.weight"].shape[1] |
| out_nc = out_nc |
| scale = 2**scale2x |
|
|
| return in_nc, out_nc, nf, nb, plus, scale |
|
|
|
|
| def tile_process(model, img, tile_pad, tile_size, scale=4): |
| """It will first crop input images to tiles, and then process each tile. |
| Finally, all the processed tiles are merged into one images. |
| |
| Modified from: https://github.com/ata4/esrgan-launcher |
| """ |
| batch, channel, height, width = img.shape |
| output_height = height * scale |
| output_width = width * scale |
| output_shape = (batch, channel, output_height, output_width) |
|
|
| |
| output = img.new_zeros(output_shape) |
| tiles_x = math.ceil(width / tile_size) |
| tiles_y = math.ceil(height / tile_size) |
|
|
| |
| for y in range(tiles_y): |
| for x in range(tiles_x): |
| |
| ofs_x = x * tile_size |
| ofs_y = y * tile_size |
| |
| input_start_x = ofs_x |
| input_end_x = min(ofs_x + tile_size, width) |
| input_start_y = ofs_y |
| input_end_y = min(ofs_y + tile_size, height) |
|
|
| |
| input_start_x_pad = max(input_start_x - tile_pad, 0) |
| input_end_x_pad = min(input_end_x + tile_pad, width) |
| input_start_y_pad = max(input_start_y - tile_pad, 0) |
| input_end_y_pad = min(input_end_y + tile_pad, height) |
|
|
| |
| input_tile_width = input_end_x - input_start_x |
| input_tile_height = input_end_y - input_start_y |
| tile_idx = y * tiles_x + x + 1 |
| input_tile = img[ |
| :, |
| :, |
| input_start_y_pad:input_end_y_pad, |
| input_start_x_pad:input_end_x_pad, |
| ] |
|
|
| |
| try: |
| with torch.no_grad(): |
| output_tile = model(input_tile) |
| except RuntimeError as error: |
| print("Error", error) |
| print(f"\tTile {tile_idx}/{tiles_x * tiles_y}") |
|
|
| |
| output_start_x = input_start_x * scale |
| output_end_x = input_end_x * scale |
| output_start_y = input_start_y * scale |
| output_end_y = input_end_y * scale |
|
|
| |
| output_start_x_tile = (input_start_x - input_start_x_pad) * scale |
| output_end_x_tile = output_start_x_tile + input_tile_width * scale |
| output_start_y_tile = (input_start_y - input_start_y_pad) * scale |
| output_end_y_tile = output_start_y_tile + input_tile_height * scale |
|
|
| |
| output[ |
| :, :, output_start_y:output_end_y, output_start_x:output_end_x |
| ] = output_tile[ |
| :, |
| :, |
| output_start_y_tile:output_end_y_tile, |
| output_start_x_tile:output_end_x_tile, |
| ] |
|
|
| return output |
|
|
|
|
| def upscale(model, img, tile_pad, tile_size): |
| img = np.array(img) |
| img = img[:, :, ::-1] |
| img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 |
| img = torch.from_numpy(img).float() |
| img = img.unsqueeze(0).to("cuda") |
|
|
| output = tile_process(model, img, tile_pad, tile_size, scale=4) |
|
|
| output = output.squeeze().float().cpu().clamp_(0, 1).numpy() |
| output = 255.0 * np.moveaxis(output, 0, 2) |
| output = output.astype(np.uint8) |
| output = output[:, :, ::-1] |
| return output |
|
|