dn6 HF Staff commited on
Commit
7237d7b
·
1 Parent(s): 2c2d194
Files changed (1) hide show
  1. app.py +338 -518
app.py CHANGED
@@ -2,18 +2,25 @@
2
  WorldEngine Real-Time World Model Demo - ZeroGPU Edition
3
 
4
  A Gradio demo optimized for HuggingFace ZeroGPU with:
 
5
  - Persistent compilation cache for faster cold starts
6
- - Warmup pass to pre-compile kernels
7
- - @spaces.GPU decorator support
8
  """
 
 
 
 
 
 
 
 
9
 
10
  import base64
11
  import os
12
  import random
13
- import threading
14
- import time
15
- from dataclasses import dataclass, field
16
  from io import BytesIO
 
17
  from pathlib import Path
18
  from typing import Optional, Set, Tuple
19
 
@@ -26,14 +33,11 @@ from diffusers import ModularPipeline
26
  from diffusers.utils import load_image
27
 
28
  # --- ZeroGPU Compilation Cache Setup ---
29
- # Set persistent cache directory BEFORE importing torch.compile
30
  CACHE_DIR = Path.home() / ".cache" / "world_engine_compile"
31
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
32
 
33
  os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(CACHE_DIR)
34
  os.environ["TORCH_COMPILE_CACHE_DIR"] = str(CACHE_DIR)
35
-
36
- # Enable FX graph caching for faster recompilation
37
  os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
38
 
39
  print(f"Compilation cache directory: {CACHE_DIR}")
@@ -42,20 +46,10 @@ torch._dynamo.config.recompile_limit = 64
42
  torch.set_float32_matmul_precision("medium")
43
  torch._dynamo.config.capture_scalar_outputs = True
44
 
45
- # Check for ZeroGPU environment
46
- try:
47
- import spaces
48
- IS_ZERO_GPU = True
49
- print("ZeroGPU environment detected")
50
- except ImportError:
51
- IS_ZERO_GPU = False
52
- print("Running in standard GPU mode")
53
-
54
  # --- Configuration ---
55
  MODEL_ID = os.environ.get("MODEL_PATH", "diffusers-internal-dev/world-engine-modular")
56
- MAX_SEED = np.iinfo(np.int32).max
57
 
58
- # Seed frame URLs for reset
59
  SEED_FRAME_URLS = [
60
  "https://gist.github.com/user-attachments/assets/5d91c49a-2ae9-418f-99c0-e93ae387e1de",
61
  "https://gist.github.com/user-attachments/assets/4adc5a3d-6980-4d1e-b6e8-9033cdf61c66",
@@ -80,297 +74,216 @@ def image_to_base64(image: Image.Image) -> str:
80
  return f"data:image/png;base64,{img_str}"
81
 
82
 
83
- # --- Control State ---
 
 
 
 
 
 
 
 
84
  @dataclass
85
- class ControlState:
86
- """Thread-safe control state."""
87
-
88
- buttons: Set[int] = field(default_factory=set)
89
- mouse: Tuple[float, float] = (0.0, 0.0)
90
- _lock: threading.Lock = field(default_factory=threading.Lock)
91
-
92
- def update(self, buttons: Set[int], mouse: Tuple[float, float]):
93
- with self._lock:
94
- self.buttons = buttons.copy()
95
- self.mouse = mouse
96
-
97
- def get(self) -> Tuple[Set[int], Tuple[float, float]]:
98
- with self._lock:
99
- return self.buttons.copy(), self.mouse
100
-
101
-
102
- # --- Game State with Background Generation ---
103
- class GameState:
104
- """Global game state with background frame generation."""
105
-
106
- def __init__(self):
107
- self.pipe: Optional[ModularPipeline] = None
108
- self.state = None
109
- self.control_state = ControlState()
110
- self.frame_count = 0
111
- self.n_frames = None
112
- self.is_initialized = False
113
- self.is_warming_up = False
114
- self.prompt = "An explorable world"
115
-
116
- # Background generation
117
- self._generation_thread: Optional[threading.Thread] = None
118
- self._running = False
119
- self._state_lock = threading.Lock()
120
-
121
- # Latest frame buffer
122
- self._latest_frame: Optional[Image.Image] = None
123
- self._latest_frame_b64: Optional[str] = None
124
- self._frame_lock = threading.Lock()
125
 
126
  @spaces.GPU(duration=120)
127
- def initialize(
128
- self,
129
- model_path: str,
130
- device: str = "cuda",
131
- use_compile: bool = True,
132
- warmup_steps: int = 1,
133
- ):
134
- if self.is_initialized:
135
- return
136
-
137
- print(f"Loading ModularPipeline from: {model_path}")
138
- self.pipe = ModularPipeline.from_pretrained(model_path, trust_remote_code=True)
139
- self.pipe.load_components(
140
- device_map=device,
141
  torch_dtype=torch.bfloat16,
142
  trust_remote_code=True,
143
  )
144
 
145
- self.pipe.transformer.apply_inference_patches()
146
- self.pipe.transformer.quantize("fp8")
 
 
 
 
 
 
 
 
 
 
147
 
148
- if use_compile:
149
  print("Compiling transformer...")
150
- self.pipe.transformer = torch.compile(
151
- self.pipe.transformer,
152
  mode="max-autotune-no-cudagraphs",
153
  fullgraph=True,
154
  dynamic=False,
155
  )
156
 
157
  print("Compiling VAE...")
158
- self.pipe.vae = torch.compile(
159
- self.pipe.vae,
160
  mode="max-autotune",
161
  fullgraph=True,
162
  dynamic=False,
163
  )
164
 
165
- self.n_frames = self.pipe.transformer.config.n_frames
166
- print(f"ModularPipeline loaded! (n_frames={self.n_frames})")
167
 
168
- # Only warmup if cache is empty (first run)
169
- # Subsequent cold starts can skip warmup since kernels are already compiled
170
- cache_files = list(CACHE_DIR.glob("**/*.so"))
171
- if use_compile and warmup_steps > 0 and not cache_files:
172
- print("No cached kernels found - running warmup to compile...")
173
- self._warmup_compile(warmup_steps)
174
- elif cache_files:
175
- print(f"Found {len(cache_files)} cached kernels - skipping warmup")
176
-
177
- self.is_initialized = True
 
 
178
 
179
- # Generate initial state
180
- print("Generating initial state...")
181
  seed_image = load_seed_frame(random.choice(SEED_FRAME_URLS))
182
- self.state = self.pipe(
183
- prompt=self.prompt,
184
  image=seed_image,
185
  button=set(),
186
  mouse=(0.0, 0.0),
187
  output_type="pil",
188
  )
189
- self.frame_count = 1
190
-
191
- # Store initial frame
192
- images = self.state.values.get("images")
193
- if images is not None:
194
- with self._frame_lock:
195
- self._latest_frame = images
196
- self._latest_frame_b64 = image_to_base64(images)
197
 
198
- print("Initial state ready!")
199
- self._start_generation_thread()
 
200
 
201
- def _warmup_compile(self, num_steps: int = 3):
202
- """Run warmup passes to trigger torch.compile before first real request."""
203
- print(f"Warming up compiled model ({num_steps} steps)...")
204
- self.is_warming_up = True
205
 
206
- try:
207
- # Create a temporary state for warmup
208
- seed_image = load_seed_frame(SEED_FRAME_URLS[0])
 
 
 
 
 
209
 
210
- warmup_state = self.pipe(
211
- prompt="warmup",
212
- image=seed_image,
213
- button=set(),
214
- mouse=(0.0, 0.0),
215
- output_type="pil",
216
- )
217
 
218
- # Run a few more steps to ensure all kernels are compiled
219
- for i in range(num_steps - 1):
220
- warmup_state = self.pipe(
221
- warmup_state,
222
- prompt="warmup",
 
 
 
 
 
 
 
 
 
 
 
223
  button=set(),
224
  mouse=(0.0, 0.0),
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  image=None,
226
  output_type="pil",
227
  )
228
- print(f" Warmup step {i + 2}/{num_steps} complete")
229
-
230
- print("Warmup complete! Compiled kernels cached.")
231
-
232
- except Exception as e:
233
- print(f"Warmup error (non-fatal): {e}")
234
-
235
- self.is_warming_up = False
236
-
237
- def _start_generation_thread(self):
238
- if self._generation_thread is not None and self._generation_thread.is_alive():
239
- return
240
-
241
- self._running = True
242
- self._generation_thread = threading.Thread(
243
- target=self._generation_loop, daemon=True
244
- )
245
- self._generation_thread.start()
246
- print("Background generation thread started!")
247
-
248
- def _generation_loop(self):
249
- while self._running:
250
- if not self.is_initialized or self.state is None or self.is_warming_up:
251
- time.sleep(0.01)
252
- continue
253
-
254
- buttons, mouse = self.control_state.get()
255
-
256
- with self._state_lock:
257
- try:
258
- self.state = self.pipe(
259
- self.state,
260
- prompt=self.prompt,
261
- button=buttons,
262
- mouse=mouse,
263
- image=None,
264
  output_type="pil",
265
  )
266
- self.frame_count += 1
267
-
268
- images = self.state.values.get("images")
269
- if images is not None:
270
- with self._frame_lock:
271
- self._latest_frame = images
272
- self._latest_frame_b64 = image_to_base64(images)
273
-
274
- if self.frame_count >= self.n_frames - 2:
275
- print(f"Auto-reset at frame {self.frame_count}")
276
- self._do_reset()
277
-
278
- except Exception as e:
279
- import traceback
280
-
281
- print(f"Generation error: {e}")
282
- traceback.print_exc()
283
- time.sleep(0.5)
284
-
285
- def _do_reset(self, seed_url: str = None, seed_image: Image.Image = None):
286
- self.frame_count = 0
287
-
288
- # Use provided image, or load from URL, or pick random
289
- if seed_image is not None:
290
- # Resize uploaded image to target size
291
- target_size = (360, 640) # (H, W)
292
- seed_image = seed_image.resize(
293
- (target_size[1], target_size[0]), Image.BILINEAR
294
- )
295
- else:
296
- url = seed_url or random.choice(SEED_FRAME_URLS)
297
- seed_image = load_seed_frame(url)
298
-
299
- buttons, mouse = self.control_state.get()
300
- self.state = self.pipe(
301
- prompt=self.prompt,
302
- image=seed_image,
303
- button=buttons,
304
- mouse=mouse,
305
- output_type="pil",
306
- )
307
- self.frame_count = 1
308
-
309
- images = self.state.values.get("images")
310
- if images is not None:
311
- with self._frame_lock:
312
- self._latest_frame = images
313
- self._latest_frame_b64 = image_to_base64(images)
314
-
315
- def reset(self, seed_url: str = None, seed_image: Image.Image = None):
316
- with self._state_lock:
317
- self._do_reset(seed_url=seed_url, seed_image=seed_image)
318
-
319
- def get_latest_frame(self) -> Tuple[Optional[str], int]:
320
- with self._frame_lock:
321
- return self._latest_frame_b64, self.frame_count
322
 
323
- def update_controls(self, buttons: Set[int], mouse: Tuple[float, float]):
324
- self.control_state.update(buttons, mouse)
325
 
326
- def stop(self):
327
- self._running = False
328
- if self._generation_thread is not None:
329
- self._generation_thread.join(timeout=2.0)
330
 
 
 
331
 
332
- game_state = GameState()
333
 
334
-
335
- # --- Control Input HTML Component ---
336
  CONTROL_INPUT_HTML = """
