ARKAISW commited on
Commit
a3c00eb
·
1 Parent(s): 30a586b

Update training notebook and verifiers

Browse files
mate_training.ipynb CHANGED
@@ -2,468 +2,415 @@
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
- "id": "header",
6
  "metadata": {},
7
  "source": [
8
- "# \ud83c\udfdb\ufe0f QuantHive \u2014 Multi-Agent GRPO Training Notebook\n",
9
  "\n",
10
- "**Re-runnable training pipeline for the OpenEnv April '26 Hackathon.**\n",
11
- "\n",
12
- "This notebook is the single deliverable that proves QuantHive's training story end-to-end:\n",
13
- "\n",
14
- "| Section | What It Does | GPU? |\n",
15
- "|:---|:---|:---|\n",
16
- "| \u00a71-2 | Install deps, clone repo | \u274c |\n",
17
- "| \u00a73 | Validate PettingZoo AEC env (3 agents, obs shapes, governance) | \u274c |\n",
18
- "| \u00a74 | Run PettingZoo API compliance test | \u274c |\n",
19
- "| \u00a75 | Multi-agent REINFORCE training (rule-based warm-up) | \u274c |\n",
20
- "| \u00a76 | Generate per-agent reward/loss plots | \u274c |\n",
21
- "| \u00a77 | Preview governance-aware GRPO prompt | \u274c |\n",
22
- "| \u00a78 | **GRPO training** \u2014 Qwen 2.5-1.5B via Unsloth | \u2705 T4 |\n",
23
- "| \u00a79 | Display all committed training evidence | \u274c |\n",
24
- "\n",
25
- "**Architecture**: PettingZoo AEC with 3 independent RL agents:\n",
26
- "- `risk_manager_0` \u2014 obs: 24D, act: Box(3) \u2014 rewarded for restricting risk\n",
27
- "- `portfolio_manager_0` \u2014 obs: 27D, act: Box(2) \u2014 rewarded for portfolio grade\n",
28
- "- `trader_0` \u2014 obs: 29D, act: Dict \u2014 rewarded for PnL + compliance\n",
29
- "\n",
30
- "Turn order: **RM \u2192 PM \u2192 Trader** per market cycle. Each agent's output becomes part of the next agent's observation."
31
  ]
32
  },
33
  {
34
  "cell_type": "markdown",
35
- "id": "sec1_header",
36
  "metadata": {},
37
  "source": [
38
- "---\n",
39
- "## 1. Install Dependencies\n",
40
  "\n",
41
- "Run on Google Colab (T4 GPU recommended for \u00a78). All other sections work on CPU."
42
  ]
43
  },
44
  {
45
  "cell_type": "code",
46
  "execution_count": null,
47
- "id": "install_deps",
48
  "metadata": {},
49
  "outputs": [],
50
  "source": [
51
- "%%capture\n",
52
- "# Core environment\n",
53
- "%pip install pettingzoo>=1.24.0 gymnasium numpy pandas matplotlib scipy\n",
54
- "\n",
55
- "# Data sources\n",
56
- "%pip install yfinance ccxt\n",
57
- "\n",
58
- "# ML / Training\n",
59
- "%pip install torch transformers trl peft accelerate datasets safetensors\n",
60
  "\n",
61
- "# Server (not needed for training, but imported by some modules)\n",
62
- "%pip install fastapi uvicorn python-dotenv openai aiohttp\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  "\n",
64
- "# Unsloth for GRPO (uncomment for GPU training in \u00a78)\n",
65
- "# %pip install unsloth"
 
 
66
  ]
67
  },
68
  {
69
  "cell_type": "markdown",
70
- "id": "sec2_header",
71
  "metadata": {},
72
  "source": [
73
- "---\n",
74
- "## 2. Clone Repository"
 
75
  ]
76
  },
77
  {
78
  "cell_type": "code",
79
  "execution_count": null,
80
- "id": "clone_repo",
81
  "metadata": {},
82
  "outputs": [],
83
  "source": [
84
- "import os, sys\n",
85
- "from pathlib import Path\n",
86
- "\n",
87
- "REPO_URL = \"https://huggingface.co/spaces/ARKAISW/QuantHive\"\n",
88
- "REPO_DIR = Path(\"/content/QuantHive\")\n",
89
- "\n",
90
- "if not REPO_DIR.exists():\n",
91
- " !git clone {REPO_URL} {REPO_DIR}\n",
92
- "\n",
93
- "%cd {REPO_DIR}\n",
94
- "if str(REPO_DIR) not in sys.path:\n",
95
- " sys.path.insert(0, str(REPO_DIR))\n",
 
96
  "\n",
97
- "print(f\"Working directory: {Path.cwd()}\")\n",
98
- "print(f\"Python: {sys.version}\")"
99
  ]
100
  },
101
  {
102
  "cell_type": "markdown",
103
- "id": "sec3_header",
104
  "metadata": {},
105
  "source": [
106
- "---\n",
107
- "## 3. Validate PettingZoo Multi-Agent Environment\n",
108
  "\n",
109
- "The environment declared in `openenv.yaml` is `env.multi_agent_env:MultiAgentTradingEnv` \u2014 a PettingZoo AEC env with 3 agents that negotiate via inter-agent message passing."
110
  ]
111
  },
112
  {
113
  "cell_type": "code",
114
  "execution_count": null,
115
- "id": "validate_env_setup",
116
  "metadata": {},
117
  "outputs": [],
118
  "source": [
 
119
  "import numpy as np\n",
 
120
  "from env.multi_agent_env import (\n",
121
- " MultiAgentTradingEnv,\n",
122
- " RISK_MANAGER,\n",
123
- " PORTFOLIO_MGR,\n",
124
- " TRADER,\n",
125
  " ALL_AGENTS,\n",
126
  " BASE_OBS_SIZE,\n",
127
- " RM_MSG_SIZE,\n",
128
  " PM_MSG_SIZE,\n",
 
 
 
 
129
  ")\n",
130
  "\n",
 
 
 
131
  "env = MultiAgentTradingEnv(difficulty=\"easy\", max_steps=50)\n",
132
- "env.reset()\n",
133
- "\n",
134
- "print(\"=\"*60)\n",
135
- "print(\" QuantHive \u2014 Multi-Agent Trading Environment\")\n",
136
- "print(\"=\"*60)\n",
137
- "print(f\" Agents: {env.agents}\")\n",
138
- "print(f\" Turn order: RM \u2192 PM \u2192 Trader\")\n",
139
- "print(f\" RM obs: {env.observe(RISK_MANAGER).shape} (base: {BASE_OBS_SIZE})\")\n",
140
- "print(f\" PM obs: {env.observe(PORTFOLIO_MGR).shape} (base + RM msg: {BASE_OBS_SIZE}+{RM_MSG_SIZE})\")\n",
141
- "print(f\" Trader obs: {env.observe(TRADER).shape} (base + RM + PM: {BASE_OBS_SIZE}+{RM_MSG_SIZE}+{PM_MSG_SIZE})\")\n",
142
- "print(f\" RM action: Box(3) \u2014 [size_limit, allow_new, force_reduce]\")\n",
143
- "print(f\" PM action: Box(2) \u2014 [cap_alloc, override_strength]\")\n",
144
- "print(f\" Trader action: Dict \u2014 {{direction, size, sl, tp}}\")"
145
  ]
146
  },
147
  {
148
  "cell_type": "code",
149
  "execution_count": null,
150
- "id": "validate_aec_cycle",
151
  "metadata": {},
152
  "outputs": [],
153
  "source": [
154
- "# Run one full AEC cycle: RM \u2192 PM \u2192 Trader\n",
155
- "# RM sets a tight 20% size limit, allows new positions, no force-reduce\n",
 
156
  "rm_action = np.array([0.20, 1.0, 0.0], dtype=np.float32)\n",
157
  "env.step(rm_action)\n",
158
- "print(f\"\u2705 RM acted: size_limit=0.20, allow_new=yes, force_reduce=no\")\n",
159
  "\n",
160
- "# PM allocates 50% capital, no override\n",
161
- "pm_action = np.array([0.50, 0.0], dtype=np.float32)\n",
162
  "env.step(pm_action)\n",
163
- "print(f\"\u2705 PM acted: cap_alloc=0.50, override=0.0\")\n",
164
  "\n",
165
- "# Trader proposes a RECKLESS buy: size=0.85 (way over RM's 0.20 limit)\n",
166
  "trader_action = {\n",
167
  " \"direction\": 1,\n",
168
- " \"size\": np.array([0.85], dtype=np.float32),\n",
169
  " \"sl\": np.array([0.0], dtype=np.float32),\n",
170
  " \"tp\": np.array([0.0], dtype=np.float32),\n",
171
  "}\n",
172
  "env.step(trader_action)\n",
173
- "print(f\"\u2705 Trader acted: proposed BUY size=0.85\")\n",
174
- "\n",
175
- "# Inspect governance outcome\n",
176
- "info = env.infos[TRADER]\n",
177
- "gov = info[\"governance\"]\n",
178
- "\n",
179
- "print(f\"\\n{'='*60}\")\n",
180
- "print(f\" GOVERNANCE NEGOTIATION RESULT\")\n",
181
- "print(f\"{'='*60}\")\n",
182
- "print(f\" Proposed: direction={gov['proposed']['direction']}, size={gov['proposed']['size']:.2f}\")\n",
183
- "print(f\" Executed: direction={gov['executed']['direction']}, size={gov['executed']['size']:.2f}\")\n",
184
- "print(f\" Compliant: {gov['was_compliant']}\")\n",
185
- "print(f\" Interventions ({len(gov['interventions'])}):\") \n",
186
- "for iv in gov[\"interventions\"]:\n",
187
- " print(f\" \u2192 {iv['agent']}: {iv['type']}\")\n",
188
- "\n",
189
- "print(f\"\\n Per-Agent Rewards:\")\n",
190
- "for agent, r in info[\"rewards\"].items():\n",
191
- " print(f\" {agent:25s} {r:+.4f}\")\n",
192
- "\n",
193
- "print(f\"\\n Portfolio: ${info['portfolio_value']:,.2f} | PnL: {info['pnl_pct']:+.2%}\")"
194
- ]
195
- },
196
- {
197
- "cell_type": "markdown",
198
- "id": "sec4_header",
199
- "metadata": {},
200
- "source": [
201
- "---\n",
202
- "## 4. PettingZoo API Compliance Test\n",
203
  "\n",
204
- "PettingZoo ships a built-in compliance test that verifies our env follows the AEC protocol correctly."
 
 
 
205
  ]
206
  },
207
  {
208
  "cell_type": "code",
209
  "execution_count": null,
210
- "id": "pz_api_test",
211
  "metadata": {},
212
  "outputs": [],
213
  "source": [
214
  "from pettingzoo.test import api_test\n",
215
  "\n",
216
- "test_env = MultiAgentTradingEnv(difficulty=\"easy\", max_steps=50)\n",
217
- "api_test(test_env, num_cycles=50, verbose_progress=True)\n",
218
- "print(\"\\n\u2705 PettingZoo API compliance test PASSED (50 cycles)\")"
219
  ]
220
  },
