eptan commited on
Commit
ccffbf6
·
verified ·
1 Parent(s): f25d4f0

Upload folder using huggingface_hub

Browse files
Dockerfile.notebook CHANGED
@@ -13,14 +13,12 @@ RUN pip install --no-cache-dir \
13
  peft \
14
  huggingface_hub
15
 
16
- # Copy everything needed for training
17
- COPY episodes.json .
18
- COPY generate_episodes.py .
19
- COPY models.py .
20
- COPY messages.py .
21
- COPY drift_events.py .
22
  COPY notebooks/ ./notebooks/
23
 
24
  EXPOSE 8888
25
 
26
- CMD ["jupyter", "notebook", "--ip=0.0.0.0", "--port=8888", "--no-browser", "--allow-root", "--NotebookApp.token=''"]
 
 
 
13
  peft \
14
  huggingface_hub
15
 
16
+ # Copy source files needed for training
17
+ COPY generate_episodes.py models.py messages.py drift_events.py ./
 
 
 
 
18
  COPY notebooks/ ./notebooks/
19
 
20
  EXPOSE 8888
21
 
22
+ # Download training data at startup (build env may block outbound network), then launch Jupyter
23
+ CMD python -c "from huggingface_hub import hf_hub_download; hf_hub_download(repo_id='eptan/crisis-inbox-episodes', filename='episodes.json', repo_type='dataset', local_dir='/app'); import os; os.rename('/app/episodes.json', '/app/.episodes.json')" && \
24
+ jupyter notebook --ip=0.0.0.0 --port=8888 --no-browser --allow-root --NotebookApp.token=''
README.md CHANGED
@@ -166,8 +166,8 @@ crisis-inbox/
166
  │ ├── app.py # FastAPI app with MCPAction workaround
167
  │ └── Dockerfile # HF Spaces deployment
168
  ├── notebooks/
169
- │ └── crisisinbox_grpo.ipynb # GRPO training notebook
170
- ├── episodes.json # Pre-generated training episodes
171
  ├── generate_episodes.py # Episode generator script
172
  ├── pyproject.toml # Package config
173
  ├── openenv.yaml # OpenEnv environment spec
@@ -182,6 +182,14 @@ crisis-inbox/
182
  - **Training:** Unsloth GRPO via Google Colab
183
  - **Model:** Qwen2.5-0.5B-Instruct
184
 
 
 
 
 
 
 
 
 
185
  ## Team
186
 
187
  Built at the OpenEnv Hackathon @ Shack15, SF — March 7-8, 2026
 
166
  │ ├── app.py # FastAPI app with MCPAction workaround
167
  │ └── Dockerfile # HF Spaces deployment
168
  ├── notebooks/
169
+ │ └── crisisinbox_grpo_simple.ipynb # GRPO training notebook (Colab)
170
+ ├── .episodes.json # Pre-generated training episodes (gitignored)
171
  ├── generate_episodes.py # Episode generator script
172
  ├── pyproject.toml # Package config
173
  ├── openenv.yaml # OpenEnv environment spec
 
182
  - **Training:** Unsloth GRPO via Google Colab
183
  - **Model:** Qwen2.5-0.5B-Instruct
184
 