337
  <div id="control-input-wrapper" style="width: 100%; background: #0a0a0f; border-radius: 12px; overflow: hidden; font-family: 'JetBrains Mono', 'Fira Code', monospace;">
338
  <div style="padding: 16px; background: linear-gradient(180deg, rgba(22, 27, 34, 0.95) 0%, rgba(13, 17, 23, 0.98) 100%);">
339
- <!-- Status Bar -->
340
  <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 16px;">
341
  <div id="status-indicator" style="display: flex; align-items: center; gap: 6px; background: rgba(0,0,0,0.4); padding: 6px 12px; border-radius: 20px; border: 1px solid rgba(88, 166, 255, 0.2);">
342
  <div id="status-dot" style="width: 8px; height: 8px; border-radius: 50%; background: #ff6b6b; box-shadow: 0 0 8px #ff6b6b;"></div>
343
  <span id="status-text" style="font-size: 11px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px;">Not capturing</span>
344
  </div>
345
  </div>
346
-
347
  <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 16px;">
348
- <!-- Keyboard State (touch-enabled) -->
349
  <div style="background: rgba(0,0,0,0.3); border-radius: 12px; padding: 12px; border: 1px solid rgba(88, 166, 255, 0.1);">
350
  <div style="font-size: 10px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 10px;">Movement</div>
351
  <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 4px; max-width: 120px; margin: 0 auto;">
352
  <div></div>
353
- <div id="key-w" data-key="KeyW" style="aspect-ratio: 1; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 11px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">W</div>
354
  <div></div>
355
- <div id="key-a" data-key="KeyA" style="aspect-ratio: 1; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 11px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">A</div>
356
- <div id="key-s" data-key="KeyS" style="aspect-ratio: 1; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 11px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">S</div>
357
- <div id="key-d" data-key="KeyD" style="aspect-ratio: 1; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 11px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">D</div>
358
  </div>
359
  <div style="display: flex; gap: 4px; margin-top: 8px; justify-content: center;">
360
- <div id="key-shift" data-key="ShiftLeft" style="padding: 4px 8px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 4px; font-size: 9px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">SHIFT</div>
361
- <div id="key-space" data-key="Space" style="padding: 4px 16px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 4px; font-size: 9px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">SPACE</div>
362
- <div id="key-e" data-key="KeyE" style="padding: 4px 8px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 4px; font-size: 9px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">E</div>
363
  </div>
364
  </div>
365
-
366
- <!-- Mouse/Look Joystick (touch-enabled) -->
367
  <div style="background: rgba(0,0,0,0.3); border-radius: 12px; padding: 12px; border: 1px solid rgba(88, 166, 255, 0.1);">
368
  <div style="font-size: 10px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 10px;">Look (drag)</div>
369
  <div style="display: flex; align-items: center; justify-content: center; gap: 16px;">
370
  <div id="mouse-joystick" style="width: 100px; height: 100px; background: rgba(88, 166, 255, 0.05); border: 2px solid rgba(88, 166, 255, 0.3); border-radius: 50%; position: relative; cursor: pointer; touch-action: none;">
371
- <div id="mouse-dot" style="width: 24px; height: 24px; background: #58a6ff; border-radius: 50%; position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%); box-shadow: 0 0 12px rgba(88, 166, 255, 0.6); transition: all 0.05s ease-out; pointer-events: none;"></div>
372
- <div style="position: absolute; top: 50%; left: 0; right: 0; height: 1px; background: rgba(88, 166, 255, 0.2); pointer-events: none;"></div>
373
- <div style="position: absolute; left: 50%; top: 0; bottom: 0; width: 1px; background: rgba(88, 166, 255, 0.2); pointer-events: none;"></div>
374
  </div>
