ZENLLC commited on
Commit
a8e6497
·
verified ·
1 Parent(s): 5872f77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +784 -454
app.py CHANGED
@@ -9,33 +9,31 @@ from PIL import Image, ImageDraw
9
  import gradio as gr
10
 
11
  # ============================================================
12
- # ChronoSandbox — Agent Timeline Lab (Deterministic, Inspectable)
13
- # - Multi-agent gridworld
14
- # - First-person pseudo-3D raycast view for selected agent
15
- # - Global truth map + per-agent belief maps (fog-of-war memory)
16
- # - AutoRun animation, time dilation, rewind scrubber
17
- # - Branching timelines (fork from any previous step)
18
- # - Click-to-edit map tiles
19
  #
20
- # Compatible with older Gradio versions by avoiding fn_kwargs in .click()
21
  # ============================================================
22
 
23
  # -----------------------------
24
- # World / render config
25
  # -----------------------------
26
  GRID_W, GRID_H = 21, 15
27
- TILE = 22 # top-down pixels per tile
28
 
29
  VIEW_W, VIEW_H = 640, 360
30
  RAY_W = 320
31
  FOV_DEG = 78
32
  MAX_DEPTH = 20
33
 
34
- # 0=E,1=S,2=W,3=N
35
  DIRS = [(1, 0), (0, 1), (-1, 0), (0, -1)]
36
  ORI_DEG = [0, 90, 180, 270]
37
 
38
- # Tile types
39
  EMPTY = 0
40
  WALL = 1
41
  FOOD = 2
@@ -52,37 +50,69 @@ TILE_NAMES = {
52
  TELE: "Teleporter",
53
  }
54
 
55
- # Palette (simple + inspectable)
 
 
 
 
 
56
  SKY = np.array([14, 16, 26], dtype=np.uint8)
57
  FLOOR_NEAR = np.array([24, 26, 40], dtype=np.uint8)
58
  FLOOR_FAR = np.array([10, 11, 18], dtype=np.uint8)
59
  WALL_BASE = np.array([210, 210, 225], dtype=np.uint8)
60
  WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8)
 
61
 
62
- AGENT_COLORS = {
63
- "Predator": (255, 120, 90),
64
- "Prey": (120, 255, 160),
65
- "Scout": (120, 190, 255),
66
- }
67
 
68
  # -----------------------------
69
- # Deterministic RNG helper
70
  # -----------------------------
71
  def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator:
72
  mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531)
73
  return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF)
74
 
75
  # -----------------------------
76
- # State definitions
77
  # -----------------------------
78
  @dataclass
79
  class Agent:
80
  name: str
81
  x: int
82
  y: int
83
- ori: int # 0..3
84
  energy: int = 100
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  @dataclass
87
  class WorldState:
88
  seed: int
@@ -91,24 +121,35 @@ class WorldState:
91
  agents: Dict[str, Agent]
92
  controlled: str
93
  pov: str
94
- autorun: bool
95
- speed_hz: float
96
  overlay: bool
97
- event_log: List[str]
98
  caught: bool
99
  branches: Dict[str, int]
100
 
 
 
 
 
 
 
 
 
 
 
101
  @dataclass
102
  class Snapshot:
103
  step: int
104
  agents: Dict[str, Dict]
105
  grid: List[List[int]]
106
- event_log_tail: List[str]
107
  caught: bool
 
 
108
 
 
 
 
109
  def default_grid() -> List[List[int]]:
110
  g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)]
111
- # Border walls
112
  for x in range(GRID_W):
113
  g[0][x] = WALL
114
  g[GRID_H - 1][x] = WALL
@@ -116,12 +157,10 @@ def default_grid() -> List[List[int]]:
116
  g[y][0] = WALL
117
  g[y][GRID_W - 1] = WALL
118
 
119
- # Interior structure
120
  for x in range(4, 17):
121
  g[7][x] = WALL
122
  g[7][10] = DOOR
123
 
124
- # Items
125
  g[3][4] = FOOD
126
  g[11][15] = FOOD
127
  g[4][14] = NOISE
@@ -136,6 +175,7 @@ def init_state(seed: int) -> WorldState:
136
  "Prey": Agent("Prey", 18, 12, 2, 100),
137
  "Scout": Agent("Scout", 10, 3, 1, 100),
138
  }
 
139
  return WorldState(
140
  seed=seed,
141
  step=0,
@@ -143,25 +183,28 @@ def init_state(seed: int) -> WorldState:
143
  agents=agents,
144
  controlled="Predator",
145
  pov="Predator",
146
- autorun=False,
147
- speed_hz=8.0,
148
  overlay=False,
149
- event_log=["Initialized world."],
150
  caught=False,
151
  branches={"main": 0},
 
 
 
 
 
 
152
  )
153
 
154
  # -----------------------------
155
- # Belief memory
156
  # -----------------------------
157
  def init_belief() -> Dict[str, np.ndarray]:
158
  b = {}
159
- for name in ["Predator", "Prey", "Scout"]:
160
- b[name] = -1 * np.ones((GRID_H, GRID_W), dtype=np.int16)
161
  return b
162
 
163
  # -----------------------------
164
- # Movement + collision
165
  # -----------------------------
166
  def in_bounds(x: int, y: int) -> bool:
167
  return 0 <= x < GRID_W and 0 <= y < GRID_H
@@ -169,37 +212,10 @@ def in_bounds(x: int, y: int) -> bool:
169
  def is_blocking(tile: int) -> bool:
170
  return tile == WALL
171
 
172
- def move_forward(state: WorldState, a: Agent) -> None:
173
- dx, dy = DIRS[a.ori]
174
- nx, ny = a.x + dx, a.y + dy
175
- if not in_bounds(nx, ny):
176
- return
177
- if is_blocking(state.grid[ny][nx]):
178
- return
179
- if state.grid[ny][nx] == DOOR:
180
- state.grid[ny][nx] = EMPTY
181
- state.event_log.append(f"t={state.step}: {a.name} opened a door.")
182
- a.x, a.y = nx, ny
183
 
184
- if state.grid[ny][nx] == TELE:
185
- teles = [(x, y) for y in range(GRID_H) for x in range(GRID_W) if state.grid[y][x] == TELE]
186
- if len(teles) >= 2:
187
- teles_sorted = sorted(teles)
188
- idx = teles_sorted.index((nx, ny))
189
- dest = teles_sorted[(idx + 1) % len(teles_sorted)]
190
- a.x, a.y = dest
191
- state.event_log.append(f"t={state.step}: {a.name} teleported.")
192
-
193
- def turn_left(a: Agent) -> None:
194
- a.ori = (a.ori - 1) % 4
195
-
196
- def turn_right(a: Agent) -> None:
197
- a.ori = (a.ori + 1) % 4
198
-
199
- # -----------------------------
200
- # LOS + FOV visibility
201
- # -----------------------------
202
- def los_clear(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> bool:
203
  dx = abs(x1 - x0)
204
  dy = abs(y1 - y0)
205
  sx = 1 if x0 < x1 else -1
@@ -220,7 +236,7 @@ def los_clear(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> bool
220
  err += dx
221
  y += sy
222
 
223
- def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = 78.0) -> bool:
224
  dx = tx - observer.x
225
  dy = ty - observer.y
226
  if dx == 0 and dy == 0:
@@ -231,10 +247,54 @@ def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = 78.0) -> bool
231
  return abs(diff) <= (fov_deg / 2)
232
 
233
  def visible(observer: Agent, target: Agent, grid: List[List[int]]) -> bool:
234
- return within_fov(observer, target.x, target.y, FOV_DEG) and los_clear(grid, observer.x, observer.y, target.x, target.y)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  # -----------------------------
237
- # Raycast pseudo-3D render
238
  # -----------------------------
239
  def raycast_view(state: WorldState, observer: Agent) -> np.ndarray:
240
  img = np.zeros((VIEW_H, VIEW_W, 3), dtype=np.uint8)
@@ -257,7 +317,8 @@ def raycast_view(state: WorldState, observer: Agent) -> np.ndarray:
257
  cos_a = math.cos(ray_ang)
258
 
259
  depth = 0.0
260
- hit_side = 0
 
261
 
262
  while depth < MAX_DEPTH:
263
  depth += 0.05
@@ -265,16 +326,16 @@ def raycast_view(state: WorldState, observer: Agent) -> np.ndarray:
265
  ty = int(oy + sin_a * depth)
266
  if not in_bounds(tx, ty):
267
  break
268
-
269
  tile = state.grid[ty][tx]
270
  if tile == WALL:
271
- hit_side = 1 if abs(cos_a) > abs(sin_a) else 0
 
272
  break
273
  if tile == DOOR:
274
- hit_side = 2
275
  break
276
 
277
- if depth >= MAX_DEPTH:
278
  continue
279
 
280
  depth *= math.cos(ray_ang - math.radians(ORI_DEG[observer.ori]))
