| |
| |
| |
| |
|
|
| use rayon::prelude::*; |
| use shakmaty::{Bitboard, Chess, Color, Move, MoveList, Position, Role, Square}; |
| use std::collections::HashMap; |
|
|
| use crate::board::GameState; |
| use crate::vocab; |
|
|
| |
| pub const IN_CHECK: u64 = 1 << 0; |
| pub const IN_DOUBLE_CHECK: u64 = 1 << 1; |
| pub const CHECKMATE: u64 = 1 << 2; |
| pub const STALEMATE: u64 = 1 << 3; |
| pub const PIN_RESTRICTS_MOVEMENT: u64 = 1 << 4; |
| pub const PAWN_CAPTURE_AVAILABLE: u64 = 1 << 5; |
| pub const PROMOTION_AVAILABLE: u64 = 1 << 6; |
| pub const EP_CAPTURE_AVAILABLE: u64 = 1 << 7; |
| pub const CASTLE_LEGAL_KINGSIDE: u64 = 1 << 8; |
| pub const CASTLE_LEGAL_QUEENSIDE: u64 = 1 << 9; |
| pub const CASTLE_BLOCKED_CHECK: u64 = 1 << 10; |
| pub const RIGHTS_LOST_KINGSIDE: u64 = 1 << 11; |
| pub const RIGHTS_LOST_QUEENSIDE: u64 = 1 << 12; |
| pub const RIGHTS_LOST_CAPTURE_KINGSIDE: u64 = 1 << 13; |
| pub const RIGHTS_LOST_CAPTURE_QUEENSIDE: u64= 1 << 14; |
| pub const PIECE_BLOCKS_B: u64 = 1 << 15; |
| pub const PIECE_BLOCKS_C: u64 = 1 << 16; |
| pub const PIECE_BLOCKS_D: u64 = 1 << 17; |
| pub const PIECE_BLOCKS_F: u64 = 1 << 18; |
| pub const PIECE_BLOCKS_G: u64 = 1 << 19; |
| pub const ATTACK_BLOCKS_C: u64 = 1 << 20; |
| pub const ATTACK_BLOCKS_D: u64 = 1 << 21; |
| pub const ATTACK_BLOCKS_F: u64 = 1 << 22; |
| pub const ATTACK_BLOCKS_G: u64 = 1 << 23; |
| pub const EP_FILE_A: u64 = 1 << 24; |
| pub const EP_FILE_B: u64 = 1 << 25; |
| pub const EP_FILE_C: u64 = 1 << 26; |
| pub const EP_FILE_D: u64 = 1 << 27; |
| pub const EP_FILE_E: u64 = 1 << 28; |
| pub const EP_FILE_F: u64 = 1 << 29; |
| pub const EP_FILE_G: u64 = 1 << 30; |
| pub const EP_FILE_H: u64 = 1 << 31; |
| pub const PROMO_FILE_A: u64 = 1 << 32; |
| pub const PROMO_FILE_B: u64 = 1 << 33; |
| pub const PROMO_FILE_C: u64 = 1 << 34; |
| pub const PROMO_FILE_D: u64 = 1 << 35; |
| pub const PROMO_FILE_E: u64 = 1 << 36; |
| pub const PROMO_FILE_F: u64 = 1 << 37; |
| pub const PROMO_FILE_G: u64 = 1 << 38; |
| pub const PROMO_FILE_H: u64 = 1 << 39; |
| pub const SEVENTY_FIVE_MOVE_RULE: u64 = 1 << 40; |
| pub const FIVEFOLD_REPETITION: u64 = 1 << 41; |
| pub const INSUFFICIENT_MATERIAL: u64 = 1 << 42; |
|
|
| |
| pub fn edge_case_bits() -> HashMap<String, u64> { |
| let mut m = HashMap::new(); |
| m.insert("IN_CHECK".into(), IN_CHECK); |
| m.insert("IN_DOUBLE_CHECK".into(), IN_DOUBLE_CHECK); |
| m.insert("CHECKMATE".into(), CHECKMATE); |
| m.insert("STALEMATE".into(), STALEMATE); |
| m.insert("PIN_RESTRICTS_MOVEMENT".into(), PIN_RESTRICTS_MOVEMENT); |
| m.insert("PAWN_CAPTURE_AVAILABLE".into(), PAWN_CAPTURE_AVAILABLE); |
| m.insert("PROMOTION_AVAILABLE".into(), PROMOTION_AVAILABLE); |
| m.insert("EP_CAPTURE_AVAILABLE".into(), EP_CAPTURE_AVAILABLE); |
| m.insert("CASTLE_LEGAL_KINGSIDE".into(), CASTLE_LEGAL_KINGSIDE); |
| m.insert("CASTLE_LEGAL_QUEENSIDE".into(), CASTLE_LEGAL_QUEENSIDE); |
| m.insert("CASTLE_BLOCKED_CHECK".into(), CASTLE_BLOCKED_CHECK); |
| m.insert("RIGHTS_LOST_KINGSIDE".into(), RIGHTS_LOST_KINGSIDE); |
| m.insert("RIGHTS_LOST_QUEENSIDE".into(), RIGHTS_LOST_QUEENSIDE); |
| m.insert("RIGHTS_LOST_CAPTURE_KINGSIDE".into(), RIGHTS_LOST_CAPTURE_KINGSIDE); |
| m.insert("RIGHTS_LOST_CAPTURE_QUEENSIDE".into(), RIGHTS_LOST_CAPTURE_QUEENSIDE); |
| m.insert("PIECE_BLOCKS_B".into(), PIECE_BLOCKS_B); |
| m.insert("PIECE_BLOCKS_C".into(), PIECE_BLOCKS_C); |
| m.insert("PIECE_BLOCKS_D".into(), PIECE_BLOCKS_D); |
| m.insert("PIECE_BLOCKS_F".into(), PIECE_BLOCKS_F); |
| m.insert("PIECE_BLOCKS_G".into(), PIECE_BLOCKS_G); |
| m.insert("ATTACK_BLOCKS_C".into(), ATTACK_BLOCKS_C); |
| m.insert("ATTACK_BLOCKS_D".into(), ATTACK_BLOCKS_D); |
| m.insert("ATTACK_BLOCKS_F".into(), ATTACK_BLOCKS_F); |
| m.insert("ATTACK_BLOCKS_G".into(), ATTACK_BLOCKS_G); |
| m.insert("EP_FILE_A".into(), EP_FILE_A); |
| m.insert("EP_FILE_B".into(), EP_FILE_B); |
| m.insert("EP_FILE_C".into(), EP_FILE_C); |
| m.insert("EP_FILE_D".into(), EP_FILE_D); |
| m.insert("EP_FILE_E".into(), EP_FILE_E); |
| m.insert("EP_FILE_F".into(), EP_FILE_F); |
| m.insert("EP_FILE_G".into(), EP_FILE_G); |
| m.insert("EP_FILE_H".into(), EP_FILE_H); |
| m.insert("PROMO_FILE_A".into(), PROMO_FILE_A); |
| m.insert("PROMO_FILE_B".into(), PROMO_FILE_B); |
| m.insert("PROMO_FILE_C".into(), PROMO_FILE_C); |
| m.insert("PROMO_FILE_D".into(), PROMO_FILE_D); |
| m.insert("PROMO_FILE_E".into(), PROMO_FILE_E); |
| m.insert("PROMO_FILE_F".into(), PROMO_FILE_F); |
| m.insert("PROMO_FILE_G".into(), PROMO_FILE_G); |
| m.insert("PROMO_FILE_H".into(), PROMO_FILE_H); |
| m.insert("SEVENTY_FIVE_MOVE_RULE".into(), SEVENTY_FIVE_MOVE_RULE); |
| m.insert("FIVEFOLD_REPETITION".into(), FIVEFOLD_REPETITION); |
| m.insert("INSUFFICIENT_MATERIAL".into(), INSUFFICIENT_MATERIAL); |
| m |
| } |
|
|
| |
| |
| |
| |
| fn compute_ply_bits( |
| pos: &Chess, |
| legal: &MoveList, |
| prev_own_ks: bool, |
| prev_own_qs: bool, |
| curr_own_ks: bool, |
| curr_own_qs: bool, |
| opponent_captured_ks_rook: bool, |
| opponent_captured_qs_rook: bool, |
| is_final_ply: bool, |
| termination_bits: u64, |
| ) -> u64 { |
| let mut bits: u64 = 0; |
| let turn = pos.turn(); |
|
|
| |
| let in_check = pos.is_check(); |
| if in_check { |
| bits |= IN_CHECK; |
|
|
| |
| let checkers = pos.checkers(); |
| if checkers.count() >= 2 { |
| bits |= IN_DOUBLE_CHECK; |
| } |
| } |
|
|
| if legal.is_empty() { |
| if in_check { |
| bits |= CHECKMATE; |
| } else { |
| bits |= STALEMATE; |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| if has_restricting_pin(pos, legal) { |
| bits |= PIN_RESTRICTS_MOVEMENT; |
| } |
|
|
| |
| let mut has_pawn_capture = false; |
| let mut has_promotion = false; |
| let mut promo_dst_files: u8 = 0; |
|
|
| for m in legal { |
| match m { |
| Move::Normal { role: Role::Pawn, from, to, promotion, .. } => { |
| if from.file() != to.file() { |
| |
| has_pawn_capture = true; |
| } |
| if promotion.is_some() { |
| has_promotion = true; |
| promo_dst_files |= 1 << (to.file() as u8); |
| } |
| } |
| Move::EnPassant { .. } => { |
| |
| has_pawn_capture = true; |
| } |
| _ => {} |
| } |
| } |
|
|
| if has_pawn_capture { |
| bits |= PAWN_CAPTURE_AVAILABLE; |
| } |
| if has_promotion { |
| bits |= PROMOTION_AVAILABLE; |
| |
| for file in 0u8..8 { |
| if promo_dst_files & (1 << file) != 0 { |
| bits |= 1u64 << (32 + file as u64); |
| } |
| } |
| } |
|
|
| |
| let has_ep = legal.iter().any(|m| matches!(m, Move::EnPassant { .. })); |
| if has_ep { |
| bits |= EP_CAPTURE_AVAILABLE; |
| |
| if let Some(ep_sq) = pos.legal_ep_square() { |
| let file = ep_sq.file() as u8; |
| bits |= 1u64 << (24 + file as u64); |
| } |
| } |
|
|
| |
| let back_rank = if turn == Color::White { |
| shakmaty::Rank::First |
| } else { |
| shakmaty::Rank::Eighth |
| }; |
|
|
| |
| let has_castle_ks = legal.iter().any(|m| { |
| matches!(m, Move::Castle { king, rook } if rook.file() > king.file()) |
| }); |
| let has_castle_qs = legal.iter().any(|m| { |
| matches!(m, Move::Castle { king, rook } if rook.file() < king.file()) |
| }); |
|
|
| if has_castle_ks { |
| bits |= CASTLE_LEGAL_KINGSIDE; |
| } |
| if has_castle_qs { |
| bits |= CASTLE_LEGAL_QUEENSIDE; |
| } |
|
|
| |
| if in_check && (curr_own_ks || curr_own_qs) { |
| bits |= CASTLE_BLOCKED_CHECK; |
| } |
|
|
| |
| if prev_own_ks && !curr_own_ks { |
| bits |= RIGHTS_LOST_KINGSIDE; |
| if opponent_captured_ks_rook { |
| bits |= RIGHTS_LOST_CAPTURE_KINGSIDE; |
| } |
| } |
| if prev_own_qs && !curr_own_qs { |
| bits |= RIGHTS_LOST_QUEENSIDE; |
| if opponent_captured_qs_rook { |
| bits |= RIGHTS_LOST_CAPTURE_QUEENSIDE; |
| } |
| } |
|
|
| |
| |
| let opp = !turn; |
| if curr_own_qs { |
| let b_sq = Square::from_coords(shakmaty::File::B, back_rank); |
| let c_sq = Square::from_coords(shakmaty::File::C, back_rank); |
| let d_sq = Square::from_coords(shakmaty::File::D, back_rank); |
|
|
| let b_occupied = pos.board().piece_at(b_sq).is_some(); |
| let c_occupied = pos.board().piece_at(c_sq).is_some(); |
| let d_occupied = pos.board().piece_at(d_sq).is_some(); |
|
|
| if b_occupied { bits |= PIECE_BLOCKS_B; } |
| if c_occupied { |
| bits |= PIECE_BLOCKS_C; |
| } else if is_attacked(pos, c_sq, opp) { |
| bits |= ATTACK_BLOCKS_C; |
| } |
| if d_occupied { |
| bits |= PIECE_BLOCKS_D; |
| } else if is_attacked(pos, d_sq, opp) { |
| bits |= ATTACK_BLOCKS_D; |
| } |
| } |
| if curr_own_ks { |
| let f_sq = Square::from_coords(shakmaty::File::F, back_rank); |
| let g_sq = Square::from_coords(shakmaty::File::G, back_rank); |
|
|
| let f_occupied = pos.board().piece_at(f_sq).is_some(); |
| let g_occupied = pos.board().piece_at(g_sq).is_some(); |
|
|
| if f_occupied { |
| bits |= PIECE_BLOCKS_F; |
| } else if is_attacked(pos, f_sq, opp) { |
| bits |= ATTACK_BLOCKS_F; |
| } |
| if g_occupied { |
| bits |= PIECE_BLOCKS_G; |
| } else if is_attacked(pos, g_sq, opp) { |
| bits |= ATTACK_BLOCKS_G; |
| } |
| } |
|
|
| |
| if is_final_ply { |
| bits |= termination_bits; |
| } |
|
|
| bits |
| } |
|
|
| |
| fn is_attacked(pos: &Chess, sq: Square, by_color: Color) -> bool { |
| let attackers = pos.board().attacks_to(sq, by_color, pos.board().occupied()); |
| !attackers.is_empty() |
| } |
|
|
| |
| |
| |
| |
| |
| |
| fn has_restricting_pin(pos: &Chess, legal: &MoveList) -> bool { |
| let turn = pos.turn(); |
| let king_sq = pos.board().king_of(turn).expect("King must exist"); |
|
|
| |
| let mut legal_from_to: std::collections::HashSet<(Square, Square)> = |
| std::collections::HashSet::new(); |
| let mut piece_legal_count: HashMap<Square, usize> = HashMap::new(); |
|
|
| for m in legal { |
| let from = match m { |
| Move::Normal { from, to, role, .. } if *role != Role::King => { |
| legal_from_to.insert((*from, *to)); |
| *from |
| } |
| Move::EnPassant { from, to } => { |
| legal_from_to.insert((*from, *to)); |
| *from |
| } |
| _ => continue, |
| }; |
| *piece_legal_count.entry(from).or_insert(0) += 1; |
| } |
|
|
| |
| |
| let our_pieces = pos.board().by_color(turn) & !Bitboard::from_square(king_sq); |
|
|
| for sq in our_pieces { |
| let piece = pos.board().piece_at(sq).unwrap(); |
| if piece.role == Role::King { |
| continue; |
| } |
|
|
| |
| |
| let between_king_and_attacker = is_on_pin_ray(pos, sq, king_sq, turn); |
| if !between_king_and_attacker { |
| continue; |
| } |
|
|
| |
| |
| |
| let pseudo_moves = pseudo_legal_moves_for_piece(pos, sq); |
| for (_, to) in &pseudo_moves { |
| if !legal_from_to.contains(&(sq, *to)) { |
| return true; |
| } |
| } |
| } |
|
|
| false |
| } |
|
|
| |
| fn is_on_pin_ray(pos: &Chess, piece_sq: Square, king_sq: Square, our_color: Color) -> bool { |
| |
| let pf = piece_sq.file() as i8; |
| let pr = piece_sq.rank() as i8; |
| let kf = king_sq.file() as i8; |
| let kr = king_sq.rank() as i8; |
|
|
| let df = (pf - kf).signum(); |
| let dr = (pr - kr).signum(); |
|
|
| |
| if df == 0 && dr == 0 { |
| return false; |
| } |
| let on_file = pf == kf; |
| let on_rank = pr == kr; |
| let on_diag = (pf - kf).abs() == (pr - kr).abs(); |
|
|
| if !on_file && !on_rank && !on_diag { |
| return false; |
| } |
|
|
| |
| let enemy = !our_color; |
| let mut cf = pf + df; |
| let mut cr = pr + dr; |
| while cf >= 0 && cf < 8 && cr >= 0 && cr < 8 { |
| let sq = Square::from_coords( |
| shakmaty::File::new(cf as u32), |
| shakmaty::Rank::new(cr as u32), |
| ); |
| if let Some(p) = pos.board().piece_at(sq) { |
| if p.color == enemy { |
| |
| let is_slider = match p.role { |
| Role::Bishop => on_diag, |
| Role::Rook => on_file || on_rank, |
| Role::Queen => true, |
| _ => false, |
| }; |
| return is_slider; |
| } else { |
| |
| return false; |
| } |
| } |
| cf += df; |
| cr += dr; |
| } |
| false |
| } |
|
|
| |
| |
| fn pseudo_legal_moves_for_piece(pos: &Chess, sq: Square) -> Vec<(Square, Square)> { |
| let piece = match pos.board().piece_at(sq) { |
| Some(p) => p, |
| None => return vec![], |
| }; |
|
|
| let mut moves = Vec::new(); |
| let occupied = pos.board().occupied(); |
| let our_pieces = pos.board().by_color(piece.color); |
|
|
| let attacks = match piece.role { |
| Role::Pawn => { |
| |
| let mut pawn_moves = Vec::new(); |
| let rank = sq.rank() as i8; |
| let file = sq.file() as i8; |
| let (dir, home_rank) = if piece.color == Color::White { |
| (1i8, 1i8) |
| } else { |
| (-1i8, 6i8) |
| }; |
|
|
| |
| let fwd_rank = rank + dir; |
| if fwd_rank >= 0 && fwd_rank < 8 { |
| let fwd = Square::from_coords( |
| shakmaty::File::new(file as u32), |
| shakmaty::Rank::new(fwd_rank as u32), |
| ); |
| if pos.board().piece_at(fwd).is_none() { |
| pawn_moves.push((sq, fwd)); |
|
|
| |
| if rank == home_rank { |
| let dbl_rank = rank + 2 * dir; |
| if dbl_rank >= 0 && dbl_rank < 8 { |
| let dbl = Square::from_coords( |
| shakmaty::File::new(file as u32), |
| shakmaty::Rank::new(dbl_rank as u32), |
| ); |
| if pos.board().piece_at(dbl).is_none() { |
| pawn_moves.push((sq, dbl)); |
| } |
| } |
| } |
| } |
| } |
|
|
| |
| for cap_df in [-1i8, 1i8] { |
| let cap_file = file + cap_df; |
| if cap_file >= 0 && cap_file < 8 && fwd_rank >= 0 && fwd_rank < 8 { |
| let cap_sq = Square::from_coords( |
| shakmaty::File::new(cap_file as u32), |
| shakmaty::Rank::new(fwd_rank as u32), |
| ); |
| let is_enemy = pos.board().piece_at(cap_sq) |
| .map_or(false, |p| p.color != piece.color); |
| let is_ep = pos.legal_ep_square() == Some(cap_sq); |
| if is_enemy || is_ep { |
| pawn_moves.push((sq, cap_sq)); |
| } |
| } |
| } |
|
|
| return pawn_moves; |
| } |
| Role::Knight => shakmaty::attacks::knight_attacks(sq), |
| Role::Bishop => shakmaty::attacks::bishop_attacks(sq, occupied), |
| Role::Rook => shakmaty::attacks::rook_attacks(sq, occupied), |
| Role::Queen => shakmaty::attacks::queen_attacks(sq, occupied), |
| Role::King => shakmaty::attacks::king_attacks(sq), |
| }; |
|
|
| |
| let dests = attacks & !our_pieces; |
| for dst in dests { |
| moves.push((sq, dst)); |
| } |
| moves |
| } |
|
|
| |
| |
| fn color_castling_rights(pos: &Chess, color: Color) -> (bool, bool) { |
| let rights = pos.castles().castling_rights(); |
| match color { |
| Color::White => ( |
| rights.contains(Square::H1), |
| rights.contains(Square::A1), |
| ), |
| Color::Black => ( |
| rights.contains(Square::H8), |
| rights.contains(Square::A8), |
| ), |
| } |
| } |
|
|
| |
| pub fn compute_edge_stats_per_ply( |
| move_ids: &[i16], |
| game_lengths: &[i16], |
| max_ply: usize, |
| ) -> (Vec<u64>, Vec<u64>, Vec<u64>) { |
| let batch = game_lengths.len(); |
|
|
| let results: Vec<(Vec<u64>, u64, u64)> = (0..batch) |
| .into_par_iter() |
| .map(|b| { |
| let length = game_lengths[b] as usize; |
| let mut state = GameState::new(); |
| |
| let mut ply_bits = vec![0u64; length + 1]; |
| let mut white_acc: u64 = 0; |
| let mut black_acc: u64 = 0; |
|
|
| |
| |
| let mut white_prev_ks = true; |
| let mut white_prev_qs = true; |
| let mut black_prev_ks = true; |
| let mut black_prev_qs = true; |
| let mut white_first_turn = true; |
| let mut black_first_turn = true; |
|
|
| |
| let mut opponent_captured_white_ks_rook = false; |
| let mut opponent_captured_white_qs_rook = false; |
| let mut opponent_captured_black_ks_rook = false; |
| let mut opponent_captured_black_qs_rook = false; |
|
|
| for t in 0..length { |
| let pos = state.position(); |
| let turn = pos.turn(); |
| let legal = pos.legal_moves(); |
|
|
| let (curr_own_ks, curr_own_qs) = color_castling_rights(pos, turn); |
|
|
| let (prev_ks, prev_qs) = match turn { |
| Color::White => { |
| if white_first_turn { |
| white_first_turn = false; |
| (true, true) |
| } else { |
| (white_prev_ks, white_prev_qs) |
| } |
| } |
| Color::Black => { |
| if black_first_turn { |
| black_first_turn = false; |
| (true, true) |
| } else { |
| (black_prev_ks, black_prev_qs) |
| } |
| } |
| }; |
|
|
| let (opp_cap_ks, opp_cap_qs) = match turn { |
| Color::White => (opponent_captured_white_ks_rook, opponent_captured_white_qs_rook), |
| Color::Black => (opponent_captured_black_ks_rook, opponent_captured_black_qs_rook), |
| }; |
|
|
| let bits = compute_ply_bits( |
| pos, &legal, |
| prev_ks, prev_qs, |
| curr_own_ks, curr_own_qs, |
| opp_cap_ks, opp_cap_qs, |
| false, 0, |
| ); |
|
|
| ply_bits[t] = bits; |
| match turn { |
| Color::White => white_acc |= bits, |
| Color::Black => black_acc |= bits, |
| } |
|
|
| |
| match turn { |
| Color::White => { |
| white_prev_ks = curr_own_ks; |
| white_prev_qs = curr_own_qs; |
| |
| opponent_captured_white_ks_rook = false; |
| opponent_captured_white_qs_rook = false; |
| } |
| Color::Black => { |
| black_prev_ks = curr_own_ks; |
| black_prev_qs = curr_own_qs; |
| opponent_captured_black_ks_rook = false; |
| opponent_captured_black_qs_rook = false; |
| } |
| } |
|
|
| |
| let token = move_ids[b * max_ply + t] as u16; |
|
|
| |
| if let Some((_src, dst, _)) = vocab::decompose_token(token) { |
| let dst_sq = crate::board::our_sq_to_shakmaty(dst); |
| if let Some(captured) = pos.board().piece_at(dst_sq) { |
| if captured.role == Role::Rook { |
| match captured.color { |
| Color::White => { |
| if dst_sq == Square::H1 { opponent_captured_white_ks_rook = true; } |
| if dst_sq == Square::A1 { opponent_captured_white_qs_rook = true; } |
| } |
| Color::Black => { |
| if dst_sq == Square::H8 { opponent_captured_black_ks_rook = true; } |
| if dst_sq == Square::A8 { opponent_captured_black_qs_rook = true; } |
| } |
| } |
| } |
| } |
| } |
|
|
| state.make_move(token).expect("Move should be legal"); |
| } |
|
|
| |
| |
| { |
| let pos = state.position(); |
| let turn = pos.turn(); |
| let legal = pos.legal_moves(); |
|
|
| let (curr_own_ks, curr_own_qs) = color_castling_rights(pos, turn); |
| let (prev_ks, prev_qs) = match turn { |
| Color::White => (white_prev_ks, white_prev_qs), |
| Color::Black => (black_prev_ks, black_prev_qs), |
| }; |
| let (opp_cap_ks, opp_cap_qs) = match turn { |
| Color::White => (opponent_captured_white_ks_rook, opponent_captured_white_qs_rook), |
| Color::Black => (opponent_captured_black_ks_rook, opponent_captured_black_qs_rook), |
| }; |
|
|
| let mut term_bits = 0u64; |
| if state.halfmove_clock() >= 150 { |
| term_bits |= SEVENTY_FIVE_MOVE_RULE; |
| } |
| if state.is_fivefold_repetition() { |
| term_bits |= FIVEFOLD_REPETITION; |
| } |
| if pos.is_insufficient_material() { |
| term_bits |= INSUFFICIENT_MATERIAL; |
| } |
|
|
| let bits = compute_ply_bits( |
| pos, &legal, |
| prev_ks, prev_qs, |
| curr_own_ks, curr_own_qs, |
| opp_cap_ks, opp_cap_qs, |
| true, term_bits, |
| ); |
|
|
| ply_bits[length] = bits; |
| match turn { |
| Color::White => white_acc |= bits, |
| Color::Black => black_acc |= bits, |
| } |
| } |
|
|
| (ply_bits, white_acc, black_acc) |
| }) |
| .collect(); |
|
|
| |
| |
| |
| |
| let mut per_ply = vec![0u64; batch * max_ply]; |
| let mut white = Vec::with_capacity(batch); |
| let mut black = Vec::with_capacity(batch); |
|
|
| for (b, (ply_bits, w, bl)) in results.into_iter().enumerate() { |
| let length = game_lengths[b] as usize; |
| |
| let copy_len = std::cmp::min(length + 1, max_ply); |
| per_ply[b * max_ply..b * max_ply + copy_len] |
| .copy_from_slice(&ply_bits[..copy_len]); |
| white.push(w); |
| black.push(bl); |
| } |
|
|
| (per_ply, white, black) |
| } |
|
|
| |
| pub fn compute_edge_stats_per_game( |
| move_ids: &[i16], |
| game_lengths: &[i16], |
| max_ply: usize, |
| ) -> (Vec<u64>, Vec<u64>) { |
| let (_, white, black) = compute_edge_stats_per_ply(move_ids, game_lengths, max_ply); |
| (white, black) |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use crate::batch::generate_random_games; |
|
|
| #[test] |
| fn test_initial_position_bits() { |
| |
| let batch = generate_random_games(1, 256, 42); |
| let (per_ply, white, black) = compute_edge_stats_per_ply( |
| &batch.move_ids, &batch.game_lengths, 256, |
| ); |
|
|
| |
| let bits = per_ply[0]; |
| assert_eq!(bits & IN_CHECK, 0); |
| assert_eq!(bits & CHECKMATE, 0); |
| assert_eq!(bits & EP_CAPTURE_AVAILABLE, 0); |
| assert_eq!(bits & PROMOTION_AVAILABLE, 0); |
| |
| assert_eq!(bits & PAWN_CAPTURE_AVAILABLE, 0); |
| } |
|
|
| #[test] |
| fn test_edge_stats_accumulators() { |
| let batch = generate_random_games(10, 256, 42); |
| let (_, white, black) = compute_edge_stats_per_ply( |
| &batch.move_ids, &batch.game_lengths, 256, |
| ); |
| let (white2, black2) = compute_edge_stats_per_game( |
| &batch.move_ids, &batch.game_lengths, 256, |
| ); |
| assert_eq!(white, white2); |
| assert_eq!(black, black2); |
| } |
| } |
|
|