ianalin123 commited on
Commit
5eca717
·
2 Parent(s): 8cc1585 2c8a058

Merge branch 'main' of https://huggingface.co/spaces/openenv-community/optigami

Browse files
Files changed (6) hide show
  1. train.py +29 -6
  2. train_origami.ipynb +184 -0
  3. trainer/mock_env.py +16 -0
  4. trainer/prompts.py +108 -66
  5. trainer/rewards.py +339 -14
  6. trainer/train.py +11 -4
train.py CHANGED
@@ -2,9 +2,14 @@
2
  OrigamiRL — GRPO Training Script
3
  Code-as-policy: model generates complete fold sequence, gets terminal reward.
4
 
 
 
 
5
  Usage:
6
  python train.py
7
- python train.py --model unsloth/Qwen2.5-7B-Instruct --epochs 3 --output origami-grpo
 
 
8
  """
9
  import argparse
10
  import json
@@ -13,10 +18,19 @@ import random
13
  from pathlib import Path
14
  from typing import Optional
15
 
 
 
 
 
 
 
 
16
 
17
  def parse_args():
18
  parser = argparse.ArgumentParser()
19
- parser.add_argument('--model', default='unsloth/Qwen2.5-7B-Instruct')
 
 
20
  parser.add_argument('--max_seq_length', type=int, default=2048)
21
  parser.add_argument('--epochs', type=int, default=3)
22
  parser.add_argument('--batch_size', type=int, default=2)
@@ -148,20 +162,29 @@ def main():
148
  return
149
 
150
  # Load model via unsloth
 
 
 
 
151
  try:
152
- from unsloth import FastLanguageModel
 
 
 
 
 
153
  except ImportError:
154
  print("ERROR: unsloth not installed. Run: pip install unsloth")
155
  print("Or run with --dry_run to test the reward function without a model.")
156
  return
157
 
158
- model, tokenizer = FastLanguageModel.from_pretrained(
159
  model_name=args.model,
160
  max_seq_length=args.max_seq_length,
161
  load_in_4bit=True,
162
  )
163
 
164
- model = FastLanguageModel.get_peft_model(
165
  model,
166
  r=32,
167
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
@@ -193,7 +216,7 @@ def main():
193
  num_generations=args.n_generations,
194
  temperature=1.0,
195
  logging_steps=1,
196
- report_to="wandb",
197
  run_name="origami-grpo",
198
  )
199
 
 
2
  OrigamiRL — GRPO Training Script
3
  Code-as-policy: model generates complete fold sequence, gets terminal reward.
4
 
5
+ Base model: SpatialThinker (Qwen2.5-VL-7B fine-tuned for spatial reasoning)
6
+ or any Unsloth-compatible model.
7
+
8
  Usage:
9
  python train.py
10
+ python train.py --model unsloth/Qwen2.5-VL-7B-Instruct --epochs 3
11
+ python train.py --model OX-PIXL/SpatialThinker-Qwen2.5-VL-7B --epochs 3
12
+ python train.py --dry_run # test rewards without GPU
13
  """
14
  import argparse
15
  import json
 
18
  from pathlib import Path
19
  from typing import Optional
20
 
21
+ # VL (vision-language) model identifiers — use FastVisionModel for these
22
+ _VL_MODEL_PATTERNS = ['VL', 'vl', 'Vision', 'vision', 'SpatialThinker', 'SpaceThinker']
23
+
24
+
25
+ def _is_vl_model(model_name: str) -> bool:
26
+ return any(p in model_name for p in _VL_MODEL_PATTERNS)
27
+
28
 
29
  def parse_args():
30
  parser = argparse.ArgumentParser()
31
+ parser.add_argument('--model', default='unsloth/Qwen2.5-VL-7B-Instruct',
32
+ help='Base model. Use unsloth/Qwen2.5-VL-7B-Instruct or '
33
+ 'OX-PIXL/SpatialThinker-Qwen2.5-VL-7B for spatial reasoning')
34
  parser.add_argument('--max_seq_length', type=int, default=2048)
35
  parser.add_argument('--epochs', type=int, default=3)
36
  parser.add_argument('--batch_size', type=int, default=2)
 
162
  return
163
 
164
  # Load model via unsloth
165
+ # VL models (SpatialThinker, Qwen2.5-VL) use FastVisionModel
166
+ # Text-only models use FastLanguageModel
167
+ is_vl = _is_vl_model(args.model)
168
+
169
  try:
170
+ if is_vl:
171
+ from unsloth import FastVisionModel as ModelLoader
172
+ print(f"Loading VL model (vision-language): {args.model}")
173
+ else:
174
+ from unsloth import FastLanguageModel as ModelLoader
175
+ print(f"Loading text model: {args.model}")
176
  except ImportError:
177
  print("ERROR: unsloth not installed. Run: pip install unsloth")
178
  print("Or run with --dry_run to test the reward function without a model.")
179
  return
180
 
181
+ model, tokenizer = ModelLoader.from_pretrained(
182
  model_name=args.model,
183
  max_seq_length=args.max_seq_length,
184
  load_in_4bit=True,
185
  )
186
 
187
+ model = ModelLoader.get_peft_model(
188
  model,
189
  r=32,
190
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
 
216
  num_generations=args.n_generations,
217
  temperature=1.0,
218
  logging_steps=1,
219
+ report_to="trackio",
220
  run_name="origami-grpo",
221
  )
222
 
