trioskosmos commited on
Commit
8d81185
·
verified ·
1 Parent(s): 1c22442

Upload ai/agents/search_prob_agent.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/agents/search_prob_agent.py +407 -0
ai/agents/search_prob_agent.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+
5
+ from ai.agents.agent_base import Agent
6
+ from engine.game.enums import Phase as PhaseEnum
7
+ from engine.game.game_state import GameState
8
+
9
+ try:
10
+ from numba import njit
11
+
12
+ HAS_NUMBA = True
13
+ except ImportError:
14
+ HAS_NUMBA = False
15
+
16
+ # Mock njit decorator if numba is missing
17
+ def njit(f):
18
+ return f
19
+
20
+
21
+ @njit
22
+ def _check_meet_jit(hearts, req):
23
+ """Greedy heart requirement check matching engine logic - Optimized."""
24
+ # 1. Match specific colors (0-5)
25
+ needed_specific = req[:6]
26
+ have_specific = hearts[:6]
27
+
28
+ # Numba doesn't support np.minimum for arrays in all versions efficiently, doing manual element-wise
29
+ used_specific = np.zeros(6, dtype=np.int32)
30
+ for i in range(6):
31
+ if needed_specific[i] < have_specific[i]:
32
+ used_specific[i] = needed_specific[i]
33
+ else:
34
+ used_specific[i] = have_specific[i]
35
+
36
+ remaining_req_0 = req[0] - used_specific[0]
37
+ remaining_req_1 = req[1] - used_specific[1]
38
+ remaining_req_2 = req[2] - used_specific[2]
39
+ remaining_req_3 = req[3] - used_specific[3]
40
+ remaining_req_4 = req[4] - used_specific[4]
41
+ remaining_req_5 = req[5] - used_specific[5]
42
+
43
+ temp_hearts_0 = hearts[0] - used_specific[0]
44
+ temp_hearts_1 = hearts[1] - used_specific[1]
45
+ temp_hearts_2 = hearts[2] - used_specific[2]
46
+ temp_hearts_3 = hearts[3] - used_specific[3]
47
+ temp_hearts_4 = hearts[4] - used_specific[4]
48
+ temp_hearts_5 = hearts[5] - used_specific[5]
49
+
50
+ # 2. Match Any requirement (index 6) with remaining specific hearts
51
+ needed_any = req[6]
52
+ have_any_from_specific = (
53
+ temp_hearts_0 + temp_hearts_1 + temp_hearts_2 + temp_hearts_3 + temp_hearts_4 + temp_hearts_5
54
+ )
55
+
56
+ used_any_from_specific = needed_any
57
+ if have_any_from_specific < needed_any:
58
+ used_any_from_specific = have_any_from_specific
59
+
60
+ # 3. Match remaining Any with Any (Wildcard) hearts (index 6)
61
+ needed_any -= used_any_from_specific
62
+ have_wild = hearts[6]
63
+
64
+ used_wild = needed_any
65
+ if have_wild < needed_any:
66
+ used_wild = have_wild
67
+
68
+ # Check if satisfied
69
+ if remaining_req_0 > 0:
70
+ return False
71
+ if remaining_req_1 > 0:
72
+ return False
73
+ if remaining_req_2 > 0:
74
+ return False
75
+ if remaining_req_3 > 0:
76
+ return False
77
+ if remaining_req_4 > 0:
78
+ return False
79
+ if remaining_req_5 > 0:
80
+ return False
81
+
82
+ if (needed_any - used_wild) > 0:
83
+ return False
84
+
85
+ return True
86
+
87
+
88
+ @njit
89
+ def _run_sampling_jit(stage_hearts, deck_ids, global_matrix, num_yells, total_req, samples):
90
+ # deck_ids: array of card Base IDs (ints)
91
+ # global_matrix: (MAX_ID+1, 7) array of hearts
92
+
93
+ success_count = 0
94
+ deck_size = len(deck_ids)
95
+
96
+ # Fix for empty deck case
97
+ if deck_size == 0:
98
+ if _check_meet_jit(stage_hearts, total_req):
99
+ return float(samples)
100
+ else:
101
+ return 0.0
102
+
103
+ sample_size = num_yells
104
+ if sample_size > deck_size:
105
+ sample_size = deck_size
106
+
107
+ # Create an index array for shuffling
108
+ indices = np.arange(deck_size)
109
+
110
+ for _ in range(samples):
111
+ # Fisher-Yates shuffle for first N elements
112
+ # Reuse existing indices array logic
113
+ for i in range(sample_size):
114
+ j = np.random.randint(i, deck_size)
115
+ # Swap
116
+ temp = indices[i]
117
+ indices[i] = indices[j]
118
+ indices[j] = temp
119
+
120
+ # Sum selected hearts using indirect lookup
121
+ simulated_hearts = stage_hearts.copy()
122
+
123
+ for k in range(sample_size):
124
+ idx = indices[k]
125
+ card_id = deck_ids[idx]
126
+
127
+ # Simple bounds check if needed, but assuming valid IDs
128
+ # Numba handles array access fast
129
+ # Unrolling 7 heart types
130
+ simulated_hearts[0] += global_matrix[card_id, 0]
131
+ simulated_hearts[1] += global_matrix[card_id, 1]
132
+ simulated_hearts[2] += global_matrix[card_id, 2]
133
+ simulated_hearts[3] += global_matrix[card_id, 3]
134
+ simulated_hearts[4] += global_matrix[card_id, 4]
135
+ simulated_hearts[5] += global_matrix[card_id, 5]
136
+ simulated_hearts[6] += global_matrix[card_id, 6]
137
+
138
+ if _check_meet_jit(simulated_hearts, total_req):
139
+ success_count += 1
140
+
141
+ return success_count / samples
142
+
143
+
144
+ class YellOddsCalculator:
145
+ """
146
+ Calculates the probability of completing a set of lives given a known (but unordered) deck.
147
+ Optimized with Numba if available using Indirect Lookup.
148
+ """
149
+
150
+ def __init__(self, member_db, live_db):
151
+ self.member_db = member_db
152
+ self.live_db = live_db
153
+
154
+ # Pre-compute global heart matrix for fast lookup
155
+ if self.member_db:
156
+ max_id = max(self.member_db.keys())
157
+ else:
158
+ max_id = 0
159
+
160
+ # Shape: (MaxID + 1, 7)
161
+ # We need to ensure it's contiguous and int32
162
+ self.global_heart_matrix = np.zeros((max_id + 1, 7), dtype=np.int32)
163
+
164
+ for mid, member in self.member_db.items():
165
+ self.global_heart_matrix[mid] = member.blade_hearts.astype(np.int32)
166
+
167
+ # Ensure it's ready for Numba
168
+ if HAS_NUMBA:
169
+ self.global_heart_matrix = np.ascontiguousarray(self.global_heart_matrix)
170
+
171
+ def calculate_odds(
172
+ self, deck_cards: List[int], stage_hearts: np.ndarray, live_ids: List[int], num_yells: int, samples: int = 150
173
+ ) -> float:
174
+ if not live_ids:
175
+ return 1.0
176
+
177
+ # Pre-calculate requirements
178
+ total_req = np.zeros(7, dtype=np.int32)
179
+ for live_id in live_ids:
180
+ base_id = live_id & 0xFFFFF
181
+ if base_id in self.live_db:
182
+ total_req += self.live_db[base_id].required_hearts
183
+
184
+ # Optimization: Just convert deck to IDs. No object lookups.
185
+ # Mask out extra bits to get Base ID
186
+ # Vectorized operation if deck_cards was numpy, but it's list.
187
+ # List comprehension is reasonably fast for small N (~50).
188
+ deck_ids_list = [c & 0xFFFFF for c in deck_cards]
189
+ deck_ids = np.array(deck_ids_list, dtype=np.int32)
190
+
191
+ # Use JITted function
192
+ if HAS_NUMBA:
193
+ # Ensure contiguous arrays
194
+ stage_hearts_c = np.ascontiguousarray(stage_hearts, dtype=np.int32)
195
+ return _run_sampling_jit(stage_hearts_c, deck_ids, self.global_heart_matrix, num_yells, total_req, samples)
196
+ else:
197
+ return _run_sampling_jit(stage_hearts, deck_ids, self.global_heart_matrix, num_yells, total_req, samples)
198
+
199
+ def check_meet(self, hearts: np.ndarray, req: np.ndarray) -> bool:
200
+ """Legacy wrapper for tests."""
201
+ return _check_meet_jit(hearts, req)
202
+
203
+
204
+ class SearchProbAgent(Agent):
205
+ """
206
+ AI that uses Alpha-Beta search for decisions and sampling for probability.
207
+ Optimizes for Expected Value (EV) = P(Success) * Score.
208
+ """
209
+
210
+ def __init__(self, depth=2, beam_width=5):
211
+ self.depth = depth
212
+ self.beam_width = beam_width
213
+ self.calculator = None
214
+ self._last_state_id = None
215
+ self._action_cache = {}
216
+
217
+ def get_calculator(self, state: GameState):
218
+ if self.calculator is None:
219
+ self.calculator = YellOddsCalculator(state.member_db, state.live_db)
220
+ return self.calculator
221
+
222
+ def evaluate_state(self, state: GameState, player_id: int) -> float:
223
+ if state.game_over:
224
+ if state.winner == player_id:
225
+ return 10000.0
226
+ if state.winner >= 0:
227
+ return -10000.0
228
+ return 0.0
229
+
230
+ p = state.players[player_id]
231
+ opp = state.players[1 - player_id]
232
+
233
+ score = 0.0
234
+
235
+ # 1. Guaranteed Score (Successful Lives)
236
+ score += len(p.success_lives) * 1000.0
237
+ score -= len(opp.success_lives) * 800.0
238
+
239
+ # 2. Board Presence (Members on Stage) - HIGH PRIORITY
240
+ stage_member_count = sum(1 for cid in p.stage if cid >= 0)
241
+ score += stage_member_count * 150.0 # Big bonus for having members on stage
242
+
243
+ # 3. Board Value (Hearts and Blades from members on stage)
244
+ total_blades = 0
245
+ total_hearts = np.zeros(7, dtype=np.int32)
246
+ for i, cid in enumerate(p.stage):
247
+ if cid >= 0:
248
+ base_id = cid & 0xFFFFF
249
+ if base_id in state.member_db:
250
+ member = state.member_db[base_id]
251
+ total_blades += member.blades
252
+ total_hearts += member.hearts
253
+
254
+ score += total_blades * 80.0
255
+ score += np.sum(total_hearts) * 40.0
256
+
257
+ # 4. Expected Score from Pending Lives
258
+ target_lives = list(p.live_zone)
259
+ if target_lives and total_blades > 0:
260
+ calc = self.get_calculator(state)
261
+ prob = calc.calculate_odds(p.main_deck, total_hearts, target_lives, total_blades)
262
+ potential_score = sum(
263
+ state.live_db[lid & 0xFFFFF].score for lid in target_lives if (lid & 0xFFFFF) in state.live_db
264
+ )
265
+ score += prob * potential_score * 500.0
266
+ if prob > 0.9:
267
+ score += 500.0
268
+
269
+ # 5. Resources
270
+ # Diminishing returns for hand size to prevent hoarding
271
+ hand_val = len(p.hand)
272
+ if hand_val > 8:
273
+ score += 80.0 + (hand_val - 8) * 1.0 # Very small bonus for cards beyond 8
274
+ else:
275
+ score += hand_val * 10.0
276
+
277
+ score += p.count_untapped_energy() * 10.0
278
+ score -= len(opp.hand) * 5.0
279
+
280
+ return score
281
+
282
+ def choose_action(self, state: GameState, player_id: int) -> int:
283
+ legal_mask = state.get_legal_actions()
284
+ legal_indices = np.where(legal_mask)[0]
285
+
286
+ if len(legal_indices) == 1:
287
+ return int(legal_indices[0])
288
+
289
+ # Skip search for simple phases
290
+ if state.phase not in (PhaseEnum.MAIN, PhaseEnum.LIVE_SET):
291
+ return int(np.random.choice(legal_indices))
292
+
293
+ # Alpha-Beta Search for Main Phase
294
+ best_action = legal_indices[0]
295
+ best_val = -float("inf")
296
+ alpha = -float("inf")
297
+ beta = float("inf")
298
+
299
+ # Limit branching factor for performance
300
+ candidates = list(legal_indices)
301
+ if len(candidates) > 15:
302
+ # Better heuristic: prioritize Play/Live/Activate over others
303
+ def action_priority(idx):
304
+ if 1 <= idx <= 180:
305
+ return 0 # Play Card
306
+ if 400 <= idx <= 459:
307
+ return 1 # Live Set
308
+ if 200 <= idx <= 202:
309
+ return 2 # Activate Ability
310
+ if idx == 0:
311
+ return 5 # Pass (End Phase)
312
+ if 900 <= idx <= 902:
313
+ return -1 # Performance (High Priority)
314
+ return 10 # Everything else (choices, target selection etc)
315
+
316
+ candidates.sort(key=action_priority)
317
+ candidates = candidates[:15]
318
+ if 0 not in candidates and 0 in legal_indices:
319
+ candidates.append(0)
320
+
321
+ for action in candidates:
322
+ try:
323
+ ns = state.copy()
324
+ ns = ns.step(action)
325
+
326
+ while ns.pending_choices and ns.current_player == player_id:
327
+ ns = ns.step(self._greedy_choice(ns))
328
+
329
+ val = self._minimax(ns, self.depth - 1, alpha, beta, False, player_id)
330
+
331
+ if val > best_val:
332
+ best_val = val
333
+ best_action = action
334
+
335
+ alpha = max(alpha, val)
336
+ except Exception:
337
+ continue
338
+
339
+ return int(best_action)
340
+
341
+ def _minimax(
342
+ self, state: GameState, depth: int, alpha: float, beta: float, is_max: bool, original_player: int
343
+ ) -> float:
344
+ if depth == 0 or state.game_over:
345
+ return self.evaluate_state(state, original_player)
346
+
347
+ legal_mask = state.get_legal_actions()
348
+ legal_indices = np.where(legal_mask)[0]
349
+ if not legal_indices.any():
350
+ return self.evaluate_state(state, original_player)
351
+
352
+ # Optimization: Only search if it's still original player's turn or transition
353
+ # If it's opponent's turn, we can either do a full minimax or just use a fixed heuristic
354
+ # for their move. Let's do simple minimax.
355
+
356
+ current_is_max = state.current_player == original_player
357
+
358
+ candidates = list(legal_indices)
359
+ if len(candidates) > 8:
360
+ indices = np.random.choice(legal_indices, 8, replace=False)
361
+ candidates = list(indices)
362
+ if 0 in legal_indices and 0 not in candidates:
363
+ candidates.append(0)
364
+
365
+ if current_is_max:
366
+ max_eval = -float("inf")
367
+ for action in candidates:
368
+ try:
369
+ ns = state.copy().step(action)
370
+ while ns.pending_choices and ns.current_player == state.current_player:
371
+ ns = ns.step(self._greedy_choice(ns))
372
+ eval = self._minimax(ns, depth - 1, alpha, beta, False, original_player)
373
+ max_eval = max(max_eval, eval)
374
+ alpha = max(alpha, eval)
375
+ if beta <= alpha:
376
+ break
377
+ except:
378
+ continue
379
+ return max_eval
380
+ else:
381
+ min_eval = float("inf")
382
+ # For simplicity, if it's opponent's turn, maybe just assume they pass if we are deep enough
383
+ # or use a very shallow search.
384
+ for action in candidates:
385
+ try:
386
+ ns = state.copy().step(action)
387
+ while ns.pending_choices and ns.current_player == state.current_player:
388
+ ns = ns.step(self._greedy_choice(ns))
389
+ eval = self._minimax(ns, depth - 1, alpha, beta, True, original_player)
390
+ min_eval = min(min_eval, eval)
391
+ beta = min(beta, eval)
392
+ if beta <= alpha:
393
+ break
394
+ except:
395
+ continue
396
+ return min_eval
397
+
398
+ def _greedy_choice(self, state: GameState) -> int:
399
+ """Fast greedy resolution for pending choices during search."""
400
+ mask = state.get_legal_actions()
401
+ indices = np.where(mask)[0]
402
+ if not indices.any():
403
+ return 0
404
+
405
+ # Simple priority: 1. Keep high cost (if mulligan), 2. Target slot 1, etc.
406
+ # For now, just pick the first valid action
407
+ return int(indices[0])