Spaces:
Sleeping
Sleeping
Commit ·
9ae534f
1
Parent(s): bd5c0bd
Remove openenv-core dep to fix HF build timeout
Browse filesThe openenv-core package pulls 50+ transitive dependencies that
cause the Docker build to exceed the free-tier time limit.
Replaced with plain FastAPI endpoints (/reset, /step, /tasks).
Training scripts still reference origami_server directly.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- README.md +1 -1
- origami_server/app.py +135 -29
- origami_server/environment.py +0 -158
- origami_server/models.py +12 -18
- requirements.txt +0 -1
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
---
|
| 2 |
title: Optigami
|
| 3 |
-
emoji: "\
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: purple
|
| 6 |
sdk: docker
|
|
|
|
| 1 |
---
|
| 2 |
title: Optigami
|
| 3 |
+
emoji: "\U0001F9A2"
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: purple
|
| 6 |
sdk: docker
|
origami_server/app.py
CHANGED
|
@@ -1,59 +1,165 @@
|
|
| 1 |
-
"""FastAPI entry point —
|
| 2 |
|
| 3 |
import os
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
from fastapi
|
|
|
|
| 8 |
|
| 9 |
-
from
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
from .models import OrigamiAction, OrigamiObservation
|
| 13 |
|
| 14 |
-
app = create_app(
|
| 15 |
-
OrigamiEnvironment,
|
| 16 |
-
OrigamiAction,
|
| 17 |
-
OrigamiObservation,
|
| 18 |
-
env_name="origami_env",
|
| 19 |
-
)
|
| 20 |
|
| 21 |
-
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
"name": task["name"],
|
| 29 |
"description": task["description"],
|
| 30 |
"difficulty": task["difficulty"],
|
| 31 |
"paper": task["paper"],
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
}
|
| 34 |
-
for name,
|
| 35 |
}
|
| 36 |
|
| 37 |
|
| 38 |
@app.get("/tasks/{task_name}")
|
| 39 |
def get_task_detail(task_name: str):
|
| 40 |
if task_name not in TASKS:
|
| 41 |
-
from fastapi import HTTPException
|
| 42 |
-
|
| 43 |
raise HTTPException(status_code=404, detail=f"Task '{task_name}' not found")
|
| 44 |
-
|
| 45 |
return {
|
| 46 |
-
"name":
|
| 47 |
-
"description":
|
| 48 |
-
"difficulty":
|
| 49 |
-
"paper":
|
| 50 |
-
"target_fold":
|
| 51 |
}
|
| 52 |
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def main():
|
| 55 |
import uvicorn
|
| 56 |
-
|
| 57 |
port = int(os.environ.get("PORT", 8000))
|
| 58 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
| 59 |
|
|
|
|
| 1 |
+
"""FastAPI entry point — standalone RL API (no openenv-core dependency)."""
|
| 2 |
|
| 3 |
import os
|
| 4 |
+
import uuid
|
| 5 |
+
from typing import Any, Optional
|
| 6 |
|
| 7 |
+
import numpy as np
|
| 8 |
+
from fastapi import FastAPI, HTTPException
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
|
| 11 |
+
from .engine.fold_parser import validate_fold
|
| 12 |
+
from .engine.shape_match import compute_shape_match
|
| 13 |
+
from .engine.simulate import simulate
|
| 14 |
+
from .tasks import TASKS, get_task
|
| 15 |
|
| 16 |
+
app = FastAPI(title="Optigami RL Environment", version="0.1.0")
|
|
|
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
# --- Request/Response models ---
|
| 20 |
|
| 21 |
+
class ResetRequest(BaseModel):
|
| 22 |
+
task_name: str = "triangle"
|
| 23 |
+
seed: Optional[int] = None
|
| 24 |
|
| 25 |
+
class StepRequest(BaseModel):
|
| 26 |
+
fold_data: dict[str, Any] = Field(..., description="FOLD-format crease pattern JSON")
|
| 27 |
+
|
| 28 |
+
class ObservationResponse(BaseModel):
|
| 29 |
+
done: bool = False
|
| 30 |
+
reward: Optional[float] = None
|
| 31 |
+
task: dict[str, Any] = {}
|
| 32 |
+
fold_data: dict[str, Any] = {}
|
| 33 |
+
final_positions: list[list[float]] = []
|
| 34 |
+
target_positions: list[list[float]] = []
|
| 35 |
+
shape_similarity: float = 0.0
|
| 36 |
+
max_strain: float = 0.0
|
| 37 |
+
is_stable: bool = True
|
| 38 |
+
error: Optional[str] = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# --- State ---
|
| 42 |
+
|
| 43 |
+
_state: dict[str, Any] = {}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# --- Endpoints ---
|
| 47 |
+
|
| 48 |
+
@app.post("/reset", response_model=ObservationResponse)
|
| 49 |
+
def reset(req: ResetRequest):
|
| 50 |
+
task = get_task(req.task_name)
|
| 51 |
+
|
| 52 |
+
if req.seed is not None:
|
| 53 |
+
np.random.seed(req.seed)
|
| 54 |
+
|
| 55 |
+
target_fold = task["target_fold"]
|
| 56 |
+
try:
|
| 57 |
+
target_result = simulate(target_fold, crease_percent=1.0)
|
| 58 |
+
target_positions = target_result.positions.tolist()
|
| 59 |
+
except Exception:
|
| 60 |
+
target_positions = []
|
| 61 |
+
|
| 62 |
+
_state["task"] = task
|
| 63 |
+
_state["target_positions"] = target_positions
|
| 64 |
+
_state["episode_id"] = str(uuid.uuid4())
|
| 65 |
+
|
| 66 |
+
return ObservationResponse(
|
| 67 |
+
done=False,
|
| 68 |
+
reward=None,
|
| 69 |
+
task={
|
| 70 |
"name": task["name"],
|
| 71 |
"description": task["description"],
|
| 72 |
"difficulty": task["difficulty"],
|
| 73 |
"paper": task["paper"],
|
| 74 |
+
},
|
| 75 |
+
target_positions=target_positions,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@app.post("/step", response_model=ObservationResponse)
|
| 80 |
+
def step(req: StepRequest):
|
| 81 |
+
task = _state.get("task")
|
| 82 |
+
if not task:
|
| 83 |
+
raise HTTPException(status_code=400, detail="Call /reset first")
|
| 84 |
+
|
| 85 |
+
target_positions = _state.get("target_positions", [])
|
| 86 |
+
fold_data = req.fold_data
|
| 87 |
+
|
| 88 |
+
is_valid, error_msg = validate_fold(fold_data)
|
| 89 |
+
if not is_valid:
|
| 90 |
+
return ObservationResponse(
|
| 91 |
+
done=True,
|
| 92 |
+
reward=-2.0,
|
| 93 |
+
task={"name": task["name"], "description": task["description"]},
|
| 94 |
+
fold_data=fold_data,
|
| 95 |
+
target_positions=target_positions,
|
| 96 |
+
error=f"Invalid FOLD data: {error_msg}",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
result = simulate(fold_data, crease_percent=1.0)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
return ObservationResponse(
|
| 103 |
+
done=True,
|
| 104 |
+
reward=-2.0,
|
| 105 |
+
task={"name": task["name"], "description": task["description"]},
|
| 106 |
+
fold_data=fold_data,
|
| 107 |
+
target_positions=target_positions,
|
| 108 |
+
error=f"Simulation error: {str(e)}",
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
target_arr = np.array(target_positions) if target_positions else np.zeros((0, 3))
|
| 112 |
+
similarity = compute_shape_match(result.positions, target_arr)
|
| 113 |
+
reward = similarity * 20.0
|
| 114 |
+
|
| 115 |
+
return ObservationResponse(
|
| 116 |
+
done=True,
|
| 117 |
+
reward=reward,
|
| 118 |
+
task={"name": task["name"], "description": task["description"]},
|
| 119 |
+
fold_data=fold_data,
|
| 120 |
+
final_positions=result.positions.tolist(),
|
| 121 |
+
target_positions=target_positions,
|
| 122 |
+
shape_similarity=similarity,
|
| 123 |
+
max_strain=result.max_strain,
|
| 124 |
+
is_stable=result.converged,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@app.get("/tasks")
|
| 129 |
+
def list_tasks():
|
| 130 |
+
return {
|
| 131 |
+
name: {
|
| 132 |
+
"name": t["name"],
|
| 133 |
+
"description": t["description"],
|
| 134 |
+
"difficulty": t["difficulty"],
|
| 135 |
+
"paper": t["paper"],
|
| 136 |
+
"target_fold": t["target_fold"],
|
| 137 |
}
|
| 138 |
+
for name, t in TASKS.items()
|
| 139 |
}
|
| 140 |
|
| 141 |
|
| 142 |
@app.get("/tasks/{task_name}")
|
| 143 |
def get_task_detail(task_name: str):
|
| 144 |
if task_name not in TASKS:
|
|
|
|
|
|
|
| 145 |
raise HTTPException(status_code=404, detail=f"Task '{task_name}' not found")
|
| 146 |
+
t = TASKS[task_name]
|
| 147 |
return {
|
| 148 |
+
"name": t["name"],
|
| 149 |
+
"description": t["description"],
|
| 150 |
+
"difficulty": t["difficulty"],
|
| 151 |
+
"paper": t["paper"],
|
| 152 |
+
"target_fold": t["target_fold"],
|
| 153 |
}
|
| 154 |
|
| 155 |
|
| 156 |
+
@app.get("/health")
|
| 157 |
+
def health():
|
| 158 |
+
return {"status": "healthy"}
|
| 159 |
+
|
| 160 |
+
|
| 161 |
def main():
|
| 162 |
import uvicorn
|
|
|
|
| 163 |
port = int(os.environ.get("PORT", 8000))
|
| 164 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
| 165 |
|
origami_server/environment.py
DELETED
|
@@ -1,158 +0,0 @@
|
|
| 1 |
-
"""Origami RL Environment — OpenEnv Environment subclass.
|
| 2 |
-
|
| 3 |
-
Single-shot episodes: LLM submits a FOLD crease pattern, physics simulates it,
|
| 4 |
-
reward = shape similarity to target. Like AlphaFold for origami.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import uuid
|
| 8 |
-
from typing import Any, Optional
|
| 9 |
-
|
| 10 |
-
import numpy as np
|
| 11 |
-
from openenv.core import Environment
|
| 12 |
-
|
| 13 |
-
from .engine.fold_parser import validate_fold
|
| 14 |
-
from .engine.shape_match import compute_shape_match
|
| 15 |
-
from .engine.simulate import SimResult, simulate
|
| 16 |
-
from .models import OrigamiAction, OrigamiObservation, OrigamiState
|
| 17 |
-
from .tasks import get_task
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class OrigamiEnvironment(
|
| 21 |
-
Environment[OrigamiAction, OrigamiObservation, OrigamiState]
|
| 22 |
-
):
|
| 23 |
-
"""Origami folding environment.
|
| 24 |
-
|
| 25 |
-
Episode flow:
|
| 26 |
-
1. reset(task_name="triangle") -> returns task description + target info
|
| 27 |
-
2. step(OrigamiAction(fold_data={...})) -> simulates, scores, returns done=True
|
| 28 |
-
|
| 29 |
-
Single action per episode. The action IS the complete crease pattern.
|
| 30 |
-
"""
|
| 31 |
-
|
| 32 |
-
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 33 |
-
|
| 34 |
-
def __init__(self, **kwargs: Any):
|
| 35 |
-
super().__init__(**kwargs)
|
| 36 |
-
self._state = OrigamiState()
|
| 37 |
-
self._task: dict = {}
|
| 38 |
-
self._target_positions: np.ndarray = np.zeros((0, 3))
|
| 39 |
-
|
| 40 |
-
def reset(
|
| 41 |
-
self,
|
| 42 |
-
seed: Optional[int] = None,
|
| 43 |
-
episode_id: Optional[str] = None,
|
| 44 |
-
**kwargs: Any,
|
| 45 |
-
) -> OrigamiObservation:
|
| 46 |
-
"""Start a new episode with a target shape task."""
|
| 47 |
-
self._state = OrigamiState(
|
| 48 |
-
episode_id=episode_id or str(uuid.uuid4()),
|
| 49 |
-
step_count=0,
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
task_name = kwargs.get("task_name", "triangle")
|
| 53 |
-
self._task = get_task(task_name)
|
| 54 |
-
self._state.task_name = self._task["name"]
|
| 55 |
-
|
| 56 |
-
target_fold = self._task["target_fold"]
|
| 57 |
-
try:
|
| 58 |
-
target_result = simulate(target_fold, crease_percent=1.0)
|
| 59 |
-
self._target_positions = target_result.positions
|
| 60 |
-
except Exception:
|
| 61 |
-
self._target_positions = np.zeros((0, 3))
|
| 62 |
-
|
| 63 |
-
return OrigamiObservation(
|
| 64 |
-
done=False,
|
| 65 |
-
reward=None,
|
| 66 |
-
task=self._task_info(),
|
| 67 |
-
fold_data={},
|
| 68 |
-
final_positions=[],
|
| 69 |
-
target_positions=self._target_positions.tolist(),
|
| 70 |
-
shape_similarity=0.0,
|
| 71 |
-
max_strain=0.0,
|
| 72 |
-
is_stable=True,
|
| 73 |
-
error=None,
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
def step(
|
| 77 |
-
self,
|
| 78 |
-
action: OrigamiAction,
|
| 79 |
-
timeout_s: Optional[float] = None,
|
| 80 |
-
**kwargs: Any,
|
| 81 |
-
) -> OrigamiObservation:
|
| 82 |
-
"""Evaluate the LLM's crease pattern.
|
| 83 |
-
|
| 84 |
-
1. Validate FOLD data
|
| 85 |
-
2. Run physics simulation (creasePercent=1.0)
|
| 86 |
-
3. Compare final shape to target
|
| 87 |
-
4. Return observation with reward = similarity * 20
|
| 88 |
-
"""
|
| 89 |
-
self._state.step_count += 1
|
| 90 |
-
fold_data = action.fold_data
|
| 91 |
-
|
| 92 |
-
is_valid, error_msg = validate_fold(fold_data)
|
| 93 |
-
if not is_valid:
|
| 94 |
-
self._state.is_stable = False
|
| 95 |
-
return OrigamiObservation(
|
| 96 |
-
done=True,
|
| 97 |
-
reward=-2.0,
|
| 98 |
-
task=self._task_info(),
|
| 99 |
-
fold_data=fold_data,
|
| 100 |
-
final_positions=[],
|
| 101 |
-
target_positions=self._target_positions.tolist(),
|
| 102 |
-
shape_similarity=0.0,
|
| 103 |
-
max_strain=0.0,
|
| 104 |
-
is_stable=False,
|
| 105 |
-
error=f"Invalid FOLD data: {error_msg}",
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
try:
|
| 109 |
-
result: SimResult = simulate(fold_data, crease_percent=1.0)
|
| 110 |
-
except Exception as e:
|
| 111 |
-
self._state.is_stable = False
|
| 112 |
-
return OrigamiObservation(
|
| 113 |
-
done=True,
|
| 114 |
-
reward=-2.0,
|
| 115 |
-
task=self._task_info(),
|
| 116 |
-
fold_data=fold_data,
|
| 117 |
-
final_positions=[],
|
| 118 |
-
target_positions=self._target_positions.tolist(),
|
| 119 |
-
shape_similarity=0.0,
|
| 120 |
-
max_strain=0.0,
|
| 121 |
-
is_stable=False,
|
| 122 |
-
error=f"Simulation error: {str(e)}",
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
-
similarity = compute_shape_match(
|
| 126 |
-
result.positions, self._target_positions
|
| 127 |
-
)
|
| 128 |
-
reward = similarity * 20.0
|
| 129 |
-
|
| 130 |
-
self._state.shape_similarity = similarity
|
| 131 |
-
self._state.is_stable = result.converged
|
| 132 |
-
|
| 133 |
-
return OrigamiObservation(
|
| 134 |
-
done=True,
|
| 135 |
-
reward=reward,
|
| 136 |
-
task=self._task_info(),
|
| 137 |
-
fold_data=fold_data,
|
| 138 |
-
final_positions=result.positions.tolist(),
|
| 139 |
-
target_positions=self._target_positions.tolist(),
|
| 140 |
-
shape_similarity=similarity,
|
| 141 |
-
max_strain=result.max_strain,
|
| 142 |
-
is_stable=result.converged,
|
| 143 |
-
error=None,
|
| 144 |
-
)
|
| 145 |
-
|
| 146 |
-
@property
|
| 147 |
-
def state(self) -> OrigamiState:
|
| 148 |
-
return self._state
|
| 149 |
-
|
| 150 |
-
def _task_info(self) -> dict:
|
| 151 |
-
if not self._task:
|
| 152 |
-
return {}
|
| 153 |
-
return {
|
| 154 |
-
"name": self._task.get("name", ""),
|
| 155 |
-
"description": self._task.get("description", ""),
|
| 156 |
-
"difficulty": self._task.get("difficulty", 0),
|
| 157 |
-
"paper": self._task.get("paper", {}),
|
| 158 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
origami_server/models.py
CHANGED
|
@@ -1,35 +1,27 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
OrigamiState: Internal episode state.
|
| 6 |
"""
|
| 7 |
|
| 8 |
from typing import Any, Optional
|
| 9 |
|
| 10 |
-
from
|
| 11 |
-
from pydantic import Field
|
| 12 |
|
| 13 |
|
| 14 |
-
class OrigamiAction(
|
| 15 |
-
"""LLM submits a FOLD crease pattern as its action.
|
| 16 |
-
|
| 17 |
-
The fold_data dict must contain:
|
| 18 |
-
- vertices_coords: [[x, y], ...] — 2D vertex positions on flat paper
|
| 19 |
-
- edges_vertices: [[v1, v2], ...] — edge connectivity
|
| 20 |
-
- edges_assignment: ["B"|"M"|"V", ...] — boundary/mountain/valley
|
| 21 |
-
- edges_foldAngle: [angle, ...] — target fold angles in degrees
|
| 22 |
-
(optional — defaults from assignment: M=-180, V=+180, B=0)
|
| 23 |
-
"""
|
| 24 |
|
| 25 |
fold_data: dict[str, Any] = Field(
|
| 26 |
..., description="FOLD-format crease pattern JSON"
|
| 27 |
)
|
| 28 |
|
| 29 |
|
| 30 |
-
class OrigamiObservation(
|
| 31 |
"""Result of simulating the LLM's crease pattern."""
|
| 32 |
|
|
|
|
|
|
|
| 33 |
task: dict[str, Any] = Field(default_factory=dict)
|
| 34 |
fold_data: dict[str, Any] = Field(default_factory=dict)
|
| 35 |
final_positions: list[list[float]] = Field(default_factory=list)
|
|
@@ -40,9 +32,11 @@ class OrigamiObservation(Observation):
|
|
| 40 |
error: Optional[str] = None
|
| 41 |
|
| 42 |
|
| 43 |
-
class OrigamiState(
|
| 44 |
"""Internal state for an origami episode."""
|
| 45 |
|
|
|
|
|
|
|
| 46 |
task_name: str = ""
|
| 47 |
shape_similarity: float = 0.0
|
| 48 |
is_stable: bool = True
|
|
|
|
| 1 |
+
"""Pydantic types for the Origami RL environment.
|
| 2 |
|
| 3 |
+
These are used by the training scripts and can also be used with openenv-core
|
| 4 |
+
when that package is available (e.g. on Colab).
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
from typing import Any, Optional
|
| 8 |
|
| 9 |
+
from pydantic import BaseModel, Field
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
+
class OrigamiAction(BaseModel):
|
| 13 |
+
"""LLM submits a FOLD crease pattern as its action."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
fold_data: dict[str, Any] = Field(
|
| 16 |
..., description="FOLD-format crease pattern JSON"
|
| 17 |
)
|
| 18 |
|
| 19 |
|
| 20 |
+
class OrigamiObservation(BaseModel):
|
| 21 |
"""Result of simulating the LLM's crease pattern."""
|
| 22 |
|
| 23 |
+
done: bool = False
|
| 24 |
+
reward: Optional[float] = None
|
| 25 |
task: dict[str, Any] = Field(default_factory=dict)
|
| 26 |
fold_data: dict[str, Any] = Field(default_factory=dict)
|
| 27 |
final_positions: list[list[float]] = Field(default_factory=list)
|
|
|
|
| 32 |
error: Optional[str] = None
|
| 33 |
|
| 34 |
|
| 35 |
+
class OrigamiState(BaseModel):
|
| 36 |
"""Internal state for an origami episode."""
|
| 37 |
|
| 38 |
+
episode_id: Optional[str] = None
|
| 39 |
+
step_count: int = 0
|
| 40 |
task_name: str = ""
|
| 41 |
shape_similarity: float = 0.0
|
| 42 |
is_stable: bool = True
|
requirements.txt
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
fastapi>=0.100.0
|
| 2 |
numpy>=1.24.0
|
| 3 |
-
openenv-core[core]>=0.2.1
|
| 4 |
pydantic>=2.0.0
|
| 5 |
scipy>=1.10
|
| 6 |
uvicorn>=0.23.0
|
|
|
|
| 1 |
fastapi>=0.100.0
|
| 2 |
numpy>=1.24.0
|
|
|
|
| 3 |
pydantic>=2.0.0
|
| 4 |
scipy>=1.10
|
| 5 |
uvicorn>=0.23.0
|