File size: 4,402 Bytes
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
//! Board state extraction for probing and diagnostics. Spec §7.5.

use rayon::prelude::*;

use crate::board::GameState;

/// Board state arrays for a batch of games.
pub struct BoardStates {
    pub boards: Vec<i8>,        // [batch * max_ply * 8 * 8]
    pub side_to_move: Vec<bool>,  // [batch * max_ply]
    pub castling_rights: Vec<u8>, // [batch * max_ply]
    pub ep_square: Vec<i8>,       // [batch * max_ply]
    pub is_check: Vec<bool>,      // [batch * max_ply]
    pub halfmove_clock: Vec<u8>,  // [batch * max_ply]
}

/// Extract board states at each ply. The state at ply i is the board BEFORE
/// move_ids[i] is played.
pub fn extract_board_states(
    move_ids: &[i16],     // [batch * max_ply]
    game_lengths: &[i16], // [batch]
    max_ply: usize,
) -> BoardStates {
    let batch = game_lengths.len();

    // Per-game extraction (parallel)
    let results: Vec<_> = (0..batch)
        .into_par_iter()
        .map(|b| {
            let length = game_lengths[b] as usize;
            let mut state = GameState::new();

            let mut boards = vec![0i8; length * 64];
            let mut side_to_move = vec![false; length];
            let mut castling_rights = vec![0u8; length];
            let mut ep_square = vec![-1i8; length];
            let mut is_check = vec![false; length];
            let mut halfmove_clock = vec![0u8; length];

            for t in 0..length {
                // Extract state BEFORE the move
                let board = state.board_array();
                for rank in 0..8 {
                    for file in 0..8 {
                        boards[t * 64 + rank * 8 + file] = board[rank][file];
                    }
                }
                side_to_move[t] = state.is_white_to_move();
                castling_rights[t] = state.castling_rights_bits();
                ep_square[t] = state.ep_square();
                is_check[t] = state.is_check();
                halfmove_clock[t] = std::cmp::min(state.halfmove_clock(), u8::MAX as u32) as u8;

                // Apply the move
                let token = move_ids[b * max_ply + t] as u16;
                state.make_move(token).expect("Move should be legal during replay");
            }

            (boards, side_to_move, castling_rights, ep_square, is_check, halfmove_clock)
        })
        .collect();

    // Pack into flat arrays
    let mut all_boards = vec![0i8; batch * max_ply * 64];
    let mut all_stm = vec![false; batch * max_ply];
    let mut all_cr = vec![0u8; batch * max_ply];
    let mut all_ep = vec![-1i8; batch * max_ply];
    let mut all_check = vec![false; batch * max_ply];
    let mut all_hmc = vec![0u8; batch * max_ply];

    for (b, (boards, stm, cr, ep, check, hmc)) in results.into_iter().enumerate() {
        let length = game_lengths[b] as usize;
        let ply_offset = b * max_ply;

        for t in 0..length {
            let src_offset = t * 64;
            let dst_offset = (ply_offset + t) * 64;
            all_boards[dst_offset..dst_offset + 64]
                .copy_from_slice(&boards[src_offset..src_offset + 64]);
            all_stm[ply_offset + t] = stm[t];
            all_cr[ply_offset + t] = cr[t];
            all_ep[ply_offset + t] = ep[t];
            all_check[ply_offset + t] = check[t];
            all_hmc[ply_offset + t] = hmc[t];
        }
    }

    BoardStates {
        boards: all_boards,
        side_to_move: all_stm,
        castling_rights: all_cr,
        ep_square: all_ep,
        is_check: all_check,
        halfmove_clock: all_hmc,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::batch::generate_random_games;

    #[test]
    fn test_extract_initial_position() {
        // Generate a short game and check the initial board state
        let batch = generate_random_games(1, 256, 42);
        let states = extract_board_states(&batch.move_ids, &batch.game_lengths, 256);

        // Ply 0: initial position, white to move
        assert!(states.side_to_move[0]); // White
        assert_eq!(states.castling_rights[0], 0b1111); // All castling rights
        assert_eq!(states.ep_square[0], -1); // No EP
        assert!(!states.is_check[0]); // Not in check

        // Check white pieces on rank 0 (a1..h1)
        // R N B Q K B N R = 4 2 3 5 6 3 2 4
        assert_eq!(states.boards[0], 4); // a1 = white rook
        assert_eq!(states.boards[4], 6); // e1 = white king
    }
}