Spaces:
Configuration error
Configuration error
File size: 2,655 Bytes
78d0ad7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | # stratego/benchmarking/plot_metrics.py
import pandas as pd
import matplotlib.pyplot as plt
import sys
def plot_from_csv(csv_path: str, rolling_window: int = 3):
df = pd.read_csv(csv_path)
games = range(len(df))
# ===============================
# 1. GAME LENGTH PER GAME
# ===============================
plt.figure()
plt.plot(games, df["turns"], marker="o")
plt.title("Game Length per Game")
plt.xlabel("Game index")
plt.ylabel("Turns")
plt.grid(True)
plt.show()
# ===============================
# 2. INVALID MOVES PER GAME
# ===============================
plt.figure()
plt.plot(games, df["invalid_moves_p0"], label="P0 Invalid Moves", marker="o")
plt.plot(games, df["invalid_moves_p1"], label="P1 Invalid Moves", marker="o")
plt.title("Invalid Moves per Game")
plt.xlabel("Game index")
plt.ylabel("Invalid moves")
plt.legend()
plt.grid(True)
plt.show()
# ===============================
# 3. ROLLING AVERAGE (STALLING)
# ===============================
df["rolling_turns"] = df["turns"].rolling(window=rolling_window).mean()
plt.figure()
plt.plot(games, df["turns"], alpha=0.3, label="Raw turns")
plt.plot(games, df["rolling_turns"], linewidth=3, label=f"Rolling avg ({rolling_window})")
plt.title("Game Stalling (Rolling Average of Turns)")
plt.xlabel("Game index")
plt.ylabel("Turns")
plt.legend()
plt.grid(True)
plt.show()
# ===============================
# 4. TERMINATION REASONS
# ===============================
termination_counts = df["game_end_reason"].value_counts()
plt.figure()
termination_counts.plot(kind="bar")
plt.title("Game Termination Reasons")
plt.xlabel("Reason")
plt.ylabel("Number of games")
plt.grid(axis="y")
plt.show()
# ===============================
# 5. CUMULATIVE WIN RATE
# ===============================
p0_wins = (df["winner"] == 0).cumsum()
p1_wins = (df["winner"] == 1).cumsum()
win_rate_p0 = p0_wins / (games + 1)
win_rate_p1 = p1_wins / (games + 1)
plt.figure()
plt.plot(games, win_rate_p0, label="P0 Win Rate")
plt.plot(games, win_rate_p1, label="P1 Win Rate")
plt.title("Cumulative Win Rate")
plt.xlabel("Game index")
plt.ylabel("Win rate")
plt.legend()
plt.grid(True)
plt.show()
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python plot_metrics.py <benchmark_csv>")
sys.exit(1)
plot_from_csv(sys.argv[1])
|