Translsis commited on
Commit
89daa3c
Β·
verified Β·
1 Parent(s): 67f8d32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +339 -80
app.py CHANGED
@@ -1,6 +1,3 @@
1
- '''
2
- Gradio demo (almost the same code as the one used in Huggingface space)
3
- '''
4
  import os, sys
5
  import cv2
6
  import time
@@ -9,7 +6,11 @@ import gradio as gr
9
  import torch
10
  import numpy as np
11
  from torchvision.utils import save_image
12
-
 
 
 
 
13
 
14
  # Import files from the local folder
15
  root_path = os.path.abspath('.')
@@ -17,6 +18,18 @@ sys.path.append(root_path)
17
  from test_code.inference import super_resolve_img
18
  from test_code.test_utils import load_grl, load_rrdb, load_dat
19
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def auto_download_if_needed(weight_path):
22
  if os.path.exists(weight_path):
@@ -40,62 +53,254 @@ def auto_download_if_needed(weight_path):
40
  if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth":
41
  os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth")
42
  os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
44
 
45
 
46
- def inference(img_path, model_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
 
 
 
 
 
48
  try:
49
- weight_dtype = torch.float32
50
 
51
- # Load the model
52
- if model_name == "4xGRL":
53
- weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
54
- auto_download_if_needed(weight_path)
55
- generator = load_grl(weight_path, scale=4) # Directly use default way now
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- elif model_name == "4xRRDB":
58
- weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
59
- auto_download_if_needed(weight_path)
60
- generator = load_rrdb(weight_path, scale=4) # Directly use default way now
 
61
 
62
- elif model_name == "2xRRDB":
63
- weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
64
- auto_download_if_needed(weight_path)
65
- generator = load_rrdb(weight_path, scale=2) # Directly use default way now
66
 
67
- elif model_name == "4xDAT":
68
- weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
69
- auto_download_if_needed(weight_path)
70
- generator = load_dat(weight_path, scale=4) # Directly use default way now
71
 
72
- else:
73
- raise gr.Error("We don't support such Model")
 
74
 
75
- generator = generator.to(dtype=weight_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
- print("We are processing ", img_path)
79
- print("The time now is ", datetime.datetime.now(pytz.timezone('US/Eastern')))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- # In default, we will automatically use crop to match 4x size
82
- super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, downsample_threshold=720, crop_for_4x=True)
83
- store_name = str(time.time()) + ".png"
84
- save_image(super_resolved_img, store_name)
85
- outputs = cv2.imread(store_name)
86
- outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
87
- os.remove(store_name)
88
-
89
- return outputs
90
 
 
 
 
91
 
92
- except Exception as error:
93
- raise gr.Error(f"global exception: {error}")
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
  if __name__ == '__main__':
98
 
 
 
 
 
99
  MARKDOWN = \
100
  """
101
  ## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
@@ -104,50 +309,104 @@ if __name__ == '__main__':
104
 
105
  APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
106
 
107
- ### Note: Due to memory restriction, all images whose short side is over 720 pixel will be downsampled to 720 pixel with the same aspect ratio. E.g., 1920x1080 -> 1280x720
108
- ### Note: Please check [Model Zoo](https://github.com/Kiteretsu77/APISR/blob/main/docs/model_zoo.md) for the description of each weight and [Here](https://imgsli.com/MjU0MjI0) for model comparisons.
109
-
110
- ### If APISR is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/APISR). Thanks! ###
111
  """
112
 
113
  block = gr.Blocks().queue(max_size=10)
114
  with block:
115
- with gr.Row():
116
- gr.Markdown(MARKDOWN)
117
- with gr.Row(elem_classes=["container"]):
118
- with gr.Column(scale=2):
119
- input_image = gr.Image(type="filepath", label="Input")
120
- model_name = gr.Dropdown(
121
- [
122
- "2xRRDB",
123
- "4xRRDB",
124
- "4xGRL",
125
- "4xDAT",
126
- ],
127
- type="value",
128
- value="4xGRL",
129
- label="model",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  )
131
- run_btn = gr.Button(value="Submit")
132
-
133
- with gr.Column(scale=3):
134
- output_image = gr.Image(type="numpy", label="Output image")
135
-
136
- with gr.Row(elem_classes=["container"]):
137
- gr.Examples(
138
- [
139
- ["__assets__/lr_inputs/image-00277.png"],
140
- ["__assets__/lr_inputs/image-00542.png"],
141
- ["__assets__/lr_inputs/41.png"],
142
- ["__assets__/lr_inputs/f91.jpg"],
143
- ["__assets__/lr_inputs/image-00440.png"],
144
- ["__assets__/lr_inputs/image-00164.jpg"],
145
- ["__assets__/lr_inputs/img_eva.jpeg"],
146
- ["__assets__/lr_inputs/naruto.jpg"],
147
- ],
148
- [input_image],
149
- )
150
-
151
- run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- block.launch()
 
 
 
 
1
  import os, sys
2
  import cv2
3
  import time
 
6
  import torch
7
  import numpy as np
8
  from torchvision.utils import save_image
9
+ import json
10
+ import threading
11
+ from queue import Queue
12
+ from pathlib import Path
13
+ import shutil
14
 
15
  # Import files from the local folder
16
  root_path = os.path.abspath('.')
 
18
  from test_code.inference import super_resolve_img
19
  from test_code.test_utils import load_grl, load_rrdb, load_dat
20
 
21
+ # Global configuration
22
+ OUTPUT_DIR = "outputs"
23
+ HISTORY_FILE = "history.json"
24
+ VIDEO_QUEUE_FILE = "video_queue.json"
25
+ video_queue = Queue()
26
+ processing_status = {}
27
+
28
+ # Initialize directories
29
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
30
+ os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True)
31
+ os.makedirs(os.path.join(OUTPUT_DIR, "videos"), exist_ok=True)
32
+
33
 
