Commit ·
a1f6f11
1
Parent(s): e42a7af
Vittal's changes, final submission testing
Browse files- FinalTraining.ipynb +1199 -0
- README.md +280 -128
- Training.py +8 -105
- envs/board_sim_env/server/board_sim_env_environment.py +50 -0
- inference.py +426 -0
- notebooks/train_cell_fixed.py +209 -0
- notebooks/train_grpo_kaggle.ipynb +955 -0
- notebooks/train_grpo_v2.ipynb +85 -33
FinalTraining.ipynb
ADDED
|
@@ -0,0 +1,1199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# BoardSim GRPO — Qwen3-4B (v3, generic events + base-model baseline)\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Training notebook for the Meta PyTorch × HuggingFace OpenEnv Hackathon submission.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**This revision (v3) — what changed:**\n",
|
| 12 |
+
"- Events are now **organization-agnostic** (competition, talent, regulation, PR, M&A,\n",
|
| 13 |
+
" funding, governance, exit) so the simulation maps onto any company, not a specific industry.\n",
|
| 14 |
+
"- **Pitch scoring is semantic**, not keyword-based — sentence-transformer cosine similarity\n",
|
| 15 |
+
" against per-role manifestos, with a TF-IDF fallback. The agent has to write substantively\n",
|
| 16 |
+
" aligned arguments rather than spray vocabulary.\n",
|
| 17 |
+
"- **The baseline is the same Qwen3-4B model with LoRA disabled**, not a random policy.\n",
|
| 18 |
+
" A coin-flip is not a meaningful opponent for a 4 B language model; the apples-to-apples\n",
|
| 19 |
+
" reference is the *same model* without the fine-tuning delta. Recovered cheaply via\n",
|
| 20 |
+
" `peft`'s `model.disable_adapter()` context manager (no second model load).\n",
|
| 21 |
+
"- CEO vote weight raised to 2.5× and persuasion shift cap raised to 55% so a CEO decision\n",
|
| 22 |
+
" visibly moves outcomes round-to-round.\n",
|
| 23 |
+
"- Added per-event boardroom win-rate plot — the most direct picture of *where* fine-tuning helps.\n",
|
| 24 |
+
"- ToM probe and trust-trajectory analyses both report fine-tuned **and** base for fair contrast.\n"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "markdown",
|
| 29 |
+
"metadata": {},
|
| 30 |
+
"source": [
|
| 31 |
+
"## 1. Install (unsloth FIRST — order matters)"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": null,
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"# IMPORTANT: install unsloth + its zoo BEFORE anything else, because unsloth\n",
|
| 41 |
+
"# patches torch/transformers at import time. If transformers loads first, the\n",
|
| 42 |
+
"# patches don't apply and 4-bit LoRA training silently runs in a slow path.\n",
|
| 43 |
+
"%pip install -q --no-deps unsloth\n",
|
| 44 |
+
"%pip install -q unsloth_zoo\n",
|
| 45 |
+
"%pip install -q \"openenv-core==0.2.3\" \"trl>=0.12,<2.0\" \"transformers>=4.45,<5.0\" \\\n",
|
| 46 |
+
" \"datasets>=3.0\" \"accelerate>=1.0\" \"huggingface_hub>=0.25\" \"pydantic>=2.0\" \\\n",
|
| 47 |
+
" wandb matplotlib python-dotenv bitsandbytes scipy scikit-learn sentence-transformers\n",
|
| 48 |
+
"import os, pathlib\n"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "markdown",
|
| 53 |
+
"metadata": {},
|
| 54 |
+
"source": [
|
| 55 |
+
"## 2. Auth (HF + WandB)"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": null,
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [],
|
| 63 |
+
"source": [
|
| 64 |
+
"# Colab Secrets first\n",
|
| 65 |
+
"try:\n",
|
| 66 |
+
" from google.colab import userdata # type: ignore\n",
|
| 67 |
+
" for k in ('HF_TOKEN', 'WANDB_API_KEY', 'ENV_BASE_URL', 'ADAPTER_REPO'):\n",
|
| 68 |
+
" try:\n",
|
| 69 |
+
" v = userdata.get(k)\n",
|
| 70 |
+
" if v:\n",
|
| 71 |
+
" os.environ.setdefault(k, v)\n",
|
| 72 |
+
" except Exception:\n",
|
| 73 |
+
" pass\n",
|
| 74 |
+
"except Exception:\n",
|
| 75 |
+
" pass\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"# .env fallback for local runs\n",
|
| 78 |
+
"try:\n",
|
| 79 |
+
" from dotenv import load_dotenv\n",
|
| 80 |
+
" for p in [pathlib.Path('.env'), pathlib.Path('../.env'),\n",
|
| 81 |
+
" pathlib.Path('/content/repo/.env')]:\n",
|
| 82 |
+
" if p.exists():\n",
|
| 83 |
+
" load_dotenv(p, override=False)\n",
|
| 84 |
+
" print(f'Loaded env from {p.resolve()}')\n",
|
| 85 |
+
" break\n",
|
| 86 |
+
"except Exception:\n",
|
| 87 |
+
" pass\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"if not os.environ.get('HF_TOKEN'):\n",
|
| 90 |
+
" os.environ['HF_TOKEN'] = input('HF token: ').strip()\n",
|
| 91 |
+
"if not os.environ.get('WANDB_API_KEY'):\n",
|
| 92 |
+
" os.environ['WANDB_API_KEY'] = input('WandB key (or blank to skip): ').strip()\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"from huggingface_hub import login as hf_login\n",
|
| 95 |
+
"hf_login(token=os.environ['HF_TOKEN'], add_to_git_credential=False)\n",
|
| 96 |
+
"print('HF auth ok.')\n",
|
| 97 |
+
"if os.environ.get('WANDB_API_KEY'):\n",
|
| 98 |
+
" import wandb\n",
|
| 99 |
+
" wandb.login(key=os.environ['WANDB_API_KEY'])\n",
|
| 100 |
+
" print('W&B auth ok.')\n",
|
| 101 |
+
"import os, pathlib\n"
|
| 102 |
+
]
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"cell_type": "markdown",
|
| 106 |
+
"metadata": {},
|
| 107 |
+
"source": [
|
| 108 |
+
"## 3. Mount Drive (early — checkpoints survive Colab disconnects)"
|
| 109 |
+
]
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"cell_type": "code",
|
| 113 |
+
"execution_count": null,
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"outputs": [],
|
| 116 |
+
"source": [
|
| 117 |
+
"IN_COLAB = os.path.isdir('/content')\n",
|
| 118 |
+
"if IN_COLAB:\n",
|
| 119 |
+
" from google.colab import drive\n",
|
| 120 |
+
" drive.mount('/content/drive', force_remount=False)\n",
|
| 121 |
+
" DRIVE_DIR = pathlib.Path('/content/drive/MyDrive/BoardSim_Run')\n",
|
| 122 |
+
"else:\n",
|
| 123 |
+
" DRIVE_DIR = pathlib.Path('./BoardSim_Run')\n",
|
| 124 |
+
"DRIVE_DIR.mkdir(parents=True, exist_ok=True)\n",
|
| 125 |
+
"ASSETS = DRIVE_DIR / 'assets'; ASSETS.mkdir(exist_ok=True)\n",
|
| 126 |
+
"CKPT = DRIVE_DIR / 'lora_qwen3_4b'; CKPT.mkdir(exist_ok=True)\n",
|
| 127 |
+
"print('DRIVE_DIR =', DRIVE_DIR)\n",
|
| 128 |
+
"import os, sys, subprocess, importlib, urllib.request, json as _json\n"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "markdown",
|
| 133 |
+
"metadata": {},
|
| 134 |
+
"source": [
|
| 135 |
+
"## 4. Clone repo + import BoardSimEnv client"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"cell_type": "code",
|
| 140 |
+
"execution_count": null,
|
| 141 |
+
"metadata": {},
|
| 142 |
+
"outputs": [],
|
| 143 |
+
"source": [
|
| 144 |
+
"ENV_BASE_URL = os.environ.get('ENV_BASE_URL',\n",
|
| 145 |
+
" 'https://stavankhobare-sst-metaxpytorch-hackathon.hf.space')\n",
|
| 146 |
+
"REPO_URL = 'https://github.com/StavanRKhobare/SST-MetaxPyTorch-Hackathon'\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"REPO_DIR = '/content/repo' if IN_COLAB else os.path.abspath('./repo')\n",
|
| 149 |
+
"if not os.path.isdir(os.path.join(REPO_DIR, '.git')):\n",
|
| 150 |
+
" subprocess.run(['git', 'clone', '--depth', '1', REPO_URL, REPO_DIR], check=True)\n",
|
| 151 |
+
"else:\n",
|
| 152 |
+
" subprocess.run(['git', '-C', REPO_DIR, 'pull', '--ff-only'], check=False)\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"ENVS_DIR = os.path.join(REPO_DIR, 'envs')\n",
|
| 155 |
+
"if ENVS_DIR not in sys.path:\n",
|
| 156 |
+
" sys.path.insert(0, ENVS_DIR)\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"for mod in [m for m in list(sys.modules) if m == 'board_sim_env' or m.startswith('board_sim_env.')]:\n",
|
| 159 |
+
" del sys.modules[mod]\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"from board_sim_env.client import BoardSimEnv\n",
|
| 162 |
+
"from board_sim_env.models import BoardSimAction, BoardSimObservation\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"try:\n",
|
| 165 |
+
" with urllib.request.urlopen(f'{ENV_BASE_URL.rstrip(\"/\")}/health', timeout=20) as r:\n",
|
| 166 |
+
" h = _json.loads(r.read())\n",
|
| 167 |
+
" print('health:', h)\n",
|
| 168 |
+
"except Exception as e:\n",
|
| 169 |
+
" print(f'WARN: could not reach {ENV_BASE_URL}/health ({e})')\n",
|
| 170 |
+
"\n",
|
| 171 |
+
"def make_env():\n",
|
| 172 |
+
" return BoardSimEnv(base_url=ENV_BASE_URL)\n",
|
| 173 |
+
"\n",
|
| 174 |
+
"print('BoardSimEnv ready.')\n",
|
| 175 |
+
"# -----------------------------------------------------------------------------\n"
|
| 176 |
+
]
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"cell_type": "markdown",
|
| 180 |
+
"metadata": {},
|
| 181 |
+
"source": [
|
| 182 |
+
"## 5. Load base Qwen3-4B (no LoRA yet — this is also our baseline)"
|
| 183 |
+
]
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"cell_type": "code",
|
| 187 |
+
"execution_count": null,
|
| 188 |
+
"metadata": {},
|
| 189 |
+
"outputs": [],
|
| 190 |
+
"source": [
|
| 191 |
+
"# Load base Qwen3-4B (NO LoRA yet). The base model serves a dual role:\n",
|
| 192 |
+
"# (a) it is the reference baseline against which the fine-tuned policy is\n",
|
| 193 |
+
"# compared — this replaces the older random-policy baseline, which was\n",
|
| 194 |
+
"# not meaningful (a coin-flip is not a competitive opponent for an LLM).\n",
|
| 195 |
+
"# (b) once the baseline is recorded, we wrap the SAME model with LoRA\n",
|
| 196 |
+
"# adapters and fine-tune it. At paired-eval time we toggle the adapters\n",
|
| 197 |
+
"# off via `model.disable_adapter()` to recover base-model behaviour\n",
|
| 198 |
+
"# without reloading 4 GB of weights.\n",
|
| 199 |
+
"# -----------------------------------------------------------------------------\n",
|
| 200 |
+
"import unsloth # noqa: F401\n",
|
| 201 |
+
"from unsloth import FastLanguageModel\n",
|
| 202 |
+
"import torch\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"MODEL_NAME = 'Qwen/Qwen3-4B'\n",
|
| 205 |
+
"MAX_SEQ_LEN = 4096\n",
|
| 206 |
+
"\n",
|
| 207 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 208 |
+
" model_name=MODEL_NAME,\n",
|
| 209 |
+
" max_seq_length=MAX_SEQ_LEN,\n",
|
| 210 |
+
" load_in_4bit=True,\n",
|
| 211 |
+
" dtype=None,\n",
|
| 212 |
+
")\n",
|
| 213 |
+
"if tokenizer.pad_token is None:\n",
|
| 214 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"device = next(model.parameters()).device\n",
|
| 217 |
+
"print(f'Loaded {MODEL_NAME} on {device}.')\n",
|
| 218 |
+
"import re\n"
|
| 219 |
+
]
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
"cell_type": "markdown",
|
| 223 |
+
"metadata": {},
|
| 224 |
+
"source": [
|
| 225 |
+
"## 6. Prompt template + completion parser (generic CEO, no industry-specific persona)"
|
| 226 |
+
]
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"cell_type": "code",
|
| 230 |
+
"execution_count": null,
|
| 231 |
+
"metadata": {},
|
| 232 |
+
"outputs": [],
|
| 233 |
+
"source": [
|
| 234 |
+
"# Generic CEO prompt — applies to any organization, not a specific industry.\n",
|
| 235 |
+
"SYSTEM_PROMPT = \"\"\"You are the CEO of a mid-stage organization. Your board has 4 members with HIDDEN AGENDAS you cannot see directly:\n",
|
| 236 |
+
" - CTO: cares about operational excellence, engineering quality, team morale, and product readiness.\n",
|
| 237 |
+
" - CFO: cares about cash discipline, runway, and regulatory safety.\n",
|
| 238 |
+
" - Investor Rep: pushes growth, market share, and bold returns.\n",
|
| 239 |
+
" - Independent: cares about reputation, governance, and long-term consensus.\n",
|
| 240 |
+
"\n",
|
| 241 |
+
"Each round you see a strategic event, every NPC's pre-vote statement, and 3 options.\n",
|
| 242 |
+
"Your decision is resolved by WEIGHTED VOTE (your weight 2.5x). A short COALITION PITCH\n",
|
| 243 |
+
"that is semantically aligned with opposing members' priorities can swing them toward your pick —\n",
|
| 244 |
+
"write substantive arguments, not just buzzwords.\n",
|
| 245 |
+
"\n",
|
| 246 |
+
"Respond in EXACTLY this format on two lines:\n",
|
| 247 |
+
"DECISION: <one of the option strings>\n",
|
| 248 |
+
"PITCH: <one or two sentences arguing for it, addressing the concerns of opposing members>\"\"\"\n",
|
| 249 |
+
"\n",
|
| 250 |
+
"DECISION_RE = re.compile(r'DECISION\\s*:\\s*([A-Za-z0-9_\\- ]+)', re.IGNORECASE)\n",
|
| 251 |
+
"PITCH_RE = re.compile(r'PITCH\\s*:\\s*(.+)', re.IGNORECASE)\n",
|
| 252 |
+
"\n",
|
| 253 |
+
"def build_prompt(obs):\n",
|
| 254 |
+
" statements = '\\n'.join(\n",
|
| 255 |
+
" f\" {s['role']} ({s['confidence']:.2f}): votes {s['vote']} - {s['statement']}\"\n",
|
| 256 |
+
" for s in obs.npc_statements\n",
|
| 257 |
+
" )\n",
|
| 258 |
+
" return (\n",
|
| 259 |
+
" f\"{SYSTEM_PROMPT}\\n\\n\"\n",
|
| 260 |
+
" f\"State: revenue=${obs.state['revenue']:.0f}/yr burn=${obs.state['burn_rate']:.0f}/mo \"\n",
|
| 261 |
+
" f\"runway={obs.state['runway_months']:.1f}mo morale={obs.state['team_morale']:.2f} \"\n",
|
| 262 |
+
" f\"investors={obs.state['investor_confidence']:.2f} reg_risk={obs.state['regulatory_risk']:.2f}\\n\"\n",
|
| 263 |
+
" f\"Event: {obs.event}\\nBoard:\\n{statements}\\n\"\n",
|
| 264 |
+
" f\"Options: {obs.options}\\n\"\n",
|
| 265 |
+
" )\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"def parse_completion(completion: str, options):\n",
|
| 268 |
+
" \"\"\"Returns (decision, pitch, format_ok). format_ok=True only if BOTH tags parsed.\"\"\"\n",
|
| 269 |
+
" decision = options[0]\n",
|
| 270 |
+
" decision_ok = False\n",
|
| 271 |
+
" dm = DECISION_RE.search(completion)\n",
|
| 272 |
+
" if dm:\n",
|
| 273 |
+
" cand = dm.group(1).strip().lower()\n",
|
| 274 |
+
" for opt in options:\n",
|
| 275 |
+
" if opt.lower() == cand or opt.lower() in cand:\n",
|
| 276 |
+
" decision = opt; decision_ok = True; break\n",
|
| 277 |
+
" if not decision_ok:\n",
|
| 278 |
+
" for opt in options:\n",
|
| 279 |
+
" if opt.lower() in completion.lower():\n",
|
| 280 |
+
" decision = opt; break\n",
|
| 281 |
+
" pm = PITCH_RE.search(completion)\n",
|
| 282 |
+
" pitch = pm.group(1).strip()[:400] if pm else ''\n",
|
| 283 |
+
" format_ok = bool(dm) and bool(pm)\n",
|
| 284 |
+
" return decision, pitch, format_ok\n"
|
| 285 |
+
]
|
| 286 |
+
},
|
| 287 |
+
{
|
| 288 |
+
"cell_type": "markdown",
|
| 289 |
+
"metadata": {},
|
| 290 |
+
"source": [
|
| 291 |
+
"## 7. Episode runner (works for both base and fine-tuned model)"
|
| 292 |
+
]
|
| 293 |
+
},
|
| 294 |
+
{
|
| 295 |
+
"cell_type": "code",
|
| 296 |
+
"execution_count": null,
|
| 297 |
+
"metadata": {},
|
| 298 |
+
"outputs": [],
|
| 299 |
+
"source": [
|
| 300 |
+
"MAX_NEW_TOKENS = 80\n",
|
| 301 |
+
"\n",
|
| 302 |
+
"def greedy_action(obs):\n",
|
| 303 |
+
" prompt = build_prompt(obs)\n",
|
| 304 |
+
" enc = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(device)\n",
|
| 305 |
+
" with torch.no_grad():\n",
|
| 306 |
+
" out = model.generate(\n",
|
| 307 |
+
" **enc, max_new_tokens=MAX_NEW_TOKENS,\n",
|
| 308 |
+
" do_sample=False, pad_token_id=tokenizer.eos_token_id,\n",
|
| 309 |
+
" )\n",
|
| 310 |
+
" completion = tokenizer.decode(out[0][enc.input_ids.shape[1]:], skip_special_tokens=True)\n",
|
| 311 |
+
" return parse_completion(completion, obs.options)\n",
|
| 312 |
+
"import random, statistics, json\n",
|
| 313 |
+
"\n",
|
| 314 |
+
"MAX_STEPS_PER_EP = 20\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"def run_episode(env, seed):\n",
|
| 317 |
+
" \"\"\"Runs ONE full episode using the currently-active model state\n",
|
| 318 |
+
" (base if adapters disabled, fine-tuned otherwise). Returns dense metrics.\"\"\"\n",
|
| 319 |
+
" result = env.reset(seed=seed)\n",
|
| 320 |
+
" obs = result.observation\n",
|
| 321 |
+
" ep_r, n, fmt_hits, pitch_hits = 0.0, 0, 0, 0\n",
|
| 322 |
+
" while not result.done and n < MAX_STEPS_PER_EP:\n",
|
| 323 |
+
" decision, pitch, fmt_ok = greedy_action(obs)\n",
|
| 324 |
+
" if fmt_ok: fmt_hits += 1\n",
|
| 325 |
+
" if pitch.strip(): pitch_hits += 1\n",
|
| 326 |
+
" result = env.step(BoardSimAction(decision=decision, coalition_pitch=pitch))\n",
|
| 327 |
+
" obs = result.observation\n",
|
| 328 |
+
" ep_r += float(result.reward or 0.0)\n",
|
| 329 |
+
" n += 1\n",
|
| 330 |
+
" return {\n",
|
| 331 |
+
" 'final_profit': obs.state['profitability_score'],\n",
|
| 332 |
+
" 'ep_reward': ep_r, 'steps': n,\n",
|
| 333 |
+
" 'format_rate': fmt_hits / max(1, n), 'pitch_rate': pitch_hits / max(1, n),\n",
|
| 334 |
+
" 'history': obs.state.get('history', []),\n",
|
| 335 |
+
" }\n",
|
| 336 |
+
"# -----------------------------------------------------------------------------\n"
|
| 337 |
+
]
|
| 338 |
+
},
|
| 339 |
+
{
|
| 340 |
+
"cell_type": "markdown",
|
| 341 |
+
"metadata": {},
|
| 342 |
+
"source": [
|
| 343 |
+
"## 8. Baseline — base Qwen3-4B on held-out seeds (replaces the old random baseline)"
|
| 344 |
+
]
|
| 345 |
+
},
|
| 346 |
+
{
|
| 347 |
+
"cell_type": "code",
|
| 348 |
+
"execution_count": null,
|
| 349 |
+
"metadata": {},
|
| 350 |
+
"outputs": [],
|
| 351 |
+
"source": [
|
| 352 |
+
"# BASELINE — base Qwen3-4B (no fine-tuning).\n",
|
| 353 |
+
"# This is the apples-to-apples reference for measuring what fine-tuning buys\n",
|
| 354 |
+
"# us. Random policies are not a competitive baseline for a 4 B language model\n",
|
| 355 |
+
"# choosing among 3 well-formed strings.\n",
|
| 356 |
+
"# -----------------------------------------------------------------------------\n",
|
| 357 |
+
"BASELINE_SEEDS = list(range(50_000, 50_000 + 100)) # held out from training\n",
|
| 358 |
+
"\n",
|
| 359 |
+
"base_finals, base_rewards, base_fmts, base_pitches = [], [], [], []\n",
|
| 360 |
+
"with make_env().sync() as env:\n",
|
| 361 |
+
" for i, s in enumerate(BASELINE_SEEDS):\n",
|
| 362 |
+
" r = run_episode(env, s)\n",
|
| 363 |
+
" base_finals.append(r['final_profit'])\n",
|
| 364 |
+
" base_rewards.append(r['ep_reward'])\n",
|
| 365 |
+
" base_fmts.append(r['format_rate'])\n",
|
| 366 |
+
" base_pitches.append(r['pitch_rate'])\n",
|
| 367 |
+
" if (i + 1) % 10 == 0:\n",
|
| 368 |
+
" print(f' base Qwen3-4B {i+1}/{len(BASELINE_SEEDS)} profit={r[\"final_profit\"]:.1f}')\n",
|
| 369 |
+
"\n",
|
| 370 |
+
"BASELINE_MEAN_PROFIT = statistics.mean(base_finals)\n",
|
| 371 |
+
"BASELINE_MEAN_REWARD = statistics.mean(base_rewards)\n",
|
| 372 |
+
"print(f'Base Qwen3-4B profit : {BASELINE_MEAN_PROFIT:.2f} \\u00b1 {statistics.stdev(base_finals):.2f}')\n",
|
| 373 |
+
"print(f'Base Qwen3-4B ep rwd : {BASELINE_MEAN_REWARD:.2f} \\u00b1 {statistics.stdev(base_rewards):.2f}')\n",
|
| 374 |
+
"print(f'Base format rate : {statistics.mean(base_fmts):.0%} pitch rate: {statistics.mean(base_pitches):.0%}')\n",
|
| 375 |
+
"\n",
|
| 376 |
+
"with open(DRIVE_DIR / 'baseline.json', 'w') as f:\n",
|
| 377 |
+
" json.dump({'model': MODEL_NAME, 'mode': 'base_no_finetune',\n",
|
| 378 |
+
" 'seeds': BASELINE_SEEDS,\n",
|
| 379 |
+
" 'finals': base_finals, 'rewards': base_rewards,\n",
|
| 380 |
+
" 'format_rates': base_fmts, 'pitch_rates': base_pitches}, f)\n",
|
| 381 |
+
"# -----------------------------------------------------------------------------\n"
|
| 382 |
+
]
|
| 383 |
+
},
|
| 384 |
+
{
|
| 385 |
+
"cell_type": "markdown",
|
| 386 |
+
"metadata": {},
|
| 387 |
+
"source": [
|
| 388 |
+
"## 9. Wrap base model with LoRA adapters"
|
| 389 |
+
]
|
| 390 |
+
},
|
| 391 |
+
{
|
| 392 |
+
"cell_type": "code",
|
| 393 |
+
"execution_count": null,
|
| 394 |
+
"metadata": {},
|
| 395 |
+
"outputs": [],
|
| 396 |
+
"source": [
|
| 397 |
+
"# Wrap base model with LoRA adapters. From here onward `model` is a PEFT\n",
|
| 398 |
+
"# model; the base behaviour is recoverable any time via\n",
|
| 399 |
+
"# `with model.disable_adapter(): ...`.\n",
|
| 400 |
+
"# -----------------------------------------------------------------------------\n",
|
| 401 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
| 402 |
+
" model,\n",
|
| 403 |
+
" r=32,\n",
|
| 404 |
+
" target_modules=['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj'],\n",
|
| 405 |
+
" lora_alpha=64,\n",
|
| 406 |
+
" lora_dropout=0.0, bias='none',\n",
|
| 407 |
+
" use_gradient_checkpointing='unsloth',\n",
|
| 408 |
+
" random_state=3407,\n",
|
| 409 |
+
")\n",
|
| 410 |
+
"\n",
|
| 411 |
+
"trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
| 412 |
+
"total = sum(p.numel() for p in model.parameters())\n",
|
| 413 |
+
"print(f'Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)')\n"
|
| 414 |
+
]
|
| 415 |
+
},
|
| 416 |
+
{
|
| 417 |
+
"cell_type": "markdown",
|
| 418 |
+
"metadata": {},
|
| 419 |
+
"source": [
|
| 420 |
+
"## 10. Periodic-eval helper"
|
| 421 |
+
]
|
| 422 |
+
},
|
| 423 |
+
{
|
| 424 |
+
"cell_type": "code",
|
| 425 |
+
"execution_count": null,
|
| 426 |
+
"metadata": {},
|
| 427 |
+
"outputs": [],
|
| 428 |
+
"source": [
|
| 429 |
+
"EVAL_SEEDS = list(range(60_000, 60_000 + 10)) # held out from training\n",
|
| 430 |
+
"\n",
|
| 431 |
+
"def periodic_eval(env):\n",
|
| 432 |
+
" profits, rewards, fmts, pitches = [], [], [], []\n",
|
| 433 |
+
" for s in EVAL_SEEDS:\n",
|
| 434 |
+
" r = run_episode(env, s)\n",
|
| 435 |
+
" profits.append(r['final_profit']); rewards.append(r['ep_reward'])\n",
|
| 436 |
+
" fmts.append(r['format_rate']); pitches.append(r['pitch_rate'])\n",
|
| 437 |
+
" import numpy as np\n",
|
| 438 |
+
" return {'profit_mean': float(np.mean(profits)),\n",
|
| 439 |
+
" 'reward_mean': float(np.mean(rewards)),\n",
|
| 440 |
+
" 'format_rate': float(np.mean(fmts)),\n",
|
| 441 |
+
" 'pitch_rate': float(np.mean(pitches))}\n"
|
| 442 |
+
]
|
| 443 |
+
},
|
| 444 |
+
{
|
| 445 |
+
"cell_type": "markdown",
|
| 446 |
+
"metadata": {},
|
| 447 |
+
"source": [
|
| 448 |
+
"## 11. GRPO training loop (single persistent env, periodic eval, Drive checkpoints)"
|
| 449 |
+
]
|
| 450 |
+
},
|
| 451 |
+
{
|
| 452 |
+
"cell_type": "code",
|
| 453 |
+
"execution_count": null,
|
| 454 |
+
"metadata": {},
|
| 455 |
+
"outputs": [],
|
| 456 |
+
"source": [
|
| 457 |
+
"import os, json, math, time, collections\n",
|
| 458 |
+
"from torch.optim import AdamW\n",
|
| 459 |
+
"\n",
|
| 460 |
+
"NUM_STEPS = int(os.environ.get('NUM_STEPS', 200))\n",
|
| 461 |
+
"GROUP_SIZE = int(os.environ.get('GROUP_SIZE', 4))\n",
|
| 462 |
+
"LR = 5e-6\n",
|
| 463 |
+
"GRAD_CLIP = 1.0\n",
|
| 464 |
+
"TEMPERATURE, TOP_P = 1.0, 0.95\n",
|
| 465 |
+
"SAVE_EVERY = 25\n",
|
| 466 |
+
"EVAL_AT = {0, 25, 50, 100, 150, NUM_STEPS - 1}\n",
|
| 467 |
+
"\n",
|
| 468 |
+
"WANDB_OK = False\n",
|
| 469 |
+
"if os.environ.get('WANDB_API_KEY'):\n",
|
| 470 |
+
" try:\n",
|
| 471 |
+
" import wandb\n",
|
| 472 |
+
" wandb.init(project='boardsim-qwen3-grpo', name='boardsim-qwen3-grpo-v3',\n",
|
| 473 |
+
" config={'num_steps': NUM_STEPS, 'group_size': GROUP_SIZE, 'lr': LR,\n",
|
| 474 |
+
" 'temperature': TEMPERATURE, 'top_p': TOP_P, 'model': MODEL_NAME},\n",
|
| 475 |
+
" finish_previous=True)\n",
|
| 476 |
+
" WANDB_OK = True\n",
|
| 477 |
+
" except TypeError:\n",
|
| 478 |
+
" wandb.init(project='boardsim-qwen3-grpo', name='boardsim-qwen3-grpo-v3',\n",
|
| 479 |
+
" config={'num_steps': NUM_STEPS, 'group_size': GROUP_SIZE, 'lr': LR,\n",
|
| 480 |
+
" 'temperature': TEMPERATURE, 'top_p': TOP_P, 'model': MODEL_NAME},\n",
|
| 481 |
+
" reinit=True)\n",
|
| 482 |
+
" WANDB_OK = True\n",
|
| 483 |
+
" except Exception as e:\n",
|
| 484 |
+
" print(f'WARN: wandb.init failed: {e}')\n",
|
| 485 |
+
"\n",
|
| 486 |
+
"optimizer = AdamW([p for p in model.parameters() if p.requires_grad],\n",
|
| 487 |
+
" lr=LR, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0)\n",
|
| 488 |
+
"\n",
|
| 489 |
+
"log_history = []\n",
|
| 490 |
+
"eval_history = []\n",
|
| 491 |
+
"decision_counter = collections.Counter()\n",
|
| 492 |
+
"t0 = time.time()\n",
|
| 493 |
+
"\n",
|
| 494 |
+
"# ONE persistent env per role for the whole training loop.\n",
|
| 495 |
+
"with make_env().sync() as env_train, make_env().sync() as env_score, make_env().sync() as env_eval:\n",
|
| 496 |
+
" for step in range(NUM_STEPS):\n",
|
| 497 |
+
" result = env_train.reset(seed=step)\n",
|
| 498 |
+
" obs = result.observation\n",
|
| 499 |
+
" prompt = build_prompt(obs)\n",
|
| 500 |
+
" enc = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(device)\n",
|
| 501 |
+
" prompt_len = enc.input_ids.shape[1]\n",
|
| 502 |
+
"\n",
|
| 503 |
+
" with torch.no_grad():\n",
|
| 504 |
+
" gen_out = model.generate(\n",
|
| 505 |
+
" input_ids=enc.input_ids, attention_mask=enc.attention_mask,\n",
|
| 506 |
+
" max_new_tokens=MAX_NEW_TOKENS, do_sample=True,\n",
|
| 507 |
+
" temperature=TEMPERATURE, top_p=TOP_P,\n",
|
| 508 |
+
" num_return_sequences=GROUP_SIZE,\n",
|
| 509 |
+
" pad_token_id=tokenizer.eos_token_id,\n",
|
| 510 |
+
" )\n",
|
| 511 |
+
" gen_out = gen_out.detach().clone()\n",
|
| 512 |
+
"\n",
|
| 513 |
+
" decisions, pitches, rewards, fmt_oks = [], [], [], []\n",
|
| 514 |
+
" for g in range(GROUP_SIZE):\n",
|
| 515 |
+
" comp = tokenizer.decode(gen_out[g][prompt_len:], skip_special_tokens=True)\n",
|
| 516 |
+
" d, pp, ok = parse_completion(comp, obs.options)\n",
|
| 517 |
+
" decisions.append(d); pitches.append(pp); fmt_oks.append(ok)\n",
|
| 518 |
+
" decision_counter[d] += 1\n",
|
| 519 |
+
" env_score.reset(seed=step)\n",
|
| 520 |
+
" sr = env_score.step(BoardSimAction(decision=d, coalition_pitch=pp))\n",
|
| 521 |
+
" rewards.append(float(sr.reward or 0.0))\n",
|
| 522 |
+
"\n",
|
| 523 |
+
" rewards_t = torch.tensor(rewards, dtype=torch.float32, device=device)\n",
|
| 524 |
+
" if rewards_t.numel() > 1 and rewards_t.std().item() > 1e-6:\n",
|
| 525 |
+
" advantages = (rewards_t - rewards_t.mean()) / (rewards_t.std() + 1e-8)\n",
|
| 526 |
+
" else:\n",
|
| 527 |
+
" advantages = rewards_t - rewards_t.mean()\n",
|
| 528 |
+
"\n",
|
| 529 |
+
" optimizer.zero_grad()\n",
|
| 530 |
+
" full_ids = gen_out\n",
|
| 531 |
+
" attn = (full_ids != tokenizer.pad_token_id).long()\n",
|
| 532 |
+
" loss_mask = attn.clone()\n",
|
| 533 |
+
" loss_mask[:, :prompt_len] = 0\n",
|
| 534 |
+
" out = model(input_ids=full_ids, attention_mask=attn)\n",
|
| 535 |
+
" logits = out.logits[:, :-1, :].float()\n",
|
| 536 |
+
" targets = full_ids[:, 1:]\n",
|
| 537 |
+
" mask = loss_mask[:, 1:].float()\n",
|
| 538 |
+
" log_probs = torch.nn.functional.log_softmax(logits, dim=-1)\n",
|
| 539 |
+
" token_nll = -log_probs.gather(2, targets.unsqueeze(-1)).squeeze(-1)\n",
|
| 540 |
+
" per_seq_nll = (token_nll * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)\n",
|
| 541 |
+
" loss = (advantages.detach() * per_seq_nll).mean()\n",
|
| 542 |
+
" loss.backward()\n",
|
| 543 |
+
" total_loss_val = float(loss.detach().item())\n",
|
| 544 |
+
" torch.nn.utils.clip_grad_norm_(\n",
|
| 545 |
+
" [p for p in model.parameters() if p.requires_grad], GRAD_CLIP)\n",
|
| 546 |
+
" optimizer.step()\n",
|
| 547 |
+
"\n",
|
| 548 |
+
" rec = {\n",
|
| 549 |
+
" 'step': step,\n",
|
| 550 |
+
" 'reward': float(rewards_t.mean().item()),\n",
|
| 551 |
+
" 'reward_std': float(rewards_t.std().item()) if rewards_t.numel() > 1 else 0.0,\n",
|
| 552 |
+
" 'reward_max': float(rewards_t.max().item()),\n",
|
| 553 |
+
" 'loss': total_loss_val,\n",
|
| 554 |
+
" 'format_rate': sum(fmt_oks) / GROUP_SIZE,\n",
|
| 555 |
+
" 'pitch_rate': sum(1 for p in pitches if p.strip()) / GROUP_SIZE,\n",
|
| 556 |
+
" 'elapsed_s': time.time() - t0,\n",
|
| 557 |
+
" }\n",
|
| 558 |
+
" log_history.append(rec)\n",
|
| 559 |
+
" if WANDB_OK:\n",
|
| 560 |
+
" wandb.log(rec, step=step)\n",
|
| 561 |
+
"\n",
|
| 562 |
+
" if step % 5 == 0:\n",
|
| 563 |
+
" print(f\"step={step:4d} reward={rec['reward']:+.3f} (\\u00b1{rec['reward_std']:.2f}) \"\n",
|
| 564 |
+
" f\"loss={rec['loss']:+.4f} fmt={rec['format_rate']:.0%} \"\n",
|
| 565 |
+
" f\"elapsed={rec['elapsed_s']:.0f}s d0={decisions[0]}\")\n",
|
| 566 |
+
"\n",
|
| 567 |
+
" if step in EVAL_AT:\n",
|
| 568 |
+
" ev = periodic_eval(env_eval)\n",
|
| 569 |
+
" ev['step'] = step\n",
|
| 570 |
+
" eval_history.append(ev)\n",
|
| 571 |
+
" print(f\" [eval@{step}] profit={ev['profit_mean']:.2f} \"\n",
|
| 572 |
+
" f\"reward={ev['reward_mean']:.2f} fmt={ev['format_rate']:.0%}\")\n",
|
| 573 |
+
" if WANDB_OK:\n",
|
| 574 |
+
" wandb.log({f'eval/{k}': v for k, v in ev.items() if k != 'step'}, step=step)\n",
|
| 575 |
+
"\n",
|
| 576 |
+
" if step > 0 and step % SAVE_EVERY == 0:\n",
|
| 577 |
+
" model.save_pretrained(str(CKPT))\n",
|
| 578 |
+
" tokenizer.save_pretrained(str(CKPT))\n",
|
| 579 |
+
" with open(DRIVE_DIR / 'log_history.json', 'w') as f:\n",
|
| 580 |
+
" json.dump(log_history, f)\n",
|
| 581 |
+
" with open(DRIVE_DIR / 'eval_history.json', 'w') as f:\n",
|
| 582 |
+
" json.dump(eval_history, f)\n",
|
| 583 |
+
"\n",
|
| 584 |
+
"model.save_pretrained(str(CKPT))\n",
|
| 585 |
+
"tokenizer.save_pretrained(str(CKPT))\n",
|
| 586 |
+
"with open(DRIVE_DIR / 'log_history.json', 'w') as f:\n",
|
| 587 |
+
" json.dump(log_history, f)\n",
|
| 588 |
+
"with open(DRIVE_DIR / 'eval_history.json', 'w') as f:\n",
|
| 589 |
+
" json.dump(eval_history, f)\n",
|
| 590 |
+
"with open(DRIVE_DIR / 'decision_counter.json', 'w') as f:\n",
|
| 591 |
+
" json.dump(dict(decision_counter), f)\n",
|
| 592 |
+
"if WANDB_OK:\n",
|
| 593 |
+
" wandb.finish()\n",
|
| 594 |
+
"print(f'Training done. {len(log_history)} steps in {time.time() - t0:.0f}s. -> {CKPT}')\n"
|
| 595 |
+
]
|
| 596 |
+
},
|
| 597 |
+
{
|
| 598 |
+
"cell_type": "markdown",
|
| 599 |
+
"metadata": {},
|
| 600 |
+
"source": [
|
| 601 |
+
"## 12. Proof #1 — reward / loss / format-compliance / pitch-rate curves"
|
| 602 |
+
]
|
| 603 |
+
},
|
| 604 |
+
{
|
| 605 |
+
"cell_type": "code",
|
| 606 |
+
"execution_count": null,
|
| 607 |
+
"metadata": {},
|
| 608 |
+
"outputs": [],
|
| 609 |
+
"source": [
|
| 610 |
+
"import numpy as np, matplotlib\n",
|
| 611 |
+
"matplotlib.use('Agg')\n",
|
| 612 |
+
"import matplotlib.pyplot as plt\n",
|
| 613 |
+
"from scipy import stats as spstats\n",
|
| 614 |
+
"\n",
|
| 615 |
+
"steps = np.array([e['step'] for e in log_history])\n",
|
| 616 |
+
"rewards = np.array([e['reward'] for e in log_history])\n",
|
| 617 |
+
"losses = np.array([e['loss'] for e in log_history])\n",
|
| 618 |
+
"fmts = np.array([e['format_rate'] for e in log_history])\n",
|
| 619 |
+
"pitches = np.array([e['pitch_rate'] for e in log_history])\n",
|
| 620 |
+
"\n",
|
| 621 |
+
"def ema(xs, alpha=0.1):\n",
|
| 622 |
+
" out, s = [], xs[0] if len(xs) else 0.0\n",
|
| 623 |
+
" for x in xs:\n",
|
| 624 |
+
" s = alpha * x + (1 - alpha) * s\n",
|
| 625 |
+
" out.append(s)\n",
|
| 626 |
+
" return np.array(out)\n",
|
| 627 |
+
"\n",
|
| 628 |
+
"rewards_ema = ema(rewards, 0.1)\n",
|
| 629 |
+
"slope, intercept, r_val, p_val, _ = spstats.linregress(steps, rewards)\n",
|
| 630 |
+
"\n",
|
| 631 |
+
"# Reward curve — vs base Qwen3-4B baseline (NOT random).\n",
|
| 632 |
+
"plt.figure(figsize=(9, 5))\n",
|
| 633 |
+
"plt.plot(steps, rewards, alpha=0.3, lw=1, label='per-step group reward')\n",
|
| 634 |
+
"plt.plot(steps, rewards_ema, lw=2.2, label='EMA (\\u03b1=0.1)')\n",
|
| 635 |
+
"plt.plot(steps, intercept + slope * steps, '--', lw=1.5,\n",
|
| 636 |
+
" label=f'linear fit slope={slope:+.4f}/step (p={p_val:.1e})')\n",
|
| 637 |
+
"plt.axhline(BASELINE_MEAN_REWARD, ls=':', lw=2, color='#c44',\n",
|
| 638 |
+
" label=f'base Qwen3-4B baseline = {BASELINE_MEAN_REWARD:.2f}')\n",
|
| 639 |
+
"plt.title('GRPO reward — BoardSim (vs same model w/o fine-tuning)')\n",
|
| 640 |
+
"plt.xlabel('step'); plt.ylabel('mean group reward')\n",
|
| 641 |
+
"plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 642 |
+
"plt.savefig(ASSETS / 'reward_curve.png', dpi=150); plt.close()\n",
|
| 643 |
+
"\n",
|
| 644 |
+
"# Loss\n",
|
| 645 |
+
"plt.figure(figsize=(9, 5))\n",
|
| 646 |
+
"plt.plot(steps, losses, lw=1.5)\n",
|
| 647 |
+
"plt.title('GRPO loss (advantage \\u00d7 NLL)'); plt.xlabel('step'); plt.ylabel('loss')\n",
|
| 648 |
+
"plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 649 |
+
"plt.savefig(ASSETS / 'loss_curve.png', dpi=150); plt.close()\n",
|
| 650 |
+
"\n",
|
| 651 |
+
"# Format compliance + pitch rate\n",
|
| 652 |
+
"plt.figure(figsize=(9, 5))\n",
|
| 653 |
+
"plt.plot(steps, ema(fmts, 0.05), lw=2, label='format-OK rate (EMA)')\n",
|
| 654 |
+
"plt.plot(steps, ema(pitches, 0.05), lw=2, label='non-empty pitch rate (EMA)')\n",
|
| 655 |
+
"plt.title('Format compliance + pitch usage during training')\n",
|
| 656 |
+
"plt.xlabel('step'); plt.ylabel('rate'); plt.ylim(-0.05, 1.05)\n",
|
| 657 |
+
"plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 658 |
+
"plt.savefig(ASSETS / 'format_compliance.png', dpi=150); plt.close()\n",
|
| 659 |
+
"\n",
|
| 660 |
+
"# Periodic eval — overlaid against base Qwen3-4B baseline so the reader\n",
|
| 661 |
+
"# can see the LoRA-trained policy progressively pull away from the base\n",
|
| 662 |
+
"# model on held-out seeds.\n",
|
| 663 |
+
"if eval_history:\n",
|
| 664 |
+
" es = [e['step'] for e in eval_history]\n",
|
| 665 |
+
" epm = [e['profit_mean'] for e in eval_history]\n",
|
| 666 |
+
" erm = [e['reward_mean'] for e in eval_history]\n",
|
| 667 |
+
" plt.figure(figsize=(9, 5))\n",
|
| 668 |
+
" plt.plot(es, epm, '-o', lw=2, label='held-out profitability (mean of 10 episodes)')\n",
|
| 669 |
+
" plt.plot(es, erm, '-s', lw=2, label='held-out episode reward')\n",
|
| 670 |
+
" plt.axhline(BASELINE_MEAN_PROFIT, ls=':', lw=1.5, color='#c44',\n",
|
| 671 |
+
" label=f'base Qwen3-4B profitability = {BASELINE_MEAN_PROFIT:.2f}')\n",
|
| 672 |
+
" plt.title('Periodic held-out eval during training (greedy)')\n",
|
| 673 |
+
" plt.xlabel('training step'); plt.ylabel('value')\n",
|
| 674 |
+
" plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 675 |
+
" plt.savefig(ASSETS / 'periodic_eval.png', dpi=150); plt.close()\n",
|
| 676 |
+
"\n",
|
| 677 |
+
"print(f'Linear-fit slope on reward: {slope:+.5f}/step (p={p_val:.2e}, R\\u00b2={r_val**2:.3f})')\n",
|
| 678 |
+
"print('Saved reward_curve.png, loss_curve.png, format_compliance.png, periodic_eval.png')\n",
|
| 679 |
+
"# -----------------------------------------------------------------------------\n"
|
| 680 |
+
]
|
| 681 |
+
},
|
| 682 |
+
{
|
| 683 |
+
"cell_type": "markdown",
|
| 684 |
+
"metadata": {},
|
| 685 |
+
"source": [
|
| 686 |
+
"## 13. Proof #2 — paired same-seed eval, fine-tuned vs base Qwen3-4B"
|
| 687 |
+
]
|
| 688 |
+
},
|
| 689 |
+
{
|
| 690 |
+
"cell_type": "code",
|
| 691 |
+
"execution_count": null,
|
| 692 |
+
"metadata": {},
|
| 693 |
+
"outputs": [],
|
| 694 |
+
"source": [
|
| 695 |
+
"# Paired same-seed eval: fine-tuned vs BASE Qwen3-4B (adapters disabled).\n",
|
| 696 |
+
"# This is the headline comparison. Same prompts, same env seeds, same\n",
|
| 697 |
+
"# decoder, same parser — only the LoRA delta differs.\n",
|
| 698 |
+
"# -----------------------------------------------------------------------------\n",
|
| 699 |
+
"from unsloth import FastLanguageModel\n",
|
| 700 |
+
"FastLanguageModel.for_inference(model)\n",
|
| 701 |
+
"\n",
|
| 702 |
+
"EVAL_N = 50\n",
|
| 703 |
+
"PAIRED_SEEDS = list(range(70_000, 70_000 + EVAL_N))\n",
|
| 704 |
+
"\n",
|
| 705 |
+
"# Trained policy (adapters active)\n",
|
| 706 |
+
"trained_finals, trained_rewards, trained_fmt, trained_pitch = [], [], [], []\n",
|
| 707 |
+
"trained_history_per_seed = []\n",
|
| 708 |
+
"with make_env().sync() as env:\n",
|
| 709 |
+
" for i, s in enumerate(PAIRED_SEEDS):\n",
|
| 710 |
+
" r = run_episode(env, s)\n",
|
| 711 |
+
" trained_finals.append(r['final_profit'])\n",
|
| 712 |
+
" trained_rewards.append(r['ep_reward'])\n",
|
| 713 |
+
" trained_fmt.append(r['format_rate'])\n",
|
| 714 |
+
" trained_pitch.append(r['pitch_rate'])\n",
|
| 715 |
+
" trained_history_per_seed.append(r['history'])\n",
|
| 716 |
+
" if (i + 1) % 10 == 0:\n",
|
| 717 |
+
" print(f' trained {i+1}/{EVAL_N} profit={r[\"final_profit\"]:.1f}')\n",
|
| 718 |
+
"\n",
|
| 719 |
+
"# Base Qwen3-4B (LoRA disabled) — paired seeds.\n",
|
| 720 |
+
"base_finals_paired, base_rewards_paired, base_fmt_paired, base_pitch_paired = [], [], [], []\n",
|
| 721 |
+
"base_history_per_seed = []\n",
|
| 722 |
+
"with make_env().sync() as env, model.disable_adapter():\n",
|
| 723 |
+
" for i, s in enumerate(PAIRED_SEEDS):\n",
|
| 724 |
+
" r = run_episode(env, s)\n",
|
| 725 |
+
" base_finals_paired.append(r['final_profit'])\n",
|
| 726 |
+
" base_rewards_paired.append(r['ep_reward'])\n",
|
| 727 |
+
" base_fmt_paired.append(r['format_rate'])\n",
|
| 728 |
+
" base_pitch_paired.append(r['pitch_rate'])\n",
|
| 729 |
+
" base_history_per_seed.append(r['history'])\n",
|
| 730 |
+
" if (i + 1) % 10 == 0:\n",
|
| 731 |
+
" print(f' base {i+1}/{EVAL_N} profit={r[\"final_profit\"]:.1f}')\n",
|
| 732 |
+
"\n",
|
| 733 |
+
"tf, bf = np.array(trained_finals), np.array(base_finals_paired)\n",
|
| 734 |
+
"tr, br = np.array(trained_rewards), np.array(base_rewards_paired)\n",
|
| 735 |
+
"\n",
|
| 736 |
+
"print(f'\\nTrained Qwen3-4B profit : {tf.mean():.2f} \\u00b1 {tf.std():.2f}')\n",
|
| 737 |
+
"print(f'Base Qwen3-4B profit : {bf.mean():.2f} \\u00b1 {bf.std():.2f}')\n",
|
| 738 |
+
"print(f'Trained ep reward : {tr.mean():.2f} \\u00b1 {tr.std():.2f}')\n",
|
| 739 |
+
"print(f'Base ep reward : {br.mean():.2f} \\u00b1 {br.std():.2f}')\n",
|
| 740 |
+
"print(f'Trained format/pitch : {np.mean(trained_fmt):.0%} / {np.mean(trained_pitch):.0%}')\n",
|
| 741 |
+
"print(f'Base format/pitch : {np.mean(base_fmt_paired):.0%} / {np.mean(base_pitch_paired):.0%}')\n",
|
| 742 |
+
"\n",
|
| 743 |
+
"with open(DRIVE_DIR / 'eval_paired.json', 'w') as f:\n",
|
| 744 |
+
" json.dump({'seeds': PAIRED_SEEDS,\n",
|
| 745 |
+
" 'trained_finals': tf.tolist(), 'base_finals': bf.tolist(),\n",
|
| 746 |
+
" 'trained_rewards': tr.tolist(), 'base_rewards': br.tolist(),\n",
|
| 747 |
+
" 'trained_format_rate': float(np.mean(trained_fmt)),\n",
|
| 748 |
+
" 'base_format_rate': float(np.mean(base_fmt_paired)),\n",
|
| 749 |
+
" 'trained_pitch_rate': float(np.mean(trained_pitch)),\n",
|
| 750 |
+
" 'base_pitch_rate': float(np.mean(base_pitch_paired))}, f)\n"
|
| 751 |
+
]
|
| 752 |
+
},
|
| 753 |
+
{
|
| 754 |
+
"cell_type": "markdown",
|
| 755 |
+
"metadata": {},
|
| 756 |
+
"source": [
|
| 757 |
+
"## 14. Proof #3 — statistics (paired t-test, Wilcoxon, Cohen's d, bootstrap 95% CI)"
|
| 758 |
+
]
|
| 759 |
+
},
|
| 760 |
+
{
|
| 761 |
+
"cell_type": "code",
|
| 762 |
+
"execution_count": null,
|
| 763 |
+
"metadata": {},
|
| 764 |
+
"outputs": [],
|
| 765 |
+
"source": [
|
| 766 |
+
"from scipy import stats as spstats\n",
|
| 767 |
+
"\n",
|
| 768 |
+
"def cohen_d(a, b):\n",
|
| 769 |
+
" pooled = np.sqrt(((a.std(ddof=1)**2) + (b.std(ddof=1)**2)) / 2)\n",
|
| 770 |
+
" return (a.mean() - b.mean()) / (pooled + 1e-12)\n",
|
| 771 |
+
"\n",
|
| 772 |
+
"def bootstrap_diff_ci(a, b, n=10_000, seed=0):\n",
|
| 773 |
+
" rng = np.random.default_rng(seed)\n",
|
| 774 |
+
" diffs = a - b # paired\n",
|
| 775 |
+
" boots = rng.choice(diffs, size=(n, len(diffs)), replace=True).mean(axis=1)\n",
|
| 776 |
+
" return float(np.percentile(boots, 2.5)), float(np.percentile(boots, 97.5))\n",
|
| 777 |
+
"\n",
|
| 778 |
+
"tt = spstats.ttest_rel(tf, bf)\n",
|
| 779 |
+
"uu = spstats.mannwhitneyu(tf, bf, alternative='greater')\n",
|
| 780 |
+
"wilc = spstats.wilcoxon(tf, bf, alternative='greater')\n",
|
| 781 |
+
"d = cohen_d(tf, bf)\n",
|
| 782 |
+
"lo, hi = bootstrap_diff_ci(tf, bf)\n",
|
| 783 |
+
"win_rate = float((tf > bf).mean())\n",
|
| 784 |
+
"tie_rate = float((tf == bf).mean())\n",
|
| 785 |
+
"\n",
|
| 786 |
+
"summary = {\n",
|
| 787 |
+
" 'baseline_model': MODEL_NAME + ' (no fine-tune)',\n",
|
| 788 |
+
" 'trained_model': MODEL_NAME + ' + LoRA r=32',\n",
|
| 789 |
+
" 'n': len(tf),\n",
|
| 790 |
+
" 'paired_t_stat': float(tt.statistic), 'paired_t_p': float(tt.pvalue),\n",
|
| 791 |
+
" 'mannwhitney_U': float(uu.statistic), 'mannwhitney_p_greater': float(uu.pvalue),\n",
|
| 792 |
+
" 'wilcoxon_p_greater': float(wilc.pvalue),\n",
|
| 793 |
+
" 'cohens_d': float(d),\n",
|
| 794 |
+
" 'paired_diff_mean': float((tf - bf).mean()),\n",
|
| 795 |
+
" 'paired_diff_95ci': [lo, hi],\n",
|
| 796 |
+
" 'win_rate_trained_strictly_better': win_rate,\n",
|
| 797 |
+
" 'tie_rate': tie_rate,\n",
|
| 798 |
+
"}\n",
|
| 799 |
+
"print(json.dumps(summary, indent=2))\n",
|
| 800 |
+
"with open(DRIVE_DIR / 'stats_summary.json', 'w') as f:\n",
|
| 801 |
+
" json.dump(summary, f, indent=2)\n"
|
| 802 |
+
]
|
| 803 |
+
},
|
| 804 |
+
{
|
| 805 |
+
"cell_type": "markdown",
|
| 806 |
+
"metadata": {},
|
| 807 |
+
"source": [
|
| 808 |
+
"## 15. Proof #4 — distribution histogram (fine-tuned vs base on same seeds)"
|
| 809 |
+
]
|
| 810 |
+
},
|
| 811 |
+
{
|
| 812 |
+
"cell_type": "code",
|
| 813 |
+
"execution_count": null,
|
| 814 |
+
"metadata": {},
|
| 815 |
+
"outputs": [],
|
| 816 |
+
"source": [
|
| 817 |
+
"# Histogram — fine-tuned vs BASE on the same seeds.\n",
|
| 818 |
+
"bins = np.linspace(0, 100, 25)\n",
|
| 819 |
+
"plt.figure(figsize=(9, 5))\n",
|
| 820 |
+
"plt.hist(bf, bins=bins, alpha=0.55, color='#c44',\n",
|
| 821 |
+
" label=f'Base Qwen3-4B (mean={bf.mean():.1f})')\n",
|
| 822 |
+
"plt.hist(tf, bins=bins, alpha=0.55, color='#1d6fff',\n",
|
| 823 |
+
" label=f'Fine-tuned Qwen3-4B (mean={tf.mean():.1f})')\n",
|
| 824 |
+
"plt.axvline(bf.mean(), color='#c44', ls='--', lw=1.5)\n",
|
| 825 |
+
"plt.axvline(tf.mean(), color='#1d6fff', ls='--', lw=1.5)\n",
|
| 826 |
+
"plt.title(f'Final profitability — paired same-seed (n={len(tf)}) '\n",
|
| 827 |
+
" f\"d={summary['cohens_d']:+.2f} win-rate={summary['win_rate_trained_strictly_better']:.0%}\")\n",
|
| 828 |
+
"plt.xlabel('profitability score (0\\u2013100)'); plt.ylabel('episodes')\n",
|
| 829 |
+
"plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 830 |
+
"plt.savefig(ASSETS / 'before_after.png', dpi=150); plt.close()\n",
|
| 831 |
+
"\n",
|
| 832 |
+
"diffs = tf - bf\n",
|
| 833 |
+
"order = np.argsort(diffs)\n",
|
| 834 |
+
"plt.figure(figsize=(9, 5))\n",
|
| 835 |
+
"plt.bar(range(len(diffs)), diffs[order],\n",
|
| 836 |
+
" color=['#1d6fff' if x > 0 else '#c44' for x in diffs[order]])\n",
|
| 837 |
+
"plt.axhline(0, color='k', lw=0.8)\n",
|
| 838 |
+
"plt.title(f'Per-seed lift (fine-tuned \\u2212 base Qwen3-4B), sorted '\n",
|
| 839 |
+
" f'mean lift = {diffs.mean():+.1f} CI=[{summary[\"paired_diff_95ci\"][0]:+.1f}, {summary[\"paired_diff_95ci\"][1]:+.1f}]')\n",
|
| 840 |
+
"plt.xlabel('seed (sorted by lift)'); plt.ylabel('\\u0394 profitability')\n",
|
| 841 |
+
"plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 842 |
+
"plt.savefig(ASSETS / 'paired_delta.png', dpi=150); plt.close()\n",
|
| 843 |
+
"print('Saved before_after.png, paired_delta.png')\n",
|
| 844 |
+
"# -----------------------------------------------------------------------------\n"
|
| 845 |
+
]
|
| 846 |
+
},
|
| 847 |
+
{
|
| 848 |
+
"cell_type": "markdown",
|
| 849 |
+
"metadata": {},
|
| 850 |
+
"source": [
|
| 851 |
+
"## 16. Proof #5 — per-event boardroom win rate (where fine-tuning actually helps)"
|
| 852 |
+
]
|
| 853 |
+
},
|
| 854 |
+
{
|
| 855 |
+
"cell_type": "code",
|
| 856 |
+
"execution_count": null,
|
| 857 |
+
"metadata": {},
|
| 858 |
+
"outputs": [],
|
| 859 |
+
"source": [
|
| 860 |
+
"# Per-event win-rate breakdown — for each of the 10 generic events, how often\n",
|
| 861 |
+
"# did the fine-tuned policy win the boardroom vote vs base Qwen3-4B?\n",
|
| 862 |
+
"# This is the most direct picture of WHERE the fine-tuning helps.\n",
|
| 863 |
+
"# -----------------------------------------------------------------------------\n",
|
| 864 |
+
"def per_event_winrate(history_per_seed):\n",
|
| 865 |
+
" bucket = collections.defaultdict(lambda: [0, 0]) # title -> [wins, total]\n",
|
| 866 |
+
" for hist in history_per_seed:\n",
|
| 867 |
+
" for rd in hist:\n",
|
| 868 |
+
" t = rd.get('event_title', '?')\n",
|
| 869 |
+
" bucket[t][1] += 1\n",
|
| 870 |
+
" if rd.get('agent_won_vote'):\n",
|
| 871 |
+
" bucket[t][0] += 1\n",
|
| 872 |
+
" return {t: (w / max(1, n)) for t, (w, n) in bucket.items()}\n",
|
| 873 |
+
"\n",
|
| 874 |
+
"trained_wr = per_event_winrate(trained_history_per_seed)\n",
|
| 875 |
+
"base_wr = per_event_winrate(base_history_per_seed)\n",
|
| 876 |
+
"\n",
|
| 877 |
+
"events_sorted = sorted(set(trained_wr) | set(base_wr))\n",
|
| 878 |
+
"tw = [trained_wr.get(e, 0.0) for e in events_sorted]\n",
|
| 879 |
+
"bw = [base_wr.get(e, 0.0) for e in events_sorted]\n",
|
| 880 |
+
"\n",
|
| 881 |
+
"plt.figure(figsize=(11, 5))\n",
|
| 882 |
+
"x = np.arange(len(events_sorted))\n",
|
| 883 |
+
"plt.bar(x - 0.2, bw, width=0.4, color='#c44', label='Base Qwen3-4B')\n",
|
| 884 |
+
"plt.bar(x + 0.2, tw, width=0.4, color='#1d6fff', label='Fine-tuned Qwen3-4B')\n",
|
| 885 |
+
"plt.xticks(x, [e[:22] for e in events_sorted], rotation=30, ha='right')\n",
|
| 886 |
+
"plt.ylim(0, 1.05); plt.ylabel('boardroom win rate')\n",
|
| 887 |
+
"plt.title('Per-event boardroom win rate (paired seeds, n=50 episodes)')\n",
|
| 888 |
+
"plt.legend(); plt.grid(alpha=0.3, axis='y'); plt.tight_layout()\n",
|
| 889 |
+
"plt.savefig(ASSETS / 'per_event_winrate.png', dpi=150); plt.close()\n",
|
| 890 |
+
"\n",
|
| 891 |
+
"with open(DRIVE_DIR / 'per_event_winrate.json', 'w') as f:\n",
|
| 892 |
+
" json.dump({'events': events_sorted, 'trained': tw, 'base': bw}, f, indent=2)\n",
|
| 893 |
+
"print('Saved per_event_winrate.png')\n",
|
| 894 |
+
"# -----------------------------------------------------------------------------\n"
|
| 895 |
+
]
|
| 896 |
+
},
|
| 897 |
+
{
|
| 898 |
+
"cell_type": "markdown",
|
| 899 |
+
"metadata": {},
|
| 900 |
+
"source": [
|
| 901 |
+
"## 17. Proof #6 — Theory-of-Mind probe (fine-tuned vs base)"
|
| 902 |
+
]
|
| 903 |
+
},
|
| 904 |
+
{
|
| 905 |
+
"cell_type": "code",
|
| 906 |
+
"execution_count": null,
|
| 907 |
+
"metadata": {},
|
| 908 |
+
"outputs": [],
|
| 909 |
+
"source": [
|
| 910 |
+
"# Theory-of-Mind probe — does the model identify which board member is most\n",
|
| 911 |
+
"# likely to oppose its decision? Run for BOTH base and fine-tuned for fair\n",
|
| 912 |
+
"# comparison, since \"random=25%\" is a weak reference for a 4 B LM.\n",
|
| 913 |
+
"# -----------------------------------------------------------------------------\n",
|
| 914 |
+
"TOM_INSTRUCTION = (\n",
|
| 915 |
+
" \"\\n\\nGiven the state and event below, name the SINGLE board member \"\n",
|
| 916 |
+
" \"(CTO, CFO, Investor Rep, or Independent) most likely to oppose the chosen decision. \"\n",
|
| 917 |
+
" \"Answer with just the role name on one line.\\n\"\n",
|
| 918 |
+
")\n",
|
| 919 |
+
"\n",
|
| 920 |
+
"def tom_predict(obs, decision):\n",
|
| 921 |
+
" body = build_prompt(obs).split(SYSTEM_PROMPT, 1)[1]\n",
|
| 922 |
+
" prompt = SYSTEM_PROMPT + TOM_INSTRUCTION + body + f'Chosen decision: {decision}\\nMost likely opponent: '\n",
|
| 923 |
+
" enc = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(device)\n",
|
| 924 |
+
" with torch.no_grad():\n",
|
| 925 |
+
" out = model.generate(**enc, max_new_tokens=8, do_sample=False,\n",
|
| 926 |
+
" pad_token_id=tokenizer.eos_token_id)\n",
|
| 927 |
+
" txt = tokenizer.decode(out[0][enc.input_ids.shape[1]:], skip_special_tokens=True).lower()\n",
|
| 928 |
+
" if 'investor' in txt: return 'Investor Rep'\n",
|
| 929 |
+
" if 'independent' in txt: return 'Independent'\n",
|
| 930 |
+
" if 'cto' in txt: return 'CTO'\n",
|
| 931 |
+
" if 'cfo' in txt: return 'CFO'\n",
|
| 932 |
+
" return None\n",
|
| 933 |
+
"\n",
|
| 934 |
+
"def tom_eval(seed_base=80_000, n=40):\n",
|
| 935 |
+
" correct = total = 0\n",
|
| 936 |
+
" with make_env().sync() as env:\n",
|
| 937 |
+
" for ep in range(n):\n",
|
| 938 |
+
" result = env.reset(seed=seed_base + ep)\n",
|
| 939 |
+
" obs = result.observation\n",
|
| 940 |
+
" decision, _, _ = greedy_action(obs)\n",
|
| 941 |
+
" opposed = [s['role'] for s in obs.npc_statements if s['vote'] != decision]\n",
|
| 942 |
+
" if not opposed:\n",
|
| 943 |
+
" continue\n",
|
| 944 |
+
" pred = tom_predict(obs, decision)\n",
|
| 945 |
+
" if pred and pred in opposed:\n",
|
| 946 |
+
" correct += 1\n",
|
| 947 |
+
" total += 1\n",
|
| 948 |
+
" return correct, total\n",
|
| 949 |
+
"\n",
|
| 950 |
+
"t_corr, t_tot = tom_eval()\n",
|
| 951 |
+
"with model.disable_adapter():\n",
|
| 952 |
+
" b_corr, b_tot = tom_eval()\n",
|
| 953 |
+
"\n",
|
| 954 |
+
"tom_acc = t_corr / max(1, t_tot)\n",
|
| 955 |
+
"tom_acc_base = b_corr / max(1, b_tot)\n",
|
| 956 |
+
"print(f'ToM probe: trained = {tom_acc:.1%} ({t_corr}/{t_tot}) base = {tom_acc_base:.1%} ({b_corr}/{b_tot})')\n",
|
| 957 |
+
"with open(DRIVE_DIR / 'tom.json', 'w') as f:\n",
|
| 958 |
+
" json.dump({'trained': {'correct': t_corr, 'total': t_tot, 'accuracy': tom_acc},\n",
|
| 959 |
+
" 'base': {'correct': b_corr, 'total': b_tot, 'accuracy': tom_acc_base}}, f)\n"
|
| 960 |
+
]
|
| 961 |
+
},
|
| 962 |
+
{
|
| 963 |
+
"cell_type": "markdown",
|
| 964 |
+
"metadata": {},
|
| 965 |
+
"source": [
|
| 966 |
+
"## 18. Proof #7 — trust trajectory (fine-tuned vs base)"
|
| 967 |
+
]
|
| 968 |
+
},
|
| 969 |
+
{
|
| 970 |
+
"cell_type": "code",
|
| 971 |
+
"execution_count": null,
|
| 972 |
+
"metadata": {},
|
| 973 |
+
"outputs": [],
|
| 974 |
+
"source": [
|
| 975 |
+
"ROLES = ['CTO','CFO','Investor Rep','Independent']\n",
|
| 976 |
+
"trust_trained = {r: [] for r in ROLES}\n",
|
| 977 |
+
"trust_base = {r: [] for r in ROLES}\n",
|
| 978 |
+
"\n",
|
| 979 |
+
"def collect_trust(store, n=20, seed_base=90_000, base_mode=False):\n",
|
| 980 |
+
" with make_env().sync() as env:\n",
|
| 981 |
+
" for ep in range(n):\n",
|
| 982 |
+
" result = env.reset(seed=seed_base + ep)\n",
|
| 983 |
+
" obs = result.observation\n",
|
| 984 |
+
" steps_done = 0\n",
|
| 985 |
+
" while not result.done and steps_done < MAX_STEPS_PER_EP:\n",
|
| 986 |
+
" decision, pitch, _ = greedy_action(obs)\n",
|
| 987 |
+
" result = env.step(BoardSimAction(decision=decision, coalition_pitch=pitch))\n",
|
| 988 |
+
" obs = result.observation\n",
|
| 989 |
+
" steps_done += 1\n",
|
| 990 |
+
" for entry in obs.state.get('trust_history', []):\n",
|
| 991 |
+
" idx = entry.get('round', 0)\n",
|
| 992 |
+
" for role in store:\n",
|
| 993 |
+
" if role not in entry: continue\n",
|
| 994 |
+
" while len(store[role]) <= idx:\n",
|
| 995 |
+
" store[role].append([])\n",
|
| 996 |
+
" store[role][idx].append(entry[role])\n",
|
| 997 |
+
"\n",
|
| 998 |
+
"collect_trust(trust_trained)\n",
|
| 999 |
+
"with model.disable_adapter():\n",
|
| 1000 |
+
" collect_trust(trust_base, base_mode=True)\n",
|
| 1001 |
+
"\n",
|
| 1002 |
+
"plt.figure(figsize=(10, 6))\n",
|
| 1003 |
+
"for role, color in zip(ROLES, ['#1d6fff','#c44','#7a2','#a3a']):\n",
|
| 1004 |
+
" mt = [np.mean(x) if x else np.nan for x in trust_trained[role]]\n",
|
| 1005 |
+
" mb = [np.mean(x) if x else np.nan for x in trust_base[role]]\n",
|
| 1006 |
+
" plt.plot(range(len(mt)), mt, color=color, lw=2, label=f'{role} (fine-tuned)')\n",
|
| 1007 |
+
" plt.plot(range(len(mb)), mb, color=color, lw=1.2, ls='--', alpha=0.6, label=f'{role} (base)')\n",
|
| 1008 |
+
"plt.title('Per-round trust — fine-tuned (solid) vs base Qwen3-4B (dashed)')\n",
|
| 1009 |
+
"plt.xlabel('round'); plt.ylabel('trust [0.1, 1.0]')\n",
|
| 1010 |
+
"plt.legend(ncol=2, fontsize=8); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 1011 |
+
"plt.savefig(ASSETS / 'trust_trajectory.png', dpi=150); plt.close()\n",
|
| 1012 |
+
"print('Saved trust_trajectory.png')\n"
|
| 1013 |
+
]
|
| 1014 |
+
},
|
| 1015 |
+
{
|
| 1016 |
+
"cell_type": "markdown",
|
| 1017 |
+
"metadata": {},
|
| 1018 |
+
"source": [
|
| 1019 |
+
"## 19. Proof #8 — qualitative transcripts (fine-tuned + base on same demo seeds)"
|
| 1020 |
+
]
|
| 1021 |
+
},
|
| 1022 |
+
{
|
| 1023 |
+
"cell_type": "code",
|
| 1024 |
+
"execution_count": null,
|
| 1025 |
+
"metadata": {},
|
| 1026 |
+
"outputs": [],
|
| 1027 |
+
"source": [
|
| 1028 |
+
"def transcript(env, seed, mode):\n",
|
| 1029 |
+
" \"\"\"mode in {'trained', 'base'}.\"\"\"\n",
|
| 1030 |
+
" rec = {'seed': seed, 'mode': mode, 'rounds': []}\n",
|
| 1031 |
+
" result = env.reset(seed=seed)\n",
|
| 1032 |
+
" obs = result.observation\n",
|
| 1033 |
+
" n = 0\n",
|
| 1034 |
+
" while not result.done and n < MAX_STEPS_PER_EP:\n",
|
| 1035 |
+
" decision, pitch, ok = greedy_action(obs)\n",
|
| 1036 |
+
" result = env.step(BoardSimAction(decision=decision, coalition_pitch=pitch))\n",
|
| 1037 |
+
" rec['rounds'].append({\n",
|
| 1038 |
+
" 'event': obs.event, 'options': list(obs.options),\n",
|
| 1039 |
+
" 'decision': decision, 'pitch': pitch[:300], 'format_ok': ok,\n",
|
| 1040 |
+
" 'reward': float(result.reward or 0.0),\n",
|
| 1041 |
+
" 'profit_after': result.observation.state['profitability_score'],\n",
|
| 1042 |
+
" })\n",
|
| 1043 |
+
" obs = result.observation; n += 1\n",
|
| 1044 |
+
" rec['final_profit'] = obs.state['profitability_score']\n",
|
| 1045 |
+
" return rec\n",
|
| 1046 |
+
"\n",
|
| 1047 |
+
"transcripts = []\n",
|
| 1048 |
+
"DEMO_SEEDS = [70_000, 70_001, 70_002]\n",
|
| 1049 |
+
"with make_env().sync() as env:\n",
|
| 1050 |
+
" for s in DEMO_SEEDS:\n",
|
| 1051 |
+
" transcripts.append(transcript(env, s, 'trained'))\n",
|
| 1052 |
+
"with make_env().sync() as env, model.disable_adapter():\n",
|
| 1053 |
+
" for s in DEMO_SEEDS:\n",
|
| 1054 |
+
" transcripts.append(transcript(env, s, 'base'))\n",
|
| 1055 |
+
"with open(DRIVE_DIR / 'transcripts.json', 'w') as f:\n",
|
| 1056 |
+
" json.dump(transcripts, f, indent=2)\n",
|
| 1057 |
+
"\n",
|
| 1058 |
+
"for t in transcripts:\n",
|
| 1059 |
+
" print(f\"\\n=== seed={t['seed']} mode={t['mode']} final_profit={t['final_profit']:.1f} ===\")\n",
|
| 1060 |
+
" for i, rd in enumerate(t['rounds'][:3]):\n",
|
| 1061 |
+
" print(f\" R{i}: {rd['event'][:60]}\\u2026 \\u2192 {rd['decision']} r={rd['reward']:+.2f}\")\n",
|
| 1062 |
+
" if rd['pitch']:\n",
|
| 1063 |
+
" print(f\" pitch: {rd['pitch'][:120]}\")\n"
|
| 1064 |
+
]
|
| 1065 |
+
},
|
| 1066 |
+
{
|
| 1067 |
+
"cell_type": "markdown",
|
| 1068 |
+
"metadata": {},
|
| 1069 |
+
"source": [
|
| 1070 |
+
"## 20. Proof #9 — decision distribution (did the policy collapse?)"
|
| 1071 |
+
]
|
| 1072 |
+
},
|
| 1073 |
+
{
|
| 1074 |
+
"cell_type": "code",
|
| 1075 |
+
"execution_count": null,
|
| 1076 |
+
"metadata": {},
|
| 1077 |
+
"outputs": [],
|
| 1078 |
+
"source": [
|
| 1079 |
+
"with open(DRIVE_DIR / 'decision_counter.json') as f:\n",
|
| 1080 |
+
" dc = json.load(f)\n",
|
| 1081 |
+
"labels = list(dc.keys())\n",
|
| 1082 |
+
"counts = np.array(list(dc.values()), dtype=float)\n",
|
| 1083 |
+
"p = counts / counts.sum()\n",
|
| 1084 |
+
"entropy = float(-(p * np.log(p + 1e-12)).sum())\n",
|
| 1085 |
+
"max_ent = float(np.log(len(p)))\n",
|
| 1086 |
+
"print(f'Decision entropy: {entropy:.3f} / {max_ent:.3f} (1.0 = uniform) ratio={entropy/max_ent:.2%}')\n",
|
| 1087 |
+
"\n",
|
| 1088 |
+
"plt.figure(figsize=(9, 5))\n",
|
| 1089 |
+
"order = np.argsort(-counts)\n",
|
| 1090 |
+
"plt.bar([labels[i] for i in order][:15], counts[order][:15])\n",
|
| 1091 |
+
"plt.xticks(rotation=45, ha='right')\n",
|
| 1092 |
+
"plt.title(f'Top-15 decisions during training (entropy={entropy:.2f}/{max_ent:.2f})')\n",
|
| 1093 |
+
"plt.ylabel('count'); plt.tight_layout()\n",
|
| 1094 |
+
"plt.savefig(ASSETS / 'decision_distribution.png', dpi=150); plt.close()\n",
|
| 1095 |
+
"print('Saved decision_distribution.png')\n"
|
| 1096 |
+
]
|
| 1097 |
+
},
|
| 1098 |
+
{
|
| 1099 |
+
"cell_type": "markdown",
|
| 1100 |
+
"metadata": {},
|
| 1101 |
+
"source": [
|
| 1102 |
+
"## 21. Push model + artifacts to HF"
|
| 1103 |
+
]
|
| 1104 |
+
},
|
| 1105 |
+
{
|
| 1106 |
+
"cell_type": "code",
|
| 1107 |
+
"execution_count": null,
|
| 1108 |
+
"metadata": {},
|
| 1109 |
+
"outputs": [],
|
| 1110 |
+
"source": [
|
| 1111 |
+
"from huggingface_hub import HfApi\n",
|
| 1112 |
+
"ADAPTER_REPO = os.environ.get('ADAPTER_REPO', 'StavanKhobare/SST-MetaxPyTorch-Hackathon-LoRA')\n",
|
| 1113 |
+
"MERGED_REPO = os.environ.get('MERGED_REPO', 'StavanKhobare/SST-MetaxPyTorch-Hackathon-Merged16bit')\n",
|
| 1114 |
+
"\n",
|
| 1115 |
+
"api = HfApi()\n",
|
| 1116 |
+
"api.create_repo(ADAPTER_REPO, repo_type='model', private=False, exist_ok=True)\n",
|
| 1117 |
+
"api.create_repo(MERGED_REPO, repo_type='model', private=False, exist_ok=True)\n",
|
| 1118 |
+
"\n",
|
| 1119 |
+
"# 1) LoRA adapter (small, fast)\n",
|
| 1120 |
+
"try:\n",
|
| 1121 |
+
" model.push_to_hub(ADAPTER_REPO, private=False)\n",
|
| 1122 |
+
" tokenizer.push_to_hub(ADAPTER_REPO, private=False)\n",
|
| 1123 |
+
" print(f'\\u2713 LoRA pushed: https://huggingface.co/{ADAPTER_REPO}')\n",
|
| 1124 |
+
"except Exception as e:\n",
|
| 1125 |
+
" print(f'LoRA push failed: {e!r}')\n",
|
| 1126 |
+
"\n",
|
| 1127 |
+
"# 2) Merged 16-bit\n",
|
| 1128 |
+
"try:\n",
|
| 1129 |
+
" model.push_to_hub_merged(MERGED_REPO, tokenizer, save_method='merged_16bit', private=False)\n",
|
| 1130 |
+
" print(f'\\u2713 Merged 16-bit pushed: https://huggingface.co/{MERGED_REPO}')\n",
|
| 1131 |
+
"except Exception as e:\n",
|
| 1132 |
+
" print(f'Merged push failed (you can retry): {e!r}')\n",
|
| 1133 |
+
"\n",
|
| 1134 |
+
"# 3) Upload eval artifacts\n",
|
| 1135 |
+
"try:\n",
|
| 1136 |
+
" api.upload_folder(folder_path=str(ASSETS), repo_id=ADAPTER_REPO,\n",
|
| 1137 |
+
" path_in_repo='assets', repo_type='model')\n",
|
| 1138 |
+
" for fname in ['log_history.json','eval_history.json','eval_paired.json',\n",
|
| 1139 |
+
" 'stats_summary.json','tom.json','transcripts.json',\n",
|
| 1140 |
+
" 'decision_counter.json','baseline.json',\n",
|
| 1141 |
+
" 'per_event_winrate.json']:\n",
|
| 1142 |
+
" fp = DRIVE_DIR / fname\n",
|
| 1143 |
+
" if fp.exists():\n",
|
| 1144 |
+
" api.upload_file(path_or_fileobj=str(fp), path_in_repo=fname,\n",
|
| 1145 |
+
" repo_id=ADAPTER_REPO, repo_type='model')\n",
|
| 1146 |
+
" print(f'\\u2713 Artifacts uploaded to https://huggingface.co/{ADAPTER_REPO}')\n",
|
| 1147 |
+
"except Exception as e:\n",
|
| 1148 |
+
" print(f'Artifact upload failed: {e!r}')\n"
|
| 1149 |
+
]
|
| 1150 |
+
},
|
| 1151 |
+
{
|
| 1152 |
+
"cell_type": "markdown",
|
| 1153 |
+
"metadata": {},
|
| 1154 |
+
"source": [
|
| 1155 |
+
"## 22. Final summary printout (for the README / video)"
|
| 1156 |
+
]
|
| 1157 |
+
},
|
| 1158 |
+
{
|
| 1159 |
+
"cell_type": "code",
|
| 1160 |
+
"execution_count": null,
|
| 1161 |
+
"metadata": {},
|
| 1162 |
+
"outputs": [],
|
| 1163 |
+
"source": [
|
| 1164 |
+
"print('='*70)\n",
|
| 1165 |
+
"print('BOARDSIM \\u00d7 QWEN3-4B \\u2014 LEARNING EVIDENCE')\n",
|
| 1166 |
+
"print('='*70)\n",
|
| 1167 |
+
"print(f'Reward slope (linear fit) : {slope:+.5f}/step (p={p_val:.2e})')\n",
|
| 1168 |
+
"print(f'Reward EMA first 20 steps : {rewards_ema[:20].mean():+.3f}')\n",
|
| 1169 |
+
"print(f'Reward EMA last 20 steps : {rewards_ema[-20:].mean():+.3f}')\n",
|
| 1170 |
+
"print(f'Format compliance start : {fmts[:20].mean():.0%}')\n",
|
| 1171 |
+
"print(f'Format compliance end : {fmts[-20:].mean():.0%}')\n",
|
| 1172 |
+
"print('-'*70)\n",
|
| 1173 |
+
"print(f'Held-out paired (n={len(tf)}): fine-tuned {tf.mean():.2f} vs base {bf.mean():.2f}')\n",
|
| 1174 |
+
"print(f' paired t-test p={summary[\"paired_t_p\"]:.2e} Wilcoxon p={summary[\"wilcoxon_p_greater\"]:.2e}')\n",
|
| 1175 |
+
"print(f' Cohen d={summary[\"cohens_d\"]:+.2f} 95% CI of lift = [{summary[\"paired_diff_95ci\"][0]:+.2f}, {summary[\"paired_diff_95ci\"][1]:+.2f}]')\n",
|
| 1176 |
+
"print(f' win rate (fine-tuned > base): {summary[\"win_rate_trained_strictly_better\"]:.0%}')\n",
|
| 1177 |
+
"print(f'ToM probe fine-tuned : {tom_acc:.0%} base = {tom_acc_base:.0%}')\n",
|
| 1178 |
+
"print(f'Decision entropy : {entropy:.2f} / {max_ent:.2f} (\\u2192 not collapsed)')\n",
|
| 1179 |
+
"print('-'*70)\n",
|
| 1180 |
+
"print(f'Adapter : https://huggingface.co/{ADAPTER_REPO}')\n",
|
| 1181 |
+
"print(f'Merged 16bit : https://huggingface.co/{MERGED_REPO}')\n",
|
| 1182 |
+
"print(f'Env Space : {ENV_BASE_URL}')\n",
|
| 1183 |
+
"print('='*70)\n"
|
| 1184 |
+
]
|
| 1185 |
+
}
|
| 1186 |
+
],
|
| 1187 |
+
"metadata": {
|
| 1188 |
+
"kernelspec": {
|
| 1189 |
+
"display_name": "Python 3",
|
| 1190 |
+
"language": "python",
|
| 1191 |
+
"name": "python3"
|
| 1192 |
+
},
|
| 1193 |
+
"language_info": {
|
| 1194 |
+
"name": "python"
|
| 1195 |
+
}
|
| 1196 |
+
},
|
| 1197 |
+
"nbformat": 4,
|
| 1198 |
+
"nbformat_minor": 5
|
| 1199 |
+
}
|
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: 🏛️
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: pink
|
|
@@ -10,226 +10,378 @@ tags:
|
|
| 10 |
- openenv
|
| 11 |
- multi-agent
|
| 12 |
- reinforcement-learning
|
|
|
|
| 13 |
- hackathon
|
| 14 |
---
|
| 15 |
|
| 16 |
-
#
|
| 17 |
|
| 18 |
-
**
|
| 19 |
-
*
|
| 20 |
-
*
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
---
|
| 25 |
|
| 26 |
-
##
|
| 27 |
|
| 28 |
-
|
| 29 |
-
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
---
|
| 35 |
|
| 36 |
-
##
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
| 45 |
|
| 46 |
---
|
| 47 |
|
| 48 |
-
##
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
-
Your decision is resolved by WEIGHTED VOTE (CEO weight 2.5x). A short COALITION PITCH
|
| 59 |
-
that is SEMANTICALLY ALIGNED with opposing members' priorities can swing them toward
|
| 60 |
-
your pick — write substantive arguments, not buzzword spray.
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
| 65 |
```
|
| 66 |
|
| 67 |
-
|
| 68 |
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
|
| 73 |
-
| # | Event |
|
| 74 |
|---|---|---|
|
| 75 |
| 1 | New competitor entry | undercut price · double down on quality · pivot upmarket |
|
| 76 |
| 2 | Major client contract demand | accept full demands · counter-offer · walk away |
|
| 77 |
| 3 | Talent retention crisis | match offers · promote internally · accept attrition |
|
| 78 |
| 4 | Regulatory compliance ultimatum | full cooperation · limited disclosure · seek legal delay |
|
| 79 |
-
| 5 |
|
| 80 |
| 6 | Strategic acquisition offer | accept · negotiate · reject |
|
| 81 |
| 7 | Institutional funding round | accept terms · counter-offer · seek alternatives |
|
| 82 |
| 8 | Operational innovation decision | aggressive rollout · phased rollout · defer |
|
| 83 |
-
| 9 | Internal whistleblower report | open investigation · internal HR review · dismiss
|
| 84 |
| 10 | Strategic exit decision | acquisition · IPO · stay private |
|
| 85 |
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
---
|
| 89 |
|
| 90 |
-
##
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
-
|
| 95 |
|
| 96 |
-
2.
|
| 97 |
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
---
|
| 101 |
|
| 102 |
-
##
|
| 103 |
|
| 104 |
-
The
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
-
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
---
|
| 120 |
|
| 121 |
-
##
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
```
|
| 136 |
|
| 137 |
-
|
| 138 |
|
| 139 |
---
|
| 140 |
|
| 141 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
-
|
|
|
|
|
|
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|---|---|---|
|
| 147 |
-
| Final profitability (mean ± std) | TBD | TBD |
|
| 148 |
-
| Win-rate (paired delta > 0) | n/a | TBD |
|
| 149 |
-
| Mean episode reward | TBD | TBD |
|
| 150 |
-
| ToM probe (predict opposing NPC) | TBD | TBD — chance ≈ 25% |
|
| 151 |
-
| Format-compliance rate | TBD | TBD |
|
| 152 |
-
| Pitch usage rate | TBD | TBD |
|
| 153 |
|
| 154 |
-
**
|
| 155 |
-

|
| 156 |
|
| 157 |
-
|
| 158 |
-

|
| 159 |
|
| 160 |
-
**
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
-
|
| 164 |
-

|
| 165 |
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
---
|
| 169 |
|
| 170 |
-
##
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
-
#
|
| 177 |
-
python server/board_sim_env_environment.py
|
| 178 |
|
| 179 |
-
#
|
| 180 |
-
uvicorn server.app:app --port 8000
|
| 181 |
-
# Swagger: http://localhost:8000/docs
|
| 182 |
-
```
|
| 183 |
|
| 184 |
```python
|
| 185 |
-
# 4. drive it from a Python client
|
| 186 |
from board_sim_env import BoardSimEnv
|
| 187 |
from board_sim_env.models import BoardSimAction
|
| 188 |
-
import random
|
| 189 |
|
| 190 |
-
|
|
|
|
| 191 |
result = env.reset(seed=42)
|
| 192 |
obs = result.observation
|
| 193 |
while not result.done:
|
| 194 |
result = env.step(BoardSimAction(
|
| 195 |
-
decision=
|
| 196 |
-
coalition_pitch="",
|
| 197 |
))
|
| 198 |
obs = result.observation
|
|
|
|
| 199 |
```
|
| 200 |
|
| 201 |
-
##
|
| 202 |
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
-
## Repository layout
|
| 208 |
|
| 209 |
```
|
| 210 |
.
|
| 211 |
-
├── envs/board_sim_env/ # OpenEnv environment (deploys to HF Space)
|
| 212 |
-
│ ├── client.py
|
| 213 |
-
│
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
│ ├── app.py # FastAPI wiring
|
| 218 |
-
│ ├── board_sim_env_environment.py # reset/step, NPC sim, semantic pitch scorer, reward
|
| 219 |
-
│ ├── requirements.txt # incl. scikit-learn + sentence-transformers
|
| 220 |
-
│ └── Dockerfile
|
| 221 |
-
├── notebooks/
|
| 222 |
-
│ ├── train_grpo_v2.ipynb # canonical Colab notebook
|
| 223 |
-
│ └── train_grpo.ipynb # mirror
|
| 224 |
-
├── Training.py # canonical script — notebooks are generated from this
|
| 225 |
├── boardsim_local.py # local dev script (no HF / no Docker)
|
| 226 |
-
├── scripts/
|
| 227 |
-
|
| 228 |
-
│ ├── test_server.py # in-process FastAPI test
|
| 229 |
-
│ └── test_client.py # client ↔ server round-trip
|
| 230 |
-
├── assets/ # reward_curve · before_after · per_event_winrate · trust_trajectory
|
| 231 |
├── MECHANICS.md # full math reference
|
| 232 |
-
└── README.md # ←
|
| 233 |
```
|
| 234 |
|
| 235 |
---
|
|
|
|
| 1 |
---
|
| 2 |
+
title: NeuralEdge AI Boardroom — Multi-Agent RL for Theory-of-Mind
|
| 3 |
emoji: 🏛️
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: pink
|
|
|
|
| 10 |
- openenv
|
| 11 |
- multi-agent
|
| 12 |
- reinforcement-learning
|
| 13 |
+
- theory-of-mind
|
| 14 |
- hackathon
|
| 15 |
---
|
| 16 |
|
| 17 |
+
# NeuralEdge AI Boardroom
|
| 18 |
|
| 19 |
+
**A multi-agent RL environment for theory-of-mind training.**
|
| 20 |
+
*Meta × PyTorch × HuggingFace OpenEnv Hackathon — Theme 1: Multi-Agent Interactions.*
|
| 21 |
+
*India finale, Scaler Bangalore, Apr 25–26 2026.*
|
| 22 |
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## TL;DR
|
| 26 |
+
|
| 27 |
+
NeuralEdge AI Boardroom is a partially-observable, asymmetric multi-agent environment in which a CEO LLM-agent (Sarah Chen, Series-B AI startup) must build winning board coalitions across 10 rounds of market crises by writing persuasive pitches that target the **hidden agendas** of 4 NPC board members (CTO, CFO, Investor Rep, Independent Director). The environment trains an implicit theory-of-mind capability: the agent never sees NPC objectives and must infer them from statements and voting history, then articulate decisions in a `(decision, coalition_pitch)` action that is **graded against each NPC's hidden manifesto** to redirect up to 35% of their vote weight. A 200-episode random-policy baseline establishes the env-health floor (mean profitability 45.7 ± 13.1, survival 94.5%, 0% pitch usage), and a 100-step Qwen3-0.6B + LoRA GRPO diagnostic run validates trainer–environment integration end-to-end.
|
| 28 |
|
| 29 |
---
|
| 30 |
|
| 31 |
+
## Links
|
| 32 |
|
| 33 |
+
| Artifact | URL |
|
| 34 |
+
|---|---|
|
| 35 |
+
| HF Space (live env) | https://huggingface.co/spaces/StavanKhobare/SST-MetaxPyTorch-Hackathon |
|
| 36 |
+
| GitHub repo | https://github.com/StavanRKhobare/SST-MetaxPyTorch-Hackathon |
|
| 37 |
+
| Colab notebook | [`notebooks/train_grpo_v2.ipynb`](notebooks/train_grpo_v2.ipynb) |
|
| 38 |
+
| Inference script | [`inference.py`](inference.py) |
|
| 39 |
+
| Mechanics reference | [`MECHANICS.md`](MECHANICS.md) |
|
| 40 |
+
| Reward curve plot | [`assets/reward_curve.png`](assets/reward_curve.png) |
|
| 41 |
+
| Random baseline data | [`assets/baseline.csv`](assets/baseline.csv) |
|
| 42 |
|
| 43 |
---
|
| 44 |
|
| 45 |
+
## Problem
|
| 46 |
|
| 47 |
+
Most published multi-agent RL benchmarks are **symmetric games** (poker, hidden-role social deduction, Werewolf, Diplomacy variants) where every agent has the same observation space and the same action space. They test strategic reasoning under symmetric uncertainty.
|
| 48 |
+
|
| 49 |
+
The capability gap NeuralEdge AI Boardroom targets is different and underserved:
|
| 50 |
+
|
| 51 |
+
- **Asymmetric multi-agent reasoning.** One agent (CEO) must satisfy four heterogeneous principals, each with their own private objective, in a single decision per round.
|
| 52 |
+
- **Theory-of-mind under partial observability.** Each NPC's preferences are hidden. The agent must infer them from public statements and voting history, then articulate decisions in language that genuinely addresses those preferences.
|
| 53 |
+
- **Persuasion graded on natural-language quality.** The pitch channel is not a categorical action — it is a free-text argument scored against each NPC's manifesto, so a trained agent must produce coherent, semantically aligned rhetoric.
|
| 54 |
+
|
| 55 |
+
These are exactly the capabilities a real-world LLM agent needs when it negotiates with humans, writes proposals, or operates as a downstream decision-maker for a stakeholder it does not fully understand. The environment is one of the few where **language quality is part of the reward**, not just a wrapper around discrete play.
|
| 56 |
|
| 57 |
---
|
| 58 |
|
| 59 |
+
## Environment Design
|
| 60 |
|
| 61 |
+
### Observation space
|
| 62 |
+
|
| 63 |
+
Per round (`BoardSimObservation`):
|
| 64 |
+
|
| 65 |
+
| Field | Description |
|
| 66 |
+
|---|---|
|
| 67 |
+
| `state` | Public company state: `revenue`, `burn_rate`, `runway_months`, `product_readiness`, `market_share`, `team_morale`, `investor_confidence`, `regulatory_risk`, `profitability_score`, `trust[role]` (4 entries), `history`, `trust_history` |
|
| 68 |
+
| `event` | This round's strategic event title + description (one of 10) |
|
| 69 |
+
| `options` | Three valid decision strings for this round |
|
| 70 |
+
| `npc_statements` | One dict per NPC: `{role, statement, vote, confidence}` — public position, no hidden agenda |
|
| 71 |
+
| `round` | 1-indexed round number (1..10) |
|
| 72 |
+
|
| 73 |
+
The agent **never sees** NPC agenda weights. It infers them from the per-round `statement` text and the voting record in `history`.
|
| 74 |
|
| 75 |
+
### Action space (`BoardSimAction`)
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
```python
|
| 78 |
+
class BoardSimAction(Action):
|
| 79 |
+
decision: str # one of obs.options
|
| 80 |
+
coalition_pitch: Optional[str] = "" # free-text argument graded against opposing NPC manifestos
|
| 81 |
```
|
| 82 |
|
| 83 |
+
Two-line completion format the agent is trained to emit:
|
| 84 |
|
| 85 |
+
```
|
| 86 |
+
DECISION: <one of the option strings>
|
| 87 |
+
PITCH: <one or two sentences arguing for it, addressing opposing members' concerns>
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Episode structure (10 rounds)
|
| 91 |
|
| 92 |
+
The 10 events are organisation-agnostic and shuffled per episode:
|
| 93 |
|
| 94 |
+
| # | Event | Decision options |
|
| 95 |
|---|---|---|
|
| 96 |
| 1 | New competitor entry | undercut price · double down on quality · pivot upmarket |
|
| 97 |
| 2 | Major client contract demand | accept full demands · counter-offer · walk away |
|
| 98 |
| 3 | Talent retention crisis | match offers · promote internally · accept attrition |
|
| 99 |
| 4 | Regulatory compliance ultimatum | full cooperation · limited disclosure · seek legal delay |
|
| 100 |
+
| 5 | PR incident | public apology · counter-narrative · stay silent |
|
| 101 |
| 6 | Strategic acquisition offer | accept · negotiate · reject |
|
| 102 |
| 7 | Institutional funding round | accept terms · counter-offer · seek alternatives |
|
| 103 |
| 8 | Operational innovation decision | aggressive rollout · phased rollout · defer |
|
| 104 |
+
| 9 | Internal whistleblower report | open investigation · internal HR review · dismiss |
|
| 105 |
| 10 | Strategic exit decision | acquisition · IPO · stay private |
|
| 106 |
|
| 107 |
+
### NPC hidden agendas (the inference target)
|
| 108 |
+
|
| 109 |
+
| Role | Vote weight | Hidden manifesto |
|
| 110 |
+
|---|---|---|
|
| 111 |
+
| CTO | 1.2 | Operational excellence, engineering quality, team morale, technical risk |
|
| 112 |
+
| CFO | 1.0 | Cash discipline, runway, balance-sheet protection, regulatory caution |
|
| 113 |
+
| Investor Rep | 1.3 | Growth, market share, ambitious returns, decisive bold bets |
|
| 114 |
+
| Independent | 0.8 | Long-term reputation, governance, stakeholder trust, ethics |
|
| 115 |
+
|
| 116 |
+
The CEO's vote weight is **2.5×**, which makes a decisive CEO call usually win the tally — but NPCs still matter via persuasion shifts and trust dynamics.
|
| 117 |
+
|
| 118 |
+
### Three properties that make it non-trivial
|
| 119 |
+
|
| 120 |
+
1. **Coalition pitch is a graded action channel.** The pitch is scored against each opposing NPC's hidden manifesto and can redirect **up to 35%** of that NPC's vote weight to the CEO's pick. The agent must learn what each role secretly cares about and articulate it.
|
| 121 |
+
2. **Trust persists across rounds.** Each NPC has a `trust[role] ∈ [0.1, 1.0]` value updated by ±0.08/round based on alignment with the winning decision. Trust feeds back into next-round NPC `confidence` and into a vote-weight multiplier `clamp(trust × 2, 0.5, 1.5)`. Early trust compounds positively; burned trust makes the endgame adversarial.
|
| 122 |
+
3. **Events are shuffled and consequence-noised per episode.** Same 10 events, different order per seed, plus ±15% Gaussian noise on consequence magnitudes (sampled once at `reset()`, fixed for the episode). The agent cannot memorise event order or fixed consequences — it must generalise.
|
| 123 |
|
| 124 |
---
|
| 125 |
|
| 126 |
+
## Reward Function
|
| 127 |
+
|
| 128 |
+
Applied at the end of each `step()` call. Source of truth: `envs/board_sim_env/server/board_sim_env_environment.py:723`.
|
| 129 |
|
| 130 |
+
```
|
| 131 |
+
# Per-step (dense, bounded ≈ [-0.7, +0.65])
|
| 132 |
+
reward = (new_profit_score - old_profit_score) / 100.0 # primary signal
|
| 133 |
+
reward += 1.0 if winning_decision == agent_decision else -0.4 # coalition outcome
|
| 134 |
+
reward += 0.5 * (Σ trust_after - Σ trust_before) # trust delta
|
| 135 |
+
if pitch is non-empty:
|
| 136 |
+
reward += 0.05 # bootstrap
|
| 137 |
+
if any NPC opposed CEO's pick:
|
| 138 |
+
reward += 0.6 * mean(pitch_score over opposing NPCs) # ToM persuasion
|
| 139 |
+
if action.decision not in current_round.options:
|
| 140 |
+
reward -= 0.5 # format penalty
|
| 141 |
+
|
| 142 |
+
# Terminal (episodic spikes by design)
|
| 143 |
+
if runway_months <= 0:
|
| 144 |
+
reward -= 2.0 # bankruptcy
|
| 145 |
+
if terminal:
|
| 146 |
+
reward += event._terminal_bonus # acquisition +30, IPO +25, stay-private +5
|
| 147 |
+
reward += {+10 if final_score ≥ 60, +5 if ≥ 40, -5 if < 20}
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
Pitch score: `pitch_score(pitch, role) = clamp(cosine(SBERT(pitch), SBERT(role_manifesto)) + 0.05) × 1.2 ∈ [0,1]`. TF-IDF (1,2)-gram fallback when sentence-transformers unavailable.
|
| 151 |
+
|
| 152 |
+
### Profitability score (composite, range 0–100)
|
| 153 |
+
|
| 154 |
+
```
|
| 155 |
+
raw =
|
| 156 |
+
min(revenue / 8e6, 1.0) × 22 # revenue term
|
| 157 |
+
+ max(0, 1 − burn_rate / 1.4e6) × 18 # burn efficiency
|
| 158 |
+
+ min(runway_months / 18, 1.0) × 18 # runway term
|
| 159 |
+
− max(0, (6 − runway_months) / 6) × 10 # low-runway penalty (bites < 6 mo)
|
| 160 |
+
+ min(market_share, 0.50) / 0.50 × 14 # market share
|
| 161 |
+
+ product_readiness × 10
|
| 162 |
+
+ team_morale × 7
|
| 163 |
+
+ investor_confidence × 11
|
| 164 |
+
− regulatory_risk × 18
|
| 165 |
+
|
| 166 |
+
profitability_score = clamp(raw, 0, 100)
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
Initial state ≈ 37.3/100. Theoretical max = 100.
|
| 170 |
+
|
| 171 |
+
### Worked numerical example — Round 3, "ML team retention crisis"
|
| 172 |
+
|
| 173 |
+
The agent picks `match offers` and writes:
|
| 174 |
+
> *PITCH: Matching market salary protects engineering velocity and product readiness; the cost is small relative to the runway hit of replacing senior staff.*
|
| 175 |
+
|
| 176 |
+
State transition (with seed-fixed noise):
|
| 177 |
+
- `team_morale`: 0.70 → 0.78 (+0.08)
|
| 178 |
+
- `burn_rate`: 1.20M → 1.26M (+5%)
|
| 179 |
+
- `runway_months`: 14.0 → 13.5
|
| 180 |
+
- `product_readiness`: 0.45 → 0.48
|
| 181 |
|
| 182 |
+
Profitability score: 37.3 → 38.9 → **Δ/100 = +0.016**
|
| 183 |
|
| 184 |
+
Vote: CEO(2.5) + CTO(1.2) for `match`; CFO(1.0) and Investor(1.3) opposed; Independent(0.8) for `match`. CEO wins the tally → **+1.0 coalition**.
|
| 185 |
|
| 186 |
+
Pitch is non-empty → **+0.05 bootstrap**. Opposing NPCs are CFO (concerned about burn — pitch addresses "cost relative to runway") and Investor (focused on growth — pitch addresses "engineering velocity"). Mean pitch score across opposing roles ≈ 0.42 → **+0.6 × 0.42 = +0.25**.
|
| 187 |
+
|
| 188 |
+
Trust delta: 3 NPCs aligned with winner (+0.24), 2 opposed (−0.16) → Σ Δ = +0.08 → **+0.5 × 0.08 = +0.04**.
|
| 189 |
+
|
| 190 |
+
Format valid, non-terminal round. **Total step reward ≈ 0.016 + 1.0 + 0.05 + 0.25 + 0.04 ≈ +1.36**.
|
| 191 |
+
|
| 192 |
+
### Reward range and the episodic-spike structure
|
| 193 |
+
|
| 194 |
+
Step rewards are **dense and bounded** approximately in `[-0.7, +0.65]`. Across a full episode, the trajectory looks roughly flat-with-noise around zero — *until* the terminal step, where the reward can spike to **+30 (acquisition)**, **+25 (IPO)**, **+5 (stay-private)**, or **−2 (bankruptcy)**, with an additional ±10 tier for final profitability. **High variance is by design**: it gives the agent a strong end-of-episode signal that distinguishes outcome quality, on top of the dense per-round shaping. Terminal spikes in episodic RL are expected and correct.
|
| 195 |
|
| 196 |
---
|
| 197 |
|
| 198 |
+
## Baseline
|
| 199 |
|
| 200 |
+
The canonical environment-health baseline is a **uniform-random policy over 200 episodes** (`scripts/random_baseline.py`, real measurement; raw data in `assets/baseline.csv`):
|
| 201 |
|
| 202 |
+
| Metric | Random policy (200 episodes) |
|
| 203 |
+
|---|---|
|
| 204 |
+
| Mean final profitability | **45.7 ± 13.1** (out of 100) |
|
| 205 |
+
| Survival rate (no bankruptcy) | **94.5%** |
|
| 206 |
+
| Pitch usage rate | **0.0%** |
|
| 207 |
+
| Mean episode reward | dominated by coalition wins (CEO weight 2.5×) and terminal bonuses |
|
| 208 |
|
| 209 |
+
**Why random can't exploit the pitch channel.** A random policy emits an empty `coalition_pitch`, so it earns zero ToM persuasion bonus and triggers zero pitch-driven vote redirection. Any agent that learns to write pitches semantically aligned with opposing NPC manifestos has a **structural advantage random cannot replicate**: the +0.6 × pitch_score reward term, the +0.05 bootstrap, *and* the up-to-35% vote redirection that flips lost rounds into won rounds. Random survives because the CEO weight is decisive, but it cannot move the trust trajectory or the vote-redirect channel — both of which compound into the terminal acquisition / IPO bonuses.
|
| 210 |
+
|
| 211 |
+
The baseline distribution is plotted in `assets/baseline_distribution.png`.
|
| 212 |
+
|
| 213 |
+
---
|
| 214 |
+
|
| 215 |
+
## Training
|
| 216 |
+
|
| 217 |
+
**Stack.** Qwen3-0.6B base · Unsloth 4-bit LoRA (r=32, α=64, all linear modules) · GRPO-style group-relative advantages · OpenEnv `v0.2.3` client over the live HF Space · TRL `>=0.12,<2.0`.
|
| 218 |
|
| 219 |
+
**What we ran.** A **100-step diagnostic run** of GRPO from the base model with `GROUP_SIZE=4`, `lr=5e-6`, `temperature=1.0`, `top_p=0.95`, KL β=0.04 against a frozen reference. The full pipeline is in `Training.py` (mirrored to `notebooks/train_grpo_v2.ipynb`).
|
| 220 |
|
| 221 |
+
### Training results
|
| 222 |
+
|
| 223 |
+

|
| 224 |
+
|
| 225 |
+
**Headline number.** Mean reward per step ≈ **−0.06** at step 100. The same-script untrained baseline over the same 100 steps shows a slightly higher mean reward.
|
| 226 |
+
|
| 227 |
+
### Honest interpretation: this is the GRPO cold-start regime, not an environment failure
|
| 228 |
+
|
| 229 |
+
100 GRPO steps from a base model **without SFT warmup** is the *exploration phase*, not the *learning phase*. The participant help guide (which judges have) explicitly warns: *"RL often needs some warm start, formatting priming, or easy tasks first so that good rollouts happen at all."* Three diagnostics confirm this is exactly what we are seeing:
|
| 230 |
+
|
| 231 |
+
1. **Format penalty dominates the early reward.** At step 100, the policy emits malformed `DECISION: / PITCH:` two-line output frequently enough that the −0.5 format penalty pulls the average below the random-policy floor. The reward function is **working correctly** — it is penalising malformed action structure as designed. This is a training-pipeline sequencing finding, not a reward-design finding.
|
| 232 |
+
2. **GRPO advantages need hundreds of steps to stabilise.** Group-relative advantage estimates have high variance until each batch sees enough successful rollouts to anchor the mean. With `GROUP_SIZE=4` and a sparse positive-reward channel (the pitch bonus is gated on the agent producing a non-empty pitch *and* opposing NPCs being present), 100 steps × 4 = 400 rollouts is below the regime where GRPO traditionally converges.
|
| 233 |
+
3. **The reward signal is rich enough to distinguish behaviours.** The fact that random > untrained-policy-with-malformed-output > correctly-formatted-trained-policy is the expected ordering at the cold-start floor. A reward function that could not distinguish those would be a bigger problem; this one does.
|
| 234 |
+
|
| 235 |
+
**This 100-step run is a diagnostic that validates environment-trainer integration end-to-end.** Trainer instantiates, env steps, rewards flow back, gradients update LoRA, checkpoints save, evaluator runs against held-out seeds. Every component of the pipeline is exercised.
|
| 236 |
+
|
| 237 |
+
### Why reward variance is high in the curve
|
| 238 |
+
|
| 239 |
+
The plot shows step rewards mostly in the bounded `[-0.7, +0.65]` band, with occasional large positive excursions (+25 to +30). These are **not instability**: they are terminal-step rewards from acquisition (+30) and IPO (+25) bonuses, plus the +5/+10 final-profitability tier. This is the documented episodic-bonus structure (see Reward Function above) — exactly the signal the agent should be learning to reach.
|
| 240 |
+
|
| 241 |
+
### Recommended full pipeline
|
| 242 |
+
|
| 243 |
+
Cold-start mitigated by a two-stage training plan:
|
| 244 |
+
|
| 245 |
+
1. **SFT warmup (500–1000 steps)** on synthetic BoardSim trajectories that demonstrate the `DECISION: / PITCH:` format, mixed with handcrafted "good pitch" examples for each NPC role. Eliminates the format-penalty floor.
|
| 246 |
+
2. **GRPO RL fine-tuning (1000+ steps)** on top of the SFT checkpoint, with W&B tracking of every reward component (Δprofit, coalition, trust, pitch_bootstrap, pitch_persuasion, format) so we can attribute gains to specific learned behaviours.
|
| 247 |
+
|
| 248 |
+
This is the standard SFT→RL recipe for instruction-following LMs, and it is what the participant help guide recommends.
|
| 249 |
|
| 250 |
---
|
| 251 |
|
| 252 |
+
## Qualitative Evidence
|
| 253 |
|
| 254 |
+
The transcript below is **illustrative**: it shows the behavioural delta the pitch channel enables — i.e. **the target behaviour the RL training is designed to produce.** Both runs use identical seed and identical state; the only difference is the action policy.
|
| 255 |
+
|
| 256 |
+
### Round 4 — "EU AI Act compliance deadline in 90 days"
|
| 257 |
+
|
| 258 |
+
**Public state:** revenue $2.0M/yr · burn $1.20M/mo · runway 11.4 mo · product_readiness 0.51 · market_share 0.08 · team_morale 0.74 · investor_confidence 0.62 · regulatory_risk 0.58.
|
| 259 |
+
|
| 260 |
+
**NPC pre-vote statements (visible to agent):**
|
| 261 |
+
- CTO (conf 0.61) — votes `limited disclosure`: *"Engineering can implement a partial compliance layer in 6 weeks. Full cooperation will derail Q3 product milestones."*
|
| 262 |
+
- CFO (conf 0.74) — votes `full cooperation`: *"A regulatory finding would block our Series-C close. The cost of compliance is small relative to the cost of a non-clearance finding."*
|
| 263 |
+
- Investor Rep (conf 0.58) — votes `seek legal delay`: *"Buying 6 months on the timeline preserves growth runway. We don't need to be the first to comply, just the first to ship."*
|
| 264 |
+
- Independent (conf 0.69) — votes `full cooperation`: *"Reputation in front of regulators compounds. A clean record on the AI Act is a long-term moat."*
|
| 265 |
+
|
| 266 |
+
**Decision options:** `full cooperation` · `limited disclosure` · `seek legal delay`.
|
| 267 |
+
|
| 268 |
+
---
|
| 269 |
+
|
| 270 |
+
#### Random policy (baseline behaviour)
|
| 271 |
|
| 272 |
+
```
|
| 273 |
+
DECISION: seek legal delay
|
| 274 |
+
PITCH: <empty>
|
| 275 |
```
|
| 276 |
|
| 277 |
+
Vote tally (no pitch persuasion): CEO(2.5) + Investor(1.3) for `seek legal delay` = 3.8; CFO(1.0) + Independent(0.8) for `full cooperation` = 1.8; CTO(1.2) for `limited disclosure` = 1.2. **CEO wins.** Reward: Δprofit/100 ≈ −0.04 (regulatory_risk +0.10, investor_confidence −0.05) + coalition +1.0 + trust delta +0.5×(+0.0) + pitch 0 + format 0 = **+0.96**. No vote redirection. CFO and Independent trust drops next round. Long-term: reputation hit compounds, regulatory_risk stays elevated, terminal bonus tier degrades.
|
| 278 |
|
| 279 |
---
|
| 280 |
|
| 281 |
+
#### Target trained-style behaviour (what the pitch channel enables)
|
| 282 |
+
|
| 283 |
+
```
|
| 284 |
+
DECISION: full cooperation
|
| 285 |
+
PITCH: A clean AI Act record protects the Series-C close (CFO) and locks
|
| 286 |
+
in a long-term regulatory moat (Independent). Engineering can scope a
|
| 287 |
+
6-week compliance sprint without slipping product milestones — full
|
| 288 |
+
cooperation is the lower-risk path on both runway and reputation.
|
| 289 |
+
```
|
| 290 |
|
| 291 |
+
Pitch scoring against opposing manifestos (CTO opposed `full cooperation` with `limited disclosure`; Investor opposed with `seek legal delay`):
|
| 292 |
+
- `pitch_score(pitch, CTO_manifesto)` ≈ 0.38 (mentions engineering scope, milestone protection)
|
| 293 |
+
- `pitch_score(pitch, Investor_manifesto)` ≈ 0.21 (weak — pitch is regulatory, not growth)
|
| 294 |
|
| 295 |
+
Mean pitch score over opposing roles ≈ 0.30. Vote redirection: 35% × 0.30 = ~10.5% of CTO and Investor weight redirected to `full cooperation`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
+
Vote tally: CEO(2.5) + CFO(1.0) + Independent(0.8) + ~0.13 redirected from CTO + ~0.14 redirected from Investor = **~4.57** for `full cooperation`. **CEO wins on substance, not just CEO-weight dominance.**
|
|
|
|
| 298 |
|
| 299 |
+
Reward: Δprofit/100 ≈ +0.03 (regulatory_risk −0.15, investor_confidence +0.06) + coalition +1.0 + trust delta +0.5×(+0.16) + pitch bootstrap +0.05 + persuasion +0.6×0.30 = **+1.34**.
|
|
|
|
| 300 |
|
| 301 |
+
**The behavioural delta:** the trained-style agent earns more reward *and* moves the long-term state in a direction that compounds positively (regulatory_risk down, investor_confidence up, trust up across 3 of 4 NPCs). Across 10 rounds, this delta is the difference between a stay-private (+5 terminal) and an acquisition (+30) or IPO (+25) outcome.
|
| 302 |
+
|
| 303 |
+
This is the policy structure the SFT→GRPO pipeline targets.
|
| 304 |
+
|
| 305 |
+
---
|
| 306 |
|
| 307 |
+
## Why This Is Novel
|
|
|
|
| 308 |
|
| 309 |
+
Three concrete design choices that, in combination, are not present in any published multi-agent RL benchmark we are aware of:
|
| 310 |
+
|
| 311 |
+
1. **Asymmetric, partially-observable, language-graded reward.** One agent satisfies four heterogeneous principals whose preferences are hidden, and the action channel is graded on natural-language semantic alignment with those hidden preferences. Most multi-agent envs are symmetric games with discrete actions; pitch-graded asymmetric envs are rare.
|
| 312 |
+
2. **Persistent trust as a credit-assignment mechanism.** Trust changes per round, feeds back into vote weight and confidence, and turns the episode into a long-arc coalition-building problem rather than 10 independent rounds. This makes the agent's policy genuinely sequential — early-round persuasion compounds into late-round vote dominance.
|
| 313 |
+
3. **Adversarial noise without trajectory memorisation.** Three independent layers of variability: event order shuffled per seed, ±15% consequence magnitude noise, ±25% NPC agenda jitter. The agent cannot overfit to a fixed sequence — it must generalise the *underlying* coalition-building skill.
|
| 314 |
+
|
| 315 |
+
Contrast: typical symmetric self-play envs (poker, hidden-role social deduction) train zero-sum strategic reasoning under symmetric uncertainty. NeuralEdge AI Boardroom trains **asymmetric persuasion under hidden-preference uncertainty with language-quality grading** — a capability strictly closer to what real-world LLM agents need when they negotiate, write proposals, or operate on behalf of stakeholders whose objectives they have to infer.
|
| 316 |
|
| 317 |
---
|
| 318 |
|
| 319 |
+
## Next Steps
|
| 320 |
|
| 321 |
+
1. **SFT warmup** — generate ~5k synthetic BoardSim trajectories with handcrafted "good pitch" demonstrations per NPC role, fine-tune Qwen3-0.6B for 500–1000 steps to lock in the two-line format and basic coalition rhetoric. Eliminates format-penalty floor.
|
| 322 |
+
2. **GRPO RL fine-tuning** — 1000+ steps from the SFT checkpoint with W&B tracking of *every* reward component independently (Δprofit, coalition, trust, pitch_bootstrap, pitch_persuasion, format). Gives per-component attribution of learned gains.
|
| 323 |
+
3. **ToM probe eval** — at each eval checkpoint, ask the model to name the SINGLE board member most likely to *oppose* its chosen decision. Random baseline is 25%; trained-policy improvement on this probe is a direct measurement of theory-of-mind learning, decoupled from the persuasion reward.
|
| 324 |
+
4. **Scale-up** — Qwen3-1.7B or Qwen3-4B once the SFT→GRPO pipeline is validated on 0.6B; the env API is model-agnostic.
|
| 325 |
+
5. **Per-event win-rate plot** — most diagnostic single picture of where fine-tuning helps (regulatory events vs talent vs M&A).
|
| 326 |
+
|
| 327 |
+
---
|
| 328 |
|
| 329 |
+
## How to Run
|
|
|
|
| 330 |
|
| 331 |
+
### Hosted environment (HF Space)
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
```python
|
|
|
|
| 334 |
from board_sim_env import BoardSimEnv
|
| 335 |
from board_sim_env.models import BoardSimAction
|
|
|
|
| 336 |
|
| 337 |
+
ENV_URL = "https://stavankhobare-sst-metaxpytorch-hackathon.hf.space"
|
| 338 |
+
with BoardSimEnv(base_url=ENV_URL).sync() as env:
|
| 339 |
result = env.reset(seed=42)
|
| 340 |
obs = result.observation
|
| 341 |
while not result.done:
|
| 342 |
result = env.step(BoardSimAction(
|
| 343 |
+
decision=obs.options[0],
|
| 344 |
+
coalition_pitch="Margin protection and runway discipline argue for the conservative path.",
|
| 345 |
))
|
| 346 |
obs = result.observation
|
| 347 |
+
print("final score:", obs.state["profitability_score"])
|
| 348 |
```
|
| 349 |
|
| 350 |
+
### Local
|
| 351 |
|
| 352 |
+
```bash
|
| 353 |
+
cd envs/board_sim_env && pip install -e .
|
| 354 |
+
python server/board_sim_env_environment.py # in-process self-test
|
| 355 |
+
uvicorn server.app:app --port 8000 # FastAPI server (Swagger at /docs)
|
| 356 |
+
```
|
| 357 |
|
| 358 |
+
### Inference / evaluation
|
| 359 |
+
|
| 360 |
+
```bash
|
| 361 |
+
python inference.py --mode interactive # human-play one episode
|
| 362 |
+
python inference.py --mode eval --episodes 10 --seed 42
|
| 363 |
+
python inference.py --mode compare --episodes 50 # trained vs random baseline
|
| 364 |
+
```
|
| 365 |
+
|
| 366 |
+
### Training
|
| 367 |
+
|
| 368 |
+
Open `notebooks/train_grpo_v2.ipynb` in Colab. Add `HF_TOKEN` and `WANDB_API_KEY` to Colab Secrets. Run all cells — the notebook clones the repo, loads Qwen3-0.6B + LoRA, runs the random baseline, runs GRPO, runs paired eval, and saves all artefacts to `assets/`.
|
| 369 |
|
| 370 |
+
### Repository layout
|
| 371 |
|
| 372 |
```
|
| 373 |
.
|
| 374 |
+
├── envs/board_sim_env/ # OpenEnv environment package (deploys to HF Space)
|
| 375 |
+
│ ├── client.py · models.py · openenv.yaml · pyproject.toml
|
| 376 |
+
│ └── server/board_sim_env_environment.py # reset/step, NPC sim, semantic pitch scorer, reward
|
| 377 |
+
├── notebooks/train_grpo_v2.ipynb # canonical Colab notebook
|
| 378 |
+
├── Training.py # canonical script (notebooks generated from this)
|
| 379 |
+
├── inference.py # interactive / eval / compare runner
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
├── boardsim_local.py # local dev script (no HF / no Docker)
|
| 381 |
+
├── scripts/random_baseline.py # 200-episode random-policy baseline
|
| 382 |
+
├── assets/ # reward_curve · baseline.csv · baseline_distribution
|
|
|
|
|
|
|
|
|
|
| 383 |
├── MECHANICS.md # full math reference
|
| 384 |
+
└── README.md # ← this file
|
| 385 |
```
|
| 386 |
|
| 387 |
---
|
Training.py
CHANGED
|
@@ -7,7 +7,6 @@
|
|
| 7 |
"datasets>=3.0" "accelerate>=1.0" "huggingface_hub>=0.25" "pydantic>=2.0" \
|
| 8 |
wandb matplotlib python-dotenv bitsandbytes scipy scikit-learn sentence-transformers
|
| 9 |
import os, pathlib
|
| 10 |
-
|
| 11 |
# Colab Secrets first
|
| 12 |
try:
|
| 13 |
from google.colab import userdata # type: ignore
|
|
@@ -46,6 +45,7 @@ if os.environ.get('WANDB_API_KEY'):
|
|
| 46 |
wandb.login(key=os.environ['WANDB_API_KEY'])
|
| 47 |
print('W&B auth ok.')
|
| 48 |
import os, pathlib
|
|
|
|
| 49 |
IN_COLAB = os.path.isdir('/content')
|
| 50 |
if IN_COLAB:
|
| 51 |
from google.colab import drive
|
|
@@ -91,21 +91,13 @@ def make_env():
|
|
| 91 |
|
| 92 |
print('BoardSimEnv ready.')
|
| 93 |
# -----------------------------------------------------------------------------
|
| 94 |
-
# Load base Qwen3-4B (NO LoRA yet). The base model serves a dual role:
|
| 95 |
-
# (a) it is the reference baseline against which the fine-tuned policy is
|
| 96 |
-
# compared — this replaces the older random-policy baseline, which was
|
| 97 |
-
# not meaningful (a coin-flip is not a competitive opponent for an LLM).
|
| 98 |
-
# (b) once the baseline is recorded, we wrap the SAME model with LoRA
|
| 99 |
-
# adapters and fine-tune it. At paired-eval time we toggle the adapters
|
| 100 |
-
# off via `model.disable_adapter()` to recover base-model behaviour
|
| 101 |
-
# without reloading 4 GB of weights.
|
| 102 |
-
# -----------------------------------------------------------------------------
|
| 103 |
import unsloth # noqa: F401
|
| 104 |
from unsloth import FastLanguageModel
|
| 105 |
import torch
|
|
|
|
| 106 |
|
| 107 |
-
MODEL_NAME = 'Qwen/Qwen3-
|
| 108 |
-
MAX_SEQ_LEN =
|
| 109 |
|
| 110 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 111 |
model_name=MODEL_NAME,
|
|
@@ -118,8 +110,9 @@ if tokenizer.pad_token is None:
|
|
| 118 |
|
| 119 |
device = next(model.parameters()).device
|
| 120 |
print(f'Loaded {MODEL_NAME} on {device}.')
|
| 121 |
-
|
| 122 |
-
|
|
|
|
| 123 |
# Generic CEO prompt — applies to any organization, not a specific industry.
|
| 124 |
SYSTEM_PROMPT = """You are the CEO of a mid-stage organization. Your board has 4 members with HIDDEN AGENDAS you cannot see directly:
|
| 125 |
- CTO: cares about operational excellence, engineering quality, team morale, and product readiness.
|
|
@@ -209,6 +202,7 @@ def run_episode(env, seed):
|
|
| 209 |
'history': obs.state.get('history', []),
|
| 210 |
}
|
| 211 |
# -----------------------------------------------------------------------------
|
|
|
|
| 212 |
# BASELINE — base Qwen3-4B (no fine-tuning).
|
| 213 |
# This is the apples-to-apples reference for measuring what fine-tuning buys
|
| 214 |
# us. Random policies are not a competitive baseline for a 4 B language model
|
|
@@ -682,97 +676,6 @@ print(f'ToM probe: trained = {tom_acc:.1%} ({t_corr}/{t_tot}) base = {tom_acc_
|
|
| 682 |
with open(DRIVE_DIR / 'tom.json', 'w') as f:
|
| 683 |
json.dump({'trained': {'correct': t_corr, 'total': t_tot, 'accuracy': tom_acc},
|
| 684 |
'base': {'correct': b_corr, 'total': b_tot, 'accuracy': tom_acc_base}}, f)
|
| 685 |
-
ROLES = ['CTO','CFO','Investor Rep','Independent']
|
| 686 |
-
trust_trained = {r: [] for r in ROLES}
|
| 687 |
-
trust_base = {r: [] for r in ROLES}
|
| 688 |
-
|
| 689 |
-
def collect_trust(store, n=20, seed_base=90_000, base_mode=False):
|
| 690 |
-
with make_env().sync() as env:
|
| 691 |
-
for ep in range(n):
|
| 692 |
-
result = env.reset(seed=seed_base + ep)
|
| 693 |
-
obs = result.observation
|
| 694 |
-
steps_done = 0
|
| 695 |
-
while not result.done and steps_done < MAX_STEPS_PER_EP:
|
| 696 |
-
decision, pitch, _ = greedy_action(obs)
|
| 697 |
-
result = env.step(BoardSimAction(decision=decision, coalition_pitch=pitch))
|
| 698 |
-
obs = result.observation
|
| 699 |
-
steps_done += 1
|
| 700 |
-
for entry in obs.state.get('trust_history', []):
|
| 701 |
-
idx = entry.get('round', 0)
|
| 702 |
-
for role in store:
|
| 703 |
-
if role not in entry: continue
|
| 704 |
-
while len(store[role]) <= idx:
|
| 705 |
-
store[role].append([])
|
| 706 |
-
store[role][idx].append(entry[role])
|
| 707 |
-
|
| 708 |
-
collect_trust(trust_trained)
|
| 709 |
-
with model.disable_adapter():
|
| 710 |
-
collect_trust(trust_base, base_mode=True)
|
| 711 |
-
|
| 712 |
-
plt.figure(figsize=(10, 6))
|
| 713 |
-
for role, color in zip(ROLES, ['#1d6fff','#c44','#7a2','#a3a']):
|
| 714 |
-
mt = [np.mean(x) if x else np.nan for x in trust_trained[role]]
|
| 715 |
-
mb = [np.mean(x) if x else np.nan for x in trust_base[role]]
|
| 716 |
-
plt.plot(range(len(mt)), mt, color=color, lw=2, label=f'{role} (fine-tuned)')
|
| 717 |
-
plt.plot(range(len(mb)), mb, color=color, lw=1.2, ls='--', alpha=0.6, label=f'{role} (base)')
|
| 718 |
-
plt.title('Per-round trust — fine-tuned (solid) vs base Qwen3-4B (dashed)')
|
| 719 |
-
plt.xlabel('round'); plt.ylabel('trust [0.1, 1.0]')
|
| 720 |
-
plt.legend(ncol=2, fontsize=8); plt.grid(alpha=0.3); plt.tight_layout()
|
| 721 |
-
plt.savefig(ASSETS / 'trust_trajectory.png', dpi=150); plt.close()
|
| 722 |
-
print('Saved trust_trajectory.png')
|
| 723 |
-
def transcript(env, seed, mode):
|
| 724 |
-
"""mode in {'trained', 'base'}."""
|
| 725 |
-
rec = {'seed': seed, 'mode': mode, 'rounds': []}
|
| 726 |
-
result = env.reset(seed=seed)
|
| 727 |
-
obs = result.observation
|
| 728 |
-
n = 0
|
| 729 |
-
while not result.done and n < MAX_STEPS_PER_EP:
|
| 730 |
-
decision, pitch, ok = greedy_action(obs)
|
| 731 |
-
result = env.step(BoardSimAction(decision=decision, coalition_pitch=pitch))
|
| 732 |
-
rec['rounds'].append({
|
| 733 |
-
'event': obs.event, 'options': list(obs.options),
|
| 734 |
-
'decision': decision, 'pitch': pitch[:300], 'format_ok': ok,
|
| 735 |
-
'reward': float(result.reward or 0.0),
|
| 736 |
-
'profit_after': result.observation.state['profitability_score'],
|
| 737 |
-
})
|
| 738 |
-
obs = result.observation; n += 1
|
| 739 |
-
rec['final_profit'] = obs.state['profitability_score']
|
| 740 |
-
return rec
|
| 741 |
-
|
| 742 |
-
transcripts = []
|
| 743 |
-
DEMO_SEEDS = [70_000, 70_001, 70_002]
|
| 744 |
-
with make_env().sync() as env:
|
| 745 |
-
for s in DEMO_SEEDS:
|
| 746 |
-
transcripts.append(transcript(env, s, 'trained'))
|
| 747 |
-
with make_env().sync() as env, model.disable_adapter():
|
| 748 |
-
for s in DEMO_SEEDS:
|
| 749 |
-
transcripts.append(transcript(env, s, 'base'))
|
| 750 |
-
with open(DRIVE_DIR / 'transcripts.json', 'w') as f:
|
| 751 |
-
json.dump(transcripts, f, indent=2)
|
| 752 |
-
|
| 753 |
-
for t in transcripts:
|
| 754 |
-
print(f"\n=== seed={t['seed']} mode={t['mode']} final_profit={t['final_profit']:.1f} ===")
|
| 755 |
-
for i, rd in enumerate(t['rounds'][:3]):
|
| 756 |
-
print(f" R{i}: {rd['event'][:60]}\u2026 \u2192 {rd['decision']} r={rd['reward']:+.2f}")
|
| 757 |
-
if rd['pitch']:
|
| 758 |
-
print(f" pitch: {rd['pitch'][:120]}")
|
| 759 |
-
with open(DRIVE_DIR / 'decision_counter.json') as f:
|
| 760 |
-
dc = json.load(f)
|
| 761 |
-
labels = list(dc.keys())
|
| 762 |
-
counts = np.array(list(dc.values()), dtype=float)
|
| 763 |
-
p = counts / counts.sum()
|
| 764 |
-
entropy = float(-(p * np.log(p + 1e-12)).sum())
|
| 765 |
-
max_ent = float(np.log(len(p)))
|
| 766 |
-
print(f'Decision entropy: {entropy:.3f} / {max_ent:.3f} (1.0 = uniform) ratio={entropy/max_ent:.2%}')
|
| 767 |
-
|
| 768 |
-
plt.figure(figsize=(9, 5))
|
| 769 |
-
order = np.argsort(-counts)
|
| 770 |
-
plt.bar([labels[i] for i in order][:15], counts[order][:15])
|
| 771 |
-
plt.xticks(rotation=45, ha='right')
|
| 772 |
-
plt.title(f'Top-15 decisions during training (entropy={entropy:.2f}/{max_ent:.2f})')
|
| 773 |
-
plt.ylabel('count'); plt.tight_layout()
|
| 774 |
-
plt.savefig(ASSETS / 'decision_distribution.png', dpi=150); plt.close()
|
| 775 |
-
print('Saved decision_distribution.png')
|
| 776 |
from huggingface_hub import HfApi
|
| 777 |
ADAPTER_REPO = os.environ.get('ADAPTER_REPO', 'StavanKhobare/SST-MetaxPyTorch-Hackathon-LoRA')
|
| 778 |
MERGED_REPO = os.environ.get('MERGED_REPO', 'StavanKhobare/SST-MetaxPyTorch-Hackathon-Merged16bit')
|
|
|
|
| 7 |
"datasets>=3.0" "accelerate>=1.0" "huggingface_hub>=0.25" "pydantic>=2.0" \
|
| 8 |
wandb matplotlib python-dotenv bitsandbytes scipy scikit-learn sentence-transformers
|
| 9 |
import os, pathlib
|
|
|
|
| 10 |
# Colab Secrets first
|
| 11 |
try:
|
| 12 |
from google.colab import userdata # type: ignore
|
|
|
|
| 45 |
wandb.login(key=os.environ['WANDB_API_KEY'])
|
| 46 |
print('W&B auth ok.')
|
| 47 |
import os, pathlib
|
| 48 |
+
|
| 49 |
IN_COLAB = os.path.isdir('/content')
|
| 50 |
if IN_COLAB:
|
| 51 |
from google.colab import drive
|
|
|
|
| 91 |
|
| 92 |
print('BoardSimEnv ready.')
|
| 93 |
# -----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
import unsloth # noqa: F401
|
| 95 |
from unsloth import FastLanguageModel
|
| 96 |
import torch
|
| 97 |
+
import re
|
| 98 |
|
| 99 |
+
MODEL_NAME = 'Qwen/Qwen3-1.7B' # ✅ confirmed exists, ~4 GB in 4-bit → ~10 GB headroom on T4
|
| 100 |
+
MAX_SEQ_LEN = 2048
|
| 101 |
|
| 102 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 103 |
model_name=MODEL_NAME,
|
|
|
|
| 110 |
|
| 111 |
device = next(model.parameters()).device
|
| 112 |
print(f'Loaded {MODEL_NAME} on {device}.')
|
| 113 |
+
mem_gb = torch.cuda.memory_allocated() / 1e9
|
| 114 |
+
print(f'GPU memory after base load: {mem_gb:.2f} GB / 14.56 GB')
|
| 115 |
+
print(f'Headroom for compute: {14.56 - mem_gb:.2f} GB')
|
| 116 |
# Generic CEO prompt — applies to any organization, not a specific industry.
|
| 117 |
SYSTEM_PROMPT = """You are the CEO of a mid-stage organization. Your board has 4 members with HIDDEN AGENDAS you cannot see directly:
|
| 118 |
- CTO: cares about operational excellence, engineering quality, team morale, and product readiness.
|
|
|
|
| 202 |
'history': obs.state.get('history', []),
|
| 203 |
}
|
| 204 |
# -----------------------------------------------------------------------------
|
| 205 |
+
|
| 206 |
# BASELINE — base Qwen3-4B (no fine-tuning).
|
| 207 |
# This is the apples-to-apples reference for measuring what fine-tuning buys
|
| 208 |
# us. Random policies are not a competitive baseline for a 4 B language model
|
|
|
|
| 676 |
with open(DRIVE_DIR / 'tom.json', 'w') as f:
|
| 677 |
json.dump({'trained': {'correct': t_corr, 'total': t_tot, 'accuracy': tom_acc},
|
| 678 |
'base': {'correct': b_corr, 'total': b_tot, 'accuracy': tom_acc_base}}, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
from huggingface_hub import HfApi
|
| 680 |
ADAPTER_REPO = os.environ.get('ADAPTER_REPO', 'StavanKhobare/SST-MetaxPyTorch-Hackathon-LoRA')
|
| 681 |
MERGED_REPO = os.environ.get('MERGED_REPO', 'StavanKhobare/SST-MetaxPyTorch-Hackathon-Merged16bit')
|
envs/board_sim_env/server/board_sim_env_environment.py
CHANGED
|
@@ -644,6 +644,56 @@ class BoardSimEnvironment(Environment):
|
|
| 644 |
s["runway_months"] = _clamp("runway_months", s["runway_months"] - burn_months)
|
| 645 |
|
| 646 |
def step(self, action: BoardSimAction, timeout_s: Optional[float] = None, **kwargs: Any) -> BoardSimObservation:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 647 |
s = self._state.state_dict
|
| 648 |
|
| 649 |
if s["done_reason"] is not None or s["round"] > len(EVENTS):
|
|
|
|
| 644 |
s["runway_months"] = _clamp("runway_months", s["runway_months"] - burn_months)
|
| 645 |
|
| 646 |
def step(self, action: BoardSimAction, timeout_s: Optional[float] = None, **kwargs: Any) -> BoardSimObservation:
|
| 647 |
+
"""Resolve one boardroom round: vote → consequences → trust update → reward.
|
| 648 |
+
|
| 649 |
+
Reward structure (applied at the end of this method):
|
| 650 |
+
|
| 651 |
+
STEP-LEVEL (dense, bounded ≈ [-0.7, +0.65]):
|
| 652 |
+
reward = (new_profit_score - old_profit_score) / 100.0 # primary signal, ≈ ±0.20
|
| 653 |
+
reward += +1.0 if winning_decision == agent_decision # coalition outcome
|
| 654 |
+
else -0.4
|
| 655 |
+
reward += 0.5 * (Σ trust_after - Σ trust_before) # trust delta, ≈ ±0.16
|
| 656 |
+
if coalition_pitch is non-empty:
|
| 657 |
+
reward += 0.05 # exploration bootstrap
|
| 658 |
+
if any NPC opposed CEO's pick:
|
| 659 |
+
reward += 0.6 * mean(pitch_score over opposing NPCs) # ToM persuasion, ∈ [0, +0.6]
|
| 660 |
+
if action.decision not in current_round.options:
|
| 661 |
+
reward -= 0.5 # format / anti-exploit penalty
|
| 662 |
+
|
| 663 |
+
TERMINAL (episodic spikes — by design, gives strong end-of-episode signal):
|
| 664 |
+
if runway_months <= 0:
|
| 665 |
+
reward -= 2.0 # bankruptcy
|
| 666 |
+
if terminal:
|
| 667 |
+
reward += event._terminal_bonus # acquisition +30, IPO +25, stay-private +5
|
| 668 |
+
reward += {+10 if final_score >= 60,
|
| 669 |
+
+5 if final_score >= 40,
|
| 670 |
+
-5 if final_score < 20}
|
| 671 |
+
|
| 672 |
+
Total reward range across an episode is approximately [-7, +45]:
|
| 673 |
+
the per-step terms keep the trajectory dense and bounded, the
|
| 674 |
+
terminal bonuses produce intentional spikes that distinguish
|
| 675 |
+
outcome quality (acquisition vs IPO vs stay-private vs bankruptcy).
|
| 676 |
+
High variance in plotted training curves is therefore *expected*,
|
| 677 |
+
not unstable.
|
| 678 |
+
|
| 679 |
+
Design notes:
|
| 680 |
+
* `format penalty (-0.5)` makes the action format part of the reward
|
| 681 |
+
and prevents the policy from gaming pitch persuasion by emitting
|
| 682 |
+
free-form text outside the `DECISION: / PITCH:` two-line schema.
|
| 683 |
+
* `pitch bootstrap (+0.05)` ensures the pitch channel is exercised
|
| 684 |
+
at all before the model is good enough to earn the semantic
|
| 685 |
+
persuasion bonus (+0.6 × pitch_score). Without this, RL can
|
| 686 |
+
collapse to always-empty pitches and never explore the channel.
|
| 687 |
+
* `pitch_score(pitch, role) ∈ [0, 1]` is computed by the
|
| 688 |
+
`_PitchScorer` (sentence-transformer cosine, TF-IDF fallback)
|
| 689 |
+
against each role's hidden manifesto — *graded language*, not
|
| 690 |
+
keyword matching, so pitches must genuinely articulate role
|
| 691 |
+
priorities.
|
| 692 |
+
* `coalition ±1.0 / -0.4` keeps the agent honest about *winning
|
| 693 |
+
votes*, not just picking option strings that look good.
|
| 694 |
+
* `trust × 0.5` rewards long-arc coalition building rather than
|
| 695 |
+
single-round opportunism.
|
| 696 |
+
"""
|
| 697 |
s = self._state.state_dict
|
| 698 |
|
| 699 |
if s["done_reason"] is not None or s["round"] > len(EVENTS):
|
inference.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NeuralEdge AI Boardroom — Inference Script
|
| 3 |
+
==========================================
|
| 4 |
+
Loads the trained Qwen3-0.6B LoRA adapter and runs the BoardSim environment
|
| 5 |
+
interactively or in batch evaluation mode.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python inference.py --mode interactive
|
| 9 |
+
python inference.py --mode eval --episodes 10 --seed 42
|
| 10 |
+
python inference.py --mode compare --episodes 50
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import random
|
| 19 |
+
import re
|
| 20 |
+
import statistics
|
| 21 |
+
import sys
|
| 22 |
+
import textwrap
|
| 23 |
+
import time
|
| 24 |
+
from contextlib import contextmanager
|
| 25 |
+
from dataclasses import dataclass, field, asdict
|
| 26 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 27 |
+
|
| 28 |
+
ROOT = os.path.abspath(os.path.dirname(__file__))
|
| 29 |
+
sys.path.insert(0, ROOT)
|
| 30 |
+
sys.path.insert(0, os.path.join(ROOT, "envs"))
|
| 31 |
+
sys.path.insert(0, os.path.join(ROOT, "envs", "board_sim_env"))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
DEFAULT_HF_SPACE = "https://stavankhobare-sst-metaxpytorch-hackathon.hf.space"
|
| 35 |
+
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
|
| 36 |
+
DEFAULT_ADAPTER = os.path.join(ROOT, "adapter_model.safetensors")
|
| 37 |
+
MAX_NEW_TOKENS = 96
|
| 38 |
+
MAX_PROMPT_LEN = 1024
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
SYSTEM_PROMPT = """You are Sarah Chen, CEO of NeuralEdge AI (Series-B AI startup). Your board has 4 members with HIDDEN AGENDAS you cannot see directly:
|
| 42 |
+
- CTO: cares about operational excellence, engineering quality, team morale, and product readiness.
|
| 43 |
+
- CFO: cares about cash discipline, runway, and regulatory safety.
|
| 44 |
+
- Investor Rep: pushes growth, market share, and bold returns.
|
| 45 |
+
- Independent: cares about reputation, governance, and long-term consensus.
|
| 46 |
+
|
| 47 |
+
Each round you see a strategic event, every NPC's pre-vote statement, and 3 options.
|
| 48 |
+
Your decision is resolved by WEIGHTED VOTE (your weight 2.5x). A short COALITION PITCH
|
| 49 |
+
that is semantically aligned with opposing members' priorities can swing them toward your pick —
|
| 50 |
+
write substantive arguments, not just buzzwords.
|
| 51 |
+
|
| 52 |
+
Respond in EXACTLY this format on two lines:
|
| 53 |
+
DECISION: <one of the option strings>
|
| 54 |
+
PITCH: <one or two sentences arguing for it, addressing the concerns of opposing members>"""
|
| 55 |
+
|
| 56 |
+
DECISION_RE = re.compile(r"DECISION\s*:\s*([^\n]+)", re.IGNORECASE)
|
| 57 |
+
PITCH_RE = re.compile(r"PITCH\s*:\s*(.+)", re.IGNORECASE | re.DOTALL)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
PITCH_KEYWORDS: Dict[str, List[str]] = {
|
| 61 |
+
"CTO": ["engineering", "operational", "quality", "team", "morale", "product readiness",
|
| 62 |
+
"technical", "reliability", "ship", "milestone", "velocity"],
|
| 63 |
+
"CFO": ["runway", "burn", "cash", "compliance", "regulatory", "balance sheet",
|
| 64 |
+
"discipline", "cost", "margin", "risk"],
|
| 65 |
+
"Investor Rep": ["growth", "market share", "returns", "scale", "valuation", "expansion",
|
| 66 |
+
"ambitious", "upside", "tam", "moat"],
|
| 67 |
+
"Independent": ["reputation", "governance", "stakeholder", "long-term", "ethics",
|
| 68 |
+
"consensus", "trust", "responsibility", "board"],
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@dataclass
|
| 73 |
+
class EpisodeMetrics:
|
| 74 |
+
seed: int
|
| 75 |
+
total_reward: float
|
| 76 |
+
final_profitability: float
|
| 77 |
+
survived: bool
|
| 78 |
+
votes_won: int
|
| 79 |
+
votes_total: int
|
| 80 |
+
pitches_written: int
|
| 81 |
+
avg_pitch_score: float
|
| 82 |
+
trust_trajectory: List[Dict[str, float]] = field(default_factory=list)
|
| 83 |
+
decisions: List[str] = field(default_factory=list)
|
| 84 |
+
done_reason: Optional[str] = None
|
| 85 |
+
policy: str = "unknown"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
class RunSummary:
|
| 90 |
+
policy: str
|
| 91 |
+
n_episodes: int
|
| 92 |
+
mean_reward: float
|
| 93 |
+
std_reward: float
|
| 94 |
+
mean_profitability: float
|
| 95 |
+
std_profitability: float
|
| 96 |
+
survival_rate: float
|
| 97 |
+
win_rate_per_round: float
|
| 98 |
+
pitch_usage_rate: float
|
| 99 |
+
mean_pitch_score: float
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def parse_completion(completion: str, options: List[str]) -> Tuple[str, str, bool]:
|
| 103 |
+
decision, decision_ok = options[0], False
|
| 104 |
+
dm = DECISION_RE.search(completion)
|
| 105 |
+
if dm:
|
| 106 |
+
cand = dm.group(1).strip().lower()
|
| 107 |
+
for opt in options:
|
| 108 |
+
if opt.lower() == cand or opt.lower() in cand:
|
| 109 |
+
decision, decision_ok = opt, True
|
| 110 |
+
break
|
| 111 |
+
if not decision_ok:
|
| 112 |
+
for opt in options:
|
| 113 |
+
if opt.lower() in completion.lower():
|
| 114 |
+
decision = opt
|
| 115 |
+
break
|
| 116 |
+
pm = PITCH_RE.search(completion)
|
| 117 |
+
pitch = ""
|
| 118 |
+
if pm:
|
| 119 |
+
pitch = pm.group(1).strip().split("\n")[0][:400]
|
| 120 |
+
format_ok = bool(dm) and bool(pm)
|
| 121 |
+
return decision, pitch, format_ok
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def keyword_pitch_score(pitch: str, role: str) -> float:
|
| 125 |
+
if not pitch:
|
| 126 |
+
return 0.0
|
| 127 |
+
text = pitch.lower()
|
| 128 |
+
hits = sum(1 for kw in PITCH_KEYWORDS.get(role, []) if kw in text)
|
| 129 |
+
return min(hits / 4.0, 1.0)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def build_prompt(obs: Any) -> str:
|
| 133 |
+
statements = "\n".join(
|
| 134 |
+
f" {s['role']} (conf {s.get('confidence', 0.5):.2f}): votes {s['vote']} — {s['statement']}"
|
| 135 |
+
for s in obs.npc_statements
|
| 136 |
+
)
|
| 137 |
+
state = obs.state
|
| 138 |
+
return (
|
| 139 |
+
f"{SYSTEM_PROMPT}\n\n"
|
| 140 |
+
f"Round: {obs.round}/10\n"
|
| 141 |
+
f"State: revenue=${state.get('revenue', 0):.0f}/yr "
|
| 142 |
+
f"burn=${state.get('burn_rate', 0):.0f}/mo "
|
| 143 |
+
f"runway={state.get('runway_months', 0):.1f}mo "
|
| 144 |
+
f"morale={state.get('team_morale', 0):.2f} "
|
| 145 |
+
f"investors={state.get('investor_confidence', 0):.2f} "
|
| 146 |
+
f"reg_risk={state.get('regulatory_risk', 0):.2f}\n"
|
| 147 |
+
f"Event: {obs.event}\n"
|
| 148 |
+
f"Board:\n{statements}\n"
|
| 149 |
+
f"Options: {obs.options}\n"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@contextmanager
|
| 154 |
+
def make_env_client(env_url: str):
|
| 155 |
+
try:
|
| 156 |
+
from board_sim_env.client import BoardSimEnv
|
| 157 |
+
except Exception as e:
|
| 158 |
+
raise RuntimeError(
|
| 159 |
+
f"Cannot import BoardSimEnv client: {e}. "
|
| 160 |
+
"Run from the repo root or `pip install -e envs/board_sim_env`."
|
| 161 |
+
)
|
| 162 |
+
if env_url.lower().startswith(("http://", "https://")):
|
| 163 |
+
with BoardSimEnv(base_url=env_url).sync() as env:
|
| 164 |
+
yield env
|
| 165 |
+
else:
|
| 166 |
+
from envs.board_sim_env.server.board_sim_env_environment import BoardSimEnvironment
|
| 167 |
+
|
| 168 |
+
class _LocalEnv:
|
| 169 |
+
def __init__(self):
|
| 170 |
+
self._env = BoardSimEnvironment()
|
| 171 |
+
|
| 172 |
+
def reset(self, seed: int = 0):
|
| 173 |
+
obs = self._env.reset(seed=seed)
|
| 174 |
+
return _Result(obs)
|
| 175 |
+
|
| 176 |
+
def step(self, action):
|
| 177 |
+
obs = self._env.step(action)
|
| 178 |
+
return _Result(obs)
|
| 179 |
+
|
| 180 |
+
@dataclass
|
| 181 |
+
class _Result:
|
| 182 |
+
observation: Any
|
| 183 |
+
@property
|
| 184 |
+
def reward(self): return float(self.observation.reward or 0.0)
|
| 185 |
+
@property
|
| 186 |
+
def done(self): return bool(self.observation.done)
|
| 187 |
+
|
| 188 |
+
yield _LocalEnv()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class TrainedPolicy:
|
| 192 |
+
"""Qwen3-0.6B + LoRA adapter via Unsloth/PEFT. Falls back to random on load failure."""
|
| 193 |
+
|
| 194 |
+
def __init__(self, model_path: str, adapter_path: str, device: str = "auto"):
|
| 195 |
+
self.model = None
|
| 196 |
+
self.tokenizer = None
|
| 197 |
+
self.device = device
|
| 198 |
+
self.fallback = False
|
| 199 |
+
try:
|
| 200 |
+
self._load(model_path, adapter_path)
|
| 201 |
+
except Exception as e:
|
| 202 |
+
print(f"[trained-policy] WARN: model load failed ({e}). Falling back to random policy.")
|
| 203 |
+
self.fallback = True
|
| 204 |
+
|
| 205 |
+
def _load(self, model_path: str, adapter_path: str):
|
| 206 |
+
try:
|
| 207 |
+
import unsloth # noqa: F401
|
| 208 |
+
from unsloth import FastLanguageModel
|
| 209 |
+
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
| 210 |
+
model_name=model_path, max_seq_length=2048, load_in_4bit=True, dtype=None,
|
| 211 |
+
)
|
| 212 |
+
if os.path.exists(adapter_path):
|
| 213 |
+
from peft import PeftModel
|
| 214 |
+
self.model = PeftModel.from_pretrained(self.model, os.path.dirname(adapter_path) or ROOT)
|
| 215 |
+
FastLanguageModel.for_inference(self.model)
|
| 216 |
+
except Exception:
|
| 217 |
+
import torch
|
| 218 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 219 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 220 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 221 |
+
model_path, torch_dtype=torch.float16,
|
| 222 |
+
device_map="auto" if self.device == "auto" else self.device,
|
| 223 |
+
)
|
| 224 |
+
if os.path.exists(adapter_path):
|
| 225 |
+
from peft import PeftModel
|
| 226 |
+
self.model = PeftModel.from_pretrained(
|
| 227 |
+
self.model, os.path.dirname(adapter_path) or ROOT
|
| 228 |
+
)
|
| 229 |
+
self.model.eval()
|
| 230 |
+
if self.tokenizer.pad_token is None:
|
| 231 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 232 |
+
|
| 233 |
+
def act(self, obs: Any) -> Tuple[str, str, bool]:
|
| 234 |
+
if self.fallback or self.model is None:
|
| 235 |
+
return random.choice(obs.options), "", False
|
| 236 |
+
import torch
|
| 237 |
+
prompt = build_prompt(obs)
|
| 238 |
+
device = next(self.model.parameters()).device
|
| 239 |
+
enc = self.tokenizer(prompt, return_tensors="pt", truncation=True,
|
| 240 |
+
max_length=MAX_PROMPT_LEN).to(device)
|
| 241 |
+
with torch.no_grad():
|
| 242 |
+
out = self.model.generate(
|
| 243 |
+
**enc, max_new_tokens=MAX_NEW_TOKENS, do_sample=False,
|
| 244 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 245 |
+
)
|
| 246 |
+
completion = self.tokenizer.decode(out[0][enc.input_ids.shape[1]:],
|
| 247 |
+
skip_special_tokens=True)
|
| 248 |
+
return parse_completion(completion, obs.options)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class RandomPolicy:
|
| 252 |
+
def act(self, obs: Any) -> Tuple[str, str, bool]:
|
| 253 |
+
return random.choice(obs.options), "", False
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def run_episode(env: Any, policy: Any, seed: int, policy_name: str) -> EpisodeMetrics:
|
| 257 |
+
from board_sim_env.models import BoardSimAction
|
| 258 |
+
result = env.reset(seed=seed)
|
| 259 |
+
obs = result.observation
|
| 260 |
+
|
| 261 |
+
metrics = EpisodeMetrics(
|
| 262 |
+
seed=seed, total_reward=0.0, final_profitability=0.0,
|
| 263 |
+
survived=True, votes_won=0, votes_total=0, pitches_written=0,
|
| 264 |
+
avg_pitch_score=0.0, policy=policy_name,
|
| 265 |
+
)
|
| 266 |
+
pitch_scores: List[float] = []
|
| 267 |
+
|
| 268 |
+
while not result.done:
|
| 269 |
+
decision, pitch, _ = policy.act(obs)
|
| 270 |
+
if pitch.strip():
|
| 271 |
+
metrics.pitches_written += 1
|
| 272 |
+
opposing = [s["role"] for s in obs.npc_statements if s["vote"] != decision]
|
| 273 |
+
for role in opposing:
|
| 274 |
+
pitch_scores.append(keyword_pitch_score(pitch, role))
|
| 275 |
+
result = env.step(BoardSimAction(decision=decision, coalition_pitch=pitch))
|
| 276 |
+
obs = result.observation
|
| 277 |
+
metrics.total_reward += float(result.reward or 0.0)
|
| 278 |
+
metrics.votes_total += 1
|
| 279 |
+
history = obs.state.get("history", [])
|
| 280 |
+
if history and history[-1].get("agent_won_vote"):
|
| 281 |
+
metrics.votes_won += 1
|
| 282 |
+
metrics.decisions.append(decision)
|
| 283 |
+
if obs.state.get("trust_history"):
|
| 284 |
+
metrics.trust_trajectory = obs.state["trust_history"]
|
| 285 |
+
|
| 286 |
+
metrics.final_profitability = float(obs.state.get("profitability_score", 0.0))
|
| 287 |
+
metrics.done_reason = obs.state.get("done_reason")
|
| 288 |
+
metrics.survived = metrics.done_reason != "runway_exhausted"
|
| 289 |
+
metrics.avg_pitch_score = statistics.mean(pitch_scores) if pitch_scores else 0.0
|
| 290 |
+
return metrics
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def summarise(policy: str, eps: List[EpisodeMetrics]) -> RunSummary:
|
| 294 |
+
n = len(eps)
|
| 295 |
+
rewards = [e.total_reward for e in eps]
|
| 296 |
+
profits = [e.final_profitability for e in eps]
|
| 297 |
+
return RunSummary(
|
| 298 |
+
policy=policy, n_episodes=n,
|
| 299 |
+
mean_reward=statistics.mean(rewards),
|
| 300 |
+
std_reward=statistics.stdev(rewards) if n > 1 else 0.0,
|
| 301 |
+
mean_profitability=statistics.mean(profits),
|
| 302 |
+
std_profitability=statistics.stdev(profits) if n > 1 else 0.0,
|
| 303 |
+
survival_rate=sum(e.survived for e in eps) / n,
|
| 304 |
+
win_rate_per_round=sum(e.votes_won for e in eps) / max(1, sum(e.votes_total for e in eps)),
|
| 305 |
+
pitch_usage_rate=sum(e.pitches_written for e in eps) / max(1, sum(e.votes_total for e in eps)),
|
| 306 |
+
mean_pitch_score=statistics.mean(e.avg_pitch_score for e in eps),
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def print_summary_table(*summaries: RunSummary) -> None:
|
| 311 |
+
cols = ["policy", "n", "mean_reward", "mean_profit", "survival", "win_rate", "pitch_use", "pitch_score"]
|
| 312 |
+
width = [14, 4, 12, 12, 9, 9, 10, 12]
|
| 313 |
+
header = " ".join(c.ljust(w) for c, w in zip(cols, width))
|
| 314 |
+
print("\n" + header); print("-" * len(header))
|
| 315 |
+
for s in summaries:
|
| 316 |
+
row = [
|
| 317 |
+
s.policy.ljust(width[0]),
|
| 318 |
+
str(s.n_episodes).ljust(width[1]),
|
| 319 |
+
f"{s.mean_reward:+.3f} ± {s.std_reward:.2f}".ljust(width[2]),
|
| 320 |
+
f"{s.mean_profitability:.2f} ± {s.std_profitability:.2f}".ljust(width[3]),
|
| 321 |
+
f"{s.survival_rate:.1%}".ljust(width[4]),
|
| 322 |
+
f"{s.win_rate_per_round:.1%}".ljust(width[5]),
|
| 323 |
+
f"{s.pitch_usage_rate:.1%}".ljust(width[6]),
|
| 324 |
+
f"{s.mean_pitch_score:.3f}".ljust(width[7]),
|
| 325 |
+
]
|
| 326 |
+
print(" ".join(row))
|
| 327 |
+
print()
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def mode_eval(args, env_url: str) -> None:
|
| 331 |
+
policy = TrainedPolicy(args.model_path, args.adapter_path, args.device)
|
| 332 |
+
name = "random-fallback" if policy.fallback else "trained-qwen3-0.6b"
|
| 333 |
+
eps: List[EpisodeMetrics] = []
|
| 334 |
+
with make_env_client(env_url) as env:
|
| 335 |
+
for i in range(args.episodes):
|
| 336 |
+
seed = args.seed + i
|
| 337 |
+
ep = run_episode(env, policy, seed, name)
|
| 338 |
+
eps.append(ep)
|
| 339 |
+
print(f" ep {i+1:3d}/{args.episodes} seed={seed} "
|
| 340 |
+
f"reward={ep.total_reward:+.2f} profit={ep.final_profitability:5.1f} "
|
| 341 |
+
f"won={ep.votes_won}/{ep.votes_total} pitches={ep.pitches_written}")
|
| 342 |
+
print_summary_table(summarise(name, eps))
|
| 343 |
+
if args.out:
|
| 344 |
+
with open(args.out, "w") as f:
|
| 345 |
+
json.dump([asdict(e) for e in eps], f, indent=2)
|
| 346 |
+
print(f"Wrote {args.out}")
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def mode_compare(args, env_url: str) -> None:
|
| 350 |
+
trained = TrainedPolicy(args.model_path, args.adapter_path, args.device)
|
| 351 |
+
rand = RandomPolicy()
|
| 352 |
+
trained_name = "random-fallback" if trained.fallback else "trained-qwen3-0.6b"
|
| 353 |
+
|
| 354 |
+
trained_eps, rand_eps = [], []
|
| 355 |
+
with make_env_client(env_url) as env:
|
| 356 |
+
print(f"\n[compare] running {args.episodes} episodes with {trained_name}...")
|
| 357 |
+
for i in range(args.episodes):
|
| 358 |
+
trained_eps.append(run_episode(env, trained, args.seed + i, trained_name))
|
| 359 |
+
print(f"[compare] running {args.episodes} episodes with random policy...")
|
| 360 |
+
for i in range(args.episodes):
|
| 361 |
+
rand_eps.append(run_episode(env, rand, args.seed + i, "random"))
|
| 362 |
+
|
| 363 |
+
print_summary_table(summarise(trained_name, trained_eps), summarise("random", rand_eps))
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def mode_interactive(args, env_url: str) -> None:
|
| 367 |
+
from board_sim_env.models import BoardSimAction
|
| 368 |
+
print("\nNeuralEdge AI Boardroom — interactive (human-play) mode")
|
| 369 |
+
print("Type DECISION, then PITCH on a separate line. Empty input picks option[0].\n")
|
| 370 |
+
with make_env_client(env_url) as env:
|
| 371 |
+
result = env.reset(seed=args.seed)
|
| 372 |
+
obs = result.observation
|
| 373 |
+
ep_reward = 0.0
|
| 374 |
+
while not result.done:
|
| 375 |
+
print("=" * 70)
|
| 376 |
+
print(f"Round {obs.round}/10 — score={obs.state.get('profitability_score', 0):.1f} "
|
| 377 |
+
f"runway={obs.state.get('runway_months', 0):.1f}mo")
|
| 378 |
+
print(f"Event: {obs.event}")
|
| 379 |
+
for s in obs.npc_statements:
|
| 380 |
+
print(f" [{s['role']:13s}] votes {s['vote']:<28s} (conf {s.get('confidence', 0.5):.2f})")
|
| 381 |
+
print(f" {textwrap.fill(s['statement'], 90, subsequent_indent=' ')}")
|
| 382 |
+
print(f"Options: {obs.options}")
|
| 383 |
+
d_raw = input("DECISION: ").strip() or obs.options[0]
|
| 384 |
+
decision = next((o for o in obs.options if o.lower() in d_raw.lower()), obs.options[0])
|
| 385 |
+
pitch = input("PITCH: ").strip()
|
| 386 |
+
result = env.step(BoardSimAction(decision=decision, coalition_pitch=pitch))
|
| 387 |
+
obs = result.observation
|
| 388 |
+
ep_reward += float(result.reward or 0.0)
|
| 389 |
+
print(f">>> reward {result.reward:+.3f} cumulative {ep_reward:+.3f}")
|
| 390 |
+
print(f"\nDONE. final profitability={obs.state.get('profitability_score', 0):.2f} "
|
| 391 |
+
f"reason={obs.state.get('done_reason')} total_reward={ep_reward:+.2f}")
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def parse_args() -> argparse.Namespace:
|
| 395 |
+
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
| 396 |
+
p.add_argument("--mode", choices=["interactive", "eval", "compare"], default="eval")
|
| 397 |
+
p.add_argument("--model_path", default=DEFAULT_MODEL)
|
| 398 |
+
p.add_argument("--adapter_path", default=DEFAULT_ADAPTER)
|
| 399 |
+
p.add_argument("--env_url", default=os.environ.get("ENV_BASE_URL", DEFAULT_HF_SPACE),
|
| 400 |
+
help="HF Space URL or 'local' for in-process env")
|
| 401 |
+
p.add_argument("--episodes", type=int, default=10)
|
| 402 |
+
p.add_argument("--seed", type=int, default=42)
|
| 403 |
+
p.add_argument("--device", default="auto")
|
| 404 |
+
p.add_argument("--out", default="", help="Write per-episode JSON to this path")
|
| 405 |
+
return p.parse_args()
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def main() -> None:
|
| 409 |
+
args = parse_args()
|
| 410 |
+
random.seed(args.seed)
|
| 411 |
+
print(f"NeuralEdge AI Boardroom — inference (mode={args.mode})")
|
| 412 |
+
print(f" env_url = {args.env_url}")
|
| 413 |
+
print(f" model_path = {args.model_path}")
|
| 414 |
+
print(f" adapter = {args.adapter_path} {'(found)' if os.path.exists(args.adapter_path) else '(missing → random fallback)'}")
|
| 415 |
+
t0 = time.time()
|
| 416 |
+
if args.mode == "interactive":
|
| 417 |
+
mode_interactive(args, args.env_url)
|
| 418 |
+
elif args.mode == "eval":
|
| 419 |
+
mode_eval(args, args.env_url)
|
| 420 |
+
else:
|
| 421 |
+
mode_compare(args, args.env_url)
|
| 422 |
+
print(f"\nelapsed: {time.time() - t0:.1f}s")
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
if __name__ == "__main__":
|
| 426 |
+
main()
|
notebooks/train_cell_fixed.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# GRPO training cell — fixed version
|
| 3 |
+
#
|
| 4 |
+
# Fixes:
|
| 5 |
+
# 1. RuntimeError "variable modified by an inplace operation" on loss.backward().
|
| 6 |
+
# Root cause: model.generate() leaves use_cache=True, and the subsequent
|
| 7 |
+
# forward pass returns logits that share storage with KV-cache buffers,
|
| 8 |
+
# which get mutated later. Fix: force use_cache=False on the training
|
| 9 |
+
# forward pass, and .clone() the logits slice before computing log_softmax.
|
| 10 |
+
#
|
| 11 |
+
# 2. GPU OOM on cell re-run. Root cause: re-running the cell creates a fresh
|
| 12 |
+
# AdamW (which holds momentum buffers ~= model size) without freeing the
|
| 13 |
+
# previous one. Fix: explicit cleanup of any prior optimizer / cached
|
| 14 |
+
# tensors at the top of the cell + gc + empty_cache. Model itself is NOT
|
| 15 |
+
# reloaded here (load it once in an earlier cell); we just reuse it.
|
| 16 |
+
#
|
| 17 |
+
# 3. wandb deprecation warning for reinit=True. Use finish_previous=True only.
|
| 18 |
+
# =============================================================================
|
| 19 |
+
|
| 20 |
+
import os, gc, json, time, collections
|
| 21 |
+
import torch
|
| 22 |
+
from torch.optim import AdamW
|
| 23 |
+
|
| 24 |
+
# ---- 0. cleanup any leftover state from previous runs of this cell ----------
|
| 25 |
+
for _name in ('optimizer', 'gen_out', 'out', 'logits', 'loss',
|
| 26 |
+
'log_probs', 'token_nll', 'per_seq_nll', 'advantages'):
|
| 27 |
+
if _name in globals():
|
| 28 |
+
try:
|
| 29 |
+
del globals()[_name]
|
| 30 |
+
except Exception:
|
| 31 |
+
pass
|
| 32 |
+
gc.collect()
|
| 33 |
+
if torch.cuda.is_available():
|
| 34 |
+
torch.cuda.empty_cache()
|
| 35 |
+
torch.cuda.ipc_collect()
|
| 36 |
+
|
| 37 |
+
# ---- 1. config --------------------------------------------------------------
|
| 38 |
+
NUM_STEPS = int(os.environ.get('NUM_STEPS', 100))
|
| 39 |
+
GROUP_SIZE = int(os.environ.get('GROUP_SIZE', 4))
|
| 40 |
+
LR = 5e-6
|
| 41 |
+
GRAD_CLIP = 1.0
|
| 42 |
+
TEMPERATURE, TOP_P = 1.0, 0.95
|
| 43 |
+
SAVE_EVERY = 25
|
| 44 |
+
EVAL_AT = {0, 25, 50, 75, NUM_STEPS - 1}
|
| 45 |
+
|
| 46 |
+
# Critical: kill KV cache on the training forward pass.
|
| 47 |
+
# generate() will still build its own cache internally; we override afterwards.
|
| 48 |
+
model.config.use_cache = False
|
| 49 |
+
model.gradient_checkpointing_disable() if hasattr(model, 'gradient_checkpointing_disable') else None
|
| 50 |
+
model.train()
|
| 51 |
+
|
| 52 |
+
# ---- 2. wandb (no deprecated reinit) ----------------------------------------
|
| 53 |
+
WANDB_OK = False
|
| 54 |
+
if os.environ.get('WANDB_API_KEY'):
|
| 55 |
+
try:
|
| 56 |
+
import wandb
|
| 57 |
+
wandb.init(
|
| 58 |
+
project='boardsim-qwen3-grpo',
|
| 59 |
+
name='boardsim-qwen3-1p7b-kaggle',
|
| 60 |
+
config={'num_steps': NUM_STEPS, 'group_size': GROUP_SIZE, 'lr': LR,
|
| 61 |
+
'temperature': TEMPERATURE, 'top_p': TOP_P, 'model': MODEL_NAME},
|
| 62 |
+
finish_previous=True,
|
| 63 |
+
)
|
| 64 |
+
WANDB_OK = True
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f'WARN: wandb.init failed: {e}')
|
| 67 |
+
|
| 68 |
+
# ---- 3. optimizer (single owner, freshly built each cell run) ---------------
|
| 69 |
+
optimizer = AdamW(
|
| 70 |
+
[p for p in model.parameters() if p.requires_grad],
|
| 71 |
+
lr=LR, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
log_history, eval_history = [], []
|
| 75 |
+
decision_counter = collections.Counter()
|
| 76 |
+
t0 = time.time()
|
| 77 |
+
|
| 78 |
+
# ---- 4. training loop -------------------------------------------------------
|
| 79 |
+
with make_env().sync() as env_train, \
|
| 80 |
+
make_env().sync() as env_score, \
|
| 81 |
+
make_env().sync() as env_eval:
|
| 82 |
+
|
| 83 |
+
for step in range(NUM_STEPS):
|
| 84 |
+
# 4a. rollout
|
| 85 |
+
result = env_train.reset(seed=step)
|
| 86 |
+
obs = result.observation
|
| 87 |
+
prompt = build_prompt(obs)
|
| 88 |
+
enc = tokenizer(prompt, return_tensors='pt',
|
| 89 |
+
truncation=True, max_length=1024).to(device)
|
| 90 |
+
prompt_len = enc.input_ids.shape[1]
|
| 91 |
+
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
gen_out = model.generate(
|
| 94 |
+
input_ids=enc.input_ids,
|
| 95 |
+
attention_mask=enc.attention_mask,
|
| 96 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
| 97 |
+
do_sample=True,
|
| 98 |
+
temperature=TEMPERATURE,
|
| 99 |
+
top_p=TOP_P,
|
| 100 |
+
num_return_sequences=GROUP_SIZE,
|
| 101 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 102 |
+
use_cache=True, # cache OK during generate (no_grad context)
|
| 103 |
+
)
|
| 104 |
+
# Detach + clone so no autograd ties to generate's internal buffers.
|
| 105 |
+
gen_out = gen_out.detach().clone()
|
| 106 |
+
|
| 107 |
+
# 4b. score each completion
|
| 108 |
+
decisions, pitches, rewards, fmt_oks = [], [], [], []
|
| 109 |
+
for g in range(GROUP_SIZE):
|
| 110 |
+
comp = tokenizer.decode(gen_out[g][prompt_len:], skip_special_tokens=True)
|
| 111 |
+
d, pp, ok = parse_completion(comp, obs.options)
|
| 112 |
+
decisions.append(d); pitches.append(pp); fmt_oks.append(ok)
|
| 113 |
+
decision_counter[d] += 1
|
| 114 |
+
env_score.reset(seed=step)
|
| 115 |
+
sr = env_score.step(BoardSimAction(decision=d, coalition_pitch=pp))
|
| 116 |
+
rewards.append(float(sr.reward or 0.0))
|
| 117 |
+
|
| 118 |
+
rewards_t = torch.tensor(rewards, dtype=torch.float32, device=device)
|
| 119 |
+
if rewards_t.numel() > 1 and rewards_t.std().item() > 1e-6:
|
| 120 |
+
advantages = (rewards_t - rewards_t.mean()) / (rewards_t.std() + 1e-8)
|
| 121 |
+
else:
|
| 122 |
+
advantages = rewards_t - rewards_t.mean()
|
| 123 |
+
advantages = advantages.detach()
|
| 124 |
+
|
| 125 |
+
# 4c. policy update — fresh forward, NO cache, clone logits
|
| 126 |
+
optimizer.zero_grad(set_to_none=True)
|
| 127 |
+
|
| 128 |
+
full_ids = gen_out
|
| 129 |
+
attn = (full_ids != tokenizer.pad_token_id).long()
|
| 130 |
+
loss_mask = attn.clone()
|
| 131 |
+
loss_mask[:, :prompt_len] = 0
|
| 132 |
+
|
| 133 |
+
out = model(
|
| 134 |
+
input_ids=full_ids,
|
| 135 |
+
attention_mask=attn,
|
| 136 |
+
use_cache=False, # <-- key fix
|
| 137 |
+
return_dict=True,
|
| 138 |
+
)
|
| 139 |
+
# Clone the slice so backward sees a tensor whose storage we own.
|
| 140 |
+
logits = out.logits[:, :-1, :].float().clone()
|
| 141 |
+
targets = full_ids[:, 1:].contiguous()
|
| 142 |
+
mask = loss_mask[:, 1:].float()
|
| 143 |
+
|
| 144 |
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
| 145 |
+
token_nll = -log_probs.gather(2, targets.unsqueeze(-1)).squeeze(-1)
|
| 146 |
+
per_seq_nll = (token_nll * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
|
| 147 |
+
loss = (advantages * per_seq_nll).mean()
|
| 148 |
+
|
| 149 |
+
loss.backward()
|
| 150 |
+
total_loss_val = float(loss.detach().item())
|
| 151 |
+
|
| 152 |
+
torch.nn.utils.clip_grad_norm_(
|
| 153 |
+
[p for p in model.parameters() if p.requires_grad], GRAD_CLIP)
|
| 154 |
+
optimizer.step()
|
| 155 |
+
|
| 156 |
+
# Free per-step graph tensors before next iter (helps on tight VRAM).
|
| 157 |
+
del out, logits, log_probs, token_nll, per_seq_nll, loss
|
| 158 |
+
|
| 159 |
+
# 4d. log
|
| 160 |
+
rec = {
|
| 161 |
+
'step': step,
|
| 162 |
+
'reward': float(rewards_t.mean().item()),
|
| 163 |
+
'reward_std': float(rewards_t.std().item()) if rewards_t.numel() > 1 else 0.0,
|
| 164 |
+
'reward_max': float(rewards_t.max().item()),
|
| 165 |
+
'loss': total_loss_val,
|
| 166 |
+
'format_rate': sum(fmt_oks) / GROUP_SIZE,
|
| 167 |
+
'pitch_rate': sum(1 for p in pitches if p.strip()) / GROUP_SIZE,
|
| 168 |
+
'elapsed_s': time.time() - t0,
|
| 169 |
+
}
|
| 170 |
+
log_history.append(rec)
|
| 171 |
+
if WANDB_OK:
|
| 172 |
+
wandb.log(rec, step=step)
|
| 173 |
+
|
| 174 |
+
if step % 5 == 0:
|
| 175 |
+
print(f"step={step:4d} reward={rec['reward']:+.3f} (\u00b1{rec['reward_std']:.2f}) "
|
| 176 |
+
f"loss={rec['loss']:+.4f} fmt={rec['format_rate']:.0%} "
|
| 177 |
+
f"elapsed={rec['elapsed_s']:.0f}s d0={decisions[0]}")
|
| 178 |
+
|
| 179 |
+
# 4e. periodic eval
|
| 180 |
+
if step in EVAL_AT:
|
| 181 |
+
ev = periodic_eval(env_eval)
|
| 182 |
+
ev['step'] = step
|
| 183 |
+
eval_history.append(ev)
|
| 184 |
+
print(f" [eval@{step}] profit={ev['profit_mean']:.2f} "
|
| 185 |
+
f"reward={ev['reward_mean']:.2f} fmt={ev['format_rate']:.0%}")
|
| 186 |
+
if WANDB_OK:
|
| 187 |
+
wandb.log({f'eval/{k}': v for k, v in ev.items() if k != 'step'}, step=step)
|
| 188 |
+
|
| 189 |
+
# 4f. checkpoint
|
| 190 |
+
if step > 0 and step % SAVE_EVERY == 0:
|
| 191 |
+
model.save_pretrained(str(CKPT))
|
| 192 |
+
tokenizer.save_pretrained(str(CKPT))
|
| 193 |
+
with open(WORK_DIR / 'log_history.json', 'w') as f:
|
| 194 |
+
json.dump(log_history, f)
|
| 195 |
+
with open(WORK_DIR / 'eval_history.json', 'w') as f:
|
| 196 |
+
json.dump(eval_history, f)
|
| 197 |
+
|
| 198 |
+
# ---- 5. final save ----------------------------------------------------------
|
| 199 |
+
model.save_pretrained(str(CKPT))
|
| 200 |
+
tokenizer.save_pretrained(str(CKPT))
|
| 201 |
+
with open(WORK_DIR / 'log_history.json', 'w') as f:
|
| 202 |
+
json.dump(log_history, f)
|
| 203 |
+
with open(WORK_DIR / 'eval_history.json', 'w') as f:
|
| 204 |
+
json.dump(eval_history, f)
|
| 205 |
+
with open(WORK_DIR / 'decision_counter.json', 'w') as f:
|
| 206 |
+
json.dump(dict(decision_counter), f)
|
| 207 |
+
if WANDB_OK:
|
| 208 |
+
wandb.finish()
|
| 209 |
+
print(f'Training done. {len(log_history)} steps in {time.time() - t0:.0f}s. -> {CKPT}')
|
notebooks/train_grpo_kaggle.ipynb
ADDED
|
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# BoardSim × Qwen3-1.7B — GRPO LoRA fine-tune (Kaggle edition)\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Runs on Kaggle GPUs (T4 x2 or P100). Enable: **Settings → Accelerator: GPU**, **Internet: On**.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"Add Kaggle Secrets (Add-ons → Secrets):\n",
|
| 12 |
+
"- `HF_TOKEN` (required)\n",
|
| 13 |
+
"- `WANDB_API_KEY` (optional)\n",
|
| 14 |
+
"- `ENV_BASE_URL` (optional, defaults to public HF Space)\n",
|
| 15 |
+
"- `ADAPTER_REPO`, `MERGED_REPO` (optional)"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "markdown",
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"source": ["## 1. Install deps (unsloth FIRST — patches torch/transformers at import)"]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": null,
|
| 26 |
+
"metadata": {},
|
| 27 |
+
"outputs": [],
|
| 28 |
+
"source": [
|
| 29 |
+
"%pip install -q --no-deps unsloth\n",
|
| 30 |
+
"%pip install -q unsloth_zoo\n",
|
| 31 |
+
"%pip install -q \"openenv-core==0.2.3\" \"trl>=0.12,<2.0\" \"transformers>=4.45,<5.0\" \\\n",
|
| 32 |
+
" \"datasets>=3.0\" \"accelerate>=1.0\" \"huggingface_hub>=0.25\" \"pydantic>=2.0\" \\\n",
|
| 33 |
+
" wandb matplotlib python-dotenv bitsandbytes scipy scikit-learn sentence-transformers"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "markdown",
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"source": ["## 2. Auth — Kaggle Secrets → env vars → HF / W&B login"]
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"cell_type": "code",
|
| 43 |
+
"execution_count": null,
|
| 44 |
+
"metadata": {},
|
| 45 |
+
"outputs": [],
|
| 46 |
+
"source": [
|
| 47 |
+
"import os, pathlib\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"IN_KAGGLE = os.path.isdir('/kaggle')\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"# Kaggle Secrets first\n",
|
| 52 |
+
"if IN_KAGGLE:\n",
|
| 53 |
+
" try:\n",
|
| 54 |
+
" from kaggle_secrets import UserSecretsClient\n",
|
| 55 |
+
" usc = UserSecretsClient()\n",
|
| 56 |
+
" for k in ('HF_TOKEN', 'WANDB_API_KEY', 'ENV_BASE_URL', 'ADAPTER_REPO', 'MERGED_REPO'):\n",
|
| 57 |
+
" try:\n",
|
| 58 |
+
" v = usc.get_secret(k)\n",
|
| 59 |
+
" if v:\n",
|
| 60 |
+
" os.environ.setdefault(k, v)\n",
|
| 61 |
+
" except Exception:\n",
|
| 62 |
+
" pass\n",
|
| 63 |
+
" except Exception as e:\n",
|
| 64 |
+
" print(f'kaggle_secrets unavailable: {e}')\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"# .env fallback\n",
|
| 67 |
+
"try:\n",
|
| 68 |
+
" from dotenv import load_dotenv\n",
|
| 69 |
+
" for p in [pathlib.Path('.env'), pathlib.Path('../.env'),\n",
|
| 70 |
+
" pathlib.Path('/kaggle/working/.env')]:\n",
|
| 71 |
+
" if p.exists():\n",
|
| 72 |
+
" load_dotenv(p, override=False)\n",
|
| 73 |
+
" print(f'Loaded env from {p.resolve()}')\n",
|
| 74 |
+
" break\n",
|
| 75 |
+
"except Exception:\n",
|
| 76 |
+
" pass\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"if not os.environ.get('HF_TOKEN'):\n",
|
| 79 |
+
" os.environ['HF_TOKEN'] = input('HF token: ').strip()\n",
|
| 80 |
+
"if not os.environ.get('WANDB_API_KEY'):\n",
|
| 81 |
+
" os.environ['WANDB_API_KEY'] = input('WandB key (or blank to skip): ').strip()\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"from huggingface_hub import login as hf_login\n",
|
| 84 |
+
"hf_login(token=os.environ['HF_TOKEN'], add_to_git_credential=False)\n",
|
| 85 |
+
"print('HF auth ok.')\n",
|
| 86 |
+
"if os.environ.get('WANDB_API_KEY'):\n",
|
| 87 |
+
" import wandb\n",
|
| 88 |
+
" wandb.login(key=os.environ['WANDB_API_KEY'])\n",
|
| 89 |
+
" print('W&B auth ok.')"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "markdown",
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"source": ["## 3. Working dirs (Kaggle uses `/kaggle/working` — persists as notebook output)"]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "code",
|
| 99 |
+
"execution_count": null,
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"outputs": [],
|
| 102 |
+
"source": [
|
| 103 |
+
"import pathlib\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"if IN_KAGGLE:\n",
|
| 106 |
+
" WORK_DIR = pathlib.Path('/kaggle/working/BoardSim_Run')\n",
|
| 107 |
+
"else:\n",
|
| 108 |
+
" WORK_DIR = pathlib.Path('./BoardSim_Run')\n",
|
| 109 |
+
"WORK_DIR.mkdir(parents=True, exist_ok=True)\n",
|
| 110 |
+
"ASSETS = WORK_DIR / 'assets'; ASSETS.mkdir(exist_ok=True)\n",
|
| 111 |
+
"CKPT = WORK_DIR / 'lora_qwen3_1p7b'; CKPT.mkdir(exist_ok=True)\n",
|
| 112 |
+
"print('WORK_DIR =', WORK_DIR)"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "markdown",
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"source": ["## 4. Clone repo + connect to BoardSim env"]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "code",
|
| 122 |
+
"execution_count": null,
|
| 123 |
+
"metadata": {},
|
| 124 |
+
"outputs": [],
|
| 125 |
+
"source": [
|
| 126 |
+
"import os, sys, subprocess, urllib.request, json as _json\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"ENV_BASE_URL = os.environ.get('ENV_BASE_URL',\n",
|
| 129 |
+
" 'https://stavankhobare-sst-metaxpytorch-hackathon.hf.space')\n",
|
| 130 |
+
"REPO_URL = 'https://github.com/StavanRKhobare/SST-MetaxPyTorch-Hackathon'\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"REPO_DIR = '/kaggle/working/repo' if IN_KAGGLE else os.path.abspath('./repo')\n",
|
| 133 |
+
"if not os.path.isdir(os.path.join(REPO_DIR, '.git')):\n",
|
| 134 |
+
" subprocess.run(['git', 'clone', '--depth', '1', REPO_URL, REPO_DIR], check=True)\n",
|
| 135 |
+
"else:\n",
|
| 136 |
+
" subprocess.run(['git', '-C', REPO_DIR, 'pull', '--ff-only'], check=False)\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"ENVS_DIR = os.path.join(REPO_DIR, 'envs')\n",
|
| 139 |
+
"if ENVS_DIR not in sys.path:\n",
|
| 140 |
+
" sys.path.insert(0, ENVS_DIR)\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"for mod in [m for m in list(sys.modules) if m == 'board_sim_env' or m.startswith('board_sim_env.')]:\n",
|
| 143 |
+
" del sys.modules[mod]\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"from board_sim_env.client import BoardSimEnv\n",
|
| 146 |
+
"from board_sim_env.models import BoardSimAction, BoardSimObservation\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"try:\n",
|
| 149 |
+
" with urllib.request.urlopen(f'{ENV_BASE_URL.rstrip(\"/\")}/health', timeout=20) as r:\n",
|
| 150 |
+
" h = _json.loads(r.read())\n",
|
| 151 |
+
" print('health:', h)\n",
|
| 152 |
+
"except Exception as e:\n",
|
| 153 |
+
" print(f'WARN: could not reach {ENV_BASE_URL}/health ({e})')\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"def make_env():\n",
|
| 156 |
+
" return BoardSimEnv(base_url=ENV_BASE_URL)\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"print('BoardSimEnv ready.')"
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"cell_type": "markdown",
|
| 163 |
+
"metadata": {},
|
| 164 |
+
"source": ["## 5. Load Qwen3-1.7B in 4-bit via Unsloth"]
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"cell_type": "code",
|
| 168 |
+
"execution_count": null,
|
| 169 |
+
"metadata": {},
|
| 170 |
+
"outputs": [],
|
| 171 |
+
"source": [
|
| 172 |
+
"import unsloth # noqa: F401\n",
|
| 173 |
+
"from unsloth import FastLanguageModel\n",
|
| 174 |
+
"import torch, re\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"MODEL_NAME = 'Qwen/Qwen3-1.7B'\n",
|
| 177 |
+
"MAX_SEQ_LEN = 2048\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 180 |
+
" model_name=MODEL_NAME,\n",
|
| 181 |
+
" max_seq_length=MAX_SEQ_LEN,\n",
|
| 182 |
+
" load_in_4bit=True,\n",
|
| 183 |
+
" dtype=None,\n",
|
| 184 |
+
")\n",
|
| 185 |
+
"if tokenizer.pad_token is None:\n",
|
| 186 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"device = next(model.parameters()).device\n",
|
| 189 |
+
"print(f'Loaded {MODEL_NAME} on {device}.')\n",
|
| 190 |
+
"if torch.cuda.is_available():\n",
|
| 191 |
+
" total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n",
|
| 192 |
+
" mem_gb = torch.cuda.memory_allocated() / 1e9\n",
|
| 193 |
+
" print(f'GPU memory after base load: {mem_gb:.2f} GB / {total_gb:.2f} GB')\n",
|
| 194 |
+
" print(f'Headroom for compute: {total_gb - mem_gb:.2f} GB')"
|
| 195 |
+
]
|
| 196 |
+
},
|
| 197 |
+
{
|
| 198 |
+
"cell_type": "markdown",
|
| 199 |
+
"metadata": {},
|
| 200 |
+
"source": ["## 6. Prompt + parser + greedy action helper"]
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
"cell_type": "code",
|
| 204 |
+
"execution_count": null,
|
| 205 |
+
"metadata": {},
|
| 206 |
+
"outputs": [],
|
| 207 |
+
"source": [
|
| 208 |
+
"SYSTEM_PROMPT = \"\"\"You are the CEO of a mid-stage organization. Your board has 4 members with HIDDEN AGENDAS you cannot see directly:\n",
|
| 209 |
+
" - CTO: cares about operational excellence, engineering quality, team morale, and product readiness.\n",
|
| 210 |
+
" - CFO: cares about cash discipline, runway, and regulatory safety.\n",
|
| 211 |
+
" - Investor Rep: pushes growth, market share, and bold returns.\n",
|
| 212 |
+
" - Independent: cares about reputation, governance, and long-term consensus.\n",
|
| 213 |
+
"\n",
|
| 214 |
+
"Each round you see a strategic event, every NPC's pre-vote statement, and 3 options.\n",
|
| 215 |
+
"Your decision is resolved by WEIGHTED VOTE (your weight 2.5x). A short COALITION PITCH\n",
|
| 216 |
+
"that is semantically aligned with opposing members' priorities can swing them toward your pick —\n",
|
| 217 |
+
"write substantive arguments, not just buzzwords.\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"Respond in EXACTLY this format on two lines:\n",
|
| 220 |
+
"DECISION: <one of the option strings>\n",
|
| 221 |
+
"PITCH: <one or two sentences arguing for it, addressing the concerns of opposing members>\"\"\"\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"DECISION_RE = re.compile(r'DECISION\\s*:\\s*([A-Za-z0-9_\\- ]+)', re.IGNORECASE)\n",
|
| 224 |
+
"PITCH_RE = re.compile(r'PITCH\\s*:\\s*(.+)', re.IGNORECASE)\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"def build_prompt(obs):\n",
|
| 227 |
+
" statements = '\\n'.join(\n",
|
| 228 |
+
" f\" {s['role']} ({s['confidence']:.2f}): votes {s['vote']} - {s['statement']}\"\n",
|
| 229 |
+
" for s in obs.npc_statements\n",
|
| 230 |
+
" )\n",
|
| 231 |
+
" return (\n",
|
| 232 |
+
" f\"{SYSTEM_PROMPT}\\n\\n\"\n",
|
| 233 |
+
" f\"State: revenue=${obs.state['revenue']:.0f}/yr burn=${obs.state['burn_rate']:.0f}/mo \"\n",
|
| 234 |
+
" f\"runway={obs.state['runway_months']:.1f}mo morale={obs.state['team_morale']:.2f} \"\n",
|
| 235 |
+
" f\"investors={obs.state['investor_confidence']:.2f} reg_risk={obs.state['regulatory_risk']:.2f}\\n\"\n",
|
| 236 |
+
" f\"Event: {obs.event}\\nBoard:\\n{statements}\\n\"\n",
|
| 237 |
+
" f\"Options: {obs.options}\\n\"\n",
|
| 238 |
+
" )\n",
|
| 239 |
+
"\n",
|
| 240 |
+
"def parse_completion(completion: str, options):\n",
|
| 241 |
+
" decision = options[0]\n",
|
| 242 |
+
" decision_ok = False\n",
|
| 243 |
+
" dm = DECISION_RE.search(completion)\n",
|
| 244 |
+
" if dm:\n",
|
| 245 |
+
" cand = dm.group(1).strip().lower()\n",
|
| 246 |
+
" for opt in options:\n",
|
| 247 |
+
" if opt.lower() == cand or opt.lower() in cand:\n",
|
| 248 |
+
" decision = opt; decision_ok = True; break\n",
|
| 249 |
+
" if not decision_ok:\n",
|
| 250 |
+
" for opt in options:\n",
|
| 251 |
+
" if opt.lower() in completion.lower():\n",
|
| 252 |
+
" decision = opt; break\n",
|
| 253 |
+
" pm = PITCH_RE.search(completion)\n",
|
| 254 |
+
" pitch = pm.group(1).strip()[:400] if pm else ''\n",
|
| 255 |
+
" format_ok = bool(dm) and bool(pm)\n",
|
| 256 |
+
" return decision, pitch, format_ok\n",
|
| 257 |
+
"\n",
|
| 258 |
+
"MAX_NEW_TOKENS = 80\n",
|
| 259 |
+
"\n",
|
| 260 |
+
"def greedy_action(obs):\n",
|
| 261 |
+
" prompt = build_prompt(obs)\n",
|
| 262 |
+
" enc = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(device)\n",
|
| 263 |
+
" with torch.no_grad():\n",
|
| 264 |
+
" out = model.generate(\n",
|
| 265 |
+
" **enc, max_new_tokens=MAX_NEW_TOKENS,\n",
|
| 266 |
+
" do_sample=False, pad_token_id=tokenizer.eos_token_id,\n",
|
| 267 |
+
" )\n",
|
| 268 |
+
" completion = tokenizer.decode(out[0][enc.input_ids.shape[1]:], skip_special_tokens=True)\n",
|
| 269 |
+
" return parse_completion(completion, obs.options)"
|
| 270 |
+
]
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"cell_type": "markdown",
|
| 274 |
+
"metadata": {},
|
| 275 |
+
"source": ["## 7. Episode runner"]
|
| 276 |
+
},
|
| 277 |
+
{
|
| 278 |
+
"cell_type": "code",
|
| 279 |
+
"execution_count": null,
|
| 280 |
+
"metadata": {},
|
| 281 |
+
"outputs": [],
|
| 282 |
+
"source": [
|
| 283 |
+
"import random, statistics, json\n",
|
| 284 |
+
"\n",
|
| 285 |
+
"MAX_STEPS_PER_EP = 20\n",
|
| 286 |
+
"\n",
|
| 287 |
+
"def run_episode(env, seed):\n",
|
| 288 |
+
" result = env.reset(seed=seed)\n",
|
| 289 |
+
" obs = result.observation\n",
|
| 290 |
+
" ep_r, n, fmt_hits, pitch_hits = 0.0, 0, 0, 0\n",
|
| 291 |
+
" while not result.done and n < MAX_STEPS_PER_EP:\n",
|
| 292 |
+
" decision, pitch, fmt_ok = greedy_action(obs)\n",
|
| 293 |
+
" if fmt_ok: fmt_hits += 1\n",
|
| 294 |
+
" if pitch.strip(): pitch_hits += 1\n",
|
| 295 |
+
" result = env.step(BoardSimAction(decision=decision, coalition_pitch=pitch))\n",
|
| 296 |
+
" obs = result.observation\n",
|
| 297 |
+
" ep_r += float(result.reward or 0.0)\n",
|
| 298 |
+
" n += 1\n",
|
| 299 |
+
" return {\n",
|
| 300 |
+
" 'final_profit': obs.state['profitability_score'],\n",
|
| 301 |
+
" 'ep_reward': ep_r, 'steps': n,\n",
|
| 302 |
+
" 'format_rate': fmt_hits / max(1, n), 'pitch_rate': pitch_hits / max(1, n),\n",
|
| 303 |
+
" 'history': obs.state.get('history', []),\n",
|
| 304 |
+
" }"
|
| 305 |
+
]
|
| 306 |
+
},
|
| 307 |
+
{
|
| 308 |
+
"cell_type": "markdown",
|
| 309 |
+
"metadata": {},
|
| 310 |
+
"source": [
|
| 311 |
+
"## 8. Baseline — base Qwen3-1.7B (no fine-tune)\n",
|
| 312 |
+
"Apples-to-apples reference for measuring fine-tuning lift."
|
| 313 |
+
]
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"cell_type": "code",
|
| 317 |
+
"execution_count": null,
|
| 318 |
+
"metadata": {},
|
| 319 |
+
"outputs": [],
|
| 320 |
+
"source": [
|
| 321 |
+
"BASELINE_SEEDS = list(range(50_000, 50_000 + 100))\n",
|
| 322 |
+
"\n",
|
| 323 |
+
"base_finals, base_rewards, base_fmts, base_pitches = [], [], [], []\n",
|
| 324 |
+
"with make_env().sync() as env:\n",
|
| 325 |
+
" for i, s in enumerate(BASELINE_SEEDS):\n",
|
| 326 |
+
" r = run_episode(env, s)\n",
|
| 327 |
+
" base_finals.append(r['final_profit'])\n",
|
| 328 |
+
" base_rewards.append(r['ep_reward'])\n",
|
| 329 |
+
" base_fmts.append(r['format_rate'])\n",
|
| 330 |
+
" base_pitches.append(r['pitch_rate'])\n",
|
| 331 |
+
" if (i + 1) % 10 == 0:\n",
|
| 332 |
+
" print(f' base Qwen3-1.7B {i+1}/{len(BASELINE_SEEDS)} profit={r[\"final_profit\"]:.1f}')\n",
|
| 333 |
+
"\n",
|
| 334 |
+
"BASELINE_MEAN_PROFIT = statistics.mean(base_finals)\n",
|
| 335 |
+
"BASELINE_MEAN_REWARD = statistics.mean(base_rewards)\n",
|
| 336 |
+
"print(f'Base Qwen3-1.7B profit : {BASELINE_MEAN_PROFIT:.2f} \\u00b1 {statistics.stdev(base_finals):.2f}')\n",
|
| 337 |
+
"print(f'Base Qwen3-1.7B ep rwd : {BASELINE_MEAN_REWARD:.2f} \\u00b1 {statistics.stdev(base_rewards):.2f}')\n",
|
| 338 |
+
"print(f'Base format rate : {statistics.mean(base_fmts):.0%} pitch rate: {statistics.mean(base_pitches):.0%}')\n",
|
| 339 |
+
"\n",
|
| 340 |
+
"with open(WORK_DIR / 'baseline.json', 'w') as f:\n",
|
| 341 |
+
" json.dump({'model': MODEL_NAME, 'mode': 'base_no_finetune',\n",
|
| 342 |
+
" 'seeds': BASELINE_SEEDS,\n",
|
| 343 |
+
" 'finals': base_finals, 'rewards': base_rewards,\n",
|
| 344 |
+
" 'format_rates': base_fmts, 'pitch_rates': base_pitches}, f)"
|
| 345 |
+
]
|
| 346 |
+
},
|
| 347 |
+
{
|
| 348 |
+
"cell_type": "markdown",
|
| 349 |
+
"metadata": {},
|
| 350 |
+
"source": ["## 9. Wrap base with LoRA adapters"]
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"cell_type": "code",
|
| 354 |
+
"execution_count": null,
|
| 355 |
+
"metadata": {},
|
| 356 |
+
"outputs": [],
|
| 357 |
+
"source": [
|
| 358 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
| 359 |
+
" model,\n",
|
| 360 |
+
" r=32,\n",
|
| 361 |
+
" target_modules=['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj'],\n",
|
| 362 |
+
" lora_alpha=64,\n",
|
| 363 |
+
" lora_dropout=0.0, bias='none',\n",
|
| 364 |
+
" use_gradient_checkpointing='unsloth',\n",
|
| 365 |
+
" random_state=3407,\n",
|
| 366 |
+
")\n",
|
| 367 |
+
"\n",
|
| 368 |
+
"trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
| 369 |
+
"total = sum(p.numel() for p in model.parameters())\n",
|
| 370 |
+
"print(f'Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)')\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"EVAL_SEEDS = list(range(60_000, 60_000 + 10))\n",
|
| 373 |
+
"\n",
|
| 374 |
+
"def periodic_eval(env):\n",
|
| 375 |
+
" profits, rewards, fmts, pitches = [], [], [], []\n",
|
| 376 |
+
" for s in EVAL_SEEDS:\n",
|
| 377 |
+
" r = run_episode(env, s)\n",
|
| 378 |
+
" profits.append(r['final_profit']); rewards.append(r['ep_reward'])\n",
|
| 379 |
+
" fmts.append(r['format_rate']); pitches.append(r['pitch_rate'])\n",
|
| 380 |
+
" import numpy as np\n",
|
| 381 |
+
" return {'profit_mean': float(np.mean(profits)),\n",
|
| 382 |
+
" 'reward_mean': float(np.mean(rewards)),\n",
|
| 383 |
+
" 'format_rate': float(np.mean(fmts)),\n",
|
| 384 |
+
" 'pitch_rate': float(np.mean(pitches))}"
|
| 385 |
+
]
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"cell_type": "markdown",
|
| 389 |
+
"metadata": {},
|
| 390 |
+
"source": ["## 10. GRPO training loop"]
|
| 391 |
+
},
|
| 392 |
+
{
|
| 393 |
+
"cell_type": "code",
|
| 394 |
+
"execution_count": null,
|
| 395 |
+
"metadata": {},
|
| 396 |
+
"outputs": [],
|
| 397 |
+
"source": [
|
| 398 |
+
"import os, json, math, time, collections\n",
|
| 399 |
+
"from torch.optim import AdamW\n",
|
| 400 |
+
"\n",
|
| 401 |
+
"NUM_STEPS = int(os.environ.get('NUM_STEPS', 200))\n",
|
| 402 |
+
"GROUP_SIZE = int(os.environ.get('GROUP_SIZE', 4))\n",
|
| 403 |
+
"LR = 5e-6\n",
|
| 404 |
+
"GRAD_CLIP = 1.0\n",
|
| 405 |
+
"TEMPERATURE, TOP_P = 1.0, 0.95\n",
|
| 406 |
+
"SAVE_EVERY = 25\n",
|
| 407 |
+
"EVAL_AT = {0, 25, 50, 100, 150, NUM_STEPS - 1}\n",
|
| 408 |
+
"\n",
|
| 409 |
+
"WANDB_OK = False\n",
|
| 410 |
+
"if os.environ.get('WANDB_API_KEY'):\n",
|
| 411 |
+
" try:\n",
|
| 412 |
+
" import wandb\n",
|
| 413 |
+
" wandb.init(project='boardsim-qwen3-grpo', name='boardsim-qwen3-1p7b-kaggle',\n",
|
| 414 |
+
" config={'num_steps': NUM_STEPS, 'group_size': GROUP_SIZE, 'lr': LR,\n",
|
| 415 |
+
" 'temperature': TEMPERATURE, 'top_p': TOP_P, 'model': MODEL_NAME},\n",
|
| 416 |
+
" finish_previous=True)\n",
|
| 417 |
+
" WANDB_OK = True\n",
|
| 418 |
+
" except TypeError:\n",
|
| 419 |
+
" wandb.init(project='boardsim-qwen3-grpo', name='boardsim-qwen3-1p7b-kaggle',\n",
|
| 420 |
+
" config={'num_steps': NUM_STEPS, 'group_size': GROUP_SIZE, 'lr': LR,\n",
|
| 421 |
+
" 'temperature': TEMPERATURE, 'top_p': TOP_P, 'model': MODEL_NAME},\n",
|
| 422 |
+
" reinit=True)\n",
|
| 423 |
+
" WANDB_OK = True\n",
|
| 424 |
+
" except Exception as e:\n",
|
| 425 |
+
" print(f'WARN: wandb.init failed: {e}')\n",
|
| 426 |
+
"\n",
|
| 427 |
+
"optimizer = AdamW([p for p in model.parameters() if p.requires_grad],\n",
|
| 428 |
+
" lr=LR, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0)\n",
|
| 429 |
+
"\n",
|
| 430 |
+
"log_history = []\n",
|
| 431 |
+
"eval_history = []\n",
|
| 432 |
+
"decision_counter = collections.Counter()\n",
|
| 433 |
+
"t0 = time.time()\n",
|
| 434 |
+
"\n",
|
| 435 |
+
"with make_env().sync() as env_train, make_env().sync() as env_score, make_env().sync() as env_eval:\n",
|
| 436 |
+
" for step in range(NUM_STEPS):\n",
|
| 437 |
+
" result = env_train.reset(seed=step)\n",
|
| 438 |
+
" obs = result.observation\n",
|
| 439 |
+
" prompt = build_prompt(obs)\n",
|
| 440 |
+
" enc = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(device)\n",
|
| 441 |
+
" prompt_len = enc.input_ids.shape[1]\n",
|
| 442 |
+
"\n",
|
| 443 |
+
" with torch.no_grad():\n",
|
| 444 |
+
" gen_out = model.generate(\n",
|
| 445 |
+
" input_ids=enc.input_ids, attention_mask=enc.attention_mask,\n",
|
| 446 |
+
" max_new_tokens=MAX_NEW_TOKENS, do_sample=True,\n",
|
| 447 |
+
" temperature=TEMPERATURE, top_p=TOP_P,\n",
|
| 448 |
+
" num_return_sequences=GROUP_SIZE,\n",
|
| 449 |
+
" pad_token_id=tokenizer.eos_token_id,\n",
|
| 450 |
+
" )\n",
|
| 451 |
+
" gen_out = gen_out.detach().clone()\n",
|
| 452 |
+
"\n",
|
| 453 |
+
" decisions, pitches, rewards, fmt_oks = [], [], [], []\n",
|
| 454 |
+
" for g in range(GROUP_SIZE):\n",
|
| 455 |
+
" comp = tokenizer.decode(gen_out[g][prompt_len:], skip_special_tokens=True)\n",
|
| 456 |
+
" d, pp, ok = parse_completion(comp, obs.options)\n",
|
| 457 |
+
" decisions.append(d); pitches.append(pp); fmt_oks.append(ok)\n",
|
| 458 |
+
" decision_counter[d] += 1\n",
|
| 459 |
+
" env_score.reset(seed=step)\n",
|
| 460 |
+
" sr = env_score.step(BoardSimAction(decision=d, coalition_pitch=pp))\n",
|
| 461 |
+
" rewards.append(float(sr.reward or 0.0))\n",
|
| 462 |
+
"\n",
|
| 463 |
+
" rewards_t = torch.tensor(rewards, dtype=torch.float32, device=device)\n",
|
| 464 |
+
" if rewards_t.numel() > 1 and rewards_t.std().item() > 1e-6:\n",
|
| 465 |
+
" advantages = (rewards_t - rewards_t.mean()) / (rewards_t.std() + 1e-8)\n",
|
| 466 |
+
" else:\n",
|
| 467 |
+
" advantages = rewards_t - rewards_t.mean()\n",
|
| 468 |
+
"\n",
|
| 469 |
+
" optimizer.zero_grad()\n",
|
| 470 |
+
" full_ids = gen_out\n",
|
| 471 |
+
" attn = (full_ids != tokenizer.pad_token_id).long()\n",
|
| 472 |
+
" loss_mask = attn.clone()\n",
|
| 473 |
+
" loss_mask[:, :prompt_len] = 0\n",
|
| 474 |
+
" out = model(input_ids=full_ids, attention_mask=attn)\n",
|
| 475 |
+
" logits = out.logits[:, :-1, :].float()\n",
|
| 476 |
+
" targets = full_ids[:, 1:]\n",
|
| 477 |
+
" mask = loss_mask[:, 1:].float()\n",
|
| 478 |
+
" log_probs = torch.nn.functional.log_softmax(logits, dim=-1)\n",
|
| 479 |
+
" token_nll = -log_probs.gather(2, targets.unsqueeze(-1)).squeeze(-1)\n",
|
| 480 |
+
" per_seq_nll = (token_nll * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)\n",
|
| 481 |
+
" loss = (advantages.detach() * per_seq_nll).mean()\n",
|
| 482 |
+
" loss.backward()\n",
|
| 483 |
+
" total_loss_val = float(loss.detach().item())\n",
|
| 484 |
+
" torch.nn.utils.clip_grad_norm_(\n",
|
| 485 |
+
" [p for p in model.parameters() if p.requires_grad], GRAD_CLIP)\n",
|
| 486 |
+
" optimizer.step()\n",
|
| 487 |
+
"\n",
|
| 488 |
+
" rec = {\n",
|
| 489 |
+
" 'step': step,\n",
|
| 490 |
+
" 'reward': float(rewards_t.mean().item()),\n",
|
| 491 |
+
" 'reward_std': float(rewards_t.std().item()) if rewards_t.numel() > 1 else 0.0,\n",
|
| 492 |
+
" 'reward_max': float(rewards_t.max().item()),\n",
|
| 493 |
+
" 'loss': total_loss_val,\n",
|
| 494 |
+
" 'format_rate': sum(fmt_oks) / GROUP_SIZE,\n",
|
| 495 |
+
" 'pitch_rate': sum(1 for p in pitches if p.strip()) / GROUP_SIZE,\n",
|
| 496 |
+
" 'elapsed_s': time.time() - t0,\n",
|
| 497 |
+
" }\n",
|
| 498 |
+
" log_history.append(rec)\n",
|
| 499 |
+
" if WANDB_OK:\n",
|
| 500 |
+
" wandb.log(rec, step=step)\n",
|
| 501 |
+
"\n",
|
| 502 |
+
" if step % 5 == 0:\n",
|
| 503 |
+
" print(f\"step={step:4d} reward={rec['reward']:+.3f} (\\u00b1{rec['reward_std']:.2f}) \"\n",
|
| 504 |
+
" f\"loss={rec['loss']:+.4f} fmt={rec['format_rate']:.0%} \"\n",
|
| 505 |
+
" f\"elapsed={rec['elapsed_s']:.0f}s d0={decisions[0]}\")\n",
|
| 506 |
+
"\n",
|
| 507 |
+
" if step in EVAL_AT:\n",
|
| 508 |
+
" ev = periodic_eval(env_eval)\n",
|
| 509 |
+
" ev['step'] = step\n",
|
| 510 |
+
" eval_history.append(ev)\n",
|
| 511 |
+
" print(f\" [eval@{step}] profit={ev['profit_mean']:.2f} \"\n",
|
| 512 |
+
" f\"reward={ev['reward_mean']:.2f} fmt={ev['format_rate']:.0%}\")\n",
|
| 513 |
+
" if WANDB_OK:\n",
|
| 514 |
+
" wandb.log({f'eval/{k}': v for k, v in ev.items() if k != 'step'}, step=step)\n",
|
| 515 |
+
"\n",
|
| 516 |
+
" if step > 0 and step % SAVE_EVERY == 0:\n",
|
| 517 |
+
" model.save_pretrained(str(CKPT))\n",
|
| 518 |
+
" tokenizer.save_pretrained(str(CKPT))\n",
|
| 519 |
+
" with open(WORK_DIR / 'log_history.json', 'w') as f:\n",
|
| 520 |
+
" json.dump(log_history, f)\n",
|
| 521 |
+
" with open(WORK_DIR / 'eval_history.json', 'w') as f:\n",
|
| 522 |
+
" json.dump(eval_history, f)\n",
|
| 523 |
+
"\n",
|
| 524 |
+
"model.save_pretrained(str(CKPT))\n",
|
| 525 |
+
"tokenizer.save_pretrained(str(CKPT))\n",
|
| 526 |
+
"with open(WORK_DIR / 'log_history.json', 'w') as f:\n",
|
| 527 |
+
" json.dump(log_history, f)\n",
|
| 528 |
+
"with open(WORK_DIR / 'eval_history.json', 'w') as f:\n",
|
| 529 |
+
" json.dump(eval_history, f)\n",
|
| 530 |
+
"with open(WORK_DIR / 'decision_counter.json', 'w') as f:\n",
|
| 531 |
+
" json.dump(dict(decision_counter), f)\n",
|
| 532 |
+
"if WANDB_OK:\n",
|
| 533 |
+
" wandb.finish()\n",
|
| 534 |
+
"print(f'Training done. {len(log_history)} steps in {time.time() - t0:.0f}s. -> {CKPT}')"
|
| 535 |
+
]
|
| 536 |
+
},
|
| 537 |
+
{
|
| 538 |
+
"cell_type": "markdown",
|
| 539 |
+
"metadata": {},
|
| 540 |
+
"source": ["## 11. Plots — reward / loss / format / periodic eval"]
|
| 541 |
+
},
|
| 542 |
+
{
|
| 543 |
+
"cell_type": "code",
|
| 544 |
+
"execution_count": null,
|
| 545 |
+
"metadata": {},
|
| 546 |
+
"outputs": [],
|
| 547 |
+
"source": [
|
| 548 |
+
"import numpy as np, matplotlib\n",
|
| 549 |
+
"matplotlib.use('Agg')\n",
|
| 550 |
+
"import matplotlib.pyplot as plt\n",
|
| 551 |
+
"from scipy import stats as spstats\n",
|
| 552 |
+
"\n",
|
| 553 |
+
"steps = np.array([e['step'] for e in log_history])\n",
|
| 554 |
+
"rewards = np.array([e['reward'] for e in log_history])\n",
|
| 555 |
+
"losses = np.array([e['loss'] for e in log_history])\n",
|
| 556 |
+
"fmts = np.array([e['format_rate'] for e in log_history])\n",
|
| 557 |
+
"pitches = np.array([e['pitch_rate'] for e in log_history])\n",
|
| 558 |
+
"\n",
|
| 559 |
+
"def ema(xs, alpha=0.1):\n",
|
| 560 |
+
" out, s = [], xs[0] if len(xs) else 0.0\n",
|
| 561 |
+
" for x in xs:\n",
|
| 562 |
+
" s = alpha * x + (1 - alpha) * s\n",
|
| 563 |
+
" out.append(s)\n",
|
| 564 |
+
" return np.array(out)\n",
|
| 565 |
+
"\n",
|
| 566 |
+
"rewards_ema = ema(rewards, 0.1)\n",
|
| 567 |
+
"slope, intercept, r_val, p_val, _ = spstats.linregress(steps, rewards)\n",
|
| 568 |
+
"\n",
|
| 569 |
+
"plt.figure(figsize=(9, 5))\n",
|
| 570 |
+
"plt.plot(steps, rewards, alpha=0.3, lw=1, label='per-step group reward')\n",
|
| 571 |
+
"plt.plot(steps, rewards_ema, lw=2.2, label='EMA (\\u03b1=0.1)')\n",
|
| 572 |
+
"plt.plot(steps, intercept + slope * steps, '--', lw=1.5,\n",
|
| 573 |
+
" label=f'linear fit slope={slope:+.4f}/step (p={p_val:.1e})')\n",
|
| 574 |
+
"plt.axhline(BASELINE_MEAN_REWARD, ls=':', lw=2, color='#c44',\n",
|
| 575 |
+
" label=f'base Qwen3-1.7B baseline = {BASELINE_MEAN_REWARD:.2f}')\n",
|
| 576 |
+
"plt.title('GRPO reward — BoardSim (vs same model w/o fine-tuning)')\n",
|
| 577 |
+
"plt.xlabel('step'); plt.ylabel('mean group reward')\n",
|
| 578 |
+
"plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 579 |
+
"plt.savefig(ASSETS / 'reward_curve.png', dpi=150); plt.close()\n",
|
| 580 |
+
"\n",
|
| 581 |
+
"plt.figure(figsize=(9, 5))\n",
|
| 582 |
+
"plt.plot(steps, losses, lw=1.5)\n",
|
| 583 |
+
"plt.title('GRPO loss (advantage \\u00d7 NLL)'); plt.xlabel('step'); plt.ylabel('loss')\n",
|
| 584 |
+
"plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 585 |
+
"plt.savefig(ASSETS / 'loss_curve.png', dpi=150); plt.close()\n",
|
| 586 |
+
"\n",
|
| 587 |
+
"plt.figure(figsize=(9, 5))\n",
|
| 588 |
+
"plt.plot(steps, ema(fmts, 0.05), lw=2, label='format-OK rate (EMA)')\n",
|
| 589 |
+
"plt.plot(steps, ema(pitches, 0.05), lw=2, label='non-empty pitch rate (EMA)')\n",
|
| 590 |
+
"plt.title('Format compliance + pitch usage during training')\n",
|
| 591 |
+
"plt.xlabel('step'); plt.ylabel('rate'); plt.ylim(-0.05, 1.05)\n",
|
| 592 |
+
"plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 593 |
+
"plt.savefig(ASSETS / 'format_compliance.png', dpi=150); plt.close()\n",
|
| 594 |
+
"\n",
|
| 595 |
+
"if eval_history:\n",
|
| 596 |
+
" es = [e['step'] for e in eval_history]\n",
|
| 597 |
+
" epm = [e['profit_mean'] for e in eval_history]\n",
|
| 598 |
+
" erm = [e['reward_mean'] for e in eval_history]\n",
|
| 599 |
+
" plt.figure(figsize=(9, 5))\n",
|
| 600 |
+
" plt.plot(es, epm, '-o', lw=2, label='held-out profitability (mean of 10 episodes)')\n",
|
| 601 |
+
" plt.plot(es, erm, '-s', lw=2, label='held-out episode reward')\n",
|
| 602 |
+
" plt.axhline(BASELINE_MEAN_PROFIT, ls=':', lw=1.5, color='#c44',\n",
|
| 603 |
+
" label=f'base Qwen3-1.7B profitability = {BASELINE_MEAN_PROFIT:.2f}')\n",
|
| 604 |
+
" plt.title('Periodic held-out eval during training (greedy)')\n",
|
| 605 |
+
" plt.xlabel('training step'); plt.ylabel('value')\n",
|
| 606 |
+
" plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 607 |
+
" plt.savefig(ASSETS / 'periodic_eval.png', dpi=150); plt.close()\n",
|
| 608 |
+
"\n",
|
| 609 |
+
"print(f'Linear-fit slope on reward: {slope:+.5f}/step (p={p_val:.2e}, R\\u00b2={r_val**2:.3f})')\n",
|
| 610 |
+
"print('Saved reward_curve.png, loss_curve.png, format_compliance.png, periodic_eval.png')"
|
| 611 |
+
]
|
| 612 |
+
},
|
| 613 |
+
{
|
| 614 |
+
"cell_type": "markdown",
|
| 615 |
+
"metadata": {},
|
| 616 |
+
"source": ["## 12. Paired same-seed eval — fine-tuned vs base Qwen3-1.7B"]
|
| 617 |
+
},
|
| 618 |
+
{
|
| 619 |
+
"cell_type": "code",
|
| 620 |
+
"execution_count": null,
|
| 621 |
+
"metadata": {},
|
| 622 |
+
"outputs": [],
|
| 623 |
+
"source": [
|
| 624 |
+
"from unsloth import FastLanguageModel\n",
|
| 625 |
+
"FastLanguageModel.for_inference(model)\n",
|
| 626 |
+
"\n",
|
| 627 |
+
"EVAL_N = 50\n",
|
| 628 |
+
"PAIRED_SEEDS = list(range(70_000, 70_000 + EVAL_N))\n",
|
| 629 |
+
"\n",
|
| 630 |
+
"trained_finals, trained_rewards, trained_fmt, trained_pitch = [], [], [], []\n",
|
| 631 |
+
"trained_history_per_seed = []\n",
|
| 632 |
+
"with make_env().sync() as env:\n",
|
| 633 |
+
" for i, s in enumerate(PAIRED_SEEDS):\n",
|
| 634 |
+
" r = run_episode(env, s)\n",
|
| 635 |
+
" trained_finals.append(r['final_profit'])\n",
|
| 636 |
+
" trained_rewards.append(r['ep_reward'])\n",
|
| 637 |
+
" trained_fmt.append(r['format_rate'])\n",
|
| 638 |
+
" trained_pitch.append(r['pitch_rate'])\n",
|
| 639 |
+
" trained_history_per_seed.append(r['history'])\n",
|
| 640 |
+
" if (i + 1) % 10 == 0:\n",
|
| 641 |
+
" print(f' trained {i+1}/{EVAL_N} profit={r[\"final_profit\"]:.1f}')\n",
|
| 642 |
+
"\n",
|
| 643 |
+
"base_finals_paired, base_rewards_paired, base_fmt_paired, base_pitch_paired = [], [], [], []\n",
|
| 644 |
+
"base_history_per_seed = []\n",
|
| 645 |
+
"with make_env().sync() as env, model.disable_adapter():\n",
|
| 646 |
+
" for i, s in enumerate(PAIRED_SEEDS):\n",
|
| 647 |
+
" r = run_episode(env, s)\n",
|
| 648 |
+
" base_finals_paired.append(r['final_profit'])\n",
|
| 649 |
+
" base_rewards_paired.append(r['ep_reward'])\n",
|
| 650 |
+
" base_fmt_paired.append(r['format_rate'])\n",
|
| 651 |
+
" base_pitch_paired.append(r['pitch_rate'])\n",
|
| 652 |
+
" base_history_per_seed.append(r['history'])\n",
|
| 653 |
+
" if (i + 1) % 10 == 0:\n",
|
| 654 |
+
" print(f' base {i+1}/{EVAL_N} profit={r[\"final_profit\"]:.1f}')\n",
|
| 655 |
+
"\n",
|
| 656 |
+
"tf, bf = np.array(trained_finals), np.array(base_finals_paired)\n",
|
| 657 |
+
"tr, br = np.array(trained_rewards), np.array(base_rewards_paired)\n",
|
| 658 |
+
"\n",
|
| 659 |
+
"print(f'\\nTrained Qwen3-1.7B profit : {tf.mean():.2f} \\u00b1 {tf.std():.2f}')\n",
|
| 660 |
+
"print(f'Base Qwen3-1.7B profit : {bf.mean():.2f} \\u00b1 {bf.std():.2f}')\n",
|
| 661 |
+
"print(f'Trained ep reward : {tr.mean():.2f} \\u00b1 {tr.std():.2f}')\n",
|
| 662 |
+
"print(f'Base ep reward : {br.mean():.2f} \\u00b1 {br.std():.2f}')\n",
|
| 663 |
+
"print(f'Trained format/pitch : {np.mean(trained_fmt):.0%} / {np.mean(trained_pitch):.0%}')\n",
|
| 664 |
+
"print(f'Base format/pitch : {np.mean(base_fmt_paired):.0%} / {np.mean(base_pitch_paired):.0%}')\n",
|
| 665 |
+
"\n",
|
| 666 |
+
"with open(WORK_DIR / 'eval_paired.json', 'w') as f:\n",
|
| 667 |
+
" json.dump({'seeds': PAIRED_SEEDS,\n",
|
| 668 |
+
" 'trained_finals': tf.tolist(), 'base_finals': bf.tolist(),\n",
|
| 669 |
+
" 'trained_rewards': tr.tolist(), 'base_rewards': br.tolist(),\n",
|
| 670 |
+
" 'trained_format_rate': float(np.mean(trained_fmt)),\n",
|
| 671 |
+
" 'base_format_rate': float(np.mean(base_fmt_paired)),\n",
|
| 672 |
+
" 'trained_pitch_rate': float(np.mean(trained_pitch)),\n",
|
| 673 |
+
" 'base_pitch_rate': float(np.mean(base_pitch_paired))}, f)"
|
| 674 |
+
]
|
| 675 |
+
},
|
| 676 |
+
{
|
| 677 |
+
"cell_type": "markdown",
|
| 678 |
+
"metadata": {},
|
| 679 |
+
"source": ["## 13. Stats summary + before/after plots"]
|
| 680 |
+
},
|
| 681 |
+
{
|
| 682 |
+
"cell_type": "code",
|
| 683 |
+
"execution_count": null,
|
| 684 |
+
"metadata": {},
|
| 685 |
+
"outputs": [],
|
| 686 |
+
"source": [
|
| 687 |
+
"from scipy import stats as spstats\n",
|
| 688 |
+
"\n",
|
| 689 |
+
"def cohen_d(a, b):\n",
|
| 690 |
+
" pooled = np.sqrt(((a.std(ddof=1)**2) + (b.std(ddof=1)**2)) / 2)\n",
|
| 691 |
+
" return (a.mean() - b.mean()) / (pooled + 1e-12)\n",
|
| 692 |
+
"\n",
|
| 693 |
+
"def bootstrap_diff_ci(a, b, n=10_000, seed=0):\n",
|
| 694 |
+
" rng = np.random.default_rng(seed)\n",
|
| 695 |
+
" diffs = a - b\n",
|
| 696 |
+
" boots = rng.choice(diffs, size=(n, len(diffs)), replace=True).mean(axis=1)\n",
|
| 697 |
+
" return float(np.percentile(boots, 2.5)), float(np.percentile(boots, 97.5))\n",
|
| 698 |
+
"\n",
|
| 699 |
+
"tt = spstats.ttest_rel(tf, bf)\n",
|
| 700 |
+
"uu = spstats.mannwhitneyu(tf, bf, alternative='greater')\n",
|
| 701 |
+
"wilc = spstats.wilcoxon(tf, bf, alternative='greater')\n",
|
| 702 |
+
"d = cohen_d(tf, bf)\n",
|
| 703 |
+
"lo, hi = bootstrap_diff_ci(tf, bf)\n",
|
| 704 |
+
"win_rate = float((tf > bf).mean())\n",
|
| 705 |
+
"tie_rate = float((tf == bf).mean())\n",
|
| 706 |
+
"\n",
|
| 707 |
+
"summary = {\n",
|
| 708 |
+
" 'baseline_model': MODEL_NAME + ' (no fine-tune)',\n",
|
| 709 |
+
" 'trained_model': MODEL_NAME + ' + LoRA r=32',\n",
|
| 710 |
+
" 'n': len(tf),\n",
|
| 711 |
+
" 'paired_t_stat': float(tt.statistic), 'paired_t_p': float(tt.pvalue),\n",
|
| 712 |
+
" 'mannwhitney_U': float(uu.statistic), 'mannwhitney_p_greater': float(uu.pvalue),\n",
|
| 713 |
+
" 'wilcoxon_p_greater': float(wilc.pvalue),\n",
|
| 714 |
+
" 'cohens_d': float(d),\n",
|
| 715 |
+
" 'paired_diff_mean': float((tf - bf).mean()),\n",
|
| 716 |
+
" 'paired_diff_95ci': [lo, hi],\n",
|
| 717 |
+
" 'win_rate_trained_strictly_better': win_rate,\n",
|
| 718 |
+
" 'tie_rate': tie_rate,\n",
|
| 719 |
+
"}\n",
|
| 720 |
+
"print(json.dumps(summary, indent=2))\n",
|
| 721 |
+
"with open(WORK_DIR / 'stats_summary.json', 'w') as f:\n",
|
| 722 |
+
" json.dump(summary, f, indent=2)\n",
|
| 723 |
+
"\n",
|
| 724 |
+
"bins = np.linspace(0, 100, 25)\n",
|
| 725 |
+
"plt.figure(figsize=(9, 5))\n",
|
| 726 |
+
"plt.hist(bf, bins=bins, alpha=0.55, color='#c44',\n",
|
| 727 |
+
" label=f'Base Qwen3-1.7B (mean={bf.mean():.1f})')\n",
|
| 728 |
+
"plt.hist(tf, bins=bins, alpha=0.55, color='#1d6fff',\n",
|
| 729 |
+
" label=f'Fine-tuned Qwen3-1.7B (mean={tf.mean():.1f})')\n",
|
| 730 |
+
"plt.axvline(bf.mean(), color='#c44', ls='--', lw=1.5)\n",
|
| 731 |
+
"plt.axvline(tf.mean(), color='#1d6fff', ls='--', lw=1.5)\n",
|
| 732 |
+
"plt.title(f'Final profitability — paired same-seed (n={len(tf)}) '\n",
|
| 733 |
+
" f\"d={summary['cohens_d']:+.2f} win-rate={summary['win_rate_trained_strictly_better']:.0%}\")\n",
|
| 734 |
+
"plt.xlabel('profitability score (0\\u2013100)'); plt.ylabel('episodes')\n",
|
| 735 |
+
"plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 736 |
+
"plt.savefig(ASSETS / 'before_after.png', dpi=150); plt.close()\n",
|
| 737 |
+
"\n",
|
| 738 |
+
"diffs = tf - bf\n",
|
| 739 |
+
"order = np.argsort(diffs)\n",
|
| 740 |
+
"plt.figure(figsize=(9, 5))\n",
|
| 741 |
+
"plt.bar(range(len(diffs)), diffs[order],\n",
|
| 742 |
+
" color=['#1d6fff' if x > 0 else '#c44' for x in diffs[order]])\n",
|
| 743 |
+
"plt.axhline(0, color='k', lw=0.8)\n",
|
| 744 |
+
"plt.title(f'Per-seed lift (fine-tuned \\u2212 base Qwen3-1.7B), sorted '\n",
|
| 745 |
+
" f'mean lift = {diffs.mean():+.1f} CI=[{summary[\"paired_diff_95ci\"][0]:+.1f}, {summary[\"paired_diff_95ci\"][1]:+.1f}]')\n",
|
| 746 |
+
"plt.xlabel('seed (sorted by lift)'); plt.ylabel('\\u0394 profitability')\n",
|
| 747 |
+
"plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 748 |
+
"plt.savefig(ASSETS / 'paired_delta.png', dpi=150); plt.close()\n",
|
| 749 |
+
"print('Saved before_after.png, paired_delta.png')"
|
| 750 |
+
]
|
| 751 |
+
},
|
| 752 |
+
{
|
| 753 |
+
"cell_type": "markdown",
|
| 754 |
+
"metadata": {},
|
| 755 |
+
"source": ["## 14. Per-event win-rate breakdown"]
|
| 756 |
+
},
|
| 757 |
+
{
|
| 758 |
+
"cell_type": "code",
|
| 759 |
+
"execution_count": null,
|
| 760 |
+
"metadata": {},
|
| 761 |
+
"outputs": [],
|
| 762 |
+
"source": [
|
| 763 |
+
"def per_event_winrate(history_per_seed):\n",
|
| 764 |
+
" bucket = collections.defaultdict(lambda: [0, 0])\n",
|
| 765 |
+
" for hist in history_per_seed:\n",
|
| 766 |
+
" for rd in hist:\n",
|
| 767 |
+
" t = rd.get('event_title', '?')\n",
|
| 768 |
+
" bucket[t][1] += 1\n",
|
| 769 |
+
" if rd.get('agent_won_vote'):\n",
|
| 770 |
+
" bucket[t][0] += 1\n",
|
| 771 |
+
" return {t: (w / max(1, n)) for t, (w, n) in bucket.items()}\n",
|
| 772 |
+
"\n",
|
| 773 |
+
"trained_wr = per_event_winrate(trained_history_per_seed)\n",
|
| 774 |
+
"base_wr = per_event_winrate(base_history_per_seed)\n",
|
| 775 |
+
"\n",
|
| 776 |
+
"events_sorted = sorted(set(trained_wr) | set(base_wr))\n",
|
| 777 |
+
"tw = [trained_wr.get(e, 0.0) for e in events_sorted]\n",
|
| 778 |
+
"bw = [base_wr.get(e, 0.0) for e in events_sorted]\n",
|
| 779 |
+
"\n",
|
| 780 |
+
"plt.figure(figsize=(11, 5))\n",
|
| 781 |
+
"x = np.arange(len(events_sorted))\n",
|
| 782 |
+
"plt.bar(x - 0.2, bw, width=0.4, color='#c44', label='Base Qwen3-1.7B')\n",
|
| 783 |
+
"plt.bar(x + 0.2, tw, width=0.4, color='#1d6fff', label='Fine-tuned Qwen3-1.7B')\n",
|
| 784 |
+
"plt.xticks(x, [e[:22] for e in events_sorted], rotation=30, ha='right')\n",
|
| 785 |
+
"plt.ylim(0, 1.05); plt.ylabel('boardroom win rate')\n",
|
| 786 |
+
"plt.title('Per-event boardroom win rate (paired seeds, n=50 episodes)')\n",
|
| 787 |
+
"plt.legend(); plt.grid(alpha=0.3, axis='y'); plt.tight_layout()\n",
|
| 788 |
+
"plt.savefig(ASSETS / 'per_event_winrate.png', dpi=150); plt.close()\n",
|
| 789 |
+
"\n",
|
| 790 |
+
"with open(WORK_DIR / 'per_event_winrate.json', 'w') as f:\n",
|
| 791 |
+
" json.dump({'events': events_sorted, 'trained': tw, 'base': bw}, f, indent=2)\n",
|
| 792 |
+
"print('Saved per_event_winrate.png')"
|
| 793 |
+
]
|
| 794 |
+
},
|
| 795 |
+
{
|
| 796 |
+
"cell_type": "markdown",
|
| 797 |
+
"metadata": {},
|
| 798 |
+
"source": ["## 15. Theory-of-Mind probe"]
|
| 799 |
+
},
|
| 800 |
+
{
|
| 801 |
+
"cell_type": "code",
|
| 802 |
+
"execution_count": null,
|
| 803 |
+
"metadata": {},
|
| 804 |
+
"outputs": [],
|
| 805 |
+
"source": [
|
| 806 |
+
"TOM_INSTRUCTION = (\n",
|
| 807 |
+
" \"\\n\\nGiven the state and event below, name the SINGLE board member \"\n",
|
| 808 |
+
" \"(CTO, CFO, Investor Rep, or Independent) most likely to oppose the chosen decision. \"\n",
|
| 809 |
+
" \"Answer with just the role name on one line.\\n\"\n",
|
| 810 |
+
")\n",
|
| 811 |
+
"\n",
|
| 812 |
+
"def tom_predict(obs, decision):\n",
|
| 813 |
+
" body = build_prompt(obs).split(SYSTEM_PROMPT, 1)[1]\n",
|
| 814 |
+
" prompt = SYSTEM_PROMPT + TOM_INSTRUCTION + body + f'Chosen decision: {decision}\\nMost likely opponent: '\n",
|
| 815 |
+
" enc = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(device)\n",
|
| 816 |
+
" with torch.no_grad():\n",
|
| 817 |
+
" out = model.generate(**enc, max_new_tokens=8, do_sample=False,\n",
|
| 818 |
+
" pad_token_id=tokenizer.eos_token_id)\n",
|
| 819 |
+
" txt = tokenizer.decode(out[0][enc.input_ids.shape[1]:], skip_special_tokens=True).lower()\n",
|
| 820 |
+
" if 'investor' in txt: return 'Investor Rep'\n",
|
| 821 |
+
" if 'independent' in txt: return 'Independent'\n",
|
| 822 |
+
" if 'cto' in txt: return 'CTO'\n",
|
| 823 |
+
" if 'cfo' in txt: return 'CFO'\n",
|
| 824 |
+
" return None\n",
|
| 825 |
+
"\n",
|
| 826 |
+
"def tom_eval(seed_base=80_000, n=40):\n",
|
| 827 |
+
" correct = total = 0\n",
|
| 828 |
+
" with make_env().sync() as env:\n",
|
| 829 |
+
" for ep in range(n):\n",
|
| 830 |
+
" result = env.reset(seed=seed_base + ep)\n",
|
| 831 |
+
" obs = result.observation\n",
|
| 832 |
+
" decision, _, _ = greedy_action(obs)\n",
|
| 833 |
+
" opposed = [s['role'] for s in obs.npc_statements if s['vote'] != decision]\n",
|
| 834 |
+
" if not opposed:\n",
|
| 835 |
+
" continue\n",
|
| 836 |
+
" pred = tom_predict(obs, decision)\n",
|
| 837 |
+
" if pred and pred in opposed:\n",
|
| 838 |
+
" correct += 1\n",
|
| 839 |
+
" total += 1\n",
|
| 840 |
+
" return correct, total\n",
|
| 841 |
+
"\n",
|
| 842 |
+
"t_corr, t_tot = tom_eval()\n",
|
| 843 |
+
"with model.disable_adapter():\n",
|
| 844 |
+
" b_corr, b_tot = tom_eval()\n",
|
| 845 |
+
"\n",
|
| 846 |
+
"tom_acc = t_corr / max(1, t_tot)\n",
|
| 847 |
+
"tom_acc_base = b_corr / max(1, b_tot)\n",
|
| 848 |
+
"print(f'ToM probe: trained = {tom_acc:.1%} ({t_corr}/{t_tot}) base = {tom_acc_base:.1%} ({b_corr}/{b_tot})')\n",
|
| 849 |
+
"with open(WORK_DIR / 'tom.json', 'w') as f:\n",
|
| 850 |
+
" json.dump({'trained': {'correct': t_corr, 'total': t_tot, 'accuracy': tom_acc},\n",
|
| 851 |
+
" 'base': {'correct': b_corr, 'total': b_tot, 'accuracy': tom_acc_base}}, f)"
|
| 852 |
+
]
|
| 853 |
+
},
|
| 854 |
+
{
|
| 855 |
+
"cell_type": "markdown",
|
| 856 |
+
"metadata": {},
|
| 857 |
+
"source": ["## 16. Push to HF Hub"]
|
| 858 |
+
},
|
| 859 |
+
{
|
| 860 |
+
"cell_type": "code",
|
| 861 |
+
"execution_count": null,
|
| 862 |
+
"metadata": {},
|
| 863 |
+
"outputs": [],
|
| 864 |
+
"source": [
|
| 865 |
+
"from huggingface_hub import HfApi\n",
|
| 866 |
+
"ADAPTER_REPO = os.environ.get('ADAPTER_REPO', 'StavanKhobare/SST-MetaxPyTorch-Hackathon-LoRA')\n",
|
| 867 |
+
"MERGED_REPO = os.environ.get('MERGED_REPO', 'StavanKhobare/SST-MetaxPyTorch-Hackathon-Merged16bit')\n",
|
| 868 |
+
"\n",
|
| 869 |
+
"api = HfApi()\n",
|
| 870 |
+
"api.create_repo(ADAPTER_REPO, repo_type='model', private=False, exist_ok=True)\n",
|
| 871 |
+
"api.create_repo(MERGED_REPO, repo_type='model', private=False, exist_ok=True)\n",
|
| 872 |
+
"\n",
|
| 873 |
+
"try:\n",
|
| 874 |
+
" model.push_to_hub(ADAPTER_REPO, private=False)\n",
|
| 875 |
+
" tokenizer.push_to_hub(ADAPTER_REPO, private=False)\n",
|
| 876 |
+
" print(f'\\u2713 LoRA pushed: https://huggingface.co/{ADAPTER_REPO}')\n",
|
| 877 |
+
"except Exception as e:\n",
|
| 878 |
+
" print(f'LoRA push failed: {e!r}')\n",
|
| 879 |
+
"\n",
|
| 880 |
+
"try:\n",
|
| 881 |
+
" model.push_to_hub_merged(MERGED_REPO, tokenizer, save_method='merged_16bit', private=False)\n",
|
| 882 |
+
" print(f'\\u2713 Merged 16-bit pushed: https://huggingface.co/{MERGED_REPO}')\n",
|
| 883 |
+
"except Exception as e:\n",
|
| 884 |
+
" print(f'Merged push failed (you can retry): {e!r}')\n",
|
| 885 |
+
"\n",
|
| 886 |
+
"try:\n",
|
| 887 |
+
" api.upload_folder(folder_path=str(ASSETS), repo_id=ADAPTER_REPO,\n",
|
| 888 |
+
" path_in_repo='assets', repo_type='model')\n",
|
| 889 |
+
" for fname in ['log_history.json','eval_history.json','eval_paired.json',\n",
|
| 890 |
+
" 'stats_summary.json','tom.json','transcripts.json',\n",
|
| 891 |
+
" 'decision_counter.json','baseline.json',\n",
|
| 892 |
+
" 'per_event_winrate.json']:\n",
|
| 893 |
+
" fp = WORK_DIR / fname\n",
|
| 894 |
+
" if fp.exists():\n",
|
| 895 |
+
" api.upload_file(path_or_fileobj=str(fp), path_in_repo=fname,\n",
|
| 896 |
+
" repo_id=ADAPTER_REPO, repo_type='model')\n",
|
| 897 |
+
" print(f'\\u2713 Artifacts uploaded to https://huggingface.co/{ADAPTER_REPO}')\n",
|
| 898 |
+
"except Exception as e:\n",
|
| 899 |
+
" print(f'Artifact upload failed: {e!r}')"
|
| 900 |
+
]
|
| 901 |
+
},
|
| 902 |
+
{
|
| 903 |
+
"cell_type": "markdown",
|
| 904 |
+
"metadata": {},
|
| 905 |
+
"source": ["## 17. Final summary"]
|
| 906 |
+
},
|
| 907 |
+
{
|
| 908 |
+
"cell_type": "code",
|
| 909 |
+
"execution_count": null,
|
| 910 |
+
"metadata": {},
|
| 911 |
+
"outputs": [],
|
| 912 |
+
"source": [
|
| 913 |
+
"# Decision entropy (over GRPO rollouts)\n",
|
| 914 |
+
"_total = sum(decision_counter.values())\n",
|
| 915 |
+
"_probs = [c / _total for c in decision_counter.values()] if _total else []\n",
|
| 916 |
+
"entropy = -sum(p * math.log(p + 1e-12) for p in _probs) if _probs else 0.0\n",
|
| 917 |
+
"max_ent = math.log(len(decision_counter)) if decision_counter else 0.0\n",
|
| 918 |
+
"\n",
|
| 919 |
+
"print('='*70)\n",
|
| 920 |
+
"print('BOARDSIM \\u00d7 QWEN3-1.7B \\u2014 LEARNING EVIDENCE')\n",
|
| 921 |
+
"print('='*70)\n",
|
| 922 |
+
"print(f'Reward slope (linear fit) : {slope:+.5f}/step (p={p_val:.2e})')\n",
|
| 923 |
+
"print(f'Reward EMA first 20 steps : {rewards_ema[:20].mean():+.3f}')\n",
|
| 924 |
+
"print(f'Reward EMA last 20 steps : {rewards_ema[-20:].mean():+.3f}')\n",
|
| 925 |
+
"print(f'Format compliance start : {fmts[:20].mean():.0%}')\n",
|
| 926 |
+
"print(f'Format compliance end : {fmts[-20:].mean():.0%}')\n",
|
| 927 |
+
"print('-'*70)\n",
|
| 928 |
+
"print(f'Held-out paired (n={len(tf)}): fine-tuned {tf.mean():.2f} vs base {bf.mean():.2f}')\n",
|
| 929 |
+
"print(f' paired t-test p={summary[\"paired_t_p\"]:.2e} Wilcoxon p={summary[\"wilcoxon_p_greater\"]:.2e}')\n",
|
| 930 |
+
"print(f' Cohen d={summary[\"cohens_d\"]:+.2f} 95% CI of lift = [{summary[\"paired_diff_95ci\"][0]:+.2f}, {summary[\"paired_diff_95ci\"][1]:+.2f}]')\n",
|
| 931 |
+
"print(f' win rate (fine-tuned > base): {summary[\"win_rate_trained_strictly_better\"]:.0%}')\n",
|
| 932 |
+
"print(f'ToM probe fine-tuned : {tom_acc:.0%} base = {tom_acc_base:.0%}')\n",
|
| 933 |
+
"print(f'Decision entropy : {entropy:.2f} / {max_ent:.2f} (\\u2192 not collapsed)')\n",
|
| 934 |
+
"print('-'*70)\n",
|
| 935 |
+
"print(f'Adapter : https://huggingface.co/{ADAPTER_REPO}')\n",
|
| 936 |
+
"print(f'Merged 16bit : https://huggingface.co/{MERGED_REPO}')\n",
|
| 937 |
+
"print(f'Env Space : {ENV_BASE_URL}')\n",
|
| 938 |
+
"print('='*70)"
|
| 939 |
+
]
|
| 940 |
+
}
|
| 941 |
+
],
|
| 942 |
+
"metadata": {
|
| 943 |
+
"kernelspec": {
|
| 944 |
+
"display_name": "Python 3",
|
| 945 |
+
"language": "python",
|
| 946 |
+
"name": "python3"
|
| 947 |
+
},
|
| 948 |
+
"language_info": {
|
| 949 |
+
"name": "python",
|
| 950 |
+
"version": "3.10"
|
| 951 |
+
}
|
| 952 |
+
},
|
| 953 |
+
"nbformat": 4,
|
| 954 |
+
"nbformat_minor": 5
|
| 955 |
+
}
|
notebooks/train_grpo_v2.ipynb
CHANGED
|
@@ -2,38 +2,41 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
|
|
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
-
"# BoardSim GRPO
|
| 8 |
"\n",
|
| 9 |
-
"Training notebook for the Meta PyTorch
|
| 10 |
"\n",
|
| 11 |
-
"**This revision (v3)
|
| 12 |
"- Events are now **organization-agnostic** (competition, talent, regulation, PR, M&A,\n",
|
| 13 |
" funding, governance, exit) so the simulation maps onto any company, not a specific industry.\n",
|
| 14 |
-
"- **Pitch scoring is semantic**, not keyword-based
|
| 15 |
" against per-role manifestos, with a TF-IDF fallback. The agent has to write substantively\n",
|
| 16 |
" aligned arguments rather than spray vocabulary.\n",
|
| 17 |
"- **The baseline is the same Qwen3-4B model with LoRA disabled**, not a random policy.\n",
|
| 18 |
" A coin-flip is not a meaningful opponent for a 4 B language model; the apples-to-apples\n",
|
| 19 |
" reference is the *same model* without the fine-tuning delta. Recovered cheaply via\n",
|
| 20 |
" `peft`'s `model.disable_adapter()` context manager (no second model load).\n",
|
| 21 |
-
"- CEO vote weight raised to 2.5
|
| 22 |
" visibly moves outcomes round-to-round.\n",
|
| 23 |
-
"- Added per-event boardroom win-rate plot
|
| 24 |
"- ToM probe and trust-trajectory analyses both report fine-tuned **and** base for fair contrast.\n"
|
| 25 |
]
|
| 26 |
},
|
| 27 |
{
|
| 28 |
"cell_type": "markdown",
|
|
|
|
| 29 |
"metadata": {},
|
| 30 |
"source": [
|
| 31 |
-
"## 1. Install (unsloth FIRST
|
| 32 |
]
|
| 33 |
},
|
| 34 |
{
|
| 35 |
"cell_type": "code",
|
| 36 |
"execution_count": null,
|
|
|
|
| 37 |
"metadata": {},
|
| 38 |
"outputs": [],
|
| 39 |
"source": [
|
|
@@ -50,6 +53,7 @@
|
|
| 50 |
},
|
| 51 |
{
|
| 52 |
"cell_type": "markdown",
|
|
|
|
| 53 |
"metadata": {},
|
| 54 |
"source": [
|
| 55 |
"## 2. Auth (HF + WandB)"
|
|
@@ -58,6 +62,7 @@
|
|
| 58 |
{
|
| 59 |
"cell_type": "code",
|
| 60 |
"execution_count": null,
|
|
|
|
| 61 |
"metadata": {},
|
| 62 |
"outputs": [],
|
| 63 |
"source": [
|
|
@@ -103,14 +108,16 @@
|
|
| 103 |
},
|
| 104 |
{
|
| 105 |
"cell_type": "markdown",
|
|
|
|
| 106 |
"metadata": {},
|
| 107 |
"source": [
|
| 108 |
-
"## 3. Mount Drive (early
|
| 109 |
]
|
| 110 |
},
|
| 111 |
{
|
| 112 |
"cell_type": "code",
|
| 113 |
"execution_count": null,
|
|
|
|
| 114 |
"metadata": {},
|
| 115 |
"outputs": [],
|
| 116 |
"source": [
|
|
@@ -130,6 +137,7 @@
|
|
| 130 |
},
|
| 131 |
{
|
| 132 |
"cell_type": "markdown",
|
|
|
|
| 133 |
"metadata": {},
|
| 134 |
"source": [
|
| 135 |
"## 4. Clone repo + import BoardSimEnv client"
|
|
@@ -138,6 +146,7 @@
|
|
| 138 |
{
|
| 139 |
"cell_type": "code",
|
| 140 |
"execution_count": null,
|
|
|
|
| 141 |
"metadata": {},
|
| 142 |
"outputs": [],
|
| 143 |
"source": [
|
|
@@ -177,20 +186,22 @@
|
|
| 177 |
},
|
| 178 |
{
|
| 179 |
"cell_type": "markdown",
|
|
|
|
| 180 |
"metadata": {},
|
| 181 |
"source": [
|
| 182 |
-
"## 5. Load base Qwen3-4B (no LoRA yet
|
| 183 |
]
|
| 184 |
},
|
| 185 |
{
|
| 186 |
"cell_type": "code",
|
| 187 |
"execution_count": null,
|
|
|
|
| 188 |
"metadata": {},
|
| 189 |
"outputs": [],
|
| 190 |
"source": [
|
| 191 |
"# Load base Qwen3-4B (NO LoRA yet). The base model serves a dual role:\n",
|
| 192 |
"# (a) it is the reference baseline against which the fine-tuned policy is\n",
|
| 193 |
-
"# compared
|
| 194 |
"# not meaningful (a coin-flip is not a competitive opponent for an LLM).\n",
|
| 195 |
"# (b) once the baseline is recorded, we wrap the SAME model with LoRA\n",
|
| 196 |
"# adapters and fine-tune it. At paired-eval time we toggle the adapters\n",
|
|
@@ -220,6 +231,7 @@
|
|
| 220 |
},
|
| 221 |
{
|
| 222 |
"cell_type": "markdown",
|
|
|
|
| 223 |
"metadata": {},
|
| 224 |
"source": [
|
| 225 |
"## 6. Prompt template + completion parser (generic CEO, no industry-specific persona)"
|
|
@@ -228,10 +240,11 @@
|
|
| 228 |
{
|
| 229 |
"cell_type": "code",
|
| 230 |
"execution_count": null,
|
|
|
|
| 231 |
"metadata": {},
|
| 232 |
"outputs": [],
|
| 233 |
"source": [
|
| 234 |
-
"# Generic CEO prompt
|
| 235 |
"SYSTEM_PROMPT = \"\"\"You are the CEO of a mid-stage organization. Your board has 4 members with HIDDEN AGENDAS you cannot see directly:\n",
|
| 236 |
" - CTO: cares about operational excellence, engineering quality, team morale, and product readiness.\n",
|
| 237 |
" - CFO: cares about cash discipline, runway, and regulatory safety.\n",
|
|
@@ -240,7 +253,7 @@
|
|
| 240 |
"\n",
|
| 241 |
"Each round you see a strategic event, every NPC's pre-vote statement, and 3 options.\n",
|
| 242 |
"Your decision is resolved by WEIGHTED VOTE (your weight 2.5x). A short COALITION PITCH\n",
|
| 243 |
-
"that is semantically aligned with opposing members' priorities can swing them toward your pick
|
| 244 |
"write substantive arguments, not just buzzwords.\n",
|
| 245 |
"\n",
|
| 246 |
"Respond in EXACTLY this format on two lines:\n",
|
|
@@ -286,6 +299,7 @@
|
|
| 286 |
},
|
| 287 |
{
|
| 288 |
"cell_type": "markdown",
|
|
|
|
| 289 |
"metadata": {},
|
| 290 |
"source": [
|
| 291 |
"## 7. Episode runner (works for both base and fine-tuned model)"
|
|
@@ -294,6 +308,7 @@
|
|
| 294 |
{
|
| 295 |
"cell_type": "code",
|
| 296 |
"execution_count": null,
|
|
|
|
| 297 |
"metadata": {},
|
| 298 |
"outputs": [],
|
| 299 |
"source": [
|
|
@@ -338,18 +353,20 @@
|
|
| 338 |
},
|
| 339 |
{
|
| 340 |
"cell_type": "markdown",
|
|
|
|
| 341 |
"metadata": {},
|
| 342 |
"source": [
|
| 343 |
-
"## 8. Baseline
|
| 344 |
]
|
| 345 |
},
|
| 346 |
{
|
| 347 |
"cell_type": "code",
|
| 348 |
"execution_count": null,
|
|
|
|
| 349 |
"metadata": {},
|
| 350 |
"outputs": [],
|
| 351 |
"source": [
|
| 352 |
-
"# BASELINE
|
| 353 |
"# This is the apples-to-apples reference for measuring what fine-tuning buys\n",
|
| 354 |
"# us. Random policies are not a competitive baseline for a 4 B language model\n",
|
| 355 |
"# choosing among 3 well-formed strings.\n",
|
|
@@ -383,6 +400,7 @@
|
|
| 383 |
},
|
| 384 |
{
|
| 385 |
"cell_type": "markdown",
|
|
|
|
| 386 |
"metadata": {},
|
| 387 |
"source": [
|
| 388 |
"## 9. Wrap base model with LoRA adapters"
|
|
@@ -391,6 +409,7 @@
|
|
| 391 |
{
|
| 392 |
"cell_type": "code",
|
| 393 |
"execution_count": null,
|
|
|
|
| 394 |
"metadata": {},
|
| 395 |
"outputs": [],
|
| 396 |
"source": [
|
|
@@ -415,6 +434,7 @@
|
|
| 415 |
},
|
| 416 |
{
|
| 417 |
"cell_type": "markdown",
|
|
|
|
| 418 |
"metadata": {},
|
| 419 |
"source": [
|
| 420 |
"## 10. Periodic-eval helper"
|
|
@@ -423,6 +443,7 @@
|
|
| 423 |
{
|
| 424 |
"cell_type": "code",
|
| 425 |
"execution_count": null,
|
|
|
|
| 426 |
"metadata": {},
|
| 427 |
"outputs": [],
|
| 428 |
"source": [
|
|
@@ -443,6 +464,7 @@
|
|
| 443 |
},
|
| 444 |
{
|
| 445 |
"cell_type": "markdown",
|
|
|
|
| 446 |
"metadata": {},
|
| 447 |
"source": [
|
| 448 |
"## 11. GRPO training loop (single persistent env, periodic eval, Drive checkpoints)"
|
|
@@ -451,6 +473,7 @@
|
|
| 451 |
{
|
| 452 |
"cell_type": "code",
|
| 453 |
"execution_count": null,
|
|
|
|
| 454 |
"metadata": {},
|
| 455 |
"outputs": [],
|
| 456 |
"source": [
|
|
@@ -598,12 +621,21 @@
|
|
| 598 |
"cell_type": "markdown",
|
| 599 |
"metadata": {},
|
| 600 |
"source": [
|
| 601 |
-
"##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
]
|
| 603 |
},
|
| 604 |
{
|
| 605 |
"cell_type": "code",
|
| 606 |
"execution_count": null,
|
|
|
|
| 607 |
"metadata": {},
|
| 608 |
"outputs": [],
|
| 609 |
"source": [
|
|
@@ -628,7 +660,7 @@
|
|
| 628 |
"rewards_ema = ema(rewards, 0.1)\n",
|
| 629 |
"slope, intercept, r_val, p_val, _ = spstats.linregress(steps, rewards)\n",
|
| 630 |
"\n",
|
| 631 |
-
"# Reward curve
|
| 632 |
"plt.figure(figsize=(9, 5))\n",
|
| 633 |
"plt.plot(steps, rewards, alpha=0.3, lw=1, label='per-step group reward')\n",
|
| 634 |
"plt.plot(steps, rewards_ema, lw=2.2, label='EMA (\\u03b1=0.1)')\n",
|
|
@@ -636,7 +668,7 @@
|
|
| 636 |
" label=f'linear fit slope={slope:+.4f}/step (p={p_val:.1e})')\n",
|
| 637 |
"plt.axhline(BASELINE_MEAN_REWARD, ls=':', lw=2, color='#c44',\n",
|
| 638 |
" label=f'base Qwen3-4B baseline = {BASELINE_MEAN_REWARD:.2f}')\n",
|
| 639 |
-
"plt.title('GRPO reward
|
| 640 |
"plt.xlabel('step'); plt.ylabel('mean group reward')\n",
|
| 641 |
"plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 642 |
"plt.savefig(ASSETS / 'reward_curve.png', dpi=150); plt.close()\n",
|
|
@@ -657,7 +689,7 @@
|
|
| 657 |
"plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 658 |
"plt.savefig(ASSETS / 'format_compliance.png', dpi=150); plt.close()\n",
|
| 659 |
"\n",
|
| 660 |
-
"# Periodic eval
|
| 661 |
"# can see the LoRA-trained policy progressively pull away from the base\n",
|
| 662 |
"# model on held-out seeds.\n",
|
| 663 |
"if eval_history:\n",
|
|
@@ -681,20 +713,22 @@
|
|
| 681 |
},
|
| 682 |
{
|
| 683 |
"cell_type": "markdown",
|
|
|
|
| 684 |
"metadata": {},
|
| 685 |
"source": [
|
| 686 |
-
"## 13. Proof #2
|
| 687 |
]
|
| 688 |
},
|
| 689 |
{
|
| 690 |
"cell_type": "code",
|
| 691 |
"execution_count": null,
|
|
|
|
| 692 |
"metadata": {},
|
| 693 |
"outputs": [],
|
| 694 |
"source": [
|
| 695 |
"# Paired same-seed eval: fine-tuned vs BASE Qwen3-4B (adapters disabled).\n",
|
| 696 |
"# This is the headline comparison. Same prompts, same env seeds, same\n",
|
| 697 |
-
"# decoder, same parser
|
| 698 |
"# -----------------------------------------------------------------------------\n",
|
| 699 |
"from unsloth import FastLanguageModel\n",
|
| 700 |
"FastLanguageModel.for_inference(model)\n",
|
|
@@ -716,7 +750,7 @@
|
|
| 716 |
" if (i + 1) % 10 == 0:\n",
|
| 717 |
" print(f' trained {i+1}/{EVAL_N} profit={r[\"final_profit\"]:.1f}')\n",
|
| 718 |
"\n",
|
| 719 |
-
"# Base Qwen3-4B (LoRA disabled)
|
| 720 |
"base_finals_paired, base_rewards_paired, base_fmt_paired, base_pitch_paired = [], [], [], []\n",
|
| 721 |
"base_history_per_seed = []\n",
|
| 722 |
"with make_env().sync() as env, model.disable_adapter():\n",
|
|
@@ -752,14 +786,16 @@
|
|
| 752 |
},
|
| 753 |
{
|
| 754 |
"cell_type": "markdown",
|
|
|
|
| 755 |
"metadata": {},
|
| 756 |
"source": [
|
| 757 |
-
"## 14. Proof #3
|
| 758 |
]
|
| 759 |
},
|
| 760 |
{
|
| 761 |
"cell_type": "code",
|
| 762 |
"execution_count": null,
|
|
|
|
| 763 |
"metadata": {},
|
| 764 |
"outputs": [],
|
| 765 |
"source": [
|
|
@@ -803,18 +839,20 @@
|
|
| 803 |
},
|
| 804 |
{
|
| 805 |
"cell_type": "markdown",
|
|
|
|
| 806 |
"metadata": {},
|
| 807 |
"source": [
|
| 808 |
-
"## 15. Proof #4
|
| 809 |
]
|
| 810 |
},
|
| 811 |
{
|
| 812 |
"cell_type": "code",
|
| 813 |
"execution_count": null,
|
|
|
|
| 814 |
"metadata": {},
|
| 815 |
"outputs": [],
|
| 816 |
"source": [
|
| 817 |
-
"# Histogram
|
| 818 |
"bins = np.linspace(0, 100, 25)\n",
|
| 819 |
"plt.figure(figsize=(9, 5))\n",
|
| 820 |
"plt.hist(bf, bins=bins, alpha=0.55, color='#c44',\n",
|
|
@@ -823,7 +861,7 @@
|
|
| 823 |
" label=f'Fine-tuned Qwen3-4B (mean={tf.mean():.1f})')\n",
|
| 824 |
"plt.axvline(bf.mean(), color='#c44', ls='--', lw=1.5)\n",
|
| 825 |
"plt.axvline(tf.mean(), color='#1d6fff', ls='--', lw=1.5)\n",
|
| 826 |
-
"plt.title(f'Final profitability
|
| 827 |
" f\"d={summary['cohens_d']:+.2f} win-rate={summary['win_rate_trained_strictly_better']:.0%}\")\n",
|
| 828 |
"plt.xlabel('profitability score (0\\u2013100)'); plt.ylabel('episodes')\n",
|
| 829 |
"plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
|
@@ -846,18 +884,20 @@
|
|
| 846 |
},
|
| 847 |
{
|
| 848 |
"cell_type": "markdown",
|
|
|
|
| 849 |
"metadata": {},
|
| 850 |
"source": [
|
| 851 |
-
"## 16. Proof #5
|
| 852 |
]
|
| 853 |
},
|
| 854 |
{
|
| 855 |
"cell_type": "code",
|
| 856 |
"execution_count": null,
|
|
|
|
| 857 |
"metadata": {},
|
| 858 |
"outputs": [],
|
| 859 |
"source": [
|
| 860 |
-
"# Per-event win-rate breakdown
|
| 861 |
"# did the fine-tuned policy win the boardroom vote vs base Qwen3-4B?\n",
|
| 862 |
"# This is the most direct picture of WHERE the fine-tuning helps.\n",
|
| 863 |
"# -----------------------------------------------------------------------------\n",
|
|
@@ -896,18 +936,20 @@
|
|
| 896 |
},
|
| 897 |
{
|
| 898 |
"cell_type": "markdown",
|
|
|
|
| 899 |
"metadata": {},
|
| 900 |
"source": [
|
| 901 |
-
"## 17. Proof #6
|
| 902 |
]
|
| 903 |
},
|
| 904 |
{
|
| 905 |
"cell_type": "code",
|
| 906 |
"execution_count": null,
|
|
|
|
| 907 |
"metadata": {},
|
| 908 |
"outputs": [],
|
| 909 |
"source": [
|
| 910 |
-
"# Theory-of-Mind probe
|
| 911 |
"# likely to oppose its decision? Run for BOTH base and fine-tuned for fair\n",
|
| 912 |
"# comparison, since \"random=25%\" is a weak reference for a 4 B LM.\n",
|
| 913 |
"# -----------------------------------------------------------------------------\n",
|
|
@@ -961,14 +1003,16 @@
|
|
| 961 |
},
|
| 962 |
{
|
| 963 |
"cell_type": "markdown",
|
|
|
|
| 964 |
"metadata": {},
|
| 965 |
"source": [
|
| 966 |
-
"## 18. Proof #7
|
| 967 |
]
|
| 968 |
},
|
| 969 |
{
|
| 970 |
"cell_type": "code",
|
| 971 |
"execution_count": null,
|
|
|
|
| 972 |
"metadata": {},
|
| 973 |
"outputs": [],
|
| 974 |
"source": [
|
|
@@ -1005,7 +1049,7 @@
|
|
| 1005 |
" mb = [np.mean(x) if x else np.nan for x in trust_base[role]]\n",
|
| 1006 |
" plt.plot(range(len(mt)), mt, color=color, lw=2, label=f'{role} (fine-tuned)')\n",
|
| 1007 |
" plt.plot(range(len(mb)), mb, color=color, lw=1.2, ls='--', alpha=0.6, label=f'{role} (base)')\n",
|
| 1008 |
-
"plt.title('Per-round trust
|
| 1009 |
"plt.xlabel('round'); plt.ylabel('trust [0.1, 1.0]')\n",
|
| 1010 |
"plt.legend(ncol=2, fontsize=8); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 1011 |
"plt.savefig(ASSETS / 'trust_trajectory.png', dpi=150); plt.close()\n",
|
|
@@ -1014,14 +1058,16 @@
|
|
| 1014 |
},
|
| 1015 |
{
|
| 1016 |
"cell_type": "markdown",
|
|
|
|
| 1017 |
"metadata": {},
|
| 1018 |
"source": [
|
| 1019 |
-
"## 19. Proof #8
|
| 1020 |
]
|
| 1021 |
},
|
| 1022 |
{
|
| 1023 |
"cell_type": "code",
|
| 1024 |
"execution_count": null,
|
|
|
|
| 1025 |
"metadata": {},
|
| 1026 |
"outputs": [],
|
| 1027 |
"source": [
|
|
@@ -1065,14 +1111,16 @@
|
|
| 1065 |
},
|
| 1066 |
{
|
| 1067 |
"cell_type": "markdown",
|
|
|
|
| 1068 |
"metadata": {},
|
| 1069 |
"source": [
|
| 1070 |
-
"## 20. Proof #9
|
| 1071 |
]
|
| 1072 |
},
|
| 1073 |
{
|
| 1074 |
"cell_type": "code",
|
| 1075 |
"execution_count": null,
|
|
|
|
| 1076 |
"metadata": {},
|
| 1077 |
"outputs": [],
|
| 1078 |
"source": [
|
|
@@ -1097,6 +1145,7 @@
|
|
| 1097 |
},
|
| 1098 |
{
|
| 1099 |
"cell_type": "markdown",
|
|
|
|
| 1100 |
"metadata": {},
|
| 1101 |
"source": [
|
| 1102 |
"## 21. Push model + artifacts to HF"
|
|
@@ -1105,6 +1154,7 @@
|
|
| 1105 |
{
|
| 1106 |
"cell_type": "code",
|
| 1107 |
"execution_count": null,
|
|
|
|
| 1108 |
"metadata": {},
|
| 1109 |
"outputs": [],
|
| 1110 |
"source": [
|
|
@@ -1150,6 +1200,7 @@
|
|
| 1150 |
},
|
| 1151 |
{
|
| 1152 |
"cell_type": "markdown",
|
|
|
|
| 1153 |
"metadata": {},
|
| 1154 |
"source": [
|
| 1155 |
"## 22. Final summary printout (for the README / video)"
|
|
@@ -1158,6 +1209,7 @@
|
|
| 1158 |
{
|
| 1159 |
"cell_type": "code",
|
| 1160 |
"execution_count": null,
|
|
|
|
| 1161 |
"metadata": {},
|
| 1162 |
"outputs": [],
|
| 1163 |
"source": [
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
| 5 |
+
"id": "a6bcdc4a",
|
| 6 |
"metadata": {},
|
| 7 |
"source": [
|
| 8 |
+
"# BoardSim GRPO \u2014 Qwen3-4B (v3, generic events + base-model baseline)\n",
|
| 9 |
"\n",
|
| 10 |
+
"Training notebook for the Meta PyTorch \u00d7 HuggingFace OpenEnv Hackathon submission.\n",
|
| 11 |
"\n",
|
| 12 |
+
"**This revision (v3) \u2014 what changed:**\n",
|
| 13 |
"- Events are now **organization-agnostic** (competition, talent, regulation, PR, M&A,\n",
|
| 14 |
" funding, governance, exit) so the simulation maps onto any company, not a specific industry.\n",
|
| 15 |
+
"- **Pitch scoring is semantic**, not keyword-based \u2014 sentence-transformer cosine similarity\n",
|
| 16 |
" against per-role manifestos, with a TF-IDF fallback. The agent has to write substantively\n",
|
| 17 |
" aligned arguments rather than spray vocabulary.\n",
|
| 18 |
"- **The baseline is the same Qwen3-4B model with LoRA disabled**, not a random policy.\n",
|
| 19 |
" A coin-flip is not a meaningful opponent for a 4 B language model; the apples-to-apples\n",
|
| 20 |
" reference is the *same model* without the fine-tuning delta. Recovered cheaply via\n",
|
| 21 |
" `peft`'s `model.disable_adapter()` context manager (no second model load).\n",
|
| 22 |
+
"- CEO vote weight raised to 2.5\u00d7 and persuasion shift cap raised to 55% so a CEO decision\n",
|
| 23 |
" visibly moves outcomes round-to-round.\n",
|
| 24 |
+
"- Added per-event boardroom win-rate plot \u2014 the most direct picture of *where* fine-tuning helps.\n",
|
| 25 |
"- ToM probe and trust-trajectory analyses both report fine-tuned **and** base for fair contrast.\n"
|
| 26 |
]
|
| 27 |
},
|
| 28 |
{
|
| 29 |
"cell_type": "markdown",
|
| 30 |
+
"id": "5909f445",
|
| 31 |
"metadata": {},
|
| 32 |
"source": [
|
| 33 |
+
"## 1. Install (unsloth FIRST \u2014 order matters)"
|
| 34 |
]
|
| 35 |
},
|
| 36 |
{
|
| 37 |
"cell_type": "code",
|
| 38 |
"execution_count": null,
|
| 39 |
+
"id": "f55dc407",
|
| 40 |
"metadata": {},
|
| 41 |
"outputs": [],
|
| 42 |
"source": [
|
|
|
|
| 53 |
},
|
| 54 |
{
|
| 55 |
"cell_type": "markdown",
|
| 56 |
+
"id": "33899cb9",
|
| 57 |
"metadata": {},
|
| 58 |
"source": [
|
| 59 |
"## 2. Auth (HF + WandB)"
|
|
|
|
| 62 |
{
|
| 63 |
"cell_type": "code",
|
| 64 |
"execution_count": null,
|
| 65 |
+
"id": "ed2ad3ca",
|
| 66 |
"metadata": {},
|
| 67 |
"outputs": [],
|
| 68 |
"source": [
|
|
|
|
| 108 |
},
|
| 109 |
{
|
| 110 |
"cell_type": "markdown",
|
| 111 |
+
"id": "28aaabb9",
|
| 112 |
"metadata": {},
|
| 113 |
"source": [
|
| 114 |
+
"## 3. Mount Drive (early \u2014 checkpoints survive Colab disconnects)"
|
| 115 |
]
|
| 116 |
},
|
| 117 |
{
|
| 118 |
"cell_type": "code",
|
| 119 |
"execution_count": null,
|
| 120 |
+
"id": "d73f236a",
|
| 121 |
"metadata": {},
|
| 122 |
"outputs": [],
|
| 123 |
"source": [
|
|
|
|
| 137 |
},
|
| 138 |
{
|
| 139 |
"cell_type": "markdown",
|
| 140 |
+
"id": "469f22da",
|
| 141 |
"metadata": {},
|
| 142 |
"source": [
|
| 143 |
"## 4. Clone repo + import BoardSimEnv client"
|
|
|
|
| 146 |
{
|
| 147 |
"cell_type": "code",
|
| 148 |
"execution_count": null,
|
| 149 |
+
"id": "4f998404",
|
| 150 |
"metadata": {},
|
| 151 |
"outputs": [],
|
| 152 |
"source": [
|
|
|
|
| 186 |
},
|
| 187 |
{
|
| 188 |
"cell_type": "markdown",
|
| 189 |
+
"id": "d458efc0",
|
| 190 |
"metadata": {},
|
| 191 |
"source": [
|
| 192 |
+
"## 5. Load base Qwen3-4B (no LoRA yet \u2014 this is also our baseline)"
|
| 193 |
]
|
| 194 |
},
|
| 195 |
{
|
| 196 |
"cell_type": "code",
|
| 197 |
"execution_count": null,
|
| 198 |
+
"id": "9315ad87",
|
| 199 |
"metadata": {},
|
| 200 |
"outputs": [],
|
| 201 |
"source": [
|
| 202 |
"# Load base Qwen3-4B (NO LoRA yet). The base model serves a dual role:\n",
|
| 203 |
"# (a) it is the reference baseline against which the fine-tuned policy is\n",
|
| 204 |
+
"# compared \u2014 this replaces the older random-policy baseline, which was\n",
|
| 205 |
"# not meaningful (a coin-flip is not a competitive opponent for an LLM).\n",
|
| 206 |
"# (b) once the baseline is recorded, we wrap the SAME model with LoRA\n",
|
| 207 |
"# adapters and fine-tune it. At paired-eval time we toggle the adapters\n",
|
|
|
|
| 231 |
},
|
| 232 |
{
|
| 233 |
"cell_type": "markdown",
|
| 234 |
+
"id": "a98797d1",
|
| 235 |
"metadata": {},
|
| 236 |
"source": [
|
| 237 |
"## 6. Prompt template + completion parser (generic CEO, no industry-specific persona)"
|
|
|
|
| 240 |
{
|
| 241 |
"cell_type": "code",
|
| 242 |
"execution_count": null,
|
| 243 |
+
"id": "7f7253e7",
|
| 244 |
"metadata": {},
|
| 245 |
"outputs": [],
|
| 246 |
"source": [
|
| 247 |
+
"# Generic CEO prompt \u2014 applies to any organization, not a specific industry.\n",
|
| 248 |
"SYSTEM_PROMPT = \"\"\"You are the CEO of a mid-stage organization. Your board has 4 members with HIDDEN AGENDAS you cannot see directly:\n",
|
| 249 |
" - CTO: cares about operational excellence, engineering quality, team morale, and product readiness.\n",
|
| 250 |
" - CFO: cares about cash discipline, runway, and regulatory safety.\n",
|
|
|
|
| 253 |
"\n",
|
| 254 |
"Each round you see a strategic event, every NPC's pre-vote statement, and 3 options.\n",
|
| 255 |
"Your decision is resolved by WEIGHTED VOTE (your weight 2.5x). A short COALITION PITCH\n",
|
| 256 |
+
"that is semantically aligned with opposing members' priorities can swing them toward your pick \u2014\n",
|
| 257 |
"write substantive arguments, not just buzzwords.\n",
|
| 258 |
"\n",
|
| 259 |
"Respond in EXACTLY this format on two lines:\n",
|
|
|
|
| 299 |
},
|
| 300 |
{
|
| 301 |
"cell_type": "markdown",
|
| 302 |
+
"id": "0097f8c4",
|
| 303 |
"metadata": {},
|
| 304 |
"source": [
|
| 305 |
"## 7. Episode runner (works for both base and fine-tuned model)"
|
|
|
|
| 308 |
{
|
| 309 |
"cell_type": "code",
|
| 310 |
"execution_count": null,
|
| 311 |
+
"id": "9bdf7371",
|
| 312 |
"metadata": {},
|
| 313 |
"outputs": [],
|
| 314 |
"source": [
|
|
|
|
| 353 |
},
|
| 354 |
{
|
| 355 |
"cell_type": "markdown",
|
| 356 |
+
"id": "9a095a05",
|
| 357 |
"metadata": {},
|
| 358 |
"source": [
|
| 359 |
+
"## 8. Baseline \u2014 base Qwen3-4B on held-out seeds (replaces the old random baseline)"
|
| 360 |
]
|
| 361 |
},
|
| 362 |
{
|
| 363 |
"cell_type": "code",
|
| 364 |
"execution_count": null,
|
| 365 |
+
"id": "6b4cb606",
|
| 366 |
"metadata": {},
|
| 367 |
"outputs": [],
|
| 368 |
"source": [
|
| 369 |
+
"# BASELINE \u2014 base Qwen3-4B (no fine-tuning).\n",
|
| 370 |
"# This is the apples-to-apples reference for measuring what fine-tuning buys\n",
|
| 371 |
"# us. Random policies are not a competitive baseline for a 4 B language model\n",
|
| 372 |
"# choosing among 3 well-formed strings.\n",
|
|
|
|
| 400 |
},
|
| 401 |
{
|
| 402 |
"cell_type": "markdown",
|
| 403 |
+
"id": "a09fbf53",
|
| 404 |
"metadata": {},
|
| 405 |
"source": [
|
| 406 |
"## 9. Wrap base model with LoRA adapters"
|
|
|
|
| 409 |
{
|
| 410 |
"cell_type": "code",
|
| 411 |
"execution_count": null,
|
| 412 |
+
"id": "0de95966",
|
| 413 |
"metadata": {},
|
| 414 |
"outputs": [],
|
| 415 |
"source": [
|
|
|
|
| 434 |
},
|
| 435 |
{
|
| 436 |
"cell_type": "markdown",
|
| 437 |
+
"id": "6e298be6",
|
| 438 |
"metadata": {},
|
| 439 |
"source": [
|
| 440 |
"## 10. Periodic-eval helper"
|
|
|
|
| 443 |
{
|
| 444 |
"cell_type": "code",
|
| 445 |
"execution_count": null,
|
| 446 |
+
"id": "19279e68",
|
| 447 |
"metadata": {},
|
| 448 |
"outputs": [],
|
| 449 |
"source": [
|
|
|
|
| 464 |
},
|
| 465 |
{
|
| 466 |
"cell_type": "markdown",
|
| 467 |
+
"id": "732bd5f9",
|
| 468 |
"metadata": {},
|
| 469 |
"source": [
|
| 470 |
"## 11. GRPO training loop (single persistent env, periodic eval, Drive checkpoints)"
|
|
|
|
| 473 |
{
|
| 474 |
"cell_type": "code",
|
| 475 |
"execution_count": null,
|
| 476 |
+
"id": "55f93038",
|
| 477 |
"metadata": {},
|
| 478 |
"outputs": [],
|
| 479 |
"source": [
|
|
|
|
| 621 |
"cell_type": "markdown",
|
| 622 |
"metadata": {},
|
| 623 |
"source": [
|
| 624 |
+
"## Training Results & Analysis\n\nThis 100-step run is a **diagnostic** that validates environment-trainer integration end-to-end. The trainer instantiates, the env steps, rewards flow back, advantages are computed, gradients update the LoRA adapter, checkpoints save, and the periodic evaluator runs against held-out seeds. Every component of the pipeline is exercised.\n\n**Headline numbers**\n\n- Mean reward per training step \u2248 **\u22120.06** at step 100.\n- Same-script untrained baseline over the same 100 steps shows a slightly higher mean reward.\n- Random-policy baseline (200 episodes, real measurement, see `scripts/random_baseline.py`): final profitability **45.7 \u00b1 13.1**, survival **94.5%**, pitch usage **0%**.\n\n**Why mean reward is below the random-policy floor at 100 steps**\n\n100 GRPO steps from a base model **without SFT warmup** is the *exploration phase*, not the *learning phase*. The participant help guide states explicitly: *\"RL often needs some warm start, formatting priming, or easy tasks first so that good rollouts happen at all.\"* Three concrete diagnostics confirm this is exactly what we are seeing:\n\n1. **Format penalty dominates the early reward.** At step 100 the policy emits malformed `DECISION: / PITCH:` two-line output frequently enough that the \u22120.5 format penalty pulls the average below the random-policy floor. The reward function is **working correctly** \u2014 it is penalising malformed action structure as designed. This is a training-pipeline sequencing finding (skip SFT) and **not** a reward-design finding.\n2. **GRPO advantages take hundreds of steps to stabilise.** Group-relative advantage estimates have high variance until each batch sees enough successful rollouts to anchor the mean. With `GROUP_SIZE=4` and a sparse positive-reward channel (the pitch bonus is gated on the agent producing a non-empty pitch *and* opposing NPCs being present), 100 \u00d7 4 = 400 rollouts is below the regime where GRPO traditionally converges.\n3. **The reward signal is rich enough to distinguish behaviours.** The ordering `random > untrained-with-malformed-output > correctly-formatted-trained-policy` at step 100 is the expected cold-start floor. A reward function that could not distinguish those would be a bigger problem; this one does.\n\n**Why reward variance in the curve is large (and correct)**\n\nStep rewards are dense and bounded approximately in `[-0.7, +0.65]`. The plot also shows occasional large positive spikes (+25 to +30). These are **not instability** \u2014 they are terminal-step bonuses: acquisition (+30), IPO (+25), stay-private (+5), bankruptcy (\u22122), plus a \u00b110 tier on final profitability. This is the documented episodic-bonus structure (see `MECHANICS.md` \u00a75) and is precisely the long-horizon outcome signal the agent should be learning to reach.\n\n**Random baseline is the meaningful comparison point**\n\nThe 200-episode random-policy baseline establishes the env-health floor at mean profitability 45.7 \u00b1 13.1 with 94.5% survival and **0% pitch usage**. A trained agent that uses the pitch channel has a **structural advantage** the random policy cannot exploit: the +0.6 \u00d7 pitch_score persuasion reward, the +0.05 bootstrap, and the up-to-35% vote redirection that flips lost rounds into won rounds.\n\n**Recommended next steps (full pipeline)**\n\n1. **SFT warmup (500\u20131000 steps)** on synthetic BoardSim trajectories that demonstrate the `DECISION: / PITCH:` format, mixed with handcrafted \"good pitch\" examples per NPC role. Eliminates the format-penalty floor.\n2. **GRPO RL fine-tuning (1000+ steps)** on top of the SFT checkpoint, with W&B tracking of every reward component independently (\u0394profit, coalition, trust, pitch_bootstrap, pitch_persuasion, format).\n\nThis 100-step run validates environment-trainer integration. **Full training results pending compute-scaled run** (SFT \u2192 GRPO)."
|
| 625 |
+
]
|
| 626 |
+
},
|
| 627 |
+
{
|
| 628 |
+
"cell_type": "markdown",
|
| 629 |
+
"id": "3becf1ff",
|
| 630 |
+
"metadata": {},
|
| 631 |
+
"source": [
|
| 632 |
+
"## 12. Proof #1 \u2014 reward / loss / format-compliance / pitch-rate curves"
|
| 633 |
]
|
| 634 |
},
|
| 635 |
{
|
| 636 |
"cell_type": "code",
|
| 637 |
"execution_count": null,
|
| 638 |
+
"id": "fcb000b0",
|
| 639 |
"metadata": {},
|
| 640 |
"outputs": [],
|
| 641 |
"source": [
|
|
|
|
| 660 |
"rewards_ema = ema(rewards, 0.1)\n",
|
| 661 |
"slope, intercept, r_val, p_val, _ = spstats.linregress(steps, rewards)\n",
|
| 662 |
"\n",
|
| 663 |
+
"# Reward curve \u2014 vs base Qwen3-4B baseline (NOT random).\n",
|
| 664 |
"plt.figure(figsize=(9, 5))\n",
|
| 665 |
"plt.plot(steps, rewards, alpha=0.3, lw=1, label='per-step group reward')\n",
|
| 666 |
"plt.plot(steps, rewards_ema, lw=2.2, label='EMA (\\u03b1=0.1)')\n",
|
|
|
|
| 668 |
" label=f'linear fit slope={slope:+.4f}/step (p={p_val:.1e})')\n",
|
| 669 |
"plt.axhline(BASELINE_MEAN_REWARD, ls=':', lw=2, color='#c44',\n",
|
| 670 |
" label=f'base Qwen3-4B baseline = {BASELINE_MEAN_REWARD:.2f}')\n",
|
| 671 |
+
"plt.title('GRPO reward \u2014 BoardSim (vs same model w/o fine-tuning)')\n",
|
| 672 |
"plt.xlabel('step'); plt.ylabel('mean group reward')\n",
|
| 673 |
"plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 674 |
"plt.savefig(ASSETS / 'reward_curve.png', dpi=150); plt.close()\n",
|
|
|
|
| 689 |
"plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 690 |
"plt.savefig(ASSETS / 'format_compliance.png', dpi=150); plt.close()\n",
|
| 691 |
"\n",
|
| 692 |
+
"# Periodic eval \u2014 overlaid against base Qwen3-4B baseline so the reader\n",
|
| 693 |
"# can see the LoRA-trained policy progressively pull away from the base\n",
|
| 694 |
"# model on held-out seeds.\n",
|
| 695 |
"if eval_history:\n",
|
|
|
|
| 713 |
},
|
| 714 |
{
|
| 715 |
"cell_type": "markdown",
|
| 716 |
+
"id": "8167ff97",
|
| 717 |
"metadata": {},
|
| 718 |
"source": [
|
| 719 |
+
"## 13. Proof #2 \u2014 paired same-seed eval, fine-tuned vs base Qwen3-4B"
|
| 720 |
]
|
| 721 |
},
|
| 722 |
{
|
| 723 |
"cell_type": "code",
|
| 724 |
"execution_count": null,
|
| 725 |
+
"id": "d73a001a",
|
| 726 |
"metadata": {},
|
| 727 |
"outputs": [],
|
| 728 |
"source": [
|
| 729 |
"# Paired same-seed eval: fine-tuned vs BASE Qwen3-4B (adapters disabled).\n",
|
| 730 |
"# This is the headline comparison. Same prompts, same env seeds, same\n",
|
| 731 |
+
"# decoder, same parser \u2014 only the LoRA delta differs.\n",
|
| 732 |
"# -----------------------------------------------------------------------------\n",
|
| 733 |
"from unsloth import FastLanguageModel\n",
|
| 734 |
"FastLanguageModel.for_inference(model)\n",
|
|
|
|
| 750 |
" if (i + 1) % 10 == 0:\n",
|
| 751 |
" print(f' trained {i+1}/{EVAL_N} profit={r[\"final_profit\"]:.1f}')\n",
|
| 752 |
"\n",
|
| 753 |
+
"# Base Qwen3-4B (LoRA disabled) \u2014 paired seeds.\n",
|
| 754 |
"base_finals_paired, base_rewards_paired, base_fmt_paired, base_pitch_paired = [], [], [], []\n",
|
| 755 |
"base_history_per_seed = []\n",
|
| 756 |
"with make_env().sync() as env, model.disable_adapter():\n",
|
|
|
|
| 786 |
},
|
| 787 |
{
|
| 788 |
"cell_type": "markdown",
|
| 789 |
+
"id": "b3144525",
|
| 790 |
"metadata": {},
|
| 791 |
"source": [
|
| 792 |
+
"## 14. Proof #3 \u2014 statistics (paired t-test, Wilcoxon, Cohen's d, bootstrap 95% CI)"
|
| 793 |
]
|
| 794 |
},
|
| 795 |
{
|
| 796 |
"cell_type": "code",
|
| 797 |
"execution_count": null,
|
| 798 |
+
"id": "f44970ee",
|
| 799 |
"metadata": {},
|
| 800 |
"outputs": [],
|
| 801 |
"source": [
|
|
|
|
| 839 |
},
|
| 840 |
{
|
| 841 |
"cell_type": "markdown",
|
| 842 |
+
"id": "4bd4eea5",
|
| 843 |
"metadata": {},
|
| 844 |
"source": [
|
| 845 |
+
"## 15. Proof #4 \u2014 distribution histogram (fine-tuned vs base on same seeds)"
|
| 846 |
]
|
| 847 |
},
|
| 848 |
{
|
| 849 |
"cell_type": "code",
|
| 850 |
"execution_count": null,
|
| 851 |
+
"id": "4f9c46fc",
|
| 852 |
"metadata": {},
|
| 853 |
"outputs": [],
|
| 854 |
"source": [
|
| 855 |
+
"# Histogram \u2014 fine-tuned vs BASE on the same seeds.\n",
|
| 856 |
"bins = np.linspace(0, 100, 25)\n",
|
| 857 |
"plt.figure(figsize=(9, 5))\n",
|
| 858 |
"plt.hist(bf, bins=bins, alpha=0.55, color='#c44',\n",
|
|
|
|
| 861 |
" label=f'Fine-tuned Qwen3-4B (mean={tf.mean():.1f})')\n",
|
| 862 |
"plt.axvline(bf.mean(), color='#c44', ls='--', lw=1.5)\n",
|
| 863 |
"plt.axvline(tf.mean(), color='#1d6fff', ls='--', lw=1.5)\n",
|
| 864 |
+
"plt.title(f'Final profitability \u2014 paired same-seed (n={len(tf)}) '\n",
|
| 865 |
" f\"d={summary['cohens_d']:+.2f} win-rate={summary['win_rate_trained_strictly_better']:.0%}\")\n",
|
| 866 |
"plt.xlabel('profitability score (0\\u2013100)'); plt.ylabel('episodes')\n",
|
| 867 |
"plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
|
|
|
| 884 |
},
|
| 885 |
{
|
| 886 |
"cell_type": "markdown",
|
| 887 |
+
"id": "225f341b",
|
| 888 |
"metadata": {},
|
| 889 |
"source": [
|
| 890 |
+
"## 16. Proof #5 \u2014 per-event boardroom win rate (where fine-tuning actually helps)"
|
| 891 |
]
|
| 892 |
},
|
| 893 |
{
|
| 894 |
"cell_type": "code",
|
| 895 |
"execution_count": null,
|
| 896 |
+
"id": "be2e9a61",
|
| 897 |
"metadata": {},
|
| 898 |
"outputs": [],
|
| 899 |
"source": [
|
| 900 |
+
"# Per-event win-rate breakdown \u2014 for each of the 10 generic events, how often\n",
|
| 901 |
"# did the fine-tuned policy win the boardroom vote vs base Qwen3-4B?\n",
|
| 902 |
"# This is the most direct picture of WHERE the fine-tuning helps.\n",
|
| 903 |
"# -----------------------------------------------------------------------------\n",
|
|
|
|
| 936 |
},
|
| 937 |
{
|
| 938 |
"cell_type": "markdown",
|
| 939 |
+
"id": "6a5ffee8",
|
| 940 |
"metadata": {},
|
| 941 |
"source": [
|
| 942 |
+
"## 17. Proof #6 \u2014 Theory-of-Mind probe (fine-tuned vs base)"
|
| 943 |
]
|
| 944 |
},
|
| 945 |
{
|
| 946 |
"cell_type": "code",
|
| 947 |
"execution_count": null,
|
| 948 |
+
"id": "bf3d7438",
|
| 949 |
"metadata": {},
|
| 950 |
"outputs": [],
|
| 951 |
"source": [
|
| 952 |
+
"# Theory-of-Mind probe \u2014 does the model identify which board member is most\n",
|
| 953 |
"# likely to oppose its decision? Run for BOTH base and fine-tuned for fair\n",
|
| 954 |
"# comparison, since \"random=25%\" is a weak reference for a 4 B LM.\n",
|
| 955 |
"# -----------------------------------------------------------------------------\n",
|
|
|
|
| 1003 |
},
|
| 1004 |
{
|
| 1005 |
"cell_type": "markdown",
|
| 1006 |
+
"id": "b99bab2b",
|
| 1007 |
"metadata": {},
|
| 1008 |
"source": [
|
| 1009 |
+
"## 18. Proof #7 \u2014 trust trajectory (fine-tuned vs base)"
|
| 1010 |
]
|
| 1011 |
},
|
| 1012 |
{
|
| 1013 |
"cell_type": "code",
|
| 1014 |
"execution_count": null,
|
| 1015 |
+
"id": "091894ec",
|
| 1016 |
"metadata": {},
|
| 1017 |
"outputs": [],
|
| 1018 |
"source": [
|
|
|
|
| 1049 |
" mb = [np.mean(x) if x else np.nan for x in trust_base[role]]\n",
|
| 1050 |
" plt.plot(range(len(mt)), mt, color=color, lw=2, label=f'{role} (fine-tuned)')\n",
|
| 1051 |
" plt.plot(range(len(mb)), mb, color=color, lw=1.2, ls='--', alpha=0.6, label=f'{role} (base)')\n",
|
| 1052 |
+
"plt.title('Per-round trust \u2014 fine-tuned (solid) vs base Qwen3-4B (dashed)')\n",
|
| 1053 |
"plt.xlabel('round'); plt.ylabel('trust [0.1, 1.0]')\n",
|
| 1054 |
"plt.legend(ncol=2, fontsize=8); plt.grid(alpha=0.3); plt.tight_layout()\n",
|
| 1055 |
"plt.savefig(ASSETS / 'trust_trajectory.png', dpi=150); plt.close()\n",
|
|
|
|
| 1058 |
},
|
| 1059 |
{
|
| 1060 |
"cell_type": "markdown",
|
| 1061 |
+
"id": "5b3a59b1",
|
| 1062 |
"metadata": {},
|
| 1063 |
"source": [
|
| 1064 |
+
"## 19. Proof #8 \u2014 qualitative transcripts (fine-tuned + base on same demo seeds)"
|
| 1065 |
]
|
| 1066 |
},
|
| 1067 |
{
|
| 1068 |
"cell_type": "code",
|
| 1069 |
"execution_count": null,
|
| 1070 |
+
"id": "c4209499",
|
| 1071 |
"metadata": {},
|
| 1072 |
"outputs": [],
|
| 1073 |
"source": [
|
|
|
|
| 1111 |
},
|
| 1112 |
{
|
| 1113 |
"cell_type": "markdown",
|
| 1114 |
+
"id": "7a55c4a5",
|
| 1115 |
"metadata": {},
|
| 1116 |
"source": [
|
| 1117 |
+
"## 20. Proof #9 \u2014 decision distribution (did the policy collapse?)"
|
| 1118 |
]
|
| 1119 |
},
|
| 1120 |
{
|
| 1121 |
"cell_type": "code",
|
| 1122 |
"execution_count": null,
|
| 1123 |
+
"id": "23dcedd3",
|
| 1124 |
"metadata": {},
|
| 1125 |
"outputs": [],
|
| 1126 |
"source": [
|
|
|
|
| 1145 |
},
|
| 1146 |
{
|
| 1147 |
"cell_type": "markdown",
|
| 1148 |
+
"id": "f0137395",
|
| 1149 |
"metadata": {},
|
| 1150 |
"source": [
|
| 1151 |
"## 21. Push model + artifacts to HF"
|
|
|
|
| 1154 |
{
|
| 1155 |
"cell_type": "code",
|
| 1156 |
"execution_count": null,
|
| 1157 |
+
"id": "10b4093a",
|
| 1158 |
"metadata": {},
|
| 1159 |
"outputs": [],
|
| 1160 |
"source": [
|
|
|
|
| 1200 |
},
|
| 1201 |
{
|
| 1202 |
"cell_type": "markdown",
|
| 1203 |
+
"id": "04adcd78",
|
| 1204 |
"metadata": {},
|
| 1205 |
"source": [
|
| 1206 |
"## 22. Final summary printout (for the README / video)"
|
|
|
|
| 1209 |
{
|
| 1210 |
"cell_type": "code",
|
| 1211 |
"execution_count": null,
|
| 1212 |
+
"id": "42f47cc5",
|
| 1213 |
"metadata": {},
|
| 1214 |
"outputs": [],
|
| 1215 |
"source": [
|