185
+ ### GRPO training (Colab)
186
+
187
+ Open the notebook with the latest fixes (context length, reward signature, left-padding, batch size) in Google Colab (T4 GPU runtime):
188
+
189
+ **[Open in Colab](https://colab.research.google.com/github/eptan/crisis-inbox/blob/main/notebooks/crisisinbox_grpo_simple.ipynb)**
190
+
191
+ Push your local changes to the `main` branch so the link above serves the updated notebook.
192
+
193
  ## Team
194
 
195
  Built at the OpenEnv Hackathon @ Shack15, SF — March 7-8, 2026
generate_episodes.py CHANGED
@@ -270,7 +270,12 @@ def generate_episodes(num_episodes: int = 50, start_seed: int = 1000) -> list:
270
  seed = start_seed + i
271
  print(f" Episode {i + 1}/{num_episodes} (seed={seed})...", end=" ")
272
  episode = build_episode(seed)
273
- n_dp = len(episode["decision_points"])
 
 
 
 
 
274
  n_msg = episode["total_messages"]
275
  drifts = ", ".join(episode["drift_events"])
276
  print(f"{n_msg} messages, {n_dp} decision points, drifts: [{drifts}]")
@@ -282,7 +287,7 @@ def save_episodes(episodes: list, filename: str = "episodes.json"):
282
  """Save episodes to JSON file."""
283
  with open(filename, "w") as f:
284
  json.dump(episodes, f, indent=2)
285
- total_prompts = sum(len(ep["decision_points"]) for ep in episodes)
286
  print(f"\nSaved {len(episodes)} episodes ({total_prompts} training prompts) to {filename}")
287
 
288
 
@@ -292,7 +297,7 @@ if __name__ == "__main__":
292
  parser = argparse.ArgumentParser(description="Generate CrisisInbox training episodes")
293
  parser.add_argument("-n", "--num-episodes", type=int, default=50, help="Number of episodes")
294
  parser.add_argument("-s", "--start-seed", type=int, default=1000, help="Starting seed")
295
- parser.add_argument("-o", "--output", type=str, default="episodes.json", help="Output file")
296
  parser.add_argument("--sample", type=int, default=5, help="Also save N sample episodes")
297
  args = parser.parse_args()
298
 
@@ -301,5 +306,5 @@ if __name__ == "__main__":
301
  save_episodes(episodes, args.output)
302
 
303
  if args.sample > 0:
304
- sample_file = "sample_episodes.json"
305
  save_episodes(episodes[:args.sample], sample_file)
 
270
  seed = start_seed + i
271
  print(f" Episode {i + 1}/{num_episodes} (seed={seed})...", end=" ")
272
  episode = build_episode(seed)
273
+ # Some episodes may not have decision points; skip them.
274
+ decision_points = episode.get("decision_points")
275
+ if not decision_points:
276
+ print("skipped (no decision_points)")
277
+ continue
278
+ n_dp = len(decision_points)
279
  n_msg = episode["total_messages"]
280
  drifts = ", ".join(episode["drift_events"])
281
  print(f"{n_msg} messages, {n_dp} decision points, drifts: [{drifts}]")
 
287
  """Save episodes to JSON file."""
288
  with open(filename, "w") as f:
289
  json.dump(episodes, f, indent=2)
290
+ total_prompts = sum(len(ep.get("decision_points", [])) for ep in episodes)
291
  print(f"\nSaved {len(episodes)} episodes ({total_prompts} training prompts) to {filename}")
292
 
293
 
 
297
  parser = argparse.ArgumentParser(description="Generate CrisisInbox training episodes")
298
  parser.add_argument("-n", "--num-episodes", type=int, default=50, help="Number of episodes")
299
  parser.add_argument("-s", "--start-seed", type=int, default=1000, help="Starting seed")
300
+ parser.add_argument("-o", "--output", type=str, default=".episodes.json", help="Output file")
301
  parser.add_argument("--sample", type=int, default=5, help="Also save N sample episodes")
302
  args = parser.parse_args()
303
 
 
306
  save_episodes(episodes, args.output)
307
 
308
  if args.sample > 0:
309
+ sample_file = ".sample_episodes.json"
310
  save_episodes(episodes[:args.sample], sample_file)
notebooks/crisisinbox_grpo_simple.ipynb CHANGED
@@ -3,95 +3,376 @@
3
  {
4
  "cell_type": "markdown",
5
  "metadata": {},
6
- "source": "# CrisisInbox GRPO Training\n\nTrain a small LLM to triage crisis inbox messages using Group Relative Policy Optimization.\n\n**What this does:**\n1. Loads pre-generated episode data (inbox snapshots at decision points)\n2. For each prompt, the model generates an action (which message to handle + response)\n3. A reward function scores the action based on urgency, deadline, drift adaptation\n4. GRPO updates the model to prefer higher-reward actions\n\n**GPU profiles:**\n- **T4 / free Colab**: Qwen2.5-0.5B, 2048 ctx, 4-bit — runs in ~30 min\n- **H100 / A100**: Qwen2.5-3B, 4096 ctx, 4-bit — better quality, ~20 min"
 
 
 
 
 
 
 
 
 
 
 
 
7
  },
8
  {
9
  "cell_type": "code",
10
- "execution_count": null,
11
  "metadata": {},
12
- "source": "# Install dependencies\n!pip install unsloth trl transformers datasets accelerate peft -q\n!pip install huggingface_hub -q\n\n# Download episode data\n# Option 1: From HF dataset (recommended)\n# Option 2: From GitHub repo\n# Option 3: Generate locally with `python generate_episodes.py -n 100`\n\nimport os\nif not os.path.exists(\"episodes.json\"):\n print(\"Downloading episodes.json from GitHub...\")\n !wget -q --show-progress https://raw.githubusercontent.com/eptan/crisis-inbox/main/episodes.json\n if not os.path.exists(\"episodes.json\"):\n print(\"ERROR: Download failed. Upload episodes.json manually or generate with:\")\n print(\" !git clone https://github.com/eptan/crisis-inbox.git && cd crisis-inbox && python generate_episodes.py -n 100\")\nelse:\n print(\"episodes.json already exists, skipping download\")\n\nprint(\"Setup complete\")",
 
13
  "outputs": []
14
  },
15
  {
16
  "cell_type": "code",
17
- "source": "# === GPU PROFILE ===\n# Change this one variable to switch between T4 and H100 configs.\n# Everything else adapts automatically.\n\nimport torch\n\nif torch.cuda.is_available():\n vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9\n gpu_name = torch.cuda.get_device_name(0)\n print(f\"GPU: {gpu_name} ({vram_gb:.0f} GB)\")\nelse:\n vram_gb = 0\n print(\"No GPU detected — config will default to smallest profile\")\n\n# Auto-select profile based on VRAM, or override manually\nif vram_gb >= 40: # H100, A100\n PROFILE = \"h100\"\n MODEL_NAME = \"unsloth/Qwen2.5-3B-Instruct\"\n MAX_SEQ_LENGTH = 4096\n MAX_PROMPT_LENGTH = 3584\n MAX_COMPLETION_LENGTH = 512\n BATCH_SIZE = 4\n GRAD_ACCUM = 2\n NUM_GENERATIONS = 8\nelif vram_gb >= 14: # T4, L4\n PROFILE = \"t4\"\n MODEL_NAME = \"unsloth/Qwen2.5-0.5B-Instruct\"\n MAX_SEQ_LENGTH = 2048\n MAX_PROMPT_LENGTH = 1792\n MAX_COMPLETION_LENGTH = 256\n BATCH_SIZE = 2\n GRAD_ACCUM = 4\n NUM_GENERATIONS = 4\nelse:\n PROFILE = \"cpu\"\n MODEL_NAME = \"unsloth/Qwen2.5-0.5B-Instruct\"\n MAX_SEQ_LENGTH = 2048\n MAX_PROMPT_LENGTH = 1792\n MAX_COMPLETION_LENGTH = 256\n BATCH_SIZE = 1\n GRAD_ACCUM = 8\n NUM_GENERATIONS = 2\n\nprint(f\"Profile: {PROFILE} | Model: {MODEL_NAME} | Context: {MAX_SEQ_LENGTH}\")",
18
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  "execution_count": null,
20
  "outputs": []
21
  },
22
  {
23
  "cell_type": "code",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  "execution_count": null,
 
 
 
 
25
  "metadata": {},
26
- "source": "import json\nimport re\nimport random\nfrom datasets import Dataset\n\n# Load episodes\nwith open(\"episodes.json\") as f:\n episodes = json.load(f)\n\n# Check format — old format has 'messages'/'tasks', new format has 'decision_points'\nif episodes and \"decision_points\" not in episodes[0]:\n old_keys = list(episodes[0].keys())\n raise ValueError(\n f\"episodes.json is in the old format (keys: {old_keys}).\\n\"\n f\"Regenerate with: python generate_episodes.py -n 100\\n\"\n f\"The old format used 'messages'/'tasks'/'schema_events'; \"\n f\"the notebook requires 'decision_points' from generate_episodes.py.\"\n )\n\n# Flatten to individual training prompts\nprompts = []\nfor ep in episodes:\n for dp in ep[\"decision_points\"]:\n prompts.append({\n \"prompt\": dp[\"prompt\"],\n \"hour\": dp[\"hour\"],\n \"visible_count\": dp[\"visible_count\"],\n \"episode_id\": ep[\"episode_id\"],\n \"seed\": ep[\"seed\"],\n \"drift_events\": ep[\"drift_events\"],\n \"superseded\": ep.get(\"superseded_messages\", {}),\n \"messages\": dp[\"visible_messages\"],\n })\n\nif not prompts:\n raise ValueError(\"No decision_points found in episodes; cannot train.\")\n\nprint(f\"Loaded {len(episodes)} episodes -> {len(prompts)} training prompts\")\nprint(f\"Average {len(prompts)/len(episodes):.1f} decision points per episode\")",
 
27
  "outputs": []
28
  },
29
  {
30
  "cell_type": "markdown",
31
- "source": "## Reward Function\n\nScores agent actions based on:\n- **Urgency base** (critical=10, high=5, medium=3, low=1)\n- **Deadline timing** (early=bonus, late=penalty)\n- **Drift adaptation** (+50% for handling policy-change messages)\n- **Stale info penalty** (-50% for acting on superseded messages)\n- **Response quality** (penalty for short/empty responses)",
32
- "metadata": {}
 
 
 
 
 
 
 
 
 
33
  },
34
  {
35
  "cell_type": "code",
36
- "source": "def score_action(completion: str, prompt_data: dict) -> float:\n \"\"\"\n Score a model completion against the inbox state.\n \n The model should output: respond_to_message(msg_id, \"response text\")\n We parse the message_id and response, then score based on the reward function.\n \"\"\"\n messages = prompt_data[\"messages\"]\n hour = prompt_data[\"hour\"]\n superseded = prompt_data.get(\"superseded\", {})\n \n # Parse the model output for message_id\n msg_id = None\n response_text = \"\"\n \n # Try to parse respond_to_message(msg_id, response)\n match = re.search(r'respond_to_message\\s*\\(\\s*[\"\\']?(msg_\\d+)[\"\\']?\\s*,\\s*[\"\\'](.+?)[\"\\']', completion, re.DOTALL)\n if match:\n msg_id = match.group(1)\n response_text = match.group(2)\n else:\n # Try simpler format: just a message ID mentioned\n id_match = re.search(r'(msg_\\d+)', completion)\n if id_match:\n msg_id = id_match.group(1)\n # No explicit response text — penalize via quality check below\n response_text = \"\"\n \n if not msg_id:\n return -1.0 # couldn't parse any action\n \n # Find the message in the inbox\n target_msg = None\n for msg in messages:\n if msg[\"id\"] == msg_id:\n target_msg = msg\n break\n \n if target_msg is None:\n return -0.5 # referenced a message not in inbox\n \n # Base reward by urgency\n urgency_rewards = {\"critical\": 10.0, \"high\": 5.0, \"medium\": 3.0, \"low\": 1.0}\n reward = urgency_rewards.get(target_msg[\"urgency\"], 1.0)\n \n # Deadline timing\n deadline = target_msg.get(\"deadline_hours\")\n if deadline is not None:\n if hour <= deadline:\n time_remaining_frac = (deadline - hour) / max(deadline, 1.0)\n reward *= 1.0 + 0.5 * time_remaining_frac\n else:\n reward *= 0.25 # late penalty\n \n # Response quality\n if len(response_text.strip()) < 10:\n reward *= 0.5\n \n # Drift adaptation bonus\n if target_msg.get(\"drift_flag\"):\n reward *= 1.5\n \n # Stale info penalty\n if target_msg[\"id\"] in superseded:\n reward *= 0.5\n \n # Penalize choosing low-urgency when unhandled critical messages exist\n unhandled_critical = any(\n m[\"urgency\"] == \"critical\" and not m.get(\"handled\") and not m.get(\"superseded\")\n for m in messages\n )\n if unhandled_critical and target_msg[\"urgency\"] in (\"low\", \"medium\"):\n reward *= 0.3\n \n return round(reward, 2)\n\n\n# Test the reward function\ntest_data = prompts[0]\nprint(\"Testing reward function on first decision point:\")\nprint(f\" Hour: {test_data['hour']}, Messages: {test_data['visible_count']}\")\n\n# Simulate good action (pick critical message)\ncritical_msgs = [m for m in test_data[\"messages\"] if m[\"urgency\"] == \"critical\"]\nif critical_msgs:\n good_action = f'respond_to_message(\"{critical_msgs[0][\"id\"]}\", \"Acknowledged. Evacuating immediately with documents and medication.\")'\n good_score = score_action(good_action, test_data)\n print(f\" Good action (critical msg): {good_score:.2f} pts\")\n\n# Simulate bad action (pick low-urgency message)\nlow_msgs = [m for m in test_data[\"messages\"] if m[\"urgency\"] == \"low\"]\nif low_msgs:\n bad_action = f'respond_to_message(\"{low_msgs[0][\"id\"]}\", \"ok\")'\n bad_score = score_action(bad_action, test_data)\n print(f\" Bad action (low msg, short response): {bad_score:.2f} pts\")\n\n# Simulate unparseable action\njunk_score = score_action(\"I think we should do something\", test_data)\nprint(f\" Unparseable action: {junk_score:.2f} pts\")",
37
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  "execution_count": null,
39
  "outputs": []
40
  },
41
  {
42
  "cell_type": "markdown",
43
- "source": "## Load Model & Configure GRPO",
44
- "metadata": {}
 
 
45
  },
46
  {
47
  "cell_type": "code",
48
- "source": "from unsloth import FastLanguageModel\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_NAME,\n max_seq_length=MAX_SEQ_LENGTH,\n load_in_4bit=True,\n)\n\n# Add LoRA adapters — bigger r for bigger models\nlora_r = 32 if PROFILE == \"h100\" else 16\n\nmodel = FastLanguageModel.get_peft_model(\n model,\n r=lora_r,\n target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\"],\n lora_alpha=lora_r,\n lora_dropout=0,\n bias=\"none\",\n use_gradient_checkpointing=\"unsloth\",\n)\nprint(f\"Model loaded: {MODEL_NAME} | LoRA r={lora_r} | ctx={MAX_SEQ_LENGTH}\")",
49
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  "execution_count": null,
51
  "outputs": []
52
  },
53
  {
54
  "cell_type": "code",
55
- "source": "# Build the training dataset\n# Each row needs a \"prompt\" field formatted as chat messages\ntrain_data = []\nfor p in prompts:\n train_data.append({\n \"prompt\": [\n {\"role\": \"user\", \"content\": p[\"prompt\"]},\n ],\n # Store metadata for reward calculation (not used by trainer directly)\n \"_hour\": p[\"hour\"],\n \"_episode_id\": p[\"episode_id\"],\n })\n\n# Shuffle and split\nrandom.seed(42)\nrandom.shuffle(train_data)\n\ndataset = Dataset.from_list(train_data)\nprint(f\"Training dataset: {len(dataset)} prompts\")\nprint(f\"Sample prompt length: {len(train_data[0]['prompt'][0]['content'])} chars\")",
56
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  "execution_count": null,
58
  "outputs": []
59
  },
60
  {
61
  "cell_type": "markdown",
62
- "source": "## GRPO Training Loop\n\nThe reward function scores each completion by:\n1. Parsing which message the model chose to handle\n2. Checking urgency, deadline timing, drift flags\n3. Penalizing bad choices (low-urgency when critical exists, stale info)",
63
- "metadata": {}
 
 
 
 
 
 
 
64
  },
65
  {
66
  "cell_type": "code",
67
- "source": "from trl import GRPOConfig, GRPOTrainer\n\n# Build a lookup from (episode_id, hour) -> prompt metadata for reward scoring\nprompt_lookup = {}\nfor p in prompts:\n key = (p[\"episode_id\"], p[\"hour\"])\n prompt_lookup[key] = p\n\n\ndef reward_fn(prompts, completions, _episode_id, _hour, **kwargs):\n \"\"\"\n GRPO reward function. Scores each completion against its inbox state.\n\n TRL passes extra dataset columns as keyword arguments, so _episode_id and\n _hour come directly from the dataset — no need to reverse-lookup from text.\n \"\"\"\n rewards = []\n for completion, ep_id, hour in zip(completions, _episode_id, _hour):\n key = (ep_id, hour)\n prompt_data = prompt_lookup.get(key)\n\n if prompt_data is None:\n rewards.append(0.0)\n continue\n\n if isinstance(completion, list):\n comp_text = completion[-1][\"content\"] if completion else \"\"\n else:\n comp_text = str(completion)\n\n score = score_action(comp_text, prompt_data)\n rewards.append(score)\n\n return rewards\n\n\nprint(f\"Prompt lookup: {len(prompt_lookup)} unique keys (expect {len(prompts)})\")\n\n# GRPO training config — all values from GPU profile\ntraining_args = GRPOConfig(\n output_dir=\"crisisinbox-grpo-output\",\n num_train_epochs=3,\n per_device_train_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=GRAD_ACCUM,\n learning_rate=5e-6,\n max_completion_length=MAX_COMPLETION_LENGTH,\n max_prompt_length=MAX_PROMPT_LENGTH,\n num_generations=NUM_GENERATIONS,\n logging_steps=10,\n save_steps=100,\n report_to=\"none\",\n bf16=True,\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=reward_fn,\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Trainer configured — batch={BATCH_SIZE}, gen={NUM_GENERATIONS}, prompt≤{MAX_PROMPT_LENGTH}tok\")",
68
  "metadata": {},
 
69
  "execution_count": null,
70
  "outputs": []
71
  },
72
  {
73
  "cell_type": "code",
74
- "source": "# Train!\ntrainer.train()\nprint(\"Training complete\")",
75
  "metadata": {},
 
 
 
 
 
76
  "execution_count": null,
77
  "outputs": []
78
  },
79
  {
80
  "cell_type": "markdown",
81
- "source": "## Evaluate Trained Model\n\nSample prompts and check whether the model picks high-urgency messages and produces well-formatted actions.",
82
- "metadata": {}
 
 
 
 
83
  },
84
  {
85
  "cell_type": "code",
86
- "source": "# Evaluate on a few test prompts\nFastLanguageModel.for_inference(model)\n\neval_prompts = random.sample(prompts, min(10, len(prompts)))\ntotal_score = 0\n\nprint(f\"=== Trained Model Evaluation ({MODEL_NAME}) ===\\n\")\nfor p in eval_prompts:\n messages = [{\"role\": \"user\", \"content\": p[\"prompt\"]}]\n inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\", add_generation_prompt=True).to(\"cuda\")\n\n with torch.no_grad():\n output = model.generate(inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)\n\n completion = tokenizer.decode(output[0][inputs.shape[1]:], skip_special_tokens=True)\n score = score_action(completion, p)\n total_score += score\n\n # Show a summary\n msg_match = re.search(r'(msg_\\d+)', completion)\n chosen_id = msg_match.group(1) if msg_match else \"none\"\n chosen_msg = next((m for m in p[\"messages\"] if m[\"id\"] == chosen_id), None)\n urgency = chosen_msg[\"urgency\"] if chosen_msg else \"?\"\n\n print(f\"Hour {p['hour']:5.1f} | Chose: {chosen_id} ({urgency:8s}) | Score: {score:+.1f}\")\n\nprint(f\"\\nAverage score: {total_score / len(eval_prompts):.2f}\")",
87
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  "execution_count": null,
89
  "outputs": []
90
  },
91
  {
92
  "cell_type": "code",
93
- "source": "# Save the trained model\nmodel.save_pretrained(\"crisisinbox-grpo-trained\")\ntokenizer.save_pretrained(\"crisisinbox-grpo-trained\")\nprint(\"Model saved to crisisinbox-grpo-trained/\")",
94
  "metadata": {},
 
 
 
 
 
 
95
  "execution_count": null,
96
  "outputs": []
97
  }
 
3
  {
4
  "cell_type": "markdown",
5
  "metadata": {},
6
+ "source": [
7
+ "# CrisisInbox GRPO Training\n",
8
+ "\n",
9
+ "Train a small LLM (Qwen2.5-0.5B) to triage crisis inbox messages using Group Relative Policy Optimization.\n",
10
+ "\n",
11
+ "**What this does:**\n",
12
+ "1. Loads pre-generated episode data (inbox snapshots at decision points)\n",
13
+ "2. For each prompt, the model generates an action (which message to handle + response)\n",
14
+ "3. A reward function scores the action based on urgency, deadline, drift adaptation\n",
15
+ "4. GRPO updates the model to prefer higher-reward actions\n",
16
+ "\n",
17
+ "Open in Google Colab with **T4 GPU** runtime."
18
+ ]
19
  },
20
  {
21
  "cell_type": "code",
 
22
  "metadata": {},
23
+ "source": "# Install dependencies\n!pip install unsloth trl transformers datasets accelerate peft -q\n!pip install huggingface_hub -q\nprint(\"Setup complete\")",
24
+ "execution_count": null,
25
  "outputs": []
26
  },
27
  {
28
  "cell_type": "code",
 
29
  "metadata": {},
30
+ "source": [
31
+ "# Avoid logging crash: transformers sometimes passes a Warning type to logger.warning(),\n",
32
+ "# which breaks %-style formatting. Patch so we don't pass that through.\n",
33
+ "import logging\n",
34
+ "import warnings\n",
35
+ "\n",
36
+ "def _patch_transformers_logging():\n",
37
+ " try:\n",
38
+ " import transformers.utils.logging as trans_log\n",
39
+ " _orig = trans_log.logger.warning\n",
40
+ " def _safe_warning(msg, *args, **kwargs):\n",
41
+ " # If first extra arg is a Warning type (e.g. FutureWarning), drop it for % formatting\n",
42
+ " if args and isinstance(args[0], type) and issubclass(args[0], Warning):\n",
43
+ " args = ()\n",
44
+ " return _orig(msg, *args, **kwargs)\n",
45
+ " trans_log.logger.warning = _safe_warning\n",
46
+ " except Exception:\n",
47
+ " pass\n",
48
+ " warnings.filterwarnings(\"ignore\", message=\".*attention mask API.*\", category=FutureWarning)\n",
49
+ "\n",
50
+ "_patch_transformers_logging()"
51
+ ],
52
  "execution_count": null,
53
  "outputs": []
54
  },
55
  {
56
  "cell_type": "code",
57
+ "metadata": {},
58
+ "source": [
59
+ "import torch\n",
60
+ "\n",
61
+ "# Print GPU info (PyTorch uses total_memory, not total_mem)\n",
62
+ "if torch.cuda.is_available():\n",
63
+ " props = torch.cuda.get_device_properties(0)\n",
64
+ " total_bytes = getattr(props, \"total_memory\", None) or getattr(props, \"total_mem\", 0)\n",
65
+ " vram_gb = total_bytes / 1e9 if total_bytes else 0\n",
66
+ " if vram_gb == 0 and hasattr(torch.cuda, \"mem_get_info\"):\n",
67
+ " _, total_bytes = torch.cuda.mem_get_info(0)\n",
68
+ " vram_gb = total_bytes / 1e9\n",
69
+ " print(f\"GPU: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB)\")\n",
70
+ "else:\n",
71
+ " print(\"No GPU available.\")"
72
+ ],
73
  "execution_count": null,
74
+ "outputs": []
75
+ },
76
+ {
77
+ "cell_type": "code",
78
  "metadata": {},
79
+ "source": "import json\nimport re\nimport random\nimport os\nfrom datasets import Dataset\nfrom huggingface_hub import hf_hub_download\n\n# Load episodes from HF dataset\nEPISODES_FILE = \".episodes.json\"\nif not os.path.exists(EPISODES_FILE):\n print(\"Downloading episodes from HF...\")\n hf_hub_download(\n repo_id=\"eptan/crisis-inbox-episodes\",\n filename=\"episodes.json\",\n repo_type=\"dataset\",\n local_dir=\".\",\n local_dir_use_symlinks=False,\n )\n os.rename(\"episodes.json\", EPISODES_FILE)\n\nwith open(EPISODES_FILE) as f:\n episodes = json.load(f)\n\n# Flatten to individual training prompts\nprompts = []\nfor ep in episodes:\n for dp in ep[\"decision_points\"]:\n prompts.append({\n \"prompt\": dp[\"prompt\"],\n \"hour\": dp[\"hour\"],\n \"visible_count\": dp[\"visible_count\"],\n \"episode_id\": ep[\"episode_id\"],\n \"seed\": ep[\"seed\"],\n \"drift_events\": ep[\"drift_events\"],\n \"superseded\": ep.get(\"superseded_messages\", {}),\n \"messages\": dp[\"visible_messages\"],\n })\n\nprint(f\"Loaded {len(episodes)} episodes -> {len(prompts)} training prompts\")\nprint(f\"Average {len(prompts)/len(episodes):.1f} decision points per episode\")",
80
+ "execution_count": null,
81
  "outputs": []
82
  },
83
  {
84
  "cell_type": "markdown",
85
+ "metadata": {},
86
+ "source": [
87
+ "## Reward Function\n",
88
+ "\n",
89
+ "Scores agent actions based on:\n",
90
+ "- **Urgency base** (critical=10, high=5, medium=3, low=1)\n",
91
+ "- **Deadline timing** (early=bonus, late=penalty)\n",
92
+ "- **Drift adaptation** (+50% for handling policy-change messages)\n",
93
+ "- **Stale info penalty** (-50% for acting on superseded messages)\n",
94
+ "- **Response quality** (penalty for short/empty responses)"
95
+ ]
96
  },
97
  {
98
  "cell_type": "code",
 
99
  "metadata": {},
100
+ "source": [
101
+ "def score_action(completion: str, prompt_data: dict) -> float:\n",
102
+ " \"\"\"\n",
103
+ " Score a model completion against the inbox state.\n",
104
+ " \n",
105
+ " The model should output: respond_to_message(msg_id, \"response text\")\n",
106
+ " We parse the message_id and response, then score based on the reward function.\n",
107
+ " \"\"\"\n",
108
+ " messages = prompt_data[\"messages\"]\n",
109
+ " hour = prompt_data[\"hour\"]\n",
110
+ " superseded = prompt_data.get(\"superseded\", {})\n",
111
+ " \n",
112
+ " # Parse the model output for message_id\n",
113
+ " msg_id = None\n",
114
+ " response_text = \"\"\n",
115
+ " \n",
116
+ " # Try to parse respond_to_message(msg_id, response)\n",
117
+ " match = re.search(r'respond_to_message\\s*\\(\\s*[\"\\']?(msg_\\d+)[\"\\']?\\s*,\\s*[\"\\'](.+?)[\"\\']', completion, re.DOTALL)\n",
118
+ " if match:\n",
119
+ " msg_id = match.group(1)\n",
120
+ " response_text = match.group(2)\n",
121
+ " else:\n",
122
+ " # Try simpler format: just a message ID mentioned\n",
123
+ " id_match = re.search(r'(msg_\\d+)', completion)\n",
124
+ " if id_match:\n",
125
+ " msg_id = id_match.group(1)\n",
126
+ " response_text = completion\n",
127
+ " \n",
128
+ " if not msg_id:\n",
129
+ " return -1.0 # couldn't parse any action\n",
130
+ " \n",
131
+ " # Find the message in the inbox\n",
132
+ " target_msg = None\n",
133
+ " for msg in messages:\n",
134
+ " if msg[\"id\"] == msg_id:\n",
135
+ " target_msg = msg\n",
136
+ " break\n",
137
+ " \n",
138
+ " if target_msg is None:\n",
139
+ " return -0.5 # referenced a message not in inbox\n",
140
+ " \n",
141
+ " # Base reward by urgency\n",
142
+ " urgency_rewards = {\"critical\": 10.0, \"high\": 5.0, \"medium\": 3.0, \"low\": 1.0}\n",
143
+ " reward = urgency_rewards.get(target_msg[\"urgency\"], 1.0)\n",
144
+ " \n",
145
+ " # Deadline timing\n",
146
+ " deadline = target_msg.get(\"deadline_hours\")\n",
147
+ " if deadline is not None:\n",
148
+ " if hour <= deadline:\n",
149
+ " time_remaining_frac = (deadline - hour) / max(deadline, 1.0)\n",
150
+ " reward *= 1.0 + 0.5 * time_remaining_frac\n",
151
+ " else:\n",
152
+ " reward *= 0.25 # late penalty\n",
153
+ " \n",
154
+ " # Response quality\n",
155
+ " if len(response_text.strip()) < 10:\n",
156
+ " reward *= 0.5\n",
157
+ " \n",
158
+ " # Drift adaptation bonus\n",
159
+ " if target_msg.get(\"drift_flag\"):\n",
160
+ " reward *= 1.5\n",
161
+ " \n",
162
+ " # Stale info penalty\n",
163
+ " if target_msg[\"id\"] in superseded:\n",
164
+ " reward *= 0.5\n",
165
+ " \n",
166
+ " # Bonus: penalize choosing low-urgency when critical exists\n",
167
+ " has_critical = any(m[\"urgency\"] == \"critical\" for m in messages)\n",
168
+ " if has_critical and target_msg[\"urgency\"] in (\"low\", \"medium\"):\n",
169
+ " reward *= 0.3 # strong penalty for ignoring critical messages\n",
170
+ " \n",
171
+ " return round(reward, 2)\n",
172
+ "\n",
173
+ "\n",
174
+ "# Test the reward function\n",
175
+ "test_data = prompts[0]\n",
176
+ "print(\"Testing reward function on first decision point:\")\n",
177
+ "print(f\" Hour: {test_data['hour']}, Messages: {test_data['visible_count']}\")\n",
178
+ "\n",
179
+ "# Simulate good action (pick critical message)\n",
180
+ "critical_msgs = [m for m in test_data[\"messages\"] if m[\"urgency\"] == \"critical\"]\n",
181
+ "if critical_msgs:\n",
182
+ " good_action = f'respond_to_message(\"{critical_msgs[0][\"id\"]}\", \"Acknowledged. Evacuating immediately with documents and medication.\")'\n",
183
+ " good_score = score_action(good_action, test_data)\n",
184
+ " print(f\" Good action (critical msg): {good_score:.2f} pts\")\n",
185
+ "\n",
186
+ "# Simulate bad action (pick low-urgency message)\n",
187
+ "low_msgs = [m for m in test_data[\"messages\"] if m[\"urgency\"] == \"low\"]\n",
188
+ "if low_msgs:\n",
189
+ " bad_action = f'respond_to_message(\"{low_msgs[0][\"id\"]}\", \"ok\")'\n",
190
+ " bad_score = score_action(bad_action, test_data)\n",
191
+ " print(f\" Bad action (low msg, short response): {bad_score:.2f} pts\")\n",
192
+ "\n",
193
+ "# Simulate unparseable action\n",
194
+ "junk_score = score_action(\"I think we should do something\", test_data)\n",
195
+ "print(f\" Unparseable action: {junk_score:.2f} pts\")"
196
+ ],
197
  "execution_count": null,
198
  "outputs": []
199
  },
200
  {
201
  "cell_type": "markdown",
202
+ "metadata": {},
203
+ "source": [
204
+ "## Load Model & Configure GRPO"
205
+ ]
206
  },
207
  {
208
  "cell_type": "code",
 
209
  "metadata": {},
210
+ "source": [
211
+ "from unsloth import FastLanguageModel\n",
212
+ "import torch\n",
213
+ "\n",
214
+ "# Load Qwen2.5-0.5B — small enough for T4 GPU\n",
215
+ "# Use a longer context window so prompt + completion\n",
216
+ "# comfortably fit without attention mask shape issues.\n",
217
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
218
+ " model_name=\"unsloth/Qwen2.5-0.5B-Instruct\",\n",
219
+ " max_seq_length=4096,\n",
220
+ " dtype=None,\n",
221
+ " load_in_4bit=True,\n",
222
+ ")\n",
223
+ "\n",
224
+ "# Add LoRA adapters for efficient fine-tuning\n",
225
+ "model = FastLanguageModel.get_peft_model(\n",
226
+ " model,\n",
227
+ " r=16,\n",
228
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
229
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
230
+ " lora_alpha=16,\n",
231
+ " lora_dropout=0,\n",
232
+ " bias=\"none\",\n",
233
+ " use_gradient_checkpointing=\"unsloth\",\n",
234
+ ")\n",
235
+ "\n",
236
+ "# GRPO expects left-padding so completion positions align across the batch\n",
237
+ "# (avoids completion_mask vs log-probs shape mismatch in masked_batch_mean).\n",
238
+ "if tokenizer.pad_token_id is None:\n",
239
+ " tokenizer.pad_token_id = tokenizer.eos_token_id\n",
240
+ "tokenizer.padding_side = \"left\"\n",
241
+ "\n",
242
+ "print(\"Model loaded with LoRA adapters\")"
243
+ ],
244
  "execution_count": null,
245
  "outputs": []
246
  },
247
  {
248
  "cell_type": "code",
 
249
  "metadata": {},
250
+ "source": [
251
+ "# Build the training dataset\n",
252
+ "# Each row needs a \"prompt\" field formatted as chat messages.\n",
253
+ "# Use a conservative max length so every batch has identical shape (avoids mask mismatch).\n",
254
+ "MAX_PROMPT_LENGTH = 1024 # must match GRPOConfig max_prompt_length below\n",
255
+ "\n",
256
+ "train_data = []\n",
257
+ "for p in prompts:\n",
258
+ " msgs = [{\"role\": \"user\", \"content\": p[\"prompt\"]}]\n",
259
+ " tok = tokenizer.apply_chat_template(msgs, return_tensors=\"pt\", add_generation_prompt=True)\n",
260
+ " # apply_chat_template can return a tensor or BatchEncoding. Do NOT use hasattr(tok, \"shape\")\n",
261
+ " # (BatchEncoding.__getattr__ raises when attribute is missing). Use dict-like check instead.\n",
262
+ " try:\n",
263
+ " ids = tok[\"input_ids\"]\n",
264
+ " except (TypeError, KeyError):\n",
265
+ " ids = tok\n",
266
+ " n_tokens = ids.shape[1] if ids.dim() > 1 else ids.shape[0]\n",
267
+ " if n_tokens > MAX_PROMPT_LENGTH:\n",
268
+ " continue # skip overlong prompts so batch shapes stay consistent\n",
269
+ " train_data.append({\n",
270
+ " \"prompt\": msgs,\n",
271
+ " \"_hour\": p[\"hour\"],\n",
272
+ " \"_episode_id\": p[\"episode_id\"],\n",
273
+ " })\n",
274
+ "\n",
275
+ "# Shuffle and split\n",
276
+ "random.seed(42)\n",
277
+ "random.shuffle(train_data)\n",
278
+ "\n",
279
+ "dataset = Dataset.from_list(train_data)\n",
280
+ "print(f\"Training dataset: {len(dataset)} prompts (after dropping prompts > {MAX_PROMPT_LENGTH} tokens)\")\n",
281
+ "print(f\"Sample prompt length: {len(train_data[0]['prompt'][0]['content'])} chars\")"
282
+ ],
283
  "execution_count": null,
284
  "outputs": []
285
  },
286
  {
287
  "cell_type": "markdown",
288
+ "metadata": {},
289
+ "source": [
290
+ "## GRPO Training Loop\n",
291
+ "\n",
292
+ "The reward function scores each completion by:\n",
293
+ "1. Parsing which message the model chose to handle\n",
294
+ "2. Checking urgency, deadline timing, drift flags\n",
295
+ "3. Penalizing bad choices (low-urgency when critical exists, stale info)"
296
+ ]
297
  },
298
  {
299
  "cell_type": "code",
 
300
  "metadata": {},
301
+ "source": "from trl import GRPOConfig, GRPOTrainer\n\n# Build a lookup from (episode_id, hour) -> prompt metadata for reward scoring\nprompt_lookup = {}\nfor p in prompts:\n key = (p[\"episode_id\"], p[\"hour\"])\n prompt_lookup[key] = p\n\n\ndef reward_fn(prompts, completions, _episode_id=None, _hour=None, **kwargs):\n \"\"\"GRPO reward function. Scores each completion against its inbox state.\n TRL passes extra dataset columns as keyword args.\"\"\"\n rewards = []\n for i, (prompt_msgs, completion) in enumerate(zip(prompts, completions)):\n # Look up prompt data by (episode_id, hour) from dataset columns\n prompt_data = None\n if _episode_id is not None and _hour is not None:\n ep_id = _episode_id[i] if hasattr(_episode_id, '__getitem__') else _episode_id\n hour = _hour[i] if hasattr(_hour, '__getitem__') else _hour\n # Convert tensor/numpy to Python scalar if needed\n if hasattr(hour, 'item'):\n hour = hour.item()\n prompt_data = prompt_lookup.get((ep_id, hour))\n\n if prompt_data is None:\n rewards.append(0.0)\n continue\n\n # Extract completion text (trainer may pass token ids or message dicts)\n if isinstance(completion, list):\n if completion and isinstance(completion[0], (int, float)):\n comp_text = tokenizer.decode(completion, skip_special_tokens=True)\n else:\n comp_text = completion[-1].get(\"content\", \"\") if completion else \"\"\n else:\n comp_text = str(completion)\n\n score = score_action(comp_text, prompt_data)\n rewards.append(score)\n\n return rewards\n\n\n# GRPO training config: conservative batch/length to avoid mask shape mismatch.\ntraining_args = GRPOConfig(\n output_dir=\"crisisinbox-grpo-output\",\n num_train_epochs=3,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=4,\n steps_per_generation=4,\n learning_rate=5e-6,\n max_completion_length=256,\n max_prompt_length=1024,\n num_generations=2,\n logging_steps=10,\n save_steps=100,\n report_to=\"none\",\n bf16=False,\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=reward_fn,\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Trainer configured — {len(prompt_lookup)} unique (episode_id, hour) keys\")\nprint(\"Ready to train\")",
302
  "execution_count": null,
303
  "outputs": []
304
  },
305
  {
306
  "cell_type": "code",
 
307
  "metadata": {},
308
+ "source": [
309
+ "# Train!\n",
310
+ "trainer.train()\n",
311
+ "print(\"Training complete\")"
312
+ ],
313
  "execution_count": null,
314
  "outputs": []
315
  },
316
  {
317
  "cell_type": "markdown",
318
+ "metadata": {},
319
+ "source": [
320
+ "## Evaluate: Before vs After\n",
321
+ "\n",
322
+ "Compare the trained model's action choices against the base model on the same prompts."
323
+ ]
324
  },
325
  {
326
  "cell_type": "code",
 
327
  "metadata": {},
328
+ "source": [
329
+ "# Evaluate on a few test prompts\n",
330
+ "FastLanguageModel.for_inference(model)\n",
331
+ "\n",
332
+ "eval_prompts = random.sample(prompts, min(10, len(prompts)))\n",
333
+ "total_score = 0\n",
334
+ "\n",
335
+ "print(\"=== Trained Model Evaluation ===\\n\")\n",
336
+ "for p in eval_prompts:\n",
337
+ " messages = [{\"role\": \"user\", \"content\": p[\"prompt\"]}]\n",
338
+ " raw = tokenizer.apply_chat_template(messages, return_tensors=\"pt\", add_generation_prompt=True)\n",
339
+ " # Do NOT use hasattr(raw, \"shape\") — BatchEncoding.__getattr__ raises. Use try/except.\n",
340
+ " try:\n",
341
+ " inputs = {k: v.to(\"cuda\") for k, v in raw.items()}\n",
342
+ " prompt_len = inputs[\"input_ids\"].shape[1]\n",
343
+ " except (TypeError, AttributeError):\n",
344
+ " inputs = raw.to(\"cuda\")\n",
345
+ " prompt_len = inputs.shape[1]\n",
346
+ "\n",
347
+ " with torch.no_grad():\n",
348
+ " output = model.generate(inputs, max_new_tokens=200, temperature=0.7, do_sample=True)\n",
349
+ "\n",
350
+ " completion = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)\n",
351
+ " score = score_action(completion, p)\n",
352
+ " total_score += score\n",
353
+ "\n",
354
+ " # Show a summary\n",
355
+ " msg_match = re.search(r'(msg_\\d+)', completion)\n",
356
+ " chosen_id = msg_match.group(1) if msg_match else \"none\"\n",
357
+ " chosen_msg = next((m for m in p[\"messages\"] if m[\"id\"] == chosen_id), None)\n",
358
+ " urgency = chosen_msg[\"urgency\"] if chosen_msg else \"?\"\n",
359
+ "\n",
360
+ " print(f\"Hour {p['hour']:5.1f} | Chose: {chosen_id} ({urgency:8s}) | Score: {score:+.1f}\")\n",
361
+ "\n",
362
+ "print(f\"\\nAverage score: {total_score / len(eval_prompts):.2f}\")"
363
+ ],
364
  "execution_count": null,
365
  "outputs": []
366
  },
367
  {
368
  "cell_type": "code",
 
369
  "metadata": {},
370
+ "source": [
371
+ "# Save the trained model\n",
372
+ "model.save_pretrained(\"crisisinbox-grpo-trained\")\n",
373
+ "tokenizer.save_pretrained(\"crisisinbox-grpo-trained\")\n",
374
+ "print(\"Model saved to crisisinbox-grpo-trained/\")"
375
+ ],
376
  "execution_count": null,
377
  "outputs": []
378
  }
server/app.py CHANGED
@@ -23,7 +23,17 @@ except ImportError:
23
 
24
 
25
  class MCPAction(Action):
26
- """Action class that deserializes both ListToolsAction and CallToolAction."""
 
 
 
 
 
 
 
 
 
 
27
 
28
  model_config = Action.model_config.copy()
29
  model_config["extra"] = "allow"
 
23
 
24
 
25
  class MCPAction(Action):
26
+ """Action class that deserializes both ListToolsAction and CallToolAction.
27
+
28
+ OpenEnv 0.2.1's WS handler passes a single action_cls to
29
+ deserialize_action(), but MCPToolClient sends both ListToolsAction
30
+ and CallToolAction through the "step" message path. Since the base
31
+ Action model uses extra="forbid", a fixed action_cls can't handle
32
+ both shapes. We override model_validate to route by the "type" field
33
+ so that MCPEnvironment.step() receives the correct Action subclass.
34
+ extra="allow" is needed because the two action types have different
35
+ field sets.
36
+ """
37
 
38
  model_config = Action.model_config.copy()
39
  model_config["extra"] = "allow"
server/crisis_inbox_environment.py CHANGED
@@ -34,6 +34,9 @@ class CrisisInboxEnvironment(MCPEnvironment):
34
  """
35
  Simulates a 48-hour post-disaster inbox triage scenario.
36
 
 
 
 
37
  The agent receives messages from family, employers, government agencies,
38
  insurance companies, and service providers. It must prioritize safety,
39
  meet deadlines, and adapt to changing rules (schema drift).
 
34
  """
35
  Simulates a 48-hour post-disaster inbox triage scenario.
36
 
37
+ Note: SUPPORTS_CONCURRENT_SESSIONS is False (default) because the
38
+ environment holds mutable per-episode state (_all_messages, _handled, etc).
39
+
40
  The agent receives messages from family, employers, government agencies,
41
  insurance companies, and service providers. It must prioritize safety,
42
  meet deadlines, and adapt to changing rules (schema drift).
training/crisisinbox_training.py CHANGED
@@ -4,8 +4,7 @@ Person B: ML Pipeline
4
 
5
  Run this in Google Colab:
6
  1. Upload this file
7
- 2. Upload episodes.json from repo
8
- 3. Run: python crisisinbox_training.py
9
  """
10
 
11
  import torch
@@ -16,12 +15,14 @@ from datasets import Dataset
16
  from unsloth import FastLanguageModel
17
  from trl import GRPOConfig, GRPOTrainer
18
 
19
- # Download episodes from GitHub repo
20
  print("Loading episodes...")
21
- import urllib.request
22
- urllib.request.urlretrieve(
23
- "https://raw.githubusercontent.com/eptan/crisis-inbox/main/episodes.json",
24
- "episodes.json"
 
 
25
  )
26
 
27
  with open("episodes.json", "r") as f:
@@ -33,6 +34,9 @@ print(f"✓ Loaded {len(EPISODES)} episodes")
33
  # PROMPT BUILDING
34
  # =============================================================================
35
 
 
 
 
36
  CRISIS_SYSTEM_PROMPT = """
37
  You are an assistant helping a working parent during a wildfire.
38
  You must triage messages, act on safety-critical items first,
@@ -56,10 +60,17 @@ def build_crisis_prompt(episode):
56
  msgs_str.append(
57
  f"[t={m['time']}h] {urgency} From {m['sender']} via {m['channel']}: {m['content']}{deadline_info}"
58
  )
59
-
 
 
 
 
60
  drift_str = []
61
  for d in episode.get("schema_events", []):
62
  drift_str.append(f"[t={d['time']}h] POLICY UPDATE: {d['kind']} -> {d.get('new_value', 'changed')}")
 
 
 
63
 
64
  user_content = (
65
  "Here is your 48-hour message history:\n\n"
@@ -80,22 +91,34 @@ def build_crisis_prompt(episode):
80
 
81
  def parse_plan(model_output):
82
  """Parse <plan> tag output into list of action dicts."""
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  actions = []
84
-
85
- plan_match = re.search(r'<plan>(.*?)</plan>', model_output, re.DOTALL | re.IGNORECASE)
86
- if not plan_match:
 
87
  return []
88
-
89
- plan_content = plan_match.group(1).strip()
90
-
91
- lines = plan_content.split('\n')
92
  for line in lines:
93
  line = line.strip()
94
  if not line or not line[0].isdigit():
95
  continue
96
 
97
  # Extract time: [time=min X]
98
- time_match = re.search(r'time=min (\d+)', line)
99
  time_min = int(time_match.group(1)) if time_match else 0
100
 
101
  # Extract action description
@@ -292,7 +315,9 @@ MODEL_NAME = "unsloth/Qwen2.5-0.5B-Instruct"
292
 
293
  model, tokenizer = FastLanguageModel.from_pretrained(
294
  model_name=MODEL_NAME,
295
- max_seq_length=2048,
 
 
296
  dtype=None,
297
  load_in_4bit=True,
298
  )
@@ -355,7 +380,8 @@ training_args = GRPOConfig(
355
  per_device_train_batch_size=2,
356
  gradient_accumulation_steps=4,
357
  num_generations=4,
358
- max_completion_length=512,
 
359
  temperature=0.7,
360
  learning_rate=1e-5,
361
  logging_steps=10,
 
4
 
5
  Run this in Google Colab:
6
  1. Upload this file
7
+ 2. Run: python crisisinbox_training.py
 
8
  """
9
 
10
  import torch
 
15
  from unsloth import FastLanguageModel
16
  from trl import GRPOConfig, GRPOTrainer
17
 
18
+ # Download episodes from HF dataset
19
  print("Loading episodes...")
20
+ from huggingface_hub import hf_hub_download
21
+ hf_hub_download(
22
+ repo_id="eptan/crisis-inbox-episodes",
23
+ filename="episodes.json",
24
+ repo_type="dataset",
25
+ local_dir=".",
26
  )
27
 
28
  with open("episodes.json", "r") as f:
 
34
  # PROMPT BUILDING
35
  # =============================================================================
36
 
37
+ MAX_MESSAGES = 40
38
+ MAX_DRIFT_EVENTS = 20
39
+
40
  CRISIS_SYSTEM_PROMPT = """
41
  You are an assistant helping a working parent during a wildfire.
42
  You must triage messages, act on safety-critical items first,
 
60
  msgs_str.append(
61
  f"[t={m['time']}h] {urgency} From {m['sender']} via {m['channel']}: {m['content']}{deadline_info}"
62
  )
63
+
64
+ # Keep only the most recent messages to avoid overlong sequences.
65
+ if len(msgs_str) > MAX_MESSAGES:
66
+ msgs_str = msgs_str[-MAX_MESSAGES:]
67
+
68
  drift_str = []
69
  for d in episode.get("schema_events", []):
70
  drift_str.append(f"[t={d['time']}h] POLICY UPDATE: {d['kind']} -> {d.get('new_value', 'changed')}")
71
+
72
+ if len(drift_str) > MAX_DRIFT_EVENTS:
73
+ drift_str = drift_str[-MAX_DRIFT_EVENTS:]
74
 
75
  user_content = (
76
  "Here is your 48-hour message history:\n\n"
 
91
 
92
  def parse_plan(model_output):
93
  """Parse <plan> tag output into list of action dicts."""
94
+ if model_output is None:
95
+ return []
96
+
97
+ # TRL/Unsloth can return completions as lists (token ids, strings, or message dicts).
98
+ # Normalize to a single string before regex parsing.
99
+ if isinstance(model_output, list):
100
+ if model_output and isinstance(model_output[0], dict) and "content" in model_output[0]:
101
+ model_output = "\n".join(str(m.get("content", "")) for m in model_output)
102
+ else:
103
+ model_output = "\n".join(map(str, model_output))
104
+ else:
105
+ model_output = str(model_output)
106
+
107
  actions = []
108
+
109
+ plan_match = re.search(r"<plan>(.*?)</plan>", model_output, re.DOTALL | re.IGNORECASE)
110
+ plan_content = plan_match.group(1).strip() if plan_match else model_output.strip()
111
+ if not plan_content:
112
  return []
113
+
114
+ lines = plan_content.split("\n")
 
 
115
  for line in lines:
116
  line = line.strip()
117
  if not line or not line[0].isdigit():
118
  continue
119
 
120
  # Extract time: [time=min X]
121
+ time_match = re.search(r"time=min (\d+)", line, re.IGNORECASE)
122
  time_min = int(time_match.group(1)) if time_match else 0
123
 
124
  # Extract action description
 
315
 
316
  model, tokenizer = FastLanguageModel.from_pretrained(
317
  model_name=MODEL_NAME,
318
+ # Allow longer combined prompt + completion to avoid
319
+ # attention mask shape mismatches during training.
320
+ max_seq_length=4096,
321
  dtype=None,
322
  load_in_4bit=True,
323
  )
 
380
  per_device_train_batch_size=2,
381
  gradient_accumulation_steps=4,
382
  num_generations=4,
383
+ # Keep completions modest so prompt+completion stay well within max_seq_length.
384
+ max_completion_length=256,
385
  temperature=0.7,
386
  learning_rate=1e-5,
387
  logging_steps=10,