@@ -284,12 +345,10 @@ def raycast_view(state: WorldState, observer: Agent) -> np.ndarray:
284
  y0 = max(0, VIEW_H // 2 - proj_h // 2)
285
  y1 = min(VIEW_H - 1, VIEW_H // 2 + proj_h // 2)
286
 
287
- if hit_side == 0:
288
- col = WALL_BASE.copy()
289
- elif hit_side == 1:
290
- col = WALL_SIDE.copy()
291
  else:
292
- col = np.array([180, 210, 255], dtype=np.uint8)
293
 
294
  dim = max(0.25, 1.0 - (depth / MAX_DEPTH))
295
  col = (col * dim).astype(np.uint8)
@@ -298,8 +357,9 @@ def raycast_view(state: WorldState, observer: Agent) -> np.ndarray:
298
  x1 = int((rx + 1) * (VIEW_W / RAY_W))
299
  img[y0:y1, x0:x1] = col
300
 
301
- for other_name, other in state.agents.items():
302
- if other_name == observer.name:
 
303
  continue
304
  if visible(observer, other, state.grid):
305
  dx = other.x - observer.x
@@ -316,7 +376,7 @@ def raycast_view(state: WorldState, observer: Agent) -> np.ndarray:
316
  y1 = min(VIEW_H - 1, y_mid + h // 2)
317
  x0 = max(0, sx - w // 2)
318
  x1 = min(VIEW_W - 1, sx + w // 2)
319
- col = AGENT_COLORS.get(other_name, (255, 200, 120))
320
  img[y0:y1, x0:x1] = np.array(col, dtype=np.uint8)
321
 
322
  if state.overlay:
@@ -326,9 +386,6 @@ def raycast_view(state: WorldState, observer: Agent) -> np.ndarray:
326
 
327
  return img
328
 
329
- # -----------------------------
330
- # Top-down render
331
- # -----------------------------
332
  def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_agents: bool = True) -> Image.Image:
333
  w = grid.shape[1] * TILE
334
  h = grid.shape[0] * TILE
@@ -366,10 +423,10 @@ def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_
366
  draw.line([0, yy, w, yy], fill=(12, 14, 22))
367
 
368
  if show_agents:
369
- for name, a in agents.items():
370
  cx = a.x * TILE + TILE // 2
371
  cy = a.y * TILE + 28 + TILE // 2
372
- col = AGENT_COLORS.get(name, (220, 220, 220))
373
  r = TILE // 3
374
  draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col)
375
  dx, dy = DIRS[a.ori]
@@ -380,9 +437,105 @@ def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_
380
  return im
381
 
382
  # -----------------------------
383
- # Policies (explicit + deterministic)
384
  # -----------------------------
385
- def predator_policy(state: WorldState, step: int) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  pred = state.agents["Predator"]
387
  prey = state.agents["Prey"]
388
  if visible(pred, prey, state.grid):
@@ -396,10 +549,10 @@ def predator_policy(state: WorldState, step: int) -> str:
396
  if diff > 10:
397
  return "R"
398
  return "F"
399
- r = rng_for(state.seed, step, stream=1)
400
- return r.choice(["F", "L", "R", "F", "F"])
401
 
402
- def prey_policy(state: WorldState, step: int) -> str:
403
  prey = state.agents["Prey"]
404
  pred = state.agents["Predator"]
405
  if visible(prey, pred, state.grid):
@@ -414,60 +567,62 @@ def prey_policy(state: WorldState, step: int) -> str:
414
  if diff_away > 10:
415
  return "R"
416
  return "F"
417
- for turn in [0, -1, 1, 2]:
418
- ori = (prey.ori + turn) % 4
419
- dx, dy = DIRS[ori]
420
- nx, ny = prey.x + dx, prey.y + dy
421
- if in_bounds(nx, ny) and state.grid[ny][nx] == FOOD:
422
- if turn == 0:
423
- return "F"
424
- if turn == -1:
425
- return "L"
426
- if turn == 1:
427
- return "R"
428
- return "R"
429
- r = rng_for(state.seed, step, stream=2)
430
- return r.choice(["F", "L", "R", "F"])
431
 
432
- def scout_policy(state: WorldState, step: int) -> str:
433
- scout = state.agents["Scout"]
434
- pred = state.agents["Predator"]
435
- if los_clear(state.grid, scout.x, scout.y, pred.x, pred.y):
436
- dist = abs(scout.x - pred.x) + abs(scout.y - pred.y)
437
- if dist <= 3:
438
- return "R"
439
- r = rng_for(state.seed, step, stream=3)
440
- return r.choice(["F", "L", "R", "F"])
441
- dx = pred.x - scout.x
442
- dy = pred.y - scout.y
443
- ang = (math.degrees(math.atan2(dy, dx)) % 360)
444
- facing = ORI_DEG[scout.ori]
445
- diff = (ang - facing + 540) % 360 - 180
446
- if diff < -10:
447
- return "L"
448
- if diff > 10:
449
- return "R"
450
- return "F"
451
-
452
- # -----------------------------
453
- # Simulation step
454
- # -----------------------------
455
- def apply_action(state: WorldState, agent_name: str, action: str) -> None:
456
- a = state.agents[agent_name]
457
- if action == "L":
458
- turn_left(a)
459
- elif action == "R":
460
- turn_right(a)
461
- elif action == "F":
462
- move_forward(state, a)
463
 
464
- def consume_tiles(state: WorldState) -> None:
465
- prey = state.agents["Prey"]
466
- tile = state.grid[prey.y][prey.x]
467
- if tile == FOOD:
468
- prey.energy = min(200, prey.energy + 35)
469
- state.grid[prey.y][prey.x] = EMPTY
470
- state.event_log.append(f"t={state.step}: Prey ate food (+energy).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
  def check_catch(state: WorldState) -> None:
473
  pred = state.agents["Predator"]
@@ -476,50 +631,220 @@ def check_catch(state: WorldState) -> None:
476
  state.caught = True
477
  state.event_log.append(f"t={state.step}: CAUGHT.")
478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  def tick(state: WorldState, manual_action: Optional[str] = None) -> None:
480
  if state.caught:
481
  return
482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  if manual_action:
484
- apply_action(state, state.controlled, manual_action)
485
-
486
- step = state.step
487
- if state.autorun and not manual_action:
488
- if state.controlled == "Predator":
489
- act = predator_policy(state, step)
490
- elif state.controlled == "Prey":
491
- act = prey_policy(state, step)
492
- else:
493
- act = scout_policy(state, step)
494
- apply_action(state, state.controlled, act)
495
 
496
- for name in ["Predator", "Prey", "Scout"]:
497
- if name == state.controlled:
 
498
  continue
499
- if name == "Predator":
500
- act = predator_policy(state, step)
501
- elif name == "Prey":
502
- act = prey_policy(state, step)
503
- else:
504
- act = scout_policy(state, step)
505
- apply_action(state, name, act)
506
 
507
- consume_tiles(state)
 
 
 
 
 
508
  check_catch(state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  state.step += 1
510
 
511
  # -----------------------------
512
- # History
513
  # -----------------------------
514
- MAX_HISTORY = 3000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
516
  def snapshot_of(state: WorldState) -> Snapshot:
517
  return Snapshot(
518
  step=state.step,
519
  agents={k: asdict(v) for k, v in state.agents.items()},
520
  grid=[row[:] for row in state.grid],
521
- event_log_tail=state.event_log[-12:],
522
  caught=state.caught,
 
 
523
  )
524
 
525
  def restore_into(state: WorldState, snap: Snapshot) -> None:
@@ -528,70 +853,82 @@ def restore_into(state: WorldState, snap: Snapshot) -> None:
528
  for k, d in snap.agents.items():
529
  state.agents[k] = Agent(**d)
530
  state.caught = snap.caught
531
- state.event_log.append(f"Jumped to t={snap.step} (rewind).")
532
 
533
  # -----------------------------
534
- # Belief updates
535
  # -----------------------------
536
- def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> None:
537
- belief[agent.y, agent.x] = state.grid[agent.y][agent.x]
538
- base = math.radians(ORI_DEG[agent.ori])
539
- half = math.radians(FOV_DEG / 2)
540
- rays = 33 if agent.name != "Scout" else 45
 
 
 
 
 
 
 
 
 
541
 
542
- for i in range(rays):
543
- t = i / (rays - 1)
544
- ang = base + (t * 2 - 1) * half
545
- sin_a, cos_a = math.sin(ang), math.cos(ang)
546
- ox, oy = agent.x + 0.5, agent.y + 0.5
547
- depth = 0.0
548
- while depth < MAX_DEPTH:
549
- depth += 0.2
550
- tx = int(ox + cos_a * depth)
551
- ty = int(oy + sin_a * depth)
552
- if not in_bounds(tx, ty):
553
- break
554
- belief[ty, tx] = state.grid[ty][tx]
555
- if state.grid[ty][tx] == WALL:
556
- break
557
 
558
- # -----------------------------
559
- # Views + UI helpers
560
- # -----------------------------
561
- def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, str, str]:
562
- pov_agent = state.agents[state.pov]
563
 
564
- for name, a in state.agents.items():
565
- update_belief_for_agent(state, beliefs[name], a)
 
 
 
 
 
 
 
 
 
566
 
567
- pov_img = raycast_view(state, pov_agent)
 
 
 
 
 
568
 
 
569
  truth_np = np.array(state.grid, dtype=np.int16)
570
- truth_img = render_topdown(truth_np, state.agents, f"Truth Map — t={state.step} seed={state.seed}", show_agents=True)
571
 
572
  ctrl = state.controlled
573
  other = "Prey" if ctrl == "Predator" else "Predator"
574
- ctrl_img = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief (Fog-of-War)", show_agents=True)
575
- other_img = render_topdown(beliefs[other], state.agents, f"{other} Belief (Fog-of-War)", show_agents=True)
576
 
 
577
  pred = state.agents["Predator"]
578
  prey = state.agents["Prey"]
579
  scout = state.agents["Scout"]
580
 
581
  status = (
582
- f"Controlled: {state.controlled} | POV: {state.pov} | "
583
- f"AutoRun: {state.autorun} @ {state.speed_hz:.2f} Hz | "
584
- f"Caught: {state.caught}\n"
585
- f"Pred({pred.x},{pred.y}) ori={pred.ori} | "
586
- f"Prey({prey.x},{prey.y}) ori={prey.ori} energy={prey.energy} | "
587
- f"Scout({scout.x},{scout.y}) ori={scout.ori}"
588
  )
589
- log = "\n".join(state.event_log[-14:])
590
- return pov_img, truth_img, ctrl_img, other_img, status, log
 
591
 
592
  def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState) -> WorldState:
593
  x_px, y_px = evt.index
594
- y_px = y_px - 28
595
  if y_px < 0:
596
  return state
597
  gx = int(x_px // TILE)
@@ -601,308 +938,301 @@ def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState
601
  if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1:
602
  return state
603
  state.grid[gy][gx] = selected_tile
604
- state.event_log.append(f"t={state.step}: Edited tile ({gx},{gy}) -> {TILE_NAMES.get(selected_tile, selected_tile)}.")
605
  return state
606
 
607
- def export_run(state: WorldState, history: List[Snapshot]) -> str:
608
- payload = {
609
- "seed": state.seed,
610
- "current_step": state.step,
611
- "controlled": state.controlled,
612
- "pov": state.pov,
613
- "autorun": state.autorun,
614
- "speed_hz": state.speed_hz,
615
- "overlay": state.overlay,
616
- "branches": state.branches,
617
- "history": [asdict(s) for s in history],
618
- }
619
- return json.dumps(payload, indent=2)
620
-
621
- def import_run(txt: str) -> Tuple[WorldState, List[Snapshot], Dict[str, np.ndarray], int]:
622
- data = json.loads(txt)
623
- st = init_state(int(data["seed"]))
624
- st.controlled = data.get("controlled", "Predator")
625
- st.pov = data.get("pov", st.controlled)
626
- st.autorun = bool(data.get("autorun", False))
627
- st.speed_hz = float(data.get("speed_hz", 8.0))
628
- st.overlay = bool(data.get("overlay", False))
629
- st.branches = dict(data.get("branches", {"main": 0}))
630
-
631
- hist = [Snapshot(**s) for s in data.get("history", [])]
632
- bel = init_belief()
633
-
634
- r_idx = min(len(hist) - 1, len(hist) - 1 if hist else 0)
635
- if hist:
636
- restore_into(st, hist[-1])
637
- st.event_log.append("Imported run.")
638
- return st, hist, bel, r_idx
639
-
640
  # -----------------------------
641
- # Gradio app
642
  # -----------------------------
643
- with gr.Blocks(title="ChronoSandbox — Agent Timeline Lab") as demo:
644
  gr.Markdown(
645
- "## ChronoSandbox — Agent Timeline Lab\n"
646
- "Deterministic multi-agent POV sandbox with **time dilation, rewind, and branching**.\n"
647
- "Explicit rules, replayable runs."
648
  )
649
 
650
- st = gr.State(init_state(seed=1337))
651
- history = gr.State([snapshot_of(init_state(seed=1337))])
652
  beliefs = gr.State(init_belief())
653
- rewind_index = gr.State(0)
654
 
655
  with gr.Row():
656
- pov_img = gr.Image(label="First-Person POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H)
657
  with gr.Column():
658
- status = gr.Textbox(label="Status", lines=3)
659
- log = gr.Textbox(label="Event Log", lines=14)
 
660
 
661
  with gr.Row():
662
  truth = gr.Image(label="Truth Map (click to edit tiles)", type="pil")
663
- belief_a = gr.Image(label="Belief A", type="pil")
664
- belief_b = gr.Image(label="Belief B", type="pil")
665
 
666
  with gr.Row():
667
  with gr.Column(scale=2):
668
- gr.Markdown("### Controls")
669
  with gr.Row():
670
- btn_L = gr.Button("Turn Left (L)")
671
- btn_F = gr.Button("Forward (F)")
672
- btn_R = gr.Button("Turn Right (R)")
673
  with gr.Row():
674
- toggle_control = gr.Button("Toggle Controlled Agent")
675
- toggle_pov = gr.Button("Toggle POV Camera")
676
- btn_step = gr.Button("Tick (Single Step)")
677
  with gr.Row():
678
- autorun = gr.Checkbox(False, label="AutoRun")
679
- overlay = gr.Checkbox(False, label="Overlay (reticle)")
680
- speed = gr.Slider(0.25, 32.0, value=8.0, step=0.25, label="Speed (Hz) — time dilation")
 
681
  tile_pick = gr.Radio(
682
  choices=[(TILE_NAMES[k], k) for k in [EMPTY, WALL, FOOD, NOISE, DOOR, TELE]],
683
  value=WALL,
684
- label="Click-edit tile type"
685
  )
686
- with gr.Column(scale=2):
687
- gr.Markdown("### Time Travel")
688
- rewind = gr.Slider(0, 0, value=0, step=1, label="Rewind Scrubber (history index)")
689
- btn_jump = gr.Button("Jump to Rewind Index")
690
- btn_branch = gr.Button("Branch From Current (fork timeline)")
691
- branch_name = gr.Textbox(value="branch_1", label="Branch name")
692
- gr.Markdown("### Import / Export")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
  export_box = gr.Textbox(label="Export JSON", lines=10)
694
- btn_export = gr.Button("Export Run")
 
695
  import_box = gr.Textbox(label="Import JSON", lines=10)
696
- btn_import = gr.Button("Import Run")
697
-
698
- timer = gr.Timer(0.12)
699
 
700
- def refresh(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r_idx: int):
701
  r_max = max(0, len(hist) - 1)
702
- r_idx = max(0, min(int(r_idx), r_max))
703
- pov_np, truth_im, a_im, b_im, stxt, ltxt = build_views(state, bel)
704
  return (
705
- pov_np,
706
- truth_im,
707
- a_im,
708
- b_im,
709
- stxt,
710
- ltxt,
711
- gr.update(maximum=r_max, value=r_idx),
712
- r_idx
713
  )
714
 
715
- def do_action(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r_idx: int, act: str):
716
- tick(state, manual_action=act)
717
  hist.append(snapshot_of(state))
718
  if len(hist) > MAX_HISTORY:
719
  hist.pop(0)
720
- r_idx = len(hist) - 1
721
- return refresh(state, hist, bel, r_idx) + (state, hist, bel, r_idx)
 
 
 
 
 
 
 
 
 
722
 
723
- def do_tick(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r_idx: int):
724
- tick(state, manual_action=None)
725
- hist.append(snapshot_of(state))
726
- if len(hist) > MAX_HISTORY:
727
- hist.pop(0)
728
- r_idx = len(hist) - 1
729
- return refresh(state, hist, bel, r_idx) + (state, hist, bel, r_idx)
730
 
731
- def set_toggles(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r_idx: int, ar: bool, sp: float, ov: bool):
732
- state.autorun = bool(ar)
733
- state.speed_hz = float(sp)
734
- state.overlay = bool(ov)
735
- return refresh(state, hist, bel, r_idx) + (state, hist, bel, r_idx)
 
 
 
 
 
 
 
 
 
 
 
 
736
 
737
- def toggle_control_fn(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r_idx: int):
738
  order = ["Predator", "Prey", "Scout"]
739
  i = order.index(state.controlled)
740
  state.controlled = order[(i + 1) % len(order)]
741
- state.event_log.append(f"t={state.step}: Controlled -> {state.controlled}.")
742
- return refresh(state, hist, bel, r_idx) + (state, hist, bel, r_idx)
 
 
 
743
 
744
- def toggle_pov_fn(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r_idx: int):
745
  order = ["Predator", "Prey", "Scout"]
746
  i = order.index(state.pov)
747
  state.pov = order[(i + 1) % len(order)]
748
- state.event_log.append(f"t={state.step}: POV -> {state.pov}.")
749
- return refresh(state, hist, bel, r_idx) + (state, hist, bel, r_idx)
 
 
 
750
 
751
- def jump_fn(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r_idx: int, idx: int):
752
- if not hist:
753
- return refresh(state, hist, bel, r_idx) + (state, hist, bel, r_idx)
754
- idx = max(0, min(int(idx), len(hist) - 1))
755
- restore_into(state, hist[idx])
756
- r_idx = idx
757
- return refresh(state, hist, bel, r_idx) + (state, hist, bel, r_idx)
758
-
759
- def branch_fn(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r_idx: int, name: str):
760
- nm = (name or "").strip() or f"branch_{len(state.branches)+1}"
761
- state.branches[nm] = r_idx
762
- state.event_log.append(f"t={state.step}: Branched timeline '{nm}' at history idx={r_idx}.")
763
- return refresh(state, hist, bel, r_idx) + (state, hist, bel, r_idx)
764
 
765
- def truth_click(tile: int, state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r_idx: int, evt: gr.SelectData):
766
  state = grid_click_to_tile(evt, int(tile), state)
767
- hist.append(snapshot_of(state))
768
- if len(hist) > MAX_HISTORY:
769
- hist.pop(0)
770
- r_idx = len(hist) - 1
771
- return refresh(state, hist, bel, r_idx) + (state, hist, bel, r_idx)
772
 
773
- def export_fn(state: WorldState, hist: List[Snapshot]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
  return export_run(state, hist)
775
 
776
- def import_fn(txt: str):
777
- state, hist, bel, r_idx = import_run(txt)
778
- pov_np, truth_im, a_im, b_im, stxt, ltxt = build_views(state, bel)
779
  r_max = max(0, len(hist) - 1)
780
  return (
781
- pov_np, truth_im, a_im, b_im, stxt, ltxt,
782
- gr.update(maximum=r_max, value=r_idx),
783
- state, hist, bel, r_idx
784
  )
785
 
786
- # --- CLICK HANDLERS (NO fn_kwargs; use lambdas for compatibility) ---
787
- btn_L.click(
788
- lambda s, h, b, r: do_action(s, h, b, r, "L"),
789
- inputs=[st, history, beliefs, rewind_index],
790
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index, st, history, beliefs, rewind_index],
791
- api_name=False,
792
- queue=True,
793
- )
794
- btn_F.click(
795
- lambda s, h, b, r: do_action(s, h, b, r, "F"),
796
- inputs=[st, history, beliefs, rewind_index],
797
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index, st, history, beliefs, rewind_index],
798
- api_name=False,
799
- queue=True,
800
- )
801
- btn_R.click(
802
- lambda s, h, b, r: do_action(s, h, b, r, "R"),
803
- inputs=[st, history, beliefs, rewind_index],
804
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index, st, history, beliefs, rewind_index],
805
- api_name=False,
806
- queue=True,
807
- )
808
-
809
- btn_step.click(
810
- do_tick,
811
- inputs=[st, history, beliefs, rewind_index],
812
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index, st, history, beliefs, rewind_index],
813
- queue=True
814
- )
815
-
816
- toggle_control.click(
817
- toggle_control_fn,
818
- inputs=[st, history, beliefs, rewind_index],
819
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index, st, history, beliefs, rewind_index],
820
- queue=True
821
- )
822
- toggle_pov.click(
823
- toggle_pov_fn,
824
- inputs=[st, history, beliefs, rewind_index],
825
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index, st, history, beliefs, rewind_index],
826
- queue=True
827
- )
828
-
829
- autorun.change(
830
- set_toggles,
831
- inputs=[st, history, beliefs, rewind_index, autorun, speed, overlay],
832
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index, st, history, beliefs, rewind_index],
833
- queue=True
834
- )
835
- speed.change(
836
- set_toggles,
837
- inputs=[st, history, beliefs, rewind_index, autorun, speed, overlay],
838
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index, st, history, beliefs, rewind_index],
839
- queue=True
840
- )
841
- overlay.change(
842
- set_toggles,
843
- inputs=[st, history, beliefs, rewind_index, autorun, speed, overlay],
844
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index, st, history, beliefs, rewind_index],
845
- queue=True
846
- )
847
-
848
- btn_jump.click(
849
- jump_fn,
850
- inputs=[st, history, beliefs, rewind_index, rewind],
851
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index, st, history, beliefs, rewind_index],
852
- queue=True
853
- )
854
- btn_branch.click(
855
- branch_fn,
856
- inputs=[st, history, beliefs, rewind_index, branch_name],
857
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index, st, history, beliefs, rewind_index],
858
- queue=True
859
- )
860
-
861
- truth.select(
862
- truth_click,
863
- inputs=[tile_pick, st, history, beliefs, rewind_index],
864
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index, st, history, beliefs, rewind_index],
865
- queue=True
866
- )
867
 
868
  btn_export.click(export_fn, inputs=[st, history], outputs=[export_box], queue=True)
869
- btn_import.click(
870
- import_fn,
871
- inputs=[import_box],
872
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, st, history, beliefs, rewind_index],
873
- queue=True
874
- )
875
-
876
- # Timer-driven autorun
877
- def timer_fn(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r_idx: int, ar: bool, sp: float):
878
- state.autorun = bool(ar)
879
- state.speed_hz = float(sp)
880
 
881
- if not state.autorun or state.caught:
882
- return refresh(state, hist, bel, r_idx) + (state, hist, bel, r_idx)
 
 
883
 
884
- ticks_per_frame = max(1, int(round(state.speed_hz * 0.12)))
885
- for _ in range(ticks_per_frame):
886
- tick(state, manual_action=None)
887
- hist.append(snapshot_of(state))
888
- if len(hist) > MAX_HISTORY:
889
- hist.pop(0)
890
-
891
- r_idx = len(hist) - 1
892
- return refresh(state, hist, bel, r_idx) + (state, hist, bel, r_idx)
893
-
894
- timer.tick(
895
- timer_fn,
896
- inputs=[st, history, beliefs, rewind_index, autorun, speed],
897
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index, st, history, beliefs, rewind_index],
898
- queue=True
899
- )
900
-
901
- demo.load(
902
- refresh,
903
- inputs=[st, history, beliefs, rewind_index],
904
- outputs=[pov_img, truth, belief_a, belief_b, status, log, rewind, rewind_index],
905
- queue=True
906
- )
907
 
