File size: 9,777 Bytes
d7ecc62 05a2457 d7ecc62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 | //! Legal move mask computation by replaying game sequences. Spec §7.4.
use rayon::prelude::*;
use crate::board::GameState;
/// Replay games and compute legal move masks at each ply.
/// The label at ply t is the legal moves BEFORE move t has been played —
/// i.e., the moves available to the side about to play move_ids[t].
/// Returns (legal_move_grid, legal_promo_mask) as flat arrays.
pub fn compute_legal_move_masks(
move_ids: &[i16], // [batch * max_ply]
game_lengths: &[i16], // [batch]
max_ply: usize,
) -> (Vec<u64>, Vec<bool>) {
let batch = game_lengths.len();
let mut grids = vec![0u64; batch * max_ply * 64];
let mut promos = vec![false; batch * max_ply * 44 * 4];
// Process each game in parallel
let results: Vec<(Vec<[u64; 64]>, Vec<[[bool; 4]; 44]>)> = (0..batch)
.into_par_iter()
.map(|b| {
let length = game_lengths[b] as usize;
let mut state = GameState::new();
let mut game_grids = Vec::with_capacity(length);
let mut game_promos = Vec::with_capacity(length);
for t in 0..length {
// Record legal moves BEFORE making the move
game_grids.push(state.legal_move_grid());
game_promos.push(state.legal_promo_mask());
let token = move_ids[b * max_ply + t] as u16;
state.make_move(token).expect("Move should be legal during replay");
}
(game_grids, game_promos)
})
.collect();
// Pack into flat arrays
for (b, (game_grids, game_promos)) in results.into_iter().enumerate() {
for (t, grid) in game_grids.iter().enumerate() {
let offset = (b * max_ply + t) * 64;
grids[offset..offset + 64].copy_from_slice(grid);
}
for (t, promo) in game_promos.iter().enumerate() {
let offset = (b * max_ply + t) * 44 * 4;
for pair in 0..44 {
for pt in 0..4 {
promos[offset + pair * 4 + pt] = promo[pair][pt];
}
}
}
}
(grids, promos)
}
/// Replay games and produce a dense (batch, max_ply, vocab_size) bool token mask.
///
/// Fuses game replay with token mask construction — no intermediate bitboard grid.
/// Each position's legal moves are converted directly to token IDs and written
/// into the flat output array. Rayon-parallel over games.
pub fn compute_legal_token_masks(
move_ids: &[i16], // [batch * max_ply]
game_lengths: &[i16], // [batch]
max_ply: usize,
vocab_size: usize,
) -> Vec<bool> {
let batch = game_lengths.len();
let stride_game = max_ply * vocab_size;
// Zero-initialize output (memset — fast)
let mut masks = vec![false; batch * stride_game];
// Each game writes to a non-overlapping slice — parallel with no contention.
masks
.par_chunks_mut(stride_game)
.enumerate()
.for_each(|(b, game_mask)| {
let length = game_lengths[b] as usize;
let mut state = GameState::new();
for t in 0..length {
let ply_base = t * vocab_size;
let tokens = state.legal_move_tokens();
for tok in tokens {
let ti = tok as usize;
if ti < vocab_size {
game_mask[ply_base + ti] = true;
}
}
let move_tok = move_ids[b * max_ply + t] as u16;
state.make_move(move_tok).expect("Move should be legal during replay");
}
});
masks
}
/// Sparse variant: return flat i64 indices into a (batch, seq_len, vocab_size) tensor.
///
/// Each index encodes `b * seq_len * vocab_size + t * vocab_size + token_id`,
/// ready for direct GPU scatter via `index_fill_`. Output is ~2 MB instead of
/// ~70 MB for the dense version (legal moves are <1% of the vocabulary).
pub fn compute_legal_token_masks_sparse(
move_ids: &[i16], // [batch * max_ply]
game_lengths: &[i16], // [batch]
max_ply: usize,
seq_len: usize, // typically max_ply + 1
vocab_size: usize,
) -> Vec<i64> {
let batch = game_lengths.len();
let per_game: Vec<Vec<i64>> = (0..batch)
.into_par_iter()
.map(|b| {
let length = game_lengths[b] as usize;
let mut state = GameState::new();
let game_base = (b * seq_len * vocab_size) as i64;
let mut indices = Vec::with_capacity(length * 32);
for t in 0..length {
let ply_base = game_base + (t * vocab_size) as i64;
for tok in state.legal_move_tokens() {
let ti = tok as usize;
if ti < vocab_size {
indices.push(ply_base + ti as i64);
}
}
let move_tok = move_ids[b * max_ply + t] as u16;
state.make_move(move_tok).expect("Move should be legal during replay");
}
// At position `length`, the target is PAD (end of game).
// Include PAD token in the legal mask so loss is finite.
if length < seq_len {
let pad_base = game_base + (length * vocab_size) as i64;
indices.push(pad_base); // PAD_TOKEN = 0
}
indices
})
.collect();
// Flatten — total size ~288K for a typical batch
let total: usize = per_game.iter().map(|v| v.len()).sum();
let mut flat = Vec::with_capacity(total);
for v in per_game {
flat.extend(v);
}
flat
}
#[cfg(test)]
mod tests {
use super::*;
use crate::batch::generate_training_batch;
#[test]
fn test_labels_match_fused() {
// Generate a batch with fused labels, then recompute via replay and compare
let batch = generate_training_batch(4, 256, 42);
let (grids, promos) = compute_legal_move_masks(
&batch.move_ids,
&batch.game_lengths,
256,
);
assert_eq!(grids, batch.legal_move_grid, "Replayed grids must match fused grids");
assert_eq!(promos, batch.legal_promo_mask, "Replayed promos must match fused promos");
}
#[test]
fn test_token_masks_via_replay() {
// Verify compute_legal_token_masks matches direct replay with legal_move_tokens()
let batch_size = 8;
let max_ply = 256;
let vocab_size = 4278;
let batch = generate_training_batch(batch_size, max_ply, 99);
let token_masks = compute_legal_token_masks(
&batch.move_ids, &batch.game_lengths, max_ply, vocab_size,
);
// Independently replay each game and verify token masks match
for b in 0..batch_size {
let gl = batch.game_lengths[b] as usize;
let mut state = GameState::new();
for t in 0..gl {
let legal_tokens = state.legal_move_tokens();
let mask_off = (b * max_ply + t) * vocab_size;
// Every legal token should be marked true
for &tok in &legal_tokens {
assert!(
token_masks[mask_off + tok as usize],
"game {b} ply {t}: legal token {tok} not set in mask"
);
}
// Count of true entries should match number of legal tokens
let mask_count: usize = (0..vocab_size)
.filter(|&v| token_masks[mask_off + v])
.count();
assert_eq!(
mask_count, legal_tokens.len(),
"game {b} ply {t}: mask has {mask_count} legal tokens but expected {}",
legal_tokens.len()
);
let move_tok = batch.move_ids[b * max_ply + t] as u16;
state.make_move(move_tok).unwrap();
}
// Verify positions beyond game_length are all-false
for t in gl..max_ply {
let mask_off = (b * max_ply + t) * vocab_size;
let any_set = (0..vocab_size).any(|v| token_masks[mask_off + v]);
assert!(!any_set, "game {b} ply {t} (past game end): mask should be all-false");
}
}
}
#[test]
fn test_sparse_matches_dense() {
let batch_size = 8;
let max_ply = 256;
let seq_len = max_ply + 1;
let vocab_size = 4278;
let batch = generate_training_batch(batch_size, max_ply, 77);
let dense = compute_legal_token_masks(
&batch.move_ids, &batch.game_lengths, max_ply, vocab_size,
);
let sparse = compute_legal_token_masks_sparse(
&batch.move_ids, &batch.game_lengths, max_ply, seq_len, vocab_size,
);
// Reconstruct dense from sparse and compare
let mut reconstructed = vec![false; batch_size * seq_len * vocab_size];
for &idx in &sparse {
reconstructed[idx as usize] = true;
}
// Dense uses (B, max_ply, V), sparse uses (B, seq_len, V) layout.
// Compare the overlapping region.
for b in 0..batch_size {
let gl = batch.game_lengths[b] as usize;
for t in 0..gl {
for v in 0..vocab_size {
let dense_val = dense[b * max_ply * vocab_size + t * vocab_size + v];
let sparse_val = reconstructed[b * seq_len * vocab_size + t * vocab_size + v];
assert_eq!(
dense_val, sparse_val,
"Mismatch at game {b} ply {t} token {v}"
);
}
}
}
}
}
|