import os import tempfile import torch import torchvision.transforms as T from PIL import Image import imageio.v3 as iio import imageio from tqdm import tqdm import gradio as gr import spaces from diffusers import AsymmetricAutoencoderKL # ------------------------------------------------------------- # Настройки модели # ------------------------------------------------------------- MODEL_ID = "AiArtLab/simplevae2x" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 # ------------------------------------------------------------- # Загрузка VAE # ------------------------------------------------------------- def load_vae(model_id=MODEL_ID, device=DEVICE): last_err = None for attempt in (None, "vae5"): try: if attempt is None: vae = AsymmetricAutoencoderKL.from_pretrained(model_id, torch_dtype=DTYPE) else: vae = AsymmetricAutoencoderKL.from_pretrained(model_id, subfolder=attempt, torch_dtype=DTYPE) vae.to(device) vae.eval().half() return vae except Exception as e: last_err = e raise RuntimeError(f"Failed to load VAE {model_id}: {last_err}") _vae = None def get_vae(): global _vae if _vae is None: _vae = load_vae() return _vae # ------------------------------------------------------------- # Апскейл одного кадра через VAE # ------------------------------------------------------------- def upscale_frame_via_vae(frame, vae): img = Image.fromarray(frame).convert("RGB") tfm = T.Compose([ T.ToTensor(), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) t = tfm(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE) with torch.no_grad(): enc = vae.encode(t) lat = enc.latent_dist.mean dec = vae.decode(lat).sample x = (dec.clamp(-1, 1) + 1) * 127.5 x = x.round().to(torch.uint8).squeeze(0).permute(1, 2, 0).cpu().numpy() return x # ------------------------------------------------------------- # Основная функция апскейла видео (без аудио) # ------------------------------------------------------------- @spaces.GPU def upscale_video(video_file): vae = get_vae() tmp_dir = tempfile.mkdtemp() output_video_path = os.path.join(tmp_dir, "upscaled_video.mp4") # Получаем путь if isinstance(video_file, dict): video_path = video_file.get("name") else: video_path = video_file # Извлечение FPS и длительности через ffmpeg.probe import ffmpeg try: probe = ffmpeg.probe(video_path) video_stream = next(s for s in probe["streams"] if s["codec_type"] == "video") fps_str = video_stream["r_frame_rate"] fps = eval(fps_str) duration = float(video_stream.get("duration", 5.0)) except Exception as e: print("⚠️ Не удалось извлечь метаданные:", e) fps, duration = 24, 5.0 duration = min(duration, 5.0) print(f"🎞️ Video FPS: {fps}, duration: {duration:.2f}s") # Читаем и апскейлим кадры reader = iio.imiter(video_path) writer = imageio.get_writer( output_video_path, fps=fps, codec="libx264", quality=8, format="FFMPEG" ) max_frames = int(fps * duration) for i, frame in enumerate(tqdm(reader, desc="Upscaling frames")): if i >= max_frames: break upscaled = upscale_frame_via_vae(frame, vae) writer.append_data(upscaled) writer.close() return output_video_path # ------------------------------------------------------------- # Gradio UI # ------------------------------------------------------------- demo = gr.Interface( fn=upscale_video, inputs=gr.Video(label="🎥 Upload a video"), outputs=gr.Video(label="🎬 Upscaled video"), title="🧠 Asymmetric VAE 2× Video Upscaler", description=( "Апскейлит видео (первые 5 секунд) через [`AiArtLab/simplevae2x`](https://huggingface.co/AiArtLab/simplevae2x). " "Аудио не сохраняется." ), allow_flagging="never", ) if __name__ == "__main__": demo.launch()