LovecaSim / ai /agents /search_prob_agent.py
trioskosmos's picture
Upload ai/agents/search_prob_agent.py with huggingface_hub
8d81185 verified
from typing import List
import numpy as np
from ai.agents.agent_base import Agent
from engine.game.enums import Phase as PhaseEnum
from engine.game.game_state import GameState
try:
from numba import njit
HAS_NUMBA = True
except ImportError:
HAS_NUMBA = False
# Mock njit decorator if numba is missing
def njit(f):
return f
@njit
def _check_meet_jit(hearts, req):
"""Greedy heart requirement check matching engine logic - Optimized."""
# 1. Match specific colors (0-5)
needed_specific = req[:6]
have_specific = hearts[:6]
# Numba doesn't support np.minimum for arrays in all versions efficiently, doing manual element-wise
used_specific = np.zeros(6, dtype=np.int32)
for i in range(6):
if needed_specific[i] < have_specific[i]:
used_specific[i] = needed_specific[i]
else:
used_specific[i] = have_specific[i]
remaining_req_0 = req[0] - used_specific[0]
remaining_req_1 = req[1] - used_specific[1]
remaining_req_2 = req[2] - used_specific[2]
remaining_req_3 = req[3] - used_specific[3]
remaining_req_4 = req[4] - used_specific[4]
remaining_req_5 = req[5] - used_specific[5]
temp_hearts_0 = hearts[0] - used_specific[0]
temp_hearts_1 = hearts[1] - used_specific[1]
temp_hearts_2 = hearts[2] - used_specific[2]
temp_hearts_3 = hearts[3] - used_specific[3]
temp_hearts_4 = hearts[4] - used_specific[4]
temp_hearts_5 = hearts[5] - used_specific[5]
# 2. Match Any requirement (index 6) with remaining specific hearts
needed_any = req[6]
have_any_from_specific = (
temp_hearts_0 + temp_hearts_1 + temp_hearts_2 + temp_hearts_3 + temp_hearts_4 + temp_hearts_5
)
used_any_from_specific = needed_any
if have_any_from_specific < needed_any:
used_any_from_specific = have_any_from_specific
# 3. Match remaining Any with Any (Wildcard) hearts (index 6)
needed_any -= used_any_from_specific
have_wild = hearts[6]
used_wild = needed_any
if have_wild < needed_any:
used_wild = have_wild
# Check if satisfied
if remaining_req_0 > 0:
return False
if remaining_req_1 > 0:
return False
if remaining_req_2 > 0:
return False
if remaining_req_3 > 0:
return False
if remaining_req_4 > 0:
return False
if remaining_req_5 > 0:
return False
if (needed_any - used_wild) > 0:
return False
return True
@njit
def _run_sampling_jit(stage_hearts, deck_ids, global_matrix, num_yells, total_req, samples):
# deck_ids: array of card Base IDs (ints)
# global_matrix: (MAX_ID+1, 7) array of hearts
success_count = 0
deck_size = len(deck_ids)
# Fix for empty deck case
if deck_size == 0:
if _check_meet_jit(stage_hearts, total_req):
return float(samples)
else:
return 0.0
sample_size = num_yells
if sample_size > deck_size:
sample_size = deck_size
# Create an index array for shuffling
indices = np.arange(deck_size)
for _ in range(samples):
# Fisher-Yates shuffle for first N elements
# Reuse existing indices array logic
for i in range(sample_size):
j = np.random.randint(i, deck_size)
# Swap
temp = indices[i]
indices[i] = indices[j]
indices[j] = temp
# Sum selected hearts using indirect lookup
simulated_hearts = stage_hearts.copy()
for k in range(sample_size):
idx = indices[k]
card_id = deck_ids[idx]
# Simple bounds check if needed, but assuming valid IDs
# Numba handles array access fast
# Unrolling 7 heart types
simulated_hearts[0] += global_matrix[card_id, 0]
simulated_hearts[1] += global_matrix[card_id, 1]
simulated_hearts[2] += global_matrix[card_id, 2]
simulated_hearts[3] += global_matrix[card_id, 3]
simulated_hearts[4] += global_matrix[card_id, 4]
simulated_hearts[5] += global_matrix[card_id, 5]
simulated_hearts[6] += global_matrix[card_id, 6]
if _check_meet_jit(simulated_hearts, total_req):
success_count += 1
return success_count / samples
class YellOddsCalculator:
"""
Calculates the probability of completing a set of lives given a known (but unordered) deck.
Optimized with Numba if available using Indirect Lookup.
"""
def __init__(self, member_db, live_db):
self.member_db = member_db
self.live_db = live_db
# Pre-compute global heart matrix for fast lookup
if self.member_db:
max_id = max(self.member_db.keys())
else:
max_id = 0
# Shape: (MaxID + 1, 7)
# We need to ensure it's contiguous and int32
self.global_heart_matrix = np.zeros((max_id + 1, 7), dtype=np.int32)
for mid, member in self.member_db.items():
self.global_heart_matrix[mid] = member.blade_hearts.astype(np.int32)
# Ensure it's ready for Numba
if HAS_NUMBA:
self.global_heart_matrix = np.ascontiguousarray(self.global_heart_matrix)
def calculate_odds(
self, deck_cards: List[int], stage_hearts: np.ndarray, live_ids: List[int], num_yells: int, samples: int = 150
) -> float:
if not live_ids:
return 1.0
# Pre-calculate requirements
total_req = np.zeros(7, dtype=np.int32)
for live_id in live_ids:
base_id = live_id & 0xFFFFF
if base_id in self.live_db:
total_req += self.live_db[base_id].required_hearts
# Optimization: Just convert deck to IDs. No object lookups.
# Mask out extra bits to get Base ID
# Vectorized operation if deck_cards was numpy, but it's list.
# List comprehension is reasonably fast for small N (~50).
deck_ids_list = [c & 0xFFFFF for c in deck_cards]
deck_ids = np.array(deck_ids_list, dtype=np.int32)
# Use JITted function
if HAS_NUMBA:
# Ensure contiguous arrays
stage_hearts_c = np.ascontiguousarray(stage_hearts, dtype=np.int32)
return _run_sampling_jit(stage_hearts_c, deck_ids, self.global_heart_matrix, num_yells, total_req, samples)
else:
return _run_sampling_jit(stage_hearts, deck_ids, self.global_heart_matrix, num_yells, total_req, samples)
def check_meet(self, hearts: np.ndarray, req: np.ndarray) -> bool:
"""Legacy wrapper for tests."""
return _check_meet_jit(hearts, req)
class SearchProbAgent(Agent):
"""
AI that uses Alpha-Beta search for decisions and sampling for probability.
Optimizes for Expected Value (EV) = P(Success) * Score.
"""
def __init__(self, depth=2, beam_width=5):
self.depth = depth
self.beam_width = beam_width
self.calculator = None
self._last_state_id = None
self._action_cache = {}
def get_calculator(self, state: GameState):
if self.calculator is None:
self.calculator = YellOddsCalculator(state.member_db, state.live_db)
return self.calculator
def evaluate_state(self, state: GameState, player_id: int) -> float:
if state.game_over:
if state.winner == player_id:
return 10000.0
if state.winner >= 0:
return -10000.0
return 0.0
p = state.players[player_id]
opp = state.players[1 - player_id]
score = 0.0
# 1. Guaranteed Score (Successful Lives)
score += len(p.success_lives) * 1000.0
score -= len(opp.success_lives) * 800.0
# 2. Board Presence (Members on Stage) - HIGH PRIORITY
stage_member_count = sum(1 for cid in p.stage if cid >= 0)
score += stage_member_count * 150.0 # Big bonus for having members on stage
# 3. Board Value (Hearts and Blades from members on stage)
total_blades = 0
total_hearts = np.zeros(7, dtype=np.int32)
for i, cid in enumerate(p.stage):
if cid >= 0:
base_id = cid & 0xFFFFF
if base_id in state.member_db:
member = state.member_db[base_id]
total_blades += member.blades
total_hearts += member.hearts
score += total_blades * 80.0
score += np.sum(total_hearts) * 40.0
# 4. Expected Score from Pending Lives
target_lives = list(p.live_zone)
if target_lives and total_blades > 0:
calc = self.get_calculator(state)
prob = calc.calculate_odds(p.main_deck, total_hearts, target_lives, total_blades)
potential_score = sum(
state.live_db[lid & 0xFFFFF].score for lid in target_lives if (lid & 0xFFFFF) in state.live_db
)
score += prob * potential_score * 500.0
if prob > 0.9:
score += 500.0
# 5. Resources
# Diminishing returns for hand size to prevent hoarding
hand_val = len(p.hand)
if hand_val > 8:
score += 80.0 + (hand_val - 8) * 1.0 # Very small bonus for cards beyond 8
else:
score += hand_val * 10.0
score += p.count_untapped_energy() * 10.0
score -= len(opp.hand) * 5.0
return score
def choose_action(self, state: GameState, player_id: int) -> int:
legal_mask = state.get_legal_actions()
legal_indices = np.where(legal_mask)[0]
if len(legal_indices) == 1:
return int(legal_indices[0])
# Skip search for simple phases
if state.phase not in (PhaseEnum.MAIN, PhaseEnum.LIVE_SET):
return int(np.random.choice(legal_indices))
# Alpha-Beta Search for Main Phase
best_action = legal_indices[0]
best_val = -float("inf")
alpha = -float("inf")
beta = float("inf")
# Limit branching factor for performance
candidates = list(legal_indices)
if len(candidates) > 15:
# Better heuristic: prioritize Play/Live/Activate over others
def action_priority(idx):
if 1 <= idx <= 180:
return 0 # Play Card
if 400 <= idx <= 459:
return 1 # Live Set
if 200 <= idx <= 202:
return 2 # Activate Ability
if idx == 0:
return 5 # Pass (End Phase)
if 900 <= idx <= 902:
return -1 # Performance (High Priority)
return 10 # Everything else (choices, target selection etc)
candidates.sort(key=action_priority)
candidates = candidates[:15]
if 0 not in candidates and 0 in legal_indices:
candidates.append(0)
for action in candidates:
try:
ns = state.copy()
ns = ns.step(action)
while ns.pending_choices and ns.current_player == player_id:
ns = ns.step(self._greedy_choice(ns))
val = self._minimax(ns, self.depth - 1, alpha, beta, False, player_id)
if val > best_val:
best_val = val
best_action = action
alpha = max(alpha, val)
except Exception:
continue
return int(best_action)
def _minimax(
self, state: GameState, depth: int, alpha: float, beta: float, is_max: bool, original_player: int
) -> float:
if depth == 0 or state.game_over:
return self.evaluate_state(state, original_player)
legal_mask = state.get_legal_actions()
legal_indices = np.where(legal_mask)[0]
if not legal_indices.any():
return self.evaluate_state(state, original_player)
# Optimization: Only search if it's still original player's turn or transition
# If it's opponent's turn, we can either do a full minimax or just use a fixed heuristic
# for their move. Let's do simple minimax.
current_is_max = state.current_player == original_player
candidates = list(legal_indices)
if len(candidates) > 8:
indices = np.random.choice(legal_indices, 8, replace=False)
candidates = list(indices)
if 0 in legal_indices and 0 not in candidates:
candidates.append(0)
if current_is_max:
max_eval = -float("inf")
for action in candidates:
try:
ns = state.copy().step(action)
while ns.pending_choices and ns.current_player == state.current_player:
ns = ns.step(self._greedy_choice(ns))
eval = self._minimax(ns, depth - 1, alpha, beta, False, original_player)
max_eval = max(max_eval, eval)
alpha = max(alpha, eval)
if beta <= alpha:
break
except:
continue
return max_eval
else:
min_eval = float("inf")
# For simplicity, if it's opponent's turn, maybe just assume they pass if we are deep enough
# or use a very shallow search.
for action in candidates:
try:
ns = state.copy().step(action)
while ns.pending_choices and ns.current_player == state.current_player:
ns = ns.step(self._greedy_choice(ns))
eval = self._minimax(ns, depth - 1, alpha, beta, True, original_player)
min_eval = min(min_eval, eval)
beta = min(beta, eval)
if beta <= alpha:
break
except:
continue
return min_eval
def _greedy_choice(self, state: GameState) -> int:
"""Fast greedy resolution for pending choices during search."""
mask = state.get_legal_actions()
indices = np.where(mask)[0]
if not indices.any():
return 0
# Simple priority: 1. Keep high cost (if mulligan), 2. Target slot 1, etc.
# For now, just pick the first valid action
return int(indices[0])