|
|
import json |
|
|
import math |
|
|
import hashlib |
|
|
from dataclasses import dataclass, asdict |
|
|
from typing import Dict, List, Tuple, Optional, Any |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image, ImageDraw |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas |
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GRID_W, GRID_H = 21, 15 |
|
|
TILE = 22 |
|
|
|
|
|
VIEW_W, VIEW_H = 640, 360 |
|
|
RAY_W = 320 |
|
|
FOV_DEG = 78 |
|
|
MAX_DEPTH = 20 |
|
|
|
|
|
DIRS = [(1, 0), (0, 1), (-1, 0), (0, -1)] |
|
|
ORI_DEG = [0, 90, 180, 270] |
|
|
|
|
|
|
|
|
EMPTY = 0 |
|
|
WALL = 1 |
|
|
FOOD = 2 |
|
|
NOISE = 3 |
|
|
DOOR = 4 |
|
|
TELE = 5 |
|
|
KEY = 6 |
|
|
EXIT = 7 |
|
|
ARTIFACT = 8 |
|
|
HAZARD = 9 |
|
|
WOOD = 10 |
|
|
ORE = 11 |
|
|
MEDKIT = 12 |
|
|
SWITCH = 13 |
|
|
BASE = 14 |
|
|
|
|
|
TILE_NAMES = { |
|
|
EMPTY: "Empty", |
|
|
WALL: "Wall", |
|
|
FOOD: "Food", |
|
|
NOISE: "Noise", |
|
|
DOOR: "Door", |
|
|
TELE: "Teleporter", |
|
|
KEY: "Key", |
|
|
EXIT: "Exit", |
|
|
ARTIFACT: "Artifact", |
|
|
HAZARD: "Hazard", |
|
|
WOOD: "Wood", |
|
|
ORE: "Ore", |
|
|
MEDKIT: "Medkit", |
|
|
SWITCH: "Switch", |
|
|
BASE: "Base", |
|
|
} |
|
|
|
|
|
AGENT_COLORS = { |
|
|
"Predator": (255, 120, 90), |
|
|
"Prey": (120, 255, 160), |
|
|
"Scout": (120, 190, 255), |
|
|
"Alpha": (255, 205, 120), |
|
|
"Bravo": (160, 210, 255), |
|
|
"Guardian": (255, 120, 220), |
|
|
"BuilderA": (140, 255, 200), |
|
|
"BuilderB": (160, 200, 255), |
|
|
"Raider": (255, 160, 120), |
|
|
} |
|
|
|
|
|
SKY = np.array([14, 16, 26], dtype=np.uint8) |
|
|
FLOOR_NEAR = np.array([24, 26, 40], dtype=np.uint8) |
|
|
FLOOR_FAR = np.array([10, 11, 18], dtype=np.uint8) |
|
|
WALL_BASE = np.array([210, 210, 225], dtype=np.uint8) |
|
|
WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8) |
|
|
DOOR_COL = np.array([140, 210, 255], dtype=np.uint8) |
|
|
|
|
|
|
|
|
ACTIONS = ["L", "F", "R", "I"] |
|
|
|
|
|
TRACE_MAX = 500 |
|
|
MAX_HISTORY = 1400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator: |
|
|
mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531) |
|
|
return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Agent: |
|
|
name: str |
|
|
x: int |
|
|
y: int |
|
|
ori: int |
|
|
hp: int = 10 |
|
|
energy: int = 100 |
|
|
team: str = "A" |
|
|
brain: str = "q" |
|
|
inventory: Dict[str, int] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.inventory is None: |
|
|
self.inventory = {} |
|
|
|
|
|
@dataclass |
|
|
class TrainConfig: |
|
|
use_q: bool = True |
|
|
alpha: float = 0.15 |
|
|
gamma: float = 0.95 |
|
|
epsilon: float = 0.10 |
|
|
epsilon_min: float = 0.02 |
|
|
epsilon_decay: float = 0.995 |
|
|
|
|
|
step_penalty: float = -0.01 |
|
|
explore_reward: float = 0.015 |
|
|
damage_penalty: float = -0.20 |
|
|
heal_reward: float = 0.10 |
|
|
|
|
|
chase_close_coeff: float = 0.03 |
|
|
chase_catch_reward: float = 3.0 |
|
|
chase_escaped_reward: float = 0.2 |
|
|
chase_caught_penalty: float = -3.0 |
|
|
food_reward: float = 0.6 |
|
|
|
|
|
artifact_pick_reward: float = 1.2 |
|
|
exit_win_reward: float = 3.0 |
|
|
guardian_tag_reward: float = 2.0 |
|
|
tagged_penalty: float = -2.0 |
|
|
switch_reward: float = 0.8 |
|
|
key_reward: float = 0.4 |
|
|
|
|
|
resource_pick_reward: float = 0.15 |
|
|
deposit_reward: float = 0.4 |
|
|
base_progress_win_reward: float = 3.5 |
|
|
raider_elim_reward: float = 2.0 |
|
|
builder_elim_penalty: float = -2.0 |
|
|
|
|
|
@dataclass |
|
|
class GlobalMetrics: |
|
|
episodes: int = 0 |
|
|
wins_teamA: int = 0 |
|
|
wins_teamB: int = 0 |
|
|
draws: int = 0 |
|
|
avg_steps: float = 0.0 |
|
|
rolling_winrate_A: float = 0.0 |
|
|
epsilon: float = 0.10 |
|
|
last_outcome: str = "init" |
|
|
last_steps: int = 0 |
|
|
|
|
|
@dataclass |
|
|
class EpisodeMetrics: |
|
|
steps: int = 0 |
|
|
returns: Dict[str, float] = None |
|
|
action_counts: Dict[str, Dict[str, int]] = None |
|
|
tiles_discovered: Dict[str, int] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.returns is None: |
|
|
self.returns = {} |
|
|
if self.action_counts is None: |
|
|
self.action_counts = {} |
|
|
if self.tiles_discovered is None: |
|
|
self.tiles_discovered = {} |
|
|
|
|
|
@dataclass |
|
|
class WorldState: |
|
|
seed: int |
|
|
step: int |
|
|
env_key: str |
|
|
grid: List[List[int]] |
|
|
agents: Dict[str, Agent] |
|
|
|
|
|
controlled: str |
|
|
pov: str |
|
|
overlay: bool |
|
|
|
|
|
done: bool |
|
|
outcome: str |
|
|
|
|
|
door_opened_global: bool = False |
|
|
base_progress: int = 0 |
|
|
base_target: int = 10 |
|
|
|
|
|
event_log: List[str] = None |
|
|
trace_log: List[str] = None |
|
|
|
|
|
cfg: TrainConfig = None |
|
|
q_tables: Dict[str, Dict[str, List[float]]] = None |
|
|
gmetrics: GlobalMetrics = None |
|
|
emetrics: EpisodeMetrics = None |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.event_log is None: |
|
|
self.event_log = [] |
|
|
if self.trace_log is None: |
|
|
self.trace_log = [] |
|
|
if self.cfg is None: |
|
|
self.cfg = TrainConfig() |
|
|
if self.q_tables is None: |
|
|
self.q_tables = {} |
|
|
if self.gmetrics is None: |
|
|
self.gmetrics = GlobalMetrics(epsilon=self.cfg.epsilon) |
|
|
if self.emetrics is None: |
|
|
self.emetrics = EpisodeMetrics() |
|
|
|
|
|
@dataclass |
|
|
class Snapshot: |
|
|
branch: str |
|
|
step: int |
|
|
env_key: str |
|
|
grid: List[List[int]] |
|
|
agents: Dict[str, Dict[str, Any]] |
|
|
done: bool |
|
|
outcome: str |
|
|
door_opened_global: bool |
|
|
base_progress: int |
|
|
base_target: int |
|
|
event_tail: List[str] |
|
|
trace_tail: List[str] |
|
|
emetrics: Dict[str, Any] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def in_bounds(x: int, y: int) -> bool: |
|
|
return 0 <= x < GRID_W and 0 <= y < GRID_H |
|
|
|
|
|
def is_blocking(tile: int, door_open: bool = False) -> bool: |
|
|
if tile == WALL: |
|
|
return True |
|
|
if tile == DOOR and not door_open: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def manhattan_xy(ax: int, ay: int, bx: int, by: int) -> int: |
|
|
return abs(ax - bx) + abs(ay - by) |
|
|
|
|
|
def bresenham_los(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> bool: |
|
|
dx = abs(x1 - x0) |
|
|
dy = abs(y1 - y0) |
|
|
sx = 1 if x0 < x1 else -1 |
|
|
sy = 1 if y0 < y1 else -1 |
|
|
err = dx - dy |
|
|
x, y = x0, y0 |
|
|
while True: |
|
|
if (x, y) != (x0, y0) and (x, y) != (x1, y1): |
|
|
if grid[y][x] == WALL: |
|
|
return False |
|
|
if x == x1 and y == y1: |
|
|
return True |
|
|
e2 = 2 * err |
|
|
if e2 > -dy: |
|
|
err -= dy |
|
|
x += sx |
|
|
if e2 < dx: |
|
|
err += dx |
|
|
y += sy |
|
|
|
|
|
def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = FOV_DEG) -> bool: |
|
|
dx = tx - observer.x |
|
|
dy = ty - observer.y |
|
|
if dx == 0 and dy == 0: |
|
|
return True |
|
|
angle = math.degrees(math.atan2(dy, dx)) % 360 |
|
|
facing = ORI_DEG[observer.ori] |
|
|
diff = (angle - facing + 540) % 360 - 180 |
|
|
return abs(diff) <= (fov_deg / 2) |
|
|
|
|
|
def visible(state: WorldState, observer: Agent, target: Agent) -> bool: |
|
|
if not within_fov(observer, target.x, target.y, FOV_DEG): |
|
|
return False |
|
|
return bresenham_los(state.grid, observer.x, observer.y, target.x, target.y) |
|
|
|
|
|
def hash_sha256(txt: str) -> str: |
|
|
return hashlib.sha256(txt.encode("utf-8")).hexdigest() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_beliefs(agent_names: List[str]) -> Dict[str, np.ndarray]: |
|
|
return {nm: (-1 * np.ones((GRID_H, GRID_W), dtype=np.int16)) for nm in agent_names} |
|
|
|
|
|
def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> int: |
|
|
before_unknown = int(np.sum(belief == -1)) |
|
|
|
|
|
belief[agent.y, agent.x] = state.grid[agent.y][agent.x] |
|
|
base = math.radians(ORI_DEG[agent.ori]) |
|
|
half = math.radians(FOV_DEG / 2) |
|
|
rays = 45 if agent.name.lower().startswith("scout") else 33 |
|
|
|
|
|
for i in range(rays): |
|
|
t = i / (rays - 1) |
|
|
ang = base + (t * 2 - 1) * half |
|
|
sin_a, cos_a = math.sin(ang), math.cos(ang) |
|
|
ox, oy = agent.x + 0.5, agent.y + 0.5 |
|
|
depth = 0.0 |
|
|
while depth < MAX_DEPTH: |
|
|
depth += 0.2 |
|
|
tx = int(ox + cos_a * depth) |
|
|
ty = int(oy + sin_a * depth) |
|
|
if not in_bounds(tx, ty): |
|
|
break |
|
|
belief[ty, tx] = state.grid[ty][tx] |
|
|
tile = state.grid[ty][tx] |
|
|
if tile == WALL: |
|
|
break |
|
|
if tile == DOOR and not state.door_opened_global: |
|
|
break |
|
|
|
|
|
after_unknown = int(np.sum(belief == -1)) |
|
|
return max(0, before_unknown - after_unknown) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def raycast_view(state: WorldState, observer: Agent) -> np.ndarray: |
|
|
img = np.zeros((VIEW_H, VIEW_W, 3), dtype=np.uint8) |
|
|
img[:, :] = SKY |
|
|
|
|
|
for y in range(VIEW_H // 2, VIEW_H): |
|
|
t = (y - VIEW_H // 2) / (VIEW_H // 2 + 1e-6) |
|
|
col = (1 - t) * FLOOR_NEAR + t * FLOOR_FAR |
|
|
img[y, :] = col.astype(np.uint8) |
|
|
|
|
|
fov = math.radians(FOV_DEG) |
|
|
half_fov = fov / 2 |
|
|
|
|
|
for rx in range(RAY_W): |
|
|
cam_x = (2 * rx / (RAY_W - 1)) - 1 |
|
|
ray_ang = math.radians(ORI_DEG[observer.ori]) + cam_x * half_fov |
|
|
|
|
|
ox, oy = observer.x + 0.5, observer.y + 0.5 |
|
|
sin_a = math.sin(ray_ang) |
|
|
cos_a = math.cos(ray_ang) |
|
|
|
|
|
depth = 0.0 |
|
|
hit = None |
|
|
side = 0 |
|
|
|
|
|
while depth < MAX_DEPTH: |
|
|
depth += 0.05 |
|
|
tx = int(ox + cos_a * depth) |
|
|
ty = int(oy + sin_a * depth) |
|
|
if not in_bounds(tx, ty): |
|
|
break |
|
|
tile = state.grid[ty][tx] |
|
|
if tile == WALL: |
|
|
hit = "wall" |
|
|
side = 1 if abs(cos_a) > abs(sin_a) else 0 |
|
|
break |
|
|
if tile == DOOR and not state.door_opened_global: |
|
|
hit = "door" |
|
|
break |
|
|
|
|
|
if hit is None: |
|
|
continue |
|
|
|
|
|
depth *= math.cos(ray_ang - math.radians(ORI_DEG[observer.ori])) |
|
|
depth = max(depth, 0.001) |
|
|
|
|
|
proj_h = int((VIEW_H * 0.9) / depth) |
|
|
y0 = max(0, VIEW_H // 2 - proj_h // 2) |
|
|
y1 = min(VIEW_H - 1, VIEW_H // 2 + proj_h // 2) |
|
|
|
|
|
if hit == "door": |
|
|
col = DOOR_COL.copy() |
|
|
else: |
|
|
col = WALL_BASE.copy() if side == 0 else WALL_SIDE.copy() |
|
|
|
|
|
dim = max(0.25, 1.0 - (depth / MAX_DEPTH)) |
|
|
col = (col * dim).astype(np.uint8) |
|
|
|
|
|
x0 = int(rx * (VIEW_W / RAY_W)) |
|
|
x1 = int((rx + 1) * (VIEW_W / RAY_W)) |
|
|
img[y0:y1, x0:x1] = col |
|
|
|
|
|
for nm, other in state.agents.items(): |
|
|
if nm == observer.name or other.hp <= 0: |
|
|
continue |
|
|
if visible(state, observer, other): |
|
|
dx = other.x - observer.x |
|
|
dy = other.y - observer.y |
|
|
ang = (math.degrees(math.atan2(dy, dx)) % 360) |
|
|
facing = ORI_DEG[observer.ori] |
|
|
diff = (ang - facing + 540) % 360 - 180 |
|
|
sx = int((diff / (FOV_DEG / 2)) * (VIEW_W / 2) + (VIEW_W / 2)) |
|
|
dist = math.sqrt(dx * dx + dy * dy) |
|
|
h = int((VIEW_H * 0.65) / max(dist, 0.75)) |
|
|
w = max(10, h // 3) |
|
|
y_mid = VIEW_H // 2 |
|
|
y0 = max(0, y_mid - h // 2) |
|
|
y1 = min(VIEW_H - 1, y_mid + h // 2) |
|
|
x0 = max(0, sx - w // 2) |
|
|
x1 = min(VIEW_W - 1, sx + w // 2) |
|
|
col = AGENT_COLORS.get(nm, (255, 200, 120)) |
|
|
img[y0:y1, x0:x1] = np.array(col, dtype=np.uint8) |
|
|
|
|
|
if state.overlay: |
|
|
cx, cy = VIEW_W // 2, VIEW_H // 2 |
|
|
img[cy - 1:cy + 2, cx - 10:cx + 10] = np.array([120, 190, 255], dtype=np.uint8) |
|
|
img[cy - 10:cy + 10, cx - 1:cx + 2] = np.array([120, 190, 255], dtype=np.uint8) |
|
|
|
|
|
return img |
|
|
|
|
|
def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_agents: bool = True) -> Image.Image: |
|
|
w = grid.shape[1] * TILE |
|
|
h = grid.shape[0] * TILE |
|
|
im = Image.new("RGB", (w, h + 28), (10, 12, 18)) |
|
|
draw = ImageDraw.Draw(im) |
|
|
|
|
|
for y in range(grid.shape[0]): |
|
|
for x in range(grid.shape[1]): |
|
|
t = int(grid[y, x]) |
|
|
if t == -1: |
|
|
col = (18, 20, 32) |
|
|
elif t == EMPTY: |
|
|
col = (26, 30, 44) |
|
|
elif t == WALL: |
|
|
col = (190, 190, 210) |
|
|
elif t == FOOD: |
|
|
col = (255, 210, 120) |
|
|
elif t == NOISE: |
|
|
col = (255, 120, 220) |
|
|
elif t == DOOR: |
|
|
col = (140, 210, 255) |
|
|
elif t == TELE: |
|
|
col = (120, 190, 255) |
|
|
elif t == KEY: |
|
|
col = (255, 235, 160) |
|
|
elif t == EXIT: |
|
|
col = (120, 255, 220) |
|
|
elif t == ARTIFACT: |
|
|
col = (255, 170, 60) |
|
|
elif t == HAZARD: |
|
|
col = (255, 90, 90) |
|
|
elif t == WOOD: |
|
|
col = (170, 120, 60) |
|
|
elif t == ORE: |
|
|
col = (140, 140, 160) |
|
|
elif t == MEDKIT: |
|
|
col = (120, 255, 140) |
|
|
elif t == SWITCH: |
|
|
col = (200, 180, 255) |
|
|
elif t == BASE: |
|
|
col = (220, 220, 240) |
|
|
else: |
|
|
col = (80, 80, 90) |
|
|
|
|
|
x0, y0 = x * TILE, y * TILE + 28 |
|
|
draw.rectangle([x0, y0, x0 + TILE - 1, y0 + TILE - 1], fill=col) |
|
|
|
|
|
for x in range(grid.shape[1] + 1): |
|
|
xx = x * TILE |
|
|
draw.line([xx, 28, xx, h + 28], fill=(12, 14, 22)) |
|
|
for y in range(grid.shape[0] + 1): |
|
|
yy = y * TILE + 28 |
|
|
draw.line([0, yy, w, yy], fill=(12, 14, 22)) |
|
|
|
|
|
if show_agents: |
|
|
for nm, a in agents.items(): |
|
|
if a.hp <= 0: |
|
|
continue |
|
|
cx = a.x * TILE + TILE // 2 |
|
|
cy = a.y * TILE + 28 + TILE // 2 |
|
|
col = AGENT_COLORS.get(nm, (220, 220, 220)) |
|
|
r = TILE // 3 |
|
|
draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col) |
|
|
dx, dy = DIRS[a.ori] |
|
|
draw.line([cx, cy, cx + dx * r, cy + dy * r], fill=(10, 10, 10), width=3) |
|
|
|
|
|
draw.rectangle([0, 0, w, 28], fill=(14, 16, 26)) |
|
|
draw.text((8, 6), title, fill=(230, 230, 240)) |
|
|
return im |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def grid_with_border() -> List[List[int]]: |
|
|
g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)] |
|
|
for x in range(GRID_W): |
|
|
g[0][x] = WALL |
|
|
g[GRID_H - 1][x] = WALL |
|
|
for y in range(GRID_H): |
|
|
g[y][0] = WALL |
|
|
g[y][GRID_W - 1] = WALL |
|
|
return g |
|
|
|
|
|
def env_chase(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]: |
|
|
g = grid_with_border() |
|
|
for x in range(4, 17): |
|
|
g[7][x] = WALL |
|
|
g[7][10] = DOOR |
|
|
g[3][4] = FOOD |
|
|
g[11][15] = FOOD |
|
|
g[4][14] = NOISE |
|
|
g[12][5] = NOISE |
|
|
g[2][18] = TELE |
|
|
g[13][2] = TELE |
|
|
|
|
|
agents = { |
|
|
"Predator": Agent("Predator", 2, 2, 0, hp=10, energy=100, team="A", brain="q"), |
|
|
"Prey": Agent("Prey", 18, 12, 2, hp=10, energy=100, team="B", brain="q"), |
|
|
"Scout": Agent("Scout", 10, 3, 1, hp=10, energy=100, team="A", brain="heuristic"), |
|
|
} |
|
|
return g, agents |
|
|
|
|
|
def env_vault(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]: |
|
|
g = grid_with_border() |
|
|
for x in range(3, 18): |
|
|
g[5][x] = WALL |
|
|
for x in range(3, 18): |
|
|
g[9][x] = WALL |
|
|
g[5][10] = DOOR |
|
|
g[9][12] = DOOR |
|
|
|
|
|
g[2][2] = KEY |
|
|
g[12][18] = EXIT |
|
|
g[12][2] = ARTIFACT |
|
|
g[2][18] = TELE |
|
|
g[13][2] = TELE |
|
|
g[7][10] = SWITCH |
|
|
g[3][15] = HAZARD |
|
|
g[11][6] = MEDKIT |
|
|
g[2][12] = FOOD |
|
|
|
|
|
agents = { |
|
|
"Alpha": Agent("Alpha", 2, 12, 0, hp=10, energy=100, team="A", brain="q"), |
|
|
"Bravo": Agent("Bravo", 3, 12, 0, hp=10, energy=100, team="A", brain="q"), |
|
|
"Guardian": Agent("Guardian", 18, 2, 2, hp=10, energy=100, team="B", brain="q"), |
|
|
} |
|
|
return g, agents |
|
|
|
|
|
def env_civ(seed: int) -> Tuple[List[List[int]], Dict[str, Agent]]: |
|
|
g = grid_with_border() |
|
|
for y in range(3, 12): |
|
|
g[y][9] = WALL |
|
|
g[7][9] = DOOR |
|
|
|
|
|
g[2][3] = WOOD |
|
|
g[3][3] = WOOD |
|
|
g[4][3] = WOOD |
|
|
g[12][16] = ORE |
|
|
g[11][16] = ORE |
|
|
g[10][16] = ORE |
|
|
g[6][4] = FOOD |
|
|
g[8][15] = FOOD |
|
|
|
|
|
g[13][10] = BASE |
|
|
g[4][15] = HAZARD |
|
|
g[10][4] = HAZARD |
|
|
g[2][18] = TELE |
|
|
g[13][2] = TELE |
|
|
g[2][2] = KEY |
|
|
g[12][6] = SWITCH |
|
|
|
|
|
agents = { |
|
|
"BuilderA": Agent("BuilderA", 3, 12, 0, hp=10, energy=100, team="A", brain="q"), |
|
|
"BuilderB": Agent("BuilderB", 4, 12, 0, hp=10, energy=100, team="A", brain="q"), |
|
|
"Raider": Agent("Raider", 18, 2, 2, hp=10, energy=100, team="B", brain="q"), |
|
|
} |
|
|
return g, agents |
|
|
|
|
|
ENV_BUILDERS = {"chase": env_chase, "vault": env_vault, "civ": env_civ} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def local_tile_ahead(state: WorldState, a: Agent) -> int: |
|
|
dx, dy = DIRS[a.ori] |
|
|
nx, ny = a.x + dx, a.y + dy |
|
|
if not in_bounds(nx, ny): |
|
|
return WALL |
|
|
return state.grid[ny][nx] |
|
|
|
|
|
def nearest_enemy_vec(state: WorldState, a: Agent) -> Tuple[int, int, int]: |
|
|
best = None |
|
|
for _, other in state.agents.items(): |
|
|
if other.hp <= 0: |
|
|
continue |
|
|
if other.team == a.team: |
|
|
continue |
|
|
d = manhattan_xy(a.x, a.y, other.x, other.y) |
|
|
if best is None or d < best[0]: |
|
|
best = (d, other.x - a.x, other.y - a.y) |
|
|
if best is None: |
|
|
return (99, 0, 0) |
|
|
d, dx, dy = best |
|
|
return (d, int(np.clip(dx, -6, 6)), int(np.clip(dy, -6, 6))) |
|
|
|
|
|
def obs_key(state: WorldState, who: str) -> str: |
|
|
a = state.agents[who] |
|
|
d, dx, dy = nearest_enemy_vec(state, a) |
|
|
ahead = local_tile_ahead(state, a) |
|
|
keys = a.inventory.get("key", 0) |
|
|
art = a.inventory.get("artifact", 0) |
|
|
wood = a.inventory.get("wood", 0) |
|
|
ore = a.inventory.get("ore", 0) |
|
|
inv_bucket = f"k{min(keys,2)}a{min(art,1)}w{min(wood,3)}o{min(ore,3)}" |
|
|
door = 1 if state.door_opened_global else 0 |
|
|
return f"{state.env_key}|{who}|{a.x},{a.y},{a.ori}|e{d}:{dx},{dy}|t{ahead}|hp{a.hp}|{inv_bucket}|D{door}|bp{state.base_progress}" |
|
|
|
|
|
def q_get(q: Dict[str, List[float]], key: str) -> List[float]: |
|
|
if key not in q: |
|
|
q[key] = [0.0 for _ in ACTIONS] |
|
|
return q[key] |
|
|
|
|
|
def epsilon_greedy(qvals: List[float], eps: float, r: np.random.Generator) -> int: |
|
|
if r.random() < eps: |
|
|
return int(r.integers(0, len(qvals))) |
|
|
return int(np.argmax(qvals)) |
|
|
|
|
|
def q_update(q: Dict[str, List[float]], key: str, a_idx: int, reward: float, next_key: str, |
|
|
alpha: float, gamma: float) -> Tuple[float, float, float]: |
|
|
qv = q_get(q, key) |
|
|
nq = q_get(q, next_key) |
|
|
old = qv[a_idx] |
|
|
target = reward + gamma * float(np.max(nq)) |
|
|
new = old + alpha * (target - old) |
|
|
qv[a_idx] = new |
|
|
return old, target, new |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def heuristic_action(state: WorldState, who: str) -> str: |
|
|
a = state.agents[who] |
|
|
r = rng_for(state.seed, state.step, stream=900 + hash(who) % 1000) |
|
|
|
|
|
t_here = state.grid[a.y][a.x] |
|
|
if t_here in (FOOD, KEY, ARTIFACT, WOOD, ORE, MEDKIT, SWITCH, BASE, EXIT): |
|
|
return "I" |
|
|
|
|
|
best = None |
|
|
best_d = 999 |
|
|
for _, other in state.agents.items(): |
|
|
if other.hp <= 0 or other.team == a.team: |
|
|
continue |
|
|
d = manhattan_xy(a.x, a.y, other.x, other.y) |
|
|
if d < best_d: |
|
|
best_d = d |
|
|
best = other |
|
|
|
|
|
if best is not None and best_d <= 6 and visible(state, a, best): |
|
|
dx = best.x - a.x |
|
|
dy = best.y - a.y |
|
|
ang = (math.degrees(math.atan2(dy, dx)) % 360) |
|
|
facing = ORI_DEG[a.ori] |
|
|
diff = (ang - facing + 540) % 360 - 180 |
|
|
if diff < -10: |
|
|
return "L" |
|
|
if diff > 10: |
|
|
return "R" |
|
|
return "F" |
|
|
|
|
|
return r.choice(["F", "F", "L", "R", "I"]) |
|
|
|
|
|
def random_action(state: WorldState, who: str) -> str: |
|
|
r = rng_for(state.seed, state.step, stream=700 + hash(who) % 1000) |
|
|
return r.choice(ACTIONS) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def turn_left(a: Agent) -> None: |
|
|
a.ori = (a.ori - 1) % 4 |
|
|
|
|
|
def turn_right(a: Agent) -> None: |
|
|
a.ori = (a.ori + 1) % 4 |
|
|
|
|
|
def move_forward(state: WorldState, a: Agent) -> str: |
|
|
dx, dy = DIRS[a.ori] |
|
|
nx, ny = a.x + dx, a.y + dy |
|
|
if not in_bounds(nx, ny): |
|
|
return "blocked: bounds" |
|
|
tile = state.grid[ny][nx] |
|
|
if is_blocking(tile, door_open=state.door_opened_global): |
|
|
return "blocked: wall/door" |
|
|
a.x, a.y = nx, ny |
|
|
|
|
|
if state.grid[ny][nx] == TELE: |
|
|
teles = [(x, y) for y in range(GRID_H) for x in range(GRID_W) if state.grid[y][x] == TELE] |
|
|
if len(teles) >= 2: |
|
|
teles_sorted = sorted(teles) |
|
|
idx = teles_sorted.index((nx, ny)) |
|
|
dest = teles_sorted[(idx + 1) % len(teles_sorted)] |
|
|
a.x, a.y = dest |
|
|
state.event_log.append(f"t={state.step}: {a.name} teleported.") |
|
|
return "moved: teleported" |
|
|
return "moved" |
|
|
|
|
|
def try_interact(state: WorldState, a: Agent) -> str: |
|
|
t = state.grid[a.y][a.x] |
|
|
|
|
|
if t == SWITCH: |
|
|
state.door_opened_global = True |
|
|
state.grid[a.y][a.x] = EMPTY |
|
|
a.inventory["switch"] = a.inventory.get("switch", 0) + 1 |
|
|
return "switch: opened all doors" |
|
|
|
|
|
if t == KEY: |
|
|
a.inventory["key"] = a.inventory.get("key", 0) + 1 |
|
|
state.grid[a.y][a.x] = EMPTY |
|
|
return "picked: key" |
|
|
|
|
|
if t == ARTIFACT: |
|
|
a.inventory["artifact"] = a.inventory.get("artifact", 0) + 1 |
|
|
state.grid[a.y][a.x] = EMPTY |
|
|
return "picked: artifact" |
|
|
|
|
|
if t == FOOD: |
|
|
a.energy = min(200, a.energy + 35) |
|
|
state.grid[a.y][a.x] = EMPTY |
|
|
return "ate: food" |
|
|
|
|
|
if t == WOOD: |
|
|
a.inventory["wood"] = a.inventory.get("wood", 0) + 1 |
|
|
state.grid[a.y][a.x] = EMPTY |
|
|
return "picked: wood" |
|
|
|
|
|
if t == ORE: |
|
|
a.inventory["ore"] = a.inventory.get("ore", 0) + 1 |
|
|
state.grid[a.y][a.x] = EMPTY |
|
|
return "picked: ore" |
|
|
|
|
|
if t == MEDKIT: |
|
|
a.hp = min(10, a.hp + 3) |
|
|
state.grid[a.y][a.x] = EMPTY |
|
|
return "used: medkit" |
|
|
|
|
|
if t == BASE: |
|
|
w = a.inventory.get("wood", 0) |
|
|
o = a.inventory.get("ore", 0) |
|
|
dep = min(w, 2) + min(o, 2) |
|
|
if dep > 0: |
|
|
a.inventory["wood"] = max(0, w - min(w, 2)) |
|
|
a.inventory["ore"] = max(0, o - min(o, 2)) |
|
|
state.base_progress += dep |
|
|
return f"deposited: +{dep} base_progress" |
|
|
return "base: nothing to deposit" |
|
|
|
|
|
if t == EXIT: |
|
|
return "at_exit" |
|
|
|
|
|
return "interact: none" |
|
|
|
|
|
def apply_action(state: WorldState, who: str, action: str) -> str: |
|
|
a = state.agents[who] |
|
|
if a.hp <= 0: |
|
|
return "dead" |
|
|
if action == "L": |
|
|
turn_left(a) |
|
|
return "turned left" |
|
|
if action == "R": |
|
|
turn_right(a) |
|
|
return "turned right" |
|
|
if action == "F": |
|
|
return move_forward(state, a) |
|
|
if action == "I": |
|
|
return try_interact(state, a) |
|
|
return "noop" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_hazards(state: WorldState, a: Agent) -> Tuple[bool, str]: |
|
|
if a.hp <= 0: |
|
|
return (False, "") |
|
|
if state.grid[a.y][a.x] == HAZARD: |
|
|
a.hp -= 1 |
|
|
return (True, "hazard:-hp") |
|
|
return (False, "") |
|
|
|
|
|
def resolve_tags(state: WorldState) -> List[str]: |
|
|
msgs = [] |
|
|
occupied: Dict[Tuple[int, int], List[str]] = {} |
|
|
for nm, a in state.agents.items(): |
|
|
if a.hp <= 0: |
|
|
continue |
|
|
occupied.setdefault((a.x, a.y), []).append(nm) |
|
|
|
|
|
for (x, y), names in occupied.items(): |
|
|
if len(names) < 2: |
|
|
continue |
|
|
teams = set(state.agents[n].team for n in names) |
|
|
if len(teams) >= 2: |
|
|
for n in names: |
|
|
state.agents[n].hp -= 1 |
|
|
msgs.append(f"t={state.step}: collision/tag at ({x},{y}) {names} (-hp all)") |
|
|
return msgs |
|
|
|
|
|
def check_done(state: WorldState) -> None: |
|
|
if state.env_key == "chase": |
|
|
pred = state.agents["Predator"] |
|
|
prey = state.agents["Prey"] |
|
|
if pred.hp <= 0 and prey.hp <= 0: |
|
|
state.done = True |
|
|
state.outcome = "draw" |
|
|
return |
|
|
if pred.hp > 0 and prey.hp > 0 and pred.x == prey.x and pred.y == prey.y: |
|
|
state.done = True |
|
|
state.outcome = "A_win" |
|
|
state.event_log.append(f"t={state.step}: CAUGHT (Predator wins).") |
|
|
return |
|
|
if state.step >= 300 and prey.hp > 0: |
|
|
state.done = True |
|
|
state.outcome = "B_win" |
|
|
state.event_log.append(f"t={state.step}: ESCAPED (Prey survives).") |
|
|
return |
|
|
|
|
|
if state.env_key == "vault": |
|
|
for nm in ["Alpha", "Bravo"]: |
|
|
a = state.agents[nm] |
|
|
if a.hp > 0 and a.inventory.get("artifact", 0) > 0 and state.grid[a.y][a.x] == EXIT: |
|
|
state.done = True |
|
|
state.outcome = "A_win" |
|
|
state.event_log.append(f"t={state.step}: VAULT CLEARED (Team A wins).") |
|
|
return |
|
|
alive_A = any(state.agents[n].hp > 0 for n in ["Alpha", "Bravo"]) |
|
|
if not alive_A: |
|
|
state.done = True |
|
|
state.outcome = "B_win" |
|
|
state.event_log.append(f"t={state.step}: TEAM A ELIMINATED (Guardian wins).") |
|
|
return |
|
|
|
|
|
if state.env_key == "civ": |
|
|
if state.base_progress >= state.base_target: |
|
|
state.done = True |
|
|
state.outcome = "A_win" |
|
|
state.event_log.append(f"t={state.step}: BASE COMPLETE (Builders win).") |
|
|
return |
|
|
alive_A = any(state.agents[n].hp > 0 for n in ["BuilderA", "BuilderB"]) |
|
|
if not alive_A: |
|
|
state.done = True |
|
|
state.outcome = "B_win" |
|
|
state.event_log.append(f"t={state.step}: BUILDERS ELIMINATED (Raider wins).") |
|
|
return |
|
|
if state.step >= 350: |
|
|
state.done = True |
|
|
state.outcome = "draw" |
|
|
state.event_log.append(f"t={state.step}: TIMEOUT (draw).") |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reward_for(prev: WorldState, now: WorldState, who: str, outcome_msg: str, took_damage: bool) -> float: |
|
|
cfg = now.cfg |
|
|
r = cfg.step_penalty |
|
|
if outcome_msg.startswith("moved"): |
|
|
r += cfg.explore_reward |
|
|
if took_damage: |
|
|
r += cfg.damage_penalty |
|
|
if outcome_msg.startswith("used: medkit"): |
|
|
r += cfg.heal_reward |
|
|
|
|
|
if now.env_key == "chase": |
|
|
pred = now.agents["Predator"] |
|
|
prey = now.agents["Prey"] |
|
|
if who == "Predator": |
|
|
d0 = manhattan_xy(prev.agents["Predator"].x, prev.agents["Predator"].y, |
|
|
prev.agents["Prey"].x, prev.agents["Prey"].y) |
|
|
d1 = manhattan_xy(pred.x, pred.y, prey.x, prey.y) |
|
|
r += cfg.chase_close_coeff * float(d0 - d1) |
|
|
if now.done and now.outcome == "A_win": |
|
|
r += cfg.chase_catch_reward |
|
|
if who == "Prey": |
|
|
if outcome_msg.startswith("ate: food"): |
|
|
r += cfg.food_reward |
|
|
if now.done and now.outcome == "B_win": |
|
|
r += cfg.chase_escaped_reward |
|
|
if now.done and now.outcome == "A_win": |
|
|
r += cfg.chase_caught_penalty |
|
|
|
|
|
if now.env_key == "vault": |
|
|
if outcome_msg.startswith("picked: artifact"): |
|
|
r += cfg.artifact_pick_reward |
|
|
if outcome_msg.startswith("picked: key"): |
|
|
r += cfg.key_reward |
|
|
if outcome_msg.startswith("switch:"): |
|
|
r += cfg.switch_reward |
|
|
if now.done: |
|
|
if now.outcome == "A_win" and now.agents[who].team == "A": |
|
|
r += cfg.exit_win_reward |
|
|
if now.outcome == "B_win" and now.agents[who].team == "B": |
|
|
r += cfg.guardian_tag_reward |
|
|
if now.outcome == "B_win" and now.agents[who].team == "A": |
|
|
r += cfg.tagged_penalty |
|
|
|
|
|
if now.env_key == "civ": |
|
|
if outcome_msg.startswith("picked: wood") or outcome_msg.startswith("picked: ore"): |
|
|
r += cfg.resource_pick_reward |
|
|
if outcome_msg.startswith("deposited:"): |
|
|
r += cfg.deposit_reward |
|
|
if now.done: |
|
|
if now.outcome == "A_win" and now.agents[who].team == "A": |
|
|
r += cfg.base_progress_win_reward |
|
|
if now.outcome == "B_win" and now.agents[who].team == "B": |
|
|
r += cfg.raider_elim_reward |
|
|
if now.outcome == "B_win" and now.agents[who].team == "A": |
|
|
r += cfg.builder_elim_penalty |
|
|
|
|
|
return float(r) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, Optional[Tuple[str, int]]]: |
|
|
a = state.agents[who] |
|
|
cfg = state.cfg |
|
|
r = rng_for(state.seed, state.step, stream=stream) |
|
|
|
|
|
if a.brain == "random": |
|
|
act = random_action(state, who) |
|
|
return act, "random", None |
|
|
if a.brain == "heuristic": |
|
|
act = heuristic_action(state, who) |
|
|
return act, "heuristic", None |
|
|
|
|
|
if cfg.use_q: |
|
|
key = obs_key(state, who) |
|
|
qtab = state.q_tables.setdefault(who, {}) |
|
|
qv = q_get(qtab, key) |
|
|
a_idx = epsilon_greedy(qv, state.gmetrics.epsilon, r) |
|
|
return ACTIONS[a_idx], f"Q eps={state.gmetrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (key, a_idx) |
|
|
|
|
|
act = heuristic_action(state, who) |
|
|
return act, "heuristic(fallback)", None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_state(seed: int, env_key: str) -> WorldState: |
|
|
g, agents = ENV_BUILDERS[env_key](seed) |
|
|
st = WorldState( |
|
|
seed=seed, |
|
|
step=0, |
|
|
env_key=env_key, |
|
|
grid=g, |
|
|
agents=agents, |
|
|
controlled=list(agents.keys())[0], |
|
|
pov=list(agents.keys())[0], |
|
|
overlay=False, |
|
|
done=False, |
|
|
outcome="ongoing", |
|
|
door_opened_global=False, |
|
|
base_progress=0, |
|
|
base_target=10, |
|
|
) |
|
|
st.event_log = [f"Initialized env={env_key} seed={seed}."] |
|
|
return st |
|
|
|
|
|
def reset_episode_keep_learning(state: WorldState, seed: Optional[int] = None) -> WorldState: |
|
|
if seed is None: |
|
|
seed = state.seed |
|
|
fresh = init_state(int(seed), state.env_key) |
|
|
fresh.cfg = state.cfg |
|
|
fresh.q_tables = state.q_tables |
|
|
fresh.gmetrics = state.gmetrics |
|
|
fresh.gmetrics.epsilon = state.gmetrics.epsilon |
|
|
return fresh |
|
|
|
|
|
def wipe_all(seed: int, env_key: str) -> WorldState: |
|
|
st = init_state(seed, env_key) |
|
|
st.cfg = TrainConfig() |
|
|
st.gmetrics = GlobalMetrics(epsilon=st.cfg.epsilon) |
|
|
st.q_tables = {} |
|
|
return st |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def snapshot_of(state: WorldState, branch: str) -> Snapshot: |
|
|
return Snapshot( |
|
|
branch=branch, |
|
|
step=state.step, |
|
|
env_key=state.env_key, |
|
|
grid=[row[:] for row in state.grid], |
|
|
agents={k: asdict(v) for k, v in state.agents.items()}, |
|
|
done=state.done, |
|
|
outcome=state.outcome, |
|
|
door_opened_global=state.door_opened_global, |
|
|
base_progress=state.base_progress, |
|
|
base_target=state.base_target, |
|
|
event_tail=state.event_log[-25:], |
|
|
trace_tail=state.trace_log[-40:], |
|
|
emetrics=asdict(state.emetrics), |
|
|
) |
|
|
|
|
|
def restore_into(state: WorldState, snap: Snapshot) -> WorldState: |
|
|
state.step = snap.step |
|
|
state.env_key = snap.env_key |
|
|
state.grid = [row[:] for row in snap.grid] |
|
|
state.agents = {k: Agent(**d) for k, d in snap.agents.items()} |
|
|
state.done = snap.done |
|
|
state.outcome = snap.outcome |
|
|
state.door_opened_global = snap.door_opened_global |
|
|
state.base_progress = snap.base_progress |
|
|
state.base_target = snap.base_target |
|
|
state.event_log.append(f"Jumped to snapshot t={snap.step} (branch={snap.branch}).") |
|
|
return state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def metrics_dashboard_image(state: WorldState) -> Image.Image: |
|
|
gm = state.gmetrics |
|
|
|
|
|
fig = plt.figure(figsize=(7.0, 2.2), dpi=120) |
|
|
ax = fig.add_subplot(111) |
|
|
|
|
|
x1 = max(1, gm.episodes) |
|
|
ax.plot([0, x1], [gm.rolling_winrate_A, gm.rolling_winrate_A]) |
|
|
ax.set_title("Global Metrics Snapshot") |
|
|
ax.set_xlabel("Episodes") |
|
|
ax.set_ylabel("Rolling winrate Team A") |
|
|
ax.set_ylim(-0.05, 1.05) |
|
|
ax.grid(True) |
|
|
|
|
|
txt = ( |
|
|
f"env={state.env_key} | eps={gm.epsilon:.3f} | episodes={gm.episodes}\n" |
|
|
f"A_wins={gm.wins_teamA} B_wins={gm.wins_teamB} draws={gm.draws} | avg_steps~{gm.avg_steps:.1f}\n" |
|
|
f"last_outcome={gm.last_outcome} last_steps={gm.last_steps}" |
|
|
) |
|
|
ax.text(0.01, 0.05, txt, transform=ax.transAxes, fontsize=8, va="bottom") |
|
|
|
|
|
fig.tight_layout() |
|
|
canvas = FigureCanvas(fig) |
|
|
canvas.draw() |
|
|
buf = np.asarray(canvas.buffer_rgba()) |
|
|
img = Image.fromarray(buf, mode="RGBA").convert("RGB") |
|
|
plt.close(fig) |
|
|
return img |
|
|
|
|
|
def action_entropy(counts: Dict[str, int]) -> float: |
|
|
total = sum(counts.values()) |
|
|
if total <= 0: |
|
|
return 0.0 |
|
|
p = np.array([c / total for c in counts.values()], dtype=np.float64) |
|
|
p = np.clip(p, 1e-12, 1.0) |
|
|
return float(-np.sum(p * np.log2(p))) |
|
|
|
|
|
def agent_scoreboard(state: WorldState) -> str: |
|
|
rows = [] |
|
|
header = ["agent", "team", "hp", "return", "steps", "entropy", "tiles_disc", "q_states", "inventory"] |
|
|
rows.append(header) |
|
|
steps = state.emetrics.steps |
|
|
|
|
|
for nm, a in state.agents.items(): |
|
|
ret = state.emetrics.returns.get(nm, 0.0) |
|
|
counts = state.emetrics.action_counts.get(nm, {}) |
|
|
ent = action_entropy(counts) |
|
|
td = state.emetrics.tiles_discovered.get(nm, 0) |
|
|
qs = len(state.q_tables.get(nm, {})) |
|
|
inv = json.dumps(a.inventory, sort_keys=True) |
|
|
rows.append([nm, a.team, a.hp, f"{ret:.2f}", steps, f"{ent:.2f}", td, qs, inv]) |
|
|
|
|
|
col_w = [max(len(str(r[i])) for r in rows) for i in range(len(header))] |
|
|
lines = [] |
|
|
for ridx, r in enumerate(rows): |
|
|
line = " | ".join(str(r[i]).ljust(col_w[i]) for i in range(len(header))) |
|
|
lines.append(line) |
|
|
if ridx == 0: |
|
|
lines.append("-+-".join("-" * w for w in col_w)) |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clone_shallow(state: WorldState) -> WorldState: |
|
|
return WorldState( |
|
|
seed=state.seed, |
|
|
step=state.step, |
|
|
env_key=state.env_key, |
|
|
grid=[row[:] for row in state.grid], |
|
|
agents={k: Agent(**asdict(v)) for k, v in state.agents.items()}, |
|
|
controlled=state.controlled, |
|
|
pov=state.pov, |
|
|
overlay=state.overlay, |
|
|
done=state.done, |
|
|
outcome=state.outcome, |
|
|
door_opened_global=state.door_opened_global, |
|
|
base_progress=state.base_progress, |
|
|
base_target=state.base_target, |
|
|
event_log=list(state.event_log), |
|
|
trace_log=list(state.trace_log), |
|
|
cfg=state.cfg, |
|
|
q_tables=state.q_tables, |
|
|
gmetrics=state.gmetrics, |
|
|
emetrics=state.emetrics, |
|
|
) |
|
|
|
|
|
def update_action_counts(state: WorldState, who: str, act: str): |
|
|
state.emetrics.action_counts.setdefault(who, {}) |
|
|
state.emetrics.action_counts[who][act] = state.emetrics.action_counts[who].get(act, 0) + 1 |
|
|
|
|
|
def tick(state: WorldState, beliefs: Dict[str, np.ndarray], manual_action: Optional[str] = None) -> None: |
|
|
if state.done: |
|
|
return |
|
|
|
|
|
prev = clone_shallow(state) |
|
|
chosen: Dict[str, str] = {} |
|
|
reasons: Dict[str, str] = {} |
|
|
qinfo: Dict[str, Optional[Tuple[str, int]]] = {} |
|
|
|
|
|
if manual_action is not None: |
|
|
chosen[state.controlled] = manual_action |
|
|
reasons[state.controlled] = "manual" |
|
|
qinfo[state.controlled] = None |
|
|
|
|
|
order = list(state.agents.keys()) |
|
|
for who in order: |
|
|
if who in chosen: |
|
|
continue |
|
|
act, reason, qi = choose_action(state, who, stream=200 + (hash(who) % 1000)) |
|
|
chosen[who] = act |
|
|
reasons[who] = reason |
|
|
qinfo[who] = qi |
|
|
|
|
|
outcomes: Dict[str, str] = {} |
|
|
took_damage: Dict[str, bool] = {nm: False for nm in order} |
|
|
|
|
|
for who in order: |
|
|
outcomes[who] = apply_action(state, who, chosen[who]) |
|
|
dmg, msg = resolve_hazards(state, state.agents[who]) |
|
|
took_damage[who] = dmg |
|
|
if msg: |
|
|
state.event_log.append(f"t={state.step}: {who} {msg}") |
|
|
update_action_counts(state, who, chosen[who]) |
|
|
|
|
|
for m in resolve_tags(state): |
|
|
state.event_log.append(m) |
|
|
|
|
|
for nm, a in state.agents.items(): |
|
|
if a.hp <= 0: |
|
|
continue |
|
|
disc = update_belief_for_agent(state, beliefs[nm], a) |
|
|
state.emetrics.tiles_discovered[nm] = state.emetrics.tiles_discovered.get(nm, 0) + disc |
|
|
|
|
|
check_done(state) |
|
|
|
|
|
q_lines = [] |
|
|
for who in order: |
|
|
state.emetrics.returns.setdefault(who, 0.0) |
|
|
|
|
|
r = reward_for(prev, state, who, outcomes[who], took_damage[who]) |
|
|
state.emetrics.returns[who] += r |
|
|
|
|
|
if qinfo.get(who) is not None: |
|
|
key, a_idx = qinfo[who] |
|
|
next_key = obs_key(state, who) |
|
|
qtab = state.q_tables.setdefault(who, {}) |
|
|
old, tgt, new = q_update(qtab, key, a_idx, r, next_key, state.cfg.alpha, state.cfg.gamma) |
|
|
q_lines.append(f"{who}: old={old:.3f} tgt={tgt:.3f} new={new:.3f} (a={ACTIONS[a_idx]})") |
|
|
|
|
|
trace = f"t={state.step} env={state.env_key} done={state.done} outcome={state.outcome}" |
|
|
for who in order: |
|
|
a = state.agents[who] |
|
|
trace += f" | {who}:{chosen[who]} ({outcomes[who]}) hp={a.hp} [{reasons[who]}]" |
|
|
if q_lines: |
|
|
trace += " | Q: " + " ; ".join(q_lines) |
|
|
|
|
|
state.trace_log.append(trace) |
|
|
if len(state.trace_log) > TRACE_MAX: |
|
|
state.trace_log = state.trace_log[-TRACE_MAX:] |
|
|
|
|
|
state.step += 1 |
|
|
state.emetrics.steps = state.step |
|
|
|
|
|
def run_episode(state: WorldState, beliefs: Dict[str, np.ndarray], max_steps: int) -> Tuple[str, int]: |
|
|
while state.step < max_steps and not state.done: |
|
|
tick(state, beliefs, manual_action=None) |
|
|
return state.outcome, state.step |
|
|
|
|
|
def update_global_metrics_after_episode(state: WorldState, outcome: str, steps: int): |
|
|
gm = state.gmetrics |
|
|
gm.episodes += 1 |
|
|
gm.last_outcome = outcome |
|
|
gm.last_steps = steps |
|
|
|
|
|
if outcome == "A_win": |
|
|
gm.wins_teamA += 1 |
|
|
gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 1.0 |
|
|
elif outcome == "B_win": |
|
|
gm.wins_teamB += 1 |
|
|
gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.0 |
|
|
else: |
|
|
gm.draws += 1 |
|
|
gm.rolling_winrate_A = 0.90 * gm.rolling_winrate_A + 0.10 * 0.5 |
|
|
|
|
|
gm.avg_steps = (0.90 * gm.avg_steps + 0.10 * steps) if gm.avg_steps > 0 else float(steps) |
|
|
gm.epsilon = max(state.cfg.epsilon_min, gm.epsilon * state.cfg.epsilon_decay) |
|
|
|
|
|
def train(state: WorldState, episodes: int, max_steps: int) -> WorldState: |
|
|
for ep in range(episodes): |
|
|
ep_seed = (state.seed * 1_000_003 + (state.gmetrics.episodes + ep) * 97_531) & 0xFFFFFFFF |
|
|
state = reset_episode_keep_learning(state, seed=int(ep_seed)) |
|
|
beliefs = init_beliefs(list(state.agents.keys())) |
|
|
outcome, steps = run_episode(state, beliefs, max_steps=max_steps) |
|
|
update_global_metrics_after_episode(state, outcome, steps) |
|
|
|
|
|
state.event_log.append( |
|
|
f"Training: +{episodes} eps | eps={state.gmetrics.epsilon:.3f} | " |
|
|
f"A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws}" |
|
|
) |
|
|
state = reset_episode_keep_learning(state, seed=state.seed) |
|
|
return state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def export_run(state: WorldState, branches: Dict[str, List[Snapshot]], active_branch: str, rewind_idx: int) -> str: |
|
|
payload = { |
|
|
"seed": state.seed, |
|
|
"env_key": state.env_key, |
|
|
"controlled": state.controlled, |
|
|
"pov": state.pov, |
|
|
"overlay": state.overlay, |
|
|
"cfg": asdict(state.cfg), |
|
|
"gmetrics": asdict(state.gmetrics), |
|
|
"q_tables": state.q_tables, |
|
|
"branches": {b: [asdict(s) for s in snaps] for b, snaps in branches.items()}, |
|
|
"active_branch": active_branch, |
|
|
"rewind_idx": int(rewind_idx), |
|
|
"grid": state.grid, |
|
|
"door_opened_global": state.door_opened_global, |
|
|
"base_progress": state.base_progress, |
|
|
"base_target": state.base_target, |
|
|
} |
|
|
txt = json.dumps(payload, indent=2) |
|
|
proof = hash_sha256(txt) |
|
|
return txt + "\n\n" + json.dumps({"proof_sha256": proof}, indent=2) |
|
|
|
|
|
def import_run(txt: str) -> Tuple[WorldState, Dict[str, List[Snapshot]], str, int, Dict[str, np.ndarray]]: |
|
|
parts = txt.strip().split("\n\n") |
|
|
data = json.loads(parts[0]) |
|
|
|
|
|
st = init_state(int(data.get("seed", 1337)), data.get("env_key", "chase")) |
|
|
st.controlled = data.get("controlled", st.controlled) |
|
|
st.pov = data.get("pov", st.pov) |
|
|
st.overlay = bool(data.get("overlay", False)) |
|
|
st.grid = data.get("grid", st.grid) |
|
|
st.door_opened_global = bool(data.get("door_opened_global", False)) |
|
|
st.base_progress = int(data.get("base_progress", 0)) |
|
|
st.base_target = int(data.get("base_target", 10)) |
|
|
|
|
|
st.cfg = TrainConfig(**data.get("cfg", asdict(st.cfg))) |
|
|
st.gmetrics = GlobalMetrics(**data.get("gmetrics", asdict(st.gmetrics))) |
|
|
st.q_tables = data.get("q_tables", {}) |
|
|
|
|
|
branches_in = data.get("branches", {}) |
|
|
branches: Dict[str, List[Snapshot]] = {} |
|
|
for bname, snaps in branches_in.items(): |
|
|
branches[bname] = [Snapshot(**s) for s in snaps] |
|
|
|
|
|
active = data.get("active_branch", "main") |
|
|
r_idx = int(data.get("rewind_idx", 0)) |
|
|
|
|
|
if active in branches and branches[active]: |
|
|
st = restore_into(st, branches[active][-1]) |
|
|
st.event_log.append("Imported run (restored last snapshot).") |
|
|
else: |
|
|
st.event_log.append("Imported run (no snapshots).") |
|
|
|
|
|
beliefs = init_beliefs(list(st.agents.keys())) |
|
|
return st, branches, active, r_idx, beliefs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, Image.Image, str, str, str, str]: |
|
|
for nm, a in state.agents.items(): |
|
|
if a.hp > 0: |
|
|
update_belief_for_agent(state, beliefs[nm], a) |
|
|
|
|
|
pov = raycast_view(state, state.agents[state.pov]) |
|
|
truth_np = np.array(state.grid, dtype=np.int16) |
|
|
truth_img = render_topdown(truth_np, state.agents, f"Truth Map — env={state.env_key} t={state.step} seed={state.seed}", True) |
|
|
|
|
|
ctrl = state.controlled |
|
|
others = [k for k in state.agents.keys() if k != ctrl] |
|
|
other = others[0] if others else ctrl |
|
|
b_ctrl = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief", True) |
|
|
b_other = render_topdown(beliefs[other], state.agents, f"{other} Belief", True) |
|
|
|
|
|
dash = metrics_dashboard_image(state) |
|
|
|
|
|
status = ( |
|
|
f"env={state.env_key} | seed={state.seed} | Controlled={state.controlled} | POV={state.pov} | done={state.done} outcome={state.outcome}\n" |
|
|
f"Episode steps={state.step} | base_progress={state.base_progress}/{state.base_target} | doors_open={state.door_opened_global}\n" |
|
|
f"Global: episodes={state.gmetrics.episodes} | A={state.gmetrics.wins_teamA} B={state.gmetrics.wins_teamB} D={state.gmetrics.draws} " |
|
|
f"| winrateA~{state.gmetrics.rolling_winrate_A:.2f} | eps={state.gmetrics.epsilon:.3f}" |
|
|
) |
|
|
events = "\n".join(state.event_log[-18:]) |
|
|
trace = "\n".join(state.trace_log[-18:]) |
|
|
scoreboard = agent_scoreboard(state) |
|
|
return pov, truth_img, b_ctrl, b_other, dash, status, events, trace, scoreboard |
|
|
|
|
|
def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState) -> WorldState: |
|
|
x_px, y_px = evt.index |
|
|
y_px -= 28 |
|
|
if y_px < 0: |
|
|
return state |
|
|
gx = int(x_px // TILE) |
|
|
gy = int(y_px // TILE) |
|
|
if not in_bounds(gx, gy): |
|
|
return state |
|
|
if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1: |
|
|
return state |
|
|
state.grid[gy][gx] = selected_tile |
|
|
state.event_log.append(f"t={state.step}: Tile ({gx},{gy}) -> {TILE_NAMES.get(selected_tile)}") |
|
|
return state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TITLE = "ZEN AgentLab — Agent POV + Autoplay Multi-Agent Sims" |
|
|
|
|
|
with gr.Blocks(title=TITLE) as demo: |
|
|
gr.Markdown( |
|
|
f"## {TITLE}\n" |
|
|
"**Press Start Autoplay** to watch the sim unfold live. Interject anytime with manual actions or edits.\n" |
|
|
"Use **Cinematic Run** for an instant full-episode spectacle. No background timers beyond the UI autoplay." |
|
|
) |
|
|
|
|
|
st0 = init_state(1337, "chase") |
|
|
st = gr.State(st0) |
|
|
branches = gr.State({"main": [snapshot_of(st0, "main")]}) |
|
|
active_branch = gr.State("main") |
|
|
rewind_idx = gr.State(0) |
|
|
beliefs = gr.State(init_beliefs(list(st0.agents.keys()))) |
|
|
|
|
|
autoplay_on = gr.State(False) |
|
|
|
|
|
with gr.Row(): |
|
|
pov_img = gr.Image(label="POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H) |
|
|
with gr.Column(): |
|
|
status = gr.Textbox(label="Status", lines=3) |
|
|
scoreboard = gr.Textbox(label="Agent Scoreboard", lines=8) |
|
|
|
|
|
with gr.Row(): |
|
|
truth = gr.Image(label="Truth Map (click to edit tiles)", type="pil") |
|
|
belief_a = gr.Image(label="Belief (Controlled)", type="pil") |
|
|
belief_b = gr.Image(label="Belief (Other)", type="pil") |
|
|
|
|
|
with gr.Row(): |
|
|
dash = gr.Image(label="Metrics Dashboard", type="pil") |
|
|
|
|
|
with gr.Row(): |
|
|
events = gr.Textbox(label="Event Log", lines=10) |
|
|
trace = gr.Textbox(label="Step Trace", lines=10) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("### Quick Start (Examples)") |
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
["chase", 1337], |
|
|
["vault", 2024], |
|
|
["civ", 777], |
|
|
], |
|
|
inputs=[], |
|
|
label="", |
|
|
) |
|
|
gr.Markdown("Pick an environment + seed below, then click **Apply**.") |
|
|
|
|
|
with gr.Row(): |
|
|
env_pick = gr.Radio( |
|
|
choices=[("Chase (Predator vs Prey)", "chase"), |
|
|
("CoopVault (team vs guardian)", "vault"), |
|
|
("MiniCiv (build + raid)", "civ")], |
|
|
value="chase", |
|
|
label="Environment" |
|
|
) |
|
|
seed_box = gr.Number(value=1337, precision=0, label="Seed") |
|
|
|
|
|
with gr.Row(): |
|
|
btn_apply_env_seed = gr.Button("Apply (Env + Seed)") |
|
|
btn_reset_ep = gr.Button("Reset Episode (keep learning)") |
|
|
|
|
|
gr.Markdown("### Autoplay + Spectacle") |
|
|
with gr.Row(): |
|
|
autoplay_speed = gr.Slider(0.05, 1.0, value=0.20, step=0.05, label="Autoplay step interval (seconds)") |
|
|
with gr.Row(): |
|
|
btn_autoplay_start = gr.Button("▶ Start Autoplay") |
|
|
btn_autoplay_stop = gr.Button("⏸ Stop Autoplay") |
|
|
with gr.Row(): |
|
|
cinematic_steps = gr.Number(value=350, precision=0, label="Cinematic max steps") |
|
|
btn_cinematic = gr.Button("🎬 Cinematic Run (Full Episode)") |
|
|
|
|
|
gr.Markdown("### Manual Controls (Interject Anytime)") |
|
|
with gr.Row(): |
|
|
btn_L = gr.Button("L") |
|
|
btn_F = gr.Button("F") |
|
|
btn_R = gr.Button("R") |
|
|
btn_I = gr.Button("I (Interact)") |
|
|
with gr.Row(): |
|
|
btn_tick = gr.Button("Tick") |
|
|
run_steps = gr.Number(value=25, label="Run N steps", precision=0) |
|
|
btn_run = gr.Button("Run") |
|
|
|
|
|
with gr.Row(): |
|
|
btn_toggle_control = gr.Button("Toggle Controlled") |
|
|
btn_toggle_pov = gr.Button("Toggle POV") |
|
|
overlay = gr.Checkbox(False, label="Overlay reticle") |
|
|
|
|
|
tile_pick = gr.Radio( |
|
|
choices=[(TILE_NAMES[k], k) for k in [EMPTY, WALL, FOOD, NOISE, DOOR, TELE, KEY, EXIT, ARTIFACT, HAZARD, WOOD, ORE, MEDKIT, SWITCH, BASE]], |
|
|
value=WALL, |
|
|
label="Paint tile type" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
gr.Markdown("### Training Controls (Tabular Q-learning)") |
|
|
use_q = gr.Checkbox(True, label="Use Q-learning (agents with brain='q')") |
|
|
alpha = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="alpha") |
|
|
gamma = gr.Slider(0.5, 0.99, value=0.95, step=0.01, label="gamma") |
|
|
eps = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="epsilon") |
|
|
eps_decay = gr.Slider(0.90, 0.999, value=0.995, step=0.001, label="epsilon decay") |
|
|
eps_min = gr.Slider(0.0, 0.2, value=0.02, step=0.01, label="epsilon min") |
|
|
|
|
|
episodes = gr.Number(value=50, label="Train episodes", precision=0) |
|
|
max_steps = gr.Number(value=260, label="Max steps/episode", precision=0) |
|
|
btn_train = gr.Button("Train") |
|
|
|
|
|
btn_reset_all = gr.Button("Reset ALL (wipe Q + metrics)") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("### Timeline + Branching") |
|
|
rewind = gr.Slider(0, 0, value=0, step=1, label="Rewind index (active branch)") |
|
|
btn_jump = gr.Button("Jump to index") |
|
|
new_branch_name = gr.Textbox(value="fork1", label="New branch name") |
|
|
btn_fork = gr.Button("Fork from current rewind") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
branch_pick = gr.Dropdown(choices=["main"], value="main", label="Active branch") |
|
|
btn_set_branch = gr.Button("Set Active Branch") |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
export_box = gr.Textbox(label="Export JSON (+ proof hash)", lines=8) |
|
|
btn_export = gr.Button("Export") |
|
|
import_box = gr.Textbox(label="Import JSON", lines=8) |
|
|
btn_import = gr.Button("Import") |
|
|
|
|
|
|
|
|
timer = gr.Timer(value=0.20, active=False) |
|
|
|
|
|
|
|
|
def refresh(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str, bel: Dict[str, np.ndarray], r: int): |
|
|
snaps = branches_d.get(active, []) |
|
|
r_max = max(0, len(snaps) - 1) |
|
|
r = max(0, min(int(r), r_max)) |
|
|
pov, tr, ba, bb, dimg, stxt, etxt, ttxt, sb = build_views(state, bel) |
|
|
branch_choices = sorted(list(branches_d.keys())) |
|
|
return ( |
|
|
pov, tr, ba, bb, dimg, stxt, sb, etxt, ttxt, |
|
|
gr.update(maximum=r_max, value=r), r, |
|
|
gr.update(choices=branch_choices, value=active), |
|
|
gr.update(choices=branch_choices, value=active), |
|
|
) |
|
|
|
|
|
def push_hist(state: WorldState, branches_d: Dict[str, List[Snapshot]], active: str) -> Dict[str, List[Snapshot]]: |
|
|
branches_d.setdefault(active, []) |
|
|
branches_d[active].append(snapshot_of(state, active)) |
|
|
if len(branches_d[active]) > MAX_HISTORY: |
|
|
branches_d[active].pop(0) |
|
|
return branches_d |
|
|
|
|
|
def set_cfg(state: WorldState, use_q_v: bool, a: float, g: float, e: float, ed: float, emin: float) -> WorldState: |
|
|
state.cfg.use_q = bool(use_q_v) |
|
|
state.cfg.alpha = float(a) |
|
|
state.cfg.gamma = float(g) |
|
|
state.gmetrics.epsilon = float(e) |
|
|
state.cfg.epsilon_decay = float(ed) |
|
|
state.cfg.epsilon_min = float(emin) |
|
|
return state |
|
|
|
|
|
def do_manual(state, branches_d, active, bel, r, act): |
|
|
tick(state, bel, manual_action=act) |
|
|
branches_d = push_hist(state, branches_d, active) |
|
|
r = len(branches_d[active]) - 1 |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def do_tick(state, branches_d, active, bel, r): |
|
|
tick(state, bel, manual_action=None) |
|
|
branches_d = push_hist(state, branches_d, active) |
|
|
r = len(branches_d[active]) - 1 |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def do_run(state, branches_d, active, bel, r, n): |
|
|
n = max(1, int(n)) |
|
|
for _ in range(n): |
|
|
if state.done: |
|
|
break |
|
|
tick(state, bel, manual_action=None) |
|
|
branches_d = push_hist(state, branches_d, active) |
|
|
r = len(branches_d[active]) - 1 |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def toggle_control(state, branches_d, active, bel, r): |
|
|
order = list(state.agents.keys()) |
|
|
i = order.index(state.controlled) |
|
|
state.controlled = order[(i + 1) % len(order)] |
|
|
state.event_log.append(f"Controlled -> {state.controlled}") |
|
|
branches_d = push_hist(state, branches_d, active) |
|
|
r = len(branches_d[active]) - 1 |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def toggle_pov(state, branches_d, active, bel, r): |
|
|
order = list(state.agents.keys()) |
|
|
i = order.index(state.pov) |
|
|
state.pov = order[(i + 1) % len(order)] |
|
|
state.event_log.append(f"POV -> {state.pov}") |
|
|
branches_d = push_hist(state, branches_d, active) |
|
|
r = len(branches_d[active]) - 1 |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def set_overlay(state, branches_d, active, bel, r, ov): |
|
|
state.overlay = bool(ov) |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def click_truth(tile, state, branches_d, active, bel, r, evt: gr.SelectData): |
|
|
state = grid_click_to_tile(evt, int(tile), state) |
|
|
branches_d = push_hist(state, branches_d, active) |
|
|
r = len(branches_d[active]) - 1 |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def jump(state, branches_d, active, bel, r, idx): |
|
|
snaps = branches_d.get(active, []) |
|
|
if not snaps: |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
idx = max(0, min(int(idx), len(snaps) - 1)) |
|
|
state = restore_into(state, snaps[idx]) |
|
|
r = idx |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def fork_branch(state, branches_d, active, bel, r, new_name): |
|
|
new_name = (new_name or "").strip() or "fork" |
|
|
new_name = new_name.replace(" ", "_") |
|
|
snaps = branches_d.get(active, []) |
|
|
if not snaps: |
|
|
branches_d[new_name] = [snapshot_of(state, new_name)] |
|
|
else: |
|
|
idx = max(0, min(int(r), len(snaps) - 1)) |
|
|
branches_d[new_name] = [Snapshot(**asdict(s)) for s in snaps[:idx + 1]] |
|
|
state = restore_into(state, branches_d[new_name][-1]) |
|
|
active = new_name |
|
|
state.event_log.append(f"Forked branch -> {new_name}") |
|
|
branches_d = push_hist(state, branches_d, active) |
|
|
r = len(branches_d[active]) - 1 |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def set_active_branch(state, branches_d, active, bel, r, br): |
|
|
br = br or "main" |
|
|
if br not in branches_d: |
|
|
branches_d[br] = [snapshot_of(state, br)] |
|
|
active = br |
|
|
if branches_d[active]: |
|
|
state = restore_into(state, branches_d[active][-1]) |
|
|
bel = init_beliefs(list(state.agents.keys())) |
|
|
r = len(branches_d[active]) - 1 |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def apply_env_seed(state, branches_d, active, bel, r, env_key, seed_val): |
|
|
env_key = env_key or "chase" |
|
|
seed_val = int(seed_val) if seed_val is not None else state.seed |
|
|
|
|
|
|
|
|
old_cfg = state.cfg |
|
|
old_q = state.q_tables |
|
|
old_gm = state.gmetrics |
|
|
|
|
|
state = init_state(seed_val, env_key) |
|
|
state.cfg = old_cfg |
|
|
state.q_tables = old_q |
|
|
state.gmetrics = old_gm |
|
|
|
|
|
bel = init_beliefs(list(state.agents.keys())) |
|
|
active = "main" |
|
|
branches_d = {"main": [snapshot_of(state, "main")]} |
|
|
r = 0 |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def reset_ep(state, branches_d, active, bel, r): |
|
|
state = reset_episode_keep_learning(state, seed=state.seed) |
|
|
bel = init_beliefs(list(state.agents.keys())) |
|
|
branches_d = {active: [snapshot_of(state, active)]} |
|
|
r = 0 |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def reset_all(state, branches_d, active, bel, r, env_key, seed_val): |
|
|
env_key = env_key or state.env_key |
|
|
seed_val = int(seed_val) if seed_val is not None else state.seed |
|
|
state = wipe_all(seed=seed_val, env_key=env_key) |
|
|
bel = init_beliefs(list(state.agents.keys())) |
|
|
active = "main" |
|
|
branches_d = {"main": [snapshot_of(state, "main")]} |
|
|
r = 0 |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def do_train(state, branches_d, active, bel, r, |
|
|
use_q_v, a, g, e, ed, emin, |
|
|
eps_count, max_s): |
|
|
state = set_cfg(state, use_q_v, a, g, e, ed, emin) |
|
|
state = train(state, episodes=max(1, int(eps_count)), max_steps=max(10, int(max_s))) |
|
|
bel = init_beliefs(list(state.agents.keys())) |
|
|
branches_d = {"main": [snapshot_of(state, "main")]} |
|
|
active = "main" |
|
|
r = 0 |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def cinematic_run(state, branches_d, active, bel, r, max_s): |
|
|
max_s = max(10, int(max_s)) |
|
|
|
|
|
state = reset_episode_keep_learning(state, seed=state.seed) |
|
|
bel = init_beliefs(list(state.agents.keys())) |
|
|
|
|
|
while state.step < max_s and not state.done: |
|
|
tick(state, bel, manual_action=None) |
|
|
|
|
|
state.event_log.append(f"Cinematic finished: outcome={state.outcome} steps={state.step}") |
|
|
branches_d = push_hist(state, branches_d, active) |
|
|
r = len(branches_d[active]) - 1 |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
def export_fn(state, branches_d, active, r): |
|
|
return export_run(state, branches_d, active, int(r)) |
|
|
|
|
|
def import_fn(txt): |
|
|
state, branches_d, active, r, bel = import_run(txt) |
|
|
branches_d.setdefault(active, []) |
|
|
if not branches_d[active]: |
|
|
branches_d[active].append(snapshot_of(state, active)) |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r) |
|
|
|
|
|
|
|
|
def autoplay_start(state, branches_d, active, bel, r, interval_s): |
|
|
interval_s = float(interval_s) |
|
|
|
|
|
return ( |
|
|
gr.update(value=interval_s, active=True), |
|
|
True, |
|
|
state, branches_d, active, bel, r |
|
|
) |
|
|
|
|
|
def autoplay_stop(state, branches_d, active, bel, r): |
|
|
return ( |
|
|
gr.update(active=False), |
|
|
False, |
|
|
state, branches_d, active, bel, r |
|
|
) |
|
|
|
|
|
def autoplay_tick(state, branches_d, active, bel, r, is_on: bool): |
|
|
|
|
|
if not is_on: |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r, is_on, gr.update()) |
|
|
|
|
|
|
|
|
if not state.done: |
|
|
tick(state, bel, manual_action=None) |
|
|
branches_d = push_hist(state, branches_d, active) |
|
|
r = len(branches_d[active]) - 1 |
|
|
|
|
|
|
|
|
if state.done: |
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r, False, gr.update(active=False)) |
|
|
|
|
|
out = refresh(state, branches_d, active, bel, r) |
|
|
return out + (state, branches_d, active, bel, r, True, gr.update()) |
|
|
|
|
|
|
|
|
common_outputs = [ |
|
|
pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace, |
|
|
rewind, rewind_idx, branch_pick, branch_pick, |
|
|
st, branches, active_branch, beliefs, rewind_idx |
|
|
] |
|
|
|
|
|
btn_L.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"L"), |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_F.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"F"), |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_R.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"R"), |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_I.click(lambda s,b,a,bel,r: do_manual(s,b,a,bel,r,"I"), |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_tick.click(do_tick, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_run.click(do_run, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx, run_steps], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_toggle_control.click(toggle_control, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_toggle_pov.click(toggle_pov, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
overlay.change(set_overlay, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx, overlay], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
truth.select(click_truth, |
|
|
inputs=[tile_pick, st, branches, active_branch, beliefs, rewind_idx], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_jump.click(jump, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx, rewind], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_fork.click(fork_branch, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx, new_branch_name], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_set_branch.click(set_active_branch, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx, branch_pick], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_apply_env_seed.click(apply_env_seed, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick, seed_box], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_reset_ep.click(reset_ep, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_reset_all.click(reset_all, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx, env_pick, seed_box], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_train.click(do_train, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx, |
|
|
use_q, alpha, gamma, eps, eps_decay, eps_min, |
|
|
episodes, max_steps], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_cinematic.click(cinematic_run, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx, cinematic_steps], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
btn_export.click(export_fn, inputs=[st, branches, active_branch, rewind_idx], outputs=[export_box], queue=True) |
|
|
|
|
|
btn_import.click(import_fn, |
|
|
inputs=[import_box], |
|
|
outputs=common_outputs, queue=True) |
|
|
|
|
|
|
|
|
btn_autoplay_start.click( |
|
|
autoplay_start, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx, autoplay_speed], |
|
|
outputs=[timer, autoplay_on, st, branches, active_branch, beliefs, rewind_idx], |
|
|
queue=True |
|
|
) |
|
|
|
|
|
btn_autoplay_stop.click( |
|
|
autoplay_stop, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx], |
|
|
outputs=[timer, autoplay_on, st, branches, active_branch, beliefs, rewind_idx], |
|
|
queue=True |
|
|
) |
|
|
|
|
|
|
|
|
timer.tick( |
|
|
autoplay_tick, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx, autoplay_on], |
|
|
outputs=[ |
|
|
pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace, |
|
|
rewind, rewind_idx, branch_pick, branch_pick, |
|
|
st, branches, active_branch, beliefs, rewind_idx, |
|
|
autoplay_on, timer |
|
|
], |
|
|
queue=True |
|
|
) |
|
|
|
|
|
demo.load( |
|
|
refresh, |
|
|
inputs=[st, branches, active_branch, beliefs, rewind_idx], |
|
|
outputs=[ |
|
|
pov_img, truth, belief_a, belief_b, dash, status, scoreboard, events, trace, |
|
|
rewind, rewind_idx, branch_pick, branch_pick |
|
|
], |
|
|
queue=True |
|
|
) |
|
|
|
|
|
|
|
|
demo.queue().launch(ssr_mode=False) |
|
|
|