lapp0 multimodalart HF Staff commited on
Commit
1d676b9
·
1 Parent(s): 14b704d

Fix per user isolation (#4)

Browse files

- Isolate per-user state to fix cross-user seed/frame race (713796d5dbe6d59915d1b4eea3d6ad7a8cca6959)
- Switch to Biome seed catalog and drop user uploads (fd21ff2ccd7da0855bbe378452c1590c04d2e9cc)
- World selector: 3-up grid and click-to-start (6352b84a1cd840fbe382d72a58531c75e96ed103)


Co-authored-by: Apolinário from multimodal AI art <multimodalart@users.noreply.huggingface.co>

Files changed (2) hide show
  1. app.py +180 -136
  2. index.html +36 -120
app.py CHANGED
@@ -3,6 +3,9 @@ Waypoint v1.5 Real-Time World Model Demo — gradio.Server + WebSocket Edition
3
 
4
  Uses gradio.Server for ZeroGPU-compatible game start/stop, and a raw WebSocket
5
  for real-time binary JPEG frame streaming + control input. No Gradio UI polling.
 
 
 
6
  """
7
  # Check for ZeroGPU environment - must be before other imports
8
  try:
@@ -22,17 +25,16 @@ import struct
22
  import contextvars
23
  import threading
24
  import time
 
25
  from collections import deque
26
  from dataclasses import dataclass, field
27
  from multiprocessing import Queue
28
- from typing import Optional, Set, Tuple
29
 
30
  import torch
31
- import numpy as np
32
  from PIL import Image
33
- import tempfile
34
- from fastapi import WebSocket, WebSocketDisconnect, UploadFile, File
35
- from fastapi.responses import HTMLResponse, JSONResponse
36
  from gradio import Server
37
 
38
  from diffusers.modular_pipelines import ModularPipeline
@@ -44,6 +46,7 @@ IMAGE_WIDTH = 1024
44
  IMAGE_HEIGHT = 512
45
  MAX_FRAMES_BEFORE_RESET = 4096
46
  JPEG_QUALITY = 80
 
47
 
48
  torch.set_float32_matmul_precision("medium")
49
 
@@ -66,36 +69,52 @@ constants_path = hf_hub_download(repo_id=MODEL_ID, filename="transformer/transfo
66
  print("Pipeline and AOT artifacts downloaded.")
67
 
68
  _aot_initialized = False
 
69
 
70
 
71
  def _ensure_aot_on_gpu():
72
  global _aot_initialized
73
  if _aot_initialized:
74
  return
75
- broadcast_status("Moving pipeline to GPU...")
76
- print("Initializing AOT model on GPU...")
77
- pipe.to("cuda")
78
- broadcast_status("Setting up KV cache...")
79
- pipe.blocks.sub_blocks["before_denoise"].sub_blocks["setup_kv_cache"]._setup_kv_cache(
80
- pipe.transformer, torch.device("cuda"), torch.bfloat16
81
- )
82
- broadcast_status("Loading AOT-compiled model...")
83
- aot_model = torch._inductor.aoti_load_package(pt2_path)
84
- constants_map = torch.load(constants_path, map_location="cuda", weights_only=True)
85
- aot_model.load_constants(constants_map, check_full_update=True, user_managed=True)
86
- pipe.transformer.forward = aot_model
87
- _aot_initialized = True
88
- broadcast_status("AOT model ready!")
89
- print("AOT model loaded successfully!")
90
-
91
-
92
- SEED_FRAME_URLS = [
93
- "https://huggingface.co/spaces/Overworld/waypoint-1-small/resolve/main/starter_18.png",
94
- "https://huggingface.co/spaces/Overworld/waypoint-1-small/resolve/main/starter_9.png",
95
- "https://huggingface.co/spaces/Overworld/waypoint-1-small/resolve/main/starter_22.png",
96
- "https://huggingface.co/spaces/Overworld/waypoint-1-small/resolve/main/starter_14.png",
97
- "https://huggingface.co/spaces/Overworld/waypoint-1-small/resolve/main/starter_21.png",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  ]
 
99
 
100
 
101
  def load_seed_frame(url: str, target_size: Tuple[int, int] = (IMAGE_HEIGHT, IMAGE_WIDTH)) -> Image.Image:
@@ -113,7 +132,6 @@ class GenerateCommand:
113
 
114
  @dataclass
115
  class ResetCommand:
116
- seed_image: Optional[Image.Image] = None
117
  seed_url: Optional[str] = None
118
  prompt: str = "An explorable world"
119
 
@@ -126,40 +144,86 @@ class StopCommand:
126
  # --- Session ---
127
  @dataclass
128
  class GameSession:
 
129
  command_queue: Queue
130
  frame_queue: queue.Queue
131
  stop_event: threading.Event
 
132
  worker_thread: Optional[threading.Thread] = None
133
  frame_times: deque = field(default_factory=lambda: deque(maxlen=30))
 
 
 
 
134
 
135
  def stop(self):
136
  self.stop_event.set()
137
- self.command_queue.put(StopCommand())
 
 
 
138
  if self.worker_thread and self.worker_thread.is_alive():
139
  self.worker_thread.join(timeout=3.0)
140
 
141
 
142
- # Global session (one player at a time on ZeroGPU)
143
- _session: Optional[GameSession] = None
144
- _session_lock = threading.Lock()
145
 
146
- # Status queue for broadcasting init progress to WebSocket clients
147
- _status_queue: queue.Queue = queue.Queue(maxsize=20)
 
 
 
 
148
 
149
 
150
  def broadcast_status(msg: str):
151
- """Push a status message for WebSocket clients to pick up."""
 
 
 
152
  try:
153
- _status_queue.put_nowait(msg)
154
  except queue.Full:
155
  pass
156
 
157
 
158
- def gpu_worker_thread(seed_url, prompt, command_queue, frame_queue, stop_event, frame_times, seed_image=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  try:
160
  gen = create_gpu_game_loop(
161
  command_queue,
162
- initial_seed_image=seed_image,
163
  initial_seed_url=seed_url,
164
  initial_prompt=prompt,
165
  )
@@ -194,7 +258,7 @@ def gpu_worker_thread(seed_url, prompt, command_queue, frame_queue, stop_event,
194
  print("Worker thread finished")
195
 
196
 
197
- def create_gpu_game_loop(command_queue, initial_seed_image=None, initial_seed_url=None, initial_prompt="An explorable world"):
198
  @spaces.GPU(duration=120)
199
  def gpu_game_loop():
200
  broadcast_status("GPU allocated! Loading model...")
@@ -203,9 +267,7 @@ def create_gpu_game_loop(command_queue, initial_seed_image=None, initial_seed_ur
203
  print(f"Model loaded! (n_frames={n_frames})")
204
 
205
  broadcast_status("Loading seed image...")
206
- if initial_seed_image is not None:
207
- seed_image = initial_seed_image.resize((IMAGE_WIDTH, IMAGE_HEIGHT), Image.BILINEAR)
208
- elif initial_seed_url is not None:
209
  seed_image = load_seed_frame(initial_seed_url)
210
  else:
211
  seed_image = load_seed_frame(random.choice(SEED_FRAME_URLS))
@@ -243,9 +305,7 @@ def create_gpu_game_loop(command_queue, initial_seed_image=None, initial_seed_ur
243
  break
244
 
245
  if reset_command is not None:
246
- if reset_command.seed_image is not None:
247
- seed_img = reset_command.seed_image.resize((IMAGE_WIDTH, IMAGE_HEIGHT), Image.BILINEAR)
248
- elif reset_command.seed_url:
249
  seed_img = load_seed_frame(reset_command.seed_url)
250
  else:
251
  seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS))
@@ -281,91 +341,74 @@ app = Server()
281
 
282
 
283
  @app.api(name="start_game")
284
- def start_game(seed_url: str = "", prompt: str = "An explorable world") -> str:
285
- """Start a new game session. Returns 'ok' immediately; GPU init happens in background."""
286
- global _session
287
- with _session_lock:
288
- if _session is not None:
289
- _session.stop()
290
- _session = None
291
-
292
- # Drain any stale status messages
293
- while not _status_queue.empty():
294
- try:
295
- _status_queue.get_nowait()
296
- except queue.Empty:
297
- break
298
-
299
- broadcast_status("Requesting GPU from ZeroGPU...")
300
 
301
  command_queue = Queue()
302
  frame_queue = queue.Queue(maxsize=2)
303
  stop_event = threading.Event()
 
304
  frame_times = deque(maxlen=30)
305
 
306
- # Use uploaded seed image if available, otherwise use URL
307
- seed_image = _uploaded_seed_image
308
-
309
- # GPU init + first frame generation happens in the worker thread,
310
- # not here — avoids blocking the Gradio SSE queue (which caused 404 timeouts).
311
- ctx = contextvars.copy_context()
312
- worker = threading.Thread(
313
- target=ctx.run,
314
- args=(gpu_worker_thread,
315
- seed_url if seed_url and not seed_image else None,
316
- prompt or "An explorable world",
317
- command_queue, frame_queue, stop_event, frame_times),
318
- kwargs={"seed_image": seed_image},
319
- daemon=True,
320
- )
321
- worker.start()
322
-
323
  session = GameSession(
 
324
  command_queue=command_queue,
325
  frame_queue=frame_queue,
326
  stop_event=stop_event,
327
- worker_thread=worker,
328
  frame_times=frame_times,
329
  )
330
 
331
- with _session_lock:
332
- _session = session
333
 
334
- return "ok"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
 
337
  @app.api(name="stop_game")
338
- def stop_game() -> str:
339
- """Stop the active game session."""
340
- global _session
341
- with _session_lock:
342
- if _session is not None:
343
- _session.stop()
344
- _session = None
345
  return "stopped"
346
 
347
 
348
- # Store for uploaded seed image (one at a time, single-player)
349
- _uploaded_seed_image: Optional[Image.Image] = None
350
-
351
-
352
- @app.post("/upload_seed_image")
353
- async def upload_seed_image(file: UploadFile = File(...)):
354
- """Upload a custom seed image. Returns confirmation."""
355
- global _uploaded_seed_image
356
- data = await file.read()
357
- _uploaded_seed_image = Image.open(io.BytesIO(data)).convert("RGB")
358
- return JSONResponse({"status": "ok"})
359
-
360
-
361
- @app.post("/clear_seed_image")
362
- async def clear_seed_image():
363
- """Clear the uploaded seed image."""
364
- global _uploaded_seed_image
365
- _uploaded_seed_image = None
366
- return JSONResponse({"status": "ok"})
367
-
368
-
369
  @app.api(name="get_worlds")
370
  def get_worlds() -> list:
371
  """Return the list of seed world URLs."""
@@ -373,34 +416,37 @@ def get_worlds() -> list:
373
 
374
 
375
  @app.websocket("/ws")
376
- async def game_ws(websocket: WebSocket):
377
  """
378
- Real-time game WebSocket.
379
- - Server sends: binary messages = 8-byte header (uint32 frame_count, uint32 fps*10) + JPEG bytes
380
- - Server sends: JSON {"type": "session_ended"} when GPU session expires
381
- - Client sends: JSON {"type": "control", "buttons": [...], "mouse_x": f, "mouse_y": f, "prompt": "..."}
382
- - Client sends: JSON {"type": "reset", "seed_url": "...", "prompt": "..."}
383
  """
384
  await websocket.accept()
 
 
 
 
 
385
  loop = asyncio.get_event_loop()
386
 
387
  async def send_frames():
388
  session_ended_sent = False
389
  while True:
390
- # Relay any status messages (init progress / errors)
391
- try:
392
- status_msg = _status_queue.get_nowait()
393
- if status_msg.startswith("error:"):
394
- await websocket.send_json({"type": "error", "message": status_msg[6:]})
 
 
 
 
 
 
 
 
395
  break
396
- await websocket.send_json({"type": "status", "message": status_msg})
397
- except queue.Empty:
398
- pass
399
- except (WebSocketDisconnect, RuntimeError):
400
- break
401
 
402
- with _session_lock:
403
- session = _session
404
  if session is None:
405
  await asyncio.sleep(0.05)
406
  continue
@@ -423,6 +469,7 @@ async def game_ws(websocket: WebSocket):
423
  jpeg_bytes = buf.getvalue()
424
  header = struct.pack(">II", count, int(fps * 10))
425
  await websocket.send_bytes(header + jpeg_bytes)
 
426
  except queue.Empty:
427
  pass
428
  except (WebSocketDisconnect, RuntimeError):
@@ -432,10 +479,10 @@ async def game_ws(websocket: WebSocket):
432
  while True:
433
  try:
434
  data = await websocket.receive_json()
435
- with _session_lock:
436
- session = _session
437
  if session is None:
438
  continue
 
439
  msg_type = data.get("type", "control")
440
  if msg_type == "control":
441
  buttons = set(data.get("buttons", []))
@@ -443,13 +490,10 @@ async def game_ws(websocket: WebSocket):
443
  prompt = data.get("prompt", "An explorable world")
444
  session.command_queue.put(GenerateCommand(buttons=buttons, mouse=mouse, prompt=prompt))
445
  elif msg_type == "reset":
446
- seed_url = data.get("seed_url")
447
  prompt = data.get("prompt", "An explorable world")
448
- use_custom = data.get("use_custom_image", False)
449
- seed_image = _uploaded_seed_image if use_custom else None
450
  session.command_queue.put(ResetCommand(
451
- seed_image=seed_image,
452
- seed_url=seed_url if not seed_image else None,
453
  prompt=prompt,
454
  ))
455
  except WebSocketDisconnect:
 
3
 
4
  Uses gradio.Server for ZeroGPU-compatible game start/stop, and a raw WebSocket
5
  for real-time binary JPEG frame streaming + control input. No Gradio UI polling.
6
+
7
+ Multi-user safe: every endpoint is keyed by a per-client `session_id` so
8
+ concurrent players never share seed images, frame queues, or status messages.
9
  """
10
  # Check for ZeroGPU environment - must be before other imports
11
  try:
 
25
  import contextvars
26
  import threading
27
  import time
28
+ import uuid
29
  from collections import deque
30
  from dataclasses import dataclass, field
31
  from multiprocessing import Queue
32
+ from typing import Dict, Optional, Set, Tuple
33
 
34
  import torch
 
35
  from PIL import Image
36
+ from fastapi import WebSocket, WebSocketDisconnect
37
+ from fastapi.responses import HTMLResponse
 
38
  from gradio import Server
39
 
40
  from diffusers.modular_pipelines import ModularPipeline
 
46
  IMAGE_HEIGHT = 512
47
  MAX_FRAMES_BEFORE_RESET = 4096
48
  JPEG_QUALITY = 80
49
+ SESSION_IDLE_TIMEOUT = 600 # seconds; janitor reaps abandoned sessions
50
 
51
  torch.set_float32_matmul_precision("medium")
52
 
 
69
  print("Pipeline and AOT artifacts downloaded.")
70
 
71
  _aot_initialized = False
72
+ _aot_init_lock = threading.Lock()
73
 
74
 
75
  def _ensure_aot_on_gpu():
76
  global _aot_initialized
77
  if _aot_initialized:
78
  return
79
+ with _aot_init_lock:
80
+ if _aot_initialized:
81
+ return
82
+ broadcast_status("Moving pipeline to GPU...")
83
+ print("Initializing AOT model on GPU...")
84
+ pipe.to("cuda")
85
+ broadcast_status("Setting up KV cache...")
86
+ pipe.blocks.sub_blocks["before_denoise"].sub_blocks["setup_kv_cache"]._setup_kv_cache(
87
+ pipe.transformer, torch.device("cuda"), torch.bfloat16
88
+ )
89
+ broadcast_status("Loading AOT-compiled model...")
90
+ aot_model = torch._inductor.aoti_load_package(pt2_path)
91
+ constants_map = torch.load(constants_path, map_location="cuda", weights_only=True)
92
+ aot_model.load_constants(constants_map, check_full_update=True, user_managed=True)
93
+ pipe.transformer.forward = aot_model
94
+ _aot_initialized = True
95
+ broadcast_status("AOT model ready!")
96
+ print("AOT model loaded successfully!")
97
+
98
+
99
+ _SEED_BASE = "https://github.com/Overworldai/Biome/blob/main/seeds"
100
+ _SEED_NAMES = [
101
+ "abandoned_city_runner.jpg",
102
+ "alien_command_center.jpg",
103
+ "ancient_standing_stones.jpg",
104
+ "default.jpg",
105
+ "desert_outpost_hangar.jpg",
106
+ "enchanted_swamp_torch.jpg",
107
+ "frozen_crystal_cavern.jpg",
108
+ "frozen_valley_sniper.jpg",
109
+ "highland_castle_loch.jpg",
110
+ "mountain_ruins_gun.jpg",
111
+ "shattered_cockpit_nebula.jpg",
112
+ "shipwreck_shore_revolver.jpg",
113
+ "snowy_forest_tracks.jpg",
114
+ "stormy_countryside_rifle.jpg",
115
+ "sunken_city_depths.jpg",
116
  ]
117
+ SEED_FRAME_URLS = [f"{_SEED_BASE}/{name}?raw=true" for name in _SEED_NAMES]
118
 
119
 
120
  def load_seed_frame(url: str, target_size: Tuple[int, int] = (IMAGE_HEIGHT, IMAGE_WIDTH)) -> Image.Image:
 
132
 
133
  @dataclass
134
  class ResetCommand:
 
135
  seed_url: Optional[str] = None
136
  prompt: str = "An explorable world"
137
 
 
144
  # --- Session ---
145
  @dataclass
146
  class GameSession:
147
+ session_id: str
148
  command_queue: Queue
149
  frame_queue: queue.Queue
150
  stop_event: threading.Event
151
+ status_queue: queue.Queue
152
  worker_thread: Optional[threading.Thread] = None
153
  frame_times: deque = field(default_factory=lambda: deque(maxlen=30))
154
+ last_active: float = field(default_factory=time.time)
155
+
156
+ def touch(self):
157
+ self.last_active = time.time()
158
 
159
  def stop(self):
160
  self.stop_event.set()
161
+ try:
162
+ self.command_queue.put_nowait(StopCommand())
163
+ except Exception:
164
+ pass
165
  if self.worker_thread and self.worker_thread.is_alive():
166
  self.worker_thread.join(timeout=3.0)
167
 
168
 
169
+ # Per-client state, keyed by session_id (UUID generated by the browser).
170
+ _sessions: Dict[str, GameSession] = {}
171
+ _sessions_lock = threading.Lock()
172
 
173
+ # Contextvar carrying the active session's status queue. The worker thread
174
+ # inherits the value via contextvars.copy_context(), so broadcast_status()
175
+ # always lands in the right session's queue without explicit threading.
176
+ _current_status_queue: contextvars.ContextVar[Optional[queue.Queue]] = contextvars.ContextVar(
177
+ "waypoint_status_queue", default=None
178
+ )
179
 
180
 
181
  def broadcast_status(msg: str):
182
+ """Push a status message to the current session's WebSocket consumer."""
183
+ q = _current_status_queue.get()
184
+ if q is None:
185
+ return
186
  try:
187
+ q.put_nowait(msg)
188
  except queue.Full:
189
  pass
190
 
191
 
192
+ def _get_session(session_id: str) -> Optional[GameSession]:
193
+ with _sessions_lock:
194
+ return _sessions.get(session_id)
195
+
196
+
197
+ def _drop_session(session_id: str) -> Optional[GameSession]:
198
+ with _sessions_lock:
199
+ return _sessions.pop(session_id, None)
200
+
201
+
202
+ def _reap_idle_sessions():
203
+ """Background janitor: stops sessions whose worker died or that went idle."""
204
+ while True:
205
+ time.sleep(60)
206
+ now = time.time()
207
+ to_drop = []
208
+ with _sessions_lock:
209
+ for sid, sess in list(_sessions.items()):
210
+ worker_dead = sess.worker_thread is None or not sess.worker_thread.is_alive()
211
+ idle = (now - sess.last_active) > SESSION_IDLE_TIMEOUT
212
+ if worker_dead and idle:
213
+ to_drop.append(sid)
214
+ for sid in to_drop:
215
+ _sessions.pop(sid, None)
216
+ if to_drop:
217
+ print(f"Janitor reaped {len(to_drop)} idle session(s)")
218
+
219
+
220
+ threading.Thread(target=_reap_idle_sessions, daemon=True).start()
221
+
222
+
223
+ def gpu_worker_thread(seed_url, prompt, command_queue, frame_queue, stop_event, frame_times):
224
  try:
225
  gen = create_gpu_game_loop(
226
  command_queue,
 
227
  initial_seed_url=seed_url,
228
  initial_prompt=prompt,
229
  )
 
258
  print("Worker thread finished")
259
 
260
 
261
+ def create_gpu_game_loop(command_queue, initial_seed_url=None, initial_prompt="An explorable world"):
262
  @spaces.GPU(duration=120)
263
  def gpu_game_loop():
264
  broadcast_status("GPU allocated! Loading model...")
 
267
  print(f"Model loaded! (n_frames={n_frames})")
268
 
269
  broadcast_status("Loading seed image...")
270
+ if initial_seed_url:
 
 
271
  seed_image = load_seed_frame(initial_seed_url)
272
  else:
273
  seed_image = load_seed_frame(random.choice(SEED_FRAME_URLS))
 
305
  break
306
 
307
  if reset_command is not None:
308
+ if reset_command.seed_url:
 
 
309
  seed_img = load_seed_frame(reset_command.seed_url)
310
  else:
311
  seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS))
 