375
  <div style="text-align: left;">
376
  <div style="font-size: 10px; color: #8b949e; margin-bottom: 4px;">Velocity</div>
@@ -380,7 +293,6 @@ CONTROL_INPUT_HTML = """
380
  </div>
381
  </div>
382
  </div>
383
-
384
  <div style="margin-top: 12px; background: rgba(0,0,0,0.3); border-radius: 8px; padding: 8px 12px; border: 1px solid rgba(88, 166, 255, 0.1);">
385
  <div style="display: flex; align-items: center; gap: 8px;">
386
  <span style="font-size: 10px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px;">Active:</span>
@@ -404,26 +316,14 @@ CONTROL_INPUT_JS = """
404
  const mouseJoystick = element.querySelector('#mouse-joystick');
405
 
406
  let isCapturing = false;
407
- let isTouchActive = false;
408
  let pressedKeys = new Set();
409
- let touchPressedKeys = new Set();
410
  let mouseVelocity = { x: 0, y: 0 };
411
  let lastMouseMove = Date.now();
412
 
413
- // Windows Virtual Key codes
414
  const BUTTON_MAP = {
415
  'KeyW': 87, 'KeyA': 65, 'KeyS': 83, 'KeyD': 68,
416
  'KeyQ': 81, 'KeyE': 69, 'KeyR': 82, 'KeyF': 70,
417
- 'KeyT': 84, 'KeyG': 71, 'KeyZ': 90, 'KeyX': 88,
418
- 'KeyC': 67, 'KeyV': 86, 'KeyB': 66, 'KeyN': 78,
419
- 'KeyM': 77, 'KeyH': 72, 'KeyJ': 74, 'KeyK': 75,
420
- 'KeyL': 76, 'KeyI': 73, 'KeyO': 79, 'KeyP': 80,
421
- 'KeyU': 85, 'KeyY': 89,
422
- 'Digit1': 49, 'Digit2': 50, 'Digit3': 51, 'Digit4': 52,
423
- 'Digit5': 53, 'Digit6': 54, 'Digit7': 55, 'Digit8': 56,
424
- 'Digit9': 57, 'Digit0': 48,
425
- 'Space': 32,
426
- 'ShiftLeft': 16, 'ShiftRight': 16,
427
  };
428
 
429
  const KEY_DISPLAY_MAP = {
@@ -431,8 +331,6 @@ CONTROL_INPUT_JS = """
431
  'ShiftLeft': 'key-shift', 'Space': 'key-space', 'KeyE': 'key-e',
432
  };
433
 
434
- const isTouchDevice = 'ontouchstart' in window || navigator.maxTouchPoints > 0;
435
-
436
  function updateKeyDisplay(code, pressed) {
437
  const elementId = KEY_DISPLAY_MAP[code];
438
  if (elementId) {
@@ -440,7 +338,6 @@ CONTROL_INPUT_JS = """
440
  if (keyEl) {
441
  keyEl.style.background = pressed ? 'rgba(88, 166, 255, 0.4)' : 'rgba(88, 166, 255, 0.1)';
442
  keyEl.style.borderColor = pressed ? '#58a6ff' : 'rgba(88, 166, 255, 0.3)';
443
- keyEl.style.boxShadow = pressed ? '0 0 8px rgba(88, 166, 255, 0.4)' : 'none';
444
  }
445
  }
446
  }
@@ -448,9 +345,8 @@ CONTROL_INPUT_JS = """
448
  function updateMouseDisplay() {
449
  const displayX = Math.max(-1, Math.min(1, mouseVelocity.x / 20));
450
  const displayY = Math.max(-1, Math.min(1, mouseVelocity.y / 20));
451
- const maxOffset = 40;
452
- mouseDot.style.left = (50 + displayX * maxOffset) + '%';
453
- mouseDot.style.top = (50 + displayY * maxOffset) + '%';
454
  mouseXValue.textContent = mouseVelocity.x.toFixed(1);
455
  mouseYValue.textContent = mouseVelocity.y.toFixed(1);
456
  }
@@ -460,18 +356,12 @@ CONTROL_INPUT_JS = """
460
  activeButtonsDisplay.innerHTML = '<span style="font-size: 11px; color: #484f58;">None</span>';
461
  } else {
462
  activeButtonsDisplay.innerHTML = Array.from(pressedKeys)
463
- .map(code => code.replace('Key', '').replace('Digit', '').replace('Left', ''))
464
  .map(name => `<span style="font-size: 10px; background: rgba(88, 166, 255, 0.2); color: #58a6ff; padding: 2px 6px; border-radius: 4px;">${name}</span>`)
465
  .join('');
466
  }
467
  }
468
 
469
- function updateStatus(capturing) {
470
- statusDot.style.background = capturing ? '#3fb950' : '#ff6b6b';
471
- statusDot.style.boxShadow = capturing ? '0 0 8px #3fb950' : '0 0 8px #ff6b6b';
472
- statusText.textContent = capturing ? 'Capturing - ESC to release' : 'Not capturing';
473
- }
474
-
475
  function triggerUpdate() {
476
  const buttonIds = Array.from(pressedKeys)
477
  .filter(code => BUTTON_MAP[code] !== undefined)
@@ -482,7 +372,8 @@ CONTROL_INPUT_JS = """
482
 
483
  document.addEventListener('pointerlockchange', () => {
484
  isCapturing = document.pointerLockElement !== null;
485
- updateStatus(isCapturing);
 
486
  if (!isCapturing) {
487
  pressedKeys.clear();
488
  mouseVelocity = { x: 0, y: 0 };
@@ -515,19 +406,17 @@ CONTROL_INPUT_JS = """
515
  }
516
  });
517
 
518
- const MOUSE_SENSITIVITY = 1.5;
519
-
520
  document.addEventListener('mousemove', (e) => {
521
  if (!isCapturing) return;
522
- mouseVelocity.x = e.movementX * MOUSE_SENSITIVITY;
523
- mouseVelocity.y = e.movementY * MOUSE_SENSITIVITY;
524
  updateMouseDisplay();
525
  triggerUpdate();
526
  lastMouseMove = Date.now();
527
  });
528
 
529
  setInterval(() => {
530
- if ((isCapturing || isTouchActive) && Date.now() - lastMouseMove > 50) {
531
  mouseVelocity.x *= 0.8;
532
  mouseVelocity.y *= 0.8;
533
  if (Math.abs(mouseVelocity.x) < 0.01) mouseVelocity.x = 0;
@@ -537,203 +426,37 @@ CONTROL_INPUT_JS = """
537
  }
538
  }, 100);
539
 
540
- // Touch controls
541
- const touchableKeys = element.querySelectorAll('[data-key]');
542
-
543
- touchableKeys.forEach(keyEl => {
544
- const keyCode = keyEl.dataset.key;
545
-
546
- const handleTouchStart = (e) => {
547
- e.preventDefault();
548
- isTouchActive = true;
549
- if (!pressedKeys.has(keyCode)) {
550
- pressedKeys.add(keyCode);
551
- touchPressedKeys.add(keyCode);
552
- updateKeyDisplay(keyCode, true);
553
- updateActiveButtonsDisplay();
554
- triggerUpdate();
555
- }
556
- };
557
-
558
- const handleTouchEnd = (e) => {
559
- e.preventDefault();
560
- if (touchPressedKeys.has(keyCode)) {
561
- pressedKeys.delete(keyCode);
562
- touchPressedKeys.delete(keyCode);
563
- updateKeyDisplay(keyCode, false);
564
- updateActiveButtonsDisplay();
565
- triggerUpdate();
566
- }
567
- };
568
-
569
- keyEl.addEventListener('touchstart', handleTouchStart, { passive: false });
570
- keyEl.addEventListener('touchend', handleTouchEnd, { passive: false });
571
- keyEl.addEventListener('touchcancel', handleTouchEnd, { passive: false });
572
-
573
- keyEl.addEventListener('mousedown', (e) => {
574
- e.preventDefault();
575
- isTouchActive = true;
576
- if (!pressedKeys.has(keyCode)) {
577
- pressedKeys.add(keyCode);
578
- touchPressedKeys.add(keyCode);
579
- updateKeyDisplay(keyCode, true);
580
- updateActiveButtonsDisplay();
581
- triggerUpdate();
582
- }
583
- });
584
-
585
- keyEl.addEventListener('mouseup', (e) => {
586
- if (touchPressedKeys.has(keyCode)) {
587
- pressedKeys.delete(keyCode);
588
- touchPressedKeys.delete(keyCode);
589
- updateKeyDisplay(keyCode, false);
590
- updateActiveButtonsDisplay();
591
- triggerUpdate();
592
- }
593
- });
594
-
595
- keyEl.addEventListener('mouseleave', (e) => {
596
- if (touchPressedKeys.has(keyCode)) {
597
- pressedKeys.delete(keyCode);
598
- touchPressedKeys.delete(keyCode);
599
- updateKeyDisplay(keyCode, false);
600
- updateActiveButtonsDisplay();
601
- triggerUpdate();
602
- }
603
- });
604
- });
605
-
606
- // Touch joystick
607
  let joystickActive = false;
608
- const JOYSTICK_SENSITIVITY = 0.5;
609
- const JOYSTICK_RADIUS = 50;
610
-
611
  const getJoystickCenter = () => {
612
  const rect = mouseJoystick.getBoundingClientRect();
613
  return { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 };
614
  };
615
-
616
  const handleJoystickMove = (clientX, clientY) => {
617
  const center = getJoystickCenter();
618
- let dx = clientX - center.x;
619
- let dy = clientY - center.y;
620
- const distance = Math.sqrt(dx * dx + dy * dy);
621
- if (distance > JOYSTICK_RADIUS) {
622
- dx = (dx / distance) * JOYSTICK_RADIUS;
623
- dy = (dy / distance) * JOYSTICK_RADIUS;
624
- }
625
- mouseVelocity.x = (dx / JOYSTICK_RADIUS) * 20 * JOYSTICK_SENSITIVITY;
626
- mouseVelocity.y = (dy / JOYSTICK_RADIUS) * 20 * JOYSTICK_SENSITIVITY;
627
  updateMouseDisplay();
628
  triggerUpdate();
629
- lastMouseMove = Date.now();
630
  };
 
 
 
631
 
632
- mouseJoystick.addEventListener('touchstart', (e) => {
633
- e.preventDefault();
634
- joystickActive = true;
635
- isTouchActive = true;
636
- handleJoystickMove(e.touches[0].clientX, e.touches[0].clientY);
637
- }, { passive: false });
638
-
639
- mouseJoystick.addEventListener('touchmove', (e) => {
640
- e.preventDefault();
641
- if (joystickActive) handleJoystickMove(e.touches[0].clientX, e.touches[0].clientY);
642
- }, { passive: false });
643
-
644
- mouseJoystick.addEventListener('touchend', (e) => {
645
- e.preventDefault();
646
- joystickActive = false;
647
- mouseVelocity = { x: 0, y: 0 };
648
- updateMouseDisplay();
649
- triggerUpdate();
650
- }, { passive: false });
651
-
652
- mouseJoystick.addEventListener('touchcancel', (e) => {
653
- joystickActive = false;
654
- mouseVelocity = { x: 0, y: 0 };
655
- updateMouseDisplay();
656
- triggerUpdate();
657
- }, { passive: false });
658
-
659
- mouseJoystick.addEventListener('mousedown', (e) => {
660
- e.preventDefault();
661
- joystickActive = true;
662
- isTouchActive = true;
663
- handleJoystickMove(e.clientX, e.clientY);
664
- });
665
-
666
- document.addEventListener('mousemove', (e) => {
667
- if (joystickActive) handleJoystickMove(e.clientX, e.clientY);
668
- });
669
-
670
- document.addEventListener('mouseup', () => {
671
- if (joystickActive) {
672
- joystickActive = false;
673
- mouseVelocity = { x: 0, y: 0 };
674
- updateMouseDisplay();
675
- triggerUpdate();
676
- }
677
  });
678
 
679
- if (isTouchDevice) {
680
- statusText.textContent = 'Touch controls ready';
681
- statusDot.style.background = '#3fb950';
682
- statusDot.style.boxShadow = '0 0 8px #3fb950';
683
- } else {
684
- updateStatus(false);
685
- }
686
-
687
  updateMouseDisplay();
688
  })();
689
  """
690
 
691
- CAPTURE_JS = """
692
- () => {
693
- const isTouchDevice = 'ontouchstart' in window || navigator.maxTouchPoints > 0;
694
- if (isTouchDevice) return;
695
-
696
- const insertButton = () => {
697
- const controlPanel = document.querySelector('#control-input-wrapper');
698
- if (!controlPanel) { setTimeout(insertButton, 100); return; }
699
-
700
- const btn = document.createElement('button');
701
- btn.id = 'capture-btn';
702
- btn.textContent = '🎮 Click to Capture Controls';
703
- btn.style.cssText = `
704
- width: 100%; padding: 12px 24px; margin-bottom: 12px;
705
- font-size: 14px; font-weight: bold;
706
- background: linear-gradient(135deg, #58a6ff 0%, #a371f7 100%);
707
- color: white; border: none; border-radius: 8px; cursor: pointer;
708
- box-shadow: 0 4px 12px rgba(0,0,0,0.3); transition: all 0.2s;
709
- `;
710
- btn.onmouseenter = () => btn.style.transform = 'scale(1.02)';
711
- btn.onmouseleave = () => btn.style.transform = 'scale(1)';
712
- btn.onclick = async () => {
713
- try { await document.body.requestPointerLock(); }
714
- catch (e) { console.log('Pointer lock failed:', e); }
715
- };
716
- controlPanel.parentNode.insertBefore(btn, controlPanel);
717
- };
718
-
719
- insertButton();
720
-
721
- document.addEventListener('pointerlockchange', () => {
722
- const btn = document.querySelector('#capture-btn');
723
- if (!btn) return;
724
- if (document.pointerLockElement) {
725
- btn.textContent = '🔒 Capturing (ESC to release)';
726
- btn.style.background = 'linear-gradient(135deg, #3fb950 0%, #2ea043 100%)';
727
- } else {
728
- btn.textContent = '🎮 Click to Capture Controls';
729
- btn.style.background = 'linear-gradient(135deg, #58a6ff 0%, #a371f7 100%)';
730
- }
731
- });
732
- }
733
- """
734
-
735
-
736
- # --- Gradio App ---
737
  css = """
738
  #col-container { max-width: 1200px; margin: 0 auto; }
739
  #video-output { aspect-ratio: 16/9; max-width: 640px; }
@@ -743,12 +466,23 @@ css = """
743
 
744
  def create_app():
745
  with gr.Blocks(css=css, theme=gr.themes.Soft(), title="WorldEngine") as demo:
 
 
 
 
 
 
 
 
 
 
 
746
  gr.Markdown("""
747
  # 🌍 WorldEngine — Real-Time World Model
748
 
749
- Interactive frame-by-frame world generation.
750
 
751
- **Controls:** WASD to move Mouse to look • Space to jump Shift to sprintE to interact • ESC to release
752
  """)
753
 
754
  with gr.Row():
@@ -761,7 +495,10 @@ def create_app():
761
  height=360,
762
  show_label=False,
763
  )
764
- frame_display = gr.Number(label="Frame", value=0, interactive=False)
 
 
 
765
 
766
  with gr.Column(scale=1):
767
  control_input = gr.HTML(
@@ -777,83 +514,166 @@ def create_app():
777
  )
778
 
779
  with gr.Accordion("Seed Selection", open=False):
780
- seed_image_upload = gr.Image(
781
- label="Upload Seed Image (optional)",
782
- type="pil",
783
- sources=["upload"],
784
- height=150,
785
- )
786
- gr.Markdown("*Or select a preset seed:*", elem_classes=["text-sm"])
787
  seed_dropdown = gr.Dropdown(
788
- choices=["Random"]
789
- + [f"Seed {i + 1}" for i in range(len(SEED_FRAME_URLS))],
790
  value="Random",
791
  label="Preset Seeds",
792
  )
793
- reset_btn = gr.Button("Reset World", variant="primary")
794
-
795
- # Event handlers
796
- def on_controller_change(controller_value):
797
- if not controller_value or not game_state.is_initialized:
798
- return
799
- buttons = set(controller_value.get("buttons", []))
800
- mouse_x = controller_value.get("mouse_x", 0.0)
801
- mouse_y = controller_value.get("mouse_y", 0.0)
802
- game_state.update_controls(buttons, (mouse_x, mouse_y))
803
-
804
- def on_prompt_change(prompt):
805
- game_state.prompt = prompt
806
-
807
- def on_reset(uploaded_image, seed_choice):
808
- if uploaded_image is not None:
809
- game_state.reset(seed_image=uploaded_image)
810
- elif seed_choice and seed_choice != "Random":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
  idx = int(seed_choice.split()[-1]) - 1
812
- game_state.reset(seed_url=SEED_FRAME_URLS[idx])
813
- else:
814
- game_state.reset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
815
 
816
- def get_frame():
817
- with game_state._frame_lock:
818
- return game_state._latest_frame, game_state.frame_count
 
 
819
 
820
- control_input.change(fn=on_controller_change, inputs=[control_input])
821
- prompt_input.change(fn=on_prompt_change, inputs=[prompt_input])
822
- reset_btn.click(fn=on_reset, inputs=[seed_image_upload, seed_dropdown])
823
 
824
- timer = gr.Timer(value=1 / 30)
825
- timer.tick(fn=get_frame, outputs=[video_output, frame_display])
 
 
 
 
 
826
 
827
- demo.load(fn=None, js=CAPTURE_JS)
 
 
 
 
 
 
 
 
 
 
 
828
 
829
  return demo
830
 
831
 
832
- def main():
833
- device = "cuda" if torch.cuda.is_available() else "cpu"
834
- use_compile = os.environ.get("COMPILE", "1") == "1"
835
- warmup_steps = int(os.environ.get("WARMUP_STEPS", "1"))
836
-
837
- print(f"Model path: {MODEL_ID}")
838
- print(f"Device: {device}")
839
- print(f"Compile: {use_compile}")
840
- print(f"Warmup steps: {warmup_steps}")
841
- print(f"Cache dir: {CACHE_DIR}")
842
 
843
- # Check if cache exists from previous run
844
- cache_files = list(CACHE_DIR.glob("**/*.json")) + list(CACHE_DIR.glob("**/*.py"))
845
- if cache_files:
846
- print(f"Found {len(cache_files)} cached compilation files")
847
 
848
- game_state.initialize(
849
- model_path=MODEL_ID,
850
- device=device,
851
- use_compile=use_compile,
852
- warmup_steps=warmup_steps,
853
- )
854
 
855
  demo = create_app()
856
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
857
 
858
 
859
  if __name__ == "__main__":
 
2
  WorldEngine Real-Time World Model Demo - ZeroGPU Edition
3
 
4
  A Gradio demo optimized for HuggingFace ZeroGPU with:
5
+ - Generator-based GPU session that stays alive
6
  - Persistent compilation cache for faster cold starts
7
+ - Command queue for real-time control
 
8
  """
9
+ # Check for ZeroGPU environment - must be before other imports
10
+ try:
11
+ import spaces
12
+ IS_ZERO_GPU = True
13
+ print("ZeroGPU environment detected")
14
+ except ImportError:
15
+ IS_ZERO_GPU = False
16
+ print("Running in standard GPU mode")
17
 
18
  import base64
19
  import os
20
  import random
21
+ from dataclasses import dataclass
 
 
22
  from io import BytesIO
23
+ from multiprocessing import Queue
24
  from pathlib import Path
25
  from typing import Optional, Set, Tuple
26
 
 
33
  from diffusers.utils import load_image
34
 
35
  # --- ZeroGPU Compilation Cache Setup ---
 
36
  CACHE_DIR = Path.home() / ".cache" / "world_engine_compile"
37
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
38
 
39
  os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(CACHE_DIR)
40
  os.environ["TORCH_COMPILE_CACHE_DIR"] = str(CACHE_DIR)
 
 
41
  os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
42
 
43
  print(f"Compilation cache directory: {CACHE_DIR}")
 
46
  torch.set_float32_matmul_precision("medium")
47
  torch._dynamo.config.capture_scalar_outputs = True
48
 
 
 
 
 
 
 
 
 
 
49
  # --- Configuration ---
50
  MODEL_ID = os.environ.get("MODEL_PATH", "diffusers-internal-dev/world-engine-modular")
51
+ USE_COMPILE = os.environ.get("COMPILE", "1") == "1"
52
 
 
53
  SEED_FRAME_URLS = [
54
  "https://gist.github.com/user-attachments/assets/5d91c49a-2ae9-418f-99c0-e93ae387e1de",
55
  "https://gist.github.com/user-attachments/assets/4adc5a3d-6980-4d1e-b6e8-9033cdf61c66",
 
74
  return f"data:image/png;base64,{img_str}"
75
 
76
 
77
+ # --- Command Types ---
78
+ @dataclass
79
+ class GenerateCommand:
80
+ """Generate next frame with given controls."""
81
+ buttons: Set[int]
82
+ mouse: Tuple[float, float]
83
+ prompt: str
84
+
85
+
86
  @dataclass
87
+ class ResetCommand:
88
+ """Reset world with new seed image."""
89
+ seed_image: Optional[Image.Image] = None
90
+ seed_url: Optional[str] = None
91
+ prompt: str = "An explorable world"
92
+
93
+
94
+ @dataclass
95
+ class StopCommand:
96
+ """Stop the GPU session."""
97
+ pass
98
+
99
+
100
+ # --- GPU Session Generator ---
101
+ def create_gpu_game_loop(command_queue: Queue):
102
+ """Create GPU game loop generator with closure over command_queue."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  @spaces.GPU(duration=120)
105
+ def gpu_game_loop():
106
+ """
107
+ Generator that keeps GPU allocated and processes commands.
108
+ Yields (frame, frame_count) tuples.
109
+ """
110
+ print("Starting GPU session...")
111
+
112
+ # Load model
113
+ print(f"Loading ModularPipeline from: {MODEL_ID}")
114
+ pipe = ModularPipeline.from_pretrained(MODEL_ID, trust_remote_code=True)
115
+ pipe.load_components(
116
+ device_map="cuda",
 
 
117
  torch_dtype=torch.bfloat16,
118
  trust_remote_code=True,
119
  )
120
 
121
+ pipe.transformer.apply_inference_patches()
122
+ pipe.transformer.quantize("fp8")
123
+
124
+ # Check if we have cached kernels
125
+ cache_files = list(CACHE_DIR.glob("**/*.so"))
126
+ needs_warmup = USE_COMPILE and not cache_files
127
+
128
+ if USE_COMPILE:
129
+ if cache_files:
130
+ print(f"Found {len(cache_files)} cached kernels")
131
+ else:
132
+ print("No cached kernels - first run will compile")
133
 
 
134
  print("Compiling transformer...")
135
+ pipe.transformer = torch.compile(
136
+ pipe.transformer,
137
  mode="max-autotune-no-cudagraphs",
138
  fullgraph=True,
139
  dynamic=False,
140
  )
141
 
142
  print("Compiling VAE...")
143
+ pipe.vae = torch.compile(
144
+ pipe.vae,
145
  mode="max-autotune",
146
  fullgraph=True,
147
  dynamic=False,
148
  )
149
 
150
+ n_frames = pipe.transformer.config.n_frames
151
+ print(f"Model loaded! (n_frames={n_frames})")
152
 
153
+ # Warmup pass if no cached kernels (triggers compilation)
154
+ if needs_warmup:
155
+ print("Running warmup pass to compile kernels...")
156
+ warmup_image = load_seed_frame(SEED_FRAME_URLS[0])
157
+ _ = pipe(
158
+ prompt="warmup",
159
+ image=warmup_image,
160
+ button=set(),
161
+ mouse=(0.0, 0.0),
162
+ output_type="pil",
163
+ )
164
+ print("Warmup complete! Kernels cached for future runs.")
165
 
166
+ # Initialize state
 
167
  seed_image = load_seed_frame(random.choice(SEED_FRAME_URLS))
168
+ state = pipe(
169
+ prompt="An explorable world",
170
  image=seed_image,
171
  button=set(),
172
  mouse=(0.0, 0.0),
173
  output_type="pil",
174
  )
175
+ frame_count = 1
 
 
 
 
 
 
 
176
 
177
+ # Get initial frame
178
+ frame = state.values.get("images")
179
+ print("Initial frame generated, entering game loop...")
180
 
181
+ # Yield initial frame
182
+ yield (frame, frame_count)
 
 
183
 
184
+ # Main loop - process commands and yield frames
185
+ while True:
186
+ try:
187
+ # Non-blocking get with short timeout for responsiveness
188
+ command = command_queue.get(timeout=0.01)
189
+ except:
190
+ # No command, continue idle
191
+ command = None
192
 
193
+ if command is None:
194
+ continue
 
 
 
 
 
195
 
196
+ if isinstance(command, StopCommand):
197
+ print("Stop command received, ending GPU session")
198
+ break
199
+
200
+ elif isinstance(command, ResetCommand):
201
+ print("Reset command received")
202
+ if command.seed_image is not None:
203
+ seed_img = command.seed_image.resize((640, 360), Image.BILINEAR)
204
+ elif command.seed_url:
205
+ seed_img = load_seed_frame(command.seed_url)
206
+ else:
207
+ seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS))
208
+
209
+ state = pipe(
210
+ prompt=command.prompt,
211
+ image=seed_img,
212
  button=set(),
213
  mouse=(0.0, 0.0),
214
+ output_type="pil",
215
+ )
216
+ frame_count = 1
217
+ frame = state.values.get("images")
218
+ yield (frame, frame_count)
219
+
220
+ elif isinstance(command, GenerateCommand):
221
+ # Generate next frame
222
+ state = pipe(
223
+ state,
224
+ prompt=command.prompt,
225
+ button=command.buttons,
226
+ mouse=command.mouse,
227
  image=None,
228
  output_type="pil",
229
  )
