Spaces:
Running on Zero
Running on Zero
| import os | |
| import shutil | |
| import subprocess | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from typing import Tuple | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import cv2 | |
| from huggingface_hub import hf_hub_download | |
| TEMP_DIR = Path(tempfile.gettempdir()) / "hf_video_enhancer" | |
| TEMP_DIR.mkdir(parents=True, exist_ok=True) | |
| def run_cmd(cmd): | |
| p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| if p.returncode != 0: | |
| raise RuntimeError(f"Command failed: {p.stderr.decode()}") | |
| return p.stdout.decode() | |
| def probe_video(video_path: str) -> Tuple[float, int, int, float]: | |
| cmd = [ | |
| "ffprobe", "-v", "error", | |
| "-select_streams", "v:0", | |
| "-show_entries", "stream=width,height,duration,r_frame_rate", | |
| "-of", "default=noprint_wrappers=1:nokey=0", | |
| video_path | |
| ] | |
| p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| out = p.stdout.decode() | |
| width = height = 0 | |
| duration = 0.0 | |
| fps = 30.0 | |
| for line in out.splitlines(): | |
| if line.startswith("width="): | |
| width = int(line.split("=")[1]) | |
| elif line.startswith("height="): | |
| height = int(line.split("=")[1]) | |
| elif line.startswith("duration="): | |
| try: | |
| duration = float(line.split("=")[1]) | |
| except: | |
| pass | |
| elif line.startswith("r_frame_rate="): | |
| try: | |
| fps_str = line.split("=")[1] | |
| if "/" in fps_str: | |
| num, den = fps_str.split("/") | |
| fps = float(num) / float(den) | |
| else: | |
| fps = float(fps_str) | |
| except: | |
| pass | |
| return duration, width, height, fps | |
| def extract_frames(video_path: str, frames_dir: Path): | |
| frames_dir.mkdir(parents=True, exist_ok=True) | |
| run_cmd([ | |
| "ffmpeg", "-y", "-i", video_path, | |
| "-vsync", "0", | |
| str(frames_dir / "%06d.png") | |
| ]) | |
| def reassemble_video(frames_dir: Path, audio_src: str, out_path: str, fps: float = 30.0): | |
| tmp_video = str(frames_dir.parent / "tmp_video.mp4") | |
| run_cmd([ | |
| "ffmpeg", "-y", "-framerate", str(fps), | |
| "-i", str(frames_dir / "%06d.png"), | |
| "-c:v", "libx264", "-preset", "veryfast", "-pix_fmt", "yuv420p", | |
| "-crf", "18", tmp_video | |
| ]) | |
| p = subprocess.run( | |
| ["ffprobe", "-v", "error", "-select_streams", "a", "-show_entries", | |
| "stream=codec_type", "-of", "default=noprint_wrappers=1", audio_src], | |
| stdout=subprocess.PIPE, stderr=subprocess.PIPE | |
| ) | |
| if p.stdout.decode().strip(): | |
| run_cmd([ | |
| "ffmpeg", "-y", "-i", tmp_video, "-i", audio_src, | |
| "-c:v", "copy", "-c:a", "aac", | |
| "-map", "0:v:0", "-map", "1:a:0", out_path | |
| ]) | |
| os.remove(tmp_video) | |
| else: | |
| shutil.move(tmp_video, out_path) | |
| def bicubic_upscale_frames(frames_dir: Path, scale: int): | |
| for fp in sorted(frames_dir.glob("*.png")): | |
| img = cv2.imread(str(fp)) | |
| if img is None: | |
| continue | |
| h, w = img.shape[:2] | |
| upscaled = cv2.resize(img, (w * scale, h * scale), interpolation=cv2.INTER_CUBIC) | |
| cv2.imwrite(str(fp), upscaled) | |
| def enhance_with_realesrgan(frames_dir: str, scale: int = 4) -> int: | |
| from spandrel import ImageModelDescriptor, ModelLoader | |
| import numpy as np | |
| frames_path = Path(frames_dir) | |
| frame_files = sorted(frames_path.glob("*.png")) | |
| total = len(frame_files) | |
| if total == 0: | |
| return 0 | |
| if scale == 2: | |
| model_path = hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x2.pth") | |
| else: | |
| model_path = hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth") | |
| model = ModelLoader().load_from_file(model_path) | |
| assert isinstance(model, ImageModelDescriptor) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = model.to(device).eval() | |
| print(f"Model loaded on {device}, processing {total} frames...") | |
| for idx, frame_path in enumerate(frame_files): | |
| img = cv2.imread(str(frame_path)) | |
| if img is None: | |
| continue | |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float().div(255.0) | |
| tensor = tensor.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(tensor) | |
| output = output.squeeze(0).cpu().clamp(0, 1).mul(255).byte() | |
| output = output.permute(1, 2, 0).numpy() | |
| output_bgr = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) | |
| cv2.imwrite(str(frame_path), output_bgr) | |
| if (idx + 1) % 5 == 0: | |
| print(f"Processed {idx + 1}/{total}") | |
| return total | |
| def upload_video(video_file: str) -> str: | |
| """ | |
| Upload a video file to the server to prepare it for enhancement. | |
| MCP callers must call this tool first to stage the video, then pass | |
| the returned path into process_video. | |
| Args: | |
| video_file: Local path to the video file to upload. | |
| Returns: | |
| The server-side path to pass directly into process_video. | |
| """ | |
| return video_file | |
| def process_video( | |
| video_file: str, | |
| scale: str = "4", | |
| oauth_token: gr.OAuthToken | None = None, | |
| request: gr.Request | None = None, | |
| ) -> Tuple[str, str]: | |
| """ | |
| Upscale and enhance a video using Real-ESRGAN AI super-resolution or bicubic interpolation. | |
| Authentication determines the upscaling mode: | |
| - MCP callers: add your HF Pro token to the Authorization header in your MCP client config: | |
| Authorization:Bearer hf_YOUR_TOKEN | |
| ZeroGPU will bill your own Pro quota. No token is needed in this function itself. | |
| - Browser users: click the Login button to authenticate with your own ZeroGPU quota. | |
| - Unauthenticated callers: bicubic upscaling is used automatically (no GPU required). | |
| Args: | |
| video_file: Path or URL to the input video file. Supported formats: mp4, avi, mov, mkv, webm. | |
| scale: Upscaling factor. Use "2" for 2x or "4" for 4x resolution (default: "4"). | |
| Returns: | |
| A tuple of (status_message, output_video_path). Status describes the method used | |
| and resolution change (e.g. '[AI] 480x270 β 1920x1080' or '[Bicubic] 480x270 β 1920x1080'). | |
| """ | |
| if video_file is None: | |
| return "β οΈ Please upload a video file.", None | |
| # Handle FileData dict from MCP ({"path": "...", "url": "..."}) | |
| if isinstance(video_file, dict): | |
| video_file = video_file.get("url") or video_file.get("path") or "" | |
| # Download if it's an HTTP URL | |
| if isinstance(video_file, str) and video_file.startswith("http"): | |
| import urllib.request | |
| ts_dl = int(time.time() * 1000) | |
| dl_dir = TEMP_DIR / f"dl_{ts_dl}" | |
| dl_dir.mkdir(parents=True, exist_ok=True) | |
| suffix = Path(video_file.split("?")[0]).suffix or ".mp4" | |
| dl_path = dl_dir / f"input{suffix}" | |
| try: | |
| urllib.request.urlretrieve(video_file, dl_path) | |
| except Exception as e: | |
| return f"β Failed to download video: {e}", None | |
| video_file = str(dl_path) | |
| # Convert scale to int safely | |
| try: | |
| scale_int = int(scale) | |
| except (ValueError, TypeError): | |
| scale_int = 4 | |
| # Auth check: browser OAuth login OR Authorization header from MCP caller | |
| use_ai = oauth_token is not None | |
| if not use_ai and request is not None: | |
| auth_header = request.headers.get("authorization", "") | |
| if auth_header.lower().startswith("bearer hf_"): | |
| use_ai = True | |
| print("AI mode enabled via Authorization header") | |
| mode_label = "AI (Real-ESRGAN)" if use_ai else "Bicubic" | |
| print(f"Upscaling mode: {mode_label} | scale: {scale_int}x") | |
| ts = int(time.time() * 1000) | |
| base_dir = TEMP_DIR / f"job_{ts}" | |
| base_dir.mkdir(parents=True, exist_ok=True) | |
| in_path = base_dir / "input_video" | |
| try: | |
| shutil.copy(video_file, in_path) | |
| except Exception as e: | |
| return f"Error copying file: {e}", None | |
| try: | |
| duration, w, h, fps = probe_video(str(in_path)) | |
| except Exception as e: | |
| shutil.rmtree(base_dir, ignore_errors=True) | |
| return f"Error probing video: {e}", None | |
| if duration <= 0: | |
| shutil.rmtree(base_dir, ignore_errors=True) | |
| return "Could not determine video duration.", None | |
| max_frames = int(fps * 30) | |
| print(f"Video: {w}x{h}, {duration:.1f}s, {fps:.1f}fps") | |
| frames_dir = base_dir / "frames" | |
| try: | |
| extract_frames(str(in_path), frames_dir) | |
| except Exception as e: | |
| shutil.rmtree(base_dir, ignore_errors=True) | |
| return f"Failed extracting frames: {e}", None | |
| frame_files = sorted(frames_dir.glob("*.png")) | |
| num_frames = len(frame_files) | |
| if num_frames > max_frames: | |
| print(f"Limiting from {num_frames} to {max_frames} frames") | |
| for f in frame_files[max_frames:]: | |
| f.unlink() | |
| num_frames = max_frames | |
| print(f"Processing {num_frames} frames with {mode_label}...") | |
| if use_ai: | |
| try: | |
| enhanced = enhance_with_realesrgan(str(frames_dir), scale_int) | |
| print(f"AI-enhanced {enhanced} frames") | |
| except Exception as e: | |
| shutil.rmtree(base_dir, ignore_errors=True) | |
| return f"β AI enhancement failed: {e}", None | |
| else: | |
| try: | |
| bicubic_upscale_frames(frames_dir, scale_int) | |
| except Exception as e: | |
| shutil.rmtree(base_dir, ignore_errors=True) | |
| return f"Bicubic upscaling failed: {e}", None | |
| out_video = base_dir / "enhanced_output.mp4" | |
| try: | |
| reassemble_video(frames_dir, str(in_path), str(out_video), fps) | |
| except Exception as e: | |
| shutil.rmtree(base_dir, ignore_errors=True) | |
| return f"Failed reassembling video: {e}", None | |
| shutil.rmtree(frames_dir, ignore_errors=True) | |
| tag = "AI" if use_ai else "Bicubic" | |
| try: | |
| _, out_w, out_h, _ = probe_video(str(out_video)) | |
| return f"β Done! [{tag}] {w}x{h} β {out_w}x{out_h}", str(out_video) | |
| except: | |
| return f"β Done! [{tag}]", str(out_video) | |
| # Gradio UI | |
| with gr.Blocks(title="AI Video Enhancer", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π¬ AI Video Enhancer") | |
| gr.Markdown( | |
| "Upscale videos using Real-ESRGAN AI enhancement.\n\n" | |
| "- π **Not logged in** β bicubic upscaling (fast, no GPU)\n" | |
| "- β **Logged in (HF Pro)** β Real-ESRGAN AI upscaling via your own ZeroGPU quota" | |
| ) | |
| gr.LoginButton() | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| video_in = gr.File(label="Upload video", file_types=[".mp4", ".avi", ".mov", ".mkv", ".webm"]) | |
| scale_choice = gr.Radio(choices=["2", "4"], value="4", label="Upscale Factor") | |
| btn = gr.Button("π Enhance", variant="primary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=1): | |
| out_video = gr.Video(label="Result") | |
| gr.Markdown("**Note:** AI mode is limited to ~30 seconds for ZeroGPU. Longer videos will be truncated.") | |
| btn.click( | |
| fn=process_video, | |
| inputs=[video_in, scale_choice], | |
| outputs=[status, out_video], | |
| ) | |
| gr.api(upload_video) | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True) | |