train_origami.ipynb ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "8smrrb11v84",
6
+ "source": "# Optigami — Origami RL Training (GRPO)\n\n**Train an LLM to generate valid origami fold sequences using verifiable geometric rewards.**\n\nArchitecture:\n- **Environment**: `env/` — CreaseGraph + Kawasaki/Maekawa/BLB verifiers + target matching\n- **Policy model**: SpatialThinker (Qwen2.5-VL-7B) or vanilla Qwen2.5-VL-7B\n- **Training**: Unsloth GRPO — model generates complete fold sequences, gets terminal reward\n- **Tracking**: Trackio — real-time reward curves on HF Spaces\n\n| Reward Component | Weight | What it measures |\n|---|---|---|\n| `progress` | 0.45 | Geometric crease coverage vs target |\n| `economy` | 0.10 | Penalty for excess creases |\n| `kawasaki` | 0.08 | Kawasaki theorem satisfaction |\n| `maekawa` | 0.07 | Maekawa theorem satisfaction |\n| `blb` | 0.05 | Big-Little-Big lemma |\n| `anchored` | 0.05 | Valid anchor point usage |\n| `completion` | +10.0 | Bonus when target reached |",
7
+ "metadata": {}
8
+ },
9
+ {
10
+ "cell_type": "markdown",
11
+ "id": "kn1k9d357j",
12
+ "source": "## 1. Setup\n\n**GPU**: H100 80GB (Northflank/CoreWeave) or A100/T4 (Colab)\n\nInstall dependencies. Unsloth handles efficient model loading + LoRA.",
13
+ "metadata": {}
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "id": "d10vqzep5b6",
18
+ "source": "%%capture\n!pip install unsloth trackio shapely numpy datasets\n!pip install --upgrade trl transformers\n\n# Check GPU\nimport torch\nprint(f\"GPU: {torch.cuda.get_device_name(0)}\")\nprint(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")",
19
+ "metadata": {},
20
+ "execution_count": null,
21
+ "outputs": []
22
+ },
23
+ {
24
+ "cell_type": "markdown",
25
+ "id": "y6wsagz8h",
26
+ "source": "## 2. Configuration\n\nChoose base model and hyperparameters. Two options:\n- **SpatialThinker** (`OX-PIXL/SpatialThinker-Qwen2.5-VL-7B`): Pre-trained for spatial reasoning via RL\n- **Vanilla Qwen2.5-VL** (`unsloth/Qwen2.5-VL-7B-Instruct`): Standard vision-language model\n\nWe'll compare both to see which learns origami folding faster.",
27
+ "metadata": {}
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "id": "dh1zapl0w5s",
32
+ "source": "# ── Config ──────────────────────────────────────────────────────────────────\n# Toggle MODEL_NAME to switch between SpatialThinker and vanilla Qwen2.5-VL\n\nMODEL_NAME = \"OX-PIXL/SpatialThinker-Qwen2.5-VL-7B\"\n# MODEL_NAME = \"unsloth/Qwen2.5-VL-7B-Instruct\" # uncomment for vanilla\n\nMAX_SEQ_LENGTH = 2048\nLORA_R = 32\nLORA_ALPHA = 32\nEPOCHS = 3\nBATCH_SIZE = 2\nGRAD_ACCUM = 4\nLR = 5e-6\nN_GENERATIONS = 8 # completions sampled per prompt (GRPO group size)\nMAX_FOLDS = 8 # max folds per episode\nLEVEL = 1 # target difficulty (1=simple, 2=medium, 3=hard)\nMAX_COMPLETION_LEN = 512\nOUTPUT_DIR = \"origami-grpo\"\n\n# Trackio — set your HF Space ID for live dashboard\nTRACKIO_SPACE_ID = None # e.g. \"your-username/optigami-training\"\n\nprint(f\"Model: {MODEL_NAME}\")\nprint(f\"Config: {EPOCHS} epochs, batch={BATCH_SIZE}, grad_accum={GRAD_ACCUM}, lr={LR}\")\nprint(f\"GRPO: {N_GENERATIONS} generations, max_folds={MAX_FOLDS}, level={LEVEL}\")",
33
+ "metadata": {},
34
+ "execution_count": null,
35
+ "outputs": []
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "id": "o5hhfbp0wb",
40
+ "source": "## 3. Clone Repo & Test Environment\n\nClone the optigami repo (skip if running locally) and verify the environment works.",
41
+ "metadata": {}
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "id": "94cemjucczl",
46
+ "source": "import os\n\n# Clone repo if not already present (Colab/Northflank)\nif not os.path.exists(\"env/environment.py\"):\n !git clone https://huggingface.co/spaces/openenv-community/optigami /content/optigami 2>/dev/null || true\n os.chdir(\"/content/optigami\")\n\n# Verify env/ is accessible\nfrom env.environment import OrigamiEnvironment\nfrom env.rewards import compute_reward\nfrom env.prompts import parse_fold_list\n\nenv = OrigamiEnvironment(mode=\"code_as_policy\", max_steps=MAX_FOLDS)\nprint(f\"Available targets: {env.available_targets()}\")\nprint(f\"Environment mode: {env.mode}\")",
47
+ "metadata": {},
48
+ "execution_count": null,
49
+ "outputs": []
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "id": "2j9mccejyfx",
54
+ "source": "# ── Dry run: test reward function ───────────────────────────────────────────\n# Verify rewards work before loading the model\n\nimport copy\n\ndef make_reward_fn(env_template, max_folds):\n \"\"\"Reward function: clone env, run completion, return total reward.\"\"\"\n def reward_fn(completions, prompts=None, **kwargs):\n rewards = []\n target_names = kwargs.get(\"target_names\", [None] * len(completions))\n for completion, target_name in zip(completions, target_names):\n try:\n e = env_template.clone()\n e.reset(target_name=target_name)\n _, reward_dict, _, _ = e.step(completion)\n rewards.append(float(reward_dict[\"total\"]))\n except Exception:\n rewards.append(-0.1)\n return rewards\n return reward_fn\n\nreward_fn = make_reward_fn(env, MAX_FOLDS)\n\ntest_completions = [\n '<folds>[{\"instruction\": \"Valley fold along horizontal center\", \"from\": [0, 0.5], \"to\": [1, 0.5], \"assignment\": \"V\"}]</folds>',\n '<folds>[{\"instruction\": \"Bad fold\", \"from\": [0.3, 0.3], \"to\": [0.7, 0.7], \"assignment\": \"V\"}]</folds>',\n 'not valid JSON',\n]\ntarget_names = [\"half_horizontal\"] * 3\nrewards = reward_fn(test_completions, target_names=target_names)\n\nfor comp, r in zip([\"perfect fold\", \"partial fold\", \"garbage\"], rewards):\n print(f\" {comp}: reward = {r:.3f}\")\nprint(\"\\nReward function OK ✓\")",
55
+ "metadata": {},
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "id": "46gs2p1cy4",
62
+ "source": "## 4. Load Model + LoRA\n\nLoad the VL model with Unsloth's `FastVisionModel` (4-bit quantized) and apply LoRA adapters.",
63
+ "metadata": {}
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "id": "82f76od6d2k",
68
+ "source": "from unsloth import FastVisionModel\n\nmodel, tokenizer = FastVisionModel.from_pretrained(\n model_name=MODEL_NAME,\n max_seq_length=MAX_SEQ_LENGTH,\n load_in_4bit=True,\n)\n\nmodel = FastVisionModel.get_peft_model(\n model,\n r=LORA_R,\n target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\"],\n lora_alpha=LORA_ALPHA,\n lora_dropout=0,\n use_gradient_checkpointing=\"unsloth\",\n)\n\nprint(f\"Model loaded: {MODEL_NAME}\")\nprint(f\"LoRA rank: {LORA_R}, alpha: {LORA_ALPHA}\")\nprint(f\"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\")",
69
+ "metadata": {},
70
+ "execution_count": null,
71
+ "outputs": []
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "id": "67dyfrj23y",
76
+ "source": "## 5. Build Dataset\n\nGenerate prompts from all level-appropriate targets. Each prompt embeds the target crease pattern description and asks the model to output `<folds>[...]</folds>`.",
77
+ "metadata": {}
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "id": "1msqpzj5fwu",
82
+ "source": "import random\nfrom datasets import Dataset\n\ndef build_dataset(env, level=1):\n \"\"\"Build training dataset of prompts from env targets.\"\"\"\n all_names = env.available_targets()\n level_names = [\n n for n in all_names\n if env._targets[n].get(\"level\", 1) == level\n ]\n if not level_names:\n level_names = all_names\n\n items = []\n for name in level_names:\n obs = env.reset(target_name=name)\n items.append({\"prompt\": obs[\"prompt\"], \"target_name\": name})\n\n # Repeat each target 10x; ensure at least 50 examples\n repeat = max(10, (50 + len(items) - 1) // len(items))\n items = items * repeat\n random.shuffle(items)\n return items\n\ndataset_items = build_dataset(env, level=LEVEL)\nhf_dataset = Dataset.from_list(dataset_items)\n\nprint(f\"Dataset: {len(dataset_items)} examples\")\nprint(f\"Targets in dataset: {sorted(set(d['target_name'] for d in dataset_items))}\")\nprint(f\"\\nSample prompt (first 300 chars):\\n{dataset_items[0]['prompt'][:300]}...\")",
83
+ "metadata": {},
84
+ "execution_count": null,
85
+ "outputs": []
86
+ },
87
+ {
88
+ "cell_type": "markdown",
89
+ "id": "7n3r3nsw8ae",
90
+ "source": "## 6. Trackio Setup\n\nInitialize Trackio for real-time training visualization. Trackio is a free W&B alternative that deploys a Gradio dashboard to HF Spaces.",
91
+ "metadata": {}
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "id": "bgru9rsw95b",
96
+ "source": "import trackio\n\n# Initialize Trackio run\ntrackio_kwargs = {\n \"project_name\": \"optigami\",\n \"run_name\": f\"grpo-{MODEL_NAME.split('/')[-1]}-level{LEVEL}\",\n}\nif TRACKIO_SPACE_ID:\n trackio_kwargs[\"space_id\"] = TRACKIO_SPACE_ID\n\ntrackio.init(**trackio_kwargs)\nprint(f\"Trackio initialized: {trackio_kwargs['run_name']}\")\nif TRACKIO_SPACE_ID:\n print(f\"Dashboard: https://huggingface.co/spaces/{TRACKIO_SPACE_ID}\")",
97
+ "metadata": {},
98
+ "execution_count": null,
99
+ "outputs": []
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "id": "n8aqymlszo",
104
+ "source": "## 7. GRPO Training\n\nRun GRPO with Trackio logging. The trainer samples `N_GENERATIONS` completions per prompt, computes rewards via the environment, and updates the policy using group-relative advantages.",
105
+ "metadata": {}
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "id": "ci4imd9ws7v",
110
+ "source": "from trl import GRPOConfig, GRPOTrainer\n\n# ── Per-component reward functions for detailed logging ─────────────────────\nREWARD_COMPONENTS = [\"kawasaki\", \"maekawa\", \"blb\", \"progress\", \"economy\", \"completion\"]\n\ndef make_component_fn(env_template, component):\n \"\"\"Create a reward function that returns a single component's value.\"\"\"\n def component_fn(completions, target_name=None, **kwargs):\n target_names = target_name if isinstance(target_name, list) else [target_name] * len(completions)\n rewards = []\n for completion, tn in zip(completions, target_names):\n try:\n e = env_template.clone()\n e.reset(target_name=tn)\n _, reward_dict, _, _ = e.step(completion)\n rewards.append(float(reward_dict.get(component, 0.0)))\n except Exception:\n rewards.append(0.0)\n return rewards\n component_fn.__name__ = f\"reward_{component}\"\n return component_fn\n\n# Main reward function (returns total reward)\ndef wrapped_reward_fn(completions, target_name=None, **kwargs):\n \"\"\"Main reward function — extracts target_name from batch columns.\"\"\"\n target_names = target_name if isinstance(target_name, list) else [target_name] * len(completions)\n return reward_fn(completions, target_names=target_names)\n\n# Build list of all reward functions: [total, kawasaki, maekawa, blb, progress, economy, completion]\nall_reward_fns = [wrapped_reward_fn] + [\n make_component_fn(env, c) for c in REWARD_COMPONENTS\n]\n\n# ── GRPO Config ─────────────────────────────────────────────────────────────\nconfig = GRPOConfig(\n output_dir=OUTPUT_DIR,\n num_train_epochs=EPOCHS,\n per_device_train_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=GRAD_ACCUM,\n learning_rate=LR,\n max_completion_length=MAX_COMPLETION_LEN,\n num_generations=N_GENERATIONS,\n temperature=1.0,\n logging_steps=1,\n save_steps=50,\n report_to=\"trackio\",\n run_name=f\"grpo-{MODEL_NAME.split('/')[-1]}-level{LEVEL}\",\n)\n\ntrainer = GRPOTrainer(\n model=model,\n config=config,\n train_dataset=hf_dataset,\n reward_funcs=all_reward_fns,\n tokenizer=tokenizer,\n)\n\nprint(f\"Trainer ready. Starting GRPO training...\")\nprint(f\" Model: {MODEL_NAME}\")\nprint(f\" Dataset: {len(hf_dataset)} examples\")\nprint(f\" Reward functions: total + {REWARD_COMPONENTS}\")\nprint(f\" Logging to: Trackio\")",
111
+ "metadata": {},
112
+ "execution_count": null,
113
+ "outputs": []
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "id": "mwy1i7d509",
118
+ "source": "# ── Train! ──────────────────────────────────────────────────────────────────\ntrainer.train()\n\n# Save model + tokenizer\nmodel.save_pretrained(OUTPUT_DIR)\ntokenizer.save_pretrained(OUTPUT_DIR)\nprint(f\"\\nTraining complete. Model saved to {OUTPUT_DIR}/\")",
119
+ "metadata": {},
120
+ "execution_count": null,
121
+ "outputs": []
122
+ },
123
+ {
124
+ "cell_type": "markdown",
125
+ "id": "54iu59w1ipb",
126
+ "source": "## 8. Evaluation\n\nRun the trained model on all targets and measure solve rates + average rewards.",
127
+ "metadata": {}
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "id": "7loiy5mrsi3",
132
+ "source": "# ── Evaluate trained model on all targets ───────────────────────────────────\nfrom transformers import TextStreamer\n\nFastVisionModel.for_inference(model)\n\neval_results = {}\nn_samples = 5 # generations per target\n\nfor target_name in env.available_targets():\n obs = env.reset(target_name=target_name)\n prompt = obs[\"prompt\"]\n\n # Tokenize prompt\n messages = [{\"role\": \"user\", \"content\": prompt}]\n input_ids = tokenizer.apply_chat_template(messages, return_tensors=\"pt\").to(model.device)\n\n target_rewards = []\n target_solved = 0\n\n for _ in range(n_samples):\n outputs = model.generate(\n input_ids=input_ids,\n max_new_tokens=MAX_COMPLETION_LEN,\n temperature=0.7,\n do_sample=True,\n )\n completion = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)\n\n # Score completion\n e = env.clone()\n e.reset(target_name=target_name)\n try:\n _, reward_dict, _, _ = e.step(completion)\n total = float(reward_dict[\"total\"])\n solved = reward_dict.get(\"completion\", 0) > 0\n except Exception:\n total = -0.1\n solved = False\n\n target_rewards.append(total)\n if solved:\n target_solved += 1\n\n eval_results[target_name] = {\n \"avg_reward\": sum(target_rewards) / len(target_rewards),\n \"max_reward\": max(target_rewards),\n \"solve_rate\": target_solved / n_samples,\n }\n print(f\" {target_name:20s} avg={eval_results[target_name]['avg_reward']:.3f} \"\n f\"max={eval_results[target_name]['max_reward']:.3f} \"\n f\"solved={target_solved}/{n_samples}\")\n\n# Summary\navg_solve = sum(r[\"solve_rate\"] for r in eval_results.values()) / len(eval_results)\navg_reward = sum(r[\"avg_reward\"] for r in eval_results.values()) / len(eval_results)\nprint(f\"\\nOverall: avg_reward={avg_reward:.3f}, solve_rate={avg_solve:.1%}\")\n\n# Log to Trackio\ntrackio.log({\"eval/avg_reward\": avg_reward, \"eval/solve_rate\": avg_solve})\nfor name, res in eval_results.items():\n trackio.log({f\"eval/{name}_reward\": res[\"avg_reward\"], f\"eval/{name}_solved\": res[\"solve_rate\"]})",
133
+ "metadata": {},
134
+ "execution_count": null,
135
+ "outputs": []
136
+ },
137
+ {
138
+ "cell_type": "markdown",
139
+ "id": "c9dr4ht05r",
140
+ "source": "## 9. A/B Comparison: SpatialThinker vs Vanilla Qwen2.5-VL\n\nTo compare both models, run this notebook twice:\n1. First run with `MODEL_NAME = \"OX-PIXL/SpatialThinker-Qwen2.5-VL-7B\"`\n2. Second run with `MODEL_NAME = \"unsloth/Qwen2.5-VL-7B-Instruct\"`\n\nBoth runs log to the same Trackio project (`optigami`) with different run names, so you can overlay the reward curves directly in the dashboard.\n\nThe cell below loads saved eval results from both runs for comparison (run after both trainings complete).",
141
+ "metadata": {}
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "id": "qwwd4wyuhnq",
146
+ "source": "# ── Save eval results for comparison ────────────────────────────────────────\nimport json\n\nmodel_tag = MODEL_NAME.split(\"/\")[-1]\neval_path = f\"eval_results_{model_tag}_level{LEVEL}.json\"\n\nwith open(eval_path, \"w\") as f:\n json.dump(eval_results, f, indent=2)\nprint(f\"Eval results saved to {eval_path}\")\n\n# ── Compare (run after both models are trained) ────────────────────────────\nspatial_path = f\"eval_results_SpatialThinker-Qwen2.5-VL-7B_level{LEVEL}.json\"\nvanilla_path = f\"eval_results_Qwen2.5-VL-7B-Instruct_level{LEVEL}.json\"\n\nif os.path.exists(spatial_path) and os.path.exists(vanilla_path):\n with open(spatial_path) as f:\n spatial = json.load(f)\n with open(vanilla_path) as f:\n vanilla = json.load(f)\n\n print(f\"\\n{'Target':<22} {'SpatialThinker':>16} {'Vanilla Qwen':>16} {'Delta':>10}\")\n print(\"-\" * 66)\n for target in sorted(set(list(spatial.keys()) + list(vanilla.keys()))):\n s_r = spatial.get(target, {}).get(\"avg_reward\", 0)\n v_r = vanilla.get(target, {}).get(\"avg_reward\", 0)\n delta = s_r - v_r\n print(f\" {target:<20} {s_r:>14.3f} {v_r:>14.3f} {delta:>+8.3f}\")\n\n s_avg = sum(r[\"avg_reward\"] for r in spatial.values()) / len(spatial)\n v_avg = sum(r[\"avg_reward\"] for r in vanilla.values()) / len(vanilla)\n print(f\"\\n {'OVERALL':<20} {s_avg:>14.3f} {v_avg:>14.3f} {s_avg - v_avg:>+8.3f}\")\n\n s_solve = sum(r[\"solve_rate\"] for r in spatial.values()) / len(spatial)\n v_solve = sum(r[\"solve_rate\"] for r in vanilla.values()) / len(vanilla)\n print(f\" {'Solve Rate':<20} {s_solve:>13.1%} {v_solve:>13.1%} {s_solve - v_solve:>+7.1%}\")\nelse:\n print(f\"Run both models to compare. Looking for:\\n {spatial_path}\\n {vanilla_path}\")",
147
+ "metadata": {},
148
+ "execution_count": null,
149
+ "outputs": []
150
+ },
151
+ {
152
+ "cell_type": "markdown",
153
+ "id": "812csd43vxk",
154
+ "source": "## 10. Push to HuggingFace Hub (optional)\n\nUpload the trained LoRA adapter to HF for deployment or further fine-tuning.",
155
+ "metadata": {}
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "id": "h38kp70n16q",
160
+ "source": "# ── Push to HF Hub (uncomment and set your repo) ───────────────────────────\n# from huggingface_hub import login\n# login(token=\"hf_...\") # or use HF_TOKEN env var\n#\n# HF_REPO = \"your-username/optigami-grpo-spatialthinker\"\n# model.push_to_hub(HF_REPO)\n# tokenizer.push_to_hub(HF_REPO)\n# print(f\"Model pushed to https://huggingface.co/{HF_REPO}\")\n\ntrackio.finish()\nprint(\"Done! Check your Trackio dashboard for training curves.\")",
161
+ "metadata": {},
162
+ "execution_count": null,
163
+ "outputs": []
164
+ }
165
+ ],
166
+ "metadata": {
167
+ "kernelspec": {
168
+ "display_name": "Python 3",
169
+ "language": "python",
170
+ "name": "python3"
171
+ },
172
+ "language_info": {
173
+ "name": "python",
174
+ "version": "3.10.0"
175
+ },
176
+ "colab": {
177
+ "provenance": [],
178
+ "gpuType": "A100"
179
+ },
180
+ "accelerator": "GPU"
181
+ },
182
+ "nbformat": 4,
183
+ "nbformat_minor": 5
184
+ }
trainer/mock_env.py CHANGED
@@ -135,6 +135,22 @@ def apply_fold_mock(state: PaperState, fold: dict) -> tuple[PaperState, str | No
135
  if fold_type not in ("valley", "mountain"):
136
  return state, f"Unknown fold type: {fold_type}"
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  if not (0 < angle_deg <= 180):
139
  return state, f"Angle must be in (0, 180], got {angle_deg}"
140
 
 
135
  if fold_type not in ("valley", "mountain"):
136
  return state, f"Unknown fold type: {fold_type}"
137
 
138
+ # angle=0 means "no fold" — return unchanged copy
139
+ if angle_deg == 0:
140
+ return PaperState(
141
+ vertices=state.vertices.copy(), edges=state.edges.copy(),
142
+ faces=[f[:] for f in state.faces],
143
+ assignments=state.assignments[:], fold_angles=state.fold_angles.copy(),
144
+ rest_lengths=state.rest_lengths.copy(), strain=state.strain.copy(),
145
+ energy=state.energy, face_orders=state.face_orders[:],
146
+ num_layers=state.num_layers, material=state.material,
147
+ bounding_box=state.bounding_box.copy(),
148
+ deployment_ratio=state.deployment_ratio, is_valid=state.is_valid,
149
+ kawasaki_violation=state.kawasaki_violation,
150
+ maekawa_violation=state.maekawa_violation,
151
+ self_intersections=state.self_intersections,
152
+ ), None
153
+
154
  if not (0 < angle_deg <= 180):
155
  return state, f"Angle must be in (0, 180], got {angle_deg}"
156
 
trainer/prompts.py CHANGED
@@ -1,49 +1,99 @@
1
  """
2
  Prompt templates for origami fold strategy generation.
3
 
4
- The LLM receives a task description and paper state, then generates
5
- a fold_strategy(paper_state) function that returns fold operations.
 
 
 
 
 
 
 
 
6
  """
7
 
 
 
 
 
8
  SYSTEM_PROMPT = """\
9
- You are an origami engineer. You design fold patterns for real-world applications \
10
- like solar panel packing, deployable shelters, and medical stents.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- You will be given a folding task with material constraints. Write a Python function \
13
- `fold_strategy(paper_state)` that returns a list of fold operations to achieve the goal.
 
 
 
 
14
 
15
  Rules:
16
  - Only use native Python (no imports except math, itertools, functools)
17
  - Each fold: {"type": "valley"|"mountain", "line": {"start": [x,y], "end": [x,y]}, "angle": 0-180}
18
- - Fold lines must intersect the paper boundaries
 
 
 
19
  - Fewer folds is better (efficiency matters)
20
- - Respect material strain limits
21
- - Output ONLY the function in ```python ... ``` backticks\
22
  """
23
 
24
 
 
 
 
 
25
  TASK_TEMPLATES = {
26
  "half_fold": {
27
  "name": "half_fold",
28
  "prompt": """\
29
  TASK: Fold a {width}m x {height}m {material} sheet in half to minimize one dimension.
30
 
 
 
 
 
 
 
 
31
  MATERIAL: {material} (thickness: {thickness_mm}mm, max strain: {max_strain_pct}%)
32
  CONSTRAINTS: Maximum {max_folds} fold operations.
33
- TARGET: Deployment ratio <= 0.5 (folded area is half or less of original)
34
-
35
- CURRENT STATE:
36
- Sheet: {width}m x {height}m, flat (0 folds applied)
37
- Bounding box: {width}m x {height}m x 0.0m
38
-
39
- Write a fold_strategy(paper_state) function that returns a list of fold operations.
40
- Each fold: {{"type": "valley"|"mountain", "line": {{"start": [x,y], "end": [x,y]}}, "angle": 0-180}}
41
-
42
- ```python
43
- def fold_strategy(paper_state):
44
- # Your code here
45
- return [...]
46
- ```""",
47
  "target_ratio": 0.5,
48
  "max_folds": 3,
49
  },
@@ -53,21 +103,14 @@ def fold_strategy(paper_state):
53
  "prompt": """\
54
  TASK: Fold a {width}m x {height}m {material} sheet into thirds (like a letter).
55
 
 
 
 
 
 
56
  MATERIAL: {material} (thickness: {thickness_mm}mm, max strain: {max_strain_pct}%)
57
  CONSTRAINTS: Maximum {max_folds} fold operations.
58
- TARGET: Deployment ratio <= 0.33
59
-
60
- CURRENT STATE:
61
- Sheet: {width}m x {height}m, flat (0 folds applied)
62
-
63
- Write a fold_strategy(paper_state) function that returns a list of fold operations.
64
- Each fold: {{"type": "valley"|"mountain", "line": {{"start": [x,y], "end": [x,y]}}, "angle": 0-180}}
65
-
66
- ```python
67
- def fold_strategy(paper_state):
68
- # Your code here
69
- return [...]
70
- ```""",
71
  "target_ratio": 0.33,
72
  "max_folds": 5,
73
  },
@@ -78,30 +121,22 @@ def fold_strategy(paper_state):
78
  TASK: Fold a {width}m x {height}m Mylar sheet to minimize packed volume for a solar panel.
79
  The folded panel must be deployable (unfold cleanly to near-original area).
80
 
 
 
 
 
 
81
  MATERIAL: Mylar (thickness: 0.05mm, Young's modulus: 4 GPa, max strain: 3%)
82
  CONSTRAINTS:
83
  - Maximum {max_folds} fold operations
84
  - Must pack into bounding box <= 15cm x 15cm x 5cm
85
- - Must deploy to >= 80% of original area
86
  - No self-intersections
87
 
88
- TARGET: Deployment ratio <= 0.05 (95% volume reduction)
89
-
90
- CURRENT STATE:
91
- Sheet: {width}m x {height}m, flat (0 folds applied)
92
- Bounding box: {width}m x {height}m x 0.0m
93
 
94
- HINT: Consider tessellated patterns like Miura-ori alternating mountain and valley
95
- folds in a grid create a highly compact, single-DOF deployable structure.
96
-
97
- Write a fold_strategy(paper_state) function that returns a list of fold operations.
98
- Each fold: {{"type": "valley"|"mountain", "line": {{"start": [x,y], "end": [x,y]}}, "angle": 0-180}}
99
-
100
- ```python
101
- def fold_strategy(paper_state):
102
- # Your code here
103
- return [...]
104
- ```""",
105
  "target_ratio": 0.05,
106
  "max_folds": 20,
107
  },
@@ -111,29 +146,26 @@ def fold_strategy(paper_state):
111
  "prompt": """\
112
  TASK: Fold a {width}m x {height}m Nitinol sheet into a compact cylinder for a medical stent.
113
 
 
 
 
 
114
  MATERIAL: Nitinol (thickness: 0.1mm, Young's modulus: 75 GPa, max strain: 8%)
115
  CONSTRAINTS:
116
  - Maximum {max_folds} fold operations
117
- - Compressed diameter: 3mm
118
- - Deployed diameter: 10mm
119
- - Must be radially deployable
120
-
121
- TARGET: Minimize packed cross-section while maintaining deployability.
122
-
123
- Write a fold_strategy(paper_state) function that returns a list of fold operations.
124
 
125
- ```python
126
- def fold_strategy(paper_state):
127
- # Your code here
128
- return [...]
129
- ```""",
130
  "target_ratio": 0.1,
131
  "max_folds": 15,
132
  },
133
  }
134
 
135
 
136
- # Default task configs for each level
 
 
 
137
  TASK_CONFIGS = {
138
  "half_fold": {
139
  "width": 1.0, "height": 1.0, "material": "paper",
@@ -158,6 +190,16 @@ def build_prompt(task_name: str = "half_fold", **overrides) -> str:
158
  """Build a complete user prompt for a given task."""
159
  task = TASK_TEMPLATES[task_name]
160
  config = {**TASK_CONFIGS[task_name], **overrides}
 
 
 
 
 
 
 
 
 
 
161
  return task["prompt"].format(**config)
162
 
163
 
 
1
  """
2
  Prompt templates for origami fold strategy generation.
3
 
4
+ Inspired by SpatialThinker (arXiv 2511.07403): the model must produce
5
+ a structured spatial representation BEFORE generating code.
6
+
7
+ Output format (4 stages):
8
+ <observe> — Describe the paper geometry and constraints
9
+ <plan> — Structured fold plan with coordinates and reasoning
10
+ <code> — The fold_strategy() function
11
+ <verify> — Predict expected outcome (deployment ratio, fold count)
12
+
13
+ Dense rewards check each stage independently, not just code execution.
14
  """
15
 
16
+ # ---------------------------------------------------------------------------
17
+ # System prompt — defines the structured output format
18
+ # ---------------------------------------------------------------------------
19
+
20
  SYSTEM_PROMPT = """\
21
+ You are an origami engineer specializing in computational fold design.
22
+ You solve folding tasks by reasoning spatially about paper geometry.
23
+
24
+ You MUST respond in exactly this 4-stage format:
25
+
26
+ <observe>
27
+ Describe the paper: dimensions, material, coordinate system.
28
+ Identify key geometric features (center, edges, diagonals, symmetry axes).
29
+ Note constraints (max strain, max folds, target ratio).
30
+ </observe>
31
+
32
+ <plan>
33
+ {
34
+ "strategy": "description of overall approach",
35
+ "folds": [
36
+ {
37
+ "description": "what this fold does",
38
+ "type": "valley or mountain",
39
+ "line_start": [x, y],
40
+ "line_end": [x, y],
41
+ "angle": 180,
42
+ "reasoning": "why these coordinates"
43
+ }
44
+ ],
45
+ "expected_ratio": 0.5,
46
+ "expected_folds": 1
47
+ }
48
+ </plan>
49
+
50
+ <code>
51
+ ```python
52
+ def fold_strategy(paper_state):
53
+ # Implementation matching the plan above
54
+ return [...]
55
+ ```
56
+ </code>
57
 
58
+ <verify>
59
+ Expected deployment ratio: X.XX
60
+ Expected fold count: N
61
+ Expected max strain: X.XXXX
62
+ Potential issues: ...
63
+ </verify>
64
 
65
  Rules:
66
  - Only use native Python (no imports except math, itertools, functools)
67
  - Each fold: {"type": "valley"|"mountain", "line": {"start": [x,y], "end": [x,y]}, "angle": 0-180}
68
+ - Fold lines must cross the paper boundary (intersect at least 2 edges)
69
+ - Valley = fold toward you (+Z), Mountain = fold away (-Z)
70
+ - angle=180 = fully folded, smaller = partial fold
71
+ - Each fold changes the geometry — later folds operate on already-folded paper
72
  - Fewer folds is better (efficiency matters)
73
+ - Respect material strain limits\
 
74
  """
75
 
76
 
77
+ # ---------------------------------------------------------------------------
78
+ # Task templates — each includes spatial context
79
+ # ---------------------------------------------------------------------------
80
+
81
  TASK_TEMPLATES = {
82
  "half_fold": {
83
  "name": "half_fold",
84
  "prompt": """\
85
  TASK: Fold a {width}m x {height}m {material} sheet in half to minimize one dimension.
86
 
87
+ PAPER GEOMETRY:
88
+ Corners: (0,0), ({width},0), ({width},{height}), (0,{height})
89
+ Center: ({cx},{cy})
90
+ Horizontal midline: y={cy} from (0,{cy}) to ({width},{cy})
91
+ Vertical midline: x={cx} from ({cx},0) to ({cx},{height})
92
+ Diagonals: (0,0)→({width},{height}) and ({width},0)→(0,{height})
93
+
94
  MATERIAL: {material} (thickness: {thickness_mm}mm, max strain: {max_strain_pct}%)
95
  CONSTRAINTS: Maximum {max_folds} fold operations.
96
+ TARGET: Deployment ratio <= 0.5""",
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  "target_ratio": 0.5,
98
  "max_folds": 3,
99
  },
 
103
  "prompt": """\
104
  TASK: Fold a {width}m x {height}m {material} sheet into thirds (like a letter).
105
 
106
+ PAPER GEOMETRY:
107
+ Corners: (0,0), ({width},0), ({width},{height}), (0,{height})
108
+ Third lines: y={t1:.4f} and y={t2:.4f}
109
+ Center: ({cx},{cy})
110
+
111
  MATERIAL: {material} (thickness: {thickness_mm}mm, max strain: {max_strain_pct}%)
112
  CONSTRAINTS: Maximum {max_folds} fold operations.
113
+ TARGET: Deployment ratio <= 0.33""",
 
 
 
 
 
 
 
 
 
 
 
 
114
  "target_ratio": 0.33,
115
  "max_folds": 5,
116
  },
 
121
  TASK: Fold a {width}m x {height}m Mylar sheet to minimize packed volume for a solar panel.
122
  The folded panel must be deployable (unfold cleanly to near-original area).
123
 
124
+ PAPER GEOMETRY:
125
+ Corners: (0,0), ({width},0), ({width},{height}), (0,{height})
126
+ Center: ({cx},{cy})
127
+ Area: {area}m²
128
+
129
  MATERIAL: Mylar (thickness: 0.05mm, Young's modulus: 4 GPa, max strain: 3%)
130
  CONSTRAINTS:
131
  - Maximum {max_folds} fold operations
132
  - Must pack into bounding box <= 15cm x 15cm x 5cm
 
133
  - No self-intersections
134
 
135
+ TARGET: Deployment ratio <= 0.05 (95% area reduction)
 
 
 
 
136
 
137
+ HINT: Tessellated patterns (alternating M/V folds in a grid) achieve high
138
+ compaction with single-DOF deployment. Consider dividing the sheet into
139
+ a regular grid of panels.""",
 
 
 
 
 
 
 
 
140
  "target_ratio": 0.05,
141
  "max_folds": 20,
142
  },
 
146
  "prompt": """\
147
  TASK: Fold a {width}m x {height}m Nitinol sheet into a compact cylinder for a medical stent.
148
 
149
+ PAPER GEOMETRY:
150
+ Corners: (0,0), ({width},0), ({width},{height}), (0,{height})
151
+ Center: ({cx},{cy})
152
+
153
  MATERIAL: Nitinol (thickness: 0.1mm, Young's modulus: 75 GPa, max strain: 8%)
154
  CONSTRAINTS:
155
  - Maximum {max_folds} fold operations
156
+ - Compressed diameter: 3mm, Deployed diameter: 10mm
 
 
 
 
 
 
157
 
158
+ TARGET: Deployment ratio <= 0.1""",
 
 
 
 
159
  "target_ratio": 0.1,
160
  "max_folds": 15,
161
  },
162
  }
