trioskosmos commited on
Commit
9a00f62
·
verified ·
1 Parent(s): 8d81185

Upload ai/agents/super_heuristic.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/agents/super_heuristic.py +310 -0
ai/agents/super_heuristic.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+
5
+ from ai.headless_runner import Agent
6
+ from engine.game.game_state import GameState, Phase
7
+
8
+
9
+ class SuperHeuristicAgent(Agent):
10
+ """
11
+ "Really Smart" heuristic AI that uses Beam Search and a comprehensive
12
+ evaluation function to look ahead and maximize advantage.
13
+ """
14
+
15
+ def __init__(self, depth=2, beam_width=3):
16
+ self.depth = depth
17
+ self.beam_width = beam_width
18
+ self.last_turn_num = -1
19
+
20
+ def evaluate_state(self, state: GameState, player_id: int) -> float:
21
+ """
22
+ Global evaluation function for a game state state from player_id's perspective.
23
+ Higher is better.
24
+ """
25
+ if state.game_over:
26
+ if state.winner == player_id:
27
+ return 100000.0
28
+ elif state.winner >= 0:
29
+ return -100000.0
30
+ else:
31
+ return 0.0 # Draw
32
+
33
+ p = state.players[player_id]
34
+ opp = state.players[1 - player_id]
35
+
36
+ score = 0.0
37
+
38
+ # --- 1. Score Advantage ---
39
+ my_score = len(p.success_lives)
40
+ opp_score = len(opp.success_lives)
41
+ # Drastically increase score weight to prioritize winning
42
+ score += my_score * 50000.0
43
+ score -= opp_score * 40000.0 # Slightly less penalty (aggressive play)
44
+
45
+ # --- 2. Live Progress (The "Closeness" to performing a live) ---
46
+ # Analyze lives in Live Zone
47
+ stage_hearts = p.get_total_hearts(state.member_db)
48
+
49
+ # Calculate pending requirement for existing lives
50
+ pending_req = np.zeros(7, dtype=np.int32)
51
+ for live_id in p.live_zone:
52
+ if live_id in state.live_db:
53
+ pending_req += state.live_db[live_id].required_hearts
54
+
55
+ # Calculate how "fulfilled" the pending requirement is
56
+ fulfilled_val = 0
57
+
58
+ # Colors
59
+ rem_hearts = stage_hearts.copy()
60
+ rem_req = pending_req.copy()
61
+
62
+ for c in range(6):
63
+ matched = min(rem_hearts[c], rem_req[c])
64
+ fulfilled_val += matched * 300 # VERY High value for matching needed colors
65
+ rem_hearts[c] -= matched
66
+ rem_req[c] -= matched
67
+
68
+ # Any
69
+ needed_any = rem_req[6] if len(rem_req) > 6 else 0
70
+ avail_any = np.sum(rem_hearts)
71
+ matched_any = min(avail_any, needed_any)
72
+ fulfilled_val += matched_any * 200
73
+
74
+ score += fulfilled_val
75
+
76
+ # Penalize unmet requirements (Distance to goal)
77
+ unmet_hearts = np.sum(rem_req[:6]) + max(0, needed_any - avail_any)
78
+ score -= unmet_hearts * 100 # Penalize distance
79
+
80
+ # Bonus: Can complete a live THIS turn?
81
+ # If unmet is 0 and we have lives in zone, HUGE bonus
82
+ if unmet_hearts == 0 and len(p.live_zone) > 0:
83
+ score += 5000.0
84
+
85
+ # --- 3. Board Strength (Secondary) ---
86
+ stage_blades = 0
87
+ stage_draws = 0
88
+ stage_raw_hearts = 0
89
+
90
+ for cid in p.stage:
91
+ if cid in state.member_db:
92
+ m = state.member_db[cid]
93
+ stage_blades += m.blades
94
+ stage_draws += m.draw_icons
95
+ stage_raw_hearts += np.sum(m.hearts)
96
+
97
+ score += stage_blades * 5 # Reduced from 10
98
+ score += stage_draws * 10 # Reduced from 15
99
+ score += stage_raw_hearts * 2 # Reduced from 5 (fulfilled matters more)
100
+
101
+ # --- 4. Resources ---
102
+ score += len(p.hand) * 10 # Reduced from 20
103
+ # Untapped Energy value
104
+ untapped_energy = p.count_untapped_energy()
105
+ score += untapped_energy * 5 # Reduced from 10
106
+
107
+ # --- 5. Opponent Denial (Simple) ---
108
+ # We want opponent to have fewer cards/resources
109
+ score -= len(opp.hand) * 5
110
+
111
+ return score
112
+
113
+ def choose_action(self, state: GameState, player_id: int) -> int:
114
+ legal_mask = state.get_legal_actions()
115
+ legal_indices = np.where(legal_mask)[0]
116
+ if len(legal_indices) == 0:
117
+ return 0
118
+ if len(legal_indices) == 1:
119
+ return int(legal_indices[0])
120
+
121
+ chosen_action = None # Will be set by phase logic
122
+
123
+ # --- PHASE SPECIFIC LOGIC ---
124
+
125
+ # 1. Mulligan: Keep Low Cost Cards
126
+ if state.phase in (Phase.MULLIGAN_P1, Phase.MULLIGAN_P2):
127
+ p = state.players[player_id]
128
+ if not hasattr(p, "mulligan_selection"):
129
+ p.mulligan_selection = set()
130
+
131
+ to_toggle = []
132
+ for i, card_id in enumerate(p.hand):
133
+ should_keep = False
134
+ if card_id in state.member_db:
135
+ member = state.member_db[card_id]
136
+ if member.cost <= 3:
137
+ should_keep = True
138
+
139
+ is_marked = i in p.mulligan_selection
140
+ if should_keep and is_marked:
141
+ to_toggle.append(300 + i)
142
+ elif not should_keep and not is_marked:
143
+ to_toggle.append(300 + i)
144
+
145
+ # Filter to only legal toggles
146
+ valid_toggles = [a for a in to_toggle if a in legal_indices]
147
+ if valid_toggles:
148
+ chosen_action = int(np.random.choice(valid_toggles))
149
+ else:
150
+ chosen_action = 0 # Confirm
151
+
152
+ # 2. Live Set: Greedy Value Check
153
+ elif state.phase == Phase.LIVE_SET:
154
+ live_actions = [i for i in legal_indices if 400 <= i <= 459]
155
+ if not live_actions:
156
+ chosen_action = 0
157
+ else:
158
+ p = state.players[player_id]
159
+ stage_hearts = p.get_total_hearts(state.member_db)
160
+
161
+ pending_req = np.zeros(7, dtype=np.int32)
162
+ for live_id in p.live_zone:
163
+ if live_id in state.live_db:
164
+ pending_req += state.live_db[live_id].required_hearts
165
+
166
+ best_action = 0
167
+ max_val = -100
168
+
169
+ for action in live_actions:
170
+ hand_idx = action - 400
171
+ if hand_idx >= len(p.hand):
172
+ continue
173
+ card_id = p.hand[hand_idx]
174
+ if card_id not in state.live_db:
175
+ continue
176
+
177
+ live = state.live_db[card_id]
178
+ total_req = pending_req + live.required_hearts
179
+
180
+ missing = 0
181
+ temp_hearts = stage_hearts.copy()
182
+ for c in range(6):
183
+ needed = total_req[c]
184
+ have = temp_hearts[c]
185
+ if have < needed:
186
+ missing += needed - have
187
+ temp_hearts[c] = 0
188
+ else:
189
+ temp_hearts[c] -= needed
190
+
191
+ needed_any = total_req[6] if len(total_req) > 6 else 0
192
+ avail_any = np.sum(temp_hearts)
193
+ if avail_any < needed_any:
194
+ missing += needed_any - avail_any
195
+
196
+ score_val = live.score * 10
197
+ score_val -= missing * 5
198
+
199
+ if score_val > 0 and score_val > max_val:
200
+ max_val = score_val
201
+ best_action = action
202
+
203
+ chosen_action = best_action if max_val > 0 else 0
204
+
205
+ # 3. Main Phase: MINIMAX SEARCH
206
+ elif state.phase == Phase.MAIN:
207
+ # Limit depth to 2 (Me -> Opponent -> Eval) for performance
208
+ # Ideally 3 to see my own follow-up response
209
+ best_action = 0
210
+ best_val = -float("inf")
211
+
212
+ # Alpha-Beta Pruning
213
+ alpha = -float("inf")
214
+ beta = float("inf")
215
+
216
+ legal_mask = state.get_legal_actions()
217
+ legal_indices = np.where(legal_mask)[0]
218
+
219
+ # Order moves by simple heuristic to improve pruning?
220
+ # For now, simplistic ordering: Live/Play > Trade > Toggle > Pass
221
+ # Actually, just random shuffle to avoid bias, or strict ordering.
222
+ # Let's shuffle to keep variety.
223
+ candidates = list(legal_indices)
224
+ random.shuffle(candidates)
225
+
226
+ # Pruning top-level candidates if too many
227
+ if len(candidates) > 8:
228
+ candidates = candidates[:8]
229
+ if 0 not in candidates and 0 in legal_indices:
230
+ candidates.append(0) # Always consider passing
231
+
232
+ for action in candidates:
233
+ try:
234
+ # MAX NODE (Me)
235
+ ns = state.step(action)
236
+ val = self._minimax(ns, self.depth - 1, alpha, beta, player_id)
237
+
238
+ if val > best_val:
239
+ best_val = val
240
+ best_action = action
241
+
242
+ alpha = max(alpha, val)
243
+ if beta <= alpha:
244
+ break # Prune
245
+ except Exception:
246
+ # If simulation fails, treat as bad move
247
+ pass
248
+
249
+ chosen_action = int(best_action)
250
+
251
+ # Fallback for other phases (ENERGY, DRAW, PERFORMANCE - usually auto)
252
+ else:
253
+ chosen_action = int(legal_indices[0])
254
+
255
+ # --- FINAL VALIDATION ---
256
+ # Ensure chosen_action is actually legal
257
+ legal_set = set(legal_indices.tolist())
258
+ if chosen_action is None or chosen_action not in legal_set:
259
+ chosen_action = int(legal_indices[0])
260
+
261
+ return chosen_action
262
+
263
+ def _minimax(self, state: GameState, depth: int, alpha: float, beta: float, maximize_player: int) -> float:
264
+ if depth == 0 or state.game_over:
265
+ return self.evaluate_state(state, maximize_player)
266
+
267
+ current_player = state.current_player
268
+ is_maximizing = current_player == maximize_player
269
+
270
+ legal_mask = state.get_legal_actions()
271
+ legal_indices = np.where(legal_mask)[0]
272
+
273
+ if len(legal_indices) == 0:
274
+ return self.evaluate_state(state, maximize_player)
275
+
276
+ # Move Ordering / Filtering for speed
277
+ candidates = list(legal_indices)
278
+ if len(candidates) > 5:
279
+ indices = np.random.choice(legal_indices, 5, replace=False)
280
+ candidates = list(indices)
281
+ # Ensure pass is included if legal (often safe fallback)
282
+ if 0 in legal_indices and 0 not in candidates:
283
+ candidates.append(0)
284
+
285
+ if is_maximizing:
286
+ max_eval = -float("inf")
287
+ for action in candidates:
288
+ try:
289
+ ns = state.step(action)
290
+ eval_val = self._minimax(ns, depth - 1, alpha, beta, maximize_player)
291
+ max_eval = max(max_eval, eval_val)
292
+ alpha = max(alpha, eval_val)
293
+ if beta <= alpha:
294
+ break
295
+ except:
296
+ pass
297
+ return max_eval
298
+ else:
299
+ min_eval = float("inf")
300
+ for action in candidates:
301
+ try:
302
+ ns = state.step(action)
303
+ eval_val = self._minimax(ns, depth - 1, alpha, beta, maximize_player)
304
+ min_eval = min(min_eval, eval_val)
305
+ beta = min(beta, eval_val)
306
+ if beta <= alpha:
307
+ break
308
+ except:
309
+ pass
310
+ return min_eval