Spaces:
Running
Running
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import os | |
| import tempfile | |
| import subprocess | |
| import traceback | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from realesrgan import RealESRGANer | |
| from realesrgan.archs.srvgg_arch import SRVGGNetCompact | |
| from gfpgan import GFPGANer | |
| from basicsr.utils.download_util import load_file_from_url | |
| # --- Model Loading (Unchanged) --- | |
| model_cache = {} | |
| def get_upsampler(model_name='realesr-general-x4v3'): | |
| if model_name in model_cache: | |
| return model_cache[model_name] | |
| if model_name == 'RealESRGAN_x4plus_anime_6B': | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) | |
| netscale = 4 | |
| file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth' | |
| else: # realesr-general-x4v3 | |
| model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
| netscale = 4 | |
| file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth' | |
| model_path = load_file_from_url(url=file_url, model_dir='weights', progress=True) | |
| upsampler = RealESRGANer( | |
| scale=netscale, model_path=model_path, model=model, | |
| tile=64, tile_pad=10, pre_pad=10, half=True, gpu_id=None | |
| ) | |
| model_cache[model_name] = upsampler | |
| return upsampler | |
| def get_face_enhancer(upsampler, outscale): | |
| key = f'face_enhancer_{outscale}' | |
| if key in model_cache: | |
| return model_cache[key] | |
| face_enhancer = GFPGANer( | |
| model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', | |
| upscale=outscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler | |
| ) | |
| model_cache[key] = face_enhancer | |
| return face_enhancer | |
| # --- Core Video Processing Function (Unchanged) --- | |
| def enhance_video(video_path, model_name, outscale, face_enhance, progress=gr.Progress(track_tqdm=True)): | |
| if not video_path: | |
| raise gr.Error("Please upload a video to enhance.") | |
| try: | |
| upsampler = get_upsampler(model_name) | |
| face_enhancer = None | |
| if face_enhance: | |
| face_enhancer = get_face_enhancer(upsampler, outscale) | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| temp_dir = tempfile.mkdtemp() | |
| enhanced_video_path = os.path.join(temp_dir, "enhanced_video.mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| writer = cv2.VideoWriter(enhanced_video_path, fourcc, fps, (width * outscale, height * outscale)) | |
| for _ in progress.tqdm(range(total_frames), desc="Enhancing Frames..."): | |
| ret, frame = cap.read() | |
| if not ret: break | |
| if face_enhancer: | |
| _, _, enhanced_frame = face_enhancer.enhance(frame, has_aligned=False, only_center_face=False, paste_back=True) | |
| else: | |
| enhanced_frame, _ = upsampler.enhance(frame, outscale=outscale) | |
| writer.write(enhanced_frame) | |
| cap.release() | |
| writer.release() | |
| final_output_path = os.path.join(temp_dir, "final_output_with_audio.mp4") | |
| audio_merge_cmd = f'ffmpeg -y -i "{enhanced_video_path}" -i "{video_path}" -c:v libx264 -crf 23 -preset fast -c:a aac -b:a 128k -map 0:v:0 -map 1:a:0 -shortest "{final_output_path}"' | |
| subprocess.call(audio_merge_cmd, shell=True, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL) | |
| return final_output_path | |
| except Exception as e: | |
| print(traceback.format_exc()) | |
| raise gr.Error(f"An error occurred: {e}") | |
| # --- Gradio UI with Corrected Layout --- | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet"), title="π₯ AI Video Enhancer") as demo: | |
| gr.Markdown( | |
| """ | |
| Improve video quality, upscale resolution, and restore faces with cutting-edge AI. | |
| """ | |
| ) | |
| # Main two-column layout | |
| with gr.Row(variant="panel"): | |
| # --- Input Column on the Left --- | |
| with gr.Column(scale=1): | |
| video_input = gr.Video(label="π¬ Upload Your Video") | |
| # Accordion for less frequently used settings | |
| with gr.Accordion("Advanced Options", open=False): | |
| model_name = gr.Dropdown( | |
| choices=["realesr-general-x4v3", "RealESRGAN_x4plus_anime_6B"], | |
| value="realesr-general-x4v3", | |
| label="Model Type (General or Anime)" | |
| ) | |
| outscale = gr.Slider(1, 4, value=2, step=1, label="Upscale Factor") | |
| # --- Output Column on the Right --- | |
| with gr.Column(scale=1): | |
| video_output = gr.Video(label="π Enhanced Result") | |
| # β FIX: Controls are now placed directly under the output video | |
| face_enhance = gr.Checkbox(label="β¨ Restore Faces (GFPGAN)", value=False, elem_id="face-enhance-checkbox") | |
| enhance_btn = gr.Button("π Enhance Video", variant="primary") | |
| download_file = gr.File(label="β¬οΈ Download Enhanced Video", visible=False) | |
| # --- Event Logic (Unchanged) --- | |
| def on_submit(video, model, scale, face): | |
| output_path = enhance_video(video, model, scale, face) | |
| return output_path, gr.update(value=output_path, visible=True) | |
| enhance_btn.click( | |
| fn=on_submit, | |
| inputs=[video_input, model_name, outscale, face_enhance], | |
| outputs=[video_output, download_file] | |
| ) | |
| if __name__ == '__main__': | |
| demo.launch() | |