dn6 HF Staff commited on
Commit
2c2d194
·
1 Parent(s): 23bd5e7
Files changed (2) hide show
  1. app.py +857 -4
  2. requirements.txt +19 -0
app.py CHANGED
@@ -1,7 +1,860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ """
2
+ WorldEngine Real-Time World Model Demo - ZeroGPU Edition
3
+
4
+ A Gradio demo optimized for HuggingFace ZeroGPU with:
5
+ - Persistent compilation cache for faster cold starts
6
+ - Warmup pass to pre-compile kernels
7
+ - @spaces.GPU decorator support
8
+ """
9
+
10
+ import base64
11
+ import os
12
+ import random
13
+ import threading
14
+ import time
15
+ from dataclasses import dataclass, field
16
+ from io import BytesIO
17
+ from pathlib import Path
18
+ from typing import Optional, Set, Tuple
19
+
20
  import gradio as gr
21
+ import numpy as np
22
+ import torch
23
+ from PIL import Image
24
+
25
+ from diffusers import ModularPipeline
26
+ from diffusers.utils import load_image
27
+
28
+ # --- ZeroGPU Compilation Cache Setup ---
29
+ # Set persistent cache directory BEFORE importing torch.compile
30
+ CACHE_DIR = Path.home() / ".cache" / "world_engine_compile"
31
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
32
+
33
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(CACHE_DIR)
34
+ os.environ["TORCH_COMPILE_CACHE_DIR"] = str(CACHE_DIR)
35
+
36
+ # Enable FX graph caching for faster recompilation
37
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
38
+
39
+ print(f"Compilation cache directory: {CACHE_DIR}")
40
+
41
+ torch._dynamo.config.recompile_limit = 64
42
+ torch.set_float32_matmul_precision("medium")
43
+ torch._dynamo.config.capture_scalar_outputs = True
44
+
45
+ # Check for ZeroGPU environment
46
+ try:
47
+ import spaces
48
+ IS_ZERO_GPU = True
49
+ print("ZeroGPU environment detected")
50
+ except ImportError:
51
+ IS_ZERO_GPU = False
52
+ print("Running in standard GPU mode")
53
+
54
+ # --- Configuration ---
55
+ MODEL_ID = os.environ.get("MODEL_PATH", "diffusers-internal-dev/world-engine-modular")
56
+ MAX_SEED = np.iinfo(np.int32).max
57
+
58
+ # Seed frame URLs for reset
59
+ SEED_FRAME_URLS = [
60
+ "https://gist.github.com/user-attachments/assets/5d91c49a-2ae9-418f-99c0-e93ae387e1de",
61
+ "https://gist.github.com/user-attachments/assets/4adc5a3d-6980-4d1e-b6e8-9033cdf61c66",
62
+ "https://gist.github.com/user-attachments/assets/ae398747-de4c-4d43-bac4-54fe61ab0ca8",
63
+ "https://gist.github.com/user-attachments/assets/9d7336fa-5cec-4c7d-bb65-eaebac0a6336",
64
+ "https://gist.github.com/user-attachments/assets/55dae2d3-00e3-4d03-bb6c-2c7e0ac70f5f",
65
+ ]
66
+
67
+
68
+ def load_seed_frame(url: str, target_size: Tuple[int, int] = (360, 640)) -> Image.Image:
69
+ """Load and resize seed frame to target size."""
70
+ img = load_image(url)
71
+ img = img.resize((target_size[1], target_size[0]), Image.BILINEAR)
72
+ return img
73
+
74
+
75
+ def image_to_base64(image: Image.Image) -> str:
76
+ """Convert PIL image to base64 data URL."""
77
+ buffered = BytesIO()
78
+ image.save(buffered, format="PNG")
79
+ img_str = base64.b64encode(buffered.getvalue()).decode()
80
+ return f"data:image/png;base64,{img_str}"
81
+
82
+
83
+ # --- Control State ---
84
+ @dataclass
85
+ class ControlState:
86
+ """Thread-safe control state."""
87
+
88
+ buttons: Set[int] = field(default_factory=set)
89
+ mouse: Tuple[float, float] = (0.0, 0.0)
90
+ _lock: threading.Lock = field(default_factory=threading.Lock)
91
+
92
+ def update(self, buttons: Set[int], mouse: Tuple[float, float]):
93
+ with self._lock:
94
+ self.buttons = buttons.copy()
95
+ self.mouse = mouse
96
+
97
+ def get(self) -> Tuple[Set[int], Tuple[float, float]]:
98
+ with self._lock:
99
+ return self.buttons.copy(), self.mouse
100
+
101
+
102
+ # --- Game State with Background Generation ---
103
+ class GameState:
104
+ """Global game state with background frame generation."""
105
+
106
+ def __init__(self):
107
+ self.pipe: Optional[ModularPipeline] = None
108
+ self.state = None
109
+ self.control_state = ControlState()
110
+ self.frame_count = 0
111
+ self.n_frames = None
112
+ self.is_initialized = False
113
+ self.is_warming_up = False
114
+ self.prompt = "An explorable world"
115
+
116
+ # Background generation
117
+ self._generation_thread: Optional[threading.Thread] = None
118
+ self._running = False
119
+ self._state_lock = threading.Lock()
120
+
121
+ # Latest frame buffer
122
+ self._latest_frame: Optional[Image.Image] = None
123
+ self._latest_frame_b64: Optional[str] = None
124
+ self._frame_lock = threading.Lock()
125
+
126
+ @spaces.GPU(duration=120)
127
+ def initialize(
128
+ self,
129
+ model_path: str,
130
+ device: str = "cuda",
131
+ use_compile: bool = True,
132
+ warmup_steps: int = 1,
133
+ ):
134
+ if self.is_initialized:
135
+ return
136
+
137
+ print(f"Loading ModularPipeline from: {model_path}")
138
+ self.pipe = ModularPipeline.from_pretrained(model_path, trust_remote_code=True)
139
+ self.pipe.load_components(
140
+ device_map=device,
141
+ torch_dtype=torch.bfloat16,
142
+ trust_remote_code=True,
143
+ )
144
+
145
+ self.pipe.transformer.apply_inference_patches()
146
+ self.pipe.transformer.quantize("fp8")
147
+
148
+ if use_compile:
149
+ print("Compiling transformer...")
150
+ self.pipe.transformer = torch.compile(
151
+ self.pipe.transformer,
152
+ mode="max-autotune-no-cudagraphs",
153
+ fullgraph=True,
154
+ dynamic=False,
155
+ )
156
+
157
+ print("Compiling VAE...")
158
+ self.pipe.vae = torch.compile(
159
+ self.pipe.vae,
160
+ mode="max-autotune",
161
+ fullgraph=True,
162
+ dynamic=False,
163
+ )
164
+
165
+ self.n_frames = self.pipe.transformer.config.n_frames
166
+ print(f"ModularPipeline loaded! (n_frames={self.n_frames})")
167
+
168
+ # Only warmup if cache is empty (first run)
169
+ # Subsequent cold starts can skip warmup since kernels are already compiled
170
+ cache_files = list(CACHE_DIR.glob("**/*.so"))
171
+ if use_compile and warmup_steps > 0 and not cache_files:
172
+ print("No cached kernels found - running warmup to compile...")
173
+ self._warmup_compile(warmup_steps)
174
+ elif cache_files:
175
+ print(f"Found {len(cache_files)} cached kernels - skipping warmup")
176
+
177
+ self.is_initialized = True
178
+
179
+ # Generate initial state
180
+ print("Generating initial state...")
181
+ seed_image = load_seed_frame(random.choice(SEED_FRAME_URLS))
182
+ self.state = self.pipe(
183
+ prompt=self.prompt,
184
+ image=seed_image,
185
+ button=set(),
186
+ mouse=(0.0, 0.0),
187
+ output_type="pil",
188
+ )
189
+ self.frame_count = 1
190
+
191
+ # Store initial frame
192
+ images = self.state.values.get("images")
193
+ if images is not None:
194
+ with self._frame_lock:
195
+ self._latest_frame = images
196
+ self._latest_frame_b64 = image_to_base64(images)
197
+
198
+ print("Initial state ready!")
199
+ self._start_generation_thread()
200
+
201
+ def _warmup_compile(self, num_steps: int = 3):
202
+ """Run warmup passes to trigger torch.compile before first real request."""
203
+ print(f"Warming up compiled model ({num_steps} steps)...")
204
+ self.is_warming_up = True
205
+
206
+ try:
207
+ # Create a temporary state for warmup
208
+ seed_image = load_seed_frame(SEED_FRAME_URLS[0])
209
+
210
+ warmup_state = self.pipe(
211
+ prompt="warmup",
212
+ image=seed_image,
213
+ button=set(),
214
+ mouse=(0.0, 0.0),
215
+ output_type="pil",
216
+ )
217
+
218
+ # Run a few more steps to ensure all kernels are compiled
219
+ for i in range(num_steps - 1):
220
+ warmup_state = self.pipe(
221
+ warmup_state,
222
+ prompt="warmup",
223
+ button=set(),
224
+ mouse=(0.0, 0.0),
225
+ image=None,
226
+ output_type="pil",
227
+ )
228
+ print(f" Warmup step {i + 2}/{num_steps} complete")
229
+
230
+ print("Warmup complete! Compiled kernels cached.")
231
+
232
+ except Exception as e:
233
+ print(f"Warmup error (non-fatal): {e}")
234
+
235
+ self.is_warming_up = False
236
+
237
+ def _start_generation_thread(self):
238
+ if self._generation_thread is not None and self._generation_thread.is_alive():
239
+ return
240
+
241
+ self._running = True
242
+ self._generation_thread = threading.Thread(
243
+ target=self._generation_loop, daemon=True
244
+ )
245
+ self._generation_thread.start()
246
+ print("Background generation thread started!")
247
+
248
+ def _generation_loop(self):
249
+ while self._running:
250
+ if not self.is_initialized or self.state is None or self.is_warming_up:
251
+ time.sleep(0.01)
252
+ continue
253
+
254
+ buttons, mouse = self.control_state.get()
255
+
256
+ with self._state_lock:
257
+ try:
258
+ self.state = self.pipe(
259
+ self.state,
260
+ prompt=self.prompt,
261
+ button=buttons,
262
+ mouse=mouse,
263
+ image=None,
264
+ output_type="pil",
265
+ )
266
+ self.frame_count += 1
267
+
268
+ images = self.state.values.get("images")
269
+ if images is not None:
270
+ with self._frame_lock:
271
+ self._latest_frame = images
272
+ self._latest_frame_b64 = image_to_base64(images)
273
+
274
+ if self.frame_count >= self.n_frames - 2:
275
+ print(f"Auto-reset at frame {self.frame_count}")
276
+ self._do_reset()
277
+
278
+ except Exception as e:
279
+ import traceback
280
+
281
+ print(f"Generation error: {e}")
282
+ traceback.print_exc()
283
+ time.sleep(0.5)
284
+
285
+ def _do_reset(self, seed_url: str = None, seed_image: Image.Image = None):
286
+ self.frame_count = 0
287
+
288
+ # Use provided image, or load from URL, or pick random
289
+ if seed_image is not None:
290
+ # Resize uploaded image to target size
291
+ target_size = (360, 640) # (H, W)
292
+ seed_image = seed_image.resize(
293
+ (target_size[1], target_size[0]), Image.BILINEAR
294
+ )
295
+ else:
296
+ url = seed_url or random.choice(SEED_FRAME_URLS)
297
+ seed_image = load_seed_frame(url)
298
+
299
+ buttons, mouse = self.control_state.get()
300
+ self.state = self.pipe(
301
+ prompt=self.prompt,
302
+ image=seed_image,
303
+ button=buttons,
304
+ mouse=mouse,
305
+ output_type="pil",
306
+ )
307
+ self.frame_count = 1
308
+
309
+ images = self.state.values.get("images")
310
+ if images is not None:
311
+ with self._frame_lock:
312
+ self._latest_frame = images
313
+ self._latest_frame_b64 = image_to_base64(images)
314
+
315
+ def reset(self, seed_url: str = None, seed_image: Image.Image = None):
316
+ with self._state_lock:
317
+ self._do_reset(seed_url=seed_url, seed_image=seed_image)
318
+
319
+ def get_latest_frame(self) -> Tuple[Optional[str], int]:
320
+ with self._frame_lock:
321
+ return self._latest_frame_b64, self.frame_count
322
+
323
+ def update_controls(self, buttons: Set[int], mouse: Tuple[float, float]):
324
+ self.control_state.update(buttons, mouse)
325
+
326
+ def stop(self):
327
+ self._running = False
328
+ if self._generation_thread is not None:
329
+ self._generation_thread.join(timeout=2.0)
330
+
331
+
332
+ game_state = GameState()
333
+
334
+
335
+ # --- Control Input HTML Component ---
336
+ CONTROL_INPUT_HTML = """
337
+ <div id="control-input-wrapper" style="width: 100%; background: #0a0a0f; border-radius: 12px; overflow: hidden; font-family: 'JetBrains Mono', 'Fira Code', monospace;">
338
+ <div style="padding: 16px; background: linear-gradient(180deg, rgba(22, 27, 34, 0.95) 0%, rgba(13, 17, 23, 0.98) 100%);">
339
+ <!-- Status Bar -->
340
+ <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 16px;">
341
+ <div id="status-indicator" style="display: flex; align-items: center; gap: 6px; background: rgba(0,0,0,0.4); padding: 6px 12px; border-radius: 20px; border: 1px solid rgba(88, 166, 255, 0.2);">
342
+ <div id="status-dot" style="width: 8px; height: 8px; border-radius: 50%; background: #ff6b6b; box-shadow: 0 0 8px #ff6b6b;"></div>
343
+ <span id="status-text" style="font-size: 11px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px;">Not capturing</span>
344
+ </div>
345
+ </div>
346
+
347
+ <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 16px;">
348
+ <!-- Keyboard State (touch-enabled) -->
349
+ <div style="background: rgba(0,0,0,0.3); border-radius: 12px; padding: 12px; border: 1px solid rgba(88, 166, 255, 0.1);">
350
+ <div style="font-size: 10px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 10px;">Movement</div>
351
+ <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 4px; max-width: 120px; margin: 0 auto;">
352
+ <div></div>
353
+ <div id="key-w" data-key="KeyW" style="aspect-ratio: 1; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 11px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">W</div>
354
+ <div></div>
355
+ <div id="key-a" data-key="KeyA" style="aspect-ratio: 1; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 11px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">A</div>
356
+ <div id="key-s" data-key="KeyS" style="aspect-ratio: 1; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 11px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">S</div>
357
+ <div id="key-d" data-key="KeyD" style="aspect-ratio: 1; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 11px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">D</div>
358
+ </div>
359
+ <div style="display: flex; gap: 4px; margin-top: 8px; justify-content: center;">
360
+ <div id="key-shift" data-key="ShiftLeft" style="padding: 4px 8px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 4px; font-size: 9px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">SHIFT</div>
361
+ <div id="key-space" data-key="Space" style="padding: 4px 16px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 4px; font-size: 9px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">SPACE</div>
362
+ <div id="key-e" data-key="KeyE" style="padding: 4px 8px; background: rgba(88, 166, 255, 0.1); border: 1px solid rgba(88, 166, 255, 0.3); border-radius: 4px; font-size: 9px; color: #58a6ff; transition: all 0.1s; cursor: pointer; user-select: none; -webkit-user-select: none; touch-action: manipulation;">E</div>
363
+ </div>
364
+ </div>
365
+
366
+ <!-- Mouse/Look Joystick (touch-enabled) -->
367
+ <div style="background: rgba(0,0,0,0.3); border-radius: 12px; padding: 12px; border: 1px solid rgba(88, 166, 255, 0.1);">
368
+ <div style="font-size: 10px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 10px;">Look (drag)</div>
369
+ <div style="display: flex; align-items: center; justify-content: center; gap: 16px;">
370
+ <div id="mouse-joystick" style="width: 100px; height: 100px; background: rgba(88, 166, 255, 0.05); border: 2px solid rgba(88, 166, 255, 0.3); border-radius: 50%; position: relative; cursor: pointer; touch-action: none;">
371
+ <div id="mouse-dot" style="width: 24px; height: 24px; background: #58a6ff; border-radius: 50%; position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%); box-shadow: 0 0 12px rgba(88, 166, 255, 0.6); transition: all 0.05s ease-out; pointer-events: none;"></div>
372
+ <div style="position: absolute; top: 50%; left: 0; right: 0; height: 1px; background: rgba(88, 166, 255, 0.2); pointer-events: none;"></div>
373
+ <div style="position: absolute; left: 50%; top: 0; bottom: 0; width: 1px; background: rgba(88, 166, 255, 0.2); pointer-events: none;"></div>
374
+ </div>
375
+ <div style="text-align: left;">
376
+ <div style="font-size: 10px; color: #8b949e; margin-bottom: 4px;">Velocity</div>
377
+ <div style="font-size: 12px; color: #58a6ff;">X: <span id="mouse-x-value">0.0</span></div>
378
+ <div style="font-size: 12px; color: #58a6ff;">Y: <span id="mouse-y-value">0.0</span></div>
379
+ </div>
380
+ </div>
381
+ </div>
382
+ </div>
383
+
384
+ <div style="margin-top: 12px; background: rgba(0,0,0,0.3); border-radius: 8px; padding: 8px 12px; border: 1px solid rgba(88, 166, 255, 0.1);">
385
+ <div style="display: flex; align-items: center; gap: 8px;">
386
+ <span style="font-size: 10px; color: #8b949e; text-transform: uppercase; letter-spacing: 1px;">Active:</span>
387
+ <div id="active-buttons" style="display: flex; gap: 4px; flex-wrap: wrap; min-height: 20px;">
388
+ <span style="font-size: 11px; color: #484f58;">None</span>
389
+ </div>
390
+ </div>
391
+ </div>
392
+ </div>
393
+ </div>
394
+ """
395
+
396
+ CONTROL_INPUT_JS = """
397
+ (() => {
398
+ const statusDot = element.querySelector('#status-dot');
399
+ const statusText = element.querySelector('#status-text');
400
+ const mouseDot = element.querySelector('#mouse-dot');
401
+ const mouseXValue = element.querySelector('#mouse-x-value');
402
+ const mouseYValue = element.querySelector('#mouse-y-value');
403
+ const activeButtonsDisplay = element.querySelector('#active-buttons');
404
+ const mouseJoystick = element.querySelector('#mouse-joystick');
405
+
406
+ let isCapturing = false;
407
+ let isTouchActive = false;
408
+ let pressedKeys = new Set();
409
+ let touchPressedKeys = new Set();
410
+ let mouseVelocity = { x: 0, y: 0 };
411
+ let lastMouseMove = Date.now();
412
+
413
+ // Windows Virtual Key codes
414
+ const BUTTON_MAP = {
415
+ 'KeyW': 87, 'KeyA': 65, 'KeyS': 83, 'KeyD': 68,
416
+ 'KeyQ': 81, 'KeyE': 69, 'KeyR': 82, 'KeyF': 70,
417
+ 'KeyT': 84, 'KeyG': 71, 'KeyZ': 90, 'KeyX': 88,
418
+ 'KeyC': 67, 'KeyV': 86, 'KeyB': 66, 'KeyN': 78,
419
+ 'KeyM': 77, 'KeyH': 72, 'KeyJ': 74, 'KeyK': 75,
420
+ 'KeyL': 76, 'KeyI': 73, 'KeyO': 79, 'KeyP': 80,
421
+ 'KeyU': 85, 'KeyY': 89,
422
+ 'Digit1': 49, 'Digit2': 50, 'Digit3': 51, 'Digit4': 52,
423
+ 'Digit5': 53, 'Digit6': 54, 'Digit7': 55, 'Digit8': 56,
424
+ 'Digit9': 57, 'Digit0': 48,
425
+ 'Space': 32,
426
+ 'ShiftLeft': 16, 'ShiftRight': 16,
427
+ };
428
+
429
+ const KEY_DISPLAY_MAP = {
430
+ 'KeyW': 'key-w', 'KeyA': 'key-a', 'KeyS': 'key-s', 'KeyD': 'key-d',
431
+ 'ShiftLeft': 'key-shift', 'Space': 'key-space', 'KeyE': 'key-e',
432
+ };
433
+
434
+ const isTouchDevice = 'ontouchstart' in window || navigator.maxTouchPoints > 0;
435
+
436
+ function updateKeyDisplay(code, pressed) {
437
+ const elementId = KEY_DISPLAY_MAP[code];
438
+ if (elementId) {
439
+ const keyEl = element.querySelector('#' + elementId);
440
+ if (keyEl) {
441
+ keyEl.style.background = pressed ? 'rgba(88, 166, 255, 0.4)' : 'rgba(88, 166, 255, 0.1)';
442
+ keyEl.style.borderColor = pressed ? '#58a6ff' : 'rgba(88, 166, 255, 0.3)';
443
+ keyEl.style.boxShadow = pressed ? '0 0 8px rgba(88, 166, 255, 0.4)' : 'none';
444
+ }
445
+ }
446
+ }
447
+
448
+ function updateMouseDisplay() {
449
+ const displayX = Math.max(-1, Math.min(1, mouseVelocity.x / 20));
450
+ const displayY = Math.max(-1, Math.min(1, mouseVelocity.y / 20));
451
+ const maxOffset = 40;
452
+ mouseDot.style.left = (50 + displayX * maxOffset) + '%';
453
+ mouseDot.style.top = (50 + displayY * maxOffset) + '%';
454
+ mouseXValue.textContent = mouseVelocity.x.toFixed(1);
455
+ mouseYValue.textContent = mouseVelocity.y.toFixed(1);
456
+ }
457
+
458
+ function updateActiveButtonsDisplay() {
459
+ if (pressedKeys.size === 0) {
460
+ activeButtonsDisplay.innerHTML = '<span style="font-size: 11px; color: #484f58;">None</span>';
461
+ } else {
462
+ activeButtonsDisplay.innerHTML = Array.from(pressedKeys)
463
+ .map(code => code.replace('Key', '').replace('Digit', '').replace('Left', ''))
464
+ .map(name => `<span style="font-size: 10px; background: rgba(88, 166, 255, 0.2); color: #58a6ff; padding: 2px 6px; border-radius: 4px;">${name}</span>`)
465
+ .join('');
466
+ }
467
+ }
468
+
469
+ function updateStatus(capturing) {
470
+ statusDot.style.background = capturing ? '#3fb950' : '#ff6b6b';
471
+ statusDot.style.boxShadow = capturing ? '0 0 8px #3fb950' : '0 0 8px #ff6b6b';
472
+ statusText.textContent = capturing ? 'Capturing - ESC to release' : 'Not capturing';
473
+ }
474
+
475
+ function triggerUpdate() {
476
+ const buttonIds = Array.from(pressedKeys)
477
+ .filter(code => BUTTON_MAP[code] !== undefined)
478
+ .map(code => BUTTON_MAP[code]);
479
+ props.value = { buttons: buttonIds, mouse_x: mouseVelocity.x, mouse_y: mouseVelocity.y };
480
+ trigger('change', props.value);
481
+ }
482
+
483
+ document.addEventListener('pointerlockchange', () => {
484
+ isCapturing = document.pointerLockElement !== null;
485
+ updateStatus(isCapturing);
486
+ if (!isCapturing) {
487
+ pressedKeys.clear();
488
+ mouseVelocity = { x: 0, y: 0 };
489
+ Object.keys(KEY_DISPLAY_MAP).forEach(code => updateKeyDisplay(code, false));
490
+ updateMouseDisplay();
491
+ updateActiveButtonsDisplay();
492
+ triggerUpdate();
493
+ }
494
+ });
495
+
496
+ document.addEventListener('keydown', (e) => {
497
+ if (!isCapturing) return;
498
+ if (e.code === 'Escape') { document.exitPointerLock(); return; }
499
+ if (BUTTON_MAP[e.code] !== undefined && !pressedKeys.has(e.code)) {
500
+ pressedKeys.add(e.code);
501
+ updateKeyDisplay(e.code, true);
502
+ updateActiveButtonsDisplay();
503
+ triggerUpdate();
504
+ }
505
+ e.preventDefault();
506
+ });
507
+
508
+ document.addEventListener('keyup', (e) => {
509
+ if (!isCapturing) return;
510
+ if (pressedKeys.has(e.code)) {
511
+ pressedKeys.delete(e.code);
512
+ updateKeyDisplay(e.code, false);
513
+ updateActiveButtonsDisplay();
514
+ triggerUpdate();
515
+ }
516
+ });
517
+
518
+ const MOUSE_SENSITIVITY = 1.5;
519
+
520
+ document.addEventListener('mousemove', (e) => {
521
+ if (!isCapturing) return;
522
+ mouseVelocity.x = e.movementX * MOUSE_SENSITIVITY;
523
+ mouseVelocity.y = e.movementY * MOUSE_SENSITIVITY;
524
+ updateMouseDisplay();
525
+ triggerUpdate();
526
+ lastMouseMove = Date.now();
527
+ });
528
+
529
+ setInterval(() => {
530
+ if ((isCapturing || isTouchActive) && Date.now() - lastMouseMove > 50) {
531
+ mouseVelocity.x *= 0.8;
532
+ mouseVelocity.y *= 0.8;
533
+ if (Math.abs(mouseVelocity.x) < 0.01) mouseVelocity.x = 0;
534
+ if (Math.abs(mouseVelocity.y) < 0.01) mouseVelocity.y = 0;
535
+ updateMouseDisplay();
536
+ triggerUpdate();
537
+ }
538
+ }, 100);
539
+
540
+ // Touch controls
541
+ const touchableKeys = element.querySelectorAll('[data-key]');
542
+
543
+ touchableKeys.forEach(keyEl => {
544
+ const keyCode = keyEl.dataset.key;
545
+
546
+ const handleTouchStart = (e) => {
547
+ e.preventDefault();
548
+ isTouchActive = true;
549
+ if (!pressedKeys.has(keyCode)) {
550
+ pressedKeys.add(keyCode);
551
+ touchPressedKeys.add(keyCode);
552
+ updateKeyDisplay(keyCode, true);
553
+ updateActiveButtonsDisplay();
554
+ triggerUpdate();
555
+ }
556
+ };
557
+
558
+ const handleTouchEnd = (e) => {
559
+ e.preventDefault();
560
+ if (touchPressedKeys.has(keyCode)) {
561
+ pressedKeys.delete(keyCode);
562
+ touchPressedKeys.delete(keyCode);
563
+ updateKeyDisplay(keyCode, false);
564
+ updateActiveButtonsDisplay();
565
+ triggerUpdate();
566
+ }
567
+ };
568
+
569
+ keyEl.addEventListener('touchstart', handleTouchStart, { passive: false });
570
+ keyEl.addEventListener('touchend', handleTouchEnd, { passive: false });
571
+ keyEl.addEventListener('touchcancel', handleTouchEnd, { passive: false });
572
+
573
+ keyEl.addEventListener('mousedown', (e) => {
574
+ e.preventDefault();
575
+ isTouchActive = true;
576
+ if (!pressedKeys.has(keyCode)) {
577
+ pressedKeys.add(keyCode);
578
+ touchPressedKeys.add(keyCode);
579
+ updateKeyDisplay(keyCode, true);
580
+ updateActiveButtonsDisplay();
581
+ triggerUpdate();
582
+ }
583
+ });
584
+
585
+ keyEl.addEventListener('mouseup', (e) => {
586
+ if (touchPressedKeys.has(keyCode)) {
587
+ pressedKeys.delete(keyCode);
588
+ touchPressedKeys.delete(keyCode);
589
+ updateKeyDisplay(keyCode, false);
590
+ updateActiveButtonsDisplay();
591
+ triggerUpdate();
592
+ }
593
+ });
594
+
595
+ keyEl.addEventListener('mouseleave', (e) => {
596
+ if (touchPressedKeys.has(keyCode)) {
597
+ pressedKeys.delete(keyCode);
598
+ touchPressedKeys.delete(keyCode);
599
+ updateKeyDisplay(keyCode, false);
600
+ updateActiveButtonsDisplay();
601
+ triggerUpdate();
602
+ }
603
+ });
604
+ });
605
+
606
+ // Touch joystick
607
+ let joystickActive = false;
608
+ const JOYSTICK_SENSITIVITY = 0.5;
609
+ const JOYSTICK_RADIUS = 50;
610
+
611
+ const getJoystickCenter = () => {
612
+ const rect = mouseJoystick.getBoundingClientRect();
613
+ return { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 };
614
+ };
615
+
616
+ const handleJoystickMove = (clientX, clientY) => {
617
+ const center = getJoystickCenter();
618
+ let dx = clientX - center.x;
619
+ let dy = clientY - center.y;
620
+ const distance = Math.sqrt(dx * dx + dy * dy);
621
+ if (distance > JOYSTICK_RADIUS) {
622
+ dx = (dx / distance) * JOYSTICK_RADIUS;
623
+ dy = (dy / distance) * JOYSTICK_RADIUS;
624
+ }
625
+ mouseVelocity.x = (dx / JOYSTICK_RADIUS) * 20 * JOYSTICK_SENSITIVITY;
626
+ mouseVelocity.y = (dy / JOYSTICK_RADIUS) * 20 * JOYSTICK_SENSITIVITY;
627
+ updateMouseDisplay();
628
+ triggerUpdate();
629
+ lastMouseMove = Date.now();
630
+ };
631
+
632
+ mouseJoystick.addEventListener('touchstart', (e) => {
633
+ e.preventDefault();
634
+ joystickActive = true;
635
+ isTouchActive = true;
636
+ handleJoystickMove(e.touches[0].clientX, e.touches[0].clientY);
637
+ }, { passive: false });
638
+
639
+ mouseJoystick.addEventListener('touchmove', (e) => {
640
+ e.preventDefault();
641
+ if (joystickActive) handleJoystickMove(e.touches[0].clientX, e.touches[0].clientY);
642
+ }, { passive: false });
643
+
644
+ mouseJoystick.addEventListener('touchend', (e) => {
645
+ e.preventDefault();
646
+ joystickActive = false;
647
+ mouseVelocity = { x: 0, y: 0 };
648
+ updateMouseDisplay();
649
+ triggerUpdate();
650
+ }, { passive: false });
651
+
652
+ mouseJoystick.addEventListener('touchcancel', (e) => {
653
+ joystickActive = false;
654
+ mouseVelocity = { x: 0, y: 0 };
655
+ updateMouseDisplay();
656
+ triggerUpdate();
657
+ }, { passive: false });
658
+
659
+ mouseJoystick.addEventListener('mousedown', (e) => {
660
+ e.preventDefault();
661
+ joystickActive = true;
662
+ isTouchActive = true;
663
+ handleJoystickMove(e.clientX, e.clientY);
664
+ });
665
+
666
+ document.addEventListener('mousemove', (e) => {
667
+ if (joystickActive) handleJoystickMove(e.clientX, e.clientY);
668
+ });
669
+
670
+ document.addEventListener('mouseup', () => {
671
+ if (joystickActive) {
672
+ joystickActive = false;
673
+ mouseVelocity = { x: 0, y: 0 };
674
+ updateMouseDisplay();
675
+ triggerUpdate();
676
+ }
677
+ });
678
+
679
+ if (isTouchDevice) {
680
+ statusText.textContent = 'Touch controls ready';
681
+ statusDot.style.background = '#3fb950';
682
+ statusDot.style.boxShadow = '0 0 8px #3fb950';
683
+ } else {
684
+ updateStatus(false);
685
+ }
686
+
687
+ updateMouseDisplay();
688
+ })();
689
+ """
690
+
691
+ CAPTURE_JS = """
692
+ () => {
693
+ const isTouchDevice = 'ontouchstart' in window || navigator.maxTouchPoints > 0;
694
+ if (isTouchDevice) return;
695
+
696
+ const insertButton = () => {
697
+ const controlPanel = document.querySelector('#control-input-wrapper');
698
+ if (!controlPanel) { setTimeout(insertButton, 100); return; }
699
+
700
+ const btn = document.createElement('button');
701
+ btn.id = 'capture-btn';
702
+ btn.textContent = '🎮 Click to Capture Controls';
703
+ btn.style.cssText = `
704
+ width: 100%; padding: 12px 24px; margin-bottom: 12px;
705
+ font-size: 14px; font-weight: bold;
706
+ background: linear-gradient(135deg, #58a6ff 0%, #a371f7 100%);
707
+ color: white; border: none; border-radius: 8px; cursor: pointer;
708
+ box-shadow: 0 4px 12px rgba(0,0,0,0.3); transition: all 0.2s;
709
+ `;
710
+ btn.onmouseenter = () => btn.style.transform = 'scale(1.02)';
711
+ btn.onmouseleave = () => btn.style.transform = 'scale(1)';
712
+ btn.onclick = async () => {
713
+ try { await document.body.requestPointerLock(); }
714
+ catch (e) { console.log('Pointer lock failed:', e); }
715
+ };
716
+ controlPanel.parentNode.insertBefore(btn, controlPanel);
717
+ };
718
+
719
+ insertButton();
720
+
721
+ document.addEventListener('pointerlockchange', () => {
722
+ const btn = document.querySelector('#capture-btn');
723
+ if (!btn) return;
724
+ if (document.pointerLockElement) {
725
+ btn.textContent = '🔒 Capturing (ESC to release)';
726
+ btn.style.background = 'linear-gradient(135deg, #3fb950 0%, #2ea043 100%)';
727
+ } else {
728
+ btn.textContent = '🎮 Click to Capture Controls';
729
+ btn.style.background = 'linear-gradient(135deg, #58a6ff 0%, #a371f7 100%)';
730
+ }
731
+ });
732
+ }
733
+ """
734
+
735
+
736
+ # --- Gradio App ---
737
+ css = """
738
+ #col-container { max-width: 1200px; margin: 0 auto; }
739
+ #video-output { aspect-ratio: 16/9; max-width: 640px; }
740
+ #video-output img { width: 100%; height: 100%; object-fit: contain; }
741
+ """
742
+
743
+
744
+ def create_app():
745
+ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="WorldEngine") as demo:
746
+ gr.Markdown("""
747
+ # 🌍 WorldEngine — Real-Time World Model
748
+
749
+ Interactive frame-by-frame world generation.
750
+
751
+ **Controls:** WASD to move • Mouse to look • Space to jump • Shift to sprint • E to interact • ESC to release
752
+ """)
753
+
754
+ with gr.Row():
755
+ with gr.Column(scale=3):
756
+ video_output = gr.Image(
757
+ label="Game View",
758
+ elem_id="video-output",
759
+ streaming=True,
760
+ width=640,
761
+ height=360,
762
+ show_label=False,
763
+ )
764
+ frame_display = gr.Number(label="Frame", value=0, interactive=False)
765
+
766
+ with gr.Column(scale=1):
767
+ control_input = gr.HTML(
768
+ value={"buttons": [], "mouse_x": 0.0, "mouse_y": 0.0},
769
+ html_template=CONTROL_INPUT_HTML,
770
+ js_on_load=CONTROL_INPUT_JS,
771
+ )
772
+
773
+ prompt_input = gr.Textbox(
774
+ label="World Prompt",
775
+ value="An explorable world",
776
+ lines=2,
777
+ )
778
+
779
+ with gr.Accordion("Seed Selection", open=False):
780
+ seed_image_upload = gr.Image(
781
+ label="Upload Seed Image (optional)",
782
+ type="pil",
783
+ sources=["upload"],
784
+ height=150,
785
+ )
786
+ gr.Markdown("*Or select a preset seed:*", elem_classes=["text-sm"])
787
+ seed_dropdown = gr.Dropdown(
788
+ choices=["Random"]
789
+ + [f"Seed {i + 1}" for i in range(len(SEED_FRAME_URLS))],
790
+ value="Random",
791
+ label="Preset Seeds",
792
+ )
793
+ reset_btn = gr.Button("Reset World", variant="primary")
794
+
795
+ # Event handlers
796
+ def on_controller_change(controller_value):
797
+ if not controller_value or not game_state.is_initialized:
798
+ return
799
+ buttons = set(controller_value.get("buttons", []))
800
+ mouse_x = controller_value.get("mouse_x", 0.0)
801
+ mouse_y = controller_value.get("mouse_y", 0.0)
802
+ game_state.update_controls(buttons, (mouse_x, mouse_y))
803
+
804
+ def on_prompt_change(prompt):
805
+ game_state.prompt = prompt
806
+
807
+ def on_reset(uploaded_image, seed_choice):
808
+ if uploaded_image is not None:
809
+ game_state.reset(seed_image=uploaded_image)
810
+ elif seed_choice and seed_choice != "Random":
811
+ idx = int(seed_choice.split()[-1]) - 1
812
+ game_state.reset(seed_url=SEED_FRAME_URLS[idx])
813
+ else:
814
+ game_state.reset()
815
+
816
+ def get_frame():
817
+ with game_state._frame_lock:
818
+ return game_state._latest_frame, game_state.frame_count
819
+
820
+ control_input.change(fn=on_controller_change, inputs=[control_input])
821
+ prompt_input.change(fn=on_prompt_change, inputs=[prompt_input])
822
+ reset_btn.click(fn=on_reset, inputs=[seed_image_upload, seed_dropdown])
823
+
824
+ timer = gr.Timer(value=1 / 30)
825
+ timer.tick(fn=get_frame, outputs=[video_output, frame_display])
826
+
827
+ demo.load(fn=None, js=CAPTURE_JS)
828
+
829
+ return demo
830
+
831
+
832
+ def main():
833
+ device = "cuda" if torch.cuda.is_available() else "cpu"
834
+ use_compile = os.environ.get("COMPILE", "1") == "1"
835
+ warmup_steps = int(os.environ.get("WARMUP_STEPS", "1"))
836
+
837
+ print(f"Model path: {MODEL_ID}")
838
+ print(f"Device: {device}")
839
+ print(f"Compile: {use_compile}")
840
+ print(f"Warmup steps: {warmup_steps}")
841
+ print(f"Cache dir: {CACHE_DIR}")
842
+
843
+ # Check if cache exists from previous run
844
+ cache_files = list(CACHE_DIR.glob("**/*.json")) + list(CACHE_DIR.glob("**/*.py"))
845
+ if cache_files:
846
+ print(f"Found {len(cache_files)} cached compilation files")
847
+
848
+ game_state.initialize(
849
+ model_path=MODEL_ID,
850
+ device=device,
851
+ use_compile=use_compile,
852
+ warmup_steps=warmup_steps,
853
+ )
854
+
855
+ demo = create_app()
856
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
857
 
 
 
858
 
859
+ if __name__ == "__main__":
860
+ main()
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch>=2.4.0
3
+ torchvision>=0.19.0
4
+
5
+ # Gradio for web interface
6
+ gradio>=5.0.0
7
+
8
+ # Diffusers with modular pipeline support
9
+ diffusers>=0.36.0
10
+
11
+ # WorldEngine modular dependencies
12
+ transformers>=4.40.0
13
+ einops>=0.8.0
14
+ tensordict>=0.5.0
15
+ regex>=2024.0.0
16
+ ftfy>=6.0.0
17
+
18
+ # HuggingFace Hub
19
+ huggingface_hub>=0.20.0