Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python3 | |
| """Gradio demo for Generalist-IDM model.""" | |
| import http.server | |
| import socketserver | |
| import subprocess | |
| import tempfile | |
| import threading | |
| import uuid | |
| from functools import partial | |
| from pathlib import Path | |
| import spaces | |
| import gradio as gr | |
| from loguru import logger | |
| from inference import InferenceConfig, InferencePipeline | |
| # Constants | |
| MODEL_ID = "open-world-agents/Generalist-IDM-1B" | |
| # HF Space embed URL (use hf.space domain for iframe embedding) | |
| VISUALIZER_SPACE_URL = "https://open-world-agents-visualize-dataset.hf.space" | |
| OUTPUT_DIR = Path(tempfile.gettempdir()) / "idm_demo_outputs" | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| FILE_SERVER_PORT = 8765 | |
| class CORSRequestHandler(http.server.SimpleHTTPRequestHandler): | |
| """HTTP handler with CORS support for cross-origin requests.""" | |
| def __init__(self, *args, directory=None, **kwargs): | |
| self.directory = directory | |
| super().__init__(*args, directory=directory, **kwargs) | |
| def end_headers(self): | |
| self.send_header("Access-Control-Allow-Origin", "*") | |
| self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS") | |
| self.send_header("Access-Control-Allow-Headers", "*") | |
| super().end_headers() | |
| def do_OPTIONS(self): | |
| self.send_response(200) | |
| self.end_headers() | |
| def start_file_server(directory: Path, port: int) -> None: | |
| """Start a simple HTTP file server in a background thread.""" | |
| handler = partial(CORSRequestHandler, directory=str(directory)) | |
| with socketserver.TCPServer(("0.0.0.0", port), handler) as httpd: | |
| httpd.allow_reuse_address = True | |
| logger.info(f"File server running at http://0.0.0.0:{port}") | |
| httpd.serve_forever() | |
| def cut_video_to_duration(input_path: str, output_path: str, duration: float = 30.0) -> str: | |
| """Cut video to specified duration, resize to 448x448, and set keyframes using ffmpeg.""" | |
| TARGET_WIDTH = "448" | |
| TARGET_HEIGHT = "448" | |
| cmd = [ | |
| "ffmpeg", | |
| "-y", # Overwrite output file | |
| "-i", | |
| input_path, | |
| "-t", | |
| str(duration), | |
| "-vsync", | |
| "1", | |
| "-filter:v", | |
| f"fps=60,scale={TARGET_WIDTH}:{TARGET_HEIGHT}", | |
| "-c:v", | |
| "libx264", | |
| "-x264-params", | |
| "keyint=30:no-scenecut=1:bframes=0", | |
| "-an", # No audio | |
| output_path, | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| raise RuntimeError(f"ffmpeg error: {result.stderr}") | |
| return output_path | |
| def get_video_duration(video_path: str) -> float: | |
| """Get video duration using ffprobe.""" | |
| cmd = [ | |
| "ffprobe", | |
| "-v", | |
| "error", | |
| "-show_entries", | |
| "format=duration", | |
| "-of", | |
| "default=noprint_wrappers=1:nokey=1", | |
| video_path, | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| raise RuntimeError(f"ffprobe error: {result.stderr}") | |
| return float(result.stdout.strip()) | |
| def create_mcap_from_video(video_path: str, mcap_path: str, fps: float = 20.0): | |
| """Create MCAP file with screen events from video.""" | |
| from mcap_owa.highlevel import OWAMcapWriter | |
| from owa.core import MESSAGES | |
| ScreenCaptured = MESSAGES["desktop/ScreenCaptured"] | |
| duration = get_video_duration(video_path) | |
| video_filename = Path(video_path).name | |
| with OWAMcapWriter(mcap_path) as writer: | |
| interval_ns = int(1e9 / fps) | |
| num_frames = int(duration * fps) | |
| for i in range(num_frames): | |
| timestamp_ns = i * interval_ns | |
| screen_msg = ScreenCaptured( | |
| utc_ns=timestamp_ns, | |
| media_ref={"uri": video_filename, "pts_ns": timestamp_ns}, | |
| ) | |
| writer.write_message(screen_msg, topic="screen", timestamp=timestamp_ns) | |
| logger.info(f"Created MCAP with {num_frames} frames at {fps} FPS") | |
| def get_gpu_duration(mcap_path: str, output_mcap_path: str, duration: float): | |
| """Calculate dynamic GPU duration based on video duration.""" | |
| return duration * 10 + 10 | |
| def run_idm_inference(mcap_path: str, output_mcap_path: str, duration: float): | |
| """Run IDM inference on MCAP file.""" | |
| config = InferenceConfig( | |
| model_path=MODEL_ID, | |
| device="cuda", | |
| max_context_length=2048, | |
| time_shift_seconds_for_action=0.1, | |
| ) | |
| pipeline = InferencePipeline(config) | |
| pipeline.pseudo_label_action(mcap_path, output_mcap_path) | |
| logger.success(f"Generated predictions: {output_mcap_path}") | |
| def process_video(video_file, duration, progress=gr.Progress()): | |
| """Main processing function for Gradio.""" | |
| if video_file is None: | |
| return None, "", "<p>Please upload a video file.</p>" | |
| # Create unique output directory | |
| session_id = str(uuid.uuid4())[:8] | |
| session_dir = OUTPUT_DIR / session_id | |
| session_dir.mkdir(exist_ok=True) | |
| try: | |
| progress(0.1, desc=f"Cutting video to {duration} seconds...") | |
| # Cut video to specified duration | |
| input_video = video_file | |
| output_video = str(session_dir / "input.mkv") | |
| cut_video_to_duration(input_video, output_video, duration) | |
| progress(0.2, desc="Creating MCAP from video...") | |
| # Create input MCAP | |
| input_mcap = str(session_dir / "input.mcap") | |
| create_mcap_from_video(output_video, input_mcap) | |
| progress(0.3, desc="Running IDM inference...") | |
| # Run inference | |
| output_mcap = str(session_dir / "output.mcap") | |
| run_idm_inference(input_mcap, output_mcap, duration) | |
| progress(0.9, desc="Preparing output...") | |
| # Return output files | |
| output_files = [output_video, output_mcap] | |
| # Build file URLs with UUID-based paths for privacy | |
| # Use Gradio's file serving which works on HF Spaces | |
| # Gradio serves files at /gradio_api/file={absolute_path} | |
| server_host = "lastdefiance20-generalist-idm.hf.space" | |
| mcap_url = f"https://{server_host}/gradio_api/file={output_mcap}" | |
| mkv_url = f"https://{server_host}/gradio_api/file={output_video}" | |
| # Create visualization iframe URL with direct file paths | |
| viz_url = f"{VISUALIZER_SPACE_URL}?mcap={mcap_url}&mkv={mkv_url}" | |
| # Create iframe HTML for embedded visualizer (following HF Spaces embed docs) | |
| iframe_html = f''' | |
| <div style="margin-top: 10px;"> | |
| <p><strong>๐ Visualization:</strong> <a href="{viz_url}" target="_blank">Open in new tab โ</a></p> | |
| <iframe | |
| src="{viz_url}" | |
| frameborder="0" | |
| width="100%" | |
| height="1000" | |
| style="border: 1px solid #ddd; border-radius: 8px; background: #fff;" | |
| allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" | |
| allowfullscreen | |
| ></iframe> | |
| </div> | |
| ''' | |
| progress(1.0, desc="Done!") | |
| return output_files, viz_url, iframe_html | |
| except Exception as e: | |
| logger.error(f"Error processing video: {e}") | |
| import traceback | |
| error_html = f"<p style='color:red;'>Error: {str(e)}</p><pre>{traceback.format_exc()}</pre>" | |
| return None, "", error_html | |
| # Custom CSS for wide layout | |
| css = ".gradio-container {max-width: 2000px !important; margin: 0 auto} #output_visualizer {height: 1000px;} #input_video {height: 400px;}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="Generalist-IDM Demo", css=css, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# ๐ฎ Generalist-IDM-1B Demo") | |
| gr.Markdown("Upload a gameplay video and the model will predict keyboard and mouse actions. [[model]](https://huggingface.co/open-world-agents/Generalist-IDM-1B)") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| video_input = gr.Video(label="Upload Gameplay Video", format="mp4", elem_id="input_video") | |
| with gr.Column(scale=2): | |
| duration_slider = gr.Slider( | |
| label="Process Duration (seconds)", | |
| minimum=1, | |
| maximum=10, | |
| value=10, | |
| step=1, | |
| info="Length of video to process (max 10s)", | |
| ) | |
| process_btn = gr.Button("๐ Process Video", variant="primary") | |
| with gr.Accordion("Output Files & Links", open=False): | |
| output_files = gr.Files(label="Output Files") | |
| viz_link = gr.Textbox(label="Visualizer URL", interactive=False) | |
| with gr.Accordion("How it works", open=False): | |
| gr.Markdown(""" | |
| 1. Video is trimmed to specified duration (max 10 seconds) | |
| 2. IDM model predicts keyboard/mouse actions for each frame | |
| 3. Predictions are saved to output MCAP file | |
| 4. View results in the embedded OWA Dataset Visualizer below | |
| """) | |
| # Examples with Brotato and CSGO2 | |
| gr.Markdown("### ๐ฌ Example Videos") | |
| gr.Examples( | |
| examples=[ | |
| ["Brotato.mkv", 8], | |
| ["CSGO2.mkv", 8], | |
| ["CSGO2_2.mkv", 8], | |
| ], | |
| inputs=[video_input, duration_slider], | |
| ) | |
| # Visualizer at the bottom, full width | |
| gr.Markdown("### ๐ Prediction Visualizer") | |
| visualizer_frame = gr.HTML( | |
| value="<p style='color: #666; padding: 20px;'>Upload a video and click 'Process Video' to see predictions here.</p>", | |
| elem_id="output_visualizer" | |
| ) | |
| process_btn.click( | |
| fn=process_video, | |
| inputs=[video_input, duration_slider], | |
| outputs=[output_files, viz_link, visualizer_frame], | |
| ) | |
| def main(): | |
| # Configure Gradio to serve static files from OUTPUT_DIR | |
| gr.set_static_paths(paths=[str(OUTPUT_DIR)]) | |
| # Start file server in background thread (fallback for local development) | |
| server_thread = threading.Thread( | |
| target=start_file_server, | |
| args=(OUTPUT_DIR, FILE_SERVER_PORT), | |
| daemon=True, | |
| ) | |
| server_thread.start() | |
| logger.info(f"Started file server on port {FILE_SERVER_PORT} serving {OUTPUT_DIR}") | |
| # Launch Gradio | |
| demo.queue().launch( | |
| server_name="0.0.0.0", server_port=7860, allowed_paths=[str(OUTPUT_DIR)] | |
| ) | |
| if __name__ == "__main__": | |
| main() | |