Update training notebook and verifiers
Browse files- mate_training.ipynb +318 -424
- training/grpo_verifiers_multiagent.py +136 -0
- training/train_grpo_multiagent.py +42 -165
mate_training.ipynb
CHANGED
|
@@ -2,468 +2,415 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
| 5 |
-
"id": "header",
|
| 6 |
"metadata": {},
|
| 7 |
"source": [
|
| 8 |
-
"#
|
| 9 |
"\n",
|
| 10 |
-
"
|
| 11 |
-
"\n",
|
| 12 |
-
"This notebook is the single deliverable that proves QuantHive's training story end-to-end:\n",
|
| 13 |
-
"\n",
|
| 14 |
-
"| Section | What It Does | GPU? |\n",
|
| 15 |
-
"|:---|:---|:---|\n",
|
| 16 |
-
"| \u00a71-2 | Install deps, clone repo | \u274c |\n",
|
| 17 |
-
"| \u00a73 | Validate PettingZoo AEC env (3 agents, obs shapes, governance) | \u274c |\n",
|
| 18 |
-
"| \u00a74 | Run PettingZoo API compliance test | \u274c |\n",
|
| 19 |
-
"| \u00a75 | Multi-agent REINFORCE training (rule-based warm-up) | \u274c |\n",
|
| 20 |
-
"| \u00a76 | Generate per-agent reward/loss plots | \u274c |\n",
|
| 21 |
-
"| \u00a77 | Preview governance-aware GRPO prompt | \u274c |\n",
|
| 22 |
-
"| \u00a78 | **GRPO training** \u2014 Qwen 2.5-1.5B via Unsloth | \u2705 T4 |\n",
|
| 23 |
-
"| \u00a79 | Display all committed training evidence | \u274c |\n",
|
| 24 |
-
"\n",
|
| 25 |
-
"**Architecture**: PettingZoo AEC with 3 independent RL agents:\n",
|
| 26 |
-
"- `risk_manager_0` \u2014 obs: 24D, act: Box(3) \u2014 rewarded for restricting risk\n",
|
| 27 |
-
"- `portfolio_manager_0` \u2014 obs: 27D, act: Box(2) \u2014 rewarded for portfolio grade\n",
|
| 28 |
-
"- `trader_0` \u2014 obs: 29D, act: Dict \u2014 rewarded for PnL + compliance\n",
|
| 29 |
-
"\n",
|
| 30 |
-
"Turn order: **RM \u2192 PM \u2192 Trader** per market cycle. Each agent's output becomes part of the next agent's observation."
|
| 31 |
]
|
| 32 |
},
|
| 33 |
{
|
| 34 |
"cell_type": "markdown",
|
| 35 |
-
"id": "sec1_header",
|
| 36 |
"metadata": {},
|
| 37 |
"source": [
|
| 38 |
-
"
|
| 39 |
-
"## 1. Install Dependencies\n",
|
| 40 |
"\n",
|
| 41 |
-
"
|
| 42 |
]
|
| 43 |
},
|
| 44 |
{
|
| 45 |
"cell_type": "code",
|
| 46 |
"execution_count": null,
|
| 47 |
-
"id": "install_deps",
|
| 48 |
"metadata": {},
|
| 49 |
"outputs": [],
|
| 50 |
"source": [
|
| 51 |
-
"
|
| 52 |
-
"
|
| 53 |
-
"
|
| 54 |
-
"\n",
|
| 55 |
-
"# Data sources\n",
|
| 56 |
-
"%pip install yfinance ccxt\n",
|
| 57 |
-
"\n",
|
| 58 |
-
"# ML / Training\n",
|
| 59 |
-
"%pip install torch transformers trl peft accelerate datasets safetensors\n",
|
| 60 |
"\n",
|
| 61 |
-
"
|
| 62 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
"\n",
|
| 64 |
-
"
|
| 65 |
-
"
|
|
|
|
|
|
|
| 66 |
]
|
| 67 |
},
|
| 68 |
{
|
| 69 |
"cell_type": "markdown",
|
| 70 |
-
"id": "sec2_header",
|
| 71 |
"metadata": {},
|
| 72 |
"source": [
|
| 73 |
-
"
|
| 74 |
-
"
|
|
|
|
| 75 |
]
|
| 76 |
},
|
| 77 |
{
|
| 78 |
"cell_type": "code",
|
| 79 |
"execution_count": null,
|
| 80 |
-
"id": "clone_repo",
|
| 81 |
"metadata": {},
|
| 82 |
"outputs": [],
|
| 83 |
"source": [
|
| 84 |
-
"
|
| 85 |
-
"
|
| 86 |
-
"\n",
|
| 87 |
-
"
|
| 88 |
-
"
|
| 89 |
-
"\n",
|
| 90 |
-
"
|
| 91 |
-
"
|
| 92 |
-
"\n",
|
| 93 |
-
"
|
| 94 |
-
"
|
| 95 |
-
"
|
|
|
|
| 96 |
"\n",
|
| 97 |
-
"
|
| 98 |
-
"print(
|
| 99 |
]
|
| 100 |
},
|
| 101 |
{
|
| 102 |
"cell_type": "markdown",
|
| 103 |
-
"id": "sec3_header",
|
| 104 |
"metadata": {},
|
| 105 |
"source": [
|
| 106 |
-
"-
|
| 107 |
-
"## 3. Validate PettingZoo Multi-Agent Environment\n",
|
| 108 |
"\n",
|
| 109 |
-
"
|
| 110 |
]
|
| 111 |
},
|
| 112 |
{
|
| 113 |
"cell_type": "code",
|
| 114 |
"execution_count": null,
|
| 115 |
-
"id": "validate_env_setup",
|
| 116 |
"metadata": {},
|
| 117 |
"outputs": [],
|
| 118 |
"source": [
|
|
|
|
| 119 |
"import numpy as np\n",
|
|
|
|
| 120 |
"from env.multi_agent_env import (\n",
|
| 121 |
-
" MultiAgentTradingEnv,\n",
|
| 122 |
-
" RISK_MANAGER,\n",
|
| 123 |
-
" PORTFOLIO_MGR,\n",
|
| 124 |
-
" TRADER,\n",
|
| 125 |
" ALL_AGENTS,\n",
|
| 126 |
" BASE_OBS_SIZE,\n",
|
| 127 |
-
"
|
| 128 |
" PM_MSG_SIZE,\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
")\n",
|
| 130 |
"\n",
|
|
|
|
|
|
|
|
|
|
| 131 |
"env = MultiAgentTradingEnv(difficulty=\"easy\", max_steps=50)\n",
|
| 132 |
-
"env.reset()\n",
|
| 133 |
-
"\n",
|
| 134 |
-
"print(\"
|
| 135 |
-
"print(\"
|
| 136 |
-
"print(\"
|
| 137 |
-
"print(f\"
|
| 138 |
-
"print(f\"
|
| 139 |
-
"print(f\"
|
| 140 |
-
"print(f\"
|
| 141 |
-
"print(f\"
|
| 142 |
-
"print(f\"
|
| 143 |
-
"print(f\"
|
| 144 |
-
"print(f\"
|
| 145 |
]
|
| 146 |
},
|
| 147 |
{
|
| 148 |
"cell_type": "code",
|
| 149 |
"execution_count": null,
|
| 150 |
-
"id": "validate_aec_cycle",
|
| 151 |
"metadata": {},
|
| 152 |
"outputs": [],
|
| 153 |
"source": [
|
| 154 |
-
"
|
| 155 |
-
"
|
|
|
|
| 156 |
"rm_action = np.array([0.20, 1.0, 0.0], dtype=np.float32)\n",
|
| 157 |
"env.step(rm_action)\n",
|
| 158 |
-
"print(f\"\u2705 RM acted: size_limit=0.20, allow_new=yes, force_reduce=no\")\n",
|
| 159 |
"\n",
|
| 160 |
-
"
|
| 161 |
-
"pm_action = np.array([0.50, 0.0], dtype=np.float32)\n",
|
| 162 |
"env.step(pm_action)\n",
|
| 163 |
-
"print(f\"\u2705 PM acted: cap_alloc=0.50, override=0.0\")\n",
|
| 164 |
"\n",
|
| 165 |
-
"# Trader proposes a RECKLESS buy: size=0.85 (way over RM's 0.20 limit)\n",
|
| 166 |
"trader_action = {\n",
|
| 167 |
" \"direction\": 1,\n",
|
| 168 |
-
" \"size\": np.array([0.
|
| 169 |
" \"sl\": np.array([0.0], dtype=np.float32),\n",
|
| 170 |
" \"tp\": np.array([0.0], dtype=np.float32),\n",
|
| 171 |
"}\n",
|
| 172 |
"env.step(trader_action)\n",
|
| 173 |
-
"print(f\"\u2705 Trader acted: proposed BUY size=0.85\")\n",
|
| 174 |
-
"\n",
|
| 175 |
-
"# Inspect governance outcome\n",
|
| 176 |
-
"info = env.infos[TRADER]\n",
|
| 177 |
-
"gov = info[\"governance\"]\n",
|
| 178 |
-
"\n",
|
| 179 |
-
"print(f\"\\n{'='*60}\")\n",
|
| 180 |
-
"print(f\" GOVERNANCE NEGOTIATION RESULT\")\n",
|
| 181 |
-
"print(f\"{'='*60}\")\n",
|
| 182 |
-
"print(f\" Proposed: direction={gov['proposed']['direction']}, size={gov['proposed']['size']:.2f}\")\n",
|
| 183 |
-
"print(f\" Executed: direction={gov['executed']['direction']}, size={gov['executed']['size']:.2f}\")\n",
|
| 184 |
-
"print(f\" Compliant: {gov['was_compliant']}\")\n",
|
| 185 |
-
"print(f\" Interventions ({len(gov['interventions'])}):\") \n",
|
| 186 |
-
"for iv in gov[\"interventions\"]:\n",
|
| 187 |
-
" print(f\" \u2192 {iv['agent']}: {iv['type']}\")\n",
|
| 188 |
-
"\n",
|
| 189 |
-
"print(f\"\\n Per-Agent Rewards:\")\n",
|
| 190 |
-
"for agent, r in info[\"rewards\"].items():\n",
|
| 191 |
-
" print(f\" {agent:25s} {r:+.4f}\")\n",
|
| 192 |
-
"\n",
|
| 193 |
-
"print(f\"\\n Portfolio: ${info['portfolio_value']:,.2f} | PnL: {info['pnl_pct']:+.2%}\")"
|
| 194 |
-
]
|
| 195 |
-
},
|
| 196 |
-
{
|
| 197 |
-
"cell_type": "markdown",
|
| 198 |
-
"id": "sec4_header",
|
| 199 |
-
"metadata": {},
|
| 200 |
-
"source": [
|
| 201 |
-
"---\n",
|
| 202 |
-
"## 4. PettingZoo API Compliance Test\n",
|
| 203 |
"\n",
|
| 204 |
-
"
|
|
|
|
|
|
|
|
|
|
| 205 |
]
|
| 206 |
},
|
| 207 |
{
|
| 208 |
"cell_type": "code",
|
| 209 |
"execution_count": null,
|
| 210 |
-
"id": "pz_api_test",
|
| 211 |
"metadata": {},
|
| 212 |
"outputs": [],
|
| 213 |
"source": [
|
| 214 |
"from pettingzoo.test import api_test\n",
|
| 215 |
"\n",
|
| 216 |
-
"
|
| 217 |
-
"api_test(
|
| 218 |
-
"print(\"
|
| 219 |
]
|
| 220 |
},
|
| 221 |
{
|
| 222 |
"cell_type": "markdown",
|
| 223 |
-
"id": "sec5_header",
|
| 224 |
"metadata": {},
|
| 225 |
"source": [
|
| 226 |
-
"--
|
| 227 |
-
"## 5. Multi-Agent REINFORCE Training (Rule-Based Policies)\n",
|
| 228 |
-
"\n",
|
| 229 |
-
"This is the CPU-friendly training path. Three rule-based policies (RM, PM, Trader) are trained using alternating optimization with REINFORCE-style policy gradients.\n",
|
| 230 |
"\n",
|
| 231 |
-
"
|
| 232 |
-
"- Episodes 0-9: optimize Trader (RM/PM frozen)\n",
|
| 233 |
-
"- Episodes 10-19: optimize Risk Manager (Trader/PM frozen)\n",
|
| 234 |
-
"- Repeat"
|
| 235 |
]
|
| 236 |
},
|
| 237 |
{
|
| 238 |
"cell_type": "code",
|
| 239 |
"execution_count": null,
|
| 240 |
-
"id": "reinforce_training",
|
| 241 |
"metadata": {},
|
| 242 |
"outputs": [],
|
| 243 |
"source": [
|
| 244 |
"from training.train_multi_agent import train\n",
|
| 245 |
"\n",
|
| 246 |
"metrics = train(\n",
|
| 247 |
-
" n_episodes=
|
| 248 |
-
" max_steps_ep=
|
| 249 |
" gamma=0.99,\n",
|
| 250 |
" alternating_freq=10,\n",
|
| 251 |
" difficulty=\"easy\",\n",
|
| 252 |
" output_dir=\"outputs/multi_agent\",\n",
|
| 253 |
-
" save_every=
|
| 254 |
-
")"
|
| 255 |
-
]
|
| 256 |
-
},
|
| 257 |
-
{
|
| 258 |
-
"cell_type": "markdown",
|
| 259 |
-
"id": "sec5b_header",
|
| 260 |
-
"metadata": {},
|
| 261 |
-
"source": [
|
| 262 |
-
"### 5.1 Curriculum Verification\n",
|
| 263 |
"\n",
|
| 264 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
]
|
| 266 |
},
|
| 267 |
{
|
| 268 |
"cell_type": "code",
|
| 269 |
"execution_count": null,
|
| 270 |
-
"id": "curriculum_check",
|
| 271 |
"metadata": {},
|
| 272 |
"outputs": [],
|
| 273 |
"source": [
|
| 274 |
-
"from training.train_multi_agent import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
"\n",
|
| 276 |
"policies = {\n",
|
| 277 |
-
" RISK_MANAGER:
|
| 278 |
" PORTFOLIO_MGR: RulePortfolioManagerPolicy(),\n",
|
| 279 |
-
" TRADER:
|
| 280 |
"}\n",
|
| 281 |
"\n",
|
| 282 |
-
"print(f\"{'Difficulty':<12} {'Episodes':<10} {'Mean Trader Return':<22} {'Mean PnL':<
|
| 283 |
-
"print(\"-\" *
|
| 284 |
"\n",
|
| 285 |
"for diff in [\"easy\", \"medium\", \"hard\"]:\n",
|
| 286 |
" returns, pnls, dds = [], [], []\n",
|
| 287 |
" test_env = MultiAgentTradingEnv(difficulty=diff, max_steps=100)\n",
|
| 288 |
" for _ in range(10):\n",
|
| 289 |
" buffers, info = collect_rollout(test_env, policies, max_steps=100)\n",
|
| 290 |
-
"
|
| 291 |
-
"
|
| 292 |
-
"
|
| 293 |
-
"
|
| 294 |
-
" print(f\"{diff:<12} {10:<10} {np.mean(returns):+.6f} {np.mean(pnls):+.4%} {np.mean(dds):.4%}\")"
|
| 295 |
]
|
| 296 |
},
|
| 297 |
{
|
| 298 |
"cell_type": "markdown",
|
| 299 |
-
"id": "sec6_header",
|
| 300 |
"metadata": {},
|
| 301 |
"source": [
|
| 302 |
-
"
|
| 303 |
-
"## 6. Generate Per-Agent Reward & Loss Plots\n",
|
| 304 |
"\n",
|
| 305 |
-
"
|
| 306 |
]
|
| 307 |
},
|
| 308 |
{
|
| 309 |
"cell_type": "code",
|
| 310 |
"execution_count": null,
|
| 311 |
-
"id": "generate_plots",
|
| 312 |
"metadata": {},
|
| 313 |
"outputs": [],
|
| 314 |
"source": [
|
| 315 |
-
"import json
|
|
|
|
|
|
|
| 316 |
"matplotlib.use(\"Agg\")\n",
|
| 317 |
"import matplotlib.pyplot as plt\n",
|
| 318 |
"\n",
|
| 319 |
"metrics_path = Path(\"outputs/multi_agent/metrics_final.json\")\n",
|
|
|
|
|
|
|
| 320 |
"if not metrics_path.exists():\n",
|
| 321 |
-
"
|
| 322 |
-
"
|
| 323 |
-
"
|
|
|
|
| 324 |
"\n",
|
| 325 |
-
"with open(metrics_path) as
|
| 326 |
-
" m = json.load(
|
| 327 |
"\n",
|
|
|
|
|
|
|
| 328 |
"episodes = m[\"episode\"]\n",
|
| 329 |
"n_eps = len(episodes)\n",
|
| 330 |
"print(f\"Loaded {n_eps} episodes from {metrics_path}\")\n",
|
| 331 |
"\n",
|
| 332 |
-
"
|
| 333 |
-
"
|
|
|
|
|
|
|
|
|
|
| 334 |
"\n",
|
| 335 |
-
"
|
| 336 |
-
"
|
| 337 |
-
"
|
| 338 |
-
"\n",
|
| 339 |
-
"
|
| 340 |
-
"
|
| 341 |
-
"
|
| 342 |
-
"
|
| 343 |
-
"ep_s
|
| 344 |
-
"\n",
|
| 345 |
-
"ax.
|
| 346 |
-
"ax.
|
| 347 |
-
"ax.
|
| 348 |
-
"ax.
|
| 349 |
-
"ax.
|
| 350 |
-
"ax.legend(); ax.grid(True, alpha=0.3)\n",
|
| 351 |
"plt.tight_layout()\n",
|
| 352 |
-
"fig.savefig(\"
|
| 353 |
"plt.show()\n",
|
| 354 |
-
"print(\"Saved: plots/reward_curve.png\")\n",
|
| 355 |
"\n",
|
| 356 |
-
"# \u2500\u2500 Loss Curve (PnL convergence) \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
|
| 357 |
"fig2, ax2 = plt.subplots(figsize=(12, 6))\n",
|
| 358 |
-
"pnl_s = smooth(m[\"pnl_pct\"],
|
| 359 |
-
"ax2.plot(episodes[:len(pnl_s)], pnl_s, color=\"#e74c3c\", linewidth=2)\n",
|
| 360 |
-
"ax2.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.5)\n",
|
| 361 |
"pnl_arr = np.array(pnl_s)\n",
|
| 362 |
-
"ax2.
|
| 363 |
-
"
|
| 364 |
-
"ax2.fill_between(episodes[:len(pnl_s)], 0, pnl_s,\n",
|
| 365 |
-
"
|
| 366 |
-
"ax2.set_xlabel(\"Episode\")
|
| 367 |
-
"ax2.
|
|
|
|
| 368 |
"ax2.grid(True, alpha=0.3)\n",
|
| 369 |
"plt.tight_layout()\n",
|
| 370 |
-
"fig2.savefig(\"
|
| 371 |
"plt.show()\n",
|
| 372 |
-
"print(\"Saved: plots/loss_curve.png\")\n",
|
| 373 |
"\n",
|
| 374 |
-
"# \u2500\u2500 Baseline vs Trained \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
|
| 375 |
"if n_eps >= 20:\n",
|
| 376 |
" fig3, ax3 = plt.subplots(figsize=(10, 6))\n",
|
| 377 |
" names = [\"Trader Return\", \"Grade\", \"Max Drawdown\", \"Sharpe\"]\n",
|
| 378 |
-
" early = [np.mean(m[
|
| 379 |
-
" late
|
| 380 |
" x = np.arange(len(names))\n",
|
| 381 |
" ax3.bar(x - 0.175, early, 0.35, label=\"First 20 eps\", color=\"#e74c3c\", alpha=0.8)\n",
|
| 382 |
-
" ax3.bar(x + 0.175, late,
|
| 383 |
-
" ax3.set_ylabel(\"Value\")
|
| 384 |
-
" ax3.
|
| 385 |
-
" ax3.
|
|
|
|
|
|
|
|
|
|
| 386 |
" plt.tight_layout()\n",
|
| 387 |
-
" fig3.savefig(\"
|
| 388 |
" plt.show()\n",
|
| 389 |
-
"
|
|
|
|
| 390 |
]
|
| 391 |
},
|
| 392 |
{
|
| 393 |
"cell_type": "markdown",
|
| 394 |
-
"id": "sec7_header",
|
| 395 |
"metadata": {},
|
| 396 |
"source": [
|
| 397 |
-
"-
|
| 398 |
-
"## 7. Preview Governance-Aware GRPO Prompt\n",
|
| 399 |
"\n",
|
| 400 |
-
"
|
| 401 |
-
"Each prompt includes the RM and PM governance messages \u2014 the Trader must learn to read and respect them."
|
| 402 |
]
|
| 403 |
},
|
| 404 |
{
|
| 405 |
"cell_type": "code",
|
| 406 |
"execution_count": null,
|
| 407 |
-
"id": "preview_prompt",
|
| 408 |
"metadata": {},
|
| 409 |
"outputs": [],
|
| 410 |
"source": [
|
| 411 |
-
"from training.
|
| 412 |
"\n",
|
| 413 |
"scenarios = generate_pz_scenarios(n=3, difficulty=\"easy\", max_env_steps=30)\n",
|
| 414 |
-
"print(f\"Generated {len(scenarios)} scenarios
|
| 415 |
-
"\n",
|
| 416 |
-
"for i,
|
| 417 |
-
" print(
|
| 418 |
-
" print(f\"
|
| 419 |
-
" print(
|
| 420 |
-
"
|
| 421 |
-
"
|
| 422 |
-
"\n",
|
| 423 |
-
"
|
| 424 |
-
"
|
| 425 |
-
"
|
| 426 |
-
"
|
| 427 |
-
"
|
|
|
|
|
|
|
|
|
|
| 428 |
]
|
| 429 |
},
|
| 430 |
{
|
| 431 |
"cell_type": "code",
|
| 432 |
"execution_count": null,
|
| 433 |
-
"id": "verify_verifiers",
|
| 434 |
"metadata": {},
|
| 435 |
"outputs": [],
|
| 436 |
"source": [
|
| 437 |
-
"
|
| 438 |
-
"from training.
|
| 439 |
-
" risk_reward_func_multiagent,\n",
|
| 440 |
" governance_reward_func_multiagent,\n",
|
|
|
|
| 441 |
")\n",
|
| 442 |
-
"from env.reward import format_reward_func, alignment_reward_func, profit_reward_func\n",
|
| 443 |
"\n",
|
| 444 |
"test_prompt = build_prompt_multiagent(scenarios[0])\n",
|
|
|
|
| 445 |
"\n",
|
| 446 |
-
"# A compliant completion\n",
|
| 447 |
"compliant = (\n",
|
| 448 |
-
"
|
| 449 |
-
"
|
| 450 |
-
"
|
| 451 |
-
"
|
| 452 |
-
" '
|
| 453 |
-
"
|
| 454 |
-
"
|
| 455 |
-
" '{{\"direction\": 1, \"size\": {size:.2f}, \"sl\": 49000, \"tp\": 52000}}\\n'\n",
|
| 456 |
-
" '</action>'\n",
|
| 457 |
-
").format(\n",
|
| 458 |
-
" limit=scenarios[0][\"rm_size_limit\"],\n",
|
| 459 |
-
" cap=scenarios[0][\"pm_cap_alloc\"],\n",
|
| 460 |
-
" size=min(scenarios[0][\"rm_size_limit\"] * 0.7, scenarios[0][\"pm_cap_alloc\"] * 0.7),\n",
|
| 461 |
-
")\n",
|
| 462 |
"\n",
|
| 463 |
-
"# A reckless completion\n",
|
| 464 |
"reckless = (\n",
|
| 465 |
-
"
|
| 466 |
-
"
|
| 467 |
")\n",
|
| 468 |
"\n",
|
| 469 |
"prompts = [test_prompt, test_prompt]\n",
|
|
@@ -472,241 +419,188 @@
|
|
| 472 |
"print(f\"{'Verifier':<25} {'Compliant':<12} {'Reckless':<12}\")\n",
|
| 473 |
"print(\"-\" * 49)\n",
|
| 474 |
"for name, func in [\n",
|
| 475 |
-
" (\"Format\",
|
| 476 |
-
" (\"Alignment\",
|
| 477 |
-
" (\"Risk
|
| 478 |
-
" (\"Profit\",
|
| 479 |
-
" (\"Governance
|
| 480 |
"]:\n",
|
| 481 |
" scores = func(prompts, completions)\n",
|
| 482 |
-
" print(f\"{name:<25} {scores[0]:<12.2f} {scores[1]:<12.2f}\")"
|
| 483 |
]
|
| 484 |
},
|
| 485 |
{
|
| 486 |
"cell_type": "markdown",
|
| 487 |
-
"id": "sec8_header",
|
| 488 |
"metadata": {},
|
| 489 |
"source": [
|
| 490 |
-
"
|
| 491 |
-
"## 8. GRPO Training \u2014 Qwen 2.5-1.5B (GPU Required)\n",
|
| 492 |
-
"\n",
|
| 493 |
-
"This section trains the Trader agent as a language model using **GRPO** (Group Relative Policy Optimization) via Unsloth + TRL.\n",
|
| 494 |
-
"\n",
|
| 495 |
-
"**Requirements**: CUDA GPU (Colab T4 is sufficient).\n",
|
| 496 |
"\n",
|
| 497 |
-
"
|
| 498 |
-
"1. **Format** \u2014 valid `<thought>` + `<action>` XML tags\n",
|
| 499 |
-
"2. **Alignment** \u2014 reasoning matches market signals (anti-hallucination)\n",
|
| 500 |
-
"3. **Risk** \u2014 size \u2264 RM's dynamic `size_limit` (reads from governance in prompt)\n",
|
| 501 |
-
"4. **Profit** \u2014 direction matches price trend\n",
|
| 502 |
-
"5. **Governance** \u2014 would this action pass without intervention? Checks compliance against *both* RM and PM constraints"
|
| 503 |
]
|
| 504 |
},
|
| 505 |
{
|
| 506 |
"cell_type": "code",
|
| 507 |
"execution_count": null,
|
| 508 |
-
"id": "grpo_training",
|
| 509 |
"metadata": {},
|
| 510 |
"outputs": [],
|
| 511 |
"source": [
|
| 512 |
-
"
|
| 513 |
-
"
|
| 514 |
-
"\n",
|
| 515 |
-
"
|
| 516 |
-
"
|
| 517 |
-
"
|
| 518 |
-
"\n",
|
| 519 |
-
"
|
| 520 |
-
"
|
| 521 |
-
"\n",
|
| 522 |
-
"
|
| 523 |
-
"
|
| 524 |
-
"
|
| 525 |
-
"
|
| 526 |
-
"
|
| 527 |
-
"
|
| 528 |
-
"
|
| 529 |
-
"
|
| 530 |
-
"
|
| 531 |
-
"
|
| 532 |
-
"
|
| 533 |
-
"
|
| 534 |
-
"
|
| 535 |
-
"
|
| 536 |
-
"
|
| 537 |
-
"
|
| 538 |
-
"
|
| 539 |
-
"
|
| 540 |
-
"
|
| 541 |
-
"
|
| 542 |
-
"
|
| 543 |
-
"
|
| 544 |
-
"
|
| 545 |
-
"
|
| 546 |
-
"
|
| 547 |
-
"
|
| 548 |
-
"
|
| 549 |
-
"
|
| 550 |
-
"
|
| 551 |
-
"
|
| 552 |
-
"
|
| 553 |
-
"
|
| 554 |
-
"
|
| 555 |
-
"
|
| 556 |
-
"
|
| 557 |
-
"
|
| 558 |
-
"
|
| 559 |
-
"
|
| 560 |
-
"
|
| 561 |
-
"
|
| 562 |
-
"
|
| 563 |
-
"
|
| 564 |
-
"
|
| 565 |
-
"
|
| 566 |
-
"
|
| 567 |
-
"
|
| 568 |
-
"
|
| 569 |
-
"
|
| 570 |
-
"
|
| 571 |
-
"
|
| 572 |
-
"from utils.plotting import plot_training_results\n",
|
| 573 |
-
"
|
| 574 |
-
"
|
| 575 |
-
"
|
| 576 |
-
"
|
| 577 |
-
"
|
| 578 |
-
"
|
| 579 |
-
"
|
| 580 |
-
"
|
| 581 |
-
"
|
| 582 |
]
|
| 583 |
},
|
| 584 |
{
|
| 585 |
"cell_type": "markdown",
|
| 586 |
-
"id": "sec9_header",
|
| 587 |
"metadata": {},
|
| 588 |
"source": [
|
| 589 |
-
"
|
| 590 |
-
"## 9. Display All Training Evidence\n",
|
| 591 |
-
"\n",
|
| 592 |
-
"The repository ships committed training plots. This section displays them (whether generated in this notebook run or previously committed)."
|
| 593 |
]
|
| 594 |
},
|
| 595 |
{
|
| 596 |
"cell_type": "code",
|
| 597 |
"execution_count": null,
|
| 598 |
-
"id": "display_plots",
|
| 599 |
"metadata": {},
|
| 600 |
"outputs": [],
|
| 601 |
"source": [
|
| 602 |
-
"from IPython.display import Image,
|
| 603 |
"\n",
|
| 604 |
"plot_files = [\n",
|
| 605 |
-
" (\"plots/reward_curve.png\",
|
| 606 |
-
" (\"plots/loss_curve.png\",
|
| 607 |
-
" (\"plots/baseline_comparison.png\", \"Early vs Late Training
|
| 608 |
-
" (\"plots/grade_progression.png\", \"Portfolio Grade + Sharpe Over Training\"),\n",
|
| 609 |
"]\n",
|
| 610 |
"\n",
|
| 611 |
"for path, title in plot_files:\n",
|
| 612 |
" if Path(path).exists():\n",
|
| 613 |
" display(Markdown(f\"### {title}\"))\n",
|
| 614 |
-
" display(Image(filename=path, width=
|
| 615 |
" else:\n",
|
| 616 |
-
" print(f\"
|
| 617 |
]
|
| 618 |
},
|
| 619 |
{
|
| 620 |
"cell_type": "markdown",
|
| 621 |
-
"id": "sec10_header",
|
| 622 |
"metadata": {},
|
| 623 |
"source": [
|
| 624 |
-
"
|
| 625 |
-
"## 10. Submission Checklist"
|
| 626 |
]
|
| 627 |
},
|
| 628 |
{
|
| 629 |
"cell_type": "code",
|
| 630 |
"execution_count": null,
|
| 631 |
-
"id": "checklist",
|
| 632 |
"metadata": {},
|
| 633 |
"outputs": [],
|
| 634 |
"source": [
|
|
|
|
|
|
|
| 635 |
"checks = [\n",
|
| 636 |
-
" (\"openenv.yaml
|
| 637 |
-
" (\"README.md\",
|
| 638 |
-
" (\"WRITEUP.md\",
|
| 639 |
-
" (\"Multi-agent env\",
|
| 640 |
-
" (\"REINFORCE trainer\",
|
| 641 |
-
" (\"GRPO trainer
|
| 642 |
-
" (\"
|
| 643 |
-
" (\"
|
| 644 |
-
" (\"
|
| 645 |
-
" (\"
|
| 646 |
-
" (\"
|
| 647 |
-
" (\"
|
| 648 |
-
" (\"
|
|
|
|
| 649 |
"]\n",
|
| 650 |
"\n",
|
| 651 |
-
"print(f\"{'Deliverable':<
|
| 652 |
-
"print(\"-\" *
|
| 653 |
"for name, path in checks:\n",
|
| 654 |
-
" status = \"
|
| 655 |
-
" print(f\"{name:<
|
| 656 |
-
"\n",
|
| 657 |
-
"
|
| 658 |
-
"
|
| 659 |
-
"
|
| 660 |
-
"
|
| 661 |
-
"
|
| 662 |
-
" ep = manifest.get(\"environment\", {}).get(\"entry_point\", \"\")\n",
|
| 663 |
-
" is_pz = \"multi_agent_env\" in ep\n",
|
| 664 |
-
" print(f\"\\nopenenv.yaml entry_point: {ep}\")\n",
|
| 665 |
-
" print(f\"{'\u2705' if is_pz else '\u274c'} Points to {'PettingZoo' if is_pz else 'WRONG'} env\")\n",
|
| 666 |
-
"except Exception:\n",
|
| 667 |
-
" print(\"\\n\u26a0 Could not parse openenv.yaml (pyyaml may not be installed)\")"
|
| 668 |
-
]
|
| 669 |
-
},
|
| 670 |
-
{
|
| 671 |
-
"cell_type": "markdown",
|
| 672 |
-
"id": "footer",
|
| 673 |
-
"metadata": {},
|
| 674 |
-
"source": [
|
| 675 |
-
"---\n",
|
| 676 |
-
"\n",
|
| 677 |
-
"## Summary\n",
|
| 678 |
-
"\n",
|
| 679 |
-
"This notebook establishes the full QuantHive training evidence chain:\n",
|
| 680 |
-
"\n",
|
| 681 |
-
"1. **Environment validated** \u2014 `MultiAgentTradingEnv` passes PettingZoo's official API test (50 cycles)\n",
|
| 682 |
-
"2. **Governance works** \u2014 RM clamped a reckless 85% trade down to 20%, logged 3 interventions\n",
|
| 683 |
-
"3. **Multi-agent training** \u2014 alternating optimization with per-agent reward curves\n",
|
| 684 |
-
"4. **GRPO ready** \u2014 prompts include RM/PM constraints; verifiers check dynamic compliance\n",
|
| 685 |
-
"5. **Curriculum works** \u2014 env runs at easy/medium/hard with decreasing Trader returns\n",
|
| 686 |
-
"\n",
|
| 687 |
-
"**Key differentiator**: GRPO verifiers #3 (Risk) and #5 (Governance) check compliance against the RM's *learned, dynamic* `size_limit` \u2014 not a hardcoded constant. This means the Trader must learn to **read and respect** governance messages.\n",
|
| 688 |
-
"\n",
|
| 689 |
-
"---\n",
|
| 690 |
-
"*Built for the OpenEnv April '26 Hackathon | Theme 1: Multi-Agent Interactions* \n",
|
| 691 |
-
"*Author: Arka Sarkar*"
|
| 692 |
]
|
| 693 |
}
|
| 694 |
],
|
| 695 |
"metadata": {
|
| 696 |
-
"accelerator": "GPU",
|
| 697 |
-
"colab": {
|
| 698 |
-
"gpuType": "T4",
|
| 699 |
-
"provenance": []
|
| 700 |
-
},
|
| 701 |
"kernelspec": {
|
| 702 |
"display_name": "Python 3",
|
|
|
|
| 703 |
"name": "python3"
|
| 704 |
},
|
| 705 |
"language_info": {
|
| 706 |
-
"name": "python"
|
| 707 |
-
"version": "3.12.0"
|
| 708 |
}
|
| 709 |
},
|
| 710 |
"nbformat": 4,
|
| 711 |
"nbformat_minor": 5
|
| 712 |
-
}
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
|
|
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
+
"# QuantHive Multi-Agent Training Notebook\n",
|
| 8 |
"\n",
|
| 9 |
+
"Rerunnable notebook for validating the PettingZoo environment, running the rule-based multi-agent trainer, previewing governance-aware prompts, and optionally launching GRPO training on a GPU runtime.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
]
|
| 11 |
},
|
| 12 |
{
|
| 13 |
"cell_type": "markdown",
|
|
|
|
| 14 |
"metadata": {},
|
| 15 |
"source": [
|
| 16 |
+
"## 1. Bootstrap The Repo\n",
|
|
|
|
| 17 |
"\n",
|
| 18 |
+
"Clone from GitHub when the notebook is running outside the repo. If the current working directory already looks like the repository, reuse it.\n"
|
| 19 |
]
|
| 20 |
},
|
| 21 |
{
|
| 22 |
"cell_type": "code",
|
| 23 |
"execution_count": null,
|
|
|
|
| 24 |
"metadata": {},
|
| 25 |
"outputs": [],
|
| 26 |
"source": [
|
| 27 |
+
"import os\n",
|
| 28 |
+
"import subprocess\n",
|
| 29 |
+
"import sys\n",
|
| 30 |
+
"from pathlib import Path\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
"\n",
|
| 32 |
+
"REPO_URL = \"https://github.com/ARKAISW/multi-agent-trading-env.git\"\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"def looks_like_repo(path: Path) -> bool:\n",
|
| 35 |
+
" return (\n",
|
| 36 |
+
" (path / \"openenv.yaml\").exists()\n",
|
| 37 |
+
" and (path / \"env\").exists()\n",
|
| 38 |
+
" and (path / \"training\").exists()\n",
|
| 39 |
+
" )\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"start_dir = Path.cwd()\n",
|
| 42 |
+
"if looks_like_repo(start_dir):\n",
|
| 43 |
+
" REPO_DIR = start_dir\n",
|
| 44 |
+
"else:\n",
|
| 45 |
+
" clone_parent = Path(\"/content\") if Path(\"/content\").exists() else start_dir\n",
|
| 46 |
+
" REPO_DIR = clone_parent / \"multi-agent-trading-env\"\n",
|
| 47 |
+
" if not REPO_DIR.exists():\n",
|
| 48 |
+
" subprocess.check_call([\"git\", \"clone\", REPO_URL, str(REPO_DIR)])\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"os.chdir(REPO_DIR)\n",
|
| 51 |
+
"if str(REPO_DIR) not in sys.path:\n",
|
| 52 |
+
" sys.path.insert(0, str(REPO_DIR))\n",
|
| 53 |
"\n",
|
| 54 |
+
"commit = subprocess.check_output([\"git\", \"rev-parse\", \"--short\", \"HEAD\"], text=True).strip()\n",
|
| 55 |
+
"print(f\"Working directory: {Path.cwd()}\")\n",
|
| 56 |
+
"print(f\"Repo commit: {commit}\")\n",
|
| 57 |
+
"print(f\"Python: {sys.version.split()[0]}\")\n"
|
| 58 |
]
|
| 59 |
},
|
| 60 |
{
|
| 61 |
"cell_type": "markdown",
|
|
|
|
| 62 |
"metadata": {},
|
| 63 |
"source": [
|
| 64 |
+
"## 2. Install Notebook Dependencies\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"Install the lightweight stack needed for environment validation, rule-based training, plots, and prompt preview. The heavier GRPO packages stay in the optional GPU section.\n"
|
| 67 |
]
|
| 68 |
},
|
| 69 |
{
|
| 70 |
"cell_type": "code",
|
| 71 |
"execution_count": null,
|
|
|
|
| 72 |
"metadata": {},
|
| 73 |
"outputs": [],
|
| 74 |
"source": [
|
| 75 |
+
"BASE_PACKAGES = [\n",
|
| 76 |
+
" \"openenv\",\n",
|
| 77 |
+
" \"pyyaml\",\n",
|
| 78 |
+
" \"pettingzoo>=1.24.0\",\n",
|
| 79 |
+
" \"gymnasium\",\n",
|
| 80 |
+
" \"numpy\",\n",
|
| 81 |
+
" \"pandas\",\n",
|
| 82 |
+
" \"matplotlib\",\n",
|
| 83 |
+
" \"scipy\",\n",
|
| 84 |
+
" \"torch\",\n",
|
| 85 |
+
" \"yfinance\",\n",
|
| 86 |
+
" \"ccxt\",\n",
|
| 87 |
+
"]\n",
|
| 88 |
"\n",
|
| 89 |
+
"subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *BASE_PACKAGES])\n",
|
| 90 |
+
"print(\"Installed base notebook dependencies.\")\n"
|
| 91 |
]
|
| 92 |
},
|
| 93 |
{
|
| 94 |
"cell_type": "markdown",
|
|
|
|
| 95 |
"metadata": {},
|
| 96 |
"source": [
|
| 97 |
+
"## 3. Validate The PettingZoo Multi-Agent Environment\n",
|
|
|
|
| 98 |
"\n",
|
| 99 |
+
"Probe the constructor first so the notebook does not assume the wrong signature, then instantiate and inspect observation and action shapes.\n"
|
| 100 |
]
|
| 101 |
},
|
| 102 |
{
|
| 103 |
"cell_type": "code",
|
| 104 |
"execution_count": null,
|
|
|
|
| 105 |
"metadata": {},
|
| 106 |
"outputs": [],
|
| 107 |
"source": [
|
| 108 |
+
"import inspect\n",
|
| 109 |
"import numpy as np\n",
|
| 110 |
+
"\n",
|
| 111 |
"from env.multi_agent_env import (\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
" ALL_AGENTS,\n",
|
| 113 |
" BASE_OBS_SIZE,\n",
|
| 114 |
+
" MultiAgentTradingEnv,\n",
|
| 115 |
" PM_MSG_SIZE,\n",
|
| 116 |
+
" PORTFOLIO_MGR,\n",
|
| 117 |
+
" RISK_MANAGER,\n",
|
| 118 |
+
" RM_MSG_SIZE,\n",
|
| 119 |
+
" TRADER,\n",
|
| 120 |
")\n",
|
| 121 |
"\n",
|
| 122 |
+
"print(\"MultiAgentTradingEnv signature:\")\n",
|
| 123 |
+
"print(inspect.signature(MultiAgentTradingEnv))\n",
|
| 124 |
+
"\n",
|
| 125 |
"env = MultiAgentTradingEnv(difficulty=\"easy\", max_steps=50)\n",
|
| 126 |
+
"reset_result = env.reset(seed=7)\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"print(\"\\nEnvironment summary\")\n",
|
| 129 |
+
"print(\"-\" * 60)\n",
|
| 130 |
+
"print(f\"reset() returned: {reset_result}\")\n",
|
| 131 |
+
"print(f\"Agents: {env.agents}\")\n",
|
| 132 |
+
"print(f\"Turn order: RM -> PM -> Trader\")\n",
|
| 133 |
+
"print(f\"RM obs: {env.observe(RISK_MANAGER).shape} (base={BASE_OBS_SIZE})\")\n",
|
| 134 |
+
"print(f\"PM obs: {env.observe(PORTFOLIO_MGR).shape} (base+rm={BASE_OBS_SIZE + RM_MSG_SIZE})\")\n",
|
| 135 |
+
"print(f\"Trader obs: {env.observe(TRADER).shape} (base+rm+pm={BASE_OBS_SIZE + RM_MSG_SIZE + PM_MSG_SIZE})\")\n",
|
| 136 |
+
"print(f\"RM action space: {env.action_space(RISK_MANAGER)}\")\n",
|
| 137 |
+
"print(f\"PM action space: {env.action_space(PORTFOLIO_MGR)}\")\n",
|
| 138 |
+
"print(f\"Trader action space: {env.action_space(TRADER)}\")\n"
|
| 139 |
]
|
| 140 |
},
|
| 141 |
{
|
| 142 |
"cell_type": "code",
|
| 143 |
"execution_count": null,
|
|
|
|
| 144 |
"metadata": {},
|
| 145 |
"outputs": [],
|
| 146 |
"source": [
|
| 147 |
+
"env = MultiAgentTradingEnv(difficulty=\"easy\", max_steps=10)\n",
|
| 148 |
+
"env.reset(seed=7)\n",
|
| 149 |
+
"\n",
|
| 150 |
"rm_action = np.array([0.20, 1.0, 0.0], dtype=np.float32)\n",
|
| 151 |
"env.step(rm_action)\n",
|
|
|
|
| 152 |
"\n",
|
| 153 |
+
"pm_action = np.array([0.35, 0.0], dtype=np.float32)\n",
|
|
|
|
| 154 |
"env.step(pm_action)\n",
|
|
|
|
| 155 |
"\n",
|
|
|
|
| 156 |
"trader_action = {\n",
|
| 157 |
" \"direction\": 1,\n",
|
| 158 |
+
" \"size\": np.array([0.10], dtype=np.float32),\n",
|
| 159 |
" \"sl\": np.array([0.0], dtype=np.float32),\n",
|
| 160 |
" \"tp\": np.array([0.0], dtype=np.float32),\n",
|
| 161 |
"}\n",
|
| 162 |
"env.step(trader_action)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
"\n",
|
| 164 |
+
"print(\"Completed one AEC cycle.\")\n",
|
| 165 |
+
"print(f\"Current step: {env._current_step}\")\n",
|
| 166 |
+
"print(f\"Current agent selection: {env.agent_selection}\")\n",
|
| 167 |
+
"print(f\"Latest rewards: {env.rewards}\")\n"
|
| 168 |
]
|
| 169 |
},
|
| 170 |
{
|
| 171 |
"cell_type": "code",
|
| 172 |
"execution_count": null,
|
|
|
|
| 173 |
"metadata": {},
|
| 174 |
"outputs": [],
|
| 175 |
"source": [
|
| 176 |
"from pettingzoo.test import api_test\n",
|
| 177 |
"\n",
|
| 178 |
+
"api_env = MultiAgentTradingEnv(difficulty=\"easy\", max_steps=20)\n",
|
| 179 |
+
"api_test(api_env, num_cycles=20, verbose_progress=True)\n",
|
| 180 |
+
"print(\"PettingZoo API test passed.\")\n"
|
| 181 |
]
|
| 182 |
},
|
| 183 |
{
|
| 184 |
"cell_type": "markdown",
|
|
|
|
| 185 |
"metadata": {},
|
| 186 |
"source": [
|
| 187 |
+
"## 4. Run The Rule-Based Multi-Agent Trainer\n",
|
|
|
|
|
|
|
|
|
|
| 188 |
"\n",
|
| 189 |
+
"This section exercises the multi-agent training loop without needing the GRPO stack.\n"
|
|
|
|
|
|
|
|
|
|
| 190 |
]
|
| 191 |
},
|
| 192 |
{
|
| 193 |
"cell_type": "code",
|
| 194 |
"execution_count": null,
|
|
|
|
| 195 |
"metadata": {},
|
| 196 |
"outputs": [],
|
| 197 |
"source": [
|
| 198 |
"from training.train_multi_agent import train\n",
|
| 199 |
"\n",
|
| 200 |
"metrics = train(\n",
|
| 201 |
+
" n_episodes=40,\n",
|
| 202 |
+
" max_steps_ep=150,\n",
|
| 203 |
" gamma=0.99,\n",
|
| 204 |
" alternating_freq=10,\n",
|
| 205 |
" difficulty=\"easy\",\n",
|
| 206 |
" output_dir=\"outputs/multi_agent\",\n",
|
| 207 |
+
" save_every=20,\n",
|
| 208 |
+
")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
"\n",
|
| 210 |
+
"print(f\"Saved training outputs to: {Path('outputs/multi_agent').resolve()}\")\n",
|
| 211 |
+
"if metrics.get(\"trader_return\"):\n",
|
| 212 |
+
" print(f\"Final trader return: {metrics['trader_return'][-1]:+.4f}\")\n",
|
| 213 |
+
"if metrics.get(\"grade\"):\n",
|
| 214 |
+
" print(f\"Final grade: {metrics['grade'][-1]:.4f}\")\n"
|
| 215 |
]
|
| 216 |
},
|
| 217 |
{
|
| 218 |
"cell_type": "code",
|
| 219 |
"execution_count": null,
|
|
|
|
| 220 |
"metadata": {},
|
| 221 |
"outputs": [],
|
| 222 |
"source": [
|
| 223 |
+
"from training.train_multi_agent import (\n",
|
| 224 |
+
" RulePortfolioManagerPolicy,\n",
|
| 225 |
+
" RuleRiskManagerPolicy,\n",
|
| 226 |
+
" RuleTraderPolicy,\n",
|
| 227 |
+
" collect_rollout,\n",
|
| 228 |
+
")\n",
|
| 229 |
"\n",
|
| 230 |
"policies = {\n",
|
| 231 |
+
" RISK_MANAGER: RuleRiskManagerPolicy(),\n",
|
| 232 |
" PORTFOLIO_MGR: RulePortfolioManagerPolicy(),\n",
|
| 233 |
+
" TRADER: RuleTraderPolicy(),\n",
|
| 234 |
"}\n",
|
| 235 |
"\n",
|
| 236 |
+
"print(f\"{'Difficulty':<12} {'Episodes':<10} {'Mean Trader Return':<22} {'Mean PnL':<12} {'Mean DD':<12}\")\n",
|
| 237 |
+
"print(\"-\" * 74)\n",
|
| 238 |
"\n",
|
| 239 |
"for diff in [\"easy\", \"medium\", \"hard\"]:\n",
|
| 240 |
" returns, pnls, dds = [], [], []\n",
|
| 241 |
" test_env = MultiAgentTradingEnv(difficulty=diff, max_steps=100)\n",
|
| 242 |
" for _ in range(10):\n",
|
| 243 |
" buffers, info = collect_rollout(test_env, policies, max_steps=100)\n",
|
| 244 |
+
" returns.append(float(np.mean(buffers[TRADER].discounted_returns())))\n",
|
| 245 |
+
" pnls.append(float(info.get(\"pnl_pct\", 0.0)))\n",
|
| 246 |
+
" dds.append(float(info.get(\"max_drawdown\", 0.0)))\n",
|
| 247 |
+
" print(f\"{diff:<12} {10:<10} {np.mean(returns):+.6f} {np.mean(pnls):+.4%} {np.mean(dds):.4%}\")\n"
|
|
|
|
| 248 |
]
|
| 249 |
},
|
| 250 |
{
|
| 251 |
"cell_type": "markdown",
|
|
|
|
| 252 |
"metadata": {},
|
| 253 |
"source": [
|
| 254 |
+
"## 5. Generate Training Plots\n",
|
|
|
|
| 255 |
"\n",
|
| 256 |
+
"Use the saved metrics when available, or fall back to the in-memory `metrics` object from the training cell.\n"
|
| 257 |
]
|
| 258 |
},
|
| 259 |
{
|
| 260 |
"cell_type": "code",
|
| 261 |
"execution_count": null,
|
|
|
|
| 262 |
"metadata": {},
|
| 263 |
"outputs": [],
|
| 264 |
"source": [
|
| 265 |
+
"import json\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"import matplotlib\n",
|
| 268 |
"matplotlib.use(\"Agg\")\n",
|
| 269 |
"import matplotlib.pyplot as plt\n",
|
| 270 |
"\n",
|
| 271 |
"metrics_path = Path(\"outputs/multi_agent/metrics_final.json\")\n",
|
| 272 |
+
"metrics_path.parent.mkdir(parents=True, exist_ok=True)\n",
|
| 273 |
+
"\n",
|
| 274 |
"if not metrics_path.exists():\n",
|
| 275 |
+
" if \"metrics\" not in globals():\n",
|
| 276 |
+
" raise RuntimeError(\"Run the training cell first so metrics are available.\")\n",
|
| 277 |
+
" with open(metrics_path, \"w\", encoding=\"utf-8\") as handle:\n",
|
| 278 |
+
" json.dump(dict(metrics), handle, indent=2)\n",
|
| 279 |
"\n",
|
| 280 |
+
"with open(metrics_path, encoding=\"utf-8\") as handle:\n",
|
| 281 |
+
" m = json.load(handle)\n",
|
| 282 |
"\n",
|
| 283 |
+
"plots_dir = Path(\"plots\")\n",
|
| 284 |
+
"plots_dir.mkdir(parents=True, exist_ok=True)\n",
|
| 285 |
"episodes = m[\"episode\"]\n",
|
| 286 |
"n_eps = len(episodes)\n",
|
| 287 |
"print(f\"Loaded {n_eps} episodes from {metrics_path}\")\n",
|
| 288 |
"\n",
|
| 289 |
+
"def smooth(values, window=10):\n",
|
| 290 |
+
" if len(values) < window:\n",
|
| 291 |
+
" return values\n",
|
| 292 |
+
" kernel = np.ones(window) / window\n",
|
| 293 |
+
" return np.convolve(values, kernel, mode=\"valid\").tolist()\n",
|
| 294 |
"\n",
|
| 295 |
+
"window = max(1, n_eps // 15)\n",
|
| 296 |
+
"\n",
|
| 297 |
+
"fig, ax = plt.subplots(figsize=(12, 6))\n",
|
| 298 |
+
"trader_s = smooth(m[\"trader_return\"], window)\n",
|
| 299 |
+
"rm_s = smooth(m[\"rm_return\"], window)\n",
|
| 300 |
+
"pm_s = smooth(m[\"pm_return\"], window)\n",
|
| 301 |
+
"ep_s = episodes[: len(trader_s)]\n",
|
| 302 |
+
"ax.plot(ep_s, trader_s, label=\"Trader\", color=\"#2ecc71\", linewidth=2)\n",
|
| 303 |
+
"ax.plot(ep_s, rm_s, label=\"Risk Manager\", color=\"#e74c3c\", linewidth=2)\n",
|
| 304 |
+
"ax.plot(ep_s, pm_s, label=\"Portfolio Manager\", color=\"#3498db\", linewidth=2)\n",
|
| 305 |
+
"ax.set_xlabel(\"Episode\")\n",
|
| 306 |
+
"ax.set_ylabel(\"Discounted Return\")\n",
|
| 307 |
+
"ax.set_title(\"QuantHive: Per-Agent Reward Curves\")\n",
|
| 308 |
+
"ax.legend()\n",
|
| 309 |
+
"ax.grid(True, alpha=0.3)\n",
|
|
|
|
| 310 |
"plt.tight_layout()\n",
|
| 311 |
+
"fig.savefig(plots_dir / \"reward_curve.png\", dpi=150)\n",
|
| 312 |
"plt.show()\n",
|
|
|
|
| 313 |
"\n",
|
|
|
|
| 314 |
"fig2, ax2 = plt.subplots(figsize=(12, 6))\n",
|
| 315 |
+
"pnl_s = smooth(m[\"pnl_pct\"], window)\n",
|
|
|
|
|
|
|
| 316 |
"pnl_arr = np.array(pnl_s)\n",
|
| 317 |
+
"ax2.plot(episodes[: len(pnl_s)], pnl_s, color=\"#e74c3c\", linewidth=2)\n",
|
| 318 |
+
"ax2.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.5)\n",
|
| 319 |
+
"ax2.fill_between(episodes[: len(pnl_s)], 0, pnl_s, where=pnl_arr > 0, color=\"#2ecc71\", alpha=0.2)\n",
|
| 320 |
+
"ax2.fill_between(episodes[: len(pnl_s)], 0, pnl_s, where=pnl_arr <= 0, color=\"#e74c3c\", alpha=0.2)\n",
|
| 321 |
+
"ax2.set_xlabel(\"Episode\")\n",
|
| 322 |
+
"ax2.set_ylabel(\"PnL %\")\n",
|
| 323 |
+
"ax2.set_title(\"QuantHive: PnL Over Training\")\n",
|
| 324 |
"ax2.grid(True, alpha=0.3)\n",
|
| 325 |
"plt.tight_layout()\n",
|
| 326 |
+
"fig2.savefig(plots_dir / \"loss_curve.png\", dpi=150)\n",
|
| 327 |
"plt.show()\n",
|
|
|
|
| 328 |
"\n",
|
|
|
|
| 329 |
"if n_eps >= 20:\n",
|
| 330 |
" fig3, ax3 = plt.subplots(figsize=(10, 6))\n",
|
| 331 |
" names = [\"Trader Return\", \"Grade\", \"Max Drawdown\", \"Sharpe\"]\n",
|
| 332 |
+
" early = [np.mean(m[key][:20]) for key in [\"trader_return\", \"grade\", \"max_drawdown\", \"sharpe\"]]\n",
|
| 333 |
+
" late = [np.mean(m[key][-20:]) for key in [\"trader_return\", \"grade\", \"max_drawdown\", \"sharpe\"]]\n",
|
| 334 |
" x = np.arange(len(names))\n",
|
| 335 |
" ax3.bar(x - 0.175, early, 0.35, label=\"First 20 eps\", color=\"#e74c3c\", alpha=0.8)\n",
|
| 336 |
+
" ax3.bar(x + 0.175, late, 0.35, label=\"Last 20 eps\", color=\"#2ecc71\", alpha=0.8)\n",
|
| 337 |
+
" ax3.set_ylabel(\"Value\")\n",
|
| 338 |
+
" ax3.set_title(\"QuantHive: Early vs Late Training\")\n",
|
| 339 |
+
" ax3.set_xticks(x)\n",
|
| 340 |
+
" ax3.set_xticklabels(names)\n",
|
| 341 |
+
" ax3.legend()\n",
|
| 342 |
+
" ax3.grid(True, alpha=0.3, axis=\"y\")\n",
|
| 343 |
" plt.tight_layout()\n",
|
| 344 |
+
" fig3.savefig(plots_dir / \"baseline_comparison.png\", dpi=150)\n",
|
| 345 |
" plt.show()\n",
|
| 346 |
+
"\n",
|
| 347 |
+
"print(f\"Saved plots to: {plots_dir.resolve()}\")\n"
|
| 348 |
]
|
| 349 |
},
|
| 350 |
{
|
| 351 |
"cell_type": "markdown",
|
|
|
|
| 352 |
"metadata": {},
|
| 353 |
"source": [
|
| 354 |
+
"## 6. Preview Governance-Aware GRPO Prompts\n",
|
|
|
|
| 355 |
"\n",
|
| 356 |
+
"Import only the lightweight prompt helpers so prompt preview does not depend on the trainer stack.\n"
|
|
|
|
| 357 |
]
|
| 358 |
},
|
| 359 |
{
|
| 360 |
"cell_type": "code",
|
| 361 |
"execution_count": null,
|
|
|
|
| 362 |
"metadata": {},
|
| 363 |
"outputs": [],
|
| 364 |
"source": [
|
| 365 |
+
"from training.prompt_utils import build_prompt_multiagent, generate_pz_scenarios\n",
|
| 366 |
"\n",
|
| 367 |
"scenarios = generate_pz_scenarios(n=3, difficulty=\"easy\", max_env_steps=30)\n",
|
| 368 |
+
"print(f\"Generated {len(scenarios)} scenarios.\\n\")\n",
|
| 369 |
+
"\n",
|
| 370 |
+
"for i, scenario in enumerate(scenarios, start=1):\n",
|
| 371 |
+
" print(\"=\" * 60)\n",
|
| 372 |
+
" print(f\"Scenario {i}\")\n",
|
| 373 |
+
" print(\n",
|
| 374 |
+
" f\"RM: size_limit={scenario['rm_size_limit']:.2f}, \"\n",
|
| 375 |
+
" f\"allow_new={scenario['rm_allow_new']}, force_reduce={scenario['rm_force_reduce']}\"\n",
|
| 376 |
+
" )\n",
|
| 377 |
+
" print(\n",
|
| 378 |
+
" f\"PM: cap_alloc={scenario['pm_cap_alloc']:.2f}, \"\n",
|
| 379 |
+
" f\"override={scenario['pm_override']:.2f}\"\n",
|
| 380 |
+
" )\n",
|
| 381 |
+
"\n",
|
| 382 |
+
"print(\"\\nFull prompt for scenario 1\")\n",
|
| 383 |
+
"print(\"=\" * 60)\n",
|
| 384 |
+
"print(build_prompt_multiagent(scenarios[0]))\n"
|
| 385 |
]
|
| 386 |
},
|
| 387 |
{
|
| 388 |
"cell_type": "code",
|
| 389 |
"execution_count": null,
|
|
|
|
| 390 |
"metadata": {},
|
| 391 |
"outputs": [],
|
| 392 |
"source": [
|
| 393 |
+
"from env.reward import alignment_reward_func, format_reward_func, profit_reward_func\n",
|
| 394 |
+
"from training.grpo_verifiers_multiagent import (\n",
|
|
|
|
| 395 |
" governance_reward_func_multiagent,\n",
|
| 396 |
+
" risk_reward_func_multiagent,\n",
|
| 397 |
")\n",
|
|
|
|
| 398 |
"\n",
|
| 399 |
"test_prompt = build_prompt_multiagent(scenarios[0])\n",
|
| 400 |
+
"effective_limit = min(scenarios[0][\"rm_size_limit\"], scenarios[0][\"pm_cap_alloc\"])\n",
|
| 401 |
"\n",
|
|
|
|
| 402 |
"compliant = (\n",
|
| 403 |
+
" \"<thought>\\n\"\n",
|
| 404 |
+
" \"Risk limits are active, so I will stay inside both the RM size cap and the PM allocation.\\n\"\n",
|
| 405 |
+
" \"</thought>\\n\"\n",
|
| 406 |
+
" \"<action>\\n\"\n",
|
| 407 |
+
" '{{\"direction\": 1, \"size\": %.2f, \"sl\": 49000, \"tp\": 52000}}\\n'\n",
|
| 408 |
+
" \"</action>\"\n",
|
| 409 |
+
") % (effective_limit * 0.7)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
"\n",
|
|
|
|
| 411 |
"reckless = (\n",
|
| 412 |
+
" \"<thought>\\nMarket is moving up. Going all in.\\n</thought>\\n\"\n",
|
| 413 |
+
" \"<action>\\n{\\\"direction\\\": 1, \\\"size\\\": 0.95, \\\"sl\\\": 0, \\\"tp\\\": 0}\\n</action>\"\n",
|
| 414 |
")\n",
|
| 415 |
"\n",
|
| 416 |
"prompts = [test_prompt, test_prompt]\n",
|
|
|
|
| 419 |
"print(f\"{'Verifier':<25} {'Compliant':<12} {'Reckless':<12}\")\n",
|
| 420 |
"print(\"-\" * 49)\n",
|
| 421 |
"for name, func in [\n",
|
| 422 |
+
" (\"Format\", format_reward_func),\n",
|
| 423 |
+
" (\"Alignment\", alignment_reward_func),\n",
|
| 424 |
+
" (\"Risk\", risk_reward_func_multiagent),\n",
|
| 425 |
+
" (\"Profit\", profit_reward_func),\n",
|
| 426 |
+
" (\"Governance\", governance_reward_func_multiagent),\n",
|
| 427 |
"]:\n",
|
| 428 |
" scores = func(prompts, completions)\n",
|
| 429 |
+
" print(f\"{name:<25} {scores[0]:<12.2f} {scores[1]:<12.2f}\")\n"
|
| 430 |
]
|
| 431 |
},
|
| 432 |
{
|
| 433 |
"cell_type": "markdown",
|
|
|
|
| 434 |
"metadata": {},
|
| 435 |
"source": [
|
| 436 |
+
"## 7. Optional GPU GRPO Training\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
"\n",
|
| 438 |
+
"Leave `RUN_GRPO` as `False` unless the runtime has a GPU and you want to install the heavier TRL and Unsloth stack.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
]
|
| 440 |
},
|
| 441 |
{
|
| 442 |
"cell_type": "code",
|
| 443 |
"execution_count": null,
|
|
|
|
| 444 |
"metadata": {},
|
| 445 |
"outputs": [],
|
| 446 |
"source": [
|
| 447 |
+
"RUN_GRPO = False\n",
|
| 448 |
+
"\n",
|
| 449 |
+
"if not RUN_GRPO:\n",
|
| 450 |
+
" print(\"Set RUN_GRPO = True after enabling a GPU runtime.\")\n",
|
| 451 |
+
"else:\n",
|
| 452 |
+
" import json\n",
|
| 453 |
+
" import torch\n",
|
| 454 |
+
" from types import SimpleNamespace\n",
|
| 455 |
+
"\n",
|
| 456 |
+
" assert torch.cuda.is_available(), \"GPU required for GRPO training\"\n",
|
| 457 |
+
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
|
| 458 |
+
"\n",
|
| 459 |
+
" extra_packages = [\n",
|
| 460 |
+
" \"datasets\",\n",
|
| 461 |
+
" \"transformers\",\n",
|
| 462 |
+
" \"trl\",\n",
|
| 463 |
+
" \"peft\",\n",
|
| 464 |
+
" \"accelerate\",\n",
|
| 465 |
+
" \"safetensors\",\n",
|
| 466 |
+
" \"unsloth\",\n",
|
| 467 |
+
" ]\n",
|
| 468 |
+
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *extra_packages])\n",
|
| 469 |
+
"\n",
|
| 470 |
+
" from datasets import Dataset\n",
|
| 471 |
+
" from training.train_grpo_multiagent import load_model, make_trainer, save_model\n",
|
| 472 |
+
"\n",
|
| 473 |
+
" print(\"Generating 256 scenarios from MultiAgentTradingEnv...\")\n",
|
| 474 |
+
" scenarios = generate_pz_scenarios(n=256, difficulty=\"easy\", max_env_steps=50)\n",
|
| 475 |
+
" prompts = [{\"prompt\": build_prompt_multiagent(sc)} for sc in scenarios]\n",
|
| 476 |
+
" dataset = Dataset.from_list(prompts)\n",
|
| 477 |
+
" print(f\"Dataset: {len(dataset)} prompts\")\n",
|
| 478 |
+
"\n",
|
| 479 |
+
" model, tokenizer = load_model(\n",
|
| 480 |
+
" \"unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit\",\n",
|
| 481 |
+
" max_seq_length=1024,\n",
|
| 482 |
+
" )\n",
|
| 483 |
+
"\n",
|
| 484 |
+
" args = SimpleNamespace(\n",
|
| 485 |
+
" output_dir=\"models/local_policy_grpo_multiagent\",\n",
|
| 486 |
+
" learning_rate=5e-5,\n",
|
| 487 |
+
" per_device_batch_size=2,\n",
|
| 488 |
+
" gradient_accumulation_steps=2,\n",
|
| 489 |
+
" max_steps=100,\n",
|
| 490 |
+
" save_steps=25,\n",
|
| 491 |
+
" logging_steps=1,\n",
|
| 492 |
+
" max_prompt_length=768,\n",
|
| 493 |
+
" max_completion_length=200,\n",
|
| 494 |
+
" num_generations=4,\n",
|
| 495 |
+
" )\n",
|
| 496 |
+
"\n",
|
| 497 |
+
" trainer = make_trainer(model, tokenizer, dataset, args, torch)\n",
|
| 498 |
+
" print(f\"Starting GRPO training ({args.max_steps} steps)...\")\n",
|
| 499 |
+
" trainer.train()\n",
|
| 500 |
+
"\n",
|
| 501 |
+
" history = trainer.state.log_history\n",
|
| 502 |
+
" rewards = [item[\"reward\"] for item in history if \"reward\" in item]\n",
|
| 503 |
+
" losses = [item[\"loss\"] for item in history if \"loss\" in item]\n",
|
| 504 |
+
" if not losses:\n",
|
| 505 |
+
" losses = [0.0]\n",
|
| 506 |
+
"\n",
|
| 507 |
+
" from utils.plotting import plot_training_results\n",
|
| 508 |
+
"\n",
|
| 509 |
+
" plot_training_results(rewards, losses, output_dir=\"plots\")\n",
|
| 510 |
+
" save_model(model, tokenizer, args.output_dir)\n",
|
| 511 |
+
"\n",
|
| 512 |
+
" Path(\"outputs\").mkdir(parents=True, exist_ok=True)\n",
|
| 513 |
+
" with open(\"outputs/grpo_metrics.json\", \"w\", encoding=\"utf-8\") as handle:\n",
|
| 514 |
+
" json.dump({\"rewards\": rewards, \"losses\": losses}, handle, indent=2)\n",
|
| 515 |
+
"\n",
|
| 516 |
+
" print(f\"Model saved to {args.output_dir}\")\n"
|
| 517 |
]
|
| 518 |
},
|
| 519 |
{
|
| 520 |
"cell_type": "markdown",
|
|
|
|
| 521 |
"metadata": {},
|
| 522 |
"source": [
|
| 523 |
+
"## 8. Display Generated Artifacts\n"
|
|
|
|
|
|
|
|
|
|
| 524 |
]
|
| 525 |
},
|
| 526 |
{
|
| 527 |
"cell_type": "code",
|
| 528 |
"execution_count": null,
|
|
|
|
| 529 |
"metadata": {},
|
| 530 |
"outputs": [],
|
| 531 |
"source": [
|
| 532 |
+
"from IPython.display import Image, Markdown, display\n",
|
| 533 |
"\n",
|
| 534 |
"plot_files = [\n",
|
| 535 |
+
" (\"plots/reward_curve.png\", \"Per-Agent Reward Curves\"),\n",
|
| 536 |
+
" (\"plots/loss_curve.png\", \"PnL Or Policy Loss Curve\"),\n",
|
| 537 |
+
" (\"plots/baseline_comparison.png\", \"Early vs Late Training\"),\n",
|
|
|
|
| 538 |
"]\n",
|
| 539 |
"\n",
|
| 540 |
"for path, title in plot_files:\n",
|
| 541 |
" if Path(path).exists():\n",
|
| 542 |
" display(Markdown(f\"### {title}\"))\n",
|
| 543 |
+
" display(Image(filename=path, width=700))\n",
|
| 544 |
" else:\n",
|
| 545 |
+
" print(f\"Missing: {path}\")\n"
|
| 546 |
]
|
| 547 |
},
|
| 548 |
{
|
| 549 |
"cell_type": "markdown",
|
|
|
|
| 550 |
"metadata": {},
|
| 551 |
"source": [
|
| 552 |
+
"## 9. Submission Checklist\n"
|
|
|
|
| 553 |
]
|
| 554 |
},
|
| 555 |
{
|
| 556 |
"cell_type": "code",
|
| 557 |
"execution_count": null,
|
|
|
|
| 558 |
"metadata": {},
|
| 559 |
"outputs": [],
|
| 560 |
"source": [
|
| 561 |
+
"import yaml\n",
|
| 562 |
+
"\n",
|
| 563 |
"checks = [\n",
|
| 564 |
+
" (\"openenv.yaml\", Path(\"openenv.yaml\")),\n",
|
| 565 |
+
" (\"README.md\", Path(\"README.md\")),\n",
|
| 566 |
+
" (\"WRITEUP.md\", Path(\"WRITEUP.md\")),\n",
|
| 567 |
+
" (\"Multi-agent env\", Path(\"env/multi_agent_env.py\")),\n",
|
| 568 |
+
" (\"REINFORCE trainer\", Path(\"training/train_multi_agent.py\")),\n",
|
| 569 |
+
" (\"GRPO trainer\", Path(\"training/train_grpo_multiagent.py\")),\n",
|
| 570 |
+
" (\"Prompt helpers\", Path(\"training/prompt_utils.py\")),\n",
|
| 571 |
+
" (\"GRPO verifiers\", Path(\"training/grpo_verifiers_multiagent.py\")),\n",
|
| 572 |
+
" (\"Training notebook\", Path(\"mate_training.ipynb\")),\n",
|
| 573 |
+
" (\"Reward curve\", Path(\"plots/reward_curve.png\")),\n",
|
| 574 |
+
" (\"Loss curve\", Path(\"plots/loss_curve.png\")),\n",
|
| 575 |
+
" (\"Baseline comparison\", Path(\"plots/baseline_comparison.png\")),\n",
|
| 576 |
+
" (\"Dockerfile\", Path(\"Dockerfile\")),\n",
|
| 577 |
+
" (\"requirements-space.txt\", Path(\"requirements-space.txt\")),\n",
|
| 578 |
"]\n",
|
| 579 |
"\n",
|
| 580 |
+
"print(f\"{'Deliverable':<28} {'Status':<10}\")\n",
|
| 581 |
+
"print(\"-\" * 42)\n",
|
| 582 |
"for name, path in checks:\n",
|
| 583 |
+
" status = \"OK\" if path.exists() else \"MISSING\"\n",
|
| 584 |
+
" print(f\"{name:<28} {status:<10}\")\n",
|
| 585 |
+
"\n",
|
| 586 |
+
"with open(\"openenv.yaml\", encoding=\"utf-8\") as handle:\n",
|
| 587 |
+
" manifest = yaml.safe_load(handle)\n",
|
| 588 |
+
"entry_point = manifest.get(\"environment\", {}).get(\"entry_point\", \"\")\n",
|
| 589 |
+
"print(f\"\\nopenenv.yaml entry_point: {entry_point}\")\n",
|
| 590 |
+
"print(f\"PettingZoo env configured: {'multi_agent_env' in entry_point}\")\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
]
|
| 592 |
}
|
| 593 |
],
|
| 594 |
"metadata": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
"kernelspec": {
|
| 596 |
"display_name": "Python 3",
|
| 597 |
+
"language": "python",
|
| 598 |
"name": "python3"
|
| 599 |
},
|
| 600 |
"language_info": {
|
| 601 |
+
"name": "python"
|
|
|
|
| 602 |
}
|
| 603 |
},
|
| 604 |
"nbformat": 4,
|
| 605 |
"nbformat_minor": 5
|
| 606 |
+
}
|
training/grpo_verifiers_multiagent.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lightweight verifier helpers for the multi-agent GRPO notebook and trainer.
|
| 3 |
+
|
| 4 |
+
These functions intentionally avoid importing the training stack so notebooks can
|
| 5 |
+
preview prompts and reward functions without loading model or trainer deps.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _extract_json_action(completion: str):
|
| 17 |
+
match = re.search(r"<action>\s*({.*?})\s*</action>", completion, re.DOTALL)
|
| 18 |
+
if not match:
|
| 19 |
+
return None
|
| 20 |
+
return json.loads(match.group(1))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _extract_signal_value(prompt: str, key: str):
|
| 24 |
+
json_match = re.search(rf'"{key}"\s*:\s*(-?[\d\.]+)', prompt)
|
| 25 |
+
if json_match:
|
| 26 |
+
return float(json_match.group(1))
|
| 27 |
+
|
| 28 |
+
plain_match = re.search(rf"{key}\s*[:=]\s*(-?[\d\.]+)", prompt)
|
| 29 |
+
if plain_match:
|
| 30 |
+
return float(plain_match.group(1))
|
| 31 |
+
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def risk_reward_func_multiagent(prompts, completions, **kwargs) -> list[float]:
|
| 36 |
+
"""Read the Risk Manager limit from the prompt and reward compliant sizing."""
|
| 37 |
+
|
| 38 |
+
rewards = []
|
| 39 |
+
for prompt, completion in zip(prompts, completions):
|
| 40 |
+
try:
|
| 41 |
+
limit = _extract_signal_value(prompt, "rm_size_limit")
|
| 42 |
+
if limit is None:
|
| 43 |
+
limit = _extract_signal_value(prompt, "position_limit")
|
| 44 |
+
if limit is None:
|
| 45 |
+
limit = 1.0
|
| 46 |
+
|
| 47 |
+
data = _extract_json_action(completion)
|
| 48 |
+
if data is None:
|
| 49 |
+
rewards.append(0.0)
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
size = float(data.get("size", 0.0))
|
| 53 |
+
score = 0.7 if size <= limit else 0.0
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
thought = completion.split("<thought>")[1].split("</thought>")[0].lower()
|
| 57 |
+
if any(kw in thought for kw in ["risk", "limit", "constraint", "size_limit"]):
|
| 58 |
+
score += 0.3
|
| 59 |
+
except (IndexError, AttributeError):
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
rewards.append(score)
|
| 63 |
+
except Exception:
|
| 64 |
+
rewards.append(0.0)
|
| 65 |
+
|
| 66 |
+
return rewards
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def governance_reward_func_multiagent(prompts, completions, **kwargs) -> list[float]:
|
| 70 |
+
"""Score compliance against both Risk Manager and Portfolio Manager limits."""
|
| 71 |
+
|
| 72 |
+
rewards = []
|
| 73 |
+
for prompt, completion in zip(prompts, completions):
|
| 74 |
+
try:
|
| 75 |
+
data = _extract_json_action(completion)
|
| 76 |
+
if data is None:
|
| 77 |
+
rewards.append(0.0)
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
size = float(data.get("size", 0.0))
|
| 81 |
+
direction = int(data.get("direction", 0))
|
| 82 |
+
|
| 83 |
+
limit = _extract_signal_value(prompt, "rm_size_limit")
|
| 84 |
+
if limit is None:
|
| 85 |
+
limit = _extract_signal_value(prompt, "position_limit")
|
| 86 |
+
if limit is None:
|
| 87 |
+
limit = 1.0
|
| 88 |
+
|
| 89 |
+
pm_cap = _extract_signal_value(prompt, "pm_cap_alloc")
|
| 90 |
+
effective_limit = min(limit, pm_cap) if pm_cap is not None else limit
|
| 91 |
+
|
| 92 |
+
score = 0.0
|
| 93 |
+
if size <= effective_limit:
|
| 94 |
+
score += 0.40
|
| 95 |
+
if 0 < size <= effective_limit * 0.8:
|
| 96 |
+
score += 0.20
|
| 97 |
+
else:
|
| 98 |
+
score -= 0.50
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
thought = completion.split("<thought>")[1].split("</thought>")[0].lower()
|
| 102 |
+
governance_keywords = [
|
| 103 |
+
"risk",
|
| 104 |
+
"limit",
|
| 105 |
+
"constraint",
|
| 106 |
+
"compliance",
|
| 107 |
+
"conservative",
|
| 108 |
+
"governance",
|
| 109 |
+
"restrict",
|
| 110 |
+
"drawdown",
|
| 111 |
+
"cap",
|
| 112 |
+
"position limit",
|
| 113 |
+
"size_limit",
|
| 114 |
+
"risk manager",
|
| 115 |
+
"portfolio manager",
|
| 116 |
+
"allocation",
|
| 117 |
+
]
|
| 118 |
+
if any(kw in thought for kw in governance_keywords):
|
| 119 |
+
score += 0.20
|
| 120 |
+
except (IndexError, AttributeError):
|
| 121 |
+
pass
|
| 122 |
+
|
| 123 |
+
if direction != 0:
|
| 124 |
+
score += 0.20
|
| 125 |
+
|
| 126 |
+
rewards.append(float(np.clip(score, 0.0, 1.0)))
|
| 127 |
+
except Exception:
|
| 128 |
+
rewards.append(0.0)
|
| 129 |
+
|
| 130 |
+
return rewards
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
__all__ = [
|
| 134 |
+
"governance_reward_func_multiagent",
|
| 135 |
+
"risk_reward_func_multiagent",
|
| 136 |
+
]
|
training/train_grpo_multiagent.py
CHANGED
|
@@ -1,178 +1,54 @@
|
|
| 1 |
"""
|
| 2 |
-
PettingZoo-
|
| 3 |
|
| 4 |
-
Uses MultiAgentTradingEnv
|
| 5 |
-
send governance messages that become part of the Trader
|
| 6 |
-
The Trader is trained
|
| 7 |
-
|
| 8 |
-
RM and PM use rule-based policies during Trader training (alternating opt.).
|
| 9 |
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
-
import os
|
| 14 |
-
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
|
| 15 |
-
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
| 16 |
-
|
| 17 |
import argparse
|
| 18 |
import inspect
|
| 19 |
import json
|
|
|
|
| 20 |
import random
|
| 21 |
import sys
|
| 22 |
from pathlib import Path
|
| 23 |
-
from typing import Dict, List
|
| 24 |
|
| 25 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
ROOT = Path(__file__).resolve().parents[1]
|
| 28 |
if str(ROOT) not in sys.path:
|
| 29 |
sys.path.insert(0, str(ROOT))
|
| 30 |
|
| 31 |
-
from datasets import Dataset
|
| 32 |
-
|
| 33 |
-
from env.multi_agent_env import (
|
| 34 |
-
MultiAgentTradingEnv,
|
| 35 |
-
RISK_MANAGER,
|
| 36 |
-
PORTFOLIO_MGR,
|
| 37 |
-
TRADER,
|
| 38 |
-
)
|
| 39 |
from env.reward import (
|
| 40 |
-
format_reward_func,
|
| 41 |
alignment_reward_func,
|
|
|
|
| 42 |
profit_reward_func,
|
| 43 |
)
|
| 44 |
-
from training.
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
)
|
| 48 |
|
| 49 |
|
| 50 |
-
# ─── Constants ─────────────────────────────────────────────────────────────────
|
| 51 |
-
|
| 52 |
DEFAULT_MODEL_NAME = "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit"
|
| 53 |
DEFAULT_OUTPUT_DIR = "models/local_policy_grpo_multiagent"
|
| 54 |
|
| 55 |
-
from training.prompt_utils import SYSTEM_PROMPT, generate_pz_scenarios, build_prompt_multiagent
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
# ─── Updated GRPO Verifiers ───────────────────────────────────────────────────
|
| 60 |
-
|
| 61 |
-
def _extract_json_action(completion: str):
|
| 62 |
-
import re
|
| 63 |
-
match = re.search(r"<action>\s*({.*?})\s*</action>", completion, re.DOTALL)
|
| 64 |
-
if not match:
|
| 65 |
-
return None
|
| 66 |
-
return json.loads(match.group(1))
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def _extract_signal_value(prompt: str, key: str):
|
| 70 |
-
import re
|
| 71 |
-
json_match = re.search(rf'"{key}"\s*:\s*(-?[\d\.]+)', prompt)
|
| 72 |
-
if json_match:
|
| 73 |
-
return float(json_match.group(1))
|
| 74 |
-
plain_match = re.search(rf"{key}\s*[:=]\s*(-?[\d\.]+)", prompt)
|
| 75 |
-
if plain_match:
|
| 76 |
-
return float(plain_match.group(1))
|
| 77 |
-
return None
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def risk_reward_func_multiagent(prompts, completions, **kwargs) -> list[float]:
|
| 81 |
-
"""Updated risk verifier: reads RM's dynamic size_limit from the prompt."""
|
| 82 |
-
rewards = []
|
| 83 |
-
for prompt, completion in zip(prompts, completions):
|
| 84 |
-
try:
|
| 85 |
-
# Read RM's size_limit from the governance block
|
| 86 |
-
limit = _extract_signal_value(prompt, "rm_size_limit")
|
| 87 |
-
if limit is None:
|
| 88 |
-
limit = _extract_signal_value(prompt, "position_limit")
|
| 89 |
-
if limit is None:
|
| 90 |
-
limit = 1.0
|
| 91 |
-
|
| 92 |
-
data = _extract_json_action(completion)
|
| 93 |
-
if data is not None:
|
| 94 |
-
size = float(data.get("size", 0.0))
|
| 95 |
-
score = 0.7 if size <= limit else 0.0
|
| 96 |
-
|
| 97 |
-
try:
|
| 98 |
-
thought = completion.split("<thought>")[1].split("</thought>")[0].lower()
|
| 99 |
-
if any(kw in thought for kw in ["risk", "limit", "constraint", "size_limit"]):
|
| 100 |
-
score += 0.3
|
| 101 |
-
except (IndexError, AttributeError):
|
| 102 |
-
pass
|
| 103 |
-
rewards.append(score)
|
| 104 |
-
else:
|
| 105 |
-
rewards.append(0.0)
|
| 106 |
-
except Exception:
|
| 107 |
-
rewards.append(0.0)
|
| 108 |
-
return rewards
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
def governance_reward_func_multiagent(prompts, completions, **kwargs) -> list[float]:
|
| 112 |
-
"""Updated governance verifier: checks compliance against *learned* RM constraints.
|
| 113 |
-
|
| 114 |
-
The key differentiator: the governance verifier now checks compliance against
|
| 115 |
-
RM's size_limit from the prompt, not a hardcoded position_limit.
|
| 116 |
-
"""
|
| 117 |
-
rewards = []
|
| 118 |
-
for prompt, completion in zip(prompts, completions):
|
| 119 |
-
try:
|
| 120 |
-
data = _extract_json_action(completion)
|
| 121 |
-
if data is None:
|
| 122 |
-
rewards.append(0.0)
|
| 123 |
-
continue
|
| 124 |
-
|
| 125 |
-
size = float(data.get("size", 0.0))
|
| 126 |
-
direction = int(data.get("direction", 0))
|
| 127 |
-
|
| 128 |
-
# Use RM's dynamic limit
|
| 129 |
-
limit = _extract_signal_value(prompt, "rm_size_limit")
|
| 130 |
-
if limit is None:
|
| 131 |
-
limit = _extract_signal_value(prompt, "position_limit")
|
| 132 |
-
if limit is None:
|
| 133 |
-
limit = 1.0
|
| 134 |
-
|
| 135 |
-
# Also check PM cap
|
| 136 |
-
pm_cap = _extract_signal_value(prompt, "pm_cap_alloc")
|
| 137 |
-
effective_limit = min(limit, pm_cap) if pm_cap is not None else limit
|
| 138 |
-
|
| 139 |
-
score = 0.0
|
| 140 |
-
|
| 141 |
-
# Core compliance: within both RM limit and PM cap
|
| 142 |
-
if size <= effective_limit:
|
| 143 |
-
score += 0.40
|
| 144 |
-
if 0 < size <= effective_limit * 0.8:
|
| 145 |
-
score += 0.20
|
| 146 |
-
else:
|
| 147 |
-
score -= 0.50
|
| 148 |
-
|
| 149 |
-
# Reasoning quality: governance-aware language
|
| 150 |
-
try:
|
| 151 |
-
thought = completion.split("<thought>")[1].split("</thought>")[0].lower()
|
| 152 |
-
governance_keywords = [
|
| 153 |
-
"risk", "limit", "constraint", "compliance", "conservative",
|
| 154 |
-
"governance", "restrict", "drawdown", "cap", "position limit",
|
| 155 |
-
"size_limit", "risk manager", "portfolio manager", "allocation",
|
| 156 |
-
]
|
| 157 |
-
if any(kw in thought for kw in governance_keywords):
|
| 158 |
-
score += 0.20
|
| 159 |
-
except (IndexError, AttributeError):
|
| 160 |
-
pass
|
| 161 |
-
|
| 162 |
-
# Activity bonus
|
| 163 |
-
if direction != 0:
|
| 164 |
-
score += 0.20
|
| 165 |
-
|
| 166 |
-
rewards.append(float(np.clip(score, 0.0, 1.0)))
|
| 167 |
-
except Exception:
|
| 168 |
-
rewards.append(0.0)
|
| 169 |
-
return rewards
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
# ─── Model Loading ─────────────────────────────────────────────────────────────
|
| 173 |
|
| 174 |
def require_cuda():
|
| 175 |
import torch
|
|
|
|
| 176 |
if not torch.cuda.is_available():
|
| 177 |
raise SystemExit("GRPO training requires CUDA.")
|
| 178 |
return torch
|
|
@@ -180,6 +56,7 @@ def require_cuda():
|
|
| 180 |
|
| 181 |
def load_model(model_name: str, max_seq_length: int):
|
| 182 |
from unsloth import FastLanguageModel, PatchFastRL
|
|
|
|
| 183 |
PatchFastRL("GRPO", "unsloth")
|
| 184 |
|
| 185 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
@@ -191,8 +68,15 @@ def load_model(model_name: str, max_seq_length: int):
|
|
| 191 |
model = FastLanguageModel.get_peft_model(
|
| 192 |
model,
|
| 193 |
r=16,
|
| 194 |
-
target_modules=[
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
lora_alpha=16,
|
| 197 |
lora_dropout=0,
|
| 198 |
bias="none",
|
|
@@ -205,8 +89,6 @@ def load_model(model_name: str, max_seq_length: int):
|
|
| 205 |
return model, tokenizer
|
| 206 |
|
| 207 |
|
| 208 |
-
# ─── Trainer ───────────────────────────────────────────────────────────────────
|
| 209 |
-
|
| 210 |
def make_trainer(model, tokenizer, dataset, args, torch_module):
|
| 211 |
from trl.trainer.grpo_config import GRPOConfig
|
| 212 |
from trl.trainer.grpo_trainer import GRPOTrainer
|
|
@@ -231,9 +113,9 @@ def make_trainer(model, tokenizer, dataset, args, torch_module):
|
|
| 231 |
reward_funcs = [
|
| 232 |
format_reward_func,
|
| 233 |
alignment_reward_func,
|
| 234 |
-
risk_reward_func_multiagent,
|
| 235 |
profit_reward_func,
|
| 236 |
-
governance_reward_func_multiagent,
|
| 237 |
]
|
| 238 |
|
| 239 |
trainer_kwargs = {
|
|
@@ -261,10 +143,8 @@ def save_model(model, tokenizer, output_dir: str) -> None:
|
|
| 261 |
tokenizer.save_pretrained(output_dir)
|
| 262 |
|
| 263 |
|
| 264 |
-
# ─── CLI ───────────────────────────────────────────────────────────────────────
|
| 265 |
-
|
| 266 |
def parse_args():
|
| 267 |
-
parser = argparse.ArgumentParser(description="Multi-
|
| 268 |
parser.add_argument("--model-name", default=DEFAULT_MODEL_NAME)
|
| 269 |
parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR)
|
| 270 |
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
|
|
@@ -288,43 +168,40 @@ def main():
|
|
| 288 |
random.seed(args.seed)
|
| 289 |
np.random.seed(args.seed)
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
| 293 |
scenarios = generate_pz_scenarios(n=args.num_scenarios, difficulty=args.difficulty)
|
| 294 |
print(f" Generated {len(scenarios)} scenarios.")
|
| 295 |
|
| 296 |
-
# 2. Build dataset
|
| 297 |
prompts = [{"prompt": build_prompt_multiagent(sc)} for sc in scenarios]
|
| 298 |
dataset = Dataset.from_list(prompts)
|
| 299 |
|
| 300 |
-
# 3. Load model
|
| 301 |
torch_module = require_cuda()
|
| 302 |
model, tokenizer = load_model(args.model_name, args.max_seq_length)
|
| 303 |
|
| 304 |
-
# 4. Train
|
| 305 |
trainer = make_trainer(model, tokenizer, dataset, args, torch_module)
|
| 306 |
print(f"Starting multi-agent GRPO training on {len(dataset)} prompts...")
|
| 307 |
-
|
| 308 |
|
| 309 |
-
# 5. Generate plots
|
| 310 |
history = trainer.state.log_history
|
| 311 |
rewards = [x["reward"] for x in history if "reward" in x]
|
| 312 |
losses = [x["loss"] for x in history if "loss" in x]
|
| 313 |
|
| 314 |
try:
|
| 315 |
from utils.plotting import plot_training_results
|
|
|
|
| 316 |
plot_training_results(rewards, losses)
|
| 317 |
-
except Exception as
|
| 318 |
-
print(f" Warning: could not generate plots: {
|
| 319 |
|
| 320 |
-
# 6. Save model
|
| 321 |
print(f"Saving GRPO policy to {args.output_dir}...")
|
| 322 |
save_model(model, tokenizer, args.output_dir)
|
| 323 |
|
| 324 |
-
# 7. Save training metrics
|
| 325 |
metrics_path = Path(args.output_dir) / "training_metrics.json"
|
| 326 |
-
with open(metrics_path, "w") as
|
| 327 |
-
json.dump({"rewards": rewards, "losses": losses},
|
| 328 |
|
| 329 |
print("Multi-agent GRPO training complete.")
|
| 330 |
print(f" Model saved to: {args.output_dir}")
|
|
|
|
| 1 |
"""
|
| 2 |
+
PettingZoo-compatible GRPO training pipeline for Qwen 2.5.
|
| 3 |
|
| 4 |
+
Uses MultiAgentTradingEnv-derived scenarios where the Risk Manager and
|
| 5 |
+
Portfolio Manager send governance messages that become part of the Trader
|
| 6 |
+
prompt. The Trader is then trained with Unsloth + TRL GRPOTrainer.
|
|
|
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
from __future__ import annotations
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
import argparse
|
| 12 |
import inspect
|
| 13 |
import json
|
| 14 |
+
import os
|
| 15 |
import random
|
| 16 |
import sys
|
| 17 |
from pathlib import Path
|
|
|
|
| 18 |
|
| 19 |
import numpy as np
|
| 20 |
+
from datasets import Dataset
|
| 21 |
+
|
| 22 |
+
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
|
| 23 |
+
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
| 24 |
|
| 25 |
ROOT = Path(__file__).resolve().parents[1]
|
| 26 |
if str(ROOT) not in sys.path:
|
| 27 |
sys.path.insert(0, str(ROOT))
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
from env.reward import (
|
|
|
|
| 30 |
alignment_reward_func,
|
| 31 |
+
format_reward_func,
|
| 32 |
profit_reward_func,
|
| 33 |
)
|
| 34 |
+
from training.grpo_verifiers_multiagent import (
|
| 35 |
+
governance_reward_func_multiagent,
|
| 36 |
+
risk_reward_func_multiagent,
|
| 37 |
+
)
|
| 38 |
+
from training.prompt_utils import (
|
| 39 |
+
SYSTEM_PROMPT,
|
| 40 |
+
build_prompt_multiagent,
|
| 41 |
+
generate_pz_scenarios,
|
| 42 |
)
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
| 45 |
DEFAULT_MODEL_NAME = "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit"
|
| 46 |
DEFAULT_OUTPUT_DIR = "models/local_policy_grpo_multiagent"
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
def require_cuda():
|
| 50 |
import torch
|
| 51 |
+
|
| 52 |
if not torch.cuda.is_available():
|
| 53 |
raise SystemExit("GRPO training requires CUDA.")
|
| 54 |
return torch
|
|
|
|
| 56 |
|
| 57 |
def load_model(model_name: str, max_seq_length: int):
|
| 58 |
from unsloth import FastLanguageModel, PatchFastRL
|
| 59 |
+
|
| 60 |
PatchFastRL("GRPO", "unsloth")
|
| 61 |
|
| 62 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
|
|
| 68 |
model = FastLanguageModel.get_peft_model(
|
| 69 |
model,
|
| 70 |
r=16,
|
| 71 |
+
target_modules=[
|
| 72 |
+
"q_proj",
|
| 73 |
+
"k_proj",
|
| 74 |
+
"v_proj",
|
| 75 |
+
"o_proj",
|
| 76 |
+
"gate_proj",
|
| 77 |
+
"up_proj",
|
| 78 |
+
"down_proj",
|
| 79 |
+
],
|
| 80 |
lora_alpha=16,
|
| 81 |
lora_dropout=0,
|
| 82 |
bias="none",
|
|
|
|
| 89 |
return model, tokenizer
|
| 90 |
|
| 91 |
|
|
|
|
|
|
|
| 92 |
def make_trainer(model, tokenizer, dataset, args, torch_module):
|
| 93 |
from trl.trainer.grpo_config import GRPOConfig
|
| 94 |
from trl.trainer.grpo_trainer import GRPOTrainer
|
|
|
|
| 113 |
reward_funcs = [
|
| 114 |
format_reward_func,
|
| 115 |
alignment_reward_func,
|
| 116 |
+
risk_reward_func_multiagent,
|
| 117 |
profit_reward_func,
|
| 118 |
+
governance_reward_func_multiagent,
|
| 119 |
]
|
| 120 |
|
| 121 |
trainer_kwargs = {
|
|
|
|
| 143 |
tokenizer.save_pretrained(output_dir)
|
| 144 |
|
| 145 |
|
|
|
|
|
|
|
| 146 |
def parse_args():
|
| 147 |
+
parser = argparse.ArgumentParser(description="Multi-agent GRPO training for Trader (Qwen 2.5)")
|
| 148 |
parser.add_argument("--model-name", default=DEFAULT_MODEL_NAME)
|
| 149 |
parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR)
|
| 150 |
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
|
|
|
|
| 168 |
random.seed(args.seed)
|
| 169 |
np.random.seed(args.seed)
|
| 170 |
|
| 171 |
+
print(
|
| 172 |
+
f"Generating {args.num_scenarios} scenarios from MultiAgentTradingEnv "
|
| 173 |
+
f"(difficulty={args.difficulty})..."
|
| 174 |
+
)
|
| 175 |
scenarios = generate_pz_scenarios(n=args.num_scenarios, difficulty=args.difficulty)
|
| 176 |
print(f" Generated {len(scenarios)} scenarios.")
|
| 177 |
|
|
|
|
| 178 |
prompts = [{"prompt": build_prompt_multiagent(sc)} for sc in scenarios]
|
| 179 |
dataset = Dataset.from_list(prompts)
|
| 180 |
|
|
|
|
| 181 |
torch_module = require_cuda()
|
| 182 |
model, tokenizer = load_model(args.model_name, args.max_seq_length)
|
| 183 |
|
|
|
|
| 184 |
trainer = make_trainer(model, tokenizer, dataset, args, torch_module)
|
| 185 |
print(f"Starting multi-agent GRPO training on {len(dataset)} prompts...")
|
| 186 |
+
trainer.train()
|
| 187 |
|
|
|
|
| 188 |
history = trainer.state.log_history
|
| 189 |
rewards = [x["reward"] for x in history if "reward" in x]
|
| 190 |
losses = [x["loss"] for x in history if "loss" in x]
|
| 191 |
|
| 192 |
try:
|
| 193 |
from utils.plotting import plot_training_results
|
| 194 |
+
|
| 195 |
plot_training_results(rewards, losses)
|
| 196 |
+
except Exception as exc:
|
| 197 |
+
print(f" Warning: could not generate plots: {exc}")
|
| 198 |
|
|
|
|
| 199 |
print(f"Saving GRPO policy to {args.output_dir}...")
|
| 200 |
save_model(model, tokenizer, args.output_dir)
|
| 201 |
|
|
|
|
| 202 |
metrics_path = Path(args.output_dir) / "training_metrics.json"
|
| 203 |
+
with open(metrics_path, "w", encoding="utf-8") as handle:
|
| 204 |
+
json.dump({"rewards": rewards, "losses": losses}, handle, indent=2)
|
| 205 |
|
| 206 |
print("Multi-agent GRPO training complete.")
|
| 207 |
print(f" Model saved to: {args.output_dir}")
|