b2bomber commited on
Commit
19467f5
Β·
verified Β·
1 Parent(s): f114dc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -61
app.py CHANGED
@@ -11,15 +11,12 @@ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
11
  from gfpgan import GFPGANer
12
  from basicsr.utils.download_util import load_file_from_url
13
 
14
- # --- Model Loading ---
15
- # We create a dictionary to cache models so they are only loaded once.
16
  model_cache = {}
17
 
18
  def get_upsampler(model_name='realesr-general-x4v3'):
19
- """Loads and returns the specified RealESRGAN model."""
20
  if model_name in model_cache:
21
  return model_cache[model_name]
22
-
23
  if model_name == 'RealESRGAN_x4plus_anime_6B':
24
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
25
  netscale = 4
@@ -28,22 +25,18 @@ def get_upsampler(model_name='realesr-general-x4v3'):
28
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
29
  netscale = 4
30
  file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
31
-
32
  model_path = load_file_from_url(url=file_url, model_dir='weights', progress=True)
33
-
34
  upsampler = RealESRGANer(
35
  scale=netscale, model_path=model_path, model=model,
36
- tile=64, tile_pad=10, pre_pad=10, half=True, gpu_id=None # Use half precision for speed
37
  )
38
  model_cache[model_name] = upsampler
39
  return upsampler
40
 
41
  def get_face_enhancer(upsampler, outscale):
42
- """Loads and returns the GFPGAN face enhancer."""
43
  key = f'face_enhancer_{outscale}'
44
  if key in model_cache:
45
  return model_cache[key]
