VideoFlexTok / app.py
andreiatanov's picture
Update app.py
26b94a3 verified
import importlib
import os
import subprocess
import sys
import tempfile
from pathlib import Path
# Install videoflextok without its deps to avoid huggingface_hub==0.25.2 conflicting
# with gradio's >=0.33.5 requirement. Compatible dep versions are in requirements.txt.
def _install_videoflextok():
try:
import videoflextok # noqa: F401
return
except ImportError:
pass
print("[VideoFlexTok] Installing videoflextok (--no-deps) ...")
subprocess.run(
[sys.executable, "-m", "pip", "install", "--quiet", "--no-deps",
"git+https://github.com/apple/ml-videoflextok.git"],
check=True,
)
importlib.invalidate_caches()
_install_videoflextok()
import spaces
import gradio as gr
import imageio.v3 as iio
import numpy as np
import torch
from videoflextok.utils.demo import denormalize, read_mp4
from videoflextok.utils.misc import detect_bf16_support, get_bf16_context
from videoflextok.wrappers import VideoFlexTokFromHub
# --- Constants ---------------------------------------------------------------------
MODEL_ID = "EPFL-VILAB/videoflextok_d18_d28"
APP_DIR = Path(__file__).resolve().parent
EXAMPLES_DIR = APP_DIR / "examples"
EXAMPLE_VIDEOS = sorted(EXAMPLES_DIR.glob("*.mp4"))
NUM_KEEP_TOKENS = [2**i for i in range(9)] # 1, 2, 4, 8, 16, 32, 64, 128, 256
APP_CSS = """
#col-container {
margin: 0 auto;
max-width: 1500px;
}
#col-input-container {
margin: 0 auto;
max-width: 420px;
}
#run-button {
margin: 0 auto;
}
"""
# --- Device setup ------------------------------------------------------------------
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ENABLE_BF16 = DEVICE.type == "cuda" and detect_bf16_support()
# --- Model loading -----------------------------------------------------------------
def _patch_for_hf_spaces(model):
"""Patch TorchDynamo and model for HF Spaces / ZeroGPU compatibility.
This PyTorch version's TorchDynamo cannot represent torch.device as a ConstantVariable,
causing torch.compile(flex_attention) to crash. The fix was merged into newer PyTorch;
here we backport it by adding torch.device to common_constant_types, so the Triton
kernel is used correctly instead of falling back to the dense O(n²) math implementation.
We also disable block mask compilation (compile_block_mask=False) since create_block_mask
uses a separate internal torch.compile call that would hit the same bug.
"""
# Patch TorchDynamo to accept torch.device as a ConstantVariable.
# common_constant_types may be closed over in is_base_literal, so patch the method directly.
import torch._dynamo.variables.constant as _dynamo_const
_orig_is_base_literal = _dynamo_const.ConstantVariable.is_base_literal
@staticmethod
def _patched_is_base_literal(value):
return isinstance(value, torch.device) or _orig_is_base_literal(value)
_dynamo_const.ConstantVariable.is_base_literal = _patched_is_base_literal
from videoflextok.model.preprocessors.flex_seq_packing import (
BlockWiseSequencePacker,
BlockWiseSequenceInterleavePacker,
BlockWiseSequencePackerWithCrossAttention,
)
for module in model.modules():
if isinstance(module, (
BlockWiseSequencePacker,
BlockWiseSequenceInterleavePacker,
BlockWiseSequencePackerWithCrossAttention,
)):
module.compile_block_mask = False
_model = None
try:
print(f"[VideoFlexTok] Loading {MODEL_ID} ...")
_model = VideoFlexTokFromHub.from_pretrained(MODEL_ID)
_model = _model.to(torch.bfloat16).to(DEVICE).eval()
_patch_for_hf_spaces(_model)
print("[VideoFlexTok] Model ready.")
except Exception as exc:
print(f"[VideoFlexTok] FATAL: model load failed: {exc}")
# --- Inference ---------------------------------------------------------------------
def _stack_reconstructed_videos(videos, output_path: str, fps: int):
"""Compose 9 reconstructions + original into a 2×5 grid video and write to output_path."""
def to_uint8_frames(video_tensor):
if video_tensor.ndim == 5:
video_tensor = video_tensor[0]
frames = denormalize(video_tensor).permute(1, 2, 3, 0).contiguous().numpy()
return (np.clip(frames, 0.0, 1.0) * 255).round().astype(np.uint8)
def add_border(frames: np.ndarray, border_px: int, color: int) -> np.ndarray:
return np.pad(
frames,
((0, 0), (border_px, border_px), (border_px, border_px), (0, 0)),
mode="constant", constant_values=color,
)
def compose_row(row_frames: list[np.ndarray], t: int, gap_px: int) -> np.ndarray:
gap_col = np.full((row_frames[0].shape[1], gap_px, 3), 255, dtype=np.uint8)
items = []
for i, frames in enumerate(row_frames):
items.append(frames[t])
if i < len(row_frames) - 1:
items.append(gap_col)
return np.concatenate(items, axis=1)
border_px, gap_px = 8, 8
reconstructed = [add_border(to_uint8_frames(v), border_px, 255) for v in videos[:9]]
original = add_border(to_uint8_frames(videos[9]), border_px, 0)
all_panels = reconstructed + [original]
total_frames = min(p.shape[0] for p in all_panels)
all_panels = [p[:total_frames] for p in all_panels]
row1 = all_panels[:5] # k = 1, 2, 4, 8, 16
row2 = all_panels[5:] # k = 32, 64, 128, 256, Original
composed = []
for t in range(total_frames):
row1_img = compose_row(row1, t, gap_px)
row2_img = compose_row(row2, t, gap_px)
row_gap = np.full((gap_px, row1_img.shape[1], 3), 255, dtype=np.uint8)
composed.append(np.concatenate([row1_img, row_gap, row2_img], axis=0))
iio.imwrite(
output_path, np.stack(composed, axis=0),
fps=fps, plugin="FFMPEG", codec="libx264", pixelformat="yuv420p",
)
def reconstruct_video(video_path: str, input_fps: int, timesteps: int, guidance_scale: float, seed: int):
if not video_path or not Path(video_path).exists():
raise gr.Error("Upload a video first.")
if _model is None:
raise gr.Error("Model failed to load at startup — check Space logs.")
try:
preprocess_args = dict(_model.video_preprocess_args)
# Public package uses 'overlap_size'; model config key is 'overlap_size_frames'
if "overlap_size_frames" in preprocess_args and "overlap_size" not in preprocess_args:
preprocess_args["overlap_size"] = preprocess_args.pop("overlap_size_frames")
video_tensor = read_mp4(str(video_path), fps=int(input_fps), **preprocess_args)
except Exception as exc:
raise gr.Error(f"Failed to decode video: {exc}") from exc
try:
with get_bf16_context(ENABLE_BF16, device_type=DEVICE.type):
print(f"[VideoFlexTok] Tokenizing {video_tensor.shape} ...")
token_ids = _model.tokenize(video_tensor[None].to(DEVICE))
print(f"[VideoFlexTok] Decoding {len(NUM_KEEP_TOKENS)} reconstructions ...")
reconstructed = _model.detokenize(
[token_ids[0]] * len(NUM_KEEP_TOKENS),
num_keep_tokens_list=NUM_KEEP_TOKENS,
timesteps=int(timesteps),
guidance_scale=float(guidance_scale),
perform_norm_guidance=True,
generator=torch.Generator(device=DEVICE.type).manual_seed(int(seed)),
eta=0.0, momentum=0.0, norm_threshold=0.6, verbose=False,
)
reconstructed = [v.cpu().float() for v in reconstructed]
print("[VideoFlexTok] Inference complete.")
except Exception as exc:
raise gr.Error(f"Model inference failed: {exc}") from exc
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
tmp.close()
_stack_reconstructed_videos(reconstructed + [video_tensor], output_path=tmp.name, fps=int(input_fps))
info = f"Extracted {video_tensor.shape[1]} frames at {input_fps} FPS"
return tmp.name, info
if spaces is not None and hasattr(spaces, "GPU"):
reconstruct_video = spaces.GPU(duration=60)(reconstruct_video)
# --- UI ----------------------------------------------------------------------------
with gr.Blocks(title="VideoFlexTok Demo", theme=gr.themes.Base(), css=APP_CSS) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# VideoFlexTok: Flexible-Length Coarse-to-Fine Video Tokenization")
with gr.Row():
with gr.Column(scale=1, elem_id="col-input-container"):
gr.Markdown(f"""
[`Website`](https://videoflextok.epfl.ch) | [`Paper`](https://arxiv.org/abs/2604.12887) | [`GitHub`](https://github.com/apple/ml-videoflextok) | [`Model`](https://huggingface.co/EPFL-VILAB/videoflextok_d18_d28)
Research demo for **VideoFlexTok: Flexible-Length Coarse-to-Fine Video Tokenization** (arXiv 2026).
Autoencodes your video with `{MODEL_ID}` and shows coarse-to-fine reconstructions.
VideoFlexTok tokenizes video into `T × 256` tokens ordered coarse-to-fine; this demo shows
reconstructions from `T × k` tokens for k ∈ `{NUM_KEEP_TOKENS}`. Bottom-right is the original.
""")
input_video = gr.Video(
label="Input video", sources=["upload"], format="mp4",
)
run_button = gr.Button("Autoencode with VideoFlexTok", elem_id="run-button")
if EXAMPLE_VIDEOS:
gr.Examples(
examples=[str(p) for p in EXAMPLE_VIDEOS],
inputs=[input_video],
outputs=[input_video],
fn=lambda p: p,
cache_examples=True,
label="Example videos",
)
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown("Adjust target FPS to control how many frames are extracted.")
input_fps = gr.Slider(minimum=1, maximum=16, value=8, step=1, label="Target FPS")
timesteps = gr.Slider(minimum=1, maximum=60, value=20, step=1, label="Denoising steps")
guidance_scale = gr.Slider(minimum=1.0, maximum=30.0, value=25.0, step=0.5, label="Guidance scale")
seed = gr.Number(value=42, precision=0, label="Seed")
with gr.Column(scale=4):
output_video = gr.Video(label="Reconstructions")
status = gr.Markdown()
run_button.click(
fn=reconstruct_video,
inputs=[input_video, input_fps, timesteps, guidance_scale, seed],
outputs=[output_video, status],
)
if DEVICE.type != "cuda":
gr.Markdown("Running on CPU — inference will be slow.")
# --- Launch ------------------------------------------------------------------------
demo.queue(max_size=16)
if __name__ == "__main__":
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
launch_kwargs = {"server_name": server_name, "ssr_mode": False}
if port := os.environ.get("GRADIO_SERVER_PORT"):
launch_kwargs["server_port"] = int(port)
launch_kwargs["allowed_paths"] = [str(APP_DIR), tempfile.gettempdir()]
demo.launch(**launch_kwargs)