341
 
342
 
343
  @app.api(name="start_game")
344
+ def start_game(session_id: str = "", seed_url: str = "", prompt: str = "An explorable world") -> str:
345
+ """Start a new game session for `session_id`. Returns the session_id used."""
346
+ if not session_id:
347
+ session_id = str(uuid.uuid4())
348
+
349
+ # Tear down any prior session for this client (e.g. a reload mid-game).
350
+ prior = _drop_session(session_id)
351
+ if prior is not None:
352
+ prior.stop()
353
+
354
+ # If no seed selected, pick a random one from the curated set so the
355
+ # player always lands somewhere instead of falling through to the model's
356
+ # default.
357
+ if not seed_url:
358
+ seed_url = random.choice(SEED_FRAME_URLS)
 
359
 
360
  command_queue = Queue()
361
  frame_queue = queue.Queue(maxsize=2)
362
  stop_event = threading.Event()
363
+ status_queue = queue.Queue(maxsize=20)
364
  frame_times = deque(maxlen=30)
365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  session = GameSession(
367
+ session_id=session_id,
368
  command_queue=command_queue,
369
  frame_queue=frame_queue,
370
  stop_event=stop_event,
371
+ status_queue=status_queue,
372
  frame_times=frame_times,
373
  )
374
 
375
+ with _sessions_lock:
376
+ _sessions[session_id] = session
377
 
