//! Batch RL environment — owns all game state in Rust. //! //! Supports random opponents (pure Rust) and UCI engine opponents //! (via the `ruci` crate with `engine-sync` feature). use rand_chacha::ChaCha8Rng; use rand::SeedableRng; use crate::board::GameState; use crate::types::Termination; /// Per-game metadata tracked alongside the Rust GameState. struct GameMeta { agent_is_white: bool, terminated: bool, forfeited: bool, outcome_reward: f32, agent_plies: u32, termination_code: i8, // -1 if not terminated } impl GameMeta { fn new(agent_is_white: bool) -> Self { Self { agent_is_white, terminated: false, forfeited: false, outcome_reward: 0.0, agent_plies: 0, termination_code: -1, } } } /// Batch RL environment owning N concurrent chess games. pub struct BatchRLEnv { games: Vec, meta: Vec, n_games: usize, max_ply: usize, rng: ChaCha8Rng, } impl BatchRLEnv { pub fn new(n_games: usize, max_ply: usize, seed: u64) -> Self { Self { games: Vec::with_capacity(n_games), meta: Vec::with_capacity(n_games), n_games, max_ply, rng: ChaCha8Rng::seed_from_u64(seed), } } /// Reset all games. First half: agent=white, second half: agent=black. pub fn reset(&mut self) { self.games.clear(); self.meta.clear(); let n_white = self.n_games / 2; for i in 0..self.n_games { self.games.push(GameState::new()); self.meta.push(GameMeta::new(i < n_white)); } } pub fn n_games(&self) -> usize { self.n_games } pub fn all_terminated(&self) -> bool { self.meta.iter().all(|m| m.terminated) } /// Game indices where it's the agent's turn and game is not over. pub fn active_agent_games(&self) -> Vec { (0..self.n_games) .filter(|&i| { let m = &self.meta[i]; if m.terminated { return false; } let white_to_move = self.games[i].is_white_to_move(); white_to_move == m.agent_is_white }) .map(|i| i as u32) .collect() } /// Game indices where it's the opponent's turn and game is not over. pub fn active_opponent_games(&self) -> Vec { (0..self.n_games) .filter(|&i| { let m = &self.meta[i]; if m.terminated { return false; } let white_to_move = self.games[i].is_white_to_move(); white_to_move != m.agent_is_white }) .map(|i| i as u32) .collect() } // ------------------------------------------------------------------ // Finalization // ------------------------------------------------------------------ fn finalize(&mut self, gi: usize) { let m = &mut self.meta[gi]; m.terminated = true; if let Some(term) = self.games[gi].check_termination(self.max_ply) { m.termination_code = term.as_u8() as i8; match term { Termination::Checkmate => { // The side to move is checkmated (lost). let loser_is_white = self.games[gi].is_white_to_move(); if loser_is_white == m.agent_is_white { m.outcome_reward = -1.0; // agent lost } else { m.outcome_reward = 1.0; // agent won } } // All other terminations are draws _ => { m.outcome_reward = 0.0; } } } } // ------------------------------------------------------------------ // Side-agnostic moves (for autoregressive generation) // ------------------------------------------------------------------ /// Apply moves to the specified games regardless of whose turn it is. /// Returns (legality_flags, termination_codes). /// Termination code is -1 if the game continues after the move. pub fn apply_moves(&mut self, game_indices: &[u32], tokens: &[u16]) -> (Vec, Vec) { let mut flags = Vec::with_capacity(game_indices.len()); let mut term_codes = Vec::with_capacity(game_indices.len()); for (&gi, &token) in game_indices.iter().zip(tokens.iter()) { let gi = gi as usize; if self.meta[gi].terminated { flags.push(false); term_codes.push(self.meta[gi].termination_code); continue; } if self.games[gi].make_move(token).is_err() { self.meta[gi].terminated = true; self.meta[gi].forfeited = true; self.meta[gi].termination_code = -3; // forfeit flags.push(false); term_codes.push(-3); continue; } flags.push(true); if let Some(term) = self.games[gi].check_termination(self.max_ply) { self.meta[gi].terminated = true; self.meta[gi].termination_code = term.as_u8() as i8; term_codes.push(term.as_u8() as i8); } else { term_codes.push(-1); } } (flags, term_codes) } /// Load prefix move sequences into games. Each game replays the given /// moves from the start position. Returns per-game termination codes /// (-1 if still going, >=0 if terminated during prefix). pub fn load_prefixes(&mut self, move_ids: &[u16], lengths: &[u32], n_games: usize, max_ply: usize) -> Vec { let mut term_codes = Vec::with_capacity(n_games); for gi in 0..n_games { let len = lengths[gi] as usize; let mut tc: i8 = -1; for t in 0..len { let token = move_ids[gi * max_ply + t]; if self.games[gi].make_move(token).is_err() { self.meta[gi].terminated = true; self.meta[gi].forfeited = true; self.meta[gi].termination_code = -3; tc = -3; break; } if let Some(term) = self.games[gi].check_termination(self.max_ply) { self.meta[gi].terminated = true; self.meta[gi].termination_code = term.as_u8() as i8; tc = term.as_u8() as i8; break; } } term_codes.push(tc); } term_codes } // ------------------------------------------------------------------ // Agent moves // ------------------------------------------------------------------ /// Apply agent moves to the specified games. Returns legality flags. pub fn apply_agent_moves(&mut self, game_indices: &[u32], tokens: &[u16]) -> Vec { let mut flags = Vec::with_capacity(game_indices.len()); for (&gi, &token) in game_indices.iter().zip(tokens.iter()) { let gi = gi as usize; if self.meta[gi].terminated { flags.push(false); continue; } if self.games[gi].make_move(token).is_err() { self.meta[gi].terminated = true; self.meta[gi].forfeited = true; self.meta[gi].outcome_reward = -1.0; flags.push(false); continue; } flags.push(true); self.meta[gi].agent_plies += 1; if self.games[gi].check_termination(self.max_ply).is_some() { self.finalize(gi); } } flags } // ------------------------------------------------------------------ // Opponent moves (random) // ------------------------------------------------------------------ /// Make random legal moves for all games where it's the opponent's turn. /// Returns indices of games that were acted upon. pub fn apply_random_opponent_moves(&mut self) -> Vec { let opp_games = self.active_opponent_games(); for &gi in &opp_games { let gi = gi as usize; self.games[gi].make_random_move(&mut self.rng); if self.games[gi].check_termination(self.max_ply).is_some() { self.finalize(gi); } } opp_games } // ------------------------------------------------------------------ // Opponent moves (UCI engine token from Python) // ------------------------------------------------------------------ /// Apply opponent moves received from an external UCI engine. /// Does NOT increment agent_plies. Returns legality flags. pub fn apply_opponent_moves(&mut self, game_indices: &[u32], tokens: &[u16]) -> Vec { let mut flags = Vec::with_capacity(game_indices.len()); for (&gi, &token) in game_indices.iter().zip(tokens.iter()) { let gi = gi as usize; if self.meta[gi].terminated { flags.push(false); continue; } if self.games[gi].make_move(token).is_err() { // Engine move was illegal — shouldn't happen. Terminate as draw. self.meta[gi].terminated = true; self.meta[gi].outcome_reward = 0.0; flags.push(false); continue; } flags.push(true); if self.games[gi].check_termination(self.max_ply).is_some() { self.finalize(gi); } } flags } // ------------------------------------------------------------------ // Bulk data extraction // ------------------------------------------------------------------ /// Get legal move token masks for a batch of games. /// Returns flat bool vec of len(game_indices) * vocab_size. pub fn get_legal_token_masks_batch( &self, game_indices: &[u32], vocab_size: usize, ) -> Vec { let b = game_indices.len(); let mut masks = vec![false; b * vocab_size]; for (bi, &gi) in game_indices.iter().enumerate() { for tok in self.games[gi as usize].legal_move_tokens() { let idx = bi * vocab_size + tok as usize; if idx < masks.len() { masks[idx] = true; } } } masks } /// Get legal moves for a batch of games: structured data + dense masks. /// Returns (per_game_structured, flat_dense_masks). /// flat_dense_masks is a flat vec of len(game_indices) * 4096 bools. pub fn get_legal_moves_batch( &self, game_indices: &[u32], ) -> (Vec<(Vec, Vec<(u16, Vec)>)>, Vec) { let b = game_indices.len(); let mut structured = Vec::with_capacity(b); let mut flat_masks = vec![false; b * 4096]; for (bi, &gi) in game_indices.iter().enumerate() { let (indices, promos, mask) = self.games[gi as usize].legal_moves_full(); structured.push((indices, promos)); flat_masks[bi * 4096..(bi + 1) * 4096].copy_from_slice(&mask); } (structured, flat_masks) } /// Get move histories as a flat i64 vec of shape (B, max_ply) + lengths. pub fn get_move_histories(&self, game_indices: &[u32]) -> (Vec, Vec) { let b = game_indices.len(); let mut flat = vec![0i64; b * self.max_ply]; let mut lengths = Vec::with_capacity(b); for (bi, &gi) in game_indices.iter().enumerate() { let hist = self.games[gi as usize].move_history(); let len = hist.len().min(self.max_ply); for t in 0..len { flat[bi * self.max_ply + t] = hist[t] as i64; } lengths.push(len as i32); } (flat, lengths) } /// Get sentinel tokens (first legal move token) for specified games. pub fn get_sentinel_tokens(&self, game_indices: &[u32]) -> Vec { game_indices .iter() .map(|&gi| { let tokens = self.games[gi as usize].legal_move_tokens(); if tokens.is_empty() { 1 } else { tokens[0] } }) .collect() } /// Get FEN strings for specified games (for Stockfish eval). pub fn get_fens(&self, game_indices: &[u32]) -> Vec { game_indices .iter() .map(|&gi| self.games[gi as usize].fen()) .collect() } /// Get UCI position strings for specified games (for UCI engine communication). pub fn get_uci_positions(&self, game_indices: &[u32]) -> Vec { game_indices .iter() .map(|&gi| self.games[gi as usize].uci_position_string()) .collect() } /// Get ply counts for specified games. pub fn get_plies(&self, game_indices: &[u32]) -> Vec { game_indices .iter() .map(|&gi| self.games[gi as usize].ply() as u32) .collect() } /// Per-game outcome data for all N games. /// Returns (terminated, forfeited, outcome_reward, agent_plies, /// termination_codes, agent_is_white). pub fn get_outcomes(&self) -> (Vec, Vec, Vec, Vec, Vec, Vec) { let mut terminated = Vec::with_capacity(self.n_games); let mut forfeited = Vec::with_capacity(self.n_games); let mut rewards = Vec::with_capacity(self.n_games); let mut plies = Vec::with_capacity(self.n_games); let mut codes = Vec::with_capacity(self.n_games); let mut colors = Vec::with_capacity(self.n_games); for m in &self.meta { terminated.push(m.terminated); forfeited.push(m.forfeited); rewards.push(m.outcome_reward); plies.push(m.agent_plies); codes.push(m.termination_code); colors.push(m.agent_is_white); } (terminated, forfeited, rewards, plies, codes, colors) } /// Access a single game's state (for debugging/testing). pub fn game(&self, gi: usize) -> &GameState { &self.games[gi] } pub fn meta(&self, gi: usize) -> (bool, bool, bool, f32, u32, i8) { let m = &self.meta[gi]; ( m.agent_is_white, m.terminated, m.forfeited, m.outcome_reward, m.agent_plies, m.termination_code, ) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_reset_and_active_games() { let mut env = BatchRLEnv::new(4, 256, 42); env.reset(); assert_eq!(env.n_games(), 4); assert!(!env.all_terminated()); // First half (0,1) = agent white, should be active (white moves first) // Second half (2,3) = agent black, opponent moves first let agent = env.active_agent_games(); let opp = env.active_opponent_games(); assert_eq!(agent, vec![0, 1]); assert_eq!(opp, vec![2, 3]); } #[test] fn test_random_opponent_moves() { let mut env = BatchRLEnv::new(4, 256, 42); env.reset(); // Apply random opponent opening moves for black-agent games let acted = env.apply_random_opponent_moves(); assert_eq!(acted, vec![2, 3]); // only the games where opponent goes first // Now all 4 games should be active for the agent let agent = env.active_agent_games(); assert_eq!(agent, vec![0, 1, 2, 3]); } #[test] fn test_fen_export() { let mut env = BatchRLEnv::new(1, 256, 42); env.reset(); let fens = env.get_fens(&[0]); assert!(fens[0].contains("rnbqkbnr")); // starting position } #[test] fn test_legal_moves_batch() { let mut env = BatchRLEnv::new(2, 256, 42); env.reset(); let (structured, masks) = env.get_legal_moves_batch(&[0, 1]); assert_eq!(structured.len(), 2); assert_eq!(masks.len(), 2 * 4096); // Starting position has 20 legal moves assert_eq!(structured[0].0.len(), 20); } }