221
  {
222
  "cell_type": "markdown",
223
- "id": "sec5_header",
224
  "metadata": {},
225
  "source": [
226
- "---\n",
227
- "## 5. Multi-Agent REINFORCE Training (Rule-Based Policies)\n",
228
- "\n",
229
- "This is the CPU-friendly training path. Three rule-based policies (RM, PM, Trader) are trained using alternating optimization with REINFORCE-style policy gradients.\n",
230
  "\n",
231
- "The alternating schedule:\n",
232
- "- Episodes 0-9: optimize Trader (RM/PM frozen)\n",
233
- "- Episodes 10-19: optimize Risk Manager (Trader/PM frozen)\n",
234
- "- Repeat"
235
  ]
236
  },
237
  {
238
  "cell_type": "code",
239
  "execution_count": null,
240
- "id": "reinforce_training",
241
  "metadata": {},
242
  "outputs": [],
243
  "source": [
244
  "from training.train_multi_agent import train\n",
245
  "\n",
246
  "metrics = train(\n",
247
- " n_episodes=100,\n",
248
- " max_steps_ep=200,\n",
249
  " gamma=0.99,\n",
250
  " alternating_freq=10,\n",
251
  " difficulty=\"easy\",\n",
252
  " output_dir=\"outputs/multi_agent\",\n",
253
- " save_every=50,\n",
254
- ")"
255
- ]
256
- },
257
- {
258
- "cell_type": "markdown",
259
- "id": "sec5b_header",
260
- "metadata": {},
261
- "source": [
262
- "### 5.1 Curriculum Verification\n",
263
  "\n",
264
- "Verify that the environment works across all difficulty levels."
 
 
 
 
265
  ]
266
  },
267
  {
268
  "cell_type": "code",
269
  "execution_count": null,
270
- "id": "curriculum_check",
271
  "metadata": {},
272
  "outputs": [],
273
  "source": [
274
- "from training.train_multi_agent import collect_rollout, RuleRiskManagerPolicy, RulePortfolioManagerPolicy, RuleTraderPolicy\n",
 
 
 
 
 
275
  "\n",
276
  "policies = {\n",
277
- " RISK_MANAGER: RuleRiskManagerPolicy(),\n",
278
  " PORTFOLIO_MGR: RulePortfolioManagerPolicy(),\n",
279
- " TRADER: RuleTraderPolicy(),\n",
280
  "}\n",
281
  "\n",
282
- "print(f\"{'Difficulty':<12} {'Episodes':<10} {'Mean Trader Return':<22} {'Mean PnL':<15} {'Mean DD':<12}\")\n",
283
- "print(\"-\" * 71)\n",
284
  "\n",
285
  "for diff in [\"easy\", \"medium\", \"hard\"]:\n",
286
  " returns, pnls, dds = [], [], []\n",
287
  " test_env = MultiAgentTradingEnv(difficulty=diff, max_steps=100)\n",
288
  " for _ in range(10):\n",
289
  " buffers, info = collect_rollout(test_env, policies, max_steps=100)\n",
290
- " trader_ret = float(np.mean(buffers[TRADER].discounted_returns()))\n",
291
- " returns.append(trader_ret)\n",
292
- " pnls.append(info.get(\"pnl_pct\", 0.0))\n",
293
- " dds.append(info.get(\"max_drawdown\", 0.0))\n",
294
- " print(f\"{diff:<12} {10:<10} {np.mean(returns):+.6f} {np.mean(pnls):+.4%} {np.mean(dds):.4%}\")"
295
  ]
296
  },
297
  {
298
  "cell_type": "markdown",
299
- "id": "sec6_header",
300
  "metadata": {},
301
  "source": [
302
- "---\n",
303
- "## 6. Generate Per-Agent Reward & Loss Plots\n",
304
  "\n",
305
- "Generates the required plots from training metrics and saves them to `plots/`."
306
  ]
307
  },
308
  {
309
  "cell_type": "code",
310
  "execution_count": null,
311
- "id": "generate_plots",
312
  "metadata": {},
313
  "outputs": [],
314
  "source": [
315
- "import json, matplotlib\n",
 
 
316
  "matplotlib.use(\"Agg\")\n",
317
  "import matplotlib.pyplot as plt\n",
318
  "\n",
319
  "metrics_path = Path(\"outputs/multi_agent/metrics_final.json\")\n",
 
 
320
  "if not metrics_path.exists():\n",
321
- " # Fallback: use the metrics dict from \u00a75 directly\n",
322
- " with open(metrics_path, \"w\") as f:\n",
323
- " json.dump(dict(metrics), f, indent=2)\n",
 
324
  "\n",
325
- "with open(metrics_path) as f:\n",
326
- " m = json.load(f)\n",
327
  "\n",
 
 
328
  "episodes = m[\"episode\"]\n",
329
  "n_eps = len(episodes)\n",
330
  "print(f\"Loaded {n_eps} episodes from {metrics_path}\")\n",
331
  "\n",
332
- "# \u2500\u2500 Per-Agent Reward Curves \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
333
- "fig, ax = plt.subplots(figsize=(12, 6))\n",
 
 
 
334
  "\n",
335
- "def smooth(vals, w=10):\n",
336
- " if len(vals) < w: return vals\n",
337
- " return np.convolve(vals, np.ones(w)/w, mode=\"valid\").tolist()\n",
338
- "\n",
339
- "w = max(1, n_eps // 15)\n",
340
- "trader_s = smooth(m[\"trader_return\"], w)\n",
341
- "rm_s = smooth(m[\"rm_return\"], w)\n",
342
- "pm_s = smooth(m[\"pm_return\"], w)\n",
343
- "ep_s = episodes[:len(trader_s)]\n",
344
- "\n",
345
- "ax.plot(ep_s, trader_s, label=\"Trader\", color=\"#2ecc71\", linewidth=2)\n",
346
- "ax.plot(ep_s, rm_s, label=\"Risk Manager\", color=\"#e74c3c\", linewidth=2)\n",
347
- "ax.plot(ep_s, pm_s, label=\"Portfolio Manager\", color=\"#3498db\", linewidth=2)\n",
348
- "ax.set_xlabel(\"Episode\"); ax.set_ylabel(\"Discounted Return\")\n",
349
- "ax.set_title(\"QuantHive: Per-Agent Reward Curves (Multi-Agent Training)\")\n",
350
- "ax.legend(); ax.grid(True, alpha=0.3)\n",
351
  "plt.tight_layout()\n",
352
- "fig.savefig(\"plots/reward_curve.png\", dpi=150)\n",
353
  "plt.show()\n",
354
- "print(\"Saved: plots/reward_curve.png\")\n",
355
  "\n",
356
- "# \u2500\u2500 Loss Curve (PnL convergence) \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
357
  "fig2, ax2 = plt.subplots(figsize=(12, 6))\n",
358
- "pnl_s = smooth(m[\"pnl_pct\"], w)\n",
359
- "ax2.plot(episodes[:len(pnl_s)], pnl_s, color=\"#e74c3c\", linewidth=2)\n",
360
- "ax2.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.5)\n",
361
  "pnl_arr = np.array(pnl_s)\n",
362
- "ax2.fill_between(episodes[:len(pnl_s)], 0, pnl_s,\n",
363
- " where=pnl_arr>0, color=\"#2ecc71\", alpha=0.2)\n",
364
- "ax2.fill_between(episodes[:len(pnl_s)], 0, pnl_s,\n",
365
- " where=pnl_arr<=0, color=\"#e74c3c\", alpha=0.2)\n",
366
- "ax2.set_xlabel(\"Episode\"); ax2.set_ylabel(\"PnL %\")\n",
367
- "ax2.set_title(\"QuantHive: PnL Over Training (Policy Convergence)\")\n",
 
368
  "ax2.grid(True, alpha=0.3)\n",
369
  "plt.tight_layout()\n",
370
- "fig2.savefig(\"plots/loss_curve.png\", dpi=150)\n",
371
  "plt.show()\n",
372
- "print(\"Saved: plots/loss_curve.png\")\n",
373
  "\n",
374
- "# \u2500\u2500 Baseline vs Trained \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
375
  "if n_eps >= 20:\n",
376
  " fig3, ax3 = plt.subplots(figsize=(10, 6))\n",
377
  " names = [\"Trader Return\", \"Grade\", \"Max Drawdown\", \"Sharpe\"]\n",
378
- " early = [np.mean(m[k][:20]) for k in [\"trader_return\", \"grade\", \"max_drawdown\", \"sharpe\"]]\n",
379
- " late = [np.mean(m[k][-20:]) for k in [\"trader_return\", \"grade\", \"max_drawdown\", \"sharpe\"]]\n",
380
  " x = np.arange(len(names))\n",
381
  " ax3.bar(x - 0.175, early, 0.35, label=\"First 20 eps\", color=\"#e74c3c\", alpha=0.8)\n",
382
- " ax3.bar(x + 0.175, late, 0.35, label=\"Last 20 eps\", color=\"#2ecc71\", alpha=0.8)\n",
383
- " ax3.set_ylabel(\"Value\"); ax3.set_title(\"QuantHive: Early vs Late Training\")\n",
384
- " ax3.set_xticks(x); ax3.set_xticklabels(names)\n",
385
- " ax3.legend(); ax3.grid(True, alpha=0.3, axis=\"y\")\n",
 
 
 
386
  " plt.tight_layout()\n",
387
- " fig3.savefig(\"plots/baseline_comparison.png\", dpi=150)\n",
388
  " plt.show()\n",
389
- " print(\"Saved: plots/baseline_comparison.png\")"
 
390
  ]
391
  },
392
  {
393
  "cell_type": "markdown",
394
- "id": "sec7_header",
395
  "metadata": {},
396
  "source": [
397
- "---\n",
398
- "## 7. Preview Governance-Aware GRPO Prompt\n",
399
  "\n",
400
- "The GRPO pipeline generates scenarios directly from `MultiAgentTradingEnv`.\n",
401
- "Each prompt includes the RM and PM governance messages \u2014 the Trader must learn to read and respect them."
402
  ]
403
  },
