Agent-POV / app.py
ZENLLC's picture
Update app.py
2a43c24 verified
import json
import math
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple, Optional
import numpy as np
from PIL import Image, ImageDraw
import gradio as gr
# ============================================================
# ChronoSandbox++ — Instrumented Training Arena
# - Deterministic gridworld + first-person raycast view
# - Click-to-edit environment (tiles)
# - Full step trace: obs -> action -> reward -> q-update rationale
# - Optional Q-learning (tabular) for Predator + Prey
# - Batch training: run episodes fast, track metrics
# - Export/import: environment, history, Q-tables, metrics
#
# Compatibility: avoids fn_kwargs + avoids gr.Timer
# ============================================================
# -----------------------------
# Config
# -----------------------------
GRID_W, GRID_H = 21, 15
TILE = 22
VIEW_W, VIEW_H = 640, 360
RAY_W = 320
FOV_DEG = 78
MAX_DEPTH = 20
DIRS = [(1, 0), (0, 1), (-1, 0), (0, -1)]
ORI_DEG = [0, 90, 180, 270]
EMPTY = 0
WALL = 1
FOOD = 2
NOISE = 3
DOOR = 4
TELE = 5
TILE_NAMES = {
EMPTY: "Empty",
WALL: "Wall",
FOOD: "Food",
NOISE: "Noise",
DOOR: "Door",
TELE: "Teleporter",
}
AGENT_COLORS = {
"Predator": (255, 120, 90),
"Prey": (120, 255, 160),
"Scout": (120, 190, 255),
}
SKY = np.array([14, 16, 26], dtype=np.uint8)
FLOOR_NEAR = np.array([24, 26, 40], dtype=np.uint8)
FLOOR_FAR = np.array([10, 11, 18], dtype=np.uint8)
WALL_BASE = np.array([210, 210, 225], dtype=np.uint8)
WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8)
DOOR_COL = np.array([180, 210, 255], dtype=np.uint8)
ACTIONS = ["L", "F", "R"] # keep small for tabular learning stability
# -----------------------------
# Deterministic RNG streams
# -----------------------------
def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator:
mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531)
return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF)
# -----------------------------
# Data structures
# -----------------------------
@dataclass
class Agent:
name: str
x: int
y: int
ori: int
energy: int = 100
@dataclass
class TrainConfig:
use_q_pred: bool = True
use_q_prey: bool = True
alpha: float = 0.15
gamma: float = 0.95
epsilon: float = 0.10
epsilon_min: float = 0.02
epsilon_decay: float = 0.995
# reward shaping
pred_step_penalty: float = -0.02
pred_dist_coeff: float = 0.03
pred_catch_reward: float = 3.0
prey_step_penalty: float = -0.02
prey_food_reward: float = 0.6
prey_survive_reward: float = 0.02
prey_caught_penalty: float = -3.0
@dataclass
class Metrics:
episodes: int = 0
catches: int = 0
avg_steps_to_catch: float = 0.0
avg_path_efficiency: float = 0.0 # optimal / actual (0..1)
last_episode_steps: int = 0
last_episode_eff: float = 0.0
epsilon: float = 0.10
@dataclass
class WorldState:
seed: int
step: int
grid: List[List[int]]
agents: Dict[str, Agent]
controlled: str
pov: str
overlay: bool
caught: bool
branches: Dict[str, int]
# instrumentation
event_log: List[str]
trace_log: List[str] # more detailed step trace (bounded)
# training
cfg: TrainConfig
q_pred: Dict[str, List[float]]
q_prey: Dict[str, List[float]]
metrics: Metrics
@dataclass
class Snapshot:
step: int
agents: Dict[str, Dict]
grid: List[List[int]]
caught: bool
event_log_tail: List[str]
trace_tail: List[str]
# -----------------------------
# Environment
# -----------------------------
def default_grid() -> List[List[int]]:
g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)]
for x in range(GRID_W):
g[0][x] = WALL
g[GRID_H - 1][x] = WALL
for y in range(GRID_H):
g[y][0] = WALL
g[y][GRID_W - 1] = WALL
for x in range(4, 17):
g[7][x] = WALL
g[7][10] = DOOR
g[3][4] = FOOD
g[11][15] = FOOD
g[4][14] = NOISE
g[12][5] = NOISE
g[2][18] = TELE
g[13][2] = TELE
return g
def init_state(seed: int) -> WorldState:
agents = {
"Predator": Agent("Predator", 2, 2, 0, 100),
"Prey": Agent("Prey", 18, 12, 2, 100),
"Scout": Agent("Scout", 10, 3, 1, 100),
}
cfg = TrainConfig()
return WorldState(
seed=seed,
step=0,
grid=default_grid(),
agents=agents,
controlled="Predator",
pov="Predator",
overlay=False,
caught=False,
branches={"main": 0},
event_log=["Initialized world."],
trace_log=[],
cfg=cfg,
q_pred={},
q_prey={},
metrics=Metrics(epsilon=cfg.epsilon),
)
# -----------------------------
# Belief maps
# -----------------------------
def init_belief() -> Dict[str, np.ndarray]:
b = {}
for nm in ["Predator", "Prey", "Scout"]:
b[nm] = -1 * np.ones((GRID_H, GRID_W), dtype=np.int16)
return b
# -----------------------------
# Helpers
# -----------------------------
def in_bounds(x: int, y: int) -> bool:
return 0 <= x < GRID_W and 0 <= y < GRID_H
def is_blocking(tile: int) -> bool:
return tile == WALL
def manhattan(a: Agent, b: Agent) -> int:
return abs(a.x - b.x) + abs(a.y - b.y)
def bresenham_los(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> bool:
dx = abs(x1 - x0)
dy = abs(y1 - y0)
sx = 1 if x0 < x1 else -1
sy = 1 if y0 < y1 else -1
err = dx - dy
x, y = x0, y0
while True:
if (x, y) != (x0, y0) and (x, y) != (x1, y1):
if grid[y][x] == WALL:
return False
if x == x1 and y == y1:
return True
e2 = 2 * err
if e2 > -dy:
err -= dy
x += sx
if e2 < dx:
err += dx
y += sy
def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = FOV_DEG) -> bool:
dx = tx - observer.x
dy = ty - observer.y
if dx == 0 and dy == 0:
return True
angle = math.degrees(math.atan2(dy, dx)) % 360
facing = ORI_DEG[observer.ori]
diff = (angle - facing + 540) % 360 - 180
return abs(diff) <= (fov_deg / 2)
def visible(observer: Agent, target: Agent, grid: List[List[int]]) -> bool:
return within_fov(observer, target.x, target.y, FOV_DEG) and bresenham_los(grid, observer.x, observer.y, target.x, target.y)
# -----------------------------
# Movement
# -----------------------------
def turn_left(a: Agent) -> None:
a.ori = (a.ori - 1) % 4
def turn_right(a: Agent) -> None:
a.ori = (a.ori + 1) % 4
def move_forward(state: WorldState, a: Agent) -> str:
dx, dy = DIRS[a.ori]
nx, ny = a.x + dx, a.y + dy
if not in_bounds(nx, ny):
return "blocked: bounds"
if is_blocking(state.grid[ny][nx]):
return "blocked: wall"
if state.grid[ny][nx] == DOOR:
state.grid[ny][nx] = EMPTY
state.event_log.append(f"t={state.step}: {a.name} opened a door.")
a.x, a.y = nx, ny
if state.grid[ny][nx] == TELE:
teles = [(x, y) for y in range(GRID_H) for x in range(GRID_W) if state.grid[y][x] == TELE]
if len(teles) >= 2:
teles_sorted = sorted(teles)
idx = teles_sorted.index((nx, ny))
dest = teles_sorted[(idx + 1) % len(teles_sorted)]
a.x, a.y = dest
state.event_log.append(f"t={state.step}: {a.name} teleported.")
return "moved: teleported"
return "moved"
def apply_action(state: WorldState, agent_name: str, action: str) -> str:
a = state.agents[agent_name]
if action == "L":
turn_left(a)
return "turned left"
if action == "R":
turn_right(a)
return "turned right"
if action == "F":
return move_forward(state, a)
return "noop"
# -----------------------------
# Rendering
# -----------------------------
def raycast_view(state: WorldState, observer: Agent) -> np.ndarray:
img = np.zeros((VIEW_H, VIEW_W, 3), dtype=np.uint8)
img[:, :] = SKY
for y in range(VIEW_H // 2, VIEW_H):
t = (y - VIEW_H // 2) / (VIEW_H // 2 + 1e-6)
col = (1 - t) * FLOOR_NEAR + t * FLOOR_FAR
img[y, :] = col.astype(np.uint8)
fov = math.radians(FOV_DEG)
half_fov = fov / 2
for rx in range(RAY_W):
cam_x = (2 * rx / (RAY_W - 1)) - 1
ray_ang = math.radians(ORI_DEG[observer.ori]) + cam_x * half_fov
ox, oy = observer.x + 0.5, observer.y + 0.5
sin_a = math.sin(ray_ang)
cos_a = math.cos(ray_ang)
depth = 0.0
hit = None # None, "wall", "door"
side = 0
while depth < MAX_DEPTH:
depth += 0.05
tx = int(ox + cos_a * depth)
ty = int(oy + sin_a * depth)
if not in_bounds(tx, ty):
break
tile = state.grid[ty][tx]
if tile == WALL:
hit = "wall"
side = 1 if abs(cos_a) > abs(sin_a) else 0
break
if tile == DOOR:
hit = "door"
break
if hit is None:
continue
depth *= math.cos(ray_ang - math.radians(ORI_DEG[observer.ori]))
depth = max(depth, 0.001)
proj_h = int((VIEW_H * 0.9) / depth)
y0 = max(0, VIEW_H // 2 - proj_h // 2)
y1 = min(VIEW_H - 1, VIEW_H // 2 + proj_h // 2)
if hit == "door":
col = DOOR_COL.copy()
else:
col = WALL_BASE.copy() if side == 0 else WALL_SIDE.copy()
dim = max(0.25, 1.0 - (depth / MAX_DEPTH))
col = (col * dim).astype(np.uint8)
x0 = int(rx * (VIEW_W / RAY_W))
x1 = int((rx + 1) * (VIEW_W / RAY_W))
img[y0:y1, x0:x1] = col
# billboards for visible agents
for nm, other in state.agents.items():
if nm == observer.name:
continue
if visible(observer, other, state.grid):
dx = other.x - observer.x
dy = other.y - observer.y
ang = (math.degrees(math.atan2(dy, dx)) % 360)
facing = ORI_DEG[observer.ori]
diff = (ang - facing + 540) % 360 - 180
sx = int((diff / (FOV_DEG / 2)) * (VIEW_W / 2) + (VIEW_W / 2))
dist = math.sqrt(dx * dx + dy * dy)
h = int((VIEW_H * 0.65) / max(dist, 0.75))
w = max(10, h // 3)
y_mid = VIEW_H // 2
y0 = max(0, y_mid - h // 2)
y1 = min(VIEW_H - 1, y_mid + h // 2)
x0 = max(0, sx - w // 2)
x1 = min(VIEW_W - 1, sx + w // 2)
col = AGENT_COLORS.get(nm, (255, 200, 120))
img[y0:y1, x0:x1] = np.array(col, dtype=np.uint8)
if state.overlay:
cx, cy = VIEW_W // 2, VIEW_H // 2
img[cy - 1:cy + 2, cx - 10:cx + 10] = np.array([120, 190, 255], dtype=np.uint8)
img[cy - 10:cy + 10, cx - 1:cx + 2] = np.array([120, 190, 255], dtype=np.uint8)
return img
def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_agents: bool = True) -> Image.Image:
w = grid.shape[1] * TILE
h = grid.shape[0] * TILE
im = Image.new("RGB", (w, h + 28), (10, 12, 18))
draw = ImageDraw.Draw(im)
for y in range(grid.shape[0]):
for x in range(grid.shape[1]):
t = int(grid[y, x])
if t == -1:
col = (18, 20, 32)
elif t == EMPTY:
col = (26, 30, 44)
elif t == WALL:
col = (190, 190, 210)
elif t == FOOD:
col = (255, 210, 120)
elif t == NOISE:
col = (255, 120, 220)
elif t == DOOR:
col = (140, 210, 255)
elif t == TELE:
col = (120, 190, 255)
else:
col = (80, 80, 90)
x0, y0 = x * TILE, y * TILE + 28
draw.rectangle([x0, y0, x0 + TILE - 1, y0 + TILE - 1], fill=col)
for x in range(grid.shape[1] + 1):
xx = x * TILE
draw.line([xx, 28, xx, h + 28], fill=(12, 14, 22))
for y in range(grid.shape[0] + 1):
yy = y * TILE + 28
draw.line([0, yy, w, yy], fill=(12, 14, 22))
if show_agents:
for nm, a in agents.items():
cx = a.x * TILE + TILE // 2
cy = a.y * TILE + 28 + TILE // 2
col = AGENT_COLORS.get(nm, (220, 220, 220))
r = TILE // 3
draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col)
dx, dy = DIRS[a.ori]
draw.line([cx, cy, cx + dx * r, cy + dy * r], fill=(10, 10, 10), width=3)
draw.rectangle([0, 0, w, 28], fill=(14, 16, 26))
draw.text((8, 6), title, fill=(230, 230, 240))
return im
# -----------------------------
# Belief updates
# -----------------------------
def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> None:
belief[agent.y, agent.x] = state.grid[agent.y][agent.x]
base = math.radians(ORI_DEG[agent.ori])
half = math.radians(FOV_DEG / 2)
rays = 33 if agent.name != "Scout" else 45
for i in range(rays):
t = i / (rays - 1)
ang = base + (t * 2 - 1) * half
sin_a, cos_a = math.sin(ang), math.cos(ang)
ox, oy = agent.x + 0.5, agent.y + 0.5
depth = 0.0
while depth < MAX_DEPTH:
depth += 0.2
tx = int(ox + cos_a * depth)
ty = int(oy + sin_a * depth)
if not in_bounds(tx, ty):
break
belief[ty, tx] = state.grid[ty][tx]
if state.grid[ty][tx] == WALL:
break
# -----------------------------
# Optimal distance (BFS) for efficiency metric
# -----------------------------
def bfs_distance(grid: List[List[int]], sx: int, sy: int, gx: int, gy: int) -> Optional[int]:
if (sx, sy) == (gx, gy):
return 0
q = [(sx, sy)]
dist = { (sx, sy): 0 }
head = 0
while head < len(q):
x, y = q[head]; head += 1
for dx, dy in DIRS:
nx, ny = x + dx, y + dy
if not in_bounds(nx, ny):
continue
if grid[ny][nx] == WALL:
continue
if (nx, ny) in dist:
continue
dist[(nx, ny)] = dist[(x, y)] + 1
if (nx, ny) == (gx, gy):
return dist[(nx, ny)]
q.append((nx, ny))
return None
# -----------------------------
# Observation encoding (compact state key)
# -----------------------------
def obs_key(state: WorldState, who: str) -> str:
pred = state.agents["Predator"]
prey = state.agents["Prey"]
a = state.agents[who]
# relative position coarse-binned to keep table smaller
dx = prey.x - pred.x
dy = prey.y - pred.y
dx_bin = int(np.clip(dx, -6, 6))
dy_bin = int(np.clip(dy, -6, 6))
vis = 1 if visible(pred, prey, state.grid) else 0
# include own orientation and role
if who == "Predator":
return f"P|{pred.x},{pred.y},{pred.ori}|d{dx_bin},{dy_bin}|v{vis}"
if who == "Prey":
# prey cares if predator is visible to it
vis2 = 1 if visible(prey, pred, state.grid) else 0
ddx = pred.x - prey.x
ddy = pred.y - prey.y
ddx_bin = int(np.clip(ddx, -6, 6))
ddy_bin = int(np.clip(ddy, -6, 6))
return f"R|{prey.x},{prey.y},{prey.ori}|d{ddx_bin},{ddy_bin}|v{vis2}|e{int(prey.energy//25)}"
# Scout: simple
return f"S|{a.x},{a.y},{a.ori}"
def q_get(q: Dict[str, List[float]], key: str) -> List[float]:
if key not in q:
q[key] = [0.0, 0.0, 0.0]
return q[key]
def epsilon_greedy(qvals: List[float], eps: float, r: np.random.Generator) -> int:
if r.random() < eps:
return int(r.integers(0, len(qvals)))
return int(np.argmax(qvals))
def q_update(q: Dict[str, List[float]], key: str, a_idx: int, reward: float, next_key: str, alpha: float, gamma: float) -> Tuple[float, float, float]:
qv = q_get(q, key)
nq = q_get(q, next_key)
old = qv[a_idx]
target = reward + gamma * float(np.max(nq))
new = old + alpha * (target - old)
qv[a_idx] = new
return old, target, new
# -----------------------------
# Baseline heuristic policies (for Scout + fallback)
# -----------------------------
def heuristic_pred_action(state: WorldState) -> str:
pred = state.agents["Predator"]
prey = state.agents["Prey"]
if visible(pred, prey, state.grid):
dx = prey.x - pred.x
dy = prey.y - pred.y
ang = (math.degrees(math.atan2(dy, dx)) % 360)
facing = ORI_DEG[pred.ori]
diff = (ang - facing + 540) % 360 - 180
if diff < -10:
return "L"
if diff > 10:
return "R"
return "F"
r = rng_for(state.seed, state.step, stream=11)
return r.choice(ACTIONS)
def heuristic_prey_action(state: WorldState) -> str:
prey = state.agents["Prey"]
pred = state.agents["Predator"]
if visible(prey, pred, state.grid):
dx = pred.x - prey.x
dy = pred.y - prey.y
ang = (math.degrees(math.atan2(dy, dx)) % 360)
facing = ORI_DEG[prey.ori]
diff = (ang - facing + 540) % 360 - 180
diff_away = ((diff + 180) + 540) % 360 - 180
if diff_away < -10:
return "L"
if diff_away > 10:
return "R"
return "F"
r = rng_for(state.seed, state.step, stream=12)
return r.choice(ACTIONS)
def heuristic_scout_action(state: WorldState) -> str:
r = rng_for(state.seed, state.step, stream=13)
return r.choice(ACTIONS)
# -----------------------------
# Reward shaping
# -----------------------------
def pred_reward(state_prev: WorldState, state_now: WorldState) -> float:
cfg = state_now.cfg
pred0 = state_prev.agents["Predator"]
prey0 = state_prev.agents["Prey"]
pred1 = state_now.agents["Predator"]
prey1 = state_now.agents["Prey"]
d0 = abs(pred0.x - prey0.x) + abs(pred0.y - prey0.y)
d1 = abs(pred1.x - prey1.x) + abs(pred1.y - prey1.y)
r = cfg.pred_step_penalty + cfg.pred_dist_coeff * (d0 - d1) # reward closing distance
if state_now.caught:
r += cfg.pred_catch_reward
return float(r)
def prey_reward(state_prev: WorldState, state_now: WorldState, ate_food: bool) -> float:
cfg = state_now.cfg
r = cfg.prey_step_penalty + cfg.prey_survive_reward
if ate_food:
r += cfg.prey_food_reward
if state_now.caught:
r += cfg.prey_caught_penalty
return float(r)
# -----------------------------
# Core simulation tick (with instrumentation + optional learning)
# -----------------------------
TRACE_MAX = 400
def clone_shallow(state: WorldState) -> WorldState:
# clone for reward computation, minimal fields
return WorldState(
seed=state.seed,
step=state.step,
grid=[row[:] for row in state.grid],
agents={k: Agent(**asdict(v)) for k, v in state.agents.items()},
controlled=state.controlled,
pov=state.pov,
overlay=state.overlay,
caught=state.caught,
branches=dict(state.branches),
event_log=list(state.event_log),
trace_log=list(state.trace_log),
cfg=state.cfg,
q_pred=state.q_pred,
q_prey=state.q_prey,
metrics=state.metrics,
)
def check_catch(state: WorldState) -> None:
pred = state.agents["Predator"]
prey = state.agents["Prey"]
if pred.x == prey.x and pred.y == prey.y:
state.caught = True
state.event_log.append(f"t={state.step}: CAUGHT.")
def consume_food(state: WorldState) -> bool:
prey = state.agents["Prey"]
if state.grid[prey.y][prey.x] == FOOD:
prey.energy = min(200, prey.energy + 35)
state.grid[prey.y][prey.x] = EMPTY
state.event_log.append(f"t={state.step}: Prey ate food (+energy).")
return True
return False
def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, Optional[Tuple[str,int]]]:
"""
Returns (action, reason, q_info)
q_info: (obs_key, action_index) if chosen by Q, else None
"""
cfg = state.cfg
r = rng_for(state.seed, state.step, stream=stream)
if who == "Predator" and cfg.use_q_pred:
k = obs_key(state, "Predator")
qv = q_get(state.q_pred, k)
a_idx = epsilon_greedy(qv, state.metrics.epsilon, r)
return ACTIONS[a_idx], f"Q(pred) eps={state.metrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (k, a_idx)
if who == "Prey" and cfg.use_q_prey:
k = obs_key(state, "Prey")
qv = q_get(state.q_prey, k)
a_idx = epsilon_greedy(qv, state.metrics.epsilon, r)
return ACTIONS[a_idx], f"Q(prey) eps={state.metrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (k, a_idx)
# fallbacks
if who == "Predator":
a = heuristic_pred_action(state)
return a, "heuristic(pred)", None
if who == "Prey":
a = heuristic_prey_action(state)
return a, "heuristic(prey)", None
a = heuristic_scout_action(state)
return a, "heuristic(scout)", None
def tick(state: WorldState, manual_action: Optional[str] = None) -> None:
if state.caught:
return
prev = clone_shallow(state)
# record optimal distance for efficiency stats
pred = state.agents["Predator"]
prey = state.agents["Prey"]
opt_dist = bfs_distance(state.grid, pred.x, pred.y, prey.x, prey.y)
if opt_dist is None:
opt_dist = 999
# Action selection
chosen = {}
reasons = {}
qinfo = {}
# manual action applies to controlled agent
if manual_action:
chosen[state.controlled] = manual_action
reasons[state.controlled] = "manual"
qinfo[state.controlled] = None
# others choose
for who in ["Predator", "Prey", "Scout"]:
if who in chosen:
continue
act, reason, q_i = choose_action(state, who, stream={"Predator":21,"Prey":22,"Scout":23}[who])
chosen[who] = act
reasons[who] = reason
qinfo[who] = q_i
# Apply actions (deterministic order)
outcomes = {}
for who in ["Predator", "Prey", "Scout"]:
outcomes[who] = apply_action(state, who, chosen[who])
ate = consume_food(state)
check_catch(state)
# Rewards + Q-updates
pred_r = pred_reward(prev, state)
prey_r = prey_reward(prev, state, ate_food=ate)
q_lines = []
if qinfo["Predator"] is not None:
k, a_idx = qinfo["Predator"]
nk = obs_key(state, "Predator")
old, target, new = q_update(state.q_pred, k, a_idx, pred_r, nk, state.cfg.alpha, state.cfg.gamma)
q_lines.append(f"Qpred: {k} a={ACTIONS[a_idx]} old={old:.3f} tgt={target:.3f} new={new:.3f}")
if qinfo["Prey"] is not None:
k, a_idx = qinfo["Prey"]
nk = obs_key(state, "Prey")
old, target, new = q_update(state.q_prey, k, a_idx, prey_r, nk, state.cfg.alpha, state.cfg.gamma)
q_lines.append(f"Qprey: {k} a={ACTIONS[a_idx]} old={old:.3f} tgt={target:.3f} new={new:.3f}")
# Trace line
dist_now = manhattan(state.agents["Predator"], state.agents["Prey"])
eff = (opt_dist / max(1, dist_now)) if dist_now > 0 else 1.0
trace = (
f"t={state.step} optDist~{opt_dist} distNow={dist_now} "
f"| Pred:{chosen['Predator']} ({outcomes['Predator']}) [{reasons['Predator']}] r={pred_r:+.3f} "
f"| Prey:{chosen['Prey']} ({outcomes['Prey']}) [{reasons['Prey']}] r={prey_r:+.3f} "
f"| Scout:{chosen['Scout']} ({outcomes['Scout']}) [{reasons['Scout']}] "
f"| ateFood={ate} caught={state.caught}"
)
if q_lines:
trace += " | " + " ; ".join(q_lines)
state.trace_log.append(trace)
if len(state.trace_log) > TRACE_MAX:
state.trace_log = state.trace_log[-TRACE_MAX:]
state.step += 1
# -----------------------------
# Episode reset + training
# -----------------------------
def reset_episode(state: WorldState, seed: Optional[int] = None) -> None:
# Keep Q-tables + cfg + metrics; reset world + logs
if seed is None:
seed = state.seed
fresh = init_state(seed)
fresh.cfg = state.cfg
fresh.q_pred = state.q_pred
fresh.q_prey = state.q_prey
fresh.metrics = state.metrics
fresh.metrics.epsilon = state.metrics.epsilon
state.seed = fresh.seed
state.step = 0
state.grid = fresh.grid
state.agents = fresh.agents
state.controlled = fresh.controlled
state.pov = fresh.pov
state.overlay = fresh.overlay
state.caught = False
state.branches = fresh.branches
state.event_log = ["Episode reset."]
state.trace_log = []
def run_episode(state: WorldState, max_steps: int) -> Tuple[bool, int, float]:
# returns (caught, steps, path_eff)
start_pred = state.agents["Predator"]
start_prey = state.agents["Prey"]
opt = bfs_distance(state.grid, start_pred.x, start_pred.y, start_prey.x, start_prey.y)
if opt is None:
opt = 999
steps = 0
while steps < max_steps and not state.caught:
tick(state, manual_action=None)
steps += 1
caught = state.caught
eff = float(opt / max(1, steps)) if opt < 999 else 0.0
return caught, steps, eff
def train(state: WorldState, episodes: int, max_steps: int) -> None:
m = state.metrics
cfg = state.cfg
catches = 0
total_steps_catch = 0
total_eff = 0.0
for ep in range(episodes):
# deterministically vary episode seed so it doesn't memorize one map-layout only
ep_seed = (state.seed * 1_000_003 + (m.episodes + ep) * 97_531) & 0xFFFFFFFF
reset_episode(state, seed=int(ep_seed))
caught, steps, eff = run_episode(state, max_steps=max_steps)
total_eff += eff
if caught:
catches += 1
total_steps_catch += steps
# epsilon decay
m.epsilon = max(cfg.epsilon_min, m.epsilon * cfg.epsilon_decay)
# Update metrics
m.episodes += episodes
m.catches += catches
m.last_episode_steps = steps
m.last_episode_eff = eff
if catches > 0:
# moving average by episode count for stability
avg_steps = total_steps_catch / catches
m.avg_steps_to_catch = (
0.85 * m.avg_steps_to_catch + 0.15 * avg_steps
if m.avg_steps_to_catch > 0 else avg_steps
)
avg_eff = total_eff / max(1, episodes)
m.avg_path_efficiency = (
0.85 * m.avg_path_efficiency + 0.15 * avg_eff
if m.avg_path_efficiency > 0 else avg_eff
)
state.event_log.append(
f"Training: +{episodes} eps | catches={catches}/{episodes} | "
f"avgStepsToCatch~{m.avg_steps_to_catch:.2f} | avgEff~{m.avg_path_efficiency:.2f} | eps={m.epsilon:.3f}"
)
# -----------------------------
# History / snapshots
# -----------------------------
MAX_HISTORY = 1200
def snapshot_of(state: WorldState) -> Snapshot:
return Snapshot(
step=state.step,
agents={k: asdict(v) for k, v in state.agents.items()},
grid=[row[:] for row in state.grid],
caught=state.caught,
event_log_tail=state.event_log[-20:],
trace_tail=state.trace_log[-40:],
)
def restore_into(state: WorldState, snap: Snapshot) -> None:
state.step = snap.step
state.grid = [row[:] for row in snap.grid]
for k, d in snap.agents.items():
state.agents[k] = Agent(**d)
state.caught = snap.caught
state.event_log.append(f"Jumped to snapshot t={snap.step}.")
# -----------------------------
# Export / import
# -----------------------------
def export_run(state: WorldState, history: List[Snapshot]) -> str:
payload = {
"seed": state.seed,
"controlled": state.controlled,
"pov": state.pov,
"overlay": state.overlay,
"cfg": asdict(state.cfg),
"metrics": asdict(state.metrics),
"q_pred": state.q_pred,
"q_prey": state.q_prey,
"history": [asdict(s) for s in history],
"grid": state.grid,
}
return json.dumps(payload, indent=2)
def import_run(txt: str) -> Tuple[WorldState, List[Snapshot], Dict[str, np.ndarray], int]:
data = json.loads(txt)
st = init_state(int(data.get("seed", 1337)))
st.controlled = data.get("controlled", st.controlled)
st.pov = data.get("pov", st.pov)
st.overlay = bool(data.get("overlay", False))
st.grid = data.get("grid", st.grid)
st.cfg = TrainConfig(**data.get("cfg", asdict(st.cfg)))
st.metrics = Metrics(**data.get("metrics", asdict(st.metrics)))
st.q_pred = data.get("q_pred", {})
st.q_prey = data.get("q_prey", {})
hist = [Snapshot(**s) for s in data.get("history", [])]
bel = init_belief()
r_idx = max(0, len(hist) - 1)
if hist:
restore_into(st, hist[-1])
st.event_log.append("Imported run.")
return st, hist, bel, r_idx
# -----------------------------
# UI glue
# -----------------------------
def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, str, str, str]:
for nm, a in state.agents.items():
update_belief_for_agent(state, beliefs[nm], a)
pov = raycast_view(state, state.agents[state.pov])
truth_np = np.array(state.grid, dtype=np.int16)
truth_img = render_topdown(truth_np, state.agents, f"Truth Map — t={state.step} seed={state.seed}", show_agents=True)
ctrl = state.controlled
other = "Prey" if ctrl == "Predator" else "Predator"
b_ctrl = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief", show_agents=True)
b_other = render_topdown(beliefs[other], state.agents, f"{other} Belief", show_agents=True)
m = state.metrics
pred = state.agents["Predator"]
prey = state.agents["Prey"]
scout = state.agents["Scout"]
status = (
f"Controlled={state.controlled} | POV={state.pov} | caught={state.caught} | eps={m.epsilon:.3f}\n"
f"Episodes={m.episodes} | catches={m.catches} | avgStepsToCatch~{m.avg_steps_to_catch:.2f} | avgEff~{m.avg_path_efficiency:.2f}\n"
f"Pred({pred.x},{pred.y}) o={pred.ori} | Prey({prey.x},{prey.y}) o={prey.ori} e={prey.energy} | Scout({scout.x},{scout.y}) o={scout.ori}"
)
events = "\n".join(state.event_log[-18:])
trace = "\n".join(state.trace_log[-18:])
return pov, truth_img, b_ctrl, b_other, status, events, trace
def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState) -> WorldState:
x_px, y_px = evt.index
y_px -= 28
if y_px < 0:
return state
gx = int(x_px // TILE)
gy = int(y_px // TILE)
if not in_bounds(gx, gy):
return state
if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1:
return state
state.grid[gy][gx] = selected_tile
state.event_log.append(f"t={state.step}: Tile ({gx},{gy}) -> {TILE_NAMES.get(selected_tile)}")
return state
# -----------------------------
# Gradio App
# -----------------------------
with gr.Blocks(title="Agent POV") as demo:
gr.Markdown(
"## Agent-POV by ZEN AI Co.\n"
"Track every interaction, train policies, and audit why outcomes happened.\n"
"No timers (compatibility). Use Tick/Run/Train for controlled experiments."
)
st = gr.State(init_state(1337))
history = gr.State([snapshot_of(init_state(1337))])
beliefs = gr.State(init_belief())
rewind_idx = gr.State(0)
with gr.Row():
pov_img = gr.Image(label="POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H)
with gr.Column():
status = gr.Textbox(label="Status + Metrics", lines=4)
events = gr.Textbox(label="Event Log", lines=10)
trace = gr.Textbox(label="Step Trace (why it happened)", lines=10)
with gr.Row():
truth = gr.Image(label="Truth Map (click to edit tiles)", type="pil")
belief_a = gr.Image(label="Belief (Controlled)", type="pil")
belief_b = gr.Image(label="Belief (Other)", type="pil")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### Manual Controls")
with gr.Row():
btn_L = gr.Button("L")
btn_F = gr.Button("F")
btn_R = gr.Button("R")
with gr.Row():
btn_tick = gr.Button("Tick")
run_steps = gr.Number(value=25, label="Run N steps", precision=0)
btn_run = gr.Button("Run")
with gr.Row():
btn_toggle_control = gr.Button("Toggle Controlled")
btn_toggle_pov = gr.Button("Toggle POV")
overlay = gr.Checkbox(False, label="Overlay reticle")
tile_pick = gr.Radio(
choices=[(TILE_NAMES[k], k) for k in [EMPTY, WALL, FOOD, NOISE, DOOR, TELE]],
value=WALL,
label="Paint tile type"
)
with gr.Column(scale=3):
gr.Markdown("### Training Controls (Q-learning)")
use_q_pred = gr.Checkbox(True, label="Use Q-learning: Predator")
use_q_prey = gr.Checkbox(True, label="Use Q-learning: Prey")
alpha = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="alpha (learn rate)")
gamma = gr.Slider(0.5, 0.99, value=0.95, step=0.01, label="gamma (discount)")
eps = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="epsilon (exploration)")
eps_decay = gr.Slider(0.90, 0.999, value=0.995, step=0.001, label="epsilon decay")
eps_min = gr.Slider(0.0, 0.2, value=0.02, step=0.01, label="epsilon min")
episodes = gr.Number(value=50, label="Train episodes", precision=0)
max_steps = gr.Number(value=250, label="Max steps per episode", precision=0)
btn_train = gr.Button("Train")
btn_reset = gr.Button("Reset Episode")
btn_reset_all = gr.Button("Reset ALL (wipe Q + metrics)")
with gr.Row():
with gr.Column():
rewind = gr.Slider(0, 0, value=0, step=1, label="Rewind (history index)")
btn_jump = gr.Button("Jump")
with gr.Column():
export_box = gr.Textbox(label="Export JSON", lines=10)
btn_export = gr.Button("Export")
with gr.Column():
import_box = gr.Textbox(label="Import JSON", lines=10)
btn_import = gr.Button("Import")
def refresh(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r: int):
r_max = max(0, len(hist) - 1)
r = max(0, min(int(r), r_max))
pov, tr, ba, bb, stxt, etxt, ttxt = build_views(state, bel)
return (
pov, tr, ba, bb,
stxt, etxt, ttxt,
gr.update(maximum=r_max, value=r),
r
)
def push_hist(state: WorldState, hist: List[Snapshot]) -> List[Snapshot]:
hist.append(snapshot_of(state))
if len(hist) > MAX_HISTORY:
hist.pop(0)
return hist
def set_cfg(state: WorldState, uq_pred: bool, uq_prey: bool, a: float, g: float, e: float, ed: float, emin: float):
state.cfg.use_q_pred = bool(uq_pred)
state.cfg.use_q_prey = bool(uq_prey)
state.cfg.alpha = float(a)
state.cfg.gamma = float(g)
state.metrics.epsilon = float(e)
state.cfg.epsilon_decay = float(ed)
state.cfg.epsilon_min = float(emin)
return state
def do_manual(state, hist, bel, r, act):
tick(state, manual_action=act)
hist = push_hist(state, hist)
r = len(hist) - 1
out = refresh(state, hist, bel, r)
return out + (state, hist, bel, r)
def do_tick(state, hist, bel, r):
tick(state, manual_action=None)
hist = push_hist(state, hist)
r = len(hist) - 1
out = refresh(state, hist, bel, r)
return out + (state, hist, bel, r)
def do_run(state, hist, bel, r, n):
n = max(1, int(n))
for _ in range(n):
if state.caught:
break
tick(state, manual_action=None)
hist = push_hist(state, hist)
r = len(hist) - 1
out = refresh(state, hist, bel, r)
return out + (state, hist, bel, r)
def toggle_control(state, hist, bel, r):
order = ["Predator", "Prey", "Scout"]
i = order.index(state.controlled)
state.controlled = order[(i + 1) % len(order)]
state.event_log.append(f"Controlled -> {state.controlled}")
hist = push_hist(state, hist)
r = len(hist) - 1
out = refresh(state, hist, bel, r)
return out + (state, hist, bel, r)
def toggle_pov(state, hist, bel, r):
order = ["Predator", "Prey", "Scout"]
i = order.index(state.pov)
state.pov = order[(i + 1) % len(order)]
state.event_log.append(f"POV -> {state.pov}")
hist = push_hist(state, hist)
r = len(hist) - 1
out = refresh(state, hist, bel, r)
return out + (state, hist, bel, r)
def set_overlay(state, hist, bel, r, ov):
state.overlay = bool(ov)
out = refresh(state, hist, bel, r)
return out + (state, hist, bel, r)
def click_truth(tile, state, hist, bel, r, evt: gr.SelectData):
state = grid_click_to_tile(evt, int(tile), state)
hist = push_hist(state, hist)
r = len(hist) - 1
out = refresh(state, hist, bel, r)
return out + (state, hist, bel, r)
def jump(state, hist, bel, r, idx):
if not hist:
out = refresh(state, hist, bel, r)
return out + (state, hist, bel, r)
idx = max(0, min(int(idx), len(hist) - 1))
restore_into(state, hist[idx])
r = idx
out = refresh(state, hist, bel, r)
return out + (state, hist, bel, r)
def reset_ep(state, hist, bel, r):
reset_episode(state, seed=state.seed)
hist = [snapshot_of(state)]
r = 0
bel = init_belief()
out = refresh(state, hist, bel, r)
return out + (state, hist, bel, r)
def reset_all(state, hist, bel, r):
seed = state.seed
state = init_state(seed)
hist = [snapshot_of(state)]
bel = init_belief()
r = 0
out = refresh(state, hist, bel, r)
return out + (state, hist, bel, r)
def do_train(state, hist, bel, r,
uq_pred, uq_prey, a, g, e, ed, emin,
eps_count, max_s):
state = set_cfg(state, uq_pred, uq_prey, a, g, e, ed, emin)
train(state, episodes=max(1, int(eps_count)), max_steps=max(10, int(max_s)))
# After training, reset to a clean episode so user sees improved behavior
reset_episode(state, seed=state.seed)
hist = [snapshot_of(state)]
bel = init_belief()
r = 0
out = refresh(state, hist, bel, r)
return out + (state, hist, bel, r)
def export_fn(state, hist):
return export_run(state, hist)
def import_fn(txt):
state, hist, bel, r = import_run(txt)
pov, tr, ba, bb, stxt, etxt, ttxt = build_views(state, bel)
r_max = max(0, len(hist) - 1)
return (
pov, tr, ba, bb, stxt, etxt, ttxt,
gr.update(maximum=r_max, value=r),
state, hist, bel, r
)
# --- Wire buttons (no fn_kwargs; use lambdas) ---
btn_L.click(lambda s,h,b,r: do_manual(s,h,b,r,"L"),
inputs=[st, history, beliefs, rewind_idx],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
queue=True)
btn_F.click(lambda s,h,b,r: do_manual(s,h,b,r,"F"),
inputs=[st, history, beliefs, rewind_idx],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
queue=True)
btn_R.click(lambda s,h,b,r: do_manual(s,h,b,r,"R"),
inputs=[st, history, beliefs, rewind_idx],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
queue=True)
btn_tick.click(do_tick,
inputs=[st, history, beliefs, rewind_idx],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
queue=True)
btn_run.click(do_run,
inputs=[st, history, beliefs, rewind_idx, run_steps],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
queue=True)
btn_toggle_control.click(toggle_control,
inputs=[st, history, beliefs, rewind_idx],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
queue=True)
btn_toggle_pov.click(toggle_pov,
inputs=[st, history, beliefs, rewind_idx],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
queue=True)
overlay.change(set_overlay,
inputs=[st, history, beliefs, rewind_idx, overlay],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
queue=True)
truth.select(click_truth,
inputs=[tile_pick, st, history, beliefs, rewind_idx],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
queue=True)
btn_jump.click(jump,
inputs=[st, history, beliefs, rewind_idx, rewind],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
queue=True)
btn_reset.click(reset_ep,
inputs=[st, history, beliefs, rewind_idx],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
queue=True)
btn_reset_all.click(reset_all,
inputs=[st, history, beliefs, rewind_idx],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
queue=True)
btn_train.click(do_train,
inputs=[st, history, beliefs, rewind_idx,
use_q_pred, use_q_prey, alpha, gamma, eps, eps_decay, eps_min,
episodes, max_steps],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx],
queue=True)
btn_export.click(export_fn, inputs=[st, history], outputs=[export_box], queue=True)
btn_import.click(import_fn,
inputs=[import_box],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, st, history, beliefs, rewind_idx],
queue=True)
demo.load(refresh,
inputs=[st, history, beliefs, rewind_idx],
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx],
queue=True)
demo.queue().launch()