u / app.py
tester1hf's picture
Update app.py
cd78906 verified
import os
import time
import gradio as gr
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
import tempfile
# Configuration
CHUNK_SIZE = 256
SCALE_FACTOR = 4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load model
model = torch.jit.load('4xTextures_GTAV_rgt-s.pth', map_location=DEVICE)
model.eval()
def process_chunk(chunk):
"""Process a single chunk through the model"""
preprocess = transforms.Compose([
transforms.ToTensor()
])
img_tensor = preprocess(chunk).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(img_tensor)
output = output.squeeze().cpu().clamp(0, 1).numpy()
return Image.fromarray((output * 255).astype(np.uint8))
def split_image(img, chunk_size):
"""Split image into chunks"""
width, height = img.size
chunks = []
positions = []
for y in range(0, height, chunk_size):
for x in range(0, width, chunk_size):
box = (x, y, x+chunk_size, y+chunk_size)
chunks.append(img.crop(box))
positions.append((x, y))
return chunks, positions
def merge_chunks(chunk_files, positions, target_size):
"""Merge processed chunks into final image"""
merged_img = Image.new('RGB', target_size)
for (x, y), path in zip(positions, chunk_files):
chunk = Image.open(path)
merged_img.paste(chunk, (x*SCALE_FACTOR, y*SCALE_FACTOR))
os.remove(path) # Cleanup immediately
return merged_img
def upscale_image(input_img):
"""Main processing function"""
start_time = time.time()
# Validate input
if input_img.size[0] != input_img.size[1]:
raise ValueError("Input image must be square")
original_size = input_img.size[0]
target_size = original_size * SCALE_FACTOR
# Split into chunks
chunks, positions = split_image(input_img, CHUNK_SIZE)
total_chunks = len(chunks)
# Create temporary directory
temp_dir = tempfile.mkdtemp()
chunk_files = []
# Process chunks
for i, (chunk, (x, y)) in enumerate(zip(chunks, positions)):
chunk_start = time.time()
# Process chunk
upscaled = process_chunk(chunk)
# Save to temp file
chunk_path = os.path.join(temp_dir, f'chunk_{x}_{y}.png')
upscaled.save(chunk_path)
chunk_files.append(chunk_path)
# Calculate progress
elapsed = time.time() - chunk_start
progress = (i + 1) / total_chunks * 100
print(f"Processed chunk {i+1}/{total_chunks} ({progress:.1f}%) - {elapsed:.2f}s")
# Merge chunks
final_image = merge_chunks(chunk_files, positions, (target_size, target_size))
os.rmdir(temp_dir) # Cleanup temp directory
total_time = time.time() - start_time
print(f"Total processing time: {total_time:.2f}s")
return final_image
# Gradio interface
demo = gr.Interface(
fn=upscale_image,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=gr.Image(type="pil", label="Upscaled Image"),
title="4x Texture Upscaler (Low Memory)",
description="Upscale large square textures using RGT model. Accepts 256, 512, 1024, 2048, 4096, 8192px images.",
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()