Spaces:
Running
Running
Commit ·
2c8a058
1
Parent(s): 85a3e59
Add GRPO training notebook + Trackio integration + SpatialThinker support
Browse files- train_origami.ipynb: Full Colab/Northflank notebook for GRPO training
with model selection (SpatialThinker vs vanilla Qwen2.5-VL), per-component
reward logging, evaluation harness, and A/B comparison
- train.py: Switch from W&B to Trackio, add VL model auto-detection
for FastVisionModel vs FastLanguageModel
- trainer/: Updated reward functions and prompts for env/ system
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- train.py +29 -6
- train_origami.ipynb +184 -0
- trainer/mock_env.py +16 -0
- trainer/prompts.py +108 -66
- trainer/rewards.py +339 -14
- trainer/train.py +11 -4
train.py
CHANGED
|
@@ -2,9 +2,14 @@
|
|
| 2 |
OrigamiRL — GRPO Training Script
|
| 3 |
Code-as-policy: model generates complete fold sequence, gets terminal reward.
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
Usage:
|
| 6 |
python train.py
|
| 7 |
-
python train.py --model unsloth/Qwen2.5-7B-Instruct --epochs 3
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
import argparse
|
| 10 |
import json
|
|
@@ -13,10 +18,19 @@ import random
|
|
| 13 |
from pathlib import Path
|
| 14 |
from typing import Optional
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
def parse_args():
|
| 18 |
parser = argparse.ArgumentParser()
|
| 19 |
-
parser.add_argument('--model', default='unsloth/Qwen2.5-7B-Instruct'
|
|
|
|
|
|
|
| 20 |
parser.add_argument('--max_seq_length', type=int, default=2048)
|
| 21 |
parser.add_argument('--epochs', type=int, default=3)
|
| 22 |
parser.add_argument('--batch_size', type=int, default=2)
|
|
@@ -148,20 +162,29 @@ def main():
|
|
| 148 |
return
|
| 149 |
|
| 150 |
# Load model via unsloth
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
try:
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
except ImportError:
|
| 154 |
print("ERROR: unsloth not installed. Run: pip install unsloth")
|
| 155 |
print("Or run with --dry_run to test the reward function without a model.")
|
| 156 |
return
|
| 157 |
|
| 158 |
-
model, tokenizer =
|
| 159 |
model_name=args.model,
|
| 160 |
max_seq_length=args.max_seq_length,
|
| 161 |
load_in_4bit=True,
|
| 162 |
)
|
| 163 |
|
| 164 |
-
model =
|
| 165 |
model,
|
| 166 |
r=32,
|
| 167 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
|
@@ -193,7 +216,7 @@ def main():
|
|
| 193 |
num_generations=args.n_generations,
|
| 194 |
temperature=1.0,
|
| 195 |
logging_steps=1,
|
| 196 |
-
report_to="
|
| 197 |
run_name="origami-grpo",
|
| 198 |
)
|
| 199 |
|
|
|
|
| 2 |
OrigamiRL — GRPO Training Script
|
| 3 |
Code-as-policy: model generates complete fold sequence, gets terminal reward.
|
| 4 |
|
| 5 |
+
Base model: SpatialThinker (Qwen2.5-VL-7B fine-tuned for spatial reasoning)
|
| 6 |
+
or any Unsloth-compatible model.
|
| 7 |
+
|
| 8 |
Usage:
|
| 9 |
python train.py
|
| 10 |
+
python train.py --model unsloth/Qwen2.5-VL-7B-Instruct --epochs 3
|
| 11 |
+
python train.py --model OX-PIXL/SpatialThinker-Qwen2.5-VL-7B --epochs 3
|
| 12 |
+
python train.py --dry_run # test rewards without GPU
|
| 13 |
"""
|
| 14 |
import argparse
|
| 15 |
import json
|
|
|
|
| 18 |
from pathlib import Path
|
| 19 |
from typing import Optional
|
| 20 |
|
| 21 |
+
# VL (vision-language) model identifiers — use FastVisionModel for these
|
| 22 |
+
_VL_MODEL_PATTERNS = ['VL', 'vl', 'Vision', 'vision', 'SpatialThinker', 'SpaceThinker']
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _is_vl_model(model_name: str) -> bool:
|
| 26 |
+
return any(p in model_name for p in _VL_MODEL_PATTERNS)
|
| 27 |
+
|
| 28 |
|
| 29 |
def parse_args():
|
| 30 |
parser = argparse.ArgumentParser()
|
| 31 |
+
parser.add_argument('--model', default='unsloth/Qwen2.5-VL-7B-Instruct',
|
| 32 |
+
help='Base model. Use unsloth/Qwen2.5-VL-7B-Instruct or '
|
| 33 |
+
'OX-PIXL/SpatialThinker-Qwen2.5-VL-7B for spatial reasoning')
|
| 34 |
parser.add_argument('--max_seq_length', type=int, default=2048)
|
| 35 |
parser.add_argument('--epochs', type=int, default=3)
|
| 36 |
parser.add_argument('--batch_size', type=int, default=2)
|
|
|
|
| 162 |
return
|
| 163 |
|
| 164 |
# Load model via unsloth
|
| 165 |
+
# VL models (SpatialThinker, Qwen2.5-VL) use FastVisionModel
|
| 166 |
+
# Text-only models use FastLanguageModel
|
| 167 |
+
is_vl = _is_vl_model(args.model)
|
| 168 |
+
|
| 169 |
try:
|
| 170 |
+
if is_vl:
|
| 171 |
+
from unsloth import FastVisionModel as ModelLoader
|
| 172 |
+
print(f"Loading VL model (vision-language): {args.model}")
|
| 173 |
+
else:
|
| 174 |
+
from unsloth import FastLanguageModel as ModelLoader
|
| 175 |
+
print(f"Loading text model: {args.model}")
|
| 176 |
except ImportError:
|
| 177 |
print("ERROR: unsloth not installed. Run: pip install unsloth")
|
| 178 |
print("Or run with --dry_run to test the reward function without a model.")
|
| 179 |
return
|
| 180 |
|
| 181 |
+
model, tokenizer = ModelLoader.from_pretrained(
|
| 182 |
model_name=args.model,
|
| 183 |
max_seq_length=args.max_seq_length,
|
| 184 |
load_in_4bit=True,
|
| 185 |
)
|
| 186 |
|
| 187 |
+
model = ModelLoader.get_peft_model(
|
| 188 |
model,
|
| 189 |
r=32,
|
| 190 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
|
|
|
| 216 |
num_generations=args.n_generations,
|
| 217 |
temperature=1.0,
|
| 218 |
logging_steps=1,
|
| 219 |
+
report_to="trackio",
|
| 220 |
run_name="origami-grpo",
|
| 221 |
)
|
| 222 |
|
train_origami.ipynb
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "8smrrb11v84",
|
| 6 |
+
"source": "# Optigami — Origami RL Training (GRPO)\n\n**Train an LLM to generate valid origami fold sequences using verifiable geometric rewards.**\n\nArchitecture:\n- **Environment**: `env/` — CreaseGraph + Kawasaki/Maekawa/BLB verifiers + target matching\n- **Policy model**: SpatialThinker (Qwen2.5-VL-7B) or vanilla Qwen2.5-VL-7B\n- **Training**: Unsloth GRPO — model generates complete fold sequences, gets terminal reward\n- **Tracking**: Trackio — real-time reward curves on HF Spaces\n\n| Reward Component | Weight | What it measures |\n|---|---|---|\n| `progress` | 0.45 | Geometric crease coverage vs target |\n| `economy` | 0.10 | Penalty for excess creases |\n| `kawasaki` | 0.08 | Kawasaki theorem satisfaction |\n| `maekawa` | 0.07 | Maekawa theorem satisfaction |\n| `blb` | 0.05 | Big-Little-Big lemma |\n| `anchored` | 0.05 | Valid anchor point usage |\n| `completion` | +10.0 | Bonus when target reached |",
|
| 7 |
+
"metadata": {}
|
| 8 |
+
},
|
| 9 |
+
{
|
| 10 |
+
"cell_type": "markdown",
|
| 11 |
+
"id": "kn1k9d357j",
|
| 12 |
+
"source": "## 1. Setup\n\n**GPU**: H100 80GB (Northflank/CoreWeave) or A100/T4 (Colab)\n\nInstall dependencies. Unsloth handles efficient model loading + LoRA.",
|
| 13 |
+
"metadata": {}
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "code",
|
| 17 |
+
"id": "d10vqzep5b6",
|
| 18 |
+
"source": "%%capture\n!pip install unsloth trackio shapely numpy datasets\n!pip install --upgrade trl transformers\n\n# Check GPU\nimport torch\nprint(f\"GPU: {torch.cuda.get_device_name(0)}\")\nprint(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")",
|
| 19 |
+
"metadata": {},
|
| 20 |
+
"execution_count": null,
|
| 21 |
+
"outputs": []
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "markdown",
|
| 25 |
+
"id": "y6wsagz8h",
|
| 26 |
+
"source": "## 2. Configuration\n\nChoose base model and hyperparameters. Two options:\n- **SpatialThinker** (`OX-PIXL/SpatialThinker-Qwen2.5-VL-7B`): Pre-trained for spatial reasoning via RL\n- **Vanilla Qwen2.5-VL** (`unsloth/Qwen2.5-VL-7B-Instruct`): Standard vision-language model\n\nWe'll compare both to see which learns origami folding faster.",
|
| 27 |
+
"metadata": {}
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "code",
|
| 31 |
+
"id": "dh1zapl0w5s",
|
| 32 |
+
"source": "# ── Config ──────────────────────────────────────────────────────────────────\n# Toggle MODEL_NAME to switch between SpatialThinker and vanilla Qwen2.5-VL\n\nMODEL_NAME = \"OX-PIXL/SpatialThinker-Qwen2.5-VL-7B\"\n# MODEL_NAME = \"unsloth/Qwen2.5-VL-7B-Instruct\" # uncomment for vanilla\n\nMAX_SEQ_LENGTH = 2048\nLORA_R = 32\nLORA_ALPHA = 32\nEPOCHS = 3\nBATCH_SIZE = 2\nGRAD_ACCUM = 4\nLR = 5e-6\nN_GENERATIONS = 8 # completions sampled per prompt (GRPO group size)\nMAX_FOLDS = 8 # max folds per episode\nLEVEL = 1 # target difficulty (1=simple, 2=medium, 3=hard)\nMAX_COMPLETION_LEN = 512\nOUTPUT_DIR = \"origami-grpo\"\n\n# Trackio — set your HF Space ID for live dashboard\nTRACKIO_SPACE_ID = None # e.g. \"your-username/optigami-training\"\n\nprint(f\"Model: {MODEL_NAME}\")\nprint(f\"Config: {EPOCHS} epochs, batch={BATCH_SIZE}, grad_accum={GRAD_ACCUM}, lr={LR}\")\nprint(f\"GRPO: {N_GENERATIONS} generations, max_folds={MAX_FOLDS}, level={LEVEL}\")",
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"execution_count": null,
|
| 35 |
+
"outputs": []
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "markdown",
|
| 39 |
+
"id": "o5hhfbp0wb",
|
| 40 |
+
"source": "## 3. Clone Repo & Test Environment\n\nClone the optigami repo (skip if running locally) and verify the environment works.",
|
| 41 |
+
"metadata": {}
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "code",
|
| 45 |
+
"id": "94cemjucczl",
|
| 46 |
+
"source": "import os\n\n# Clone repo if not already present (Colab/Northflank)\nif not os.path.exists(\"env/environment.py\"):\n !git clone https://huggingface.co/spaces/openenv-community/optigami /content/optigami 2>/dev/null || true\n os.chdir(\"/content/optigami\")\n\n# Verify env/ is accessible\nfrom env.environment import OrigamiEnvironment\nfrom env.rewards import compute_reward\nfrom env.prompts import parse_fold_list\n\nenv = OrigamiEnvironment(mode=\"code_as_policy\", max_steps=MAX_FOLDS)\nprint(f\"Available targets: {env.available_targets()}\")\nprint(f\"Environment mode: {env.mode}\")",
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"execution_count": null,
|
| 49 |
+
"outputs": []
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"id": "2j9mccejyfx",
|
| 54 |
+
"source": "# ── Dry run: test reward function ───────────────────────────────────────────\n# Verify rewards work before loading the model\n\nimport copy\n\ndef make_reward_fn(env_template, max_folds):\n \"\"\"Reward function: clone env, run completion, return total reward.\"\"\"\n def reward_fn(completions, prompts=None, **kwargs):\n rewards = []\n target_names = kwargs.get(\"target_names\", [None] * len(completions))\n for completion, target_name in zip(completions, target_names):\n try:\n e = env_template.clone()\n e.reset(target_name=target_name)\n _, reward_dict, _, _ = e.step(completion)\n rewards.append(float(reward_dict[\"total\"]))\n except Exception:\n rewards.append(-0.1)\n return rewards\n return reward_fn\n\nreward_fn = make_reward_fn(env, MAX_FOLDS)\n\ntest_completions = [\n '<folds>[{\"instruction\": \"Valley fold along horizontal center\", \"from\": [0, 0.5], \"to\": [1, 0.5], \"assignment\": \"V\"}]</folds>',\n '<folds>[{\"instruction\": \"Bad fold\", \"from\": [0.3, 0.3], \"to\": [0.7, 0.7], \"assignment\": \"V\"}]</folds>',\n 'not valid JSON',\n]\ntarget_names = [\"half_horizontal\"] * 3\nrewards = reward_fn(test_completions, target_names=target_names)\n\nfor comp, r in zip([\"perfect fold\", \"partial fold\", \"garbage\"], rewards):\n print(f\" {comp}: reward = {r:.3f}\")\nprint(\"\\nReward function OK ✓\")",
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"execution_count": null,
|
| 57 |
+
"outputs": []
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "markdown",
|
| 61 |
+
"id": "46gs2p1cy4",
|
| 62 |
+
"source": "## 4. Load Model + LoRA\n\nLoad the VL model with Unsloth's `FastVisionModel` (4-bit quantized) and apply LoRA adapters.",
|
| 63 |
+
"metadata": {}
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "code",
|
| 67 |
+
"id": "82f76od6d2k",
|
| 68 |
+
"source": "from unsloth import FastVisionModel\n\nmodel, tokenizer = FastVisionModel.from_pretrained(\n model_name=MODEL_NAME,\n max_seq_length=MAX_SEQ_LENGTH,\n load_in_4bit=True,\n)\n\nmodel = FastVisionModel.get_peft_model(\n model,\n r=LORA_R,\n target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\"],\n lora_alpha=LORA_ALPHA,\n lora_dropout=0,\n use_gradient_checkpointing=\"unsloth\",\n)\n\nprint(f\"Model loaded: {MODEL_NAME}\")\nprint(f\"LoRA rank: {LORA_R}, alpha: {LORA_ALPHA}\")\nprint(f\"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\")",
|
| 69 |
+
"metadata": {},
|
| 70 |
+
"execution_count": null,
|
| 71 |
+
"outputs": []
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "markdown",
|
| 75 |
+
"id": "67dyfrj23y",
|
| 76 |
+
"source": "## 5. Build Dataset\n\nGenerate prompts from all level-appropriate targets. Each prompt embeds the target crease pattern description and asks the model to output `<folds>[...]</folds>`.",
|
| 77 |
+
"metadata": {}
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"cell_type": "code",
|
| 81 |
+
"id": "1msqpzj5fwu",
|
| 82 |
+
"source": "import random\nfrom datasets import Dataset\n\ndef build_dataset(env, level=1):\n \"\"\"Build training dataset of prompts from env targets.\"\"\"\n all_names = env.available_targets()\n level_names = [\n n for n in all_names\n if env._targets[n].get(\"level\", 1) == level\n ]\n if not level_names:\n level_names = all_names\n\n items = []\n for name in level_names:\n obs = env.reset(target_name=name)\n items.append({\"prompt\": obs[\"prompt\"], \"target_name\": name})\n\n # Repeat each target 10x; ensure at least 50 examples\n repeat = max(10, (50 + len(items) - 1) // len(items))\n items = items * repeat\n random.shuffle(items)\n return items\n\ndataset_items = build_dataset(env, level=LEVEL)\nhf_dataset = Dataset.from_list(dataset_items)\n\nprint(f\"Dataset: {len(dataset_items)} examples\")\nprint(f\"Targets in dataset: {sorted(set(d['target_name'] for d in dataset_items))}\")\nprint(f\"\\nSample prompt (first 300 chars):\\n{dataset_items[0]['prompt'][:300]}...\")",
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"execution_count": null,
|
| 85 |
+
"outputs": []
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"cell_type": "markdown",
|
| 89 |
+
"id": "7n3r3nsw8ae",
|
| 90 |
+
"source": "## 6. Trackio Setup\n\nInitialize Trackio for real-time training visualization. Trackio is a free W&B alternative that deploys a Gradio dashboard to HF Spaces.",
|
| 91 |
+
"metadata": {}
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "code",
|
| 95 |
+
"id": "bgru9rsw95b",
|
| 96 |
+
"source": "import trackio\n\n# Initialize Trackio run\ntrackio_kwargs = {\n \"project_name\": \"optigami\",\n \"run_name\": f\"grpo-{MODEL_NAME.split('/')[-1]}-level{LEVEL}\",\n}\nif TRACKIO_SPACE_ID:\n trackio_kwargs[\"space_id\"] = TRACKIO_SPACE_ID\n\ntrackio.init(**trackio_kwargs)\nprint(f\"Trackio initialized: {trackio_kwargs['run_name']}\")\nif TRACKIO_SPACE_ID:\n print(f\"Dashboard: https://huggingface.co/spaces/{TRACKIO_SPACE_ID}\")",
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"execution_count": null,
|
| 99 |
+
"outputs": []
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"cell_type": "markdown",
|
| 103 |
+
"id": "n8aqymlszo",
|
| 104 |
+
"source": "## 7. GRPO Training\n\nRun GRPO with Trackio logging. The trainer samples `N_GENERATIONS` completions per prompt, computes rewards via the environment, and updates the policy using group-relative advantages.",
|
| 105 |
+
"metadata": {}
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "code",
|
| 109 |
+
"id": "ci4imd9ws7v",
|
| 110 |
+
"source": "from trl import GRPOConfig, GRPOTrainer\n\n# ── Per-component reward functions for detailed logging ─────────────────────\nREWARD_COMPONENTS = [\"kawasaki\", \"maekawa\", \"blb\", \"progress\", \"economy\", \"completion\"]\n\ndef make_component_fn(env_template, component):\n \"\"\"Create a reward function that returns a single component's value.\"\"\"\n def component_fn(completions, target_name=None, **kwargs):\n target_names = target_name if isinstance(target_name, list) else [target_name] * len(completions)\n rewards = []\n for completion, tn in zip(completions, target_names):\n try:\n e = env_template.clone()\n e.reset(target_name=tn)\n _, reward_dict, _, _ = e.step(completion)\n rewards.append(float(reward_dict.get(component, 0.0)))\n except Exception:\n rewards.append(0.0)\n return rewards\n component_fn.__name__ = f\"reward_{component}\"\n return component_fn\n\n# Main reward function (returns total reward)\ndef wrapped_reward_fn(completions, target_name=None, **kwargs):\n \"\"\"Main reward function — extracts target_name from batch columns.\"\"\"\n target_names = target_name if isinstance(target_name, list) else [target_name] * len(completions)\n return reward_fn(completions, target_names=target_names)\n\n# Build list of all reward functions: [total, kawasaki, maekawa, blb, progress, economy, completion]\nall_reward_fns = [wrapped_reward_fn] + [\n make_component_fn(env, c) for c in REWARD_COMPONENTS\n]\n\n# ── GRPO Config ─────────────────────────────────────────────────────────────\nconfig = GRPOConfig(\n output_dir=OUTPUT_DIR,\n num_train_epochs=EPOCHS,\n per_device_train_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=GRAD_ACCUM,\n learning_rate=LR,\n max_completion_length=MAX_COMPLETION_LEN,\n num_generations=N_GENERATIONS,\n temperature=1.0,\n logging_steps=1,\n save_steps=50,\n report_to=\"trackio\",\n run_name=f\"grpo-{MODEL_NAME.split('/')[-1]}-level{LEVEL}\",\n)\n\ntrainer = GRPOTrainer(\n model=model,\n config=config,\n train_dataset=hf_dataset,\n reward_funcs=all_reward_fns,\n tokenizer=tokenizer,\n)\n\nprint(f\"Trainer ready. Starting GRPO training...\")\nprint(f\" Model: {MODEL_NAME}\")\nprint(f\" Dataset: {len(hf_dataset)} examples\")\nprint(f\" Reward functions: total + {REWARD_COMPONENTS}\")\nprint(f\" Logging to: Trackio\")",
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"execution_count": null,
|
| 113 |
+
"outputs": []
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "code",
|
| 117 |
+
"id": "mwy1i7d509",
|
| 118 |
+
"source": "# ── Train! ──────────────────────────────────────────────────────────────────\ntrainer.train()\n\n# Save model + tokenizer\nmodel.save_pretrained(OUTPUT_DIR)\ntokenizer.save_pretrained(OUTPUT_DIR)\nprint(f\"\\nTraining complete. Model saved to {OUTPUT_DIR}/\")",
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"execution_count": null,
|
| 121 |
+
"outputs": []
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"cell_type": "markdown",
|
| 125 |
+
"id": "54iu59w1ipb",
|
| 126 |
+
"source": "## 8. Evaluation\n\nRun the trained model on all targets and measure solve rates + average rewards.",
|
| 127 |
+
"metadata": {}
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"cell_type": "code",
|
| 131 |
+
"id": "7loiy5mrsi3",
|
| 132 |
+
"source": "# ── Evaluate trained model on all targets ───────────────────────────────────\nfrom transformers import TextStreamer\n\nFastVisionModel.for_inference(model)\n\neval_results = {}\nn_samples = 5 # generations per target\n\nfor target_name in env.available_targets():\n obs = env.reset(target_name=target_name)\n prompt = obs[\"prompt\"]\n\n # Tokenize prompt\n messages = [{\"role\": \"user\", \"content\": prompt}]\n input_ids = tokenizer.apply_chat_template(messages, return_tensors=\"pt\").to(model.device)\n\n target_rewards = []\n target_solved = 0\n\n for _ in range(n_samples):\n outputs = model.generate(\n input_ids=input_ids,\n max_new_tokens=MAX_COMPLETION_LEN,\n temperature=0.7,\n do_sample=True,\n )\n completion = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)\n\n # Score completion\n e = env.clone()\n e.reset(target_name=target_name)\n try:\n _, reward_dict, _, _ = e.step(completion)\n total = float(reward_dict[\"total\"])\n solved = reward_dict.get(\"completion\", 0) > 0\n except Exception:\n total = -0.1\n solved = False\n\n target_rewards.append(total)\n if solved:\n target_solved += 1\n\n eval_results[target_name] = {\n \"avg_reward\": sum(target_rewards) / len(target_rewards),\n \"max_reward\": max(target_rewards),\n \"solve_rate\": target_solved / n_samples,\n }\n print(f\" {target_name:20s} avg={eval_results[target_name]['avg_reward']:.3f} \"\n f\"max={eval_results[target_name]['max_reward']:.3f} \"\n f\"solved={target_solved}/{n_samples}\")\n\n# Summary\navg_solve = sum(r[\"solve_rate\"] for r in eval_results.values()) / len(eval_results)\navg_reward = sum(r[\"avg_reward\"] for r in eval_results.values()) / len(eval_results)\nprint(f\"\\nOverall: avg_reward={avg_reward:.3f}, solve_rate={avg_solve:.1%}\")\n\n# Log to Trackio\ntrackio.log({\"eval/avg_reward\": avg_reward, \"eval/solve_rate\": avg_solve})\nfor name, res in eval_results.items():\n trackio.log({f\"eval/{name}_reward\": res[\"avg_reward\"], f\"eval/{name}_solved\": res[\"solve_rate\"]})",
|
| 133 |
+
"metadata": {},
|
| 134 |
+
"execution_count": null,
|
| 135 |
+
"outputs": []
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "markdown",
|
| 139 |
+
"id": "c9dr4ht05r",
|
| 140 |
+
"source": "## 9. A/B Comparison: SpatialThinker vs Vanilla Qwen2.5-VL\n\nTo compare both models, run this notebook twice:\n1. First run with `MODEL_NAME = \"OX-PIXL/SpatialThinker-Qwen2.5-VL-7B\"`\n2. Second run with `MODEL_NAME = \"unsloth/Qwen2.5-VL-7B-Instruct\"`\n\nBoth runs log to the same Trackio project (`optigami`) with different run names, so you can overlay the reward curves directly in the dashboard.\n\nThe cell below loads saved eval results from both runs for comparison (run after both trainings complete).",
|
| 141 |
+
"metadata": {}
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"cell_type": "code",
|
| 145 |
+
"id": "qwwd4wyuhnq",
|
| 146 |
+
"source": "# ── Save eval results for comparison ────────────────────────────────────────\nimport json\n\nmodel_tag = MODEL_NAME.split(\"/\")[-1]\neval_path = f\"eval_results_{model_tag}_level{LEVEL}.json\"\n\nwith open(eval_path, \"w\") as f:\n json.dump(eval_results, f, indent=2)\nprint(f\"Eval results saved to {eval_path}\")\n\n# ── Compare (run after both models are trained) ────────────────────────────\nspatial_path = f\"eval_results_SpatialThinker-Qwen2.5-VL-7B_level{LEVEL}.json\"\nvanilla_path = f\"eval_results_Qwen2.5-VL-7B-Instruct_level{LEVEL}.json\"\n\nif os.path.exists(spatial_path) and os.path.exists(vanilla_path):\n with open(spatial_path) as f:\n spatial = json.load(f)\n with open(vanilla_path) as f:\n vanilla = json.load(f)\n\n print(f\"\\n{'Target':<22} {'SpatialThinker':>16} {'Vanilla Qwen':>16} {'Delta':>10}\")\n print(\"-\" * 66)\n for target in sorted(set(list(spatial.keys()) + list(vanilla.keys()))):\n s_r = spatial.get(target, {}).get(\"avg_reward\", 0)\n v_r = vanilla.get(target, {}).get(\"avg_reward\", 0)\n delta = s_r - v_r\n print(f\" {target:<20} {s_r:>14.3f} {v_r:>14.3f} {delta:>+8.3f}\")\n\n s_avg = sum(r[\"avg_reward\"] for r in spatial.values()) / len(spatial)\n v_avg = sum(r[\"avg_reward\"] for r in vanilla.values()) / len(vanilla)\n print(f\"\\n {'OVERALL':<20} {s_avg:>14.3f} {v_avg:>14.3f} {s_avg - v_avg:>+8.3f}\")\n\n s_solve = sum(r[\"solve_rate\"] for r in spatial.values()) / len(spatial)\n v_solve = sum(r[\"solve_rate\"] for r in vanilla.values()) / len(vanilla)\n print(f\" {'Solve Rate':<20} {s_solve:>13.1%} {v_solve:>13.1%} {s_solve - v_solve:>+7.1%}\")\nelse:\n print(f\"Run both models to compare. Looking for:\\n {spatial_path}\\n {vanilla_path}\")",
|
| 147 |
+
"metadata": {},
|
| 148 |
+
"execution_count": null,
|
| 149 |
+
"outputs": []
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"cell_type": "markdown",
|
| 153 |
+
"id": "812csd43vxk",
|
| 154 |
+
"source": "## 10. Push to HuggingFace Hub (optional)\n\nUpload the trained LoRA adapter to HF for deployment or further fine-tuning.",
|
| 155 |
+
"metadata": {}
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"cell_type": "code",
|
| 159 |
+
"id": "h38kp70n16q",
|
| 160 |
+
"source": "# ── Push to HF Hub (uncomment and set your repo) ───────────────────────────\n# from huggingface_hub import login\n# login(token=\"hf_...\") # or use HF_TOKEN env var\n#\n# HF_REPO = \"your-username/optigami-grpo-spatialthinker\"\n# model.push_to_hub(HF_REPO)\n# tokenizer.push_to_hub(HF_REPO)\n# print(f\"Model pushed to https://huggingface.co/{HF_REPO}\")\n\ntrackio.finish()\nprint(\"Done! Check your Trackio dashboard for training curves.\")",
|
| 161 |
+
"metadata": {},
|
| 162 |
+
"execution_count": null,
|
| 163 |
+
"outputs": []
|
| 164 |
+
}
|
| 165 |
+
],
|
| 166 |
+
"metadata": {
|
| 167 |
+
"kernelspec": {
|
| 168 |
+
"display_name": "Python 3",
|
| 169 |
+
"language": "python",
|
| 170 |
+
"name": "python3"
|
| 171 |
+
},
|
| 172 |
+
"language_info": {
|
| 173 |
+
"name": "python",
|
| 174 |
+
"version": "3.10.0"
|
| 175 |
+
},
|
| 176 |
+
"colab": {
|
| 177 |
+
"provenance": [],
|
| 178 |
+
"gpuType": "A100"
|
| 179 |
+
},
|
| 180 |
+
"accelerator": "GPU"
|
| 181 |
+
},
|
| 182 |
+
"nbformat": 4,
|
| 183 |
+
"nbformat_minor": 5
|
| 184 |
+
}
|
trainer/mock_env.py
CHANGED
|
@@ -135,6 +135,22 @@ def apply_fold_mock(state: PaperState, fold: dict) -> tuple[PaperState, str | No
|
|
| 135 |
if fold_type not in ("valley", "mountain"):
|
| 136 |
return state, f"Unknown fold type: {fold_type}"
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
if not (0 < angle_deg <= 180):
|
| 139 |
return state, f"Angle must be in (0, 180], got {angle_deg}"
|
| 140 |
|
|
|
|
| 135 |
if fold_type not in ("valley", "mountain"):
|
| 136 |
return state, f"Unknown fold type: {fold_type}"
|
| 137 |
|
| 138 |
+
# angle=0 means "no fold" — return unchanged copy
|
| 139 |
+
if angle_deg == 0:
|
| 140 |
+
return PaperState(
|
| 141 |
+
vertices=state.vertices.copy(), edges=state.edges.copy(),
|
| 142 |
+
faces=[f[:] for f in state.faces],
|
| 143 |
+
assignments=state.assignments[:], fold_angles=state.fold_angles.copy(),
|
| 144 |
+
rest_lengths=state.rest_lengths.copy(), strain=state.strain.copy(),
|
| 145 |
+
energy=state.energy, face_orders=state.face_orders[:],
|
| 146 |
+
num_layers=state.num_layers, material=state.material,
|
| 147 |
+
bounding_box=state.bounding_box.copy(),
|
| 148 |
+
deployment_ratio=state.deployment_ratio, is_valid=state.is_valid,
|
| 149 |
+
kawasaki_violation=state.kawasaki_violation,
|
| 150 |
+
maekawa_violation=state.maekawa_violation,
|
| 151 |
+
self_intersections=state.self_intersections,
|
| 152 |
+
), None
|
| 153 |
+
|
| 154 |
if not (0 < angle_deg <= 180):
|
| 155 |
return state, f"Angle must be in (0, 180], got {angle_deg}"
|
| 156 |
|
trainer/prompts.py
CHANGED
|
@@ -1,49 +1,99 @@
|
|
| 1 |
"""
|
| 2 |
Prompt templates for origami fold strategy generation.
|
| 3 |
|
| 4 |
-
|
| 5 |
-
a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
SYSTEM_PROMPT = """\
|
| 9 |
-
You are an origami engineer
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
Rules:
|
| 16 |
- Only use native Python (no imports except math, itertools, functools)
|
| 17 |
- Each fold: {"type": "valley"|"mountain", "line": {"start": [x,y], "end": [x,y]}, "angle": 0-180}
|
| 18 |
-
- Fold lines must
|
|
|
|
|
|
|
|
|
|
| 19 |
- Fewer folds is better (efficiency matters)
|
| 20 |
-
- Respect material strain limits
|
| 21 |
-
- Output ONLY the function in ```python ... ``` backticks\
|
| 22 |
"""
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
TASK_TEMPLATES = {
|
| 26 |
"half_fold": {
|
| 27 |
"name": "half_fold",
|
| 28 |
"prompt": """\
|
| 29 |
TASK: Fold a {width}m x {height}m {material} sheet in half to minimize one dimension.
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
MATERIAL: {material} (thickness: {thickness_mm}mm, max strain: {max_strain_pct}%)
|
| 32 |
CONSTRAINTS: Maximum {max_folds} fold operations.
|
| 33 |
-
TARGET: Deployment ratio <= 0.5
|
| 34 |
-
|
| 35 |
-
CURRENT STATE:
|
| 36 |
-
Sheet: {width}m x {height}m, flat (0 folds applied)
|
| 37 |
-
Bounding box: {width}m x {height}m x 0.0m
|
| 38 |
-
|
| 39 |
-
Write a fold_strategy(paper_state) function that returns a list of fold operations.
|
| 40 |
-
Each fold: {{"type": "valley"|"mountain", "line": {{"start": [x,y], "end": [x,y]}}, "angle": 0-180}}
|
| 41 |
-
|
| 42 |
-
```python
|
| 43 |
-
def fold_strategy(paper_state):
|
| 44 |
-
# Your code here
|
| 45 |
-
return [...]
|
| 46 |
-
```""",
|
| 47 |
"target_ratio": 0.5,
|
| 48 |
"max_folds": 3,
|
| 49 |
},
|
|
@@ -53,21 +103,14 @@ def fold_strategy(paper_state):
|
|
| 53 |
"prompt": """\
|
| 54 |
TASK: Fold a {width}m x {height}m {material} sheet into thirds (like a letter).
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
MATERIAL: {material} (thickness: {thickness_mm}mm, max strain: {max_strain_pct}%)
|
| 57 |
CONSTRAINTS: Maximum {max_folds} fold operations.
|
| 58 |
-
TARGET: Deployment ratio <= 0.33
|
| 59 |
-
|
| 60 |
-
CURRENT STATE:
|
| 61 |
-
Sheet: {width}m x {height}m, flat (0 folds applied)
|
| 62 |
-
|
| 63 |
-
Write a fold_strategy(paper_state) function that returns a list of fold operations.
|
| 64 |
-
Each fold: {{"type": "valley"|"mountain", "line": {{"start": [x,y], "end": [x,y]}}, "angle": 0-180}}
|
| 65 |
-
|
| 66 |
-
```python
|
| 67 |
-
def fold_strategy(paper_state):
|
| 68 |
-
# Your code here
|
| 69 |
-
return [...]
|
| 70 |
-
```""",
|
| 71 |
"target_ratio": 0.33,
|
| 72 |
"max_folds": 5,
|
| 73 |
},
|
|
@@ -78,30 +121,22 @@ def fold_strategy(paper_state):
|
|
| 78 |
TASK: Fold a {width}m x {height}m Mylar sheet to minimize packed volume for a solar panel.
|
| 79 |
The folded panel must be deployable (unfold cleanly to near-original area).
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
MATERIAL: Mylar (thickness: 0.05mm, Young's modulus: 4 GPa, max strain: 3%)
|
| 82 |
CONSTRAINTS:
|
| 83 |
- Maximum {max_folds} fold operations
|
| 84 |
- Must pack into bounding box <= 15cm x 15cm x 5cm
|
| 85 |
-
- Must deploy to >= 80% of original area
|
| 86 |
- No self-intersections
|
| 87 |
|
| 88 |
-
TARGET: Deployment ratio <= 0.05 (95%
|
| 89 |
-
|
| 90 |
-
CURRENT STATE:
|
| 91 |
-
Sheet: {width}m x {height}m, flat (0 folds applied)
|
| 92 |
-
Bounding box: {width}m x {height}m x 0.0m
|
| 93 |
|
| 94 |
-
HINT:
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
Write a fold_strategy(paper_state) function that returns a list of fold operations.
|
| 98 |
-
Each fold: {{"type": "valley"|"mountain", "line": {{"start": [x,y], "end": [x,y]}}, "angle": 0-180}}
|
| 99 |
-
|
| 100 |
-
```python
|
| 101 |
-
def fold_strategy(paper_state):
|
| 102 |
-
# Your code here
|
| 103 |
-
return [...]
|
| 104 |
-
```""",
|
| 105 |
"target_ratio": 0.05,
|
| 106 |
"max_folds": 20,
|
| 107 |
},
|
|
@@ -111,29 +146,26 @@ def fold_strategy(paper_state):
|
|
| 111 |
"prompt": """\
|
| 112 |
TASK: Fold a {width}m x {height}m Nitinol sheet into a compact cylinder for a medical stent.
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
MATERIAL: Nitinol (thickness: 0.1mm, Young's modulus: 75 GPa, max strain: 8%)
|
| 115 |
CONSTRAINTS:
|
| 116 |
- Maximum {max_folds} fold operations
|
| 117 |
-
- Compressed diameter: 3mm
|
| 118 |
-
- Deployed diameter: 10mm
|
| 119 |
-
- Must be radially deployable
|
| 120 |
-
|
| 121 |
-
TARGET: Minimize packed cross-section while maintaining deployability.
|
| 122 |
-
|
| 123 |
-
Write a fold_strategy(paper_state) function that returns a list of fold operations.
|
| 124 |
|
| 125 |
-
|
| 126 |
-
def fold_strategy(paper_state):
|
| 127 |
-
# Your code here
|
| 128 |
-
return [...]
|
| 129 |
-
```""",
|
| 130 |
"target_ratio": 0.1,
|
| 131 |
"max_folds": 15,
|
| 132 |
},
|
| 133 |
}
|
| 134 |
|
| 135 |
|
| 136 |
-
#
|
|
|
|
|
|
|
|
|
|
| 137 |
TASK_CONFIGS = {
|
| 138 |
"half_fold": {
|
| 139 |
"width": 1.0, "height": 1.0, "material": "paper",
|
|
@@ -158,6 +190,16 @@ def build_prompt(task_name: str = "half_fold", **overrides) -> str:
|
|
| 158 |
"""Build a complete user prompt for a given task."""
|
| 159 |
task = TASK_TEMPLATES[task_name]
|
| 160 |
config = {**TASK_CONFIGS[task_name], **overrides}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
return task["prompt"].format(**config)
|
| 162 |
|
| 163 |
|
|
|
|
| 1 |
"""
|
| 2 |
Prompt templates for origami fold strategy generation.
|
| 3 |
|
| 4 |
+
Inspired by SpatialThinker (arXiv 2511.07403): the model must produce
|
| 5 |
+
a structured spatial representation BEFORE generating code.
|
| 6 |
+
|
| 7 |
+
Output format (4 stages):
|
| 8 |
+
<observe> — Describe the paper geometry and constraints
|
| 9 |
+
<plan> — Structured fold plan with coordinates and reasoning
|
| 10 |
+
<code> — The fold_strategy() function
|
| 11 |
+
<verify> — Predict expected outcome (deployment ratio, fold count)
|
| 12 |
+
|
| 13 |
+
Dense rewards check each stage independently, not just code execution.
|
| 14 |
"""
|
| 15 |
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
# System prompt — defines the structured output format
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
SYSTEM_PROMPT = """\
|
| 21 |
+
You are an origami engineer specializing in computational fold design.
|
| 22 |
+
You solve folding tasks by reasoning spatially about paper geometry.
|
| 23 |
+
|
| 24 |
+
You MUST respond in exactly this 4-stage format:
|
| 25 |
+
|
| 26 |
+
<observe>
|
| 27 |
+
Describe the paper: dimensions, material, coordinate system.
|
| 28 |
+
Identify key geometric features (center, edges, diagonals, symmetry axes).
|
| 29 |
+
Note constraints (max strain, max folds, target ratio).
|
| 30 |
+
</observe>
|
| 31 |
+
|
| 32 |
+
<plan>
|
| 33 |
+
{
|
| 34 |
+
"strategy": "description of overall approach",
|
| 35 |
+
"folds": [
|
| 36 |
+
{
|
| 37 |
+
"description": "what this fold does",
|
| 38 |
+
"type": "valley or mountain",
|
| 39 |
+
"line_start": [x, y],
|
| 40 |
+
"line_end": [x, y],
|
| 41 |
+
"angle": 180,
|
| 42 |
+
"reasoning": "why these coordinates"
|
| 43 |
+
}
|
| 44 |
+
],
|
| 45 |
+
"expected_ratio": 0.5,
|
| 46 |
+
"expected_folds": 1
|
| 47 |
+
}
|
| 48 |
+
</plan>
|
| 49 |
+
|
| 50 |
+
<code>
|
| 51 |
+
```python
|
| 52 |
+
def fold_strategy(paper_state):
|
| 53 |
+
# Implementation matching the plan above
|
| 54 |
+
return [...]
|
| 55 |
+
```
|
| 56 |
+
</code>
|
| 57 |
|
| 58 |
+
<verify>
|
| 59 |
+
Expected deployment ratio: X.XX
|
| 60 |
+
Expected fold count: N
|
| 61 |
+
Expected max strain: X.XXXX
|
| 62 |
+
Potential issues: ...
|
| 63 |
+
</verify>
|
| 64 |
|
| 65 |
Rules:
|
| 66 |
- Only use native Python (no imports except math, itertools, functools)
|
| 67 |
- Each fold: {"type": "valley"|"mountain", "line": {"start": [x,y], "end": [x,y]}, "angle": 0-180}
|
| 68 |
+
- Fold lines must cross the paper boundary (intersect at least 2 edges)
|
| 69 |
+
- Valley = fold toward you (+Z), Mountain = fold away (-Z)
|
| 70 |
+
- angle=180 = fully folded, smaller = partial fold
|
| 71 |
+
- Each fold changes the geometry — later folds operate on already-folded paper
|
| 72 |
- Fewer folds is better (efficiency matters)
|
| 73 |
+
- Respect material strain limits\
|
|
|
|
| 74 |
"""
|
| 75 |
|
| 76 |
|
| 77 |
+
# ---------------------------------------------------------------------------
|
| 78 |
+
# Task templates — each includes spatial context
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
|
| 81 |
TASK_TEMPLATES = {
|
| 82 |
"half_fold": {
|
| 83 |
"name": "half_fold",
|
| 84 |
"prompt": """\
|
| 85 |
TASK: Fold a {width}m x {height}m {material} sheet in half to minimize one dimension.
|
| 86 |
|
| 87 |
+
PAPER GEOMETRY:
|
| 88 |
+
Corners: (0,0), ({width},0), ({width},{height}), (0,{height})
|
| 89 |
+
Center: ({cx},{cy})
|
| 90 |
+
Horizontal midline: y={cy} from (0,{cy}) to ({width},{cy})
|
| 91 |
+
Vertical midline: x={cx} from ({cx},0) to ({cx},{height})
|
| 92 |
+
Diagonals: (0,0)→({width},{height}) and ({width},0)→(0,{height})
|
| 93 |
+
|
| 94 |
MATERIAL: {material} (thickness: {thickness_mm}mm, max strain: {max_strain_pct}%)
|
| 95 |
CONSTRAINTS: Maximum {max_folds} fold operations.
|
| 96 |
+
TARGET: Deployment ratio <= 0.5""",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
"target_ratio": 0.5,
|
| 98 |
"max_folds": 3,
|
| 99 |
},
|
|
|
|
| 103 |
"prompt": """\
|
| 104 |
TASK: Fold a {width}m x {height}m {material} sheet into thirds (like a letter).
|
| 105 |
|
| 106 |
+
PAPER GEOMETRY:
|
| 107 |
+
Corners: (0,0), ({width},0), ({width},{height}), (0,{height})
|
| 108 |
+
Third lines: y={t1:.4f} and y={t2:.4f}
|
| 109 |
+
Center: ({cx},{cy})
|
| 110 |
+
|
| 111 |
MATERIAL: {material} (thickness: {thickness_mm}mm, max strain: {max_strain_pct}%)
|
| 112 |
CONSTRAINTS: Maximum {max_folds} fold operations.
|
| 113 |
+
TARGET: Deployment ratio <= 0.33""",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
"target_ratio": 0.33,
|
| 115 |
"max_folds": 5,
|
| 116 |
},
|
|
|
|
| 121 |
TASK: Fold a {width}m x {height}m Mylar sheet to minimize packed volume for a solar panel.
|
| 122 |
The folded panel must be deployable (unfold cleanly to near-original area).
|
| 123 |
|
| 124 |
+
PAPER GEOMETRY:
|
| 125 |
+
Corners: (0,0), ({width},0), ({width},{height}), (0,{height})
|
| 126 |
+
Center: ({cx},{cy})
|
| 127 |
+
Area: {area}m²
|
| 128 |
+
|
| 129 |
MATERIAL: Mylar (thickness: 0.05mm, Young's modulus: 4 GPa, max strain: 3%)
|
| 130 |
CONSTRAINTS:
|
| 131 |
- Maximum {max_folds} fold operations
|
| 132 |
- Must pack into bounding box <= 15cm x 15cm x 5cm
|
|
|
|
| 133 |
- No self-intersections
|
| 134 |
|
| 135 |
+
TARGET: Deployment ratio <= 0.05 (95% area reduction)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
HINT: Tessellated patterns (alternating M/V folds in a grid) achieve high
|
| 138 |
+
compaction with single-DOF deployment. Consider dividing the sheet into
|
| 139 |
+
a regular grid of panels.""",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
"target_ratio": 0.05,
|
| 141 |
"max_folds": 20,
|
| 142 |
},
|
|
|
|
| 146 |
"prompt": """\
|
| 147 |
TASK: Fold a {width}m x {height}m Nitinol sheet into a compact cylinder for a medical stent.
|
| 148 |
|
| 149 |
+
PAPER GEOMETRY:
|
| 150 |
+
Corners: (0,0), ({width},0), ({width},{height}), (0,{height})
|
| 151 |
+
Center: ({cx},{cy})
|
| 152 |
+
|
| 153 |
MATERIAL: Nitinol (thickness: 0.1mm, Young's modulus: 75 GPa, max strain: 8%)
|
| 154 |
CONSTRAINTS:
|
| 155 |
- Maximum {max_folds} fold operations
|
| 156 |
+
- Compressed diameter: 3mm, Deployed diameter: 10mm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
+
TARGET: Deployment ratio <= 0.1""",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
"target_ratio": 0.1,
|
| 160 |
"max_folds": 15,
|
| 161 |
},
|
| 162 |
}
|
| 163 |
|
| 164 |
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
# Config and builders
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
|
| 169 |
TASK_CONFIGS = {
|
| 170 |
"half_fold": {
|
| 171 |
"width": 1.0, "height": 1.0, "material": "paper",
|
|
|
|
| 190 |
"""Build a complete user prompt for a given task."""
|
| 191 |
task = TASK_TEMPLATES[task_name]
|
| 192 |
config = {**TASK_CONFIGS[task_name], **overrides}
|
| 193 |
+
|
| 194 |
+
# Add computed geometry values
|
| 195 |
+
w = config["width"]
|
| 196 |
+
h = config["height"]
|
| 197 |
+
config["cx"] = w / 2
|
| 198 |
+
config["cy"] = h / 2
|
| 199 |
+
config["area"] = w * h
|
| 200 |
+
config["t1"] = h / 3
|
| 201 |
+
config["t2"] = 2 * h / 3
|
| 202 |
+
|
| 203 |
return task["prompt"].format(**config)
|
| 204 |
|
| 205 |
|
trainer/rewards.py
CHANGED
|
@@ -1,17 +1,22 @@
|
|
| 1 |
"""
|
| 2 |
Reward functions for origami GRPO training.
|
| 3 |
|
| 4 |
-
|
| 5 |
-
1.
|
| 6 |
-
2.
|
| 7 |
-
3.
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
|
| 13 |
import ast
|
|
|
|
| 14 |
import sys
|
|
|
|
| 15 |
import math
|
| 16 |
import traceback
|
| 17 |
from typing import Callable
|
|
@@ -60,23 +65,57 @@ except ImportError:
|
|
| 60 |
# ---------------------------------------------------------------------------
|
| 61 |
|
| 62 |
def extract_function(text: str) -> str | None:
|
| 63 |
-
"""Extract
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
return None
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
| 70 |
# Find the def statement
|
| 71 |
-
def_idx =
|
| 72 |
if def_idx == -1:
|
| 73 |
return None
|
| 74 |
-
fx =
|
| 75 |
if fx.startswith("def fold_strategy("):
|
| 76 |
return fx
|
| 77 |
return None
|
| 78 |
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
def check_imports_stdlib_only(code: str) -> tuple[bool, str]:
|
| 81 |
"""Check that code only imports from Python stdlib."""
|
| 82 |
try:
|
|
@@ -386,3 +425,289 @@ def fold_quality(completions, **kwargs) -> list[float]:
|
|
| 386 |
scores.append(-3.0)
|
| 387 |
|
| 388 |
return scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Reward functions for origami GRPO training.
|
| 3 |
|
| 4 |
+
SpatialThinker-style dense rewards (arXiv 2511.07403):
|
| 5 |
+
1. format_reward (0.10) — All 4 tags present, valid JSON plan, valid function
|
| 6 |
+
2. spatial_reward (0.20) — Fold coordinates in plan are within bounds, lines valid
|
| 7 |
+
3. execution_reward (0.50) — Physical validity + fold quality (code execution)
|
| 8 |
+
4. consistency_reward(0.20) — Plan matches code, verify matches actual results
|
| 9 |
|
| 10 |
+
Plus legacy rewards for backwards compatibility:
|
| 11 |
+
- code_valid, physically_valid, fold_quality
|
| 12 |
+
|
| 13 |
+
Lexicographic gating: if code doesn't parse, downstream rewards are 0.
|
| 14 |
"""
|
| 15 |
|
| 16 |
import ast
|
| 17 |
+
import re
|
| 18 |
import sys
|
| 19 |
+
import json
|
| 20 |
import math
|
| 21 |
import traceback
|
| 22 |
from typing import Callable
|
|
|
|
| 65 |
# ---------------------------------------------------------------------------
|
| 66 |
|
| 67 |
def extract_function(text: str) -> str | None:
|
| 68 |
+
"""Extract fold_strategy() from <code> blocks or triple-backtick code blocks."""
|
| 69 |
+
# Try <code> block first (SpatialThinker format)
|
| 70 |
+
code_match = re.search(r'<code>(.*?)</code>', text, re.DOTALL)
|
| 71 |
+
if code_match:
|
| 72 |
+
code_block = code_match.group(1).strip()
|
| 73 |
+
elif text.count("```") >= 2:
|
| 74 |
+
first = text.find("```") + 3
|
| 75 |
+
second = text.find("```", first)
|
| 76 |
+
code_block = text[first:second].strip()
|
| 77 |
+
else:
|
| 78 |
return None
|
| 79 |
+
|
| 80 |
+
code_block = code_block.removeprefix("```python\n").removeprefix("```python\r\n")
|
| 81 |
+
code_block = code_block.removeprefix("python\n").removeprefix("python\r\n")
|
| 82 |
+
code_block = code_block.rstrip("`").strip()
|
| 83 |
+
|
| 84 |
# Find the def statement
|
| 85 |
+
def_idx = code_block.find("def ")
|
| 86 |
if def_idx == -1:
|
| 87 |
return None
|
| 88 |
+
fx = code_block[def_idx:]
|
| 89 |
if fx.startswith("def fold_strategy("):
|
| 90 |
return fx
|
| 91 |
return None
|
| 92 |
|
| 93 |
|
| 94 |
+
def extract_section(text: str, tag: str) -> str | None:
|
| 95 |
+
"""Extract content between <tag>...</tag>."""
|
| 96 |
+
match = re.search(rf'<{tag}>(.*?)</{tag}>', text, re.DOTALL)
|
| 97 |
+
return match.group(1).strip() if match else None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def extract_plan_json(text: str) -> dict | None:
|
| 101 |
+
"""Extract and parse the JSON fold plan from <plan> block."""
|
| 102 |
+
plan_text = extract_section(text, "plan")
|
| 103 |
+
if not plan_text:
|
| 104 |
+
return None
|
| 105 |
+
try:
|
| 106 |
+
return json.loads(plan_text)
|
| 107 |
+
except json.JSONDecodeError:
|
| 108 |
+
# Try to find JSON object within the plan text
|
| 109 |
+
brace_start = plan_text.find("{")
|
| 110 |
+
brace_end = plan_text.rfind("}")
|
| 111 |
+
if brace_start >= 0 and brace_end > brace_start:
|
| 112 |
+
try:
|
| 113 |
+
return json.loads(plan_text[brace_start:brace_end + 1])
|
| 114 |
+
except json.JSONDecodeError:
|
| 115 |
+
pass
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
|
| 119 |
def check_imports_stdlib_only(code: str) -> tuple[bool, str]:
|
| 120 |
"""Check that code only imports from Python stdlib."""
|
| 121 |
try:
|
|
|
|
| 425 |
scores.append(-3.0)
|
| 426 |
|
| 427 |
return scores
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# ---------------------------------------------------------------------------
|
| 431 |
+
# SpatialThinker Dense Rewards (weight 0.10 + 0.20 + 0.50 + 0.20 = 1.0)
|
| 432 |
+
# ---------------------------------------------------------------------------
|
| 433 |
+
|
| 434 |
+
REQUIRED_TAGS = ["observe", "plan", "code", "verify"]
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def format_reward(completions, **kwargs) -> list[float]:
|
| 438 |
+
"""
|
| 439 |
+
SpatialThinker format reward (weight: 0.10).
|
| 440 |
+
|
| 441 |
+
Checks that the response has all 4 structured tags, valid JSON in <plan>,
|
| 442 |
+
and a parseable function in <code>.
|
| 443 |
+
|
| 444 |
+
Score range: [0.0, 1.0]
|
| 445 |
+
"""
|
| 446 |
+
scores = []
|
| 447 |
+
for completion in completions:
|
| 448 |
+
response = completion[0]["content"]
|
| 449 |
+
score = 0.0
|
| 450 |
+
|
| 451 |
+
# Check each required tag (0.15 each = 0.60 for all 4)
|
| 452 |
+
tags_present = 0
|
| 453 |
+
for tag in REQUIRED_TAGS:
|
| 454 |
+
if extract_section(response, tag) is not None:
|
| 455 |
+
tags_present += 1
|
| 456 |
+
score += 0.15 * tags_present
|
| 457 |
+
|
| 458 |
+
# Valid JSON in <plan> (0.20)
|
| 459 |
+
plan = extract_plan_json(response)
|
| 460 |
+
if plan is not None:
|
| 461 |
+
score += 0.20
|
| 462 |
+
# Plan has required fields (0.05 bonus)
|
| 463 |
+
if "folds" in plan and isinstance(plan["folds"], list):
|
| 464 |
+
score += 0.05
|
| 465 |
+
|
| 466 |
+
# Valid function in <code> (0.15)
|
| 467 |
+
fn = extract_function(response)
|
| 468 |
+
if fn is not None:
|
| 469 |
+
score += 0.15
|
| 470 |
+
|
| 471 |
+
scores.append(score)
|
| 472 |
+
return scores
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def spatial_reward(completions, **kwargs) -> list[float]:
|
| 476 |
+
"""
|
| 477 |
+
SpatialThinker spatial plan quality reward (weight: 0.20).
|
| 478 |
+
|
| 479 |
+
Checks that fold coordinates in <plan> are geometrically valid:
|
| 480 |
+
- Within paper bounds
|
| 481 |
+
- Line endpoints form valid fold lines (cross the paper)
|
| 482 |
+
- Fold types are valid
|
| 483 |
+
- Expected ratio/count are reasonable
|
| 484 |
+
|
| 485 |
+
Score range: [0.0, 1.0]
|
| 486 |
+
"""
|
| 487 |
+
w = _current_task["width"]
|
| 488 |
+
h = _current_task["height"]
|
| 489 |
+
|
| 490 |
+
scores = []
|
| 491 |
+
for completion in completions:
|
| 492 |
+
response = completion[0]["content"]
|
| 493 |
+
plan = extract_plan_json(response)
|
| 494 |
+
|
| 495 |
+
if plan is None:
|
| 496 |
+
scores.append(0.0)
|
| 497 |
+
continue
|
| 498 |
+
|
| 499 |
+
score = 0.0
|
| 500 |
+
folds = plan.get("folds", [])
|
| 501 |
+
|
| 502 |
+
if not folds:
|
| 503 |
+
scores.append(0.0)
|
| 504 |
+
continue
|
| 505 |
+
|
| 506 |
+
# Score each fold in the plan
|
| 507 |
+
valid_folds = 0
|
| 508 |
+
for fold in folds:
|
| 509 |
+
fold_score = 0.0
|
| 510 |
+
|
| 511 |
+
# Has required fields
|
| 512 |
+
has_type = fold.get("type") in ("valley", "mountain")
|
| 513 |
+
has_start = isinstance(fold.get("line_start"), list) and len(fold.get("line_start", [])) == 2
|
| 514 |
+
has_end = isinstance(fold.get("line_end"), list) and len(fold.get("line_end", [])) == 2
|
| 515 |
+
|
| 516 |
+
if has_type:
|
| 517 |
+
fold_score += 0.25
|
| 518 |
+
if has_start and has_end:
|
| 519 |
+
fold_score += 0.25
|
| 520 |
+
# Coordinates within paper bounds (with small tolerance)
|
| 521 |
+
sx, sy = fold["line_start"]
|
| 522 |
+
ex, ey = fold["line_end"]
|
| 523 |
+
tol = 0.01
|
| 524 |
+
in_bounds = (
|
| 525 |
+
-tol <= sx <= w + tol and -tol <= sy <= h + tol and
|
| 526 |
+
-tol <= ex <= w + tol and -tol <= ey <= h + tol
|
| 527 |
+
)
|
| 528 |
+
if in_bounds:
|
| 529 |
+
fold_score += 0.25
|
| 530 |
+
|
| 531 |
+
# Start != end (not a degenerate line)
|
| 532 |
+
dist = math.sqrt((ex - sx)**2 + (ey - sy)**2)
|
| 533 |
+
if dist > 0.01:
|
| 534 |
+
fold_score += 0.25
|
| 535 |
+
|
| 536 |
+
if fold_score > 0.5:
|
| 537 |
+
valid_folds += 1
|
| 538 |
+
|
| 539 |
+
# Proportion of valid folds
|
| 540 |
+
score = valid_folds / len(folds) if folds else 0.0
|
| 541 |
+
|
| 542 |
+
# Bonus: expected_ratio is reasonable (0.0 to 1.0)
|
| 543 |
+
expected = plan.get("expected_ratio")
|
| 544 |
+
if isinstance(expected, (int, float)) and 0.0 < expected <= 1.0:
|
| 545 |
+
score = min(1.0, score + 0.1)
|
| 546 |
+
|
| 547 |
+
scores.append(min(1.0, score))
|
| 548 |
+
return scores
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def execution_reward(completions, **kwargs) -> list[float]:
|
| 552 |
+
"""
|
| 553 |
+
SpatialThinker execution/accuracy reward (weight: 0.50).
|
| 554 |
+
|
| 555 |
+
Combines code validity, physical validity, and fold quality into
|
| 556 |
+
one normalized score. This is the main reward signal.
|
| 557 |
+
|
| 558 |
+
Score range: [0.0, 1.0]
|
| 559 |
+
"""
|
| 560 |
+
scores = []
|
| 561 |
+
for completion in completions:
|
| 562 |
+
response = completion[0]["content"]
|
| 563 |
+
function_code = extract_function(response)
|
| 564 |
+
|
| 565 |
+
# Gate: no function → 0
|
| 566 |
+
if function_code is None:
|
| 567 |
+
scores.append(0.0)
|
| 568 |
+
continue
|
| 569 |
+
|
| 570 |
+
ok, info = check_imports_stdlib_only(function_code)
|
| 571 |
+
if not ok:
|
| 572 |
+
scores.append(0.0)
|
| 573 |
+
continue
|
| 574 |
+
|
| 575 |
+
try:
|
| 576 |
+
strategy_fn = create_sandboxed_function(function_code)
|
| 577 |
+
except Exception:
|
| 578 |
+
scores.append(0.0)
|
| 579 |
+
continue
|
| 580 |
+
|
| 581 |
+
try:
|
| 582 |
+
paper = _create_sheet(
|
| 583 |
+
_current_task["width"],
|
| 584 |
+
_current_task["height"],
|
| 585 |
+
_current_task["material"],
|
| 586 |
+
)
|
| 587 |
+
original = paper
|
| 588 |
+
final_state, applied, error = execute_fold_strategy(
|
| 589 |
+
strategy_fn, paper, _current_task["max_folds"]
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
if error or len(applied) == 0:
|
| 593 |
+
scores.append(0.0)
|
| 594 |
+
continue
|
| 595 |
+
|
| 596 |
+
val = validate_paper(final_state)
|
| 597 |
+
metrics = compute_metrics(final_state, original)
|
| 598 |
+
deploy_ratio = metrics.get("deployment_ratio", 1.0)
|
| 599 |
+
max_strain = metrics.get("max_strain", 0.0)
|
| 600 |
+
|
| 601 |
+
# Physical validity component (0-0.3)
|
| 602 |
+
phys = 0.3
|
| 603 |
+
if not val.is_valid:
|
| 604 |
+
phys -= 0.1 * val.kawasaki_violation
|
| 605 |
+
phys -= 0.1 * val.maekawa_violation
|
| 606 |
+
if val.self_intersection_count > 0:
|
| 607 |
+
phys -= 0.15
|
| 608 |
+
mat_limit = _current_task["material"].max_strain
|
| 609 |
+
if max_strain > mat_limit:
|
| 610 |
+
phys -= 0.05
|
| 611 |
+
phys = max(0.0, phys)
|
| 612 |
+
|
| 613 |
+
# Quality component (0-0.5)
|
| 614 |
+
compactness = 1.0 - deploy_ratio
|
| 615 |
+
quality = 0.5 * compactness
|
| 616 |
+
|
| 617 |
+
# Target bonus (0-0.2)
|
| 618 |
+
target = 0.0
|
| 619 |
+
if deploy_ratio <= _current_task["target_ratio"]:
|
| 620 |
+
target = 0.2
|
| 621 |
+
|
| 622 |
+
score = phys + quality + target
|
| 623 |
+
scores.append(min(1.0, score))
|
| 624 |
+
|
| 625 |
+
except Exception:
|
| 626 |
+
scores.append(0.0)
|
| 627 |
+
|
| 628 |
+
return scores
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def consistency_reward(completions, **kwargs) -> list[float]:
|
| 632 |
+
"""
|
| 633 |
+
SpatialThinker consistency reward (weight: 0.20).
|
| 634 |
+
|
| 635 |
+
Checks that <plan> matches <code> and <verify> matches actual results.
|
| 636 |
+
- Plan fold count matches code fold count
|
| 637 |
+
- Verify predictions close to actual metrics
|
| 638 |
+
|
| 639 |
+
Score range: [0.0, 1.0]
|
| 640 |
+
"""
|
| 641 |
+
scores = []
|
| 642 |
+
for completion in completions:
|
| 643 |
+
response = completion[0]["content"]
|
| 644 |
+
plan = extract_plan_json(response)
|
| 645 |
+
verify = extract_section(response, "verify")
|
| 646 |
+
function_code = extract_function(response)
|
| 647 |
+
|
| 648 |
+
# Need at least plan + code to check consistency
|
| 649 |
+
if plan is None or function_code is None:
|
| 650 |
+
scores.append(0.0)
|
| 651 |
+
continue
|
| 652 |
+
|
| 653 |
+
score = 0.0
|
| 654 |
+
|
| 655 |
+
# 1. Plan fold count vs code fold count (0.4)
|
| 656 |
+
plan_folds = plan.get("folds", [])
|
| 657 |
+
plan_count = len(plan_folds)
|
| 658 |
+
|
| 659 |
+
try:
|
| 660 |
+
strategy_fn = create_sandboxed_function(function_code)
|
| 661 |
+
paper = _create_sheet(
|
| 662 |
+
_current_task["width"],
|
| 663 |
+
_current_task["height"],
|
| 664 |
+
_current_task["material"],
|
| 665 |
+
)
|
| 666 |
+
original = paper
|
| 667 |
+
final_state, applied, error = execute_fold_strategy(
|
| 668 |
+
strategy_fn, paper, _current_task["max_folds"]
|
| 669 |
+
)
|
| 670 |
+
if error or len(applied) == 0:
|
| 671 |
+
scores.append(0.0)
|
| 672 |
+
continue
|
| 673 |
+
|
| 674 |
+
actual_count = len(applied)
|
| 675 |
+
if plan_count == actual_count:
|
| 676 |
+
score += 0.4
|
| 677 |
+
elif abs(plan_count - actual_count) <= 1:
|
| 678 |
+
score += 0.2
|
| 679 |
+
|
| 680 |
+
# 2. Verify predictions vs actual (0.6)
|
| 681 |
+
if verify:
|
| 682 |
+
metrics = compute_metrics(final_state, original)
|
| 683 |
+
actual_ratio = metrics.get("deployment_ratio", 1.0)
|
| 684 |
+
|
| 685 |
+
# Extract predicted ratio from verify text
|
| 686 |
+
ratio_match = re.search(
|
| 687 |
+
r'deployment\s*ratio[:\s]*([\d.]+)', verify, re.IGNORECASE)
|
| 688 |
+
if ratio_match:
|
| 689 |
+
predicted_ratio = float(ratio_match.group(1))
|
| 690 |
+
error_pct = abs(predicted_ratio - actual_ratio)
|
| 691 |
+
if error_pct < 0.05:
|
| 692 |
+
score += 0.4
|
| 693 |
+
elif error_pct < 0.15:
|
| 694 |
+
score += 0.2
|
| 695 |
+
elif error_pct < 0.3:
|
| 696 |
+
score += 0.1
|
| 697 |
+
|
| 698 |
+
# Extract predicted fold count
|
| 699 |
+
count_match = re.search(
|
| 700 |
+
r'fold\s*count[:\s]*(\d+)', verify, re.IGNORECASE)
|
| 701 |
+
if count_match:
|
| 702 |
+
predicted_count = int(count_match.group(1))
|
| 703 |
+
if predicted_count == actual_count:
|
| 704 |
+
score += 0.2
|
| 705 |
+
elif abs(predicted_count - actual_count) <= 1:
|
| 706 |
+
score += 0.1
|
| 707 |
+
|
| 708 |
+
except Exception:
|
| 709 |
+
scores.append(0.0)
|
| 710 |
+
continue
|
| 711 |
+
|
| 712 |
+
scores.append(min(1.0, score))
|
| 713 |
+
return scores
|
trainer/train.py
CHANGED
|
@@ -19,7 +19,10 @@ if PROJECT_ROOT not in sys.path:
|
|
| 19 |
sys.path.insert(0, PROJECT_ROOT)
|
| 20 |
|
| 21 |
from trainer.prompts import build_prompt, SYSTEM_PROMPT, get_task_target_ratio, get_task_max_folds
|
| 22 |
-
from trainer.rewards import
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
try:
|
| 25 |
from engine.materials import get_material
|
|
@@ -167,14 +170,18 @@ def main():
|
|
| 167 |
# ========================================================================
|
| 168 |
# 6. Create trainer and start training
|
| 169 |
# ========================================================================
|
|
|
|
|
|
|
| 170 |
trainer = GRPOTrainer(
|
| 171 |
model=model,
|
| 172 |
processing_class=tokenizer,
|
| 173 |
reward_funcs=[
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
| 177 |
],
|
|
|
|
| 178 |
args=training_args,
|
| 179 |
train_dataset=dataset,
|
| 180 |
)
|
|
|
|
| 19 |
sys.path.insert(0, PROJECT_ROOT)
|
| 20 |
|
| 21 |
from trainer.prompts import build_prompt, SYSTEM_PROMPT, get_task_target_ratio, get_task_max_folds
|
| 22 |
+
from trainer.rewards import (
|
| 23 |
+
code_valid, physically_valid, fold_quality, set_task_config,
|
| 24 |
+
format_reward, spatial_reward, execution_reward, consistency_reward,
|
| 25 |
+
)
|
| 26 |
|
| 27 |
try:
|
| 28 |
from engine.materials import get_material
|
|
|
|
| 170 |
# ========================================================================
|
| 171 |
# 6. Create trainer and start training
|
| 172 |
# ========================================================================
|
| 173 |
+
# SpatialThinker dense rewards (weighted: 0.10 + 0.20 + 0.50 + 0.20)
|
| 174 |
+
# These replace the legacy 3-reward setup with structured spatial reasoning
|
| 175 |
trainer = GRPOTrainer(
|
| 176 |
model=model,
|
| 177 |
processing_class=tokenizer,
|
| 178 |
reward_funcs=[
|
| 179 |
+
format_reward, # 0.10 — 4-stage format compliance
|
| 180 |
+
spatial_reward, # 0.20 — fold plan geometric validity
|
| 181 |
+
execution_reward, # 0.50 — code execution + physical quality
|
| 182 |
+
consistency_reward, # 0.20 — plan↔code↔verify agreement
|
| 183 |
],
|
| 184 |
+
reward_weights=[0.10, 0.20, 0.50, 0.20],
|
| 185 |
args=training_args,
|
| 186 |
train_dataset=dataset,
|
| 187 |
)
|