34
  def auto_download_if_needed(weight_path):
35
  if os.path.exists(weight_path):
 
53
  if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth":
54
  os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth")
55
  os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
56
+
57
+
58
+ def load_history():
59
+ """Load processing history from JSON file"""
60
+ if os.path.exists(HISTORY_FILE):
61
+ with open(HISTORY_FILE, 'r') as f:
62
+ return json.load(f)
63
+ return []
64
+
65
+
66
+ def save_history(history):
67
+ """Save processing history to JSON file"""
68
+ with open(HISTORY_FILE, 'w') as f:
69
+ json.dump(history, f, indent=2)
70
+
71
+
72
+ def add_to_history(input_path, output_path, model_name, process_type, status="completed"):
73
+ """Add a record to history"""
74
+ history = load_history()
75
+ record = {
76
+ "timestamp": datetime.datetime.now().isoformat(),
77
+ "input_path": input_path,
78
+ "output_path": output_path,
79
+ "model_name": model_name,
80
+ "process_type": process_type,
81
+ "status": status
82
+ }
83
+ history.insert(0, record) # Add to beginning
84
+ save_history(history)
85
+
86
+
87
+ def load_generator(model_name):
88
+ """Load the appropriate model"""
89
+ if model_name == "4xGRL":
90
+ weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
91
+ auto_download_if_needed(weight_path)
92
+ generator = load_grl(weight_path, scale=4)
93
+
94
+ elif model_name == "4xRRDB":
95
+ weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
96
+ auto_download_if_needed(weight_path)
97
+ generator = load_rrdb(weight_path, scale=4)
98
+
99
+ elif model_name == "2xRRDB":
100
+ weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
101
+ auto_download_if_needed(weight_path)
102
+ generator = load_rrdb(weight_path, scale=2)
103
+
104
+ elif model_name == "4xDAT":
105
+ weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
106
+ auto_download_if_needed(weight_path)
107
+ generator = load_dat(weight_path, scale=4)
108
+ else:
109
+ raise ValueError(f"Model {model_name} not supported")
110
 
