# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Chess Environment Implementation. An RL agent plays White against a random bot (Black). Reward is shaped: per-move material delta + terminal win/loss/draw bonus. """ import random from typing import List from uuid import uuid4 import chess from openenv_core.env_server.interfaces import Environment from openenv_core.env_server.types import State from models import ChessAction, ChessObservation # Piece values: P=1, N=3, B=3, R=5, Q=9 PIECE_VALUES = { chess.PAWN: 1, chess.KNIGHT: 3, chess.BISHOP: 3, chess.ROOK: 5, chess.QUEEN: 9, } class ChessEnvironment(Environment): """ Chess environment where an RL agent (White) plays against a random bot (Black). Reward shaping: - material_delta: change in (White material - Black material) each step - terminal bonus: +1 for win, -1 for loss, 0 for draw """ def __init__(self): self._board = chess.Board() self._state = State(episode_id=str(uuid4()), step_count=0) self._captured_pieces: List[str] = [] def reset(self) -> ChessObservation: self._board = chess.Board() self._state = State(episode_id=str(uuid4()), step_count=0) self._captured_pieces = [] return ChessObservation( board_fen=self._board.fen(), legal_moves=[m.uci() for m in self._board.legal_moves], white_move="", black_move=None, material_balance=0.0, game_status="ongoing", captured_pieces=[], done=False, reward=0.0, ) def step(self, action: ChessAction) -> ChessObservation: # type: ignore[override] # --- validate --- try: move = chess.Move.from_uci(action.move) except (chess.InvalidMoveError, ValueError) as exc: raise ValueError(f"Invalid UCI string: {action.move!r}") from exc if move not in self._board.legal_moves: raise ValueError( f"Illegal move: {action.move!r}. " f"Legal moves: {[m.uci() for m in self._board.legal_moves]}" ) self._state.step_count += 1 balance_before = self._material_balance() # --- White's move --- self._track_capture(move) self._board.push(move) white_uci = action.move # check if game ended after White's move status = self._get_game_status() if self._is_terminal_status(status): balance_after = self._material_balance() material_delta = balance_after - balance_before reward = material_delta + self._terminal_reward(status) return ChessObservation( board_fen=self._board.fen(), legal_moves=[], white_move=white_uci, black_move=None, material_balance=balance_after, game_status=status, captured_pieces=list(self._captured_pieces), done=True, reward=reward, ) # --- Black's move (random) --- black_moves = list(self._board.legal_moves) black_move = random.choice(black_moves) self._track_capture(black_move) self._board.push(black_move) black_uci = black_move.uci() # --- post-move evaluation --- status = self._get_game_status() balance_after = self._material_balance() material_delta = balance_after - balance_before terminal = self._is_terminal_status(status) reward = material_delta + (self._terminal_reward(status) if terminal else 0.0) return ChessObservation( board_fen=self._board.fen(), legal_moves=[m.uci() for m in self._board.legal_moves] if not terminal else [], white_move=white_uci, black_move=black_uci, material_balance=balance_after, game_status=status, captured_pieces=list(self._captured_pieces), done=terminal, reward=reward, ) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _track_capture(self, move: chess.Move) -> None: """Record a captured piece symbol (handles en passant).""" board = self._board if board.is_en_passant(move): self._captured_pieces.append(chess.piece_symbol(chess.PAWN)) elif board.piece_at(move.to_square) is not None: self._captured_pieces.append(board.piece_at(move.to_square).symbol()) def _material_balance(self) -> float: """Return White material minus Black material.""" white = 0.0 black = 0.0 for sq in chess.SQUARES: piece = self._board.piece_at(sq) if piece is None: continue val = PIECE_VALUES.get(piece.piece_type, 0) if piece.color == chess.WHITE: white += val else: black += val return white - black def _get_game_status(self) -> str: b = self._board if b.is_checkmate(): return "checkmate" if b.is_stalemate(): return "stalemate" if b.is_insufficient_material(): return "draw_insufficient" if b.is_fifty_moves(): return "draw_fifty" if b.is_repetition(): return "draw_repetition" if b.is_check(): return "check" return "ongoing" @staticmethod def _is_terminal_status(status: str) -> bool: return status in ("checkmate", "stalemate", "draw_insufficient", "draw_fifty", "draw_repetition") def _terminal_reward(self, status: str) -> float: if status == "checkmate": # whoever is to move is in checkmate → they lost if self._board.turn == chess.BLACK: return 1.0 # White delivered checkmate else: return -1.0 # Black delivered checkmate # all other terminal states are draws return 0.0 @property def state(self) -> State: return self._state