AutoGaze / app.py
bfshi's picture
update links (#2)
7e6904b
# IMPORTANT: Import spaces first, before any CUDA-related packages (torch, etc.)
try:
import spaces
ZEROGPU_AVAILABLE = True
except ImportError:
ZEROGPU_AVAILABLE = False
print("Warning: spaces module not available. Running without ZeroGPU support.")
import gradio as gr
import tempfile
import os
import torch
import gc
from demo_utils import load_model, process_video, save_video, image_to_video
import av
from PIL import Image
import numpy as np
model_cache = {}
def get_model(device):
if device not in model_cache:
model_cache[device] = load_model(device=device)
return model_cache[device]
# Determine device: use CUDA if available locally or if ZeroGPU will provide it
if ZEROGPU_AVAILABLE:
device = "cuda" # ZeroGPU will provide GPU
print("Using ZeroGPU (CUDA device will be allocated on demand)")
elif torch.cuda.is_available():
device = "cuda"
print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}")
else:
device = "cpu"
print("No GPU available, using CPU")
def cleanup_gpu():
"""Clean up GPU memory."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def extract_metadata(file):
if file is None:
return "", None, None, None, None, None
file_extension = os.path.splitext(file.name)[1].lower()
is_image = file_extension in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
if is_image:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video:
tmp_path = tmp_video.name
metadata = image_to_video(file.name, tmp_path, fps=1.0)
total_frames = metadata['frames']
fps = metadata['fps']
original_height = metadata['height']
original_width = metadata['width']
info_text = f"{original_width}×{original_height} | Image (1 frame)"
else:
tmp_path = file.name
container = av.open(tmp_path)
video_stream = container.streams.video[0]
total_frames = video_stream.frames
fps = float(video_stream.average_rate)
original_height = video_stream.height
original_width = video_stream.width
container.close()
info_text = f"{original_width}×{original_height} | {total_frames} frames @ {fps:.1f} FPS"
return info_text, tmp_path, total_frames, fps, original_width, original_height
def handle_file_upload(file):
metadata = extract_metadata(file)
if metadata[1] is None:
return "", None, None
info_text, tmp_path, total_frames, fps, original_width, original_height = metadata
return info_text, metadata, fps
def _process_video_impl(file_info, gazing_ratio, task_loss_requirement, output_fps, progress=None):
if file_info is None:
return None, None, None, None, None, None, None, "No file uploaded"
_, tmp_path, total_frames, fps, _, _ = file_info
if tmp_path is None:
return None, None, None, None, None, None, None, "Invalid file"
# Yield initial status
yield None, None, None, None, None, None, None, "Loading model..."
if progress:
progress(0.0, desc="Loading model...")
setup = get_model(device)
yield None, None, None, None, None, None, None, "Processing video..."
if progress:
progress(0.1, desc="Processing video...")
status_messages = []
def update_progress(pct, msg):
if progress:
progress(pct, desc=msg)
status_messages.append(msg)
# Convert UI gazing ratio to model gazing ratio
# UI: ranges from 1/196 to 265/196 (effective patches per frame / 196)
# Model: needs value * (196/265) to get actual gazing ratio
model_gazing_ratio = gazing_ratio * (196 / 265)
for results in process_video(
tmp_path,
setup,
gazing_ratio=model_gazing_ratio,
task_loss_requirement=task_loss_requirement,
progress_callback=update_progress,
spatial_batch_size=2 # Process 4 spatial chunks at a time to avoid OOM
):
if status_messages:
yield None, None, None, None, None, None, None, status_messages[-1]
yield None, None, None, None, None, None, None, "Saving output videos..."
with tempfile.TemporaryDirectory() as tmpdir:
original_path = os.path.join(tmpdir, "original.mp4")
gazing_path = os.path.join(tmpdir, "gazing.mp4")
recon_path = os.path.join(tmpdir, "reconstruction.mp4")
scales_stitch_path = os.path.join(tmpdir, "scales_stitch.mp4")
# Use output_fps if specified, otherwise use original video fps
fps_to_use = output_fps if output_fps is not None else results['fps']
save_video(results['original_frames'], original_path, fps_to_use)
save_video(results['gazing_frames'], gazing_path, fps_to_use)
save_video(results['reconstruction_frames'], recon_path, fps_to_use)
save_video(results['scales_stitch_frames'], scales_stitch_path, fps_to_use)
with open(original_path, "rb") as f:
original_data = f.read()
with open(gazing_path, "rb") as f:
gazing_data = f.read()
with open(recon_path, "rb") as f:
recon_data = f.read()
with open(scales_stitch_path, "rb") as f:
scales_stitch_data = f.read()
original_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
original_file.write(original_data)
original_file.close()
gazing_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
gazing_file.write(gazing_data)
gazing_file.close()
recon_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
recon_file.write(recon_data)
recon_file.close()
scales_stitch_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
scales_stitch_file.write(scales_stitch_data)
scales_stitch_file.close()
gazing_pct_text = f"{results['gazing_pct']:.2%}"
gazing_tokens_text = f"{results['total_gazing_tokens']:,}"
total_tokens_text = f"{results['total_possible_tokens']:,}"
yield (
gazing_pct_text,
gazing_tokens_text,
total_tokens_text,
original_file.name,
gazing_file.name,
recon_file.name,
scales_stitch_file.name,
"Processing complete!"
)
if ZEROGPU_AVAILABLE:
process_video_ui = spaces.GPU(duration=120)(_process_video_impl)
else:
process_video_ui = _process_video_impl
def extract_first_frame_thumbnail(video_path, output_path, size=(200, 200), force=False):
"""Extract first frame from video and save as thumbnail with fixed aspect ratio."""
if os.path.exists(output_path) and not force:
return
container = av.open(video_path)
for frame in container.decode(video=0):
img = frame.to_image()
# Crop to center square first, then resize
width, height = img.size
min_dim = min(width, height)
left = (width - min_dim) // 2
top = (height - min_dim) // 2
img_cropped = img.crop((left, top, left + min_dim, top + min_dim))
img_resized = img_cropped.resize(size, Image.LANCZOS)
img_resized.save(output_path)
break
container.close()
# Generate thumbnails for example videos
example_videos = [
"example_inputs/doorbell.mp4",
"example_inputs/tomjerry.mp4",
"example_inputs/security.mp4",
]
for video_path in example_videos:
if os.path.exists(video_path):
thumb_path = video_path.replace('.mp4', '_thumb.png')
# Force regeneration with square aspect ratio at 100x100 to match gallery height
extract_first_frame_thumbnail(video_path, thumb_path, size=(100, 100), force=True)
# Load thumbnails as numpy arrays
doorbell_thumb_img = np.array(Image.open("example_inputs/doorbell_thumb.png"))
tomjerry_thumb_img = np.array(Image.open("example_inputs/tomjerry_thumb.png"))
security_thumb_img = np.array(Image.open("example_inputs/security_thumb.png"))
with gr.Blocks(title="AutoGaze Demo", delete_cache=(86400, 86400)) as demo:
gr.Markdown("# AutoGaze Official Demo")
gr.Markdown("## **Attend Before Attention: Efficient and Scalable Video Understanding via Autoregressive Gazing**")
gr.Markdown("""
<div style="text-align: left; margin: 10px 0; font-size: 1.2em; font-weight: 600;">
📄 <a href="https://arxiv.org/abs/2603.12254" target="_blank" style="text-decoration: none; color: inherit;">Paper</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 🌐 <a href="https://autogaze.github.io" target="_blank" style="text-decoration: none; color: inherit;">Project Website</a>
</div>
""")
file_metadata = gr.State()
with gr.Row():
with gr.Column(scale=2):
uploaded_file = gr.File(
label="Upload Video or Image",
file_types=["video", "image"]
)
with gr.Column(scale=1):
file_info = gr.Textbox(label="File Info", interactive=False)
process_button = gr.Button("Process Video", variant="primary")
def load_example_video(evt: gr.SelectData):
video_map = {
0: "example_inputs/doorbell.mp4",
1: "example_inputs/tomjerry.mp4",
2: "example_inputs/security.mp4",
}
return video_map[evt.index]
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Example Videos - Click Thumbnail to Load")
example_gallery = gr.Gallery(
value=[
(doorbell_thumb_img, "doorbell.mp4"),
(tomjerry_thumb_img, "tomjerry.mp4"),
(security_thumb_img, "security.mp4"),
],
label="",
show_label=False,
columns=3,
rows=1,
height=200,
object_fit="contain",
allow_preview=False
)
gr.Markdown("### Settings")
with gr.Accordion("Output Settings", open=True):
fps_slider = gr.Number(
label="Output FPS",
value=None,
minimum=1,
maximum=120,
info="Frames per second for displaying output videos (only affects playback speed)"
)
with gr.Accordion("Model Parameters", open=True):
gazing_ratio_slider = gr.Slider(
label="Gazing Ratio",
minimum=round(1/196, 2),
maximum=round(265/196, 2),
step=0.01,
value=0.75,
info="Max fraction of patches to gaze at per frame"
)
task_loss_slider = gr.Slider(
label="Task Loss Requirement",
minimum=0.0,
maximum=1.5,
step=0.05,
value=0.7,
info="Reconstruction loss threshold"
)
with gr.Accordion("FAQ", open=False):
gr.Markdown("""
**What file formats are supported?**
The app supports common video formats (MP4, AVI, MOV, etc.) and image formats (JPG, PNG, etc.).
**What is the Gazing Ratio?**
The gazing ratio explicitly controls how many patches the model looks at per frame. Higher values mean more patches are selected. The range extends to past 1.0 because of multi-scale gazing; if all patches at all scales are selected, the ratio can reach up to 1.35.
**What is Task Loss Requirement?**
This threshold determines when the model stops gazing at a frame, based on the predicted reconstruction loss from the current gazed patches. Lower = more gazing, higher = less gazing.
**How do Gazing Ratio and Task Loss interact?**
These two parameters separately control the number of gazed patches in an image/video. This demo will take the stricter of the two requirements when determining how many patches to gaze at. For example, if the gazing ratio suggests gazing at 15% of patches, but the task loss requirement is met after only 7% patches, then only 7% patches will be gazed at. To only use one of the two parameters, set the other to its maximum value.
""")
with gr.Column(scale=2):
gr.Markdown("### Results")
status_text = gr.Markdown("Ready")
with gr.Row():
gazing_pct = gr.Textbox(label="Gazing %", interactive=False)
gazing_tokens = gr.Textbox(label="# Gazed Patches", interactive=False)
total_tokens = gr.Textbox(label="Total Patches", interactive=False)
with gr.Row():
original_video = gr.Video(label="Original", autoplay=False, loop=True)
gazing_video = gr.Video(label="Gazing Pattern (all scales)", autoplay=False, loop=True)
reconstruction_video = gr.Video(label="Reconstruction", autoplay=False, loop=True)
with gr.Row():
scales_stitch_video = gr.Video(label="Gazing Pattern (individual scales)", autoplay=False, loop=True)
example_gallery.select(load_example_video, outputs=uploaded_file)
uploaded_file.change(
fn=handle_file_upload,
inputs=[uploaded_file],
outputs=[file_info, file_metadata, fps_slider]
)
process_button.click(
fn=process_video_ui,
inputs=[file_metadata, gazing_ratio_slider, task_loss_slider, fps_slider],
outputs=[
gazing_pct,
gazing_tokens,
total_tokens,
original_video,
gazing_video,
reconstruction_video,
scales_stitch_video,
status_text
]
).then(
fn=cleanup_gpu,
inputs=None,
outputs=None
)
# Clean up GPU memory when user disconnects
demo.unload(cleanup_gpu)
# Clear any cached models and free GPU memory at app startup
print("Clearing model cache and GPU memory at startup...")
model_cache.clear()
cleanup_gpu()
print("Startup cleanup complete.")
if __name__ == "__main__":
demo.launch(share=True)