use crate::core::heuristics::Heuristic; use crate::core::logic::{CardDatabase, GameState}; #[cfg(feature = "extension-module")] use pyo3::prelude::*; use rand::prelude::*; use rand::rngs::SmallRng; use rand::SeedableRng; // use crate::core::logic::{Phase, ActionReceiver}; #[cfg(feature = "nn")] #[cfg(feature = "nn")] use crate::core::logic::ai_encoding::GameStateEncoding; #[cfg(feature = "nn")] use ort::session::Session; #[cfg(feature = "parallel")] use rayon::prelude::*; use smallvec::SmallVec; use std::collections::HashMap; use std::f32; #[cfg(feature = "nn")] use std::sync::{Arc, Mutex}; use std::time::Duration; #[derive(Default, Clone, Copy)] pub struct MCTSProfiler { pub determinization: Duration, pub selection: Duration, pub expansion: Duration, pub simulation: Duration, pub backpropagation: Duration, } // DeckStats and EvalMode moved to heuristics.rs impl MCTSProfiler { pub fn merge(&mut self, other: &Self) { self.determinization += other.determinization; self.selection += other.selection; self.expansion += other.expansion; self.simulation += other.simulation; self.backpropagation += other.backpropagation; } pub fn print(&self, total: Duration) { let total_secs = total.as_secs_f64(); if total_secs == 0.0 { return; } println!("[MCTS Profile] Breakdown:"); let items = [ ("Determinization", self.determinization), ("Selection", self.selection), ("Expansion", self.expansion), ("Simulation", self.simulation), ("Backpropagation", self.backpropagation), ]; for (name, dur) in items { let secs = dur.as_secs_f64(); println!( " - {:<16}: {:>8.3}s ({:>6.1}%)", name, secs, (secs / total_secs) * 100.0 ); } } } struct Node { visit_count: u32, value_sum: f32, player_just_moved: u8, prior_prob: f32, // AlphaZero Policy P(s,a) untried_actions: SmallVec<[i32; 32]>, children: SmallVec<[(i32, usize); 16]>, // (Action, NodeIndex in Arena) parent: Option, parent_action: i32, } pub struct MCTS { nodes: Vec, rng: SmallRng, unseen_buffer: SmallVec<[i32; 64]>, legal_buffer: SmallVec<[i32; 32]>, reusable_state: GameState, pub evaluator: Option>>, pub cpu_batch_size: usize, pub leaf_batch_size: usize, pub exploration_weight: f32, } #[cfg_attr(feature = "extension-module", pyclass)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SearchHorizon { GameEnd(), TurnEnd(), Limited(u32), } // EvalMode moved to heuristics.rs impl MCTS { pub fn new() -> Self { Self { nodes: Vec::with_capacity(256), // 128 sims → ~128 nodes max; was 16384 = 7.2MB per tree rng: SmallRng::from_os_rng(), unseen_buffer: SmallVec::with_capacity(64), legal_buffer: SmallVec::with_capacity(32), reusable_state: GameState::default(), evaluator: None, cpu_batch_size: 16, leaf_batch_size: 512, exploration_weight: 1.41, } } pub fn with_evaluator( eval: std::sync::Arc>, batch_size: usize, ) -> Self { let mut mcts = Self::new(); mcts.evaluator = Some(eval); mcts.cpu_batch_size = batch_size; mcts } pub fn get_tree_size(&self) -> usize { self.nodes.len() } pub fn get_max_depth(&self) -> usize { let mut max_d = 0; for i in 0..self.nodes.len() { let mut d = 0; let mut curr = i; while let Some(p) = self.nodes[curr].parent { d += 1; curr = p; } if d > max_d { max_d = d; } } max_d } pub fn search_parallel( &self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, heuristic: &dyn Heuristic, shuffle_self: bool, ) -> Vec<(i32, f32, u32)> { #[cfg(feature = "parallel")] let mut _num_threads = rayon::current_num_threads().max(1); #[cfg(not(feature = "parallel"))] let _num_threads = 1; let num_threads = _num_threads; let start_overall = std::time::Instant::now(); let leaf_batch_size = self.leaf_batch_size; let sims_per_thread = if num_sims > 0 { (num_sims + num_threads - 1) / num_threads } else { 0 }; // Collect results #[cfg(feature = "parallel")] let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = { let pool = rayon::ThreadPoolBuilder::new() .stack_size(16 * 1024 * 1024) .num_threads(num_threads) .build() .unwrap_or_else(|_| { rayon::ThreadPoolBuilder::default() .num_threads(num_threads) .build() .unwrap() }); pool.install(|| { (0..num_threads) .into_par_iter() .map(|_| { let mut mcts = MCTS::new(); mcts.leaf_batch_size = leaf_batch_size; mcts.search_custom( root_state, db, sims_per_thread, timeout_sec, horizon, heuristic, shuffle_self, true, ) }) .collect() }) }; #[cfg(not(feature = "parallel"))] let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = vec![{ let mut mcts = MCTS::new(); mcts.leaf_batch_size = leaf_batch_size; mcts.search_custom( root_state, db, num_sims, timeout_sec, horizon, heuristic, shuffle_self, true, ) }]; // Merge results let mut agg_map: HashMap = HashMap::new(); let mut total_visits = 0; let mut agg_profile = MCTSProfiler::default(); for (res, profile) in results { agg_profile.merge(&profile); for (action, score, visits) in res { let entry = agg_map.entry(action).or_insert((0.0, 0)); let total_value = score * visits as f32; entry.0 += total_value; entry.1 += visits; total_visits += visits; } } // Logging Speed let duration = start_overall.elapsed(); let _sims_per_sec = total_visits as f64 / duration.as_secs_f64(); if total_visits > 100 { // println!("[MCTS] Completed {} sims in {:.3}s ({:.0} sims/s)", total_visits, duration.as_secs_f64(), sims_per_sec); // agg_profile.print(duration); } let mut stats: Vec<(i32, f32, u32)> = agg_map .into_iter() .map(|(action, (sum_val, visits))| { if visits > 0 { (action, sum_val / visits as f32, visits) } else { (action, 0.0, 0) } }) .collect(); stats.sort_by_key(|&(_, _, v)| std::cmp::Reverse(v)); stats } fn root_action_tiebreak( root_state: &GameState, db: &CardDatabase, action: i32, ) -> (i32, i32) { let mut sim = root_state.clone(); if sim.step(db, action).is_err() { return (i32::MIN, i32::MIN); } let me = root_state.current_player as usize; let opp = 1 - me; let live_diff = sim.core.players[me].success_lives.len() as i32 - sim.core.players[opp].success_lives.len() as i32; let score_diff = sim.core.players[me].score as i32 - sim.core.players[opp].score as i32; (live_diff, score_diff) } pub fn search_parallel_mode( &self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, eval_mode: crate::core::heuristics::EvalMode, heuristic: &dyn Heuristic, ) -> Vec<(i32, f32, u32)> { let start_overall = std::time::Instant::now(); #[cfg(feature = "parallel")] let mut _num_threads = rayon::current_num_threads().max(1); #[cfg(not(feature = "parallel"))] let mut _num_threads = 1; let num_threads = _num_threads; let leaf_batch_size = self.leaf_batch_size; let sims_per_thread = if num_sims > 0 { (num_sims + num_threads - 1) / num_threads } else { 0 }; // Collect results #[cfg(feature = "parallel")] let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = (0..num_threads) .into_par_iter() .map(|_| { let mut mcts = MCTS::new(); mcts.leaf_batch_size = leaf_batch_size; mcts.search_mode( root_state, db, sims_per_thread, timeout_sec, horizon, eval_mode, heuristic, ) }) .collect(); #[cfg(not(feature = "parallel"))] let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = vec![{ let mut mcts = MCTS::new(); mcts.leaf_batch_size = leaf_batch_size; mcts.search_mode( root_state, db, num_sims, timeout_sec, horizon, eval_mode, heuristic, ) }]; // Merge results let mut agg_map: HashMap = HashMap::new(); let mut total_visits = 0; let mut agg_profile = MCTSProfiler::default(); for (res, profile) in results { agg_profile.merge(&profile); for (action, score, visits) in res { let entry = agg_map.entry(action).or_insert((0.0, 0)); let total_value = score * visits as f32; entry.0 += total_value; entry.1 += visits; total_visits += visits; } } // Logging Speed let duration = start_overall.elapsed(); let _sims_per_sec = total_visits as f64 / duration.as_secs_f64(); if total_visits > 100 { // println!("[MCTS Mode] Completed {} sims in {:.3}s ({:.0} sims/s)", total_visits, duration.as_secs_f64(), sims_per_sec); // agg_profile.print(duration); } let mut stats: Vec<(i32, f32, u32)> = agg_map .into_iter() .map(|(action, (sum_val, visits))| { if visits > 0 { (action, sum_val / visits as f32, visits) } else { (action, 0.0, 0) } }) .collect(); stats.sort_by_key(|&(_, _, v)| std::cmp::Reverse(v)); stats } pub fn search( &mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, heuristic: &dyn Heuristic, ) -> (Vec<(i32, f32, u32)>, MCTSProfiler) { self.search_custom( root_state, db, num_sims, timeout_sec, horizon, heuristic, false, false, ) } pub fn search_custom( &mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, heuristic: &dyn Heuristic, shuffle_self: bool, enable_rollout: bool, ) -> (Vec<(i32, f32, u32)>, MCTSProfiler) { use crate::core::heuristics::calculate_deck_expectations; let p0_stats = calculate_deck_expectations(&root_state.players[0].deck, db); let mut p1_unseen = root_state.players[1].hand.clone(); p1_unseen.extend(root_state.players[1].deck.iter().cloned()); let p1_stats = calculate_deck_expectations(&p1_unseen, db); self.run_mcts_config( root_state, db, num_sims, timeout_sec, horizon, shuffle_self, enable_rollout, |state, _db| { if state.is_terminal() { match state.get_winner() { 0 => 1.0, 1 => 0.0, _ => 0.5, } } else { heuristic.evaluate( state, db, root_state.players[0].score, root_state.players[1].score, crate::core::heuristics::EvalMode::Normal, Some(p0_stats), Some(p1_stats), ) } }, ) } pub fn search_mode( &mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, eval_mode: crate::core::heuristics::EvalMode, heuristic: &dyn Heuristic, ) -> (Vec<(i32, f32, u32)>, MCTSProfiler) { use crate::core::heuristics::{calculate_deck_expectations, EvalMode}; // Pre-calculate deck expectations for optimization let p0_stats = calculate_deck_expectations(&root_state.players[0].deck, db); let mut p1_unseen = root_state.players[1].hand.clone(); p1_unseen.extend(root_state.players[1].deck.iter().cloned()); let p1_stats = calculate_deck_expectations(&p1_unseen, db); self.run_mcts_config( root_state, db, num_sims, timeout_sec, horizon, eval_mode == EvalMode::Blind, true, |state, _db| { if state.is_terminal() { if eval_mode == EvalMode::Solitaire { match state.get_winner() { 0 => 1.0, 1 => 0.0, _ => 0.5, } } else { match state.get_winner() { 0 => 1.0, 1 => 0.0, _ => 0.5, } } } else { heuristic.evaluate( state, db, root_state.players[0].score, root_state.players[1].score, eval_mode, Some(p0_stats), Some(p1_stats), ) } }, ) } pub fn run_mcts_config( &mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, shuffle_self: bool, enable_rollout: bool, mut eval_fn: F, ) -> (Vec<(i32, f32, u32)>, MCTSProfiler) where F: FnMut(&GameState, &CardDatabase) -> f32, { self.nodes.clear(); // Reserve exactly enough for the entire search upfront. // Each simulation adds ≤1 node, so this is a tight upper bound. // Prevents Vec from doubling (8192→16384 = 7.2MB) when system RAM is low. if num_sims > 0 { self.nodes.reserve(num_sims + 1); } let start_time = std::time::Instant::now(); let timeout = if timeout_sec > 0.0 { Some(std::time::Duration::from_secs_f32(timeout_sec)) } else { None }; let root_start_turn = root_state.turn; let mut profiler = MCTSProfiler::default(); // Root Node let mut legal_indices: SmallVec<[i32; 32]> = SmallVec::with_capacity(32); root_state.generate_legal_actions( db, root_state.current_player as usize, &mut legal_indices, ); if legal_indices.is_empty() { return (vec![(0, 0.5, 0)], profiler); } if legal_indices.len() == 1 { return (vec![(legal_indices[0], 0.5, 1)], profiler); } self.nodes.push(Node { visit_count: 0, value_sum: 0.0, player_just_moved: 1 - root_state.current_player, prior_prob: 1.0, untried_actions: legal_indices, children: SmallVec::new(), parent: None, parent_action: 0, }); let mut sims_done = 0; loop { if num_sims > 0 && sims_done >= num_sims { break; } if let Some(to) = timeout { if start_time.elapsed() >= to { break; } } if num_sims == 0 && timeout.is_none() { break; } // Safety let is_trace = num_sims == 1 && std::env::var("RABUKA_MCTS_TRACE") .map(|value| value == "1") .unwrap_or(false); let t_setup = std::time::Instant::now(); // --- BRANCH: AlphaZero Batch or Sequential --- if let Some(ref evaluator) = self.evaluator { let remaining_sims = num_sims.saturating_sub(sims_done).max(1); let batch_size = self.cpu_batch_size.max(1).min(remaining_sims); let mut batch_leaf_indices = Vec::with_capacity(batch_size); let mut batch_leaf_states = Vec::with_capacity(batch_size); for _ in 0..batch_size { // 1. Setup (Determinization) self.reusable_state.copy_from(&root_state); let state = &mut self.reusable_state; state.ui.silent = true; // (Determinization logic duplicated for now within batch loop) let me = root_state.current_player as usize; let opp = 1 - me; let opp_hand_len = state.players[opp].hand.len(); self.unseen_buffer.clear(); self.unseen_buffer .extend_from_slice(&state.players[opp].hand); self.unseen_buffer .extend_from_slice(&state.players[opp].deck); self.unseen_buffer.shuffle(&mut self.rng); state.players[opp] .hand .copy_from_slice(&self.unseen_buffer[0..opp_hand_len]); state.players[opp] .deck .copy_from_slice(&self.unseen_buffer[opp_hand_len..]); // 2. Selection let mut node_idx = 0; while self.nodes[node_idx].untried_actions.is_empty() && !self.nodes[node_idx].children.is_empty() { node_idx = Self::select_child_logic( &self.nodes, node_idx, self.exploration_weight, ); let action = self.nodes[node_idx].parent_action; let _ = state.step(db, action); } // 3. Expansion if !self.nodes[node_idx].untried_actions.is_empty() { let idx = self .rng .random_range(0..self.nodes[node_idx].untried_actions.len()); let action = self.nodes[node_idx].untried_actions.swap_remove(idx); let actor = state.current_player; let _ = state.step(db, action); state.sync_stat_caches(0, db); state.sync_stat_caches(1, db); self.legal_buffer.clear(); state.generate_legal_actions( db, state.current_player as usize, &mut self.legal_buffer, ); let new_legal_indices = self.legal_buffer.clone(); let prob = 1.0 / (new_legal_indices.len().max(1) as f32); let new_idx = self.nodes.len(); self.nodes.push(Node { visit_count: 0, value_sum: 0.0, player_just_moved: actor, prior_prob: prob, untried_actions: new_legal_indices, children: SmallVec::new(), parent: Some(node_idx), parent_action: action, }); self.nodes[node_idx].children.push((action, new_idx)); node_idx = new_idx; } batch_leaf_indices.push(node_idx); batch_leaf_states.push(state.clone()); } // 4. Collective Inference let outputs = evaluator.evaluate_batch(&batch_leaf_states, db); // 5. Apply results & Backprop for (i, output) in outputs.into_iter().enumerate() { let leaf_idx = batch_leaf_indices[i]; // Update Policy Priors for children (if expanded) if !self.nodes[leaf_idx].children.is_empty() { // Assuming ACTION_SPACE is 8192 for j in 0..self.nodes[leaf_idx].children.len() { let (action, child_idx) = self.nodes[leaf_idx].children[j]; if (action as usize) < output.policy.len() { self.nodes[child_idx].prior_prob = output.policy[action as usize]; } } } // Hybrid Reward: NN Value + Heuristic Weighting let leaf_state: &GameState = &batch_leaf_states[i]; let reward_p0 = if leaf_state.is_terminal() { match leaf_state.get_winner() { 0 => 1.0, 1 => 0.0, _ => 0.5, } } else { let current_player_win_prob = ((output.value.clamp(-1.0, 1.0) + 1.0) * 0.5).clamp(0.0, 1.0); if leaf_state.current_player == 0 { current_player_win_prob } else { 1.0 - current_player_win_prob } }; // Backprop let mut curr = Some(leaf_idx); while let Some(idx) = curr { let node_p_moved = self.nodes[idx].player_just_moved; let node = &mut self.nodes[idx]; node.visit_count += 1; if node_p_moved == 0 { node.value_sum += reward_p0; } else { node.value_sum += 1.0 - reward_p0; } curr = node.parent; } } sims_done += batch_size; } else { sims_done += 1; let mut node_idx = 0; // 1. Setup & Determinization (Sequential) self.reusable_state.copy_from(&root_state); let state = &mut self.reusable_state; state.ui.silent = true; let me = root_state.current_player as usize; let opp = 1 - me; let opp_hand_len = state.players[opp].hand.len(); self.unseen_buffer.clear(); self.unseen_buffer .extend_from_slice(&state.players[opp].hand); self.unseen_buffer .extend_from_slice(&state.players[opp].deck); self.unseen_buffer.shuffle(&mut self.rng); state.players[opp] .hand .copy_from_slice(&self.unseen_buffer[0..opp_hand_len]); state.players[opp] .deck .copy_from_slice(&self.unseen_buffer[opp_hand_len..]); if shuffle_self { let mut my_deck = state.players[me].deck.clone(); my_deck.shuffle(&mut self.rng); state.players[me].deck = my_deck; } // --- SEQUENTIAL PATH (Original logic) --- profiler.determinization += t_setup.elapsed(); // 2. Selection let t_selection = std::time::Instant::now(); while self.nodes[node_idx].untried_actions.is_empty() && !self.nodes[node_idx].children.is_empty() { node_idx = Self::select_child_logic(&self.nodes, node_idx, self.exploration_weight); let action = self.nodes[node_idx].parent_action; if is_trace { println!("[MCTS TRACE] Selection: Action {}", action); } let _ = state.step(db, action); } profiler.selection += t_selection.elapsed(); // 3. Expansion let t_expansion = std::time::Instant::now(); if !self.nodes[node_idx].untried_actions.is_empty() { let idx = self .rng .random_range(0..self.nodes[node_idx].untried_actions.len()); let action = self.nodes[node_idx].untried_actions.swap_remove(idx); if is_trace { println!("[MCTS TRACE] Expansion: Action {}", action); } let actor = state.current_player; let _ = state.step(db, action); state.sync_stat_caches(0, db); state.sync_stat_caches(1, db); self.legal_buffer.clear(); let mut is_capped = false; if let SearchHorizon::TurnEnd() = horizon { if state.turn > root_start_turn { is_capped = true; } } if !is_capped { state.generate_legal_actions( db, state.current_player as usize, &mut self.legal_buffer, ); } let new_legal_indices = self.legal_buffer.clone(); let prob = 1.0 / (new_legal_indices.len().max(1) as f32); let new_node = Node { visit_count: 0, value_sum: 0.0, player_just_moved: actor, prior_prob: prob, untried_actions: new_legal_indices, children: SmallVec::new(), parent: Some(node_idx), parent_action: action, }; let new_idx = self.nodes.len(); self.nodes.push(new_node); self.nodes[node_idx].children.push((action, new_idx)); node_idx = new_idx; } profiler.expansion += t_expansion.elapsed(); // 4. Simulation let t_simulation = std::time::Instant::now(); let reward_p0 = { // Standard CPU Rollout let mut depth = 0; if enable_rollout { while !state.is_terminal() && depth < 200 { if let SearchHorizon::Limited(max_depth) = horizon { if depth >= max_depth as usize { break; } } if let SearchHorizon::TurnEnd() = horizon { if state.turn > root_start_turn { break; } } state.generate_legal_actions( db, state.current_player as usize, &mut self.legal_buffer, ); if self.legal_buffer.is_empty() { break; } let chunk_action = *self.legal_buffer.choose(&mut self.rng).unwrap(); let _ = state.step(db, chunk_action); depth += 1; } } eval_fn(&state, db) }; profiler.simulation += t_simulation.elapsed(); // 5. Backpropagation let _t_backprop = std::time::Instant::now(); let mut curr = Some(node_idx); let batch_weight = 1.0; let batch_visits = 1; while let Some(idx) = curr { let node_p_moved = self.nodes[idx].player_just_moved; let node = &mut self.nodes[idx]; node.visit_count += batch_visits as u32; if node_p_moved == 0 { node.value_sum += reward_p0 * batch_weight; } else { node.value_sum += (1.0 - reward_p0) * batch_weight; } curr = node.parent; } sims_done += 1; } } let mut stats: Vec<(i32, f32, u32)> = self.nodes[0] .children .iter() .map(|&(act, idx)| { let child = &self.nodes[idx]; ( act, child.value_sum / child.visit_count as f32, child.visit_count, ) }) .collect(); stats.sort_by(|left, right| { let left_tiebreak = Self::root_action_tiebreak(root_state, db, left.0); let right_tiebreak = Self::root_action_tiebreak(root_state, db, right.0); right .2 .cmp(&left.2) .then_with(|| right.1.partial_cmp(&left.1).unwrap_or(std::cmp::Ordering::Equal)) .then_with(|| right_tiebreak.0.cmp(&left_tiebreak.0)) .then_with(|| right_tiebreak.1.cmp(&left_tiebreak.1)) }); (stats, profiler) } fn select_child_logic(nodes: &[Node], node_idx: usize, exploration_weight: f32) -> usize { let node = &nodes[node_idx]; let mut best_score = f32::NEG_INFINITY; let mut best_child = 0; let total_visits_sqrt = (node.visit_count as f32).sqrt(); for &(_, child_idx) in &node.children { let child = &nodes[child_idx]; // Average Value (Q). If no visits, evaluate slightly optimistically (0.5) let exploit = if child.visit_count > 0 { child.value_sum / child.visit_count as f32 } else { 0.5 }; // PUCT Exploration (AlphaZero) let explore = exploration_weight * child.prior_prob * (total_visits_sqrt / (1.0 + child.visit_count as f32)); let score = exploit + explore; if score > best_score { best_score = score; best_child = child_idx; } } best_child } #[allow(dead_code)] fn select_child(&self, node_idx: usize) -> usize { Self::select_child_logic(&self.nodes, node_idx, self.exploration_weight) } // Evaluation functions moved to heuristics.rs } #[cfg(feature = "nn")] pub struct HybridMCTS { pub session: Arc>, pub neural_weight: f32, pub skip_rollout: bool, pub rng: SmallRng, } #[cfg(feature = "nn")] impl HybridMCTS { pub fn new(session: Arc>, neural_weight: f32, skip_rollout: bool) -> Self { Self { session, neural_weight, skip_rollout, rng: SmallRng::from_os_rng(), } } pub fn get_suggestions( &mut self, state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, ) -> Vec<(i32, f32, u32)> { let start = std::time::Instant::now(); let (stats, profile) = self.search(state, db, num_sims, timeout_sec); let duration = start.elapsed(); if num_sims > 10 { profile.print(duration); } stats } pub fn search( &mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, ) -> (Vec<(i32, f32, u32)>, MCTSProfiler) { let session_arc = self.session.clone(); let neural_weight = self.neural_weight; let mut mcts = MCTS::new(); mcts.run_mcts_config( root_state, db, num_sims, timeout_sec, SearchHorizon::GameEnd(), false, !self.skip_rollout, |state: &GameState, db: &CardDatabase| { if state.is_terminal() { return match state.get_winner() { 0 => 1.0, 1 => 0.0, _ => 0.5, }; } // Normal Heuristic Baseline use crate::core::heuristics::{EvalMode, OriginalHeuristic}; let h = OriginalHeuristic::default(); let h_val = h.evaluate( state, db, root_state.players[0].score, root_state.players[1].score, EvalMode::Normal, None, None, ); // NN Evaluation let input_vec = state.encode_state(db); let mut session = session_arc.lock().unwrap(); let input_shape = [1, input_vec.len()]; // Try to create input tensor and run using (shape, vec) which is version-agnostic if let Ok(input_tensor) = ort::value::Value::from_array((input_shape, input_vec)) { if let Ok(outputs) = session.run(ort::inputs![input_tensor]) { // Try to get value output. AlphaNet has 'output_1' or index 1 let val_opt = outputs.get("output_1").or_else(|| outputs.get("value")); if let Some(v_val) = val_opt { if let Ok((_, v_slice)) = v_val.try_extract_tensor::() { let nn_val = v_slice[0]; let nn_norm = (nn_val * 0.5 + 0.5).clamp(0.0, 1.0) as f32; return h_val * (1.0 - neural_weight) + nn_norm * neural_weight; } } else if outputs.len() > 1 { // Fallback to index if names missing if let Ok((_, v_slice)) = outputs[1].try_extract_tensor::() { let nn_val = v_slice[0]; let nn_norm = (nn_val * 0.5 + 0.5).clamp(0.0, 1.0) as f32; return h_val * (1.0 - neural_weight) + nn_norm * neural_weight; } } } } h_val }, ) } }