111
+ return generator.to(device='cpu')
112
 
113
 
114
+ def inference_image(img_path, model_name):
115
+ """Process a single image"""
116
+ try:
117
+ generator = load_generator(model_name)
118
+
119
+ print("Processing image:", img_path)
120
+ print("Time:", datetime.datetime.now(pytz.timezone('US/Eastern')))
121
+
122
+ # Process image
123
+ super_resolved_img = super_resolve_img(
124
+ generator, img_path, output_path=None,
125
+ downsample_threshold=720, crop_for_4x=True
126
+ )
127
+
128
+ # Save output
129
+ timestamp = int(time.time() * 1000)
130
+ output_name = f"image_{timestamp}.png"
131
+ output_path = os.path.join(OUTPUT_DIR, "images", output_name)
132
+ save_image(super_resolved_img, output_path)
133
+
134
+ # Load and convert for display
135
+ outputs = cv2.imread(output_path)
136
+ outputs = cv2.cvtColor(outputs, cv2.COLOR_BGR2RGB)
137
+
138
+ # Add to history
139
+ add_to_history(img_path, output_path, model_name, "image")
140
+
141
+ return outputs, f"βœ… Saved to: {output_path}"
142
 
143
+ except Exception as error:
144
+ raise gr.Error(f"Error processing image: {error}")
145
+
146
+
147
+ def process_video_frame_by_frame(video_path, model_name, task_id):
148
+ """Process video frame by frame"""
149
  try:
150
+ processing_status[task_id] = {"status": "processing", "progress": 0}
151
 
152
+ # Load model
153
+ generator = load_generator(model_name)
154
+
155
+ # Open video
156
+ cap = cv2.VideoCapture(video_path)
157
+ if not cap.isOpened():
158
+ raise ValueError("Cannot open video file")
159
+
160
+ # Get video properties
161
+ fps = cap.get(cv2.CAP_PROP_FPS)
162
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
163
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
164
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
165
+
166
+ # Prepare output
167
+ timestamp = int(time.time() * 1000)
168
+ output_name = f"video_{timestamp}.mp4"
169
+ output_path = os.path.join(OUTPUT_DIR, "videos", output_name)
170
+
171
+ # Create temporary directory for frames
172
+ temp_dir = f"temp_frames_{timestamp}"
173
+ os.makedirs(temp_dir, exist_ok=True)
174
+
175
+ # Process frames
176
+ frame_count = 0
177
+ while True:
178
+ ret, frame = cap.read()
179
+ if not ret:
180
+ break
181
+
182
+ # Save frame temporarily
183
+ temp_frame_path = os.path.join(temp_dir, f"frame_{frame_count:06d}.png")
184
+ cv2.imwrite(temp_frame_path, frame)
185
 
186
+ # Super resolve frame
187
+ super_resolved_img = super_resolve_img(
188
+ generator, temp_frame_path, output_path=None,
189
+ downsample_threshold=720, crop_for_4x=True
190
+ )
191
 
192
+ # Save processed frame
193
+ output_frame_path = os.path.join(temp_dir, f"output_{frame_count:06d}.png")
194
+ save_image(super_resolved_img, output_frame_path)
 
195
 
196
+ frame_count += 1
197
+ progress = int((frame_count / total_frames) * 100)
198
+ processing_status[task_id] = {"status": "processing", "progress": progress}
 
199
 
200
+ print(f"Task {task_id}: Processed frame {frame_count}/{total_frames} ({progress}%)")
201
+
202
+ cap.release()
203
 
204
+ # Combine frames into video using ffmpeg
205
+ print(f"Task {task_id}: Combining frames into video...")
206
+ processing_status[task_id] = {"status": "encoding", "progress": 100}
207
+
208
+ os.system(f"ffmpeg -framerate {fps} -i {temp_dir}/output_%06d.png -c:v libx264 -pix_fmt yuv420p {output_path}")
209
+
210
+ # Clean up
211
+ shutil.rmtree(temp_dir)
212
+
213
+ processing_status[task_id] = {"status": "completed", "progress": 100, "output": output_path}
214
+ add_to_history(video_path, output_path, model_name, "video")
215
+
216
+ print(f"Task {task_id}: Completed! Output: {output_path}")
217
+
218
+ except Exception as error:
219
+ processing_status[task_id] = {"status": "error", "error": str(error)}
220
+ print(f"Task {task_id}: Error - {error}")
221
 
222
 
223
+ def video_queue_worker():
224
+ """Background worker to process video queue"""
225
+ print("Video queue worker started...")
226
+ while True:
227
+ try:
228
+ task = video_queue.get()
229
+ if task is None: # Poison pill to stop worker
230
+ break
231
+
232
+ task_id, video_path, model_name = task
233
+ print(f"Starting task {task_id}...")
234
+ process_video_frame_by_frame(video_path, model_name, task_id)
235
+
236
+ except Exception as e:
237
+ print(f"Worker error: {e}")
238
+ finally:
239
+ video_queue.task_done()
240
+
241
 
242
+ def submit_video(video_path, model_name):
243
+ """Submit video to processing queue"""
244
+ if video_path is None:
245
+ return None, "❌ Please upload a video first"
 
 
 
 
 
246
 
247
+ task_id = f"task_{int(time.time() * 1000)}"
248
+ video_queue.put((task_id, video_path, model_name))
249
+ processing_status[task_id] = {"status": "queued", "progress": 0}
250
 
251
+ return None, f"βœ… Video submitted to queue! Task ID: {task_id}\nCheck status in the monitoring section."
252
+
253
 
254
+ def get_queue_status():
255
+ """Get current queue status"""
256
+ status_text = "πŸ“Š **Queue Status**\n\n"
257
+ status_text += f"Videos in queue: {video_queue.qsize()}\n\n"
258
+
259
+ if processing_status:
260
+ status_text += "**Active Tasks:**\n"
261
+ for task_id, status in processing_status.items():
262
+ status_text += f"\n🎬 {task_id}:\n"
263
+ status_text += f" Status: {status['status']}\n"
264
+ status_text += f" Progress: {status.get('progress', 0)}%\n"
265
+ if 'output' in status:
266
+ status_text += f" Output: {status['output']}\n"
267
+ if 'error' in status:
268
+ status_text += f" Error: {status['error']}\n"
269
+ else:
270
+ status_text += "No active tasks"
271
+
272
+ return status_text
273
+
274
+
275
+ def get_history_display():
276
+ """Get formatted history for display"""
277
+ history = load_history()
278
+ if not history:
279
+ return "No history available"
280
+
281
+ history_text = "πŸ“œ **Processing History**\n\n"
282
+ for idx, record in enumerate(history[:50]): # Show last 50
283
+ history_text += f"**{idx + 1}. {record['process_type'].upper()}** - {record['timestamp']}\n"
284
+ history_text += f" Model: {record['model_name']}\n"
285
+ history_text += f" Status: {record['status']}\n"
286
+ history_text += f" Output: {record['output_path']}\n\n"
287
+
288
+ return history_text
289
+
290
+
291
+ def clear_history():
292
+ """Clear all history"""
293
+ if os.path.exists(HISTORY_FILE):
294
+ os.remove(HISTORY_FILE)
295
+ return "βœ… History cleared!"
296
 
297
 
298
  if __name__ == '__main__':
299
 
300
+ # Start background worker thread
301
+ worker_thread = threading.Thread(target=video_queue_worker, daemon=True)
302
+ worker_thread.start()
303
+
304
  MARKDOWN = \
