Spaces:
Sleeping
Sleeping
Commit ·
5378254
1
Parent(s): c65943e
feat(v2): update train_grpo.py for step-level prompts and per_step_reward
Browse files- training/train_grpo.py +63 -6
training/train_grpo.py
CHANGED
|
@@ -52,6 +52,43 @@ Structural rules:
|
|
| 52 |
|
| 53 |
Output ONLY the JSON object wrapped in ```json ... ``` markers."""
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def build_prompt(task: dict) -> str:
|
| 57 |
w = task["paper"]["width"]
|
|
@@ -71,7 +108,7 @@ def main():
|
|
| 71 |
"--task", default="all",
|
| 72 |
help="Comma-separated task names, or 'all' for all tasks",
|
| 73 |
)
|
| 74 |
-
parser.add_argument("--max_steps", type=int, default=
|
| 75 |
parser.add_argument("--num_generations", type=int, default=2)
|
| 76 |
parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct")
|
| 77 |
parser.add_argument("--lr", type=float, default=2e-4)
|
|
@@ -103,7 +140,7 @@ def main():
|
|
| 103 |
raise SystemExit(1)
|
| 104 |
|
| 105 |
# --- Get task info from server ---
|
| 106 |
-
ALL_TASKS = ["triangle", "half_fold", "quarter_fold", "letter_fold"]
|
| 107 |
task_names = ALL_TASKS if args.task == "all" else [t.strip() for t in args.task.split(",")]
|
| 108 |
tasks = {}
|
| 109 |
for name in task_names:
|
|
@@ -113,7 +150,7 @@ def main():
|
|
| 113 |
# --- Configure reward functions (OpenEnv pattern) ---
|
| 114 |
from client import OrigamiEnv
|
| 115 |
from origami_server.models import OrigamiAction
|
| 116 |
-
from training.reward import extract_fold_json, flat_foldable_reward, valid_fold
|
| 117 |
from unsloth import is_port_open, launch_openenv
|
| 118 |
|
| 119 |
global port, openenv_process
|
|
@@ -150,6 +187,26 @@ def main():
|
|
| 150 |
scores.append(-2.0)
|
| 151 |
return scores
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
# --- Build dataset (same prompt repeated, like 2048) ---
|
| 154 |
from datasets import Dataset
|
| 155 |
|
|
@@ -158,7 +215,7 @@ def main():
|
|
| 158 |
samples_per_task = 200
|
| 159 |
rows = []
|
| 160 |
for tname, tinfo in tasks.items():
|
| 161 |
-
prompt_text =
|
| 162 |
rows.extend([{
|
| 163 |
"prompt": [
|
| 164 |
{"role": "system", "content": "/no_think"},
|
|
@@ -244,7 +301,7 @@ def main():
|
|
| 244 |
gradient_accumulation_steps=1,
|
| 245 |
num_generations=args.num_generations,
|
| 246 |
max_prompt_length=512,
|
| 247 |
-
max_completion_length=
|
| 248 |
max_steps=args.max_steps,
|
| 249 |
save_steps=10,
|
| 250 |
output_dir=os.environ.get("OUTPUT_DIR", "outputs"),
|
|
@@ -253,7 +310,7 @@ def main():
|
|
| 253 |
trainer = GRPOTrainer(
|
| 254 |
model=model,
|
| 255 |
processing_class=tokenizer,
|
| 256 |
-
reward_funcs=[
|
| 257 |
args=training_args,
|
| 258 |
train_dataset=dataset,
|
| 259 |
)
|
|
|
|
| 52 |
|
| 53 |
Output ONLY the JSON object wrapped in ```json ... ``` markers."""
|
| 54 |
|
| 55 |
+
STEP_PROMPT_TEMPLATE = """You are an origami designer. Add the next fold crease.
|
| 56 |
+
|
| 57 |
+
Target: {description}
|
| 58 |
+
Paper: {width} x {height} unit square
|
| 59 |
+
|
| 60 |
+
CURRENT STATE (step {step} of {max_folds}):
|
| 61 |
+
Creases placed: {crease_history}
|
| 62 |
+
|
| 63 |
+
AVAILABLE ANCHOR POINTS:
|
| 64 |
+
Corners: (0,0) ({width},0) ({width},{height}) (0,{height})
|
| 65 |
+
Midpoints: (0,{hy}) ({hx},0) ({width},{hy}) ({hx},{height})
|
| 66 |
+
Intersections: {intersections}
|
| 67 |
+
|
| 68 |
+
Flat-foldability rules at every interior vertex:
|
| 69 |
+
- Kawasaki: alternating sector angles each sum to 180 degrees
|
| 70 |
+
- Maekawa: |mountain_count - valley_count| = 2
|
| 71 |
+
- BLB: smallest sector bounded by opposite M/V types
|
| 72 |
+
|
| 73 |
+
Output ONLY this JSON (no explanation):
|
| 74 |
+
{{"from": [x1, y1], "to": [x2, y2], "assignment": "M" or "V"}}"""
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def build_step_prompt(task: dict, step: int = 0, crease_history: str = "none", intersections: str = "none") -> str:
|
| 78 |
+
w = task["paper"]["width"]
|
| 79 |
+
h = task["paper"]["height"]
|
| 80 |
+
return STEP_PROMPT_TEMPLATE.format(
|
| 81 |
+
description=task["description"],
|
| 82 |
+
width=w,
|
| 83 |
+
height=h,
|
| 84 |
+
hx=round(w / 2, 4),
|
| 85 |
+
hy=round(h / 2, 4),
|
| 86 |
+
step=step,
|
| 87 |
+
max_folds=task.get("max_folds", 1),
|
| 88 |
+
crease_history=crease_history,
|
| 89 |
+
intersections=intersections,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
|
| 93 |
def build_prompt(task: dict) -> str:
|
| 94 |
w = task["paper"]["width"]
|
|
|
|
| 108 |
"--task", default="all",
|
| 109 |
help="Comma-separated task names, or 'all' for all tasks",
|
| 110 |
)
|
| 111 |
+
parser.add_argument("--max_steps", type=int, default=1200)
|
| 112 |
parser.add_argument("--num_generations", type=int, default=2)
|
| 113 |
parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct")
|
| 114 |
parser.add_argument("--lr", type=float, default=2e-4)
|
|
|
|
| 140 |
raise SystemExit(1)
|
| 141 |
|
| 142 |
# --- Get task info from server ---
|
| 143 |
+
ALL_TASKS = ["triangle", "half_fold", "quarter_fold", "letter_fold", "waterbomb_base", "map_fold"]
|
| 144 |
task_names = ALL_TASKS if args.task == "all" else [t.strip() for t in args.task.split(",")]
|
| 145 |
tasks = {}
|
| 146 |
for name in task_names:
|
|
|
|
| 150 |
# --- Configure reward functions (OpenEnv pattern) ---
|
| 151 |
from client import OrigamiEnv
|
| 152 |
from origami_server.models import OrigamiAction
|
| 153 |
+
from training.reward import extract_fold_json, extract_crease_json, flat_foldable_reward, valid_fold, valid_crease
|
| 154 |
from unsloth import is_port_open, launch_openenv
|
| 155 |
|
| 156 |
global port, openenv_process
|
|
|
|
| 187 |
scores.append(-2.0)
|
| 188 |
return scores
|
| 189 |
|
| 190 |
+
def per_step_reward(completions, task_name, **kwargs):
|
| 191 |
+
global port, openenv_process
|
| 192 |
+
scores = []
|
| 193 |
+
for completion, tname in zip(completions, task_name):
|
| 194 |
+
response = completion[0]["content"]
|
| 195 |
+
crease = extract_crease_json(response)
|
| 196 |
+
if crease is None:
|
| 197 |
+
scores.append(-2.0)
|
| 198 |
+
continue
|
| 199 |
+
try:
|
| 200 |
+
port, openenv_process = launch_openenv(port, openenv_process)
|
| 201 |
+
openenv_process.reset(task_name=tname)
|
| 202 |
+
result = openenv_process.step(OrigamiAction(crease=crease))
|
| 203 |
+
scores.append(result.reward if result.reward is not None else 0.0)
|
| 204 |
+
except TimeoutError:
|
| 205 |
+
scores.append(-1.0)
|
| 206 |
+
except Exception:
|
| 207 |
+
scores.append(-2.0)
|
| 208 |
+
return scores
|
| 209 |
+
|
| 210 |
# --- Build dataset (same prompt repeated, like 2048) ---
|
| 211 |
from datasets import Dataset
|
| 212 |
|
|
|
|
| 215 |
samples_per_task = 200
|
| 216 |
rows = []
|
| 217 |
for tname, tinfo in tasks.items():
|
| 218 |
+
prompt_text = build_step_prompt(tinfo)
|
| 219 |
rows.extend([{
|
| 220 |
"prompt": [
|
| 221 |
{"role": "system", "content": "/no_think"},
|
|
|
|
| 301 |
gradient_accumulation_steps=1,
|
| 302 |
num_generations=args.num_generations,
|
| 303 |
max_prompt_length=512,
|
| 304 |
+
max_completion_length=128,
|
| 305 |
max_steps=args.max_steps,
|
| 306 |
save_steps=10,
|
| 307 |
output_dir=os.environ.get("OUTPUT_DIR", "outputs"),
|
|
|
|
| 310 |
trainer = GRPOTrainer(
|
| 311 |
model=model,
|
| 312 |
processing_class=tokenizer,
|
| 313 |
+
reward_funcs=[valid_crease, per_step_reward],
|
| 314 |
args=training_args,
|
| 315 |
train_dataset=dataset,
|
| 316 |
)
|