Spaces:
Sleeping
Sleeping
Commit Β·
72390e9
1
Parent(s): b01521d
Rename server/ to origami_server/ to avoid module name conflict with uvicorn.server
Browse files- Dockerfile +1 -1
- client.py +1 -1
- models.py +2 -2
- openenv.yaml +1 -1
- {server β origami_server}/__init__.py +0 -0
- {server β origami_server}/app.py +0 -0
- {server β origami_server}/engine/__init__.py +0 -0
- {server β origami_server}/engine/fold_parser.py +0 -0
- {server β origami_server}/engine/shape_match.py +0 -0
- {server β origami_server}/engine/simulate.py +0 -0
- {server β origami_server}/environment.py +0 -0
- {server β origami_server}/models.py +0 -0
- {server β origami_server}/tasks.py +0 -0
- pyproject.toml +1 -1
- server/requirements.txt +0 -6
- tests/test_origami.py +7 -7
- training/reward.py +4 -4
- training/train_grpo.py +1 -1
Dockerfile
CHANGED
|
@@ -11,4 +11,4 @@ COPY . /app
|
|
| 11 |
|
| 12 |
EXPOSE 8000
|
| 13 |
|
| 14 |
-
CMD ["uvicorn", "
|
|
|
|
| 11 |
|
| 12 |
EXPOSE 8000
|
| 13 |
|
| 14 |
+
CMD ["uvicorn", "origami_server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
client.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Any, Dict
|
|
| 5 |
from openenv.core.client_types import StepResult
|
| 6 |
from openenv.core.env_client import EnvClient
|
| 7 |
|
| 8 |
-
from
|
| 9 |
|
| 10 |
|
| 11 |
class OrigamiEnv(EnvClient[OrigamiAction, OrigamiObservation, OrigamiState]):
|
|
|
|
| 5 |
from openenv.core.client_types import StepResult
|
| 6 |
from openenv.core.env_client import EnvClient
|
| 7 |
|
| 8 |
+
from origami_server.models import OrigamiAction, OrigamiObservation, OrigamiState
|
| 9 |
|
| 10 |
|
| 11 |
class OrigamiEnv(EnvClient[OrigamiAction, OrigamiObservation, OrigamiState]):
|
models.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
"""Re-export models from
|
| 2 |
|
| 3 |
-
from
|
| 4 |
|
| 5 |
__all__ = ["OrigamiAction", "OrigamiObservation", "OrigamiState"]
|
|
|
|
| 1 |
+
"""Re-export models from origami_server.models for OpenEnv client usage."""
|
| 2 |
|
| 3 |
+
from origami_server.models import OrigamiAction, OrigamiObservation, OrigamiState
|
| 4 |
|
| 5 |
__all__ = ["OrigamiAction", "OrigamiObservation", "OrigamiState"]
|
openenv.yaml
CHANGED
|
@@ -2,5 +2,5 @@ spec_version: 1
|
|
| 2 |
name: origami_env
|
| 3 |
type: space
|
| 4 |
runtime: fastapi
|
| 5 |
-
app:
|
| 6 |
port: 8000
|
|
|
|
| 2 |
name: origami_env
|
| 3 |
type: space
|
| 4 |
runtime: fastapi
|
| 5 |
+
app: origami_server.app:app
|
| 6 |
port: 8000
|
{server β origami_server}/__init__.py
RENAMED
|
File without changes
|
{server β origami_server}/app.py
RENAMED
|
File without changes
|
{server β origami_server}/engine/__init__.py
RENAMED
|
File without changes
|
{server β origami_server}/engine/fold_parser.py
RENAMED
|
File without changes
|
{server β origami_server}/engine/shape_match.py
RENAMED
|
File without changes
|
{server β origami_server}/engine/simulate.py
RENAMED
|
File without changes
|
{server β origami_server}/environment.py
RENAMED
|
File without changes
|
{server β origami_server}/models.py
RENAMED
|
File without changes
|
{server β origami_server}/tasks.py
RENAMED
|
File without changes
|
pyproject.toml
CHANGED
|
@@ -18,7 +18,7 @@ dependencies = [
|
|
| 18 |
]
|
| 19 |
|
| 20 |
[project.scripts]
|
| 21 |
-
server = "
|
| 22 |
|
| 23 |
[project.optional-dependencies]
|
| 24 |
training = [
|
|
|
|
| 18 |
]
|
| 19 |
|
| 20 |
[project.scripts]
|
| 21 |
+
server = "origami_server.app:main"
|
| 22 |
|
| 23 |
[project.optional-dependencies]
|
| 24 |
training = [
|
server/requirements.txt
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
openenv-core>=0.2.1
|
| 2 |
-
numpy>=1.24
|
| 3 |
-
scipy>=1.10
|
| 4 |
-
pydantic>=2.0
|
| 5 |
-
fastapi>=0.100
|
| 6 |
-
uvicorn>=0.22
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_origami.py
CHANGED
|
@@ -3,12 +3,12 @@
|
|
| 3 |
import numpy as np
|
| 4 |
import pytest
|
| 5 |
|
| 6 |
-
from
|
| 7 |
-
from
|
| 8 |
-
from
|
| 9 |
-
from
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
from training.reward import extract_fold_json, shape_match, valid_fold
|
| 13 |
|
| 14 |
# --- Fixtures ---
|
|
@@ -239,7 +239,7 @@ class TestAPI:
|
|
| 239 |
def client(self):
|
| 240 |
from fastapi.testclient import TestClient
|
| 241 |
|
| 242 |
-
from
|
| 243 |
|
| 244 |
return TestClient(app)
|
| 245 |
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import pytest
|
| 5 |
|
| 6 |
+
from origami_server.engine.fold_parser import parse_fold, validate_fold
|
| 7 |
+
from origami_server.engine.shape_match import compute_shape_match
|
| 8 |
+
from origami_server.engine.simulate import simulate
|
| 9 |
+
from origami_server.environment import OrigamiEnvironment
|
| 10 |
+
from origami_server.models import OrigamiAction
|
| 11 |
+
from origami_server.tasks import TASKS, get_task, list_tasks
|
| 12 |
from training.reward import extract_fold_json, shape_match, valid_fold
|
| 13 |
|
| 14 |
# --- Fixtures ---
|
|
|
|
| 239 |
def client(self):
|
| 240 |
from fastapi.testclient import TestClient
|
| 241 |
|
| 242 |
+
from origami_server.app import app
|
| 243 |
|
| 244 |
return TestClient(app)
|
| 245 |
|
training/reward.py
CHANGED
|
@@ -11,10 +11,10 @@ from typing import Any
|
|
| 11 |
|
| 12 |
import numpy as np
|
| 13 |
|
| 14 |
-
from
|
| 15 |
-
from
|
| 16 |
-
from
|
| 17 |
-
from
|
| 18 |
|
| 19 |
|
| 20 |
def extract_fold_json(response: str) -> dict | None:
|
|
|
|
| 11 |
|
| 12 |
import numpy as np
|
| 13 |
|
| 14 |
+
from origami_server.engine.fold_parser import validate_fold
|
| 15 |
+
from origami_server.engine.shape_match import compute_shape_match
|
| 16 |
+
from origami_server.engine.simulate import simulate
|
| 17 |
+
from origami_server.tasks import get_task
|
| 18 |
|
| 19 |
|
| 20 |
def extract_fold_json(response: str) -> dict | None:
|
training/train_grpo.py
CHANGED
|
@@ -55,7 +55,7 @@ def main():
|
|
| 55 |
from trl import GRPOConfig, GRPOTrainer
|
| 56 |
from unsloth import FastLanguageModel
|
| 57 |
|
| 58 |
-
from
|
| 59 |
from training.reward import shape_match, valid_fold
|
| 60 |
|
| 61 |
task = get_task(args.task)
|
|
|
|
| 55 |
from trl import GRPOConfig, GRPOTrainer
|
| 56 |
from unsloth import FastLanguageModel
|
| 57 |
|
| 58 |
+
from origami_server.tasks import get_task
|
| 59 |
from training.reward import shape_match, valid_fold
|
| 60 |
|
| 61 |
task = get_task(args.task)
|