b2bomber commited on
Commit
056bc4c
Β·
verified Β·
1 Parent(s): b813279

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -108
app.py CHANGED
@@ -1,127 +1,157 @@
1
- import os
2
  import cv2
3
  import numpy as np
 
4
  import tempfile
5
- from tqdm import tqdm
6
- import gradio as gr
7
-
8
  from basicsr.archs.rrdbnet_arch import RRDBNet
9
- from basicsr.utils.download_util import load_file_from_url
10
  from realesrgan import RealESRGANer
11
  from realesrgan.archs.srvgg_arch import SRVGGNetCompact
12
  from gfpgan import GFPGANer
 
 
 
 
 
13
 
 
 
 
 
14
 
15
- # Load models
16
- def load_model(model_name, denoise_strength=1.0):
17
  if model_name == 'RealESRGAN_x4plus_anime_6B':
18
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
19
- num_block=6, num_grow_ch=32, scale=4)
20
  netscale = 4
21
- file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
22
- elif model_name == 'realesr-general-x4v3':
23
- model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64,
24
- num_conv=32, upscale=4, act_type='prelu')
25
  netscale = 4
26
- file_url = [
27
- 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
28
- 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
29
- ]
30
-
31
- model_path = os.path.join('weights', model_name + '.pth')
32
- os.makedirs('weights', exist_ok=True)
33
-
34
- if not os.path.isfile(model_path):
35
- for url in file_url:
36
- model_path = load_file_from_url(url=url, model_dir='weights', progress=True)
37
-
38
- dni_weight = None
39
- if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
40
- model_path = [
41
- os.path.join('weights', 'realesr-general-x4v3.pth'),
42
- os.path.join('weights', 'realesr-general-wdn-x4v3.pth')
43
- ]
44
- dni_weight = [denoise_strength, 1 - denoise_strength]
45
 
 
 
46
  upsampler = RealESRGANer(
47
- scale=netscale,
48
- model_path=model_path,
49
- dni_weight=dni_weight,
50
- model=model,
51
- tile=128,
52
- tile_pad=10,
53
- pre_pad=10,
54
- half=False,
55
- gpu_id=None
56
  )
57
-
58
  return upsampler
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- def enhance_video(video_path, model_name, denoise_strength, face_enhance, outscale):
62
- upsampler = load_model(model_name, denoise_strength)
63
-
64
- if face_enhance:
65
- face_enhancer = GFPGANer(
66
- model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
67
- upscale=outscale,
68
- arch='clean',
69
- channel_multiplier=2,
70
- bg_upsampler=upsampler
71
- )
72
-
73
- cap = cv2.VideoCapture(video_path)
74
- fps = cap.get(cv2.CAP_PROP_FPS)
75
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
76
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
77
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
78
-
79
- temp_out = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
80
- out_path = temp_out.name
81
-
82
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
83
- writer = cv2.VideoWriter(out_path, fourcc, fps, (w * outscale, h * outscale))
84
-
85
- for _ in tqdm(range(total_frames), desc="Enhancing video"):
86
- success, frame = cap.read()
87
- if not success:
88
- break
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- try:
91
- if face_enhance:
92
- _, _, enhanced = face_enhancer.enhance(frame, has_aligned=False, only_center_face=False, paste_back=True)
93
- else:
94
- enhanced, _ = upsampler.enhance(frame, outscale=outscale)
95
- writer.write(enhanced)
96
- except RuntimeError as e:
97
- print("Runtime error:", e)
98
- continue
99
-
100
- cap.release()
101
- writer.release()
102
-
103
- return out_path
104
-
105
-
106
- def gradio_interface(video, model_name, denoise_strength, face_enhance, outscale):
107
- if video is None:
108
- return None
109
- return enhance_video(video, model_name, denoise_strength, face_enhance, outscale)
110
-
111
-
112
- demo = gr.Interface(
113
- fn=gradio_interface,
114
- inputs=[
115
- gr.Video(label="Upload a short video (<30s)"),
116
- gr.Dropdown(["realesr-general-x4v3", "RealESRGAN_x4plus_anime_6B"], label="Model", value="realesr-general-x4v3"),
117
- gr.Slider(0, 1, step=0.1, value=1.0, label="Denoise Strength"),
118
- gr.Checkbox(label="Enable Face Enhancement (GFPGAN)", value=False),
119
- gr.Slider(1, 4, step=1, value=2, label="Upscale Factor")
120
- ],
121
- outputs=gr.Video(label="Enhanced Video Output"),
122
- title="🎬 AI Video Enhancer",
123
- description="Upscale your videos with Real-ESRGAN and optional face enhancement using GFPGAN. Optimized for Hugging Face CPU Spaces."
124
- )
125
-
126
- if __name__ == "__main__":
127
- demo.launch()
 
