Spaces:
Running
Running
| """ | |
| REINFORCE trainer for the EA Policy Network. | |
| Runs on AMD ROCm (PyTorch CUDA interface) or CPU. | |
| Trains in ~30 seconds on CPU, <5 seconds on ROCm. | |
| """ | |
| import os | |
| import json | |
| import time | |
| import logging | |
| import numpy as np | |
| import torch | |
| import torch.optim as optim | |
| from backend.drl.policy_network import EAPolicyNetwork, get_device | |
| from backend.drl.environment import EAEnvironment | |
| log = logging.getLogger(__name__) | |
| CHECKPOINT_DIR = os.path.join(os.path.dirname(__file__), "checkpoints") | |
| DEFAULT_CHECKPOINT = os.path.join(CHECKPOINT_DIR, "ea_policy_v1.pt") | |
| class REINFORCETrainer: | |
| """Vanilla policy gradient trainer for EAPolicyNetwork on EAEnvironment.""" | |
| def __init__( | |
| self, | |
| policy: EAPolicyNetwork | None = None, | |
| env: EAEnvironment | None = None, | |
| lr: float = 1e-3, | |
| device: torch.device | None = None, | |
| ): | |
| self.device = device or get_device() | |
| self.policy = (policy or EAPolicyNetwork()).to(self.device) | |
| self.env = env or EAEnvironment() | |
| self.optimizer = optim.Adam(self.policy.parameters(), lr=lr) | |
| self.training_log: list[dict] = [] | |
| def compute_returns(self, rewards: list[float], gamma: float = 0.99) -> torch.Tensor: | |
| """Compute discounted returns (single-step env: just the reward).""" | |
| returns = [] | |
| G = 0.0 | |
| for r in reversed(rewards): | |
| G = r + gamma * G | |
| returns.insert(0, G) | |
| returns_t = torch.FloatTensor(returns).to(self.device) | |
| # Normalise | |
| if returns_t.std() > 1e-8: | |
| returns_t = (returns_t - returns_t.mean()) / (returns_t.std() + 1e-8) | |
| return returns_t | |
| def train(self, n_episodes: int = 100, gamma: float = 0.99) -> list[dict]: | |
| """ | |
| Run REINFORCE for n_episodes. | |
| Returns training_log list of {episode, reward, loss}. | |
| Saves checkpoint every 25 episodes and at end. | |
| """ | |
| log.info(f"Starting REINFORCE training: {n_episodes} episodes on {self.device}") | |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) | |
| t0 = time.time() | |
| episode_rewards = [] | |
| for episode in range(n_episodes): | |
| self.policy.train() | |
| state = self.env.reset() | |
| state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device) | |
| # Single-step episode: sample one ordering | |
| action_indices, log_prob_sum = self.policy.sample_action(state_t.squeeze(0)) | |
| action_np = action_indices.cpu().numpy() | |
| _, reward, _ = self.env.step(action_np) | |
| # REINFORCE update | |
| returns = self.compute_returns([reward], gamma) | |
| loss = -(log_prob_sum * returns[0]) | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0) | |
| self.optimizer.step() | |
| episode_rewards.append(reward) | |
| entry = { | |
| "episode": episode, | |
| "reward": round(reward, 4), | |
| "loss": round(loss.item(), 6), | |
| "avg_reward_last10": round(float(np.mean(episode_rewards[-10:])), 4), | |
| } | |
| self.training_log.append(entry) | |
| if episode % 25 == 0 or episode == n_episodes - 1: | |
| log.info( | |
| f" Episode {episode+1}/{n_episodes} | " | |
| f"reward={reward:.4f} | loss={loss.item():.6f} | " | |
| f"avg10={entry['avg_reward_last10']:.4f}" | |
| ) | |
| self.save_checkpoint(DEFAULT_CHECKPOINT) | |
| elapsed = time.time() - t0 | |
| log.info(f"Training completed in {elapsed:.1f}s | final avg reward: {np.mean(episode_rewards[-20:]):.4f}") | |
| # Save training log alongside checkpoint | |
| log_path = DEFAULT_CHECKPOINT.replace(".pt", "_log.json") | |
| with open(log_path, "w") as f: | |
| json.dump(self.training_log, f, indent=2) | |
| return self.training_log | |
| def save_checkpoint(self, path: str): | |
| torch.save({ | |
| "policy_state_dict": self.policy.state_dict(), | |
| "optimizer_state_dict": self.optimizer.state_dict(), | |
| "training_log": self.training_log, | |
| "device": str(self.device), | |
| }, path) | |
| def load_checkpoint(self, path: str = DEFAULT_CHECKPOINT) -> bool: | |
| if not os.path.exists(path): | |
| log.warning(f"Checkpoint not found: {path}") | |
| return False | |
| checkpoint = torch.load(path, map_location=self.device) | |
| self.policy.load_state_dict(checkpoint["policy_state_dict"]) | |
| if "optimizer_state_dict" in checkpoint: | |
| self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| if "training_log" in checkpoint: | |
| self.training_log = checkpoint["training_log"] | |
| log.info(f"Loaded checkpoint from {path}") | |
| return True | |
| def load_trained_policy(checkpoint_path: str = DEFAULT_CHECKPOINT) -> EAPolicyNetwork: | |
| """Convenience function: load a trained policy for inference.""" | |
| device = get_device() | |
| policy = EAPolicyNetwork().to(device) | |
| if os.path.exists(checkpoint_path): | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| policy.load_state_dict(checkpoint["policy_state_dict"]) | |
| log.info(f"Policy loaded from {checkpoint_path}") | |
| else: | |
| log.warning(f"No checkpoint at {checkpoint_path} — using random policy. Run: python -m pipeline.train_drl") | |
| policy.eval() | |
| return policy | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| trainer = REINFORCETrainer() | |
| trainer.train(n_episodes=150) | |
| print(f"Checkpoint saved to: {DEFAULT_CHECKPOINT}") | |