ZENLLC commited on
Commit
26366a0
·
verified ·
1 Parent(s): 850d45c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1864 -0
app.py ADDED
@@ -0,0 +1,1864 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import hashlib
4
+ from dataclasses import dataclass, asdict
5
+ from typing import Dict, List, Tuple, Optional, Any
6
+
7
+ import numpy as np
8
+ from PIL import Image, ImageDraw
9
+ import matplotlib.pyplot as plt
10
+ import gradio as gr
11
+
12
+ # ============================================================
13
+ # ZEN AgentLab — Agent POV + Multi-Agent Mini-Sim Arena
14
+ # State-of-the-art evolution of your "ChronoSandbox++" reference.
15
+ #
16
+ # Features:
17
+ # - Deterministic gridworld + first-person raycast POV
18
+ # - Multiple environments (Chase / CoopVault / MiniCiv)
19
+ # - Click-to-edit tiles + inventory pickups
20
+ # - Full step trace: obs -> action -> reward -> (optional) Q-update
21
+ # - Branching timelines (fork from any rewind point)
22
+ # - Batch training (tabular Q-learning) + metrics dashboard
23
+ # - Export/import full runs + SHA256 proof hash ("Proof-of-Run")
24
+ #
25
+ # Hugging Face Spaces compatible: no timers, no fn_kwargs
26
+ # ============================================================
27
+
28
+ # -----------------------------
29
+ # Global config (shared)
30
+ # -----------------------------
31
+ GRID_W, GRID_H = 21, 15
32
+ TILE = 22
33
+
34
+ VIEW_W, VIEW_H = 640, 360
35
+ RAY_W = 320
36
+ FOV_DEG = 78
37
+ MAX_DEPTH = 20
38
+
39
+ DIRS = [(1, 0), (0, 1), (-1, 0), (0, -1)]
40
+ ORI_DEG = [0, 90, 180, 270]
41
+
42
+ # Tiles
43
+ EMPTY = 0
44
+ WALL = 1
45
+ FOOD = 2
46
+ NOISE = 3
47
+ DOOR = 4
48
+ TELE = 5
49
+ KEY = 6
50
+ EXIT = 7
51
+ ARTIFACT = 8
52
+ HAZARD = 9
53
+ WOOD = 10
54
+ ORE = 11
55
+ MEDKIT = 12
56
+ SWITCH = 13
57
+ BASE = 14
58
+
59
+ TILE_NAMES = {
60
+ EMPTY: "Empty",
61
+ WALL: "Wall",
62
+ FOOD: "Food",
63
+ NOISE: "Noise",
64
+ DOOR: "Door",
65
+ TELE: "Teleporter",
66
+ KEY: "Key",
67
+ EXIT: "Exit",
68
+ ARTIFACT: "Artifact",
69
+ HAZARD: "Hazard",
70
+ WOOD: "Wood",
71
+ ORE: "Ore",
72
+ MEDKIT: "Medkit",
73
+ SWITCH: "Switch",
74
+ BASE: "Base",
75
+ }
76
+
77
+ # Colors
78
+ AGENT_COLORS = {
79
+ "Predator": (255, 120, 90),
80
+ "Prey": (120, 255, 160),
81
+ "Scout": (120, 190, 255),
82
+
83
+ "Alpha": (255, 205, 120),
84
+ "Bravo": (160, 210, 255),
85
+ "Guardian": (255, 120, 220),
86
+
87
+ "BuilderA": (140, 255, 200),
88
+ "BuilderB": (160, 200, 255),
89
+ "Raider": (255, 160, 120),
90
+ }
91
+
92
+ SKY = np.array([14, 16, 26], dtype=np.uint8)
93
+ FLOOR_NEAR = np.array([24, 26, 40], dtype=np.uint8)
94
+ FLOOR_FAR = np.array([10, 11, 18], dtype=np.uint8)
95
+ WALL_BASE = np.array([210, 210, 225], dtype=np.uint8)
96
+ WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8)
97
+ DOOR_COL = np.array([140, 210, 255], dtype=np.uint8)
98
+
99
+ # Keep actions small for tabular stability
100
+ ACTIONS = ["L", "F", "R", "I"] # I = interact (pickups/door/key/switch/base)
101
+
102
+ # -----------------------------
103
+ # Deterministic RNG streams
104
+ # -----------------------------
105
+ def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator:
106
+ mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531)
107
+ return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF)
108
+
109
+
110
+ # -----------------------------
111
+ # Data structures
112
+ # -----------------------------
113
+ @dataclass
114
+ class Agent:
115
+ name: str
116
+ x: int
117
+ y: int
118
+ ori: int
119
+ hp: int = 10
120
+ energy: int = 100
121
+ team: str = "A"
122
+ brain: str = "q" # q | heuristic | random
123
+ inventory: Dict[str, int] = None
124
+
125
+ def __post_init__(self):
126
+ if self.inventory is None:
127
+ self.inventory = {}
128
+
129
+ @dataclass
130
+ class TrainConfig:
131
+ use_q: bool = True
132
+ alpha: float = 0.15
133
+ gamma: float = 0.95
134
+ epsilon: float = 0.10
135
+ epsilon_min: float = 0.02
136
+ epsilon_decay: float = 0.995
137
+
138
+ # shaping (generic)
139
+ step_penalty: float = -0.01
140
+ explore_reward: float = 0.015
141
+ damage_penalty: float = -0.20
142
+ heal_reward: float = 0.10
143
+
144
+ # chase env shaping
145
+ chase_close_coeff: float = 0.03
146
+ chase_catch_reward: float = 3.0
147
+ chase_escaped_reward: float = 0.2
148
+ chase_caught_penalty: float = -3.0
149
+ food_reward: float = 0.6
150
+
151
+ # vault env shaping
152
+ artifact_pick_reward: float = 1.2
153
+ exit_win_reward: float = 3.0
154
+ guardian_tag_reward: float = 2.0
155
+ tagged_penalty: float = -2.0
156
+ switch_reward: float = 0.8
157
+ key_reward: float = 0.4
158
+
159
+ # civ env shaping
160
+ resource_pick_reward: float = 0.15
161
+ deposit_reward: float = 0.4
162
+ base_progress_win_reward: float = 3.5
163
+ raider_elim_reward: float = 2.0
164
+ builder_elim_penalty: float = -2.0
165
+
166
+ @dataclass
167
+ class GlobalMetrics:
168
+ episodes: int = 0
169
+ wins_teamA: int = 0
170
+ wins_teamB: int = 0
171
+ draws: int = 0
172
+ avg_steps: float = 0.0
173
+ rolling_winrate_A: float = 0.0
174
+ epsilon: float = 0.10
175
+ last_outcome: str = "init"
176
+ last_steps: int = 0
177
+
178
+ @dataclass
179
+ class EpisodeMetrics:
180
+ steps: int = 0
181
+ returns: Dict[str, float] = None
182
+ action_counts: Dict[str, Dict[str, int]] = None
183
+ tiles_discovered: Dict[str, int] = None
184
+ q_states: Dict[str, int] = None
185
+
186
+ def __post_init__(self):
187
+ if self.returns is None:
188
+ self.returns = {}
189
+ if self.action_counts is None:
190
+ self.action_counts = {}
191
+ if self.tiles_discovered is None:
192
+ self.tiles_discovered = {}
193
+ if self.q_states is None:
194
+ self.q_states = {}
195
+
196
+ @dataclass
197
+ class WorldState:
198
+ seed: int
199
+ step: int
200
+ env_key: str
201
+ grid: List[List[int]]
202
+ agents: Dict[str, Agent]
203
+
204
+ controlled: str
205
+ pov: str
206
+ overlay: bool
207
+
208
+ done: bool
209
+ outcome: str # "A_win" | "B_win" | "draw" | "ongoing"
210
+
211
+ # env state
212
+ door_opened_global: bool = False
213
+ base_progress: int = 0
214
+ base_target: int = 10
215
+
216
+ # instrumentation
217
+ event_log: List[str] = None
218
+ trace_log: List[str] = None
219
+
220
+ # learning
221
+ cfg: TrainConfig = None
222
+ q_tables: Dict[str, Dict[str, List[float]]] = None # per-agent Q
223
+ gmetrics: GlobalMetrics = None
224
+ emetrics: EpisodeMetrics = None
225
+
226
+ def __post_init__(self):
227
+ if self.event_log is None:
228
+ self.event_log = []
229
+ if self.trace_log is None:
230
+ self.trace_log = []
231
+ if self.cfg is None:
232
+ self.cfg = TrainConfig()
233
+ if self.q_tables is None:
234
+ self.q_tables = {}
235
+ if self.gmetrics is None:
236
+ self.gmetrics = GlobalMetrics(epsilon=self.cfg.epsilon)
237
+ if self.emetrics is None:
238
+ self.emetrics = EpisodeMetrics()
239
+
240
+ @dataclass
241
+ class Snapshot:
242
+ branch: str
243
+ step: int
244
+ env_key: str
245
+ grid: List[List[int]]
246
+ agents: Dict[str, Dict[str, Any]]
247
+ done: bool
248
+ outcome: str
249
+ door_opened_global: bool
250
+ base_progress: int
251
+ base_target: int
252
+ event_tail: List[str]
253
+ trace_tail: List[str]
254
+ emetrics: Dict[str, Any]
255
+
256
+ # -----------------------------
257
+ # Helpers
258
+ # -----------------------------
259
+ def in_bounds(x: int, y: int) -> bool:
260
+ return 0 <= x < GRID_W and 0 <= y < GRID_H
261
+
262
+ def is_blocking(tile: int, door_open: bool = False) -> bool:
263
+ if tile == WALL:
264
+ return True
265
+ if tile == DOOR and not door_open:
266
+ return True
267
+ return False
268
+
269
+ def manhattan_xy(ax: int, ay: int, bx: int, by: int) -> int:
270
+ return abs(ax - bx) + abs(ay - by)
271
+
272
+ def bresenham_los(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> bool:
273
+ dx = abs(x1 - x0)
274
+ dy = abs(y1 - y0)
275
+ sx = 1 if x0 < x1 else -1
276
+ sy = 1 if y0 < y1 else -1
277
+ err = dx - dy
278
+ x, y = x0, y0
279
+ while True:
280
+ if (x, y) != (x0, y0) and (x, y) != (x1, y1):
281
+ if grid[y][x] == WALL:
282
+ return False
283
+ if x == x1 and y == y1:
284
+ return True
285
+ e2 = 2 * err
286
+ if e2 > -dy:
287
+ err -= dy
288
+ x += sx
289
+ if e2 < dx:
290
+ err += dx
291
+ y += sy
292
+
293
+ def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = FOV_DEG) -> bool:
294
+ dx = tx - observer.x
295
+ dy = ty - observer.y
296
+ if dx == 0 and dy == 0:
297
+ return True
298
+ angle = math.degrees(math.atan2(dy, dx)) % 360
299
+ facing = ORI_DEG[observer.ori]
300
+ diff = (angle - facing + 540) % 360 - 180
301
+ return abs(diff) <= (fov_deg / 2)
302
+
303
+ def visible(state: WorldState, observer: Agent, target: Agent) -> bool:
304
+ # doors block LOS like walls for simplicity
305
+ if not within_fov(observer, target.x, target.y, FOV_DEG):
306
+ return False
307
+ # treat door as wall in LOS even if opened, to keep simple
308
+ return bresenham_los(state.grid, observer.x, observer.y, target.x, target.y)
309
+
310
+ def hash_sha256(txt: str) -> str:
311
+ return hashlib.sha256(txt.encode("utf-8")).hexdigest()
312
+
313
+ # -----------------------------
314
+ # Belief maps / fog-of-war
315
+ # -----------------------------
316
+ def init_beliefs(agent_names: List[str]) -> Dict[str, np.ndarray]:
317
+ b = {}
318
+ for nm in agent_names:
319
+ b[nm] = -1 * np.ones((GRID_H, GRID_W), dtype=np.int16)
320
+ return b
321
+
322
+ def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> None:
323
+ belief[agent.y, agent.x] = state.grid[agent.y][agent.x]
324
+ base = math.radians(ORI_DEG[agent.ori])
325
+ half = math.radians(FOV_DEG / 2)
326
+ rays = 45 if agent.name.lower().startswith("scout") else 33
327
+
328
+ for i in range(rays):
329
+ t = i / (rays - 1)
330
+ ang = base + (t * 2 - 1) * half
331
+ sin_a, cos_a = math.sin(ang), math.cos(ang)
332
+ ox, oy = agent.x + 0.5, agent.y + 0.5
333
+ depth = 0.0
334
+ while depth < MAX_DEPTH:
335
+ depth += 0.2
336
+ tx = int(ox + cos_a * depth)
337
+ ty = int(oy + sin_a * depth)
338
+ if not in_bounds(tx, ty):
339
+ break
340
+ belief[ty, tx] = state.grid[ty][tx]
341
+ tile = state.grid[ty][tx]
342
+ if tile == WALL:
343
+ break
344
+ if tile == DOOR and not state.door_opened_global:
345
+ break
346
+
347
+ # -----------------------------
348
+ # Rendering
349
+ # -----------------------------
350
+ def raycast_view(state: WorldState, observer: Agent) -> np.ndarray:
351
+ img = np.zeros((VIEW_H, VIEW_W, 3), dtype=np.uint8)
352
+ img[:, :] = SKY
353
+
354
+ for y in range(VIEW_H // 2, VIEW_H):
355
+ t = (y - VIEW_H // 2) / (VIEW_H // 2 + 1e-6)
356
+ col = (1 - t) * FLOOR_NEAR + t * FLOOR_FAR
357
+ img[y, :] = col.astype(np.uint8)
358
+
359
+ fov = math.radians(FOV_DEG)
360
+ half_fov = fov / 2
361
+
362
+ for rx in range(RAY_W):
363
+ cam_x = (2 * rx / (RAY_W - 1)) - 1
364
+ ray_ang = math.radians(ORI_DEG[observer.ori]) + cam_x * half_fov
365
+
366
+ ox, oy = observer.x + 0.5, observer.y + 0.5
367
+ sin_a = math.sin(ray_ang)
368
+ cos_a = math.cos(ray_ang)
369
+
370
+ depth = 0.0
371
+ hit = None # "wall" | "door"
372
+ side = 0
373
+
374
+ while depth < MAX_DEPTH:
375
+ depth += 0.05
376
+ tx = int(ox + cos_a * depth)
377
+ ty = int(oy + sin_a * depth)
378
+ if not in_bounds(tx, ty):
379
+ break
380
+ tile = state.grid[ty][tx]
381
+ if tile == WALL:
382
+ hit = "wall"
383
+ side = 1 if abs(cos_a) > abs(sin_a) else 0
384
+ break
385
+ if tile == DOOR and not state.door_opened_global:
386
+ hit = "door"
387
+ break
388
+
389
+ if hit is None:
390
+ continue
391
+
392
+ depth *= math.cos(ray_ang - math.radians(ORI_DEG[observer.ori]))
393
+ depth = max(depth, 0.001)
394
+
395
+ proj_h = int((VIEW_H * 0.9) / depth)
396
+ y0 = max(0, VIEW_H // 2 - proj_h // 2)
397
+ y1 = min(VIEW_H - 1, VIEW_H // 2 + proj_h // 2)
398
+
399
+ if hit == "door":
400
+ col = DOOR_COL.copy()
401
+ else:
402
+ col = WALL_BASE.copy() if side == 0 else WALL_SIDE.copy()
403
+
404
+ dim = max(0.25, 1.0 - (depth / MAX_DEPTH))
405
+ col = (col * dim).astype(np.uint8)
406
+
407
+ x0 = int(rx * (VIEW_W / RAY_W))
408
+ x1 = int((rx + 1) * (VIEW_W / RAY_W))
409
+ img[y0:y1, x0:x1] = col
410
+
411
+ # billboards for visible agents
412
+ for nm, other in state.agents.items():
413
+ if nm == observer.name:
414
+ continue
415
+ if visible(state, observer, other):
416
+ dx = other.x - observer.x
417
+ dy = other.y - observer.y
418
+ ang = (math.degrees(math.atan2(dy, dx)) % 360)
419
+ facing = ORI_DEG[observer.ori]
420
+ diff = (ang - facing + 540) % 360 - 180
421
+ sx = int((diff / (FOV_DEG / 2)) * (VIEW_W / 2) + (VIEW_W / 2))
422
+ dist = math.sqrt(dx * dx + dy * dy)
423
+ h = int((VIEW_H * 0.65) / max(dist, 0.75))
424
+ w = max(10, h // 3)
425
+ y_mid = VIEW_H // 2
426
+ y0 = max(0, y_mid - h // 2)
427
+ y1 = min(VIEW_H - 1, y_mid + h // 2)
428
+ x0 = max(0, sx - w // 2)
429
+ x1 = min(VIEW_W - 1, sx + w // 2)
430
+ col = AGENT_COLORS.get(nm, (255, 200, 120))
431
+ img[y0:y1, x0:x1] = np.array(col, dtype=np.uint8)
432
+
433
+ if state.overlay:
434
+ cx, cy = VIEW_W // 2, VIEW_H // 2
435
+ img[cy - 1:cy + 2, cx - 10:cx + 10] = np.array([120, 190, 255], dtype=np.uint8)
436
+ img[cy - 10:cy + 10, cx - 1:cx + 2] = np.array([120, 190, 255], dtype=np.uint8)
437
+
438
+ return img
439
+
440
+ def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_agents: bool = True) -> Image.Image:
441
+ w = grid.shape[1] * TILE
442
+ h = grid.shape[0] * TILE
443
+ im = Image.new("RGB", (w, h + 28), (10, 12, 18))
444
+ draw = ImageDraw.Draw(im)
445
+
446
+ for y in range(grid.shape[0]):
447
+ for x in range(grid.shape[1]):
448
+ t = int(grid[y, x])
449
+ if t == -1:
450
+ col = (18, 20, 32)
451
+ elif t == EMPTY:
452
+ col = (26, 30, 44)
453
+ elif t == WALL:
454
+ col = (190, 190, 210)
455
+ elif t == FOOD:
456
+ col = (255, 210, 120)
457
+ elif t == NOISE:
458
+ col = (255, 120, 220)
459
+ elif t == DOOR:
460
+ col = (140, 210, 255)
461
+ elif t == TELE:
462
+ col = (120, 190, 255)
463
+ elif t == KEY:
464
+ col = (255, 235, 160)
465
+ elif t == EXIT:
466
+ col = (120, 255, 220)
467
+ elif t == ARTIFACT:
468
+ col = (255, 170, 60)
469
+ elif t == HAZARD:
470
+ col = (255, 90, 90)
471
+ elif t == WOOD:
472
+ col = (170, 120, 60)
473
+ elif t == ORE:
474
+ col = (140, 140, 160)
475
+ elif t == MEDKIT:
476
+ col = (120, 255, 140)
477
+ elif t == SWITCH:
478
+ col = (200, 180, 255)
479
+ elif t == BASE:
480
+ col = (220, 220, 240)
481
+ else:
482
+ col = (80, 80, 90)
483
+
484
+ x0, y0 = x * TILE, y * TILE + 28
485
+ draw.rectangle([x0, y0, x0 + TILE - 1, y0 + TILE - 1], fill=col)
486
+
487
+ for x in range(grid.shape[1] + 1):
488
+ xx = x * TILE
489
+ draw.line([xx, 28, xx, h + 28], fill=(12, 14, 22))
490
+ for y in range(grid.shape[0] + 1):
491
+ yy = y * TILE + 28
492
+ draw.line([0, yy, w, yy], fill=(12, 14, 22))
493
+
494
+ if show_agents:
495
+ for nm, a in agents.items():
496
+ if a.hp <= 0:
497
+ continue
498
+ cx = a.x * TILE + TILE // 2
499
+ cy = a.y * TILE + 28 + TILE // 2
500
+ col = AGENT_COLORS.get(nm, (220, 220, 220))
501
+ r = TILE // 3
502
+ draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col)
503
+ dx, dy = DIRS[a.ori]
504
+ draw.line([cx, cy, cx + dx * r, cy + dy * r], fill=(10, 10, 10), width=3)
505
+
506
+ draw.rectangle([0, 0, w, 28], fill=(14, 16, 26))
507
+ draw.text((8, 6), title, fill=(230, 230, 240))
508
+ return im
509
+
510
+ # -----------------------------
511
+ # Environments (3 modes)
512
+ # -----------------------------
513
+ def grid_with_border() -> List[List[int]]:
514
+ g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)]
515
+ for x in range(GRID_W):
516
+ g[0][x] = WALL
517
+ g[GRID_H - 1][x] = WALL
518
+ for y in range(GRID_H):
519
+ g[y][0] = WALL
520
+ g[y][GRID_W - 1] = WALL
521
+ return g
522
+
523
+ def env_chase(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
524
+ g = grid_with_border()
525
+ # mid-wall + door
526
+ for x in range(4, 17):
527
+ g[7][x] = WALL
528
+ g[7][10] = DOOR
529
+
530
+ # objects
531
+ g[3][4] = FOOD
532
+ g[11][15] = FOOD
533
+ g[4][14] = NOISE
534
+ g[12][5] = NOISE
535
+ g[2][18] = TELE
536
+ g[13][2] = TELE
537
+
538
+ agents = {
539
+ "Predator": Agent("Predator", 2, 2, 0, hp=10, energy=100, team="A", brain="q"),
540
+ "Prey": Agent("Prey", 18, 12, 2, hp=10, energy=100, team="B", brain="q"),
541
+ "Scout": Agent("Scout", 10, 3, 1, hp=10, energy=100, team="A", brain="heuristic"),
542
+ }
543
+ return g, agents
544
+
545
+ def env_vault(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
546
+ g = grid_with_border()
547
+ # internal maze
548
+ for x in range(3, 18):
549
+ g[5][x] = WALL
550
+ for x in range(3, 18):
551
+ g[9][x] = WALL
552
+ g[5][10] = DOOR
553
+ g[9][12] = DOOR
554
+
555
+ # special tiles
556
+ g[2][2] = KEY
557
+ g[12][18] = EXIT
558
+ g[12][2] = ARTIFACT
559
+ g[2][18] = TELE
560
+ g[13][2] = TELE
561
+ g[7][10] = SWITCH
562
+ g[3][15] = HAZARD
563
+ g[11][6] = MEDKIT
564
+ g[2][12] = FOOD
565
+
566
+ agents = {
567
+ "Alpha": Agent("Alpha", 2, 12, 0, hp=10, energy=100, team="A", brain="q"),
568
+ "Bravo": Agent("Bravo", 3, 12, 0, hp=10, energy=100, team="A", brain="q"),
569
+ "Guardian": Agent("Guardian", 18, 2, 2, hp=10, energy=100, team="B", brain="q"),
570
+ }
571
+ return g, agents
572
+
573
+ def env_civ(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
574
+ g = grid_with_border()
575
+
576
+ # walls forming zones
577
+ for y in range(3, 12):
578
+ g[y][9] = WALL
579
+ g[7][9] = DOOR
580
+
581
+ # resources
582
+ g[2][3] = WOOD
583
+ g[3][3] = WOOD
584
+ g[4][3] = WOOD
585
+ g[12][16] = ORE
586
+ g[11][16] = ORE
587
+ g[10][16] = ORE
588
+ g[6][4] = FOOD
589
+ g[8][15] = FOOD
590
+
591
+ # base + hazards + switch/key
592
+ g[13][10] = BASE
593
+ g[4][15] = HAZARD
594
+ g[10][4] = HAZARD
595
+ g[2][18] = TELE
596
+ g[13][2] = TELE
597
+ g[2][2] = KEY
598
+ g[12][6] = SWITCH
599
+ g[7][9] = DOOR
600
+
601
+ agents = {
602
+ "BuilderA": Agent("BuilderA", 3, 12, 0, hp=10, energy=100, team="A", brain="q"),
603
+ "BuilderB": Agent("BuilderB", 4, 12, 0, hp=10, energy=100, team="A", brain="q"),
604
+ "Raider": Agent("Raider", 18, 2, 2, hp=10, energy=100, team="B", brain="q"),
605
+ }
606
+ return g, agents
607
+
608
+ ENV_BUILDERS = {
609
+ "chase": env_chase,
610
+ "vault": env_vault,
611
+ "civ": env_civ,
612
+ }
613
+
614
+ # -----------------------------
615
+ # Observation encoding (compact)
616
+ # -----------------------------
617
+ def local_tile_ahead(state: WorldState, a: Agent) -> int:
618
+ dx, dy = DIRS[a.ori]
619
+ nx, ny = a.x + dx, a.y + dy
620
+ if not in_bounds(nx, ny):
621
+ return WALL
622
+ return state.grid[ny][nx]
623
+
624
+ def nearest_enemy_vec(state: WorldState, a: Agent) -> Tuple[int, int, int]:
625
+ best = None
626
+ for nm, other in state.agents.items():
627
+ if other.hp <= 0:
628
+ continue
629
+ if other.team == a.team:
630
+ continue
631
+ d = manhattan_xy(a.x, a.y, other.x, other.y)
632
+ if best is None or d < best[0]:
633
+ best = (d, other.x - a.x, other.y - a.y)
634
+ if best is None:
635
+ return (99, 0, 0)
636
+ d, dx, dy = best
637
+ return (d, int(np.clip(dx, -6, 6)), int(np.clip(dy, -6, 6)))
638
+
639
+ def obs_key(state: WorldState, who: str) -> str:
640
+ a = state.agents[who]
641
+ # include env key so tables don't cross-contaminate
642
+ # include coarse enemy vector, tile ahead, inventory coarse
643
+ d, dx, dy = nearest_enemy_vec(state, a)
644
+ ahead = local_tile_ahead(state, a)
645
+ keys = a.inventory.get("key", 0)
646
+ art = a.inventory.get("artifact", 0)
647
+ wood = a.inventory.get("wood", 0)
648
+ ore = a.inventory.get("ore", 0)
649
+
650
+ inv_bucket = f"k{min(keys,2)}a{min(art,1)}w{min(wood,3)}o{min(ore,3)}"
651
+ door = 1 if state.door_opened_global else 0
652
+ return f"{state.env_key}|{who}|{a.x},{a.y},{a.ori}|e{d}:{dx},{dy}|t{ahead}|hp{a.hp}|{inv_bucket}|D{door}|bp{state.base_progress}"
653
+
654
+ def q_get(q: Dict[str, List[float]], key: str) -> List[float]:
655
+ if key not in q:
656
+ q[key] = [0.0 for _ in ACTIONS]
657
+ return q[key]
658
+
659
+ def epsilon_greedy(qvals: List[float], eps: float, r: np.random.Generator) -> int:
660
+ if r.random() < eps:
661
+ return int(r.integers(0, len(qvals)))
662
+ return int(np.argmax(qvals))
663
+
664
+ def q_update(q: Dict[str, List[float]], key: str, a_idx: int, reward: float, next_key: str,
665
+ alpha: float, gamma: float) -> Tuple[float, float, float]:
666
+ qv = q_get(q, key)
667
+ nq = q_get(q, next_key)
668
+ old = qv[a_idx]
669
+ target = reward + gamma * float(np.max(nq))
670
+ new = old + alpha * (target - old)
671
+ qv[a_idx] = new
672
+ return old, target, new
673
+
674
+ # -----------------------------
675
+ # Heuristic baselines
676
+ # -----------------------------
677
+ def face_towards(a: Agent, tx: int, ty: int) -> str:
678
+ dx = tx - a.x
679
+ dy = ty - a.y
680
+ ang = (math.degrees(math.atan2(dy, dx)) % 360)
681
+ facing = ORI_DEG[a.ori]
682
+ diff = (ang - facing + 540) % 360 - 180
683
+ if diff < -10:
684
+ return "L"
685
+ if diff > 10:
686
+ return "R"
687
+ return "F"
688
+
689
+ def heuristic_action(state: WorldState, who: str) -> str:
690
+ a = state.agents[who]
691
+ r = rng_for(state.seed, state.step, stream=900 + hash(who) % 1000)
692
+
693
+ # simple: if enemy visible, chase if team B (raider/guardian/predator) else flee-ish
694
+ # also, prioritize interact if standing on something valuable
695
+ tile_here = state.grid[a.y][a.x]
696
+ if tile_here in (FOOD, KEY, ARTIFACT, WOOD, ORE, MEDKIT, SWITCH, BASE, EXIT):
697
+ return "I"
698
+
699
+ # find nearest enemy
700
+ best_nm = None
701
+ best_d = 999
702
+ best = None
703
+ for nm, other in state.agents.items():
704
+ if other.hp <= 0 or other.team == a.team:
705
+ continue
706
+ d = manhattan_xy(a.x, a.y, other.x, other.y)
707
+ if d < best_d:
708
+ best_d = d
709
+ best_nm = nm
710
+ best = other
711
+
712
+ if best is not None and best_d <= 6 and visible(state, a, best):
713
+ # attackers chase
714
+ if a.team == "B":
715
+ return face_towards(a, best.x, best.y)
716
+ # defenders flee: turn away from enemy vector
717
+ dx = best.x - a.x
718
+ dy = best.y - a.y
719
+ ang = (math.degrees(math.atan2(dy, dx)) % 360)
720
+ facing = ORI_DEG[a.ori]
721
+ diff = (ang - facing + 540) % 360 - 180
722
+ diff_away = ((diff + 180) + 540) % 360 - 180
723
+ if diff_away < -10:
724
+ return "L"
725
+ if diff_away > 10:
726
+ return "R"
727
+ return "F"
728
+
729
+ # mild exploration bias: try forward more
730
+ return r.choice(["F", "F", "L", "R", "I"])
731
+
732
+ def random_action(state: WorldState, who: str) -> str:
733
+ r = rng_for(state.seed, state.step, stream=700 + hash(who) % 1000)
734
+ return r.choice(ACTIONS)
735
+
736
+ # -----------------------------
737
+ # Movement + interaction
738
+ # -----------------------------
739
+ def turn_left(a: Agent) -> None:
740
+ a.ori = (a.ori - 1) % 4
741
+
742
+ def turn_right(a: Agent) -> None:
743
+ a.ori = (a.ori + 1) % 4
744
+
745
+ def move_forward(state: WorldState, a: Agent) -> str:
746
+ dx, dy = DIRS[a.ori]
747
+ nx, ny = a.x + dx, a.y + dy
748
+ if not in_bounds(nx, ny):
749
+ return "blocked: bounds"
750
+ tile = state.grid[ny][nx]
751
+ if is_blocking(tile, door_open=state.door_opened_global):
752
+ return "blocked: wall/door"
753
+ a.x, a.y = nx, ny
754
+
755
+ if state.grid[ny][nx] == TELE:
756
+ teles = [(x, y) for y in range(GRID_H) for x in range(GRID_W) if state.grid[y][x] == TELE]
757
+ if len(teles) >= 2:
758
+ teles_sorted = sorted(teles)
759
+ idx = teles_sorted.index((nx, ny))
760
+ dest = teles_sorted[(idx + 1) % len(teles_sorted)]
761
+ a.x, a.y = dest
762
+ state.event_log.append(f"t={state.step}: {a.name} teleported.")
763
+ return "moved: teleported"
764
+ return "moved"
765
+
766
+ def try_interact(state: WorldState, a: Agent) -> str:
767
+ t = state.grid[a.y][a.x]
768
+
769
+ # door open global via key or switch
770
+ if t == SWITCH:
771
+ state.door_opened_global = True
772
+ state.grid[a.y][a.x] = EMPTY
773
+ a.inventory["switch"] = a.inventory.get("switch", 0) + 1
774
+ return "switch: opened all doors"
775
+
776
+ if t == KEY:
777
+ a.inventory["key"] = a.inventory.get("key", 0) + 1
778
+ state.grid[a.y][a.x] = EMPTY
779
+ return "picked: key"
780
+
781
+ if t == ARTIFACT:
782
+ a.inventory["artifact"] = a.inventory.get("artifact", 0) + 1
783
+ state.grid[a.y][a.x] = EMPTY
784
+ return "picked: artifact"
785
+
786
+ if t == FOOD:
787
+ a.energy = min(200, a.energy + 35)
788
+ state.grid[a.y][a.x] = EMPTY
789
+ return "ate: food"
790
+
791
+ if t == WOOD:
792
+ a.inventory["wood"] = a.inventory.get("wood", 0) + 1
793
+ state.grid[a.y][a.x] = EMPTY
794
+ return "picked: wood"
795
+
796
+ if t == ORE:
797
+ a.inventory["ore"] = a.inventory.get("ore", 0) + 1
798
+ state.grid[a.y][a.x] = EMPTY
799
+ return "picked: ore"
800
+
801
+ if t == MEDKIT:
802
+ a.hp = min(10, a.hp + 3)
803
+ state.grid[a.y][a.x] = EMPTY
804
+ return "used: medkit"
805
+
806
+ if t == BASE:
807
+ # deposit resources into base_progress
808
+ w = a.inventory.get("wood", 0)
809
+ o = a.inventory.get("ore", 0)
810
+ dep = min(w, 2) + min(o, 2)
811
+ if dep > 0:
812
+ a.inventory["wood"] = max(0, w - min(w, 2))
813
+ a.inventory["ore"] = max(0, o - min(o, 2))
814
+ state.base_progress += dep
815
+ return f"deposited: +{dep} base_progress"
816
+ return "base: nothing to deposit"
817
+
818
+ if t == EXIT:
819
+ return "at_exit"
820
+
821
+ return "interact: none"
822
+
823
+ def apply_action(state: WorldState, who: str, action: str) -> str:
824
+ a = state.agents[who]
825
+ if a.hp <= 0:
826
+ return "dead"
827
+ if action == "L":
828
+ turn_left(a)
829
+ return "turned left"
830
+ if action == "R":
831
+ turn_right(a)
832
+ return "turned right"
833
+ if action == "F":
834
+ return move_forward(state, a)
835
+ if action == "I":
836
+ return try_interact(state, a)
837
+ return "noop"
838
+
839
+ # -----------------------------
840
+ # Combat / hazards / win conditions
841
+ # -----------------------------
842
+ def resolve_hazards(state: WorldState, a: Agent) -> Tuple[bool, str]:
843
+ # returns (took_damage, msg)
844
+ if a.hp <= 0:
845
+ return (False, "")
846
+ t = state.grid[a.y][a.x]
847
+ if t == HAZARD:
848
+ a.hp -= 1
849
+ return (True, "hazard:-hp")
850
+ return (False, "")
851
+
852
+ def resolve_tags(state: WorldState) -> List[str]:
853
+ # if two opposing agents occupy same tile: team B "tags" team A
854
+ msgs = []
855
+ occupied = {}
856
+ for nm, a in state.agents.items():
857
+ if a.hp <= 0:
858
+ continue
859
+ occupied.setdefault((a.x, a.y), []).append(nm)
860
+
861
+ for (x, y), names in occupied.items():
862
+ if len(names) < 2:
863
+ continue
864
+ teams = set(state.agents[n].team for n in names)
865
+ if len(teams) >= 2:
866
+ # tag: both sides take 1 hp damage, but log who collided
867
+ for n in names:
868
+ state.agents[n].hp -= 1
869
+ msgs.append(f"t={state.step}: collision/tag at ({x},{y}) {names} (-hp all)")
870
+ return msgs
871
+
872
+ def check_done(state: WorldState) -> None:
873
+ # Determine environment-specific terminal conditions
874
+ if state.env_key == "chase":
875
+ pred = state.agents["Predator"]
876
+ prey = state.agents["Prey"]
877
+ if pred.hp <= 0 and prey.hp <= 0:
878
+ state.done = True
879
+ state.outcome = "draw"
880
+ return
881
+ if pred.x == prey.x and pred.y == prey.y and pred.hp > 0 and prey.hp > 0:
882
+ state.done = True
883
+ state.outcome = "A_win" # Predator team A
884
+ state.event_log.append(f"t={state.step}: CAUGHT (Predator wins).")
885
+ return
886
+ # prey "escape" win if survives long enough with food? Use energy threshold
887
+ if state.step >= 300 and prey.hp > 0:
888
+ state.done = True
889
+ state.outcome = "B_win"
890
+ state.event_log.append(f"t={state.step}: ESCAPED (Prey survives).")
891
+ return
892
+
893
+ if state.env_key == "vault":
894
+ # Team A wins if any A has artifact and reaches exit
895
+ for nm in ["Alpha", "Bravo"]:
896
+ a = state.agents[nm]
897
+ if a.hp > 0 and a.inventory.get("artifact", 0) > 0 and state.grid[a.y][a.x] == EXIT:
898
+ state.done = True
899
+ state.outcome = "A_win"
900
+ state.event_log.append(f"t={state.step}: VAULT CLEARED (Team A wins).")
901
+ return
902
+ # Team B wins if all A agents eliminated
903
+ alive_A = any(state.agents[n].hp > 0 for n in ["Alpha", "Bravo"])
904
+ if not alive_A:
905
+ state.done = True
906
+ state.outcome = "B_win"
907
+ state.event_log.append(f"t={state.step}: TEAM A ELIMINATED (Guardian wins).")
908
+ return
909
+
910
+ if state.env_key == "civ":
911
+ # Team A wins if base_progress reaches target
912
+ if state.base_progress >= state.base_target:
913
+ state.done = True
914
+ state.outcome = "A_win"
915
+ state.event_log.append(f"t={state.step}: BASE COMPLETE (Builders win).")
916
+ return
917
+ # Team B wins if both builders eliminated
918
+ alive_A = any(state.agents[n].hp > 0 for n in ["BuilderA", "BuilderB"])
919
+ if not alive_A:
920
+ state.done = True
921
+ state.outcome = "B_win"
922
+ state.event_log.append(f"t={state.step}: BUILDERS ELIMINATED (Raider wins).")
923
+ return
924
+ # draw if too long
925
+ if state.step >= 350:
926
+ state.done = True
927
+ state.outcome = "draw"
928
+ state.event_log.append(f"t={state.step}: TIMEOUT (draw).")
929
+ return
930
+
931
+ # -----------------------------
932
+ # Rewards
933
+ # -----------------------------
934
+ def reward_for(state_prev: WorldState, state_now: WorldState, who: str, outcome_msg: str,
935
+ took_damage: bool, interacted: bool) -> float:
936
+ cfg = state_now.cfg
937
+ a0 = state_prev.agents[who]
938
+ a1 = state_now.agents[who]
939
+
940
+ r = cfg.step_penalty
941
+
942
+ # exploration reward: if agent discovered new tiles in belief (we approximate via emetrics tiles_discovered)
943
+ # we update this outside; here just tiny reward if moved
944
+ if outcome_msg.startswith("moved"):
945
+ r += cfg.explore_reward
946
+
947
+ if took_damage:
948
+ r += cfg.damage_penalty
949
+
950
+ # heal reward if used medkit
951
+ if outcome_msg.startswith("used: medkit"):
952
+ r += cfg.heal_reward
953
+
954
+ # environment shaping
955
+ if state_now.env_key == "chase":
956
+ pred = state_now.agents["Predator"]
957
+ prey = state_now.agents["Prey"]
958
+ if who == "Predator":
959
+ d0 = manhattan_xy(state_prev.agents["Predator"].x, state_prev.agents["Predator"].y,
960
+ state_prev.agents["Prey"].x, state_prev.agents["Prey"].y)
961
+ d1 = manhattan_xy(pred.x, pred.y, prey.x, prey.y)
962
+ r += cfg.chase_close_coeff * float(d0 - d1)
963
+ if state_now.done and state_now.outcome == "A_win":
964
+ r += cfg.chase_catch_reward
965
+ if who == "Prey":
966
+ if outcome_msg.startswith("ate: food"):
967
+ r += cfg.food_reward
968
+ if state_now.done and state_now.outcome == "B_win":
969
+ r += cfg.chase_escaped_reward
970
+ if state_now.done and state_now.outcome == "A_win":
971
+ r += cfg.chase_caught_penalty
972
+
973
+ if state_now.env_key == "vault":
974
+ if outcome_msg.startswith("picked: artifact"):
975
+ r += cfg.artifact_pick_reward
976
+ if outcome_msg.startswith("picked: key"):
977
+ r += cfg.key_reward
978
+ if outcome_msg.startswith("switch:"):
979
+ r += cfg.switch_reward
980
+ if state_now.done:
981
+ if state_now.outcome == "A_win" and state_now.agents[who].team == "A":
982
+ r += cfg.exit_win_reward
983
+ if state_now.outcome == "B_win" and state_now.agents[who].team == "B":
984
+ r += cfg.guardian_tag_reward
985
+ if state_now.outcome == "B_win" and state_now.agents[who].team == "A":
986
+ r += cfg.tagged_penalty
987
+
988
+ if state_now.env_key == "civ":
989
+ if outcome_msg.startswith("picked: wood") or outcome_msg.startswith("picked: ore"):
990
+ r += cfg.resource_pick_reward
991
+ if outcome_msg.startswith("deposited:"):
992
+ r += cfg.deposit_reward
993
+ if state_now.done:
994
+ if state_now.outcome == "A_win" and state_now.agents[who].team == "A":
995
+ r += cfg.base_progress_win_reward
996
+ if state_now.outcome == "B_win" and state_now.agents[who].team == "B":
997
+ r += cfg.raider_elim_reward
998
+ if state_now.outcome == "B_win" and state_now.agents[who].team == "A":
999
+ r += cfg.builder_elim_penalty
1000
+
1001
+ return float(r)
1002
+
1003
+ # -----------------------------
1004
+ # Q / policy selection
1005
+ # -----------------------------
1006
+ def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, Optional[Tuple[str, int]]]:
1007
+ """
1008
+ Returns (action, reason, q_info)
1009
+ q_info: (obs_key, action_index) if Q-based else None
1010
+ """
1011
+ a = state.agents[who]
1012
+ cfg = state.cfg
1013
+ r = rng_for(state.seed, state.step, stream=stream)
1014
+
1015
+ # manual control handled outside
1016
+ if a.brain == "random":
1017
+ act = random_action(state, who)
1018
+ return act, "random", None
1019
+ if a.brain == "heuristic":
1020
+ act = heuristic_action(state, who)
1021
+ return act, "heuristic", None
1022
+
1023
+ # Q learning
1024
+ if cfg.use_q:
1025
+ key = obs_key(state, who)
1026
+ qtab = state.q_tables.setdefault(who, {})
1027
+ qv = q_get(qtab, key)
1028
+ a_idx = epsilon_greedy(qv, state.gmetrics.epsilon, r)
1029
+ return ACTIONS[a_idx], f"Q eps={state.gmetrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (key, a_idx)
1030
+
1031
+ act = heuristic_action(state, who)
1032
+ return act, "heuristic(fallback)", None
1033
+
1034
+ # -----------------------------
1035
+ # Episode initialization / reset
1036
+ # -----------------------------
1037
+ def init_state(seed: int, env_key: str) -> WorldState:
1038
+ g, agents = ENV_BUILDERS[env_key](seed)
1039
+
1040
+ st = WorldState(
1041
+ seed=seed,
1042
+ step=0,
1043
+ env_key=env_key,
1044
+ grid=g,
1045
+ agents=agents,
1046
+ controlled=list(agents.keys())[0],
1047
+ pov=list(agents.keys())[0],
1048
+ overlay=False,
1049
+ done=False,
1050
+ outcome="ongoing",
1051
+ door_opened_global=False,
1052
+ base_progress=0,
1053
+ base_target=10,
1054
+ )
1055
+ st.event_log = [f"Initialized env={env_key} seed={seed}."]
1056
+ return st
1057
+
1058
+ def reset_episode_keep_learning(state: WorldState, seed: Optional[int] = None) -> WorldState:
1059
+ if seed is None:
1060
+ seed = state.seed
1061
+ fresh = init_state(int(seed), state.env_key)
1062
+ # carry learning + global metrics
1063
+ fresh.cfg = state.cfg
1064
+ fresh.q_tables = state.q_tables
1065
+ fresh.gmetrics = state.gmetrics
1066
+ fresh.gmetrics.epsilon = state.gmetrics.epsilon
1067
+ return fresh
1068
+
1069
+ def wipe_all(seed: int, env_key: str) -> WorldState:
1070
+ st = init_state(seed, env_key)
1071
+ st.cfg = TrainConfig()
1072
+ st.gmetrics = GlobalMetrics(epsilon=st.cfg.epsilon)
1073
+ st.q_tables = {}
1074
+ return st
1075
+
1076
+ # -----------------------------
1077
+ # History / branching
1078
+ # -----------------------------
1079
+ TRACE_MAX = 500
1080
+ MAX_HISTORY = 1400
1081
+
1082
+ def snapshot_of(state: WorldState, branch: str) -> Snapshot:
1083
+ return Snapshot(
1084
+ branch=branch,
1085
+ step=state.step,
1086
+ env_key=state.env_key,
1087
+ grid=[row[:] for row in state.grid],
1088
+ agents={k: asdict(v) for k, v in state.agents.items()},
1089
+ done=state.done,
1090
+ outcome=state.outcome,
1091
+ door_opened_global=state.door_opened_global,
1092
+ base_progress=state.base_progress,
1093
+ base_target=state.base_target,
1094
+ event_tail=state.event_log[-25:],
1095
+ trace_tail=state.trace_log[-40:],
1096
+ emetrics=asdict(state.emetrics),
1097
+ )
1098
+
1099
+ def restore_into(state: WorldState, snap: Snapshot) -> WorldState:
1100
+ state.step = snap.step
1101
+ state.env_key = snap.env_key
1102
+ state.grid = [row[:] for row in snap.grid]
1103
+ state.agents = {k: Agent(**d) for k, d in snap.agents.items()}
1104
+ state.done = snap.done
1105
+ state.outcome = snap.outcome
1106
+ state.door_opened_global = snap.door_opened_global
1107
+ state.base_progress = snap.base_progress
1108
+ state.base_target = snap.base_target
1109
+ state.event_log.append(f"Jumped to snapshot t={snap.step} (branch={snap.branch}).")
1110
+ return state
1111
+
1112
+ # -----------------------------
1113
+ # Metrics / dashboard
1114
+ # -----------------------------
1115
+ def update_action_counts(state: WorldState, who: str, act: str):
1116
+ state.emetrics.action_counts.setdefault(who, {})
1117
+ state.emetrics.action_counts[who][act] = state.emetrics.action_counts[who].get(act, 0) + 1
1118
+
1119
+ def action_entropy(counts: Dict[str, int]) -> float:
1120
+ total = sum(counts.values())
1121
+ if total <= 0:
1122
+ return 0.0
1123
+ p = np.array([c / total for c in counts.values()], dtype=np.float64)
1124
+ p = np.clip(p, 1e-12, 1.0)
1125
+ return float(-np.sum(p * np.log2(p)))
1126
+
1127
+ def metrics_dashboard_image(state: WorldState) -> Image.Image:
1128
+ gm = state.gmetrics
1129
+ fig = plt.figure(figsize=(7.0, 2.2), dpi=120)
1130
+ ax = fig.add_subplot(111)
1131
+ ax.plot([0, gm.episodes], [gm.rolling_winrate_A, gm.rolling_winrate_A])
1132
+ ax.set_title("Global Metrics Snapshot")
1133
+ ax.set_xlabel("Episodes (scalar)")
1134
+ ax.set_ylabel("Rolling winrate Team A")
1135
+ ax.grid(True)
1136
+
1137
+ # annotate
1138
+ txt = (
1139
+ f"env={state.env_key} | eps={gm.epsilon:.3f} | episodes={gm.episodes}\n"
1140
+ f"A_wins={gm.wins_teamA} B_wins={gm.wins_teamB} draws={gm.draws} | "
1141
+ f"avg_steps~{gm.avg_steps:.1f}\n"
1142
+ f"last_outcome={gm.last_outcome} last_steps={gm.last_steps}"
1143
+ )
1144
+ ax.text(0.01, 0.05, txt, transform=ax.transAxes, fontsize=8, va="bottom")
1145
+
1146
+ fig.tight_layout()
1147
+ fig.canvas.draw()
1148
+ w, h = fig.canvas.get_width_height()
1149
+ img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(h, w, 3)
1150
+ plt.close(fig)
1151
+ return Image.fromarray(img)
1152
+
1153
+ def agent_scoreboard(state: WorldState) -> str:
1154
+ rows = []
1155
+ header = ["agent", "team", "hp", "return", "steps", "entropy", "tiles_disc", "q_states", "inventory"]
1156
+ rows.append(header)
1157
+
1158
+ steps = state.emetrics.steps
1159
+ for nm, a in state.agents.items():
1160
+ ret = state.emetrics.returns.get(nm, 0.0)
1161
+ counts = state.emetrics.action_counts.get(nm, {})
1162
+ ent = action_entropy(counts)
1163
+ td = state.emetrics.tiles_discovered.get(nm, 0)
1164
+ qs = len(state.q_tables.get(nm, {}))
1165
+ inv = json.dumps(a.inventory, sort_keys=True)
1166
+ rows.append([nm, a.team, a.hp, f"{ret:.2f}", steps, f"{ent:.2f}", td, qs, inv])
1167
+
1168
+ # pretty format as fixed-width table
1169
+ col_w = [max(len(str(r[i])) for r in rows) for i in range(len(header))]
1170
+ lines = []
1171
+ for ridx, r in enumerate(rows):
1172
+ line = " | ".join(str(r[i]).ljust(col_w[i]) for i in range(len(header)))
1173
+ lines.append(line)
1174
+ if ridx == 0:
1175
+ lines.append("-+-".join("-" * w for w in col_w))
1176
+ return "\n".join(lines)
1177
+
1178
+ # -----------------------------
1179
+ # Tick (core simulation step)
1180
+ # -----------------------------
1181
+ def clone_shallow(state: WorldState) -> WorldState:
1182
+ # minimal clone to compute rewards
1183
+ st = WorldState(
1184
+ seed=state.seed,
1185
+ step=state.step,
1186
+ env_key=state.env_key,
1187
+ grid=[row[:] for row in state.grid],
1188
+ agents={k: Agent(**asdict(v)) for k, v in state.agents.items()},
1189
+ controlled=state.controlled,
1190
+ pov=state.pov,
1191
+ overlay=state.overlay,
1192
+ done=state.done,
1193
+ outcome=state.outcome,
1194
+ door_opened_global=state.door_opened_global,
1195
+ base_progress=state.base_progress,
1196
+ base_target=state.base_target,
1197
+ event_log=list(state.event_log),
1198
+ trace_log=list(state.trace_log),
1199
+ cfg=state.cfg,
1200
+ q_tables=state.q_tables,
1201
+ gmetrics=state.gmetrics,
1202
+ emetrics=state.emetrics,
1203
+ )
1204
+ return st
1205
+
1206
+ def tick(state: WorldState, beliefs: Dict[str, np.ndarray], manual_action: Optional[str] = None) -> None:
1207
+ if state.done:
1208
+ return
1209
+
1210
+ prev = clone_shallow(state)
1211
+
1212
+ # pick actions
1213
+ chosen: Dict[str, str] = {}
1214
+ reasons: Dict[str, str] = {}
1215
+ qinfo: Dict[str, Optional[Tuple[str, int]]] = {}
1216
+
1217
+ if manual_action is not None:
1218
+ chosen[state.controlled] = manual_action
1219
+ reasons[state.controlled] = "manual"
1220
+ qinfo[state.controlled] = None
1221
+
1222
+ # others choose
1223
+ for who in list(state.agents.keys()):
1224
+ if who in chosen:
1225
+ continue
1226
+ act, reason, qi = choose_action(state, who, stream=200 + (hash(who) % 1000))
1227
+ chosen[who] = act
1228
+ reasons[who] = reason
1229
+ qinfo[who] = qi
1230
+
1231
+ # apply actions in fixed order (deterministic)
1232
+ order = list(state.agents.keys())
1233
+ outcomes: Dict[str, str] = {}
1234
+ took_damage: Dict[str, bool] = {nm: False for nm in order}
1235
+ interacted: Dict[str, bool] = {nm: False for nm in order}
1236
+
1237
+ for who in order:
1238
+ before_tile = state.grid[state.agents[who].y][state.agents[who].x] if state.agents[who].hp > 0 else EMPTY
1239
+ outcomes[who] = apply_action(state, who, chosen[who])
1240
+ if chosen[who] == "I" and outcomes[who] != "interact: none":
1241
+ interacted[who] = True
1242
+
1243
+ dmg, msg = resolve_hazards(state, state.agents[who])
1244
+ took_damage[who] = dmg
1245
+ if msg:
1246
+ state.event_log.append(f"t={state.step}: {who} {msg}")
1247
+
1248
+ # track action counts
1249
+ update_action_counts(state, who, chosen[who])
1250
+
1251
+ # collisions/tags after movement
1252
+ tag_msgs = resolve_tags(state)
1253
+ for m in tag_msgs:
1254
+ state.event_log.append(m)
1255
+
1256
+ # update beliefs / tiles discovered metric
1257
+ for nm, a in state.agents.items():
1258
+ if a.hp <= 0:
1259
+ continue
1260
+ before_unknown = int(np.sum(beliefs[nm] == -1))
1261
+ update_belief_for_agent(state, beliefs[nm], a)
1262
+ after_unknown = int(np.sum(beliefs[nm] == -1))
1263
+ discovered = max(0, before_unknown - after_unknown)
1264
+ state.emetrics.tiles_discovered[nm] = state.emetrics.tiles_discovered.get(nm, 0) + discovered
1265
+
1266
+ # check done conditions
1267
+ check_done(state)
1268
+
1269
+ # rewards + Q updates + returns
1270
+ q_lines = []
1271
+ for who in order:
1272
+ if who not in state.emetrics.returns:
1273
+ state.emetrics.returns[who] = 0.0
1274
+ r = reward_for(prev, state, who, outcomes[who], took_damage[who], interacted[who])
1275
+ state.emetrics.returns[who] += r
1276
+
1277
+ if qinfo.get(who) is not None:
1278
+ key, a_idx = qinfo[who]
1279
+ next_key = obs_key(state, who)
1280
+ qtab = state.q_tables.setdefault(who, {})
1281
+ old, tgt, new = q_update(qtab, key, a_idx, r, next_key, state.cfg.alpha, state.cfg.gamma)
1282
+ q_lines.append(f"{who}: old={old:.3f} tgt={tgt:.3f} new={new:.3f} (a={ACTIONS[a_idx]})")
1283
+
1284
+ # trace
1285
+ trace = f"t={state.step} env={state.env_key} done={state.done} outcome={state.outcome}"
1286
+ for who in order:
1287
+ a = state.agents[who]
1288
+ trace += f" | {who}:{chosen[who]} ({outcomes[who]}) hp={a.hp} [{reasons[who]}]"
1289
+ if q_lines:
1290
+ trace += " | Q: " + " ; ".join(q_lines)
1291
+
1292
+ state.trace_log.append(trace)
1293
+ if len(state.trace_log) > TRACE_MAX:
1294
+ state.trace_log = state.trace_log[-TRACE_MAX:]
1295
+
1296
+ state.step += 1
1297
+ state.emetrics.steps = state.step
1298
+
1299
+ # -----------------------------
1300
+ # Training
1301
+ # -----------------------------
1302
+ def run_episode(state: WorldState, beliefs: Dict[str, np.ndarray], max_steps: int) -> Tuple[str, int]:
1303
+ while state.step < max_steps and not state.done:
1304
+ tick(state, beliefs, manual_action=None)
1305
+ return state.outcome, state.step
1306
+
1307
+ def update_global_metrics_after_episode(state: WorldState, outcome: str, steps: int):
1308
+ gm = state.gmetrics
1309
+ gm.episodes += 1
1310
+ gm.last_outcome = outcome
1311
+ gm.last_steps = steps
1312
+
1313
+ if outcome == "A_win":
1314
+ gm.wins_teamA += 1
1315
+ gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 1.0
1316
+ elif outcome == "B_win":
1317
+ gm.wins_teamB += 1
1318
+ gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.0
1319
+ else:
1320
+ gm.draws += 1
1321
+ gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.5
1322
+
1323
+ gm.avg_steps = (0.90 * gm.avg_steps + 0.10 * steps) if gm.avg_steps > 0 else float(steps)
1324
+
1325
+ # epsilon decay
1326
+ gm.epsilon = max(state.cfg.epsilon_min, gm.epsilon * state.cfg.epsilon_decay)
1327
+
1328
+ def train(state: WorldState, episodes: int, max_steps: int) -> WorldState:
1329
+ for ep in range(episodes):
1330
+ # vary seed deterministically per episode
1331
+ ep_seed = (state.seed * 1_000_003 + (state.gmetrics.episodes + ep) * 97_531) & 0xFFFFFFFF
1332
+ state = reset_episode_keep_learning(state, seed=int(ep_seed))
1333
+ beliefs = init_beliefs(list(state.agents.keys()))
1334
+ outcome, steps = run_episode(state, beliefs, max_steps=max_steps)
1335
+ update_global_metrics_after_episode(state, outcome, steps)
1336
+ state.event_log.append(
1337
+ f"Training: +{episodes} eps | eps={state.gmetrics.epsilon:.3f} | "
1338
+ f"A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws}"
1339
+ )
1340
+ # after training return a clean episode at current seed
1341
+ state = reset_episode_keep_learning(state, seed=state.seed)
1342
+ return state
1343
+
1344
+ # -----------------------------
1345
+ # Export / Import
1346
+ # -----------------------------
1347
+ def export_run(state: WorldState, branches: Dict[str, List[Snapshot]], active_branch: str, rewind_idx: int) -> str:
1348
+ payload = {
1349
+ "seed": state.seed,
1350
+ "env_key": state.env_key,
1351
+ "controlled": state.controlled,
1352
+ "pov": state.pov,
1353
+ "overlay": state.overlay,
1354
+ "cfg": asdict(state.cfg),
1355
+ "gmetrics": asdict(state.gmetrics),
1356
+ "q_tables": state.q_tables,
1357
+ "branches": {b: [asdict(s) for s in snaps] for b, snaps in branches.items()},
1358
+ "active_branch": active_branch,
1359
+ "rewind_idx": rewind_idx,
1360
+ "grid": state.grid,
1361
+ "door_opened_global": state.door_opened_global,
1362
+ "base_progress": state.base_progress,
1363
+ "base_target": state.base_target,
1364
+ }
1365
+ txt = json.dumps(payload, indent=2)
1366
+ proof = hash_sha256(txt)
1367
+ return txt + "\n\n" + json.dumps({"proof_sha256": proof}, indent=2)
1368
+
1369
+ def import_run(txt: str) -> Tuple[WorldState, Dict[str, List[Snapshot]], str, int, Dict[str, np.ndarray]]:
1370
+ # allow trailing proof block
1371
+ parts = txt.strip().split("\n\n")
1372
+ data = json.loads(parts[0])
1373
+
1374
+ st = init_state(int(data.get("seed", 1337)), data.get("env_key", "chase"))
1375
+ st.controlled = data.get("controlled", st.controlled)
1376
+ st.pov = data.get("pov", st.pov)
1377
+ st.overlay = bool(data.get("overlay", False))
1378
+ st.grid = data.get("grid", st.grid)
1379
+ st.door_opened_global = bool(data.get("door_opened_global", False))
1380
+ st.base_progress = int(data.get("base_progress", 0))
1381
+ st.base_target = int(data.get("base_target", 10))
1382
+
1383
+ st.cfg = TrainConfig(**data.get("cfg", asdict(st.cfg)))
1384
+ st.gmetrics = GlobalMetrics(**data.get("gmetrics", asdict(st.gmetrics)))
1385
+ st.q_tables = data.get("q_tables", {})
1386
+
1387
+ branches_in = data.get("branches", {})
1388
+ branches: Dict[str, List[Snapshot]] = {}
1389
+ for bname, snaps in branches_in.items():
1390
+ branches[bname] = [Snapshot(**s) for s in snaps]
1391
+
1392
+ active = data.get("active_branch", "main")
1393
+ r_idx = int(data.get("rewind_idx", 0))
1394
+
1395
+ # restore last snap of active branch if exists
1396
+ if active in branches and branches[active]:
1397
+ st = restore_into(st, branches[active][-1])
1398
+ st.event_log.append("Imported run (restored last snapshot).")
1399
+ else:
1400
+ st.event_log.append("Imported run (no snapshots).")
1401
+
1402
+ beliefs = init_beliefs(list(st.agents.keys()))
1403
+ return st, branches, active, r_idx, beliefs
1404
+
1405
+ # -----------------------------
1406
+ # UI helpers
1407
+ # -----------------------------
1408
+ def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, Image.Image, str, str, str, str]:
1409
+ # update beliefs
1410
+ for nm, a in state.agents.items():
1411
+ if a.hp > 0:
1412
+ update_belief_for_agent(state, beliefs[nm], a)
1413
+
1414
+ pov = raycast_view(state, state.agents[state.pov])
1415
+ truth_np = np.array(state.grid, dtype=np.int16)
1416
+ truth_img = render_topdown(truth_np, state.agents, f"Truth Map — env={state.env_key} t={state.step} seed={state.seed}", show_agents=True)
1417
+
1418
+ ctrl = state.controlled
1419
+ # pick "other belief" as someone else
1420
+ others = [k for k in state.agents.keys() if k != ctrl]
1421
+ other = others[0] if others else ctrl
1422
+ b_ctrl = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief", show_agents=True)
1423
+ b_other = render_topdown(beliefs[other], state.agents, f"{other} Belief", show_agents=True)
1424
+
1425
+ dash = metrics_dashboard_image(state)
1426
+
1427
+ status = (
1428
+ f"env={state.env_key} | Controlled={state.controlled} | POV={state.pov} | done={state.done} outcome={state.outcome}\n"
1429
+ f"Episode steps={state.step} | base_progress={state.base_progress}/{state.base_target} | doors_open={state.door_opened_global}\n"
1430
+ f"Global: episodes={state.gmetrics.episodes} | A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws} "
1431
+ f"| winrateA~{state.gmetrics.rolling_winrate_A:.2f} | eps={state.gmetrics.epsilon:.3f}"
1432
+ )
1433
+ events = "\n".join(state.event_log[-18:])
1434
+ trace = "\n".join(state.trace_log[-18:])
1435
+ scoreboard = agent_scoreboard(state)
1436
+ return pov, truth_img, b_ctrl, b_other, dash, status, events, trace, scoreboard
1437
+
1438
+ def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState) -> WorldState:
1439
+ x_px, y_px = evt.index
1440
+ y_px -= 28
1441
+ if y_px < 0:
1442
+ return state
1443
+ gx = int(x_px // TILE)
1444
+ gy = int(y_px // TILE)
1445
+ if not in_bounds(gx, gy):
1446
+ return state
1447
+ # keep border walls fixed
1448
+ if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1:
1449
+ return state
1450
+ state.grid[gy][gx] = selected_tile
1451
+ state.event_log.append(f"t={state.step}: Tile ({gx},{gy}) -> {TILE_NAMES.get(selected_tile)}")
1452
+ return state
1453
+
1454
+ # -----------------------------
1455
+ # Gradio App
1456
+ # -----------------------------
1457
+ TITLE = "ZEN AgentLab — Agent POV + Multi-Agent Mini-Sim Arena"
1458
+
1459
+ with gr.Blocks(title=TITLE) as demo:
1460
+ gr.Markdown(
1461
+ f"## {TITLE}\n"
1462
+ "A multi-environment agent observatory with POV, belief maps, branching timelines, training, and metrics.\n"
1463
+ "**Controls:** No timers. Use Tick / Run / Train for deterministic experiments."
1464
+ )
1465
+
1466
+ # Core state
1467
+ st = gr.State(init_state(1337, "chase"))
1468
+ branches = gr.State({"main": [snapshot_of(init_state(1337, "chase"), "main")]})
1469
+ active_branch = gr.State("main")
1470
+ rewind_idx = gr.State(0)
1471
+ beliefs = gr.State(init_beliefs(list(init_state(1337, "chase").agents.keys())))
1472
+
1473
+ with gr.Row():
1474
+ pov_img = gr.Image(label="POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H)
1475
+ with gr.Column():
1476
+ status = gr.Textbox(label="Status", lines=3)
1477
+ scoreboard = gr.Textbox(label="Agent Scoreboard", lines=8)
1478
+
1479
+ with gr.Row():
1480
+ truth = gr.Image(label="Truth Map (click to edit tiles)", type="pil")
1481
+ belief_a = gr.Image(label="Belief (Controlled)", type="pil")
1482
+ belief_b = gr.Image(label="Belief (Other)", type="pil")
1483
+
1484
+ with gr.Row():
1485
+ dash = gr.Image(label="Metrics Dashboard", type="pil")
1486
+
1487
+ with gr.Row():
1488
+ events = gr.Textbox(label="Event Log", lines=10)
1489
+ trace = gr.Textbox(label="Step Trace (why it happened)", lines=10)
1490
+
1491
+ with gr.Row():
1492
+ with gr.Column(scale=2):
1493
+ gr.Markdown("### Manual Controls")
1494
+ with gr.Row():
1495
+ btn_L = gr.Button("L")
1496
+ btn_F = gr.Button("F")
1497
+ btn_R = gr.Button("R")
1498
+ btn_I = gr.Button("I (Interact)")
1499
+ with gr.Row():
1500
+ btn_tick = gr.Button("Tick")
1501
+ run_steps = gr.Number(value=25, label="Run N steps", precision=0)
1502
+ btn_run = gr.Button("Run")
1503
+
1504
+ with gr.Row():
1505
+ btn_toggle_control = gr.Button("Toggle Controlled")
1506
+ btn_toggle_pov = gr.Button("Toggle POV")
1507
+ overlay = gr.Checkbox(False, label="Overlay reticle")
1508
+
1509
+ gr.Markdown("### Environment + Edit")
1510
+ env_pick = gr.Radio(
1511
+ choices=[("Chase (Predator vs Prey)", "chase"),
1512
+ ("CoopVault (team vs guardian)", "vault"),
1513
+ ("MiniCiv (build + raid)", "civ")],
1514
+ value="chase",
1515
+ label="Environment"
1516
+ )
1517
+ tile_pick = gr.Radio(
1518
+ choices=[(TILE_NAMES[k], k) for k in [EMPTY, WALL, FOOD, NOISE, DOOR, TELE, KEY, EXIT, ARTIFACT, HAZARD, WOOD, ORE, MEDKIT, SWITCH, BASE]],
1519
+ value=WALL,
1520
+ label="Paint tile type"
1521
+ )
1522
+
1523
+ with gr.Column(scale=3):
1524
+ gr.Markdown("### Training Controls (Tabular Q-learning)")
1525
+ use_q = gr.Checkbox(True, label="Use Q-learning (agents with brain='q')")
1526
+ alpha = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="alpha (learn rate)")
1527
+ gamma = gr.Slider(0.5, 0.99, value=0.95, step=0.01, label="gamma (discount)")
1528
+ eps = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="epsilon (exploration)")
1529
+ eps_decay = gr.Slider(0.90, 0.999, value=0.995, step=0.001, label="epsilon decay")
1530
+ eps_min = gr.Slider(0.0, 0.2, value=0.02, step=0.01, label="epsilon min")
1531
+
1532
+ episodes = gr.Number(value=50, label="Train episodes", precision=0)
1533
+ max_steps = gr.Number(value=260, label="Max steps per episode", precision=0)
1534
+ btn_train = gr.Button("Train")
1535
+
1536
+ btn_reset = gr.Button("Reset Episode (keep learning)")
1537
+ btn_reset_all = gr.Button("Reset ALL (wipe Q + metrics)")
1538
+
1539
+ with gr.Row():
1540
+ with gr.Column(scale=2):
1541
+ gr.Markdown("### Timeline + Branching")
1542
+ rewind = gr.Slider(0, 0, value=0, step=1, label="Rewind index (active branch)")
1543
+ btn_jump = gr.Button("Jump to index")
1544
+ new_branch_name = gr.Textbox(value="fork1", label="New branch name")
1545
+ btn_fork = gr.Button("Fork from current rewind")
1546
+
1547
+ with gr.Column(scale=2):
1548
+ branch_pick = gr.Dropdown(choices=["main"], value="main", label="Active branch")
1549
+ btn_set_branch = gr.Button("Set Active Branch")
1550
+
1551
+ with gr.Column(scale=3):
1552
+ export_box = gr.Textbox(label="Export JSON (+ proof hash)", lines=8)
1553
+ btn_export = gr.Button("Export")
1554
+ import_box = gr.Textbox(label="Import JSON", lines=8)
1555
+ btn_import = gr.Button("Import")
1556
+
1557
+ # ---------- UI glue ----------
1558
+ def refresh(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str, bel: Dict[str, np.ndarray], r: int):
1559
+ snaps = branches_d.get(active, [])
1560
+ r_max = max(0, len(snaps) - 1)
1561
+ r = max(0, min(int(r), r_max))
1562
+ pov, tr, ba, bb, dimg, stxt, etxt, ttxt, sb = build_views(state, bel)
1563
+
1564
+ branch_choices = sorted(list(branches_d.keys()))
1565
+ return (
1566
+ pov, tr, ba, bb, dimg,
1567
+ stxt, sb, etxt, ttxt,
1568
+ gr.update(maximum=r_max, value=r),
1569
+ r,
1570
+ gr.update(choices=branch_choices, value=active),
1571
+ gr.update(choices=branch_choices, value=active),
1572
+ )
1573
+
1574
+ def push_hist(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str) -> Dict[str, List[Snapshot]]:
1575
+ branches_d.setdefault(active, [])
1576
+ branches_d[active].append(snapshot_of(state, active))
1577
+ if len(branches_d[active]) > MAX_HISTORY:
1578
+ branches_d[active].pop(0)
1579
+ return branches_d
1580
+
1581
+ def set_cfg(state: WorldState, use_q_v: bool, a: float, g: float, e: float, ed: float, emin: float) -> WorldState:
1582
+ state.cfg.use_q = bool(use_q_v)
1583
+ state.cfg.alpha = float(a)
1584
+ state.cfg.gamma = float(g)
1585
+ state.gmetrics.epsilon = float(e)
1586
+ state.cfg.epsilon_decay = float(ed)
1587
+ state.cfg.epsilon_min = float(emin)
1588
+ return state
1589
+
1590
+ def do_manual(state, branches_d, active, bel, r, act):
1591
+ tick(state, bel, manual_action=act)
1592
+ branches_d = push_hist(state, branches_d, active)
1593
+ r = len(branches_d[active]) - 1
1594
+ out = refresh(state, branches_d, active, bel, r)
1595
+ return out + (state, branches_d, active, bel, r)
1596
+
1597
+ def do_tick(state, branches_d, active, bel, r):
1598
+ tick(state, bel, manual_action=None)
1599
+ branches_d = push_hist(state, branches_d, active)
1600
+ r = len(branches_d[active]) - 1
1601
+ out = refresh(state, branches_d, active, bel, r)
1602
+ return out + (state, branches_d, active, bel, r)
1603
+
1604
+ def do_run(state, branches_d, active, bel, r, n):
1605
+ n = max(1, int(n))
1606
+ for _ in range(n):
1607
+ if state.done:
1608
+ break
1609
+ tick(state, bel, manual_action=None)
1610
+ branches_d = push_hist(state, branches_d, active)
1611
+ r = len(branches_d[active]) - 1
1612
+ out = refresh(state, branches_d, active, bel, r)
1613
+ return out + (state, branches_d, active, bel, r)
1614
+
1615
+ def toggle_control(state, branches_d, active, bel, r):
1616
+ order = list(state.agents.keys())
1617
+ i = order.index(state.controlled)
1618
+ state.controlled = order[(i + 1) % len(order)]
1619
+ state.event_log.append(f"Controlled -> {state.controlled}")
1620
+ branches_d = push_hist(state, branches_d, active)
1621
+ r = len(branches_d[active]) - 1
1622
+ out = refresh(state, branches_d, active, bel, r)
1623
+ return out + (state, branches_d, active, bel, r)
1624
+
1625
+ def toggle_pov(state, branches_d, active, bel, r):
1626
+ order = list(state.agents.keys())
1627
+ i = order.index(state.pov)
1628
+ state.pov = order[(i + 1) % len(order)]
1629
+ state.event_log.append(f"POV -> {state.pov}")
1630
+ branches_d = push_hist(state, branches_d, active)
1631
+ r = len(branches_d[active]) - 1
1632
+ out = refresh(state, branches_d, active, bel, r)
1633
+ return out + (state, branches_d, active, bel, r)
1634
+
1635
+ def set_overlay(state, branches_d, active, bel, r, ov):
1636
+ state.overlay = bool(ov)
1637
+ out = refresh(state, branches_d, active, bel, r)
1638
+ return out + (state, branches_d, active, bel, r)
1639
+
1640
+ def click_truth(tile, state, branches_d, active, bel, r, evt: gr.SelectData):
1641
+ state = grid_click_to_tile(evt, int(tile), state)
1642
+ branches_d = push_hist(state, branches_d, active)
1643
+ r = len(branches_d[active]) - 1
1644
+ out = refresh(state, branches_d, active, bel, r)
1645
+ return out + (state, branches_d, active, bel, r)
1646
+
1647
+ def jump(state, branches_d, active, bel, r, idx):
1648
+ snaps = branches_d.get(active, [])
1649
+ if not snaps:
1650
+ out = refresh(state, branches_d, active, bel, r)
1651
+ return out + (state, branches_d, active, bel, r)
1652
+ idx = max(0, min(int(idx), len(snaps) - 1))
1653
+ state = restore_into(state, snaps[idx])
1654
+ r = idx
1655
+ out = refresh(state, branches_d, active, bel, r)
1656
+ return out + (state, branches_d, active, bel, r)
1657
+
1658
+ def fork_branch(state, branches_d, active, bel, r, new_name):
1659
+ new_name = (new_name or "").strip() or "fork"
1660
+ new_name = new_name.replace(" ", "_")
1661
+ snaps = branches_d.get(active, [])
1662
+ if not snaps:
1663
+ branches_d[new_name] = []
1664
+ branches_d[new_name].append(snapshot_of(state, new_name))
1665
+ else:
1666
+ idx = max(0, min(int(r), len(snaps) - 1))
1667
+ # fork snapshots up to idx (inclusive)
1668
+ branches_d[new_name] = [Snapshot(**asdict(s)) for s in snaps[:idx + 1]]
1669
+ # restore state at fork point (last fork snap)
1670
+ state = restore_into(state, branches_d[new_name][-1])
1671
+ active = new_name
1672
+ state.event_log.append(f"Forked branch -> {new_name}")
1673
+ branches_d = push_hist(state, branches_d, active)
1674
+ r = len(branches_d[active]) - 1
1675
+ out = refresh(state, branches_d, active, bel, r)
1676
+ return out + (state, branches_d, active, bel, r)
1677
+
1678
+ def set_active_branch(state, branches_d, active, bel, r, br):
1679
+ br = br or "main"
1680
+ if br not in branches_d:
1681
+ branches_d[br] = [snapshot_of(state, br)]
1682
+ active = br
1683
+ # restore latest state on that branch
1684
+ if branches_d[active]:
1685
+ state = restore_into(state, branches_d[active][-1])
1686
+ bel = init_beliefs(list(state.agents.keys()))
1687
+ out = refresh(state, branches_d, active, bel, len(branches_d[active]) - 1)
1688
+ r = len(branches_d[active]) - 1
1689
+ return out + (state, branches_d, active, bel, r)
1690
+
1691
+ def change_env(state, branches_d, active, bel, r, env_key):
1692
+ env_key = env_key or "chase"
1693
+ # reset episode but keep learning tables (they are per-agent key so safe)
1694
+ state.env_key = env_key
1695
+ state = reset_episode_keep_learning(state, seed=state.seed)
1696
+ bel = init_beliefs(list(state.agents.keys()))
1697
+ active = "main"
1698
+ branches_d = {"main": [snapshot_of(state, "main")]}
1699
+ r = 0
1700
+ out = refresh(state, branches_d, active, bel, r)
1701
+ return out + (state, branches_d, active, bel, r)
1702
+
1703
+ def reset_ep(state, branches_d, active, bel, r):
1704
+ state = reset_episode_keep_learning(state, seed=state.seed)
1705
+ bel = init_beliefs(list(state.agents.keys()))
1706
+ branches_d = {active: [snapshot_of(state, active)]}
1707
+ r = 0
1708
+ out = refresh(state, branches_d, active, bel, r)
1709
+ return out + (state, branches_d, active, bel, r)
1710
+
1711
+ def reset_all(state, branches_d, active, bel, r, env_key):
1712
+ env_key = env_key or state.env_key
1713
+ state = wipe_all(seed=state.seed, env_key=env_key)
1714
+ bel = init_beliefs(list(state.agents.keys()))
1715
+ active = "main"
1716
+ branches_d = {"main": [snapshot_of(state, "main")]}
1717
+ r = 0
1718
+ out = refresh(state, branches_d, active, bel, r)
1719
+ return out + (state, branches_d, active, bel, r)
1720
+
1721
+ def do_train(state, branches_d, active, bel, r,
1722
+ use_q_v, a, g, e, ed, emin,
1723
+ eps_count, max_s):
1724
+ state = set_cfg(state, use_q_v, a, g, e, ed, emin)
1725
+ state = train(state, episodes=max(1, int(eps_count)), max_steps=max(10, int(max_s)))
1726
+ bel = init_beliefs(list(state.agents.keys()))
1727
+ branches_d = {"main": [snapshot_of(state, "main")]}
1728
+ active = "main"
1729
+ r = 0
1730
+ out = refresh(state, branches_d, active, bel, r)
1731
+ return out + (state, branches_d, active, bel, r)
1732
+
1733
+ def export_fn(state, branches_d, active, r):
1734
+ return export_run(state, branches_d, active, int(r))
1735
+
1736
+ def import_fn(txt):
1737
+ state, branches_d, active, r, bel = import_run(txt)
1738
+ # ensure at least a snapshot
1739
+ branches_d.setdefault(active, [])
1740
+ if not branches_d[active]:
1741
+ branches_d[active].append(snapshot_of(state, active))
1742
+ out = refresh(state, branches_d, active, bel, r)
1743
+ return out + (state, branches_d, active, bel, r)
1744
+
1745
+ # ----- Wire buttons (no fn_kwargs) -----
1746
+ btn_L.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"L"),
1747
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1748
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1749
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1750
+ queue=True)
1751
+
1752
+ btn_F.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"F"),
1753
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1754
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1755
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1756
+ queue=True)
1757
+
1758
+ btn_R.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"R"),
1759
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1760
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1761
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1762
+ queue=True)
1763
+
1764
+ btn_I.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"I"),
1765
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1766
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1767
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1768
+ queue=True)
1769
+
1770
+ btn_tick.click(do_tick,
1771
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1772
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1773
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1774
+ queue=True)
1775
+
1776
+ btn_run.click(do_run,
1777
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, run_steps],
1778
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1779
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1780
+ queue=True)
1781
+
1782
+ btn_toggle_control.click(toggle_control,
1783
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1784
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1785
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1786
+ queue=True)
1787
+
1788
+ btn_toggle_pov.click(toggle_pov,
1789
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1790
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1791
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1792
+ queue=True)
1793
+
1794
+ overlay.change(set_overlay,
1795
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, overlay],
1796
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1797
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1798
+ queue=True)
1799
+
1800
+ env_pick.change(change_env,
1801
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick],
1802
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1803
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1804
+ queue=True)
1805
+
1806
+ truth.select(click_truth,
1807
+ inputs=[tile_pick, st, branches, active_branch, beliefs, rewind_idx],
1808
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1809
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1810
+ queue=True)
1811
+
1812
+ btn_jump.click(jump,
1813
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, rewind],
1814
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1815
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1816
+ queue=True)
1817
+
1818
+ btn_fork.click(fork_branch,
1819
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, new_branch_name],
1820
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1821
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1822
+ queue=True)
1823
+
1824
+ btn_set_branch.click(set_active_branch,
1825
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, branch_pick],
1826
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1827
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1828
+ queue=True)
1829
+
1830
+ btn_reset.click(reset_ep,
1831
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1832
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1833
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1834
+ queue=True)
1835
+
1836
+ btn_reset_all.click(reset_all,
1837
+ inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick],
1838
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1839
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1840
+ queue=True)
1841
+
1842
+ btn_train.click(do_train,
1843
+ inputs=[st, branches, active_branch, beliefs, rewind_idx,
1844
+ use_q, alpha, gamma, eps, eps_decay, eps_min,
1845
+ episodes, max_steps],
1846
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1847
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1848
+ queue=True)
1849
+
1850
+ btn_export.click(export_fn, inputs=[st, branches, active_branch, rewind_idx], outputs=[export_box], queue=True)
1851
+
1852
+ btn_import.click(import_fn,
1853
+ inputs=[import_box],
1854
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1855
+ rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1856
+ queue=True)
1857
+
1858
+ demo.load(refresh,
1859
+ inputs=[st, branches, active_branch, beliefs, rewind_idx],
1860
+ outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1861
+ rewind, rewind_idx, branch_pick, branch_pick],
1862
+ queue=True)
1863
+
1864
+ demo.queue().launch()