File size: 14,463 Bytes
c0c592e
 
 
 
 
 
 
 
7e3b296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0c592e
 
 
 
 
 
 
 
 
 
7e3b296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9053eb9
7e3b296
 
 
 
 
 
 
 
 
 
 
9053eb9
7e3b296
 
 
 
 
 
7e6904b
7e3b296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9053eb9
 
 
7e3b296
 
 
 
 
 
 
 
 
 
9053eb9
7e3b296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9053eb9
7e3b296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
# IMPORTANT: Import spaces first, before any CUDA-related packages (torch, etc.)
try:
    import spaces
    ZEROGPU_AVAILABLE = True
except ImportError:
    ZEROGPU_AVAILABLE = False
    print("Warning: spaces module not available. Running without ZeroGPU support.")

import gradio as gr
import tempfile
import os
import torch
import gc
from demo_utils import load_model, process_video, save_video, image_to_video
import av
from PIL import Image
import numpy as np

model_cache = {}

def get_model(device):
    if device not in model_cache:
        model_cache[device] = load_model(device=device)
    return model_cache[device]

# Determine device: use CUDA if available locally or if ZeroGPU will provide it
if ZEROGPU_AVAILABLE:
    device = "cuda"  # ZeroGPU will provide GPU
    print("Using ZeroGPU (CUDA device will be allocated on demand)")
elif torch.cuda.is_available():
    device = "cuda"
    print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}")
else:
    device = "cpu"
    print("No GPU available, using CPU")

