File size: 9,777 Bytes
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05a2457
 
 
 
 
 
 
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
//! Legal move mask computation by replaying game sequences. Spec §7.4.

use rayon::prelude::*;

use crate::board::GameState;

/// Replay games and compute legal move masks at each ply.
/// The label at ply t is the legal moves BEFORE move t has been played —
/// i.e., the moves available to the side about to play move_ids[t].
/// Returns (legal_move_grid, legal_promo_mask) as flat arrays.
pub fn compute_legal_move_masks(
    move_ids: &[i16],   // [batch * max_ply]
    game_lengths: &[i16], // [batch]
    max_ply: usize,
) -> (Vec<u64>, Vec<bool>) {
    let batch = game_lengths.len();
    let mut grids = vec![0u64; batch * max_ply * 64];
    let mut promos = vec![false; batch * max_ply * 44 * 4];

    // Process each game in parallel
    let results: Vec<(Vec<[u64; 64]>, Vec<[[bool; 4]; 44]>)> = (0..batch)
        .into_par_iter()
        .map(|b| {
            let length = game_lengths[b] as usize;
            let mut state = GameState::new();
            let mut game_grids = Vec::with_capacity(length);
            let mut game_promos = Vec::with_capacity(length);

            for t in 0..length {
                // Record legal moves BEFORE making the move
                game_grids.push(state.legal_move_grid());
                game_promos.push(state.legal_promo_mask());

                let token = move_ids[b * max_ply + t] as u16;
                state.make_move(token).expect("Move should be legal during replay");
            }

            (game_grids, game_promos)
        })
        .collect();

    // Pack into flat arrays
    for (b, (game_grids, game_promos)) in results.into_iter().enumerate() {
        for (t, grid) in game_grids.iter().enumerate() {
            let offset = (b * max_ply + t) * 64;
            grids[offset..offset + 64].copy_from_slice(grid);
        }
        for (t, promo) in game_promos.iter().enumerate() {
            let offset = (b * max_ply + t) * 44 * 4;
            for pair in 0..44 {
                for pt in 0..4 {
                    promos[offset + pair * 4 + pt] = promo[pair][pt];
                }
            }
        }
    }

    (grids, promos)
}

/// Replay games and produce a dense (batch, max_ply, vocab_size) bool token mask.
///
/// Fuses game replay with token mask construction — no intermediate bitboard grid.
/// Each position's legal moves are converted directly to token IDs and written
/// into the flat output array.  Rayon-parallel over games.
pub fn compute_legal_token_masks(
    move_ids: &[i16],      // [batch * max_ply]
    game_lengths: &[i16],  // [batch]
    max_ply: usize,
    vocab_size: usize,
) -> Vec<bool> {
    let batch = game_lengths.len();
    let stride_game = max_ply * vocab_size;

    // Zero-initialize output (memset — fast)
    let mut masks = vec![false; batch * stride_game];

    // Each game writes to a non-overlapping slice — parallel with no contention.
    masks
        .par_chunks_mut(stride_game)
        .enumerate()
        .for_each(|(b, game_mask)| {
            let length = game_lengths[b] as usize;
            let mut state = GameState::new();

            for t in 0..length {
                let ply_base = t * vocab_size;
                let tokens = state.legal_move_tokens();
                for tok in tokens {
                    let ti = tok as usize;
                    if ti < vocab_size {
                        game_mask[ply_base + ti] = true;
                    }
                }
                let move_tok = move_ids[b * max_ply + t] as u16;
                state.make_move(move_tok).expect("Move should be legal during replay");
            }
        });

    masks
}

