Spaces:
Running
Running
Commit Β·
56c400c
1
Parent(s): 9e670bb
feat(training): richer LLM prompt with fold history, metric deltas, approach hints
Browse files- System prompt explains coordinate system, fold types, physics clearly
- User message now includes full fold history, Ξcompactness, max_strain,
kawasaki_violations so model has real gradient signal each step
- 8 distinct approach hints give diversity across parallel episodes
- Pass fold_history into strategy_fn for richer context
- stop detection handles both text "stop" and JSON {"type":"stop"}
- Increase animation delay 0.4β0.5s for viewer
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- training/demo_llm.py +125 -39
training/demo_llm.py
CHANGED
|
@@ -6,7 +6,8 @@ Usage:
|
|
| 6 |
ANTHROPIC_API_KEY=sk-... python -m training.demo_llm
|
| 7 |
|
| 8 |
Each of the 8 episodes calls Claude (claude-haiku-4-5) once per fold step.
|
| 9 |
-
Claude
|
|
|
|
| 10 |
"""
|
| 11 |
from __future__ import annotations
|
| 12 |
|
|
@@ -31,55 +32,137 @@ NUM_EPISODES = 8
|
|
| 31 |
MODEL = "claude-haiku-4-5-20251001"
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# ββ LLM strategy factory βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
|
| 36 |
def make_llm_strategy(client: anthropic.Anthropic, task: dict, episode_num: int):
|
| 37 |
-
"""Return a strategy_fn for one episode.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
history: list[dict[str, Any]] = []
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
def strategy(paper_state: dict) -> dict:
|
| 41 |
fold_count = paper_state.get("fold_count", 0)
|
| 42 |
-
compactness = paper_state.get("compactness", 0)
|
| 43 |
bb = paper_state.get("bounding_box", [1, 1, 0])
|
|
|
|
|
|
|
|
|
|
| 44 |
target_box = task.get("target_box", [1, 0.5, 0.02])
|
| 45 |
max_folds = task.get("max_folds", 3)
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
}}
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
history.append({"role": "user", "content": user_msg})
|
| 67 |
|
| 68 |
response = client.messages.create(
|
| 69 |
model=MODEL,
|
| 70 |
-
max_tokens=
|
|
|
|
| 71 |
messages=history,
|
| 72 |
)
|
| 73 |
reply = response.content[0].text.strip()
|
| 74 |
history.append({"role": "assistant", "content": reply})
|
| 75 |
|
| 76 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
match = re.search(r'\{[^{}]+\}', reply, re.DOTALL)
|
| 78 |
if not match:
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
fold_dict = json.loads(match.group())
|
| 82 |
-
# Normalize: ensure required keys
|
| 83 |
fold_dict.setdefault("type", "valley")
|
| 84 |
fold_dict.setdefault("line", {"start": [0.0, 0.5], "end": [1.0, 0.5]})
|
| 85 |
fold_dict.setdefault("angle", 180.0)
|
|
@@ -115,13 +198,13 @@ def run_episode_llm(
|
|
| 115 |
if obs.done:
|
| 116 |
break
|
| 117 |
|
| 118 |
-
#
|
| 119 |
ps = dict(obs.paper_state)
|
| 120 |
-
ps.update(obs.metrics)
|
| 121 |
ps["fold_count"] = step_idx
|
| 122 |
|
| 123 |
try:
|
| 124 |
-
fold_dict = strategy_fn(ps)
|
| 125 |
except Exception as exc:
|
| 126 |
broadcast_fn(ep_id, {
|
| 127 |
"type": "episode_done", "episode_id": ep_id,
|
|
@@ -133,7 +216,7 @@ def run_episode_llm(
|
|
| 133 |
if fold_dict.get("type") == "stop":
|
| 134 |
break
|
| 135 |
|
| 136 |
-
time.sleep(0.
|
| 137 |
|
| 138 |
action = OrigamiAction(
|
| 139 |
fold_type=fold_dict["type"],
|
|
@@ -192,16 +275,19 @@ async def run_demo() -> None:
|
|
| 192 |
|
| 193 |
await asyncio.sleep(1.5) # wait for server startup
|
| 194 |
|
| 195 |
-
print(f"\n[llm-demo] Model: {MODEL}")
|
| 196 |
-
print(f"[llm-demo] Task: {TASK_NAME} β {task['description']}")
|
| 197 |
-
print(f"[llm-demo] Open: http://localhost:9001/viewer/training.html\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
await broadcast.start_batch(1, NUM_EPISODES)
|
| 200 |
|
| 201 |
ep_ids = [f"ep_{i:02d}" for i in range(NUM_EPISODES)]
|
| 202 |
strategies = [make_llm_strategy(client, task, i) for i in range(NUM_EPISODES)]
|
| 203 |
|
| 204 |
-
# Run all episodes concurrently (each makes its own Claude API calls)
|
| 205 |
results = await asyncio.gather(*[
|
| 206 |
asyncio.to_thread(run_episode_llm, fn, TASK_NAME, ep_id, broadcast.publish)
|
| 207 |
for fn, ep_id in zip(strategies, ep_ids)
|
|
@@ -213,10 +299,10 @@ async def run_demo() -> None:
|
|
| 213 |
await broadcast.finish_batch(1, scores, best_episode_id=ep_ids[best_idx])
|
| 214 |
|
| 215 |
print("\n[llm-demo] Results:")
|
| 216 |
-
for i, result in enumerate(results):
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
print("\n[llm-demo] Press Ctrl+C to stop.\n")
|
| 220 |
|
| 221 |
|
| 222 |
async def _main() -> None:
|
|
|
|
| 6 |
ANTHROPIC_API_KEY=sk-... python -m training.demo_llm
|
| 7 |
|
| 8 |
Each of the 8 episodes calls Claude (claude-haiku-4-5) once per fold step.
|
| 9 |
+
Claude receives the current paper state (metrics + fold history) and decides
|
| 10 |
+
the next fold action. Episodes run concurrently; all stream to the grid viewer.
|
| 11 |
"""
|
| 12 |
from __future__ import annotations
|
| 13 |
|
|
|
|
| 32 |
MODEL = "claude-haiku-4-5-20251001"
|
| 33 |
|
| 34 |
|
| 35 |
+
# ββ System prompt ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
|
| 37 |
+
SYSTEM_PROMPT = """\
|
| 38 |
+
You are an origami folding agent controlling a robotic paper-folding system.
|
| 39 |
+
|
| 40 |
+
COORDINATE SYSTEM
|
| 41 |
+
- Paper starts as a flat sheet; coordinates are normalized to the sheet's original size.
|
| 42 |
+
- x=0 is the left edge, x=1 is the right edge.
|
| 43 |
+
- y=0 is the bottom edge, y=1 is the top edge.
|
| 44 |
+
- Fold line endpoints must be on or outside the paper boundary (0.0β1.0 range).
|
| 45 |
+
- A fold line that runs off the edge is fine β it just doesn't affect paper outside the sheet.
|
| 46 |
+
|
| 47 |
+
FOLD TYPES
|
| 48 |
+
- "valley": folds the paper toward you (creates a V crease when viewed from above).
|
| 49 |
+
- "mountain": folds the paper away from you (creates a ^ crease).
|
| 50 |
+
- "stop": you are satisfied β no more folds needed.
|
| 51 |
+
|
| 52 |
+
PHYSICS
|
| 53 |
+
- angle=180 means a fully flat fold (paper halved).
|
| 54 |
+
- Smaller angles (e.g. 90) create partial folds.
|
| 55 |
+
- Each fold updates compactness, bounding_box, and strain readings.
|
| 56 |
+
- Kawasaki/Maekawa violations indicate geometrically invalid crease patterns.
|
| 57 |
+
|
| 58 |
+
RESPONSE FORMAT β output ONLY valid JSON, no markdown, no explanation:
|
| 59 |
+
{"type": "valley", "line": {"start": [x, y], "end": [x, y]}, "angle": 180}
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# Eight distinct approach hints β gives diversity across the parallel episodes.
|
| 63 |
+
APPROACH_HINTS = [
|
| 64 |
+
"Try a single clean horizontal fold at the exact midline.",
|
| 65 |
+
"Try a single clean vertical fold at the exact midline.",
|
| 66 |
+
"Use two folds: first horizontal then vertical, to create quarters.",
|
| 67 |
+
"Use a diagonal fold from one corner to the opposite.",
|
| 68 |
+
"Try folding at y=0.333 and y=0.667 to create thirds.",
|
| 69 |
+
"Try a single fold but vary the position slightly off-center to explore.",
|
| 70 |
+
"Use a mountain fold instead of valley for the primary crease.",
|
| 71 |
+
"Try to reach the target box in as few folds as possible β stop early if done.",
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
# ββ LLM strategy factory βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 76 |
|
| 77 |
def make_llm_strategy(client: anthropic.Anthropic, task: dict, episode_num: int):
|
| 78 |
+
"""Return a strategy_fn for one episode.
|
| 79 |
+
|
| 80 |
+
Each episode has its own conversation history (multi-turn) and a unique
|
| 81 |
+
approach hint so the 8 concurrent episodes explore different strategies.
|
| 82 |
+
"""
|
| 83 |
history: list[dict[str, Any]] = []
|
| 84 |
+
hint = APPROACH_HINTS[episode_num % len(APPROACH_HINTS)]
|
| 85 |
+
prev_compactness: list[float] = [0.0] # mutable cell for delta tracking
|
| 86 |
|
| 87 |
+
def strategy(paper_state: dict, fold_history: list[dict]) -> dict:
|
| 88 |
fold_count = paper_state.get("fold_count", 0)
|
| 89 |
+
compactness = float(paper_state.get("compactness", 0))
|
| 90 |
bb = paper_state.get("bounding_box", [1, 1, 0])
|
| 91 |
+
fits = paper_state.get("fits_target_box", False)
|
| 92 |
+
strain = paper_state.get("max_strain", 0.0)
|
| 93 |
+
kaw = paper_state.get("kawasaki_violations", 0)
|
| 94 |
target_box = task.get("target_box", [1, 0.5, 0.02])
|
| 95 |
max_folds = task.get("max_folds", 3)
|
| 96 |
|
| 97 |
+
delta = compactness - prev_compactness[0]
|
| 98 |
+
prev_compactness[0] = compactness
|
| 99 |
+
|
| 100 |
+
# Summarise what has been done so far
|
| 101 |
+
history_lines = ""
|
| 102 |
+
if fold_history:
|
| 103 |
+
history_lines = "Folds applied so far:\n"
|
| 104 |
+
for i, f in enumerate(fold_history, 1):
|
| 105 |
+
t = f.get("type", "?")
|
| 106 |
+
ln = f.get("line", {})
|
| 107 |
+
s = ln.get("start", [0, 0])
|
| 108 |
+
e = ln.get("end", [1, 1])
|
| 109 |
+
ang = f.get("angle", 180)
|
| 110 |
+
history_lines += (
|
| 111 |
+
f" {i}. {t} fold "
|
| 112 |
+
f"from ({s[0]:.3f},{s[1]:.3f}) to ({e[0]:.3f},{e[1]:.3f}) "
|
| 113 |
+
f"angle={ang}\n"
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
history_lines = "No folds applied yet β paper is flat.\n"
|
| 117 |
+
|
| 118 |
+
sign = "+" if delta >= 0 else ""
|
| 119 |
+
user_msg = (
|
| 120 |
+
f"Task: {task['description']}\n"
|
| 121 |
+
f"Sheet: {task['width']}Γ{task['height']} {task['material']}\n"
|
| 122 |
+
f"Target bounding box: {target_box} (must fit inside to succeed)\n"
|
| 123 |
+
f"Max folds remaining: {max_folds - fold_count}\n"
|
| 124 |
+
f"\n"
|
| 125 |
+
f"{history_lines}"
|
| 126 |
+
f"\n"
|
| 127 |
+
f"Current state after fold {fold_count}/{max_folds}:\n"
|
| 128 |
+
f" compactness : {compactness:.4f} (Ξ {sign}{delta:.4f})\n"
|
| 129 |
+
f" bounding_box: [{bb[0]:.4f}, {bb[1]:.4f}, {bb[2]:.5f}]\n"
|
| 130 |
+
f" fits_target : {'YES β' if fits else 'no'}\n"
|
| 131 |
+
f" max_strain : {strain:.5f}\n"
|
| 132 |
+
f" kaw_violations: {kaw}\n"
|
| 133 |
+
f"\n"
|
| 134 |
+
f"Approach hint: {hint}\n"
|
| 135 |
+
f"\n"
|
| 136 |
+
f"What is your next fold action? "
|
| 137 |
+
f"Return \"stop\" if the target is already achieved or no useful fold remains."
|
| 138 |
+
)
|
| 139 |
|
| 140 |
history.append({"role": "user", "content": user_msg})
|
| 141 |
|
| 142 |
response = client.messages.create(
|
| 143 |
model=MODEL,
|
| 144 |
+
max_tokens=150,
|
| 145 |
+
system=SYSTEM_PROMPT,
|
| 146 |
messages=history,
|
| 147 |
)
|
| 148 |
reply = response.content[0].text.strip()
|
| 149 |
history.append({"role": "assistant", "content": reply})
|
| 150 |
|
| 151 |
+
# Handle explicit "stop" text before JSON parse
|
| 152 |
+
if reply.lower().startswith("stop") or '"type": "stop"' in reply:
|
| 153 |
+
return {"type": "stop", "line": {"start": [0, 0.5], "end": [1, 0.5]}, "angle": 0.0}
|
| 154 |
+
|
| 155 |
+
# Extract JSON β handles markdown code fences
|
| 156 |
match = re.search(r'\{[^{}]+\}', reply, re.DOTALL)
|
| 157 |
if not match:
|
| 158 |
+
# Malformed response β default safe fold then stop next turn
|
| 159 |
+
return {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
fold_dict = json.loads(match.group())
|
| 163 |
+
except json.JSONDecodeError:
|
| 164 |
+
return {"type": "valley", "line": {"start": [0.0, 0.5], "end": [1.0, 0.5]}, "angle": 180.0}
|
| 165 |
|
|
|
|
|
|
|
| 166 |
fold_dict.setdefault("type", "valley")
|
| 167 |
fold_dict.setdefault("line", {"start": [0.0, 0.5], "end": [1.0, 0.5]})
|
| 168 |
fold_dict.setdefault("angle", 180.0)
|
|
|
|
| 198 |
if obs.done:
|
| 199 |
break
|
| 200 |
|
| 201 |
+
# Merge paper_state + metrics for the strategy
|
| 202 |
ps = dict(obs.paper_state)
|
| 203 |
+
ps.update(obs.metrics)
|
| 204 |
ps["fold_count"] = step_idx
|
| 205 |
|
| 206 |
try:
|
| 207 |
+
fold_dict = strategy_fn(ps, list(obs.fold_history))
|
| 208 |
except Exception as exc:
|
| 209 |
broadcast_fn(ep_id, {
|
| 210 |
"type": "episode_done", "episode_id": ep_id,
|
|
|
|
| 216 |
if fold_dict.get("type") == "stop":
|
| 217 |
break
|
| 218 |
|
| 219 |
+
time.sleep(0.5) # pace for viewer animation
|
| 220 |
|
| 221 |
action = OrigamiAction(
|
| 222 |
fold_type=fold_dict["type"],
|
|
|
|
| 275 |
|
| 276 |
await asyncio.sleep(1.5) # wait for server startup
|
| 277 |
|
| 278 |
+
print(f"\n[llm-demo] Model : {MODEL}")
|
| 279 |
+
print(f"[llm-demo] Task : {TASK_NAME} β {task['description']}")
|
| 280 |
+
print(f"[llm-demo] Open : http://localhost:9001/viewer/training.html\n")
|
| 281 |
+
print(f"[llm-demo] Episodes:")
|
| 282 |
+
for i, hint in enumerate(APPROACH_HINTS):
|
| 283 |
+
print(f" ep_{i:02d} hint: {hint}")
|
| 284 |
+
print()
|
| 285 |
|
| 286 |
await broadcast.start_batch(1, NUM_EPISODES)
|
| 287 |
|
| 288 |
ep_ids = [f"ep_{i:02d}" for i in range(NUM_EPISODES)]
|
| 289 |
strategies = [make_llm_strategy(client, task, i) for i in range(NUM_EPISODES)]
|
| 290 |
|
|
|
|
| 291 |
results = await asyncio.gather(*[
|
| 292 |
asyncio.to_thread(run_episode_llm, fn, TASK_NAME, ep_id, broadcast.publish)
|
| 293 |
for fn, ep_id in zip(strategies, ep_ids)
|
|
|
|
| 299 |
await broadcast.finish_batch(1, scores, best_episode_id=ep_ids[best_idx])
|
| 300 |
|
| 301 |
print("\n[llm-demo] Results:")
|
| 302 |
+
for i, (result, hint) in enumerate(zip(results, APPROACH_HINTS)):
|
| 303 |
+
marker = " β best" if i == best_idx else ""
|
| 304 |
+
print(f" ep_{i:02d} score={result['score']:+.2f} status={result['status']} hint: {hint}{marker}")
|
| 305 |
+
print(f"\n[llm-demo] Press Ctrl+C to stop.\n")
|
| 306 |
|
| 307 |
|
| 308 |
async def _main() -> None:
|