Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import tempfile | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| import imageio.v3 as iio | |
| import imageio | |
| from tqdm import tqdm | |
| import gradio as gr | |
| import spaces | |
| from diffusers import AsymmetricAutoencoderKL | |
| # ------------------------------------------------------------- | |
| # Настройки модели | |
| # ------------------------------------------------------------- | |
| MODEL_ID = "AiArtLab/simplevae2x" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 | |
| # ------------------------------------------------------------- | |
| # Загрузка VAE | |
| # ------------------------------------------------------------- | |
| def load_vae(model_id=MODEL_ID, device=DEVICE): | |
| last_err = None | |
| for attempt in (None, "vae5"): | |
| try: | |
| if attempt is None: | |
| vae = AsymmetricAutoencoderKL.from_pretrained(model_id, torch_dtype=DTYPE) | |
| else: | |
| vae = AsymmetricAutoencoderKL.from_pretrained(model_id, subfolder=attempt, torch_dtype=DTYPE) | |
| vae.to(device) | |
| vae.eval().half() | |
| return vae | |
| except Exception as e: | |
| last_err = e | |
| raise RuntimeError(f"Failed to load VAE {model_id}: {last_err}") | |
| _vae = None | |
| def get_vae(): | |
| global _vae | |
| if _vae is None: | |
| _vae = load_vae() | |
| return _vae | |
| # ------------------------------------------------------------- | |
| # Апскейл одного кадра через VAE | |
| # ------------------------------------------------------------- | |
| def upscale_frame_via_vae(frame, vae): | |
| img = Image.fromarray(frame).convert("RGB") | |
| tfm = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ]) | |
| t = tfm(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE) | |
| with torch.no_grad(): | |
| enc = vae.encode(t) | |
| lat = enc.latent_dist.mean | |
| dec = vae.decode(lat).sample | |
| x = (dec.clamp(-1, 1) + 1) * 127.5 | |
| x = x.round().to(torch.uint8).squeeze(0).permute(1, 2, 0).cpu().numpy() | |
| return x | |
| # ------------------------------------------------------------- | |
| # Основная функция апскейла видео (без аудио) | |
| # ------------------------------------------------------------- | |
| def upscale_video(video_file): | |
| vae = get_vae() | |
| tmp_dir = tempfile.mkdtemp() | |
| output_video_path = os.path.join(tmp_dir, "upscaled_video.mp4") | |
| # Получаем путь | |
| if isinstance(video_file, dict): | |
| video_path = video_file.get("name") | |
| else: | |
| video_path = video_file | |
| # Извлечение FPS и длительности через ffmpeg.probe | |
| import ffmpeg | |
| try: | |
| probe = ffmpeg.probe(video_path) | |
| video_stream = next(s for s in probe["streams"] if s["codec_type"] == "video") | |
| fps_str = video_stream["r_frame_rate"] | |
| fps = eval(fps_str) | |
| duration = float(video_stream.get("duration", 5.0)) | |
| except Exception as e: | |
| print("⚠️ Не удалось извлечь метаданные:", e) | |
| fps, duration = 24, 5.0 | |
| duration = min(duration, 5.0) | |
| print(f"🎞️ Video FPS: {fps}, duration: {duration:.2f}s") | |
| # Читаем и апскейлим кадры | |
| reader = iio.imiter(video_path) | |
| writer = imageio.get_writer( | |
| output_video_path, | |
| fps=fps, | |
| codec="libx264", | |
| quality=8, | |
| format="FFMPEG" | |
| ) | |
| max_frames = int(fps * duration) | |
| for i, frame in enumerate(tqdm(reader, desc="Upscaling frames")): | |
| if i >= max_frames: | |
| break | |
| upscaled = upscale_frame_via_vae(frame, vae) | |
| writer.append_data(upscaled) | |
| writer.close() | |
| return output_video_path | |
| # ------------------------------------------------------------- | |
| # Gradio UI | |
| # ------------------------------------------------------------- | |
| demo = gr.Interface( | |
| fn=upscale_video, | |
| inputs=gr.Video(label="🎥 Upload a video"), | |
| outputs=gr.Video(label="🎬 Upscaled video"), | |
| title="🧠 Asymmetric VAE 2× Video Upscaler", | |
| description=( | |
| "Апскейлит видео (первые 5 секунд) через [`AiArtLab/simplevae2x`](https://huggingface.co/AiArtLab/simplevae2x). " | |
| "Аудио не сохраняется." | |
| ), | |
| allow_flagging="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |