ZENLLC commited on
Commit
1f24d62
·
verified ·
1 Parent(s): 0ef25d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -272
app.py CHANGED
@@ -6,23 +6,28 @@ from typing import Dict, List, Tuple, Optional, Any
6
 
7
  import numpy as np
8
  from PIL import Image, ImageDraw
 
9
  import matplotlib.pyplot as plt
 
 
10
  import gradio as gr
11
 
12
  # ============================================================
13
  # ZEN AgentLab — Agent POV + Multi-Agent Mini-Sim Arena
14
- # State-of-the-art evolution of your "ChronoSandbox++" reference.
 
 
15
  #
16
  # Features:
17
  # - Deterministic gridworld + first-person raycast POV
18
  # - Multiple environments (Chase / CoopVault / MiniCiv)
19
- # - Click-to-edit tiles + inventory pickups
20
  # - Full step trace: obs -> action -> reward -> (optional) Q-update
21
- # - Branching timelines (fork from any rewind point)
22
  # - Batch training (tabular Q-learning) + metrics dashboard
23
- # - Export/import full runs + SHA256 proof hash ("Proof-of-Run")
24
  #
25
- # Hugging Face Spaces compatible: no timers, no fn_kwargs
26
  # ============================================================
27
 
28
  # -----------------------------
@@ -74,16 +79,13 @@ TILE_NAMES = {
74
  BASE: "Base",
75
  }
76
 
77
- # Colors
78
  AGENT_COLORS = {
79
  "Predator": (255, 120, 90),
80
  "Prey": (120, 255, 160),
81
  "Scout": (120, 190, 255),
82
-
83
  "Alpha": (255, 205, 120),
84
  "Bravo": (160, 210, 255),
85
  "Guardian": (255, 120, 220),
86
-
87
  "BuilderA": (140, 255, 200),
88
  "BuilderB": (160, 200, 255),
89
  "Raider": (255, 160, 120),
@@ -96,17 +98,16 @@ WALL_BASE = np.array([210, 210, 225], dtype=np.uint8)
96
  WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8)
97
  DOOR_COL = np.array([140, 210, 255], dtype=np.uint8)
98
 
99
- # Keep actions small for tabular stability
100
- ACTIONS = ["L", "F", "R", "I"] # I = interact (pickups/door/key/switch/base)
101
 
102
  # -----------------------------
103
- # Deterministic RNG streams
104
  # -----------------------------
105
  def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator:
106
  mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531)
107
  return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF)
108
 
109
-
110
  # -----------------------------
111
  # Data structures
112
  # -----------------------------
@@ -135,20 +136,20 @@ class TrainConfig:
135
  epsilon_min: float = 0.02
136
  epsilon_decay: float = 0.995
137
 
138
- # shaping (generic)
139
  step_penalty: float = -0.01
140
  explore_reward: float = 0.015
141
  damage_penalty: float = -0.20
142
  heal_reward: float = 0.10
143
 
144
- # chase env shaping
145
  chase_close_coeff: float = 0.03
146
  chase_catch_reward: float = 3.0
147
  chase_escaped_reward: float = 0.2
148
  chase_caught_penalty: float = -3.0
149
  food_reward: float = 0.6
150
 
151
- # vault env shaping
152
  artifact_pick_reward: float = 1.2
153
  exit_win_reward: float = 3.0
154
  guardian_tag_reward: float = 2.0
@@ -156,7 +157,7 @@ class TrainConfig:
156
  switch_reward: float = 0.8
157
  key_reward: float = 0.4
158
 
159
- # civ env shaping
160
  resource_pick_reward: float = 0.15
161
  deposit_reward: float = 0.4
162
  base_progress_win_reward: float = 3.5
@@ -181,7 +182,6 @@ class EpisodeMetrics:
181
  returns: Dict[str, float] = None
182
  action_counts: Dict[str, Dict[str, int]] = None
183
  tiles_discovered: Dict[str, int] = None
184
- q_states: Dict[str, int] = None
185
 
186
  def __post_init__(self):
187
  if self.returns is None:
@@ -190,8 +190,6 @@ class EpisodeMetrics:
190
  self.action_counts = {}
191
  if self.tiles_discovered is None:
192
  self.tiles_discovered = {}
193
- if self.q_states is None:
194
- self.q_states = {}
195
 
196
  @dataclass
197
  class WorldState:
@@ -206,7 +204,7 @@ class WorldState:
206
  overlay: bool
207
 
208
  done: bool
209
- outcome: str # "A_win" | "B_win" | "draw" | "ongoing"
210
 
211
  # env state
212
  door_opened_global: bool = False
@@ -219,7 +217,7 @@ class WorldState:
219
 
220
  # learning
221
  cfg: TrainConfig = None
222
- q_tables: Dict[str, Dict[str, List[float]]] = None # per-agent Q
223
  gmetrics: GlobalMetrics = None
224
  emetrics: EpisodeMetrics = None
225
 
@@ -301,25 +299,23 @@ def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = FOV_DEG) -> b
301
  return abs(diff) <= (fov_deg / 2)
302
 
303
  def visible(state: WorldState, observer: Agent, target: Agent) -> bool:
304
- # doors block LOS like walls for simplicity
305
  if not within_fov(observer, target.x, target.y, FOV_DEG):
306
  return False
307
- # treat door as wall in LOS even if opened, to keep simple
308
  return bresenham_los(state.grid, observer.x, observer.y, target.x, target.y)
309
 
310
  def hash_sha256(txt: str) -> str:
311
  return hashlib.sha256(txt.encode("utf-8")).hexdigest()
312
 
313
  # -----------------------------
314
- # Belief maps / fog-of-war
315
  # -----------------------------
316
  def init_beliefs(agent_names: List[str]) -> Dict[str, np.ndarray]:
317
- b = {}
318
- for nm in agent_names:
319
- b[nm] = -1 * np.ones((GRID_H, GRID_W), dtype=np.int16)
320
- return b
 
321
 
322
- def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> None:
323
  belief[agent.y, agent.x] = state.grid[agent.y][agent.x]
324
  base = math.radians(ORI_DEG[agent.ori])
325
  half = math.radians(FOV_DEG / 2)
@@ -344,6 +340,9 @@ def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent)
344
  if tile == DOOR and not state.door_opened_global:
345
  break
346
 
 
 
 
347
  # -----------------------------
348
  # Rendering
349
  # -----------------------------
@@ -410,7 +409,7 @@ def raycast_view(state: WorldState, observer: Agent) -> np.ndarray:
410
 
411
  # billboards for visible agents
412
  for nm, other in state.agents.items():
413
- if nm == observer.name:
414
  continue
415
  if visible(state, observer, other):
416
  dx = other.x - observer.x
@@ -508,7 +507,7 @@ def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_
508
  return im
509
 
510
  # -----------------------------
511
- # Environments (3 modes)
512
  # -----------------------------
513
  def grid_with_border() -> List[List[int]]:
514
  g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)]
@@ -522,12 +521,10 @@ def grid_with_border() -> List[List[int]]:
522
 
