b2bomber commited on
Commit
998efdd
·
verified ·
1 Parent(s): 56fa43e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import tempfile
5
+ from tqdm import tqdm
6
+ import gradio as gr
7
+
8
+ from basicsr.archs.rrdbnet_arch import RRDBNet
9
+ from basicsr.utils.download_util import load_file_from_url
10
+ from realesrgan import RealESRGANer
11
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
12
+ from gfpgan import GFPGANer
13
+
14
+
15
+ # Load models
16
+ def load_model(model_name, denoise_strength=1.0):
17
+ if model_name == 'RealESRGAN_x4plus_anime_6B':
18
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
19
+ num_block=6, num_grow_ch=32, scale=4)
20
+ netscale = 4
21
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
22
+ elif model_name == 'realesr-general-x4v3':
23
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64,
24
+ num_conv=32, upscale=4, act_type='prelu')
25
+ netscale = 4
26
+ file_url = [
27
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
28
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
29
+ ]
30
+
31
+ model_path = os.path.join('weights', model_name + '.pth')
32
+ os.makedirs('weights', exist_ok=True)
33
+
34
+ if not os.path.isfile(model_path):
35
+ for url in file_url:
36
+ model_path = load_file_from_url(url=url, model_dir='weights', progress=True)
37
+
38
+ dni_weight = None
39
+ if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
40
+ model_path = [
41
+ os.path.join('weights', 'realesr-general-x4v3.pth'),
42
+ os.path.join('weights', 'realesr-general-wdn-x4v3.pth')
43
+ ]
44
+ dni_weight = [denoise_strength, 1 - denoise_strength]
45
+
46
+ upsampler = RealESRGANer(
47
+ scale=netscale,
48
+ model_path=model_path,
49
+ dni_weight=dni_weight,
50
+ model=model,
51
+ tile=128,
52
+ tile_pad=10,
53
+ pre_pad=10,
54
+ half=False,
55
+ gpu_id=None
56
+ )
57
+
58
+ return upsampler
59
+
60
+
61
+ def enhance_video(video_path, model_name, denoise_strength, face_enhance, outscale):
62
+ upsampler = load_model(model_name, denoise_strength)
63
+
64
+ if face_enhance:
65
+ face_enhancer = GFPGANer(
66
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
67
+ upscale=outscale,
68
+ arch='clean',
69
+ channel_multiplier=2,
70
+ bg_upsampler=upsampler
71
+ )
72
+
73
+ cap = cv2.VideoCapture(video_path)
74
+ fps = cap.get(cv2.CAP_PROP_FPS)
75
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
76
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
77
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
78
+
79
+ temp_out = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
80
+ out_path = temp_out.name
81
+
82
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
83
+ writer = cv2.VideoWriter(out_path, fourcc, fps, (w * outscale, h * outscale))
84
+
85
+ for _ in tqdm(range(total_frames), desc="Enhancing video"):
86
+ success, frame = cap.read()
87
+ if not success:
88
+ break
89
+
90
+ try:
91
+ if face_enhance:
92
+ _, _, enhanced = face_enhancer.enhance(frame, has_aligned=False, only_center_face=False, paste_back=True)
93
+ else:
94
+ enhanced, _ = upsampler.enhance(frame, outscale=outscale)
95
+ writer.write(enhanced)
96
+ except RuntimeError as e:
97
+ print("Runtime error:", e)
98
+ continue
99
+
100
+ cap.release()
101
+ writer.release()
102
+
103
+ return out_path
104
+
105
+
106
+ def gradio_interface(video, model_name, denoise_strength, face_enhance, outscale):
107
+ if video is None:
108
+ return None
109
+ return enhance_video(video, model_name, denoise_strength, face_enhance, outscale)
110
+
111
+
112
+ demo = gr.Interface(
113
+ fn=gradio_interface,
114
+ inputs=[
115
+ gr.Video(label="Upload a short video (<30s)"),
116
+ gr.Dropdown(["realesr-general-x4v3", "RealESRGAN_x4plus_anime_6B"], label="Model", value="realesr-general-x4v3"),
117
+ gr.Slider(0, 1, step=0.1, value=1.0, label="Denoise Strength"),
118
+ gr.Checkbox(label="Enable Face Enhancement (GFPGAN)", value=False),
119
+ gr.Slider(1, 4, step=1, value=2, label="Upscale Factor")
120
+ ],
121
+ outputs=gr.Video(label="Enhanced Video Output"),
122
+ title="🎬 AI Video Enhancer",
123
+ description="Upscale your videos with Real-ESRGAN and optional face enhancement using GFPGAN. Optimized for Hugging Face CPU Spaces."
124
+ )
125
+
126
+ if __name__ == "__main__":
127
+ demo.launch()