File size: 5,703 Bytes
056bc4c
998efdd
 
056bc4c
998efdd
056bc4c
f114dc6
998efdd
 
 
 
056bc4c
 
19467f5
056bc4c
998efdd
f114dc6
056bc4c
 
998efdd
056bc4c
998efdd
056bc4c
 
 
998efdd
056bc4c
 
998efdd
056bc4c
19467f5
998efdd
056bc4c
998efdd
 
056bc4c
f114dc6
056bc4c
 
 
 
 
 
 
 
 
19467f5
f114dc6
056bc4c
 
 
f114dc6
056bc4c
 
 
 
19467f5
 
 
 
056bc4c
 
 
 
f114dc6
056bc4c
19467f5
056bc4c
 
 
 
 
 
 
 
f114dc6
056bc4c
 
 
 
 
 
19467f5
056bc4c
 
 
19467f5
056bc4c
 
19467f5
 
 
 
 
 
 
 
 
056bc4c
 
 
f114dc6
056bc4c
 
19467f5
 
 
 
056bc4c
19467f5
 
 
 
056bc4c
19467f5
f114dc6
 
056bc4c
 
 
 
f114dc6
056bc4c
 
998efdd
056bc4c
f114dc6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import cv2
import numpy as np
import os
import tempfile
import subprocess
import traceback
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from gfpgan import GFPGANer
from basicsr.utils.download_util import load_file_from_url

# --- Model Loading (Unchanged) ---
model_cache = {}

def get_upsampler(model_name='realesr-general-x4v3'):
    if model_name in model_cache:
        return model_cache[model_name]
    if model_name == 'RealESRGAN_x4plus_anime_6B':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'
    else: # realesr-general-x4v3
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
        netscale = 4
        file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
    model_path = load_file_from_url(url=file_url, model_dir='weights', progress=True)
    upsampler = RealESRGANer(
        scale=netscale, model_path=model_path, model=model,
        tile=64, tile_pad=10, pre_pad=10, half=True, gpu_id=None
    )
    model_cache[model_name] = upsampler
    return upsampler

def get_face_enhancer(upsampler, outscale):
    key = f'face_enhancer_{outscale}'
    if key in model_cache:
        return model_cache[key]
    face_enhancer = GFPGANer(
        model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
        upscale=outscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler
    )
    model_cache[key] = face_enhancer
    return face_enhancer

# --- Core Video Processing Function (Unchanged) ---
def enhance_video(video_path, model_name, outscale, face_enhance, progress=gr.Progress(track_tqdm=True)):
    if not video_path:
        raise gr.Error("Please upload a video to enhance.")
    try:
        upsampler = get_upsampler(model_name)
        face_enhancer = None
        if face_enhance:
            face_enhancer = get_face_enhancer(upsampler, outscale)
        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS)
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        temp_dir = tempfile.mkdtemp()
        enhanced_video_path = os.path.join(temp_dir, "enhanced_video.mp4")
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(enhanced_video_path, fourcc, fps, (width * outscale, height * outscale))
        for _ in progress.tqdm(range(total_frames), desc="Enhancing Frames..."):
            ret, frame = cap.read()
            if not ret: break
            if face_enhancer:
                _, _, enhanced_frame = face_enhancer.enhance(frame, has_aligned=False, only_center_face=False, paste_back=True)
            else:
                enhanced_frame, _ = upsampler.enhance(frame, outscale=outscale)
            writer.write(enhanced_frame)
        cap.release()
        writer.release()
        final_output_path = os.path.join(temp_dir, "final_output_with_audio.mp4")
        audio_merge_cmd = f'ffmpeg -y -i "{enhanced_video_path}" -i "{video_path}" -c:v libx264 -crf 23 -preset fast -c:a aac -b:a 128k -map 0:v:0 -map 1:a:0 -shortest "{final_output_path}"'
        subprocess.call(audio_merge_cmd, shell=True, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
        return final_output_path
    except Exception as e:
        print(traceback.format_exc())
        raise gr.Error(f"An error occurred: {e}")

# --- Gradio UI with Corrected Layout ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet"), title="πŸŽ₯ AI Video Enhancer") as demo:
    gr.Markdown(
        """
        Improve video quality, upscale resolution, and restore faces with cutting-edge AI.
        """
    )
    
    # Main two-column layout
    with gr.Row(variant="panel"):
        # --- Input Column on the Left ---
        with gr.Column(scale=1):
            video_input = gr.Video(label="🎬 Upload Your Video")

            # Accordion for less frequently used settings
            with gr.Accordion("Advanced Options", open=False):
                model_name = gr.Dropdown(
                    choices=["realesr-general-x4v3", "RealESRGAN_x4plus_anime_6B"], 
                    value="realesr-general-x4v3", 
                    label="Model Type (General or Anime)"
                )
                outscale = gr.Slider(1, 4, value=2, step=1, label="Upscale Factor")
        
        # --- Output Column on the Right ---
        with gr.Column(scale=1):
            video_output = gr.Video(label="🌟 Enhanced Result")
            
            # βœ… FIX: Controls are now placed directly under the output video
            face_enhance = gr.Checkbox(label="✨ Restore Faces (GFPGAN)", value=False, elem_id="face-enhance-checkbox")
            enhance_btn = gr.Button("πŸš€ Enhance Video", variant="primary")
            download_file = gr.File(label="⬇️ Download Enhanced Video", visible=False)

    # --- Event Logic (Unchanged) ---
    def on_submit(video, model, scale, face):
        output_path = enhance_video(video, model, scale, face)
        return output_path, gr.update(value=output_path, visible=True)

    enhance_btn.click(
        fn=on_submit, 
        inputs=[video_input, model_name, outscale, face_enhance], 
        outputs=[video_output, download_file]
    )

if __name__ == '__main__':
    demo.launch()