| pub mod types; |
| pub mod vocab; |
| pub mod board; |
| pub mod random; |
| pub mod pgn; |
| pub mod uci; |
| pub mod engine_gen; |
| pub mod rl_batch; |
| pub mod batch; |
| pub mod labels; |
| pub mod edgestats; |
| pub mod diagnostic; |
| pub mod validate; |
| pub mod extract; |
|
|
| use pyo3::prelude::*; |
| use pyo3::types::{PyDict, PyList}; |
| use numpy::{PyArray1, PyArray2, PyArray3, PyArray4, PyArrayMethods, PyUntypedArrayMethods}; |
|
|
| |
| #[pyfunction] |
| fn export_move_vocabulary(py: Python<'_>) -> PyResult<PyObject> { |
| let dict = PyDict::new(py); |
|
|
| let sq_names: Vec<&str> = vocab::SQUARE_NAMES.to_vec(); |
| dict.set_item("square_names", sq_names)?; |
|
|
| let promo_pieces: Vec<&str> = vocab::PROMO_PIECES.to_vec(); |
| dict.set_item("promo_pieces", promo_pieces)?; |
|
|
| let pairs = vocab::promo_pairs(); |
| let promo_pairs_list = PyList::empty(py); |
| for &(s, d) in pairs.iter() { |
| promo_pairs_list.append((s, d))?; |
| } |
| dict.set_item("promo_pairs", promo_pairs_list)?; |
|
|
| let (t2m, m2t) = vocab::build_vocab_maps(); |
| let t2m_dict = PyDict::new(py); |
| for (k, v) in &t2m { |
| t2m_dict.set_item(*k, v.as_str())?; |
| } |
| dict.set_item("token_to_move", t2m_dict)?; |
|
|
| let m2t_dict = PyDict::new(py); |
| for (k, v) in &m2t { |
| m2t_dict.set_item(k.as_str(), *v)?; |
| } |
| dict.set_item("move_to_token", m2t_dict)?; |
|
|
| Ok(dict.into()) |
| } |
|
|
| |
| #[pyfunction] |
| #[pyo3(signature = (batch_size, max_ply=256, seed=42))] |
| fn generate_training_batch<'py>( |
| py: Python<'py>, |
| batch_size: usize, |
| max_ply: usize, |
| seed: u64, |
| ) -> PyResult<( |
| Bound<'py, PyArray2<i16>>, |
| Bound<'py, PyArray1<i16>>, |
| Bound<'py, PyArray3<u64>>, |
| Bound<'py, PyArray4<bool>>, |
| Bound<'py, PyArray1<u8>>, |
| )> { |
| let result = py.allow_threads(|| { |
| batch::generate_training_batch(batch_size, max_ply, seed) |
| }); |
|
|
| let move_ids = numpy::PyArray::from_vec(py, result.move_ids) |
| .reshape([batch_size, max_ply])?; |
| let game_lengths = numpy::PyArray::from_vec(py, result.game_lengths); |
| let legal_move_grid = numpy::PyArray::from_vec(py, result.legal_move_grid) |
| .reshape([batch_size, max_ply, 64])?; |
| let legal_promo_mask = numpy::PyArray::from_vec(py, result.legal_promo_mask) |
| .reshape([batch_size, max_ply, 44, 4])?; |
| let termination_codes = numpy::PyArray::from_vec(py, result.termination_codes); |
|
|
| Ok((move_ids, game_lengths, legal_move_grid, legal_promo_mask, termination_codes)) |
| } |
|
|
| |
| #[pyfunction] |
| #[pyo3(signature = (n_white_wins, n_black_wins, max_ply=256, seed=42))] |
| fn generate_checkmate_games<'py>( |
| py: Python<'py>, |
| n_white_wins: usize, |
| n_black_wins: usize, |
| max_ply: usize, |
| seed: u64, |
| ) -> PyResult<( |
| Bound<'py, PyArray2<i16>>, |
| Bound<'py, PyArray1<i16>>, |
| Bound<'py, PyArray1<u8>>, |
| usize, |
| )> { |
| let (result, total_generated) = py.allow_threads(|| { |
| batch::generate_checkmate_games(n_white_wins, n_black_wins, max_ply, seed) |
| }); |
|
|
| let n_games = result.n_games; |
| let move_ids = numpy::PyArray::from_vec(py, result.move_ids) |
| .reshape([n_games, max_ply])?; |
| let game_lengths = numpy::PyArray::from_vec(py, result.game_lengths); |
| let termination_codes = numpy::PyArray::from_vec(py, result.termination_codes); |
|
|
| Ok((move_ids, game_lengths, termination_codes, total_generated)) |
| } |
|
|
| |
| #[pyfunction] |
| #[pyo3(signature = (n_games, max_ply=256, seed=42))] |
| fn generate_checkmate_training_batch<'py>( |
| py: Python<'py>, |
| n_games: usize, |
| max_ply: usize, |
| seed: u64, |
| ) -> PyResult<( |
| Bound<'py, PyArray2<i16>>, |
| Bound<'py, PyArray1<i16>>, |
| Bound<'py, PyArray2<u64>>, |
| Bound<'py, PyArray2<u64>>, |
| usize, |
| )> { |
| let result = py.allow_threads(|| { |
| batch::generate_checkmate_training_batch(n_games, max_ply, seed) |
| }); |
|
|
| let n = result.n_games; |
| let move_ids = numpy::PyArray::from_vec(py, result.move_ids) |
| .reshape([n, max_ply])?; |
| let game_lengths = numpy::PyArray::from_vec(py, result.game_lengths); |
| let checkmate_targets = numpy::PyArray::from_vec(py, result.checkmate_targets) |
| .reshape([n, 64])?; |
| let legal_grids = numpy::PyArray::from_vec(py, result.legal_grids) |
| .reshape([n, 64])?; |
|
|
| Ok((move_ids, game_lengths, checkmate_targets, legal_grids, result.total_generated)) |
| } |
|
|
| |
| #[pyfunction] |
| #[pyo3(signature = (n_games, max_ply=256, seed=42, discard_ply_limit=false))] |
| fn generate_random_games<'py>( |
| py: Python<'py>, |
| n_games: usize, |
| max_ply: usize, |
| seed: u64, |
| discard_ply_limit: bool, |
| ) -> PyResult<( |
| Bound<'py, PyArray2<i16>>, |
| Bound<'py, PyArray1<i16>>, |
| Bound<'py, PyArray1<u8>>, |
| )> { |
| let result = py.allow_threads(|| { |
| if discard_ply_limit { |
| batch::generate_completed_games(n_games, max_ply, seed) |
| } else { |
| batch::generate_random_games(n_games, max_ply, seed) |
| } |
| }); |
|
|
| let move_ids = numpy::PyArray::from_vec(py, result.move_ids) |
| .reshape([n_games, max_ply])?; |
| let game_lengths = numpy::PyArray::from_vec(py, result.game_lengths); |
| let termination_codes = numpy::PyArray::from_vec(py, result.termination_codes); |
|
|
| Ok((move_ids, game_lengths, termination_codes)) |
| } |
|
|
| |
| |
| |
| |
| |
| #[pyfunction] |
| #[pyo3(signature = (batch_size, seq_len=256, seed=42, discard_ply_limit=false))] |
| fn generate_clm_batch<'py>( |
| py: Python<'py>, |
| batch_size: usize, |
| seq_len: usize, |
| seed: u64, |
| discard_ply_limit: bool, |
| ) -> PyResult<( |
| Bound<'py, PyArray2<i16>>, |
| Bound<'py, PyArray2<i16>>, |
| Bound<'py, PyArray2<bool>>, |
| Bound<'py, PyArray2<i16>>, |
| Bound<'py, PyArray1<i16>>, |
| Bound<'py, PyArray1<u8>>, |
| )> { |
| let result = py.allow_threads(|| { |
| batch::generate_clm_batch(batch_size, seq_len, seed, discard_ply_limit) |
| }); |
|
|
| let max_ply = seq_len - 1; |
| let input_ids = numpy::PyArray::from_vec(py, result.input_ids) |
| .reshape([batch_size, seq_len])?; |
| let targets = numpy::PyArray::from_vec(py, result.targets) |
| .reshape([batch_size, seq_len])?; |
| let loss_mask = numpy::PyArray::from_vec(py, result.loss_mask) |
| .reshape([batch_size, seq_len])?; |
| let move_ids = numpy::PyArray::from_vec(py, result.move_ids) |
| .reshape([batch_size, max_ply])?; |
| let game_lengths = numpy::PyArray::from_vec(py, result.game_lengths); |
| let termination_codes = numpy::PyArray::from_vec(py, result.termination_codes); |
|
|
| Ok((input_ids, targets, loss_mask, move_ids, game_lengths, termination_codes)) |
| } |
|
|
| |
| #[pyfunction] |
| fn compute_legal_move_masks<'py>( |
| py: Python<'py>, |
| move_ids: numpy::PyReadonlyArray2<'py, i16>, |
| game_lengths: numpy::PyReadonlyArray1<'py, i16>, |
| ) -> PyResult<( |
| Bound<'py, PyArray3<u64>>, |
| Bound<'py, PyArray4<bool>>, |
| )> { |
| let move_ids_slice = move_ids.as_slice()?; |
| let game_lengths_slice = game_lengths.as_slice()?; |
| let dims = move_ids.shape(); |
| let batch_size = dims[0]; |
| let max_ply = dims[1]; |
|
|
| if game_lengths_slice.len() != batch_size { |
| return Err(pyo3::exceptions::PyValueError::new_err( |
| format!("game_lengths has {} elements but move_ids has {} rows", game_lengths_slice.len(), batch_size), |
| )); |
| } |
| for (i, &gl) in game_lengths_slice.iter().enumerate() { |
| if gl < 0 || (gl as usize) > max_ply { |
| return Err(pyo3::exceptions::PyValueError::new_err( |
| format!("game_lengths[{}] = {} is out of range [0, {}]", i, gl, max_ply), |
| )); |
| } |
| } |
|
|
| let (grids, promos) = py.allow_threads(|| { |
| labels::compute_legal_move_masks(move_ids_slice, game_lengths_slice, max_ply) |
| }); |
|
|
| let grids_arr = numpy::PyArray::from_vec(py, grids) |
| .reshape([batch_size, max_ply, 64])?; |
| let promos_arr = numpy::PyArray::from_vec(py, promos) |
| .reshape([batch_size, max_ply, 44, 4])?; |
|
|
| Ok((grids_arr, promos_arr)) |
| } |
|
|
| |
| |
| #[pyfunction] |
| fn compute_legal_token_masks<'py>( |
| py: Python<'py>, |
| move_ids: numpy::PyReadonlyArray2<'py, i16>, |
| game_lengths: numpy::PyReadonlyArray1<'py, i16>, |
| vocab_size: usize, |
| ) -> PyResult<Bound<'py, PyArray3<bool>>> { |
| let move_ids_slice = move_ids.as_slice()?; |
| let game_lengths_slice = game_lengths.as_slice()?; |
| let dims = move_ids.shape(); |
| let batch_size = dims[0]; |
| let max_ply = dims[1]; |
|
|
| if game_lengths_slice.len() != batch_size { |
| return Err(pyo3::exceptions::PyValueError::new_err( |
| format!("game_lengths has {} elements but move_ids has {} rows", |
| game_lengths_slice.len(), batch_size), |
| )); |
| } |
| for (i, &gl) in game_lengths_slice.iter().enumerate() { |
| if gl < 0 || (gl as usize) > max_ply { |
| return Err(pyo3::exceptions::PyValueError::new_err( |
| format!("game_lengths[{}] = {} is out of range [0, {}]", i, gl, max_ply), |
| )); |
| } |
| } |
|
|
| let masks = py.allow_threads(|| { |
| labels::compute_legal_token_masks( |
| move_ids_slice, game_lengths_slice, max_ply, vocab_size, |
| ) |
| }); |
|
|
| let arr = numpy::PyArray::from_vec(py, masks) |
| .reshape([batch_size, max_ply, vocab_size])?; |
| Ok(arr) |
| } |
|
|
| |
| #[pyfunction] |
| fn compute_legal_token_masks_sparse<'py>( |
| py: Python<'py>, |
| move_ids: numpy::PyReadonlyArray2<'py, i16>, |
| game_lengths: numpy::PyReadonlyArray1<'py, i16>, |
| seq_len: usize, |
| vocab_size: usize, |
| ) -> PyResult<Bound<'py, PyArray1<i64>>> { |
| let move_ids_slice = move_ids.as_slice()?; |
| let game_lengths_slice = game_lengths.as_slice()?; |
| let dims = move_ids.shape(); |
| let batch_size = dims[0]; |
| let max_ply = dims[1]; |
|
|
| if game_lengths_slice.len() != batch_size { |
| return Err(pyo3::exceptions::PyValueError::new_err( |
| format!("game_lengths has {} elements but move_ids has {} rows", |
| game_lengths_slice.len(), batch_size), |
| )); |
| } |
| for (i, &gl) in game_lengths_slice.iter().enumerate() { |
| if gl < 0 || (gl as usize) > max_ply { |
| return Err(pyo3::exceptions::PyValueError::new_err( |
| format!("game_lengths[{}] = {} is out of range [0, {}]", i, gl, max_ply), |
| )); |
| } |
| } |
|
|
| let indices = py.allow_threads(|| { |
| labels::compute_legal_token_masks_sparse( |
| move_ids_slice, game_lengths_slice, max_ply, seq_len, vocab_size, |
| ) |
| }); |
|
|
| Ok(numpy::PyArray::from_vec(py, indices)) |
| } |
|
|
| |
| #[pyfunction] |
| fn extract_board_states<'py>( |
| py: Python<'py>, |
| move_ids: numpy::PyReadonlyArray2<'py, i16>, |
| game_lengths: numpy::PyReadonlyArray1<'py, i16>, |
| ) -> PyResult<( |
| Bound<'py, PyArray4<i8>>, |
| Bound<'py, PyArray2<bool>>, |
| Bound<'py, PyArray2<u8>>, |
| Bound<'py, PyArray2<i8>>, |
| Bound<'py, PyArray2<bool>>, |
| Bound<'py, PyArray2<u8>>, |
| )> { |
| let move_ids_slice = move_ids.as_slice()?; |
| let game_lengths_slice = game_lengths.as_slice()?; |
| let dims = move_ids.shape(); |
| let batch_size = dims[0]; |
| let max_ply = dims[1]; |
|
|
| if game_lengths_slice.len() != batch_size { |
| return Err(pyo3::exceptions::PyValueError::new_err( |
| format!("game_lengths has {} elements but move_ids has {} rows", game_lengths_slice.len(), batch_size), |
| )); |
| } |
| for (i, &gl) in game_lengths_slice.iter().enumerate() { |
| if gl < 0 || (gl as usize) > max_ply { |
| return Err(pyo3::exceptions::PyValueError::new_err( |
| format!("game_lengths[{}] = {} is out of range [0, {}]", i, gl, max_ply), |
| )); |
| } |
| } |
|
|
| let states = py.allow_threads(|| { |
| extract::extract_board_states(move_ids_slice, game_lengths_slice, max_ply) |
| }); |
|
|
| let boards = numpy::PyArray::from_vec(py, states.boards) |
| .reshape([batch_size, max_ply, 8, 8])?; |
| let side_to_move = numpy::PyArray::from_vec(py, states.side_to_move) |
| .reshape([batch_size, max_ply])?; |
| let castling_rights = numpy::PyArray::from_vec(py, states.castling_rights) |
| .reshape([batch_size, max_ply])?; |
| let ep_square = numpy::PyArray::from_vec(py, states.ep_square) |
| .reshape([batch_size, max_ply])?; |
| let is_check = numpy::PyArray::from_vec(py, states.is_check) |
| .reshape([batch_size, max_ply])?; |
| let halfmove_clock = numpy::PyArray::from_vec(py, states.halfmove_clock) |
| .reshape([batch_size, max_ply])?; |
|
|
| Ok((boards, side_to_move, castling_rights, ep_square, is_check, halfmove_clock)) |
| } |
|
|
| |
| #[pyfunction] |
| fn validate_games_py<'py>( |
| py: Python<'py>, |
| move_ids: numpy::PyReadonlyArray2<'py, i16>, |
| game_lengths: numpy::PyReadonlyArray1<'py, i16>, |
| ) -> PyResult<( |
| Bound<'py, PyArray1<bool>>, |
| Bound<'py, PyArray1<i16>>, |
| )> { |
| let move_ids_slice = move_ids.as_slice()?; |
| let game_lengths_slice = game_lengths.as_slice()?; |
| let batch_size = move_ids.shape()[0]; |
| let max_ply = move_ids.shape()[1]; |
|
|
| if game_lengths_slice.len() != batch_size { |
| return Err(pyo3::exceptions::PyValueError::new_err( |
| format!("game_lengths has {} elements but move_ids has {} rows", game_lengths_slice.len(), batch_size), |
| )); |
| } |
|
|
| let (is_valid, first_illegal) = py.allow_threads(|| { |
| validate::validate_games(move_ids_slice, game_lengths_slice, max_ply) |
| }); |
|
|
| let is_valid_arr = numpy::PyArray::from_vec(py, is_valid); |
| let first_illegal_arr = numpy::PyArray::from_vec(py, first_illegal); |
|
|
| Ok((is_valid_arr, first_illegal_arr)) |
| } |
|
|
| |
| #[pyfunction] |
| fn compute_edge_stats_per_ply_py<'py>( |
| py: Python<'py>, |
| move_ids: numpy::PyReadonlyArray2<'py, i16>, |
| game_lengths: numpy::PyReadonlyArray1<'py, i16>, |
| ) -> PyResult<( |
| Bound<'py, PyArray2<u64>>, |
| Bound<'py, PyArray1<u64>>, |
| Bound<'py, PyArray1<u64>>, |
| )> { |
| let move_ids_slice = move_ids.as_slice()?; |
| let game_lengths_slice = game_lengths.as_slice()?; |
| let dims = move_ids.shape(); |
| let batch_size = dims[0]; |
| let max_ply = dims[1]; |
|
|
| if game_lengths_slice.len() != batch_size { |
| return Err(pyo3::exceptions::PyValueError::new_err( |
| format!("game_lengths has {} elements but move_ids has {} rows", game_lengths_slice.len(), batch_size), |
| )); |
| } |
|
|
| let (per_ply, white, black) = py.allow_threads(|| { |
| edgestats::compute_edge_stats_per_ply(move_ids_slice, game_lengths_slice, max_ply) |
| }); |
|
|
| let per_ply_arr = numpy::PyArray::from_vec(py, per_ply) |
| .reshape([batch_size, max_ply])?; |
| let white_arr = numpy::PyArray::from_vec(py, white); |
| let black_arr = numpy::PyArray::from_vec(py, black); |
|
|
| Ok((per_ply_arr, white_arr, black_arr)) |
| } |
|
|
| |
| #[pyfunction] |
| fn compute_edge_stats_per_game_py<'py>( |
| py: Python<'py>, |
| move_ids: numpy::PyReadonlyArray2<'py, i16>, |
| game_lengths: numpy::PyReadonlyArray1<'py, i16>, |
| ) -> PyResult<( |
| Bound<'py, PyArray1<u64>>, |
| Bound<'py, PyArray1<u64>>, |
| )> { |
| let move_ids_slice = move_ids.as_slice()?; |
| let game_lengths_slice = game_lengths.as_slice()?; |
| let max_ply = move_ids.shape()[1]; |
|
|
| let (white, black) = py.allow_threads(|| { |
| edgestats::compute_edge_stats_per_game(move_ids_slice, game_lengths_slice, max_ply) |
| }); |
|
|
| let white_arr = numpy::PyArray::from_vec(py, white); |
| let black_arr = numpy::PyArray::from_vec(py, black); |
|
|
| Ok((white_arr, black_arr)) |
| } |
|
|
| |
| #[pyfunction] |
| #[pyo3(signature = (quotas_white, quotas_black, total_games, max_ply=256, seed=42, max_simulated_factor=100.0))] |
| fn generate_diagnostic_sets_py<'py>( |
| py: Python<'py>, |
| quotas_white: numpy::PyReadonlyArray1<'py, i32>, |
| quotas_black: numpy::PyReadonlyArray1<'py, i32>, |
| total_games: usize, |
| max_ply: usize, |
| seed: u64, |
| max_simulated_factor: f64, |
| ) -> PyResult<( |
| Bound<'py, PyArray2<i16>>, |
| Bound<'py, PyArray1<i16>>, |
| Bound<'py, PyArray1<u8>>, |
| Bound<'py, PyArray2<u64>>, |
| Bound<'py, PyArray1<u64>>, |
| Bound<'py, PyArray1<u64>>, |
| Bound<'py, PyArray1<u64>>, |
| Bound<'py, PyArray1<u64>>, |
| Bound<'py, PyArray1<i32>>, |
| Bound<'py, PyArray1<i32>>, |
| )> { |
| let qw_slice = quotas_white.as_slice()?; |
| let qb_slice = quotas_black.as_slice()?; |
|
|
| if qw_slice.len() < 64 { |
| return Err(pyo3::exceptions::PyValueError::new_err( |
| format!("quotas_white has {} elements, expected at least 64", qw_slice.len()), |
| )); |
| } |
| if qb_slice.len() < 64 { |
| return Err(pyo3::exceptions::PyValueError::new_err( |
| format!("quotas_black has {} elements, expected at least 64", qb_slice.len()), |
| )); |
| } |
|
|
| let mut qw = [0i32; 64]; |
| let mut qb = [0i32; 64]; |
| qw.copy_from_slice(&qw_slice[..64]); |
| qb.copy_from_slice(&qb_slice[..64]); |
|
|
| let output = py.allow_threads(|| { |
| diagnostic::generate_diagnostic_sets(&qw, &qb, total_games, max_ply, seed, max_simulated_factor) |
| }); |
|
|
| let n = output.n_games; |
| let move_ids = numpy::PyArray::from_vec(py, output.move_ids) |
| .reshape([n, max_ply])?; |
| let game_lengths = numpy::PyArray::from_vec(py, output.game_lengths); |
| let termination_codes = numpy::PyArray::from_vec(py, output.termination_codes); |
| let per_ply_stats = numpy::PyArray::from_vec(py, output.per_ply_stats) |
| .reshape([n, max_ply])?; |
| let white = numpy::PyArray::from_vec(py, output.white); |
| let black = numpy::PyArray::from_vec(py, output.black); |
| let qa_white = numpy::PyArray::from_vec(py, output.quota_assignment_white); |
| let qa_black = numpy::PyArray::from_vec(py, output.quota_assignment_black); |
| let qf_white = numpy::PyArray::from_vec(py, output.quotas_filled_white); |
| let qf_black = numpy::PyArray::from_vec(py, output.quotas_filled_black); |
|
|
| Ok((move_ids, game_lengths, termination_codes, per_ply_stats, |
| white, black, qa_white, qa_black, qf_white, qf_black)) |
| } |
|
|
| |
| #[pyfunction] |
| fn edge_case_bits(py: Python<'_>) -> PyResult<PyObject> { |
| let bits = edgestats::edge_case_bits(); |
| let dict = PyDict::new(py); |
| for (k, v) in &bits { |
| dict.set_item(k.as_str(), *v)?; |
| } |
| Ok(dict.into()) |
| } |
|
|
| |
| #[pyfunction] |
| fn hello() -> &'static str { |
| "pawn-engine ready" |
| } |
|
|
| |
| |
| |
|
|
| |
| #[pyclass] |
| struct PyGameState { |
| inner: board::GameState, |
| max_ply: usize, |
| } |
|
|
| #[pymethods] |
| impl PyGameState { |
| #[new] |
| #[pyo3(signature = (max_ply=256))] |
| fn new(max_ply: usize) -> Self { |
| Self { |
| inner: board::GameState::new(), |
| max_ply, |
| } |
| } |
|
|
| |
| fn make_move(&mut self, token: u16) -> bool { |
| self.inner.make_move(token).is_ok() |
| } |
|
|
| |
| fn legal_move_tokens(&self) -> Vec<u16> { |
| self.inner.legal_move_tokens() |
| } |
|
|
| |
| fn ply(&self) -> usize { |
| self.inner.ply() |
| } |
|
|
| |
| fn is_white_to_move(&self) -> bool { |
| self.inner.is_white_to_move() |
| } |
|
|
| |
| |
| |
| fn check_termination(&self) -> i8 { |
| match self.inner.check_termination(self.max_ply) { |
| Some(t) => t as i8, |
| None => -1, |
| } |
| } |
|
|
| |
| fn is_game_over(&self) -> bool { |
| self.inner.check_termination(self.max_ply).is_some() |
| } |
|
|
| |
| fn is_checkmate(&self) -> bool { |
| self.inner.check_termination(self.max_ply) == Some(types::Termination::Checkmate) |
| } |
|
|
| |
| fn is_check(&self) -> bool { |
| self.inner.is_check() |
| } |
|
|
| |
| fn move_history(&self) -> Vec<u16> { |
| self.inner.move_history().to_vec() |
| } |
|
|
| |
| |
| |
| fn legal_moves_structured(&self) -> (Vec<u16>, Vec<(u16, Vec<u8>)>) { |
| self.inner.legal_moves_structured() |
| } |
|
|
| |
| fn legal_moves_grid_mask(&self) -> Vec<bool> { |
| self.inner.legal_moves_grid_mask().to_vec() |
| } |
|
|
| |
| |
| fn legal_moves_full(&self) -> (Vec<u16>, Vec<(u16, Vec<u8>)>, Vec<bool>) { |
| let (indices, promos, mask) = self.inner.legal_moves_full(); |
| (indices, promos, mask.to_vec()) |
| } |
|
|
| |
| fn make_move_uci(&mut self, token: u16) -> Option<String> { |
| self.inner.make_move_uci(token).ok() |
| } |
|
|
| |
| |
| fn uci_position(&self) -> String { |
| self.inner.uci_position_string() |
| } |
| } |
|
|
| |
| |
| |
|
|
| |
| |
| #[pyfunction] |
| #[pyo3(signature = (path, max_ply=128, max_games=1000000, min_ply=10))] |
| fn parse_pgn_file<'py>( |
| py: Python<'py>, |
| path: &str, |
| max_ply: usize, |
| max_games: usize, |
| min_ply: usize, |
| ) -> PyResult<( |
| Bound<'py, numpy::PyArray2<i16>>, |
| Bound<'py, numpy::PyArray1<i16>>, |
| usize, |
| )> { |
| let (flat, lengths, n_parsed) = py.allow_threads(|| { |
| pgn::pgn_file_to_tokens(path, max_ply, max_games, min_ply) |
| }); |
|
|
| let n = lengths.len(); |
| let move_ids = numpy::PyArray::from_vec(py, flat) |
| .reshape([n, max_ply])?; |
| let lengths_arr = numpy::PyArray::from_vec(py, lengths); |
|
|
| Ok((move_ids, lengths_arr, n_parsed)) |
| } |
|
|
| |
| |
| #[pyfunction] |
| #[pyo3(signature = (games, max_ply=256))] |
| fn pgn_to_tokens<'py>( |
| py: Python<'py>, |
| games: Vec<Vec<String>>, |
| max_ply: usize, |
| ) -> PyResult<( |
| Bound<'py, numpy::PyArray2<i16>>, |
| Bound<'py, numpy::PyArray1<i16>>, |
| )> { |
| let n = games.len(); |
| let (flat, lengths) = py.allow_threads(|| { |
| let refs: Vec<Vec<&str>> = games |
| .iter() |
| .map(|g| g.iter().map(|s| s.as_str()).collect()) |
| .collect(); |
| pgn::batch_san_to_tokens(&refs, max_ply) |
| }); |
|
|
| let move_ids = numpy::PyArray::from_vec(py, flat) |
| .reshape([n, max_ply])?; |
| let lengths_arr = numpy::PyArray::from_vec(py, lengths); |
|
|
| Ok((move_ids, lengths_arr)) |
| } |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| #[pyfunction] |
| #[pyo3(signature = (path, max_ply=256, max_games=1_000_000, min_ply=10))] |
| fn parse_uci_file<'py>( |
| py: Python<'py>, |
| path: &str, |
| max_ply: usize, |
| max_games: usize, |
| min_ply: usize, |
| ) -> PyResult<( |
| Bound<'py, numpy::PyArray2<i16>>, |
| Bound<'py, numpy::PyArray1<i16>>, |
| usize, |
| )> { |
| let (flat, lengths, n_parsed) = py.allow_threads(|| { |
| uci::uci_file_to_tokens(path, max_ply, max_games, min_ply) |
| }); |
|
|
| let n = lengths.len(); |
| let move_ids = numpy::PyArray::from_vec(py, flat) |
| .reshape([n, max_ply])?; |
| let lengths_arr = numpy::PyArray::from_vec(py, lengths); |
|
|
| Ok((move_ids, lengths_arr, n_parsed)) |
| } |
|
|
| |
| |
| |
| |
| #[pyfunction] |
| #[pyo3(signature = (games, max_ply=256))] |
| fn uci_to_tokens<'py>( |
| py: Python<'py>, |
| games: Vec<Vec<String>>, |
| max_ply: usize, |
| ) -> PyResult<( |
| Bound<'py, numpy::PyArray2<i16>>, |
| Bound<'py, numpy::PyArray1<i16>>, |
| )> { |
| let n = games.len(); |
| let (flat, lengths) = py.allow_threads(|| { |
| let mut flat = vec![0i16; n * max_ply]; |
| let mut lengths = Vec::with_capacity(n); |
| |
| let results: Vec<(Vec<u16>, usize)> = games |
| .iter() |
| .map(|g| { |
| let refs: Vec<&str> = g.iter().map(|s| s.as_str()).collect(); |
| uci::uci_moves_to_tokens(&refs, max_ply) |
| }) |
| .collect(); |
| for (gi, (tokens, n_valid)) in results.iter().enumerate() { |
| for (t, &tok) in tokens.iter().enumerate() { |
| flat[gi * max_ply + t] = tok as i16; |
| } |
| lengths.push(*n_valid as i16); |
| } |
| (flat, lengths) |
| }); |
|
|
| let move_ids = numpy::PyArray::from_vec(py, flat) |
| .reshape([n, max_ply])?; |
| let lengths_arr = numpy::PyArray::from_vec(py, lengths); |
|
|
| Ok((move_ids, lengths_arr)) |
| } |
|
|
| |
| |
| |
| |
| #[pyfunction] |
| fn pgn_to_uci(py: Python<'_>, games: Vec<Vec<String>>) -> PyResult<Vec<Vec<String>>> { |
| let results = py.allow_threads(|| { |
| let refs: Vec<Vec<&str>> = games |
| .iter() |
| .map(|g| g.iter().map(|s| s.as_str()).collect()) |
| .collect(); |
| uci::batch_san_to_uci(&refs) |
| }); |
|
|
| Ok(results.into_iter().map(|(uci, _)| uci).collect()) |
| } |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyfunction] |
| #[pyo3(signature = ( |
| engine_path, |
| nodes = 1, |
| total_games = 100_000, |
| n_workers = 8, |
| base_seed = 10_000, |
| temperature = 1.0, |
| multi_pv = 5, |
| sample_plies = 999, |
| hash_mb = 16, |
| max_ply = 500, |
| ))] |
| fn generate_engine_games_py( |
| py: Python<'_>, |
| engine_path: &str, |
| nodes: u32, |
| total_games: u32, |
| n_workers: u32, |
| base_seed: u64, |
| temperature: f64, |
| multi_pv: u32, |
| sample_plies: u32, |
| hash_mb: u32, |
| max_ply: u32, |
| ) -> PyResult<PyObject> { |
| let results = py.allow_threads(|| { |
| engine_gen::generate_engine_games( |
| engine_path, nodes, total_games, n_workers, base_seed, |
| temperature, multi_pv, sample_plies, hash_mb, max_ply, |
| ) |
| }); |
|
|
| let dict = PyDict::new(py); |
|
|
| let uci: Vec<String> = results.iter().map(|r| r.uci.clone()).collect(); |
| let result: Vec<String> = results.iter().map(|r| r.result.clone()).collect(); |
| let n_ply: Vec<u16> = results.iter().map(|r| r.n_ply).collect(); |
|
|
| |
| let base = total_games / n_workers; |
| let remainder = total_games % n_workers; |
| let mut worker_ids: Vec<u16> = Vec::with_capacity(results.len()); |
| let mut seeds: Vec<u64> = Vec::with_capacity(results.len()); |
| for i in 0..n_workers { |
| let games = base + if i < remainder { 1 } else { 0 }; |
| for _ in 0..games { |
| worker_ids.push(i as u16); |
| seeds.push(base_seed + i as u64); |
| } |
| } |
|
|
| dict.set_item("uci", uci)?; |
| dict.set_item("result", result)?; |
| dict.set_item("n_ply", n_ply)?; |
| dict.set_item("worker_id", worker_ids)?; |
| dict.set_item("seed", seeds)?; |
| dict.set_item("nodes", vec![nodes as i32; results.len()])?; |
|
|
| Ok(dict.into()) |
| } |
|
|
| |
| |
| |
|
|
| #[pyclass] |
| struct PyBatchRLEnv { |
| inner: rl_batch::BatchRLEnv, |
| } |
|
|
| #[pymethods] |
| impl PyBatchRLEnv { |
| #[new] |
| #[pyo3(signature = (n_games, max_ply=256, seed=42))] |
| fn new(n_games: usize, max_ply: usize, seed: u64) -> Self { |
| Self { |
| inner: rl_batch::BatchRLEnv::new(n_games, max_ply, seed), |
| } |
| } |
|
|
| fn reset(&mut self) { |
| self.inner.reset(); |
| } |
|
|
| fn n_games(&self) -> usize { |
| self.inner.n_games() |
| } |
|
|
| fn all_terminated(&self) -> bool { |
| self.inner.all_terminated() |
| } |
|
|
| fn active_agent_games(&self) -> Vec<u32> { |
| self.inner.active_agent_games() |
| } |
|
|
| fn active_opponent_games(&self) -> Vec<u32> { |
| self.inner.active_opponent_games() |
| } |
|
|
| |
| fn apply_moves<'py>( |
| &mut self, |
| py: Python<'py>, |
| game_indices: Vec<u32>, |
| tokens: Vec<u16>, |
| ) -> ( |
| Bound<'py, numpy::PyArray1<bool>>, |
| Bound<'py, numpy::PyArray1<i8>>, |
| ) { |
| let (flags, codes) = self.inner.apply_moves(&game_indices, &tokens); |
| ( |
| numpy::PyArray::from_vec(py, flags), |
| numpy::PyArray::from_vec(py, codes), |
| ) |
| } |
|
|
| |
| fn load_prefixes<'py>( |
| &mut self, |
| py: Python<'py>, |
| move_ids: numpy::PyReadonlyArray2<'py, u16>, |
| lengths: numpy::PyReadonlyArray1<'py, u32>, |
| ) -> Bound<'py, numpy::PyArray1<i8>> { |
| let ids = move_ids.as_slice().unwrap(); |
| let lens = lengths.as_slice().unwrap(); |
| let n_games = lens.len(); |
| let max_ply = if n_games > 0 { ids.len() / n_games } else { 0 }; |
| let codes = self.inner.load_prefixes(ids, lens, n_games, max_ply); |
| numpy::PyArray::from_vec(py, codes) |
| } |
|
|
| |
| fn apply_agent_moves(&mut self, game_indices: Vec<u32>, tokens: Vec<u16>) -> Vec<bool> { |
| self.inner.apply_agent_moves(&game_indices, &tokens) |
| } |
|
|
| |
| fn apply_random_opponent_moves(&mut self) -> Vec<u32> { |
| self.inner.apply_random_opponent_moves() |
| } |
|
|
| |
| fn apply_opponent_moves(&mut self, game_indices: Vec<u32>, tokens: Vec<u16>) -> Vec<bool> { |
| self.inner.apply_opponent_moves(&game_indices, &tokens) |
| } |
|
|
| |
| |
| #[pyo3(signature = (game_indices, vocab_size=4278))] |
| fn get_legal_token_masks_batch<'py>( |
| &self, |
| py: Python<'py>, |
| game_indices: Vec<u32>, |
| vocab_size: usize, |
| ) -> PyResult<Bound<'py, numpy::PyArray2<bool>>> { |
| let b = game_indices.len(); |
| let flat = self.inner.get_legal_token_masks_batch(&game_indices, vocab_size); |
| let arr = numpy::PyArray::from_vec(py, flat) |
| .reshape([b, vocab_size])?; |
| Ok(arr) |
| } |
|
|
| |
| |
| |
| fn get_legal_moves_batch<'py>( |
| &self, |
| py: Python<'py>, |
| game_indices: Vec<u32>, |
| ) -> PyResult<( |
| Vec<(Vec<u16>, Vec<(u16, Vec<u8>)>)>, |
| Bound<'py, numpy::PyArray2<bool>>, |
| )> { |
| let b = game_indices.len(); |
| let (structured, flat_masks) = self.inner.get_legal_moves_batch(&game_indices); |
| let mask_array = numpy::PyArray::from_vec(py, flat_masks) |
| .reshape([b, 4096])?; |
| Ok((structured, mask_array)) |
| } |
|
|
| |
| fn get_move_histories<'py>( |
| &self, |
| py: Python<'py>, |
| game_indices: Vec<u32>, |
| ) -> PyResult<( |
| Bound<'py, numpy::PyArray2<i64>>, |
| Bound<'py, numpy::PyArray1<i32>>, |
| )> { |
| let b = game_indices.len(); |
| let (flat, lengths) = self.inner.get_move_histories(&game_indices); |
| let cols = if b > 0 { flat.len() / b } else { 0 }; |
| let move_ids = numpy::PyArray::from_vec(py, flat) |
| .reshape([b, cols])?; |
| let lengths_arr = numpy::PyArray::from_vec(py, lengths); |
| Ok((move_ids, lengths_arr)) |
| } |
|
|
| |
| fn get_sentinel_tokens(&self, game_indices: Vec<u32>) -> Vec<u16> { |
| self.inner.get_sentinel_tokens(&game_indices) |
| } |
|
|
| |
| fn get_fens(&self, game_indices: Vec<u32>) -> Vec<String> { |
| self.inner.get_fens(&game_indices) |
| } |
|
|
| |
| fn get_uci_positions(&self, game_indices: Vec<u32>) -> Vec<String> { |
| self.inner.get_uci_positions(&game_indices) |
| } |
|
|
| |
| fn get_plies(&self, game_indices: Vec<u32>) -> Vec<u32> { |
| self.inner.get_plies(&game_indices) |
| } |
|
|
| |
| |
| fn get_outcomes<'py>( |
| &self, |
| py: Python<'py>, |
| ) -> ( |
| Bound<'py, numpy::PyArray1<bool>>, |
| Bound<'py, numpy::PyArray1<bool>>, |
| Bound<'py, numpy::PyArray1<f32>>, |
| Bound<'py, numpy::PyArray1<u32>>, |
| Bound<'py, numpy::PyArray1<i8>>, |
| Bound<'py, numpy::PyArray1<bool>>, |
| ) { |
| let (terminated, forfeited, rewards, plies, codes, colors) = self.inner.get_outcomes(); |
| ( |
| numpy::PyArray::from_vec(py, terminated), |
| numpy::PyArray::from_vec(py, forfeited), |
| numpy::PyArray::from_vec(py, rewards), |
| numpy::PyArray::from_vec(py, plies), |
| numpy::PyArray::from_vec(py, codes), |
| numpy::PyArray::from_vec(py, colors), |
| ) |
| } |
|
|
| |
| fn get_game_info(&self, game_idx: u32) -> (bool, bool, bool, f32, u32, i8) { |
| self.inner.meta(game_idx as usize) |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[pyfunction] |
| #[pyo3(signature = (n_games=1000, max_ply=255, n_rollouts=32, sample_rate=0.01, seed=77777))] |
| fn compute_accuracy_ceiling_py( |
| py: Python<'_>, |
| n_games: usize, |
| max_ply: usize, |
| n_rollouts: usize, |
| sample_rate: f64, |
| seed: u64, |
| ) -> PyResult<PyObject> { |
| let results = py.allow_threads(|| { |
| random::compute_accuracy_ceiling(n_games, max_ply, n_rollouts, sample_rate, seed) |
| }); |
|
|
| let n = results.len(); |
| let mut uncond_sum = 0.0f64; |
| let mut cond_sum = 0.0f64; |
| let mut naive_cond_sum = 0.0f64; |
|
|
| |
| let mut plies = Vec::with_capacity(n); |
| let mut game_lengths = Vec::with_capacity(n); |
| let mut n_legals = Vec::with_capacity(n); |
| let mut unconditionals = Vec::with_capacity(n); |
| let mut conditionals = Vec::with_capacity(n); |
| let mut naive_conditionals = Vec::with_capacity(n); |
| let mut outcomes = Vec::with_capacity(n); |
|
|
| for r in &results { |
| uncond_sum += r.unconditional; |
| cond_sum += r.conditional; |
| naive_cond_sum += r.naive_conditional; |
| plies.push(r.ply); |
| game_lengths.push(r.game_length); |
| n_legals.push(r.n_legal); |
| unconditionals.push(r.unconditional as f32); |
| conditionals.push(r.conditional as f32); |
| naive_conditionals.push(r.naive_conditional as f32); |
| outcomes.push(r.actual_outcome); |
| } |
|
|
| let dict = pyo3::types::PyDict::new(py); |
| dict.set_item("n_positions", n)?; |
| dict.set_item("n_games", n_games)?; |
| dict.set_item("n_rollouts", n_rollouts)?; |
| dict.set_item("sample_rate", sample_rate)?; |
| dict.set_item("unconditional_ceiling", if n > 0 { uncond_sum / n as f64 } else { 0.0 })?; |
| dict.set_item("naive_conditional_ceiling", if n > 0 { naive_cond_sum / n as f64 } else { 0.0 })?; |
| dict.set_item("conditional_ceiling", if n > 0 { cond_sum / n as f64 } else { 0.0 })?; |
|
|
| |
| let np = py.import("numpy")?; |
| dict.set_item("ply", np.call_method1("array", (plies,))?)?; |
| dict.set_item("game_length", np.call_method1("array", (game_lengths,))?)?; |
| dict.set_item("n_legal", np.call_method1("array", (n_legals,))?)?; |
| dict.set_item("unconditional", np.call_method1("array", (unconditionals,))?)?; |
| dict.set_item("naive_conditional", np.call_method1("array", (naive_conditionals,))?)?; |
| dict.set_item("conditional", np.call_method1("array", (conditionals,))?)?; |
| let outcomes_u16: Vec<u16> = outcomes.iter().map(|&x| x as u16).collect(); |
| dict.set_item("outcome", np.call_method1("array", (outcomes_u16,))?)?; |
|
|
| Ok(dict.into()) |
| } |
|
|
| #[pymodule] |
| fn _engine(m: &Bound<'_, PyModule>) -> PyResult<()> { |
| m.add_function(wrap_pyfunction!(hello, m)?)?; |
| m.add_function(wrap_pyfunction!(export_move_vocabulary, m)?)?; |
| m.add_function(wrap_pyfunction!(generate_training_batch, m)?)?; |
| m.add_function(wrap_pyfunction!(generate_random_games, m)?)?; |
| m.add_function(wrap_pyfunction!(generate_clm_batch, m)?)?; |
| m.add_function(wrap_pyfunction!(generate_checkmate_games, m)?)?; |
| m.add_function(wrap_pyfunction!(generate_checkmate_training_batch, m)?)?; |
| m.add_function(wrap_pyfunction!(compute_legal_move_masks, m)?)?; |
| m.add_function(wrap_pyfunction!(compute_legal_token_masks, m)?)?; |
| m.add_function(wrap_pyfunction!(compute_legal_token_masks_sparse, m)?)?; |
| m.add_function(wrap_pyfunction!(extract_board_states, m)?)?; |
| m.add_function(wrap_pyfunction!(validate_games_py, m)?)?; |
| m.add_function(wrap_pyfunction!(compute_edge_stats_per_ply_py, m)?)?; |
| m.add_function(wrap_pyfunction!(compute_edge_stats_per_game_py, m)?)?; |
| m.add_function(wrap_pyfunction!(generate_diagnostic_sets_py, m)?)?; |
| m.add_function(wrap_pyfunction!(edge_case_bits, m)?)?; |
| m.add_class::<PyGameState>()?; |
| m.add_class::<PyBatchRLEnv>()?; |
| m.add_function(wrap_pyfunction!(parse_pgn_file, m)?)?; |
| m.add_function(wrap_pyfunction!(pgn_to_tokens, m)?)?; |
| m.add_function(wrap_pyfunction!(parse_uci_file, m)?)?; |
| m.add_function(wrap_pyfunction!(uci_to_tokens, m)?)?; |
| m.add_function(wrap_pyfunction!(pgn_to_uci, m)?)?; |
| m.add_function(wrap_pyfunction!(generate_engine_games_py, m)?)?; |
| m.add_function(wrap_pyfunction!(compute_accuracy_ceiling_py, m)?)?; |
| Ok(()) |
| } |
|
|