378
+ # Bind the status queue into the contextvar so the worker thread (and the
379
+ # @spaces.GPU function it spawns) inherits it via copy_context().
380
+ token = _current_status_queue.set(status_queue)
381
+ try:
382
+ broadcast_status("Requesting GPU from ZeroGPU...")
383
+
384
+ ctx = contextvars.copy_context()
385
+ worker = threading.Thread(
386
+ target=ctx.run,
387
+ args=(gpu_worker_thread,
388
+ seed_url,
389
+ prompt or "An explorable world",
390
+ command_queue, frame_queue, stop_event, frame_times),
391
+ daemon=True,
392
+ )
393
+ session.worker_thread = worker
394
+ worker.start()
395
+ finally:
396
+ _current_status_queue.reset(token)
397
+
398
+ return session_id
399
 
400
 
401
  @app.api(name="stop_game")
402
+ def stop_game(session_id: str = "") -> str:
403
+ """Stop the active game session for the given client."""
404
+ if not session_id:
405
+ return "no_session"
406
+ session = _drop_session(session_id)
407
+ if session is not None:
408
+ session.stop()
409
  return "stopped"
410
 
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  @app.api(name="get_worlds")
413
  def get_worlds() -> list:
414
  """Return the list of seed world URLs."""
 
416
 
417
 
418
  @app.websocket("/ws")
419
+ async def game_ws(websocket: WebSocket, session_id: str = ""):
420
  """
421
+ Real-time game WebSocket. Requires `?session_id=...` query param matching
422
+ the value passed to /start_game.
 
 
 
423
  """
424
  await websocket.accept()
425
+ if not session_id:
426
+ await websocket.send_json({"type": "error", "message": "missing session_id"})
427
+ await websocket.close(code=1008)
428
+ return
429
+
430
  loop = asyncio.get_event_loop()
431
 
432
  async def send_frames():
433
  session_ended_sent = False
434
  while True:
435
+ session = _get_session(session_id)
436
+
437
+ # Drain any status messages for this session (init progress / errors).
438
+ if session is not None:
439
+ try:
440
+ status_msg = session.status_queue.get_nowait()
441
+ if status_msg.startswith("error:"):
442
+ await websocket.send_json({"type": "error", "message": status_msg[6:]})
443
+ break
444
+ await websocket.send_json({"type": "status", "message": status_msg})
445
+ except queue.Empty:
446
+ pass
447
+ except (WebSocketDisconnect, RuntimeError):
448
  break
 
 
 
 
 
449
 
 
 
450
  if session is None:
451
  await asyncio.sleep(0.05)
452
  continue
 
469
  jpeg_bytes = buf.getvalue()
470
  header = struct.pack(">II", count, int(fps * 10))
471
  await websocket.send_bytes(header + jpeg_bytes)
472
+ session.touch()
473
  except queue.Empty:
474
  pass
475
  except (WebSocketDisconnect, RuntimeError):
 
479
  while True:
480
  try:
481
  data = await websocket.receive_json()
482
+ session = _get_session(session_id)
 
483
  if session is None:
484
  continue
485
+ session.touch()
486
  msg_type = data.get("type", "control")
487
  if msg_type == "control":
488
  buttons = set(data.get("buttons", []))
 
490
  prompt = data.get("prompt", "An explorable world")
491
  session.command_queue.put(GenerateCommand(buttons=buttons, mouse=mouse, prompt=prompt))
492
  elif msg_type == "reset":
493
+ seed_url = data.get("seed_url") or random.choice(SEED_FRAME_URLS)
494
  prompt = data.get("prompt", "An explorable world")
 
 
495
  session.command_queue.put(ResetCommand(
496
+ seed_url=seed_url,
 
497
  prompt=prompt,
498
  ))
499
  except WebSocketDisconnect:
index.html CHANGED
@@ -389,12 +389,13 @@
389
 
390
  /* --- World Selector --- */
391
  .world-grid {
392
- display: flex;
 
393
  gap: 6px;
394
- flex-wrap: wrap;
395
  }
396
  .world-thumb {
397
- width: 56px; height: 28px;
 
398
  border-radius: 4px;
399
  border: 2px solid transparent;
400
  cursor: pointer;
@@ -404,48 +405,6 @@
404
  .world-thumb:hover { border-color: var(--accent); transform: scale(1.05); }
405
  .world-thumb.selected { border-color: var(--accent); box-shadow: 0 0 8px rgba(88,166,255,0.4); }
406
 
407
- /* --- Custom Image Upload --- */
408
- .upload-area {
409
- border: 2px dashed rgba(88,166,255,0.25);
410
- border-radius: 8px;
411
- padding: 12px;
412
- text-align: center;
413
- cursor: pointer;
414
- transition: border-color 0.15s, background 0.15s;
415
- position: relative;
416
- min-height: 60px;
417
- display: flex;
418
- flex-direction: column;
419
- align-items: center;
420
- justify-content: center;
421
- gap: 4px;
422
- }
423
- .upload-area:hover { border-color: var(--accent); background: rgba(88,166,255,0.05); }
424
- .upload-area.has-image { border-style: solid; padding: 4px; }
425
- .upload-area img {
426
- max-width: 100%;
427
- max-height: 80px;
428
- object-fit: contain;
429
- border-radius: 4px;
430
- }
431
- .upload-area .placeholder { font-size: 11px; color: var(--text-dim); }
432
- .upload-area .icon { font-size: 20px; opacity: 0.5; }
433
- .upload-clear {
434
- position: absolute;
435
- top: 4px; right: 4px;
436
- background: rgba(255,107,107,0.8);
437
- border: none;
438
- color: #fff;
439
- width: 18px; height: 18px;
440
- border-radius: 50%;
441
- font-size: 11px;
442
- cursor: pointer;
443
- display: flex;
444
- align-items: center;
445
- justify-content: center;
446
- line-height: 1;
447
- }
448
-
449
  /* --- Prompt --- */
450
  #prompt-input {
451
  width: 100%;
@@ -516,7 +475,7 @@
516
  .key-wide { min-height: 34px; padding: 6px 14px; font-size: 11px; }
517
  .joystick-ring { width: 90px; height: 90px; min-width: 90px; }
518
 
519
- .world-thumb { width: 44px; height: 22px; }
520
  }
521
  </style>
522
  </head>
@@ -636,16 +595,6 @@
636
  <div class="world-grid" id="world-selector"></div>
637
  </div>
638
 
639
- <!-- Custom Image Upload -->
640
- <div class="sidebar-section">
641
- <div class="section-label">Custom Start Image</div>
642
- <div class="upload-area" id="upload-area">
643
- <div class="icon">+</div>
644
- <div class="placeholder">Click or drop image</div>
645
- </div>
646
- <input type="file" id="file-input" accept="image/*" style="display:none">
647
- </div>
648
-
649
  <!-- Prompt -->
650
  <div class="sidebar-section">
651
  <div class="section-label">World Prompt</div>
@@ -657,7 +606,13 @@
657
  import { Client } from "https://cdn.jsdelivr.net/npm/@gradio/client/dist/index.min.js";
658
 
659
  // --- State ---
 
 
 
 
 
