//! Game state wrapper around shakmaty, with our vocabulary mapping. use rand::Rng; use shakmaty::{ Chess, Color, EnPassantMode, Move, MoveList, Piece, Position, Role, Square, }; use shakmaty::fen::Fen; use crate::types::Termination; use crate::vocab; /// Convert our square index (file-major: a1=0, b1=1, ..., h8=63) to shakmaty Square. #[inline] pub fn our_sq_to_shakmaty(sq: u8) -> Square { // Our indexing: file = sq % 8, rank = sq / 8 // shakmaty Square::new(file, rank) expects File and Rank enums // but Square also has from_coords(file, rank) let file = sq % 8; let rank = sq / 8; Square::from_coords( shakmaty::File::new(file as u32), shakmaty::Rank::new(rank as u32), ) } /// Convert shakmaty Square to our square index. #[inline] pub fn shakmaty_sq_to_ours(sq: Square) -> u8 { let file = sq.file() as u8; let rank = sq.rank() as u8; rank * 8 + file } /// Convert a shakmaty Move to our token index. pub fn move_to_token(m: &Move) -> u16 { let (src, dst) = match m { Move::Normal { from, to, .. } => (*from, *to), Move::EnPassant { from, to } => (*from, *to), Move::Castle { king, rook } => { // UCI king-movement notation let king_sq = *king; let rook_sq = *rook; let dst = if rook_sq.file() > king_sq.file() { // Kingside: king goes to g-file Square::from_coords(shakmaty::File::G, king_sq.rank()) } else { // Queenside: king goes to c-file Square::from_coords(shakmaty::File::C, king_sq.rank()) }; (king_sq, dst) } Move::Put { .. } => panic!("Put moves not supported in standard chess"), }; let src_idx = shakmaty_sq_to_ours(src); let dst_idx = shakmaty_sq_to_ours(dst); // Check if this is a promotion if let Move::Normal { promotion: Some(role), .. } = m { let promo_type = match role { Role::Queen => 0, Role::Rook => 1, Role::Bishop => 2, Role::Knight => 3, _ => panic!("Invalid promotion role: {:?}", role), }; vocab::promo_token(src_idx, dst_idx, promo_type) .expect("Promotion move should have a valid promo pair") } else { vocab::base_grid_token(src_idx, dst_idx) } } /// Convert our token index to a shakmaty Move, given the current position. /// Finds the legal move matching the token's (src, dst, promo) decomposition. pub fn token_to_move(pos: &Chess, token: u16) -> Option { // Validate the token is decomposable (not PAD/outcome) vocab::decompose_token(token)?; let legal = pos.legal_moves(); for m in &legal { if move_to_token(m) == token { return Some(m.clone()); } } None } /// Piece encoding for board state extraction. /// 0=empty, 1-6=white P/N/B/R/Q/K, 7-12=black P/N/B/R/Q/K pub fn piece_to_code(piece: Option) -> i8 { match piece { None => 0, Some(p) => { let base = match p.role { Role::Pawn => 1, Role::Knight => 2, Role::Bishop => 3, Role::Rook => 4, Role::Queen => 5, Role::King => 6, }; if p.color == Color::White { base } else { base + 6 } } } } /// Full game state for replaying and analysis. #[derive(Clone)] pub struct GameState { pos: Chess, move_history: Vec, // tokens position_hashes: Vec, halfmove_clock: u32, } impl GameState { pub fn new() -> Self { let pos = Chess::default(); let hash = Self::position_hash(&pos); Self { pos, move_history: Vec::new(), position_hashes: vec![hash], halfmove_clock: 0, } } /// Simple position hash for repetition detection. /// Uses the board layout + castling rights + ep square + side to move. fn position_hash(pos: &Chess) -> u64 { use std::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher; let mut hasher = DefaultHasher::new(); // Hash piece placement for sq in Square::ALL { let piece = pos.board().piece_at(sq); piece.hash(&mut hasher); } // Hash side to move pos.turn().hash(&mut hasher); // Hash castling rights pos.castles().castling_rights().hash(&mut hasher); // Hash en passant square // shakmaty's legal_ep_square accounts for actual EP capture availability pos.legal_ep_square().hash(&mut hasher); hasher.finish() } pub fn position(&self) -> &Chess { &self.pos } pub fn turn(&self) -> Color { self.pos.turn() } pub fn is_white_to_move(&self) -> bool { self.pos.turn() == Color::White } pub fn ply(&self) -> usize { self.move_history.len() } pub fn move_history(&self) -> &[u16] { &self.move_history } pub fn halfmove_clock(&self) -> u32 { self.halfmove_clock } /// Get legal moves as token indices. pub fn legal_move_tokens(&self) -> Vec { let legal = self.pos.legal_moves(); legal.iter().map(|m| move_to_token(m)).collect() } /// Get legal moves as shakmaty Move objects. pub fn legal_moves(&self) -> MoveList { self.pos.legal_moves() } /// Apply a move given as a token index. pub fn make_move(&mut self, token: u16) -> Result<(), String> { let m = token_to_move(&self.pos, token) .ok_or_else(|| format!("Token {} is not a legal move at ply {}", token, self.ply()))?; // Update halfmove clock let is_pawn = match &m { Move::Normal { role, .. } => *role == Role::Pawn, Move::EnPassant { .. } => true, Move::Castle { .. } => false, Move::Put { .. } => false, }; let is_capture = m.is_capture(); if is_pawn || is_capture { self.halfmove_clock = 0; } else { self.halfmove_clock += 1; } self.pos.play_unchecked(m); self.move_history.push(token); let hash = Self::position_hash(&self.pos); self.position_hashes.push(hash); Ok(()) } /// Check if the game is over. Returns the termination reason if so. pub fn check_termination(&self, max_ply: usize) -> Option { let legal = self.pos.legal_moves(); // Check terminal states (checkmate/stalemate) before ply limit so // that games ending in checkmate on the final ply get the correct // termination code rather than PlyLimit. if legal.is_empty() { if self.pos.is_check() { return Some(Termination::Checkmate); } else { return Some(Termination::Stalemate); } } if self.ply() >= max_ply { return Some(Termination::PlyLimit); } // 75-move rule: 150 halfmoves without capture or pawn push if self.halfmove_clock >= 150 { return Some(Termination::SeventyFiveMoveRule); } // Fivefold repetition if self.is_fivefold_repetition() { return Some(Termination::FivefoldRepetition); } // Insufficient material if self.pos.is_insufficient_material() { return Some(Termination::InsufficientMaterial); } None } pub fn is_fivefold_repetition(&self) -> bool { let current = self.position_hashes.last().unwrap(); let count = self.position_hashes.iter().filter(|h| *h == current).count(); count >= 5 } /// Compute the legal move grid: [u64; 64] where bit d of grid[s] is set /// if a move from square s to square d is legal. pub fn legal_move_grid(&self) -> [u64; 64] { let mut grid = [0u64; 64]; let legal = self.pos.legal_moves(); for m in &legal { let token = move_to_token(m); if let Some((src, dst, _promo)) = vocab::decompose_token(token) { grid[src as usize] |= 1u64 << dst; } } grid } /// Compute the promotion mask: [[bool; 4]; 44] where mask[pair_idx][promo_type] /// is true if that specific promotion is legal. pub fn legal_promo_mask(&self) -> [[bool; 4]; 44] { let mut mask = [[false; 4]; 44]; let legal = self.pos.legal_moves(); for m in &legal { if let Move::Normal { from, to, promotion: Some(role), .. } = m { let src = shakmaty_sq_to_ours(*from); let dst = shakmaty_sq_to_ours(*to); if let Some(pair_idx) = vocab::promo_pair_index(src, dst) { let promo_type = match role { Role::Queen => 0, Role::Rook => 1, Role::Bishop => 2, Role::Knight => 3, _ => continue, }; mask[pair_idx][promo_type] = true; } } } mask } /// Extract board state for probing. pub fn board_array(&self) -> [[i8; 8]; 8] { let mut board = [[0i8; 8]; 8]; for rank in 0..8 { for file in 0..8 { let sq = Square::from_coords( shakmaty::File::new(file as u32), shakmaty::Rank::new(rank as u32), ); board[rank][file] = piece_to_code(self.pos.board().piece_at(sq)); } } board } /// Get castling rights as a 4-bit field: bit 0=K, 1=Q, 2=k, 3=q. pub fn castling_rights_bits(&self) -> u8 { let rights = self.pos.castles().castling_rights(); let mut bits = 0u8; if rights.contains(Square::H1) { bits |= 1; } // White kingside if rights.contains(Square::A1) { bits |= 2; } // White queenside if rights.contains(Square::H8) { bits |= 4; } // Black kingside if rights.contains(Square::A8) { bits |= 8; } // Black queenside bits } /// Get en passant square as our index (0-63), or -1 if none. pub fn ep_square(&self) -> i8 { match self.pos.legal_ep_square() { Some(sq) => shakmaty_sq_to_ours(sq) as i8, None => -1, } } pub fn is_check(&self) -> bool { self.pos.is_check() } /// Get legal moves structured for RL move selection. /// /// Returns (grid_indices, promotions) where: /// - grid_indices: flat src*64+dst for every legal move (promotion pairs deduplicated) /// - promotions: Vec of (pair_idx, legal_promo_types) for each promotion-eligible square pair pub fn legal_moves_structured(&self) -> (Vec, Vec<(u16, Vec)>) { let legal = self.pos.legal_moves(); let mut grid_indices: Vec = Vec::with_capacity(legal.len()); let mut promo_map: Vec<(u16, Vec)> = Vec::new(); let mut seen_promo_flat: u16 = u16::MAX; // track last seen promo flat_idx for dedup for m in &legal { let token = move_to_token(m); let (src, dst, promo) = vocab::decompose_token(token).unwrap(); let flat_idx = (src as u16) * 64 + (dst as u16); if promo == 0 { grid_indices.push(flat_idx); } else { let pair_idx = vocab::promo_pair_index(src, dst).unwrap(); let promo_type = promo - 1; // 1-indexed to 0-indexed if flat_idx != seen_promo_flat { // New promotion pair — add grid index and start new entry grid_indices.push(flat_idx); promo_map.push((pair_idx as u16, vec![promo_type])); seen_promo_flat = flat_idx; } else { // Same pair, add promo type promo_map.last_mut().unwrap().1.push(promo_type); } } } (grid_indices, promo_map) } /// Return a dense 4096-element mask: true if (src*64+dst) has a legal move. pub fn legal_moves_grid_mask(&self) -> [bool; 4096] { let legal = self.pos.legal_moves(); let mut mask = [false; 4096]; for m in &legal { let token = move_to_token(m); let (src, dst, _promo) = vocab::decompose_token(token).unwrap(); let flat_idx = (src as usize) * 64 + (dst as usize); mask[flat_idx] = true; } mask } /// Get all legal move data in a single pass: structured moves + dense grid mask. /// /// Computes `legal_moves()` once and derives both structured data (for promo /// handling) and the dense 4096-bool mask (for softmax masking). pub fn legal_moves_full(&self) -> (Vec, Vec<(u16, Vec)>, [bool; 4096]) { let legal = self.pos.legal_moves(); let mut grid_indices: Vec = Vec::with_capacity(legal.len()); let mut promo_map: Vec<(u16, Vec)> = Vec::new(); let mut seen_promo_flat: u16 = u16::MAX; let mut mask = [false; 4096]; for m in &legal { let token = move_to_token(m); let (src, dst, promo) = vocab::decompose_token(token).unwrap(); let flat_idx = (src as u16) * 64 + (dst as u16); mask[flat_idx as usize] = true; if promo == 0 { grid_indices.push(flat_idx); } else { let pair_idx = vocab::promo_pair_index(src, dst).unwrap(); let promo_type = promo - 1; if flat_idx != seen_promo_flat { grid_indices.push(flat_idx); promo_map.push((pair_idx as u16, vec![promo_type])); seen_promo_flat = flat_idx; } else { promo_map.last_mut().unwrap().1.push(promo_type); } } } (grid_indices, promo_map, mask) } /// Apply a move and return its UCI string. Returns Err if illegal. pub fn make_move_uci(&mut self, token: u16) -> Result { let uci = vocab::token_to_uci(token) .ok_or_else(|| format!("Token {} has no UCI representation", token))?; self.make_move(token)?; Ok(uci) } /// Get the UCI position string for engine communication. /// Returns "position startpos" or "position startpos moves e2e4 e7e5 ..." pub fn uci_position_string(&self) -> String { if self.move_history.is_empty() { return "position startpos".to_string(); } let mut s = String::with_capacity(24 + self.move_history.len() * 6); s.push_str("position startpos moves"); for &token in &self.move_history { s.push(' '); s.push_str(&vocab::token_to_uci(token).unwrap()); } s } /// Get the FEN string for the current position. pub fn fen(&self) -> String { let setup = self.pos.to_setup(EnPassantMode::Legal); let fen = Fen::try_from(setup).expect("valid position should produce valid FEN"); fen.to_string() } /// Pick a random legal move, apply it, and return the token. /// Returns None if no legal moves (game is over). pub fn make_random_move(&mut self, rng: &mut impl Rng) -> Option { let legal = self.pos.legal_moves(); if legal.is_empty() { return None; } let idx = rng.gen_range(0..legal.len()); let m = &legal[idx]; let token = move_to_token(m); // We know the move is legal, so this should always succeed self.make_move(token).ok(); Some(token) } /// Create a GameState by replaying a sequence of move tokens from the starting position. /// Returns an error if any token is invalid or illegal. pub fn from_move_tokens(tokens: &[u16]) -> Result { let mut state = Self::new(); for (i, &token) in tokens.iter().enumerate() { state.make_move(token).map_err(|e| format!("ply {}: {}", i, e))?; } Ok(state) } /// Play out a random game from the current position to completion. /// Returns the termination type. pub fn play_random_to_end(&mut self, rng: &mut impl Rng, max_ply: usize) -> Termination { loop { if let Some(term) = self.check_termination(max_ply) { return term; } if self.make_random_move(rng).is_none() { return Termination::Stalemate; } } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_square_conversion_roundtrip() { for i in 0..64u8 { let sq = our_sq_to_shakmaty(i); assert_eq!(shakmaty_sq_to_ours(sq), i, "Roundtrip failed for {}", i); } } #[test] fn test_initial_legal_moves() { let state = GameState::new(); let tokens = state.legal_move_tokens(); // Starting position has 20 legal moves assert_eq!(tokens.len(), 20, "Starting position should have 20 legal moves"); } #[test] fn test_make_move() { let mut state = GameState::new(); // e2e4: src=e2=12, dst=e4=28 let token = vocab::base_grid_token(12, 28); state.make_move(token).unwrap(); assert_eq!(state.ply(), 1); assert_eq!(state.turn(), Color::Black); } #[test] fn test_legal_move_grid() { let state = GameState::new(); let grid = state.legal_move_grid(); // Count total legal moves from grid let total: u32 = grid.iter().map(|g| g.count_ones()).sum(); assert_eq!(total, 20); } #[test] fn test_castling_token() { // Verify e1g1 maps correctly let src = shakmaty_sq_to_ours(Square::E1); let dst = shakmaty_sq_to_ours(Square::G1); assert_eq!(src, 4); // e1 assert_eq!(dst, 6); // g1 let token = vocab::base_grid_token(src, dst); let uci = vocab::token_to_uci(token).unwrap(); assert_eq!(uci, "e1g1"); } }