import json import math from dataclasses import dataclass, asdict from typing import Dict, List, Tuple, Optional import numpy as np from PIL import Image, ImageDraw import gradio as gr # ============================================================ # ChronoSandbox++ — Instrumented Training Arena # - Deterministic gridworld + first-person raycast view # - Click-to-edit environment (tiles) # - Full step trace: obs -> action -> reward -> q-update rationale # - Optional Q-learning (tabular) for Predator + Prey # - Batch training: run episodes fast, track metrics # - Export/import: environment, history, Q-tables, metrics # # Compatibility: avoids fn_kwargs + avoids gr.Timer # ============================================================ # ----------------------------- # Config # ----------------------------- GRID_W, GRID_H = 21, 15 TILE = 22 VIEW_W, VIEW_H = 640, 360 RAY_W = 320 FOV_DEG = 78 MAX_DEPTH = 20 DIRS = [(1, 0), (0, 1), (-1, 0), (0, -1)] ORI_DEG = [0, 90, 180, 270] EMPTY = 0 WALL = 1 FOOD = 2 NOISE = 3 DOOR = 4 TELE = 5 TILE_NAMES = { EMPTY: "Empty", WALL: "Wall", FOOD: "Food", NOISE: "Noise", DOOR: "Door", TELE: "Teleporter", } AGENT_COLORS = { "Predator": (255, 120, 90), "Prey": (120, 255, 160), "Scout": (120, 190, 255), } SKY = np.array([14, 16, 26], dtype=np.uint8) FLOOR_NEAR = np.array([24, 26, 40], dtype=np.uint8) FLOOR_FAR = np.array([10, 11, 18], dtype=np.uint8) WALL_BASE = np.array([210, 210, 225], dtype=np.uint8) WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8) DOOR_COL = np.array([180, 210, 255], dtype=np.uint8) ACTIONS = ["L", "F", "R"] # keep small for tabular learning stability # ----------------------------- # Deterministic RNG streams # ----------------------------- def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator: mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531) return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF) # ----------------------------- # Data structures # ----------------------------- @dataclass class Agent: name: str x: int y: int ori: int energy: int = 100 @dataclass class TrainConfig: use_q_pred: bool = True use_q_prey: bool = True alpha: float = 0.15 gamma: float = 0.95 epsilon: float = 0.10 epsilon_min: float = 0.02 epsilon_decay: float = 0.995 # reward shaping pred_step_penalty: float = -0.02 pred_dist_coeff: float = 0.03 pred_catch_reward: float = 3.0 prey_step_penalty: float = -0.02 prey_food_reward: float = 0.6 prey_survive_reward: float = 0.02 prey_caught_penalty: float = -3.0 @dataclass class Metrics: episodes: int = 0 catches: int = 0 avg_steps_to_catch: float = 0.0 avg_path_efficiency: float = 0.0 # optimal / actual (0..1) last_episode_steps: int = 0 last_episode_eff: float = 0.0 epsilon: float = 0.10 @dataclass class WorldState: seed: int step: int grid: List[List[int]] agents: Dict[str, Agent] controlled: str pov: str overlay: bool caught: bool branches: Dict[str, int] # instrumentation event_log: List[str] trace_log: List[str] # more detailed step trace (bounded) # training cfg: TrainConfig q_pred: Dict[str, List[float]] q_prey: Dict[str, List[float]] metrics: Metrics @dataclass class Snapshot: step: int agents: Dict[str, Dict] grid: List[List[int]] caught: bool event_log_tail: List[str] trace_tail: List[str] # ----------------------------- # Environment # ----------------------------- def default_grid() -> List[List[int]]: g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)] for x in range(GRID_W): g[0][x] = WALL g[GRID_H - 1][x] = WALL for y in range(GRID_H): g[y][0] = WALL g[y][GRID_W - 1] = WALL for x in range(4, 17): g[7][x] = WALL g[7][10] = DOOR g[3][4] = FOOD g[11][15] = FOOD g[4][14] = NOISE g[12][5] = NOISE g[2][18] = TELE g[13][2] = TELE return g def init_state(seed: int) -> WorldState: agents = { "Predator": Agent("Predator", 2, 2, 0, 100), "Prey": Agent("Prey", 18, 12, 2, 100), "Scout": Agent("Scout", 10, 3, 1, 100), } cfg = TrainConfig() return WorldState( seed=seed, step=0, grid=default_grid(), agents=agents, controlled="Predator", pov="Predator", overlay=False, caught=False, branches={"main": 0}, event_log=["Initialized world."], trace_log=[], cfg=cfg, q_pred={}, q_prey={}, metrics=Metrics(epsilon=cfg.epsilon), ) # ----------------------------- # Belief maps # ----------------------------- def init_belief() -> Dict[str, np.ndarray]: b = {} for nm in ["Predator", "Prey", "Scout"]: b[nm] = -1 * np.ones((GRID_H, GRID_W), dtype=np.int16) return b # ----------------------------- # Helpers # ----------------------------- def in_bounds(x: int, y: int) -> bool: return 0 <= x < GRID_W and 0 <= y < GRID_H def is_blocking(tile: int) -> bool: return tile == WALL def manhattan(a: Agent, b: Agent) -> int: return abs(a.x - b.x) + abs(a.y - b.y) def bresenham_los(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> bool: dx = abs(x1 - x0) dy = abs(y1 - y0) sx = 1 if x0 < x1 else -1 sy = 1 if y0 < y1 else -1 err = dx - dy x, y = x0, y0 while True: if (x, y) != (x0, y0) and (x, y) != (x1, y1): if grid[y][x] == WALL: return False if x == x1 and y == y1: return True e2 = 2 * err if e2 > -dy: err -= dy x += sx if e2 < dx: err += dx y += sy def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = FOV_DEG) -> bool: dx = tx - observer.x dy = ty - observer.y if dx == 0 and dy == 0: return True angle = math.degrees(math.atan2(dy, dx)) % 360 facing = ORI_DEG[observer.ori] diff = (angle - facing + 540) % 360 - 180 return abs(diff) <= (fov_deg / 2) def visible(observer: Agent, target: Agent, grid: List[List[int]]) -> bool: return within_fov(observer, target.x, target.y, FOV_DEG) and bresenham_los(grid, observer.x, observer.y, target.x, target.y) # ----------------------------- # Movement # ----------------------------- def turn_left(a: Agent) -> None: a.ori = (a.ori - 1) % 4 def turn_right(a: Agent) -> None: a.ori = (a.ori + 1) % 4 def move_forward(state: WorldState, a: Agent) -> str: dx, dy = DIRS[a.ori] nx, ny = a.x + dx, a.y + dy if not in_bounds(nx, ny): return "blocked: bounds" if is_blocking(state.grid[ny][nx]): return "blocked: wall" if state.grid[ny][nx] == DOOR: state.grid[ny][nx] = EMPTY state.event_log.append(f"t={state.step}: {a.name} opened a door.") a.x, a.y = nx, ny if state.grid[ny][nx] == TELE: teles = [(x, y) for y in range(GRID_H) for x in range(GRID_W) if state.grid[y][x] == TELE] if len(teles) >= 2: teles_sorted = sorted(teles) idx = teles_sorted.index((nx, ny)) dest = teles_sorted[(idx + 1) % len(teles_sorted)] a.x, a.y = dest state.event_log.append(f"t={state.step}: {a.name} teleported.") return "moved: teleported" return "moved" def apply_action(state: WorldState, agent_name: str, action: str) -> str: a = state.agents[agent_name] if action == "L": turn_left(a) return "turned left" if action == "R": turn_right(a) return "turned right" if action == "F": return move_forward(state, a) return "noop" # ----------------------------- # Rendering # ----------------------------- def raycast_view(state: WorldState, observer: Agent) -> np.ndarray: img = np.zeros((VIEW_H, VIEW_W, 3), dtype=np.uint8) img[:, :] = SKY for y in range(VIEW_H // 2, VIEW_H): t = (y - VIEW_H // 2) / (VIEW_H // 2 + 1e-6) col = (1 - t) * FLOOR_NEAR + t * FLOOR_FAR img[y, :] = col.astype(np.uint8) fov = math.radians(FOV_DEG) half_fov = fov / 2 for rx in range(RAY_W): cam_x = (2 * rx / (RAY_W - 1)) - 1 ray_ang = math.radians(ORI_DEG[observer.ori]) + cam_x * half_fov ox, oy = observer.x + 0.5, observer.y + 0.5 sin_a = math.sin(ray_ang) cos_a = math.cos(ray_ang) depth = 0.0 hit = None # None, "wall", "door" side = 0 while depth < MAX_DEPTH: depth += 0.05 tx = int(ox + cos_a * depth) ty = int(oy + sin_a * depth) if not in_bounds(tx, ty): break tile = state.grid[ty][tx] if tile == WALL: hit = "wall" side = 1 if abs(cos_a) > abs(sin_a) else 0 break if tile == DOOR: hit = "door" break if hit is None: continue depth *= math.cos(ray_ang - math.radians(ORI_DEG[observer.ori])) depth = max(depth, 0.001) proj_h = int((VIEW_H * 0.9) / depth) y0 = max(0, VIEW_H // 2 - proj_h // 2) y1 = min(VIEW_H - 1, VIEW_H // 2 + proj_h // 2) if hit == "door": col = DOOR_COL.copy() else: col = WALL_BASE.copy() if side == 0 else WALL_SIDE.copy() dim = max(0.25, 1.0 - (depth / MAX_DEPTH)) col = (col * dim).astype(np.uint8) x0 = int(rx * (VIEW_W / RAY_W)) x1 = int((rx + 1) * (VIEW_W / RAY_W)) img[y0:y1, x0:x1] = col # billboards for visible agents for nm, other in state.agents.items(): if nm == observer.name: continue if visible(observer, other, state.grid): dx = other.x - observer.x dy = other.y - observer.y ang = (math.degrees(math.atan2(dy, dx)) % 360) facing = ORI_DEG[observer.ori] diff = (ang - facing + 540) % 360 - 180 sx = int((diff / (FOV_DEG / 2)) * (VIEW_W / 2) + (VIEW_W / 2)) dist = math.sqrt(dx * dx + dy * dy) h = int((VIEW_H * 0.65) / max(dist, 0.75)) w = max(10, h // 3) y_mid = VIEW_H // 2 y0 = max(0, y_mid - h // 2) y1 = min(VIEW_H - 1, y_mid + h // 2) x0 = max(0, sx - w // 2) x1 = min(VIEW_W - 1, sx + w // 2) col = AGENT_COLORS.get(nm, (255, 200, 120)) img[y0:y1, x0:x1] = np.array(col, dtype=np.uint8) if state.overlay: cx, cy = VIEW_W // 2, VIEW_H // 2 img[cy - 1:cy + 2, cx - 10:cx + 10] = np.array([120, 190, 255], dtype=np.uint8) img[cy - 10:cy + 10, cx - 1:cx + 2] = np.array([120, 190, 255], dtype=np.uint8) return img def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_agents: bool = True) -> Image.Image: w = grid.shape[1] * TILE h = grid.shape[0] * TILE im = Image.new("RGB", (w, h + 28), (10, 12, 18)) draw = ImageDraw.Draw(im) for y in range(grid.shape[0]): for x in range(grid.shape[1]): t = int(grid[y, x]) if t == -1: col = (18, 20, 32) elif t == EMPTY: col = (26, 30, 44) elif t == WALL: col = (190, 190, 210) elif t == FOOD: col = (255, 210, 120) elif t == NOISE: col = (255, 120, 220) elif t == DOOR: col = (140, 210, 255) elif t == TELE: col = (120, 190, 255) else: col = (80, 80, 90) x0, y0 = x * TILE, y * TILE + 28 draw.rectangle([x0, y0, x0 + TILE - 1, y0 + TILE - 1], fill=col) for x in range(grid.shape[1] + 1): xx = x * TILE draw.line([xx, 28, xx, h + 28], fill=(12, 14, 22)) for y in range(grid.shape[0] + 1): yy = y * TILE + 28 draw.line([0, yy, w, yy], fill=(12, 14, 22)) if show_agents: for nm, a in agents.items(): cx = a.x * TILE + TILE // 2 cy = a.y * TILE + 28 + TILE // 2 col = AGENT_COLORS.get(nm, (220, 220, 220)) r = TILE // 3 draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col) dx, dy = DIRS[a.ori] draw.line([cx, cy, cx + dx * r, cy + dy * r], fill=(10, 10, 10), width=3) draw.rectangle([0, 0, w, 28], fill=(14, 16, 26)) draw.text((8, 6), title, fill=(230, 230, 240)) return im # ----------------------------- # Belief updates # ----------------------------- def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> None: belief[agent.y, agent.x] = state.grid[agent.y][agent.x] base = math.radians(ORI_DEG[agent.ori]) half = math.radians(FOV_DEG / 2) rays = 33 if agent.name != "Scout" else 45 for i in range(rays): t = i / (rays - 1) ang = base + (t * 2 - 1) * half sin_a, cos_a = math.sin(ang), math.cos(ang) ox, oy = agent.x + 0.5, agent.y + 0.5 depth = 0.0 while depth < MAX_DEPTH: depth += 0.2 tx = int(ox + cos_a * depth) ty = int(oy + sin_a * depth) if not in_bounds(tx, ty): break belief[ty, tx] = state.grid[ty][tx] if state.grid[ty][tx] == WALL: break # ----------------------------- # Optimal distance (BFS) for efficiency metric # ----------------------------- def bfs_distance(grid: List[List[int]], sx: int, sy: int, gx: int, gy: int) -> Optional[int]: if (sx, sy) == (gx, gy): return 0 q = [(sx, sy)] dist = { (sx, sy): 0 } head = 0 while head < len(q): x, y = q[head]; head += 1 for dx, dy in DIRS: nx, ny = x + dx, y + dy if not in_bounds(nx, ny): continue if grid[ny][nx] == WALL: continue if (nx, ny) in dist: continue dist[(nx, ny)] = dist[(x, y)] + 1 if (nx, ny) == (gx, gy): return dist[(nx, ny)] q.append((nx, ny)) return None # ----------------------------- # Observation encoding (compact state key) # ----------------------------- def obs_key(state: WorldState, who: str) -> str: pred = state.agents["Predator"] prey = state.agents["Prey"] a = state.agents[who] # relative position coarse-binned to keep table smaller dx = prey.x - pred.x dy = prey.y - pred.y dx_bin = int(np.clip(dx, -6, 6)) dy_bin = int(np.clip(dy, -6, 6)) vis = 1 if visible(pred, prey, state.grid) else 0 # include own orientation and role if who == "Predator": return f"P|{pred.x},{pred.y},{pred.ori}|d{dx_bin},{dy_bin}|v{vis}" if who == "Prey": # prey cares if predator is visible to it vis2 = 1 if visible(prey, pred, state.grid) else 0 ddx = pred.x - prey.x ddy = pred.y - prey.y ddx_bin = int(np.clip(ddx, -6, 6)) ddy_bin = int(np.clip(ddy, -6, 6)) return f"R|{prey.x},{prey.y},{prey.ori}|d{ddx_bin},{ddy_bin}|v{vis2}|e{int(prey.energy//25)}" # Scout: simple return f"S|{a.x},{a.y},{a.ori}" def q_get(q: Dict[str, List[float]], key: str) -> List[float]: if key not in q: q[key] = [0.0, 0.0, 0.0] return q[key] def epsilon_greedy(qvals: List[float], eps: float, r: np.random.Generator) -> int: if r.random() < eps: return int(r.integers(0, len(qvals))) return int(np.argmax(qvals)) def q_update(q: Dict[str, List[float]], key: str, a_idx: int, reward: float, next_key: str, alpha: float, gamma: float) -> Tuple[float, float, float]: qv = q_get(q, key) nq = q_get(q, next_key) old = qv[a_idx] target = reward + gamma * float(np.max(nq)) new = old + alpha * (target - old) qv[a_idx] = new return old, target, new # ----------------------------- # Baseline heuristic policies (for Scout + fallback) # ----------------------------- def heuristic_pred_action(state: WorldState) -> str: pred = state.agents["Predator"] prey = state.agents["Prey"] if visible(pred, prey, state.grid): dx = prey.x - pred.x dy = prey.y - pred.y ang = (math.degrees(math.atan2(dy, dx)) % 360) facing = ORI_DEG[pred.ori] diff = (ang - facing + 540) % 360 - 180 if diff < -10: return "L" if diff > 10: return "R" return "F" r = rng_for(state.seed, state.step, stream=11) return r.choice(ACTIONS) def heuristic_prey_action(state: WorldState) -> str: prey = state.agents["Prey"] pred = state.agents["Predator"] if visible(prey, pred, state.grid): dx = pred.x - prey.x dy = pred.y - prey.y ang = (math.degrees(math.atan2(dy, dx)) % 360) facing = ORI_DEG[prey.ori] diff = (ang - facing + 540) % 360 - 180 diff_away = ((diff + 180) + 540) % 360 - 180 if diff_away < -10: return "L" if diff_away > 10: return "R" return "F" r = rng_for(state.seed, state.step, stream=12) return r.choice(ACTIONS) def heuristic_scout_action(state: WorldState) -> str: r = rng_for(state.seed, state.step, stream=13) return r.choice(ACTIONS) # ----------------------------- # Reward shaping # ----------------------------- def pred_reward(state_prev: WorldState, state_now: WorldState) -> float: cfg = state_now.cfg pred0 = state_prev.agents["Predator"] prey0 = state_prev.agents["Prey"] pred1 = state_now.agents["Predator"] prey1 = state_now.agents["Prey"] d0 = abs(pred0.x - prey0.x) + abs(pred0.y - prey0.y) d1 = abs(pred1.x - prey1.x) + abs(pred1.y - prey1.y) r = cfg.pred_step_penalty + cfg.pred_dist_coeff * (d0 - d1) # reward closing distance if state_now.caught: r += cfg.pred_catch_reward return float(r) def prey_reward(state_prev: WorldState, state_now: WorldState, ate_food: bool) -> float: cfg = state_now.cfg r = cfg.prey_step_penalty + cfg.prey_survive_reward if ate_food: r += cfg.prey_food_reward if state_now.caught: r += cfg.prey_caught_penalty return float(r) # ----------------------------- # Core simulation tick (with instrumentation + optional learning) # ----------------------------- TRACE_MAX = 400 def clone_shallow(state: WorldState) -> WorldState: # clone for reward computation, minimal fields return WorldState( seed=state.seed, step=state.step, grid=[row[:] for row in state.grid], agents={k: Agent(**asdict(v)) for k, v in state.agents.items()}, controlled=state.controlled, pov=state.pov, overlay=state.overlay, caught=state.caught, branches=dict(state.branches), event_log=list(state.event_log), trace_log=list(state.trace_log), cfg=state.cfg, q_pred=state.q_pred, q_prey=state.q_prey, metrics=state.metrics, ) def check_catch(state: WorldState) -> None: pred = state.agents["Predator"] prey = state.agents["Prey"] if pred.x == prey.x and pred.y == prey.y: state.caught = True state.event_log.append(f"t={state.step}: CAUGHT.") def consume_food(state: WorldState) -> bool: prey = state.agents["Prey"] if state.grid[prey.y][prey.x] == FOOD: prey.energy = min(200, prey.energy + 35) state.grid[prey.y][prey.x] = EMPTY state.event_log.append(f"t={state.step}: Prey ate food (+energy).") return True return False def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, Optional[Tuple[str,int]]]: """ Returns (action, reason, q_info) q_info: (obs_key, action_index) if chosen by Q, else None """ cfg = state.cfg r = rng_for(state.seed, state.step, stream=stream) if who == "Predator" and cfg.use_q_pred: k = obs_key(state, "Predator") qv = q_get(state.q_pred, k) a_idx = epsilon_greedy(qv, state.metrics.epsilon, r) return ACTIONS[a_idx], f"Q(pred) eps={state.metrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (k, a_idx) if who == "Prey" and cfg.use_q_prey: k = obs_key(state, "Prey") qv = q_get(state.q_prey, k) a_idx = epsilon_greedy(qv, state.metrics.epsilon, r) return ACTIONS[a_idx], f"Q(prey) eps={state.metrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (k, a_idx) # fallbacks if who == "Predator": a = heuristic_pred_action(state) return a, "heuristic(pred)", None if who == "Prey": a = heuristic_prey_action(state) return a, "heuristic(prey)", None a = heuristic_scout_action(state) return a, "heuristic(scout)", None def tick(state: WorldState, manual_action: Optional[str] = None) -> None: if state.caught: return prev = clone_shallow(state) # record optimal distance for efficiency stats pred = state.agents["Predator"] prey = state.agents["Prey"] opt_dist = bfs_distance(state.grid, pred.x, pred.y, prey.x, prey.y) if opt_dist is None: opt_dist = 999 # Action selection chosen = {} reasons = {} qinfo = {} # manual action applies to controlled agent if manual_action: chosen[state.controlled] = manual_action reasons[state.controlled] = "manual" qinfo[state.controlled] = None # others choose for who in ["Predator", "Prey", "Scout"]: if who in chosen: continue act, reason, q_i = choose_action(state, who, stream={"Predator":21,"Prey":22,"Scout":23}[who]) chosen[who] = act reasons[who] = reason qinfo[who] = q_i # Apply actions (deterministic order) outcomes = {} for who in ["Predator", "Prey", "Scout"]: outcomes[who] = apply_action(state, who, chosen[who]) ate = consume_food(state) check_catch(state) # Rewards + Q-updates pred_r = pred_reward(prev, state) prey_r = prey_reward(prev, state, ate_food=ate) q_lines = [] if qinfo["Predator"] is not None: k, a_idx = qinfo["Predator"] nk = obs_key(state, "Predator") old, target, new = q_update(state.q_pred, k, a_idx, pred_r, nk, state.cfg.alpha, state.cfg.gamma) q_lines.append(f"Qpred: {k} a={ACTIONS[a_idx]} old={old:.3f} tgt={target:.3f} new={new:.3f}") if qinfo["Prey"] is not None: k, a_idx = qinfo["Prey"] nk = obs_key(state, "Prey") old, target, new = q_update(state.q_prey, k, a_idx, prey_r, nk, state.cfg.alpha, state.cfg.gamma) q_lines.append(f"Qprey: {k} a={ACTIONS[a_idx]} old={old:.3f} tgt={target:.3f} new={new:.3f}") # Trace line dist_now = manhattan(state.agents["Predator"], state.agents["Prey"]) eff = (opt_dist / max(1, dist_now)) if dist_now > 0 else 1.0 trace = ( f"t={state.step} optDist~{opt_dist} distNow={dist_now} " f"| Pred:{chosen['Predator']} ({outcomes['Predator']}) [{reasons['Predator']}] r={pred_r:+.3f} " f"| Prey:{chosen['Prey']} ({outcomes['Prey']}) [{reasons['Prey']}] r={prey_r:+.3f} " f"| Scout:{chosen['Scout']} ({outcomes['Scout']}) [{reasons['Scout']}] " f"| ateFood={ate} caught={state.caught}" ) if q_lines: trace += " | " + " ; ".join(q_lines) state.trace_log.append(trace) if len(state.trace_log) > TRACE_MAX: state.trace_log = state.trace_log[-TRACE_MAX:] state.step += 1 # ----------------------------- # Episode reset + training # ----------------------------- def reset_episode(state: WorldState, seed: Optional[int] = None) -> None: # Keep Q-tables + cfg + metrics; reset world + logs if seed is None: seed = state.seed fresh = init_state(seed) fresh.cfg = state.cfg fresh.q_pred = state.q_pred fresh.q_prey = state.q_prey fresh.metrics = state.metrics fresh.metrics.epsilon = state.metrics.epsilon state.seed = fresh.seed state.step = 0 state.grid = fresh.grid state.agents = fresh.agents state.controlled = fresh.controlled state.pov = fresh.pov state.overlay = fresh.overlay state.caught = False state.branches = fresh.branches state.event_log = ["Episode reset."] state.trace_log = [] def run_episode(state: WorldState, max_steps: int) -> Tuple[bool, int, float]: # returns (caught, steps, path_eff) start_pred = state.agents["Predator"] start_prey = state.agents["Prey"] opt = bfs_distance(state.grid, start_pred.x, start_pred.y, start_prey.x, start_prey.y) if opt is None: opt = 999 steps = 0 while steps < max_steps and not state.caught: tick(state, manual_action=None) steps += 1 caught = state.caught eff = float(opt / max(1, steps)) if opt < 999 else 0.0 return caught, steps, eff def train(state: WorldState, episodes: int, max_steps: int) -> None: m = state.metrics cfg = state.cfg catches = 0 total_steps_catch = 0 total_eff = 0.0 for ep in range(episodes): # deterministically vary episode seed so it doesn't memorize one map-layout only ep_seed = (state.seed * 1_000_003 + (m.episodes + ep) * 97_531) & 0xFFFFFFFF reset_episode(state, seed=int(ep_seed)) caught, steps, eff = run_episode(state, max_steps=max_steps) total_eff += eff if caught: catches += 1 total_steps_catch += steps # epsilon decay m.epsilon = max(cfg.epsilon_min, m.epsilon * cfg.epsilon_decay) # Update metrics m.episodes += episodes m.catches += catches m.last_episode_steps = steps m.last_episode_eff = eff if catches > 0: # moving average by episode count for stability avg_steps = total_steps_catch / catches m.avg_steps_to_catch = ( 0.85 * m.avg_steps_to_catch + 0.15 * avg_steps if m.avg_steps_to_catch > 0 else avg_steps ) avg_eff = total_eff / max(1, episodes) m.avg_path_efficiency = ( 0.85 * m.avg_path_efficiency + 0.15 * avg_eff if m.avg_path_efficiency > 0 else avg_eff ) state.event_log.append( f"Training: +{episodes} eps | catches={catches}/{episodes} | " f"avgStepsToCatch~{m.avg_steps_to_catch:.2f} | avgEff~{m.avg_path_efficiency:.2f} | eps={m.epsilon:.3f}" ) # ----------------------------- # History / snapshots # ----------------------------- MAX_HISTORY = 1200 def snapshot_of(state: WorldState) -> Snapshot: return Snapshot( step=state.step, agents={k: asdict(v) for k, v in state.agents.items()}, grid=[row[:] for row in state.grid], caught=state.caught, event_log_tail=state.event_log[-20:], trace_tail=state.trace_log[-40:], ) def restore_into(state: WorldState, snap: Snapshot) -> None: state.step = snap.step state.grid = [row[:] for row in snap.grid] for k, d in snap.agents.items(): state.agents[k] = Agent(**d) state.caught = snap.caught state.event_log.append(f"Jumped to snapshot t={snap.step}.") # ----------------------------- # Export / import # ----------------------------- def export_run(state: WorldState, history: List[Snapshot]) -> str: payload = { "seed": state.seed, "controlled": state.controlled, "pov": state.pov, "overlay": state.overlay, "cfg": asdict(state.cfg), "metrics": asdict(state.metrics), "q_pred": state.q_pred, "q_prey": state.q_prey, "history": [asdict(s) for s in history], "grid": state.grid, } return json.dumps(payload, indent=2) def import_run(txt: str) -> Tuple[WorldState, List[Snapshot], Dict[str, np.ndarray], int]: data = json.loads(txt) st = init_state(int(data.get("seed", 1337))) st.controlled = data.get("controlled", st.controlled) st.pov = data.get("pov", st.pov) st.overlay = bool(data.get("overlay", False)) st.grid = data.get("grid", st.grid) st.cfg = TrainConfig(**data.get("cfg", asdict(st.cfg))) st.metrics = Metrics(**data.get("metrics", asdict(st.metrics))) st.q_pred = data.get("q_pred", {}) st.q_prey = data.get("q_prey", {}) hist = [Snapshot(**s) for s in data.get("history", [])] bel = init_belief() r_idx = max(0, len(hist) - 1) if hist: restore_into(st, hist[-1]) st.event_log.append("Imported run.") return st, hist, bel, r_idx # ----------------------------- # UI glue # ----------------------------- def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, str, str, str]: for nm, a in state.agents.items(): update_belief_for_agent(state, beliefs[nm], a) pov = raycast_view(state, state.agents[state.pov]) truth_np = np.array(state.grid, dtype=np.int16) truth_img = render_topdown(truth_np, state.agents, f"Truth Map — t={state.step} seed={state.seed}", show_agents=True) ctrl = state.controlled other = "Prey" if ctrl == "Predator" else "Predator" b_ctrl = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief", show_agents=True) b_other = render_topdown(beliefs[other], state.agents, f"{other} Belief", show_agents=True) m = state.metrics pred = state.agents["Predator"] prey = state.agents["Prey"] scout = state.agents["Scout"] status = ( f"Controlled={state.controlled} | POV={state.pov} | caught={state.caught} | eps={m.epsilon:.3f}\n" f"Episodes={m.episodes} | catches={m.catches} | avgStepsToCatch~{m.avg_steps_to_catch:.2f} | avgEff~{m.avg_path_efficiency:.2f}\n" f"Pred({pred.x},{pred.y}) o={pred.ori} | Prey({prey.x},{prey.y}) o={prey.ori} e={prey.energy} | Scout({scout.x},{scout.y}) o={scout.ori}" ) events = "\n".join(state.event_log[-18:]) trace = "\n".join(state.trace_log[-18:]) return pov, truth_img, b_ctrl, b_other, status, events, trace def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState) -> WorldState: x_px, y_px = evt.index y_px -= 28 if y_px < 0: return state gx = int(x_px // TILE) gy = int(y_px // TILE) if not in_bounds(gx, gy): return state if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1: return state state.grid[gy][gx] = selected_tile state.event_log.append(f"t={state.step}: Tile ({gx},{gy}) -> {TILE_NAMES.get(selected_tile)}") return state # ----------------------------- # Gradio App # ----------------------------- with gr.Blocks(title="Agent POV") as demo: gr.Markdown( "## Agent-POV by ZEN AI Co.\n" "Track every interaction, train policies, and audit why outcomes happened.\n" "No timers (compatibility). Use Tick/Run/Train for controlled experiments." ) st = gr.State(init_state(1337)) history = gr.State([snapshot_of(init_state(1337))]) beliefs = gr.State(init_belief()) rewind_idx = gr.State(0) with gr.Row(): pov_img = gr.Image(label="POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H) with gr.Column(): status = gr.Textbox(label="Status + Metrics", lines=4) events = gr.Textbox(label="Event Log", lines=10) trace = gr.Textbox(label="Step Trace (why it happened)", lines=10) with gr.Row(): truth = gr.Image(label="Truth Map (click to edit tiles)", type="pil") belief_a = gr.Image(label="Belief (Controlled)", type="pil") belief_b = gr.Image(label="Belief (Other)", type="pil") with gr.Row(): with gr.Column(scale=2): gr.Markdown("### Manual Controls") with gr.Row(): btn_L = gr.Button("L") btn_F = gr.Button("F") btn_R = gr.Button("R") with gr.Row(): btn_tick = gr.Button("Tick") run_steps = gr.Number(value=25, label="Run N steps", precision=0) btn_run = gr.Button("Run") with gr.Row(): btn_toggle_control = gr.Button("Toggle Controlled") btn_toggle_pov = gr.Button("Toggle POV") overlay = gr.Checkbox(False, label="Overlay reticle") tile_pick = gr.Radio( choices=[(TILE_NAMES[k], k) for k in [EMPTY, WALL, FOOD, NOISE, DOOR, TELE]], value=WALL, label="Paint tile type" ) with gr.Column(scale=3): gr.Markdown("### Training Controls (Q-learning)") use_q_pred = gr.Checkbox(True, label="Use Q-learning: Predator") use_q_prey = gr.Checkbox(True, label="Use Q-learning: Prey") alpha = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="alpha (learn rate)") gamma = gr.Slider(0.5, 0.99, value=0.95, step=0.01, label="gamma (discount)") eps = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="epsilon (exploration)") eps_decay = gr.Slider(0.90, 0.999, value=0.995, step=0.001, label="epsilon decay") eps_min = gr.Slider(0.0, 0.2, value=0.02, step=0.01, label="epsilon min") episodes = gr.Number(value=50, label="Train episodes", precision=0) max_steps = gr.Number(value=250, label="Max steps per episode", precision=0) btn_train = gr.Button("Train") btn_reset = gr.Button("Reset Episode") btn_reset_all = gr.Button("Reset ALL (wipe Q + metrics)") with gr.Row(): with gr.Column(): rewind = gr.Slider(0, 0, value=0, step=1, label="Rewind (history index)") btn_jump = gr.Button("Jump") with gr.Column(): export_box = gr.Textbox(label="Export JSON", lines=10) btn_export = gr.Button("Export") with gr.Column(): import_box = gr.Textbox(label="Import JSON", lines=10) btn_import = gr.Button("Import") def refresh(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r: int): r_max = max(0, len(hist) - 1) r = max(0, min(int(r), r_max)) pov, tr, ba, bb, stxt, etxt, ttxt = build_views(state, bel) return ( pov, tr, ba, bb, stxt, etxt, ttxt, gr.update(maximum=r_max, value=r), r ) def push_hist(state: WorldState, hist: List[Snapshot]) -> List[Snapshot]: hist.append(snapshot_of(state)) if len(hist) > MAX_HISTORY: hist.pop(0) return hist def set_cfg(state: WorldState, uq_pred: bool, uq_prey: bool, a: float, g: float, e: float, ed: float, emin: float): state.cfg.use_q_pred = bool(uq_pred) state.cfg.use_q_prey = bool(uq_prey) state.cfg.alpha = float(a) state.cfg.gamma = float(g) state.metrics.epsilon = float(e) state.cfg.epsilon_decay = float(ed) state.cfg.epsilon_min = float(emin) return state def do_manual(state, hist, bel, r, act): tick(state, manual_action=act) hist = push_hist(state, hist) r = len(hist) - 1 out = refresh(state, hist, bel, r) return out + (state, hist, bel, r) def do_tick(state, hist, bel, r): tick(state, manual_action=None) hist = push_hist(state, hist) r = len(hist) - 1 out = refresh(state, hist, bel, r) return out + (state, hist, bel, r) def do_run(state, hist, bel, r, n): n = max(1, int(n)) for _ in range(n): if state.caught: break tick(state, manual_action=None) hist = push_hist(state, hist) r = len(hist) - 1 out = refresh(state, hist, bel, r) return out + (state, hist, bel, r) def toggle_control(state, hist, bel, r): order = ["Predator", "Prey", "Scout"] i = order.index(state.controlled) state.controlled = order[(i + 1) % len(order)] state.event_log.append(f"Controlled -> {state.controlled}") hist = push_hist(state, hist) r = len(hist) - 1 out = refresh(state, hist, bel, r) return out + (state, hist, bel, r) def toggle_pov(state, hist, bel, r): order = ["Predator", "Prey", "Scout"] i = order.index(state.pov) state.pov = order[(i + 1) % len(order)] state.event_log.append(f"POV -> {state.pov}") hist = push_hist(state, hist) r = len(hist) - 1 out = refresh(state, hist, bel, r) return out + (state, hist, bel, r) def set_overlay(state, hist, bel, r, ov): state.overlay = bool(ov) out = refresh(state, hist, bel, r) return out + (state, hist, bel, r) def click_truth(tile, state, hist, bel, r, evt: gr.SelectData): state = grid_click_to_tile(evt, int(tile), state) hist = push_hist(state, hist) r = len(hist) - 1 out = refresh(state, hist, bel, r) return out + (state, hist, bel, r) def jump(state, hist, bel, r, idx): if not hist: out = refresh(state, hist, bel, r) return out + (state, hist, bel, r) idx = max(0, min(int(idx), len(hist) - 1)) restore_into(state, hist[idx]) r = idx out = refresh(state, hist, bel, r) return out + (state, hist, bel, r) def reset_ep(state, hist, bel, r): reset_episode(state, seed=state.seed) hist = [snapshot_of(state)] r = 0 bel = init_belief() out = refresh(state, hist, bel, r) return out + (state, hist, bel, r) def reset_all(state, hist, bel, r): seed = state.seed state = init_state(seed) hist = [snapshot_of(state)] bel = init_belief() r = 0 out = refresh(state, hist, bel, r) return out + (state, hist, bel, r) def do_train(state, hist, bel, r, uq_pred, uq_prey, a, g, e, ed, emin, eps_count, max_s): state = set_cfg(state, uq_pred, uq_prey, a, g, e, ed, emin) train(state, episodes=max(1, int(eps_count)), max_steps=max(10, int(max_s))) # After training, reset to a clean episode so user sees improved behavior reset_episode(state, seed=state.seed) hist = [snapshot_of(state)] bel = init_belief() r = 0 out = refresh(state, hist, bel, r) return out + (state, hist, bel, r) def export_fn(state, hist): return export_run(state, hist) def import_fn(txt): state, hist, bel, r = import_run(txt) pov, tr, ba, bb, stxt, etxt, ttxt = build_views(state, bel) r_max = max(0, len(hist) - 1) return ( pov, tr, ba, bb, stxt, etxt, ttxt, gr.update(maximum=r_max, value=r), state, hist, bel, r ) # --- Wire buttons (no fn_kwargs; use lambdas) --- btn_L.click(lambda s,h,b,r: do_manual(s,h,b,r,"L"), inputs=[st, history, beliefs, rewind_idx], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], queue=True) btn_F.click(lambda s,h,b,r: do_manual(s,h,b,r,"F"), inputs=[st, history, beliefs, rewind_idx], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], queue=True) btn_R.click(lambda s,h,b,r: do_manual(s,h,b,r,"R"), inputs=[st, history, beliefs, rewind_idx], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], queue=True) btn_tick.click(do_tick, inputs=[st, history, beliefs, rewind_idx], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], queue=True) btn_run.click(do_run, inputs=[st, history, beliefs, rewind_idx, run_steps], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], queue=True) btn_toggle_control.click(toggle_control, inputs=[st, history, beliefs, rewind_idx], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], queue=True) btn_toggle_pov.click(toggle_pov, inputs=[st, history, beliefs, rewind_idx], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], queue=True) overlay.change(set_overlay, inputs=[st, history, beliefs, rewind_idx, overlay], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], queue=True) truth.select(click_truth, inputs=[tile_pick, st, history, beliefs, rewind_idx], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], queue=True) btn_jump.click(jump, inputs=[st, history, beliefs, rewind_idx, rewind], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], queue=True) btn_reset.click(reset_ep, inputs=[st, history, beliefs, rewind_idx], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], queue=True) btn_reset_all.click(reset_all, inputs=[st, history, beliefs, rewind_idx], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], queue=True) btn_train.click(do_train, inputs=[st, history, beliefs, rewind_idx, use_q_pred, use_q_prey, alpha, gamma, eps, eps_decay, eps_min, episodes, max_steps], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], queue=True) btn_export.click(export_fn, inputs=[st, history], outputs=[export_box], queue=True) btn_import.click(import_fn, inputs=[import_box], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, st, history, beliefs, rewind_idx], queue=True) demo.load(refresh, inputs=[st, history, beliefs, rewind_idx], outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx], queue=True) demo.queue().launch()