| use std::collections::{HashMap, HashSet}; |
| use std::io::{self, BufRead}; |
| use std::sync::{Arc, Mutex}; |
| use std::sync::atomic::{AtomicU64, AtomicBool, Ordering}; |
|
|
| static NODES_SEARCHED: AtomicU64 = AtomicU64::new(0); |
| use dashmap::DashMap; |
| use shakmaty::{Chess, Move, Position, Setup, CastlingMode}; |
| use shakmaty::zobrist::{Zobrist64, ZobristHash}; |
| use ort::session::builder::GraphOptimizationLevel; |
| use ort::session::Session; |
| use shakmaty_syzygy::{Tablebase, Wdl, AmbiguousWdl}; |
| use std::path::Path; |
|
|
| struct PolyglotBook { |
| file: std::fs::File, |
| num_entries: u64, |
| } |
|
|
| impl PolyglotBook { |
| fn new(path: &str) -> Result<Self, io::Error> { |
| let file = std::fs::File::open(path)?; |
| let metadata = file.metadata()?; |
| let num_entries = metadata.len() / 16; |
| Ok(Self { file, num_entries }) |
| } |
|
|
| fn lookup(&self, hash: u64) -> Option<String> { |
| use std::io::{Read, Seek, SeekFrom}; |
| let mut low = 0; |
| let mut high = self.num_entries as i64 - 1; |
| let mut f = &self.file; |
|
|
| while low <= high { |
| let mid = low + (high - low) / 2; |
| f.seek(SeekFrom::Start(mid as u64 * 16)).ok()?; |
| let mut buf = [0u8; 8]; |
| f.read_exact(&mut buf).ok()?; |
| let key = u64::from_be_bytes(buf); |
|
|
| if key == hash { |
| f.read_exact(&mut buf).ok()?; |
| let move_bits = u16::from_be_bytes([buf[0], buf[1]]); |
| return Some(Self::decode_move(move_bits)); |
| } else if key < hash { |
| low = mid + 1; |
| } else { |
| high = mid - 1; |
| } |
| } |
| None |
| } |
|
|
| fn decode_move(move_bits: u16) -> String { |
| let to_file = (move_bits & 7) as u8; |
| let to_row = ((move_bits >> 3) & 7) as u8; |
| let from_file = ((move_bits >> 6) & 7) as u8; |
| let from_row = ((move_bits >> 9) & 7) as u8; |
| let prom = (move_bits >> 12) & 7; |
|
|
| let from_sq = format!("{}{}", (b'a' + from_file) as char, (b'1' + from_row) as char); |
| let to_sq = format!("{}{}", (b'a' + to_file) as char, (b'1' + to_row) as char); |
| |
| let prom_char = match prom { |
| 1 => "n", |
| 2 => "b", |
| 3 => "r", |
| 4 => "q", |
| _ => "", |
| }; |
|
|
| format!("{}{}{}", from_sq, to_sq, prom_char) |
| } |
| } |
|
|
| fn load_vocab() -> (HashMap<String, i64>, HashMap<i64, String>) { |
| let vocab_str = std::fs::read_to_string("vocab.json").expect("Failed to read vocab.json"); |
| let vocab: HashMap<String, i64> = serde_json::from_str(&vocab_str).expect("Invalid vocab.json"); |
| let mut inv_vocab = HashMap::new(); |
| for (k, &v) in &vocab { |
| inv_vocab.insert(v, k.clone()); |
| } |
| (vocab, inv_vocab) |
| } |
|
|
| fn get_piece_value(role: shakmaty::Role) -> i32 { |
| match role { |
| shakmaty::Role::Pawn => 100, |
| shakmaty::Role::Knight => 320, |
| shakmaty::Role::Bishop => 330, |
| shakmaty::Role::Rook => 500, |
| shakmaty::Role::Queen => 900, |
| shakmaty::Role::King => 20000, |
| } |
| } |
|
|
| #[derive(Clone)] |
| struct RootKVCache { |
| seq_len: usize, |
| layers: Vec<(Vec<f32>, Vec<f32>)>, |
| } |
|
|
| fn compute_root_kv_cache( |
| session: &mut Session, |
| history: &[String], |
| vocab: &HashMap<String, i64>, |
| ) -> RootKVCache { |
| let mut seq = Vec::new(); |
| let bos_id = vocab.get("<bos>").copied().unwrap_or(0); |
| seq.push(bos_id); |
| let max_len = 120; |
| let start_idx = if history.len() > max_len { history.len() - max_len } else { 0 }; |
| for tok in &history[start_idx..] { |
| if tok == "<bos>" { continue; } |
| if let Some(&id) = vocab.get(tok) { seq.push(id); } |
| else { seq.push(vocab.get("<unk>").copied().unwrap_or(0)); } |
| } |
| |
| let seq_len = seq.len(); |
| let input_value = ort::value::Tensor::from_array((vec![1, seq_len], seq)).unwrap(); |
| |
| let pk_0 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap(); |
| let pv_0 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap(); |
| let pk_1 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap(); |
| let pv_1 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap(); |
| let pk_2 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap(); |
| let pv_2 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap(); |
| let pk_3 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap(); |
| let pv_3 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap(); |
| let pk_4 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap(); |
| let pv_4 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap(); |
| let pk_5 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap(); |
| let pv_5 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap(); |
|
|
| let outputs = session.run(ort::inputs![ |
| "input_ids" => input_value, |
| "past_k_0" => pk_0, "past_v_0" => pv_0, |
| "past_k_1" => pk_1, "past_v_1" => pv_1, |
| "past_k_2" => pk_2, "past_v_2" => pv_2, |
| "past_k_3" => pk_3, "past_v_3" => pv_3, |
| "past_k_4" => pk_4, "past_v_4" => pv_4, |
| "past_k_5" => pk_5, "past_v_5" => pv_5 |
| ]).unwrap(); |
| |
| let mut cache = RootKVCache { seq_len, layers: vec![] }; |
| for i in 0..6 { |
| let pk = outputs[format!("present_k_{}", i).as_str()].try_extract_tensor::<f32>().unwrap().1.to_vec(); |
| let pv = outputs[format!("present_v_{}", i).as_str()].try_extract_tensor::<f32>().unwrap().1.to_vec(); |
| cache.layers.push((pk, pv)); |
| } |
| |
| cache |
| } |
|
|
| fn evaluate_onnx_value( |
| session: &mut Session, |
| history: &[String], |
| vocab: &HashMap<String, i64>, |
| root_cache: &Option<RootKVCache>, |
| root_history_len: usize, |
| ) -> f32 { |
| let mut seq = Vec::new(); |
| let is_root = root_cache.is_none(); |
| |
| if is_root { |
| let bos_id = vocab.get("<bos>").copied().unwrap_or(0); |
| seq.push(bos_id); |
| for tok in history { |
| if tok == "<bos>" { continue; } |
| if let Some(&id) = vocab.get(tok) { seq.push(id); } |
| else { seq.push(vocab.get("<unk>").copied().unwrap_or(0)); } |
| } |
| } else { |
| let new_moves = if history.len() > root_history_len { &history[root_history_len..] } else { &[] }; |
| for tok in new_moves { |
| if tok == "<bos>" { continue; } |
| if let Some(&id) = vocab.get(tok) { seq.push(id); } |
| else { seq.push(vocab.get("<unk>").copied().unwrap_or(0)); } |
| } |
| if seq.is_empty() { |
| seq.push(vocab.get("<pad>").copied().unwrap_or(0)); |
| } |
| } |
| |
| let seq_len = seq.len(); |
| let input_value = ort::value::Tensor::from_array((vec![1, seq_len], seq)).unwrap(); |
| |
| let past_seq_len = root_cache.as_ref().map(|c| c.seq_len).unwrap_or(0); |
| |
| let pk_0 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[0].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() }; |
| let pv_0 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[0].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() }; |
| let pk_1 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[1].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() }; |
| let pv_1 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[1].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() }; |
| let pk_2 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[2].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() }; |
| let pv_2 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[2].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() }; |
| let pk_3 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[3].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() }; |
| let pv_3 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[3].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() }; |
| let pk_4 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[4].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() }; |
| let pv_4 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[4].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() }; |
| let pk_5 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[5].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() }; |
| let pv_5 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[5].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() }; |
|
|
| let outputs = session.run(ort::inputs![ |
| "input_ids" => input_value, |
| "past_k_0" => pk_0, "past_v_0" => pv_0, |
| "past_k_1" => pk_1, "past_v_1" => pv_1, |
| "past_k_2" => pk_2, "past_v_2" => pv_2, |
| "past_k_3" => pk_3, "past_v_3" => pv_3, |
| "past_k_4" => pk_4, "past_v_4" => pv_4, |
| "past_k_5" => pk_5, "past_v_5" => pv_5 |
| ]).unwrap(); |
| let value_tensor = outputs["value"].try_extract_tensor::<f32>().unwrap(); |
| value_tensor.1[0] |
| } |
|
|
| #[derive(Clone, Copy, PartialEq)] |
| enum TTNodeType { Exact, LowerBound, UpperBound } |
|
|
| #[derive(Clone)] |
| struct TTEntry { |
| depth: u8, |
| score: f32, |
| node_type: TTNodeType, |
| best_move: Option<shakmaty::uci::Uci>, |
| } |
|
|
| fn alpha_beta( |
| board: &Chess, |
| depth: u8, |
| mut alpha: f32, |
| mut beta: f32, |
| history: &mut Vec<String>, |
| session: &mut Option<Session>, |
| vocab: &HashMap<String, i64>, |
| tt: &DashMap<u64, TTEntry>, |
| tablebases: &Option<Tablebase<Chess>>, |
| root_cache: &Option<RootKVCache>, |
| root_history_len: usize, |
| stop_flag: &AtomicBool, |
| is_null: bool, |
| ) -> f32 { |
| NODES_SEARCHED.fetch_add(1, Ordering::Relaxed); |
| if stop_flag.load(Ordering::Relaxed) { return 0.0; } |
|
|
| if let Some(tb) = tablebases { |
| if board.board().occupied().count() <= 6 { |
| if let Ok(wdl) = tb.probe_wdl(board) { |
| match wdl { |
| AmbiguousWdl::Win => return 0.99, |
| AmbiguousWdl::Loss => return -0.99, |
| AmbiguousWdl::Draw | AmbiguousWdl::BlessedLoss | AmbiguousWdl::CursedWin | AmbiguousWdl::MaybeLoss | AmbiguousWdl::MaybeWin => return 0.0, |
| } |
| } |
| } |
| } |
|
|
| if board.is_game_over() { |
| if board.is_checkmate() { return -1.0; } |
| return 0.0; |
| } |
| |
| |
| |
| |
| |
| let hash = board.zobrist_hash::<Zobrist64>(shakmaty::EnPassantMode::Legal).0; |
| |
| let mut tt_move = None; |
| if let Some(entry) = tt.get(&hash) { |
| if entry.depth >= depth { |
| match entry.node_type { |
| TTNodeType::Exact => return entry.score, |
| TTNodeType::LowerBound => { if entry.score > alpha { alpha = entry.score; } }, |
| TTNodeType::UpperBound => { if entry.score < beta { beta = entry.score; } }, |
| } |
| if alpha >= beta { return entry.score; } |
| } |
| tt_move = entry.best_move.clone(); |
| } |
| |
| if depth == 0 { |
| let nn_val = if let Some(sess) = session { |
| evaluate_onnx_value(sess, history, vocab, root_cache, root_history_len) |
| } else { |
| 0.0 |
| }; |
| let val = if board.turn() == shakmaty::Color::Black { -nn_val } else { nn_val }; |
| return val; |
| } |
| |
| let mut legals: Vec<_> = board.legal_moves().into_iter().collect(); |
| legals.sort_by_cached_key(|m| { |
| let uci = m.to_uci(CastlingMode::Standard); |
| if Some(uci.clone()) == tt_move { return -20000; } |
| if m.is_capture() { |
| let victim = m.capture().map(|r| get_piece_value(r)).unwrap_or(0); |
| let attacker = get_piece_value(m.role()); |
| return -victim + attacker - 10000; |
| } |
| 0 |
| }); |
| |
| let mut best_val = -2.0; |
| let mut best_move = None; |
| let old_alpha = alpha; |
| |
| let mut move_count = 0; |
| |
| for m in &legals { |
| move_count += 1; |
| let uci = m.to_uci(CastlingMode::Standard); |
| let mut next_board = board.clone(); |
| next_board.play_unchecked(m); |
| history.push(uci.to_string()); |
| |
| let is_tactical = m.is_capture() || m.is_promotion() || next_board.is_check(); |
| let is_tt = Some(uci.clone()) == tt_move; |
| |
| let mut score; |
| if depth >= 3 && move_count >= 4 && !is_tactical && !is_tt { |
| let r_depth = depth - 2; |
| score = -alpha_beta(&next_board, r_depth, -beta, -alpha, history, session, vocab, tt, tablebases, root_cache, root_history_len, stop_flag, false); |
| if score > alpha { |
| score = -alpha_beta(&next_board, depth - 1, -beta, -alpha, history, session, vocab, tt, tablebases, root_cache, root_history_len, stop_flag, false); |
| } |
| } else { |
| score = -alpha_beta(&next_board, depth - 1, -beta, -alpha, history, session, vocab, tt, tablebases, root_cache, root_history_len, stop_flag, false); |
| } |
| |
| history.pop(); |
| |
| if score > best_val { |
| best_val = score; |
| best_move = Some(uci.clone()); |
| } |
| if score > alpha { alpha = score; } |
| if alpha >= beta { break; } |
| } |
| |
| let node_type = if best_val <= old_alpha { TTNodeType::UpperBound } |
| else if best_val >= beta { TTNodeType::LowerBound } |
| else { TTNodeType::Exact }; |
| |
| tt.insert(hash, TTEntry { |
| depth, |
| score: best_val, |
| node_type, |
| best_move, |
| }); |
| |
| best_val |
| } |
|
|
| fn main() { |
| let stdin = io::stdin(); |
| let mut board = Chess::default(); |
| let mut history_moves = vec!["<bos>".to_string()]; |
| |
| let _ = ort::init().with_name("neural_engine").commit(); |
| |
| let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4); |
| |
| let mut tb = Tablebase::<Chess>::new(); |
| let _ = tb.add_directory("syzygy"); |
| let tablebases = Arc::new(Some(tb)); |
| |
| let polyglot_book = Arc::new(std::sync::Mutex::new(PolyglotBook::new("book.bin").ok())); |
| let (vocab, _) = load_vocab(); |
| let vocab = Arc::new(vocab); |
|
|
| let tt = Arc::new(DashMap::new()); |
| let mut global_root_hash = board.zobrist_hash::<Zobrist64>(shakmaty::EnPassantMode::Legal).0; |
|
|
| for line in stdin.lock().lines() { |
| let line = line.expect("Failed to read line"); |
| let tokens: Vec<&str> = line.trim().split_whitespace().collect(); |
| if tokens.is_empty() { continue; } |
|
|
| match tokens[0] { |
| "uci" => { |
| println!("id name Neurex Pure Value AB"); |
| println!("id author Neural Engine Architect"); |
| println!("uciok"); |
| } |
| "isready" => { |
| println!("readyok"); |
| } |
| "position" => { |
| board = Chess::default(); |
| history_moves.clear(); |
| history_moves.push("<bos>".to_string()); |
| global_root_hash = board.zobrist_hash::<Zobrist64>(shakmaty::EnPassantMode::Legal).0; |
| |
| if tokens.contains(&"moves") { |
| let moves_idx = tokens.iter().position(|&r| r == "moves").unwrap(); |
| for m_str in &tokens[moves_idx + 1..] { |
| if let Ok(uci_move) = shakmaty::uci::Uci::from_ascii(m_str.as_bytes()) { |
| if let Ok(m) = uci_move.to_move(&board) { |
| board.play_unchecked(&m); |
| history_moves.push(m_str.to_string()); |
| global_root_hash = board.zobrist_hash::<Zobrist64>(shakmaty::EnPassantMode::Legal).0; |
| } |
| } |
| } |
| } |
| } |
| "go" => { |
| let mut movetime_ms = 5000; |
| let mut wtime = 0; |
| let mut btime = 0; |
| let mut inc = 0; |
| |
| let mut i = 1; |
| while i < tokens.len() { |
| match tokens[i] { |
| "wtime" => { wtime = tokens[i+1].parse().unwrap_or(0); i += 2; } |
| "btime" => { btime = tokens[i+1].parse().unwrap_or(0); i += 2; } |
| "winc" => { if board.turn() == shakmaty::Color::White { inc = tokens[i+1].parse().unwrap_or(0); } i += 2; } |
| "binc" => { if board.turn() == shakmaty::Color::Black { inc = tokens[i+1].parse().unwrap_or(0); } i += 2; } |
| "movetime" => { movetime_ms = tokens[i+1].parse().unwrap_or(5000); i += 2; } |
| _ => { i += 1; } |
| } |
| } |
| |
| let mut time_left = 0; |
| if wtime > 0 || btime > 0 { |
| time_left = if board.turn() == shakmaty::Color::White { wtime } else { btime }; |
| let estimated_moves_remaining = 30; |
| movetime_ms = time_left / estimated_moves_remaining + inc / 2; |
| let max_allowable = time_left / 4; |
| movetime_ms = movetime_ms.min(max_allowable).max(100); |
| } |
| |
| if let Ok(book_opt) = polyglot_book.lock() { |
| if let Some(book) = book_opt.as_ref() { |
| if let Some(m_str) = book.lookup(global_root_hash) { |
| println!("bestmove {}", m_str); |
| continue; |
| } |
| } |
| } |
| |
| tt.clear(); |
| NODES_SEARCHED.store(0, Ordering::Relaxed); |
| let start_time = std::time::Instant::now(); |
| let stop_flag = Arc::new(AtomicBool::new(false)); |
| let mut handles = Vec::new(); |
| |
| let root_history_len = history_moves.len(); |
| let mut root_cache = None; |
| if let Some(mut sess) = if std::path::Path::new("model_int8.onnx").exists() { |
| Some(Session::builder().unwrap() |
| .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() |
| .with_intra_threads(1).unwrap() |
| .commit_from_file("model_int8.onnx").unwrap()) |
| } else { None } { |
| root_cache = Some(compute_root_kv_cache(&mut sess, &history_moves, &vocab)); |
| } |
|
|
| for thread_id in 0..num_threads { |
| let t_board = board.clone(); |
| let mut t_history = history_moves.clone(); |
| let t_vocab = Arc::clone(&vocab); |
| let t_tb = Arc::clone(&tablebases); |
| let t_stop_flag = Arc::clone(&stop_flag); |
| let t_root_cache = root_cache.clone(); |
| let t_tt = Arc::clone(&tt); |
| |
| let handle = std::thread::spawn(move || { |
| let mut t_session = if Path::new("model_int8.onnx").exists() { |
| Some(Session::builder().unwrap() |
| .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() |
| .with_intra_threads(1).unwrap() |
| .commit_from_file("model_int8.onnx").unwrap()) |
| } else { None }; |
| |
| let max_depth = 100; |
| let mut prev_score = 0.0; |
| for depth in 1..=max_depth { |
| if t_stop_flag.load(Ordering::Relaxed) { break; } |
| |
| let mut alpha = -2.0; |
| let mut beta = 2.0; |
| |
| if depth >= 3 { |
| alpha = prev_score - 0.2; |
| beta = prev_score + 0.2; |
| } |
| |
| let mut score = alpha_beta( |
| &t_board, |
| depth, |
| alpha, |
| beta, |
| &mut t_history, |
| &mut t_session, |
| &t_vocab, |
| &t_tt, |
| &t_tb, |
| &t_root_cache, |
| root_history_len, |
| &t_stop_flag, |
| false |
| ); |
| |
| if score <= alpha || score >= beta { |
| alpha = -2.0; |
| beta = 2.0; |
| score = alpha_beta( |
| &t_board, |
| depth, |
| alpha, |
| beta, |
| &mut t_history, |
| &mut t_session, |
| &t_vocab, |
| &t_tt, |
| &t_tb, |
| &t_root_cache, |
| root_history_len, |
| &t_stop_flag, |
| false |
| ); |
| } |
| prev_score = score; |
| |
| let is_main_thread = thread_id == 0; |
| if is_main_thread && !t_stop_flag.load(Ordering::Relaxed) { |
| let elapsed = start_time.elapsed().as_secs_f64(); |
| let n = NODES_SEARCHED.load(Ordering::Relaxed); |
| let nps = if elapsed > 0.0 { (n as f64 / elapsed) as u64 } else { 0 }; |
| let cp = (prev_score * 100.0) as i32; |
| |
| let mut pv = String::new(); |
| if let Some(entry) = t_tt.get(&global_root_hash) { |
| if let Some(m) = &entry.best_move { |
| pv = m.to_string(); |
| } |
| } |
| |
| println!("info depth {} score cp {} nodes {} nps {} time {} pv {}", |
| depth, cp, n, nps, (elapsed * 1000.0) as u64, pv); |
| } |
| } |
| }); |
| handles.push(handle); |
| } |
| |
| std::thread::sleep(std::time::Duration::from_millis(movetime_ms as u64)); |
| stop_flag.store(true, Ordering::Relaxed); |
| |
| for h in handles { |
| let _ = h.join(); |
| } |
| |
| let mut bestmove = String::from("0000"); |
| if let Some(entry) = tt.get(&global_root_hash) { |
| if let Some(m) = &entry.best_move { |
| bestmove = m.to_string(); |
| } |
| } |
| |
| if bestmove == "0000" { |
| if let Some(m) = board.legal_moves().into_iter().next() { |
| bestmove = m.to_uci(CastlingMode::Standard).to_string(); |
| } |
| } |
| |
| println!("bestmove {}", bestmove); |
| } |
| "quit" => break, |
| _ => {} |
| } |
| } |
| } |
|
|