DarshanScripts's picture
Upload stratego/benchmarking/plot_metrics.py with huggingface_hub
78d0ad7 verified
# stratego/benchmarking/plot_metrics.py
import pandas as pd
import matplotlib.pyplot as plt
import sys
def plot_from_csv(csv_path: str, rolling_window: int = 3):
df = pd.read_csv(csv_path)
games = range(len(df))
# ===============================
# 1. GAME LENGTH PER GAME
# ===============================
plt.figure()
plt.plot(games, df["turns"], marker="o")
plt.title("Game Length per Game")
plt.xlabel("Game index")
plt.ylabel("Turns")
plt.grid(True)
plt.show()
# ===============================
# 2. INVALID MOVES PER GAME
# ===============================
plt.figure()
plt.plot(games, df["invalid_moves_p0"], label="P0 Invalid Moves", marker="o")
plt.plot(games, df["invalid_moves_p1"], label="P1 Invalid Moves", marker="o")
plt.title("Invalid Moves per Game")
plt.xlabel("Game index")
plt.ylabel("Invalid moves")
plt.legend()
plt.grid(True)
plt.show()
# ===============================
# 3. ROLLING AVERAGE (STALLING)
# ===============================
df["rolling_turns"] = df["turns"].rolling(window=rolling_window).mean()
plt.figure()
plt.plot(games, df["turns"], alpha=0.3, label="Raw turns")
plt.plot(games, df["rolling_turns"], linewidth=3, label=f"Rolling avg ({rolling_window})")
plt.title("Game Stalling (Rolling Average of Turns)")
plt.xlabel("Game index")
plt.ylabel("Turns")
plt.legend()
plt.grid(True)
plt.show()
# ===============================
# 4. TERMINATION REASONS
# ===============================
termination_counts = df["game_end_reason"].value_counts()
plt.figure()
termination_counts.plot(kind="bar")
plt.title("Game Termination Reasons")
plt.xlabel("Reason")
plt.ylabel("Number of games")
plt.grid(axis="y")
plt.show()
# ===============================
# 5. CUMULATIVE WIN RATE
# ===============================
p0_wins = (df["winner"] == 0).cumsum()
p1_wins = (df["winner"] == 1).cumsum()
win_rate_p0 = p0_wins / (games + 1)
win_rate_p1 = p1_wins / (games + 1)
plt.figure()
plt.plot(games, win_rate_p0, label="P0 Win Rate")
plt.plot(games, win_rate_p1, label="P1 Win Rate")
plt.title("Cumulative Win Rate")
plt.xlabel("Game index")
plt.ylabel("Win rate")
plt.legend()
plt.grid(True)
plt.show()
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python plot_metrics.py <benchmark_csv>")
sys.exit(1)
plot_from_csv(sys.argv[1])