use std::collections::{HashMap, HashSet}; use std::io::{self, BufRead}; use std::sync::{Arc, Mutex}; use std::sync::atomic::{AtomicU64, AtomicBool, Ordering}; static NODES_SEARCHED: AtomicU64 = AtomicU64::new(0); use dashmap::DashMap; use shakmaty::{Chess, Move, Position, Setup, CastlingMode}; use shakmaty::zobrist::{Zobrist64, ZobristHash}; use ort::session::builder::GraphOptimizationLevel; use ort::session::Session; use shakmaty_syzygy::{Tablebase, Wdl, AmbiguousWdl}; use std::path::Path; struct PolyglotBook { file: std::fs::File, num_entries: u64, } impl PolyglotBook { fn new(path: &str) -> Result { let file = std::fs::File::open(path)?; let metadata = file.metadata()?; let num_entries = metadata.len() / 16; Ok(Self { file, num_entries }) } fn lookup(&self, hash: u64) -> Option { use std::io::{Read, Seek, SeekFrom}; let mut low = 0; let mut high = self.num_entries as i64 - 1; let mut f = &self.file; while low <= high { let mid = low + (high - low) / 2; f.seek(SeekFrom::Start(mid as u64 * 16)).ok()?; let mut buf = [0u8; 8]; f.read_exact(&mut buf).ok()?; let key = u64::from_be_bytes(buf); if key == hash { f.read_exact(&mut buf).ok()?; let move_bits = u16::from_be_bytes([buf[0], buf[1]]); return Some(Self::decode_move(move_bits)); } else if key < hash { low = mid + 1; } else { high = mid - 1; } } None } fn decode_move(move_bits: u16) -> String { let to_file = (move_bits & 7) as u8; let to_row = ((move_bits >> 3) & 7) as u8; let from_file = ((move_bits >> 6) & 7) as u8; let from_row = ((move_bits >> 9) & 7) as u8; let prom = (move_bits >> 12) & 7; let from_sq = format!("{}{}", (b'a' + from_file) as char, (b'1' + from_row) as char); let to_sq = format!("{}{}", (b'a' + to_file) as char, (b'1' + to_row) as char); let prom_char = match prom { 1 => "n", 2 => "b", 3 => "r", 4 => "q", _ => "", }; format!("{}{}{}", from_sq, to_sq, prom_char) } } fn load_vocab() -> (HashMap, HashMap) { let vocab_str = std::fs::read_to_string("vocab.json").expect("Failed to read vocab.json"); let vocab: HashMap = serde_json::from_str(&vocab_str).expect("Invalid vocab.json"); let mut inv_vocab = HashMap::new(); for (k, &v) in &vocab { inv_vocab.insert(v, k.clone()); } (vocab, inv_vocab) } fn get_piece_value(role: shakmaty::Role) -> i32 { match role { shakmaty::Role::Pawn => 100, shakmaty::Role::Knight => 320, shakmaty::Role::Bishop => 330, shakmaty::Role::Rook => 500, shakmaty::Role::Queen => 900, shakmaty::Role::King => 20000, } } #[derive(Clone)] struct RootKVCache { seq_len: usize, layers: Vec<(Vec, Vec)>, } fn compute_root_kv_cache( session: &mut Session, history: &[String], vocab: &HashMap, ) -> RootKVCache { let mut seq = Vec::new(); let bos_id = vocab.get("").copied().unwrap_or(0); seq.push(bos_id); let max_len = 120; let start_idx = if history.len() > max_len { history.len() - max_len } else { 0 }; for tok in &history[start_idx..] { if tok == "" { continue; } if let Some(&id) = vocab.get(tok) { seq.push(id); } else { seq.push(vocab.get("").copied().unwrap_or(0)); } } let seq_len = seq.len(); let input_value = ort::value::Tensor::from_array((vec![1, seq_len], seq)).unwrap(); let pk_0 = ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap(); let pv_0 = ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap(); let pk_1 = ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap(); let pv_1 = ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap(); let pk_2 = ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap(); let pv_2 = ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap(); let pk_3 = ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap(); let pv_3 = ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap(); let pk_4 = ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap(); let pv_4 = ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap(); let pk_5 = ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap(); let pv_5 = ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap(); let outputs = session.run(ort::inputs![ "input_ids" => input_value, "past_k_0" => pk_0, "past_v_0" => pv_0, "past_k_1" => pk_1, "past_v_1" => pv_1, "past_k_2" => pk_2, "past_v_2" => pv_2, "past_k_3" => pk_3, "past_v_3" => pv_3, "past_k_4" => pk_4, "past_v_4" => pv_4, "past_k_5" => pk_5, "past_v_5" => pv_5 ]).unwrap(); let mut cache = RootKVCache { seq_len, layers: vec![] }; for i in 0..6 { let pk = outputs[format!("present_k_{}", i).as_str()].try_extract_tensor::().unwrap().1.to_vec(); let pv = outputs[format!("present_v_{}", i).as_str()].try_extract_tensor::().unwrap().1.to_vec(); cache.layers.push((pk, pv)); } cache } fn evaluate_onnx_value( session: &mut Session, history: &[String], vocab: &HashMap, root_cache: &Option, root_history_len: usize, ) -> f32 { let mut seq = Vec::new(); let is_root = root_cache.is_none(); if is_root { let bos_id = vocab.get("").copied().unwrap_or(0); seq.push(bos_id); for tok in history { if tok == "" { continue; } if let Some(&id) = vocab.get(tok) { seq.push(id); } else { seq.push(vocab.get("").copied().unwrap_or(0)); } } } else { let new_moves = if history.len() > root_history_len { &history[root_history_len..] } else { &[] }; for tok in new_moves { if tok == "" { continue; } if let Some(&id) = vocab.get(tok) { seq.push(id); } else { seq.push(vocab.get("").copied().unwrap_or(0)); } } if seq.is_empty() { seq.push(vocab.get("").copied().unwrap_or(0)); } } let seq_len = seq.len(); let input_value = ort::value::Tensor::from_array((vec![1, seq_len], seq)).unwrap(); let past_seq_len = root_cache.as_ref().map(|c| c.seq_len).unwrap_or(0); let pk_0 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[0].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap() }; let pv_0 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[0].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap() }; let pk_1 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[1].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap() }; let pv_1 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[1].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap() }; let pk_2 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[2].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap() }; let pv_2 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[2].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap() }; let pk_3 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[3].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap() }; let pv_3 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[3].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap() }; let pk_4 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[4].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap() }; let pv_4 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[4].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap() }; let pk_5 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[5].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap() }; let pv_5 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[5].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::::zeros((1, 2, 0, 64))).unwrap() }; let outputs = session.run(ort::inputs![ "input_ids" => input_value, "past_k_0" => pk_0, "past_v_0" => pv_0, "past_k_1" => pk_1, "past_v_1" => pv_1, "past_k_2" => pk_2, "past_v_2" => pv_2, "past_k_3" => pk_3, "past_v_3" => pv_3, "past_k_4" => pk_4, "past_v_4" => pv_4, "past_k_5" => pk_5, "past_v_5" => pv_5 ]).unwrap(); let value_tensor = outputs["value"].try_extract_tensor::().unwrap(); value_tensor.1[0] } #[derive(Clone, Copy, PartialEq)] enum TTNodeType { Exact, LowerBound, UpperBound } #[derive(Clone)] struct TTEntry { depth: u8, score: f32, node_type: TTNodeType, best_move: Option, } fn alpha_beta( board: &Chess, depth: u8, mut alpha: f32, mut beta: f32, history: &mut Vec, session: &mut Option, vocab: &HashMap, tt: &DashMap, tablebases: &Option>, root_cache: &Option, root_history_len: usize, stop_flag: &AtomicBool, is_null: bool, ) -> f32 { NODES_SEARCHED.fetch_add(1, Ordering::Relaxed); if stop_flag.load(Ordering::Relaxed) { return 0.0; } if let Some(tb) = tablebases { if board.board().occupied().count() <= 6 { if let Ok(wdl) = tb.probe_wdl(board) { match wdl { AmbiguousWdl::Win => return 0.99, AmbiguousWdl::Loss => return -0.99, AmbiguousWdl::Draw | AmbiguousWdl::BlessedLoss | AmbiguousWdl::CursedWin | AmbiguousWdl::MaybeLoss | AmbiguousWdl::MaybeWin => return 0.0, } } } } if board.is_game_over() { if board.is_checkmate() { return -1.0; } return 0.0; } // Null Move Pruning removed because passing a turn without appending to history // breaks the Transformer's odd/even sequence length mapping to player turn, // resulting in hallucinated evaluations. let hash = board.zobrist_hash::(shakmaty::EnPassantMode::Legal).0; let mut tt_move = None; if let Some(entry) = tt.get(&hash) { if entry.depth >= depth { match entry.node_type { TTNodeType::Exact => return entry.score, TTNodeType::LowerBound => { if entry.score > alpha { alpha = entry.score; } }, TTNodeType::UpperBound => { if entry.score < beta { beta = entry.score; } }, } if alpha >= beta { return entry.score; } } tt_move = entry.best_move.clone(); } if depth == 0 { let nn_val = if let Some(sess) = session { evaluate_onnx_value(sess, history, vocab, root_cache, root_history_len) } else { 0.0 }; let val = if board.turn() == shakmaty::Color::Black { -nn_val } else { nn_val }; return val; } let mut legals: Vec<_> = board.legal_moves().into_iter().collect(); legals.sort_by_cached_key(|m| { let uci = m.to_uci(CastlingMode::Standard); if Some(uci.clone()) == tt_move { return -20000; } if m.is_capture() { let victim = m.capture().map(|r| get_piece_value(r)).unwrap_or(0); let attacker = get_piece_value(m.role()); return -victim + attacker - 10000; } 0 }); let mut best_val = -2.0; let mut best_move = None; let old_alpha = alpha; let mut move_count = 0; for m in &legals { move_count += 1; let uci = m.to_uci(CastlingMode::Standard); let mut next_board = board.clone(); next_board.play_unchecked(m); history.push(uci.to_string()); let is_tactical = m.is_capture() || m.is_promotion() || next_board.is_check(); let is_tt = Some(uci.clone()) == tt_move; let mut score; if depth >= 3 && move_count >= 4 && !is_tactical && !is_tt { let r_depth = depth - 2; score = -alpha_beta(&next_board, r_depth, -beta, -alpha, history, session, vocab, tt, tablebases, root_cache, root_history_len, stop_flag, false); if score > alpha { score = -alpha_beta(&next_board, depth - 1, -beta, -alpha, history, session, vocab, tt, tablebases, root_cache, root_history_len, stop_flag, false); } } else { score = -alpha_beta(&next_board, depth - 1, -beta, -alpha, history, session, vocab, tt, tablebases, root_cache, root_history_len, stop_flag, false); } history.pop(); if score > best_val { best_val = score; best_move = Some(uci.clone()); } if score > alpha { alpha = score; } if alpha >= beta { break; } } let node_type = if best_val <= old_alpha { TTNodeType::UpperBound } else if best_val >= beta { TTNodeType::LowerBound } else { TTNodeType::Exact }; tt.insert(hash, TTEntry { depth, score: best_val, node_type, best_move, }); best_val } fn main() { let stdin = io::stdin(); let mut board = Chess::default(); let mut history_moves = vec!["".to_string()]; let _ = ort::init().with_name("neural_engine").commit(); let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4); let mut tb = Tablebase::::new(); let _ = tb.add_directory("syzygy"); let tablebases = Arc::new(Some(tb)); let polyglot_book = Arc::new(std::sync::Mutex::new(PolyglotBook::new("book.bin").ok())); let (vocab, _) = load_vocab(); let vocab = Arc::new(vocab); let tt = Arc::new(DashMap::new()); let mut global_root_hash = board.zobrist_hash::(shakmaty::EnPassantMode::Legal).0; for line in stdin.lock().lines() { let line = line.expect("Failed to read line"); let tokens: Vec<&str> = line.trim().split_whitespace().collect(); if tokens.is_empty() { continue; } match tokens[0] { "uci" => { println!("id name Neurex Pure Value AB"); println!("id author Neural Engine Architect"); println!("uciok"); } "isready" => { println!("readyok"); } "position" => { board = Chess::default(); history_moves.clear(); history_moves.push("".to_string()); global_root_hash = board.zobrist_hash::(shakmaty::EnPassantMode::Legal).0; if tokens.contains(&"moves") { let moves_idx = tokens.iter().position(|&r| r == "moves").unwrap(); for m_str in &tokens[moves_idx + 1..] { if let Ok(uci_move) = shakmaty::uci::Uci::from_ascii(m_str.as_bytes()) { if let Ok(m) = uci_move.to_move(&board) { board.play_unchecked(&m); history_moves.push(m_str.to_string()); global_root_hash = board.zobrist_hash::(shakmaty::EnPassantMode::Legal).0; } } } } } "go" => { let mut movetime_ms = 5000; let mut wtime = 0; let mut btime = 0; let mut inc = 0; let mut i = 1; while i < tokens.len() { match tokens[i] { "wtime" => { wtime = tokens[i+1].parse().unwrap_or(0); i += 2; } "btime" => { btime = tokens[i+1].parse().unwrap_or(0); i += 2; } "winc" => { if board.turn() == shakmaty::Color::White { inc = tokens[i+1].parse().unwrap_or(0); } i += 2; } "binc" => { if board.turn() == shakmaty::Color::Black { inc = tokens[i+1].parse().unwrap_or(0); } i += 2; } "movetime" => { movetime_ms = tokens[i+1].parse().unwrap_or(5000); i += 2; } _ => { i += 1; } } } let mut time_left = 0; if wtime > 0 || btime > 0 { time_left = if board.turn() == shakmaty::Color::White { wtime } else { btime }; let estimated_moves_remaining = 30; movetime_ms = time_left / estimated_moves_remaining + inc / 2; let max_allowable = time_left / 4; movetime_ms = movetime_ms.min(max_allowable).max(100); } if let Ok(book_opt) = polyglot_book.lock() { if let Some(book) = book_opt.as_ref() { if let Some(m_str) = book.lookup(global_root_hash) { println!("bestmove {}", m_str); continue; } } } tt.clear(); NODES_SEARCHED.store(0, Ordering::Relaxed); let start_time = std::time::Instant::now(); let stop_flag = Arc::new(AtomicBool::new(false)); let mut handles = Vec::new(); let root_history_len = history_moves.len(); let mut root_cache = None; if let Some(mut sess) = if std::path::Path::new("model_int8.onnx").exists() { Some(Session::builder().unwrap() .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() .with_intra_threads(1).unwrap() .commit_from_file("model_int8.onnx").unwrap()) } else { None } { root_cache = Some(compute_root_kv_cache(&mut sess, &history_moves, &vocab)); } for thread_id in 0..num_threads { let t_board = board.clone(); let mut t_history = history_moves.clone(); let t_vocab = Arc::clone(&vocab); let t_tb = Arc::clone(&tablebases); let t_stop_flag = Arc::clone(&stop_flag); let t_root_cache = root_cache.clone(); let t_tt = Arc::clone(&tt); let handle = std::thread::spawn(move || { let mut t_session = if Path::new("model_int8.onnx").exists() { Some(Session::builder().unwrap() .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() .with_intra_threads(1).unwrap() .commit_from_file("model_int8.onnx").unwrap()) } else { None }; let max_depth = 100; let mut prev_score = 0.0; for depth in 1..=max_depth { if t_stop_flag.load(Ordering::Relaxed) { break; } let mut alpha = -2.0; let mut beta = 2.0; if depth >= 3 { alpha = prev_score - 0.2; beta = prev_score + 0.2; } let mut score = alpha_beta( &t_board, depth, alpha, beta, &mut t_history, &mut t_session, &t_vocab, &t_tt, &t_tb, &t_root_cache, root_history_len, &t_stop_flag, false ); if score <= alpha || score >= beta { alpha = -2.0; beta = 2.0; score = alpha_beta( &t_board, depth, alpha, beta, &mut t_history, &mut t_session, &t_vocab, &t_tt, &t_tb, &t_root_cache, root_history_len, &t_stop_flag, false ); } prev_score = score; let is_main_thread = thread_id == 0; if is_main_thread && !t_stop_flag.load(Ordering::Relaxed) { let elapsed = start_time.elapsed().as_secs_f64(); let n = NODES_SEARCHED.load(Ordering::Relaxed); let nps = if elapsed > 0.0 { (n as f64 / elapsed) as u64 } else { 0 }; let cp = (prev_score * 100.0) as i32; let mut pv = String::new(); if let Some(entry) = t_tt.get(&global_root_hash) { if let Some(m) = &entry.best_move { pv = m.to_string(); } } println!("info depth {} score cp {} nodes {} nps {} time {} pv {}", depth, cp, n, nps, (elapsed * 1000.0) as u64, pv); } } }); handles.push(handle); } std::thread::sleep(std::time::Duration::from_millis(movetime_ms as u64)); stop_flag.store(true, Ordering::Relaxed); for h in handles { let _ = h.join(); } let mut bestmove = String::from("0000"); if let Some(entry) = tt.get(&global_root_hash) { if let Some(m) = &entry.best_move { bestmove = m.to_string(); } } if bestmove == "0000" { if let Some(m) = board.legal_moves().into_iter().next() { bestmove = m.to_uci(CastlingMode::Standard).to_string(); } } println!("bestmove {}", bestmove); } "quit" => break, _ => {} } } }