trioskosmos commited on
Commit
77b2fc5
·
verified ·
1 Parent(s): fa71ce1

Upload ai/training/train_vectorized.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/training/train_vectorized.py +585 -0
ai/training/train_vectorized.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+
5
+ # Immediate feedback
6
+ print(" [Init] Python process started. Loading libraries...")
7
+
8
+ print(" [Init] Loading Pytorch...", end="", flush=True)
9
+ import torch
10
+ import torch as th
11
+ import torch.nn.functional as F
12
+
13
+ print(" Done.")
14
+
15
+ print(" [Init] Loading Gymnasium & SB3...", end="", flush=True)
16
+ import glob
17
+ import warnings
18
+
19
+ import numpy as np
20
+ from gymnasium import spaces
21
+ from sb3_contrib import MaskablePPO
22
+ from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
23
+ from stable_baselines3.common.utils import explained_variance
24
+ from tqdm import tqdm
25
+
26
+ print(" Done.")
27
+
28
+ # Filter Numba warning
29
+ warnings.filterwarnings("ignore", category=RuntimeWarning, message="nopython is set for njit")
30
+
31
+ # Ensure project root is in path
32
+ sys.path.append(os.getcwd())
33
+
34
+ print(" [Init] Loading LovecaSim Vector Engine...", end="", flush=True)
35
+ from ai.environments.vec_env_adapter import VectorEnvAdapter
36
+ from ai.utils.loveca_features_extractor import LovecaFeaturesExtractor
37
+
38
+ print(" Done.")
39
+
40
+
41
+ class TimeCheckpointCallback(BaseCallback):
42
+ """
43
+ Save the model every N minutes.
44
+ """
45
+
46
+ def __init__(self, save_freq_minutes: float, save_path: str, name_prefix: str, verbose: int = 0):
47
+ super().__init__(verbose)
48
+ self.save_freq_seconds = save_freq_minutes * 60
49
+ self.save_path = save_path
50
+ self.name_prefix = name_prefix
51
+ self.last_time_save = time.time()
52
+
53
+ def _on_step(self) -> bool:
54
+ if (time.time() - self.last_time_save) > self.save_freq_seconds:
55
+ save_path = os.path.join(self.save_path, f"{self.name_prefix}_time_auto")
56
+ self.model.save(save_path)
57
+ if self.verbose > 0:
58
+ print(f" [Save] Model auto-saved after 3 minutes to {save_path}")
59
+ self.last_time_save = time.time()
60
+ return True
61
+
62
+
63
+ class ModelSnapshotCallback(BaseCallback):
64
+ """
65
+ Saves a 'Model Snapshot' every X minutes:
66
+ - model.zip
67
+ - verified_card_pool.json (Context)
68
+ - snapshot_meta.json (Architecture/Config)
69
+ """
70
+
71
+ def __init__(self, save_freq_minutes: float, save_path: str, verbose=0):
72
+ super().__init__(verbose)
73
+ self.save_freq_minutes = save_freq_minutes
74
+ self.save_path = save_path
75
+ self.last_save_time = time.time()
76
+ # Ensure historiccheckpoints exists
77
+ os.makedirs("historiccheckpoints", exist_ok=True)
78
+
79
+ def _on_step(self) -> bool:
80
+ if time.time() - self.last_save_time > self.save_freq_minutes * 60:
81
+ self.last_save_time = time.time()
82
+ self._save_snapshot()
83
+ return True
84
+
85
+ def _save_snapshot(self):
86
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
87
+ steps = self.num_timesteps
88
+ snapshot_name = f"{timestamp}_{steps}_steps"
89
+ snapshot_dir = os.path.join("historiccheckpoints", snapshot_name)
90
+
91
+ if self.verbose > 0:
92
+ print(f" [Snapshot] Saving to {snapshot_dir}...")
93
+
94
+ os.makedirs(snapshot_dir, exist_ok=True)
95
+
96
+ # 1. Save Model
97
+ model_path = os.path.join(snapshot_dir, "model.zip")
98
+ self.model.save(model_path)
99
+
100
+ # 2. Save Card Pool (Context)
101
+ try:
102
+ import shutil
103
+
104
+ shutil.copy("verified_card_pool.json", os.path.join(snapshot_dir, "verified_card_pool.json"))
105
+ except Exception as e:
106
+ print(f" [Snapshot] Warning: Could not copy card pool: {e}")
107
+
108
+ # 3. Save Metadata (Architecture)
109
+ meta = {
110
+ "timestamp": timestamp,
111
+ "timesteps": int(steps),
112
+ "obs_dim": int(self.model.observation_space.shape[0]),
113
+ "action_space_size": int(self.model.action_space.n),
114
+ "features": ["GlobalVolumes", "LiveZone", "Traits", "TurnNumber"],
115
+ "notes": "Generated by ModelSnapshotCallback",
116
+ }
117
+ try:
118
+ import json
119
+
120
+ with open(os.path.join(snapshot_dir, "snapshot_meta.json"), "w") as f:
121
+ json.dump(meta, f, indent=2)
122
+ except Exception as e:
123
+ print(f" [Snapshot] Warning: Could not save meta: {e}")
124
+
125
+ # 4. Limit to Last 5 Snapshots
126
+ self._prune_snapshots()
127
+
128
+ def _prune_snapshots(self):
129
+ root = os.path.dirname(self.save_path) # wait, save_path is "historiccheckpoints"?
130
+ # save_path passed in init is "historiccheckpoints" relative to cwd? Yes.
131
+ # But wait, self.save_path in init is used.
132
+
133
+ # Let's verify self.save_path from init
134
+ # It is "historiccheckpoints"
135
+
136
+ search_dir = self.save_path
137
+ if not os.path.exists(search_dir):
138
+ return
139
+
140
+ # Get list of directories
141
+ try:
142
+ subdirs = [
143
+ os.path.join(search_dir, d)
144
+ for d in os.listdir(search_dir)
145
+ if os.path.isdir(os.path.join(search_dir, d))
146
+ ]
147
+ # Sort by creation time (oldest first)
148
+ subdirs.sort(key=os.path.getctime)
149
+
150
+ # Keep last 5
151
+ max_keep = 5
152
+ if len(subdirs) > max_keep:
153
+ to_remove = subdirs[:-max_keep]
154
+ import shutil
155
+
156
+ for d in to_remove:
157
+ try:
158
+ shutil.rmtree(d)
159
+ if self.verbose > 0:
160
+ print(f" [Snapshot] Pruned old snapshot: {d}")
161
+ except Exception as e:
162
+ print(f" [Snapshot] Warning: Failed to prune {d}: {e}")
163
+ except Exception as e:
164
+ print(f" [Snapshot] Warning: Pruning failed: {e}")
165
+
166
+
167
+ class DetailedStatusCallback(BaseCallback):
168
+ """
169
+ Logs detailed phase information (Collection vs Optimization) and VRAM usage.
170
+ """
171
+
172
+ def __init__(self, verbose=0):
173
+ super().__init__(verbose)
174
+ self.collection_start_time = 0.0
175
+
176
+ def _on_rollout_start(self) -> None:
177
+ """
178
+ A rollout is the collection of environment steps.
179
+ """
180
+ self.collection_start_time = time.time()
181
+ print(f"\n [Phase] Starting Rollout Collection (Steps: {self.model.n_steps})...")
182
+
183
+ def _on_rollout_end(self) -> None:
184
+ """
185
+ This event is triggered before updating the policy.
186
+ """
187
+ duration = time.time() - self.collection_start_time
188
+ n_envs = self.model.n_envs
189
+ n_steps = self.model.n_steps
190
+ total_steps = n_envs * n_steps
191
+ fps = total_steps / duration if duration > 0 else 0
192
+
193
+ print(f" [Phase] Collection Complete. Duration: {duration:.2f}s ({fps:.0f} FPS)")
194
+
195
+ # PPO optimization is about to start
196
+ print(f" [Phase] Starting PPO Optimization (Epochs: {self.model.n_epochs}, Batch: {self.model.batch_size})...")
197
+
198
+ if torch.cuda.is_available():
199
+ allocated = torch.cuda.memory_allocated() / 1024**3
200
+ reserved = torch.cuda.memory_reserved() / 1024**3
201
+ print(f" [VRAM] Allocated: {allocated:.2f} GB | Reserved: {reserved:.2f} GB")
202
+ print(" [Info] Optimization may take time if batch size is large. Please wait...")
203
+
204
+ def _on_step(self) -> bool:
205
+ return True
206
+
207
+
208
+ class TrainingStatsCallback(BaseCallback):
209
+ """
210
+ Simple stats logging for Vectorized Training.
211
+ """
212
+
213
+ def __init__(self, verbose=0):
214
+ super().__init__(verbose)
215
+
216
+ def _on_step(self) -> bool:
217
+ # Log win rate if available in infos
218
+ infos = self.locals.get("infos")
219
+ if infos:
220
+ # VectorEnv doesn't emit 'win_rate' in infos by default unless we add it
221
+ # But we can look for 'episode' keys
222
+ episodes = [i.get("episode") for i in infos if "episode" in i]
223
+ if episodes:
224
+ rew = np.mean([ep["r"] for ep in episodes])
225
+ length = np.mean([ep["l"] for ep in episodes])
226
+ self.logger.record("rollout/ep_rew_mean", rew)
227
+ self.logger.record("rollout/ep_len_mean", length)
228
+ return True
229
+
230
+
231
+ class ProgressMaskablePPO(MaskablePPO):
232
+ """
233
+ MaskablePPO with a tqdm progress bar during the optimization phase.
234
+ """
235
+
236
+ def __init__(self, *args, **kwargs):
237
+ super().__init__(*args, **kwargs)
238
+ self.optimization_pbar = None
239
+
240
+ def train(self) -> None:
241
+ """
242
+ Update policy using the currently gathered rollout buffer.
243
+ """
244
+ # Switch to train mode (this affects batch norm / dropout)
245
+ self.policy.set_training_mode(True)
246
+ # Update optimizer learning rate
247
+ self._update_learning_rate(self.policy.optimizer)
248
+ # Compute current clip range
249
+ clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator]
250
+ # Optional: clip range for the value function
251
+ if self.clip_range_vf is not None:
252
+ clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
253
+
254
+ entropy_losses = []
255
+ pg_losses, value_losses = [], []
256
+ clip_fractions = []
257
+
258
+ continue_training = True
259
+
260
+ # train for n_epochs epochs
261
+ # ADDED: Persistent TQDM Progress Bar
262
+ total_steps = self.n_epochs * (self.rollout_buffer.buffer_size // self.batch_size)
263
+
264
+ if self.optimization_pbar is None:
265
+ self.optimization_pbar = tqdm(total=total_steps, desc="Optimization", unit="batch", leave=True)
266
+ else:
267
+ self.optimization_pbar.reset(total=total_steps)
268
+
269
+ for epoch in range(self.n_epochs):
270
+ approx_kl_divs = []
271
+ # Do a complete pass on the rollout buffer
272
+ for rollout_data in self.rollout_buffer.get(self.batch_size):
273
+ actions = rollout_data.actions
274
+ if isinstance(self.action_space, spaces.Discrete):
275
+ # Convert discrete action from float to long
276
+ actions = rollout_data.actions.long().flatten()
277
+
278
+ with th.cuda.amp.autocast(enabled=th.cuda.is_available()):
279
+ values, log_prob, entropy = self.policy.evaluate_actions(
280
+ rollout_data.observations,
281
+ actions,
282
+ action_masks=rollout_data.action_masks,
283
+ )
284
+
285
+ values = values.flatten()
286
+ # Normalize advantage
287
+ advantages = rollout_data.advantages
288
+ if self.normalize_advantage:
289
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
290
+
291
+ # ratio between old and new policy, should be one at the first iteration
292
+ ratio = th.exp(log_prob - rollout_data.old_log_prob)
293
+
294
+ # clipped surrogate loss
295
+ policy_loss_1 = advantages * ratio
296
+ policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
297
+ policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
298
+
299
+ # Logging
300
+ pg_losses.append(policy_loss.item())
301
+ clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
302
+ clip_fractions.append(clip_fraction)
303
+
304
+ if self.clip_range_vf is None:
305
+ # No clipping
306
+ values_pred = values
307
+ else:
308
+ # Clip the different between old and new value
309
+ # NOTE: this depends on the reward scaling
310
+ values_pred = rollout_data.old_values + th.clamp(
311
+ values - rollout_data.old_values, -clip_range_vf, clip_range_vf
312
+ )
313
+ # Value loss using the TD(gae_lambda) target
314
+ value_loss = F.mse_loss(rollout_data.returns, values_pred)
315
+ value_losses.append(value_loss.item())
316
+
317
+ # Entropy loss favor exploration
318
+ if entropy is None:
319
+ # Approximate entropy when no analytical form
320
+ entropy_loss = -th.mean(-log_prob)
321
+ else:
322
+ entropy_loss = -th.mean(entropy)
323
+
324
+ entropy_losses.append(entropy_loss.item())
325
+
326
+ loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
327
+
328
+ # Calculate approximate form of reverse KL Divergence for early stopping
329
+ # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
330
+ # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
331
+ # and Schulman blog: http://joschu.net/blog/kl-approx.html
332
+ with th.no_grad():
333
+ log_ratio = log_prob - rollout_data.old_log_prob
334
+ approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
335
+ approx_kl_divs.append(approx_kl_div)
336
+
337
+ if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
338
+ continue_training = False
339
+ if self.verbose >= 1:
340
+ print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
341
+ break
342
+
343
+ # Optimization step
344
+ self.policy.optimizer.zero_grad()
345
+
346
+ # AMP: Automatic Mixed Precision
347
+ # Check if scaler exists (backward compatibility)
348
+ if not hasattr(self, "scaler"):
349
+ self.scaler = th.cuda.amp.GradScaler(enabled=th.cuda.is_available())
350
+
351
+ # Backward pass
352
+ self.scaler.scale(loss).backward()
353
+
354
+ # Clip grad norm
355
+ self.scaler.unscale_(self.policy.optimizer)
356
+ th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
357
+
358
+ # Optimizer step
359
+ self.scaler.step(self.policy.optimizer)
360
+ self.scaler.update()
361
+
362
+ # Update Progress Bar
363
+ self.optimization_pbar.update(1)
364
+
365
+ if not continue_training:
366
+ break
367
+
368
+ # Don't close, just leave it for the next reset
369
+
370
+ self._n_updates += self.n_epochs
371
+ explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
372
+
373
+ # Logs
374
+ self.logger.record("train/entropy_loss", np.mean(entropy_losses))
375
+ self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
376
+ self.logger.record("train/value_loss", np.mean(value_losses))
377
+ self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
378
+ self.logger.record("train/clip_fraction", np.mean(clip_fractions))
379
+ self.logger.record("train/loss", loss.item())
380
+ self.logger.record("train/explained_variance", explained_var)
381
+ self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
382
+ self.logger.record("train/clip_range", clip_range)
383
+ if self.clip_range_vf is not None:
384
+ self.logger.record("train/clip_range_vf", clip_range_vf)
385
+
386
+ def _excluded_save_params(self) -> list[str]:
387
+ """
388
+ Returns the names of the parameters that should be excluded from being saved.
389
+ """
390
+ return super()._excluded_save_params() + ["optimization_pbar"]
391
+
392
+
393
+ def main():
394
+ print("========================================================")
395
+ print(" LovecaSim - STARTING VECTORIZED TRAINING (700k+ SPS) ")
396
+ print("========================================================")
397
+
398
+ # Configuration from Environment Variables
399
+ TOTAL_TIMESTEPS = int(os.getenv("TRAIN_STEPS", "100_000_000"))
400
+ BATCH_SIZE = int(os.getenv("TRAIN_BATCH_SIZE", "8192"))
401
+ NUM_ENVS = int(os.getenv("TRAIN_ENVS", "4096"))
402
+ N_STEPS = int(os.getenv("TRAIN_N_STEPS", "256"))
403
+
404
+ # Advanced Hyperparameters
405
+ ENT_COEF = float(os.getenv("ENT_COEF", "0.01"))
406
+ GAMMA = float(os.getenv("GAMMA", "0.99"))
407
+ GAE_LAMBDA = float(os.getenv("GAE_LAMBDA", "0.95"))
408
+
409
+ SAVE_PATH = "./checkpoints/vector/"
410
+ os.makedirs(SAVE_PATH, exist_ok=True)
411
+
412
+ # Log Hardware/Threading Config
413
+ omp_threads = os.getenv("OMP_NUM_THREADS", "Unset (All Cores)")
414
+ print(f" [Config] Batch Size: {BATCH_SIZE}")
415
+ print(f" [Config] Num Envs: {NUM_ENVS}")
416
+ print(f" [Config] N Steps: {N_STEPS}")
417
+ print(f" [Config] CPU Cores: {omp_threads}")
418
+
419
+ # 1. Create Vector Environment (Numba)
420
+ print(f" [Init] Creating {NUM_ENVS} parallel Numba environments...")
421
+ env = VectorEnvAdapter(num_envs=NUM_ENVS)
422
+
423
+ # --- WARMUP / COMPILATION ---
424
+ print(" [Init] Compiling Numba functions (Reset)... This may take 30s+")
425
+ env.reset()
426
+ print(" [Init] Compiling Numba functions (Step)... This may take 60s+")
427
+ # Perform a dummy step to force compilation of the massive step kernel
428
+ dummy_actions = np.zeros(NUM_ENVS, dtype=np.int32)
429
+ env.step(dummy_actions)
430
+ print(" [Init] Compilation complete! Starting training...")
431
+ # ----------------------------
432
+
433
+ # 2. Setup or Load PPO Agent
434
+ checkpoint_path = os.getenv("LOAD_CHECKPOINT", "")
435
+
436
+ # Auto-resolve "LATEST" or "AUTO"
437
+ force_restart = os.getenv("RESTART_TRAINING", "FALSE").upper() == "TRUE"
438
+ if force_restart:
439
+ print(" [Config] RESTART_TRAINING=TRUE. Ignoring checkpoints.")
440
+ checkpoint_path = ""
441
+ elif checkpoint_path.upper() in ["LATEST", "AUTO"]:
442
+ list_of_files = glob.glob(os.path.join(SAVE_PATH, "*.zip"))
443
+ if list_of_files:
444
+ checkpoint_path = max(list_of_files, key=os.path.getctime)
445
+ print(f" [Config] LOAD_CHECKPOINT='{os.getenv('LOAD_CHECKPOINT')}' -> Auto-resolved to: {checkpoint_path}")
446
+ else:
447
+ print(" [Config] LOAD_CHECKPOINT='LATEST' but no checkpoints found. Starting fresh.")
448
+ checkpoint_path = ""
449
+
450
+ model = None
451
+ if checkpoint_path and os.path.exists(checkpoint_path):
452
+ print(f" [Load] Scanning checkpoint: {checkpoint_path}")
453
+ try:
454
+ # Check dimensions before full load if possible, or load and check
455
+ temp_model = ProgressMaskablePPO.load(checkpoint_path, device="cpu")
456
+ model_obs_dim = temp_model.observation_space.shape[0]
457
+ env_obs_dim = env.observation_space.shape[0]
458
+
459
+ if model_obs_dim != env_obs_dim:
460
+ print(f" [Load] Dimension Mismatch! Model: {model_obs_dim}, Env: {env_obs_dim}")
461
+ print(f" [Load] Cannot resume training across eras. Starting FRESH {env_obs_dim}-dim model.")
462
+ model = None
463
+ else:
464
+ print(f" [Load] Dimensions match ({model_obs_dim}). Resuming training...")
465
+ model = ProgressMaskablePPO.load(
466
+ checkpoint_path,
467
+ env=env,
468
+ device="cuda" if torch.cuda.is_available() else "cpu",
469
+ custom_objects={
470
+ "learning_rate": float(os.getenv("LEARNING_RATE", "3e-4")),
471
+ "batch_size": BATCH_SIZE,
472
+ "n_epochs": int(os.getenv("NUM_EPOCHS", "4")),
473
+ },
474
+ )
475
+ reset_num_timesteps = False
476
+ print(" [Load] Success.")
477
+ except Exception as e:
478
+ print(f" [Error] Failed to load checkpoint: {e}")
479
+ print(" [Init] Falling back to fresh model...")
480
+ model = None
481
+
482
+ if model is None:
483
+ if checkpoint_path and not os.path.exists(checkpoint_path):
484
+ print(f" [Warning] Checkpoint file not found: {checkpoint_path}")
485
+
486
+ print(" [Init] Creating fresh ProgressMaskablePPO model...")
487
+
488
+ # Determine Policy Args
489
+ obs_mode_env = os.getenv("OBS_MODE", "STANDARD")
490
+ if obs_mode_env == "ATTENTION":
491
+ print(" [Init] Using LovecaFeaturesExtractor (Attention)")
492
+ policy_kwargs = dict(
493
+ features_extractor_class=LovecaFeaturesExtractor,
494
+ features_extractor_kwargs=dict(features_dim=256),
495
+ net_arch=[],
496
+ )
497
+ else:
498
+ policy_kwargs = dict(net_arch=[512, 512])
499
+
500
+ model = ProgressMaskablePPO(
501
+ "MlpPolicy",
502
+ env,
503
+ verbose=1,
504
+ learning_rate=float(os.getenv("LEARNING_RATE", "3e-4")),
505
+ n_steps=N_STEPS,
506
+ batch_size=BATCH_SIZE,
507
+ n_epochs=int(os.getenv("NUM_EPOCHS", "4")),
508
+ gamma=GAMMA,
509
+ gae_lambda=GAE_LAMBDA,
510
+ ent_coef=ENT_COEF,
511
+ tensorboard_log="./logs/vector_tensorboard/",
512
+ policy_kwargs=policy_kwargs,
513
+ )
514
+ reset_num_timesteps = True
515
+
516
+ print(f" [Init] PPO Model initialized. Device: {model.device}")
517
+
518
+ # 3. Callbacks
519
+ # Refactored: Callbacks moved to module level.
520
+
521
+ # Standard Checkpoint (Keep for compatibility/safety)
522
+ checkpoint_callback = CheckpointCallback(
523
+ save_freq=max(1, 1000000 // NUM_ENVS), save_path=SAVE_PATH, name_prefix="numba_ppo"
524
+ )
525
+
526
+ save_freq = float(os.getenv("SAVE_FREQ_MINS", "15.0"))
527
+
528
+ # Snapshot Callback (Replaces TimeCheckpointCallback)
529
+ snapshot_callback = ModelSnapshotCallback(
530
+ save_freq_minutes=save_freq,
531
+ save_path="historiccheckpoints",
532
+ verbose=1,
533
+ )
534
+ # Store OBS_MODE in snapshot meta
535
+ # (We need to update ModelSnapshotCallback logic or just trust env stores it?
536
+ # Ideally pass it to callback or update meta generation.
537
+ # Let's keep it simple: Environment tracks it.)
538
+
539
+ # 4. Train
540
+ print(" [Train] Starting training loop...")
541
+ print(f" [Train] Model Mode: {os.getenv('OBS_MODE', 'STANDARD')}")
542
+ print(f" [Train] Reset Timesteps: {reset_num_timesteps}")
543
+ print(" [Note] Press Ctrl+C to stop and force-save.")
544
+
545
+ # Generate a timestamped run name for TensorBoard
546
+ run_name = f"ProgressPPO_{time.strftime('%m%d_%H%M%S')}"
547
+ if not reset_num_timesteps:
548
+ run_name += "_RESUME"
549
+
550
+ try:
551
+ model.learn(
552
+ total_timesteps=TOTAL_TIMESTEPS,
553
+ callback=[
554
+ checkpoint_callback,
555
+ snapshot_callback,
556
+ TrainingStatsCallback(),
557
+ DetailedStatusCallback(),
558
+ ], # Use Snapshot + DetailedStatus!
559
+ progress_bar=True,
560
+ reset_num_timesteps=reset_num_timesteps,
561
+ tb_log_name=run_name,
562
+ )
563
+ print(" [Train] Training finished.")
564
+ model.save(f"{SAVE_PATH}/final_model")
565
+ except KeyboardInterrupt:
566
+ print("\n [Train] Interrupted by user. Saving model...")
567
+ model.save(f"{SAVE_PATH}/interrupted_model")
568
+ # Trigger explicit snapshot on interrupt
569
+ snapshot_callback._save_snapshot()
570
+ print(" [Train] Model saved.")
571
+ except Exception as e:
572
+ print(f"\n [Error] Training crashed: {e}")
573
+ import traceback
574
+
575
+ traceback.print_exc()
576
+ emergency_save = os.path.join(SAVE_PATH, "crash_emergency")
577
+ model.save(emergency_save)
578
+ print(f" [Save] Crash emergency checkpoint saved to: {emergency_save}")
579
+ finally:
580
+ print(" [Done] Exiting gracefully.")
581
+ env.close()
582
+
583
+
584
+ if __name__ == "__main__":
585
+ main()