Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| WorldEngine Real-Time World Model Demo - ZeroGPU Edition | |
| A Gradio demo optimized for HuggingFace ZeroGPU with: | |
| - Generator-based GPU session that stays alive | |
| - Persistent compilation cache for faster cold starts | |
| - Command queue for real-time control | |
| """ | |
| # Check for ZeroGPU environment - must be before other imports | |
| try: | |
| import spaces | |
| IS_ZERO_GPU = True | |
| print("ZeroGPU environment detected") | |
| except ImportError: | |
| IS_ZERO_GPU = False | |
| print("Running in standard GPU mode") | |
| import base64 | |
| import os | |
| import queue | |
| import random | |
| import threading | |
| import time | |
| from collections import deque | |
| from dataclasses import dataclass, field | |
| from io import BytesIO | |
| from multiprocessing import Queue | |
| from pathlib import Path | |
| from typing import Optional, Set, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from diffusers import ModularPipeline | |
| from diffusers.utils import load_image | |
| from aoti import aoti_load_ | |
| # --- ZeroGPU Compilation Cache Setup --- | |
| CACHE_DIR = Path.home() / ".cache" / "world_engine_compile" | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(CACHE_DIR) | |
| os.environ["TORCH_COMPILE_CACHE_DIR"] = str(CACHE_DIR) | |
| os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" | |
| os.environ["TORCHDYNAMO_VERBOSE"] = "1" | |
| print(f"Compilation cache directory: {CACHE_DIR}") | |
| torch._dynamo.config.recompile_limit = 64 | |
| torch.set_float32_matmul_precision("medium") | |
| torch._dynamo.config.capture_scalar_outputs = True | |
| # --- Configuration --- | |
| MODEL_ID = os.environ.get("MODEL_PATH", "diffusers-internal-dev/world-engine-modular") | |
| pipe = ModularPipeline.from_pretrained(MODEL_ID, trust_remote_code=True, revision="aot-compatible") | |
| pipe.load_components(["transformer", "vae"], trust_remote_code=True, revision="aot-compatible", torch_dtype=torch.bfloat16) | |
| pipe.load_components(["text_encoder", "tokenizer"], trust_remote_code=True, torch_dtype=torch.bfloat16) | |
| pipe.to("cuda") | |
| pipe.blocks.sub_blocks['before_denoise'].sub_blocks['setup_kv_cache']._setup_kv_cache(pipe.transformer, pipe.device, torch.bfloat16) | |
| aoti_load_( | |
| pipe.transformer, | |
| "diffusers/waypoint-1-small-aot", | |
| "transformer.pt2", | |
| "transformer-constants.pt" | |
| ) | |
| #aoti_load_( | |
| # pipe.vae.decoder, | |
| # "diffusers/waypoint-1-small-aot", | |
| # "decoder.pt2", | |
| # "decoder-constants.pt" | |
| #) | |
| SEED_FRAME_URLS = [ | |
| "starter_21.png", | |
| "starter_18.png", | |
| "starter_22.png", | |
| "starter_14.png", | |
| "starter_9.png", | |
| ] | |
| def load_seed_frame(url: str, target_size: Tuple[int, int] = (360, 640)) -> Image.Image: | |
| """Load and resize seed frame to target size.""" | |
| img = load_image(url) | |
| img = img.resize((target_size[1], target_size[0]), Image.BILINEAR) | |
| return img | |
| def image_to_base64(image: Image.Image) -> str: | |
| """Convert PIL image to base64 data URL.""" | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return f"data:image/png;base64,{img_str}" | |
| def create_loading_image( | |
| width: int = 640, | |
| height: int = 360, | |
| text: str = "Loading...", | |
| subtext: str = None, | |
| elapsed: float = None, | |
| ) -> Image.Image: | |
| """Create a loading placeholder image with optional elapsed time.""" | |
| from PIL import ImageDraw | |
| img = Image.new("RGB", (width, height), color=(20, 20, 30)) | |
| draw = ImageDraw.Draw(img) | |
| # Main text centered | |
| bbox = draw.textbbox((0, 0), text) | |
| text_width = bbox[2] - bbox[0] | |
| text_height = bbox[3] - bbox[1] | |
| x = (width - text_width) // 2 | |
| y = (height - text_height) // 2 - 10 | |
| draw.text((x, y), text, fill=(150, 150, 170)) | |
| # Elapsed time below main text | |
| if elapsed is not None: | |
| time_text = f"{elapsed:.0f}s" | |
| bbox2 = draw.textbbox((0, 0), time_text) | |
| time_width = bbox2[2] - bbox2[0] | |
| draw.text(((width - time_width) // 2, y + text_height + 8), time_text, fill=(100, 100, 120)) | |
| # Subtext below | |
| if subtext: | |
| bbox3 = draw.textbbox((0, 0), subtext) | |
| subtext_width = bbox3[2] - bbox3[0] | |
| draw.text(((width - subtext_width) // 2, y + text_height + 28), subtext, fill=(80, 80, 100)) | |
| return img | |
| # --- Command Types --- | |
| class GenerateCommand: | |
| """Generate next frame with given controls.""" | |
| buttons: Set[int] | |
| mouse: Tuple[float, float] | |
| prompt: str | |
| class ResetCommand: | |
| """Reset world with new seed image.""" | |
| seed_image: Optional[Image.Image] = None | |
| seed_url: Optional[str] = None | |
| prompt: str = "An explorable world" | |
| class StopCommand: | |
| """Stop the GPU session.""" | |
| pass | |
| # --- Session State --- | |
| class GameSession: | |
| """Per-user game session with background worker thread.""" | |
| command_queue: Queue | |
| frame_queue: queue.Queue # Thread-safe queue for output frames | |
| worker_thread: threading.Thread | |
| stop_event: threading.Event | |
| generator: object = None # The GPU generator | |
| frame_times: deque = field(default_factory=lambda: deque(maxlen=30)) # Track last 30 frame times for FPS | |
| def gpu_worker_thread(gen, command_queue, frame_queue, stop_event, frame_times): | |
| """ | |
| Worker thread that consumes the GPU generator and pushes frames to frame_queue. | |
| Runs independently of Gradio's timer, allowing non-blocking frame reads. | |
| """ | |
| try: | |
| while not stop_event.is_set(): | |
| # Get next frame from generator (this blocks on GPU work) | |
| try: | |
| frame, frame_count = next(gen) | |
| # Track frame generation time for FPS | |
| now = time.time() | |
| frame_times.append(now) | |
| # Calculate FPS from generation times | |
| if len(frame_times) >= 2: | |
| elapsed = frame_times[-1] - frame_times[0] | |
| fps = (len(frame_times) - 1) / elapsed if elapsed > 0 else 0.0 | |
| else: | |
| fps = 0.0 | |
| # Put frame in queue, replacing old frame if queue is full | |
| try: | |
| # Clear old frame if present (keep only latest) | |
| while not frame_queue.empty(): | |
| try: | |
| frame_queue.get_nowait() | |
| except queue.Empty: | |
| break | |
| except: | |
| pass | |
| frame_queue.put_nowait((frame, frame_count, round(fps, 1))) | |
| except StopIteration: | |
| print("Generator exhausted, worker thread ending") | |
| break | |
| except Exception as e: | |
| print(f"Worker thread error: {e}") | |
| break | |
| finally: | |
| print("Worker thread finished") | |
| # --- GPU Session Generator --- | |
| def create_gpu_game_loop(command_queue: Queue, initial_seed_image=None, initial_seed_url=None, initial_prompt="An explorable world"): | |
| """Create GPU game loop generator with closure over command_queue.""" | |
| print(f"create_gpu_game_loop: initial_seed_image={type(initial_seed_image)}, initial_seed_url={initial_seed_url}") | |
| def gpu_game_loop(): | |
| """ | |
| Generator that keeps GPU allocated and processes commands. | |
| Yields (frame, frame_count) tuples. | |
| """ | |
| n_frames = pipe.transformer.config.n_frames | |
| print(f"Model loaded! (n_frames={n_frames})") | |
| print(f"gpu_game_loop: initial_seed_image={type(initial_seed_image)}, initial_seed_url={initial_seed_url}") | |
| # Initialize state with provided seed or random | |
| if initial_seed_image is not None: | |
| print(f"gpu_game_loop init: Using initial_seed_image {initial_seed_image.size if hasattr(initial_seed_image, 'size') else type(initial_seed_image)}") | |
| seed_image = initial_seed_image.resize((640, 360), Image.BILINEAR) | |
| elif initial_seed_url is not None: | |
| print(f"gpu_game_loop init: Using initial_seed_url {initial_seed_url}") | |
| seed_image = load_seed_frame(initial_seed_url) | |
| else: | |
| print("gpu_game_loop init: Using random seed") | |
| seed_image = load_seed_frame(random.choice(SEED_FRAME_URLS)) | |
| state = pipe( | |
| prompt=initial_prompt, | |
| image=seed_image, | |
| button=set(), | |
| mouse=(0.0, 0.0), | |
| output_type="pil", | |
| ) | |
| frame_count = 1 | |
| # Get initial frame | |
| frame = state.values.get("images") | |
| print("Initial frame generated, entering game loop...") | |
| # Yield initial frame | |
| yield (frame, frame_count) | |
| # Track current input state (updated by commands) | |
| current_buttons = set() | |
| current_mouse = (0.0, 0.0) | |
| current_prompt = initial_prompt | |
| # Main loop - generate frames continuously, sample latest input | |
| while True: | |
| # Drain command queue - get all pending commands (non-blocking) | |
| stop_requested = False | |
| reset_command = None | |
| while True: | |
| try: | |
| command = command_queue.get_nowait() | |
| if isinstance(command, StopCommand): | |
| stop_requested = True | |
| break | |
| elif isinstance(command, ResetCommand): | |
| reset_command = command | |
| elif isinstance(command, GenerateCommand): | |
| # Update current input state with latest command | |
| current_buttons = command.buttons | |
| current_mouse = command.mouse | |
| current_prompt = command.prompt | |
| except: | |
| break # Queue empty | |
| if stop_requested: | |
| print("Stop command received, ending GPU session") | |
| break | |
| # Handle reset if requested | |
| if reset_command is not None: | |
| print(f"Reset command received: seed_image={type(reset_command.seed_image)}, seed_url={reset_command.seed_url}") | |
| if reset_command.seed_image is not None: | |
| print(f"Using seed_image from command: {reset_command.seed_image.size if hasattr(reset_command.seed_image, 'size') else 'unknown'}") | |
| seed_img = reset_command.seed_image.resize((640, 360), Image.BILINEAR) | |
| elif reset_command.seed_url: | |
| print(f"Using seed_url from command: {reset_command.seed_url}") | |
| seed_img = load_seed_frame(reset_command.seed_url) | |
| else: | |
| print("Using random seed from command") | |
| seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS)) | |
| state = pipe( | |
| prompt=reset_command.prompt, | |
| image=seed_img, | |
| button=set(), | |
| mouse=(0.0, 0.0), | |
| output_type="pil", | |
| ) | |
| frame_count = 1 | |
| current_prompt = reset_command.prompt | |
| frame = state.values.get("images") | |
| yield (frame, frame_count) | |
| continue | |
| # Generate next frame with current input state (ALWAYS generates) | |
| state = pipe( | |
| state, | |
| prompt=current_prompt, | |
| button=current_buttons, | |
| mouse=current_mouse, | |
| image=None, | |
| output_type="pil", | |
| ) | |
| frame_count += 1 | |
| frame = state.values.get("images") | |
| # Auto-reset near end of context | |
| if frame_count >= n_frames - 2: | |
| print(f"Auto-reset at frame {frame_count}") | |
| seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS)) | |
| state = pipe( | |
| prompt=current_prompt, | |
| image=seed_img, | |
| button=set(), | |
| mouse=(0.0, 0.0), | |
| output_type="pil", | |
| ) | |
| frame_count = 1 | |
| frame = state.values.get("images") | |
| yield (frame, frame_count) | |
| print("GPU session ended") | |
| # Return the generator | |
| return gpu_game_loop() | |
| # --- Gradio App --- | |
| CONTROL_INPUT_HTML = """ | |
| <div id="control-input-wrapper" style="width: 100%; background: #0a0a0f; border-radius: 12px; overflow: hidden; font-family: 'JetBrains Mono', 'Fira Code', monospace;"> | |
| <div style="padding: 12px; background: linear-gradient(180deg, rgba(22, 27, 34, 0.95) 0%, rgba(13, 17, 23, 0.98) 100%);"> | |
| <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 12px;"> | |
| <div id="status-indicator" style="display: flex; align-items: center; gap: 6px; background: rgba(0,0,0,0.4); padding: 6px 12px; border-radius: 20px; border: 1px solid rgba(88, 166, 255, 0.2);"> | |
| <div id="status-dot" style="width: 8px; height: 8px; border-radius: 50%; background: #ff6b6b; box-shadow: 0 0 8px #ff6b6b;"></div> | |
| <span id="status-text" style="font-size: 11px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px;">Tap to enable</span> | |
| </div> | |
| <button id="mobile-toggle" style="display: none; padding: 6px 12px; background: rgba(88, 166, 255, 0.2); border: 1px solid rgba(88, 166, 255, 0.4); border-radius: 20px; color: #58a6ff; font-size: 11px; cursor: pointer;">Enable Controls</button> | |
| </div> | |
| <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 12px;"> | |
| <div style="background: rgba(0,0,0,0.3); border-radius: 12px; padding: 10px; border: 1px solid rgba(88, 166, 255, 0.1);"> | |
| <div style="font-size: 10px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 8px;">Movement</div> | |
| <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 4px; max-width: 100px; margin: 0 auto;"> | |
| <div></div> | |
| <div id="key-w" data-key="KeyW" style="aspect-ratio: 1; min-height: 32px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 12px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">W</div> | |
| <div></div> | |
| <div id="key-a" data-key="KeyA" style="aspect-ratio: 1; min-height: 32px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 12px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">A</div> | |
| <div id="key-s" data-key="KeyS" style="aspect-ratio: 1; min-height: 32px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 12px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">S</div> | |
| <div id="key-d" data-key="KeyD" style="aspect-ratio: 1; min-height: 32px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 12px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">D</div> | |
| </div> | |
| <div style="display: flex; gap: 4px; margin-top: 8px; justify-content: center; flex-wrap: wrap;"> | |
| <div id="key-shift" data-key="ShiftLeft" style="padding: 6px 10px; min-height: 28px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 4px; font-size: 10px; color: #58a6ff; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation; display: flex; align-items: center;">SHIFT</div> | |
| <div id="key-space" data-key="Space" style="padding: 6px 14px; min-height: 28px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 4px; font-size: 10px; color: #58a6ff; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation; display: flex; align-items: center;">SPACE</div> | |
| <div id="key-e" data-key="KeyE" style="padding: 6px 10px; min-height: 28px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 4px; font-size: 10px; color: #58a6ff; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation; display: flex; align-items: center;">E</div> | |
| </div> | |
| </div> | |
| <div style="background: rgba(0,0,0,0.3); border-radius: 12px; padding: 10px; border: 1px solid rgba(88, 166, 255, 0.1); min-width: 0; overflow: visible;"> | |
| <div style="font-size: 10px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 8px;">Look</div> | |
| <div style="display: flex; align-items: center; justify-content: center; gap: 6px; flex-wrap: wrap;"> | |
| <div id="mouse-joystick" style="width: 70px; height: 70px; min-width: 70px; min-height: 70px; background: rgba(88, 166, 255, 0.05); border: 2px solid rgba(88, 166, 255, 0.3); border-radius: 50%; position: relative; cursor: pointer; touch-action: none; flex-shrink: 0;"> | |
| <div id="mouse-dot" style="width: 18px; height: 18px; background: #58a6ff; border-radius: 50%; position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%); box-shadow: 0 0 12px rgba(88, 166, 255, 0.6); pointer-events: none;"></div> | |
| </div> | |
| <div style="text-align: left; font-size: 10px;"> | |
| <div style="color: #8b949e; margin-bottom: 2px;">Velocity</div> | |
| <div style="color: #58a6ff;">X: <span id="mouse-x-value">0.0</span></div> | |
| <div style="color: #58a6ff;">Y: <span id="mouse-y-value">0.0</span></div> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| <div style="margin-top: 10px; background: rgba(0,0,0,0.3); border-radius: 8px; padding: 6px 10px; border: 1px solid rgba(88, 166, 255, 0.1);"> | |
| <div style="display: flex; align-items: center; gap: 8px;"> | |
| <span style="font-size: 10px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px;">Active:</span> | |
| <div id="active-buttons" style="display: flex; gap: 4px; flex-wrap: wrap; min-height: 18px;"> | |
| <span style="font-size: 11px; color: #484f58;">None</span> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| CONTROL_INPUT_JS = """ | |
| (() => { | |
| const statusDot = element.querySelector('#status-dot'); | |
| const statusText = element.querySelector('#status-text'); | |
| const mouseDot = element.querySelector('#mouse-dot'); | |
| const mouseXValue = element.querySelector('#mouse-x-value'); | |
| const mouseYValue = element.querySelector('#mouse-y-value'); | |
| const activeButtonsDisplay = element.querySelector('#active-buttons'); | |
| const mouseJoystick = element.querySelector('#mouse-joystick'); | |
| const mobileToggle = element.querySelector('#mobile-toggle'); | |
| // Detect mobile/touch device | |
| const isMobile = ('ontouchstart' in window) || (navigator.maxTouchPoints > 0) || (window.innerWidth <= 768); | |
| let isCapturing = false; | |
| let pressedKeys = new Set(); | |
| let mouseVelocity = { x: 0, y: 0 }; | |
| let lastMouseMove = Date.now(); | |
| const BUTTON_MAP = { | |
| 'KeyW': 87, 'KeyA': 65, 'KeyS': 83, 'KeyD': 68, | |
| 'KeyQ': 81, 'KeyE': 69, 'KeyR': 82, 'KeyF': 70, | |
| 'Space': 32, 'ShiftLeft': 16, 'ShiftRight': 16, | |
| }; | |
| const KEY_DISPLAY_MAP = { | |
| 'KeyW': 'key-w', 'KeyA': 'key-a', 'KeyS': 'key-s', 'KeyD': 'key-d', | |
| 'ShiftLeft': 'key-shift', 'Space': 'key-space', 'KeyE': 'key-e', | |
| }; | |
| function updateKeyDisplay(code, pressed) { | |
| const elementId = KEY_DISPLAY_MAP[code]; | |
| if (elementId) { | |
| const keyEl = element.querySelector('#' + elementId); | |
| if (keyEl) { | |
| keyEl.style.background = pressed ? 'rgba(88, 166, 255, 0.4)' : 'rgba(88, 166, 255, 0.1)'; | |
| keyEl.style.borderColor = pressed ? '#58a6ff' : 'rgba(88, 166, 255, 0.3)'; | |
| } | |
| } | |
| } | |
| function updateMouseDisplay() { | |
| const maxRadius = 40; | |
| const displayX = Math.max(-1, Math.min(1, mouseVelocity.x / 10)); | |
| const displayY = Math.max(-1, Math.min(1, mouseVelocity.y / 10)); | |
| mouseDot.style.left = (50 + displayX * maxRadius) + '%'; | |
| mouseDot.style.top = (50 + displayY * maxRadius) + '%'; | |
| mouseXValue.textContent = mouseVelocity.x.toFixed(1); | |
| mouseYValue.textContent = mouseVelocity.y.toFixed(1); | |
| } | |
| function updateActiveButtonsDisplay() { | |
| if (pressedKeys.size === 0) { | |
| activeButtonsDisplay.innerHTML = '<span style="font-size: 11px; color: #484f58;">None</span>'; | |
| } else { | |
| activeButtonsDisplay.innerHTML = Array.from(pressedKeys) | |
| .map(code => code.replace('Key', '').replace('Left', '')) | |
| .map(name => `<span style="font-size: 10px; background: rgba(88, 166, 255, 0.2); color: #58a6ff; padding: 2px 6px; border-radius: 4px;">${name}</span>`) | |
| .join(''); | |
| } | |
| } | |
| function triggerUpdate() { | |
| const buttonIds = Array.from(pressedKeys) | |
| .filter(code => BUTTON_MAP[code] !== undefined) | |
| .map(code => BUTTON_MAP[code]); | |
| props.value = { buttons: buttonIds, mouse_x: mouseVelocity.x, mouse_y: mouseVelocity.y }; | |
| trigger('change', props.value); | |
| } | |
| function setCapturing(capturing) { | |
| isCapturing = capturing; | |
| statusDot.style.background = isCapturing ? '#3fb950' : '#ff6b6b'; | |
| statusDot.style.boxShadow = isCapturing ? '0 0 8px #3fb950' : '0 0 8px #ff6b6b'; | |
| if (isMobile) { | |
| statusText.textContent = isCapturing ? 'Controls active' : 'Tap to enable'; | |
| mobileToggle.textContent = isCapturing ? 'Disable' : 'Enable'; | |
| } else { | |
| statusText.textContent = isCapturing ? 'Capturing - ESC to release' : 'Click game to capture'; | |
| } | |
| if (!isCapturing) { | |
| pressedKeys.clear(); | |
| mouseVelocity = { x: 0, y: 0 }; | |
| Object.keys(KEY_DISPLAY_MAP).forEach(code => updateKeyDisplay(code, false)); | |
| updateMouseDisplay(); | |
| updateActiveButtonsDisplay(); | |
| triggerUpdate(); | |
| } | |
| } | |
| // Expose setCapturing globally so we can trigger it from Start Game button | |
| window.worldEngineSetCapturing = setCapturing; | |
| window.worldEngineRequestPointerLock = () => { | |
| if (!isMobile) { | |
| document.body.requestPointerLock(); | |
| } else { | |
| setCapturing(true); | |
| } | |
| }; | |
| // Mobile: show toggle button and enable controls on tap | |
| if (isMobile) { | |
| mobileToggle.style.display = 'block'; | |
| statusText.textContent = 'Tap to enable'; | |
| mobileToggle.addEventListener('click', () => { | |
| setCapturing(!isCapturing); | |
| }); | |
| // Also enable on tapping the control wrapper | |
| element.querySelector('#control-input-wrapper').addEventListener('click', (e) => { | |
| if (e.target === mobileToggle) return; | |
| if (!isCapturing) setCapturing(true); | |
| }); | |
| } | |
| // Desktop: use pointer lock | |
| document.addEventListener('pointerlockchange', () => { | |
| if (!isMobile) { | |
| setCapturing(document.pointerLockElement !== null); | |
| } | |
| }); | |
| // Keyboard controls (desktop) | |
| document.addEventListener('keydown', (e) => { | |
| if (!isCapturing) return; | |
| if (e.code === 'Escape') { | |
| if (isMobile) { | |
| setCapturing(false); | |
| } else { | |
| document.exitPointerLock(); | |
| } | |
| return; | |
| } | |
| if (BUTTON_MAP[e.code] !== undefined && !pressedKeys.has(e.code)) { | |
| pressedKeys.add(e.code); | |
| updateKeyDisplay(e.code, true); | |
| updateActiveButtonsDisplay(); | |
| triggerUpdate(); | |
| } | |
| e.preventDefault(); | |
| }); | |
| document.addEventListener('keyup', (e) => { | |
| if (!isCapturing) return; | |
| if (pressedKeys.has(e.code)) { | |
| pressedKeys.delete(e.code); | |
| updateKeyDisplay(e.code, false); | |
| updateActiveButtonsDisplay(); | |
| triggerUpdate(); | |
| } | |
| }); | |
| // Mouse movement (desktop pointer lock) | |
| document.addEventListener('mousemove', (e) => { | |
| if (!isCapturing || isMobile) return; | |
| mouseVelocity.x = e.movementX * 1.5; | |
| mouseVelocity.y = e.movementY * 1.5; | |
| updateMouseDisplay(); | |
| triggerUpdate(); | |
| lastMouseMove = Date.now(); | |
| }); | |
| // Decay mouse velocity when not moving | |
| setInterval(() => { | |
| if (isCapturing && Date.now() - lastMouseMove > 50) { | |
| mouseVelocity.x *= 0.8; | |
| mouseVelocity.y *= 0.8; | |
| if (Math.abs(mouseVelocity.x) < 0.01) mouseVelocity.x = 0; | |
| if (Math.abs(mouseVelocity.y) < 0.01) mouseVelocity.y = 0; | |
| updateMouseDisplay(); | |
| triggerUpdate(); | |
| } | |
| }, 100); | |
| // Joystick controls (touch + mouse) | |
| let joystickActive = false; | |
| const getJoystickCenter = () => { | |
| const rect = mouseJoystick.getBoundingClientRect(); | |
| return { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 }; | |
| }; | |
| const handleJoystickMove = (clientX, clientY) => { | |
| const center = getJoystickCenter(); | |
| const maxDist = mouseJoystick.offsetWidth / 2; | |
| let dx = clientX - center.x, dy = clientY - center.y; | |
| const dist = Math.sqrt(dx*dx + dy*dy); | |
| if (dist > maxDist) { dx = dx/dist*maxDist; dy = dy/dist*maxDist; } | |
| mouseVelocity.x = (dx/maxDist) * 10; | |
| mouseVelocity.y = (dy/maxDist) * 10; | |
| updateMouseDisplay(); | |
| triggerUpdate(); | |
| lastMouseMove = Date.now(); | |
| }; | |
| const resetJoystick = () => { | |
| joystickActive = false; | |
| mouseVelocity = {x:0, y:0}; | |
| updateMouseDisplay(); | |
| triggerUpdate(); | |
| }; | |
| // Mouse joystick (desktop fallback when not in pointer lock) | |
| mouseJoystick.addEventListener('mousedown', (e) => { | |
| e.preventDefault(); | |
| joystickActive = true; | |
| handleJoystickMove(e.clientX, e.clientY); | |
| }); | |
| document.addEventListener('mousemove', (e) => { | |
| if (joystickActive && !document.pointerLockElement) handleJoystickMove(e.clientX, e.clientY); | |
| }); | |
| document.addEventListener('mouseup', () => { | |
| if (joystickActive) resetJoystick(); | |
| }); | |
| // Touch joystick | |
| mouseJoystick.addEventListener('touchstart', (e) => { | |
| e.preventDefault(); | |
| joystickActive = true; | |
| if (!isCapturing) setCapturing(true); | |
| const touch = e.touches[0]; | |
| handleJoystickMove(touch.clientX, touch.clientY); | |
| }, {passive: false}); | |
| mouseJoystick.addEventListener('touchmove', (e) => { | |
| e.preventDefault(); | |
| if (joystickActive) { | |
| const touch = e.touches[0]; | |
| handleJoystickMove(touch.clientX, touch.clientY); | |
| } | |
| }, {passive: false}); | |
| mouseJoystick.addEventListener('touchend', (e) => { | |
| e.preventDefault(); | |
| resetJoystick(); | |
| }, {passive: false}); | |
| mouseJoystick.addEventListener('touchcancel', (e) => { | |
| e.preventDefault(); | |
| resetJoystick(); | |
| }, {passive: false}); | |
| // Touch controls for movement keys | |
| element.querySelectorAll('[data-key]').forEach(keyEl => { | |
| const keyCode = keyEl.dataset.key; | |
| // Touch events | |
| keyEl.addEventListener('touchstart', (e) => { | |
| e.preventDefault(); | |
| if (!isCapturing) setCapturing(true); | |
| pressedKeys.add(keyCode); | |
| updateKeyDisplay(keyCode, true); | |
| updateActiveButtonsDisplay(); | |
| triggerUpdate(); | |
| }, {passive: false}); | |
| keyEl.addEventListener('touchend', (e) => { | |
| e.preventDefault(); | |
| pressedKeys.delete(keyCode); | |
| updateKeyDisplay(keyCode, false); | |
| updateActiveButtonsDisplay(); | |
| triggerUpdate(); | |
| }, {passive: false}); | |
| keyEl.addEventListener('touchcancel', (e) => { | |
| e.preventDefault(); | |
| pressedKeys.delete(keyCode); | |
| updateKeyDisplay(keyCode, false); | |
| updateActiveButtonsDisplay(); | |
| triggerUpdate(); | |
| }, {passive: false}); | |
| // Mouse click events (for desktop users who prefer clicking) | |
| keyEl.addEventListener('mousedown', (e) => { | |
| e.preventDefault(); | |
| pressedKeys.add(keyCode); | |
| updateKeyDisplay(keyCode, true); | |
| updateActiveButtonsDisplay(); | |
| triggerUpdate(); | |
| }); | |
| keyEl.addEventListener('mouseup', (e) => { | |
| e.preventDefault(); | |
| pressedKeys.delete(keyCode); | |
| updateKeyDisplay(keyCode, false); | |
| updateActiveButtonsDisplay(); | |
| triggerUpdate(); | |
| }); | |
| keyEl.addEventListener('mouseleave', (e) => { | |
| if (pressedKeys.has(keyCode)) { | |
| pressedKeys.delete(keyCode); | |
| updateKeyDisplay(keyCode, false); | |
| updateActiveButtonsDisplay(); | |
| triggerUpdate(); | |
| } | |
| }); | |
| }); | |
| updateMouseDisplay(); | |
| })(); | |
| """ | |
| css = """ | |
| #col-container { max-width: 1200px; margin: 0 auto; } | |
| #video-output { aspect-ratio: 16/9; max-width: 640px; } | |
| #video-output img { width: 100%; height: 100%; object-fit: contain; } | |
| .seed-image-upload img { max-height: 120px !important; object-fit: contain; } | |
| .main-row { align-items: flex-start !important; } | |
| .controls-column { min-width: 280px; } | |
| .world-gallery { margin-bottom: 12px; } | |
| .world-gallery .gallery { gap: 8px !important; } | |
| .world-gallery .gallery-item { border-radius: 8px; overflow: hidden; } | |
| /* Mobile responsive styles */ | |
| @media (max-width: 768px) { | |
| .main-row { | |
| flex-direction: column !important; | |
| } | |
| .main-row > div { | |
| width: 100% !important; | |
| max-width: 100% !important; | |
| flex: none !important; | |
| } | |
| .controls-column { | |
| min-width: unset; | |
| margin-top: 12px; | |
| } | |
| #video-output { max-width: 100%; } | |
| /* On mobile, flatten the game-column children and reorder */ | |
| .game-column { | |
| display: contents !important; | |
| } | |
| /* Make the main-row handle all ordering */ | |
| .main-row { | |
| display: flex !important; | |
| flex-direction: column !important; | |
| } | |
| /* Order: video-output wrapper (1), button-row (2), controls-column (3), world-accordion (4) */ | |
| .main-row #video-output { order: 1 !important; } | |
| .main-row .button-row { order: 2 !important; width: 100% !important; } | |
| .main-row .controls-column { order: 3 !important; } | |
| .main-row .world-accordion { order: 4 !important; width: 100% !important; } | |
| /* Ensure proper spacing */ | |
| .main-row > * { | |
| margin-bottom: 8px; | |
| } | |
| } | |
| """ | |
| def create_app(): | |
| with gr.Blocks(css=css, theme=gr.themes.Soft(), title="WorldEngine") as demo: | |
| # State: (generator, command_queue) or empty tuple | |
| session_state = gr.State(()) | |
| # Current controls (updated by JS) | |
| current_controls = gr.State({"buttons": [], "mouse_x": 0.0, "mouse_y": 0.0}) | |
| current_prompt = gr.State("An explorable world") | |
| # Latest frame for display | |
| latest_frame = gr.State(None) | |
| latest_frame_count = gr.State(0) | |
| # Selected seed URL from examples | |
| selected_seed_url = gr.State(None) | |
| # Store uploaded image in state (workaround for Gradio component value issues) | |
| uploaded_image_state = gr.State(None) | |
| gr.Markdown(""" | |
| # 🌍 Waypoint 1 Small | |
| Interactive frame-by-frame world generation. This model is running on ZeroGPU (H200), and works even better on Blackwell GPUs (B200, RTX 6000 Blackwell, RTX 5090s) | |
| [[blog](https://huggingface.co/blog/waypoint-1)], [[model](https://huggingface.co/Overworld/Waypoint-1-Small)], [Overworld Streaming Client](https://www.overworld.stream) | |
| **Controls:** Click "Start Game" → WASD to move • Mouse to look • Press ESC to release controls (or touch the controls on mobile) | |
| """) | |
| with gr.Row(elem_classes=["main-row"]): | |
| with gr.Column(scale=2, elem_classes=["game-column"]): | |
| video_output = gr.Image( | |
| label="Game View", | |
| elem_id="video-output", | |
| streaming=True, | |
| width=640, | |
| height=360, | |
| show_label=False, | |
| ) | |
| with gr.Row(elem_classes=["button-row"]): | |
| start_btn = gr.Button("🎮 Start Game", variant="primary") | |
| stop_btn = gr.Button("⏹ End Game", interactive=False) | |
| with gr.Accordion("World Selection", open=True, elem_classes=["world-accordion"]): | |
| gr.Markdown("**Choose a starting world** (or leave blank for random):") | |
| # Gallery for world selection | |
| world_gallery = gr.Gallery( | |
| value=SEED_FRAME_URLS, | |
| label="Preset Worlds", | |
| columns=5, | |
| rows=1, | |
| height=200, | |
| object_fit="cover", | |
| allow_preview=False, | |
| elem_classes=["world-gallery"], | |
| ) | |
| gr.Markdown("**Or upload your own:**") | |
| seed_image_upload = gr.Image( | |
| label="Custom World Image", | |
| type="pil", | |
| sources=["upload", "clipboard"], | |
| height=220, | |
| elem_classes=["seed-image-upload"], | |
| ) | |
| reset_btn = gr.Button("Restart World", variant="secondary", size="sm") | |
| with gr.Column(scale=1, elem_classes=["controls-column"]): | |
| control_input = gr.HTML( | |
| value={"buttons": [], "mouse_x": 0.0, "mouse_y": 0.0}, | |
| html_template=CONTROL_INPUT_HTML, | |
| js_on_load=CONTROL_INPUT_JS, | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="World Prompt", | |
| value="An explorable world", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| frame_display = gr.Number(label="Frame", value=0, interactive=False) | |
| fps_display = gr.Number(label="FPS", value=0.0, interactive=False) | |
| # --- Event Handlers --- | |
| def on_gallery_select(evt: gr.SelectData): | |
| """Handle gallery selection - return the URL of selected world.""" | |
| if evt.index is not None and evt.index < len(SEED_FRAME_URLS): | |
| return SEED_FRAME_URLS[evt.index] | |
| return None | |
| def on_gallery_start(state, evt: gr.SelectData, uploaded_image, prompt): | |
| """Handle gallery selection - start/restart game with selected world.""" | |
| print(f"on_gallery_start CALLED: evt.index={evt.index}", flush=True) | |
| if evt.index is None or evt.index >= len(SEED_FRAME_URLS): | |
| # No valid selection, do nothing | |
| yield (state, None, 0, None, 0, gr.update(), gr.update()) | |
| return | |
| selected_url = SEED_FRAME_URLS[evt.index] | |
| # If game is running, stop it first | |
| if state and isinstance(state, GameSession): | |
| session = state | |
| session.stop_event.set() | |
| session.command_queue.put(StopCommand()) | |
| if session.worker_thread.is_alive(): | |
| session.worker_thread.join(timeout=2.0) | |
| # Show info about controls | |
| gr.Info("Controls locked! Press ESC to release mouse/keyboard capture.", duration=5) | |
| # Show loading state | |
| loading_img = create_loading_image(text="Generating World ...") | |
| yield ( | |
| (), | |
| None, | |
| 0, | |
| loading_img, | |
| 0, | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| ) | |
| # Start new game with selected world | |
| command_queue = Queue() | |
| frame_queue = queue.Queue(maxsize=2) | |
| stop_event = threading.Event() | |
| gen = create_gpu_game_loop( | |
| command_queue, | |
| initial_seed_image=None, | |
| initial_seed_url=selected_url, | |
| initial_prompt=prompt or "An explorable world" | |
| ) | |
| # Get initial frame | |
| frame, frame_count = next(gen) | |
| # Start worker thread | |
| frame_times = deque(maxlen=30) | |
| worker = threading.Thread( | |
| target=gpu_worker_thread, | |
| args=(gen, command_queue, frame_queue, stop_event, frame_times), | |
| daemon=True | |
| ) | |
| worker.start() | |
| session = GameSession( | |
| command_queue=command_queue, | |
| frame_queue=frame_queue, | |
| worker_thread=worker, | |
| stop_event=stop_event, | |
| generator=gen, | |
| frame_times=frame_times, | |
| ) | |
| yield ( | |
| session, | |
| frame, | |
| frame_count, | |
| frame, | |
| frame_count, | |
| gr.update(interactive=False), | |
| gr.update(interactive=True), | |
| ) | |
| def on_start(selected_url, uploaded_image, prompt): | |
| """Start GPU session - creates background worker thread for non-blocking frames.""" | |
| print(f"on_start CALLED:", flush=True) | |
| print(f" uploaded_image (from state) type: {type(uploaded_image)}", flush=True) | |
| print(f" uploaded_image is PIL: {isinstance(uploaded_image, Image.Image) if uploaded_image else False}", flush=True) | |
| print(f" selected_url: {selected_url}", flush=True) | |
| # Show info about controls | |
| gr.Info("Controls locked! Press ESC to release mouse/keyboard capture.", duration=5) | |
| # Show loading state immediately | |
| loading_img = create_loading_image(text="Generating World ...") | |
| yield ( | |
| (), # session_state (empty during loading) | |
| None, # latest_frame | |
| 0, # latest_frame_count | |
| loading_img, # video_output - show loading | |
| 0, # frame_display | |
| gr.update(interactive=False), # start_btn - disable | |
| gr.update(interactive=False), # stop_btn - disable during load | |
| ) | |
| # Determine seed image/url | |
| # Priority: uploaded image (from state) > selected from gallery > random | |
| seed_image = None | |
| seed_url = None | |
| # Check if uploaded_image is a valid PIL Image | |
| is_pil_image = isinstance(uploaded_image, Image.Image) | |
| has_uploaded_image = uploaded_image is not None and is_pil_image | |
| print(f"on_start decision:", flush=True) | |
| print(f" is_pil_image: {is_pil_image}", flush=True) | |
| print(f" has_uploaded_image: {has_uploaded_image}", flush=True) | |
| if has_uploaded_image: | |
| seed_image = uploaded_image | |
| print(f"on_start: Using uploaded image: {seed_image.size}", flush=True) | |
| elif selected_url is not None: | |
| seed_url = selected_url | |
| print(f"on_start: Using selected URL: {seed_url}", flush=True) | |
| else: | |
| print("on_start: Using random seed", flush=True) | |
| # else: random will be chosen in create_gpu_game_loop | |
| command_queue = Queue() | |
| frame_queue = queue.Queue(maxsize=2) # Thread-safe output queue | |
| stop_event = threading.Event() | |
| gen = create_gpu_game_loop( | |
| command_queue, | |
| initial_seed_image=seed_image, | |
| initial_seed_url=seed_url, | |
| initial_prompt=prompt or "An explorable world" | |
| ) | |
| # Get initial frame synchronously (needed to show first frame) | |
| frame, frame_count = next(gen) | |
| # Start worker thread to consume generator in background | |
| frame_times = deque(maxlen=30) | |
| worker = threading.Thread( | |
| target=gpu_worker_thread, | |
| args=(gen, command_queue, frame_queue, stop_event, frame_times), | |
| daemon=True | |
| ) | |
| worker.start() | |
| session = GameSession( | |
| command_queue=command_queue, | |
| frame_queue=frame_queue, | |
| worker_thread=worker, | |
| stop_event=stop_event, | |
| generator=gen, | |
| frame_times=frame_times, | |
| ) | |
| yield ( | |
| session, # session_state | |
| frame, # latest_frame | |
| frame_count, # latest_frame_count | |
| frame, # video_output | |
| frame_count, # frame_display | |
| gr.update(interactive=False), # start_btn | |
| gr.update(interactive=True), # stop_btn | |
| ) | |
| def on_stop(state): | |
| """Stop GPU session and cleanup worker thread.""" | |
| if not state or not isinstance(state, GameSession): | |
| return ((), None, 0, None, 0, | |
| gr.update(interactive=True), gr.update(interactive=False)) | |
| session = state | |
| # Signal worker to stop | |
| session.stop_event.set() | |
| session.command_queue.put(StopCommand()) | |
| # Wait for worker thread to finish (with timeout) | |
| if session.worker_thread.is_alive(): | |
| session.worker_thread.join(timeout=2.0) | |
| return ( | |
| (), | |
| None, | |
| 0, | |
| None, | |
| 0, | |
| gr.update(interactive=True), | |
| gr.update(interactive=False), | |
| ) | |
| def on_generate_tick(state, controls, prompt, current_frame, current_count, current_fps): | |
| """Called by timer - send generate command and get next frame (non-blocking).""" | |
| if not state or not isinstance(state, GameSession): | |
| return current_frame, current_count, current_frame, current_count, 0.0 | |
| session = state | |
| # Send generate command (non-blocking) | |
| buttons = set(controls.get("buttons", [])) | |
| mouse = (controls.get("mouse_x", 0.0), controls.get("mouse_y", 0.0)) | |
| session.command_queue.put(GenerateCommand(buttons=buttons, mouse=mouse, prompt=prompt)) | |
| # Non-blocking read from frame_queue - get latest frame if available | |
| try: | |
| frame, frame_count, fps = session.frame_queue.get_nowait() | |
| return frame, frame_count, frame, frame_count, fps | |
| except queue.Empty: | |
| # No new frame yet, show previous frame (never blocks!) | |
| return current_frame, current_count, current_frame, current_count, current_fps | |
| def on_reset(state, selected_url, uploaded_image, prompt): | |
| """Reset world with new seed - starts game if not running.""" | |
| print(f"on_reset CALLED:", flush=True) | |
| print(f" uploaded_image (from state) type: {type(uploaded_image)}", flush=True) | |
| print(f" uploaded_image is PIL: {isinstance(uploaded_image, Image.Image) if uploaded_image else False}", flush=True) | |
| print(f" selected_url: {selected_url}", flush=True) | |
| # Check if uploaded_image is a valid PIL Image | |
| is_pil_image = isinstance(uploaded_image, Image.Image) | |
| has_uploaded_image = uploaded_image is not None and is_pil_image | |
| # If game is not running, start it | |
| if not state or not isinstance(state, GameSession): | |
| # Show info about controls | |
| gr.Info("Controls locked! Press ESC to release mouse/keyboard capture.", duration=5) | |
| # Show loading state | |
| loading_img = create_loading_image(text="Generating World ...") | |
| yield ( | |
| (), | |
| None, | |
| 0, | |
| loading_img, | |
| 0, | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| ) | |
| # Priority: uploaded image > selected from gallery > random | |
| seed_image = None | |
| seed_url = None | |
| if has_uploaded_image: | |
| seed_image = uploaded_image | |
| print(f"on_reset (start): Using uploaded image: {seed_image.size}", flush=True) | |
| elif selected_url is not None: | |
| seed_url = selected_url | |
| print(f"on_reset (start): Using selected URL: {seed_url}", flush=True) | |
| else: | |
| print("on_reset (start): Using random seed", flush=True) | |
| command_queue = Queue() | |
| frame_queue = queue.Queue(maxsize=2) | |
| stop_event = threading.Event() | |
| gen = create_gpu_game_loop( | |
| command_queue, | |
| initial_seed_image=seed_image, | |
| initial_seed_url=seed_url, | |
| initial_prompt=prompt or "An explorable world" | |
| ) | |
| # Get initial frame | |
| frame, frame_count = next(gen) | |
| # Start worker thread | |
| frame_times = deque(maxlen=30) | |
| worker = threading.Thread( | |
| target=gpu_worker_thread, | |
| args=(gen, command_queue, frame_queue, stop_event, frame_times), | |
| daemon=True | |
| ) | |
| worker.start() | |
| session = GameSession( | |
| command_queue=command_queue, | |
| frame_queue=frame_queue, | |
| worker_thread=worker, | |
| stop_event=stop_event, | |
| generator=gen, | |
| frame_times=frame_times, | |
| ) | |
| yield ( | |
| session, | |
| frame, | |
| frame_count, | |
| frame, | |
| frame_count, | |
| gr.update(interactive=False), | |
| gr.update(interactive=True), | |
| ) | |
| return | |
| # Game is running - reset it | |
| session = state | |
| # Priority: uploaded image > selected from gallery > random | |
| seed_image = None | |
| seed_url = None | |
| if has_uploaded_image: | |
| seed_image = uploaded_image | |
| print(f"on_reset (running): Using uploaded image: {seed_image.size}", flush=True) | |
| elif selected_url is not None: | |
| seed_url = selected_url | |
| print(f"on_reset (running): Using selected URL: {seed_url}", flush=True) | |
| else: | |
| print("on_reset (running): Using random seed", flush=True) | |
| session.command_queue.put(ResetCommand(seed_image=seed_image, seed_url=seed_url, prompt=prompt)) | |
| # Just return current state - next timer tick will pick up the reset frame | |
| yield ( | |
| state, | |
| None, | |
| 0, | |
| None, | |
| 0, | |
| gr.update(), | |
| gr.update(), | |
| ) | |
| def on_controller_change(value): | |
| """Update current controls state.""" | |
| return value or {"buttons": [], "mouse_x": 0.0, "mouse_y": 0.0} | |
| def on_prompt_change(value): | |
| """Update current prompt state.""" | |
| return value | |
| # Wire up events | |
| world_gallery.select( | |
| fn=on_gallery_select, | |
| inputs=[], | |
| outputs=[selected_seed_url], | |
| ) | |
| # Also start/restart game when gallery item is clicked | |
| world_gallery.select( | |
| fn=on_gallery_start, | |
| inputs=[session_state, seed_image_upload, current_prompt], | |
| outputs=[session_state, latest_frame, latest_frame_count, | |
| video_output, frame_display, start_btn, stop_btn], | |
| js="() => { setTimeout(() => { if (window.worldEngineRequestPointerLock) window.worldEngineRequestPointerLock(); }, 500); }", | |
| ) | |
| start_btn.click( | |
| fn=on_start, | |
| inputs=[selected_seed_url, uploaded_image_state, prompt_input], | |
| outputs=[session_state, latest_frame, latest_frame_count, | |
| video_output, frame_display, start_btn, stop_btn], | |
| js="() => { setTimeout(() => { if (window.worldEngineRequestPointerLock) window.worldEngineRequestPointerLock(); }, 500); }", | |
| ) | |
| stop_btn.click( | |
| fn=on_stop, | |
| inputs=[session_state], | |
| outputs=[session_state, latest_frame, latest_frame_count, | |
| video_output, frame_display, start_btn, stop_btn], | |
| ) | |
| reset_btn.click( | |
| fn=on_reset, | |
| inputs=[session_state, selected_seed_url, uploaded_image_state, current_prompt], | |
| outputs=[session_state, latest_frame, latest_frame_count, | |
| video_output, frame_display, start_btn, stop_btn], | |
| js="() => { setTimeout(() => { if (window.worldEngineRequestPointerLock) window.worldEngineRequestPointerLock(); }, 500); }", | |
| ) | |
| control_input.change(fn=on_controller_change, inputs=[control_input], outputs=[current_controls]) | |
| prompt_input.change(fn=on_prompt_change, inputs=[prompt_input], outputs=[current_prompt]) | |
| # Store uploaded image in state and clear gallery selection | |
| def on_image_upload(image): | |
| """When user uploads an image, store it and clear the gallery selection.""" | |
| print(f"on_image_upload: image type={type(image)}, is PIL={isinstance(image, Image.Image) if image else False}", flush=True) | |
| if image is not None and isinstance(image, Image.Image): | |
| print(f"on_image_upload: Storing uploaded image {image.size}", flush=True) | |
| return image, None # Store image, clear selected_seed_url | |
| else: | |
| print(f"on_image_upload: Clearing stored image", flush=True) | |
| return None, gr.update() # Clear stored image, no change to URL | |
| seed_image_upload.change( | |
| fn=on_image_upload, | |
| inputs=[seed_image_upload], | |
| outputs=[uploaded_image_state, selected_seed_url], | |
| ) | |
| # Timer for continuous generation | |
| timer = gr.Timer(value=1/30) | |
| timer.tick( | |
| fn=on_generate_tick, | |
| inputs=[session_state, current_controls, current_prompt, latest_frame, latest_frame_count, fps_display], | |
| outputs=[latest_frame, latest_frame_count, video_output, frame_display, fps_display], | |
| ) | |
| # Pointer lock JS - also allows clicking the game window | |
| demo.load(fn=None, js=""" | |
| () => { | |
| const insertButton = () => { | |
| const output = document.querySelector('#video-output'); | |
| if (!output) { setTimeout(insertButton, 100); return; } | |
| output.style.cursor = 'pointer'; | |
| output.onclick = () => { | |
| if (window.worldEngineRequestPointerLock) { | |
| window.worldEngineRequestPointerLock(); | |
| } else { | |
| document.body.requestPointerLock(); | |
| } | |
| }; | |
| }; | |
| insertButton(); | |
| } | |
| """) | |
| return demo | |
| # Avoid ZeroGPU "no GPU function" error | |
| if IS_ZERO_GPU: | |
| spaces.GPU(lambda: None) | |
| def main(): | |
| print(f"Model: {MODEL_ID}") | |
| print(f"Cache dir: {CACHE_DIR}") | |
| print(f"ZeroGPU: {IS_ZERO_GPU}") | |
| demo = create_app() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |
| if __name__ == "__main__": | |
| main() |