404
  {
405
  "cell_type": "code",
406
  "execution_count": null,
407
- "id": "preview_prompt",
408
  "metadata": {},
409
  "outputs": [],
410
  "source": [
411
- "from training.train_grpo_multiagent import generate_pz_scenarios, build_prompt_multiagent\n",
412
  "\n",
413
  "scenarios = generate_pz_scenarios(n=3, difficulty=\"easy\", max_env_steps=30)\n",
414
- "print(f\"Generated {len(scenarios)} scenarios from PZ env\\n\")\n",
415
- "\n",
416
- "for i, sc in enumerate(scenarios):\n",
417
- " print(f\"{'='*60}\")\n",
418
- " print(f\" Scenario {i+1}\")\n",
419
- " print(f\" RM: size_limit={sc['rm_size_limit']:.2f}, allow_new={sc['rm_allow_new']}, force_reduce={sc['rm_force_reduce']}\")\n",
420
- " print(f\" PM: cap_alloc={sc['pm_cap_alloc']:.2f}, override={sc['pm_override']:.2f}\")\n",
421
- " print(f\"{'='*60}\")\n",
422
- "\n",
423
- "# Show full prompt for one scenario\n",
424
- "print(\"\\n\" + \"=\"*60)\n",
425
- "print(\" FULL PROMPT (Scenario 1)\")\n",
426
- "print(\"=\"*60)\n",
427
- "print(build_prompt_multiagent(scenarios[0]))"
 
 
 
428
  ]
429
  },
430
  {
431
  "cell_type": "code",
432
  "execution_count": null,
433
- "id": "verify_verifiers",
434
  "metadata": {},
435
  "outputs": [],
436
  "source": [
437
- "# Verify the updated GRPO verifiers can parse governance constraints\n",
438
- "from training.train_grpo_multiagent import (\n",
439
- " risk_reward_func_multiagent,\n",
440
  " governance_reward_func_multiagent,\n",
 
441
  ")\n",
442
- "from env.reward import format_reward_func, alignment_reward_func, profit_reward_func\n",
443
  "\n",
444
  "test_prompt = build_prompt_multiagent(scenarios[0])\n",
 
445
  "\n",
446
- "# A compliant completion\n",
447
  "compliant = (\n",
448
- " '<thought>\\n'\n",
449
- " 'RSI is 0.28, indicating oversold conditions. EMA20 crossing above EMA50 suggests bullish momentum. '\n",
450
- " 'However, the Risk Manager restricts allocation to {limit:.2f} given current market conditions. '\n",
451
- " 'The Portfolio Manager allocated {cap:.0%} capital. I will propose a conservative position '\n",
452
- " 'within the governance constraints to avoid intervention.\\n'\n",
453
- " '</thought>\\n'\n",
454
- " '<action>\\n'\n",
455
- " '{{\"direction\": 1, \"size\": {size:.2f}, \"sl\": 49000, \"tp\": 52000}}\\n'\n",
456
- " '</action>'\n",
457
- ").format(\n",
458
- " limit=scenarios[0][\"rm_size_limit\"],\n",
459
- " cap=scenarios[0][\"pm_cap_alloc\"],\n",
460
- " size=min(scenarios[0][\"rm_size_limit\"] * 0.7, scenarios[0][\"pm_cap_alloc\"] * 0.7),\n",
461
- ")\n",
462
  "\n",
463
- "# A reckless completion\n",
464
  "reckless = (\n",
465
- " '<thought>\\nMarket is moving up. Going all in.\\n</thought>\\n'\n",
466
- " '<action>\\n{\"direction\": 1, \"size\": 0.95, \"sl\": 0, \"tp\": 0}\\n</action>'\n",
467
  ")\n",
468
  "\n",
469
  "prompts = [test_prompt, test_prompt]\n",
@@ -472,241 +419,188 @@
472
  "print(f\"{'Verifier':<25} {'Compliant':<12} {'Reckless':<12}\")\n",
473
  "print(\"-\" * 49)\n",
474
  "for name, func in [\n",
475
- " (\"Format\", format_reward_func),\n",
476
- " (\"Alignment\", alignment_reward_func),\n",
477
- " (\"Risk (dynamic RM)\", risk_reward_func_multiagent),\n",
478
- " (\"Profit\", profit_reward_func),\n",
479
- " (\"Governance (RM+PM)\", governance_reward_func_multiagent),\n",
480
  "]:\n",
481
  " scores = func(prompts, completions)\n",
482
- " print(f\"{name:<25} {scores[0]:<12.2f} {scores[1]:<12.2f}\")"
483
  ]
484
  },
485
  {
486
  "cell_type": "markdown",
487
- "id": "sec8_header",
488
  "metadata": {},
489
  "source": [
490
- "---\n",
491
- "## 8. GRPO Training \u2014 Qwen 2.5-1.5B (GPU Required)\n",
492
- "\n",
493
- "This section trains the Trader agent as a language model using **GRPO** (Group Relative Policy Optimization) via Unsloth + TRL.\n",
494
- "\n",
495
- "**Requirements**: CUDA GPU (Colab T4 is sufficient).\n",
496
  "\n",
497
- "The 5 verifiers:\n",
498
- "1. **Format** \u2014 valid `<thought>` + `<action>` XML tags\n",
499
- "2. **Alignment** \u2014 reasoning matches market signals (anti-hallucination)\n",
500
- "3. **Risk** \u2014 size \u2264 RM's dynamic `size_limit` (reads from governance in prompt)\n",
501
- "4. **Profit** \u2014 direction matches price trend\n",
502
- "5. **Governance** \u2014 would this action pass without intervention? Checks compliance against *both* RM and PM constraints"
503
  ]
504
  },
505
  {
506
  "cell_type": "code",
507
  "execution_count": null,
508
- "id": "grpo_training",
509
  "metadata": {},
510
  "outputs": [],
511
  "source": [
512
- "# \u26a1 Uncomment the entire cell to run GRPO training on a T4 GPU.\n",
513
- "# This takes ~20-30 minutes for 100 steps on T4.\n",
514
- "\n",
515
- "import torch\n",
516
- "assert torch.cuda.is_available(), \"GPU required for GRPO training\"\n",
517
- "print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
518
- "\n",
519
- "# # Install Unsloth if not already\n",
520
- "# # %pip install unsloth\n",
521
- "\n",
522
- "#\n",
523
- "# --- Phase 1: Generate scenarios from the PZ env ---\n",
524
- "from training.train_grpo_multiagent import (\n",
525
- " generate_pz_scenarios,\n",
526
- " build_prompt_multiagent,\n",
527
- " load_model,\n",
528
- " make_trainer,\n",
529
- " save_model,\n",
530
- " risk_reward_func_multiagent,\n",
531
- " governance_reward_func_multiagent,\n",
532
- ")\n",
533
- "from datasets import Dataset\n",
534
- "import json\n",
535
- "#\n",
536
- "print(\"Generating 256 scenarios from MultiAgentTradingEnv...\")\n",
537
- "scenarios = generate_pz_scenarios(n=256, difficulty=\"easy\", max_env_steps=50)\n",
538
- "prompts = [{\"prompt\": build_prompt_multiagent(sc)} for sc in scenarios]\n",
539
- "dataset = Dataset.from_list(prompts)\n",
540
- "print(f\"Dataset: {len(dataset)} prompts\")\n",
541
- "#\n",
542
- "# --- Phase 2: Load Qwen 2.5-1.5B with LoRA ---\n",
543
- "model, tokenizer = load_model(\n",
544
- " \"unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit\",\n",
545
- " max_seq_length=1024,\n",
546
- ")\n",
547
- "#\n",
548
- "# --- Phase 3: Train with 5 verifiers ---\n",
549
- "from types import SimpleNamespace\n",
550
- "args = SimpleNamespace(\n",
551
- " output_dir=\"models/local_policy_grpo_multiagent\",\n",
552
- " learning_rate=5e-5,\n",
553
- " per_device_batch_size=2,\n",
554
- " gradient_accumulation_steps=2,\n",
555
- " max_steps=100,\n",
556
- " save_steps=25,\n",
557
- " logging_steps=1,\n",
558
- " max_prompt_length=768,\n",
559
- " max_completion_length=200,\n",
560
- " num_generations=4,\n",
561
- ")\n",
562
- "#\n",
563
- "trainer = make_trainer(model, tokenizer, dataset, args, torch)\n",
564
- "print(f\"Starting GRPO training ({args.max_steps} steps)...\")\n",
565
- "train_result = trainer.train()\n",
566
- "#\n",
567
- "# --- Phase 4: Extract metrics and generate plots ---\n",
568
- "history = trainer.state.log_history\n",
569
- "rewards = [x[\"reward\"] for x in history if \"reward\" in x]\n",
570
- "# losses = [x[\"loss\"] for x in history if \"loss\" in x]\n",
571
- "#\n",
572
- "from utils.plotting import plot_training_results\n",
573
- "plot_training_results(rewards, losses, output_dir=\"plots\")\n",
574
- "#\n",
575
- "# --- Phase 5: Save model ---\n",
576
- "save_model(model, tokenizer, args.output_dir)\n",
577
- "print(f\"Model saved to {args.output_dir}\")\n",
578
- "#\n",
579
- "# # Save metrics for later analysis\n",
580
- "with open(\"outputs/grpo_metrics.json\", \"w\") as f:\n",
581
- " json.dump({\"rewards\": rewards, \"losses\": losses}, f, indent=2)"
582
  ]
583
  },
584
  {
585
  "cell_type": "markdown",
586
- "id": "sec9_header",
587
  "metadata": {},
588
  "source": [
589
- "---\n",
590
- "## 9. Display All Training Evidence\n",
591
- "\n",
592
- "The repository ships committed training plots. This section displays them (whether generated in this notebook run or previously committed)."
593
  ]
594
  },
595
  {
596
  "cell_type": "code",
597
  "execution_count": null,
598
- "id": "display_plots",
599
  "metadata": {},
600
  "outputs": [],
601
  "source": [
602
- "from IPython.display import Image, display, Markdown\n",
603
  "\n",
604
  "plot_files = [\n",
605
- " (\"plots/reward_curve.png\", \"Per-Agent Reward Curves (RM, PM, Trader)\"),\n",
606
- " (\"plots/loss_curve.png\", \"PnL / Policy Loss Convergence\"),\n",
607
- " (\"plots/baseline_comparison.png\", \"Early vs Late Training Performance\"),\n",
608
- " (\"plots/grade_progression.png\", \"Portfolio Grade + Sharpe Over Training\"),\n",
609
  "]\n",
610
  "\n",
611
  "for path, title in plot_files:\n",
612
  " if Path(path).exists():\n",
613
  " display(Markdown(f\"### {title}\"))\n",
614
- " display(Image(filename=path, width=800))\n",
615
  " else:\n",
616
- " print(f\"\u26a0 {path} not found \u2014 run training (\u00a75/\u00a78) to generate\")"
617
  ]
618
  },
619
  {
620
  "cell_type": "markdown",
621
- "id": "sec10_header",
622
  "metadata": {},
623
  "source": [
624
- "---\n",
625
- "## 10. Submission Checklist"
626
  ]
627
  },
