dpv007's picture
Upload folder using huggingface_hub
43b516e verified
Raw
History Blame Contribute Delete
25.8 kB
use std::collections::{HashMap, HashSet};
use std::io::{self, BufRead};
use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicU64, AtomicBool, Ordering};
static NODES_SEARCHED: AtomicU64 = AtomicU64::new(0);
use dashmap::DashMap;
use shakmaty::{Chess, Move, Position, Setup, CastlingMode};
use shakmaty::zobrist::{Zobrist64, ZobristHash};
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use shakmaty_syzygy::{Tablebase, Wdl, AmbiguousWdl};
use std::path::Path;
struct PolyglotBook {
file: std::fs::File,
num_entries: u64,
}
impl PolyglotBook {
fn new(path: &str) -> Result<Self, io::Error> {
let file = std::fs::File::open(path)?;
let metadata = file.metadata()?;
let num_entries = metadata.len() / 16;
Ok(Self { file, num_entries })
}
fn lookup(&self, hash: u64) -> Option<String> {
use std::io::{Read, Seek, SeekFrom};
let mut low = 0;
let mut high = self.num_entries as i64 - 1;
let mut f = &self.file;
while low <= high {
let mid = low + (high - low) / 2;
f.seek(SeekFrom::Start(mid as u64 * 16)).ok()?;
let mut buf = [0u8; 8];
f.read_exact(&mut buf).ok()?;
let key = u64::from_be_bytes(buf);
if key == hash {
f.read_exact(&mut buf).ok()?;
let move_bits = u16::from_be_bytes([buf[0], buf[1]]);
return Some(Self::decode_move(move_bits));
} else if key < hash {
low = mid + 1;
} else {
high = mid - 1;
}
}
None
}
fn decode_move(move_bits: u16) -> String {
let to_file = (move_bits & 7) as u8;
let to_row = ((move_bits >> 3) & 7) as u8;
let from_file = ((move_bits >> 6) & 7) as u8;
let from_row = ((move_bits >> 9) & 7) as u8;
let prom = (move_bits >> 12) & 7;
let from_sq = format!("{}{}", (b'a' + from_file) as char, (b'1' + from_row) as char);
let to_sq = format!("{}{}", (b'a' + to_file) as char, (b'1' + to_row) as char);
let prom_char = match prom {
1 => "n",
2 => "b",
3 => "r",
4 => "q",
_ => "",
};
format!("{}{}{}", from_sq, to_sq, prom_char)
}
}
fn load_vocab() -> (HashMap<String, i64>, HashMap<i64, String>) {
let vocab_str = std::fs::read_to_string("vocab.json").expect("Failed to read vocab.json");
let vocab: HashMap<String, i64> = serde_json::from_str(&vocab_str).expect("Invalid vocab.json");
let mut inv_vocab = HashMap::new();
for (k, &v) in &vocab {
inv_vocab.insert(v, k.clone());
}
(vocab, inv_vocab)
}
fn get_piece_value(role: shakmaty::Role) -> i32 {
match role {
shakmaty::Role::Pawn => 100,
shakmaty::Role::Knight => 320,
shakmaty::Role::Bishop => 330,
shakmaty::Role::Rook => 500,
shakmaty::Role::Queen => 900,
shakmaty::Role::King => 20000,
}
}
#[derive(Clone)]
struct RootKVCache {
seq_len: usize,
layers: Vec<(Vec<f32>, Vec<f32>)>,
}
fn compute_root_kv_cache(
session: &mut Session,
history: &[String],
vocab: &HashMap<String, i64>,
) -> RootKVCache {
let mut seq = Vec::new();
let bos_id = vocab.get("<bos>").copied().unwrap_or(0);
seq.push(bos_id);
let max_len = 120;
let start_idx = if history.len() > max_len { history.len() - max_len } else { 0 };
for tok in &history[start_idx..] {
if tok == "<bos>" { continue; }
if let Some(&id) = vocab.get(tok) { seq.push(id); }
else { seq.push(vocab.get("<unk>").copied().unwrap_or(0)); }
}
let seq_len = seq.len();
let input_value = ort::value::Tensor::from_array((vec![1, seq_len], seq)).unwrap();
let pk_0 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap();
let pv_0 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap();
let pk_1 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap();
let pv_1 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap();
let pk_2 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap();
let pv_2 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap();
let pk_3 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap();
let pv_3 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap();
let pk_4 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap();
let pv_4 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap();
let pk_5 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap();
let pv_5 = ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap();
let outputs = session.run(ort::inputs![
"input_ids" => input_value,
"past_k_0" => pk_0, "past_v_0" => pv_0,
"past_k_1" => pk_1, "past_v_1" => pv_1,
"past_k_2" => pk_2, "past_v_2" => pv_2,
"past_k_3" => pk_3, "past_v_3" => pv_3,
"past_k_4" => pk_4, "past_v_4" => pv_4,
"past_k_5" => pk_5, "past_v_5" => pv_5
]).unwrap();
let mut cache = RootKVCache { seq_len, layers: vec![] };
for i in 0..6 {
let pk = outputs[format!("present_k_{}", i).as_str()].try_extract_tensor::<f32>().unwrap().1.to_vec();
let pv = outputs[format!("present_v_{}", i).as_str()].try_extract_tensor::<f32>().unwrap().1.to_vec();
cache.layers.push((pk, pv));
}
cache
}
fn evaluate_onnx_value(
session: &mut Session,
history: &[String],
vocab: &HashMap<String, i64>,
root_cache: &Option<RootKVCache>,
root_history_len: usize,
) -> f32 {
let mut seq = Vec::new();
let is_root = root_cache.is_none();
if is_root {
let bos_id = vocab.get("<bos>").copied().unwrap_or(0);
seq.push(bos_id);
for tok in history {
if tok == "<bos>" { continue; }
if let Some(&id) = vocab.get(tok) { seq.push(id); }
else { seq.push(vocab.get("<unk>").copied().unwrap_or(0)); }
}
} else {
let new_moves = if history.len() > root_history_len { &history[root_history_len..] } else { &[] };
for tok in new_moves {
if tok == "<bos>" { continue; }
if let Some(&id) = vocab.get(tok) { seq.push(id); }
else { seq.push(vocab.get("<unk>").copied().unwrap_or(0)); }
}
if seq.is_empty() {
seq.push(vocab.get("<pad>").copied().unwrap_or(0));
}
}
let seq_len = seq.len();
let input_value = ort::value::Tensor::from_array((vec![1, seq_len], seq)).unwrap();
let past_seq_len = root_cache.as_ref().map(|c| c.seq_len).unwrap_or(0);
let pk_0 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[0].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() };
let pv_0 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[0].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() };
let pk_1 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[1].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() };
let pv_1 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[1].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() };
let pk_2 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[2].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() };
let pv_2 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[2].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() };
let pk_3 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[3].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() };
let pv_3 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[3].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() };
let pk_4 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[4].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() };
let pv_4 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[4].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() };
let pk_5 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[5].0.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() };
let pv_5 = if !is_root { let cache = root_cache.as_ref().unwrap(); ort::value::Tensor::from_array((vec![1, 2, past_seq_len, 64], cache.layers[5].1.clone())).unwrap() } else { ort::value::Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 2, 0, 64))).unwrap() };
let outputs = session.run(ort::inputs![
"input_ids" => input_value,
"past_k_0" => pk_0, "past_v_0" => pv_0,
"past_k_1" => pk_1, "past_v_1" => pv_1,
"past_k_2" => pk_2, "past_v_2" => pv_2,
"past_k_3" => pk_3, "past_v_3" => pv_3,
"past_k_4" => pk_4, "past_v_4" => pv_4,
"past_k_5" => pk_5, "past_v_5" => pv_5
]).unwrap();
let value_tensor = outputs["value"].try_extract_tensor::<f32>().unwrap();
value_tensor.1[0]
}
#[derive(Clone, Copy, PartialEq)]
enum TTNodeType { Exact, LowerBound, UpperBound }
#[derive(Clone)]
struct TTEntry {
depth: u8,
score: f32,
node_type: TTNodeType,
best_move: Option<shakmaty::uci::Uci>,
}
fn alpha_beta(
board: &Chess,
depth: u8,
mut alpha: f32,
mut beta: f32,
history: &mut Vec<String>,
session: &mut Option<Session>,
vocab: &HashMap<String, i64>,
tt: &DashMap<u64, TTEntry>,
tablebases: &Option<Tablebase<Chess>>,
root_cache: &Option<RootKVCache>,
root_history_len: usize,
stop_flag: &AtomicBool,
is_null: bool,
) -> f32 {
NODES_SEARCHED.fetch_add(1, Ordering::Relaxed);
if stop_flag.load(Ordering::Relaxed) { return 0.0; }
if let Some(tb) = tablebases {
if board.board().occupied().count() <= 6 {
if let Ok(wdl) = tb.probe_wdl(board) {
match wdl {
AmbiguousWdl::Win => return 0.99,
AmbiguousWdl::Loss => return -0.99,
AmbiguousWdl::Draw | AmbiguousWdl::BlessedLoss | AmbiguousWdl::CursedWin | AmbiguousWdl::MaybeLoss | AmbiguousWdl::MaybeWin => return 0.0,
}
}
}
}
if board.is_game_over() {
if board.is_checkmate() { return -1.0; }
return 0.0;
}
// Null Move Pruning removed because passing a turn without appending to history
// breaks the Transformer's odd/even sequence length mapping to player turn,
// resulting in hallucinated evaluations.
let hash = board.zobrist_hash::<Zobrist64>(shakmaty::EnPassantMode::Legal).0;
let mut tt_move = None;
if let Some(entry) = tt.get(&hash) {
if entry.depth >= depth {
match entry.node_type {
TTNodeType::Exact => return entry.score,
TTNodeType::LowerBound => { if entry.score > alpha { alpha = entry.score; } },
TTNodeType::UpperBound => { if entry.score < beta { beta = entry.score; } },
}
if alpha >= beta { return entry.score; }
}
tt_move = entry.best_move.clone();
}
if depth == 0 {
let nn_val = if let Some(sess) = session {
evaluate_onnx_value(sess, history, vocab, root_cache, root_history_len)
} else {
0.0
};
let val = if board.turn() == shakmaty::Color::Black { -nn_val } else { nn_val };
return val;
}
let mut legals: Vec<_> = board.legal_moves().into_iter().collect();
legals.sort_by_cached_key(|m| {
let uci = m.to_uci(CastlingMode::Standard);
if Some(uci.clone()) == tt_move { return -20000; }
if m.is_capture() {
let victim = m.capture().map(|r| get_piece_value(r)).unwrap_or(0);
let attacker = get_piece_value(m.role());
return -victim + attacker - 10000;
}
0
});
let mut best_val = -2.0;
let mut best_move = None;
let old_alpha = alpha;
let mut move_count = 0;
for m in &legals {
move_count += 1;
let uci = m.to_uci(CastlingMode::Standard);
let mut next_board = board.clone();
next_board.play_unchecked(m);
history.push(uci.to_string());
let is_tactical = m.is_capture() || m.is_promotion() || next_board.is_check();
let is_tt = Some(uci.clone()) == tt_move;
let mut score;
if depth >= 3 && move_count >= 4 && !is_tactical && !is_tt {
let r_depth = depth - 2;
score = -alpha_beta(&next_board, r_depth, -beta, -alpha, history, session, vocab, tt, tablebases, root_cache, root_history_len, stop_flag, false);
if score > alpha {
score = -alpha_beta(&next_board, depth - 1, -beta, -alpha, history, session, vocab, tt, tablebases, root_cache, root_history_len, stop_flag, false);
}
} else {
score = -alpha_beta(&next_board, depth - 1, -beta, -alpha, history, session, vocab, tt, tablebases, root_cache, root_history_len, stop_flag, false);
}
history.pop();
if score > best_val {
best_val = score;
best_move = Some(uci.clone());
}
if score > alpha { alpha = score; }
if alpha >= beta { break; }
}
let node_type = if best_val <= old_alpha { TTNodeType::UpperBound }
else if best_val >= beta { TTNodeType::LowerBound }
else { TTNodeType::Exact };
tt.insert(hash, TTEntry {
depth,
score: best_val,
node_type,
best_move,
});
best_val
}
fn main() {
let stdin = io::stdin();
let mut board = Chess::default();
let mut history_moves = vec!["<bos>".to_string()];
let _ = ort::init().with_name("neural_engine").commit();
let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
let mut tb = Tablebase::<Chess>::new();
let _ = tb.add_directory("syzygy");
let tablebases = Arc::new(Some(tb));
let polyglot_book = Arc::new(std::sync::Mutex::new(PolyglotBook::new("book.bin").ok()));
let (vocab, _) = load_vocab();
let vocab = Arc::new(vocab);
let tt = Arc::new(DashMap::new());
let mut global_root_hash = board.zobrist_hash::<Zobrist64>(shakmaty::EnPassantMode::Legal).0;
for line in stdin.lock().lines() {
let line = line.expect("Failed to read line");
let tokens: Vec<&str> = line.trim().split_whitespace().collect();
if tokens.is_empty() { continue; }
match tokens[0] {
"uci" => {
println!("id name Neurex Pure Value AB");
println!("id author Neural Engine Architect");
println!("uciok");
}
"isready" => {
println!("readyok");
}
"position" => {
board = Chess::default();
history_moves.clear();
history_moves.push("<bos>".to_string());
global_root_hash = board.zobrist_hash::<Zobrist64>(shakmaty::EnPassantMode::Legal).0;
if tokens.contains(&"moves") {
let moves_idx = tokens.iter().position(|&r| r == "moves").unwrap();
for m_str in &tokens[moves_idx + 1..] {
if let Ok(uci_move) = shakmaty::uci::Uci::from_ascii(m_str.as_bytes()) {
if let Ok(m) = uci_move.to_move(&board) {
board.play_unchecked(&m);
history_moves.push(m_str.to_string());
global_root_hash = board.zobrist_hash::<Zobrist64>(shakmaty::EnPassantMode::Legal).0;
}
}
}
}
}
"go" => {
let mut movetime_ms = 5000;
let mut wtime = 0;
let mut btime = 0;
let mut inc = 0;
let mut i = 1;
while i < tokens.len() {
match tokens[i] {
"wtime" => { wtime = tokens[i+1].parse().unwrap_or(0); i += 2; }
"btime" => { btime = tokens[i+1].parse().unwrap_or(0); i += 2; }
"winc" => { if board.turn() == shakmaty::Color::White { inc = tokens[i+1].parse().unwrap_or(0); } i += 2; }
"binc" => { if board.turn() == shakmaty::Color::Black { inc = tokens[i+1].parse().unwrap_or(0); } i += 2; }
"movetime" => { movetime_ms = tokens[i+1].parse().unwrap_or(5000); i += 2; }
_ => { i += 1; }
}
}
let mut time_left = 0;
if wtime > 0 || btime > 0 {
time_left = if board.turn() == shakmaty::Color::White { wtime } else { btime };
let estimated_moves_remaining = 30;
movetime_ms = time_left / estimated_moves_remaining + inc / 2;
let max_allowable = time_left / 4;
movetime_ms = movetime_ms.min(max_allowable).max(100);
}
if let Ok(book_opt) = polyglot_book.lock() {
if let Some(book) = book_opt.as_ref() {
if let Some(m_str) = book.lookup(global_root_hash) {
println!("bestmove {}", m_str);
continue;
}
}
}
tt.clear();
NODES_SEARCHED.store(0, Ordering::Relaxed);
let start_time = std::time::Instant::now();
let stop_flag = Arc::new(AtomicBool::new(false));
let mut handles = Vec::new();
let root_history_len = history_moves.len();
let mut root_cache = None;
if let Some(mut sess) = if std::path::Path::new("model_int8.onnx").exists() {
Some(Session::builder().unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
.with_intra_threads(1).unwrap()
.commit_from_file("model_int8.onnx").unwrap())
} else { None } {
root_cache = Some(compute_root_kv_cache(&mut sess, &history_moves, &vocab));
}
for thread_id in 0..num_threads {
let t_board = board.clone();
let mut t_history = history_moves.clone();
let t_vocab = Arc::clone(&vocab);
let t_tb = Arc::clone(&tablebases);
let t_stop_flag = Arc::clone(&stop_flag);
let t_root_cache = root_cache.clone();
let t_tt = Arc::clone(&tt);
let handle = std::thread::spawn(move || {
let mut t_session = if Path::new("model_int8.onnx").exists() {
Some(Session::builder().unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
.with_intra_threads(1).unwrap()
.commit_from_file("model_int8.onnx").unwrap())
} else { None };
let max_depth = 100;
let mut prev_score = 0.0;
for depth in 1..=max_depth {
if t_stop_flag.load(Ordering::Relaxed) { break; }
let mut alpha = -2.0;
let mut beta = 2.0;
if depth >= 3 {
alpha = prev_score - 0.2;
beta = prev_score + 0.2;
}
let mut score = alpha_beta(
&t_board,
depth,
alpha,
beta,
&mut t_history,
&mut t_session,
&t_vocab,
&t_tt,
&t_tb,
&t_root_cache,
root_history_len,
&t_stop_flag,
false
);
if score <= alpha || score >= beta {
alpha = -2.0;
beta = 2.0;
score = alpha_beta(
&t_board,
depth,
alpha,
beta,
&mut t_history,
&mut t_session,
&t_vocab,
&t_tt,
&t_tb,
&t_root_cache,
root_history_len,
&t_stop_flag,
false
);
}
prev_score = score;
let is_main_thread = thread_id == 0;
if is_main_thread && !t_stop_flag.load(Ordering::Relaxed) {
let elapsed = start_time.elapsed().as_secs_f64();
let n = NODES_SEARCHED.load(Ordering::Relaxed);
let nps = if elapsed > 0.0 { (n as f64 / elapsed) as u64 } else { 0 };
let cp = (prev_score * 100.0) as i32;
let mut pv = String::new();
if let Some(entry) = t_tt.get(&global_root_hash) {
if let Some(m) = &entry.best_move {
pv = m.to_string();
}
}
println!("info depth {} score cp {} nodes {} nps {} time {} pv {}",
depth, cp, n, nps, (elapsed * 1000.0) as u64, pv);
}
}
});
handles.push(handle);
}
std::thread::sleep(std::time::Duration::from_millis(movetime_ms as u64));
stop_flag.store(true, Ordering::Relaxed);
for h in handles {
let _ = h.join();
}
let mut bestmove = String::from("0000");
if let Some(entry) = tt.get(&global_root_hash) {
if let Some(m) = &entry.best_move {
bestmove = m.to_string();
}
}
if bestmove == "0000" {
if let Some(m) = board.legal_moves().into_iter().next() {
bestmove = m.to_uci(CastlingMode::Standard).to_string();
}
}
println!("bestmove {}", bestmove);
}
"quit" => break,
_ => {}
}
}
}