1
+ import gradio as gr
2
  import cv2
3
  import numpy as np
4
+ import os
5
  import tempfile
6
+ import subprocess
 
 
7
  from basicsr.archs.rrdbnet_arch import RRDBNet
 
8
  from realesrgan import RealESRGANer
9
  from realesrgan.archs.srvgg_arch import SRVGGNetCompact
10
  from gfpgan import GFPGANer
11
+ from basicsr.utils.download_util import load_file_from_url
12
+
13
+ # --- Model Loading ---
14
+ # We create a dictionary to cache models so they are only loaded once.
15
+ model_cache = {}
16
 
17
+ def get_upsampler(model_name='realesr-general-x4v3', denoise_strength=1):
18
+ """Loads and returns the specified RealESRGAN model."""
19
+ if model_name in model_cache:
20
+ return model_cache[model_name]
21
 
 
 
22
  if model_name == 'RealESRGAN_x4plus_anime_6B':
23
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
 
24
  netscale = 4
25
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'
26
+ else: # realesr-general-x4v3
27
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
 
28
  netscale = 4
29
+ file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ model_path = load_file_from_url(url=file_url, model_dir='weights', progress=True)
32
+
33
  upsampler = RealESRGANer(
34
+ scale=netscale, model_path=model_path, model=model,
35
+ tile=64, tile_pad=10, pre_pad=10, half=False, gpu_id=None
 
 
 
 
 
 
 
36
  )
37
+ model_cache[model_name] = upsampler
38
  return upsampler
39
 
