""" openenv/env.py ────────────── Stateful ChessEcon environment that implements the OpenEnv 0.1 contract: reset() → ResetResponse step() → StepResponse state() → StateResponse Key design decisions: - Each call to reset() creates a new episode (new game_id, fresh board). - step(action) accepts either UCI or SAN notation. - Rewards are computed per-step (not just terminal): +0.01 legal move played +0.05 move gives check +0.10 capture +1.00 win -1.00 loss 0.00 draw - Economy (entry fees, prize pool) is tracked per episode. - Thread-safe: each episode is independent. The FastAPI router creates one global instance and serialises access via asyncio locks. """ from __future__ import annotations import uuid import logging from typing import Optional import chess from backend.chess_engine import ChessEngine from backend.settings import settings from backend.openenv.models import ( ChessObservation, ResetResponse, StepResponse, StateResponse, ResetRequest, ) logger = logging.getLogger(__name__) # Shaping rewards (small intermediate signals) REWARD_LEGAL_MOVE = 0.01 REWARD_CHECK = 0.05 REWARD_CAPTURE = 0.10 REWARD_WIN = 1.00 REWARD_LOSS = -1.00 REWARD_DRAW = 0.00 class ChessEconEnv: """ OpenEnv-compliant Chess Economy environment. Manages a single active episode. Call reset() to start a new episode. Call step(action) to advance it. Call state() to inspect without advancing. """ def __init__( self, white_model_id: str, black_model_id: str, starting_wallet: float = 100.0, entry_fee: float = 10.0, prize_pool_fraction: float = 0.9, max_moves: int = 150, ): self.white_model_id = white_model_id self.black_model_id = black_model_id self.starting_wallet = starting_wallet self.entry_fee = entry_fee self.prize_pool_fraction = prize_pool_fraction self.max_moves = max_moves # Episode state (None until first reset()) self._engine: Optional[ChessEngine] = None self._episode_id: str = "" self._step_count: int = 0 self._status: str = "idle" self._move_history: list[str] = [] # Economy self._wallet_white: float = starting_wallet self._wallet_black: float = starting_wallet self._prize_pool: float = 0.0 # Last move for observation self._last_uci: Optional[str] = None self._last_san: Optional[str] = None # ── OpenEnv core API ─────────────────────────────────────────────────────── def reset(self, request: Optional[ResetRequest] = None) -> ResetResponse: """ Start a new episode. Deducts entry fees and returns the initial observation. """ self._engine = ChessEngine() self._episode_id = str(uuid.uuid4()) self._step_count = 0 self._status = "active" self._move_history = [] self._last_uci = None self._last_san = None # Economy: deduct entry fees self._wallet_white -= self.entry_fee self._wallet_black -= self.entry_fee self._prize_pool = self.entry_fee * 2 * self.prize_pool_fraction logger.info( "Episode %s started. Wallets: W=%.1f B=%.1f prize_pool=%.1f", self._episode_id[:8], self._wallet_white, self._wallet_black, self._prize_pool, ) obs = self._build_observation() return ResetResponse( observation=obs, info={ "episode_id": self._episode_id, "prize_pool": self._prize_pool, "entry_fee": self.entry_fee, }, ) def step(self, action: str) -> StepResponse: """ Apply a move to the board and return the next observation + reward. action: UCI string ('e2e4') or SAN string ('e4'). """ if self._engine is None or self._status != "active": raise RuntimeError("Call reset() before step()") # ── Apply the move ───────────────────────────────────────────────── # Try UCI first, then SAN uci_applied: Optional[str] = None san_applied: Optional[str] = None # UCI path san_from_uci = self._engine.apply_move_uci(action) if san_from_uci is not None: uci_applied = action san_applied = san_from_uci else: # SAN path — we need the UCI back try: move = self._engine.board.parse_san(action) uci_applied = move.uci() san_applied = self._engine.board.san(move) self._engine.board.push(move) except Exception: # Illegal move — return current state with negative reward obs = self._build_observation() return StepResponse( observation=obs, reward=-0.10, terminated=False, truncated=False, info={"error": f"Illegal move: {action}", "legal_moves": self._engine.legal_moves_uci[:10]}, ) self._last_uci = uci_applied self._last_san = san_applied self._move_history.append(san_applied) self._step_count += 1 # ── Compute per-step reward ──────────────────────────────────────── reward = self._compute_step_reward(uci_applied) # ── Check termination ────────────────────────────────────────────── terminated = bool(self._engine.is_game_over) truncated = (not terminated) and (self._step_count >= self.max_moves * 2) if terminated or truncated: reward = self._settle_game(terminated, truncated, reward) obs = self._build_observation() return StepResponse( observation=obs, reward=round(reward, 4), terminated=terminated, truncated=truncated, info={ "episode_id": self._episode_id, "step": self._step_count, "san": san_applied, "uci": uci_applied, "move_history": self._move_history[-10:], "prize_pool": self._prize_pool, }, ) def state(self) -> StateResponse: """Return current episode state without advancing it.""" if self._engine is None: # Return idle state with default observation idle_obs = ChessObservation( fen="rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", turn="white", move_number=1, legal_moves_uci=[], wallet_white=self._wallet_white, wallet_black=self._wallet_black, white_model=self.white_model_id, black_model=self.black_model_id, ) return StateResponse( observation=idle_obs, episode_id="", step_count=0, status="idle", ) return StateResponse( observation=self._build_observation(), episode_id=self._episode_id, step_count=self._step_count, status=self._status, info={ "prize_pool": self._prize_pool, "move_history": self._move_history[-10:], }, ) # ── Internal helpers ─────────────────────────────────────────────────────── def _build_observation(self) -> ChessObservation: engine = self._engine assert engine is not None board = engine.board return ChessObservation( fen=engine.fen, turn=engine.turn, move_number=engine.move_number, last_move_uci=self._last_uci, last_move_san=self._last_san, legal_moves_uci=engine.legal_moves_uci, is_check=board.is_check(), wallet_white=round(self._wallet_white, 2), wallet_black=round(self._wallet_black, 2), white_model=self.white_model_id, black_model=self.black_model_id, info={ "move_history": self._move_history[-20:], "step_count": self._step_count, "episode_id": self._episode_id, }, ) def _compute_step_reward(self, uci: str) -> float: """ Dense per-step reward shaping. Evaluated AFTER the move has been applied, so we look at the NEW board state. """ engine = self._engine assert engine is not None board = engine.board reward = REWARD_LEGAL_MOVE # Check bonus (opponent is now in check) if board.is_check(): reward += REWARD_CHECK # Capture bonus — look at the move that was just pushed if board.move_stack: last_move = board.move_stack[-1] # Castling and en-passant: board.is_capture works on the board before the move # We check by looking at whether a piece disappeared from the target square # Simple heuristic: the move stack entry captures flag if board.is_capture(last_move): reward += REWARD_CAPTURE return reward def _settle_game(self, terminated: bool, truncated: bool, step_reward: float) -> float: """ Apply terminal reward and settle the economy. Returns the final total reward for the last move. """ engine = self._engine assert engine is not None result = engine.result or "1/2-1/2" white_reward = engine.compute_reward("white") # +1, -1, or 0 # Terminal reward if white_reward > 0: terminal = REWARD_WIN self._wallet_white += self._prize_pool logger.info("White wins! Prize: +%.1f", self._prize_pool) elif white_reward < 0: terminal = REWARD_LOSS self._wallet_black += self._prize_pool logger.info("Black wins! Prize: +%.1f", self._prize_pool) else: terminal = REWARD_DRAW self._wallet_white += self._prize_pool / 2 self._wallet_black += self._prize_pool / 2 logger.info("Draw. Split prize: +%.1f each", self._prize_pool / 2) self._status = "truncated" if truncated else "terminated" logger.info( "Episode %s ended. Result=%s Wallets: W=%.1f B=%.1f", self._episode_id[:8], result, self._wallet_white, self._wallet_black, ) return step_reward + terminal