628
  {
629
  "cell_type": "code",
630
  "execution_count": null,
631
- "id": "checklist",
632
  "metadata": {},
633
  "outputs": [],
634
  "source": [
 
 
635
  "checks = [\n",
636
- " (\"openenv.yaml (PettingZoo entry_point)\", Path(\"openenv.yaml\")),\n",
637
- " (\"README.md\", Path(\"README.md\")),\n",
638
- " (\"WRITEUP.md\", Path(\"WRITEUP.md\")),\n",
639
- " (\"Multi-agent env\", Path(\"env/multi_agent_env.py\")),\n",
640
- " (\"REINFORCE trainer\", Path(\"training/train_multi_agent.py\")),\n",
641
- " (\"GRPO trainer (multi-agent)\", Path(\"training/train_grpo_multiagent.py\")),\n",
642
- " (\"Plot generator\", Path(\"training/plot_multiagent.py\")),\n",
643
- " (\"Training notebook\", Path(\"mate_training.ipynb\")),\n",
644
- " (\"Reward curve\", Path(\"plots/reward_curve.png\")),\n",
645
- " (\"Loss curve\", Path(\"plots/loss_curve.png\")),\n",
646
- " (\"Baseline comparison\", Path(\"plots/baseline_comparison.png\")),\n",
647
- " (\"Dockerfile\", Path(\"Dockerfile\")),\n",
648
- " (\"requirements-space.txt\", Path(\"requirements-space.txt\")),\n",
 
649
  "]\n",
650
  "\n",
651
- "print(f\"{'Deliverable':<40} {'Status':<10}\")\n",
652
- "print(\"-\" * 50)\n",
653
  "for name, path in checks:\n",
654
- " status = \"\u2705 OK\" if path.exists() else \"\u274c MISSING\"\n",
655
- " print(f\"{name:<40} {status}\")\n",
656
- "\n",
657
- "# Verify openenv.yaml points to PettingZoo\n",
658
- "import yaml\n",
659
- "try:\n",
660
- " with open(\"openenv.yaml\") as f:\n",
661
- " manifest = yaml.safe_load(f)\n",
662
- " ep = manifest.get(\"environment\", {}).get(\"entry_point\", \"\")\n",
663
- " is_pz = \"multi_agent_env\" in ep\n",
664
- " print(f\"\\nopenenv.yaml entry_point: {ep}\")\n",
665
- " print(f\"{'\u2705' if is_pz else '\u274c'} Points to {'PettingZoo' if is_pz else 'WRONG'} env\")\n",
666
- "except Exception:\n",
667
- " print(\"\\n\u26a0 Could not parse openenv.yaml (pyyaml may not be installed)\")"
668
- ]
669
- },
670
- {
671
- "cell_type": "markdown",
672
- "id": "footer",
673
- "metadata": {},
674
- "source": [
675
- "---\n",
676
- "\n",
677
- "## Summary\n",
678
- "\n",
679
- "This notebook establishes the full QuantHive training evidence chain:\n",
680
- "\n",
681
- "1. **Environment validated** \u2014 `MultiAgentTradingEnv` passes PettingZoo's official API test (50 cycles)\n",
682
- "2. **Governance works** \u2014 RM clamped a reckless 85% trade down to 20%, logged 3 interventions\n",
683
- "3. **Multi-agent training** \u2014 alternating optimization with per-agent reward curves\n",
684
- "4. **GRPO ready** \u2014 prompts include RM/PM constraints; verifiers check dynamic compliance\n",
685
- "5. **Curriculum works** \u2014 env runs at easy/medium/hard with decreasing Trader returns\n",
686
- "\n",
687
- "**Key differentiator**: GRPO verifiers #3 (Risk) and #5 (Governance) check compliance against the RM's *learned, dynamic* `size_limit` \u2014 not a hardcoded constant. This means the Trader must learn to **read and respect** governance messages.\n",
688
- "\n",
689
- "---\n",
690
- "*Built for the OpenEnv April '26 Hackathon | Theme 1: Multi-Agent Interactions* \n",
691
- "*Author: Arka Sarkar*"
692
  ]
693
  }
694
  ],
695
  "metadata": {
696
- "accelerator": "GPU",
697
- "colab": {
698
- "gpuType": "T4",
699
- "provenance": []
700
- },
701
  "kernelspec": {
702
  "display_name": "Python 3",
 
703
  "name": "python3"
704
  },
705
  "language_info": {
706
- "name": "python",
707
- "version": "3.12.0"
708
  }
709
  },
710
  "nbformat": 4,
711
  "nbformat_minor": 5
712
- }
 
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
 
5
  "metadata": {},
6
  "source": [
7
+ "# QuantHive Multi-Agent Training Notebook\n",
8
  "\n",
9
+ "Rerunnable notebook for validating the PettingZoo environment, running the rule-based multi-agent trainer, previewing governance-aware prompts, and optionally launching GRPO training on a GPU runtime.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  ]
11
  },
12
  {
13
  "cell_type": "markdown",
 
14
  "metadata": {},
15
  "source": [
16
+ "## 1. Bootstrap The Repo\n",
 
17
  "\n",
18
+ "Clone from GitHub when the notebook is running outside the repo. If the current working directory already looks like the repository, reuse it.\n"
19
  ]
20
  },
21
  {
22
  "cell_type": "code",
23
  "execution_count": null,
 
24
  "metadata": {},
25
  "outputs": [],
26
  "source": [
27
+ "import os\n",
28
+ "import subprocess\n",
29
+ "import sys\n",
30
+ "from pathlib import Path\n",
 
 
 
 
 
31
  "\n",
32
+ "REPO_URL = \"https://github.com/ARKAISW/multi-agent-trading-env.git\"\n",
33
+ "\n",
34
+ "def looks_like_repo(path: Path) -> bool:\n",
35
+ " return (\n",
36
+ " (path / \"openenv.yaml\").exists()\n",
37
+ " and (path / \"env\").exists()\n",
38
+ " and (path / \"training\").exists()\n",
39
+ " )\n",
40
+ "\n",
41
+ "start_dir = Path.cwd()\n",
42
+ "if looks_like_repo(start_dir):\n",
43
+ " REPO_DIR = start_dir\n",
44
+ "else:\n",
45
+ " clone_parent = Path(\"/content\") if Path(\"/content\").exists() else start_dir\n",
46
+ " REPO_DIR = clone_parent / \"multi-agent-trading-env\"\n",
47
+ " if not REPO_DIR.exists():\n",
48
+ " subprocess.check_call([\"git\", \"clone\", REPO_URL, str(REPO_DIR)])\n",
49
+ "\n",
50
+ "os.chdir(REPO_DIR)\n",
51
+ "if str(REPO_DIR) not in sys.path:\n",
52
+ " sys.path.insert(0, str(REPO_DIR))\n",
53
  "\n",
54
+ "commit = subprocess.check_output([\"git\", \"rev-parse\", \"--short\", \"HEAD\"], text=True).strip()\n",
55
+ "print(f\"Working directory: {Path.cwd()}\")\n",
56
+ "print(f\"Repo commit: {commit}\")\n",
57
+ "print(f\"Python: {sys.version.split()[0]}\")\n"
58
  ]
59
  },
60
  {
61
  "cell_type": "markdown",
 
62
  "metadata": {},
63
  "source": [
64
+ "## 2. Install Notebook Dependencies\n",
65
+ "\n",
66
+ "Install the lightweight stack needed for environment validation, rule-based training, plots, and prompt preview. The heavier GRPO packages stay in the optional GPU section.\n"
67
  ]
68
  },
69
  {
70
  "cell_type": "code",
71
  "execution_count": null,
 
72
  "metadata": {},
73
  "outputs": [],
74
  "source": [
75
+ "BASE_PACKAGES = [\n",
76
+ " \"openenv\",\n",
77
+ " \"pyyaml\",\n",
78
+ " \"pettingzoo>=1.24.0\",\n",
79
+ " \"gymnasium\",\n",
80
+ " \"numpy\",\n",
81
+ " \"pandas\",\n",
82
+ " \"matplotlib\",\n",
83
+ " \"scipy\",\n",
84
+ " \"torch\",\n",
85
+ " \"yfinance\",\n",
86
+ " \"ccxt\",\n",
87
+ "]\n",
88
  "\n",
89
+ "subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *BASE_PACKAGES])\n",
90
+ "print(\"Installed base notebook dependencies.\")\n"
91
  ]
92
  },
93
  {
94
  "cell_type": "markdown",
 
95
  "metadata": {},
96
  "source": [
97
+ "## 3. Validate The PettingZoo Multi-Agent Environment\n",
 
98
  "\n",
99
+ "Probe the constructor first so the notebook does not assume the wrong signature, then instantiate and inspect observation and action shapes.\n"
100
  ]
101
  },
102
  {
103
  "cell_type": "code",
104
  "execution_count": null,
 
105
  "metadata": {},
106
  "outputs": [],
107
  "source": [
108
+ "import inspect\n",
109
  "import numpy as np\n",
110
+ "\n",
111
  "from env.multi_agent_env import (\n",
 
 
 
 
112
  " ALL_AGENTS,\n",
113
  " BASE_OBS_SIZE,\n",
114
+ " MultiAgentTradingEnv,\n",
115
  " PM_MSG_SIZE,\n",
116
+ " PORTFOLIO_MGR,\n",
117
+ " RISK_MANAGER,\n",
118
+ " RM_MSG_SIZE,\n",
119
+ " TRADER,\n",
120
  ")\n",
121
  "\n",
122
+ "print(\"MultiAgentTradingEnv signature:\")\n",
123
+ "print(inspect.signature(MultiAgentTradingEnv))\n",
124
+ "\n",
125
  "env = MultiAgentTradingEnv(difficulty=\"easy\", max_steps=50)\n",
126
+ "reset_result = env.reset(seed=7)\n",
127
+ "\n",
128
+ "print(\"\\nEnvironment summary\")\n",
129
+ "print(\"-\" * 60)\n",
130
+ "print(f\"reset() returned: {reset_result}\")\n",
131
+ "print(f\"Agents: {env.agents}\")\n",
132
+ "print(f\"Turn order: RM -> PM -> Trader\")\n",
133
+ "print(f\"RM obs: {env.observe(RISK_MANAGER).shape} (base={BASE_OBS_SIZE})\")\n",
134
+ "print(f\"PM obs: {env.observe(PORTFOLIO_MGR).shape} (base+rm={BASE_OBS_SIZE + RM_MSG_SIZE})\")\n",
135
+ "print(f\"Trader obs: {env.observe(TRADER).shape} (base+rm+pm={BASE_OBS_SIZE + RM_MSG_SIZE + PM_MSG_SIZE})\")\n",
136
+ "print(f\"RM action space: {env.action_space(RISK_MANAGER)}\")\n",
137
+ "print(f\"PM action space: {env.action_space(PORTFOLIO_MGR)}\")\n",
138
+ "print(f\"Trader action space: {env.action_space(TRADER)}\")\n"
139
  ]
140
  },