40
+ def get_face_enhancer(upsampler, outscale):
41
+ """Loads and returns the GFPGAN face enhancer."""
42
+ key = 'face_enhancer'
43
+ if key in model_cache:
44
+ return model_cache[key]
45
+
46
+ face_enhancer = GFPGANer(
47
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
48
+ upscale=outscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler
49
+ )
50
+ model_cache[key] = face_enhancer
51
+ return face_enhancer
52
+
53
+ # --- Core Video Processing Function ---
54
+ def enhance_video(video_path, model_name, denoise_strength, outscale, face_enhance, progress=gr.Progress(track_tqdm=True)):
55
+ """Enhances a video frame by frame and provides progress updates."""
56
+ if not video_path:
57
+ raise gr.Error("Please upload a video to enhance.")
58
+
59
+ try:
60
+ upsampler = get_upsampler(model_name, denoise_strength)
61
+
62
+ face_enhancer = None
63
+ if face_enhance:
64
+ face_enhancer = get_face_enhancer(upsampler, outscale)
65
+
66
+ cap = cv2.VideoCapture(video_path)
67
+ fps = cap.get(cv2.CAP_PROP_FPS)
68
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
69
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
70
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
71
+
72
+ # Prepare output video writer
73
+ temp_dir = tempfile.mkdtemp()
74
+ enhanced_video_path = os.path.join(temp_dir, "enhanced_video.mp4")
75
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
76
+ writer = cv2.VideoWriter(enhanced_video_path, fourcc, fps, (width * outscale, height * outscale))
77
+
78
+ # Process each frame
79
+ for i in progress.tqdm(range(total_frames), desc="Enhancing Frames..."):
80
+ ret, frame = cap.read()
81
+ if not ret:
82
+ break
83
+
84
+ if face_enhancer:
85
+ _, _, enhanced_frame = face_enhancer.enhance(frame, has_aligned=False, only_center_face=False, paste_back=True)
86
+ else:
87
+ enhanced_frame, _ = upsampler.enhance(frame, outscale=outscale)
88
+
89
+ writer.write(enhanced_frame)
90
+
91
+ cap.release()
92
+ writer.release()
93
+
94
+ # Merge audio back into the enhanced video
95
+ final_output_path = os.path.join(temp_dir, "final_output_with_audio.mp4")
96
+ audio_merge_cmd = f"ffmpeg -y -i {enhanced_video_path} -i {video_path} -c:v libx264 -crf 23 -preset fast -c:a aac -b:a 128k -map 0:v:0 -map 1:a:0 -shortest {final_output_path}"
97
+ subprocess.call(audio_merge_cmd, shell=True, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
98
+
99
+ return final_output_path
100
+
101
+ except Exception as e:
102
+ print(traceback.format_exc())
103
+ raise gr.Error(f"An error occurred: {e}")
104
+
105
+ # --- Gradio UI with Modern Design ---
106
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet"), title="πŸŽ₯ AI Video Enhancer") as demo:
107
+ gr.Markdown(
108
+ """
109
+ # πŸŽ₯ AI Video Enhancer & Upscaler
110
+ Improve video quality, upscale resolution, and restore faces with cutting-edge AI.
111
+ **Note:** Processing can be slow, especially for longer videos.
112
+ """
113
+ )
114
 
115
+ with gr.Row(variant="panel"):
116
+ # --- Input Column ---
117
+ with gr.Column(scale=1):
118
+ video_input = gr.Video(label="🎬 Upload Your Video")
119
+ gr.Examples(
120
+ examples=["sample_video.mp4"], # Add path to your example video
121
+ inputs=[video_input],
122
+ label="Click an example to start"
123
+ )
124
+
125
+ # Settings in a clean Accordion
126
+ with gr.Accordion("βš™οΈ Enhancement Options", open=True):
127
+ model_name = gr.Dropdown(
128
+ choices=["realesr-general-x4v3", "RealESRGAN_x4plus_anime_6B"],
129
+ value="realesr-general-x4v3",
130
+ label="Model Type"
131
+ )
132
+ outscale = gr.Slider(1, 4, value=2, step=1, label="Upscale Factor")
133
+ face_enhance = gr.Checkbox(label="Restore Faces (GFPGAN)")
134
+ denoise_strength = gr.Slider(0, 1, value=0.5, step=0.1, label="Denoise Strength (for general model only)")
135
+
136
+ enhance_btn = gr.Button("✨ Enhance Video", variant="primary")
137
+
138
+ # --- Output Column ---
139
+ with gr.Column(scale=1):
140
+ video_output = gr.Video(label="🌟 Enhanced Result")
141
+ download_file = gr.File(label="⬇️ Download Enhanced Video", visible=False)
142
+
143
+ # --- Event Logic ---
144
+ def on_submit(video, model, denoise, scale, face):
145
+ # When the button is clicked, start the enhancement and return the path to the output video.
146
+ # Also, make the download button visible.
147
+ output_path = enhance_video(video, model, denoise, scale, face)
148
+ return output_path, gr.update(value=output_path, visible=True)
149
+
150
+ enhance_btn.click(
151
+ fn=on_submit,
152
+ inputs=[video_input, model_name, denoise_strength, outscale, face_enhance],
153
+ outputs=[video_output, download_file]
154
+ )
155
 
156
+ if __name__ == '__main__':
157
+ demo.launch()