Spaces:
Running
Running
| import os | |
| import sys | |
| import time | |
| # Immediate feedback | |
| print(" [Init] Python process started. Loading libraries...") | |
| print(" [Init] Loading Pytorch...", end="", flush=True) | |
| import torch | |
| import torch as th | |
| import torch.nn.functional as F | |
| print(" Done.") | |
| print(" [Init] Loading Gymnasium & SB3...", end="", flush=True) | |
| import glob | |
| import warnings | |
| import numpy as np | |
| from gymnasium import spaces | |
| from sb3_contrib import MaskablePPO | |
| from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback | |
| from stable_baselines3.common.utils import explained_variance | |
| from tqdm import tqdm | |
| print(" Done.") | |
| # Filter Numba warning | |
| warnings.filterwarnings("ignore", category=RuntimeWarning, message="nopython is set for njit") | |
| # Ensure project root is in path | |
| sys.path.append(os.getcwd()) | |
| print(" [Init] Loading LovecaSim Vector Engine...", end="", flush=True) | |
| from ai.environments.vec_env_adapter import VectorEnvAdapter | |
| from ai.utils.loveca_features_extractor import LovecaFeaturesExtractor | |
| print(" Done.") | |
| class TimeCheckpointCallback(BaseCallback): | |
| """ | |
| Save the model every N minutes. | |
| """ | |
| def __init__(self, save_freq_minutes: float, save_path: str, name_prefix: str, verbose: int = 0): | |
| super().__init__(verbose) | |
| self.save_freq_seconds = save_freq_minutes * 60 | |
| self.save_path = save_path | |
| self.name_prefix = name_prefix | |
| self.last_time_save = time.time() | |
| def _on_step(self) -> bool: | |
| if (time.time() - self.last_time_save) > self.save_freq_seconds: | |
| save_path = os.path.join(self.save_path, f"{self.name_prefix}_time_auto") | |
| self.model.save(save_path) | |
| if self.verbose > 0: | |
| print(f" [Save] Model auto-saved after 3 minutes to {save_path}") | |
| self.last_time_save = time.time() | |
| return True | |
| class ModelSnapshotCallback(BaseCallback): | |
| """ | |
| Saves a 'Model Snapshot' every X minutes: | |
| - model.zip | |
| - verified_card_pool.json (Context) | |
| - snapshot_meta.json (Architecture/Config) | |
| """ | |
| def __init__(self, save_freq_minutes: float, save_path: str, verbose=0): | |
| super().__init__(verbose) | |
| self.save_freq_minutes = save_freq_minutes | |
| self.save_path = save_path | |
| self.last_save_time = time.time() | |
| # Ensure historiccheckpoints exists | |
| os.makedirs("historiccheckpoints", exist_ok=True) | |
| def _on_step(self) -> bool: | |
| if time.time() - self.last_save_time > self.save_freq_minutes * 60: | |
| self.last_save_time = time.time() | |
| self._save_snapshot() | |
| return True | |
| def _save_snapshot(self): | |
| timestamp = time.strftime("%Y%m%d_%H%M%S") | |
| steps = self.num_timesteps | |
| snapshot_name = f"{timestamp}_{steps}_steps" | |
| snapshot_dir = os.path.join("historiccheckpoints", snapshot_name) | |
| if self.verbose > 0: | |
| print(f" [Snapshot] Saving to {snapshot_dir}...") | |
| os.makedirs(snapshot_dir, exist_ok=True) | |
| # 1. Save Model | |
| model_path = os.path.join(snapshot_dir, "model.zip") | |
| self.model.save(model_path) | |
| # 2. Save Card Pool (Context) | |
| try: | |
| import shutil | |
| shutil.copy("verified_card_pool.json", os.path.join(snapshot_dir, "verified_card_pool.json")) | |
| except Exception as e: | |
| print(f" [Snapshot] Warning: Could not copy card pool: {e}") | |
| # 3. Save Metadata (Architecture) | |
| meta = { | |
| "timestamp": timestamp, | |
| "timesteps": int(steps), | |
| "obs_dim": int(self.model.observation_space.shape[0]), | |
| "action_space_size": int(self.model.action_space.n), | |
| "features": ["GlobalVolumes", "LiveZone", "Traits", "TurnNumber"], | |
| "notes": "Generated by ModelSnapshotCallback", | |
| } | |
| try: | |
| import json | |
| with open(os.path.join(snapshot_dir, "snapshot_meta.json"), "w") as f: | |
| json.dump(meta, f, indent=2) | |
| except Exception as e: | |
| print(f" [Snapshot] Warning: Could not save meta: {e}") | |
| # 4. Limit to Last 5 Snapshots | |
| self._prune_snapshots() | |
| def _prune_snapshots(self): | |
| root = os.path.dirname(self.save_path) # wait, save_path is "historiccheckpoints"? | |
| # save_path passed in init is "historiccheckpoints" relative to cwd? Yes. | |
| # But wait, self.save_path in init is used. | |
| # Let's verify self.save_path from init | |
| # It is "historiccheckpoints" | |
| search_dir = self.save_path | |
| if not os.path.exists(search_dir): | |
| return | |
| # Get list of directories | |
| try: | |
| subdirs = [ | |
| os.path.join(search_dir, d) | |
| for d in os.listdir(search_dir) | |
| if os.path.isdir(os.path.join(search_dir, d)) | |
| ] | |
| # Sort by creation time (oldest first) | |
| subdirs.sort(key=os.path.getctime) | |
| # Keep last 5 | |
| max_keep = 5 | |
| if len(subdirs) > max_keep: | |
| to_remove = subdirs[:-max_keep] | |
| import shutil | |
| for d in to_remove: | |
| try: | |
| shutil.rmtree(d) | |
| if self.verbose > 0: | |
| print(f" [Snapshot] Pruned old snapshot: {d}") | |
| except Exception as e: | |
| print(f" [Snapshot] Warning: Failed to prune {d}: {e}") | |
| except Exception as e: | |
| print(f" [Snapshot] Warning: Pruning failed: {e}") | |
| class DetailedStatusCallback(BaseCallback): | |
| """ | |
| Logs detailed phase information (Collection vs Optimization) and VRAM usage. | |
| """ | |
| def __init__(self, verbose=0): | |
| super().__init__(verbose) | |
| self.collection_start_time = 0.0 | |
| def _on_rollout_start(self) -> None: | |
| """ | |
| A rollout is the collection of environment steps. | |
| """ | |
| self.collection_start_time = time.time() | |
| print(f"\n [Phase] Starting Rollout Collection (Steps: {self.model.n_steps})...") | |
| def _on_rollout_end(self) -> None: | |
| """ | |
| This event is triggered before updating the policy. | |
| """ | |
| duration = time.time() - self.collection_start_time | |
| n_envs = self.model.n_envs | |
| n_steps = self.model.n_steps | |
| total_steps = n_envs * n_steps | |
| fps = total_steps / duration if duration > 0 else 0 | |
| print(f" [Phase] Collection Complete. Duration: {duration:.2f}s ({fps:.0f} FPS)") | |
| # PPO optimization is about to start | |
| print(f" [Phase] Starting PPO Optimization (Epochs: {self.model.n_epochs}, Batch: {self.model.batch_size})...") | |
| if torch.cuda.is_available(): | |
| allocated = torch.cuda.memory_allocated() / 1024**3 | |
| reserved = torch.cuda.memory_reserved() / 1024**3 | |
| print(f" [VRAM] Allocated: {allocated:.2f} GB | Reserved: {reserved:.2f} GB") | |
| print(" [Info] Optimization may take time if batch size is large. Please wait...") | |
| def _on_step(self) -> bool: | |
| return True | |
| class TrainingStatsCallback(BaseCallback): | |
| """ | |
| Simple stats logging for Vectorized Training. | |
| """ | |
| def __init__(self, verbose=0): | |
| super().__init__(verbose) | |
| def _on_step(self) -> bool: | |
| # Log win rate if available in infos | |
| infos = self.locals.get("infos") | |
| if infos: | |
| # VectorEnv doesn't emit 'win_rate' in infos by default unless we add it | |
| # But we can look for 'episode' keys | |
| episodes = [i.get("episode") for i in infos if "episode" in i] | |
| if episodes: | |
| rew = np.mean([ep["r"] for ep in episodes]) | |
| length = np.mean([ep["l"] for ep in episodes]) | |
| self.logger.record("rollout/ep_rew_mean", rew) | |
| self.logger.record("rollout/ep_len_mean", length) | |
| return True | |
| class ProgressMaskablePPO(MaskablePPO): | |
| """ | |
| MaskablePPO with a tqdm progress bar during the optimization phase. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.optimization_pbar = None | |
| def train(self) -> None: | |
| """ | |
| Update policy using the currently gathered rollout buffer. | |
| """ | |
| # Switch to train mode (this affects batch norm / dropout) | |
| self.policy.set_training_mode(True) | |
| # Update optimizer learning rate | |
| self._update_learning_rate(self.policy.optimizer) | |
| # Compute current clip range | |
| clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator] | |
| # Optional: clip range for the value function | |
| if self.clip_range_vf is not None: | |
| clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator] | |
| entropy_losses = [] | |
| pg_losses, value_losses = [], [] | |
| clip_fractions = [] | |
| continue_training = True | |
| # train for n_epochs epochs | |
| # ADDED: Persistent TQDM Progress Bar | |
| total_steps = self.n_epochs * (self.rollout_buffer.buffer_size // self.batch_size) | |
| if self.optimization_pbar is None: | |
| self.optimization_pbar = tqdm(total=total_steps, desc="Optimization", unit="batch", leave=True) | |
| else: | |
| self.optimization_pbar.reset(total=total_steps) | |
| for epoch in range(self.n_epochs): | |
| approx_kl_divs = [] | |
| # Do a complete pass on the rollout buffer | |
| for rollout_data in self.rollout_buffer.get(self.batch_size): | |
| actions = rollout_data.actions | |
| if isinstance(self.action_space, spaces.Discrete): | |
| # Convert discrete action from float to long | |
| actions = rollout_data.actions.long().flatten() | |
| with th.cuda.amp.autocast(enabled=th.cuda.is_available()): | |
| values, log_prob, entropy = self.policy.evaluate_actions( | |
| rollout_data.observations, | |
| actions, | |
| action_masks=rollout_data.action_masks, | |
| ) | |
| values = values.flatten() | |
| # Normalize advantage | |
| advantages = rollout_data.advantages | |
| if self.normalize_advantage: | |
| advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) | |
| # ratio between old and new policy, should be one at the first iteration | |
| ratio = th.exp(log_prob - rollout_data.old_log_prob) | |
| # clipped surrogate loss | |
| policy_loss_1 = advantages * ratio | |
| policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) | |
| policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() | |
| # Logging | |
| pg_losses.append(policy_loss.item()) | |
| clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() | |
| clip_fractions.append(clip_fraction) | |
| if self.clip_range_vf is None: | |
| # No clipping | |
| values_pred = values | |
| else: | |
| # Clip the different between old and new value | |
| # NOTE: this depends on the reward scaling | |
| values_pred = rollout_data.old_values + th.clamp( | |
| values - rollout_data.old_values, -clip_range_vf, clip_range_vf | |
| ) | |
| # Value loss using the TD(gae_lambda) target | |
| value_loss = F.mse_loss(rollout_data.returns, values_pred) | |
| value_losses.append(value_loss.item()) | |
| # Entropy loss favor exploration | |
| if entropy is None: | |
| # Approximate entropy when no analytical form | |
| entropy_loss = -th.mean(-log_prob) | |
| else: | |
| entropy_loss = -th.mean(entropy) | |
| entropy_losses.append(entropy_loss.item()) | |
| loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss | |
| # Calculate approximate form of reverse KL Divergence for early stopping | |
| # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 | |
| # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 | |
| # and Schulman blog: http://joschu.net/blog/kl-approx.html | |
| with th.no_grad(): | |
| log_ratio = log_prob - rollout_data.old_log_prob | |
| approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() | |
| approx_kl_divs.append(approx_kl_div) | |
| if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: | |
| continue_training = False | |
| if self.verbose >= 1: | |
| print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") | |
| break | |
| # Optimization step | |
| self.policy.optimizer.zero_grad() | |
| # AMP: Automatic Mixed Precision | |
| # Check if scaler exists (backward compatibility) | |
| if not hasattr(self, "scaler"): | |
| self.scaler = th.cuda.amp.GradScaler(enabled=th.cuda.is_available()) | |
| # Backward pass | |
| self.scaler.scale(loss).backward() | |
| # Clip grad norm | |
| self.scaler.unscale_(self.policy.optimizer) | |
| th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) | |
| # Optimizer step | |
| self.scaler.step(self.policy.optimizer) | |
| self.scaler.update() | |
| # Update Progress Bar | |
| self.optimization_pbar.update(1) | |
| if not continue_training: | |
| break | |
| # Don't close, just leave it for the next reset | |
| self._n_updates += self.n_epochs | |
| explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) | |
| # Logs | |
| self.logger.record("train/entropy_loss", np.mean(entropy_losses)) | |
| self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) | |
| self.logger.record("train/value_loss", np.mean(value_losses)) | |
| self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) | |
| self.logger.record("train/clip_fraction", np.mean(clip_fractions)) | |
| self.logger.record("train/loss", loss.item()) | |
| self.logger.record("train/explained_variance", explained_var) | |
| self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") | |
| self.logger.record("train/clip_range", clip_range) | |
| if self.clip_range_vf is not None: | |
| self.logger.record("train/clip_range_vf", clip_range_vf) | |
| def _excluded_save_params(self) -> list[str]: | |
| """ | |
| Returns the names of the parameters that should be excluded from being saved. | |
| """ | |
| return super()._excluded_save_params() + ["optimization_pbar"] | |
| def main(): | |
| print("========================================================") | |
| print(" LovecaSim - STARTING VECTORIZED TRAINING (700k+ SPS) ") | |
| print("========================================================") | |
| # Configuration from Environment Variables | |
| TOTAL_TIMESTEPS = int(os.getenv("TRAIN_STEPS", "100_000_000")) | |
| BATCH_SIZE = int(os.getenv("TRAIN_BATCH_SIZE", "8192")) | |
| NUM_ENVS = int(os.getenv("TRAIN_ENVS", "4096")) | |
| N_STEPS = int(os.getenv("TRAIN_N_STEPS", "256")) | |
| # Advanced Hyperparameters | |
| ENT_COEF = float(os.getenv("ENT_COEF", "0.01")) | |
| GAMMA = float(os.getenv("GAMMA", "0.99")) | |
| GAE_LAMBDA = float(os.getenv("GAE_LAMBDA", "0.95")) | |
| SAVE_PATH = "./checkpoints/vector/" | |
| os.makedirs(SAVE_PATH, exist_ok=True) | |
| # Log Hardware/Threading Config | |
| omp_threads = os.getenv("OMP_NUM_THREADS", "Unset (All Cores)") | |
| print(f" [Config] Batch Size: {BATCH_SIZE}") | |
| print(f" [Config] Num Envs: {NUM_ENVS}") | |
| print(f" [Config] N Steps: {N_STEPS}") | |
| print(f" [Config] CPU Cores: {omp_threads}") | |
| # 1. Create Vector Environment (Numba) | |
| print(f" [Init] Creating {NUM_ENVS} parallel Numba environments...") | |
| env = VectorEnvAdapter(num_envs=NUM_ENVS) | |
| # --- WARMUP / COMPILATION --- | |
| print(" [Init] Compiling Numba functions (Reset)... This may take 30s+") | |
| env.reset() | |
| print(" [Init] Compiling Numba functions (Step)... This may take 60s+") | |
| # Perform a dummy step to force compilation of the massive step kernel | |
| dummy_actions = np.zeros(NUM_ENVS, dtype=np.int32) | |
| env.step(dummy_actions) | |
| print(" [Init] Compilation complete! Starting training...") | |
| # ---------------------------- | |
| # 2. Setup or Load PPO Agent | |
| checkpoint_path = os.getenv("LOAD_CHECKPOINT", "") | |
| # Auto-resolve "LATEST" or "AUTO" | |
| force_restart = os.getenv("RESTART_TRAINING", "FALSE").upper() == "TRUE" | |
| if force_restart: | |
| print(" [Config] RESTART_TRAINING=TRUE. Ignoring checkpoints.") | |
| checkpoint_path = "" | |
| elif checkpoint_path.upper() in ["LATEST", "AUTO"]: | |
| list_of_files = glob.glob(os.path.join(SAVE_PATH, "*.zip")) | |
| if list_of_files: | |
| checkpoint_path = max(list_of_files, key=os.path.getctime) | |
| print(f" [Config] LOAD_CHECKPOINT='{os.getenv('LOAD_CHECKPOINT')}' -> Auto-resolved to: {checkpoint_path}") | |
| else: | |
| print(" [Config] LOAD_CHECKPOINT='LATEST' but no checkpoints found. Starting fresh.") | |
| checkpoint_path = "" | |
| model = None | |
| if checkpoint_path and os.path.exists(checkpoint_path): | |
| print(f" [Load] Scanning checkpoint: {checkpoint_path}") | |
| try: | |
| # Check dimensions before full load if possible, or load and check | |
| temp_model = ProgressMaskablePPO.load(checkpoint_path, device="cpu") | |
| model_obs_dim = temp_model.observation_space.shape[0] | |
| env_obs_dim = env.observation_space.shape[0] | |
| if model_obs_dim != env_obs_dim: | |
| print(f" [Load] Dimension Mismatch! Model: {model_obs_dim}, Env: {env_obs_dim}") | |
| print(f" [Load] Cannot resume training across eras. Starting FRESH {env_obs_dim}-dim model.") | |
| model = None | |
| else: | |
| print(f" [Load] Dimensions match ({model_obs_dim}). Resuming training...") | |
| model = ProgressMaskablePPO.load( | |
| checkpoint_path, | |
| env=env, | |
| device="cuda" if torch.cuda.is_available() else "cpu", | |
| custom_objects={ | |
| "learning_rate": float(os.getenv("LEARNING_RATE", "3e-4")), | |
| "batch_size": BATCH_SIZE, | |
| "n_epochs": int(os.getenv("NUM_EPOCHS", "4")), | |
| }, | |
| ) | |
| reset_num_timesteps = False | |
| print(" [Load] Success.") | |
| except Exception as e: | |
| print(f" [Error] Failed to load checkpoint: {e}") | |
| print(" [Init] Falling back to fresh model...") | |
| model = None | |
| if model is None: | |
| if checkpoint_path and not os.path.exists(checkpoint_path): | |
| print(f" [Warning] Checkpoint file not found: {checkpoint_path}") | |
| print(" [Init] Creating fresh ProgressMaskablePPO model...") | |
| # Determine Policy Args | |
| obs_mode_env = os.getenv("OBS_MODE", "STANDARD") | |
| if obs_mode_env == "ATTENTION": | |
| print(" [Init] Using LovecaFeaturesExtractor (Attention)") | |
| policy_kwargs = dict( | |
| features_extractor_class=LovecaFeaturesExtractor, | |
| features_extractor_kwargs=dict(features_dim=256), | |
| net_arch=[], | |
| ) | |
| else: | |
| policy_kwargs = dict(net_arch=[512, 512]) | |
| model = ProgressMaskablePPO( | |
| "MlpPolicy", | |
| env, | |
| verbose=1, | |
| learning_rate=float(os.getenv("LEARNING_RATE", "3e-4")), | |
| n_steps=N_STEPS, | |
| batch_size=BATCH_SIZE, | |
| n_epochs=int(os.getenv("NUM_EPOCHS", "4")), | |
| gamma=GAMMA, | |
| gae_lambda=GAE_LAMBDA, | |
| ent_coef=ENT_COEF, | |
| tensorboard_log="./logs/vector_tensorboard/", | |
| policy_kwargs=policy_kwargs, | |
| ) | |
| reset_num_timesteps = True | |
| print(f" [Init] PPO Model initialized. Device: {model.device}") | |
| # 3. Callbacks | |
| # Refactored: Callbacks moved to module level. | |
| # Standard Checkpoint (Keep for compatibility/safety) | |
| checkpoint_callback = CheckpointCallback( | |
| save_freq=max(1, 1000000 // NUM_ENVS), save_path=SAVE_PATH, name_prefix="numba_ppo" | |
| ) | |
| save_freq = float(os.getenv("SAVE_FREQ_MINS", "15.0")) | |
| # Snapshot Callback (Replaces TimeCheckpointCallback) | |
| snapshot_callback = ModelSnapshotCallback( | |
| save_freq_minutes=save_freq, | |
| save_path="historiccheckpoints", | |
| verbose=1, | |
| ) | |
| # Store OBS_MODE in snapshot meta | |
| # (We need to update ModelSnapshotCallback logic or just trust env stores it? | |
| # Ideally pass it to callback or update meta generation. | |
| # Let's keep it simple: Environment tracks it.) | |
| # 4. Train | |
| print(" [Train] Starting training loop...") | |
| print(f" [Train] Model Mode: {os.getenv('OBS_MODE', 'STANDARD')}") | |
| print(f" [Train] Reset Timesteps: {reset_num_timesteps}") | |
| print(" [Note] Press Ctrl+C to stop and force-save.") | |
| # Generate a timestamped run name for TensorBoard | |
| run_name = f"ProgressPPO_{time.strftime('%m%d_%H%M%S')}" | |
| if not reset_num_timesteps: | |
| run_name += "_RESUME" | |
| try: | |
| model.learn( | |
| total_timesteps=TOTAL_TIMESTEPS, | |
| callback=[ | |
| checkpoint_callback, | |
| snapshot_callback, | |
| TrainingStatsCallback(), | |
| DetailedStatusCallback(), | |
| ], # Use Snapshot + DetailedStatus! | |
| progress_bar=True, | |
| reset_num_timesteps=reset_num_timesteps, | |
| tb_log_name=run_name, | |
| ) | |
| print(" [Train] Training finished.") | |
| model.save(f"{SAVE_PATH}/final_model") | |
| except KeyboardInterrupt: | |
| print("\n [Train] Interrupted by user. Saving model...") | |
| model.save(f"{SAVE_PATH}/interrupted_model") | |
| # Trigger explicit snapshot on interrupt | |
| snapshot_callback._save_snapshot() | |
| print(" [Train] Model saved.") | |
| except Exception as e: | |
| print(f"\n [Error] Training crashed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| emergency_save = os.path.join(SAVE_PATH, "crash_emergency") | |
| model.save(emergency_save) | |
| print(f" [Save] Crash emergency checkpoint saved to: {emergency_save}") | |
| finally: | |
| print(" [Done] Exiting gracefully.") | |
| env.close() | |
| if __name__ == "__main__": | |
| main() | |