141
  {
142
  "cell_type": "code",
143
  "execution_count": null,
 
144
  "metadata": {},
145
  "outputs": [],
146
  "source": [
147
+ "env = MultiAgentTradingEnv(difficulty=\"easy\", max_steps=10)\n",
148
+ "env.reset(seed=7)\n",
149
+ "\n",
150
  "rm_action = np.array([0.20, 1.0, 0.0], dtype=np.float32)\n",
151
  "env.step(rm_action)\n",
 
152
  "\n",
153
+ "pm_action = np.array([0.35, 0.0], dtype=np.float32)\n",
 
154
  "env.step(pm_action)\n",
 
155
  "\n",
 
156
  "trader_action = {\n",
157
  " \"direction\": 1,\n",
158
+ " \"size\": np.array([0.10], dtype=np.float32),\n",
159
  " \"sl\": np.array([0.0], dtype=np.float32),\n",
160
  " \"tp\": np.array([0.0], dtype=np.float32),\n",
161
  "}\n",
162
  "env.step(trader_action)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  "\n",
164
+ "print(\"Completed one AEC cycle.\")\n",
165
+ "print(f\"Current step: {env._current_step}\")\n",
166
+ "print(f\"Current agent selection: {env.agent_selection}\")\n",
167
+ "print(f\"Latest rewards: {env.rewards}\")\n"
168
  ]
169
  },
170
  {
171
  "cell_type": "code",
172
  "execution_count": null,
 
173
  "metadata": {},
174
  "outputs": [],
175
  "source": [
176
  "from pettingzoo.test import api_test\n",
177
  "\n",
178
+ "api_env = MultiAgentTradingEnv(difficulty=\"easy\", max_steps=20)\n",
179
+ "api_test(api_env, num_cycles=20, verbose_progress=True)\n",
180
+ "print(\"PettingZoo API test passed.\")\n"
181
  ]
182
  },
183
  {
184
  "cell_type": "markdown",
 
185
  "metadata": {},
186
  "source": [
187
+ "## 4. Run The Rule-Based Multi-Agent Trainer\n",
 
 
 
188
  "\n",
189
+ "This section exercises the multi-agent training loop without needing the GRPO stack.\n"
 
 
 
190
  ]
191
  },
192
  {
193
  "cell_type": "code",
194
  "execution_count": null,
 
195
  "metadata": {},
196
  "outputs": [],
197
  "source": [
198
  "from training.train_multi_agent import train\n",
199
  "\n",
200
  "metrics = train(\n",
201
+ " n_episodes=40,\n",
202
+ " max_steps_ep=150,\n",
203
  " gamma=0.99,\n",
204
  " alternating_freq=10,\n",
205
  " difficulty=\"easy\",\n",
206
  " output_dir=\"outputs/multi_agent\",\n",
207
+ " save_every=20,\n",
208
+ ")\n",
 
 
 
 
 
 
 
 
209
  "\n",
210
+ "print(f\"Saved training outputs to: {Path('outputs/multi_agent').resolve()}\")\n",
211
+ "if metrics.get(\"trader_return\"):\n",
212
+ " print(f\"Final trader return: {metrics['trader_return'][-1]:+.4f}\")\n",
213
+ "if metrics.get(\"grade\"):\n",
214
+ " print(f\"Final grade: {metrics['grade'][-1]:.4f}\")\n"
215
  ]
216
  },
217
  {
218
  "cell_type": "code",
219
  "execution_count": null,
 
220
  "metadata": {},
221
  "outputs": [],
222
  "source": [
223
+ "from training.train_multi_agent import (\n",
224
+ " RulePortfolioManagerPolicy,\n",
225
+ " RuleRiskManagerPolicy,\n",
226
+ " RuleTraderPolicy,\n",
227
+ " collect_rollout,\n",
228
+ ")\n",
229
  "\n",
230
  "policies = {\n",
231
+ " RISK_MANAGER: RuleRiskManagerPolicy(),\n",
232
  " PORTFOLIO_MGR: RulePortfolioManagerPolicy(),\n",
233
+ " TRADER: RuleTraderPolicy(),\n",
234
  "}\n",
235
  "\n",
236
+ "print(f\"{'Difficulty':<12} {'Episodes':<10} {'Mean Trader Return':<22} {'Mean PnL':<12} {'Mean DD':<12}\")\n",
237
+ "print(\"-\" * 74)\n",
238
  "\n",
239
  "for diff in [\"easy\", \"medium\", \"hard\"]:\n",
240
  " returns, pnls, dds = [], [], []\n",
241
  " test_env = MultiAgentTradingEnv(difficulty=diff, max_steps=100)\n",
242
  " for _ in range(10):\n",
243
  " buffers, info = collect_rollout(test_env, policies, max_steps=100)\n",
244
+ " returns.append(float(np.mean(buffers[TRADER].discounted_returns())))\n",
245
+ " pnls.append(float(info.get(\"pnl_pct\", 0.0)))\n",
246
+ " dds.append(float(info.get(\"max_drawdown\", 0.0)))\n",
247
+ " print(f\"{diff:<12} {10:<10} {np.mean(returns):+.6f} {np.mean(pnls):+.4%} {np.mean(dds):.4%}\")\n"
 
248
  ]
249
  },
250
  {
251
  "cell_type": "markdown",
 
252
  "metadata": {},
253
  "source": [
254
+ "## 5. Generate Training Plots\n",
 
255
  "\n",
256
+ "Use the saved metrics when available, or fall back to the in-memory `metrics` object from the training cell.\n"
257
  ]
258
  },
259
  {
260
  "cell_type": "code",
261
  "execution_count": null,
 
262
  "metadata": {},
263
  "outputs": [],
264
  "source": [
265
+ "import json\n",
266
+ "\n",
267
+ "import matplotlib\n",
268
  "matplotlib.use(\"Agg\")\n",
269
  "import matplotlib.pyplot as plt\n",
270
  "\n",
271
  "metrics_path = Path(\"outputs/multi_agent/metrics_final.json\")\n",
272
+ "metrics_path.parent.mkdir(parents=True, exist_ok=True)\n",
273
+ "\n",
274
  "if not metrics_path.exists():\n",
275
+ " if \"metrics\" not in globals():\n",
276
+ " raise RuntimeError(\"Run the training cell first so metrics are available.\")\n",
277
+ " with open(metrics_path, \"w\", encoding=\"utf-8\") as handle:\n",
278
+ " json.dump(dict(metrics), handle, indent=2)\n",
279
  "\n",
280
+ "with open(metrics_path, encoding=\"utf-8\") as handle:\n",
281
+ " m = json.load(handle)\n",
282
  "\n",
283
+ "plots_dir = Path(\"plots\")\n",
284
+ "plots_dir.mkdir(parents=True, exist_ok=True)\n",
285
  "episodes = m[\"episode\"]\n",
286
  "n_eps = len(episodes)\n",
287
  "print(f\"Loaded {n_eps} episodes from {metrics_path}\")\n",
288
  "\n",
289
+ "def smooth(values, window=10):\n",
290
+ " if len(values) < window:\n",
291
+ " return values\n",
292
+ " kernel = np.ones(window) / window\n",
293
+ " return np.convolve(values, kernel, mode=\"valid\").tolist()\n",
294
  "\n",
295
+ "window = max(1, n_eps // 15)\n",
296
+ "\n",
297
+ "fig, ax = plt.subplots(figsize=(12, 6))\n",
298
+ "trader_s = smooth(m[\"trader_return\"], window)\n",
299
+ "rm_s = smooth(m[\"rm_return\"], window)\n",
300
+ "pm_s = smooth(m[\"pm_return\"], window)\n",
301
+ "ep_s = episodes[: len(trader_s)]\n",
302
+ "ax.plot(ep_s, trader_s, label=\"Trader\", color=\"#2ecc71\", linewidth=2)\n",
303
+ "ax.plot(ep_s, rm_s, label=\"Risk Manager\", color=\"#e74c3c\", linewidth=2)\n",
304
+ "ax.plot(ep_s, pm_s, label=\"Portfolio Manager\", color=\"#3498db\", linewidth=2)\n",
305
+ "ax.set_xlabel(\"Episode\")\n",
306
+ "ax.set_ylabel(\"Discounted Return\")\n",
307
+ "ax.set_title(\"QuantHive: Per-Agent Reward Curves\")\n",
308
+ "ax.legend()\n",
309
+ "ax.grid(True, alpha=0.3)\n",
 
310
  "plt.tight_layout()\n",
311
+ "fig.savefig(plots_dir / \"reward_curve.png\", dpi=150)\n",
312
  "plt.show()\n",
 
313
  "\n",
 
314
  "fig2, ax2 = plt.subplots(figsize=(12, 6))\n",
315
+ "pnl_s = smooth(m[\"pnl_pct\"], window)\n",
 
 
316
  "pnl_arr = np.array(pnl_s)\n",
317
+ "ax2.plot(episodes[: len(pnl_s)], pnl_s, color=\"#e74c3c\", linewidth=2)\n",
318
+ "ax2.axhline(y=0, color=\"gray\", linestyle=\"--\", alpha=0.5)\n",
319
+ "ax2.fill_between(episodes[: len(pnl_s)], 0, pnl_s, where=pnl_arr > 0, color=\"#2ecc71\", alpha=0.2)\n",
320
+ "ax2.fill_between(episodes[: len(pnl_s)], 0, pnl_s, where=pnl_arr <= 0, color=\"#e74c3c\", alpha=0.2)\n",
321
+ "ax2.set_xlabel(\"Episode\")\n",
322
+ "ax2.set_ylabel(\"PnL %\")\n",
323
+ "ax2.set_title(\"QuantHive: PnL Over Training\")\n",
324
  "ax2.grid(True, alpha=0.3)\n",
325
  "plt.tight_layout()\n",
326
+ "fig2.savefig(plots_dir / \"loss_curve.png\", dpi=150)\n",
327
  "plt.show()\n",
 
328
  "\n",
 
329
  "if n_eps >= 20:\n",
330
  " fig3, ax3 = plt.subplots(figsize=(10, 6))\n",
331
  " names = [\"Trader Return\", \"Grade\", \"Max Drawdown\", \"Sharpe\"]\n",
332
+ " early = [np.mean(m[key][:20]) for key in [\"trader_return\", \"grade\", \"max_drawdown\", \"sharpe\"]]\n",
333
+ " late = [np.mean(m[key][-20:]) for key in [\"trader_return\", \"grade\", \"max_drawdown\", \"sharpe\"]]\n",
334
  " x = np.arange(len(names))\n",
335
  " ax3.bar(x - 0.175, early, 0.35, label=\"First 20 eps\", color=\"#e74c3c\", alpha=0.8)\n",
336
+ " ax3.bar(x + 0.175, late, 0.35, label=\"Last 20 eps\", color=\"#2ecc71\", alpha=0.8)\n",
337
+ " ax3.set_ylabel(\"Value\")\n",
338
+ " ax3.set_title(\"QuantHive: Early vs Late Training\")\n",
339
+ " ax3.set_xticks(x)\n",
340
+ " ax3.set_xticklabels(names)\n",
341
+ " ax3.legend()\n",
342
+ " ax3.grid(True, alpha=0.3, axis=\"y\")\n",
343
  " plt.tight_layout()\n",
344
+ " fig3.savefig(plots_dir / \"baseline_comparison.png\", dpi=150)\n",
345
  " plt.show()\n",
346
+ "\n",
347
+ "print(f\"Saved plots to: {plots_dir.resolve()}\")\n"
348
  ]
349
  },