523
  def env_chase(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
524
  g = grid_with_border()
525
- # mid-wall + door
526
  for x in range(4, 17):
527
  g[7][x] = WALL
528
  g[7][10] = DOOR
529
 
530
- # objects
531
  g[3][4] = FOOD
532
  g[11][15] = FOOD
533
  g[4][14] = NOISE
@@ -544,7 +541,6 @@ def env_chase(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
544
 
545
  def env_vault(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
546
  g = grid_with_border()
547
- # internal maze
548
  for x in range(3, 18):
549
  g[5][x] = WALL
550
  for x in range(3, 18):
@@ -552,7 +548,6 @@ def env_vault(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
552
  g[5][10] = DOOR
553
  g[9][12] = DOOR
554
 
555
- # special tiles
556
  g[2][2] = KEY
557
  g[12][18] = EXIT
558
  g[12][2] = ARTIFACT
@@ -572,13 +567,10 @@ def env_vault(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
572
 
573
  def env_civ(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
574
  g = grid_with_border()
575
-
576
- # walls forming zones
577
  for y in range(3, 12):
578
  g[y][9] = WALL
579
  g[7][9] = DOOR
580
 
581
- # resources
582
  g[2][3] = WOOD
583
  g[3][3] = WOOD
584
  g[4][3] = WOOD
@@ -588,7 +580,6 @@ def env_civ(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
588
  g[6][4] = FOOD
589
  g[8][15] = FOOD
590
 
591
- # base + hazards + switch/key
592
  g[13][10] = BASE
593
  g[4][15] = HAZARD
594
  g[10][4] = HAZARD
@@ -596,7 +587,6 @@ def env_civ(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
596
  g[13][2] = TELE
597
  g[2][2] = KEY
598
  g[12][6] = SWITCH
599
- g[7][9] = DOOR
600
 
601
  agents = {
602
  "BuilderA": Agent("BuilderA", 3, 12, 0, hp=10, energy=100, team="A", brain="q"),
@@ -605,14 +595,10 @@ def env_civ(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
605
  }
606
  return g, agents
607
 
608
- ENV_BUILDERS = {
609
- "chase": env_chase,
610
- "vault": env_vault,
611
- "civ": env_civ,
612
- }
613
 
614
  # -----------------------------
615
- # Observation encoding (compact)
616
  # -----------------------------
617
  def local_tile_ahead(state: WorldState, a: Agent) -> int:
618
  dx, dy = DIRS[a.ori]
@@ -623,7 +609,7 @@ def local_tile_ahead(state: WorldState, a: Agent) -> int:
623
 
624
  def nearest_enemy_vec(state: WorldState, a: Agent) -> Tuple[int, int, int]:
625
  best = None
626
- for nm, other in state.agents.items():
627
  if other.hp <= 0:
628
  continue
629
  if other.team == a.team:
@@ -638,15 +624,12 @@ def nearest_enemy_vec(state: WorldState, a: Agent) -> Tuple[int, int, int]:
638
 
639
  def obs_key(state: WorldState, who: str) -> str:
640
  a = state.agents[who]
641
- # include env key so tables don't cross-contaminate
642
- # include coarse enemy vector, tile ahead, inventory coarse
643
  d, dx, dy = nearest_enemy_vec(state, a)
644
  ahead = local_tile_ahead(state, a)
645
  keys = a.inventory.get("key", 0)
646
  art = a.inventory.get("artifact", 0)
647
  wood = a.inventory.get("wood", 0)
648
  ore = a.inventory.get("ore", 0)
649
-
650
  inv_bucket = f"k{min(keys,2)}a{min(art,1)}w{min(wood,3)}o{min(ore,3)}"
651
  door = 1 if state.door_opened_global else 0
652
  return f"{state.env_key}|{who}|{a.x},{a.y},{a.ori}|e{d}:{dx},{dy}|t{ahead}|hp{a.hp}|{inv_bucket}|D{door}|bp{state.base_progress}"
@@ -672,7 +655,7 @@ def q_update(q: Dict[str, List[float]], key: str, a_idx: int, reward: float, nex
672
  return old, target, new
673
 
674
  # -----------------------------
675
- # Heuristic baselines
676
  # -----------------------------
677
  def face_towards(a: Agent, tx: int, ty: int) -> str:
678
  dx = tx - a.x
@@ -690,30 +673,27 @@ def heuristic_action(state: WorldState, who: str) -> str:
690
  a = state.agents[who]
691
  r = rng_for(state.seed, state.step, stream=900 + hash(who) % 1000)
692
 
693
- # simple: if enemy visible, chase if team B (raider/guardian/predator) else flee-ish
694
- # also, prioritize interact if standing on something valuable
695
- tile_here = state.grid[a.y][a.x]
696
- if tile_here in (FOOD, KEY, ARTIFACT, WOOD, ORE, MEDKIT, SWITCH, BASE, EXIT):
697
  return "I"
698
 
699
- # find nearest enemy
700
- best_nm = None
701
- best_d = 999
702
  best = None
703
- for nm, other in state.agents.items():
 
704
  if other.hp <= 0 or other.team == a.team:
705
  continue
706
  d = manhattan_xy(a.x, a.y, other.x, other.y)
707
  if d < best_d:
708
  best_d = d
709
- best_nm = nm
710
  best = other
711
 
712
  if best is not None and best_d <= 6 and visible(state, a, best):
713
- # attackers chase
714
  if a.team == "B":
715
  return face_towards(a, best.x, best.y)
716
- # defenders flee: turn away from enemy vector
717
  dx = best.x - a.x
718
  dy = best.y - a.y
719
  ang = (math.degrees(math.atan2(dy, dx)) % 360)
@@ -726,7 +706,6 @@ def heuristic_action(state: WorldState, who: str) -> str:
726
  return "R"
727
  return "F"
728
 
729
- # mild exploration bias: try forward more
730
  return r.choice(["F", "F", "L", "R", "I"])
731
 
732
  def random_action(state: WorldState, who: str) -> str:
@@ -766,7 +745,6 @@ def move_forward(state: WorldState, a: Agent) -> str:
766
  def try_interact(state: WorldState, a: Agent) -> str:
767
  t = state.grid[a.y][a.x]
768
 
769
- # door open global via key or switch
770
  if t == SWITCH:
771
  state.door_opened_global = True
772
  state.grid[a.y][a.x] = EMPTY
@@ -804,7 +782,6 @@ def try_interact(state: WorldState, a: Agent) -> str:
804
  return "used: medkit"
805
 
806
  if t == BASE:
807
- # deposit resources into base_progress
808
  w = a.inventory.get("wood", 0)
809
  o = a.inventory.get("ore", 0)
810
  dep = min(w, 2) + min(o, 2)
@@ -837,22 +814,19 @@ def apply_action(state: WorldState, who: str, action: str) -> str:
837
  return "noop"
838
 
839
  # -----------------------------
840
- # Combat / hazards / win conditions
841
  # -----------------------------
842
  def resolve_hazards(state: WorldState, a: Agent) -> Tuple[bool, str]:
843
- # returns (took_damage, msg)
844
  if a.hp <= 0:
845
  return (False, "")
846
- t = state.grid[a.y][a.x]
847
- if t == HAZARD:
848
  a.hp -= 1
849
  return (True, "hazard:-hp")
850
  return (False, "")
851
 
852
  def resolve_tags(state: WorldState) -> List[str]:
853
- # if two opposing agents occupy same tile: team B "tags" team A
854
  msgs = []
855
- occupied = {}
856
  for nm, a in state.agents.items():
857
  if a.hp <= 0:
858
  continue
@@ -863,14 +837,12 @@ def resolve_tags(state: WorldState) -> List[str]:
863
  continue
864
  teams = set(state.agents[n].team for n in names)
865
  if len(teams) >= 2:
866
- # tag: both sides take 1 hp damage, but log who collided
867
  for n in names:
868
  state.agents[n].hp -= 1
869
  msgs.append(f"t={state.step}: collision/tag at ({x},{y}) {names} (-hp all)")
870
  return msgs
871
 
872
  def check_done(state: WorldState) -> None:
873
- # Determine environment-specific terminal conditions
874
  if state.env_key == "chase":
875
  pred = state.agents["Predator"]
876
  prey = state.agents["Prey"]
@@ -878,12 +850,11 @@ def check_done(state: WorldState) -> None:
878
  state.done = True
879
  state.outcome = "draw"
880
  return
881
- if pred.x == prey.x and pred.y == prey.y and pred.hp > 0 and prey.hp > 0:
882
  state.done = True
883
- state.outcome = "A_win" # Predator team A
884
  state.event_log.append(f"t={state.step}: CAUGHT (Predator wins).")
885
  return
886
- # prey "escape" win if survives long enough with food? Use energy threshold
887
  if state.step >= 300 and prey.hp > 0:
888
  state.done = True
889
  state.outcome = "B_win"
@@ -891,7 +862,6 @@ def check_done(state: WorldState) -> None:
891
  return
892
 
893
  if state.env_key == "vault":
894
- # Team A wins if any A has artifact and reaches exit
895
  for nm in ["Alpha", "Bravo"]:
896
  a = state.agents[nm]
897
  if a.hp > 0 and a.inventory.get("artifact", 0) > 0 and state.grid[a.y][a.x] == EXIT:
@@ -899,7 +869,6 @@ def check_done(state: WorldState) -> None:
899
  state.outcome = "A_win"
900
  state.event_log.append(f"t={state.step}: VAULT CLEARED (Team A wins).")
901
  return
902
- # Team B wins if all A agents eliminated
903
  alive_A = any(state.agents[n].hp > 0 for n in ["Alpha", "Bravo"])
904
  if not alive_A:
905
  state.done = True
@@ -908,20 +877,17 @@ def check_done(state: WorldState) -> None:
908
  return
909
 
910
  if state.env_key == "civ":
911
- # Team A wins if base_progress reaches target
912
  if state.base_progress >= state.base_target:
913
  state.done = True
914
  state.outcome = "A_win"
915
  state.event_log.append(f"t={state.step}: BASE COMPLETE (Builders win).")
916
  return
917
- # Team B wins if both builders eliminated
918
  alive_A = any(state.agents[n].hp > 0 for n in ["BuilderA", "BuilderB"])
919
  if not alive_A:
920
  state.done = True
921
  state.outcome = "B_win"
922
  state.event_log.append(f"t={state.step}: BUILDERS ELIMINATED (Raider wins).")
923
  return
924
- # draw if too long
925
  if state.step >= 350:
926
  state.done = True
927
  state.outcome = "draw"
@@ -931,88 +897,72 @@ def check_done(state: WorldState) -> None:
931
  # -----------------------------
932
  # Rewards
933
  # -----------------------------
934
- def reward_for(state_prev: WorldState, state_now: WorldState, who: str, outcome_msg: str,
935
- took_damage: bool, interacted: bool) -> float:
936
- cfg = state_now.cfg
937
- a0 = state_prev.agents[who]
938
- a1 = state_now.agents[who]
939
-
940
  r = cfg.step_penalty
941
-
942
- # exploration reward: if agent discovered new tiles in belief (we approximate via emetrics tiles_discovered)
943
- # we update this outside; here just tiny reward if moved
944
  if outcome_msg.startswith("moved"):
945
  r += cfg.explore_reward
946
-
947
  if took_damage:
948
  r += cfg.damage_penalty
949
-
950
- # heal reward if used medkit
951
  if outcome_msg.startswith("used: medkit"):
952
  r += cfg.heal_reward
953
 
954
- # environment shaping
955
- if state_now.env_key == "chase":
956
- pred = state_now.agents["Predator"]
957
- prey = state_now.agents["Prey"]
958
  if who == "Predator":
959
- d0 = manhattan_xy(state_prev.agents["Predator"].x, state_prev.agents["Predator"].y,
960
- state_prev.agents["Prey"].x, state_prev.agents["Prey"].y)
961
  d1 = manhattan_xy(pred.x, pred.y, prey.x, prey.y)
962
  r += cfg.chase_close_coeff * float(d0 - d1)
963
- if state_now.done and state_now.outcome == "A_win":
964
  r += cfg.chase_catch_reward
965
  if who == "Prey":
966
  if outcome_msg.startswith("ate: food"):
967
  r += cfg.food_reward
968
- if state_now.done and state_now.outcome == "B_win":
969
  r += cfg.chase_escaped_reward
970
- if state_now.done and state_now.outcome == "A_win":
971
  r += cfg.chase_caught_penalty
972
 
973
- if state_now.env_key == "vault":
974
  if outcome_msg.startswith("picked: artifact"):
975
  r += cfg.artifact_pick_reward
976
  if outcome_msg.startswith("picked: key"):
977
  r += cfg.key_reward
978
  if outcome_msg.startswith("switch:"):
979
  r += cfg.switch_reward
980
- if state_now.done:
981
- if state_now.outcome == "A_win" and state_now.agents[who].team == "A":
982
  r += cfg.exit_win_reward
983
- if state_now.outcome == "B_win" and state_now.agents[who].team == "B":
984
  r += cfg.guardian_tag_reward
985
- if state_now.outcome == "B_win" and state_now.agents[who].team == "A":
986
  r += cfg.tagged_penalty
987
 
988
- if state_now.env_key == "civ":
989
  if outcome_msg.startswith("picked: wood") or outcome_msg.startswith("picked: ore"):
990
  r += cfg.resource_pick_reward
991
  if outcome_msg.startswith("deposited:"):
992
  r += cfg.deposit_reward
993
- if state_now.done:
994
- if state_now.outcome == "A_win" and state_now.agents[who].team == "A":
995
  r += cfg.base_progress_win_reward
996
- if state_now.outcome == "B_win" and state_now.agents[who].team == "B":
997
  r += cfg.raider_elim_reward
998
- if state_now.outcome == "B_win" and state_now.agents[who].team == "A":
999
  r += cfg.builder_elim_penalty
1000
 
1001
  return float(r)
1002
 
1003
  # -----------------------------
1004
- # Q / policy selection
1005
  # -----------------------------
1006
  def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, Optional[Tuple[str, int]]]:
1007
- """
1008
- Returns (action, reason, q_info)
1009
- q_info: (obs_key, action_index) if Q-based else None
1010
- """
1011
  a = state.agents[who]
1012
  cfg = state.cfg
1013
  r = rng_for(state.seed, state.step, stream=stream)
1014
 
1015
- # manual control handled outside
1016
  if a.brain == "random":
1017
  act = random_action(state, who)
1018
  return act, "random", None
@@ -1020,7 +970,6 @@ def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, O
1020
  act = heuristic_action(state, who)
1021
  return act, "heuristic", None
1022
 
1023
- # Q learning
1024
  if cfg.use_q:
1025
  key = obs_key(state, who)
1026
  qtab = state.q_tables.setdefault(who, {})
@@ -1032,11 +981,10 @@ def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, O
1032
  return act, "heuristic(fallback)", None
1033
 
1034
  # -----------------------------
1035
- # Episode initialization / reset
1036
  # -----------------------------
1037
  def init_state(seed: int, env_key: str) -> WorldState:
1038
  g, agents = ENV_BUILDERS[env_key](seed)
1039
-
1040
  st = WorldState(
1041
  seed=seed,
1042
  step=0,
@@ -1059,7 +1007,6 @@ def reset_episode_keep_learning(state: WorldState, seed: Optional[int] = None) -
1059
  if seed is None:
1060
  seed = state.seed
1061
  fresh = init_state(int(seed), state.env_key)
1062
- # carry learning + global metrics
1063
  fresh.cfg = state.cfg
1064
  fresh.q_tables = state.q_tables
1065
  fresh.gmetrics = state.gmetrics
@@ -1110,52 +1057,52 @@ def restore_into(state: WorldState, snap: Snapshot) -> WorldState:
1110
  return state
1111
 
1112
  # -----------------------------
1113
- # Metrics / dashboard
1114
  # -----------------------------
1115
- def update_action_counts(state: WorldState, who: str, act: str):
1116
- state.emetrics.action_counts.setdefault(who, {})
1117
- state.emetrics.action_counts[who][act] = state.emetrics.action_counts[who].get(act, 0) + 1
1118
-
1119
- def action_entropy(counts: Dict[str, int]) -> float:
1120
- total = sum(counts.values())
1121
- if total <= 0:
1122
- return 0.0
1123
- p = np.array([c / total for c in counts.values()], dtype=np.float64)
1124
- p = np.clip(p, 1e-12, 1.0)
1125
- return float(-np.sum(p * np.log2(p)))
1126
-
1127
  def metrics_dashboard_image(state: WorldState) -> Image.Image:
1128
  gm = state.gmetrics
 
1129
  fig = plt.figure(figsize=(7.0, 2.2), dpi=120)
1130
  ax = fig.add_subplot(111)
1131
- ax.plot([0, gm.episodes], [gm.rolling_winrate_A, gm.rolling_winrate_A])
 
 
1132
  ax.set_title("Global Metrics Snapshot")
1133
- ax.set_xlabel("Episodes (scalar)")
1134
  ax.set_ylabel("Rolling winrate Team A")
 
1135
  ax.grid(True)
1136
 
1137
- # annotate
1138
  txt = (
1139
  f"env={state.env_key} | eps={gm.epsilon:.3f} | episodes={gm.episodes}\n"
1140
- f"A_wins={gm.wins_teamA} B_wins={gm.wins_teamB} draws={gm.draws} | "
1141
- f"avg_steps~{gm.avg_steps:.1f}\n"
1142
  f"last_outcome={gm.last_outcome} last_steps={gm.last_steps}"
1143
  )
1144
  ax.text(0.01, 0.05, txt, transform=ax.transAxes, fontsize=8, va="bottom")
1145
 
1146
  fig.tight_layout()
1147
- fig.canvas.draw()
1148
- w, h = fig.canvas.get_width_height()
1149
- img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(h, w, 3)
 
 
1150
  plt.close(fig)
1151
- return Image.fromarray(img)
 
 
 
 
 
 
 
 
1152
 
1153
  def agent_scoreboard(state: WorldState) -> str:
1154
  rows = []
1155
  header = ["agent", "team", "hp", "return", "steps", "entropy", "tiles_disc", "q_states", "inventory"]
1156
  rows.append(header)
1157
-
1158
  steps = state.emetrics.steps
 
1159
  for nm, a in state.agents.items():
1160
  ret = state.emetrics.returns.get(nm, 0.0)
1161
  counts = state.emetrics.action_counts.get(nm, {})
@@ -1165,7 +1112,6 @@ def agent_scoreboard(state: WorldState) -> str:
1165
  inv = json.dumps(a.inventory, sort_keys=True)
1166
  rows.append([nm, a.team, a.hp, f"{ret:.2f}", steps, f"{ent:.2f}", td, qs, inv])
1167
 
1168
- # pretty format as fixed-width table
1169
  col_w = [max(len(str(r[i])) for r in rows) for i in range(len(header))]
1170
  lines = []
1171
  for ridx, r in enumerate(rows):
@@ -1176,11 +1122,10 @@ def agent_scoreboard(state: WorldState) -> str:
1176
  return "\n".join(lines)
1177
 
1178
  # -----------------------------
1179
- # Tick (core simulation step)
1180
  # -----------------------------
1181
  def clone_shallow(state: WorldState) -> WorldState:
1182
- # minimal clone to compute rewards
1183
- st = WorldState(
1184
  seed=state.seed,
1185
  step=state.step,
1186
  env_key=state.env_key,
@@ -1201,15 +1146,16 @@ def clone_shallow(state: WorldState) -> WorldState:
1201
  gmetrics=state.gmetrics,
1202
  emetrics=state.emetrics,
1203
  )
1204
- return st
 
 
 
1205
 
1206
  def tick(state: WorldState, beliefs: Dict[str, np.ndarray], manual_action: Optional[str] = None) -> None:
1207
  if state.done:
1208
  return
1209
 
1210
  prev = clone_shallow(state)
1211
-
1212
- # pick actions
1213
  chosen: Dict[str, str] = {}
1214
  reasons: Dict[str, str] = {}
1215
  qinfo: Dict[str, Optional[Tuple[str, int]]] = {}
@@ -1219,8 +1165,9 @@ def tick(state: WorldState, beliefs: Dict[str, np.ndarray], manual_action: Optio
1219
  reasons[state.controlled] = "manual"
1220
  qinfo[state.controlled] = None
1221
 
1222
- # others choose
1223
- for who in list(state.agents.keys()):
 
1224
  if who in chosen:
1225
  continue
1226
  act, reason, qi = choose_action(state, who, stream=200 + (hash(who) % 1000))
@@ -1228,50 +1175,38 @@ def tick(state: WorldState, beliefs: Dict[str, np.ndarray], manual_action: Optio
1228
  reasons[who] = reason
1229
  qinfo[who] = qi
1230
 
1231
- # apply actions in fixed order (deterministic)
1232
- order = list(state.agents.keys())
1233
  outcomes: Dict[str, str] = {}
1234
  took_damage: Dict[str, bool] = {nm: False for nm in order}
1235
- interacted: Dict[str, bool] = {nm: False for nm in order}
1236
 
1237
  for who in order:
1238
- before_tile = state.grid[state.agents[who].y][state.agents[who].x] if state.agents[who].hp > 0 else EMPTY
1239
  outcomes[who] = apply_action(state, who, chosen[who])
1240
- if chosen[who] == "I" and outcomes[who] != "interact: none":
1241
- interacted[who] = True
1242
 
1243
  dmg, msg = resolve_hazards(state, state.agents[who])
1244
  took_damage[who] = dmg
1245
  if msg:
1246
  state.event_log.append(f"t={state.step}: {who} {msg}")
1247
 
1248
- # track action counts
1249
  update_action_counts(state, who, chosen[who])
1250
 
1251
- # collisions/tags after movement
1252
- tag_msgs = resolve_tags(state)
1253
- for m in tag_msgs:
1254
  state.event_log.append(m)
1255
 
1256
- # update beliefs / tiles discovered metric
1257
  for nm, a in state.agents.items():
1258
  if a.hp <= 0:
1259
  continue
1260
- before_unknown = int(np.sum(beliefs[nm] == -1))
1261
- update_belief_for_agent(state, beliefs[nm], a)
1262
- after_unknown = int(np.sum(beliefs[nm] == -1))
1263
- discovered = max(0, before_unknown - after_unknown)
1264
- state.emetrics.tiles_discovered[nm] = state.emetrics.tiles_discovered.get(nm, 0) + discovered
1265
 
1266
- # check done conditions
1267
  check_done(state)
1268
 
1269
- # rewards + Q updates + returns
1270
  q_lines = []
1271
  for who in order:
1272
  if who not in state.emetrics.returns:
1273
  state.emetrics.returns[who] = 0.0
1274
- r = reward_for(prev, state, who, outcomes[who], took_damage[who], interacted[who])
 
1275
  state.emetrics.returns[who] += r
1276
 
1277
  if qinfo.get(who) is not None:
@@ -1281,7 +1216,6 @@ def tick(state: WorldState, beliefs: Dict[str, np.ndarray], manual_action: Optio
1281
  old, tgt, new = q_update(qtab, key, a_idx, r, next_key, state.cfg.alpha, state.cfg.gamma)
1282
  q_lines.append(f"{who}: old={old:.3f} tgt={tgt:.3f} new={new:.3f} (a={ACTIONS[a_idx]})")
1283
 
1284
- # trace
1285
  trace = f"t={state.step} env={state.env_key} done={state.done} outcome={state.outcome}"
1286
  for who in order:
1287
  a = state.agents[who]
@@ -1296,9 +1230,6 @@ def tick(state: WorldState, beliefs: Dict[str, np.ndarray], manual_action: Optio
1296
  state.step += 1
1297
  state.emetrics.steps = state.step
1298
 
1299
- # -----------------------------
1300
- # Training
1301
- # -----------------------------
1302
  def run_episode(state: WorldState, beliefs: Dict[str, np.ndarray], max_steps: int) -> Tuple[str, int]:
1303
  while state.step < max_steps and not state.done:
1304
  tick(state, beliefs, manual_action=None)
@@ -1321,23 +1252,20 @@ def update_global_metrics_after_episode(state: WorldState, outcome: str, steps:
1321
  gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.5
1322
 
1323
  gm.avg_steps = (0.90 * gm.avg_steps + 0.10 * steps) if gm.avg_steps > 0 else float(steps)
1324
-
1325
- # epsilon decay
1326
  gm.epsilon = max(state.cfg.epsilon_min, gm.epsilon * state.cfg.epsilon_decay)
1327
 
1328
  def train(state: WorldState, episodes: int, max_steps: int) -> WorldState:
1329
  for ep in range(episodes):
1330
- # vary seed deterministically per episode
1331
  ep_seed = (state.seed * 1_000_003 + (state.gmetrics.episodes + ep) * 97_531) & 0xFFFFFFFF
1332
  state = reset_episode_keep_learning(state, seed=int(ep_seed))
1333
  beliefs = init_beliefs(list(state.agents.keys()))
1334
  outcome, steps = run_episode(state, beliefs, max_steps=max_steps)
1335
  update_global_metrics_after_episode(state, outcome, steps)
 
1336
  state.event_log.append(
1337
  f"Training: +{episodes} eps | eps={state.gmetrics.epsilon:.3f} | "
1338
  f"A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws}"
1339
  )
1340
- # after training return a clean episode at current seed
1341
  state = reset_episode_keep_learning(state, seed=state.seed)
1342
  return state
1343
 
@@ -1356,7 +1284,7 @@ def export_run(state: WorldState, branches: Dict[str, List[Snapshot]], active_br
1356
  "q_tables": state.q_tables,
1357
  "branches": {b: [asdict(s) for s in snaps] for b, snaps in branches.items()},
1358
  "active_branch": active_branch,
1359
- "rewind_idx": rewind_idx,
1360
  "grid": state.grid,
1361
  "door_opened_global": state.door_opened_global,
1362
  "base_progress": state.base_progress,
@@ -1367,7 +1295,6 @@ def export_run(state: WorldState, branches: Dict[str, List[Snapshot]], active_br
1367
  return txt + "\n\n" + json.dumps({"proof_sha256": proof}, indent=2)
1368
 
1369
  def import_run(txt: str) -> Tuple[WorldState, Dict[str, List[Snapshot]], str, int, Dict[str, np.ndarray]]:
1370
- # allow trailing proof block
1371
  parts = txt.strip().split("\n\n")
1372
  data = json.loads(parts[0])
1373
 
@@ -1392,7 +1319,6 @@ def import_run(txt: str) -> Tuple[WorldState, Dict[str, List[Snapshot]], str, in
1392
  active = data.get("active_branch", "main")
1393
  r_idx = int(data.get("rewind_idx", 0))
1394
 
1395
- # restore last snap of active branch if exists
1396
  if active in branches and branches[active]:
1397
  st = restore_into(st, branches[active][-1])
1398
  st.event_log.append("Imported run (restored last snapshot).")
@@ -1406,21 +1332,19 @@ def import_run(txt: str) -> Tuple[WorldState, Dict[str, List[Snapshot]], str, in
1406
  # UI helpers
1407
  # -----------------------------
1408
  def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, Image.Image, str, str, str, str]:
1409
- # update beliefs
1410
  for nm, a in state.agents.items():
1411
  if a.hp > 0:
1412
  update_belief_for_agent(state, beliefs[nm], a)
1413
 
1414
  pov = raycast_view(state, state.agents[state.pov])
1415
  truth_np = np.array(state.grid, dtype=np.int16)
1416
- truth_img = render_topdown(truth_np, state.agents, f"Truth Map — env={state.env_key} t={state.step} seed={state.seed}", show_agents=True)
1417
 
1418
  ctrl = state.controlled
1419
- # pick "other belief" as someone else
1420
  others = [k for k in state.agents.keys() if k != ctrl]
1421
  other = others[0] if others else ctrl
1422
- b_ctrl = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief", show_agents=True)
1423
- b_other = render_topdown(beliefs[other], state.agents, f"{other} Belief", show_agents=True)
1424
 
1425
  dash = metrics_dashboard_image(state)
1426
 
@@ -1444,7 +1368,6 @@ def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState
1444
  gy = int(y_px // TILE)
1445
  if not in_bounds(gx, gy):
1446
  return state
1447
- # keep border walls fixed
1448
  if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1:
1449
  return state
1450
  state.grid[gy][gx] = selected_tile
@@ -1452,23 +1375,23 @@ def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState
1452
  return state
1453
 
1454
  # -----------------------------
1455
- # Gradio App
1456
  # -----------------------------
1457
  TITLE = "ZEN AgentLab — Agent POV + Multi-Agent Mini-Sim Arena"
1458
 
1459
  with gr.Blocks(title=TITLE) as demo:
1460
  gr.Markdown(
1461
  f"## {TITLE}\n"
1462
- "A multi-environment agent observatory with POV, belief maps, branching timelines, training, and metrics.\n"
1463
- "**Controls:** No timers. Use Tick / Run / Train for deterministic experiments."
1464
  )
1465
 
1466
- # Core state
1467
- st = gr.State(init_state(1337, "chase"))
1468
- branches = gr.State({"main": [snapshot_of(init_state(1337, "chase"), "main")]})
1469
  active_branch = gr.State("main")
1470
  rewind_idx = gr.State(0)
1471
- beliefs = gr.State(init_beliefs(list(init_state(1337, "chase").agents.keys())))
1472
 
1473
  with gr.Row():
1474
  pov_img = gr.Image(label="POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H)
@@ -1486,7 +1409,7 @@ with gr.Blocks(title=TITLE) as demo:
1486
 
1487
  with gr.Row():
1488
  events = gr.Textbox(label="Event Log", lines=10)
1489
- trace = gr.Textbox(label="Step Trace (why it happened)", lines=10)
1490
 
1491
  with gr.Row():
1492
  with gr.Column(scale=2):
@@ -1523,14 +1446,14 @@ with gr.Blocks(title=TITLE) as demo:
1523
  with gr.Column(scale=3):
1524
  gr.Markdown("### Training Controls (Tabular Q-learning)")
1525
  use_q = gr.Checkbox(True, label="Use Q-learning (agents with brain='q')")
1526
- alpha = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="alpha (learn rate)")
1527
- gamma = gr.Slider(0.5, 0.99, value=0.95, step=0.01, label="gamma (discount)")
1528
- eps = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="epsilon (exploration)")
1529
  eps_decay = gr.Slider(0.90, 0.999, value=0.995, step=0.001, label="epsilon decay")
1530
  eps_min = gr.Slider(0.0, 0.2, value=0.02, step=0.01, label="epsilon min")
1531
 
1532
  episodes = gr.Number(value=50, label="Train episodes", precision=0)
1533
- max_steps = gr.Number(value=260, label="Max steps per episode", precision=0)
1534
  btn_train = gr.Button("Train")
1535
 
1536
  btn_reset = gr.Button("Reset Episode (keep learning)")
@@ -1554,19 +1477,16 @@ with gr.Blocks(title=TITLE) as demo:
1554
  import_box = gr.Textbox(label="Import JSON", lines=8)
1555
  btn_import = gr.Button("Import")
1556
 
1557
- # ---------- UI glue ----------
1558
  def refresh(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str, bel: Dict[str, np.ndarray], r: int):
1559
  snaps = branches_d.get(active, [])
1560
  r_max = max(0, len(snaps) - 1)
1561
  r = max(0, min(int(r), r_max))
1562
  pov, tr, ba, bb, dimg, stxt, etxt, ttxt, sb = build_views(state, bel)
1563
-
1564
  branch_choices = sorted(list(branches_d.keys()))
1565
  return (
1566
- pov, tr, ba, bb, dimg,
1567
- stxt, sb, etxt, ttxt,
1568
- gr.update(maximum=r_max, value=r),
1569
- r,
1570
  gr.update(choices=branch_choices, value=active),
1571
  gr.update(choices=branch_choices, value=active),
1572
  )
@@ -1664,9 +1584,7 @@ with gr.Blocks(title=TITLE) as demo:
1664
  branches_d[new_name].append(snapshot_of(state, new_name))
1665
  else:
1666
  idx = max(0, min(int(r), len(snaps) - 1))
1667
- # fork snapshots up to idx (inclusive)
1668
  branches_d[new_name] = [Snapshot(**asdict(s)) for s in snaps[:idx + 1]]
1669
- # restore state at fork point (last fork snap)
1670
  state = restore_into(state, branches_d[new_name][-1])
1671
  active = new_name
1672
  state.event_log.append(f"Forked branch -> {new_name}")
@@ -1680,17 +1598,15 @@ with gr.Blocks(title=TITLE) as demo:
1680
  if br not in branches_d:
1681
  branches_d[br] = [snapshot_of(state, br)]
1682
  active = br
1683
- # restore latest state on that branch
1684
  if branches_d[active]:
1685
  state = restore_into(state, branches_d[active][-1])
1686
  bel = init_beliefs(list(state.agents.keys()))
1687
- out = refresh(state, branches_d, active, bel, len(branches_d[active]) - 1)
1688
  r = len(branches_d[active]) - 1
 
1689
  return out + (state, branches_d, active, bel, r)
1690
 
1691
  def change_env(state, branches_d, active, bel, r, env_key):
1692
  env_key = env_key or "chase"
1693
- # reset episode but keep learning tables (they are per-agent key so safe)
1694
  state.env_key = env_key
1695
  state = reset_episode_keep_learning(state, seed=state.seed)
1696
  bel = init_beliefs(list(state.agents.keys()))
@@ -1735,130 +1651,102 @@ with gr.Blocks(title=TITLE) as demo:
1735
 
1736
  def import_fn(txt):
1737
  state, branches_d, active, r, bel = import_run(txt)
1738
- # ensure at least a snapshot
1739
  branches_d.setdefault(active, [])
1740
  if not branches_d[active]:
1741
  branches_d[active].append(snapshot_of(state, active))
1742
  out = refresh(state, branches_d, active, bel, r)
1743
  return out + (state, branches_d, active, bel, r)
1744
 
1745
- # ----- Wire buttons (no fn_kwargs) -----
 
 
 
 
 
 
1746
  btn_L.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"L"),
1747
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1748
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1749
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1750
- queue=True)
1751
 
1752
  btn_F.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"F"),
1753
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1754
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1755
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1756
- queue=True)
1757
 
1758
  btn_R.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"R"),
1759
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1760
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1761
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1762
- queue=True)
1763
 
1764
  btn_I.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"I"),
1765
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1766
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1767
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1768
- queue=True)
1769
 
1770
  btn_tick.click(do_tick,
1771
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1772
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1773
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1774
- queue=True)
1775
 
1776
  btn_run.click(do_run,
1777
  inputs=[st, branches, active_branch, beliefs, rewind_idx, run_steps],
1778
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1779
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1780
- queue=True)
1781
 
1782
  btn_toggle_control.click(toggle_control,
1783
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1784
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1785
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1786
- queue=True)
1787
 
1788
  btn_toggle_pov.click(toggle_pov,
1789
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1790
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1791
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1792
- queue=True)
1793
 
1794
  overlay.change(set_overlay,
1795
  inputs=[st, branches, active_branch, beliefs, rewind_idx, overlay],
1796
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1797
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1798
- queue=True)
1799
 
1800
  env_pick.change(change_env,
1801
  inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick],
1802
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1803
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1804
- queue=True)
1805
 
1806
  truth.select(click_truth,
1807
  inputs=[tile_pick, st, branches, active_branch, beliefs, rewind_idx],
1808
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1809
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1810
- queue=True)
1811
 
1812
  btn_jump.click(jump,
1813
  inputs=[st, branches, active_branch, beliefs, rewind_idx, rewind],
1814
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1815
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1816
- queue=True)
1817
 
1818
  btn_fork.click(fork_branch,
1819
  inputs=[st, branches, active_branch, beliefs, rewind_idx, new_branch_name],
1820
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1821
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1822
- queue=True)
1823
 
1824
  btn_set_branch.click(set_active_branch,
1825
  inputs=[st, branches, active_branch, beliefs, rewind_idx, branch_pick],
1826
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1827
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1828
- queue=True)
1829
 
1830
  btn_reset.click(reset_ep,
1831
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1832
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1833
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1834
- queue=True)
1835
 
1836
  btn_reset_all.click(reset_all,
1837
  inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick],
1838
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1839
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1840
- queue=True)
1841
 
1842
  btn_train.click(do_train,
1843
  inputs=[st, branches, active_branch, beliefs, rewind_idx,
1844
  use_q, alpha, gamma, eps, eps_decay, eps_min,
1845
  episodes, max_steps],
1846
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1847
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1848
- queue=True)
1849
 
1850
  btn_export.click(export_fn, inputs=[st, branches, active_branch, rewind_idx], outputs=[export_box], queue=True)
1851
 
1852
  btn_import.click(import_fn,
1853
  inputs=[import_box],
1854
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1855
- rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx],
1856
- queue=True)
1857
 
1858
  demo.load(refresh,
1859
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1860
- outputs=[pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1861
- rewind, rewind_idx, branch_pick, branch_pick],
 
 
1862
  queue=True)
1863
 
1864
- demo.queue().launch()
 
 
6
 
7
  import numpy as np
8
  from PIL import Image, ImageDraw
9
+
10
  import matplotlib.pyplot as plt
11
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
12
+
13
  import gradio as gr
14
 
15
  # ============================================================
16
  # ZEN AgentLab — Agent POV + Multi-Agent Mini-Sim Arena
17
+ #
18
+ # Fix included:
19
+ # - Matplotlib rendering uses FigureCanvas.buffer_rgba() (HF-safe)
20
  #
21
  # Features:
22
  # - Deterministic gridworld + first-person raycast POV
23
  # - Multiple environments (Chase / CoopVault / MiniCiv)
24
+ # - Click-to-edit tiles + pickups + hazards + simple combat tags
25
  # - Full step trace: obs -> action -> reward -> (optional) Q-update
26
+ # - Branching timelines (rewind + fork)
27
  # - Batch training (tabular Q-learning) + metrics dashboard
28
+ # - Export/import full runs + SHA256 proof hash
29
  #
30
+ # HF Spaces compatible: no timers, no fn_kwargs
31
  # ============================================================
32
 
33
  # -----------------------------
 
79
  BASE: "Base",
80
  }
81
 
 
82
  AGENT_COLORS = {
83
  "Predator": (255, 120, 90),
84
  "Prey": (120, 255, 160),
85
  "Scout": (120, 190, 255),
 
86
  "Alpha": (255, 205, 120),
87
  "Bravo": (160, 210, 255),
88
  "Guardian": (255, 120, 220),
 
89
  "BuilderA": (140, 255, 200),
90
  "BuilderB": (160, 200, 255),
91
  "Raider": (255, 160, 120),
 
98
  WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8)
99
  DOOR_COL = np.array([140, 210, 255], dtype=np.uint8)
100
 
101
+ # Small action space for tabular stability
102
+ ACTIONS = ["L", "F", "R", "I"] # interact
103
 
104
  # -----------------------------
105
+ # Deterministic RNG
106
  # -----------------------------
107
  def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator:
108
  mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531)
109
  return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF)
110
 
 
111
  # -----------------------------
112
  # Data structures
113
  # -----------------------------
 
136
  epsilon_min: float = 0.02
137
  epsilon_decay: float = 0.995
138
 
139
+ # generic shaping
140
  step_penalty: float = -0.01
141
  explore_reward: float = 0.015
142
  damage_penalty: float = -0.20
143
  heal_reward: float = 0.10
144
 
145
+ # chase
146
  chase_close_coeff: float = 0.03
147
  chase_catch_reward: float = 3.0
148
  chase_escaped_reward: float = 0.2
149
  chase_caught_penalty: float = -3.0
150
  food_reward: float = 0.6
151
 
152
+ # vault
153
  artifact_pick_reward: float = 1.2
154
  exit_win_reward: float = 3.0
155
  guardian_tag_reward: float = 2.0
 
157
  switch_reward: float = 0.8
158
  key_reward: float = 0.4
159
 
160
+ # civ
161
  resource_pick_reward: float = 0.15
162
  deposit_reward: float = 0.4
163
  base_progress_win_reward: float = 3.5
 
182
  returns: Dict[str, float] = None
183
  action_counts: Dict[str, Dict[str, int]] = None
184
  tiles_discovered: Dict[str, int] = None
 
185
 
186
  def __post_init__(self):
187
  if self.returns is None:
 
190
  self.action_counts = {}
191
  if self.tiles_discovered is None:
192
  self.tiles_discovered = {}
 
 
193
 
194
  @dataclass
195
  class WorldState:
 
204
  overlay: bool
205
 
206
  done: bool
207
+ outcome: str # A_win | B_win | draw | ongoing
208
 
209
  # env state
210
  door_opened_global: bool = False
 
217
 
218
  # learning
219
  cfg: TrainConfig = None
220
+ q_tables: Dict[str, Dict[str, List[float]]] = None
221
  gmetrics: GlobalMetrics = None
222
  emetrics: EpisodeMetrics = None
223
 
 
299
  return abs(diff) <= (fov_deg / 2)
300
 
301
  def visible(state: WorldState, observer: Agent, target: Agent) -> bool:
 
302
  if not within_fov(observer, target.x, target.y, FOV_DEG):
303
  return False
 
304
  return bresenham_los(state.grid, observer.x, observer.y, target.x, target.y)
305
 
306
  def hash_sha256(txt: str) -> str:
307
  return hashlib.sha256(txt.encode("utf-8")).hexdigest()
308
 
309
  # -----------------------------
310
+ # Beliefs / fog-of-war
311
  # -----------------------------
312
  def init_beliefs(agent_names: List[str]) -> Dict[str, np.ndarray]:
313
+ return {nm: (-1 * np.ones((GRID_H, GRID_W), dtype=np.int16)) for nm in agent_names}
314
+
315
+ def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> int:
316
+ """Returns number of newly discovered tiles this update."""
317
+ before_unknown = int(np.sum(belief == -1))
318
 
 
319
  belief[agent.y, agent.x] = state.grid[agent.y][agent.x]
320
  base = math.radians(ORI_DEG[agent.ori])
321
  half = math.radians(FOV_DEG / 2)
 
340
  if tile == DOOR and not state.door_opened_global:
341
  break
342
 
343
+ after_unknown = int(np.sum(belief == -1))
344
+ return max(0, before_unknown - after_unknown)
345
+
346
  # -----------------------------
347
  # Rendering
348
  # -----------------------------
 
409
 
410
  # billboards for visible agents
411
  for nm, other in state.agents.items():
412
+ if nm == observer.name or other.hp <= 0:
413
  continue
414
  if visible(state, observer, other):
415
  dx = other.x - observer.x
 
507
  return im
508
 
509
  # -----------------------------
510
+ # Environments
511
  # -----------------------------
512
  def grid_with_border() -> List[List[int]]:
513
  g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)]
 
521
 
522
  def env_chase(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
523
  g = grid_with_border()
 
524
  for x in range(4, 17):
525
  g[7][x] = WALL
526
  g[7][10] = DOOR
527
 
 
528
  g[3][4] = FOOD
529
  g[11][15] = FOOD
530
  g[4][14] = NOISE
 
541
 
542
  def env_vault(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
543
  g = grid_with_border()
 
544
  for x in range(3, 18):
545
  g[5][x] = WALL
546
  for x in range(3, 18):
 
548
  g[5][10] = DOOR
549
  g[9][12] = DOOR
550
 
 
551
  g[2][2] = KEY
552
  g[12][18] = EXIT
553
  g[12][2] = ARTIFACT
 
567
 
568
  def env_civ(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
569
  g = grid_with_border()
 
 
570
  for y in range(3, 12):
571
  g[y][9] = WALL
572
  g[7][9] = DOOR
573
 
 
574
  g[2][3] = WOOD
575
  g[3][3] = WOOD
576
  g[4][3] = WOOD
 
580
  g[6][4] = FOOD
581
  g[8][15] = FOOD
582
 
 
583
  g[13][10] = BASE
584
  g[4][15] = HAZARD
585
  g[10][4] = HAZARD
 
587
  g[13][2] = TELE
588
  g[2][2] = KEY
589
  g[12][6] = SWITCH
 
590
 
591
  agents = {
592
  "BuilderA": Agent("BuilderA", 3, 12, 0, hp=10, energy=100, team="A", brain="q"),
 
595
  }
596
  return g, agents
597
 
598
+ ENV_BUILDERS = {"chase": env_chase, "vault": env_vault, "civ": env_civ}
 
 
 
 
599
 
600
  # -----------------------------
601
+ # Observation / Q-learning
602
  # -----------------------------
603
  def local_tile_ahead(state: WorldState, a: Agent) -> int:
604
  dx, dy = DIRS[a.ori]
 
609
 
610
  def nearest_enemy_vec(state: WorldState, a: Agent) -> Tuple[int, int, int]:
611
  best = None
612
+ for _, other in state.agents.items():
613
  if other.hp <= 0:
614
  continue
615
  if other.team == a.team:
 
624
 
625
  def obs_key(state: WorldState, who: str) -> str:
626
  a = state.agents[who]
 
 
627
  d, dx, dy = nearest_enemy_vec(state, a)
628
  ahead = local_tile_ahead(state, a)
629
  keys = a.inventory.get("key", 0)
630
  art = a.inventory.get("artifact", 0)
631
  wood = a.inventory.get("wood", 0)
632
  ore = a.inventory.get("ore", 0)
 
633
  inv_bucket = f"k{min(keys,2)}a{min(art,1)}w{min(wood,3)}o{min(ore,3)}"
634
  door = 1 if state.door_opened_global else 0
635
  return f"{state.env_key}|{who}|{a.x},{a.y},{a.ori}|e{d}:{dx},{dy}|t{ahead}|hp{a.hp}|{inv_bucket}|D{door}|bp{state.base_progress}"
 
655
  return old, target, new
656
 
657
  # -----------------------------
658
+ # Baseline heuristics
659
  # -----------------------------
660
  def face_towards(a: Agent, tx: int, ty: int) -> str:
661
  dx = tx - a.x
 
673
  a = state.agents[who]
674
  r = rng_for(state.seed, state.step, stream=900 + hash(who) % 1000)
675
 
676
+ # Prioritize interacting on valuable tiles
677
+ t_here = state.grid[a.y][a.x]
678
+ if t_here in (FOOD, KEY, ARTIFACT, WOOD, ORE, MEDKIT, SWITCH, BASE, EXIT):
 
679
  return "I"
680
 
681
+ # Find nearest enemy
 
 
682
  best = None
683
+ best_d = 999
684
+ for _, other in state.agents.items():
685
  if other.hp <= 0 or other.team == a.team:
686
  continue
687
  d = manhattan_xy(a.x, a.y, other.x, other.y)
688
  if d < best_d:
689
  best_d = d
 
690
  best = other
691
 
692
  if best is not None and best_d <= 6 and visible(state, a, best):
693
+ # attackers chase, defenders try to flee
694
  if a.team == "B":
695
  return face_towards(a, best.x, best.y)
696
+
697
  dx = best.x - a.x
698
  dy = best.y - a.y
699
  ang = (math.degrees(math.atan2(dy, dx)) % 360)
 
706
  return "R"
707
  return "F"
708
 
 
709
  return r.choice(["F", "F", "L", "R", "I"])
710
 
711
  def random_action(state: WorldState, who: str) -> str:
 
745
  def try_interact(state: WorldState, a: Agent) -> str:
746
  t = state.grid[a.y][a.x]
747
 
 
748
  if t == SWITCH:
749
  state.door_opened_global = True
750
  state.grid[a.y][a.x] = EMPTY
 
782
  return "used: medkit"
783
 
784
  if t == BASE:
 
785
  w = a.inventory.get("wood", 0)
786
  o = a.inventory.get("ore", 0)
787
  dep = min(w, 2) + min(o, 2)
 
814
  return "noop"
815
 
816
  # -----------------------------
817
+ # Hazards / collisions / done
818
  # -----------------------------
819
  def resolve_hazards(state: WorldState, a: Agent) -> Tuple[bool, str]:
 
820
  if a.hp <= 0:
821
  return (False, "")
822
+ if state.grid[a.y][a.x] == HAZARD:
 
823
  a.hp -= 1
824
  return (True, "hazard:-hp")
825
  return (False, "")
826
 
827
  def resolve_tags(state: WorldState) -> List[str]:
 
828
  msgs = []
829
+ occupied: Dict[Tuple[int, int], List[str]] = {}
830
  for nm, a in state.agents.items():
831
  if a.hp <= 0:
832
  continue
 
837
  continue
838
  teams = set(state.agents[n].team for n in names)
839
  if len(teams) >= 2:
 
840
  for n in names:
841
  state.agents[n].hp -= 1
842
  msgs.append(f"t={state.step}: collision/tag at ({x},{y}) {names} (-hp all)")
843
  return msgs
844
 
845
  def check_done(state: WorldState) -> None:
 
846
  if state.env_key == "chase":
847
  pred = state.agents["Predator"]
848
  prey = state.agents["Prey"]
 
850
  state.done = True
851
  state.outcome = "draw"
852
  return
853
+ if pred.hp > 0 and prey.hp > 0 and pred.x == prey.x and pred.y == prey.y:
854
  state.done = True
855
+ state.outcome = "A_win"
856
  state.event_log.append(f"t={state.step}: CAUGHT (Predator wins).")
857
  return
 
858
  if state.step >= 300 and prey.hp > 0:
859
  state.done = True
860
  state.outcome = "B_win"
 
862
  return
863
 
864
  if state.env_key == "vault":
 
865
  for nm in ["Alpha", "Bravo"]:
866
  a = state.agents[nm]
867
  if a.hp > 0 and a.inventory.get("artifact", 0) > 0 and state.grid[a.y][a.x] == EXIT:
 
869
  state.outcome = "A_win"
870
  state.event_log.append(f"t={state.step}: VAULT CLEARED (Team A wins).")
871
  return
 
872
  alive_A = any(state.agents[n].hp > 0 for n in ["Alpha", "Bravo"])
873
  if not alive_A:
874
  state.done = True
 
877
  return
878
 
879
  if state.env_key == "civ":
 
880
  if state.base_progress >= state.base_target:
881
  state.done = True
882
  state.outcome = "A_win"
883
  state.event_log.append(f"t={state.step}: BASE COMPLETE (Builders win).")
884
  return
 
885
  alive_A = any(state.agents[n].hp > 0 for n in ["BuilderA", "BuilderB"])
886
  if not alive_A:
887
  state.done = True
888
  state.outcome = "B_win"
889
  state.event_log.append(f"t={state.step}: BUILDERS ELIMINATED (Raider wins).")
890
  return
 
891
  if state.step >= 350:
892
  state.done = True
893
  state.outcome = "draw"
 
897
  # -----------------------------
898
  # Rewards
899
  # -----------------------------
900
+ def reward_for(prev: WorldState, now: WorldState, who: str, outcome_msg: str, took_damage: bool) -> float:
901
+ cfg = now.cfg
 
 
 
 
902
  r = cfg.step_penalty
 
 
 
903
  if outcome_msg.startswith("moved"):
904
  r += cfg.explore_reward
 
905
  if took_damage:
906
  r += cfg.damage_penalty
 
 
907
  if outcome_msg.startswith("used: medkit"):
908
  r += cfg.heal_reward
909
 
910
+ if now.env_key == "chase":
911
+ pred = now.agents["Predator"]
912
+ prey = now.agents["Prey"]
 
913
  if who == "Predator":
914
+ d0 = manhattan_xy(prev.agents["Predator"].x, prev.agents["Predator"].y,
915
+ prev.agents["Prey"].x, prev.agents["Prey"].y)
916
  d1 = manhattan_xy(pred.x, pred.y, prey.x, prey.y)
917
  r += cfg.chase_close_coeff * float(d0 - d1)
918
+ if now.done and now.outcome == "A_win":
919
  r += cfg.chase_catch_reward
920
  if who == "Prey":
921
  if outcome_msg.startswith("ate: food"):
922
  r += cfg.food_reward
923
+ if now.done and now.outcome == "B_win":
924
  r += cfg.chase_escaped_reward
925
+ if now.done and now.outcome == "A_win":
926
  r += cfg.chase_caught_penalty
927
 
928
+ if now.env_key == "vault":
929
  if outcome_msg.startswith("picked: artifact"):
930
  r += cfg.artifact_pick_reward
931
  if outcome_msg.startswith("picked: key"):
932
  r += cfg.key_reward
933
  if outcome_msg.startswith("switch:"):
934
  r += cfg.switch_reward
935
+ if now.done:
936
+ if now.outcome == "A_win" and now.agents[who].team == "A":
937
  r += cfg.exit_win_reward
938
+ if now.outcome == "B_win" and now.agents[who].team == "B":
939
  r += cfg.guardian_tag_reward
940
+ if now.outcome == "B_win" and now.agents[who].team == "A":
941
  r += cfg.tagged_penalty
942
 
943
+ if now.env_key == "civ":
944
  if outcome_msg.startswith("picked: wood") or outcome_msg.startswith("picked: ore"):
945
  r += cfg.resource_pick_reward
946
  if outcome_msg.startswith("deposited:"):
947
  r += cfg.deposit_reward
948
+ if now.done:
949
+ if now.outcome == "A_win" and now.agents[who].team == "A":
950
  r += cfg.base_progress_win_reward
951
+ if now.outcome == "B_win" and now.agents[who].team == "B":
952
  r += cfg.raider_elim_reward
953
+ if now.outcome == "B_win" and now.agents[who].team == "A":
954
  r += cfg.builder_elim_penalty
955
 
956
  return float(r)
957
 
958
  # -----------------------------
959
+ # Policy selection
960
  # -----------------------------
961
  def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, Optional[Tuple[str, int]]]:
 
 
 
 
962
  a = state.agents[who]
963
  cfg = state.cfg
964
  r = rng_for(state.seed, state.step, stream=stream)
965
 
 
966
  if a.brain == "random":
967
  act = random_action(state, who)
968
  return act, "random", None
 
970
  act = heuristic_action(state, who)
971
  return act, "heuristic", None
972
 
 
973
  if cfg.use_q:
974
  key = obs_key(state, who)
975
  qtab = state.q_tables.setdefault(who, {})
 
981
  return act, "heuristic(fallback)", None
982
 
983
  # -----------------------------
984
+ # Init / reset
985
  # -----------------------------
986
  def init_state(seed: int, env_key: str) -> WorldState:
987
  g, agents = ENV_BUILDERS[env_key](seed)
 
988
  st = WorldState(
989
  seed=seed,
990
  step=0,
 
1007
  if seed is None:
1008
  seed = state.seed
1009
  fresh = init_state(int(seed), state.env_key)
 
1010
  fresh.cfg = state.cfg
1011
  fresh.q_tables = state.q_tables
1012
  fresh.gmetrics = state.gmetrics
 
1057
  return state
1058
 
1059
  # -----------------------------
1060
+ # Metrics dashboard (HF-safe)
1061
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
1062
  def metrics_dashboard_image(state: WorldState) -> Image.Image:
1063
  gm = state.gmetrics
1064
+
1065
  fig = plt.figure(figsize=(7.0, 2.2), dpi=120)
1066
  ax = fig.add_subplot(111)
1067
+
1068
+ x1 = max(1, gm.episodes)
1069
+ ax.plot([0, x1], [gm.rolling_winrate_A, gm.rolling_winrate_A])
1070
  ax.set_title("Global Metrics Snapshot")
1071
+ ax.set_xlabel("Episodes")
1072
  ax.set_ylabel("Rolling winrate Team A")
1073
+ ax.set_ylim(-0.05, 1.05)
1074
  ax.grid(True)
1075
 
 
1076
  txt = (
1077
  f"env={state.env_key} | eps={gm.epsilon:.3f} | episodes={gm.episodes}\n"
1078
+ f"A_wins={gm.wins_teamA} B_wins={gm.wins_teamB} draws={gm.draws} | avg_steps~{gm.avg_steps:.1f}\n"
 
1079
  f"last_outcome={gm.last_outcome} last_steps={gm.last_steps}"
1080
  )
1081
  ax.text(0.01, 0.05, txt, transform=ax.transAxes, fontsize=8, va="bottom")
1082
 
1083
  fig.tight_layout()
1084
+
1085
+ canvas = FigureCanvas(fig)
1086
+ canvas.draw()
1087
+ buf = np.asarray(canvas.buffer_rgba()) # (H,W,4)
1088
+ img = Image.fromarray(buf, mode="RGBA").convert("RGB")
1089
  plt.close(fig)
1090
+ return img
1091
+
1092
+ def action_entropy(counts: Dict[str, int]) -> float:
1093
+ total = sum(counts.values())
1094
+ if total <= 0:
1095
+ return 0.0
1096
+ p = np.array([c / total for c in counts.values()], dtype=np.float64)
1097
+ p = np.clip(p, 1e-12, 1.0)
1098
+ return float(-np.sum(p * np.log2(p)))
1099
 
1100
  def agent_scoreboard(state: WorldState) -> str:
1101
  rows = []
1102
  header = ["agent", "team", "hp", "return", "steps", "entropy", "tiles_disc", "q_states", "inventory"]
1103
  rows.append(header)
 
1104
  steps = state.emetrics.steps
1105
+
1106
  for nm, a in state.agents.items():
1107
  ret = state.emetrics.returns.get(nm, 0.0)
1108
  counts = state.emetrics.action_counts.get(nm, {})
 
1112
  inv = json.dumps(a.inventory, sort_keys=True)
1113
  rows.append([nm, a.team, a.hp, f"{ret:.2f}", steps, f"{ent:.2f}", td, qs, inv])
1114
 
 
1115
  col_w = [max(len(str(r[i])) for r in rows) for i in range(len(header))]
1116
  lines = []
1117
  for ridx, r in enumerate(rows):
 
1122
  return "\n".join(lines)
1123
 
1124
  # -----------------------------
1125
+ # Tick / training
1126
  # -----------------------------
1127
  def clone_shallow(state: WorldState) -> WorldState:
1128
+ return WorldState(
 
1129
  seed=state.seed,
1130
  step=state.step,
1131
  env_key=state.env_key,
 
1146
  gmetrics=state.gmetrics,
1147
  emetrics=state.emetrics,
1148
  )
1149
+
1150
+ def update_action_counts(state: WorldState, who: str, act: str):
1151
+ state.emetrics.action_counts.setdefault(who, {})
1152
+ state.emetrics.action_counts[who][act] = state.emetrics.action_counts[who].get(act, 0) + 1
1153
 
1154
  def tick(state: WorldState, beliefs: Dict[str, np.ndarray], manual_action: Optional[str] = None) -> None:
1155
  if state.done:
1156
  return
1157
 
1158
  prev = clone_shallow(state)
 
 
1159
  chosen: Dict[str, str] = {}
1160
  reasons: Dict[str, str] = {}
1161
  qinfo: Dict[str, Optional[Tuple[str, int]]] = {}
 
1165
  reasons[state.controlled] = "manual"
1166
  qinfo[state.controlled] = None
1167
 
1168
+ order = list(state.agents.keys())
1169
+
1170
+ for who in order:
1171
  if who in chosen:
1172
  continue
1173
  act, reason, qi = choose_action(state, who, stream=200 + (hash(who) % 1000))
 
1175
  reasons[who] = reason
1176
  qinfo[who] = qi
1177
 
 
 
1178
  outcomes: Dict[str, str] = {}
1179
  took_damage: Dict[str, bool] = {nm: False for nm in order}
 
1180
 
1181
  for who in order:
 
1182
  outcomes[who] = apply_action(state, who, chosen[who])
 
 
1183
 
1184
  dmg, msg = resolve_hazards(state, state.agents[who])
1185
  took_damage[who] = dmg
1186
  if msg:
1187
  state.event_log.append(f"t={state.step}: {who} {msg}")
1188
 
 
1189
  update_action_counts(state, who, chosen[who])
1190
 
1191
+ for m in resolve_tags(state):
 
 
1192
  state.event_log.append(m)
1193
 
1194
+ # belief updates + discovered tiles
1195
  for nm, a in state.agents.items():
1196
  if a.hp <= 0:
1197
  continue
1198
+ disc = update_belief_for_agent(state, beliefs[nm], a)
1199
+ state.emetrics.tiles_discovered[nm] = state.emetrics.tiles_discovered.get(nm, 0) + disc
 
 
 
1200
 
 
1201
  check_done(state)
1202
 
1203
+ # rewards + Q
1204
  q_lines = []
1205
  for who in order:
1206
  if who not in state.emetrics.returns:
1207
  state.emetrics.returns[who] = 0.0
1208
+
1209
+ r = reward_for(prev, state, who, outcomes[who], took_damage[who])
1210
  state.emetrics.returns[who] += r
1211
 
1212
  if qinfo.get(who) is not None:
 
1216
  old, tgt, new = q_update(qtab, key, a_idx, r, next_key, state.cfg.alpha, state.cfg.gamma)
1217
  q_lines.append(f"{who}: old={old:.3f} tgt={tgt:.3f} new={new:.3f} (a={ACTIONS[a_idx]})")
1218
 
 
1219
  trace = f"t={state.step} env={state.env_key} done={state.done} outcome={state.outcome}"
1220
  for who in order:
1221
  a = state.agents[who]
 
1230
  state.step += 1
1231
  state.emetrics.steps = state.step
1232
 
 
 
 
1233
  def run_episode(state: WorldState, beliefs: Dict[str, np.ndarray], max_steps: int) -> Tuple[str, int]:
1234
  while state.step < max_steps and not state.done:
1235
  tick(state, beliefs, manual_action=None)
 
1252
  gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.5
1253
 
1254
  gm.avg_steps = (0.90 * gm.avg_steps + 0.10 * steps) if gm.avg_steps > 0 else float(steps)
 
 
1255
  gm.epsilon = max(state.cfg.epsilon_min, gm.epsilon * state.cfg.epsilon_decay)
1256
 
1257
  def train(state: WorldState, episodes: int, max_steps: int) -> WorldState:
1258
  for ep in range(episodes):
 
1259
  ep_seed = (state.seed * 1_000_003 + (state.gmetrics.episodes + ep) * 97_531) & 0xFFFFFFFF
1260
  state = reset_episode_keep_learning(state, seed=int(ep_seed))
1261
  beliefs = init_beliefs(list(state.agents.keys()))
1262
  outcome, steps = run_episode(state, beliefs, max_steps=max_steps)
1263
  update_global_metrics_after_episode(state, outcome, steps)
1264
+
1265
  state.event_log.append(
1266
  f"Training: +{episodes} eps | eps={state.gmetrics.epsilon:.3f} | "
1267
  f"A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws}"
1268
  )
 
1269
  state = reset_episode_keep_learning(state, seed=state.seed)
1270
  return state
1271
 
 
1284
  "q_tables": state.q_tables,
1285
  "branches": {b: [asdict(s) for s in snaps] for b, snaps in branches.items()},
1286
  "active_branch": active_branch,
1287
+ "rewind_idx": int(rewind_idx),
1288
  "grid": state.grid,
1289
  "door_opened_global": state.door_opened_global,
1290
  "base_progress": state.base_progress,
 
1295
  return txt + "\n\n" + json.dumps({"proof_sha256": proof}, indent=2)
1296
 
1297
  def import_run(txt: str) -> Tuple[WorldState, Dict[str, List[Snapshot]], str, int, Dict[str, np.ndarray]]:
 
1298
  parts = txt.strip().split("\n\n")
1299
  data = json.loads(parts[0])
1300
 
 
1319
  active = data.get("active_branch", "main")
1320
  r_idx = int(data.get("rewind_idx", 0))
1321
 
 
1322
  if active in branches and branches[active]:
1323
  st = restore_into(st, branches[active][-1])
1324
  st.event_log.append("Imported run (restored last snapshot).")
 
1332
  # UI helpers
1333
  # -----------------------------
1334
  def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, Image.Image, str, str, str, str]:
 
1335
  for nm, a in state.agents.items():
1336
  if a.hp > 0:
1337
  update_belief_for_agent(state, beliefs[nm], a)
1338
 
1339
  pov = raycast_view(state, state.agents[state.pov])
1340
  truth_np = np.array(state.grid, dtype=np.int16)
1341
+ truth_img = render_topdown(truth_np, state.agents, f"Truth Map — env={state.env_key} t={state.step} seed={state.seed}", True)
1342
 
1343
  ctrl = state.controlled
 
1344
  others = [k for k in state.agents.keys() if k != ctrl]
1345
  other = others[0] if others else ctrl
1346
+ b_ctrl = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief", True)
1347
+ b_other = render_topdown(beliefs[other], state.agents, f"{other} Belief", True)
1348
 
1349
  dash = metrics_dashboard_image(state)
1350
 
 
1368
  gy = int(y_px // TILE)
1369
  if not in_bounds(gx, gy):
1370
  return state
 
1371
  if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1:
1372
  return state
1373
  state.grid[gy][gx] = selected_tile
 
1375
  return state
1376
 
1377
  # -----------------------------
1378
+ # Gradio app
1379
  # -----------------------------
1380
  TITLE = "ZEN AgentLab — Agent POV + Multi-Agent Mini-Sim Arena"
1381
 
1382
  with gr.Blocks(title=TITLE) as demo:
1383
  gr.Markdown(
1384
  f"## {TITLE}\n"
1385
+ "Multi-environment agent sandbox with POV, belief maps, branching timelines, training, and metrics.\n"
1386
+ "**No timers** use Tick / Run / Train for deterministic experiments."
1387
  )
1388
 
1389
+ st0 = init_state(1337, "chase")
1390
+ st = gr.State(st0)
1391
+ branches = gr.State({"main": [snapshot_of(st0, "main")]})
1392
  active_branch = gr.State("main")
1393
  rewind_idx = gr.State(0)
1394
+ beliefs = gr.State(init_beliefs(list(st0.agents.keys())))
1395
 
1396
  with gr.Row():
1397
  pov_img = gr.Image(label="POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H)
 
1409
 
1410
  with gr.Row():
1411
  events = gr.Textbox(label="Event Log", lines=10)
1412
+ trace = gr.Textbox(label="Step Trace", lines=10)
1413
 
1414
  with gr.Row():
1415
  with gr.Column(scale=2):
 
1446
  with gr.Column(scale=3):
1447
  gr.Markdown("### Training Controls (Tabular Q-learning)")
1448
  use_q = gr.Checkbox(True, label="Use Q-learning (agents with brain='q')")
1449
+ alpha = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="alpha")
1450
+ gamma = gr.Slider(0.5, 0.99, value=0.95, step=0.01, label="gamma")
1451
+ eps = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="epsilon")
1452
  eps_decay = gr.Slider(0.90, 0.999, value=0.995, step=0.001, label="epsilon decay")
1453
  eps_min = gr.Slider(0.0, 0.2, value=0.02, step=0.01, label="epsilon min")
1454
 
1455
  episodes = gr.Number(value=50, label="Train episodes", precision=0)
1456
+ max_steps = gr.Number(value=260, label="Max steps/episode", precision=0)
1457
  btn_train = gr.Button("Train")
1458
 
1459
  btn_reset = gr.Button("Reset Episode (keep learning)")
 
1477
  import_box = gr.Textbox(label="Import JSON", lines=8)
1478
  btn_import = gr.Button("Import")
1479
 
1480
+ # ---------- glue ----------
1481
  def refresh(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str, bel: Dict[str, np.ndarray], r: int):
1482
  snaps = branches_d.get(active, [])
1483
  r_max = max(0, len(snaps) - 1)
1484
  r = max(0, min(int(r), r_max))
1485
  pov, tr, ba, bb, dimg, stxt, etxt, ttxt, sb = build_views(state, bel)
 
1486
  branch_choices = sorted(list(branches_d.keys()))
1487
  return (
1488
+ pov, tr, ba, bb, dimg, stxt, sb, etxt, ttxt,
1489
+ gr.update(maximum=r_max, value=r), r,
 
 
1490
  gr.update(choices=branch_choices, value=active),
1491
  gr.update(choices=branch_choices, value=active),
1492
  )
 
1584
  branches_d[new_name].append(snapshot_of(state, new_name))
1585
  else:
1586
  idx = max(0, min(int(r), len(snaps) - 1))
 
1587
  branches_d[new_name] = [Snapshot(**asdict(s)) for s in snaps[:idx + 1]]
 
1588
  state = restore_into(state, branches_d[new_name][-1])
1589
  active = new_name
1590
  state.event_log.append(f"Forked branch -> {new_name}")
 
1598
  if br not in branches_d:
1599
  branches_d[br] = [snapshot_of(state, br)]
1600
  active = br
 
1601
  if branches_d[active]:
1602
  state = restore_into(state, branches_d[active][-1])
1603
  bel = init_beliefs(list(state.agents.keys()))
 
1604
  r = len(branches_d[active]) - 1
1605
+ out = refresh(state, branches_d, active, bel, r)
1606
  return out + (state, branches_d, active, bel, r)
1607
 
1608
  def change_env(state, branches_d, active, bel, r, env_key):
1609
  env_key = env_key or "chase"
 
1610
  state.env_key = env_key
1611
  state = reset_episode_keep_learning(state, seed=state.seed)
1612
  bel = init_beliefs(list(state.agents.keys()))
 
1651
 
1652
  def import_fn(txt):
1653
  state, branches_d, active, r, bel = import_run(txt)
 
1654
  branches_d.setdefault(active, [])
1655
  if not branches_d[active]:
1656
  branches_d[active].append(snapshot_of(state, active))
1657
  out = refresh(state, branches_d, active, bel, r)
1658
  return out + (state, branches_d, active, bel, r)
1659
 
1660
+ # ---- wire events (no fn_kwargs) ----
1661
+ common_outputs = [
1662
+ pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1663
+ rewind, rewind_idx, branch_pick, branch_pick,
1664
+ st, branches, active_branch, beliefs, rewind_idx
1665
+ ]
1666
+
1667
  btn_L.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"L"),
1668
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1669
+ outputs=common_outputs, queue=True)
 
 
1670
 
1671
  btn_F.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"F"),
1672
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1673
+ outputs=common_outputs, queue=True)
 
 
1674
 
