import json import math import hashlib from dataclasses import dataclass, asdict from typing import Dict, List, Tuple, Optional, Any import numpy as np from PIL import Image, ImageDraw import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas import gradio as gr # ============================================================ # ZEN AgentLab — Agent POV + Multi-Agent Mini-Sim Arena # # Additions in this version: # - Autoplay (Start/Stop) via gr.Timer (watch agents live) # - One-click "Cinematic Run" (full episode in one click) # - Example presets (env+seed) + seed controls # - Autoplay is interruptible: manual buttons still work anytime # # Matplotlib HF-safe: uses canvas.buffer_rgba() # ============================================================ # ----------------------------- # Global config # ----------------------------- GRID_W, GRID_H = 21, 15 TILE = 22 VIEW_W, VIEW_H = 640, 360 RAY_W = 320 FOV_DEG = 78 MAX_DEPTH = 20 DIRS = [(1, 0), (0, 1), (-1, 0), (0, -1)] ORI_DEG = [0, 90, 180, 270] # Tiles EMPTY = 0 WALL = 1 FOOD = 2 NOISE = 3 DOOR = 4 TELE = 5 KEY = 6 EXIT = 7 ARTIFACT = 8 HAZARD = 9 WOOD = 10 ORE = 11 MEDKIT = 12 SWITCH = 13 BASE = 14 TILE_NAMES = { EMPTY: "Empty", WALL: "Wall", FOOD: "Food", NOISE: "Noise", DOOR: "Door", TELE: "Teleporter", KEY: "Key", EXIT: "Exit", ARTIFACT: "Artifact", HAZARD: "Hazard", WOOD: "Wood", ORE: "Ore", MEDKIT: "Medkit", SWITCH: "Switch", BASE: "Base", } AGENT_COLORS = { "Predator": (255, 120, 90), "Prey": (120, 255, 160), "Scout": (120, 190, 255), "Alpha": (255, 205, 120), "Bravo": (160, 210, 255), "Guardian": (255, 120, 220), "BuilderA": (140, 255, 200), "BuilderB": (160, 200, 255), "Raider": (255, 160, 120), } SKY = np.array([14, 16, 26], dtype=np.uint8) FLOOR_NEAR = np.array([24, 26, 40], dtype=np.uint8) FLOOR_FAR = np.array([10, 11, 18], dtype=np.uint8) WALL_BASE = np.array([210, 210, 225], dtype=np.uint8) WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8) DOOR_COL = np.array([140, 210, 255], dtype=np.uint8) # Small action space ACTIONS = ["L", "F", "R", "I"] # interact TRACE_MAX = 500 MAX_HISTORY = 1400 # ----------------------------- # Deterministic RNG # ----------------------------- def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator: mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531) return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF) # ----------------------------- # Data structures # ----------------------------- @dataclass class Agent: name: str x: int y: int ori: int hp: int = 10 energy: int = 100 team: str = "A" brain: str = "q" # q | heuristic | random inventory: Dict[str, int] = None def __post_init__(self): if self.inventory is None: self.inventory = {} @dataclass class TrainConfig: use_q: bool = True alpha: float = 0.15 gamma: float = 0.95 epsilon: float = 0.10 epsilon_min: float = 0.02 epsilon_decay: float = 0.995 step_penalty: float = -0.01 explore_reward: float = 0.015 damage_penalty: float = -0.20 heal_reward: float = 0.10 chase_close_coeff: float = 0.03 chase_catch_reward: float = 3.0 chase_escaped_reward: float = 0.2 chase_caught_penalty: float = -3.0 food_reward: float = 0.6 artifact_pick_reward: float = 1.2 exit_win_reward: float = 3.0 guardian_tag_reward: float = 2.0 tagged_penalty: float = -2.0 switch_reward: float = 0.8 key_reward: float = 0.4 resource_pick_reward: float = 0.15 deposit_reward: float = 0.4 base_progress_win_reward: float = 3.5 raider_elim_reward: float = 2.0 builder_elim_penalty: float = -2.0 @dataclass class GlobalMetrics: episodes: int = 0 wins_teamA: int = 0 wins_teamB: int = 0 draws: int = 0 avg_steps: float = 0.0 rolling_winrate_A: float = 0.0 epsilon: float = 0.10 last_outcome: str = "init" last_steps: int = 0 @dataclass class EpisodeMetrics: steps: int = 0 returns: Dict[str, float] = None action_counts: Dict[str, Dict[str, int]] = None tiles_discovered: Dict[str, int] = None def __post_init__(self): if self.returns is None: self.returns = {} if self.action_counts is None: self.action_counts = {} if self.tiles_discovered is None: self.tiles_discovered = {} @dataclass class WorldState: seed: int step: int env_key: str grid: List[List[int]] agents: Dict[str, Agent] controlled: str pov: str overlay: bool done: bool outcome: str # A_win | B_win | draw | ongoing door_opened_global: bool = False base_progress: int = 0 base_target: int = 10 event_log: List[str] = None trace_log: List[str] = None cfg: TrainConfig = None q_tables: Dict[str, Dict[str, List[float]]] = None gmetrics: GlobalMetrics = None emetrics: EpisodeMetrics = None def __post_init__(self): if self.event_log is None: self.event_log = [] if self.trace_log is None: self.trace_log = [] if self.cfg is None: self.cfg = TrainConfig() if self.q_tables is None: self.q_tables = {} if self.gmetrics is None: self.gmetrics = GlobalMetrics(epsilon=self.cfg.epsilon) if self.emetrics is None: self.emetrics = EpisodeMetrics() @dataclass class Snapshot: branch: str step: int env_key: str grid: List[List[int]] agents: Dict[str, Dict[str, Any]] done: bool outcome: str door_opened_global: bool base_progress: int base_target: int event_tail: List[str] trace_tail: List[str] emetrics: Dict[str, Any] # ----------------------------- # Helpers # ----------------------------- def in_bounds(x: int, y: int) -> bool: return 0 <= x < GRID_W and 0 <= y < GRID_H def is_blocking(tile: int, door_open: bool = False) -> bool: if tile == WALL: return True if tile == DOOR and not door_open: return True return False def manhattan_xy(ax: int, ay: int, bx: int, by: int) -> int: return abs(ax - bx) + abs(ay - by) def bresenham_los(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> bool: dx = abs(x1 - x0) dy = abs(y1 - y0) sx = 1 if x0 < x1 else -1 sy = 1 if y0 < y1 else -1 err = dx - dy x, y = x0, y0 while True: if (x, y) != (x0, y0) and (x, y) != (x1, y1): if grid[y][x] == WALL: return False if x == x1 and y == y1: return True e2 = 2 * err if e2 > -dy: err -= dy x += sx if e2 < dx: err += dx y += sy def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = FOV_DEG) -> bool: dx = tx - observer.x dy = ty - observer.y if dx == 0 and dy == 0: return True angle = math.degrees(math.atan2(dy, dx)) % 360 facing = ORI_DEG[observer.ori] diff = (angle - facing + 540) % 360 - 180 return abs(diff) <= (fov_deg / 2) def visible(state: WorldState, observer: Agent, target: Agent) -> bool: if not within_fov(observer, target.x, target.y, FOV_DEG): return False return bresenham_los(state.grid, observer.x, observer.y, target.x, target.y) def hash_sha256(txt: str) -> str: return hashlib.sha256(txt.encode("utf-8")).hexdigest() # ----------------------------- # Beliefs # ----------------------------- def init_beliefs(agent_names: List[str]) -> Dict[str, np.ndarray]: return {nm: (-1 * np.ones((GRID_H, GRID_W), dtype=np.int16)) for nm in agent_names} def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> int: before_unknown = int(np.sum(belief == -1)) belief[agent.y, agent.x] = state.grid[agent.y][agent.x] base = math.radians(ORI_DEG[agent.ori]) half = math.radians(FOV_DEG / 2) rays = 45 if agent.name.lower().startswith("scout") else 33 for i in range(rays): t = i / (rays - 1) ang = base + (t * 2 - 1) * half sin_a, cos_a = math.sin(ang), math.cos(ang) ox, oy = agent.x + 0.5, agent.y + 0.5 depth = 0.0 while depth < MAX_DEPTH: depth += 0.2 tx = int(ox + cos_a * depth) ty = int(oy + sin_a * depth) if not in_bounds(tx, ty): break belief[ty, tx] = state.grid[ty][tx] tile = state.grid[ty][tx] if tile == WALL: break if tile == DOOR and not state.door_opened_global: break after_unknown = int(np.sum(belief == -1)) return max(0, before_unknown - after_unknown) # ----------------------------- # Rendering # ----------------------------- def raycast_view(state: WorldState, observer: Agent) -> np.ndarray: img = np.zeros((VIEW_H, VIEW_W, 3), dtype=np.uint8) img[:, :] = SKY for y in range(VIEW_H // 2, VIEW_H): t = (y - VIEW_H // 2) / (VIEW_H // 2 + 1e-6) col = (1 - t) * FLOOR_NEAR + t * FLOOR_FAR img[y, :] = col.astype(np.uint8) fov = math.radians(FOV_DEG) half_fov = fov / 2 for rx in range(RAY_W): cam_x = (2 * rx / (RAY_W - 1)) - 1 ray_ang = math.radians(ORI_DEG[observer.ori]) + cam_x * half_fov ox, oy = observer.x + 0.5, observer.y + 0.5 sin_a = math.sin(ray_ang) cos_a = math.cos(ray_ang) depth = 0.0 hit = None side = 0 while depth < MAX_DEPTH: depth += 0.05 tx = int(ox + cos_a * depth) ty = int(oy + sin_a * depth) if not in_bounds(tx, ty): break tile = state.grid[ty][tx] if tile == WALL: hit = "wall" side = 1 if abs(cos_a) > abs(sin_a) else 0 break if tile == DOOR and not state.door_opened_global: hit = "door" break if hit is None: continue depth *= math.cos(ray_ang - math.radians(ORI_DEG[observer.ori])) depth = max(depth, 0.001) proj_h = int((VIEW_H * 0.9) / depth) y0 = max(0, VIEW_H // 2 - proj_h // 2) y1 = min(VIEW_H - 1, VIEW_H // 2 + proj_h // 2) if hit == "door": col = DOOR_COL.copy() else: col = WALL_BASE.copy() if side == 0 else WALL_SIDE.copy() dim = max(0.25, 1.0 - (depth / MAX_DEPTH)) col = (col * dim).astype(np.uint8) x0 = int(rx * (VIEW_W / RAY_W)) x1 = int((rx + 1) * (VIEW_W / RAY_W)) img[y0:y1, x0:x1] = col for nm, other in state.agents.items(): if nm == observer.name or other.hp <= 0: continue if visible(state, observer, other): dx = other.x - observer.x dy = other.y - observer.y ang = (math.degrees(math.atan2(dy, dx)) % 360) facing = ORI_DEG[observer.ori] diff = (ang - facing + 540) % 360 - 180 sx = int((diff / (FOV_DEG / 2)) * (VIEW_W / 2) + (VIEW_W / 2)) dist = math.sqrt(dx * dx + dy * dy) h = int((VIEW_H * 0.65) / max(dist, 0.75)) w = max(10, h // 3) y_mid = VIEW_H // 2 y0 = max(0, y_mid - h // 2) y1 = min(VIEW_H - 1, y_mid + h // 2) x0 = max(0, sx - w // 2) x1 = min(VIEW_W - 1, sx + w // 2) col = AGENT_COLORS.get(nm, (255, 200, 120)) img[y0:y1, x0:x1] = np.array(col, dtype=np.uint8) if state.overlay: cx, cy = VIEW_W // 2, VIEW_H // 2 img[cy - 1:cy + 2, cx - 10:cx + 10] = np.array([120, 190, 255], dtype=np.uint8) img[cy - 10:cy + 10, cx - 1:cx + 2] = np.array([120, 190, 255], dtype=np.uint8) return img def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_agents: bool = True) -> Image.Image: w = grid.shape[1] * TILE h = grid.shape[0] * TILE im = Image.new("RGB", (w, h + 28), (10, 12, 18)) draw = ImageDraw.Draw(im) for y in range(grid.shape[0]): for x in range(grid.shape[1]): t = int(grid[y, x]) if t == -1: col = (18, 20, 32) elif t == EMPTY: col = (26, 30, 44) elif t == WALL: col = (190, 190, 210) elif t == FOOD: col = (255, 210, 120) elif t == NOISE: col = (255, 120, 220) elif t == DOOR: col = (140, 210, 255) elif t == TELE: col = (120, 190, 255) elif t == KEY: col = (255, 235, 160) elif t == EXIT: col = (120, 255, 220) elif t == ARTIFACT: col = (255, 170, 60) elif t == HAZARD: col = (255, 90, 90) elif t == WOOD: col = (170, 120, 60) elif t == ORE: col = (140, 140, 160) elif t == MEDKIT: col = (120, 255, 140) elif t == SWITCH: col = (200, 180, 255) elif t == BASE: col = (220, 220, 240) else: col = (80, 80, 90) x0, y0 = x * TILE, y * TILE + 28 draw.rectangle([x0, y0, x0 + TILE - 1, y0 + TILE - 1], fill=col) for x in range(grid.shape[1] + 1): xx = x * TILE draw.line([xx, 28, xx, h + 28], fill=(12, 14, 22)) for y in range(grid.shape[0] + 1): yy = y * TILE + 28 draw.line([0, yy, w, yy], fill=(12, 14, 22)) if show_agents: for nm, a in agents.items(): if a.hp <= 0: continue cx = a.x * TILE + TILE // 2 cy = a.y * TILE + 28 + TILE // 2 col = AGENT_COLORS.get(nm, (220, 220, 220)) r = TILE // 3 draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col) dx, dy = DIRS[a.ori] draw.line([cx, cy, cx + dx * r, cy + dy * r], fill=(10, 10, 10), width=3) draw.rectangle([0, 0, w, 28], fill=(14, 16, 26)) draw.text((8, 6), title, fill=(230, 230, 240)) return im # ----------------------------- # Environments # ----------------------------- def grid_with_border() -> List[List[int]]: g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)] for x in range(GRID_W): g[0][x] = WALL g[GRID_H - 1][x] = WALL for y in range(GRID_H): g[y][0] = WALL g[y][GRID_W - 1] = WALL return g def env_chase(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]: g = grid_with_border() for x in range(4, 17): g[7][x] = WALL g[7][10] = DOOR g[3][4] = FOOD g[11][15] = FOOD g[4][14] = NOISE g[12][5] = NOISE g[2][18] = TELE g[13][2] = TELE agents = { "Predator": Agent("Predator", 2, 2, 0, hp=10, energy=100, team="A", brain="q"), "Prey": Agent("Prey", 18, 12, 2, hp=10, energy=100, team="B", brain="q"), "Scout": Agent("Scout", 10, 3, 1, hp=10, energy=100, team="A", brain="heuristic"), } return g, agents def env_vault(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]: g = grid_with_border() for x in range(3, 18): g[5][x] = WALL for x in range(3, 18): g[9][x] = WALL g[5][10] = DOOR g[9][12] = DOOR g[2][2] = KEY g[12][18] = EXIT g[12][2] = ARTIFACT g[2][18] = TELE g[13][2] = TELE g[7][10] = SWITCH g[3][15] = HAZARD g[11][6] = MEDKIT g[2][12] = FOOD agents = { "Alpha": Agent("Alpha", 2, 12, 0, hp=10, energy=100, team="A", brain="q"), "Bravo": Agent("Bravo", 3, 12, 0, hp=10, energy=100, team="A", brain="q"), "Guardian": Agent("Guardian", 18, 2, 2, hp=10, energy=100, team="B", brain="q"), } return g, agents def env_civ(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]: g = grid_with_border() for y in range(3, 12): g[y][9] = WALL g[7][9] = DOOR g[2][3] = WOOD g[3][3] = WOOD g[4][3] = WOOD g[12][16] = ORE g[11][16] = ORE g[10][16] = ORE g[6][4] = FOOD g[8][15] = FOOD g[13][10] = BASE g[4][15] = HAZARD g[10][4] = HAZARD g[2][18] = TELE g[13][2] = TELE g[2][2] = KEY g[12][6] = SWITCH agents = { "BuilderA": Agent("BuilderA", 3, 12, 0, hp=10, energy=100, team="A", brain="q"), "BuilderB": Agent("BuilderB", 4, 12, 0, hp=10, energy=100, team="A", brain="q"), "Raider": Agent("Raider", 18, 2, 2, hp=10, energy=100, team="B", brain="q"), } return g, agents ENV_BUILDERS = {"chase": env_chase, "vault": env_vault, "civ": env_civ} # ----------------------------- # Observation / Q-learning # ----------------------------- def local_tile_ahead(state: WorldState, a: Agent) -> int: dx, dy = DIRS[a.ori] nx, ny = a.x + dx, a.y + dy if not in_bounds(nx, ny): return WALL return state.grid[ny][nx] def nearest_enemy_vec(state: WorldState, a: Agent) -> Tuple[int, int, int]: best = None for _, other in state.agents.items(): if other.hp <= 0: continue if other.team == a.team: continue d = manhattan_xy(a.x, a.y, other.x, other.y) if best is None or d < best[0]: best = (d, other.x - a.x, other.y - a.y) if best is None: return (99, 0, 0) d, dx, dy = best return (d, int(np.clip(dx, -6, 6)), int(np.clip(dy, -6, 6))) def obs_key(state: WorldState, who: str) -> str: a = state.agents[who] d, dx, dy = nearest_enemy_vec(state, a) ahead = local_tile_ahead(state, a) keys = a.inventory.get("key", 0) art = a.inventory.get("artifact", 0) wood = a.inventory.get("wood", 0) ore = a.inventory.get("ore", 0) inv_bucket = f"k{min(keys,2)}a{min(art,1)}w{min(wood,3)}o{min(ore,3)}" door = 1 if state.door_opened_global else 0 return f"{state.env_key}|{who}|{a.x},{a.y},{a.ori}|e{d}:{dx},{dy}|t{ahead}|hp{a.hp}|{inv_bucket}|D{door}|bp{state.base_progress}" def q_get(q: Dict[str, List[float]], key: str) -> List[float]: if key not in q: q[key] = [0.0 for _ in ACTIONS] return q[key] def epsilon_greedy(qvals: List[float], eps: float, r: np.random.Generator) -> int: if r.random() < eps: return int(r.integers(0, len(qvals))) return int(np.argmax(qvals)) def q_update(q: Dict[str, List[float]], key: str, a_idx: int, reward: float, next_key: str, alpha: float, gamma: float) -> Tuple[float, float, float]: qv = q_get(q, key) nq = q_get(q, next_key) old = qv[a_idx] target = reward + gamma * float(np.max(nq)) new = old + alpha * (target - old) qv[a_idx] = new return old, target, new # ----------------------------- # Baseline heuristics # ----------------------------- def heuristic_action(state: WorldState, who: str) -> str: a = state.agents[who] r = rng_for(state.seed, state.step, stream=900 + hash(who) % 1000) t_here = state.grid[a.y][a.x] if t_here in (FOOD, KEY, ARTIFACT, WOOD, ORE, MEDKIT, SWITCH, BASE, EXIT): return "I" best = None best_d = 999 for _, other in state.agents.items(): if other.hp <= 0 or other.team == a.team: continue d = manhattan_xy(a.x, a.y, other.x, other.y) if d < best_d: best_d = d best = other if best is not None and best_d <= 6 and visible(state, a, best): dx = best.x - a.x dy = best.y - a.y ang = (math.degrees(math.atan2(dy, dx)) % 360) facing = ORI_DEG[a.ori] diff = (ang - facing + 540) % 360 - 180 if diff < -10: return "L" if diff > 10: return "R" return "F" return r.choice(["F", "F", "L", "R", "I"]) def random_action(state: WorldState, who: str) -> str: r = rng_for(state.seed, state.step, stream=700 + hash(who) % 1000) return r.choice(ACTIONS) # ----------------------------- # Movement + interaction # ----------------------------- def turn_left(a: Agent) -> None: a.ori = (a.ori - 1) % 4 def turn_right(a: Agent) -> None: a.ori = (a.ori + 1) % 4 def move_forward(state: WorldState, a: Agent) -> str: dx, dy = DIRS[a.ori] nx, ny = a.x + dx, a.y + dy if not in_bounds(nx, ny): return "blocked: bounds" tile = state.grid[ny][nx] if is_blocking(tile, door_open=state.door_opened_global): return "blocked: wall/door" a.x, a.y = nx, ny if state.grid[ny][nx] == TELE: teles = [(x, y) for y in range(GRID_H) for x in range(GRID_W) if state.grid[y][x] == TELE] if len(teles) >= 2: teles_sorted = sorted(teles) idx = teles_sorted.index((nx, ny)) dest = teles_sorted[(idx + 1) % len(teles_sorted)] a.x, a.y = dest state.event_log.append(f"t={state.step}: {a.name} teleported.") return "moved: teleported" return "moved" def try_interact(state: WorldState, a: Agent) -> str: t = state.grid[a.y][a.x] if t == SWITCH: state.door_opened_global = True state.grid[a.y][a.x] = EMPTY a.inventory["switch"] = a.inventory.get("switch", 0) + 1 return "switch: opened all doors" if t == KEY: a.inventory["key"] = a.inventory.get("key", 0) + 1 state.grid[a.y][a.x] = EMPTY return "picked: key" if t == ARTIFACT: a.inventory["artifact"] = a.inventory.get("artifact", 0) + 1 state.grid[a.y][a.x] = EMPTY return "picked: artifact" if t == FOOD: a.energy = min(200, a.energy + 35) state.grid[a.y][a.x] = EMPTY return "ate: food" if t == WOOD: a.inventory["wood"] = a.inventory.get("wood", 0) + 1 state.grid[a.y][a.x] = EMPTY return "picked: wood" if t == ORE: a.inventory["ore"] = a.inventory.get("ore", 0) + 1 state.grid[a.y][a.x] = EMPTY return "picked: ore" if t == MEDKIT: a.hp = min(10, a.hp + 3) state.grid[a.y][a.x] = EMPTY return "used: medkit" if t == BASE: w = a.inventory.get("wood", 0) o = a.inventory.get("ore", 0) dep = min(w, 2) + min(o, 2) if dep > 0: a.inventory["wood"] = max(0, w - min(w, 2)) a.inventory["ore"] = max(0, o - min(o, 2)) state.base_progress += dep return f"deposited: +{dep} base_progress" return "base: nothing to deposit" if t == EXIT: return "at_exit" return "interact: none" def apply_action(state: WorldState, who: str, action: str) -> str: a = state.agents[who] if a.hp <= 0: return "dead" if action == "L": turn_left(a) return "turned left" if action == "R": turn_right(a) return "turned right" if action == "F": return move_forward(state, a) if action == "I": return try_interact(state, a) return "noop" # ----------------------------- # Hazards / collisions / done # ----------------------------- def resolve_hazards(state: WorldState, a: Agent) -> Tuple[bool, str]: if a.hp <= 0: return (False, "") if state.grid[a.y][a.x] == HAZARD: a.hp -= 1 return (True, "hazard:-hp") return (False, "") def resolve_tags(state: WorldState) -> List[str]: msgs = [] occupied: Dict[Tuple[int, int], List[str]] = {} for nm, a in state.agents.items(): if a.hp <= 0: continue occupied.setdefault((a.x, a.y), []).append(nm) for (x, y), names in occupied.items(): if len(names) < 2: continue teams = set(state.agents[n].team for n in names) if len(teams) >= 2: for n in names: state.agents[n].hp -= 1 msgs.append(f"t={state.step}: collision/tag at ({x},{y}) {names} (-hp all)") return msgs def check_done(state: WorldState) -> None: if state.env_key == "chase": pred = state.agents["Predator"] prey = state.agents["Prey"] if pred.hp <= 0 and prey.hp <= 0: state.done = True state.outcome = "draw" return if pred.hp > 0 and prey.hp > 0 and pred.x == prey.x and pred.y == prey.y: state.done = True state.outcome = "A_win" state.event_log.append(f"t={state.step}: CAUGHT (Predator wins).") return if state.step >= 300 and prey.hp > 0: state.done = True state.outcome = "B_win" state.event_log.append(f"t={state.step}: ESCAPED (Prey survives).") return if state.env_key == "vault": for nm in ["Alpha", "Bravo"]: a = state.agents[nm] if a.hp > 0 and a.inventory.get("artifact", 0) > 0 and state.grid[a.y][a.x] == EXIT: state.done = True state.outcome = "A_win" state.event_log.append(f"t={state.step}: VAULT CLEARED (Team A wins).") return alive_A = any(state.agents[n].hp > 0 for n in ["Alpha", "Bravo"]) if not alive_A: state.done = True state.outcome = "B_win" state.event_log.append(f"t={state.step}: TEAM A ELIMINATED (Guardian wins).") return if state.env_key == "civ": if state.base_progress >= state.base_target: state.done = True state.outcome = "A_win" state.event_log.append(f"t={state.step}: BASE COMPLETE (Builders win).") return alive_A = any(state.agents[n].hp > 0 for n in ["BuilderA", "BuilderB"]) if not alive_A: state.done = True state.outcome = "B_win" state.event_log.append(f"t={state.step}: BUILDERS ELIMINATED (Raider wins).") return if state.step >= 350: state.done = True state.outcome = "draw" state.event_log.append(f"t={state.step}: TIMEOUT (draw).") return # ----------------------------- # Rewards # ----------------------------- def reward_for(prev: WorldState, now: WorldState, who: str, outcome_msg: str, took_damage: bool) -> float: cfg = now.cfg r = cfg.step_penalty if outcome_msg.startswith("moved"): r += cfg.explore_reward if took_damage: r += cfg.damage_penalty if outcome_msg.startswith("used: medkit"): r += cfg.heal_reward if now.env_key == "chase": pred = now.agents["Predator"] prey = now.agents["Prey"] if who == "Predator": d0 = manhattan_xy(prev.agents["Predator"].x, prev.agents["Predator"].y, prev.agents["Prey"].x, prev.agents["Prey"].y) d1 = manhattan_xy(pred.x, pred.y, prey.x, prey.y) r += cfg.chase_close_coeff * float(d0 - d1) if now.done and now.outcome == "A_win": r += cfg.chase_catch_reward if who == "Prey": if outcome_msg.startswith("ate: food"): r += cfg.food_reward if now.done and now.outcome == "B_win": r += cfg.chase_escaped_reward if now.done and now.outcome == "A_win": r += cfg.chase_caught_penalty if now.env_key == "vault": if outcome_msg.startswith("picked: artifact"): r += cfg.artifact_pick_reward if outcome_msg.startswith("picked: key"): r += cfg.key_reward if outcome_msg.startswith("switch:"): r += cfg.switch_reward if now.done: if now.outcome == "A_win" and now.agents[who].team == "A": r += cfg.exit_win_reward if now.outcome == "B_win" and now.agents[who].team == "B": r += cfg.guardian_tag_reward if now.outcome == "B_win" and now.agents[who].team == "A": r += cfg.tagged_penalty if now.env_key == "civ": if outcome_msg.startswith("picked: wood") or outcome_msg.startswith("picked: ore"): r += cfg.resource_pick_reward if outcome_msg.startswith("deposited:"): r += cfg.deposit_reward if now.done: if now.outcome == "A_win" and now.agents[who].team == "A": r += cfg.base_progress_win_reward if now.outcome == "B_win" and now.agents[who].team == "B": r += cfg.raider_elim_reward if now.outcome == "B_win" and now.agents[who].team == "A": r += cfg.builder_elim_penalty return float(r) # ----------------------------- # Policy selection # ----------------------------- def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, Optional[Tuple[str, int]]]: a = state.agents[who] cfg = state.cfg r = rng_for(state.seed, state.step, stream=stream) if a.brain == "random": act = random_action(state, who) return act, "random", None if a.brain == "heuristic": act = heuristic_action(state, who) return act, "heuristic", None if cfg.use_q: key = obs_key(state, who) qtab = state.q_tables.setdefault(who, {}) qv = q_get(qtab, key) a_idx = epsilon_greedy(qv, state.gmetrics.epsilon, r) return ACTIONS[a_idx], f"Q eps={state.gmetrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (key, a_idx) act = heuristic_action(state, who) return act, "heuristic(fallback)", None # ----------------------------- # Init / reset # ----------------------------- def init_state(seed: int, env_key: str) -> WorldState: g, agents = ENV_BUILDERS[env_key](seed) st = WorldState( seed=seed, step=0, env_key=env_key, grid=g, agents=agents, controlled=list(agents.keys())[0], pov=list(agents.keys())[0], overlay=False, done=False, outcome="ongoing", door_opened_global=False, base_progress=0, base_target=10, ) st.event_log = [f"Initialized env={env_key} seed={seed}."] return st def reset_episode_keep_learning(state: WorldState, seed: Optional[int] = None) -> WorldState: if seed is None: seed = state.seed fresh = init_state(int(seed), state.env_key) fresh.cfg = state.cfg fresh.q_tables = state.q_tables fresh.gmetrics = state.gmetrics fresh.gmetrics.epsilon = state.gmetrics.epsilon return fresh def wipe_all(seed: int, env_key: str) -> WorldState: st = init_state(seed, env_key) st.cfg = TrainConfig() st.gmetrics = GlobalMetrics(epsilon=st.cfg.epsilon) st.q_tables = {} return st # ----------------------------- # History / branching # ----------------------------- def snapshot_of(state: WorldState, branch: str) -> Snapshot: return Snapshot( branch=branch, step=state.step, env_key=state.env_key, grid=[row[:] for row in state.grid], agents={k: asdict(v) for k, v in state.agents.items()}, done=state.done, outcome=state.outcome, door_opened_global=state.door_opened_global, base_progress=state.base_progress, base_target=state.base_target, event_tail=state.event_log[-25:], trace_tail=state.trace_log[-40:], emetrics=asdict(state.emetrics), ) def restore_into(state: WorldState, snap: Snapshot) -> WorldState: state.step = snap.step state.env_key = snap.env_key state.grid = [row[:] for row in snap.grid] state.agents = {k: Agent(**d) for k, d in snap.agents.items()} state.done = snap.done state.outcome = snap.outcome state.door_opened_global = snap.door_opened_global state.base_progress = snap.base_progress state.base_target = snap.base_target state.event_log.append(f"Jumped to snapshot t={snap.step} (branch={snap.branch}).") return state # ----------------------------- # Metrics / dashboard # ----------------------------- def metrics_dashboard_image(state: WorldState) -> Image.Image: gm = state.gmetrics fig = plt.figure(figsize=(7.0, 2.2), dpi=120) ax = fig.add_subplot(111) x1 = max(1, gm.episodes) ax.plot([0, x1], [gm.rolling_winrate_A, gm.rolling_winrate_A]) ax.set_title("Global Metrics Snapshot") ax.set_xlabel("Episodes") ax.set_ylabel("Rolling winrate Team A") ax.set_ylim(-0.05, 1.05) ax.grid(True) txt = ( f"env={state.env_key} | eps={gm.epsilon:.3f} | episodes={gm.episodes}\n" f"A_wins={gm.wins_teamA} B_wins={gm.wins_teamB} draws={gm.draws} | avg_steps~{gm.avg_steps:.1f}\n" f"last_outcome={gm.last_outcome} last_steps={gm.last_steps}" ) ax.text(0.01, 0.05, txt, transform=ax.transAxes, fontsize=8, va="bottom") fig.tight_layout() canvas = FigureCanvas(fig) canvas.draw() buf = np.asarray(canvas.buffer_rgba()) img = Image.fromarray(buf, mode="RGBA").convert("RGB") plt.close(fig) return img def action_entropy(counts: Dict[str, int]) -> float: total = sum(counts.values()) if total <= 0: return 0.0 p = np.array([c / total for c in counts.values()], dtype=np.float64) p = np.clip(p, 1e-12, 1.0) return float(-np.sum(p * np.log2(p))) def agent_scoreboard(state: WorldState) -> str: rows = [] header = ["agent", "team", "hp", "return", "steps", "entropy", "tiles_disc", "q_states", "inventory"] rows.append(header) steps = state.emetrics.steps for nm, a in state.agents.items(): ret = state.emetrics.returns.get(nm, 0.0) counts = state.emetrics.action_counts.get(nm, {}) ent = action_entropy(counts) td = state.emetrics.tiles_discovered.get(nm, 0) qs = len(state.q_tables.get(nm, {})) inv = json.dumps(a.inventory, sort_keys=True) rows.append([nm, a.team, a.hp, f"{ret:.2f}", steps, f"{ent:.2f}", td, qs, inv]) col_w = [max(len(str(r[i])) for r in rows) for i in range(len(header))] lines = [] for ridx, r in enumerate(rows): line = " | ".join(str(r[i]).ljust(col_w[i]) for i in range(len(header))) lines.append(line) if ridx == 0: lines.append("-+-".join("-" * w for w in col_w)) return "\n".join(lines) # ----------------------------- # Tick / training # ----------------------------- def clone_shallow(state: WorldState) -> WorldState: return WorldState( seed=state.seed, step=state.step, env_key=state.env_key, grid=[row[:] for row in state.grid], agents={k: Agent(**asdict(v)) for k, v in state.agents.items()}, controlled=state.controlled, pov=state.pov, overlay=state.overlay, done=state.done, outcome=state.outcome, door_opened_global=state.door_opened_global, base_progress=state.base_progress, base_target=state.base_target, event_log=list(state.event_log), trace_log=list(state.trace_log), cfg=state.cfg, q_tables=state.q_tables, gmetrics=state.gmetrics, emetrics=state.emetrics, ) def update_action_counts(state: WorldState, who: str, act: str): state.emetrics.action_counts.setdefault(who, {}) state.emetrics.action_counts[who][act] = state.emetrics.action_counts[who].get(act, 0) + 1 def tick(state: WorldState, beliefs: Dict[str, np.ndarray], manual_action: Optional[str] = None) -> None: if state.done: return prev = clone_shallow(state) chosen: Dict[str, str] = {} reasons: Dict[str, str] = {} qinfo: Dict[str, Optional[Tuple[str, int]]] = {} if manual_action is not None: chosen[state.controlled] = manual_action reasons[state.controlled] = "manual" qinfo[state.controlled] = None order = list(state.agents.keys()) for who in order: if who in chosen: continue act, reason, qi = choose_action(state, who, stream=200 + (hash(who) % 1000)) chosen[who] = act reasons[who] = reason qinfo[who] = qi outcomes: Dict[str, str] = {} took_damage: Dict[str, bool] = {nm: False for nm in order} for who in order: outcomes[who] = apply_action(state, who, chosen[who]) dmg, msg = resolve_hazards(state, state.agents[who]) took_damage[who] = dmg if msg: state.event_log.append(f"t={state.step}: {who} {msg}") update_action_counts(state, who, chosen[who]) for m in resolve_tags(state): state.event_log.append(m) for nm, a in state.agents.items(): if a.hp <= 0: continue disc = update_belief_for_agent(state, beliefs[nm], a) state.emetrics.tiles_discovered[nm] = state.emetrics.tiles_discovered.get(nm, 0) + disc check_done(state) q_lines = [] for who in order: state.emetrics.returns.setdefault(who, 0.0) r = reward_for(prev, state, who, outcomes[who], took_damage[who]) state.emetrics.returns[who] += r if qinfo.get(who) is not None: key, a_idx = qinfo[who] next_key = obs_key(state, who) qtab = state.q_tables.setdefault(who, {}) old, tgt, new = q_update(qtab, key, a_idx, r, next_key, state.cfg.alpha, state.cfg.gamma) q_lines.append(f"{who}: old={old:.3f} tgt={tgt:.3f} new={new:.3f} (a={ACTIONS[a_idx]})") trace = f"t={state.step} env={state.env_key} done={state.done} outcome={state.outcome}" for who in order: a = state.agents[who] trace += f" | {who}:{chosen[who]} ({outcomes[who]}) hp={a.hp} [{reasons[who]}]" if q_lines: trace += " | Q: " + " ; ".join(q_lines) state.trace_log.append(trace) if len(state.trace_log) > TRACE_MAX: state.trace_log = state.trace_log[-TRACE_MAX:] state.step += 1 state.emetrics.steps = state.step def run_episode(state: WorldState, beliefs: Dict[str, np.ndarray], max_steps: int) -> Tuple[str, int]: while state.step < max_steps and not state.done: tick(state, beliefs, manual_action=None) return state.outcome, state.step def update_global_metrics_after_episode(state: WorldState, outcome: str, steps: int): gm = state.gmetrics gm.episodes += 1 gm.last_outcome = outcome gm.last_steps = steps if outcome == "A_win": gm.wins_teamA += 1 gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 1.0 elif outcome == "B_win": gm.wins_teamB += 1 gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.0 else: gm.draws += 1 gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.5 gm.avg_steps = (0.90 * gm.avg_steps + 0.10 * steps) if gm.avg_steps > 0 else float(steps) gm.epsilon = max(state.cfg.epsilon_min, gm.epsilon * state.cfg.epsilon_decay) def train(state: WorldState, episodes: int, max_steps: int) -> WorldState: for ep in range(episodes): ep_seed = (state.seed * 1_000_003 + (state.gmetrics.episodes + ep) * 97_531) & 0xFFFFFFFF state = reset_episode_keep_learning(state, seed=int(ep_seed)) beliefs = init_beliefs(list(state.agents.keys())) outcome, steps = run_episode(state, beliefs, max_steps=max_steps) update_global_metrics_after_episode(state, outcome, steps) state.event_log.append( f"Training: +{episodes} eps | eps={state.gmetrics.epsilon:.3f} | " f"A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws}" ) state = reset_episode_keep_learning(state, seed=state.seed) return state # ----------------------------- # Export / Import # ----------------------------- def export_run(state: WorldState, branches: Dict[str, List[Snapshot]], active_branch: str, rewind_idx: int) -> str: payload = { "seed": state.seed, "env_key": state.env_key, "controlled": state.controlled, "pov": state.pov, "overlay": state.overlay, "cfg": asdict(state.cfg), "gmetrics": asdict(state.gmetrics), "q_tables": state.q_tables, "branches": {b: [asdict(s) for s in snaps] for b, snaps in branches.items()}, "active_branch": active_branch, "rewind_idx": int(rewind_idx), "grid": state.grid, "door_opened_global": state.door_opened_global, "base_progress": state.base_progress, "base_target": state.base_target, } txt = json.dumps(payload, indent=2) proof = hash_sha256(txt) return txt + "\n\n" + json.dumps({"proof_sha256": proof}, indent=2) def import_run(txt: str) -> Tuple[WorldState, Dict[str, List[Snapshot]], str, int, Dict[str, np.ndarray]]: parts = txt.strip().split("\n\n") data = json.loads(parts[0]) st = init_state(int(data.get("seed", 1337)), data.get("env_key", "chase")) st.controlled = data.get("controlled", st.controlled) st.pov = data.get("pov", st.pov) st.overlay = bool(data.get("overlay", False)) st.grid = data.get("grid", st.grid) st.door_opened_global = bool(data.get("door_opened_global", False)) st.base_progress = int(data.get("base_progress", 0)) st.base_target = int(data.get("base_target", 10)) st.cfg = TrainConfig(**data.get("cfg", asdict(st.cfg))) st.gmetrics = GlobalMetrics(**data.get("gmetrics", asdict(st.gmetrics))) st.q_tables = data.get("q_tables", {}) branches_in = data.get("branches", {}) branches: Dict[str, List[Snapshot]] = {} for bname, snaps in branches_in.items(): branches[bname] = [Snapshot(**s) for s in snaps] active = data.get("active_branch", "main") r_idx = int(data.get("rewind_idx", 0)) if active in branches and branches[active]: st = restore_into(st, branches[active][-1]) st.event_log.append("Imported run (restored last snapshot).") else: st.event_log.append("Imported run (no snapshots).") beliefs = init_beliefs(list(st.agents.keys())) return st, branches, active, r_idx, beliefs # ----------------------------- # UI helpers # ----------------------------- def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, Image.Image, str, str, str, str]: for nm, a in state.agents.items(): if a.hp > 0: update_belief_for_agent(state, beliefs[nm], a) pov = raycast_view(state, state.agents[state.pov]) truth_np = np.array(state.grid, dtype=np.int16) truth_img = render_topdown(truth_np, state.agents, f"Truth Map — env={state.env_key} t={state.step} seed={state.seed}", True) ctrl = state.controlled others = [k for k in state.agents.keys() if k != ctrl] other = others[0] if others else ctrl b_ctrl = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief", True) b_other = render_topdown(beliefs[other], state.agents, f"{other} Belief", True) dash = metrics_dashboard_image(state) status = ( f"env={state.env_key} | seed={state.seed} | Controlled={state.controlled} | POV={state.pov} | done={state.done} outcome={state.outcome}\n" f"Episode steps={state.step} | base_progress={state.base_progress}/{state.base_target} | doors_open={state.door_opened_global}\n" f"Global: episodes={state.gmetrics.episodes} | A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws} " f"| winrateA~{state.gmetrics.rolling_winrate_A:.2f} | eps={state.gmetrics.epsilon:.3f}" ) events = "\n".join(state.event_log[-18:]) trace = "\n".join(state.trace_log[-18:]) scoreboard = agent_scoreboard(state) return pov, truth_img, b_ctrl, b_other, dash, status, events, trace, scoreboard def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState) -> WorldState: x_px, y_px = evt.index y_px -= 28 if y_px < 0: return state gx = int(x_px // TILE) gy = int(y_px // TILE) if not in_bounds(gx, gy): return state if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1: return state state.grid[gy][gx] = selected_tile state.event_log.append(f"t={state.step}: Tile ({gx},{gy}) -> {TILE_NAMES.get(selected_tile)}") return state # ----------------------------- # Gradio app # ----------------------------- TITLE = "ZEN AgentLab — Agent POV + Autoplay Multi-Agent Sims" with gr.Blocks(title=TITLE) as demo: gr.Markdown( f"## {TITLE}\n" "**Press Start Autoplay** to watch the sim unfold live. Interject anytime with manual actions or edits.\n" "Use **Cinematic Run** for an instant full-episode spectacle. No background timers beyond the UI autoplay." ) st0 = init_state(1337, "chase") st = gr.State(st0) branches = gr.State({"main": [snapshot_of(st0, "main")]}) active_branch = gr.State("main") rewind_idx = gr.State(0) beliefs = gr.State(init_beliefs(list(st0.agents.keys()))) autoplay_on = gr.State(False) with gr.Row(): pov_img = gr.Image(label="POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H) with gr.Column(): status = gr.Textbox(label="Status", lines=3) scoreboard = gr.Textbox(label="Agent Scoreboard", lines=8) with gr.Row(): truth = gr.Image(label="Truth Map (click to edit tiles)", type="pil") belief_a = gr.Image(label="Belief (Controlled)", type="pil") belief_b = gr.Image(label="Belief (Other)", type="pil") with gr.Row(): dash = gr.Image(label="Metrics Dashboard", type="pil") with gr.Row(): events = gr.Textbox(label="Event Log", lines=10) trace = gr.Textbox(label="Step Trace", lines=10) with gr.Row(): with gr.Column(scale=2): gr.Markdown("### Quick Start (Examples)") examples = gr.Examples( examples=[ ["chase", 1337], ["vault", 2024], ["civ", 777], ], inputs=[], label="", ) gr.Markdown("Pick an environment + seed below, then click **Apply**.") with gr.Row(): env_pick = gr.Radio( choices=[("Chase (Predator vs Prey)", "chase"), ("CoopVault (team vs guardian)", "vault"), ("MiniCiv (build + raid)", "civ")], value="chase", label="Environment" ) seed_box = gr.Number(value=1337, precision=0, label="Seed") with gr.Row(): btn_apply_env_seed = gr.Button("Apply (Env + Seed)") btn_reset_ep = gr.Button("Reset Episode (keep learning)") gr.Markdown("### Autoplay + Spectacle") with gr.Row(): autoplay_speed = gr.Slider(0.05, 1.0, value=0.20, step=0.05, label="Autoplay step interval (seconds)") with gr.Row(): btn_autoplay_start = gr.Button("▶ Start Autoplay") btn_autoplay_stop = gr.Button("⏸ Stop Autoplay") with gr.Row(): cinematic_steps = gr.Number(value=350, precision=0, label="Cinematic max steps") btn_cinematic = gr.Button("🎬 Cinematic Run (Full Episode)") gr.Markdown("### Manual Controls (Interject Anytime)") with gr.Row(): btn_L = gr.Button("L") btn_F = gr.Button("F") btn_R = gr.Button("R") btn_I = gr.Button("I (Interact)") with gr.Row(): btn_tick = gr.Button("Tick") run_steps = gr.Number(value=25, label="Run N steps", precision=0) btn_run = gr.Button("Run") with gr.Row(): btn_toggle_control = gr.Button("Toggle Controlled") btn_toggle_pov = gr.Button("Toggle POV") overlay = gr.Checkbox(False, label="Overlay reticle") tile_pick = gr.Radio( choices=[(TILE_NAMES[k], k) for k in [EMPTY, WALL, FOOD, NOISE, DOOR, TELE, KEY, EXIT, ARTIFACT, HAZARD, WOOD, ORE, MEDKIT, SWITCH, BASE]], value=WALL, label="Paint tile type" ) with gr.Column(scale=3): gr.Markdown("### Training Controls (Tabular Q-learning)") use_q = gr.Checkbox(True, label="Use Q-learning (agents with brain='q')") alpha = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="alpha") gamma = gr.Slider(0.5, 0.99, value=0.95, step=0.01, label="gamma") eps = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="epsilon") eps_decay = gr.Slider(0.90, 0.999, value=0.995, step=0.001, label="epsilon decay") eps_min = gr.Slider(0.0, 0.2, value=0.02, step=0.01, label="epsilon min") episodes = gr.Number(value=50, label="Train episodes", precision=0) max_steps = gr.Number(value=260, label="Max steps/episode", precision=0) btn_train = gr.Button("Train") btn_reset_all = gr.Button("Reset ALL (wipe Q + metrics)") with gr.Row(): with gr.Column(scale=2): gr.Markdown("### Timeline + Branching") rewind = gr.Slider(0, 0, value=0, step=1, label="Rewind index (active branch)") btn_jump = gr.Button("Jump to index") new_branch_name = gr.Textbox(value="fork1", label="New branch name") btn_fork = gr.Button("Fork from current rewind") with gr.Column(scale=2): branch_pick = gr.Dropdown(choices=["main"], value="main", label="Active branch") btn_set_branch = gr.Button("Set Active Branch") with gr.Column(scale=3): export_box = gr.Textbox(label="Export JSON (+ proof hash)", lines=8) btn_export = gr.Button("Export") import_box = gr.Textbox(label="Import JSON", lines=8) btn_import = gr.Button("Import") # Autoplay timer (inactive by default) timer = gr.Timer(value=0.20, active=False) # ---------- glue ---------- def refresh(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str, bel: Dict[str, np.ndarray], r: int): snaps = branches_d.get(active, []) r_max = max(0, len(snaps) - 1) r = max(0, min(int(r), r_max)) pov, tr, ba, bb, dimg, stxt, etxt, ttxt, sb = build_views(state, bel) branch_choices = sorted(list(branches_d.keys())) return ( pov, tr, ba, bb, dimg, stxt, sb, etxt, ttxt, gr.update(maximum=r_max, value=r), r, gr.update(choices=branch_choices, value=active), gr.update(choices=branch_choices, value=active), ) def push_hist(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str) -> Dict[str, List[Snapshot]]: branches_d.setdefault(active, []) branches_d[active].append(snapshot_of(state, active)) if len(branches_d[active]) > MAX_HISTORY: branches_d[active].pop(0) return branches_d def set_cfg(state: WorldState, use_q_v: bool, a: float, g: float, e: float, ed: float, emin: float) -> WorldState: state.cfg.use_q = bool(use_q_v) state.cfg.alpha = float(a) state.cfg.gamma = float(g) state.gmetrics.epsilon = float(e) state.cfg.epsilon_decay = float(ed) state.cfg.epsilon_min = float(emin) return state def do_manual(state, branches_d, active, bel, r, act): tick(state, bel, manual_action=act) branches_d = push_hist(state, branches_d, active) r = len(branches_d[active]) - 1 out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def do_tick(state, branches_d, active, bel, r): tick(state, bel, manual_action=None) branches_d = push_hist(state, branches_d, active) r = len(branches_d[active]) - 1 out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def do_run(state, branches_d, active, bel, r, n): n = max(1, int(n)) for _ in range(n): if state.done: break tick(state, bel, manual_action=None) branches_d = push_hist(state, branches_d, active) r = len(branches_d[active]) - 1 out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def toggle_control(state, branches_d, active, bel, r): order = list(state.agents.keys()) i = order.index(state.controlled) state.controlled = order[(i + 1) % len(order)] state.event_log.append(f"Controlled -> {state.controlled}") branches_d = push_hist(state, branches_d, active) r = len(branches_d[active]) - 1 out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def toggle_pov(state, branches_d, active, bel, r): order = list(state.agents.keys()) i = order.index(state.pov) state.pov = order[(i + 1) % len(order)] state.event_log.append(f"POV -> {state.pov}") branches_d = push_hist(state, branches_d, active) r = len(branches_d[active]) - 1 out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def set_overlay(state, branches_d, active, bel, r, ov): state.overlay = bool(ov) out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def click_truth(tile, state, branches_d, active, bel, r, evt: gr.SelectData): state = grid_click_to_tile(evt, int(tile), state) branches_d = push_hist(state, branches_d, active) r = len(branches_d[active]) - 1 out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def jump(state, branches_d, active, bel, r, idx): snaps = branches_d.get(active, []) if not snaps: out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) idx = max(0, min(int(idx), len(snaps) - 1)) state = restore_into(state, snaps[idx]) r = idx out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def fork_branch(state, branches_d, active, bel, r, new_name): new_name = (new_name or "").strip() or "fork" new_name = new_name.replace(" ", "_") snaps = branches_d.get(active, []) if not snaps: branches_d[new_name] = [snapshot_of(state, new_name)] else: idx = max(0, min(int(r), len(snaps) - 1)) branches_d[new_name] = [Snapshot(**asdict(s)) for s in snaps[:idx + 1]] state = restore_into(state, branches_d[new_name][-1]) active = new_name state.event_log.append(f"Forked branch -> {new_name}") branches_d = push_hist(state, branches_d, active) r = len(branches_d[active]) - 1 out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def set_active_branch(state, branches_d, active, bel, r, br): br = br or "main" if br not in branches_d: branches_d[br] = [snapshot_of(state, br)] active = br if branches_d[active]: state = restore_into(state, branches_d[active][-1]) bel = init_beliefs(list(state.agents.keys())) r = len(branches_d[active]) - 1 out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def apply_env_seed(state, branches_d, active, bel, r, env_key, seed_val): env_key = env_key or "chase" seed_val = int(seed_val) if seed_val is not None else state.seed # Preserve learning across env swaps old_cfg = state.cfg old_q = state.q_tables old_gm = state.gmetrics state = init_state(seed_val, env_key) state.cfg = old_cfg state.q_tables = old_q state.gmetrics = old_gm bel = init_beliefs(list(state.agents.keys())) active = "main" branches_d = {"main": [snapshot_of(state, "main")]} r = 0 out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def reset_ep(state, branches_d, active, bel, r): state = reset_episode_keep_learning(state, seed=state.seed) bel = init_beliefs(list(state.agents.keys())) branches_d = {active: [snapshot_of(state, active)]} r = 0 out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def reset_all(state, branches_d, active, bel, r, env_key, seed_val): env_key = env_key or state.env_key seed_val = int(seed_val) if seed_val is not None else state.seed state = wipe_all(seed=seed_val, env_key=env_key) bel = init_beliefs(list(state.agents.keys())) active = "main" branches_d = {"main": [snapshot_of(state, "main")]} r = 0 out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def do_train(state, branches_d, active, bel, r, use_q_v, a, g, e, ed, emin, eps_count, max_s): state = set_cfg(state, use_q_v, a, g, e, ed, emin) state = train(state, episodes=max(1, int(eps_count)), max_steps=max(10, int(max_s))) bel = init_beliefs(list(state.agents.keys())) branches_d = {"main": [snapshot_of(state, "main")]} active = "main" r = 0 out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def cinematic_run(state, branches_d, active, bel, r, max_s): max_s = max(10, int(max_s)) # Reset episode so the cinematic is clean state = reset_episode_keep_learning(state, seed=state.seed) bel = init_beliefs(list(state.agents.keys())) # Run to completion (or max steps) in one click while state.step < max_s and not state.done: tick(state, bel, manual_action=None) state.event_log.append(f"Cinematic finished: outcome={state.outcome} steps={state.step}") branches_d = push_hist(state, branches_d, active) r = len(branches_d[active]) - 1 out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) def export_fn(state, branches_d, active, r): return export_run(state, branches_d, active, int(r)) def import_fn(txt): state, branches_d, active, r, bel = import_run(txt) branches_d.setdefault(active, []) if not branches_d[active]: branches_d[active].append(snapshot_of(state, active)) out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r) # ---- Autoplay control ---- def autoplay_start(state, branches_d, active, bel, r, interval_s): interval_s = float(interval_s) # Enable timer + autoplay flag return ( gr.update(value=interval_s, active=True), True, state, branches_d, active, bel, r ) def autoplay_stop(state, branches_d, active, bel, r): return ( gr.update(active=False), False, state, branches_d, active, bel, r ) def autoplay_tick(state, branches_d, active, bel, r, is_on: bool): # If not on, do nothing (also keep timer active state as-is) if not is_on: out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r, is_on, gr.update()) # Step once if not state.done: tick(state, bel, manual_action=None) branches_d = push_hist(state, branches_d, active) r = len(branches_d[active]) - 1 # If done, stop autoplay automatically if state.done: out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r, False, gr.update(active=False)) out = refresh(state, branches_d, active, bel, r) return out + (state, branches_d, active, bel, r, True, gr.update()) # ---- wiring ---- common_outputs = [ pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace, rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx ] btn_L.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"L"), inputs=[st, branches, active_branch, beliefs, rewind_idx], outputs=common_outputs, queue=True) btn_F.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"F"), inputs=[st, branches, active_branch, beliefs, rewind_idx], outputs=common_outputs, queue=True) btn_R.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"R"), inputs=[st, branches, active_branch, beliefs, rewind_idx], outputs=common_outputs, queue=True) btn_I.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"I"), inputs=[st, branches, active_branch, beliefs, rewind_idx], outputs=common_outputs, queue=True) btn_tick.click(do_tick, inputs=[st, branches, active_branch, beliefs, rewind_idx], outputs=common_outputs, queue=True) btn_run.click(do_run, inputs=[st, branches, active_branch, beliefs, rewind_idx, run_steps], outputs=common_outputs, queue=True) btn_toggle_control.click(toggle_control, inputs=[st, branches, active_branch, beliefs, rewind_idx], outputs=common_outputs, queue=True) btn_toggle_pov.click(toggle_pov, inputs=[st, branches, active_branch, beliefs, rewind_idx], outputs=common_outputs, queue=True) overlay.change(set_overlay, inputs=[st, branches, active_branch, beliefs, rewind_idx, overlay], outputs=common_outputs, queue=True) truth.select(click_truth, inputs=[tile_pick, st, branches, active_branch, beliefs, rewind_idx], outputs=common_outputs, queue=True) btn_jump.click(jump, inputs=[st, branches, active_branch, beliefs, rewind_idx, rewind], outputs=common_outputs, queue=True) btn_fork.click(fork_branch, inputs=[st, branches, active_branch, beliefs, rewind_idx, new_branch_name], outputs=common_outputs, queue=True) btn_set_branch.click(set_active_branch, inputs=[st, branches, active_branch, beliefs, rewind_idx, branch_pick], outputs=common_outputs, queue=True) btn_apply_env_seed.click(apply_env_seed, inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick, seed_box], outputs=common_outputs, queue=True) btn_reset_ep.click(reset_ep, inputs=[st, branches, active_branch, beliefs, rewind_idx], outputs=common_outputs, queue=True) btn_reset_all.click(reset_all, inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick, seed_box], outputs=common_outputs, queue=True) btn_train.click(do_train, inputs=[st, branches, active_branch, beliefs, rewind_idx, use_q, alpha, gamma, eps, eps_decay, eps_min, episodes, max_steps], outputs=common_outputs, queue=True) btn_cinematic.click(cinematic_run, inputs=[st, branches, active_branch, beliefs, rewind_idx, cinematic_steps], outputs=common_outputs, queue=True) btn_export.click(export_fn, inputs=[st, branches, active_branch, rewind_idx], outputs=[export_box], queue=True) btn_import.click(import_fn, inputs=[import_box], outputs=common_outputs, queue=True) # Autoplay start/stop wires btn_autoplay_start.click( autoplay_start, inputs=[st, branches, active_branch, beliefs, rewind_idx, autoplay_speed], outputs=[timer, autoplay_on, st, branches, active_branch, beliefs, rewind_idx], queue=True ) btn_autoplay_stop.click( autoplay_stop, inputs=[st, branches, active_branch, beliefs, rewind_idx], outputs=[timer, autoplay_on, st, branches, active_branch, beliefs, rewind_idx], queue=True ) # Timer tick: step and update UI; auto-stop when done timer.tick( autoplay_tick, inputs=[st, branches, active_branch, beliefs, rewind_idx, autoplay_on], outputs=[ pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace, rewind, rewind_idx, branch_pick, branch_pick, st, branches, active_branch, beliefs, rewind_idx, autoplay_on, timer ], queue=True ) demo.load( refresh, inputs=[st, branches, active_branch, beliefs, rewind_idx], outputs=[ pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace, rewind, rewind_idx, branch_pick, branch_pick ], queue=True ) # Disable SSR for HF stability demo.queue().launch(ssr_mode=False)