SGAN / inference.py
ParamAhuja
ui
aa37e25
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
def process_tile(model, tile, device):
tensor_tile = torch.from_numpy(tile).permute(2, 0, 1).unsqueeze(0).float() / 255.0
tensor_tile = tensor_tile.to(device)
with torch.no_grad():
out = model(tensor_tile)
out = out.squeeze(0).permute(1, 2, 0).cpu().clamp(0, 1).numpy() * 255.0
return out.astype(np.uint8)
def process_tiled(image, model, tile_size=128, overlap=32, scale=4, device="cpu"):
img_array = np.array(image)
if len(img_array.shape) == 2:
img_array = np.expand_dims(img_array, axis=2)
h, w, c = img_array.shape
out_h, out_w = h * scale, w * scale
# We use a simple blending map
result = np.zeros((out_h, out_w, c), dtype=np.float32)
weight_sum = np.zeros((out_h, out_w, c), dtype=np.float32)
for y in range(0, h, tile_size - overlap):
for x in range(0, w, tile_size - overlap):
# Extract tile
y_end = min(y + tile_size, h)
x_end = min(x + tile_size, w)
tile = img_array[y:y_end, x:x_end, :]
# Predict
tile_out = process_tile(model, tile, device)
# Place in output array
out_y = y * scale
out_x = x * scale
out_y_end = y_end * scale
out_x_end = x_end * scale
# Simple weighting: 1 for the tile. To improve, we can implement Bartlett or Hann window.
# Here we just average overlapping areas.
if len(tile_out.shape) == 2:
tile_out = np.expand_dims(tile_out, axis=2)
result[out_y:out_y_end, out_x:out_x_end, :] += tile_out
weight_sum[out_y:out_y_end, out_x:out_x_end, :] += 1.0
# Avoid div by zero
result = result / np.clip(weight_sum, 1e-5, None)
result = np.clip(result, 0, 255).astype(np.uint8)
if c == 1:
return Image.fromarray(result[:, :, 0], mode="L")
return Image.fromarray(result)
def default_x4_upscale(image):
# Dummy fallback if model doesn't exist or crashes
w, h = image.size
return image.resize((w*4, h*4), Image.BICUBIC)
def run_inference(image, models_dict, x8_mode=False, device="cpu"):
results = {}
for model_name in ["srcnn", "satlas", "esrgan"]:
if model_name == "satlas":
print("Bypassing tiled inference for satlas backbone, using bicubic placeholder.")
w, h = image.size
sr_img = image.resize((w*4, h*4), Image.BICUBIC)
elif model_name in models_dict:
try:
print(f"Running inference with {model_name}...")
if model_name == "srcnn":
img_ycbcr = image.convert('YCbCr')
y, cb, cr = img_ycbcr.split()
w, h = image.size
cb = cb.resize((w*4, h*4), Image.BICUBIC)
cr = cr.resize((w*4, h*4), Image.BICUBIC)
y_out = process_tiled(y, models_dict[model_name], device=device)
sr_img = Image.merge('YCbCr', (y_out, cb, cr)).convert('RGB')
else:
sr_img = process_tiled(image, models_dict[model_name], device=device)
except Exception as e:
print(f"Error inferencing {model_name}: {e}")
sr_img = None
else:
sr_img = None
if sr_img is not None:
if x8_mode:
w, h = sr_img.size
sr_img = sr_img.resize((w * 2, h * 2), Image.BICUBIC)
results[model_name] = sr_img
else:
results[model_name] = None
return results