Generalist-IDM / app.py
lastdefiance20's picture
chore: modify duration 10->8
1df2909
#!/usr/bin/env python3
"""Gradio demo for Generalist-IDM model."""
import http.server
import socketserver
import subprocess
import tempfile
import threading
import uuid
from functools import partial
from pathlib import Path
import spaces
import gradio as gr
from loguru import logger
from inference import InferenceConfig, InferencePipeline
# Constants
MODEL_ID = "open-world-agents/Generalist-IDM-1B"
# HF Space embed URL (use hf.space domain for iframe embedding)
VISUALIZER_SPACE_URL = "https://open-world-agents-visualize-dataset.hf.space"
OUTPUT_DIR = Path(tempfile.gettempdir()) / "idm_demo_outputs"
OUTPUT_DIR.mkdir(exist_ok=True)
FILE_SERVER_PORT = 8765
class CORSRequestHandler(http.server.SimpleHTTPRequestHandler):
"""HTTP handler with CORS support for cross-origin requests."""
def __init__(self, *args, directory=None, **kwargs):
self.directory = directory
super().__init__(*args, directory=directory, **kwargs)
def end_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS")
self.send_header("Access-Control-Allow-Headers", "*")
super().end_headers()
def do_OPTIONS(self):
self.send_response(200)
self.end_headers()
def start_file_server(directory: Path, port: int) -> None:
"""Start a simple HTTP file server in a background thread."""
handler = partial(CORSRequestHandler, directory=str(directory))
with socketserver.TCPServer(("0.0.0.0", port), handler) as httpd:
httpd.allow_reuse_address = True
logger.info(f"File server running at http://0.0.0.0:{port}")
httpd.serve_forever()
def cut_video_to_duration(input_path: str, output_path: str, duration: float = 30.0) -> str:
"""Cut video to specified duration, resize to 448x448, and set keyframes using ffmpeg."""
TARGET_WIDTH = "448"
TARGET_HEIGHT = "448"
cmd = [
"ffmpeg",
"-y", # Overwrite output file
"-i",
input_path,
"-t",
str(duration),
"-vsync",
"1",
"-filter:v",
f"fps=60,scale={TARGET_WIDTH}:{TARGET_HEIGHT}",
"-c:v",
"libx264",
"-x264-params",
"keyint=30:no-scenecut=1:bframes=0",
"-an", # No audio
output_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg error: {result.stderr}")
return output_path
def get_video_duration(video_path: str) -> float:
"""Get video duration using ffprobe."""
cmd = [
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
video_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"ffprobe error: {result.stderr}")
return float(result.stdout.strip())
def create_mcap_from_video(video_path: str, mcap_path: str, fps: float = 20.0):
"""Create MCAP file with screen events from video."""
from mcap_owa.highlevel import OWAMcapWriter
from owa.core import MESSAGES
ScreenCaptured = MESSAGES["desktop/ScreenCaptured"]
duration = get_video_duration(video_path)
video_filename = Path(video_path).name
with OWAMcapWriter(mcap_path) as writer:
interval_ns = int(1e9 / fps)
num_frames = int(duration * fps)
for i in range(num_frames):
timestamp_ns = i * interval_ns
screen_msg = ScreenCaptured(
utc_ns=timestamp_ns,
media_ref={"uri": video_filename, "pts_ns": timestamp_ns},
)
writer.write_message(screen_msg, topic="screen", timestamp=timestamp_ns)
logger.info(f"Created MCAP with {num_frames} frames at {fps} FPS")
def get_gpu_duration(mcap_path: str, output_mcap_path: str, duration: float):
"""Calculate dynamic GPU duration based on video duration."""
return duration * 10 + 10
@spaces.GPU(duration=get_gpu_duration)
def run_idm_inference(mcap_path: str, output_mcap_path: str, duration: float):
"""Run IDM inference on MCAP file."""
config = InferenceConfig(
model_path=MODEL_ID,
device="cuda",
max_context_length=2048,
time_shift_seconds_for_action=0.1,
)
pipeline = InferencePipeline(config)
pipeline.pseudo_label_action(mcap_path, output_mcap_path)
logger.success(f"Generated predictions: {output_mcap_path}")
def process_video(video_file, duration, progress=gr.Progress()):
"""Main processing function for Gradio."""
if video_file is None:
return None, "", "<p>Please upload a video file.</p>"
# Create unique output directory
session_id = str(uuid.uuid4())[:8]
session_dir = OUTPUT_DIR / session_id
session_dir.mkdir(exist_ok=True)
try:
progress(0.1, desc=f"Cutting video to {duration} seconds...")
# Cut video to specified duration
input_video = video_file
output_video = str(session_dir / "input.mkv")
cut_video_to_duration(input_video, output_video, duration)
progress(0.2, desc="Creating MCAP from video...")
# Create input MCAP
input_mcap = str(session_dir / "input.mcap")
create_mcap_from_video(output_video, input_mcap)
progress(0.3, desc="Running IDM inference...")
# Run inference
output_mcap = str(session_dir / "output.mcap")
run_idm_inference(input_mcap, output_mcap, duration)
progress(0.9, desc="Preparing output...")
# Return output files
output_files = [output_video, output_mcap]
# Build file URLs with UUID-based paths for privacy
# Use Gradio's file serving which works on HF Spaces
# Gradio serves files at /gradio_api/file={absolute_path}
server_host = "lastdefiance20-generalist-idm.hf.space"
mcap_url = f"https://{server_host}/gradio_api/file={output_mcap}"
mkv_url = f"https://{server_host}/gradio_api/file={output_video}"
# Create visualization iframe URL with direct file paths
viz_url = f"{VISUALIZER_SPACE_URL}?mcap={mcap_url}&mkv={mkv_url}"
# Create iframe HTML for embedded visualizer (following HF Spaces embed docs)
iframe_html = f'''
<div style="margin-top: 10px;">
<p><strong>๐Ÿ“Š Visualization:</strong> <a href="{viz_url}" target="_blank">Open in new tab โ†—</a></p>
<iframe
src="{viz_url}"
frameborder="0"
width="100%"
height="1000"
style="border: 1px solid #ddd; border-radius: 8px; background: #fff;"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
allowfullscreen
></iframe>
</div>
'''
progress(1.0, desc="Done!")
return output_files, viz_url, iframe_html
except Exception as e:
logger.error(f"Error processing video: {e}")
import traceback
error_html = f"<p style='color:red;'>Error: {str(e)}</p><pre>{traceback.format_exc()}</pre>"
return None, "", error_html
# Custom CSS for wide layout
css = ".gradio-container {max-width: 2000px !important; margin: 0 auto} #output_visualizer {height: 1000px;} #input_video {height: 400px;}"
# Create Gradio interface
with gr.Blocks(title="Generalist-IDM Demo", css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown("# ๐ŸŽฎ Generalist-IDM-1B Demo")
gr.Markdown("Upload a gameplay video and the model will predict keyboard and mouse actions. [[model]](https://huggingface.co/open-world-agents/Generalist-IDM-1B)")
with gr.Row():
with gr.Column(scale=2):
video_input = gr.Video(label="Upload Gameplay Video", format="mp4", elem_id="input_video")
with gr.Column(scale=2):
duration_slider = gr.Slider(
label="Process Duration (seconds)",
minimum=1,
maximum=10,
value=10,
step=1,
info="Length of video to process (max 10s)",
)
process_btn = gr.Button("๐Ÿš€ Process Video", variant="primary")
with gr.Accordion("Output Files & Links", open=False):
output_files = gr.Files(label="Output Files")
viz_link = gr.Textbox(label="Visualizer URL", interactive=False)
with gr.Accordion("How it works", open=False):
gr.Markdown("""
1. Video is trimmed to specified duration (max 10 seconds)
2. IDM model predicts keyboard/mouse actions for each frame
3. Predictions are saved to output MCAP file
4. View results in the embedded OWA Dataset Visualizer below
""")
# Examples with Brotato and CSGO2
gr.Markdown("### ๐ŸŽฌ Example Videos")
gr.Examples(
examples=[
["Brotato.mkv", 8],
["CSGO2.mkv", 8],
["CSGO2_2.mkv", 8],
],
inputs=[video_input, duration_slider],
)
# Visualizer at the bottom, full width
gr.Markdown("### ๐Ÿ“Š Prediction Visualizer")
visualizer_frame = gr.HTML(
value="<p style='color: #666; padding: 20px;'>Upload a video and click 'Process Video' to see predictions here.</p>",
elem_id="output_visualizer"
)
process_btn.click(
fn=process_video,
inputs=[video_input, duration_slider],
outputs=[output_files, viz_link, visualizer_frame],
)
def main():
# Configure Gradio to serve static files from OUTPUT_DIR
gr.set_static_paths(paths=[str(OUTPUT_DIR)])
# Start file server in background thread (fallback for local development)
server_thread = threading.Thread(
target=start_file_server,
args=(OUTPUT_DIR, FILE_SERVER_PORT),
daemon=True,
)
server_thread.start()
logger.info(f"Started file server on port {FILE_SERVER_PORT} serving {OUTPUT_DIR}")
# Launch Gradio
demo.queue().launch(
server_name="0.0.0.0", server_port=7860, allowed_paths=[str(OUTPUT_DIR)]
)
if __name__ == "__main__":
main()