Spaces:
Running
Running
File size: 5,703 Bytes
056bc4c 998efdd 056bc4c 998efdd 056bc4c f114dc6 998efdd 056bc4c 19467f5 056bc4c 998efdd f114dc6 056bc4c 998efdd 056bc4c 998efdd 056bc4c 998efdd 056bc4c 998efdd 056bc4c 19467f5 998efdd 056bc4c 998efdd 056bc4c f114dc6 056bc4c 19467f5 f114dc6 056bc4c f114dc6 056bc4c 19467f5 056bc4c f114dc6 056bc4c 19467f5 056bc4c f114dc6 056bc4c 19467f5 056bc4c 19467f5 056bc4c 19467f5 056bc4c f114dc6 056bc4c 19467f5 056bc4c 19467f5 056bc4c 19467f5 f114dc6 056bc4c f114dc6 056bc4c 998efdd 056bc4c f114dc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr
import cv2
import numpy as np
import os
import tempfile
import subprocess
import traceback
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from gfpgan import GFPGANer
from basicsr.utils.download_util import load_file_from_url
# --- Model Loading (Unchanged) ---
model_cache = {}
def get_upsampler(model_name='realesr-general-x4v3'):
if model_name in model_cache:
return model_cache[model_name]
if model_name == 'RealESRGAN_x4plus_anime_6B':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'
else: # realesr-general-x4v3
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
model_path = load_file_from_url(url=file_url, model_dir='weights', progress=True)
upsampler = RealESRGANer(
scale=netscale, model_path=model_path, model=model,
tile=64, tile_pad=10, pre_pad=10, half=True, gpu_id=None
)
model_cache[model_name] = upsampler
return upsampler
def get_face_enhancer(upsampler, outscale):
key = f'face_enhancer_{outscale}'
if key in model_cache:
return model_cache[key]
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=outscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler
)
model_cache[key] = face_enhancer
return face_enhancer
# --- Core Video Processing Function (Unchanged) ---
def enhance_video(video_path, model_name, outscale, face_enhance, progress=gr.Progress(track_tqdm=True)):
if not video_path:
raise gr.Error("Please upload a video to enhance.")
try:
upsampler = get_upsampler(model_name)
face_enhancer = None
if face_enhance:
face_enhancer = get_face_enhancer(upsampler, outscale)
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
temp_dir = tempfile.mkdtemp()
enhanced_video_path = os.path.join(temp_dir, "enhanced_video.mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(enhanced_video_path, fourcc, fps, (width * outscale, height * outscale))
for _ in progress.tqdm(range(total_frames), desc="Enhancing Frames..."):
ret, frame = cap.read()
if not ret: break
if face_enhancer:
_, _, enhanced_frame = face_enhancer.enhance(frame, has_aligned=False, only_center_face=False, paste_back=True)
else:
enhanced_frame, _ = upsampler.enhance(frame, outscale=outscale)
writer.write(enhanced_frame)
cap.release()
writer.release()
final_output_path = os.path.join(temp_dir, "final_output_with_audio.mp4")
audio_merge_cmd = f'ffmpeg -y -i "{enhanced_video_path}" -i "{video_path}" -c:v libx264 -crf 23 -preset fast -c:a aac -b:a 128k -map 0:v:0 -map 1:a:0 -shortest "{final_output_path}"'
subprocess.call(audio_merge_cmd, shell=True, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
return final_output_path
except Exception as e:
print(traceback.format_exc())
raise gr.Error(f"An error occurred: {e}")
# --- Gradio UI with Corrected Layout ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet"), title="π₯ AI Video Enhancer") as demo:
gr.Markdown(
"""
Improve video quality, upscale resolution, and restore faces with cutting-edge AI.
"""
)
# Main two-column layout
with gr.Row(variant="panel"):
# --- Input Column on the Left ---
with gr.Column(scale=1):
video_input = gr.Video(label="π¬ Upload Your Video")
# Accordion for less frequently used settings
with gr.Accordion("Advanced Options", open=False):
model_name = gr.Dropdown(
choices=["realesr-general-x4v3", "RealESRGAN_x4plus_anime_6B"],
value="realesr-general-x4v3",
label="Model Type (General or Anime)"
)
outscale = gr.Slider(1, 4, value=2, step=1, label="Upscale Factor")
# --- Output Column on the Right ---
with gr.Column(scale=1):
video_output = gr.Video(label="π Enhanced Result")
# β
FIX: Controls are now placed directly under the output video
face_enhance = gr.Checkbox(label="β¨ Restore Faces (GFPGAN)", value=False, elem_id="face-enhance-checkbox")
enhance_btn = gr.Button("π Enhance Video", variant="primary")
download_file = gr.File(label="β¬οΈ Download Enhanced Video", visible=False)
# --- Event Logic (Unchanged) ---
def on_submit(video, model, scale, face):
output_path = enhance_video(video, model, scale, face)
return output_path, gr.update(value=output_path, visible=True)
enhance_btn.click(
fn=on_submit,
inputs=[video_input, model_name, outscale, face_enhance],
outputs=[video_output, download_file]
)
if __name__ == '__main__':
demo.launch()
|