# IMPORTANT: Import spaces first, before any CUDA-related packages (torch, etc.) try: import spaces ZEROGPU_AVAILABLE = True except ImportError: ZEROGPU_AVAILABLE = False print("Warning: spaces module not available. Running without ZeroGPU support.") import gradio as gr import tempfile import os import torch import gc from demo_utils import load_model, process_video, save_video, image_to_video import av from PIL import Image import numpy as np model_cache = {} def get_model(device): if device not in model_cache: model_cache[device] = load_model(device=device) return model_cache[device] # Determine device: use CUDA if available locally or if ZeroGPU will provide it if ZEROGPU_AVAILABLE: device = "cuda" # ZeroGPU will provide GPU print("Using ZeroGPU (CUDA device will be allocated on demand)") elif torch.cuda.is_available(): device = "cuda" print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}") else: device = "cpu" print("No GPU available, using CPU") def cleanup_gpu(): """Clean up GPU memory.""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() def extract_metadata(file): if file is None: return "", None, None, None, None, None file_extension = os.path.splitext(file.name)[1].lower() is_image = file_extension in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'] if is_image: with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video: tmp_path = tmp_video.name metadata = image_to_video(file.name, tmp_path, fps=1.0) total_frames = metadata['frames'] fps = metadata['fps'] original_height = metadata['height'] original_width = metadata['width'] info_text = f"{original_width}×{original_height} | Image (1 frame)" else: tmp_path = file.name container = av.open(tmp_path) video_stream = container.streams.video[0] total_frames = video_stream.frames fps = float(video_stream.average_rate) original_height = video_stream.height original_width = video_stream.width container.close() info_text = f"{original_width}×{original_height} | {total_frames} frames @ {fps:.1f} FPS" return info_text, tmp_path, total_frames, fps, original_width, original_height def handle_file_upload(file): metadata = extract_metadata(file) if metadata[1] is None: return "", None, None info_text, tmp_path, total_frames, fps, original_width, original_height = metadata return info_text, metadata, fps def _process_video_impl(file_info, gazing_ratio, task_loss_requirement, output_fps, progress=None): if file_info is None: return None, None, None, None, None, None, None, "No file uploaded" _, tmp_path, total_frames, fps, _, _ = file_info if tmp_path is None: return None, None, None, None, None, None, None, "Invalid file" # Yield initial status yield None, None, None, None, None, None, None, "Loading model..." if progress: progress(0.0, desc="Loading model...") setup = get_model(device) yield None, None, None, None, None, None, None, "Processing video..." if progress: progress(0.1, desc="Processing video...") status_messages = [] def update_progress(pct, msg): if progress: progress(pct, desc=msg) status_messages.append(msg) # Convert UI gazing ratio to model gazing ratio # UI: ranges from 1/196 to 265/196 (effective patches per frame / 196) # Model: needs value * (196/265) to get actual gazing ratio model_gazing_ratio = gazing_ratio * (196 / 265) for results in process_video( tmp_path, setup, gazing_ratio=model_gazing_ratio, task_loss_requirement=task_loss_requirement, progress_callback=update_progress, spatial_batch_size=2 # Process 4 spatial chunks at a time to avoid OOM ): if status_messages: yield None, None, None, None, None, None, None, status_messages[-1] yield None, None, None, None, None, None, None, "Saving output videos..." with tempfile.TemporaryDirectory() as tmpdir: original_path = os.path.join(tmpdir, "original.mp4") gazing_path = os.path.join(tmpdir, "gazing.mp4") recon_path = os.path.join(tmpdir, "reconstruction.mp4") scales_stitch_path = os.path.join(tmpdir, "scales_stitch.mp4") # Use output_fps if specified, otherwise use original video fps fps_to_use = output_fps if output_fps is not None else results['fps'] save_video(results['original_frames'], original_path, fps_to_use) save_video(results['gazing_frames'], gazing_path, fps_to_use) save_video(results['reconstruction_frames'], recon_path, fps_to_use) save_video(results['scales_stitch_frames'], scales_stitch_path, fps_to_use) with open(original_path, "rb") as f: original_data = f.read() with open(gazing_path, "rb") as f: gazing_data = f.read() with open(recon_path, "rb") as f: recon_data = f.read() with open(scales_stitch_path, "rb") as f: scales_stitch_data = f.read() original_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") original_file.write(original_data) original_file.close() gazing_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") gazing_file.write(gazing_data) gazing_file.close() recon_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") recon_file.write(recon_data) recon_file.close() scales_stitch_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") scales_stitch_file.write(scales_stitch_data) scales_stitch_file.close() gazing_pct_text = f"{results['gazing_pct']:.2%}" gazing_tokens_text = f"{results['total_gazing_tokens']:,}" total_tokens_text = f"{results['total_possible_tokens']:,}" yield ( gazing_pct_text, gazing_tokens_text, total_tokens_text, original_file.name, gazing_file.name, recon_file.name, scales_stitch_file.name, "Processing complete!" ) if ZEROGPU_AVAILABLE: process_video_ui = spaces.GPU(duration=120)(_process_video_impl) else: process_video_ui = _process_video_impl def extract_first_frame_thumbnail(video_path, output_path, size=(200, 200), force=False): """Extract first frame from video and save as thumbnail with fixed aspect ratio.""" if os.path.exists(output_path) and not force: return container = av.open(video_path) for frame in container.decode(video=0): img = frame.to_image() # Crop to center square first, then resize width, height = img.size min_dim = min(width, height) left = (width - min_dim) // 2 top = (height - min_dim) // 2 img_cropped = img.crop((left, top, left + min_dim, top + min_dim)) img_resized = img_cropped.resize(size, Image.LANCZOS) img_resized.save(output_path) break container.close() # Generate thumbnails for example videos example_videos = [ "example_inputs/doorbell.mp4", "example_inputs/tomjerry.mp4", "example_inputs/security.mp4", ] for video_path in example_videos: if os.path.exists(video_path): thumb_path = video_path.replace('.mp4', '_thumb.png') # Force regeneration with square aspect ratio at 100x100 to match gallery height extract_first_frame_thumbnail(video_path, thumb_path, size=(100, 100), force=True) # Load thumbnails as numpy arrays doorbell_thumb_img = np.array(Image.open("example_inputs/doorbell_thumb.png")) tomjerry_thumb_img = np.array(Image.open("example_inputs/tomjerry_thumb.png")) security_thumb_img = np.array(Image.open("example_inputs/security_thumb.png")) with gr.Blocks(title="AutoGaze Demo", delete_cache=(86400, 86400)) as demo: gr.Markdown("# AutoGaze Official Demo") gr.Markdown("## **Attend Before Attention: Efficient and Scalable Video Understanding via Autoregressive Gazing**") gr.Markdown("""