908
  demo.queue().launch()
 
9
  import gradio as gr
10
 
11
  # ============================================================
12
+ # ChronoSandbox++Instrumented Training Arena
13
+ # - Deterministic gridworld + first-person raycast view
14
+ # - Click-to-edit environment (tiles)
15
+ # - Full step trace: obs -> action -> reward -> q-update rationale
16
+ # - Optional Q-learning (tabular) for Predator + Prey
17
+ # - Batch training: run episodes fast, track metrics
18
+ # - Export/import: environment, history, Q-tables, metrics
19
  #
20
+ # Compatibility: avoids fn_kwargs + avoids gr.Timer
21
  # ============================================================
22
 
23
  # -----------------------------
24
+ # Config
25
  # -----------------------------
26
  GRID_W, GRID_H = 21, 15
27
+ TILE = 22
28
 
29
  VIEW_W, VIEW_H = 640, 360
30
  RAY_W = 320
31
  FOV_DEG = 78
32
  MAX_DEPTH = 20
33
 
 
34
  DIRS = [(1, 0), (0, 1), (-1, 0), (0, -1)]
35
  ORI_DEG = [0, 90, 180, 270]
36
 
 
37
  EMPTY = 0
38
  WALL = 1
39
  FOOD = 2
 
50
  TELE: "Teleporter",
51
  }
52
 
53
+ AGENT_COLORS = {
54
+ "Predator": (255, 120, 90),
55
+ "Prey": (120, 255, 160),
56
+ "Scout": (120, 190, 255),
57
+ }
58
+
59
  SKY = np.array([14, 16, 26], dtype=np.uint8)
60
  FLOOR_NEAR = np.array([24, 26, 40], dtype=np.uint8)
61
  FLOOR_FAR = np.array([10, 11, 18], dtype=np.uint8)
62
  WALL_BASE = np.array([210, 210, 225], dtype=np.uint8)
63
  WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8)
64
+ DOOR_COL = np.array([180, 210, 255], dtype=np.uint8)
65
 
66
+ ACTIONS = ["L", "F", "R"] # keep small for tabular learning stability
 
 
 
 
67
 
68
  # -----------------------------
69
+ # Deterministic RNG streams
70
  # -----------------------------
71
  def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator:
72
  mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531)
73
  return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF)
74
 
75
  # -----------------------------
76
+ # Data structures
77
  # -----------------------------
78
  @dataclass
79
  class Agent:
80
  name: str
81
  x: int
82
  y: int
83
+ ori: int
84
  energy: int = 100
85
 
86
+ @dataclass
87
+ class TrainConfig:
88
+ use_q_pred: bool = True
89
+ use_q_prey: bool = True
90
+ alpha: float = 0.15
91
+ gamma: float = 0.95
92
+ epsilon: float = 0.10
93
+ epsilon_min: float = 0.02
94
+ epsilon_decay: float = 0.995
95
+
96
+ # reward shaping
97
+ pred_step_penalty: float = -0.02
98
+ pred_dist_coeff: float = 0.03
99
+ pred_catch_reward: float = 3.0
100
+
101
+ prey_step_penalty: float = -0.02
102
+ prey_food_reward: float = 0.6
103
+ prey_survive_reward: float = 0.02
104
+ prey_caught_penalty: float = -3.0
105
+
106
+ @dataclass
107
+ class Metrics:
108
+ episodes: int = 0
109
+ catches: int = 0
110
+ avg_steps_to_catch: float = 0.0
111
+ avg_path_efficiency: float = 0.0 # optimal / actual (0..1)
112
+ last_episode_steps: int = 0
113
+ last_episode_eff: float = 0.0
114
+ epsilon: float = 0.10
115
+
116
  @dataclass
117
  class WorldState:
118
  seed: int
 
121
  agents: Dict[str, Agent]
122
  controlled: str
123
  pov: str
 
 
124
  overlay: bool
125
+
126
  caught: bool
127
  branches: Dict[str, int]
128
 
129
+ # instrumentation
130
+ event_log: List[str]
131
+ trace_log: List[str] # more detailed step trace (bounded)
132
+
133
+ # training
134
+ cfg: TrainConfig
135
+ q_pred: Dict[str, List[float]]
136
+ q_prey: Dict[str, List[float]]
137
+ metrics: Metrics
138
+
139
  @dataclass
140
  class Snapshot:
141
  step: int
142
  agents: Dict[str, Dict]
143
  grid: List[List[int]]
 
144
  caught: bool
145
+ event_log_tail: List[str]
146
+ trace_tail: List[str]
147
 
148
+ # -----------------------------
149
+ # Environment
150
+ # -----------------------------
151
  def default_grid() -> List[List[int]]:
152
  g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)]
 
153
  for x in range(GRID_W):
154
  g[0][x] = WALL
155
  g[GRID_H - 1][x] = WALL
 
157
  g[y][0] = WALL
158
  g[y][GRID_W - 1] = WALL
159
 
 
160
  for x in range(4, 17):
161
  g[7][x] = WALL
162
  g[7][10] = DOOR
163
 
 
164
  g[3][4] = FOOD
165
  g[11][15] = FOOD
166
  g[4][14] = NOISE
 
175
  "Prey": Agent("Prey", 18, 12, 2, 100),
176
  "Scout": Agent("Scout", 10, 3, 1, 100),
177
  }
178
+ cfg = TrainConfig()
179
  return WorldState(
180
  seed=seed,
181
  step=0,
 
183
  agents=agents,
184
  controlled="Predator",
185
  pov="Predator",
 
 
186
  overlay=False,
 
187
  caught=False,
188
  branches={"main": 0},
189
+ event_log=["Initialized world."],
190
+ trace_log=[],
191
+ cfg=cfg,
192
+ q_pred={},
193
+ q_prey={},
194
+ metrics=Metrics(epsilon=cfg.epsilon),
195
  )
196
 
197
  # -----------------------------
198
+ # Belief maps
199
  # -----------------------------
200
  def init_belief() -> Dict[str, np.ndarray]:
201
  b = {}
202
+ for nm in ["Predator", "Prey", "Scout"]:
203
+ b[nm] = -1 * np.ones((GRID_H, GRID_W), dtype=np.int16)
204
  return b
205
 
206
  # -----------------------------
207
+ # Helpers
208
  # -----------------------------
209
  def in_bounds(x: int, y: int) -> bool:
210
  return 0 <= x < GRID_W and 0 <= y < GRID_H
 
212
  def is_blocking(tile: int) -> bool:
213
  return tile == WALL
214
 
215
+ def manhattan(a: Agent, b: Agent) -> int:
216
+ return abs(a.x - b.x) + abs(a.y - b.y)
 
 
 
 
 
 
 
 
 
217
 
218
+ def bresenham_los(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> bool:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  dx = abs(x1 - x0)
220
  dy = abs(y1 - y0)
221
  sx = 1 if x0 < x1 else -1
 
236
  err += dx
237
  y += sy
238
 
239
+ def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = FOV_DEG) -> bool:
240
  dx = tx - observer.x
241
  dy = ty - observer.y
242
  if dx == 0 and dy == 0:
 
247
  return abs(diff) <= (fov_deg / 2)
248
 
249
  def visible(observer: Agent, target: Agent, grid: List[List[int]]) -> bool:
250
+ return within_fov(observer, target.x, target.y, FOV_DEG) and bresenham_los(grid, observer.x, observer.y, target.x, target.y)
251
+
252
+ # -----------------------------
253
+ # Movement
254
+ # -----------------------------
255
+ def turn_left(a: Agent) -> None:
256
+ a.ori = (a.ori - 1) % 4
257
+
258
+ def turn_right(a: Agent) -> None:
259
+ a.ori = (a.ori + 1) % 4
260
+
261
+ def move_forward(state: WorldState, a: Agent) -> str:
262
+ dx, dy = DIRS[a.ori]
263
+ nx, ny = a.x + dx, a.y + dy
264
+ if not in_bounds(nx, ny):
265
+ return "blocked: bounds"
266
+ if is_blocking(state.grid[ny][nx]):
267
+ return "blocked: wall"
268
+ if state.grid[ny][nx] == DOOR:
269
+ state.grid[ny][nx] = EMPTY
270
+ state.event_log.append(f"t={state.step}: {a.name} opened a door.")
271
+ a.x, a.y = nx, ny
272
+
273
+ if state.grid[ny][nx] == TELE:
274
+ teles = [(x, y) for y in range(GRID_H) for x in range(GRID_W) if state.grid[y][x] == TELE]
275
+ if len(teles) >= 2:
276
+ teles_sorted = sorted(teles)
277
+ idx = teles_sorted.index((nx, ny))
278
+ dest = teles_sorted[(idx + 1) % len(teles_sorted)]
279
+ a.x, a.y = dest
280
+ state.event_log.append(f"t={state.step}: {a.name} teleported.")
281
+ return "moved: teleported"
282
+ return "moved"
283
+
284
+ def apply_action(state: WorldState, agent_name: str, action: str) -> str:
285
+ a = state.agents[agent_name]
286
+ if action == "L":
287
+ turn_left(a)
288
+ return "turned left"
289
+ if action == "R":
290
+ turn_right(a)
291
+ return "turned right"
292
+ if action == "F":
293
+ return move_forward(state, a)
294
+ return "noop"
295
 
296
  # -----------------------------
297
+ # Rendering
298
  # -----------------------------
299
  def raycast_view(state: WorldState, observer: Agent) -> np.ndarray:
300
  img = np.zeros((VIEW_H, VIEW_W, 3), dtype=np.uint8)
 
317
  cos_a = math.cos(ray_ang)
318
 
319
  depth = 0.0
320
+ hit = None # None, "wall", "door"
321
+ side = 0
322
 
323
  while depth < MAX_DEPTH:
324
  depth += 0.05
 
326
  ty = int(oy + sin_a * depth)
327
  if not in_bounds(tx, ty):
328
  break
 
329
  tile = state.grid[ty][tx]
330
  if tile == WALL:
331
+ hit = "wall"
332
+ side = 1 if abs(cos_a) > abs(sin_a) else 0
333
  break
334
  if tile == DOOR:
335
+ hit = "door"
336
  break
337
 
338
+ if hit is None:
339
  continue
340
 
341
  depth *= math.cos(ray_ang - math.radians(ORI_DEG[observer.ori]))
 
345
  y0 = max(0, VIEW_H // 2 - proj_h // 2)
346
  y1 = min(VIEW_H - 1, VIEW_H // 2 + proj_h // 2)
347
 
348
+ if hit == "door":
349
+ col = DOOR_COL.copy()
 
 
350
  else:
351
+ col = WALL_BASE.copy() if side == 0 else WALL_SIDE.copy()
352
 
353
  dim = max(0.25, 1.0 - (depth / MAX_DEPTH))
354
  col = (col * dim).astype(np.uint8)
 
357
  x1 = int((rx + 1) * (VIEW_W / RAY_W))
358
  img[y0:y1, x0:x1] = col
359
 
360
+ # billboards for visible agents
361
+ for nm, other in state.agents.items():
362
+ if nm == observer.name:
363
  continue
364
  if visible(observer, other, state.grid):
365
  dx = other.x - observer.x
 
376
  y1 = min(VIEW_H - 1, y_mid + h // 2)
377
  x0 = max(0, sx - w // 2)
378
  x1 = min(VIEW_W - 1, sx + w // 2)
379
+ col = AGENT_COLORS.get(nm, (255, 200, 120))
380
  img[y0:y1, x0:x1] = np.array(col, dtype=np.uint8)
381
 
382
  if state.overlay:
 
386
 
387
  return img
388
 
 
 
 
389
  def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_agents: bool = True) -> Image.Image:
390
  w = grid.shape[1] * TILE
391
  h = grid.shape[0] * TILE
 
423
  draw.line([0, yy, w, yy], fill=(12, 14, 22))
424
 
425
  if show_agents:
426
+ for nm, a in agents.items():
427
  cx = a.x * TILE + TILE // 2
428
  cy = a.y * TILE + 28 + TILE // 2
429
+ col = AGENT_COLORS.get(nm, (220, 220, 220))
430
  r = TILE // 3
431
  draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col)
432
  dx, dy = DIRS[a.ori]
 
437
  return im
438
 
439
  # -----------------------------
440
+ # Belief updates
441
  # -----------------------------
442
+ def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> None:
443
+ belief[agent.y, agent.x] = state.grid[agent.y][agent.x]
444
+ base = math.radians(ORI_DEG[agent.ori])
445
+ half = math.radians(FOV_DEG / 2)
446
+ rays = 33 if agent.name != "Scout" else 45
447
+
448
+ for i in range(rays):
449
+ t = i / (rays - 1)
450
+ ang = base + (t * 2 - 1) * half
451
+ sin_a, cos_a = math.sin(ang), math.cos(ang)
452
+ ox, oy = agent.x + 0.5, agent.y + 0.5
453
+ depth = 0.0
454
+ while depth < MAX_DEPTH:
455
+ depth += 0.2
456
+ tx = int(ox + cos_a * depth)
457
+ ty = int(oy + sin_a * depth)
458
+ if not in_bounds(tx, ty):
459
+ break
460
+ belief[ty, tx] = state.grid[ty][tx]
461
+ if state.grid[ty][tx] == WALL:
462
+ break
463
+
464
+ # -----------------------------
465
+ # Optimal distance (BFS) for efficiency metric
466
+ # -----------------------------
467
+ def bfs_distance(grid: List[List[int]], sx: int, sy: int, gx: int, gy: int) -> Optional[int]:
468
+ if (sx, sy) == (gx, gy):
469
+ return 0
470
+ q = [(sx, sy)]
471
+ dist = { (sx, sy): 0 }
472
+ head = 0
473
+ while head < len(q):
474
+ x, y = q[head]; head += 1
475
+ for dx, dy in DIRS:
476
+ nx, ny = x + dx, y + dy
477
+ if not in_bounds(nx, ny):
478
+ continue
479
+ if grid[ny][nx] == WALL:
480
+ continue
481
+ if (nx, ny) in dist:
482
+ continue
483
+ dist[(nx, ny)] = dist[(x, y)] + 1
484
+ if (nx, ny) == (gx, gy):
485
+ return dist[(nx, ny)]
486
+ q.append((nx, ny))
487
+ return None
488
+
489
+ # -----------------------------
490
+ # Observation encoding (compact state key)
491
+ # -----------------------------
492
+ def obs_key(state: WorldState, who: str) -> str:
493
+ pred = state.agents["Predator"]
494
+ prey = state.agents["Prey"]
495
+ a = state.agents[who]
496
+ # relative position coarse-binned to keep table smaller
497
+ dx = prey.x - pred.x
498
+ dy = prey.y - pred.y
499
+ dx_bin = int(np.clip(dx, -6, 6))
500
+ dy_bin = int(np.clip(dy, -6, 6))
501
+ vis = 1 if visible(pred, prey, state.grid) else 0
502
+ # include own orientation and role
503
+ if who == "Predator":
504
+ return f"P|{pred.x},{pred.y},{pred.ori}|d{dx_bin},{dy_bin}|v{vis}"
505
+ if who == "Prey":
506
+ # prey cares if predator is visible to it
507
+ vis2 = 1 if visible(prey, pred, state.grid) else 0
508
+ ddx = pred.x - prey.x
509
+ ddy = pred.y - prey.y
510
+ ddx_bin = int(np.clip(ddx, -6, 6))
511
+ ddy_bin = int(np.clip(ddy, -6, 6))
512
+ return f"R|{prey.x},{prey.y},{prey.ori}|d{ddx_bin},{ddy_bin}|v{vis2}|e{int(prey.energy//25)}"
513
+ # Scout: simple
514
+ return f"S|{a.x},{a.y},{a.ori}"
515
+
516
+ def q_get(q: Dict[str, List[float]], key: str) -> List[float]:
517
+ if key not in q:
518
+ q[key] = [0.0, 0.0, 0.0]
519
+ return q[key]
520
+
521
+ def epsilon_greedy(qvals: List[float], eps: float, r: np.random.Generator) -> int:
522
+ if r.random() < eps:
523
+ return int(r.integers(0, len(qvals)))
524
+ return int(np.argmax(qvals))
525
+
526
+ def q_update(q: Dict[str, List[float]], key: str, a_idx: int, reward: float, next_key: str, alpha: float, gamma: float) -> Tuple[float, float, float]:
527
+ qv = q_get(q, key)
528
+ nq = q_get(q, next_key)
529
+ old = qv[a_idx]
530
+ target = reward + gamma * float(np.max(nq))
531
+ new = old + alpha * (target - old)
532
+ qv[a_idx] = new
533
+ return old, target, new
534
+
535
+ # -----------------------------
536
+ # Baseline heuristic policies (for Scout + fallback)
537
+ # -----------------------------
538
+ def heuristic_pred_action(state: WorldState) -> str:
539
  pred = state.agents["Predator"]
540
  prey = state.agents["Prey"]
541
  if visible(pred, prey, state.grid):
 
549
  if diff > 10:
550
  return "R"
551
  return "F"
552
+ r = rng_for(state.seed, state.step, stream=11)
553
+ return r.choice(ACTIONS)
554
 
555
+ def heuristic_prey_action(state: WorldState) -> str:
556
  prey = state.agents["Prey"]
557
  pred = state.agents["Predator"]
558
  if visible(prey, pred, state.grid):
 
567
  if diff_away > 10:
568
  return "R"
569
  return "F"
570
+ r = rng_for(state.seed, state.step, stream=12)
571
+ return r.choice(ACTIONS)
 
 
 
 
 
 
 
 
 
 
 
 
572
 
573
+ def heuristic_scout_action(state: WorldState) -> str:
574
+ r = rng_for(state.seed, state.step, stream=13)
575
+ return r.choice(ACTIONS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
 
577
+ # -----------------------------
578
+ # Reward shaping
579
+ # -----------------------------
580
+ def pred_reward(state_prev: WorldState, state_now: WorldState) -> float:
581
+ cfg = state_now.cfg
582
+ pred0 = state_prev.agents["Predator"]
583
+ prey0 = state_prev.agents["Prey"]
584
+ pred1 = state_now.agents["Predator"]
585
+ prey1 = state_now.agents["Prey"]
586
+ d0 = abs(pred0.x - prey0.x) + abs(pred0.y - prey0.y)
587
+ d1 = abs(pred1.x - prey1.x) + abs(pred1.y - prey1.y)
588
+ r = cfg.pred_step_penalty + cfg.pred_dist_coeff * (d0 - d1) # reward closing distance
589
+ if state_now.caught:
590
+ r += cfg.pred_catch_reward
591
+ return float(r)
592
+
593
+ def prey_reward(state_prev: WorldState, state_now: WorldState, ate_food: bool) -> float:
594
+ cfg = state_now.cfg
595
+ r = cfg.prey_step_penalty + cfg.prey_survive_reward
596
+ if ate_food:
597
+ r += cfg.prey_food_reward
598
+ if state_now.caught:
599
+ r += cfg.prey_caught_penalty
600
+ return float(r)
601
+
602
+ # -----------------------------
603
+ # Core simulation tick (with instrumentation + optional learning)
604
+ # -----------------------------
605
+ TRACE_MAX = 400
606
+
607
+ def clone_shallow(state: WorldState) -> WorldState:
608
+ # clone for reward computation, minimal fields
609
+ return WorldState(
610
+ seed=state.seed,
611
+ step=state.step,
612
+ grid=[row[:] for row in state.grid],
613
+ agents={k: Agent(**asdict(v)) for k, v in state.agents.items()},
614
+ controlled=state.controlled,
615
+ pov=state.pov,
616
+ overlay=state.overlay,
617
+ caught=state.caught,
618
+ branches=dict(state.branches),
619
+ event_log=list(state.event_log),
620
+ trace_log=list(state.trace_log),
621
+ cfg=state.cfg,
622
+ q_pred=state.q_pred,
623
+ q_prey=state.q_prey,
624
+ metrics=state.metrics,
625
+ )
626
 
627
  def check_catch(state: WorldState) -> None:
628
  pred = state.agents["Predator"]
 
631
  state.caught = True
632
  state.event_log.append(f"t={state.step}: CAUGHT.")
633
 
634
+ def consume_food(state: WorldState) -> bool:
635
+ prey = state.agents["Prey"]
636
+ if state.grid[prey.y][prey.x] == FOOD:
637
+ prey.energy = min(200, prey.energy + 35)
638
+ state.grid[prey.y][prey.x] = EMPTY
639
+ state.event_log.append(f"t={state.step}: Prey ate food (+energy).")
640
+ return True
641
+ return False
642
+
643
+ def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, Optional[Tuple[str,int]]]:
644
+ """
645
+ Returns (action, reason, q_info)
646
+ q_info: (obs_key, action_index) if chosen by Q, else None
647
+ """
648
+ cfg = state.cfg
649
+ r = rng_for(state.seed, state.step, stream=stream)
650
+
651
+ if who == "Predator" and cfg.use_q_pred:
652
+ k = obs_key(state, "Predator")
653
+ qv = q_get(state.q_pred, k)
654
+ a_idx = epsilon_greedy(qv, state.metrics.epsilon, r)
655
+ return ACTIONS[a_idx], f"Q(pred) eps={state.metrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (k, a_idx)
656
+
657
+ if who == "Prey" and cfg.use_q_prey:
658
+ k = obs_key(state, "Prey")
659
+ qv = q_get(state.q_prey, k)
660
+ a_idx = epsilon_greedy(qv, state.metrics.epsilon, r)
661
+ return ACTIONS[a_idx], f"Q(prey) eps={state.metrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (k, a_idx)
662
+
663
+ # fallbacks
664
+ if who == "Predator":
665
+ a = heuristic_pred_action(state)
666
+ return a, "heuristic(pred)", None
667
+ if who == "Prey":
668
+ a = heuristic_prey_action(state)
669
+ return a, "heuristic(prey)", None
670
+ a = heuristic_scout_action(state)
671
+ return a, "heuristic(scout)", None
672
+
673
  def tick(state: WorldState, manual_action: Optional[str] = None) -> None:
674
  if state.caught:
675
  return
676
 
677
+ prev = clone_shallow(state)
678
+
679
+ # record optimal distance for efficiency stats
680
+ pred = state.agents["Predator"]
681
+ prey = state.agents["Prey"]
682
+ opt_dist = bfs_distance(state.grid, pred.x, pred.y, prey.x, prey.y)
683
+ if opt_dist is None:
684
+ opt_dist = 999
685
+
686
+ # Action selection
687
+ chosen = {}
688
+ reasons = {}
689
+ qinfo = {}
690
+
691
+ # manual action applies to controlled agent
692
  if manual_action:
693
+ chosen[state.controlled] = manual_action
694
+ reasons[state.controlled] = "manual"
695
+ qinfo[state.controlled] = None
 
 
 
 
 
 
 
 
696
 
697
+ # others choose
698
+ for who in ["Predator", "Prey", "Scout"]:
699
+ if who in chosen:
700
  continue
701
+ act, reason, q_i = choose_action(state, who, stream={"Predator":21,"Prey":22,"Scout":23}[who])
702
+ chosen[who] = act
703
+ reasons[who] = reason
704
+ qinfo[who] = q_i
 
 
 
705
 
706
+ # Apply actions (deterministic order)
707
+ outcomes = {}
708
+ for who in ["Predator", "Prey", "Scout"]:
709
+ outcomes[who] = apply_action(state, who, chosen[who])
710
+
711
+ ate = consume_food(state)
712
  check_catch(state)
713
+
714
+ # Rewards + Q-updates
715
+ pred_r = pred_reward(prev, state)
716
+ prey_r = prey_reward(prev, state, ate_food=ate)
717
+
718
+ q_lines = []
719
+ if qinfo["Predator"] is not None:
720
+ k, a_idx = qinfo["Predator"]
721
+ nk = obs_key(state, "Predator")
722
+ old, target, new = q_update(state.q_pred, k, a_idx, pred_r, nk, state.cfg.alpha, state.cfg.gamma)
723
+ q_lines.append(f"Qpred: {k} a={ACTIONS[a_idx]} old={old:.3f} tgt={target:.3f} new={new:.3f}")
724
+
725
+ if qinfo["Prey"] is not None:
726
+ k, a_idx = qinfo["Prey"]
727
+ nk = obs_key(state, "Prey")
728
+ old, target, new = q_update(state.q_prey, k, a_idx, prey_r, nk, state.cfg.alpha, state.cfg.gamma)
729
+ q_lines.append(f"Qprey: {k} a={ACTIONS[a_idx]} old={old:.3f} tgt={target:.3f} new={new:.3f}")
730
+
731
+ # Trace line
732
+ dist_now = manhattan(state.agents["Predator"], state.agents["Prey"])
733
+ eff = (opt_dist / max(1, dist_now)) if dist_now > 0 else 1.0
734
+ trace = (
735
+ f"t={state.step} optDist~{opt_dist} distNow={dist_now} "
736
+ f"| Pred:{chosen['Predator']} ({outcomes['Predator']}) [{reasons['Predator']}] r={pred_r:+.3f} "
737
+ f"| Prey:{chosen['Prey']} ({outcomes['Prey']}) [{reasons['Prey']}] r={prey_r:+.3f} "
738
+ f"| Scout:{chosen['Scout']} ({outcomes['Scout']}) [{reasons['Scout']}] "
739
+ f"| ateFood={ate} caught={state.caught}"
740
+ )
741
+ if q_lines:
742
+ trace += " | " + " ; ".join(q_lines)
743
+
744
+ state.trace_log.append(trace)
745
+ if len(state.trace_log) > TRACE_MAX:
746
+ state.trace_log = state.trace_log[-TRACE_MAX:]
747
+
748
  state.step += 1
749
 
750
  # -----------------------------
751
+ # Episode reset + training
752
  # -----------------------------
753
+ def reset_episode(state: WorldState, seed: Optional[int] = None) -> None:
754
+ # Keep Q-tables + cfg + metrics; reset world + logs
755
+ if seed is None:
756
+ seed = state.seed
757
+ fresh = init_state(seed)
758
+ fresh.cfg = state.cfg
759
+ fresh.q_pred = state.q_pred
760
+ fresh.q_prey = state.q_prey
761
+ fresh.metrics = state.metrics
762
+ fresh.metrics.epsilon = state.metrics.epsilon
763
+ state.seed = fresh.seed
764
+ state.step = 0
765
+ state.grid = fresh.grid
766
+ state.agents = fresh.agents
767
+ state.controlled = fresh.controlled
768
+ state.pov = fresh.pov
769
+ state.overlay = fresh.overlay
770
+ state.caught = False
771
+ state.branches = fresh.branches
772
+ state.event_log = ["Episode reset."]
773
+ state.trace_log = []
774
+
775
+ def run_episode(state: WorldState, max_steps: int) -> Tuple[bool, int, float]:
776
+ # returns (caught, steps, path_eff)
777
+ start_pred = state.agents["Predator"]
778
+ start_prey = state.agents["Prey"]
779
+ opt = bfs_distance(state.grid, start_pred.x, start_pred.y, start_prey.x, start_prey.y)
780
+ if opt is None:
781
+ opt = 999
782
+ steps = 0
783
+ while steps < max_steps and not state.caught:
784
+ tick(state, manual_action=None)
785
+ steps += 1
786
+ caught = state.caught
787
+ eff = float(opt / max(1, steps)) if opt < 999 else 0.0
788
+ return caught, steps, eff
789
+
790
+ def train(state: WorldState, episodes: int, max_steps: int) -> None:
791
+ m = state.metrics
792
+ cfg = state.cfg
793
+ catches = 0
794
+ total_steps_catch = 0
795
+ total_eff = 0.0
796
+
797
+ for ep in range(episodes):
798
+ # deterministically vary episode seed so it doesn't memorize one map-layout only
799
+ ep_seed = (state.seed * 1_000_003 + (m.episodes + ep) * 97_531) & 0xFFFFFFFF
800
+ reset_episode(state, seed=int(ep_seed))
801
+
802
+ caught, steps, eff = run_episode(state, max_steps=max_steps)
803
+ total_eff += eff
804
+
805
+ if caught:
806
+ catches += 1
807
+ total_steps_catch += steps
808
+
809
+ # epsilon decay
810
+ m.epsilon = max(cfg.epsilon_min, m.epsilon * cfg.epsilon_decay)
811
+
812
+ # Update metrics
813
+ m.episodes += episodes
814
+ m.catches += catches
815
+ m.last_episode_steps = steps
816
+ m.last_episode_eff = eff
817
+ if catches > 0:
818
+ # moving average by episode count for stability
819
+ avg_steps = total_steps_catch / catches
820
+ m.avg_steps_to_catch = (
821
+ 0.85 * m.avg_steps_to_catch + 0.15 * avg_steps
822
+ if m.avg_steps_to_catch > 0 else avg_steps
823
+ )
824
+ avg_eff = total_eff / max(1, episodes)
825
+ m.avg_path_efficiency = (
826
+ 0.85 * m.avg_path_efficiency + 0.15 * avg_eff
827
+ if m.avg_path_efficiency > 0 else avg_eff
828
+ )
829
+
830
+ state.event_log.append(
831
+ f"Training: +{episodes} eps | catches={catches}/{episodes} | "
832
+ f"avgStepsToCatch~{m.avg_steps_to_catch:.2f} | avgEff~{m.avg_path_efficiency:.2f} | eps={m.epsilon:.3f}"
833
+ )
834
+
835
+ # -----------------------------
836
+ # History / snapshots
837
+ # -----------------------------
838
+ MAX_HISTORY = 1200
839
 
840
  def snapshot_of(state: WorldState) -> Snapshot:
841
  return Snapshot(
842
  step=state.step,
843
  agents={k: asdict(v) for k, v in state.agents.items()},
844
  grid=[row[:] for row in state.grid],
 
845
  caught=state.caught,
846
+ event_log_tail=state.event_log[-20:],
847
+ trace_tail=state.trace_log[-40:],
848
  )
849
 
850
  def restore_into(state: WorldState, snap: Snapshot) -> None:
 
853
  for k, d in snap.agents.items():
854
  state.agents[k] = Agent(**d)
855
  state.caught = snap.caught
856
+ state.event_log.append(f"Jumped to snapshot t={snap.step}.")
857
 
858
  # -----------------------------
859
+ # Export / import
860
  # -----------------------------
861
+ def export_run(state: WorldState, history: List[Snapshot]) -> str:
862
+ payload = {
863
+ "seed": state.seed,
864
+ "controlled": state.controlled,
865
+ "pov": state.pov,
866
+ "overlay": state.overlay,
867
+ "cfg": asdict(state.cfg),
868
+ "metrics": asdict(state.metrics),
869
+ "q_pred": state.q_pred,
870
+ "q_prey": state.q_prey,
871
+ "history": [asdict(s) for s in history],
872
+ "grid": state.grid,
873
+ }
874
+ return json.dumps(payload, indent=2)
875
 
876
+ def import_run(txt: str) -> Tuple[WorldState, List[Snapshot], Dict[str, np.ndarray], int]:
877
+ data = json.loads(txt)
878
+ st = init_state(int(data.get("seed", 1337)))
879
+ st.controlled = data.get("controlled", st.controlled)
880
+ st.pov = data.get("pov", st.pov)
881
+ st.overlay = bool(data.get("overlay", False))
882
+ st.grid = data.get("grid", st.grid)
 
 
 
 
 
 
 
 
883
 
884
+ st.cfg = TrainConfig(**data.get("cfg", asdict(st.cfg)))
885
+ st.metrics = Metrics(**data.get("metrics", asdict(st.metrics)))
 
 
 
886
 
887
+ st.q_pred = data.get("q_pred", {})
888
+ st.q_prey = data.get("q_prey", {})
889
+
890
+ hist = [Snapshot(**s) for s in data.get("history", [])]
891
+ bel = init_belief()
892
+ r_idx = max(0, len(hist) - 1)
893
+
894
+ if hist:
895
+ restore_into(st, hist[-1])
896
+ st.event_log.append("Imported run.")
897
+ return st, hist, bel, r_idx
898
 
899
+ # -----------------------------
900
+ # UI glue
901
+ # -----------------------------
902
+ def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, str, str, str]:
903
+ for nm, a in state.agents.items():
904
+ update_belief_for_agent(state, beliefs[nm], a)
905
 
906
+ pov = raycast_view(state, state.agents[state.pov])
907
  truth_np = np.array(state.grid, dtype=np.int16)
908
+ truth_img = render_topdown(truth_np, state.agents, f"Truth Map — t={state.step} seed={state.seed}", show_agents=True)
909
 
910
  ctrl = state.controlled
911
  other = "Prey" if ctrl == "Predator" else "Predator"
912
+ b_ctrl = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief", show_agents=True)
913
+ b_other = render_topdown(beliefs[other], state.agents, f"{other} Belief", show_agents=True)
914
 
915
+ m = state.metrics
916
  pred = state.agents["Predator"]
917
  prey = state.agents["Prey"]
918
  scout = state.agents["Scout"]
919
 
920
  status = (
921
+ f"Controlled={state.controlled} | POV={state.pov} | caught={state.caught} | eps={m.epsilon:.3f}\n"
922
+ f"Episodes={m.episodes} | catches={m.catches} | avgStepsToCatch~{m.avg_steps_to_catch:.2f} | avgEff~{m.avg_path_efficiency:.2f}\n"
923
+ f"Pred({pred.x},{pred.y}) o={pred.ori} | Prey({prey.x},{prey.y}) o={prey.ori} e={prey.energy} | Scout({scout.x},{scout.y}) o={scout.ori}"
 
 
 
924
  )
925
+ events = "\n".join(state.event_log[-18:])
926
+ trace = "\n".join(state.trace_log[-18:])
927
+ return pov, truth_img, b_ctrl, b_other, status, events, trace
928
 
929
  def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState) -> WorldState:
930
  x_px, y_px = evt.index
931
+ y_px -= 28
932
  if y_px < 0:
933
  return state
934
  gx = int(x_px // TILE)
 
938
  if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1:
939
  return state
940
  state.grid[gy][gx] = selected_tile
941
+ state.event_log.append(f"t={state.step}: Tile ({gx},{gy}) -> {TILE_NAMES.get(selected_tile)}")
942
  return state
943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944
  # -----------------------------
945
+ # Gradio App
946
  # -----------------------------
947
+ with gr.Blocks(title="ChronoSandbox++Training Arena") as demo:
948
  gr.Markdown(
949
+ "## ChronoSandbox++Instrumented Agent Training Arena\n"
950
+ "Track every interaction, train policies, and audit why outcomes happened.\n"
951
+ "No timers (compatibility). Use Tick/Run/Train for controlled experiments."
952
  )
953
 
954
+ st = gr.State(init_state(1337))
955
+ history = gr.State([snapshot_of(init_state(1337))])
956
  beliefs = gr.State(init_belief())
957
+ rewind_idx = gr.State(0)
958
 
959
  with gr.Row():
960
+ pov_img = gr.Image(label="POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H)
961
  with gr.Column():
962
+ status = gr.Textbox(label="Status + Metrics", lines=4)
963
+ events = gr.Textbox(label="Event Log", lines=10)
964
+ trace = gr.Textbox(label="Step Trace (why it happened)", lines=10)
965
 
966
  with gr.Row():
967
  truth = gr.Image(label="Truth Map (click to edit tiles)", type="pil")
968
+ belief_a = gr.Image(label="Belief (Controlled)", type="pil")
969
+ belief_b = gr.Image(label="Belief (Other)", type="pil")
970
 
971
  with gr.Row():
972
  with gr.Column(scale=2):
973
+ gr.Markdown("### Manual Controls")
974
  with gr.Row():
975
+ btn_L = gr.Button("L")
976
+ btn_F = gr.Button("F")
977
+ btn_R = gr.Button("R")
978
  with gr.Row():
979
+ btn_tick = gr.Button("Tick")
980
+ run_steps = gr.Number(value=25, label="Run N steps", precision=0)
981
+ btn_run = gr.Button("Run")
982
  with gr.Row():
983
+ btn_toggle_control = gr.Button("Toggle Controlled")
984
+ btn_toggle_pov = gr.Button("Toggle POV")
985
+ overlay = gr.Checkbox(False, label="Overlay reticle")
986
+
987
  tile_pick = gr.Radio(
988
  choices=[(TILE_NAMES[k], k) for k in [EMPTY, WALL, FOOD, NOISE, DOOR, TELE]],
989
  value=WALL,
990
+ label="Paint tile type"
991
  )
992
+
993
+ with gr.Column(scale=3):
994
+ gr.Markdown("### Training Controls (Q-learning)")
995
+ use_q_pred = gr.Checkbox(True, label="Use Q-learning: Predator")
996
+ use_q_prey = gr.Checkbox(True, label="Use Q-learning: Prey")
997
+ alpha = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="alpha (learn rate)")
998
+ gamma = gr.Slider(0.5, 0.99, value=0.95, step=0.01, label="gamma (discount)")
999
+ eps = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="epsilon (exploration)")
1000
+ eps_decay = gr.Slider(0.90, 0.999, value=0.995, step=0.001, label="epsilon decay")
1001
+ eps_min = gr.Slider(0.0, 0.2, value=0.02, step=0.01, label="epsilon min")
1002
+
1003
+ episodes = gr.Number(value=50, label="Train episodes", precision=0)
1004
+ max_steps = gr.Number(value=250, label="Max steps per episode", precision=0)
1005
+ btn_train = gr.Button("Train")
1006
+
1007
+ btn_reset = gr.Button("Reset Episode")
1008
+ btn_reset_all = gr.Button("Reset ALL (wipe Q + metrics)")
1009
+
1010
+ with gr.Row():
1011
+ with gr.Column():
1012
+ rewind = gr.Slider(0, 0, value=0, step=1, label="Rewind (history index)")
1013
+ btn_jump = gr.Button("Jump")
1014
+ with gr.Column():
1015
  export_box = gr.Textbox(label="Export JSON", lines=10)
1016
+ btn_export = gr.Button("Export")
1017
+ with gr.Column():
1018
  import_box = gr.Textbox(label="Import JSON", lines=10)
1019
+ btn_import = gr.Button("Import")
 
 
1020
 
1021
+ def refresh(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r: int):
1022
  r_max = max(0, len(hist) - 1)
1023
+ r = max(0, min(int(r), r_max))
1024
+ pov, tr, ba, bb, stxt, etxt, ttxt = build_views(state, bel)
1025
  return (
1026
+ pov, tr, ba, bb,
1027
+ stxt, etxt, ttxt,
1028
+ gr.update(maximum=r_max, value=r),
1029
+ r
 
 
 
 
1030
  )
1031
 
1032
+ def push_hist(state: WorldState, hist: List[Snapshot]) -> List[Snapshot]:
 
1033
  hist.append(snapshot_of(state))
1034
  if len(hist) > MAX_HISTORY:
1035
  hist.pop(0)
1036
+ return hist
1037
+
1038
+ def set_cfg(state: WorldState, uq_pred: bool, uq_prey: bool, a: float, g: float, e: float, ed: float, emin: float):
1039
+ state.cfg.use_q_pred = bool(uq_pred)
1040
+ state.cfg.use_q_prey = bool(uq_prey)
1041
+ state.cfg.alpha = float(a)
1042
+ state.cfg.gamma = float(g)
1043
+ state.metrics.epsilon = float(e)
1044
+ state.cfg.epsilon_decay = float(ed)
1045
+ state.cfg.epsilon_min = float(emin)
1046
+ return state
1047
 
1048
+ def do_manual(state, hist, bel, r, act):
1049
+ tick(state, manual_action=act)
1050
+ hist = push_hist(state, hist)
1051
+ r = len(hist) - 1
1052
+ out = refresh(state, hist, bel, r)
1053
+ return out + (state, hist, bel, r)
 
1054
 
1055
+ def do_tick(state, hist, bel, r):
1056
+ tick(state, manual_action=None)
1057
+ hist = push_hist(state, hist)
1058
+ r = len(hist) - 1
1059
+ out = refresh(state, hist, bel, r)
1060
+ return out + (state, hist, bel, r)
1061
+
1062
+ def do_run(state, hist, bel, r, n):
1063
+ n = max(1, int(n))
1064
+ for _ in range(n):
1065
+ if state.caught:
1066
+ break
1067
+ tick(state, manual_action=None)
1068
+ hist = push_hist(state, hist)
1069
+ r = len(hist) - 1
1070
+ out = refresh(state, hist, bel, r)
1071
+ return out + (state, hist, bel, r)
1072
 
1073
+ def toggle_control(state, hist, bel, r):
1074
  order = ["Predator", "Prey", "Scout"]
1075
  i = order.index(state.controlled)
1076
  state.controlled = order[(i + 1) % len(order)]
1077
+ state.event_log.append(f"Controlled -> {state.controlled}")
1078
+ hist = push_hist(state, hist)
1079
+ r = len(hist) - 1
1080
+ out = refresh(state, hist, bel, r)
1081
+ return out + (state, hist, bel, r)
1082
 
1083
+ def toggle_pov(state, hist, bel, r):
1084
  order = ["Predator", "Prey", "Scout"]
1085
  i = order.index(state.pov)
1086
  state.pov = order[(i + 1) % len(order)]
1087
+ state.event_log.append(f"POV -> {state.pov}")
1088
+ hist = push_hist(state, hist)
1089
+ r = len(hist) - 1
1090
+ out = refresh(state, hist, bel, r)
1091
+ return out + (state, hist, bel, r)
1092
 
1093
+ def set_overlay(state, hist, bel, r, ov):
1094
+ state.overlay = bool(ov)
1095
+ out = refresh(state, hist, bel, r)
1096
+ return out + (state, hist, bel, r)
 
 
 
 
 
 
 
 
 
1097
 
1098
+ def click_truth(tile, state, hist, bel, r, evt: gr.SelectData):
1099
  state = grid_click_to_tile(evt, int(tile), state)
1100
+ hist = push_hist(state, hist)
1101
+ r = len(hist) - 1
1102
+ out = refresh(state, hist, bel, r)
1103
+ return out + (state, hist, bel, r)
 
1104
 
1105
+ def jump(state, hist, bel, r, idx):
1106
+ if not hist:
1107
+ out = refresh(state, hist, bel, r)
1108
+ return out + (state, hist, bel, r)
1109
+ idx = max(0, min(int(idx), len(hist) - 1))
1110
+ restore_into(state, hist[idx])
1111
+ r = idx
1112
+ out = refresh(state, hist, bel, r)
1113
+ return out + (state, hist, bel, r)
1114
+
1115
+ def reset_ep(state, hist, bel, r):
1116
+ reset_episode(state, seed=state.seed)
1117
+ hist = [snapshot_of(state)]
1118
+ r = 0
1119
+ bel = init_belief()
1120
+ out = refresh(state, hist, bel, r)
1121
+ return out + (state, hist, bel, r)
1122
+
1123
+ def reset_all(state, hist, bel, r):
1124
+ seed = state.seed
1125
+ state = init_state(seed)
1126
+ hist = [snapshot_of(state)]
1127
+ bel = init_belief()
1128
+ r = 0
1129
+ out = refresh(state, hist, bel, r)
1130
+ return out + (state, hist, bel, r)
1131
+
1132
+ def do_train(state, hist, bel, r,
1133
+ uq_pred, uq_prey, a, g, e, ed, emin,
1134
+ eps_count, max_s):
1135
+ state = set_cfg(state, uq_pred, uq_prey, a, g, e, ed, emin)
1136
+ train(state, episodes=max(1, int(eps_count)), max_steps=max(10, int(max_s)))
1137
+ # After training, reset to a clean episode so user sees improved behavior
1138
+ reset_episode(state, seed=state.seed)
1139
+ hist = [snapshot_of(state)]
1140
+ bel = init_belief()
1141
+ r = 0
1142
+ out = refresh(state, hist, bel, r)
1143
+ return out + (state, hist, bel, r)
1144
+
1145
+ def export_fn(state, hist):
1146
  return export_run(state, hist)
1147
 
1148
+ def import_fn(txt):
1149
+ state, hist, bel, r = import_run(txt)
1150
+ pov, tr, ba, bb, stxt, etxt, ttxt = build_views(state, bel)
1151
  r_max = max(0, len(hist) - 1)
1152
  return (
1153
+ pov, tr, ba, bb, stxt, etxt, ttxt,
1154
+ gr.update(maximum=r_max, value=r),
1155
+ state, hist, bel, r
1156
  )
1157
 
1158
+ # --- Wire buttons (no fn_kwargs; use lambdas) ---
1159
+ btn_L.click(lambda s,h,b,r: do_manual(s,h,b,r,"L"),
1160
+ inputs=[st, history, beliefs, rewind_idx],
1161
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
1162
+ queue=True)
1163
+
1164
+ btn_F.click(lambda s,h,b,r: do_manual(s,h,b,r,"F"),
1165
+ inputs=[st, history, beliefs, rewind_idx],
1166
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
1167
+ queue=True)
1168
+
1169
+ btn_R.click(lambda s,h,b,r: do_manual(s,h,b,r,"R"),
1170
+ inputs=[st, history, beliefs, rewind_idx],
1171
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
1172
+ queue=True)
1173
+
1174
+ btn_tick.click(do_tick,
1175
+ inputs=[st, history, beliefs, rewind_idx],
1176
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
1177
+ queue=True)
1178
+
1179
+ btn_run.click(do_run,
1180
+ inputs=[st, history, beliefs, rewind_idx, run_steps],
1181
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
1182
+ queue=True)
1183
+
1184
+ btn_toggle_control.click(toggle_control,
1185
+ inputs=[st, history, beliefs, rewind_idx],
1186
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
1187
+ queue=True)
1188
+
1189
+ btn_toggle_pov.click(toggle_pov,
1190
+ inputs=[st, history, beliefs, rewind_idx],
1191
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
1192
+ queue=True)
1193
+
1194
+ overlay.change(set_overlay,
1195
+ inputs=[st, history, beliefs, rewind_idx, overlay],
1196
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
1197
+ queue=True)
1198
+
1199
+ truth.select(click_truth,
1200
+ inputs=[tile_pick, st, history, beliefs, rewind_idx],
1201
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
1202
+ queue=True)
1203
+
1204
+ btn_jump.click(jump,
1205
+ inputs=[st, history, beliefs, rewind_idx, rewind],
1206
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
1207
+ queue=True)
1208
+
1209
+ btn_reset.click(reset_ep,
1210
+ inputs=[st, history, beliefs, rewind_idx],
1211
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
1212
+ queue=True)
1213
+
1214
+ btn_reset_all.click(reset_all,
1215
+ inputs=[st, history, beliefs, rewind_idx],
1216
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
1217
+ queue=True)
1218
+
1219
+ btn_train.click(do_train,
1220
+ inputs=[st, history, beliefs, rewind_idx,
1221
+ use_q_pred, use_q_prey, alpha, gamma, eps, eps_decay, eps_min,
1222
+ episodes, max_steps],
1223
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
1224
+ queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1225
 
1226
  btn_export.click(export_fn, inputs=[st, history], outputs=[export_box], queue=True)
 
 
 
 
 
 
 
 
 
 
 
1227
 
1228
+ btn_import.click(import_fn,
1229
+ inputs=[import_box],
1230
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, st, history, beliefs, rewind_idx],
1231
+ queue=True)
1232
 
1233
+ demo.load(refresh,
1234
+ inputs=[st, history, beliefs, rewind_idx],
1235
+ outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx],
1236
+ queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1237
 
1238
  demo.queue().launch()