trioskosmos's picture
Upload folder using huggingface_hub
88d4171 verified
#[cfg(feature = "extension-module")]
use pyo3::prelude::*;
use rand::prelude::*;
use rand::rngs::SmallRng;
use rand::SeedableRng;
use crate::core::heuristics::Heuristic;
use crate::core::logic::{
GameState, CardDatabase, LiveCard, Phase,
FLAG_DRAW, FLAG_SEARCH, FLAG_RECOVER, FLAG_BUFF, FLAG_CHARGE,
FLAG_TEMPO, FLAG_REDUCE, FLAG_BOOST, FLAG_TRANSFORM, FLAG_WIN_COND
};
use std::f32;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use std::collections::HashMap;
#[cfg(feature = "nn")]
use ort::session::Session;
#[cfg(feature = "nn")]
use std::sync::{Arc, Mutex};
use std::time::Duration;
#[derive(Default, Clone, Copy)]
pub struct MCTSProfiler {
pub determinization: Duration,
pub selection: Duration,
pub expansion: Duration,
pub simulation: Duration,
pub backpropagation: Duration,
}
impl MCTSProfiler {
pub fn merge(&mut self, other: &Self) {
self.determinization += other.determinization;
self.selection += other.selection;
self.expansion += other.expansion;
self.simulation += other.simulation;
self.backpropagation += other.backpropagation;
}
pub fn print(&self, total: Duration) {
let total_secs = total.as_secs_f64();
if total_secs == 0.0 { return; }
println!("[MCTS Profile] Breakdown:");
let items = [
("Determinization", self.determinization),
("Selection", self.selection),
("Expansion", self.expansion),
("Simulation", self.simulation),
("Backpropagation", self.backpropagation),
];
for (name, dur) in items {
let secs = dur.as_secs_f64();
println!(" - {:<16}: {:>8.3}s ({:>6.1}%)", name, secs, (secs / total_secs) * 100.0);
}
}
}
struct Node {
visit_count: u32,
value_sum: f32,
player_just_moved: u8,
untried_actions: Vec<i32>,
children: Vec<(i32, usize)>, // (Action, NodeIndex in Arena)
parent: Option<usize>,
parent_action: i32,
}
pub struct MCTS {
nodes: Vec<Node>,
rng: SmallRng,
unseen_buffer: Vec<u16>,
legal_buffer: Vec<i32>,
reusable_state: GameState,
}
#[cfg_attr(feature = "extension-module", pyclass(eq, eq_int))]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SearchHorizon {
GameEnd,
TurnEnd,
}
#[cfg_attr(feature = "extension-module", pyclass(eq, eq_int))]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum EvalMode {
Normal,
Solitaire,
Blind,
}
impl MCTS {
pub fn new() -> Self {
Self {
nodes: Vec::with_capacity(1000),
rng: SmallRng::from_os_rng(),
unseen_buffer: Vec::with_capacity(60),
legal_buffer: Vec::with_capacity(32),
reusable_state: GameState::default(),
}
}
pub fn search_parallel(&self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, heuristic: &dyn Heuristic, shuffle_self: bool) -> Vec<(i32, f32, u32)> {
let start_overall = std::time::Instant::now();
#[cfg(feature = "parallel")]
let num_threads = rayon::current_num_threads().max(1);
#[cfg(not(feature = "parallel"))]
let num_threads = 1;
let sims_per_thread = if num_sims > 0 { (num_sims + num_threads - 1) / num_threads } else { 0 };
// Collect results
#[cfg(feature = "parallel")]
let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = (0..num_threads).into_par_iter().map(|_| {
let mut mcts = MCTS::new();
mcts.search_custom(root_state, db, sims_per_thread, timeout_sec, horizon, heuristic, shuffle_self, true)
}).collect();
#[cfg(not(feature = "parallel"))]
let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = vec![{
let mut mcts = MCTS::new();
mcts.search_custom(root_state, db, num_sims, timeout_sec, horizon, heuristic, shuffle_self, true)
}];
// Merge results
let mut agg_map: HashMap<i32, (f32, u32)> = HashMap::new();
let mut total_visits = 0;
let mut agg_profile = MCTSProfiler::default();
for (res, profile) in results {
agg_profile.merge(&profile);
for (action, score, visits) in res {
let entry = agg_map.entry(action).or_insert((0.0, 0));
let total_value = score * visits as f32;
entry.0 += total_value;
entry.1 += visits;
total_visits += visits;
}
}
// Logging Speed
let duration = start_overall.elapsed();
let sims_per_sec = total_visits as f64 / duration.as_secs_f64();
if total_visits > 100 {
println!("[MCTS] Completed {} sims in {:.3}s ({:.0} sims/s)", total_visits, duration.as_secs_f64(), sims_per_sec);
agg_profile.print(duration);
}
let mut stats: Vec<(i32, f32, u32)> = agg_map.into_iter().map(|(action, (sum_val, visits))| {
if visits > 0 {
(action, sum_val / visits as f32, visits)
} else {
(action, 0.0, 0)
}
}).collect();
stats.sort_by_key(|&(_, _, v)| std::cmp::Reverse(v));
stats
}
pub fn search_parallel_mode(&self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, eval_mode: EvalMode) -> Vec<(i32, f32, u32)> {
let start_overall = std::time::Instant::now();
#[cfg(feature = "parallel")]
let num_threads = rayon::current_num_threads().max(1);
#[cfg(not(feature = "parallel"))]
let num_threads = 1;
let sims_per_thread = if num_sims > 0 { (num_sims + num_threads - 1) / num_threads } else { 0 };
// Collect results
#[cfg(feature = "parallel")]
let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = (0..num_threads).into_par_iter().map(|_| {
let mut mcts = MCTS::new();
mcts.search_mode(root_state, db, sims_per_thread, timeout_sec, horizon, eval_mode)
}).collect();
#[cfg(not(feature = "parallel"))]
let results: Vec<(Vec<(i32, f32, u32)>, MCTSProfiler)> = vec![{
let mut mcts = MCTS::new();
mcts.search_mode(root_state, db, num_sims, timeout_sec, horizon, eval_mode)
}];
// Merge results
let mut agg_map: HashMap<i32, (f32, u32)> = HashMap::new();
let mut total_visits = 0;
let mut agg_profile = MCTSProfiler::default();
for (res, profile) in results {
agg_profile.merge(&profile);
for (action, score, visits) in res {
let entry = agg_map.entry(action).or_insert((0.0, 0));
let total_value = score * visits as f32;
entry.0 += total_value;
entry.1 += visits;
total_visits += visits;
}
}
// Logging Speed
let duration = start_overall.elapsed();
let sims_per_sec = total_visits as f64 / duration.as_secs_f64();
if total_visits > 100 {
println!("[MCTS Mode] Completed {} sims in {:.3}s ({:.0} sims/s)", total_visits, duration.as_secs_f64(), sims_per_sec);
agg_profile.print(duration);
}
let mut stats: Vec<(i32, f32, u32)> = agg_map.into_iter().map(|(action, (sum_val, visits))| {
if visits > 0 {
(action, sum_val / visits as f32, visits)
} else {
(action, 0.0, 0)
}
}).collect();
stats.sort_by_key(|&(_, _, v)| std::cmp::Reverse(v));
stats
}
pub fn search(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, heuristic: &dyn Heuristic) -> (Vec<(i32, f32, u32)>, MCTSProfiler) {
self.search_custom(root_state, db, num_sims, timeout_sec, horizon, heuristic, false, true)
}
pub fn search_custom(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, heuristic: &dyn Heuristic, shuffle_self: bool, enable_rollout: bool) -> (Vec<(i32, f32, u32)>, MCTSProfiler) {
self.run_mcts_config(root_state, db, num_sims, timeout_sec, horizon, shuffle_self, enable_rollout, |state, _db| {
if state.is_terminal() {
match state.get_winner() {
0 => 1.0,
1 => 0.0,
_ => 0.5,
}
} else {
heuristic.evaluate(state, db, root_state.players[0].score, root_state.players[1].score)
}
})
}
pub fn search_mode(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, eval_mode: EvalMode) -> (Vec<(i32, f32, u32)>, MCTSProfiler) {
// Pre-calculate deck expectations for optimization
// P0 (Me): Use current deck
let p0_stats = Self::calculate_deck_expectations(&root_state.players[0].deck, db);
// P1 (Opponent): Use Hand + Deck (Unseen)
let mut p1_unseen = root_state.players[1].hand.clone();
p1_unseen.extend(root_state.players[1].deck.iter().cloned());
let p1_stats = Self::calculate_deck_expectations(&p1_unseen, db);
self.run_mcts_config(root_state, db, num_sims, timeout_sec, horizon, eval_mode == EvalMode::Blind, true, |state, _db| {
if state.is_terminal() {
if eval_mode == EvalMode::Solitaire {
match state.get_winner() {
0 => 1.0,
1 => 0.0,
_ => 0.5,
}
} else {
match state.get_winner() {
0 => 1.0,
1 => 0.0,
_ => 0.5,
}
}
} else {
Self::heuristic_eval(state, db, root_state.players[0].score, root_state.players[1].score, eval_mode, Some(p0_stats), Some(p1_stats))
}
})
}
fn run_mcts_config<F>(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32, horizon: SearchHorizon, shuffle_self: bool, enable_rollout: bool, mut eval_fn: F) -> (Vec<(i32, f32, u32)>, MCTSProfiler)
where F: FnMut(&GameState, &CardDatabase) -> f32
{
self.nodes.clear();
let start_time = std::time::Instant::now();
let timeout = if timeout_sec > 0.0 { Some(std::time::Duration::from_secs_f32(timeout_sec)) } else { None };
let start_turn = root_state.turn;
let mut profiler = MCTSProfiler::default();
// Root Node
let mut legal_indices: Vec<i32> = Vec::with_capacity(32);
root_state.generate_legal_actions(db, &mut legal_indices);
if legal_indices.is_empty() { return (vec![(0, 0.5, 0)], profiler); }
if legal_indices.len() == 1 { return (vec![(legal_indices[0], 0.5, 1)], profiler); }
self.nodes.push(Node {
visit_count: 0,
value_sum: 0.0,
player_just_moved: 1 - root_state.current_player,
untried_actions: legal_indices,
children: Vec::new(),
parent: None,
parent_action: 0,
});
let mut sims_done = 0;
loop {
if num_sims > 0 && sims_done >= num_sims { break; }
if let Some(to) = timeout {
if start_time.elapsed() >= to { break; }
}
if num_sims == 0 && timeout.is_none() { break; } // Safety
sims_done += 1;
// 1. Setup & Determinization
let t_setup = std::time::Instant::now();
let mut node_idx = 0;
self.reusable_state.copy_from(&root_state);
let state = &mut self.reusable_state;
state.silent = true;
let me = root_state.current_player as usize;
let opp = 1 - me;
let opp_hand_len = state.players[opp].hand.len();
self.unseen_buffer.clear();
self.unseen_buffer.extend_from_slice(&state.players[opp].hand);
self.unseen_buffer.extend_from_slice(&state.players[opp].deck);
self.unseen_buffer.shuffle(&mut self.rng);
state.players[opp].hand.copy_from_slice(&self.unseen_buffer[0..opp_hand_len]);
state.players[opp].deck.copy_from_slice(&self.unseen_buffer[opp_hand_len..]);
if shuffle_self {
let mut my_deck = state.players[me].deck.clone();
my_deck.shuffle(&mut self.rng);
state.players[me].deck = my_deck;
}
profiler.determinization += t_setup.elapsed();
// 2. Selection
let t_selection = std::time::Instant::now();
while self.nodes[node_idx].untried_actions.is_empty() && !self.nodes[node_idx].children.is_empty() {
node_idx = Self::select_child(&self.nodes, node_idx);
let action = self.nodes[node_idx].parent_action;
let _ = state.step(db, action);
}
profiler.selection += t_selection.elapsed();
// 3. Expansion
let t_expansion = std::time::Instant::now();
if !self.nodes[node_idx].untried_actions.is_empty() {
let idx = self.rng.random_range(0..self.nodes[node_idx].untried_actions.len());
let action = self.nodes[node_idx].untried_actions.swap_remove(idx);
let actor = state.current_player;
let _ = state.step(db, action);
let mut new_legal_indices: Vec<i32> = Vec::with_capacity(32);
state.generate_legal_actions(db, &mut new_legal_indices);
let new_node = Node {
visit_count: 0,
value_sum: 0.0,
player_just_moved: actor,
untried_actions: new_legal_indices,
children: Vec::new(),
parent: Some(node_idx),
parent_action: action,
};
let new_idx = self.nodes.len();
self.nodes.push(new_node);
self.nodes[node_idx].children.push((action, new_idx));
node_idx = new_idx;
}
profiler.expansion += t_expansion.elapsed();
// 4. Simulation
let t_simulation = std::time::Instant::now();
let mut depth = 0;
if enable_rollout {
while !state.is_terminal() && depth < 200 {
// Horizon Check
if horizon == SearchHorizon::TurnEnd && state.turn > start_turn {
break;
}
state.generate_legal_actions(db, &mut self.legal_buffer);
if self.legal_buffer.is_empty() { break; }
let chunk_action = *self.legal_buffer.choose(&mut self.rng).unwrap();
let _ = state.step(db, chunk_action);
depth += 1;
}
}
profiler.simulation += t_simulation.elapsed();
// 5. Backpropagation
let t_backprop = std::time::Instant::now();
let reward_p0 = eval_fn(&state, db);
let mut curr = Some(node_idx);
while let Some(idx) = curr {
let node_p_moved = self.nodes[idx].player_just_moved;
let node = &mut self.nodes[idx];
node.visit_count += 1;
if node_p_moved == 0 {
node.value_sum += reward_p0;
} else {
node.value_sum += 1.0 - reward_p0;
}
curr = node.parent;
}
profiler.backpropagation += t_backprop.elapsed();
}
let mut stats: Vec<(i32, f32, u32)> = self.nodes[0].children.iter()
.map(|&(act, idx)| {
let child = &self.nodes[idx];
(act, child.value_sum / child.visit_count as f32, child.visit_count)
})
.collect();
stats.sort_by_key(|&(_, _, v)| std::cmp::Reverse(v));
(stats, profiler)
}
fn select_child(nodes: &[Node], node_idx: usize) -> usize {
let node = &nodes[node_idx];
let mut best_score = f32::NEG_INFINITY;
let mut best_child = 0;
let log_n = (node.visit_count as f32).ln();
for &(_, child_idx) in &node.children {
let child = &nodes[child_idx];
// UCB1
let exploit = child.value_sum / child.visit_count as f32;
let explore = 1.4 * (log_n / child.visit_count as f32).sqrt();
let score = exploit + explore;
if score > best_score {
best_score = score;
best_child = child_idx;
}
}
best_child
}
fn heuristic_eval(state: &GameState, db: &CardDatabase, p0_baseline: u32, p1_baseline: u32, eval_mode: EvalMode, p0_deck_stats: Option<([f32; 7], f32)>, p1_deck_stats: Option<([f32; 7], f32)>) -> f32 {
let score0 = Self::evaluate_player(state, db, 0, p0_baseline, p0_deck_stats);
if eval_mode == EvalMode::Solitaire {
// Solitaire: Only care about P0 score. Normalize to [0,1].
// Max possible score estimate: 3 lives * 5 + 2 bonus + 5 board + 2 hand = ~25
return (score0 / 25.0).clamp(0.0, 1.0);
}
let score1 = Self::evaluate_player(state, db, 1, p1_baseline, p1_deck_stats);
let mut final_val = (score0 - score1) * 0.5 + 0.5;
// "Win the Live" Tie-breaker (Volume Lead)
// If we are in Performance/LiveResult, checking volume matters for resolving ties.
// Assuming equal lives, higher volume is better.
let p0_vol = state.players[0].current_turn_volume;
let p1_vol = state.players[1].current_turn_volume;
if p0_vol > p1_vol {
final_val += 0.05;
} else if p1_vol > p0_vol {
final_val -= 0.05;
}
final_val.clamp(0.0, 1.0)
}
fn evaluate_player(state: &GameState, db: &CardDatabase, p_idx: usize, baseline_score: u32, deck_stats: Option<([f32; 7], f32)>) -> f32 {
let p = &state.players[p_idx];
let mut score = 0.0;
// 1. Success Lives (The Goal)
// Heavily weight cleared lives.
score += p.success_lives.len() as f32 * 5.0; // Increased weight to prioritize winning
// Bonus for score increase relative to baseline
if p.success_lives.len() > baseline_score as usize {
score += 2.0;
}
// 2. Board State (Stage)
let mut stage_hearts = [0u32; 7];
let mut stage_blades = 0;
for i in 0..3 {
let h = state.get_effective_hearts(p_idx, i, db);
for color in 0..7 { stage_hearts[color] += h[color] as u32; }
stage_blades += state.get_effective_blades(p_idx, i, db);
}
// Small bonus for board presence/power
score += stage_blades as f32 * 0.1;
// 3. Deck Expectations (Yells)
// Calculate expected value from deck based on probability
let (avg_hearts, avg_vol) = if let Some(stats) = deck_stats {
stats
} else {
Self::calculate_deck_expectations(&p.deck, db)
};
let expected_yell_hearts: Vec<f32> = avg_hearts.iter().map(|&h| h * stage_blades as f32).collect();
let expected_volume = avg_vol * stage_blades as f32;
// 4. Live Clearing Probability
let mut max_prob = 0.0;
for &cid in &p.live_zone {
if cid >= 0 {
if let Some(l) = db.get_live(cid as u16) {
let prob = Self::calculate_live_success_prob(
l,
&stage_hearts,
&expected_yell_hearts,
&p.heart_req_reductions
);
// Reward probability of clearing.
score += prob * (l.score as f32 * 2.0); // Higher weight for potential clears
if prob > max_prob { max_prob = prob; }
}
}
}
// Reward Volume if we are likely to clear at least one live
if max_prob > 0.5 {
score += (expected_volume + p.current_turn_volume as f32) * max_prob * 0.5;
}
// 5. Hand Quality (Construction)
let hand_val = Self::calculate_hand_quality(state, db, p_idx);
score += hand_val * 0.15;
// 6. Resources & Discard
score += p.hand.len() as f32 * 0.05;
// Energy efficiency (using available energy)
let unused_energy = p.tapped_energy.iter().filter(|&&t| !t).count();
score += unused_energy as f32 * 0.01;
// Discard Recovery Potential
// If hand has recovery, value good stuff in discard
let has_recovery = p.hand.iter().any(|&cid| {
if let Some(m) = db.get_member(cid) {
m.abilities.iter().any(|a| Self::has_opcode(&a.bytecode, 15) || Self::has_opcode(&a.bytecode, 17))
} else { false }
});
if has_recovery {
let discard_val = p.discard.iter().filter(|&&cid| {
db.get_live(cid).is_some() || db.get_member(cid).map_or(false, |m| m.cost >= 3)
}).count();
score += discard_val as f32 * 0.1;
}
score
}
fn calculate_deck_expectations(deck: &[u16], db: &CardDatabase) -> ([f32; 7], f32) {
if deck.is_empty() { return ([0.0; 7], 0.0); }
// Since deck is determinized (shuffled) in the leaf node state,
// we could just look at the top N cards if we knew N (blades).
// But `blades` changes. So getting average stats of the *remaining* deck is more robust.
let mut total_hearts = [0.0; 7];
let mut total_vol = 0.0;
let count = deck.len() as f32;
for &cid in deck {
if let Some(m) = db.get_member(cid) {
for i in 0..7 { total_hearts[i] += m.blade_hearts[i] as f32; }
total_vol += m.volume_icons as f32;
} else if let Some(l) = db.get_live(cid) {
for i in 0..7 { total_hearts[i] += l.blade_hearts[i] as f32; }
total_vol += l.volume_icons as f32;
}
}
let avg_hearts = total_hearts.map(|v| v / count);
let avg_vol = total_vol / count;
(avg_hearts, avg_vol)
}
fn calculate_live_success_prob(live: &LiveCard, stage_hearts: &[u32; 7], expected_yell_hearts: &[f32], reductions: &[i32; 7]) -> f32 {
let mut prob;
// 1. Hearts Check
let mut needed = live.required_hearts;
// Apply reductions
for i in 0..7 {
needed[i] = (needed[i] as i32 - reductions[i]).max(0) as u8;
}
let mut satisfied = 0.0;
let mut total_req = 0.0;
let mut wildcards_avail = stage_hearts[6] as f32 + expected_yell_hearts[6];
// Specific Colors
for i in 0..6 {
let req = needed[i] as f32;
total_req += req;
let have = stage_hearts[i] as f32 + expected_yell_hearts[i];
if have >= req {
satisfied += req;
} else {
satisfied += have;
let deficit = req - have;
let used_wild = wildcards_avail.min(deficit);
satisfied += used_wild;
wildcards_avail -= used_wild;
}
}
// ANY (Star) Requirement
let any_req = needed[6] as f32;
total_req += any_req;
let used_wild = wildcards_avail.min(any_req);
satisfied += used_wild;
let mut remaining_any = any_req - used_wild;
if remaining_any > 0.0 {
// Use surpluses
for i in 0..6 {
let req = needed[i] as f32;
let have = stage_hearts[i] as f32 + expected_yell_hearts[i];
let surplus = (have - req).max(0.0);
let used = surplus.min(remaining_any);
satisfied += used;
remaining_any -= used;
if remaining_any <= 0.0 { break; }
}
}
if total_req > 0.0 {
prob = (satisfied / total_req).clamp(0.0, 1.0);
// Non-linear scaling: getting 90% is much better than 50%
prob = prob.powf(0.5); // Square root makes it optimistic? No, we want to penalize failure.
// Actually, close to 1.0 is good.
if prob >= 1.0 { prob = 1.2; } // Guaranteed win bonus
} else {
prob = 1.2; // Free live
}
prob
}
fn calculate_hand_quality(state: &GameState, db: &CardDatabase, p_idx: usize) -> f32 {
let p = &state.players[p_idx];
let mut val = 0.0;
// In Mulligan phase, assume we will have ~3 energy soon for cost evaluations
let is_mulligan = match state.phase {
Phase::MulliganP1 | Phase::MulliganP2 => true,
_ => false,
};
let max_energy = if is_mulligan { 3 } else { p.energy_zone.len() as u32 };
for (i, &cid) in p.hand.iter().enumerate() {
let card_val = Self::calculate_card_potential(cid, db, max_energy);
// If card is selected for mulligan, it's effectively "gone" but replaced by an "average" card
// Average card potential is roughly 1.0
if is_mulligan && ((p.mulligan_selection >> i) & 1u64 == 1) {
val += 1.0; // Expected potential of replacement
} else {
val += card_val;
}
}
val
}
fn calculate_card_potential(cid: u16, db: &CardDatabase, max_energy: u32) -> f32 {
if let Some(m) = db.get_member(cid) {
let mut score = 0.0;
// Cost Efficiency
let stat_sum: u32 = m.hearts.iter().map(|&x| x as u32).sum();
score += (m.blades as f32 + stat_sum as f32) / (m.cost as f32 + 1.0);
// Penalty for unplayable high-cost cards
if m.cost > max_energy {
let diff = m.cost - max_energy;
score -= diff as f32 * 0.5;
}
// Fast Flag Checks (O(1))
let f = m.ability_flags;
if (f & FLAG_DRAW) != 0 { score += 0.5; }
if (f & FLAG_SEARCH) != 0 { score += 0.6; }
if (f & FLAG_RECOVER) != 0 { score += 0.4; }
if (f & FLAG_BUFF) != 0 { score += 0.3; }
if (f & FLAG_CHARGE) != 0 { score += 0.8; }
if (f & FLAG_TEMPO) != 0 { score += 0.2; }
if (f & FLAG_REDUCE) != 0 { score += 0.4; }
if (f & FLAG_BOOST) != 0 { score += 0.5; }
if (f & FLAG_TRANSFORM) != 0 { score += 0.3; }
if (f & FLAG_WIN_COND) != 0 { score += 0.6; }
score
} else if let Some(l) = db.get_live(cid) {
// Lives in hand are good if we can clear them, but bad if they clog
// Simple heuristic: Score value
l.score as f32 * 0.2
} else {
0.0
}
}
fn has_opcode(bytecode: &[i32], target_op: i32) -> bool {
let mut i = 0;
while i < bytecode.len() {
if i + 3 >= bytecode.len() { break; }
let op = bytecode[i];
if op == target_op { return true; }
i += 4;
}
false
}
}
#[cfg(feature = "nn")]
pub struct HybridMCTS {
pub session: Arc<Mutex<Session>>,
pub neural_weight: f32,
pub skip_rollout: bool,
pub rng: SmallRng,
}
#[cfg(feature = "nn")]
impl HybridMCTS {
pub fn new(session: Arc<Mutex<Session>>, neural_weight: f32, skip_rollout: bool) -> Self {
Self {
session,
neural_weight,
skip_rollout,
rng: SmallRng::from_os_rng(),
}
}
pub fn get_suggestions(&mut self, state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32) -> Vec<(i32, f32, u32)> {
let start = std::time::Instant::now();
let (stats, profile) = self.search(state, db, num_sims, timeout_sec);
let duration = start.elapsed();
if num_sims > 10 {
profile.print(duration);
}
stats
}
pub fn search(&mut self, root_state: &GameState, db: &CardDatabase, num_sims: usize, timeout_sec: f32) -> (Vec<(i32, f32, u32)>, MCTSProfiler) {
let session_arc = self.session.clone();
let neural_weight = self.neural_weight;
let mut mcts = MCTS::new();
mcts.run_mcts_config(root_state, db, num_sims, timeout_sec, SearchHorizon::GameEnd, false, !self.skip_rollout, |state: &GameState, db: &CardDatabase| {
if state.is_terminal() {
return match state.get_winner() {
0 => 1.0,
1 => 0.0,
_ => 0.5,
};
}
// Normal Heuristic Baseline
let h_val = MCTS::heuristic_eval(state, db, root_state.players[0].score, root_state.players[1].score, EvalMode::Normal, None, None);
// NN Evaluation
let input_vec = state.encode_state(db);
let mut session = session_arc.lock().unwrap();
let input_shape = [1, input_vec.len()];
// Try to create input tensor and run using (shape, vec) which is version-agnostic
if let Ok(input_tensor) = ort::value::Value::from_array((input_shape, input_vec)) {
if let Ok(outputs) = session.run(ort::inputs![input_tensor]) {
// Try to get value output. AlphaNet has 'output_1' or index 1
let val_opt = outputs.get("output_1")
.or_else(|| outputs.get("value"));
if let Some(v_val) = val_opt {
if let Ok((_, v_slice)) = v_val.try_extract_tensor::<f32>() {
let nn_val = v_slice[0];
let nn_norm = (nn_val * 0.5 + 0.5).clamp(0.0, 1.0) as f32;
return h_val * (1.0 - neural_weight) + nn_norm * neural_weight;
}
} else if outputs.len() > 1 {
// Fallback to index if names missing
if let Ok((_, v_slice)) = outputs[1].try_extract_tensor::<f32>() {
let nn_val = v_slice[0];
let nn_norm = (nn_val * 0.5 + 0.5).clamp(0.0, 1.0) as f32;
return h_val * (1.0 - neural_weight) + nn_norm * neural_weight;
}
}
}
}
h_val
})
}
}