rabukasim / tools /benchmarks /profile_agents.py
trioskosmos's picture
Upload folder using huggingface_hub
463f868 verified
import json
import os
import random
import sys
import time
# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import engine_rust
from ai.benchmark_decks import parse_deck
from ai.tournament import MCTSAgent, RandomAgent, ResNetAgent
def profile_agents(num_games=3):
with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
db_content = f.read()
db_json = json.loads(db_content)
# Load Decks
deck_paths = [
"ai/decks/aqours_cup.txt",
"ai/decks/hasunosora_cup.txt",
"ai/decks/liella_cup.txt",
"ai/decks/muse_cup.txt",
"ai/decks/nijigaku_cup.txt",
]
decks = []
for dp in deck_paths:
if os.path.exists(dp):
decks.append(parse_deck(dp, db_json["member_db"], db_json["live_db"], db_json.get("energy_db", {})))
# Agents
random_agent = RandomAgent()
mcts_agent = MCTSAgent(sims=100)
resnet_agent = ResNetAgent("ai/models/alphanet_best.pt")
agent_names = ["Random", "MCTS-100", "ResNet-v1"]
agents = [random_agent, mcts_agent, resnet_agent]
print(f"Profiling Agents over {num_games} test games each...")
stats = {name: {"calls": 0, "total_time": 0} for name in agent_names}
db = engine_rust.PyCardDatabase(db_content)
for name, agent in zip(agent_names, agents):
print(f"Profiling {name}...")
for _ in range(num_games):
game = engine_rust.PyGameState(db)
p0_deck, p0_lives = random.choice(decks)
p1_deck, p1_lives = random.choice(decks)
game.initialize_game(p0_deck, p1_deck, [0] * 10, [0] * 10, p0_lives, p1_lives)
step = 0
while not game.is_terminal() and step < 200:
cp = game.current_player
phase = game.phase
if phase in [-1, 0, 4, 5]:
start = time.perf_counter()
action = agent.get_action(game, game.db)
end = time.perf_counter()
stats[name]["total_time"] += end - start
stats[name]["calls"] += 1
try:
game.step(action)
except:
legal = game.get_legal_action_ids()
if legal:
game.step(int(legal[0]))
else:
break
else:
game.step(0)
step += 1
print("\nSpeed Results:")
print(f"{'Agent':<12} | {'Avg Time/Move':<15} | {'Moves/Sec':<12}")
print("-" * 50)
for name, data in stats.items():
if data["calls"] > 0:
avg = data["total_time"] / data["calls"]
m_sec = 1.0 / avg if avg > 0 else 0
print(f"{name:<12} | {avg * 1000:>10.2f} ms | {m_sec:>10.1f}")
else:
print(f"{name:<12} | N/A")
if __name__ == "__main__":
profile_agents()