Spaces:
Runtime error
Runtime error
| """ | |
| Main training entry point. | |
| Uses SB3 RecurrentPPO (LSTM) with curriculum learning. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import yaml | |
| import click | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| def train(config, phase, timesteps, demo_mode, checkpoint): | |
| """Train the SpindleFlow RL delegation policy.""" | |
| try: | |
| from sb3_contrib import RecurrentPPO | |
| except ImportError: | |
| print("ERROR: sb3-contrib required. Run: pip install sb3-contrib") | |
| sys.exit(1) | |
| from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize | |
| from stable_baselines3.common.callbacks import ( | |
| CheckpointCallback, EvalCallback, BaseCallback | |
| ) | |
| from training.curriculum import CurriculumManager | |
| from training.specialist_improvement_callback import SpecialistImprovementCallback | |
| from env.spindleflow_env import SpindleFlowEnv | |
| from policy.lstm_policy import build_policy_kwargs | |
| with open(config) as f: | |
| cfg = yaml.safe_load(f) | |
| ppo_cfg = cfg["ppo"] | |
| training_cfg = cfg["training"] | |
| lstm_cfg = cfg["lstm"] | |
| total_ts = timesteps or training_cfg["total_timesteps"] | |
| curriculum = CurriculumManager(config_path=config) | |
| print(f"\n{'='*60}") | |
| print(f"SpindleFlow RL Training") | |
| print(f" Phase: {phase}") | |
| print(f" Timesteps: {total_ts}") | |
| print(f" Demo mode (real SpindleFlow): {demo_mode}") | |
| print(f"{'='*60}\n") | |
| def make_env(): | |
| return SpindleFlowEnv( | |
| config_path=config, | |
| phase=phase, | |
| use_real_spindleflow=demo_mode, | |
| ) | |
| n_envs = training_cfg.get("n_envs", 1) | |
| env = DummyVecEnv([make_env for _ in range(n_envs)]) | |
| env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0) | |
| eval_env = DummyVecEnv([make_env]) | |
| eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False) | |
| policy_kwargs = build_policy_kwargs( | |
| hidden_size=lstm_cfg["hidden_size"], | |
| num_lstm_layers=lstm_cfg["num_layers"], | |
| ) | |
| if checkpoint and os.path.exists(checkpoint): | |
| print(f"Loading checkpoint: {checkpoint}") | |
| model = RecurrentPPO.load(checkpoint, env=env) | |
| else: | |
| model = RecurrentPPO( | |
| policy="MlpLstmPolicy", | |
| env=env, | |
| learning_rate=ppo_cfg["learning_rate"], | |
| n_steps=ppo_cfg["n_steps"], | |
| batch_size=ppo_cfg["batch_size"], | |
| n_epochs=ppo_cfg["n_epochs"], | |
| gamma=ppo_cfg["gamma"], | |
| gae_lambda=ppo_cfg["gae_lambda"], | |
| clip_range=ppo_cfg["clip_range"], | |
| ent_coef=ppo_cfg["ent_coef"], | |
| vf_coef=ppo_cfg["vf_coef"], | |
| max_grad_norm=ppo_cfg["max_grad_norm"], | |
| policy_kwargs=policy_kwargs, | |
| tensorboard_log="./tensorboard_logs/", | |
| verbose=1, | |
| seed=training_cfg["seed"], | |
| device=training_cfg["device"], | |
| ) | |
| _max_specialists = cfg["environment"].get("max_specialists_per_episode", 6) | |
| class _RewardLogger(BaseCallback): | |
| def __init__(self, max_specialists: int, curriculum: CurriculumManager): | |
| super().__init__() | |
| self.episode_rewards: list[float] = [] | |
| self.episode_entropies: list[float] = [] | |
| self._running_reward = 0.0 | |
| self._running_entropy: list[float] = [] | |
| self._max_specialists = max_specialists | |
| self._curriculum = curriculum | |
| def _on_step(self): | |
| import numpy as np | |
| rewards = self.locals.get("rewards", []) | |
| dones = self.locals.get("dones", []) | |
| actions = self.locals.get("actions", None) | |
| if actions is not None: | |
| for action_vec in actions: | |
| n = self._max_specialists | |
| logits = action_vec[1:1 + n] | |
| logits = logits - logits.max() | |
| exp_l = np.exp(logits) | |
| probs = exp_l / (exp_l.sum() + 1e-8) | |
| entropy = float(-np.sum(probs * np.log(probs + 1e-8))) | |
| self._running_entropy.append(entropy) | |
| for r, d in zip(rewards, dones): | |
| self._running_reward += float(r) | |
| if d: | |
| ep_reward = self._running_reward | |
| self.episode_rewards.append(ep_reward) | |
| if self._running_entropy: | |
| self.episode_entropies.append( | |
| float(sum(self._running_entropy) / len(self._running_entropy)) | |
| ) | |
| self._running_entropy = [] | |
| self._running_reward = 0.0 | |
| self._curriculum.on_episode_end(ep_reward) | |
| return True | |
| reward_logger = _RewardLogger(max_specialists=_max_specialists, curriculum=curriculum) | |
| checkpoint_cb = CheckpointCallback( | |
| save_freq=2000, | |
| save_path="./checkpoints/", | |
| name_prefix="spindleflow_ppo", | |
| ) | |
| eval_cb = EvalCallback( | |
| eval_env, | |
| best_model_save_path="./checkpoints/best/", | |
| log_path="./eval_logs/", | |
| eval_freq=1000, | |
| n_eval_episodes=5, | |
| verbose=1, | |
| ) | |
| si_cfg = cfg.get("specialist_improvement", {}) | |
| improvement_cb = SpecialistImprovementCallback( | |
| improve_every_n_episodes=si_cfg.get("improve_every_n_episodes", 100), | |
| verbose=1, | |
| ) | |
| print(f"Starting training for {total_ts} timesteps...") | |
| print(f"TensorBoard: tensorboard --logdir tensorboard_logs/\n") | |
| model.learn( | |
| total_timesteps=total_ts, | |
| callback=[checkpoint_cb, eval_cb, reward_logger, improvement_cb], | |
| reset_num_timesteps=checkpoint is None, | |
| ) | |
| os.makedirs("checkpoints", exist_ok=True) | |
| model.save("checkpoints/spindleflow_final") | |
| env.save("checkpoints/vec_normalize.pkl") | |
| print("\nTraining complete. Model saved to checkpoints/spindleflow_final") | |
| # Save reward curve for the Streamlit dashboard | |
| import json, numpy as np | |
| ep = reward_logger.episode_rewards | |
| if ep: | |
| os.makedirs("demo/assets", exist_ok=True) | |
| step = max(1, len(ep) // 200) | |
| smoothed = [float(np.mean(ep[max(0, i-19):i+1])) for i in range(len(ep))] | |
| with open("demo/assets/reward_curve.json", "w") as f: | |
| json.dump({"episodes": list(range(len(ep)))[::step], | |
| "mean_rewards": smoothed[::step]}, f) | |
| print(f"Saved demo/assets/reward_curve.json ({len(ep)} episodes)") | |
| # Save entropy log for Training tab entropy chart | |
| ep_e = reward_logger.episode_entropies | |
| if ep_e: | |
| step_e = max(1, len(ep_e) // 200) | |
| with open("demo/assets/entropy_log.json", "w") as f: | |
| json.dump({ | |
| "episodes": list(range(len(ep_e)))[::step_e], | |
| "mean_entropies": ep_e[::step_e], | |
| }, f) | |
| print(f"Saved demo/assets/entropy_log.json ({len(ep_e)} episodes)") | |
| if __name__ == "__main__": | |
| train() | |