/// Sparse variant: return flat i64 indices into a (batch, seq_len, vocab_size) tensor.
///
/// Each index encodes `b * seq_len * vocab_size + t * vocab_size + token_id`,
/// ready for direct GPU scatter via `index_fill_`.  Output is ~2 MB instead of
/// ~70 MB for the dense version (legal moves are <1% of the vocabulary).
pub fn compute_legal_token_masks_sparse(
    move_ids: &[i16],      // [batch * max_ply]
    game_lengths: &[i16],  // [batch]
    max_ply: usize,
    seq_len: usize,        // typically max_ply + 1
    vocab_size: usize,
) -> Vec<i64> {
    let batch = game_lengths.len();

    let per_game: Vec<Vec<i64>> = (0..batch)
        .into_par_iter()
        .map(|b| {
            let length = game_lengths[b] as usize;
            let mut state = GameState::new();
            let game_base = (b * seq_len * vocab_size) as i64;
            let mut indices = Vec::with_capacity(length * 32);

            for t in 0..length {
                let ply_base = game_base + (t * vocab_size) as i64;
                for tok in state.legal_move_tokens() {
                    let ti = tok as usize;
                    if ti < vocab_size {
                        indices.push(ply_base + ti as i64);
                    }
                }
                let move_tok = move_ids[b * max_ply + t] as u16;
                state.make_move(move_tok).expect("Move should be legal during replay");
            }

            // At position `length`, the target is PAD (end of game).
            // Include PAD token in the legal mask so loss is finite.
            if length < seq_len {
                let pad_base = game_base + (length * vocab_size) as i64;
                indices.push(pad_base); // PAD_TOKEN = 0
            }

            indices
        })
        .collect();

    // Flatten — total size ~288K for a typical batch
    let total: usize = per_game.iter().map(|v| v.len()).sum();
    let mut flat = Vec::with_capacity(total);
    for v in per_game {
        flat.extend(v);
    }
    flat
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::batch::generate_training_batch;
    #[test]
    fn test_labels_match_fused() {
        // Generate a batch with fused labels, then recompute via replay and compare
        let batch = generate_training_batch(4, 256, 42);
        let (grids, promos) = compute_legal_move_masks(
            &batch.move_ids,
            &batch.game_lengths,
            256,
        );
        assert_eq!(grids, batch.legal_move_grid, "Replayed grids must match fused grids");
        assert_eq!(promos, batch.legal_promo_mask, "Replayed promos must match fused promos");
    }

    #[test]
    fn test_token_masks_via_replay() {
        // Verify compute_legal_token_masks matches direct replay with legal_move_tokens()
        let batch_size = 8;
        let max_ply = 256;
        let vocab_size = 4278;
        let batch = generate_training_batch(batch_size, max_ply, 99);

        let token_masks = compute_legal_token_masks(
            &batch.move_ids, &batch.game_lengths, max_ply, vocab_size,
        );

        // Independently replay each game and verify token masks match
        for b in 0..batch_size {
            let gl = batch.game_lengths[b] as usize;
            let mut state = GameState::new();

            for t in 0..gl {
                let legal_tokens = state.legal_move_tokens();
                let mask_off = (b * max_ply + t) * vocab_size;

                // Every legal token should be marked true
                for &tok in &legal_tokens {
                    assert!(
                        token_masks[mask_off + tok as usize],
                        "game {b} ply {t}: legal token {tok} not set in mask"
                    );
                }

                // Count of true entries should match number of legal tokens
                let mask_count: usize = (0..vocab_size)
                    .filter(|&v| token_masks[mask_off + v])
                    .count();
                assert_eq!(
                    mask_count, legal_tokens.len(),
                    "game {b} ply {t}: mask has {mask_count} legal tokens but expected {}",
                    legal_tokens.len()
                );

                let move_tok = batch.move_ids[b * max_ply + t] as u16;
                state.make_move(move_tok).unwrap();
            }

            // Verify positions beyond game_length are all-false
            for t in gl..max_ply {
                let mask_off = (b * max_ply + t) * vocab_size;
                let any_set = (0..vocab_size).any(|v| token_masks[mask_off + v]);
                assert!(!any_set, "game {b} ply {t} (past game end): mask should be all-false");
            }
        }
    }

    #[test]
    fn test_sparse_matches_dense() {
        let batch_size = 8;
        let max_ply = 256;
        let seq_len = max_ply + 1;
        let vocab_size = 4278;
        let batch = generate_training_batch(batch_size, max_ply, 77);

        let dense = compute_legal_token_masks(
            &batch.move_ids, &batch.game_lengths, max_ply, vocab_size,
        );
        let sparse = compute_legal_token_masks_sparse(
            &batch.move_ids, &batch.game_lengths, max_ply, seq_len, vocab_size,
        );

        // Reconstruct dense from sparse and compare
        let mut reconstructed = vec![false; batch_size * seq_len * vocab_size];
        for &idx in &sparse {
            reconstructed[idx as usize] = true;
        }

        // Dense uses (B, max_ply, V), sparse uses (B, seq_len, V) layout.
        // Compare the overlapping region.
        for b in 0..batch_size {
            let gl = batch.game_lengths[b] as usize;
            for t in 0..gl {
                for v in 0..vocab_size {
                    let dense_val = dense[b * max_ply * vocab_size + t * vocab_size + v];
                    let sparse_val = reconstructed[b * seq_len * vocab_size + t * vocab_size + v];
                    assert_eq!(
                        dense_val, sparse_val,
                        "Mismatch at game {b} ply {t} token {v}"
                    );
                }
            }
        }
    }
}