46
-
47
  face_enhancer = GFPGANer(
48
  model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
49
  upscale=outscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler
@@ -51,102 +44,76 @@ def get_face_enhancer(upsampler, outscale):
51
  model_cache[key] = face_enhancer
52
  return face_enhancer
53
 
54
- # --- Core Video Processing Function ---
55
  def enhance_video(video_path, model_name, outscale, face_enhance, progress=gr.Progress(track_tqdm=True)):
56
- """Enhances a video frame by frame and provides progress updates."""
57
  if not video_path:
58
  raise gr.Error("Please upload a video to enhance.")
59
-
60
  try:
61
  upsampler = get_upsampler(model_name)
62
-
63
  face_enhancer = None
64
  if face_enhance:
65
  face_enhancer = get_face_enhancer(upsampler, outscale)
66
-
67
  cap = cv2.VideoCapture(video_path)
68
- fps = cap.get(cv.CAP_PROP_FPS)
69
- width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
70
- height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
71
- total_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
72
-
73
- # Prepare output video writer
74
  temp_dir = tempfile.mkdtemp()
75
  enhanced_video_path = os.path.join(temp_dir, "enhanced_video.mp4")
76
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
77
  writer = cv2.VideoWriter(enhanced_video_path, fourcc, fps, (width * outscale, height * outscale))
78
-
79
- # Process each frame
80
  for _ in progress.tqdm(range(total_frames), desc="Enhancing Frames..."):
81
  ret, frame = cap.read()
82
- if not ret:
83
- break
84
-
85
  if face_enhancer:
86
  _, _, enhanced_frame = face_enhancer.enhance(frame, has_aligned=False, only_center_face=False, paste_back=True)
87
  else:
88
  enhanced_frame, _ = upsampler.enhance(frame, outscale=outscale)
89
-
90
  writer.write(enhanced_frame)
91
-
92
  cap.release()
93
  writer.release()
94
-
95
- # Merge audio back into the enhanced video
96
  final_output_path = os.path.join(temp_dir, "final_output_with_audio.mp4")
97
  audio_merge_cmd = f'ffmpeg -y -i "{enhanced_video_path}" -i "{video_path}" -c:v libx264 -crf 23 -preset fast -c:a aac -b:a 128k -map 0:v:0 -map 1:a:0 -shortest "{final_output_path}"'
98
  subprocess.call(audio_merge_cmd, shell=True, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
99
-
100
  return final_output_path
101
-
102
  except Exception as e:
103
  print(traceback.format_exc())
104
  raise gr.Error(f"An error occurred: {e}")
105
 
106
- # --- Gradio UI with Modern Design ---
107
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet"), title="πŸŽ₯ AI Video Enhancer") as demo:
108
  gr.Markdown(
109
  """
110
- # πŸŽ₯ AI Video Enhancer & Upscaler
111
- Improve video quality, upscale resolution, and restore faces with cutting-edge AI.
112
- **Note:** Processing can be slow, especially for longer videos.
113
  """
114
  )
115
-
116
- # Top row for video previews
117
- with gr.Row():
118
- video_input = gr.Video(label="🎬 Original Video")
119
- video_output = gr.Video(label="🌟 Enhanced Result")
120
-
121
- # Panel for all settings and actions
122
- with gr.Box():
123
- with gr.Row():
124
- # Left side for main settings
125
- with gr.Column(scale=3):
126
  model_name = gr.Dropdown(
127
  choices=["realesr-general-x4v3", "RealESRGAN_x4plus_anime_6B"],
128
  value="realesr-general-x4v3",
129
  label="Model Type (General or Anime)"
130
  )
131
  outscale = gr.Slider(1, 4, value=2, step=1, label="Upscale Factor")
 
 
 
 
132
 
133
- # Right side for the most important actions
134
- with gr.Column(scale=1, min_width=200):
135
- face_enhance = gr.Checkbox(label="✨ Restore Faces (GFPGAN)", value=False)
136
- enhance_btn = gr.Button("πŸš€ Enhance Video", variant="primary")
137
-
138
- # Examples and Download components
139
- gr.Examples(
140
- examples=["sample_video.mp4"], # Add path to your example video
141
- inputs=[video_input],
142
- label="Click an example to start"
143
- )
144
- download_file = gr.File(label="⬇️ Download Enhanced Video", visible=False)
145
 
146
- # --- Event Logic ---
147
  def on_submit(video, model, scale, face):
148
- # When the button is clicked, start the enhancement and return the path to the output video.
149
- # Also, make the download button visible.
150
  output_path = enhance_video(video, model, scale, face)
151
  return output_path, gr.update(value=output_path, visible=True)
152
 
 
11
  from gfpgan import GFPGANer
12
  from basicsr.utils.download_util import load_file_from_url
13
 
14
+ # --- Model Loading (Unchanged) ---
 
15
  model_cache = {}
16
 
17
  def get_upsampler(model_name='realesr-general-x4v3'):
 
18
  if model_name in model_cache:
19
  return model_cache[model_name]
 
20
  if model_name == 'RealESRGAN_x4plus_anime_6B':
21
  model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
22
  netscale = 4
 
25
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
26
  netscale = 4
27
  file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
 
28
  model_path = load_file_from_url(url=file_url, model_dir='weights', progress=True)
 
29
  upsampler = RealESRGANer(
30
  scale=netscale, model_path=model_path, model=model,
31
+ tile=64, tile_pad=10, pre_pad=10, half=True, gpu_id=None
32
  )
33
  model_cache[model_name] = upsampler
34
  return upsampler
35
 
36
  def get_face_enhancer(upsampler, outscale):
 
37
  key = f'face_enhancer_{outscale}'
38
  if key in model_cache:
39
  return model_cache[key]
 
40
  face_enhancer = GFPGANer(
41
  model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
42
  upscale=outscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler
 
44
  model_cache[key] = face_enhancer
45
  return face_enhancer
46
 
47
+ # --- Core Video Processing Function (Unchanged) ---
48
  def enhance_video(video_path, model_name, outscale, face_enhance, progress=gr.Progress(track_tqdm=True)):
 
49
  if not video_path:
50
  raise gr.Error("Please upload a video to enhance.")
 
51
  try:
52
  upsampler = get_upsampler(model_name)
 
53
  face_enhancer = None
54
  if face_enhance:
55
  face_enhancer = get_face_enhancer(upsampler, outscale)
 
56
  cap = cv2.VideoCapture(video_path)
57
+ fps = cap.get(cv2.CAP_PROP_FPS)
58
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
59
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
60
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
 
61
  temp_dir = tempfile.mkdtemp()
62
  enhanced_video_path = os.path.join(temp_dir, "enhanced_video.mp4")
63
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
64
  writer = cv2.VideoWriter(enhanced_video_path, fourcc, fps, (width * outscale, height * outscale))
 
 
65
  for _ in progress.tqdm(range(total_frames), desc="Enhancing Frames..."):
66
  ret, frame = cap.read()
67
+ if not ret: break
 
 
68
  if face_enhancer:
69
  _, _, enhanced_frame = face_enhancer.enhance(frame, has_aligned=False, only_center_face=False, paste_back=True)
70
  else:
71
  enhanced_frame, _ = upsampler.enhance(frame, outscale=outscale)
 
72
  writer.write(enhanced_frame)
 
73
  cap.release()
74
  writer.release()
 
 
75
  final_output_path = os.path.join(temp_dir, "final_output_with_audio.mp4")
76
  audio_merge_cmd = f'ffmpeg -y -i "{enhanced_video_path}" -i "{video_path}" -c:v libx264 -crf 23 -preset fast -c:a aac -b:a 128k -map 0:v:0 -map 1:a:0 -shortest "{final_output_path}"'
77
  subprocess.call(audio_merge_cmd, shell=True, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
 
78
  return final_output_path
 
79
  except Exception as e:
80
  print(traceback.format_exc())
81
  raise gr.Error(f"An error occurred: {e}")
82
 
83
+ # --- Gradio UI with Corrected Layout ---
84
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet"), title="πŸŽ₯ AI Video Enhancer") as demo:
85
  gr.Markdown(
86
  """
87
+ Improve video quality, upscale resolution, and restore faces with cutting-edge AI.
 
 
88
  """
89
  )
90
+
91
+ # Main two-column layout
92
+ with gr.Row(variant="panel"):
93
+ # --- Input Column on the Left ---
94
+ with gr.Column(scale=1):
95
+ video_input = gr.Video(label="🎬 Upload Your Video")
96
+
97
+ # Accordion for less frequently used settings
98
+ with gr.Accordion("Advanced Options", open=False):
 
 
99
  model_name = gr.Dropdown(
100
  choices=["realesr-general-x4v3", "RealESRGAN_x4plus_anime_6B"],
101
  value="realesr-general-x4v3",
102
  label="Model Type (General or Anime)"
103
  )
104
  outscale = gr.Slider(1, 4, value=2, step=1, label="Upscale Factor")
105
+
106
+ # --- Output Column on the Right ---
107
+ with gr.Column(scale=1):
108
+ video_output = gr.Video(label="🌟 Enhanced Result")
109
 
110
+ # βœ… FIX: Controls are now placed directly under the output video
111
+ face_enhance = gr.Checkbox(label="✨ Restore Faces (GFPGAN)", value=False, elem_id="face-enhance-checkbox")
112
+ enhance_btn = gr.Button("πŸš€ Enhance Video", variant="primary")
113
+ download_file = gr.File(label="⬇️ Download Enhanced Video", visible=False)
 
 
 
 
 
 
 
 
114
 
115
+ # --- Event Logic (Unchanged) ---
116
  def on_submit(video, model, scale, face):
 
 
117
  output_path = enhance_video(video, model, scale, face)
118
  return output_path, gr.update(value=output_path, visible=True)
119