Spaces:
Sleeping
Sleeping
| from typing import List | |
| import numpy as np | |
| from ai.agents.agent_base import Agent | |
| from engine.game.enums import Phase as PhaseEnum | |
| from engine.game.game_state import GameState | |
| try: | |
| from numba import njit | |
| HAS_NUMBA = True | |
| except ImportError: | |
| HAS_NUMBA = False | |
| # Mock njit decorator if numba is missing | |
| def njit(f): | |
| return f | |
| def _check_meet_jit(hearts, req): | |
| """Greedy heart requirement check matching engine logic - Optimized.""" | |
| # 1. Match specific colors (0-5) | |
| needed_specific = req[:6] | |
| have_specific = hearts[:6] | |
| # Numba doesn't support np.minimum for arrays in all versions efficiently, doing manual element-wise | |
| used_specific = np.zeros(6, dtype=np.int32) | |
| for i in range(6): | |
| if needed_specific[i] < have_specific[i]: | |
| used_specific[i] = needed_specific[i] | |
| else: | |
| used_specific[i] = have_specific[i] | |
| remaining_req_0 = req[0] - used_specific[0] | |
| remaining_req_1 = req[1] - used_specific[1] | |
| remaining_req_2 = req[2] - used_specific[2] | |
| remaining_req_3 = req[3] - used_specific[3] | |
| remaining_req_4 = req[4] - used_specific[4] | |
| remaining_req_5 = req[5] - used_specific[5] | |
| temp_hearts_0 = hearts[0] - used_specific[0] | |
| temp_hearts_1 = hearts[1] - used_specific[1] | |
| temp_hearts_2 = hearts[2] - used_specific[2] | |
| temp_hearts_3 = hearts[3] - used_specific[3] | |
| temp_hearts_4 = hearts[4] - used_specific[4] | |
| temp_hearts_5 = hearts[5] - used_specific[5] | |
| # 2. Match Any requirement (index 6) with remaining specific hearts | |
| needed_any = req[6] | |
| have_any_from_specific = ( | |
| temp_hearts_0 + temp_hearts_1 + temp_hearts_2 + temp_hearts_3 + temp_hearts_4 + temp_hearts_5 | |
| ) | |
| used_any_from_specific = needed_any | |
| if have_any_from_specific < needed_any: | |
| used_any_from_specific = have_any_from_specific | |
| # 3. Match remaining Any with Any (Wildcard) hearts (index 6) | |
| needed_any -= used_any_from_specific | |
| have_wild = hearts[6] | |
| used_wild = needed_any | |
| if have_wild < needed_any: | |
| used_wild = have_wild | |
| # Check if satisfied | |
| if remaining_req_0 > 0: | |
| return False | |
| if remaining_req_1 > 0: | |
| return False | |
| if remaining_req_2 > 0: | |
| return False | |
| if remaining_req_3 > 0: | |
| return False | |
| if remaining_req_4 > 0: | |
| return False | |
| if remaining_req_5 > 0: | |
| return False | |
| if (needed_any - used_wild) > 0: | |
| return False | |
| return True | |
| def _run_sampling_jit(stage_hearts, deck_ids, global_matrix, num_yells, total_req, samples): | |
| # deck_ids: array of card Base IDs (ints) | |
| # global_matrix: (MAX_ID+1, 7) array of hearts | |
| success_count = 0 | |
| deck_size = len(deck_ids) | |
| # Fix for empty deck case | |
| if deck_size == 0: | |
| if _check_meet_jit(stage_hearts, total_req): | |
| return float(samples) | |
| else: | |
| return 0.0 | |
| sample_size = num_yells | |
| if sample_size > deck_size: | |
| sample_size = deck_size | |
| # Create an index array for shuffling | |
| indices = np.arange(deck_size) | |
| for _ in range(samples): | |
| # Fisher-Yates shuffle for first N elements | |
| # Reuse existing indices array logic | |
| for i in range(sample_size): | |
| j = np.random.randint(i, deck_size) | |
| # Swap | |
| temp = indices[i] | |
| indices[i] = indices[j] | |
| indices[j] = temp | |
| # Sum selected hearts using indirect lookup | |
| simulated_hearts = stage_hearts.copy() | |
| for k in range(sample_size): | |
| idx = indices[k] | |
| card_id = deck_ids[idx] | |
| # Simple bounds check if needed, but assuming valid IDs | |
| # Numba handles array access fast | |
| # Unrolling 7 heart types | |
| simulated_hearts[0] += global_matrix[card_id, 0] | |
| simulated_hearts[1] += global_matrix[card_id, 1] | |
| simulated_hearts[2] += global_matrix[card_id, 2] | |
| simulated_hearts[3] += global_matrix[card_id, 3] | |
| simulated_hearts[4] += global_matrix[card_id, 4] | |
| simulated_hearts[5] += global_matrix[card_id, 5] | |
| simulated_hearts[6] += global_matrix[card_id, 6] | |
| if _check_meet_jit(simulated_hearts, total_req): | |
| success_count += 1 | |
| return success_count / samples | |
| class YellOddsCalculator: | |
| """ | |
| Calculates the probability of completing a set of lives given a known (but unordered) deck. | |
| Optimized with Numba if available using Indirect Lookup. | |
| """ | |
| def __init__(self, member_db, live_db): | |
| self.member_db = member_db | |
| self.live_db = live_db | |
| # Pre-compute global heart matrix for fast lookup | |
| if self.member_db: | |
| max_id = max(self.member_db.keys()) | |
| else: | |
| max_id = 0 | |
| # Shape: (MaxID + 1, 7) | |
| # We need to ensure it's contiguous and int32 | |
| self.global_heart_matrix = np.zeros((max_id + 1, 7), dtype=np.int32) | |
| for mid, member in self.member_db.items(): | |
| self.global_heart_matrix[mid] = member.blade_hearts.astype(np.int32) | |
| # Ensure it's ready for Numba | |
| if HAS_NUMBA: | |
| self.global_heart_matrix = np.ascontiguousarray(self.global_heart_matrix) | |
| def calculate_odds( | |
| self, deck_cards: List[int], stage_hearts: np.ndarray, live_ids: List[int], num_yells: int, samples: int = 150 | |
| ) -> float: | |
| if not live_ids: | |
| return 1.0 | |
| # Pre-calculate requirements | |
| total_req = np.zeros(7, dtype=np.int32) | |
| for live_id in live_ids: | |
| base_id = live_id & 0xFFFFF | |
| if base_id in self.live_db: | |
| total_req += self.live_db[base_id].required_hearts | |
| # Optimization: Just convert deck to IDs. No object lookups. | |
| # Mask out extra bits to get Base ID | |
| # Vectorized operation if deck_cards was numpy, but it's list. | |
| # List comprehension is reasonably fast for small N (~50). | |
| deck_ids_list = [c & 0xFFFFF for c in deck_cards] | |
| deck_ids = np.array(deck_ids_list, dtype=np.int32) | |
| # Use JITted function | |
| if HAS_NUMBA: | |
| # Ensure contiguous arrays | |
| stage_hearts_c = np.ascontiguousarray(stage_hearts, dtype=np.int32) | |
| return _run_sampling_jit(stage_hearts_c, deck_ids, self.global_heart_matrix, num_yells, total_req, samples) | |
| else: | |
| return _run_sampling_jit(stage_hearts, deck_ids, self.global_heart_matrix, num_yells, total_req, samples) | |
| def check_meet(self, hearts: np.ndarray, req: np.ndarray) -> bool: | |
| """Legacy wrapper for tests.""" | |
| return _check_meet_jit(hearts, req) | |
| class SearchProbAgent(Agent): | |
| """ | |
| AI that uses Alpha-Beta search for decisions and sampling for probability. | |
| Optimizes for Expected Value (EV) = P(Success) * Score. | |
| """ | |
| def __init__(self, depth=2, beam_width=5): | |
| self.depth = depth | |
| self.beam_width = beam_width | |
| self.calculator = None | |
| self._last_state_id = None | |
| self._action_cache = {} | |
| def get_calculator(self, state: GameState): | |
| if self.calculator is None: | |
| self.calculator = YellOddsCalculator(state.member_db, state.live_db) | |
| return self.calculator | |
| def evaluate_state(self, state: GameState, player_id: int) -> float: | |
| if state.game_over: | |
| if state.winner == player_id: | |
| return 10000.0 | |
| if state.winner >= 0: | |
| return -10000.0 | |
| return 0.0 | |
| p = state.players[player_id] | |
| opp = state.players[1 - player_id] | |
| score = 0.0 | |
| # 1. Guaranteed Score (Successful Lives) | |
| score += len(p.success_lives) * 1000.0 | |
| score -= len(opp.success_lives) * 800.0 | |
| # 2. Board Presence (Members on Stage) - HIGH PRIORITY | |
| stage_member_count = sum(1 for cid in p.stage if cid >= 0) | |
| score += stage_member_count * 150.0 # Big bonus for having members on stage | |
| # 3. Board Value (Hearts and Blades from members on stage) | |
| total_blades = 0 | |
| total_hearts = np.zeros(7, dtype=np.int32) | |
| for i, cid in enumerate(p.stage): | |
| if cid >= 0: | |
| base_id = cid & 0xFFFFF | |
| if base_id in state.member_db: | |
| member = state.member_db[base_id] | |
| total_blades += member.blades | |
| total_hearts += member.hearts | |
| score += total_blades * 80.0 | |
| score += np.sum(total_hearts) * 40.0 | |
| # 4. Expected Score from Pending Lives | |
| target_lives = list(p.live_zone) | |
| if target_lives and total_blades > 0: | |
| calc = self.get_calculator(state) | |
| prob = calc.calculate_odds(p.main_deck, total_hearts, target_lives, total_blades) | |
| potential_score = sum( | |
| state.live_db[lid & 0xFFFFF].score for lid in target_lives if (lid & 0xFFFFF) in state.live_db | |
| ) | |
| score += prob * potential_score * 500.0 | |
| if prob > 0.9: | |
| score += 500.0 | |
| # 5. Resources | |
| # Diminishing returns for hand size to prevent hoarding | |
| hand_val = len(p.hand) | |
| if hand_val > 8: | |
| score += 80.0 + (hand_val - 8) * 1.0 # Very small bonus for cards beyond 8 | |
| else: | |
| score += hand_val * 10.0 | |
| score += p.count_untapped_energy() * 10.0 | |
| score -= len(opp.hand) * 5.0 | |
| return score | |
| def choose_action(self, state: GameState, player_id: int) -> int: | |
| legal_mask = state.get_legal_actions() | |
| legal_indices = np.where(legal_mask)[0] | |
| if len(legal_indices) == 1: | |
| return int(legal_indices[0]) | |
| # Skip search for simple phases | |
| if state.phase not in (PhaseEnum.MAIN, PhaseEnum.LIVE_SET): | |
| return int(np.random.choice(legal_indices)) | |
| # Alpha-Beta Search for Main Phase | |
| best_action = legal_indices[0] | |
| best_val = -float("inf") | |
| alpha = -float("inf") | |
| beta = float("inf") | |
| # Limit branching factor for performance | |
| candidates = list(legal_indices) | |
| if len(candidates) > 15: | |
| # Better heuristic: prioritize Play/Live/Activate over others | |
| def action_priority(idx): | |
| if 1 <= idx <= 180: | |
| return 0 # Play Card | |
| if 400 <= idx <= 459: | |
| return 1 # Live Set | |
| if 200 <= idx <= 202: | |
| return 2 # Activate Ability | |
| if idx == 0: | |
| return 5 # Pass (End Phase) | |
| if 900 <= idx <= 902: | |
| return -1 # Performance (High Priority) | |
| return 10 # Everything else (choices, target selection etc) | |
| candidates.sort(key=action_priority) | |
| candidates = candidates[:15] | |
| if 0 not in candidates and 0 in legal_indices: | |
| candidates.append(0) | |
| for action in candidates: | |
| try: | |
| ns = state.copy() | |
| ns = ns.step(action) | |
| while ns.pending_choices and ns.current_player == player_id: | |
| ns = ns.step(self._greedy_choice(ns)) | |
| val = self._minimax(ns, self.depth - 1, alpha, beta, False, player_id) | |
| if val > best_val: | |
| best_val = val | |
| best_action = action | |
| alpha = max(alpha, val) | |
| except Exception: | |
| continue | |
| return int(best_action) | |
| def _minimax( | |
| self, state: GameState, depth: int, alpha: float, beta: float, is_max: bool, original_player: int | |
| ) -> float: | |
| if depth == 0 or state.game_over: | |
| return self.evaluate_state(state, original_player) | |
| legal_mask = state.get_legal_actions() | |
| legal_indices = np.where(legal_mask)[0] | |
| if not legal_indices.any(): | |
| return self.evaluate_state(state, original_player) | |
| # Optimization: Only search if it's still original player's turn or transition | |
| # If it's opponent's turn, we can either do a full minimax or just use a fixed heuristic | |
| # for their move. Let's do simple minimax. | |
| current_is_max = state.current_player == original_player | |
| candidates = list(legal_indices) | |
| if len(candidates) > 8: | |
| indices = np.random.choice(legal_indices, 8, replace=False) | |
| candidates = list(indices) | |
| if 0 in legal_indices and 0 not in candidates: | |
| candidates.append(0) | |
| if current_is_max: | |
| max_eval = -float("inf") | |
| for action in candidates: | |
| try: | |
| ns = state.copy().step(action) | |
| while ns.pending_choices and ns.current_player == state.current_player: | |
| ns = ns.step(self._greedy_choice(ns)) | |
| eval = self._minimax(ns, depth - 1, alpha, beta, False, original_player) | |
| max_eval = max(max_eval, eval) | |
| alpha = max(alpha, eval) | |
| if beta <= alpha: | |
| break | |
| except: | |
| continue | |
| return max_eval | |
| else: | |
| min_eval = float("inf") | |
| # For simplicity, if it's opponent's turn, maybe just assume they pass if we are deep enough | |
| # or use a very shallow search. | |
| for action in candidates: | |
| try: | |
| ns = state.copy().step(action) | |
| while ns.pending_choices and ns.current_player == state.current_player: | |
| ns = ns.step(self._greedy_choice(ns)) | |
| eval = self._minimax(ns, depth - 1, alpha, beta, True, original_player) | |
| min_eval = min(min_eval, eval) | |
| beta = min(beta, eval) | |
| if beta <= alpha: | |
| break | |
| except: | |
| continue | |
| return min_eval | |
| def _greedy_choice(self, state: GameState) -> int: | |
| """Fast greedy resolution for pending choices during search.""" | |
| mask = state.get_legal_actions() | |
| indices = np.where(mask)[0] | |
| if not indices.any(): | |
| return 0 | |
| # Simple priority: 1. Keep high cost (if mulligan), 2. Target slot 1, etc. | |
| # For now, just pick the first valid action | |
| return int(indices[0]) | |