|
|
import os |
|
|
import time |
|
|
import gradio as gr |
|
|
import torch |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from torchvision import transforms |
|
|
import tempfile |
|
|
|
|
|
|
|
|
CHUNK_SIZE = 256 |
|
|
SCALE_FACTOR = 4 |
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
model = torch.jit.load('4xTextures_GTAV_rgt-s.pth', map_location=DEVICE) |
|
|
model.eval() |
|
|
|
|
|
def process_chunk(chunk): |
|
|
"""Process a single chunk through the model""" |
|
|
preprocess = transforms.Compose([ |
|
|
transforms.ToTensor() |
|
|
]) |
|
|
|
|
|
img_tensor = preprocess(chunk).unsqueeze(0).to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(img_tensor) |
|
|
|
|
|
output = output.squeeze().cpu().clamp(0, 1).numpy() |
|
|
return Image.fromarray((output * 255).astype(np.uint8)) |
|
|
|
|
|
def split_image(img, chunk_size): |
|
|
"""Split image into chunks""" |
|
|
width, height = img.size |
|
|
chunks = [] |
|
|
positions = [] |
|
|
|
|
|
for y in range(0, height, chunk_size): |
|
|
for x in range(0, width, chunk_size): |
|
|
box = (x, y, x+chunk_size, y+chunk_size) |
|
|
chunks.append(img.crop(box)) |
|
|
positions.append((x, y)) |
|
|
|
|
|
return chunks, positions |
|
|
|
|
|
def merge_chunks(chunk_files, positions, target_size): |
|
|
"""Merge processed chunks into final image""" |
|
|
merged_img = Image.new('RGB', target_size) |
|
|
for (x, y), path in zip(positions, chunk_files): |
|
|
chunk = Image.open(path) |
|
|
merged_img.paste(chunk, (x*SCALE_FACTOR, y*SCALE_FACTOR)) |
|
|
os.remove(path) |
|
|
return merged_img |
|
|
|
|
|
def upscale_image(input_img): |
|
|
"""Main processing function""" |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
if input_img.size[0] != input_img.size[1]: |
|
|
raise ValueError("Input image must be square") |
|
|
|
|
|
original_size = input_img.size[0] |
|
|
target_size = original_size * SCALE_FACTOR |
|
|
|
|
|
|
|
|
chunks, positions = split_image(input_img, CHUNK_SIZE) |
|
|
total_chunks = len(chunks) |
|
|
|
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
chunk_files = [] |
|
|
|
|
|
|
|
|
for i, (chunk, (x, y)) in enumerate(zip(chunks, positions)): |
|
|
chunk_start = time.time() |
|
|
|
|
|
|
|
|
upscaled = process_chunk(chunk) |
|
|
|
|
|
|
|
|
chunk_path = os.path.join(temp_dir, f'chunk_{x}_{y}.png') |
|
|
upscaled.save(chunk_path) |
|
|
chunk_files.append(chunk_path) |
|
|
|
|
|
|
|
|
elapsed = time.time() - chunk_start |
|
|
progress = (i + 1) / total_chunks * 100 |
|
|
print(f"Processed chunk {i+1}/{total_chunks} ({progress:.1f}%) - {elapsed:.2f}s") |
|
|
|
|
|
|
|
|
final_image = merge_chunks(chunk_files, positions, (target_size, target_size)) |
|
|
os.rmdir(temp_dir) |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
print(f"Total processing time: {total_time:.2f}s") |
|
|
return final_image |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=upscale_image, |
|
|
inputs=gr.Image(type="pil", label="Input Image"), |
|
|
outputs=gr.Image(type="pil", label="Upscaled Image"), |
|
|
title="4x Texture Upscaler (Low Memory)", |
|
|
description="Upscale large square textures using RGT model. Accepts 256, 512, 1024, 2048, 4096, 8192px images.", |
|
|
allow_flagging="never" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |