trioskosmos commited on
Commit
c6d22b8
·
verified ·
1 Parent(s): a8f17ae

Upload ai/training/ppo_self_play.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/training/ppo_self_play.py +129 -0
ai/training/ppo_self_play.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import sys
4
+
5
+ import numpy as np
6
+ from sb3_contrib import MaskablePPO
7
+
8
+ # Ensure project root is in path
9
+ sys.path.append(os.getcwd())
10
+
11
+ from ai.vector_env import VectorGameState
12
+
13
+
14
+ def run_self_play():
15
+ print("--- PPO Self-Play Verification ---")
16
+
17
+ # 1. Config
18
+ USE_LATEST = True
19
+ MODEL_PATH = "checkpoints/vector/interrupted_model.zip"
20
+ BATCH_SIZE = 50
21
+ N_GAMES = 100
22
+
23
+ if USE_LATEST:
24
+ list_of_files = glob.glob("checkpoints/vector/*.zip")
25
+ if list_of_files:
26
+ MODEL_PATH = max(list_of_files, key=os.path.getmtime)
27
+ print(f"Model: {MODEL_PATH}")
28
+
29
+ # 2. Load Model
30
+ print("Loading PPO Model...")
31
+ model = MaskablePPO.load(MODEL_PATH, device="cpu")
32
+
33
+ # 3. Init Raw VectorGameState (No Adapter, manually stepping for self-play)
34
+ print(f"Initializing {BATCH_SIZE} environments...")
35
+ env = VectorGameState(num_envs=BATCH_SIZE)
36
+ env.reset()
37
+
38
+ total_wins_p0 = 0
39
+ total_wins_p1 = 0
40
+ total_draws = 0
41
+ games_played = 0
42
+
43
+ num_batches = (N_GAMES + BATCH_SIZE - 1) // BATCH_SIZE
44
+
45
+ for b in range(num_batches):
46
+ env.reset()
47
+ active = np.ones(BATCH_SIZE, dtype=bool)
48
+
49
+ # Track if game ended in this batch loop
50
+ batch_wins_p0 = np.zeros(BATCH_SIZE, dtype=bool)
51
+ batch_wins_p1 = np.zeros(BATCH_SIZE, dtype=bool)
52
+ batch_draws = np.zeros(BATCH_SIZE, dtype=bool)
53
+
54
+ step_count = 0
55
+ while np.any(active) and step_count < 150: # Slightly longer for self-play
56
+ # Player 0 Turn (Perspective 0)
57
+ obs0 = env.get_observations(player_id=0)
58
+ masks0 = env.get_action_masks(player_id=0)
59
+ num_legal = np.sum(masks0[0])
60
+ if step_count == 0:
61
+ print(f" Step {step_count}: {num_legal} legal actions.")
62
+
63
+ act0_raw, _ = model.predict(obs0, action_masks=masks0, deterministic=True)
64
+ act0 = act0_raw.astype(np.int32)
65
+
66
+ # Player 1 Turn (Perspective 1)
67
+ obs1 = env.get_observations(player_id=1)
68
+ masks1 = env.get_action_masks(player_id=1)
69
+ act1_raw, _ = model.predict(obs1, action_masks=masks1, deterministic=True)
70
+ act1 = act1_raw.astype(np.int32)
71
+
72
+ # Step both!
73
+ env.step(act0, opp_actions=act1)
74
+
75
+ # Detailed Logging for Turn 1-5
76
+ if step_count < 5:
77
+ # Get more context for P0
78
+ stg0 = env.batch_stage[0]
79
+ sc0 = env.batch_scores[0]
80
+ ph0 = env.batch_global_ctx[0, 8]
81
+ print(f" T{step_count + 1} | P0 Act: {act0[0]} | Stage: {stg0} | Score: {sc0} | Ph: {ph0}")
82
+
83
+ # Check for dones (Custom logic since no adapter)
84
+ for i in range(BATCH_SIZE):
85
+ if active[i]:
86
+ sc0 = env.batch_scores[i]
87
+ sc1 = env.opp_scores[i]
88
+
89
+ is_done = False
90
+ if sc0 >= 3 or sc1 >= 3:
91
+ is_done = True
92
+ elif env.turn >= 50: # Shorten for speed in debug
93
+ is_done = True
94
+
95
+ if is_done:
96
+ active[i] = False
97
+ if sc0 >= 3 and sc0 > sc1:
98
+ batch_wins_p0[i] = True
99
+ elif sc1 >= 3 and sc1 > sc0:
100
+ batch_wins_p1[i] = True
101
+ else:
102
+ batch_draws[i] = True
103
+
104
+ step_count += 1
105
+
106
+ total_wins_p0 += np.sum(batch_wins_p0)
107
+ total_wins_p1 += np.sum(batch_wins_p1)
108
+ total_draws += np.sum(batch_draws)
109
+ games_played += BATCH_SIZE
110
+ print(f"Batch {b + 1} done. P0 Wins: {total_wins_p0}, P1 Wins: {total_wins_p1}, Draws: {total_draws}")
111
+
112
+ with open("benchmarks/self_play_results.txt", "w") as f:
113
+ f.write("\n--- Final Self-Play Results ---\n")
114
+ f.write(f"Total Games: {games_played}\n")
115
+ f.write(f"Player 0 Wins: {total_wins_p0}\n")
116
+ f.write(f"Player 1 Wins: {total_wins_p1}\n")
117
+ f.write(f"Draws: {total_draws}\n")
118
+
119
+ # Analyze
120
+ if total_wins_p0 + total_wins_p1 > 0:
121
+ f.write("RESULT: Agent CAN win games when playing against itself.\n")
122
+ else:
123
+ f.write("RESULT: Agent fails to win even in self-play. Policy likely broken.\n")
124
+
125
+ print("Results saved to benchmarks/self_play_results.txt")
126
+
127
+
128
+ if __name__ == "__main__":
129
+ run_self_play()