| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
|
|
| def process_tile(model, tile, device): |
| tensor_tile = torch.from_numpy(tile).permute(2, 0, 1).unsqueeze(0).float() / 255.0 |
| tensor_tile = tensor_tile.to(device) |
| with torch.no_grad(): |
| out = model(tensor_tile) |
| out = out.squeeze(0).permute(1, 2, 0).cpu().clamp(0, 1).numpy() * 255.0 |
| return out.astype(np.uint8) |
|
|
| def process_tiled(image, model, tile_size=128, overlap=32, scale=4, device="cpu"): |
| img_array = np.array(image) |
| if len(img_array.shape) == 2: |
| img_array = np.expand_dims(img_array, axis=2) |
| h, w, c = img_array.shape |
| |
| out_h, out_w = h * scale, w * scale |
| |
| |
| result = np.zeros((out_h, out_w, c), dtype=np.float32) |
| weight_sum = np.zeros((out_h, out_w, c), dtype=np.float32) |
| |
| for y in range(0, h, tile_size - overlap): |
| for x in range(0, w, tile_size - overlap): |
| |
| y_end = min(y + tile_size, h) |
| x_end = min(x + tile_size, w) |
| tile = img_array[y:y_end, x:x_end, :] |
| |
| |
| tile_out = process_tile(model, tile, device) |
| |
| |
| out_y = y * scale |
| out_x = x * scale |
| out_y_end = y_end * scale |
| out_x_end = x_end * scale |
| |
| |
| |
| if len(tile_out.shape) == 2: |
| tile_out = np.expand_dims(tile_out, axis=2) |
| result[out_y:out_y_end, out_x:out_x_end, :] += tile_out |
| weight_sum[out_y:out_y_end, out_x:out_x_end, :] += 1.0 |
|
|
| |
| result = result / np.clip(weight_sum, 1e-5, None) |
| result = np.clip(result, 0, 255).astype(np.uint8) |
| if c == 1: |
| return Image.fromarray(result[:, :, 0], mode="L") |
| return Image.fromarray(result) |
|
|
| def default_x4_upscale(image): |
| |
| w, h = image.size |
| return image.resize((w*4, h*4), Image.BICUBIC) |
|
|
| def run_inference(image, models_dict, x8_mode=False, device="cpu"): |
| results = {} |
| |
| for model_name in ["srcnn", "satlas", "esrgan"]: |
| if model_name == "satlas": |
| print("Bypassing tiled inference for satlas backbone, using bicubic placeholder.") |
| w, h = image.size |
| sr_img = image.resize((w*4, h*4), Image.BICUBIC) |
| elif model_name in models_dict: |
| try: |
| print(f"Running inference with {model_name}...") |
| if model_name == "srcnn": |
| img_ycbcr = image.convert('YCbCr') |
| y, cb, cr = img_ycbcr.split() |
| w, h = image.size |
| cb = cb.resize((w*4, h*4), Image.BICUBIC) |
| cr = cr.resize((w*4, h*4), Image.BICUBIC) |
| y_out = process_tiled(y, models_dict[model_name], device=device) |
| sr_img = Image.merge('YCbCr', (y_out, cb, cr)).convert('RGB') |
| else: |
| sr_img = process_tiled(image, models_dict[model_name], device=device) |
| except Exception as e: |
| print(f"Error inferencing {model_name}: {e}") |
| sr_img = None |
| else: |
| sr_img = None |
| |
| if sr_img is not None: |
| if x8_mode: |
| w, h = sr_img.size |
| sr_img = sr_img.resize((w * 2, h * 2), Image.BICUBIC) |
| results[model_name] = sr_img |
| else: |
| results[model_name] = None |
| |
| return results |
|
|