Spaces:
Running on Zero
Running on Zero
| import importlib | |
| import os | |
| import subprocess | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| # Install videoflextok without its deps to avoid huggingface_hub==0.25.2 conflicting | |
| # with gradio's >=0.33.5 requirement. Compatible dep versions are in requirements.txt. | |
| def _install_videoflextok(): | |
| try: | |
| import videoflextok # noqa: F401 | |
| return | |
| except ImportError: | |
| pass | |
| print("[VideoFlexTok] Installing videoflextok (--no-deps) ...") | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--quiet", "--no-deps", | |
| "git+https://github.com/apple/ml-videoflextok.git"], | |
| check=True, | |
| ) | |
| importlib.invalidate_caches() | |
| _install_videoflextok() | |
| import spaces | |
| import gradio as gr | |
| import imageio.v3 as iio | |
| import numpy as np | |
| import torch | |
| from videoflextok.utils.demo import denormalize, read_mp4 | |
| from videoflextok.utils.misc import detect_bf16_support, get_bf16_context | |
| from videoflextok.wrappers import VideoFlexTokFromHub | |
| # --- Constants --------------------------------------------------------------------- | |
| MODEL_ID = "EPFL-VILAB/videoflextok_d18_d28" | |
| APP_DIR = Path(__file__).resolve().parent | |
| EXAMPLES_DIR = APP_DIR / "examples" | |
| EXAMPLE_VIDEOS = sorted(EXAMPLES_DIR.glob("*.mp4")) | |
| NUM_KEEP_TOKENS = [2**i for i in range(9)] # 1, 2, 4, 8, 16, 32, 64, 128, 256 | |
| APP_CSS = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1500px; | |
| } | |
| #col-input-container { | |
| margin: 0 auto; | |
| max-width: 420px; | |
| } | |
| #run-button { | |
| margin: 0 auto; | |
| } | |
| """ | |
| # --- Device setup ------------------------------------------------------------------ | |
| torch.set_grad_enabled(False) | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ENABLE_BF16 = DEVICE.type == "cuda" and detect_bf16_support() | |
| # --- Model loading ----------------------------------------------------------------- | |
| def _patch_for_hf_spaces(model): | |
| """Patch TorchDynamo and model for HF Spaces / ZeroGPU compatibility. | |
| This PyTorch version's TorchDynamo cannot represent torch.device as a ConstantVariable, | |
| causing torch.compile(flex_attention) to crash. The fix was merged into newer PyTorch; | |
| here we backport it by adding torch.device to common_constant_types, so the Triton | |
| kernel is used correctly instead of falling back to the dense O(n²) math implementation. | |
| We also disable block mask compilation (compile_block_mask=False) since create_block_mask | |
| uses a separate internal torch.compile call that would hit the same bug. | |
| """ | |
| # Patch TorchDynamo to accept torch.device as a ConstantVariable. | |
| # common_constant_types may be closed over in is_base_literal, so patch the method directly. | |
| import torch._dynamo.variables.constant as _dynamo_const | |
| _orig_is_base_literal = _dynamo_const.ConstantVariable.is_base_literal | |
| def _patched_is_base_literal(value): | |
| return isinstance(value, torch.device) or _orig_is_base_literal(value) | |
| _dynamo_const.ConstantVariable.is_base_literal = _patched_is_base_literal | |
| from videoflextok.model.preprocessors.flex_seq_packing import ( | |
| BlockWiseSequencePacker, | |
| BlockWiseSequenceInterleavePacker, | |
| BlockWiseSequencePackerWithCrossAttention, | |
| ) | |
| for module in model.modules(): | |
| if isinstance(module, ( | |
| BlockWiseSequencePacker, | |
| BlockWiseSequenceInterleavePacker, | |
| BlockWiseSequencePackerWithCrossAttention, | |
| )): | |
| module.compile_block_mask = False | |
| _model = None | |
| try: | |
| print(f"[VideoFlexTok] Loading {MODEL_ID} ...") | |
| _model = VideoFlexTokFromHub.from_pretrained(MODEL_ID) | |
| _model = _model.to(torch.bfloat16).to(DEVICE).eval() | |
| _patch_for_hf_spaces(_model) | |
| print("[VideoFlexTok] Model ready.") | |
| except Exception as exc: | |
| print(f"[VideoFlexTok] FATAL: model load failed: {exc}") | |
| # --- Inference --------------------------------------------------------------------- | |
| def _stack_reconstructed_videos(videos, output_path: str, fps: int): | |
| """Compose 9 reconstructions + original into a 2×5 grid video and write to output_path.""" | |
| def to_uint8_frames(video_tensor): | |
| if video_tensor.ndim == 5: | |
| video_tensor = video_tensor[0] | |
| frames = denormalize(video_tensor).permute(1, 2, 3, 0).contiguous().numpy() | |
| return (np.clip(frames, 0.0, 1.0) * 255).round().astype(np.uint8) | |
| def add_border(frames: np.ndarray, border_px: int, color: int) -> np.ndarray: | |
| return np.pad( | |
| frames, | |
| ((0, 0), (border_px, border_px), (border_px, border_px), (0, 0)), | |
| mode="constant", constant_values=color, | |
| ) | |
| def compose_row(row_frames: list[np.ndarray], t: int, gap_px: int) -> np.ndarray: | |
| gap_col = np.full((row_frames[0].shape[1], gap_px, 3), 255, dtype=np.uint8) | |
| items = [] | |
| for i, frames in enumerate(row_frames): | |
| items.append(frames[t]) | |
| if i < len(row_frames) - 1: | |
| items.append(gap_col) | |
| return np.concatenate(items, axis=1) | |
| border_px, gap_px = 8, 8 | |
| reconstructed = [add_border(to_uint8_frames(v), border_px, 255) for v in videos[:9]] | |
| original = add_border(to_uint8_frames(videos[9]), border_px, 0) | |
| all_panels = reconstructed + [original] | |
| total_frames = min(p.shape[0] for p in all_panels) | |
| all_panels = [p[:total_frames] for p in all_panels] | |
| row1 = all_panels[:5] # k = 1, 2, 4, 8, 16 | |
| row2 = all_panels[5:] # k = 32, 64, 128, 256, Original | |
| composed = [] | |
| for t in range(total_frames): | |
| row1_img = compose_row(row1, t, gap_px) | |
| row2_img = compose_row(row2, t, gap_px) | |
| row_gap = np.full((gap_px, row1_img.shape[1], 3), 255, dtype=np.uint8) | |
| composed.append(np.concatenate([row1_img, row_gap, row2_img], axis=0)) | |
| iio.imwrite( | |
| output_path, np.stack(composed, axis=0), | |
| fps=fps, plugin="FFMPEG", codec="libx264", pixelformat="yuv420p", | |
| ) | |
| def reconstruct_video(video_path: str, input_fps: int, timesteps: int, guidance_scale: float, seed: int): | |
| if not video_path or not Path(video_path).exists(): | |
| raise gr.Error("Upload a video first.") | |
| if _model is None: | |
| raise gr.Error("Model failed to load at startup — check Space logs.") | |
| try: | |
| preprocess_args = dict(_model.video_preprocess_args) | |
| # Public package uses 'overlap_size'; model config key is 'overlap_size_frames' | |
| if "overlap_size_frames" in preprocess_args and "overlap_size" not in preprocess_args: | |
| preprocess_args["overlap_size"] = preprocess_args.pop("overlap_size_frames") | |
| video_tensor = read_mp4(str(video_path), fps=int(input_fps), **preprocess_args) | |
| except Exception as exc: | |
| raise gr.Error(f"Failed to decode video: {exc}") from exc | |
| try: | |
| with get_bf16_context(ENABLE_BF16, device_type=DEVICE.type): | |
| print(f"[VideoFlexTok] Tokenizing {video_tensor.shape} ...") | |
| token_ids = _model.tokenize(video_tensor[None].to(DEVICE)) | |
| print(f"[VideoFlexTok] Decoding {len(NUM_KEEP_TOKENS)} reconstructions ...") | |
| reconstructed = _model.detokenize( | |
| [token_ids[0]] * len(NUM_KEEP_TOKENS), | |
| num_keep_tokens_list=NUM_KEEP_TOKENS, | |
| timesteps=int(timesteps), | |
| guidance_scale=float(guidance_scale), | |
| perform_norm_guidance=True, | |
| generator=torch.Generator(device=DEVICE.type).manual_seed(int(seed)), | |
| eta=0.0, momentum=0.0, norm_threshold=0.6, verbose=False, | |
| ) | |
| reconstructed = [v.cpu().float() for v in reconstructed] | |
| print("[VideoFlexTok] Inference complete.") | |
| except Exception as exc: | |
| raise gr.Error(f"Model inference failed: {exc}") from exc | |
| tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| tmp.close() | |
| _stack_reconstructed_videos(reconstructed + [video_tensor], output_path=tmp.name, fps=int(input_fps)) | |
| info = f"Extracted {video_tensor.shape[1]} frames at {input_fps} FPS" | |
| return tmp.name, info | |
| if spaces is not None and hasattr(spaces, "GPU"): | |
| reconstruct_video = spaces.GPU(duration=60)(reconstruct_video) | |
| # --- UI ---------------------------------------------------------------------------- | |
| with gr.Blocks(title="VideoFlexTok Demo", theme=gr.themes.Base(), css=APP_CSS) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# VideoFlexTok: Flexible-Length Coarse-to-Fine Video Tokenization") | |
| with gr.Row(): | |
| with gr.Column(scale=1, elem_id="col-input-container"): | |
| gr.Markdown(f""" | |
| [`Website`](https://videoflextok.epfl.ch) | [`Paper`](https://arxiv.org/abs/2604.12887) | [`GitHub`](https://github.com/apple/ml-videoflextok) | [`Model`](https://huggingface.co/EPFL-VILAB/videoflextok_d18_d28) | |
| Research demo for **VideoFlexTok: Flexible-Length Coarse-to-Fine Video Tokenization** (arXiv 2026). | |
| Autoencodes your video with `{MODEL_ID}` and shows coarse-to-fine reconstructions. | |
| VideoFlexTok tokenizes video into `T × 256` tokens ordered coarse-to-fine; this demo shows | |
| reconstructions from `T × k` tokens for k ∈ `{NUM_KEEP_TOKENS}`. Bottom-right is the original. | |
| """) | |
| input_video = gr.Video( | |
| label="Input video", sources=["upload"], format="mp4", | |
| ) | |
| run_button = gr.Button("Autoencode with VideoFlexTok", elem_id="run-button") | |
| if EXAMPLE_VIDEOS: | |
| gr.Examples( | |
| examples=[str(p) for p in EXAMPLE_VIDEOS], | |
| inputs=[input_video], | |
| outputs=[input_video], | |
| fn=lambda p: p, | |
| cache_examples=True, | |
| label="Example videos", | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| gr.Markdown("Adjust target FPS to control how many frames are extracted.") | |
| input_fps = gr.Slider(minimum=1, maximum=16, value=8, step=1, label="Target FPS") | |
| timesteps = gr.Slider(minimum=1, maximum=60, value=20, step=1, label="Denoising steps") | |
| guidance_scale = gr.Slider(minimum=1.0, maximum=30.0, value=25.0, step=0.5, label="Guidance scale") | |
| seed = gr.Number(value=42, precision=0, label="Seed") | |
| with gr.Column(scale=4): | |
| output_video = gr.Video(label="Reconstructions") | |
| status = gr.Markdown() | |
| run_button.click( | |
| fn=reconstruct_video, | |
| inputs=[input_video, input_fps, timesteps, guidance_scale, seed], | |
| outputs=[output_video, status], | |
| ) | |
| if DEVICE.type != "cuda": | |
| gr.Markdown("Running on CPU — inference will be slow.") | |
| # --- Launch ------------------------------------------------------------------------ | |
| demo.queue(max_size=16) | |
| if __name__ == "__main__": | |
| server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") | |
| launch_kwargs = {"server_name": server_name, "ssr_mode": False} | |
| if port := os.environ.get("GRADIO_SERVER_PORT"): | |
| launch_kwargs["server_port"] = int(port) | |
| launch_kwargs["allowed_paths"] = [str(APP_DIR), tempfile.gettempdir()] | |
| demo.launch(**launch_kwargs) | |