Spaces:
Running
Running
| use pyo3::prelude::*; | |
| use rand::prelude::*; | |
| use rand::rngs::SmallRng; | |
| use rand::SeedableRng; | |
| use crate::core::heuristics::Heuristic; | |
| use crate::core::logic::{ | |
| GameState, CardDatabase, LiveCard, Phase, | |
| FLAG_DRAW, FLAG_SEARCH, FLAG_RECOVER, FLAG_BUFF, FLAG_CHARGE, | |
| FLAG_TEMPO, FLAG_REDUCE, FLAG_BOOST, FLAG_TRANSFORM, FLAG_WIN_COND | |
| }; | |
| use std::f32; | |
| use rayon::prelude::*; | |
| use std::collections::HashMap; | |
| use ort::session::Session; | |
| use std::sync::{Arc, Mutex}; | |
| use std::time::Duration; | |
| pub struct MCTSProfiler { | |
| pub determinization: Duration, | |
| pub selection: Duration, | |
| pub expansion: Duration, | |
| pub simulation: Duration, | |
| pub backpropagation: Duration, | |
| } | |
| impl MCTSProfiler { | |
| pub fn merge(&mut self, other: &Self) { | |
| self.determinization += other.determinization; | |
| self.selection += other.selection; | |
| self.expansion += other.expansion; | |
| self.simulation += other.simulation; | |
| self.backpropagation += other.backpropagation; | |
| } | |
| pub fn print(&self, total: Duration) { | |
| let total_secs = total.as_secs_f64(); | |
| if total_secs == 0.0 { return; } | |
| println!("[MCTS Profile] Breakdown:"); | |
| let items = [ | |
| ("Determinization", self.determinization), | |
| ("Selection", self.selection), | |
| ("Expansion", self.expansion), | |
| ("Simulation", self.simulation), | |
| ("Backpropagation", self.backpropagation), | |
| ]; | |
| for (name, dur) in items { | |
| let secs = dur.as_secs_f64(); | |
| println!(" - {:<16}: {:>8.3}s ({:>6.1}%)", name, secs, (secs / total_secs) * 100.0); | |
| } | |
| } | |
| } | |
| struct Node { | |
| visit_count: u32, | |
| value_sum: f32, | |
| player_just_moved: u8, | |
| untried_actions: Vec<i32>, | |
| children: Vec<(i32, usize)>, // (Action, NodeIndex in Arena) | |
| parent: Option<usize>, | |
| parent_action: i32, | |
| } | |
| pub struct MCTS { | |
| nodes: Vec<Node>, | |
| rng: SmallRng, | |
| unseen_buffer: Vec<u16>, | |
| legal_buffer: Vec<i32>, | |
| reusable_state: GameState, | |
| } | |
| pub enum SearchHorizon { | |
| GameEnd, | |
| TurnEnd, | |
| } | |
| pub enum EvalMode { | |
| Normal, | |
| Solitaire, | |
| Blind, | |
| } | |
| impl MCTS { | |
| pub fn new() -> Self { | |
| Self { | |
| nodes: Vec::with_capacity(1000), | |
| rng: SmallRng::from_os_rng(), | |
| unseen_buffer: Vec::with_capacity(60), | |
| legal_buffer: Vec::with_capacity(32), | |
| reusable_state: GameState::default(), | |
| } | |
| } | |
| pub fn search_parallel(&self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, heuristic: &dyn Heuristic, shuffle_self: bool) -> Vec<(i32, f32, u32)> { | |
| let start_overall = std::time::Instant::now(); | |
| let num_threads = rayon::current_num_threads().max(1); | |
| let num_threads = 1; | |
| let sims_per_thread = if num_sims > 0 { (num_sims + num_threads - 1) / num_threads } else { 0 }; | |
| // Collect results | |
| let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = (0..num_threads).into_par_iter().map(|_| { | |
| let mut mcts = MCTS::new(); | |
| mcts.search_custom(root_state, db, sims_per_thread, timeout_sec, horizon, heuristic, shuffle_self, true) | |
| }).collect(); | |
| let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = vec![{ | |
| let mut mcts = MCTS::new(); | |
| mcts.search_custom(root_state, db, num_sims, timeout_sec, horizon, heuristic, shuffle_self, true) | |
| }]; | |
| // Merge results | |
| let mut agg_map: HashMap<i32, (f32, u32)> = HashMap::new(); | |
| let mut total_visits = 0; | |
| let mut agg_profile = MCTSProfiler::default(); | |
| for (res, profile) in results { | |
| agg_profile.merge(&profile); | |
| for (action, score, visits) in res { | |
| let entry = agg_map.entry(action).or_insert((0.0, 0)); | |
| let total_value = score * visits as f32; | |
| entry.0 += total_value; | |
| entry.1 += visits; | |
| total_visits += visits; | |
| } | |
| } | |
| // Logging Speed | |
| let duration = start_overall.elapsed(); | |
| let sims_per_sec = total_visits as f64 / duration.as_secs_f64(); | |
| if total_visits > 100 { | |
| println!("[MCTS] Completed {} sims in {:.3}s ({:.0} sims/s)", total_visits, duration.as_secs_f64(), sims_per_sec); | |
| agg_profile.print(duration); | |
| } | |
| let mut stats: Vec<(i32, f32, u32)> = agg_map.into_iter().map(|(action, (sum_val, visits))| { | |
| if visits > 0 { | |
| (action, sum_val / visits as f32, visits) | |
| } else { | |
| (action, 0.0, 0) | |
| } | |
| }).collect(); | |
| stats.sort_by_key(|&(_, _, v)| std::cmp::Reverse(v)); | |
| stats | |
| } | |
| pub fn search_parallel_mode(&self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, eval_mode: EvalMode) -> Vec<(i32, f32, u32)> { | |
| let start_overall = std::time::Instant::now(); | |
| let num_threads = rayon::current_num_threads().max(1); | |
| let num_threads = 1; | |
| let sims_per_thread = if num_sims > 0 { (num_sims + num_threads - 1) / num_threads } else { 0 }; | |
| // Collect results | |
| let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = (0..num_threads).into_par_iter().map(|_| { | |
| let mut mcts = MCTS::new(); | |
| mcts.search_mode(root_state, db, sims_per_thread, timeout_sec, horizon, eval_mode) | |
| }).collect(); | |
| let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = vec![{ | |
| let mut mcts = MCTS::new(); | |
| mcts.search_mode(root_state, db, num_sims, timeout_sec, horizon, eval_mode) | |
| }]; | |
| // Merge results | |
| let mut agg_map: HashMap<i32, (f32, u32)> = HashMap::new(); | |
| let mut total_visits = 0; | |
| let mut agg_profile = MCTSProfiler::default(); | |
| for (res, profile) in results { | |
| agg_profile.merge(&profile); | |
| for (action, score, visits) in res { | |
| let entry = agg_map.entry(action).or_insert((0.0, 0)); | |
| let total_value = score * visits as f32; | |
| entry.0 += total_value; | |
| entry.1 += visits; | |
| total_visits += visits; | |
| } | |
| } | |
| // Logging Speed | |
| let duration = start_overall.elapsed(); | |
| let sims_per_sec = total_visits as f64 / duration.as_secs_f64(); | |
| if total_visits > 100 { | |
| println!("[MCTS Mode] Completed {} sims in {:.3}s ({:.0} sims/s)", total_visits, duration.as_secs_f64(), sims_per_sec); | |
| agg_profile.print(duration); | |
| } | |
| let mut stats: Vec<(i32, f32, u32)> = agg_map.into_iter().map(|(action, (sum_val, visits))| { | |
| if visits > 0 { | |
| (action, sum_val / visits as f32, visits) | |
| } else { | |
| (action, 0.0, 0) | |
| } | |
| }).collect(); | |
| stats.sort_by_key(|&(_, _, v)| std::cmp::Reverse(v)); | |
| stats | |
| } | |
| pub fn search(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, heuristic: &dyn Heuristic) -> (Vec<(i32, f32, u32)>, MCTSProfiler) { | |
| self.search_custom(root_state, db, num_sims, timeout_sec, horizon, heuristic, false, true) | |
| } | |
| pub fn search_custom(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, heuristic: &dyn Heuristic, shuffle_self: bool, enable_rollout: bool) -> (Vec<(i32, f32, u32)>, MCTSProfiler) { | |
| self.run_mcts_config(root_state, db, num_sims, timeout_sec, horizon, shuffle_self, enable_rollout, |state, _db| { | |
| if state.is_terminal() { | |
| match state.get_winner() { | |
| 0 => 1.0, | |
| 1 => 0.0, | |
| _ => 0.5, | |
| } | |
| } else { | |
| heuristic.evaluate(state, db, root_state.players[0].score, root_state.players[1].score) | |
| } | |
| }) | |
| } | |
| pub fn search_mode(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, eval_mode: EvalMode) -> (Vec<(i32, f32, u32)>, MCTSProfiler) { | |
| // Pre-calculate deck expectations for optimization | |
| // P0 (Me): Use current deck | |
| let p0_stats = Self::calculate_deck_expectations(&root_state.players[0].deck, db); | |
| // P1 (Opponent): Use Hand + Deck (Unseen) | |
| let mut p1_unseen = root_state.players[1].hand.clone(); | |
| p1_unseen.extend(root_state.players[1].deck.iter().cloned()); | |
| let p1_stats = Self::calculate_deck_expectations(&p1_unseen, db); | |
| self.run_mcts_config(root_state, db, num_sims, timeout_sec, horizon, eval_mode == EvalMode::Blind, true, |state, _db| { | |
| if state.is_terminal() { | |
| if eval_mode == EvalMode::Solitaire { | |
| match state.get_winner() { | |
| 0 => 1.0, | |
| 1 => 0.0, | |
| _ => 0.5, | |
| } | |
| } else { | |
| match state.get_winner() { | |
| 0 => 1.0, | |
| 1 => 0.0, | |
| _ => 0.5, | |
| } | |
| } | |
| } else { | |
| Self::heuristic_eval(state, db, root_state.players[0].score, root_state.players[1].score, eval_mode, Some(p0_stats), Some(p1_stats)) | |
| } | |
| }) | |
| } | |
| fn run_mcts_config<F>(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, shuffle_self: bool, enable_rollout: bool, mut eval_fn: F) -> (Vec<(i32, f32, u32)>, MCTSProfiler) | |
| where F: FnMut(&GameState, &CardDatabase) -> f32 | |
| { | |
| self.nodes.clear(); | |
| let start_time = std::time::Instant::now(); | |
| let timeout = if timeout_sec > 0.0 { Some(std::time::Duration::from_secs_f32(timeout_sec)) } else { None }; | |
| let start_turn = root_state.turn; | |
| let mut profiler = MCTSProfiler::default(); | |
| // Root Node | |
| let mut legal_indices: Vec<i32> = Vec::with_capacity(32); | |
| root_state.generate_legal_actions(db, &mut legal_indices); | |
| if legal_indices.is_empty() { return (vec![(0, 0.5, 0)], profiler); } | |
| if legal_indices.len() == 1 { return (vec![(legal_indices[0], 0.5, 1)], profiler); } | |
| self.nodes.push(Node { | |
| visit_count: 0, | |
| value_sum: 0.0, | |
| player_just_moved: 1 - root_state.current_player, | |
| untried_actions: legal_indices, | |
| children: Vec::new(), | |
| parent: None, | |
| parent_action: 0, | |
| }); | |
| let mut sims_done = 0; | |
| loop { | |
| if num_sims > 0 && sims_done >= num_sims { break; } | |
| if let Some(to) = timeout { | |
| if start_time.elapsed() >= to { break; } | |
| } | |
| if num_sims == 0 && timeout.is_none() { break; } // Safety | |
| sims_done += 1; | |
| // 1. Setup & Determinization | |
| let t_setup = std::time::Instant::now(); | |
| let mut node_idx = 0; | |
| self.reusable_state.copy_from(&root_state); | |
| let state = &mut self.reusable_state; | |
| state.silent = true; | |
| let me = root_state.current_player as usize; | |
| let opp = 1 - me; | |
| let opp_hand_len = state.players[opp].hand.len(); | |
| self.unseen_buffer.clear(); | |
| self.unseen_buffer.extend_from_slice(&state.players[opp].hand); | |
| self.unseen_buffer.extend_from_slice(&state.players[opp].deck); | |
| self.unseen_buffer.shuffle(&mut self.rng); | |
| state.players[opp].hand.copy_from_slice(&self.unseen_buffer[0..opp_hand_len]); | |
| state.players[opp].deck.copy_from_slice(&self.unseen_buffer[opp_hand_len..]); | |
| if shuffle_self { | |
| let mut my_deck = state.players[me].deck.clone(); | |
| my_deck.shuffle(&mut self.rng); | |
| state.players[me].deck = my_deck; | |
| } | |
| profiler.determinization += t_setup.elapsed(); | |
| // 2. Selection | |
| let t_selection = std::time::Instant::now(); | |
| while self.nodes[node_idx].untried_actions.is_empty() && !self.nodes[node_idx].children.is_empty() { | |
| node_idx = Self::select_child(&self.nodes, node_idx); | |
| let action = self.nodes[node_idx].parent_action; | |
| let _ = state.step(db, action); | |
| } | |
| profiler.selection += t_selection.elapsed(); | |
| // 3. Expansion | |
| let t_expansion = std::time::Instant::now(); | |
| if !self.nodes[node_idx].untried_actions.is_empty() { | |
| let idx = self.rng.random_range(0..self.nodes[node_idx].untried_actions.len()); | |
| let action = self.nodes[node_idx].untried_actions.swap_remove(idx); | |
| let actor = state.current_player; | |
| let _ = state.step(db, action); | |
| let mut new_legal_indices: Vec<i32> = Vec::with_capacity(32); | |
| state.generate_legal_actions(db, &mut new_legal_indices); | |
| let new_node = Node { | |
| visit_count: 0, | |
| value_sum: 0.0, | |
| player_just_moved: actor, | |
| untried_actions: new_legal_indices, | |
| children: Vec::new(), | |
| parent: Some(node_idx), | |
| parent_action: action, | |
| }; | |
| let new_idx = self.nodes.len(); | |
| self.nodes.push(new_node); | |
| self.nodes[node_idx].children.push((action, new_idx)); | |
| node_idx = new_idx; | |
| } | |
| profiler.expansion += t_expansion.elapsed(); | |
| // 4. Simulation | |
| let t_simulation = std::time::Instant::now(); | |
| let mut depth = 0; | |
| if enable_rollout { | |
| while !state.is_terminal() && depth < 200 { | |
| // Horizon Check | |
| if horizon == SearchHorizon::TurnEnd && state.turn > start_turn { | |
| break; | |
| } | |
| state.generate_legal_actions(db, &mut self.legal_buffer); | |
| if self.legal_buffer.is_empty() { break; } | |
| let chunk_action = *self.legal_buffer.choose(&mut self.rng).unwrap(); | |
| let _ = state.step(db, chunk_action); | |
| depth += 1; | |
| } | |
| } | |
| profiler.simulation += t_simulation.elapsed(); | |
| // 5. Backpropagation | |
| let t_backprop = std::time::Instant::now(); | |
| let reward_p0 = eval_fn(&state, db); | |
| let mut curr = Some(node_idx); | |
| while let Some(idx) = curr { | |
| let node_p_moved = self.nodes[idx].player_just_moved; | |
| let node = &mut self.nodes[idx]; | |
| node.visit_count += 1; | |
| if node_p_moved == 0 { | |
| node.value_sum += reward_p0; | |
| } else { | |
| node.value_sum += 1.0 - reward_p0; | |
| } | |
| curr = node.parent; | |
| } | |
| profiler.backpropagation += t_backprop.elapsed(); | |
| } | |
| let mut stats: Vec<(i32, f32, u32)> = self.nodes[0].children.iter() | |
| .map(|&(act, idx)| { | |
| let child = &self.nodes[idx]; | |
| (act, child.value_sum / child.visit_count as f32, child.visit_count) | |
| }) | |
| .collect(); | |
| stats.sort_by_key(|&(_, _, v)| std::cmp::Reverse(v)); | |
| (stats, profiler) | |
| } | |
| fn select_child(nodes: &[Node], node_idx: usize) -> usize { | |
| let node = &nodes[node_idx]; | |
| let mut best_score = f32::NEG_INFINITY; | |
| let mut best_child = 0; | |
| let log_n = (node.visit_count as f32).ln(); | |
| for &(_, child_idx) in &node.children { | |
| let child = &nodes[child_idx]; | |
| // UCB1 | |
| let exploit = child.value_sum / child.visit_count as f32; | |
| let explore = 1.4 * (log_n / child.visit_count as f32).sqrt(); | |
| let score = exploit + explore; | |
| if score > best_score { | |
| best_score = score; | |
| best_child = child_idx; | |
| } | |
| } | |
| best_child | |
| } | |
| fn heuristic_eval(state: &GameState, db: &CardDatabase, p0_baseline: u32, p1_baseline: u32, eval_mode: EvalMode, p0_deck_stats: Option<([f32; 7], f32)>, p1_deck_stats: Option<([f32; 7], f32)>) -> f32 { | |
| let score0 = Self::evaluate_player(state, db, 0, p0_baseline, p0_deck_stats); | |
| if eval_mode == EvalMode::Solitaire { | |
| // Solitaire: Only care about P0 score. Normalize to [0,1]. | |
| // Max possible score estimate: 3 lives * 5 + 2 bonus + 5 board + 2 hand = ~25 | |
| return (score0 / 25.0).clamp(0.0, 1.0); | |
| } | |
| let score1 = Self::evaluate_player(state, db, 1, p1_baseline, p1_deck_stats); | |
| let mut final_val = (score0 - score1) * 0.5 + 0.5; | |
| // "Win the Live" Tie-breaker (Volume Lead) | |
| // If we are in Performance/LiveResult, checking volume matters for resolving ties. | |
| // Assuming equal lives, higher volume is better. | |
| let p0_vol = state.players[0].current_turn_volume; | |
| let p1_vol = state.players[1].current_turn_volume; | |
| if p0_vol > p1_vol { | |
| final_val += 0.05; | |
| } else if p1_vol > p0_vol { | |
| final_val -= 0.05; | |
| } | |
| final_val.clamp(0.0, 1.0) | |
| } | |
| fn evaluate_player(state: &GameState, db: &CardDatabase, p_idx: usize, baseline_score: u32, deck_stats: Option<([f32; 7], f32)>) -> f32 { | |
| let p = &state.players[p_idx]; | |
| let mut score = 0.0; | |
| // 1. Success Lives (The Goal) | |
| // Heavily weight cleared lives. | |
| score += p.success_lives.len() as f32 * 5.0; // Increased weight to prioritize winning | |
| // Bonus for score increase relative to baseline | |
| if p.success_lives.len() > baseline_score as usize { | |
| score += 2.0; | |
| } | |
| // 2. Board State (Stage) | |
| let mut stage_hearts = [0u32; 7]; | |
| let mut stage_blades = 0; | |
| for i in 0..3 { | |
| let h = state.get_effective_hearts(p_idx, i, db); | |
| for color in 0..7 { stage_hearts[color] += h[color] as u32; } | |
| stage_blades += state.get_effective_blades(p_idx, i, db); | |
| } | |
| // Small bonus for board presence/power | |
| score += stage_blades as f32 * 0.1; | |
| // 3. Deck Expectations (Yells) | |
| // Calculate expected value from deck based on probability | |
| let (avg_hearts, avg_vol) = if let Some(stats) = deck_stats { | |
| stats | |
| } else { | |
| Self::calculate_deck_expectations(&p.deck, db) | |
| }; | |
| let expected_yell_hearts: Vec<f32> = avg_hearts.iter().map(|&h| h * stage_blades as f32).collect(); | |
| let expected_volume = avg_vol * stage_blades as f32; | |
| // 4. Live Clearing Probability | |
| let mut max_prob = 0.0; | |
| for &cid in &p.live_zone { | |
| if cid >= 0 { | |
| if let Some(l) = db.get_live(cid as u16) { | |
| let prob = Self::calculate_live_success_prob( | |
| l, | |
| &stage_hearts, | |
| &expected_yell_hearts, | |
| &p.heart_req_reductions | |
| ); | |
| // Reward probability of clearing. | |
| score += prob * (l.score as f32 * 2.0); // Higher weight for potential clears | |
| if prob > max_prob { max_prob = prob; } | |
| } | |
| } | |
| } | |
| // Reward Volume if we are likely to clear at least one live | |
| if max_prob > 0.5 { | |
| score += (expected_volume + p.current_turn_volume as f32) * max_prob * 0.5; | |
| } | |
| // 5. Hand Quality (Construction) | |
| let hand_val = Self::calculate_hand_quality(state, db, p_idx); | |
| score += hand_val * 0.15; | |
| // 6. Resources & Discard | |
| score += p.hand.len() as f32 * 0.05; | |
| // Energy efficiency (using available energy) | |
| let unused_energy = p.tapped_energy.iter().filter(|&&t| !t).count(); | |
| score += unused_energy as f32 * 0.01; | |
| // Discard Recovery Potential | |
| // If hand has recovery, value good stuff in discard | |
| let has_recovery = p.hand.iter().any(|&cid| { | |
| if let Some(m) = db.get_member(cid) { | |
| m.abilities.iter().any(|a| Self::has_opcode(&a.bytecode, 15) || Self::has_opcode(&a.bytecode, 17)) | |
| } else { false } | |
| }); | |
| if has_recovery { | |
| let discard_val = p.discard.iter().filter(|&&cid| { | |
| db.get_live(cid).is_some() || db.get_member(cid).map_or(false, |m| m.cost >= 3) | |
| }).count(); | |
| score += discard_val as f32 * 0.1; | |
| } | |
| score | |
| } | |
| fn calculate_deck_expectations(deck: &[u16], db: &CardDatabase) -> ([f32; 7], f32) { | |
| if deck.is_empty() { return ([0.0; 7], 0.0); } | |
| // Since deck is determinized (shuffled) in the leaf node state, | |
| // we could just look at the top N cards if we knew N (blades). | |
| // But `blades` changes. So getting average stats of the *remaining* deck is more robust. | |
| let mut total_hearts = [0.0; 7]; | |
| let mut total_vol = 0.0; | |
| let count = deck.len() as f32; | |
| for &cid in deck { | |
| if let Some(m) = db.get_member(cid) { | |
| for i in 0..7 { total_hearts[i] += m.blade_hearts[i] as f32; } | |
| total_vol += m.volume_icons as f32; | |
| } else if let Some(l) = db.get_live(cid) { | |
| for i in 0..7 { total_hearts[i] += l.blade_hearts[i] as f32; } | |
| total_vol += l.volume_icons as f32; | |
| } | |
| } | |
| let avg_hearts = total_hearts.map(|v| v / count); | |
| let avg_vol = total_vol / count; | |
| (avg_hearts, avg_vol) | |
| } | |
| fn calculate_live_success_prob(live: &LiveCard, stage_hearts: &[u32; 7], expected_yell_hearts: &[f32], reductions: &[i32; 7]) -> f32 { | |
| let mut prob; | |
| // 1. Hearts Check | |
| let mut needed = live.required_hearts; | |
| // Apply reductions | |
| for i in 0..7 { | |
| needed[i] = (needed[i] as i32 - reductions[i]).max(0) as u8; | |
| } | |
| let mut satisfied = 0.0; | |
| let mut total_req = 0.0; | |
| let mut wildcards_avail = stage_hearts[6] as f32 + expected_yell_hearts[6]; | |
| // Specific Colors | |
| for i in 0..6 { | |
| let req = needed[i] as f32; | |
| total_req += req; | |
| let have = stage_hearts[i] as f32 + expected_yell_hearts[i]; | |
| if have >= req { | |
| satisfied += req; | |
| } else { | |
| satisfied += have; | |
| let deficit = req - have; | |
| let used_wild = wildcards_avail.min(deficit); | |
| satisfied += used_wild; | |
| wildcards_avail -= used_wild; | |
| } | |
| } | |
| // ANY (Star) Requirement | |
| let any_req = needed[6] as f32; | |
| total_req += any_req; | |
| let used_wild = wildcards_avail.min(any_req); | |
| satisfied += used_wild; | |
| let mut remaining_any = any_req - used_wild; | |
| if remaining_any > 0.0 { | |
| // Use surpluses | |
| for i in 0..6 { | |
| let req = needed[i] as f32; | |
| let have = stage_hearts[i] as f32 + expected_yell_hearts[i]; | |
| let surplus = (have - req).max(0.0); | |
| let used = surplus.min(remaining_any); | |
| satisfied += used; | |
| remaining_any -= used; | |
| if remaining_any <= 0.0 { break; } | |
| } | |
| } | |
| if total_req > 0.0 { | |
| prob = (satisfied / total_req).clamp(0.0, 1.0); | |
| // Non-linear scaling: getting 90% is much better than 50% | |
| prob = prob.powf(0.5); // Square root makes it optimistic? No, we want to penalize failure. | |
| // Actually, close to 1.0 is good. | |
| if prob >= 1.0 { prob = 1.2; } // Guaranteed win bonus | |
| } else { | |
| prob = 1.2; // Free live | |
| } | |
| prob | |
| } | |
| fn calculate_hand_quality(state: &GameState, db: &CardDatabase, p_idx: usize) -> f32 { | |
| let p = &state.players[p_idx]; | |
| let mut val = 0.0; | |
| // In Mulligan phase, assume we will have ~3 energy soon for cost evaluations | |
| let is_mulligan = match state.phase { | |
| Phase::MulliganP1 | Phase::MulliganP2 => true, | |
| _ => false, | |
| }; | |
| let max_energy = if is_mulligan { 3 } else { p.energy_zone.len() as u32 }; | |
| for (i, &cid) in p.hand.iter().enumerate() { | |
| let card_val = Self::calculate_card_potential(cid, db, max_energy); | |
| // If card is selected for mulligan, it's effectively "gone" but replaced by an "average" card | |
| // Average card potential is roughly 1.0 | |
| if is_mulligan && ((p.mulligan_selection >> i) & 1u64 == 1) { | |
| val += 1.0; // Expected potential of replacement | |
| } else { | |
| val += card_val; | |
| } | |
| } | |
| val | |
| } | |
| fn calculate_card_potential(cid: u16, db: &CardDatabase, max_energy: u32) -> f32 { | |
| if let Some(m) = db.get_member(cid) { | |
| let mut score = 0.0; | |
| // Cost Efficiency | |
| let stat_sum: u32 = m.hearts.iter().map(|&x| x as u32).sum(); | |
| score += (m.blades as f32 + stat_sum as f32) / (m.cost as f32 + 1.0); | |
| // Penalty for unplayable high-cost cards | |
| if m.cost > max_energy { | |
| let diff = m.cost - max_energy; | |
| score -= diff as f32 * 0.5; | |
| } | |
| // Fast Flag Checks (O(1)) | |
| let f = m.ability_flags; | |
| if (f & FLAG_DRAW) != 0 { score += 0.5; } | |
| if (f & FLAG_SEARCH) != 0 { score += 0.6; } | |
| if (f & FLAG_RECOVER) != 0 { score += 0.4; } | |
| if (f & FLAG_BUFF) != 0 { score += 0.3; } | |
| if (f & FLAG_CHARGE) != 0 { score += 0.8; } | |
| if (f & FLAG_TEMPO) != 0 { score += 0.2; } | |
| if (f & FLAG_REDUCE) != 0 { score += 0.4; } | |
| if (f & FLAG_BOOST) != 0 { score += 0.5; } | |
| if (f & FLAG_TRANSFORM) != 0 { score += 0.3; } | |
| if (f & FLAG_WIN_COND) != 0 { score += 0.6; } | |
| score | |
| } else if let Some(l) = db.get_live(cid) { | |
| // Lives in hand are good if we can clear them, but bad if they clog | |
| // Simple heuristic: Score value | |
| l.score as f32 * 0.2 | |
| } else { | |
| 0.0 | |
| } | |
| } | |
| fn has_opcode(bytecode: &[i32], target_op: i32) -> bool { | |
| let mut i = 0; | |
| while i < bytecode.len() { | |
| if i + 3 >= bytecode.len() { break; } | |
| let op = bytecode[i]; | |
| if op == target_op { return true; } | |
| i += 4; | |
| } | |
| false | |
| } | |
| } | |
| pub struct HybridMCTS { | |
| pub session: Arc<Mutex<Session>>, | |
| pub neural_weight: f32, | |
| pub skip_rollout: bool, | |
| pub rng: SmallRng, | |
| } | |
| impl HybridMCTS { | |
| pub fn new(session: Arc<Mutex<Session>>, neural_weight: f32, skip_rollout: bool) -> Self { | |
| Self { | |
| session, | |
| neural_weight, | |
| skip_rollout, | |
| rng: SmallRng::from_os_rng(), | |
| } | |
| } | |
| pub fn get_suggestions(&mut self, state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32) -> Vec<(i32, f32, u32)> { | |
| let start = std::time::Instant::now(); | |
| let (stats, profile) = self.search(state, db, num_sims, timeout_sec); | |
| let duration = start.elapsed(); | |
| if num_sims > 10 { | |
| profile.print(duration); | |
| } | |
| stats | |
| } | |
| pub fn search(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32) -> (Vec<(i32, f32, u32)>, MCTSProfiler) { | |
| let session_arc = self.session.clone(); | |
| let neural_weight = self.neural_weight; | |
| let mut mcts = MCTS::new(); | |
| mcts.run_mcts_config(root_state, db, num_sims, timeout_sec, SearchHorizon::GameEnd, false, !self.skip_rollout, |state: &GameState, db: &CardDatabase| { | |
| if state.is_terminal() { | |
| return match state.get_winner() { | |
| 0 => 1.0, | |
| 1 => 0.0, | |
| _ => 0.5, | |
| }; | |
| } | |
| // Normal Heuristic Baseline | |
| let h_val = MCTS::heuristic_eval(state, db, root_state.players[0].score, root_state.players[1].score, EvalMode::Normal, None, None); | |
| // NN Evaluation | |
| let input_vec = state.encode_state(db); | |
| let mut session = session_arc.lock().unwrap(); | |
| let input_shape = [1, input_vec.len()]; | |
| // Try to create input tensor and run using (shape, vec) which is version-agnostic | |
| if let Ok(input_tensor) = ort::value::Value::from_array((input_shape, input_vec)) { | |
| if let Ok(outputs) = session.run(ort::inputs![input_tensor]) { | |
| // Try to get value output. AlphaNet has 'output_1' or index 1 | |
| let val_opt = outputs.get("output_1") | |
| .or_else(|| outputs.get("value")); | |
| if let Some(v_val) = val_opt { | |
| if let Ok((_, v_slice)) = v_val.try_extract_tensor::<f32>() { | |
| let nn_val = v_slice[0]; | |
| let nn_norm = (nn_val * 0.5 + 0.5).clamp(0.0, 1.0) as f32; | |
| return h_val * (1.0 - neural_weight) + nn_norm * neural_weight; | |
| } | |
| } else if outputs.len() > 1 { | |
| // Fallback to index if names missing | |
| if let Ok((_, v_slice)) = outputs[1].try_extract_tensor::<f32>() { | |
| let nn_val = v_slice[0]; | |
| let nn_norm = (nn_val * 0.5 + 0.5).clamp(0.0, 1.0) as f32; | |
| return h_val * (1.0 - neural_weight) + nn_norm * neural_weight; | |
| } | |
| } | |
| } | |
| } | |
| h_val | |
| }) | |
| } | |
| } | |