ianalin123 Claude Sonnet 4.6 commited on
Commit
56c400c
Β·
1 Parent(s): 9e670bb

feat(training): richer LLM prompt with fold history, metric deltas, approach hints

Browse files

- System prompt explains coordinate system, fold types, physics clearly
- User message now includes full fold history, Ξ”compactness, max_strain,
kawasaki_violations so model has real gradient signal each step
- 8 distinct approach hints give diversity across parallel episodes
- Pass fold_history into strategy_fn for richer context
- stop detection handles both text "stop" and JSON {"type":"stop"}
- Increase animation delay 0.4β†’0.5s for viewer

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. training/demo_llm.py +125 -39
training/demo_llm.py CHANGED
@@ -6,7 +6,8 @@ Usage:
6
  ANTHROPIC_API_KEY=sk-... python -m training.demo_llm
7
 
8
  Each of the 8 episodes calls Claude (claude-haiku-4-5) once per fold step.
9
- Claude sees the current paper_state metrics and decides the next fold.
 
10
  """
11
  from __future__ import annotations
12
 
@@ -31,55 +32,137 @@ NUM_EPISODES = 8
31
  MODEL = "claude-haiku-4-5-20251001"
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # ── LLM strategy factory ───────────────────────────────────────────────────────
35
 
36
  def make_llm_strategy(client: anthropic.Anthropic, task: dict, episode_num: int):
37
- """Return a strategy_fn for one episode. Each episode gets its own call history."""
 
 
 
 
38
  history: list[dict[str, Any]] = []
 
 
39
 
40
- def strategy(paper_state: dict) -> dict:
41
  fold_count = paper_state.get("fold_count", 0)
42
- compactness = paper_state.get("compactness", 0)
43
  bb = paper_state.get("bounding_box", [1, 1, 0])
 
 
 
44
  target_box = task.get("target_box", [1, 0.5, 0.02])
45
  max_folds = task.get("max_folds", 3)
46
 
47
- user_msg = f"""You are folding a {task['width']}x{task['height']} sheet of {task['material']}.
48
- Task: {task['description']}
49
- Target box to fit inside: {target_box}
50
- Max folds allowed: {max_folds}
51
-
52
- Current state (fold {fold_count}/{max_folds}):
53
- compactness: {compactness:.3f} (1.0 = fully packed, 0.0 = flat)
54
- bounding_box: [{bb[0]:.3f}, {bb[1]:.3f}, {bb[2]:.4f}]
55
- fits_target_box: {paper_state.get('fits_target_box', False)}
56
-
57
- Choose the next fold. Respond with ONLY valid JSON, no other text:
58
- {{
59
- "type": "valley" or "mountain" or "stop",
60
- "line": {{"start": [x, y], "end": [x, y]}},
61
- "angle": 180
62
- }}
63
-
64
- Coordinates are normalized 0-1. Use "stop" if done."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  history.append({"role": "user", "content": user_msg})
67
 
68
  response = client.messages.create(
69
  model=MODEL,
70
- max_tokens=120,
 
71
  messages=history,
72
  )
73
  reply = response.content[0].text.strip()
74
  history.append({"role": "assistant", "content": reply})
75
 
76
- # Extract JSON β€” handle markdown code blocks
 
 
 
 
77
  match = re.search(r'\{[^{}]+\}', reply, re.DOTALL)
78
  if not match:
79
- return {"type": "stop", "line": {"start": [0, 0.5], "end": [1, 0.5]}, "angle": 0.0}
 
 
 
 
 
 
80
 
81
- fold_dict = json.loads(match.group())
82
- # Normalize: ensure required keys
83
  fold_dict.setdefault("type", "valley")
84
  fold_dict.setdefault("line", {"start": [0.0, 0.5], "end": [1.0, 0.5]})
85
  fold_dict.setdefault("angle", 180.0)
