Spaces:
Running
Running
| use pyo3::prelude::*; | |
| use rayon::prelude::*; | |
| use numpy::{PyArray1, PyArray2, PyReadonlyArray1, PyArrayMethods}; | |
| use crate::core::logic::{GameState, PlayerState, Phase}; | |
| use crate::core::mcts::{SearchHorizon, EvalMode}; | |
| use smallvec::SmallVec; | |
| // use crate::core::heuristics::{OriginalHeuristic, SimpleHeuristic}; | |
| pub struct PyPlayerState { | |
| pub inner: PlayerState, | |
| } | |
| impl PyPlayerState { | |
| fn player_id(&self) -> u8 { | |
| self.inner.player_id | |
| } | |
| fn score(&self) -> u32 { | |
| self.inner.score | |
| } | |
| fn success_lives(&self) -> Vec<u32> { | |
| self.inner.success_lives.iter().map(|&x| x as u32).collect() | |
| } | |
| fn hand(&self) -> Vec<u32> { | |
| self.inner.hand.iter().map(|&x| x as u32).collect() | |
| } | |
| fn stage(&self) -> [i32; 3] { | |
| self.inner.stage.map(|x| x as i32) | |
| } | |
| fn discard(&self) -> Vec<u32> { | |
| self.inner.discard.iter().map(|&x| x as u32).collect() | |
| } | |
| fn exile(&self) -> Vec<u32> { | |
| self.inner.exile.iter().map(|&x| x as u32).collect() | |
| } | |
| fn deck(&self) -> Vec<u32> { | |
| self.inner.deck.iter().map(|&x| x as u32).collect() | |
| } | |
| fn energy_zone(&self) -> Vec<u32> { | |
| self.inner.energy_zone.iter().map(|&x| x as u32).collect() | |
| } | |
| fn live_zone(&self) -> [i32; 3] { | |
| self.inner.live_zone.map(|x| x as i32) | |
| } | |
| fn live_zone_revealed(&self) -> [bool; 3] { | |
| self.inner.live_zone_revealed | |
| } | |
| fn tapped_energy(&self) -> Vec<bool> { | |
| self.inner.tapped_energy.to_vec() | |
| } | |
| fn tapped_members(&self) -> [bool; 3] { | |
| self.inner.tapped_members | |
| } | |
| fn set_hand(&mut self, val: Vec<u32>) { | |
| self.inner.hand = val.into_iter().map(|x| x as u16).collect(); | |
| } | |
| fn set_energy_zone(&mut self, val: Vec<u32>) { | |
| self.inner.energy_zone = val.into_iter().map(|x| x as u16).collect(); | |
| } | |
| fn set_discard(&mut self, val: Vec<u32>) { | |
| self.inner.discard = val.into_iter().map(|x| x as u16).collect(); | |
| } | |
| fn set_score(&mut self, val: u32) { | |
| self.inner.score = val; | |
| } | |
| fn set_deck(&mut self, val: Vec<u32>) { | |
| self.inner.deck = val.into_iter().map(|x| x as u16).collect(); | |
| } | |
| fn set_tapped_energy(&mut self, val: Vec<bool>) { | |
| self.inner.tapped_energy = SmallVec::from_vec(val); | |
| } | |
| fn set_tapped_members(&mut self, val: [bool; 3]) { | |
| self.inner.tapped_members = val; | |
| } | |
| fn set_moved_members_this_turn(&mut self, val: [bool; 3]) { | |
| self.inner.moved_members_this_turn = val; | |
| } | |
| fn set_stage(&mut self, val: [i32; 3]) { | |
| self.inner.stage = val.map(|x| x as i16); | |
| } | |
| fn set_live_zone(&mut self, val: [i32; 3]) { | |
| self.inner.live_zone = val.map(|x| x as i16); | |
| } | |
| fn set_live_zone_revealed(&mut self, val: [bool; 3]) { | |
| self.inner.live_zone_revealed = val; | |
| } | |
| fn mulligan_selection(&self) -> u64 { | |
| self.inner.mulligan_selection | |
| } | |
| fn deck_count(&self) -> usize { | |
| self.inner.deck.len() | |
| } | |
| fn energy_deck_count(&self) -> usize { | |
| self.inner.energy_deck.len() | |
| } | |
| fn hand_added_turn(&self) -> Vec<u32> { | |
| self.inner.hand_added_turn.iter().map(|&x| x as u32).collect() | |
| } | |
| fn looked_cards(&self) -> Vec<u32> { | |
| self.inner.looked_cards.iter().map(|&x| x as u32).collect() | |
| } | |
| fn yell_cards(&self) -> Vec<u32> { | |
| // Moved to GameState | |
| Vec::new() | |
| } | |
| fn set_yell_cards(&mut self, _val: Vec<u32>) { | |
| // Moved to GameState | |
| } | |
| pub fn heart_buffs(&self) -> Vec<Vec<i32>> { | |
| self.inner.heart_buffs.iter().map(|h| h.to_vec()).collect() | |
| } | |
| pub fn set_heart_buffs(&mut self, val: Vec<Vec<i32>>) { | |
| for (i, v) in val.iter().enumerate() { | |
| if i < 3 && v.len() == 7 { | |
| for (j, &heart) in v.iter().enumerate() { | |
| self.inner.heart_buffs[i][j] = heart; | |
| } | |
| } | |
| } | |
| } | |
| pub fn blade_buffs(&self) -> Vec<i32> { | |
| self.inner.blade_buffs.to_vec() | |
| } | |
| pub fn set_blade_buffs(&mut self, val: Vec<i32>) { | |
| for (i, &v) in val.iter().enumerate() { | |
| if i < 3 { | |
| self.inner.blade_buffs[i] = v; | |
| } | |
| } | |
| } | |
| } | |
| pub struct PyCardDatabase { | |
| pub inner: std::sync::Arc<crate::core::logic::CardDatabase>, | |
| } | |
| impl PyCardDatabase { | |
| fn new(json_str: &str) -> PyResult<Self> { | |
| let db = crate::core::logic::CardDatabase::from_json(json_str) | |
| .map_err(|e: serde_json::Error| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; | |
| Ok(Self { inner: std::sync::Arc::new(db) }) | |
| } | |
| fn member_count(&self) -> usize { | |
| self.inner.members.len() | |
| } | |
| fn live_count(&self) -> usize { | |
| self.inner.lives.len() | |
| } | |
| fn has_member(&self, card_id: u32) -> bool { | |
| self.inner.members.contains_key(&(card_id as u16)) | |
| } | |
| fn get_member_ids(&self) -> Vec<u32> { | |
| self.inner.members.keys().map(|&k| k as u32).collect() | |
| } | |
| } | |
| pub struct PyGameState { | |
| pub inner: GameState, | |
| pub db: PyCardDatabase, | |
| pub legal_action_buffer: Vec<bool>, | |
| } | |
| impl PyGameState { | |
| fn new(db: PyCardDatabase) -> PyResult<Self> { | |
| Ok(Self { | |
| inner: GameState::default(), | |
| db, | |
| legal_action_buffer: vec![false; crate::core::logic::ACTION_SPACE], | |
| }) | |
| } | |
| pub fn copy(&self) -> Self { | |
| Self { | |
| inner: self.inner.clone(), | |
| db: self.db.clone(), | |
| legal_action_buffer: vec![false; crate::core::logic::ACTION_SPACE], | |
| } | |
| } | |
| pub fn ping(&self) -> String { | |
| "pong_v_force_fix_1212".to_string() | |
| } | |
| fn db(&self) -> PyCardDatabase { | |
| self.db.clone() | |
| } | |
| fn current_player(&self) -> u8 { | |
| self.inner.current_player | |
| } | |
| fn set_current_player(&mut self, val: u8) { | |
| self.inner.current_player = val; | |
| } | |
| fn first_player(&self) -> u8 { | |
| self.inner.first_player | |
| } | |
| fn set_first_player(&mut self, val: u8) { | |
| self.inner.first_player = val; | |
| } | |
| fn rule_log(&self) -> Vec<String> { | |
| self.inner.rule_log.clone() | |
| } | |
| fn phase(&self) -> i8 { | |
| self.inner.phase as i8 | |
| } | |
| fn turn(&self) -> u32 { | |
| self.inner.turn as u32 | |
| } | |
| fn set_turn(&mut self, val: u32) { | |
| self.inner.turn = val as u16; | |
| } | |
| fn silent(&self) -> bool { | |
| self.inner.silent | |
| } | |
| fn set_silent(&mut self, val: bool) { | |
| self.inner.silent = val; | |
| } | |
| fn performance_results(&self) -> String { | |
| serde_json::to_string(&self.inner.performance_results).unwrap_or_default() | |
| } | |
| fn pending_card_id(&self) -> i32 { | |
| self.inner.pending_card_id as i32 | |
| } | |
| fn pending_ab_idx(&self) -> i32 { | |
| self.inner.pending_ab_idx as i32 | |
| } | |
| fn pending_effect_opcode(&self) -> i32 { | |
| self.inner.pending_effect_opcode as i32 | |
| } | |
| fn pending_choice_type(&self) -> String { | |
| self.inner.pending_choice_type.clone() | |
| } | |
| fn yell_cards(&self) -> Vec<u32> { | |
| // Moved to GameState | |
| Vec::new() | |
| } | |
| fn set_yell_cards(&mut self, _val: Vec<u32>) { | |
| // Moved to GameState | |
| } | |
| fn pending_area_idx(&self) -> i32 { | |
| if let Some(ctx) = &self.inner.pending_ctx { | |
| ctx.area_idx as i32 | |
| } else { | |
| -1 | |
| } | |
| } | |
| fn pending_player_id(&self) -> i32 { | |
| if let Some(ctx) = &self.inner.pending_ctx { | |
| ctx.player_id as i32 | |
| } else { | |
| -1 | |
| } | |
| } | |
| fn last_performance_results(&self) -> String { | |
| serde_json::to_string(&self.inner.last_performance_results).unwrap_or_else(|_| "{}".to_string()) | |
| } | |
| fn performance_history(&self) -> String { | |
| serde_json::to_string(&self.inner.performance_history).unwrap_or_else(|_| "[]".to_string()) | |
| } | |
| fn pending_choices(&self) -> Vec<(String, String)> { | |
| use crate::core::logic::*; | |
| let mut result = Vec::new(); | |
| let op = self.inner.pending_effect_opcode; | |
| let p_idx = if let Some(ctx) = &self.inner.pending_ctx { | |
| ctx.player_id as usize | |
| } else { | |
| self.inner.current_player as usize | |
| }; | |
| if op == O_ORDER_DECK as i16 || op == O_LOOK_AND_CHOOSE as i16 || op == O_REVEAL as i16 { | |
| let looked = &self.inner.players[p_idx].looked_cards; | |
| let params = serde_json::json!({ | |
| "cards": looked | |
| }); | |
| let type_str = if op == O_ORDER_DECK as i16 { "ORDER_DECK" } else { "SELECT_FROM_LIST" }; | |
| result.push((type_str.to_string(), params.to_string())); | |
| } else if op == O_TAP_O as i16 { | |
| result.push(("TARGET_OPPONENT_MEMBER".to_string(), "{}".to_string())); | |
| } else if op == O_MOVE_MEMBER as i16 { | |
| result.push(("MOVE_MEMBER".to_string(), "{}".to_string())); | |
| } else if op == O_ACTIVATE_MEMBER as i16 { | |
| result.push(("TAP_MEMBER".to_string(), "{}".to_string())); | |
| } else if op == O_COLOR_SELECT as i16 { | |
| result.push(("COLOR_SELECT".to_string(), "{}".to_string())); | |
| } else if op == O_SELECT_MODE as i16 { | |
| // We might need to store the options in the state if we want better labels | |
| result.push(("SELECT_MODE".to_string(), "{}".to_string())); | |
| } | |
| result | |
| } | |
| fn pending_effects(&self) -> Vec<String> { | |
| Vec::new() | |
| } | |
| fn get_player(&self, idx: usize) -> PyResult<PyPlayerState> { | |
| if idx < 2 { | |
| Ok(PyPlayerState { inner: self.inner.players[idx].clone() }) | |
| } else { | |
| Err(pyo3::exceptions::PyIndexError::new_err("Player index out of bounds")) | |
| } | |
| } | |
| fn initialize_game(&mut self, p0_deck: Vec<u32>, p1_deck: Vec<u32>, p0_energy: Vec<u32>, p1_energy: Vec<u32>, p0_lives: Vec<u32>, p1_lives: Vec<u32>) { | |
| let p0_d: Vec<u16> = p0_deck.into_iter().map(|x| x as u16).collect(); | |
| let p1_d: Vec<u16> = p1_deck.into_iter().map(|x| x as u16).collect(); | |
| let p0_e: Vec<u16> = p0_energy.into_iter().map(|x| x as u16).collect(); | |
| let p1_e: Vec<u16> = p1_energy.into_iter().map(|x| x as u16).collect(); | |
| let p0_l: Vec<u16> = p0_lives.into_iter().map(|x| x as u16).collect(); | |
| let p1_l: Vec<u16> = p1_lives.into_iter().map(|x| x as u16).collect(); | |
| self.inner.initialize_game(p0_d, p1_d, p0_e, p1_e, p0_l, p1_l); | |
| } | |
| fn initialize_game_with_seed(&mut self, p0_deck: Vec<u32>, p1_deck: Vec<u32>, p0_energy: Vec<u32>, p1_energy: Vec<u32>, p0_lives: Vec<u32>, p1_lives: Vec<u32>, seed: u64) { | |
| let p0_d: Vec<u16> = p0_deck.into_iter().map(|x| x as u16).collect(); | |
| let p1_d: Vec<u16> = p1_deck.into_iter().map(|x| x as u16).collect(); | |
| let p0_e: Vec<u16> = p0_energy.into_iter().map(|x| x as u16).collect(); | |
| let p1_e: Vec<u16> = p1_energy.into_iter().map(|x| x as u16).collect(); | |
| let p0_l: Vec<u16> = p0_lives.into_iter().map(|x| x as u16).collect(); | |
| let p1_l: Vec<u16> = p1_lives.into_iter().map(|x| x as u16).collect(); | |
| self.inner.initialize_game_with_seed(p0_d, p1_d, p0_e, p1_e, p0_l, p1_l, Some(seed)); | |
| } | |
| fn get_legal_actions(&mut self) -> Vec<bool> { | |
| self.inner.get_legal_actions_into(&self.db.inner, &mut self.legal_action_buffer); | |
| self.legal_action_buffer.clone() | |
| } | |
| fn get_legal_action_ids(&mut self) -> Vec<i32> { | |
| self.inner.get_legal_action_ids(&self.db.inner) | |
| } | |
| fn get_observation(&self) -> Vec<f32> { | |
| self.inner.get_observation(&self.db.inner) | |
| } | |
| fn is_terminal(&self) -> bool { | |
| self.inner.phase == Phase::Terminal | |
| } | |
| fn get_winner(&self) -> i32 { | |
| self.inner.get_winner() | |
| } | |
| fn get_effective_blades(&self, p_idx: usize, slot_idx: usize) -> u32 { | |
| self.inner.get_effective_blades(p_idx, slot_idx, &self.db.inner) | |
| } | |
| fn get_effective_hearts(&self, p_idx: usize, slot_idx: usize) -> [u8; 7] { | |
| self.inner.get_effective_hearts(p_idx, slot_idx, &self.db.inner) | |
| } | |
| fn get_total_blades(&self, p_idx: usize) -> u32 { | |
| self.inner.get_total_blades(p_idx, &self.db.inner) | |
| } | |
| fn get_total_hearts(&self, p_idx: usize) -> [u32; 7] { | |
| self.inner.get_total_hearts(p_idx, &self.db.inner) | |
| } | |
| fn get_member_cost(&self, p_idx: usize, card_id: i32, slot_idx: i32) -> i32 { | |
| self.inner.get_member_cost(p_idx, card_id, slot_idx, &self.db.inner) | |
| } | |
| fn execute_mulligan(&mut self, player_idx: usize, discard_indices: Vec<usize>) { | |
| self.inner.execute_mulligan(player_idx, discard_indices); | |
| } | |
| fn step(&mut self, action: i32) -> PyResult<()> { | |
| let db = &self.db.inner; | |
| self.inner.step(db, action) | |
| .map_err(|e| pyo3::exceptions::PyValueError::new_err(e)) | |
| } | |
| fn debug_execute_bytecode(&mut self, bytecode: Vec<i32>, player_id: u8, area_idx: i32, source_card_id: i32, target_slot: i32, choice_index: i32, selected_color: i32) { | |
| let db = &self.db.inner; | |
| let ctx = crate::core::logic::AbilityContext { | |
| player_id, | |
| area_idx: area_idx as i16, | |
| source_card_id: source_card_id as i16, | |
| target_slot: target_slot as i16, | |
| choice_index: choice_index as i16, | |
| selected_color: selected_color as i16, | |
| program_counter: 0, | |
| ability_index: -1, | |
| }; | |
| self.inner.resolve_bytecode(db, &bytecode, &ctx); | |
| } | |
| fn integrated_step(&mut self, action: i32, opp_mode: u8, mcts_sims: usize, enable_rollout: bool) -> (f32, bool) { | |
| let db = &self.db.inner; | |
| self.inner.integrated_step(db, action, opp_mode, mcts_sims, enable_rollout) | |
| } | |
| fn play_asymmetric_match(&mut self, p0_sims: usize, p1_sims: usize, p0_heuristic_id: i32, p1_heuristic_id: i32, horizon: SearchHorizon, p0_rollout: bool, p1_rollout: bool) -> (i32, u32) { | |
| let db = &self.db.inner; | |
| self.inner.play_asymmetric_match(db, p0_sims, p1_sims, p0_heuristic_id, p1_heuristic_id, horizon, p0_rollout, p1_rollout) | |
| } | |
| fn play_mirror_match(&mut self, p0_sims: usize, p1_sims: usize, p0_heuristic_id: i32, p1_heuristic_id: i32, horizon: SearchHorizon, enable_rollout: bool) -> (i32, u32) { | |
| let db = &self.db.inner; | |
| self.inner.play_mirror_match(db, p0_sims, p1_sims, p0_heuristic_id, p1_heuristic_id, horizon, enable_rollout) | |
| } | |
| fn step_opponent(&mut self) { | |
| let db = &self.db.inner; | |
| self.inner.step_opponent(db); | |
| } | |
| fn step_opponent_mcts(&mut self, sims: usize) { | |
| let db = &self.db.inner; | |
| self.inner.step_opponent_mcts(db, sims, &crate::core::heuristics::OriginalHeuristic); | |
| } | |
| fn step_opponent_greedy(&mut self) { | |
| let db = &self.db.inner; | |
| self.inner.step_opponent_greedy(db, &crate::core::heuristics::OriginalHeuristic); | |
| } | |
| fn get_greedy_action(&mut self) -> i32 { | |
| let db = &self.db.inner; | |
| self.inner.get_greedy_action(db, &crate::core::heuristics::OriginalHeuristic) | |
| } | |
| fn get_mcts_suggestions(&mut self, sims: usize, horizon: SearchHorizon, eval_mode: EvalMode) -> Vec<(i32, f32, u32)> { | |
| let db = &self.db.inner; | |
| self.inner.get_mcts_suggestions(db, sims, horizon, eval_mode) | |
| } | |
| fn set_phase(&mut self, val: i8) { | |
| self.inner.phase = match val { | |
| -1 => Phase::MulliganP1, | |
| 0 => Phase::MulliganP2, | |
| 1 => Phase::Active, | |
| 2 => Phase::Energy, | |
| 3 => Phase::Draw, | |
| 4 => Phase::Main, | |
| 5 => Phase::LiveSet, | |
| 6 => Phase::PerformanceP1, | |
| 7 => Phase::PerformanceP2, | |
| 8 => Phase::LiveResult, | |
| 9 => Phase::Terminal, | |
| 10 => Phase::Response, | |
| _ => Phase::Setup, | |
| }; | |
| } | |
| fn set_player(&mut self, idx: usize, player: PyPlayerState) -> PyResult<()> { | |
| if idx < 2 { | |
| self.inner.log(format!("set_player {}: Discard len = {}", idx, player.inner.discard.len())); | |
| self.inner.players[idx] = player.inner; | |
| Ok(()) | |
| } else { | |
| Err(pyo3::exceptions::PyIndexError::new_err("Player index out of bounds")) | |
| } | |
| } | |
| fn set_stage_card(&mut self, p_idx: usize, slot_idx: usize, card_id: i32) { | |
| if p_idx < 2 && slot_idx < 3 { | |
| self.inner.players[p_idx].stage[slot_idx] = card_id as i16; | |
| } | |
| } | |
| fn set_live_card(&mut self, p_idx: usize, slot_idx: usize, card_id: i32, revealed: bool) { | |
| if p_idx < 2 && slot_idx < 3 { | |
| self.inner.players[p_idx].live_zone[slot_idx] = card_id as i16; | |
| self.inner.players[p_idx].live_zone_revealed[slot_idx] = revealed; | |
| } | |
| } | |
| fn set_hand_cards(&mut self, p_idx: usize, cards: Vec<u32>) { | |
| if p_idx < 2 { | |
| self.inner.players[p_idx].hand = cards.into_iter().map(|x| x as u16).collect(); | |
| self.inner.players[p_idx].hand_added_turn = SmallVec::from_vec(vec![self.inner.turn as u16; self.inner.players[p_idx].hand.len()]); | |
| } | |
| } | |
| fn resolve_bytecode(&mut self, bytecode: Vec<i32>, player_id: u8, area_idx: i32) { | |
| let ctx = crate::core::logic::AbilityContext { | |
| player_id, | |
| area_idx: area_idx as i16, | |
| ..crate::core::logic::AbilityContext::default() | |
| }; | |
| self.inner.resolve_bytecode(&self.db.inner, &bytecode, &ctx); | |
| } | |
| fn trigger_abilities(&mut self, trigger: i32, player_id: u8) { | |
| let trigger_type = unsafe { std::mem::transmute::<i8, crate::core::enums::TriggerType>(trigger as i8) }; | |
| let ctx = crate::core::logic::AbilityContext { | |
| player_id, | |
| ..crate::core::logic::AbilityContext::default() | |
| }; | |
| self.inner.trigger_abilities(&self.db.inner, trigger_type, &ctx); | |
| } | |
| fn search_mcts(&self, num_sims: usize, seconds: f32, heuristic_type: &str, horizon: SearchHorizon, eval_mode: EvalMode, model_path: Option<&str>) -> Vec<(i32, f32, u32)> { | |
| if heuristic_type == "resnet" || heuristic_type == "hybrid" { | |
| { | |
| let actual_path = model_path.unwrap_or("ai/models/alphanet_best.onnx"); | |
| // Helper for fallback to avoid repetition | |
| let run_fallback = || { | |
| let mcts = crate::core::mcts::MCTS::new(); | |
| let h = crate::core::heuristics::OriginalHeuristic; | |
| mcts.search_parallel(&self.inner, &self.db.inner, num_sims, seconds, horizon, &h, eval_mode == EvalMode::Blind) | |
| }; | |
| // Check if file exists to avoid panics | |
| if !std::path::Path::new(actual_path).exists() { | |
| return run_fallback(); | |
| } | |
| let session_res = ort::session::Session::builder() | |
| .and_then(|b| b.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)) | |
| .and_then(|b| b.with_intra_threads(1)) | |
| .and_then(|b| b.commit_from_file(actual_path)); | |
| match session_res { | |
| Ok(session) => { | |
| let weight = if heuristic_type == "resnet" { 1.0 } else { 0.5 }; | |
| let session_arc = std::sync::Arc::new(std::sync::Mutex::new(session)); | |
| let mut hybrid = crate::core::mcts::HybridMCTS::new(session_arc, weight, false); | |
| return hybrid.get_suggestions(&self.inner, &self.db.inner, num_sims, seconds); | |
| } | |
| Err(_) => { | |
| return run_fallback(); | |
| } | |
| } | |
| } | |
| { | |
| let mcts = crate::core::mcts::MCTS::new(); | |
| let h = crate::core::heuristics::OriginalHeuristic; | |
| return mcts.search_parallel(&self.inner, &self.db.inner, num_sims, seconds, horizon, &h, eval_mode == EvalMode::Blind); | |
| } | |
| } | |
| let mcts = crate::core::mcts::MCTS::new(); | |
| let h: Box<dyn crate::core::heuristics::Heuristic> = match heuristic_type { | |
| "simple" => Box::new(crate::core::heuristics::SimpleHeuristic), | |
| _ => Box::new(crate::core::heuristics::OriginalHeuristic), | |
| }; | |
| mcts.search_parallel(&self.inner, &self.db.inner, num_sims, seconds, horizon, h.as_ref(), eval_mode == EvalMode::Blind) | |
| } | |
| } | |
| pub struct PyVectorGameState { | |
| envs: Vec<GameState>, | |
| db: PyCardDatabase, | |
| p0_deck: Vec<u32>, | |
| p1_deck: Vec<u32>, | |
| p0_lives: Vec<u32>, | |
| p1_lives: Vec<u32>, | |
| seeds: Vec<u64>, | |
| opp_mode: u8, | |
| mcts_sims: usize, | |
| } | |
| impl PyVectorGameState { | |
| fn new(num_envs: usize, db: PyCardDatabase, opp_mode: u8, mcts_sims: usize) -> Self { | |
| let mut envs = Vec::with_capacity(num_envs); | |
| for _ in 0..num_envs { | |
| envs.push(GameState::default()); | |
| } | |
| Self { | |
| envs, | |
| db, | |
| p0_deck: Vec::new(), | |
| p1_deck: Vec::new(), | |
| p0_lives: Vec::new(), | |
| p1_lives: Vec::new(), | |
| seeds: vec![0; num_envs], | |
| opp_mode, | |
| mcts_sims, | |
| } | |
| } | |
| fn initialize(&mut self, p0_deck: Vec<u32>, p1_deck: Vec<u32>, p0_lives: Vec<u32>, p1_lives: Vec<u32>, seed: u64) { | |
| self.p0_deck = p0_deck; | |
| self.p1_deck = p1_deck; | |
| self.p0_lives = p0_lives; | |
| self.p1_lives = p1_lives; | |
| let num_envs = self.envs.len(); | |
| for i in 0..num_envs { | |
| self.seeds[i] = seed + i as u64; | |
| } | |
| self.envs.par_iter_mut().enumerate().for_each(|(i, env)| { | |
| env.initialize_game_with_seed( | |
| self.p0_deck.iter().map(|&x| x as u16).collect(), | |
| self.p1_deck.iter().map(|&x| x as u16).collect(), | |
| Vec::new(), Vec::new(), | |
| self.p0_lives.iter().map(|&x| x as u16).collect(), | |
| self.p1_lives.iter().map(|&x| x as u16).collect(), | |
| Some(self.seeds[i]) | |
| ); | |
| }); | |
| } | |
| fn step<'py>( | |
| &mut self, | |
| _py: Python<'py>, | |
| actions: PyReadonlyArray1<'py, i32>, | |
| obs_out: &Bound<'py, PyArray2<f32>>, | |
| rewards_out: &Bound<'py, PyArray1<f32>>, | |
| dones_out: &Bound<'py, PyArray1<bool>>, | |
| term_obs_out: &Bound<'py, PyArray2<f32>>, | |
| ) -> PyResult<Vec<usize>> { | |
| let actions = actions.as_slice()?; | |
| let obs_slice = unsafe { obs_out.as_slice_mut()? }; | |
| let rewards_slice = unsafe { rewards_out.as_slice_mut()? }; | |
| let dones_slice = unsafe { dones_out.as_slice_mut()? }; | |
| let term_obs_slice = unsafe { term_obs_out.as_slice_mut()? }; | |
| let num_envs = self.envs.len(); | |
| let db = &self.db.inner; | |
| let obs_dim = 320; | |
| if actions.len() != num_envs { | |
| return Err(pyo3::exceptions::PyValueError::new_err("Action dim mismatch")); | |
| } | |
| // 1. Step | |
| let opp_mode = self.opp_mode; | |
| let mcts_sims = self.mcts_sims; | |
| let results: Vec<(f32, bool)> = self.envs.par_iter_mut().zip(actions.par_iter()) | |
| .map(|(env, &act)| { | |
| env.integrated_step(db, act, opp_mode, mcts_sims, true) | |
| }).collect(); | |
| results.par_iter().zip(rewards_slice.par_iter_mut()).zip(dones_slice.par_iter_mut()) | |
| .for_each(|((&(r, d), r_out), d_out)| { | |
| *r_out = r; | |
| *d_out = d; | |
| }); | |
| // 2. Filter Done | |
| let mut done_indices = Vec::with_capacity(num_envs / 10); | |
| for (i, &(_, done)) in results.iter().enumerate() { | |
| if done { done_indices.push(i); } | |
| } | |
| // 3. Write Terminal Obs (Before Reset) | |
| if !done_indices.is_empty() { | |
| term_obs_slice.par_chunks_mut(obs_dim).zip(done_indices.par_iter()) | |
| .for_each(|(chunk, &env_idx)| { | |
| self.envs[env_idx].write_observation(db, chunk); | |
| }); | |
| } | |
| // 4. Auto-Reset | |
| let p0_deck = &self.p0_deck; | |
| let p1_deck = &self.p1_deck; | |
| let p0_lives = &self.p0_lives; | |
| let p1_lives = &self.p1_lives; | |
| self.envs.par_iter_mut().zip(results.par_iter()).for_each(|(env, &(_, done))| { | |
| if done { | |
| env.initialize_game_with_seed( | |
| p0_deck.iter().map(|&x| x as u16).collect(), | |
| p1_deck.iter().map(|&x| x as u16).collect(), | |
| Vec::new(), Vec::new(), | |
| p0_lives.iter().map(|&x| x as u16).collect(), | |
| p1_lives.iter().map(|&x| x as u16).collect(), | |
| None | |
| ); | |
| } | |
| }); | |
| // 5. Write Final Obs | |
| obs_slice.par_chunks_mut(obs_dim).zip(self.envs.par_iter()) | |
| .for_each(|(chunk, env)| { | |
| env.write_observation(db, chunk); | |
| }); | |
| Ok(done_indices) | |
| } | |
| // New: Zero-Copy get_observations | |
| fn get_observations<'py>(&self, _py: Python<'py>, out: &Bound<'py, PyArray2<f32>>) -> PyResult<()> { | |
| let db = &self.db.inner; | |
| let obs_dim = 320; | |
| let obs_slice = unsafe { out.as_slice_mut()? }; | |
| obs_slice.par_chunks_mut(obs_dim).zip(self.envs.par_iter()) | |
| .for_each(|(chunk, env)| { | |
| env.write_observation(db, chunk); | |
| }); | |
| Ok(()) | |
| } | |
| // New: Zero-Copy get_action_masks | |
| fn get_action_masks<'py>(&self, _py: Python<'py>, out: &Bound<'py, PyArray2<bool>>) -> PyResult<()> { | |
| let db = &self.db.inner; | |
| let action_dim = crate::core::logic::ACTION_SPACE; | |
| let mask_slice = unsafe { out.as_slice_mut()? }; | |
| mask_slice.par_chunks_mut(action_dim).zip(self.envs.par_iter()) | |
| .for_each(|(chunk, env)| { | |
| env.get_legal_actions_into(db, chunk); | |
| }); | |
| Ok(()) | |
| } | |
| } | |
| pub struct PyHybridMCTS { | |
| pub session: std::sync::Arc<std::sync::Mutex<ort::session::Session>>, | |
| pub neural_weight: f32, | |
| pub skip_rollout: bool, | |
| } | |
| impl PyHybridMCTS { | |
| fn new(model_path: &str, neural_weight: f32, skip_rollout: bool) -> PyResult<Self> { | |
| let session = ort::session::Session::builder() | |
| .map_err(|e: ort::Error| pyo3::exceptions::PyValueError::new_err(e.to_string()))? | |
| .commit_from_file(model_path) | |
| .map_err(|e: ort::Error| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; | |
| Ok(Self { | |
| session: std::sync::Arc::new(std::sync::Mutex::new(session)), | |
| neural_weight, | |
| skip_rollout | |
| }) | |
| } | |
| fn get_suggestions(&mut self, game: &mut PyGameState, num_sims: usize, seconds: f32) -> Vec<(i32, f32, u32)> { | |
| let mut mcts = crate::core::mcts::HybridMCTS::new( | |
| self.session.clone(), | |
| self.neural_weight, | |
| self.skip_rollout | |
| ); | |
| mcts.get_suggestions(&game.inner, &game.db.inner, num_sims, seconds) | |
| } | |
| } | |
| pub fn register_python_module(m: &Bound<'_, PyModule>) -> PyResult<()> { | |
| m.add_class::<PyGameState>()?; | |
| m.add_class::<PyPlayerState>()?; | |
| m.add_class::<PyCardDatabase>()?; | |
| m.add_class::<PyVectorGameState>()?; | |
| m.add_class::<PyHybridMCTS>()?; | |
| m.add_class::<SearchHorizon>()?; | |
| m.add_class::<EvalMode>()?; | |
| Ok(()) | |
| } | |