{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# \ud83d\udc1d QuantHive \u2014 Multi-Agent GRPO Training Notebook\n",
"\n",
"**Project:** QuantHive \u2014 Decentralized Multi-Agent Trading Governance \n",
"**Model:** Qwen 2.5 1.5B (4-bit quantized via Unsloth) \n",
"**Method:** GRPO (Group Relative Policy Optimization) with 5 independent reward verifiers \n",
"**Framework:** PettingZoo AEC + OpenEnv + TRL + Unsloth \n",
"\n",
"This notebook trains the **Trader agent** to propose compliant, profitable trades while respecting governance constraints set by the **Risk Manager** and **Portfolio Manager** agents.\n",
"\n",
"---\n",
"\n",
"## Architecture\n",
"\n",
"```\n",
"\u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510 constraints \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510 allocation \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n",
"\u2502 Risk Manager \u2502 \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u25ba\u2502 Portfolio Manager \u2502\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u25ba\u2502 Trader \u2502\n",
"\u2502 (rule-based) \u2502 size_limit, \u2502 (rule-based) \u2502 cap_alloc, \u2502 (LLM) \u2502\n",
"\u2502 \u2502 allow_new, \u2502 \u2502 override \u2502 Qwen2.5 \u2502\n",
"\u2502 \u2502 force_reduce \u2502 \u2502 \u2502 1.5B \u2502\n",
"\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518 \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518 \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n",
" \u25b2 \u25b2 \u2502\n",
" \u2502 PettingZoo AEC Turn Order \u2502\n",
" \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500 Market Feedback \u25c4\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n",
"```\n",
"\n",
"**5 Reward Verifiers:**\n",
"1. `format_reward_func` \u2014 Enforces `......` format\n",
"2. `alignment_reward_func` \u2014 Anti-hallucination: reasoning must match market signals\n",
"3. `risk_reward_func_multiagent` \u2014 Compliance with Risk Manager size limits\n",
"4. `profit_reward_func` \u2014 Directional correctness vs. price trend\n",
"5. `governance_reward_func_multiagent` \u2014 Full governance compliance (RM + PM limits)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Environment Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \u2500\u2500\u2500 Check GPU availability \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"import torch\n",
"if not torch.cuda.is_available():\n",
" raise RuntimeError(\n",
" \"\u274c GPU not available. Go to Runtime \u2192 Change runtime type \u2192 T4 GPU\"\n",
" )\n",
"gpu = torch.cuda.get_device_name(0)\n",
"mem = torch.cuda.get_device_properties(0).total_memory / 1e9\n",
"print(f\"\u2705 GPU: {gpu} | VRAM: {mem:.1f} GB\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \u2500\u2500\u2500 Install Unsloth (must come first \u2014 patches TRL internals) \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"%pip install --no-deps \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
"%pip install --no-deps unsloth_zoo xformers trl peft accelerate bitsandbytes triton\n",
"%pip install datasets transformers sentencepiece protobuf openenv\n",
"%pip install pettingzoo>=1.24.0 gymnasium pandas numpy matplotlib"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \u2500\u2500\u2500 Clone the QuantHive repository \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"import os\n",
"\n",
"REPO_URL = \"https://github.com/ARKAISW/multi-agent-trading-env.git\"\n",
"REPO_DIR = \"/content/multi-agent-trading-env\"\n",
"\n",
"if os.path.exists(REPO_DIR):\n",
" print(f\"\ud83d\udcc1 Repo already cloned at {REPO_DIR}\")\n",
"else:\n",
" !git clone {REPO_URL} {REPO_DIR}\n",
" print(f\"\u2705 Cloned to {REPO_DIR}\")\n",
"\n",
"os.chdir(REPO_DIR)\n",
"print(f\"\ud83d\udcc2 Working directory: {os.getcwd()}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \u2500\u2500\u2500 Add repo root to Python path so all imports work \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"import sys\n",
"sys.path.insert(0, REPO_DIR)\n",
"\n",
"# Verify critical imports\n",
"from env.multi_agent_env import MultiAgentTradingEnv, RISK_MANAGER, PORTFOLIO_MGR, TRADER\n",
"from env.reward import format_reward_func, alignment_reward_func, profit_reward_func\n",
"from training.grpo_verifiers_multiagent import (\n",
" governance_reward_func_multiagent,\n",
" risk_reward_func_multiagent,\n",
")\n",
"from training.prompt_utils import (\n",
" SYSTEM_PROMPT,\n",
" build_prompt_multiagent,\n",
" generate_pz_scenarios,\n",
")\n",
"print(\"\u2705 All project imports successful\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. PettingZoo Environment Demo\n",
"\n",
"Quick sanity check \u2014 run the multi-agent AEC loop with rule-based policies."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from training.train_multi_agent import (\n",
" RuleRiskManagerPolicy,\n",
" RulePortfolioManagerPolicy,\n",
" RuleTraderPolicy,\n",
" collect_rollout,\n",
")\n",
"\n",
"# Create environment instance\n",
"env = MultiAgentTradingEnv(difficulty=\"easy\", max_steps=50)\n",
"policies = {\n",
" RISK_MANAGER: RuleRiskManagerPolicy(),\n",
" PORTFOLIO_MGR: RulePortfolioManagerPolicy(),\n",
" TRADER: RuleTraderPolicy(),\n",
"}\n",
"\n",
"# Run one episode\n",
"buffers, info = collect_rollout(env, policies, max_steps=50)\n",
"\n",
"print(\"=\" * 60)\n",
"print(\" PettingZoo AEC Environment \u2014 Single Episode\")\n",
"print(\"=\" * 60)\n",
"print(f\" Trader steps: {len(buffers[TRADER])}\")\n",
"print(f\" RM steps: {len(buffers[RISK_MANAGER])}\")\n",
"print(f\" PM steps: {len(buffers[PORTFOLIO_MGR])}\")\n",
"print(f\" Final PnL: {info.get('pnl_pct', 0):.2%}\")\n",
"print(f\" Max Drawdown: {info.get('max_drawdown', 0):.2%}\")\n",
"print(f\" Sharpe: {info.get('sharpe_ratio', 0):.3f}\")\n",
"print(f\" Grade: {info.get('grade', 0):.4f}\")\n",
"print(\"=\" * 60)\n",
"print(\"\u2705 Environment functional!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Generate Training Scenarios\n",
"\n",
"Run the PettingZoo env with rule-based RM/PM policies to collect realistic market scenarios. Each scenario contains:\n",
"- Base market observation (24 dims)\n",
"- Risk Manager constraints (`rm_size_limit`, `rm_allow_new`, `rm_force_reduce`)\n",
"- Portfolio Manager allocation (`pm_cap_alloc`, `pm_override`)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import numpy as np\n",
"\n",
"SEED = 3407\n",
"random.seed(SEED)\n",
"np.random.seed(SEED)\n",
"\n",
"# \u2500\u2500\u2500 Configuration \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"NUM_SCENARIOS = 500 # Number of training prompts\n",
"DIFFICULTY = \"easy\" # Start easy for faster convergence\n",
"MAX_ENV_STEPS = 100 # Steps per PZ episode for scenario generation\n",
"\n",
"print(f\"Generating {NUM_SCENARIOS} scenarios (difficulty={DIFFICULTY})...\")\n",
"scenarios = generate_pz_scenarios(\n",
" n=NUM_SCENARIOS,\n",
" difficulty=DIFFICULTY,\n",
" max_env_steps=MAX_ENV_STEPS,\n",
")\n",
"print(f\"\u2705 Generated {len(scenarios)} scenarios\")\n",
"\n",
"# Preview one scenario\n",
"import json\n",
"print(\"\\n\u2500\u2500 Sample Scenario \u2500\u2500\")\n",
"print(json.dumps(scenarios[0], indent=2, default=str))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \u2500\u2500\u2500 Build prompts and HF Dataset \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"from datasets import Dataset\n",
"\n",
"prompts = [{\"prompt\": build_prompt_multiagent(sc)} for sc in scenarios]\n",
"dataset = Dataset.from_list(prompts)\n",
"\n",
"print(f\"\u2705 Dataset ready: {len(dataset)} prompts\")\n",
"print(\"\\n\u2500\u2500 Sample Prompt (first 500 chars) \u2500\u2500\")\n",
"print(dataset[0][\"prompt\"][:500])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Load Model with Unsloth\n",
"\n",
"Load Qwen 2.5 1.5B with 4-bit quantization and apply LoRA adapters for parameter-efficient training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Disable Unsloth's compiled GRPO trainer (causes SymFloat errors on Colab/Kaggle)\n",
"import os\n",
"os.environ['UNSLOTH_DISABLE_COMPILE'] = '1'\n",
"os.environ['UNSLOTH_COMPILE_DISABLE'] = '1'\n",
"os.environ['DISABLE_UNSLOTH_COMPILE'] = '1'\n",
"from unsloth import FastLanguageModel\n",
"\n",
"\n",
"# \u2500\u2500\u2500 Model Configuration \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"MODEL_NAME = \"unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit\"\n",
"MAX_SEQ_LENGTH = 1024\n",
"\n",
"print(f\"Loading {MODEL_NAME}...\")\n",
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
" model_name=MODEL_NAME,\n",
" max_seq_length=MAX_SEQ_LENGTH,\n",
" dtype=None, # Auto-detect (bf16 if supported)\n",
" load_in_4bit=True, # 4-bit quantization via bitsandbytes\n",
")\n",
"\n",
"# Apply LoRA adapters\n",
"model = FastLanguageModel.get_peft_model(\n",
" model,\n",
" r=16,\n",
" target_modules=[\n",
" \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
" \"gate_proj\", \"up_proj\", \"down_proj\",\n",
" ],\n",
" lora_alpha=16,\n",
" lora_dropout=0,\n",
" bias=\"none\",\n",
" use_gradient_checkpointing=\"unsloth\",\n",
" random_state=SEED,\n",
" use_rslora=False,\n",
")\n",
"\n",
"if tokenizer.pad_token is None:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
"trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"total = sum(p.numel() for p in model.parameters())\n",
"print(f\"\u2705 Model loaded\")\n",
"print(f\" Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. GRPO Training\n",
"\n",
"Train using Group Relative Policy Optimization with **5 independent reward verifiers**:\n",
"\n",
"| Verifier | What it checks | Max Score |\n",
"|---|---|---|\n",
"| `format_reward_func` | `` + `` tags, JSON validity | 1.0 |\n",
"| `alignment_reward_func` | Reasoning matches market signals (anti-hallucination) | 1.0 |\n",
"| `risk_reward_func_multiagent` | Compliance with RM size limit | 1.0 |\n",
"| `profit_reward_func` | Directional correctness vs. price trend | 1.0 |\n",
"| `governance_reward_func_multiagent` | Full RM + PM governance compliance | 1.0 |"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import inspect\n",
"from trl.trainer.grpo_config import GRPOConfig\n",
"from trl.trainer.grpo_trainer import GRPOTrainer\n",
"\n",
"# \u2500\u2500\u2500 Training Hyperparameters \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"OUTPUT_DIR = \"./grpo_output\"\n",
"MAX_STEPS = 250 # Training steps (increase for longer training)\n",
"BATCH_SIZE = 4 # Per-device batch size\n",
"GRAD_ACCUM = 2 # Gradient accumulation steps\n",
"NUM_GENERATIONS = 4 # GRPO generations per prompt\n",
"LEARNING_RATE = 5e-5\n",
"MAX_PROMPT_LENGTH = 768\n",
"MAX_COMPLETION_LEN = 200\n",
"SAVE_STEPS = 50\n",
"LOGGING_STEPS = 1\n",
"\n",
"# \u2500\u2500\u2500 GRPO Config \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"training_args = GRPOConfig(\n",
" output_dir=OUTPUT_DIR,\n",
" learning_rate=LEARNING_RATE,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" gradient_accumulation_steps=GRAD_ACCUM,\n",
" num_train_epochs=1,\n",
" max_steps=MAX_STEPS,\n",
" save_steps=SAVE_STEPS,\n",
" logging_steps=LOGGING_STEPS,\n",
" bf16=torch.cuda.is_bf16_supported(),\n",
" fp16=not torch.cuda.is_bf16_supported(),\n",
" num_generations=NUM_GENERATIONS,\n",
" report_to=\"none\",\n",
")\n",
"\n",
"# \u2500\u2500\u2500 Reward Functions \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"reward_funcs = [\n",
" format_reward_func,\n",
" alignment_reward_func,\n",
" risk_reward_func_multiagent,\n",
" profit_reward_func,\n",
" governance_reward_func_multiagent,\n",
"]\n",
"\n",
"# \u2500\u2500\u2500 Build Trainer (handles TRL version differences) \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"trainer_kwargs = {\n",
" \"model\": model,\n",
" \"reward_funcs\": reward_funcs,\n",
" \"args\": training_args,\n",
" \"train_dataset\": dataset,\n",
"}\n",
"\n",
"# TRL version compatibility: older = \"tokenizer\", newer = \"processing_class\"\n",
"sig = inspect.signature(GRPOTrainer.__init__)\n",
"if \"processing_class\" in sig.parameters:\n",
" trainer_kwargs[\"processing_class\"] = tokenizer\n",
"elif \"tokenizer\" in sig.parameters:\n",
" trainer_kwargs[\"tokenizer\"] = tokenizer\n",
"\n",
"trainer = GRPOTrainer(**trainer_kwargs)\n",
"\n",
"\n",
"# Delete Unsloth compiled cache to force vanilla TRL\n",
"import shutil\n",
"cache_dir = os.path.join(os.getcwd(), 'unsloth_compiled_cache')\n",
"if os.path.exists(cache_dir):\n",
" shutil.rmtree(cache_dir)\n",
" print('\ud83d\uddd1\ufe0f Deleted unsloth_compiled_cache')\n",
"\n",
"# Safety net: inject Unsloth-expected attrs into training_args\n",
"_unsloth_defaults = {\n",
" 'unsloth_num_chunks': -1,\n",
" 'unsloth_grpo_mini_batch': None,\n",
" 'loss_type': 'GRPO',\n",
" 'importance_sampling_level': 0,\n",
" 'unsloth_logit_chunk_multiplier': 1,\n",
"}\n",
"for _k, _v in _unsloth_defaults.items():\n",
" if not hasattr(training_args, _k): setattr(training_args, _k, _v)\n",
"\n",
"print(f\"\u2705 GRPOTrainer initialized\")\n",
"print(f\" Steps: {MAX_STEPS} | Batch: {BATCH_SIZE} | Generations: {NUM_GENERATIONS}\")\n",
"print(f\" Reward functions: {len(reward_funcs)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \u2500\u2500\u2500 Run Training \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"print(\"\ud83d\ude80 Starting GRPO training...\")\n",
"print(f\" This may take 30-60+ minutes on a T4 GPU.\")\n",
"print()\n",
"\n",
"trainer.train()\n",
"\n",
"print(\"\\n\u2705 Training complete!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Training Results & Plots\n",
"\n",
"Extract loss and reward curves from the training history and generate publication-ready plots."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"matplotlib.rcParams.update({\"font.size\": 12, \"figure.dpi\": 120})\n",
"\n",
"# \u2500\u2500\u2500 Extract metrics from trainer log history \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"history = trainer.state.log_history\n",
"steps = [x[\"step\"] for x in history if \"loss\" in x]\n",
"losses = [x[\"loss\"] for x in history if \"loss\" in x]\n",
"rewards = [x[\"reward\"] for x in history if \"reward\" in x]\n",
"r_steps = [x[\"step\"] for x in history if \"reward\" in x]\n",
"\n",
"print(f\"Collected {len(losses)} loss entries, {len(rewards)} reward entries\")\n",
"\n",
"# \u2500\u2500\u2500 Plot 1: Loss Curve \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
"\n",
"ax1.plot(steps, losses, color=\"#e74c3c\", linewidth=1.5, alpha=0.7, label=\"Raw\")\n",
"# Smoothed line\n",
"if len(losses) > 10:\n",
" window = min(20, len(losses) // 3)\n",
" smoothed = np.convolve(losses, np.ones(window)/window, mode=\"valid\")\n",
" ax1.plot(steps[window-1:], smoothed, color=\"#c0392b\", linewidth=2.5, label=f\"MA-{window}\")\n",
"ax1.set_xlabel(\"Training Step\")\n",
"ax1.set_ylabel(\"Loss\")\n",
"ax1.set_title(\"GRPO Training Loss\")\n",
"ax1.legend()\n",
"ax1.grid(True, alpha=0.3)\n",
"\n",
"# \u2500\u2500\u2500 Plot 2: Reward Curve \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"ax2.plot(r_steps, rewards, color=\"#27ae60\", linewidth=1.5, alpha=0.7, label=\"Raw\")\n",
"if len(rewards) > 10:\n",
" window = min(20, len(rewards) // 3)\n",
" smoothed = np.convolve(rewards, np.ones(window)/window, mode=\"valid\")\n",
" ax2.plot(r_steps[window-1:], smoothed, color=\"#1e8449\", linewidth=2.5, label=f\"MA-{window}\")\n",
"ax2.set_xlabel(\"Training Step\")\n",
"ax2.set_ylabel(\"Mean Reward\")\n",
"ax2.set_title(\"GRPO Mean Reward (5 Verifiers)\")\n",
"ax2.legend()\n",
"ax2.grid(True, alpha=0.3)\n",
"\n",
"plt.suptitle(\"QuantHive Multi-Agent GRPO Training \u2014 Qwen 2.5 1.5B\", fontsize=14, fontweight=\"bold\")\n",
"plt.tight_layout()\n",
"plt.savefig(\"training_curves.png\", bbox_inches=\"tight\")\n",
"plt.show()\n",
"print(\"\u2705 Saved training_curves.png\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \u2500\u2500\u2500 Save training metrics to JSON \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"metrics = {\n",
" \"losses\": losses,\n",
" \"rewards\": rewards,\n",
" \"final_loss\": losses[-1] if losses else None,\n",
" \"final_reward\": rewards[-1] if rewards else None,\n",
" \"best_reward\": max(rewards) if rewards else None,\n",
" \"total_steps\": MAX_STEPS,\n",
" \"model\": MODEL_NAME,\n",
" \"num_scenarios\": NUM_SCENARIOS,\n",
"}\n",
"\n",
"with open(\"training_metrics.json\", \"w\") as f:\n",
" json.dump(metrics, f, indent=2)\n",
"\n",
"print(\"\ud83d\udcca Training Summary:\")\n",
"print(f\" Final Loss: {metrics['final_loss']:.4f}\" if metrics['final_loss'] else \" No loss data\")\n",
"print(f\" Final Reward: {metrics['final_reward']:.4f}\" if metrics['final_reward'] else \" No reward data\")\n",
"print(f\" Best Reward: {metrics['best_reward']:.4f}\" if metrics['best_reward'] else \" No reward data\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Save Model\n",
"\n",
"Save the trained LoRA adapters. Optionally push to Hugging Face Hub."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \u2500\u2500\u2500 Save LoRA adapters locally \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"ADAPTER_DIR = \"./quanthive-trader-grpo-lora\"\n",
"\n",
"model.save_pretrained(ADAPTER_DIR)\n",
"tokenizer.save_pretrained(ADAPTER_DIR)\n",
"print(f\"\u2705 LoRA adapters saved to {ADAPTER_DIR}\")\n",
"\n",
"# Optional: Save full merged model (16-bit) \u2014 uses more disk\n",
"# model.save_pretrained_merged(\"./quanthive-merged-16bit\", tokenizer, save_method=\"merged_16bit\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \u2500\u2500\u2500 (Optional) Push to Hugging Face Hub \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"# Uncomment and set your token + repo to push adapters to HF\n",
"\n",
"# HF_TOKEN = \"hf_YOUR_TOKEN_HERE\" # or use huggingface-cli login\n",
"# HF_REPO = \"ARKAISW/quanthive-trader-grpo-lora\"\n",
"#\n",
"# model.push_to_hub(HF_REPO, token=HF_TOKEN)\n",
"# tokenizer.push_to_hub(HF_REPO, token=HF_TOKEN)\n",
"# print(f\"\u2705 Pushed to https://huggingface.co/{HF_REPO}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Inference Test \u2014 Before vs After\n",
"\n",
"Compare the trained model's outputs against the base model to demonstrate learning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \u2500\u2500\u2500 Generate a test scenario \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"test_scenarios = generate_pz_scenarios(n=3, difficulty=\"easy\")\n",
"test_prompt = build_prompt_multiagent(test_scenarios[0])\n",
"\n",
"# Switch to inference mode\n",
"FastLanguageModel.for_inference(model)\n",
"\n",
"# Tokenize\n",
"inputs = tokenizer(\n",
" test_prompt,\n",
" return_tensors=\"pt\",\n",
" truncation=True,\n",
" max_length=MAX_PROMPT_LENGTH,\n",
").to(\"cuda\")\n",
"\n",
"# Generate\n",
"with torch.no_grad():\n",
" output_ids = model.generate(\n",
" **inputs,\n",
" max_new_tokens=MAX_COMPLETION_LEN,\n",
" do_sample=True,\n",
" temperature=0.7,\n",
" top_p=0.9,\n",
" )\n",
"\n",
"# Decode only the generated part\n",
"generated = tokenizer.decode(\n",
" output_ids[0][inputs[\"input_ids\"].shape[1]:],\n",
" skip_special_tokens=True,\n",
")\n",
"\n",
"print(\"=\" * 60)\n",
"print(\" \ud83e\uddea Trained Model Inference\")\n",
"print(\"=\" * 60)\n",
"print(f\"\\nGovernance context:\")\n",
"print(f\" RM size_limit: {test_scenarios[0]['rm_size_limit']:.3f}\")\n",
"print(f\" PM cap_alloc: {test_scenarios[0]['pm_cap_alloc']:.3f}\")\n",
"print(f\"\\nModel output:\")\n",
"print(generated)\n",
"print(\"=\" * 60)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \u2500\u2500\u2500 Score the generated output with all 5 reward verifiers \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"test_prompts = [test_prompt]\n",
"test_completions = [generated]\n",
"\n",
"scores = {\n",
" \"format\": format_reward_func(test_prompts, test_completions)[0],\n",
" \"alignment\": alignment_reward_func(test_prompts, test_completions)[0],\n",
" \"risk\": risk_reward_func_multiagent(test_prompts, test_completions)[0],\n",
" \"profit\": profit_reward_func(test_prompts, test_completions)[0],\n",
" \"governance\": governance_reward_func_multiagent(test_prompts, test_completions)[0],\n",
"}\n",
"\n",
"print(\"\\n\ud83d\udcca Reward Verifier Scores:\")\n",
"print(\"-\" * 40)\n",
"for name, score in scores.items():\n",
" bar = \"\u2588\" * int(score * 20) + \"\u2591\" * (20 - int(score * 20))\n",
" print(f\" {name:15s} {bar} {score:.2f}\")\n",
"print(\"-\" * 40)\n",
"print(f\" {'TOTAL':15s} {' '*20} {sum(scores.values()):.2f} / 5.00\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9. Baseline Comparison\n",
"\n",
"Run the same scenarios through the environment with the trained model vs. random baseline to demonstrate improvement."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \u2500\u2500\u2500 Score multiple scenarios: Trained vs Random \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"N_EVAL = 20\n",
"eval_scenarios = generate_pz_scenarios(n=N_EVAL, difficulty=\"easy\")\n",
"\n",
"trained_scores = {k: [] for k in [\"format\", \"alignment\", \"risk\", \"profit\", \"governance\"]}\n",
"random_scores = {k: [] for k in [\"format\", \"alignment\", \"risk\", \"profit\", \"governance\"]}\n",
"\n",
"print(f\"Evaluating {N_EVAL} scenarios...\")\n",
"\n",
"for i, sc in enumerate(eval_scenarios):\n",
" prompt = build_prompt_multiagent(sc)\n",
" inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, max_length=MAX_PROMPT_LENGTH).to(\"cuda\")\n",
"\n",
" # Trained model generation\n",
" with torch.no_grad():\n",
" out = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LEN, do_sample=True, temperature=0.7, top_p=0.9)\n",
" gen = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
"\n",
" # Random baseline: random action string\n",
" rand_completion = f'Random trade decision.\\n{{\"direction\": {random.choice([0,1,2])}, \"size\": {random.uniform(0.01, 0.9):.2f}, \"sl\": 0, \"tp\": 0}}'\n",
"\n",
" # Score both\n",
" for key, func in [(\"format\", format_reward_func), (\"alignment\", alignment_reward_func),\n",
" (\"risk\", risk_reward_func_multiagent), (\"profit\", profit_reward_func),\n",
" (\"governance\", governance_reward_func_multiagent)]:\n",
" trained_scores[key].append(func([prompt], [gen])[0])\n",
" random_scores[key].append(func([prompt], [rand_completion])[0])\n",
"\n",
" if (i + 1) % 5 == 0:\n",
" print(f\" [{i+1}/{N_EVAL}] done\")\n",
"\n",
"print(\"\u2705 Evaluation complete\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# \u2500\u2500\u2500 Plot: Trained vs Random Baseline \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
"categories = list(trained_scores.keys())\n",
"trained_means = [np.mean(trained_scores[k]) for k in categories]\n",
"random_means = [np.mean(random_scores[k]) for k in categories]\n",
"\n",
"x = np.arange(len(categories))\n",
"width = 0.35\n",
"\n",
"fig, ax = plt.subplots(figsize=(10, 6))\n",
"bars1 = ax.bar(x - width/2, random_means, width, label=\"Random Baseline\", color=\"#e74c3c\", alpha=0.8)\n",
"bars2 = ax.bar(x + width/2, trained_means, width, label=\"GRPO-Trained\", color=\"#27ae60\", alpha=0.8)\n",
"\n",
"ax.set_xlabel(\"Reward Verifier\")\n",
"ax.set_ylabel(\"Mean Score\")\n",
"ax.set_title(\"QuantHive: Trained Agent vs Random Baseline\")\n",
"ax.set_xticks(x)\n",
"ax.set_xticklabels([c.capitalize() for c in categories])\n",
"ax.legend()\n",
"ax.set_ylim(0, 1.1)\n",
"ax.grid(axis=\"y\", alpha=0.3)\n",
"\n",
"# Add value labels\n",
"for bar in bars1:\n",
" ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02,\n",
" f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=9)\n",
"for bar in bars2:\n",
" ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02,\n",
" f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=9)\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig(\"baseline_comparison.png\", bbox_inches=\"tight\")\n",
"plt.show()\n",
"\n",
"# Summary table\n",
"print(\"\\n\ud83d\udcca Comparison Summary:\")\n",
"print(f\"{'Verifier':15s} {'Random':>8s} {'Trained':>8s} {'\u0394':>8s}\")\n",
"print(\"-\" * 45)\n",
"for cat, rm, tm in zip(categories, random_means, trained_means):\n",
" delta = tm - rm\n",
" print(f\"{cat:15s} {rm:8.3f} {tm:8.3f} {delta:+8.3f}\")\n",
"print(\"-\" * 45)\n",
"print(f\"{'TOTAL':15s} {sum(random_means):8.3f} {sum(trained_means):8.3f} {sum(trained_means)-sum(random_means):+8.3f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## \u2705 Done!\n",
"\n",
"**Files generated:**\n",
"- `training_curves.png` \u2014 Loss and reward curves\n",
"- `baseline_comparison.png` \u2014 Trained vs. random baseline\n",
"- `training_metrics.json` \u2014 Raw metrics\n",
"- `quanthive-trader-grpo-lora/` \u2014 LoRA adapter weights\n",
"\n",
"**Next steps:**\n",
"1. Download `training_curves.png` and `baseline_comparison.png` and commit to your repo's `plots/` directory\n",
"2. Push LoRA adapters to Hugging Face Hub (uncomment the push cell above)\n",
"3. For longer training runs with more compute, see the HF Jobs guide in the README\n",
"\n",
"---\n",
"\n",
"**GitHub:** [ARKAISW/multi-agent-trading-env](https://github.com/ARKAISW/multi-agent-trading-env) \n",
"**HF Space:** [ARKAISW/QuantHive](https://huggingface.co/spaces/ARKAISW/QuantHive)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 0
}