305
  """
306
  ## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
 
309
 
310
  APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
311
 
312
+ ### ⚠️ Note: Images with short side > 720px will be downsampled to 720px (e.g., 1920x1080 β†’ 1280x720)
313
+ ### πŸ“Ή New: Video processing runs in background queue - you can close the browser and it continues!
 
 
314
  """
315
 
316
  block = gr.Blocks().queue(max_size=10)
317
  with block:
318
+ gr.Markdown(MARKDOWN)
319
+
320
+ with gr.Tabs():
321
+ # Tab 1: Image Processing
322
+ with gr.Tab("πŸ–ΌοΈ Image Processing"):
323
+ with gr.Row():
324
+ with gr.Column(scale=2):
325
+ input_image = gr.Image(type="filepath", label="Input Image")
326
+ image_model = gr.Dropdown(
327
+ ["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"],
328
+ value="4xGRL",
329
+ label="Model"
330
+ )
331
+ image_btn = gr.Button("πŸš€ Process Image", variant="primary")
332
+
333
+ with gr.Column(scale=3):
334
+ output_image = gr.Image(type="numpy", label="Output Image")
335
+ image_status = gr.Textbox(label="Status", lines=2)
336
+
337
+ with gr.Row():
338
+ gr.Examples(
339
+ [
340
+ ["__assets__/lr_inputs/image-00277.png"],
341
+ ["__assets__/lr_inputs/image-00542.png"],
342
+ ["__assets__/lr_inputs/41.png"],
343
+ ["__assets__/lr_inputs/f91.jpg"],
344
+ ],
345
+ [input_image],
346
+ )
347
+
348
+ image_btn.click(
349
+ inference_image,
350
+ inputs=[input_image, image_model],
351
+ outputs=[output_image, image_status]
352
  )
353
+
354
+ # Tab 2: Video Processing
355
+ with gr.Tab("🎬 Video Processing"):
356
+ gr.Markdown("""
357
+ ### Video Processing Queue
358
+ Videos are processed in the background. You can submit multiple videos and close the browser - processing continues!
359
+ """)
360
+
361
+ with gr.Row():
362
+ with gr.Column():
363
+ input_video = gr.Video(label="Input Video")
364
+ video_model = gr.Dropdown(
365
+ ["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"],
366
+ value="4xGRL",
367
+ label="Model"
368
+ )
369
+ video_btn = gr.Button("πŸ“€ Submit to Queue", variant="primary")
370
+ video_status = gr.Textbox(label="Submission Status", lines=3)
371
+
372
+ with gr.Column():
373
+ gr.Markdown("### πŸ“Š Queue Monitor")
374
+ queue_status = gr.Textbox(label="Queue Status", lines=15)
375
+ refresh_btn = gr.Button("πŸ”„ Refresh Status")
376
+
377
+ video_btn.click(
378
+ submit_video,
379
+ inputs=[input_video, video_model],
380
+ outputs=[input_video, video_status]
381
+ )
382
+
383
+ refresh_btn.click(
384
+ get_queue_status,
385
+ outputs=[queue_status]
386
+ )
387
+
388
+ # Tab 3: History
389
+ with gr.Tab("πŸ“œ History"):
390
+ gr.Markdown("### Processing History")
391
+
392
+ with gr.Row():
393
+ refresh_history_btn = gr.Button("πŸ”„ Refresh History")
394
+ clear_history_btn = gr.Button("πŸ—‘οΈ Clear History", variant="stop")
395
+
396
+ history_display = gr.Textbox(label="History", lines=20)
397
+ clear_status = gr.Textbox(label="Status", lines=1)
398
+
399
+ refresh_history_btn.click(
400
+ get_history_display,
401
+ outputs=[history_display]
402
+ )
403
+
404
+ clear_history_btn.click(
405
+ clear_history,
406
+ outputs=[clear_status]
407
+ )
408
+
409
+ # Auto-load history on tab open
410
+ block.load(get_history_display, outputs=[history_display])
411
 
412
+ block.launch(server_name="0.0.0.0", server_port=7860)