trioskosmos commited on
Commit
a384afe
·
verified ·
1 Parent(s): aa9d06f

Upload ai/utils/tournament.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/utils/tournament.py +215 -0
ai/utils/tournament.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import random
5
+ import sys
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ # Add project root to path
11
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
12
+
13
+ import engine_rust
14
+
15
+ from ai.agents.neural_mcts import HybridMCTSAgent
16
+ from ai.models.training_config import POLICY_SIZE
17
+ from ai.training.train import AlphaNet
18
+ from ai.utils.benchmark_decks import parse_deck
19
+
20
+
21
+ class Agent:
22
+ def get_action(self, game, db):
23
+ pass
24
+
25
+
26
+ class RandomAgent(Agent):
27
+ def get_action(self, game, db):
28
+ actions = game.get_legal_action_ids()
29
+ if not actions:
30
+ return 0
31
+ return random.choice(actions)
32
+
33
+
34
+ class MCTSAgent(Agent):
35
+ def __init__(self, sims=100):
36
+ self.sims = sims
37
+
38
+ def get_action(self, game, db):
39
+ suggestions = game.get_mcts_suggestions(self.sims, engine_rust.SearchHorizon.TurnEnd)
40
+ if not suggestions:
41
+ return 0
42
+ return suggestions[0][0]
43
+
44
+
45
+ class ResNetAgent(Agent):
46
+ def __init__(self, model_path):
47
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+ checkpoint = torch.load(model_path, map_location=self.device)
49
+
50
+ # Handle new dictionary checkpoint format
51
+ if isinstance(checkpoint, dict) and "model_state" in checkpoint:
52
+ state_dict = checkpoint["model_state"]
53
+ else:
54
+ state_dict = checkpoint
55
+
56
+ # Detect policy size from weights
57
+ p_fc_bias = state_dict.get("policy_head_fc.bias")
58
+ detected_policy_size = p_fc_bias.shape[0] if p_fc_bias is not None else POLICY_SIZE
59
+ print(f"ResNetAgent: Detected Policy Size {detected_policy_size}")
60
+
61
+ self.model = AlphaNet(policy_size=detected_policy_size).to(self.device)
62
+ self.model.load_state_dict(state_dict)
63
+ self.model.eval()
64
+ self.policy_size = detected_policy_size
65
+
66
+ def get_action(self, game, db):
67
+ # 1. Encode state
68
+ encoded = game.encode_state(db)
69
+ state_tensor = torch.FloatTensor(encoded).unsqueeze(0).to(self.device)
70
+
71
+ # 2. Get policy logits
72
+ with torch.no_grad():
73
+ logits, _ = self.model(state_tensor)
74
+
75
+ # 3. Mask illegal actions
76
+ legal_ids = game.get_legal_action_ids()
77
+ mask = torch.full((self.policy_size,), -1e9).to(self.device)
78
+ for aid in legal_ids:
79
+ if aid < self.policy_size:
80
+ mask[int(aid)] = 0.0
81
+
82
+ masked_logits = logits.squeeze(0) + mask
83
+
84
+ # 4. Argmax
85
+ return int(torch.argmax(masked_logits).item())
86
+
87
+
88
+ def play_match(agent0, agent1, db_content, decks, game_id):
89
+ db = engine_rust.PyCardDatabase(db_content)
90
+ game = engine_rust.PyGameState(db)
91
+
92
+ # Select random decks
93
+ p0_deck, p0_lives, p0_energy = random.choice(decks)
94
+ p1_deck, p1_lives, p1_energy = random.choice(decks)
95
+
96
+ game.initialize_game(p0_deck, p1_deck, p0_energy, p1_energy, p0_lives, p1_lives)
97
+
98
+ agents = [agent0, agent1]
99
+ step = 0
100
+ while not game.is_terminal() and step < 1000:
101
+ cp = game.current_player
102
+ phase = game.phase
103
+
104
+ is_interactive = phase in [-1, 0, 4, 5]
105
+
106
+ if is_interactive:
107
+ action = agents[cp].get_action(game, game.db)
108
+ try:
109
+ game.step(action)
110
+ except Exception:
111
+ # print(f"Action {action} failed: {e}")
112
+ # Fallback to random if model fails
113
+ legal = game.get_legal_action_ids()
114
+ if legal:
115
+ game.step(int(legal[0]))
116
+ else:
117
+ break
118
+ else:
119
+ game.step(0)
120
+ step += 1
121
+
122
+ return game.get_winner(), game.get_player(0).score, game.get_player(1).score, game.turn
123
+
124
+
125
+ def run_tournament(num_games=10):
126
+ with open("engine/data/cards_compiled.json", "r", encoding="utf-8") as f:
127
+ db_content = f.read()
128
+ db_json = json.loads(db_content)
129
+
130
+ # Load Decks
131
+ deck_paths = [
132
+ "ai/decks/aqours_cup.txt",
133
+ "ai/decks/hasunosora_cup.txt",
134
+ "ai/decks/liella_cup.txt",
135
+ "ai/decks/muse_cup.txt",
136
+ "ai/decks/nijigaku_cup.txt",
137
+ ]
138
+ decks = []
139
+ for dp in deck_paths:
140
+ if os.path.exists(dp):
141
+ decks.append(parse_deck(dp, db_json["member_db"], db_json["live_db"], db_json.get("energy_db", {})))
142
+
143
+ # Agents
144
+ # Agents
145
+ random_agent = RandomAgent()
146
+ mcts_agent = MCTSAgent(sims=100)
147
+ # resnet_agent = ResNetAgent("ai/models/alphanet_best.pt")
148
+
149
+ competitors = {
150
+ "Random": random_agent,
151
+ "MCTS-100": mcts_agent,
152
+ # "ResNet-Standalone": resnet_agent,
153
+ # "Neural-Hybrid (Py)": NeuralHeuristicAgent("ai/models/alphanet_best.pt", sims=100),
154
+ # "Neural-Rust (Full)": NeuralMCTSFullAgent("ai/models/alphanet.onnx", sims=100),
155
+ "Neural-Rust (Hybrid)": HybridMCTSAgent("ai/models/alphanet_best.onnx", sims=100, neural_weight=0.3),
156
+ }
157
+
158
+ results = {name: {"wins": 0, "draws": 0, "losses": 0, "total_score": 0, "turns": []} for name in competitors}
159
+
160
+ matchups = [("Neural-Rust (Hybrid)", "MCTS-100"), ("Neural-Rust (Hybrid)", "Random")]
161
+
162
+ print(f"Starting Tournament: {num_games} rounds per matchup...")
163
+ for p0_name, p1_name in matchups:
164
+ print(f"Matchup: {p0_name} vs {p1_name}")
165
+ for i in tqdm(range(num_games)):
166
+ # Swap sides every game
167
+ if i % 2 == 0:
168
+ winner, s0, s1, t = play_match(competitors[p0_name], competitors[p1_name], db_content, decks, i)
169
+ results[p0_name]["total_score"] += s0
170
+ results[p1_name]["total_score"] += s1
171
+ results[p0_name]["turns"].append(t)
172
+ results[p1_name]["turns"].append(t)
173
+ if winner == 0:
174
+ results[p0_name]["wins"] += 1
175
+ results[p1_name]["losses"] += 1
176
+ elif winner == 1:
177
+ results[p1_name]["wins"] += 1
178
+ results[p0_name]["losses"] += 1
179
+ else:
180
+ results[p0_name]["draws"] += 1
181
+ results[p1_name]["draws"] += 1
182
+ else:
183
+ winner, s1, s0, t = play_match(competitors[p1_name], competitors[p0_name], db_content, decks, i)
184
+ results[p0_name]["total_score"] += s0
185
+ results[p1_name]["total_score"] += s1
186
+ results[p0_name]["turns"].append(t)
187
+ results[p1_name]["turns"].append(t)
188
+ if winner == 0:
189
+ results[p1_name]["wins"] += 1
190
+ results[p0_name]["losses"] += 1
191
+ elif winner == 1:
192
+ results[p0_name]["wins"] += 1
193
+ results[p1_name]["losses"] += 1
194
+ else:
195
+ results[p0_name]["draws"] += 1
196
+ results[p1_name]["draws"] += 1
197
+
198
+ print("\nTournament Results:")
199
+ print(f"{'Agent':<18} | {'Wins':<5} | {'Draws':<5} | {'Losses':<5} | {'Avg Score':<10} | {'Avg Turns':<10}")
200
+ print("-" * 75)
201
+ for name, stat in results.items():
202
+ total_games = stat["wins"] + stat["draws"] + stat["losses"]
203
+ avg_score = stat["total_score"] / total_games if total_games > 0 else 0
204
+ avg_turns = sum(stat["turns"]) / len(stat["turns"]) if stat["turns"] else 0
205
+ print(
206
+ f"{name:<18} | {stat['wins']:<5} | {stat['draws']:<5} | {stat['losses']:<5} | {avg_score:<10.2f} | {avg_turns:<10.2f}"
207
+ )
208
+
209
+
210
+ if __name__ == "__main__":
211
+ parser = argparse.ArgumentParser()
212
+ parser.add_argument("--rounds", type=int, default=10)
213
+ args = parser.parse_args()
214
+
215
+ run_tournament(num_games=args.rounds)