350
  {
351
  "cell_type": "markdown",
 
352
  "metadata": {},
353
  "source": [
354
+ "## 6. Preview Governance-Aware GRPO Prompts\n",
 
355
  "\n",
356
+ "Import only the lightweight prompt helpers so prompt preview does not depend on the trainer stack.\n"
 
357
  ]
358
  },
359
  {
360
  "cell_type": "code",
361
  "execution_count": null,
 
362
  "metadata": {},
363
  "outputs": [],
364
  "source": [
365
+ "from training.prompt_utils import build_prompt_multiagent, generate_pz_scenarios\n",
366
  "\n",
367
  "scenarios = generate_pz_scenarios(n=3, difficulty=\"easy\", max_env_steps=30)\n",
368
+ "print(f\"Generated {len(scenarios)} scenarios.\\n\")\n",
369
+ "\n",
370
+ "for i, scenario in enumerate(scenarios, start=1):\n",
371
+ " print(\"=\" * 60)\n",
372
+ " print(f\"Scenario {i}\")\n",
373
+ " print(\n",
374
+ " f\"RM: size_limit={scenario['rm_size_limit']:.2f}, \"\n",
375
+ " f\"allow_new={scenario['rm_allow_new']}, force_reduce={scenario['rm_force_reduce']}\"\n",
376
+ " )\n",
377
+ " print(\n",
378
+ " f\"PM: cap_alloc={scenario['pm_cap_alloc']:.2f}, \"\n",
379
+ " f\"override={scenario['pm_override']:.2f}\"\n",
380
+ " )\n",
381
+ "\n",
382
+ "print(\"\\nFull prompt for scenario 1\")\n",
383
+ "print(\"=\" * 60)\n",
384
+ "print(build_prompt_multiagent(scenarios[0]))\n"
385
  ]
386
  },
387
  {
388
  "cell_type": "code",
389
  "execution_count": null,
 
390
  "metadata": {},
391
  "outputs": [],
392
  "source": [
393
+ "from env.reward import alignment_reward_func, format_reward_func, profit_reward_func\n",
394
+ "from training.grpo_verifiers_multiagent import (\n",
 
395
  " governance_reward_func_multiagent,\n",
396
+ " risk_reward_func_multiagent,\n",
397
  ")\n",
 
398
  "\n",
399
  "test_prompt = build_prompt_multiagent(scenarios[0])\n",
400
+ "effective_limit = min(scenarios[0][\"rm_size_limit\"], scenarios[0][\"pm_cap_alloc\"])\n",
401
  "\n",
 
402
  "compliant = (\n",
403
+ " \"<thought>\\n\"\n",
404
+ " \"Risk limits are active, so I will stay inside both the RM size cap and the PM allocation.\\n\"\n",
405
+ " \"</thought>\\n\"\n",
406
+ " \"<action>\\n\"\n",
407
+ " '{{\"direction\": 1, \"size\": %.2f, \"sl\": 49000, \"tp\": 52000}}\\n'\n",
408
+ " \"</action>\"\n",
409
+ ") % (effective_limit * 0.7)\n",
 
 
 
 
 
 
 
410
  "\n",
 
411
  "reckless = (\n",
412
+ " \"<thought>\\nMarket is moving up. Going all in.\\n</thought>\\n\"\n",
413
+ " \"<action>\\n{\\\"direction\\\": 1, \\\"size\\\": 0.95, \\\"sl\\\": 0, \\\"tp\\\": 0}\\n</action>\"\n",
414
  ")\n",
415
  "\n",
416
  "prompts = [test_prompt, test_prompt]\n",
 
419
  "print(f\"{'Verifier':<25} {'Compliant':<12} {'Reckless':<12}\")\n",
420
  "print(\"-\" * 49)\n",
421
  "for name, func in [\n",
422
+ " (\"Format\", format_reward_func),\n",
423
+ " (\"Alignment\", alignment_reward_func),\n",
424
+ " (\"Risk\", risk_reward_func_multiagent),\n",
425
+ " (\"Profit\", profit_reward_func),\n",
426
+ " (\"Governance\", governance_reward_func_multiagent),\n",
427
  "]:\n",
428
  " scores = func(prompts, completions)\n",
429
+ " print(f\"{name:<25} {scores[0]:<12.2f} {scores[1]:<12.2f}\")\n"
430
  ]
431
  },
432
  {
433
  "cell_type": "markdown",
 
434
  "metadata": {},
435
  "source": [
436
+ "## 7. Optional GPU GRPO Training\n",
 
 
 
 
 
437
  "\n",
438
+ "Leave `RUN_GRPO` as `False` unless the runtime has a GPU and you want to install the heavier TRL and Unsloth stack.\n"
 
 
 
 
 
439
  ]
440
  },
441
  {
442
  "cell_type": "code",
443
  "execution_count": null,
 
444
  "metadata": {},
445
  "outputs": [],
446
  "source": [
447
+ "RUN_GRPO = False\n",
448
+ "\n",
449
+ "if not RUN_GRPO:\n",
450
+ " print(\"Set RUN_GRPO = True after enabling a GPU runtime.\")\n",
451
+ "else:\n",
452
+ " import json\n",
453
+ " import torch\n",
454
+ " from types import SimpleNamespace\n",
455
+ "\n",
456
+ " assert torch.cuda.is_available(), \"GPU required for GRPO training\"\n",
457
+ " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
458
+ "\n",
459
+ " extra_packages = [\n",
460
+ " \"datasets\",\n",
461
+ " \"transformers\",\n",
462
+ " \"trl\",\n",
463
+ " \"peft\",\n",
464
+ " \"accelerate\",\n",
465
+ " \"safetensors\",\n",
466
+ " \"unsloth\",\n",
467
+ " ]\n",
468
+ " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *extra_packages])\n",
469
+ "\n",
470
+ " from datasets import Dataset\n",
471
+ " from training.train_grpo_multiagent import load_model, make_trainer, save_model\n",
472
+ "\n",
473
+ " print(\"Generating 256 scenarios from MultiAgentTradingEnv...\")\n",
474
+ " scenarios = generate_pz_scenarios(n=256, difficulty=\"easy\", max_env_steps=50)\n",
475
+ " prompts = [{\"prompt\": build_prompt_multiagent(sc)} for sc in scenarios]\n",
476
+ " dataset = Dataset.from_list(prompts)\n",
477
+ " print(f\"Dataset: {len(dataset)} prompts\")\n",
478
+ "\n",
479
+ " model, tokenizer = load_model(\n",
480
+ " \"unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit\",\n",
481
+ " max_seq_length=1024,\n",
482
+ " )\n",
483
+ "\n",
484
+ " args = SimpleNamespace(\n",
485
+ " output_dir=\"models/local_policy_grpo_multiagent\",\n",
486
+ " learning_rate=5e-5,\n",
487
+ " per_device_batch_size=2,\n",
488
+ " gradient_accumulation_steps=2,\n",
489
+ " max_steps=100,\n",
490
+ " save_steps=25,\n",
491
+ " logging_steps=1,\n",
492
+ " max_prompt_length=768,\n",
493
+ " max_completion_length=200,\n",
494
+ " num_generations=4,\n",
495
+ " )\n",
496
+ "\n",
497
+ " trainer = make_trainer(model, tokenizer, dataset, args, torch)\n",
498
+ " print(f\"Starting GRPO training ({args.max_steps} steps)...\")\n",
499
+ " trainer.train()\n",
500
+ "\n",
501
+ " history = trainer.state.log_history\n",
502
+ " rewards = [item[\"reward\"] for item in history if \"reward\" in item]\n",
503
+ " losses = [item[\"loss\"] for item in history if \"loss\" in item]\n",
504
+ " if not losses:\n",
505
+ " losses = [0.0]\n",
506
+ "\n",
507
+ " from utils.plotting import plot_training_results\n",
508
+ "\n",
509
+ " plot_training_results(rewards, losses, output_dir=\"plots\")\n",
510
+ " save_model(model, tokenizer, args.output_dir)\n",
511
+ "\n",
512
+ " Path(\"outputs\").mkdir(parents=True, exist_ok=True)\n",
513
+ " with open(\"outputs/grpo_metrics.json\", \"w\", encoding=\"utf-8\") as handle:\n",
514
+ " json.dump({\"rewards\": rewards, \"losses\": losses}, handle, indent=2)\n",
515
+ "\n",
516
+ " print(f\"Model saved to {args.output_dir}\")\n"
517
  ]
518
  },
519
  {
520
  "cell_type": "markdown",
 
521
  "metadata": {},
522
  "source": [
523
+ "## 8. Display Generated Artifacts\n"
 
 
 
524
  ]
525
  },
526
  {
527
  "cell_type": "code",
528
  "execution_count": null,
 
529
  "metadata": {},
530
  "outputs": [],
531
  "source": [
532
+ "from IPython.display import Image, Markdown, display\n",
533
  "\n",
534
  "plot_files = [\n",
535
+ " (\"plots/reward_curve.png\", \"Per-Agent Reward Curves\"),\n",
536
+ " (\"plots/loss_curve.png\", \"PnL Or Policy Loss Curve\"),\n",
537
+ " (\"plots/baseline_comparison.png\", \"Early vs Late Training\"),\n",
 
538
  "]\n",
539
  "\n",
540
  "for path, title in plot_files:\n",
541
  " if Path(path).exists():\n",
542
  " display(Markdown(f\"### {title}\"))\n",
543
+ " display(Image(filename=path, width=700))\n",
544
  " else:\n",
545
+ " print(f\"Missing: {path}\")\n"
546
  ]
547
  },
548
  {
549
  "cell_type": "markdown",
 
550
  "metadata": {},
551
  "source": [
552
+ "## 9. Submission Checklist\n"
 
553
  ]
554
  },
