# IMPORTANT: Import spaces first, before any CUDA-related packages (torch, etc.) try: import spaces ZEROGPU_AVAILABLE = True except ImportError: ZEROGPU_AVAILABLE = False print("Warning: spaces module not available. Running without ZeroGPU support.") import gradio as gr import tempfile import os import torch import gc from demo_utils import load_model, process_video, save_video, image_to_video import av from PIL import Image import numpy as np model_cache = {} def get_model(device): if device not in model_cache: model_cache[device] = load_model(device=device) return model_cache[device] # Determine device: use CUDA if available locally or if ZeroGPU will provide it if ZEROGPU_AVAILABLE: device = "cuda" # ZeroGPU will provide GPU print("Using ZeroGPU (CUDA device will be allocated on demand)") elif torch.cuda.is_available(): device = "cuda" print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}") else: device = "cpu" print("No GPU available, using CPU") def cleanup_gpu(): """Clean up GPU memory.""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() def extract_metadata(file): if file is None: return "", None, None, None, None, None file_extension = os.path.splitext(file.name)[1].lower() is_image = file_extension in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'] if is_image: with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video: tmp_path = tmp_video.name metadata = image_to_video(file.name, tmp_path, fps=1.0) total_frames = metadata['frames'] fps = metadata['fps'] original_height = metadata['height'] original_width = metadata['width'] info_text = f"{original_width}×{original_height} | Image (1 frame)" else: tmp_path = file.name container = av.open(tmp_path) video_stream = container.streams.video[0] total_frames = video_stream.frames fps = float(video_stream.average_rate) original_height = video_stream.height original_width = video_stream.width container.close() info_text = f"{original_width}×{original_height} | {total_frames} frames @ {fps:.1f} FPS" return info_text, tmp_path, total_frames, fps, original_width, original_height def handle_file_upload(file): metadata = extract_metadata(file) if metadata[1] is None: return "", None, None info_text, tmp_path, total_frames, fps, original_width, original_height = metadata return info_text, metadata, fps def _process_video_impl(file_info, gazing_ratio, task_loss_requirement, output_fps, progress=None): if file_info is None: return None, None, None, None, None, None, None, "No file uploaded" _, tmp_path, total_frames, fps, _, _ = file_info if tmp_path is None: return None, None, None, None, None, None, None, "Invalid file" # Yield initial status yield None, None, None, None, None, None, None, "Loading model..." if progress: progress(0.0, desc="Loading model...") setup = get_model(device) yield None, None, None, None, None, None, None, "Processing video..." if progress: progress(0.1, desc="Processing video...") status_messages = [] def update_progress(pct, msg): if progress: progress(pct, desc=msg) status_messages.append(msg) # Convert UI gazing ratio to model gazing ratio # UI: ranges from 1/196 to 265/196 (effective patches per frame / 196) # Model: needs value * (196/265) to get actual gazing ratio model_gazing_ratio = gazing_ratio * (196 / 265) for results in process_video( tmp_path, setup, gazing_ratio=model_gazing_ratio, task_loss_requirement=task_loss_requirement, progress_callback=update_progress, spatial_batch_size=2 # Process 4 spatial chunks at a time to avoid OOM ): if status_messages: yield None, None, None, None, None, None, None, status_messages[-1] yield None, None, None, None, None, None, None, "Saving output videos..." with tempfile.TemporaryDirectory() as tmpdir: original_path = os.path.join(tmpdir, "original.mp4") gazing_path = os.path.join(tmpdir, "gazing.mp4") recon_path = os.path.join(tmpdir, "reconstruction.mp4") scales_stitch_path = os.path.join(tmpdir, "scales_stitch.mp4") # Use output_fps if specified, otherwise use original video fps fps_to_use = output_fps if output_fps is not None else results['fps'] save_video(results['original_frames'], original_path, fps_to_use) save_video(results['gazing_frames'], gazing_path, fps_to_use) save_video(results['reconstruction_frames'], recon_path, fps_to_use) save_video(results['scales_stitch_frames'], scales_stitch_path, fps_to_use) with open(original_path, "rb") as f: original_data = f.read() with open(gazing_path, "rb") as f: gazing_data = f.read() with open(recon_path, "rb") as f: recon_data = f.read() with open(scales_stitch_path, "rb") as f: scales_stitch_data = f.read() original_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") original_file.write(original_data) original_file.close() gazing_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") gazing_file.write(gazing_data) gazing_file.close() recon_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") recon_file.write(recon_data) recon_file.close() scales_stitch_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") scales_stitch_file.write(scales_stitch_data) scales_stitch_file.close() gazing_pct_text = f"{results['gazing_pct']:.2%}" gazing_tokens_text = f"{results['total_gazing_tokens']:,}" total_tokens_text = f"{results['total_possible_tokens']:,}" yield ( gazing_pct_text, gazing_tokens_text, total_tokens_text, original_file.name, gazing_file.name, recon_file.name, scales_stitch_file.name, "Processing complete!" ) if ZEROGPU_AVAILABLE: process_video_ui = spaces.GPU(duration=120)(_process_video_impl) else: process_video_ui = _process_video_impl def extract_first_frame_thumbnail(video_path, output_path, size=(200, 200), force=False): """Extract first frame from video and save as thumbnail with fixed aspect ratio.""" if os.path.exists(output_path) and not force: return container = av.open(video_path) for frame in container.decode(video=0): img = frame.to_image() # Crop to center square first, then resize width, height = img.size min_dim = min(width, height) left = (width - min_dim) // 2 top = (height - min_dim) // 2 img_cropped = img.crop((left, top, left + min_dim, top + min_dim)) img_resized = img_cropped.resize(size, Image.LANCZOS) img_resized.save(output_path) break container.close() # Generate thumbnails for example videos example_videos = [ "example_inputs/doorbell.mp4", "example_inputs/tomjerry.mp4", "example_inputs/security.mp4", ] for video_path in example_videos: if os.path.exists(video_path): thumb_path = video_path.replace('.mp4', '_thumb.png') # Force regeneration with square aspect ratio at 100x100 to match gallery height extract_first_frame_thumbnail(video_path, thumb_path, size=(100, 100), force=True) # Load thumbnails as numpy arrays doorbell_thumb_img = np.array(Image.open("example_inputs/doorbell_thumb.png")) tomjerry_thumb_img = np.array(Image.open("example_inputs/tomjerry_thumb.png")) security_thumb_img = np.array(Image.open("example_inputs/security_thumb.png")) with gr.Blocks(title="AutoGaze Demo", delete_cache=(86400, 86400)) as demo: gr.Markdown("# AutoGaze Official Demo") gr.Markdown("## **Attend Before Attention: Efficient and Scalable Video Understanding via Autoregressive Gazing**") gr.Markdown("""
📄 Paper       🌐 Project Website
""") file_metadata = gr.State() with gr.Row(): with gr.Column(scale=2): uploaded_file = gr.File( label="Upload Video or Image", file_types=["video", "image"] ) with gr.Column(scale=1): file_info = gr.Textbox(label="File Info", interactive=False) process_button = gr.Button("Process Video", variant="primary") def load_example_video(evt: gr.SelectData): video_map = { 0: "example_inputs/doorbell.mp4", 1: "example_inputs/tomjerry.mp4", 2: "example_inputs/security.mp4", } return video_map[evt.index] with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Example Videos - Click Thumbnail to Load") example_gallery = gr.Gallery( value=[ (doorbell_thumb_img, "doorbell.mp4"), (tomjerry_thumb_img, "tomjerry.mp4"), (security_thumb_img, "security.mp4"), ], label="", show_label=False, columns=3, rows=1, height=200, object_fit="contain", allow_preview=False ) gr.Markdown("### Settings") with gr.Accordion("Output Settings", open=True): fps_slider = gr.Number( label="Output FPS", value=None, minimum=1, maximum=120, info="Frames per second for displaying output videos (only affects playback speed)" ) with gr.Accordion("Model Parameters", open=True): gazing_ratio_slider = gr.Slider( label="Gazing Ratio", minimum=round(1/196, 2), maximum=round(265/196, 2), step=0.01, value=0.75, info="Max fraction of patches to gaze at per frame" ) task_loss_slider = gr.Slider( label="Task Loss Requirement", minimum=0.0, maximum=1.5, step=0.05, value=0.7, info="Reconstruction loss threshold" ) with gr.Accordion("FAQ", open=False): gr.Markdown(""" **What file formats are supported?** The app supports common video formats (MP4, AVI, MOV, etc.) and image formats (JPG, PNG, etc.). **What is the Gazing Ratio?** The gazing ratio explicitly controls how many patches the model looks at per frame. Higher values mean more patches are selected. The range extends to past 1.0 because of multi-scale gazing; if all patches at all scales are selected, the ratio can reach up to 1.35. **What is Task Loss Requirement?** This threshold determines when the model stops gazing at a frame, based on the predicted reconstruction loss from the current gazed patches. Lower = more gazing, higher = less gazing. **How do Gazing Ratio and Task Loss interact?** These two parameters separately control the number of gazed patches in an image/video. This demo will take the stricter of the two requirements when determining how many patches to gaze at. For example, if the gazing ratio suggests gazing at 15% of patches, but the task loss requirement is met after only 7% patches, then only 7% patches will be gazed at. To only use one of the two parameters, set the other to its maximum value. """) with gr.Column(scale=2): gr.Markdown("### Results") status_text = gr.Markdown("Ready") with gr.Row(): gazing_pct = gr.Textbox(label="Gazing %", interactive=False) gazing_tokens = gr.Textbox(label="# Gazed Patches", interactive=False) total_tokens = gr.Textbox(label="Total Patches", interactive=False) with gr.Row(): original_video = gr.Video(label="Original", autoplay=False, loop=True) gazing_video = gr.Video(label="Gazing Pattern (all scales)", autoplay=False, loop=True) reconstruction_video = gr.Video(label="Reconstruction", autoplay=False, loop=True) with gr.Row(): scales_stitch_video = gr.Video(label="Gazing Pattern (individual scales)", autoplay=False, loop=True) example_gallery.select(load_example_video, outputs=uploaded_file) uploaded_file.change( fn=handle_file_upload, inputs=[uploaded_file], outputs=[file_info, file_metadata, fps_slider] ) process_button.click( fn=process_video_ui, inputs=[file_metadata, gazing_ratio_slider, task_loss_slider, fps_slider], outputs=[ gazing_pct, gazing_tokens, total_tokens, original_video, gazing_video, reconstruction_video, scales_stitch_video, status_text ] ).then( fn=cleanup_gpu, inputs=None, outputs=None ) # Clean up GPU memory when user disconnects demo.unload(cleanup_gpu) # Clear any cached models and free GPU memory at app startup print("Clearing model cache and GPU memory at startup...") model_cache.clear() cleanup_gpu() print("Startup cleanup complete.") if __name__ == "__main__": demo.launch(share=True)