diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,271 +1,139 @@ -import json +# app.py import math -import hashlib -from dataclasses import dataclass, asdict +import json +import base64 +import random +from dataclasses import dataclass, asdict, field from typing import Dict, List, Tuple, Optional, Any import numpy as np -from PIL import Image, ImageDraw - -import matplotlib.pyplot as plt -from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from PIL import Image, ImageDraw, ImageFont import gradio as gr # ============================================================ -# ZEN AgentLab — Agent POV + Multi-Agent Mini-Sim Arena -# -# Fix included: -# - Matplotlib rendering uses FigureCanvas.buffer_rgba() (HF-safe) -# -# Features: -# - Deterministic gridworld + first-person raycast POV -# - Multiple environments (Chase / CoopVault / MiniCiv) -# - Click-to-edit tiles + pickups + hazards + simple combat tags -# - Full step trace: obs -> action -> reward -> (optional) Q-update -# - Branching timelines (rewind + fork) -# - Batch training (tabular Q-learning) + metrics dashboard -# - Export/import full runs + SHA256 proof hash +# ZEN AgentLab++ — Animated Multi-Map Agent Simulation Arena +# ============================================================ +# Goals: +# - Working, automated "autoplay" simulation (timer-driven) +# - Multiple gameplay types (Predator/Prey "Pacman", CTF, Treasure, Resource Raid) +# - Multiple maps/courses (hand-crafted + procedural mazes) +# - "Cool" UI/UX + animations in the actual environment: +# * Smooth animated top-down via SVG + CSS transitions (browser-side animation) +# * Optional pseudo-3D "POV" panel (simple raycast look) +# * Mini objective HUD + event highlights +# - Fully self-contained: just this app.py + requirements.txt # -# HF Spaces compatible: no timers, no fn_kwargs +# NOTE: The SVG renderer is the secret weapon: +# It updates positions each tick; CSS transitions animate movement smoothly +# without generating tons of frames server-side. # ============================================================ # ----------------------------- -# Global config (shared) +# Grid + Render Config # ----------------------------- -GRID_W, GRID_H = 21, 15 -TILE = 22 +GRID_W, GRID_H = 29, 19 # map resolution (tiles) +TILE = 24 # pixels per tile for SVG +HUD_H = 64 # HUD header height (px) +SVG_W = GRID_W * TILE +SVG_H = GRID_H * TILE + HUD_H -VIEW_W, VIEW_H = 640, 360 -RAY_W = 320 -FOV_DEG = 78 -MAX_DEPTH = 20 +VIEW_W, VIEW_H = 560, 315 # pseudo-3D POV panel +FOV_DEG = 74 +MAX_DEPTH = 22 DIRS = [(1, 0), (0, 1), (-1, 0), (0, -1)] ORI_DEG = [0, 90, 180, 270] +# ----------------------------- # Tiles +# ----------------------------- EMPTY = 0 WALL = 1 -FOOD = 2 -NOISE = 3 -DOOR = 4 -TELE = 5 -KEY = 6 -EXIT = 7 -ARTIFACT = 8 -HAZARD = 9 -WOOD = 10 -ORE = 11 -MEDKIT = 12 -SWITCH = 13 -BASE = 14 +PELLET = 2 +POWER = 3 +FLAG_A = 4 +FLAG_B = 5 +TREASURE = 6 +BASE_A = 7 +BASE_B = 8 +RESOURCE = 9 +HAZARD = 10 +GATE = 11 TILE_NAMES = { EMPTY: "Empty", WALL: "Wall", - FOOD: "Food", - NOISE: "Noise", - DOOR: "Door", - TELE: "Teleporter", - KEY: "Key", - EXIT: "Exit", - ARTIFACT: "Artifact", + PELLET: "Pellet", + POWER: "Power", + FLAG_A: "Flag A", + FLAG_B: "Flag B", + TREASURE: "Treasure", + BASE_A: "Base A", + BASE_B: "Base B", + RESOURCE: "Resource", HAZARD: "Hazard", - WOOD: "Wood", - ORE: "Ore", - MEDKIT: "Medkit", - SWITCH: "Switch", - BASE: "Base", + GATE: "Gate", } +# Palette (kept consistent / readable) +COL_BG = "#0b1020" +COL_PANEL = "#0f1733" +COL_GRIDLINE = "#121a3b" +COL_WALL = "#cdd2e6" +COL_EMPTY = "#19214a" +COL_PELLET = "#ffd17a" +COL_POWER = "#ff7ad9" +COL_FLAG_A = "#7affc8" +COL_FLAG_B = "#ff7a7a" +COL_TREASURE = "#ffb86b" +COL_BASE_A = "#a0ffd9" +COL_BASE_B = "#ffb0b0" +COL_RESOURCE = "#9ab0ff" +COL_HAZARD = "#ff3b3b" +COL_GATE = "#7ad9ff" + AGENT_COLORS = { - "Predator": (255, 120, 90), - "Prey": (120, 255, 160), - "Scout": (120, 190, 255), - "Alpha": (255, 205, 120), - "Bravo": (160, 210, 255), - "Guardian": (255, 120, 220), - "BuilderA": (140, 255, 200), - "BuilderB": (160, 200, 255), - "Raider": (255, 160, 120), + "Predator": "#ff6d6d", + "Prey": "#6dffb0", + "Ghost1": "#ff7ad9", + "Ghost2": "#7ad9ff", + "RunnerA": "#ffd17a", + "RunnerB": "#9ab0ff", + "GuardA": "#7affc8", + "GuardB": "#ffb0b0", + "MinerA": "#a0ffd9", + "MinerB": "#c7d2fe", + "Raider": "#ff9b6b", } -SKY = np.array([14, 16, 26], dtype=np.uint8) -FLOOR_NEAR = np.array([24, 26, 40], dtype=np.uint8) -FLOOR_FAR = np.array([10, 11, 18], dtype=np.uint8) -WALL_BASE = np.array([210, 210, 225], dtype=np.uint8) -WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8) -DOOR_COL = np.array([140, 210, 255], dtype=np.uint8) - -# Small action space for tabular stability -ACTIONS = ["L", "F", "R", "I"] # interact - # ----------------------------- -# Deterministic RNG +# Utility # ----------------------------- -def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator: - mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531) - return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF) - -# ----------------------------- -# Data structures -# ----------------------------- -@dataclass -class Agent: - name: str - x: int - y: int - ori: int - hp: int = 10 - energy: int = 100 - team: str = "A" - brain: str = "q" # q | heuristic | random - inventory: Dict[str, int] = None - - def __post_init__(self): - if self.inventory is None: - self.inventory = {} - -@dataclass -class TrainConfig: - use_q: bool = True - alpha: float = 0.15 - gamma: float = 0.95 - epsilon: float = 0.10 - epsilon_min: float = 0.02 - epsilon_decay: float = 0.995 - - # generic shaping - step_penalty: float = -0.01 - explore_reward: float = 0.015 - damage_penalty: float = -0.20 - heal_reward: float = 0.10 - - # chase - chase_close_coeff: float = 0.03 - chase_catch_reward: float = 3.0 - chase_escaped_reward: float = 0.2 - chase_caught_penalty: float = -3.0 - food_reward: float = 0.6 - - # vault - artifact_pick_reward: float = 1.2 - exit_win_reward: float = 3.0 - guardian_tag_reward: float = 2.0 - tagged_penalty: float = -2.0 - switch_reward: float = 0.8 - key_reward: float = 0.4 - - # civ - resource_pick_reward: float = 0.15 - deposit_reward: float = 0.4 - base_progress_win_reward: float = 3.5 - raider_elim_reward: float = 2.0 - builder_elim_penalty: float = -2.0 - -@dataclass -class GlobalMetrics: - episodes: int = 0 - wins_teamA: int = 0 - wins_teamB: int = 0 - draws: int = 0 - avg_steps: float = 0.0 - rolling_winrate_A: float = 0.0 - epsilon: float = 0.10 - last_outcome: str = "init" - last_steps: int = 0 - -@dataclass -class EpisodeMetrics: - steps: int = 0 - returns: Dict[str, float] = None - action_counts: Dict[str, Dict[str, int]] = None - tiles_discovered: Dict[str, int] = None - - def __post_init__(self): - if self.returns is None: - self.returns = {} - if self.action_counts is None: - self.action_counts = {} - if self.tiles_discovered is None: - self.tiles_discovered = {} - -@dataclass -class WorldState: - seed: int - step: int - env_key: str - grid: List[List[int]] - agents: Dict[str, Agent] - - controlled: str - pov: str - overlay: bool - - done: bool - outcome: str # A_win | B_win | draw | ongoing - - # env state - door_opened_global: bool = False - base_progress: int = 0 - base_target: int = 10 - - # instrumentation - event_log: List[str] = None - trace_log: List[str] = None - - # learning - cfg: TrainConfig = None - q_tables: Dict[str, Dict[str, List[float]]] = None - gmetrics: GlobalMetrics = None - emetrics: EpisodeMetrics = None - - def __post_init__(self): - if self.event_log is None: - self.event_log = [] - if self.trace_log is None: - self.trace_log = [] - if self.cfg is None: - self.cfg = TrainConfig() - if self.q_tables is None: - self.q_tables = {} - if self.gmetrics is None: - self.gmetrics = GlobalMetrics(epsilon=self.cfg.epsilon) - if self.emetrics is None: - self.emetrics = EpisodeMetrics() - -@dataclass -class Snapshot: - branch: str - step: int - env_key: str - grid: List[List[int]] - agents: Dict[str, Dict[str, Any]] - done: bool - outcome: str - door_opened_global: bool - base_progress: int - base_target: int - event_tail: List[str] - trace_tail: List[str] - emetrics: Dict[str, Any] +def clamp(v, lo, hi): + return lo if v < lo else hi if v > hi else v -# ----------------------------- -# Helpers -# ----------------------------- def in_bounds(x: int, y: int) -> bool: return 0 <= x < GRID_W and 0 <= y < GRID_H -def is_blocking(tile: int, door_open: bool = False) -> bool: - if tile == WALL: - return True - if tile == DOOR and not door_open: - return True - return False +def manhattan(a: Tuple[int, int], b: Tuple[int, int]) -> int: + return abs(a[0] - b[0]) + abs(a[1] - b[1]) + +def rng(seed: int) -> random.Random: + r = random.Random() + r.seed(seed & 0xFFFFFFFF) + return r -def manhattan_xy(ax: int, ay: int, bx: int, by: int) -> int: - return abs(ax - bx) + abs(ay - by) +def grid_copy(g: List[List[int]]) -> List[List[int]]: + return [row[:] for row in g] + +def find_all(g: List[List[int]], tile: int) -> List[Tuple[int, int]]: + out = [] + for y in range(GRID_H): + for x in range(GRID_W): + if g[y][x] == tile: + out.append((x, y)) + return out def bresenham_los(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> bool: dx = abs(x1 - x0) @@ -288,1465 +156,1507 @@ def bresenham_los(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> err += dx y += sy -def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = FOV_DEG) -> bool: - dx = tx - observer.x - dy = ty - observer.y +def within_fov(ax: int, ay: int, ori: int, tx: int, ty: int, fov_deg: float = FOV_DEG) -> bool: + dx = tx - ax + dy = ty - ay if dx == 0 and dy == 0: return True - angle = math.degrees(math.atan2(dy, dx)) % 360 - facing = ORI_DEG[observer.ori] - diff = (angle - facing + 540) % 360 - 180 + ang = (math.degrees(math.atan2(dy, dx)) % 360) + facing = ORI_DEG[ori] + diff = (ang - facing + 540) % 360 - 180 return abs(diff) <= (fov_deg / 2) -def visible(state: WorldState, observer: Agent, target: Agent) -> bool: - if not within_fov(observer, target.x, target.y, FOV_DEG): - return False - return bresenham_los(state.grid, observer.x, observer.y, target.x, target.y) - -def hash_sha256(txt: str) -> str: - return hashlib.sha256(txt.encode("utf-8")).hexdigest() - # ----------------------------- -# Beliefs / fog-of-war +# Data Models # ----------------------------- -def init_beliefs(agent_names: List[str]) -> Dict[str, np.ndarray]: - return {nm: (-1 * np.ones((GRID_H, GRID_W), dtype=np.int16)) for nm in agent_names} +@dataclass +class Agent: + name: str + team: str + x: int + y: int + ori: int = 0 + hp: int = 5 + energy: int = 200 + inventory: Dict[str, int] = field(default_factory=dict) + mode: str = "auto" # auto | manual + brain: str = "heur" # heur | random -def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> int: - """Returns number of newly discovered tiles this update.""" - before_unknown = int(np.sum(belief == -1)) +@dataclass +class Objective: + title: str + detail: str - belief[agent.y, agent.x] = state.grid[agent.y][agent.x] - base = math.radians(ORI_DEG[agent.ori]) - half = math.radians(FOV_DEG / 2) - rays = 45 if agent.name.lower().startswith("scout") else 33 +@dataclass +class EnvSpec: + key: str + title: str + summary: str + max_steps: int - for i in range(rays): - t = i / (rays - 1) - ang = base + (t * 2 - 1) * half - sin_a, cos_a = math.sin(ang), math.cos(ang) - ox, oy = agent.x + 0.5, agent.y + 0.5 - depth = 0.0 - while depth < MAX_DEPTH: - depth += 0.2 - tx = int(ox + cos_a * depth) - ty = int(oy + sin_a * depth) - if not in_bounds(tx, ty): - break - belief[ty, tx] = state.grid[ty][tx] - tile = state.grid[ty][tx] - if tile == WALL: - break - if tile == DOOR and not state.door_opened_global: - break +@dataclass +class World: + seed: int + step: int + env_key: str + map_key: str - after_unknown = int(np.sum(belief == -1)) - return max(0, before_unknown - after_unknown) + grid: List[List[int]] + agents: Dict[str, Agent] -# ----------------------------- -# Rendering -# ----------------------------- -def raycast_view(state: WorldState, observer: Agent) -> np.ndarray: - img = np.zeros((VIEW_H, VIEW_W, 3), dtype=np.uint8) - img[:, :] = SKY + # gameplay flags / counters + done: bool = False + outcome: str = "ongoing" # A_win | B_win | draw | ongoing - for y in range(VIEW_H // 2, VIEW_H): - t = (y - VIEW_H // 2) / (VIEW_H // 2 + 1e-6) - col = (1 - t) * FLOOR_NEAR + t * FLOOR_FAR - img[y, :] = col.astype(np.uint8) + # pacman-style + power_timer: int = 0 + pellets_left: int = 0 - fov = math.radians(FOV_DEG) - half_fov = fov / 2 + # capture-the-flag + flag_carrier: Optional[str] = None + flag_taken_from: Optional[str] = None - for rx in range(RAY_W): - cam_x = (2 * rx / (RAY_W - 1)) - 1 - ray_ang = math.radians(ORI_DEG[observer.ori]) + cam_x * half_fov + # treasure run + treasure_collected_A: int = 0 + treasure_collected_B: int = 0 - ox, oy = observer.x + 0.5, observer.y + 0.5 - sin_a = math.sin(ray_ang) - cos_a = math.cos(ray_ang) + # resource raid + baseA_progress: int = 0 + baseB_progress: int = 0 + base_target: int = 10 - depth = 0.0 - hit = None # "wall" | "door" - side = 0 + # UX + controlled: str = "" + pov: str = "" + overlay: bool = True + auto_camera: bool = True - while depth < MAX_DEPTH: - depth += 0.05 - tx = int(ox + cos_a * depth) - ty = int(oy + sin_a * depth) - if not in_bounds(tx, ty): - break - tile = state.grid[ty][tx] - if tile == WALL: - hit = "wall" - side = 1 if abs(cos_a) > abs(sin_a) else 0 - break - if tile == DOOR and not state.door_opened_global: - hit = "door" - break - - if hit is None: - continue + # logs + events: List[str] = field(default_factory=list) - depth *= math.cos(ray_ang - math.radians(ORI_DEG[observer.ori])) - depth = max(depth, 0.001) +# ----------------------------- +# Maps / Courses +# ----------------------------- +def base_border_grid() -> List[List[int]]: + g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)] + for x in range(GRID_W): + g[0][x] = WALL + g[GRID_H - 1][x] = WALL + for y in range(GRID_H): + g[y][0] = WALL + g[y][GRID_W - 1] = WALL + return g - proj_h = int((VIEW_H * 0.9) / depth) - y0 = max(0, VIEW_H // 2 - proj_h // 2) - y1 = min(VIEW_H - 1, VIEW_H // 2 + proj_h // 2) +def carve_maze(seed: int, density: float = 0.66) -> List[List[int]]: + """ + Procedural "course" generator: a DFS maze with a few open plazas. + We generate walls then carve corridors. This produces interesting navigation. + """ + r = rng(seed) + g = [[WALL for _ in range(GRID_W)] for _ in range(GRID_H)] - if hit == "door": - col = DOOR_COL.copy() - else: - col = WALL_BASE.copy() if side == 0 else WALL_SIDE.copy() + # keep borders solid + for y in range(GRID_H): + for x in range(GRID_W): + if x in (0, GRID_W - 1) or y in (0, GRID_H - 1): + g[y][x] = WALL + + # carve from odd cells + def neighbors(cx, cy): + dirs = [(2, 0), (-2, 0), (0, 2), (0, -2)] + r.shuffle(dirs) + for dx, dy in dirs: + nx, ny = cx + dx, cy + dy + if 1 <= nx < GRID_W - 1 and 1 <= ny < GRID_H - 1: + yield nx, ny, dx, dy + + start = (1 + 2 * (r.randint(0, (GRID_W - 3) // 2)), + 1 + 2 * (r.randint(0, (GRID_H - 3) // 2))) + stack = [start] + g[start[1]][start[0]] = EMPTY + + visited = set([start]) + while stack: + cx, cy = stack[-1] + moved = False + for nx, ny, dx, dy in neighbors(cx, cy): + if (nx, ny) in visited: + continue + visited.add((nx, ny)) + g[cy + dy // 2][cx + dx // 2] = EMPTY + g[ny][nx] = EMPTY + stack.append((nx, ny)) + moved = True + break + if not moved: + stack.pop() + + # open up some "plazas" based on density + plazas = int((1.0 - density) * 8) + 2 + for _ in range(plazas): + px = r.randint(3, GRID_W - 4) + py = r.randint(3, GRID_H - 4) + w = r.randint(2, 4) + h = r.randint(2, 3) + for yy in range(py - h, py + h + 1): + for xx in range(px - w, px + w + 1): + if 1 <= xx < GRID_W - 1 and 1 <= yy < GRID_H - 1: + g[yy][xx] = EMPTY - dim = max(0.25, 1.0 - (depth / MAX_DEPTH)) - col = (col * dim).astype(np.uint8) + return g - x0 = int(rx * (VIEW_W / RAY_W)) - x1 = int((rx + 1) * (VIEW_W / RAY_W)) - img[y0:y1, x0:x1] = col +def map_pac_chase(seed: int) -> List[List[int]]: + g = base_border_grid() + # iconic mid-wall with gates + for x in range(4, GRID_W - 4): + g[GRID_H // 2][x] = WALL + gate_x = GRID_W // 2 + g[GRID_H // 2][gate_x] = GATE + g[GRID_H // 2][gate_x - 1] = GATE + g[GRID_H // 2][gate_x + 1] = GATE + + # pellets everywhere open + for y in range(1, GRID_H - 1): + for x in range(1, GRID_W - 1): + if g[y][x] == EMPTY: + g[y][x] = PELLET + + # power pellets at corners + for (x, y) in [(2, 2), (GRID_W - 3, 2), (2, GRID_H - 3), (GRID_W - 3, GRID_H - 3)]: + g[y][x] = POWER + + # a few internal blocks + r = rng(seed) + for _ in range(26): + x = r.randint(2, GRID_W - 3) + y = r.randint(2, GRID_H - 3) + if g[y][x] in (PELLET, EMPTY): + g[y][x] = WALL - # billboards for visible agents - for nm, other in state.agents.items(): - if nm == observer.name or other.hp <= 0: - continue - if visible(state, observer, other): - dx = other.x - observer.x - dy = other.y - observer.y - ang = (math.degrees(math.atan2(dy, dx)) % 360) - facing = ORI_DEG[observer.ori] - diff = (ang - facing + 540) % 360 - 180 - sx = int((diff / (FOV_DEG / 2)) * (VIEW_W / 2) + (VIEW_W / 2)) - dist = math.sqrt(dx * dx + dy * dy) - h = int((VIEW_H * 0.65) / max(dist, 0.75)) - w = max(10, h // 3) - y_mid = VIEW_H // 2 - y0 = max(0, y_mid - h // 2) - y1 = min(VIEW_H - 1, y_mid + h // 2) - x0 = max(0, sx - w // 2) - x1 = min(VIEW_W - 1, sx + w // 2) - col = AGENT_COLORS.get(nm, (255, 200, 120)) - img[y0:y1, x0:x1] = np.array(col, dtype=np.uint8) - - if state.overlay: - cx, cy = VIEW_W // 2, VIEW_H // 2 - img[cy - 1:cy + 2, cx - 10:cx + 10] = np.array([120, 190, 255], dtype=np.uint8) - img[cy - 10:cy + 10, cx - 1:cx + 2] = np.array([120, 190, 255], dtype=np.uint8) + return g - return img +def map_ctf_arena(seed: int) -> List[List[int]]: + g = carve_maze(seed, density=0.60) + # clear some central arena + cx, cy = GRID_W // 2, GRID_H // 2 + for y in range(cy - 3, cy + 4): + for x in range(cx - 5, cx + 6): + if 1 <= x < GRID_W - 1 and 1 <= y < GRID_H - 1: + g[y][x] = EMPTY + + # flags + bases + g[2][2] = FLAG_A + g[GRID_H - 3][GRID_W - 3] = FLAG_B + g[2][GRID_W - 3] = BASE_A + g[GRID_H - 3][2] = BASE_B + + # hazards sprinkled + r = rng(seed + 11) + for _ in range(18): + x = r.randint(2, GRID_W - 3) + y = r.randint(2, GRID_H - 3) + if g[y][x] == EMPTY: + g[y][x] = HAZARD -def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_agents: bool = True) -> Image.Image: - w = grid.shape[1] * TILE - h = grid.shape[0] * TILE - im = Image.new("RGB", (w, h + 28), (10, 12, 18)) - draw = ImageDraw.Draw(im) - - for y in range(grid.shape[0]): - for x in range(grid.shape[1]): - t = int(grid[y, x]) - if t == -1: - col = (18, 20, 32) - elif t == EMPTY: - col = (26, 30, 44) - elif t == WALL: - col = (190, 190, 210) - elif t == FOOD: - col = (255, 210, 120) - elif t == NOISE: - col = (255, 120, 220) - elif t == DOOR: - col = (140, 210, 255) - elif t == TELE: - col = (120, 190, 255) - elif t == KEY: - col = (255, 235, 160) - elif t == EXIT: - col = (120, 255, 220) - elif t == ARTIFACT: - col = (255, 170, 60) - elif t == HAZARD: - col = (255, 90, 90) - elif t == WOOD: - col = (170, 120, 60) - elif t == ORE: - col = (140, 140, 160) - elif t == MEDKIT: - col = (120, 255, 140) - elif t == SWITCH: - col = (200, 180, 255) - elif t == BASE: - col = (220, 220, 240) - else: - col = (80, 80, 90) + return g - x0, y0 = x * TILE, y * TILE + 28 - draw.rectangle([x0, y0, x0 + TILE - 1, y0 + TILE - 1], fill=col) +def map_treasure_run(seed: int) -> List[List[int]]: + g = carve_maze(seed, density=0.70) + # treasures + r = rng(seed + 7) + for _ in range(12): + x = r.randint(2, GRID_W - 3) + y = r.randint(2, GRID_H - 3) + if g[y][x] == EMPTY: + g[y][x] = TREASURE + # bases + g[2][2] = BASE_A + g[GRID_H - 3][GRID_W - 3] = BASE_B + return g - for x in range(grid.shape[1] + 1): - xx = x * TILE - draw.line([xx, 28, xx, h + 28], fill=(12, 14, 22)) - for y in range(grid.shape[0] + 1): - yy = y * TILE + 28 - draw.line([0, yy, w, yy], fill=(12, 14, 22)) +def map_resource_raid(seed: int) -> List[List[int]]: + g = carve_maze(seed, density=0.64) + # resource clusters + r = rng(seed + 23) + for _ in range(22): + x = r.randint(2, GRID_W - 3) + y = r.randint(2, GRID_H - 3) + if g[y][x] == EMPTY: + g[y][x] = RESOURCE + # bases + g[2][2] = BASE_A + g[GRID_H - 3][GRID_W - 3] = BASE_B + return g - if show_agents: - for nm, a in agents.items(): - if a.hp <= 0: - continue - cx = a.x * TILE + TILE // 2 - cy = a.y * TILE + 28 + TILE // 2 - col = AGENT_COLORS.get(nm, (220, 220, 220)) - r = TILE // 3 - draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col) - dx, dy = DIRS[a.ori] - draw.line([cx, cy, cx + dx * r, cy + dy * r], fill=(10, 10, 10), width=3) - - draw.rectangle([0, 0, w, 28], fill=(14, 16, 26)) - draw.text((8, 6), title, fill=(230, 230, 240)) - return im +MAP_BUILDERS = { + "Classic Pac-Chase": map_pac_chase, + "CTF Maze Arena": map_ctf_arena, + "Treasure Labyrinth": map_treasure_run, + "Resource Raid Maze": map_resource_raid, + "Procedural Maze (General)": lambda seed: carve_maze(seed, density=0.62), +} # ----------------------------- -# Environments +# Environments (Gameplay Types) # ----------------------------- -def grid_with_border() -> List[List[int]]: - g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)] - for x in range(GRID_W): - g[0][x] = WALL - g[GRID_H - 1][x] = WALL - for y in range(GRID_H): - g[y][0] = WALL - g[y][GRID_W - 1] = WALL - return g +ENVS: Dict[str, EnvSpec] = { + "pac_chase": EnvSpec( + key="pac_chase", + title="Predator/Prey (Pac-Chase)", + summary="Predator hunts Prey. Prey scores by eating pellets; power flips the chase temporarily.", + max_steps=650, + ), + "ctf": EnvSpec( + key="ctf", + title="Capture The Flag", + summary="Steal the opponent’s flag and return it to your base. Hazards drain HP.", + max_steps=800, + ), + "treasure": EnvSpec( + key="treasure", + title="Treasure Run", + summary="Collect treasures scattered in the maze and deposit at base. First to 6 deposits wins.", + max_steps=750, + ), + "resource": EnvSpec( + key="resource", + title="Resource Raid", + summary="Mine resources, deposit to build base progress. Raider tries to disrupt and tag.", + max_steps=850, + ), +} -def env_chase(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]: - g = grid_with_border() - for x in range(4, 17): - g[7][x] = WALL - g[7][10] = DOOR - - g[3][4] = FOOD - g[11][15] = FOOD - g[4][14] = NOISE - g[12][5] = NOISE - g[2][18] = TELE - g[13][2] = TELE - - agents = { - "Predator": Agent("Predator", 2, 2, 0, hp=10, energy=100, team="A", brain="q"), - "Prey": Agent("Prey", 18, 12, 2, hp=10, energy=100, team="B", brain="q"), - "Scout": Agent("Scout", 10, 3, 1, hp=10, energy=100, team="A", brain="heuristic"), - } - return g, agents - -def env_vault(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]: - g = grid_with_border() - for x in range(3, 18): - g[5][x] = WALL - for x in range(3, 18): - g[9][x] = WALL - g[5][10] = DOOR - g[9][12] = DOOR - - g[2][2] = KEY - g[12][18] = EXIT - g[12][2] = ARTIFACT - g[2][18] = TELE - g[13][2] = TELE - g[7][10] = SWITCH - g[3][15] = HAZARD - g[11][6] = MEDKIT - g[2][12] = FOOD - - agents = { - "Alpha": Agent("Alpha", 2, 12, 0, hp=10, energy=100, team="A", brain="q"), - "Bravo": Agent("Bravo", 3, 12, 0, hp=10, energy=100, team="A", brain="q"), - "Guardian": Agent("Guardian", 18, 2, 2, hp=10, energy=100, team="B", brain="q"), - } - return g, agents - -def env_civ(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]: - g = grid_with_border() - for y in range(3, 12): - g[y][9] = WALL - g[7][9] = DOOR - - g[2][3] = WOOD - g[3][3] = WOOD - g[4][3] = WOOD - g[12][16] = ORE - g[11][16] = ORE - g[10][16] = ORE - g[6][4] = FOOD - g[8][15] = FOOD - - g[13][10] = BASE - g[4][15] = HAZARD - g[10][4] = HAZARD - g[2][18] = TELE - g[13][2] = TELE - g[2][2] = KEY - g[12][6] = SWITCH - - agents = { - "BuilderA": Agent("BuilderA", 3, 12, 0, hp=10, energy=100, team="A", brain="q"), - "BuilderB": Agent("BuilderB", 4, 12, 0, hp=10, energy=100, team="A", brain="q"), - "Raider": Agent("Raider", 18, 2, 2, hp=10, energy=100, team="B", brain="q"), - } - return g, agents - -ENV_BUILDERS = {"chase": env_chase, "vault": env_vault, "civ": env_civ} +def env_objectives(env_key: str) -> List[Objective]: + if env_key == "pac_chase": + return [ + Objective("Prey", "Eat pellets (+) and survive. Power pellet makes Predator vulnerable temporarily."), + Objective("Predator", "Catch the Prey (tag on same tile). Avoid chasing into power windows."), + ] + if env_key == "ctf": + return [ + Objective("Team A", "Grab Flag B and return to Base A."), + Objective("Team B", "Grab Flag A and return to Base B."), + ] + if env_key == "treasure": + return [ + Objective("Both Teams", "Collect Treasures and deposit at your Base. First to 6 deposits wins."), + ] + if env_key == "resource": + return [ + Objective("Builders (A & B)", "Collect Resources and deposit to raise base progress."), + Objective("Raider", "Tag builders (collision) to slow progress; win by eliminating both or forcing timeout."), + ] + return [Objective("Objective", "Explore.")] # ----------------------------- -# Observation / Q-learning +# Spawn / Init # ----------------------------- -def local_tile_ahead(state: WorldState, a: Agent) -> int: - dx, dy = DIRS[a.ori] - nx, ny = a.x + dx, a.y + dy - if not in_bounds(nx, ny): - return WALL - return state.grid[ny][nx] - -def nearest_enemy_vec(state: WorldState, a: Agent) -> Tuple[int, int, int]: - best = None - for _, other in state.agents.items(): - if other.hp <= 0: - continue - if other.team == a.team: - continue - d = manhattan_xy(a.x, a.y, other.x, other.y) - if best is None or d < best[0]: - best = (d, other.x - a.x, other.y - a.y) - if best is None: - return (99, 0, 0) - d, dx, dy = best - return (d, int(np.clip(dx, -6, 6)), int(np.clip(dy, -6, 6))) - -def obs_key(state: WorldState, who: str) -> str: - a = state.agents[who] - d, dx, dy = nearest_enemy_vec(state, a) - ahead = local_tile_ahead(state, a) - keys = a.inventory.get("key", 0) - art = a.inventory.get("artifact", 0) - wood = a.inventory.get("wood", 0) - ore = a.inventory.get("ore", 0) - inv_bucket = f"k{min(keys,2)}a{min(art,1)}w{min(wood,3)}o{min(ore,3)}" - door = 1 if state.door_opened_global else 0 - return f"{state.env_key}|{who}|{a.x},{a.y},{a.ori}|e{d}:{dx},{dy}|t{ahead}|hp{a.hp}|{inv_bucket}|D{door}|bp{state.base_progress}" - -def q_get(q: Dict[str, List[float]], key: str) -> List[float]: - if key not in q: - q[key] = [0.0 for _ in ACTIONS] - return q[key] - -def epsilon_greedy(qvals: List[float], eps: float, r: np.random.Generator) -> int: - if r.random() < eps: - return int(r.integers(0, len(qvals))) - return int(np.argmax(qvals)) - -def q_update(q: Dict[str, List[float]], key: str, a_idx: int, reward: float, next_key: str, - alpha: float, gamma: float) -> Tuple[float, float, float]: - qv = q_get(q, key) - nq = q_get(q, next_key) - old = qv[a_idx] - target = reward + gamma * float(np.max(nq)) - new = old + alpha * (target - old) - qv[a_idx] = new - return old, target, new +def random_empty_cell(g: List[List[int]], r: random.Random) -> Tuple[int, int]: + empties = [(x, y) for y in range(1, GRID_H - 1) for x in range(1, GRID_W - 1) if g[y][x] in (EMPTY, PELLET)] + return r.choice(empties) if empties else (2, 2) + +def init_world(seed: int, env_key: str, map_key: str) -> World: + r = rng(seed) + g = MAP_BUILDERS[map_key](seed) + spec = ENVS[env_key] + + agents: Dict[str, Agent] = {} + + if env_key == "pac_chase": + # Predator + Prey + 2 ghosts (as roaming threats / decoys) + px, py = 2, 2 + qx, qy = GRID_W - 3, GRID_H - 3 + agents["Predator"] = Agent("Predator", "A", px, py, ori=0, hp=6, mode="auto", brain="heur") + agents["Prey"] = Agent("Prey", "B", qx, qy, ori=2, hp=5, mode="auto", brain="heur") + gx1, gy1 = (GRID_W // 2, 2) + gx2, gy2 = (GRID_W // 2, GRID_H - 3) + agents["Ghost1"] = Agent("Ghost1", "A", gx1, gy1, ori=1, hp=4, mode="auto", brain="random") + agents["Ghost2"] = Agent("Ghost2", "A", gx2, gy2, ori=3, hp=4, mode="auto", brain="random") + + pellets = sum(1 for y in range(GRID_H) for x in range(GRID_W) if g[y][x] in (PELLET, POWER)) + controlled = "Prey" + pov = "Prey" + + elif env_key == "ctf": + # 2 runners + 2 guards + ax, ay = 2, GRID_H - 3 + bx, by = GRID_W - 3, 2 + agents["RunnerA"] = Agent("RunnerA", "A", ax, ay, ori=0, hp=6, mode="auto", brain="heur") + agents["GuardA"] = Agent("GuardA", "A", 2, 2, ori=0, hp=7, mode="auto", brain="heur") + agents["RunnerB"] = Agent("RunnerB", "B", bx, by, ori=2, hp=6, mode="auto", brain="heur") + agents["GuardB"] = Agent("GuardB", "B", GRID_W - 3, GRID_H - 3, ori=2, hp=7, mode="auto", brain="heur") + pellets = 0 + controlled = "RunnerA" + pov = "RunnerA" + + elif env_key == "treasure": + agents["RunnerA"] = Agent("RunnerA", "A", 2, 2, ori=0, hp=6, mode="auto", brain="heur") + agents["RunnerB"] = Agent("RunnerB", "B", GRID_W - 3, GRID_H - 3, ori=2, hp=6, mode="auto", brain="heur") + agents["GuardA"] = Agent("GuardA", "A", 2, GRID_H - 3, ori=0, hp=6, mode="auto", brain="heur") + agents["GuardB"] = Agent("GuardB", "B", GRID_W - 3, 2, ori=2, hp=6, mode="auto", brain="heur") + pellets = 0 + controlled = "RunnerA" + pov = "RunnerA" + + else: # resource + agents["MinerA"] = Agent("MinerA", "A", 2, 2, ori=0, hp=6, mode="auto", brain="heur") + agents["MinerB"] = Agent("MinerB", "B", GRID_W - 3, GRID_H - 3, ori=2, hp=6, mode="auto", brain="heur") + agents["Raider"] = Agent("Raider", "R", GRID_W - 3, 2, ori=2, hp=7, mode="auto", brain="heur") + pellets = 0 + controlled = "MinerA" + pov = "MinerA" + + w = World( + seed=seed, + step=0, + env_key=env_key, + map_key=map_key, + grid=g, + agents=agents, + pellets_left=pellets, + controlled=controlled, + pov=pov, + overlay=True, + auto_camera=True, + events=[f"Initialized: env={env_key} ({spec.title}) | map={map_key} | seed={seed}"], + ) + return w # ----------------------------- -# Baseline heuristics +# Pathing + Movement # ----------------------------- -def face_towards(a: Agent, tx: int, ty: int) -> str: +def is_blocking(tile: int) -> bool: + return tile == WALL + +def neighbors4(x: int, y: int) -> List[Tuple[int, int]]: + return [(x + 1, y), (x, y + 1), (x - 1, y), (x, y - 1)] + +def bfs_next_step(grid: List[List[int]], start: Tuple[int, int], goal: Tuple[int, int]) -> Optional[Tuple[int, int]]: + if start == goal: + return None + sx, sy = start + gx, gy = goal + q = [(sx, sy)] + prev = {start: None} + while q: + x, y = q.pop(0) + if (x, y) == (gx, gy): + break + for nx, ny in neighbors4(x, y): + if not in_bounds(nx, ny): + continue + if is_blocking(grid[ny][nx]): + continue + if (nx, ny) not in prev: + prev[(nx, ny)] = (x, y) + q.append((nx, ny)) + if (gx, gy) not in prev: + return None + # backtrack one step from goal to start + cur = (gx, gy) + while prev[cur] != start and prev[cur] is not None: + cur = prev[cur] + return cur + +def face_towards(a: Agent, tx: int, ty: int): dx = tx - a.x dy = ty - a.y - ang = (math.degrees(math.atan2(dy, dx)) % 360) - facing = ORI_DEG[a.ori] - diff = (ang - facing + 540) % 360 - 180 - if diff < -10: - return "L" - if diff > 10: - return "R" - return "F" - -def heuristic_action(state: WorldState, who: str) -> str: - a = state.agents[who] - r = rng_for(state.seed, state.step, stream=900 + hash(who) % 1000) - - # Prioritize interacting on valuable tiles - t_here = state.grid[a.y][a.x] - if t_here in (FOOD, KEY, ARTIFACT, WOOD, ORE, MEDKIT, SWITCH, BASE, EXIT): - return "I" - - # Find nearest enemy - best = None - best_d = 999 - for _, other in state.agents.items(): - if other.hp <= 0 or other.team == a.team: - continue - d = manhattan_xy(a.x, a.y, other.x, other.y) - if d < best_d: - best_d = d - best = other - - if best is not None and best_d <= 6 and visible(state, a, best): - # attackers chase, defenders try to flee - if a.team == "B": - return face_towards(a, best.x, best.y) - - dx = best.x - a.x - dy = best.y - a.y - ang = (math.degrees(math.atan2(dy, dx)) % 360) - facing = ORI_DEG[a.ori] - diff = (ang - facing + 540) % 360 - 180 - diff_away = ((diff + 180) + 540) % 360 - 180 - if diff_away < -10: - return "L" - if diff_away > 10: - return "R" - return "F" - - return r.choice(["F", "F", "L", "R", "I"]) - -def random_action(state: WorldState, who: str) -> str: - r = rng_for(state.seed, state.step, stream=700 + hash(who) % 1000) - return r.choice(ACTIONS) - -# ----------------------------- -# Movement + interaction -# ----------------------------- -def turn_left(a: Agent) -> None: - a.ori = (a.ori - 1) % 4 - -def turn_right(a: Agent) -> None: - a.ori = (a.ori + 1) % 4 + if abs(dx) > abs(dy): + a.ori = 0 if dx > 0 else 2 + else: + a.ori = 1 if dy > 0 else 3 -def move_forward(state: WorldState, a: Agent) -> str: - dx, dy = DIRS[a.ori] - nx, ny = a.x + dx, a.y + dy +def move_to(world: World, a: Agent, nx: int, ny: int) -> bool: if not in_bounds(nx, ny): - return "blocked: bounds" - tile = state.grid[ny][nx] - if is_blocking(tile, door_open=state.door_opened_global): - return "blocked: wall/door" + return False + if is_blocking(world.grid[ny][nx]): + return False a.x, a.y = nx, ny - - if state.grid[ny][nx] == TELE: - teles = [(x, y) for y in range(GRID_H) for x in range(GRID_W) if state.grid[y][x] == TELE] - if len(teles) >= 2: - teles_sorted = sorted(teles) - idx = teles_sorted.index((nx, ny)) - dest = teles_sorted[(idx + 1) % len(teles_sorted)] - a.x, a.y = dest - state.event_log.append(f"t={state.step}: {a.name} teleported.") - return "moved: teleported" - return "moved" - -def try_interact(state: WorldState, a: Agent) -> str: - t = state.grid[a.y][a.x] - - if t == SWITCH: - state.door_opened_global = True - state.grid[a.y][a.x] = EMPTY - a.inventory["switch"] = a.inventory.get("switch", 0) + 1 - return "switch: opened all doors" - - if t == KEY: - a.inventory["key"] = a.inventory.get("key", 0) + 1 - state.grid[a.y][a.x] = EMPTY - return "picked: key" - - if t == ARTIFACT: - a.inventory["artifact"] = a.inventory.get("artifact", 0) + 1 - state.grid[a.y][a.x] = EMPTY - return "picked: artifact" - - if t == FOOD: - a.energy = min(200, a.energy + 35) - state.grid[a.y][a.x] = EMPTY - return "ate: food" - - if t == WOOD: - a.inventory["wood"] = a.inventory.get("wood", 0) + 1 - state.grid[a.y][a.x] = EMPTY - return "picked: wood" - - if t == ORE: - a.inventory["ore"] = a.inventory.get("ore", 0) + 1 - state.grid[a.y][a.x] = EMPTY - return "picked: ore" - - if t == MEDKIT: - a.hp = min(10, a.hp + 3) - state.grid[a.y][a.x] = EMPTY - return "used: medkit" - - if t == BASE: - w = a.inventory.get("wood", 0) - o = a.inventory.get("ore", 0) - dep = min(w, 2) + min(o, 2) - if dep > 0: - a.inventory["wood"] = max(0, w - min(w, 2)) - a.inventory["ore"] = max(0, o - min(o, 2)) - state.base_progress += dep - return f"deposited: +{dep} base_progress" - return "base: nothing to deposit" - - if t == EXIT: - return "at_exit" - - return "interact: none" - -def apply_action(state: WorldState, who: str, action: str) -> str: - a = state.agents[who] - if a.hp <= 0: - return "dead" - if action == "L": - turn_left(a) - return "turned left" - if action == "R": - turn_right(a) - return "turned right" - if action == "F": - return move_forward(state, a) - if action == "I": - return try_interact(state, a) - return "noop" + a.energy = max(0, a.energy - 1) + return True # ----------------------------- -# Hazards / collisions / done +# Core Interactions # ----------------------------- -def resolve_hazards(state: WorldState, a: Agent) -> Tuple[bool, str]: - if a.hp <= 0: - return (False, "") - if state.grid[a.y][a.x] == HAZARD: +def apply_tile_effects(world: World, a: Agent): + t = world.grid[a.y][a.x] + + # hazards drain HP + if t == HAZARD: a.hp -= 1 - return (True, "hazard:-hp") - return (False, "") + world.events.append(f"t={world.step}: {a.name} hit a hazard (-hp).") + + if world.env_key == "pac_chase": + if t == PELLET: + world.grid[a.y][a.x] = EMPTY + world.pellets_left = max(0, world.pellets_left - 1) + a.inventory["pellets"] = a.inventory.get("pellets", 0) + 1 + elif t == POWER: + world.grid[a.y][a.x] = EMPTY + world.pellets_left = max(0, world.pellets_left - 1) + world.power_timer = 26 + world.events.append(f"t={world.step}: POWER ACTIVE — chase flips for a bit.") + + if world.env_key == "ctf": + if t == FLAG_A and a.team == "B" and world.flag_carrier is None: + world.flag_carrier = a.name + world.flag_taken_from = "A" + world.grid[a.y][a.x] = EMPTY + world.events.append(f"t={world.step}: {a.name} stole Flag A!") + if t == FLAG_B and a.team == "A" and world.flag_carrier is None: + world.flag_carrier = a.name + world.flag_taken_from = "B" + world.grid[a.y][a.x] = EMPTY + world.events.append(f"t={world.step}: {a.name} stole Flag B!") + + # return conditions + if world.flag_carrier == a.name: + if a.team == "A" and world.grid[a.y][a.x] == BASE_A and world.flag_taken_from == "B": + world.done = True + world.outcome = "A_win" + world.events.append(f"t={world.step}: Team A captured the flag!") + if a.team == "B" and world.grid[a.y][a.x] == BASE_B and world.flag_taken_from == "A": + world.done = True + world.outcome = "B_win" + world.events.append(f"t={world.step}: Team B captured the flag!") + + if world.env_key == "treasure": + if t == TREASURE: + world.grid[a.y][a.x] = EMPTY + a.inventory["treasure"] = a.inventory.get("treasure", 0) + 1 + world.events.append(f"t={world.step}: {a.name} picked treasure.") + if t == BASE_A and a.team == "A": + dep = a.inventory.get("treasure", 0) + if dep > 0: + a.inventory["treasure"] = 0 + world.treasure_collected_A += dep + world.events.append(f"t={world.step}: Team A deposited {dep} treasure (total={world.treasure_collected_A}).") + if t == BASE_B and a.team == "B": + dep = a.inventory.get("treasure", 0) + if dep > 0: + a.inventory["treasure"] = 0 + world.treasure_collected_B += dep + world.events.append(f"t={world.step}: Team B deposited {dep} treasure (total={world.treasure_collected_B}).") + + if world.env_key == "resource": + if t == RESOURCE: + world.grid[a.y][a.x] = EMPTY + a.inventory["res"] = a.inventory.get("res", 0) + 1 + world.events.append(f"t={world.step}: {a.name} mined resource.") + if t == BASE_A and a.name == "MinerA": + dep = min(2, a.inventory.get("res", 0)) + if dep > 0: + a.inventory["res"] -= dep + world.baseA_progress += dep + world.events.append(f"t={world.step}: MinerA deposited +{dep} (A={world.baseA_progress}/{world.base_target}).") + if t == BASE_B and a.name == "MinerB": + dep = min(2, a.inventory.get("res", 0)) + if dep > 0: + a.inventory["res"] -= dep + world.baseB_progress += dep + world.events.append(f"t={world.step}: MinerB deposited +{dep} (B={world.baseB_progress}/{world.base_target}).") -def resolve_tags(state: WorldState) -> List[str]: - msgs = [] - occupied: Dict[Tuple[int, int], List[str]] = {} - for nm, a in state.agents.items(): +# ----------------------------- +# Collision / Tagging +# ----------------------------- +def resolve_tags(world: World): + # If opposing agents occupy same tile: tag event + pos: Dict[Tuple[int, int], List[str]] = {} + for nm, a in world.agents.items(): if a.hp <= 0: continue - occupied.setdefault((a.x, a.y), []).append(nm) + pos.setdefault((a.x, a.y), []).append(nm) - for (x, y), names in occupied.items(): + for (x, y), names in pos.items(): if len(names) < 2: continue - teams = set(state.agents[n].team for n in names) - if len(teams) >= 2: - for n in names: - state.agents[n].hp -= 1 - msgs.append(f"t={state.step}: collision/tag at ({x},{y}) {names} (-hp all)") - return msgs - -def check_done(state: WorldState) -> None: - if state.env_key == "chase": - pred = state.agents["Predator"] - prey = state.agents["Prey"] - if pred.hp <= 0 and prey.hp <= 0: - state.done = True - state.outcome = "draw" - return - if pred.hp > 0 and prey.hp > 0 and pred.x == prey.x and pred.y == prey.y: - state.done = True - state.outcome = "A_win" - state.event_log.append(f"t={state.step}: CAUGHT (Predator wins).") - return - if state.step >= 300 and prey.hp > 0: - state.done = True - state.outcome = "B_win" - state.event_log.append(f"t={state.step}: ESCAPED (Prey survives).") - return + teams = set(world.agents[n].team for n in names) + if len(teams) <= 1: + continue - if state.env_key == "vault": - for nm in ["Alpha", "Bravo"]: - a = state.agents[nm] - if a.hp > 0 and a.inventory.get("artifact", 0) > 0 and state.grid[a.y][a.x] == EXIT: - state.done = True - state.outcome = "A_win" - state.event_log.append(f"t={state.step}: VAULT CLEARED (Team A wins).") + # pac_chase special: power flips who is vulnerable + if world.env_key == "pac_chase": + if "Predator" in names and "Prey" in names: + if world.power_timer > 0: + # Predator vulnerable + world.agents["Predator"].hp -= 2 + world.events.append(f"t={world.step}: Prey TAGGED Predator during POWER (-2hp Predator).") + else: + world.done = True + world.outcome = "A_win" + world.events.append(f"t={world.step}: Predator CAUGHT Prey.") return - alive_A = any(state.agents[n].hp > 0 for n in ["Alpha", "Bravo"]) - if not alive_A: - state.done = True - state.outcome = "B_win" - state.event_log.append(f"t={state.step}: TEAM A ELIMINATED (Guardian wins).") - return - if state.env_key == "civ": - if state.base_progress >= state.base_target: - state.done = True - state.outcome = "A_win" - state.event_log.append(f"t={state.step}: BASE COMPLETE (Builders win).") - return - alive_A = any(state.agents[n].hp > 0 for n in ["BuilderA", "BuilderB"]) - if not alive_A: - state.done = True - state.outcome = "B_win" - state.event_log.append(f"t={state.step}: BUILDERS ELIMINATED (Raider wins).") - return - if state.step >= 350: - state.done = True - state.outcome = "draw" - state.event_log.append(f"t={state.step}: TIMEOUT (draw).") - return + # otherwise, both lose hp + for n in names: + world.agents[n].hp -= 1 + world.events.append(f"t={world.step}: TAG at ({x},{y}) {names} (-hp).") + + # CTF: drop flag if carrier tagged + if world.env_key == "ctf" and world.flag_carrier in names: + carrier = world.flag_carrier + world.flag_carrier = None + # respawn flag to original side + if world.flag_taken_from == "A": + world.grid[2][2] = FLAG_A + elif world.flag_taken_from == "B": + world.grid[GRID_H - 3][GRID_W - 3] = FLAG_B + world.events.append(f"t={world.step}: {carrier} dropped the flag!") # ----------------------------- -# Rewards +# Done Conditions # ----------------------------- -def reward_for(prev: WorldState, now: WorldState, who: str, outcome_msg: str, took_damage: bool) -> float: - cfg = now.cfg - r = cfg.step_penalty - if outcome_msg.startswith("moved"): - r += cfg.explore_reward - if took_damage: - r += cfg.damage_penalty - if outcome_msg.startswith("used: medkit"): - r += cfg.heal_reward - - if now.env_key == "chase": - pred = now.agents["Predator"] - prey = now.agents["Prey"] - if who == "Predator": - d0 = manhattan_xy(prev.agents["Predator"].x, prev.agents["Predator"].y, - prev.agents["Prey"].x, prev.agents["Prey"].y) - d1 = manhattan_xy(pred.x, pred.y, prey.x, prey.y) - r += cfg.chase_close_coeff * float(d0 - d1) - if now.done and now.outcome == "A_win": - r += cfg.chase_catch_reward - if who == "Prey": - if outcome_msg.startswith("ate: food"): - r += cfg.food_reward - if now.done and now.outcome == "B_win": - r += cfg.chase_escaped_reward - if now.done and now.outcome == "A_win": - r += cfg.chase_caught_penalty - - if now.env_key == "vault": - if outcome_msg.startswith("picked: artifact"): - r += cfg.artifact_pick_reward - if outcome_msg.startswith("picked: key"): - r += cfg.key_reward - if outcome_msg.startswith("switch:"): - r += cfg.switch_reward - if now.done: - if now.outcome == "A_win" and now.agents[who].team == "A": - r += cfg.exit_win_reward - if now.outcome == "B_win" and now.agents[who].team == "B": - r += cfg.guardian_tag_reward - if now.outcome == "B_win" and now.agents[who].team == "A": - r += cfg.tagged_penalty - - if now.env_key == "civ": - if outcome_msg.startswith("picked: wood") or outcome_msg.startswith("picked: ore"): - r += cfg.resource_pick_reward - if outcome_msg.startswith("deposited:"): - r += cfg.deposit_reward - if now.done: - if now.outcome == "A_win" and now.agents[who].team == "A": - r += cfg.base_progress_win_reward - if now.outcome == "B_win" and now.agents[who].team == "B": - r += cfg.raider_elim_reward - if now.outcome == "B_win" and now.agents[who].team == "A": - r += cfg.builder_elim_penalty - - return float(r) +def check_done(world: World): + spec = ENVS[world.env_key] + if world.done: + return -# ----------------------------- -# Policy selection -# ----------------------------- -def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, Optional[Tuple[str, int]]]: - a = state.agents[who] - cfg = state.cfg - r = rng_for(state.seed, state.step, stream=stream) + # timeout / survival + if world.step >= spec.max_steps: + world.done = True + world.outcome = "draw" + world.events.append(f"t={world.step}: TIMEOUT (draw).") + return - if a.brain == "random": - act = random_action(state, who) - return act, "random", None - if a.brain == "heuristic": - act = heuristic_action(state, who) - return act, "heuristic", None - - if cfg.use_q: - key = obs_key(state, who) - qtab = state.q_tables.setdefault(who, {}) - qv = q_get(qtab, key) - a_idx = epsilon_greedy(qv, state.gmetrics.epsilon, r) - return ACTIONS[a_idx], f"Q eps={state.gmetrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (key, a_idx) - - act = heuristic_action(state, who) - return act, "heuristic(fallback)", None + if world.env_key == "pac_chase": + prey = world.agents["Prey"] + pred = world.agents["Predator"] + if prey.hp <= 0: + world.done = True + world.outcome = "A_win" + world.events.append(f"t={world.step}: Prey eliminated — Predator wins.") + return + if pred.hp <= 0: + world.done = True + world.outcome = "B_win" + world.events.append(f"t={world.step}: Predator eliminated — Prey wins.") + return + if world.pellets_left <= 0: + world.done = True + world.outcome = "B_win" + world.events.append(f"t={world.step}: All pellets cleared — Prey wins.") + return -# ----------------------------- -# Init / reset -# ----------------------------- -def init_state(seed: int, env_key: str) -> WorldState: - g, agents = ENV_BUILDERS[env_key](seed) - st = WorldState( - seed=seed, - step=0, - env_key=env_key, - grid=g, - agents=agents, - controlled=list(agents.keys())[0], - pov=list(agents.keys())[0], - overlay=False, - done=False, - outcome="ongoing", - door_opened_global=False, - base_progress=0, - base_target=10, - ) - st.event_log = [f"Initialized env={env_key} seed={seed}."] - return st - -def reset_episode_keep_learning(state: WorldState, seed: Optional[int] = None) -> WorldState: - if seed is None: - seed = state.seed - fresh = init_state(int(seed), state.env_key) - fresh.cfg = state.cfg - fresh.q_tables = state.q_tables - fresh.gmetrics = state.gmetrics - fresh.gmetrics.epsilon = state.gmetrics.epsilon - return fresh - -def wipe_all(seed: int, env_key: str) -> WorldState: - st = init_state(seed, env_key) - st.cfg = TrainConfig() - st.gmetrics = GlobalMetrics(epsilon=st.cfg.epsilon) - st.q_tables = {} - return st + if world.env_key == "ctf": + # done handled on return + # elimination condition + aliveA = any(a.hp > 0 for a in world.agents.values() if a.team == "A") + aliveB = any(a.hp > 0 for a in world.agents.values() if a.team == "B") + if not aliveA and aliveB: + world.done = True + world.outcome = "B_win" + world.events.append(f"t={world.step}: Team A eliminated — Team B wins.") + elif not aliveB and aliveA: + world.done = True + world.outcome = "A_win" + world.events.append(f"t={world.step}: Team B eliminated — Team A wins.") + + if world.env_key == "treasure": + if world.treasure_collected_A >= 6: + world.done = True + world.outcome = "A_win" + world.events.append(f"t={world.step}: Team A reached 6 treasure — wins.") + elif world.treasure_collected_B >= 6: + world.done = True + world.outcome = "B_win" + world.events.append(f"t={world.step}: Team B reached 6 treasure — wins.") + + if world.env_key == "resource": + if world.baseA_progress >= world.base_target: + world.done = True + world.outcome = "A_win" + world.events.append(f"t={world.step}: Base A complete — MinerA wins.") + elif world.baseB_progress >= world.base_target: + world.done = True + world.outcome = "B_win" + world.events.append(f"t={world.step}: Base B complete — MinerB wins.") + # Raider wins by eliminating both miners + alive_miners = sum(1 for nm in ("MinerA", "MinerB") if world.agents.get(nm) and world.agents[nm].hp > 0) + if alive_miners == 0 and world.agents["Raider"].hp > 0: + world.done = True + world.outcome = "B_win" + world.events.append(f"t={world.step}: Miners eliminated — Raider wins.") # ----------------------------- -# History / branching +# Agent "Brains" (Heuristic + Random) # ----------------------------- -TRACE_MAX = 500 -MAX_HISTORY = 1400 - -def snapshot_of(state: WorldState, branch: str) -> Snapshot: - return Snapshot( - branch=branch, - step=state.step, - env_key=state.env_key, - grid=[row[:] for row in state.grid], - agents={k: asdict(v) for k, v in state.agents.items()}, - done=state.done, - outcome=state.outcome, - door_opened_global=state.door_opened_global, - base_progress=state.base_progress, - base_target=state.base_target, - event_tail=state.event_log[-25:], - trace_tail=state.trace_log[-40:], - emetrics=asdict(state.emetrics), - ) - -def restore_into(state: WorldState, snap: Snapshot) -> WorldState: - state.step = snap.step - state.env_key = snap.env_key - state.grid = [row[:] for row in snap.grid] - state.agents = {k: Agent(**d) for k, d in snap.agents.items()} - state.done = snap.done - state.outcome = snap.outcome - state.door_opened_global = snap.door_opened_global - state.base_progress = snap.base_progress - state.base_target = snap.base_target - state.event_log.append(f"Jumped to snapshot t={snap.step} (branch={snap.branch}).") - return state +def choose_target_pac(world: World, who: str) -> Tuple[int, int]: + a = world.agents[who] + prey = world.agents["Prey"] + pred = world.agents["Predator"] + + if who == "Prey": + # survival logic: if power active, prey can bully predator a bit, otherwise flee + if world.power_timer > 0: + return (pred.x, pred.y) # go toward predator (aggressive window) + # otherwise: go toward nearest pellet/power but avoid predator + pellets = find_all(world.grid, PELLET) + find_all(world.grid, POWER) + if pellets: + pellets.sort(key=lambda p: manhattan((a.x, a.y), p)) + return pellets[0] + return (a.x, a.y) + + if who == "Predator": + # if power active, avoid prey (pred vulnerable) + if world.power_timer > 0: + # run away from prey by targeting a far corner + corners = [(2, 2), (GRID_W - 3, 2), (2, GRID_H - 3), (GRID_W - 3, GRID_H - 3)] + corners.sort(key=lambda c: -manhattan((prey.x, prey.y), c)) + return corners[0] + return (prey.x, prey.y) + + # ghosts roam toward prey loosely + return (prey.x, prey.y) + +def choose_target_ctf(world: World, who: str) -> Tuple[int, int]: + a = world.agents[who] + # runners prioritize stealing flags; guards prioritize intercepting carrier / defending + if a.team == "A": + home_base = BASE_A + enemy_flag = FLAG_B + home_base_pos = find_all(world.grid, BASE_A)[0] + else: + home_base = BASE_B + enemy_flag = FLAG_A + home_base_pos = find_all(world.grid, BASE_B)[0] + + # if carrying flag, run home + if world.flag_carrier == who: + return home_base_pos + + # if teammate carrying flag, guard/intercept threats + if world.flag_carrier is not None: + carrier = world.agents[world.flag_carrier] + return (carrier.x, carrier.y) + + # otherwise: runners go to enemy flag; guards hover mid or defend base + if "Runner" in who: + flags = find_all(world.grid, enemy_flag) + if flags: + return flags[0] + return home_base_pos + + # guard: midpoint between base and enemy + enemy_flag_pos = find_all(world.grid, enemy_flag) + if enemy_flag_pos: + ex, ey = enemy_flag_pos[0] + bx, by = home_base_pos + return ((ex + bx) // 2, (ey + by) // 2) + return home_base_pos + +def choose_target_treasure(world: World, who: str) -> Tuple[int, int]: + a = world.agents[who] + base = BASE_A if a.team == "A" else BASE_B + base_pos = find_all(world.grid, base)[0] + + # deposit if holding + if a.inventory.get("treasure", 0) >= 2: + return base_pos + + treasures = find_all(world.grid, TREASURE) + if treasures: + treasures.sort(key=lambda p: manhattan((a.x, a.y), p)) + return treasures[0] + return base_pos + +def choose_target_resource(world: World, who: str) -> Tuple[int, int]: + a = world.agents[who] + if who == "Raider": + # hunt nearest miner + miners = [world.agents[n] for n in ("MinerA", "MinerB") if world.agents.get(n) and world.agents[n].hp > 0] + if miners: + miners.sort(key=lambda m: manhattan((a.x, a.y), (m.x, m.y))) + return (miners[0].x, miners[0].y) + return (a.x, a.y) + + # miners: deposit if holding enough + base_tile = BASE_A if who == "MinerA" else BASE_B + base_pos = find_all(world.grid, base_tile)[0] + if a.inventory.get("res", 0) >= 3: + return base_pos + + res = find_all(world.grid, RESOURCE) + if res: + res.sort(key=lambda p: manhattan((a.x, a.y), p)) + return res[0] + return base_pos + +def choose_target(world: World, who: str) -> Tuple[int, int]: + if world.env_key == "pac_chase": + return choose_target_pac(world, who) + if world.env_key == "ctf": + return choose_target_ctf(world, who) + if world.env_key == "treasure": + return choose_target_treasure(world, who) + if world.env_key == "resource": + return choose_target_resource(world, who) + return (world.agents[who].x, world.agents[who].y) + +def auto_step_agent(world: World, who: str): + a = world.agents[who] + if a.hp <= 0: + return -# ----------------------------- -# Metrics dashboard (HF-safe) -# ----------------------------- -def metrics_dashboard_image(state: WorldState) -> Image.Image: - gm = state.gmetrics - - fig = plt.figure(figsize=(7.0, 2.2), dpi=120) - ax = fig.add_subplot(111) - - x1 = max(1, gm.episodes) - ax.plot([0, x1], [gm.rolling_winrate_A, gm.rolling_winrate_A]) - ax.set_title("Global Metrics Snapshot") - ax.set_xlabel("Episodes") - ax.set_ylabel("Rolling winrate Team A") - ax.set_ylim(-0.05, 1.05) - ax.grid(True) - - txt = ( - f"env={state.env_key} | eps={gm.epsilon:.3f} | episodes={gm.episodes}\n" - f"A_wins={gm.wins_teamA} B_wins={gm.wins_teamB} draws={gm.draws} | avg_steps~{gm.avg_steps:.1f}\n" - f"last_outcome={gm.last_outcome} last_steps={gm.last_steps}" - ) - ax.text(0.01, 0.05, txt, transform=ax.transAxes, fontsize=8, va="bottom") + # choose next move + if a.brain == "random": + cand = [] + for nx, ny in neighbors4(a.x, a.y): + if in_bounds(nx, ny) and not is_blocking(world.grid[ny][nx]): + cand.append((nx, ny)) + if cand: + nx, ny = random.choice(cand) + face_towards(a, nx, ny) + move_to(world, a, nx, ny) + return - fig.tight_layout() + tx, ty = choose_target(world, who) + nxt = bfs_next_step(world.grid, (a.x, a.y), (tx, ty)) + if nxt is None: + # small wander if stuck + cand = [] + for nx, ny in neighbors4(a.x, a.y): + if in_bounds(nx, ny) and not is_blocking(world.grid[ny][nx]): + cand.append((nx, ny)) + if cand: + nx, ny = cand[world.step % len(cand)] + face_towards(a, nx, ny) + move_to(world, a, nx, ny) + return - canvas = FigureCanvas(fig) - canvas.draw() - buf = np.asarray(canvas.buffer_rgba()) # (H,W,4) - img = Image.fromarray(buf, mode="RGBA").convert("RGB") - plt.close(fig) - return img + nx, ny = nxt + face_towards(a, nx, ny) + move_to(world, a, nx, ny) + +def manual_action(world: World, action: str): + """ + Manual control for the 'controlled' agent: + L/R/F/I style minimal actions (Pacman-appropriate). + """ + who = world.controlled + a = world.agents[who] + if a.hp <= 0: + return -def action_entropy(counts: Dict[str, int]) -> float: - total = sum(counts.values()) - if total <= 0: - return 0.0 - p = np.array([c / total for c in counts.values()], dtype=np.float64) - p = np.clip(p, 1e-12, 1.0) - return float(-np.sum(p * np.log2(p))) - -def agent_scoreboard(state: WorldState) -> str: - rows = [] - header = ["agent", "team", "hp", "return", "steps", "entropy", "tiles_disc", "q_states", "inventory"] - rows.append(header) - steps = state.emetrics.steps - - for nm, a in state.agents.items(): - ret = state.emetrics.returns.get(nm, 0.0) - counts = state.emetrics.action_counts.get(nm, {}) - ent = action_entropy(counts) - td = state.emetrics.tiles_discovered.get(nm, 0) - qs = len(state.q_tables.get(nm, {})) - inv = json.dumps(a.inventory, sort_keys=True) - rows.append([nm, a.team, a.hp, f"{ret:.2f}", steps, f"{ent:.2f}", td, qs, inv]) - - col_w = [max(len(str(r[i])) for r in rows) for i in range(len(header))] - lines = [] - for ridx, r in enumerate(rows): - line = " | ".join(str(r[i]).ljust(col_w[i]) for i in range(len(header))) - lines.append(line) - if ridx == 0: - lines.append("-+-".join("-" * w for w in col_w)) - return "\n".join(lines) + if action == "L": + a.ori = (a.ori - 1) % 4 + return + if action == "R": + a.ori = (a.ori + 1) % 4 + return + if action == "F": + dx, dy = DIRS[a.ori] + nx, ny = a.x + dx, a.y + dy + if in_bounds(nx, ny) and not is_blocking(world.grid[ny][nx]): + move_to(world, a, nx, ny) + return + if action == "I": + # In this sim, "I" is effectively "interact": for some envs, that means "pick/drop". + # Most pickups happen automatically via tile effects; so we use I for "drop" in CTF if holding. + if world.env_key == "ctf" and world.flag_carrier == who: + world.flag_carrier = None + # drop flag at current position (simple) + if world.flag_taken_from == "A": + world.grid[a.y][a.x] = FLAG_A + elif world.flag_taken_from == "B": + world.grid[a.y][a.x] = FLAG_B + world.events.append(f"t={world.step}: {who} dropped the flag manually.") + return # ----------------------------- -# Tick / training +# Pseudo-3D POV Renderer (lightweight) # ----------------------------- -def clone_shallow(state: WorldState) -> WorldState: - return WorldState( - seed=state.seed, - step=state.step, - env_key=state.env_key, - grid=[row[:] for row in state.grid], - agents={k: Agent(**asdict(v)) for k, v in state.agents.items()}, - controlled=state.controlled, - pov=state.pov, - overlay=state.overlay, - done=state.done, - outcome=state.outcome, - door_opened_global=state.door_opened_global, - base_progress=state.base_progress, - base_target=state.base_target, - event_log=list(state.event_log), - trace_log=list(state.trace_log), - cfg=state.cfg, - q_tables=state.q_tables, - gmetrics=state.gmetrics, - emetrics=state.emetrics, - ) - -def update_action_counts(state: WorldState, who: str, act: str): - state.emetrics.action_counts.setdefault(who, {}) - state.emetrics.action_counts[who][act] = state.emetrics.action_counts[who].get(act, 0) + 1 +SKY = np.array([12, 14, 26], dtype=np.uint8) +FLOOR1 = np.array([24, 28, 54], dtype=np.uint8) +FLOOR2 = np.array([10, 12, 22], dtype=np.uint8) +WALL1 = np.array([205, 210, 232], dtype=np.uint8) +WALL2 = np.array([160, 168, 195], dtype=np.uint8) +GATEC = np.array([120, 220, 255], dtype=np.uint8) + +def raycast_pov(world: World, who: str) -> np.ndarray: + a = world.agents[who] + img = np.zeros((VIEW_H, VIEW_W, 3), dtype=np.uint8) + img[:, :] = SKY + # floor gradient + for y in range(VIEW_H // 2, VIEW_H): + t = (y - VIEW_H // 2) / max(1, (VIEW_H // 2)) + col = (1 - t) * FLOOR1 + t * FLOOR2 + img[y, :] = col.astype(np.uint8) -def tick(state: WorldState, beliefs: Dict[str, np.ndarray], manual_action: Optional[str] = None) -> None: - if state.done: - return + # rays + ray_cols = VIEW_W + half = math.radians(FOV_DEG / 2) + base = math.radians(ORI_DEG[a.ori]) - prev = clone_shallow(state) - chosen: Dict[str, str] = {} - reasons: Dict[str, str] = {} - qinfo: Dict[str, Optional[Tuple[str, int]]] = {} + for rx in range(ray_cols): + cam = (2 * rx / (ray_cols - 1)) - 1 + ang = base + cam * half + sin_a = math.sin(ang) + cos_a = math.cos(ang) - if manual_action is not None: - chosen[state.controlled] = manual_action - reasons[state.controlled] = "manual" - qinfo[state.controlled] = None + ox, oy = a.x + 0.5, a.y + 0.5 + depth = 0.0 + hit = None + side = 0 - order = list(state.agents.keys()) + while depth < MAX_DEPTH: + depth += 0.06 + tx = int(ox + cos_a * depth) + ty = int(oy + sin_a * depth) + if not in_bounds(tx, ty): + break + tile = world.grid[ty][tx] + if tile == WALL: + hit = "wall" + side = 1 if abs(cos_a) > abs(sin_a) else 0 + break + if tile == GATE: + hit = "gate" + break - for who in order: - if who in chosen: + if hit is None: continue - act, reason, qi = choose_action(state, who, stream=200 + (hash(who) % 1000)) - chosen[who] = act - reasons[who] = reason - qinfo[who] = qi - outcomes: Dict[str, str] = {} - took_damage: Dict[str, bool] = {nm: False for nm in order} + depth *= math.cos(ang - base) + depth = max(depth, 0.001) - for who in order: - outcomes[who] = apply_action(state, who, chosen[who]) + h = int((VIEW_H * 0.92) / depth) + y0 = max(0, VIEW_H // 2 - h // 2) + y1 = min(VIEW_H - 1, VIEW_H // 2 + h // 2) - dmg, msg = resolve_hazards(state, state.agents[who]) - took_damage[who] = dmg - if msg: - state.event_log.append(f"t={state.step}: {who} {msg}") + col = (GATEC.copy() if hit == "gate" else (WALL1.copy() if side == 0 else WALL2.copy())) + dim = max(0.28, 1.0 - depth / MAX_DEPTH) + col = (col * dim).astype(np.uint8) + img[y0:y1, rx:rx + 1] = col - update_action_counts(state, who, chosen[who]) + # simple agent sprites in view if visible + for nm, other in world.agents.items(): + if nm == who or other.hp <= 0: + continue + if not within_fov(a.x, a.y, a.ori, other.x, other.y): + continue + if not bresenham_los(world.grid, a.x, a.y, other.x, other.y): + continue - for m in resolve_tags(state): - state.event_log.append(m) + dx = other.x - a.x + dy = other.y - a.y + ang = math.degrees(math.atan2(dy, dx)) % 360 + facing = ORI_DEG[a.ori] + diff = (ang - facing + 540) % 360 - 180 + sx = int((diff / (FOV_DEG / 2)) * (VIEW_W / 2) + (VIEW_W / 2)) + dist = math.sqrt(dx * dx + dy * dy) + size = int((VIEW_H * 0.55) / max(dist, 1.0)) + size = clamp(size, 10, 110) + ymid = VIEW_H // 2 + x0 = clamp(sx - size // 4, 0, VIEW_W - 1) + x1 = clamp(sx + size // 4, 0, VIEW_W - 1) + y0 = clamp(ymid - size // 2, 0, VIEW_H - 1) + y1 = clamp(ymid + size // 2, 0, VIEW_H - 1) + + # convert agent hex color to rgb + hexcol = AGENT_COLORS.get(nm, "#ffd17a").lstrip("#") + rgb = np.array([int(hexcol[i:i+2], 16) for i in (0, 2, 4)], dtype=np.uint8) + img[y0:y1, x0:x1] = rgb + + # reticle + if world.overlay: + cx, cy = VIEW_W // 2, VIEW_H // 2 + img[cy - 1:cy + 2, cx - 16:cx + 16] = np.array([110, 210, 255], dtype=np.uint8) + img[cy - 16:cy + 16, cx - 1:cx + 2] = np.array([110, 210, 255], dtype=np.uint8) - # belief updates + discovered tiles - for nm, a in state.agents.items(): - if a.hp <= 0: - continue - disc = update_belief_for_agent(state, beliefs[nm], a) - state.emetrics.tiles_discovered[nm] = state.emetrics.tiles_discovered.get(nm, 0) + disc - - check_done(state) - - # rewards + Q - q_lines = [] - for who in order: - if who not in state.emetrics.returns: - state.emetrics.returns[who] = 0.0 - - r = reward_for(prev, state, who, outcomes[who], took_damage[who]) - state.emetrics.returns[who] += r - - if qinfo.get(who) is not None: - key, a_idx = qinfo[who] - next_key = obs_key(state, who) - qtab = state.q_tables.setdefault(who, {}) - old, tgt, new = q_update(qtab, key, a_idx, r, next_key, state.cfg.alpha, state.cfg.gamma) - q_lines.append(f"{who}: old={old:.3f} tgt={tgt:.3f} new={new:.3f} (a={ACTIONS[a_idx]})") - - trace = f"t={state.step} env={state.env_key} done={state.done} outcome={state.outcome}" - for who in order: - a = state.agents[who] - trace += f" | {who}:{chosen[who]} ({outcomes[who]}) hp={a.hp} [{reasons[who]}]" - if q_lines: - trace += " | Q: " + " ; ".join(q_lines) - - state.trace_log.append(trace) - if len(state.trace_log) > TRACE_MAX: - state.trace_log = state.trace_log[-TRACE_MAX:] - - state.step += 1 - state.emetrics.steps = state.step - -def run_episode(state: WorldState, beliefs: Dict[str, np.ndarray], max_steps: int) -> Tuple[str, int]: - while state.step < max_steps and not state.done: - tick(state, beliefs, manual_action=None) - return state.outcome, state.step - -def update_global_metrics_after_episode(state: WorldState, outcome: str, steps: int): - gm = state.gmetrics - gm.episodes += 1 - gm.last_outcome = outcome - gm.last_steps = steps - - if outcome == "A_win": - gm.wins_teamA += 1 - gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 1.0 - elif outcome == "B_win": - gm.wins_teamB += 1 - gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.0 - else: - gm.draws += 1 - gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.5 - - gm.avg_steps = (0.90 * gm.avg_steps + 0.10 * steps) if gm.avg_steps > 0 else float(steps) - gm.epsilon = max(state.cfg.epsilon_min, gm.epsilon * state.cfg.epsilon_decay) - -def train(state: WorldState, episodes: int, max_steps: int) -> WorldState: - for ep in range(episodes): - ep_seed = (state.seed * 1_000_003 + (state.gmetrics.episodes + ep) * 97_531) & 0xFFFFFFFF - state = reset_episode_keep_learning(state, seed=int(ep_seed)) - beliefs = init_beliefs(list(state.agents.keys())) - outcome, steps = run_episode(state, beliefs, max_steps=max_steps) - update_global_metrics_after_episode(state, outcome, steps) - - state.event_log.append( - f"Training: +{episodes} eps | eps={state.gmetrics.epsilon:.3f} | " - f"A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws}" - ) - state = reset_episode_keep_learning(state, seed=state.seed) - return state + return img # ----------------------------- -# Export / Import +# SVG Animated Renderer (the "cool UI" core) # ----------------------------- -def export_run(state: WorldState, branches: Dict[str, List[Snapshot]], active_branch: str, rewind_idx: int) -> str: - payload = { - "seed": state.seed, - "env_key": state.env_key, - "controlled": state.controlled, - "pov": state.pov, - "overlay": state.overlay, - "cfg": asdict(state.cfg), - "gmetrics": asdict(state.gmetrics), - "q_tables": state.q_tables, - "branches": {b: [asdict(s) for s in snaps] for b, snaps in branches.items()}, - "active_branch": active_branch, - "rewind_idx": int(rewind_idx), - "grid": state.grid, - "door_opened_global": state.door_opened_global, - "base_progress": state.base_progress, - "base_target": state.base_target, - } - txt = json.dumps(payload, indent=2) - proof = hash_sha256(txt) - return txt + "\n\n" + json.dumps({"proof_sha256": proof}, indent=2) - -def import_run(txt: str) -> Tuple[WorldState, Dict[str, List[Snapshot]], str, int, Dict[str, np.ndarray]]: - parts = txt.strip().split("\n\n") - data = json.loads(parts[0]) - - st = init_state(int(data.get("seed", 1337)), data.get("env_key", "chase")) - st.controlled = data.get("controlled", st.controlled) - st.pov = data.get("pov", st.pov) - st.overlay = bool(data.get("overlay", False)) - st.grid = data.get("grid", st.grid) - st.door_opened_global = bool(data.get("door_opened_global", False)) - st.base_progress = int(data.get("base_progress", 0)) - st.base_target = int(data.get("base_target", 10)) - - st.cfg = TrainConfig(**data.get("cfg", asdict(st.cfg))) - st.gmetrics = GlobalMetrics(**data.get("gmetrics", asdict(st.gmetrics))) - st.q_tables = data.get("q_tables", {}) - - branches_in = data.get("branches", {}) - branches: Dict[str, List[Snapshot]] = {} - for bname, snaps in branches_in.items(): - branches[bname] = [Snapshot(**s) for s in snaps] - - active = data.get("active_branch", "main") - r_idx = int(data.get("rewind_idx", 0)) - - if active in branches and branches[active]: - st = restore_into(st, branches[active][-1]) - st.event_log.append("Imported run (restored last snapshot).") +def tile_color(tile: int) -> str: + return { + EMPTY: COL_EMPTY, + WALL: COL_WALL, + PELLET: COL_PELLET, + POWER: COL_POWER, + FLAG_A: COL_FLAG_A, + FLAG_B: COL_FLAG_B, + TREASURE: COL_TREASURE, + BASE_A: COL_BASE_A, + BASE_B: COL_BASE_B, + RESOURCE: COL_RESOURCE, + HAZARD: COL_HAZARD, + GATE: COL_GATE, + }.get(tile, COL_EMPTY) + +def objective_hud(world: World) -> Tuple[str, str]: + spec = ENVS[world.env_key] + # short headline + detail line + if world.env_key == "pac_chase": + prey_score = world.agents["Prey"].inventory.get("pellets", 0) + headline = f"{spec.title} • pellets_left={world.pellets_left} • prey_score={prey_score} • power={world.power_timer}" + detail = "Prey clears pellets; Predator catches. Power flips vulnerability briefly." + elif world.env_key == "ctf": + carrier = world.flag_carrier or "none" + headline = f"{spec.title} • carrier={carrier} • step={world.step}/{spec.max_steps}" + detail = "Steal opponent flag → return to base. Tagging drops the flag." + elif world.env_key == "treasure": + headline = f"{spec.title} • A={world.treasure_collected_A}/6 • B={world.treasure_collected_B}/6 • step={world.step}/{spec.max_steps}" + detail = "Collect treasures and deposit at base. First to 6 wins." else: - st.event_log.append("Imported run (no snapshots).") + headline = f"{spec.title} • A={world.baseA_progress}/{world.base_target} • B={world.baseB_progress}/{world.base_target} • step={world.step}/{spec.max_steps}" + detail = "Mine resources, deposit to build progress. Raider tags to disrupt." + return headline, detail + +def svg_render(world: World, highlight: Optional[Tuple[int, int]] = None) -> str: + headline, detail = objective_hud(world) + + # CSS transitions: smooth movement + subtle breathing glow + # Note: SVG updates each tick, browser animates between transforms. + css = f""" + + """ + + # HUD panel + svg = [f""" +
+ {css} + + + + {headline} + {detail} + """] + + # tiles + for y in range(GRID_H): + for x in range(GRID_W): + t = world.grid[y][x] + c = tile_color(t) + px = x * TILE + py = HUD_H + y * TILE + # pellets as dots on top of empty tile (for nicer look) + if t == PELLET: + # base tile + svg.append(f'') + cx = px + TILE * 0.5 + cy = py + TILE * 0.5 + svg.append(f'') + elif t == POWER: + svg.append(f'') + cx = px + TILE * 0.5 + cy = py + TILE * 0.5 + svg.append(f'') + else: + svg.append(f'') + + # gridlines (subtle) + for x in range(GRID_W + 1): + px = x * TILE + svg.append(f'') + for y in range(GRID_H + 1): + py = HUD_H + y * TILE + svg.append(f'') + + # highlight tile (optional) + if highlight is not None: + hx, hy = highlight + if in_bounds(hx, hy): + px = hx * TILE + py = HUD_H + hy * TILE + svg.append(f'') + + # agents + for nm, a in world.agents.items(): + px = a.x * TILE + py = HUD_H + a.y * TILE + col = AGENT_COLORS.get(nm, "#ffd17a") + dead_cls = " dead" if a.hp <= 0 else "" + # base transform for smooth animation + svg.append(f""" + + + + """) + + # direction pointer + dx, dy = DIRS[a.ori] + x2 = TILE/2 + dx*(TILE*0.32) + y2 = TILE/2 + dy*(TILE*0.32) + svg.append(f'') + + # name badge + badge_w = max(46, 10 * len(nm) * 0.62) + svg.append(f'') + svg.append(f'{nm}') + + # HP bar + hp = clamp(a.hp, 0, 10) + bar_w = TILE * 0.78 + bx = TILE/2 - bar_w/2 + by = TILE * 0.80 + svg.append(f'') + svg.append(f'') + + # controlled indicator + if nm == world.controlled: + svg.append(f'') + + svg.append("") + + # footer mini status + if world.done: + outcome = world.outcome + outcome_col = "rgba(122,255,200,0.95)" if outcome == "A_win" else "rgba(255,122,122,0.95)" if outcome == "B_win" else "rgba(255,209,122,0.95)" + svg.append(f""" + + + DONE • {outcome} + + """) + + svg.append("
") + return "".join(svg) - beliefs = init_beliefs(list(st.agents.keys())) - return st, branches, active, r_idx, beliefs +# ----------------------------- +# UI Text Blocks +# ----------------------------- +def agent_table(world: World) -> str: + rows = [["agent", "team", "hp", "x", "y", "ori", "mode", "inv"]] + for nm, a in world.agents.items(): + rows.append([nm, a.team, a.hp, a.x, a.y, ORI_DEG[a.ori], a.mode, json.dumps(a.inventory, sort_keys=True)]) + # pretty fixed-width + widths = [max(len(str(r[i])) for r in rows) for i in range(len(rows[0]))] + lines = [] + for i, r in enumerate(rows): + lines.append(" | ".join(str(r[j]).ljust(widths[j]) for j in range(len(widths)))) + if i == 0: + lines.append("-+-".join("-" * w for w in widths)) + return "\n".join(lines) + +def status_text(world: World) -> str: + spec = ENVS[world.env_key] + return ( + f"env={world.env_key} ({spec.title}) | map={world.map_key} | seed={world.seed}\n" + f"step={world.step}/{spec.max_steps} | done={world.done} outcome={world.outcome}\n" + f"controlled={world.controlled} (mode={world.agents[world.controlled].mode}) | pov={world.pov} | auto_camera={world.auto_camera}" + ) # ----------------------------- -# UI helpers +# Simulation Tick # ----------------------------- -def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, Image.Image, str, str, str, str]: - for nm, a in state.agents.items(): - if a.hp > 0: - update_belief_for_agent(state, beliefs[nm], a) +def tick_world(world: World, manual: Optional[str] = None): + if world.done: + return + + # decay power + if world.power_timer > 0: + world.power_timer -= 1 + + # manual action for controlled agent if in manual mode OR if manual action is pressed + if manual is not None: + manual_action(world, manual) - pov = raycast_view(state, state.agents[state.pov]) - truth_np = np.array(state.grid, dtype=np.int16) - truth_img = render_topdown(truth_np, state.agents, f"Truth Map — env={state.env_key} t={state.step} seed={state.seed}", True) + # auto agents step + for nm, a in world.agents.items(): + if a.hp <= 0: + continue + if nm == world.controlled and manual is not None: + # already acted manually this tick; still allow auto for others + pass + else: + if a.mode == "auto": + auto_step_agent(world, nm) - ctrl = state.controlled - others = [k for k in state.agents.keys() if k != ctrl] - other = others[0] if others else ctrl - b_ctrl = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief", True) - b_other = render_topdown(beliefs[other], state.agents, f"{other} Belief", True) + # tile effects + for nm, a in world.agents.items(): + if a.hp > 0: + apply_tile_effects(world, a) - dash = metrics_dashboard_image(state) + # collisions / tags + resolve_tags(world) - status = ( - f"env={state.env_key} | Controlled={state.controlled} | POV={state.pov} | done={state.done} outcome={state.outcome}\n" - f"Episode steps={state.step} | base_progress={state.base_progress}/{state.base_target} | doors_open={state.door_opened_global}\n" - f"Global: episodes={state.gmetrics.episodes} | A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws} " - f"| winrateA~{state.gmetrics.rolling_winrate_A:.2f} | eps={state.gmetrics.epsilon:.3f}" - ) - events = "\n".join(state.event_log[-18:]) - trace = "\n".join(state.trace_log[-18:]) - scoreboard = agent_scoreboard(state) - return pov, truth_img, b_ctrl, b_other, dash, status, events, trace, scoreboard - -def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState) -> WorldState: - x_px, y_px = evt.index - y_px -= 28 - if y_px < 0: - return state - gx = int(x_px // TILE) - gy = int(y_px // TILE) - if not in_bounds(gx, gy): - return state - if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1: - return state - state.grid[gy][gx] = selected_tile - state.event_log.append(f"t={state.step}: Tile ({gx},{gy}) -> {TILE_NAMES.get(selected_tile)}") - return state + # clean dead agents + for nm, a in world.agents.items(): + if a.hp <= 0: + a.hp = 0 + + # auto camera cuts: switch POV to "most interesting" agent + if world.auto_camera: + # crude "drama score": close to enemy / holding objective / low hp + best = None + best_score = -1e9 + for nm, a in world.agents.items(): + if a.hp <= 0: + continue + score = 0.0 + score += (10 - a.hp) * 0.8 + # chase proximity + for om, o in world.agents.items(): + if om == nm or o.hp <= 0: + continue + if a.team != o.team: + d = manhattan((a.x, a.y), (o.x, o.y)) + score += max(0, 10 - d) * 0.25 + if world.env_key == "ctf" and world.flag_carrier == nm: + score += 4.0 + if world.env_key == "treasure" and a.inventory.get("treasure", 0) > 0: + score += 2.0 + if world.env_key == "resource" and a.inventory.get("res", 0) > 0: + score += 1.5 + + if score > best_score: + best_score = score + best = nm + if best is not None: + world.pov = best + + # done conditions + check_done(world) + + # advance time + world.step += 1 + + # prune logs + if len(world.events) > 220: + world.events = world.events[-220:] + +# ----------------------------- +# UI Orchestration +# ----------------------------- +def rebuild_views(world: World, highlight: Optional[Tuple[int, int]] = None): + svg = svg_render(world, highlight=highlight) + pov = raycast_pov(world, world.pov) + status = status_text(world) + agents_txt = agent_table(world) + events_txt = "\n".join(world.events[-18:]) + return svg, pov, status, agents_txt, events_txt + +def set_agent_modes(world: World, controlled_mode: str, other_mode: str): + # controlled agent mode + if world.controlled in world.agents: + world.agents[world.controlled].mode = controlled_mode + # others + for nm, a in world.agents.items(): + if nm != world.controlled: + a.mode = other_mode + world.events.append(f"t={world.step}: Modes set — controlled={controlled_mode}, others={other_mode}") + +def swap_controlled(world: World): + names = list(world.agents.keys()) + i = names.index(world.controlled) + world.controlled = names[(i + 1) % len(names)] + world.events.append(f"t={world.step}: Controlled -> {world.controlled}") + +def swap_pov(world: World): + names = list(world.agents.keys()) + i = names.index(world.pov) + world.pov = names[(i + 1) % len(names)] + world.events.append(f"t={world.step}: POV -> {world.pov}") + +def apply_env_map_seed(seed: int, env_key: str, map_key: str) -> World: + seed = int(seed) + w = init_world(seed=seed, env_key=env_key, map_key=map_key) + return w # ----------------------------- -# Gradio app +# Gradio App # ----------------------------- -TITLE = "ZEN AgentLab — Agent POV + Multi-Agent Mini-Sim Arena" +TITLE = "ZEN AgentLab++ — Animated Multi-Map Agent Simulation Arena" with gr.Blocks(title=TITLE) as demo: gr.Markdown( f"## {TITLE}\n" - "Multi-environment agent sandbox with POV, belief maps, branching timelines, training, and metrics.\n" - "**No timers** — use Tick / Run / Train for deterministic experiments." + "A living playground: agents **navigate real maps/courses**, chase objectives, and animate smoothly.\n" + "Use **Autoplay** for hands-free demos (Pac-Chase feels like chaotic Pac-Man)." ) - st0 = init_state(1337, "chase") - st = gr.State(st0) - branches = gr.State({"main": [snapshot_of(st0, "main")]}) - active_branch = gr.State("main") - rewind_idx = gr.State(0) - beliefs = gr.State(init_beliefs(list(st0.agents.keys()))) + # state + w0 = init_world(seed=1337, env_key="pac_chase", map_key="Classic Pac-Chase") + w_state = gr.State(w0) + autoplay_on = gr.State(False) + highlight_state = gr.State(None) # (x,y) or None with gr.Row(): - pov_img = gr.Image(label="POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H) - with gr.Column(): - status = gr.Textbox(label="Status", lines=3) - scoreboard = gr.Textbox(label="Agent Scoreboard", lines=8) + # Left: animated top-down SVG + arena = gr.HTML(label="Arena (Animated SVG)") - with gr.Row(): - truth = gr.Image(label="Truth Map (click to edit tiles)", type="pil") - belief_a = gr.Image(label="Belief (Controlled)", type="pil") - belief_b = gr.Image(label="Belief (Other)", type="pil") + # Right: POV + Status + with gr.Column(scale=1): + pov_img = gr.Image(label="Agent POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H) + status_box = gr.Textbox(label="Status", lines=3) + agent_box = gr.Textbox(label="Agents", lines=10) with gr.Row(): - dash = gr.Image(label="Metrics Dashboard", type="pil") + events_box = gr.Textbox(label="Event Log", lines=10) with gr.Row(): - events = gr.Textbox(label="Event Log", lines=10) - trace = gr.Textbox(label="Step Trace", lines=10) + with gr.Column(scale=2): + gr.Markdown("### Scenario Controls") + env_pick = gr.Radio( + choices=[ + (ENVS["pac_chase"].title, "pac_chase"), + (ENVS["ctf"].title, "ctf"), + (ENVS["treasure"].title, "treasure"), + (ENVS["resource"].title, "resource"), + ], + value="pac_chase", + label="Gameplay Type", + ) + map_pick = gr.Dropdown( + choices=list(MAP_BUILDERS.keys()), + value="Classic Pac-Chase", + label="Map / Course", + ) + seed_box = gr.Number(value=1337, precision=0, label="Seed") + + with gr.Row(): + btn_apply = gr.Button("Apply (Env + Map + Seed)") + btn_reset = gr.Button("Reset (Same Env/Map/Seed)") + + gr.Markdown("### Autoplay / Demo Mode") + autoplay_speed = gr.Slider(0.05, 0.8, value=0.18, step=0.01, label="Autoplay tick interval (sec)") + with gr.Row(): + btn_play = gr.Button("▶ Start Autoplay") + btn_pause = gr.Button("⏸ Stop Autoplay") + with gr.Row(): + run_n = gr.Number(value=25, precision=0, label="Run N ticks") + btn_run = gr.Button("Run") - with gr.Row(): with gr.Column(scale=2): - gr.Markdown("### Manual Controls") + gr.Markdown("### Control & Camera") + with gr.Row(): + btn_ctrl = gr.Button("Swap Controlled Agent") + btn_pov = gr.Button("Swap POV Agent") + overlay = gr.Checkbox(value=True, label="POV Overlay Reticle") + auto_camera = gr.Checkbox(value=True, label="Auto Camera Cuts (Spectator Mode)") + + gr.Markdown("### Agent Modes") + controlled_mode = gr.Radio(choices=["auto", "manual"], value="auto", label="Controlled Agent Mode") + other_mode = gr.Radio(choices=["auto", "manual"], value="auto", label="Other Agents Mode") + btn_modes = gr.Button("Apply Agent Modes") + + gr.Markdown("### Manual Actions (Controlled Agent)") with gr.Row(): btn_L = gr.Button("L") btn_F = gr.Button("F") btn_R = gr.Button("R") - btn_I = gr.Button("I (Interact)") - with gr.Row(): - btn_tick = gr.Button("Tick") - run_steps = gr.Number(value=25, label="Run N steps", precision=0) - btn_run = gr.Button("Run") + btn_I = gr.Button("I") - with gr.Row(): - btn_toggle_control = gr.Button("Toggle Controlled") - btn_toggle_pov = gr.Button("Toggle POV") - overlay = gr.Checkbox(False, label="Overlay reticle") + # Timer for autoplay + timer = gr.Timer(value=0.18, active=False) - gr.Markdown("### Environment + Edit") - env_pick = gr.Radio( - choices=[("Chase (Predator vs Prey)", "chase"), - ("CoopVault (team vs guardian)", "vault"), - ("MiniCiv (build + raid)", "civ")], - value="chase", - label="Environment" - ) - tile_pick = gr.Radio( - choices=[(TILE_NAMES[k], k) for k in [EMPTY, WALL, FOOD, NOISE, DOOR, TELE, KEY, EXIT, ARTIFACT, HAZARD, WOOD, ORE, MEDKIT, SWITCH, BASE]], - value=WALL, - label="Paint tile type" - ) + # ---------- initial load ---------- + def ui_refresh(world: World, highlight): + world.overlay = bool(world.overlay) + return (*rebuild_views(world, highlight=highlight), world, highlight) - with gr.Column(scale=3): - gr.Markdown("### Training Controls (Tabular Q-learning)") - use_q = gr.Checkbox(True, label="Use Q-learning (agents with brain='q')") - alpha = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="alpha") - gamma = gr.Slider(0.5, 0.99, value=0.95, step=0.01, label="gamma") - eps = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="epsilon") - eps_decay = gr.Slider(0.90, 0.999, value=0.995, step=0.001, label="epsilon decay") - eps_min = gr.Slider(0.0, 0.2, value=0.02, step=0.01, label="epsilon min") + def on_load(world: World, highlight): + return ui_refresh(world, highlight) - episodes = gr.Number(value=50, label="Train episodes", precision=0) - max_steps = gr.Number(value=260, label="Max steps/episode", precision=0) - btn_train = gr.Button("Train") + demo.load( + on_load, + inputs=[w_state, highlight_state], + outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state], + queue=True, + ) - btn_reset = gr.Button("Reset Episode (keep learning)") - btn_reset_all = gr.Button("Reset ALL (wipe Q + metrics)") + # ---------- Apply scenario ---------- + def apply_clicked(world: World, highlight, env_key: str, map_key: str, seed: int): + world = apply_env_map_seed(seed=seed, env_key=env_key, map_key=map_key) + world.overlay = True + world.auto_camera = True + highlight = None + return ui_refresh(world, highlight) + + btn_apply.click( + apply_clicked, + inputs=[w_state, highlight_state, env_pick, map_pick, seed_box], + outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state], + queue=True, + ) - with gr.Row(): - with gr.Column(scale=2): - gr.Markdown("### Timeline + Branching") - rewind = gr.Slider(0, 0, value=0, step=1, label="Rewind index (active branch)") - btn_jump = gr.Button("Jump to index") - new_branch_name = gr.Textbox(value="fork1", label="New branch name") - btn_fork = gr.Button("Fork from current rewind") + def reset_clicked(world: World, highlight): + world = init_world(seed=world.seed, env_key=world.env_key, map_key=world.map_key) + highlight = None + return ui_refresh(world, highlight) - with gr.Column(scale=2): - branch_pick = gr.Dropdown(choices=["main"], value="main", label="Active branch") - btn_set_branch = gr.Button("Set Active Branch") - - with gr.Column(scale=3): - export_box = gr.Textbox(label="Export JSON (+ proof hash)", lines=8) - btn_export = gr.Button("Export") - import_box = gr.Textbox(label="Import JSON", lines=8) - btn_import = gr.Button("Import") - - # ---------- glue ---------- - def refresh(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str, bel: Dict[str, np.ndarray], r: int): - snaps = branches_d.get(active, []) - r_max = max(0, len(snaps) - 1) - r = max(0, min(int(r), r_max)) - pov, tr, ba, bb, dimg, stxt, etxt, ttxt, sb = build_views(state, bel) - branch_choices = sorted(list(branches_d.keys())) - return ( - pov, tr, ba, bb, dimg, stxt, sb, etxt, ttxt, - gr.update(maximum=r_max, value=r), r, - gr.update(choices=branch_choices, value=active), - gr.update(choices=branch_choices, value=active), - ) - - def push_hist(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str) -> Dict[str, List[Snapshot]]: - branches_d.setdefault(active, []) - branches_d[active].append(snapshot_of(state, active)) - if len(branches_d[active]) > MAX_HISTORY: - branches_d[active].pop(0) - return branches_d - - def set_cfg(state: WorldState, use_q_v: bool, a: float, g: float, e: float, ed: float, emin: float) -> WorldState: - state.cfg.use_q = bool(use_q_v) - state.cfg.alpha = float(a) - state.cfg.gamma = float(g) - state.gmetrics.epsilon = float(e) - state.cfg.epsilon_decay = float(ed) - state.cfg.epsilon_min = float(emin) - return state - - def do_manual(state, branches_d, active, bel, r, act): - tick(state, bel, manual_action=act) - branches_d = push_hist(state, branches_d, active) - r = len(branches_d[active]) - 1 - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def do_tick(state, branches_d, active, bel, r): - tick(state, bel, manual_action=None) - branches_d = push_hist(state, branches_d, active) - r = len(branches_d[active]) - 1 - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def do_run(state, branches_d, active, bel, r, n): + btn_reset.click( + reset_clicked, + inputs=[w_state, highlight_state], + outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state], + queue=True, + ) + + # ---------- Modes / camera ---------- + def modes_clicked(world: World, highlight, cmode: str, omode: str): + set_agent_modes(world, cmode, omode) + return ui_refresh(world, highlight) + + btn_modes.click( + modes_clicked, + inputs=[w_state, highlight_state, controlled_mode, other_mode], + outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state], + queue=True, + ) + + def ctrl_clicked(world: World, highlight): + swap_controlled(world) + return ui_refresh(world, highlight) + + btn_ctrl.click( + ctrl_clicked, + inputs=[w_state, highlight_state], + outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state], + queue=True, + ) + + def pov_clicked(world: World, highlight): + swap_pov(world) + return ui_refresh(world, highlight) + + btn_pov.click( + pov_clicked, + inputs=[w_state, highlight_state], + outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state], + queue=True, + ) + + def overlay_changed(world: World, highlight, v: bool): + world.overlay = bool(v) + return ui_refresh(world, highlight) + + overlay.change( + overlay_changed, + inputs=[w_state, highlight_state, overlay], + outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state], + queue=True, + ) + + def auto_camera_changed(world: World, highlight, v: bool): + world.auto_camera = bool(v) + world.events.append(f"t={world.step}: auto_camera={world.auto_camera}") + return ui_refresh(world, highlight) + + auto_camera.change( + auto_camera_changed, + inputs=[w_state, highlight_state, auto_camera], + outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state], + queue=True, + ) + + # ---------- Manual buttons ---------- + def manual_btn(world: World, highlight, act: str): + tick_world(world, manual=act) + return ui_refresh(world, highlight) + + btn_L.click(lambda w,h: manual_btn(w,h,"L"), inputs=[w_state, highlight_state], outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state], queue=True) + btn_F.click(lambda w,h: manual_btn(w,h,"F"), inputs=[w_state, highlight_state], outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state], queue=True) + btn_R.click(lambda w,h: manual_btn(w,h,"R"), inputs=[w_state, highlight_state], outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state], queue=True) + btn_I.click(lambda w,h: manual_btn(w,h,"I"), inputs=[w_state, highlight_state], outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state], queue=True) + + # ---------- Run N ---------- + def run_clicked(world: World, highlight, n: int): n = max(1, int(n)) for _ in range(n): - if state.done: + if world.done: break - tick(state, bel, manual_action=None) - branches_d = push_hist(state, branches_d, active) - r = len(branches_d[active]) - 1 - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def toggle_control(state, branches_d, active, bel, r): - order = list(state.agents.keys()) - i = order.index(state.controlled) - state.controlled = order[(i + 1) % len(order)] - state.event_log.append(f"Controlled -> {state.controlled}") - branches_d = push_hist(state, branches_d, active) - r = len(branches_d[active]) - 1 - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def toggle_pov(state, branches_d, active, bel, r): - order = list(state.agents.keys()) - i = order.index(state.pov) - state.pov = order[(i + 1) % len(order)] - state.event_log.append(f"POV -> {state.pov}") - branches_d = push_hist(state, branches_d, active) - r = len(branches_d[active]) - 1 - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def set_overlay(state, branches_d, active, bel, r, ov): - state.overlay = bool(ov) - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def click_truth(tile, state, branches_d, active, bel, r, evt: gr.SelectData): - state = grid_click_to_tile(evt, int(tile), state) - branches_d = push_hist(state, branches_d, active) - r = len(branches_d[active]) - 1 - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def jump(state, branches_d, active, bel, r, idx): - snaps = branches_d.get(active, []) - if not snaps: - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - idx = max(0, min(int(idx), len(snaps) - 1)) - state = restore_into(state, snaps[idx]) - r = idx - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def fork_branch(state, branches_d, active, bel, r, new_name): - new_name = (new_name or "").strip() or "fork" - new_name = new_name.replace(" ", "_") - snaps = branches_d.get(active, []) - if not snaps: - branches_d[new_name] = [] - branches_d[new_name].append(snapshot_of(state, new_name)) - else: - idx = max(0, min(int(r), len(snaps) - 1)) - branches_d[new_name] = [Snapshot(**asdict(s)) for s in snaps[:idx + 1]] - state = restore_into(state, branches_d[new_name][-1]) - active = new_name - state.event_log.append(f"Forked branch -> {new_name}") - branches_d = push_hist(state, branches_d, active) - r = len(branches_d[active]) - 1 - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def set_active_branch(state, branches_d, active, bel, r, br): - br = br or "main" - if br not in branches_d: - branches_d[br] = [snapshot_of(state, br)] - active = br - if branches_d[active]: - state = restore_into(state, branches_d[active][-1]) - bel = init_beliefs(list(state.agents.keys())) - r = len(branches_d[active]) - 1 - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def change_env(state, branches_d, active, bel, r, env_key): - env_key = env_key or "chase" - state.env_key = env_key - state = reset_episode_keep_learning(state, seed=state.seed) - bel = init_beliefs(list(state.agents.keys())) - active = "main" - branches_d = {"main": [snapshot_of(state, "main")]} - r = 0 - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def reset_ep(state, branches_d, active, bel, r): - state = reset_episode_keep_learning(state, seed=state.seed) - bel = init_beliefs(list(state.agents.keys())) - branches_d = {active: [snapshot_of(state, active)]} - r = 0 - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def reset_all(state, branches_d, active, bel, r, env_key): - env_key = env_key or state.env_key - state = wipe_all(seed=state.seed, env_key=env_key) - bel = init_beliefs(list(state.agents.keys())) - active = "main" - branches_d = {"main": [snapshot_of(state, "main")]} - r = 0 - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def do_train(state, branches_d, active, bel, r, - use_q_v, a, g, e, ed, emin, - eps_count, max_s): - state = set_cfg(state, use_q_v, a, g, e, ed, emin) - state = train(state, episodes=max(1, int(eps_count)), max_steps=max(10, int(max_s))) - bel = init_beliefs(list(state.agents.keys())) - branches_d = {"main": [snapshot_of(state, "main")]} - active = "main" - r = 0 - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - def export_fn(state, branches_d, active, r): - return export_run(state, branches_d, active, int(r)) - - def import_fn(txt): - state, branches_d, active, r, bel = import_run(txt) - branches_d.setdefault(active, []) - if not branches_d[active]: - branches_d[active].append(snapshot_of(state, active)) - out = refresh(state, branches_d, active, bel, r) - return out + (state, branches_d, active, bel, r) - - # ---- wire events (no fn_kwargs) ---- - common_outputs = [ - pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace, - rewind, rewind_idx, branch_pick, branch_pick, - st, branches, active_branch, beliefs, rewind_idx - ] - - btn_L.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"L"), - inputs=[st, branches, active_branch, beliefs, rewind_idx], - outputs=common_outputs, queue=True) - - btn_F.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"F"), - inputs=[st, branches, active_branch, beliefs, rewind_idx], - outputs=common_outputs, queue=True) - - btn_R.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"R"), - inputs=[st, branches, active_branch, beliefs, rewind_idx], - outputs=common_outputs, queue=True) - - btn_I.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"I"), - inputs=[st, branches, active_branch, beliefs, rewind_idx], - outputs=common_outputs, queue=True) - - btn_tick.click(do_tick, - inputs=[st, branches, active_branch, beliefs, rewind_idx], - outputs=common_outputs, queue=True) - - btn_run.click(do_run, - inputs=[st, branches, active_branch, beliefs, rewind_idx, run_steps], - outputs=common_outputs, queue=True) - - btn_toggle_control.click(toggle_control, - inputs=[st, branches, active_branch, beliefs, rewind_idx], - outputs=common_outputs, queue=True) - - btn_toggle_pov.click(toggle_pov, - inputs=[st, branches, active_branch, beliefs, rewind_idx], - outputs=common_outputs, queue=True) - - overlay.change(set_overlay, - inputs=[st, branches, active_branch, beliefs, rewind_idx, overlay], - outputs=common_outputs, queue=True) - - env_pick.change(change_env, - inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick], - outputs=common_outputs, queue=True) - - truth.select(click_truth, - inputs=[tile_pick, st, branches, active_branch, beliefs, rewind_idx], - outputs=common_outputs, queue=True) - - btn_jump.click(jump, - inputs=[st, branches, active_branch, beliefs, rewind_idx, rewind], - outputs=common_outputs, queue=True) - - btn_fork.click(fork_branch, - inputs=[st, branches, active_branch, beliefs, rewind_idx, new_branch_name], - outputs=common_outputs, queue=True) - - btn_set_branch.click(set_active_branch, - inputs=[st, branches, active_branch, beliefs, rewind_idx, branch_pick], - outputs=common_outputs, queue=True) - - btn_reset.click(reset_ep, - inputs=[st, branches, active_branch, beliefs, rewind_idx], - outputs=common_outputs, queue=True) - - btn_reset_all.click(reset_all, - inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick], - outputs=common_outputs, queue=True) - - btn_train.click(do_train, - inputs=[st, branches, active_branch, beliefs, rewind_idx, - use_q, alpha, gamma, eps, eps_decay, eps_min, - episodes, max_steps], - outputs=common_outputs, queue=True) - - btn_export.click(export_fn, inputs=[st, branches, active_branch, rewind_idx], outputs=[export_box], queue=True) - - btn_import.click(import_fn, - inputs=[import_box], - outputs=common_outputs, queue=True) - - demo.load(refresh, - inputs=[st, branches, active_branch, beliefs, rewind_idx], - outputs=[ - pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace, - rewind, rewind_idx, branch_pick, branch_pick - ], - queue=True) - -# HF sometimes enables SSR by default; disable for maximum compatibility + tick_world(world, manual=None) + return ui_refresh(world, highlight) + + btn_run.click( + run_clicked, + inputs=[w_state, highlight_state, run_n], + outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state], + queue=True, + ) + + # ---------- Autoplay ---------- + def autoplay_start(world: World, highlight, interval: float): + interval = float(interval) + return gr.update(value=interval, active=True), True, world, highlight + + def autoplay_stop(world: World, highlight): + return gr.update(active=False), False, world, highlight + + btn_play.click( + autoplay_start, + inputs=[w_state, highlight_state, autoplay_speed], + outputs=[timer, autoplay_on, w_state, highlight_state], + queue=True, + ) + + btn_pause.click( + autoplay_stop, + inputs=[w_state, highlight_state], + outputs=[timer, autoplay_on, w_state, highlight_state], + queue=True, + ) + + def autoplay_tick(world: World, highlight, is_on: bool): + if not is_on: + return (*rebuild_views(world, highlight=highlight), world, highlight, is_on, gr.update()) + if not world.done: + tick_world(world, manual=None) + # stop automatically when done + if world.done: + return (*rebuild_views(world, highlight=highlight), world, highlight, False, gr.update(active=False)) + return (*rebuild_views(world, highlight=highlight), world, highlight, True, gr.update()) + + timer.tick( + autoplay_tick, + inputs=[w_state, highlight_state, autoplay_on], + outputs=[arena, pov_img, status_box, agent_box, events_box, w_state, highlight_state, autoplay_on, timer], + queue=True, + ) + demo.queue().launch(ssr_mode=False)