#[cfg(feature = "extension-module")] use pyo3::prelude::*; use rand::prelude::*; use rand::rngs::SmallRng; use rand::SeedableRng; use crate::core::heuristics::Heuristic; use crate::core::logic::{ GameState, CardDatabase, LiveCard, Phase, FLAG_DRAW, FLAG_SEARCH, FLAG_RECOVER, FLAG_BUFF, FLAG_CHARGE, FLAG_TEMPO, FLAG_REDUCE, FLAG_BOOST, FLAG_TRANSFORM, FLAG_WIN_COND }; use std::f32; #[cfg(feature = "parallel")] use rayon::prelude::*; use std::collections::HashMap; #[cfg(feature = "nn")] use ort::session::Session; #[cfg(feature = "nn")] use std::sync::{Arc, Mutex}; use std::time::Duration; #[derive(Default, Clone, Copy)] pub struct MCTSProfiler { pub determinization: Duration, pub selection: Duration, pub expansion: Duration, pub simulation: Duration, pub backpropagation: Duration, } impl MCTSProfiler { pub fn merge(&mut self, other: &Self) { self.determinization += other.determinization; self.selection += other.selection; self.expansion += other.expansion; self.simulation += other.simulation; self.backpropagation += other.backpropagation; } pub fn print(&self, total: Duration) { let total_secs = total.as_secs_f64(); if total_secs == 0.0 { return; } println!("[MCTS Profile] Breakdown:"); let items = [ ("Determinization", self.determinization), ("Selection", self.selection), ("Expansion", self.expansion), ("Simulation", self.simulation), ("Backpropagation", self.backpropagation), ]; for (name, dur) in items { let secs = dur.as_secs_f64(); println!(" - {:<16}: {:>8.3}s ({:>6.1}%)", name, secs, (secs / total_secs) * 100.0); } } } struct Node { visit_count: u32, value_sum: f32, player_just_moved: u8, untried_actions: Vec, children: Vec<(i32, usize)>, // (Action, NodeIndex in Arena) parent: Option, parent_action: i32, } pub struct MCTS { nodes: Vec, rng: SmallRng, unseen_buffer: Vec, legal_buffer: Vec, reusable_state: GameState, } #[cfg_attr(feature = "extension-module", pyclass(eq, eq_int))] #[derive(Debug, Clone, Copy, PartialEq)] pub enum SearchHorizon { GameEnd, TurnEnd, } #[cfg_attr(feature = "extension-module", pyclass(eq, eq_int))] #[derive(Debug, Clone, Copy, PartialEq)] pub enum EvalMode { Normal, Solitaire, Blind, } impl MCTS { pub fn new() -> Self { Self { nodes: Vec::with_capacity(1000), rng: SmallRng::from_os_rng(), unseen_buffer: Vec::with_capacity(60), legal_buffer: Vec::with_capacity(32), reusable_state: GameState::default(), } } pub fn search_parallel(&self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, heuristic: &dyn Heuristic, shuffle_self: bool) -> Vec<(i32, f32, u32)> { let start_overall = std::time::Instant::now(); #[cfg(feature = "parallel")] let num_threads = rayon::current_num_threads().max(1); #[cfg(not(feature = "parallel"))] let num_threads = 1; let sims_per_thread = if num_sims > 0 { (num_sims + num_threads - 1) / num_threads } else { 0 }; // Collect results #[cfg(feature = "parallel")] let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = (0..num_threads).into_par_iter().map(|_| { let mut mcts = MCTS::new(); mcts.search_custom(root_state, db, sims_per_thread, timeout_sec, horizon, heuristic, shuffle_self, true) }).collect(); #[cfg(not(feature = "parallel"))] let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = vec![{ let mut mcts = MCTS::new(); mcts.search_custom(root_state, db, num_sims, timeout_sec, horizon, heuristic, shuffle_self, true) }]; // Merge results let mut agg_map: HashMap = HashMap::new(); let mut total_visits = 0; let mut agg_profile = MCTSProfiler::default(); for (res, profile) in results { agg_profile.merge(&profile); for (action, score, visits) in res { let entry = agg_map.entry(action).or_insert((0.0, 0)); let total_value = score * visits as f32; entry.0 += total_value; entry.1 += visits; total_visits += visits; } } // Logging Speed let duration = start_overall.elapsed(); let sims_per_sec = total_visits as f64 / duration.as_secs_f64(); if total_visits > 100 { println!("[MCTS] Completed {} sims in {:.3}s ({:.0} sims/s)", total_visits, duration.as_secs_f64(), sims_per_sec); agg_profile.print(duration); } let mut stats: Vec<(i32, f32, u32)> = agg_map.into_iter().map(|(action, (sum_val, visits))| { if visits > 0 { (action, sum_val / visits as f32, visits) } else { (action, 0.0, 0) } }).collect(); stats.sort_by_key(|&(_, _, v)| std::cmp::Reverse(v)); stats } pub fn search_parallel_mode(&self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, eval_mode: EvalMode) -> Vec<(i32, f32, u32)> { let start_overall = std::time::Instant::now(); #[cfg(feature = "parallel")] let num_threads = rayon::current_num_threads().max(1); #[cfg(not(feature = "parallel"))] let num_threads = 1; let sims_per_thread = if num_sims > 0 { (num_sims + num_threads - 1) / num_threads } else { 0 }; // Collect results #[cfg(feature = "parallel")] let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = (0..num_threads).into_par_iter().map(|_| { let mut mcts = MCTS::new(); mcts.search_mode(root_state, db, sims_per_thread, timeout_sec, horizon, eval_mode) }).collect(); #[cfg(not(feature = "parallel"))] let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = vec![{ let mut mcts = MCTS::new(); mcts.search_mode(root_state, db, num_sims, timeout_sec, horizon, eval_mode) }]; // Merge results let mut agg_map: HashMap = HashMap::new(); let mut total_visits = 0; let mut agg_profile = MCTSProfiler::default(); for (res, profile) in results { agg_profile.merge(&profile); for (action, score, visits) in res { let entry = agg_map.entry(action).or_insert((0.0, 0)); let total_value = score * visits as f32; entry.0 += total_value; entry.1 += visits; total_visits += visits; } } // Logging Speed let duration = start_overall.elapsed(); let sims_per_sec = total_visits as f64 / duration.as_secs_f64(); if total_visits > 100 { println!("[MCTS Mode] Completed {} sims in {:.3}s ({:.0} sims/s)", total_visits, duration.as_secs_f64(), sims_per_sec); agg_profile.print(duration); } let mut stats: Vec<(i32, f32, u32)> = agg_map.into_iter().map(|(action, (sum_val, visits))| { if visits > 0 { (action, sum_val / visits as f32, visits) } else { (action, 0.0, 0) } }).collect(); stats.sort_by_key(|&(_, _, v)| std::cmp::Reverse(v)); stats } pub fn search(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, heuristic: &dyn Heuristic) -> (Vec<(i32, f32, u32)>, MCTSProfiler) { self.search_custom(root_state, db, num_sims, timeout_sec, horizon, heuristic, false, true) } pub fn search_custom(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, heuristic: &dyn Heuristic, shuffle_self: bool, enable_rollout: bool) -> (Vec<(i32, f32, u32)>, MCTSProfiler) { self.run_mcts_config(root_state, db, num_sims, timeout_sec, horizon, shuffle_self, enable_rollout, |state, _db| { if state.is_terminal() { match state.get_winner() { 0 => 1.0, 1 => 0.0, _ => 0.5, } } else { heuristic.evaluate(state, db, root_state.players[0].score, root_state.players[1].score) } }) } pub fn search_mode(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, eval_mode: EvalMode) -> (Vec<(i32, f32, u32)>, MCTSProfiler) { // Pre-calculate deck expectations for optimization // P0 (Me): Use current deck let p0_stats = Self::calculate_deck_expectations(&root_state.players[0].deck, db); // P1 (Opponent): Use Hand + Deck (Unseen) let mut p1_unseen = root_state.players[1].hand.clone(); p1_unseen.extend(root_state.players[1].deck.iter().cloned()); let p1_stats = Self::calculate_deck_expectations(&p1_unseen, db); self.run_mcts_config(root_state, db, num_sims, timeout_sec, horizon, eval_mode == EvalMode::Blind, true, |state, _db| { if state.is_terminal() { if eval_mode == EvalMode::Solitaire { match state.get_winner() { 0 => 1.0, 1 => 0.0, _ => 0.5, } } else { match state.get_winner() { 0 => 1.0, 1 => 0.0, _ => 0.5, } } } else { Self::heuristic_eval(state, db, root_state.players[0].score, root_state.players[1].score, eval_mode, Some(p0_stats), Some(p1_stats)) } }) } fn run_mcts_config(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, shuffle_self: bool, enable_rollout: bool, mut eval_fn: F) -> (Vec<(i32, f32, u32)>, MCTSProfiler) where F: FnMut(&GameState, &CardDatabase) -> f32 { self.nodes.clear(); let start_time = std::time::Instant::now(); let timeout = if timeout_sec > 0.0 { Some(std::time::Duration::from_secs_f32(timeout_sec)) } else { None }; let start_turn = root_state.turn; let mut profiler = MCTSProfiler::default(); // Root Node let mut legal_indices: Vec = Vec::with_capacity(32); root_state.generate_legal_actions(db, &mut legal_indices); if legal_indices.is_empty() { return (vec![(0, 0.5, 0)], profiler); } if legal_indices.len() == 1 { return (vec![(legal_indices[0], 0.5, 1)], profiler); } self.nodes.push(Node { visit_count: 0, value_sum: 0.0, player_just_moved: 1 - root_state.current_player, untried_actions: legal_indices, children: Vec::new(), parent: None, parent_action: 0, }); let mut sims_done = 0; loop { if num_sims > 0 && sims_done >= num_sims { break; } if let Some(to) = timeout { if start_time.elapsed() >= to { break; } } if num_sims == 0 && timeout.is_none() { break; } // Safety sims_done += 1; // 1. Setup & Determinization let t_setup = std::time::Instant::now(); let mut node_idx = 0; self.reusable_state.copy_from(&root_state); let state = &mut self.reusable_state; state.silent = true; let me = root_state.current_player as usize; let opp = 1 - me; let opp_hand_len = state.players[opp].hand.len(); self.unseen_buffer.clear(); self.unseen_buffer.extend_from_slice(&state.players[opp].hand); self.unseen_buffer.extend_from_slice(&state.players[opp].deck); self.unseen_buffer.shuffle(&mut self.rng); state.players[opp].hand.copy_from_slice(&self.unseen_buffer[0..opp_hand_len]); state.players[opp].deck.copy_from_slice(&self.unseen_buffer[opp_hand_len..]); if shuffle_self { let mut my_deck = state.players[me].deck.clone(); my_deck.shuffle(&mut self.rng); state.players[me].deck = my_deck; } profiler.determinization += t_setup.elapsed(); // 2. Selection let t_selection = std::time::Instant::now(); while self.nodes[node_idx].untried_actions.is_empty() && !self.nodes[node_idx].children.is_empty() { node_idx = Self::select_child(&self.nodes, node_idx); let action = self.nodes[node_idx].parent_action; let _ = state.step(db, action); } profiler.selection += t_selection.elapsed(); // 3. Expansion let t_expansion = std::time::Instant::now(); if !self.nodes[node_idx].untried_actions.is_empty() { let idx = self.rng.random_range(0..self.nodes[node_idx].untried_actions.len()); let action = self.nodes[node_idx].untried_actions.swap_remove(idx); let actor = state.current_player; let _ = state.step(db, action); let mut new_legal_indices: Vec = Vec::with_capacity(32); state.generate_legal_actions(db, &mut new_legal_indices); let new_node = Node { visit_count: 0, value_sum: 0.0, player_just_moved: actor, untried_actions: new_legal_indices, children: Vec::new(), parent: Some(node_idx), parent_action: action, }; let new_idx = self.nodes.len(); self.nodes.push(new_node); self.nodes[node_idx].children.push((action, new_idx)); node_idx = new_idx; } profiler.expansion += t_expansion.elapsed(); // 4. Simulation let t_simulation = std::time::Instant::now(); let mut depth = 0; if enable_rollout { while !state.is_terminal() && depth < 200 { // Horizon Check if horizon == SearchHorizon::TurnEnd && state.turn > start_turn { break; } state.generate_legal_actions(db, &mut self.legal_buffer); if self.legal_buffer.is_empty() { break; } let chunk_action = *self.legal_buffer.choose(&mut self.rng).unwrap(); let _ = state.step(db, chunk_action); depth += 1; } } profiler.simulation += t_simulation.elapsed(); // 5. Backpropagation let t_backprop = std::time::Instant::now(); let reward_p0 = eval_fn(&state, db); let mut curr = Some(node_idx); while let Some(idx) = curr { let node_p_moved = self.nodes[idx].player_just_moved; let node = &mut self.nodes[idx]; node.visit_count += 1; if node_p_moved == 0 { node.value_sum += reward_p0; } else { node.value_sum += 1.0 - reward_p0; } curr = node.parent; } profiler.backpropagation += t_backprop.elapsed(); } let mut stats: Vec<(i32, f32, u32)> = self.nodes[0].children.iter() .map(|&(act, idx)| { let child = &self.nodes[idx]; (act, child.value_sum / child.visit_count as f32, child.visit_count) }) .collect(); stats.sort_by_key(|&(_, _, v)| std::cmp::Reverse(v)); (stats, profiler) } fn select_child(nodes: &[Node], node_idx: usize) -> usize { let node = &nodes[node_idx]; let mut best_score = f32::NEG_INFINITY; let mut best_child = 0; let log_n = (node.visit_count as f32).ln(); for &(_, child_idx) in &node.children { let child = &nodes[child_idx]; // UCB1 let exploit = child.value_sum / child.visit_count as f32; let explore = 1.4 * (log_n / child.visit_count as f32).sqrt(); let score = exploit + explore; if score > best_score { best_score = score; best_child = child_idx; } } best_child } fn heuristic_eval(state: &GameState, db: &CardDatabase, p0_baseline: u32, p1_baseline: u32, eval_mode: EvalMode, p0_deck_stats: Option<([f32; 7], f32)>, p1_deck_stats: Option<([f32; 7], f32)>) -> f32 { let score0 = Self::evaluate_player(state, db, 0, p0_baseline, p0_deck_stats); if eval_mode == EvalMode::Solitaire { // Solitaire: Only care about P0 score. Normalize to [0,1]. // Max possible score estimate: 3 lives * 5 + 2 bonus + 5 board + 2 hand = ~25 return (score0 / 25.0).clamp(0.0, 1.0); } let score1 = Self::evaluate_player(state, db, 1, p1_baseline, p1_deck_stats); let mut final_val = (score0 - score1) * 0.5 + 0.5; // "Win the Live" Tie-breaker (Volume Lead) // If we are in Performance/LiveResult, checking volume matters for resolving ties. // Assuming equal lives, higher volume is better. let p0_vol = state.players[0].current_turn_volume; let p1_vol = state.players[1].current_turn_volume; if p0_vol > p1_vol { final_val += 0.05; } else if p1_vol > p0_vol { final_val -= 0.05; } final_val.clamp(0.0, 1.0) } fn evaluate_player(state: &GameState, db: &CardDatabase, p_idx: usize, baseline_score: u32, deck_stats: Option<([f32; 7], f32)>) -> f32 { let p = &state.players[p_idx]; let mut score = 0.0; // 1. Success Lives (The Goal) // Heavily weight cleared lives. score += p.success_lives.len() as f32 * 5.0; // Increased weight to prioritize winning // Bonus for score increase relative to baseline if p.success_lives.len() > baseline_score as usize { score += 2.0; } // 2. Board State (Stage) let mut stage_hearts = [0u32; 7]; let mut stage_blades = 0; for i in 0..3 { let h = state.get_effective_hearts(p_idx, i, db); for color in 0..7 { stage_hearts[color] += h[color] as u32; } stage_blades += state.get_effective_blades(p_idx, i, db); } // Small bonus for board presence/power score += stage_blades as f32 * 0.1; // 3. Deck Expectations (Yells) // Calculate expected value from deck based on probability let (avg_hearts, avg_vol) = if let Some(stats) = deck_stats { stats } else { Self::calculate_deck_expectations(&p.deck, db) }; let expected_yell_hearts: Vec = avg_hearts.iter().map(|&h| h * stage_blades as f32).collect(); let expected_volume = avg_vol * stage_blades as f32; // 4. Live Clearing Probability let mut max_prob = 0.0; for &cid in &p.live_zone { if cid >= 0 { if let Some(l) = db.get_live(cid as u16) { let prob = Self::calculate_live_success_prob( l, &stage_hearts, &expected_yell_hearts, &p.heart_req_reductions ); // Reward probability of clearing. score += prob * (l.score as f32 * 2.0); // Higher weight for potential clears if prob > max_prob { max_prob = prob; } } } } // Reward Volume if we are likely to clear at least one live if max_prob > 0.5 { score += (expected_volume + p.current_turn_volume as f32) * max_prob * 0.5; } // 5. Hand Quality (Construction) let hand_val = Self::calculate_hand_quality(state, db, p_idx); score += hand_val * 0.15; // 6. Resources & Discard score += p.hand.len() as f32 * 0.05; // Energy efficiency (using available energy) let unused_energy = p.tapped_energy.iter().filter(|&&t| !t).count(); score += unused_energy as f32 * 0.01; // Discard Recovery Potential // If hand has recovery, value good stuff in discard let has_recovery = p.hand.iter().any(|&cid| { if let Some(m) = db.get_member(cid) { m.abilities.iter().any(|a| Self::has_opcode(&a.bytecode, 15) || Self::has_opcode(&a.bytecode, 17)) } else { false } }); if has_recovery { let discard_val = p.discard.iter().filter(|&&cid| { db.get_live(cid).is_some() || db.get_member(cid).map_or(false, |m| m.cost >= 3) }).count(); score += discard_val as f32 * 0.1; } score } fn calculate_deck_expectations(deck: &[u16], db: &CardDatabase) -> ([f32; 7], f32) { if deck.is_empty() { return ([0.0; 7], 0.0); } // Since deck is determinized (shuffled) in the leaf node state, // we could just look at the top N cards if we knew N (blades). // But `blades` changes. So getting average stats of the *remaining* deck is more robust. let mut total_hearts = [0.0; 7]; let mut total_vol = 0.0; let count = deck.len() as f32; for &cid in deck { if let Some(m) = db.get_member(cid) { for i in 0..7 { total_hearts[i] += m.blade_hearts[i] as f32; } total_vol += m.volume_icons as f32; } else if let Some(l) = db.get_live(cid) { for i in 0..7 { total_hearts[i] += l.blade_hearts[i] as f32; } total_vol += l.volume_icons as f32; } } let avg_hearts = total_hearts.map(|v| v / count); let avg_vol = total_vol / count; (avg_hearts, avg_vol) } fn calculate_live_success_prob(live: &LiveCard, stage_hearts: &[u32; 7], expected_yell_hearts: &[f32], reductions: &[i32; 7]) -> f32 { let mut prob; // 1. Hearts Check let mut needed = live.required_hearts; // Apply reductions for i in 0..7 { needed[i] = (needed[i] as i32 - reductions[i]).max(0) as u8; } let mut satisfied = 0.0; let mut total_req = 0.0; let mut wildcards_avail = stage_hearts[6] as f32 + expected_yell_hearts[6]; // Specific Colors for i in 0..6 { let req = needed[i] as f32; total_req += req; let have = stage_hearts[i] as f32 + expected_yell_hearts[i]; if have >= req { satisfied += req; } else { satisfied += have; let deficit = req - have; let used_wild = wildcards_avail.min(deficit); satisfied += used_wild; wildcards_avail -= used_wild; } } // ANY (Star) Requirement let any_req = needed[6] as f32; total_req += any_req; let used_wild = wildcards_avail.min(any_req); satisfied += used_wild; let mut remaining_any = any_req - used_wild; if remaining_any > 0.0 { // Use surpluses for i in 0..6 { let req = needed[i] as f32; let have = stage_hearts[i] as f32 + expected_yell_hearts[i]; let surplus = (have - req).max(0.0); let used = surplus.min(remaining_any); satisfied += used; remaining_any -= used; if remaining_any <= 0.0 { break; } } } if total_req > 0.0 { prob = (satisfied / total_req).clamp(0.0, 1.0); // Non-linear scaling: getting 90% is much better than 50% prob = prob.powf(0.5); // Square root makes it optimistic? No, we want to penalize failure. // Actually, close to 1.0 is good. if prob >= 1.0 { prob = 1.2; } // Guaranteed win bonus } else { prob = 1.2; // Free live } prob } fn calculate_hand_quality(state: &GameState, db: &CardDatabase, p_idx: usize) -> f32 { let p = &state.players[p_idx]; let mut val = 0.0; // In Mulligan phase, assume we will have ~3 energy soon for cost evaluations let is_mulligan = match state.phase { Phase::MulliganP1 | Phase::MulliganP2 => true, _ => false, }; let max_energy = if is_mulligan { 3 } else { p.energy_zone.len() as u32 }; for (i, &cid) in p.hand.iter().enumerate() { let card_val = Self::calculate_card_potential(cid, db, max_energy); // If card is selected for mulligan, it's effectively "gone" but replaced by an "average" card // Average card potential is roughly 1.0 if is_mulligan && ((p.mulligan_selection >> i) & 1u64 == 1) { val += 1.0; // Expected potential of replacement } else { val += card_val; } } val } fn calculate_card_potential(cid: u16, db: &CardDatabase, max_energy: u32) -> f32 { if let Some(m) = db.get_member(cid) { let mut score = 0.0; // Cost Efficiency let stat_sum: u32 = m.hearts.iter().map(|&x| x as u32).sum(); score += (m.blades as f32 + stat_sum as f32) / (m.cost as f32 + 1.0); // Penalty for unplayable high-cost cards if m.cost > max_energy { let diff = m.cost - max_energy; score -= diff as f32 * 0.5; } // Fast Flag Checks (O(1)) let f = m.ability_flags; if (f & FLAG_DRAW) != 0 { score += 0.5; } if (f & FLAG_SEARCH) != 0 { score += 0.6; } if (f & FLAG_RECOVER) != 0 { score += 0.4; } if (f & FLAG_BUFF) != 0 { score += 0.3; } if (f & FLAG_CHARGE) != 0 { score += 0.8; } if (f & FLAG_TEMPO) != 0 { score += 0.2; } if (f & FLAG_REDUCE) != 0 { score += 0.4; } if (f & FLAG_BOOST) != 0 { score += 0.5; } if (f & FLAG_TRANSFORM) != 0 { score += 0.3; } if (f & FLAG_WIN_COND) != 0 { score += 0.6; } score } else if let Some(l) = db.get_live(cid) { // Lives in hand are good if we can clear them, but bad if they clog // Simple heuristic: Score value l.score as f32 * 0.2 } else { 0.0 } } fn has_opcode(bytecode: &[i32], target_op: i32) -> bool { let mut i = 0; while i < bytecode.len() { if i + 3 >= bytecode.len() { break; } let op = bytecode[i]; if op == target_op { return true; } i += 4; } false } } #[cfg(feature = "nn")] pub struct HybridMCTS { pub session: Arc>, pub neural_weight: f32, pub skip_rollout: bool, pub rng: SmallRng, } #[cfg(feature = "nn")] impl HybridMCTS { pub fn new(session: Arc>, neural_weight: f32, skip_rollout: bool) -> Self { Self { session, neural_weight, skip_rollout, rng: SmallRng::from_os_rng(), } } pub fn get_suggestions(&mut self, state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32) -> Vec<(i32, f32, u32)> { let start = std::time::Instant::now(); let (stats, profile) = self.search(state, db, num_sims, timeout_sec); let duration = start.elapsed(); if num_sims > 10 { profile.print(duration); } stats } pub fn search(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32) -> (Vec<(i32, f32, u32)>, MCTSProfiler) { let session_arc = self.session.clone(); let neural_weight = self.neural_weight; let mut mcts = MCTS::new(); mcts.run_mcts_config(root_state, db, num_sims, timeout_sec, SearchHorizon::GameEnd, false, !self.skip_rollout, |state: &GameState, db: &CardDatabase| { if state.is_terminal() { return match state.get_winner() { 0 => 1.0, 1 => 0.0, _ => 0.5, }; } // Normal Heuristic Baseline let h_val = MCTS::heuristic_eval(state, db, root_state.players[0].score, root_state.players[1].score, EvalMode::Normal, None, None); // NN Evaluation let input_vec = state.encode_state(db); let mut session = session_arc.lock().unwrap(); let input_shape = [1, input_vec.len()]; // Try to create input tensor and run using (shape, vec) which is version-agnostic if let Ok(input_tensor) = ort::value::Value::from_array((input_shape, input_vec)) { if let Ok(outputs) = session.run(ort::inputs![input_tensor]) { // Try to get value output. AlphaNet has 'output_1' or index 1 let val_opt = outputs.get("output_1") .or_else(|| outputs.get("value")); if let Some(v_val) = val_opt { if let Ok((_, v_slice)) = v_val.try_extract_tensor::() { let nn_val = v_slice[0]; let nn_norm = (nn_val * 0.5 + 0.5).clamp(0.0, 1.0) as f32; return h_val * (1.0 - neural_weight) + nn_norm * neural_weight; } } else if outputs.len() > 1 { // Fallback to index if names missing if let Ok((_, v_slice)) = outputs[1].try_extract_tensor::() { let nn_val = v_slice[0]; let nn_norm = (nn_val * 0.5 + 0.5).clamp(0.0, 1.0) as f32; return h_val * (1.0 - neural_weight) + nn_norm * neural_weight; } } } } h_val }) } }