trioskosmos commited on
Commit
69c4849
·
verified ·
1 Parent(s): 23592d5

Upload ai/headless_runner.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/headless_runner.py +927 -0
ai/headless_runner.py ADDED
@@ -0,0 +1,927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import random
5
+ import sys
6
+ import time
7
+
8
+ import numpy as np
9
+
10
+ # Add parent dir to path
11
+ # Add parent dir to path (for ai directory)
12
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
13
+ # Add engine directory
14
+ # Add project root directory
15
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
16
+
17
+ from ai.agents.agent_base import Agent
18
+ from ai.agents.search_prob_agent import SearchProbAgent
19
+ from engine.game.data_loader import CardDataLoader
20
+ from engine.game.game_state import GameState, Phase
21
+
22
+
23
+ class TrueRandomAgent(Agent):
24
+ """Completely random agent with no heuristics"""
25
+
26
+ def choose_action(self, state: GameState, player_id: int) -> int:
27
+ legal_mask = state.get_legal_actions()
28
+ legal_indices = np.where(legal_mask)[0]
29
+ if len(legal_indices) == 0:
30
+ return 0
31
+ return int(np.random.choice(legal_indices))
32
+
33
+
34
+ class RandomAgent(Agent):
35
+ def choose_action(self, state: GameState, player_id: int) -> int:
36
+ legal_mask = state.get_legal_actions()
37
+ legal_indices = np.where(legal_mask)[0]
38
+ if len(legal_indices) == 0:
39
+ return 0
40
+
41
+ # SMART HEURISTICS
42
+ non_pass = [i for i in legal_indices if i != 0]
43
+
44
+ # MULLIGAN: Sometimes confirm (action 0)
45
+ if state.phase in (Phase.MULLIGAN_P1, Phase.MULLIGAN_P2):
46
+ # 30% chance to confirm, 70% to toggle cards
47
+ if random.random() < 0.3:
48
+ return 0
49
+ mulligan_actions = [i for i in legal_indices if 300 <= i <= 359]
50
+ if mulligan_actions:
51
+ return int(np.random.choice(mulligan_actions))
52
+ return 0
53
+
54
+ # Priority 1: In LIVE_SET, prioritize setting LIVE cards over passing
55
+ if state.phase == Phase.LIVE_SET:
56
+ live_set_actions = [i for i in legal_indices if 400 <= i <= 459]
57
+ if live_set_actions:
58
+ return int(np.random.choice(live_set_actions))
59
+
60
+ # Priority 2: In MAIN phase, try to play members to stage
61
+ if state.phase == Phase.MAIN:
62
+ play_actions = [i for i in legal_indices if 1 <= i <= 180]
63
+ if play_actions:
64
+ # 80% chance to play instead of pass
65
+ if random.random() < 0.8:
66
+ return int(np.random.choice(play_actions))
67
+
68
+ # Priority 3: Never pass if ANY other action available
69
+ if non_pass:
70
+ return int(np.random.choice(non_pass))
71
+
72
+ return 0
73
+
74
+
75
+ class SmartHeuristicAgent(Agent):
76
+ """Advanced AI with better winning strategies"""
77
+
78
+ def __init__(self):
79
+ self.last_turn_num = -1
80
+ self.turn_action_counts = {}
81
+
82
+ def choose_action(self, state: GameState, player_id: int) -> int:
83
+ # --- Loop Protection ---
84
+ if state.turn_number != self.last_turn_num:
85
+ self.last_turn_num = state.turn_number
86
+ self.turn_action_counts = {}
87
+
88
+ legal_mask = state.get_legal_actions()
89
+ legal_indices = np.where(legal_mask)[0]
90
+ if len(legal_indices) == 0:
91
+ return 0
92
+
93
+ p = state.players[player_id]
94
+
95
+ # --- MULLIGAN PHASE ---
96
+ if state.phase in (Phase.MULLIGAN_P1, Phase.MULLIGAN_P2):
97
+ # Keep members with cost <= 3, discard others and all Live cards
98
+ # 300-359: index i is toggled
99
+
100
+ # Initialize mulligan_selection if not present
101
+ if not hasattr(p, "mulligan_selection"):
102
+ p.mulligan_selection = set()
103
+
104
+ to_toggle = []
105
+ for i, card_id in enumerate(p.hand):
106
+ should_keep = False
107
+ if card_id in state.member_db:
108
+ member = state.member_db[card_id]
109
+ if member.cost <= 3:
110
+ should_keep = True
111
+
112
+ # Check if already marked for return (mulligan_selection is a set of indices)
113
+ is_marked = i in p.mulligan_selection
114
+ if should_keep and is_marked:
115
+ # Unmark keepable card
116
+ to_toggle.append(300 + i)
117
+ elif not should_keep and not is_marked:
118
+ # Mark bad card
119
+ to_toggle.append(300 + i)
120
+
121
+ if to_toggle:
122
+ # Filter to only legal toggles
123
+ legal_set = set(legal_indices.tolist())
124
+ valid_toggles = [a for a in to_toggle if a in legal_set]
125
+ if valid_toggles:
126
+ choice = np.random.choice(valid_toggles)
127
+ return int(choice) if np.isscalar(choice) else int(choice[0])
128
+ return 0 # Confirm
129
+
130
+ # --- LIVE SET PHASE ---
131
+ if state.phase == Phase.LIVE_SET:
132
+ live_actions = [i for i in legal_indices if 400 <= i <= 459]
133
+ if not live_actions:
134
+ return 0 # Pass
135
+
136
+ current_hearts = p.get_total_hearts(state.member_db)
137
+
138
+ # Calculate what we already need for pending live cards
139
+ pending_req = np.zeros(7, dtype=np.int32)
140
+ for live_id in p.live_zone:
141
+ if live_id in state.live_db:
142
+ pending_req += state.live_db[live_id].required_hearts
143
+
144
+ # --- Improved LIVE_SET Logic ---
145
+ best_action = -1
146
+ max_value = -1
147
+
148
+ for action in live_actions:
149
+ hand_idx = action - 400
150
+ card_id = p.hand[hand_idx]
151
+ if card_id not in state.live_db:
152
+ continue
153
+
154
+ live = state.live_db[card_id]
155
+ total_req = pending_req + live.required_hearts
156
+
157
+ # Check feasibility
158
+ needed = total_req.copy()
159
+ have = current_hearts.copy()
160
+
161
+ # 1. Colors
162
+ possible = True
163
+ for c in range(6):
164
+ if have[c] >= needed[c]:
165
+ have[c] -= needed[c]
166
+ needed[c] = 0
167
+ else:
168
+ possible = False
169
+ break
170
+
171
+ if not possible:
172
+ continue
173
+
174
+ # 2. Any hearts
175
+ if np.sum(have) < needed[6]:
176
+ continue
177
+
178
+ # If possible, calculate value
179
+ value = live.score * 10
180
+ # Prefer cards we have hearts for
181
+ value += np.sum(have) - needed[6]
182
+
183
+ if value > max_value:
184
+ max_value = value
185
+ best_action = action
186
+
187
+ if best_action != -1:
188
+ return int(best_action)
189
+ return 0 # Pass if no safe plays
190
+
191
+ # --- MAIN PHASE ---
192
+ if state.phase == Phase.MAIN:
193
+ # 1. Activate Abilities (Rule of thumb: Draw/Energy > Buff > Damage)
194
+ activate_actions = [i for i in legal_indices if 200 <= i <= 202]
195
+ best_ability_action = -1
196
+ best_ability_score = -1
197
+
198
+ for action in activate_actions:
199
+ area = action - 200
200
+ card_id = p.stage[area]
201
+ if card_id in state.member_db:
202
+ # HEURISTIC: Use 1-step lookahead to detect no-ops or loops
203
+ try:
204
+ next_state = state.step(action)
205
+ next_p = next_state.players[player_id]
206
+
207
+ # Comparison metrics
208
+ hand_delta = len(next_p.hand) - len(p.hand)
209
+ energy_delta = len(next_p.energy_zone) - len(p.energy_zone)
210
+ tap_delta = np.sum(next_p.tapped_energy) - np.sum(p.tapped_energy)
211
+ stage_changed = not np.array_equal(next_p.stage, p.stage)
212
+ choice_pending = len(next_state.pending_choices) > 0
213
+
214
+ # Repeating action penalty
215
+ reps = self.turn_action_counts.get(action, 0)
216
+
217
+ if (
218
+ not any([hand_delta > 0, energy_delta > 0, stage_changed, choice_pending])
219
+ and tap_delta <= 0
220
+ ):
221
+ # State didn't meaningfully improve for the better (maybe it tapped something but didn't gain)
222
+ score = -10
223
+ else:
224
+ score = 15 if (hand_delta > 0 or energy_delta > 0) else 10
225
+
226
+ # Apply repetition penalty
227
+ score -= reps * 20
228
+
229
+ except Exception:
230
+ score = -100 # Crashes are bad
231
+
232
+ if score > best_ability_score:
233
+ best_ability_score = score
234
+ best_ability_action = action
235
+
236
+ # 2. Play Members
237
+ play_actions = [i for i in legal_indices if 1 <= i <= 180]
238
+ best_play_action = -1
239
+ best_play_score = -1
240
+
241
+ if play_actions:
242
+ # Find current requirements from all live cards in zone
243
+ # Precise "Scanning" of what hearts are missing
244
+ pending_req = np.zeros(7, dtype=np.int32)
245
+ for live_id in p.live_zone:
246
+ if live_id in state.live_db:
247
+ pending_req += state.live_db[live_id].required_hearts
248
+
249
+ # What we have (excluding hand)
250
+ current_hearts = p.get_total_hearts(state.member_db)
251
+
252
+ # Calculate simple missing vector (ignoring Any for a moment to prioritize colors)
253
+ # We really want to find a card that reduces the "Distance" to completion
254
+
255
+ for action in play_actions:
256
+ hand_idx = (action - 1) // 3
257
+ card_id = p.hand[hand_idx]
258
+ member = state.member_db[card_id]
259
+
260
+ score = 0
261
+
262
+ # A. Heart Contribution
263
+ # Does this member provide a heart provided in 'pending_req' that we don't have enough of?
264
+ prov = member.hearts # Shape (6,)
265
+
266
+ for c in range(6):
267
+ if pending_req[c] > current_hearts[c]:
268
+ # We need this color
269
+ if prov[c] > 0:
270
+ score += 20 # HUGE bonus for matching a need
271
+
272
+ # A2. Total Heart Volume (Crucial for 'Any' requirements)
273
+ total_hearts = prov.sum()
274
+ score += total_hearts * 5
275
+
276
+ # B. Base Stats
277
+ score += member.blades # Power is good
278
+ score += member.draw_icons * 5 # Drawing is good
279
+
280
+ # C. Cost Efficiency
281
+ # If we are low on energy, cheap cards are better
282
+ # But don't punish so hard we don't play at all!
283
+ untapped_energy = p.count_untapped_energy()
284
+ if untapped_energy < 1 and member.cost > 1:
285
+ score -= 2 # Small penalty
286
+
287
+ # D. Slot Efficiency
288
+ area = (action - 1) % 3
289
+ if p.stage[area] >= 0:
290
+ # Replacing a member.
291
+ prev = state.member_db[p.stage[area]]
292
+ if prev.hearts.sum() > member.hearts.sum():
293
+ score -= 5
294
+ else:
295
+ score += 5 # Filling empty slot is good
296
+
297
+ if score > best_play_score:
298
+ best_play_score = score
299
+ best_play_action = action
300
+
301
+ # Decision
302
+ if best_ability_score > 0:
303
+ self.turn_action_counts[best_ability_action] = self.turn_action_counts.get(best_ability_action, 0) + 1
304
+ return int(best_ability_action)
305
+
306
+ if best_play_action != -1:
307
+ return int(best_play_action)
308
+
309
+ # Pass - but verify it's legal
310
+ if 0 in legal_indices:
311
+ return 0
312
+ return int(legal_indices[0]) # Fallback to first legal
313
+
314
+ # Default: pick random non-pass if available
315
+ non_pass = [i for i in legal_indices if i != 0]
316
+ if non_pass:
317
+ return int(np.random.choice(non_pass))
318
+ # Fallback
319
+ return int(legal_indices[0]) if len(legal_indices) > 0 else 0
320
+
321
+
322
+ def generate_random_decks(member_ids, live_ids):
323
+ """Generate two random decks: 40 members + 10 lives in ONE main_deck each"""
324
+ m_pool = list(member_ids)
325
+ l_pool = list(live_ids)
326
+
327
+ # Ensure pool is not empty
328
+ if not m_pool:
329
+ m_pool = [0]
330
+ if not l_pool:
331
+ l_pool = [0]
332
+
333
+ # Mix members and lives in one deck
334
+ deck1 = [random.choice(m_pool) for _ in range(40)] + [random.choice(l_pool) for _ in range(10)]
335
+ deck2 = [random.choice(m_pool) for _ in range(40)] + [random.choice(l_pool) for _ in range(10)]
336
+
337
+ random.shuffle(deck1)
338
+ random.shuffle(deck2)
339
+
340
+ return deck1, deck2
341
+
342
+
343
+ def initialize_game(use_real_data: bool = True, cards_path: str = "data/cards.json") -> GameState:
344
+ """Initializes GameState with card data."""
345
+ if use_real_data:
346
+ try:
347
+ loader = CardDataLoader(cards_path)
348
+ m_db, l_db, e_db = loader.load()
349
+ GameState.member_db = m_db
350
+ GameState.live_db = l_db
351
+ except Exception as e:
352
+ print(f"Failed to load real data: {e}")
353
+ GameState.member_db = {}
354
+ GameState.live_db = {}
355
+ else:
356
+ # For testing, ensure dbs are empty or mocked if not loading real data
357
+ GameState.member_db = {}
358
+ GameState.live_db = {}
359
+ return GameState()
360
+
361
+
362
+ def create_easy_cards():
363
+ """Create custom easy cards for testing scoring"""
364
+ import numpy as np
365
+ from game.game_state import LiveCard, MemberCard
366
+
367
+ # Easy Member: Cost 1, provides 1 of each heart + 1 blade
368
+ m = MemberCard(
369
+ card_id=888,
370
+ card_no="PL!-sd1-001-SD", # Correct field name
371
+ name="Easy Member",
372
+ cost=1,
373
+ hearts=np.array([1, 1, 1, 1, 1, 1], dtype=np.int32),
374
+ blade_hearts=np.array([0, 0, 0, 0, 0, 0], dtype=np.int32),
375
+ blades=1,
376
+ volume_icons=0,
377
+ draw_icons=0,
378
+ img_path="cards/PLSD01/PL!-sd1-001-SD.png",
379
+ group="Easy",
380
+ )
381
+
382
+ # Easy Live: Score 1, Requires 1 Any Heart
383
+ l = LiveCard(
384
+ card_id=39999,
385
+ card_no="PL!-pb1-019-SD", # Correct field name
386
+ name="Easy Live",
387
+ score=1,
388
+ required_hearts=np.array([0, 0, 0, 0, 0, 0, 1], dtype=np.int32),
389
+ volume_icons=0,
390
+ draw_icons=0,
391
+ img_path="cards/PLSD01/PL!-pb1-019-SD.png",
392
+ group="Easy",
393
+ )
394
+
395
+ return m, l
396
+
397
+
398
+ def setup_game(args):
399
+ # Initialize game state
400
+ use_easy = args.deck_type == "easy"
401
+
402
+ state = initialize_game(use_real_data=(not use_easy), cards_path=args.cards_path)
403
+
404
+ # Set seed
405
+ np.random.seed(args.seed)
406
+ random.seed(args.seed)
407
+
408
+ if use_easy:
409
+ # INJECT EASY CARDS
410
+ m, l = create_easy_cards()
411
+ state.member_db[888] = m
412
+ state.live_db[39999] = l
413
+
414
+ # Single main_deck with BOTH Members (40) and Lives (10), shuffled
415
+ for p in state.players:
416
+ m_list = [888] * 48
417
+ l_list = [39999] * 12
418
+ p.main_deck = m_list + l_list
419
+ random.shuffle(p.main_deck)
420
+ p.energy_deck = [40000] * 12
421
+ p.hand = []
422
+ p.energy_zone = []
423
+ p.live_zone = []
424
+ p.discard = []
425
+ p.stage = np.array([-1, -1, -1], dtype=np.int32)
426
+ else:
427
+ # Normal Random Decks (Members + Lives mixed)
428
+ member_keys = list(state.member_db.keys())
429
+
430
+ if args.deck_type == "ability_only":
431
+ # Filter for members with abilities
432
+ member_keys = [mid for mid in member_keys if state.member_db[mid].abilities]
433
+ if not member_keys:
434
+ print("WARNING: No members with abilities found! Reverting to all members.")
435
+ member_keys = list(state.member_db.keys())
436
+
437
+ deck1, deck2 = generate_random_decks(member_keys, state.live_db.keys())
438
+ state.players[0].main_deck = deck1
439
+ state.players[0].energy_deck = [39999] * 10
440
+
441
+ state.players[1].main_deck = deck2
442
+ state.players[1].energy_deck = [39999] * 10
443
+
444
+ # Clear hands/zones just in case
445
+ for p in state.players:
446
+ p.hand = []
447
+ p.energy_zone = []
448
+
449
+ # Initial Draw (5 cards from main_deck)
450
+ for _ in range(5):
451
+ if state.players[0].main_deck:
452
+ state.players[0].hand.append(state.players[0].main_deck.pop())
453
+ if state.players[1].main_deck:
454
+ state.players[1].hand.append(state.players[1].main_deck.pop())
455
+
456
+ # Setup Energy Decks (Rule 6.1.1.3: 12 cards)
457
+ for p in state.players:
458
+ p.energy_deck = [40000] * 12
459
+ p.energy_zone = []
460
+ # Initial Energy (Rule 6.2.1.7: Move 3 cards to energy zone)
461
+ for _ in range(3):
462
+ if p.energy_deck:
463
+ p.energy_zone.append(p.energy_deck.pop(0))
464
+
465
+ return state
466
+
467
+
468
+ class AbilityFocusAgent(SmartHeuristicAgent):
469
+ """
470
+ Agent that prioritizes activating abilities and playing cards with abilities.
471
+ Used for stress-testing ability implementations.
472
+ """
473
+
474
+ def choose_action(self, state: GameState, player_id: int) -> int:
475
+ legal_mask = state.get_legal_actions()
476
+ legal_indices = np.where(legal_mask)[0]
477
+ if len(legal_indices) == 0:
478
+ return 0
479
+
480
+ # If we have pending choices, we MUST choose one of them (usually 500+)
481
+ if state.pending_choices:
482
+ non_zero = [i for i in legal_indices if i != 0]
483
+ if non_zero:
484
+ return int(np.random.choice(non_zero))
485
+ return int(np.random.choice(legal_indices))
486
+
487
+ p = state.players[player_id]
488
+
489
+ # 1. (LIVE_SET is handled by superclass logic for smarter selection)
490
+
491
+ # 2. MAIN Phase Priorities
492
+ if state.phase == Phase.MAIN:
493
+ priority_actions = []
494
+
495
+ # Check Play Actions (1-180)
496
+ play_actions = [i for i in legal_indices if 1 <= i <= 180]
497
+ for action_id in play_actions:
498
+ hand_idx = (action_id - 1) // 3
499
+ if hand_idx < len(p.hand):
500
+ card_id = p.hand[hand_idx]
501
+ if card_id in state.member_db:
502
+ card = state.member_db[card_id]
503
+ if card.abilities:
504
+ # Massive priority for cards with ON_PLAY or ACTIVATED
505
+ has_prio = any(a.trigger in (1, 7) for a in card.abilities) # 1=ON_PLAY, 7=ACTIVATED
506
+ if has_prio:
507
+ priority_actions.append(action_id)
508
+
509
+ # Check Activated Ability Actions (200-202)
510
+ ability_actions = [i for i in legal_indices if 200 <= i <= 202]
511
+ priority_actions.extend(ability_actions)
512
+
513
+ if priority_actions:
514
+ return int(np.random.choice(priority_actions))
515
+
516
+ # Fallback to SmartHeuristic if no high-priority ability action found
517
+ return super().choose_action(state, player_id)
518
+
519
+
520
+ class ConservativeAgent(SmartHeuristicAgent):
521
+ """
522
+ Very safe AI. Only sets Live cards if it has strictly sufficient hearts
523
+ available on stage right now (untapped members). Never gambles on future draws.
524
+ """
525
+
526
+ def choose_action(self, state: GameState, player_id: int) -> int:
527
+ # Override LIVE_SET phase with ultra-conservative logic
528
+ if state.phase == Phase.LIVE_SET:
529
+ p = state.players[player_id]
530
+ legal_indices = np.where(state.get_legal_actions())[0]
531
+ live_actions = [i for i in legal_indices if 400 <= i <= 459]
532
+ if not live_actions:
533
+ return 0 # Pass
534
+
535
+ # ONLY count hearts on stage (no assumptions about future)
536
+ stage_hearts = p.get_total_hearts(state.member_db)
537
+
538
+ # Calculate what we already need for pending live cards
539
+ pending_req = np.zeros(7, dtype=np.int32)
540
+ for live_id in p.live_zone:
541
+ if live_id in state.live_db:
542
+ pending_req += state.live_db[live_id].required_hearts
543
+
544
+ best_action = -1
545
+ max_value = -1
546
+
547
+ for action in live_actions:
548
+ hand_idx = action - 400
549
+ card_id = p.hand[hand_idx]
550
+ if card_id not in state.live_db:
551
+ continue
552
+
553
+ live = state.live_db[card_id]
554
+ total_req = pending_req + live.required_hearts
555
+
556
+ # Ultra-strict feasibility check: need EXACT hearts available
557
+ needed = total_req.copy()
558
+ have = stage_hearts.copy()
559
+
560
+ # 1. Check colored hearts (must have exact matches)
561
+ possible = True
562
+ for c in range(6):
563
+ if have[c] < needed[c]:
564
+ possible = False
565
+ break
566
+ have[c] -= needed[c]
567
+ needed[c] = 0
568
+
569
+ if not possible:
570
+ continue
571
+
572
+ # 2. Check "Any" hearts (must have enough remaining)
573
+ if np.sum(have) < needed[6]:
574
+ continue
575
+
576
+ # If strictly possible, calculate conservative value
577
+ value = live.score * 10
578
+ # Small bonus for having extra hearts (prefer safer plays)
579
+ value += np.sum(have) - needed[6]
580
+
581
+ if value > max_value:
582
+ max_value = value
583
+ best_action = action
584
+
585
+ if best_action != -1:
586
+ return int(best_action)
587
+ return 0 # Pass if no 100% safe plays
588
+
589
+ # For all other phases, use SmartHeuristicAgent logic
590
+ return super().choose_action(state, player_id)
591
+
592
+
593
+ class GambleAgent(SmartHeuristicAgent):
594
+ """
595
+ Risk-taking AI. Sets Live cards if it has enough hearts OR if it has
596
+ enough blades on stage to likely get the hearts from yell cards.
597
+ """
598
+
599
+ def choose_action(self, state: GameState, player_id: int) -> int:
600
+ if state.phase == Phase.LIVE_SET:
601
+ p = state.players[player_id]
602
+ legal_indices = np.where(state.get_legal_actions())[0]
603
+ live_actions = [i for i in legal_indices if 400 <= i <= 459]
604
+ if not live_actions:
605
+ return 0
606
+
607
+ # Current hearts on stage
608
+ stage_hearts = p.get_total_hearts(state.member_db)
609
+ # Total blades on stage (potential yells)
610
+ total_blades = p.get_total_blades(state.member_db)
611
+
612
+ # Estimated hearts from yells: Roughly 0.5 hearts per blade?
613
+ # Or simplified: consider blades as "Any" hearts for feasibility check
614
+ est_extra_hearts = total_blades // 2
615
+
616
+ best_action = -1
617
+ max_value = -1
618
+
619
+ # Pending req
620
+ pending_req = np.zeros(7, dtype=np.int32)
621
+ for live_id in p.live_zone:
622
+ if live_id in state.live_db:
623
+ pending_req += state.live_db[live_id].required_hearts
624
+
625
+ for action in live_actions:
626
+ hand_idx = action - 400
627
+ card_id = p.hand[hand_idx]
628
+ if card_id not in state.live_db:
629
+ continue
630
+
631
+ live = state.live_db[card_id]
632
+ total_req = pending_req + live.required_hearts
633
+
634
+ # Feasibility check with "Gamble" factor
635
+ needed = total_req.copy()
636
+ have = stage_hearts.copy()
637
+
638
+ # satisfy colors
639
+ possible = True
640
+ for c in range(6):
641
+ if have[c] < needed[c]:
642
+ # Can we gamble on this color?
643
+ # Maybe if we have a lot of blades.
644
+ # For simplicity, let's say we can only gamble on 'Any'
645
+ possible = False
646
+ break
647
+ have[c] -= needed[c]
648
+
649
+ if not possible:
650
+ continue
651
+
652
+ # Any hearts check with gamble
653
+ total_have = np.sum(have) + est_extra_hearts
654
+ if total_have >= needed[6]:
655
+ value = live.score * 10 + (total_have - needed[6])
656
+ if value > max_value:
657
+ max_value = value
658
+ best_action = action
659
+
660
+ if best_action != -1:
661
+ return int(best_action)
662
+ return 0
663
+
664
+ return super().choose_action(state, player_id)
665
+
666
+
667
+ class NNAgent(Agent):
668
+ """
669
+ Agent backed by a Neural Network (PyTorch), running on GPU if available.
670
+ """
671
+
672
+ def __init__(self, device=None, model_path=None):
673
+ try:
674
+ # Lazy import to avoid hard dependency if not used
675
+ # import torch
676
+ from game.network import NetworkConfig
677
+ from game.network_torch import TorchNetworkWrapper
678
+
679
+ self.config = NetworkConfig()
680
+ self.net = TorchNetworkWrapper(self.config, device=device)
681
+ self.device = self.net.device
682
+
683
+ if model_path:
684
+ print(f"Loading model from {model_path}...")
685
+ self.net.load(model_path)
686
+ # print(f"NNAgent initialized on device: {self.device}")
687
+
688
+ except ImportError as e:
689
+ print(f"WARNING: PyTorch or network modules not found. NNAgent falling back to Random. Error: {e}")
690
+ self.net = None
691
+ except Exception as e:
692
+ print(f"WARNING: Failed to initialize NNAgent: {e}")
693
+ self.net = None
694
+
695
+ def choose_action(self, state: GameState, player_id: int) -> int:
696
+ if self.net is None:
697
+ # Fallback to random if failed to load
698
+ legal_mask = state.get_legal_actions()
699
+ legal_indices = np.where(legal_mask)[0]
700
+ return int(np.random.choice(legal_indices)) if len(legal_indices) > 0 else 0
701
+
702
+ # Predict policy (this runs on GPU if available)
703
+ policy, value = self.net.predict(state)
704
+
705
+ # Choose action based on policy probabilities
706
+ # Direct policy sampling (fastest way to use the network without MCTS)
707
+
708
+ # Ensure probabilities sum to 1 (handling float errors)
709
+ policy_sum = policy.sum()
710
+ if policy_sum > 0:
711
+ policy = policy / policy_sum
712
+ return int(np.random.choice(len(policy), p=policy))
713
+ else:
714
+ # Fallback if policy is all zeros (shouldn't happen with proper masking)
715
+ legal_mask = state.get_legal_actions()
716
+ legal_indices = np.where(legal_mask)[0]
717
+ return int(np.random.choice(legal_indices)) if len(legal_indices) > 0 else 0
718
+
719
+
720
+ def run_simulation(args):
721
+ import io
722
+
723
+ # We will manage logging manually per game
724
+ root_logger = logging.getLogger()
725
+ root_logger.setLevel(logging.INFO)
726
+
727
+ # Console handler for high-level info
728
+ console = logging.StreamHandler()
729
+ console.setLevel(logging.WARNING) # Only show warnings/errors to console during run
730
+ root_logger.addHandler(console)
731
+
732
+ best_combined_score = -1
733
+ best_log_content = ""
734
+ best_game_idx = -1
735
+ best_winner = -1
736
+
737
+ results = []
738
+
739
+ start_total = time.time()
740
+
741
+ for game_idx in range(args.num_games):
742
+ # Capture logs for this game
743
+ log_capture = io.StringIO()
744
+ handler = logging.StreamHandler(log_capture)
745
+ handler.setLevel(logging.INFO)
746
+ # Use a simple format for game logs
747
+ formatter = logging.Formatter("%(message)s")
748
+ handler.setFormatter(formatter)
749
+
750
+ root_logger.handlers = [console, handler] # Replace handlers (keep console)
751
+
752
+ # Log Header
753
+ logging.info(f"=== Game {game_idx + 1} ===")
754
+
755
+ # Setup Game
756
+ try:
757
+ state = setup_game(args)
758
+ current_seed = args.seed + game_idx
759
+ random.seed(current_seed)
760
+ np.random.seed(current_seed)
761
+
762
+ # Agent Selection
763
+ if args.agent == "random":
764
+ p0_agent = RandomAgent()
765
+ elif args.agent == "ability_focus":
766
+ p0_agent = AbilityFocusAgent()
767
+ elif args.agent == "conservative":
768
+ p0_agent = ConservativeAgent()
769
+ elif args.agent == "gamble":
770
+ p0_agent = GambleAgent()
771
+ elif args.agent == "nn":
772
+ p0_agent = NNAgent()
773
+ elif args.agent == "search":
774
+ p0_agent = SearchProbAgent(depth=args.depth)
775
+ else:
776
+ p0_agent = SmartHeuristicAgent()
777
+
778
+ # Agent Selection P1
779
+ if args.agent_p2 == "ability_focus":
780
+ p1_agent = AbilityFocusAgent()
781
+ elif args.agent_p2 == "search":
782
+ p1_agent = SearchProbAgent(depth=args.depth)
783
+ elif args.agent_p2 == "smart":
784
+ p1_agent = SmartHeuristicAgent()
785
+ else:
786
+ p1_agent = RandomAgent()
787
+
788
+ agents = [p0_agent, p1_agent]
789
+
790
+ action_count = 0
791
+ while not state.game_over:
792
+ # Limit safety
793
+ if action_count > args.max_turns:
794
+ break
795
+ state.check_win_condition()
796
+ if state.game_over:
797
+ break
798
+
799
+ active_pid = state.current_player
800
+
801
+ # Detailed Log
802
+ logging.info("-" * 40)
803
+ logging.info(f"Turn {state.turn_number} | Phase {state.phase.name} | Active: P{active_pid}")
804
+ p0 = state.players[0]
805
+ p1 = state.players[1]
806
+ logging.info(f"Score: P0({len(p0.success_lives)}) - P1({len(p1.success_lives)})")
807
+ logging.info(f"Hand: P0({len(p0.hand)}) - P1({len(p1.hand)})")
808
+
809
+ # Agent Act
810
+ action = agents[active_pid].choose_action(state, active_pid)
811
+ logging.info(f"Action: P{active_pid} chooses {action}")
812
+
813
+ state = state.step(action)
814
+ action_count += 1
815
+
816
+ # Game End
817
+ p0_score = len(state.players[0].success_lives)
818
+ p1_score = len(state.players[1].success_lives)
819
+ combined_score = p0_score + p1_score
820
+ winner = state.winner
821
+
822
+ logging.info("=" * 40)
823
+ logging.info(f"Game Over. Winner: {winner}. Score: {p0_score}-{p1_score}")
824
+
825
+ res = {
826
+ "id": game_idx,
827
+ "winner": winner,
828
+ "score_total": combined_score,
829
+ "p0_score": p0_score,
830
+ "p1_score": p1_score,
831
+ "actions": action_count,
832
+ "game_turns": state.turn_number,
833
+ }
834
+ results.append(res)
835
+ print(f"DEBUG: Game {game_idx} Winner: {winner}")
836
+
837
+ # Check if this is the "best" game
838
+ is_win = winner == 0 or winner == 1
839
+ if is_win or combined_score > best_combined_score:
840
+ if is_win and best_winner == -1:
841
+ print(f"Found a Winner in Game {game_idx + 1}! (Winner: P{winner})")
842
+
843
+ best_log_content = log_capture.getvalue()
844
+ best_combined_score = combined_score
845
+ best_winner = winner
846
+ best_game_idx = game_idx # Added this line to update best_game_idx
847
+
848
+ if (game_idx + 1) % 100 == 0:
849
+ print(f"Simulated {game_idx + 1} games... Best Score: {best_combined_score}")
850
+
851
+ except Exception as e:
852
+ msg = f"Error in game {game_idx}: {e}"
853
+ print(msg, file=sys.stderr)
854
+ import traceback
855
+
856
+ traceback.print_exc()
857
+
858
+ finally:
859
+ log_capture.close()
860
+
861
+ total_time = time.time() - start_total
862
+
863
+ # Write best log
864
+ with open(args.log_file, "w", encoding="utf-8") as f:
865
+ f.write(best_log_content)
866
+
867
+ print("\n=== Simulation Complete ===")
868
+ print(f"Total Games Ran: {len(results)}")
869
+ print(f"Total Time: {total_time:.2f}s")
870
+
871
+ wins0 = sum(1 for r in results if r["winner"] == 0)
872
+ wins1 = sum(1 for r in results if r["winner"] == 1)
873
+ draws = sum(1 for r in results if r["winner"] == 2)
874
+
875
+ print(f"Wins: P0={wins0}, P1={wins1}, Draws={draws}")
876
+
877
+ total_actions = sum(r["actions"] for r in results)
878
+ total_game_turns = sum(r["game_turns"] for r in results)
879
+
880
+ if total_time > 0:
881
+ print(f"APS (Actions Per Second): {total_actions / total_time:.2f}")
882
+ print(f"TPS (Turns Per Second): {total_game_turns / total_time:.2f}")
883
+
884
+ print(
885
+ f"Best Game was Game {best_game_idx + 1} with Score Total {best_combined_score if best_combined_score >= 0 else 0}"
886
+ )
887
+ print(f"Log for best game saved to {args.log_file}")
888
+ import json
889
+
890
+ if results:
891
+ print(f"Last Game Summary: {json.dumps(results[-1], indent=2)}")
892
+
893
+
894
+ if __name__ == "__main__":
895
+ # Default path relative to this script
896
+ script_dir = os.path.dirname(os.path.abspath(__file__))
897
+ default_cards_path = os.path.join(script_dir, "..", "engine", "data", "cards.json")
898
+
899
+ parser = argparse.ArgumentParser()
900
+ parser.add_argument("--cards_path", default=default_cards_path, help="Path to cards.json")
901
+ parser.add_argument(
902
+ "--deck_type",
903
+ default="normal",
904
+ choices=["normal", "easy", "ability_only"],
905
+ help="Deck type: normal, easy, or ability_only",
906
+ )
907
+ parser.add_argument("--max_turns", type=int, default=1000, help="Max steps/turns to run")
908
+ parser.add_argument("--log_file", default="game_log.txt", help="Output log file")
909
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
910
+ parser.add_argument("--num_games", type=int, default=1, help="Number of games to run")
911
+ parser.add_argument(
912
+ "--agent",
913
+ default="smart",
914
+ choices=["random", "smart", "ability_focus", "conservative", "gamble", "nn", "search"],
915
+ help="Agent type to control P0",
916
+ )
917
+ parser.add_argument(
918
+ "--agent_p2",
919
+ default="random",
920
+ choices=["random", "smart", "ability_focus", "conservative", "gamble", "nn", "search"],
921
+ help="Agent type to control P1",
922
+ )
923
+ parser.add_argument("--depth", type=int, default=2, help="Search depth for SearchProbAgent")
924
+
925
+ args = parser.parse_args()
926
+
927
+ run_simulation(args)