chess / server /my_game_environment.py
SadyaMeta's picture
Upload folder using huggingface_hub
c755ba9 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Chess Environment Implementation.
An RL agent plays White against a random bot (Black).
Reward is shaped: per-move material delta + terminal win/loss/draw bonus.
"""
import random
from typing import List
from uuid import uuid4
import chess
from openenv_core.env_server.interfaces import Environment
from openenv_core.env_server.types import State
from models import ChessAction, ChessObservation
# Piece values: P=1, N=3, B=3, R=5, Q=9
PIECE_VALUES = {
chess.PAWN: 1,
chess.KNIGHT: 3,
chess.BISHOP: 3,
chess.ROOK: 5,
chess.QUEEN: 9,
}
class ChessEnvironment(Environment):
"""
Chess environment where an RL agent (White) plays against a random bot (Black).
Reward shaping:
- material_delta: change in (White material - Black material) each step
- terminal bonus: +1 for win, -1 for loss, 0 for draw
"""
def __init__(self):
self._board = chess.Board()
self._state = State(episode_id=str(uuid4()), step_count=0)
self._captured_pieces: List[str] = []
def reset(self) -> ChessObservation:
self._board = chess.Board()
self._state = State(episode_id=str(uuid4()), step_count=0)
self._captured_pieces = []
return ChessObservation(
board_fen=self._board.fen(),
legal_moves=[m.uci() for m in self._board.legal_moves],
white_move="",
black_move=None,
material_balance=0.0,
game_status="ongoing",
captured_pieces=[],
done=False,
reward=0.0,
)
def step(self, action: ChessAction) -> ChessObservation: # type: ignore[override]
# --- validate ---
try:
move = chess.Move.from_uci(action.move)
except (chess.InvalidMoveError, ValueError) as exc:
raise ValueError(f"Invalid UCI string: {action.move!r}") from exc
if move not in self._board.legal_moves:
raise ValueError(
f"Illegal move: {action.move!r}. "
f"Legal moves: {[m.uci() for m in self._board.legal_moves]}"
)
self._state.step_count += 1
balance_before = self._material_balance()
# --- White's move ---
self._track_capture(move)
self._board.push(move)
white_uci = action.move
# check if game ended after White's move
status = self._get_game_status()
if self._is_terminal_status(status):
balance_after = self._material_balance()
material_delta = balance_after - balance_before
reward = material_delta + self._terminal_reward(status)
return ChessObservation(
board_fen=self._board.fen(),
legal_moves=[],
white_move=white_uci,
black_move=None,
material_balance=balance_after,
game_status=status,
captured_pieces=list(self._captured_pieces),
done=True,
reward=reward,
)
# --- Black's move (random) ---
black_moves = list(self._board.legal_moves)
black_move = random.choice(black_moves)
self._track_capture(black_move)
self._board.push(black_move)
black_uci = black_move.uci()
# --- post-move evaluation ---
status = self._get_game_status()
balance_after = self._material_balance()
material_delta = balance_after - balance_before
terminal = self._is_terminal_status(status)
reward = material_delta + (self._terminal_reward(status) if terminal else 0.0)
return ChessObservation(
board_fen=self._board.fen(),
legal_moves=[m.uci() for m in self._board.legal_moves] if not terminal else [],
white_move=white_uci,
black_move=black_uci,
material_balance=balance_after,
game_status=status,
captured_pieces=list(self._captured_pieces),
done=terminal,
reward=reward,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _track_capture(self, move: chess.Move) -> None:
"""Record a captured piece symbol (handles en passant)."""
board = self._board
if board.is_en_passant(move):
self._captured_pieces.append(chess.piece_symbol(chess.PAWN))
elif board.piece_at(move.to_square) is not None:
self._captured_pieces.append(board.piece_at(move.to_square).symbol())
def _material_balance(self) -> float:
"""Return White material minus Black material."""
white = 0.0
black = 0.0
for sq in chess.SQUARES:
piece = self._board.piece_at(sq)
if piece is None:
continue
val = PIECE_VALUES.get(piece.piece_type, 0)
if piece.color == chess.WHITE:
white += val
else:
black += val
return white - black
def _get_game_status(self) -> str:
b = self._board
if b.is_checkmate():
return "checkmate"
if b.is_stalemate():
return "stalemate"
if b.is_insufficient_material():
return "draw_insufficient"
if b.is_fifty_moves():
return "draw_fifty"
if b.is_repetition():
return "draw_repetition"
if b.is_check():
return "check"
return "ongoing"
@staticmethod
def _is_terminal_status(status: str) -> bool:
return status in ("checkmate", "stalemate", "draw_insufficient",
"draw_fifty", "draw_repetition")
def _terminal_reward(self, status: str) -> float:
if status == "checkmate":
# whoever is to move is in checkmate → they lost
if self._board.turn == chess.BLACK:
return 1.0 # White delivered checkmate
else:
return -1.0 # Black delivered checkmate
# all other terminal states are draws
return 0.0
@property
def state(self) -> State:
return self._state