multimodalart HF Staff commited on
Commit
8d8b1fb
·
verified ·
1 Parent(s): 22052ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +331 -140
app.py CHANGED
@@ -17,8 +17,12 @@ except ImportError:
17
 
18
  import base64
19
  import os
 
20
  import random
21
- from dataclasses import dataclass
 
 
 
22
  from io import BytesIO
23
  from multiprocessing import Queue
24
  from pathlib import Path
@@ -54,6 +58,20 @@ MODEL_ID = os.environ.get("MODEL_PATH", "diffusers-internal-dev/world-engine-mod
54
  pipe = ModularPipeline.from_pretrained(MODEL_ID, trust_remote_code=True, revision="aot-compatible")
55
  pipe.load_components(["transformer", "vae"], trust_remote_code=True, revision="aot-compatible", torch_dtype=torch.bfloat16)
56
  pipe.load_components(["text_encoder", "tokenizer"], trust_remote_code=True, torch_dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  SEED_FRAME_URLS = [
59
  "https://gist.github.com/user-attachments/assets/5d91c49a-2ae9-418f-99c0-e93ae387e1de",
@@ -137,9 +155,65 @@ class StopCommand:
137
  pass
138
 
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  # --- GPU Session Generator ---
141
  def create_gpu_game_loop(command_queue: Queue, initial_seed_image=None, initial_seed_url=None, initial_prompt="An explorable world"):
142
  """Create GPU game loop generator with closure over command_queue."""
 
143
 
144
  @spaces.GPU(duration=90)
145
  def gpu_game_loop():
@@ -147,29 +221,20 @@ def create_gpu_game_loop(command_queue: Queue, initial_seed_image=None, initial_
147
  Generator that keeps GPU allocated and processes commands.
148
  Yields (frame, frame_count) tuples.
149
  """
150
- pipe.to("cuda")
151
- pipe.blocks.sub_blocks['before_denoise'].sub_blocks['setup_kv_cache']._setup_kv_cache(pipe.transformer, pipe.device, torch.bfloat16)
152
- aoti_load_(
153
- pipe.transformer,
154
- "diffusers-internal-dev/world-engine-aot",
155
- "transformer-fp8.pt2",
156
- "transformer-fp8-constants.pt"
157
- )
158
- aoti_load_(
159
- pipe.vae.decoder,
160
- "diffusers-internal-dev/world-engine-aot",
161
- "decoder.pt2",
162
- "decoder-constants.pt"
163
- )
164
  n_frames = pipe.transformer.config.n_frames
165
  print(f"Model loaded! (n_frames={n_frames})")
 
166
 
167
  # Initialize state with provided seed or random
168
  if initial_seed_image is not None:
 
169
  seed_image = initial_seed_image.resize((640, 360), Image.BILINEAR)
170
  elif initial_seed_url is not None:
 
171
  seed_image = load_seed_frame(initial_seed_url)
172
  else:
 
173
  seed_image = load_seed_frame(random.choice(SEED_FRAME_URLS))
174
 
175
  state = pipe(
@@ -188,70 +253,89 @@ def create_gpu_game_loop(command_queue: Queue, initial_seed_image=None, initial_
188
  # Yield initial frame
189
  yield (frame, frame_count)
190
 
191
- # Main loop - process commands and yield frames
192
- while True:
193
- try:
194
- # Non-blocking get with short timeout for responsiveness
195
- command = command_queue.get(timeout=0.005)
196
- except:
197
- # No command, continue idle
198
- command = None
199
-
200
- if command is None:
201
- continue
202
 
203
- if isinstance(command, StopCommand):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  print("Stop command received, ending GPU session")
205
  break
206
 
207
- elif isinstance(command, ResetCommand):
208
- print("Reset command received")
209
- if command.seed_image is not None:
210
- seed_img = command.seed_image.resize((640, 360), Image.BILINEAR)
211
- elif command.seed_url:
212
- seed_img = load_seed_frame(command.seed_url)
 
 
 
213
  else:
 
214
  seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS))
215
 
216
  state = pipe(
217
- prompt=command.prompt,
218
  image=seed_img,
219
  button=set(),
220
  mouse=(0.0, 0.0),
221
  output_type="pil",
222
  )
223
  frame_count = 1
 
224
  frame = state.values.get("images")
225
  yield (frame, frame_count)
 
226
 
227
- elif isinstance(command, GenerateCommand):
228
- # Generate next frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  state = pipe(
230
- state,
231
- prompt=command.prompt,
232
- button=command.buttons,
233
- mouse=command.mouse,
234
- image=None,
235
  output_type="pil",
236
  )
237
- frame_count += 1
238
  frame = state.values.get("images")
239
 
240
- # Auto-reset near end of context
241
- if frame_count >= n_frames - 2:
242
- print(f"Auto-reset at frame {frame_count}")
243
- seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS))
244
- state = pipe(
245
- prompt=command.prompt,
246
- image=seed_img,
247
- button=set(),
248
- mouse=(0.0, 0.0),
249
- output_type="pil",
250
- )
251
- frame_count = 1
252
- frame = state.values.get("images")
253
-
254
- yield (frame, frame_count)
255
 
256
  print("GPU session ended")
257
 
@@ -681,6 +765,9 @@ def create_app():
681
 
682
  # Selected seed URL from examples
683
  selected_seed_url = gr.State(None)
 
 
 
684
 
685
  gr.Markdown("""
686
  # 🌍 Waypoint 1 Small
@@ -742,7 +829,9 @@ def create_app():
742
  lines=2,
743
  )
744
 
745
- frame_display = gr.Number(label="Frame", value=0, interactive=False)
 
 
746
 
747
  # --- Event Handlers ---
748
 
@@ -754,26 +843,25 @@ def create_app():
754
 
755
  def on_gallery_start(state, evt: gr.SelectData, uploaded_image, prompt):
756
  """Handle gallery selection - start/restart game with selected world."""
 
757
  if evt.index is None or evt.index >= len(SEED_FRAME_URLS):
758
  # No valid selection, do nothing
759
  yield (state, None, 0, None, 0, gr.update(), gr.update())
760
  return
761
-
762
  selected_url = SEED_FRAME_URLS[evt.index]
763
-
764
  # If game is running, stop it first
765
- if len(state) > 0:
766
- gen, command_queue = state
767
- command_queue.put(StopCommand())
768
- try:
769
- while True:
770
- next(gen)
771
- except StopIteration:
772
- pass
773
-
774
  # Show info about controls
775
  gr.Info("Controls locked! Press ESC to release mouse/keyboard capture.", duration=5)
776
-
777
  # Show loading state
778
  loading_img = create_loading_image(text="Generating World ...")
779
  yield (
@@ -785,21 +873,42 @@ def create_app():
785
  gr.update(interactive=False),
786
  gr.update(interactive=False),
787
  )
788
-
789
  # Start new game with selected world
790
  command_queue = Queue()
 
 
 
791
  gen = create_gpu_game_loop(
792
  command_queue,
793
  initial_seed_image=None,
794
  initial_seed_url=selected_url,
795
  initial_prompt=prompt or "An explorable world"
796
  )
797
-
798
  # Get initial frame
799
  frame, frame_count = next(gen)
800
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801
  yield (
802
- (gen, command_queue),
803
  frame,
804
  frame_count,
805
  frame,
@@ -809,10 +918,15 @@ def create_app():
809
  )
810
 
811
  def on_start(selected_url, uploaded_image, prompt):
812
- """Start GPU session - generator that shows loading then first frame."""
 
 
 
 
 
813
  # Show info about controls
814
  gr.Info("Controls locked! Press ESC to release mouse/keyboard capture.", duration=5)
815
-
816
  # Show loading state immediately
817
  loading_img = create_loading_image(text="Generating World ...")
818
  yield (
@@ -826,17 +940,32 @@ def create_app():
826
  )
827
 
828
  # Determine seed image/url
829
- # Priority: uploaded image > selected from gallery > random
830
  seed_image = None
831
  seed_url = None
832
 
833
- if uploaded_image is not None:
 
 
 
 
 
 
 
 
834
  seed_image = uploaded_image
 
835
  elif selected_url is not None:
836
  seed_url = selected_url
 
 
 
837
  # else: random will be chosen in create_gpu_game_loop
838
 
839
  command_queue = Queue()
 
 
 
840
  gen = create_gpu_game_loop(
841
  command_queue,
842
  initial_seed_image=seed_image,
@@ -844,11 +973,29 @@ def create_app():
844
  initial_prompt=prompt or "An explorable world"
845
  )
846
 
847
- # Get initial frame
848
  frame, frame_count = next(gen)
849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850
  yield (
851
- (gen, command_queue), # session_state
852
  frame, # latest_frame
853
  frame_count, # latest_frame_count
854
  frame, # video_output
@@ -858,20 +1005,19 @@ def create_app():
858
  )
859
 
860
  def on_stop(state):
861
- """Stop GPU session."""
862
- if len(state) == 0:
863
  return ((), None, 0, None, 0,
864
  gr.update(interactive=True), gr.update(interactive=False))
865
 
866
- gen, command_queue = state
867
- command_queue.put(StopCommand())
 
 
868
 
869
- # Drain generator
870
- try:
871
- while True:
872
- next(gen)
873
- except StopIteration:
874
- pass
875
 
876
  return (
877
  (),
@@ -883,32 +1029,42 @@ def create_app():
883
  gr.update(interactive=False),
884
  )
885
 
886
- def on_generate_tick(state, controls, prompt, current_frame, current_count):
887
- """Called by timer - send generate command and get next frame."""
888
- if len(state) == 0:
889
- return current_frame, current_count, current_frame, current_count
890
 
891
- gen, command_queue = state
892
 
893
- # Send generate command
894
  buttons = set(controls.get("buttons", []))
895
  mouse = (controls.get("mouse_x", 0.0), controls.get("mouse_y", 0.0))
896
- command_queue.put(GenerateCommand(buttons=buttons, mouse=mouse, prompt=prompt))
897
 
898
- # Get next frame
899
  try:
900
- frame, frame_count = next(gen)
901
- return frame, frame_count, frame, frame_count
902
- except StopIteration:
903
- return current_frame, current_count, current_frame, current_count
 
904
 
905
  def on_reset(state, selected_url, uploaded_image, prompt):
906
  """Reset world with new seed - starts game if not running."""
 
 
 
 
 
 
 
 
 
907
  # If game is not running, start it
908
- if len(state) == 0:
909
  # Show info about controls
910
  gr.Info("Controls locked! Press ESC to release mouse/keyboard capture.", duration=5)
911
-
912
  # Show loading state
913
  loading_img = create_loading_image(text="Generating World ...")
914
  yield (
@@ -920,29 +1076,54 @@ def create_app():
920
  gr.update(interactive=False),
921
  gr.update(interactive=False),
922
  )
923
-
924
  # Priority: uploaded image > selected from gallery > random
925
  seed_image = None
926
  seed_url = None
927
-
928
- if uploaded_image is not None:
929
  seed_image = uploaded_image
 
930
  elif selected_url is not None:
931
  seed_url = selected_url
932
-
 
 
 
933
  command_queue = Queue()
 
 
 
934
  gen = create_gpu_game_loop(
935
  command_queue,
936
  initial_seed_image=seed_image,
937
  initial_seed_url=seed_url,
938
  initial_prompt=prompt or "An explorable world"
939
  )
940
-
941
  # Get initial frame
942
  frame, frame_count = next(gen)
943
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944
  yield (
945
- (gen, command_queue),
946
  frame,
947
  frame_count,
948
  frame,
@@ -953,40 +1134,33 @@ def create_app():
953
  return
954
 
955
  # Game is running - reset it
956
- gen, command_queue = state
957
 
958
  # Priority: uploaded image > selected from gallery > random
959
  seed_image = None
960
  seed_url = None
961
 
962
- if uploaded_image is not None:
963
  seed_image = uploaded_image
 
964
  elif selected_url is not None:
965
  seed_url = selected_url
 
 
 
966
 
967
- command_queue.put(ResetCommand(seed_image=seed_image, seed_url=seed_url, prompt=prompt))
968
 
969
- try:
970
- frame, frame_count = next(gen)
971
- yield (
972
- state,
973
- frame,
974
- frame_count,
975
- frame,
976
- frame_count,
977
- gr.update(),
978
- gr.update(),
979
- )
980
- except StopIteration:
981
- yield (
982
- state,
983
- None,
984
- 0,
985
- None,
986
- 0,
987
- gr.update(),
988
- gr.update(),
989
- )
990
 
991
  def on_controller_change(value):
992
  """Update current controls state."""
@@ -1014,7 +1188,7 @@ def create_app():
1014
 
1015
  start_btn.click(
1016
  fn=on_start,
1017
- inputs=[selected_seed_url, seed_image_upload, prompt_input],
1018
  outputs=[session_state, latest_frame, latest_frame_count,
1019
  video_output, frame_display, start_btn, stop_btn],
1020
  js="() => { setTimeout(() => { if (window.worldEngineRequestPointerLock) window.worldEngineRequestPointerLock(); }, 500); }",
@@ -1029,7 +1203,7 @@ def create_app():
1029
 
1030
  reset_btn.click(
1031
  fn=on_reset,
1032
- inputs=[session_state, selected_seed_url, seed_image_upload, current_prompt],
1033
  outputs=[session_state, latest_frame, latest_frame_count,
1034
  video_output, frame_display, start_btn, stop_btn],
1035
  js="() => { setTimeout(() => { if (window.worldEngineRequestPointerLock) window.worldEngineRequestPointerLock(); }, 500); }",
@@ -1037,13 +1211,30 @@ def create_app():
1037
 
1038
  control_input.change(fn=on_controller_change, inputs=[control_input], outputs=[current_controls])
1039
  prompt_input.change(fn=on_prompt_change, inputs=[prompt_input], outputs=[current_prompt])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1040
 
1041
  # Timer for continuous generation
1042
- timer = gr.Timer(value=1/30) # 15 FPS target
1043
  timer.tick(
1044
  fn=on_generate_tick,
1045
- inputs=[session_state, current_controls, current_prompt, latest_frame, latest_frame_count],
1046
- outputs=[latest_frame, latest_frame_count, video_output, frame_display],
1047
  )
1048
 
1049
  # Pointer lock JS - also allows clicking the game window
 
17
 
18
  import base64
19
  import os
20
+ import queue
21
  import random
22
+ import threading
23
+ import time
24
+ from collections import deque
25
+ from dataclasses import dataclass, field
26
  from io import BytesIO
27
  from multiprocessing import Queue
28
  from pathlib import Path
 
58
  pipe = ModularPipeline.from_pretrained(MODEL_ID, trust_remote_code=True, revision="aot-compatible")
59
  pipe.load_components(["transformer", "vae"], trust_remote_code=True, revision="aot-compatible", torch_dtype=torch.bfloat16)
60
  pipe.load_components(["text_encoder", "tokenizer"], trust_remote_code=True, torch_dtype=torch.bfloat16)
61
+ pipe.to("cuda")
62
+ pipe.blocks.sub_blocks['before_denoise'].sub_blocks['setup_kv_cache']._setup_kv_cache(pipe.transformer, pipe.device, torch.bfloat16)
63
+ aoti_load_(
64
+ pipe.transformer,
65
+ "diffusers-internal-dev/world-engine-aot",
66
+ "transformer-fp8.pt2",
67
+ "transformer-fp8-constants.pt"
68
+ )
69
+ aoti_load_(
70
+ pipe.vae.decoder,
71
+ "diffusers-internal-dev/world-engine-aot",
72
+ "decoder.pt2",
73
+ "decoder-constants.pt"
74
+ )
75
 
76
  SEED_FRAME_URLS = [
77
  "https://gist.github.com/user-attachments/assets/5d91c49a-2ae9-418f-99c0-e93ae387e1de",
 
155
  pass
156
 
157
 
158
+ # --- Session State ---
159
+ @dataclass
160
+ class GameSession:
161
+ """Per-user game session with background worker thread."""
162
+ command_queue: Queue
163
+ frame_queue: queue.Queue # Thread-safe queue for output frames
164
+ worker_thread: threading.Thread
165
+ stop_event: threading.Event
166
+ generator: object = None # The GPU generator
167
+ frame_times: deque = field(default_factory=lambda: deque(maxlen=30)) # Track last 30 frame times for FPS
168
+
169
+
170
+ def gpu_worker_thread(gen, command_queue, frame_queue, stop_event, frame_times):
171
+ """
172
+ Worker thread that consumes the GPU generator and pushes frames to frame_queue.
173
+ Runs independently of Gradio's timer, allowing non-blocking frame reads.
174
+ """
175
+ try:
176
+ while not stop_event.is_set():
177
+ # Get next frame from generator (this blocks on GPU work)
178
+ try:
179
+ frame, frame_count = next(gen)
180
+
181
+ # Track frame generation time for FPS
182
+ now = time.time()
183
+ frame_times.append(now)
184
+
185
+ # Calculate FPS from generation times
186
+ if len(frame_times) >= 2:
187
+ elapsed = frame_times[-1] - frame_times[0]
188
+ fps = (len(frame_times) - 1) / elapsed if elapsed > 0 else 0.0
189
+ else:
190
+ fps = 0.0
191
+
192
+ # Put frame in queue, replacing old frame if queue is full
193
+ try:
194
+ # Clear old frame if present (keep only latest)
195
+ while not frame_queue.empty():
196
+ try:
197
+ frame_queue.get_nowait()
198
+ except queue.Empty:
199
+ break
200
+ except:
201
+ pass
202
+ frame_queue.put_nowait((frame, frame_count, round(fps, 1)))
203
+ except StopIteration:
204
+ print("Generator exhausted, worker thread ending")
205
+ break
206
+ except Exception as e:
207
+ print(f"Worker thread error: {e}")
208
+ break
209
+ finally:
210
+ print("Worker thread finished")
211
+
212
+
213
  # --- GPU Session Generator ---
214
  def create_gpu_game_loop(command_queue: Queue, initial_seed_image=None, initial_seed_url=None, initial_prompt="An explorable world"):
215
  """Create GPU game loop generator with closure over command_queue."""
216
+ print(f"create_gpu_game_loop: initial_seed_image={type(initial_seed_image)}, initial_seed_url={initial_seed_url}")
217
 
218
  @spaces.GPU(duration=90)
219
  def gpu_game_loop():
 
221
  Generator that keeps GPU allocated and processes commands.
222
  Yields (frame, frame_count) tuples.
223
  """
224
+
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  n_frames = pipe.transformer.config.n_frames
226
  print(f"Model loaded! (n_frames={n_frames})")
227
+ print(f"gpu_game_loop: initial_seed_image={type(initial_seed_image)}, initial_seed_url={initial_seed_url}")
228
 
229
  # Initialize state with provided seed or random
230
  if initial_seed_image is not None:
231
+ print(f"gpu_game_loop init: Using initial_seed_image {initial_seed_image.size if hasattr(initial_seed_image, 'size') else type(initial_seed_image)}")
232
  seed_image = initial_seed_image.resize((640, 360), Image.BILINEAR)
233
  elif initial_seed_url is not None:
234
+ print(f"gpu_game_loop init: Using initial_seed_url {initial_seed_url}")
235
  seed_image = load_seed_frame(initial_seed_url)
236
  else:
237
+ print("gpu_game_loop init: Using random seed")
238
  seed_image = load_seed_frame(random.choice(SEED_FRAME_URLS))
239
 
240
  state = pipe(
 
253
  # Yield initial frame
254
  yield (frame, frame_count)
255
 
256
+ # Track current input state (updated by commands)
257
+ current_buttons = set()
258
+ current_mouse = (0.0, 0.0)
259
+ current_prompt = initial_prompt
 
 
 
 
 
 
 
260
 
261
+ # Main loop - generate frames continuously, sample latest input
262
+ while True:
263
+ # Drain command queue - get all pending commands (non-blocking)
264
+ stop_requested = False
265
+ reset_command = None
266
+ while True:
267
+ try:
268
+ command = command_queue.get_nowait()
269
+ if isinstance(command, StopCommand):
270
+ stop_requested = True
271
+ break
272
+ elif isinstance(command, ResetCommand):
273
+ reset_command = command
274
+ elif isinstance(command, GenerateCommand):
275
+ # Update current input state with latest command
276
+ current_buttons = command.buttons
277
+ current_mouse = command.mouse
278
+ current_prompt = command.prompt
279
+ except:
280
+ break # Queue empty
281
+
282
+ if stop_requested:
283
  print("Stop command received, ending GPU session")
284
  break
285
 
286
+ # Handle reset if requested
287
+ if reset_command is not None:
288
+ print(f"Reset command received: seed_image={type(reset_command.seed_image)}, seed_url={reset_command.seed_url}")
289
+ if reset_command.seed_image is not None:
290
+ print(f"Using seed_image from command: {reset_command.seed_image.size if hasattr(reset_command.seed_image, 'size') else 'unknown'}")
291
+ seed_img = reset_command.seed_image.resize((640, 360), Image.BILINEAR)
292
+ elif reset_command.seed_url:
293
+ print(f"Using seed_url from command: {reset_command.seed_url}")
294
+ seed_img = load_seed_frame(reset_command.seed_url)
295
  else:
296
+ print("Using random seed from command")
297
  seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS))
298
 
299
  state = pipe(
300
+ prompt=reset_command.prompt,
301
  image=seed_img,
302
  button=set(),
303
  mouse=(0.0, 0.0),
304
  output_type="pil",
305
  )
306
  frame_count = 1
307
+ current_prompt = reset_command.prompt
308
  frame = state.values.get("images")
309
  yield (frame, frame_count)
310
+ continue
311
 
312
+ # Generate next frame with current input state (ALWAYS generates)
313
+ state = pipe(
314
+ state,
315
+ prompt=current_prompt,
316
+ button=current_buttons,
317
+ mouse=current_mouse,
318
+ image=None,
319
+ output_type="pil",
320
+ )
321
+ frame_count += 1
322
+ frame = state.values.get("images")
323
+
324
+ # Auto-reset near end of context
325
+ if frame_count >= n_frames - 2:
326
+ print(f"Auto-reset at frame {frame_count}")
327
+ seed_img = load_seed_frame(random.choice(SEED_FRAME_URLS))
328
  state = pipe(
329
+ prompt=current_prompt,
330
+ image=seed_img,
331
+ button=set(),
332
+ mouse=(0.0, 0.0),
 
333
  output_type="pil",
334
  )
335
+ frame_count = 1
336
  frame = state.values.get("images")
337
 
338
+ yield (frame, frame_count)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
  print("GPU session ended")
341
 
 
765
 
766
  # Selected seed URL from examples
767
  selected_seed_url = gr.State(None)
768
+
769
+ # Store uploaded image in state (workaround for Gradio component value issues)
770
+ uploaded_image_state = gr.State(None)
771
 
772
  gr.Markdown("""
773
  # 🌍 Waypoint 1 Small
 
829
  lines=2,
830
  )
831
 
832
+ with gr.Row():
833
+ frame_display = gr.Number(label="Frame", value=0, interactive=False)
834
+ fps_display = gr.Number(label="FPS", value=0.0, interactive=False)
835
 
836
  # --- Event Handlers ---
837
 
 
843
 
844
  def on_gallery_start(state, evt: gr.SelectData, uploaded_image, prompt):
845
  """Handle gallery selection - start/restart game with selected world."""
846
+ print(f"on_gallery_start CALLED: evt.index={evt.index}", flush=True)
847
  if evt.index is None or evt.index >= len(SEED_FRAME_URLS):
848
  # No valid selection, do nothing
849
  yield (state, None, 0, None, 0, gr.update(), gr.update())
850
  return
851
+
852
  selected_url = SEED_FRAME_URLS[evt.index]
853
+
854
  # If game is running, stop it first
855
+ if state and isinstance(state, GameSession):
856
+ session = state
857
+ session.stop_event.set()
858
+ session.command_queue.put(StopCommand())
859
+ if session.worker_thread.is_alive():
860
+ session.worker_thread.join(timeout=2.0)
861
+
 
 
862
  # Show info about controls
863
  gr.Info("Controls locked! Press ESC to release mouse/keyboard capture.", duration=5)
864
+
865
  # Show loading state
866
  loading_img = create_loading_image(text="Generating World ...")
867
  yield (
 
873
  gr.update(interactive=False),
874
  gr.update(interactive=False),
875
  )
876
+
877
  # Start new game with selected world
878
  command_queue = Queue()
879
+ frame_queue = queue.Queue(maxsize=2)
880
+ stop_event = threading.Event()
881
+
882
  gen = create_gpu_game_loop(
883
  command_queue,
884
  initial_seed_image=None,
885
  initial_seed_url=selected_url,
886
  initial_prompt=prompt or "An explorable world"
887
  )
888
+
889
  # Get initial frame
890
  frame, frame_count = next(gen)
891
+
892
+ # Start worker thread
893
+ frame_times = deque(maxlen=30)
894
+ worker = threading.Thread(
895
+ target=gpu_worker_thread,
896
+ args=(gen, command_queue, frame_queue, stop_event, frame_times),
897
+ daemon=True
898
+ )
899
+ worker.start()
900
+
901
+ session = GameSession(
902
+ command_queue=command_queue,
903
+ frame_queue=frame_queue,
904
+ worker_thread=worker,
905
+ stop_event=stop_event,
906
+ generator=gen,
907
+ frame_times=frame_times,
908
+ )
909
+
910
  yield (
911
+ session,
912
  frame,
913
  frame_count,
914
  frame,
 
918
  )
919
 
920
  def on_start(selected_url, uploaded_image, prompt):
921
+ """Start GPU session - creates background worker thread for non-blocking frames."""
922
+ print(f"on_start CALLED:", flush=True)
923
+ print(f" uploaded_image (from state) type: {type(uploaded_image)}", flush=True)
924
+ print(f" uploaded_image is PIL: {isinstance(uploaded_image, Image.Image) if uploaded_image else False}", flush=True)
925
+ print(f" selected_url: {selected_url}", flush=True)
926
+
927
  # Show info about controls
928
  gr.Info("Controls locked! Press ESC to release mouse/keyboard capture.", duration=5)
929
+
930
  # Show loading state immediately
931
  loading_img = create_loading_image(text="Generating World ...")
932
  yield (
 
940
  )
941
 
942
  # Determine seed image/url
943
+ # Priority: uploaded image (from state) > selected from gallery > random
944
  seed_image = None
945
  seed_url = None
946
 
947
+ # Check if uploaded_image is a valid PIL Image
948
+ is_pil_image = isinstance(uploaded_image, Image.Image)
949
+ has_uploaded_image = uploaded_image is not None and is_pil_image
950
+
951
+ print(f"on_start decision:", flush=True)
952
+ print(f" is_pil_image: {is_pil_image}", flush=True)
953
+ print(f" has_uploaded_image: {has_uploaded_image}", flush=True)
954
+
955
+ if has_uploaded_image:
956
  seed_image = uploaded_image
957
+ print(f"on_start: Using uploaded image: {seed_image.size}", flush=True)
958
  elif selected_url is not None:
959
  seed_url = selected_url
960
+ print(f"on_start: Using selected URL: {seed_url}", flush=True)
961
+ else:
962
+ print("on_start: Using random seed", flush=True)
963
  # else: random will be chosen in create_gpu_game_loop
964
 
965
  command_queue = Queue()
966
+ frame_queue = queue.Queue(maxsize=2) # Thread-safe output queue
967
+ stop_event = threading.Event()
968
+
969
  gen = create_gpu_game_loop(
970
  command_queue,
971
  initial_seed_image=seed_image,
 
973
  initial_prompt=prompt or "An explorable world"
974
  )
975
 
976
+ # Get initial frame synchronously (needed to show first frame)
977
  frame, frame_count = next(gen)
978
 
979
+ # Start worker thread to consume generator in background
980
+ frame_times = deque(maxlen=30)
981
+ worker = threading.Thread(
982
+ target=gpu_worker_thread,
983
+ args=(gen, command_queue, frame_queue, stop_event, frame_times),
984
+ daemon=True
985
+ )
986
+ worker.start()
987
+
988
+ session = GameSession(
989
+ command_queue=command_queue,
990
+ frame_queue=frame_queue,
991
+ worker_thread=worker,
992
+ stop_event=stop_event,
993
+ generator=gen,
994
+ frame_times=frame_times,
995
+ )
996
+
997
  yield (
998
+ session, # session_state
999
  frame, # latest_frame
1000
  frame_count, # latest_frame_count
1001
  frame, # video_output
 
1005
  )
1006
 
1007
  def on_stop(state):
1008
+ """Stop GPU session and cleanup worker thread."""
1009
+ if not state or not isinstance(state, GameSession):
1010
  return ((), None, 0, None, 0,
1011
  gr.update(interactive=True), gr.update(interactive=False))
1012
 
1013
+ session = state
1014
+ # Signal worker to stop
1015
+ session.stop_event.set()
1016
+ session.command_queue.put(StopCommand())
1017
 
1018
+ # Wait for worker thread to finish (with timeout)
1019
+ if session.worker_thread.is_alive():
1020
+ session.worker_thread.join(timeout=2.0)
 
 
 
1021
 
1022
  return (
1023
  (),
 
1029
  gr.update(interactive=False),
1030
  )
1031
 
1032
+ def on_generate_tick(state, controls, prompt, current_frame, current_count, current_fps):
1033
+ """Called by timer - send generate command and get next frame (non-blocking)."""
1034
+ if not state or not isinstance(state, GameSession):
1035
+ return current_frame, current_count, current_frame, current_count, 0.0
1036
 
1037
+ session = state
1038
 
1039
+ # Send generate command (non-blocking)
1040
  buttons = set(controls.get("buttons", []))
1041
  mouse = (controls.get("mouse_x", 0.0), controls.get("mouse_y", 0.0))
1042
+ session.command_queue.put(GenerateCommand(buttons=buttons, mouse=mouse, prompt=prompt))
1043
 
1044
+ # Non-blocking read from frame_queue - get latest frame if available
1045
  try:
1046
+ frame, frame_count, fps = session.frame_queue.get_nowait()
1047
+ return frame, frame_count, frame, frame_count, fps
1048
+ except queue.Empty:
1049
+ # No new frame yet, show previous frame (never blocks!)
1050
+ return current_frame, current_count, current_frame, current_count, current_fps
1051
 
1052
  def on_reset(state, selected_url, uploaded_image, prompt):
1053
  """Reset world with new seed - starts game if not running."""
1054
+ print(f"on_reset CALLED:", flush=True)
1055
+ print(f" uploaded_image (from state) type: {type(uploaded_image)}", flush=True)
1056
+ print(f" uploaded_image is PIL: {isinstance(uploaded_image, Image.Image) if uploaded_image else False}", flush=True)
1057
+ print(f" selected_url: {selected_url}", flush=True)
1058
+
1059
+ # Check if uploaded_image is a valid PIL Image
1060
+ is_pil_image = isinstance(uploaded_image, Image.Image)
1061
+ has_uploaded_image = uploaded_image is not None and is_pil_image
1062
+
1063
  # If game is not running, start it
1064
+ if not state or not isinstance(state, GameSession):
1065
  # Show info about controls
1066
  gr.Info("Controls locked! Press ESC to release mouse/keyboard capture.", duration=5)
1067
+
1068
  # Show loading state
1069
  loading_img = create_loading_image(text="Generating World ...")
1070
  yield (
 
1076
  gr.update(interactive=False),
1077
  gr.update(interactive=False),
1078
  )
1079
+
1080
  # Priority: uploaded image > selected from gallery > random
1081
  seed_image = None
1082
  seed_url = None
1083
+
1084
+ if has_uploaded_image:
1085
  seed_image = uploaded_image
1086
+ print(f"on_reset (start): Using uploaded image: {seed_image.size}", flush=True)
1087
  elif selected_url is not None:
1088
  seed_url = selected_url
1089
+ print(f"on_reset (start): Using selected URL: {seed_url}", flush=True)
1090
+ else:
1091
+ print("on_reset (start): Using random seed", flush=True)
1092
+
1093
  command_queue = Queue()
1094
+ frame_queue = queue.Queue(maxsize=2)
1095
+ stop_event = threading.Event()
1096
+
1097
  gen = create_gpu_game_loop(
1098
  command_queue,
1099
  initial_seed_image=seed_image,
1100
  initial_seed_url=seed_url,
1101
  initial_prompt=prompt or "An explorable world"
1102
  )
1103
+
1104
  # Get initial frame
1105
  frame, frame_count = next(gen)
1106
+
1107
+ # Start worker thread
1108
+ frame_times = deque(maxlen=30)
1109
+ worker = threading.Thread(
1110
+ target=gpu_worker_thread,
1111
+ args=(gen, command_queue, frame_queue, stop_event, frame_times),
1112
+ daemon=True
1113
+ )
1114
+ worker.start()
1115
+
1116
+ session = GameSession(
1117
+ command_queue=command_queue,
1118
+ frame_queue=frame_queue,
1119
+ worker_thread=worker,
1120
+ stop_event=stop_event,
1121
+ generator=gen,
1122
+ frame_times=frame_times,
1123
+ )
1124
+
1125
  yield (
1126
+ session,
1127
  frame,
1128
  frame_count,
1129
  frame,
 
1134
  return
1135
 
1136
  # Game is running - reset it
1137
+ session = state
1138
 
1139
  # Priority: uploaded image > selected from gallery > random
1140
  seed_image = None
1141
  seed_url = None
1142
 
1143
+ if has_uploaded_image:
1144
  seed_image = uploaded_image
1145
+ print(f"on_reset (running): Using uploaded image: {seed_image.size}", flush=True)
1146
  elif selected_url is not None:
1147
  seed_url = selected_url
1148
+ print(f"on_reset (running): Using selected URL: {seed_url}", flush=True)
1149
+ else:
1150
+ print("on_reset (running): Using random seed", flush=True)
1151
 
1152
+ session.command_queue.put(ResetCommand(seed_image=seed_image, seed_url=seed_url, prompt=prompt))
1153
 
1154
+ # Just return current state - next timer tick will pick up the reset frame
1155
+ yield (
1156
+ state,
1157
+ None,
1158
+ 0,
1159
+ None,
1160
+ 0,
1161
+ gr.update(),
1162
+ gr.update(),
1163
+ )
 
 
 
 
 
 
 
 
 
 
 
1164
 
1165
  def on_controller_change(value):
1166
  """Update current controls state."""
 
1188
 
1189
  start_btn.click(
1190
  fn=on_start,
1191
+ inputs=[selected_seed_url, uploaded_image_state, prompt_input],
1192
  outputs=[session_state, latest_frame, latest_frame_count,
1193
  video_output, frame_display, start_btn, stop_btn],
1194
  js="() => { setTimeout(() => { if (window.worldEngineRequestPointerLock) window.worldEngineRequestPointerLock(); }, 500); }",
 
1203
 
1204
  reset_btn.click(
1205
  fn=on_reset,
1206
+ inputs=[session_state, selected_seed_url, uploaded_image_state, current_prompt],
1207
  outputs=[session_state, latest_frame, latest_frame_count,
1208
  video_output, frame_display, start_btn, stop_btn],
1209
  js="() => { setTimeout(() => { if (window.worldEngineRequestPointerLock) window.worldEngineRequestPointerLock(); }, 500); }",
 
1211
 
1212
  control_input.change(fn=on_controller_change, inputs=[control_input], outputs=[current_controls])
1213
  prompt_input.change(fn=on_prompt_change, inputs=[prompt_input], outputs=[current_prompt])
1214
+
1215
+ # Store uploaded image in state and clear gallery selection
1216
+ def on_image_upload(image):
1217
+ """When user uploads an image, store it and clear the gallery selection."""
1218
+ print(f"on_image_upload: image type={type(image)}, is PIL={isinstance(image, Image.Image) if image else False}", flush=True)
1219
+ if image is not None and isinstance(image, Image.Image):
1220
+ print(f"on_image_upload: Storing uploaded image {image.size}", flush=True)
1221
+ return image, None # Store image, clear selected_seed_url
1222
+ else:
1223
+ print(f"on_image_upload: Clearing stored image", flush=True)
1224
+ return None, gr.update() # Clear stored image, no change to URL
1225
+
1226
+ seed_image_upload.change(
1227
+ fn=on_image_upload,
1228
+ inputs=[seed_image_upload],
1229
+ outputs=[uploaded_image_state, selected_seed_url],
1230
+ )
1231
 
1232
  # Timer for continuous generation
1233
+ timer = gr.Timer(value=1/30)
1234
  timer.tick(
1235
  fn=on_generate_tick,
1236
+ inputs=[session_state, current_controls, current_prompt, latest_frame, latest_frame_count, fps_display],
1237
+ outputs=[latest_frame, latest_frame_count, video_output, frame_display, fps_display],
1238
  )
1239
 
1240
  # Pointer lock JS - also allows clicking the game window