555
  {
556
  "cell_type": "code",
557
  "execution_count": null,
 
558
  "metadata": {},
559
  "outputs": [],
560
  "source": [
561
+ "import yaml\n",
562
+ "\n",
563
  "checks = [\n",
564
+ " (\"openenv.yaml\", Path(\"openenv.yaml\")),\n",
565
+ " (\"README.md\", Path(\"README.md\")),\n",
566
+ " (\"WRITEUP.md\", Path(\"WRITEUP.md\")),\n",
567
+ " (\"Multi-agent env\", Path(\"env/multi_agent_env.py\")),\n",
568
+ " (\"REINFORCE trainer\", Path(\"training/train_multi_agent.py\")),\n",
569
+ " (\"GRPO trainer\", Path(\"training/train_grpo_multiagent.py\")),\n",
570
+ " (\"Prompt helpers\", Path(\"training/prompt_utils.py\")),\n",
571
+ " (\"GRPO verifiers\", Path(\"training/grpo_verifiers_multiagent.py\")),\n",
572
+ " (\"Training notebook\", Path(\"mate_training.ipynb\")),\n",
573
+ " (\"Reward curve\", Path(\"plots/reward_curve.png\")),\n",
574
+ " (\"Loss curve\", Path(\"plots/loss_curve.png\")),\n",
575
+ " (\"Baseline comparison\", Path(\"plots/baseline_comparison.png\")),\n",
576
+ " (\"Dockerfile\", Path(\"Dockerfile\")),\n",
577
+ " (\"requirements-space.txt\", Path(\"requirements-space.txt\")),\n",
578
  "]\n",
579
  "\n",
580
+ "print(f\"{'Deliverable':<28} {'Status':<10}\")\n",
581
+ "print(\"-\" * 42)\n",
582
  "for name, path in checks:\n",
583
+ " status = \"OK\" if path.exists() else \"MISSING\"\n",
584
+ " print(f\"{name:<28} {status:<10}\")\n",
585
+ "\n",
586
+ "with open(\"openenv.yaml\", encoding=\"utf-8\") as handle:\n",
587
+ " manifest = yaml.safe_load(handle)\n",
588
+ "entry_point = manifest.get(\"environment\", {}).get(\"entry_point\", \"\")\n",
589
+ "print(f\"\\nopenenv.yaml entry_point: {entry_point}\")\n",
590
+ "print(f\"PettingZoo env configured: {'multi_agent_env' in entry_point}\")\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
  ]
592
  }
593
  ],
594
  "metadata": {
 
 
 
 
 
595
  "kernelspec": {
596
  "display_name": "Python 3",
597
+ "language": "python",
598
  "name": "python3"
599
  },
600
  "language_info": {
601
+ "name": "python"
 
602
  }
603
  },
604
  "nbformat": 4,
605
  "nbformat_minor": 5
606
+ }
training/grpo_verifiers_multiagent.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lightweight verifier helpers for the multi-agent GRPO notebook and trainer.
3
+
4
+ These functions intentionally avoid importing the training stack so notebooks can
5
+ preview prompts and reward functions without loading model or trainer deps.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import re
12
+
13
+ import numpy as np
14
+
15
+
16
+ def _extract_json_action(completion: str):
17
+ match = re.search(r"<action>\s*({.*?})\s*</action>", completion, re.DOTALL)
18
+ if not match:
19
+ return None
20
+ return json.loads(match.group(1))
21
+
22
+
23
+ def _extract_signal_value(prompt: str, key: str):
24
+ json_match = re.search(rf'"{key}"\s*:\s*(-?[\d\.]+)', prompt)
25
+ if json_match:
26
+ return float(json_match.group(1))
27
+
28
+ plain_match = re.search(rf"{key}\s*[:=]\s*(-?[\d\.]+)", prompt)
29
+ if plain_match:
30
+ return float(plain_match.group(1))
31
+
32
+ return None
33
+
34
+
35
+ def risk_reward_func_multiagent(prompts, completions, **kwargs) -> list[float]:
36
+ """Read the Risk Manager limit from the prompt and reward compliant sizing."""
37
+
38
+ rewards = []
39
+ for prompt, completion in zip(prompts, completions):
40
+ try:
41
+ limit = _extract_signal_value(prompt, "rm_size_limit")
42
+ if limit is None:
43
+ limit = _extract_signal_value(prompt, "position_limit")
44
+ if limit is None:
45
+ limit = 1.0
46
+
47
+ data = _extract_json_action(completion)
48
+ if data is None:
49
+ rewards.append(0.0)
50
+ continue
51
+
52
+ size = float(data.get("size", 0.0))
53
+ score = 0.7 if size <= limit else 0.0
54
+
55
+ try:
56
+ thought = completion.split("<thought>")[1].split("</thought>")[0].lower()
57
+ if any(kw in thought for kw in ["risk", "limit", "constraint", "size_limit"]):
58
+ score += 0.3
59
+ except (IndexError, AttributeError):
60
+ pass
61
+
62
+ rewards.append(score)
63
+ except Exception:
64
+ rewards.append(0.0)
65
+
66
+ return rewards
67
+
68
+
69
+ def governance_reward_func_multiagent(prompts, completions, **kwargs) -> list[float]:
70
+ """Score compliance against both Risk Manager and Portfolio Manager limits."""
71
+
72
+ rewards = []
73
+ for prompt, completion in zip(prompts, completions):
74
+ try:
75
+ data = _extract_json_action(completion)
76
+ if data is None:
77
+ rewards.append(0.0)
78
+ continue
79
+
80
+ size = float(data.get("size", 0.0))
81
+ direction = int(data.get("direction", 0))
82
+
83
+ limit = _extract_signal_value(prompt, "rm_size_limit")
84
+ if limit is None:
85
+ limit = _extract_signal_value(prompt, "position_limit")
86
+ if limit is None:
87
+ limit = 1.0
88
+
89
+ pm_cap = _extract_signal_value(prompt, "pm_cap_alloc")
90
+ effective_limit = min(limit, pm_cap) if pm_cap is not None else limit
91
+
92
+ score = 0.0
93
+ if size <= effective_limit:
94
+ score += 0.40
95
+ if 0 < size <= effective_limit * 0.8:
96
+ score += 0.20
97
+ else:
98
+ score -= 0.50
99
+
100
+ try:
101
+ thought = completion.split("<thought>")[1].split("</thought>")[0].lower()
102
+ governance_keywords = [
103
+ "risk",
104
+ "limit",
105
+ "constraint",
106
+ "compliance",
107
+ "conservative",
108
+ "governance",
109
+ "restrict",
110
+ "drawdown",
111
+ "cap",
112
+ "position limit",
113
+ "size_limit",
114
+ "risk manager",
115
+ "portfolio manager",
116
+ "allocation",
117
+ ]
118
+ if any(kw in thought for kw in governance_keywords):
119
+ score += 0.20
120
+ except (IndexError, AttributeError):
121
+ pass
122
+
123
+ if direction != 0:
124
+ score += 0.20
125
+
126
+ rewards.append(float(np.clip(score, 0.0, 1.0)))
127
+ except Exception:
128
+ rewards.append(0.0)
129
+
130
+ return rewards
131
+
132
+
133
+ __all__ = [
134
+ "governance_reward_func_multiagent",
135
+ "risk_reward_func_multiagent",
136
+ ]
training/train_grpo_multiagent.py CHANGED
@@ -1,178 +1,54 @@
1
  """
2
- PettingZoo-Compatible GRPO Training Pipeline for Qwen 2.5.
3
 
4
- Uses MultiAgentTradingEnv to generate scenarios where RM and PM
5
- send governance messages that become part of the Trader's prompt.
6
- The Trader is trained as a Qwen 2.5-1.5B model via Unsloth + TRL GRPOTrainer.
7
-
8
- RM and PM use rule-based policies during Trader training (alternating opt.).
9
  """
10
 
11
  from __future__ import annotations
12
 
13
- import os
14
- os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
15
- os.environ.setdefault("OMP_NUM_THREADS", "1")
16
-
17
  import argparse
18
  import inspect
19
  import json
 
20
  import random
21
  import sys
22
  from pathlib import Path
23
- from typing import Dict, List
24
 
25
  import numpy as np
 
 
 
 
26
 
27
  ROOT = Path(__file__).resolve().parents[1]
28
  if str(ROOT) not in sys.path:
29
  sys.path.insert(0, str(ROOT))
30
 
31
- from datasets import Dataset
32
-
33
- from env.multi_agent_env import (
34
- MultiAgentTradingEnv,
35
- RISK_MANAGER,
36
- PORTFOLIO_MGR,
37
- TRADER,
38
- )
39
  from env.reward import (
40
- format_reward_func,
41
  alignment_reward_func,
 
42
  profit_reward_func,
43
  )
44
- from training.train_multi_agent import (
45
- RuleRiskManagerPolicy,
46
- RulePortfolioManagerPolicy,
 
 
 
 
 
47
  )
48
 
49
 
50
- # ─── Constants ─────────────────────────────────────────────────────────────────
51
-
52
  DEFAULT_MODEL_NAME = "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit"
53
  DEFAULT_OUTPUT_DIR = "models/local_policy_grpo_multiagent"
54
 
55
- from training.prompt_utils import SYSTEM_PROMPT, generate_pz_scenarios, build_prompt_multiagent
56
-
57
-
58
-
59
- # ─── Updated GRPO Verifiers ───────────────────────────────────────────────────
60
-
61
- def _extract_json_action(completion: str):
62
- import re
63
- match = re.search(r"<action>\s*({.*?})\s*</action>", completion, re.DOTALL)
64
- if not match:
65
- return None
66
- return json.loads(match.group(1))
67
-
68
-
69
- def _extract_signal_value(prompt: str, key: str):
70
- import re
71
- json_match = re.search(rf'"{key}"\s*:\s*(-?[\d\.]+)', prompt)
72
- if json_match:
73
- return float(json_match.group(1))
74
- plain_match = re.search(rf"{key}\s*[:=]\s*(-?[\d\.]+)", prompt)
75
- if plain_match:
76
- return float(plain_match.group(1))
77
- return None
78
-
79
-
80
- def risk_reward_func_multiagent(prompts, completions, **kwargs) -> list[float]:
81
- """Updated risk verifier: reads RM's dynamic size_limit from the prompt."""
82
- rewards = []
83
- for prompt, completion in zip(prompts, completions):
84
- try:
85
- # Read RM's size_limit from the governance block
86
- limit = _extract_signal_value(prompt, "rm_size_limit")
87
- if limit is None:
88
- limit = _extract_signal_value(prompt, "position_limit")
89
- if limit is None:
90
- limit = 1.0
91
-
92
- data = _extract_json_action(completion)
93
- if data is not None:
94
- size = float(data.get("size", 0.0))
95
- score = 0.7 if size <= limit else 0.0
96
-
97
- try:
98
- thought = completion.split("<thought>")[1].split("</thought>")[0].lower()
99
- if any(kw in thought for kw in ["risk", "limit", "constraint", "size_limit"]):
100
- score += 0.3
101
- except (IndexError, AttributeError):
102
- pass
103
- rewards.append(score)
104
- else:
105
- rewards.append(0.0)
106
- except Exception:
107
- rewards.append(0.0)
108
- return rewards
109
-
110
-
111
- def governance_reward_func_multiagent(prompts, completions, **kwargs) -> list[float]:
112
- """Updated governance verifier: checks compliance against *learned* RM constraints.
113
-
114
- The key differentiator: the governance verifier now checks compliance against
115
- RM's size_limit from the prompt, not a hardcoded position_limit.
116
- """
117
- rewards = []
118
- for prompt, completion in zip(prompts, completions):
119
- try:
120
- data = _extract_json_action(completion)
121
- if data is None:
122
- rewards.append(0.0)
123
- continue
124
-
125
- size = float(data.get("size", 0.0))
126
- direction = int(data.get("direction", 0))
127
-
128
- # Use RM's dynamic limit
129
- limit = _extract_signal_value(prompt, "rm_size_limit")
130
- if limit is None:
131
- limit = _extract_signal_value(prompt, "position_limit")
132
- if limit is None:
133
- limit = 1.0
134
-
135
- # Also check PM cap
136
- pm_cap = _extract_signal_value(prompt, "pm_cap_alloc")
137
- effective_limit = min(limit, pm_cap) if pm_cap is not None else limit
138
-
139
- score = 0.0
140
-
141
- # Core compliance: within both RM limit and PM cap
142
- if size <= effective_limit:
143
- score += 0.40
144
- if 0 < size <= effective_limit * 0.8:
145
- score += 0.20
146
- else:
147
- score -= 0.50
148
-
149
- # Reasoning quality: governance-aware language
150
- try:
151
- thought = completion.split("<thought>")[1].split("</thought>")[0].lower()
152
- governance_keywords = [
153
- "risk", "limit", "constraint", "compliance", "conservative",
154
- "governance", "restrict", "drawdown", "cap", "position limit",
155
- "size_limit", "risk manager", "portfolio manager", "allocation",
156
- ]
157
- if any(kw in thought for kw in governance_keywords):
158
- score += 0.20
159
- except (IndexError, AttributeError):
160
- pass
161
-
162
- # Activity bonus
163
- if direction != 0:
164
- score += 0.20
165
-
166
- rewards.append(float(np.clip(score, 0.0, 1.0)))
167
- except Exception:
168
- rewards.append(0.0)
169
- return rewards
170
-
171
-
172
- # ─── Model Loading ─────────────────────────────────────────────────────────────
173
 
