prasanna287 commited on
Commit
b9a4a95
Β·
1 Parent(s): 7fccd0c

Add MAX_CONCURRENT_ENVS, sync latest changes

Browse files
Dockerfile CHANGED
@@ -9,6 +9,8 @@ RUN pip install --no-cache-dir -r requirements.txt \
9
 
10
  COPY . /app
11
 
 
 
12
  EXPOSE 8000
13
 
14
  CMD ["uvicorn", "origami_server.app:app", "--host", "0.0.0.0", "--port", "8000"]
 
9
 
10
  COPY . /app
11
 
12
+ ENV MAX_CONCURRENT_ENVS=16
13
+
14
  EXPOSE 8000
15
 
16
  CMD ["uvicorn", "origami_server.app:app", "--host", "0.0.0.0", "--port", "8000"]
origami_server/app.py CHANGED
@@ -16,6 +16,7 @@ app = create_app(
16
  OrigamiAction,
17
  OrigamiObservation,
18
  env_name="origami_env",
 
19
  )
20
 
21
  from .tasks import TASKS
 
16
  OrigamiAction,
17
  OrigamiObservation,
18
  env_name="origami_env",
19
+ max_concurrent_envs=int(os.environ.get("MAX_CONCURRENT_ENVS", 1)),
20
  )
21
 
22
  from .tasks import TASKS
tests/test_origami.py CHANGED
@@ -9,7 +9,7 @@ from origami_server.engine.simulate import simulate
9
  from origami_server.environment import OrigamiEnvironment
10
  from origami_server.models import OrigamiAction
11
  from origami_server.tasks import TASKS, get_task, list_tasks
12
- from training.reward import extract_fold_json, shape_match, valid_fold
13
 
14
  # --- Fixtures ---
15
 
@@ -221,14 +221,20 @@ class TestRewards:
221
  assert scores[0] == 1.0
222
  assert scores[1] == -2.0
223
 
224
- def test_shape_match_reward(self):
225
- import json
 
 
226
 
227
- good = [[{"content": json.dumps(TRIANGLE_FOLD)}]]
228
- bad = [[{"content": "nope"}]]
229
- scores = shape_match(good + bad, task_name="triangle")
230
- assert scores[0] == 20.0
231
- assert scores[1] == -2.0
 
 
 
 
232
 
233
 
234
  # --- API ---
 
9
  from origami_server.environment import OrigamiEnvironment
10
  from origami_server.models import OrigamiAction
11
  from origami_server.tasks import TASKS, get_task, list_tasks
12
+ from training.reward import extract_fold_json, valid_fold
13
 
14
  # --- Fixtures ---
15
 
 
221
  assert scores[0] == 1.0
222
  assert scores[1] == -2.0
223
 
224
+ def test_shape_match_via_server(self):
225
+ """shape_match reward now goes through the server (WebSocket).
226
+ Test the same flow via TestClient's websocket to verify end-to-end."""
227
+ from fastapi.testclient import TestClient
228
 
229
+ from origami_server.app import app
230
+
231
+ client = TestClient(app)
232
+ with client.websocket_connect("/ws") as ws:
233
+ ws.send_json({"type": "reset", "data": {"task_name": "triangle"}})
234
+ ws.receive_json()
235
+ ws.send_json({"type": "step", "data": {"fold_data": TRIANGLE_FOLD}})
236
+ resp = ws.receive_json()
237
+ assert resp["data"]["reward"] == 20.0
238
 
239
 
240
  # --- API ---
training/reward.py CHANGED
@@ -1,21 +1,17 @@
1
  """GRPO reward functions for origami RL training.
2
 
3
- Two reward functions (matching the 2048 pattern):
4
- 1. valid_fold: Does the LLM output parse as valid FOLD JSON?
5
- 2. shape_match: Simulate and compare to target shape.
 
 
 
6
  """
7
 
8
  import json
9
  import re
10
  from typing import Any
11
 
12
- import numpy as np
13
-
14
- from origami_server.engine.fold_parser import validate_fold
15
- from origami_server.engine.shape_match import compute_shape_match
16
- from origami_server.engine.simulate import simulate
17
- from origami_server.tasks import get_task
18
-
19
 
20
  def extract_fold_json(response: str) -> dict | None:
21
  """Extract FOLD JSON from LLM response text.
@@ -55,6 +51,8 @@ def valid_fold(completions: list, **kwargs: Any) -> list[float]:
55
  +1.0 valid FOLD JSON with correct structure
56
  -0.5 parseable JSON but invalid FOLD structure
57
  -2.0 not parseable as JSON at all
 
 
58
  """
59
  scores = []
60
  for completion in completions:
@@ -65,58 +63,35 @@ def valid_fold(completions: list, **kwargs: Any) -> list[float]:
65
  scores.append(-2.0)
66
  continue
67
 
68
- is_valid, error = validate_fold(fold_data)
69
- if is_valid:
70
- scores.append(1.0)
71
- else:
72
  scores.append(-0.5)
 
73
 
74
- return scores
75
-
76
-
77
- def shape_match(
78
- completions: list,
79
- task_name: str = "triangle",
80
- **kwargs: Any,
81
- ) -> list[float]:
82
- """Reward 2: Simulate the fold and compare to target shape.
83
-
84
- Score = similarity Γ— 20.0 (range: 0 to 20)
85
- -1.0 if simulation fails/diverges
86
- -2.0 if FOLD data is invalid
87
-
88
- This is the main reward signal β€” AlphaFold-style shape comparison.
89
- """
90
- task = get_task(task_name)
91
- target_fold = task["target_fold"]
92
-
93
- # Pre-compute target positions
94
- try:
95
- target_result = simulate(target_fold, crease_percent=1.0)
96
- target_positions = target_result.positions
97
- except Exception:
98
- # Target itself fails β€” all scores 0
99
- return [0.0] * len(completions)
100
 
101
- scores = []
102
- for completion in completions:
103
- response = completion[0]["content"]
104
- fold_data = extract_fold_json(response)
105
 
106
- if fold_data is None:
107
- scores.append(-2.0)
 
 
108
  continue
109
 
110
- is_valid, error = validate_fold(fold_data)
111
- if not is_valid:
112
- scores.append(-1.0)
 
 
 
 
113
  continue
114
 
115
- try:
116
- result = simulate(fold_data, crease_percent=1.0)
117
- similarity = compute_shape_match(result.positions, target_positions)
118
- scores.append(similarity * 20.0)
119
- except Exception:
120
- scores.append(-1.0)
121
 
122
  return scores
 
1
  """GRPO reward functions for origami RL training.
2
 
3
+ Follows the OpenEnv 2048 pattern exactly:
4
+ - launch_openenv() spawns/reuses the origami server
5
+ - Reward functions call the server via EnvClient
6
+ - Server computes simulation + shape matching, returns reward
7
+
8
+ These functions are also importable for use in notebooks.
9
  """
10
 
11
  import json
12
  import re
13
  from typing import Any
14
 
 
 
 
 
 
 
 
15
 
16
  def extract_fold_json(response: str) -> dict | None:
17
  """Extract FOLD JSON from LLM response text.
 
51
  +1.0 valid FOLD JSON with correct structure
52
  -0.5 parseable JSON but invalid FOLD structure
53
  -2.0 not parseable as JSON at all
54
+
55
+ Local check β€” no server needed.
56
  """
57
  scores = []
58
  for completion in completions:
 
63
  scores.append(-2.0)
64
  continue
65
 
66
+ # Basic structural validation
67
+ required = {"vertices_coords", "edges_vertices", "edges_assignment"}
68
+ if not required.issubset(fold_data.keys()):
 
69
  scores.append(-0.5)
70
+ continue
71
 
72
+ verts = fold_data.get("vertices_coords", [])
73
+ edges = fold_data.get("edges_vertices", [])
74
+ assigns = fold_data.get("edges_assignment", [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ if len(edges) != len(assigns):
77
+ scores.append(-0.5)
78
+ continue
 
79
 
80
+ has_fold = any(a in ("M", "V") for a in assigns)
81
+ has_boundary = any(a == "B" for a in assigns)
82
+ if not has_fold or not has_boundary:
83
+ scores.append(-0.5)
84
  continue
85
 
86
+ n = len(verts)
87
+ valid_indices = all(
88
+ 0 <= e[0] < n and 0 <= e[1] < n and e[0] != e[1]
89
+ for e in edges
90
+ )
91
+ if not valid_indices:
92
+ scores.append(-0.5)
93
  continue
94
 
95
+ scores.append(1.0)
 
 
 
 
 
96
 
97
  return scores
training/train_grpo.py CHANGED
@@ -1,20 +1,28 @@
1
  """GRPO training script for origami RL.
2
 
3
- Follows the 2048 OpenEnv + Unsloth pattern:
4
- - LLM generates FOLD JSON crease patterns
5
- - Two reward functions: valid_fold + shape_match
 
6
  - GRPOTrainer from TRL handles the RL loop
7
 
8
- Usage (local/Colab):
 
 
 
 
9
  python -m training.train_grpo --task triangle --max_steps 600
10
 
11
- Usage (Northflank β€” env vars set in Dockerfile.train):
12
- python -m training.train_grpo --task $TASK --model $MODEL --max_steps $MAX_STEPS
13
  """
14
 
15
  import argparse
 
16
  import os
17
 
 
 
18
  PROMPT_TEMPLATE = """You are an origami designer. Generate a FOLD-format crease pattern
19
  that, when folded, produces the target shape described below.
20
 
@@ -49,60 +57,109 @@ def main():
49
  parser = argparse.ArgumentParser(description="GRPO training for origami RL")
50
  parser.add_argument("--task", default="triangle", help="Task name")
51
  parser.add_argument("--max_steps", type=int, default=600)
52
- parser.add_argument("--num_generations", type=int, default=4)
53
- parser.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct")
54
  parser.add_argument("--lr", type=float, default=2e-4)
 
 
 
 
 
55
  args = parser.parse_args()
56
 
57
- # --- These imports are heavy, only load when actually training ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  from datasets import Dataset
59
- from trl import GRPOConfig, GRPOTrainer
60
 
61
- from origami_server.tasks import get_task
62
- from training.reward import shape_match, valid_fold
 
63
 
64
- # Try Unsloth first (CUDA), fall back to HF+PEFT
65
  try:
66
  from unsloth import FastLanguageModel
67
  USE_UNSLOTH = True
68
  except ImportError:
69
  USE_UNSLOTH = False
70
 
71
- task = get_task(args.task)
72
- prompt_text = build_prompt(task)
73
-
74
- # Build dataset (1000 copies of same prompt, like 2048)
75
- dataset = Dataset.from_list(
76
- [
77
- {
78
- "prompt": [{"role": "user", "content": prompt_text}],
79
- "answer": 0,
80
- }
81
- ]
82
- * 1000
83
- )
84
 
85
- # Load model with LoRA
86
  if USE_UNSLOTH:
 
87
  model, tokenizer = FastLanguageModel.from_pretrained(
88
  model_name=args.model,
89
  load_in_4bit=True,
90
- max_seq_length=2048,
 
91
  )
92
  model = FastLanguageModel.get_peft_model(
93
  model,
94
- r=8,
95
  target_modules=[
96
  "q_proj", "k_proj", "v_proj", "o_proj",
97
  "gate_proj", "up_proj", "down_proj",
98
  ],
99
- lora_alpha=16,
100
  use_gradient_checkpointing="unsloth",
 
101
  )
102
  else:
103
  import torch
104
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
105
  from peft import LoraConfig, get_peft_model
 
106
 
107
  bnb_config = BitsAndBytesConfig(
108
  load_in_4bit=True,
@@ -118,19 +175,23 @@ def main():
118
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
119
  )
120
  model = get_peft_model(model, LoraConfig(
121
- r=8, lora_alpha=16, task_type="CAUSAL_LM",
122
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
123
- "gate_proj", "up_proj", "down_proj"],
 
 
 
 
124
  ))
125
 
126
  if tokenizer.pad_token is None:
127
  tokenizer.pad_token = tokenizer.eos_token
128
 
129
- # Wrap shape_match to inject task_name
130
- def shape_match_reward(completions, **kwargs):
131
- return shape_match(completions, task_name=args.task, **kwargs)
 
132
 
133
- # GRPO config
134
  training_args = GRPOConfig(
135
  temperature=1.0,
136
  learning_rate=args.lr,
@@ -142,8 +203,8 @@ def main():
142
  per_device_train_batch_size=1,
143
  gradient_accumulation_steps=1,
144
  num_generations=args.num_generations,
145
- max_prompt_length=1024,
146
- max_completion_length=1024,
147
  max_steps=args.max_steps,
148
  save_steps=100,
149
  output_dir=os.environ.get("OUTPUT_DIR", "outputs"),
@@ -157,9 +218,10 @@ def main():
157
  train_dataset=dataset,
158
  )
159
 
 
160
  trainer.train()
161
 
162
- # Save the LoRA adapter
163
  save_path = os.path.join(
164
  os.environ.get("OUTPUT_DIR", "outputs"),
165
  f"origami-{args.task}-lora-final",
 
1
  """GRPO training script for origami RL.
2
 
3
+ Follows the OpenEnv 2048 pattern exactly:
4
+ - Environment runs as a FastAPI server (origami_server.app)
5
+ - Training connects via WebSocket client (OrigamiEnv)
6
+ - Reward functions call the server, never import engine code
7
  - GRPOTrainer from TRL handles the RL loop
8
 
9
+ Usage:
10
+ # 1. Start the environment server first:
11
+ uvicorn origami_server.app:app --host 0.0.0.0 --port 8000
12
+
13
+ # 2. Run training (connects to server):
14
  python -m training.train_grpo --task triangle --max_steps 600
15
 
16
+ # Or specify server URL:
17
+ python -m training.train_grpo --server http://gpu-host:8000
18
  """
19
 
20
  import argparse
21
+ import functools
22
  import os
23
 
24
+ import requests
25
+
26
  PROMPT_TEMPLATE = """You are an origami designer. Generate a FOLD-format crease pattern
27
  that, when folded, produces the target shape described below.
28
 
 
57
  parser = argparse.ArgumentParser(description="GRPO training for origami RL")
58
  parser.add_argument("--task", default="triangle", help="Task name")
59
  parser.add_argument("--max_steps", type=int, default=600)
60
+ parser.add_argument("--num_generations", type=int, default=2)
61
+ parser.add_argument("--model", default="unsloth/Qwen3-14B")
62
  parser.add_argument("--lr", type=float, default=2e-4)
63
+ parser.add_argument("--lora_rank", type=int, default=4)
64
+ parser.add_argument(
65
+ "--server", default="http://localhost:8000",
66
+ help="URL of the origami environment server",
67
+ )
68
  args = parser.parse_args()
69
 
70
+ # --- Verify server is running ---
71
+ print(f"Connecting to environment server at {args.server}...")
72
+ try:
73
+ r = requests.get(f"{args.server}/health", timeout=5)
74
+ assert r.status_code == 200
75
+ print("Server is healthy.")
76
+ except Exception as e:
77
+ print(f"ERROR: Cannot connect to server at {args.server}")
78
+ print(f"Start it first: uvicorn origami_server.app:app --port 8000")
79
+ raise SystemExit(1)
80
+
81
+ # --- Get task info from server ---
82
+ task = requests.get(f"{args.server}/tasks/{args.task}").json()
83
+ prompt_text = build_prompt(task)
84
+ print(f"Task: {task['name']} β€” {task['description']}")
85
+
86
+ # --- Configure reward functions (OpenEnv pattern) ---
87
+ from client import OrigamiEnv
88
+ from origami_server.models import OrigamiAction
89
+ from training.reward import extract_fold_json, valid_fold
90
+ from unsloth import is_port_open, launch_openenv
91
+
92
+ global port, openenv_process
93
+ port = int(args.server.split(":")[-1]) if ":" in args.server else 8000
94
+ openenv_process = None
95
+
96
+ launch_openenv = functools.partial(
97
+ launch_openenv,
98
+ working_directory=os.getcwd(),
99
+ server="origami_server.app:app",
100
+ environment={**os.environ, "PYTHONPATH": os.getcwd()},
101
+ openenv_class=OrigamiEnv,
102
+ )
103
+
104
+ def shape_match_reward(completions, **kwargs):
105
+ global port, openenv_process
106
+ scores = []
107
+ for completion in completions:
108
+ response = completion[0]["content"]
109
+ fold_data = extract_fold_json(response)
110
+ if fold_data is None:
111
+ scores.append(0.0)
112
+ continue
113
+ try:
114
+ port, openenv_process = launch_openenv(port, openenv_process)
115
+ openenv_process.reset(task_name=args.task)
116
+ result = openenv_process.step(OrigamiAction(fold_data=fold_data))
117
+ scores.append(result.reward if result.reward is not None else 0.0)
118
+ except TimeoutError:
119
+ scores.append(-1.0)
120
+ except Exception:
121
+ scores.append(-3.0)
122
+ return scores
123
+
124
+ # --- Build dataset (same prompt repeated, like 2048) ---
125
  from datasets import Dataset
 
126
 
127
+ dataset = Dataset.from_list(
128
+ [{"prompt": [{"role": "user", "content": prompt_text}]}] * 1000
129
+ )
130
 
131
+ # --- Load model with QLoRA ---
132
  try:
133
  from unsloth import FastLanguageModel
134
  USE_UNSLOTH = True
135
  except ImportError:
136
  USE_UNSLOTH = False
137
 
138
+ max_seq_length = 768 # FOLD JSON is compact
 
 
 
 
 
 
 
 
 
 
 
 
139
 
 
140
  if USE_UNSLOTH:
141
+ print(f"Loading {args.model} with Unsloth QLoRA (rank={args.lora_rank})...")
142
  model, tokenizer = FastLanguageModel.from_pretrained(
143
  model_name=args.model,
144
  load_in_4bit=True,
145
+ max_seq_length=max_seq_length,
146
+ offload_embedding=True, # Needed for 14B on limited VRAM
147
  )
148
  model = FastLanguageModel.get_peft_model(
149
  model,
150
+ r=args.lora_rank,
151
  target_modules=[
152
  "q_proj", "k_proj", "v_proj", "o_proj",
153
  "gate_proj", "up_proj", "down_proj",
154
  ],
155
+ lora_alpha=args.lora_rank * 2,
156
  use_gradient_checkpointing="unsloth",
157
+ random_state=3407,
158
  )
159
  else:
160
  import torch
 
161
  from peft import LoraConfig, get_peft_model
162
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
163
 
164
  bnb_config = BitsAndBytesConfig(
165
  load_in_4bit=True,
 
175
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
176
  )
177
  model = get_peft_model(model, LoraConfig(
178
+ r=args.lora_rank,
179
+ lora_alpha=args.lora_rank * 2,
180
+ task_type="CAUSAL_LM",
181
+ target_modules=[
182
+ "q_proj", "k_proj", "v_proj", "o_proj",
183
+ "gate_proj", "up_proj", "down_proj",
184
+ ],
185
  ))
186
 
187
  if tokenizer.pad_token is None:
188
  tokenizer.pad_token = tokenizer.eos_token
189
 
190
+ model.print_trainable_parameters()
191
+
192
+ # --- GRPO config (matches 2048 pattern) ---
193
+ from trl import GRPOConfig, GRPOTrainer
194
 
 
195
  training_args = GRPOConfig(
196
  temperature=1.0,
197
  learning_rate=args.lr,
 
203
  per_device_train_batch_size=1,
204
  gradient_accumulation_steps=1,
205
  num_generations=args.num_generations,
206
+ max_prompt_length=512,
207
+ max_completion_length=max_seq_length - 512,
208
  max_steps=args.max_steps,
209
  save_steps=100,
210
  output_dir=os.environ.get("OUTPUT_DIR", "outputs"),
 
218
  train_dataset=dataset,
219
  )
220
 
221
+ print(f"Training: {args.max_steps} steps, {args.num_generations} generations/step")
222
  trainer.train()
223
 
224
+ # Save LoRA adapter
225
  save_path = os.path.join(
226
  os.environ.get("OUTPUT_DIR", "outputs"),
227
  f"origami-{args.task}-lora-final",
training/train_origami.ipynb CHANGED
@@ -3,7 +3,7 @@
3
  {
4
  "cell_type": "markdown",
5
  "id": "p8uwc5bkc4n",
6
- "source": "# Origami RL β€” GRPO Training Notebook\n\nTrain an LLM to generate valid FOLD-format crease patterns that fold into target shapes.\n\n**Pipeline:**\n1. LLM receives a prompt describing a target shape (e.g. \"fold diagonally into a triangle\")\n2. LLM generates a FOLD JSON crease pattern\n3. Physics simulator folds the paper analytically\n4. Reward = shape similarity (chamfer distance) to target Γ— 20\n\n**Reward functions:**\n- `valid_fold`: +1.0 valid FOLD JSON, βˆ’0.5 parseable but invalid, βˆ’2.0 unparseable\n- `shape_match`: similarity Γ— 20.0 (0–20), βˆ’1.0 sim fails, βˆ’2.0 invalid FOLD\n\n**Algorithm:** GRPO (Group Relative Policy Optimization) via TRL + Unsloth LoRA",
7
  "metadata": {}
8
  },
9
  {
@@ -15,7 +15,7 @@
15
  {
16
  "cell_type": "code",
17
  "id": "ulhu8a5p5ti",
18
- "source": "# Run this cell once to install all dependencies\n# For Colab: unsloth has a specific install process\nimport sys\nIN_COLAB = \"google.colab\" in sys.modules\n\nif IN_COLAB:\n # Unsloth's recommended Colab install\n !pip install --no-deps \"unsloth[colab-new]\"\n !pip install --no-deps trl datasets peft accelerate bitsandbytes xformers\nelse:\n !pip install -q \"trl>=0.7\" \"datasets>=2.14\" unsloth torch transformers accelerate bitsandbytes\n\n# Core origami env deps (numpy, scipy, pydantic)\n!pip install -q numpy scipy pydantic",
19
  "metadata": {},
20
  "execution_count": null,
21
  "outputs": []
@@ -23,13 +23,13 @@
23
  {
24
  "cell_type": "markdown",
25
  "id": "qcetkmcq1hf",
26
- "source": "## 2. Setup Python Path & Imports",
27
  "metadata": {}
28
  },
29
  {
30
  "cell_type": "code",
31
  "id": "3hr273dhqiv",
32
- "source": "import os\nimport sys\nimport json\n\n# Add the repo root to Python path so origami_server and training modules are importable\nREPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(\"__file__\"), \"..\"))\nif REPO_ROOT not in sys.path:\n sys.path.insert(0, REPO_ROOT)\n\nprint(f\"Repo root: {REPO_ROOT}\")\nprint(f\"Python: {sys.version}\")",
33
  "metadata": {},
34
  "execution_count": null,
35
  "outputs": []
@@ -37,7 +37,7 @@
37
  {
38
  "cell_type": "code",
39
  "id": "bnm2w57r3lc",
40
- "source": "import numpy as np\n\n# Verify origami env modules load correctly\nfrom origami_server.tasks import TASKS, get_task, list_tasks\nfrom origami_server.engine.fold_parser import validate_fold, parse_fold\nfrom origami_server.engine.simulate import simulate\nfrom origami_server.engine.shape_match import compute_shape_match\nfrom training.reward import valid_fold, shape_match, extract_fold_json\n\nprint(f\"Available tasks: {list_tasks()}\")\nprint(\"All origami modules loaded successfully.\")",
41
  "metadata": {},
42
  "execution_count": null,
43
  "outputs": []
@@ -45,43 +45,13 @@
45
  {
46
  "cell_type": "markdown",
47
  "id": "lcaus7mtuj",
48
- "source": "## 3. Explore the Environment\n\nSanity-check the simulator and reward functions before training.",
49
  "metadata": {}
50
  },
51
  {
52
  "cell_type": "code",
53
  "id": "hlqp4y30m87",
54
- "source": "# Print all tasks with their details\nfor name, task in TASKS.items():\n print(f\"\\n{'='*50}\")\n print(f\"Task: {task['name']}\")\n print(f\"Description: {task['description']}\")\n print(f\"Difficulty: {task['difficulty']}\")\n print(f\"Paper: {task['paper']}\")\n fold = task[\"target_fold\"]\n n_verts = len(fold[\"vertices_coords\"])\n n_edges = len(fold[\"edges_vertices\"])\n n_folds = sum(1 for a in fold[\"edges_assignment\"] if a in (\"M\", \"V\"))\n print(f\"Vertices: {n_verts}, Edges: {n_edges}, Fold creases: {n_folds}\")",
55
- "metadata": {},
56
- "execution_count": null,
57
- "outputs": []
58
- },
59
- {
60
- "cell_type": "code",
61
- "id": "dwqqus8mhlj",
62
- "source": "# Test the simulator on each task\nfor name in list_tasks():\n task = get_task(name)\n target_fold = task[\"target_fold\"]\n \n # Simulate flat (0%), half (50%), and fully folded (100%)\n r_flat = simulate(target_fold, crease_percent=0.0)\n r_half = simulate(target_fold, crease_percent=0.5)\n r_full = simulate(target_fold, crease_percent=1.0)\n \n z_half = r_half.positions[:, 2].max() - r_half.positions[:, 2].min()\n \n # Shape match: target vs itself should be 1.0\n self_sim = compute_shape_match(r_full.positions, r_full.positions)\n \n print(f\"{name:15s} | converged={r_full.converged} | strain={r_full.max_strain:.6f} | \"\n f\"z_range@50%={z_half:.3f} | self_similarity={self_sim:.3f}\")",
63
- "metadata": {},
64
- "execution_count": null,
65
- "outputs": []
66
- },
67
- {
68
- "cell_type": "code",
69
- "id": "p1weq9kv5q",
70
- "source": "# Test reward functions with mock LLM outputs\ntriangle_fold = TASKS[\"triangle\"][\"target_fold\"]\n\n# Simulate what the reward functions see during training:\n# completions = list of [{\"content\": \"...LLM response...\"}]\ngood_response = json.dumps(triangle_fold)\nbad_json = \"I think we should fold it like this...\"\ninvalid_fold = json.dumps({\"vertices_coords\": [[0, 0]], \"edges_vertices\": [], \"edges_assignment\": []})\n\ncompletions = [\n [{\"content\": f\"```json\\n{good_response}\\n```\"}], # correct answer in fenced block\n [{\"content\": bad_json}], # garbage\n [{\"content\": invalid_fold}], # parseable but invalid FOLD\n]\n\nprint(\"valid_fold rewards:\", valid_fold(completions))\nprint(\"shape_match rewards:\", shape_match(completions, task_name=\"triangle\"))\nprint()\nprint(\"Expected: valid_fold = [1.0, -2.0, -0.5]\")\nprint(\"Expected: shape_match = [20.0, -2.0, -1.0]\")",
71
- "metadata": {},
72
- "execution_count": null,
73
- "outputs": []
74
- },
75
- {
76
- "cell_type": "markdown",
77
- "id": "45l0n1hgvr",
78
- "source": "## 4. Visualize Tasks\n\n2D crease patterns for each task (matplotlib).",
79
- "metadata": {}
80
- },
81
- {
82
- "cell_type": "code",
83
- "id": "fkopb9lgg7i",
84
- "source": "import matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom mpl_toolkits.mplot3d.art3d import Poly3DCollection\n\nEDGE_COLORS = {\"M\": \"red\", \"V\": \"blue\", \"B\": \"black\"}\nEDGE_STYLES = {\"M\": \"--\", \"V\": \":\", \"B\": \"-\"}\n\nfig, axes = plt.subplots(2, 4, figsize=(16, 8))\n\nfor idx, (name, task) in enumerate(TASKS.items()):\n fold = task[\"target_fold\"]\n verts = np.array(fold[\"vertices_coords\"])\n \n # Row 1: 2D crease pattern\n ax = axes[0, idx]\n ax.set_title(f\"{name}\\n{task['description']}\", fontsize=9)\n ax.set_aspect(\"equal\")\n ax.set_xlim(-0.1, 1.1)\n ax.set_ylim(-0.1, 1.1)\n ax.grid(True, alpha=0.2)\n \n for i, (e, a) in enumerate(zip(fold[\"edges_vertices\"], fold[\"edges_assignment\"])):\n v1, v2 = verts[e[0]], verts[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n style = EDGE_STYLES.get(a, \"-\")\n lw = 2.5 if a == \"B\" else 1.8\n ax.plot([v1[0], v2[0]], [v1[1], v2[1]], color=color, linestyle=style, linewidth=lw)\n \n ax.scatter(verts[:, 0], verts[:, 1], c=\"black\", s=15, zorder=5)\n \n # Row 2: 3D folded shape\n ax3 = fig.add_subplot(2, 4, idx + 5, projection=\"3d\")\n result = simulate(fold, crease_percent=1.0)\n pos = result.positions\n \n if \"faces_vertices\" in fold:\n for face in fold[\"faces_vertices\"]:\n tri_verts = [pos[vi] for vi in face]\n poly = Poly3DCollection([tri_verts], alpha=0.3, facecolor=\"lightskyblue\", edgecolor=\"steelblue\")\n ax3.add_collection3d(poly)\n \n for i, (e, a) in enumerate(zip(fold[\"edges_vertices\"], fold[\"edges_assignment\"])):\n p1, p2 = pos[e[0]], pos[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n ax3.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, linewidth=1.2)\n \n ax3.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c=\"black\", s=10, zorder=5)\n ax3.set_title(f\"Folded (3D)\", fontsize=9)\n ax3.set_xlim(-0.2, 1.2)\n ax3.set_ylim(-0.2, 1.2)\n ax3.set_zlim(-0.6, 0.6)\n \n # Remove the empty 2D subplot that was in row 2\n axes[1, idx].remove()\n\nplt.tight_layout()\nplt.show()",
85
  "metadata": {},
86
  "execution_count": null,
87
  "outputs": []
@@ -89,13 +59,13 @@
89
  {
90
  "cell_type": "markdown",
91
  "id": "a14w2fkoewq",
92
- "source": "## 5. Training Configuration",
93
  "metadata": {}
94
  },
95
  {
96
  "cell_type": "code",
97
  "id": "2phdejbobq3",
98
- "source": "# ============================================================\n# Training hyperparameters β€” edit these before launching\n# ============================================================\n\nTASK_NAME = \"triangle\" # \"triangle\", \"half_fold\", \"quarter_fold\", \"letter_fold\"\nMODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\" # Change to your preferred model\nMAX_STEPS = 600 # Total GRPO training steps\nNUM_GENERATIONS = 4 # Completions per prompt per step\nLEARNING_RATE = 2e-4\nLORA_R = 8 # LoRA rank\nLORA_ALPHA = 16 # LoRA alpha\nMAX_PROMPT_LENGTH = 1024\nMAX_COMPLETION_LENGTH = 1024\nDATASET_SIZE = 1000 # Number of prompt copies (same prompt repeated)\nOUTPUT_DIR = \"outputs\"\nSAVE_STEPS = 100",
99
  "metadata": {},
100
  "execution_count": null,
101
  "outputs": []
@@ -103,13 +73,13 @@
103
  {
104
  "cell_type": "markdown",
105
  "id": "feal20fr8j5",
106
- "source": "## 6. Build the Prompt & Dataset",
107
  "metadata": {}
108
  },
109
  {
110
  "cell_type": "code",
111
  "id": "uo7zh1dwp6r",
112
- "source": "from training.train_grpo import PROMPT_TEMPLATE, build_prompt\n\ntask = get_task(TASK_NAME)\nprompt_text = build_prompt(task)\n\nprint(\"=\"*60)\nprint(\"PROMPT THAT THE LLM WILL SEE:\")\nprint(\"=\"*60)\nprint(prompt_text)",
113
  "metadata": {},
114
  "execution_count": null,
115
  "outputs": []
@@ -117,7 +87,7 @@
117
  {
118
  "cell_type": "code",
119
  "id": "900vyqwb8g",
120
- "source": "from datasets import Dataset\n\n# GRPO pattern: same prompt repeated many times, the RL loop generates\n# multiple completions per prompt and uses relative rewards to update policy\ndataset = Dataset.from_list(\n [\n {\n \"prompt\": [{\"role\": \"user\", \"content\": prompt_text}],\n \"answer\": 0, # placeholder, not used by GRPO\n }\n ]\n * DATASET_SIZE\n)\n\nprint(f\"Dataset size: {len(dataset)}\")\nprint(f\"Sample prompt (first 100 chars): {dataset[0]['prompt'][0]['content'][:100]}...\")",
121
  "metadata": {},
122
  "execution_count": null,
123
  "outputs": []
@@ -125,48 +95,54 @@
125
  {
126
  "cell_type": "markdown",
127
  "id": "xn6n1hpx2aa",
128
- "source": "## 7. Load Model + LoRA\n\nUses Unsloth for fast 4-bit LoRA fine-tuning. Falls back to standard HuggingFace if Unsloth isn't available.",
129
  "metadata": {}
130
  },
131
  {
132
  "cell_type": "code",
133
  "id": "vkfaeuu9dq",
134
- "source": "import torch\nprint(f\"CUDA available: {torch.cuda.is_available()}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")\nelif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n print(\"Apple MPS (Metal) available β€” note: Unsloth requires CUDA, will use HF fallback\")\nelse:\n print(\"No GPU detected β€” training will be very slow\")",
135
  "metadata": {},
136
  "execution_count": null,
137
  "outputs": []
138
  },
 
 
 
 
 
 
139
  {
140
  "cell_type": "code",
141
- "id": "xwlkfw3xxoo",
142
- "source": "USE_UNSLOTH = False\n\ntry:\n from unsloth import FastLanguageModel\n USE_UNSLOTH = True\n print(\"Using Unsloth for fast LoRA loading\")\nexcept ImportError:\n print(\"Unsloth not available, using standard HuggingFace + PEFT\")\n\nif USE_UNSLOTH:\n model, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_NAME,\n load_in_4bit=True,\n max_seq_length=MAX_PROMPT_LENGTH + MAX_COMPLETION_LENGTH,\n )\n model = FastLanguageModel.get_peft_model(\n model,\n r=LORA_R,\n target_modules=[\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n lora_alpha=LORA_ALPHA,\n use_gradient_checkpointing=\"unsloth\",\n )\nelse:\n from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n from peft import LoraConfig, get_peft_model\n\n bnb_config = BitsAndBytesConfig(\n load_in_4bit=True,\n bnb_4bit_quant_type=\"nf4\",\n bnb_4bit_compute_dtype=torch.bfloat16,\n ) if torch.cuda.is_available() else None\n\n tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n model = AutoModelForCausalLM.from_pretrained(\n MODEL_NAME,\n quantization_config=bnb_config,\n device_map=\"auto\" if torch.cuda.is_available() else \"cpu\",\n torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,\n )\n\n lora_config = LoraConfig(\n r=LORA_R,\n lora_alpha=LORA_ALPHA,\n target_modules=[\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n task_type=\"CAUSAL_LM\",\n )\n model = get_peft_model(model, lora_config)\n\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\nmodel.print_trainable_parameters()",
143
  "metadata": {},
144
  "execution_count": null,
145
  "outputs": []
146
  },
147
  {
148
  "cell_type": "markdown",
149
- "id": "3f7ritml396",
150
- "source": "## 8. Setup GRPO Trainer",
151
  "metadata": {}
152
  },
153
  {
154
  "cell_type": "code",
155
- "id": "4dqsw30e9nq",
156
- "source": "from trl import GRPOConfig, GRPOTrainer\n\n# Wrap shape_match to inject the task name\ndef shape_match_reward(completions, **kwargs):\n return shape_match(completions, task_name=TASK_NAME, **kwargs)\n\ntraining_args = GRPOConfig(\n temperature=1.0,\n learning_rate=LEARNING_RATE,\n weight_decay=0.001,\n warmup_ratio=0.1,\n lr_scheduler_type=\"linear\",\n optim=\"adamw_8bit\" if torch.cuda.is_available() else \"adamw_torch\",\n logging_steps=1,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=1,\n num_generations=NUM_GENERATIONS,\n max_prompt_length=MAX_PROMPT_LENGTH,\n max_completion_length=MAX_COMPLETION_LENGTH,\n max_steps=MAX_STEPS,\n save_steps=SAVE_STEPS,\n output_dir=OUTPUT_DIR,\n report_to=\"none\", # Set to \"wandb\" if you want W&B logging\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=[valid_fold, shape_match_reward],\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Trainer ready. Task: {TASK_NAME}, Model: {MODEL_NAME}\")\nprint(f\"Max steps: {MAX_STEPS}, Generations per step: {NUM_GENERATIONS}\")\nprint(f\"Reward functions: valid_fold + shape_match\")",
157
  "metadata": {},
158
  "execution_count": null,
159
  "outputs": []
160
  },
161
  {
162
  "cell_type": "markdown",
163
- "id": "62lvkfoyu1p",
164
  "source": "## 9. Train!",
165
  "metadata": {}
166
  },
167
  {
168
  "cell_type": "code",
169
- "id": "eohisxhna96",
170
  "source": "trainer.train()",
171
  "metadata": {},
172
  "execution_count": null,
@@ -181,7 +157,7 @@
181
  {
182
  "cell_type": "code",
183
  "id": "t3d4tu6o5mc",
184
- "source": "SAVE_PATH = f\"origami-{TASK_NAME}-lora\"\n\n# Save LoRA adapter\nmodel.save_pretrained(SAVE_PATH)\ntokenizer.save_pretrained(SAVE_PATH)\nprint(f\"LoRA adapter saved to {SAVE_PATH}/\")\n\n# Optional: merge LoRA into base model and save full model\n# merged_path = f\"origami-{TASK_NAME}-merged\"\n# if USE_UNSLOTH:\n# model.save_pretrained_merged(merged_path, tokenizer)\n# else:\n# merged_model = model.merge_and_unload()\n# merged_model.save_pretrained(merged_path)\n# tokenizer.save_pretrained(merged_path)\n# print(f\"Merged model saved to {merged_path}/\")",
185
  "metadata": {},
186
  "execution_count": null,
187
  "outputs": []
@@ -189,13 +165,13 @@
189
  {
190
  "cell_type": "markdown",
191
  "id": "q18eizy1ok",
192
- "source": "## 11. Evaluate β€” Generate & Score Completions\n\nTest the trained model by generating crease patterns and scoring them.",
193
  "metadata": {}
194
  },
195
  {
196
  "cell_type": "code",
197
  "id": "on56augj41",
198
- "source": "# Put model in inference mode\nif USE_UNSLOTH:\n FastLanguageModel.for_inference(model)\n\nNUM_EVAL_SAMPLES = 8\n\n# Build chat messages\nmessages = [{\"role\": \"user\", \"content\": prompt_text}]\ninput_ids = tokenizer.apply_chat_template(\n messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\"\n).to(model.device)\n\nprint(f\"Generating {NUM_EVAL_SAMPLES} completions...\")\nprint(f\"Input length: {input_ids.shape[1]} tokens\\n\")\n\neval_completions = []\nfor i in range(NUM_EVAL_SAMPLES):\n with torch.no_grad():\n output = model.generate(\n input_ids,\n max_new_tokens=MAX_COMPLETION_LENGTH,\n temperature=0.7,\n top_p=0.9,\n do_sample=True,\n pad_token_id=tokenizer.pad_token_id,\n )\n response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)\n eval_completions.append([{\"content\": response}])\n \n # Quick score\n fold_data = extract_fold_json(response)\n if fold_data is None:\n status = \"UNPARSEABLE\"\n else:\n is_valid, err = validate_fold(fold_data)\n if not is_valid:\n status = f\"INVALID: {err}\"\n else:\n try:\n result = simulate(fold_data, crease_percent=1.0)\n target_result = simulate(task[\"target_fold\"], crease_percent=1.0)\n sim = compute_shape_match(result.positions, target_result.positions)\n status = f\"similarity={sim:.3f} (reward={sim * 20:.1f})\"\n except Exception as e:\n status = f\"SIM ERROR: {e}\"\n \n print(f\" Sample {i+1}: {status}\")\n\n# Compute aggregate reward scores\nprint(f\"\\nAggregate rewards:\")\nvf_scores = valid_fold(eval_completions)\nsm_scores = shape_match(eval_completions, task_name=TASK_NAME)\nprint(f\" valid_fold: mean={np.mean(vf_scores):.2f}, scores={vf_scores}\")\nprint(f\" shape_match: mean={np.mean(sm_scores):.2f}, scores={sm_scores}\")",
199
  "metadata": {},
200
  "execution_count": null,
201
  "outputs": []
@@ -203,13 +179,13 @@
203
  {
204
  "cell_type": "markdown",
205
  "id": "tb1y8hszrk",
206
- "source": "## 12. Visualize a Generated Fold\n\nPick the best completion and visualize its crease pattern + 3D fold vs the target.",
207
  "metadata": {}
208
  },
209
  {
210
  "cell_type": "code",
211
  "id": "0zo3krbkiqej",
212
- "source": "# Find the best valid completion\nbest_idx = int(np.argmax(sm_scores))\nbest_response = eval_completions[best_idx][0][\"content\"]\nbest_fold = extract_fold_json(best_response)\n\nif best_fold is None or sm_scores[best_idx] <= 0:\n print(\"No valid completions to visualize.\")\nelse:\n is_valid, _ = validate_fold(best_fold)\n if not is_valid:\n print(\"Best completion has invalid FOLD structure.\")\n else:\n pred_result = simulate(best_fold, crease_percent=1.0)\n target_result = simulate(task[\"target_fold\"], crease_percent=1.0)\n \n fig = plt.figure(figsize=(14, 5))\n \n # 1) Generated 2D crease pattern\n ax1 = fig.add_subplot(131)\n ax1.set_title(f\"Generated Crease Pattern\\n(sample {best_idx+1})\", fontsize=10)\n ax1.set_aspect(\"equal\")\n verts = np.array(best_fold[\"vertices_coords\"])\n for i, (e, a) in enumerate(zip(best_fold[\"edges_vertices\"], best_fold[\"edges_assignment\"])):\n v1, v2 = verts[e[0]], verts[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n style = EDGE_STYLES.get(a, \"-\")\n ax1.plot([v1[0], v2[0]], [v1[1], v2[1]], color=color, linestyle=style, linewidth=2)\n ax1.scatter(verts[:, 0], verts[:, 1], c=\"black\", s=20, zorder=5)\n ax1.grid(True, alpha=0.2)\n \n # 2) Generated 3D fold\n ax2 = fig.add_subplot(132, projection=\"3d\")\n ax2.set_title(f\"Generated 3D Fold\\nsimilarity={sm_scores[best_idx]/20:.3f}\", fontsize=10)\n pos = pred_result.positions\n for i, (e, a) in enumerate(zip(best_fold[\"edges_vertices\"], best_fold[\"edges_assignment\"])):\n p1, p2 = pos[e[0]], pos[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n ax2.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, linewidth=1.5)\n ax2.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c=\"black\", s=15, zorder=5)\n \n # 3) Target 3D fold\n ax3 = fig.add_subplot(133, projection=\"3d\")\n ax3.set_title(\"Target 3D Fold\", fontsize=10)\n tpos = target_result.positions\n tfold = task[\"target_fold\"]\n for i, (e, a) in enumerate(zip(tfold[\"edges_vertices\"], tfold[\"edges_assignment\"])):\n p1, p2 = tpos[e[0]], tpos[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n ax3.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, linewidth=1.5)\n ax3.scatter(tpos[:, 0], tpos[:, 1], tpos[:, 2], c=\"black\", s=15, zorder=5)\n \n plt.tight_layout()\n plt.show()\n \n print(f\"\\nBest generated FOLD JSON:\")\n print(json.dumps(best_fold, indent=2))",
213
  "metadata": {},
214
  "execution_count": null,
215
  "outputs": []
@@ -217,7 +193,7 @@
217
  {
218
  "cell_type": "markdown",
219
  "id": "qlakksqmoe",
220
- "source": "## 13. Plot Training Logs",
221
  "metadata": {}
222
  },
223
  {
 
3
  {
4
  "cell_type": "markdown",
5
  "id": "p8uwc5bkc4n",
6
+ "source": "# Origami RL β€” GRPO Training\n\nTrain an LLM to generate FOLD crease patterns using OpenEnv + Unsloth + TRL.\n\nFollows the [2048 OpenEnv notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/OpenEnv_gpt_oss_(20B)_Reinforcement_Learning_2048_Game.ipynb) pattern exactly:\n1. `launch_openenv()` spawns the origami environment server\n2. LLM generates FOLD JSON crease patterns\n3. Reward functions call the server via OpenEnv client\n4. GRPO updates policy based on relative rewards",
7
  "metadata": {}
8
  },
9
  {
 
15
  {
16
  "cell_type": "code",
17
  "id": "ulhu8a5p5ti",
18
+ "source": "%%capture\nimport os, importlib.util\n!pip install --upgrade -qqq uv\nif importlib.util.find_spec(\"torch\") is None or \"COLAB_\" in \"\".join(os.environ.keys()):\n try: import numpy; get_numpy = f\"numpy=={numpy.__version__}\"\n except: get_numpy = \"numpy\"\n !uv pip install -qqq \\\n \"torch>=2.8.0\" \"triton>=3.4.0\" {get_numpy} torchvision bitsandbytes \"transformers==4.56.2\" trackio \\\n \"unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo\" \\\n \"unsloth[base] @ git+https://github.com/unslothai/unsloth\"\nelif importlib.util.find_spec(\"unsloth\") is None:\n !uv pip install -qqq unsloth trackio\n!uv pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo\n!pip install -qqq fastapi uvicorn requests numpy scipy pydantic",
19
  "metadata": {},
20
  "execution_count": null,
21
  "outputs": []
 
23
  {
24
  "cell_type": "markdown",
25
  "id": "qcetkmcq1hf",
26
+ "source": "## 2. Clone Origami Env + Setup Paths",
27
  "metadata": {}
28
  },
29
  {
30
  "cell_type": "code",
31
  "id": "3hr273dhqiv",
32
+ "source": "%%capture\n# Clone the origami env repo (skip if running locally)\nimport subprocess, sys, os\nfrom pathlib import Path\n\nREPO_URL = \"https://github.com/YOUR_USERNAME/origami_env.git\" # TODO: update with your repo\nLOCAL_DIR = \"origami_env\"\n\nif not Path(LOCAL_DIR).exists():\n # Running on Colab β€” clone the repo\n !git clone {REPO_URL} {LOCAL_DIR} > /dev/null 2>&1\n !pip install -e {LOCAL_DIR} > /dev/null 2>&1\n\n# Add repo to Python path\nworking_directory = str(Path(LOCAL_DIR).absolute()) if Path(LOCAL_DIR).exists() else str(Path.cwd().parent.absolute())\nsys.path.insert(0, working_directory)\nprint(f\"Working directory: {working_directory}\")",
33
  "metadata": {},
34
  "execution_count": null,
35
  "outputs": []
 
37
  {
38
  "cell_type": "code",
39
  "id": "bnm2w57r3lc",
40
+ "source": "# Import OpenEnv client + models (same pattern as 2048 notebook)\nfrom client import OrigamiEnv\nfrom origami_server.models import OrigamiAction, OrigamiObservation, OrigamiState\nprint(\"Origami OpenEnv modules loaded.\")",
41
  "metadata": {},
42
  "execution_count": null,
43
  "outputs": []
 
45
  {
46
  "cell_type": "markdown",
47
  "id": "lcaus7mtuj",
48
+ "source": "## 3. Load Model + QLoRA",
49
  "metadata": {}
50
  },
51
  {
52
  "cell_type": "code",
53
  "id": "hlqp4y30m87",
54
+ "source": "from unsloth import FastLanguageModel\nimport torch\n\nmax_seq_length = 768\nlora_rank = 4\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name = \"unsloth/Qwen3-14B\",\n load_in_4bit = True,\n max_seq_length = max_seq_length,\n offload_embedding = True,\n)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  "metadata": {},
56
  "execution_count": null,
57
  "outputs": []
 
59
  {
60
  "cell_type": "markdown",
61
  "id": "a14w2fkoewq",
62
+ "source": "## 4. LoRA Adapter",
63
  "metadata": {}
64
  },
65
  {
66
  "cell_type": "code",
67
  "id": "2phdejbobq3",
68
+ "source": "model = FastLanguageModel.get_peft_model(\n model,\n r = lora_rank,\n target_modules = [\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n lora_alpha = lora_rank * 2,\n use_gradient_checkpointing = \"unsloth\",\n random_state = 3407,\n)",
69
  "metadata": {},
70
  "execution_count": null,
71
  "outputs": []
 
73
  {
74
  "cell_type": "markdown",
75
  "id": "feal20fr8j5",
76
+ "source": "## 5. Launch OpenEnv Server",
77
  "metadata": {}
78
  },
79
  {
80
  "cell_type": "code",
81
  "id": "uo7zh1dwp6r",
82
+ "source": "# Launch origami environment server (same pattern as 2048 notebook)\nglobal port\nglobal openenv_process\nport = 8000\nopenenv_process = None\nserver = \"origami_server.app:app\"\nenvironment = {\n **os.environ,\n \"PYTHONPATH\": working_directory,\n}\n\n# Augment Unsloth's launch_openenv with our config\nimport functools\nfrom unsloth import is_port_open, launch_openenv\nlaunch_openenv = functools.partial(\n launch_openenv,\n working_directory = working_directory,\n server = server,\n environment = environment,\n openenv_class = OrigamiEnv,\n)",
83
  "metadata": {},
84
  "execution_count": null,
85
  "outputs": []
 
87
  {
88
  "cell_type": "code",
89
  "id": "900vyqwb8g",
90
+ "source": "# Test the connection β€” reset and inspect\nport, openenv_process = launch_openenv(port, openenv_process)\nresult = openenv_process.reset(task_name=\"triangle\")\nprint(f\"Server running on port {port}\")\nprint(f\"Observation: done={result.done}, reward={result.reward}\")\nprint(f\"Task: {result.observation.task}\")",
91
  "metadata": {},
92
  "execution_count": null,
93
  "outputs": []
 
95
  {
96
  "cell_type": "markdown",
97
  "id": "xn6n1hpx2aa",
98
+ "source": "## 6. Prompt + Dataset",
99
  "metadata": {}
100
  },
101
  {
102
  "cell_type": "code",
103
  "id": "vkfaeuu9dq",
104
+ "source": "import requests\n\nTASK_NAME = \"triangle\" # \"triangle\", \"half_fold\", \"quarter_fold\", \"letter_fold\"\n\n# Fetch task params from the server (paper size, description, etc.)\ntask_info = requests.get(f\"http://localhost:{port}/tasks/{TASK_NAME}\").json()\n\nPROMPT_TEMPLATE = \"\"\"You are an origami designer. Generate a FOLD-format crease pattern\nthat, when folded, produces the target shape described below.\n\nTarget: {description}\nPaper size: {width} x {height}\n\nOutput a JSON object with these exact fields:\n- vertices_coords: [[x, y], ...] β€” 2D positions on the flat paper (0 to {width} for x, 0 to {height} for y)\n- edges_vertices: [[v1, v2], ...] β€” pairs of vertex indices forming edges\n- edges_assignment: [\"B\"|\"M\"|\"V\", ...] β€” B=boundary, M=mountain fold, V=valley fold\n- edges_foldAngle: [angle, ...] β€” fold angles in degrees (V: 180, M: -180, B: 0)\n\nRules:\n- Boundary edges (B) must outline the paper rectangle\n- At least one fold crease (M or V) must exist\n- All vertex indices must be valid (0 to N-1)\n\nOutput ONLY the JSON object wrapped in ```json ... ``` markers.\"\"\"\n\nprompt = PROMPT_TEMPLATE.format(\n description=task_info[\"description\"],\n width=task_info[\"paper\"][\"width\"],\n height=task_info[\"paper\"][\"height\"],\n).strip()\n\n# Build dataset β€” same prompt repeated 1000x (identical to 2048 pattern)\nfrom datasets import Dataset\ndataset = Dataset.from_list([{\n \"prompt\": [{\"role\": \"user\", \"content\": prompt}],\n}] * 1000)\n\nprint(f\"Task: {task_info['name']} β€” {task_info['description']}\")\nprint(f\"Paper: {task_info['paper']['width']} x {task_info['paper']['height']}\")\nprint(f\"Difficulty: {task_info['difficulty']}\")\nprint(f\"Dataset: {len(dataset)} rows\")\nprint(f\"\\nPrompt:\\n{prompt[:200]}...\")",
105
  "metadata": {},
106
  "execution_count": null,
107
  "outputs": []
108
  },
109
+ {
110
+ "cell_type": "markdown",
111
+ "id": "3f7ritml396",
112
+ "source": "## 7. Reward Functions\n\nTwo reward functions (same pattern as 2048 notebook):\n- `valid_fold` β€” local JSON structure check (fast, no server call)\n- `shape_match` β€” calls the origami server via `launch_openenv`, submits the fold, returns similarity Γ— 20",
113
+ "metadata": {}
114
+ },
115
  {
116
  "cell_type": "code",
117
+ "id": "4dqsw30e9nq",
118
+ "source": "import json, re\n\n# --- Reward 1: valid_fold (local check, no server needed) ---\n\ndef extract_fold_json(response):\n \"\"\"Extract FOLD JSON from LLM response text.\"\"\"\n # Try fenced code block\n match = re.search(r\"```(?:json)?\\s*(\\{.*?\\})\\s*```\", response, re.DOTALL)\n if match:\n try: return json.loads(match.group(1))\n except json.JSONDecodeError: pass\n # Try raw JSON with vertices_coords\n match = re.search(r\"\\{[^{}]*\\\"vertices_coords\\\"[^{}]*\\}\", response, re.DOTALL)\n if match:\n try: return json.loads(match.group(0))\n except json.JSONDecodeError: pass\n # Try whole response\n try:\n data = json.loads(response.strip())\n if isinstance(data, dict) and \"vertices_coords\" in data:\n return data\n except (json.JSONDecodeError, ValueError): pass\n return None\n\ndef valid_fold(completions, **kwargs):\n \"\"\"Does the LLM output parse as valid FOLD JSON?\n +1.0 valid, -0.5 parseable but invalid, -2.0 unparseable.\"\"\"\n scores = []\n for completion in completions:\n response = completion[0][\"content\"]\n fold_data = extract_fold_json(response)\n if fold_data is None:\n scores.append(-2.0); continue\n required = {\"vertices_coords\", \"edges_vertices\", \"edges_assignment\"}\n if not required.issubset(fold_data.keys()):\n scores.append(-0.5); continue\n verts = fold_data.get(\"vertices_coords\", [])\n edges = fold_data.get(\"edges_vertices\", [])\n assigns = fold_data.get(\"edges_assignment\", [])\n if len(edges) != len(assigns):\n scores.append(-0.5); continue\n if not any(a in (\"M\", \"V\") for a in assigns) or not any(a == \"B\" for a in assigns):\n scores.append(-0.5); continue\n n = len(verts)\n if not all(0 <= e[0] < n and 0 <= e[1] < n and e[0] != e[1] for e in edges):\n scores.append(-0.5); continue\n scores.append(1.0)\n return scores\n\n# --- Reward 2: shape_match (calls server via launch_openenv) ---\n\ndef shape_match(completions, **kwargs):\n \"\"\"Submit fold to origami server, get shape similarity reward.\n Calls launch_openenv to ensure server is running, then reset + step.\"\"\"\n global port, openenv_process\n scores = []\n for completion in completions:\n response = completion[0][\"content\"]\n fold_data = extract_fold_json(response)\n if fold_data is None:\n scores.append(0.0)\n continue\n try:\n port, openenv_process = launch_openenv(port, openenv_process)\n openenv_process.reset(task_name=TASK_NAME)\n result = openenv_process.step(OrigamiAction(fold_data=fold_data))\n reward = result.reward if result.reward is not None else 0.0\n scores.append(reward)\n except TimeoutError:\n scores.append(-1.0)\n except Exception as e:\n scores.append(-3.0)\n return scores\n\n# Quick test\ntest_good = [[{\"content\": json.dumps({\n \"vertices_coords\": [[0,0],[1,0],[1,1],[0,1]],\n \"edges_vertices\": [[0,1],[1,2],[2,3],[3,0],[0,2]],\n \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"V\"],\n \"edges_foldAngle\": [0,0,0,0,180]\n})}]]\ntest_bad = [[{\"content\": \"not json\"}]]\nprint(f\"valid_fold β€” good: {valid_fold(test_good)}, bad: {valid_fold(test_bad)}\")\nprint(f\"shape_match β€” good: {shape_match(test_good)}\")",
119
  "metadata": {},
120
  "execution_count": null,
121
  "outputs": []
122
  },
123
  {
124
  "cell_type": "markdown",
125
+ "id": "62lvkfoyu1p",
126
+ "source": "## 8. GRPO Trainer",
127
  "metadata": {}
128
  },
129
  {
130
  "cell_type": "code",
131
+ "id": "eohisxhna96",
132
+ "source": "from trl import GRPOConfig, GRPOTrainer\n\ntraining_args = GRPOConfig(\n temperature = 1.0,\n learning_rate = 2e-4,\n weight_decay = 0.001,\n warmup_ratio = 0.1,\n lr_scheduler_type = \"linear\",\n optim = \"adamw_8bit\",\n logging_steps = 1,\n per_device_train_batch_size = 1,\n gradient_accumulation_steps = 1,\n num_generations = 2,\n max_prompt_length = 512,\n max_completion_length = max_seq_length - 512,\n max_steps = 600,\n save_steps = 100,\n output_dir = \"outputs\",\n report_to = \"none\",\n)\n\ntrainer = GRPOTrainer(\n model = model,\n processing_class = tokenizer,\n reward_funcs = [valid_fold, shape_match],\n args = training_args,\n train_dataset = dataset,\n)\n\nprint(f\"Trainer ready: {training_args.max_steps} steps, {training_args.num_generations} generations/step\")",
133
  "metadata": {},
134
  "execution_count": null,
135
  "outputs": []
136
  },
137
  {
138
  "cell_type": "markdown",
139
+ "id": "ve98mq6rgot",
140
  "source": "## 9. Train!",
141
  "metadata": {}
142
  },
143
  {
144
  "cell_type": "code",
145
+ "id": "8il1yknetfg",
146
  "source": "trainer.train()",
147
  "metadata": {},
148
  "execution_count": null,
 
157
  {
158
  "cell_type": "code",
159
  "id": "t3d4tu6o5mc",
160
+ "source": "save_path = f\"origami-{TASK_NAME}-lora\"\nmodel.save_pretrained(save_path)\ntokenizer.save_pretrained(save_path)\nprint(f\"LoRA adapter saved to {save_path}/\")",
161
  "metadata": {},
162
  "execution_count": null,
163
  "outputs": []
 
165
  {
166
  "cell_type": "markdown",
167
  "id": "q18eizy1ok",
168
+ "source": "## 11. Evaluate β€” Generate & Score",
169
  "metadata": {}
170
  },
171
  {
172
  "cell_type": "code",
173
  "id": "on56augj41",
174
+ "source": "import numpy as np\nFastLanguageModel.for_inference(model)\n\nNUM_EVAL = 8\nmessages = [{\"role\": \"user\", \"content\": prompt}]\ninput_ids = tokenizer.apply_chat_template(\n messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\"\n).to(model.device)\n\nprint(f\"Generating {NUM_EVAL} completions (input: {input_ids.shape[1]} tokens)...\\n\")\n\neval_completions = []\nfor i in range(NUM_EVAL):\n with torch.no_grad():\n output = model.generate(\n input_ids,\n max_new_tokens=max_seq_length - 512,\n temperature=0.7, top_p=0.9, do_sample=True,\n pad_token_id=tokenizer.pad_token_id,\n )\n response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)\n eval_completions.append([{\"content\": response}])\n fold = extract_fold_json(response)\n status = f\"parsed ({len(fold.get('vertices_coords', []))} verts)\" if fold else \"UNPARSEABLE\"\n print(f\" Sample {i+1}: {status}\")\n\nprint(f\"\\nScoring via server...\")\nvf_scores = valid_fold(eval_completions)\nsm_scores = shape_match(eval_completions)\nprint(f\" valid_fold: mean={np.mean(vf_scores):.2f}, scores={vf_scores}\")\nprint(f\" shape_match: mean={np.mean(sm_scores):.2f}, scores={sm_scores}\")",
175
  "metadata": {},
176
  "execution_count": null,
177
  "outputs": []
 
179
  {
180
  "cell_type": "markdown",
181
  "id": "tb1y8hszrk",
182
+ "source": "## 12. Visualize Best Result",
183
  "metadata": {}
184
  },
185
  {
186
  "cell_type": "code",
187
  "id": "0zo3krbkiqej",
188
+ "source": "import matplotlib.pyplot as plt\nimport requests\n\nEDGE_COLORS = {\"M\": \"red\", \"V\": \"blue\", \"B\": \"black\"}\nEDGE_STYLES = {\"M\": \"--\", \"V\": \":\", \"B\": \"-\"}\n\nbest_idx = int(np.argmax(sm_scores))\nbest_fold = extract_fold_json(eval_completions[best_idx][0][\"content\"])\n\nif best_fold is None or sm_scores[best_idx] <= 0:\n print(\"No valid completions to visualize.\")\nelse:\n fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))\n\n # Generated crease pattern\n ax1.set_title(f\"Generated (sample {best_idx+1})\\nreward={sm_scores[best_idx]:.1f}\", fontsize=10)\n ax1.set_aspect(\"equal\")\n verts = np.array(best_fold[\"vertices_coords\"])\n for e, a in zip(best_fold[\"edges_vertices\"], best_fold[\"edges_assignment\"]):\n v1, v2 = verts[e[0]], verts[e[1]]\n ax1.plot([v1[0], v2[0]], [v1[1], v2[1]],\n color=EDGE_COLORS.get(a, \"gray\"),\n linestyle=EDGE_STYLES.get(a, \"-\"), linewidth=2)\n ax1.scatter(verts[:, 0], verts[:, 1], c=\"black\", s=20, zorder=5)\n ax1.grid(True, alpha=0.2)\n\n # Target crease pattern (from server)\n ax2.set_title(\"Target\", fontsize=10)\n ax2.set_aspect(\"equal\")\n port, openenv_process = launch_openenv(port, openenv_process)\n # Get target from server via HTTP\n target_resp = requests.get(f\"http://localhost:{port}/tasks/{TASK_NAME}\")\n target = target_resp.json()[\"target_fold\"]\n tverts = np.array(target[\"vertices_coords\"])\n for e, a in zip(target[\"edges_vertices\"], target[\"edges_assignment\"]):\n v1, v2 = tverts[e[0]], tverts[e[1]]\n ax2.plot([v1[0], v2[0]], [v1[1], v2[1]],\n color=EDGE_COLORS.get(a, \"gray\"),\n linestyle=EDGE_STYLES.get(a, \"-\"), linewidth=2)\n ax2.scatter(tverts[:, 0], tverts[:, 1], c=\"black\", s=20, zorder=5)\n ax2.grid(True, alpha=0.2)\n\n plt.tight_layout()\n plt.show()\n print(f\"\\nBest FOLD JSON:\\n{json.dumps(best_fold, indent=2)}\")",
189
  "metadata": {},
190
  "execution_count": null,
191
  "outputs": []
 
193
  {
194
  "cell_type": "markdown",
195
  "id": "qlakksqmoe",
196
+ "source": "## 13. Training Logs",
197
  "metadata": {}
198
  },
199
  {