Spaces:
Paused
Paused
File size: 4,363 Bytes
332ad95 ade5830 4ec83db 8f43770 0d5f536 ade5830 332ad95 ade5830 54839fc ade5830 4ec83db ade5830 29dbc6d 4ec83db ade5830 4ec83db 4a1c292 ade5830 332ad95 4ec83db 332ad95 4ec83db 8f43770 4ec83db ade5830 332ad95 5ba4a34 332ad95 6625197 ade5830 5ba4a34 8f43770 4ec83db 8f43770 0d5f536 5ba4a34 332ad95 0d5f536 332ad95 0d5f536 332ad95 0d5f536 332ad95 5ba4a34 332ad95 0d5f536 5ba4a34 332ad95 0d5f536 332ad95 8f43770 4ec83db 8f43770 5ba4a34 332ad95 ade5830 332ad95 4ec83db 5ba4a34 4ec83db 332ad95 ade5830 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import os
import tempfile
import torch
import torchvision.transforms as T
from PIL import Image
import imageio.v3 as iio
import imageio
from tqdm import tqdm
import gradio as gr
import spaces
from diffusers import AsymmetricAutoencoderKL
# -------------------------------------------------------------
# Настройки модели
# -------------------------------------------------------------
MODEL_ID = "AiArtLab/simplevae2x"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16
# -------------------------------------------------------------
# Загрузка VAE
# -------------------------------------------------------------
def load_vae(model_id=MODEL_ID, device=DEVICE):
last_err = None
for attempt in (None, "vae5"):
try:
if attempt is None:
vae = AsymmetricAutoencoderKL.from_pretrained(model_id, torch_dtype=DTYPE)
else:
vae = AsymmetricAutoencoderKL.from_pretrained(model_id, subfolder=attempt, torch_dtype=DTYPE)
vae.to(device)
vae.eval().half()
return vae
except Exception as e:
last_err = e
raise RuntimeError(f"Failed to load VAE {model_id}: {last_err}")
_vae = None
def get_vae():
global _vae
if _vae is None:
_vae = load_vae()
return _vae
# -------------------------------------------------------------
# Апскейл одного кадра через VAE
# -------------------------------------------------------------
def upscale_frame_via_vae(frame, vae):
img = Image.fromarray(frame).convert("RGB")
tfm = T.Compose([
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
t = tfm(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE)
with torch.no_grad():
enc = vae.encode(t)
lat = enc.latent_dist.mean
dec = vae.decode(lat).sample
x = (dec.clamp(-1, 1) + 1) * 127.5
x = x.round().to(torch.uint8).squeeze(0).permute(1, 2, 0).cpu().numpy()
return x
# -------------------------------------------------------------
# Основная функция апскейла видео (без аудио)
# -------------------------------------------------------------
@spaces.GPU
def upscale_video(video_file):
vae = get_vae()
tmp_dir = tempfile.mkdtemp()
output_video_path = os.path.join(tmp_dir, "upscaled_video.mp4")
# Получаем путь
if isinstance(video_file, dict):
video_path = video_file.get("name")
else:
video_path = video_file
# Извлечение FPS и длительности через ffmpeg.probe
import ffmpeg
try:
probe = ffmpeg.probe(video_path)
video_stream = next(s for s in probe["streams"] if s["codec_type"] == "video")
fps_str = video_stream["r_frame_rate"]
fps = eval(fps_str)
duration = float(video_stream.get("duration", 5.0))
except Exception as e:
print("⚠️ Не удалось извлечь метаданные:", e)
fps, duration = 24, 5.0
duration = min(duration, 5.0)
print(f"🎞️ Video FPS: {fps}, duration: {duration:.2f}s")
# Читаем и апскейлим кадры
reader = iio.imiter(video_path)
writer = imageio.get_writer(
output_video_path,
fps=fps,
codec="libx264",
quality=8,
format="FFMPEG"
)
max_frames = int(fps * duration)
for i, frame in enumerate(tqdm(reader, desc="Upscaling frames")):
if i >= max_frames:
break
upscaled = upscale_frame_via_vae(frame, vae)
writer.append_data(upscaled)
writer.close()
return output_video_path
# -------------------------------------------------------------
# Gradio UI
# -------------------------------------------------------------
demo = gr.Interface(
fn=upscale_video,
inputs=gr.Video(label="🎥 Upload a video"),
outputs=gr.Video(label="🎬 Upscaled video"),
title="🧠 Asymmetric VAE 2× Video Upscaler",
description=(
"Апскейлит видео (первые 5 секунд) через [`AiArtLab/simplevae2x`](https://huggingface.co/AiArtLab/simplevae2x). "
"Аудио не сохраняется."
),
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch()
|