trioskosmos commited on
Commit
5dfc8e2
·
verified ·
1 Parent(s): 5ff442a

Upload ai/training/train_optimized.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/training/train_optimized.py +281 -0
ai/training/train_optimized.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+ from sb3_contrib import MaskablePPO
6
+ from sb3_contrib.common.wrappers import ActionMasker
7
+ from stable_baselines3.common.callbacks import BaseCallback, CallbackList, CheckpointCallback
8
+ from stable_baselines3.common.monitor import Monitor
9
+
10
+ # Ensure project root is in path for local imports
11
+ if os.getcwd() not in sys.path:
12
+ sys.path.append(os.getcwd())
13
+
14
+ print(" [Heartbeat] train_optimized.py entry point reached.", flush=True)
15
+
16
+ import argparse
17
+
18
+ import torch
19
+
20
+ # If many workers are used, we keep intra-op threads low to avoid overhead
21
+ if int(os.getenv("TRAIN_CPUS", "4")) <= 4:
22
+ torch.set_num_threads(2)
23
+ else:
24
+ torch.set_num_threads(1)
25
+
26
+ # Fix for Windows DLL loading issues in subprocesses
27
+ if sys.platform == "win32":
28
+ # Add torch lib to DLL search path
29
+ torch_lib_path = os.path.join(os.path.dirname(torch.__file__), "lib")
30
+ if os.path.exists(torch_lib_path):
31
+ os.add_dll_directory(torch_lib_path)
32
+
33
+ # Ensure CUDA_PATH is in environ if found
34
+ if "CUDA_PATH" not in os.environ:
35
+ cuda_path = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.2"
36
+ if os.path.exists(cuda_path):
37
+ os.environ["CUDA_PATH"] = cuda_path
38
+ os.environ["PATH"] = os.path.join(cuda_path, "bin") + os.pathsep + os.environ["PATH"]
39
+
40
+ # Import our environment
41
+ from functools import partial
42
+
43
+ from ai.batched_env import BatchedSubprocVecEnv
44
+ from ai.gym_env import LoveLiveCardGameEnv
45
+
46
+
47
+ class TrainingStatsCallback(BaseCallback):
48
+ """Custom callback for logging win rates and illegal move stats from gym_env."""
49
+
50
+ def __init__(self, verbose=0):
51
+ super(TrainingStatsCallback, self).__init__(verbose)
52
+
53
+ def _on_step(self) -> bool:
54
+ if self.n_calls == 1:
55
+ print(" [Heartbeat] Training loop is active. First step reached!", flush=True)
56
+
57
+ infos = self.locals.get("infos")
58
+ if infos:
59
+ # 1. Capture Win Rate from custom env attribute (Legacy/Direct)
60
+ if len(infos) > 0 and "win_rate" in infos[0]:
61
+ avg_win_rate = np.mean([info.get("win_rate", 0) for info in infos])
62
+ self.logger.record("game/win_rate_legacy", avg_win_rate)
63
+
64
+ # 1b. Per-game Heartbeat
65
+ for info in infos:
66
+ if "episode" in info:
67
+ print(
68
+ f" [Heartbeat] Game completed! Reward: {info['episode']['r']:.2f} | Turns: {info['episode']['l']}",
69
+ flush=True,
70
+ )
71
+
72
+ # 2. Capture Episode Completion Stats
73
+ episode_infos = [info.get("episode") for info in infos if "episode" in info]
74
+ if episode_infos:
75
+ avg_reward = np.mean([ep["r"] for ep in episode_infos])
76
+ avg_turns = np.mean([ep["turn"] for ep in episode_infos])
77
+ win_count = sum(1 for ep in episode_infos if ep["win"])
78
+ win_rate = (win_count / len(episode_infos)) * 100
79
+
80
+ self.logger.record("game/avg_episode_reward", avg_reward)
81
+ self.logger.record("game/avg_win_turn", avg_turns)
82
+ self.logger.record("game/win_rate_rolling", win_rate)
83
+
84
+ # Periodic summary to terminal (More frequent for visibility)
85
+ if self.n_calls % 256 == 0:
86
+ print(
87
+ f" [Stats] Steps: {self.num_timesteps} | Win Rate: {win_rate:.1f}% | Avg Reward: {avg_reward:.2f} | Avg Turn: {avg_turns:.1f}",
88
+ flush=True,
89
+ )
90
+
91
+ return True
92
+
93
+
94
+ class SaveOnBestWinRateCallback(BaseCallback):
95
+ """Callback to save the model when win rate reaches a new peak."""
96
+
97
+ def __init__(self, check_freq: int, save_path: str, verbose=1):
98
+ super(SaveOnBestWinRateCallback, self).__init__(verbose)
99
+ self.check_freq = check_freq
100
+ self.save_path = save_path
101
+ self.best_win_rate = -np.inf
102
+ self.min_win_rate_threshold = 30.0 # Only save 'best' if above 30% to avoid early noise
103
+
104
+ def _init_callback(self) -> None:
105
+ if self.save_path is not None:
106
+ os.makedirs(self.save_path, exist_ok=True)
107
+
108
+ def _on_step(self) -> bool:
109
+ if self.n_calls % self.check_freq == 0:
110
+ infos = self.locals.get("infos")
111
+ if infos:
112
+ avg_win_rate = np.mean([info.get("win_rate", 0) for info in infos])
113
+ if avg_win_rate > self.best_win_rate and avg_win_rate > self.min_win_rate_threshold:
114
+ self.best_win_rate = avg_win_rate
115
+ if self.verbose > 0:
116
+ print(
117
+ f" [Saving] New Best Win Rate: {avg_win_rate:.1f}%! Progressing towards big moment...",
118
+ flush=True,
119
+ )
120
+ self.model.save(os.path.join(self.save_path, "best_win_rate_model"))
121
+ return True
122
+
123
+
124
+ class SelfPlayUpdateCallback(BaseCallback):
125
+ """Callback to save the model for self-play opponents."""
126
+
127
+ def __init__(self, update_freq: int, save_path: str, verbose=0):
128
+ super(SelfPlayUpdateCallback, self).__init__(verbose)
129
+ self.update_freq = update_freq
130
+ self.save_path = save_path
131
+
132
+ def _init_callback(self) -> None:
133
+ if self.save_path is not None:
134
+ os.makedirs(self.save_path, exist_ok=True)
135
+
136
+ def _on_step(self) -> bool:
137
+ if self.n_calls % self.update_freq == 0:
138
+ if self.verbose > 0:
139
+ print(" [Self-Play] Updating opponent model...", flush=True)
140
+ self.model.save(os.path.join(self.save_path, "self_play_opponent"))
141
+ return True
142
+
143
+
144
+ def create_env(rank, usage=0.5, deck_type="random_verified", opponent_type="random"):
145
+ env = LoveLiveCardGameEnv(target_cpu_usage=usage, deck_type=deck_type, opponent_type=opponent_type)
146
+ env = Monitor(env)
147
+ env = ActionMasker(env, lambda e: e.unwrapped.action_masks())
148
+
149
+ # Seed for diversity across workers
150
+ env.reset(seed=42 + rank)
151
+ return env
152
+
153
+
154
+ def train():
155
+ # 1. Hardware Constraints Setup
156
+ num_cpu = int(os.getenv("TRAIN_CPUS", "4"))
157
+ usage = float(os.getenv("TRAIN_USAGE", "0.5"))
158
+ deck_type = os.getenv("TRAIN_DECK", "random_verified")
159
+ gpu_usage = float(os.getenv("TRAIN_GPU_USAGE", "0.7"))
160
+ batch_size = int(os.getenv("TRAIN_BATCH_SIZE", "256"))
161
+ n_epochs = int(os.getenv("TRAIN_EPOCHS", "10"))
162
+ n_steps = int(os.getenv("TRAIN_STEPS", "2048"))
163
+ opponent_type = os.getenv("TRAIN_OPPONENT", "random")
164
+
165
+ if torch.cuda.is_available() and gpu_usage < 1.0:
166
+ try:
167
+ print(f"Limiting GPU memory usage to {int(gpu_usage * 100)}%...", flush=True)
168
+ torch.cuda.set_per_process_memory_fraction(gpu_usage)
169
+ except Exception as e:
170
+ print(f"Warning: Could not set GPU memory fraction: {e}. Proceeding without limit.", flush=True)
171
+
172
+ print(
173
+ f"Initializing {num_cpu} parallel environments ({deck_type}) with opponent {opponent_type} and {int(usage * 100)}% per-core throttle...",
174
+ flush=True,
175
+ )
176
+
177
+ # Create Vectorized Environment
178
+ try:
179
+ # Optimization: Workers always use "random" internally because BatchedSubprocVecEnv intercepts and runs the real opponent
180
+ # This prevents workers from importing torch/sb3 and saves GBs of RAM.
181
+ env_fns = [
182
+ partial(create_env, rank=i, usage=usage, deck_type=deck_type, opponent_type="random")
183
+ for i in range(num_cpu)
184
+ ]
185
+
186
+ # Use our new Batched inference environment
187
+ opponent_path = os.path.join(os.getcwd(), "checkpoints", "self_play_opponent.zip")
188
+ env = BatchedSubprocVecEnv(env_fns, opponent_model_path=opponent_path if opponent_type == "self_play" else None)
189
+ print("Batched workers initialized! Starting training loop...", flush=True)
190
+
191
+ except Exception as e:
192
+ print(f"CRITICAL ERROR during worker initialization: {e}", flush=True)
193
+ import traceback
194
+
195
+ traceback.print_exc()
196
+ return
197
+
198
+ # 2. Model Configuration
199
+ load_path = os.getenv("LOAD_MODEL")
200
+ model = None
201
+ if load_path and os.path.exists(load_path):
202
+ try:
203
+ print(f" [LOAD] Loading existing model from {load_path}...", flush=True)
204
+ device = "cuda" if torch.cuda.is_available() else "cpu"
205
+ model = MaskablePPO.load(load_path, env=env, device=device)
206
+ print(" [LOAD] Model loaded successfully.", flush=True)
207
+ except ValueError as val_err:
208
+ if "Observation spaces do not match" in str(val_err):
209
+ print(
210
+ f" [WARNING] Checkpoint {load_path} has incompatible observation space (likely from an older engine version).",
211
+ flush=True,
212
+ )
213
+ print(" [WARNING] Skipping load and starting fresh to maintain stability.", flush=True)
214
+ model = None # Force fresh start
215
+ else:
216
+ raise val_err
217
+ except Exception as load_err:
218
+ print(f" [CRITICAL ERROR] Failed to load checkpoint: {load_err}", flush=True)
219
+ import traceback
220
+
221
+ traceback.print_exc()
222
+ env.close()
223
+ sys.exit(1)
224
+
225
+ if model is None:
226
+ print(" [INFO] Initializing fresh MaskablePPO model...", flush=True)
227
+ model = MaskablePPO(
228
+ "MlpPolicy",
229
+ env,
230
+ verbose=0,
231
+ gamma=0.99,
232
+ learning_rate=3e-4,
233
+ n_steps=n_steps,
234
+ batch_size=batch_size,
235
+ n_epochs=n_epochs,
236
+ tensorboard_log="./logs/ppo_tensorboard/",
237
+ device="cuda",
238
+ )
239
+
240
+ # NEW: Dry run support
241
+ parser = argparse.ArgumentParser()
242
+ parser.add_argument("--dry-run", action="store_true", help="Initialize and exit")
243
+ args, unknown = parser.parse_known_args()
244
+
245
+ if args.dry_run:
246
+ print(" [Dry Run] Workers initialized successfully. Exiting.", flush=True)
247
+ env.close()
248
+ return
249
+
250
+ print(f"Starting Training on {num_cpu} workers (Usage: {usage * 100}%)...", flush=True)
251
+
252
+ # Checkpoint Callback
253
+ checkpoint_callback = CheckpointCallback(
254
+ save_freq=max(1, 200000 // num_cpu), save_path="./checkpoints/", name_prefix="lovelive_ppo_checkpoint"
255
+ )
256
+
257
+ # 3. Learning Loop
258
+ stats_callback = TrainingStatsCallback()
259
+ best_rate_callback = SaveOnBestWinRateCallback(check_freq=1024, save_path="./checkpoints/")
260
+ self_play_callback = SelfPlayUpdateCallback(update_freq=20000, save_path="./checkpoints/")
261
+
262
+ callback_list = CallbackList([checkpoint_callback, stats_callback, best_rate_callback, self_play_callback])
263
+
264
+ try:
265
+ print(f"Starting Long-Running Training on {num_cpu} workers (Usage: {usage * 100}%)...")
266
+ model.learn(total_timesteps=2_000_000_000, progress_bar=False, callback=callback_list)
267
+
268
+ # Save Final Model
269
+ os.makedirs("checkpoints", exist_ok=True)
270
+ model.save("checkpoints/lovelive_ppo_optimized")
271
+ print("Training Complete. Model Saved.")
272
+
273
+ except KeyboardInterrupt:
274
+ print("\nTraining interrupted. Saving current progress...")
275
+ model.save("checkpoints/lovelive_ppo_interrupted")
276
+ finally:
277
+ env.close()
278
+
279
+
280
+ if __name__ == "__main__":
281
+ train()