atharvak30's picture
Upload app.py
8358d83 verified
import os
import shutil
import subprocess
import tempfile
import time
from pathlib import Path
from typing import Tuple
import gradio as gr
import spaces
import torch
import cv2
from huggingface_hub import hf_hub_download
TEMP_DIR = Path(tempfile.gettempdir()) / "hf_video_enhancer"
TEMP_DIR.mkdir(parents=True, exist_ok=True)
def run_cmd(cmd):
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if p.returncode != 0:
raise RuntimeError(f"Command failed: {p.stderr.decode()}")
return p.stdout.decode()
def probe_video(video_path: str) -> Tuple[float, int, int, float]:
cmd = [
"ffprobe", "-v", "error",
"-select_streams", "v:0",
"-show_entries", "stream=width,height,duration,r_frame_rate",
"-of", "default=noprint_wrappers=1:nokey=0",
video_path
]
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out = p.stdout.decode()
width = height = 0
duration = 0.0
fps = 30.0
for line in out.splitlines():
if line.startswith("width="):
width = int(line.split("=")[1])
elif line.startswith("height="):
height = int(line.split("=")[1])
elif line.startswith("duration="):
try:
duration = float(line.split("=")[1])
except:
pass
elif line.startswith("r_frame_rate="):
try:
fps_str = line.split("=")[1]
if "/" in fps_str:
num, den = fps_str.split("/")
fps = float(num) / float(den)
else:
fps = float(fps_str)
except:
pass
return duration, width, height, fps
def extract_frames(video_path: str, frames_dir: Path):
frames_dir.mkdir(parents=True, exist_ok=True)
run_cmd([
"ffmpeg", "-y", "-i", video_path,
"-vsync", "0",
str(frames_dir / "%06d.png")
])
def reassemble_video(frames_dir: Path, audio_src: str, out_path: str, fps: float = 30.0):
tmp_video = str(frames_dir.parent / "tmp_video.mp4")
run_cmd([
"ffmpeg", "-y", "-framerate", str(fps),
"-i", str(frames_dir / "%06d.png"),
"-c:v", "libx264", "-preset", "veryfast", "-pix_fmt", "yuv420p",
"-crf", "18", tmp_video
])
p = subprocess.run(
["ffprobe", "-v", "error", "-select_streams", "a", "-show_entries",
"stream=codec_type", "-of", "default=noprint_wrappers=1", audio_src],
stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
if p.stdout.decode().strip():
run_cmd([
"ffmpeg", "-y", "-i", tmp_video, "-i", audio_src,
"-c:v", "copy", "-c:a", "aac",
"-map", "0:v:0", "-map", "1:a:0", out_path
])
os.remove(tmp_video)
else:
shutil.move(tmp_video, out_path)
def bicubic_upscale_frames(frames_dir: Path, scale: int):
for fp in sorted(frames_dir.glob("*.png")):
img = cv2.imread(str(fp))
if img is None:
continue
h, w = img.shape[:2]
upscaled = cv2.resize(img, (w * scale, h * scale), interpolation=cv2.INTER_CUBIC)
cv2.imwrite(str(fp), upscaled)
@spaces.GPU(duration=120)
def enhance_with_realesrgan(frames_dir: str, scale: int = 4) -> int:
from spandrel import ImageModelDescriptor, ModelLoader
import numpy as np
frames_path = Path(frames_dir)
frame_files = sorted(frames_path.glob("*.png"))
total = len(frame_files)
if total == 0:
return 0
if scale == 2:
model_path = hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x2.pth")
else:
model_path = hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth")
model = ModelLoader().load_from_file(model_path)
assert isinstance(model, ImageModelDescriptor)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device).eval()
print(f"Model loaded on {device}, processing {total} frames...")
for idx, frame_path in enumerate(frame_files):
img = cv2.imread(str(frame_path))
if img is None:
continue
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float().div(255.0)
tensor = tensor.unsqueeze(0).to(device)
with torch.no_grad():
output = model(tensor)
output = output.squeeze(0).cpu().clamp(0, 1).mul(255).byte()
output = output.permute(1, 2, 0).numpy()
output_bgr = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
cv2.imwrite(str(frame_path), output_bgr)
if (idx + 1) % 5 == 0:
print(f"Processed {idx + 1}/{total}")
return total
def upload_video(video_file: str) -> str:
"""
Upload a video file to the server to prepare it for enhancement.
MCP callers must call this tool first to stage the video, then pass
the returned path into process_video.
Args:
video_file: Local path to the video file to upload.
Returns:
The server-side path to pass directly into process_video.
"""
return video_file
def process_video(
video_file: str,
scale: str = "4",
oauth_token: gr.OAuthToken | None = None,
request: gr.Request | None = None,
) -> Tuple[str, str]:
"""
Upscale and enhance a video using Real-ESRGAN AI super-resolution or bicubic interpolation.
Authentication determines the upscaling mode:
- MCP callers: add your HF Pro token to the Authorization header in your MCP client config:
Authorization:Bearer hf_YOUR_TOKEN
ZeroGPU will bill your own Pro quota. No token is needed in this function itself.
- Browser users: click the Login button to authenticate with your own ZeroGPU quota.
- Unauthenticated callers: bicubic upscaling is used automatically (no GPU required).
Args:
video_file: Path or URL to the input video file. Supported formats: mp4, avi, mov, mkv, webm.
scale: Upscaling factor. Use "2" for 2x or "4" for 4x resolution (default: "4").
Returns:
A tuple of (status_message, output_video_path). Status describes the method used
and resolution change (e.g. '[AI] 480x270 β†’ 1920x1080' or '[Bicubic] 480x270 β†’ 1920x1080').
"""
if video_file is None:
return "⚠️ Please upload a video file.", None
# Handle FileData dict from MCP ({"path": "...", "url": "..."})
if isinstance(video_file, dict):
video_file = video_file.get("url") or video_file.get("path") or ""
# Download if it's an HTTP URL
if isinstance(video_file, str) and video_file.startswith("http"):
import urllib.request
ts_dl = int(time.time() * 1000)
dl_dir = TEMP_DIR / f"dl_{ts_dl}"
dl_dir.mkdir(parents=True, exist_ok=True)
suffix = Path(video_file.split("?")[0]).suffix or ".mp4"
dl_path = dl_dir / f"input{suffix}"
try:
urllib.request.urlretrieve(video_file, dl_path)
except Exception as e:
return f"❌ Failed to download video: {e}", None
video_file = str(dl_path)
# Convert scale to int safely
try:
scale_int = int(scale)
except (ValueError, TypeError):
scale_int = 4
# Auth check: browser OAuth login OR Authorization header from MCP caller
use_ai = oauth_token is not None
if not use_ai and request is not None:
auth_header = request.headers.get("authorization", "")
if auth_header.lower().startswith("bearer hf_"):
use_ai = True
print("AI mode enabled via Authorization header")
mode_label = "AI (Real-ESRGAN)" if use_ai else "Bicubic"
print(f"Upscaling mode: {mode_label} | scale: {scale_int}x")
ts = int(time.time() * 1000)
base_dir = TEMP_DIR / f"job_{ts}"
base_dir.mkdir(parents=True, exist_ok=True)
in_path = base_dir / "input_video"
try:
shutil.copy(video_file, in_path)
except Exception as e:
return f"Error copying file: {e}", None
try:
duration, w, h, fps = probe_video(str(in_path))
except Exception as e:
shutil.rmtree(base_dir, ignore_errors=True)
return f"Error probing video: {e}", None
if duration <= 0:
shutil.rmtree(base_dir, ignore_errors=True)
return "Could not determine video duration.", None
max_frames = int(fps * 30)
print(f"Video: {w}x{h}, {duration:.1f}s, {fps:.1f}fps")
frames_dir = base_dir / "frames"
try:
extract_frames(str(in_path), frames_dir)
except Exception as e:
shutil.rmtree(base_dir, ignore_errors=True)
return f"Failed extracting frames: {e}", None
frame_files = sorted(frames_dir.glob("*.png"))
num_frames = len(frame_files)
if num_frames > max_frames:
print(f"Limiting from {num_frames} to {max_frames} frames")
for f in frame_files[max_frames:]:
f.unlink()
num_frames = max_frames
print(f"Processing {num_frames} frames with {mode_label}...")
if use_ai:
try:
enhanced = enhance_with_realesrgan(str(frames_dir), scale_int)
print(f"AI-enhanced {enhanced} frames")
except Exception as e:
shutil.rmtree(base_dir, ignore_errors=True)
return f"❌ AI enhancement failed: {e}", None
else:
try:
bicubic_upscale_frames(frames_dir, scale_int)
except Exception as e:
shutil.rmtree(base_dir, ignore_errors=True)
return f"Bicubic upscaling failed: {e}", None
out_video = base_dir / "enhanced_output.mp4"
try:
reassemble_video(frames_dir, str(in_path), str(out_video), fps)
except Exception as e:
shutil.rmtree(base_dir, ignore_errors=True)
return f"Failed reassembling video: {e}", None
shutil.rmtree(frames_dir, ignore_errors=True)
tag = "AI" if use_ai else "Bicubic"
try:
_, out_w, out_h, _ = probe_video(str(out_video))
return f"βœ… Done! [{tag}] {w}x{h} β†’ {out_w}x{out_h}", str(out_video)
except:
return f"βœ… Done! [{tag}]", str(out_video)
# Gradio UI
with gr.Blocks(title="AI Video Enhancer", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎬 AI Video Enhancer")
gr.Markdown(
"Upscale videos using Real-ESRGAN AI enhancement.\n\n"
"- πŸ”’ **Not logged in** β€” bicubic upscaling (fast, no GPU)\n"
"- βœ… **Logged in (HF Pro)** β€” Real-ESRGAN AI upscaling via your own ZeroGPU quota"
)
gr.LoginButton()
with gr.Row():
with gr.Column(scale=2):
video_in = gr.File(label="Upload video", file_types=[".mp4", ".avi", ".mov", ".mkv", ".webm"])
scale_choice = gr.Radio(choices=["2", "4"], value="4", label="Upscale Factor")
btn = gr.Button("πŸš€ Enhance", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=1):
out_video = gr.Video(label="Result")
gr.Markdown("**Note:** AI mode is limited to ~30 seconds for ZeroGPU. Longer videos will be truncated.")
btn.click(
fn=process_video,
inputs=[video_in, scale_choice],
outputs=[status, out_video],
)
gr.api(upload_video)
if __name__ == "__main__":
demo.launch(mcp_server=True)