DarshanScripts commited on
Commit
ec70901
·
verified ·
1 Parent(s): 4dd791c

Upload stratego\benchmarking\plot_metrics.py with huggingface_hub

Browse files
stratego//benchmarking//plot_metrics.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stratego/benchmarking/plot_metrics.py
2
+
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import sys
6
+
7
+
8
+ def plot_from_csv(csv_path: str, rolling_window: int = 3):
9
+ df = pd.read_csv(csv_path)
10
+
11
+ games = range(len(df))
12
+
13
+ # ===============================
14
+ # 1. GAME LENGTH PER GAME
15
+ # ===============================
16
+ plt.figure()
17
+ plt.plot(games, df["turns"], marker="o")
18
+ plt.title("Game Length per Game")
19
+ plt.xlabel("Game index")
20
+ plt.ylabel("Turns")
21
+ plt.grid(True)
22
+ plt.show()
23
+
24
+ # ===============================
25
+ # 2. INVALID MOVES PER GAME
26
+ # ===============================
27
+ plt.figure()
28
+ plt.plot(games, df["invalid_moves_p0"], label="P0 Invalid Moves", marker="o")
29
+ plt.plot(games, df["invalid_moves_p1"], label="P1 Invalid Moves", marker="o")
30
+ plt.title("Invalid Moves per Game")
31
+ plt.xlabel("Game index")
32
+ plt.ylabel("Invalid moves")
33
+ plt.legend()
34
+ plt.grid(True)
35
+ plt.show()
36
+
37
+ # ===============================
38
+ # 3. ROLLING AVERAGE (STALLING)
39
+ # ===============================
40
+ df["rolling_turns"] = df["turns"].rolling(window=rolling_window).mean()
41
+
42
+ plt.figure()
43
+ plt.plot(games, df["turns"], alpha=0.3, label="Raw turns")
44
+ plt.plot(games, df["rolling_turns"], linewidth=3, label=f"Rolling avg ({rolling_window})")
45
+ plt.title("Game Stalling (Rolling Average of Turns)")
46
+ plt.xlabel("Game index")
47
+ plt.ylabel("Turns")
48
+ plt.legend()
49
+ plt.grid(True)
50
+ plt.show()
51
+
52
+ # ===============================
53
+ # 4. TERMINATION REASONS
54
+ # ===============================
55
+ termination_counts = df["game_end_reason"].value_counts()
56
+
57
+ plt.figure()
58
+ termination_counts.plot(kind="bar")
59
+ plt.title("Game Termination Reasons")
60
+ plt.xlabel("Reason")
61
+ plt.ylabel("Number of games")
62
+ plt.grid(axis="y")
63
+ plt.show()
64
+
65
+ # ===============================
66
+ # 5. CUMULATIVE WIN RATE
67
+ # ===============================
68
+ p0_wins = (df["winner"] == 0).cumsum()
69
+ p1_wins = (df["winner"] == 1).cumsum()
70
+
71
+ win_rate_p0 = p0_wins / (games + 1)
72
+ win_rate_p1 = p1_wins / (games + 1)
73
+
74
+ plt.figure()
75
+ plt.plot(games, win_rate_p0, label="P0 Win Rate")
76
+ plt.plot(games, win_rate_p1, label="P1 Win Rate")
77
+ plt.title("Cumulative Win Rate")
78
+ plt.xlabel("Game index")
79
+ plt.ylabel("Win rate")
80
+ plt.legend()
81
+ plt.grid(True)
82
+ plt.show()
83
+
84
+
85
+ if __name__ == "__main__":
86
+ if len(sys.argv) < 2:
87
+ print("Usage: python plot_metrics.py <benchmark_csv>")
88
+ sys.exit(1)
89
+
90
+ plot_from_csv(sys.argv[1])