230
+ frame_count += 1
231
+ frame = state.values.get("images")
232
+
233
+ # Auto-reset near end of context
234
+ if frame_count >= n_frames - 2:
235
+ print(f"Auto-reset at frame {frame_count}")
236
+ seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS))
237
+ state = pipe(
238
+ prompt=command.prompt,
239
+ image=seed_img,
240
+ button=set(),
241
+ mouse=(0.0, 0.0),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  output_type="pil",
243
  )
244
+ frame_count = 1
245
+ frame = state.values.get("images")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
+ yield (frame, frame_count)
 
248
 
249
+ print("GPU session ended")
 
 
 
250
 
251
+ # Return the generator
252
+ return gpu_game_loop()
253
 
 
254
 
255
+ # --- Gradio App ---
 
256
  CONTROL_INPUT_HTML = """
257
  <div id="control-input-wrapper" style="width: 100%; background: #0a0a0f; border-radius: 12px; overflow: hidden; font-family: 'JetBrains Mono', 'Fira Code', monospace;">
258
  <div style="padding: 16px; background: linear-gradient(180deg, rgba(22, 27, 34, 0.95) 0%, rgba(13, 17, 23, 0.98) 100%);">
 
259
  <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 16px;">
260
  <div id="status-indicator" style="display: flex; align-items: center; gap: 6px; background: rgba(0,0,0,0.4); padding: 6px 12px; border-radius: 20px; border: 1px solid rgba(88, 166, 255, 0.2);">
261
  <div id="status-dot" style="width: 8px; height: 8px; border-radius: 50%; background: #ff6b6b; box-shadow: 0 0 8px #ff6b6b;"></div>
262
  <span id="status-text" style="font-size: 11px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px;">Not capturing</span>
263
  </div>
264
  </div>
 
265
  <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 16px;">
 
266
  <div style="background: rgba(0,0,0,0.3); border-radius: 12px; padding: 12px; border: 1px solid rgba(88, 166, 255, 0.1);">
267
  <div style="font-size: 10px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 10px;">Movement</div>
268
  <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 4px; max-width: 120px; margin: 0 auto;">
269
  <div></div>
270
+ <div id="key-w" data-key="KeyW" style="aspect-ratio: 1; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 11px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none;">W</div>
271
  <div></div>
272
+ <div id="key-a" data-key="KeyA" style="aspect-ratio: 1; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 11px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none;">A</div>
273
+ <div id="key-s" data-key="KeyS" style="aspect-ratio: 1; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 11px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none;">S</div>
274
+ <div id="key-d" data-key="KeyD" style="aspect-ratio: 1; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 11px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none;">D</div>
275
  </div>
276
  <div style="display: flex; gap: 4px; margin-top: 8px; justify-content: center;">
277
+ <div id="key-shift" data-key="ShiftLeft" style="padding: 4px 8px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 4px; font-size: 9px; color: #58a6ff; cursor: pointer; user-select: none;">SHIFT</div>
278
+ <div id="key-space" data-key="Space" style="padding: 4px 16px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 4px; font-size: 9px; color: #58a6ff; cursor: pointer; user-select: none;">SPACE</div>
279
+ <div id="key-e" data-key="KeyE" style="padding: 4px 8px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 4px; font-size: 9px; color: #58a6ff; cursor: pointer; user-select: none;">E</div>
280
  </div>
281
  </div>
 
 
282
  <div style="background: rgba(0,0,0,0.3); border-radius: 12px; padding: 12px; border: 1px solid rgba(88, 166, 255, 0.1);">
283
  <div style="font-size: 10px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 10px;">Look (drag)</div>
284
  <div style="display: flex; align-items: center; justify-content: center; gap: 16px;">
285
  <div id="mouse-joystick" style="width: 100px; height: 100px; background: rgba(88, 166, 255, 0.05); border: 2px solid rgba(88, 166, 255, 0.3); border-radius: 50%; position: relative; cursor: pointer; touch-action: none;">
286
+ <div id="mouse-dot" style="width: 24px; height: 24px; background: #58a6ff; border-radius: 50%; position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%); box-shadow: 0 0 12px rgba(88, 166, 255, 0.6); pointer-events: none;"></div>
 
 
287
  </div>
288
  <div style="text-align: left;">
289
  <div style="font-size: 10px; color: #8b949e; margin-bottom: 4px;">Velocity</div>
 
293
  </div>
294
  </div>
295
  </div>
 
296
  <div style="margin-top: 12px; background: rgba(0,0,0,0.3); border-radius: 8px; padding: 8px 12px; border: 1px solid rgba(88, 166, 255, 0.1);">
297
  <div style="display: flex; align-items: center; gap: 8px;">
298
  <span style="font-size: 10px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px;">Active:</span>
 
316
  const mouseJoystick = element.querySelector('#mouse-joystick');
317
 
318
  let isCapturing = false;
 
319
  let pressedKeys = new Set();
 
320
  let mouseVelocity = { x: 0, y: 0 };
321
  let lastMouseMove = Date.now();
322
 
 
323
  const BUTTON_MAP = {
324
  'KeyW': 87, 'KeyA': 65, 'KeyS': 83, 'KeyD': 68,
325
  'KeyQ': 81, 'KeyE': 69, 'KeyR': 82, 'KeyF': 70,
326
+ 'Space': 32, 'ShiftLeft': 16, 'ShiftRight': 16,
 
 
 
 
 
 
 
 
 
327
  };
328
 
329
  const KEY_DISPLAY_MAP = {
 
331
  'ShiftLeft': 'key-shift', 'Space': 'key-space', 'KeyE': 'key-e',
332
  };
333
 
 
 
334
  function updateKeyDisplay(code, pressed) {
335
  const elementId = KEY_DISPLAY_MAP[code];
336
  if (elementId) {
 
338
  if (keyEl) {
339
  keyEl.style.background = pressed ? 'rgba(88, 166, 255, 0.4)' : 'rgba(88, 166, 255, 0.1)';
340
  keyEl.style.borderColor = pressed ? '#58a6ff' : 'rgba(88, 166, 255, 0.3)';
 
341
  }
342
  }
343
  }
 
345
  function updateMouseDisplay() {
346
  const displayX = Math.max(-1, Math.min(1, mouseVelocity.x / 20));
347
  const displayY = Math.max(-1, Math.min(1, mouseVelocity.y / 20));
348
+ mouseDot.style.left = (50 + displayX * 40) + '%';
349
+ mouseDot.style.top = (50 + displayY * 40) + '%';
 
350
  mouseXValue.textContent = mouseVelocity.x.toFixed(1);
351
  mouseYValue.textContent = mouseVelocity.y.toFixed(1);
352
  }
 
356
  activeButtonsDisplay.innerHTML = '<span style="font-size: 11px; color: #484f58;">None</span>';
357
  } else {
358
  activeButtonsDisplay.innerHTML = Array.from(pressedKeys)
359
+ .map(code => code.replace('Key', '').replace('Left', ''))
360
  .map(name => `<span style="font-size: 10px; background: rgba(88, 166, 255, 0.2); color: #58a6ff; padding: 2px 6px; border-radius: 4px;">${name}</span>`)
361
  .join('');
362
  }
363
  }
364
 
 
 
 
 
 
 
365
  function triggerUpdate() {
366
  const buttonIds = Array.from(pressedKeys)
367
  .filter(code => BUTTON_MAP[code] !== undefined)
 
372
 
373
  document.addEventListener('pointerlockchange', () => {
374
  isCapturing = document.pointerLockElement !== null;
375
+ statusDot.style.background = isCapturing ? '#3fb950' : '#ff6b6b';
376
+ statusText.textContent = isCapturing ? 'Capturing - ESC to release' : 'Not capturing';
377
  if (!isCapturing) {
378
  pressedKeys.clear();
379
  mouseVelocity = { x: 0, y: 0 };
 
406
  }
407
  });
408
 
 
 
409
  document.addEventListener('mousemove', (e) => {
410
  if (!isCapturing) return;
411
+ mouseVelocity.x = e.movementX * 1.5;
412
+ mouseVelocity.y = e.movementY * 1.5;
413
  updateMouseDisplay();
414
  triggerUpdate();
415
  lastMouseMove = Date.now();
416
  });
417
 
418
  setInterval(() => {
419
+ if (isCapturing && Date.now() - lastMouseMove > 50) {
420
  mouseVelocity.x *= 0.8;
421
  mouseVelocity.y *= 0.8;
422
  if (Math.abs(mouseVelocity.x) < 0.01) mouseVelocity.x = 0;
 
426
  }
427
  }, 100);
428
 
429
+ // Touch/click controls for joystick
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  let joystickActive = false;
 
 
 
431
  const getJoystickCenter = () => {
432
  const rect = mouseJoystick.getBoundingClientRect();
433
  return { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 };
434
  };
 
435
  const handleJoystickMove = (clientX, clientY) => {
436
  const center = getJoystickCenter();
437
+ let dx = clientX - center.x, dy = clientY - center.y;
438
+ const dist = Math.sqrt(dx*dx + dy*dy);
439
+ if (dist > 50) { dx = dx/dist*50; dy = dy/dist*50; }
440
+ mouseVelocity.x = (dx/50) * 10;
441
+ mouseVelocity.y = (dy/50) * 10;
 
 
 
 
442
  updateMouseDisplay();
443
  triggerUpdate();
 
444
  };
445
+ mouseJoystick.addEventListener('mousedown', (e) => { joystickActive = true; handleJoystickMove(e.clientX, e.clientY); });
446
+ document.addEventListener('mousemove', (e) => { if (joystickActive) handleJoystickMove(e.clientX, e.clientY); });
447
+ document.addEventListener('mouseup', () => { if (joystickActive) { joystickActive = false; mouseVelocity = {x:0,y:0}; updateMouseDisplay(); triggerUpdate(); }});
448
 
449
+ // Touch controls for keys
450
+ element.querySelectorAll('[data-key]').forEach(keyEl => {
451
+ const keyCode = keyEl.dataset.key;
452
+ keyEl.addEventListener('touchstart', (e) => { e.preventDefault(); pressedKeys.add(keyCode); updateKeyDisplay(keyCode, true); updateActiveButtonsDisplay(); triggerUpdate(); }, {passive:false});
453
+ keyEl.addEventListener('touchend', (e) => { e.preventDefault(); pressedKeys.delete(keyCode); updateKeyDisplay(keyCode, false); updateActiveButtonsDisplay(); triggerUpdate(); }, {passive:false});
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  });
455
 
 
 
 
 
 
 
 
 
456
  updateMouseDisplay();
457
  })();
458
  """
