bfshi commited on
Commit
7e3b296
·
1 Parent(s): fac3282
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tempfile
3
+ import os
4
+ import torch
5
+ import gc
6
+ from demo_utils import load_model, process_video, save_video, image_to_video
7
+ import av
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ try:
12
+ import spaces
13
+ ZEROGPU_AVAILABLE = True
14
+ except ImportError:
15
+ ZEROGPU_AVAILABLE = False
16
+ print("Warning: spaces module not available. Running without ZeroGPU support.")
17
+
18
+ model_cache = {}
19
+
20
+ def get_model(device):
21
+ if device not in model_cache:
22
+ model_cache[device] = load_model(device=device)
23
+ return model_cache[device]
24
+
25
+ device = "cuda" if torch.cuda.is_available() or ZEROGPU_AVAILABLE else "cpu"
26
+
27
+ def cleanup_gpu():
28
+ """Clean up GPU memory."""
29
+ gc.collect()
30
+ if torch.cuda.is_available():
31
+ torch.cuda.empty_cache()
32
+ torch.cuda.synchronize()
33
+
34
+ def extract_metadata(file):
35
+ if file is None:
36
+ return "", None, None, None, None, None
37
+
38
+ file_extension = os.path.splitext(file.name)[1].lower()
39
+ is_image = file_extension in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
40
+
41
+ if is_image:
42
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video:
43
+ tmp_path = tmp_video.name
44
+
45
+ metadata = image_to_video(file.name, tmp_path, fps=1.0)
46
+
47
+ total_frames = metadata['frames']
48
+ fps = metadata['fps']
49
+ original_height = metadata['height']
50
+ original_width = metadata['width']
51
+ info_text = f"{original_width}×{original_height} | Image (1 frame)"
52
+ else:
53
+ tmp_path = file.name
54
+
55
+ container = av.open(tmp_path)
56
+ video_stream = container.streams.video[0]
57
+ total_frames = video_stream.frames
58
+ fps = float(video_stream.average_rate)
59
+ original_height = video_stream.height
60
+ original_width = video_stream.width
61
+ container.close()
62
+ info_text = f"{original_width}×{original_height} | {total_frames} frames @ {fps:.1f} FPS"
63
+
64
+ return info_text, tmp_path, total_frames, fps, original_width, original_height
65
+
66
+ def handle_file_upload(file):
67
+ metadata = extract_metadata(file)
68
+
69
+ if metadata[1] is None:
70
+ return "", None, None
71
+
72
+ info_text, tmp_path, total_frames, fps, original_width, original_height = metadata
73
+ return info_text, metadata, fps
74
+
75
+ def _process_video_impl(file_info, gazing_ratio, task_loss_requirement, output_fps, progress=None):
76
+ if file_info is None:
77
+ return None, None, None, None, None, None, None, "No file uploaded"
78
+
79
+ _, tmp_path, total_frames, fps, _, _ = file_info
80
+
81
+ if tmp_path is None:
82
+ return None, None, None, None, None, None, None, "Invalid file"
83
+
84
+ # Yield initial status
85
+ yield None, None, None, None, None, None, None, "Loading model..."
86
+
87
+ if progress:
88
+ progress(0.0, desc="Loading model...")
89
+ setup = get_model(device)
90
+
91
+ yield None, None, None, None, None, None, None, "Processing video..."
92
+
93
+ if progress:
94
+ progress(0.1, desc="Processing video...")
95
+
96
+ status_messages = []
97
+
98
+ def update_progress(pct, msg):
99
+ if progress:
100
+ progress(pct, desc=msg)
101
+ status_messages.append(msg)
102
+
103
+ # Convert UI gazing ratio to model gazing ratio
104
+ # UI: ranges from 1/196 to 265/196 (effective patches per frame / 196)
105
+ # Model: needs value * (196/265) to get actual gazing ratio
106
+ model_gazing_ratio = gazing_ratio * (196 / 265)
107
+
108
+ for results in process_video(
109
+ tmp_path,
110
+ setup,
111
+ gazing_ratio=model_gazing_ratio,
112
+ task_loss_requirement=task_loss_requirement,
113
+ progress_callback=update_progress,
114
+ spatial_batch_size=2 # Process 4 spatial chunks at a time to avoid OOM
115
+ ):
116
+ if status_messages:
117
+ yield None, None, None, None, None, None, None, status_messages[-1]
118
+
119
+ yield None, None, None, None, None, None, None, "Saving output videos..."
120
+
121
+ with tempfile.TemporaryDirectory() as tmpdir:
122
+ original_path = os.path.join(tmpdir, "original.mp4")
123
+ gazing_path = os.path.join(tmpdir, "gazing.mp4")
124
+ recon_path = os.path.join(tmpdir, "reconstruction.mp4")
125
+ scales_stitch_path = os.path.join(tmpdir, "scales_stitch.mp4")
126
+
127
+ # Use output_fps if specified, otherwise use original video fps
128
+ fps_to_use = output_fps if output_fps is not None else results['fps']
129
+
130
+ save_video(results['original_frames'], original_path, fps_to_use)
131
+ save_video(results['gazing_frames'], gazing_path, fps_to_use)
132
+ save_video(results['reconstruction_frames'], recon_path, fps_to_use)
133
+ save_video(results['scales_stitch_frames'], scales_stitch_path, fps_to_use)
134
+
135
+ with open(original_path, "rb") as f:
136
+ original_data = f.read()
137
+ with open(gazing_path, "rb") as f:
138
+ gazing_data = f.read()
139
+ with open(recon_path, "rb") as f:
140
+ recon_data = f.read()
141
+ with open(scales_stitch_path, "rb") as f:
142
+ scales_stitch_data = f.read()
143
+
144
+ original_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
145
+ original_file.write(original_data)
146
+ original_file.close()
147
+
148
+ gazing_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
149
+ gazing_file.write(gazing_data)
150
+ gazing_file.close()
151
+
152
+ recon_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
153
+ recon_file.write(recon_data)
154
+ recon_file.close()
155
+
156
+ scales_stitch_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
157
+ scales_stitch_file.write(scales_stitch_data)
158
+ scales_stitch_file.close()
159
+
160
+ gazing_pct_text = f"{results['gazing_pct']:.2%}"
161
+ gazing_tokens_text = f"{results['total_gazing_tokens']:,}"
162
+ total_tokens_text = f"{results['total_possible_tokens']:,}"
163
+
164
+ yield (
165
+ gazing_pct_text,
166
+ gazing_tokens_text,
167
+ total_tokens_text,
168
+ original_file.name,
169
+ gazing_file.name,
170
+ recon_file.name,
171
+ scales_stitch_file.name,
172
+ "Processing complete!"
173
+ )
174
+
175
+ if ZEROGPU_AVAILABLE:
176
+ process_video_ui = spaces.GPU(duration=120)(_process_video_impl)
177
+ else:
178
+ process_video_ui = _process_video_impl
179
+
180
+ def extract_first_frame_thumbnail(video_path, output_path, size=(200, 200), force=False):
181
+ """Extract first frame from video and save as thumbnail with fixed aspect ratio."""
182
+ if os.path.exists(output_path) and not force:
183
+ return
184
+ container = av.open(video_path)
185
+ for frame in container.decode(video=0):
186
+ img = frame.to_image()
187
+ # Crop to center square first, then resize
188
+ width, height = img.size
189
+ min_dim = min(width, height)
190
+ left = (width - min_dim) // 2
191
+ top = (height - min_dim) // 2
192
+ img_cropped = img.crop((left, top, left + min_dim, top + min_dim))
193
+ img_resized = img_cropped.resize(size, Image.LANCZOS)
194
+ img_resized.save(output_path)
195
+ break
196
+ container.close()
197
+
198
+ # Generate thumbnails for example videos
199
+ example_videos = [
200
+ "example_inputs/aerial.mp4",
201
+ "example_inputs/doorbell.mp4",
202
+ "example_inputs/tomjerry.mp4",
203
+ ]
204
+
205
+ for video_path in example_videos:
206
+ if os.path.exists(video_path):
207
+ thumb_path = video_path.replace('.mp4', '_thumb.png')
208
+ # Force regeneration with square aspect ratio at 100x100 to match gallery height
209
+ extract_first_frame_thumbnail(video_path, thumb_path, size=(100, 100), force=True)
210
+
211
+ # Load thumbnails as numpy arrays
212
+ aerial_thumb_img = np.array(Image.open("example_inputs/aerial_thumb.png"))
213
+ doorbell_thumb_img = np.array(Image.open("example_inputs/doorbell_thumb.png"))
214
+ tomjerry_thumb_img = np.array(Image.open("example_inputs/tomjerry_thumb.png"))
215
+
216
+ with gr.Blocks(title="AutoGaze Demo", delete_cache=(86400, 86400)) as demo:
217
+ gr.Markdown("# AutoGaze Official Demo")
218
+ gr.Markdown("## **Attend Before Attention: Efficient and Scalable Video Understanding via Autoregressive Gazing**")
219
+ gr.Markdown("""
220
+ <div style="text-align: left; margin: 10px 0; font-size: 1.2em; font-weight: 600;">
221
+ 📄 <a href="https://arxiv.org/abs/PLACEHOLDER" target="_blank" style="text-decoration: none; color: inherit;">Paper</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 🌐 <a href="https://placeholder-website.com" target="_blank" style="text-decoration: none; color: inherit;">Project Website</a>
222
+ </div>
223
+ """)
224
+
225
+ file_metadata = gr.State()
226
+
227
+ with gr.Row():
228
+ with gr.Column(scale=2):
229
+ uploaded_file = gr.File(
230
+ label="Upload Video or Image",
231
+ file_types=["video", "image"]
232
+ )
233
+ with gr.Column(scale=1):
234
+ file_info = gr.Textbox(label="File Info", interactive=False)
235
+ process_button = gr.Button("Process Video", variant="primary")
236
+
237
+
238
+ def load_example_video(evt: gr.SelectData):
239
+ video_map = {
240
+ 0: "example_inputs/aerial.mp4",
241
+ 1: "example_inputs/doorbell.mp4",
242
+ 2: "example_inputs/tomjerry.mp4",
243
+ }
244
+ return video_map[evt.index]
245
+
246
+ with gr.Row():
247
+ with gr.Column(scale=1):
248
+ gr.Markdown("### Example Videos - Click Thumbnail to Load")
249
+ example_gallery = gr.Gallery(
250
+ value=[
251
+ (aerial_thumb_img, "aerial.mp4"),
252
+ (doorbell_thumb_img, "doorbell.mp4"),
253
+ (tomjerry_thumb_img, "tomjerry.mp4"),
254
+ ],
255
+ label="",
256
+ show_label=False,
257
+ columns=3,
258
+ rows=1,
259
+ height=200,
260
+ object_fit="contain",
261
+ allow_preview=False
262
+ )
263
+ gr.Markdown("### Settings")
264
+
265
+ with gr.Accordion("Output Settings", open=True):
266
+ fps_slider = gr.Number(
267
+ label="Output FPS",
268
+ value=None,
269
+ minimum=1,
270
+ maximum=120,
271
+ info="Frames per second for displaying output videos (only affects playback speed)"
272
+ )
273
+
274
+ with gr.Accordion("Model Parameters", open=True):
275
+ gazing_ratio_slider = gr.Slider(
276
+ label="Gazing Ratio",
277
+ minimum=round(1/196, 2),
278
+ maximum=round(265/196, 2),
279
+ step=0.01,
280
+ value=0.75,
281
+ info="Max fraction of patches to gaze at per frame"
282
+ )
283
+ task_loss_slider = gr.Slider(
284
+ label="Task Loss Requirement",
285
+ minimum=0.0,
286
+ maximum=1.5,
287
+ step=0.05,
288
+ value=0.6,
289
+ info="Reconstruction loss threshold"
290
+ )
291
+
292
+ with gr.Accordion("FAQ", open=False):
293
+ gr.Markdown("""
294
+ **What file formats are supported?**
295
+
296
+ The app supports common video formats (MP4, AVI, MOV, etc.) and image formats (JPG, PNG, etc.).
297
+
298
+ **What is the Gazing Ratio?**
299
+
300
+ The gazing ratio explicitly controls how many patches the model looks at per frame. Higher values mean more patches are selected. The range extends to past 1.0 because of multi-scale gazing; if all patches at all scales are selected, the ratio can reach up to 1.35.
301
+
302
+ **What is Task Loss Requirement?**
303
+
304
+ This threshold determines when the model stops gazing at a frame, based on the predicted reconstruction loss from the current gazed patches. Lower = more gazing, higher = less gazing.
305
+
306
+ **How do Gazing Ratio and Task Loss interact?**
307
+
308
+ These two parameters separately control the number of gazed patches in an image/video. This demo will take the stricter of the two requirements when determining how many patches to gaze at. For example, if the gazing ratio suggests gazing at 15% of patches, but the task loss requirement is met after only 7% patches, then only 7% patches will be gazed at. To only use one of the two parameters, set the other to its maximum value.
309
+ """)
310
+
311
+ with gr.Column(scale=2):
312
+ gr.Markdown("### Results")
313
+
314
+ status_text = gr.Markdown("Ready")
315
+
316
+ with gr.Row():
317
+ gazing_pct = gr.Textbox(label="Gazing %", interactive=False)
318
+ gazing_tokens = gr.Textbox(label="# Gazed Patches", interactive=False)
319
+ total_tokens = gr.Textbox(label="Total Patches", interactive=False)
320
+
321
+ with gr.Row():
322
+ original_video = gr.Video(label="Original", autoplay=False, loop=True)
323
+ gazing_video = gr.Video(label="Gazing Pattern (all scales)", autoplay=False, loop=True)
324
+ reconstruction_video = gr.Video(label="Reconstruction", autoplay=False, loop=True)
325
+
326
+ with gr.Row():
327
+ scales_stitch_video = gr.Video(label="Gazing Pattern (individual scales)", autoplay=False, loop=True)
328
+
329
+ example_gallery.select(load_example_video, outputs=uploaded_file)
330
+ uploaded_file.change(
331
+ fn=handle_file_upload,
332
+ inputs=[uploaded_file],
333
+ outputs=[file_info, file_metadata, fps_slider]
334
+ )
335
+
336
+ process_button.click(
337
+ fn=process_video_ui,
338
+ inputs=[file_metadata, gazing_ratio_slider, task_loss_slider, fps_slider],
339
+ outputs=[
340
+ gazing_pct,
341
+ gazing_tokens,
342
+ total_tokens,
343
+ original_video,
344
+ gazing_video,
345
+ reconstruction_video,
346
+ scales_stitch_video,
347
+ status_text
348
+ ]
349
+ ).then(
350
+ fn=cleanup_gpu,
351
+ inputs=None,
352
+ outputs=None
353
+ )
354
+
355
+ # Clean up GPU memory when user disconnects
356
+ demo.unload(cleanup_gpu)
357
+
358
+ # Clear any cached models and free GPU memory at app startup
359
+ print("Clearing model cache and GPU memory at startup...")
360
+ model_cache.clear()
361
+ cleanup_gpu()
362
+ print("Startup cleanup complete.")
363
+
364
+ if __name__ == "__main__":
365
+ demo.launch(share=True)
demo_utils.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import av
7
+ import imageio
8
+ from transformers import VivitImageProcessor
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ from omegaconf import OmegaConf
11
+ from einops import rearrange
12
+
13
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'gengaze'))
14
+ from autogaze.models.autogaze import AutoGaze
15
+ from autogaze.datasets.video_utils import read_video_pyav, transform_video_for_pytorch
16
+ from autogaze.tasks.video_mae_reconstruction import VideoMAEReconstruction
17
+ from autogaze.utils import UnNormalize
18
+ from tqdm import trange
19
+
20
+ try:
21
+ import spaces
22
+ ZEROGPU_AVAILABLE = True
23
+ except ImportError:
24
+ ZEROGPU_AVAILABLE = False
25
+
26
+
27
+ def image_to_video(image_path, output_path, fps):
28
+ """
29
+ Convert a single image to a single-frame video file.
30
+
31
+ Args:
32
+ image_path: Path to input image
33
+ output_path: Path to output video file
34
+ fps: Frame rate for the video
35
+
36
+ Returns:
37
+ Dictionary with video metadata (width, height, frames, fps)
38
+ """
39
+ img = Image.open(image_path)
40
+ if img.mode != 'RGB':
41
+ img = img.convert('RGB')
42
+
43
+ img_array = np.array(img)
44
+
45
+ with imageio.get_writer(output_path, fps=fps, format='FFMPEG', codec='libx264', pixelformat='yuv420p') as writer:
46
+ writer.append_data(img_array)
47
+
48
+ return {
49
+ 'width': img_array.shape[1],
50
+ 'height': img_array.shape[0],
51
+ 'frames': 1,
52
+ 'fps': fps
53
+ }
54
+
55
+
56
+ def load_model(device='cuda'):
57
+ print("Loading AutoGaze model from HuggingFace...")
58
+ model = AutoGaze.from_pretrained("bfshi/AutoGaze")
59
+ model = model.to(device)
60
+ model.eval()
61
+
62
+ transform = VivitImageProcessor.from_pretrained(
63
+ "facebook/vit-mae-large",
64
+ size=model.scales[-1],
65
+ crop_size=model.scales[-1]
66
+ )
67
+
68
+ unnorm = UnNormalize(
69
+ mean=transform.image_mean,
70
+ std=transform.image_std,
71
+ rescale_factor=transform.rescale_factor
72
+ )
73
+
74
+ print("Loading VideoMAE model from HuggingFace...")
75
+ scales_str = '+'.join(map(str, model.scales))
76
+ recon_model_config = OmegaConf.create({
77
+ 'scale_embed': True,
78
+ 'max_num_frames': 256,
79
+ 'time_embed': True,
80
+ 'causal': True,
81
+ 'loss_type': 'l1+dinov2_reg+siglip2',
82
+ 'loss_weights': '1',
83
+ 'l1_loss_config': {},
84
+ 'dinov2_reg_loss_config': {
85
+ 'model': 'facebook/dinov2-with-registers-base'
86
+ },
87
+ 'siglip2_loss_config': {
88
+ 'model': 'google/siglip2-base-patch16-224'
89
+ }
90
+ })
91
+ task = VideoMAEReconstruction(
92
+ recon_model='facebook/vit-mae-large',
93
+ recon_model_config=recon_model_config,
94
+ scales=scales_str,
95
+ recon_sample_rate=1,
96
+ attn_mode='sdpa'
97
+ )
98
+
99
+ # Load fine-tuned weights from HuggingFace
100
+ from huggingface_hub import hf_hub_download
101
+ checkpoint_path = hf_hub_download(repo_id="bfshi/VideoMAE_AutoGaze", filename="videomae.pt")
102
+ print(f"Loading VideoMAE checkpoint from {checkpoint_path}...")
103
+ task_sd = torch.load(checkpoint_path, map_location='cpu')
104
+ task_sd = {k.replace('module.mae.', ''): v for k, v in task_sd.items()}
105
+ task.mae.load_state_dict(task_sd, strict=True)
106
+ print("Loaded VideoMAE checkpoint from HuggingFace")
107
+
108
+ task = task.to(device)
109
+ task.eval()
110
+
111
+ return {
112
+ 'model': model,
113
+ 'task': task,
114
+ 'unnorm': unnorm,
115
+ 'scales': model.scales,
116
+ 'transform': transform,
117
+ }
118
+
119
+
120
+ def process_video(video_path, setup, gazing_ratio=0.75, task_loss_requirement=0.6, progress_callback=None, spatial_batch_size=16):
121
+ """
122
+ Process a video file with AutoGaze using chunking for any resolution/duration.
123
+
124
+ Args:
125
+ video_path: Path to video file
126
+ setup: Dictionary with model, task, unnorm, scales, transform
127
+ gazing_ratio: Maximum percentage of patches to gaze per frame
128
+ task_loss_requirement: Reconstruction loss threshold
129
+ progress_callback: Optional callback function for progress updates
130
+
131
+ Yields:
132
+ Dictionary with original frames, gazing frames, reconstruction frames, and statistics
133
+ """
134
+ model = setup['model']
135
+ task = setup['task']
136
+ transform = setup['transform']
137
+ device = next(model.parameters()).device
138
+ if device == 'cuda':
139
+ torch.cuda.empty_cache()
140
+
141
+ container = av.open(video_path)
142
+ video_stream = container.streams.video[0]
143
+ total_frames_available = video_stream.frames
144
+ fps = float(video_stream.average_rate)
145
+ container.close()
146
+
147
+ container = av.open(video_path)
148
+ sample_indices = list(range(total_frames_available))
149
+ video = read_video_pyav(container=container, indices=sample_indices) # (T, H, W, 3) numpy array
150
+ container.close()
151
+
152
+ # Keep video on CPU for preprocessing to save GPU memory
153
+ video_tensor = torch.from_numpy(video).float() # (T, H, W, 3)
154
+ video_tensor = video_tensor / 255.0 # Normalize to [0, 1]
155
+ video_tensor = video_tensor.permute(0, 3, 1, 2) # (T, C, H, W)
156
+ T, C, H, W = video_tensor.shape
157
+ if T > 200:
158
+ print(f'Video has {T} frames, which may require significant GPU memory. Decreasing spatial_batch_size to 2.')
159
+ spatial_batch_size //= 2
160
+
161
+ # Clone for later visualization (keep on CPU)
162
+ video_tensor_original = video_tensor.clone()
163
+
164
+ # Pad video to be divisible by 224x224 and 16 frames
165
+ pad_t = (16 - T % 16) % 16
166
+ pad_h = (224 - H % 224) % 224
167
+ pad_w = (224 - W % 224) % 224
168
+
169
+ if pad_t > 0 or pad_h > 0 or pad_w > 0:
170
+ video_tensor = F.pad(video_tensor, (0, pad_w, 0, pad_h, 0, 0, 0, pad_t))
171
+
172
+ # Chunk video into 16-frame, 224x224 chunks (following QUICK_START.md)
173
+ video_tensor = video_tensor.unsqueeze(0) # 1 * T * C * H * W
174
+
175
+ # Calculate chunking dimensions
176
+ nt = (T + pad_t) // 16
177
+ nh = (H + pad_h) // 224
178
+ nw = (W + pad_w) // 224
179
+ num_spatial_chunks = nh * nw
180
+ num_chunks = nt * num_spatial_chunks
181
+
182
+ # Chunk into (num_chunks, 16, C, 224, 224)
183
+ video_chunks = rearrange(video_tensor, 'B (nt t) C (nh h) (nw w) -> (B nt nh nw) t C h w', t=16, h=224, w=224)
184
+
185
+ print(f"Video chunked into {num_chunks} chunks ({nt} temporal x {num_spatial_chunks} spatial) of shape (16, {C}, 224, 224). Original shape: ({T}, {C}, {H}, {W})")
186
+
187
+ # Apply VivitImageProcessor normalization to chunks
188
+ # Rearrange chunks to process all frames: (num_chunks, 16, C, H, W) -> (num_chunks * 16, C, H, W)
189
+ chunks_flat = rearrange(video_chunks, 'b t c h w -> (b t) c h w')
190
+
191
+ # Apply normalization using VivitImageProcessor's mean and std (on CPU)
192
+ mean = torch.tensor(transform.image_mean).view(1, 3, 1, 1)
193
+ std = torch.tensor(transform.image_std).view(1, 3, 1, 1)
194
+ chunks_flat = (chunks_flat - mean) / std
195
+
196
+ video_chunks = rearrange(chunks_flat, '(b t) c h w -> b t c h w', b=num_chunks, t=16)
197
+ video_chunks = rearrange(video_chunks, '(ns nt) t c h w -> ns nt t c h w', ns=num_spatial_chunks, nt=nt)
198
+
199
+ # Keep video_chunks on CPU - only move mini-batches to GPU as needed
200
+ print(f'video_chunks shape (spatial, temporal, frames, C, H, W): {video_chunks.shape}')
201
+
202
+ del video_tensor, chunks_flat, mean, std
203
+
204
+ with torch.inference_mode():
205
+ # Process spatial locations in mini-batches (keep all temporal chunks together per spatial location)
206
+ num_spatial_batches = (num_spatial_chunks + spatial_batch_size - 1) // spatial_batch_size
207
+
208
+ all_gaze_outputs = []
209
+ total_gazing_tokens = 0
210
+
211
+ for batch_idx in range(num_spatial_batches):
212
+ start_idx = batch_idx * spatial_batch_size
213
+ end_idx = min(start_idx + spatial_batch_size, num_spatial_chunks)
214
+ batch_size = end_idx - start_idx
215
+
216
+ gazing_pct = int(((batch_idx + 1) / num_spatial_batches) * 100)
217
+ if progress_callback:
218
+ progress_callback(0.1 + 0.4 * (batch_idx / num_spatial_batches), f"Gazing progress: {gazing_pct}%")
219
+ yield None
220
+
221
+ # Extract mini-batch from CPU and move to GPU: (batch_size, nt, 16, C, H, W)
222
+ spatial_batch = video_chunks[start_idx:end_idx].to(device)
223
+ # Flatten to (batch_size * nt, 16, C, H, W) for model
224
+ spatial_batch = rearrange(spatial_batch, 'bs nt t c h w -> (bs nt) t c h w')
225
+ print(f'Processing spatial batch {batch_idx+1}/{num_spatial_batches} with {batch_size} spatial locations x {nt} temporal = {spatial_batch.shape[0]} chunks')
226
+
227
+ # Run AutoGaze on this mini-batch
228
+ batch_gaze_output = model({"video": spatial_batch}, gazing_ratio=gazing_ratio, task_loss_requirement=task_loss_requirement)
229
+
230
+ # Free GPU memory after forward pass
231
+ del spatial_batch
232
+
233
+ # Count gazing tokens for this batch
234
+ if_padded = batch_gaze_output.get('if_padded_gazing')
235
+ if if_padded is not None:
236
+ total_gazing_tokens += (~if_padded).sum().item()
237
+ else:
238
+ total_gazing_tokens += (batch_gaze_output['gazing_pos'] < (196 * 16)).sum().item()
239
+
240
+ # Store the output
241
+ all_gaze_outputs.append(batch_gaze_output)
242
+ if torch.cuda.is_available():
243
+ torch.cuda.empty_cache()
244
+
245
+ print("Merging mini-batch results...")
246
+
247
+ # Find max sequence length across all mini-batches
248
+ max_seq_len = max(out['gazing_pos'].shape[1] for out in all_gaze_outputs)
249
+
250
+ # Pad gazing_pos and if_padded_gazing to same length (they have variable seq length)
251
+ # gazing_mask doesn't need padding since all chunks have same shape
252
+ padded_gazing_pos = []
253
+ padded_if_padded_gazing = []
254
+
255
+ for out in all_gaze_outputs:
256
+ seq_len = out['gazing_pos'].shape[1]
257
+ pad_len = max_seq_len - seq_len
258
+
259
+ # Pad gazing_pos with zeros
260
+ padded_pos = F.pad(out['gazing_pos'], (0, pad_len), value=0)
261
+ padded_gazing_pos.append(padded_pos)
262
+
263
+ # Pad if_padded_gazing and mark new positions as True (padded)
264
+ if 'if_padded_gazing' in out:
265
+ padded_if_pad = F.pad(out['if_padded_gazing'], (0, pad_len), value=True)
266
+ padded_if_padded_gazing.append(padded_if_pad)
267
+
268
+ # Store num_gazing_each_frame per mini-batch for later per-chunk extraction
269
+ num_gazing_each_frame_list = [out['num_gazing_each_frame'] for out in all_gaze_outputs]
270
+ batch_sizes = [out['gazing_pos'].shape[0] for out in all_gaze_outputs]
271
+
272
+ gaze_output = {
273
+ 'gazing_pos': torch.cat(padded_gazing_pos, dim=0),
274
+ 'gazing_mask': [torch.cat([out['gazing_mask'][i] for out in all_gaze_outputs], dim=0) for i in range(4)],
275
+ 'num_gazing_each_frame_list': num_gazing_each_frame_list, # List of values per mini-batch
276
+ 'batch_sizes': batch_sizes, # Track which chunks came from which mini-batch
277
+ 'frame_sampling_rate': all_gaze_outputs[0]['frame_sampling_rate'],
278
+ 'num_vision_tokens_each_frame': all_gaze_outputs[0]['num_vision_tokens_each_frame'],
279
+ }
280
+ if len(padded_if_padded_gazing) > 0:
281
+ gaze_output['if_padded_gazing'] = torch.cat(padded_if_padded_gazing, dim=0)
282
+
283
+ # Clean up mini-batch outputs
284
+ del all_gaze_outputs
285
+
286
+ total_possible_tokens = 196 * 16 * num_chunks
287
+
288
+ # Extract gazing masks for later visualization (already in batched form)
289
+ gazing_masks_batched = gaze_output['gazing_mask'] # List of 4 scales, each (num_chunks, 16, num_patches)
290
+
291
+ # Flatten video_chunks back to (num_chunks, 16, C, H, W) for reconstruction
292
+ video_chunks_flat = rearrange(video_chunks, 'ns nt t c h w -> (ns nt) t c h w').cpu()
293
+
294
+ # Pre-allocate reconstruction tensor on CPU to avoid memory accumulation
295
+ total_frames = num_chunks * 16
296
+ C = video_chunks_flat.shape[2]
297
+ reconstruction_chunks = torch.zeros((total_frames, C, 224, 224), dtype=torch.float32)
298
+ frame_idx_counter = 0
299
+
300
+ # Process reconstruction in mini-batches matching AutoGaze batch structure
301
+ num_autogaze_batches = len(gaze_output['num_gazing_each_frame_list'])
302
+ print(f'Reconstructing {num_chunks} chunks in {num_autogaze_batches} batches (aligned with AutoGaze batches)...')
303
+
304
+ chunk_idx = 0
305
+ for autogaze_batch_idx in range(num_autogaze_batches):
306
+ batch_size = gaze_output['batch_sizes'][autogaze_batch_idx]
307
+ start_chunk_idx = chunk_idx
308
+ end_chunk_idx = chunk_idx + batch_size
309
+
310
+ print(f'Reconstructing chunks {start_chunk_idx+1}-{end_chunk_idx}/{num_chunks}...')
311
+
312
+ # Extract videos for all chunks in this AutoGaze batch
313
+ batch_videos = video_chunks_flat[start_chunk_idx:end_chunk_idx].to(device) # (batch_size, 16, C, H, W)
314
+
315
+ # Extract gazing data for all chunks in this AutoGaze batch
316
+ batch_gazing_pos = gaze_output['gazing_pos'][start_chunk_idx:end_chunk_idx]
317
+ batch_gazing_mask = [scale_mask[start_chunk_idx:end_chunk_idx] for scale_mask in gaze_output['gazing_mask']]
318
+ batch_num_gazing_each_frame = gaze_output['num_gazing_each_frame_list'][autogaze_batch_idx]
319
+
320
+ # Trim to expected sequence length for this AutoGaze batch
321
+ expected_seq_len = batch_num_gazing_each_frame.sum().item()
322
+ batch_gazing_pos = batch_gazing_pos[:, :expected_seq_len]
323
+
324
+ chunk_idx = end_chunk_idx
325
+
326
+ batch_gaze_output = {
327
+ 'gazing_pos': batch_gazing_pos,
328
+ 'gazing_mask': batch_gazing_mask,
329
+ 'num_gazing_each_frame': batch_num_gazing_each_frame,
330
+ 'frame_sampling_rate': gaze_output['frame_sampling_rate'],
331
+ 'num_vision_tokens_each_frame': gaze_output['num_vision_tokens_each_frame'],
332
+ }
333
+
334
+ if 'if_padded_gazing' in gaze_output:
335
+ batch_if_padded = gaze_output['if_padded_gazing'][start_chunk_idx:end_chunk_idx]
336
+ batch_if_padded = batch_if_padded[:, :expected_seq_len]
337
+ batch_gaze_output['if_padded_gazing'] = batch_if_padded
338
+
339
+ # Reconstruct frame by frame for this batch
340
+ batch_video_dict = {"video": batch_videos}
341
+ # Pre-allocate batch_reconstructions tensor to avoid list + stack memory spike
342
+ batch_reconstructions = torch.zeros((16, batch_size, C, 224, 224), device=device)
343
+ for frame_idx in range(16):
344
+ # Update progress for each frame
345
+ frame_pct = int(((autogaze_batch_idx * 16 + frame_idx + 1) / (num_autogaze_batches * 16)) * 100)
346
+ if progress_callback:
347
+ progress_callback(0.5 + 0.4 * ((autogaze_batch_idx * 16 + frame_idx + 1) / (num_autogaze_batches * 16)), f"Reconstruction progress: {frame_pct}%")
348
+ yield None
349
+
350
+ task_output = task.forward_output(batch_video_dict, batch_gaze_output, frame_idx_to_reconstruct=[frame_idx])
351
+ batch_reconstructions[frame_idx] = task_output['reconstruction'][:, 0] # (recon_batch_size, C, H, W)
352
+ del task_output
353
+
354
+ # Reorder from (16, recon_batch_size, C, H, W) to (recon_batch_size, 16, C, H, W) to match expected chunk ordering
355
+ # batch_reconstructions already in shape (16, recon_batch_size, C, H, W)
356
+ batch_reconstructions = rearrange(batch_reconstructions, 't b c h w -> (b t) c h w') # (recon_batch_size * 16, C, H, W)
357
+
358
+ # Write directly into pre-allocated tensor
359
+ batch_size_frames = batch_reconstructions.shape[0]
360
+ reconstruction_chunks[frame_idx_counter:frame_idx_counter+batch_size_frames] = batch_reconstructions.cpu()
361
+ frame_idx_counter += batch_size_frames
362
+
363
+ # Clean up batch-specific variables
364
+ del batch_videos, batch_gaze_output, batch_video_dict, batch_reconstructions
365
+ print('Reconstruction complete.')
366
+ # Manually reverse the mean/std normalization to get back to [0, 1] range
367
+ mean = torch.tensor(transform.image_mean).view(1, 3, 1, 1).to(reconstruction_chunks.device)
368
+ std = torch.tensor(transform.image_std).view(1, 3, 1, 1).to(reconstruction_chunks.device)
369
+ reconstruction_chunks = reconstruction_chunks * std + mean
370
+
371
+ # Clean up video chunks and gaze output to free GPU memory (keep gazing_masks_batched for later)
372
+ del video_chunks, video_chunks_flat, gaze_output
373
+
374
+ # Reshape chunks back to original structure (nt, nh, nw already calculated earlier)
375
+ print(f'Reshaping reconstructed chunks back to video tensor...')
376
+ reconstruction_tensor = rearrange(reconstruction_chunks, '(nt nh nw t) C h w -> (nt t) C (nh h) (nw w)', nt=nt, nh=nh, nw=nw, t=16)
377
+ reconstruction_tensor = reconstruction_tensor[:T, :, :H, :W] # Remove padding
378
+
379
+ # Move reconstruction to GPU for visualization
380
+ reconstruction_tensor = reconstruction_tensor.to(device)
381
+
382
+ gazing_mask_assembled = []
383
+ for scale_idx in range(4):
384
+ scale_masks_stacked = gazing_masks_batched[scale_idx]
385
+
386
+ # Reshape: (num_chunks, 16, num_patches) -> (num_chunks * 16, num_patches)
387
+ scale_masks_flat = scale_masks_stacked.reshape(-1, scale_masks_stacked.shape[-1])
388
+
389
+ # Rearrange back to original video structure
390
+ scale_masks_reshaped = rearrange(scale_masks_flat, '(nt nh nw t) n -> (nt t) (nh nw) n', nt=nt, nh=nh, nw=nw, t=16)
391
+ scale_masks_reshaped = scale_masks_reshaped[:T] # Remove temporal padding
392
+
393
+ gazing_mask_assembled.append(scale_masks_reshaped)
394
+
395
+ del scale_masks_stacked, scale_masks_flat, scale_masks_reshaped
396
+
397
+ del gazing_masks_batched
398
+
399
+ pct = total_gazing_tokens / total_possible_tokens
400
+
401
+ # Move original video to GPU for visualization
402
+ video_viz = video_tensor_original.to(device)
403
+
404
+ # Generate frame-by-frame visualizations
405
+ original_frames = []
406
+ composite_frames = []
407
+ reconstruction_frames = []
408
+ scales_stitch_frames = []
409
+
410
+ print('Visualizing...')
411
+ if progress_callback:
412
+ progress_callback(0.9, "Visualizing...")
413
+ yield None
414
+ for t in trange(T):
415
+ # Original frame
416
+ frame = video_viz[t].permute(1, 2, 0)
417
+ frame = torch.clip(frame, 0, 1)
418
+ frame_uint8 = (frame * 255).byte().cpu().numpy()
419
+ original_frames.append(frame_uint8)
420
+
421
+ # Reconstruction frame
422
+ recon_frame = reconstruction_tensor[t].permute(1, 2, 0)
423
+ recon_frame = torch.clip(recon_frame, 0, 1)
424
+ recon_uint8 = (recon_frame * 255).byte().cpu().numpy()
425
+ reconstruction_frames.append(recon_uint8)
426
+
427
+ composite = torch.zeros((H, W, 3)).to(device)
428
+ scales = setup['scales']
429
+ alpha_values = [0.4, 0.5, 0.6, 0.7] # Per-scale opacity (coarse to fine)
430
+ colors = [
431
+ [1.0, 0.0, 0.0], # Scale 0 (coarsest): Red
432
+ [0.0, 1.0, 0.0], # Scale 1: Green
433
+ [0.0, 0.0, 1.0], # Scale 2: Blue
434
+ [1.0, 1.0, 0.0] # Scale 3 (finest): Yellow
435
+ ]
436
+
437
+ for scale_idx in range(4):
438
+ scale = scales[scale_idx]
439
+ scale_h = int(scale * H / 224)
440
+ scale_w = int(scale * W / 224)
441
+
442
+ # Get mask for this scale and frame
443
+ mask = gazing_mask_assembled[scale_idx][t] # (nh * nw, num_patches)
444
+
445
+ # print(f'Frame {t}, Scale {scale}: mask shape {mask.shape}')
446
+ # print(mask)
447
+ # print()
448
+
449
+ # Reshape mask: (nh * nw, num_patches) where num_patches = s^2
450
+ num_patches_per_chunk = mask.shape[-1]
451
+ s = int(num_patches_per_chunk ** 0.5)
452
+
453
+ # Rearrange to 2D spatial grid
454
+ mask_2d = rearrange(mask, '(nh nw) (h w) -> (nh h) (nw w)', nh=nh, nw=nw, h=s, w=s)
455
+
456
+ # Convert to tensor if needed
457
+ if isinstance(mask_2d, np.ndarray):
458
+ mask_tensor = torch.from_numpy(mask_2d)
459
+ else:
460
+ mask_tensor = mask_2d
461
+
462
+ mask_resized = F.interpolate(mask_tensor.unsqueeze(0).unsqueeze(0).float(), size=(scale_h, scale_w), mode='nearest')[0, 0]
463
+
464
+ frame_tensor = video_viz[t]
465
+ frame_scaled = F.interpolate(frame_tensor.unsqueeze(0), size=(scale_h, scale_w), mode='bicubic', align_corners=False).squeeze().clamp(0, 1)
466
+
467
+ frame_scaled_masked = frame_scaled * mask_resized.unsqueeze(0)
468
+
469
+ # Upsample both masked frame and mask to full size
470
+ frame_upsampled = F.interpolate(frame_scaled_masked.unsqueeze(0), size=(H, W), mode='nearest').squeeze() #.cpu().numpy()
471
+ mask_upsampled = F.interpolate(mask_resized.unsqueeze(0).unsqueeze(0), size=(H, W), mode='nearest').squeeze() #.cpu().numpy()
472
+
473
+ frame_upsampled = frame_upsampled.permute(1, 2, 0)
474
+
475
+ composite = composite * (1 - mask_upsampled[:, :, None] * alpha_values[scale_idx]) + frame_upsampled * alpha_values[scale_idx]
476
+
477
+ composite_np = composite.detach().cpu().numpy()
478
+ composite_np = (composite_np - composite_np.min()) / (composite_np.max() - composite_np.min() + 1e-8)
479
+ composite_uint8 = (composite_np * 255).astype(np.uint8)
480
+ composite_frames.append(composite_uint8)
481
+
482
+ # Create individual scale visualizations for horizontal stitch
483
+ scale_composites = []
484
+ label_bar_height = 30
485
+
486
+ for scale_idx in range(4):
487
+ scale = scales[scale_idx]
488
+ scale_h = int(scale * H / 224)
489
+ scale_w = int(scale * W / 224)
490
+
491
+ # Get mask for this scale and frame
492
+ mask = gazing_mask_assembled[scale_idx][t]
493
+
494
+ # Reshape mask to 2D spatial grid
495
+ num_patches_per_chunk = mask.shape[-1]
496
+ s = int(num_patches_per_chunk ** 0.5)
497
+ mask_2d = rearrange(mask, '(nh nw) (h w) -> (nh h) (nw w)', nh=nh, nw=nw, h=s, w=s)
498
+
499
+ if isinstance(mask_2d, np.ndarray):
500
+ mask_tensor_scale = torch.from_numpy(mask_2d)
501
+ else:
502
+ mask_tensor_scale = mask_2d
503
+
504
+ mask_resized_scale = F.interpolate(mask_tensor_scale.unsqueeze(0).unsqueeze(0).float(), size=(scale_h, scale_w), mode='nearest')[0, 0]
505
+
506
+ frame_tensor_scale = video_viz[t]
507
+ frame_scaled_scale = F.interpolate(frame_tensor_scale.unsqueeze(0), size=(scale_h, scale_w), mode='bicubic', align_corners=False).squeeze().clamp(0, 1)
508
+
509
+ # Apply gazing pattern: gazed tiles = 1.0 brightness, ungazed tiles = 0.2 brightness
510
+ frame_scaled_permuted = frame_scaled_scale.permute(1, 2, 0)
511
+ scale_composite = frame_scaled_permuted * (mask_resized_scale[:, :, None] * 1.0 + (1 - mask_resized_scale[:, :, None]) * 0.2)
512
+
513
+ scale_composite_np = scale_composite.detach().cpu().numpy()
514
+ scale_composite_np = np.clip(scale_composite_np, 0, 1)
515
+ scale_composite_uint8 = (scale_composite_np * 255).astype(np.uint8)
516
+
517
+ # Resize visualization to common display height first (preserving aspect ratio)
518
+ display_width = int(scale_w * H / scale_h)
519
+ scale_composite_pil = Image.fromarray(scale_composite_uint8)
520
+ scale_composite_resized = scale_composite_pil.resize((display_width, H), Image.NEAREST)
521
+ scale_composite_resized_np = np.array(scale_composite_resized)
522
+
523
+ # Create label bar matching the resized visualization width
524
+ label_bar = np.ones((label_bar_height, display_width, 3), dtype=np.uint8) * 255
525
+ label_bar_pil = Image.fromarray(label_bar)
526
+ draw = ImageDraw.Draw(label_bar_pil)
527
+ try:
528
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
529
+ except:
530
+ font = ImageFont.load_default()
531
+
532
+ label = f"Scale {scale_idx + 1}"
533
+ draw.text((5, 5), label, fill=(0, 0, 0), font=font)
534
+ label_bar_np = np.array(label_bar_pil)
535
+
536
+ # Stack label bar above the visualization
537
+ scale_with_label = np.vstack([label_bar_np, scale_composite_resized_np])
538
+
539
+ scale_composites.append(scale_with_label)
540
+
541
+ # Add 10px white padding between scales
542
+ padding = np.ones((H + label_bar_height, 10, 3), dtype=np.uint8) * 255
543
+
544
+ # Concatenate all scales horizontally with padding
545
+ stitched = scale_composites[0]
546
+ for i in range(1, 4):
547
+ stitched = np.concatenate([stitched, padding, scale_composites[i]], axis=1)
548
+
549
+ # Add white padding at the top to prevent Gradio's label from blocking content
550
+ top_padding = np.ones((50, stitched.shape[1], 3), dtype=np.uint8) * 255
551
+ stitched = np.vstack([top_padding, stitched])
552
+
553
+ scales_stitch_frames.append(stitched)
554
+
555
+ del frame_tensor, mask_tensor, mask_resized, frame_scaled, frame_scaled_masked, frame_upsampled, mask_upsampled
556
+
557
+ del gazing_mask_assembled
558
+
559
+ del video_tensor_original, reconstruction_tensor, video_viz, reconstruction_chunks
560
+
561
+ if device == 'cuda':
562
+ torch.cuda.empty_cache()
563
+
564
+ yield {
565
+ 'original_frames': original_frames,
566
+ 'gazing_frames': composite_frames,
567
+ 'reconstruction_frames': reconstruction_frames,
568
+ 'scales_stitch_frames': scales_stitch_frames,
569
+ 'fps': fps,
570
+ 'gazing_pct': pct,
571
+ 'total_gazing_tokens': total_gazing_tokens,
572
+ 'total_possible_tokens': total_possible_tokens
573
+ }
574
+
575
+
576
+ def save_video(frames, output_path, fps):
577
+ with imageio.get_writer(output_path, fps=fps, format='FFMPEG', codec='libx264', pixelformat='yuv420p') as writer:
578
+ for frame in frames:
579
+ writer.append_data(frame)
environment.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: gengaze_demo
2
+ channels:
3
+ - nvidia
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ - python=3.10
8
+ - pip
9
+ - nvidia::cuda-toolkit=12.6
10
+ - pip:
11
+ - torch==2.7.1
12
+ - torchvision==0.22.1
13
+ - torchaudio==2.7.1
14
+ - numpy==1.26.4
15
+ - pillow==10.4.0
16
+ - matplotlib==3.10.1
17
+ - gradio>=4.0.0
18
+ - spaces
19
+ - flash_attn==2.8.0.post2
20
+ - hydra-core==1.3.2
21
+ - wandb==0.21.0
22
+ - loguru==0.7.3
23
+ - timm==1.0.15
24
+ - tqdm==4.67.1
25
+ - transformers==4.53.0
26
+ - omegaconf==2.3.0
27
+ - einops==0.8.1
28
+ - av==14.4.0
29
+ - imageio==2.37.0
example_inputs/aerial.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e90d807c5d0438ff80112a2634b8cc10c4700cfbeaede96c5bd931035f170f46
3
+ size 298170
example_inputs/aerial_thumb.png ADDED
example_inputs/doorbell.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8667f28dd39c89b630e4ad29cf49e8bc82fb5ed26196fe0371b0f10e54a2ba9
3
+ size 460064
example_inputs/doorbell_thumb.png ADDED
example_inputs/tomjerry.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e0a7d90ea96f817268e44dc38f58bdac3348df7f64a3eb54bba79ed5e7df7a3
3
+ size 435371
example_inputs/tomjerry_thumb.png ADDED
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ pillow==10.4.0
3
+ matplotlib==3.10.1
4
+ gradio>=4.0.0
5
+ spaces
6
+ hydra-core==1.3.2
7
+ wandb==0.21.0
8
+ loguru==0.7.3
9
+ timm==1.0.15
10
+ tqdm==4.67.1
11
+ transformers==4.53.0
12
+ omegaconf==2.3.0
13
+ einops==0.8.1
14
+ av==14.4.0
15
+ imageio==2.37.0