| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Chess Environment Implementation. |
| | |
| | An RL agent plays White against a random bot (Black). |
| | Reward is shaped: per-move material delta + terminal win/loss/draw bonus. |
| | """ |
| |
|
| | import random |
| | from typing import List |
| | from uuid import uuid4 |
| |
|
| | import chess |
| |
|
| | from openenv_core.env_server.interfaces import Environment |
| | from openenv_core.env_server.types import State |
| |
|
| | from models import ChessAction, ChessObservation |
| |
|
| | |
| | PIECE_VALUES = { |
| | chess.PAWN: 1, |
| | chess.KNIGHT: 3, |
| | chess.BISHOP: 3, |
| | chess.ROOK: 5, |
| | chess.QUEEN: 9, |
| | } |
| |
|
| |
|
| | class ChessEnvironment(Environment): |
| | """ |
| | Chess environment where an RL agent (White) plays against a random bot (Black). |
| | |
| | Reward shaping: |
| | - material_delta: change in (White material - Black material) each step |
| | - terminal bonus: +1 for win, -1 for loss, 0 for draw |
| | """ |
| |
|
| | def __init__(self): |
| | self._board = chess.Board() |
| | self._state = State(episode_id=str(uuid4()), step_count=0) |
| | self._captured_pieces: List[str] = [] |
| |
|
| | def reset(self) -> ChessObservation: |
| | self._board = chess.Board() |
| | self._state = State(episode_id=str(uuid4()), step_count=0) |
| | self._captured_pieces = [] |
| |
|
| | return ChessObservation( |
| | board_fen=self._board.fen(), |
| | legal_moves=[m.uci() for m in self._board.legal_moves], |
| | white_move="", |
| | black_move=None, |
| | material_balance=0.0, |
| | game_status="ongoing", |
| | captured_pieces=[], |
| | done=False, |
| | reward=0.0, |
| | ) |
| |
|
| | def step(self, action: ChessAction) -> ChessObservation: |
| | |
| | try: |
| | move = chess.Move.from_uci(action.move) |
| | except (chess.InvalidMoveError, ValueError) as exc: |
| | raise ValueError(f"Invalid UCI string: {action.move!r}") from exc |
| |
|
| | if move not in self._board.legal_moves: |
| | raise ValueError( |
| | f"Illegal move: {action.move!r}. " |
| | f"Legal moves: {[m.uci() for m in self._board.legal_moves]}" |
| | ) |
| |
|
| | self._state.step_count += 1 |
| | balance_before = self._material_balance() |
| |
|
| | |
| | self._track_capture(move) |
| | self._board.push(move) |
| | white_uci = action.move |
| |
|
| | |
| | status = self._get_game_status() |
| | if self._is_terminal_status(status): |
| | balance_after = self._material_balance() |
| | material_delta = balance_after - balance_before |
| | reward = material_delta + self._terminal_reward(status) |
| | return ChessObservation( |
| | board_fen=self._board.fen(), |
| | legal_moves=[], |
| | white_move=white_uci, |
| | black_move=None, |
| | material_balance=balance_after, |
| | game_status=status, |
| | captured_pieces=list(self._captured_pieces), |
| | done=True, |
| | reward=reward, |
| | ) |
| |
|
| | |
| | black_moves = list(self._board.legal_moves) |
| | black_move = random.choice(black_moves) |
| | self._track_capture(black_move) |
| | self._board.push(black_move) |
| | black_uci = black_move.uci() |
| |
|
| | |
| | status = self._get_game_status() |
| | balance_after = self._material_balance() |
| | material_delta = balance_after - balance_before |
| | terminal = self._is_terminal_status(status) |
| | reward = material_delta + (self._terminal_reward(status) if terminal else 0.0) |
| |
|
| | return ChessObservation( |
| | board_fen=self._board.fen(), |
| | legal_moves=[m.uci() for m in self._board.legal_moves] if not terminal else [], |
| | white_move=white_uci, |
| | black_move=black_uci, |
| | material_balance=balance_after, |
| | game_status=status, |
| | captured_pieces=list(self._captured_pieces), |
| | done=terminal, |
| | reward=reward, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | def _track_capture(self, move: chess.Move) -> None: |
| | """Record a captured piece symbol (handles en passant).""" |
| | board = self._board |
| | if board.is_en_passant(move): |
| | self._captured_pieces.append(chess.piece_symbol(chess.PAWN)) |
| | elif board.piece_at(move.to_square) is not None: |
| | self._captured_pieces.append(board.piece_at(move.to_square).symbol()) |
| |
|
| | def _material_balance(self) -> float: |
| | """Return White material minus Black material.""" |
| | white = 0.0 |
| | black = 0.0 |
| | for sq in chess.SQUARES: |
| | piece = self._board.piece_at(sq) |
| | if piece is None: |
| | continue |
| | val = PIECE_VALUES.get(piece.piece_type, 0) |
| | if piece.color == chess.WHITE: |
| | white += val |
| | else: |
| | black += val |
| | return white - black |
| |
|
| | def _get_game_status(self) -> str: |
| | b = self._board |
| | if b.is_checkmate(): |
| | return "checkmate" |
| | if b.is_stalemate(): |
| | return "stalemate" |
| | if b.is_insufficient_material(): |
| | return "draw_insufficient" |
| | if b.is_fifty_moves(): |
| | return "draw_fifty" |
| | if b.is_repetition(): |
| | return "draw_repetition" |
| | if b.is_check(): |
| | return "check" |
| | return "ongoing" |
| |
|
| | @staticmethod |
| | def _is_terminal_status(status: str) -> bool: |
| | return status in ("checkmate", "stalemate", "draw_insufficient", |
| | "draw_fifty", "draw_repetition") |
| |
|
| | def _terminal_reward(self, status: str) -> float: |
| | if status == "checkmate": |
| | |
| | if self._board.turn == chess.BLACK: |
| | return 1.0 |
| | else: |
| | return -1.0 |
| | |
| | return 0.0 |
| |
|
| | @property |
| | def state(self) -> State: |
| | return self._state |
| |
|