459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  css = """
461
  #col-container { max-width: 1200px; margin: 0 auto; }
462
  #video-output { aspect-ratio: 16/9; max-width: 640px; }
 
466
 
467
  def create_app():
468
  with gr.Blocks(css=css, theme=gr.themes.Soft(), title="WorldEngine") as demo:
469
+ # State: (generator, command_queue) or empty tuple
470
+ session_state = gr.State(())
471
+
472
+ # Current controls (updated by JS)
473
+ current_controls = gr.State({"buttons": [], "mouse_x": 0.0, "mouse_y": 0.0})
474
+ current_prompt = gr.State("An explorable world")
475
+
476
+ # Latest frame for display
477
+ latest_frame = gr.State(None)
478
+ latest_frame_count = gr.State(0)
479
+
480
  gr.Markdown("""
481
  # 🌍 WorldEngine — Real-Time World Model
482
 
483
+ Interactive frame-by-frame world generation on ZeroGPU.
484
 
485
+ **Controls:** Click "Start GPU" Click game view to capture WASD to moveMouse to look • ESC to release
486
  """)
487
 
488
  with gr.Row():
 
495
  height=360,
496
  show_label=False,
497
  )
498
+ with gr.Row():
499
+ frame_display = gr.Number(label="Frame", value=0, interactive=False)
500
+ start_btn = gr.Button("🎮 Start Game", variant="primary")
501
+ stop_btn = gr.Button("⏹ End Game", interactive=False)
502
 
503
  with gr.Column(scale=1):
504
  control_input = gr.HTML(
 
514
  )
515
 
516
  with gr.Accordion("Seed Selection", open=False):
 
 
 
 
 
 
 
517
  seed_dropdown = gr.Dropdown(
518
+ choices=["Random"] + [f"Seed {i+1}" for i in range(len(SEED_FRAME_URLS))],
 
519
  value="Random",
520
  label="Preset Seeds",
521
  )
522
+ reset_btn = gr.Button("Reset World", variant="secondary")
523
+
524
+ # --- Event Handlers ---
525
+
526
+ def on_start():
527
+ """Start GPU session - returns generator and queue."""
528
+ command_queue = Queue()
529
+ gen = create_gpu_game_loop(command_queue)
530
+
531
+ # Get initial frame
532
+ frame, frame_count = next(gen)
533
+
534
+ return (
535
+ (gen, command_queue), # session_state
536
+ frame, # latest_frame
537
+ frame_count, # latest_frame_count
538
+ frame, # video_output
539
+ frame_count, # frame_display
540
+ gr.update(interactive=False), # start_btn
541
+ gr.update(interactive=True), # stop_btn
542
+ )
543
+
544
+ def on_stop(state):
545
+ """Stop GPU session."""
546
+ if len(state) == 0:
547
+ return ((), None, 0, None, 0,
548
+ gr.update(interactive=True), gr.update(interactive=False))
549
+
550
+ gen, command_queue = state
551
+ command_queue.put(StopCommand())
552
+
553
+ # Drain generator
554
+ try:
555
+ while True:
556
+ next(gen)
557
+ except StopIteration:
558
+ pass
559
+
560
+ return (
561
+ (),
562
+ None,
563
+ 0,
564
+ None,
565
+ 0,
566
+ gr.update(interactive=True),
567
+ gr.update(interactive=False),
568
+ )
569
+
570
+ def on_generate_tick(state, controls, prompt, current_frame, current_count):
571
+ """Called by timer - send generate command and get next frame."""
572
+ if len(state) == 0:
573
+ return current_frame, current_count, current_frame, current_count
574
+
575
+ gen, command_queue = state
576
+
577
+ # Send generate command
578
+ buttons = set(controls.get("buttons", []))
579
+ mouse = (controls.get("mouse_x", 0.0), controls.get("mouse_y", 0.0))
580
+ command_queue.put(GenerateCommand(buttons=buttons, mouse=mouse, prompt=prompt))
581
+
582
+ # Get next frame
583
+ try:
584
+ frame, frame_count = next(gen)
585
+ return frame, frame_count, frame, frame_count
586
+ except StopIteration:
587
+ return current_frame, current_count, current_frame, current_count
588
+
589
+ def on_reset(state, seed_choice, prompt):
590
+ """Reset world with new seed."""
591
+ if len(state) == 0:
592
+ return None, 0, None, 0
593
+
594
+ gen, command_queue = state
595
+
596
+ seed_url = None
597
+ if seed_choice and seed_choice != "Random":
598
  idx = int(seed_choice.split()[-1]) - 1
599
+ seed_url = SEED_FRAME_URLS[idx]
600
+
601
+ command_queue.put(ResetCommand(seed_url=seed_url, prompt=prompt))
602
+
603
+ try:
604
+ frame, frame_count = next(gen)
605
+ return frame, frame_count, frame, frame_count
606
+ except StopIteration:
607
+ return None, 0, None, 0
608
+
609
+ def on_controller_change(value):
610
+ """Update current controls state."""
611
+ return value or {"buttons": [], "mouse_x": 0.0, "mouse_y": 0.0}
612
+
613
+ def on_prompt_change(value):
614
+ """Update current prompt state."""
615
+ return value
616
+
617
+ # Wire up events
618
+ start_btn.click(
619
+ fn=on_start,
620
+ outputs=[session_state, latest_frame, latest_frame_count,
621
+ video_output, frame_display, start_btn, stop_btn],
622
+ )
623
+
624
+ stop_btn.click(
625
+ fn=on_stop,
626
+ inputs=[session_state],
627
+ outputs=[session_state, latest_frame, latest_frame_count,
628
+ video_output, frame_display, start_btn, stop_btn],
629
+ )
630
 
631
+ reset_btn.click(
632
+ fn=on_reset,
633
+ inputs=[session_state, seed_dropdown, current_prompt],
634
+ outputs=[latest_frame, latest_frame_count, video_output, frame_display],
635
+ )
636
 
637
+ control_input.change(fn=on_controller_change, inputs=[control_input], outputs=[current_controls])
638
+ prompt_input.change(fn=on_prompt_change, inputs=[prompt_input], outputs=[current_prompt])
 
639
 
640
+ # Timer for continuous generation
641
+ timer = gr.Timer(value=1/15) # 15 FPS target
642
+ timer.tick(
643
+ fn=on_generate_tick,
644
+ inputs=[session_state, current_controls, current_prompt, latest_frame, latest_frame_count],
645
+ outputs=[latest_frame, latest_frame_count, video_output, frame_display],
646
+ )
647
 
648
+ # Pointer lock JS
649
+ demo.load(fn=None, js="""
650
+ () => {
651
+ const insertButton = () => {
652
+ const output = document.querySelector('#video-output');
653
+ if (!output) { setTimeout(insertButton, 100); return; }
654
+ output.style.cursor = 'pointer';
655
+ output.onclick = () => document.body.requestPointerLock();
656
+ };
657
+ insertButton();
658
+ }
659
+ """)
660
 
661
  return demo
662
 
663
 
664
+ # Avoid ZeroGPU "no GPU function" error
665
+ if IS_ZERO_GPU:
666
+ spaces.GPU(lambda: None)
 
 
 
 
 
 
 
667
 
 
 
 
 
668
 
669
+ def main():
670
+ print(f"Model: {MODEL_ID}")
671
+ print(f"Compile: {USE_COMPILE}")
672
+ print(f"Cache dir: {CACHE_DIR}")
673
+ print(f"ZeroGPU: {IS_ZERO_GPU}")
 
674
 
675
  demo = create_app()
676
+ demo.launch(server_name="0.0.0.0", server_port=7860)
677
 
678
 
679
  if __name__ == "__main__":