174
  def require_cuda():
175
  import torch
 
176
  if not torch.cuda.is_available():
177
  raise SystemExit("GRPO training requires CUDA.")
178
  return torch
@@ -180,6 +56,7 @@ def require_cuda():
180
 
181
  def load_model(model_name: str, max_seq_length: int):
182
  from unsloth import FastLanguageModel, PatchFastRL
 
183
  PatchFastRL("GRPO", "unsloth")
184
 
185
  model, tokenizer = FastLanguageModel.from_pretrained(
@@ -191,8 +68,15 @@ def load_model(model_name: str, max_seq_length: int):
191
  model = FastLanguageModel.get_peft_model(
192
  model,
193
  r=16,
194
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
195
- "gate_proj", "up_proj", "down_proj"],
 
 
 
 
 
 
 
196
  lora_alpha=16,
197
  lora_dropout=0,
198
  bias="none",
@@ -205,8 +89,6 @@ def load_model(model_name: str, max_seq_length: int):
205
  return model, tokenizer
206
 
207
 
208
- # ─── Trainer ───────────────────────────────────────────────────────────────────
209
-
210
  def make_trainer(model, tokenizer, dataset, args, torch_module):
211
  from trl.trainer.grpo_config import GRPOConfig
212
  from trl.trainer.grpo_trainer import GRPOTrainer
@@ -231,9 +113,9 @@ def make_trainer(model, tokenizer, dataset, args, torch_module):
231
  reward_funcs = [
232
  format_reward_func,
233
  alignment_reward_func,
234
- risk_reward_func_multiagent, # Updated: reads RM's dynamic size_limit
235
  profit_reward_func,
236
- governance_reward_func_multiagent, # Updated: checks compliance vs learned RM constraints
237
  ]
238
 
239
  trainer_kwargs = {
@@ -261,10 +143,8 @@ def save_model(model, tokenizer, output_dir: str) -> None:
261
  tokenizer.save_pretrained(output_dir)
262
 
263
 
264
- # ─── CLI ───────────────────────────────────────────────────────────────────────
265
-
266
  def parse_args():
267
- parser = argparse.ArgumentParser(description="Multi-Agent GRPO Training for Trader (Qwen 2.5)")
268
  parser.add_argument("--model-name", default=DEFAULT_MODEL_NAME)
269
  parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR)
270
  parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
@@ -288,43 +168,40 @@ def main():
288
  random.seed(args.seed)
289
  np.random.seed(args.seed)
290
 
291
- # 1. Generate scenarios from the PettingZoo env
292
- print(f"Generating {args.num_scenarios} scenarios from MultiAgentTradingEnv (difficulty={args.difficulty})...")
 
 
293
  scenarios = generate_pz_scenarios(n=args.num_scenarios, difficulty=args.difficulty)
294
  print(f" Generated {len(scenarios)} scenarios.")
295
 
296
- # 2. Build dataset
297
  prompts = [{"prompt": build_prompt_multiagent(sc)} for sc in scenarios]
298
  dataset = Dataset.from_list(prompts)
299
 
300
- # 3. Load model
301
  torch_module = require_cuda()
302
  model, tokenizer = load_model(args.model_name, args.max_seq_length)
303
 
304
- # 4. Train
305
  trainer = make_trainer(model, tokenizer, dataset, args, torch_module)
306
  print(f"Starting multi-agent GRPO training on {len(dataset)} prompts...")
307
- train_result = trainer.train()
308
 
309
- # 5. Generate plots
310
  history = trainer.state.log_history
311
  rewards = [x["reward"] for x in history if "reward" in x]
312
  losses = [x["loss"] for x in history if "loss" in x]
313
 
314
  try:
315
  from utils.plotting import plot_training_results
 
316
  plot_training_results(rewards, losses)
317
- except Exception as e:
318
- print(f" Warning: could not generate plots: {e}")
319
 
320
- # 6. Save model
321
  print(f"Saving GRPO policy to {args.output_dir}...")
322
  save_model(model, tokenizer, args.output_dir)
323
 
324
- # 7. Save training metrics
325
  metrics_path = Path(args.output_dir) / "training_metrics.json"
326
- with open(metrics_path, "w") as f:
327
- json.dump({"rewards": rewards, "losses": losses}, f, indent=2)
328
 
329
  print("Multi-agent GRPO training complete.")
330
  print(f" Model saved to: {args.output_dir}")
 
1
  """
2
+ PettingZoo-compatible GRPO training pipeline for Qwen 2.5.
3
 
4
+ Uses MultiAgentTradingEnv-derived scenarios where the Risk Manager and
5
+ Portfolio Manager send governance messages that become part of the Trader
6
+ prompt. The Trader is then trained with Unsloth + TRL GRPOTrainer.
 
 
7
  """
8
 
9
  from __future__ import annotations
10
 
 
 
 
 
11
  import argparse
12
  import inspect
13
  import json
14
+ import os
15
  import random
16
  import sys
17
  from pathlib import Path
 
18
 
19
  import numpy as np
20
+ from datasets import Dataset
21
+
22
+ os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
23
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
24
 
25
  ROOT = Path(__file__).resolve().parents[1]
26
  if str(ROOT) not in sys.path:
27
  sys.path.insert(0, str(ROOT))
28
 
 
 
 
 
 
 
 
 
29
  from env.reward import (
 
30
  alignment_reward_func,
31
+ format_reward_func,
32
  profit_reward_func,
33
  )
34
+ from training.grpo_verifiers_multiagent import (
35
+ governance_reward_func_multiagent,
36
+ risk_reward_func_multiagent,
37
+ )
38
+ from training.prompt_utils import (
39
+ SYSTEM_PROMPT,
40
+ build_prompt_multiagent,
41
+ generate_pz_scenarios,
42
  )
43
 
44
 
 
 
45
  DEFAULT_MODEL_NAME = "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit"
46
  DEFAULT_OUTPUT_DIR = "models/local_policy_grpo_multiagent"
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def require_cuda():
50
  import torch
51
+
52
  if not torch.cuda.is_available():
53
  raise SystemExit("GRPO training requires CUDA.")
54
  return torch
 
56
 
57
  def load_model(model_name: str, max_seq_length: int):
58
  from unsloth import FastLanguageModel, PatchFastRL
59
+
60
  PatchFastRL("GRPO", "unsloth")
61
 
62
  model, tokenizer = FastLanguageModel.from_pretrained(
 
68
  model = FastLanguageModel.get_peft_model(
69
  model,
70
  r=16,
71
+ target_modules=[
72
+ "q_proj",
73
+ "k_proj",
74
+ "v_proj",
75
+ "o_proj",
76
+ "gate_proj",
77
+ "up_proj",
78
+ "down_proj",
79
+ ],
80
  lora_alpha=16,
81
  lora_dropout=0,
82
  bias="none",
 
89
  return model, tokenizer
90
 
91
 
 
 
92
  def make_trainer(model, tokenizer, dataset, args, torch_module):
93
  from trl.trainer.grpo_config import GRPOConfig
94
  from trl.trainer.grpo_trainer import GRPOTrainer
 
113
  reward_funcs = [
114
  format_reward_func,
115
  alignment_reward_func,
116
+ risk_reward_func_multiagent,
117
  profit_reward_func,
118
+ governance_reward_func_multiagent,
119
  ]
120
 
121
  trainer_kwargs = {
 
143
  tokenizer.save_pretrained(output_dir)
144
 
145
 
 
 
146
  def parse_args():
147
+ parser = argparse.ArgumentParser(description="Multi-agent GRPO training for Trader (Qwen 2.5)")
148
  parser.add_argument("--model-name", default=DEFAULT_MODEL_NAME)
149
  parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR)
150
  parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
 
168
  random.seed(args.seed)
169
  np.random.seed(args.seed)
170
 
171
+ print(
172
+ f"Generating {args.num_scenarios} scenarios from MultiAgentTradingEnv "
173
+ f"(difficulty={args.difficulty})..."
174
+ )
175
  scenarios = generate_pz_scenarios(n=args.num_scenarios, difficulty=args.difficulty)
176
  print(f" Generated {len(scenarios)} scenarios.")
177
 
 
178
  prompts = [{"prompt": build_prompt_multiagent(sc)} for sc in scenarios]
179
  dataset = Dataset.from_list(prompts)
180
 
 
181
  torch_module = require_cuda()
182
  model, tokenizer = load_model(args.model_name, args.max_seq_length)
183
 
 
184
  trainer = make_trainer(model, tokenizer, dataset, args, torch_module)
185
  print(f"Starting multi-agent GRPO training on {len(dataset)} prompts...")
186
+ trainer.train()
187
 
 
188
  history = trainer.state.log_history
189
  rewards = [x["reward"] for x in history if "reward" in x]
190
  losses = [x["loss"] for x in history if "loss" in x]
191
 
192
  try:
193
  from utils.plotting import plot_training_results
194
+
195
  plot_training_results(rewards, losses)
196
+ except Exception as exc:
197
+ print(f" Warning: could not generate plots: {exc}")
198
 
 
199
  print(f"Saving GRPO policy to {args.output_dir}...")
200
  save_model(model, tokenizer, args.output_dir)
201
 
 
202
  metrics_path = Path(args.output_dir) / "training_metrics.json"
203
+ with open(metrics_path, "w", encoding="utf-8") as handle:
204
+ json.dump({"rewards": rewards, "losses": losses}, handle, indent=2)
205
 
206
  print("Multi-agent GRPO training complete.")
207
  print(f" Model saved to: {args.output_dir}")