tester1hf commited on
Commit
cd78906
·
verified ·
1 Parent(s): 433874a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -26
app.py CHANGED
@@ -1,40 +1,110 @@
 
 
1
  import gradio as gr
2
- import numpy as np
3
  from PIL import Image
4
- import onnxruntime as ort
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Load the ONNX model
7
- ort_session = ort.InferenceSession("4xTextures_GTAV_rgt-s_fp32_opset17.onnx",
8
- providers=['CPUExecutionProvider'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def upscale_image(input_img):
11
- # Convert input to RGB and preprocess
12
- img = input_img.convert("RGB")
13
- lr = np.array(img).astype(np.float32) / 255.0
14
- lr = np.transpose(lr, (2, 0, 1))[np.newaxis, ...] # Add batch dimension
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Perform inference
17
- ort_inputs = {ort_session.get_inputs()[0].name: lr}
18
- ort_outs = ort_session.run(None, ort_inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Post-process output
21
- sr = ort_outs[0][0]
22
- sr = np.clip(sr, 0, 1)
23
- sr = np.transpose(sr, (1, 2, 0)) # CHW to HWC
24
- sr = (sr * 255).astype(np.uint8)
25
 
26
- return Image.fromarray(sr)
 
 
27
 
28
- # Create Gradio interface
29
  demo = gr.Interface(
30
  fn=upscale_image,
31
- inputs=gr.Image(type="pil", label="Input Texture"),
32
- outputs=gr.Image(type="pil", label="Upscaled (4x)"),
33
- title="4x Texture Upscaler (RGT Architecture)",
34
- description="Upscale textures using RGT model (4x scale) - CPU implementation",
35
- #examples=[
36
- #["sample_texture_lowres.png"]
37
- #]
38
  )
39
 
40
  if __name__ == "__main__":
 
1
+ import os
2
+ import time
3
  import gradio as gr
4
+ import torch
5
  from PIL import Image
6
+ import numpy as np
7
+ from torchvision import transforms
8
+ import tempfile
9
+
10
+ # Configuration
11
+ CHUNK_SIZE = 256
12
+ SCALE_FACTOR = 4
13
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+
15
+ # Load model
16
+ model = torch.jit.load('4xTextures_GTAV_rgt-s.pth', map_location=DEVICE)
17
+ model.eval()
18
 
19
+ def process_chunk(chunk):
20
+ """Process a single chunk through the model"""
21
+ preprocess = transforms.Compose([
22
+ transforms.ToTensor()
23
+ ])
24
+
25
+ img_tensor = preprocess(chunk).unsqueeze(0).to(DEVICE)
26
+
27
+ with torch.no_grad():
28
+ output = model(img_tensor)
29
+
30
+ output = output.squeeze().cpu().clamp(0, 1).numpy()
31
+ return Image.fromarray((output * 255).astype(np.uint8))
32
+
33
+ def split_image(img, chunk_size):
34
+ """Split image into chunks"""
35
+ width, height = img.size
36
+ chunks = []
37
+ positions = []
38
+
39
+ for y in range(0, height, chunk_size):
40
+ for x in range(0, width, chunk_size):
41
+ box = (x, y, x+chunk_size, y+chunk_size)
42
+ chunks.append(img.crop(box))
43
+ positions.append((x, y))
44
+
45
+ return chunks, positions
46
+
47
+ def merge_chunks(chunk_files, positions, target_size):
48
+ """Merge processed chunks into final image"""
49
+ merged_img = Image.new('RGB', target_size)
50
+ for (x, y), path in zip(positions, chunk_files):
51
+ chunk = Image.open(path)
52
+ merged_img.paste(chunk, (x*SCALE_FACTOR, y*SCALE_FACTOR))
53
+ os.remove(path) # Cleanup immediately
54
+ return merged_img
55
 
56
  def upscale_image(input_img):
57
+ """Main processing function"""
58
+ start_time = time.time()
59
+
60
+ # Validate input
61
+ if input_img.size[0] != input_img.size[1]:
62
+ raise ValueError("Input image must be square")
63
+
64
+ original_size = input_img.size[0]
65
+ target_size = original_size * SCALE_FACTOR
66
+
67
+ # Split into chunks
68
+ chunks, positions = split_image(input_img, CHUNK_SIZE)
69
+ total_chunks = len(chunks)
70
+
71
+ # Create temporary directory
72
+ temp_dir = tempfile.mkdtemp()
73
+ chunk_files = []
74
 
75
+ # Process chunks
76
+ for i, (chunk, (x, y)) in enumerate(zip(chunks, positions)):
77
+ chunk_start = time.time()
78
+
79
+ # Process chunk
80
+ upscaled = process_chunk(chunk)
81
+
82
+ # Save to temp file
83
+ chunk_path = os.path.join(temp_dir, f'chunk_{x}_{y}.png')
84
+ upscaled.save(chunk_path)
85
+ chunk_files.append(chunk_path)
86
+
87
+ # Calculate progress
88
+ elapsed = time.time() - chunk_start
89
+ progress = (i + 1) / total_chunks * 100
90
+ print(f"Processed chunk {i+1}/{total_chunks} ({progress:.1f}%) - {elapsed:.2f}s")
91
 
92
+ # Merge chunks
93
+ final_image = merge_chunks(chunk_files, positions, (target_size, target_size))
94
+ os.rmdir(temp_dir) # Cleanup temp directory
 
 
95
 
96
+ total_time = time.time() - start_time
97
+ print(f"Total processing time: {total_time:.2f}s")
98
+ return final_image
99
 
100
+ # Gradio interface
101
  demo = gr.Interface(
102
  fn=upscale_image,
103
+ inputs=gr.Image(type="pil", label="Input Image"),
104
+ outputs=gr.Image(type="pil", label="Upscaled Image"),
105
+ title="4x Texture Upscaler (Low Memory)",
106
+ description="Upscale large square textures using RGT model. Accepts 256, 512, 1024, 2048, 4096, 8192px images.",
107
+ allow_flagging="never"
 
 
108
  )
109
 
110
  if __name__ == "__main__":