POV2 / app.py
ZENLLC's picture
Create app.py
b83ea71 verified
import json
import math
import hashlib
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple, Optional, Any
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import gradio as gr
# ============================================================
# ZEN AgentLab — Agent POV + Multi-Agent Mini-Sim Arena
#
# Additions in this version:
# - Autoplay (Start/Stop) via gr.Timer (watch agents live)
# - One-click "Cinematic Run" (full episode in one click)
# - Example presets (env+seed) + seed controls
# - Autoplay is interruptible: manual buttons still work anytime
#
# Matplotlib HF-safe: uses canvas.buffer_rgba()
# ============================================================
# -----------------------------
# Global config
# -----------------------------
GRID_W, GRID_H = 21, 15
TILE = 22
VIEW_W, VIEW_H = 640, 360
RAY_W = 320
FOV_DEG = 78
MAX_DEPTH = 20
DIRS = [(1, 0), (0, 1), (-1, 0), (0, -1)]
ORI_DEG = [0, 90, 180, 270]
# Tiles
EMPTY = 0
WALL = 1
FOOD = 2
NOISE = 3
DOOR = 4
TELE = 5
KEY = 6
EXIT = 7
ARTIFACT = 8
HAZARD = 9
WOOD = 10
ORE = 11
MEDKIT = 12
SWITCH = 13
BASE = 14
TILE_NAMES = {
EMPTY: "Empty",
WALL: "Wall",
FOOD: "Food",
NOISE: "Noise",
DOOR: "Door",
TELE: "Teleporter",
KEY: "Key",
EXIT: "Exit",
ARTIFACT: "Artifact",
HAZARD: "Hazard",
WOOD: "Wood",
ORE: "Ore",
MEDKIT: "Medkit",
SWITCH: "Switch",
BASE: "Base",
}
AGENT_COLORS = {
"Predator": (255, 120, 90),
"Prey": (120, 255, 160),
"Scout": (120, 190, 255),
"Alpha": (255, 205, 120),
"Bravo": (160, 210, 255),
"Guardian": (255, 120, 220),
"BuilderA": (140, 255, 200),
"BuilderB": (160, 200, 255),
"Raider": (255, 160, 120),
}
SKY = np.array([14, 16, 26], dtype=np.uint8)
FLOOR_NEAR = np.array([24, 26, 40], dtype=np.uint8)
FLOOR_FAR = np.array([10, 11, 18], dtype=np.uint8)
WALL_BASE = np.array([210, 210, 225], dtype=np.uint8)
WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8)
DOOR_COL = np.array([140, 210, 255], dtype=np.uint8)
# Small action space
ACTIONS = ["L", "F", "R", "I"] # interact
TRACE_MAX = 500
MAX_HISTORY = 1400
# -----------------------------
# Deterministic RNG
# -----------------------------
def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator:
mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531)
return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF)
# -----------------------------
# Data structures
# -----------------------------
@dataclass
class Agent:
name: str
x: int
y: int
ori: int
hp: int = 10
energy: int = 100
team: str = "A"
brain: str = "q" # q | heuristic | random
inventory: Dict[str, int] = None
def __post_init__(self):
if self.inventory is None:
self.inventory = {}
@dataclass
class TrainConfig:
use_q: bool = True
alpha: float = 0.15
gamma: float = 0.95
epsilon: float = 0.10
epsilon_min: float = 0.02
epsilon_decay: float = 0.995
step_penalty: float = -0.01
explore_reward: float = 0.015
damage_penalty: float = -0.20
heal_reward: float = 0.10
chase_close_coeff: float = 0.03
chase_catch_reward: float = 3.0
chase_escaped_reward: float = 0.2
chase_caught_penalty: float = -3.0
food_reward: float = 0.6
artifact_pick_reward: float = 1.2
exit_win_reward: float = 3.0
guardian_tag_reward: float = 2.0
tagged_penalty: float = -2.0
switch_reward: float = 0.8
key_reward: float = 0.4
resource_pick_reward: float = 0.15
deposit_reward: float = 0.4
base_progress_win_reward: float = 3.5
raider_elim_reward: float = 2.0
builder_elim_penalty: float = -2.0
@dataclass
class GlobalMetrics:
episodes: int = 0
wins_teamA: int = 0
wins_teamB: int = 0
draws: int = 0
avg_steps: float = 0.0
rolling_winrate_A: float = 0.0
epsilon: float = 0.10
last_outcome: str = "init"
last_steps: int = 0
@dataclass
class EpisodeMetrics:
steps: int = 0
returns: Dict[str, float] = None
action_counts: Dict[str, Dict[str, int]] = None
tiles_discovered: Dict[str, int] = None
def __post_init__(self):
if self.returns is None:
self.returns = {}
if self.action_counts is None:
self.action_counts = {}
if self.tiles_discovered is None:
self.tiles_discovered = {}
@dataclass
class WorldState:
seed: int
step: int
env_key: str
grid: List[List[int]]
agents: Dict[str, Agent]
controlled: str
pov: str
overlay: bool
done: bool
outcome: str # A_win | B_win | draw | ongoing
door_opened_global: bool = False
base_progress: int = 0
base_target: int = 10
event_log: List[str] = None
trace_log: List[str] = None
cfg: TrainConfig = None
q_tables: Dict[str, Dict[str, List[float]]] = None
gmetrics: GlobalMetrics = None
emetrics: EpisodeMetrics = None
def __post_init__(self):
if self.event_log is None:
self.event_log = []
if self.trace_log is None:
self.trace_log = []
if self.cfg is None:
self.cfg = TrainConfig()
if self.q_tables is None:
self.q_tables = {}
if self.gmetrics is None:
self.gmetrics = GlobalMetrics(epsilon=self.cfg.epsilon)
if self.emetrics is None:
self.emetrics = EpisodeMetrics()
@dataclass
class Snapshot:
branch: str
step: int
env_key: str
grid: List[List[int]]
agents: Dict[str, Dict[str, Any]]
done: bool
outcome: str
door_opened_global: bool
base_progress: int
base_target: int
event_tail: List[str]
trace_tail: List[str]
emetrics: Dict[str, Any]
# -----------------------------
# Helpers
# -----------------------------
def in_bounds(x: int, y: int) -> bool:
return 0 <= x < GRID_W and 0 <= y < GRID_H
def is_blocking(tile: int, door_open: bool = False) -> bool:
if tile == WALL:
return True
if tile == DOOR and not door_open:
return True
return False
def manhattan_xy(ax: int, ay: int, bx: int, by: int) -> int:
return abs(ax - bx) + abs(ay - by)
def bresenham_los(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> bool:
dx = abs(x1 - x0)
dy = abs(y1 - y0)
sx = 1 if x0 < x1 else -1
sy = 1 if y0 < y1 else -1
err = dx - dy
x, y = x0, y0
while True:
if (x, y) != (x0, y0) and (x, y) != (x1, y1):
if grid[y][x] == WALL:
return False
if x == x1 and y == y1:
return True
e2 = 2 * err
if e2 > -dy:
err -= dy
x += sx
if e2 < dx:
err += dx
y += sy
def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = FOV_DEG) -> bool:
dx = tx - observer.x
dy = ty - observer.y
if dx == 0 and dy == 0:
return True
angle = math.degrees(math.atan2(dy, dx)) % 360
facing = ORI_DEG[observer.ori]
diff = (angle - facing + 540) % 360 - 180
return abs(diff) <= (fov_deg / 2)
def visible(state: WorldState, observer: Agent, target: Agent) -> bool:
if not within_fov(observer, target.x, target.y, FOV_DEG):
return False
return bresenham_los(state.grid, observer.x, observer.y, target.x, target.y)
def hash_sha256(txt: str) -> str:
return hashlib.sha256(txt.encode("utf-8")).hexdigest()
# -----------------------------
# Beliefs
# -----------------------------
def init_beliefs(agent_names: List[str]) -> Dict[str, np.ndarray]:
return {nm: (-1 * np.ones((GRID_H, GRID_W), dtype=np.int16)) for nm in agent_names}
def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> int:
before_unknown = int(np.sum(belief == -1))
belief[agent.y, agent.x] = state.grid[agent.y][agent.x]
base = math.radians(ORI_DEG[agent.ori])
half = math.radians(FOV_DEG / 2)
rays = 45 if agent.name.lower().startswith("scout") else 33
for i in range(rays):
t = i / (rays - 1)
ang = base + (t * 2 - 1) * half
sin_a, cos_a = math.sin(ang), math.cos(ang)
ox, oy = agent.x + 0.5, agent.y + 0.5
depth = 0.0
while depth < MAX_DEPTH:
depth += 0.2
tx = int(ox + cos_a * depth)
ty = int(oy + sin_a * depth)
if not in_bounds(tx, ty):
break
belief[ty, tx] = state.grid[ty][tx]
tile = state.grid[ty][tx]
if tile == WALL:
break
if tile == DOOR and not state.door_opened_global:
break
after_unknown = int(np.sum(belief == -1))
return max(0, before_unknown - after_unknown)
# -----------------------------
# Rendering
# -----------------------------
def raycast_view(state: WorldState, observer: Agent) -> np.ndarray:
img = np.zeros((VIEW_H, VIEW_W, 3), dtype=np.uint8)
img[:, :] = SKY
for y in range(VIEW_H // 2, VIEW_H):
t = (y - VIEW_H // 2) / (VIEW_H // 2 + 1e-6)
col = (1 - t) * FLOOR_NEAR + t * FLOOR_FAR
img[y, :] = col.astype(np.uint8)
fov = math.radians(FOV_DEG)
half_fov = fov / 2
for rx in range(RAY_W):
cam_x = (2 * rx / (RAY_W - 1)) - 1
ray_ang = math.radians(ORI_DEG[observer.ori]) + cam_x * half_fov
ox, oy = observer.x + 0.5, observer.y + 0.5
sin_a = math.sin(ray_ang)
cos_a = math.cos(ray_ang)
depth = 0.0
hit = None
side = 0
while depth < MAX_DEPTH:
depth += 0.05
tx = int(ox + cos_a * depth)
ty = int(oy + sin_a * depth)
if not in_bounds(tx, ty):
break
tile = state.grid[ty][tx]
if tile == WALL:
hit = "wall"
side = 1 if abs(cos_a) > abs(sin_a) else 0
break
if tile == DOOR and not state.door_opened_global:
hit = "door"
break
if hit is None:
continue
depth *= math.cos(ray_ang - math.radians(ORI_DEG[observer.ori]))
depth = max(depth, 0.001)
proj_h = int((VIEW_H * 0.9) / depth)
y0 = max(0, VIEW_H // 2 - proj_h // 2)
y1 = min(VIEW_H - 1, VIEW_H // 2 + proj_h // 2)
if hit == "door":
col = DOOR_COL.copy()
else:
col = WALL_BASE.copy() if side == 0 else WALL_SIDE.copy()
dim = max(0.25, 1.0 - (depth / MAX_DEPTH))
col = (col * dim).astype(np.uint8)
x0 = int(rx * (VIEW_W / RAY_W))
x1 = int((rx + 1) * (VIEW_W / RAY_W))
img[y0:y1, x0:x1] = col
for nm, other in state.agents.items():
if nm == observer.name or other.hp <= 0:
continue
if visible(state, observer, other):
dx = other.x - observer.x
dy = other.y - observer.y
ang = (math.degrees(math.atan2(dy, dx)) % 360)
facing = ORI_DEG[observer.ori]
diff = (ang - facing + 540) % 360 - 180
sx = int((diff / (FOV_DEG / 2)) * (VIEW_W / 2) + (VIEW_W / 2))
dist = math.sqrt(dx * dx + dy * dy)
h = int((VIEW_H * 0.65) / max(dist, 0.75))
w = max(10, h // 3)
y_mid = VIEW_H // 2
y0 = max(0, y_mid - h // 2)
y1 = min(VIEW_H - 1, y_mid + h // 2)
x0 = max(0, sx - w // 2)
x1 = min(VIEW_W - 1, sx + w // 2)
col = AGENT_COLORS.get(nm, (255, 200, 120))
img[y0:y1, x0:x1] = np.array(col, dtype=np.uint8)
if state.overlay:
cx, cy = VIEW_W // 2, VIEW_H // 2
img[cy - 1:cy + 2, cx - 10:cx + 10] = np.array([120, 190, 255], dtype=np.uint8)
img[cy - 10:cy + 10, cx - 1:cx + 2] = np.array([120, 190, 255], dtype=np.uint8)
return img
def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_agents: bool = True) -> Image.Image:
w = grid.shape[1] * TILE
h = grid.shape[0] * TILE
im = Image.new("RGB", (w, h + 28), (10, 12, 18))
draw = ImageDraw.Draw(im)
for y in range(grid.shape[0]):
for x in range(grid.shape[1]):
t = int(grid[y, x])
if t == -1:
col = (18, 20, 32)
elif t == EMPTY:
col = (26, 30, 44)
elif t == WALL:
col = (190, 190, 210)
elif t == FOOD:
col = (255, 210, 120)
elif t == NOISE:
col = (255, 120, 220)
elif t == DOOR:
col = (140, 210, 255)
elif t == TELE:
col = (120, 190, 255)
elif t == KEY:
col = (255, 235, 160)
elif t == EXIT:
col = (120, 255, 220)
elif t == ARTIFACT:
col = (255, 170, 60)
elif t == HAZARD:
col = (255, 90, 90)
elif t == WOOD:
col = (170, 120, 60)
elif t == ORE:
col = (140, 140, 160)
elif t == MEDKIT:
col = (120, 255, 140)
elif t == SWITCH:
col = (200, 180, 255)
elif t == BASE:
col = (220, 220, 240)
else:
col = (80, 80, 90)
x0, y0 = x * TILE, y * TILE + 28
draw.rectangle([x0, y0, x0 + TILE - 1, y0 + TILE - 1], fill=col)
for x in range(grid.shape[1] + 1):
xx = x * TILE
draw.line([xx, 28, xx, h + 28], fill=(12, 14, 22))
for y in range(grid.shape[0] + 1):
yy = y * TILE + 28
draw.line([0, yy, w, yy], fill=(12, 14, 22))
if show_agents:
for nm, a in agents.items():
if a.hp <= 0:
continue
cx = a.x * TILE + TILE // 2
cy = a.y * TILE + 28 + TILE // 2
col = AGENT_COLORS.get(nm, (220, 220, 220))
r = TILE // 3
draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col)
dx, dy = DIRS[a.ori]
draw.line([cx, cy, cx + dx * r, cy + dy * r], fill=(10, 10, 10), width=3)
draw.rectangle([0, 0, w, 28], fill=(14, 16, 26))
draw.text((8, 6), title, fill=(230, 230, 240))
return im
# -----------------------------
# Environments
# -----------------------------
def grid_with_border() -> List[List[int]]:
g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)]
for x in range(GRID_W):
g[0][x] = WALL
g[GRID_H - 1][x] = WALL
for y in range(GRID_H):
g[y][0] = WALL
g[y][GRID_W - 1] = WALL
return g
def env_chase(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
g = grid_with_border()
for x in range(4, 17):
g[7][x] = WALL
g[7][10] = DOOR
g[3][4] = FOOD
g[11][15] = FOOD
g[4][14] = NOISE
g[12][5] = NOISE
g[2][18] = TELE
g[13][2] = TELE
agents = {
"Predator": Agent("Predator", 2, 2, 0, hp=10, energy=100, team="A", brain="q"),
"Prey": Agent("Prey", 18, 12, 2, hp=10, energy=100, team="B", brain="q"),
"Scout": Agent("Scout", 10, 3, 1, hp=10, energy=100, team="A", brain="heuristic"),
}
return g, agents
def env_vault(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
g = grid_with_border()
for x in range(3, 18):
g[5][x] = WALL
for x in range(3, 18):
g[9][x] = WALL
g[5][10] = DOOR
g[9][12] = DOOR
g[2][2] = KEY
g[12][18] = EXIT
g[12][2] = ARTIFACT
g[2][18] = TELE
g[13][2] = TELE
g[7][10] = SWITCH
g[3][15] = HAZARD
g[11][6] = MEDKIT
g[2][12] = FOOD
agents = {
"Alpha": Agent("Alpha", 2, 12, 0, hp=10, energy=100, team="A", brain="q"),
"Bravo": Agent("Bravo", 3, 12, 0, hp=10, energy=100, team="A", brain="q"),
"Guardian": Agent("Guardian", 18, 2, 2, hp=10, energy=100, team="B", brain="q"),
}
return g, agents
def env_civ(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]:
g = grid_with_border()
for y in range(3, 12):
g[y][9] = WALL
g[7][9] = DOOR
g[2][3] = WOOD
g[3][3] = WOOD
g[4][3] = WOOD
g[12][16] = ORE
g[11][16] = ORE
g[10][16] = ORE
g[6][4] = FOOD
g[8][15] = FOOD
g[13][10] = BASE
g[4][15] = HAZARD
g[10][4] = HAZARD
g[2][18] = TELE
g[13][2] = TELE
g[2][2] = KEY
g[12][6] = SWITCH
agents = {
"BuilderA": Agent("BuilderA", 3, 12, 0, hp=10, energy=100, team="A", brain="q"),
"BuilderB": Agent("BuilderB", 4, 12, 0, hp=10, energy=100, team="A", brain="q"),
"Raider": Agent("Raider", 18, 2, 2, hp=10, energy=100, team="B", brain="q"),
}
return g, agents
ENV_BUILDERS = {"chase": env_chase, "vault": env_vault, "civ": env_civ}
# -----------------------------
# Observation / Q-learning
# -----------------------------
def local_tile_ahead(state: WorldState, a: Agent) -> int:
dx, dy = DIRS[a.ori]
nx, ny = a.x + dx, a.y + dy
if not in_bounds(nx, ny):
return WALL
return state.grid[ny][nx]
def nearest_enemy_vec(state: WorldState, a: Agent) -> Tuple[int, int, int]:
best = None
for _, other in state.agents.items():
if other.hp <= 0:
continue
if other.team == a.team:
continue
d = manhattan_xy(a.x, a.y, other.x, other.y)
if best is None or d < best[0]:
best = (d, other.x - a.x, other.y - a.y)
if best is None:
return (99, 0, 0)
d, dx, dy = best
return (d, int(np.clip(dx, -6, 6)), int(np.clip(dy, -6, 6)))
def obs_key(state: WorldState, who: str) -> str:
a = state.agents[who]
d, dx, dy = nearest_enemy_vec(state, a)
ahead = local_tile_ahead(state, a)
keys = a.inventory.get("key", 0)
art = a.inventory.get("artifact", 0)
wood = a.inventory.get("wood", 0)
ore = a.inventory.get("ore", 0)
inv_bucket = f"k{min(keys,2)}a{min(art,1)}w{min(wood,3)}o{min(ore,3)}"
door = 1 if state.door_opened_global else 0
return f"{state.env_key}|{who}|{a.x},{a.y},{a.ori}|e{d}:{dx},{dy}|t{ahead}|hp{a.hp}|{inv_bucket}|D{door}|bp{state.base_progress}"
def q_get(q: Dict[str, List[float]], key: str) -> List[float]:
if key not in q:
q[key] = [0.0 for _ in ACTIONS]
return q[key]
def epsilon_greedy(qvals: List[float], eps: float, r: np.random.Generator) -> int:
if r.random() < eps:
return int(r.integers(0, len(qvals)))
return int(np.argmax(qvals))
def q_update(q: Dict[str, List[float]], key: str, a_idx: int, reward: float, next_key: str,
alpha: float, gamma: float) -> Tuple[float, float, float]:
qv = q_get(q, key)
nq = q_get(q, next_key)
old = qv[a_idx]
target = reward + gamma * float(np.max(nq))
new = old + alpha * (target - old)
qv[a_idx] = new
return old, target, new
# -----------------------------
# Baseline heuristics
# -----------------------------
def heuristic_action(state: WorldState, who: str) -> str:
a = state.agents[who]
r = rng_for(state.seed, state.step, stream=900 + hash(who) % 1000)
t_here = state.grid[a.y][a.x]
if t_here in (FOOD, KEY, ARTIFACT, WOOD, ORE, MEDKIT, SWITCH, BASE, EXIT):
return "I"
best = None
best_d = 999
for _, other in state.agents.items():
if other.hp <= 0 or other.team == a.team:
continue
d = manhattan_xy(a.x, a.y, other.x, other.y)
if d < best_d:
best_d = d
best = other
if best is not None and best_d <= 6 and visible(state, a, best):
dx = best.x - a.x
dy = best.y - a.y
ang = (math.degrees(math.atan2(dy, dx)) % 360)
facing = ORI_DEG[a.ori]
diff = (ang - facing + 540) % 360 - 180
if diff < -10:
return "L"
if diff > 10:
return "R"
return "F"
return r.choice(["F", "F", "L", "R", "I"])
def random_action(state: WorldState, who: str) -> str:
r = rng_for(state.seed, state.step, stream=700 + hash(who) % 1000)
return r.choice(ACTIONS)
# -----------------------------
# Movement + interaction
# -----------------------------
def turn_left(a: Agent) -> None:
a.ori = (a.ori - 1) % 4
def turn_right(a: Agent) -> None:
a.ori = (a.ori + 1) % 4
def move_forward(state: WorldState, a: Agent) -> str:
dx, dy = DIRS[a.ori]
nx, ny = a.x + dx, a.y + dy
if not in_bounds(nx, ny):
return "blocked: bounds"
tile = state.grid[ny][nx]
if is_blocking(tile, door_open=state.door_opened_global):
return "blocked: wall/door"
a.x, a.y = nx, ny
if state.grid[ny][nx] == TELE:
teles = [(x, y) for y in range(GRID_H) for x in range(GRID_W) if state.grid[y][x] == TELE]
if len(teles) >= 2:
teles_sorted = sorted(teles)
idx = teles_sorted.index((nx, ny))
dest = teles_sorted[(idx + 1) % len(teles_sorted)]
a.x, a.y = dest
state.event_log.append(f"t={state.step}: {a.name} teleported.")
return "moved: teleported"
return "moved"
def try_interact(state: WorldState, a: Agent) -> str:
t = state.grid[a.y][a.x]
if t == SWITCH:
state.door_opened_global = True
state.grid[a.y][a.x] = EMPTY
a.inventory["switch"] = a.inventory.get("switch", 0) + 1
return "switch: opened all doors"
if t == KEY:
a.inventory["key"] = a.inventory.get("key", 0) + 1
state.grid[a.y][a.x] = EMPTY
return "picked: key"
if t == ARTIFACT:
a.inventory["artifact"] = a.inventory.get("artifact", 0) + 1
state.grid[a.y][a.x] = EMPTY
return "picked: artifact"
if t == FOOD:
a.energy = min(200, a.energy + 35)
state.grid[a.y][a.x] = EMPTY
return "ate: food"
if t == WOOD:
a.inventory["wood"] = a.inventory.get("wood", 0) + 1
state.grid[a.y][a.x] = EMPTY
return "picked: wood"
if t == ORE:
a.inventory["ore"] = a.inventory.get("ore", 0) + 1
state.grid[a.y][a.x] = EMPTY
return "picked: ore"
if t == MEDKIT:
a.hp = min(10, a.hp + 3)
state.grid[a.y][a.x] = EMPTY
return "used: medkit"
if t == BASE:
w = a.inventory.get("wood", 0)
o = a.inventory.get("ore", 0)
dep = min(w, 2) + min(o, 2)
if dep > 0:
a.inventory["wood"] = max(0, w - min(w, 2))
a.inventory["ore"] = max(0, o - min(o, 2))
state.base_progress += dep
return f"deposited: +{dep} base_progress"
return "base: nothing to deposit"
if t == EXIT:
return "at_exit"
return "interact: none"
def apply_action(state: WorldState, who: str, action: str) -> str:
a = state.agents[who]
if a.hp <= 0:
return "dead"
if action == "L":
turn_left(a)
return "turned left"
if action == "R":
turn_right(a)
return "turned right"
if action == "F":
return move_forward(state, a)
if action == "I":
return try_interact(state, a)
return "noop"
# -----------------------------
# Hazards / collisions / done
# -----------------------------
def resolve_hazards(state: WorldState, a: Agent) -> Tuple[bool, str]:
if a.hp <= 0:
return (False, "")
if state.grid[a.y][a.x] == HAZARD:
a.hp -= 1
return (True, "hazard:-hp")
return (False, "")
def resolve_tags(state: WorldState) -> List[str]:
msgs = []
occupied: Dict[Tuple[int, int], List[str]] = {}
for nm, a in state.agents.items():
if a.hp <= 0:
continue
occupied.setdefault((a.x, a.y), []).append(nm)
for (x, y), names in occupied.items():
if len(names) < 2:
continue
teams = set(state.agents[n].team for n in names)
if len(teams) >= 2:
for n in names:
state.agents[n].hp -= 1
msgs.append(f"t={state.step}: collision/tag at ({x},{y}) {names} (-hp all)")
return msgs
def check_done(state: WorldState) -> None:
if state.env_key == "chase":
pred = state.agents["Predator"]
prey = state.agents["Prey"]
if pred.hp <= 0 and prey.hp <= 0:
state.done = True
state.outcome = "draw"
return
if pred.hp > 0 and prey.hp > 0 and pred.x == prey.x and pred.y == prey.y:
state.done = True
state.outcome = "A_win"
state.event_log.append(f"t={state.step}: CAUGHT (Predator wins).")
return
if state.step >= 300 and prey.hp > 0:
state.done = True
state.outcome = "B_win"
state.event_log.append(f"t={state.step}: ESCAPED (Prey survives).")
return
if state.env_key == "vault":
for nm in ["Alpha", "Bravo"]:
a = state.agents[nm]
if a.hp > 0 and a.inventory.get("artifact", 0) > 0 and state.grid[a.y][a.x] == EXIT:
state.done = True
state.outcome = "A_win"
state.event_log.append(f"t={state.step}: VAULT CLEARED (Team A wins).")
return
alive_A = any(state.agents[n].hp > 0 for n in ["Alpha", "Bravo"])
if not alive_A:
state.done = True
state.outcome = "B_win"
state.event_log.append(f"t={state.step}: TEAM A ELIMINATED (Guardian wins).")
return
if state.env_key == "civ":
if state.base_progress >= state.base_target:
state.done = True
state.outcome = "A_win"
state.event_log.append(f"t={state.step}: BASE COMPLETE (Builders win).")
return
alive_A = any(state.agents[n].hp > 0 for n in ["BuilderA", "BuilderB"])
if not alive_A:
state.done = True
state.outcome = "B_win"
state.event_log.append(f"t={state.step}: BUILDERS ELIMINATED (Raider wins).")
return
if state.step >= 350:
state.done = True
state.outcome = "draw"
state.event_log.append(f"t={state.step}: TIMEOUT (draw).")
return
# -----------------------------
# Rewards
# -----------------------------
def reward_for(prev: WorldState, now: WorldState, who: str, outcome_msg: str, took_damage: bool) -> float:
cfg = now.cfg
r = cfg.step_penalty
if outcome_msg.startswith("moved"):
r += cfg.explore_reward
if took_damage:
r += cfg.damage_penalty
if outcome_msg.startswith("used: medkit"):
r += cfg.heal_reward
if now.env_key == "chase":
pred = now.agents["Predator"]
prey = now.agents["Prey"]
if who == "Predator":
d0 = manhattan_xy(prev.agents["Predator"].x, prev.agents["Predator"].y,
prev.agents["Prey"].x, prev.agents["Prey"].y)
d1 = manhattan_xy(pred.x, pred.y, prey.x, prey.y)
r += cfg.chase_close_coeff * float(d0 - d1)
if now.done and now.outcome == "A_win":
r += cfg.chase_catch_reward
if who == "Prey":
if outcome_msg.startswith("ate: food"):
r += cfg.food_reward
if now.done and now.outcome == "B_win":
r += cfg.chase_escaped_reward
if now.done and now.outcome == "A_win":
r += cfg.chase_caught_penalty
if now.env_key == "vault":
if outcome_msg.startswith("picked: artifact"):
r += cfg.artifact_pick_reward
if outcome_msg.startswith("picked: key"):
r += cfg.key_reward
if outcome_msg.startswith("switch:"):
r += cfg.switch_reward
if now.done:
if now.outcome == "A_win" and now.agents[who].team == "A":
r += cfg.exit_win_reward
if now.outcome == "B_win" and now.agents[who].team == "B":
r += cfg.guardian_tag_reward
if now.outcome == "B_win" and now.agents[who].team == "A":
r += cfg.tagged_penalty
if now.env_key == "civ":
if outcome_msg.startswith("picked: wood") or outcome_msg.startswith("picked: ore"):
r += cfg.resource_pick_reward
if outcome_msg.startswith("deposited:"):
r += cfg.deposit_reward
if now.done:
if now.outcome == "A_win" and now.agents[who].team == "A":
r += cfg.base_progress_win_reward
if now.outcome == "B_win" and now.agents[who].team == "B":
r += cfg.raider_elim_reward
if now.outcome == "B_win" and now.agents[who].team == "A":
r += cfg.builder_elim_penalty
return float(r)
# -----------------------------
# Policy selection
# -----------------------------
def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, Optional[Tuple[str, int]]]:
a = state.agents[who]
cfg = state.cfg
r = rng_for(state.seed, state.step, stream=stream)
if a.brain == "random":
act = random_action(state, who)
return act, "random", None
if a.brain == "heuristic":
act = heuristic_action(state, who)
return act, "heuristic", None
if cfg.use_q:
key = obs_key(state, who)
qtab = state.q_tables.setdefault(who, {})
qv = q_get(qtab, key)
a_idx = epsilon_greedy(qv, state.gmetrics.epsilon, r)
return ACTIONS[a_idx], f"Q eps={state.gmetrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (key, a_idx)
act = heuristic_action(state, who)
return act, "heuristic(fallback)", None
# -----------------------------
# Init / reset
# -----------------------------
def init_state(seed: int, env_key: str) -> WorldState:
g, agents = ENV_BUILDERS[env_key](seed)
st = WorldState(
seed=seed,
step=0,
env_key=env_key,
grid=g,
agents=agents,
controlled=list(agents.keys())[0],
pov=list(agents.keys())[0],
overlay=False,
done=False,
outcome="ongoing",
door_opened_global=False,
base_progress=0,
base_target=10,
)
st.event_log = [f"Initialized env={env_key} seed={seed}."]
return st
def reset_episode_keep_learning(state: WorldState, seed: Optional[int] = None) -> WorldState:
if seed is None:
seed = state.seed
fresh = init_state(int(seed), state.env_key)
fresh.cfg = state.cfg
fresh.q_tables = state.q_tables
fresh.gmetrics = state.gmetrics
fresh.gmetrics.epsilon = state.gmetrics.epsilon
return fresh
def wipe_all(seed: int, env_key: str) -> WorldState:
st = init_state(seed, env_key)
st.cfg = TrainConfig()
st.gmetrics = GlobalMetrics(epsilon=st.cfg.epsilon)
st.q_tables = {}
return st
# -----------------------------
# History / branching
# -----------------------------
def snapshot_of(state: WorldState, branch: str) -> Snapshot:
return Snapshot(
branch=branch,
step=state.step,
env_key=state.env_key,
grid=[row[:] for row in state.grid],
agents={k: asdict(v) for k, v in state.agents.items()},
done=state.done,
outcome=state.outcome,
door_opened_global=state.door_opened_global,
base_progress=state.base_progress,
base_target=state.base_target,
event_tail=state.event_log[-25:],
trace_tail=state.trace_log[-40:],
emetrics=asdict(state.emetrics),
)
def restore_into(state: WorldState, snap: Snapshot) -> WorldState:
state.step = snap.step
state.env_key = snap.env_key
state.grid = [row[:] for row in snap.grid]
state.agents = {k: Agent(**d) for k, d in snap.agents.items()}
state.done = snap.done
state.outcome = snap.outcome
state.door_opened_global = snap.door_opened_global
state.base_progress = snap.base_progress
state.base_target = snap.base_target
state.event_log.append(f"Jumped to snapshot t={snap.step} (branch={snap.branch}).")
return state
# -----------------------------
# Metrics / dashboard
# -----------------------------
def metrics_dashboard_image(state: WorldState) -> Image.Image:
gm = state.gmetrics
fig = plt.figure(figsize=(7.0, 2.2), dpi=120)
ax = fig.add_subplot(111)
x1 = max(1, gm.episodes)
ax.plot([0, x1], [gm.rolling_winrate_A, gm.rolling_winrate_A])
ax.set_title("Global Metrics Snapshot")
ax.set_xlabel("Episodes")
ax.set_ylabel("Rolling winrate Team A")
ax.set_ylim(-0.05, 1.05)
ax.grid(True)
txt = (
f"env={state.env_key} | eps={gm.epsilon:.3f} | episodes={gm.episodes}\n"
f"A_wins={gm.wins_teamA} B_wins={gm.wins_teamB} draws={gm.draws} | avg_steps~{gm.avg_steps:.1f}\n"
f"last_outcome={gm.last_outcome} last_steps={gm.last_steps}"
)
ax.text(0.01, 0.05, txt, transform=ax.transAxes, fontsize=8, va="bottom")
fig.tight_layout()
canvas = FigureCanvas(fig)
canvas.draw()
buf = np.asarray(canvas.buffer_rgba())
img = Image.fromarray(buf, mode="RGBA").convert("RGB")
plt.close(fig)
return img
def action_entropy(counts: Dict[str, int]) -> float:
total = sum(counts.values())
if total <= 0:
return 0.0
p = np.array([c / total for c in counts.values()], dtype=np.float64)
p = np.clip(p, 1e-12, 1.0)
return float(-np.sum(p * np.log2(p)))
def agent_scoreboard(state: WorldState) -> str:
rows = []
header = ["agent", "team", "hp", "return", "steps", "entropy", "tiles_disc", "q_states", "inventory"]
rows.append(header)
steps = state.emetrics.steps
for nm, a in state.agents.items():
ret = state.emetrics.returns.get(nm, 0.0)
counts = state.emetrics.action_counts.get(nm, {})
ent = action_entropy(counts)
td = state.emetrics.tiles_discovered.get(nm, 0)
qs = len(state.q_tables.get(nm, {}))
inv = json.dumps(a.inventory, sort_keys=True)
rows.append([nm, a.team, a.hp, f"{ret:.2f}", steps, f"{ent:.2f}", td, qs, inv])
col_w = [max(len(str(r[i])) for r in rows) for i in range(len(header))]
lines = []
for ridx, r in enumerate(rows):
line = " | ".join(str(r[i]).ljust(col_w[i]) for i in range(len(header)))
lines.append(line)
if ridx == 0:
lines.append("-+-".join("-" * w for w in col_w))
return "\n".join(lines)
# -----------------------------
# Tick / training
# -----------------------------
def clone_shallow(state: WorldState) -> WorldState:
return WorldState(
seed=state.seed,
step=state.step,
env_key=state.env_key,
grid=[row[:] for row in state.grid],
agents={k: Agent(**asdict(v)) for k, v in state.agents.items()},
controlled=state.controlled,
pov=state.pov,
overlay=state.overlay,
done=state.done,
outcome=state.outcome,
door_opened_global=state.door_opened_global,
base_progress=state.base_progress,
base_target=state.base_target,
event_log=list(state.event_log),
trace_log=list(state.trace_log),
cfg=state.cfg,
q_tables=state.q_tables,
gmetrics=state.gmetrics,
emetrics=state.emetrics,
)
def update_action_counts(state: WorldState, who: str, act: str):
state.emetrics.action_counts.setdefault(who, {})
state.emetrics.action_counts[who][act] = state.emetrics.action_counts[who].get(act, 0) + 1
def tick(state: WorldState, beliefs: Dict[str, np.ndarray], manual_action: Optional[str] = None) -> None:
if state.done:
return
prev = clone_shallow(state)
chosen: Dict[str, str] = {}
reasons: Dict[str, str] = {}
qinfo: Dict[str, Optional[Tuple[str, int]]] = {}
if manual_action is not None:
chosen[state.controlled] = manual_action
reasons[state.controlled] = "manual"
qinfo[state.controlled] = None
order = list(state.agents.keys())
for who in order:
if who in chosen:
continue
act, reason, qi = choose_action(state, who, stream=200 + (hash(who) % 1000))
chosen[who] = act
reasons[who] = reason
qinfo[who] = qi
outcomes: Dict[str, str] = {}
took_damage: Dict[str, bool] = {nm: False for nm in order}
for who in order:
outcomes[who] = apply_action(state, who, chosen[who])
dmg, msg = resolve_hazards(state, state.agents[who])
took_damage[who] = dmg
if msg:
state.event_log.append(f"t={state.step}: {who} {msg}")
update_action_counts(state, who, chosen[who])
for m in resolve_tags(state):
state.event_log.append(m)
for nm, a in state.agents.items():
if a.hp <= 0:
continue
disc = update_belief_for_agent(state, beliefs[nm], a)
state.emetrics.tiles_discovered[nm] = state.emetrics.tiles_discovered.get(nm, 0) + disc
check_done(state)
q_lines = []
for who in order:
state.emetrics.returns.setdefault(who, 0.0)
r = reward_for(prev, state, who, outcomes[who], took_damage[who])
state.emetrics.returns[who] += r
if qinfo.get(who) is not None:
key, a_idx = qinfo[who]
next_key = obs_key(state, who)
qtab = state.q_tables.setdefault(who, {})
old, tgt, new = q_update(qtab, key, a_idx, r, next_key, state.cfg.alpha, state.cfg.gamma)
q_lines.append(f"{who}: old={old:.3f} tgt={tgt:.3f} new={new:.3f} (a={ACTIONS[a_idx]})")
trace = f"t={state.step} env={state.env_key} done={state.done} outcome={state.outcome}"
for who in order:
a = state.agents[who]
trace += f" | {who}:{chosen[who]} ({outcomes[who]}) hp={a.hp} [{reasons[who]}]"
if q_lines:
trace += " | Q: " + " ; ".join(q_lines)
state.trace_log.append(trace)
if len(state.trace_log) > TRACE_MAX:
state.trace_log = state.trace_log[-TRACE_MAX:]
state.step += 1
state.emetrics.steps = state.step
def run_episode(state: WorldState, beliefs: Dict[str, np.ndarray], max_steps: int) -> Tuple[str, int]:
while state.step < max_steps and not state.done:
tick(state, beliefs, manual_action=None)
return state.outcome, state.step
def update_global_metrics_after_episode(state: WorldState, outcome: str, steps: int):
gm = state.gmetrics
gm.episodes += 1
gm.last_outcome = outcome
gm.last_steps = steps
if outcome == "A_win":
gm.wins_teamA += 1
gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 1.0
elif outcome == "B_win":
gm.wins_teamB += 1
gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.0
else:
gm.draws += 1
gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.5
gm.avg_steps = (0.90 * gm.avg_steps + 0.10 * steps) if gm.avg_steps > 0 else float(steps)
gm.epsilon = max(state.cfg.epsilon_min, gm.epsilon * state.cfg.epsilon_decay)
def train(state: WorldState, episodes: int, max_steps: int) -> WorldState:
for ep in range(episodes):
ep_seed = (state.seed * 1_000_003 + (state.gmetrics.episodes + ep) * 97_531) & 0xFFFFFFFF
state = reset_episode_keep_learning(state, seed=int(ep_seed))
beliefs = init_beliefs(list(state.agents.keys()))
outcome, steps = run_episode(state, beliefs, max_steps=max_steps)
update_global_metrics_after_episode(state, outcome, steps)
state.event_log.append(
f"Training: +{episodes} eps | eps={state.gmetrics.epsilon:.3f} | "
f"A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws}"
)
state = reset_episode_keep_learning(state, seed=state.seed)
return state
# -----------------------------
# Export / Import
# -----------------------------
def export_run(state: WorldState, branches: Dict[str, List[Snapshot]], active_branch: str, rewind_idx: int) -> str:
payload = {
"seed": state.seed,
"env_key": state.env_key,
"controlled": state.controlled,
"pov": state.pov,
"overlay": state.overlay,
"cfg": asdict(state.cfg),
"gmetrics": asdict(state.gmetrics),
"q_tables": state.q_tables,
"branches": {b: [asdict(s) for s in snaps] for b, snaps in branches.items()},
"active_branch": active_branch,
"rewind_idx": int(rewind_idx),
"grid": state.grid,
"door_opened_global": state.door_opened_global,
"base_progress": state.base_progress,
"base_target": state.base_target,
}
txt = json.dumps(payload, indent=2)
proof = hash_sha256(txt)
return txt + "\n\n" + json.dumps({"proof_sha256": proof}, indent=2)
def import_run(txt: str) -> Tuple[WorldState, Dict[str, List[Snapshot]], str, int, Dict[str, np.ndarray]]:
parts = txt.strip().split("\n\n")
data = json.loads(parts[0])
st = init_state(int(data.get("seed", 1337)), data.get("env_key", "chase"))
st.controlled = data.get("controlled", st.controlled)
st.pov = data.get("pov", st.pov)
st.overlay = bool(data.get("overlay", False))
st.grid = data.get("grid", st.grid)
st.door_opened_global = bool(data.get("door_opened_global", False))
st.base_progress = int(data.get("base_progress", 0))
st.base_target = int(data.get("base_target", 10))
st.cfg = TrainConfig(**data.get("cfg", asdict(st.cfg)))
st.gmetrics = GlobalMetrics(**data.get("gmetrics", asdict(st.gmetrics)))
st.q_tables = data.get("q_tables", {})
branches_in = data.get("branches", {})
branches: Dict[str, List[Snapshot]] = {}
for bname, snaps in branches_in.items():
branches[bname] = [Snapshot(**s) for s in snaps]
active = data.get("active_branch", "main")
r_idx = int(data.get("rewind_idx", 0))
if active in branches and branches[active]:
st = restore_into(st, branches[active][-1])
st.event_log.append("Imported run (restored last snapshot).")
else:
st.event_log.append("Imported run (no snapshots).")
beliefs = init_beliefs(list(st.agents.keys()))
return st, branches, active, r_idx, beliefs
# -----------------------------
# UI helpers
# -----------------------------
def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, Image.Image, str, str, str, str]:
for nm, a in state.agents.items():
if a.hp > 0:
update_belief_for_agent(state, beliefs[nm], a)
pov = raycast_view(state, state.agents[state.pov])
truth_np = np.array(state.grid, dtype=np.int16)
truth_img = render_topdown(truth_np, state.agents, f"Truth Map — env={state.env_key} t={state.step} seed={state.seed}", True)
ctrl = state.controlled
others = [k for k in state.agents.keys() if k != ctrl]
other = others[0] if others else ctrl
b_ctrl = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief", True)
b_other = render_topdown(beliefs[other], state.agents, f"{other} Belief", True)
dash = metrics_dashboard_image(state)
status = (
f"env={state.env_key} | seed={state.seed} | Controlled={state.controlled} | POV={state.pov} | done={state.done} outcome={state.outcome}\n"
f"Episode steps={state.step} | base_progress={state.base_progress}/{state.base_target} | doors_open={state.door_opened_global}\n"
f"Global: episodes={state.gmetrics.episodes} | A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws} "
f"| winrateA~{state.gmetrics.rolling_winrate_A:.2f} | eps={state.gmetrics.epsilon:.3f}"
)
events = "\n".join(state.event_log[-18:])
trace = "\n".join(state.trace_log[-18:])
scoreboard = agent_scoreboard(state)
return pov, truth_img, b_ctrl, b_other, dash, status, events, trace, scoreboard
def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState) -> WorldState:
x_px, y_px = evt.index
y_px -= 28
if y_px < 0:
return state
gx = int(x_px // TILE)
gy = int(y_px // TILE)
if not in_bounds(gx, gy):
return state
if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1:
return state
state.grid[gy][gx] = selected_tile
state.event_log.append(f"t={state.step}: Tile ({gx},{gy}) -> {TILE_NAMES.get(selected_tile)}")
return state
# -----------------------------
# Gradio app
# -----------------------------
TITLE = "ZEN AgentLab — Agent POV + Autoplay Multi-Agent Sims"
with gr.Blocks(title=TITLE) as demo:
gr.Markdown(
f"## {TITLE}\n"
"**Press Start Autoplay** to watch the sim unfold live. Interject anytime with manual actions or edits.\n"
"Use **Cinematic Run** for an instant full-episode spectacle. No background timers beyond the UI autoplay."
)
st0 = init_state(1337, "chase")
st = gr.State(st0)
branches = gr.State({"main": [snapshot_of(st0, "main")]})
active_branch = gr.State("main")
rewind_idx = gr.State(0)
beliefs = gr.State(init_beliefs(list(st0.agents.keys())))
autoplay_on = gr.State(False)
with gr.Row():
pov_img = gr.Image(label="POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H)
with gr.Column():
status = gr.Textbox(label="Status", lines=3)
scoreboard = gr.Textbox(label="Agent Scoreboard", lines=8)
with gr.Row():
truth = gr.Image(label="Truth Map (click to edit tiles)", type="pil")
belief_a = gr.Image(label="Belief (Controlled)", type="pil")
belief_b = gr.Image(label="Belief (Other)", type="pil")
with gr.Row():
dash = gr.Image(label="Metrics Dashboard", type="pil")
with gr.Row():
events = gr.Textbox(label="Event Log", lines=10)
trace = gr.Textbox(label="Step Trace", lines=10)
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### Quick Start (Examples)")
examples = gr.Examples(
examples=[
["chase", 1337],
["vault", 2024],
["civ", 777],
],
inputs=[],
label="",
)
gr.Markdown("Pick an environment + seed below, then click **Apply**.")
with gr.Row():
env_pick = gr.Radio(
choices=[("Chase (Predator vs Prey)", "chase"),
("CoopVault (team vs guardian)", "vault"),
("MiniCiv (build + raid)", "civ")],
value="chase",
label="Environment"
)
seed_box = gr.Number(value=1337, precision=0, label="Seed")
with gr.Row():
btn_apply_env_seed = gr.Button("Apply (Env + Seed)")
btn_reset_ep = gr.Button("Reset Episode (keep learning)")
gr.Markdown("### Autoplay + Spectacle")
with gr.Row():
autoplay_speed = gr.Slider(0.05, 1.0, value=0.20, step=0.05, label="Autoplay step interval (seconds)")
with gr.Row():
btn_autoplay_start = gr.Button("▶ Start Autoplay")
btn_autoplay_stop = gr.Button("⏸ Stop Autoplay")
with gr.Row():
cinematic_steps = gr.Number(value=350, precision=0, label="Cinematic max steps")
btn_cinematic = gr.Button("🎬 Cinematic Run (Full Episode)")
gr.Markdown("### Manual Controls (Interject Anytime)")
with gr.Row():
btn_L = gr.Button("L")
btn_F = gr.Button("F")
btn_R = gr.Button("R")
btn_I = gr.Button("I (Interact)")
with gr.Row():
btn_tick = gr.Button("Tick")
run_steps = gr.Number(value=25, label="Run N steps", precision=0)
btn_run = gr.Button("Run")
with gr.Row():
btn_toggle_control = gr.Button("Toggle Controlled")
btn_toggle_pov = gr.Button("Toggle POV")
overlay = gr.Checkbox(False, label="Overlay reticle")
tile_pick = gr.Radio(
choices=[(TILE_NAMES[k], k) for k in [EMPTY, WALL, FOOD, NOISE, DOOR, TELE, KEY, EXIT, ARTIFACT, HAZARD, WOOD, ORE, MEDKIT, SWITCH, BASE]],
value=WALL,
label="Paint tile type"
)
with gr.Column(scale=3):
gr.Markdown("### Training Controls (Tabular Q-learning)")
use_q = gr.Checkbox(True, label="Use Q-learning (agents with brain='q')")
alpha = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="alpha")
gamma = gr.Slider(0.5, 0.99, value=0.95, step=0.01, label="gamma")
eps = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="epsilon")
eps_decay = gr.Slider(0.90, 0.999, value=0.995, step=0.001, label="epsilon decay")
eps_min = gr.Slider(0.0, 0.2, value=0.02, step=0.01, label="epsilon min")
episodes = gr.Number(value=50, label="Train episodes", precision=0)
max_steps = gr.Number(value=260, label="Max steps/episode", precision=0)
btn_train = gr.Button("Train")
btn_reset_all = gr.Button("Reset ALL (wipe Q + metrics)")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### Timeline + Branching")
rewind = gr.Slider(0, 0, value=0, step=1, label="Rewind index (active branch)")
btn_jump = gr.Button("Jump to index")
new_branch_name = gr.Textbox(value="fork1", label="New branch name")
btn_fork = gr.Button("Fork from current rewind")
with gr.Column(scale=2):
branch_pick = gr.Dropdown(choices=["main"], value="main", label="Active branch")
btn_set_branch = gr.Button("Set Active Branch")
with gr.Column(scale=3):
export_box = gr.Textbox(label="Export JSON (+ proof hash)", lines=8)
btn_export = gr.Button("Export")
import_box = gr.Textbox(label="Import JSON", lines=8)
btn_import = gr.Button("Import")
# Autoplay timer (inactive by default)
timer = gr.Timer(value=0.20, active=False)
# ---------- glue ----------
def refresh(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str, bel: Dict[str, np.ndarray], r: int):
snaps = branches_d.get(active, [])
r_max = max(0, len(snaps) - 1)
r = max(0, min(int(r), r_max))
pov, tr, ba, bb, dimg, stxt, etxt, ttxt, sb = build_views(state, bel)
branch_choices = sorted(list(branches_d.keys()))
return (
pov, tr, ba, bb, dimg, stxt, sb, etxt, ttxt,
gr.update(maximum=r_max, value=r), r,
gr.update(choices=branch_choices, value=active),
gr.update(choices=branch_choices, value=active),
)
def push_hist(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str) -> Dict[str, List[Snapshot]]:
branches_d.setdefault(active, [])
branches_d[active].append(snapshot_of(state, active))
if len(branches_d[active]) > MAX_HISTORY:
branches_d[active].pop(0)
return branches_d
def set_cfg(state: WorldState, use_q_v: bool, a: float, g: float, e: float, ed: float, emin: float) -> WorldState:
state.cfg.use_q = bool(use_q_v)
state.cfg.alpha = float(a)
state.cfg.gamma = float(g)
state.gmetrics.epsilon = float(e)
state.cfg.epsilon_decay = float(ed)
state.cfg.epsilon_min = float(emin)
return state
def do_manual(state, branches_d, active, bel, r, act):
tick(state, bel, manual_action=act)
branches_d = push_hist(state, branches_d, active)
r = len(branches_d[active]) - 1
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def do_tick(state, branches_d, active, bel, r):
tick(state, bel, manual_action=None)
branches_d = push_hist(state, branches_d, active)
r = len(branches_d[active]) - 1
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def do_run(state, branches_d, active, bel, r, n):
n = max(1, int(n))
for _ in range(n):
if state.done:
break
tick(state, bel, manual_action=None)
branches_d = push_hist(state, branches_d, active)
r = len(branches_d[active]) - 1
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def toggle_control(state, branches_d, active, bel, r):
order = list(state.agents.keys())
i = order.index(state.controlled)
state.controlled = order[(i + 1) % len(order)]
state.event_log.append(f"Controlled -> {state.controlled}")
branches_d = push_hist(state, branches_d, active)
r = len(branches_d[active]) - 1
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def toggle_pov(state, branches_d, active, bel, r):
order = list(state.agents.keys())
i = order.index(state.pov)
state.pov = order[(i + 1) % len(order)]
state.event_log.append(f"POV -> {state.pov}")
branches_d = push_hist(state, branches_d, active)
r = len(branches_d[active]) - 1
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def set_overlay(state, branches_d, active, bel, r, ov):
state.overlay = bool(ov)
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def click_truth(tile, state, branches_d, active, bel, r, evt: gr.SelectData):
state = grid_click_to_tile(evt, int(tile), state)
branches_d = push_hist(state, branches_d, active)
r = len(branches_d[active]) - 1
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def jump(state, branches_d, active, bel, r, idx):
snaps = branches_d.get(active, [])
if not snaps:
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
idx = max(0, min(int(idx), len(snaps) - 1))
state = restore_into(state, snaps[idx])
r = idx
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def fork_branch(state, branches_d, active, bel, r, new_name):
new_name = (new_name or "").strip() or "fork"
new_name = new_name.replace(" ", "_")
snaps = branches_d.get(active, [])
if not snaps:
branches_d[new_name] = [snapshot_of(state, new_name)]
else:
idx = max(0, min(int(r), len(snaps) - 1))
branches_d[new_name] = [Snapshot(**asdict(s)) for s in snaps[:idx + 1]]
state = restore_into(state, branches_d[new_name][-1])
active = new_name
state.event_log.append(f"Forked branch -> {new_name}")
branches_d = push_hist(state, branches_d, active)
r = len(branches_d[active]) - 1
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def set_active_branch(state, branches_d, active, bel, r, br):
br = br or "main"
if br not in branches_d:
branches_d[br] = [snapshot_of(state, br)]
active = br
if branches_d[active]:
state = restore_into(state, branches_d[active][-1])
bel = init_beliefs(list(state.agents.keys()))
r = len(branches_d[active]) - 1
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def apply_env_seed(state, branches_d, active, bel, r, env_key, seed_val):
env_key = env_key or "chase"
seed_val = int(seed_val) if seed_val is not None else state.seed
# Preserve learning across env swaps
old_cfg = state.cfg
old_q = state.q_tables
old_gm = state.gmetrics
state = init_state(seed_val, env_key)
state.cfg = old_cfg
state.q_tables = old_q
state.gmetrics = old_gm
bel = init_beliefs(list(state.agents.keys()))
active = "main"
branches_d = {"main": [snapshot_of(state, "main")]}
r = 0
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def reset_ep(state, branches_d, active, bel, r):
state = reset_episode_keep_learning(state, seed=state.seed)
bel = init_beliefs(list(state.agents.keys()))
branches_d = {active: [snapshot_of(state, active)]}
r = 0
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def reset_all(state, branches_d, active, bel, r, env_key, seed_val):
env_key = env_key or state.env_key
seed_val = int(seed_val) if seed_val is not None else state.seed
state = wipe_all(seed=seed_val, env_key=env_key)
bel = init_beliefs(list(state.agents.keys()))
active = "main"
branches_d = {"main": [snapshot_of(state, "main")]}
r = 0
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def do_train(state, branches_d, active, bel, r,
use_q_v, a, g, e, ed, emin,
eps_count, max_s):
state = set_cfg(state, use_q_v, a, g, e, ed, emin)
state = train(state, episodes=max(1, int(eps_count)), max_steps=max(10, int(max_s)))
bel = init_beliefs(list(state.agents.keys()))
branches_d = {"main": [snapshot_of(state, "main")]}
active = "main"
r = 0
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def cinematic_run(state, branches_d, active, bel, r, max_s):
max_s = max(10, int(max_s))
# Reset episode so the cinematic is clean
state = reset_episode_keep_learning(state, seed=state.seed)
bel = init_beliefs(list(state.agents.keys()))
# Run to completion (or max steps) in one click
while state.step < max_s and not state.done:
tick(state, bel, manual_action=None)
state.event_log.append(f"Cinematic finished: outcome={state.outcome} steps={state.step}")
branches_d = push_hist(state, branches_d, active)
r = len(branches_d[active]) - 1
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
def export_fn(state, branches_d, active, r):
return export_run(state, branches_d, active, int(r))
def import_fn(txt):
state, branches_d, active, r, bel = import_run(txt)
branches_d.setdefault(active, [])
if not branches_d[active]:
branches_d[active].append(snapshot_of(state, active))
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r)
# ---- Autoplay control ----
def autoplay_start(state, branches_d, active, bel, r, interval_s):
interval_s = float(interval_s)
# Enable timer + autoplay flag
return (
gr.update(value=interval_s, active=True),
True,
state, branches_d, active, bel, r
)
def autoplay_stop(state, branches_d, active, bel, r):
return (
gr.update(active=False),
False,
state, branches_d, active, bel, r
)
def autoplay_tick(state, branches_d, active, bel, r, is_on: bool):
# If not on, do nothing (also keep timer active state as-is)
if not is_on:
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r, is_on, gr.update())
# Step once
if not state.done:
tick(state, bel, manual_action=None)
branches_d = push_hist(state, branches_d, active)
r = len(branches_d[active]) - 1
# If done, stop autoplay automatically
if state.done:
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r, False, gr.update(active=False))
out = refresh(state, branches_d, active, bel, r)
return out + (state, branches_d, active, bel, r, True, gr.update())
# ---- wiring ----
common_outputs = [
pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
rewind, rewind_idx, branch_pick, branch_pick,
st, branches, active_branch, beliefs, rewind_idx
]
btn_L.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"L"),
inputs=[st, branches, active_branch, beliefs, rewind_idx],
outputs=common_outputs, queue=True)
btn_F.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"F"),
inputs=[st, branches, active_branch, beliefs, rewind_idx],
outputs=common_outputs, queue=True)
btn_R.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"R"),
inputs=[st, branches, active_branch, beliefs, rewind_idx],
outputs=common_outputs, queue=True)
btn_I.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"I"),
inputs=[st, branches, active_branch, beliefs, rewind_idx],
outputs=common_outputs, queue=True)
btn_tick.click(do_tick,
inputs=[st, branches, active_branch, beliefs, rewind_idx],
outputs=common_outputs, queue=True)
btn_run.click(do_run,
inputs=[st, branches, active_branch, beliefs, rewind_idx, run_steps],
outputs=common_outputs, queue=True)
btn_toggle_control.click(toggle_control,
inputs=[st, branches, active_branch, beliefs, rewind_idx],
outputs=common_outputs, queue=True)
btn_toggle_pov.click(toggle_pov,
inputs=[st, branches, active_branch, beliefs, rewind_idx],
outputs=common_outputs, queue=True)
overlay.change(set_overlay,
inputs=[st, branches, active_branch, beliefs, rewind_idx, overlay],
outputs=common_outputs, queue=True)
truth.select(click_truth,
inputs=[tile_pick, st, branches, active_branch, beliefs, rewind_idx],
outputs=common_outputs, queue=True)
btn_jump.click(jump,
inputs=[st, branches, active_branch, beliefs, rewind_idx, rewind],
outputs=common_outputs, queue=True)
btn_fork.click(fork_branch,
inputs=[st, branches, active_branch, beliefs, rewind_idx, new_branch_name],
outputs=common_outputs, queue=True)
btn_set_branch.click(set_active_branch,
inputs=[st, branches, active_branch, beliefs, rewind_idx, branch_pick],
outputs=common_outputs, queue=True)
btn_apply_env_seed.click(apply_env_seed,
inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick, seed_box],
outputs=common_outputs, queue=True)
btn_reset_ep.click(reset_ep,
inputs=[st, branches, active_branch, beliefs, rewind_idx],
outputs=common_outputs, queue=True)
btn_reset_all.click(reset_all,
inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick, seed_box],
outputs=common_outputs, queue=True)
btn_train.click(do_train,
inputs=[st, branches, active_branch, beliefs, rewind_idx,
use_q, alpha, gamma, eps, eps_decay, eps_min,
episodes, max_steps],
outputs=common_outputs, queue=True)
btn_cinematic.click(cinematic_run,
inputs=[st, branches, active_branch, beliefs, rewind_idx, cinematic_steps],
outputs=common_outputs, queue=True)
btn_export.click(export_fn, inputs=[st, branches, active_branch, rewind_idx], outputs=[export_box], queue=True)
btn_import.click(import_fn,
inputs=[import_box],
outputs=common_outputs, queue=True)
# Autoplay start/stop wires
btn_autoplay_start.click(
autoplay_start,
inputs=[st, branches, active_branch, beliefs, rewind_idx, autoplay_speed],
outputs=[timer, autoplay_on, st, branches, active_branch, beliefs, rewind_idx],
queue=True
)
btn_autoplay_stop.click(
autoplay_stop,
inputs=[st, branches, active_branch, beliefs, rewind_idx],
outputs=[timer, autoplay_on, st, branches, active_branch, beliefs, rewind_idx],
queue=True
)
# Timer tick: step and update UI; auto-stop when done
timer.tick(
autoplay_tick,
inputs=[st, branches, active_branch, beliefs, rewind_idx, autoplay_on],
outputs=[
pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
rewind, rewind_idx, branch_pick, branch_pick,
st, branches, active_branch, beliefs, rewind_idx,
autoplay_on, timer
],
queue=True
)
demo.load(
refresh,
inputs=[st, branches, active_branch, beliefs, rewind_idx],
outputs=[
pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace,
rewind, rewind_idx, branch_pick, branch_pick
],
queue=True
)
# Disable SSR for HF stability
demo.queue().launch(ssr_mode=False)