def cleanup_gpu():
    """Clean up GPU memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def extract_metadata(file):
    if file is None:
        return "", None, None, None, None, None

    file_extension = os.path.splitext(file.name)[1].lower()
    is_image = file_extension in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']

    if is_image:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video:
            tmp_path = tmp_video.name

        metadata = image_to_video(file.name, tmp_path, fps=1.0)

        total_frames = metadata['frames']
        fps = metadata['fps']
        original_height = metadata['height']
        original_width = metadata['width']
        info_text = f"{original_width}ร—{original_height} | Image (1 frame)"
    else:
        tmp_path = file.name

        container = av.open(tmp_path)
        video_stream = container.streams.video[0]
        total_frames = video_stream.frames
        fps = float(video_stream.average_rate)
        original_height = video_stream.height
        original_width = video_stream.width
        container.close()
        info_text = f"{original_width}ร—{original_height} | {total_frames} frames @ {fps:.1f} FPS"

    return info_text, tmp_path, total_frames, fps, original_width, original_height

def handle_file_upload(file):
    metadata = extract_metadata(file)

    if metadata[1] is None:
        return "", None, None

    info_text, tmp_path, total_frames, fps, original_width, original_height = metadata
    return info_text, metadata, fps

def _process_video_impl(file_info, gazing_ratio, task_loss_requirement, output_fps, progress=None):
    if file_info is None:
        return None, None, None, None, None, None, None, "No file uploaded"

    _, tmp_path, total_frames, fps, _, _ = file_info

    if tmp_path is None:
        return None, None, None, None, None, None, None, "Invalid file"

    # Yield initial status
    yield None, None, None, None, None, None, None, "Loading model..."

    if progress:
        progress(0.0, desc="Loading model...")
    setup = get_model(device)

    yield None, None, None, None, None, None, None, "Processing video..."

    if progress:
        progress(0.1, desc="Processing video...")

    status_messages = []

    def update_progress(pct, msg):
        if progress:
            progress(pct, desc=msg)
        status_messages.append(msg)

    # Convert UI gazing ratio to model gazing ratio
    # UI: ranges from 1/196 to 265/196 (effective patches per frame / 196)
    # Model: needs value * (196/265) to get actual gazing ratio
    model_gazing_ratio = gazing_ratio * (196 / 265)

    for results in process_video(
        tmp_path,
        setup,
        gazing_ratio=model_gazing_ratio,
        task_loss_requirement=task_loss_requirement,
        progress_callback=update_progress,
        spatial_batch_size=2  # Process 4 spatial chunks at a time to avoid OOM
    ):
        if status_messages:
            yield None, None, None, None, None, None, None, status_messages[-1]

    yield None, None, None, None, None, None, None, "Saving output videos..."

    with tempfile.TemporaryDirectory() as tmpdir:
        original_path = os.path.join(tmpdir, "original.mp4")
        gazing_path = os.path.join(tmpdir, "gazing.mp4")
        recon_path = os.path.join(tmpdir, "reconstruction.mp4")
        scales_stitch_path = os.path.join(tmpdir, "scales_stitch.mp4")

        # Use output_fps if specified, otherwise use original video fps
        fps_to_use = output_fps if output_fps is not None else results['fps']

        save_video(results['original_frames'], original_path, fps_to_use)
        save_video(results['gazing_frames'], gazing_path, fps_to_use)
        save_video(results['reconstruction_frames'], recon_path, fps_to_use)
        save_video(results['scales_stitch_frames'], scales_stitch_path, fps_to_use)

        with open(original_path, "rb") as f:
            original_data = f.read()
        with open(gazing_path, "rb") as f:
            gazing_data = f.read()
        with open(recon_path, "rb") as f:
            recon_data = f.read()
        with open(scales_stitch_path, "rb") as f:
            scales_stitch_data = f.read()

        original_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
        original_file.write(original_data)
        original_file.close()

        gazing_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
        gazing_file.write(gazing_data)
        gazing_file.close()

        recon_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
        recon_file.write(recon_data)
        recon_file.close()

        scales_stitch_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
        scales_stitch_file.write(scales_stitch_data)
        scales_stitch_file.close()

    gazing_pct_text = f"{results['gazing_pct']:.2%}"
    gazing_tokens_text = f"{results['total_gazing_tokens']:,}"
    total_tokens_text = f"{results['total_possible_tokens']:,}"

    yield (
        gazing_pct_text,
        gazing_tokens_text,
        total_tokens_text,
        original_file.name,
        gazing_file.name,
        recon_file.name,
        scales_stitch_file.name,
        "Processing complete!"
    )

if ZEROGPU_AVAILABLE:
    process_video_ui = spaces.GPU(duration=120)(_process_video_impl)
else:
    process_video_ui = _process_video_impl

def extract_first_frame_thumbnail(video_path, output_path, size=(200, 200), force=False):
    """Extract first frame from video and save as thumbnail with fixed aspect ratio."""
    if os.path.exists(output_path) and not force:
        return
    container = av.open(video_path)
    for frame in container.decode(video=0):
        img = frame.to_image()
        # Crop to center square first, then resize
        width, height = img.size
        min_dim = min(width, height)
        left = (width - min_dim) // 2
        top = (height - min_dim) // 2
        img_cropped = img.crop((left, top, left + min_dim, top + min_dim))
        img_resized = img_cropped.resize(size, Image.LANCZOS)
        img_resized.save(output_path)
        break
    container.close()

# Generate thumbnails for example videos
example_videos = [
    "example_inputs/doorbell.mp4",
    "example_inputs/tomjerry.mp4",
    "example_inputs/security.mp4",
]

for video_path in example_videos:
    if os.path.exists(video_path):
        thumb_path = video_path.replace('.mp4', '_thumb.png')
        # Force regeneration with square aspect ratio at 100x100 to match gallery height
        extract_first_frame_thumbnail(video_path, thumb_path, size=(100, 100), force=True)

# Load thumbnails as numpy arrays
doorbell_thumb_img = np.array(Image.open("example_inputs/doorbell_thumb.png"))
tomjerry_thumb_img = np.array(Image.open("example_inputs/tomjerry_thumb.png"))
security_thumb_img = np.array(Image.open("example_inputs/security_thumb.png"))

with gr.Blocks(title="AutoGaze Demo", delete_cache=(86400, 86400)) as demo:
    gr.Markdown("# AutoGaze Official Demo")
    gr.Markdown("## **Attend Before Attention: Efficient and Scalable Video Understanding via Autoregressive Gazing**")
    gr.Markdown("""
        <div style="text-align: left; margin: 10px 0; font-size: 1.2em; font-weight: 600;">
            ๐Ÿ“„ <a href="https://arxiv.org/abs/2603.12254" target="_blank" style="text-decoration: none; color: inherit;">Paper</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ๐ŸŒ <a href="https://autogaze.github.io" target="_blank" style="text-decoration: none; color: inherit;">Project Website</a>
        </div>
    """)

    file_metadata = gr.State()

    with gr.Row():
        with gr.Column(scale=2):
            uploaded_file = gr.File(
                label="Upload Video or Image",
                file_types=["video", "image"]
            )
        with gr.Column(scale=1):
            file_info = gr.Textbox(label="File Info", interactive=False)
            process_button = gr.Button("Process Video", variant="primary")


    def load_example_video(evt: gr.SelectData):
        video_map = {
            0: "example_inputs/doorbell.mp4",
            1: "example_inputs/tomjerry.mp4",
            2: "example_inputs/security.mp4",
        }
        return video_map[evt.index]

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Example Videos - Click Thumbnail to Load")
            example_gallery = gr.Gallery(
                value=[
                    (doorbell_thumb_img, "doorbell.mp4"),
                    (tomjerry_thumb_img, "tomjerry.mp4"),
                    (security_thumb_img, "security.mp4"),
                ],
                label="",
                show_label=False,
                columns=3,
                rows=1,
                height=200,
                object_fit="contain",
                allow_preview=False
            )
            gr.Markdown("### Settings")

            with gr.Accordion("Output Settings", open=True):
                fps_slider = gr.Number(
                    label="Output FPS",
                    value=None,
                    minimum=1,
                    maximum=120,
                    info="Frames per second for displaying output videos (only affects playback speed)"
                )

            with gr.Accordion("Model Parameters", open=True):
                gazing_ratio_slider = gr.Slider(
                    label="Gazing Ratio",
                    minimum=round(1/196, 2),
                    maximum=round(265/196, 2),
                    step=0.01,
                    value=0.75,
                    info="Max fraction of patches to gaze at per frame"
                )
                task_loss_slider = gr.Slider(
                    label="Task Loss Requirement",
                    minimum=0.0,
                    maximum=1.5,
                    step=0.05,
                    value=0.7,
                    info="Reconstruction loss threshold"
                )

            with gr.Accordion("FAQ", open=False):
                gr.Markdown("""
                    **What file formats are supported?**
                    
                    The app supports common video formats (MP4, AVI, MOV, etc.) and image formats (JPG, PNG, etc.).
                
                    **What is the Gazing Ratio?**
                    
                    The gazing ratio explicitly controls how many patches the model looks at per frame. Higher values mean more patches are selected. The range extends to past 1.0 because of multi-scale gazing; if all patches at all scales are selected, the ratio can reach up to 1.35.
                
                    **What is Task Loss Requirement?**
                    
                    This threshold determines when the model stops gazing at a frame, based on the predicted reconstruction loss from the current gazed patches. Lower = more gazing, higher = less gazing.
                    
                    **How do Gazing Ratio and Task Loss interact?**
                    
                    These two parameters separately control the number of gazed patches in an image/video. This demo will take the stricter of the two requirements when determining how many patches to gaze at. For example, if the gazing ratio suggests gazing at 15% of patches, but the task loss requirement is met after only 7% patches, then only 7% patches will be gazed at. To only use one of the two parameters, set the other to its maximum value.
                """)

        with gr.Column(scale=2):
            gr.Markdown("### Results")

            status_text = gr.Markdown("Ready")

            with gr.Row():
                gazing_pct = gr.Textbox(label="Gazing %", interactive=False)
                gazing_tokens = gr.Textbox(label="# Gazed Patches", interactive=False)
                total_tokens = gr.Textbox(label="Total Patches", interactive=False)

            with gr.Row():
                original_video = gr.Video(label="Original", autoplay=False, loop=True)
                gazing_video = gr.Video(label="Gazing Pattern (all scales)", autoplay=False, loop=True)
                reconstruction_video = gr.Video(label="Reconstruction", autoplay=False, loop=True)

            with gr.Row():
                scales_stitch_video = gr.Video(label="Gazing Pattern (individual scales)", autoplay=False, loop=True)

    example_gallery.select(load_example_video, outputs=uploaded_file)
    uploaded_file.change(
        fn=handle_file_upload,
        inputs=[uploaded_file],
        outputs=[file_info, file_metadata, fps_slider]
    )

    process_button.click(
        fn=process_video_ui,
        inputs=[file_metadata, gazing_ratio_slider, task_loss_slider, fps_slider],
        outputs=[
            gazing_pct,
            gazing_tokens,
            total_tokens,
            original_video,
            gazing_video,
            reconstruction_video,
            scales_stitch_video,
            status_text
        ]
    ).then(
        fn=cleanup_gpu,
        inputs=None,
        outputs=None
    )

    # Clean up GPU memory when user disconnects
    demo.unload(cleanup_gpu)

# Clear any cached models and free GPU memory at app startup
print("Clearing model cache and GPU memory at startup...")
model_cache.clear()
cleanup_gpu()
print("Startup cleanup complete.")

if __name__ == "__main__":
    demo.launch(share=True)