Spaces:
Running on Zero
Running on Zero
| # IMPORTANT: Import spaces first, before any CUDA-related packages (torch, etc.) | |
| try: | |
| import spaces | |
| ZEROGPU_AVAILABLE = True | |
| except ImportError: | |
| ZEROGPU_AVAILABLE = False | |
| print("Warning: spaces module not available. Running without ZeroGPU support.") | |
| import gradio as gr | |
| import tempfile | |
| import os | |
| import torch | |
| import gc | |
| from demo_utils import load_model, process_video, save_video, image_to_video | |
| import av | |
| from PIL import Image | |
| import numpy as np | |
| model_cache = {} | |
| def get_model(device): | |
| if device not in model_cache: | |
| model_cache[device] = load_model(device=device) | |
| return model_cache[device] | |
| # Determine device: use CUDA if available locally or if ZeroGPU will provide it | |
| if ZEROGPU_AVAILABLE: | |
| device = "cuda" # ZeroGPU will provide GPU | |
| print("Using ZeroGPU (CUDA device will be allocated on demand)") | |
| elif torch.cuda.is_available(): | |
| device = "cuda" | |
| print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}") | |
| else: | |
| device = "cpu" | |
| print("No GPU available, using CPU") | |
| def cleanup_gpu(): | |
| """Clean up GPU memory.""" | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| def extract_metadata(file): | |
| if file is None: | |
| return "", None, None, None, None, None | |
| file_extension = os.path.splitext(file.name)[1].lower() | |
| is_image = file_extension in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'] | |
| if is_image: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video: | |
| tmp_path = tmp_video.name | |
| metadata = image_to_video(file.name, tmp_path, fps=1.0) | |
| total_frames = metadata['frames'] | |
| fps = metadata['fps'] | |
| original_height = metadata['height'] | |
| original_width = metadata['width'] | |
| info_text = f"{original_width}×{original_height} | Image (1 frame)" | |
| else: | |
| tmp_path = file.name | |
| container = av.open(tmp_path) | |
| video_stream = container.streams.video[0] | |
| total_frames = video_stream.frames | |
| fps = float(video_stream.average_rate) | |
| original_height = video_stream.height | |
| original_width = video_stream.width | |
| container.close() | |
| info_text = f"{original_width}×{original_height} | {total_frames} frames @ {fps:.1f} FPS" | |
| return info_text, tmp_path, total_frames, fps, original_width, original_height | |
| def handle_file_upload(file): | |
| metadata = extract_metadata(file) | |
| if metadata[1] is None: | |
| return "", None, None | |
| info_text, tmp_path, total_frames, fps, original_width, original_height = metadata | |
| return info_text, metadata, fps | |
| def _process_video_impl(file_info, gazing_ratio, task_loss_requirement, output_fps, progress=None): | |
| if file_info is None: | |
| return None, None, None, None, None, None, None, "No file uploaded" | |
| _, tmp_path, total_frames, fps, _, _ = file_info | |
| if tmp_path is None: | |
| return None, None, None, None, None, None, None, "Invalid file" | |
| # Yield initial status | |
| yield None, None, None, None, None, None, None, "Loading model..." | |
| if progress: | |
| progress(0.0, desc="Loading model...") | |
| setup = get_model(device) | |
| yield None, None, None, None, None, None, None, "Processing video..." | |
| if progress: | |
| progress(0.1, desc="Processing video...") | |
| status_messages = [] | |
| def update_progress(pct, msg): | |
| if progress: | |
| progress(pct, desc=msg) | |
| status_messages.append(msg) | |
| # Convert UI gazing ratio to model gazing ratio | |
| # UI: ranges from 1/196 to 265/196 (effective patches per frame / 196) | |
| # Model: needs value * (196/265) to get actual gazing ratio | |
| model_gazing_ratio = gazing_ratio * (196 / 265) | |
| for results in process_video( | |
| tmp_path, | |
| setup, | |
| gazing_ratio=model_gazing_ratio, | |
| task_loss_requirement=task_loss_requirement, | |
| progress_callback=update_progress, | |
| spatial_batch_size=2 # Process 4 spatial chunks at a time to avoid OOM | |
| ): | |
| if status_messages: | |
| yield None, None, None, None, None, None, None, status_messages[-1] | |
| yield None, None, None, None, None, None, None, "Saving output videos..." | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| original_path = os.path.join(tmpdir, "original.mp4") | |
| gazing_path = os.path.join(tmpdir, "gazing.mp4") | |
| recon_path = os.path.join(tmpdir, "reconstruction.mp4") | |
| scales_stitch_path = os.path.join(tmpdir, "scales_stitch.mp4") | |
| # Use output_fps if specified, otherwise use original video fps | |
| fps_to_use = output_fps if output_fps is not None else results['fps'] | |
| save_video(results['original_frames'], original_path, fps_to_use) | |
| save_video(results['gazing_frames'], gazing_path, fps_to_use) | |
| save_video(results['reconstruction_frames'], recon_path, fps_to_use) | |
| save_video(results['scales_stitch_frames'], scales_stitch_path, fps_to_use) | |
| with open(original_path, "rb") as f: | |
| original_data = f.read() | |
| with open(gazing_path, "rb") as f: | |
| gazing_data = f.read() | |
| with open(recon_path, "rb") as f: | |
| recon_data = f.read() | |
| with open(scales_stitch_path, "rb") as f: | |
| scales_stitch_data = f.read() | |
| original_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| original_file.write(original_data) | |
| original_file.close() | |
| gazing_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| gazing_file.write(gazing_data) | |
| gazing_file.close() | |
| recon_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| recon_file.write(recon_data) | |
| recon_file.close() | |
| scales_stitch_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| scales_stitch_file.write(scales_stitch_data) | |
| scales_stitch_file.close() | |
| gazing_pct_text = f"{results['gazing_pct']:.2%}" | |
| gazing_tokens_text = f"{results['total_gazing_tokens']:,}" | |
| total_tokens_text = f"{results['total_possible_tokens']:,}" | |
| yield ( | |
| gazing_pct_text, | |
| gazing_tokens_text, | |
| total_tokens_text, | |
| original_file.name, | |
| gazing_file.name, | |
| recon_file.name, | |
| scales_stitch_file.name, | |
| "Processing complete!" | |
| ) | |
| if ZEROGPU_AVAILABLE: | |
| process_video_ui = spaces.GPU(duration=120)(_process_video_impl) | |
| else: | |
| process_video_ui = _process_video_impl | |
| def extract_first_frame_thumbnail(video_path, output_path, size=(200, 200), force=False): | |
| """Extract first frame from video and save as thumbnail with fixed aspect ratio.""" | |
| if os.path.exists(output_path) and not force: | |
| return | |
| container = av.open(video_path) | |
| for frame in container.decode(video=0): | |
| img = frame.to_image() | |
| # Crop to center square first, then resize | |
| width, height = img.size | |
| min_dim = min(width, height) | |
| left = (width - min_dim) // 2 | |
| top = (height - min_dim) // 2 | |
| img_cropped = img.crop((left, top, left + min_dim, top + min_dim)) | |
| img_resized = img_cropped.resize(size, Image.LANCZOS) | |
| img_resized.save(output_path) | |
| break | |
| container.close() | |
| # Generate thumbnails for example videos | |
| example_videos = [ | |
| "example_inputs/doorbell.mp4", | |
| "example_inputs/tomjerry.mp4", | |
| "example_inputs/security.mp4", | |
| ] | |
| for video_path in example_videos: | |
| if os.path.exists(video_path): | |
| thumb_path = video_path.replace('.mp4', '_thumb.png') | |
| # Force regeneration with square aspect ratio at 100x100 to match gallery height | |
| extract_first_frame_thumbnail(video_path, thumb_path, size=(100, 100), force=True) | |
| # Load thumbnails as numpy arrays | |
| doorbell_thumb_img = np.array(Image.open("example_inputs/doorbell_thumb.png")) | |
| tomjerry_thumb_img = np.array(Image.open("example_inputs/tomjerry_thumb.png")) | |
| security_thumb_img = np.array(Image.open("example_inputs/security_thumb.png")) | |
| with gr.Blocks(title="AutoGaze Demo", delete_cache=(86400, 86400)) as demo: | |
| gr.Markdown("# AutoGaze Official Demo") | |
| gr.Markdown("## **Attend Before Attention: Efficient and Scalable Video Understanding via Autoregressive Gazing**") | |
| gr.Markdown(""" | |
| <div style="text-align: left; margin: 10px 0; font-size: 1.2em; font-weight: 600;"> | |
| 📄 <a href="https://arxiv.org/abs/2603.12254" target="_blank" style="text-decoration: none; color: inherit;">Paper</a> 🌐 <a href="https://autogaze.github.io" target="_blank" style="text-decoration: none; color: inherit;">Project Website</a> | |
| </div> | |
| """) | |
| file_metadata = gr.State() | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| uploaded_file = gr.File( | |
| label="Upload Video or Image", | |
| file_types=["video", "image"] | |
| ) | |
| with gr.Column(scale=1): | |
| file_info = gr.Textbox(label="File Info", interactive=False) | |
| process_button = gr.Button("Process Video", variant="primary") | |
| def load_example_video(evt: gr.SelectData): | |
| video_map = { | |
| 0: "example_inputs/doorbell.mp4", | |
| 1: "example_inputs/tomjerry.mp4", | |
| 2: "example_inputs/security.mp4", | |
| } | |
| return video_map[evt.index] | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Example Videos - Click Thumbnail to Load") | |
| example_gallery = gr.Gallery( | |
| value=[ | |
| (doorbell_thumb_img, "doorbell.mp4"), | |
| (tomjerry_thumb_img, "tomjerry.mp4"), | |
| (security_thumb_img, "security.mp4"), | |
| ], | |
| label="", | |
| show_label=False, | |
| columns=3, | |
| rows=1, | |
| height=200, | |
| object_fit="contain", | |
| allow_preview=False | |
| ) | |
| gr.Markdown("### Settings") | |
| with gr.Accordion("Output Settings", open=True): | |
| fps_slider = gr.Number( | |
| label="Output FPS", | |
| value=None, | |
| minimum=1, | |
| maximum=120, | |
| info="Frames per second for displaying output videos (only affects playback speed)" | |
| ) | |
| with gr.Accordion("Model Parameters", open=True): | |
| gazing_ratio_slider = gr.Slider( | |
| label="Gazing Ratio", | |
| minimum=round(1/196, 2), | |
| maximum=round(265/196, 2), | |
| step=0.01, | |
| value=0.75, | |
| info="Max fraction of patches to gaze at per frame" | |
| ) | |
| task_loss_slider = gr.Slider( | |
| label="Task Loss Requirement", | |
| minimum=0.0, | |
| maximum=1.5, | |
| step=0.05, | |
| value=0.7, | |
| info="Reconstruction loss threshold" | |
| ) | |
| with gr.Accordion("FAQ", open=False): | |
| gr.Markdown(""" | |
| **What file formats are supported?** | |
| The app supports common video formats (MP4, AVI, MOV, etc.) and image formats (JPG, PNG, etc.). | |
| **What is the Gazing Ratio?** | |
| The gazing ratio explicitly controls how many patches the model looks at per frame. Higher values mean more patches are selected. The range extends to past 1.0 because of multi-scale gazing; if all patches at all scales are selected, the ratio can reach up to 1.35. | |
| **What is Task Loss Requirement?** | |
| This threshold determines when the model stops gazing at a frame, based on the predicted reconstruction loss from the current gazed patches. Lower = more gazing, higher = less gazing. | |
| **How do Gazing Ratio and Task Loss interact?** | |
| These two parameters separately control the number of gazed patches in an image/video. This demo will take the stricter of the two requirements when determining how many patches to gaze at. For example, if the gazing ratio suggests gazing at 15% of patches, but the task loss requirement is met after only 7% patches, then only 7% patches will be gazed at. To only use one of the two parameters, set the other to its maximum value. | |
| """) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Results") | |
| status_text = gr.Markdown("Ready") | |
| with gr.Row(): | |
| gazing_pct = gr.Textbox(label="Gazing %", interactive=False) | |
| gazing_tokens = gr.Textbox(label="# Gazed Patches", interactive=False) | |
| total_tokens = gr.Textbox(label="Total Patches", interactive=False) | |
| with gr.Row(): | |
| original_video = gr.Video(label="Original", autoplay=False, loop=True) | |
| gazing_video = gr.Video(label="Gazing Pattern (all scales)", autoplay=False, loop=True) | |
| reconstruction_video = gr.Video(label="Reconstruction", autoplay=False, loop=True) | |
| with gr.Row(): | |
| scales_stitch_video = gr.Video(label="Gazing Pattern (individual scales)", autoplay=False, loop=True) | |
| example_gallery.select(load_example_video, outputs=uploaded_file) | |
| uploaded_file.change( | |
| fn=handle_file_upload, | |
| inputs=[uploaded_file], | |
| outputs=[file_info, file_metadata, fps_slider] | |
| ) | |
| process_button.click( | |
| fn=process_video_ui, | |
| inputs=[file_metadata, gazing_ratio_slider, task_loss_slider, fps_slider], | |
| outputs=[ | |
| gazing_pct, | |
| gazing_tokens, | |
| total_tokens, | |
| original_video, | |
| gazing_video, | |
| reconstruction_video, | |
| scales_stitch_video, | |
| status_text | |
| ] | |
| ).then( | |
| fn=cleanup_gpu, | |
| inputs=None, | |
| outputs=None | |
| ) | |
| # Clean up GPU memory when user disconnects | |
| demo.unload(cleanup_gpu) | |
| # Clear any cached models and free GPU memory at app startup | |
| print("Clearing model cache and GPU memory at startup...") | |
| model_cache.clear() | |
| cleanup_gpu() | |
| print("Startup cleanup complete.") | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |