| |
|
|
| use rayon::prelude::*; |
|
|
| use crate::board::GameState; |
|
|
| |
| |
| |
| |
| pub fn compute_legal_move_masks( |
| move_ids: &[i16], |
| game_lengths: &[i16], |
| max_ply: usize, |
| ) -> (Vec<u64>, Vec<bool>) { |
| let batch = game_lengths.len(); |
| let mut grids = vec![0u64; batch * max_ply * 64]; |
| let mut promos = vec![false; batch * max_ply * 44 * 4]; |
|
|
| |
| let results: Vec<(Vec<[u64; 64]>, Vec<[[bool; 4]; 44]>)> = (0..batch) |
| .into_par_iter() |
| .map(|b| { |
| let length = game_lengths[b] as usize; |
| let mut state = GameState::new(); |
| let mut game_grids = Vec::with_capacity(length); |
| let mut game_promos = Vec::with_capacity(length); |
|
|
| for t in 0..length { |
| |
| game_grids.push(state.legal_move_grid()); |
| game_promos.push(state.legal_promo_mask()); |
|
|
| let token = move_ids[b * max_ply + t] as u16; |
| state.make_move(token).expect("Move should be legal during replay"); |
| } |
|
|
| (game_grids, game_promos) |
| }) |
| .collect(); |
|
|
| |
| for (b, (game_grids, game_promos)) in results.into_iter().enumerate() { |
| for (t, grid) in game_grids.iter().enumerate() { |
| let offset = (b * max_ply + t) * 64; |
| grids[offset..offset + 64].copy_from_slice(grid); |
| } |
| for (t, promo) in game_promos.iter().enumerate() { |
| let offset = (b * max_ply + t) * 44 * 4; |
| for pair in 0..44 { |
| for pt in 0..4 { |
| promos[offset + pair * 4 + pt] = promo[pair][pt]; |
| } |
| } |
| } |
| } |
|
|
| (grids, promos) |
| } |
|
|
| |
| |
| |
| |
| |
| pub fn compute_legal_token_masks( |
| move_ids: &[i16], |
| game_lengths: &[i16], |
| max_ply: usize, |
| vocab_size: usize, |
| ) -> Vec<bool> { |
| let batch = game_lengths.len(); |
| let stride_game = max_ply * vocab_size; |
|
|
| |
| let mut masks = vec![false; batch * stride_game]; |
|
|
| |
| masks |
| .par_chunks_mut(stride_game) |
| .enumerate() |
| .for_each(|(b, game_mask)| { |
| let length = game_lengths[b] as usize; |
| let mut state = GameState::new(); |
|
|
| for t in 0..length { |
| let ply_base = t * vocab_size; |
| let tokens = state.legal_move_tokens(); |
| for tok in tokens { |
| let ti = tok as usize; |
| if ti < vocab_size { |
| game_mask[ply_base + ti] = true; |
| } |
| } |
| let move_tok = move_ids[b * max_ply + t] as u16; |
| state.make_move(move_tok).expect("Move should be legal during replay"); |
| } |
| }); |
|
|
| masks |
| } |
|
|
| |
| |
| |
| |
| |
| pub fn compute_legal_token_masks_sparse( |
| move_ids: &[i16], |
| game_lengths: &[i16], |
| max_ply: usize, |
| seq_len: usize, |
| vocab_size: usize, |
| ) -> Vec<i64> { |
| let batch = game_lengths.len(); |
|
|
| let per_game: Vec<Vec<i64>> = (0..batch) |
| .into_par_iter() |
| .map(|b| { |
| let length = game_lengths[b] as usize; |
| let mut state = GameState::new(); |
| let game_base = (b * seq_len * vocab_size) as i64; |
| let mut indices = Vec::with_capacity(length * 32); |
|
|
| for t in 0..length { |
| let ply_base = game_base + (t * vocab_size) as i64; |
| for tok in state.legal_move_tokens() { |
| let ti = tok as usize; |
| if ti < vocab_size { |
| indices.push(ply_base + ti as i64); |
| } |
| } |
| let move_tok = move_ids[b * max_ply + t] as u16; |
| state.make_move(move_tok).expect("Move should be legal during replay"); |
| } |
|
|
| |
| |
| if length < seq_len { |
| let pad_base = game_base + (length * vocab_size) as i64; |
| indices.push(pad_base); |
| } |
|
|
| indices |
| }) |
| .collect(); |
|
|
| |
| let total: usize = per_game.iter().map(|v| v.len()).sum(); |
| let mut flat = Vec::with_capacity(total); |
| for v in per_game { |
| flat.extend(v); |
| } |
| flat |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use crate::batch::generate_training_batch; |
| #[test] |
| fn test_labels_match_fused() { |
| |
| let batch = generate_training_batch(4, 256, 42); |
| let (grids, promos) = compute_legal_move_masks( |
| &batch.move_ids, |
| &batch.game_lengths, |
| 256, |
| ); |
| assert_eq!(grids, batch.legal_move_grid, "Replayed grids must match fused grids"); |
| assert_eq!(promos, batch.legal_promo_mask, "Replayed promos must match fused promos"); |
| } |
|
|
| #[test] |
| fn test_token_masks_via_replay() { |
| |
| let batch_size = 8; |
| let max_ply = 256; |
| let vocab_size = 4278; |
| let batch = generate_training_batch(batch_size, max_ply, 99); |
|
|
| let token_masks = compute_legal_token_masks( |
| &batch.move_ids, &batch.game_lengths, max_ply, vocab_size, |
| ); |
|
|
| |
| for b in 0..batch_size { |
| let gl = batch.game_lengths[b] as usize; |
| let mut state = GameState::new(); |
|
|
| for t in 0..gl { |
| let legal_tokens = state.legal_move_tokens(); |
| let mask_off = (b * max_ply + t) * vocab_size; |
|
|
| |
| for &tok in &legal_tokens { |
| assert!( |
| token_masks[mask_off + tok as usize], |
| "game {b} ply {t}: legal token {tok} not set in mask" |
| ); |
| } |
|
|
| |
| let mask_count: usize = (0..vocab_size) |
| .filter(|&v| token_masks[mask_off + v]) |
| .count(); |
| assert_eq!( |
| mask_count, legal_tokens.len(), |
| "game {b} ply {t}: mask has {mask_count} legal tokens but expected {}", |
| legal_tokens.len() |
| ); |
|
|
| let move_tok = batch.move_ids[b * max_ply + t] as u16; |
| state.make_move(move_tok).unwrap(); |
| } |
|
|
| |
| for t in gl..max_ply { |
| let mask_off = (b * max_ply + t) * vocab_size; |
| let any_set = (0..vocab_size).any(|v| token_masks[mask_off + v]); |
| assert!(!any_set, "game {b} ply {t} (past game end): mask should be all-false"); |
| } |
| } |
| } |
|
|
| #[test] |
| fn test_sparse_matches_dense() { |
| let batch_size = 8; |
| let max_ply = 256; |
| let seq_len = max_ply + 1; |
| let vocab_size = 4278; |
| let batch = generate_training_batch(batch_size, max_ply, 77); |
|
|
| let dense = compute_legal_token_masks( |
| &batch.move_ids, &batch.game_lengths, max_ply, vocab_size, |
| ); |
| let sparse = compute_legal_token_masks_sparse( |
| &batch.move_ids, &batch.game_lengths, max_ply, seq_len, vocab_size, |
| ); |
|
|
| |
| let mut reconstructed = vec![false; batch_size * seq_len * vocab_size]; |
| for &idx in &sparse { |
| reconstructed[idx as usize] = true; |
| } |
|
|
| |
| |
| for b in 0..batch_size { |
| let gl = batch.game_lengths[b] as usize; |
| for t in 0..gl { |
| for v in 0..vocab_size { |
| let dense_val = dense[b * max_ply * vocab_size + t * vocab_size + v]; |
| let sparse_val = reconstructed[b * seq_len * vocab_size + t * vocab_size + v]; |
| assert_eq!( |
| dense_val, sparse_val, |
| "Mismatch at game {b} ply {t} token {v}" |
| ); |
| } |
| } |
| } |
| } |
| } |
|
|