Spaces:
Sleeping
Sleeping
Commit ·
71f1fe0
1
Parent(s): a79d430
feat: enhance completion parsing to handle truncated JSON and `<think>` blocks
Browse files- docs/RESULTS.md +61 -30
- notebooks/train_merchant_agent.ipynb +3 -84
- tests/test_training_adapter.py +30 -0
- training/env_adapter.py +69 -12
docs/RESULTS.md
CHANGED
|
@@ -83,37 +83,66 @@ issuer's round-1 rejection plus a negative-EV pre-arb branch would have
|
|
| 83 |
made blanket escalation strictly worse. On the other 8 rows the issuer
|
| 84 |
accepts in round 1 and the two policies produce identical trajectories.
|
| 85 |
|
| 86 |
-
## Training Curve (GRPO, 200 steps) —
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
> `notebooks/train_merchant_agent.ipynb` end-to-end and re-rendering this
|
| 93 |
-
> table from the printed checkpoint scores.
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
### Per-family curve (multi-task RL view)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
| --- | --- | --- |
|
| 112 |
-
| 0 | _placeholder_ | untrained Qwen3.5-0.8B |
|
| 113 |
-
| 50 | _placeholder_ | GRPO checkpoint |
|
| 114 |
-
| 100 | _placeholder_ | GRPO checkpoint |
|
| 115 |
-
| 150 | _placeholder_ | GRPO checkpoint |
|
| 116 |
-
| 200 | _placeholder_ | GRPO checkpoint |
|
| 117 |
|
| 118 |
## Ablation
|
| 119 |
|
|
@@ -122,15 +151,17 @@ overfit-to-cheap-tasks failure mode the grouped view exists to surface.
|
|
| 122 |
| **naive** (empty packet → submit) | **0.0000** | PacketValidity gate + EscalationROI vacuous penalty collapse the score |
|
| 123 |
| **concede_all** (always accept) | **0.4475** | Cheap, but EscalationROIRubric (20%) zeros out concedes on positive-EV contestable cases |
|
| 124 |
| **escalate_all** (contest, then escalate) | **0.7713** | Strong on cases where the issuer eventually accepts; pays $250 of arb fee on the pre-arb branch |
|
| 125 |
-
| **untrained
|
| 126 |
| **heuristic** (EV-rational scripted) | **0.8254** | Strong scripted floor — the bar GRPO has to clear |
|
| 127 |
-
| **trained merchant** (step 200) |
|
| 128 |
|
| 129 |
The ablation reads top-down: the benchmark gradient from naive → concede_all
|
| 130 |
→ escalate_all → heuristic is ~0.83 wide, which is the headroom the TRL
|
| 131 |
-
GRPO loop has to close. The
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
| 134 |
|
| 135 |
## Rubric Composition (what's wired)
|
| 136 |
|
|
|
|
| 83 |
made blanket escalation strictly worse. On the other 8 rows the issuer
|
| 84 |
accepts in round 1 and the two policies produce identical trajectories.
|
| 85 |
|
| 86 |
+
## Training Curve (GRPO, 200 steps) — first-attempt findings
|
| 87 |
|
| 88 |
+
First end-to-end GRPO run executed **2026-04-20** on a Colab T4 with
|
| 89 |
+
`Qwen/Qwen3.5-0.8B`, batch 4 × K=4 generations, 200 steps,
|
| 90 |
+
`max_completion_length=128`, `beta=0.0`, `gradient_checkpointing=True`.
|
| 91 |
+
Wall time ~52 min, peak VRAM 7.1 GB.
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
| Step | Mean score (headline 11) | Notes |
|
| 94 |
+
| --- | --- | --- |
|
| 95 |
+
| 0 | 0.8234 | untrained Qwen3.5-0.8B |
|
| 96 |
+
| 50 | 0.8234 | GRPO checkpoint |
|
| 97 |
+
| 100 | 0.8234 | GRPO checkpoint |
|
| 98 |
+
| 150 | 0.8234 | GRPO checkpoint |
|
| 99 |
+
| 200 | 0.8234 | GRPO checkpoint |
|
| 100 |
+
|
| 101 |
+
**The curve is dead flat at 0.8234 — exactly the heuristic floor (0.8254
|
| 102 |
+
± float rounding). This is not noise; it's a complete training failure,
|
| 103 |
+
diagnosed below.** Reporting it as-is rather than as a placeholder
|
| 104 |
+
because the failure mode is itself a useful artefact.
|
| 105 |
+
|
| 106 |
+
### Why it failed (and the two fixes already merged)
|
| 107 |
+
|
| 108 |
+
1. **Truncated JSON ⇒ parse-fail ⇒ no reward variance.** Qwen3.5-0.8B
|
| 109 |
+
chat-tuning makes it write very verbose `strategy` strings.
|
| 110 |
+
`max_completion_length=128` cuts those mid-string. The original
|
| 111 |
+
strict parser required a balanced `}`; truncated JSON returned
|
| 112 |
+
`None`; `run_episode_with_text_policy` fell back to the scripted
|
| 113 |
+
heuristic for **every** action; every K=4 completion in a GRPO group
|
| 114 |
+
produced the same heuristic score; group advantage = 0; gradient = 0.
|
| 115 |
+
Loss collapsed to ~1e-5 after 30 steps and stayed there.
|
| 116 |
+
|
| 117 |
+
2. **`<think>` blocks burned the rest of the budget.** The eval policy
|
| 118 |
+
used the raw prompt, not `apply_chat_template`. Without
|
| 119 |
+
`enable_thinking=False` Qwen3.5 emits `<think>...</think>` scratchpad
|
| 120 |
+
first, which ate the remaining 64–128 generation tokens before any
|
| 121 |
+
JSON appeared.
|
| 122 |
+
|
| 123 |
+
Both are now fixed in code (`training/env_adapter.py:101` —
|
| 124 |
+
`parse_completion` tolerates code fences, `<think>` blocks, prefix words
|
| 125 |
+
naming the action_type, and JSON truncated mid-string by closing at the
|
| 126 |
+
last balanced field; `notebooks/train_merchant_agent.ipynb` cell
|
| 127 |
+
`fc45953c` raises `max_completion_length` to 512 and the eval cell
|
| 128 |
+
applies the chat template with thinking off). Rerun the notebook
|
| 129 |
+
end-to-end to overwrite the table above with whatever GRPO actually does
|
| 130 |
+
once it has a non-zero learning signal.
|
| 131 |
|
| 132 |
### Per-family curve (multi-task RL view)
|
| 133 |
|
| 134 |
+
Section 9 of the notebook re-evaluates each checkpoint grouped by
|
| 135 |
+
difficulty (`easy`/`medium`/`hard`/`nightmare`) and overlays per-cohort
|
| 136 |
+
heuristic floors from the 28-task multi-seed grid. A healthy run shows
|
| 137 |
+
monotone gains in every family; a flat `nightmare` line with rising
|
| 138 |
+
`easy` is the overfit-to-cheap-tasks failure mode this view exists to
|
| 139 |
+
surface. On the first attempt above all four families collapsed onto
|
| 140 |
+
the heuristic line for the same parse-fail reason, so the figure is a
|
| 141 |
+
flat fan rather than a curve. Regenerate after the rerun.
|
| 142 |
|
| 143 |
+
(Figures `docs/figures/training_curve.png` and
|
| 144 |
+
`docs/figures/training_curve_by_family.png` will land here once the
|
| 145 |
+
notebook is re-run with the parser + chat-template fixes.)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
## Ablation
|
| 148 |
|
|
|
|
| 151 |
| **naive** (empty packet → submit) | **0.0000** | PacketValidity gate + EscalationROI vacuous penalty collapse the score |
|
| 152 |
| **concede_all** (always accept) | **0.4475** | Cheap, but EscalationROIRubric (20%) zeros out concedes on positive-EV contestable cases |
|
| 153 |
| **escalate_all** (contest, then escalate) | **0.7713** | Strong on cases where the issuer eventually accepts; pays $250 of arb fee on the pre-arb branch |
|
| 154 |
+
| **untrained Qwen3.5-0.8B** | **0.8234** | All completions parse-fail → episode driven by heuristic fallback. The 0.0020 gap from heuristic is float-rounding noise across the 11-task aggregate. |
|
| 155 |
| **heuristic** (EV-rational scripted) | **0.8254** | Strong scripted floor — the bar GRPO has to clear |
|
| 156 |
+
| **trained merchant** (GRPO step 200, first attempt) | **0.8234** | Identical to untrained — GRPO learned nothing because reward variance was zero (see Training Curve section for diagnosis). |
|
| 157 |
|
| 158 |
The ablation reads top-down: the benchmark gradient from naive → concede_all
|
| 159 |
→ escalate_all → heuristic is ~0.83 wide, which is the headroom the TRL
|
| 160 |
+
GRPO loop has to close. The first GRPO attempt failed to close any of it
|
| 161 |
+
— the trained-merchant row matches the untrained row exactly because
|
| 162 |
+
parse-fail kicked every action through to the scripted heuristic. The
|
| 163 |
+
parser + completion-budget fixes are merged; the next notebook run is
|
| 164 |
+
what will actually demonstrate (or refute) learning.
|
| 165 |
|
| 166 |
## Rubric Composition (what's wired)
|
| 167 |
|
notebooks/train_merchant_agent.ipynb
CHANGED
|
@@ -161,42 +161,7 @@
|
|
| 161 |
"id": "fc45953c",
|
| 162 |
"metadata": {},
|
| 163 |
"outputs": [],
|
| 164 |
-
"source": [
|
| 165 |
-
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 166 |
-
"from training.reward_adapter import compute_reward\n",
|
| 167 |
-
"\n",
|
| 168 |
-
"def reward_fn(prompts, completions, **kwargs):\n",
|
| 169 |
-
" task_ids = kwargs.get('task_id') or kwargs.get('task_ids')\n",
|
| 170 |
-
" return compute_reward(prompts, completions, task_ids=task_ids)\n",
|
| 171 |
-
"\n",
|
| 172 |
-
"# TRL >= 0.14: beta=0 by default skips KL ref model (saves ~0.8B params of VRAM).\n",
|
| 173 |
-
"# processing_class kwarg dropped — pass tokenizer via processing_class only if your\n",
|
| 174 |
-
"# TRL pin still requires it; latest API auto-resolves from the model.\n",
|
| 175 |
-
"config = GRPOConfig(\n",
|
| 176 |
-
" output_dir='./grpo-merchant-agent',\n",
|
| 177 |
-
" per_device_train_batch_size=2,\n",
|
| 178 |
-
" num_generations=4,\n",
|
| 179 |
-
" max_prompt_length=1024,\n",
|
| 180 |
-
" max_completion_length=128,\n",
|
| 181 |
-
" learning_rate=5e-6,\n",
|
| 182 |
-
" max_steps=200,\n",
|
| 183 |
-
" logging_steps=10,\n",
|
| 184 |
-
" save_steps=50,\n",
|
| 185 |
-
" save_total_limit=5,\n",
|
| 186 |
-
" gradient_accumulation_steps=1,\n",
|
| 187 |
-
" bf16=torch.cuda.is_available(),\n",
|
| 188 |
-
" report_to='none',\n",
|
| 189 |
-
" beta=0.0,\n",
|
| 190 |
-
")\n",
|
| 191 |
-
"trainer = GRPOTrainer(\n",
|
| 192 |
-
" model=model,\n",
|
| 193 |
-
" processing_class=tokenizer,\n",
|
| 194 |
-
" reward_funcs=[reward_fn],\n",
|
| 195 |
-
" args=config,\n",
|
| 196 |
-
" train_dataset=train_dataset,\n",
|
| 197 |
-
")\n",
|
| 198 |
-
"trainer.train()"
|
| 199 |
-
]
|
| 200 |
},
|
| 201 |
{
|
| 202 |
"cell_type": "markdown",
|
|
@@ -214,53 +179,7 @@
|
|
| 214 |
"id": "db6b8ab5",
|
| 215 |
"metadata": {},
|
| 216 |
"outputs": [],
|
| 217 |
-
"source": [
|
| 218 |
-
"import glob\n",
|
| 219 |
-
"import re\n",
|
| 220 |
-
"\n",
|
| 221 |
-
"import torch\n",
|
| 222 |
-
"from transformers import AutoModelForCausalLM\n",
|
| 223 |
-
"\n",
|
| 224 |
-
"from training.curve import evaluate_checkpoint\n",
|
| 225 |
-
"\n",
|
| 226 |
-
"def make_text_policy(ckpt_path: str):\n",
|
| 227 |
-
" ckpt_model = AutoModelForCausalLM.from_pretrained(\n",
|
| 228 |
-
" ckpt_path,\n",
|
| 229 |
-
" torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n",
|
| 230 |
-
" device_map='auto',\n",
|
| 231 |
-
" )\n",
|
| 232 |
-
"\n",
|
| 233 |
-
" def _policy(prompt: str) -> str:\n",
|
| 234 |
-
" inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(ckpt_model.device)\n",
|
| 235 |
-
" with torch.no_grad():\n",
|
| 236 |
-
" out = ckpt_model.generate(\n",
|
| 237 |
-
" **inputs,\n",
|
| 238 |
-
" max_new_tokens=128,\n",
|
| 239 |
-
" do_sample=False,\n",
|
| 240 |
-
" temperature=0.0,\n",
|
| 241 |
-
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 242 |
-
" )\n",
|
| 243 |
-
" return tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
|
| 244 |
-
"\n",
|
| 245 |
-
" return _policy\n",
|
| 246 |
-
"\n",
|
| 247 |
-
"def checkpoint_step(path: str) -> int:\n",
|
| 248 |
-
" match = re.search(r'checkpoint-(\\d+)$', path)\n",
|
| 249 |
-
" return int(match.group(1)) if match else -1\n",
|
| 250 |
-
"\n",
|
| 251 |
-
"checkpoint_dirs = sorted(\n",
|
| 252 |
-
" glob.glob('./grpo-merchant-agent/checkpoint-*'),\n",
|
| 253 |
-
" key=checkpoint_step,\n",
|
| 254 |
-
")\n",
|
| 255 |
-
"checkpoints = [evaluate_checkpoint(step=0, policy=make_text_policy(MODEL_ID))]\n",
|
| 256 |
-
"for ckpt_dir in checkpoint_dirs:\n",
|
| 257 |
-
" step = checkpoint_step(ckpt_dir)\n",
|
| 258 |
-
" if step <= 0:\n",
|
| 259 |
-
" continue\n",
|
| 260 |
-
" checkpoints.append(evaluate_checkpoint(step=step, policy=make_text_policy(ckpt_dir)))\n",
|
| 261 |
-
"for ckpt in checkpoints:\n",
|
| 262 |
-
" print(f'step={ckpt.step:4d} mean={ckpt.mean_score:.4f}')"
|
| 263 |
-
]
|
| 264 |
},
|
| 265 |
{
|
| 266 |
"cell_type": "markdown",
|
|
@@ -396,4 +315,4 @@
|
|
| 396 |
},
|
| 397 |
"nbformat": 4,
|
| 398 |
"nbformat_minor": 5
|
| 399 |
-
}
|
|
|
|
| 161 |
"id": "fc45953c",
|
| 162 |
"metadata": {},
|
| 163 |
"outputs": [],
|
| 164 |
+
"source": "from trl import GRPOConfig, GRPOTrainer\nfrom training.reward_adapter import compute_reward\n\ndef reward_fn(prompts, completions, **kwargs):\n task_ids = kwargs.get('task_id') or kwargs.get('task_ids')\n return compute_reward(prompts, completions, task_ids=task_ids)\n\n# beta=0 skips the KL ref model (saves a copy of the policy weights).\n# max_completion_length=512: Qwen-style verbose `strategy` strings overflow the\n# 128-token budget mid-string. The tolerant parser closes truncated JSON, but\n# more headroom = more parseable completions = more reward variance = more\n# learning signal. Empirically 128 collapses the GRPO loss to 0 on every step\n# because every group's K completions all parse-fail and hit the same heuristic\n# fallback reward → zero advantage → no gradient.\n# gradient_checkpointing=True trades ~30% throughput for ~2 GB VRAM headroom\n# so the run lands comfortably under T4's 15 GB even at batch 4.\nconfig = GRPOConfig(\n output_dir='./grpo-merchant-agent',\n per_device_train_batch_size=4,\n num_generations=4,\n max_prompt_length=1024,\n max_completion_length=512,\n learning_rate=5e-6,\n max_steps=200,\n logging_steps=10,\n save_steps=50,\n save_total_limit=5,\n gradient_accumulation_steps=1,\n bf16=torch.cuda.is_available(),\n report_to='none',\n beta=0.0,\n gradient_checkpointing=True,\n)\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=[reward_fn],\n args=config,\n train_dataset=train_dataset,\n)\ntrainer.train()"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
},
|
| 166 |
{
|
| 167 |
"cell_type": "markdown",
|
|
|
|
| 179 |
"id": "db6b8ab5",
|
| 180 |
"metadata": {},
|
| 181 |
"outputs": [],
|
| 182 |
+
"source": "import glob\nimport re\n\nimport torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom training.curve import evaluate_checkpoint\n\ndef make_text_policy(ckpt_path: str):\n ckpt_model = AutoModelForCausalLM.from_pretrained(\n ckpt_path,\n torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n device_map='auto',\n trust_remote_code=True,\n )\n if not hasattr(ckpt_model, 'warnings_issued'):\n ckpt_model.warnings_issued = {}\n ckpt_tok = AutoTokenizer.from_pretrained(ckpt_path, trust_remote_code=True)\n\n def _policy(prompt: str) -> str:\n # Qwen3.5 chat template with thinking off — keeps completion tokens\n # focused on the JSON action instead of `<think>` scratchpad.\n try:\n full_prompt = ckpt_tok.apply_chat_template(\n [{'role': 'user', 'content': prompt}],\n tokenize=False,\n add_generation_prompt=True,\n enable_thinking=False,\n )\n except (TypeError, ValueError):\n full_prompt = prompt\n inputs = ckpt_tok(full_prompt, return_tensors='pt', truncation=True, max_length=2048).to(ckpt_model.device)\n with torch.no_grad():\n out = ckpt_model.generate(\n **inputs,\n max_new_tokens=512,\n do_sample=False,\n pad_token_id=ckpt_tok.eos_token_id,\n )\n return ckpt_tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n\n return _policy\n\ndef checkpoint_step(path: str) -> int:\n match = re.search(r'checkpoint-(\\d+)$', path)\n return int(match.group(1)) if match else -1\n\ncheckpoint_dirs = sorted(\n glob.glob('./grpo-merchant-agent/checkpoint-*'),\n key=checkpoint_step,\n)\ncheckpoints = [evaluate_checkpoint(step=0, policy=make_text_policy(MODEL_ID))]\nfor ckpt_dir in checkpoint_dirs:\n step = checkpoint_step(ckpt_dir)\n if step <= 0:\n continue\n checkpoints.append(evaluate_checkpoint(step=step, policy=make_text_policy(ckpt_dir)))\nfor ckpt in checkpoints:\n print(f'step={ckpt.step:4d} mean={ckpt.mean_score:.4f}')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
},
|
| 184 |
{
|
| 185 |
"cell_type": "markdown",
|
|
|
|
| 315 |
},
|
| 316 |
"nbformat": 4,
|
| 317 |
"nbformat_minor": 5
|
| 318 |
+
}
|
tests/test_training_adapter.py
CHANGED
|
@@ -74,6 +74,36 @@ def test_action_from_completion_returns_none_on_bad_type():
|
|
| 74 |
assert action_from_completion(payload) is None
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
def test_run_episode_falls_back_to_heuristic_on_empty_completion():
|
| 78 |
"""Unparseable completions must not deadlock the episode."""
|
| 79 |
result = run_episode_with_text_policy(
|
|
|
|
| 74 |
assert action_from_completion(payload) is None
|
| 75 |
|
| 76 |
|
| 77 |
+
def test_parse_completion_handles_truncated_json():
|
| 78 |
+
"""Mid-string truncation: tolerate by closing at last balanced field."""
|
| 79 |
+
payload = (
|
| 80 |
+
'```json\n{"action_type": "select_case", "case_id": "CB-E1", '
|
| 81 |
+
'"strategy": "Select the case ID to procee'
|
| 82 |
+
)
|
| 83 |
+
parsed = parse_completion(payload)
|
| 84 |
+
assert parsed is not None
|
| 85 |
+
assert parsed["action_type"] == "select_case"
|
| 86 |
+
assert parsed["case_id"] == "CB-E1"
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def test_parse_completion_strips_think_block():
|
| 90 |
+
payload = (
|
| 91 |
+
'<think>\nlet me think about this\n</think>\n'
|
| 92 |
+
'{"action_type": "select_case", "case_id": "CB-1"}'
|
| 93 |
+
)
|
| 94 |
+
parsed = parse_completion(payload)
|
| 95 |
+
assert parsed == {"action_type": "select_case", "case_id": "CB-1"}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def test_parse_completion_infers_action_type_from_prefix():
|
| 99 |
+
"""Model emits action name as prose then JSON without action_type field."""
|
| 100 |
+
payload = ' select_case\n{"case_id": "CB-X", "strategy": "go"}'
|
| 101 |
+
parsed = parse_completion(payload)
|
| 102 |
+
assert parsed is not None
|
| 103 |
+
assert parsed["action_type"] == "select_case"
|
| 104 |
+
assert parsed["case_id"] == "CB-X"
|
| 105 |
+
|
| 106 |
+
|
| 107 |
def test_run_episode_falls_back_to_heuristic_on_empty_completion():
|
| 108 |
"""Unparseable completions must not deadlock the episode."""
|
| 109 |
result = run_episode_with_text_policy(
|
training/env_adapter.py
CHANGED
|
@@ -9,6 +9,7 @@ side effects — so they are cheap to unit-test.
|
|
| 9 |
from __future__ import annotations
|
| 10 |
|
| 11 |
import json
|
|
|
|
| 12 |
from typing import Any
|
| 13 |
|
| 14 |
try:
|
|
@@ -99,27 +100,83 @@ def build_prompt(observation: dict[str, Any]) -> str:
|
|
| 99 |
|
| 100 |
|
| 101 |
def parse_completion(text: str) -> dict[str, Any] | None:
|
| 102 |
-
"""Parse a model completion into a raw action dict, or return None.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
if not text:
|
| 105 |
return None
|
| 106 |
cleaned = text.strip()
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
cleaned = cleaned[4:].lstrip()
|
| 112 |
-
# Find the first {...} block so prose before JSON is tolerated.
|
| 113 |
start = cleaned.find("{")
|
| 114 |
-
|
| 115 |
-
if start == -1 or end == -1 or end <= start:
|
| 116 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
try:
|
| 118 |
-
|
|
|
|
|
|
|
| 119 |
except json.JSONDecodeError:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
return None
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
return {k: v for k, v in data.items() if k in _ALLOWED_ACTION_FIELDS}
|
| 124 |
|
| 125 |
|
|
|
|
| 9 |
from __future__ import annotations
|
| 10 |
|
| 11 |
import json
|
| 12 |
+
import re
|
| 13 |
from typing import Any
|
| 14 |
|
| 15 |
try:
|
|
|
|
| 100 |
|
| 101 |
|
| 102 |
def parse_completion(text: str) -> dict[str, Any] | None:
|
| 103 |
+
"""Parse a model completion into a raw action dict, or return None.
|
| 104 |
+
|
| 105 |
+
Tolerates: code fences, leading prose / `<think>` blocks, prefix words
|
| 106 |
+
naming the action_type before the JSON, and JSON truncated mid-string
|
| 107 |
+
(auto-closes at the last balanced field). Required because untrained
|
| 108 |
+
Qwen-style chat models often emit valid JSON head + truncated tail —
|
| 109 |
+
a strict parser would zero out the entire training signal.
|
| 110 |
+
"""
|
| 111 |
|
| 112 |
if not text:
|
| 113 |
return None
|
| 114 |
cleaned = text.strip()
|
| 115 |
+
cleaned = re.sub(r"<think>.*?</think>", "", cleaned, flags=re.DOTALL).strip()
|
| 116 |
+
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned)
|
| 117 |
+
cleaned = re.sub(r"```\s*$", "", cleaned).strip()
|
| 118 |
+
|
|
|
|
|
|
|
| 119 |
start = cleaned.find("{")
|
| 120 |
+
if start == -1:
|
|
|
|
| 121 |
return None
|
| 122 |
+
prefix = cleaned[:start].strip()
|
| 123 |
+
body = cleaned[start:]
|
| 124 |
+
|
| 125 |
+
data: dict[str, Any] | None = None
|
| 126 |
try:
|
| 127 |
+
candidate = json.loads(body)
|
| 128 |
+
if isinstance(candidate, dict):
|
| 129 |
+
data = candidate
|
| 130 |
except json.JSONDecodeError:
|
| 131 |
+
pass
|
| 132 |
+
|
| 133 |
+
if data is None:
|
| 134 |
+
depth = 0
|
| 135 |
+
in_str = False
|
| 136 |
+
esc = False
|
| 137 |
+
last_safe = -1
|
| 138 |
+
for i, ch in enumerate(body):
|
| 139 |
+
if esc:
|
| 140 |
+
esc = False
|
| 141 |
+
continue
|
| 142 |
+
if ch == "\\":
|
| 143 |
+
esc = True
|
| 144 |
+
continue
|
| 145 |
+
if ch == '"':
|
| 146 |
+
in_str = not in_str
|
| 147 |
+
continue
|
| 148 |
+
if in_str:
|
| 149 |
+
continue
|
| 150 |
+
if ch == "{":
|
| 151 |
+
depth += 1
|
| 152 |
+
elif ch == "}":
|
| 153 |
+
depth -= 1
|
| 154 |
+
if depth == 0:
|
| 155 |
+
try:
|
| 156 |
+
candidate = json.loads(body[: i + 1])
|
| 157 |
+
if isinstance(candidate, dict):
|
| 158 |
+
data = candidate
|
| 159 |
+
break
|
| 160 |
+
except json.JSONDecodeError:
|
| 161 |
+
pass
|
| 162 |
+
elif ch == "," and depth == 1:
|
| 163 |
+
last_safe = i
|
| 164 |
+
if data is None and last_safe != -1:
|
| 165 |
+
try:
|
| 166 |
+
candidate = json.loads(body[:last_safe] + "}")
|
| 167 |
+
if isinstance(candidate, dict):
|
| 168 |
+
data = candidate
|
| 169 |
+
except json.JSONDecodeError:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
if data is None:
|
| 173 |
return None
|
| 174 |
+
|
| 175 |
+
if "action_type" not in data and prefix:
|
| 176 |
+
m = re.match(r"[a-z_][a-z0-9_]*", prefix.lower())
|
| 177 |
+
if m:
|
| 178 |
+
data["action_type"] = m.group(0)
|
| 179 |
+
|
| 180 |
return {k: v for k, v in data.items() if k in _ALLOWED_ACTION_FIELDS}
|
| 181 |
|
| 182 |
|