import importlib import os import subprocess import sys import tempfile from pathlib import Path # Install videoflextok without its deps to avoid huggingface_hub==0.25.2 conflicting # with gradio's >=0.33.5 requirement. Compatible dep versions are in requirements.txt. def _install_videoflextok(): try: import videoflextok # noqa: F401 return except ImportError: pass print("[VideoFlexTok] Installing videoflextok (--no-deps) ...") subprocess.run( [sys.executable, "-m", "pip", "install", "--quiet", "--no-deps", "git+https://github.com/apple/ml-videoflextok.git"], check=True, ) importlib.invalidate_caches() _install_videoflextok() import spaces import gradio as gr import imageio.v3 as iio import numpy as np import torch from videoflextok.utils.demo import denormalize, read_mp4 from videoflextok.utils.misc import detect_bf16_support, get_bf16_context from videoflextok.wrappers import VideoFlexTokFromHub # --- Constants --------------------------------------------------------------------- MODEL_ID = "EPFL-VILAB/videoflextok_d18_d28" APP_DIR = Path(__file__).resolve().parent EXAMPLES_DIR = APP_DIR / "examples" EXAMPLE_VIDEOS = sorted(EXAMPLES_DIR.glob("*.mp4")) NUM_KEEP_TOKENS = [2**i for i in range(9)] # 1, 2, 4, 8, 16, 32, 64, 128, 256 APP_CSS = """ #col-container { margin: 0 auto; max-width: 1500px; } #col-input-container { margin: 0 auto; max-width: 420px; } #run-button { margin: 0 auto; } """ # --- Device setup ------------------------------------------------------------------ torch.set_grad_enabled(False) if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") ENABLE_BF16 = DEVICE.type == "cuda" and detect_bf16_support() # --- Model loading ----------------------------------------------------------------- def _patch_for_hf_spaces(model): """Patch TorchDynamo and model for HF Spaces / ZeroGPU compatibility. This PyTorch version's TorchDynamo cannot represent torch.device as a ConstantVariable, causing torch.compile(flex_attention) to crash. The fix was merged into newer PyTorch; here we backport it by adding torch.device to common_constant_types, so the Triton kernel is used correctly instead of falling back to the dense O(n²) math implementation. We also disable block mask compilation (compile_block_mask=False) since create_block_mask uses a separate internal torch.compile call that would hit the same bug. """ # Patch TorchDynamo to accept torch.device as a ConstantVariable. # common_constant_types may be closed over in is_base_literal, so patch the method directly. import torch._dynamo.variables.constant as _dynamo_const _orig_is_base_literal = _dynamo_const.ConstantVariable.is_base_literal @staticmethod def _patched_is_base_literal(value): return isinstance(value, torch.device) or _orig_is_base_literal(value) _dynamo_const.ConstantVariable.is_base_literal = _patched_is_base_literal from videoflextok.model.preprocessors.flex_seq_packing import ( BlockWiseSequencePacker, BlockWiseSequenceInterleavePacker, BlockWiseSequencePackerWithCrossAttention, ) for module in model.modules(): if isinstance(module, ( BlockWiseSequencePacker, BlockWiseSequenceInterleavePacker, BlockWiseSequencePackerWithCrossAttention, )): module.compile_block_mask = False _model = None try: print(f"[VideoFlexTok] Loading {MODEL_ID} ...") _model = VideoFlexTokFromHub.from_pretrained(MODEL_ID) _model = _model.to(torch.bfloat16).to(DEVICE).eval() _patch_for_hf_spaces(_model) print("[VideoFlexTok] Model ready.") except Exception as exc: print(f"[VideoFlexTok] FATAL: model load failed: {exc}") # --- Inference --------------------------------------------------------------------- def _stack_reconstructed_videos(videos, output_path: str, fps: int): """Compose 9 reconstructions + original into a 2×5 grid video and write to output_path.""" def to_uint8_frames(video_tensor): if video_tensor.ndim == 5: video_tensor = video_tensor[0] frames = denormalize(video_tensor).permute(1, 2, 3, 0).contiguous().numpy() return (np.clip(frames, 0.0, 1.0) * 255).round().astype(np.uint8) def add_border(frames: np.ndarray, border_px: int, color: int) -> np.ndarray: return np.pad( frames, ((0, 0), (border_px, border_px), (border_px, border_px), (0, 0)), mode="constant", constant_values=color, ) def compose_row(row_frames: list[np.ndarray], t: int, gap_px: int) -> np.ndarray: gap_col = np.full((row_frames[0].shape[1], gap_px, 3), 255, dtype=np.uint8) items = [] for i, frames in enumerate(row_frames): items.append(frames[t]) if i < len(row_frames) - 1: items.append(gap_col) return np.concatenate(items, axis=1) border_px, gap_px = 8, 8 reconstructed = [add_border(to_uint8_frames(v), border_px, 255) for v in videos[:9]] original = add_border(to_uint8_frames(videos[9]), border_px, 0) all_panels = reconstructed + [original] total_frames = min(p.shape[0] for p in all_panels) all_panels = [p[:total_frames] for p in all_panels] row1 = all_panels[:5] # k = 1, 2, 4, 8, 16 row2 = all_panels[5:] # k = 32, 64, 128, 256, Original composed = [] for t in range(total_frames): row1_img = compose_row(row1, t, gap_px) row2_img = compose_row(row2, t, gap_px) row_gap = np.full((gap_px, row1_img.shape[1], 3), 255, dtype=np.uint8) composed.append(np.concatenate([row1_img, row_gap, row2_img], axis=0)) iio.imwrite( output_path, np.stack(composed, axis=0), fps=fps, plugin="FFMPEG", codec="libx264", pixelformat="yuv420p", ) def reconstruct_video(video_path: str, input_fps: int, timesteps: int, guidance_scale: float, seed: int): if not video_path or not Path(video_path).exists(): raise gr.Error("Upload a video first.") if _model is None: raise gr.Error("Model failed to load at startup — check Space logs.") try: preprocess_args = dict(_model.video_preprocess_args) # Public package uses 'overlap_size'; model config key is 'overlap_size_frames' if "overlap_size_frames" in preprocess_args and "overlap_size" not in preprocess_args: preprocess_args["overlap_size"] = preprocess_args.pop("overlap_size_frames") video_tensor = read_mp4(str(video_path), fps=int(input_fps), **preprocess_args) except Exception as exc: raise gr.Error(f"Failed to decode video: {exc}") from exc try: with get_bf16_context(ENABLE_BF16, device_type=DEVICE.type): print(f"[VideoFlexTok] Tokenizing {video_tensor.shape} ...") token_ids = _model.tokenize(video_tensor[None].to(DEVICE)) print(f"[VideoFlexTok] Decoding {len(NUM_KEEP_TOKENS)} reconstructions ...") reconstructed = _model.detokenize( [token_ids[0]] * len(NUM_KEEP_TOKENS), num_keep_tokens_list=NUM_KEEP_TOKENS, timesteps=int(timesteps), guidance_scale=float(guidance_scale), perform_norm_guidance=True, generator=torch.Generator(device=DEVICE.type).manual_seed(int(seed)), eta=0.0, momentum=0.0, norm_threshold=0.6, verbose=False, ) reconstructed = [v.cpu().float() for v in reconstructed] print("[VideoFlexTok] Inference complete.") except Exception as exc: raise gr.Error(f"Model inference failed: {exc}") from exc tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) tmp.close() _stack_reconstructed_videos(reconstructed + [video_tensor], output_path=tmp.name, fps=int(input_fps)) info = f"Extracted {video_tensor.shape[1]} frames at {input_fps} FPS" return tmp.name, info if spaces is not None and hasattr(spaces, "GPU"): reconstruct_video = spaces.GPU(duration=60)(reconstruct_video) # --- UI ---------------------------------------------------------------------------- with gr.Blocks(title="VideoFlexTok Demo", theme=gr.themes.Base(), css=APP_CSS) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# VideoFlexTok: Flexible-Length Coarse-to-Fine Video Tokenization") with gr.Row(): with gr.Column(scale=1, elem_id="col-input-container"): gr.Markdown(f""" [`Website`](https://videoflextok.epfl.ch) | [`Paper`](https://arxiv.org/abs/2604.12887) | [`GitHub`](https://github.com/apple/ml-videoflextok) | [`Model`](https://huggingface.co/EPFL-VILAB/videoflextok_d18_d28) Research demo for **VideoFlexTok: Flexible-Length Coarse-to-Fine Video Tokenization** (arXiv 2026). Autoencodes your video with `{MODEL_ID}` and shows coarse-to-fine reconstructions. VideoFlexTok tokenizes video into `T × 256` tokens ordered coarse-to-fine; this demo shows reconstructions from `T × k` tokens for k ∈ `{NUM_KEEP_TOKENS}`. Bottom-right is the original. """) input_video = gr.Video( label="Input video", sources=["upload"], format="mp4", ) run_button = gr.Button("Autoencode with VideoFlexTok", elem_id="run-button") if EXAMPLE_VIDEOS: gr.Examples( examples=[str(p) for p in EXAMPLE_VIDEOS], inputs=[input_video], outputs=[input_video], fn=lambda p: p, cache_examples=True, label="Example videos", ) with gr.Accordion("Advanced Settings", open=False): gr.Markdown("Adjust target FPS to control how many frames are extracted.") input_fps = gr.Slider(minimum=1, maximum=16, value=8, step=1, label="Target FPS") timesteps = gr.Slider(minimum=1, maximum=60, value=20, step=1, label="Denoising steps") guidance_scale = gr.Slider(minimum=1.0, maximum=30.0, value=25.0, step=0.5, label="Guidance scale") seed = gr.Number(value=42, precision=0, label="Seed") with gr.Column(scale=4): output_video = gr.Video(label="Reconstructions") status = gr.Markdown() run_button.click( fn=reconstruct_video, inputs=[input_video, input_fps, timesteps, guidance_scale, seed], outputs=[output_video, status], ) if DEVICE.type != "cuda": gr.Markdown("Running on CPU — inference will be slow.") # --- Launch ------------------------------------------------------------------------ demo.queue(max_size=16) if __name__ == "__main__": server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") launch_kwargs = {"server_name": server_name, "ssr_mode": False} if port := os.environ.get("GRADIO_SERVER_PORT"): launch_kwargs["server_port"] = int(port) launch_kwargs["allowed_paths"] = [str(APP_DIR), tempfile.gettempdir()] demo.launch(**launch_kwargs)