1675
  btn_R.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"R"),
1676
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1677
+ outputs=common_outputs, queue=True)
 
 
1678
 
1679
  btn_I.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"I"),
1680
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1681
+ outputs=common_outputs, queue=True)
 
 
1682
 
1683
  btn_tick.click(do_tick,
1684
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1685
+ outputs=common_outputs, queue=True)
 
 
1686
 
1687
  btn_run.click(do_run,
1688
  inputs=[st, branches, active_branch, beliefs, rewind_idx, run_steps],
1689
+ outputs=common_outputs, queue=True)
 
 
1690
 
1691
  btn_toggle_control.click(toggle_control,
1692
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1693
+ outputs=common_outputs, queue=True)
 
 
1694
 
1695
  btn_toggle_pov.click(toggle_pov,
1696
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1697
+ outputs=common_outputs, queue=True)
 
 
1698
 
1699
  overlay.change(set_overlay,
1700
  inputs=[st, branches, active_branch, beliefs, rewind_idx, overlay],
1701
+ outputs=common_outputs, queue=True)
 
 
1702
 
1703
  env_pick.change(change_env,
1704
  inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick],
1705
+ outputs=common_outputs, queue=True)
 
 
1706
 
1707
  truth.select(click_truth,
1708
  inputs=[tile_pick, st, branches, active_branch, beliefs, rewind_idx],
1709
+ outputs=common_outputs, queue=True)
 
 
1710
 
