AutoGaze / app.py
uriva's picture
Change API output to gr.File for proper file serving
f2edce9
# IMPORTANT: Import spaces first, before any CUDA-related packages (torch, etc.)
try:
import spaces
ZEROGPU_AVAILABLE = True
except ImportError:
ZEROGPU_AVAILABLE = False
print("Warning: spaces module not available. Running without ZeroGPU support.")
import gradio as gr
import tempfile
import os
import torch
import gc
from demo_utils import load_model, process_video, save_video, image_to_video
import av
from PIL import Image
import numpy as np
model_cache = {}
def get_model(device):
if device not in model_cache:
model_cache[device] = load_model(device=device)
return model_cache[device]
# Determine device: use CUDA if available locally or if ZeroGPU will provide it
if ZEROGPU_AVAILABLE:
device = "cuda" # ZeroGPU will provide GPU
print("Using ZeroGPU (CUDA device will be allocated on demand)")
elif torch.cuda.is_available():
device = "cuda"
print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}")
else:
device = "cpu"
print("No GPU available, using CPU")
def cleanup_gpu():
"""Clean up GPU memory."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def extract_metadata(file):
if file is None:
return "", None, None, None, None, None
# Handle both file objects (UI) and string paths (API)
file_path = file.name if hasattr(file, "name") else str(file)
file_extension = os.path.splitext(file_path)[1].lower()
is_image = file_extension in [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"]
if is_image:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video:
tmp_path = tmp_video.name
metadata = image_to_video(file_path, tmp_path, fps=1.0)
total_frames = metadata["frames"]
fps = metadata["fps"]
original_height = metadata["height"]
original_width = metadata["width"]
info_text = f"{original_width}×{original_height} | Image (1 frame)"
else:
tmp_path = file_path
container = av.open(tmp_path)
video_stream = container.streams.video[0]
total_frames = video_stream.frames
fps = float(video_stream.average_rate)
original_height = video_stream.height
original_width = video_stream.width
container.close()
info_text = f"{original_width}×{original_height} | {total_frames} frames @ {fps:.1f} FPS"
return info_text, tmp_path, total_frames, fps, original_width, original_height
def handle_file_upload(file):
metadata = extract_metadata(file)
if metadata[1] is None:
return "", None, None
info_text, tmp_path, total_frames, fps, original_width, original_height = metadata
return info_text, metadata, fps
def _process_video_impl(
file_info, gazing_ratio, task_loss_requirement, output_fps, progress=None
):
if file_info is None:
return None, None, None, None, None, None, None, "No file uploaded"
_, tmp_path, total_frames, fps, _, _ = file_info
if tmp_path is None:
return None, None, None, None, None, None, None, "Invalid file"
# Yield initial status
yield None, None, None, None, None, None, None, "Loading model..."
if progress:
progress(0.0, desc="Loading model...")
setup = get_model(device)
yield None, None, None, None, None, None, None, "Processing video..."
if progress:
progress(0.1, desc="Processing video...")
status_messages = []
def update_progress(pct, msg):
if progress:
progress(pct, desc=msg)
status_messages.append(msg)
# Convert UI gazing ratio to model gazing ratio
# UI: ranges from 1/196 to 265/196 (effective patches per frame / 196)
# Model: needs value * (196/265) to get actual gazing ratio
model_gazing_ratio = gazing_ratio * (196 / 265)
for results in process_video(
tmp_path,
setup,
gazing_ratio=model_gazing_ratio,
task_loss_requirement=task_loss_requirement,
progress_callback=update_progress,
spatial_batch_size=2, # Process 4 spatial chunks at a time to avoid OOM
):
if status_messages:
yield None, None, None, None, None, None, None, status_messages[-1]
yield None, None, None, None, None, None, None, "Saving output videos..."
with tempfile.TemporaryDirectory() as tmpdir:
original_path = os.path.join(tmpdir, "original.mp4")
gazing_path = os.path.join(tmpdir, "gazing.mp4")
recon_path = os.path.join(tmpdir, "reconstruction.mp4")
scales_stitch_path = os.path.join(tmpdir, "scales_stitch.mp4")
# Use output_fps if specified, otherwise use original video fps
fps_to_use = output_fps if output_fps is not None else results["fps"]
save_video(results["original_frames"], original_path, fps_to_use)
save_video(results["gazing_frames"], gazing_path, fps_to_use)
save_video(results["reconstruction_frames"], recon_path, fps_to_use)
save_video(results["scales_stitch_frames"], scales_stitch_path, fps_to_use)
with open(original_path, "rb") as f:
original_data = f.read()
with open(gazing_path, "rb") as f:
gazing_data = f.read()
with open(recon_path, "rb") as f:
recon_data = f.read()
with open(scales_stitch_path, "rb") as f:
scales_stitch_data = f.read()
original_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
original_file.write(original_data)
original_file.close()
gazing_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
gazing_file.write(gazing_data)
gazing_file.close()
recon_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
recon_file.write(recon_data)
recon_file.close()
scales_stitch_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
scales_stitch_file.write(scales_stitch_data)
scales_stitch_file.close()
gazing_pct_text = f"{results['gazing_pct']:.2%}"
gazing_tokens_text = f"{results['total_gazing_tokens']:,}"
total_tokens_text = f"{results['total_possible_tokens']:,}"
yield (
gazing_pct_text,
gazing_tokens_text,
total_tokens_text,
original_file.name,
gazing_file.name,
recon_file.name,
scales_stitch_file.name,
"Processing complete!",
)
if ZEROGPU_AVAILABLE:
process_video_ui = spaces.GPU(duration=120)(_process_video_impl)
else:
process_video_ui = _process_video_impl
def _process_video_api_impl(
file_path, gazing_ratio, task_loss_requirement, output_fps, progress=None
):
"""API-friendly endpoint that takes a file path string instead of gr.File.
Returns the gazing video as a gr.File output for proper file serving."""
if not file_path or not os.path.exists(file_path):
raise gr.Error("file not found")
metadata = extract_metadata(file_path)
if metadata[1] is None:
raise gr.Error("could not read file")
_, tmp_path, total_frames, fps, _, _ = metadata
yield gr.update()
if progress:
progress(0.0, desc="Loading model...")
setup = get_model(device)
yield gr.update()
if progress:
progress(0.1, desc="Processing video...")
def update_progress(pct, msg):
if progress:
progress(pct, desc=msg)
model_gazing_ratio = gazing_ratio * (196 / 265)
for results in process_video(
tmp_path,
setup,
gazing_ratio=model_gazing_ratio,
task_loss_requirement=task_loss_requirement,
progress_callback=update_progress,
spatial_batch_size=2,
):
yield gr.update()
fps_to_use = output_fps if output_fps is not None else results["fps"]
gazing_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
save_video(results["gazing_frames"], gazing_file.name, fps_to_use)
gazing_file.close()
yield gr.File(value=gazing_file.name)
if ZEROGPU_AVAILABLE:
process_video_api = spaces.GPU(duration=120)(_process_video_api_impl)
else:
process_video_api = _process_video_api_impl
def extract_first_frame_thumbnail(
video_path, output_path, size=(200, 200), force=False
):
"""Extract first frame from video and save as thumbnail with fixed aspect ratio."""
if os.path.exists(output_path) and not force:
return
container = av.open(video_path)
for frame in container.decode(video=0):
img = frame.to_image()
# Crop to center square first, then resize
width, height = img.size
min_dim = min(width, height)
left = (width - min_dim) // 2
top = (height - min_dim) // 2
img_cropped = img.crop((left, top, left + min_dim, top + min_dim))
img_resized = img_cropped.resize(size, Image.LANCZOS)
img_resized.save(output_path)
break
container.close()
# Generate thumbnails for example videos
example_videos = [
"example_inputs/doorbell.mp4",
"example_inputs/tomjerry.mp4",
"example_inputs/security.mp4",
]
for video_path in example_videos:
if os.path.exists(video_path):
thumb_path = video_path.replace(".mp4", "_thumb.png")
# Force regeneration with square aspect ratio at 100x100 to match gallery height
extract_first_frame_thumbnail(
video_path, thumb_path, size=(100, 100), force=True
)
# Load thumbnails as numpy arrays
doorbell_thumb_img = np.array(Image.open("example_inputs/doorbell_thumb.png"))
tomjerry_thumb_img = np.array(Image.open("example_inputs/tomjerry_thumb.png"))
security_thumb_img = np.array(Image.open("example_inputs/security_thumb.png"))
with gr.Blocks(title="AutoGaze Demo", delete_cache=(86400, 86400)) as demo:
gr.Markdown("# AutoGaze Official Demo")
gr.Markdown(
"## **Attend Before Attention: Efficient and Scalable Video Understanding via Autoregressive Gazing**"
)
gr.Markdown("""
<div style="text-align: left; margin: 10px 0; font-size: 1.2em; font-weight: 600;">
📄 <a href="https://arxiv.org/abs/2603.12254" target="_blank" style="text-decoration: none; color: inherit;">Paper</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 🌐 <a href="https://autogaze.github.io" target="_blank" style="text-decoration: none; color: inherit;">Project Website</a>
</div>
""")
file_metadata = gr.State()
with gr.Row():
with gr.Column(scale=2):
uploaded_file = gr.File(
label="Upload Video or Image", file_types=["video", "image"]
)
with gr.Column(scale=1):
file_info = gr.Textbox(label="File Info", interactive=False)
process_button = gr.Button("Process Video", variant="primary")
def load_example_video(evt: gr.SelectData):
video_map = {
0: "example_inputs/doorbell.mp4",
1: "example_inputs/tomjerry.mp4",
2: "example_inputs/security.mp4",
}
return video_map[evt.index]
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Example Videos - Click Thumbnail to Load")
example_gallery = gr.Gallery(
value=[
(doorbell_thumb_img, "doorbell.mp4"),
(tomjerry_thumb_img, "tomjerry.mp4"),
(security_thumb_img, "security.mp4"),
],
label="",
show_label=False,
columns=3,
rows=1,
height=200,
object_fit="contain",
allow_preview=False,
)
gr.Markdown("### Settings")
with gr.Accordion("Output Settings", open=True):
fps_slider = gr.Number(
label="Output FPS",
value=None,
minimum=1,
maximum=120,
info="Frames per second for displaying output videos (only affects playback speed)",
)
with gr.Accordion("Model Parameters", open=True):
gazing_ratio_slider = gr.Slider(
label="Gazing Ratio",
minimum=round(1 / 196, 2),
maximum=round(265 / 196, 2),
step=0.01,
value=0.75,
info="Max fraction of patches to gaze at per frame",
)
task_loss_slider = gr.Slider(
label="Task Loss Requirement",
minimum=0.0,
maximum=1.5,
step=0.05,
value=0.7,
info="Reconstruction loss threshold",
)
with gr.Accordion("FAQ", open=False):
gr.Markdown("""
**What file formats are supported?**
The app supports common video formats (MP4, AVI, MOV, etc.) and image formats (JPG, PNG, etc.).
**What is the Gazing Ratio?**
The gazing ratio explicitly controls how many patches the model looks at per frame. Higher values mean more patches are selected. The range extends to past 1.0 because of multi-scale gazing; if all patches at all scales are selected, the ratio can reach up to 1.35.
**What is Task Loss Requirement?**
This threshold determines when the model stops gazing at a frame, based on the predicted reconstruction loss from the current gazed patches. Lower = more gazing, higher = less gazing.
**How do Gazing Ratio and Task Loss interact?**
These two parameters separately control the number of gazed patches in an image/video. This demo will take the stricter of the two requirements when determining how many patches to gaze at. For example, if the gazing ratio suggests gazing at 15% of patches, but the task loss requirement is met after only 7% patches, then only 7% patches will be gazed at. To only use one of the two parameters, set the other to its maximum value.
""")
with gr.Column(scale=2):
gr.Markdown("### Results")
status_text = gr.Markdown("Ready")
with gr.Row():
gazing_pct = gr.Textbox(label="Gazing %", interactive=False)
gazing_tokens = gr.Textbox(label="# Gazed Patches", interactive=False)
total_tokens = gr.Textbox(label="Total Patches", interactive=False)
with gr.Row():
original_video = gr.Video(label="Original", autoplay=False, loop=True)
gazing_video = gr.Video(
label="Gazing Pattern (all scales)", autoplay=False, loop=True
)
reconstruction_video = gr.Video(
label="Reconstruction", autoplay=False, loop=True
)
with gr.Row():
scales_stitch_video = gr.Video(
label="Gazing Pattern (individual scales)",
autoplay=False,
loop=True,
)
example_gallery.select(load_example_video, outputs=uploaded_file)
uploaded_file.change(
fn=handle_file_upload,
inputs=[uploaded_file],
outputs=[file_info, file_metadata, fps_slider],
)
process_button.click(
fn=process_video_ui,
inputs=[file_metadata, gazing_ratio_slider, task_loss_slider, fps_slider],
outputs=[
gazing_pct,
gazing_tokens,
total_tokens,
original_video,
gazing_video,
reconstruction_video,
scales_stitch_video,
status_text,
],
).then(fn=cleanup_gpu, inputs=None, outputs=None)
# --- API-friendly endpoint (hidden tab, bypasses FileData validation) ---
with gr.Tab("API", visible=False):
api_file_path = gr.Textbox(label="File Path")
api_gazing_ratio = gr.Slider(
minimum=round(1 / 196, 2),
maximum=round(265 / 196, 2),
step=0.01,
value=0.75,
label="Gazing Ratio",
)
api_task_loss = gr.Slider(
minimum=0.0,
maximum=1.5,
step=0.05,
value=0.7,
label="Task Loss Requirement",
)
api_output_fps = gr.Number(label="Output FPS", value=None)
api_button = gr.Button("Process (API)")
api_result = gr.File(label="Result")
api_button.click(
fn=process_video_api,
inputs=[api_file_path, api_gazing_ratio, api_task_loss, api_output_fps],
outputs=[api_result],
api_name="process_video_api",
)
# Clean up GPU memory when user disconnects
demo.unload(cleanup_gpu)
# Clear any cached models and free GPU memory at app startup
print("Clearing model cache and GPU memory at startup...")
model_cache.clear()
cleanup_gpu()
print("Startup cleanup complete.")
if __name__ == "__main__":
demo.launch(share=True, show_error=True)