163
 
164
 
165
+ # ---------------------------------------------------------------------------
166
+ # Config and builders
167
+ # ---------------------------------------------------------------------------
168
+
169
  TASK_CONFIGS = {
170
  "half_fold": {
171
  "width": 1.0, "height": 1.0, "material": "paper",
 
190
  """Build a complete user prompt for a given task."""
191
  task = TASK_TEMPLATES[task_name]
192
  config = {**TASK_CONFIGS[task_name], **overrides}
193
+
194
+ # Add computed geometry values
195
+ w = config["width"]
196
+ h = config["height"]
197
+ config["cx"] = w / 2
198
+ config["cy"] = h / 2
199
+ config["area"] = w * h
200
+ config["t1"] = h / 3
201
+ config["t2"] = 2 * h / 3
202
+
203
  return task["prompt"].format(**config)
204
 
205
 
trainer/rewards.py CHANGED
@@ -1,17 +1,22 @@
1
  """
2
  Reward functions for origami GRPO training.
3
 
4
- Three reward functions following the 2048 pattern:
5
- 1. code_validDoes the generated code parse and produce fold instructions?
6
- 2. physically_validAre the folds geometrically/physically valid?
7
- 3. fold_quality How good is the folding solution (compactness, efficiency)?
 
8
 
9
- Lexicographic gating (from SpatialThinker): if code doesn't parse,
10
- all downstream rewards are 0. This prevents reward hacking.
 
 
11
  """
12
 
13
  import ast
 
14
  import sys
 
15
  import math
16
  import traceback
17
  from typing import Callable
@@ -60,23 +65,57 @@ except ImportError:
60
  # ---------------------------------------------------------------------------
61
 
62
  def extract_function(text: str) -> str | None:
63
- """Extract a Python function from triple-backtick code blocks."""
64
- if text.count("```") < 2:
 
 
 
 
 
 
 
 
65
  return None
66
- first = text.find("```") + 3
67
- second = text.find("```", first)
68
- fx = text[first:second].strip()
69
- fx = fx.removeprefix("python\n").removeprefix("python\r\n")
 
70
  # Find the def statement
71
- def_idx = fx.find("def ")
72
  if def_idx == -1:
73
  return None
74
- fx = fx[def_idx:]
75
  if fx.startswith("def fold_strategy("):
76
  return fx
77
  return None
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def check_imports_stdlib_only(code: str) -> tuple[bool, str]:
81
  """Check that code only imports from Python stdlib."""
82
  try:
@@ -386,3 +425,289 @@ def fold_quality(completions, **kwargs) -> list[float]:
386
  scores.append(-3.0)
387
 
388
  return scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  Reward functions for origami GRPO training.
3
 
4
+ SpatialThinker-style dense rewards (arXiv 2511.07403):
5
+ 1. format_reward (0.10) All 4 tags present, valid JSON plan, valid function
6
+ 2. spatial_reward (0.20)Fold coordinates in plan are within bounds, lines valid
7
+ 3. execution_reward (0.50) Physical validity + fold quality (code execution)
8
+ 4. consistency_reward(0.20) — Plan matches code, verify matches actual results
9
 
10
+ Plus legacy rewards for backwards compatibility:
11
+ - code_valid, physically_valid, fold_quality
12
+
13
+ Lexicographic gating: if code doesn't parse, downstream rewards are 0.
14
  """
15
 
16
  import ast
17
+ import re
18
  import sys
19
+ import json
20
  import math
21
  import traceback
22
  from typing import Callable
 
65
  # ---------------------------------------------------------------------------
66
 
67
  def extract_function(text: str) -> str | None:
68
+ """Extract fold_strategy() from <code> blocks or triple-backtick code blocks."""
69
+ # Try <code> block first (SpatialThinker format)
70
+ code_match = re.search(r'<code>(.*?)</code>', text, re.DOTALL)
71
+ if code_match:
72
+ code_block = code_match.group(1).strip()
73
+ elif text.count("```") >= 2:
74
+ first = text.find("```") + 3
75
+ second = text.find("```", first)
76
+ code_block = text[first:second].strip()
77
+ else:
78
  return None
79
+
80
+ code_block = code_block.removeprefix("```python\n").removeprefix("```python\r\n")
81
+ code_block = code_block.removeprefix("python\n").removeprefix("python\r\n")
82
+ code_block = code_block.rstrip("`").strip()
83
+
84
  # Find the def statement
85
+ def_idx = code_block.find("def ")
86
  if def_idx == -1:
87
  return None
88
+ fx = code_block[def_idx:]
89
  if fx.startswith("def fold_strategy("):
90
  return fx
91
  return None
92
 
93
 
94
+ def extract_section(text: str, tag: str) -> str | None:
95
+ """Extract content between <tag>...</tag>."""
96
+ match = re.search(rf'<{tag}>(.*?)</{tag}>', text, re.DOTALL)
97
+ return match.group(1).strip() if match else None
98
+
99
+
100
+ def extract_plan_json(text: str) -> dict | None:
101
+ """Extract and parse the JSON fold plan from <plan> block."""
102
+ plan_text = extract_section(text, "plan")
103
+ if not plan_text:
104
+ return None
105
+ try:
106
+ return json.loads(plan_text)
107
+ except json.JSONDecodeError:
108
+ # Try to find JSON object within the plan text
109
+ brace_start = plan_text.find("{")
110
+ brace_end = plan_text.rfind("}")
111
+ if brace_start >= 0 and brace_end > brace_start:
112
+ try:
113
+ return json.loads(plan_text[brace_start:brace_end + 1])
114
+ except json.JSONDecodeError:
115
+ pass
116
+ return None
117
+
118
+
119
  def check_imports_stdlib_only(code: str) -> tuple[bool, str]:
120
  """Check that code only imports from Python stdlib."""
121
  try:
 
425
  scores.append(-3.0)
426
 
427
  return scores
428
+
429
+
430
+ # ---------------------------------------------------------------------------
431
+ # SpatialThinker Dense Rewards (weight 0.10 + 0.20 + 0.50 + 0.20 = 1.0)
432
+ # ---------------------------------------------------------------------------
433
+
434
+ REQUIRED_TAGS = ["observe", "plan", "code", "verify"]
435
+
436
+
437
+ def format_reward(completions, **kwargs) -> list[float]:
438
+ """
439
+ SpatialThinker format reward (weight: 0.10).
440
+
441
+ Checks that the response has all 4 structured tags, valid JSON in <plan>,
442
+ and a parseable function in <code>.
443
+
444
+ Score range: [0.0, 1.0]
445
+ """
446
+ scores = []
447
+ for completion in completions:
448
+ response = completion[0]["content"]
449
+ score = 0.0
450
+
451
+ # Check each required tag (0.15 each = 0.60 for all 4)
452
+ tags_present = 0
453
+ for tag in REQUIRED_TAGS:
454
+ if extract_section(response, tag) is not None:
455
+ tags_present += 1
456
+ score += 0.15 * tags_present
457
+
458
+ # Valid JSON in <plan> (0.20)
459
+ plan = extract_plan_json(response)
460
+ if plan is not None:
461
+ score += 0.20
462
+ # Plan has required fields (0.05 bonus)
463
+ if "folds" in plan and isinstance(plan["folds"], list):
464
+ score += 0.05
465
+
466
+ # Valid function in <code> (0.15)
467
+ fn = extract_function(response)
468
+ if fn is not None:
469
+ score += 0.15
470
+
471
+ scores.append(score)
472
+ return scores
473
+
474
+
475
+ def spatial_reward(completions, **kwargs) -> list[float]:
476
+ """
477
+ SpatialThinker spatial plan quality reward (weight: 0.20).
478
+
479
+ Checks that fold coordinates in <plan> are geometrically valid:
480
+ - Within paper bounds
481
+ - Line endpoints form valid fold lines (cross the paper)
482
+ - Fold types are valid
483
+ - Expected ratio/count are reasonable
484
+
485
+ Score range: [0.0, 1.0]
486
+ """
487
+ w = _current_task["width"]
488
+ h = _current_task["height"]
489
+
490
+ scores = []
491
+ for completion in completions:
492
+ response = completion[0]["content"]
493
+ plan = extract_plan_json(response)
494
+
495
+ if plan is None:
496
+ scores.append(0.0)
497
+ continue
498
+
499
+ score = 0.0
500
+ folds = plan.get("folds", [])
501
+
502
+ if not folds:
503
+ scores.append(0.0)
504
+ continue
505
+
506
+ # Score each fold in the plan
507
+ valid_folds = 0
508
+ for fold in folds:
509
+ fold_score = 0.0
510
+
511
+ # Has required fields
512
+ has_type = fold.get("type") in ("valley", "mountain")
513
+ has_start = isinstance(fold.get("line_start"), list) and len(fold.get("line_start", [])) == 2
514
+ has_end = isinstance(fold.get("line_end"), list) and len(fold.get("line_end", [])) == 2
515
+
516
+ if has_type:
517
+ fold_score += 0.25
518
+ if has_start and has_end:
519
+ fold_score += 0.25
520
+ # Coordinates within paper bounds (with small tolerance)
521
+ sx, sy = fold["line_start"]
522
+ ex, ey = fold["line_end"]
523
+ tol = 0.01
524
+ in_bounds = (
525
+ -tol <= sx <= w + tol and -tol <= sy <= h + tol and
526
+ -tol <= ex <= w + tol and -tol <= ey <= h + tol
527
+ )
528
+ if in_bounds:
529
+ fold_score += 0.25
530
+
531
+ # Start != end (not a degenerate line)
532
+ dist = math.sqrt((ex - sx)**2 + (ey - sy)**2)
533
+ if dist > 0.01:
534
+ fold_score += 0.25
535
+
536
+ if fold_score > 0.5:
537
+ valid_folds += 1
538
+
539
+ # Proportion of valid folds
540
+ score = valid_folds / len(folds) if folds else 0.0
541
+
542
+ # Bonus: expected_ratio is reasonable (0.0 to 1.0)
543
+ expected = plan.get("expected_ratio")
544
+ if isinstance(expected, (int, float)) and 0.0 < expected <= 1.0:
545
+ score = min(1.0, score + 0.1)
546
+
547
+ scores.append(min(1.0, score))
548
+ return scores
549
+
550
+
551
+ def execution_reward(completions, **kwargs) -> list[float]:
552
+ """
553
+ SpatialThinker execution/accuracy reward (weight: 0.50).
554
+
555
+ Combines code validity, physical validity, and fold quality into
556
+ one normalized score. This is the main reward signal.
557
+
558
+ Score range: [0.0, 1.0]
559
+ """
560
+ scores = []
561
+ for completion in completions:
562
+ response = completion[0]["content"]
563
+ function_code = extract_function(response)
564
+
565
+ # Gate: no function → 0
566
+ if function_code is None:
567
+ scores.append(0.0)
568
+ continue
569
+
570
+ ok, info = check_imports_stdlib_only(function_code)
571
+ if not ok:
572
+ scores.append(0.0)
573
+ continue
574
+
575
+ try:
576
+ strategy_fn = create_sandboxed_function(function_code)
577
+ except Exception:
578
+ scores.append(0.0)
579
+ continue
580
+
581
+ try:
582
+ paper = _create_sheet(
583
+ _current_task["width"],
584
+ _current_task["height"],
585
+ _current_task["material"],
586
+ )
587
+ original = paper
588
+ final_state, applied, error = execute_fold_strategy(
589
+ strategy_fn, paper, _current_task["max_folds"]
590
+ )
591
+
592
+ if error or len(applied) == 0:
593
+ scores.append(0.0)
594
+ continue
595
+
596
+ val = validate_paper(final_state)
597
+ metrics = compute_metrics(final_state, original)
598
+ deploy_ratio = metrics.get("deployment_ratio", 1.0)
599
+ max_strain = metrics.get("max_strain", 0.0)
600
+
601
+ # Physical validity component (0-0.3)
602
+ phys = 0.3
603
+ if not val.is_valid:
604
+ phys -= 0.1 * val.kawasaki_violation
605
+ phys -= 0.1 * val.maekawa_violation
606
+ if val.self_intersection_count > 0:
607
+ phys -= 0.15
608
+ mat_limit = _current_task["material"].max_strain
609
+ if max_strain > mat_limit:
610
+ phys -= 0.05
611
+ phys = max(0.0, phys)
612
+
613
+ # Quality component (0-0.5)
614
+ compactness = 1.0 - deploy_ratio
615
+ quality = 0.5 * compactness
616
+
617
+ # Target bonus (0-0.2)
618
+ target = 0.0
619
+ if deploy_ratio <= _current_task["target_ratio"]:
620
+ target = 0.2
621
+
622
+ score = phys + quality + target
623
+ scores.append(min(1.0, score))
624
+
625
+ except Exception:
626
+ scores.append(0.0)
627
+
628
+ return scores
629
+
630
+
631
+ def consistency_reward(completions, **kwargs) -> list[float]:
632
+ """
633
+ SpatialThinker consistency reward (weight: 0.20).
634
+
635
+ Checks that <plan> matches <code> and <verify> matches actual results.
636
+ - Plan fold count matches code fold count
637
+ - Verify predictions close to actual metrics
638
+
639
+ Score range: [0.0, 1.0]
640
+ """
641
+ scores = []
642
+ for completion in completions:
643
+ response = completion[0]["content"]
644
+ plan = extract_plan_json(response)
645
+ verify = extract_section(response, "verify")
646
+ function_code = extract_function(response)
647
+
648
+ # Need at least plan + code to check consistency
649
+ if plan is None or function_code is None:
650
+ scores.append(0.0)
651
+ continue
652
+
653
+ score = 0.0
654
+
655
+ # 1. Plan fold count vs code fold count (0.4)
656
+ plan_folds = plan.get("folds", [])
657
+ plan_count = len(plan_folds)
658
+
659
+ try:
660
+ strategy_fn = create_sandboxed_function(function_code)
661
+ paper = _create_sheet(
662
+ _current_task["width"],
663
+ _current_task["height"],
664
+ _current_task["material"],
665
+ )
666
+ original = paper
667
+ final_state, applied, error = execute_fold_strategy(
668
+ strategy_fn, paper, _current_task["max_folds"]
669
+ )
670
+ if error or len(applied) == 0:
671
+ scores.append(0.0)
672
+ continue
673
+
674
+ actual_count = len(applied)
675
+ if plan_count == actual_count:
676
+ score += 0.4
677
+ elif abs(plan_count - actual_count) <= 1:
678
+ score += 0.2
679
+
680
+ # 2. Verify predictions vs actual (0.6)
681
+ if verify:
682
+ metrics = compute_metrics(final_state, original)
683
+ actual_ratio = metrics.get("deployment_ratio", 1.0)
684
+
685
+ # Extract predicted ratio from verify text
686
+ ratio_match = re.search(
687
+ r'deployment\s*ratio[:\s]*([\d.]+)', verify, re.IGNORECASE)
688
+ if ratio_match:
689
+ predicted_ratio = float(ratio_match.group(1))
690
+ error_pct = abs(predicted_ratio - actual_ratio)
691
+ if error_pct < 0.05:
692
+ score += 0.4
693
+ elif error_pct < 0.15:
694
+ score += 0.2
695
+ elif error_pct < 0.3:
696
+ score += 0.1
697
+
698
+ # Extract predicted fold count
699
+ count_match = re.search(
700
+ r'fold\s*count[:\s]*(\d+)', verify, re.IGNORECASE)
701
+ if count_match:
702
+ predicted_count = int(count_match.group(1))
703
+ if predicted_count == actual_count:
704
+ score += 0.2
705
+ elif abs(predicted_count - actual_count) <= 1:
706
+ score += 0.1
707
+
708
+ except Exception:
709
+ scores.append(0.0)
710
+ continue
711
+
712
+ scores.append(min(1.0, score))
713
+ return scores
trainer/train.py CHANGED
@@ -19,7 +19,10 @@ if PROJECT_ROOT not in sys.path:
19
  sys.path.insert(0, PROJECT_ROOT)
20
 
21
  from trainer.prompts import build_prompt, SYSTEM_PROMPT, get_task_target_ratio, get_task_max_folds
22
- from trainer.rewards import code_valid, physically_valid, fold_quality, set_task_config
 
 
 
23
 
24
  try:
25
  from engine.materials import get_material
@@ -167,14 +170,18 @@ def main():
167
  # ========================================================================
168
  # 6. Create trainer and start training
169
  # ========================================================================
 
 
170
  trainer = GRPOTrainer(
171
  model=model,
172
  processing_class=tokenizer,
173
  reward_funcs=[
174
- code_valid, # Reward 1: valid Python?
175
- physically_valid, # Reward 2: physically possible folds?
176
- fold_quality, # Reward 3: how good is the solution?
 
177
  ],
 
178
  args=training_args,
179
  train_dataset=dataset,
180
  )
 
19
  sys.path.insert(0, PROJECT_ROOT)
20
 
21
  from trainer.prompts import build_prompt, SYSTEM_PROMPT, get_task_target_ratio, get_task_max_folds
22
+ from trainer.rewards import (
23
+ code_valid, physically_valid, fold_quality, set_task_config,
24
+ format_reward, spatial_reward, execution_reward, consistency_reward,
25
+ )
26
 
27
  try:
28
  from engine.materials import get_material
 
170
  # ========================================================================
171
  # 6. Create trainer and start training
172
  # ========================================================================
173
+ # SpatialThinker dense rewards (weighted: 0.10 + 0.20 + 0.50 + 0.20)
174
+ # These replace the legacy 3-reward setup with structured spatial reasoning
175
  trainer = GRPOTrainer(
176
  model=model,
177
  processing_class=tokenizer,
178
  reward_funcs=[
179
+ format_reward, # 0.10 4-stage format compliance
180
+ spatial_reward, # 0.20 fold plan geometric validity
181
+ execution_reward, # 0.50 code execution + physical quality
182
+ consistency_reward, # 0.20 — plan↔code↔verify agreement
183
  ],
184
+ reward_weights=[0.10, 0.20, 0.50, 0.20],
185
  args=training_args,
186
  train_dataset=dataset,
187
  )