Spaces:
Sleeping
Sleeping
| use engine_rust::core::heuristics::OriginalHeuristic; | |
| use engine_rust::core::logic::{GameState, CardDatabase}; | |
| use rand::SeedableRng; | |
| use engine_rust::core::mcts::{SearchHorizon, MCTS}; | |
| use engine_rust::test_helpers::load_real_db; | |
| use rand::Rng; | |
| enum Agent { | |
| Random, | |
| Greedy, // AlphaZero 1-move lookahead | |
| Mcts(f32), // AlphaZero Full MCTS, Timeout in seconds | |
| } | |
| impl Agent { | |
| fn name(&self) -> String { | |
| match self { | |
| Agent::Random => "Random".to_string(), | |
| Agent::Greedy => "AlphaZero Greedy (1-Move)".to_string(), | |
| Agent::Mcts(t) => format!("AlphaZero MCTS ({}s)", t), | |
| } | |
| } | |
| } | |
| fn get_action<R: Rng>( | |
| agent: Agent, | |
| state: &GameState, | |
| db: &CardDatabase, | |
| mcts: &mut MCTS, | |
| heuristic: &OriginalHeuristic, | |
| rng: &mut R, | |
| ) -> i32 { | |
| let p_idx = state.current_player as usize; | |
| let mut mask = vec![false; engine_rust::core::logic::ACTION_SPACE]; | |
| state.get_legal_actions_into(db, p_idx, &mut mask); | |
| let valid_actions: Vec<i32> = mask | |
| .iter() | |
| .enumerate() | |
| .filter_map(|(i, &b)| if b { Some(i as i32) } else { None }) | |
| .collect(); | |
| if valid_actions.is_empty() { | |
| return 0; | |
| } | |
| if valid_actions.len() == 1 { | |
| return valid_actions[0]; | |
| } | |
| match agent { | |
| Agent::Random => { | |
| let idx = rng.random_range(0..valid_actions.len()); | |
| valid_actions[idx] | |
| } | |
| Agent::Greedy => { | |
| // AlphaZero Heuristic evaluated perfectly at Depth 1 | |
| let mut best_action = valid_actions[0]; | |
| // To maximize our win chance, we want the highest heuristic score from *our* perspective. | |
| // But state.step might change turns. Heuristic is typically absolute P0 win rate, | |
| // so we must invert it if it's P1's turn. | |
| let mut best_score = f32::NEG_INFINITY; | |
| for &action in &valid_actions { | |
| let mut next_sim = state.clone(); | |
| let _ = next_sim.step(db, action); | |
| let mut eval = if next_sim.is_terminal() { | |
| match next_sim.get_winner() { | |
| w if w == p_idx as i32 => 1.0, | |
| w if w == 1 - p_idx as i32 => 0.0, | |
| _ => 0.5, | |
| } | |
| } else { | |
| use engine_rust::core::heuristics::Heuristic; | |
| let h_val = heuristic.evaluate( | |
| &next_sim, | |
| db, | |
| state.players[0].score, | |
| state.players[1].score, | |
| engine_rust::core::heuristics::EvalMode::Normal, | |
| None, | |
| None, | |
| ); | |
| if p_idx == 1 { | |
| 1.0 - h_val | |
| } else { | |
| h_val | |
| } // Convert P0 win rate to Current Player win rate | |
| }; | |
| // Tiebreaker: avoid passing instantly unless it's strictly better | |
| if action == 0 { | |
| eval -= 0.0001; | |
| } | |
| if eval > best_score { | |
| best_score = eval; | |
| best_action = action; | |
| } | |
| } | |
| best_action | |
| } | |
| Agent::Mcts(timeout) => { | |
| let (stats, _) = | |
| mcts.search(state, db, 0, timeout, SearchHorizon::GameEnd(), heuristic); | |
| if stats.is_empty() { | |
| valid_actions[0] | |
| } else { | |
| stats[0].0 | |
| } | |
| } | |
| } | |
| } | |
| fn parse_deck(path: &str, db: &CardDatabase) -> Vec<i32> { | |
| let mut main_deck = Vec::new(); | |
| if let Ok(content) = std::fs::read_to_string(path) { | |
| for line in content.lines() { | |
| let line = line.trim(); | |
| if line.is_empty() || line.starts_with('#') { | |
| continue; | |
| } | |
| let parts: Vec<&str> = line.split('x').map(|s| s.trim()).collect(); | |
| let no = parts[0]; | |
| let count = if parts.len() > 1 { | |
| parts[1].parse::<usize>().unwrap_or(1) | |
| } else { | |
| 1 | |
| }; | |
| if let Some(&id) = db.card_no_to_id.get(no) { | |
| for _ in 0..count { | |
| main_deck.push(id); | |
| } | |
| } | |
| } | |
| } | |
| main_deck | |
| } | |
| fn run_matchup(agent0: Agent, agent1: Agent, num_games: usize) { | |
| println!("\n======================================================="); | |
| println!("Face-off: {} (P0) vs {} (P1)", agent0.name(), agent1.name()); | |
| println!("======================================================="); | |
| let db = load_real_db(); | |
| let deck_path = if std::path::Path::new("ai/decks/muse_cup.txt").exists() { | |
| "ai/decks/muse_cup.txt" | |
| } else { | |
| "../ai/decks/muse_cup.txt" | |
| }; | |
| let mut p_main = parse_deck(deck_path, &db); | |
| if p_main.is_empty() { | |
| println!("Warning: Could not parse deck, using fallback."); | |
| p_main = db.members.keys().take(48).cloned().collect(); | |
| let mut fallback_lives: Vec<i32> = db.lives.keys().take(12).cloned().collect(); | |
| p_main.append(&mut fallback_lives); | |
| } | |
| let energy_ids: Vec<i32> = db.energy_db.keys().take(10).cloned().collect(); | |
| let mut wins_0 = 0; | |
| let mut wins_1 = 0; | |
| let mut draws = 0; | |
| for _ in 0..num_games { | |
| let mut sim = GameState::default(); | |
| sim.initialize_game( | |
| p_main.clone(), | |
| p_main.clone(), | |
| energy_ids.clone(), | |
| energy_ids.clone(), | |
| Vec::new(), | |
| Vec::new(), | |
| ); | |
| sim.ui.silent = true; | |
| sim.phase = engine_rust::core::enums::Phase::Main; // Use absolute enum path | |
| let mut mcts = MCTS::new(); | |
| let heuristic = OriginalHeuristic::default(); | |
| let mut steps = 0; | |
| let mut rng = rand::rngs::SmallRng::from_os_rng(); | |
| while !sim.is_terminal() && steps < 500 { | |
| let action = if sim.current_player == 0 { | |
| get_action(agent0, &sim, &db, &mut mcts, &heuristic, &mut rng) | |
| } else { | |
| get_action(agent1, &sim, &db, &mut mcts, &heuristic, &mut rng) | |
| }; | |
| let _ = sim.step(&db, action); | |
| if steps % 50 == 0 { | |
| print!("."); | |
| use std::io::Write; | |
| std::io::stdout().flush().unwrap(); | |
| } | |
| steps += 1; | |
| } | |
| let winner = sim.get_winner(); | |
| match winner { | |
| 0 => { | |
| wins_0 += 1; | |
| } | |
| 1 => { | |
| wins_1 += 1; | |
| } | |
| _ => { | |
| draws += 1; | |
| } | |
| } | |
| } | |
| println!( | |
| "\n--- Results for {} vs {} ---", | |
| agent0.name(), | |
| agent1.name() | |
| ); | |
| println!("Total Games: {}", num_games); | |
| println!("{} (P0) Wins: {}", agent0.name(), wins_0); | |
| println!("{} (P1) Wins: {}", agent1.name(), wins_1); | |
| println!("Draws: {}", draws); | |
| } | |
| fn main() { | |
| let num_games = 10; | |
| // As requested: | |
| // 1. Random vs AlphaZero (0.1s) | |
| run_matchup(Agent::Random, Agent::Mcts(0.1), num_games); | |
| // 2. Greedy vs AlphaZero (0.01s) | |
| run_matchup(Agent::Greedy, Agent::Mcts(0.01), num_games); | |
| // 3. AlphaZero (0.01s) vs AlphaZero (0.01s) | |
| run_matchup(Agent::Mcts(0.01), Agent::Mcts(0.01), num_games); | |
| } |