andreiatanov commited on
Commit
8603681
·
0 Parent(s):

VideoFlexTok demo

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .venv310
2
+ __pycache__
3
+ .DS_Store
4
+ ml-videoflextok/
5
+ gradio_cached_examples/
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: VideoFlexTok
3
+ emoji: 🎞️
4
+ colorFrom: pink
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 6.5.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: 'VideoFlexTok: flexible-length coarse-to-fine video tokenizer'
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ import tempfile
6
+ from pathlib import Path
7
+
8
+ # Install videoflextok without its deps to avoid huggingface_hub==0.25.2 conflicting
9
+ # with gradio's >=0.33.5 requirement. Compatible dep versions are in requirements.txt.
10
+ def _install_videoflextok():
11
+ try:
12
+ import videoflextok # noqa: F401
13
+ return
14
+ except ImportError:
15
+ pass
16
+ print("[VideoFlexTok] Installing videoflextok (--no-deps) ...")
17
+ subprocess.run(
18
+ [sys.executable, "-m", "pip", "install", "--quiet", "--no-deps",
19
+ "git+https://github.com/apple/ml-videoflextok.git"],
20
+ check=True,
21
+ )
22
+ importlib.invalidate_caches()
23
+
24
+ _install_videoflextok()
25
+
26
+ import spaces
27
+ import gradio as gr
28
+ import imageio.v3 as iio
29
+ import numpy as np
30
+ import torch
31
+
32
+ from videoflextok.utils.demo import denormalize, read_mp4
33
+ from videoflextok.utils.misc import detect_bf16_support, get_bf16_context
34
+ from videoflextok.wrappers import VideoFlexTokFromHub
35
+
36
+
37
+ # --- Constants ---------------------------------------------------------------------
38
+
39
+ MODEL_ID = "EPFL-VILAB/videoflextok_d18_d28"
40
+ APP_DIR = Path(__file__).resolve().parent
41
+ EXAMPLES_DIR = APP_DIR / "examples"
42
+ EXAMPLE_VIDEOS = sorted(EXAMPLES_DIR.glob("*.mp4"))
43
+ NUM_KEEP_TOKENS = [2**i for i in range(9)] # 1, 2, 4, 8, 16, 32, 64, 128, 256
44
+
45
+ APP_CSS = """
46
+ #col-container {
47
+ margin: 0 auto;
48
+ max-width: 1500px;
49
+ }
50
+ #col-input-container {
51
+ margin: 0 auto;
52
+ max-width: 420px;
53
+ }
54
+ #run-button {
55
+ margin: 0 auto;
56
+ }
57
+ """
58
+
59
+
60
+ # --- Device setup ------------------------------------------------------------------
61
+
62
+ torch.set_grad_enabled(False)
63
+ if torch.cuda.is_available():
64
+ torch.backends.cuda.matmul.allow_tf32 = True
65
+ torch.backends.cudnn.allow_tf32 = True
66
+
67
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+ ENABLE_BF16 = DEVICE.type == "cuda" and detect_bf16_support()
69
+
70
+
71
+ # --- Model loading -----------------------------------------------------------------
72
+
73
+ def _patch_for_hf_spaces(model):
74
+ """Patch TorchDynamo and model for HF Spaces / ZeroGPU compatibility.
75
+
76
+ This PyTorch version's TorchDynamo cannot represent torch.device as a ConstantVariable,
77
+ causing torch.compile(flex_attention) to crash. The fix was merged into newer PyTorch;
78
+ here we backport it by adding torch.device to common_constant_types, so the Triton
79
+ kernel is used correctly instead of falling back to the dense O(n²) math implementation.
80
+
81
+ We also disable block mask compilation (compile_block_mask=False) since create_block_mask
82
+ uses a separate internal torch.compile call that would hit the same bug.
83
+ """
84
+ # Patch TorchDynamo to accept torch.device as a ConstantVariable.
85
+ # common_constant_types may be closed over in is_base_literal, so patch the method directly.
86
+ import torch._dynamo.variables.constant as _dynamo_const
87
+ _orig_is_base_literal = _dynamo_const.ConstantVariable.is_base_literal
88
+
89
+ @staticmethod
90
+ def _patched_is_base_literal(value):
91
+ return isinstance(value, torch.device) or _orig_is_base_literal(value)
92
+
93
+ _dynamo_const.ConstantVariable.is_base_literal = _patched_is_base_literal
94
+
95
+ from videoflextok.model.preprocessors.flex_seq_packing import (
96
+ BlockWiseSequencePacker,
97
+ BlockWiseSequenceInterleavePacker,
98
+ BlockWiseSequencePackerWithCrossAttention,
99
+ )
100
+ for module in model.modules():
101
+ if isinstance(module, (
102
+ BlockWiseSequencePacker,
103
+ BlockWiseSequenceInterleavePacker,
104
+ BlockWiseSequencePackerWithCrossAttention,
105
+ )):
106
+ module.compile_block_mask = False
107
+
108
+
109
+ _model = None
110
+ try:
111
+ print(f"[VideoFlexTok] Loading {MODEL_ID} ...")
112
+ _model = VideoFlexTokFromHub.from_pretrained(MODEL_ID)
113
+ _model = _model.to(torch.bfloat16).to(DEVICE).eval()
114
+ _patch_for_hf_spaces(_model)
115
+ print("[VideoFlexTok] Model ready.")
116
+ except Exception as exc:
117
+ print(f"[VideoFlexTok] FATAL: model load failed: {exc}")
118
+
119
+
120
+ # --- Inference ---------------------------------------------------------------------
121
+
122
+ def _stack_reconstructed_videos(videos, output_path: str, fps: int):
123
+ """Compose 9 reconstructions + original into a 2×5 grid video and write to output_path."""
124
+ def to_uint8_frames(video_tensor):
125
+ if video_tensor.ndim == 5:
126
+ video_tensor = video_tensor[0]
127
+ frames = denormalize(video_tensor).permute(1, 2, 3, 0).contiguous().numpy()
128
+ return (np.clip(frames, 0.0, 1.0) * 255).round().astype(np.uint8)
129
+
130
+ def add_border(frames: np.ndarray, border_px: int, color: int) -> np.ndarray:
131
+ return np.pad(
132
+ frames,
133
+ ((0, 0), (border_px, border_px), (border_px, border_px), (0, 0)),
134
+ mode="constant", constant_values=color,
135
+ )
136
+
137
+ def compose_row(row_frames: list[np.ndarray], t: int, gap_px: int) -> np.ndarray:
138
+ gap_col = np.full((row_frames[0].shape[1], gap_px, 3), 255, dtype=np.uint8)
139
+ items = []
140
+ for i, frames in enumerate(row_frames):
141
+ items.append(frames[t])
142
+ if i < len(row_frames) - 1:
143
+ items.append(gap_col)
144
+ return np.concatenate(items, axis=1)
145
+
146
+ border_px, gap_px = 8, 8
147
+ reconstructed = [add_border(to_uint8_frames(v), border_px, 255) for v in videos[:9]]
148
+ original = add_border(to_uint8_frames(videos[9]), border_px, 0)
149
+
150
+ all_panels = reconstructed + [original]
151
+ total_frames = min(p.shape[0] for p in all_panels)
152
+ all_panels = [p[:total_frames] for p in all_panels]
153
+
154
+ row1 = all_panels[:5] # k = 1, 2, 4, 8, 16
155
+ row2 = all_panels[5:] # k = 32, 64, 128, 256, Original
156
+
157
+ composed = []
158
+ for t in range(total_frames):
159
+ row1_img = compose_row(row1, t, gap_px)
160
+ row2_img = compose_row(row2, t, gap_px)
161
+ row_gap = np.full((gap_px, row1_img.shape[1], 3), 255, dtype=np.uint8)
162
+ composed.append(np.concatenate([row1_img, row_gap, row2_img], axis=0))
163
+
164
+ iio.imwrite(
165
+ output_path, np.stack(composed, axis=0),
166
+ fps=fps, plugin="FFMPEG", codec="libx264", pixelformat="yuv420p",
167
+ )
168
+
169
+
170
+ def reconstruct_video(video_path: str, input_fps: int, timesteps: int, guidance_scale: float, seed: int):
171
+ if not video_path or not Path(video_path).exists():
172
+ raise gr.Error("Upload a video first.")
173
+ if _model is None:
174
+ raise gr.Error("Model failed to load at startup — check Space logs.")
175
+
176
+ try:
177
+ preprocess_args = dict(_model.video_preprocess_args)
178
+ # Public package uses 'overlap_size'; model config key is 'overlap_size_frames'
179
+ if "overlap_size_frames" in preprocess_args and "overlap_size" not in preprocess_args:
180
+ preprocess_args["overlap_size"] = preprocess_args.pop("overlap_size_frames")
181
+ video_tensor = read_mp4(str(video_path), fps=int(input_fps), **preprocess_args)
182
+ except Exception as exc:
183
+ raise gr.Error(f"Failed to decode video: {exc}") from exc
184
+
185
+ try:
186
+ with get_bf16_context(ENABLE_BF16, device_type=DEVICE.type):
187
+ print(f"[VideoFlexTok] Tokenizing {video_tensor.shape} ...")
188
+ token_ids = _model.tokenize(video_tensor[None].to(DEVICE))
189
+ print(f"[VideoFlexTok] Decoding {len(NUM_KEEP_TOKENS)} reconstructions ...")
190
+ reconstructed = _model.detokenize(
191
+ [token_ids[0]] * len(NUM_KEEP_TOKENS),
192
+ num_keep_tokens_list=NUM_KEEP_TOKENS,
193
+ timesteps=int(timesteps),
194
+ guidance_scale=float(guidance_scale),
195
+ perform_norm_guidance=True,
196
+ generator=torch.Generator(device=DEVICE.type).manual_seed(int(seed)),
197
+ eta=0.0, momentum=0.0, norm_threshold=0.6, verbose=False,
198
+ )
199
+ reconstructed = [v.cpu().float() for v in reconstructed]
200
+ print("[VideoFlexTok] Inference complete.")
201
+ except Exception as exc:
202
+ raise gr.Error(f"Model inference failed: {exc}") from exc
203
+
204
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
205
+ tmp.close()
206
+ _stack_reconstructed_videos(reconstructed + [video_tensor], output_path=tmp.name, fps=int(input_fps))
207
+
208
+ info = f"Extracted {video_tensor.shape[1]} frames at {input_fps} FPS"
209
+ return tmp.name, info
210
+
211
+
212
+ if spaces is not None and hasattr(spaces, "GPU"):
213
+ reconstruct_video = spaces.GPU(duration=60)(reconstruct_video)
214
+
215
+
216
+ # --- UI ----------------------------------------------------------------------------
217
+
218
+ with gr.Blocks(title="VideoFlexTok Demo", theme=gr.themes.Base(), css=APP_CSS) as demo:
219
+ with gr.Column(elem_id="col-container"):
220
+ gr.Markdown("# VideoFlexTok: Flexible-Length Coarse-to-Fine Video Tokenization")
221
+
222
+ with gr.Row():
223
+ with gr.Column(scale=1, elem_id="col-input-container"):
224
+ gr.Markdown(f"""
225
+ [`Website`](https://videoflextok.epfl.ch) | [`GitHub`](https://github.com/apple/ml-videoflextok) | [`Model`](https://huggingface.co/{MODEL_ID})
226
+
227
+ Research demo for **VideoFlexTok: Flexible-Length Coarse-to-Fine Video Tokenization** (arXiv 2026).
228
+ Autoencodes your video with `{MODEL_ID}` and shows coarse-to-fine reconstructions.
229
+ VideoFlexTok tokenizes video into `T × 256` tokens ordered coarse-to-fine; this demo shows
230
+ reconstructions from `T × k` tokens for k ∈ `{NUM_KEEP_TOKENS}`. Bottom-right is the original.
231
+ """)
232
+ input_video = gr.Video(
233
+ label="Input video", sources=["upload"], format="mp4",
234
+ )
235
+ run_button = gr.Button("Autoencode with VideoFlexTok", elem_id="run-button")
236
+
237
+ if EXAMPLE_VIDEOS:
238
+ gr.Examples(
239
+ examples=[str(p) for p in EXAMPLE_VIDEOS],
240
+ inputs=[input_video],
241
+ outputs=[input_video],
242
+ fn=lambda p: p,
243
+ cache_examples=True,
244
+ label="Example videos",
245
+ )
246
+
247
+ with gr.Accordion("Advanced Settings", open=False):
248
+ gr.Markdown("Adjust target FPS to control how many frames are extracted.")
249
+ input_fps = gr.Slider(minimum=1, maximum=16, value=8, step=1, label="Target FPS")
250
+ timesteps = gr.Slider(minimum=1, maximum=60, value=20, step=1, label="Denoising steps")
251
+ guidance_scale = gr.Slider(minimum=1.0, maximum=30.0, value=25.0, step=0.5, label="Guidance scale")
252
+ seed = gr.Number(value=42, precision=0, label="Seed")
253
+
254
+ with gr.Column(scale=4):
255
+ output_video = gr.Video(label="Reconstructions")
256
+ status = gr.Markdown()
257
+
258
+ run_button.click(
259
+ fn=reconstruct_video,
260
+ inputs=[input_video, input_fps, timesteps, guidance_scale, seed],
261
+ outputs=[output_video, status],
262
+ )
263
+
264
+ if DEVICE.type != "cuda":
265
+ gr.Markdown("Running on CPU — inference will be slow.")
266
+
267
+
268
+ # --- Launch ------------------------------------------------------------------------
269
+
270
+ demo.queue(max_size=16)
271
+
272
+ if __name__ == "__main__":
273
+ server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
274
+ launch_kwargs = {"server_name": server_name, "ssr_mode": False}
275
+ if port := os.environ.get("GRADIO_SERVER_PORT"):
276
+ launch_kwargs["server_port"] = int(port)
277
+ launch_kwargs["allowed_paths"] = [str(APP_DIR), tempfile.gettempdir()]
278
+ demo.launch(**launch_kwargs)
examples/apple.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c2f7782fdb34cfa29bd36a92ebf47a4cf006f278c28891d3feb944a526b6a26
3
+ size 71661
examples/arch.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:662e89863b7479fa5323e0a209c67b249b0ad064ff337285ffa10b25a91570a7
3
+ size 63973
examples/cat.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54f7eece320681198e6e817f6c7170a08b22e778435322421f15b68271c95734
3
+ size 58276
examples/porsche.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4738089a93af048948eda8deb0b53c47baf5a898021471508b131784f1bc39f3
3
+ size 293070
examples/sculpture.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3746cdde8a5096398280efa4012b6978751a8a1471245baabaa5605984056fc4
3
+ size 66136
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==6.5.1
2
+ imageio-ffmpeg==0.6.0
3
+ imageio
4
+
5
+ # videoflextok is installed without its deps at Space startup (see app.py).
6
+ # Its pyproject.toml pins huggingface_hub==0.25.2 which conflicts with gradio>=0.33.5,
7
+ # so we install --no-deps and provide compatible versions here instead.
8
+ # git+https://github.com/apple/ml-videoflextok.git
9
+
10
+ # Pin torch to 2.8.x — the version videoflextok was developed and tested on.
11
+ # The HF Spaces base image ships 2.9.1 which has TorchDynamo regressions.
12
+ torch==2.8.0
13
+ torchvision==0.23.0
14
+
15
+ # videoflextok dependencies (compatible versions)
16
+ diffusers>=0.28.0
17
+ einops>=0.7.0
18
+ huggingface_hub>=0.33.5,<0.40
19
+ hydra-core>=1.3.2
20
+ omegaconf>=2.3.0
21
+ PyYAML>=6.0
22
+ mup
23
+ safetensors>=0.4.0
24
+ tqdm>=4.64.1
25
+ eva-decord==0.6.1