{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# AutoDataLab++ — GRPO + RLVR for the Chief of Staff\n", "\n", "Trains a small instruct LLM to act as the **Chief of Staff** in our OpenEnv multi-agent environment. Inspired by the TRL OpenEnv Wordle GRPO notebook ([source](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_wordle_grpo.ipynb)), but adapted to a multi-turn orchestration task with a verifiable terminal grader (RLVR).\n", "\n", "**Prereqs**\n", "1. Deploy `autodatalab-plus/` to a Hugging Face Space (see `SPACE_README.md`).\n", "2. Runtime → Change runtime type → **T4 GPU** (free) for 1.5B, or **A100** for 3B/7B.\n", "3. Paste your **HF token (write)** in the cell marked `### >>> PASTE HF TOKEN HERE <<<`.\n", "4. Paste your Space URL in the cell marked `### >>> PASTE BASE_URL HERE <<<`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Install (one-time, ~3 min)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install -q --upgrade \"transformers>=4.45\" \"trl>=0.12\" \"peft>=0.13\" \"accelerate>=1.0\" \"bitsandbytes>=0.44\" \"datasets>=3.0\" \"requests>=2.32\" matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Auth & config" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "### >>> PASTE HF TOKEN HERE <<<\n", "HF_TOKEN = \"hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\"\n", "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n", "os.environ[\"HUGGING_FACE_HUB_TOKEN\"] = HF_TOKEN\n", "\n", "from huggingface_hub import login\n", "login(token=HF_TOKEN, add_to_git_credential=False)\n", "\n", "### >>> PASTE BASE_URL HERE <<<\n", "BASE_URL = \"https://-autodatalab-plus.hf.space\"\n", "\n", "# Pick model. Defaults are GPU-aware:\n", "# T4 (16 GB) -> Qwen/Qwen2.5-1.5B-Instruct (recommended, fastest, lowest token cost)\n", "# A100 (40 GB)-> Qwen/Qwen2.5-3B-Instruct or Qwen/Qwen2.5-7B-Instruct\n", "MODEL_ID = \"Qwen/Qwen2.5-1.5B-Instruct\"\n", "USE_4BIT = True # QLoRA; flip to False on A100 if you have headroom\n", "\n", "import requests\n", "r = requests.get(f\"{BASE_URL}/health\", timeout=10); r.raise_for_status()\n", "print(\"Space healthy:\", r.json())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. OpenEnv HTTP client (thin wrapper around `/reset` + `/step`)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json, requests\n", "from typing import Any\n", "\n", "class CoSEnvClient:\n", " def __init__(self, base_url: str, timeout: float = 30.0):\n", " self.base_url = base_url.rstrip(\"/\")\n", " self.timeout = timeout\n", " self.episode_id: str | None = None\n", "\n", " def reset(self, task: str = \"easy_brief\", use_rag: bool = False) -> dict[str, Any]:\n", " r = requests.post(f\"{self.base_url}/reset\", json={\"task\": task, \"use_rag\": use_rag}, timeout=self.timeout)\n", " r.raise_for_status()\n", " obs = r.json()\n", " self.episode_id = obs[\"episode_id\"]\n", " return obs\n", "\n", " def step(self, action: dict[str, Any]) -> dict[str, Any]:\n", " assert self.episode_id is not None, \"call reset() first\"\n", " body = {\"episode_id\": self.episode_id, \"action\": action}\n", " r = requests.post(f\"{self.base_url}/step\", json=body, timeout=self.timeout)\n", " r.raise_for_status()\n", " return r.json()\n", "\n", "# smoke test\n", "c = CoSEnvClient(BASE_URL)\n", "obs0 = c.reset(\"easy_brief\")\n", "print(\"reset ok | task:\", obs0[\"task_name\"], \"| max_steps:\", obs0[\"max_steps\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Prompt + action parsing\n", "\n", "The CoS must emit ONE JSON action per turn. Keep prompts short to stay token-cheap." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "SYSTEM_PROMPT = (\n", " \"You are the Chief of Staff in AutoDataLab++. You orchestrate four specialists: \"\n", " \"analyst, finance, strategy, hr. Reply with STRICT JSON only.\\n\"\n", " \"Schema: {\\\"action_type\\\": one of [consult, ask, summarize, submit, noop], \"\n", " \"\\\"expert_id\\\": one of [analyst, finance, hr, strategy] or null}.\\n\"\n", " \"Rules: consult each required expert at most once -> summarize -> submit.\"\n", ")\n", "\n", "def render_obs(obs: dict) -> str:\n", " return (\n", " f\"task={obs['task_name']} step={obs['step_count']}/{obs['max_steps']} \"\n", " f\"rag={obs['rag_enabled']} \"\n", " f\"consulted={obs['consulted_experts']} \"\n", " f\"brief_done={obs.get('current_brief') is not None} \"\n", " f\"available={obs['available_experts']}\"\n", " )\n", "\n", "import json, re\n", "_JSON_RE = re.compile(r\"\\{[^{}]*\\}\", re.S)\n", "VALID_ACTIONS = {\"consult\",\"ask\",\"summarize\",\"submit\",\"noop\"}\n", "VALID_EXPERTS = {\"analyst\",\"finance\",\"hr\",\"strategy\"}\n", "\n", "def parse_action(text: str) -> dict:\n", " m = _JSON_RE.search(text or \"\")\n", " if not m:\n", " return {\"action_type\": \"noop\"}\n", " try:\n", " a = json.loads(m.group(0))\n", " except Exception:\n", " return {\"action_type\": \"noop\"}\n", " at = a.get(\"action_type\")\n", " if at not in VALID_ACTIONS:\n", " return {\"action_type\": \"noop\"}\n", " eid = a.get(\"expert_id\")\n", " if eid is not None and eid not in VALID_EXPERTS:\n", " eid = None\n", " return {\"action_type\": at, \"expert_id\": eid}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Load model (QLoRA-ready)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n", "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n", "\n", "tok = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)\n", "if tok.pad_token is None:\n", " tok.pad_token = tok.eos_token\n", "\n", "bnb = None\n", "if USE_4BIT:\n", " bnb = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=torch.bfloat16,\n", " bnb_4bit_use_double_quant=True,\n", " )\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " MODEL_ID, token=HF_TOKEN, device_map=\"auto\",\n", " quantization_config=bnb, torch_dtype=torch.bfloat16,\n", ")\n", "if USE_4BIT:\n", " model = prepare_model_for_kbit_training(model)\n", "\n", "lora = LoraConfig(\n", " r=16, lora_alpha=32, lora_dropout=0.05, bias=\"none\", task_type=\"CAUSAL_LM\",\n", " target_modules=[\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\"],\n", ")\n", "model = get_peft_model(model, lora)\n", "model.print_trainable_parameters()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. RLVR rollout: env returns the verifiable reward\n", "\n", "For each prompt we roll out a full episode against the Space, capture the env's `terminal_grader_score`, and shape it lightly with per-step `reward`. This is the **GRPO reward function**.\n", "\n", "We use `GRPOTrainer` with **`num_generations=4`** group samples per prompt; the env is queried for each sample. To keep token spend low we cap `max_new_tokens` to **48** (a CoS JSON action is ~20 tokens)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import Dataset\n", "\n", "TASKS = [\"easy_brief\", \"medium_brief\", \"hard_brief\", \"expert_brief\"]\n", "\n", "def make_prompt(task: str, use_rag: bool) -> str:\n", " client = CoSEnvClient(BASE_URL)\n", " obs = client.reset(task=task, use_rag=use_rag)\n", " user = render_obs(obs)\n", " msgs = [{\"role\":\"system\",\"content\":SYSTEM_PROMPT},{\"role\":\"user\",\"content\":user}]\n", " text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n", " # We stash the episode_id and task in the prompt as a hidden suffix the trainer ignores;\n", " # actually we just regenerate per-rollout in the reward fn (simpler + safer).\n", " return {\"prompt\": text, \"task\": task, \"use_rag\": use_rag}\n", "\n", "rows = []\n", "for _ in range(48): # 48 prompts * 4 generations = 192 rollouts per epoch\n", " for t in TASKS:\n", " rows.append(make_prompt(t, use_rag=False))\n", "train_ds = Dataset.from_list(rows)\n", "print(\"prompts:\", len(train_ds))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Reward function: each completion is the FIRST action; we then continue the episode\n", "# using a tiny rollout where we re-prompt the model for subsequent steps. To save\n", "# tokens, after step 1 we use a deterministic continuation (oracle-style: consult\n", "# remaining required experts -> summarize -> submit) so reward purely measures\n", "# whether the model's FIRST action was a sensible orchestration choice. This is\n", "# the same trick the Wordle notebook uses (reward credits the policy's move,\n", "# environment provides verifiable scoring).\n", "\n", "REQUIRED = {\n", " \"easy_brief\": [\"analyst\",\"finance\"],\n", " \"medium_brief\": [\"analyst\",\"finance\",\"strategy\"],\n", " \"hard_brief\": [\"analyst\",\"finance\",\"strategy\",\"hr\"],\n", " \"expert_brief\": [\"analyst\",\"finance\",\"strategy\",\"hr\"],\n", "}\n", "\n", "def deterministic_continuation(client: CoSEnvClient, obs: dict, task: str) -> float:\n", " # Consult any still-missing required experts, then summarize+submit.\n", " while not obs[\"done\"] and obs[\"step_count\"] < obs[\"max_steps\"]:\n", " missing = [e for e in REQUIRED[task] if e not in obs[\"consulted_experts\"]]\n", " if missing:\n", " act = {\"action_type\":\"consult\",\"expert_id\":missing[0]}\n", " elif obs.get(\"current_brief\") is None:\n", " act = {\"action_type\":\"summarize\"}\n", " else:\n", " act = {\"action_type\":\"submit\"}\n", " obs = client.step(act)\n", " return float(obs.get(\"terminal_grader_score\") or 0.0)\n", "\n", "def reward_fn(prompts, completions, task, use_rag, **kw):\n", " rewards = []\n", " for comp, t, rag in zip(completions, task, use_rag):\n", " # 1) parse model's first action\n", " action = parse_action(comp)\n", " # 2) replay episode against the Space\n", " client = CoSEnvClient(BASE_URL)\n", " obs = client.reset(task=t, use_rag=bool(rag))\n", " try:\n", " obs = client.step(action)\n", " terminal = deterministic_continuation(client, obs, t)\n", " except Exception:\n", " terminal = 0.0\n", " # tiny shaping: penalize obviously-bad first moves so gradients are non-zero early\n", " shaping = 0.0\n", " if action[\"action_type\"] == \"consult\" and action[\"expert_id\"] in REQUIRED[t]:\n", " shaping += 0.05\n", " if action[\"action_type\"] in {\"submit\",\"summarize\"}:\n", " shaping -= 0.05 # never optimal as first move\n", " rewards.append(terminal + shaping)\n", " return rewards" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. GRPO training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from trl import GRPOConfig, GRPOTrainer\n", "\n", "cfg = GRPOConfig(\n", " output_dir=\"./cos_grpo_out\",\n", " learning_rate=5e-6,\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=4,\n", " num_generations=4, # group size for GRPO\n", " # Note: trl.GRPOConfig has no max_prompt_length; keep prompts short in the dataset (render_obs).\n", " max_completion_length=48, # JSON action is short -> token-cheap\n", " num_train_epochs=1,\n", " logging_steps=2,\n", " save_steps=50,\n", " bf16=True,\n", " gradient_checkpointing=True,\n", " report_to=\"none\",\n", " temperature=0.7,\n", " beta=0.04, # KL to reference\n", ")\n", "\n", "trainer = GRPOTrainer(\n", " model=model,\n", " processing_class=tok,\n", " reward_funcs=[reward_fn],\n", " args=cfg,\n", " train_dataset=train_ds,\n", ")\n", "trainer.train()\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. Reward curve + before/after eval" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "hist = [h for h in trainer.state.log_history if \"reward\" in h]\n", "if hist:\n", " xs = [h[\"step\"] for h in hist]\n", " ys = [h[\"reward\"] for h in hist]\n", " plt.figure(figsize=(8,4))\n", " plt.plot(xs, ys, label=\"mean reward\")\n", " plt.xlabel(\"step\"); plt.ylabel(\"reward\"); plt.title(\"GRPO reward curve\"); plt.grid(alpha=.3); plt.legend()\n", " plt.savefig(\"reward_curve.png\", dpi=130, bbox_inches=\"tight\")\n", " plt.show()\n", " print(\"saved reward_curve.png\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@torch.no_grad()\n", "def eval_policy(n_per_task: int = 5) -> dict:\n", " out = {}\n", " for t in TASKS:\n", " scores = []\n", " for _ in range(n_per_task):\n", " client = CoSEnvClient(BASE_URL)\n", " obs = client.reset(task=t, use_rag=False)\n", " msgs = [{\"role\":\"system\",\"content\":SYSTEM_PROMPT},{\"role\":\"user\",\"content\":render_obs(obs)}]\n", " text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n", " ids = tok(text, return_tensors=\"pt\").to(model.device)\n", " gen = model.generate(**ids, max_new_tokens=48, do_sample=False, pad_token_id=tok.pad_token_id)\n", " comp = tok.decode(gen[0, ids.input_ids.shape[1]:], skip_special_tokens=True)\n", " action = parse_action(comp)\n", " try:\n", " obs = client.step(action)\n", " term = deterministic_continuation(client, obs, t)\n", " except Exception:\n", " term = 0.0\n", " scores.append(term)\n", " out[t] = round(sum(scores)/len(scores), 4)\n", " return out\n", "\n", "print(\"after-training eval:\", eval_policy(5))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9. Save & (optionally) push the LoRA adapter to the Hub" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer.save_model(\"./cos_grpo_out/final\")\n", "# trainer.push_to_hub(\"/autodatalab-cos-qwen2_5-1_5b-grpo\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11" } }, "nbformat": 4, "nbformat_minor": 5 }