660
  const state = {
 
661
  playing: false,
662
  capturing: false,
663
  pressedKeys: new Set(),
@@ -665,7 +620,6 @@ const state = {
665
  mouseVelocity: { x: 0, y: 0 },
666
  lastMouseMove: 0,
667
  selectedSeedUrl: "",
668
- useCustomImage: false,
669
  prompt: "An explorable world",
670
  ws: null,
671
  client: null,
@@ -705,8 +659,6 @@ const ctrlStatusText = document.getElementById("ctrl-status-text");
705
  const activeTags = document.getElementById("active-tags");
706
  const mouseXVal = document.getElementById("mouse-x-val");
707
  const mouseYVal = document.getElementById("mouse-y-val");
708
- const uploadArea = document.getElementById("upload-area");
709
- const fileInput = document.getElementById("file-input");
710
  const endedOverlay = document.getElementById("ended-overlay");
711
  const restartBtn = document.getElementById("restart-btn");
712
 
@@ -720,7 +672,7 @@ async function initClient() {
720
  // --- WebSocket ---
721
  function connectWS() {
722
  const proto = location.protocol === "https:" ? "wss:" : "ws:";
723
- const ws = new WebSocket(`${proto}//${location.host}/ws`);
724
  ws.binaryType = "arraybuffer";
725
 
726
  ws.onmessage = (e) => {
@@ -1027,81 +979,45 @@ joystick.addEventListener("touchend", (e) => { e.preventDefault(); resetJoystick
1027
  joystick.addEventListener("touchcancel", (e) => { e.preventDefault(); resetJoystick(); }, { passive: false });
1028
 
1029
  // --- World Selector ---
1030
- const SEED_URLS = [
1031
- "https://huggingface.co/spaces/Overworld/waypoint-1-small/resolve/main/starter_18.png",
1032
- "https://huggingface.co/spaces/Overworld/waypoint-1-small/resolve/main/starter_9.png",
1033
- "https://huggingface.co/spaces/Overworld/waypoint-1-small/resolve/main/starter_22.png",
1034
- "https://huggingface.co/spaces/Overworld/waypoint-1-small/resolve/main/starter_14.png",
1035
- "https://huggingface.co/spaces/Overworld/waypoint-1-small/resolve/main/starter_21.png",
 
 
 
 
 
 
 
 
 
 
 
1036
  ];
 
1037
 
1038
  SEED_URLS.forEach((url, i) => {
1039
  const img = document.createElement("img");
1040
  img.src = url;
1041
  img.className = "world-thumb";
1042
- img.title = `World ${i + 1}`;
1043
  img.addEventListener("click", () => {
1044
  document.querySelectorAll(".world-thumb").forEach(t => t.classList.remove("selected"));
1045
  img.classList.add("selected");
1046
  state.selectedSeedUrl = url;
1047
- state.useCustomImage = false;
1048
- // If playing, reset to this world
1049
  if (state.playing && state.ws?.readyState === WebSocket.OPEN) {
1050
  state.ws.send(JSON.stringify({ type: "reset", seed_url: url, prompt: state.prompt }));
 
 
1051
  }
1052
  });
1053
  worldSelector.appendChild(img);
1054
  });
1055
 
1056
- // --- Custom Image Upload ---
1057
- uploadArea.addEventListener("click", () => fileInput.click());
1058
- uploadArea.addEventListener("dragover", (e) => { e.preventDefault(); uploadArea.style.borderColor = "var(--accent)"; });
1059
- uploadArea.addEventListener("dragleave", () => { uploadArea.style.borderColor = ""; });
1060
- uploadArea.addEventListener("drop", (e) => {
1061
- e.preventDefault();
1062
- uploadArea.style.borderColor = "";
1063
- const file = e.dataTransfer?.files?.[0];
1064
- if (file && file.type.startsWith("image/")) handleImageUpload(file);
1065
- });
1066
- fileInput.addEventListener("change", () => {
1067
- if (fileInput.files?.[0]) handleImageUpload(fileInput.files[0]);
1068
- });
1069
-
1070
- async function handleImageUpload(file) {
1071
- // Show preview
1072
- const reader = new FileReader();
1073
- reader.onload = (e) => {
1074
- uploadArea.innerHTML = `<img src="${e.target.result}" alt="Custom seed"><button class="upload-clear" id="clear-upload">&times;</button>`;
1075
- uploadArea.classList.add("has-image");
1076
- document.getElementById("clear-upload").addEventListener("click", (ev) => {
1077
- ev.stopPropagation();
1078
- clearUpload();
1079
- });
1080
- };
1081
- reader.readAsDataURL(file);
1082
-
1083
- // Upload to server
1084
- const formData = new FormData();
1085
- formData.append("file", file);
1086
- try {
1087
- await fetch("/upload_seed_image", { method: "POST", body: formData });
1088
- state.useCustomImage = true;
1089
- // Deselect world thumbs
1090
- document.querySelectorAll(".world-thumb").forEach(t => t.classList.remove("selected"));
1091
- state.selectedSeedUrl = "";
1092
- } catch (err) {
1093
- console.error("Upload failed:", err);
1094
- }
1095
- }
1096
-
1097
- async function clearUpload() {
1098
- uploadArea.innerHTML = '<div class="icon">+</div><div class="placeholder">Click or drop image</div>';
1099
- uploadArea.classList.remove("has-image");
1100
- state.useCustomImage = false;
1101
- fileInput.value = "";
1102
- try { await fetch("/clear_seed_image", { method: "POST" }); } catch {}
1103
- }
1104
-
1105
  // --- Prompt ---
1106
  promptInput.addEventListener("input", () => {
1107
  state.prompt = promptInput.value || "An explorable world";
@@ -1122,6 +1038,7 @@ async function startGame() {
1122
 
1123
  try {
1124
  await state.client.predict("/start_game", {
 
1125
  seed_url: state.selectedSeedUrl || "",
1126
  prompt: state.prompt,
1127
  });
@@ -1138,7 +1055,7 @@ restartBtn.addEventListener("click", startGame);
1138
 
1139
  stopBtn.addEventListener("click", async () => {
1140
  if (!state.client) return;
1141
- try { await state.client.predict("/stop_game"); } catch {}
1142
  setPlaying(false);
1143
  endedOverlay.classList.add("hidden");
1144
  setStatus("idle", "Idle");
@@ -1153,7 +1070,6 @@ resetBtn.addEventListener("click", () => {
1153
  state.ws.send(JSON.stringify({
1154
  type: "reset",
1155
  seed_url: state.selectedSeedUrl || "",
1156
- use_custom_image: state.useCustomImage,
1157
  prompt: state.prompt,
1158
  }));
1159
  }
 
389
 
390
  /* --- World Selector --- */
391
  .world-grid {
392
+ display: grid;
393
+ grid-template-columns: repeat(3, 1fr);
394
  gap: 6px;
 
395
  }
396
  .world-thumb {
397
+ width: 100%;
398
+ aspect-ratio: 16 / 9;
399
  border-radius: 4px;
400
  border: 2px solid transparent;
401
  cursor: pointer;
 
405
  .world-thumb:hover { border-color: var(--accent); transform: scale(1.05); }
406
  .world-thumb.selected { border-color: var(--accent); box-shadow: 0 0 8px rgba(88,166,255,0.4); }
407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  /* --- Prompt --- */
409
  #prompt-input {
410
  width: 100%;
 
475
  .key-wide { min-height: 34px; padding: 6px 14px; font-size: 11px; }
476
  .joystick-ring { width: 90px; height: 90px; min-width: 90px; }
477
 
478
+ .world-thumb { aspect-ratio: 16 / 9; }
479
  }
480
  </style>
481
  </head>
 
595
  <div class="world-grid" id="world-selector"></div>
596
  </div>
597
 
 
 
 
 
 
 
 
 
 
 
598
  <!-- Prompt -->
599
  <div class="sidebar-section">
600
  <div class="section-label">World Prompt</div>
 
606
  import { Client } from "https://cdn.jsdelivr.net/npm/@gradio/client/dist/index.min.js";
607
 
608
  // --- State ---
609
+ // One session_id per browser tab — keeps every server-side queue, status
610
+ // stream, and uploaded seed isolated from other concurrent users.
611
+ const SESSION_ID = (crypto.randomUUID && crypto.randomUUID()) ||
612
+ (`s-${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 10)}`);
613
+
614
  const state = {
615
+ sessionId: SESSION_ID,
616
  playing: false,
617
  capturing: false,
618
  pressedKeys: new Set(),
 
620
  mouseVelocity: { x: 0, y: 0 },
621
  lastMouseMove: 0,
622
  selectedSeedUrl: "",
 
623
  prompt: "An explorable world",
624
  ws: null,
625
  client: null,
 
659
  const activeTags = document.getElementById("active-tags");
660
  const mouseXVal = document.getElementById("mouse-x-val");
661
  const mouseYVal = document.getElementById("mouse-y-val");
 
 
662
  const endedOverlay = document.getElementById("ended-overlay");
663
  const restartBtn = document.getElementById("restart-btn");
664
 
 
672
  // --- WebSocket ---
673
  function connectWS() {
674
  const proto = location.protocol === "https:" ? "wss:" : "ws:";
675
+ const ws = new WebSocket(`${proto}//${location.host}/ws?session_id=${encodeURIComponent(state.sessionId)}`);
676
  ws.binaryType = "arraybuffer";
677
 
678
  ws.onmessage = (e) => {
 
979
  joystick.addEventListener("touchcancel", (e) => { e.preventDefault(); resetJoystick(); }, { passive: false });
980
 
981
  // --- World Selector ---
982
+ const SEED_BASE = "https://github.com/Overworldai/Biome/blob/main/seeds";
983
+ const SEED_NAMES = [
984
+ "abandoned_city_runner.jpg",
985
+ "alien_command_center.jpg",
986
+ "ancient_standing_stones.jpg",
987
+ "default.jpg",
988
+ "desert_outpost_hangar.jpg",
989
+ "enchanted_swamp_torch.jpg",
990
+ "frozen_crystal_cavern.jpg",
991
+ "frozen_valley_sniper.jpg",
992
+ "highland_castle_loch.jpg",
993
+ "mountain_ruins_gun.jpg",
994
+ "shattered_cockpit_nebula.jpg",
995
+ "shipwreck_shore_revolver.jpg",
996
+ "snowy_forest_tracks.jpg",
997
+ "stormy_countryside_rifle.jpg",
998
+ "sunken_city_depths.jpg",
999
  ];
1000
+ const SEED_URLS = SEED_NAMES.map((name) => `${SEED_BASE}/${name}?raw=true`);
1001
 
1002
  SEED_URLS.forEach((url, i) => {
1003
  const img = document.createElement("img");
1004
  img.src = url;
1005
  img.className = "world-thumb";
1006
+ img.title = SEED_NAMES[i].replace(/\.jpg$/, "").replace(/_/g, " ");
1007
  img.addEventListener("click", () => {
1008
  document.querySelectorAll(".world-thumb").forEach(t => t.classList.remove("selected"));
1009
  img.classList.add("selected");
1010
  state.selectedSeedUrl = url;
1011
+ // If playing, reset to this world; otherwise kick off a new game.
 
1012
  if (state.playing && state.ws?.readyState === WebSocket.OPEN) {
1013
  state.ws.send(JSON.stringify({ type: "reset", seed_url: url, prompt: state.prompt }));
1014
+ } else if (!state.playing && !startBtn.disabled) {
1015
+ startGame();
1016
  }
1017
  });
1018
  worldSelector.appendChild(img);
1019
  });
1020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1021
  // --- Prompt ---
1022
  promptInput.addEventListener("input", () => {
1023
  state.prompt = promptInput.value || "An explorable world";
 
1038
 
1039
  try {
1040
  await state.client.predict("/start_game", {
1041
+ session_id: state.sessionId,
1042
  seed_url: state.selectedSeedUrl || "",
1043
  prompt: state.prompt,
1044
  });
 
1055
 
1056
  stopBtn.addEventListener("click", async () => {
1057
  if (!state.client) return;
1058
+ try { await state.client.predict("/stop_game", { session_id: state.sessionId }); } catch {}
1059
  setPlaying(false);
1060
  endedOverlay.classList.add("hidden");
1061
  setStatus("idle", "Idle");
 
1070
  state.ws.send(JSON.stringify({
1071
  type: "reset",
1072
  seed_url: state.selectedSeedUrl || "",
 
1073
  prompt: state.prompt,
1074
  }));
1075
  }