Spaces:
Runtime error
Runtime error
Commit Β·
b9a4a95
1
Parent(s): 7fccd0c
Add MAX_CONCURRENT_ENVS, sync latest changes
Browse files- Dockerfile +2 -0
- origami_server/app.py +1 -0
- tests/test_origami.py +14 -8
- training/reward.py +30 -55
- training/train_grpo.py +103 -41
- training/train_origami.ipynb +34 -58
Dockerfile
CHANGED
|
@@ -9,6 +9,8 @@ RUN pip install --no-cache-dir -r requirements.txt \
|
|
| 9 |
|
| 10 |
COPY . /app
|
| 11 |
|
|
|
|
|
|
|
| 12 |
EXPOSE 8000
|
| 13 |
|
| 14 |
CMD ["uvicorn", "origami_server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
|
|
| 9 |
|
| 10 |
COPY . /app
|
| 11 |
|
| 12 |
+
ENV MAX_CONCURRENT_ENVS=16
|
| 13 |
+
|
| 14 |
EXPOSE 8000
|
| 15 |
|
| 16 |
CMD ["uvicorn", "origami_server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
origami_server/app.py
CHANGED
|
@@ -16,6 +16,7 @@ app = create_app(
|
|
| 16 |
OrigamiAction,
|
| 17 |
OrigamiObservation,
|
| 18 |
env_name="origami_env",
|
|
|
|
| 19 |
)
|
| 20 |
|
| 21 |
from .tasks import TASKS
|
|
|
|
| 16 |
OrigamiAction,
|
| 17 |
OrigamiObservation,
|
| 18 |
env_name="origami_env",
|
| 19 |
+
max_concurrent_envs=int(os.environ.get("MAX_CONCURRENT_ENVS", 1)),
|
| 20 |
)
|
| 21 |
|
| 22 |
from .tasks import TASKS
|
tests/test_origami.py
CHANGED
|
@@ -9,7 +9,7 @@ from origami_server.engine.simulate import simulate
|
|
| 9 |
from origami_server.environment import OrigamiEnvironment
|
| 10 |
from origami_server.models import OrigamiAction
|
| 11 |
from origami_server.tasks import TASKS, get_task, list_tasks
|
| 12 |
-
from training.reward import extract_fold_json,
|
| 13 |
|
| 14 |
# --- Fixtures ---
|
| 15 |
|
|
@@ -221,14 +221,20 @@ class TestRewards:
|
|
| 221 |
assert scores[0] == 1.0
|
| 222 |
assert scores[1] == -2.0
|
| 223 |
|
| 224 |
-
def
|
| 225 |
-
|
|
|
|
|
|
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
|
| 234 |
# --- API ---
|
|
|
|
| 9 |
from origami_server.environment import OrigamiEnvironment
|
| 10 |
from origami_server.models import OrigamiAction
|
| 11 |
from origami_server.tasks import TASKS, get_task, list_tasks
|
| 12 |
+
from training.reward import extract_fold_json, valid_fold
|
| 13 |
|
| 14 |
# --- Fixtures ---
|
| 15 |
|
|
|
|
| 221 |
assert scores[0] == 1.0
|
| 222 |
assert scores[1] == -2.0
|
| 223 |
|
| 224 |
+
def test_shape_match_via_server(self):
|
| 225 |
+
"""shape_match reward now goes through the server (WebSocket).
|
| 226 |
+
Test the same flow via TestClient's websocket to verify end-to-end."""
|
| 227 |
+
from fastapi.testclient import TestClient
|
| 228 |
|
| 229 |
+
from origami_server.app import app
|
| 230 |
+
|
| 231 |
+
client = TestClient(app)
|
| 232 |
+
with client.websocket_connect("/ws") as ws:
|
| 233 |
+
ws.send_json({"type": "reset", "data": {"task_name": "triangle"}})
|
| 234 |
+
ws.receive_json()
|
| 235 |
+
ws.send_json({"type": "step", "data": {"fold_data": TRIANGLE_FOLD}})
|
| 236 |
+
resp = ws.receive_json()
|
| 237 |
+
assert resp["data"]["reward"] == 20.0
|
| 238 |
|
| 239 |
|
| 240 |
# --- API ---
|
training/reward.py
CHANGED
|
@@ -1,21 +1,17 @@
|
|
| 1 |
"""GRPO reward functions for origami RL training.
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import json
|
| 9 |
import re
|
| 10 |
from typing import Any
|
| 11 |
|
| 12 |
-
import numpy as np
|
| 13 |
-
|
| 14 |
-
from origami_server.engine.fold_parser import validate_fold
|
| 15 |
-
from origami_server.engine.shape_match import compute_shape_match
|
| 16 |
-
from origami_server.engine.simulate import simulate
|
| 17 |
-
from origami_server.tasks import get_task
|
| 18 |
-
|
| 19 |
|
| 20 |
def extract_fold_json(response: str) -> dict | None:
|
| 21 |
"""Extract FOLD JSON from LLM response text.
|
|
@@ -55,6 +51,8 @@ def valid_fold(completions: list, **kwargs: Any) -> list[float]:
|
|
| 55 |
+1.0 valid FOLD JSON with correct structure
|
| 56 |
-0.5 parseable JSON but invalid FOLD structure
|
| 57 |
-2.0 not parseable as JSON at all
|
|
|
|
|
|
|
| 58 |
"""
|
| 59 |
scores = []
|
| 60 |
for completion in completions:
|
|
@@ -65,58 +63,35 @@ def valid_fold(completions: list, **kwargs: Any) -> list[float]:
|
|
| 65 |
scores.append(-2.0)
|
| 66 |
continue
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
else:
|
| 72 |
scores.append(-0.5)
|
|
|
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def shape_match(
|
| 78 |
-
completions: list,
|
| 79 |
-
task_name: str = "triangle",
|
| 80 |
-
**kwargs: Any,
|
| 81 |
-
) -> list[float]:
|
| 82 |
-
"""Reward 2: Simulate the fold and compare to target shape.
|
| 83 |
-
|
| 84 |
-
Score = similarity Γ 20.0 (range: 0 to 20)
|
| 85 |
-
-1.0 if simulation fails/diverges
|
| 86 |
-
-2.0 if FOLD data is invalid
|
| 87 |
-
|
| 88 |
-
This is the main reward signal β AlphaFold-style shape comparison.
|
| 89 |
-
"""
|
| 90 |
-
task = get_task(task_name)
|
| 91 |
-
target_fold = task["target_fold"]
|
| 92 |
-
|
| 93 |
-
# Pre-compute target positions
|
| 94 |
-
try:
|
| 95 |
-
target_result = simulate(target_fold, crease_percent=1.0)
|
| 96 |
-
target_positions = target_result.positions
|
| 97 |
-
except Exception:
|
| 98 |
-
# Target itself fails β all scores 0
|
| 99 |
-
return [0.0] * len(completions)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
fold_data = extract_fold_json(response)
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
| 108 |
continue
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
continue
|
| 114 |
|
| 115 |
-
|
| 116 |
-
result = simulate(fold_data, crease_percent=1.0)
|
| 117 |
-
similarity = compute_shape_match(result.positions, target_positions)
|
| 118 |
-
scores.append(similarity * 20.0)
|
| 119 |
-
except Exception:
|
| 120 |
-
scores.append(-1.0)
|
| 121 |
|
| 122 |
return scores
|
|
|
|
| 1 |
"""GRPO reward functions for origami RL training.
|
| 2 |
|
| 3 |
+
Follows the OpenEnv 2048 pattern exactly:
|
| 4 |
+
- launch_openenv() spawns/reuses the origami server
|
| 5 |
+
- Reward functions call the server via EnvClient
|
| 6 |
+
- Server computes simulation + shape matching, returns reward
|
| 7 |
+
|
| 8 |
+
These functions are also importable for use in notebooks.
|
| 9 |
"""
|
| 10 |
|
| 11 |
import json
|
| 12 |
import re
|
| 13 |
from typing import Any
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def extract_fold_json(response: str) -> dict | None:
|
| 17 |
"""Extract FOLD JSON from LLM response text.
|
|
|
|
| 51 |
+1.0 valid FOLD JSON with correct structure
|
| 52 |
-0.5 parseable JSON but invalid FOLD structure
|
| 53 |
-2.0 not parseable as JSON at all
|
| 54 |
+
|
| 55 |
+
Local check β no server needed.
|
| 56 |
"""
|
| 57 |
scores = []
|
| 58 |
for completion in completions:
|
|
|
|
| 63 |
scores.append(-2.0)
|
| 64 |
continue
|
| 65 |
|
| 66 |
+
# Basic structural validation
|
| 67 |
+
required = {"vertices_coords", "edges_vertices", "edges_assignment"}
|
| 68 |
+
if not required.issubset(fold_data.keys()):
|
|
|
|
| 69 |
scores.append(-0.5)
|
| 70 |
+
continue
|
| 71 |
|
| 72 |
+
verts = fold_data.get("vertices_coords", [])
|
| 73 |
+
edges = fold_data.get("edges_vertices", [])
|
| 74 |
+
assigns = fold_data.get("edges_assignment", [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
if len(edges) != len(assigns):
|
| 77 |
+
scores.append(-0.5)
|
| 78 |
+
continue
|
|
|
|
| 79 |
|
| 80 |
+
has_fold = any(a in ("M", "V") for a in assigns)
|
| 81 |
+
has_boundary = any(a == "B" for a in assigns)
|
| 82 |
+
if not has_fold or not has_boundary:
|
| 83 |
+
scores.append(-0.5)
|
| 84 |
continue
|
| 85 |
|
| 86 |
+
n = len(verts)
|
| 87 |
+
valid_indices = all(
|
| 88 |
+
0 <= e[0] < n and 0 <= e[1] < n and e[0] != e[1]
|
| 89 |
+
for e in edges
|
| 90 |
+
)
|
| 91 |
+
if not valid_indices:
|
| 92 |
+
scores.append(-0.5)
|
| 93 |
continue
|
| 94 |
|
| 95 |
+
scores.append(1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
return scores
|
training/train_grpo.py
CHANGED
|
@@ -1,20 +1,28 @@
|
|
| 1 |
"""GRPO training script for origami RL.
|
| 2 |
|
| 3 |
-
Follows the
|
| 4 |
-
-
|
| 5 |
-
-
|
|
|
|
| 6 |
- GRPOTrainer from TRL handles the RL loop
|
| 7 |
|
| 8 |
-
Usage
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
python -m training.train_grpo --task triangle --max_steps 600
|
| 10 |
|
| 11 |
-
|
| 12 |
-
python -m training.train_grpo --
|
| 13 |
"""
|
| 14 |
|
| 15 |
import argparse
|
|
|
|
| 16 |
import os
|
| 17 |
|
|
|
|
|
|
|
| 18 |
PROMPT_TEMPLATE = """You are an origami designer. Generate a FOLD-format crease pattern
|
| 19 |
that, when folded, produces the target shape described below.
|
| 20 |
|
|
@@ -49,60 +57,109 @@ def main():
|
|
| 49 |
parser = argparse.ArgumentParser(description="GRPO training for origami RL")
|
| 50 |
parser.add_argument("--task", default="triangle", help="Task name")
|
| 51 |
parser.add_argument("--max_steps", type=int, default=600)
|
| 52 |
-
parser.add_argument("--num_generations", type=int, default=
|
| 53 |
-
parser.add_argument("--model", default="
|
| 54 |
parser.add_argument("--lr", type=float, default=2e-4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
args = parser.parse_args()
|
| 56 |
|
| 57 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
from datasets import Dataset
|
| 59 |
-
from trl import GRPOConfig, GRPOTrainer
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
|
|
|
| 63 |
|
| 64 |
-
#
|
| 65 |
try:
|
| 66 |
from unsloth import FastLanguageModel
|
| 67 |
USE_UNSLOTH = True
|
| 68 |
except ImportError:
|
| 69 |
USE_UNSLOTH = False
|
| 70 |
|
| 71 |
-
|
| 72 |
-
prompt_text = build_prompt(task)
|
| 73 |
-
|
| 74 |
-
# Build dataset (1000 copies of same prompt, like 2048)
|
| 75 |
-
dataset = Dataset.from_list(
|
| 76 |
-
[
|
| 77 |
-
{
|
| 78 |
-
"prompt": [{"role": "user", "content": prompt_text}],
|
| 79 |
-
"answer": 0,
|
| 80 |
-
}
|
| 81 |
-
]
|
| 82 |
-
* 1000
|
| 83 |
-
)
|
| 84 |
|
| 85 |
-
# Load model with LoRA
|
| 86 |
if USE_UNSLOTH:
|
|
|
|
| 87 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 88 |
model_name=args.model,
|
| 89 |
load_in_4bit=True,
|
| 90 |
-
max_seq_length=
|
|
|
|
| 91 |
)
|
| 92 |
model = FastLanguageModel.get_peft_model(
|
| 93 |
model,
|
| 94 |
-
r=
|
| 95 |
target_modules=[
|
| 96 |
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 97 |
"gate_proj", "up_proj", "down_proj",
|
| 98 |
],
|
| 99 |
-
lora_alpha=
|
| 100 |
use_gradient_checkpointing="unsloth",
|
|
|
|
| 101 |
)
|
| 102 |
else:
|
| 103 |
import torch
|
| 104 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 105 |
from peft import LoraConfig, get_peft_model
|
|
|
|
| 106 |
|
| 107 |
bnb_config = BitsAndBytesConfig(
|
| 108 |
load_in_4bit=True,
|
|
@@ -118,19 +175,23 @@ def main():
|
|
| 118 |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 119 |
)
|
| 120 |
model = get_peft_model(model, LoraConfig(
|
| 121 |
-
r=
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
))
|
| 125 |
|
| 126 |
if tokenizer.pad_token is None:
|
| 127 |
tokenizer.pad_token = tokenizer.eos_token
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
| 132 |
|
| 133 |
-
# GRPO config
|
| 134 |
training_args = GRPOConfig(
|
| 135 |
temperature=1.0,
|
| 136 |
learning_rate=args.lr,
|
|
@@ -142,8 +203,8 @@ def main():
|
|
| 142 |
per_device_train_batch_size=1,
|
| 143 |
gradient_accumulation_steps=1,
|
| 144 |
num_generations=args.num_generations,
|
| 145 |
-
max_prompt_length=
|
| 146 |
-
max_completion_length=
|
| 147 |
max_steps=args.max_steps,
|
| 148 |
save_steps=100,
|
| 149 |
output_dir=os.environ.get("OUTPUT_DIR", "outputs"),
|
|
@@ -157,9 +218,10 @@ def main():
|
|
| 157 |
train_dataset=dataset,
|
| 158 |
)
|
| 159 |
|
|
|
|
| 160 |
trainer.train()
|
| 161 |
|
| 162 |
-
# Save
|
| 163 |
save_path = os.path.join(
|
| 164 |
os.environ.get("OUTPUT_DIR", "outputs"),
|
| 165 |
f"origami-{args.task}-lora-final",
|
|
|
|
| 1 |
"""GRPO training script for origami RL.
|
| 2 |
|
| 3 |
+
Follows the OpenEnv 2048 pattern exactly:
|
| 4 |
+
- Environment runs as a FastAPI server (origami_server.app)
|
| 5 |
+
- Training connects via WebSocket client (OrigamiEnv)
|
| 6 |
+
- Reward functions call the server, never import engine code
|
| 7 |
- GRPOTrainer from TRL handles the RL loop
|
| 8 |
|
| 9 |
+
Usage:
|
| 10 |
+
# 1. Start the environment server first:
|
| 11 |
+
uvicorn origami_server.app:app --host 0.0.0.0 --port 8000
|
| 12 |
+
|
| 13 |
+
# 2. Run training (connects to server):
|
| 14 |
python -m training.train_grpo --task triangle --max_steps 600
|
| 15 |
|
| 16 |
+
# Or specify server URL:
|
| 17 |
+
python -m training.train_grpo --server http://gpu-host:8000
|
| 18 |
"""
|
| 19 |
|
| 20 |
import argparse
|
| 21 |
+
import functools
|
| 22 |
import os
|
| 23 |
|
| 24 |
+
import requests
|
| 25 |
+
|
| 26 |
PROMPT_TEMPLATE = """You are an origami designer. Generate a FOLD-format crease pattern
|
| 27 |
that, when folded, produces the target shape described below.
|
| 28 |
|
|
|
|
| 57 |
parser = argparse.ArgumentParser(description="GRPO training for origami RL")
|
| 58 |
parser.add_argument("--task", default="triangle", help="Task name")
|
| 59 |
parser.add_argument("--max_steps", type=int, default=600)
|
| 60 |
+
parser.add_argument("--num_generations", type=int, default=2)
|
| 61 |
+
parser.add_argument("--model", default="unsloth/Qwen3-14B")
|
| 62 |
parser.add_argument("--lr", type=float, default=2e-4)
|
| 63 |
+
parser.add_argument("--lora_rank", type=int, default=4)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--server", default="http://localhost:8000",
|
| 66 |
+
help="URL of the origami environment server",
|
| 67 |
+
)
|
| 68 |
args = parser.parse_args()
|
| 69 |
|
| 70 |
+
# --- Verify server is running ---
|
| 71 |
+
print(f"Connecting to environment server at {args.server}...")
|
| 72 |
+
try:
|
| 73 |
+
r = requests.get(f"{args.server}/health", timeout=5)
|
| 74 |
+
assert r.status_code == 200
|
| 75 |
+
print("Server is healthy.")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"ERROR: Cannot connect to server at {args.server}")
|
| 78 |
+
print(f"Start it first: uvicorn origami_server.app:app --port 8000")
|
| 79 |
+
raise SystemExit(1)
|
| 80 |
+
|
| 81 |
+
# --- Get task info from server ---
|
| 82 |
+
task = requests.get(f"{args.server}/tasks/{args.task}").json()
|
| 83 |
+
prompt_text = build_prompt(task)
|
| 84 |
+
print(f"Task: {task['name']} β {task['description']}")
|
| 85 |
+
|
| 86 |
+
# --- Configure reward functions (OpenEnv pattern) ---
|
| 87 |
+
from client import OrigamiEnv
|
| 88 |
+
from origami_server.models import OrigamiAction
|
| 89 |
+
from training.reward import extract_fold_json, valid_fold
|
| 90 |
+
from unsloth import is_port_open, launch_openenv
|
| 91 |
+
|
| 92 |
+
global port, openenv_process
|
| 93 |
+
port = int(args.server.split(":")[-1]) if ":" in args.server else 8000
|
| 94 |
+
openenv_process = None
|
| 95 |
+
|
| 96 |
+
launch_openenv = functools.partial(
|
| 97 |
+
launch_openenv,
|
| 98 |
+
working_directory=os.getcwd(),
|
| 99 |
+
server="origami_server.app:app",
|
| 100 |
+
environment={**os.environ, "PYTHONPATH": os.getcwd()},
|
| 101 |
+
openenv_class=OrigamiEnv,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def shape_match_reward(completions, **kwargs):
|
| 105 |
+
global port, openenv_process
|
| 106 |
+
scores = []
|
| 107 |
+
for completion in completions:
|
| 108 |
+
response = completion[0]["content"]
|
| 109 |
+
fold_data = extract_fold_json(response)
|
| 110 |
+
if fold_data is None:
|
| 111 |
+
scores.append(0.0)
|
| 112 |
+
continue
|
| 113 |
+
try:
|
| 114 |
+
port, openenv_process = launch_openenv(port, openenv_process)
|
| 115 |
+
openenv_process.reset(task_name=args.task)
|
| 116 |
+
result = openenv_process.step(OrigamiAction(fold_data=fold_data))
|
| 117 |
+
scores.append(result.reward if result.reward is not None else 0.0)
|
| 118 |
+
except TimeoutError:
|
| 119 |
+
scores.append(-1.0)
|
| 120 |
+
except Exception:
|
| 121 |
+
scores.append(-3.0)
|
| 122 |
+
return scores
|
| 123 |
+
|
| 124 |
+
# --- Build dataset (same prompt repeated, like 2048) ---
|
| 125 |
from datasets import Dataset
|
|
|
|
| 126 |
|
| 127 |
+
dataset = Dataset.from_list(
|
| 128 |
+
[{"prompt": [{"role": "user", "content": prompt_text}]}] * 1000
|
| 129 |
+
)
|
| 130 |
|
| 131 |
+
# --- Load model with QLoRA ---
|
| 132 |
try:
|
| 133 |
from unsloth import FastLanguageModel
|
| 134 |
USE_UNSLOTH = True
|
| 135 |
except ImportError:
|
| 136 |
USE_UNSLOTH = False
|
| 137 |
|
| 138 |
+
max_seq_length = 768 # FOLD JSON is compact
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
|
|
|
| 140 |
if USE_UNSLOTH:
|
| 141 |
+
print(f"Loading {args.model} with Unsloth QLoRA (rank={args.lora_rank})...")
|
| 142 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 143 |
model_name=args.model,
|
| 144 |
load_in_4bit=True,
|
| 145 |
+
max_seq_length=max_seq_length,
|
| 146 |
+
offload_embedding=True, # Needed for 14B on limited VRAM
|
| 147 |
)
|
| 148 |
model = FastLanguageModel.get_peft_model(
|
| 149 |
model,
|
| 150 |
+
r=args.lora_rank,
|
| 151 |
target_modules=[
|
| 152 |
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 153 |
"gate_proj", "up_proj", "down_proj",
|
| 154 |
],
|
| 155 |
+
lora_alpha=args.lora_rank * 2,
|
| 156 |
use_gradient_checkpointing="unsloth",
|
| 157 |
+
random_state=3407,
|
| 158 |
)
|
| 159 |
else:
|
| 160 |
import torch
|
|
|
|
| 161 |
from peft import LoraConfig, get_peft_model
|
| 162 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 163 |
|
| 164 |
bnb_config = BitsAndBytesConfig(
|
| 165 |
load_in_4bit=True,
|
|
|
|
| 175 |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 176 |
)
|
| 177 |
model = get_peft_model(model, LoraConfig(
|
| 178 |
+
r=args.lora_rank,
|
| 179 |
+
lora_alpha=args.lora_rank * 2,
|
| 180 |
+
task_type="CAUSAL_LM",
|
| 181 |
+
target_modules=[
|
| 182 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 183 |
+
"gate_proj", "up_proj", "down_proj",
|
| 184 |
+
],
|
| 185 |
))
|
| 186 |
|
| 187 |
if tokenizer.pad_token is None:
|
| 188 |
tokenizer.pad_token = tokenizer.eos_token
|
| 189 |
|
| 190 |
+
model.print_trainable_parameters()
|
| 191 |
+
|
| 192 |
+
# --- GRPO config (matches 2048 pattern) ---
|
| 193 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 194 |
|
|
|
|
| 195 |
training_args = GRPOConfig(
|
| 196 |
temperature=1.0,
|
| 197 |
learning_rate=args.lr,
|
|
|
|
| 203 |
per_device_train_batch_size=1,
|
| 204 |
gradient_accumulation_steps=1,
|
| 205 |
num_generations=args.num_generations,
|
| 206 |
+
max_prompt_length=512,
|
| 207 |
+
max_completion_length=max_seq_length - 512,
|
| 208 |
max_steps=args.max_steps,
|
| 209 |
save_steps=100,
|
| 210 |
output_dir=os.environ.get("OUTPUT_DIR", "outputs"),
|
|
|
|
| 218 |
train_dataset=dataset,
|
| 219 |
)
|
| 220 |
|
| 221 |
+
print(f"Training: {args.max_steps} steps, {args.num_generations} generations/step")
|
| 222 |
trainer.train()
|
| 223 |
|
| 224 |
+
# Save LoRA adapter
|
| 225 |
save_path = os.path.join(
|
| 226 |
os.environ.get("OUTPUT_DIR", "outputs"),
|
| 227 |
f"origami-{args.task}-lora-final",
|
training/train_origami.ipynb
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"id": "p8uwc5bkc4n",
|
| 6 |
-
"source": "# Origami RL β GRPO Training
|
| 7 |
"metadata": {}
|
| 8 |
},
|
| 9 |
{
|
|
@@ -15,7 +15,7 @@
|
|
| 15 |
{
|
| 16 |
"cell_type": "code",
|
| 17 |
"id": "ulhu8a5p5ti",
|
| 18 |
-
"source": "
|
| 19 |
"metadata": {},
|
| 20 |
"execution_count": null,
|
| 21 |
"outputs": []
|
|
@@ -23,13 +23,13 @@
|
|
| 23 |
{
|
| 24 |
"cell_type": "markdown",
|
| 25 |
"id": "qcetkmcq1hf",
|
| 26 |
-
"source": "## 2.
|
| 27 |
"metadata": {}
|
| 28 |
},
|
| 29 |
{
|
| 30 |
"cell_type": "code",
|
| 31 |
"id": "3hr273dhqiv",
|
| 32 |
-
"source": "
|
| 33 |
"metadata": {},
|
| 34 |
"execution_count": null,
|
| 35 |
"outputs": []
|
|
@@ -37,7 +37,7 @@
|
|
| 37 |
{
|
| 38 |
"cell_type": "code",
|
| 39 |
"id": "bnm2w57r3lc",
|
| 40 |
-
"source": "
|
| 41 |
"metadata": {},
|
| 42 |
"execution_count": null,
|
| 43 |
"outputs": []
|
|
@@ -45,43 +45,13 @@
|
|
| 45 |
{
|
| 46 |
"cell_type": "markdown",
|
| 47 |
"id": "lcaus7mtuj",
|
| 48 |
-
"source": "## 3.
|
| 49 |
"metadata": {}
|
| 50 |
},
|
| 51 |
{
|
| 52 |
"cell_type": "code",
|
| 53 |
"id": "hlqp4y30m87",
|
| 54 |
-
"source": "
|
| 55 |
-
"metadata": {},
|
| 56 |
-
"execution_count": null,
|
| 57 |
-
"outputs": []
|
| 58 |
-
},
|
| 59 |
-
{
|
| 60 |
-
"cell_type": "code",
|
| 61 |
-
"id": "dwqqus8mhlj",
|
| 62 |
-
"source": "# Test the simulator on each task\nfor name in list_tasks():\n task = get_task(name)\n target_fold = task[\"target_fold\"]\n \n # Simulate flat (0%), half (50%), and fully folded (100%)\n r_flat = simulate(target_fold, crease_percent=0.0)\n r_half = simulate(target_fold, crease_percent=0.5)\n r_full = simulate(target_fold, crease_percent=1.0)\n \n z_half = r_half.positions[:, 2].max() - r_half.positions[:, 2].min()\n \n # Shape match: target vs itself should be 1.0\n self_sim = compute_shape_match(r_full.positions, r_full.positions)\n \n print(f\"{name:15s} | converged={r_full.converged} | strain={r_full.max_strain:.6f} | \"\n f\"z_range@50%={z_half:.3f} | self_similarity={self_sim:.3f}\")",
|
| 63 |
-
"metadata": {},
|
| 64 |
-
"execution_count": null,
|
| 65 |
-
"outputs": []
|
| 66 |
-
},
|
| 67 |
-
{
|
| 68 |
-
"cell_type": "code",
|
| 69 |
-
"id": "p1weq9kv5q",
|
| 70 |
-
"source": "# Test reward functions with mock LLM outputs\ntriangle_fold = TASKS[\"triangle\"][\"target_fold\"]\n\n# Simulate what the reward functions see during training:\n# completions = list of [{\"content\": \"...LLM response...\"}]\ngood_response = json.dumps(triangle_fold)\nbad_json = \"I think we should fold it like this...\"\ninvalid_fold = json.dumps({\"vertices_coords\": [[0, 0]], \"edges_vertices\": [], \"edges_assignment\": []})\n\ncompletions = [\n [{\"content\": f\"```json\\n{good_response}\\n```\"}], # correct answer in fenced block\n [{\"content\": bad_json}], # garbage\n [{\"content\": invalid_fold}], # parseable but invalid FOLD\n]\n\nprint(\"valid_fold rewards:\", valid_fold(completions))\nprint(\"shape_match rewards:\", shape_match(completions, task_name=\"triangle\"))\nprint()\nprint(\"Expected: valid_fold = [1.0, -2.0, -0.5]\")\nprint(\"Expected: shape_match = [20.0, -2.0, -1.0]\")",
|
| 71 |
-
"metadata": {},
|
| 72 |
-
"execution_count": null,
|
| 73 |
-
"outputs": []
|
| 74 |
-
},
|
| 75 |
-
{
|
| 76 |
-
"cell_type": "markdown",
|
| 77 |
-
"id": "45l0n1hgvr",
|
| 78 |
-
"source": "## 4. Visualize Tasks\n\n2D crease patterns for each task (matplotlib).",
|
| 79 |
-
"metadata": {}
|
| 80 |
-
},
|
| 81 |
-
{
|
| 82 |
-
"cell_type": "code",
|
| 83 |
-
"id": "fkopb9lgg7i",
|
| 84 |
-
"source": "import matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom mpl_toolkits.mplot3d.art3d import Poly3DCollection\n\nEDGE_COLORS = {\"M\": \"red\", \"V\": \"blue\", \"B\": \"black\"}\nEDGE_STYLES = {\"M\": \"--\", \"V\": \":\", \"B\": \"-\"}\n\nfig, axes = plt.subplots(2, 4, figsize=(16, 8))\n\nfor idx, (name, task) in enumerate(TASKS.items()):\n fold = task[\"target_fold\"]\n verts = np.array(fold[\"vertices_coords\"])\n \n # Row 1: 2D crease pattern\n ax = axes[0, idx]\n ax.set_title(f\"{name}\\n{task['description']}\", fontsize=9)\n ax.set_aspect(\"equal\")\n ax.set_xlim(-0.1, 1.1)\n ax.set_ylim(-0.1, 1.1)\n ax.grid(True, alpha=0.2)\n \n for i, (e, a) in enumerate(zip(fold[\"edges_vertices\"], fold[\"edges_assignment\"])):\n v1, v2 = verts[e[0]], verts[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n style = EDGE_STYLES.get(a, \"-\")\n lw = 2.5 if a == \"B\" else 1.8\n ax.plot([v1[0], v2[0]], [v1[1], v2[1]], color=color, linestyle=style, linewidth=lw)\n \n ax.scatter(verts[:, 0], verts[:, 1], c=\"black\", s=15, zorder=5)\n \n # Row 2: 3D folded shape\n ax3 = fig.add_subplot(2, 4, idx + 5, projection=\"3d\")\n result = simulate(fold, crease_percent=1.0)\n pos = result.positions\n \n if \"faces_vertices\" in fold:\n for face in fold[\"faces_vertices\"]:\n tri_verts = [pos[vi] for vi in face]\n poly = Poly3DCollection([tri_verts], alpha=0.3, facecolor=\"lightskyblue\", edgecolor=\"steelblue\")\n ax3.add_collection3d(poly)\n \n for i, (e, a) in enumerate(zip(fold[\"edges_vertices\"], fold[\"edges_assignment\"])):\n p1, p2 = pos[e[0]], pos[e[1]]\n color = EDGE_COLORS.get(a, \"gray\")\n ax3.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=color, linewidth=1.2)\n \n ax3.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c=\"black\", s=10, zorder=5)\n ax3.set_title(f\"Folded (3D)\", fontsize=9)\n ax3.set_xlim(-0.2, 1.2)\n ax3.set_ylim(-0.2, 1.2)\n ax3.set_zlim(-0.6, 0.6)\n \n # Remove the empty 2D subplot that was in row 2\n axes[1, idx].remove()\n\nplt.tight_layout()\nplt.show()",
|
| 85 |
"metadata": {},
|
| 86 |
"execution_count": null,
|
| 87 |
"outputs": []
|
|
@@ -89,13 +59,13 @@
|
|
| 89 |
{
|
| 90 |
"cell_type": "markdown",
|
| 91 |
"id": "a14w2fkoewq",
|
| 92 |
-
"source": "##
|
| 93 |
"metadata": {}
|
| 94 |
},
|
| 95 |
{
|
| 96 |
"cell_type": "code",
|
| 97 |
"id": "2phdejbobq3",
|
| 98 |
-
"source": "
|
| 99 |
"metadata": {},
|
| 100 |
"execution_count": null,
|
| 101 |
"outputs": []
|
|
@@ -103,13 +73,13 @@
|
|
| 103 |
{
|
| 104 |
"cell_type": "markdown",
|
| 105 |
"id": "feal20fr8j5",
|
| 106 |
-
"source": "##
|
| 107 |
"metadata": {}
|
| 108 |
},
|
| 109 |
{
|
| 110 |
"cell_type": "code",
|
| 111 |
"id": "uo7zh1dwp6r",
|
| 112 |
-
"source": "
|
| 113 |
"metadata": {},
|
| 114 |
"execution_count": null,
|
| 115 |
"outputs": []
|
|
@@ -117,7 +87,7 @@
|
|
| 117 |
{
|
| 118 |
"cell_type": "code",
|
| 119 |
"id": "900vyqwb8g",
|
| 120 |
-
"source": "
|
| 121 |
"metadata": {},
|
| 122 |
"execution_count": null,
|
| 123 |
"outputs": []
|
|
@@ -125,48 +95,54 @@
|
|
| 125 |
{
|
| 126 |
"cell_type": "markdown",
|
| 127 |
"id": "xn6n1hpx2aa",
|
| 128 |
-
"source": "##
|
| 129 |
"metadata": {}
|
| 130 |
},
|
| 131 |
{
|
| 132 |
"cell_type": "code",
|
| 133 |
"id": "vkfaeuu9dq",
|
| 134 |
-
"source": "import
|
| 135 |
"metadata": {},
|
| 136 |
"execution_count": null,
|
| 137 |
"outputs": []
|
| 138 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
{
|
| 140 |
"cell_type": "code",
|
| 141 |
-
"id": "
|
| 142 |
-
"source": "
|
| 143 |
"metadata": {},
|
| 144 |
"execution_count": null,
|
| 145 |
"outputs": []
|
| 146 |
},
|
| 147 |
{
|
| 148 |
"cell_type": "markdown",
|
| 149 |
-
"id": "
|
| 150 |
-
"source": "## 8.
|
| 151 |
"metadata": {}
|
| 152 |
},
|
| 153 |
{
|
| 154 |
"cell_type": "code",
|
| 155 |
-
"id": "
|
| 156 |
-
"source": "from trl import GRPOConfig, GRPOTrainer\n\
|
| 157 |
"metadata": {},
|
| 158 |
"execution_count": null,
|
| 159 |
"outputs": []
|
| 160 |
},
|
| 161 |
{
|
| 162 |
"cell_type": "markdown",
|
| 163 |
-
"id": "
|
| 164 |
"source": "## 9. Train!",
|
| 165 |
"metadata": {}
|
| 166 |
},
|
| 167 |
{
|
| 168 |
"cell_type": "code",
|
| 169 |
-
"id": "
|
| 170 |
"source": "trainer.train()",
|
| 171 |
"metadata": {},
|
| 172 |
"execution_count": null,
|
|
@@ -181,7 +157,7 @@
|
|
| 181 |
{
|
| 182 |
"cell_type": "code",
|
| 183 |
"id": "t3d4tu6o5mc",
|
| 184 |
-
"source": "
|
| 185 |
"metadata": {},
|
| 186 |
"execution_count": null,
|
| 187 |
"outputs": []
|
|
@@ -189,13 +165,13 @@
|
|
| 189 |
{
|
| 190 |
"cell_type": "markdown",
|
| 191 |
"id": "q18eizy1ok",
|
| 192 |
-
"source": "## 11. Evaluate β Generate & Score
|
| 193 |
"metadata": {}
|
| 194 |
},
|
| 195 |
{
|
| 196 |
"cell_type": "code",
|
| 197 |
"id": "on56augj41",
|
| 198 |
-
"source": "
|
| 199 |
"metadata": {},
|
| 200 |
"execution_count": null,
|
| 201 |
"outputs": []
|
|
@@ -203,13 +179,13 @@
|
|
| 203 |
{
|
| 204 |
"cell_type": "markdown",
|
| 205 |
"id": "tb1y8hszrk",
|
| 206 |
-
"source": "## 12. Visualize
|
| 207 |
"metadata": {}
|
| 208 |
},
|
| 209 |
{
|
| 210 |
"cell_type": "code",
|
| 211 |
"id": "0zo3krbkiqej",
|
| 212 |
-
"source": "
|
| 213 |
"metadata": {},
|
| 214 |
"execution_count": null,
|
| 215 |
"outputs": []
|
|
@@ -217,7 +193,7 @@
|
|
| 217 |
{
|
| 218 |
"cell_type": "markdown",
|
| 219 |
"id": "qlakksqmoe",
|
| 220 |
-
"source": "## 13.
|
| 221 |
"metadata": {}
|
| 222 |
},
|
| 223 |
{
|
|
|
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"id": "p8uwc5bkc4n",
|
| 6 |
+
"source": "# Origami RL β GRPO Training\n\nTrain an LLM to generate FOLD crease patterns using OpenEnv + Unsloth + TRL.\n\nFollows the [2048 OpenEnv notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/OpenEnv_gpt_oss_(20B)_Reinforcement_Learning_2048_Game.ipynb) pattern exactly:\n1. `launch_openenv()` spawns the origami environment server\n2. LLM generates FOLD JSON crease patterns\n3. Reward functions call the server via OpenEnv client\n4. GRPO updates policy based on relative rewards",
|
| 7 |
"metadata": {}
|
| 8 |
},
|
| 9 |
{
|
|
|
|
| 15 |
{
|
| 16 |
"cell_type": "code",
|
| 17 |
"id": "ulhu8a5p5ti",
|
| 18 |
+
"source": "%%capture\nimport os, importlib.util\n!pip install --upgrade -qqq uv\nif importlib.util.find_spec(\"torch\") is None or \"COLAB_\" in \"\".join(os.environ.keys()):\n try: import numpy; get_numpy = f\"numpy=={numpy.__version__}\"\n except: get_numpy = \"numpy\"\n !uv pip install -qqq \\\n \"torch>=2.8.0\" \"triton>=3.4.0\" {get_numpy} torchvision bitsandbytes \"transformers==4.56.2\" trackio \\\n \"unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo\" \\\n \"unsloth[base] @ git+https://github.com/unslothai/unsloth\"\nelif importlib.util.find_spec(\"unsloth\") is None:\n !uv pip install -qqq unsloth trackio\n!uv pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo\n!pip install -qqq fastapi uvicorn requests numpy scipy pydantic",
|
| 19 |
"metadata": {},
|
| 20 |
"execution_count": null,
|
| 21 |
"outputs": []
|
|
|
|
| 23 |
{
|
| 24 |
"cell_type": "markdown",
|
| 25 |
"id": "qcetkmcq1hf",
|
| 26 |
+
"source": "## 2. Clone Origami Env + Setup Paths",
|
| 27 |
"metadata": {}
|
| 28 |
},
|
| 29 |
{
|
| 30 |
"cell_type": "code",
|
| 31 |
"id": "3hr273dhqiv",
|
| 32 |
+
"source": "%%capture\n# Clone the origami env repo (skip if running locally)\nimport subprocess, sys, os\nfrom pathlib import Path\n\nREPO_URL = \"https://github.com/YOUR_USERNAME/origami_env.git\" # TODO: update with your repo\nLOCAL_DIR = \"origami_env\"\n\nif not Path(LOCAL_DIR).exists():\n # Running on Colab β clone the repo\n !git clone {REPO_URL} {LOCAL_DIR} > /dev/null 2>&1\n !pip install -e {LOCAL_DIR} > /dev/null 2>&1\n\n# Add repo to Python path\nworking_directory = str(Path(LOCAL_DIR).absolute()) if Path(LOCAL_DIR).exists() else str(Path.cwd().parent.absolute())\nsys.path.insert(0, working_directory)\nprint(f\"Working directory: {working_directory}\")",
|
| 33 |
"metadata": {},
|
| 34 |
"execution_count": null,
|
| 35 |
"outputs": []
|
|
|
|
| 37 |
{
|
| 38 |
"cell_type": "code",
|
| 39 |
"id": "bnm2w57r3lc",
|
| 40 |
+
"source": "# Import OpenEnv client + models (same pattern as 2048 notebook)\nfrom client import OrigamiEnv\nfrom origami_server.models import OrigamiAction, OrigamiObservation, OrigamiState\nprint(\"Origami OpenEnv modules loaded.\")",
|
| 41 |
"metadata": {},
|
| 42 |
"execution_count": null,
|
| 43 |
"outputs": []
|
|
|
|
| 45 |
{
|
| 46 |
"cell_type": "markdown",
|
| 47 |
"id": "lcaus7mtuj",
|
| 48 |
+
"source": "## 3. Load Model + QLoRA",
|
| 49 |
"metadata": {}
|
| 50 |
},
|
| 51 |
{
|
| 52 |
"cell_type": "code",
|
| 53 |
"id": "hlqp4y30m87",
|
| 54 |
+
"source": "from unsloth import FastLanguageModel\nimport torch\n\nmax_seq_length = 768\nlora_rank = 4\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name = \"unsloth/Qwen3-14B\",\n load_in_4bit = True,\n max_seq_length = max_seq_length,\n offload_embedding = True,\n)",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
"metadata": {},
|
| 56 |
"execution_count": null,
|
| 57 |
"outputs": []
|
|
|
|
| 59 |
{
|
| 60 |
"cell_type": "markdown",
|
| 61 |
"id": "a14w2fkoewq",
|
| 62 |
+
"source": "## 4. LoRA Adapter",
|
| 63 |
"metadata": {}
|
| 64 |
},
|
| 65 |
{
|
| 66 |
"cell_type": "code",
|
| 67 |
"id": "2phdejbobq3",
|
| 68 |
+
"source": "model = FastLanguageModel.get_peft_model(\n model,\n r = lora_rank,\n target_modules = [\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n lora_alpha = lora_rank * 2,\n use_gradient_checkpointing = \"unsloth\",\n random_state = 3407,\n)",
|
| 69 |
"metadata": {},
|
| 70 |
"execution_count": null,
|
| 71 |
"outputs": []
|
|
|
|
| 73 |
{
|
| 74 |
"cell_type": "markdown",
|
| 75 |
"id": "feal20fr8j5",
|
| 76 |
+
"source": "## 5. Launch OpenEnv Server",
|
| 77 |
"metadata": {}
|
| 78 |
},
|
| 79 |
{
|
| 80 |
"cell_type": "code",
|
| 81 |
"id": "uo7zh1dwp6r",
|
| 82 |
+
"source": "# Launch origami environment server (same pattern as 2048 notebook)\nglobal port\nglobal openenv_process\nport = 8000\nopenenv_process = None\nserver = \"origami_server.app:app\"\nenvironment = {\n **os.environ,\n \"PYTHONPATH\": working_directory,\n}\n\n# Augment Unsloth's launch_openenv with our config\nimport functools\nfrom unsloth import is_port_open, launch_openenv\nlaunch_openenv = functools.partial(\n launch_openenv,\n working_directory = working_directory,\n server = server,\n environment = environment,\n openenv_class = OrigamiEnv,\n)",
|
| 83 |
"metadata": {},
|
| 84 |
"execution_count": null,
|
| 85 |
"outputs": []
|
|
|
|
| 87 |
{
|
| 88 |
"cell_type": "code",
|
| 89 |
"id": "900vyqwb8g",
|
| 90 |
+
"source": "# Test the connection β reset and inspect\nport, openenv_process = launch_openenv(port, openenv_process)\nresult = openenv_process.reset(task_name=\"triangle\")\nprint(f\"Server running on port {port}\")\nprint(f\"Observation: done={result.done}, reward={result.reward}\")\nprint(f\"Task: {result.observation.task}\")",
|
| 91 |
"metadata": {},
|
| 92 |
"execution_count": null,
|
| 93 |
"outputs": []
|
|
|
|
| 95 |
{
|
| 96 |
"cell_type": "markdown",
|
| 97 |
"id": "xn6n1hpx2aa",
|
| 98 |
+
"source": "## 6. Prompt + Dataset",
|
| 99 |
"metadata": {}
|
| 100 |
},
|
| 101 |
{
|
| 102 |
"cell_type": "code",
|
| 103 |
"id": "vkfaeuu9dq",
|
| 104 |
+
"source": "import requests\n\nTASK_NAME = \"triangle\" # \"triangle\", \"half_fold\", \"quarter_fold\", \"letter_fold\"\n\n# Fetch task params from the server (paper size, description, etc.)\ntask_info = requests.get(f\"http://localhost:{port}/tasks/{TASK_NAME}\").json()\n\nPROMPT_TEMPLATE = \"\"\"You are an origami designer. Generate a FOLD-format crease pattern\nthat, when folded, produces the target shape described below.\n\nTarget: {description}\nPaper size: {width} x {height}\n\nOutput a JSON object with these exact fields:\n- vertices_coords: [[x, y], ...] β 2D positions on the flat paper (0 to {width} for x, 0 to {height} for y)\n- edges_vertices: [[v1, v2], ...] β pairs of vertex indices forming edges\n- edges_assignment: [\"B\"|\"M\"|\"V\", ...] β B=boundary, M=mountain fold, V=valley fold\n- edges_foldAngle: [angle, ...] β fold angles in degrees (V: 180, M: -180, B: 0)\n\nRules:\n- Boundary edges (B) must outline the paper rectangle\n- At least one fold crease (M or V) must exist\n- All vertex indices must be valid (0 to N-1)\n\nOutput ONLY the JSON object wrapped in ```json ... ``` markers.\"\"\"\n\nprompt = PROMPT_TEMPLATE.format(\n description=task_info[\"description\"],\n width=task_info[\"paper\"][\"width\"],\n height=task_info[\"paper\"][\"height\"],\n).strip()\n\n# Build dataset β same prompt repeated 1000x (identical to 2048 pattern)\nfrom datasets import Dataset\ndataset = Dataset.from_list([{\n \"prompt\": [{\"role\": \"user\", \"content\": prompt}],\n}] * 1000)\n\nprint(f\"Task: {task_info['name']} β {task_info['description']}\")\nprint(f\"Paper: {task_info['paper']['width']} x {task_info['paper']['height']}\")\nprint(f\"Difficulty: {task_info['difficulty']}\")\nprint(f\"Dataset: {len(dataset)} rows\")\nprint(f\"\\nPrompt:\\n{prompt[:200]}...\")",
|
| 105 |
"metadata": {},
|
| 106 |
"execution_count": null,
|
| 107 |
"outputs": []
|
| 108 |
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "markdown",
|
| 111 |
+
"id": "3f7ritml396",
|
| 112 |
+
"source": "## 7. Reward Functions\n\nTwo reward functions (same pattern as 2048 notebook):\n- `valid_fold` β local JSON structure check (fast, no server call)\n- `shape_match` β calls the origami server via `launch_openenv`, submits the fold, returns similarity Γ 20",
|
| 113 |
+
"metadata": {}
|
| 114 |
+
},
|
| 115 |
{
|
| 116 |
"cell_type": "code",
|
| 117 |
+
"id": "4dqsw30e9nq",
|
| 118 |
+
"source": "import json, re\n\n# --- Reward 1: valid_fold (local check, no server needed) ---\n\ndef extract_fold_json(response):\n \"\"\"Extract FOLD JSON from LLM response text.\"\"\"\n # Try fenced code block\n match = re.search(r\"```(?:json)?\\s*(\\{.*?\\})\\s*```\", response, re.DOTALL)\n if match:\n try: return json.loads(match.group(1))\n except json.JSONDecodeError: pass\n # Try raw JSON with vertices_coords\n match = re.search(r\"\\{[^{}]*\\\"vertices_coords\\\"[^{}]*\\}\", response, re.DOTALL)\n if match:\n try: return json.loads(match.group(0))\n except json.JSONDecodeError: pass\n # Try whole response\n try:\n data = json.loads(response.strip())\n if isinstance(data, dict) and \"vertices_coords\" in data:\n return data\n except (json.JSONDecodeError, ValueError): pass\n return None\n\ndef valid_fold(completions, **kwargs):\n \"\"\"Does the LLM output parse as valid FOLD JSON?\n +1.0 valid, -0.5 parseable but invalid, -2.0 unparseable.\"\"\"\n scores = []\n for completion in completions:\n response = completion[0][\"content\"]\n fold_data = extract_fold_json(response)\n if fold_data is None:\n scores.append(-2.0); continue\n required = {\"vertices_coords\", \"edges_vertices\", \"edges_assignment\"}\n if not required.issubset(fold_data.keys()):\n scores.append(-0.5); continue\n verts = fold_data.get(\"vertices_coords\", [])\n edges = fold_data.get(\"edges_vertices\", [])\n assigns = fold_data.get(\"edges_assignment\", [])\n if len(edges) != len(assigns):\n scores.append(-0.5); continue\n if not any(a in (\"M\", \"V\") for a in assigns) or not any(a == \"B\" for a in assigns):\n scores.append(-0.5); continue\n n = len(verts)\n if not all(0 <= e[0] < n and 0 <= e[1] < n and e[0] != e[1] for e in edges):\n scores.append(-0.5); continue\n scores.append(1.0)\n return scores\n\n# --- Reward 2: shape_match (calls server via launch_openenv) ---\n\ndef shape_match(completions, **kwargs):\n \"\"\"Submit fold to origami server, get shape similarity reward.\n Calls launch_openenv to ensure server is running, then reset + step.\"\"\"\n global port, openenv_process\n scores = []\n for completion in completions:\n response = completion[0][\"content\"]\n fold_data = extract_fold_json(response)\n if fold_data is None:\n scores.append(0.0)\n continue\n try:\n port, openenv_process = launch_openenv(port, openenv_process)\n openenv_process.reset(task_name=TASK_NAME)\n result = openenv_process.step(OrigamiAction(fold_data=fold_data))\n reward = result.reward if result.reward is not None else 0.0\n scores.append(reward)\n except TimeoutError:\n scores.append(-1.0)\n except Exception as e:\n scores.append(-3.0)\n return scores\n\n# Quick test\ntest_good = [[{\"content\": json.dumps({\n \"vertices_coords\": [[0,0],[1,0],[1,1],[0,1]],\n \"edges_vertices\": [[0,1],[1,2],[2,3],[3,0],[0,2]],\n \"edges_assignment\": [\"B\",\"B\",\"B\",\"B\",\"V\"],\n \"edges_foldAngle\": [0,0,0,0,180]\n})}]]\ntest_bad = [[{\"content\": \"not json\"}]]\nprint(f\"valid_fold β good: {valid_fold(test_good)}, bad: {valid_fold(test_bad)}\")\nprint(f\"shape_match β good: {shape_match(test_good)}\")",
|
| 119 |
"metadata": {},
|
| 120 |
"execution_count": null,
|
| 121 |
"outputs": []
|
| 122 |
},
|
| 123 |
{
|
| 124 |
"cell_type": "markdown",
|
| 125 |
+
"id": "62lvkfoyu1p",
|
| 126 |
+
"source": "## 8. GRPO Trainer",
|
| 127 |
"metadata": {}
|
| 128 |
},
|
| 129 |
{
|
| 130 |
"cell_type": "code",
|
| 131 |
+
"id": "eohisxhna96",
|
| 132 |
+
"source": "from trl import GRPOConfig, GRPOTrainer\n\ntraining_args = GRPOConfig(\n temperature = 1.0,\n learning_rate = 2e-4,\n weight_decay = 0.001,\n warmup_ratio = 0.1,\n lr_scheduler_type = \"linear\",\n optim = \"adamw_8bit\",\n logging_steps = 1,\n per_device_train_batch_size = 1,\n gradient_accumulation_steps = 1,\n num_generations = 2,\n max_prompt_length = 512,\n max_completion_length = max_seq_length - 512,\n max_steps = 600,\n save_steps = 100,\n output_dir = \"outputs\",\n report_to = \"none\",\n)\n\ntrainer = GRPOTrainer(\n model = model,\n processing_class = tokenizer,\n reward_funcs = [valid_fold, shape_match],\n args = training_args,\n train_dataset = dataset,\n)\n\nprint(f\"Trainer ready: {training_args.max_steps} steps, {training_args.num_generations} generations/step\")",
|
| 133 |
"metadata": {},
|
| 134 |
"execution_count": null,
|
| 135 |
"outputs": []
|
| 136 |
},
|
| 137 |
{
|
| 138 |
"cell_type": "markdown",
|
| 139 |
+
"id": "ve98mq6rgot",
|
| 140 |
"source": "## 9. Train!",
|
| 141 |
"metadata": {}
|
| 142 |
},
|
| 143 |
{
|
| 144 |
"cell_type": "code",
|
| 145 |
+
"id": "8il1yknetfg",
|
| 146 |
"source": "trainer.train()",
|
| 147 |
"metadata": {},
|
| 148 |
"execution_count": null,
|
|
|
|
| 157 |
{
|
| 158 |
"cell_type": "code",
|
| 159 |
"id": "t3d4tu6o5mc",
|
| 160 |
+
"source": "save_path = f\"origami-{TASK_NAME}-lora\"\nmodel.save_pretrained(save_path)\ntokenizer.save_pretrained(save_path)\nprint(f\"LoRA adapter saved to {save_path}/\")",
|
| 161 |
"metadata": {},
|
| 162 |
"execution_count": null,
|
| 163 |
"outputs": []
|
|
|
|
| 165 |
{
|
| 166 |
"cell_type": "markdown",
|
| 167 |
"id": "q18eizy1ok",
|
| 168 |
+
"source": "## 11. Evaluate β Generate & Score",
|
| 169 |
"metadata": {}
|
| 170 |
},
|
| 171 |
{
|
| 172 |
"cell_type": "code",
|
| 173 |
"id": "on56augj41",
|
| 174 |
+
"source": "import numpy as np\nFastLanguageModel.for_inference(model)\n\nNUM_EVAL = 8\nmessages = [{\"role\": \"user\", \"content\": prompt}]\ninput_ids = tokenizer.apply_chat_template(\n messages, tokenize=True, add_generation_prompt=True, return_tensors=\"pt\"\n).to(model.device)\n\nprint(f\"Generating {NUM_EVAL} completions (input: {input_ids.shape[1]} tokens)...\\n\")\n\neval_completions = []\nfor i in range(NUM_EVAL):\n with torch.no_grad():\n output = model.generate(\n input_ids,\n max_new_tokens=max_seq_length - 512,\n temperature=0.7, top_p=0.9, do_sample=True,\n pad_token_id=tokenizer.pad_token_id,\n )\n response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)\n eval_completions.append([{\"content\": response}])\n fold = extract_fold_json(response)\n status = f\"parsed ({len(fold.get('vertices_coords', []))} verts)\" if fold else \"UNPARSEABLE\"\n print(f\" Sample {i+1}: {status}\")\n\nprint(f\"\\nScoring via server...\")\nvf_scores = valid_fold(eval_completions)\nsm_scores = shape_match(eval_completions)\nprint(f\" valid_fold: mean={np.mean(vf_scores):.2f}, scores={vf_scores}\")\nprint(f\" shape_match: mean={np.mean(sm_scores):.2f}, scores={sm_scores}\")",
|
| 175 |
"metadata": {},
|
| 176 |
"execution_count": null,
|
| 177 |
"outputs": []
|
|
|
|
| 179 |
{
|
| 180 |
"cell_type": "markdown",
|
| 181 |
"id": "tb1y8hszrk",
|
| 182 |
+
"source": "## 12. Visualize Best Result",
|
| 183 |
"metadata": {}
|
| 184 |
},
|
| 185 |
{
|
| 186 |
"cell_type": "code",
|
| 187 |
"id": "0zo3krbkiqej",
|
| 188 |
+
"source": "import matplotlib.pyplot as plt\nimport requests\n\nEDGE_COLORS = {\"M\": \"red\", \"V\": \"blue\", \"B\": \"black\"}\nEDGE_STYLES = {\"M\": \"--\", \"V\": \":\", \"B\": \"-\"}\n\nbest_idx = int(np.argmax(sm_scores))\nbest_fold = extract_fold_json(eval_completions[best_idx][0][\"content\"])\n\nif best_fold is None or sm_scores[best_idx] <= 0:\n print(\"No valid completions to visualize.\")\nelse:\n fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))\n\n # Generated crease pattern\n ax1.set_title(f\"Generated (sample {best_idx+1})\\nreward={sm_scores[best_idx]:.1f}\", fontsize=10)\n ax1.set_aspect(\"equal\")\n verts = np.array(best_fold[\"vertices_coords\"])\n for e, a in zip(best_fold[\"edges_vertices\"], best_fold[\"edges_assignment\"]):\n v1, v2 = verts[e[0]], verts[e[1]]\n ax1.plot([v1[0], v2[0]], [v1[1], v2[1]],\n color=EDGE_COLORS.get(a, \"gray\"),\n linestyle=EDGE_STYLES.get(a, \"-\"), linewidth=2)\n ax1.scatter(verts[:, 0], verts[:, 1], c=\"black\", s=20, zorder=5)\n ax1.grid(True, alpha=0.2)\n\n # Target crease pattern (from server)\n ax2.set_title(\"Target\", fontsize=10)\n ax2.set_aspect(\"equal\")\n port, openenv_process = launch_openenv(port, openenv_process)\n # Get target from server via HTTP\n target_resp = requests.get(f\"http://localhost:{port}/tasks/{TASK_NAME}\")\n target = target_resp.json()[\"target_fold\"]\n tverts = np.array(target[\"vertices_coords\"])\n for e, a in zip(target[\"edges_vertices\"], target[\"edges_assignment\"]):\n v1, v2 = tverts[e[0]], tverts[e[1]]\n ax2.plot([v1[0], v2[0]], [v1[1], v2[1]],\n color=EDGE_COLORS.get(a, \"gray\"),\n linestyle=EDGE_STYLES.get(a, \"-\"), linewidth=2)\n ax2.scatter(tverts[:, 0], tverts[:, 1], c=\"black\", s=20, zorder=5)\n ax2.grid(True, alpha=0.2)\n\n plt.tight_layout()\n plt.show()\n print(f\"\\nBest FOLD JSON:\\n{json.dumps(best_fold, indent=2)}\")",
|
| 189 |
"metadata": {},
|
| 190 |
"execution_count": null,
|
| 191 |
"outputs": []
|
|
|
|
| 193 |
{
|
| 194 |
"cell_type": "markdown",
|
| 195 |
"id": "qlakksqmoe",
|
| 196 |
+
"source": "## 13. Training Logs",
|
| 197 |
"metadata": {}
|
| 198 |
},
|
| 199 |
{
|