u
File size: 3,336 Bytes
cd78906
 
4f20e59
cd78906
f612ed1
cd78906
 
 
 
 
 
 
 
 
 
 
 
f612ed1
cd78906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f612ed1
 
cd78906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f612ed1
cd78906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f612ed1
cd78906
 
 
f612ed1
cd78906
 
 
f612ed1
cd78906
f612ed1
 
cd78906
 
 
 
 
f612ed1
4f20e59
 
f612ed1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import time
import gradio as gr
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
import tempfile

# Configuration
CHUNK_SIZE = 256
SCALE_FACTOR = 4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load model
model = torch.jit.load('4xTextures_GTAV_rgt-s.pth', map_location=DEVICE)
model.eval()

def process_chunk(chunk):
    """Process a single chunk through the model"""
    preprocess = transforms.Compose([
        transforms.ToTensor()
    ])
    
    img_tensor = preprocess(chunk).unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        output = model(img_tensor)
    
    output = output.squeeze().cpu().clamp(0, 1).numpy()
    return Image.fromarray((output * 255).astype(np.uint8))

def split_image(img, chunk_size):
    """Split image into chunks"""
    width, height = img.size
    chunks = []
    positions = []
    
    for y in range(0, height, chunk_size):
        for x in range(0, width, chunk_size):
            box = (x, y, x+chunk_size, y+chunk_size)
            chunks.append(img.crop(box))
            positions.append((x, y))
    
    return chunks, positions

def merge_chunks(chunk_files, positions, target_size):
    """Merge processed chunks into final image"""
    merged_img = Image.new('RGB', target_size)
    for (x, y), path in zip(positions, chunk_files):
        chunk = Image.open(path)
        merged_img.paste(chunk, (x*SCALE_FACTOR, y*SCALE_FACTOR))
        os.remove(path)  # Cleanup immediately
    return merged_img

def upscale_image(input_img):
    """Main processing function"""
    start_time = time.time()
    
    # Validate input
    if input_img.size[0] != input_img.size[1]:
        raise ValueError("Input image must be square")
    
    original_size = input_img.size[0]
    target_size = original_size * SCALE_FACTOR
    
    # Split into chunks
    chunks, positions = split_image(input_img, CHUNK_SIZE)
    total_chunks = len(chunks)
    
    # Create temporary directory
    temp_dir = tempfile.mkdtemp()
    chunk_files = []
    
    # Process chunks
    for i, (chunk, (x, y)) in enumerate(zip(chunks, positions)):
        chunk_start = time.time()
        
        # Process chunk
        upscaled = process_chunk(chunk)
        
        # Save to temp file
        chunk_path = os.path.join(temp_dir, f'chunk_{x}_{y}.png')
        upscaled.save(chunk_path)
        chunk_files.append(chunk_path)
        
        # Calculate progress
        elapsed = time.time() - chunk_start
        progress = (i + 1) / total_chunks * 100
        print(f"Processed chunk {i+1}/{total_chunks} ({progress:.1f}%) - {elapsed:.2f}s")
    
    # Merge chunks
    final_image = merge_chunks(chunk_files, positions, (target_size, target_size))
    os.rmdir(temp_dir)  # Cleanup temp directory
    
    total_time = time.time() - start_time
    print(f"Total processing time: {total_time:.2f}s")
    return final_image

# Gradio interface
demo = gr.Interface(
    fn=upscale_image,
    inputs=gr.Image(type="pil", label="Input Image"),
    outputs=gr.Image(type="pil", label="Upscaled Image"),
    title="4x Texture Upscaler (Low Memory)",
    description="Upscale large square textures using RGT model. Accepts 256, 512, 1024, 2048, 4096, 8192px images.",
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch()