Spaces:
Configuration error
Configuration error
| # stratego/benchmarking/plot_metrics.py | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import sys | |
| def plot_from_csv(csv_path: str, rolling_window: int = 3): | |
| df = pd.read_csv(csv_path) | |
| games = range(len(df)) | |
| # =============================== | |
| # 1. GAME LENGTH PER GAME | |
| # =============================== | |
| plt.figure() | |
| plt.plot(games, df["turns"], marker="o") | |
| plt.title("Game Length per Game") | |
| plt.xlabel("Game index") | |
| plt.ylabel("Turns") | |
| plt.grid(True) | |
| plt.show() | |
| # =============================== | |
| # 2. INVALID MOVES PER GAME | |
| # =============================== | |
| plt.figure() | |
| plt.plot(games, df["invalid_moves_p0"], label="P0 Invalid Moves", marker="o") | |
| plt.plot(games, df["invalid_moves_p1"], label="P1 Invalid Moves", marker="o") | |
| plt.title("Invalid Moves per Game") | |
| plt.xlabel("Game index") | |
| plt.ylabel("Invalid moves") | |
| plt.legend() | |
| plt.grid(True) | |
| plt.show() | |
| # =============================== | |
| # 3. ROLLING AVERAGE (STALLING) | |
| # =============================== | |
| df["rolling_turns"] = df["turns"].rolling(window=rolling_window).mean() | |
| plt.figure() | |
| plt.plot(games, df["turns"], alpha=0.3, label="Raw turns") | |
| plt.plot(games, df["rolling_turns"], linewidth=3, label=f"Rolling avg ({rolling_window})") | |
| plt.title("Game Stalling (Rolling Average of Turns)") | |
| plt.xlabel("Game index") | |
| plt.ylabel("Turns") | |
| plt.legend() | |
| plt.grid(True) | |
| plt.show() | |
| # =============================== | |
| # 4. TERMINATION REASONS | |
| # =============================== | |
| termination_counts = df["game_end_reason"].value_counts() | |
| plt.figure() | |
| termination_counts.plot(kind="bar") | |
| plt.title("Game Termination Reasons") | |
| plt.xlabel("Reason") | |
| plt.ylabel("Number of games") | |
| plt.grid(axis="y") | |
| plt.show() | |
| # =============================== | |
| # 5. CUMULATIVE WIN RATE | |
| # =============================== | |
| p0_wins = (df["winner"] == 0).cumsum() | |
| p1_wins = (df["winner"] == 1).cumsum() | |
| win_rate_p0 = p0_wins / (games + 1) | |
| win_rate_p1 = p1_wins / (games + 1) | |
| plt.figure() | |
| plt.plot(games, win_rate_p0, label="P0 Win Rate") | |
| plt.plot(games, win_rate_p1, label="P1 Win Rate") | |
| plt.title("Cumulative Win Rate") | |
| plt.xlabel("Game index") | |
| plt.ylabel("Win rate") | |
| plt.legend() | |
| plt.grid(True) | |
| plt.show() | |
| if __name__ == "__main__": | |
| if len(sys.argv) < 2: | |
| print("Usage: python plot_metrics.py <benchmark_csv>") | |
| sys.exit(1) | |
| plot_from_csv(sys.argv[1]) | |