|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import hashlib |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, Any, List, Tuple, Optional |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image, ImageDraw |
|
|
|
|
|
|
|
|
T_UNKNOWN = -1 |
|
|
T_EMPTY = 0 |
|
|
T_WALL = 1 |
|
|
T_COIN = 2 |
|
|
T_HAZARD = 3 |
|
|
T_GOAL = 4 |
|
|
T_AGENT = 5 |
|
|
|
|
|
ACTIONS = ["UP", "DOWN", "LEFT", "RIGHT", "WAIT"] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SimConfig: |
|
|
size: int = 12 |
|
|
walls_pct: float = 0.18 |
|
|
coins: int = 5 |
|
|
hazards: int = 4 |
|
|
pov_radius: int = 4 |
|
|
max_steps: int = 2000 |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
return { |
|
|
"size": int(self.size), |
|
|
"walls_pct": float(self.walls_pct), |
|
|
"coins": int(self.coins), |
|
|
"hazards": int(self.hazards), |
|
|
"pov_radius": int(self.pov_radius), |
|
|
"max_steps": int(self.max_steps), |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SimState: |
|
|
cfg: SimConfig |
|
|
seed: int |
|
|
rng_state_tag: int |
|
|
grid: np.ndarray |
|
|
agent_xy: Tuple[int, int] |
|
|
goal_xy: Tuple[int, int] |
|
|
score: int |
|
|
step: int |
|
|
done: bool |
|
|
last_state_sha256: Optional[str] = None |
|
|
|
|
|
def clone(self) -> "SimState": |
|
|
return SimState( |
|
|
cfg=self.cfg, |
|
|
seed=int(self.seed), |
|
|
rng_state_tag=int(self.rng_state_tag), |
|
|
grid=self.grid.copy(), |
|
|
agent_xy=(int(self.agent_xy[0]), int(self.agent_xy[1])), |
|
|
goal_xy=(int(self.goal_xy[0]), int(self.goal_xy[1])), |
|
|
score=int(self.score), |
|
|
step=int(self.step), |
|
|
done=bool(self.done), |
|
|
last_state_sha256=self.last_state_sha256, |
|
|
) |
|
|
|
|
|
|
|
|
def _sha256_hex(b: bytes) -> str: |
|
|
return hashlib.sha256(b).hexdigest() |
|
|
|
|
|
|
|
|
def _state_hash(state: SimState) -> str: |
|
|
N = int(state.cfg.size) |
|
|
ax, ay = state.agent_xy |
|
|
gx, gy = state.goal_xy |
|
|
header = np.array( |
|
|
[N, ax, ay, gx, gy, int(state.score), int(state.step), int(state.done), int(state.rng_state_tag)], |
|
|
dtype=np.int32, |
|
|
).tobytes() |
|
|
grid_bytes = state.grid.astype(np.int8).tobytes() |
|
|
return _sha256_hex(header + grid_bytes) |
|
|
|
|
|
|
|
|
def _in_bounds(N: int, x: int, y: int) -> bool: |
|
|
return 0 <= x < N and 0 <= y < N |
|
|
|
|
|
|
|
|
def reset_sim(cfg: SimConfig, seed: int) -> SimState: |
|
|
rng = np.random.RandomState(int(seed)) |
|
|
N = int(cfg.size) |
|
|
|
|
|
grid = np.zeros((N, N), dtype=np.int8) |
|
|
|
|
|
|
|
|
grid[0, :] = T_WALL |
|
|
grid[N - 1, :] = T_WALL |
|
|
grid[:, 0] = T_WALL |
|
|
grid[:, N - 1] = T_WALL |
|
|
|
|
|
|
|
|
internal = (rng.rand(N, N) < float(cfg.walls_pct)).astype(np.int8) * T_WALL |
|
|
internal[0, :] = 0 |
|
|
internal[N - 1, :] = 0 |
|
|
internal[:, 0] = 0 |
|
|
internal[:, N - 1] = 0 |
|
|
grid = np.maximum(grid, internal).astype(np.int8) |
|
|
|
|
|
|
|
|
agent_xy = (1, 1) |
|
|
goal_xy = (N - 2, N - 2) |
|
|
grid[agent_xy[1], agent_xy[0]] = T_EMPTY |
|
|
grid[goal_xy[1], goal_xy[0]] = T_GOAL |
|
|
|
|
|
|
|
|
empties = [ |
|
|
(x, y) |
|
|
for y in range(1, N - 1) |
|
|
for x in range(1, N - 1) |
|
|
if grid[y, x] == T_EMPTY and (x, y) not in (agent_xy, goal_xy) |
|
|
] |
|
|
rng.shuffle(empties) |
|
|
|
|
|
|
|
|
for i in range(min(int(cfg.coins), len(empties))): |
|
|
x, y = empties[i] |
|
|
grid[y, x] = T_COIN |
|
|
|
|
|
|
|
|
start_idx = min(int(cfg.coins), len(empties)) |
|
|
for i in range(start_idx, min(start_idx + int(cfg.hazards), len(empties))): |
|
|
x, y = empties[i] |
|
|
grid[y, x] = T_HAZARD |
|
|
|
|
|
st = SimState( |
|
|
cfg=cfg, |
|
|
seed=int(seed), |
|
|
rng_state_tag=int(rng.randint(0, 2**31 - 1)), |
|
|
grid=grid, |
|
|
agent_xy=agent_xy, |
|
|
goal_xy=goal_xy, |
|
|
score=0, |
|
|
step=0, |
|
|
done=False, |
|
|
last_state_sha256=None, |
|
|
) |
|
|
st.last_state_sha256 = _state_hash(st) |
|
|
return st |
|
|
|
|
|
|
|
|
def _agent_policy(cfg: SimConfig, state: SimState) -> str: |
|
|
|
|
|
|
|
|
ax, ay = state.agent_xy |
|
|
gx, gy = state.goal_xy |
|
|
|
|
|
candidates: List[Tuple[str, int, int]] = [] |
|
|
if gx > ax: |
|
|
candidates.append(("RIGHT", ax + 1, ay)) |
|
|
if gx < ax: |
|
|
candidates.append(("LEFT", ax - 1, ay)) |
|
|
if gy > ay: |
|
|
candidates.append(("DOWN", ax, ay + 1)) |
|
|
if gy < ay: |
|
|
candidates.append(("UP", ax, ay - 1)) |
|
|
|
|
|
|
|
|
candidates += [ |
|
|
("UP", ax, ay - 1), |
|
|
("DOWN", ax, ay + 1), |
|
|
("LEFT", ax - 1, ay), |
|
|
("RIGHT", ax + 1, ay), |
|
|
("WAIT", ax, ay), |
|
|
] |
|
|
|
|
|
N = int(cfg.size) |
|
|
for a, nx, ny in candidates: |
|
|
if not _in_bounds(N, nx, ny): |
|
|
continue |
|
|
if int(state.grid[ny, nx]) == T_WALL: |
|
|
continue |
|
|
return a |
|
|
return "WAIT" |
|
|
|
|
|
|
|
|
def step_sim(cfg: SimConfig, state: SimState) -> Tuple[SimState, str]: |
|
|
if state.done: |
|
|
return state, "WAIT" |
|
|
|
|
|
action = _agent_policy(cfg, state) |
|
|
ax, ay = state.agent_xy |
|
|
nx, ny = ax, ay |
|
|
|
|
|
if action == "UP": |
|
|
ny -= 1 |
|
|
elif action == "DOWN": |
|
|
ny += 1 |
|
|
elif action == "LEFT": |
|
|
nx -= 1 |
|
|
elif action == "RIGHT": |
|
|
nx += 1 |
|
|
elif action == "WAIT": |
|
|
pass |
|
|
|
|
|
new = state.clone() |
|
|
new.step += 1 |
|
|
|
|
|
N = int(cfg.size) |
|
|
if (not _in_bounds(N, nx, ny)) or int(new.grid[ny, nx]) == T_WALL: |
|
|
nx, ny = ax, ay |
|
|
|
|
|
tile = int(new.grid[ny, nx]) |
|
|
if tile == T_COIN: |
|
|
new.score += 1 |
|
|
new.grid[ny, nx] = T_EMPTY |
|
|
elif tile == T_HAZARD: |
|
|
new.score -= 2 |
|
|
elif tile == T_GOAL: |
|
|
new.score += 10 |
|
|
new.done = True |
|
|
|
|
|
new.agent_xy = (nx, ny) |
|
|
|
|
|
if new.step >= int(cfg.max_steps): |
|
|
new.done = True |
|
|
|
|
|
new.last_state_sha256 = _state_hash(new) |
|
|
return new, action |
|
|
|
|
|
|
|
|
def observation_array(state: SimState) -> np.ndarray: |
|
|
|
|
|
N = int(state.cfg.size) |
|
|
r = int(state.cfg.pov_radius) |
|
|
ax, ay = state.agent_xy |
|
|
|
|
|
obs = np.full((N, N), T_UNKNOWN, dtype=np.int8) |
|
|
|
|
|
y0, y1 = max(0, ay - r), min(N, ay + r + 1) |
|
|
x0, x1 = max(0, ax - r), min(N, ax + r + 1) |
|
|
|
|
|
obs[y0:y1, x0:x1] = state.grid[y0:y1, x0:x1] |
|
|
obs[ay, ax] = T_AGENT |
|
|
return obs |
|
|
|
|
|
|
|
|
def observation_sha256(state: SimState) -> str: |
|
|
obs = observation_array(state) |
|
|
return _sha256_hex(obs.astype(np.int8).tobytes()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_BG = (10, 14, 22) |
|
|
_GRID = (38, 52, 80) |
|
|
_WALL = (160, 170, 190) |
|
|
_EMPTY = (18, 24, 36) |
|
|
_COIN = (240, 210, 60) |
|
|
_HAZ = (255, 90, 90) |
|
|
_GOAL = (120, 255, 170) |
|
|
_AGENT = (120, 180, 255) |
|
|
_UNKNOWN = (0, 0, 0) |
|
|
|
|
|
CELL = 24 |
|
|
PAD = 12 |
|
|
|
|
|
|
|
|
def _tile_color(t: int): |
|
|
if t == T_WALL: |
|
|
return _WALL |
|
|
if t == T_COIN: |
|
|
return _COIN |
|
|
if t == T_HAZARD: |
|
|
return _HAZ |
|
|
if t == T_GOAL: |
|
|
return _GOAL |
|
|
if t == T_AGENT: |
|
|
return _AGENT |
|
|
if t == T_UNKNOWN: |
|
|
return _UNKNOWN |
|
|
return _EMPTY |
|
|
|
|
|
|
|
|
def render_world_image(state: SimState) -> Image.Image: |
|
|
N = int(state.cfg.size) |
|
|
w = PAD * 2 + N * CELL |
|
|
h = PAD * 2 + N * CELL + 44 |
|
|
|
|
|
img = Image.new("RGB", (w, h), _BG) |
|
|
d = ImageDraw.Draw(img) |
|
|
|
|
|
d.text((PAD, 10), f"World | seed={state.seed} step={state.step} score={state.score}", fill=(235, 235, 235)) |
|
|
|
|
|
ox, oy = PAD, PAD + 34 |
|
|
for y in range(N): |
|
|
for x in range(N): |
|
|
t = int(state.grid[y, x]) |
|
|
if (x, y) == state.agent_xy: |
|
|
t = T_AGENT |
|
|
c = _tile_color(t) |
|
|
x0 = ox + x * CELL |
|
|
y0 = oy + y * CELL |
|
|
d.rectangle([x0, y0, x0 + CELL - 1, y0 + CELL - 1], fill=c) |
|
|
d.rectangle([x0, y0, x0 + CELL - 1, y0 + CELL - 1], outline=_GRID) |
|
|
|
|
|
hs = (state.last_state_sha256 or "")[:16] |
|
|
d.text((PAD, h - 18), f"state_hash={hs}", fill=(170, 170, 170)) |
|
|
return img |
|
|
|
|
|
|
|
|
def render_pov_image(state: SimState) -> Image.Image: |
|
|
N = int(state.cfg.size) |
|
|
obs = observation_array(state) |
|
|
|
|
|
w = PAD * 2 + N * CELL |
|
|
h = PAD * 2 + N * CELL + 44 |
|
|
|
|
|
img = Image.new("RGB", (w, h), _BG) |
|
|
d = ImageDraw.Draw(img) |
|
|
|
|
|
d.text( |
|
|
(PAD, 10), |
|
|
f"Agent POV | radius={state.cfg.pov_radius} obs_hash={observation_sha256(state)[:12]}", |
|
|
fill=(235, 235, 235), |
|
|
) |
|
|
|
|
|
ox, oy = PAD, PAD + 34 |
|
|
for y in range(N): |
|
|
for x in range(N): |
|
|
t = int(obs[y, x]) |
|
|
c = _tile_color(t) |
|
|
x0 = ox + x * CELL |
|
|
y0 = oy + y * CELL |
|
|
d.rectangle([x0, y0, x0 + CELL - 1, y0 + CELL - 1], fill=c) |
|
|
d.rectangle([x0, y0, x0 + CELL - 1, y0 + CELL - 1], outline=_GRID) |
|
|
|
|
|
hs = (state.last_state_sha256 or "")[:16] |
|
|
d.text((PAD, h - 18), f"state_hash={hs}", fill=(170, 170, 170)) |
|
|
return img |