import torch import torch.nn.functional as F import numpy as np from PIL import Image def process_tile(model, tile, device): tensor_tile = torch.from_numpy(tile).permute(2, 0, 1).unsqueeze(0).float() / 255.0 tensor_tile = tensor_tile.to(device) with torch.no_grad(): out = model(tensor_tile) out = out.squeeze(0).permute(1, 2, 0).cpu().clamp(0, 1).numpy() * 255.0 return out.astype(np.uint8) def process_tiled(image, model, tile_size=128, overlap=32, scale=4, device="cpu"): img_array = np.array(image) if len(img_array.shape) == 2: img_array = np.expand_dims(img_array, axis=2) h, w, c = img_array.shape out_h, out_w = h * scale, w * scale # We use a simple blending map result = np.zeros((out_h, out_w, c), dtype=np.float32) weight_sum = np.zeros((out_h, out_w, c), dtype=np.float32) for y in range(0, h, tile_size - overlap): for x in range(0, w, tile_size - overlap): # Extract tile y_end = min(y + tile_size, h) x_end = min(x + tile_size, w) tile = img_array[y:y_end, x:x_end, :] # Predict tile_out = process_tile(model, tile, device) # Place in output array out_y = y * scale out_x = x * scale out_y_end = y_end * scale out_x_end = x_end * scale # Simple weighting: 1 for the tile. To improve, we can implement Bartlett or Hann window. # Here we just average overlapping areas. if len(tile_out.shape) == 2: tile_out = np.expand_dims(tile_out, axis=2) result[out_y:out_y_end, out_x:out_x_end, :] += tile_out weight_sum[out_y:out_y_end, out_x:out_x_end, :] += 1.0 # Avoid div by zero result = result / np.clip(weight_sum, 1e-5, None) result = np.clip(result, 0, 255).astype(np.uint8) if c == 1: return Image.fromarray(result[:, :, 0], mode="L") return Image.fromarray(result) def default_x4_upscale(image): # Dummy fallback if model doesn't exist or crashes w, h = image.size return image.resize((w*4, h*4), Image.BICUBIC) def run_inference(image, models_dict, x8_mode=False, device="cpu"): results = {} for model_name in ["srcnn", "satlas", "esrgan"]: if model_name == "satlas": print("Bypassing tiled inference for satlas backbone, using bicubic placeholder.") w, h = image.size sr_img = image.resize((w*4, h*4), Image.BICUBIC) elif model_name in models_dict: try: print(f"Running inference with {model_name}...") if model_name == "srcnn": img_ycbcr = image.convert('YCbCr') y, cb, cr = img_ycbcr.split() w, h = image.size cb = cb.resize((w*4, h*4), Image.BICUBIC) cr = cr.resize((w*4, h*4), Image.BICUBIC) y_out = process_tiled(y, models_dict[model_name], device=device) sr_img = Image.merge('YCbCr', (y_out, cb, cr)).convert('RGB') else: sr_img = process_tiled(image, models_dict[model_name], device=device) except Exception as e: print(f"Error inferencing {model_name}: {e}") sr_img = None else: sr_img = None if sr_img is not None: if x8_mode: w, h = sr_img.size sr_img = sr_img.resize((w * 2, h * 2), Image.BICUBIC) results[model_name] = sr_img else: results[model_name] = None return results