mutou0308 commited on
Commit
bc3e084
·
verified ·
1 Parent(s): 3652801

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -359
app.py DELETED
@@ -1,359 +0,0 @@
1
-
2
-
3
- import torch
4
- import numpy as np
5
- import gradio as gr
6
- from PIL import Image
7
- import math
8
- import torch.nn.functional as F
9
- import os
10
- import tempfile
11
- import time
12
- import threading
13
-
14
- from utils.hatropeamp import HATNOUP_ROPE_AMP
15
- from utils.fea2gsropeamp import Fea2GS_ROPE_AMP
16
- from utils.edsrbaseline import EDSRNOUP
17
- from utils.hatropeamp import HATNOUP_ROPE_AMP
18
- from utils.rdn import RDNNOUP
19
- from utils.swinir import SwinIRNOUP
20
- from utils.fea2gsropeamp import Fea2GS_ROPE_AMP
21
- from utils.gaussian_splatting import generate_2D_gaussian_splatting_step
22
- from utils.split_and_joint_image import split_and_joint_image
23
- from huggingface_hub import hf_hub_download
24
- import subprocess
25
- import sys
26
- import spaces
27
-
28
-
29
-
30
- # Device setup
31
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
-
33
- # Global stop flag for interrupting inference
34
- stop_inference = False
35
- inference_lock = threading.Lock()
36
-
37
- def load_model(
38
- pretrained_model_name_or_path: str = "mutou0308/GSASR",
39
- model_name: str = "HATL_SA1B",
40
- device: str | torch.device = "cuda"
41
- ):
42
- enc_path = hf_hub_download(
43
- repo_id=pretrained_model_name_or_path, filename=os.path.join(model_name, 'GSASR_enhenced_ultra', 'encoder.pth')
44
- )
45
- dec_path = hf_hub_download(
46
- repo_id=pretrained_model_name_or_path, filename=os.path.join(model_name, 'GSASR_enhenced_ultra', 'decoder.pth')
47
- )
48
-
49
- enc_weight = torch.load(enc_path, weights_only=True)['params_ema']
50
- dec_weight = torch.load(dec_path, weights_only=True)['params_ema']
51
-
52
- if model_name in ['EDSR_DIV2K', 'EDSR_DF2K']:
53
- encoder = EDSRNOUP()
54
- decoder = Fea2GS_ROPE_AMP()
55
- elif model_name in ['RDN_DIV2K', 'RDN_DF2K']:
56
- encoder = RDNNOUP()
57
- decoder = Fea2GS_ROPE_AMP(num_crossattn_blocks = 2)
58
- elif model_name in ['SwinIR_DIV2K', 'SwinIR_DF2K']:
59
- encoder = SwinIRNOUP()
60
- decoder = Fea2GS_ROPE_AMP(num_crossattn_blocks=2, num_crossattn_layers=4, num_gs_seed=256, window_size=16)
61
- elif model_name in ['HATL_SA1B']:
62
- encoder = HATNOUP_ROPE_AMP()
63
- decoder = Fea2GS_ROPE_AMP(channel=192, num_crossattn_blocks=4, num_crossattn_layers=4, num_selfattn_blocks=8, num_selfattn_layers=6,
64
- num_gs_seed=256, window_size=16)
65
- else:
66
- raise ValueError(f"args.model-{model_name} must be in ['EDSR_DIV2K', 'EDSR_DF2K', 'RDN_DIV2K', 'RDN_DF2K', 'SwinIR_DIV2K', 'SwinIR_DF2K', 'HATL_SA1B']")
67
-
68
- encoder.load_state_dict(enc_weight, strict=True)
69
- decoder.load_state_dict(dec_weight, strict=True)
70
- encoder.eval()
71
- decoder.eval()
72
- encoder = encoder.to(device)
73
- decoder = decoder.to(device)
74
- return encoder, decoder
75
-
76
-
77
- def preprocess(x, denominator=16):
78
- """Preprocess image to ensure dimensions are multiples of denominator"""
79
- _, c, h, w = x.shape
80
- if h % denominator > 0:
81
- pad_h = denominator - h % denominator
82
- else:
83
- pad_h = 0
84
- if w % denominator > 0:
85
- pad_w = denominator - w % denominator
86
- else:
87
- pad_w = 0
88
- x_new = F.pad(x, (0, pad_w, 0, pad_h), 'reflect')
89
- return x_new
90
-
91
- def postprocess(x, gt_size_h, gt_size_w):
92
- """Post-process by cropping to target size"""
93
- x_new = x[:, :, :gt_size_h, :gt_size_w]
94
- return x_new
95
-
96
- def should_use_tile(image_height, image_width, threshold=1024):
97
- """Determine if tile processing should be used based on image resolution"""
98
- return max(image_height, image_width) > threshold
99
-
100
- def set_stop_flag():
101
- """Set the global stop flag to interrupt inference"""
102
- global stop_inference
103
- with inference_lock:
104
- stop_inference = True
105
- return "🛑 Stopping inference...", gr.update(interactive=False)
106
-
107
- def reset_stop_flag():
108
- """Reset the global stop flag"""
109
- global stop_inference
110
- with inference_lock:
111
- stop_inference = False
112
-
113
- def check_stop_flag():
114
- """Check if inference should be stopped"""
115
- global stop_inference
116
- with inference_lock:
117
- return stop_inference
118
-
119
- @spaces.GPU
120
- def super_resolution_inference(image, scale=4.0):
121
- """Super-resolution inference function with automatic tile processing"""
122
-
123
- # Check if gscuda setup has been run
124
- setup_marker = ".setup_complete"
125
- if not os.path.exists(setup_marker):
126
- print("First run detected, installing dependencies...")
127
- try:
128
- # subprocess.check_call(["pip", "install", "-e", "."])
129
- subprocess.check_call(["pip", "install", "dist/gscuda-0.0.0-cp310-cp310-linux_x86_64.whl"])
130
- # Create marker file to indicate setup is complete
131
- with open(setup_marker, "w") as f:
132
- f.write("Setup completed")
133
- print("Setup completed successfully!")
134
- except subprocess.CalledProcessError as e:
135
- return None, f"❌ Setup failed with error: {e}", None
136
-
137
-
138
-
139
- if image is None:
140
- return None, "Please upload an image", None
141
-
142
- # Load model
143
- encoder, decoder = load_model(model_name="HATL_SA1B")
144
-
145
- # Reset stop flag at the beginning
146
- reset_stop_flag()
147
-
148
- # Fixed parameters
149
- tile_overlap = 16 # Fixed overlap size
150
- crop_size = 8 # Fixed crop size
151
- tile_size = 1024 # Fixed tile size for large images
152
-
153
- try:
154
- # Check for interruption
155
- if check_stop_flag():
156
- return None, "❌ Inference interrupted", None
157
-
158
- # Convert PIL image to numpy array
159
- img_np = np.array(image)
160
- if len(img_np.shape) == 3:
161
- img_np = img_np[:, :, [2, 1, 0]] # RGB to BGR
162
-
163
- # Convert to tensor
164
- img = torch.from_numpy(np.transpose(img_np.astype(np.float32) / 255., (2, 0, 1))).float()
165
- img = img.unsqueeze(0).to(device)
166
-
167
- # Check for interruption
168
- if check_stop_flag():
169
- return None, "❌ Inference interrupted", None
170
-
171
- # Calculate target size
172
- gt_size = [math.floor(scale * img.shape[2]), math.floor(scale * img.shape[3])]
173
-
174
- # Determine if tile processing should be used
175
- use_tile = should_use_tile(img.shape[2], img.shape[3])
176
-
177
- # Force AMP mixed precision
178
- with torch.inference_mode():
179
- with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
180
- # Check for interruption before main processing
181
- if check_stop_flag():
182
- return None, "❌ Inference interrupted", None
183
-
184
- if use_tile:
185
- # Use tile processing
186
- assert tile_size % 16 == 0, f"tile_size-{tile_size} must be divisible by 16"
187
- assert 2 * tile_overlap < tile_size, f"2 * tile_overlap must be less than tile_size"
188
- assert 2 * crop_size <= tile_overlap, f"2 * crop_size must be less than or equal to tile_overlap"
189
-
190
- with torch.no_grad():
191
- output = split_and_joint_image(
192
- lq=img,
193
- scale_factor=scale,
194
- split_size=tile_size,
195
- overlap_size=tile_overlap,
196
- model_g=encoder,
197
- model_fea2gs=decoder,
198
- crop_size=crop_size,
199
- scale_modify=torch.tensor([scale, scale]),
200
- default_step_size=1.2,
201
- cuda_rendering=True,
202
- mode='scale_modify',
203
- if_dmax=True,
204
- dmax_mode='fix',
205
- dmax=0.1
206
- )
207
- else:
208
- # Direct processing without tiles
209
- lq_pad = preprocess(img, 16) # denominator=16 for HATL
210
- gt_size_pad = torch.tensor([math.floor(scale * lq_pad.shape[2]),
211
- math.floor(scale * lq_pad.shape[3])])
212
- gt_size_pad = gt_size_pad.unsqueeze(0)
213
-
214
- with torch.no_grad():
215
- # Check for interruption before encoder
216
- if check_stop_flag():
217
- return None, "❌ Inference interrupted", None
218
-
219
- # Encoder output
220
- encoder_output = encoder(lq_pad) # b,c,h,w
221
-
222
- # Check for interruption before decoder
223
- if check_stop_flag():
224
- return None, "❌ Inference interrupted", None
225
-
226
- scale_vector = torch.tensor(scale, dtype=torch.float32).unsqueeze(0).to(device)
227
-
228
- # Decoder output
229
- batch_gs_parameters = decoder(encoder_output, scale_vector)
230
- gs_parameters = batch_gs_parameters[0, :]
231
-
232
- # Check for interruption before gaussian rendering
233
- if check_stop_flag():
234
- return None, "❌ Inference interrupted", None
235
-
236
- # Gaussian rendering
237
- b_output = generate_2D_gaussian_splatting_step(
238
- gs_parameters=gs_parameters,
239
- sr_size=gt_size_pad[0],
240
- scale=scale,
241
- sample_coords=None,
242
- scale_modify=torch.tensor([scale, scale]),
243
- default_step_size=1.2,
244
- cuda_rendering=True,
245
- mode='scale_modify',
246
- if_dmax=True,
247
- dmax_mode='fix',
248
- dmax=0.1
249
- )
250
- output = b_output.unsqueeze(0)
251
-
252
- # Check for interruption before post-processing
253
- if check_stop_flag():
254
- return None, "❌ Inference interrupted", None
255
-
256
- # Post-processing
257
- output = postprocess(output, gt_size[0], gt_size[1])
258
-
259
- # Convert back to PIL image format
260
- output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
261
- output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # BGR to RGB
262
- output = (output * 255.0).round().astype(np.uint8)
263
-
264
- # Convert to PIL image
265
- output_pil = Image.fromarray(output)
266
-
267
- # Generate result information
268
- original_size = f"{img.shape[3]}x{img.shape[2]}"
269
- output_size = f"{output.shape[1]}x{output.shape[0]}"
270
- tile_info = f"Tile processing enabled (size: {tile_size})" if use_tile else "Direct processing (no tiles)"
271
- result_info = f"✅ Processing completed successfully!\nOriginal size: {original_size}\nSuper-resolution size: {output_size}\nScale factor: {scale:.2f}x\nProcessing mode: {tile_info}\nAMP acceleration: Force enabled\nOverlap size: {tile_overlap}\nCrop size: {crop_size}"
272
-
273
- return output_pil, result_info, output_pil
274
-
275
- except Exception as e:
276
- if check_stop_flag():
277
- return None, "❌ Inference interrupted", None
278
- return None, f"❌ Error during processing: {str(e)}", None
279
-
280
- def predict(image, scale):
281
- """Gradio prediction function"""
282
- output_image, info, download_image = super_resolution_inference(image, scale)
283
-
284
- # If processing successful, save image for download
285
- if output_image is not None:
286
- # Create temporary filename
287
- timestamp = int(time.time())
288
- temp_filename = f"GSASR_SR_result_{scale}x_{timestamp}.png"
289
- temp_path = os.path.join(tempfile.gettempdir(), temp_filename)
290
-
291
- # Save image
292
- output_image.save(temp_path, "PNG")
293
-
294
- return output_image, temp_path, "✅ Ready", gr.update(interactive=True)
295
- else:
296
- return output_image, None, info if info else "❌ Processing failed", gr.update(interactive=True)
297
-
298
- # Create Gradio interface
299
- with gr.Blocks(title="🚀 GSASR (2D Gaussian Splatting Super-Resolution)") as demo:
300
- gr.Markdown("# **🚀 GSASR (Generalized and efficient 2d gaussian splatting for arbitrary-scale super-resolution)**")
301
- gr.Markdown("Official demo for GSASR. Please refer to our [paper](https://arxiv.org/pdf/2501.06838), [project page](https://mt-cly.github.io/GSASR.github.io/), and [github](https://github.com/ChrisDud0257/GSASR) for more details.")
302
-
303
- with gr.Row():
304
- with gr.Column():
305
- input_image = gr.Image(type="pil", label="Input Image")
306
-
307
- # Scale parameters
308
- with gr.Group():
309
- gr.Markdown("### SR Scale")
310
- scale_slider = gr.Slider(minimum=1.0, maximum=30.0, value=4.0, step=0.1, label="SR Scale")
311
-
312
- # Control buttons
313
- with gr.Row():
314
- submit_btn = gr.Button("🚀 Start Super-Resolution", variant="primary")
315
- stop_btn = gr.Button("🛑 Stop Inference", variant="stop")
316
-
317
- with gr.Column():
318
- output_image = gr.Image(type="pil", label="Super-Resolution Result")
319
-
320
- # Status display
321
- status_text = gr.Textbox(label="Status", value="✅ Ready", interactive=False)
322
-
323
- # Download component
324
- with gr.Group():
325
- gr.Markdown("### 📥 Download Super-Resolution Result")
326
- download_btn = gr.File(visible=True)
327
-
328
- # Event handlers
329
- submit_event = submit_btn.click(
330
- fn=predict,
331
- inputs=[input_image, scale_slider],
332
- outputs=[output_image, download_btn, status_text, stop_btn]
333
- )
334
-
335
- stop_btn.click(
336
- fn=set_stop_flag,
337
- inputs=[],
338
- outputs=[status_text, stop_btn],
339
- cancels=[submit_event]
340
- )
341
-
342
- # Example images
343
- gr.Markdown("### 📚 Example Images")
344
- gr.Markdown("Try these examples with different scales:")
345
-
346
- gr.Examples(
347
- examples=[
348
- ["assets/0846x4.png", 1.5],
349
- ["assets/0892x4.png", 2.8],
350
- ["assets/0873x4_cropped_120x120.png", 30.0]
351
- ],
352
- inputs=[input_image, scale_slider],
353
- examples_per_page=3,
354
- cache_examples=False,
355
- label="Examples"
356
- )
357
-
358
- if __name__ == "__main__":
359
- demo.launch(share=True, server_name="0.0.0.0")