""" Main training entry point. Uses SB3 RecurrentPPO (LSTM) with curriculum learning. """ from __future__ import annotations import os import sys import yaml import click from pathlib import Path from dotenv import load_dotenv load_dotenv() @click.command() @click.option("--config", default="configs/training_config.yaml", help="Training config path") @click.option("--phase", default=1, type=int, help="Starting curriculum phase (1/2/3)") @click.option("--timesteps", default=None, type=int, help="Override total timesteps") @click.option("--demo-mode", is_flag=True, help="Use real SpindleFlow (slower, for demo)") @click.option("--checkpoint", default=None, help="Resume from checkpoint path") def train(config, phase, timesteps, demo_mode, checkpoint): """Train the SpindleFlow RL delegation policy.""" try: from sb3_contrib import RecurrentPPO except ImportError: print("ERROR: sb3-contrib required. Run: pip install sb3-contrib") sys.exit(1) from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize from stable_baselines3.common.callbacks import ( CheckpointCallback, EvalCallback, BaseCallback ) from training.curriculum import CurriculumManager from training.specialist_improvement_callback import SpecialistImprovementCallback from env.spindleflow_env import SpindleFlowEnv from policy.lstm_policy import build_policy_kwargs with open(config) as f: cfg = yaml.safe_load(f) ppo_cfg = cfg["ppo"] training_cfg = cfg["training"] lstm_cfg = cfg["lstm"] total_ts = timesteps or training_cfg["total_timesteps"] curriculum = CurriculumManager(config_path=config) print(f"\n{'='*60}") print(f"SpindleFlow RL Training") print(f" Phase: {phase}") print(f" Timesteps: {total_ts}") print(f" Demo mode (real SpindleFlow): {demo_mode}") print(f"{'='*60}\n") def make_env(): return SpindleFlowEnv( config_path=config, phase=phase, use_real_spindleflow=demo_mode, ) n_envs = training_cfg.get("n_envs", 1) env = DummyVecEnv([make_env for _ in range(n_envs)]) env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0) eval_env = DummyVecEnv([make_env]) eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False) policy_kwargs = build_policy_kwargs( hidden_size=lstm_cfg["hidden_size"], num_lstm_layers=lstm_cfg["num_layers"], ) if checkpoint and os.path.exists(checkpoint): print(f"Loading checkpoint: {checkpoint}") model = RecurrentPPO.load(checkpoint, env=env) else: model = RecurrentPPO( policy="MlpLstmPolicy", env=env, learning_rate=ppo_cfg["learning_rate"], n_steps=ppo_cfg["n_steps"], batch_size=ppo_cfg["batch_size"], n_epochs=ppo_cfg["n_epochs"], gamma=ppo_cfg["gamma"], gae_lambda=ppo_cfg["gae_lambda"], clip_range=ppo_cfg["clip_range"], ent_coef=ppo_cfg["ent_coef"], vf_coef=ppo_cfg["vf_coef"], max_grad_norm=ppo_cfg["max_grad_norm"], policy_kwargs=policy_kwargs, tensorboard_log="./tensorboard_logs/", verbose=1, seed=training_cfg["seed"], device=training_cfg["device"], ) _max_specialists = cfg["environment"].get("max_specialists_per_episode", 6) class _RewardLogger(BaseCallback): def __init__(self, max_specialists: int, curriculum: CurriculumManager): super().__init__() self.episode_rewards: list[float] = [] self.episode_entropies: list[float] = [] self._running_reward = 0.0 self._running_entropy: list[float] = [] self._max_specialists = max_specialists self._curriculum = curriculum def _on_step(self): import numpy as np rewards = self.locals.get("rewards", []) dones = self.locals.get("dones", []) actions = self.locals.get("actions", None) if actions is not None: for action_vec in actions: n = self._max_specialists logits = action_vec[1:1 + n] logits = logits - logits.max() exp_l = np.exp(logits) probs = exp_l / (exp_l.sum() + 1e-8) entropy = float(-np.sum(probs * np.log(probs + 1e-8))) self._running_entropy.append(entropy) for r, d in zip(rewards, dones): self._running_reward += float(r) if d: ep_reward = self._running_reward self.episode_rewards.append(ep_reward) if self._running_entropy: self.episode_entropies.append( float(sum(self._running_entropy) / len(self._running_entropy)) ) self._running_entropy = [] self._running_reward = 0.0 self._curriculum.on_episode_end(ep_reward) return True reward_logger = _RewardLogger(max_specialists=_max_specialists, curriculum=curriculum) checkpoint_cb = CheckpointCallback( save_freq=2000, save_path="./checkpoints/", name_prefix="spindleflow_ppo", ) eval_cb = EvalCallback( eval_env, best_model_save_path="./checkpoints/best/", log_path="./eval_logs/", eval_freq=1000, n_eval_episodes=5, verbose=1, ) si_cfg = cfg.get("specialist_improvement", {}) improvement_cb = SpecialistImprovementCallback( improve_every_n_episodes=si_cfg.get("improve_every_n_episodes", 100), verbose=1, ) print(f"Starting training for {total_ts} timesteps...") print(f"TensorBoard: tensorboard --logdir tensorboard_logs/\n") model.learn( total_timesteps=total_ts, callback=[checkpoint_cb, eval_cb, reward_logger, improvement_cb], reset_num_timesteps=checkpoint is None, ) os.makedirs("checkpoints", exist_ok=True) model.save("checkpoints/spindleflow_final") env.save("checkpoints/vec_normalize.pkl") print("\nTraining complete. Model saved to checkpoints/spindleflow_final") # Save reward curve for the Streamlit dashboard import json, numpy as np ep = reward_logger.episode_rewards if ep: os.makedirs("demo/assets", exist_ok=True) step = max(1, len(ep) // 200) smoothed = [float(np.mean(ep[max(0, i-19):i+1])) for i in range(len(ep))] with open("demo/assets/reward_curve.json", "w") as f: json.dump({"episodes": list(range(len(ep)))[::step], "mean_rewards": smoothed[::step]}, f) print(f"Saved demo/assets/reward_curve.json ({len(ep)} episodes)") # Save entropy log for Training tab entropy chart ep_e = reward_logger.episode_entropies if ep_e: step_e = max(1, len(ep_e) // 200) with open("demo/assets/entropy_log.json", "w") as f: json.dump({ "episodes": list(range(len(ep_e)))[::step_e], "mean_entropies": ep_e[::step_e], }, f) print(f"Saved demo/assets/entropy_log.json ({len(ep_e)} episodes)") if __name__ == "__main__": train()