|
|
import json |
|
|
import math |
|
|
from dataclasses import dataclass, asdict |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image, ImageDraw |
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GRID_W, GRID_H = 21, 15 |
|
|
TILE = 22 |
|
|
|
|
|
VIEW_W, VIEW_H = 640, 360 |
|
|
RAY_W = 320 |
|
|
FOV_DEG = 78 |
|
|
MAX_DEPTH = 20 |
|
|
|
|
|
DIRS = [(1, 0), (0, 1), (-1, 0), (0, -1)] |
|
|
ORI_DEG = [0, 90, 180, 270] |
|
|
|
|
|
EMPTY = 0 |
|
|
WALL = 1 |
|
|
FOOD = 2 |
|
|
NOISE = 3 |
|
|
DOOR = 4 |
|
|
TELE = 5 |
|
|
|
|
|
TILE_NAMES = { |
|
|
EMPTY: "Empty", |
|
|
WALL: "Wall", |
|
|
FOOD: "Food", |
|
|
NOISE: "Noise", |
|
|
DOOR: "Door", |
|
|
TELE: "Teleporter", |
|
|
} |
|
|
|
|
|
AGENT_COLORS = { |
|
|
"Predator": (255, 120, 90), |
|
|
"Prey": (120, 255, 160), |
|
|
"Scout": (120, 190, 255), |
|
|
} |
|
|
|
|
|
SKY = np.array([14, 16, 26], dtype=np.uint8) |
|
|
FLOOR_NEAR = np.array([24, 26, 40], dtype=np.uint8) |
|
|
FLOOR_FAR = np.array([10, 11, 18], dtype=np.uint8) |
|
|
WALL_BASE = np.array([210, 210, 225], dtype=np.uint8) |
|
|
WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8) |
|
|
DOOR_COL = np.array([180, 210, 255], dtype=np.uint8) |
|
|
|
|
|
ACTIONS = ["L", "F", "R"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator: |
|
|
mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531) |
|
|
return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Agent: |
|
|
name: str |
|
|
x: int |
|
|
y: int |
|
|
ori: int |
|
|
energy: int = 100 |
|
|
|
|
|
@dataclass |
|
|
class TrainConfig: |
|
|
use_q_pred: bool = True |
|
|
use_q_prey: bool = True |
|
|
alpha: float = 0.15 |
|
|
gamma: float = 0.95 |
|
|
epsilon: float = 0.10 |
|
|
epsilon_min: float = 0.02 |
|
|
epsilon_decay: float = 0.995 |
|
|
|
|
|
|
|
|
pred_step_penalty: float = -0.02 |
|
|
pred_dist_coeff: float = 0.03 |
|
|
pred_catch_reward: float = 3.0 |
|
|
|
|
|
prey_step_penalty: float = -0.02 |
|
|
prey_food_reward: float = 0.6 |
|
|
prey_survive_reward: float = 0.02 |
|
|
prey_caught_penalty: float = -3.0 |
|
|
|
|
|
@dataclass |
|
|
class Metrics: |
|
|
episodes: int = 0 |
|
|
catches: int = 0 |
|
|
avg_steps_to_catch: float = 0.0 |
|
|
avg_path_efficiency: float = 0.0 |
|
|
last_episode_steps: int = 0 |
|
|
last_episode_eff: float = 0.0 |
|
|
epsilon: float = 0.10 |
|
|
|
|
|
@dataclass |
|
|
class WorldState: |
|
|
seed: int |
|
|
step: int |
|
|
grid: List[List[int]] |
|
|
agents: Dict[str, Agent] |
|
|
controlled: str |
|
|
pov: str |
|
|
overlay: bool |
|
|
|
|
|
caught: bool |
|
|
branches: Dict[str, int] |
|
|
|
|
|
|
|
|
event_log: List[str] |
|
|
trace_log: List[str] |
|
|
|
|
|
|
|
|
cfg: TrainConfig |
|
|
q_pred: Dict[str, List[float]] |
|
|
q_prey: Dict[str, List[float]] |
|
|
metrics: Metrics |
|
|
|
|
|
@dataclass |
|
|
class Snapshot: |
|
|
step: int |
|
|
agents: Dict[str, Dict] |
|
|
grid: List[List[int]] |
|
|
caught: bool |
|
|
event_log_tail: List[str] |
|
|
trace_tail: List[str] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_grid() -> List[List[int]]: |
|
|
g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)] |
|
|
for x in range(GRID_W): |
|
|
g[0][x] = WALL |
|
|
g[GRID_H - 1][x] = WALL |
|
|
for y in range(GRID_H): |
|
|
g[y][0] = WALL |
|
|
g[y][GRID_W - 1] = WALL |
|
|
|
|
|
for x in range(4, 17): |
|
|
g[7][x] = WALL |
|
|
g[7][10] = DOOR |
|
|
|
|
|
g[3][4] = FOOD |
|
|
g[11][15] = FOOD |
|
|
g[4][14] = NOISE |
|
|
g[12][5] = NOISE |
|
|
g[2][18] = TELE |
|
|
g[13][2] = TELE |
|
|
return g |
|
|
|
|
|
def init_state(seed: int) -> WorldState: |
|
|
agents = { |
|
|
"Predator": Agent("Predator", 2, 2, 0, 100), |
|
|
"Prey": Agent("Prey", 18, 12, 2, 100), |
|
|
"Scout": Agent("Scout", 10, 3, 1, 100), |
|
|
} |
|
|
cfg = TrainConfig() |
|
|
return WorldState( |
|
|
seed=seed, |
|
|
step=0, |
|
|
grid=default_grid(), |
|
|
agents=agents, |
|
|
controlled="Predator", |
|
|
pov="Predator", |
|
|
overlay=False, |
|
|
caught=False, |
|
|
branches={"main": 0}, |
|
|
event_log=["Initialized world."], |
|
|
trace_log=[], |
|
|
cfg=cfg, |
|
|
q_pred={}, |
|
|
q_prey={}, |
|
|
metrics=Metrics(epsilon=cfg.epsilon), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_belief() -> Dict[str, np.ndarray]: |
|
|
b = {} |
|
|
for nm in ["Predator", "Prey", "Scout"]: |
|
|
b[nm] = -1 * np.ones((GRID_H, GRID_W), dtype=np.int16) |
|
|
return b |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def in_bounds(x: int, y: int) -> bool: |
|
|
return 0 <= x < GRID_W and 0 <= y < GRID_H |
|
|
|
|
|
def is_blocking(tile: int) -> bool: |
|
|
return tile == WALL |
|
|
|
|
|
def manhattan(a: Agent, b: Agent) -> int: |
|
|
return abs(a.x - b.x) + abs(a.y - b.y) |
|
|
|
|
|
def bresenham_los(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> bool: |
|
|
dx = abs(x1 - x0) |
|
|
dy = abs(y1 - y0) |
|
|
sx = 1 if x0 < x1 else -1 |
|
|
sy = 1 if y0 < y1 else -1 |
|
|
err = dx - dy |
|
|
x, y = x0, y0 |
|
|
while True: |
|
|
if (x, y) != (x0, y0) and (x, y) != (x1, y1): |
|
|
if grid[y][x] == WALL: |
|
|
return False |
|
|
if x == x1 and y == y1: |
|
|
return True |
|
|
e2 = 2 * err |
|
|
if e2 > -dy: |
|
|
err -= dy |
|
|
x += sx |
|
|
if e2 < dx: |
|
|
err += dx |
|
|
y += sy |
|
|
|
|
|
def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = FOV_DEG) -> bool: |
|
|
dx = tx - observer.x |
|
|
dy = ty - observer.y |
|
|
if dx == 0 and dy == 0: |
|
|
return True |
|
|
angle = math.degrees(math.atan2(dy, dx)) % 360 |
|
|
facing = ORI_DEG[observer.ori] |
|
|
diff = (angle - facing + 540) % 360 - 180 |
|
|
return abs(diff) <= (fov_deg / 2) |
|
|
|
|
|
def visible(observer: Agent, target: Agent, grid: List[List[int]]) -> bool: |
|
|
return within_fov(observer, target.x, target.y, FOV_DEG) and bresenham_los(grid, observer.x, observer.y, target.x, target.y) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def turn_left(a: Agent) -> None: |
|
|
a.ori = (a.ori - 1) % 4 |
|
|
|
|
|
def turn_right(a: Agent) -> None: |
|
|
a.ori = (a.ori + 1) % 4 |
|
|
|
|
|
def move_forward(state: WorldState, a: Agent) -> str: |
|
|
dx, dy = DIRS[a.ori] |
|
|
nx, ny = a.x + dx, a.y + dy |
|
|
if not in_bounds(nx, ny): |
|
|
return "blocked: bounds" |
|
|
if is_blocking(state.grid[ny][nx]): |
|
|
return "blocked: wall" |
|
|
if state.grid[ny][nx] == DOOR: |
|
|
state.grid[ny][nx] = EMPTY |
|
|
state.event_log.append(f"t={state.step}: {a.name} opened a door.") |
|
|
a.x, a.y = nx, ny |
|
|
|
|
|
if state.grid[ny][nx] == TELE: |
|
|
teles = [(x, y) for y in range(GRID_H) for x in range(GRID_W) if state.grid[y][x] == TELE] |
|
|
if len(teles) >= 2: |
|
|
teles_sorted = sorted(teles) |
|
|
idx = teles_sorted.index((nx, ny)) |
|
|
dest = teles_sorted[(idx + 1) % len(teles_sorted)] |
|
|
a.x, a.y = dest |
|
|
state.event_log.append(f"t={state.step}: {a.name} teleported.") |
|
|
return "moved: teleported" |
|
|
return "moved" |
|
|
|
|
|
def apply_action(state: WorldState, agent_name: str, action: str) -> str: |
|
|
a = state.agents[agent_name] |
|
|
if action == "L": |
|
|
turn_left(a) |
|
|
return "turned left" |
|
|
if action == "R": |
|
|
turn_right(a) |
|
|
return "turned right" |
|
|
if action == "F": |
|
|
return move_forward(state, a) |
|
|
return "noop" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def raycast_view(state: WorldState, observer: Agent) -> np.ndarray: |
|
|
img = np.zeros((VIEW_H, VIEW_W, 3), dtype=np.uint8) |
|
|
img[:, :] = SKY |
|
|
|
|
|
for y in range(VIEW_H // 2, VIEW_H): |
|
|
t = (y - VIEW_H // 2) / (VIEW_H // 2 + 1e-6) |
|
|
col = (1 - t) * FLOOR_NEAR + t * FLOOR_FAR |
|
|
img[y, :] = col.astype(np.uint8) |
|
|
|
|
|
fov = math.radians(FOV_DEG) |
|
|
half_fov = fov / 2 |
|
|
|
|
|
for rx in range(RAY_W): |
|
|
cam_x = (2 * rx / (RAY_W - 1)) - 1 |
|
|
ray_ang = math.radians(ORI_DEG[observer.ori]) + cam_x * half_fov |
|
|
|
|
|
ox, oy = observer.x + 0.5, observer.y + 0.5 |
|
|
sin_a = math.sin(ray_ang) |
|
|
cos_a = math.cos(ray_ang) |
|
|
|
|
|
depth = 0.0 |
|
|
hit = None |
|
|
side = 0 |
|
|
|
|
|
while depth < MAX_DEPTH: |
|
|
depth += 0.05 |
|
|
tx = int(ox + cos_a * depth) |
|
|
ty = int(oy + sin_a * depth) |
|
|
if not in_bounds(tx, ty): |
|
|
break |
|
|
tile = state.grid[ty][tx] |
|
|
if tile == WALL: |
|
|
hit = "wall" |
|
|
side = 1 if abs(cos_a) > abs(sin_a) else 0 |
|
|
break |
|
|
if tile == DOOR: |
|
|
hit = "door" |
|
|
break |
|
|
|
|
|
if hit is None: |
|
|
continue |
|
|
|
|
|
depth *= math.cos(ray_ang - math.radians(ORI_DEG[observer.ori])) |
|
|
depth = max(depth, 0.001) |
|
|
|
|
|
proj_h = int((VIEW_H * 0.9) / depth) |
|
|
y0 = max(0, VIEW_H // 2 - proj_h // 2) |
|
|
y1 = min(VIEW_H - 1, VIEW_H // 2 + proj_h // 2) |
|
|
|
|
|
if hit == "door": |
|
|
col = DOOR_COL.copy() |
|
|
else: |
|
|
col = WALL_BASE.copy() if side == 0 else WALL_SIDE.copy() |
|
|
|
|
|
dim = max(0.25, 1.0 - (depth / MAX_DEPTH)) |
|
|
col = (col * dim).astype(np.uint8) |
|
|
|
|
|
x0 = int(rx * (VIEW_W / RAY_W)) |
|
|
x1 = int((rx + 1) * (VIEW_W / RAY_W)) |
|
|
img[y0:y1, x0:x1] = col |
|
|
|
|
|
|
|
|
for nm, other in state.agents.items(): |
|
|
if nm == observer.name: |
|
|
continue |
|
|
if visible(observer, other, state.grid): |
|
|
dx = other.x - observer.x |
|
|
dy = other.y - observer.y |
|
|
ang = (math.degrees(math.atan2(dy, dx)) % 360) |
|
|
facing = ORI_DEG[observer.ori] |
|
|
diff = (ang - facing + 540) % 360 - 180 |
|
|
sx = int((diff / (FOV_DEG / 2)) * (VIEW_W / 2) + (VIEW_W / 2)) |
|
|
dist = math.sqrt(dx * dx + dy * dy) |
|
|
h = int((VIEW_H * 0.65) / max(dist, 0.75)) |
|
|
w = max(10, h // 3) |
|
|
y_mid = VIEW_H // 2 |
|
|
y0 = max(0, y_mid - h // 2) |
|
|
y1 = min(VIEW_H - 1, y_mid + h // 2) |
|
|
x0 = max(0, sx - w // 2) |
|
|
x1 = min(VIEW_W - 1, sx + w // 2) |
|
|
col = AGENT_COLORS.get(nm, (255, 200, 120)) |
|
|
img[y0:y1, x0:x1] = np.array(col, dtype=np.uint8) |
|
|
|
|
|
if state.overlay: |
|
|
cx, cy = VIEW_W // 2, VIEW_H // 2 |
|
|
img[cy - 1:cy + 2, cx - 10:cx + 10] = np.array([120, 190, 255], dtype=np.uint8) |
|
|
img[cy - 10:cy + 10, cx - 1:cx + 2] = np.array([120, 190, 255], dtype=np.uint8) |
|
|
|
|
|
return img |
|
|
|
|
|
def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_agents: bool = True) -> Image.Image: |
|
|
w = grid.shape[1] * TILE |
|
|
h = grid.shape[0] * TILE |
|
|
im = Image.new("RGB", (w, h + 28), (10, 12, 18)) |
|
|
draw = ImageDraw.Draw(im) |
|
|
|
|
|
for y in range(grid.shape[0]): |
|
|
for x in range(grid.shape[1]): |
|
|
t = int(grid[y, x]) |
|
|
if t == -1: |
|
|
col = (18, 20, 32) |
|
|
elif t == EMPTY: |
|
|
col = (26, 30, 44) |
|
|
elif t == WALL: |
|
|
col = (190, 190, 210) |
|
|
elif t == FOOD: |
|
|
col = (255, 210, 120) |
|
|
elif t == NOISE: |
|
|
col = (255, 120, 220) |
|
|
elif t == DOOR: |
|
|
col = (140, 210, 255) |
|
|
elif t == TELE: |
|
|
col = (120, 190, 255) |
|
|
else: |
|
|
col = (80, 80, 90) |
|
|
|
|
|
x0, y0 = x * TILE, y * TILE + 28 |
|
|
draw.rectangle([x0, y0, x0 + TILE - 1, y0 + TILE - 1], fill=col) |
|
|
|
|
|
for x in range(grid.shape[1] + 1): |
|
|
xx = x * TILE |
|
|
draw.line([xx, 28, xx, h + 28], fill=(12, 14, 22)) |
|
|
for y in range(grid.shape[0] + 1): |
|
|
yy = y * TILE + 28 |
|
|
draw.line([0, yy, w, yy], fill=(12, 14, 22)) |
|
|
|
|
|
if show_agents: |
|
|
for nm, a in agents.items(): |
|
|
cx = a.x * TILE + TILE // 2 |
|
|
cy = a.y * TILE + 28 + TILE // 2 |
|
|
col = AGENT_COLORS.get(nm, (220, 220, 220)) |
|
|
r = TILE // 3 |
|
|
draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col) |
|
|
dx, dy = DIRS[a.ori] |
|
|
draw.line([cx, cy, cx + dx * r, cy + dy * r], fill=(10, 10, 10), width=3) |
|
|
|
|
|
draw.rectangle([0, 0, w, 28], fill=(14, 16, 26)) |
|
|
draw.text((8, 6), title, fill=(230, 230, 240)) |
|
|
return im |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> None: |
|
|
belief[agent.y, agent.x] = state.grid[agent.y][agent.x] |
|
|
base = math.radians(ORI_DEG[agent.ori]) |
|
|
half = math.radians(FOV_DEG / 2) |
|
|
rays = 33 if agent.name != "Scout" else 45 |
|
|
|
|
|
for i in range(rays): |
|
|
t = i / (rays - 1) |
|
|
ang = base + (t * 2 - 1) * half |
|
|
sin_a, cos_a = math.sin(ang), math.cos(ang) |
|
|
ox, oy = agent.x + 0.5, agent.y + 0.5 |
|
|
depth = 0.0 |
|
|
while depth < MAX_DEPTH: |
|
|
depth += 0.2 |
|
|
tx = int(ox + cos_a * depth) |
|
|
ty = int(oy + sin_a * depth) |
|
|
if not in_bounds(tx, ty): |
|
|
break |
|
|
belief[ty, tx] = state.grid[ty][tx] |
|
|
if state.grid[ty][tx] == WALL: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bfs_distance(grid: List[List[int]], sx: int, sy: int, gx: int, gy: int) -> Optional[int]: |
|
|
if (sx, sy) == (gx, gy): |
|
|
return 0 |
|
|
q = [(sx, sy)] |
|
|
dist = { (sx, sy): 0 } |
|
|
head = 0 |
|
|
while head < len(q): |
|
|
x, y = q[head]; head += 1 |
|
|
for dx, dy in DIRS: |
|
|
nx, ny = x + dx, y + dy |
|
|
if not in_bounds(nx, ny): |
|
|
continue |
|
|
if grid[ny][nx] == WALL: |
|
|
continue |
|
|
if (nx, ny) in dist: |
|
|
continue |
|
|
dist[(nx, ny)] = dist[(x, y)] + 1 |
|
|
if (nx, ny) == (gx, gy): |
|
|
return dist[(nx, ny)] |
|
|
q.append((nx, ny)) |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def obs_key(state: WorldState, who: str) -> str: |
|
|
pred = state.agents["Predator"] |
|
|
prey = state.agents["Prey"] |
|
|
a = state.agents[who] |
|
|
|
|
|
dx = prey.x - pred.x |
|
|
dy = prey.y - pred.y |
|
|
dx_bin = int(np.clip(dx, -6, 6)) |
|
|
dy_bin = int(np.clip(dy, -6, 6)) |
|
|
vis = 1 if visible(pred, prey, state.grid) else 0 |
|
|
|
|
|
if who == "Predator": |
|
|
return f"P|{pred.x},{pred.y},{pred.ori}|d{dx_bin},{dy_bin}|v{vis}" |
|
|
if who == "Prey": |
|
|
|
|
|
vis2 = 1 if visible(prey, pred, state.grid) else 0 |
|
|
ddx = pred.x - prey.x |
|
|
ddy = pred.y - prey.y |
|
|
ddx_bin = int(np.clip(ddx, -6, 6)) |
|
|
ddy_bin = int(np.clip(ddy, -6, 6)) |
|
|
return f"R|{prey.x},{prey.y},{prey.ori}|d{ddx_bin},{ddy_bin}|v{vis2}|e{int(prey.energy//25)}" |
|
|
|
|
|
return f"S|{a.x},{a.y},{a.ori}" |
|
|
|
|
|
def q_get(q: Dict[str, List[float]], key: str) -> List[float]: |
|
|
if key not in q: |
|
|
q[key] = [0.0, 0.0, 0.0] |
|
|
return q[key] |
|
|
|
|
|
def epsilon_greedy(qvals: List[float], eps: float, r: np.random.Generator) -> int: |
|
|
if r.random() < eps: |
|
|
return int(r.integers(0, len(qvals))) |
|
|
return int(np.argmax(qvals)) |
|
|
|
|
|
def q_update(q: Dict[str, List[float]], key: str, a_idx: int, reward: float, next_key: str, alpha: float, gamma: float) -> Tuple[float, float, float]: |
|
|
qv = q_get(q, key) |
|
|
nq = q_get(q, next_key) |
|
|
old = qv[a_idx] |
|
|
target = reward + gamma * float(np.max(nq)) |
|
|
new = old + alpha * (target - old) |
|
|
qv[a_idx] = new |
|
|
return old, target, new |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def heuristic_pred_action(state: WorldState) -> str: |
|
|
pred = state.agents["Predator"] |
|
|
prey = state.agents["Prey"] |
|
|
if visible(pred, prey, state.grid): |
|
|
dx = prey.x - pred.x |
|
|
dy = prey.y - pred.y |
|
|
ang = (math.degrees(math.atan2(dy, dx)) % 360) |
|
|
facing = ORI_DEG[pred.ori] |
|
|
diff = (ang - facing + 540) % 360 - 180 |
|
|
if diff < -10: |
|
|
return "L" |
|
|
if diff > 10: |
|
|
return "R" |
|
|
return "F" |
|
|
r = rng_for(state.seed, state.step, stream=11) |
|
|
return r.choice(ACTIONS) |
|
|
|
|
|
def heuristic_prey_action(state: WorldState) -> str: |
|
|
prey = state.agents["Prey"] |
|
|
pred = state.agents["Predator"] |
|
|
if visible(prey, pred, state.grid): |
|
|
dx = pred.x - prey.x |
|
|
dy = pred.y - prey.y |
|
|
ang = (math.degrees(math.atan2(dy, dx)) % 360) |
|
|
facing = ORI_DEG[prey.ori] |
|
|
diff = (ang - facing + 540) % 360 - 180 |
|
|
diff_away = ((diff + 180) + 540) % 360 - 180 |
|
|
if diff_away < -10: |
|
|
return "L" |
|
|
if diff_away > 10: |
|
|
return "R" |
|
|
return "F" |
|
|
r = rng_for(state.seed, state.step, stream=12) |
|
|
return r.choice(ACTIONS) |
|
|
|
|
|
def heuristic_scout_action(state: WorldState) -> str: |
|
|
r = rng_for(state.seed, state.step, stream=13) |
|
|
return r.choice(ACTIONS) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pred_reward(state_prev: WorldState, state_now: WorldState) -> float: |
|
|
cfg = state_now.cfg |
|
|
pred0 = state_prev.agents["Predator"] |
|
|
prey0 = state_prev.agents["Prey"] |
|
|
pred1 = state_now.agents["Predator"] |
|
|
prey1 = state_now.agents["Prey"] |
|
|
d0 = abs(pred0.x - prey0.x) + abs(pred0.y - prey0.y) |
|
|
d1 = abs(pred1.x - prey1.x) + abs(pred1.y - prey1.y) |
|
|
r = cfg.pred_step_penalty + cfg.pred_dist_coeff * (d0 - d1) |
|
|
if state_now.caught: |
|
|
r += cfg.pred_catch_reward |
|
|
return float(r) |
|
|
|
|
|
def prey_reward(state_prev: WorldState, state_now: WorldState, ate_food: bool) -> float: |
|
|
cfg = state_now.cfg |
|
|
r = cfg.prey_step_penalty + cfg.prey_survive_reward |
|
|
if ate_food: |
|
|
r += cfg.prey_food_reward |
|
|
if state_now.caught: |
|
|
r += cfg.prey_caught_penalty |
|
|
return float(r) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TRACE_MAX = 400 |
|
|
|
|
|
def clone_shallow(state: WorldState) -> WorldState: |
|
|
|
|
|
return WorldState( |
|
|
seed=state.seed, |
|
|
step=state.step, |
|
|
grid=[row[:] for row in state.grid], |
|
|
agents={k: Agent(**asdict(v)) for k, v in state.agents.items()}, |
|
|
controlled=state.controlled, |
|
|
pov=state.pov, |
|
|
overlay=state.overlay, |
|
|
caught=state.caught, |
|
|
branches=dict(state.branches), |
|
|
event_log=list(state.event_log), |
|
|
trace_log=list(state.trace_log), |
|
|
cfg=state.cfg, |
|
|
q_pred=state.q_pred, |
|
|
q_prey=state.q_prey, |
|
|
metrics=state.metrics, |
|
|
) |
|
|
|
|
|
def check_catch(state: WorldState) -> None: |
|
|
pred = state.agents["Predator"] |
|
|
prey = state.agents["Prey"] |
|
|
if pred.x == prey.x and pred.y == prey.y: |
|
|
state.caught = True |
|
|
state.event_log.append(f"t={state.step}: CAUGHT.") |
|
|
|
|
|
def consume_food(state: WorldState) -> bool: |
|
|
prey = state.agents["Prey"] |
|
|
if state.grid[prey.y][prey.x] == FOOD: |
|
|
prey.energy = min(200, prey.energy + 35) |
|
|
state.grid[prey.y][prey.x] = EMPTY |
|
|
state.event_log.append(f"t={state.step}: Prey ate food (+energy).") |
|
|
return True |
|
|
return False |
|
|
|
|
|
def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, Optional[Tuple[str,int]]]: |
|
|
""" |
|
|
Returns (action, reason, q_info) |
|
|
q_info: (obs_key, action_index) if chosen by Q, else None |
|
|
""" |
|
|
cfg = state.cfg |
|
|
r = rng_for(state.seed, state.step, stream=stream) |
|
|
|
|
|
if who == "Predator" and cfg.use_q_pred: |
|
|
k = obs_key(state, "Predator") |
|
|
qv = q_get(state.q_pred, k) |
|
|
a_idx = epsilon_greedy(qv, state.metrics.epsilon, r) |
|
|
return ACTIONS[a_idx], f"Q(pred) eps={state.metrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (k, a_idx) |
|
|
|
|
|
if who == "Prey" and cfg.use_q_prey: |
|
|
k = obs_key(state, "Prey") |
|
|
qv = q_get(state.q_prey, k) |
|
|
a_idx = epsilon_greedy(qv, state.metrics.epsilon, r) |
|
|
return ACTIONS[a_idx], f"Q(prey) eps={state.metrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (k, a_idx) |
|
|
|
|
|
|
|
|
if who == "Predator": |
|
|
a = heuristic_pred_action(state) |
|
|
return a, "heuristic(pred)", None |
|
|
if who == "Prey": |
|
|
a = heuristic_prey_action(state) |
|
|
return a, "heuristic(prey)", None |
|
|
a = heuristic_scout_action(state) |
|
|
return a, "heuristic(scout)", None |
|
|
|
|
|
def tick(state: WorldState, manual_action: Optional[str] = None) -> None: |
|
|
if state.caught: |
|
|
return |
|
|
|
|
|
prev = clone_shallow(state) |
|
|
|
|
|
|
|
|
pred = state.agents["Predator"] |
|
|
prey = state.agents["Prey"] |
|
|
opt_dist = bfs_distance(state.grid, pred.x, pred.y, prey.x, prey.y) |
|
|
if opt_dist is None: |
|
|
opt_dist = 999 |
|
|
|
|
|
|
|
|
chosen = {} |
|
|
reasons = {} |
|
|
qinfo = {} |
|
|
|
|
|
|
|
|
if manual_action: |
|
|
chosen[state.controlled] = manual_action |
|
|
reasons[state.controlled] = "manual" |
|
|
qinfo[state.controlled] = None |
|
|
|
|
|
|
|
|
for who in ["Predator", "Prey", "Scout"]: |
|
|
if who in chosen: |
|
|
continue |
|
|
act, reason, q_i = choose_action(state, who, stream={"Predator":21,"Prey":22,"Scout":23}[who]) |
|
|
chosen[who] = act |
|
|
reasons[who] = reason |
|
|
qinfo[who] = q_i |
|
|
|
|
|
|
|
|
outcomes = {} |
|
|
for who in ["Predator", "Prey", "Scout"]: |
|
|
outcomes[who] = apply_action(state, who, chosen[who]) |
|
|
|
|
|
ate = consume_food(state) |
|
|
check_catch(state) |
|
|
|
|
|
|
|
|
pred_r = pred_reward(prev, state) |
|
|
prey_r = prey_reward(prev, state, ate_food=ate) |
|
|
|
|
|
q_lines = [] |
|
|
if qinfo["Predator"] is not None: |
|
|
k, a_idx = qinfo["Predator"] |
|
|
nk = obs_key(state, "Predator") |
|
|
old, target, new = q_update(state.q_pred, k, a_idx, pred_r, nk, state.cfg.alpha, state.cfg.gamma) |
|
|
q_lines.append(f"Qpred: {k} a={ACTIONS[a_idx]} old={old:.3f} tgt={target:.3f} new={new:.3f}") |
|
|
|
|
|
if qinfo["Prey"] is not None: |
|
|
k, a_idx = qinfo["Prey"] |
|
|
nk = obs_key(state, "Prey") |
|
|
old, target, new = q_update(state.q_prey, k, a_idx, prey_r, nk, state.cfg.alpha, state.cfg.gamma) |
|
|
q_lines.append(f"Qprey: {k} a={ACTIONS[a_idx]} old={old:.3f} tgt={target:.3f} new={new:.3f}") |
|
|
|
|
|
|
|
|
dist_now = manhattan(state.agents["Predator"], state.agents["Prey"]) |
|
|
eff = (opt_dist / max(1, dist_now)) if dist_now > 0 else 1.0 |
|
|
trace = ( |
|
|
f"t={state.step} optDist~{opt_dist} distNow={dist_now} " |
|
|
f"| Pred:{chosen['Predator']} ({outcomes['Predator']}) [{reasons['Predator']}] r={pred_r:+.3f} " |
|
|
f"| Prey:{chosen['Prey']} ({outcomes['Prey']}) [{reasons['Prey']}] r={prey_r:+.3f} " |
|
|
f"| Scout:{chosen['Scout']} ({outcomes['Scout']}) [{reasons['Scout']}] " |
|
|
f"| ateFood={ate} caught={state.caught}" |
|
|
) |
|
|
if q_lines: |
|
|
trace += " | " + " ; ".join(q_lines) |
|
|
|
|
|
state.trace_log.append(trace) |
|
|
if len(state.trace_log) > TRACE_MAX: |
|
|
state.trace_log = state.trace_log[-TRACE_MAX:] |
|
|
|
|
|
state.step += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_episode(state: WorldState, seed: Optional[int] = None) -> None: |
|
|
|
|
|
if seed is None: |
|
|
seed = state.seed |
|
|
fresh = init_state(seed) |
|
|
fresh.cfg = state.cfg |
|
|
fresh.q_pred = state.q_pred |
|
|
fresh.q_prey = state.q_prey |
|
|
fresh.metrics = state.metrics |
|
|
fresh.metrics.epsilon = state.metrics.epsilon |
|
|
state.seed = fresh.seed |
|
|
state.step = 0 |
|
|
state.grid = fresh.grid |
|
|
state.agents = fresh.agents |
|
|
state.controlled = fresh.controlled |
|
|
state.pov = fresh.pov |
|
|
state.overlay = fresh.overlay |
|
|
state.caught = False |
|
|
state.branches = fresh.branches |
|
|
state.event_log = ["Episode reset."] |
|
|
state.trace_log = [] |
|
|
|
|
|
def run_episode(state: WorldState, max_steps: int) -> Tuple[bool, int, float]: |
|
|
|
|
|
start_pred = state.agents["Predator"] |
|
|
start_prey = state.agents["Prey"] |
|
|
opt = bfs_distance(state.grid, start_pred.x, start_pred.y, start_prey.x, start_prey.y) |
|
|
if opt is None: |
|
|
opt = 999 |
|
|
steps = 0 |
|
|
while steps < max_steps and not state.caught: |
|
|
tick(state, manual_action=None) |
|
|
steps += 1 |
|
|
caught = state.caught |
|
|
eff = float(opt / max(1, steps)) if opt < 999 else 0.0 |
|
|
return caught, steps, eff |
|
|
|
|
|
def train(state: WorldState, episodes: int, max_steps: int) -> None: |
|
|
m = state.metrics |
|
|
cfg = state.cfg |
|
|
catches = 0 |
|
|
total_steps_catch = 0 |
|
|
total_eff = 0.0 |
|
|
|
|
|
for ep in range(episodes): |
|
|
|
|
|
ep_seed = (state.seed * 1_000_003 + (m.episodes + ep) * 97_531) & 0xFFFFFFFF |
|
|
reset_episode(state, seed=int(ep_seed)) |
|
|
|
|
|
caught, steps, eff = run_episode(state, max_steps=max_steps) |
|
|
total_eff += eff |
|
|
|
|
|
if caught: |
|
|
catches += 1 |
|
|
total_steps_catch += steps |
|
|
|
|
|
|
|
|
m.epsilon = max(cfg.epsilon_min, m.epsilon * cfg.epsilon_decay) |
|
|
|
|
|
|
|
|
m.episodes += episodes |
|
|
m.catches += catches |
|
|
m.last_episode_steps = steps |
|
|
m.last_episode_eff = eff |
|
|
if catches > 0: |
|
|
|
|
|
avg_steps = total_steps_catch / catches |
|
|
m.avg_steps_to_catch = ( |
|
|
0.85 * m.avg_steps_to_catch + 0.15 * avg_steps |
|
|
if m.avg_steps_to_catch > 0 else avg_steps |
|
|
) |
|
|
avg_eff = total_eff / max(1, episodes) |
|
|
m.avg_path_efficiency = ( |
|
|
0.85 * m.avg_path_efficiency + 0.15 * avg_eff |
|
|
if m.avg_path_efficiency > 0 else avg_eff |
|
|
) |
|
|
|
|
|
state.event_log.append( |
|
|
f"Training: +{episodes} eps | catches={catches}/{episodes} | " |
|
|
f"avgStepsToCatch~{m.avg_steps_to_catch:.2f} | avgEff~{m.avg_path_efficiency:.2f} | eps={m.epsilon:.3f}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_HISTORY = 1200 |
|
|
|
|
|
def snapshot_of(state: WorldState) -> Snapshot: |
|
|
return Snapshot( |
|
|
step=state.step, |
|
|
agents={k: asdict(v) for k, v in state.agents.items()}, |
|
|
grid=[row[:] for row in state.grid], |
|
|
caught=state.caught, |
|
|
event_log_tail=state.event_log[-20:], |
|
|
trace_tail=state.trace_log[-40:], |
|
|
) |
|
|
|
|
|
def restore_into(state: WorldState, snap: Snapshot) -> None: |
|
|
state.step = snap.step |
|
|
state.grid = [row[:] for row in snap.grid] |
|
|
for k, d in snap.agents.items(): |
|
|
state.agents[k] = Agent(**d) |
|
|
state.caught = snap.caught |
|
|
state.event_log.append(f"Jumped to snapshot t={snap.step}.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def export_run(state: WorldState, history: List[Snapshot]) -> str: |
|
|
payload = { |
|
|
"seed": state.seed, |
|
|
"controlled": state.controlled, |
|
|
"pov": state.pov, |
|
|
"overlay": state.overlay, |
|
|
"cfg": asdict(state.cfg), |
|
|
"metrics": asdict(state.metrics), |
|
|
"q_pred": state.q_pred, |
|
|
"q_prey": state.q_prey, |
|
|
"history": [asdict(s) for s in history], |
|
|
"grid": state.grid, |
|
|
} |
|
|
return json.dumps(payload, indent=2) |
|
|
|
|
|
def import_run(txt: str) -> Tuple[WorldState, List[Snapshot], Dict[str, np.ndarray], int]: |
|
|
data = json.loads(txt) |
|
|
st = init_state(int(data.get("seed", 1337))) |
|
|
st.controlled = data.get("controlled", st.controlled) |
|
|
st.pov = data.get("pov", st.pov) |
|
|
st.overlay = bool(data.get("overlay", False)) |
|
|
st.grid = data.get("grid", st.grid) |
|
|
|
|
|
st.cfg = TrainConfig(**data.get("cfg", asdict(st.cfg))) |
|
|
st.metrics = Metrics(**data.get("metrics", asdict(st.metrics))) |
|
|
|
|
|
st.q_pred = data.get("q_pred", {}) |
|
|
st.q_prey = data.get("q_prey", {}) |
|
|
|
|
|
hist = [Snapshot(**s) for s in data.get("history", [])] |
|
|
bel = init_belief() |
|
|
r_idx = max(0, len(hist) - 1) |
|
|
|
|
|
if hist: |
|
|
restore_into(st, hist[-1]) |
|
|
st.event_log.append("Imported run.") |
|
|
return st, hist, bel, r_idx |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, str, str, str]: |
|
|
for nm, a in state.agents.items(): |
|
|
update_belief_for_agent(state, beliefs[nm], a) |
|
|
|
|
|
pov = raycast_view(state, state.agents[state.pov]) |
|
|
truth_np = np.array(state.grid, dtype=np.int16) |
|
|
truth_img = render_topdown(truth_np, state.agents, f"Truth Map — t={state.step} seed={state.seed}", show_agents=True) |
|
|
|
|
|
ctrl = state.controlled |
|
|
other = "Prey" if ctrl == "Predator" else "Predator" |
|
|
b_ctrl = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief", show_agents=True) |
|
|
b_other = render_topdown(beliefs[other], state.agents, f"{other} Belief", show_agents=True) |
|
|
|
|
|
m = state.metrics |
|
|
pred = state.agents["Predator"] |
|
|
prey = state.agents["Prey"] |
|
|
scout = state.agents["Scout"] |
|
|
|
|
|
status = ( |
|
|
f"Controlled={state.controlled} | POV={state.pov} | caught={state.caught} | eps={m.epsilon:.3f}\n" |
|
|
f"Episodes={m.episodes} | catches={m.catches} | avgStepsToCatch~{m.avg_steps_to_catch:.2f} | avgEff~{m.avg_path_efficiency:.2f}\n" |
|
|
f"Pred({pred.x},{pred.y}) o={pred.ori} | Prey({prey.x},{prey.y}) o={prey.ori} e={prey.energy} | Scout({scout.x},{scout.y}) o={scout.ori}" |
|
|
) |
|
|
events = "\n".join(state.event_log[-18:]) |
|
|
trace = "\n".join(state.trace_log[-18:]) |
|
|
return pov, truth_img, b_ctrl, b_other, status, events, trace |
|
|
|
|
|
def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState) -> WorldState: |
|
|
x_px, y_px = evt.index |
|
|
y_px -= 28 |
|
|
if y_px < 0: |
|
|
return state |
|
|
gx = int(x_px // TILE) |
|
|
gy = int(y_px // TILE) |
|
|
if not in_bounds(gx, gy): |
|
|
return state |
|
|
if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1: |
|
|
return state |
|
|
state.grid[gy][gx] = selected_tile |
|
|
state.event_log.append(f"t={state.step}: Tile ({gx},{gy}) -> {TILE_NAMES.get(selected_tile)}") |
|
|
return state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Agent POV") as demo: |
|
|
gr.Markdown( |
|
|
"## Agent-POV by ZEN AI Co.\n" |
|
|
"Track every interaction, train policies, and audit why outcomes happened.\n" |
|
|
"No timers (compatibility). Use Tick/Run/Train for controlled experiments." |
|
|
) |
|
|
|
|
|
st = gr.State(init_state(1337)) |
|
|
history = gr.State([snapshot_of(init_state(1337))]) |
|
|
beliefs = gr.State(init_belief()) |
|
|
rewind_idx = gr.State(0) |
|
|
|
|
|
with gr.Row(): |
|
|
pov_img = gr.Image(label="POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H) |
|
|
with gr.Column(): |
|
|
status = gr.Textbox(label="Status + Metrics", lines=4) |
|
|
events = gr.Textbox(label="Event Log", lines=10) |
|
|
trace = gr.Textbox(label="Step Trace (why it happened)", lines=10) |
|
|
|
|
|
with gr.Row(): |
|
|
truth = gr.Image(label="Truth Map (click to edit tiles)", type="pil") |
|
|
belief_a = gr.Image(label="Belief (Controlled)", type="pil") |
|
|
belief_b = gr.Image(label="Belief (Other)", type="pil") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("### Manual Controls") |
|
|
with gr.Row(): |
|
|
btn_L = gr.Button("L") |
|
|
btn_F = gr.Button("F") |
|
|
btn_R = gr.Button("R") |
|
|
with gr.Row(): |
|
|
btn_tick = gr.Button("Tick") |
|
|
run_steps = gr.Number(value=25, label="Run N steps", precision=0) |
|
|
btn_run = gr.Button("Run") |
|
|
with gr.Row(): |
|
|
btn_toggle_control = gr.Button("Toggle Controlled") |
|
|
btn_toggle_pov = gr.Button("Toggle POV") |
|
|
overlay = gr.Checkbox(False, label="Overlay reticle") |
|
|
|
|
|
tile_pick = gr.Radio( |
|
|
choices=[(TILE_NAMES[k], k) for k in [EMPTY, WALL, FOOD, NOISE, DOOR, TELE]], |
|
|
value=WALL, |
|
|
label="Paint tile type" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
gr.Markdown("### Training Controls (Q-learning)") |
|
|
use_q_pred = gr.Checkbox(True, label="Use Q-learning: Predator") |
|
|
use_q_prey = gr.Checkbox(True, label="Use Q-learning: Prey") |
|
|
alpha = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="alpha (learn rate)") |
|
|
gamma = gr.Slider(0.5, 0.99, value=0.95, step=0.01, label="gamma (discount)") |
|
|
eps = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="epsilon (exploration)") |
|
|
eps_decay = gr.Slider(0.90, 0.999, value=0.995, step=0.001, label="epsilon decay") |
|
|
eps_min = gr.Slider(0.0, 0.2, value=0.02, step=0.01, label="epsilon min") |
|
|
|
|
|
episodes = gr.Number(value=50, label="Train episodes", precision=0) |
|
|
max_steps = gr.Number(value=250, label="Max steps per episode", precision=0) |
|
|
btn_train = gr.Button("Train") |
|
|
|
|
|
btn_reset = gr.Button("Reset Episode") |
|
|
btn_reset_all = gr.Button("Reset ALL (wipe Q + metrics)") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
rewind = gr.Slider(0, 0, value=0, step=1, label="Rewind (history index)") |
|
|
btn_jump = gr.Button("Jump") |
|
|
with gr.Column(): |
|
|
export_box = gr.Textbox(label="Export JSON", lines=10) |
|
|
btn_export = gr.Button("Export") |
|
|
with gr.Column(): |
|
|
import_box = gr.Textbox(label="Import JSON", lines=10) |
|
|
btn_import = gr.Button("Import") |
|
|
|
|
|
def refresh(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r: int): |
|
|
r_max = max(0, len(hist) - 1) |
|
|
r = max(0, min(int(r), r_max)) |
|
|
pov, tr, ba, bb, stxt, etxt, ttxt = build_views(state, bel) |
|
|
return ( |
|
|
pov, tr, ba, bb, |
|
|
stxt, etxt, ttxt, |
|
|
gr.update(maximum=r_max, value=r), |
|
|
r |
|
|
) |
|
|
|
|
|
def push_hist(state: WorldState, hist: List[Snapshot]) -> List[Snapshot]: |
|
|
hist.append(snapshot_of(state)) |
|
|
if len(hist) > MAX_HISTORY: |
|
|
hist.pop(0) |
|
|
return hist |
|
|
|
|
|
def set_cfg(state: WorldState, uq_pred: bool, uq_prey: bool, a: float, g: float, e: float, ed: float, emin: float): |
|
|
state.cfg.use_q_pred = bool(uq_pred) |
|
|
state.cfg.use_q_prey = bool(uq_prey) |
|
|
state.cfg.alpha = float(a) |
|
|
state.cfg.gamma = float(g) |
|
|
state.metrics.epsilon = float(e) |
|
|
state.cfg.epsilon_decay = float(ed) |
|
|
state.cfg.epsilon_min = float(emin) |
|
|
return state |
|
|
|
|
|
def do_manual(state, hist, bel, r, act): |
|
|
tick(state, manual_action=act) |
|
|
hist = push_hist(state, hist) |
|
|
r = len(hist) - 1 |
|
|
out = refresh(state, hist, bel, r) |
|
|
return out + (state, hist, bel, r) |
|
|
|
|
|
def do_tick(state, hist, bel, r): |
|
|
tick(state, manual_action=None) |
|
|
hist = push_hist(state, hist) |
|
|
r = len(hist) - 1 |
|
|
out = refresh(state, hist, bel, r) |
|
|
return out + (state, hist, bel, r) |
|
|
|
|
|
def do_run(state, hist, bel, r, n): |
|
|
n = max(1, int(n)) |
|
|
for _ in range(n): |
|
|
if state.caught: |
|
|
break |
|
|
tick(state, manual_action=None) |
|
|
hist = push_hist(state, hist) |
|
|
r = len(hist) - 1 |
|
|
out = refresh(state, hist, bel, r) |
|
|
return out + (state, hist, bel, r) |
|
|
|
|
|
def toggle_control(state, hist, bel, r): |
|
|
order = ["Predator", "Prey", "Scout"] |
|
|
i = order.index(state.controlled) |
|
|
state.controlled = order[(i + 1) % len(order)] |
|
|
state.event_log.append(f"Controlled -> {state.controlled}") |
|
|
hist = push_hist(state, hist) |
|
|
r = len(hist) - 1 |
|
|
out = refresh(state, hist, bel, r) |
|
|
return out + (state, hist, bel, r) |
|
|
|
|
|
def toggle_pov(state, hist, bel, r): |
|
|
order = ["Predator", "Prey", "Scout"] |
|
|
i = order.index(state.pov) |
|
|
state.pov = order[(i + 1) % len(order)] |
|
|
state.event_log.append(f"POV -> {state.pov}") |
|
|
hist = push_hist(state, hist) |
|
|
r = len(hist) - 1 |
|
|
out = refresh(state, hist, bel, r) |
|
|
return out + (state, hist, bel, r) |
|
|
|
|
|
def set_overlay(state, hist, bel, r, ov): |
|
|
state.overlay = bool(ov) |
|
|
out = refresh(state, hist, bel, r) |
|
|
return out + (state, hist, bel, r) |
|
|
|
|
|
def click_truth(tile, state, hist, bel, r, evt: gr.SelectData): |
|
|
state = grid_click_to_tile(evt, int(tile), state) |
|
|
hist = push_hist(state, hist) |
|
|
r = len(hist) - 1 |
|
|
out = refresh(state, hist, bel, r) |
|
|
return out + (state, hist, bel, r) |
|
|
|
|
|
def jump(state, hist, bel, r, idx): |
|
|
if not hist: |
|
|
out = refresh(state, hist, bel, r) |
|
|
return out + (state, hist, bel, r) |
|
|
idx = max(0, min(int(idx), len(hist) - 1)) |
|
|
restore_into(state, hist[idx]) |
|
|
r = idx |
|
|
out = refresh(state, hist, bel, r) |
|
|
return out + (state, hist, bel, r) |
|
|
|
|
|
def reset_ep(state, hist, bel, r): |
|
|
reset_episode(state, seed=state.seed) |
|
|
hist = [snapshot_of(state)] |
|
|
r = 0 |
|
|
bel = init_belief() |
|
|
out = refresh(state, hist, bel, r) |
|
|
return out + (state, hist, bel, r) |
|
|
|
|
|
def reset_all(state, hist, bel, r): |
|
|
seed = state.seed |
|
|
state = init_state(seed) |
|
|
hist = [snapshot_of(state)] |
|
|
bel = init_belief() |
|
|
r = 0 |
|
|
out = refresh(state, hist, bel, r) |
|
|
return out + (state, hist, bel, r) |
|
|
|
|
|
def do_train(state, hist, bel, r, |
|
|
uq_pred, uq_prey, a, g, e, ed, emin, |
|
|
eps_count, max_s): |
|
|
state = set_cfg(state, uq_pred, uq_prey, a, g, e, ed, emin) |
|
|
train(state, episodes=max(1, int(eps_count)), max_steps=max(10, int(max_s))) |
|
|
|
|
|
reset_episode(state, seed=state.seed) |
|
|
hist = [snapshot_of(state)] |
|
|
bel = init_belief() |
|
|
r = 0 |
|
|
out = refresh(state, hist, bel, r) |
|
|
return out + (state, hist, bel, r) |
|
|
|
|
|
def export_fn(state, hist): |
|
|
return export_run(state, hist) |
|
|
|
|
|
def import_fn(txt): |
|
|
state, hist, bel, r = import_run(txt) |
|
|
pov, tr, ba, bb, stxt, etxt, ttxt = build_views(state, bel) |
|
|
r_max = max(0, len(hist) - 1) |
|
|
return ( |
|
|
pov, tr, ba, bb, stxt, etxt, ttxt, |
|
|
gr.update(maximum=r_max, value=r), |
|
|
state, hist, bel, r |
|
|
) |
|
|
|
|
|
|
|
|
btn_L.click(lambda s,h,b,r: do_manual(s,h,b,r,"L"), |
|
|
inputs=[st, history, beliefs, rewind_idx], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
btn_F.click(lambda s,h,b,r: do_manual(s,h,b,r,"F"), |
|
|
inputs=[st, history, beliefs, rewind_idx], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
btn_R.click(lambda s,h,b,r: do_manual(s,h,b,r,"R"), |
|
|
inputs=[st, history, beliefs, rewind_idx], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
btn_tick.click(do_tick, |
|
|
inputs=[st, history, beliefs, rewind_idx], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
btn_run.click(do_run, |
|
|
inputs=[st, history, beliefs, rewind_idx, run_steps], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
btn_toggle_control.click(toggle_control, |
|
|
inputs=[st, history, beliefs, rewind_idx], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
btn_toggle_pov.click(toggle_pov, |
|
|
inputs=[st, history, beliefs, rewind_idx], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
overlay.change(set_overlay, |
|
|
inputs=[st, history, beliefs, rewind_idx, overlay], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
truth.select(click_truth, |
|
|
inputs=[tile_pick, st, history, beliefs, rewind_idx], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
btn_jump.click(jump, |
|
|
inputs=[st, history, beliefs, rewind_idx, rewind], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
btn_reset.click(reset_ep, |
|
|
inputs=[st, history, beliefs, rewind_idx], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
btn_reset_all.click(reset_all, |
|
|
inputs=[st, history, beliefs, rewind_idx], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
btn_train.click(do_train, |
|
|
inputs=[st, history, beliefs, rewind_idx, |
|
|
use_q_pred, use_q_prey, alpha, gamma, eps, eps_decay, eps_min, |
|
|
episodes, max_steps], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
btn_export.click(export_fn, inputs=[st, history], outputs=[export_box], queue=True) |
|
|
|
|
|
btn_import.click(import_fn, |
|
|
inputs=[import_box], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, st, history, beliefs, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
demo.load(refresh, |
|
|
inputs=[st, history, beliefs, rewind_idx], |
|
|
outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx], |
|
|
queue=True) |
|
|
|
|
|
demo.queue().launch() |
|
|
|