Spaces:
Running
Running
Commit ·
bc52096
1
Parent(s): 654687c
Add RL training environment with OpenEnv backend
Browse files- origami_server/: Full OpenEnv environment with physics simulator
- engine/: FOLD parser, analytical fold simulator (BFS + rotation
transforms), chamfer-distance shape matching
- environment.py: OrigamiEnvironment (reset/step/state)
- tasks.py: 7 tasks (triangle, half_fold, quarter_fold, letter_fold,
waterbomb, accordion, miura_ori)
- app.py: FastAPI server with /tasks endpoints
- training/: GRPO training scripts for Colab
- reward.py: valid_fold + shape_match reward functions
- train_grpo.py: Unsloth + TRL GRPOTrainer pipeline
- Dual-service Docker: Next.js UI on :7860, OpenEnv API on :8000
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- .dockerignore +1 -0
- Dockerfile +29 -14
- README.md +1 -1
- origami_server/__init__.py +0 -0
- origami_server/app.py +62 -0
- origami_server/engine/__init__.py +3 -0
- origami_server/engine/fold_parser.py +176 -0
- origami_server/engine/shape_match.py +95 -0
- origami_server/engine/simulate.py +150 -0
- origami_server/environment.py +158 -0
- origami_server/models.py +48 -0
- origami_server/tasks.py +218 -0
- requirements.txt +6 -0
- start.sh +6 -0
- training/__init__.py +0 -0
- training/reward.py +119 -0
- training/train_grpo.py +128 -0
.dockerignore
CHANGED
|
@@ -3,3 +3,4 @@ node_modules
|
|
| 3 |
.git
|
| 4 |
__pycache__
|
| 5 |
*.md
|
|
|
|
|
|
| 3 |
.git
|
| 4 |
__pycache__
|
| 5 |
*.md
|
| 6 |
+
outputs/
|
Dockerfile
CHANGED
|
@@ -1,39 +1,54 @@
|
|
| 1 |
FROM node:20-alpine AS base
|
| 2 |
|
| 3 |
-
# ---
|
| 4 |
FROM base AS deps
|
| 5 |
WORKDIR /app
|
| 6 |
COPY package.json package-lock.json ./
|
| 7 |
RUN npm ci
|
| 8 |
|
| 9 |
-
# ---
|
| 10 |
FROM base AS builder
|
| 11 |
WORKDIR /app
|
| 12 |
COPY --from=deps /app/node_modules ./node_modules
|
| 13 |
-
COPY
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
ENV NEXT_TELEMETRY_DISABLED=1
|
| 16 |
RUN npm run build
|
| 17 |
|
| 18 |
-
# ---
|
| 19 |
-
FROM
|
| 20 |
WORKDIR /app
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
ENV NODE_ENV=production
|
| 23 |
ENV NEXT_TELEMETRY_DISABLED=1
|
| 24 |
ENV PORT=7860
|
| 25 |
ENV HOSTNAME=0.0.0.0
|
| 26 |
|
| 27 |
-
|
| 28 |
-
RUN adduser --system --uid 1001 nextjs
|
| 29 |
-
|
| 30 |
COPY --from=builder /app/public ./public
|
| 31 |
-
COPY --from=builder
|
| 32 |
-
COPY --from=builder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
|
|
|
|
|
|
| 35 |
|
| 36 |
EXPOSE 7860
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
CMD ["node", "server.js"]
|
|
|
|
| 1 |
FROM node:20-alpine AS base
|
| 2 |
|
| 3 |
+
# --- Node dependencies ---
|
| 4 |
FROM base AS deps
|
| 5 |
WORKDIR /app
|
| 6 |
COPY package.json package-lock.json ./
|
| 7 |
RUN npm ci
|
| 8 |
|
| 9 |
+
# --- Next.js build ---
|
| 10 |
FROM base AS builder
|
| 11 |
WORKDIR /app
|
| 12 |
COPY --from=deps /app/node_modules ./node_modules
|
| 13 |
+
COPY app/ ./app/
|
| 14 |
+
COPY components/ ./components/
|
| 15 |
+
COPY hooks/ ./hooks/
|
| 16 |
+
COPY lib/ ./lib/
|
| 17 |
+
COPY public/ ./public/
|
| 18 |
+
COPY package.json package-lock.json tsconfig.json next.config.ts postcss.config.mjs eslint.config.mjs .eslintrc.json next-env.d.ts metadata.json ./
|
| 19 |
ENV NEXT_TELEMETRY_DISABLED=1
|
| 20 |
RUN npm run build
|
| 21 |
|
| 22 |
+
# --- Final runner: Node + Python ---
|
| 23 |
+
FROM node:20-alpine AS runner
|
| 24 |
WORKDIR /app
|
| 25 |
|
| 26 |
+
# Install Python + pip for the OpenEnv API
|
| 27 |
+
RUN apk add --no-cache python3 py3-pip py3-numpy py3-scipy
|
| 28 |
+
|
| 29 |
ENV NODE_ENV=production
|
| 30 |
ENV NEXT_TELEMETRY_DISABLED=1
|
| 31 |
ENV PORT=7860
|
| 32 |
ENV HOSTNAME=0.0.0.0
|
| 33 |
|
| 34 |
+
# Copy Next.js standalone build
|
|
|
|
|
|
|
| 35 |
COPY --from=builder /app/public ./public
|
| 36 |
+
COPY --from=builder /app/.next/standalone ./
|
| 37 |
+
COPY --from=builder /app/.next/static ./.next/static
|
| 38 |
+
|
| 39 |
+
# Copy Python RL environment
|
| 40 |
+
COPY origami_server/ ./origami_server/
|
| 41 |
+
COPY training/ ./training/
|
| 42 |
+
COPY requirements.txt ./
|
| 43 |
+
|
| 44 |
+
# Install Python deps
|
| 45 |
+
RUN pip3 install --no-cache-dir --break-system-packages -r requirements.txt
|
| 46 |
|
| 47 |
+
# Start script: run both services
|
| 48 |
+
COPY start.sh ./
|
| 49 |
+
RUN chmod +x start.sh
|
| 50 |
|
| 51 |
EXPOSE 7860
|
| 52 |
+
EXPOSE 8000
|
| 53 |
|
| 54 |
+
CMD ["./start.sh"]
|
|
|
README.md
CHANGED
|
@@ -7,5 +7,5 @@ sdk: docker
|
|
| 7 |
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
license: mit
|
| 10 |
-
short_description: Interactive 3D origami simulator
|
| 11 |
---
|
|
|
|
| 7 |
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
license: mit
|
| 10 |
+
short_description: Interactive 3D origami simulator + RL training environment
|
| 11 |
---
|
origami_server/__init__.py
ADDED
|
File without changes
|
origami_server/app.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI entry point — OpenEnv create_app() + custom endpoints."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from fastapi.responses import HTMLResponse
|
| 7 |
+
from fastapi.staticfiles import StaticFiles
|
| 8 |
+
|
| 9 |
+
from openenv.core.env_server.http_server import create_app
|
| 10 |
+
|
| 11 |
+
from .environment import OrigamiEnvironment
|
| 12 |
+
from .models import OrigamiAction, OrigamiObservation
|
| 13 |
+
|
| 14 |
+
app = create_app(
|
| 15 |
+
OrigamiEnvironment,
|
| 16 |
+
OrigamiAction,
|
| 17 |
+
OrigamiObservation,
|
| 18 |
+
env_name="origami_env",
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from .tasks import TASKS
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@app.get("/tasks")
|
| 25 |
+
def get_tasks():
|
| 26 |
+
return {
|
| 27 |
+
name: {
|
| 28 |
+
"name": task["name"],
|
| 29 |
+
"description": task["description"],
|
| 30 |
+
"difficulty": task["difficulty"],
|
| 31 |
+
"paper": task["paper"],
|
| 32 |
+
"target_fold": task["target_fold"],
|
| 33 |
+
}
|
| 34 |
+
for name, task in TASKS.items()
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@app.get("/tasks/{task_name}")
|
| 39 |
+
def get_task_detail(task_name: str):
|
| 40 |
+
if task_name not in TASKS:
|
| 41 |
+
from fastapi import HTTPException
|
| 42 |
+
|
| 43 |
+
raise HTTPException(status_code=404, detail=f"Task '{task_name}' not found")
|
| 44 |
+
task = TASKS[task_name]
|
| 45 |
+
return {
|
| 46 |
+
"name": task["name"],
|
| 47 |
+
"description": task["description"],
|
| 48 |
+
"difficulty": task["difficulty"],
|
| 49 |
+
"paper": task["paper"],
|
| 50 |
+
"target_fold": task["target_fold"],
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def main():
|
| 55 |
+
import uvicorn
|
| 56 |
+
|
| 57 |
+
port = int(os.environ.get("PORT", 8000))
|
| 58 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
main()
|
origami_server/engine/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .simulate import simulate
|
| 2 |
+
from .fold_parser import parse_fold, validate_fold
|
| 3 |
+
from .shape_match import compute_shape_match
|
origami_server/engine/fold_parser.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FOLD JSON parsing and validation.
|
| 2 |
+
|
| 3 |
+
Validates LLM-generated FOLD crease patterns before simulation.
|
| 4 |
+
FOLD spec: https://github.com/edemaine/fold
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def validate_fold(fold_data: dict[str, Any]) -> tuple[bool, str]:
|
| 13 |
+
"""Validate a FOLD JSON object. Returns (is_valid, error_message)."""
|
| 14 |
+
|
| 15 |
+
# Required fields
|
| 16 |
+
for key in ("vertices_coords", "edges_vertices", "edges_assignment"):
|
| 17 |
+
if key not in fold_data:
|
| 18 |
+
return False, f"Missing required field: {key}"
|
| 19 |
+
|
| 20 |
+
verts = fold_data["vertices_coords"]
|
| 21 |
+
edges = fold_data["edges_vertices"]
|
| 22 |
+
assignments = fold_data["edges_assignment"]
|
| 23 |
+
|
| 24 |
+
# Must have at least 3 vertices (a triangle)
|
| 25 |
+
if len(verts) < 3:
|
| 26 |
+
return False, f"Need at least 3 vertices, got {len(verts)}"
|
| 27 |
+
|
| 28 |
+
# Must have at least 3 edges
|
| 29 |
+
if len(edges) < 3:
|
| 30 |
+
return False, f"Need at least 3 edges, got {len(edges)}"
|
| 31 |
+
|
| 32 |
+
# Edges and assignments must match length
|
| 33 |
+
if len(edges) != len(assignments):
|
| 34 |
+
return False, (
|
| 35 |
+
f"edges_vertices ({len(edges)}) and "
|
| 36 |
+
f"edges_assignment ({len(assignments)}) must match length"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Fold angles must match if present
|
| 40 |
+
if "edges_foldAngle" in fold_data:
|
| 41 |
+
angles = fold_data["edges_foldAngle"]
|
| 42 |
+
if len(angles) != len(edges):
|
| 43 |
+
return False, (
|
| 44 |
+
f"edges_foldAngle ({len(angles)}) must match "
|
| 45 |
+
f"edges_vertices ({len(edges)})"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Validate vertex coordinates (2D or 3D)
|
| 49 |
+
num_verts = len(verts)
|
| 50 |
+
for i, v in enumerate(verts):
|
| 51 |
+
if not isinstance(v, (list, tuple)) or len(v) < 2:
|
| 52 |
+
return False, f"Vertex {i} must be [x, y] or [x, y, z], got {v}"
|
| 53 |
+
|
| 54 |
+
# Validate edge indices
|
| 55 |
+
for i, e in enumerate(edges):
|
| 56 |
+
if not isinstance(e, (list, tuple)) or len(e) != 2:
|
| 57 |
+
return False, f"Edge {i} must be [v1, v2], got {e}"
|
| 58 |
+
v1, v2 = e
|
| 59 |
+
if v1 < 0 or v1 >= num_verts or v2 < 0 or v2 >= num_verts:
|
| 60 |
+
return False, f"Edge {i} references invalid vertex: {e}"
|
| 61 |
+
if v1 == v2:
|
| 62 |
+
return False, f"Edge {i} is degenerate (same vertex): {e}"
|
| 63 |
+
|
| 64 |
+
# Validate assignments
|
| 65 |
+
valid_assignments = {"M", "V", "B", "F", "U", "C"}
|
| 66 |
+
for i, a in enumerate(assignments):
|
| 67 |
+
if a not in valid_assignments:
|
| 68 |
+
return False, f"Edge {i} has invalid assignment '{a}'"
|
| 69 |
+
|
| 70 |
+
# Must have at least one fold crease (M or V)
|
| 71 |
+
has_fold = any(a in ("M", "V") for a in assignments)
|
| 72 |
+
if not has_fold:
|
| 73 |
+
return False, "No fold creases (M or V) found"
|
| 74 |
+
|
| 75 |
+
# Must have boundary edges
|
| 76 |
+
has_boundary = any(a == "B" for a in assignments)
|
| 77 |
+
if not has_boundary:
|
| 78 |
+
return False, "No boundary edges (B) found"
|
| 79 |
+
|
| 80 |
+
return True, ""
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def parse_fold(fold_data: dict[str, Any]) -> dict[str, np.ndarray]:
|
| 84 |
+
"""Parse validated FOLD JSON into numpy arrays for simulation.
|
| 85 |
+
|
| 86 |
+
Returns dict with:
|
| 87 |
+
vertices: (N, 3) float64 — vertex positions (z=0 for 2D input)
|
| 88 |
+
edges: (E, 2) int — edge vertex indices
|
| 89 |
+
assignments: list[str] — edge type per edge
|
| 90 |
+
fold_angles: (E,) float64 — target fold angle per edge (radians)
|
| 91 |
+
faces: (F, 3) int — triangulated face vertex indices
|
| 92 |
+
"""
|
| 93 |
+
verts = fold_data["vertices_coords"]
|
| 94 |
+
|
| 95 |
+
# Ensure 3D (add z=0 if 2D)
|
| 96 |
+
vertices = np.zeros((len(verts), 3), dtype=np.float64)
|
| 97 |
+
for i, v in enumerate(verts):
|
| 98 |
+
vertices[i, 0] = v[0]
|
| 99 |
+
vertices[i, 1] = v[1]
|
| 100 |
+
if len(v) > 2:
|
| 101 |
+
vertices[i, 2] = v[2]
|
| 102 |
+
|
| 103 |
+
edges = np.array(fold_data["edges_vertices"], dtype=np.int32)
|
| 104 |
+
assignments = list(fold_data["edges_assignment"])
|
| 105 |
+
|
| 106 |
+
# Fold angles: default based on assignment if not provided
|
| 107 |
+
if "edges_foldAngle" in fold_data:
|
| 108 |
+
fold_angles = np.array(fold_data["edges_foldAngle"], dtype=np.float64)
|
| 109 |
+
else:
|
| 110 |
+
fold_angles = np.zeros(len(edges), dtype=np.float64)
|
| 111 |
+
for i, a in enumerate(assignments):
|
| 112 |
+
if a == "V":
|
| 113 |
+
fold_angles[i] = 180.0
|
| 114 |
+
elif a == "M":
|
| 115 |
+
fold_angles[i] = -180.0
|
| 116 |
+
|
| 117 |
+
# Convert degrees to radians for simulation
|
| 118 |
+
fold_angles_rad = np.radians(fold_angles)
|
| 119 |
+
|
| 120 |
+
# Triangulate faces
|
| 121 |
+
if "faces_vertices" in fold_data:
|
| 122 |
+
raw_faces = fold_data["faces_vertices"]
|
| 123 |
+
faces = _triangulate_faces(raw_faces)
|
| 124 |
+
else:
|
| 125 |
+
faces = _compute_faces(vertices, edges)
|
| 126 |
+
|
| 127 |
+
return {
|
| 128 |
+
"vertices": vertices,
|
| 129 |
+
"edges": edges,
|
| 130 |
+
"assignments": assignments,
|
| 131 |
+
"fold_angles": fold_angles_rad,
|
| 132 |
+
"faces": faces,
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _triangulate_faces(raw_faces: list[list[int]]) -> np.ndarray:
|
| 137 |
+
"""Fan-triangulate polygon faces into triangles."""
|
| 138 |
+
triangles = []
|
| 139 |
+
for face in raw_faces:
|
| 140 |
+
if len(face) < 3:
|
| 141 |
+
continue
|
| 142 |
+
for i in range(1, len(face) - 1):
|
| 143 |
+
triangles.append([face[0], face[i], face[i + 1]])
|
| 144 |
+
if not triangles:
|
| 145 |
+
return np.zeros((0, 3), dtype=np.int32)
|
| 146 |
+
return np.array(triangles, dtype=np.int32)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _compute_faces(vertices: np.ndarray, edges: np.ndarray) -> np.ndarray:
|
| 150 |
+
"""Compute triangulated faces from vertices and edges using adjacency."""
|
| 151 |
+
from collections import defaultdict
|
| 152 |
+
|
| 153 |
+
adj = defaultdict(set)
|
| 154 |
+
for v1, v2 in edges:
|
| 155 |
+
adj[v1].add(v2)
|
| 156 |
+
adj[v2].add(v1)
|
| 157 |
+
|
| 158 |
+
triangles = set()
|
| 159 |
+
for v1, v2 in edges:
|
| 160 |
+
common = adj[v1] & adj[v2]
|
| 161 |
+
for v3 in common:
|
| 162 |
+
tri = tuple(sorted([v1, v2, v3]))
|
| 163 |
+
triangles.add(tri)
|
| 164 |
+
|
| 165 |
+
if not triangles:
|
| 166 |
+
# Fallback: create faces using Delaunay on 2D projection
|
| 167 |
+
from scipy.spatial import Delaunay
|
| 168 |
+
|
| 169 |
+
pts_2d = vertices[:, :2]
|
| 170 |
+
try:
|
| 171 |
+
tri = Delaunay(pts_2d)
|
| 172 |
+
return tri.simplices.astype(np.int32)
|
| 173 |
+
except Exception:
|
| 174 |
+
return np.zeros((0, 3), dtype=np.int32)
|
| 175 |
+
|
| 176 |
+
return np.array(list(triangles), dtype=np.int32)
|
origami_server/engine/shape_match.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shape matching for reward computation.
|
| 2 |
+
|
| 3 |
+
Computes similarity between the LLM's folded shape and the target shape.
|
| 4 |
+
Like AlphaFold's RMSD but for origami vertex positions.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from scipy.spatial.distance import cdist
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def compute_shape_match(
|
| 12 |
+
predicted: np.ndarray,
|
| 13 |
+
target: np.ndarray,
|
| 14 |
+
) -> float:
|
| 15 |
+
"""Compute shape similarity between predicted and target positions.
|
| 16 |
+
|
| 17 |
+
Uses chamfer distance normalized by bounding box diagonal.
|
| 18 |
+
Aligns shapes by centering before comparison.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
predicted: (N, 3) predicted vertex positions.
|
| 22 |
+
target: (M, 3) target vertex positions.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Similarity score in [0, 1]. 1.0 = perfect match.
|
| 26 |
+
"""
|
| 27 |
+
if len(predicted) == 0 or len(target) == 0:
|
| 28 |
+
return 0.0
|
| 29 |
+
|
| 30 |
+
# Center both point clouds
|
| 31 |
+
pred_centered = predicted - predicted.mean(axis=0)
|
| 32 |
+
target_centered = target - target.mean(axis=0)
|
| 33 |
+
|
| 34 |
+
# Try multiple rotations and pick best match
|
| 35 |
+
best_score = 0.0
|
| 36 |
+
for rotation in _get_alignment_rotations():
|
| 37 |
+
rotated = pred_centered @ rotation.T
|
| 38 |
+
score = _chamfer_similarity(rotated, target_centered)
|
| 39 |
+
best_score = max(best_score, score)
|
| 40 |
+
|
| 41 |
+
return best_score
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _chamfer_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
| 45 |
+
"""Chamfer distance converted to similarity score."""
|
| 46 |
+
d = cdist(a, b)
|
| 47 |
+
|
| 48 |
+
# Forward: for each point in a, min distance to b
|
| 49 |
+
forward = d.min(axis=1).mean()
|
| 50 |
+
# Backward: for each point in b, min distance to a
|
| 51 |
+
backward = d.min(axis=0).mean()
|
| 52 |
+
chamfer = (forward + backward) / 2.0
|
| 53 |
+
|
| 54 |
+
# Normalize by bounding box diagonal of target
|
| 55 |
+
all_pts = np.vstack([a, b])
|
| 56 |
+
bbox_diag = np.linalg.norm(all_pts.max(axis=0) - all_pts.min(axis=0))
|
| 57 |
+
if bbox_diag < 1e-12:
|
| 58 |
+
return 1.0 if chamfer < 1e-12 else 0.0
|
| 59 |
+
|
| 60 |
+
similarity = max(0.0, 1.0 - chamfer / bbox_diag)
|
| 61 |
+
return similarity
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _get_alignment_rotations() -> list[np.ndarray]:
|
| 65 |
+
"""Generate rotation matrices for alignment search.
|
| 66 |
+
|
| 67 |
+
Identity + 90 deg rotations around each axis + mirrors (15 total).
|
| 68 |
+
"""
|
| 69 |
+
I = np.eye(3)
|
| 70 |
+
rotations = [I]
|
| 71 |
+
|
| 72 |
+
# 90 deg rotations around Z axis
|
| 73 |
+
for k in range(1, 4):
|
| 74 |
+
angle = k * np.pi / 2
|
| 75 |
+
c, s = np.cos(angle), np.sin(angle)
|
| 76 |
+
rotations.append(np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]))
|
| 77 |
+
|
| 78 |
+
# 90 deg rotations around X axis
|
| 79 |
+
for k in range(1, 4):
|
| 80 |
+
angle = k * np.pi / 2
|
| 81 |
+
c, s = np.cos(angle), np.sin(angle)
|
| 82 |
+
rotations.append(np.array([[1, 0, 0], [0, c, -s], [0, s, c]]))
|
| 83 |
+
|
| 84 |
+
# 90 deg rotations around Y axis
|
| 85 |
+
for k in range(1, 4):
|
| 86 |
+
angle = k * np.pi / 2
|
| 87 |
+
c, s = np.cos(angle), np.sin(angle)
|
| 88 |
+
rotations.append(np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]]))
|
| 89 |
+
|
| 90 |
+
# Mirrors
|
| 91 |
+
rotations.append(np.diag([-1.0, 1.0, 1.0]))
|
| 92 |
+
rotations.append(np.diag([1.0, -1.0, 1.0]))
|
| 93 |
+
rotations.append(np.diag([1.0, 1.0, -1.0]))
|
| 94 |
+
|
| 95 |
+
return rotations
|
origami_server/engine/simulate.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Origami fold simulator — analytical rotation with cumulative transforms.
|
| 2 |
+
|
| 3 |
+
BFS from face 0 through the face adjacency graph. Each face accumulates
|
| 4 |
+
a rotation transform (R, t) such that: folded_pos = R @ flat_pos + t.
|
| 5 |
+
When crossing a fold edge, the fold rotation is composed with the parent
|
| 6 |
+
face's transform. Non-fold edges inherit the parent's transform directly.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from scipy.spatial.transform import Rotation
|
| 13 |
+
|
| 14 |
+
from .fold_parser import parse_fold
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class SimResult:
|
| 19 |
+
"""Result of a fold simulation."""
|
| 20 |
+
|
| 21 |
+
positions: np.ndarray # (N, 3) final vertex positions
|
| 22 |
+
converged: bool
|
| 23 |
+
steps_taken: int
|
| 24 |
+
max_strain: float
|
| 25 |
+
total_energy: float
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def simulate(
|
| 29 |
+
fold_data: dict,
|
| 30 |
+
crease_percent: float = 1.0,
|
| 31 |
+
max_steps: int = 500,
|
| 32 |
+
params: dict | None = None,
|
| 33 |
+
) -> SimResult:
|
| 34 |
+
"""Simulate a FOLD crease pattern and return final 3D positions."""
|
| 35 |
+
parsed = parse_fold(fold_data)
|
| 36 |
+
flat_pos = parsed["vertices"].copy()
|
| 37 |
+
edges = parsed["edges"]
|
| 38 |
+
assignments = parsed["assignments"]
|
| 39 |
+
fold_angles = parsed["fold_angles"]
|
| 40 |
+
faces = parsed["faces"]
|
| 41 |
+
positions = flat_pos.copy()
|
| 42 |
+
|
| 43 |
+
if len(faces) == 0:
|
| 44 |
+
return SimResult(
|
| 45 |
+
positions=positions, converged=True,
|
| 46 |
+
steps_taken=0, max_strain=0.0, total_energy=0.0,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
face_adj = _build_face_adjacency(faces)
|
| 50 |
+
|
| 51 |
+
crease_map: dict[tuple[int, int], float] = {}
|
| 52 |
+
for i, (v1, v2) in enumerate(edges):
|
| 53 |
+
key = (min(int(v1), int(v2)), max(int(v1), int(v2)))
|
| 54 |
+
if assignments[i] in ("M", "V"):
|
| 55 |
+
crease_map[key] = fold_angles[i] * crease_percent
|
| 56 |
+
|
| 57 |
+
n_faces = len(faces)
|
| 58 |
+
face_R = [None] * n_faces
|
| 59 |
+
face_t = [None] * n_faces
|
| 60 |
+
|
| 61 |
+
face_R[0] = np.eye(3)
|
| 62 |
+
face_t[0] = np.zeros(3)
|
| 63 |
+
|
| 64 |
+
visited = [False] * n_faces
|
| 65 |
+
visited[0] = True
|
| 66 |
+
|
| 67 |
+
placed: set[int] = set()
|
| 68 |
+
for vi in faces[0]:
|
| 69 |
+
placed.add(int(vi))
|
| 70 |
+
|
| 71 |
+
queue = [0]
|
| 72 |
+
while queue:
|
| 73 |
+
fi = queue.pop(0)
|
| 74 |
+
face = faces[fi]
|
| 75 |
+
|
| 76 |
+
for j in range(len(face)):
|
| 77 |
+
v1, v2 = int(face[j]), int(face[(j + 1) % len(face)])
|
| 78 |
+
edge_key = (min(v1, v2), max(v1, v2))
|
| 79 |
+
|
| 80 |
+
for fj in face_adj.get(edge_key, []):
|
| 81 |
+
if visited[fj]:
|
| 82 |
+
continue
|
| 83 |
+
visited[fj] = True
|
| 84 |
+
queue.append(fj)
|
| 85 |
+
|
| 86 |
+
angle = crease_map.get(edge_key, 0.0)
|
| 87 |
+
|
| 88 |
+
if abs(angle) > 1e-10:
|
| 89 |
+
p1 = positions[v1].copy()
|
| 90 |
+
axis = positions[v2] - p1
|
| 91 |
+
axis_len = np.linalg.norm(axis)
|
| 92 |
+
if axis_len > 1e-12:
|
| 93 |
+
axis_unit = axis / axis_len
|
| 94 |
+
fold_rot = Rotation.from_rotvec(
|
| 95 |
+
angle * axis_unit,
|
| 96 |
+
).as_matrix()
|
| 97 |
+
else:
|
| 98 |
+
fold_rot = np.eye(3)
|
| 99 |
+
|
| 100 |
+
face_R[fj] = fold_rot @ face_R[fi]
|
| 101 |
+
face_t[fj] = fold_rot @ (face_t[fi] - p1) + p1
|
| 102 |
+
else:
|
| 103 |
+
face_R[fj] = face_R[fi].copy()
|
| 104 |
+
face_t[fj] = face_t[fi].copy()
|
| 105 |
+
|
| 106 |
+
for vi in faces[fj]:
|
| 107 |
+
vi_int = int(vi)
|
| 108 |
+
if vi_int not in placed:
|
| 109 |
+
positions[vi_int] = face_R[fj] @ flat_pos[vi_int] + face_t[fj]
|
| 110 |
+
placed.add(vi_int)
|
| 111 |
+
|
| 112 |
+
max_strain = _compute_strain(positions, parsed)
|
| 113 |
+
|
| 114 |
+
return SimResult(
|
| 115 |
+
positions=positions,
|
| 116 |
+
converged=True,
|
| 117 |
+
steps_taken=1,
|
| 118 |
+
max_strain=max_strain,
|
| 119 |
+
total_energy=0.0,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _build_face_adjacency(
|
| 124 |
+
faces: np.ndarray,
|
| 125 |
+
) -> dict[tuple[int, int], list[int]]:
|
| 126 |
+
"""Map each edge (sorted vertex pair) to list of face indices."""
|
| 127 |
+
adj: dict[tuple[int, int], list[int]] = {}
|
| 128 |
+
for fi, face in enumerate(faces):
|
| 129 |
+
n = len(face)
|
| 130 |
+
for j in range(n):
|
| 131 |
+
v1, v2 = int(face[j]), int(face[(j + 1) % n])
|
| 132 |
+
key = (min(v1, v2), max(v1, v2))
|
| 133 |
+
if key not in adj:
|
| 134 |
+
adj[key] = []
|
| 135 |
+
adj[key].append(fi)
|
| 136 |
+
return adj
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _compute_strain(positions: np.ndarray, parsed: dict) -> float:
|
| 140 |
+
"""Compute max axial strain across all edges."""
|
| 141 |
+
edges = parsed["edges"]
|
| 142 |
+
vertices_flat = parsed["vertices"]
|
| 143 |
+
max_strain = 0.0
|
| 144 |
+
for v1, v2 in edges:
|
| 145 |
+
rest = np.linalg.norm(vertices_flat[v2] - vertices_flat[v1])
|
| 146 |
+
curr = np.linalg.norm(positions[v2] - positions[v1])
|
| 147 |
+
if rest > 1e-12:
|
| 148 |
+
strain = abs(curr - rest) / rest
|
| 149 |
+
max_strain = max(max_strain, strain)
|
| 150 |
+
return max_strain
|
origami_server/environment.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Origami RL Environment — OpenEnv Environment subclass.
|
| 2 |
+
|
| 3 |
+
Single-shot episodes: LLM submits a FOLD crease pattern, physics simulates it,
|
| 4 |
+
reward = shape similarity to target. Like AlphaFold for origami.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import uuid
|
| 8 |
+
from typing import Any, Optional
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from openenv.core import Environment
|
| 12 |
+
|
| 13 |
+
from .engine.fold_parser import validate_fold
|
| 14 |
+
from .engine.shape_match import compute_shape_match
|
| 15 |
+
from .engine.simulate import SimResult, simulate
|
| 16 |
+
from .models import OrigamiAction, OrigamiObservation, OrigamiState
|
| 17 |
+
from .tasks import get_task
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class OrigamiEnvironment(
|
| 21 |
+
Environment[OrigamiAction, OrigamiObservation, OrigamiState]
|
| 22 |
+
):
|
| 23 |
+
"""Origami folding environment.
|
| 24 |
+
|
| 25 |
+
Episode flow:
|
| 26 |
+
1. reset(task_name="triangle") -> returns task description + target info
|
| 27 |
+
2. step(OrigamiAction(fold_data={...})) -> simulates, scores, returns done=True
|
| 28 |
+
|
| 29 |
+
Single action per episode. The action IS the complete crease pattern.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 33 |
+
|
| 34 |
+
def __init__(self, **kwargs: Any):
|
| 35 |
+
super().__init__(**kwargs)
|
| 36 |
+
self._state = OrigamiState()
|
| 37 |
+
self._task: dict = {}
|
| 38 |
+
self._target_positions: np.ndarray = np.zeros((0, 3))
|
| 39 |
+
|
| 40 |
+
def reset(
|
| 41 |
+
self,
|
| 42 |
+
seed: Optional[int] = None,
|
| 43 |
+
episode_id: Optional[str] = None,
|
| 44 |
+
**kwargs: Any,
|
| 45 |
+
) -> OrigamiObservation:
|
| 46 |
+
"""Start a new episode with a target shape task."""
|
| 47 |
+
self._state = OrigamiState(
|
| 48 |
+
episode_id=episode_id or str(uuid.uuid4()),
|
| 49 |
+
step_count=0,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
task_name = kwargs.get("task_name", "triangle")
|
| 53 |
+
self._task = get_task(task_name)
|
| 54 |
+
self._state.task_name = self._task["name"]
|
| 55 |
+
|
| 56 |
+
target_fold = self._task["target_fold"]
|
| 57 |
+
try:
|
| 58 |
+
target_result = simulate(target_fold, crease_percent=1.0)
|
| 59 |
+
self._target_positions = target_result.positions
|
| 60 |
+
except Exception:
|
| 61 |
+
self._target_positions = np.zeros((0, 3))
|
| 62 |
+
|
| 63 |
+
return OrigamiObservation(
|
| 64 |
+
done=False,
|
| 65 |
+
reward=None,
|
| 66 |
+
task=self._task_info(),
|
| 67 |
+
fold_data={},
|
| 68 |
+
final_positions=[],
|
| 69 |
+
target_positions=self._target_positions.tolist(),
|
| 70 |
+
shape_similarity=0.0,
|
| 71 |
+
max_strain=0.0,
|
| 72 |
+
is_stable=True,
|
| 73 |
+
error=None,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def step(
|
| 77 |
+
self,
|
| 78 |
+
action: OrigamiAction,
|
| 79 |
+
timeout_s: Optional[float] = None,
|
| 80 |
+
**kwargs: Any,
|
| 81 |
+
) -> OrigamiObservation:
|
| 82 |
+
"""Evaluate the LLM's crease pattern.
|
| 83 |
+
|
| 84 |
+
1. Validate FOLD data
|
| 85 |
+
2. Run physics simulation (creasePercent=1.0)
|
| 86 |
+
3. Compare final shape to target
|
| 87 |
+
4. Return observation with reward = similarity * 20
|
| 88 |
+
"""
|
| 89 |
+
self._state.step_count += 1
|
| 90 |
+
fold_data = action.fold_data
|
| 91 |
+
|
| 92 |
+
is_valid, error_msg = validate_fold(fold_data)
|
| 93 |
+
if not is_valid:
|
| 94 |
+
self._state.is_stable = False
|
| 95 |
+
return OrigamiObservation(
|
| 96 |
+
done=True,
|
| 97 |
+
reward=-2.0,
|
| 98 |
+
task=self._task_info(),
|
| 99 |
+
fold_data=fold_data,
|
| 100 |
+
final_positions=[],
|
| 101 |
+
target_positions=self._target_positions.tolist(),
|
| 102 |
+
shape_similarity=0.0,
|
| 103 |
+
max_strain=0.0,
|
| 104 |
+
is_stable=False,
|
| 105 |
+
error=f"Invalid FOLD data: {error_msg}",
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
result: SimResult = simulate(fold_data, crease_percent=1.0)
|
| 110 |
+
except Exception as e:
|
| 111 |
+
self._state.is_stable = False
|
| 112 |
+
return OrigamiObservation(
|
| 113 |
+
done=True,
|
| 114 |
+
reward=-2.0,
|
| 115 |
+
task=self._task_info(),
|
| 116 |
+
fold_data=fold_data,
|
| 117 |
+
final_positions=[],
|
| 118 |
+
target_positions=self._target_positions.tolist(),
|
| 119 |
+
shape_similarity=0.0,
|
| 120 |
+
max_strain=0.0,
|
| 121 |
+
is_stable=False,
|
| 122 |
+
error=f"Simulation error: {str(e)}",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
similarity = compute_shape_match(
|
| 126 |
+
result.positions, self._target_positions
|
| 127 |
+
)
|
| 128 |
+
reward = similarity * 20.0
|
| 129 |
+
|
| 130 |
+
self._state.shape_similarity = similarity
|
| 131 |
+
self._state.is_stable = result.converged
|
| 132 |
+
|
| 133 |
+
return OrigamiObservation(
|
| 134 |
+
done=True,
|
| 135 |
+
reward=reward,
|
| 136 |
+
task=self._task_info(),
|
| 137 |
+
fold_data=fold_data,
|
| 138 |
+
final_positions=result.positions.tolist(),
|
| 139 |
+
target_positions=self._target_positions.tolist(),
|
| 140 |
+
shape_similarity=similarity,
|
| 141 |
+
max_strain=result.max_strain,
|
| 142 |
+
is_stable=result.converged,
|
| 143 |
+
error=None,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def state(self) -> OrigamiState:
|
| 148 |
+
return self._state
|
| 149 |
+
|
| 150 |
+
def _task_info(self) -> dict:
|
| 151 |
+
if not self._task:
|
| 152 |
+
return {}
|
| 153 |
+
return {
|
| 154 |
+
"name": self._task.get("name", ""),
|
| 155 |
+
"description": self._task.get("description", ""),
|
| 156 |
+
"difficulty": self._task.get("difficulty", 0),
|
| 157 |
+
"paper": self._task.get("paper", {}),
|
| 158 |
+
}
|
origami_server/models.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenEnv types for the Origami RL environment.
|
| 2 |
+
|
| 3 |
+
OrigamiAction: LLM submits a FOLD crease pattern.
|
| 4 |
+
OrigamiObservation: Result of simulating that pattern against a target.
|
| 5 |
+
OrigamiState: Internal episode state.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Optional
|
| 9 |
+
|
| 10 |
+
from openenv.core import Action, Observation, State
|
| 11 |
+
from pydantic import Field
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class OrigamiAction(Action):
|
| 15 |
+
"""LLM submits a FOLD crease pattern as its action.
|
| 16 |
+
|
| 17 |
+
The fold_data dict must contain:
|
| 18 |
+
- vertices_coords: [[x, y], ...] — 2D vertex positions on flat paper
|
| 19 |
+
- edges_vertices: [[v1, v2], ...] — edge connectivity
|
| 20 |
+
- edges_assignment: ["B"|"M"|"V", ...] — boundary/mountain/valley
|
| 21 |
+
- edges_foldAngle: [angle, ...] — target fold angles in degrees
|
| 22 |
+
(optional — defaults from assignment: M=-180, V=+180, B=0)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
fold_data: dict[str, Any] = Field(
|
| 26 |
+
..., description="FOLD-format crease pattern JSON"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class OrigamiObservation(Observation):
|
| 31 |
+
"""Result of simulating the LLM's crease pattern."""
|
| 32 |
+
|
| 33 |
+
task: dict[str, Any] = Field(default_factory=dict)
|
| 34 |
+
fold_data: dict[str, Any] = Field(default_factory=dict)
|
| 35 |
+
final_positions: list[list[float]] = Field(default_factory=list)
|
| 36 |
+
target_positions: list[list[float]] = Field(default_factory=list)
|
| 37 |
+
shape_similarity: float = 0.0
|
| 38 |
+
max_strain: float = 0.0
|
| 39 |
+
is_stable: bool = True
|
| 40 |
+
error: Optional[str] = None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class OrigamiState(State):
|
| 44 |
+
"""Internal state for an origami episode."""
|
| 45 |
+
|
| 46 |
+
task_name: str = ""
|
| 47 |
+
shape_similarity: float = 0.0
|
| 48 |
+
is_stable: bool = True
|
origami_server/tasks.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task definitions for origami RL training.
|
| 2 |
+
|
| 3 |
+
Each task defines a target shape as a reference FOLD crease pattern.
|
| 4 |
+
The LLM must discover a crease pattern that folds into the same shape.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
TASKS: dict[str, dict] = {
|
| 8 |
+
"triangle": {
|
| 9 |
+
"name": "triangle",
|
| 10 |
+
"description": "Fold the paper in half diagonally to make a triangle",
|
| 11 |
+
"difficulty": 1,
|
| 12 |
+
"paper": {"width": 1.0, "height": 1.0},
|
| 13 |
+
"target_fold": {
|
| 14 |
+
"vertices_coords": [[0, 0], [1, 0], [1, 1], [0, 1]],
|
| 15 |
+
"edges_vertices": [[0, 1], [1, 2], [2, 3], [3, 0], [0, 2]],
|
| 16 |
+
"edges_assignment": ["B", "B", "B", "B", "V"],
|
| 17 |
+
"edges_foldAngle": [0, 0, 0, 0, 180],
|
| 18 |
+
"faces_vertices": [[0, 1, 2], [0, 2, 3]],
|
| 19 |
+
},
|
| 20 |
+
},
|
| 21 |
+
"half_fold": {
|
| 22 |
+
"name": "half_fold",
|
| 23 |
+
"description": "Fold the paper in half horizontally",
|
| 24 |
+
"difficulty": 1,
|
| 25 |
+
"paper": {"width": 1.0, "height": 1.0},
|
| 26 |
+
"target_fold": {
|
| 27 |
+
"vertices_coords": [
|
| 28 |
+
[0, 0], [1, 0], [1, 1], [0, 1], [0, 0.5], [1, 0.5],
|
| 29 |
+
],
|
| 30 |
+
"edges_vertices": [
|
| 31 |
+
[0, 1], [1, 5], [5, 2], [2, 3], [3, 4], [4, 0],
|
| 32 |
+
[4, 5],
|
| 33 |
+
],
|
| 34 |
+
"edges_assignment": ["B", "B", "B", "B", "B", "B", "V"],
|
| 35 |
+
"edges_foldAngle": [0, 0, 0, 0, 0, 0, 180],
|
| 36 |
+
"faces_vertices": [[0, 1, 5, 4], [4, 5, 2, 3]],
|
| 37 |
+
},
|
| 38 |
+
},
|
| 39 |
+
"quarter_fold": {
|
| 40 |
+
"name": "quarter_fold",
|
| 41 |
+
"description": "Fold the paper into quarters (two perpendicular folds)",
|
| 42 |
+
"difficulty": 2,
|
| 43 |
+
"paper": {"width": 1.0, "height": 1.0},
|
| 44 |
+
"target_fold": {
|
| 45 |
+
"vertices_coords": [
|
| 46 |
+
[0, 0], [0.5, 0], [1, 0],
|
| 47 |
+
[0, 0.5], [0.5, 0.5], [1, 0.5],
|
| 48 |
+
[0, 1], [0.5, 1], [1, 1],
|
| 49 |
+
],
|
| 50 |
+
"edges_vertices": [
|
| 51 |
+
[0, 1], [1, 2], [2, 5], [5, 8], [8, 7], [7, 6], [6, 3], [3, 0],
|
| 52 |
+
[1, 4], [4, 7],
|
| 53 |
+
[3, 4], [4, 5],
|
| 54 |
+
],
|
| 55 |
+
"edges_assignment": [
|
| 56 |
+
"B", "B", "B", "B", "B", "B", "B", "B",
|
| 57 |
+
"V", "V", "V", "V",
|
| 58 |
+
],
|
| 59 |
+
"edges_foldAngle": [
|
| 60 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
| 61 |
+
180, 180, 180, 180,
|
| 62 |
+
],
|
| 63 |
+
"faces_vertices": [
|
| 64 |
+
[0, 1, 4, 3],
|
| 65 |
+
[1, 2, 5, 4],
|
| 66 |
+
[3, 4, 7, 6],
|
| 67 |
+
[4, 5, 8, 7],
|
| 68 |
+
],
|
| 69 |
+
},
|
| 70 |
+
},
|
| 71 |
+
"letter_fold": {
|
| 72 |
+
"name": "letter_fold",
|
| 73 |
+
"description": "Tri-fold the paper like a letter (two parallel folds)",
|
| 74 |
+
"difficulty": 2,
|
| 75 |
+
"paper": {"width": 1.0, "height": 1.0},
|
| 76 |
+
"target_fold": {
|
| 77 |
+
"vertices_coords": [
|
| 78 |
+
[0, 0], [1, 0],
|
| 79 |
+
[0, 1/3], [1, 1/3],
|
| 80 |
+
[0, 2/3], [1, 2/3],
|
| 81 |
+
[0, 1], [1, 1],
|
| 82 |
+
],
|
| 83 |
+
"edges_vertices": [
|
| 84 |
+
[0, 1], [1, 3], [3, 5], [5, 7], [7, 6], [6, 4], [4, 2], [2, 0],
|
| 85 |
+
[2, 3],
|
| 86 |
+
[4, 5],
|
| 87 |
+
],
|
| 88 |
+
"edges_assignment": [
|
| 89 |
+
"B", "B", "B", "B", "B", "B", "B", "B",
|
| 90 |
+
"V", "M",
|
| 91 |
+
],
|
| 92 |
+
"edges_foldAngle": [
|
| 93 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
| 94 |
+
180, -180,
|
| 95 |
+
],
|
| 96 |
+
"faces_vertices": [
|
| 97 |
+
[0, 1, 3, 2],
|
| 98 |
+
[2, 3, 5, 4],
|
| 99 |
+
[4, 5, 7, 6],
|
| 100 |
+
],
|
| 101 |
+
},
|
| 102 |
+
},
|
| 103 |
+
# --- Optigami patterns (from lib/patterns.ts) ---
|
| 104 |
+
"waterbomb": {
|
| 105 |
+
"name": "waterbomb",
|
| 106 |
+
"description": "Create a waterbomb base with valley folds on diagonals and mountain folds on midlines of a square sheet",
|
| 107 |
+
"difficulty": 3,
|
| 108 |
+
"paper": {"width": 2.0, "height": 2.0},
|
| 109 |
+
"target_fold": {
|
| 110 |
+
"vertices_coords": [
|
| 111 |
+
[-1, 1], [0, 1], [1, 1],
|
| 112 |
+
[-1, 0], [0, 0], [1, 0],
|
| 113 |
+
[-1, -1], [0, -1], [1, -1],
|
| 114 |
+
],
|
| 115 |
+
"edges_vertices": [
|
| 116 |
+
[0, 1], [1, 2], [2, 5], [5, 8], [8, 7], [7, 6], [6, 3], [3, 0],
|
| 117 |
+
[0, 4], [2, 4], [6, 4], [8, 4],
|
| 118 |
+
[1, 4], [3, 4], [5, 4], [7, 4],
|
| 119 |
+
],
|
| 120 |
+
"edges_assignment": [
|
| 121 |
+
"B", "B", "B", "B", "B", "B", "B", "B",
|
| 122 |
+
"V", "V", "V", "V",
|
| 123 |
+
"M", "M", "M", "M",
|
| 124 |
+
],
|
| 125 |
+
"edges_foldAngle": [
|
| 126 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
| 127 |
+
180, 180, 180, 180,
|
| 128 |
+
-180, -180, -180, -180,
|
| 129 |
+
],
|
| 130 |
+
"faces_vertices": [
|
| 131 |
+
[0, 1, 4], [1, 2, 4],
|
| 132 |
+
[2, 5, 4], [5, 8, 4],
|
| 133 |
+
[8, 7, 4], [7, 6, 4],
|
| 134 |
+
[6, 3, 4], [3, 0, 4],
|
| 135 |
+
],
|
| 136 |
+
},
|
| 137 |
+
},
|
| 138 |
+
"accordion": {
|
| 139 |
+
"name": "accordion",
|
| 140 |
+
"description": "Make an accordion (zig-zag) fold with alternating mountain and valley creases like a paper fan",
|
| 141 |
+
"difficulty": 2,
|
| 142 |
+
"paper": {"width": 2.0, "height": 2.0},
|
| 143 |
+
"target_fold": {
|
| 144 |
+
"vertices_coords": [
|
| 145 |
+
[-1, 1], [-0.5, 1], [0, 1], [0.5, 1], [1, 1],
|
| 146 |
+
[-1, -1], [-0.5, -1], [0, -1], [0.5, -1], [1, -1],
|
| 147 |
+
],
|
| 148 |
+
"edges_vertices": [
|
| 149 |
+
[0, 1], [1, 2], [2, 3], [3, 4],
|
| 150 |
+
[5, 6], [6, 7], [7, 8], [8, 9],
|
| 151 |
+
[0, 5], [4, 9],
|
| 152 |
+
[1, 6], [2, 7], [3, 8],
|
| 153 |
+
],
|
| 154 |
+
"edges_assignment": [
|
| 155 |
+
"B", "B", "B", "B",
|
| 156 |
+
"B", "B", "B", "B",
|
| 157 |
+
"B", "B",
|
| 158 |
+
"V", "M", "V",
|
| 159 |
+
],
|
| 160 |
+
"edges_foldAngle": [
|
| 161 |
+
0, 0, 0, 0,
|
| 162 |
+
0, 0, 0, 0,
|
| 163 |
+
0, 0,
|
| 164 |
+
180, -180, 180,
|
| 165 |
+
],
|
| 166 |
+
"faces_vertices": [
|
| 167 |
+
[0, 1, 6, 5], [1, 2, 7, 6],
|
| 168 |
+
[2, 3, 8, 7], [3, 4, 9, 8],
|
| 169 |
+
],
|
| 170 |
+
},
|
| 171 |
+
},
|
| 172 |
+
"miura_ori": {
|
| 173 |
+
"name": "miura_ori",
|
| 174 |
+
"description": "Create a Miura-ori tessellation fold pattern on a 2x2 grid with offset zigzag vertices",
|
| 175 |
+
"difficulty": 3,
|
| 176 |
+
"paper": {"width": 2.0, "height": 2.0},
|
| 177 |
+
"target_fold": {
|
| 178 |
+
"vertices_coords": [
|
| 179 |
+
[-1, 1], [0, 1.2], [1, 1],
|
| 180 |
+
[-1, 0], [0, 0.2], [1, 0],
|
| 181 |
+
[-1, -1], [0, -0.8], [1, -1],
|
| 182 |
+
],
|
| 183 |
+
"edges_vertices": [
|
| 184 |
+
[0, 1], [1, 2], [2, 5], [5, 8], [8, 7], [7, 6], [6, 3], [3, 0],
|
| 185 |
+
[1, 4], [4, 7],
|
| 186 |
+
[3, 4], [4, 5],
|
| 187 |
+
],
|
| 188 |
+
"edges_assignment": [
|
| 189 |
+
"B", "B", "B", "B", "B", "B", "B", "B",
|
| 190 |
+
"M", "M",
|
| 191 |
+
"V", "M",
|
| 192 |
+
],
|
| 193 |
+
"edges_foldAngle": [
|
| 194 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
| 195 |
+
-180, -180,
|
| 196 |
+
180, -180,
|
| 197 |
+
],
|
| 198 |
+
"faces_vertices": [
|
| 199 |
+
[0, 1, 4, 3], [1, 2, 5, 4],
|
| 200 |
+
[3, 4, 7, 6], [4, 5, 8, 7],
|
| 201 |
+
],
|
| 202 |
+
},
|
| 203 |
+
},
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_task(name: str | None = None) -> dict:
|
| 208 |
+
"""Get a task by name. Defaults to 'triangle'."""
|
| 209 |
+
if name is None:
|
| 210 |
+
name = "triangle"
|
| 211 |
+
if name not in TASKS:
|
| 212 |
+
raise ValueError(f"Unknown task '{name}'. Available: {list(TASKS.keys())}")
|
| 213 |
+
return TASKS[name]
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def list_tasks() -> list[str]:
|
| 217 |
+
"""List all available task names."""
|
| 218 |
+
return list(TASKS.keys())
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.100.0
|
| 2 |
+
numpy>=1.24.0
|
| 3 |
+
openenv-core[core]>=0.2.1
|
| 4 |
+
pydantic>=2.0.0
|
| 5 |
+
scipy>=1.10
|
| 6 |
+
uvicorn>=0.23.0
|
start.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/sh
|
| 2 |
+
# Start the OpenEnv API in the background on port 8000
|
| 3 |
+
python3 -m uvicorn origami_server.app:app --host 0.0.0.0 --port 8000 &
|
| 4 |
+
|
| 5 |
+
# Start Next.js frontend on port 7860 (foreground)
|
| 6 |
+
node server.js
|
training/__init__.py
ADDED
|
File without changes
|
training/reward.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GRPO reward functions for origami RL training.
|
| 2 |
+
|
| 3 |
+
Two reward functions (matching the 2048 pattern):
|
| 4 |
+
1. valid_fold: Does the LLM output parse as valid FOLD JSON?
|
| 5 |
+
2. shape_match: Simulate and compare to target shape.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import re
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from origami_server.engine.fold_parser import validate_fold
|
| 15 |
+
from origami_server.engine.shape_match import compute_shape_match
|
| 16 |
+
from origami_server.engine.simulate import simulate
|
| 17 |
+
from origami_server.tasks import get_task
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def extract_fold_json(response: str) -> dict | None:
|
| 21 |
+
"""Extract FOLD JSON from LLM response text.
|
| 22 |
+
|
| 23 |
+
Looks for JSON between ```json ... ``` or raw JSON object.
|
| 24 |
+
"""
|
| 25 |
+
# Try fenced code block first
|
| 26 |
+
match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", response, re.DOTALL)
|
| 27 |
+
if match:
|
| 28 |
+
try:
|
| 29 |
+
return json.loads(match.group(1))
|
| 30 |
+
except json.JSONDecodeError:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
# Try raw JSON object
|
| 34 |
+
match = re.search(r"\{[^{}]*\"vertices_coords\"[^{}]*\}", response, re.DOTALL)
|
| 35 |
+
if match:
|
| 36 |
+
try:
|
| 37 |
+
return json.loads(match.group(0))
|
| 38 |
+
except json.JSONDecodeError:
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
# Try parsing the whole response
|
| 42 |
+
try:
|
| 43 |
+
data = json.loads(response.strip())
|
| 44 |
+
if isinstance(data, dict) and "vertices_coords" in data:
|
| 45 |
+
return data
|
| 46 |
+
except (json.JSONDecodeError, ValueError):
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def valid_fold(completions: list, **kwargs: Any) -> list[float]:
|
| 53 |
+
"""Reward 1: Does the LLM output parse as valid FOLD JSON?
|
| 54 |
+
|
| 55 |
+
+1.0 valid FOLD JSON with correct structure
|
| 56 |
+
-0.5 parseable JSON but invalid FOLD structure
|
| 57 |
+
-2.0 not parseable as JSON at all
|
| 58 |
+
"""
|
| 59 |
+
scores = []
|
| 60 |
+
for completion in completions:
|
| 61 |
+
response = completion[0]["content"]
|
| 62 |
+
fold_data = extract_fold_json(response)
|
| 63 |
+
|
| 64 |
+
if fold_data is None:
|
| 65 |
+
scores.append(-2.0)
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
is_valid, error = validate_fold(fold_data)
|
| 69 |
+
if is_valid:
|
| 70 |
+
scores.append(1.0)
|
| 71 |
+
else:
|
| 72 |
+
scores.append(-0.5)
|
| 73 |
+
|
| 74 |
+
return scores
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def shape_match(
|
| 78 |
+
completions: list,
|
| 79 |
+
task_name: str = "triangle",
|
| 80 |
+
**kwargs: Any,
|
| 81 |
+
) -> list[float]:
|
| 82 |
+
"""Reward 2: Simulate the fold and compare to target shape.
|
| 83 |
+
|
| 84 |
+
Score = similarity * 20.0 (range: 0 to 20)
|
| 85 |
+
-1.0 if simulation fails/diverges
|
| 86 |
+
-2.0 if FOLD data is invalid
|
| 87 |
+
"""
|
| 88 |
+
task = get_task(task_name)
|
| 89 |
+
target_fold = task["target_fold"]
|
| 90 |
+
|
| 91 |
+
# Pre-compute target positions
|
| 92 |
+
try:
|
| 93 |
+
target_result = simulate(target_fold, crease_percent=1.0)
|
| 94 |
+
target_positions = target_result.positions
|
| 95 |
+
except Exception:
|
| 96 |
+
return [0.0] * len(completions)
|
| 97 |
+
|
| 98 |
+
scores = []
|
| 99 |
+
for completion in completions:
|
| 100 |
+
response = completion[0]["content"]
|
| 101 |
+
fold_data = extract_fold_json(response)
|
| 102 |
+
|
| 103 |
+
if fold_data is None:
|
| 104 |
+
scores.append(-2.0)
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
is_valid, error = validate_fold(fold_data)
|
| 108 |
+
if not is_valid:
|
| 109 |
+
scores.append(-1.0)
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
result = simulate(fold_data, crease_percent=1.0)
|
| 114 |
+
similarity = compute_shape_match(result.positions, target_positions)
|
| 115 |
+
scores.append(similarity * 20.0)
|
| 116 |
+
except Exception:
|
| 117 |
+
scores.append(-1.0)
|
| 118 |
+
|
| 119 |
+
return scores
|
training/train_grpo.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GRPO training script for origami RL.
|
| 2 |
+
|
| 3 |
+
Follows the OpenEnv + Unsloth pattern:
|
| 4 |
+
- LLM generates FOLD JSON crease patterns
|
| 5 |
+
- Two reward functions: valid_fold + shape_match
|
| 6 |
+
- GRPOTrainer from TRL handles the RL loop
|
| 7 |
+
|
| 8 |
+
Usage (Colab):
|
| 9 |
+
python -m training.train_grpo --task triangle --max_steps 600
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
|
| 14 |
+
PROMPT_TEMPLATE = """You are an origami designer. Generate a FOLD-format crease pattern
|
| 15 |
+
that, when folded, produces the target shape described below.
|
| 16 |
+
|
| 17 |
+
Target: {description}
|
| 18 |
+
Paper size: {width} x {height}
|
| 19 |
+
|
| 20 |
+
Output a JSON object with these exact fields:
|
| 21 |
+
- vertices_coords: [[x, y], ...] — 2D positions on the flat paper (0 to {width} for x, 0 to {height} for y)
|
| 22 |
+
- edges_vertices: [[v1, v2], ...] — pairs of vertex indices forming edges
|
| 23 |
+
- edges_assignment: ["B"|"M"|"V", ...] — B=boundary, M=mountain fold, V=valley fold
|
| 24 |
+
- edges_foldAngle: [angle, ...] — fold angles in degrees (M: negative like -180, V: positive like 180, B: 0)
|
| 25 |
+
|
| 26 |
+
Rules:
|
| 27 |
+
- Boundary edges (B) must outline the paper rectangle
|
| 28 |
+
- At least one fold crease (M or V) must exist
|
| 29 |
+
- Mountain fold angles are negative (-180 to 0)
|
| 30 |
+
- Valley fold angles are positive (0 to 180)
|
| 31 |
+
- All vertex indices in edges must be valid (0 to N-1)
|
| 32 |
+
|
| 33 |
+
Output ONLY the JSON object wrapped in ```json ... ``` markers."""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_prompt(task: dict) -> str:
|
| 37 |
+
return PROMPT_TEMPLATE.format(
|
| 38 |
+
description=task["description"],
|
| 39 |
+
width=task["paper"]["width"],
|
| 40 |
+
height=task["paper"]["height"],
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
parser = argparse.ArgumentParser(description="GRPO training for origami RL")
|
| 46 |
+
parser.add_argument("--task", default="triangle", help="Task name")
|
| 47 |
+
parser.add_argument("--max_steps", type=int, default=600)
|
| 48 |
+
parser.add_argument("--num_generations", type=int, default=4)
|
| 49 |
+
parser.add_argument("--model", default="unsloth/Qwen3-8B-unsloth-bnb-4bit")
|
| 50 |
+
parser.add_argument("--lr", type=float, default=2e-4)
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
|
| 53 |
+
# --- These imports are heavy, only load when actually training ---
|
| 54 |
+
from datasets import Dataset
|
| 55 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 56 |
+
from unsloth import FastLanguageModel
|
| 57 |
+
|
| 58 |
+
from origami_server.tasks import get_task
|
| 59 |
+
from training.reward import shape_match, valid_fold
|
| 60 |
+
|
| 61 |
+
task = get_task(args.task)
|
| 62 |
+
prompt_text = build_prompt(task)
|
| 63 |
+
|
| 64 |
+
# Build dataset (1000 copies of same prompt, like 2048)
|
| 65 |
+
dataset = Dataset.from_list(
|
| 66 |
+
[
|
| 67 |
+
{
|
| 68 |
+
"prompt": [{"role": "user", "content": prompt_text}],
|
| 69 |
+
"answer": 0,
|
| 70 |
+
}
|
| 71 |
+
]
|
| 72 |
+
* 1000
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Load model with LoRA
|
| 76 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 77 |
+
model_name=args.model,
|
| 78 |
+
load_in_4bit=True,
|
| 79 |
+
max_seq_length=2048,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
model = FastLanguageModel.get_peft_model(
|
| 83 |
+
model,
|
| 84 |
+
r=8,
|
| 85 |
+
target_modules=[
|
| 86 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 87 |
+
"gate_proj", "up_proj", "down_proj",
|
| 88 |
+
],
|
| 89 |
+
lora_alpha=16,
|
| 90 |
+
use_gradient_checkpointing="unsloth",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Wrap shape_match to inject task_name
|
| 94 |
+
def shape_match_reward(completions, **kwargs):
|
| 95 |
+
return shape_match(completions, task_name=args.task, **kwargs)
|
| 96 |
+
|
| 97 |
+
# GRPO config
|
| 98 |
+
training_args = GRPOConfig(
|
| 99 |
+
temperature=1.0,
|
| 100 |
+
learning_rate=args.lr,
|
| 101 |
+
weight_decay=0.001,
|
| 102 |
+
warmup_ratio=0.1,
|
| 103 |
+
lr_scheduler_type="linear",
|
| 104 |
+
optim="adamw_8bit",
|
| 105 |
+
logging_steps=1,
|
| 106 |
+
per_device_train_batch_size=1,
|
| 107 |
+
gradient_accumulation_steps=1,
|
| 108 |
+
num_generations=args.num_generations,
|
| 109 |
+
max_prompt_length=1024,
|
| 110 |
+
max_completion_length=1024,
|
| 111 |
+
max_steps=args.max_steps,
|
| 112 |
+
save_steps=100,
|
| 113 |
+
output_dir="outputs",
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
trainer = GRPOTrainer(
|
| 117 |
+
model=model,
|
| 118 |
+
processing_class=tokenizer,
|
| 119 |
+
reward_funcs=[valid_fold, shape_match_reward],
|
| 120 |
+
args=training_args,
|
| 121 |
+
train_dataset=dataset,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
trainer.train()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
main()
|