ianalin123 commited on
Commit
5378254
·
1 Parent(s): c65943e

feat(v2): update train_grpo.py for step-level prompts and per_step_reward

Browse files
Files changed (1) hide show
  1. training/train_grpo.py +63 -6
training/train_grpo.py CHANGED
@@ -52,6 +52,43 @@ Structural rules:
52
 
53
  Output ONLY the JSON object wrapped in ```json ... ``` markers."""
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def build_prompt(task: dict) -> str:
57
  w = task["paper"]["width"]
@@ -71,7 +108,7 @@ def main():
71
  "--task", default="all",
72
  help="Comma-separated task names, or 'all' for all tasks",
73
  )
74
- parser.add_argument("--max_steps", type=int, default=600)
75
  parser.add_argument("--num_generations", type=int, default=2)
76
  parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct")
77
  parser.add_argument("--lr", type=float, default=2e-4)
@@ -103,7 +140,7 @@ def main():
103
  raise SystemExit(1)
104
 
105
  # --- Get task info from server ---
106
- ALL_TASKS = ["triangle", "half_fold", "quarter_fold", "letter_fold"]
107
  task_names = ALL_TASKS if args.task == "all" else [t.strip() for t in args.task.split(",")]
108
  tasks = {}
109
  for name in task_names:
@@ -113,7 +150,7 @@ def main():
113
  # --- Configure reward functions (OpenEnv pattern) ---
114
  from client import OrigamiEnv
115
  from origami_server.models import OrigamiAction
116
- from training.reward import extract_fold_json, flat_foldable_reward, valid_fold
117
  from unsloth import is_port_open, launch_openenv
118
 
119
  global port, openenv_process
@@ -150,6 +187,26 @@ def main():
150
  scores.append(-2.0)
151
  return scores
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  # --- Build dataset (same prompt repeated, like 2048) ---
154
  from datasets import Dataset
155
 
@@ -158,7 +215,7 @@ def main():
158
  samples_per_task = 200
159
  rows = []
160
  for tname, tinfo in tasks.items():
161
- prompt_text = build_prompt(tinfo)
162
  rows.extend([{
163
  "prompt": [
164
  {"role": "system", "content": "/no_think"},
@@ -244,7 +301,7 @@ def main():
244
  gradient_accumulation_steps=1,
245
  num_generations=args.num_generations,
246
  max_prompt_length=512,
247
- max_completion_length=max_seq_length - 512,
248
  max_steps=args.max_steps,
249
  save_steps=10,
250
  output_dir=os.environ.get("OUTPUT_DIR", "outputs"),
@@ -253,7 +310,7 @@ def main():
253
  trainer = GRPOTrainer(
254
  model=model,
255
  processing_class=tokenizer,
256
- reward_funcs=[valid_fold, flat_foldable_reward, shape_match_reward],
257
  args=training_args,
258
  train_dataset=dataset,
259
  )
 
52
 
53
  Output ONLY the JSON object wrapped in ```json ... ``` markers."""
54
 
55
+ STEP_PROMPT_TEMPLATE = """You are an origami designer. Add the next fold crease.
56
+
57
+ Target: {description}
58
+ Paper: {width} x {height} unit square
59
+
60
+ CURRENT STATE (step {step} of {max_folds}):
61
+ Creases placed: {crease_history}
62
+
63
+ AVAILABLE ANCHOR POINTS:
64
+ Corners: (0,0) ({width},0) ({width},{height}) (0,{height})
65
+ Midpoints: (0,{hy}) ({hx},0) ({width},{hy}) ({hx},{height})
66
+ Intersections: {intersections}
67
+
68
+ Flat-foldability rules at every interior vertex:
69
+ - Kawasaki: alternating sector angles each sum to 180 degrees
70
+ - Maekawa: |mountain_count - valley_count| = 2
71
+ - BLB: smallest sector bounded by opposite M/V types
72
+
73
+ Output ONLY this JSON (no explanation):
74
+ {{"from": [x1, y1], "to": [x2, y2], "assignment": "M" or "V"}}"""
75
+
76
+
77
+ def build_step_prompt(task: dict, step: int = 0, crease_history: str = "none", intersections: str = "none") -> str:
78
+ w = task["paper"]["width"]
79
+ h = task["paper"]["height"]
80
+ return STEP_PROMPT_TEMPLATE.format(
81
+ description=task["description"],
82
+ width=w,
83
+ height=h,
84
+ hx=round(w / 2, 4),
85
+ hy=round(h / 2, 4),
86
+ step=step,
87
+ max_folds=task.get("max_folds", 1),
88
+ crease_history=crease_history,
89
+ intersections=intersections,
90
+ )
91
+
92
 
93
  def build_prompt(task: dict) -> str:
94
  w = task["paper"]["width"]
 
108
  "--task", default="all",
109
  help="Comma-separated task names, or 'all' for all tasks",
110
  )
111
+ parser.add_argument("--max_steps", type=int, default=1200)
112
  parser.add_argument("--num_generations", type=int, default=2)
113
  parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct")
114
  parser.add_argument("--lr", type=float, default=2e-4)
 
140
  raise SystemExit(1)
141
 
142
  # --- Get task info from server ---
143
+ ALL_TASKS = ["triangle", "half_fold", "quarter_fold", "letter_fold", "waterbomb_base", "map_fold"]
144
  task_names = ALL_TASKS if args.task == "all" else [t.strip() for t in args.task.split(",")]
145
  tasks = {}
146
  for name in task_names:
 
150
  # --- Configure reward functions (OpenEnv pattern) ---
151
  from client import OrigamiEnv
152
  from origami_server.models import OrigamiAction
153
+ from training.reward import extract_fold_json, extract_crease_json, flat_foldable_reward, valid_fold, valid_crease
154
  from unsloth import is_port_open, launch_openenv
155
 
156
  global port, openenv_process
 
187
  scores.append(-2.0)
188
  return scores
189
 
190
+ def per_step_reward(completions, task_name, **kwargs):
191
+ global port, openenv_process
192
+ scores = []
193
+ for completion, tname in zip(completions, task_name):
194
+ response = completion[0]["content"]
195
+ crease = extract_crease_json(response)
196
+ if crease is None:
197
+ scores.append(-2.0)
198
+ continue
199
+ try:
200
+ port, openenv_process = launch_openenv(port, openenv_process)
201
+ openenv_process.reset(task_name=tname)
202
+ result = openenv_process.step(OrigamiAction(crease=crease))
203
+ scores.append(result.reward if result.reward is not None else 0.0)
204
+ except TimeoutError:
205
+ scores.append(-1.0)
206
+ except Exception:
207
+ scores.append(-2.0)
208
+ return scores
209
+
210
  # --- Build dataset (same prompt repeated, like 2048) ---
211
  from datasets import Dataset
212
 
 
215
  samples_per_task = 200
216
  rows = []
217
  for tname, tinfo in tasks.items():
218
+ prompt_text = build_step_prompt(tinfo)
219
  rows.extend([{
220
  "prompt": [
221
  {"role": "system", "content": "/no_think"},
 
301
  gradient_accumulation_steps=1,
302
  num_generations=args.num_generations,
303
  max_prompt_length=512,
304
+ max_completion_length=128,
305
  max_steps=args.max_steps,
306
  save_steps=10,
307
  output_dir=os.environ.get("OUTPUT_DIR", "outputs"),
 
310
  trainer = GRPOTrainer(
311
  model=model,
312
  processing_class=tokenizer,
313
+ reward_funcs=[valid_crease, per_step_reward],
314
  args=training_args,
315
  train_dataset=dataset,
316
  )