File size: 2,655 Bytes
78d0ad7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# stratego/benchmarking/plot_metrics.py

import pandas as pd
import matplotlib.pyplot as plt
import sys


def plot_from_csv(csv_path: str, rolling_window: int = 3):
    df = pd.read_csv(csv_path)

    games = range(len(df))

    # ===============================
    # 1. GAME LENGTH PER GAME
    # ===============================
    plt.figure()
    plt.plot(games, df["turns"], marker="o")
    plt.title("Game Length per Game")
    plt.xlabel("Game index")
    plt.ylabel("Turns")
    plt.grid(True)
    plt.show()

    # ===============================
    # 2. INVALID MOVES PER GAME
    # ===============================
    plt.figure()
    plt.plot(games, df["invalid_moves_p0"], label="P0 Invalid Moves", marker="o")
    plt.plot(games, df["invalid_moves_p1"], label="P1 Invalid Moves", marker="o")
    plt.title("Invalid Moves per Game")
    plt.xlabel("Game index")
    plt.ylabel("Invalid moves")
    plt.legend()
    plt.grid(True)
    plt.show()

    # ===============================
    # 3. ROLLING AVERAGE (STALLING)
    # ===============================
    df["rolling_turns"] = df["turns"].rolling(window=rolling_window).mean()

    plt.figure()
    plt.plot(games, df["turns"], alpha=0.3, label="Raw turns")
    plt.plot(games, df["rolling_turns"], linewidth=3, label=f"Rolling avg ({rolling_window})")
    plt.title("Game Stalling (Rolling Average of Turns)")
    plt.xlabel("Game index")
    plt.ylabel("Turns")
    plt.legend()
    plt.grid(True)
    plt.show()

    # ===============================
    # 4. TERMINATION REASONS
    # ===============================
    termination_counts = df["game_end_reason"].value_counts()

    plt.figure()
    termination_counts.plot(kind="bar")
    plt.title("Game Termination Reasons")
    plt.xlabel("Reason")
    plt.ylabel("Number of games")
    plt.grid(axis="y")
    plt.show()

    # ===============================
    # 5. CUMULATIVE WIN RATE
    # ===============================
    p0_wins = (df["winner"] == 0).cumsum()
    p1_wins = (df["winner"] == 1).cumsum()

    win_rate_p0 = p0_wins / (games + 1)
    win_rate_p1 = p1_wins / (games + 1)

    plt.figure()
    plt.plot(games, win_rate_p0, label="P0 Win Rate")
    plt.plot(games, win_rate_p1, label="P1 Win Rate")
    plt.title("Cumulative Win Rate")
    plt.xlabel("Game index")
    plt.ylabel("Win rate")
    plt.legend()
    plt.grid(True)
    plt.show()


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python plot_metrics.py <benchmark_csv>")
        sys.exit(1)

    plot_from_csv(sys.argv[1])