trioskosmos commited on
Commit
764819e
·
verified ·
1 Parent(s): ae55ffb

Upload ai/utils/benchmark_decks.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/utils/benchmark_decks.py +126 -0
ai/utils/benchmark_decks.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import json
3
+
4
+ import engine_rust
5
+
6
+
7
+ def parse_deck(deck_file, member_db, live_db, energy_db):
8
+ with open(deck_file, "r", encoding="utf-8") as f:
9
+ lines = f.readlines()
10
+
11
+ main_deck = []
12
+ lives = []
13
+ energy_deck = []
14
+
15
+ db_map = {"member": member_db, "live": live_db, "energy": energy_db}
16
+
17
+ for line in lines:
18
+ line = line.strip()
19
+ if not line or " x " not in line:
20
+ continue
21
+
22
+ parts = line.split(" x ")
23
+ card_no = parts[0].strip()
24
+ count = int(parts[1].strip())
25
+
26
+ found_type = None
27
+ found_id = None
28
+
29
+ for db_type, db in db_map.items():
30
+ for i, card in db.items():
31
+ if card.get("card_no") == card_no:
32
+ found_id = int(i)
33
+ found_type = db_type
34
+ break
35
+ if found_id is not None:
36
+ break
37
+
38
+ if found_id is not None:
39
+ if found_type == "live":
40
+ lives.extend([found_id] * count)
41
+ elif found_type == "energy":
42
+ energy_deck.extend([found_id] * count)
43
+ else:
44
+ main_deck.extend([found_id] * count)
45
+ else:
46
+ # Fallback for Energy if it's named like one or matches known patterns
47
+ if "energy" in card_no.lower() or "sd1-036" in card_no.lower():
48
+ default_energy_id = 20000 # Use new offset
49
+ energy_deck.extend([default_energy_id] * count)
50
+ else:
51
+ print(f"Warning: Card {card_no} not found in DB")
52
+
53
+ return main_deck, lives, energy_deck
54
+
55
+
56
+ def run_benchmark(deck_name, deck_file, db_content, sims=100):
57
+ db_json = json.loads(db_content)
58
+ member_db = db_json["member_db"]
59
+ live_db = db_json["live_db"]
60
+ energy_db = db_json.get("energy_db", {})
61
+
62
+ main_deck, lives, energy_deck = parse_deck(deck_file, member_db, live_db, energy_db)
63
+
64
+ # Padding/Trimming to standard sizes if needed
65
+ test_lives = lives[:12]
66
+ test_deck = main_deck[:48] # Rule 6.1.1.1
67
+ test_energy = energy_deck[:12] # Rule 6.1.1.3
68
+
69
+ db = engine_rust.PyCardDatabase(db_content)
70
+ game = engine_rust.PyGameState(db)
71
+
72
+ game.initialize_game(test_deck, test_deck, test_energy, test_energy, test_lives, test_lives)
73
+
74
+ turn_limit = 10
75
+ step = 0
76
+ while not game.is_terminal() and game.turn <= turn_limit and step < 1000:
77
+ cp = game.current_player
78
+ phase = game.phase
79
+ is_interactive = phase in [-1, 0, 4, 5]
80
+
81
+ if is_interactive:
82
+ # Use TurnEnd horizon specifically for this bench
83
+ suggestions = game.get_mcts_suggestions(sims, engine_rust.SearchHorizon.TurnEnd)
84
+ best_action = suggestions[0][0]
85
+ game.step(best_action)
86
+ else:
87
+ game.step(0)
88
+ step += 1
89
+
90
+ p0 = game.get_player(0)
91
+ return {
92
+ "deck": deck_name,
93
+ "score": p0.score,
94
+ "lives_cleared": len(p0.success_lives),
95
+ "turns": game.turn,
96
+ "steps": step,
97
+ }
98
+
99
+
100
+ def main():
101
+ with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
102
+ db_content = f.read()
103
+
104
+ deck_files = {
105
+ "Aqours": "ai/decks/aqours_cup.txt",
106
+ "Hasunosora": "ai/decks/hasunosora_cup.txt",
107
+ "Liella": "ai/decks/liella_cup.txt",
108
+ "Muse": "ai/decks/muse_cup.txt",
109
+ "Nijigasaki": "ai/decks/nijigaku_cup.txt",
110
+ }
111
+
112
+ print(f"{'Deck':<12} | {'Score':<5} | {'Lives':<5} | {'Turns':<5}")
113
+ print("-" * 40)
114
+
115
+ # Run in parallel to save time
116
+ with concurrent.futures.ProcessPoolExecutor() as executor:
117
+ futures = {
118
+ executor.submit(run_benchmark, name, path, db_content, 50): name for name, path in deck_files.items()
119
+ }
120
+ for future in concurrent.futures.as_completed(futures):
121
+ res = future.result()
122
+ print(f"{res['deck']:<12} | {res['score']:<5} | {res['lives_cleared']:<5} | {res['turns']:<5}")
123
+
124
+
125
+ if __name__ == "__main__":
126
+ main()