videoupscaler / app.py
recoilme's picture
Update app.py
29dbc6d verified
import os
import tempfile
import torch
import torchvision.transforms as T
from PIL import Image
import imageio.v3 as iio
import imageio
from tqdm import tqdm
import gradio as gr
import spaces
from diffusers import AsymmetricAutoencoderKL
# -------------------------------------------------------------
# Настройки модели
# -------------------------------------------------------------
MODEL_ID = "AiArtLab/simplevae2x"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16
# -------------------------------------------------------------
# Загрузка VAE
# -------------------------------------------------------------
def load_vae(model_id=MODEL_ID, device=DEVICE):
last_err = None
for attempt in (None, "vae5"):
try:
if attempt is None:
vae = AsymmetricAutoencoderKL.from_pretrained(model_id, torch_dtype=DTYPE)
else:
vae = AsymmetricAutoencoderKL.from_pretrained(model_id, subfolder=attempt, torch_dtype=DTYPE)
vae.to(device)
vae.eval().half()
return vae
except Exception as e:
last_err = e
raise RuntimeError(f"Failed to load VAE {model_id}: {last_err}")
_vae = None
def get_vae():
global _vae
if _vae is None:
_vae = load_vae()
return _vae
# -------------------------------------------------------------
# Апскейл одного кадра через VAE
# -------------------------------------------------------------
def upscale_frame_via_vae(frame, vae):
img = Image.fromarray(frame).convert("RGB")
tfm = T.Compose([
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
t = tfm(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE)
with torch.no_grad():
enc = vae.encode(t)
lat = enc.latent_dist.mean
dec = vae.decode(lat).sample
x = (dec.clamp(-1, 1) + 1) * 127.5
x = x.round().to(torch.uint8).squeeze(0).permute(1, 2, 0).cpu().numpy()
return x
# -------------------------------------------------------------
# Основная функция апскейла видео (без аудио)
# -------------------------------------------------------------
@spaces.GPU
def upscale_video(video_file):
vae = get_vae()
tmp_dir = tempfile.mkdtemp()
output_video_path = os.path.join(tmp_dir, "upscaled_video.mp4")
# Получаем путь
if isinstance(video_file, dict):
video_path = video_file.get("name")
else:
video_path = video_file
# Извлечение FPS и длительности через ffmpeg.probe
import ffmpeg
try:
probe = ffmpeg.probe(video_path)
video_stream = next(s for s in probe["streams"] if s["codec_type"] == "video")
fps_str = video_stream["r_frame_rate"]
fps = eval(fps_str)
duration = float(video_stream.get("duration", 5.0))
except Exception as e:
print("⚠️ Не удалось извлечь метаданные:", e)
fps, duration = 24, 5.0
duration = min(duration, 5.0)
print(f"🎞️ Video FPS: {fps}, duration: {duration:.2f}s")
# Читаем и апскейлим кадры
reader = iio.imiter(video_path)
writer = imageio.get_writer(
output_video_path,
fps=fps,
codec="libx264",
quality=8,
format="FFMPEG"
)
max_frames = int(fps * duration)
for i, frame in enumerate(tqdm(reader, desc="Upscaling frames")):
if i >= max_frames:
break
upscaled = upscale_frame_via_vae(frame, vae)
writer.append_data(upscaled)
writer.close()
return output_video_path
# -------------------------------------------------------------
# Gradio UI
# -------------------------------------------------------------
demo = gr.Interface(
fn=upscale_video,
inputs=gr.Video(label="🎥 Upload a video"),
outputs=gr.Video(label="🎬 Upscaled video"),
title="🧠 Asymmetric VAE 2× Video Upscaler",
description=(
"Апскейлит видео (первые 5 секунд) через [`AiArtLab/simplevae2x`](https://huggingface.co/AiArtLab/simplevae2x). "
"Аудио не сохраняется."
),
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch()