Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile.notebook +5 -7
- README.md +10 -2
- generate_episodes.py +9 -4
- notebooks/crisisinbox_grpo_simple.ipynb +301 -20
- server/app.py +11 -1
- server/crisis_inbox_environment.py +3 -0
- training/crisisinbox_training.py +44 -18
Dockerfile.notebook
CHANGED
|
@@ -13,14 +13,12 @@ RUN pip install --no-cache-dir \
|
|
| 13 |
peft \
|
| 14 |
huggingface_hub
|
| 15 |
|
| 16 |
-
# Copy
|
| 17 |
-
COPY
|
| 18 |
-
COPY generate_episodes.py .
|
| 19 |
-
COPY models.py .
|
| 20 |
-
COPY messages.py .
|
| 21 |
-
COPY drift_events.py .
|
| 22 |
COPY notebooks/ ./notebooks/
|
| 23 |
|
| 24 |
EXPOSE 8888
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
| 13 |
peft \
|
| 14 |
huggingface_hub
|
| 15 |
|
| 16 |
+
# Copy source files needed for training
|
| 17 |
+
COPY generate_episodes.py models.py messages.py drift_events.py ./
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
COPY notebooks/ ./notebooks/
|
| 19 |
|
| 20 |
EXPOSE 8888
|
| 21 |
|
| 22 |
+
# Download training data at startup (build env may block outbound network), then launch Jupyter
|
| 23 |
+
CMD python -c "from huggingface_hub import hf_hub_download; hf_hub_download(repo_id='eptan/crisis-inbox-episodes', filename='episodes.json', repo_type='dataset', local_dir='/app'); import os; os.rename('/app/episodes.json', '/app/.episodes.json')" && \
|
| 24 |
+
jupyter notebook --ip=0.0.0.0 --port=8888 --no-browser --allow-root --NotebookApp.token=''
|
README.md
CHANGED
|
@@ -166,8 +166,8 @@ crisis-inbox/
|
|
| 166 |
│ ├── app.py # FastAPI app with MCPAction workaround
|
| 167 |
│ └── Dockerfile # HF Spaces deployment
|
| 168 |
├── notebooks/
|
| 169 |
-
│ └──
|
| 170 |
-
├── episodes.json
|
| 171 |
├── generate_episodes.py # Episode generator script
|
| 172 |
├── pyproject.toml # Package config
|
| 173 |
├── openenv.yaml # OpenEnv environment spec
|
|
@@ -182,6 +182,14 @@ crisis-inbox/
|
|
| 182 |
- **Training:** Unsloth GRPO via Google Colab
|
| 183 |
- **Model:** Qwen2.5-0.5B-Instruct
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
## Team
|
| 186 |
|
| 187 |
Built at the OpenEnv Hackathon @ Shack15, SF — March 7-8, 2026
|
|
|
|
| 166 |
│ ├── app.py # FastAPI app with MCPAction workaround
|
| 167 |
│ └── Dockerfile # HF Spaces deployment
|
| 168 |
├── notebooks/
|
| 169 |
+
│ └── crisisinbox_grpo_simple.ipynb # GRPO training notebook (Colab)
|
| 170 |
+
├── .episodes.json # Pre-generated training episodes (gitignored)
|
| 171 |
├── generate_episodes.py # Episode generator script
|
| 172 |
├── pyproject.toml # Package config
|
| 173 |
├── openenv.yaml # OpenEnv environment spec
|
|
|
|
| 182 |
- **Training:** Unsloth GRPO via Google Colab
|
| 183 |
- **Model:** Qwen2.5-0.5B-Instruct
|
| 184 |
|
| 185 |
+
### GRPO training (Colab)
|
| 186 |
+
|
| 187 |
+
Open the notebook with the latest fixes (context length, reward signature, left-padding, batch size) in Google Colab (T4 GPU runtime):
|
| 188 |
+
|
| 189 |
+
**[Open in Colab](https://colab.research.google.com/github/eptan/crisis-inbox/blob/main/notebooks/crisisinbox_grpo_simple.ipynb)**
|
| 190 |
+
|
| 191 |
+
Push your local changes to the `main` branch so the link above serves the updated notebook.
|
| 192 |
+
|
| 193 |
## Team
|
| 194 |
|
| 195 |
Built at the OpenEnv Hackathon @ Shack15, SF — March 7-8, 2026
|
generate_episodes.py
CHANGED
|
@@ -270,7 +270,12 @@ def generate_episodes(num_episodes: int = 50, start_seed: int = 1000) -> list:
|
|
| 270 |
seed = start_seed + i
|
| 271 |
print(f" Episode {i + 1}/{num_episodes} (seed={seed})...", end=" ")
|
| 272 |
episode = build_episode(seed)
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
n_msg = episode["total_messages"]
|
| 275 |
drifts = ", ".join(episode["drift_events"])
|
| 276 |
print(f"{n_msg} messages, {n_dp} decision points, drifts: [{drifts}]")
|
|
@@ -282,7 +287,7 @@ def save_episodes(episodes: list, filename: str = "episodes.json"):
|
|
| 282 |
"""Save episodes to JSON file."""
|
| 283 |
with open(filename, "w") as f:
|
| 284 |
json.dump(episodes, f, indent=2)
|
| 285 |
-
total_prompts = sum(len(ep
|
| 286 |
print(f"\nSaved {len(episodes)} episodes ({total_prompts} training prompts) to {filename}")
|
| 287 |
|
| 288 |
|
|
@@ -292,7 +297,7 @@ if __name__ == "__main__":
|
|
| 292 |
parser = argparse.ArgumentParser(description="Generate CrisisInbox training episodes")
|
| 293 |
parser.add_argument("-n", "--num-episodes", type=int, default=50, help="Number of episodes")
|
| 294 |
parser.add_argument("-s", "--start-seed", type=int, default=1000, help="Starting seed")
|
| 295 |
-
parser.add_argument("-o", "--output", type=str, default="episodes.json", help="Output file")
|
| 296 |
parser.add_argument("--sample", type=int, default=5, help="Also save N sample episodes")
|
| 297 |
args = parser.parse_args()
|
| 298 |
|
|
@@ -301,5 +306,5 @@ if __name__ == "__main__":
|
|
| 301 |
save_episodes(episodes, args.output)
|
| 302 |
|
| 303 |
if args.sample > 0:
|
| 304 |
-
sample_file = "sample_episodes.json"
|
| 305 |
save_episodes(episodes[:args.sample], sample_file)
|
|
|
|
| 270 |
seed = start_seed + i
|
| 271 |
print(f" Episode {i + 1}/{num_episodes} (seed={seed})...", end=" ")
|
| 272 |
episode = build_episode(seed)
|
| 273 |
+
# Some episodes may not have decision points; skip them.
|
| 274 |
+
decision_points = episode.get("decision_points")
|
| 275 |
+
if not decision_points:
|
| 276 |
+
print("skipped (no decision_points)")
|
| 277 |
+
continue
|
| 278 |
+
n_dp = len(decision_points)
|
| 279 |
n_msg = episode["total_messages"]
|
| 280 |
drifts = ", ".join(episode["drift_events"])
|
| 281 |
print(f"{n_msg} messages, {n_dp} decision points, drifts: [{drifts}]")
|
|
|
|
| 287 |
"""Save episodes to JSON file."""
|
| 288 |
with open(filename, "w") as f:
|
| 289 |
json.dump(episodes, f, indent=2)
|
| 290 |
+
total_prompts = sum(len(ep.get("decision_points", [])) for ep in episodes)
|
| 291 |
print(f"\nSaved {len(episodes)} episodes ({total_prompts} training prompts) to {filename}")
|
| 292 |
|
| 293 |
|
|
|
|
| 297 |
parser = argparse.ArgumentParser(description="Generate CrisisInbox training episodes")
|
| 298 |
parser.add_argument("-n", "--num-episodes", type=int, default=50, help="Number of episodes")
|
| 299 |
parser.add_argument("-s", "--start-seed", type=int, default=1000, help="Starting seed")
|
| 300 |
+
parser.add_argument("-o", "--output", type=str, default=".episodes.json", help="Output file")
|
| 301 |
parser.add_argument("--sample", type=int, default=5, help="Also save N sample episodes")
|
| 302 |
args = parser.parse_args()
|
| 303 |
|
|
|
|
| 306 |
save_episodes(episodes, args.output)
|
| 307 |
|
| 308 |
if args.sample > 0:
|
| 309 |
+
sample_file = ".sample_episodes.json"
|
| 310 |
save_episodes(episodes[:args.sample], sample_file)
|
notebooks/crisisinbox_grpo_simple.ipynb
CHANGED
|
@@ -3,95 +3,376 @@
|
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
-
"source":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
},
|
| 8 |
{
|
| 9 |
"cell_type": "code",
|
| 10 |
-
"execution_count": null,
|
| 11 |
"metadata": {},
|
| 12 |
-
"source": "# Install dependencies\n!pip install unsloth trl transformers datasets accelerate peft -q\n!pip install huggingface_hub -q\
|
|
|
|
| 13 |
"outputs": []
|
| 14 |
},
|
| 15 |
{
|
| 16 |
"cell_type": "code",
|
| 17 |
-
"source": "# === GPU PROFILE ===\n# Change this one variable to switch between T4 and H100 configs.\n# Everything else adapts automatically.\n\nimport torch\n\nif torch.cuda.is_available():\n vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9\n gpu_name = torch.cuda.get_device_name(0)\n print(f\"GPU: {gpu_name} ({vram_gb:.0f} GB)\")\nelse:\n vram_gb = 0\n print(\"No GPU detected — config will default to smallest profile\")\n\n# Auto-select profile based on VRAM, or override manually\nif vram_gb >= 40: # H100, A100\n PROFILE = \"h100\"\n MODEL_NAME = \"unsloth/Qwen2.5-3B-Instruct\"\n MAX_SEQ_LENGTH = 4096\n MAX_PROMPT_LENGTH = 3584\n MAX_COMPLETION_LENGTH = 512\n BATCH_SIZE = 4\n GRAD_ACCUM = 2\n NUM_GENERATIONS = 8\nelif vram_gb >= 14: # T4, L4\n PROFILE = \"t4\"\n MODEL_NAME = \"unsloth/Qwen2.5-0.5B-Instruct\"\n MAX_SEQ_LENGTH = 2048\n MAX_PROMPT_LENGTH = 1792\n MAX_COMPLETION_LENGTH = 256\n BATCH_SIZE = 2\n GRAD_ACCUM = 4\n NUM_GENERATIONS = 4\nelse:\n PROFILE = \"cpu\"\n MODEL_NAME = \"unsloth/Qwen2.5-0.5B-Instruct\"\n MAX_SEQ_LENGTH = 2048\n MAX_PROMPT_LENGTH = 1792\n MAX_COMPLETION_LENGTH = 256\n BATCH_SIZE = 1\n GRAD_ACCUM = 8\n NUM_GENERATIONS = 2\n\nprint(f\"Profile: {PROFILE} | Model: {MODEL_NAME} | Context: {MAX_SEQ_LENGTH}\")",
|
| 18 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"execution_count": null,
|
| 20 |
"outputs": []
|
| 21 |
},
|
| 22 |
{
|
| 23 |
"cell_type": "code",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
"execution_count": null,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
"metadata": {},
|
| 26 |
-
"source": "import json\nimport re\nimport random\nfrom datasets import Dataset\n\n# Load episodes\
|
|
|
|
| 27 |
"outputs": []
|
| 28 |
},
|
| 29 |
{
|
| 30 |
"cell_type": "markdown",
|
| 31 |
-
"
|
| 32 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
},
|
| 34 |
{
|
| 35 |
"cell_type": "code",
|
| 36 |
-
"source": "def score_action(completion: str, prompt_data: dict) -> float:\n \"\"\"\n Score a model completion against the inbox state.\n \n The model should output: respond_to_message(msg_id, \"response text\")\n We parse the message_id and response, then score based on the reward function.\n \"\"\"\n messages = prompt_data[\"messages\"]\n hour = prompt_data[\"hour\"]\n superseded = prompt_data.get(\"superseded\", {})\n \n # Parse the model output for message_id\n msg_id = None\n response_text = \"\"\n \n # Try to parse respond_to_message(msg_id, response)\n match = re.search(r'respond_to_message\\s*\\(\\s*[\"\\']?(msg_\\d+)[\"\\']?\\s*,\\s*[\"\\'](.+?)[\"\\']', completion, re.DOTALL)\n if match:\n msg_id = match.group(1)\n response_text = match.group(2)\n else:\n # Try simpler format: just a message ID mentioned\n id_match = re.search(r'(msg_\\d+)', completion)\n if id_match:\n msg_id = id_match.group(1)\n # No explicit response text — penalize via quality check below\n response_text = \"\"\n \n if not msg_id:\n return -1.0 # couldn't parse any action\n \n # Find the message in the inbox\n target_msg = None\n for msg in messages:\n if msg[\"id\"] == msg_id:\n target_msg = msg\n break\n \n if target_msg is None:\n return -0.5 # referenced a message not in inbox\n \n # Base reward by urgency\n urgency_rewards = {\"critical\": 10.0, \"high\": 5.0, \"medium\": 3.0, \"low\": 1.0}\n reward = urgency_rewards.get(target_msg[\"urgency\"], 1.0)\n \n # Deadline timing\n deadline = target_msg.get(\"deadline_hours\")\n if deadline is not None:\n if hour <= deadline:\n time_remaining_frac = (deadline - hour) / max(deadline, 1.0)\n reward *= 1.0 + 0.5 * time_remaining_frac\n else:\n reward *= 0.25 # late penalty\n \n # Response quality\n if len(response_text.strip()) < 10:\n reward *= 0.5\n \n # Drift adaptation bonus\n if target_msg.get(\"drift_flag\"):\n reward *= 1.5\n \n # Stale info penalty\n if target_msg[\"id\"] in superseded:\n reward *= 0.5\n \n # Penalize choosing low-urgency when unhandled critical messages exist\n unhandled_critical = any(\n m[\"urgency\"] == \"critical\" and not m.get(\"handled\") and not m.get(\"superseded\")\n for m in messages\n )\n if unhandled_critical and target_msg[\"urgency\"] in (\"low\", \"medium\"):\n reward *= 0.3\n \n return round(reward, 2)\n\n\n# Test the reward function\ntest_data = prompts[0]\nprint(\"Testing reward function on first decision point:\")\nprint(f\" Hour: {test_data['hour']}, Messages: {test_data['visible_count']}\")\n\n# Simulate good action (pick critical message)\ncritical_msgs = [m for m in test_data[\"messages\"] if m[\"urgency\"] == \"critical\"]\nif critical_msgs:\n good_action = f'respond_to_message(\"{critical_msgs[0][\"id\"]}\", \"Acknowledged. Evacuating immediately with documents and medication.\")'\n good_score = score_action(good_action, test_data)\n print(f\" Good action (critical msg): {good_score:.2f} pts\")\n\n# Simulate bad action (pick low-urgency message)\nlow_msgs = [m for m in test_data[\"messages\"] if m[\"urgency\"] == \"low\"]\nif low_msgs:\n bad_action = f'respond_to_message(\"{low_msgs[0][\"id\"]}\", \"ok\")'\n bad_score = score_action(bad_action, test_data)\n print(f\" Bad action (low msg, short response): {bad_score:.2f} pts\")\n\n# Simulate unparseable action\njunk_score = score_action(\"I think we should do something\", test_data)\nprint(f\" Unparseable action: {junk_score:.2f} pts\")",
|
| 37 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"execution_count": null,
|
| 39 |
"outputs": []
|
| 40 |
},
|
| 41 |
{
|
| 42 |
"cell_type": "markdown",
|
| 43 |
-
"
|
| 44 |
-
"
|
|
|
|
|
|
|
| 45 |
},
|
| 46 |
{
|
| 47 |
"cell_type": "code",
|
| 48 |
-
"source": "from unsloth import FastLanguageModel\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_NAME,\n max_seq_length=MAX_SEQ_LENGTH,\n load_in_4bit=True,\n)\n\n# Add LoRA adapters — bigger r for bigger models\nlora_r = 32 if PROFILE == \"h100\" else 16\n\nmodel = FastLanguageModel.get_peft_model(\n model,\n r=lora_r,\n target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\"],\n lora_alpha=lora_r,\n lora_dropout=0,\n bias=\"none\",\n use_gradient_checkpointing=\"unsloth\",\n)\nprint(f\"Model loaded: {MODEL_NAME} | LoRA r={lora_r} | ctx={MAX_SEQ_LENGTH}\")",
|
| 49 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
"execution_count": null,
|
| 51 |
"outputs": []
|
| 52 |
},
|
| 53 |
{
|
| 54 |
"cell_type": "code",
|
| 55 |
-
"source": "# Build the training dataset\n# Each row needs a \"prompt\" field formatted as chat messages\ntrain_data = []\nfor p in prompts:\n train_data.append({\n \"prompt\": [\n {\"role\": \"user\", \"content\": p[\"prompt\"]},\n ],\n # Store metadata for reward calculation (not used by trainer directly)\n \"_hour\": p[\"hour\"],\n \"_episode_id\": p[\"episode_id\"],\n })\n\n# Shuffle and split\nrandom.seed(42)\nrandom.shuffle(train_data)\n\ndataset = Dataset.from_list(train_data)\nprint(f\"Training dataset: {len(dataset)} prompts\")\nprint(f\"Sample prompt length: {len(train_data[0]['prompt'][0]['content'])} chars\")",
|
| 56 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
"execution_count": null,
|
| 58 |
"outputs": []
|
| 59 |
},
|
| 60 |
{
|
| 61 |
"cell_type": "markdown",
|
| 62 |
-
"
|
| 63 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
},
|
| 65 |
{
|
| 66 |
"cell_type": "code",
|
| 67 |
-
"source": "from trl import GRPOConfig, GRPOTrainer\n\n# Build a lookup from (episode_id, hour) -> prompt metadata for reward scoring\nprompt_lookup = {}\nfor p in prompts:\n key = (p[\"episode_id\"], p[\"hour\"])\n prompt_lookup[key] = p\n\n\ndef reward_fn(prompts, completions, _episode_id, _hour, **kwargs):\n \"\"\"\n GRPO reward function. Scores each completion against its inbox state.\n\n TRL passes extra dataset columns as keyword arguments, so _episode_id and\n _hour come directly from the dataset — no need to reverse-lookup from text.\n \"\"\"\n rewards = []\n for completion, ep_id, hour in zip(completions, _episode_id, _hour):\n key = (ep_id, hour)\n prompt_data = prompt_lookup.get(key)\n\n if prompt_data is None:\n rewards.append(0.0)\n continue\n\n if isinstance(completion, list):\n comp_text = completion[-1][\"content\"] if completion else \"\"\n else:\n comp_text = str(completion)\n\n score = score_action(comp_text, prompt_data)\n rewards.append(score)\n\n return rewards\n\n\nprint(f\"Prompt lookup: {len(prompt_lookup)} unique keys (expect {len(prompts)})\")\n\n# GRPO training config — all values from GPU profile\ntraining_args = GRPOConfig(\n output_dir=\"crisisinbox-grpo-output\",\n num_train_epochs=3,\n per_device_train_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=GRAD_ACCUM,\n learning_rate=5e-6,\n max_completion_length=MAX_COMPLETION_LENGTH,\n max_prompt_length=MAX_PROMPT_LENGTH,\n num_generations=NUM_GENERATIONS,\n logging_steps=10,\n save_steps=100,\n report_to=\"none\",\n bf16=True,\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=reward_fn,\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Trainer configured — batch={BATCH_SIZE}, gen={NUM_GENERATIONS}, prompt≤{MAX_PROMPT_LENGTH}tok\")",
|
| 68 |
"metadata": {},
|
|
|
|
| 69 |
"execution_count": null,
|
| 70 |
"outputs": []
|
| 71 |
},
|
| 72 |
{
|
| 73 |
"cell_type": "code",
|
| 74 |
-
"source": "# Train!\ntrainer.train()\nprint(\"Training complete\")",
|
| 75 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
"execution_count": null,
|
| 77 |
"outputs": []
|
| 78 |
},
|
| 79 |
{
|
| 80 |
"cell_type": "markdown",
|
| 81 |
-
"
|
| 82 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
},
|
| 84 |
{
|
| 85 |
"cell_type": "code",
|
| 86 |
-
"source": "# Evaluate on a few test prompts\nFastLanguageModel.for_inference(model)\n\neval_prompts = random.sample(prompts, min(10, len(prompts)))\ntotal_score = 0\n\nprint(f\"=== Trained Model Evaluation ({MODEL_NAME}) ===\\n\")\nfor p in eval_prompts:\n messages = [{\"role\": \"user\", \"content\": p[\"prompt\"]}]\n inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\", add_generation_prompt=True).to(\"cuda\")\n\n with torch.no_grad():\n output = model.generate(inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)\n\n completion = tokenizer.decode(output[0][inputs.shape[1]:], skip_special_tokens=True)\n score = score_action(completion, p)\n total_score += score\n\n # Show a summary\n msg_match = re.search(r'(msg_\\d+)', completion)\n chosen_id = msg_match.group(1) if msg_match else \"none\"\n chosen_msg = next((m for m in p[\"messages\"] if m[\"id\"] == chosen_id), None)\n urgency = chosen_msg[\"urgency\"] if chosen_msg else \"?\"\n\n print(f\"Hour {p['hour']:5.1f} | Chose: {chosen_id} ({urgency:8s}) | Score: {score:+.1f}\")\n\nprint(f\"\\nAverage score: {total_score / len(eval_prompts):.2f}\")",
|
| 87 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
"execution_count": null,
|
| 89 |
"outputs": []
|
| 90 |
},
|
| 91 |
{
|
| 92 |
"cell_type": "code",
|
| 93 |
-
"source": "# Save the trained model\nmodel.save_pretrained(\"crisisinbox-grpo-trained\")\ntokenizer.save_pretrained(\"crisisinbox-grpo-trained\")\nprint(\"Model saved to crisisinbox-grpo-trained/\")",
|
| 94 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
"execution_count": null,
|
| 96 |
"outputs": []
|
| 97 |
}
|
|
|
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# CrisisInbox GRPO Training\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Train a small LLM (Qwen2.5-0.5B) to triage crisis inbox messages using Group Relative Policy Optimization.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**What this does:**\n",
|
| 12 |
+
"1. Loads pre-generated episode data (inbox snapshots at decision points)\n",
|
| 13 |
+
"2. For each prompt, the model generates an action (which message to handle + response)\n",
|
| 14 |
+
"3. A reward function scores the action based on urgency, deadline, drift adaptation\n",
|
| 15 |
+
"4. GRPO updates the model to prefer higher-reward actions\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"Open in Google Colab with **T4 GPU** runtime."
|
| 18 |
+
]
|
| 19 |
},
|
| 20 |
{
|
| 21 |
"cell_type": "code",
|
|
|
|
| 22 |
"metadata": {},
|
| 23 |
+
"source": "# Install dependencies\n!pip install unsloth trl transformers datasets accelerate peft -q\n!pip install huggingface_hub -q\nprint(\"Setup complete\")",
|
| 24 |
+
"execution_count": null,
|
| 25 |
"outputs": []
|
| 26 |
},
|
| 27 |
{
|
| 28 |
"cell_type": "code",
|
|
|
|
| 29 |
"metadata": {},
|
| 30 |
+
"source": [
|
| 31 |
+
"# Avoid logging crash: transformers sometimes passes a Warning type to logger.warning(),\n",
|
| 32 |
+
"# which breaks %-style formatting. Patch so we don't pass that through.\n",
|
| 33 |
+
"import logging\n",
|
| 34 |
+
"import warnings\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"def _patch_transformers_logging():\n",
|
| 37 |
+
" try:\n",
|
| 38 |
+
" import transformers.utils.logging as trans_log\n",
|
| 39 |
+
" _orig = trans_log.logger.warning\n",
|
| 40 |
+
" def _safe_warning(msg, *args, **kwargs):\n",
|
| 41 |
+
" # If first extra arg is a Warning type (e.g. FutureWarning), drop it for % formatting\n",
|
| 42 |
+
" if args and isinstance(args[0], type) and issubclass(args[0], Warning):\n",
|
| 43 |
+
" args = ()\n",
|
| 44 |
+
" return _orig(msg, *args, **kwargs)\n",
|
| 45 |
+
" trans_log.logger.warning = _safe_warning\n",
|
| 46 |
+
" except Exception:\n",
|
| 47 |
+
" pass\n",
|
| 48 |
+
" warnings.filterwarnings(\"ignore\", message=\".*attention mask API.*\", category=FutureWarning)\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"_patch_transformers_logging()"
|
| 51 |
+
],
|
| 52 |
"execution_count": null,
|
| 53 |
"outputs": []
|
| 54 |
},
|
| 55 |
{
|
| 56 |
"cell_type": "code",
|
| 57 |
+
"metadata": {},
|
| 58 |
+
"source": [
|
| 59 |
+
"import torch\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"# Print GPU info (PyTorch uses total_memory, not total_mem)\n",
|
| 62 |
+
"if torch.cuda.is_available():\n",
|
| 63 |
+
" props = torch.cuda.get_device_properties(0)\n",
|
| 64 |
+
" total_bytes = getattr(props, \"total_memory\", None) or getattr(props, \"total_mem\", 0)\n",
|
| 65 |
+
" vram_gb = total_bytes / 1e9 if total_bytes else 0\n",
|
| 66 |
+
" if vram_gb == 0 and hasattr(torch.cuda, \"mem_get_info\"):\n",
|
| 67 |
+
" _, total_bytes = torch.cuda.mem_get_info(0)\n",
|
| 68 |
+
" vram_gb = total_bytes / 1e9\n",
|
| 69 |
+
" print(f\"GPU: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB)\")\n",
|
| 70 |
+
"else:\n",
|
| 71 |
+
" print(\"No GPU available.\")"
|
| 72 |
+
],
|
| 73 |
"execution_count": null,
|
| 74 |
+
"outputs": []
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"cell_type": "code",
|
| 78 |
"metadata": {},
|
| 79 |
+
"source": "import json\nimport re\nimport random\nimport os\nfrom datasets import Dataset\nfrom huggingface_hub import hf_hub_download\n\n# Load episodes from HF dataset\nEPISODES_FILE = \".episodes.json\"\nif not os.path.exists(EPISODES_FILE):\n print(\"Downloading episodes from HF...\")\n hf_hub_download(\n repo_id=\"eptan/crisis-inbox-episodes\",\n filename=\"episodes.json\",\n repo_type=\"dataset\",\n local_dir=\".\",\n local_dir_use_symlinks=False,\n )\n os.rename(\"episodes.json\", EPISODES_FILE)\n\nwith open(EPISODES_FILE) as f:\n episodes = json.load(f)\n\n# Flatten to individual training prompts\nprompts = []\nfor ep in episodes:\n for dp in ep[\"decision_points\"]:\n prompts.append({\n \"prompt\": dp[\"prompt\"],\n \"hour\": dp[\"hour\"],\n \"visible_count\": dp[\"visible_count\"],\n \"episode_id\": ep[\"episode_id\"],\n \"seed\": ep[\"seed\"],\n \"drift_events\": ep[\"drift_events\"],\n \"superseded\": ep.get(\"superseded_messages\", {}),\n \"messages\": dp[\"visible_messages\"],\n })\n\nprint(f\"Loaded {len(episodes)} episodes -> {len(prompts)} training prompts\")\nprint(f\"Average {len(prompts)/len(episodes):.1f} decision points per episode\")",
|
| 80 |
+
"execution_count": null,
|
| 81 |
"outputs": []
|
| 82 |
},
|
| 83 |
{
|
| 84 |
"cell_type": "markdown",
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"source": [
|
| 87 |
+
"## Reward Function\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"Scores agent actions based on:\n",
|
| 90 |
+
"- **Urgency base** (critical=10, high=5, medium=3, low=1)\n",
|
| 91 |
+
"- **Deadline timing** (early=bonus, late=penalty)\n",
|
| 92 |
+
"- **Drift adaptation** (+50% for handling policy-change messages)\n",
|
| 93 |
+
"- **Stale info penalty** (-50% for acting on superseded messages)\n",
|
| 94 |
+
"- **Response quality** (penalty for short/empty responses)"
|
| 95 |
+
]
|
| 96 |
},
|
| 97 |
{
|
| 98 |
"cell_type": "code",
|
|
|
|
| 99 |
"metadata": {},
|
| 100 |
+
"source": [
|
| 101 |
+
"def score_action(completion: str, prompt_data: dict) -> float:\n",
|
| 102 |
+
" \"\"\"\n",
|
| 103 |
+
" Score a model completion against the inbox state.\n",
|
| 104 |
+
" \n",
|
| 105 |
+
" The model should output: respond_to_message(msg_id, \"response text\")\n",
|
| 106 |
+
" We parse the message_id and response, then score based on the reward function.\n",
|
| 107 |
+
" \"\"\"\n",
|
| 108 |
+
" messages = prompt_data[\"messages\"]\n",
|
| 109 |
+
" hour = prompt_data[\"hour\"]\n",
|
| 110 |
+
" superseded = prompt_data.get(\"superseded\", {})\n",
|
| 111 |
+
" \n",
|
| 112 |
+
" # Parse the model output for message_id\n",
|
| 113 |
+
" msg_id = None\n",
|
| 114 |
+
" response_text = \"\"\n",
|
| 115 |
+
" \n",
|
| 116 |
+
" # Try to parse respond_to_message(msg_id, response)\n",
|
| 117 |
+
" match = re.search(r'respond_to_message\\s*\\(\\s*[\"\\']?(msg_\\d+)[\"\\']?\\s*,\\s*[\"\\'](.+?)[\"\\']', completion, re.DOTALL)\n",
|
| 118 |
+
" if match:\n",
|
| 119 |
+
" msg_id = match.group(1)\n",
|
| 120 |
+
" response_text = match.group(2)\n",
|
| 121 |
+
" else:\n",
|
| 122 |
+
" # Try simpler format: just a message ID mentioned\n",
|
| 123 |
+
" id_match = re.search(r'(msg_\\d+)', completion)\n",
|
| 124 |
+
" if id_match:\n",
|
| 125 |
+
" msg_id = id_match.group(1)\n",
|
| 126 |
+
" response_text = completion\n",
|
| 127 |
+
" \n",
|
| 128 |
+
" if not msg_id:\n",
|
| 129 |
+
" return -1.0 # couldn't parse any action\n",
|
| 130 |
+
" \n",
|
| 131 |
+
" # Find the message in the inbox\n",
|
| 132 |
+
" target_msg = None\n",
|
| 133 |
+
" for msg in messages:\n",
|
| 134 |
+
" if msg[\"id\"] == msg_id:\n",
|
| 135 |
+
" target_msg = msg\n",
|
| 136 |
+
" break\n",
|
| 137 |
+
" \n",
|
| 138 |
+
" if target_msg is None:\n",
|
| 139 |
+
" return -0.5 # referenced a message not in inbox\n",
|
| 140 |
+
" \n",
|
| 141 |
+
" # Base reward by urgency\n",
|
| 142 |
+
" urgency_rewards = {\"critical\": 10.0, \"high\": 5.0, \"medium\": 3.0, \"low\": 1.0}\n",
|
| 143 |
+
" reward = urgency_rewards.get(target_msg[\"urgency\"], 1.0)\n",
|
| 144 |
+
" \n",
|
| 145 |
+
" # Deadline timing\n",
|
| 146 |
+
" deadline = target_msg.get(\"deadline_hours\")\n",
|
| 147 |
+
" if deadline is not None:\n",
|
| 148 |
+
" if hour <= deadline:\n",
|
| 149 |
+
" time_remaining_frac = (deadline - hour) / max(deadline, 1.0)\n",
|
| 150 |
+
" reward *= 1.0 + 0.5 * time_remaining_frac\n",
|
| 151 |
+
" else:\n",
|
| 152 |
+
" reward *= 0.25 # late penalty\n",
|
| 153 |
+
" \n",
|
| 154 |
+
" # Response quality\n",
|
| 155 |
+
" if len(response_text.strip()) < 10:\n",
|
| 156 |
+
" reward *= 0.5\n",
|
| 157 |
+
" \n",
|
| 158 |
+
" # Drift adaptation bonus\n",
|
| 159 |
+
" if target_msg.get(\"drift_flag\"):\n",
|
| 160 |
+
" reward *= 1.5\n",
|
| 161 |
+
" \n",
|
| 162 |
+
" # Stale info penalty\n",
|
| 163 |
+
" if target_msg[\"id\"] in superseded:\n",
|
| 164 |
+
" reward *= 0.5\n",
|
| 165 |
+
" \n",
|
| 166 |
+
" # Bonus: penalize choosing low-urgency when critical exists\n",
|
| 167 |
+
" has_critical = any(m[\"urgency\"] == \"critical\" for m in messages)\n",
|
| 168 |
+
" if has_critical and target_msg[\"urgency\"] in (\"low\", \"medium\"):\n",
|
| 169 |
+
" reward *= 0.3 # strong penalty for ignoring critical messages\n",
|
| 170 |
+
" \n",
|
| 171 |
+
" return round(reward, 2)\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"\n",
|
| 174 |
+
"# Test the reward function\n",
|
| 175 |
+
"test_data = prompts[0]\n",
|
| 176 |
+
"print(\"Testing reward function on first decision point:\")\n",
|
| 177 |
+
"print(f\" Hour: {test_data['hour']}, Messages: {test_data['visible_count']}\")\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"# Simulate good action (pick critical message)\n",
|
| 180 |
+
"critical_msgs = [m for m in test_data[\"messages\"] if m[\"urgency\"] == \"critical\"]\n",
|
| 181 |
+
"if critical_msgs:\n",
|
| 182 |
+
" good_action = f'respond_to_message(\"{critical_msgs[0][\"id\"]}\", \"Acknowledged. Evacuating immediately with documents and medication.\")'\n",
|
| 183 |
+
" good_score = score_action(good_action, test_data)\n",
|
| 184 |
+
" print(f\" Good action (critical msg): {good_score:.2f} pts\")\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"# Simulate bad action (pick low-urgency message)\n",
|
| 187 |
+
"low_msgs = [m for m in test_data[\"messages\"] if m[\"urgency\"] == \"low\"]\n",
|
| 188 |
+
"if low_msgs:\n",
|
| 189 |
+
" bad_action = f'respond_to_message(\"{low_msgs[0][\"id\"]}\", \"ok\")'\n",
|
| 190 |
+
" bad_score = score_action(bad_action, test_data)\n",
|
| 191 |
+
" print(f\" Bad action (low msg, short response): {bad_score:.2f} pts\")\n",
|
| 192 |
+
"\n",
|
| 193 |
+
"# Simulate unparseable action\n",
|
| 194 |
+
"junk_score = score_action(\"I think we should do something\", test_data)\n",
|
| 195 |
+
"print(f\" Unparseable action: {junk_score:.2f} pts\")"
|
| 196 |
+
],
|
| 197 |
"execution_count": null,
|
| 198 |
"outputs": []
|
| 199 |
},
|
| 200 |
{
|
| 201 |
"cell_type": "markdown",
|
| 202 |
+
"metadata": {},
|
| 203 |
+
"source": [
|
| 204 |
+
"## Load Model & Configure GRPO"
|
| 205 |
+
]
|
| 206 |
},
|
| 207 |
{
|
| 208 |
"cell_type": "code",
|
|
|
|
| 209 |
"metadata": {},
|
| 210 |
+
"source": [
|
| 211 |
+
"from unsloth import FastLanguageModel\n",
|
| 212 |
+
"import torch\n",
|
| 213 |
+
"\n",
|
| 214 |
+
"# Load Qwen2.5-0.5B — small enough for T4 GPU\n",
|
| 215 |
+
"# Use a longer context window so prompt + completion\n",
|
| 216 |
+
"# comfortably fit without attention mask shape issues.\n",
|
| 217 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 218 |
+
" model_name=\"unsloth/Qwen2.5-0.5B-Instruct\",\n",
|
| 219 |
+
" max_seq_length=4096,\n",
|
| 220 |
+
" dtype=None,\n",
|
| 221 |
+
" load_in_4bit=True,\n",
|
| 222 |
+
")\n",
|
| 223 |
+
"\n",
|
| 224 |
+
"# Add LoRA adapters for efficient fine-tuning\n",
|
| 225 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
| 226 |
+
" model,\n",
|
| 227 |
+
" r=16,\n",
|
| 228 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
| 229 |
+
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 230 |
+
" lora_alpha=16,\n",
|
| 231 |
+
" lora_dropout=0,\n",
|
| 232 |
+
" bias=\"none\",\n",
|
| 233 |
+
" use_gradient_checkpointing=\"unsloth\",\n",
|
| 234 |
+
")\n",
|
| 235 |
+
"\n",
|
| 236 |
+
"# GRPO expects left-padding so completion positions align across the batch\n",
|
| 237 |
+
"# (avoids completion_mask vs log-probs shape mismatch in masked_batch_mean).\n",
|
| 238 |
+
"if tokenizer.pad_token_id is None:\n",
|
| 239 |
+
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
|
| 240 |
+
"tokenizer.padding_side = \"left\"\n",
|
| 241 |
+
"\n",
|
| 242 |
+
"print(\"Model loaded with LoRA adapters\")"
|
| 243 |
+
],
|
| 244 |
"execution_count": null,
|
| 245 |
"outputs": []
|
| 246 |
},
|
| 247 |
{
|
| 248 |
"cell_type": "code",
|
|
|
|
| 249 |
"metadata": {},
|
| 250 |
+
"source": [
|
| 251 |
+
"# Build the training dataset\n",
|
| 252 |
+
"# Each row needs a \"prompt\" field formatted as chat messages.\n",
|
| 253 |
+
"# Use a conservative max length so every batch has identical shape (avoids mask mismatch).\n",
|
| 254 |
+
"MAX_PROMPT_LENGTH = 1024 # must match GRPOConfig max_prompt_length below\n",
|
| 255 |
+
"\n",
|
| 256 |
+
"train_data = []\n",
|
| 257 |
+
"for p in prompts:\n",
|
| 258 |
+
" msgs = [{\"role\": \"user\", \"content\": p[\"prompt\"]}]\n",
|
| 259 |
+
" tok = tokenizer.apply_chat_template(msgs, return_tensors=\"pt\", add_generation_prompt=True)\n",
|
| 260 |
+
" # apply_chat_template can return a tensor or BatchEncoding. Do NOT use hasattr(tok, \"shape\")\n",
|
| 261 |
+
" # (BatchEncoding.__getattr__ raises when attribute is missing). Use dict-like check instead.\n",
|
| 262 |
+
" try:\n",
|
| 263 |
+
" ids = tok[\"input_ids\"]\n",
|
| 264 |
+
" except (TypeError, KeyError):\n",
|
| 265 |
+
" ids = tok\n",
|
| 266 |
+
" n_tokens = ids.shape[1] if ids.dim() > 1 else ids.shape[0]\n",
|
| 267 |
+
" if n_tokens > MAX_PROMPT_LENGTH:\n",
|
| 268 |
+
" continue # skip overlong prompts so batch shapes stay consistent\n",
|
| 269 |
+
" train_data.append({\n",
|
| 270 |
+
" \"prompt\": msgs,\n",
|
| 271 |
+
" \"_hour\": p[\"hour\"],\n",
|
| 272 |
+
" \"_episode_id\": p[\"episode_id\"],\n",
|
| 273 |
+
" })\n",
|
| 274 |
+
"\n",
|
| 275 |
+
"# Shuffle and split\n",
|
| 276 |
+
"random.seed(42)\n",
|
| 277 |
+
"random.shuffle(train_data)\n",
|
| 278 |
+
"\n",
|
| 279 |
+
"dataset = Dataset.from_list(train_data)\n",
|
| 280 |
+
"print(f\"Training dataset: {len(dataset)} prompts (after dropping prompts > {MAX_PROMPT_LENGTH} tokens)\")\n",
|
| 281 |
+
"print(f\"Sample prompt length: {len(train_data[0]['prompt'][0]['content'])} chars\")"
|
| 282 |
+
],
|
| 283 |
"execution_count": null,
|
| 284 |
"outputs": []
|
| 285 |
},
|
| 286 |
{
|
| 287 |
"cell_type": "markdown",
|
| 288 |
+
"metadata": {},
|
| 289 |
+
"source": [
|
| 290 |
+
"## GRPO Training Loop\n",
|
| 291 |
+
"\n",
|
| 292 |
+
"The reward function scores each completion by:\n",
|
| 293 |
+
"1. Parsing which message the model chose to handle\n",
|
| 294 |
+
"2. Checking urgency, deadline timing, drift flags\n",
|
| 295 |
+
"3. Penalizing bad choices (low-urgency when critical exists, stale info)"
|
| 296 |
+
]
|
| 297 |
},
|
| 298 |
{
|
| 299 |
"cell_type": "code",
|
|
|
|
| 300 |
"metadata": {},
|
| 301 |
+
"source": "from trl import GRPOConfig, GRPOTrainer\n\n# Build a lookup from (episode_id, hour) -> prompt metadata for reward scoring\nprompt_lookup = {}\nfor p in prompts:\n key = (p[\"episode_id\"], p[\"hour\"])\n prompt_lookup[key] = p\n\n\ndef reward_fn(prompts, completions, _episode_id=None, _hour=None, **kwargs):\n \"\"\"GRPO reward function. Scores each completion against its inbox state.\n TRL passes extra dataset columns as keyword args.\"\"\"\n rewards = []\n for i, (prompt_msgs, completion) in enumerate(zip(prompts, completions)):\n # Look up prompt data by (episode_id, hour) from dataset columns\n prompt_data = None\n if _episode_id is not None and _hour is not None:\n ep_id = _episode_id[i] if hasattr(_episode_id, '__getitem__') else _episode_id\n hour = _hour[i] if hasattr(_hour, '__getitem__') else _hour\n # Convert tensor/numpy to Python scalar if needed\n if hasattr(hour, 'item'):\n hour = hour.item()\n prompt_data = prompt_lookup.get((ep_id, hour))\n\n if prompt_data is None:\n rewards.append(0.0)\n continue\n\n # Extract completion text (trainer may pass token ids or message dicts)\n if isinstance(completion, list):\n if completion and isinstance(completion[0], (int, float)):\n comp_text = tokenizer.decode(completion, skip_special_tokens=True)\n else:\n comp_text = completion[-1].get(\"content\", \"\") if completion else \"\"\n else:\n comp_text = str(completion)\n\n score = score_action(comp_text, prompt_data)\n rewards.append(score)\n\n return rewards\n\n\n# GRPO training config: conservative batch/length to avoid mask shape mismatch.\ntraining_args = GRPOConfig(\n output_dir=\"crisisinbox-grpo-output\",\n num_train_epochs=3,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=4,\n steps_per_generation=4,\n learning_rate=5e-6,\n max_completion_length=256,\n max_prompt_length=1024,\n num_generations=2,\n logging_steps=10,\n save_steps=100,\n report_to=\"none\",\n bf16=False,\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=reward_fn,\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Trainer configured — {len(prompt_lookup)} unique (episode_id, hour) keys\")\nprint(\"Ready to train\")",
|
| 302 |
"execution_count": null,
|
| 303 |
"outputs": []
|
| 304 |
},
|
| 305 |
{
|
| 306 |
"cell_type": "code",
|
|
|
|
| 307 |
"metadata": {},
|
| 308 |
+
"source": [
|
| 309 |
+
"# Train!\n",
|
| 310 |
+
"trainer.train()\n",
|
| 311 |
+
"print(\"Training complete\")"
|
| 312 |
+
],
|
| 313 |
"execution_count": null,
|
| 314 |
"outputs": []
|
| 315 |
},
|
| 316 |
{
|
| 317 |
"cell_type": "markdown",
|
| 318 |
+
"metadata": {},
|
| 319 |
+
"source": [
|
| 320 |
+
"## Evaluate: Before vs After\n",
|
| 321 |
+
"\n",
|
| 322 |
+
"Compare the trained model's action choices against the base model on the same prompts."
|
| 323 |
+
]
|
| 324 |
},
|
| 325 |
{
|
| 326 |
"cell_type": "code",
|
|
|
|
| 327 |
"metadata": {},
|
| 328 |
+
"source": [
|
| 329 |
+
"# Evaluate on a few test prompts\n",
|
| 330 |
+
"FastLanguageModel.for_inference(model)\n",
|
| 331 |
+
"\n",
|
| 332 |
+
"eval_prompts = random.sample(prompts, min(10, len(prompts)))\n",
|
| 333 |
+
"total_score = 0\n",
|
| 334 |
+
"\n",
|
| 335 |
+
"print(\"=== Trained Model Evaluation ===\\n\")\n",
|
| 336 |
+
"for p in eval_prompts:\n",
|
| 337 |
+
" messages = [{\"role\": \"user\", \"content\": p[\"prompt\"]}]\n",
|
| 338 |
+
" raw = tokenizer.apply_chat_template(messages, return_tensors=\"pt\", add_generation_prompt=True)\n",
|
| 339 |
+
" # Do NOT use hasattr(raw, \"shape\") — BatchEncoding.__getattr__ raises. Use try/except.\n",
|
| 340 |
+
" try:\n",
|
| 341 |
+
" inputs = {k: v.to(\"cuda\") for k, v in raw.items()}\n",
|
| 342 |
+
" prompt_len = inputs[\"input_ids\"].shape[1]\n",
|
| 343 |
+
" except (TypeError, AttributeError):\n",
|
| 344 |
+
" inputs = raw.to(\"cuda\")\n",
|
| 345 |
+
" prompt_len = inputs.shape[1]\n",
|
| 346 |
+
"\n",
|
| 347 |
+
" with torch.no_grad():\n",
|
| 348 |
+
" output = model.generate(inputs, max_new_tokens=200, temperature=0.7, do_sample=True)\n",
|
| 349 |
+
"\n",
|
| 350 |
+
" completion = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)\n",
|
| 351 |
+
" score = score_action(completion, p)\n",
|
| 352 |
+
" total_score += score\n",
|
| 353 |
+
"\n",
|
| 354 |
+
" # Show a summary\n",
|
| 355 |
+
" msg_match = re.search(r'(msg_\\d+)', completion)\n",
|
| 356 |
+
" chosen_id = msg_match.group(1) if msg_match else \"none\"\n",
|
| 357 |
+
" chosen_msg = next((m for m in p[\"messages\"] if m[\"id\"] == chosen_id), None)\n",
|
| 358 |
+
" urgency = chosen_msg[\"urgency\"] if chosen_msg else \"?\"\n",
|
| 359 |
+
"\n",
|
| 360 |
+
" print(f\"Hour {p['hour']:5.1f} | Chose: {chosen_id} ({urgency:8s}) | Score: {score:+.1f}\")\n",
|
| 361 |
+
"\n",
|
| 362 |
+
"print(f\"\\nAverage score: {total_score / len(eval_prompts):.2f}\")"
|
| 363 |
+
],
|
| 364 |
"execution_count": null,
|
| 365 |
"outputs": []
|
| 366 |
},
|
| 367 |
{
|
| 368 |
"cell_type": "code",
|
|
|
|
| 369 |
"metadata": {},
|
| 370 |
+
"source": [
|
| 371 |
+
"# Save the trained model\n",
|
| 372 |
+
"model.save_pretrained(\"crisisinbox-grpo-trained\")\n",
|
| 373 |
+
"tokenizer.save_pretrained(\"crisisinbox-grpo-trained\")\n",
|
| 374 |
+
"print(\"Model saved to crisisinbox-grpo-trained/\")"
|
| 375 |
+
],
|
| 376 |
"execution_count": null,
|
| 377 |
"outputs": []
|
| 378 |
}
|
server/app.py
CHANGED
|
@@ -23,7 +23,17 @@ except ImportError:
|
|
| 23 |
|
| 24 |
|
| 25 |
class MCPAction(Action):
|
| 26 |
-
"""Action class that deserializes both ListToolsAction and CallToolAction.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
model_config = Action.model_config.copy()
|
| 29 |
model_config["extra"] = "allow"
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class MCPAction(Action):
|
| 26 |
+
"""Action class that deserializes both ListToolsAction and CallToolAction.
|
| 27 |
+
|
| 28 |
+
OpenEnv 0.2.1's WS handler passes a single action_cls to
|
| 29 |
+
deserialize_action(), but MCPToolClient sends both ListToolsAction
|
| 30 |
+
and CallToolAction through the "step" message path. Since the base
|
| 31 |
+
Action model uses extra="forbid", a fixed action_cls can't handle
|
| 32 |
+
both shapes. We override model_validate to route by the "type" field
|
| 33 |
+
so that MCPEnvironment.step() receives the correct Action subclass.
|
| 34 |
+
extra="allow" is needed because the two action types have different
|
| 35 |
+
field sets.
|
| 36 |
+
"""
|
| 37 |
|
| 38 |
model_config = Action.model_config.copy()
|
| 39 |
model_config["extra"] = "allow"
|
server/crisis_inbox_environment.py
CHANGED
|
@@ -34,6 +34,9 @@ class CrisisInboxEnvironment(MCPEnvironment):
|
|
| 34 |
"""
|
| 35 |
Simulates a 48-hour post-disaster inbox triage scenario.
|
| 36 |
|
|
|
|
|
|
|
|
|
|
| 37 |
The agent receives messages from family, employers, government agencies,
|
| 38 |
insurance companies, and service providers. It must prioritize safety,
|
| 39 |
meet deadlines, and adapt to changing rules (schema drift).
|
|
|
|
| 34 |
"""
|
| 35 |
Simulates a 48-hour post-disaster inbox triage scenario.
|
| 36 |
|
| 37 |
+
Note: SUPPORTS_CONCURRENT_SESSIONS is False (default) because the
|
| 38 |
+
environment holds mutable per-episode state (_all_messages, _handled, etc).
|
| 39 |
+
|
| 40 |
The agent receives messages from family, employers, government agencies,
|
| 41 |
insurance companies, and service providers. It must prioritize safety,
|
| 42 |
meet deadlines, and adapt to changing rules (schema drift).
|
training/crisisinbox_training.py
CHANGED
|
@@ -4,8 +4,7 @@ Person B: ML Pipeline
|
|
| 4 |
|
| 5 |
Run this in Google Colab:
|
| 6 |
1. Upload this file
|
| 7 |
-
2.
|
| 8 |
-
3. Run: python crisisinbox_training.py
|
| 9 |
"""
|
| 10 |
|
| 11 |
import torch
|
|
@@ -16,12 +15,14 @@ from datasets import Dataset
|
|
| 16 |
from unsloth import FastLanguageModel
|
| 17 |
from trl import GRPOConfig, GRPOTrainer
|
| 18 |
|
| 19 |
-
# Download episodes from
|
| 20 |
print("Loading episodes...")
|
| 21 |
-
import
|
| 22 |
-
|
| 23 |
-
"
|
| 24 |
-
"episodes.json"
|
|
|
|
|
|
|
| 25 |
)
|
| 26 |
|
| 27 |
with open("episodes.json", "r") as f:
|
|
@@ -33,6 +34,9 @@ print(f"✓ Loaded {len(EPISODES)} episodes")
|
|
| 33 |
# PROMPT BUILDING
|
| 34 |
# =============================================================================
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
CRISIS_SYSTEM_PROMPT = """
|
| 37 |
You are an assistant helping a working parent during a wildfire.
|
| 38 |
You must triage messages, act on safety-critical items first,
|
|
@@ -56,10 +60,17 @@ def build_crisis_prompt(episode):
|
|
| 56 |
msgs_str.append(
|
| 57 |
f"[t={m['time']}h] {urgency} From {m['sender']} via {m['channel']}: {m['content']}{deadline_info}"
|
| 58 |
)
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
drift_str = []
|
| 61 |
for d in episode.get("schema_events", []):
|
| 62 |
drift_str.append(f"[t={d['time']}h] POLICY UPDATE: {d['kind']} -> {d.get('new_value', 'changed')}")
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
user_content = (
|
| 65 |
"Here is your 48-hour message history:\n\n"
|
|
@@ -80,22 +91,34 @@ def build_crisis_prompt(episode):
|
|
| 80 |
|
| 81 |
def parse_plan(model_output):
|
| 82 |
"""Parse <plan> tag output into list of action dicts."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
actions = []
|
| 84 |
-
|
| 85 |
-
plan_match = re.search(r
|
| 86 |
-
|
|
|
|
| 87 |
return []
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
lines = plan_content.split('\n')
|
| 92 |
for line in lines:
|
| 93 |
line = line.strip()
|
| 94 |
if not line or not line[0].isdigit():
|
| 95 |
continue
|
| 96 |
|
| 97 |
# Extract time: [time=min X]
|
| 98 |
-
time_match = re.search(r
|
| 99 |
time_min = int(time_match.group(1)) if time_match else 0
|
| 100 |
|
| 101 |
# Extract action description
|
|
@@ -292,7 +315,9 @@ MODEL_NAME = "unsloth/Qwen2.5-0.5B-Instruct"
|
|
| 292 |
|
| 293 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 294 |
model_name=MODEL_NAME,
|
| 295 |
-
|
|
|
|
|
|
|
| 296 |
dtype=None,
|
| 297 |
load_in_4bit=True,
|
| 298 |
)
|
|
@@ -355,7 +380,8 @@ training_args = GRPOConfig(
|
|
| 355 |
per_device_train_batch_size=2,
|
| 356 |
gradient_accumulation_steps=4,
|
| 357 |
num_generations=4,
|
| 358 |
-
|
|
|
|
| 359 |
temperature=0.7,
|
| 360 |
learning_rate=1e-5,
|
| 361 |
logging_steps=10,
|
|
|
|
| 4 |
|
| 5 |
Run this in Google Colab:
|
| 6 |
1. Upload this file
|
| 7 |
+
2. Run: python crisisinbox_training.py
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import torch
|
|
|
|
| 15 |
from unsloth import FastLanguageModel
|
| 16 |
from trl import GRPOConfig, GRPOTrainer
|
| 17 |
|
| 18 |
+
# Download episodes from HF dataset
|
| 19 |
print("Loading episodes...")
|
| 20 |
+
from huggingface_hub import hf_hub_download
|
| 21 |
+
hf_hub_download(
|
| 22 |
+
repo_id="eptan/crisis-inbox-episodes",
|
| 23 |
+
filename="episodes.json",
|
| 24 |
+
repo_type="dataset",
|
| 25 |
+
local_dir=".",
|
| 26 |
)
|
| 27 |
|
| 28 |
with open("episodes.json", "r") as f:
|
|
|
|
| 34 |
# PROMPT BUILDING
|
| 35 |
# =============================================================================
|
| 36 |
|
| 37 |
+
MAX_MESSAGES = 40
|
| 38 |
+
MAX_DRIFT_EVENTS = 20
|
| 39 |
+
|
| 40 |
CRISIS_SYSTEM_PROMPT = """
|
| 41 |
You are an assistant helping a working parent during a wildfire.
|
| 42 |
You must triage messages, act on safety-critical items first,
|
|
|
|
| 60 |
msgs_str.append(
|
| 61 |
f"[t={m['time']}h] {urgency} From {m['sender']} via {m['channel']}: {m['content']}{deadline_info}"
|
| 62 |
)
|
| 63 |
+
|
| 64 |
+
# Keep only the most recent messages to avoid overlong sequences.
|
| 65 |
+
if len(msgs_str) > MAX_MESSAGES:
|
| 66 |
+
msgs_str = msgs_str[-MAX_MESSAGES:]
|
| 67 |
+
|
| 68 |
drift_str = []
|
| 69 |
for d in episode.get("schema_events", []):
|
| 70 |
drift_str.append(f"[t={d['time']}h] POLICY UPDATE: {d['kind']} -> {d.get('new_value', 'changed')}")
|
| 71 |
+
|
| 72 |
+
if len(drift_str) > MAX_DRIFT_EVENTS:
|
| 73 |
+
drift_str = drift_str[-MAX_DRIFT_EVENTS:]
|
| 74 |
|
| 75 |
user_content = (
|
| 76 |
"Here is your 48-hour message history:\n\n"
|
|
|
|
| 91 |
|
| 92 |
def parse_plan(model_output):
|
| 93 |
"""Parse <plan> tag output into list of action dicts."""
|
| 94 |
+
if model_output is None:
|
| 95 |
+
return []
|
| 96 |
+
|
| 97 |
+
# TRL/Unsloth can return completions as lists (token ids, strings, or message dicts).
|
| 98 |
+
# Normalize to a single string before regex parsing.
|
| 99 |
+
if isinstance(model_output, list):
|
| 100 |
+
if model_output and isinstance(model_output[0], dict) and "content" in model_output[0]:
|
| 101 |
+
model_output = "\n".join(str(m.get("content", "")) for m in model_output)
|
| 102 |
+
else:
|
| 103 |
+
model_output = "\n".join(map(str, model_output))
|
| 104 |
+
else:
|
| 105 |
+
model_output = str(model_output)
|
| 106 |
+
|
| 107 |
actions = []
|
| 108 |
+
|
| 109 |
+
plan_match = re.search(r"<plan>(.*?)</plan>", model_output, re.DOTALL | re.IGNORECASE)
|
| 110 |
+
plan_content = plan_match.group(1).strip() if plan_match else model_output.strip()
|
| 111 |
+
if not plan_content:
|
| 112 |
return []
|
| 113 |
+
|
| 114 |
+
lines = plan_content.split("\n")
|
|
|
|
|
|
|
| 115 |
for line in lines:
|
| 116 |
line = line.strip()
|
| 117 |
if not line or not line[0].isdigit():
|
| 118 |
continue
|
| 119 |
|
| 120 |
# Extract time: [time=min X]
|
| 121 |
+
time_match = re.search(r"time=min (\d+)", line, re.IGNORECASE)
|
| 122 |
time_min = int(time_match.group(1)) if time_match else 0
|
| 123 |
|
| 124 |
# Extract action description
|
|
|
|
| 315 |
|
| 316 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 317 |
model_name=MODEL_NAME,
|
| 318 |
+
# Allow longer combined prompt + completion to avoid
|
| 319 |
+
# attention mask shape mismatches during training.
|
| 320 |
+
max_seq_length=4096,
|
| 321 |
dtype=None,
|
| 322 |
load_in_4bit=True,
|
| 323 |
)
|
|
|
|
| 380 |
per_device_train_batch_size=2,
|
| 381 |
gradient_accumulation_steps=4,
|
| 382 |
num_generations=4,
|
| 383 |
+
# Keep completions modest so prompt+completion stay well within max_seq_length.
|
| 384 |
+
max_completion_length=256,
|
| 385 |
temperature=0.7,
|
| 386 |
learning_rate=1e-5,
|
| 387 |
logging_steps=10,
|