1711
  btn_jump.click(jump,
1712
  inputs=[st, branches, active_branch, beliefs, rewind_idx, rewind],
1713
+ outputs=common_outputs, queue=True)
 
 
1714
 
1715
  btn_fork.click(fork_branch,
1716
  inputs=[st, branches, active_branch, beliefs, rewind_idx, new_branch_name],
1717
+ outputs=common_outputs, queue=True)
 
 
1718
 
1719
  btn_set_branch.click(set_active_branch,
1720
  inputs=[st, branches, active_branch, beliefs, rewind_idx, branch_pick],
1721
+ outputs=common_outputs, queue=True)
 
 
1722
 
1723
  btn_reset.click(reset_ep,
1724
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1725
+ outputs=common_outputs, queue=True)
 
 
1726
 
1727
  btn_reset_all.click(reset_all,
1728
  inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick],
1729
+ outputs=common_outputs, queue=True)
 
 
1730
 
1731
  btn_train.click(do_train,
1732
  inputs=[st, branches, active_branch, beliefs, rewind_idx,
1733
  use_q, alpha, gamma, eps, eps_decay, eps_min,
1734
  episodes, max_steps],
1735
+ outputs=common_outputs, queue=True)
 
 
1736
 
1737
  btn_export.click(export_fn, inputs=[st, branches, active_branch, rewind_idx], outputs=[export_box], queue=True)
1738
 
1739
  btn_import.click(import_fn,
1740
  inputs=[import_box],
1741
+ outputs=common_outputs, queue=True)
 
 
1742
 
1743
  demo.load(refresh,
1744
  inputs=[st, branches, active_branch, beliefs, rewind_idx],
1745
+ outputs=[
1746
+ pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
1747
+ rewind, rewind_idx, branch_pick, branch_pick
1748
+ ],
1749
  queue=True)
1750
 
1751
+ # HF sometimes enables SSR by default; disable for maximum compatibility
1752
+ demo.queue().launch(ssr_mode=False)