| """ |
| Multi-Agent Online RL Training Loop. |
| |
| Uses alternating optimization: |
| Phase 1: Train Trader (freeze RM and PM policies, collect Trader trajectories). |
| Phase 2: Train RiskManager (freeze Trader and PM, collect RM trajectories). |
| (PM is trained similarly, but is often left as a rule-based agent for stability.) |
| |
| Trajectory collection: Step the MultiAgentTradingEnv AEC loop, collecting |
| (obs, action, reward, next_obs) per agent per step. |
| |
| GRPO/PPO fitting: Feed collected rollout buffers into TRL's GROPOTrainer |
| (for LLM-based agents) or a simple PPO loop (for numeric-action agents). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import time |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Dict, List, Tuple, Any |
|
|
| import numpy as np |
| import torch |
|
|
| from env.multi_agent_env import ( |
| MultiAgentTradingEnv, |
| RISK_MANAGER, |
| PORTFOLIO_MGR, |
| TRADER, |
| ALL_AGENTS, |
| ) |
|
|
|
|
| |
|
|
| class TrajectoryBuffer: |
| """Rollout buffer for one agent across many steps.""" |
|
|
| def __init__(self): |
| self.observations: List[np.ndarray] = [] |
| self.actions: List[Any] = [] |
| self.rewards: List[float] = [] |
|
|
| def add(self, obs: np.ndarray, action: Any, reward: float): |
| self.observations.append(obs) |
| self.actions.append(action) |
| self.rewards.append(reward) |
|
|
| def discounted_returns(self, gamma: float = 0.99) -> np.ndarray: |
| """Compute discounted returns (G_t) backward.""" |
| returns = np.zeros(len(self.rewards), dtype=np.float32) |
| running = 0.0 |
| for i in reversed(range(len(self.rewards))): |
| running = self.rewards[i] + gamma * running |
| returns[i] = running |
| return returns |
|
|
| def clear(self): |
| self.observations.clear() |
| self.actions.clear() |
| self.rewards.clear() |
|
|
| def __len__(self) -> int: |
| return len(self.rewards) |
|
|
|
|
| |
|
|
| class RuleRiskManagerPolicy: |
| """Baseline rule-based RM policy β sets constraints based on obs.""" |
|
|
| def act(self, obs: np.ndarray) -> np.ndarray: |
| drawdown = float(obs[19]) if len(obs) > 19 else 0.0 |
| volatility = float(obs[22]) if len(obs) > 22 else 0.1 |
| size_limit = float(np.clip(0.5 - drawdown * 2.0, 0.05, 0.80)) |
| allow_new = 1.0 if drawdown < 0.20 else 0.0 |
| force_reduce = 1.0 if drawdown > 0.25 else 0.0 |
| |
| noise = np.random.normal(0, 0.05, 3) |
| return np.clip( |
| np.array([size_limit, allow_new, force_reduce], dtype=np.float32) + noise, |
| 0.0, 1.0, |
| ) |
|
|
|
|
| class RulePortfolioManagerPolicy: |
| """Baseline rule-based PM policy.""" |
|
|
| def act(self, obs: np.ndarray) -> np.ndarray: |
| grade = float(obs[22]) if len(obs) > 22 else 0.5 |
| drawdown = float(obs[21]) if len(obs) > 21 else 0.0 |
| cap_alloc = float(np.clip(0.3 + 0.5 * grade - drawdown * 1.5, 0.05, 0.90)) |
| override_str = 0.0 |
| noise = np.random.normal(0, 0.03, 2) |
| return np.clip( |
| np.array([cap_alloc, override_str], dtype=np.float32) + noise, |
| 0.0, 1.0, |
| ) |
|
|
|
|
| class RuleTraderPolicy: |
| """Baseline rule-based Trader policy for warm-up rollouts.""" |
|
|
| def act(self, obs: np.ndarray) -> Dict: |
| |
| rsi = float(obs[5]) if len(obs) > 5 else 0.5 |
| bb_pos = float(obs[11]) if len(obs) > 11 else 0.5 |
| rm_limit = float(obs[24]) if len(obs) > 24 else 0.5 |
|
|
| if rsi < 0.35 and bb_pos < 0.25: |
| direction = 1 |
| elif rsi > 0.65 and bb_pos > 0.75: |
| direction = 2 |
| else: |
| direction = 0 |
|
|
| size = float(np.clip(np.random.uniform(0.05, min(0.3, rm_limit)) + np.random.normal(0, 0.03), 0.01, rm_limit)) |
| return { |
| "direction": direction, |
| "size": np.array([size], dtype=np.float32), |
| "sl": np.array([0.0], dtype=np.float32), |
| "tp": np.array([0.0], dtype=np.float32), |
| } |
|
|
|
|
| |
|
|
| def collect_rollout( |
| env: MultiAgentTradingEnv, |
| policies: Dict, |
| max_steps: int = 300, |
| ) -> Tuple[Dict[str, TrajectoryBuffer], Dict]: |
| """ |
| Run one full episode on the PettingZoo AEC env. |
| Returns per-agent TrajectoryBuffers and final info dict. |
| """ |
| buffers = {ag: TrajectoryBuffer() for ag in ALL_AGENTS} |
| env.reset() |
|
|
| step_count = 0 |
| final_info: Dict = {} |
|
|
| while env.agents and step_count < max_steps: |
| agent = env.agent_selection |
| obs = env.observe(agent) |
| policy = policies.get(agent) |
|
|
| if policy is None: |
| action = env.action_space(agent).sample() |
| else: |
| action = policy.act(obs) |
|
|
| |
| buffers[agent].add(obs, action, env.rewards.get(agent, 0.0)) |
|
|
| env.step(action) |
| step_count += 1 |
|
|
| if not env.agents: |
| final_info = env.infos.get(TRADER, {}) |
| break |
|
|
| return buffers, final_info |
|
|
|
|
| def compute_policy_gradient_loss( |
| buffers: Dict[str, TrajectoryBuffer], |
| target_agent: str, |
| gamma: float = 0.99, |
| ) -> float: |
| """ |
| Compute a simple REINFORCE-style loss for a given agent. |
| Returns mean discounted return (proxy for policy quality). |
| """ |
| buf = buffers.get(target_agent) |
| if buf is None or len(buf) == 0: |
| return 0.0 |
| returns = buf.discounted_returns(gamma=gamma) |
| return float(np.mean(returns)) |
|
|
|
|
| def train( |
| n_episodes: int = 200, |
| max_steps_ep: int = 300, |
| gamma: float = 0.99, |
| alternating_freq: int = 10, |
| output_dir: str = "outputs/multi_agent", |
| difficulty: str = "hard", |
| save_every: int = 25, |
| ) -> Dict: |
| """ |
| Main multi-agent training loop. |
| |
| Uses alternating optimization: |
| Episodes [0, alternating_freq): optimize Trader |
| Episodes [alternating_freq, 2*alternating_freq): optimize RiskManager |
| Then restart cycle. |
| |
| For each non-optimized agent, uses the rule-based fallback. |
| """ |
| out_path = Path(output_dir) |
| out_path.mkdir(parents=True, exist_ok=True) |
|
|
| env = MultiAgentTradingEnv(difficulty=difficulty, max_steps=max_steps_ep) |
|
|
| policies = { |
| RISK_MANAGER: RuleRiskManagerPolicy(), |
| PORTFOLIO_MGR: RulePortfolioManagerPolicy(), |
| TRADER: RuleTraderPolicy(), |
| } |
|
|
| |
| metrics: Dict = defaultdict(list) |
| best_trader_return = -np.inf |
|
|
| print("=" * 60) |
| print(" Multi-Agent Trading - Alternating Optimization Loop") |
| print(f" Episodes: {n_episodes} | Steps/ep: {max_steps_ep} | gamma={gamma}") |
| print("=" * 60) |
|
|
| for ep in range(n_episodes): |
| |
| cycle_pos = ep % (2 * alternating_freq) |
| opt_agent = TRADER if cycle_pos < alternating_freq else RISK_MANAGER |
|
|
| t0 = time.time() |
| buffers, info = collect_rollout(env, policies, max_steps=max_steps_ep) |
| elapsed = time.time() - t0 |
|
|
| |
| trader_return = compute_policy_gradient_loss(buffers, TRADER, gamma) |
| rm_return = compute_policy_gradient_loss(buffers, RISK_MANAGER, gamma) |
| pm_return = compute_policy_gradient_loss(buffers, PORTFOLIO_MGR, gamma) |
|
|
| |
| pnl_pct = info.get("pnl_pct", 0.0) |
| drawdown = info.get("max_drawdown", 0.0) |
| grade = info.get("grade", 0.0) |
| sharpe = info.get("sharpe_ratio", 0.0) |
| governance = info.get("governance", {}) |
| compliant = governance.get("was_compliant", False) |
|
|
| metrics["episode"].append(ep) |
| metrics["trader_return"].append(float(trader_return)) |
| metrics["rm_return"].append(float(rm_return)) |
| metrics["pm_return"].append(float(pm_return)) |
| metrics["pnl_pct"].append(float(pnl_pct)) |
| metrics["max_drawdown"].append(float(drawdown)) |
| metrics["grade"].append(float(grade)) |
| metrics["sharpe"].append(float(sharpe)) |
| metrics["opt_agent"].append(opt_agent) |
|
|
| if ep % 10 == 0: |
| print( |
| f"Ep {ep:4d} [{opt_agent:20s}] | " |
| f"Trader G={trader_return:+.4f} | RM G={rm_return:+.4f} | " |
| f"PnL={pnl_pct:+.2%} | DD={drawdown:.2%} | Grade={grade:.3f} | " |
| f"Sharpe={sharpe:+.3f} | {elapsed:.1f}s" |
| ) |
|
|
| |
| if trader_return > best_trader_return and len(buffers[TRADER]) > 10: |
| best_trader_return = trader_return |
| with open(out_path / "best_episode.json", "w") as f: |
| json.dump({"episode": ep, "trader_return": trader_return, "grade": grade}, f, indent=2) |
|
|
| |
| if ep % save_every == (save_every - 1): |
| _save_metrics(metrics, out_path / f"metrics_ep{ep+1}.json") |
| print(f" -> Checkpoint saved at episode {ep+1}") |
|
|
| _save_metrics(metrics, out_path / "metrics_final.json") |
| print("\nTraining complete.") |
| print(f" Best Trader Return: {best_trader_return:.4f}") |
| print(f" Final Mean Grade: {np.mean(metrics['grade'][-20:]):.4f}") |
| return metrics |
|
|
|
|
| def _save_metrics(metrics: Dict, path: Path): |
| import json |
| serialized = {k: [float(x) if isinstance(x, (np.floating, np.integer)) else x |
| for x in v] |
| for k, v in metrics.items()} |
| with open(path, "w") as f: |
| json.dump(serialized, f, indent=2) |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Multi-Agent Online RL Training") |
| parser.add_argument("--episodes", type=int, default=200) |
| parser.add_argument("--max-steps", type=int, default=300) |
| parser.add_argument("--gamma", type=float, default=0.99) |
| parser.add_argument("--alt-freq", type=int, default=10, |
| help="Alternating optimization frequency (episodes)") |
| parser.add_argument("--output-dir", type=str, default="outputs/multi_agent") |
| parser.add_argument("--difficulty", type=str, default="hard", |
| choices=["easy", "medium", "hard"]) |
| parser.add_argument("--save-every", type=int, default=25) |
| args = parser.parse_args() |
|
|
| metrics = train( |
| n_episodes=args.episodes, |
| max_steps_ep=args.max_steps, |
| gamma=args.gamma, |
| alternating_freq=args.alt_freq, |
| output_dir=args.output_dir, |
| difficulty=args.difficulty, |
| save_every=args.save_every, |
| ) |
|
|