@@ -115,13 +198,13 @@ def run_episode_llm(
115
  if obs.done:
116
  break
117
 
118
- # Build a flat paper_state dict for the LLM (add metrics inline)
119
  ps = dict(obs.paper_state)
120
- ps.update(obs.metrics) # compactness, fits_target_box, etc.
121
  ps["fold_count"] = step_idx
122
 
123
  try:
124
- fold_dict = strategy_fn(ps)
125
  except Exception as exc:
126
  broadcast_fn(ep_id, {
127
  "type": "episode_done", "episode_id": ep_id,
@@ -133,7 +216,7 @@ def run_episode_llm(
133
  if fold_dict.get("type") == "stop":
134
  break
135
 
136
- time.sleep(0.4) # pace for viewer animation
137
 
138
  action = OrigamiAction(
139
  fold_type=fold_dict["type"],
@@ -192,16 +275,19 @@ async def run_demo() -> None:
192
 
193
  await asyncio.sleep(1.5) # wait for server startup
194
 
195
- print(f"\n[llm-demo] Model: {MODEL}")
196
- print(f"[llm-demo] Task: {TASK_NAME} β€” {task['description']}")
197
- print(f"[llm-demo] Open: http://localhost:9001/viewer/training.html\n")
 
 
 
 
198
 
199
  await broadcast.start_batch(1, NUM_EPISODES)
200
 
201
  ep_ids = [f"ep_{i:02d}" for i in range(NUM_EPISODES)]
202
  strategies = [make_llm_strategy(client, task, i) for i in range(NUM_EPISODES)]
203
 
204
- # Run all episodes concurrently (each makes its own Claude API calls)
205
  results = await asyncio.gather(*[
206
  asyncio.to_thread(run_episode_llm, fn, TASK_NAME, ep_id, broadcast.publish)
207
  for fn, ep_id in zip(strategies, ep_ids)
@@ -213,10 +299,10 @@ async def run_demo() -> None:
213
  await broadcast.finish_batch(1, scores, best_episode_id=ep_ids[best_idx])
214
 
215
  print("\n[llm-demo] Results:")
216
- for i, result in enumerate(results):
217
- print(f" ep_{i:02d} score={result['score']:+.2f} status={result['status']}")
218
- print(f"\n[llm-demo] Best: ep_{best_idx:02d} (score={scores[best_idx]:+.2f})")
219
- print("\n[llm-demo] Press Ctrl+C to stop.\n")
220
 
221
 
222
  async def _main() -> None:
 
6
  ANTHROPIC_API_KEY=sk-... python -m training.demo_llm
7
 
8
  Each of the 8 episodes calls Claude (claude-haiku-4-5) once per fold step.
9
+ Claude receives the current paper state (metrics + fold history) and decides
10
+ the next fold action. Episodes run concurrently; all stream to the grid viewer.
11
  """
12
  from __future__ import annotations
13
 
 
32
  MODEL = "claude-haiku-4-5-20251001"
33
 
34
 
35
+ # ── System prompt ──────────────────────────────────────────────────────────────
36
+
37
+ SYSTEM_PROMPT = """\
38
+ You are an origami folding agent controlling a robotic paper-folding system.
39
+
40
+ COORDINATE SYSTEM
41
+ - Paper starts as a flat sheet; coordinates are normalized to the sheet's original size.
42
+ - x=0 is the left edge, x=1 is the right edge.
43
+ - y=0 is the bottom edge, y=1 is the top edge.
44
+ - Fold line endpoints must be on or outside the paper boundary (0.0–1.0 range).
45
+ - A fold line that runs off the edge is fine β€” it just doesn't affect paper outside the sheet.
46
+
47
+ FOLD TYPES
48
+ - "valley": folds the paper toward you (creates a V crease when viewed from above).
49
+ - "mountain": folds the paper away from you (creates a ^ crease).
50
+ - "stop": you are satisfied β€” no more folds needed.
51
+
52
+ PHYSICS
53
+ - angle=180 means a fully flat fold (paper halved).
54
+ - Smaller angles (e.g. 90) create partial folds.
55
+ - Each fold updates compactness, bounding_box, and strain readings.
56
+ - Kawasaki/Maekawa violations indicate geometrically invalid crease patterns.
57
+
58
+ RESPONSE FORMAT β€” output ONLY valid JSON, no markdown, no explanation:
59
+ {"type": "valley", "line": {"start": [x, y], "end": [x, y]}, "angle": 180}
60
+ """
61
+
62
+ # Eight distinct approach hints β€” gives diversity across the parallel episodes.
63
+ APPROACH_HINTS = [
64
+ "Try a single clean horizontal fold at the exact midline.",
65
+ "Try a single clean vertical fold at the exact midline.",
66
+ "Use two folds: first horizontal then vertical, to create quarters.",
67
+ "Use a diagonal fold from one corner to the opposite.",
68
+ "Try folding at y=0.333 and y=0.667 to create thirds.",
69
+ "Try a single fold but vary the position slightly off-center to explore.",
70
+ "Use a mountain fold instead of valley for the primary crease.",
71
+ "Try to reach the target box in as few folds as possible β€” stop early if done.",
72
+ ]
73
+
74
+
75
  # ── LLM strategy factory ───────────────────────────────────────────────────────
76
 
77
  def make_llm_strategy(client: anthropic.Anthropic, task: dict, episode_num: int):
78
+ """Return a strategy_fn for one episode.
79
+
80
+ Each episode has its own conversation history (multi-turn) and a unique
81
+ approach hint so the 8 concurrent episodes explore different strategies.
82
+ """
83
  history: list[dict[str, Any]] = []
84
+ hint = APPROACH_HINTS[episode_num % len(APPROACH_HINTS)]
85
+ prev_compactness: list[float] = [0.0] # mutable cell for delta tracking
86
 
87
+ def strategy(paper_state: dict, fold_history: list[dict]) -> dict:
88
  fold_count = paper_state.get("fold_count", 0)
89
+ compactness = float(paper_state.get("compactness", 0))
90
  bb = paper_state.get("bounding_box", [1, 1, 0])
91
+ fits = paper_state.get("fits_target_box", False)
92
+ strain = paper_state.get("max_strain", 0.0)
93
+ kaw = paper_state.get("kawasaki_violations", 0)
94
  target_box = task.get("target_box", [1, 0.5, 0.02])
95
  max_folds = task.get("max_folds", 3)
96
 
97
+ delta = compactness - prev_compactness[0]
98
+ prev_compactness[0] = compactness
99
+
100
+ # Summarise what has been done so far
101
+ history_lines = ""
102
+ if fold_history:
103
+ history_lines = "Folds applied so far:\n"
104
+ for i, f in enumerate(fold_history, 1):
105
+ t = f.get("type", "?")
106
+ ln = f.get("line", {})
107
+ s = ln.get("start", [0, 0])
108
+ e = ln.get("end", [1, 1])
109
+ ang = f.get("angle", 180)
110
+ history_lines += (
111
+ f" {i}. {t} fold "
112
+ f"from ({s[0]:.3f},{s[1]:.3f}) to ({e[0]:.3f},{e[1]:.3f}) "
113
+ f"angle={ang}\n"
114
+ )
115
+ else:
116
+ history_lines = "No folds applied yet β€” paper is flat.\n"
117
+
118
+ sign = "+" if delta >= 0 else ""
119
+ user_msg = (
120
+ f"Task: {task['description']}\n"
121
+ f"Sheet: {task['width']}Γ—{task['height']} {task['material']}\n"
122
+ f"Target bounding box: {target_box} (must fit inside to succeed)\n"
123
+ f"Max folds remaining: {max_folds - fold_count}\n"
124
+ f"\n"
125
+ f"{history_lines}"
126
+ f"\n"
127
+ f"Current state after fold {fold_count}/{max_folds}:\n"
128
+ f" compactness : {compactness:.4f} (Ξ” {sign}{delta:.4f})\n"
129
+ f" bounding_box: [{bb[0]:.4f}, {bb[1]:.4f}, {bb[2]:.5f}]\n"
130
+ f" fits_target : {'YES βœ“' if fits else 'no'}\n"
131
+ f" max_strain : {strain:.5f}\n"
132
+ f" kaw_violations: {kaw}\n"
133
+ f"\n"
134
+ f"Approach hint: {hint}\n"
135
+ f"\n"
136
+ f"What is your next fold action? "
137
+ f"Return \"stop\" if the target is already achieved or no useful fold remains."
138
+ )
139
 
140
  history.append({"role": "user", "content": user_msg})
141
 
142
  response = client.messages.create(
143
  model=MODEL,
144
+ max_tokens=150,
145
+ system=SYSTEM_PROMPT,
146
  messages=history,
147
  )
148
  reply = response.content[0].text.strip()
149
  history.append({"role": "assistant", "content": reply})
150
 
151
+ # Handle explicit "stop" text before JSON parse
152
+ if reply.lower().startswith("stop") or '"type": "stop"' in reply:
153
+ return {"type": "stop", "line": {"start": [0, 0.5], "end": [1, 0.5]}, "angle": 0.0}
154
+
155
+ # Extract JSON β€” handles markdown code fences
156
  match = re.search(r'\{[^{}]+\}', reply, re.DOTALL)
157
  if not match:
158
+ # Malformed response β€” default safe fold then stop next turn
159
+ return {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}
160
+
161
+ try:
162
+ fold_dict = json.loads(match.group())
163
+ except json.JSONDecodeError:
164
+ return {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}
165
 
 
 
166
  fold_dict.setdefault("type", "valley")
167
  fold_dict.setdefault("line", {"start": [0.0, 0.5], "end": [1.0, 0.5]})
168
  fold_dict.setdefault("angle", 180.0)
 
198
  if obs.done:
199
  break
200
 
201
+ # Merge paper_state + metrics for the strategy
202
  ps = dict(obs.paper_state)
203
+ ps.update(obs.metrics)
204
  ps["fold_count"] = step_idx
205
 
206
  try:
207
+ fold_dict = strategy_fn(ps, list(obs.fold_history))
208
  except Exception as exc:
209
  broadcast_fn(ep_id, {
210
  "type": "episode_done", "episode_id": ep_id,
 
216
  if fold_dict.get("type") == "stop":
217
  break
218
 
219
+ time.sleep(0.5) # pace for viewer animation
220
 
221
  action = OrigamiAction(
222
  fold_type=fold_dict["type"],
 
275
 
276
  await asyncio.sleep(1.5) # wait for server startup
277
 
278
+ print(f"\n[llm-demo] Model : {MODEL}")
279
+ print(f"[llm-demo] Task : {TASK_NAME} β€” {task['description']}")
280
+ print(f"[llm-demo] Open : http://localhost:9001/viewer/training.html\n")
281
+ print(f"[llm-demo] Episodes:")
282
+ for i, hint in enumerate(APPROACH_HINTS):
283
+ print(f" ep_{i:02d} hint: {hint}")
284
+ print()
285
 
286
  await broadcast.start_batch(1, NUM_EPISODES)
287
 
288
  ep_ids = [f"ep_{i:02d}" for i in range(NUM_EPISODES)]
289
  strategies = [make_llm_strategy(client, task, i) for i in range(NUM_EPISODES)]
290
 
 
291
  results = await asyncio.gather(*[
292
  asyncio.to_thread(run_episode_llm, fn, TASK_NAME, ep_id, broadcast.publish)
293
  for fn, ep_id in zip(strategies, ep_ids)
 
299
  await broadcast.finish_batch(1, scores, best_episode_id=ep_ids[best_idx])
300
 
301
  print("\n[llm-demo] Results:")
302
+ for i, (result, hint) in enumerate(zip(results, APPROACH_HINTS)):
303
+ marker = " ← best" if i == best_idx else ""
304
+ print(f" ep_{i:02d} score={result['score']:+.2f} status={result['status']} hint: {hint}{marker}")
305
+ print(f"\n[llm-demo] Press Ctrl+C to stop.\n")
306
 
307
 
308
  async def _main() -> None: