use pyo3::prelude::*; use rayon::prelude::*; use numpy::{PyArray1, PyArray2, PyReadonlyArray1, PyArrayMethods}; use crate::core::logic::{GameState, PlayerState, Phase}; use crate::core::mcts::{SearchHorizon, EvalMode}; use smallvec::SmallVec; // use crate::core::heuristics::{OriginalHeuristic, SimpleHeuristic}; #[pyclass] #[derive(Clone)] pub struct PyPlayerState { pub inner: PlayerState, } #[pymethods] impl PyPlayerState { #[getter] fn player_id(&self) -> u8 { self.inner.player_id } #[getter] fn score(&self) -> u32 { self.inner.score } #[getter] fn success_lives(&self) -> Vec { self.inner.success_lives.iter().map(|&x| x as u32).collect() } #[getter] fn hand(&self) -> Vec { self.inner.hand.iter().map(|&x| x as u32).collect() } #[getter] fn stage(&self) -> [i32; 3] { self.inner.stage.map(|x| x as i32) } #[getter] fn discard(&self) -> Vec { self.inner.discard.iter().map(|&x| x as u32).collect() } #[getter] fn exile(&self) -> Vec { self.inner.exile.iter().map(|&x| x as u32).collect() } #[getter] fn deck(&self) -> Vec { self.inner.deck.iter().map(|&x| x as u32).collect() } #[getter] fn energy_zone(&self) -> Vec { self.inner.energy_zone.iter().map(|&x| x as u32).collect() } #[getter] fn live_zone(&self) -> [i32; 3] { self.inner.live_zone.map(|x| x as i32) } #[getter] fn live_zone_revealed(&self) -> [bool; 3] { self.inner.live_zone_revealed } #[getter] fn tapped_energy(&self) -> Vec { self.inner.tapped_energy.to_vec() } #[getter] fn tapped_members(&self) -> [bool; 3] { self.inner.tapped_members } #[setter] fn set_hand(&mut self, val: Vec) { self.inner.hand = val.into_iter().map(|x| x as u16).collect(); } #[setter] fn set_energy_zone(&mut self, val: Vec) { self.inner.energy_zone = val.into_iter().map(|x| x as u16).collect(); } #[setter] fn set_discard(&mut self, val: Vec) { self.inner.discard = val.into_iter().map(|x| x as u16).collect(); } #[setter] fn set_score(&mut self, val: u32) { self.inner.score = val; } #[setter(deck)] fn set_deck(&mut self, val: Vec) { self.inner.deck = val.into_iter().map(|x| x as u16).collect(); } #[setter] fn set_tapped_energy(&mut self, val: Vec) { self.inner.tapped_energy = SmallVec::from_vec(val); } #[setter] fn set_tapped_members(&mut self, val: [bool; 3]) { self.inner.tapped_members = val; } #[setter] fn set_moved_members_this_turn(&mut self, val: [bool; 3]) { self.inner.moved_members_this_turn = val; } #[setter] fn set_stage(&mut self, val: [i32; 3]) { self.inner.stage = val.map(|x| x as i16); } #[setter] fn set_live_zone(&mut self, val: [i32; 3]) { self.inner.live_zone = val.map(|x| x as i16); } #[setter] fn set_live_zone_revealed(&mut self, val: [bool; 3]) { self.inner.live_zone_revealed = val; } #[getter] fn mulligan_selection(&self) -> u64 { self.inner.mulligan_selection } #[getter] fn deck_count(&self) -> usize { self.inner.deck.len() } #[getter] fn energy_deck_count(&self) -> usize { self.inner.energy_deck.len() } #[getter] fn hand_added_turn(&self) -> Vec { self.inner.hand_added_turn.iter().map(|&x| x as u32).collect() } #[getter] fn looked_cards(&self) -> Vec { self.inner.looked_cards.iter().map(|&x| x as u32).collect() } #[getter] fn yell_cards(&self) -> Vec { // Moved to GameState Vec::new() } #[setter] fn set_yell_cards(&mut self, _val: Vec) { // Moved to GameState } #[getter] pub fn heart_buffs(&self) -> Vec> { self.inner.heart_buffs.iter().map(|h| h.to_vec()).collect() } #[setter] pub fn set_heart_buffs(&mut self, val: Vec>) { for (i, v) in val.iter().enumerate() { if i < 3 && v.len() == 7 { for (j, &heart) in v.iter().enumerate() { self.inner.heart_buffs[i][j] = heart; } } } } #[getter] pub fn blade_buffs(&self) -> Vec { self.inner.blade_buffs.to_vec() } #[setter] pub fn set_blade_buffs(&mut self, val: Vec) { for (i, &v) in val.iter().enumerate() { if i < 3 { self.inner.blade_buffs[i] = v; } } } } #[pyclass] #[derive(Clone)] pub struct PyCardDatabase { pub inner: std::sync::Arc, } #[pymethods] impl PyCardDatabase { #[new] fn new(json_str: &str) -> PyResult { let db = crate::core::logic::CardDatabase::from_json(json_str) .map_err(|e: serde_json::Error| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; Ok(Self { inner: std::sync::Arc::new(db) }) } #[getter] fn member_count(&self) -> usize { self.inner.members.len() } #[getter] fn live_count(&self) -> usize { self.inner.lives.len() } fn has_member(&self, card_id: u32) -> bool { self.inner.members.contains_key(&(card_id as u16)) } fn get_member_ids(&self) -> Vec { self.inner.members.keys().map(|&k| k as u32).collect() } } #[pyclass] pub struct PyGameState { pub inner: GameState, pub db: PyCardDatabase, pub legal_action_buffer: Vec, } #[pymethods] impl PyGameState { #[new] fn new(db: PyCardDatabase) -> PyResult { Ok(Self { inner: GameState::default(), db, legal_action_buffer: vec![false; crate::core::logic::ACTION_SPACE], }) } pub fn copy(&self) -> Self { Self { inner: self.inner.clone(), db: self.db.clone(), legal_action_buffer: vec![false; crate::core::logic::ACTION_SPACE], } } pub fn ping(&self) -> String { "pong_v_force_fix_1212".to_string() } #[getter] fn db(&self) -> PyCardDatabase { self.db.clone() } #[getter] fn current_player(&self) -> u8 { self.inner.current_player } #[setter] fn set_current_player(&mut self, val: u8) { self.inner.current_player = val; } #[getter] fn first_player(&self) -> u8 { self.inner.first_player } #[setter] fn set_first_player(&mut self, val: u8) { self.inner.first_player = val; } #[getter] fn rule_log(&self) -> Vec { self.inner.rule_log.clone() } #[getter] fn phase(&self) -> i8 { self.inner.phase as i8 } #[getter] fn turn(&self) -> u32 { self.inner.turn as u32 } #[setter] fn set_turn(&mut self, val: u32) { self.inner.turn = val as u16; } #[getter] fn silent(&self) -> bool { self.inner.silent } #[setter] fn set_silent(&mut self, val: bool) { self.inner.silent = val; } #[getter] fn performance_results(&self) -> String { serde_json::to_string(&self.inner.performance_results).unwrap_or_default() } #[getter] fn pending_card_id(&self) -> i32 { self.inner.pending_card_id as i32 } #[getter] fn pending_ab_idx(&self) -> i32 { self.inner.pending_ab_idx as i32 } #[getter] fn pending_effect_opcode(&self) -> i32 { self.inner.pending_effect_opcode as i32 } #[getter] fn pending_choice_type(&self) -> String { self.inner.pending_choice_type.clone() } #[getter] fn yell_cards(&self) -> Vec { // Moved to GameState Vec::new() } #[setter] fn set_yell_cards(&mut self, _val: Vec) { // Moved to GameState } #[getter] fn pending_area_idx(&self) -> i32 { if let Some(ctx) = &self.inner.pending_ctx { ctx.area_idx as i32 } else { -1 } } #[getter] fn pending_player_id(&self) -> i32 { if let Some(ctx) = &self.inner.pending_ctx { ctx.player_id as i32 } else { -1 } } #[getter] fn last_performance_results(&self) -> String { serde_json::to_string(&self.inner.last_performance_results).unwrap_or_else(|_| "{}".to_string()) } #[getter] fn performance_history(&self) -> String { serde_json::to_string(&self.inner.performance_history).unwrap_or_else(|_| "[]".to_string()) } #[getter] fn pending_choices(&self) -> Vec<(String, String)> { use crate::core::logic::*; let mut result = Vec::new(); let op = self.inner.pending_effect_opcode; let p_idx = if let Some(ctx) = &self.inner.pending_ctx { ctx.player_id as usize } else { self.inner.current_player as usize }; if op == O_ORDER_DECK as i16 || op == O_LOOK_AND_CHOOSE as i16 || op == O_REVEAL as i16 { let looked = &self.inner.players[p_idx].looked_cards; let params = serde_json::json!({ "cards": looked }); let type_str = if op == O_ORDER_DECK as i16 { "ORDER_DECK" } else { "SELECT_FROM_LIST" }; result.push((type_str.to_string(), params.to_string())); } else if op == O_TAP_O as i16 { result.push(("TARGET_OPPONENT_MEMBER".to_string(), "{}".to_string())); } else if op == O_MOVE_MEMBER as i16 { result.push(("MOVE_MEMBER".to_string(), "{}".to_string())); } else if op == O_ACTIVATE_MEMBER as i16 { result.push(("TAP_MEMBER".to_string(), "{}".to_string())); } else if op == O_COLOR_SELECT as i16 { result.push(("COLOR_SELECT".to_string(), "{}".to_string())); } else if op == O_SELECT_MODE as i16 { // We might need to store the options in the state if we want better labels result.push(("SELECT_MODE".to_string(), "{}".to_string())); } result } #[getter] fn pending_effects(&self) -> Vec { Vec::new() } fn get_player(&self, idx: usize) -> PyResult { if idx < 2 { Ok(PyPlayerState { inner: self.inner.players[idx].clone() }) } else { Err(pyo3::exceptions::PyIndexError::new_err("Player index out of bounds")) } } fn initialize_game(&mut self, p0_deck: Vec, p1_deck: Vec, p0_energy: Vec, p1_energy: Vec, p0_lives: Vec, p1_lives: Vec) { let p0_d: Vec = p0_deck.into_iter().map(|x| x as u16).collect(); let p1_d: Vec = p1_deck.into_iter().map(|x| x as u16).collect(); let p0_e: Vec = p0_energy.into_iter().map(|x| x as u16).collect(); let p1_e: Vec = p1_energy.into_iter().map(|x| x as u16).collect(); let p0_l: Vec = p0_lives.into_iter().map(|x| x as u16).collect(); let p1_l: Vec = p1_lives.into_iter().map(|x| x as u16).collect(); self.inner.initialize_game(p0_d, p1_d, p0_e, p1_e, p0_l, p1_l); } fn initialize_game_with_seed(&mut self, p0_deck: Vec, p1_deck: Vec, p0_energy: Vec, p1_energy: Vec, p0_lives: Vec, p1_lives: Vec, seed: u64) { let p0_d: Vec = p0_deck.into_iter().map(|x| x as u16).collect(); let p1_d: Vec = p1_deck.into_iter().map(|x| x as u16).collect(); let p0_e: Vec = p0_energy.into_iter().map(|x| x as u16).collect(); let p1_e: Vec = p1_energy.into_iter().map(|x| x as u16).collect(); let p0_l: Vec = p0_lives.into_iter().map(|x| x as u16).collect(); let p1_l: Vec = p1_lives.into_iter().map(|x| x as u16).collect(); self.inner.initialize_game_with_seed(p0_d, p1_d, p0_e, p1_e, p0_l, p1_l, Some(seed)); } fn get_legal_actions(&mut self) -> Vec { self.inner.get_legal_actions_into(&self.db.inner, &mut self.legal_action_buffer); self.legal_action_buffer.clone() } fn get_legal_action_ids(&mut self) -> Vec { self.inner.get_legal_action_ids(&self.db.inner) } fn get_observation(&self) -> Vec { self.inner.get_observation(&self.db.inner) } fn is_terminal(&self) -> bool { self.inner.phase == Phase::Terminal } fn get_winner(&self) -> i32 { self.inner.get_winner() } fn get_effective_blades(&self, p_idx: usize, slot_idx: usize) -> u32 { self.inner.get_effective_blades(p_idx, slot_idx, &self.db.inner) } fn get_effective_hearts(&self, p_idx: usize, slot_idx: usize) -> [u8; 7] { self.inner.get_effective_hearts(p_idx, slot_idx, &self.db.inner) } fn get_total_blades(&self, p_idx: usize) -> u32 { self.inner.get_total_blades(p_idx, &self.db.inner) } fn get_total_hearts(&self, p_idx: usize) -> [u32; 7] { self.inner.get_total_hearts(p_idx, &self.db.inner) } fn get_member_cost(&self, p_idx: usize, card_id: i32, slot_idx: i32) -> i32 { self.inner.get_member_cost(p_idx, card_id, slot_idx, &self.db.inner) } fn execute_mulligan(&mut self, player_idx: usize, discard_indices: Vec) { self.inner.execute_mulligan(player_idx, discard_indices); } fn step(&mut self, action: i32) -> PyResult<()> { let db = &self.db.inner; self.inner.step(db, action) .map_err(|e| pyo3::exceptions::PyValueError::new_err(e)) } fn debug_execute_bytecode(&mut self, bytecode: Vec, player_id: u8, area_idx: i32, source_card_id: i32, target_slot: i32, choice_index: i32, selected_color: i32) { let db = &self.db.inner; let ctx = crate::core::logic::AbilityContext { player_id, area_idx: area_idx as i16, source_card_id: source_card_id as i16, target_slot: target_slot as i16, choice_index: choice_index as i16, selected_color: selected_color as i16, program_counter: 0, ability_index: -1, }; self.inner.resolve_bytecode(db, &bytecode, &ctx); } fn integrated_step(&mut self, action: i32, opp_mode: u8, mcts_sims: usize, enable_rollout: bool) -> (f32, bool) { let db = &self.db.inner; self.inner.integrated_step(db, action, opp_mode, mcts_sims, enable_rollout) } #[pyo3(signature = (p0_sims, p1_sims, p0_heuristic_id, p1_heuristic_id, horizon=SearchHorizon::GameEnd, p0_rollout=true, p1_rollout=true))] fn play_asymmetric_match(&mut self, p0_sims: usize, p1_sims: usize, p0_heuristic_id: i32, p1_heuristic_id: i32, horizon: SearchHorizon, p0_rollout: bool, p1_rollout: bool) -> (i32, u32) { let db = &self.db.inner; self.inner.play_asymmetric_match(db, p0_sims, p1_sims, p0_heuristic_id, p1_heuristic_id, horizon, p0_rollout, p1_rollout) } #[pyo3(signature = (p0_sims, p1_sims, p0_heuristic_id, p1_heuristic_id, horizon=SearchHorizon::GameEnd, enable_rollout=true))] fn play_mirror_match(&mut self, p0_sims: usize, p1_sims: usize, p0_heuristic_id: i32, p1_heuristic_id: i32, horizon: SearchHorizon, enable_rollout: bool) -> (i32, u32) { let db = &self.db.inner; self.inner.play_mirror_match(db, p0_sims, p1_sims, p0_heuristic_id, p1_heuristic_id, horizon, enable_rollout) } fn step_opponent(&mut self) { let db = &self.db.inner; self.inner.step_opponent(db); } fn step_opponent_mcts(&mut self, sims: usize) { let db = &self.db.inner; self.inner.step_opponent_mcts(db, sims, &crate::core::heuristics::OriginalHeuristic); } fn step_opponent_greedy(&mut self) { let db = &self.db.inner; self.inner.step_opponent_greedy(db, &crate::core::heuristics::OriginalHeuristic); } fn get_greedy_action(&mut self) -> i32 { let db = &self.db.inner; self.inner.get_greedy_action(db, &crate::core::heuristics::OriginalHeuristic) } #[pyo3(signature = (sims, horizon=SearchHorizon::GameEnd, eval_mode=EvalMode::Blind))] fn get_mcts_suggestions(&mut self, sims: usize, horizon: SearchHorizon, eval_mode: EvalMode) -> Vec<(i32, f32, u32)> { let db = &self.db.inner; self.inner.get_mcts_suggestions(db, sims, horizon, eval_mode) } #[setter] fn set_phase(&mut self, val: i8) { self.inner.phase = match val { -1 => Phase::MulliganP1, 0 => Phase::MulliganP2, 1 => Phase::Active, 2 => Phase::Energy, 3 => Phase::Draw, 4 => Phase::Main, 5 => Phase::LiveSet, 6 => Phase::PerformanceP1, 7 => Phase::PerformanceP2, 8 => Phase::LiveResult, 9 => Phase::Terminal, 10 => Phase::Response, _ => Phase::Setup, }; } fn set_player(&mut self, idx: usize, player: PyPlayerState) -> PyResult<()> { if idx < 2 { self.inner.log(format!("set_player {}: Discard len = {}", idx, player.inner.discard.len())); self.inner.players[idx] = player.inner; Ok(()) } else { Err(pyo3::exceptions::PyIndexError::new_err("Player index out of bounds")) } } fn set_stage_card(&mut self, p_idx: usize, slot_idx: usize, card_id: i32) { if p_idx < 2 && slot_idx < 3 { self.inner.players[p_idx].stage[slot_idx] = card_id as i16; } } fn set_live_card(&mut self, p_idx: usize, slot_idx: usize, card_id: i32, revealed: bool) { if p_idx < 2 && slot_idx < 3 { self.inner.players[p_idx].live_zone[slot_idx] = card_id as i16; self.inner.players[p_idx].live_zone_revealed[slot_idx] = revealed; } } fn set_hand_cards(&mut self, p_idx: usize, cards: Vec) { if p_idx < 2 { self.inner.players[p_idx].hand = cards.into_iter().map(|x| x as u16).collect(); self.inner.players[p_idx].hand_added_turn = SmallVec::from_vec(vec![self.inner.turn as u16; self.inner.players[p_idx].hand.len()]); } } fn resolve_bytecode(&mut self, bytecode: Vec, player_id: u8, area_idx: i32) { let ctx = crate::core::logic::AbilityContext { player_id, area_idx: area_idx as i16, ..crate::core::logic::AbilityContext::default() }; self.inner.resolve_bytecode(&self.db.inner, &bytecode, &ctx); } fn trigger_abilities(&mut self, trigger: i32, player_id: u8) { let trigger_type = unsafe { std::mem::transmute::(trigger as i8) }; let ctx = crate::core::logic::AbilityContext { player_id, ..crate::core::logic::AbilityContext::default() }; self.inner.trigger_abilities(&self.db.inner, trigger_type, &ctx); } #[pyo3(signature = (num_sims=0, seconds=0.0, heuristic_type="original", horizon=SearchHorizon::GameEnd, eval_mode=EvalMode::Blind, model_path=None))] fn search_mcts(&self, num_sims: usize, seconds: f32, heuristic_type: &str, horizon: SearchHorizon, eval_mode: EvalMode, model_path: Option<&str>) -> Vec<(i32, f32, u32)> { if heuristic_type == "resnet" || heuristic_type == "hybrid" { #[cfg(feature = "nn")] { let actual_path = model_path.unwrap_or("ai/models/alphanet_best.onnx"); // Helper for fallback to avoid repetition let run_fallback = || { let mcts = crate::core::mcts::MCTS::new(); let h = crate::core::heuristics::OriginalHeuristic; mcts.search_parallel(&self.inner, &self.db.inner, num_sims, seconds, horizon, &h, eval_mode == EvalMode::Blind) }; // Check if file exists to avoid panics if !std::path::Path::new(actual_path).exists() { return run_fallback(); } let session_res = ort::session::Session::builder() .and_then(|b| b.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)) .and_then(|b| b.with_intra_threads(1)) .and_then(|b| b.commit_from_file(actual_path)); match session_res { Ok(session) => { let weight = if heuristic_type == "resnet" { 1.0 } else { 0.5 }; let session_arc = std::sync::Arc::new(std::sync::Mutex::new(session)); let mut hybrid = crate::core::mcts::HybridMCTS::new(session_arc, weight, false); return hybrid.get_suggestions(&self.inner, &self.db.inner, num_sims, seconds); } Err(_) => { return run_fallback(); } } } #[cfg(not(feature = "nn"))] { let mcts = crate::core::mcts::MCTS::new(); let h = crate::core::heuristics::OriginalHeuristic; return mcts.search_parallel(&self.inner, &self.db.inner, num_sims, seconds, horizon, &h, eval_mode == EvalMode::Blind); } } let mcts = crate::core::mcts::MCTS::new(); let h: Box = match heuristic_type { "simple" => Box::new(crate::core::heuristics::SimpleHeuristic), _ => Box::new(crate::core::heuristics::OriginalHeuristic), }; mcts.search_parallel(&self.inner, &self.db.inner, num_sims, seconds, horizon, h.as_ref(), eval_mode == EvalMode::Blind) } } #[pyclass] pub struct PyVectorGameState { envs: Vec, db: PyCardDatabase, p0_deck: Vec, p1_deck: Vec, p0_lives: Vec, p1_lives: Vec, seeds: Vec, opp_mode: u8, mcts_sims: usize, } #[pymethods] impl PyVectorGameState { #[new] #[pyo3(signature = (num_envs, db, opp_mode=0, mcts_sims=50))] fn new(num_envs: usize, db: PyCardDatabase, opp_mode: u8, mcts_sims: usize) -> Self { let mut envs = Vec::with_capacity(num_envs); for _ in 0..num_envs { envs.push(GameState::default()); } Self { envs, db, p0_deck: Vec::new(), p1_deck: Vec::new(), p0_lives: Vec::new(), p1_lives: Vec::new(), seeds: vec![0; num_envs], opp_mode, mcts_sims, } } fn initialize(&mut self, p0_deck: Vec, p1_deck: Vec, p0_lives: Vec, p1_lives: Vec, seed: u64) { self.p0_deck = p0_deck; self.p1_deck = p1_deck; self.p0_lives = p0_lives; self.p1_lives = p1_lives; let num_envs = self.envs.len(); for i in 0..num_envs { self.seeds[i] = seed + i as u64; } self.envs.par_iter_mut().enumerate().for_each(|(i, env)| { env.initialize_game_with_seed( self.p0_deck.iter().map(|&x| x as u16).collect(), self.p1_deck.iter().map(|&x| x as u16).collect(), Vec::new(), Vec::new(), self.p0_lives.iter().map(|&x| x as u16).collect(), self.p1_lives.iter().map(|&x| x as u16).collect(), Some(self.seeds[i]) ); }); } #[allow(clippy::too_many_arguments)] fn step<'py>( &mut self, _py: Python<'py>, actions: PyReadonlyArray1<'py, i32>, obs_out: &Bound<'py, PyArray2>, rewards_out: &Bound<'py, PyArray1>, dones_out: &Bound<'py, PyArray1>, term_obs_out: &Bound<'py, PyArray2>, ) -> PyResult> { let actions = actions.as_slice()?; let obs_slice = unsafe { obs_out.as_slice_mut()? }; let rewards_slice = unsafe { rewards_out.as_slice_mut()? }; let dones_slice = unsafe { dones_out.as_slice_mut()? }; let term_obs_slice = unsafe { term_obs_out.as_slice_mut()? }; let num_envs = self.envs.len(); let db = &self.db.inner; let obs_dim = 320; if actions.len() != num_envs { return Err(pyo3::exceptions::PyValueError::new_err("Action dim mismatch")); } // 1. Step let opp_mode = self.opp_mode; let mcts_sims = self.mcts_sims; let results: Vec<(f32, bool)> = self.envs.par_iter_mut().zip(actions.par_iter()) .map(|(env, &act)| { env.integrated_step(db, act, opp_mode, mcts_sims, true) }).collect(); results.par_iter().zip(rewards_slice.par_iter_mut()).zip(dones_slice.par_iter_mut()) .for_each(|((&(r, d), r_out), d_out)| { *r_out = r; *d_out = d; }); // 2. Filter Done let mut done_indices = Vec::with_capacity(num_envs / 10); for (i, &(_, done)) in results.iter().enumerate() { if done { done_indices.push(i); } } // 3. Write Terminal Obs (Before Reset) if !done_indices.is_empty() { term_obs_slice.par_chunks_mut(obs_dim).zip(done_indices.par_iter()) .for_each(|(chunk, &env_idx)| { self.envs[env_idx].write_observation(db, chunk); }); } // 4. Auto-Reset let p0_deck = &self.p0_deck; let p1_deck = &self.p1_deck; let p0_lives = &self.p0_lives; let p1_lives = &self.p1_lives; self.envs.par_iter_mut().zip(results.par_iter()).for_each(|(env, &(_, done))| { if done { env.initialize_game_with_seed( p0_deck.iter().map(|&x| x as u16).collect(), p1_deck.iter().map(|&x| x as u16).collect(), Vec::new(), Vec::new(), p0_lives.iter().map(|&x| x as u16).collect(), p1_lives.iter().map(|&x| x as u16).collect(), None ); } }); // 5. Write Final Obs obs_slice.par_chunks_mut(obs_dim).zip(self.envs.par_iter()) .for_each(|(chunk, env)| { env.write_observation(db, chunk); }); Ok(done_indices) } // New: Zero-Copy get_observations fn get_observations<'py>(&self, _py: Python<'py>, out: &Bound<'py, PyArray2>) -> PyResult<()> { let db = &self.db.inner; let obs_dim = 320; let obs_slice = unsafe { out.as_slice_mut()? }; obs_slice.par_chunks_mut(obs_dim).zip(self.envs.par_iter()) .for_each(|(chunk, env)| { env.write_observation(db, chunk); }); Ok(()) } // New: Zero-Copy get_action_masks fn get_action_masks<'py>(&self, _py: Python<'py>, out: &Bound<'py, PyArray2>) -> PyResult<()> { let db = &self.db.inner; let action_dim = crate::core::logic::ACTION_SPACE; let mask_slice = unsafe { out.as_slice_mut()? }; mask_slice.par_chunks_mut(action_dim).zip(self.envs.par_iter()) .for_each(|(chunk, env)| { env.get_legal_actions_into(db, chunk); }); Ok(()) } } #[cfg(feature = "nn")] #[pyclass] pub struct PyHybridMCTS { pub session: std::sync::Arc>, pub neural_weight: f32, pub skip_rollout: bool, } #[cfg(feature = "nn")] #[pymethods] impl PyHybridMCTS { #[new] #[pyo3(signature = (model_path, neural_weight=0.3, skip_rollout=false))] fn new(model_path: &str, neural_weight: f32, skip_rollout: bool) -> PyResult { let session = ort::session::Session::builder() .map_err(|e: ort::Error| pyo3::exceptions::PyValueError::new_err(e.to_string()))? .commit_from_file(model_path) .map_err(|e: ort::Error| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; Ok(Self { session: std::sync::Arc::new(std::sync::Mutex::new(session)), neural_weight, skip_rollout }) } #[pyo3(signature = (game, num_sims=0, seconds=0.0))] fn get_suggestions(&mut self, game: &mut PyGameState, num_sims: usize, seconds: f32) -> Vec<(i32, f32, u32)> { let mut mcts = crate::core::mcts::HybridMCTS::new( self.session.clone(), self.neural_weight, self.skip_rollout ); mcts.get_suggestions(&game.inner, &game.db.inner, num_sims, seconds) } } pub fn register_python_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; #[cfg(feature = "nn")] m.add_class::()?; m.add_class::()?; m.add_class::()?; Ok(()) }