Spaces:
Running
Running
Deploy from feature/updates branch (PR #132)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Dockerfile +6 -2
- core/env_server/http_server.py +0 -257
- core/env_server/types.py +0 -57
- core/http_env_client.py +0 -203
- core/pyproject.toml +0 -47
- core/uv.lock +0 -0
- {core → src/core}/README.md +11 -11
- {core → src/core}/__init__.py +4 -7
- {core → src/core}/client_types.py +3 -2
- src/core/containers/__init__.py +7 -0
- {core → src/core}/containers/images/Dockerfile +0 -0
- {core → src/core}/containers/images/README.md +3 -3
- {core → src/core}/containers/runtime/__init__.py +12 -2
- {core → src/core}/containers/runtime/providers.py +374 -6
- src/core/containers/runtime/uv_provider.py +224 -0
- {core → src/core}/containers/test_local_docker_provider.py +9 -6
- src/core/env_client.py +361 -0
- {core → src/core}/env_server/__init__.py +62 -1
- {core → src/core}/env_server/base_transforms.py +1 -1
- src/core/env_server/exceptions.py +105 -0
- src/core/env_server/http_server.py +935 -0
- {core → src/core}/env_server/interfaces.py +194 -118
- src/core/env_server/route_config.py +57 -0
- src/core/env_server/serialization.py +137 -0
- src/core/env_server/types.py +341 -0
- {core → src/core}/env_server/web_interface.py +355 -347
- {core → src/core}/tools/__init__.py +1 -1
- {core → src/core}/tools/git_server_client.py +9 -2
- {core → src/core}/tools/local_python_executor.py +10 -3
- src/core/utils.py +27 -0
- src/openenv/__init__.py +12 -0
- {core/containers → src/openenv/cli}/__init__.py +3 -1
- src/openenv/cli/__main__.py +58 -0
- src/openenv/cli/_cli_utils.py +76 -0
- src/openenv/cli/_validation.py +159 -0
- src/openenv/cli/commands/__init__.py +11 -0
- src/openenv/cli/commands/build.py +453 -0
- src/openenv/cli/commands/init.py +501 -0
- src/openenv/cli/commands/push.py +541 -0
- src/openenv/cli/commands/serve.py +92 -0
- src/openenv/cli/commands/validate.py +108 -0
- src/openenv/cli/templates/__init__.py +7 -0
- src/openenv/cli/templates/openenv_env/.dockerignore +15 -0
- src/openenv/cli/templates/openenv_env/README.md +255 -0
- src/openenv/cli/templates/openenv_env/__init__.py +16 -0
- src/openenv/cli/templates/openenv_env/client.py +99 -0
- src/openenv/cli/templates/openenv_env/models.py +28 -0
- src/openenv/cli/templates/openenv_env/openenv.yaml +7 -0
- src/openenv/cli/templates/openenv_env/pyproject.toml +45 -0
- src/openenv/cli/templates/openenv_env/server/Dockerfile +80 -0
Dockerfile
CHANGED
|
@@ -1,13 +1,17 @@
|
|
|
|
|
| 1 |
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 2 |
FROM ${BASE_IMAGE}
|
| 3 |
|
| 4 |
-
|
| 5 |
-
COPY src/
|
| 6 |
COPY README.md /app/README.md
|
| 7 |
|
|
|
|
| 8 |
ENV ENABLE_WEB_INTERFACE=true
|
| 9 |
|
|
|
|
| 10 |
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 11 |
CMD curl -f http://localhost:8000/health || exit 1
|
| 12 |
|
|
|
|
| 13 |
CMD ["sh", "-lc", "python -m uvicorn envs.wildfire_env.server.app:app --host 0.0.0.0 --port ${PORT:-8000} --proxy-headers --forwarded-allow-ips='*'"]
|
|
|
|
| 1 |
+
# Use OpenEnv base image
|
| 2 |
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 3 |
FROM ${BASE_IMAGE}
|
| 4 |
|
| 5 |
+
# Copy all source files
|
| 6 |
+
COPY src/ /app/src/
|
| 7 |
COPY README.md /app/README.md
|
| 8 |
|
| 9 |
+
# Set environment variables
|
| 10 |
ENV ENABLE_WEB_INTERFACE=true
|
| 11 |
|
| 12 |
+
# Health check
|
| 13 |
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 14 |
CMD curl -f http://localhost:8000/health || exit 1
|
| 15 |
|
| 16 |
+
# Run the wildfire environment server
|
| 17 |
CMD ["sh", "-lc", "python -m uvicorn envs.wildfire_env.server.app:app --host 0.0.0.0 --port ${PORT:-8000} --proxy-headers --forwarded-allow-ips='*'"]
|
core/env_server/http_server.py
DELETED
|
@@ -1,257 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
HTTP server wrapper for Environment instances.
|
| 9 |
-
|
| 10 |
-
This module provides utilities to wrap any Environment subclass and expose it
|
| 11 |
-
over HTTP endpoints that HTTPEnvClient can consume.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
from __future__ import annotations
|
| 15 |
-
|
| 16 |
-
import asyncio
|
| 17 |
-
import os
|
| 18 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 19 |
-
from dataclasses import asdict
|
| 20 |
-
from typing import Any, Dict, Type
|
| 21 |
-
|
| 22 |
-
from .interfaces import Environment
|
| 23 |
-
from .types import Action, Observation
|
| 24 |
-
from fastapi import Body, FastAPI
|
| 25 |
-
|
| 26 |
-
class HTTPEnvServer:
|
| 27 |
-
"""
|
| 28 |
-
HTTP server wrapper for Environment instances.
|
| 29 |
-
|
| 30 |
-
This class wraps an Environment and exposes its reset(), step(), and state
|
| 31 |
-
methods as HTTP endpoints compatible with HTTPEnvClient.
|
| 32 |
-
|
| 33 |
-
The server expects:
|
| 34 |
-
- Action deserialization: Converts JSON dict to Action subclass
|
| 35 |
-
- Observation serialization: Converts Observation subclass to JSON dict
|
| 36 |
-
|
| 37 |
-
Example:
|
| 38 |
-
>>> from core.env_server import HTTPEnvServer
|
| 39 |
-
>>> from envs.coding_env.server import CodeExecutionEnvironment
|
| 40 |
-
>>>
|
| 41 |
-
>>> env = CodeExecutionEnvironment()
|
| 42 |
-
>>> server = HTTPEnvServer(env)
|
| 43 |
-
>>>
|
| 44 |
-
>>> # Register routes with FastAPI
|
| 45 |
-
>>> from fastapi import FastAPI
|
| 46 |
-
>>> app = FastAPI()
|
| 47 |
-
>>> server.register_routes(app)
|
| 48 |
-
"""
|
| 49 |
-
|
| 50 |
-
def __init__(
|
| 51 |
-
self,
|
| 52 |
-
env: Environment,
|
| 53 |
-
action_cls: Type[Action],
|
| 54 |
-
observation_cls: Type[Observation],
|
| 55 |
-
):
|
| 56 |
-
"""
|
| 57 |
-
Initialize HTTP server wrapper.
|
| 58 |
-
|
| 59 |
-
Args:
|
| 60 |
-
env: The Environment instance to wrap
|
| 61 |
-
action_cls: The Action subclass this environment expects
|
| 62 |
-
observation_cls: The Observation subclass this environment returns
|
| 63 |
-
"""
|
| 64 |
-
self.env = env
|
| 65 |
-
self.action_cls = action_cls
|
| 66 |
-
self.observation_cls = observation_cls
|
| 67 |
-
# Create thread pool for running sync code in async context
|
| 68 |
-
# This is needed for environments using sync libraries (e.g., Playwright sync API)
|
| 69 |
-
self._executor = ThreadPoolExecutor(max_workers=1)
|
| 70 |
-
|
| 71 |
-
def register_routes(self, app: Any) -> None:
|
| 72 |
-
"""
|
| 73 |
-
Register HTTP routes on a FastAPI application.
|
| 74 |
-
|
| 75 |
-
Args:
|
| 76 |
-
app: FastAPI application instance
|
| 77 |
-
"""
|
| 78 |
-
|
| 79 |
-
if not isinstance(app, FastAPI):
|
| 80 |
-
raise TypeError("app must be a FastAPI instance")
|
| 81 |
-
|
| 82 |
-
@app.post("/reset")
|
| 83 |
-
async def reset(request: Dict[str, Any] = Body(default={})) -> Dict[str, Any]:
|
| 84 |
-
"""Reset endpoint - returns initial observation."""
|
| 85 |
-
# TODO: Handle seed, episode_id from request if provided
|
| 86 |
-
# Run sync environment code in thread pool to avoid blocking asyncio loop
|
| 87 |
-
loop = asyncio.get_event_loop()
|
| 88 |
-
observation = await loop.run_in_executor(self._executor, self.env.reset)
|
| 89 |
-
return self._serialize_observation(observation)
|
| 90 |
-
|
| 91 |
-
@app.post("/step")
|
| 92 |
-
async def step(request: Dict[str, Any]) -> Dict[str, Any]:
|
| 93 |
-
"""Step endpoint - executes action and returns observation."""
|
| 94 |
-
# Support both {"action": {...}} and direct action fields
|
| 95 |
-
action_data = request.get("action", request)
|
| 96 |
-
# TODO: Handle timeout_s, request_id, episode_id from request if provided
|
| 97 |
-
|
| 98 |
-
# Deserialize action
|
| 99 |
-
action = self._deserialize_action(action_data)
|
| 100 |
-
|
| 101 |
-
# Execute step in thread pool to avoid blocking asyncio loop
|
| 102 |
-
loop = asyncio.get_event_loop()
|
| 103 |
-
observation = await loop.run_in_executor(
|
| 104 |
-
self._executor, self.env.step, action
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
# Return serialized observation
|
| 108 |
-
return self._serialize_observation(observation)
|
| 109 |
-
|
| 110 |
-
@app.get("/state")
|
| 111 |
-
async def get_state() -> Dict[str, Any]:
|
| 112 |
-
"""State endpoint - returns current environment state."""
|
| 113 |
-
state = self.env.state
|
| 114 |
-
return asdict(state)
|
| 115 |
-
|
| 116 |
-
@app.get("/health")
|
| 117 |
-
async def health() -> Dict[str, str]:
|
| 118 |
-
"""Health check endpoint."""
|
| 119 |
-
return {"status": "healthy"}
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
def _deserialize_action(self, action_data: Dict[str, Any]) -> Action:
|
| 123 |
-
"""
|
| 124 |
-
Convert JSON dict to Action instance.
|
| 125 |
-
|
| 126 |
-
Args:
|
| 127 |
-
action_data: Dictionary containing action data
|
| 128 |
-
|
| 129 |
-
Returns:
|
| 130 |
-
Action instance
|
| 131 |
-
|
| 132 |
-
Note:
|
| 133 |
-
This is a simple implementation. Subclasses may need to override
|
| 134 |
-
for more complex deserialization logic.
|
| 135 |
-
"""
|
| 136 |
-
# Remove metadata if present (it will be set via kw_only field)
|
| 137 |
-
metadata = action_data.pop("metadata", {})
|
| 138 |
-
action = self.action_cls(**action_data)
|
| 139 |
-
action.metadata = metadata
|
| 140 |
-
return action
|
| 141 |
-
|
| 142 |
-
def _serialize_observation(self, observation: Observation) -> Dict[str, Any]:
|
| 143 |
-
"""
|
| 144 |
-
Convert Observation instance to JSON-compatible dict.
|
| 145 |
-
|
| 146 |
-
Args:
|
| 147 |
-
observation: Observation instance
|
| 148 |
-
|
| 149 |
-
Returns:
|
| 150 |
-
Dictionary compatible with HTTPEnvClient._parse_result()
|
| 151 |
-
|
| 152 |
-
The format matches what HTTPEnvClient expects:
|
| 153 |
-
{
|
| 154 |
-
"observation": {...}, # Observation fields
|
| 155 |
-
"reward": float | None,
|
| 156 |
-
"done": bool,
|
| 157 |
-
}
|
| 158 |
-
"""
|
| 159 |
-
obs_dict = asdict(observation)
|
| 160 |
-
|
| 161 |
-
# Convert numpy arrays to lists for JSON serialization
|
| 162 |
-
def _convert_numpy(obj):
|
| 163 |
-
"""Recursively convert numpy arrays to lists."""
|
| 164 |
-
if hasattr(obj, '__array__'): # numpy array
|
| 165 |
-
return obj.tolist()
|
| 166 |
-
elif isinstance(obj, dict):
|
| 167 |
-
return {k: _convert_numpy(v) for k, v in obj.items()}
|
| 168 |
-
elif isinstance(obj, (list, tuple)):
|
| 169 |
-
return type(obj)(_convert_numpy(item) for item in obj)
|
| 170 |
-
return obj
|
| 171 |
-
|
| 172 |
-
obs_dict = _convert_numpy(obs_dict)
|
| 173 |
-
|
| 174 |
-
# Extract reward and done (these are part of StepResult on client side)
|
| 175 |
-
reward = obs_dict.pop("reward", None)
|
| 176 |
-
done = obs_dict.pop("done", False)
|
| 177 |
-
obs_dict.pop("metadata", None) # Remove metadata from observation
|
| 178 |
-
|
| 179 |
-
# Return in HTTPEnvClient expected format
|
| 180 |
-
return {
|
| 181 |
-
"observation": obs_dict,
|
| 182 |
-
"reward": reward,
|
| 183 |
-
"done": done,
|
| 184 |
-
}
|
| 185 |
-
|
| 186 |
-
def create_app(
|
| 187 |
-
env: Environment,
|
| 188 |
-
action_cls: Type[Action],
|
| 189 |
-
observation_cls: Type[Observation],
|
| 190 |
-
env_name: Optional[str] = None,
|
| 191 |
-
) -> Any:
|
| 192 |
-
"""
|
| 193 |
-
Create a FastAPI application with or without web interface.
|
| 194 |
-
|
| 195 |
-
This function creates a FastAPI app with the web interface enabled by default,
|
| 196 |
-
including README integration for better user experience.
|
| 197 |
-
|
| 198 |
-
Args:
|
| 199 |
-
env: The Environment instance to serve
|
| 200 |
-
action_cls: The Action subclass this environment expects
|
| 201 |
-
observation_cls: The Observation subclass this environment returns
|
| 202 |
-
env_name: Optional environment name for README loading
|
| 203 |
-
|
| 204 |
-
Returns:
|
| 205 |
-
FastAPI application instance with or without web interface and README integration
|
| 206 |
-
"""
|
| 207 |
-
# Check if web interface should be enabled
|
| 208 |
-
# This can be controlled via environment variable or build argument
|
| 209 |
-
enable_web = (
|
| 210 |
-
os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ("true", "1", "yes")
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
if enable_web:
|
| 214 |
-
# Import web interface only when needed
|
| 215 |
-
from .web_interface import create_web_interface_app
|
| 216 |
-
return create_web_interface_app(env, action_cls, observation_cls, env_name)
|
| 217 |
-
else:
|
| 218 |
-
# Use standard FastAPI app without web interface
|
| 219 |
-
return create_fastapi_app(env, action_cls, observation_cls)
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
def create_fastapi_app(
|
| 223 |
-
env: Environment,
|
| 224 |
-
action_cls: Type[Action],
|
| 225 |
-
observation_cls: Type[Observation],
|
| 226 |
-
) -> Any:
|
| 227 |
-
"""
|
| 228 |
-
Create a FastAPI application with routes for the given environment.
|
| 229 |
-
|
| 230 |
-
Args:
|
| 231 |
-
env: The Environment instance to serve
|
| 232 |
-
action_cls: The Action subclass this environment expects
|
| 233 |
-
observation_cls: The Observation subclass this environment returns
|
| 234 |
-
|
| 235 |
-
Returns:
|
| 236 |
-
FastAPI application instance with routes registered
|
| 237 |
-
|
| 238 |
-
Example:
|
| 239 |
-
>>> from envs.coding_env.server import CodeExecutionEnvironment
|
| 240 |
-
>>> from envs.coding_env.models import CodeAction, CodeObservation
|
| 241 |
-
>>>
|
| 242 |
-
>>> env = CodeExecutionEnvironment()
|
| 243 |
-
>>> app = create_fastapi_app(env, CodeAction, CodeObservation)
|
| 244 |
-
>>>
|
| 245 |
-
>>> # Run with: uvicorn module:app --host 0.0.0.0 --port 8000
|
| 246 |
-
"""
|
| 247 |
-
try:
|
| 248 |
-
from fastapi import FastAPI
|
| 249 |
-
except ImportError:
|
| 250 |
-
raise ImportError(
|
| 251 |
-
"FastAPI is required. Install with: pip install fastapi uvicorn"
|
| 252 |
-
)
|
| 253 |
-
|
| 254 |
-
app = FastAPI(title="Environment HTTP Server")
|
| 255 |
-
server = HTTPEnvServer(env, action_cls, observation_cls)
|
| 256 |
-
server.register_routes(app)
|
| 257 |
-
return app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
core/env_server/types.py
DELETED
|
@@ -1,57 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
from dataclasses import dataclass, field
|
| 8 |
-
from typing import Any, Dict, List, Optional, Union
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
# Type aliases
|
| 12 |
-
Scalar = Union[int, float, bool]
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
@dataclass(kw_only=True)
|
| 16 |
-
class Action:
|
| 17 |
-
"""Base class for all environment actions."""
|
| 18 |
-
|
| 19 |
-
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
@dataclass(kw_only=True)
|
| 23 |
-
class Observation:
|
| 24 |
-
"""Base class for all environment observations."""
|
| 25 |
-
|
| 26 |
-
done: bool = False
|
| 27 |
-
reward: Union[bool, int, float, None] = None
|
| 28 |
-
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
@dataclass
|
| 32 |
-
class State:
|
| 33 |
-
"""Base class for environment state."""
|
| 34 |
-
|
| 35 |
-
episode_id: Optional[str] = None
|
| 36 |
-
step_count: int = 0
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
@dataclass
|
| 40 |
-
class CodeExecResult:
|
| 41 |
-
"""Result of code execution containing stdout, stderr, and exit code."""
|
| 42 |
-
|
| 43 |
-
stdout: str
|
| 44 |
-
stderr: str
|
| 45 |
-
exit_code: int
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
@dataclass
|
| 49 |
-
class EnvironmentMetadata:
|
| 50 |
-
"""Metadata about an environment for documentation and UI purposes."""
|
| 51 |
-
|
| 52 |
-
name: str
|
| 53 |
-
description: str
|
| 54 |
-
readme_content: Optional[str] = None
|
| 55 |
-
version: Optional[str] = None
|
| 56 |
-
author: Optional[str] = None
|
| 57 |
-
documentation_url: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
core/http_env_client.py
DELETED
|
@@ -1,203 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
core/runner_env.py
|
| 3 |
-
Minimal HTTP-based environment client.
|
| 4 |
-
- Talks to a single env worker exposing: POST /reset, POST /step
|
| 5 |
-
|
| 6 |
-
Future hooks (commented below) for:
|
| 7 |
-
- episode_id, seed on reset
|
| 8 |
-
- request_id on step
|
| 9 |
-
- custom headers (auth/trace)
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
from __future__ import annotations
|
| 13 |
-
|
| 14 |
-
from abc import ABC, abstractmethod
|
| 15 |
-
from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar
|
| 16 |
-
|
| 17 |
-
import requests
|
| 18 |
-
|
| 19 |
-
from .client_types import StepResult
|
| 20 |
-
from .containers.runtime import LocalDockerProvider
|
| 21 |
-
|
| 22 |
-
if TYPE_CHECKING:
|
| 23 |
-
from .containers.runtime import ContainerProvider
|
| 24 |
-
|
| 25 |
-
ActT = TypeVar("ActT")
|
| 26 |
-
ObsT = TypeVar("ObsT")
|
| 27 |
-
EnvClientT = TypeVar("EnvClientT", bound="HTTPEnvClient")
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
class HTTPEnvClient(ABC, Generic[ActT, ObsT]):
|
| 31 |
-
def __init__(
|
| 32 |
-
self,
|
| 33 |
-
base_url: str,
|
| 34 |
-
request_timeout_s: float = 15.0,
|
| 35 |
-
default_headers: Optional[Dict[str, str]] = None,
|
| 36 |
-
provider: Optional["ContainerProvider"] = None,
|
| 37 |
-
):
|
| 38 |
-
self._base = base_url.rstrip("/")
|
| 39 |
-
self._timeout = float(request_timeout_s)
|
| 40 |
-
self._http = requests.Session()
|
| 41 |
-
self._headers = default_headers or {}
|
| 42 |
-
self._provider = provider
|
| 43 |
-
|
| 44 |
-
@classmethod
|
| 45 |
-
def from_docker_image(
|
| 46 |
-
cls: Type[EnvClientT],
|
| 47 |
-
image: str,
|
| 48 |
-
provider: Optional["ContainerProvider"] = None,
|
| 49 |
-
**kwargs: Any,
|
| 50 |
-
) -> EnvClientT:
|
| 51 |
-
"""
|
| 52 |
-
Create an environment client by spinning up a Docker container locally.
|
| 53 |
-
|
| 54 |
-
This is a development utility that:
|
| 55 |
-
1. Starts a Docker container from the specified image
|
| 56 |
-
2. Waits for the server to be ready
|
| 57 |
-
3. Creates and returns a client instance connected to the container
|
| 58 |
-
|
| 59 |
-
Note: The container lifecycle management is left to the user or higher-level
|
| 60 |
-
orchestration. The container will keep running until manually stopped.
|
| 61 |
-
|
| 62 |
-
Args:
|
| 63 |
-
image: Docker image name to run (e.g., "echo-env:latest")
|
| 64 |
-
provider: Container provider to use (defaults to LocalDockerProvider)
|
| 65 |
-
**kwargs: Additional arguments to pass to provider.start_container()
|
| 66 |
-
(e.g., env_vars, port)
|
| 67 |
-
|
| 68 |
-
Returns:
|
| 69 |
-
An instance of the client class connected to the running container
|
| 70 |
-
|
| 71 |
-
Example:
|
| 72 |
-
>>> from envs.coding_env.client import CodingEnv
|
| 73 |
-
>>> from envs.coding_env.models import CodeAction
|
| 74 |
-
>>>
|
| 75 |
-
>>> # Create environment from image
|
| 76 |
-
>>> env = CodingEnv.from_docker_image("coding-env:latest")
|
| 77 |
-
>>>
|
| 78 |
-
>>> # Create environment with custom env vars
|
| 79 |
-
>>> env = CodingEnv.from_docker_image(
|
| 80 |
-
... "coding-env:latest",
|
| 81 |
-
... env_vars={"MY_VAR": "value"}
|
| 82 |
-
... )
|
| 83 |
-
>>>
|
| 84 |
-
>>> # Use the environment
|
| 85 |
-
>>> result = env.reset()
|
| 86 |
-
>>> print(result.observation)
|
| 87 |
-
>>>
|
| 88 |
-
>>> step_result = env.step(CodeAction(code="print('hello')"))
|
| 89 |
-
>>> print(step_result.observation.stdout)
|
| 90 |
-
>>>
|
| 91 |
-
>>> # Cleanup (optional)
|
| 92 |
-
>>> env.close()
|
| 93 |
-
"""
|
| 94 |
-
|
| 95 |
-
# Use default provider if none provided
|
| 96 |
-
if provider is None:
|
| 97 |
-
provider = LocalDockerProvider()
|
| 98 |
-
|
| 99 |
-
# 1. Start container with optional kwargs (e.g., env_vars, port)
|
| 100 |
-
base_url = provider.start_container(image, **kwargs)
|
| 101 |
-
|
| 102 |
-
# 2. Wait for server to be ready
|
| 103 |
-
provider.wait_for_ready(base_url)
|
| 104 |
-
|
| 105 |
-
# 3. Create and return client instance with provider reference
|
| 106 |
-
return cls(base_url=base_url, provider=provider)
|
| 107 |
-
|
| 108 |
-
@classmethod
|
| 109 |
-
def from_hub(cls: Type[EnvClientT], repo_id: str, provider: Optional["ContainerProvider"] = None, **kwargs: Any) -> EnvClientT:
|
| 110 |
-
"""
|
| 111 |
-
Create an environment client by pulling from a Hugging Face model hub.
|
| 112 |
-
"""
|
| 113 |
-
|
| 114 |
-
if provider is None:
|
| 115 |
-
provider = LocalDockerProvider()
|
| 116 |
-
|
| 117 |
-
if "tag" in kwargs:
|
| 118 |
-
tag = kwargs["tag"]
|
| 119 |
-
else:
|
| 120 |
-
tag = "latest"
|
| 121 |
-
|
| 122 |
-
base_url = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}"
|
| 123 |
-
|
| 124 |
-
return cls.from_docker_image(image=base_url, provider=provider)
|
| 125 |
-
|
| 126 |
-
@abstractmethod
|
| 127 |
-
def _step_payload(self, action: ActT) -> dict:
|
| 128 |
-
"""Convert an Action object to the JSON body expected by the env server."""
|
| 129 |
-
raise NotImplementedError
|
| 130 |
-
|
| 131 |
-
@abstractmethod
|
| 132 |
-
def _parse_result(self, payload: dict) -> StepResult[ObsT]:
|
| 133 |
-
"""Convert a JSON response from the env server to StepResult[ObsT]."""
|
| 134 |
-
raise NotImplementedError
|
| 135 |
-
|
| 136 |
-
@abstractmethod
|
| 137 |
-
def _parse_state(self, payload: dict) -> Any:
|
| 138 |
-
"""Convert a JSON response from the state endpoint to a State object."""
|
| 139 |
-
raise NotImplementedError
|
| 140 |
-
|
| 141 |
-
# ---------- Environment Server Interface Methods ----------
|
| 142 |
-
def reset(self) -> StepResult[ObsT]:
|
| 143 |
-
body: Dict[str, Any] = {}
|
| 144 |
-
# TODO: later:
|
| 145 |
-
# body["seed"] = seed
|
| 146 |
-
# body["episode_id"] = episode_id
|
| 147 |
-
r = self._http.post(
|
| 148 |
-
f"{self._base}/reset",
|
| 149 |
-
json=body,
|
| 150 |
-
headers=self._headers,
|
| 151 |
-
timeout=self._timeout,
|
| 152 |
-
)
|
| 153 |
-
r.raise_for_status()
|
| 154 |
-
return self._parse_result(r.json())
|
| 155 |
-
|
| 156 |
-
def step(self, action: ActT) -> StepResult[ObsT]:
|
| 157 |
-
body: Dict[str, Any] = {
|
| 158 |
-
"action": self._step_payload(action),
|
| 159 |
-
"timeout_s": int(self._timeout),
|
| 160 |
-
}
|
| 161 |
-
# TODO: later:
|
| 162 |
-
# body["request_id"] = str(uuid.uuid4())
|
| 163 |
-
# body["episode_id"] = current_episode_id
|
| 164 |
-
r = self._http.post(
|
| 165 |
-
f"{self._base}/step",
|
| 166 |
-
json=body,
|
| 167 |
-
headers=self._headers,
|
| 168 |
-
timeout=self._timeout,
|
| 169 |
-
)
|
| 170 |
-
r.raise_for_status()
|
| 171 |
-
return self._parse_result(r.json())
|
| 172 |
-
|
| 173 |
-
def state(self) -> Any:
|
| 174 |
-
"""
|
| 175 |
-
Get the current environment state from the server.
|
| 176 |
-
|
| 177 |
-
Returns:
|
| 178 |
-
State object with environment state information (e.g., episode_id, step_count)
|
| 179 |
-
|
| 180 |
-
Example:
|
| 181 |
-
>>> client = EchoEnv.from_docker_image("echo-env:latest")
|
| 182 |
-
>>> result = client.reset()
|
| 183 |
-
>>> state = client.state()
|
| 184 |
-
>>> print(state.episode_id)
|
| 185 |
-
>>> print(state.step_count)
|
| 186 |
-
"""
|
| 187 |
-
r = self._http.get(
|
| 188 |
-
f"{self._base}/state",
|
| 189 |
-
headers=self._headers,
|
| 190 |
-
timeout=self._timeout,
|
| 191 |
-
)
|
| 192 |
-
r.raise_for_status()
|
| 193 |
-
return self._parse_state(r.json())
|
| 194 |
-
|
| 195 |
-
def close(self) -> None:
|
| 196 |
-
"""
|
| 197 |
-
Close the environment and clean up resources.
|
| 198 |
-
|
| 199 |
-
If this client was created via from_docker_image(), this will stop
|
| 200 |
-
and remove the associated container.
|
| 201 |
-
"""
|
| 202 |
-
if self._provider is not None:
|
| 203 |
-
self._provider.stop_container()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
core/pyproject.toml
DELETED
|
@@ -1,47 +0,0 @@
|
|
| 1 |
-
[build-system]
|
| 2 |
-
requires = ["setuptools>=45", "wheel"]
|
| 3 |
-
build-backend = "setuptools.build_meta"
|
| 4 |
-
|
| 5 |
-
[project]
|
| 6 |
-
name = "openenv-core"
|
| 7 |
-
version = "0.1.0"
|
| 8 |
-
description = "Core components for OpenEnv - HTTP-based agentic environments"
|
| 9 |
-
readme = "README.md"
|
| 10 |
-
requires-python = ">=3.10"
|
| 11 |
-
license = {text = "BSD-3-Clause"}
|
| 12 |
-
authors = [
|
| 13 |
-
{name = "Meta Platforms, Inc.", email = "opensource@meta.com"}
|
| 14 |
-
]
|
| 15 |
-
keywords = ["environment", "agent", "http", "docker", "fastapi"]
|
| 16 |
-
|
| 17 |
-
dependencies = [
|
| 18 |
-
"fastapi>=0.104.0",
|
| 19 |
-
"pydantic>=2.0.0",
|
| 20 |
-
"uvicorn[standard]>=0.24.0",
|
| 21 |
-
"requests>=2.25.0",
|
| 22 |
-
]
|
| 23 |
-
|
| 24 |
-
[project.optional-dependencies]
|
| 25 |
-
dev = [
|
| 26 |
-
"pytest>=7.0.0",
|
| 27 |
-
"black>=23.0.0",
|
| 28 |
-
"ruff>=0.1.0",
|
| 29 |
-
"mypy>=1.0.0",
|
| 30 |
-
]
|
| 31 |
-
|
| 32 |
-
[project.urls]
|
| 33 |
-
Homepage = "https://github.com/facebookresearch/OpenEnv"
|
| 34 |
-
Repository = "https://github.com/facebookresearch/OpenEnv"
|
| 35 |
-
Documentation = "https://github.com/facebookresearch/OpenEnv/blob/main/README.md"
|
| 36 |
-
"Bug Tracker" = "https://github.com/facebookresearch/OpenEnv/issues"
|
| 37 |
-
|
| 38 |
-
[tool.setuptools]
|
| 39 |
-
py-modules = ["openenv_core.__init__", "openenv_core.http_env_client", "openenv_core.client_types"]
|
| 40 |
-
packages = [
|
| 41 |
-
"openenv_core",
|
| 42 |
-
"openenv_core.containers",
|
| 43 |
-
"openenv_core.containers.runtime",
|
| 44 |
-
"openenv_core.env_server",
|
| 45 |
-
"openenv_core.tools"
|
| 46 |
-
]
|
| 47 |
-
package-dir = {"openenv_core" = "."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
core/uv.lock
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
{core → src/core}/README.md
RENAMED
|
@@ -6,7 +6,7 @@ In addition to making it easier for researchers and RL framework writers, we als
|
|
| 6 |
|
| 7 |
|
| 8 |
## Overview
|
| 9 |
-
`openenv
|
| 10 |
|
| 11 |
> ⚠️ **Early Development Warning** OpenEnv is currently in an experimental
|
| 12 |
> stage. You should expect bugs, incomplete features, and APIs that may change
|
|
@@ -22,8 +22,8 @@ Core components for OpenEnv - a framework for building HTTP-based agentic enviro
|
|
| 22 |
|
| 23 |
## Features
|
| 24 |
|
| 25 |
-
- **
|
| 26 |
-
- **HTTPEnvServer**: FastAPI-based server wrapper for exposing environments over HTTP
|
| 27 |
- **Container Providers**: Pluggable architecture for running containers (Docker, Kubernetes, etc.)
|
| 28 |
- **Type System**: Strongly-typed Action/Observation/State interfaces
|
| 29 |
- **Web Interface**: Optional web UI for interacting with environments
|
|
@@ -31,12 +31,12 @@ Core components for OpenEnv - a framework for building HTTP-based agentic enviro
|
|
| 31 |
## Installation
|
| 32 |
|
| 33 |
```bash
|
| 34 |
-
pip install openenv
|
| 35 |
```
|
| 36 |
|
| 37 |
For development:
|
| 38 |
```bash
|
| 39 |
-
pip install openenv
|
| 40 |
```
|
| 41 |
|
| 42 |
## Quick Start
|
|
@@ -44,7 +44,7 @@ pip install openenv-core[dev]
|
|
| 44 |
### Creating an Environment Client
|
| 45 |
|
| 46 |
```python
|
| 47 |
-
from
|
| 48 |
from dataclasses import dataclass
|
| 49 |
|
| 50 |
@dataclass
|
|
@@ -55,7 +55,7 @@ class MyAction:
|
|
| 55 |
class MyObservation:
|
| 56 |
response: str
|
| 57 |
|
| 58 |
-
class MyEnvClient(
|
| 59 |
def _step_payload(self, action: MyAction) -> dict:
|
| 60 |
return {"text": action.text}
|
| 61 |
|
|
@@ -80,7 +80,7 @@ env.close()
|
|
| 80 |
### Creating an Environment Server
|
| 81 |
|
| 82 |
```python
|
| 83 |
-
from
|
| 84 |
from dataclasses import dataclass
|
| 85 |
|
| 86 |
@dataclass
|
|
@@ -118,7 +118,7 @@ OpenEnv Core supports multiple container providers:
|
|
| 118 |
### Local Docker Provider
|
| 119 |
|
| 120 |
```python
|
| 121 |
-
from
|
| 122 |
|
| 123 |
provider = LocalDockerProvider()
|
| 124 |
base_url = provider.start_container("my-env:latest")
|
|
@@ -130,7 +130,7 @@ provider.stop_container()
|
|
| 130 |
### Kubernetes Provider (Coming Soon)
|
| 131 |
|
| 132 |
```python
|
| 133 |
-
from
|
| 134 |
|
| 135 |
provider = KubernetesProvider(namespace="envs")
|
| 136 |
base_url = provider.start_container("my-env:latest")
|
|
@@ -141,7 +141,7 @@ provider.stop_container()
|
|
| 141 |
|
| 142 |
## API Reference
|
| 143 |
|
| 144 |
-
###
|
| 145 |
|
| 146 |
Base class for environment clients with these abstract methods:
|
| 147 |
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
## Overview
|
| 9 |
+
`openenv.core` provides the foundational building blocks for creating and interacting with containerized environments over HTTP. It enables you to build agent environments that can be deployed as Docker containers and accessed via a simple HTTP API.
|
| 10 |
|
| 11 |
> ⚠️ **Early Development Warning** OpenEnv is currently in an experimental
|
| 12 |
> stage. You should expect bugs, incomplete features, and APIs that may change
|
|
|
|
| 22 |
|
| 23 |
## Features
|
| 24 |
|
| 25 |
+
- **EnvClient**: Generic client for interacting with remote environments
|
| 26 |
+
- **HTTPEnvServer**: FastAPI-based server wrapper for exposing environments over HTTP/WebSocket
|
| 27 |
- **Container Providers**: Pluggable architecture for running containers (Docker, Kubernetes, etc.)
|
| 28 |
- **Type System**: Strongly-typed Action/Observation/State interfaces
|
| 29 |
- **Web Interface**: Optional web UI for interacting with environments
|
|
|
|
| 31 |
## Installation
|
| 32 |
|
| 33 |
```bash
|
| 34 |
+
pip install "openenv[core]"
|
| 35 |
```
|
| 36 |
|
| 37 |
For development:
|
| 38 |
```bash
|
| 39 |
+
pip install "openenv[core]"
|
| 40 |
```
|
| 41 |
|
| 42 |
## Quick Start
|
|
|
|
| 44 |
### Creating an Environment Client
|
| 45 |
|
| 46 |
```python
|
| 47 |
+
from openenv.core import EnvClient, StepResult
|
| 48 |
from dataclasses import dataclass
|
| 49 |
|
| 50 |
@dataclass
|
|
|
|
| 55 |
class MyObservation:
|
| 56 |
response: str
|
| 57 |
|
| 58 |
+
class MyEnvClient(EnvClient[MyAction, MyObservation]):
|
| 59 |
def _step_payload(self, action: MyAction) -> dict:
|
| 60 |
return {"text": action.text}
|
| 61 |
|
|
|
|
| 80 |
### Creating an Environment Server
|
| 81 |
|
| 82 |
```python
|
| 83 |
+
from openenv.core.env_server import Environment, HTTPEnvServer, create_app
|
| 84 |
from dataclasses import dataclass
|
| 85 |
|
| 86 |
@dataclass
|
|
|
|
| 118 |
### Local Docker Provider
|
| 119 |
|
| 120 |
```python
|
| 121 |
+
from openenv.core.containers.runtime import LocalDockerProvider
|
| 122 |
|
| 123 |
provider = LocalDockerProvider()
|
| 124 |
base_url = provider.start_container("my-env:latest")
|
|
|
|
| 130 |
### Kubernetes Provider (Coming Soon)
|
| 131 |
|
| 132 |
```python
|
| 133 |
+
from openenv.core.containers.runtime import KubernetesProvider
|
| 134 |
|
| 135 |
provider = KubernetesProvider(namespace="envs")
|
| 136 |
base_url = provider.start_container("my-env:latest")
|
|
|
|
| 141 |
|
| 142 |
## API Reference
|
| 143 |
|
| 144 |
+
### EnvClient
|
| 145 |
|
| 146 |
Base class for environment clients with these abstract methods:
|
| 147 |
|
{core → src/core}/__init__.py
RENAMED
|
@@ -7,13 +7,10 @@
|
|
| 7 |
"""Core components for agentic environments."""
|
| 8 |
|
| 9 |
# Re-export main components from submodules for convenience
|
| 10 |
-
from .env_server import *
|
| 11 |
-
from .
|
| 12 |
-
from .
|
| 13 |
|
| 14 |
# Note: MCP module doesn't export anything yet
|
| 15 |
|
| 16 |
-
__all__ = [
|
| 17 |
-
"HTTPEnvClient",
|
| 18 |
-
"StepResult",
|
| 19 |
-
]
|
|
|
|
| 7 |
"""Core components for agentic environments."""
|
| 8 |
|
| 9 |
# Re-export main components from submodules for convenience
|
| 10 |
+
from .env_server import * # noqa: F403
|
| 11 |
+
from . import env_server
|
| 12 |
+
from .env_client import EnvClient
|
| 13 |
|
| 14 |
# Note: MCP module doesn't export anything yet
|
| 15 |
|
| 16 |
+
__all__ = ["EnvClient"] + env_server.__all__ # type: ignore
|
|
|
|
|
|
|
|
|
{core → src/core}/client_types.py
RENAMED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
# Type definitions for EnvTorch
|
| 2 |
from dataclasses import dataclass
|
| 3 |
-
from typing import
|
| 4 |
|
| 5 |
# Generic type for observations
|
| 6 |
-
ObsT = TypeVar("ObsT")
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
@dataclass
|
|
|
|
| 1 |
# Type definitions for EnvTorch
|
| 2 |
from dataclasses import dataclass
|
| 3 |
+
from typing import Generic, Optional, TypeVar
|
| 4 |
|
| 5 |
# Generic type for observations
|
| 6 |
+
ObsT = TypeVar("ObsT")
|
| 7 |
+
StateT = TypeVar("StateT")
|
| 8 |
|
| 9 |
|
| 10 |
@dataclass
|
src/core/containers/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Container management for environment servers."""
|
{core → src/core}/containers/images/Dockerfile
RENAMED
|
File without changes
|
{core → src/core}/containers/images/README.md
RENAMED
|
@@ -48,7 +48,7 @@ FROM openenv-base:latest
|
|
| 48 |
|
| 49 |
# Copy only environment-specific files
|
| 50 |
COPY src/core/ /app/src/core/
|
| 51 |
-
COPY
|
| 52 |
|
| 53 |
# Run the server
|
| 54 |
CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
@@ -69,7 +69,7 @@ CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "80
|
|
| 69 |
docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
|
| 70 |
|
| 71 |
# Step 2: Build echo environment (uses base)
|
| 72 |
-
docker build -t echo-env:latest -f
|
| 73 |
|
| 74 |
# Step 3: Run echo environment
|
| 75 |
docker run -p 8000:8000 echo-env:latest
|
|
@@ -88,5 +88,5 @@ When dependencies need updating:
|
|
| 88 |
docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
|
| 89 |
|
| 90 |
# Rebuild environments (they automatically use new base)
|
| 91 |
-
docker build -t echo-env:latest -f
|
| 92 |
```
|
|
|
|
| 48 |
|
| 49 |
# Copy only environment-specific files
|
| 50 |
COPY src/core/ /app/src/core/
|
| 51 |
+
COPY envs/my_env/ /app/envs/my_env/
|
| 52 |
|
| 53 |
# Run the server
|
| 54 |
CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
|
|
| 69 |
docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
|
| 70 |
|
| 71 |
# Step 2: Build echo environment (uses base)
|
| 72 |
+
docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile .
|
| 73 |
|
| 74 |
# Step 3: Run echo environment
|
| 75 |
docker run -p 8000:8000 echo-env:latest
|
|
|
|
| 88 |
docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
|
| 89 |
|
| 90 |
# Rebuild environments (they automatically use new base)
|
| 91 |
+
docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile .
|
| 92 |
```
|
{core → src/core}/containers/runtime/__init__.py
RENAMED
|
@@ -6,10 +6,20 @@
|
|
| 6 |
|
| 7 |
"""Container runtime providers."""
|
| 8 |
|
| 9 |
-
from .providers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
__all__ = [
|
| 12 |
"ContainerProvider",
|
|
|
|
| 13 |
"LocalDockerProvider",
|
| 14 |
"KubernetesProvider",
|
| 15 |
-
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
"""Container runtime providers."""
|
| 8 |
|
| 9 |
+
from .providers import (
|
| 10 |
+
ContainerProvider,
|
| 11 |
+
DockerSwarmProvider,
|
| 12 |
+
KubernetesProvider,
|
| 13 |
+
LocalDockerProvider,
|
| 14 |
+
RuntimeProvider,
|
| 15 |
+
)
|
| 16 |
+
from .uv_provider import UVProvider
|
| 17 |
|
| 18 |
__all__ = [
|
| 19 |
"ContainerProvider",
|
| 20 |
+
"DockerSwarmProvider",
|
| 21 |
"LocalDockerProvider",
|
| 22 |
"KubernetesProvider",
|
| 23 |
+
"RuntimeProvider",
|
| 24 |
+
"UVProvider",
|
| 25 |
+
]
|
{core → src/core}/containers/runtime/providers.py
RENAMED
|
@@ -8,13 +8,13 @@
|
|
| 8 |
Container provider abstractions for running environment servers.
|
| 9 |
|
| 10 |
This module provides a pluggable architecture for different container providers
|
| 11 |
-
(local Docker, Kubernetes, cloud providers, etc.) to be used with
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
| 16 |
from abc import ABC, abstractmethod
|
| 17 |
-
from typing import Any, Dict, Optional
|
| 18 |
|
| 19 |
|
| 20 |
class ContainerProvider(ABC):
|
|
@@ -118,7 +118,11 @@ class LocalDockerProvider(ContainerProvider):
|
|
| 118 |
capture_output=True,
|
| 119 |
timeout=5,
|
| 120 |
)
|
| 121 |
-
except (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
raise RuntimeError(
|
| 123 |
"Docker is not available. Please install Docker Desktop or Docker Engine."
|
| 124 |
)
|
|
@@ -154,10 +158,13 @@ class LocalDockerProvider(ContainerProvider):
|
|
| 154 |
|
| 155 |
# Build docker run command
|
| 156 |
cmd = [
|
| 157 |
-
"docker",
|
|
|
|
| 158 |
"-d", # Detached
|
| 159 |
-
"--name",
|
| 160 |
-
|
|
|
|
|
|
|
| 161 |
]
|
| 162 |
|
| 163 |
# Add environment variables
|
|
@@ -277,6 +284,304 @@ class LocalDockerProvider(ContainerProvider):
|
|
| 277 |
return f"{clean_image}-{timestamp}"
|
| 278 |
|
| 279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
class KubernetesProvider(ContainerProvider):
|
| 281 |
"""
|
| 282 |
Container provider for Kubernetes clusters.
|
|
@@ -290,4 +595,67 @@ class KubernetesProvider(ContainerProvider):
|
|
| 290 |
>>> # Pod running in k8s, accessible via service or port-forward
|
| 291 |
>>> provider.stop_container()
|
| 292 |
"""
|
|
|
|
| 293 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
Container provider abstractions for running environment servers.
|
| 9 |
|
| 10 |
This module provides a pluggable architecture for different container providers
|
| 11 |
+
(local Docker, Kubernetes, cloud providers, etc.) to be used with EnvClient.
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
| 16 |
from abc import ABC, abstractmethod
|
| 17 |
+
from typing import Any, Dict, Optional, Sequence
|
| 18 |
|
| 19 |
|
| 20 |
class ContainerProvider(ABC):
|
|
|
|
| 118 |
capture_output=True,
|
| 119 |
timeout=5,
|
| 120 |
)
|
| 121 |
+
except (
|
| 122 |
+
subprocess.CalledProcessError,
|
| 123 |
+
FileNotFoundError,
|
| 124 |
+
subprocess.TimeoutExpired,
|
| 125 |
+
):
|
| 126 |
raise RuntimeError(
|
| 127 |
"Docker is not available. Please install Docker Desktop or Docker Engine."
|
| 128 |
)
|
|
|
|
| 158 |
|
| 159 |
# Build docker run command
|
| 160 |
cmd = [
|
| 161 |
+
"docker",
|
| 162 |
+
"run",
|
| 163 |
"-d", # Detached
|
| 164 |
+
"--name",
|
| 165 |
+
self._container_name,
|
| 166 |
+
"-p",
|
| 167 |
+
f"{port}:8000", # Map port
|
| 168 |
]
|
| 169 |
|
| 170 |
# Add environment variables
|
|
|
|
| 284 |
return f"{clean_image}-{timestamp}"
|
| 285 |
|
| 286 |
|
| 287 |
+
class DockerSwarmProvider(ContainerProvider):
|
| 288 |
+
"""
|
| 289 |
+
Container provider that uses Docker Swarm services for local concurrency.
|
| 290 |
+
|
| 291 |
+
This provider creates a replicated Swarm service backed by the local Docker
|
| 292 |
+
engine. The built-in load-balancer fans requests across the replicas,
|
| 293 |
+
allowing multiple container instances to run concurrently on the developer
|
| 294 |
+
workstation (mirroring the workflow described in the Docker stack docs).
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
def __init__(
|
| 298 |
+
self,
|
| 299 |
+
*,
|
| 300 |
+
auto_init_swarm: bool = True,
|
| 301 |
+
overlay_network: Optional[str] = None,
|
| 302 |
+
):
|
| 303 |
+
"""
|
| 304 |
+
Args:
|
| 305 |
+
auto_init_swarm: Whether to call ``docker swarm init`` when Swarm
|
| 306 |
+
is not active. Otherwise, user must manually initialize Swarm.
|
| 307 |
+
overlay_network: Optional overlay network name for the service.
|
| 308 |
+
When provided, the network is created with
|
| 309 |
+
``docker network create --driver overlay --attachable`` if it
|
| 310 |
+
does not already exist.
|
| 311 |
+
"""
|
| 312 |
+
self._service_name: Optional[str] = None
|
| 313 |
+
self._service_id: Optional[str] = None
|
| 314 |
+
self._published_port: Optional[int] = None
|
| 315 |
+
self._overlay_network = overlay_network
|
| 316 |
+
self._auto_init_swarm = auto_init_swarm
|
| 317 |
+
|
| 318 |
+
self._ensure_docker_available()
|
| 319 |
+
self._ensure_swarm_initialized()
|
| 320 |
+
if self._overlay_network:
|
| 321 |
+
self._ensure_overlay_network(self._overlay_network)
|
| 322 |
+
|
| 323 |
+
def start_container(
|
| 324 |
+
self,
|
| 325 |
+
image: str,
|
| 326 |
+
port: Optional[int] = None,
|
| 327 |
+
env_vars: Optional[Dict[str, str]] = None,
|
| 328 |
+
**kwargs: Any,
|
| 329 |
+
) -> str:
|
| 330 |
+
"""
|
| 331 |
+
Start (or scale) a Swarm service for the given image.
|
| 332 |
+
|
| 333 |
+
Supported kwargs:
|
| 334 |
+
replicas (int): Number of container replicas (default: 2).
|
| 335 |
+
cpu_limit (float | str): CPU limit passed to ``--limit-cpu``.
|
| 336 |
+
memory_limit (str): Memory limit passed to ``--limit-memory``.
|
| 337 |
+
constraints (Sequence[str]): Placement constraints.
|
| 338 |
+
labels (Dict[str, str]): Service labels.
|
| 339 |
+
command (Sequence[str] | str): Override container command.
|
| 340 |
+
"""
|
| 341 |
+
import shlex
|
| 342 |
+
import subprocess
|
| 343 |
+
import time
|
| 344 |
+
|
| 345 |
+
allowed_kwargs = {
|
| 346 |
+
"replicas",
|
| 347 |
+
"cpu_limit",
|
| 348 |
+
"memory_limit",
|
| 349 |
+
"constraints",
|
| 350 |
+
"labels",
|
| 351 |
+
"command",
|
| 352 |
+
}
|
| 353 |
+
unknown = set(kwargs) - allowed_kwargs
|
| 354 |
+
if unknown:
|
| 355 |
+
raise ValueError(f"Unsupported kwargs for DockerSwarmProvider: {unknown}")
|
| 356 |
+
|
| 357 |
+
replicas = int(kwargs.get("replicas", 2))
|
| 358 |
+
cpu_limit = kwargs.get("cpu_limit")
|
| 359 |
+
memory_limit = kwargs.get("memory_limit")
|
| 360 |
+
constraints: Optional[Sequence[str]] = kwargs.get("constraints")
|
| 361 |
+
labels: Optional[Dict[str, str]] = kwargs.get("labels")
|
| 362 |
+
command_override = kwargs.get("command")
|
| 363 |
+
|
| 364 |
+
if port is None:
|
| 365 |
+
port = self._find_available_port()
|
| 366 |
+
|
| 367 |
+
self._service_name = self._generate_service_name(image)
|
| 368 |
+
self._published_port = port
|
| 369 |
+
|
| 370 |
+
cmd = [
|
| 371 |
+
"docker",
|
| 372 |
+
"service",
|
| 373 |
+
"create",
|
| 374 |
+
"--detach",
|
| 375 |
+
"--name",
|
| 376 |
+
self._service_name,
|
| 377 |
+
"--replicas",
|
| 378 |
+
str(max(1, replicas)),
|
| 379 |
+
"--publish",
|
| 380 |
+
f"{port}:8000",
|
| 381 |
+
]
|
| 382 |
+
|
| 383 |
+
if self._overlay_network:
|
| 384 |
+
cmd.extend(["--network", self._overlay_network])
|
| 385 |
+
|
| 386 |
+
if env_vars:
|
| 387 |
+
for key, value in env_vars.items():
|
| 388 |
+
cmd.extend(["--env", f"{key}={value}"])
|
| 389 |
+
|
| 390 |
+
if cpu_limit is not None:
|
| 391 |
+
cmd.extend(["--limit-cpu", str(cpu_limit)])
|
| 392 |
+
|
| 393 |
+
if memory_limit is not None:
|
| 394 |
+
cmd.extend(["--limit-memory", str(memory_limit)])
|
| 395 |
+
|
| 396 |
+
if constraints:
|
| 397 |
+
for constraint in constraints:
|
| 398 |
+
cmd.extend(["--constraint", constraint])
|
| 399 |
+
|
| 400 |
+
if labels:
|
| 401 |
+
for key, value in labels.items():
|
| 402 |
+
cmd.extend(["--label", f"{key}={value}"])
|
| 403 |
+
|
| 404 |
+
cmd.append(image)
|
| 405 |
+
|
| 406 |
+
if command_override:
|
| 407 |
+
if isinstance(command_override, str):
|
| 408 |
+
cmd.extend(shlex.split(command_override))
|
| 409 |
+
else:
|
| 410 |
+
cmd.extend(command_override)
|
| 411 |
+
|
| 412 |
+
try:
|
| 413 |
+
result = subprocess.run(
|
| 414 |
+
cmd,
|
| 415 |
+
capture_output=True,
|
| 416 |
+
text=True,
|
| 417 |
+
check=True,
|
| 418 |
+
)
|
| 419 |
+
self._service_id = result.stdout.strip()
|
| 420 |
+
except subprocess.CalledProcessError as e:
|
| 421 |
+
error_msg = (
|
| 422 |
+
"Failed to start Docker Swarm service.\n"
|
| 423 |
+
f"Command: {' '.join(cmd)}\n"
|
| 424 |
+
f"Exit code: {e.returncode}\n"
|
| 425 |
+
f"Stdout: {e.stdout}\n"
|
| 426 |
+
f"Stderr: {e.stderr}"
|
| 427 |
+
)
|
| 428 |
+
raise RuntimeError(error_msg) from e
|
| 429 |
+
|
| 430 |
+
# Give Swarm a brief moment to schedule the tasks.
|
| 431 |
+
time.sleep(1.0)
|
| 432 |
+
|
| 433 |
+
return f"http://localhost:{port}"
|
| 434 |
+
|
| 435 |
+
def stop_container(self) -> None:
|
| 436 |
+
"""
|
| 437 |
+
Remove the Swarm service (and keep the Swarm manager running).
|
| 438 |
+
"""
|
| 439 |
+
if not self._service_name:
|
| 440 |
+
return
|
| 441 |
+
|
| 442 |
+
import subprocess
|
| 443 |
+
|
| 444 |
+
try:
|
| 445 |
+
subprocess.run(
|
| 446 |
+
["docker", "service", "rm", self._service_name],
|
| 447 |
+
capture_output=True,
|
| 448 |
+
check=True,
|
| 449 |
+
timeout=10,
|
| 450 |
+
)
|
| 451 |
+
except subprocess.CalledProcessError:
|
| 452 |
+
# Service may already be gone; ignore.
|
| 453 |
+
pass
|
| 454 |
+
finally:
|
| 455 |
+
self._service_name = None
|
| 456 |
+
self._service_id = None
|
| 457 |
+
self._published_port = None
|
| 458 |
+
|
| 459 |
+
def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None:
|
| 460 |
+
"""
|
| 461 |
+
Wait for at least one replica to become healthy by polling /health.
|
| 462 |
+
|
| 463 |
+
Note: With Swarm's load balancer, requests round-robin across replicas,
|
| 464 |
+
so this only verifies that at least one replica is responding. Some
|
| 465 |
+
replicas may still be starting when this returns.
|
| 466 |
+
"""
|
| 467 |
+
import time
|
| 468 |
+
import requests
|
| 469 |
+
|
| 470 |
+
deadline = time.time() + timeout_s
|
| 471 |
+
health_url = f"{base_url}/health"
|
| 472 |
+
|
| 473 |
+
while time.time() < deadline:
|
| 474 |
+
try:
|
| 475 |
+
response = requests.get(health_url, timeout=2.0)
|
| 476 |
+
if response.status_code == 200:
|
| 477 |
+
return
|
| 478 |
+
except requests.RequestException:
|
| 479 |
+
pass
|
| 480 |
+
|
| 481 |
+
time.sleep(0.5)
|
| 482 |
+
|
| 483 |
+
raise TimeoutError(
|
| 484 |
+
f"Swarm service at {base_url} did not become ready within {timeout_s}s"
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
def _ensure_docker_available(self) -> None:
|
| 488 |
+
import subprocess
|
| 489 |
+
|
| 490 |
+
try:
|
| 491 |
+
subprocess.run(
|
| 492 |
+
["docker", "version"],
|
| 493 |
+
check=True,
|
| 494 |
+
capture_output=True,
|
| 495 |
+
timeout=5,
|
| 496 |
+
)
|
| 497 |
+
except (
|
| 498 |
+
subprocess.CalledProcessError,
|
| 499 |
+
FileNotFoundError,
|
| 500 |
+
subprocess.TimeoutExpired,
|
| 501 |
+
) as exc:
|
| 502 |
+
raise RuntimeError(
|
| 503 |
+
"Docker is not available. Please install Docker Desktop or Docker Engine."
|
| 504 |
+
) from exc
|
| 505 |
+
|
| 506 |
+
def _ensure_swarm_initialized(self) -> None:
|
| 507 |
+
import subprocess
|
| 508 |
+
|
| 509 |
+
try:
|
| 510 |
+
result = subprocess.run(
|
| 511 |
+
["docker", "info", "--format", "{{.Swarm.LocalNodeState}}"],
|
| 512 |
+
capture_output=True,
|
| 513 |
+
text=True,
|
| 514 |
+
check=True,
|
| 515 |
+
timeout=5,
|
| 516 |
+
)
|
| 517 |
+
state = result.stdout.strip().lower()
|
| 518 |
+
if state == "active":
|
| 519 |
+
return
|
| 520 |
+
except subprocess.CalledProcessError:
|
| 521 |
+
state = "unknown"
|
| 522 |
+
|
| 523 |
+
if not self._auto_init_swarm:
|
| 524 |
+
raise RuntimeError(
|
| 525 |
+
f"Docker Swarm is not active (state={state}). Enable Swarm manually or pass auto_init_swarm=True."
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
try:
|
| 529 |
+
subprocess.run(
|
| 530 |
+
["docker", "swarm", "init"],
|
| 531 |
+
check=True,
|
| 532 |
+
capture_output=True,
|
| 533 |
+
timeout=10,
|
| 534 |
+
)
|
| 535 |
+
except subprocess.CalledProcessError as e:
|
| 536 |
+
raise RuntimeError("Failed to initialize Docker Swarm") from e
|
| 537 |
+
|
| 538 |
+
def _ensure_overlay_network(self, network: str) -> None:
|
| 539 |
+
import subprocess
|
| 540 |
+
|
| 541 |
+
inspect = subprocess.run(
|
| 542 |
+
["docker", "network", "inspect", network],
|
| 543 |
+
capture_output=True,
|
| 544 |
+
text=True,
|
| 545 |
+
check=False,
|
| 546 |
+
)
|
| 547 |
+
if inspect.returncode == 0:
|
| 548 |
+
return
|
| 549 |
+
|
| 550 |
+
try:
|
| 551 |
+
subprocess.run(
|
| 552 |
+
[
|
| 553 |
+
"docker",
|
| 554 |
+
"network",
|
| 555 |
+
"create",
|
| 556 |
+
"--driver",
|
| 557 |
+
"overlay",
|
| 558 |
+
"--attachable",
|
| 559 |
+
network,
|
| 560 |
+
],
|
| 561 |
+
check=True,
|
| 562 |
+
capture_output=True,
|
| 563 |
+
timeout=10,
|
| 564 |
+
)
|
| 565 |
+
except subprocess.CalledProcessError as e:
|
| 566 |
+
raise RuntimeError(f"Failed to create overlay network '{network}'") from e
|
| 567 |
+
|
| 568 |
+
def _find_available_port(self) -> int:
|
| 569 |
+
import socket
|
| 570 |
+
|
| 571 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 572 |
+
s.bind(("", 0))
|
| 573 |
+
s.listen(1)
|
| 574 |
+
port = s.getsockname()[1]
|
| 575 |
+
return port
|
| 576 |
+
|
| 577 |
+
def _generate_service_name(self, image: str) -> str:
|
| 578 |
+
import time
|
| 579 |
+
|
| 580 |
+
clean_image = image.split("/")[-1].split(":")[0]
|
| 581 |
+
timestamp = int(time.time() * 1000)
|
| 582 |
+
return f"{clean_image}-swarm-{timestamp}"
|
| 583 |
+
|
| 584 |
+
|
| 585 |
class KubernetesProvider(ContainerProvider):
|
| 586 |
"""
|
| 587 |
Container provider for Kubernetes clusters.
|
|
|
|
| 595 |
>>> # Pod running in k8s, accessible via service or port-forward
|
| 596 |
>>> provider.stop_container()
|
| 597 |
"""
|
| 598 |
+
|
| 599 |
pass
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
class RuntimeProvider(ABC):
|
| 603 |
+
"""
|
| 604 |
+
Abstract base class for runtime providers that are not container providers.
|
| 605 |
+
Providers implement this interface to support different runtime platforms:
|
| 606 |
+
- UVProvider: Runs environments via `uv run`
|
| 607 |
+
|
| 608 |
+
The provider manages a single runtime lifecycle and provides the base URL
|
| 609 |
+
for connecting to it.
|
| 610 |
+
|
| 611 |
+
Example:
|
| 612 |
+
>>> provider = UVProvider(project_path="/path/to/env")
|
| 613 |
+
>>> base_url = provider.start()
|
| 614 |
+
>>> print(base_url) # http://localhost:8000
|
| 615 |
+
>>> provider.stop()
|
| 616 |
+
"""
|
| 617 |
+
|
| 618 |
+
@abstractmethod
|
| 619 |
+
def start(
|
| 620 |
+
self,
|
| 621 |
+
port: Optional[int] = None,
|
| 622 |
+
env_vars: Optional[Dict[str, str]] = None,
|
| 623 |
+
**kwargs: Any,
|
| 624 |
+
) -> str:
|
| 625 |
+
"""
|
| 626 |
+
Start a runtime from the specified image.
|
| 627 |
+
|
| 628 |
+
Args:
|
| 629 |
+
image: Runtime image name
|
| 630 |
+
port: Port to expose (if None, provider chooses)
|
| 631 |
+
env_vars: Environment variables for the runtime
|
| 632 |
+
**kwargs: Additional runtime options
|
| 633 |
+
"""
|
| 634 |
+
|
| 635 |
+
@abstractmethod
|
| 636 |
+
def stop(self) -> None:
|
| 637 |
+
"""
|
| 638 |
+
Stop the runtime.
|
| 639 |
+
"""
|
| 640 |
+
pass
|
| 641 |
+
|
| 642 |
+
@abstractmethod
|
| 643 |
+
def wait_for_ready(self, timeout_s: float = 30.0) -> None:
|
| 644 |
+
"""
|
| 645 |
+
Wait for the runtime to be ready to accept requests.
|
| 646 |
+
"""
|
| 647 |
+
pass
|
| 648 |
+
|
| 649 |
+
def __enter__(self) -> "RuntimeProvider":
|
| 650 |
+
"""
|
| 651 |
+
Enter the runtime provider.
|
| 652 |
+
"""
|
| 653 |
+
self.start()
|
| 654 |
+
return self
|
| 655 |
+
|
| 656 |
+
def __exit__(self, exc_type, exc, tb) -> None:
|
| 657 |
+
"""
|
| 658 |
+
Exit the runtime provider.
|
| 659 |
+
"""
|
| 660 |
+
self.stop()
|
| 661 |
+
return False
|
src/core/containers/runtime/uv_provider.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Providers for launching ASGI applications via ``uv run``."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import socket
|
| 7 |
+
import subprocess
|
| 8 |
+
import time
|
| 9 |
+
from typing import Dict, Optional
|
| 10 |
+
|
| 11 |
+
import requests
|
| 12 |
+
|
| 13 |
+
from .providers import RuntimeProvider
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _check_uv_installed() -> None:
|
| 17 |
+
try:
|
| 18 |
+
subprocess.check_output(["uv", "--version"])
|
| 19 |
+
except FileNotFoundError as exc:
|
| 20 |
+
raise RuntimeError(
|
| 21 |
+
"`uv` executable not found. Install uv from https://docs.astral.sh and ensure it is on PATH."
|
| 22 |
+
) from exc
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _find_free_port() -> int:
|
| 26 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
| 27 |
+
sock.bind(("", 0))
|
| 28 |
+
sock.listen(1)
|
| 29 |
+
return sock.getsockname()[1]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _create_uv_command(
|
| 33 |
+
*,
|
| 34 |
+
host: str,
|
| 35 |
+
port: int,
|
| 36 |
+
reload: bool,
|
| 37 |
+
workers: int,
|
| 38 |
+
app: str,
|
| 39 |
+
project_path: str,
|
| 40 |
+
) -> list[str]:
|
| 41 |
+
command: list[str] = ["uv", "run", "--isolated", "--project", project_path]
|
| 42 |
+
|
| 43 |
+
command.append("--")
|
| 44 |
+
command.extend(
|
| 45 |
+
[
|
| 46 |
+
"uvicorn",
|
| 47 |
+
app,
|
| 48 |
+
"--host",
|
| 49 |
+
host,
|
| 50 |
+
"--port",
|
| 51 |
+
str(port),
|
| 52 |
+
"--workers",
|
| 53 |
+
str(workers),
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if reload:
|
| 58 |
+
command.append("--reload")
|
| 59 |
+
|
| 60 |
+
return command
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _poll_health(health_url: str, timeout_s: float) -> None:
|
| 64 |
+
"""Poll a health endpoint until it returns HTTP 200 or times out."""
|
| 65 |
+
|
| 66 |
+
deadline = time.time() + timeout_s
|
| 67 |
+
while time.time() < deadline:
|
| 68 |
+
try:
|
| 69 |
+
timeout = max(0.0001, min(deadline - time.time(), 2.0))
|
| 70 |
+
response = requests.get(health_url, timeout=timeout)
|
| 71 |
+
if response.status_code == 200:
|
| 72 |
+
return
|
| 73 |
+
except requests.RequestException:
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
time.sleep(0.5)
|
| 77 |
+
|
| 78 |
+
raise TimeoutError(f"Server did not become ready within {timeout_s:.1f} seconds")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class UVProvider(RuntimeProvider):
|
| 82 |
+
"""
|
| 83 |
+
RuntimeProvider implementation backed by ``uv run``.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
project_path: Local path to a uv project (passed to ``uv run --project``)
|
| 87 |
+
app: ASGI application path for uvicorn (defaults to ``server.app:app``)
|
| 88 |
+
host: Host interface to bind to (defaults to ``0.0.0.0``)
|
| 89 |
+
reload: Whether to enable uvicorn's reload mode
|
| 90 |
+
env_vars: Environment variables to pass through to the spawned process
|
| 91 |
+
context_timeout_s: How long to wait for the environment to become ready
|
| 92 |
+
|
| 93 |
+
Example:
|
| 94 |
+
>>> provider = UVProvider(project_path="/path/to/env")
|
| 95 |
+
>>> base_url = provider.start()
|
| 96 |
+
>>> print(base_url) # http://localhost:8000
|
| 97 |
+
>>> # Use the environment via base_url
|
| 98 |
+
>>> provider.stop()
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
*,
|
| 104 |
+
project_path: str,
|
| 105 |
+
app: str = "server.app:app",
|
| 106 |
+
host: str = "0.0.0.0",
|
| 107 |
+
reload: bool = False,
|
| 108 |
+
env_vars: Optional[Dict[str, str]] = None,
|
| 109 |
+
context_timeout_s: float = 60.0,
|
| 110 |
+
):
|
| 111 |
+
"""Initialize the UVProvider."""
|
| 112 |
+
self.project_path = os.path.abspath(project_path)
|
| 113 |
+
self.app = app
|
| 114 |
+
self.host = host
|
| 115 |
+
self.reload = reload
|
| 116 |
+
self.env_vars = env_vars
|
| 117 |
+
self.context_timeout_s = context_timeout_s
|
| 118 |
+
_check_uv_installed()
|
| 119 |
+
self._process = None
|
| 120 |
+
self._base_url = None
|
| 121 |
+
|
| 122 |
+
def start(
|
| 123 |
+
self,
|
| 124 |
+
port: Optional[int] = None,
|
| 125 |
+
env_vars: Optional[Dict[str, str]] = None,
|
| 126 |
+
workers: int = 1,
|
| 127 |
+
**_: Dict[str, str],
|
| 128 |
+
) -> str:
|
| 129 |
+
"""
|
| 130 |
+
Start the environment via `uv run`.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
port: The port to bind the environment to
|
| 134 |
+
env_vars: Environment variables to pass to the environment
|
| 135 |
+
workers: The number of workers to use
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
The base URL of the environment
|
| 139 |
+
|
| 140 |
+
Raises:
|
| 141 |
+
RuntimeError: If the environment is already running
|
| 142 |
+
"""
|
| 143 |
+
if self._process is not None and self._process.poll() is None:
|
| 144 |
+
raise RuntimeError("UVProvider is already running")
|
| 145 |
+
|
| 146 |
+
bind_port = port or _find_free_port()
|
| 147 |
+
|
| 148 |
+
command = _create_uv_command(
|
| 149 |
+
host=self.host,
|
| 150 |
+
port=bind_port,
|
| 151 |
+
reload=self.reload,
|
| 152 |
+
workers=workers,
|
| 153 |
+
app=self.app,
|
| 154 |
+
project_path=self.project_path,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
env = os.environ.copy()
|
| 158 |
+
|
| 159 |
+
if self.env_vars:
|
| 160 |
+
env.update(self.env_vars)
|
| 161 |
+
if env_vars:
|
| 162 |
+
env.update(env_vars)
|
| 163 |
+
|
| 164 |
+
try:
|
| 165 |
+
self._process = subprocess.Popen(command, env=env)
|
| 166 |
+
except OSError as exc:
|
| 167 |
+
raise RuntimeError(f"Failed to launch `uv run`: {exc}") from exc
|
| 168 |
+
|
| 169 |
+
client_host = "127.0.0.1" if self.host in {"0.0.0.0", "::"} else self.host
|
| 170 |
+
self._base_url = f"http://{client_host}:{bind_port}"
|
| 171 |
+
return self._base_url
|
| 172 |
+
|
| 173 |
+
def wait_for_ready(self, timeout_s: float = 60.0) -> None:
|
| 174 |
+
"""
|
| 175 |
+
Wait for the environment to become ready.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
timeout_s: The timeout to wait for the environment to become ready
|
| 179 |
+
|
| 180 |
+
Raises:
|
| 181 |
+
RuntimeError: If the environment is not running
|
| 182 |
+
TimeoutError: If the environment does not become ready within the timeout
|
| 183 |
+
"""
|
| 184 |
+
if self._process and self._process.poll() is not None:
|
| 185 |
+
code = self._process.returncode
|
| 186 |
+
raise RuntimeError(f"uv process exited prematurely with code {code}")
|
| 187 |
+
|
| 188 |
+
_poll_health(f"{self._base_url}/health", timeout_s=timeout_s)
|
| 189 |
+
|
| 190 |
+
def stop(self) -> None:
|
| 191 |
+
"""
|
| 192 |
+
Stop the environment.
|
| 193 |
+
|
| 194 |
+
Raises:
|
| 195 |
+
RuntimeError: If the environment is not running
|
| 196 |
+
"""
|
| 197 |
+
if self._process is None:
|
| 198 |
+
return
|
| 199 |
+
|
| 200 |
+
if self._process.poll() is None:
|
| 201 |
+
self._process.terminate()
|
| 202 |
+
try:
|
| 203 |
+
self._process.wait(timeout=10.0)
|
| 204 |
+
except subprocess.TimeoutExpired:
|
| 205 |
+
self._process.kill()
|
| 206 |
+
self._process.wait(timeout=5.0)
|
| 207 |
+
|
| 208 |
+
self._process = None
|
| 209 |
+
self._base_url = None
|
| 210 |
+
|
| 211 |
+
@property
|
| 212 |
+
def base_url(self) -> str:
|
| 213 |
+
"""
|
| 214 |
+
The base URL of the environment.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
The base URL of the environment
|
| 218 |
+
|
| 219 |
+
Raises:
|
| 220 |
+
RuntimeError: If the environment is not running
|
| 221 |
+
"""
|
| 222 |
+
if self._base_url is None:
|
| 223 |
+
raise RuntimeError("UVProvider has not been started")
|
| 224 |
+
return self._base_url
|
{core → src/core}/containers/test_local_docker_provider.py
RENAMED
|
@@ -17,7 +17,8 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
| 17 |
|
| 18 |
import requests
|
| 19 |
|
| 20 |
-
from core.containers.runtime import LocalDockerProvider
|
|
|
|
| 21 |
|
| 22 |
# TODO: Remove this test or make it a functional test sicne this will be tested in e2e test for echo env
|
| 23 |
def test_local_docker_provider():
|
|
@@ -87,7 +88,9 @@ def test_local_docker_provider():
|
|
| 87 |
print(f" Length: {data['observation']['message_length']}")
|
| 88 |
print(f" Reward: {data['reward']}")
|
| 89 |
assert response.status_code == 200
|
| 90 |
-
assert
|
|
|
|
|
|
|
| 91 |
assert data["observation"]["message_length"] == 31
|
| 92 |
print("✓ Step test passed\n")
|
| 93 |
|
|
@@ -107,11 +110,11 @@ def test_local_docker_provider():
|
|
| 107 |
for i in range(3):
|
| 108 |
response = requests.post(
|
| 109 |
f"{base_url}/step",
|
| 110 |
-
json={"action": {"message": f"Message {i+1}"}},
|
| 111 |
headers={"Content-Type": "application/json"},
|
| 112 |
)
|
| 113 |
assert response.status_code == 200
|
| 114 |
-
print(f" Step {i+1}: ✓")
|
| 115 |
|
| 116 |
# Check state updated
|
| 117 |
response = requests.get(f"{base_url}/state")
|
|
@@ -130,6 +133,7 @@ def test_local_docker_provider():
|
|
| 130 |
except Exception as e:
|
| 131 |
print(f"\n❌ Test failed: {e}")
|
| 132 |
import traceback
|
|
|
|
| 133 |
traceback.print_exc()
|
| 134 |
return False
|
| 135 |
|
|
@@ -197,8 +201,7 @@ def test_provider_with_env_vars():
|
|
| 197 |
|
| 198 |
print("Starting container with environment variables...")
|
| 199 |
base_url = provider.start_container(
|
| 200 |
-
"echo-env:latest",
|
| 201 |
-
env_vars={"DEBUG": "true", "LOG_LEVEL": "info"}
|
| 202 |
)
|
| 203 |
print(f"✓ Started at: {base_url}")
|
| 204 |
|
|
|
|
| 17 |
|
| 18 |
import requests
|
| 19 |
|
| 20 |
+
from openenv.core.containers.runtime import LocalDockerProvider
|
| 21 |
+
|
| 22 |
|
| 23 |
# TODO: Remove this test or make it a functional test sicne this will be tested in e2e test for echo env
|
| 24 |
def test_local_docker_provider():
|
|
|
|
| 88 |
print(f" Length: {data['observation']['message_length']}")
|
| 89 |
print(f" Reward: {data['reward']}")
|
| 90 |
assert response.status_code == 200
|
| 91 |
+
assert (
|
| 92 |
+
data["observation"]["echoed_message"] == "Hello from LocalDockerProvider!"
|
| 93 |
+
)
|
| 94 |
assert data["observation"]["message_length"] == 31
|
| 95 |
print("✓ Step test passed\n")
|
| 96 |
|
|
|
|
| 110 |
for i in range(3):
|
| 111 |
response = requests.post(
|
| 112 |
f"{base_url}/step",
|
| 113 |
+
json={"action": {"message": f"Message {i + 1}"}},
|
| 114 |
headers={"Content-Type": "application/json"},
|
| 115 |
)
|
| 116 |
assert response.status_code == 200
|
| 117 |
+
print(f" Step {i + 1}: ✓")
|
| 118 |
|
| 119 |
# Check state updated
|
| 120 |
response = requests.get(f"{base_url}/state")
|
|
|
|
| 133 |
except Exception as e:
|
| 134 |
print(f"\n❌ Test failed: {e}")
|
| 135 |
import traceback
|
| 136 |
+
|
| 137 |
traceback.print_exc()
|
| 138 |
return False
|
| 139 |
|
|
|
|
| 201 |
|
| 202 |
print("Starting container with environment variables...")
|
| 203 |
base_url = provider.start_container(
|
| 204 |
+
"echo-env:latest", env_vars={"DEBUG": "true", "LOG_LEVEL": "info"}
|
|
|
|
| 205 |
)
|
| 206 |
print(f"✓ Started at: {base_url}")
|
| 207 |
|
src/core/env_client.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Environment client for persistent sessions.
|
| 9 |
+
|
| 10 |
+
This module provides a WebSocket-based client that maintains a persistent connection
|
| 11 |
+
to an environment server, enabling efficient multi-step interactions without
|
| 12 |
+
the overhead of HTTP request/response cycles.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
from abc import ABC, abstractmethod
|
| 19 |
+
from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar
|
| 20 |
+
|
| 21 |
+
from .client_types import StepResult, StateT
|
| 22 |
+
from .containers.runtime import LocalDockerProvider, UVProvider
|
| 23 |
+
from .utils import convert_to_ws_url
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
from .containers.runtime import ContainerProvider, RuntimeProvider
|
| 27 |
+
from websockets.sync.client import ClientConnection
|
| 28 |
+
|
| 29 |
+
from websockets.sync.client import connect as ws_connect
|
| 30 |
+
|
| 31 |
+
ActT = TypeVar("ActT")
|
| 32 |
+
ObsT = TypeVar("ObsT")
|
| 33 |
+
EnvClientT = TypeVar("EnvClientT", bound="EnvClient")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class EnvClient(ABC, Generic[ActT, ObsT, StateT]):
|
| 37 |
+
"""
|
| 38 |
+
Environment client for persistent sessions.
|
| 39 |
+
|
| 40 |
+
This client maintains a persistent WebSocket connection to an environment
|
| 41 |
+
server, enabling efficient multi-step interactions. Each client instance
|
| 42 |
+
corresponds to a dedicated environment session on the server.
|
| 43 |
+
|
| 44 |
+
Features:
|
| 45 |
+
- Lower latency for sequential interactions
|
| 46 |
+
- Session state is maintained server-side
|
| 47 |
+
- Better suited for long-running episodes
|
| 48 |
+
|
| 49 |
+
Example:
|
| 50 |
+
>>> from envs.coding_env.client import CodingEnv
|
| 51 |
+
>>>
|
| 52 |
+
>>> # Connect to a server
|
| 53 |
+
>>> with CodingEnv(base_url="ws://localhost:8000") as env:
|
| 54 |
+
... result = env.reset(seed=42)
|
| 55 |
+
... while not result.done:
|
| 56 |
+
... action = agent.predict(result.observation)
|
| 57 |
+
... result = env.step(action)
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
base_url: str,
|
| 63 |
+
connect_timeout_s: float = 10.0,
|
| 64 |
+
message_timeout_s: float = 60.0,
|
| 65 |
+
provider: Optional["ContainerProvider | RuntimeProvider"] = None,
|
| 66 |
+
):
|
| 67 |
+
"""
|
| 68 |
+
Initialize environment client.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
base_url: Base URL of the environment server (http:// or ws://).
|
| 72 |
+
Will be converted to ws:// if http:// is provided.
|
| 73 |
+
connect_timeout_s: Timeout for establishing WebSocket connection
|
| 74 |
+
message_timeout_s: Timeout for receiving responses to messages
|
| 75 |
+
provider: Optional container/runtime provider for lifecycle management.
|
| 76 |
+
Can be a ContainerProvider (Docker) or RuntimeProvider (UV).
|
| 77 |
+
"""
|
| 78 |
+
# Convert HTTP URL to WebSocket URL
|
| 79 |
+
ws_url = convert_to_ws_url(base_url)
|
| 80 |
+
|
| 81 |
+
self._ws_url = f"{ws_url}/ws"
|
| 82 |
+
self._connect_timeout = connect_timeout_s
|
| 83 |
+
self._message_timeout = message_timeout_s
|
| 84 |
+
self._provider = provider
|
| 85 |
+
self._ws: Optional[ClientConnection] = None
|
| 86 |
+
|
| 87 |
+
def connect(self) -> "EnvClient":
|
| 88 |
+
"""
|
| 89 |
+
Establish WebSocket connection to the server.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
self for method chaining
|
| 93 |
+
|
| 94 |
+
Raises:
|
| 95 |
+
ConnectionError: If connection cannot be established
|
| 96 |
+
"""
|
| 97 |
+
if self._ws is not None:
|
| 98 |
+
return self
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
self._ws = ws_connect(
|
| 102 |
+
self._ws_url,
|
| 103 |
+
open_timeout=self._connect_timeout,
|
| 104 |
+
)
|
| 105 |
+
except Exception as e:
|
| 106 |
+
raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e
|
| 107 |
+
|
| 108 |
+
return self
|
| 109 |
+
|
| 110 |
+
def disconnect(self) -> None:
|
| 111 |
+
"""Close the WebSocket connection."""
|
| 112 |
+
if self._ws is not None:
|
| 113 |
+
try:
|
| 114 |
+
# Send close message
|
| 115 |
+
self._send({"type": "close"})
|
| 116 |
+
except Exception:
|
| 117 |
+
pass # Best effort
|
| 118 |
+
try:
|
| 119 |
+
self._ws.close()
|
| 120 |
+
except Exception:
|
| 121 |
+
pass
|
| 122 |
+
self._ws = None
|
| 123 |
+
|
| 124 |
+
def _ensure_connected(self) -> None:
|
| 125 |
+
"""Ensure WebSocket connection is established."""
|
| 126 |
+
if self._ws is None:
|
| 127 |
+
self.connect()
|
| 128 |
+
|
| 129 |
+
def _send(self, message: Dict[str, Any]) -> None:
|
| 130 |
+
"""Send a message over the WebSocket."""
|
| 131 |
+
self._ensure_connected()
|
| 132 |
+
assert self._ws is not None
|
| 133 |
+
self._ws.send(json.dumps(message))
|
| 134 |
+
|
| 135 |
+
def _receive(self) -> Dict[str, Any]:
|
| 136 |
+
"""Receive and parse a message from the WebSocket."""
|
| 137 |
+
assert self._ws is not None
|
| 138 |
+
raw = self._ws.recv(timeout=self._message_timeout)
|
| 139 |
+
return json.loads(raw)
|
| 140 |
+
|
| 141 |
+
def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
| 142 |
+
"""Send a message and wait for response."""
|
| 143 |
+
self._send(message)
|
| 144 |
+
response = self._receive()
|
| 145 |
+
|
| 146 |
+
# Check for error response
|
| 147 |
+
if response.get("type") == "error":
|
| 148 |
+
error_data = response.get("data", {})
|
| 149 |
+
raise RuntimeError(
|
| 150 |
+
f"Server error: {error_data.get('message', 'Unknown error')} "
|
| 151 |
+
f"(code: {error_data.get('code', 'UNKNOWN')})"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
return response
|
| 155 |
+
|
| 156 |
+
@classmethod
|
| 157 |
+
def from_docker_image(
|
| 158 |
+
cls: Type[EnvClientT],
|
| 159 |
+
image: str,
|
| 160 |
+
provider: Optional["ContainerProvider"] = None,
|
| 161 |
+
**kwargs: Any,
|
| 162 |
+
) -> EnvClientT:
|
| 163 |
+
"""
|
| 164 |
+
Create an environment client by spinning up a Docker container.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
image: Docker image name to run (e.g., "coding-env:latest")
|
| 168 |
+
provider: Container provider to use (defaults to LocalDockerProvider)
|
| 169 |
+
**kwargs: Additional arguments to pass to provider.start_container()
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
Connected client instance
|
| 173 |
+
"""
|
| 174 |
+
if provider is None:
|
| 175 |
+
provider = LocalDockerProvider()
|
| 176 |
+
|
| 177 |
+
# Start container
|
| 178 |
+
base_url = provider.start_container(image, **kwargs)
|
| 179 |
+
|
| 180 |
+
# Wait for server to be ready
|
| 181 |
+
provider.wait_for_ready(base_url)
|
| 182 |
+
|
| 183 |
+
# Create and connect client
|
| 184 |
+
client = cls(base_url=base_url, provider=provider)
|
| 185 |
+
client.connect()
|
| 186 |
+
|
| 187 |
+
return client
|
| 188 |
+
|
| 189 |
+
@classmethod
|
| 190 |
+
def from_hub(
|
| 191 |
+
cls: Type[EnvClientT],
|
| 192 |
+
repo_id: str,
|
| 193 |
+
*,
|
| 194 |
+
use_docker: bool = True,
|
| 195 |
+
provider: Optional["ContainerProvider | RuntimeProvider"] = None,
|
| 196 |
+
**provider_kwargs: Any,
|
| 197 |
+
) -> EnvClientT:
|
| 198 |
+
"""
|
| 199 |
+
Create a client from a Hugging Face Space.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
repo_id: Hugging Face space identifier ``{org}/{space}``.
|
| 203 |
+
use_docker: When ``True`` (default) pull from the HF registry and
|
| 204 |
+
launch via :class:`LocalDockerProvider`. When ``False`` run the
|
| 205 |
+
space locally with :class:`UVProvider`.
|
| 206 |
+
provider: Optional provider instance to reuse. Must be a
|
| 207 |
+
:class:`ContainerProvider` when ``use_docker=True`` and a
|
| 208 |
+
:class:`RuntimeProvider` otherwise.
|
| 209 |
+
provider_kwargs: Additional keyword arguments forwarded to
|
| 210 |
+
either the container provider's ``start_container`` (docker)
|
| 211 |
+
or to the ``UVProvider`` constructor/start (uv). When
|
| 212 |
+
``use_docker=False``, the ``project_path`` argument can be
|
| 213 |
+
used to override the default git URL
|
| 214 |
+
(``git+https://huggingface.co/spaces/{repo_id}``).
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Connected client instance
|
| 218 |
+
|
| 219 |
+
Examples:
|
| 220 |
+
>>> # Pull and run from HF Docker registry
|
| 221 |
+
>>> env = MyEnv.from_hub("openenv/echo-env")
|
| 222 |
+
>>>
|
| 223 |
+
>>> # Run locally with UV (clones the space)
|
| 224 |
+
>>> env = MyEnv.from_hub("openenv/echo-env", use_docker=False)
|
| 225 |
+
>>>
|
| 226 |
+
>>> # Run from a local checkout
|
| 227 |
+
>>> env = MyEnv.from_hub(
|
| 228 |
+
... "openenv/echo-env",
|
| 229 |
+
... use_docker=False,
|
| 230 |
+
... project_path="/path/to/local/checkout"
|
| 231 |
+
... )
|
| 232 |
+
"""
|
| 233 |
+
# Extract start args that apply to both providers
|
| 234 |
+
start_args = {}
|
| 235 |
+
for key in ("port", "env_vars", "workers"):
|
| 236 |
+
if key in provider_kwargs:
|
| 237 |
+
start_args[key] = provider_kwargs.pop(key)
|
| 238 |
+
|
| 239 |
+
if use_docker:
|
| 240 |
+
# Docker mode: pull from HF registry
|
| 241 |
+
docker_provider = provider or LocalDockerProvider()
|
| 242 |
+
tag = provider_kwargs.pop("tag", "latest")
|
| 243 |
+
image = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}"
|
| 244 |
+
base_url = docker_provider.start_container(
|
| 245 |
+
image, **start_args, **provider_kwargs
|
| 246 |
+
)
|
| 247 |
+
docker_provider.wait_for_ready(base_url)
|
| 248 |
+
|
| 249 |
+
client = cls(base_url=base_url, provider=docker_provider)
|
| 250 |
+
client.connect()
|
| 251 |
+
return client
|
| 252 |
+
else:
|
| 253 |
+
# UV mode: clone and run with uv
|
| 254 |
+
if provider is None:
|
| 255 |
+
uv_kwargs = dict(provider_kwargs)
|
| 256 |
+
project_path = uv_kwargs.pop("project_path", None)
|
| 257 |
+
if project_path is None:
|
| 258 |
+
project_path = f"git+https://huggingface.co/spaces/{repo_id}"
|
| 259 |
+
|
| 260 |
+
provider = UVProvider(project_path=project_path, **uv_kwargs)
|
| 261 |
+
else:
|
| 262 |
+
if provider_kwargs:
|
| 263 |
+
raise ValueError(
|
| 264 |
+
"provider_kwargs cannot be used when supplying a provider instance"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
base_url = provider.start(**start_args)
|
| 268 |
+
provider.wait_for_ready()
|
| 269 |
+
|
| 270 |
+
client = cls(base_url=base_url, provider=provider)
|
| 271 |
+
client.connect()
|
| 272 |
+
return client
|
| 273 |
+
|
| 274 |
+
@abstractmethod
|
| 275 |
+
def _step_payload(self, action: ActT) -> Dict[str, Any]:
|
| 276 |
+
"""Convert an Action object to the JSON data expected by the env server."""
|
| 277 |
+
raise NotImplementedError
|
| 278 |
+
|
| 279 |
+
@abstractmethod
|
| 280 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]:
|
| 281 |
+
"""Convert a JSON response from the env server to StepResult[ObsT]."""
|
| 282 |
+
raise NotImplementedError
|
| 283 |
+
|
| 284 |
+
@abstractmethod
|
| 285 |
+
def _parse_state(self, payload: Dict[str, Any]) -> StateT:
|
| 286 |
+
"""Convert a JSON response from the state endpoint to a State object."""
|
| 287 |
+
raise NotImplementedError
|
| 288 |
+
|
| 289 |
+
def reset(self, **kwargs: Any) -> StepResult[ObsT]:
|
| 290 |
+
"""
|
| 291 |
+
Reset the environment with optional parameters.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
**kwargs: Optional parameters passed to the environment's reset method.
|
| 295 |
+
Common parameters include:
|
| 296 |
+
- seed: Random seed for reproducibility
|
| 297 |
+
- episode_id: Custom episode identifier
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
StepResult containing initial observation
|
| 301 |
+
"""
|
| 302 |
+
message = {
|
| 303 |
+
"type": "reset",
|
| 304 |
+
"data": kwargs,
|
| 305 |
+
}
|
| 306 |
+
response = self._send_and_receive(message)
|
| 307 |
+
return self._parse_result(response.get("data", {}))
|
| 308 |
+
|
| 309 |
+
def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]:
|
| 310 |
+
"""
|
| 311 |
+
Execute an action in the environment.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
action: The action to execute
|
| 315 |
+
**kwargs: Optional parameters (currently ignored)
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
StepResult containing observation, reward, and done status
|
| 319 |
+
"""
|
| 320 |
+
message = {
|
| 321 |
+
"type": "step",
|
| 322 |
+
"data": self._step_payload(action),
|
| 323 |
+
}
|
| 324 |
+
response = self._send_and_receive(message)
|
| 325 |
+
return self._parse_result(response.get("data", {}))
|
| 326 |
+
|
| 327 |
+
def state(self) -> StateT:
|
| 328 |
+
"""
|
| 329 |
+
Get the current environment state from the server.
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
State object with environment state information
|
| 333 |
+
"""
|
| 334 |
+
message = {"type": "state"}
|
| 335 |
+
response = self._send_and_receive(message)
|
| 336 |
+
return self._parse_state(response.get("data", {}))
|
| 337 |
+
|
| 338 |
+
def close(self) -> None:
|
| 339 |
+
"""
|
| 340 |
+
Close the WebSocket connection and clean up resources.
|
| 341 |
+
|
| 342 |
+
If this client was created via from_docker_image() or from_hub(),
|
| 343 |
+
this will also stop and remove the associated container/process.
|
| 344 |
+
"""
|
| 345 |
+
self.disconnect()
|
| 346 |
+
|
| 347 |
+
if self._provider is not None:
|
| 348 |
+
# Handle both ContainerProvider and RuntimeProvider
|
| 349 |
+
if hasattr(self._provider, "stop_container"):
|
| 350 |
+
self._provider.stop_container()
|
| 351 |
+
elif hasattr(self._provider, "stop"):
|
| 352 |
+
self._provider.stop()
|
| 353 |
+
|
| 354 |
+
def __enter__(self) -> "EnvClient":
|
| 355 |
+
"""Enter context manager, ensuring connection is established."""
|
| 356 |
+
self.connect()
|
| 357 |
+
return self
|
| 358 |
+
|
| 359 |
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
| 360 |
+
"""Exit context manager, closing connection."""
|
| 361 |
+
self.close()
|
{core → src/core}/env_server/__init__.py
RENAMED
|
@@ -9,7 +9,39 @@
|
|
| 9 |
from .base_transforms import CompositeTransform, NullTransform
|
| 10 |
from .http_server import HTTPEnvServer, create_app, create_fastapi_app
|
| 11 |
from .interfaces import Environment, Message, ModelTokenizer, Transform
|
| 12 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from .web_interface import create_web_interface_app, WebInterfaceManager
|
| 14 |
|
| 15 |
__all__ = [
|
|
@@ -22,6 +54,29 @@ __all__ = [
|
|
| 22 |
"Action",
|
| 23 |
"Observation",
|
| 24 |
"State",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Base transforms
|
| 26 |
"CompositeTransform",
|
| 27 |
"NullTransform",
|
|
@@ -32,4 +87,10 @@ __all__ = [
|
|
| 32 |
# Web Interface
|
| 33 |
"create_web_interface_app",
|
| 34 |
"WebInterfaceManager",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
]
|
|
|
|
| 9 |
from .base_transforms import CompositeTransform, NullTransform
|
| 10 |
from .http_server import HTTPEnvServer, create_app, create_fastapi_app
|
| 11 |
from .interfaces import Environment, Message, ModelTokenizer, Transform
|
| 12 |
+
from .route_config import GetEndpointConfig
|
| 13 |
+
from .serialization import (
|
| 14 |
+
deserialize_action,
|
| 15 |
+
deserialize_action_with_preprocessing,
|
| 16 |
+
serialize_observation,
|
| 17 |
+
)
|
| 18 |
+
from .types import (
|
| 19 |
+
Action,
|
| 20 |
+
Observation,
|
| 21 |
+
State,
|
| 22 |
+
SchemaResponse,
|
| 23 |
+
HealthResponse,
|
| 24 |
+
BaseMessage,
|
| 25 |
+
WSIncomingMessage,
|
| 26 |
+
WSResetMessage,
|
| 27 |
+
WSStepMessage,
|
| 28 |
+
WSStateMessage,
|
| 29 |
+
WSCloseMessage,
|
| 30 |
+
WSObservationResponse,
|
| 31 |
+
WSStateResponse,
|
| 32 |
+
WSErrorResponse,
|
| 33 |
+
ConcurrencyConfig,
|
| 34 |
+
ServerCapacityStatus,
|
| 35 |
+
SessionInfo,
|
| 36 |
+
)
|
| 37 |
+
from .exceptions import (
|
| 38 |
+
OpenEnvError,
|
| 39 |
+
ConcurrencyConfigurationError,
|
| 40 |
+
SessionCapacityError,
|
| 41 |
+
SessionNotFoundError,
|
| 42 |
+
SessionCreationError,
|
| 43 |
+
EnvironmentFactoryError,
|
| 44 |
+
)
|
| 45 |
from .web_interface import create_web_interface_app, WebInterfaceManager
|
| 46 |
|
| 47 |
__all__ = [
|
|
|
|
| 54 |
"Action",
|
| 55 |
"Observation",
|
| 56 |
"State",
|
| 57 |
+
"SchemaResponse",
|
| 58 |
+
"HealthResponse",
|
| 59 |
+
# WebSocket message types
|
| 60 |
+
"BaseMessage",
|
| 61 |
+
"WSIncomingMessage",
|
| 62 |
+
"WSResetMessage",
|
| 63 |
+
"WSStepMessage",
|
| 64 |
+
"WSStateMessage",
|
| 65 |
+
"WSCloseMessage",
|
| 66 |
+
"WSObservationResponse",
|
| 67 |
+
"WSStateResponse",
|
| 68 |
+
"WSErrorResponse",
|
| 69 |
+
# Concurrency types
|
| 70 |
+
"ConcurrencyConfig",
|
| 71 |
+
"ServerCapacityStatus",
|
| 72 |
+
"SessionInfo",
|
| 73 |
+
# Exceptions
|
| 74 |
+
"OpenEnvError",
|
| 75 |
+
"ConcurrencyConfigurationError",
|
| 76 |
+
"SessionCapacityError",
|
| 77 |
+
"SessionNotFoundError",
|
| 78 |
+
"SessionCreationError",
|
| 79 |
+
"EnvironmentFactoryError",
|
| 80 |
# Base transforms
|
| 81 |
"CompositeTransform",
|
| 82 |
"NullTransform",
|
|
|
|
| 87 |
# Web Interface
|
| 88 |
"create_web_interface_app",
|
| 89 |
"WebInterfaceManager",
|
| 90 |
+
# Serialization utilities
|
| 91 |
+
"deserialize_action",
|
| 92 |
+
"deserialize_action_with_preprocessing",
|
| 93 |
+
"serialize_observation",
|
| 94 |
+
# Route configuration
|
| 95 |
+
"GetEndpointConfig",
|
| 96 |
]
|
{core → src/core}/env_server/base_transforms.py
RENAMED
|
@@ -26,4 +26,4 @@ class NullTransform(Transform):
|
|
| 26 |
"""Default transform that passes through unchanged."""
|
| 27 |
|
| 28 |
def __call__(self, observation: Observation) -> Observation:
|
| 29 |
-
return observation
|
|
|
|
| 26 |
"""Default transform that passes through unchanged."""
|
| 27 |
|
| 28 |
def __call__(self, observation: Observation) -> Observation:
|
| 29 |
+
return observation
|
src/core/env_server/exceptions.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Custom exceptions for environment server operations."""
|
| 8 |
+
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class OpenEnvError(Exception):
|
| 13 |
+
"""Base exception for all OpenEnv errors."""
|
| 14 |
+
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ConcurrencyConfigurationError(OpenEnvError):
|
| 19 |
+
"""
|
| 20 |
+
Raised when an environment is misconfigured for concurrent sessions.
|
| 21 |
+
|
| 22 |
+
This error is raised during server startup when max_concurrent_envs > 1
|
| 23 |
+
is specified for an environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
environment_name: str,
|
| 29 |
+
max_concurrent_envs: int,
|
| 30 |
+
message: Optional[str] = None,
|
| 31 |
+
):
|
| 32 |
+
self.environment_name = environment_name
|
| 33 |
+
self.max_concurrent_envs = max_concurrent_envs
|
| 34 |
+
|
| 35 |
+
if message is None:
|
| 36 |
+
message = (
|
| 37 |
+
f"Environment '{environment_name}' is not marked as SUPPORTS_CONCURRENT_SESSIONS. "
|
| 38 |
+
f"Cannot run with max_concurrent_envs={max_concurrent_envs}. "
|
| 39 |
+
f"Either set max_concurrent_envs=1 or ensure the environment "
|
| 40 |
+
f"properly isolates session state and set SUPPORTS_CONCURRENT_SESSIONS=True."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
super().__init__(message)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SessionCapacityError(OpenEnvError):
|
| 47 |
+
"""
|
| 48 |
+
Raised when the server cannot accept new sessions due to capacity limits.
|
| 49 |
+
|
| 50 |
+
This error is raised when a new WebSocket connection is attempted but
|
| 51 |
+
the server has already reached max_concurrent_envs active sessions.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
active_sessions: int,
|
| 57 |
+
max_sessions: int,
|
| 58 |
+
message: Optional[str] = None,
|
| 59 |
+
):
|
| 60 |
+
self.active_sessions = active_sessions
|
| 61 |
+
self.max_sessions = max_sessions
|
| 62 |
+
|
| 63 |
+
if message is None:
|
| 64 |
+
message = (
|
| 65 |
+
f"Server at capacity: {active_sessions}/{max_sessions} sessions active. "
|
| 66 |
+
f"Cannot accept new connections."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
super().__init__(message)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class SessionNotFoundError(OpenEnvError):
|
| 73 |
+
"""Raised when attempting to access a session that does not exist."""
|
| 74 |
+
|
| 75 |
+
def __init__(self, session_id: str, message: Optional[str] = None):
|
| 76 |
+
self.session_id = session_id
|
| 77 |
+
|
| 78 |
+
if message is None:
|
| 79 |
+
message = f"Session '{session_id}' not found."
|
| 80 |
+
|
| 81 |
+
super().__init__(message)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class SessionCreationError(OpenEnvError):
|
| 85 |
+
"""Raised when a session cannot be created."""
|
| 86 |
+
|
| 87 |
+
def __init__(self, reason: str, message: Optional[str] = None):
|
| 88 |
+
self.reason = reason
|
| 89 |
+
|
| 90 |
+
if message is None:
|
| 91 |
+
message = f"Failed to create session: {reason}"
|
| 92 |
+
|
| 93 |
+
super().__init__(message)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class EnvironmentFactoryError(OpenEnvError):
|
| 97 |
+
"""Raised when the environment factory fails to create an instance."""
|
| 98 |
+
|
| 99 |
+
def __init__(self, factory_name: str, message: Optional[str] = None):
|
| 100 |
+
self.factory_name = factory_name
|
| 101 |
+
|
| 102 |
+
if message is None:
|
| 103 |
+
message = f"Environment factory '{factory_name}' failed to create instance."
|
| 104 |
+
|
| 105 |
+
super().__init__(message)
|
src/core/env_server/http_server.py
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
HTTP server wrapper for Environment instances.
|
| 9 |
+
|
| 10 |
+
This module provides utilities to wrap any Environment subclass and expose it
|
| 11 |
+
over HTTP and WebSocket endpoints that EnvClient can consume.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import asyncio
|
| 17 |
+
import inspect
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import time
|
| 21 |
+
import uuid
|
| 22 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 23 |
+
from typing import Any, Callable, Dict, Optional, Type, Union
|
| 24 |
+
|
| 25 |
+
from fastapi import Body, FastAPI, HTTPException, WebSocket, WebSocketDisconnect, status
|
| 26 |
+
from pydantic import ValidationError
|
| 27 |
+
|
| 28 |
+
from .interfaces import Environment
|
| 29 |
+
from .route_config import (
|
| 30 |
+
GetEndpointConfig,
|
| 31 |
+
register_get_endpoints,
|
| 32 |
+
)
|
| 33 |
+
from .serialization import deserialize_action, serialize_observation
|
| 34 |
+
from .types import (
|
| 35 |
+
Action,
|
| 36 |
+
Observation,
|
| 37 |
+
ResetRequest,
|
| 38 |
+
ResetResponse,
|
| 39 |
+
State,
|
| 40 |
+
StepRequest,
|
| 41 |
+
StepResponse,
|
| 42 |
+
EnvironmentMetadata,
|
| 43 |
+
SchemaResponse,
|
| 44 |
+
HealthResponse,
|
| 45 |
+
WSResetMessage,
|
| 46 |
+
WSStepMessage,
|
| 47 |
+
WSStateMessage,
|
| 48 |
+
WSCloseMessage,
|
| 49 |
+
WSObservationResponse,
|
| 50 |
+
WSStateResponse,
|
| 51 |
+
WSErrorResponse,
|
| 52 |
+
ConcurrencyConfig,
|
| 53 |
+
ServerCapacityStatus,
|
| 54 |
+
SessionInfo,
|
| 55 |
+
)
|
| 56 |
+
from .exceptions import (
|
| 57 |
+
ConcurrencyConfigurationError,
|
| 58 |
+
SessionCapacityError,
|
| 59 |
+
EnvironmentFactoryError,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class HTTPEnvServer:
|
| 64 |
+
"""
|
| 65 |
+
HTTP server wrapper for Environment instances.
|
| 66 |
+
|
| 67 |
+
This class wraps an Environment and exposes its reset(), step(), and state
|
| 68 |
+
methods as HTTP and WebSocket endpoints compatible with EnvClient.
|
| 69 |
+
|
| 70 |
+
The server expects:
|
| 71 |
+
- Action deserialization: Converts JSON dict to Action subclass
|
| 72 |
+
- Observation serialization: Converts Observation subclass to JSON dict
|
| 73 |
+
|
| 74 |
+
Example:
|
| 75 |
+
>>> from core.env_server import HTTPEnvServer
|
| 76 |
+
>>> from envs.coding_env.server import CodeExecutionEnvironment
|
| 77 |
+
>>> from envs.coding_env.models import CodeAction, CodeObservation
|
| 78 |
+
>>>
|
| 79 |
+
>>> # Pass environment class (factory pattern)
|
| 80 |
+
>>> server = HTTPEnvServer(
|
| 81 |
+
... env=CodeExecutionEnvironment,
|
| 82 |
+
... action_cls=CodeAction,
|
| 83 |
+
... observation_cls=CodeObservation,
|
| 84 |
+
... max_concurrent_envs=4,
|
| 85 |
+
... )
|
| 86 |
+
>>>
|
| 87 |
+
>>> # Register routes with FastAPI
|
| 88 |
+
>>> from fastapi import FastAPI
|
| 89 |
+
>>> app = FastAPI()
|
| 90 |
+
>>> server.register_routes(app)
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
env: Callable[[], Environment],
|
| 96 |
+
action_cls: Type[Action],
|
| 97 |
+
observation_cls: Type[Observation],
|
| 98 |
+
max_concurrent_envs: Optional[int] = None,
|
| 99 |
+
concurrency_config: Optional[ConcurrencyConfig] = None,
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
Initialize HTTP server wrapper.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
env: Environment factory (callable) that creates new instances.
|
| 106 |
+
Will be called to create a new environment for each WebSocket session.
|
| 107 |
+
action_cls: The Action subclass this environment expects
|
| 108 |
+
observation_cls: The Observation subclass this environment returns
|
| 109 |
+
max_concurrent_envs: Maximum number of concurrent WebSocket sessions.
|
| 110 |
+
Mutually exclusive with concurrency_config.
|
| 111 |
+
concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings.
|
| 112 |
+
Mutually exclusive with max_concurrent_envs.
|
| 113 |
+
|
| 114 |
+
Raises:
|
| 115 |
+
ValueError: If both max_concurrent_envs and concurrency_config are provided.
|
| 116 |
+
ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an
|
| 117 |
+
environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS.
|
| 118 |
+
"""
|
| 119 |
+
# Validate that env is callable
|
| 120 |
+
if not callable(env):
|
| 121 |
+
raise TypeError(
|
| 122 |
+
f"env must be a callable (class or factory function), got {type(env)}. "
|
| 123 |
+
f"Pass the environment class (e.g., MyEnvironment) not an instance (e.g., MyEnvironment())."
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self._env_factory: Callable[[], Environment] = env
|
| 127 |
+
|
| 128 |
+
# Handle concurrency configuration
|
| 129 |
+
if max_concurrent_envs is not None and concurrency_config is not None:
|
| 130 |
+
raise ValueError(
|
| 131 |
+
"Cannot specify both 'max_concurrent_envs' and 'concurrency_config'. "
|
| 132 |
+
"Please use only one method to configure concurrency."
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if concurrency_config is not None:
|
| 136 |
+
self._concurrency_config = concurrency_config
|
| 137 |
+
elif max_concurrent_envs is not None:
|
| 138 |
+
self._concurrency_config = ConcurrencyConfig(
|
| 139 |
+
max_concurrent_envs=max_concurrent_envs,
|
| 140 |
+
session_timeout=None,
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
# Default configuration
|
| 144 |
+
self._concurrency_config = ConcurrencyConfig(
|
| 145 |
+
max_concurrent_envs=1,
|
| 146 |
+
session_timeout=None,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self._max_concurrent_envs = self._concurrency_config.max_concurrent_envs
|
| 150 |
+
|
| 151 |
+
# Validate concurrency configuration
|
| 152 |
+
self._validate_concurrency_safety()
|
| 153 |
+
|
| 154 |
+
self.action_cls = action_cls
|
| 155 |
+
self.observation_cls = observation_cls
|
| 156 |
+
|
| 157 |
+
# Session management for WebSocket connections
|
| 158 |
+
self._sessions: Dict[str, Environment] = {}
|
| 159 |
+
self._session_executors: Dict[str, ThreadPoolExecutor] = {}
|
| 160 |
+
self._session_info: Dict[str, SessionInfo] = {}
|
| 161 |
+
self._session_lock = asyncio.Lock()
|
| 162 |
+
|
| 163 |
+
# Create thread pool for running sync code in async context
|
| 164 |
+
# This is needed for environments using sync libraries (e.g., Playwright)
|
| 165 |
+
self._executor = ThreadPoolExecutor(max_workers=32)
|
| 166 |
+
|
| 167 |
+
def _validate_concurrency_safety(self) -> None:
|
| 168 |
+
"""
|
| 169 |
+
Validate that the environment supports the configured concurrency level.
|
| 170 |
+
|
| 171 |
+
Raises:
|
| 172 |
+
ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an
|
| 173 |
+
environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS.
|
| 174 |
+
"""
|
| 175 |
+
if self._max_concurrent_envs <= 1:
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
if inspect.isclass(self._env_factory):
|
| 179 |
+
env_cls = self._env_factory
|
| 180 |
+
else:
|
| 181 |
+
_temp_env = self._env_factory()
|
| 182 |
+
env_cls = type(_temp_env)
|
| 183 |
+
_temp_env.close()
|
| 184 |
+
del _temp_env
|
| 185 |
+
|
| 186 |
+
if not getattr(env_cls, "SUPPORTS_CONCURRENT_SESSIONS", False):
|
| 187 |
+
raise ConcurrencyConfigurationError(
|
| 188 |
+
environment_name=env_cls.__name__,
|
| 189 |
+
max_concurrent_envs=self._max_concurrent_envs,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def get_capacity_status(self) -> ServerCapacityStatus:
|
| 193 |
+
"""
|
| 194 |
+
Get the current capacity status of the server.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
ServerCapacityStatus with current session counts and availability.
|
| 198 |
+
"""
|
| 199 |
+
return ServerCapacityStatus.from_counts(
|
| 200 |
+
active=len(self._sessions),
|
| 201 |
+
max_sessions=self._max_concurrent_envs,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
async def _run_sync_in_thread_pool(
|
| 205 |
+
self, func: Callable[..., Observation], *args, **kwargs
|
| 206 |
+
) -> Observation:
|
| 207 |
+
"""Run a synchronous function in the thread pool executor."""
|
| 208 |
+
loop = asyncio.get_event_loop()
|
| 209 |
+
return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs))
|
| 210 |
+
|
| 211 |
+
def _get_valid_kwargs(
|
| 212 |
+
self,
|
| 213 |
+
sig: inspect.Signature,
|
| 214 |
+
kwargs: Dict[str, Any],
|
| 215 |
+
skip_params: Optional[set[str]] = None,
|
| 216 |
+
) -> Dict[str, Any]:
|
| 217 |
+
"""Filter kwargs to only include parameters accepted by the function signature."""
|
| 218 |
+
if skip_params is None:
|
| 219 |
+
skip_params = set()
|
| 220 |
+
|
| 221 |
+
valid_kwargs = {}
|
| 222 |
+
|
| 223 |
+
has_kwargs = any(
|
| 224 |
+
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
for k, v in kwargs.items():
|
| 228 |
+
if k in sig.parameters or has_kwargs:
|
| 229 |
+
if k not in skip_params:
|
| 230 |
+
valid_kwargs[k] = v
|
| 231 |
+
|
| 232 |
+
return valid_kwargs
|
| 233 |
+
|
| 234 |
+
async def _create_session(self) -> tuple[str, Environment]:
|
| 235 |
+
"""
|
| 236 |
+
Create a new WebSocket session with its own environment instance.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
Tuple of (session_id, environment)
|
| 240 |
+
|
| 241 |
+
Raises:
|
| 242 |
+
SessionCapacityError: If max concurrent sessions reached
|
| 243 |
+
EnvironmentFactoryError: If the factory fails to create an environment
|
| 244 |
+
"""
|
| 245 |
+
async with self._session_lock:
|
| 246 |
+
if len(self._sessions) >= self._max_concurrent_envs:
|
| 247 |
+
raise SessionCapacityError(
|
| 248 |
+
active_sessions=len(self._sessions),
|
| 249 |
+
max_sessions=self._max_concurrent_envs,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
session_id = str(uuid.uuid4())
|
| 253 |
+
current_time = time.time()
|
| 254 |
+
|
| 255 |
+
try:
|
| 256 |
+
env = self._env_factory()
|
| 257 |
+
except Exception as e:
|
| 258 |
+
factory_name = getattr(
|
| 259 |
+
self._env_factory, "__name__", str(self._env_factory)
|
| 260 |
+
)
|
| 261 |
+
raise EnvironmentFactoryError(factory_name) from e
|
| 262 |
+
|
| 263 |
+
self._sessions[session_id] = env
|
| 264 |
+
|
| 265 |
+
self._session_executors[session_id] = ThreadPoolExecutor(max_workers=1)
|
| 266 |
+
|
| 267 |
+
# Track session metadata
|
| 268 |
+
self._session_info[session_id] = SessionInfo(
|
| 269 |
+
session_id=session_id,
|
| 270 |
+
created_at=current_time,
|
| 271 |
+
last_activity_at=current_time,
|
| 272 |
+
step_count=0,
|
| 273 |
+
environment_type=type(env).__name__,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
return session_id, env
|
| 277 |
+
|
| 278 |
+
async def _destroy_session(self, session_id: str) -> None:
|
| 279 |
+
"""
|
| 280 |
+
Destroy a WebSocket session and cleanup resources.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
session_id: The session ID to destroy
|
| 284 |
+
"""
|
| 285 |
+
async with self._session_lock:
|
| 286 |
+
if session_id in self._sessions:
|
| 287 |
+
env = self._sessions.pop(session_id)
|
| 288 |
+
env.close()
|
| 289 |
+
|
| 290 |
+
if session_id in self._session_executors:
|
| 291 |
+
executor = self._session_executors.pop(session_id)
|
| 292 |
+
executor.shutdown(wait=False)
|
| 293 |
+
|
| 294 |
+
self._session_info.pop(session_id, None)
|
| 295 |
+
|
| 296 |
+
def _update_session_activity(
|
| 297 |
+
self, session_id: str, increment_step: bool = False
|
| 298 |
+
) -> None:
|
| 299 |
+
"""
|
| 300 |
+
Update session activity timestamp and optionally increment step count.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
session_id: The session ID to update
|
| 304 |
+
increment_step: If True, increment the step count
|
| 305 |
+
"""
|
| 306 |
+
if session_id in self._session_info:
|
| 307 |
+
self._session_info[session_id].last_activity_at = time.time()
|
| 308 |
+
if increment_step:
|
| 309 |
+
self._session_info[session_id].step_count += 1
|
| 310 |
+
|
| 311 |
+
def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
|
| 312 |
+
"""
|
| 313 |
+
Get information about a specific session.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
session_id: The session ID to query
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
SessionInfo if the session exists, None otherwise
|
| 320 |
+
"""
|
| 321 |
+
return self._session_info.get(session_id)
|
| 322 |
+
|
| 323 |
+
async def _run_in_session_executor(
|
| 324 |
+
self, session_id: str, func: Callable[..., Observation], *args, **kwargs
|
| 325 |
+
) -> Observation:
|
| 326 |
+
"""Run a synchronous function in the session's thread pool executor."""
|
| 327 |
+
executor = self._session_executors.get(session_id, self._executor)
|
| 328 |
+
loop = asyncio.get_event_loop()
|
| 329 |
+
return await loop.run_in_executor(executor, lambda: func(*args, **kwargs))
|
| 330 |
+
|
| 331 |
+
@property
|
| 332 |
+
def active_sessions(self) -> int:
|
| 333 |
+
"""Return the number of active WebSocket sessions."""
|
| 334 |
+
return len(self._sessions)
|
| 335 |
+
|
| 336 |
+
@property
|
| 337 |
+
def max_concurrent_envs(self) -> int:
|
| 338 |
+
"""Return the maximum number of concurrent environments."""
|
| 339 |
+
return self._max_concurrent_envs
|
| 340 |
+
|
| 341 |
+
@property
|
| 342 |
+
def is_concurrency_safe(self) -> bool:
|
| 343 |
+
"""Return whether the environment is marked as concurrency safe."""
|
| 344 |
+
import inspect
|
| 345 |
+
|
| 346 |
+
if inspect.isclass(self._env_factory):
|
| 347 |
+
return getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False)
|
| 348 |
+
else:
|
| 349 |
+
_temp_env = self._env_factory()
|
| 350 |
+
result = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False)
|
| 351 |
+
_temp_env.close()
|
| 352 |
+
del _temp_env
|
| 353 |
+
return result
|
| 354 |
+
|
| 355 |
+
@property
|
| 356 |
+
def concurrency_config(self) -> ConcurrencyConfig:
|
| 357 |
+
"""Return the concurrency configuration."""
|
| 358 |
+
return self._concurrency_config
|
| 359 |
+
|
| 360 |
+
def register_routes(self, app: FastAPI) -> None:
|
| 361 |
+
"""
|
| 362 |
+
Register HTTP routes on a FastAPI application.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
app: FastAPI application instance
|
| 366 |
+
"""
|
| 367 |
+
|
| 368 |
+
# Helper function to handle reset endpoint
|
| 369 |
+
async def reset_handler(
|
| 370 |
+
request: ResetRequest = Body(default_factory=ResetRequest),
|
| 371 |
+
) -> ResetResponse:
|
| 372 |
+
"""Reset endpoint - returns initial observation."""
|
| 373 |
+
_env = self._env_factory()
|
| 374 |
+
|
| 375 |
+
try:
|
| 376 |
+
kwargs = request.model_dump(exclude_unset=True)
|
| 377 |
+
|
| 378 |
+
is_async = _env.reset_async.__func__ is not Environment.reset_async
|
| 379 |
+
|
| 380 |
+
if is_async:
|
| 381 |
+
sig = inspect.signature(_env.reset_async)
|
| 382 |
+
else:
|
| 383 |
+
sig = inspect.signature(_env.reset)
|
| 384 |
+
valid_kwargs = self._get_valid_kwargs(sig, kwargs)
|
| 385 |
+
|
| 386 |
+
if is_async:
|
| 387 |
+
observation = await _env.reset_async(**valid_kwargs)
|
| 388 |
+
else:
|
| 389 |
+
observation = await self._run_sync_in_thread_pool(
|
| 390 |
+
_env.reset, **valid_kwargs
|
| 391 |
+
)
|
| 392 |
+
return ResetResponse(**serialize_observation(observation))
|
| 393 |
+
finally:
|
| 394 |
+
_env.close()
|
| 395 |
+
|
| 396 |
+
# Helper function to handle step endpoint
|
| 397 |
+
async def step_handler(request: StepRequest) -> StepResponse:
|
| 398 |
+
"""Step endpoint - executes action and returns observation."""
|
| 399 |
+
action_data = request.action
|
| 400 |
+
|
| 401 |
+
try:
|
| 402 |
+
action = deserialize_action(action_data, self.action_cls)
|
| 403 |
+
except ValidationError as e:
|
| 404 |
+
raise HTTPException(
|
| 405 |
+
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors()
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
_env = self._env_factory()
|
| 409 |
+
|
| 410 |
+
try:
|
| 411 |
+
kwargs = request.model_dump(exclude_unset=True, exclude={"action"})
|
| 412 |
+
|
| 413 |
+
is_async = _env.step_async.__func__ is not Environment.step_async
|
| 414 |
+
|
| 415 |
+
if is_async:
|
| 416 |
+
sig = inspect.signature(_env.step_async)
|
| 417 |
+
else:
|
| 418 |
+
sig = inspect.signature(_env.step)
|
| 419 |
+
valid_kwargs = self._get_valid_kwargs(
|
| 420 |
+
sig, kwargs, skip_params={"action"}
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
if is_async:
|
| 424 |
+
observation = await _env.step_async(action, **valid_kwargs)
|
| 425 |
+
else:
|
| 426 |
+
observation = await self._run_sync_in_thread_pool(
|
| 427 |
+
_env.step, action, **valid_kwargs
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
return StepResponse(**serialize_observation(observation))
|
| 431 |
+
finally:
|
| 432 |
+
_env.close()
|
| 433 |
+
|
| 434 |
+
# Register routes using the helpers
|
| 435 |
+
@app.post(
|
| 436 |
+
"/reset",
|
| 437 |
+
response_model=ResetResponse,
|
| 438 |
+
tags=["Environment Control"],
|
| 439 |
+
summary="Reset the environment",
|
| 440 |
+
description="""
|
| 441 |
+
Reset the environment to its initial state and return the first observation.
|
| 442 |
+
|
| 443 |
+
You can optionally provide a seed for reproducibility and an episode_id for tracking.
|
| 444 |
+
""",
|
| 445 |
+
responses={
|
| 446 |
+
200: {
|
| 447 |
+
"description": "Environment reset successfully",
|
| 448 |
+
"content": {
|
| 449 |
+
"application/json": {
|
| 450 |
+
"example": {
|
| 451 |
+
"observation": {"status": "ready", "data": {}},
|
| 452 |
+
"reward": None,
|
| 453 |
+
"done": False,
|
| 454 |
+
}
|
| 455 |
+
}
|
| 456 |
+
},
|
| 457 |
+
}
|
| 458 |
+
},
|
| 459 |
+
)
|
| 460 |
+
async def reset(
|
| 461 |
+
request: ResetRequest = Body(default_factory=ResetRequest),
|
| 462 |
+
) -> ResetResponse:
|
| 463 |
+
return await reset_handler(request)
|
| 464 |
+
|
| 465 |
+
@app.post(
|
| 466 |
+
"/step",
|
| 467 |
+
response_model=StepResponse,
|
| 468 |
+
tags=["Environment Control"],
|
| 469 |
+
summary="Execute an action in the environment",
|
| 470 |
+
description="""
|
| 471 |
+
Execute an action in the environment and receive the resulting observation.
|
| 472 |
+
|
| 473 |
+
The action must conform to the environment's action schema, which can be
|
| 474 |
+
retrieved from the `/schema` endpoint. If the action is invalid,
|
| 475 |
+
the endpoint will return HTTP 422 with detailed validation errors.
|
| 476 |
+
|
| 477 |
+
The response includes:
|
| 478 |
+
- **observation**: The environment's response to the action
|
| 479 |
+
- **reward**: Optional reward signal (float or None)
|
| 480 |
+
- **done**: Boolean indicating if the episode has terminated
|
| 481 |
+
""",
|
| 482 |
+
responses={
|
| 483 |
+
200: {
|
| 484 |
+
"description": "Action executed successfully",
|
| 485 |
+
"content": {
|
| 486 |
+
"application/json": {
|
| 487 |
+
"example": {
|
| 488 |
+
"observation": {"status": "success", "data": {}},
|
| 489 |
+
"reward": 1.0,
|
| 490 |
+
"done": False,
|
| 491 |
+
}
|
| 492 |
+
}
|
| 493 |
+
},
|
| 494 |
+
},
|
| 495 |
+
422: {
|
| 496 |
+
"description": "Validation error - invalid action format or values",
|
| 497 |
+
"content": {
|
| 498 |
+
"application/json": {
|
| 499 |
+
"example": {
|
| 500 |
+
"detail": [
|
| 501 |
+
{
|
| 502 |
+
"type": "string_too_short",
|
| 503 |
+
"loc": ["body", "action", "message"],
|
| 504 |
+
"msg": "String should have at least 1 character",
|
| 505 |
+
"input": "",
|
| 506 |
+
}
|
| 507 |
+
]
|
| 508 |
+
}
|
| 509 |
+
}
|
| 510 |
+
},
|
| 511 |
+
},
|
| 512 |
+
500: {"description": "Internal server error during action execution"},
|
| 513 |
+
},
|
| 514 |
+
)
|
| 515 |
+
async def step(request: StepRequest) -> StepResponse:
|
| 516 |
+
return await step_handler(request)
|
| 517 |
+
|
| 518 |
+
def get_state_handler() -> State:
|
| 519 |
+
_env = self._env_factory()
|
| 520 |
+
try:
|
| 521 |
+
return _env.state
|
| 522 |
+
finally:
|
| 523 |
+
_env.close()
|
| 524 |
+
|
| 525 |
+
def get_metadata_handler() -> EnvironmentMetadata:
|
| 526 |
+
_env = self._env_factory()
|
| 527 |
+
try:
|
| 528 |
+
return _env.get_metadata()
|
| 529 |
+
finally:
|
| 530 |
+
_env.close()
|
| 531 |
+
|
| 532 |
+
get_endpoints = [
|
| 533 |
+
GetEndpointConfig(
|
| 534 |
+
path="/state",
|
| 535 |
+
handler=get_state_handler,
|
| 536 |
+
response_model=State,
|
| 537 |
+
tag="State Management",
|
| 538 |
+
summary="Get current environment state",
|
| 539 |
+
description="""
|
| 540 |
+
Retrieve the current internal state of the environment.
|
| 541 |
+
|
| 542 |
+
The structure of the state object is defined by the environment's State model.
|
| 543 |
+
""",
|
| 544 |
+
),
|
| 545 |
+
GetEndpointConfig(
|
| 546 |
+
path="/metadata",
|
| 547 |
+
handler=get_metadata_handler,
|
| 548 |
+
response_model=EnvironmentMetadata,
|
| 549 |
+
tag="Environment Info",
|
| 550 |
+
summary="Get environment metadata",
|
| 551 |
+
description="""
|
| 552 |
+
Get metadata about this environment.
|
| 553 |
+
|
| 554 |
+
Returns information about the environment including name, description,
|
| 555 |
+
version, author, and documentation links.
|
| 556 |
+
""",
|
| 557 |
+
),
|
| 558 |
+
GetEndpointConfig(
|
| 559 |
+
path="/health",
|
| 560 |
+
handler=lambda: HealthResponse(status="healthy"),
|
| 561 |
+
response_model=HealthResponse,
|
| 562 |
+
tag="Health",
|
| 563 |
+
summary="Health check",
|
| 564 |
+
description="Check if the environment server is running and healthy.",
|
| 565 |
+
),
|
| 566 |
+
]
|
| 567 |
+
register_get_endpoints(app, get_endpoints)
|
| 568 |
+
|
| 569 |
+
# Register combined schema endpoint
|
| 570 |
+
@app.get(
|
| 571 |
+
"/schema",
|
| 572 |
+
response_model=SchemaResponse,
|
| 573 |
+
tags=["Schema"],
|
| 574 |
+
summary="Get all JSON schemas",
|
| 575 |
+
description="""
|
| 576 |
+
Get JSON schemas for actions, observations, and state in a single response.
|
| 577 |
+
|
| 578 |
+
Returns a combined schema object containing:
|
| 579 |
+
- **action**: JSON schema for actions accepted by this environment
|
| 580 |
+
- **observation**: JSON schema for observations returned by this environment
|
| 581 |
+
- **state**: JSON schema for environment state objects
|
| 582 |
+
|
| 583 |
+
This is more efficient than calling individual schema endpoints and provides
|
| 584 |
+
all schema information needed to interact with the environment.
|
| 585 |
+
""",
|
| 586 |
+
responses={
|
| 587 |
+
200: {
|
| 588 |
+
"description": "Combined schemas retrieved successfully",
|
| 589 |
+
"content": {
|
| 590 |
+
"application/json": {
|
| 591 |
+
"example": {
|
| 592 |
+
"action": {
|
| 593 |
+
"type": "object",
|
| 594 |
+
"properties": {"message": {"type": "string"}},
|
| 595 |
+
},
|
| 596 |
+
"observation": {
|
| 597 |
+
"type": "object",
|
| 598 |
+
"properties": {"response": {"type": "string"}},
|
| 599 |
+
},
|
| 600 |
+
"state": {
|
| 601 |
+
"type": "object",
|
| 602 |
+
"properties": {"step_count": {"type": "integer"}},
|
| 603 |
+
},
|
| 604 |
+
}
|
| 605 |
+
}
|
| 606 |
+
},
|
| 607 |
+
}
|
| 608 |
+
},
|
| 609 |
+
)
|
| 610 |
+
async def get_schemas() -> SchemaResponse:
|
| 611 |
+
"""Return all schemas in one response."""
|
| 612 |
+
return SchemaResponse(
|
| 613 |
+
action=self.action_cls.model_json_schema(),
|
| 614 |
+
observation=self.observation_cls.model_json_schema(),
|
| 615 |
+
state=State.model_json_schema(),
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
# Register WebSocket endpoint for persistent sessions
|
| 619 |
+
@app.websocket("/ws")
|
| 620 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 621 |
+
"""
|
| 622 |
+
WebSocket endpoint for persistent environment sessions.
|
| 623 |
+
|
| 624 |
+
Each WebSocket connection gets its own environment instance.
|
| 625 |
+
|
| 626 |
+
Message Protocol:
|
| 627 |
+
- Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage
|
| 628 |
+
- Server responds: WSObservationResponse | WSStateResponse | WSErrorResponse
|
| 629 |
+
"""
|
| 630 |
+
await websocket.accept()
|
| 631 |
+
|
| 632 |
+
session_id = None
|
| 633 |
+
session_env = None
|
| 634 |
+
|
| 635 |
+
try:
|
| 636 |
+
# Create session with dedicated environment
|
| 637 |
+
session_id, session_env = await self._create_session()
|
| 638 |
+
|
| 639 |
+
while True:
|
| 640 |
+
# Receive message from client
|
| 641 |
+
raw_message = await websocket.receive_text()
|
| 642 |
+
|
| 643 |
+
try:
|
| 644 |
+
message_dict = json.loads(raw_message)
|
| 645 |
+
except json.JSONDecodeError as e:
|
| 646 |
+
error_resp = WSErrorResponse(
|
| 647 |
+
data={
|
| 648 |
+
"message": f"Invalid JSON: {e}",
|
| 649 |
+
"code": "INVALID_JSON",
|
| 650 |
+
}
|
| 651 |
+
)
|
| 652 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 653 |
+
continue
|
| 654 |
+
|
| 655 |
+
msg_type = message_dict.get("type", "")
|
| 656 |
+
|
| 657 |
+
try:
|
| 658 |
+
match msg_type:
|
| 659 |
+
case "reset":
|
| 660 |
+
msg = WSResetMessage(**message_dict)
|
| 661 |
+
|
| 662 |
+
is_async = (
|
| 663 |
+
session_env.reset_async.__func__
|
| 664 |
+
is not Environment.reset_async
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
if is_async:
|
| 668 |
+
sig = inspect.signature(session_env.reset_async)
|
| 669 |
+
valid_kwargs = self._get_valid_kwargs(sig, msg.data)
|
| 670 |
+
observation = await session_env.reset_async(
|
| 671 |
+
**valid_kwargs
|
| 672 |
+
)
|
| 673 |
+
else:
|
| 674 |
+
sig = inspect.signature(session_env.reset)
|
| 675 |
+
valid_kwargs = self._get_valid_kwargs(sig, msg.data)
|
| 676 |
+
observation = await self._run_in_session_executor(
|
| 677 |
+
session_id, session_env.reset, **valid_kwargs
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
self._update_session_activity(session_id)
|
| 681 |
+
|
| 682 |
+
response = WSObservationResponse(
|
| 683 |
+
data=serialize_observation(observation)
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
case "step":
|
| 687 |
+
msg = WSStepMessage(**message_dict)
|
| 688 |
+
action = deserialize_action(msg.data, self.action_cls)
|
| 689 |
+
|
| 690 |
+
is_async = (
|
| 691 |
+
session_env.step_async.__func__
|
| 692 |
+
is not Environment.step_async
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
if is_async:
|
| 696 |
+
observation = await session_env.step_async(action)
|
| 697 |
+
else:
|
| 698 |
+
observation = await self._run_in_session_executor(
|
| 699 |
+
session_id, session_env.step, action
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
self._update_session_activity(
|
| 703 |
+
session_id, increment_step=True
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
response = WSObservationResponse(
|
| 707 |
+
data=serialize_observation(observation)
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
case "state":
|
| 711 |
+
msg = WSStateMessage(**message_dict)
|
| 712 |
+
state = session_env.state
|
| 713 |
+
if hasattr(state, "model_dump"):
|
| 714 |
+
state_data = state.model_dump()
|
| 715 |
+
else:
|
| 716 |
+
state_data = dict(state) if state else {}
|
| 717 |
+
|
| 718 |
+
response = WSStateResponse(data=state_data)
|
| 719 |
+
|
| 720 |
+
case "close":
|
| 721 |
+
msg = WSCloseMessage(**message_dict)
|
| 722 |
+
break
|
| 723 |
+
|
| 724 |
+
case _:
|
| 725 |
+
response = WSErrorResponse(
|
| 726 |
+
data={
|
| 727 |
+
"message": f"Unknown message type: {msg_type}",
|
| 728 |
+
"code": "UNKNOWN_TYPE",
|
| 729 |
+
}
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
await websocket.send_text(response.model_dump_json())
|
| 733 |
+
|
| 734 |
+
except ValidationError as e:
|
| 735 |
+
error_resp = WSErrorResponse(
|
| 736 |
+
data={
|
| 737 |
+
"message": "Invalid message",
|
| 738 |
+
"code": "VALIDATION_ERROR",
|
| 739 |
+
"errors": e.errors(),
|
| 740 |
+
}
|
| 741 |
+
)
|
| 742 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 743 |
+
except Exception as e:
|
| 744 |
+
error_resp = WSErrorResponse(
|
| 745 |
+
data={"message": str(e), "code": "EXECUTION_ERROR"}
|
| 746 |
+
)
|
| 747 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 748 |
+
|
| 749 |
+
except WebSocketDisconnect:
|
| 750 |
+
pass
|
| 751 |
+
except SessionCapacityError as e:
|
| 752 |
+
error_resp = WSErrorResponse(
|
| 753 |
+
data={
|
| 754 |
+
"message": str(e),
|
| 755 |
+
"code": "CAPACITY_REACHED",
|
| 756 |
+
"active_sessions": e.active_sessions,
|
| 757 |
+
"max_sessions": e.max_sessions,
|
| 758 |
+
}
|
| 759 |
+
)
|
| 760 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 761 |
+
except EnvironmentFactoryError as e:
|
| 762 |
+
error_resp = WSErrorResponse(
|
| 763 |
+
data={
|
| 764 |
+
"message": str(e),
|
| 765 |
+
"code": "FACTORY_ERROR",
|
| 766 |
+
"factory_name": e.factory_name,
|
| 767 |
+
}
|
| 768 |
+
)
|
| 769 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 770 |
+
except Exception as e:
|
| 771 |
+
error_resp = WSErrorResponse(
|
| 772 |
+
data={"message": str(e), "code": "SESSION_ERROR"}
|
| 773 |
+
)
|
| 774 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 775 |
+
finally:
|
| 776 |
+
if session_id:
|
| 777 |
+
await self._destroy_session(session_id)
|
| 778 |
+
try:
|
| 779 |
+
await websocket.close()
|
| 780 |
+
except RuntimeError:
|
| 781 |
+
pass
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
def create_app(
|
| 785 |
+
env: Callable[[], Environment],
|
| 786 |
+
action_cls: Type[Action],
|
| 787 |
+
observation_cls: Type[Observation],
|
| 788 |
+
env_name: Optional[str] = None,
|
| 789 |
+
max_concurrent_envs: Optional[int] = None,
|
| 790 |
+
concurrency_config: Optional[ConcurrencyConfig] = None,
|
| 791 |
+
) -> FastAPI:
|
| 792 |
+
"""
|
| 793 |
+
Create a FastAPI application with or without web interface.
|
| 794 |
+
|
| 795 |
+
This function creates a FastAPI app with the web interface enabled by default,
|
| 796 |
+
including README integration for better user experience.
|
| 797 |
+
|
| 798 |
+
Args:
|
| 799 |
+
env: Environment factory (callable) that creates new instances
|
| 800 |
+
action_cls: The Action subclass this environment expects
|
| 801 |
+
observation_cls: The Observation subclass this environment returns
|
| 802 |
+
env_name: Optional environment name for README loading
|
| 803 |
+
max_concurrent_envs: Maximum concurrent WebSocket sessions.
|
| 804 |
+
Mutually exclusive with concurrency_config.
|
| 805 |
+
concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings.
|
| 806 |
+
Mutually exclusive with max_concurrent_envs.
|
| 807 |
+
|
| 808 |
+
Returns:
|
| 809 |
+
FastAPI application instance with or without web interface and README integration
|
| 810 |
+
"""
|
| 811 |
+
# Check if web interface should be enabled
|
| 812 |
+
# This can be controlled via environment variable or build argument
|
| 813 |
+
enable_web = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in (
|
| 814 |
+
"true",
|
| 815 |
+
"1",
|
| 816 |
+
"yes",
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
if enable_web:
|
| 820 |
+
# Import web interface only when needed
|
| 821 |
+
from .web_interface import create_web_interface_app
|
| 822 |
+
|
| 823 |
+
return create_web_interface_app(
|
| 824 |
+
env,
|
| 825 |
+
action_cls,
|
| 826 |
+
observation_cls,
|
| 827 |
+
env_name,
|
| 828 |
+
max_concurrent_envs,
|
| 829 |
+
concurrency_config,
|
| 830 |
+
)
|
| 831 |
+
else:
|
| 832 |
+
# Use standard FastAPI app without web interface
|
| 833 |
+
return create_fastapi_app(
|
| 834 |
+
env, action_cls, observation_cls, max_concurrent_envs, concurrency_config
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
def create_fastapi_app(
|
| 839 |
+
env: Callable[[], Environment],
|
| 840 |
+
action_cls: Type[Action],
|
| 841 |
+
observation_cls: Type[Observation],
|
| 842 |
+
max_concurrent_envs: Optional[int] = None,
|
| 843 |
+
concurrency_config: Optional[ConcurrencyConfig] = None,
|
| 844 |
+
) -> FastAPI:
|
| 845 |
+
"""
|
| 846 |
+
Create a FastAPI application with comprehensive documentation.
|
| 847 |
+
|
| 848 |
+
Args:
|
| 849 |
+
env: Environment factory (callable) that creates new instances
|
| 850 |
+
action_cls: The Action subclass this environment expects
|
| 851 |
+
observation_cls: The Observation subclass this environment returns
|
| 852 |
+
max_concurrent_envs: Maximum concurrent WebSocket sessions.
|
| 853 |
+
Mutually exclusive with concurrency_config.
|
| 854 |
+
concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings.
|
| 855 |
+
Mutually exclusive with max_concurrent_envs.
|
| 856 |
+
|
| 857 |
+
Returns:
|
| 858 |
+
FastAPI application instance
|
| 859 |
+
"""
|
| 860 |
+
try:
|
| 861 |
+
from fastapi import FastAPI
|
| 862 |
+
except ImportError:
|
| 863 |
+
raise ImportError(
|
| 864 |
+
"FastAPI is required. Install with: pip install fastapi uvicorn"
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
app = FastAPI(
|
| 868 |
+
title="OpenEnv Environment HTTP API",
|
| 869 |
+
version="1.0.0",
|
| 870 |
+
description="""
|
| 871 |
+
# OpenEnv Environment HTTP API
|
| 872 |
+
|
| 873 |
+
HTTP API for interacting with OpenEnv environments through a standardized interface.
|
| 874 |
+
|
| 875 |
+
## Features
|
| 876 |
+
|
| 877 |
+
* **Environment Reset**: Initialize or restart episodes
|
| 878 |
+
* **Action Execution**: Send actions and receive observations
|
| 879 |
+
* **State Inspection**: Query current environment state
|
| 880 |
+
* **Schema Access**: Retrieve JSON schemas for actions and observations
|
| 881 |
+
|
| 882 |
+
## Workflow
|
| 883 |
+
|
| 884 |
+
1. Call `/reset` to start a new episode and get initial observation
|
| 885 |
+
2. Call `/step` repeatedly with actions to interact with environment
|
| 886 |
+
3. Episode ends when observation returns `done: true`
|
| 887 |
+
4. Call `/state` anytime to inspect current environment state
|
| 888 |
+
|
| 889 |
+
## Documentation
|
| 890 |
+
|
| 891 |
+
* **Swagger UI**: Available at `/docs`
|
| 892 |
+
* **ReDoc**: Available at `/redoc`
|
| 893 |
+
* **OpenAPI Schema**: Available at `/openapi.json`
|
| 894 |
+
""",
|
| 895 |
+
openapi_tags=[
|
| 896 |
+
{
|
| 897 |
+
"name": "Environment Control",
|
| 898 |
+
"description": "Core operations for environment interaction (reset, step)",
|
| 899 |
+
},
|
| 900 |
+
{
|
| 901 |
+
"name": "State Management",
|
| 902 |
+
"description": "Operations for inspecting environment state",
|
| 903 |
+
},
|
| 904 |
+
{
|
| 905 |
+
"name": "Environment Info",
|
| 906 |
+
"description": "Information about the environment",
|
| 907 |
+
},
|
| 908 |
+
{
|
| 909 |
+
"name": "Schema",
|
| 910 |
+
"description": "JSON Schema endpoints for actions, observations, and state",
|
| 911 |
+
},
|
| 912 |
+
{"name": "Health", "description": "Service health and status checks"},
|
| 913 |
+
],
|
| 914 |
+
docs_url="/docs",
|
| 915 |
+
redoc_url="/redoc",
|
| 916 |
+
openapi_url="/openapi.json",
|
| 917 |
+
contact={
|
| 918 |
+
"name": "OpenEnv Team",
|
| 919 |
+
"url": "https://github.com/meta-pytorch/OpenEnv",
|
| 920 |
+
},
|
| 921 |
+
license_info={
|
| 922 |
+
"name": "BSD-3-Clause",
|
| 923 |
+
"url": "https://github.com/meta-pytorch/OpenEnv/blob/main/LICENSE",
|
| 924 |
+
},
|
| 925 |
+
)
|
| 926 |
+
|
| 927 |
+
server = HTTPEnvServer(
|
| 928 |
+
env,
|
| 929 |
+
action_cls,
|
| 930 |
+
observation_cls,
|
| 931 |
+
max_concurrent_envs,
|
| 932 |
+
concurrency_config=concurrency_config,
|
| 933 |
+
)
|
| 934 |
+
server.register_routes(app)
|
| 935 |
+
return app
|
{core → src/core}/env_server/interfaces.py
RENAMED
|
@@ -1,118 +1,194 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
from abc import ABC, abstractmethod
|
| 8 |
-
from typing import Any, Protocol, TypedDict
|
| 9 |
-
|
| 10 |
-
from .types import Action, Observation, State
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
"""
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from typing import Any, Generic, Optional, Protocol, TypedDict, TypeVar
|
| 9 |
+
|
| 10 |
+
from .types import Action, Observation, State, EnvironmentMetadata
|
| 11 |
+
|
| 12 |
+
ActT = TypeVar("ActT", bound=Action)
|
| 13 |
+
ObsT = TypeVar("ObsT", bound=Observation)
|
| 14 |
+
StateT = TypeVar("StateT", bound=State)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Message(TypedDict):
|
| 18 |
+
"""A message in a conversation.
|
| 19 |
+
|
| 20 |
+
Compatible with Huggingface chat template format.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
role: str
|
| 24 |
+
content: str
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ModelTokenizer(Protocol):
|
| 28 |
+
"""Protocol for tokenizers that support chat templates.
|
| 29 |
+
|
| 30 |
+
This protocol defines the interface that tokenizers must implement
|
| 31 |
+
to work with chat-based environments. It's compatible with
|
| 32 |
+
Huggingface transformers tokenizers.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def apply_chat_template(
|
| 36 |
+
self,
|
| 37 |
+
conversation: list[Message],
|
| 38 |
+
tokenize: bool = True,
|
| 39 |
+
return_tensors: str | None = None,
|
| 40 |
+
**kwargs: Any,
|
| 41 |
+
) -> Any:
|
| 42 |
+
"""Apply a chat template to format and optionally tokenize a conversation.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
conversation: List of message dictionaries with 'role' and 'content'
|
| 46 |
+
tokenize: Whether to tokenize the output
|
| 47 |
+
return_tensors: Format for returned tensors ('pt' for PyTorch)
|
| 48 |
+
**kwargs: Additional arguments
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Formatted and optionally tokenized conversation
|
| 52 |
+
"""
|
| 53 |
+
...
|
| 54 |
+
|
| 55 |
+
def decode(
|
| 56 |
+
self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any
|
| 57 |
+
) -> str:
|
| 58 |
+
"""Decode token IDs back to text.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
token_ids: Token IDs to decode
|
| 62 |
+
skip_special_tokens: Whether to skip special tokens in output
|
| 63 |
+
**kwargs: Additional arguments
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Decoded text string
|
| 67 |
+
"""
|
| 68 |
+
...
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Transform(ABC, Generic[ObsT]):
|
| 72 |
+
"""Transform observations to add rewards, metrics, or other modifications.
|
| 73 |
+
|
| 74 |
+
Transforms follow the TorchRL pattern where they take an observation
|
| 75 |
+
and return a (potentially modified) observation. This allows for
|
| 76 |
+
flexible reward computation and observation augmentation.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
@abstractmethod
|
| 80 |
+
def __call__(self, observation: ObsT) -> ObsT:
|
| 81 |
+
"""Transform an observation.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
observation: The input observation
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
The transformed observation
|
| 88 |
+
"""
|
| 89 |
+
pass
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class Environment(ABC, Generic[ActT, ObsT, StateT]):
|
| 93 |
+
"""Base class for all environment servers following Gym/Gymnasium API.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
transform: Optional transform to apply to observations
|
| 97 |
+
|
| 98 |
+
Class Attributes:
|
| 99 |
+
SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions.
|
| 100 |
+
When True, multiple WebSocket connections can each have their own
|
| 101 |
+
environment instance (up to max_concurrent_envs). When False (default),
|
| 102 |
+
the environment should only be used with a single session at a time.
|
| 103 |
+
|
| 104 |
+
Set this to True in your Environment subclass if:
|
| 105 |
+
- The environment uses proper session isolation (e.g., unique working dirs)
|
| 106 |
+
- No shared mutable state exists between instances
|
| 107 |
+
- External resources (databases, APIs) can handle concurrent access
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
# Class-level flag indicating whether this environment supports concurrent sessions
|
| 111 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = False
|
| 112 |
+
|
| 113 |
+
def __init__(self, transform: Optional[Transform[ObsT]] = None):
|
| 114 |
+
self.transform = transform
|
| 115 |
+
|
| 116 |
+
@abstractmethod
|
| 117 |
+
def reset(
|
| 118 |
+
self,
|
| 119 |
+
seed: Optional[int] = None,
|
| 120 |
+
episode_id: Optional[str] = None,
|
| 121 |
+
**kwargs: Any,
|
| 122 |
+
) -> ObsT:
|
| 123 |
+
"""Reset the environment and return initial observation."""
|
| 124 |
+
pass
|
| 125 |
+
|
| 126 |
+
async def reset_async(
|
| 127 |
+
self,
|
| 128 |
+
seed: Optional[int] = None,
|
| 129 |
+
episode_id: Optional[str] = None,
|
| 130 |
+
**kwargs: Any,
|
| 131 |
+
) -> ObsT:
|
| 132 |
+
"""Async version of reset. Default implementation calls sync reset.
|
| 133 |
+
|
| 134 |
+
Override to provide true async implementation.
|
| 135 |
+
"""
|
| 136 |
+
return self.reset(seed=seed, episode_id=episode_id, **kwargs)
|
| 137 |
+
|
| 138 |
+
@abstractmethod
|
| 139 |
+
def step(
|
| 140 |
+
self,
|
| 141 |
+
action: ActT,
|
| 142 |
+
timeout_s: Optional[float] = None,
|
| 143 |
+
**kwargs: Any,
|
| 144 |
+
) -> ObsT:
|
| 145 |
+
"""Take a step in the environment."""
|
| 146 |
+
pass
|
| 147 |
+
|
| 148 |
+
async def step_async(
|
| 149 |
+
self,
|
| 150 |
+
action: ActT,
|
| 151 |
+
timeout_s: Optional[float] = None,
|
| 152 |
+
**kwargs: Any,
|
| 153 |
+
) -> ObsT:
|
| 154 |
+
"""Async version of step. Default implementation calls sync step.
|
| 155 |
+
|
| 156 |
+
Override to provide true async implementation.
|
| 157 |
+
"""
|
| 158 |
+
return self.step(action, timeout_s=timeout_s, **kwargs)
|
| 159 |
+
|
| 160 |
+
@property
|
| 161 |
+
@abstractmethod
|
| 162 |
+
def state(self) -> StateT:
|
| 163 |
+
"""Get the current environment state."""
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
def get_metadata(self) -> EnvironmentMetadata:
|
| 167 |
+
"""
|
| 168 |
+
Get metadata about this environment.
|
| 169 |
+
|
| 170 |
+
Override this method to provide custom metadata for the environment.
|
| 171 |
+
Default implementation returns basic metadata derived from class name.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
EnvironmentMetadata with environment information
|
| 175 |
+
"""
|
| 176 |
+
return EnvironmentMetadata(
|
| 177 |
+
name=self.__class__.__name__,
|
| 178 |
+
description=f"{self.__class__.__name__} environment",
|
| 179 |
+
version="1.0.0",
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def _apply_transform(self, observation: ObsT) -> ObsT:
|
| 183 |
+
"""Apply transform if one is provided."""
|
| 184 |
+
if self.transform is not None:
|
| 185 |
+
return self.transform(observation)
|
| 186 |
+
return observation
|
| 187 |
+
|
| 188 |
+
def close(self) -> None:
|
| 189 |
+
"""Clean up resources used by the environment.
|
| 190 |
+
|
| 191 |
+
Override this method to implement custom cleanup logic.
|
| 192 |
+
Called when the environment is being destroyed or reset.
|
| 193 |
+
"""
|
| 194 |
+
pass
|
src/core/env_server/route_config.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Route configuration utilities for declarative FastAPI route registration.
|
| 9 |
+
|
| 10 |
+
This module provides utilities to reduce boilerplate in route registration
|
| 11 |
+
by using configuration objects instead of repeated function calls.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Callable, List, Type
|
| 16 |
+
|
| 17 |
+
from fastapi import FastAPI
|
| 18 |
+
from pydantic import BaseModel
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class GetEndpointConfig:
|
| 23 |
+
"""Configuration for a simple GET endpoint."""
|
| 24 |
+
|
| 25 |
+
path: str
|
| 26 |
+
handler: Callable[[], BaseModel | dict]
|
| 27 |
+
response_model: Type[BaseModel] | type[dict]
|
| 28 |
+
tag: str
|
| 29 |
+
summary: str
|
| 30 |
+
description: str
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def register_get_endpoints(app: FastAPI, configs: List[GetEndpointConfig]) -> None:
|
| 34 |
+
"""
|
| 35 |
+
Register multiple GET endpoints from configuration.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
app: FastAPI application instance
|
| 39 |
+
configs: List of GET endpoint configurations
|
| 40 |
+
"""
|
| 41 |
+
for config in configs:
|
| 42 |
+
# Capture handler in a closure to avoid non-serializable default parameter
|
| 43 |
+
def make_endpoint(
|
| 44 |
+
handler: Callable[[], BaseModel | dict],
|
| 45 |
+
) -> Callable[[], BaseModel | dict]:
|
| 46 |
+
async def endpoint() -> BaseModel | dict:
|
| 47 |
+
return handler()
|
| 48 |
+
|
| 49 |
+
return endpoint
|
| 50 |
+
|
| 51 |
+
app.get(
|
| 52 |
+
config.path,
|
| 53 |
+
response_model=config.response_model,
|
| 54 |
+
tags=[config.tag],
|
| 55 |
+
summary=config.summary,
|
| 56 |
+
description=config.description,
|
| 57 |
+
)(make_endpoint(config.handler))
|
src/core/env_server/serialization.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Shared serialization and deserialization utilities for OpenEnv HTTP servers.
|
| 9 |
+
|
| 10 |
+
This module provides common utilities for converting between JSON dictionaries
|
| 11 |
+
and Pydantic models (Action/Observation) to eliminate code duplication across
|
| 12 |
+
HTTP server and web interface implementations.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from typing import Any, Dict, Type
|
| 16 |
+
|
| 17 |
+
from .types import Action, Observation
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action:
|
| 21 |
+
"""
|
| 22 |
+
Convert JSON dict to Action instance using Pydantic validation.
|
| 23 |
+
|
| 24 |
+
This is a basic deserialization that works for most environments.
|
| 25 |
+
For special cases (e.g., tensor fields, custom type conversions),
|
| 26 |
+
use deserialize_action_with_preprocessing().
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
action_data: Dictionary containing action data
|
| 30 |
+
action_cls: The Action subclass to instantiate
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Action instance
|
| 34 |
+
|
| 35 |
+
Raises:
|
| 36 |
+
ValidationError: If action_data is invalid for the action class
|
| 37 |
+
|
| 38 |
+
Note:
|
| 39 |
+
This uses Pydantic's model_validate() for automatic validation.
|
| 40 |
+
"""
|
| 41 |
+
return action_cls.model_validate(action_data)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def deserialize_action_with_preprocessing(
|
| 45 |
+
action_data: Dict[str, Any], action_cls: Type[Action]
|
| 46 |
+
) -> Action:
|
| 47 |
+
"""
|
| 48 |
+
Convert JSON dict to Action instance with preprocessing for special types.
|
| 49 |
+
|
| 50 |
+
This version handles common type conversions needed for web interfaces:
|
| 51 |
+
- Converting lists/strings to tensors for 'tokens' field
|
| 52 |
+
- Converting string action_id to int
|
| 53 |
+
- Other custom preprocessing as needed
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
action_data: Dictionary containing action data
|
| 57 |
+
action_cls: The Action subclass to instantiate
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Action instance
|
| 61 |
+
|
| 62 |
+
Raises:
|
| 63 |
+
ValidationError: If action_data is invalid for the action class
|
| 64 |
+
"""
|
| 65 |
+
processed_data = {}
|
| 66 |
+
|
| 67 |
+
for key, value in action_data.items():
|
| 68 |
+
if key == "tokens" and isinstance(value, (list, str)):
|
| 69 |
+
# Convert list or string to tensor
|
| 70 |
+
if isinstance(value, str):
|
| 71 |
+
# If it's a string, try to parse it as a list of numbers
|
| 72 |
+
try:
|
| 73 |
+
import json
|
| 74 |
+
|
| 75 |
+
value = json.loads(value)
|
| 76 |
+
except Exception:
|
| 77 |
+
# If parsing fails, treat as empty list
|
| 78 |
+
value = []
|
| 79 |
+
if isinstance(value, list):
|
| 80 |
+
try:
|
| 81 |
+
import torch # type: ignore
|
| 82 |
+
|
| 83 |
+
processed_data[key] = torch.tensor(value, dtype=torch.long)
|
| 84 |
+
except ImportError:
|
| 85 |
+
# If torch not available, keep as list
|
| 86 |
+
processed_data[key] = value
|
| 87 |
+
else:
|
| 88 |
+
processed_data[key] = value
|
| 89 |
+
elif key == "action_id" and isinstance(value, str):
|
| 90 |
+
# Convert action_id from string to int
|
| 91 |
+
try:
|
| 92 |
+
processed_data[key] = int(value)
|
| 93 |
+
except ValueError:
|
| 94 |
+
# If conversion fails, keep original value
|
| 95 |
+
processed_data[key] = value
|
| 96 |
+
else:
|
| 97 |
+
processed_data[key] = value
|
| 98 |
+
|
| 99 |
+
return action_cls.model_validate(processed_data)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def serialize_observation(observation: Observation) -> Dict[str, Any]:
|
| 103 |
+
"""
|
| 104 |
+
Convert Observation instance to JSON-compatible dict using Pydantic.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
observation: Observation instance
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Dictionary compatible with EnvClient._parse_result()
|
| 111 |
+
|
| 112 |
+
The format matches what EnvClient expects:
|
| 113 |
+
{
|
| 114 |
+
"observation": {...}, # Observation fields
|
| 115 |
+
"reward": float | None,
|
| 116 |
+
"done": bool,
|
| 117 |
+
}
|
| 118 |
+
"""
|
| 119 |
+
# Use Pydantic's model_dump() for serialization
|
| 120 |
+
obs_dict = observation.model_dump(
|
| 121 |
+
exclude={
|
| 122 |
+
"reward",
|
| 123 |
+
"done",
|
| 124 |
+
"metadata",
|
| 125 |
+
} # Exclude these from observation dict
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Extract reward and done directly from the observation
|
| 129 |
+
reward = observation.reward
|
| 130 |
+
done = observation.done
|
| 131 |
+
|
| 132 |
+
# Return in EnvClient expected format
|
| 133 |
+
return {
|
| 134 |
+
"observation": obs_dict,
|
| 135 |
+
"reward": reward,
|
| 136 |
+
"done": done,
|
| 137 |
+
}
|
src/core/env_server/types.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Any, Dict, Optional, Union, Literal, Annotated
|
| 8 |
+
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Type aliases
|
| 12 |
+
Scalar = Union[int, float, bool]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Action(BaseModel):
|
| 16 |
+
"""Base class for all environment actions.
|
| 17 |
+
|
| 18 |
+
All action subclasses should inherit from this base class.
|
| 19 |
+
Uses Pydantic for automatic validation and serialization.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
model_config = ConfigDict(
|
| 23 |
+
extra="forbid", # Reject unknown fields
|
| 24 |
+
validate_assignment=True, # Validate on field assignment
|
| 25 |
+
arbitrary_types_allowed=True, # Allow numpy arrays, torch tensors, etc.
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
metadata: Dict[str, Any] = Field(
|
| 29 |
+
default_factory=dict, description="Additional metadata for the action"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Observation(BaseModel):
|
| 34 |
+
"""Base class for all environment observations.
|
| 35 |
+
|
| 36 |
+
All observation subclasses should inherit from this base class.
|
| 37 |
+
Uses Pydantic for automatic validation and serialization.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
model_config = ConfigDict(
|
| 41 |
+
extra="forbid",
|
| 42 |
+
validate_assignment=True,
|
| 43 |
+
arbitrary_types_allowed=True,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
done: bool = Field(default=False, description="Whether the episode has terminated")
|
| 47 |
+
reward: bool | int | float | None = Field(
|
| 48 |
+
default=None, description="Reward signal from the last action"
|
| 49 |
+
)
|
| 50 |
+
metadata: Dict[str, Any] = Field(
|
| 51 |
+
default_factory=dict, description="Additional metadata for the observation"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class ResetRequest(BaseModel):
|
| 56 |
+
"""Request model for environment reset."""
|
| 57 |
+
|
| 58 |
+
model_config = ConfigDict(
|
| 59 |
+
extra="allow", # Allow extra fields for custom reset parameters
|
| 60 |
+
json_schema_extra={"examples": [{"seed": 42, "episode_id": "episode-001"}, {}]},
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
seed: Optional[int] = Field(
|
| 64 |
+
default=None, ge=0, description="Random seed for reproducible episodes"
|
| 65 |
+
)
|
| 66 |
+
episode_id: Optional[str] = Field(
|
| 67 |
+
default=None, max_length=255, description="Custom episode identifier"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ResetResponse(BaseModel):
|
| 72 |
+
"""Response model for environment reset."""
|
| 73 |
+
|
| 74 |
+
model_config = ConfigDict(extra="forbid")
|
| 75 |
+
|
| 76 |
+
observation: Dict[str, Any] = Field(
|
| 77 |
+
..., description="Initial observation from the environment"
|
| 78 |
+
)
|
| 79 |
+
reward: Optional[float] = Field(
|
| 80 |
+
default=None, description="Initial reward (typically None at reset)"
|
| 81 |
+
)
|
| 82 |
+
done: bool = Field(
|
| 83 |
+
default=False, description="Whether episode is already done (typically False)"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class StepRequest(BaseModel):
|
| 88 |
+
"""Request model for environment step."""
|
| 89 |
+
|
| 90 |
+
model_config = ConfigDict(
|
| 91 |
+
extra="allow", # Allow extra fields for custom step parameters
|
| 92 |
+
json_schema_extra={
|
| 93 |
+
"examples": [
|
| 94 |
+
{"action": {"value": 1}, "timeout_s": 30.0},
|
| 95 |
+
{"action": {"value": 1}, "render": True, "verbose": False},
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
action: Dict[str, Any] = Field(
|
| 101 |
+
...,
|
| 102 |
+
description="Action to execute, must conform to environment's action schema",
|
| 103 |
+
)
|
| 104 |
+
timeout_s: Optional[float] = Field(
|
| 105 |
+
default=None,
|
| 106 |
+
gt=0,
|
| 107 |
+
description="Optional timeout in seconds for action execution",
|
| 108 |
+
)
|
| 109 |
+
request_id: Optional[str] = Field(
|
| 110 |
+
default=None,
|
| 111 |
+
max_length=255,
|
| 112 |
+
description="Optional request identifier for tracking",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class StepResponse(BaseModel):
|
| 117 |
+
"""Response model for environment step."""
|
| 118 |
+
|
| 119 |
+
model_config = ConfigDict(extra="forbid")
|
| 120 |
+
|
| 121 |
+
observation: Dict[str, Any] = Field(
|
| 122 |
+
..., description="Observation resulting from the action"
|
| 123 |
+
)
|
| 124 |
+
reward: Optional[float] = Field(
|
| 125 |
+
default=None, description="Reward signal from the action"
|
| 126 |
+
)
|
| 127 |
+
done: bool = Field(default=False, description="Whether the episode has terminated")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class BaseMessage(BaseModel):
|
| 131 |
+
"""Base class for WebSocket messages with shared configuration."""
|
| 132 |
+
|
| 133 |
+
model_config = ConfigDict(
|
| 134 |
+
extra="forbid",
|
| 135 |
+
validate_assignment=True,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class State(BaseModel):
|
| 140 |
+
"""Base class for environment state.
|
| 141 |
+
|
| 142 |
+
Represents internal environment state, separate from observations.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
model_config = ConfigDict(
|
| 146 |
+
extra="allow", # Allow extra fields for flexibility
|
| 147 |
+
validate_assignment=True,
|
| 148 |
+
arbitrary_types_allowed=True,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
episode_id: Optional[str] = Field(
|
| 152 |
+
default=None, description="Unique identifier for the current episode"
|
| 153 |
+
)
|
| 154 |
+
step_count: int = Field(
|
| 155 |
+
default=0,
|
| 156 |
+
ge=0, # Greater than or equal to 0
|
| 157 |
+
description="Number of steps taken in the current episode",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class CodeExecResult(BaseMessage):
|
| 162 |
+
"""Result of code execution containing stdout, stderr, and exit code."""
|
| 163 |
+
|
| 164 |
+
stdout: str = Field(description="Standard output from code execution")
|
| 165 |
+
stderr: str = Field(description="Standard error from code execution")
|
| 166 |
+
exit_code: int = Field(description="Exit code from code execution")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class EnvironmentMetadata(BaseMessage):
|
| 170 |
+
"""Metadata about an environment for documentation and UI purposes."""
|
| 171 |
+
|
| 172 |
+
name: str = Field(description="Name of the environment")
|
| 173 |
+
description: str = Field(description="Description of what the environment does")
|
| 174 |
+
readme_content: Optional[str] = Field(
|
| 175 |
+
default=None, description="Content of the README file for the environment"
|
| 176 |
+
)
|
| 177 |
+
version: Optional[str] = Field(
|
| 178 |
+
default=None, description="Version of the environment"
|
| 179 |
+
)
|
| 180 |
+
author: Optional[str] = Field(default=None, description="Author of the environment")
|
| 181 |
+
documentation_url: Optional[str] = Field(
|
| 182 |
+
default=None, description="URL to the environment's documentation"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class SchemaResponse(BaseMessage):
|
| 187 |
+
"""Response model for the combined schema endpoint."""
|
| 188 |
+
|
| 189 |
+
action: Dict[str, Any] = Field(
|
| 190 |
+
description="JSON schema for actions accepted by this environment"
|
| 191 |
+
)
|
| 192 |
+
observation: Dict[str, Any] = Field(
|
| 193 |
+
description="JSON schema for observations returned by this environment"
|
| 194 |
+
)
|
| 195 |
+
state: Dict[str, Any] = Field(
|
| 196 |
+
description="JSON schema for environment state objects"
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class HealthResponse(BaseMessage):
|
| 201 |
+
"""Response model for health check endpoint."""
|
| 202 |
+
|
| 203 |
+
status: str = Field(description="Health status of the environment server")
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class WSResetMessage(BaseMessage):
|
| 207 |
+
"""WebSocket message to reset the environment."""
|
| 208 |
+
|
| 209 |
+
type: Literal["reset"] = Field(default="reset", description="Message type")
|
| 210 |
+
data: Dict[str, Any] = Field(
|
| 211 |
+
default_factory=dict,
|
| 212 |
+
description="Optional reset parameters (seed, episode_id, etc.)",
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class WSStepMessage(BaseMessage):
|
| 217 |
+
"""WebSocket message to execute a step."""
|
| 218 |
+
|
| 219 |
+
type: Literal["step"] = Field(default="step", description="Message type")
|
| 220 |
+
data: Dict[str, Any] = Field(
|
| 221 |
+
..., description="Action data conforming to environment's action schema"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class WSStateMessage(BaseMessage):
|
| 226 |
+
"""WebSocket message to request current state."""
|
| 227 |
+
|
| 228 |
+
type: Literal["state"] = Field(default="state", description="Message type")
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class WSCloseMessage(BaseMessage):
|
| 232 |
+
"""WebSocket message to close the session."""
|
| 233 |
+
|
| 234 |
+
type: Literal["close"] = Field(default="close", description="Message type")
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# Discriminated union for incoming WebSocket messages
|
| 238 |
+
WSIncomingMessage = Annotated[
|
| 239 |
+
WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage,
|
| 240 |
+
Field(discriminator="type"),
|
| 241 |
+
]
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class WSObservationResponse(BaseModel):
|
| 245 |
+
"""WebSocket response containing an observation."""
|
| 246 |
+
|
| 247 |
+
model_config = ConfigDict(extra="forbid")
|
| 248 |
+
|
| 249 |
+
type: str = Field(default="observation", description="Response type")
|
| 250 |
+
data: Dict[str, Any] = Field(description="Observation data")
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class WSStateResponse(BaseModel):
|
| 254 |
+
"""WebSocket response containing environment state."""
|
| 255 |
+
|
| 256 |
+
model_config = ConfigDict(extra="forbid")
|
| 257 |
+
|
| 258 |
+
type: str = Field(default="state", description="Response type")
|
| 259 |
+
data: Dict[str, Any] = Field(description="State data")
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class WSErrorResponse(BaseModel):
|
| 263 |
+
"""WebSocket response for errors."""
|
| 264 |
+
|
| 265 |
+
model_config = ConfigDict(extra="forbid")
|
| 266 |
+
|
| 267 |
+
type: str = Field(default="error", description="Response type")
|
| 268 |
+
data: Dict[str, Any] = Field(description="Error details including message and code")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class ConcurrencyConfig(BaseMessage):
|
| 272 |
+
"""Configuration for concurrent environment sessions."""
|
| 273 |
+
|
| 274 |
+
max_concurrent_envs: int = Field(
|
| 275 |
+
default=1,
|
| 276 |
+
ge=1,
|
| 277 |
+
description="Maximum number of concurrent WebSocket sessions allowed",
|
| 278 |
+
)
|
| 279 |
+
session_timeout: Optional[float] = Field(
|
| 280 |
+
default=None,
|
| 281 |
+
gt=0,
|
| 282 |
+
description="Timeout in seconds for inactive sessions. None means no timeout.",
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class ServerCapacityStatus(BaseMessage):
|
| 287 |
+
"""Status of server capacity for concurrent sessions."""
|
| 288 |
+
|
| 289 |
+
active_sessions: int = Field(
|
| 290 |
+
ge=0,
|
| 291 |
+
description="Number of currently active sessions",
|
| 292 |
+
)
|
| 293 |
+
max_sessions: int = Field(
|
| 294 |
+
ge=1,
|
| 295 |
+
description="Maximum number of allowed sessions",
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
@model_validator(mode="after")
|
| 299 |
+
def check_capacity_bounds(self) -> "ServerCapacityStatus":
|
| 300 |
+
if self.active_sessions > self.max_sessions:
|
| 301 |
+
raise ValueError(
|
| 302 |
+
f"active_sessions ({self.active_sessions}) cannot exceed "
|
| 303 |
+
f"max_sessions ({self.max_sessions})"
|
| 304 |
+
)
|
| 305 |
+
return self
|
| 306 |
+
|
| 307 |
+
@property
|
| 308 |
+
def available_slots(self) -> int:
|
| 309 |
+
"""Number of available session slots."""
|
| 310 |
+
return self.max_sessions - self.active_sessions
|
| 311 |
+
|
| 312 |
+
@property
|
| 313 |
+
def is_at_capacity(self) -> bool:
|
| 314 |
+
"""Whether the server has reached maximum capacity."""
|
| 315 |
+
return self.available_slots == 0
|
| 316 |
+
|
| 317 |
+
@classmethod
|
| 318 |
+
def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus":
|
| 319 |
+
"""Create status from active and max session counts."""
|
| 320 |
+
return cls(
|
| 321 |
+
active_sessions=active,
|
| 322 |
+
max_sessions=max_sessions,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class SessionInfo(BaseMessage):
|
| 327 |
+
"""Information about an active session."""
|
| 328 |
+
|
| 329 |
+
session_id: str = Field(description="Unique identifier for the session")
|
| 330 |
+
created_at: float = Field(description="Unix timestamp when the session was created")
|
| 331 |
+
last_activity_at: float = Field(
|
| 332 |
+
description="Unix timestamp of the last activity in the session"
|
| 333 |
+
)
|
| 334 |
+
step_count: int = Field(
|
| 335 |
+
default=0,
|
| 336 |
+
ge=0,
|
| 337 |
+
description="Number of steps executed in this session",
|
| 338 |
+
)
|
| 339 |
+
environment_type: str = Field(
|
| 340 |
+
description="Environment type for this session (e.g. `CodingEnv`)"
|
| 341 |
+
)
|
{core → src/core}/env_server/web_interface.py
RENAMED
|
@@ -13,55 +13,60 @@ including a two-pane layout for HumanAgent interaction and state observation.
|
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
|
|
|
| 16 |
import json
|
| 17 |
-
import
|
| 18 |
-
from dataclasses import asdict, dataclass
|
| 19 |
from typing import Any, Dict, List, Optional, Type
|
| 20 |
from datetime import datetime
|
| 21 |
|
| 22 |
-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 23 |
-
from fastapi.responses import HTMLResponse
|
| 24 |
-
from
|
| 25 |
-
from pydantic import BaseModel
|
| 26 |
|
| 27 |
from .interfaces import Environment
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
from .types import Action, Observation, State, EnvironmentMetadata
|
| 29 |
|
| 30 |
|
| 31 |
-
def load_environment_metadata(
|
|
|
|
|
|
|
| 32 |
"""
|
| 33 |
Load environment metadata including README content.
|
| 34 |
-
|
| 35 |
Args:
|
| 36 |
env: The environment instance
|
| 37 |
env_name: Optional environment name for README file lookup
|
| 38 |
-
|
| 39 |
Returns:
|
| 40 |
EnvironmentMetadata with loaded information
|
| 41 |
"""
|
| 42 |
# Try to get metadata from environment if it has a method for it
|
| 43 |
-
if hasattr(env,
|
| 44 |
return env.get_metadata()
|
| 45 |
-
|
| 46 |
# Default metadata
|
| 47 |
metadata = EnvironmentMetadata(
|
| 48 |
name=env_name or env.__class__.__name__,
|
| 49 |
description=f"{env.__class__.__name__} environment",
|
| 50 |
-
version="1.0.0"
|
| 51 |
)
|
| 52 |
-
|
| 53 |
# Try to load README from file system
|
| 54 |
readme_content = _load_readme_from_filesystem(env_name)
|
| 55 |
if readme_content:
|
| 56 |
metadata.readme_content = readme_content
|
| 57 |
-
|
| 58 |
return metadata
|
| 59 |
|
| 60 |
|
| 61 |
def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]:
|
| 62 |
"""
|
| 63 |
Load README content from the filesystem.
|
| 64 |
-
|
| 65 |
Tries multiple locations:
|
| 66 |
1. Container filesystem: /app/README.md
|
| 67 |
2. Local development: src/envs/{env_name}/README.md
|
|
@@ -69,59 +74,71 @@ def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]:
|
|
| 69 |
"""
|
| 70 |
import os
|
| 71 |
from pathlib import Path
|
| 72 |
-
|
| 73 |
# Try container filesystem first
|
| 74 |
container_readme = Path("/app/README.md")
|
| 75 |
if container_readme.exists():
|
| 76 |
try:
|
| 77 |
-
return container_readme.read_text(encoding=
|
| 78 |
except Exception:
|
| 79 |
pass
|
| 80 |
-
|
| 81 |
# Try environment variable path
|
| 82 |
custom_path = os.environ.get("ENV_README_PATH")
|
| 83 |
if custom_path and Path(custom_path).exists():
|
| 84 |
try:
|
| 85 |
-
return Path(custom_path).read_text(encoding=
|
| 86 |
except Exception:
|
| 87 |
pass
|
| 88 |
-
|
| 89 |
# Try local development path
|
| 90 |
if env_name:
|
| 91 |
local_readme = Path(f"src/envs/{env_name}/README.md")
|
| 92 |
if local_readme.exists():
|
| 93 |
try:
|
| 94 |
-
return local_readme.read_text(encoding=
|
| 95 |
except Exception:
|
| 96 |
pass
|
| 97 |
-
|
| 98 |
return None
|
| 99 |
|
| 100 |
|
| 101 |
-
|
| 102 |
-
class ActionLog:
|
| 103 |
"""Log entry for an action taken."""
|
| 104 |
-
timestamp: str
|
| 105 |
-
action: Dict[str, Any]
|
| 106 |
-
observation: Dict[str, Any]
|
| 107 |
-
reward: Optional[float]
|
| 108 |
-
done: bool
|
| 109 |
-
step_count: int
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
-
class EpisodeState:
|
| 114 |
"""Current episode state for the web interface."""
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
class WebInterfaceManager:
|
| 123 |
"""Manages the web interface for an environment."""
|
| 124 |
-
|
| 125 |
def __init__(
|
| 126 |
self,
|
| 127 |
env: Environment,
|
|
@@ -134,147 +151,127 @@ class WebInterfaceManager:
|
|
| 134 |
self.observation_cls = observation_cls
|
| 135 |
self.metadata = metadata or EnvironmentMetadata(
|
| 136 |
name=env.__class__.__name__,
|
| 137 |
-
description=f"{env.__class__.__name__} environment"
|
| 138 |
)
|
| 139 |
self.episode_state = EpisodeState(
|
| 140 |
episode_id=None,
|
| 141 |
step_count=0,
|
| 142 |
current_observation=None,
|
| 143 |
-
action_logs=[]
|
| 144 |
)
|
| 145 |
self.connected_clients: List[WebSocket] = []
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
async def connect_websocket(self, websocket: WebSocket):
|
| 148 |
"""Connect a new WebSocket client."""
|
| 149 |
await websocket.accept()
|
| 150 |
self.connected_clients.append(websocket)
|
| 151 |
-
|
| 152 |
# Send current state to the new client
|
| 153 |
await self._send_state_update()
|
| 154 |
-
|
| 155 |
async def disconnect_websocket(self, websocket: WebSocket):
|
| 156 |
"""Disconnect a WebSocket client."""
|
| 157 |
if websocket in self.connected_clients:
|
| 158 |
self.connected_clients.remove(websocket)
|
| 159 |
-
|
| 160 |
async def _send_state_update(self):
|
| 161 |
"""Send current state to all connected clients."""
|
| 162 |
if not self.connected_clients:
|
| 163 |
return
|
| 164 |
-
|
| 165 |
state_data = {
|
| 166 |
"type": "state_update",
|
| 167 |
-
"episode_state":
|
| 168 |
}
|
| 169 |
-
|
| 170 |
# Send to all connected clients
|
| 171 |
disconnected_clients = []
|
| 172 |
for client in self.connected_clients:
|
| 173 |
try:
|
| 174 |
await client.send_text(json.dumps(state_data))
|
| 175 |
-
except:
|
| 176 |
disconnected_clients.append(client)
|
| 177 |
-
|
| 178 |
# Remove disconnected clients
|
| 179 |
for client in disconnected_clients:
|
| 180 |
self.connected_clients.remove(client)
|
| 181 |
-
|
| 182 |
async def reset_environment(self) -> Dict[str, Any]:
|
| 183 |
"""Reset the environment and update state."""
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
# Update episode state
|
| 188 |
self.episode_state.episode_id = state.episode_id
|
| 189 |
self.episode_state.step_count = 0
|
| 190 |
-
self.episode_state.current_observation =
|
| 191 |
self.episode_state.action_logs = []
|
| 192 |
self.episode_state.is_reset = True
|
| 193 |
-
|
| 194 |
# Send state update
|
| 195 |
await self._send_state_update()
|
| 196 |
-
|
| 197 |
-
return
|
| 198 |
-
|
| 199 |
-
"reward": observation.reward,
|
| 200 |
-
"done": observation.done,
|
| 201 |
-
}
|
| 202 |
-
|
| 203 |
async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 204 |
"""Execute a step in the environment and update state."""
|
| 205 |
-
# Deserialize action
|
| 206 |
-
action =
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
# Create action log
|
| 213 |
action_log = ActionLog(
|
| 214 |
timestamp=datetime.now().isoformat(),
|
| 215 |
-
action=
|
| 216 |
-
observation=
|
| 217 |
reward=observation.reward,
|
| 218 |
done=observation.done,
|
| 219 |
-
step_count=state.step_count
|
| 220 |
)
|
| 221 |
-
|
| 222 |
# Update episode state
|
| 223 |
self.episode_state.episode_id = state.episode_id
|
| 224 |
self.episode_state.step_count = state.step_count
|
| 225 |
-
self.episode_state.current_observation =
|
| 226 |
self.episode_state.action_logs.append(action_log)
|
| 227 |
self.episode_state.is_reset = False
|
| 228 |
-
|
| 229 |
# Send state update
|
| 230 |
await self._send_state_update()
|
| 231 |
-
|
| 232 |
-
return
|
| 233 |
-
|
| 234 |
-
"reward": observation.reward,
|
| 235 |
-
"done": observation.done,
|
| 236 |
-
}
|
| 237 |
-
|
| 238 |
def get_state(self) -> Dict[str, Any]:
|
| 239 |
"""Get current environment state."""
|
| 240 |
-
state = self.env.state
|
| 241 |
-
return
|
| 242 |
-
|
| 243 |
-
def _deserialize_action(self, action_data: Dict[str, Any]) -> Action:
|
| 244 |
-
"""Convert JSON dict to Action instance."""
|
| 245 |
-
metadata = action_data.pop("metadata", {})
|
| 246 |
-
|
| 247 |
-
# Handle tensor fields that come from JSON as lists
|
| 248 |
-
processed_data = {}
|
| 249 |
-
for key, value in action_data.items():
|
| 250 |
-
if key == "tokens" and isinstance(value, (list, str)):
|
| 251 |
-
# Convert list or string to tensor
|
| 252 |
-
if isinstance(value, str):
|
| 253 |
-
# If it's a string, try to parse it as a list of numbers
|
| 254 |
-
try:
|
| 255 |
-
import json
|
| 256 |
-
value = json.loads(value)
|
| 257 |
-
except:
|
| 258 |
-
# If parsing fails, treat as empty list
|
| 259 |
-
value = []
|
| 260 |
-
if isinstance(value, list):
|
| 261 |
-
import torch
|
| 262 |
-
processed_data[key] = torch.tensor(value, dtype=torch.long)
|
| 263 |
-
else:
|
| 264 |
-
processed_data[key] = value
|
| 265 |
-
elif key == "action_id" and isinstance(value, str):
|
| 266 |
-
# Convert action_id from string to int
|
| 267 |
-
try:
|
| 268 |
-
processed_data[key] = int(value)
|
| 269 |
-
except ValueError:
|
| 270 |
-
# If conversion fails, keep original value
|
| 271 |
-
processed_data[key] = value
|
| 272 |
-
else:
|
| 273 |
-
processed_data[key] = value
|
| 274 |
-
|
| 275 |
-
action = self.action_cls(**processed_data)
|
| 276 |
-
action.metadata = metadata
|
| 277 |
-
return action
|
| 278 |
|
| 279 |
|
| 280 |
def create_web_interface_app(
|
|
@@ -285,41 +282,45 @@ def create_web_interface_app(
|
|
| 285 |
) -> FastAPI:
|
| 286 |
"""
|
| 287 |
Create a FastAPI application with web interface for the given environment.
|
| 288 |
-
|
| 289 |
Args:
|
| 290 |
env: The Environment instance to serve
|
| 291 |
action_cls: The Action subclass this environment expects
|
| 292 |
observation_cls: The Observation subclass this environment returns
|
| 293 |
env_name: Optional environment name for README loading
|
| 294 |
-
|
| 295 |
Returns:
|
| 296 |
FastAPI application instance with web interface
|
| 297 |
"""
|
| 298 |
from .http_server import create_fastapi_app
|
| 299 |
-
|
| 300 |
# Create the base environment app
|
| 301 |
app = create_fastapi_app(env, action_cls, observation_cls)
|
| 302 |
-
|
| 303 |
# Load environment metadata
|
| 304 |
metadata = load_environment_metadata(env, env_name)
|
| 305 |
-
|
| 306 |
# Create web interface manager
|
| 307 |
web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
|
| 308 |
-
|
| 309 |
# Add web interface routes
|
| 310 |
@app.get("/web", response_class=HTMLResponse)
|
| 311 |
async def web_interface():
|
| 312 |
"""Serve the web interface."""
|
| 313 |
return get_web_interface_html(action_cls, web_manager.metadata)
|
| 314 |
-
|
| 315 |
@app.get("/web/metadata")
|
| 316 |
async def web_metadata():
|
| 317 |
"""Get environment metadata."""
|
| 318 |
-
return
|
| 319 |
-
|
| 320 |
-
@app.websocket("/ws")
|
| 321 |
-
async def
|
| 322 |
-
"""WebSocket endpoint for real-time updates.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
await web_manager.connect_websocket(websocket)
|
| 324 |
try:
|
| 325 |
while True:
|
|
@@ -327,12 +328,12 @@ def create_web_interface_app(
|
|
| 327 |
await websocket.receive_text()
|
| 328 |
except WebSocketDisconnect:
|
| 329 |
await web_manager.disconnect_websocket(websocket)
|
| 330 |
-
|
| 331 |
@app.post("/web/reset")
|
| 332 |
async def web_reset():
|
| 333 |
"""Reset endpoint for web interface."""
|
| 334 |
return await web_manager.reset_environment()
|
| 335 |
-
|
| 336 |
@app.post("/web/step")
|
| 337 |
async def web_step(request: Dict[str, Any]):
|
| 338 |
"""Step endpoint for web interface."""
|
|
@@ -344,31 +345,37 @@ def create_web_interface_app(
|
|
| 344 |
action_data = {"tokens": action.tokens.tolist()}
|
| 345 |
else:
|
| 346 |
action_data = request.get("action", {})
|
| 347 |
-
|
| 348 |
return await web_manager.step_environment(action_data)
|
| 349 |
-
|
| 350 |
@app.get("/web/state")
|
| 351 |
async def web_state():
|
| 352 |
"""State endpoint for web interface."""
|
| 353 |
return web_manager.get_state()
|
| 354 |
-
|
| 355 |
return app
|
| 356 |
|
| 357 |
|
| 358 |
-
def get_web_interface_html(
|
|
|
|
|
|
|
| 359 |
"""Generate the HTML for the web interface."""
|
| 360 |
-
|
| 361 |
# Check if this is a chat environment by looking for tokens field
|
| 362 |
is_chat_env = False
|
| 363 |
-
if hasattr(action_cls,
|
| 364 |
-
for field_name, field_info in action_cls.
|
| 365 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
is_chat_env = True
|
| 367 |
break
|
| 368 |
-
|
| 369 |
# Get action fields for dynamic form generation with enhanced metadata
|
| 370 |
action_fields = _extract_action_fields(action_cls)
|
| 371 |
-
|
| 372 |
return f"""
|
| 373 |
<!DOCTYPE html>
|
| 374 |
<html lang="en">
|
|
@@ -971,7 +978,7 @@ def get_web_interface_html(action_cls: Type[Action], metadata: Optional[Environm
|
|
| 971 |
|
| 972 |
connectWebSocket() {{
|
| 973 |
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
| 974 |
-
const wsUrl = `${{protocol}}//${{window.location.host}}/ws`;
|
| 975 |
|
| 976 |
this.ws = new WebSocket(wsUrl);
|
| 977 |
|
|
@@ -1259,19 +1266,22 @@ def get_web_interface_html(action_cls: Type[Action], metadata: Optional[Environm
|
|
| 1259 |
</script>
|
| 1260 |
</body>
|
| 1261 |
</html>
|
| 1262 |
-
""".replace(
|
|
|
|
|
|
|
|
|
|
| 1263 |
|
| 1264 |
|
| 1265 |
-
def _generate_instructions_section(
|
|
|
|
|
|
|
| 1266 |
"""Generate the instructions section with environment documentation."""
|
| 1267 |
if not metadata or not metadata.readme_content:
|
| 1268 |
-
return
|
| 1269 |
-
|
| 1270 |
-
# Convert markdown to HTML (basic conversion)
|
| 1271 |
-
import re
|
| 1272 |
html_content = _markdown_to_html(metadata.readme_content)
|
| 1273 |
-
|
| 1274 |
-
return f
|
| 1275 |
<!-- Instructions Section -->
|
| 1276 |
<div class="instructions-section">
|
| 1277 |
<div class="instructions-header">
|
|
@@ -1284,194 +1294,178 @@ def _generate_instructions_section(metadata: Optional[EnvironmentMetadata]) -> s
|
|
| 1284 |
</div>
|
| 1285 |
</div>
|
| 1286 |
</div>
|
| 1287 |
-
|
| 1288 |
|
| 1289 |
|
| 1290 |
def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]:
|
| 1291 |
"""Extract enhanced field metadata from Action class for form generation."""
|
| 1292 |
-
|
| 1293 |
-
|
| 1294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1295 |
action_fields = []
|
| 1296 |
-
|
| 1297 |
-
|
| 1298 |
-
|
| 1299 |
-
for field_name, field_info in action_cls.__dataclass_fields__.items():
|
| 1300 |
-
if field_name == 'metadata':
|
| 1301 |
continue
|
| 1302 |
-
|
| 1303 |
-
|
| 1304 |
-
|
| 1305 |
-
|
| 1306 |
-
|
| 1307 |
-
|
| 1308 |
-
|
| 1309 |
-
|
| 1310 |
-
|
| 1311 |
-
|
| 1312 |
-
|
| 1313 |
-
|
| 1314 |
-
|
| 1315 |
-
|
| 1316 |
-
|
| 1317 |
-
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
|
|
|
|
|
|
|
| 1325 |
return action_fields
|
| 1326 |
|
| 1327 |
|
| 1328 |
-
def
|
| 1329 |
-
|
| 1330 |
-
|
| 1331 |
-
|
| 1332 |
-
|
| 1333 |
-
metadata = {}
|
| 1334 |
-
|
| 1335 |
-
# Extract description from field docstring or annotation
|
| 1336 |
-
if hasattr(field_info, 'metadata') and field_info.metadata:
|
| 1337 |
-
# Check for custom metadata
|
| 1338 |
-
for meta in field_info.metadata:
|
| 1339 |
-
if isinstance(meta, dict):
|
| 1340 |
-
metadata.update(meta)
|
| 1341 |
-
|
| 1342 |
-
# Extract type information
|
| 1343 |
-
field_type = field_info.type
|
| 1344 |
-
origin = get_origin(field_type)
|
| 1345 |
-
|
| 1346 |
-
# Handle Literal types for dropdown choices
|
| 1347 |
-
if origin is Literal:
|
| 1348 |
-
args = get_args(field_type)
|
| 1349 |
-
metadata['choices'] = list(args)
|
| 1350 |
-
|
| 1351 |
-
# Handle Optional types
|
| 1352 |
-
if origin is Union:
|
| 1353 |
-
args = get_args(field_type)
|
| 1354 |
-
if len(args) == 2 and type(None) in args:
|
| 1355 |
-
# This is Optional[SomeType]
|
| 1356 |
-
non_none_type = args[0] if args[1] is type(None) else args[1]
|
| 1357 |
-
metadata['optional'] = True
|
| 1358 |
-
# Recursively check the non-None type for choices
|
| 1359 |
-
if get_origin(non_none_type) is Literal:
|
| 1360 |
-
metadata['choices'] = list(get_args(non_none_type))
|
| 1361 |
-
else:
|
| 1362 |
-
# Regular Union type
|
| 1363 |
-
metadata['choices'] = [str(arg) for arg in args if arg is not type(None)]
|
| 1364 |
-
|
| 1365 |
-
# Handle numeric constraints
|
| 1366 |
-
if field_type in (int, float):
|
| 1367 |
-
# Check for common constraint patterns in field name
|
| 1368 |
-
if 'count' in field_name.lower() or 'num' in field_name.lower():
|
| 1369 |
-
metadata['min_value'] = 0
|
| 1370 |
-
if 'id' in field_name.lower():
|
| 1371 |
-
metadata['min_value'] = 0
|
| 1372 |
-
|
| 1373 |
-
# Generate placeholder text
|
| 1374 |
-
if 'message' in field_name.lower():
|
| 1375 |
-
metadata['placeholder'] = f'Enter {field_name.replace("_", " ")}...'
|
| 1376 |
-
elif 'code' in field_name.lower():
|
| 1377 |
-
metadata['placeholder'] = 'Enter Python code here...'
|
| 1378 |
-
elif 'tokens' in field_name.lower():
|
| 1379 |
-
metadata['placeholder'] = 'Enter comma-separated token IDs (e.g., 1,2,3,4,5)'
|
| 1380 |
-
else:
|
| 1381 |
-
metadata['placeholder'] = f'Enter {field_name.replace("_", " ")}...'
|
| 1382 |
-
|
| 1383 |
-
# Generate help text based on field name and type
|
| 1384 |
-
if 'action_id' in field_name.lower():
|
| 1385 |
-
metadata['help_text'] = 'The action ID to execute in the environment'
|
| 1386 |
-
elif 'game_name' in field_name.lower():
|
| 1387 |
-
metadata['help_text'] = 'Name of the game or environment'
|
| 1388 |
-
elif 'tokens' in field_name.lower():
|
| 1389 |
-
metadata['help_text'] = 'Token IDs as a comma-separated list of integers'
|
| 1390 |
-
elif 'code' in field_name.lower():
|
| 1391 |
-
metadata['help_text'] = 'Python code to execute in the environment'
|
| 1392 |
-
elif 'message' in field_name.lower():
|
| 1393 |
-
metadata['help_text'] = 'Text message to send'
|
| 1394 |
-
|
| 1395 |
-
return metadata
|
| 1396 |
|
|
|
|
|
|
|
|
|
|
| 1397 |
|
| 1398 |
-
|
| 1399 |
-
"""Determine the appropriate HTML input type for a field type."""
|
| 1400 |
-
import typing
|
| 1401 |
-
from typing import get_origin, get_args, Literal, Union
|
| 1402 |
-
|
| 1403 |
-
# Handle direct types
|
| 1404 |
-
if field_type == str:
|
| 1405 |
-
return "text"
|
| 1406 |
-
elif field_type == int:
|
| 1407 |
-
return "number"
|
| 1408 |
-
elif field_type == float:
|
| 1409 |
-
return "number"
|
| 1410 |
-
elif field_type == bool:
|
| 1411 |
-
return "checkbox"
|
| 1412 |
-
|
| 1413 |
-
# Handle complex types
|
| 1414 |
-
origin = get_origin(field_type)
|
| 1415 |
-
|
| 1416 |
-
if origin is Literal:
|
| 1417 |
return "select"
|
| 1418 |
-
|
| 1419 |
-
|
| 1420 |
-
|
| 1421 |
-
|
| 1422 |
-
|
| 1423 |
-
|
| 1424 |
-
|
| 1425 |
-
|
| 1426 |
-
|
| 1427 |
-
|
| 1428 |
-
|
| 1429 |
-
|
| 1430 |
-
|
|
|
|
|
|
|
| 1431 |
return "text"
|
| 1432 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1433 |
|
| 1434 |
def _markdown_to_html(markdown: str) -> str:
|
| 1435 |
"""Convert basic markdown to HTML for README display."""
|
| 1436 |
import html
|
| 1437 |
import re
|
| 1438 |
-
|
| 1439 |
# Escape HTML first
|
| 1440 |
html_content = html.escape(markdown)
|
| 1441 |
-
|
| 1442 |
# Convert headers
|
| 1443 |
-
html_content = re.sub(
|
| 1444 |
-
|
| 1445 |
-
|
| 1446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1447 |
# Convert code blocks
|
| 1448 |
-
html_content = re.sub(
|
| 1449 |
-
|
| 1450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1451 |
# Convert bold and italic
|
| 1452 |
-
html_content = re.sub(r
|
| 1453 |
-
html_content = re.sub(r
|
| 1454 |
-
|
| 1455 |
# Convert lists
|
| 1456 |
-
html_content = re.sub(
|
| 1457 |
-
|
| 1458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1459 |
# Convert line breaks
|
| 1460 |
-
html_content = html_content.replace(
|
| 1461 |
-
|
| 1462 |
return html_content
|
| 1463 |
|
| 1464 |
|
| 1465 |
-
def _generate_action_interface(
|
|
|
|
|
|
|
| 1466 |
"""Generate either a chat interface or action form based on environment type."""
|
| 1467 |
if is_chat_env:
|
| 1468 |
return _generate_chat_interface()
|
| 1469 |
else:
|
| 1470 |
return _generate_action_form(action_fields)
|
| 1471 |
|
|
|
|
| 1472 |
def _generate_chat_interface() -> str:
|
| 1473 |
"""Generate a chat-style interface for chat environments."""
|
| 1474 |
-
return
|
| 1475 |
<!-- Chat Interface -->
|
| 1476 |
<div class="chat-interface">
|
| 1477 |
<h3>Chat Interface</h3>
|
|
@@ -1495,11 +1489,12 @@ def _generate_chat_interface() -> str:
|
|
| 1495 |
</div>
|
| 1496 |
</div>
|
| 1497 |
</div>
|
| 1498 |
-
|
|
|
|
| 1499 |
|
| 1500 |
def _generate_action_form(action_fields: List[Dict[str, Any]]) -> str:
|
| 1501 |
"""Generate a traditional action form for non-chat environments."""
|
| 1502 |
-
return f
|
| 1503 |
<!-- Action Form -->
|
| 1504 |
<div class="action-form">
|
| 1505 |
<h3>Take Action</h3>
|
|
@@ -1508,106 +1503,119 @@ def _generate_action_form(action_fields: List[Dict[str, Any]]) -> str:
|
|
| 1508 |
<button type="submit" class="btn" id="step-btn">Step</button>
|
| 1509 |
</form>
|
| 1510 |
</div>
|
| 1511 |
-
|
|
|
|
| 1512 |
|
| 1513 |
def _generate_action_form_fields(action_fields: List[Dict[str, Any]]) -> str:
|
| 1514 |
"""Generate HTML form fields for action input with enhanced metadata."""
|
| 1515 |
if not action_fields:
|
| 1516 |
-
return
|
| 1517 |
-
|
| 1518 |
fields_html = []
|
| 1519 |
for field in action_fields:
|
| 1520 |
field_html = _generate_single_field(field)
|
| 1521 |
fields_html.append(field_html)
|
| 1522 |
-
|
| 1523 |
-
return
|
| 1524 |
|
| 1525 |
|
| 1526 |
def _generate_single_field(field: Dict[str, Any]) -> str:
|
| 1527 |
"""Generate HTML for a single form field with enhanced metadata."""
|
| 1528 |
-
field_name = field[
|
| 1529 |
-
field_type = field[
|
| 1530 |
-
required = field[
|
| 1531 |
-
placeholder = field.get(
|
| 1532 |
-
help_text = field.get(
|
| 1533 |
-
choices = field.get(
|
| 1534 |
-
min_value = field.get(
|
| 1535 |
-
max_value = field.get(
|
| 1536 |
-
default_value = field.get(
|
| 1537 |
-
|
|
|
|
|
|
|
|
|
|
| 1538 |
# Build label with required indicator
|
| 1539 |
-
label_text = field_name.replace(
|
| 1540 |
if required:
|
| 1541 |
label_text += ' <span style="color: red;">*</span>'
|
| 1542 |
-
|
| 1543 |
# Build input attributes
|
| 1544 |
input_attrs = []
|
| 1545 |
if required:
|
| 1546 |
-
input_attrs.append(
|
| 1547 |
if placeholder:
|
| 1548 |
input_attrs.append(f'placeholder="{placeholder}"')
|
| 1549 |
if min_value is not None:
|
| 1550 |
input_attrs.append(f'min="{min_value}"')
|
| 1551 |
if max_value is not None:
|
| 1552 |
input_attrs.append(f'max="{max_value}"')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1553 |
if default_value is not None:
|
| 1554 |
input_attrs.append(f'value="{default_value}"')
|
| 1555 |
-
|
| 1556 |
-
attrs_str =
|
| 1557 |
-
|
| 1558 |
-
if field_type ==
|
|
|
|
| 1559 |
return f'''
|
| 1560 |
<div class="form-group">
|
| 1561 |
<label>
|
| 1562 |
-
<input type="checkbox" name="{field_name}" value="true" {
|
| 1563 |
{label_text}
|
| 1564 |
</label>
|
| 1565 |
-
{f'<small class="help-text">{help_text}</small>' if help_text else
|
| 1566 |
</div>
|
| 1567 |
'''
|
| 1568 |
-
|
| 1569 |
-
elif field_type ==
|
| 1570 |
options_html = []
|
| 1571 |
if not required:
|
| 1572 |
options_html.append(f'<option value="">-- Select {label_text} --</option>')
|
| 1573 |
-
|
| 1574 |
for choice in choices:
|
| 1575 |
-
selected =
|
| 1576 |
-
options_html.append(
|
| 1577 |
-
|
|
|
|
|
|
|
| 1578 |
return f'''
|
| 1579 |
<div class="form-group">
|
| 1580 |
<label for="{field_name}">{label_text}:</label>
|
| 1581 |
<select name="{field_name}" id="{field_name}" {attrs_str}>
|
| 1582 |
-
{
|
| 1583 |
</select>
|
| 1584 |
-
{f'<small class="help-text">{help_text}</small>' if help_text else
|
| 1585 |
</div>
|
| 1586 |
'''
|
| 1587 |
-
|
| 1588 |
-
elif field_type ==
|
| 1589 |
return f'''
|
| 1590 |
<div class="form-group">
|
| 1591 |
<label for="{field_name}">{label_text} (comma-separated integers):</label>
|
| 1592 |
<input type="text" name="{field_name}" id="{field_name}" {attrs_str}>
|
| 1593 |
-
<small class="help-text">{help_text or
|
| 1594 |
</div>
|
| 1595 |
'''
|
| 1596 |
-
|
| 1597 |
-
elif field_type ==
|
| 1598 |
return f'''
|
| 1599 |
<div class="form-group">
|
| 1600 |
<label for="{field_name}">{label_text}:</label>
|
| 1601 |
-
<textarea name="{field_name}" id="{field_name}" rows="3" {attrs_str}></textarea>
|
| 1602 |
-
{f'<small class="help-text">{help_text}</small>' if help_text else
|
| 1603 |
</div>
|
| 1604 |
'''
|
| 1605 |
-
|
| 1606 |
else:
|
| 1607 |
return f'''
|
| 1608 |
<div class="form-group">
|
| 1609 |
<label for="{field_name}">{label_text}:</label>
|
| 1610 |
<input type="{field_type}" name="{field_name}" id="{field_name}" {attrs_str}>
|
| 1611 |
-
{f'<small class="help-text">{help_text}</small>' if help_text else
|
| 1612 |
</div>
|
| 1613 |
'''
|
|
|
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
| 16 |
+
import asyncio
|
| 17 |
import json
|
| 18 |
+
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
| 19 |
from typing import Any, Dict, List, Optional, Type
|
| 20 |
from datetime import datetime
|
| 21 |
|
| 22 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 23 |
+
from fastapi.responses import HTMLResponse
|
| 24 |
+
from pydantic import BaseModel, Field, ConfigDict
|
|
|
|
| 25 |
|
| 26 |
from .interfaces import Environment
|
| 27 |
+
from .serialization import (
|
| 28 |
+
deserialize_action_with_preprocessing,
|
| 29 |
+
serialize_observation,
|
| 30 |
+
)
|
| 31 |
from .types import Action, Observation, State, EnvironmentMetadata
|
| 32 |
|
| 33 |
|
| 34 |
+
def load_environment_metadata(
|
| 35 |
+
env: Environment, env_name: Optional[str] = None
|
| 36 |
+
) -> EnvironmentMetadata:
|
| 37 |
"""
|
| 38 |
Load environment metadata including README content.
|
| 39 |
+
|
| 40 |
Args:
|
| 41 |
env: The environment instance
|
| 42 |
env_name: Optional environment name for README file lookup
|
| 43 |
+
|
| 44 |
Returns:
|
| 45 |
EnvironmentMetadata with loaded information
|
| 46 |
"""
|
| 47 |
# Try to get metadata from environment if it has a method for it
|
| 48 |
+
if hasattr(env, "get_metadata"):
|
| 49 |
return env.get_metadata()
|
| 50 |
+
|
| 51 |
# Default metadata
|
| 52 |
metadata = EnvironmentMetadata(
|
| 53 |
name=env_name or env.__class__.__name__,
|
| 54 |
description=f"{env.__class__.__name__} environment",
|
| 55 |
+
version="1.0.0",
|
| 56 |
)
|
| 57 |
+
|
| 58 |
# Try to load README from file system
|
| 59 |
readme_content = _load_readme_from_filesystem(env_name)
|
| 60 |
if readme_content:
|
| 61 |
metadata.readme_content = readme_content
|
| 62 |
+
|
| 63 |
return metadata
|
| 64 |
|
| 65 |
|
| 66 |
def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]:
|
| 67 |
"""
|
| 68 |
Load README content from the filesystem.
|
| 69 |
+
|
| 70 |
Tries multiple locations:
|
| 71 |
1. Container filesystem: /app/README.md
|
| 72 |
2. Local development: src/envs/{env_name}/README.md
|
|
|
|
| 74 |
"""
|
| 75 |
import os
|
| 76 |
from pathlib import Path
|
| 77 |
+
|
| 78 |
# Try container filesystem first
|
| 79 |
container_readme = Path("/app/README.md")
|
| 80 |
if container_readme.exists():
|
| 81 |
try:
|
| 82 |
+
return container_readme.read_text(encoding="utf-8")
|
| 83 |
except Exception:
|
| 84 |
pass
|
| 85 |
+
|
| 86 |
# Try environment variable path
|
| 87 |
custom_path = os.environ.get("ENV_README_PATH")
|
| 88 |
if custom_path and Path(custom_path).exists():
|
| 89 |
try:
|
| 90 |
+
return Path(custom_path).read_text(encoding="utf-8")
|
| 91 |
except Exception:
|
| 92 |
pass
|
| 93 |
+
|
| 94 |
# Try local development path
|
| 95 |
if env_name:
|
| 96 |
local_readme = Path(f"src/envs/{env_name}/README.md")
|
| 97 |
if local_readme.exists():
|
| 98 |
try:
|
| 99 |
+
return local_readme.read_text(encoding="utf-8")
|
| 100 |
except Exception:
|
| 101 |
pass
|
| 102 |
+
|
| 103 |
return None
|
| 104 |
|
| 105 |
|
| 106 |
+
class ActionLog(BaseModel):
|
|
|
|
| 107 |
"""Log entry for an action taken."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
| 110 |
+
|
| 111 |
+
timestamp: str = Field(description="Timestamp when action was taken")
|
| 112 |
+
action: Dict[str, Any] = Field(description="Action that was taken")
|
| 113 |
+
observation: Dict[str, Any] = Field(description="Observation returned from action")
|
| 114 |
+
reward: Optional[float] = Field(
|
| 115 |
+
default=None, description="Reward received from action"
|
| 116 |
+
)
|
| 117 |
+
done: bool = Field(description="Whether the episode is done after this action")
|
| 118 |
+
step_count: int = Field(description="Step count when this action was taken")
|
| 119 |
|
| 120 |
+
|
| 121 |
+
class EpisodeState(BaseModel):
|
| 122 |
"""Current episode state for the web interface."""
|
| 123 |
+
|
| 124 |
+
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
| 125 |
+
|
| 126 |
+
episode_id: Optional[str] = Field(default=None, description="Current episode ID")
|
| 127 |
+
step_count: int = Field(description="Current step count in episode")
|
| 128 |
+
current_observation: Optional[Dict[str, Any]] = Field(
|
| 129 |
+
default=None, description="Current observation"
|
| 130 |
+
)
|
| 131 |
+
action_logs: List[ActionLog] = Field(
|
| 132 |
+
default_factory=list, description="List of action logs"
|
| 133 |
+
)
|
| 134 |
+
is_reset: bool = Field(
|
| 135 |
+
default=True, description="Whether the episode has been reset"
|
| 136 |
+
)
|
| 137 |
|
| 138 |
|
| 139 |
class WebInterfaceManager:
|
| 140 |
"""Manages the web interface for an environment."""
|
| 141 |
+
|
| 142 |
def __init__(
|
| 143 |
self,
|
| 144 |
env: Environment,
|
|
|
|
| 151 |
self.observation_cls = observation_cls
|
| 152 |
self.metadata = metadata or EnvironmentMetadata(
|
| 153 |
name=env.__class__.__name__,
|
| 154 |
+
description=f"{env.__class__.__name__} environment",
|
| 155 |
)
|
| 156 |
self.episode_state = EpisodeState(
|
| 157 |
episode_id=None,
|
| 158 |
step_count=0,
|
| 159 |
current_observation=None,
|
| 160 |
+
action_logs=[],
|
| 161 |
)
|
| 162 |
self.connected_clients: List[WebSocket] = []
|
| 163 |
+
# Thread pool for running sync code (e.g., Playwright sync API) in async context
|
| 164 |
+
self._executor = ThreadPoolExecutor(max_workers=1)
|
| 165 |
+
|
| 166 |
+
async def _run_sync_in_thread_pool(self, func, *args, **kwargs):
|
| 167 |
+
"""Run a synchronous function in the thread pool executor.
|
| 168 |
+
|
| 169 |
+
This is needed for environments using sync libraries (e.g., Playwright sync API)
|
| 170 |
+
that cannot be called directly from an async context.
|
| 171 |
+
"""
|
| 172 |
+
loop = asyncio.get_event_loop()
|
| 173 |
+
return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs))
|
| 174 |
+
|
| 175 |
async def connect_websocket(self, websocket: WebSocket):
|
| 176 |
"""Connect a new WebSocket client."""
|
| 177 |
await websocket.accept()
|
| 178 |
self.connected_clients.append(websocket)
|
| 179 |
+
|
| 180 |
# Send current state to the new client
|
| 181 |
await self._send_state_update()
|
| 182 |
+
|
| 183 |
async def disconnect_websocket(self, websocket: WebSocket):
|
| 184 |
"""Disconnect a WebSocket client."""
|
| 185 |
if websocket in self.connected_clients:
|
| 186 |
self.connected_clients.remove(websocket)
|
| 187 |
+
|
| 188 |
async def _send_state_update(self):
|
| 189 |
"""Send current state to all connected clients."""
|
| 190 |
if not self.connected_clients:
|
| 191 |
return
|
| 192 |
+
|
| 193 |
state_data = {
|
| 194 |
"type": "state_update",
|
| 195 |
+
"episode_state": self.episode_state.model_dump(),
|
| 196 |
}
|
| 197 |
+
|
| 198 |
# Send to all connected clients
|
| 199 |
disconnected_clients = []
|
| 200 |
for client in self.connected_clients:
|
| 201 |
try:
|
| 202 |
await client.send_text(json.dumps(state_data))
|
| 203 |
+
except Exception:
|
| 204 |
disconnected_clients.append(client)
|
| 205 |
+
|
| 206 |
# Remove disconnected clients
|
| 207 |
for client in disconnected_clients:
|
| 208 |
self.connected_clients.remove(client)
|
| 209 |
+
|
| 210 |
async def reset_environment(self) -> Dict[str, Any]:
|
| 211 |
"""Reset the environment and update state."""
|
| 212 |
+
# Run sync reset in thread pool to avoid blocking event loop
|
| 213 |
+
# and to support environments using sync libraries (e.g., Playwright)
|
| 214 |
+
observation: Observation = await self._run_sync_in_thread_pool(self.env.reset)
|
| 215 |
+
state: State = self.env.state
|
| 216 |
+
|
| 217 |
+
# Serialize observation once using shared utility
|
| 218 |
+
serialized = serialize_observation(observation)
|
| 219 |
+
|
| 220 |
# Update episode state
|
| 221 |
self.episode_state.episode_id = state.episode_id
|
| 222 |
self.episode_state.step_count = 0
|
| 223 |
+
self.episode_state.current_observation = serialized["observation"]
|
| 224 |
self.episode_state.action_logs = []
|
| 225 |
self.episode_state.is_reset = True
|
| 226 |
+
|
| 227 |
# Send state update
|
| 228 |
await self._send_state_update()
|
| 229 |
+
|
| 230 |
+
return serialized
|
| 231 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 233 |
"""Execute a step in the environment and update state."""
|
| 234 |
+
# Deserialize action with preprocessing for web interface special cases
|
| 235 |
+
action: Action = deserialize_action_with_preprocessing(
|
| 236 |
+
action_data, self.action_cls
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Run sync step in thread pool to avoid blocking event loop
|
| 240 |
+
# and to support environments using sync libraries (e.g., Playwright)
|
| 241 |
+
observation: Observation = await self._run_sync_in_thread_pool(
|
| 242 |
+
self.env.step, action
|
| 243 |
+
)
|
| 244 |
+
state: State = self.env.state
|
| 245 |
+
|
| 246 |
+
# Serialize observation once using shared utility
|
| 247 |
+
serialized = serialize_observation(observation)
|
| 248 |
+
|
| 249 |
# Create action log
|
| 250 |
action_log = ActionLog(
|
| 251 |
timestamp=datetime.now().isoformat(),
|
| 252 |
+
action=action.model_dump(exclude={"metadata"}),
|
| 253 |
+
observation=serialized["observation"],
|
| 254 |
reward=observation.reward,
|
| 255 |
done=observation.done,
|
| 256 |
+
step_count=state.step_count,
|
| 257 |
)
|
| 258 |
+
|
| 259 |
# Update episode state
|
| 260 |
self.episode_state.episode_id = state.episode_id
|
| 261 |
self.episode_state.step_count = state.step_count
|
| 262 |
+
self.episode_state.current_observation = serialized["observation"]
|
| 263 |
self.episode_state.action_logs.append(action_log)
|
| 264 |
self.episode_state.is_reset = False
|
| 265 |
+
|
| 266 |
# Send state update
|
| 267 |
await self._send_state_update()
|
| 268 |
+
|
| 269 |
+
return serialized
|
| 270 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
def get_state(self) -> Dict[str, Any]:
|
| 272 |
"""Get current environment state."""
|
| 273 |
+
state: State = self.env.state
|
| 274 |
+
return state.model_dump()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
|
| 277 |
def create_web_interface_app(
|
|
|
|
| 282 |
) -> FastAPI:
|
| 283 |
"""
|
| 284 |
Create a FastAPI application with web interface for the given environment.
|
| 285 |
+
|
| 286 |
Args:
|
| 287 |
env: The Environment instance to serve
|
| 288 |
action_cls: The Action subclass this environment expects
|
| 289 |
observation_cls: The Observation subclass this environment returns
|
| 290 |
env_name: Optional environment name for README loading
|
| 291 |
+
|
| 292 |
Returns:
|
| 293 |
FastAPI application instance with web interface
|
| 294 |
"""
|
| 295 |
from .http_server import create_fastapi_app
|
| 296 |
+
|
| 297 |
# Create the base environment app
|
| 298 |
app = create_fastapi_app(env, action_cls, observation_cls)
|
| 299 |
+
|
| 300 |
# Load environment metadata
|
| 301 |
metadata = load_environment_metadata(env, env_name)
|
| 302 |
+
|
| 303 |
# Create web interface manager
|
| 304 |
web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
|
| 305 |
+
|
| 306 |
# Add web interface routes
|
| 307 |
@app.get("/web", response_class=HTMLResponse)
|
| 308 |
async def web_interface():
|
| 309 |
"""Serve the web interface."""
|
| 310 |
return get_web_interface_html(action_cls, web_manager.metadata)
|
| 311 |
+
|
| 312 |
@app.get("/web/metadata")
|
| 313 |
async def web_metadata():
|
| 314 |
"""Get environment metadata."""
|
| 315 |
+
return web_manager.metadata.model_dump()
|
| 316 |
+
|
| 317 |
+
@app.websocket("/ws/ui")
|
| 318 |
+
async def websocket_ui_endpoint(websocket: WebSocket):
|
| 319 |
+
"""WebSocket endpoint for web UI real-time updates.
|
| 320 |
+
|
| 321 |
+
Note: Uses /ws/ui to avoid conflict with /ws in http_server.py
|
| 322 |
+
which is used for concurrent environment sessions.
|
| 323 |
+
"""
|
| 324 |
await web_manager.connect_websocket(websocket)
|
| 325 |
try:
|
| 326 |
while True:
|
|
|
|
| 328 |
await websocket.receive_text()
|
| 329 |
except WebSocketDisconnect:
|
| 330 |
await web_manager.disconnect_websocket(websocket)
|
| 331 |
+
|
| 332 |
@app.post("/web/reset")
|
| 333 |
async def web_reset():
|
| 334 |
"""Reset endpoint for web interface."""
|
| 335 |
return await web_manager.reset_environment()
|
| 336 |
+
|
| 337 |
@app.post("/web/step")
|
| 338 |
async def web_step(request: Dict[str, Any]):
|
| 339 |
"""Step endpoint for web interface."""
|
|
|
|
| 345 |
action_data = {"tokens": action.tokens.tolist()}
|
| 346 |
else:
|
| 347 |
action_data = request.get("action", {})
|
| 348 |
+
|
| 349 |
return await web_manager.step_environment(action_data)
|
| 350 |
+
|
| 351 |
@app.get("/web/state")
|
| 352 |
async def web_state():
|
| 353 |
"""State endpoint for web interface."""
|
| 354 |
return web_manager.get_state()
|
| 355 |
+
|
| 356 |
return app
|
| 357 |
|
| 358 |
|
| 359 |
+
def get_web_interface_html(
|
| 360 |
+
action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None
|
| 361 |
+
) -> str:
|
| 362 |
"""Generate the HTML for the web interface."""
|
| 363 |
+
|
| 364 |
# Check if this is a chat environment by looking for tokens field
|
| 365 |
is_chat_env = False
|
| 366 |
+
if hasattr(action_cls, "model_fields"):
|
| 367 |
+
for field_name, field_info in action_cls.model_fields.items():
|
| 368 |
+
if (
|
| 369 |
+
field_name == "tokens"
|
| 370 |
+
and hasattr(field_info.annotation, "__name__")
|
| 371 |
+
and "Tensor" in field_info.annotation.__name__
|
| 372 |
+
):
|
| 373 |
is_chat_env = True
|
| 374 |
break
|
| 375 |
+
|
| 376 |
# Get action fields for dynamic form generation with enhanced metadata
|
| 377 |
action_fields = _extract_action_fields(action_cls)
|
| 378 |
+
|
| 379 |
return f"""
|
| 380 |
<!DOCTYPE html>
|
| 381 |
<html lang="en">
|
|
|
|
| 978 |
|
| 979 |
connectWebSocket() {{
|
| 980 |
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
| 981 |
+
const wsUrl = `${{protocol}}//${{window.location.host}}/ws/ui`;
|
| 982 |
|
| 983 |
this.ws = new WebSocket(wsUrl);
|
| 984 |
|
|
|
|
| 1266 |
</script>
|
| 1267 |
</body>
|
| 1268 |
</html>
|
| 1269 |
+
""".replace(
|
| 1270 |
+
"{_generate_action_form_fields(action_fields)}",
|
| 1271 |
+
_generate_action_form_fields(action_fields),
|
| 1272 |
+
)
|
| 1273 |
|
| 1274 |
|
| 1275 |
+
def _generate_instructions_section(
|
| 1276 |
+
metadata: Optional[EnvironmentMetadata],
|
| 1277 |
+
) -> str:
|
| 1278 |
"""Generate the instructions section with environment documentation."""
|
| 1279 |
if not metadata or not metadata.readme_content:
|
| 1280 |
+
return ""
|
| 1281 |
+
|
|
|
|
|
|
|
| 1282 |
html_content = _markdown_to_html(metadata.readme_content)
|
| 1283 |
+
|
| 1284 |
+
return f"""
|
| 1285 |
<!-- Instructions Section -->
|
| 1286 |
<div class="instructions-section">
|
| 1287 |
<div class="instructions-header">
|
|
|
|
| 1294 |
</div>
|
| 1295 |
</div>
|
| 1296 |
</div>
|
| 1297 |
+
"""
|
| 1298 |
|
| 1299 |
|
| 1300 |
def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]:
|
| 1301 |
"""Extract enhanced field metadata from Action class for form generation."""
|
| 1302 |
+
# Use Pydantic's JSON schema generation for robust metadata extraction
|
| 1303 |
+
try:
|
| 1304 |
+
schema = action_cls.model_json_schema()
|
| 1305 |
+
except AttributeError:
|
| 1306 |
+
# Fallback for non-Pydantic v2 models or if something goes wrong
|
| 1307 |
+
return []
|
| 1308 |
+
|
| 1309 |
+
properties = schema.get("properties", {})
|
| 1310 |
+
required_fields = schema.get("required", [])
|
| 1311 |
+
|
| 1312 |
action_fields = []
|
| 1313 |
+
|
| 1314 |
+
for field_name, field_info in properties.items():
|
| 1315 |
+
if field_name == "metadata":
|
|
|
|
|
|
|
| 1316 |
continue
|
| 1317 |
+
|
| 1318 |
+
# JSON schema "type" can be a string or list/undefined
|
| 1319 |
+
# Determine our internal input type
|
| 1320 |
+
input_type = _determine_input_type_from_schema(field_info, field_name)
|
| 1321 |
+
|
| 1322 |
+
is_required = field_name in required_fields
|
| 1323 |
+
|
| 1324 |
+
action_fields.append(
|
| 1325 |
+
{
|
| 1326 |
+
"name": field_name,
|
| 1327 |
+
"type": input_type,
|
| 1328 |
+
"required": is_required,
|
| 1329 |
+
"description": field_info.get("description", ""),
|
| 1330 |
+
"default_value": field_info.get("default"),
|
| 1331 |
+
"choices": field_info.get("enum"),
|
| 1332 |
+
"min_value": field_info.get("minimum"),
|
| 1333 |
+
"max_value": field_info.get("maximum"),
|
| 1334 |
+
"min_length": field_info.get("minLength"),
|
| 1335 |
+
"max_length": field_info.get("maxLength"),
|
| 1336 |
+
"pattern": field_info.get("pattern"),
|
| 1337 |
+
"placeholder": _generate_placeholder(field_name, field_info),
|
| 1338 |
+
"help_text": _generate_help_text(field_name, field_info),
|
| 1339 |
+
}
|
| 1340 |
+
)
|
| 1341 |
+
|
| 1342 |
return action_fields
|
| 1343 |
|
| 1344 |
|
| 1345 |
+
def _determine_input_type_from_schema(
|
| 1346 |
+
field_info: Dict[str, Any], field_name: str
|
| 1347 |
+
) -> str:
|
| 1348 |
+
"""Determine the appropriate HTML input type from JSON schema info."""
|
| 1349 |
+
schema_type = field_info.get("type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1350 |
|
| 1351 |
+
# Check for specific tensor field convention
|
| 1352 |
+
if "tokens" in field_name.lower():
|
| 1353 |
+
return "tensor"
|
| 1354 |
|
| 1355 |
+
if "enum" in field_info:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1356 |
return "select"
|
| 1357 |
+
|
| 1358 |
+
if schema_type == "boolean":
|
| 1359 |
+
return "checkbox"
|
| 1360 |
+
|
| 1361 |
+
if schema_type == "integer" or schema_type == "number":
|
| 1362 |
+
return "number"
|
| 1363 |
+
|
| 1364 |
+
if schema_type == "string":
|
| 1365 |
+
# Check if it should be a textarea
|
| 1366 |
+
if (
|
| 1367 |
+
field_info.get("maxLength", 0) > 100
|
| 1368 |
+
or "message" in field_name.lower()
|
| 1369 |
+
or "code" in field_name.lower()
|
| 1370 |
+
):
|
| 1371 |
+
return "textarea"
|
| 1372 |
return "text"
|
| 1373 |
|
| 1374 |
+
# Default fallback
|
| 1375 |
+
return "text"
|
| 1376 |
+
|
| 1377 |
+
|
| 1378 |
+
def _generate_placeholder(field_name: str, field_info: Dict[str, Any]) -> str:
|
| 1379 |
+
"""Generate placeholder text."""
|
| 1380 |
+
if "message" in field_name.lower():
|
| 1381 |
+
return f"Enter {field_name.replace('_', ' ')}..."
|
| 1382 |
+
elif "code" in field_name.lower():
|
| 1383 |
+
return "Enter Python code here..."
|
| 1384 |
+
elif "tokens" in field_name.lower():
|
| 1385 |
+
return "Enter comma-separated token IDs (e.g., 1,2,3,4,5)"
|
| 1386 |
+
else:
|
| 1387 |
+
return f"Enter {field_name.replace('_', ' ')}..."
|
| 1388 |
+
|
| 1389 |
+
|
| 1390 |
+
def _generate_help_text(field_name: str, field_info: Dict[str, Any]) -> str:
|
| 1391 |
+
"""Generate help text."""
|
| 1392 |
+
description = field_info.get("description", "")
|
| 1393 |
+
if description:
|
| 1394 |
+
return description
|
| 1395 |
+
|
| 1396 |
+
if "action_id" in field_name.lower():
|
| 1397 |
+
return "The action ID to execute in environment"
|
| 1398 |
+
elif "game_name" in field_name.lower():
|
| 1399 |
+
return "Name of game or environment"
|
| 1400 |
+
elif "tokens" in field_name.lower():
|
| 1401 |
+
return "Token IDs as a comma-separated list of integers"
|
| 1402 |
+
elif "code" in field_name.lower():
|
| 1403 |
+
return "Python code to execute in environment"
|
| 1404 |
+
elif "message" in field_name.lower():
|
| 1405 |
+
return "Text message to send"
|
| 1406 |
+
|
| 1407 |
+
return ""
|
| 1408 |
+
|
| 1409 |
|
| 1410 |
def _markdown_to_html(markdown: str) -> str:
|
| 1411 |
"""Convert basic markdown to HTML for README display."""
|
| 1412 |
import html
|
| 1413 |
import re
|
| 1414 |
+
|
| 1415 |
# Escape HTML first
|
| 1416 |
html_content = html.escape(markdown)
|
| 1417 |
+
|
| 1418 |
# Convert headers
|
| 1419 |
+
html_content = re.sub(
|
| 1420 |
+
r"^# (.*?)$", r"<h1>\1</h1>", html_content, flags=re.MULTILINE
|
| 1421 |
+
)
|
| 1422 |
+
html_content = re.sub(
|
| 1423 |
+
r"^## (.*?)$", r"<h2>\1</h2>", html_content, flags=re.MULTILINE
|
| 1424 |
+
)
|
| 1425 |
+
html_content = re.sub(
|
| 1426 |
+
r"^### (.*?)$", r"<h3>\1</h3>", html_content, flags=re.MULTILINE
|
| 1427 |
+
)
|
| 1428 |
+
|
| 1429 |
# Convert code blocks
|
| 1430 |
+
html_content = re.sub(
|
| 1431 |
+
r"```(.*?)\n(.*?)\n```",
|
| 1432 |
+
r"<pre><code>\2</code></pre>",
|
| 1433 |
+
html_content,
|
| 1434 |
+
flags=re.DOTALL,
|
| 1435 |
+
)
|
| 1436 |
+
html_content = re.sub(r"`([^`]+)`", r"<code>\1</code>", html_content)
|
| 1437 |
+
|
| 1438 |
# Convert bold and italic
|
| 1439 |
+
html_content = re.sub(r"\*\*(.*?)\*\*", r"<strong>\1</strong>", html_content)
|
| 1440 |
+
html_content = re.sub(r"\*(.*?)\*", r"<em>\1</em>", html_content)
|
| 1441 |
+
|
| 1442 |
# Convert lists
|
| 1443 |
+
html_content = re.sub(
|
| 1444 |
+
r"^- (.*?)$", r"<li>\1</li>", html_content, flags=re.MULTILINE
|
| 1445 |
+
)
|
| 1446 |
+
html_content = re.sub(
|
| 1447 |
+
r"(<li>.*</li>)", r"<ul>\1</ul>", html_content, flags=re.DOTALL
|
| 1448 |
+
)
|
| 1449 |
+
|
| 1450 |
# Convert line breaks
|
| 1451 |
+
html_content = html_content.replace("\n", "<br>")
|
| 1452 |
+
|
| 1453 |
return html_content
|
| 1454 |
|
| 1455 |
|
| 1456 |
+
def _generate_action_interface(
|
| 1457 |
+
action_fields: List[Dict[str, Any]], is_chat_env: bool
|
| 1458 |
+
) -> str:
|
| 1459 |
"""Generate either a chat interface or action form based on environment type."""
|
| 1460 |
if is_chat_env:
|
| 1461 |
return _generate_chat_interface()
|
| 1462 |
else:
|
| 1463 |
return _generate_action_form(action_fields)
|
| 1464 |
|
| 1465 |
+
|
| 1466 |
def _generate_chat_interface() -> str:
|
| 1467 |
"""Generate a chat-style interface for chat environments."""
|
| 1468 |
+
return """
|
| 1469 |
<!-- Chat Interface -->
|
| 1470 |
<div class="chat-interface">
|
| 1471 |
<h3>Chat Interface</h3>
|
|
|
|
| 1489 |
</div>
|
| 1490 |
</div>
|
| 1491 |
</div>
|
| 1492 |
+
"""
|
| 1493 |
+
|
| 1494 |
|
| 1495 |
def _generate_action_form(action_fields: List[Dict[str, Any]]) -> str:
|
| 1496 |
"""Generate a traditional action form for non-chat environments."""
|
| 1497 |
+
return f"""
|
| 1498 |
<!-- Action Form -->
|
| 1499 |
<div class="action-form">
|
| 1500 |
<h3>Take Action</h3>
|
|
|
|
| 1503 |
<button type="submit" class="btn" id="step-btn">Step</button>
|
| 1504 |
</form>
|
| 1505 |
</div>
|
| 1506 |
+
"""
|
| 1507 |
+
|
| 1508 |
|
| 1509 |
def _generate_action_form_fields(action_fields: List[Dict[str, Any]]) -> str:
|
| 1510 |
"""Generate HTML form fields for action input with enhanced metadata."""
|
| 1511 |
if not action_fields:
|
| 1512 |
+
return "<p>No action fields available</p>"
|
| 1513 |
+
|
| 1514 |
fields_html = []
|
| 1515 |
for field in action_fields:
|
| 1516 |
field_html = _generate_single_field(field)
|
| 1517 |
fields_html.append(field_html)
|
| 1518 |
+
|
| 1519 |
+
return "\n".join(fields_html)
|
| 1520 |
|
| 1521 |
|
| 1522 |
def _generate_single_field(field: Dict[str, Any]) -> str:
|
| 1523 |
"""Generate HTML for a single form field with enhanced metadata."""
|
| 1524 |
+
field_name = field["name"]
|
| 1525 |
+
field_type = field["type"]
|
| 1526 |
+
required = field["required"]
|
| 1527 |
+
placeholder = field.get("placeholder", "")
|
| 1528 |
+
help_text = field.get("help_text", "")
|
| 1529 |
+
choices = field.get("choices", [])
|
| 1530 |
+
min_value = field.get("min_value")
|
| 1531 |
+
max_value = field.get("max_value")
|
| 1532 |
+
default_value = field.get("default_value")
|
| 1533 |
+
min_length = field.get("min_length")
|
| 1534 |
+
max_length = field.get("max_length")
|
| 1535 |
+
pattern = field.get("pattern")
|
| 1536 |
+
|
| 1537 |
# Build label with required indicator
|
| 1538 |
+
label_text = field_name.replace("_", " ").title()
|
| 1539 |
if required:
|
| 1540 |
label_text += ' <span style="color: red;">*</span>'
|
| 1541 |
+
|
| 1542 |
# Build input attributes
|
| 1543 |
input_attrs = []
|
| 1544 |
if required:
|
| 1545 |
+
input_attrs.append("required")
|
| 1546 |
if placeholder:
|
| 1547 |
input_attrs.append(f'placeholder="{placeholder}"')
|
| 1548 |
if min_value is not None:
|
| 1549 |
input_attrs.append(f'min="{min_value}"')
|
| 1550 |
if max_value is not None:
|
| 1551 |
input_attrs.append(f'max="{max_value}"')
|
| 1552 |
+
if min_length is not None:
|
| 1553 |
+
input_attrs.append(f'minlength="{min_length}"')
|
| 1554 |
+
if max_length is not None:
|
| 1555 |
+
input_attrs.append(f'maxlength="{max_length}"')
|
| 1556 |
+
if pattern is not None:
|
| 1557 |
+
input_attrs.append(f'pattern="{pattern}"')
|
| 1558 |
if default_value is not None:
|
| 1559 |
input_attrs.append(f'value="{default_value}"')
|
| 1560 |
+
|
| 1561 |
+
attrs_str = " ".join(input_attrs)
|
| 1562 |
+
|
| 1563 |
+
if field_type == "checkbox":
|
| 1564 |
+
checked = "checked" if default_value is True else ""
|
| 1565 |
return f'''
|
| 1566 |
<div class="form-group">
|
| 1567 |
<label>
|
| 1568 |
+
<input type="checkbox" name="{field_name}" value="true" {checked}>
|
| 1569 |
{label_text}
|
| 1570 |
</label>
|
| 1571 |
+
{f'<small class="help-text">{help_text}</small>' if help_text else ""}
|
| 1572 |
</div>
|
| 1573 |
'''
|
| 1574 |
+
|
| 1575 |
+
elif field_type == "select":
|
| 1576 |
options_html = []
|
| 1577 |
if not required:
|
| 1578 |
options_html.append(f'<option value="">-- Select {label_text} --</option>')
|
| 1579 |
+
|
| 1580 |
for choice in choices:
|
| 1581 |
+
selected = "selected" if str(choice) == str(default_value) else ""
|
| 1582 |
+
options_html.append(
|
| 1583 |
+
f'<option value="{choice}" {selected}>{choice}</option>'
|
| 1584 |
+
)
|
| 1585 |
+
|
| 1586 |
return f'''
|
| 1587 |
<div class="form-group">
|
| 1588 |
<label for="{field_name}">{label_text}:</label>
|
| 1589 |
<select name="{field_name}" id="{field_name}" {attrs_str}>
|
| 1590 |
+
{"".join(options_html)}
|
| 1591 |
</select>
|
| 1592 |
+
{f'<small class="help-text">{help_text}</small>' if help_text else ""}
|
| 1593 |
</div>
|
| 1594 |
'''
|
| 1595 |
+
|
| 1596 |
+
elif field_type == "tensor":
|
| 1597 |
return f'''
|
| 1598 |
<div class="form-group">
|
| 1599 |
<label for="{field_name}">{label_text} (comma-separated integers):</label>
|
| 1600 |
<input type="text" name="{field_name}" id="{field_name}" {attrs_str}>
|
| 1601 |
+
<small class="help-text">{help_text or "Enter token IDs as comma-separated integers (e.g., 1,2,3,4,5)"}</small>
|
| 1602 |
</div>
|
| 1603 |
'''
|
| 1604 |
+
|
| 1605 |
+
elif field_type == "textarea":
|
| 1606 |
return f'''
|
| 1607 |
<div class="form-group">
|
| 1608 |
<label for="{field_name}">{label_text}:</label>
|
| 1609 |
+
<textarea name="{field_name}" id="{field_name}" rows="3" {attrs_str}>{default_value if default_value is not None else ""}</textarea>
|
| 1610 |
+
{f'<small class="help-text">{help_text}</small>' if help_text else ""}
|
| 1611 |
</div>
|
| 1612 |
'''
|
| 1613 |
+
|
| 1614 |
else:
|
| 1615 |
return f'''
|
| 1616 |
<div class="form-group">
|
| 1617 |
<label for="{field_name}">{label_text}:</label>
|
| 1618 |
<input type="{field_type}" name="{field_name}" id="{field_name}" {attrs_str}>
|
| 1619 |
+
{f'<small class="help-text">{help_text}</small>' if help_text else ""}
|
| 1620 |
</div>
|
| 1621 |
'''
|
{core → src/core}/tools/__init__.py
RENAMED
|
@@ -13,4 +13,4 @@ __all__ = [
|
|
| 13 |
"PyExecutor",
|
| 14 |
"GitServerClient",
|
| 15 |
"RepoInfo",
|
| 16 |
-
]
|
|
|
|
| 13 |
"PyExecutor",
|
| 14 |
"GitServerClient",
|
| 15 |
"RepoInfo",
|
| 16 |
+
]
|
{core → src/core}/tools/git_server_client.py
RENAMED
|
@@ -100,7 +100,9 @@ class GitServerClient:
|
|
| 100 |
gitconfig_path.write_text(git_config)
|
| 101 |
|
| 102 |
# Git credentials
|
| 103 |
-
git_credentials =
|
|
|
|
|
|
|
| 104 |
gitcreds_path = home_dir / ".git-credentials"
|
| 105 |
gitcreds_path.write_text(git_credentials)
|
| 106 |
gitcreds_path.chmod(0o600)
|
|
@@ -272,7 +274,12 @@ class GitServerClient:
|
|
| 272 |
raise RuntimeError(f"Checkout failed: {result.stderr}")
|
| 273 |
|
| 274 |
result = subprocess.run(
|
| 275 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
cwd=str(repo_path),
|
| 277 |
capture_output=True,
|
| 278 |
text=True,
|
|
|
|
| 100 |
gitconfig_path.write_text(git_config)
|
| 101 |
|
| 102 |
# Git credentials
|
| 103 |
+
git_credentials = (
|
| 104 |
+
f"http://{self.username}:{self.password}@{self.domain}:{self.port}\n"
|
| 105 |
+
)
|
| 106 |
gitcreds_path = home_dir / ".git-credentials"
|
| 107 |
gitcreds_path.write_text(git_credentials)
|
| 108 |
gitcreds_path.chmod(0o600)
|
|
|
|
| 274 |
raise RuntimeError(f"Checkout failed: {result.stderr}")
|
| 275 |
|
| 276 |
result = subprocess.run(
|
| 277 |
+
[
|
| 278 |
+
"git",
|
| 279 |
+
"reset",
|
| 280 |
+
"--hard",
|
| 281 |
+
f"origin/{commit}" if commit != "main" else commit,
|
| 282 |
+
],
|
| 283 |
cwd=str(repo_path),
|
| 284 |
capture_output=True,
|
| 285 |
text=True,
|
{core → src/core}/tools/local_python_executor.py
RENAMED
|
@@ -28,7 +28,7 @@ from typing import Any
|
|
| 28 |
|
| 29 |
from smolagents import LocalPythonExecutor
|
| 30 |
|
| 31 |
-
from core.env_server.types import CodeExecResult
|
| 32 |
|
| 33 |
logger = logging.getLogger(__name__)
|
| 34 |
logger.addHandler(logging.NullHandler())
|
|
@@ -69,7 +69,10 @@ class PyExecutor:
|
|
| 69 |
except Exception:
|
| 70 |
# If the LocalPythonExecutor implementation doesn't support
|
| 71 |
# send_tools or fails, log and continue — the executor is still usable.
|
| 72 |
-
logger.debug(
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
def run(self, code: str) -> CodeExecResult:
|
| 75 |
"""Execute Python code and return a CodeExecResult.
|
|
@@ -127,7 +130,11 @@ class PyExecutor:
|
|
| 127 |
# Determine exit code if provided
|
| 128 |
try:
|
| 129 |
if hasattr(exec_result, "exit_code"):
|
| 130 |
-
exit_code =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
elif hasattr(exec_result, "success"):
|
| 132 |
# Some versions use `success` boolean
|
| 133 |
exit_code = 0 if exec_result.success else 1
|
|
|
|
| 28 |
|
| 29 |
from smolagents import LocalPythonExecutor
|
| 30 |
|
| 31 |
+
from openenv.core.env_server.types import CodeExecResult
|
| 32 |
|
| 33 |
logger = logging.getLogger(__name__)
|
| 34 |
logger.addHandler(logging.NullHandler())
|
|
|
|
| 69 |
except Exception:
|
| 70 |
# If the LocalPythonExecutor implementation doesn't support
|
| 71 |
# send_tools or fails, log and continue — the executor is still usable.
|
| 72 |
+
logger.debug(
|
| 73 |
+
"LocalPythonExecutor.send_tools failed; continuing without extra tools",
|
| 74 |
+
exc_info=True,
|
| 75 |
+
)
|
| 76 |
|
| 77 |
def run(self, code: str) -> CodeExecResult:
|
| 78 |
"""Execute Python code and return a CodeExecResult.
|
|
|
|
| 130 |
# Determine exit code if provided
|
| 131 |
try:
|
| 132 |
if hasattr(exec_result, "exit_code"):
|
| 133 |
+
exit_code = (
|
| 134 |
+
int(exec_result.exit_code)
|
| 135 |
+
if exec_result.exit_code is not None
|
| 136 |
+
else 0
|
| 137 |
+
)
|
| 138 |
elif hasattr(exec_result, "success"):
|
| 139 |
# Some versions use `success` boolean
|
| 140 |
exit_code = 0 if exec_result.success else 1
|
src/core/utils.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Utility functions for OpenEnv core."""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def convert_to_ws_url(url: str) -> str:
|
| 11 |
+
"""
|
| 12 |
+
Convert an HTTP/HTTPS URL to a WS/WSS URL.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
url: The URL to convert.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
The converted WebSocket URL.
|
| 19 |
+
"""
|
| 20 |
+
ws_url = url.rstrip("/")
|
| 21 |
+
if ws_url.startswith("http://"):
|
| 22 |
+
ws_url = "ws://" + ws_url[7:]
|
| 23 |
+
elif ws_url.startswith("https://"):
|
| 24 |
+
ws_url = "wss://" + ws_url[8:]
|
| 25 |
+
elif not ws_url.startswith("ws://") and not ws_url.startswith("wss://"):
|
| 26 |
+
ws_url = "ws://" + ws_url
|
| 27 |
+
return ws_url
|
src/openenv/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified OpenEnv package bundling the CLI and core runtime.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from importlib import metadata
|
| 6 |
+
|
| 7 |
+
__all__ = ["core", "cli"]
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
__version__ = metadata.version("openenv") # type: ignore[arg-type]
|
| 11 |
+
except metadata.PackageNotFoundError: # pragma: no cover - local dev
|
| 12 |
+
__version__ = "0.0.0"
|
{core/containers → src/openenv/cli}/__init__.py
RENAMED
|
@@ -4,4 +4,6 @@
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
-
"""
|
|
|
|
|
|
|
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
+
"""OpenEnv CLI package."""
|
| 8 |
+
|
| 9 |
+
__version__ = "0.1.0"
|
src/openenv/cli/__main__.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
OpenEnv CLI entry point.
|
| 9 |
+
|
| 10 |
+
This module provides the main entry point for the OpenEnv command-line interface,
|
| 11 |
+
following the Hugging Face CLI pattern.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
import typer
|
| 17 |
+
|
| 18 |
+
from openenv.cli.commands import build, init, push, serve, validate
|
| 19 |
+
|
| 20 |
+
# Create the main CLI app
|
| 21 |
+
app = typer.Typer(
|
| 22 |
+
name="openenv",
|
| 23 |
+
help="OpenEnv - An e2e framework for creating, deploying and using isolated execution environments for agentic RL training",
|
| 24 |
+
no_args_is_help=True,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Register commands
|
| 28 |
+
app.command(name="init", help="Initialize a new OpenEnv environment")(init.init)
|
| 29 |
+
app.command(name="build", help="Build Docker images for OpenEnv environments")(
|
| 30 |
+
build.build
|
| 31 |
+
)
|
| 32 |
+
app.command(
|
| 33 |
+
name="validate", help="Validate environment structure and deployment readiness"
|
| 34 |
+
)(validate.validate)
|
| 35 |
+
app.command(
|
| 36 |
+
name="push",
|
| 37 |
+
help="Push an OpenEnv environment to Hugging Face Spaces or custom registry",
|
| 38 |
+
)(push.push)
|
| 39 |
+
app.command(name="serve", help="Serve environments locally (TODO: Phase 4)")(
|
| 40 |
+
serve.serve
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Entry point for setuptools
|
| 45 |
+
def main() -> None:
|
| 46 |
+
"""Main entry point for the CLI."""
|
| 47 |
+
try:
|
| 48 |
+
app()
|
| 49 |
+
except KeyboardInterrupt:
|
| 50 |
+
print("\nOperation cancelled by user.")
|
| 51 |
+
sys.exit(130)
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"Error: {e}", file=sys.stderr)
|
| 54 |
+
sys.exit(1)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
main()
|
src/openenv/cli/_cli_utils.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""CLI utilities for OpenEnv command-line interface."""
|
| 8 |
+
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import List
|
| 11 |
+
|
| 12 |
+
from rich.console import Console
|
| 13 |
+
|
| 14 |
+
# Create a console instance for CLI output
|
| 15 |
+
console = Console()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def validate_env_structure(env_dir: Path, strict: bool = False) -> List[str]:
|
| 19 |
+
"""
|
| 20 |
+
Validate that the directory follows OpenEnv environment structure.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
env_dir: Path to environment directory
|
| 24 |
+
strict: If True, enforce all optional requirements
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
List of validation warnings (empty if all checks pass)
|
| 28 |
+
|
| 29 |
+
Raises:
|
| 30 |
+
FileNotFoundError: If required files are missing
|
| 31 |
+
"""
|
| 32 |
+
warnings = []
|
| 33 |
+
|
| 34 |
+
# Required files
|
| 35 |
+
required_files = [
|
| 36 |
+
"openenv.yaml",
|
| 37 |
+
"__init__.py",
|
| 38 |
+
"client.py",
|
| 39 |
+
"models.py",
|
| 40 |
+
"README.md",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
for file in required_files:
|
| 44 |
+
if not (env_dir / file).exists():
|
| 45 |
+
raise FileNotFoundError(f"Required file missing: {file}")
|
| 46 |
+
|
| 47 |
+
# Required directories
|
| 48 |
+
server_dir = env_dir / "server"
|
| 49 |
+
if not server_dir.exists() or not server_dir.is_dir():
|
| 50 |
+
raise FileNotFoundError("Required directory missing: server/")
|
| 51 |
+
|
| 52 |
+
# Server directory required files
|
| 53 |
+
server_required = [
|
| 54 |
+
"server/__init__.py",
|
| 55 |
+
"server/app.py",
|
| 56 |
+
"server/Dockerfile",
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
for file in server_required:
|
| 60 |
+
if not (env_dir / file).exists():
|
| 61 |
+
raise FileNotFoundError(f"Required file missing: {file}")
|
| 62 |
+
|
| 63 |
+
# Check for dependency management (pyproject.toml required)
|
| 64 |
+
has_pyproject = (env_dir / "pyproject.toml").exists()
|
| 65 |
+
|
| 66 |
+
if not has_pyproject:
|
| 67 |
+
raise FileNotFoundError(
|
| 68 |
+
"No dependency specification found. 'pyproject.toml' is required."
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Warnings for recommended structure
|
| 72 |
+
|
| 73 |
+
if not (env_dir / "outputs").exists():
|
| 74 |
+
warnings.append("Recommended directory missing: outputs/")
|
| 75 |
+
|
| 76 |
+
return warnings
|
src/openenv/cli/_validation.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Validation utilities for multi-mode deployment readiness.
|
| 9 |
+
|
| 10 |
+
This module provides functions to check if environments are properly
|
| 11 |
+
configured for multi-mode deployment (Docker, direct Python, notebooks, clusters).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import subprocess
|
| 15 |
+
import tomllib
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def validate_multi_mode_deployment(env_path: Path) -> tuple[bool, list[str]]:
|
| 20 |
+
"""
|
| 21 |
+
Validate that an environment is ready for multi-mode deployment.
|
| 22 |
+
|
| 23 |
+
Checks:
|
| 24 |
+
1. pyproject.toml exists
|
| 25 |
+
2. uv.lock exists and is up-to-date
|
| 26 |
+
3. pyproject.toml has [project.scripts] with server entry point
|
| 27 |
+
4. server/app.py has a main() function
|
| 28 |
+
5. Required dependencies are present
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Tuple of (is_valid, list of issues found)
|
| 32 |
+
"""
|
| 33 |
+
issues = []
|
| 34 |
+
|
| 35 |
+
# Check pyproject.toml exists
|
| 36 |
+
pyproject_path = env_path / "pyproject.toml"
|
| 37 |
+
if not pyproject_path.exists():
|
| 38 |
+
issues.append("Missing pyproject.toml")
|
| 39 |
+
return False, issues
|
| 40 |
+
|
| 41 |
+
# Check uv.lock exists
|
| 42 |
+
lockfile_path = env_path / "uv.lock"
|
| 43 |
+
if not lockfile_path.exists():
|
| 44 |
+
issues.append("Missing uv.lock - run 'uv lock' to generate it")
|
| 45 |
+
else:
|
| 46 |
+
# Check if uv.lock is up-to-date (optional, can be expensive)
|
| 47 |
+
# We can add a check using `uv lock --check` if needed
|
| 48 |
+
try:
|
| 49 |
+
result = subprocess.run(
|
| 50 |
+
["uv", "lock", "--check", "--directory", str(env_path)],
|
| 51 |
+
capture_output=True,
|
| 52 |
+
text=True,
|
| 53 |
+
timeout=5,
|
| 54 |
+
)
|
| 55 |
+
if result.returncode != 0:
|
| 56 |
+
issues.append(
|
| 57 |
+
"uv.lock is out of date with pyproject.toml - run 'uv lock' to update"
|
| 58 |
+
)
|
| 59 |
+
except (subprocess.TimeoutExpired, FileNotFoundError):
|
| 60 |
+
# If uv is not available or times out, skip this check
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
# Parse pyproject.toml
|
| 64 |
+
try:
|
| 65 |
+
with open(pyproject_path, "rb") as f:
|
| 66 |
+
pyproject = tomllib.load(f)
|
| 67 |
+
except Exception as e:
|
| 68 |
+
issues.append(f"Failed to parse pyproject.toml: {e}")
|
| 69 |
+
return False, issues
|
| 70 |
+
|
| 71 |
+
# Check [project.scripts] section
|
| 72 |
+
scripts = pyproject.get("project", {}).get("scripts", {})
|
| 73 |
+
if "server" not in scripts:
|
| 74 |
+
issues.append("Missing [project.scripts] server entry point")
|
| 75 |
+
|
| 76 |
+
# Check server entry point format
|
| 77 |
+
server_entry = scripts.get("server", "")
|
| 78 |
+
if server_entry and ":main" not in server_entry:
|
| 79 |
+
issues.append(
|
| 80 |
+
f"Server entry point should reference main function, got: {server_entry}"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Check required dependencies
|
| 84 |
+
deps = [dep.lower() for dep in pyproject.get("project", {}).get("dependencies", [])]
|
| 85 |
+
has_openenv = any(
|
| 86 |
+
dep.startswith("openenv") and not dep.startswith("openenv-core") for dep in deps
|
| 87 |
+
)
|
| 88 |
+
has_legacy_core = any(dep.startswith("openenv-core") for dep in deps)
|
| 89 |
+
|
| 90 |
+
if not (has_openenv or has_legacy_core):
|
| 91 |
+
issues.append("Missing required dependency: openenv>=0.2.0")
|
| 92 |
+
elif has_legacy_core and not has_openenv:
|
| 93 |
+
issues.append(
|
| 94 |
+
"Dependency on openenv-core is deprecated; use openenv>=0.2.0 instead"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Check server/app.py exists
|
| 98 |
+
server_app = env_path / "server" / "app.py"
|
| 99 |
+
if not server_app.exists():
|
| 100 |
+
issues.append("Missing server/app.py")
|
| 101 |
+
else:
|
| 102 |
+
# Check for main() function (flexible - with or without parameters)
|
| 103 |
+
app_content = server_app.read_text(encoding="utf-8")
|
| 104 |
+
if "def main(" not in app_content:
|
| 105 |
+
issues.append("server/app.py missing main() function")
|
| 106 |
+
|
| 107 |
+
# Check if main() is callable
|
| 108 |
+
if "__name__" not in app_content or "main()" not in app_content:
|
| 109 |
+
issues.append(
|
| 110 |
+
"server/app.py main() function not callable (missing if __name__ == '__main__')"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return len(issues) == 0, issues
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_deployment_modes(env_path: Path) -> dict[str, bool]:
|
| 117 |
+
"""
|
| 118 |
+
Check which deployment modes are supported by the environment.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Dictionary with deployment mode names and whether they're supported
|
| 122 |
+
"""
|
| 123 |
+
modes = {
|
| 124 |
+
"docker": False,
|
| 125 |
+
"openenv_serve": False,
|
| 126 |
+
"uv_run": False,
|
| 127 |
+
"python_module": False,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
# Check Docker
|
| 131 |
+
dockerfile = env_path / "server" / "Dockerfile"
|
| 132 |
+
modes["docker"] = dockerfile.exists()
|
| 133 |
+
|
| 134 |
+
# Check multi-mode deployment readiness
|
| 135 |
+
is_valid, _ = validate_multi_mode_deployment(env_path)
|
| 136 |
+
if is_valid:
|
| 137 |
+
modes["openenv_serve"] = True
|
| 138 |
+
modes["uv_run"] = True
|
| 139 |
+
modes["python_module"] = True
|
| 140 |
+
|
| 141 |
+
return modes
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def format_validation_report(env_name: str, is_valid: bool, issues: list[str]) -> str:
|
| 145 |
+
"""
|
| 146 |
+
Format a validation report for display.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Formatted report string
|
| 150 |
+
"""
|
| 151 |
+
if is_valid:
|
| 152 |
+
return f"[OK] {env_name}: Ready for multi-mode deployment"
|
| 153 |
+
|
| 154 |
+
report = [f"[FAIL] {env_name}: Not ready for multi-mode deployment", ""]
|
| 155 |
+
report.append("Issues found:")
|
| 156 |
+
for issue in issues:
|
| 157 |
+
report.append(f" - {issue}")
|
| 158 |
+
|
| 159 |
+
return "\n".join(report)
|
src/openenv/cli/commands/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""OpenEnv CLI commands."""
|
| 8 |
+
|
| 9 |
+
from . import build, init, push, serve, validate
|
| 10 |
+
|
| 11 |
+
__all__ = ["build", "init", "push", "serve", "validate"]
|
src/openenv/cli/commands/build.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Build Docker images for OpenEnv environments."""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import shutil
|
| 12 |
+
import subprocess
|
| 13 |
+
import tempfile
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Annotated
|
| 17 |
+
|
| 18 |
+
import typer
|
| 19 |
+
|
| 20 |
+
from .._cli_utils import console
|
| 21 |
+
|
| 22 |
+
app = typer.Typer(help="Build Docker images for OpenEnv environments")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _detect_build_context(env_path: Path) -> tuple[str, Path, Path | None]:
|
| 26 |
+
"""
|
| 27 |
+
Detect whether we're building a standalone or in-repo environment.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
tuple: (build_mode, build_context_path, repo_root)
|
| 31 |
+
- build_mode: "standalone" or "in-repo"
|
| 32 |
+
- build_context_path: Path to use as Docker build context
|
| 33 |
+
- repo_root: Path to repo root (None for standalone)
|
| 34 |
+
"""
|
| 35 |
+
# Ensure env_path is absolute for proper comparison
|
| 36 |
+
env_path = env_path.absolute()
|
| 37 |
+
|
| 38 |
+
# Check if we're in a git repository
|
| 39 |
+
current = env_path
|
| 40 |
+
repo_root = None
|
| 41 |
+
|
| 42 |
+
# Walk up to find .git directory
|
| 43 |
+
for parent in [current] + list(current.parents):
|
| 44 |
+
if (parent / ".git").exists():
|
| 45 |
+
repo_root = parent
|
| 46 |
+
break
|
| 47 |
+
|
| 48 |
+
if repo_root is None:
|
| 49 |
+
# Not in a git repo = standalone
|
| 50 |
+
return "standalone", env_path, None
|
| 51 |
+
|
| 52 |
+
# Check if environment is under envs/ (in-repo pattern)
|
| 53 |
+
try:
|
| 54 |
+
rel_path = env_path.relative_to(repo_root)
|
| 55 |
+
rel_str = str(rel_path)
|
| 56 |
+
if (
|
| 57 |
+
rel_str.startswith("envs/")
|
| 58 |
+
or rel_str.startswith("envs\\")
|
| 59 |
+
or rel_str.startswith("envs/")
|
| 60 |
+
):
|
| 61 |
+
# In-repo environment
|
| 62 |
+
return "in-repo", repo_root, repo_root
|
| 63 |
+
except ValueError:
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
# Otherwise, it's standalone (environment outside repo structure)
|
| 67 |
+
return "standalone", env_path, None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _prepare_standalone_build(env_path: Path, temp_dir: Path) -> Path:
|
| 71 |
+
"""
|
| 72 |
+
Prepare a standalone environment for building.
|
| 73 |
+
|
| 74 |
+
For standalone builds:
|
| 75 |
+
1. Copy environment to temp directory
|
| 76 |
+
2. Ensure pyproject.toml depends on openenv
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Path to the prepared build directory
|
| 80 |
+
"""
|
| 81 |
+
console.print("[cyan]Preparing standalone build...[/cyan]")
|
| 82 |
+
|
| 83 |
+
# Copy environment to temp directory
|
| 84 |
+
build_dir = temp_dir / env_path.name
|
| 85 |
+
shutil.copytree(env_path, build_dir, symlinks=True)
|
| 86 |
+
|
| 87 |
+
console.print(f"[cyan]Copied environment to:[/cyan] {build_dir}")
|
| 88 |
+
|
| 89 |
+
# Check if pyproject.toml has openenv dependency
|
| 90 |
+
pyproject_path = build_dir / "pyproject.toml"
|
| 91 |
+
if pyproject_path.exists():
|
| 92 |
+
with open(pyproject_path, "rb") as f:
|
| 93 |
+
try:
|
| 94 |
+
import tomli
|
| 95 |
+
|
| 96 |
+
pyproject = tomli.load(f)
|
| 97 |
+
deps = pyproject.get("project", {}).get("dependencies", [])
|
| 98 |
+
|
| 99 |
+
# Check if openenv dependency is declared
|
| 100 |
+
has_openenv = any(dep.startswith("openenv") for dep in deps)
|
| 101 |
+
|
| 102 |
+
if not has_openenv:
|
| 103 |
+
console.print(
|
| 104 |
+
"[yellow]Warning:[/yellow] pyproject.toml doesn't list the openenv dependency",
|
| 105 |
+
)
|
| 106 |
+
console.print(
|
| 107 |
+
"[yellow]You may need to add:[/yellow] openenv>=0.2.0",
|
| 108 |
+
)
|
| 109 |
+
except ImportError:
|
| 110 |
+
console.print(
|
| 111 |
+
"[yellow]Warning:[/yellow] tomli not available, skipping dependency check",
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return build_dir
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _prepare_inrepo_build(env_path: Path, repo_root: Path, temp_dir: Path) -> Path:
|
| 118 |
+
"""
|
| 119 |
+
Prepare an in-repo environment for building.
|
| 120 |
+
|
| 121 |
+
For in-repo builds:
|
| 122 |
+
1. Create temp directory with environment and core
|
| 123 |
+
2. Set up structure that matches expected layout
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Path to the prepared build directory
|
| 127 |
+
"""
|
| 128 |
+
console.print("[cyan]Preparing in-repo build...[/cyan]")
|
| 129 |
+
|
| 130 |
+
# Copy environment to temp directory
|
| 131 |
+
build_dir = temp_dir / env_path.name
|
| 132 |
+
shutil.copytree(env_path, build_dir, symlinks=True)
|
| 133 |
+
|
| 134 |
+
# Copy OpenEnv package to temp directory
|
| 135 |
+
package_src = repo_root / "src" / "openenv"
|
| 136 |
+
if package_src.exists():
|
| 137 |
+
package_dest = build_dir / "openenv"
|
| 138 |
+
shutil.copytree(package_src, package_dest, symlinks=True)
|
| 139 |
+
console.print(f"[cyan]Copied OpenEnv package to:[/cyan] {package_dest}")
|
| 140 |
+
|
| 141 |
+
# Update pyproject.toml to reference local OpenEnv copy
|
| 142 |
+
pyproject_path = build_dir / "pyproject.toml"
|
| 143 |
+
if pyproject_path.exists():
|
| 144 |
+
with open(pyproject_path, "rb") as f:
|
| 145 |
+
try:
|
| 146 |
+
import tomli
|
| 147 |
+
|
| 148 |
+
pyproject = tomli.load(f)
|
| 149 |
+
deps = pyproject.get("project", {}).get("dependencies", [])
|
| 150 |
+
|
| 151 |
+
# Replace openenv/openenv-core with local reference
|
| 152 |
+
new_deps = []
|
| 153 |
+
for dep in deps:
|
| 154 |
+
if (
|
| 155 |
+
dep.startswith("openenv-core")
|
| 156 |
+
or dep.startswith("openenv_core")
|
| 157 |
+
or dep.startswith("openenv")
|
| 158 |
+
):
|
| 159 |
+
# Skip - we'll use local core
|
| 160 |
+
continue
|
| 161 |
+
new_deps.append(dep)
|
| 162 |
+
|
| 163 |
+
# Write back with local core reference
|
| 164 |
+
pyproject["project"]["dependencies"] = new_deps + [
|
| 165 |
+
"openenv @ file:///app/env/openenv"
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
# Write updated pyproject.toml
|
| 169 |
+
with open(pyproject_path, "wb") as out_f:
|
| 170 |
+
import tomli_w
|
| 171 |
+
|
| 172 |
+
tomli_w.dump(pyproject, out_f)
|
| 173 |
+
|
| 174 |
+
console.print(
|
| 175 |
+
"[cyan]Updated pyproject.toml to use local core[/cyan]"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Remove old lockfile since dependencies changed
|
| 179 |
+
lockfile = build_dir / "uv.lock"
|
| 180 |
+
if lockfile.exists():
|
| 181 |
+
lockfile.unlink()
|
| 182 |
+
console.print("[cyan]Removed outdated uv.lock[/cyan]")
|
| 183 |
+
|
| 184 |
+
except ImportError:
|
| 185 |
+
console.print(
|
| 186 |
+
"[yellow]Warning:[/yellow] tomli/tomli_w not available, using pyproject.toml as-is",
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
console.print(
|
| 190 |
+
"[yellow]Warning:[/yellow] OpenEnv package not found, building without it"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
console.print(f"[cyan]Build directory prepared:[/cyan] {build_dir}")
|
| 194 |
+
return build_dir
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _run_command(
|
| 198 |
+
cmd: list[str],
|
| 199 |
+
cwd: Path | None = None,
|
| 200 |
+
check: bool = True,
|
| 201 |
+
) -> subprocess.CompletedProcess:
|
| 202 |
+
"""Run a shell command and handle errors."""
|
| 203 |
+
console.print(f"[bold cyan]Running:[/bold cyan] {' '.join(cmd)}")
|
| 204 |
+
try:
|
| 205 |
+
result = subprocess.run(
|
| 206 |
+
cmd, cwd=cwd, check=check, capture_output=True, text=True
|
| 207 |
+
)
|
| 208 |
+
if result.stdout:
|
| 209 |
+
console.print(result.stdout)
|
| 210 |
+
if result.stderr:
|
| 211 |
+
print(result.stderr, file=sys.stderr)
|
| 212 |
+
return result
|
| 213 |
+
except subprocess.CalledProcessError as e:
|
| 214 |
+
print(f"Error running command: {e}", file=sys.stderr)
|
| 215 |
+
if e.stdout:
|
| 216 |
+
console.print(e.stdout)
|
| 217 |
+
if e.stderr:
|
| 218 |
+
print(e.stderr, file=sys.stderr)
|
| 219 |
+
if check:
|
| 220 |
+
raise typer.Exit(1) from e
|
| 221 |
+
return e
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _build_docker_image(
|
| 225 |
+
env_path: Path,
|
| 226 |
+
tag: str | None = None,
|
| 227 |
+
context_path: Path | None = None,
|
| 228 |
+
dockerfile: Path | None = None,
|
| 229 |
+
build_args: dict[str, str] | None = None,
|
| 230 |
+
no_cache: bool = False,
|
| 231 |
+
) -> bool:
|
| 232 |
+
"""Build Docker image for the environment with smart context detection."""
|
| 233 |
+
|
| 234 |
+
# Detect build context (standalone vs in-repo)
|
| 235 |
+
build_mode, detected_context, repo_root = _detect_build_context(env_path)
|
| 236 |
+
|
| 237 |
+
console.print(f"[bold cyan]Build mode detected:[/bold cyan] {build_mode}")
|
| 238 |
+
|
| 239 |
+
# Use detected context unless explicitly overridden
|
| 240 |
+
if context_path is None:
|
| 241 |
+
context_path = detected_context
|
| 242 |
+
|
| 243 |
+
# Create temporary build directory
|
| 244 |
+
with tempfile.TemporaryDirectory() as temp_dir_str:
|
| 245 |
+
temp_dir = Path(temp_dir_str)
|
| 246 |
+
|
| 247 |
+
# Prepare build directory based on mode
|
| 248 |
+
if build_mode == "standalone":
|
| 249 |
+
build_dir = _prepare_standalone_build(env_path, temp_dir)
|
| 250 |
+
else: # in-repo
|
| 251 |
+
build_dir = _prepare_inrepo_build(env_path, repo_root, temp_dir)
|
| 252 |
+
|
| 253 |
+
# Determine Dockerfile path
|
| 254 |
+
if dockerfile is None:
|
| 255 |
+
# Look for Dockerfile in server/ subdirectory
|
| 256 |
+
dockerfile = build_dir / "server" / "Dockerfile"
|
| 257 |
+
if not dockerfile.exists():
|
| 258 |
+
# Fallback to root of build directory
|
| 259 |
+
dockerfile = build_dir / "Dockerfile"
|
| 260 |
+
|
| 261 |
+
if not dockerfile.exists():
|
| 262 |
+
console.print(
|
| 263 |
+
f"[bold red]Error:[/bold red] Dockerfile not found at {dockerfile}",
|
| 264 |
+
)
|
| 265 |
+
return False
|
| 266 |
+
|
| 267 |
+
# Generate tag if not provided
|
| 268 |
+
if tag is None:
|
| 269 |
+
env_name = env_path.name
|
| 270 |
+
if env_name.endswith("_env"):
|
| 271 |
+
env_name = env_name[:-4]
|
| 272 |
+
tag = f"openenv-{env_name}"
|
| 273 |
+
|
| 274 |
+
console.print(f"[bold cyan]Building Docker image:[/bold cyan] {tag}")
|
| 275 |
+
console.print(f"[bold cyan]Build context:[/bold cyan] {build_dir}")
|
| 276 |
+
console.print(f"[bold cyan]Dockerfile:[/bold cyan] {dockerfile}")
|
| 277 |
+
|
| 278 |
+
# Prepare build args
|
| 279 |
+
if build_args is None:
|
| 280 |
+
build_args = {}
|
| 281 |
+
|
| 282 |
+
# Add build mode and env name to build args
|
| 283 |
+
build_args["BUILD_MODE"] = build_mode
|
| 284 |
+
build_args["ENV_NAME"] = env_path.name.replace("_env", "")
|
| 285 |
+
|
| 286 |
+
# Build Docker command
|
| 287 |
+
cmd = ["docker", "build", "-t", tag, "-f", str(dockerfile)]
|
| 288 |
+
|
| 289 |
+
if no_cache:
|
| 290 |
+
cmd.append("--no-cache")
|
| 291 |
+
|
| 292 |
+
for key, value in build_args.items():
|
| 293 |
+
cmd.extend(["--build-arg", f"{key}={value}"])
|
| 294 |
+
|
| 295 |
+
cmd.append(str(build_dir))
|
| 296 |
+
|
| 297 |
+
result = _run_command(cmd, check=False)
|
| 298 |
+
return result.returncode == 0
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def _push_docker_image(tag: str, registry: str | None = None) -> bool:
|
| 302 |
+
"""Push Docker image to registry."""
|
| 303 |
+
if registry:
|
| 304 |
+
full_tag = f"{registry}/{tag}"
|
| 305 |
+
console.print(f"[bold cyan]Tagging image as {full_tag}[/bold cyan]")
|
| 306 |
+
_run_command(["docker", "tag", tag, full_tag])
|
| 307 |
+
tag = full_tag
|
| 308 |
+
|
| 309 |
+
console.print(f"[bold cyan]Pushing image:[/bold cyan] {tag}")
|
| 310 |
+
result = _run_command(["docker", "push", tag], check=False)
|
| 311 |
+
return result.returncode == 0
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
@app.command()
|
| 315 |
+
def build(
|
| 316 |
+
env_path: Annotated[
|
| 317 |
+
str | None,
|
| 318 |
+
typer.Argument(
|
| 319 |
+
help="Path to the environment directory (default: current directory)"
|
| 320 |
+
),
|
| 321 |
+
] = None,
|
| 322 |
+
tag: Annotated[
|
| 323 |
+
str | None,
|
| 324 |
+
typer.Option(
|
| 325 |
+
"--tag",
|
| 326 |
+
"-t",
|
| 327 |
+
help="Docker image tag (default: openenv-<env_name>)",
|
| 328 |
+
),
|
| 329 |
+
] = None,
|
| 330 |
+
context: Annotated[
|
| 331 |
+
str | None,
|
| 332 |
+
typer.Option(
|
| 333 |
+
"--context",
|
| 334 |
+
"-c",
|
| 335 |
+
help="Build context path (default: <env_path>/server)",
|
| 336 |
+
),
|
| 337 |
+
] = None,
|
| 338 |
+
dockerfile: Annotated[
|
| 339 |
+
str | None,
|
| 340 |
+
typer.Option(
|
| 341 |
+
"--dockerfile",
|
| 342 |
+
"-f",
|
| 343 |
+
help="Path to Dockerfile (default: <context>/Dockerfile)",
|
| 344 |
+
),
|
| 345 |
+
] = None,
|
| 346 |
+
no_cache: Annotated[
|
| 347 |
+
bool,
|
| 348 |
+
typer.Option(
|
| 349 |
+
"--no-cache",
|
| 350 |
+
help="Build without using cache",
|
| 351 |
+
),
|
| 352 |
+
] = False,
|
| 353 |
+
build_arg: Annotated[
|
| 354 |
+
list[str] | None,
|
| 355 |
+
typer.Option(
|
| 356 |
+
"--build-arg",
|
| 357 |
+
help="Build arguments (can be used multiple times, format: KEY=VALUE)",
|
| 358 |
+
),
|
| 359 |
+
] = None,
|
| 360 |
+
) -> None:
|
| 361 |
+
"""
|
| 362 |
+
Build Docker images for OpenEnv environments.
|
| 363 |
+
|
| 364 |
+
This command builds Docker images using the environment's pyproject.toml
|
| 365 |
+
and uv for dependency management. Run from the environment root directory.
|
| 366 |
+
|
| 367 |
+
Examples:
|
| 368 |
+
# Build from environment root (recommended)
|
| 369 |
+
$ cd my_env
|
| 370 |
+
$ openenv build
|
| 371 |
+
|
| 372 |
+
# Build with custom tag
|
| 373 |
+
$ openenv build -t my-custom-tag
|
| 374 |
+
|
| 375 |
+
# Build without cache
|
| 376 |
+
$ openenv build --no-cache
|
| 377 |
+
|
| 378 |
+
# Build with custom build arguments
|
| 379 |
+
$ openenv build --build-arg VERSION=1.0 --build-arg ENV=prod
|
| 380 |
+
|
| 381 |
+
# Build from different directory
|
| 382 |
+
$ openenv build envs/echo_env
|
| 383 |
+
"""
|
| 384 |
+
# Determine environment path (default to current directory)
|
| 385 |
+
if env_path is None:
|
| 386 |
+
env_path_obj = Path.cwd()
|
| 387 |
+
else:
|
| 388 |
+
env_path_obj = Path(env_path)
|
| 389 |
+
|
| 390 |
+
# Validate environment path
|
| 391 |
+
if not env_path_obj.exists():
|
| 392 |
+
print(
|
| 393 |
+
f"Error: Environment path does not exist: {env_path_obj}",
|
| 394 |
+
file=sys.stderr,
|
| 395 |
+
)
|
| 396 |
+
raise typer.Exit(1)
|
| 397 |
+
|
| 398 |
+
if not env_path_obj.is_dir():
|
| 399 |
+
print(
|
| 400 |
+
f"Error: Environment path is not a directory: {env_path_obj}",
|
| 401 |
+
file=sys.stderr,
|
| 402 |
+
)
|
| 403 |
+
raise typer.Exit(1)
|
| 404 |
+
|
| 405 |
+
# Check for openenv.yaml to confirm this is an environment directory
|
| 406 |
+
openenv_yaml = env_path_obj / "openenv.yaml"
|
| 407 |
+
if not openenv_yaml.exists():
|
| 408 |
+
print(
|
| 409 |
+
f"Error: Not an OpenEnv environment directory (missing openenv.yaml): {env_path_obj}",
|
| 410 |
+
file=sys.stderr,
|
| 411 |
+
)
|
| 412 |
+
print(
|
| 413 |
+
"Hint: Run this command from the environment root directory or specify the path",
|
| 414 |
+
file=sys.stderr,
|
| 415 |
+
)
|
| 416 |
+
raise typer.Exit(1)
|
| 417 |
+
|
| 418 |
+
console.print(f"[bold]Building Docker image for:[/bold] {env_path_obj.name}")
|
| 419 |
+
console.print("=" * 60)
|
| 420 |
+
|
| 421 |
+
# Parse build args
|
| 422 |
+
build_args = {}
|
| 423 |
+
if build_arg:
|
| 424 |
+
for arg in build_arg:
|
| 425 |
+
if "=" in arg:
|
| 426 |
+
key, value = arg.split("=", 1)
|
| 427 |
+
build_args[key] = value
|
| 428 |
+
else:
|
| 429 |
+
print(
|
| 430 |
+
f"Warning: Invalid build arg format: {arg}",
|
| 431 |
+
file=sys.stderr,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Convert string paths to Path objects
|
| 435 |
+
context_path_obj = Path(context) if context else None
|
| 436 |
+
dockerfile_path_obj = Path(dockerfile) if dockerfile else None
|
| 437 |
+
|
| 438 |
+
# Build Docker image
|
| 439 |
+
success = _build_docker_image(
|
| 440 |
+
env_path=env_path_obj,
|
| 441 |
+
tag=tag,
|
| 442 |
+
context_path=context_path_obj,
|
| 443 |
+
dockerfile=dockerfile_path_obj,
|
| 444 |
+
build_args=build_args if build_args else None,
|
| 445 |
+
no_cache=no_cache,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
if not success:
|
| 449 |
+
print("Γ£ù Docker build failed", file=sys.stderr)
|
| 450 |
+
raise typer.Exit(1)
|
| 451 |
+
|
| 452 |
+
console.print("[bold green]Γ£ô Docker build successful[/bold green]")
|
| 453 |
+
console.print("\n[bold green]Done![/bold green]")
|
src/openenv/cli/commands/init.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Initialize a new OpenEnv environment."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import shutil
|
| 8 |
+
import subprocess
|
| 9 |
+
from importlib import resources
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Annotated, Dict, List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
import typer
|
| 14 |
+
|
| 15 |
+
from .._cli_utils import console
|
| 16 |
+
|
| 17 |
+
# Commands are registered in __main__.py
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _snake_to_pascal(snake_str: str) -> str:
|
| 21 |
+
"""Convert snake_case to PascalCase (e.g., 'my_env' -> 'MyEnv')."""
|
| 22 |
+
return "".join(word.capitalize() for word in snake_str.split("_"))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _get_env_prefix(env_name: str) -> str:
|
| 26 |
+
"""Extract the prefix for class names (e.g., 'my_env' -> 'My', 'test_env' -> 'Test')."""
|
| 27 |
+
# Remove trailing '_env' if present
|
| 28 |
+
if env_name.endswith("_env"):
|
| 29 |
+
base = env_name[:-4] # Remove '_env'
|
| 30 |
+
else:
|
| 31 |
+
base = env_name
|
| 32 |
+
|
| 33 |
+
# If empty or just one part, use the whole thing
|
| 34 |
+
if not base or "_" not in base:
|
| 35 |
+
return base.capitalize() if base else env_name.capitalize()
|
| 36 |
+
|
| 37 |
+
# PascalCase all parts except the last
|
| 38 |
+
parts = base.split("_")
|
| 39 |
+
return "".join(word.capitalize() for word in parts)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _snake_to_camel(snake_str: str) -> str:
|
| 43 |
+
"""Convert snake_case to camelCase (e.g., 'my_env' -> 'myEnv')."""
|
| 44 |
+
parts = snake_str.split("_")
|
| 45 |
+
return parts[0] + "".join(word.capitalize() for word in parts[1:])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _snake_to_title(snake_str: str) -> str:
|
| 49 |
+
"""Convert snake_case to Title Case (e.g., 'my_env' -> 'My Env')."""
|
| 50 |
+
return " ".join(word.capitalize() for word in snake_str.split("_"))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _validate_env_name(name: str) -> str:
|
| 54 |
+
"""Validate environment name (must be valid Python identifier in snake_case)."""
|
| 55 |
+
if not name:
|
| 56 |
+
raise typer.BadParameter("Environment name cannot be empty")
|
| 57 |
+
|
| 58 |
+
# Check if it's a valid Python identifier
|
| 59 |
+
if not name.isidentifier():
|
| 60 |
+
raise typer.BadParameter(
|
| 61 |
+
f"Environment name '{name}' is not a valid Python identifier. Use snake_case (e.g., 'my_env', 'game_env')."
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Check if it starts with a number
|
| 65 |
+
if name[0].isdigit():
|
| 66 |
+
raise typer.BadParameter(
|
| 67 |
+
f"Environment name '{name}' cannot start with a number."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
return name
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _get_random_hf_space_config() -> Dict[str, str]:
|
| 74 |
+
"""
|
| 75 |
+
Get random Hugging Face Space configuration values.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Dictionary with 'emoji', 'colorFrom', and 'colorTo' keys
|
| 79 |
+
"""
|
| 80 |
+
# Valid emojis (emoji-only characters)
|
| 81 |
+
emojis = [
|
| 82 |
+
"🎮",
|
| 83 |
+
"🎯",
|
| 84 |
+
"🚀",
|
| 85 |
+
"🌟",
|
| 86 |
+
"🎨",
|
| 87 |
+
"🎪",
|
| 88 |
+
"🎭",
|
| 89 |
+
"🎬",
|
| 90 |
+
"🎤",
|
| 91 |
+
"🎧",
|
| 92 |
+
"🎵",
|
| 93 |
+
"🎶",
|
| 94 |
+
"🎸",
|
| 95 |
+
"🎹",
|
| 96 |
+
"🥁",
|
| 97 |
+
"🎺",
|
| 98 |
+
"🎻",
|
| 99 |
+
"🎼",
|
| 100 |
+
"🎯",
|
| 101 |
+
"🎲",
|
| 102 |
+
"🎳",
|
| 103 |
+
"🎰",
|
| 104 |
+
"🎴",
|
| 105 |
+
"🃏",
|
| 106 |
+
"🀄",
|
| 107 |
+
"🎴",
|
| 108 |
+
"🎨",
|
| 109 |
+
"🖼️",
|
| 110 |
+
"🎬",
|
| 111 |
+
"🎭",
|
| 112 |
+
"🎪",
|
| 113 |
+
"🎤",
|
| 114 |
+
"🎧",
|
| 115 |
+
"🎵",
|
| 116 |
+
"🎶",
|
| 117 |
+
"🎸",
|
| 118 |
+
"🎹",
|
| 119 |
+
"🎺",
|
| 120 |
+
"🎻",
|
| 121 |
+
"🥁",
|
| 122 |
+
"🎯",
|
| 123 |
+
"🎲",
|
| 124 |
+
"🎳",
|
| 125 |
+
"🎰",
|
| 126 |
+
"🏀",
|
| 127 |
+
"⚽",
|
| 128 |
+
"🏈",
|
| 129 |
+
"⚾",
|
| 130 |
+
"🎾",
|
| 131 |
+
"🏐",
|
| 132 |
+
"🏉",
|
| 133 |
+
"🎱",
|
| 134 |
+
"🏓",
|
| 135 |
+
"🏸",
|
| 136 |
+
"🥅",
|
| 137 |
+
"🏒",
|
| 138 |
+
"🏑",
|
| 139 |
+
"🏏",
|
| 140 |
+
"⛳",
|
| 141 |
+
"🏹",
|
| 142 |
+
"🎣",
|
| 143 |
+
"🥊",
|
| 144 |
+
"🥋",
|
| 145 |
+
"🎽",
|
| 146 |
+
"🏅",
|
| 147 |
+
"🎖️",
|
| 148 |
+
"🏆",
|
| 149 |
+
"🥇",
|
| 150 |
+
"🥈",
|
| 151 |
+
"🥉",
|
| 152 |
+
"🔊",
|
| 153 |
+
"🔉",
|
| 154 |
+
"🔈",
|
| 155 |
+
"🔇",
|
| 156 |
+
"📢",
|
| 157 |
+
"📣",
|
| 158 |
+
"📯",
|
| 159 |
+
"🔔",
|
| 160 |
+
"🔕",
|
| 161 |
+
"📻",
|
| 162 |
+
"📡",
|
| 163 |
+
"💻",
|
| 164 |
+
"🖥️",
|
| 165 |
+
"🖨️",
|
| 166 |
+
"⌨️",
|
| 167 |
+
"🖱️",
|
| 168 |
+
"🖲️",
|
| 169 |
+
"🕹️",
|
| 170 |
+
"🗜️",
|
| 171 |
+
"💾",
|
| 172 |
+
"💿",
|
| 173 |
+
"📀",
|
| 174 |
+
"📼",
|
| 175 |
+
"📷",
|
| 176 |
+
"📸",
|
| 177 |
+
"📹",
|
| 178 |
+
"🎥",
|
| 179 |
+
"📽️",
|
| 180 |
+
"🎞️",
|
| 181 |
+
"📞",
|
| 182 |
+
"☎️",
|
| 183 |
+
"📟",
|
| 184 |
+
"📠",
|
| 185 |
+
"📺",
|
| 186 |
+
"📻",
|
| 187 |
+
"🎙️",
|
| 188 |
+
"🎚️",
|
| 189 |
+
"🎛️",
|
| 190 |
+
"⏱️",
|
| 191 |
+
"⏲️",
|
| 192 |
+
"⏰",
|
| 193 |
+
"🕰️",
|
| 194 |
+
"⌚",
|
| 195 |
+
"📱",
|
| 196 |
+
"📲",
|
| 197 |
+
"💻",
|
| 198 |
+
"⌨️",
|
| 199 |
+
"🖥️",
|
| 200 |
+
"🖨️",
|
| 201 |
+
"🖱️",
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
# Valid colors from HF Spaces config reference
|
| 205 |
+
colors = ["red", "yellow", "green", "blue", "indigo", "purple", "pink", "gray"]
|
| 206 |
+
|
| 207 |
+
return {
|
| 208 |
+
"emoji": random.choice(emojis),
|
| 209 |
+
"colorFrom": random.choice(colors),
|
| 210 |
+
"colorTo": random.choice(colors),
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _create_template_replacements(env_name: str) -> Dict[str, str]:
|
| 215 |
+
"""
|
| 216 |
+
Create comprehensive template replacement dictionary.
|
| 217 |
+
|
| 218 |
+
Supports all naming conventions:
|
| 219 |
+
- PascalCase for class names
|
| 220 |
+
- camelCase for variable names
|
| 221 |
+
- snake_case for module names, file paths
|
| 222 |
+
"""
|
| 223 |
+
env_pascal = _snake_to_pascal(env_name)
|
| 224 |
+
env_prefix = _get_env_prefix(env_name)
|
| 225 |
+
env_camel = _snake_to_camel(env_name)
|
| 226 |
+
env_title = _snake_to_title(env_name)
|
| 227 |
+
|
| 228 |
+
# Get random HF Space config values
|
| 229 |
+
hf_config = _get_random_hf_space_config()
|
| 230 |
+
|
| 231 |
+
replacements = {
|
| 232 |
+
# Template placeholders (MUST come first - full class names before partial)
|
| 233 |
+
"__ENV_CLASS_NAME__Environment": f"{env_prefix}Environment",
|
| 234 |
+
"__ENV_CLASS_NAME__Action": f"{env_prefix}Action",
|
| 235 |
+
"__ENV_CLASS_NAME__Observation": f"{env_prefix}Observation",
|
| 236 |
+
"__ENV_CLASS_NAME__Env": f"{env_prefix}Env",
|
| 237 |
+
# Template placeholders (partial - must come after full replacements)
|
| 238 |
+
"__ENV_NAME__": env_name,
|
| 239 |
+
"__ENV_CLASS_NAME__": env_prefix, # Use prefix, not full PascalCase
|
| 240 |
+
"__ENV_TITLE_NAME__": env_title,
|
| 241 |
+
"__ENV_CAMEL_NAME__": env_camel,
|
| 242 |
+
# Hugging Face Space config placeholders
|
| 243 |
+
"__HF_EMOJI__": hf_config["emoji"],
|
| 244 |
+
"__HF_COLOR_FROM__": hf_config["colorFrom"],
|
| 245 |
+
"__HF_COLOR_TO__": hf_config["colorTo"],
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
return replacements
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _replace_in_content(content: str, replacements: Dict[str, str]) -> str:
|
| 252 |
+
"""Replace all occurrences in content using case-sensitive replacements."""
|
| 253 |
+
result = content
|
| 254 |
+
# Sort by length (longest first) to avoid partial replacements
|
| 255 |
+
for old, new in sorted(replacements.items(), key=lambda x: len(x[0]), reverse=True):
|
| 256 |
+
result = result.replace(old, new)
|
| 257 |
+
return result
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def _should_rename_file(filename: str, env_name: str) -> Tuple[bool, str]:
|
| 261 |
+
"""
|
| 262 |
+
Check if a file should be renamed and return the new name.
|
| 263 |
+
|
| 264 |
+
Handles template placeholders in filenames like:
|
| 265 |
+
- `__ENV_NAME___environment.py` → `<env_name>_environment.py`
|
| 266 |
+
"""
|
| 267 |
+
# Check for template placeholder
|
| 268 |
+
if "__ENV_NAME__" in filename:
|
| 269 |
+
new_name = filename.replace("__ENV_NAME__", env_name)
|
| 270 |
+
return True, new_name
|
| 271 |
+
|
| 272 |
+
return False, filename
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _copy_and_template_file(
|
| 276 |
+
src_path: Path,
|
| 277 |
+
dest_path: Path,
|
| 278 |
+
replacements: Dict[str, str],
|
| 279 |
+
) -> None:
|
| 280 |
+
"""Copy a file and apply template replacements."""
|
| 281 |
+
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
# Read source file
|
| 285 |
+
content = src_path.read_bytes()
|
| 286 |
+
|
| 287 |
+
# Try to decode as text and apply replacements
|
| 288 |
+
try:
|
| 289 |
+
text = content.decode("utf-8")
|
| 290 |
+
# Normalize line endings to LF before applying replacements
|
| 291 |
+
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
| 292 |
+
text = _replace_in_content(text, replacements)
|
| 293 |
+
dest_path.write_text(text, encoding="utf-8", newline="\n")
|
| 294 |
+
except UnicodeDecodeError:
|
| 295 |
+
# Binary file, just copy
|
| 296 |
+
dest_path.write_bytes(content)
|
| 297 |
+
except Exception as e:
|
| 298 |
+
raise RuntimeError(
|
| 299 |
+
f"Failed to copy template file {src_path} to {dest_path}: {e}"
|
| 300 |
+
) from e
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _copy_template_directory(
|
| 304 |
+
template_pkg: str,
|
| 305 |
+
template_dir: str,
|
| 306 |
+
dest_dir: Path,
|
| 307 |
+
replacements: Dict[str, str],
|
| 308 |
+
env_name: str,
|
| 309 |
+
) -> List[Path]:
|
| 310 |
+
"""Recursively copy template directory and apply replacements."""
|
| 311 |
+
created_files: List[Path] = []
|
| 312 |
+
|
| 313 |
+
# Get the package path using importlib.resources but avoid importing the template package
|
| 314 |
+
# We'll use the package's __file__ to get the directory path
|
| 315 |
+
import importlib
|
| 316 |
+
|
| 317 |
+
try:
|
| 318 |
+
# Import the parent package (not the template package itself)
|
| 319 |
+
if "." in template_pkg:
|
| 320 |
+
parent_pkg = ".".join(template_pkg.split(".")[:-1])
|
| 321 |
+
pkg = importlib.import_module(parent_pkg)
|
| 322 |
+
template_path = Path(pkg.__file__).parent / template_pkg.split(".")[-1]
|
| 323 |
+
else:
|
| 324 |
+
pkg = importlib.import_module(template_pkg.split(".")[0])
|
| 325 |
+
template_path = Path(pkg.__file__).parent / template_pkg.split(".")[-1]
|
| 326 |
+
except Exception:
|
| 327 |
+
# Fallback: try to use resources.files but handle import errors
|
| 328 |
+
try:
|
| 329 |
+
base = resources.files(template_pkg.split(".")[0])
|
| 330 |
+
template_path = base.joinpath(*template_pkg.split(".")[1:])
|
| 331 |
+
if not template_path.exists():
|
| 332 |
+
raise FileNotFoundError(f"Template directory not found: {template_pkg}")
|
| 333 |
+
except Exception as e:
|
| 334 |
+
raise FileNotFoundError(
|
| 335 |
+
f"Template directory not found: {template_pkg}"
|
| 336 |
+
) from e
|
| 337 |
+
|
| 338 |
+
if template_dir:
|
| 339 |
+
template_path = template_path / template_dir
|
| 340 |
+
|
| 341 |
+
if not template_path.exists() or not template_path.is_dir():
|
| 342 |
+
raise FileNotFoundError(
|
| 343 |
+
f"Template directory not found: {template_pkg}.{template_dir}"
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# Walk through all files in template directory using Path
|
| 347 |
+
for item in template_path.rglob("*"):
|
| 348 |
+
if item.is_file():
|
| 349 |
+
rel_path = item.relative_to(template_path)
|
| 350 |
+
dest_path = dest_dir / rel_path
|
| 351 |
+
|
| 352 |
+
# Apply filename templating
|
| 353 |
+
should_rename, new_name = _should_rename_file(dest_path.name, env_name)
|
| 354 |
+
if should_rename:
|
| 355 |
+
dest_path = dest_path.parent / new_name
|
| 356 |
+
|
| 357 |
+
# Copy and apply replacements
|
| 358 |
+
_copy_and_template_file(item, dest_path, replacements)
|
| 359 |
+
created_files.append(dest_path)
|
| 360 |
+
|
| 361 |
+
return created_files
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _generate_uv_lock(env_dir: Path) -> bool:
|
| 365 |
+
"""Generate uv.lock from pyproject.toml using uv."""
|
| 366 |
+
pyproject_path = env_dir / "pyproject.toml"
|
| 367 |
+
|
| 368 |
+
if not pyproject_path.exists():
|
| 369 |
+
return False
|
| 370 |
+
|
| 371 |
+
try:
|
| 372 |
+
cmd = [
|
| 373 |
+
"uv",
|
| 374 |
+
"lock",
|
| 375 |
+
"--directory",
|
| 376 |
+
str(env_dir),
|
| 377 |
+
]
|
| 378 |
+
|
| 379 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
| 380 |
+
|
| 381 |
+
if result.stdout:
|
| 382 |
+
console.print(result.stdout)
|
| 383 |
+
|
| 384 |
+
return True
|
| 385 |
+
|
| 386 |
+
except subprocess.CalledProcessError as e:
|
| 387 |
+
console.print(
|
| 388 |
+
f"[yellow]Warning: Could not generate uv.lock: {e.stderr}[/yellow]"
|
| 389 |
+
)
|
| 390 |
+
return False
|
| 391 |
+
except FileNotFoundError:
|
| 392 |
+
console.print(
|
| 393 |
+
"[yellow]Warning: 'uv' not found. Install it to generate uv.lock[/yellow]"
|
| 394 |
+
)
|
| 395 |
+
return False
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def init(
|
| 399 |
+
env_name: Annotated[
|
| 400 |
+
str,
|
| 401 |
+
typer.Argument(
|
| 402 |
+
help="Name of the environment to create (snake_case, e.g., 'my_env')"
|
| 403 |
+
),
|
| 404 |
+
],
|
| 405 |
+
output_dir: Annotated[
|
| 406 |
+
Optional[str],
|
| 407 |
+
typer.Option(
|
| 408 |
+
"--output-dir",
|
| 409 |
+
"-o",
|
| 410 |
+
help="Output directory (defaults to current working directory)",
|
| 411 |
+
),
|
| 412 |
+
] = None,
|
| 413 |
+
) -> None:
|
| 414 |
+
"""
|
| 415 |
+
Initialize a new OpenEnv environment.
|
| 416 |
+
|
| 417 |
+
Creates a new directory with the environment name and generates all necessary
|
| 418 |
+
files based on the OpenEnv template structure.
|
| 419 |
+
|
| 420 |
+
Example:
|
| 421 |
+
$ openenv init my_game_env
|
| 422 |
+
$ openenv init my_env --output-dir /path/to/projects
|
| 423 |
+
"""
|
| 424 |
+
# Validate environment name
|
| 425 |
+
env_name = _validate_env_name(env_name)
|
| 426 |
+
|
| 427 |
+
# Determine output directory
|
| 428 |
+
base_dir = Path(output_dir).resolve() if output_dir else Path.cwd().resolve()
|
| 429 |
+
env_dir = base_dir / env_name
|
| 430 |
+
|
| 431 |
+
# Check if directory already exists
|
| 432 |
+
if env_dir.exists():
|
| 433 |
+
if env_dir.is_file():
|
| 434 |
+
raise typer.BadParameter(f"Path '{env_dir}' exists and is a file")
|
| 435 |
+
if any(env_dir.iterdir()):
|
| 436 |
+
raise typer.BadParameter(
|
| 437 |
+
f"Directory '{env_dir}' already exists and is not empty. "
|
| 438 |
+
"Please choose a different name or remove the existing directory."
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
try:
|
| 442 |
+
# Create template replacements
|
| 443 |
+
replacements = _create_template_replacements(env_name)
|
| 444 |
+
|
| 445 |
+
# Create environment directory
|
| 446 |
+
env_dir.mkdir(parents=True, exist_ok=True)
|
| 447 |
+
|
| 448 |
+
console.print(
|
| 449 |
+
f"[bold cyan]Creating OpenEnv environment '{env_name}'...[/bold cyan]"
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Copy template files from template structure
|
| 453 |
+
template_pkg = "openenv.cli.templates.openenv_env"
|
| 454 |
+
created_files = _copy_template_directory(
|
| 455 |
+
template_pkg,
|
| 456 |
+
"",
|
| 457 |
+
env_dir,
|
| 458 |
+
replacements,
|
| 459 |
+
env_name,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
console.print(f"[bold green]✓[/bold green] Created {len(created_files)} files")
|
| 463 |
+
|
| 464 |
+
# Generate uv.lock
|
| 465 |
+
console.print("\n[bold]Generating uv.lock...[/bold]")
|
| 466 |
+
if _generate_uv_lock(env_dir):
|
| 467 |
+
console.print("[green]✓[/green] Generated uv.lock")
|
| 468 |
+
else:
|
| 469 |
+
console.print("[yellow]⚠[/yellow] Could not generate uv.lock automatically")
|
| 470 |
+
console.print(" You can generate it manually with:")
|
| 471 |
+
console.print(f" cd {env_dir} && uv lock")
|
| 472 |
+
|
| 473 |
+
console.print(
|
| 474 |
+
f"\n[bold green]Environment created successfully at: {env_dir}[/bold green]"
|
| 475 |
+
)
|
| 476 |
+
console.print("\n[bold]Next steps:[/bold]")
|
| 477 |
+
console.print(f" cd {env_dir}")
|
| 478 |
+
console.print(
|
| 479 |
+
f" # Edit your environment implementation in server/{env_name}_environment.py"
|
| 480 |
+
)
|
| 481 |
+
console.print(" # Edit your models in models.py")
|
| 482 |
+
console.print(" # Install dependencies: uv sync")
|
| 483 |
+
console.print("\n # To integrate into OpenEnv repo:")
|
| 484 |
+
console.print(f" # 1. Copy this directory to <repo_root>/envs/{env_name}_env")
|
| 485 |
+
console.print(
|
| 486 |
+
f" # 2. Build from repo root: docker build -t {env_name}_env:latest -f envs/{env_name}_env/server/Dockerfile ."
|
| 487 |
+
)
|
| 488 |
+
console.print(
|
| 489 |
+
f" # 3. Run your image: docker run -p 8000:8000 {env_name}_env:latest"
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
except Exception as e:
|
| 493 |
+
# Cleanup on error
|
| 494 |
+
if env_dir.exists() and env_dir.is_dir():
|
| 495 |
+
try:
|
| 496 |
+
shutil.rmtree(env_dir)
|
| 497 |
+
except Exception:
|
| 498 |
+
pass
|
| 499 |
+
|
| 500 |
+
console.print(f"[bold red]Error:[/bold red] {e}")
|
| 501 |
+
raise typer.Exit(1) from e
|
src/openenv/cli/commands/push.py
ADDED
|
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Push an OpenEnv environment to Hugging Face Spaces."""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import shutil
|
| 12 |
+
import tempfile
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Annotated
|
| 15 |
+
import sys
|
| 16 |
+
import typer
|
| 17 |
+
import yaml
|
| 18 |
+
from huggingface_hub import HfApi, login, whoami
|
| 19 |
+
|
| 20 |
+
from .._cli_utils import console, validate_env_structure
|
| 21 |
+
|
| 22 |
+
app = typer.Typer(help="Push an OpenEnv environment to Hugging Face Spaces")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _validate_openenv_directory(directory: Path) -> tuple[str, dict]:
|
| 26 |
+
"""
|
| 27 |
+
Validate that the directory is an OpenEnv environment.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Tuple of (env_name, manifest_data)
|
| 31 |
+
"""
|
| 32 |
+
# Use the comprehensive validation function
|
| 33 |
+
try:
|
| 34 |
+
warnings = validate_env_structure(directory)
|
| 35 |
+
for warning in warnings:
|
| 36 |
+
console.print(f"[bold yellow]ΓÜá[/bold yellow] {warning}")
|
| 37 |
+
except FileNotFoundError as e:
|
| 38 |
+
raise typer.BadParameter(f"Invalid OpenEnv environment structure: {e}") from e
|
| 39 |
+
|
| 40 |
+
# Load and validate manifest
|
| 41 |
+
manifest_path = directory / "openenv.yaml"
|
| 42 |
+
try:
|
| 43 |
+
with open(manifest_path, "r") as f:
|
| 44 |
+
manifest = yaml.safe_load(f)
|
| 45 |
+
except Exception as e:
|
| 46 |
+
raise typer.BadParameter(f"Failed to parse openenv.yaml: {e}") from e
|
| 47 |
+
|
| 48 |
+
if not isinstance(manifest, dict):
|
| 49 |
+
raise typer.BadParameter("openenv.yaml must be a YAML dictionary")
|
| 50 |
+
|
| 51 |
+
env_name = manifest.get("name")
|
| 52 |
+
if not env_name:
|
| 53 |
+
raise typer.BadParameter("openenv.yaml must contain a 'name' field")
|
| 54 |
+
|
| 55 |
+
return env_name, manifest
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _ensure_hf_authenticated() -> str:
|
| 59 |
+
"""
|
| 60 |
+
Ensure user is authenticated with Hugging Face.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Username of authenticated user
|
| 64 |
+
"""
|
| 65 |
+
try:
|
| 66 |
+
# Try to get current user
|
| 67 |
+
user_info = whoami()
|
| 68 |
+
# Handle both dict and object return types
|
| 69 |
+
if isinstance(user_info, dict):
|
| 70 |
+
username = (
|
| 71 |
+
user_info.get("name")
|
| 72 |
+
or user_info.get("fullname")
|
| 73 |
+
or user_info.get("username")
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
# If it's an object, try to get name attribute
|
| 77 |
+
username = (
|
| 78 |
+
getattr(user_info, "name", None)
|
| 79 |
+
or getattr(user_info, "fullname", None)
|
| 80 |
+
or getattr(user_info, "username", None)
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if not username:
|
| 84 |
+
raise ValueError("Could not extract username from whoami response")
|
| 85 |
+
|
| 86 |
+
console.print(f"[bold green]Γ£ô[/bold green] Authenticated as: {username}")
|
| 87 |
+
return username
|
| 88 |
+
except Exception:
|
| 89 |
+
# Not authenticated, prompt for login
|
| 90 |
+
console.print(
|
| 91 |
+
"[bold yellow]Not authenticated with Hugging Face. Please login...[/bold yellow]"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
login()
|
| 96 |
+
# Verify login worked
|
| 97 |
+
user_info = whoami()
|
| 98 |
+
# Handle both dict and object return types
|
| 99 |
+
if isinstance(user_info, dict):
|
| 100 |
+
username = (
|
| 101 |
+
user_info.get("name")
|
| 102 |
+
or user_info.get("fullname")
|
| 103 |
+
or user_info.get("username")
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
username = (
|
| 107 |
+
getattr(user_info, "name", None)
|
| 108 |
+
or getattr(user_info, "fullname", None)
|
| 109 |
+
or getattr(user_info, "username", None)
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if not username:
|
| 113 |
+
raise ValueError("Could not extract username from whoami response")
|
| 114 |
+
|
| 115 |
+
console.print(f"[bold green]Γ£ô[/bold green] Authenticated as: {username}")
|
| 116 |
+
return username
|
| 117 |
+
except Exception as e:
|
| 118 |
+
raise typer.BadParameter(
|
| 119 |
+
f"Hugging Face authentication failed: {e}. Please run login manually."
|
| 120 |
+
) from e
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _prepare_staging_directory(
|
| 124 |
+
env_dir: Path,
|
| 125 |
+
env_name: str,
|
| 126 |
+
staging_dir: Path,
|
| 127 |
+
base_image: str | None = None,
|
| 128 |
+
enable_interface: bool = True,
|
| 129 |
+
) -> None:
|
| 130 |
+
"""
|
| 131 |
+
Prepare files for deployment.
|
| 132 |
+
|
| 133 |
+
This includes:
|
| 134 |
+
- Copying necessary files
|
| 135 |
+
- Modifying Dockerfile to optionally enable web interface and update base image
|
| 136 |
+
- Ensuring README has proper HF frontmatter (if interface enabled)
|
| 137 |
+
"""
|
| 138 |
+
# Create staging directory structure
|
| 139 |
+
staging_dir.mkdir(parents=True, exist_ok=True)
|
| 140 |
+
|
| 141 |
+
# Copy all files from env directory
|
| 142 |
+
for item in env_dir.iterdir():
|
| 143 |
+
# Skip hidden files and common ignore patterns
|
| 144 |
+
if item.name.startswith(".") or item.name in ["__pycache__", ".git"]:
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
dest = staging_dir / item.name
|
| 148 |
+
if item.is_dir():
|
| 149 |
+
shutil.copytree(item, dest, dirs_exist_ok=True)
|
| 150 |
+
else:
|
| 151 |
+
shutil.copy2(item, dest)
|
| 152 |
+
|
| 153 |
+
# Ensure Dockerfile is at repository root (required by Hugging Face)
|
| 154 |
+
dockerfile_server_path = staging_dir / "server" / "Dockerfile"
|
| 155 |
+
dockerfile_root_path = staging_dir / "Dockerfile"
|
| 156 |
+
dockerfile_path: Path | None = None
|
| 157 |
+
|
| 158 |
+
if dockerfile_server_path.exists():
|
| 159 |
+
if dockerfile_root_path.exists():
|
| 160 |
+
dockerfile_root_path.unlink()
|
| 161 |
+
dockerfile_server_path.rename(dockerfile_root_path)
|
| 162 |
+
console.print(
|
| 163 |
+
"[bold cyan]Moved Dockerfile to repository root for deployment[/bold cyan]"
|
| 164 |
+
)
|
| 165 |
+
dockerfile_path = dockerfile_root_path
|
| 166 |
+
elif dockerfile_root_path.exists():
|
| 167 |
+
dockerfile_path = dockerfile_root_path
|
| 168 |
+
|
| 169 |
+
# Modify Dockerfile to optionally enable web interface and update base image
|
| 170 |
+
if dockerfile_path and dockerfile_path.exists():
|
| 171 |
+
dockerfile_content = dockerfile_path.read_text()
|
| 172 |
+
lines = dockerfile_content.split("\n")
|
| 173 |
+
new_lines = []
|
| 174 |
+
cmd_found = False
|
| 175 |
+
base_image_updated = False
|
| 176 |
+
web_interface_env_exists = "ENABLE_WEB_INTERFACE" in dockerfile_content
|
| 177 |
+
last_instruction = None
|
| 178 |
+
|
| 179 |
+
for line in lines:
|
| 180 |
+
stripped = line.strip()
|
| 181 |
+
token = stripped.split(maxsplit=1)[0] if stripped else ""
|
| 182 |
+
current_instruction = token.upper()
|
| 183 |
+
|
| 184 |
+
is_healthcheck_continuation = last_instruction == "HEALTHCHECK"
|
| 185 |
+
|
| 186 |
+
# Update base image if specified
|
| 187 |
+
if base_image and stripped.startswith("FROM") and not base_image_updated:
|
| 188 |
+
new_lines.append(f"FROM {base_image}")
|
| 189 |
+
base_image_updated = True
|
| 190 |
+
last_instruction = "FROM"
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
if (
|
| 194 |
+
stripped.startswith("CMD")
|
| 195 |
+
and not cmd_found
|
| 196 |
+
and not web_interface_env_exists
|
| 197 |
+
and enable_interface
|
| 198 |
+
and not is_healthcheck_continuation
|
| 199 |
+
):
|
| 200 |
+
new_lines.append("ENV ENABLE_WEB_INTERFACE=true")
|
| 201 |
+
cmd_found = True
|
| 202 |
+
|
| 203 |
+
new_lines.append(line)
|
| 204 |
+
|
| 205 |
+
if current_instruction:
|
| 206 |
+
last_instruction = current_instruction
|
| 207 |
+
|
| 208 |
+
if not cmd_found and not web_interface_env_exists and enable_interface:
|
| 209 |
+
new_lines.append("ENV ENABLE_WEB_INTERFACE=true")
|
| 210 |
+
|
| 211 |
+
if base_image and not base_image_updated:
|
| 212 |
+
new_lines.insert(0, f"FROM {base_image}")
|
| 213 |
+
|
| 214 |
+
dockerfile_path.write_text("\n".join(new_lines))
|
| 215 |
+
|
| 216 |
+
changes = []
|
| 217 |
+
if base_image and base_image_updated:
|
| 218 |
+
changes.append("updated base image")
|
| 219 |
+
if enable_interface and not web_interface_env_exists:
|
| 220 |
+
changes.append("enabled web interface")
|
| 221 |
+
if changes:
|
| 222 |
+
console.print(
|
| 223 |
+
f"[bold green]Γ£ô[/bold green] Updated Dockerfile: {', '.join(changes)}"
|
| 224 |
+
)
|
| 225 |
+
else:
|
| 226 |
+
console.print(
|
| 227 |
+
"[bold yellow]ΓÜá[/bold yellow] No Dockerfile found at server/Dockerfile"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Ensure README has proper HF frontmatter (only if interface enabled)
|
| 231 |
+
if enable_interface:
|
| 232 |
+
readme_path = staging_dir / "README.md"
|
| 233 |
+
if readme_path.exists():
|
| 234 |
+
readme_content = readme_path.read_text()
|
| 235 |
+
if "base_path: /web" not in readme_content:
|
| 236 |
+
# Check if frontmatter exists
|
| 237 |
+
if readme_content.startswith("---"):
|
| 238 |
+
# Add base_path to existing frontmatter
|
| 239 |
+
lines = readme_content.split("\n")
|
| 240 |
+
new_lines = []
|
| 241 |
+
_in_frontmatter = True
|
| 242 |
+
for i, line in enumerate(lines):
|
| 243 |
+
new_lines.append(line)
|
| 244 |
+
if line.strip() == "---" and i > 0:
|
| 245 |
+
# End of frontmatter, add base_path before this line
|
| 246 |
+
if "base_path:" not in "\n".join(new_lines):
|
| 247 |
+
new_lines.insert(-1, "base_path: /web")
|
| 248 |
+
_in_frontmatter = False
|
| 249 |
+
readme_path.write_text("\n".join(new_lines))
|
| 250 |
+
else:
|
| 251 |
+
# No frontmatter, add it
|
| 252 |
+
frontmatter = f"""---
|
| 253 |
+
title: {env_name.replace("_", " ").title()} Environment Server
|
| 254 |
+
emoji: 🔊
|
| 255 |
+
colorFrom: '#00C9FF'
|
| 256 |
+
colorTo: '#1B2845'
|
| 257 |
+
sdk: docker
|
| 258 |
+
pinned: false
|
| 259 |
+
app_port: 8000
|
| 260 |
+
base_path: /web
|
| 261 |
+
tags:
|
| 262 |
+
- openenv
|
| 263 |
+
---
|
| 264 |
+
|
| 265 |
+
"""
|
| 266 |
+
readme_path.write_text(frontmatter + readme_content)
|
| 267 |
+
console.print(
|
| 268 |
+
"[bold green]Γ£ô[/bold green] Updated README with HF Space frontmatter"
|
| 269 |
+
)
|
| 270 |
+
else:
|
| 271 |
+
console.print("[bold yellow]ΓÜá[/bold yellow] No README.md found")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def _create_hf_space(
|
| 275 |
+
repo_id: str,
|
| 276 |
+
api: HfApi,
|
| 277 |
+
private: bool = False,
|
| 278 |
+
) -> None:
|
| 279 |
+
"""Create a Hugging Face Space if it doesn't exist."""
|
| 280 |
+
console.print(f"[bold cyan]Creating/verifying space: {repo_id}[/bold cyan]")
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
api.create_repo(
|
| 284 |
+
repo_id=repo_id,
|
| 285 |
+
repo_type="space",
|
| 286 |
+
space_sdk="docker",
|
| 287 |
+
private=private,
|
| 288 |
+
exist_ok=True,
|
| 289 |
+
)
|
| 290 |
+
console.print(f"[bold green]Γ£ô[/bold green] Space {repo_id} is ready")
|
| 291 |
+
except Exception as e:
|
| 292 |
+
# Space might already exist, which is okay with exist_ok=True
|
| 293 |
+
# But if there's another error, log it
|
| 294 |
+
console.print(f"[bold yellow]ΓÜá[/bold yellow] Space creation: {e}")
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def _upload_to_hf_space(
|
| 298 |
+
repo_id: str,
|
| 299 |
+
staging_dir: Path,
|
| 300 |
+
api: HfApi,
|
| 301 |
+
private: bool = False,
|
| 302 |
+
) -> None:
|
| 303 |
+
"""Upload files to Hugging Face Space."""
|
| 304 |
+
console.print(f"[bold cyan]Uploading files to {repo_id}...[/bold cyan]")
|
| 305 |
+
|
| 306 |
+
try:
|
| 307 |
+
api.upload_folder(
|
| 308 |
+
folder_path=str(staging_dir),
|
| 309 |
+
repo_id=repo_id,
|
| 310 |
+
repo_type="space",
|
| 311 |
+
ignore_patterns=[".git", "__pycache__", "*.pyc"],
|
| 312 |
+
)
|
| 313 |
+
console.print("[bold green]Γ£ô[/bold green] Upload completed successfully")
|
| 314 |
+
console.print(
|
| 315 |
+
f"[bold]Space URL:[/bold] https://huggingface.co/spaces/{repo_id}"
|
| 316 |
+
)
|
| 317 |
+
except Exception as e:
|
| 318 |
+
console.print(f"[bold red]Γ£ù[/bold red] Upload failed: {e}")
|
| 319 |
+
raise typer.Exit(1) from e
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@app.command()
|
| 323 |
+
def push(
|
| 324 |
+
directory: Annotated[
|
| 325 |
+
str | None,
|
| 326 |
+
typer.Argument(
|
| 327 |
+
help="Directory containing the OpenEnv environment (default: current directory)"
|
| 328 |
+
),
|
| 329 |
+
] = None,
|
| 330 |
+
repo_id: Annotated[
|
| 331 |
+
str | None,
|
| 332 |
+
typer.Option(
|
| 333 |
+
"--repo-id",
|
| 334 |
+
"-r",
|
| 335 |
+
help="Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)",
|
| 336 |
+
),
|
| 337 |
+
] = None,
|
| 338 |
+
base_image: Annotated[
|
| 339 |
+
str | None,
|
| 340 |
+
typer.Option(
|
| 341 |
+
"--base-image",
|
| 342 |
+
"-b",
|
| 343 |
+
help="Base Docker image to use (overrides Dockerfile FROM)",
|
| 344 |
+
),
|
| 345 |
+
] = None,
|
| 346 |
+
interface: Annotated[
|
| 347 |
+
bool,
|
| 348 |
+
typer.Option(
|
| 349 |
+
"--interface",
|
| 350 |
+
help="Enable web interface (default: True if no registry specified)",
|
| 351 |
+
),
|
| 352 |
+
] = None,
|
| 353 |
+
no_interface: Annotated[
|
| 354 |
+
bool,
|
| 355 |
+
typer.Option(
|
| 356 |
+
"--no-interface",
|
| 357 |
+
help="Disable web interface",
|
| 358 |
+
),
|
| 359 |
+
] = False,
|
| 360 |
+
registry: Annotated[
|
| 361 |
+
str | None,
|
| 362 |
+
typer.Option(
|
| 363 |
+
"--registry",
|
| 364 |
+
help="Custom registry URL (e.g., docker.io/username). Disables web interface by default.",
|
| 365 |
+
),
|
| 366 |
+
] = None,
|
| 367 |
+
private: Annotated[
|
| 368 |
+
bool,
|
| 369 |
+
typer.Option(
|
| 370 |
+
"--private",
|
| 371 |
+
help="Deploy the space as private",
|
| 372 |
+
),
|
| 373 |
+
] = False,
|
| 374 |
+
) -> None:
|
| 375 |
+
"""
|
| 376 |
+
Push an OpenEnv environment to Hugging Face Spaces or a custom Docker registry.
|
| 377 |
+
|
| 378 |
+
This command:
|
| 379 |
+
1. Validates that the directory is an OpenEnv environment (openenv.yaml present)
|
| 380 |
+
2. Builds and pushes to Hugging Face Spaces or custom Docker registry
|
| 381 |
+
3. Optionally enables web interface for deployment
|
| 382 |
+
|
| 383 |
+
The web interface is enabled by default when pushing to HuggingFace Spaces,
|
| 384 |
+
but disabled by default when pushing to a custom Docker registry.
|
| 385 |
+
|
| 386 |
+
Examples:
|
| 387 |
+
# Push to HuggingFace Spaces from current directory (web interface enabled)
|
| 388 |
+
$ cd my_env
|
| 389 |
+
$ openenv push
|
| 390 |
+
|
| 391 |
+
# Push to HuggingFace without web interface
|
| 392 |
+
$ openenv push --no-interface
|
| 393 |
+
|
| 394 |
+
# Push to Docker Hub
|
| 395 |
+
$ openenv push --registry docker.io/myuser
|
| 396 |
+
|
| 397 |
+
# Push to GitHub Container Registry
|
| 398 |
+
$ openenv push --registry ghcr.io/myorg
|
| 399 |
+
|
| 400 |
+
# Push to custom registry with web interface
|
| 401 |
+
$ openenv push --registry myregistry.io/path1/path2 --interface
|
| 402 |
+
|
| 403 |
+
# Push to specific HuggingFace repo
|
| 404 |
+
$ openenv push --repo-id my-org/my-env
|
| 405 |
+
|
| 406 |
+
# Push privately with custom base image
|
| 407 |
+
$ openenv push --private --base-image ghcr.io/meta-pytorch/openenv-base:latest
|
| 408 |
+
"""
|
| 409 |
+
# Handle interface flag logic
|
| 410 |
+
if no_interface and interface:
|
| 411 |
+
console.print(
|
| 412 |
+
"[bold red]Error:[/bold red] Cannot specify both --interface and --no-interface",
|
| 413 |
+
file=sys.stderr,
|
| 414 |
+
)
|
| 415 |
+
raise typer.Exit(1)
|
| 416 |
+
|
| 417 |
+
# Determine if web interface should be enabled
|
| 418 |
+
if no_interface:
|
| 419 |
+
enable_interface = False
|
| 420 |
+
elif interface is not None:
|
| 421 |
+
enable_interface = interface
|
| 422 |
+
elif registry is not None:
|
| 423 |
+
# Custom registry: disable interface by default
|
| 424 |
+
enable_interface = False
|
| 425 |
+
else:
|
| 426 |
+
# HuggingFace: enable interface by default
|
| 427 |
+
enable_interface = True
|
| 428 |
+
|
| 429 |
+
# Determine directory
|
| 430 |
+
if directory:
|
| 431 |
+
env_dir = Path(directory).resolve()
|
| 432 |
+
else:
|
| 433 |
+
env_dir = Path.cwd().resolve()
|
| 434 |
+
|
| 435 |
+
if not env_dir.exists() or not env_dir.is_dir():
|
| 436 |
+
raise typer.BadParameter(f"Directory does not exist: {env_dir}")
|
| 437 |
+
|
| 438 |
+
# Check for openenv.yaml to confirm this is an environment directory
|
| 439 |
+
openenv_yaml = env_dir / "openenv.yaml"
|
| 440 |
+
if not openenv_yaml.exists():
|
| 441 |
+
console.print(
|
| 442 |
+
f"[bold red]Error:[/bold red] Not an OpenEnv environment directory (missing openenv.yaml): {env_dir}",
|
| 443 |
+
)
|
| 444 |
+
console.print(
|
| 445 |
+
"[yellow]Hint:[/yellow] Run this command from the environment root directory",
|
| 446 |
+
)
|
| 447 |
+
raise typer.Exit(1)
|
| 448 |
+
|
| 449 |
+
# Validate OpenEnv environment
|
| 450 |
+
console.print(
|
| 451 |
+
f"[bold cyan]Validating OpenEnv environment in {env_dir}...[/bold cyan]"
|
| 452 |
+
)
|
| 453 |
+
env_name, manifest = _validate_openenv_directory(env_dir)
|
| 454 |
+
console.print(f"[bold green]Γ£ô[/bold green] Found OpenEnv environment: {env_name}")
|
| 455 |
+
|
| 456 |
+
# Handle custom registry push
|
| 457 |
+
if registry:
|
| 458 |
+
console.print("[bold cyan]Preparing to push to custom registry...[/bold cyan]")
|
| 459 |
+
if enable_interface:
|
| 460 |
+
console.print("[bold cyan]Web interface will be enabled[/bold cyan]")
|
| 461 |
+
|
| 462 |
+
# Import build functions
|
| 463 |
+
from .build import _build_docker_image, _push_docker_image
|
| 464 |
+
|
| 465 |
+
# Prepare build args for custom registry deployment
|
| 466 |
+
build_args = {}
|
| 467 |
+
if enable_interface:
|
| 468 |
+
build_args["ENABLE_WEB_INTERFACE"] = "true"
|
| 469 |
+
|
| 470 |
+
# Build Docker image from the environment directory
|
| 471 |
+
tag = f"{registry}/{env_name}"
|
| 472 |
+
console.print(f"[bold cyan]Building Docker image: {tag}[/bold cyan]")
|
| 473 |
+
|
| 474 |
+
success = _build_docker_image(
|
| 475 |
+
env_path=env_dir,
|
| 476 |
+
tag=tag,
|
| 477 |
+
build_args=build_args if build_args else None,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
if not success:
|
| 481 |
+
console.print("[bold red]Γ£ù Docker build failed[/bold red]")
|
| 482 |
+
raise typer.Exit(1)
|
| 483 |
+
|
| 484 |
+
console.print("[bold green]Γ£ô Docker build successful[/bold green]")
|
| 485 |
+
|
| 486 |
+
# Push to registry
|
| 487 |
+
console.print(f"[bold cyan]Pushing to registry: {registry}[/bold cyan]")
|
| 488 |
+
|
| 489 |
+
success = _push_docker_image(
|
| 490 |
+
tag, registry=None
|
| 491 |
+
) # Tag already includes registry
|
| 492 |
+
|
| 493 |
+
if not success:
|
| 494 |
+
console.print("[bold red]Γ£ù Docker push failed[/bold red]")
|
| 495 |
+
raise typer.Exit(1)
|
| 496 |
+
|
| 497 |
+
console.print("\n[bold green]Γ£ô Deployment complete![/bold green]")
|
| 498 |
+
console.print(f"[bold]Image:[/bold] {tag}")
|
| 499 |
+
return
|
| 500 |
+
|
| 501 |
+
# Ensure authentication for HuggingFace
|
| 502 |
+
username = _ensure_hf_authenticated()
|
| 503 |
+
|
| 504 |
+
# Determine repo_id
|
| 505 |
+
if not repo_id:
|
| 506 |
+
repo_id = f"{username}/{env_name}"
|
| 507 |
+
|
| 508 |
+
# Validate repo_id format
|
| 509 |
+
if "/" not in repo_id or repo_id.count("/") != 1:
|
| 510 |
+
raise typer.BadParameter(
|
| 511 |
+
f"Invalid repo-id format: {repo_id}. Expected format: 'username/repo-name'"
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Initialize Hugging Face API
|
| 515 |
+
api = HfApi()
|
| 516 |
+
|
| 517 |
+
# Prepare staging directory
|
| 518 |
+
deployment_type = (
|
| 519 |
+
"with web interface" if enable_interface else "without web interface"
|
| 520 |
+
)
|
| 521 |
+
console.print(
|
| 522 |
+
f"[bold cyan]Preparing files for Hugging Face deployment ({deployment_type})...[/bold cyan]"
|
| 523 |
+
)
|
| 524 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 525 |
+
staging_dir = Path(tmpdir) / "staging"
|
| 526 |
+
_prepare_staging_directory(
|
| 527 |
+
env_dir,
|
| 528 |
+
env_name,
|
| 529 |
+
staging_dir,
|
| 530 |
+
base_image=base_image,
|
| 531 |
+
enable_interface=enable_interface,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# Create/verify space
|
| 535 |
+
_create_hf_space(repo_id, api, private=private)
|
| 536 |
+
|
| 537 |
+
# Upload files
|
| 538 |
+
_upload_to_hf_space(repo_id, staging_dir, api, private=private)
|
| 539 |
+
|
| 540 |
+
console.print("\n[bold green]Γ£ô Deployment complete![/bold green]")
|
| 541 |
+
console.print(f"Visit your space at: https://huggingface.co/spaces/{repo_id}")
|
src/openenv/cli/commands/serve.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Serve OpenEnv environments locally (TO BE IMPLEMENTED)."""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Annotated, Optional
|
| 13 |
+
|
| 14 |
+
import typer
|
| 15 |
+
|
| 16 |
+
from .._cli_utils import console
|
| 17 |
+
|
| 18 |
+
# Commands are registered in __main__.py
|
| 19 |
+
|
| 20 |
+
def serve(
|
| 21 |
+
env_path: Annotated[
|
| 22 |
+
Optional[str],
|
| 23 |
+
typer.Argument(
|
| 24 |
+
help="Path to the environment directory (default: current directory)"
|
| 25 |
+
),
|
| 26 |
+
] = None,
|
| 27 |
+
port: Annotated[
|
| 28 |
+
int,
|
| 29 |
+
typer.Option("--port", help="Port to serve on"),
|
| 30 |
+
] = 8000,
|
| 31 |
+
host: Annotated[
|
| 32 |
+
str,
|
| 33 |
+
typer.Option("--host", help="Host to bind to"),
|
| 34 |
+
] = "0.0.0.0",
|
| 35 |
+
reload: Annotated[
|
| 36 |
+
bool,
|
| 37 |
+
typer.Option("--reload", help="Enable auto-reload on code changes"),
|
| 38 |
+
] = False,
|
| 39 |
+
) -> None:
|
| 40 |
+
"""
|
| 41 |
+
Serve an OpenEnv environment locally.
|
| 42 |
+
|
| 43 |
+
TODO: This command is currently not implemented and has been deferred for later.
|
| 44 |
+
|
| 45 |
+
Planned functionality:
|
| 46 |
+
- Run environment server locally without Docker
|
| 47 |
+
- Support multiple deployment modes (local, notebook, cluster)
|
| 48 |
+
- Auto-reload for development
|
| 49 |
+
- Integration with environment's [project.scripts] entry point
|
| 50 |
+
|
| 51 |
+
For now, use Docker-based serving:
|
| 52 |
+
1. Build the environment: openenv build
|
| 53 |
+
2. Run the container: docker run -p 8000:8000 <image-name>
|
| 54 |
+
|
| 55 |
+
Or use uv directly:
|
| 56 |
+
uv run --project . server --port 8000
|
| 57 |
+
"""
|
| 58 |
+
console.print("[bold yellow]⚠ This command is not yet implemented[/bold yellow]\n")
|
| 59 |
+
|
| 60 |
+
console.print(
|
| 61 |
+
"The [bold cyan]openenv serve[/bold cyan] command has been deferred for later."
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
console.print("[bold]Alternative approaches:[/bold]\n")
|
| 65 |
+
|
| 66 |
+
console.print("[cyan]Option 1: Docker-based serving (recommended)[/cyan]")
|
| 67 |
+
console.print(" 1. Build the environment:")
|
| 68 |
+
console.print(" [dim]$ openenv build[/dim]")
|
| 69 |
+
console.print(" 2. Run the Docker container:")
|
| 70 |
+
console.print(
|
| 71 |
+
f" [dim]$ docker run -p {port}:{port} openenv-<env-name>:latest[/dim]\n"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
console.print("[cyan]Option 2: Direct execution with uv[/cyan]")
|
| 75 |
+
|
| 76 |
+
# Determine environment path
|
| 77 |
+
if env_path is None:
|
| 78 |
+
env_path_obj = Path.cwd()
|
| 79 |
+
else:
|
| 80 |
+
env_path_obj = Path(env_path)
|
| 81 |
+
|
| 82 |
+
# Check for openenv.yaml
|
| 83 |
+
openenv_yaml = env_path_obj / "openenv.yaml"
|
| 84 |
+
if openenv_yaml.exists():
|
| 85 |
+
console.print(" From your environment directory:")
|
| 86 |
+
console.print(f" [dim]$ cd {env_path_obj}[/dim]")
|
| 87 |
+
console.print(f" [dim]$ uv run --project . server --port {port}[/dim]\n")
|
| 88 |
+
else:
|
| 89 |
+
console.print(" From an environment directory with pyproject.toml:")
|
| 90 |
+
console.print(f" [dim]$ uv run --project . server --port {port}[/dim]\n")
|
| 91 |
+
|
| 92 |
+
raise typer.Exit(0)
|
src/openenv/cli/commands/validate.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
OpenEnv validate command.
|
| 9 |
+
|
| 10 |
+
This module provides the 'openenv validate' command to check if environments
|
| 11 |
+
are properly configured for multi-mode deployment.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import typer
|
| 17 |
+
|
| 18 |
+
from openenv.cli._validation import (
|
| 19 |
+
format_validation_report,
|
| 20 |
+
get_deployment_modes,
|
| 21 |
+
validate_multi_mode_deployment,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def validate(
|
| 26 |
+
env_path: str | None = typer.Argument(
|
| 27 |
+
None, help="Path to the environment directory (default: current directory)"
|
| 28 |
+
),
|
| 29 |
+
verbose: bool = typer.Option(
|
| 30 |
+
False, "--verbose", "-v", help="Show detailed information"
|
| 31 |
+
),
|
| 32 |
+
) -> None:
|
| 33 |
+
"""
|
| 34 |
+
Validate an environment for standardized structure and deployment readiness.
|
| 35 |
+
|
| 36 |
+
This command checks if an environment is properly configured with:
|
| 37 |
+
- Required files (pyproject.toml, openenv.yaml, server/app.py, etc.)
|
| 38 |
+
- Docker deployment support
|
| 39 |
+
- uv run server capability
|
| 40 |
+
- python -m module execution
|
| 41 |
+
|
| 42 |
+
Examples:
|
| 43 |
+
# Validate current directory (recommended)
|
| 44 |
+
$ cd my_env
|
| 45 |
+
$ openenv validate
|
| 46 |
+
|
| 47 |
+
# Validate with detailed output
|
| 48 |
+
$ openenv validate --verbose
|
| 49 |
+
|
| 50 |
+
# Validate specific environment
|
| 51 |
+
$ openenv validate envs/echo_env
|
| 52 |
+
"""
|
| 53 |
+
# Determine environment path (default to current directory)
|
| 54 |
+
if env_path is None:
|
| 55 |
+
env_path_obj = Path.cwd()
|
| 56 |
+
else:
|
| 57 |
+
env_path_obj = Path(env_path)
|
| 58 |
+
|
| 59 |
+
if not env_path_obj.exists():
|
| 60 |
+
typer.echo(f"Error: Path does not exist: {env_path_obj}", err=True)
|
| 61 |
+
raise typer.Exit(1)
|
| 62 |
+
|
| 63 |
+
if not env_path_obj.is_dir():
|
| 64 |
+
typer.echo(f"Error: Path is not a directory: {env_path_obj}", err=True)
|
| 65 |
+
raise typer.Exit(1)
|
| 66 |
+
|
| 67 |
+
# Check for openenv.yaml to confirm this is an environment directory
|
| 68 |
+
openenv_yaml = env_path_obj / "openenv.yaml"
|
| 69 |
+
if not openenv_yaml.exists():
|
| 70 |
+
typer.echo(
|
| 71 |
+
f"Error: Not an OpenEnv environment directory (missing openenv.yaml): {env_path_obj}",
|
| 72 |
+
err=True,
|
| 73 |
+
)
|
| 74 |
+
typer.echo(
|
| 75 |
+
"Hint: Run this command from the environment root directory or specify the path",
|
| 76 |
+
err=True,
|
| 77 |
+
)
|
| 78 |
+
raise typer.Exit(1)
|
| 79 |
+
|
| 80 |
+
env_name = env_path_obj.name
|
| 81 |
+
if env_name.endswith("_env"):
|
| 82 |
+
base_name = env_name[:-4]
|
| 83 |
+
else:
|
| 84 |
+
base_name = env_name
|
| 85 |
+
|
| 86 |
+
# Run validation
|
| 87 |
+
is_valid, issues = validate_multi_mode_deployment(env_path_obj)
|
| 88 |
+
|
| 89 |
+
# Show validation report
|
| 90 |
+
report = format_validation_report(base_name, is_valid, issues)
|
| 91 |
+
typer.echo(report)
|
| 92 |
+
|
| 93 |
+
# Show deployment modes if verbose
|
| 94 |
+
if verbose:
|
| 95 |
+
typer.echo("\nSupported deployment modes:")
|
| 96 |
+
modes = get_deployment_modes(env_path_obj)
|
| 97 |
+
for mode, supported in modes.items():
|
| 98 |
+
status = "[YES]" if supported else "[NO]"
|
| 99 |
+
typer.echo(f" {status} {mode}")
|
| 100 |
+
|
| 101 |
+
if is_valid:
|
| 102 |
+
typer.echo("\nUsage examples:")
|
| 103 |
+
typer.echo(f" cd {env_path_obj.name} && uv run server")
|
| 104 |
+
typer.echo(f" cd {env_path_obj.name} && openenv build")
|
| 105 |
+
typer.echo(f" cd {env_path_obj.name} && openenv push")
|
| 106 |
+
|
| 107 |
+
if not is_valid:
|
| 108 |
+
raise typer.Exit(1)
|
src/openenv/cli/templates/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""OpenEnv CLI templates package."""
|
src/openenv/cli/templates/openenv_env/.dockerignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv
|
| 2 |
+
.git
|
| 3 |
+
.gitignore
|
| 4 |
+
.env
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.pyc
|
| 7 |
+
*.pyo
|
| 8 |
+
*.pyd
|
| 9 |
+
*.pyw
|
| 10 |
+
*.pyz
|
| 11 |
+
*.pywz
|
| 12 |
+
*.pyzw
|
| 13 |
+
*.pyzwz
|
| 14 |
+
|
| 15 |
+
|
src/openenv/cli/templates/openenv_env/README.md
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: __ENV_TITLE_NAME__ Environment Server
|
| 3 |
+
emoji: __HF_EMOJI__
|
| 4 |
+
colorFrom: __HF_COLOR_FROM__
|
| 5 |
+
colorTo: __HF_COLOR_TO__
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# __ENV_TITLE_NAME__ Environment
|
| 15 |
+
|
| 16 |
+
A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
|
| 17 |
+
|
| 18 |
+
## Quick Start
|
| 19 |
+
|
| 20 |
+
The simplest way to use the __ENV_TITLE_NAME__ environment is through the `__ENV_CLASS_NAME__Env` class:
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
# Create environment from Docker image
|
| 27 |
+
__ENV_NAME__env = __ENV_CLASS_NAME__Env.from_docker_image("__ENV_NAME__-env:latest")
|
| 28 |
+
|
| 29 |
+
# Reset
|
| 30 |
+
result = __ENV_NAME__env.reset()
|
| 31 |
+
print(f"Reset: {result.observation.echoed_message}")
|
| 32 |
+
|
| 33 |
+
# Send multiple messages
|
| 34 |
+
messages = ["Hello, World!", "Testing echo", "Final message"]
|
| 35 |
+
|
| 36 |
+
for msg in messages:
|
| 37 |
+
result = __ENV_NAME__env.step(__ENV_CLASS_NAME__Action(message=msg))
|
| 38 |
+
print(f"Sent: '{msg}'")
|
| 39 |
+
print(f" → Echoed: '{result.observation.echoed_message}'")
|
| 40 |
+
print(f" → Length: {result.observation.message_length}")
|
| 41 |
+
print(f" → Reward: {result.reward}")
|
| 42 |
+
|
| 43 |
+
finally:
|
| 44 |
+
# Always clean up
|
| 45 |
+
__ENV_NAME__env.close()
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
That's it! The `__ENV_CLASS_NAME__Env.from_docker_image()` method handles:
|
| 49 |
+
- Starting the Docker container
|
| 50 |
+
- Waiting for the server to be ready
|
| 51 |
+
- Connecting to the environment
|
| 52 |
+
- Container cleanup when you call `close()`
|
| 53 |
+
|
| 54 |
+
## Building the Docker Image
|
| 55 |
+
|
| 56 |
+
Before using the environment, you need to build the Docker image:
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
# From project root
|
| 60 |
+
docker build -t __ENV_NAME__-env:latest -f server/Dockerfile .
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Deploying to Hugging Face Spaces
|
| 64 |
+
|
| 65 |
+
You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
# From the environment directory (where openenv.yaml is located)
|
| 69 |
+
openenv push
|
| 70 |
+
|
| 71 |
+
# Or specify options
|
| 72 |
+
openenv push --namespace my-org --private
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
The `openenv push` command will:
|
| 76 |
+
1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
|
| 77 |
+
2. Prepare a custom build for Hugging Face Docker space (enables web interface)
|
| 78 |
+
3. Upload to Hugging Face (ensuring you're logged in)
|
| 79 |
+
|
| 80 |
+
### Prerequisites
|
| 81 |
+
|
| 82 |
+
- Authenticate with Hugging Face: The command will prompt for login if not already authenticated
|
| 83 |
+
|
| 84 |
+
### Options
|
| 85 |
+
|
| 86 |
+
- `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
|
| 87 |
+
- `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
|
| 88 |
+
- `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
|
| 89 |
+
- `--private`: Deploy the space as private (default: public)
|
| 90 |
+
|
| 91 |
+
### Examples
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
# Push to your personal namespace (defaults to username/env-name from openenv.yaml)
|
| 95 |
+
openenv push
|
| 96 |
+
|
| 97 |
+
# Push to a specific repository
|
| 98 |
+
openenv push --repo-id my-org/my-env
|
| 99 |
+
|
| 100 |
+
# Push with a custom base image
|
| 101 |
+
openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
|
| 102 |
+
|
| 103 |
+
# Push as a private space
|
| 104 |
+
openenv push --private
|
| 105 |
+
|
| 106 |
+
# Combine options
|
| 107 |
+
openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
After deployment, your space will be available at:
|
| 111 |
+
`https://huggingface.co/spaces/<repo-id>`
|
| 112 |
+
|
| 113 |
+
The deployed space includes:
|
| 114 |
+
- **Web Interface** at `/web` - Interactive UI for exploring the environment
|
| 115 |
+
- **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
|
| 116 |
+
- **Health Check** at `/health` - Container health monitoring
|
| 117 |
+
- **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
|
| 118 |
+
|
| 119 |
+
## Environment Details
|
| 120 |
+
|
| 121 |
+
### Action
|
| 122 |
+
**__ENV_CLASS_NAME__Action**: Contains a single field
|
| 123 |
+
- `message` (str) - The message to echo back
|
| 124 |
+
|
| 125 |
+
### Observation
|
| 126 |
+
**__ENV_CLASS_NAME__Observation**: Contains the echo response and metadata
|
| 127 |
+
- `echoed_message` (str) - The message echoed back
|
| 128 |
+
- `message_length` (int) - Length of the message
|
| 129 |
+
- `reward` (float) - Reward based on message length (length × 0.1)
|
| 130 |
+
- `done` (bool) - Always False for echo environment
|
| 131 |
+
- `metadata` (dict) - Additional info like step count
|
| 132 |
+
|
| 133 |
+
### Reward
|
| 134 |
+
The reward is calculated as: `message_length × 0.1`
|
| 135 |
+
- "Hi" → reward: 0.2
|
| 136 |
+
- "Hello, World!" → reward: 1.3
|
| 137 |
+
- Empty message → reward: 0.0
|
| 138 |
+
|
| 139 |
+
## Advanced Usage
|
| 140 |
+
|
| 141 |
+
### Connecting to an Existing Server
|
| 142 |
+
|
| 143 |
+
If you already have a __ENV_TITLE_NAME__ environment server running, you can connect directly:
|
| 144 |
+
|
| 145 |
+
```python
|
| 146 |
+
from __ENV_NAME__ import __ENV_CLASS_NAME__Env
|
| 147 |
+
|
| 148 |
+
# Connect to existing server
|
| 149 |
+
__ENV_NAME__env = __ENV_CLASS_NAME__Env(base_url="<ENV_HTTP_URL_HERE>")
|
| 150 |
+
|
| 151 |
+
# Use as normal
|
| 152 |
+
result = __ENV_NAME__env.reset()
|
| 153 |
+
result = __ENV_NAME__env.step(__ENV_CLASS_NAME__Action(message="Hello!"))
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
Note: When connecting to an existing server, `__ENV_NAME__env.close()` will NOT stop the server.
|
| 157 |
+
|
| 158 |
+
### Using the Context Manager
|
| 159 |
+
|
| 160 |
+
The client supports context manager usage for automatic connection management:
|
| 161 |
+
|
| 162 |
+
```python
|
| 163 |
+
from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env
|
| 164 |
+
|
| 165 |
+
# Connect with context manager (auto-connects and closes)
|
| 166 |
+
with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as env:
|
| 167 |
+
result = env.reset()
|
| 168 |
+
print(f"Reset: {result.observation.echoed_message}")
|
| 169 |
+
# Multiple steps with low latency
|
| 170 |
+
for msg in ["Hello", "World", "!"]:
|
| 171 |
+
result = env.step(__ENV_CLASS_NAME__Action(message=msg))
|
| 172 |
+
print(f"Echoed: {result.observation.echoed_message}")
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
The client uses WebSocket connections for:
|
| 176 |
+
- **Lower latency**: No HTTP connection overhead per request
|
| 177 |
+
- **Persistent session**: Server maintains your environment state
|
| 178 |
+
- **Efficient for episodes**: Better for many sequential steps
|
| 179 |
+
|
| 180 |
+
### Concurrent WebSocket Sessions
|
| 181 |
+
|
| 182 |
+
The server supports multiple concurrent WebSocket connections. To enable this,
|
| 183 |
+
modify `server/app.py` to use factory mode:
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
# In server/app.py - use factory mode for concurrent sessions
|
| 187 |
+
app = create_app(
|
| 188 |
+
__ENV_CLASS_NAME__Environment, # Pass class, not instance
|
| 189 |
+
__ENV_CLASS_NAME__Action,
|
| 190 |
+
__ENV_CLASS_NAME__Observation,
|
| 191 |
+
max_concurrent_envs=4, # Allow 4 concurrent sessions
|
| 192 |
+
)
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
Then multiple clients can connect simultaneously:
|
| 196 |
+
|
| 197 |
+
```python
|
| 198 |
+
from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env
|
| 199 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 200 |
+
|
| 201 |
+
def run_episode(client_id: int):
|
| 202 |
+
with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as env:
|
| 203 |
+
result = env.reset()
|
| 204 |
+
for i in range(10):
|
| 205 |
+
result = env.step(__ENV_CLASS_NAME__Action(message=f"Client {client_id}, step {i}"))
|
| 206 |
+
return client_id, result.observation.message_length
|
| 207 |
+
|
| 208 |
+
# Run 4 episodes concurrently
|
| 209 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
| 210 |
+
results = list(executor.map(run_episode, range(4)))
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
## Development & Testing
|
| 214 |
+
|
| 215 |
+
### Direct Environment Testing
|
| 216 |
+
|
| 217 |
+
Test the environment logic directly without starting the HTTP server:
|
| 218 |
+
|
| 219 |
+
```bash
|
| 220 |
+
# From the server directory
|
| 221 |
+
python3 server/__ENV_NAME___environment.py
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
This verifies that:
|
| 225 |
+
- Environment resets correctly
|
| 226 |
+
- Step executes actions properly
|
| 227 |
+
- State tracking works
|
| 228 |
+
- Rewards are calculated correctly
|
| 229 |
+
|
| 230 |
+
### Running Locally
|
| 231 |
+
|
| 232 |
+
Run the server locally for development:
|
| 233 |
+
|
| 234 |
+
```bash
|
| 235 |
+
uvicorn server.app:app --reload
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
## Project Structure
|
| 239 |
+
|
| 240 |
+
```
|
| 241 |
+
__ENV_NAME__/
|
| 242 |
+
├── .dockerignore # Docker build exclusions
|
| 243 |
+
├── __init__.py # Module exports
|
| 244 |
+
├── README.md # This file
|
| 245 |
+
├── openenv.yaml # OpenEnv manifest
|
| 246 |
+
├── pyproject.toml # Project metadata and dependencies
|
| 247 |
+
├── uv.lock # Locked dependencies (generated)
|
| 248 |
+
├── client.py # __ENV_CLASS_NAME__Env client
|
| 249 |
+
├── models.py # Action and Observation models
|
| 250 |
+
└── server/
|
| 251 |
+
├── __init__.py # Server module exports
|
| 252 |
+
├── __ENV_NAME___environment.py # Core environment logic
|
| 253 |
+
├── app.py # FastAPI application (HTTP + WebSocket endpoints)
|
| 254 |
+
└── Dockerfile # Container image definition
|
| 255 |
+
```
|
src/openenv/cli/templates/openenv_env/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""__ENV_TITLE_NAME__ Environment."""
|
| 8 |
+
|
| 9 |
+
from .client import __ENV_CLASS_NAME__Env
|
| 10 |
+
from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"__ENV_CLASS_NAME__Action",
|
| 14 |
+
"__ENV_CLASS_NAME__Observation",
|
| 15 |
+
"__ENV_CLASS_NAME__Env",
|
| 16 |
+
]
|
src/openenv/cli/templates/openenv_env/client.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""__ENV_TITLE_NAME__ Environment Client."""
|
| 8 |
+
|
| 9 |
+
from typing import Dict
|
| 10 |
+
|
| 11 |
+
from openenv.core.client_types import StepResult
|
| 12 |
+
from openenv.core.env_server.types import State
|
| 13 |
+
from openenv.core import EnvClient
|
| 14 |
+
|
| 15 |
+
from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class __ENV_CLASS_NAME__Env(
|
| 19 |
+
EnvClient[__ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation]
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Client for the __ENV_TITLE_NAME__ Environment.
|
| 23 |
+
|
| 24 |
+
This client maintains a persistent WebSocket connection to the environment server,
|
| 25 |
+
enabling efficient multi-step interactions with lower latency.
|
| 26 |
+
Each client instance has its own dedicated environment session on the server.
|
| 27 |
+
|
| 28 |
+
Example:
|
| 29 |
+
>>> # Connect to a running server
|
| 30 |
+
>>> with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as client:
|
| 31 |
+
... result = client.reset()
|
| 32 |
+
... print(result.observation.echoed_message)
|
| 33 |
+
...
|
| 34 |
+
... result = client.step(__ENV_CLASS_NAME__Action(message="Hello!"))
|
| 35 |
+
... print(result.observation.echoed_message)
|
| 36 |
+
|
| 37 |
+
Example with Docker:
|
| 38 |
+
>>> # Automatically start container and connect
|
| 39 |
+
>>> client = __ENV_CLASS_NAME__Env.from_docker_image("__ENV_NAME__-env:latest")
|
| 40 |
+
>>> try:
|
| 41 |
+
... result = client.reset()
|
| 42 |
+
... result = client.step(__ENV_CLASS_NAME__Action(message="Test"))
|
| 43 |
+
... finally:
|
| 44 |
+
... client.close()
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def _step_payload(self, action: __ENV_CLASS_NAME__Action) -> Dict:
|
| 48 |
+
"""
|
| 49 |
+
Convert __ENV_CLASS_NAME__Action to JSON payload for step message.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
action: __ENV_CLASS_NAME__Action instance
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Dictionary representation suitable for JSON encoding
|
| 56 |
+
"""
|
| 57 |
+
return {
|
| 58 |
+
"message": action.message,
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def _parse_result(self, payload: Dict) -> StepResult[__ENV_CLASS_NAME__Observation]:
|
| 62 |
+
"""
|
| 63 |
+
Parse server response into StepResult[__ENV_CLASS_NAME__Observation].
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
payload: JSON response data from server
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
StepResult with __ENV_CLASS_NAME__Observation
|
| 70 |
+
"""
|
| 71 |
+
obs_data = payload.get("observation", {})
|
| 72 |
+
observation = __ENV_CLASS_NAME__Observation(
|
| 73 |
+
echoed_message=obs_data.get("echoed_message", ""),
|
| 74 |
+
message_length=obs_data.get("message_length", 0),
|
| 75 |
+
done=payload.get("done", False),
|
| 76 |
+
reward=payload.get("reward"),
|
| 77 |
+
metadata=obs_data.get("metadata", {}),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
return StepResult(
|
| 81 |
+
observation=observation,
|
| 82 |
+
reward=payload.get("reward"),
|
| 83 |
+
done=payload.get("done", False),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 87 |
+
"""
|
| 88 |
+
Parse server response into State object.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
payload: JSON response from state request
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
State object with episode_id and step_count
|
| 95 |
+
"""
|
| 96 |
+
return State(
|
| 97 |
+
episode_id=payload.get("episode_id"),
|
| 98 |
+
step_count=payload.get("step_count", 0),
|
| 99 |
+
)
|
src/openenv/cli/templates/openenv_env/models.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Data models for the __ENV_TITLE_NAME__ Environment.
|
| 9 |
+
|
| 10 |
+
The __ENV_NAME__ environment is a simple test environment that echoes back messages.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from pydantic import Field
|
| 14 |
+
|
| 15 |
+
from openenv.core.env_server.types import Action, Observation
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class __ENV_CLASS_NAME__Action(Action):
|
| 19 |
+
"""Action for the __ENV_TITLE_NAME__ environment - just a message to echo."""
|
| 20 |
+
|
| 21 |
+
message: str = Field(..., description="Message to echo back")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class __ENV_CLASS_NAME__Observation(Observation):
|
| 25 |
+
"""Observation from the __ENV_TITLE_NAME__ environment - the echoed message."""
|
| 26 |
+
|
| 27 |
+
echoed_message: str = Field(default="", description="The echoed message")
|
| 28 |
+
message_length: int = Field(default=0, description="Length of the echoed message")
|
src/openenv/cli/templates/openenv_env/openenv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: __ENV_NAME__
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
src/openenv/cli/templates/openenv_env/pyproject.toml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-__ENV_NAME__"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "__ENV_TITLE_NAME__ environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]>=0.2.0",
|
| 21 |
+
# Environment-specific dependencies
|
| 22 |
+
# Add all dependencies needed for your environment here
|
| 23 |
+
# Examples:
|
| 24 |
+
# "numpy>=1.19.0",
|
| 25 |
+
# "torch>=2.0.0",
|
| 26 |
+
# "gymnasium>=0.29.0",
|
| 27 |
+
# "openspiel>=1.0.0",
|
| 28 |
+
# "smolagents>=1.22.0,<2",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
[project.scripts]
|
| 38 |
+
# Server entry point - enables running via: uv run --project . server
|
| 39 |
+
# or: python -m __ENV_NAME__.server.app
|
| 40 |
+
server = "__ENV_NAME__.server.app:main"
|
| 41 |
+
|
| 42 |
+
[tool.setuptools]
|
| 43 |
+
include-package-data = true
|
| 44 |
+
packages = ["__ENV_NAME__", "__ENV_NAME__.server"]
|
| 45 |
+
package-dir = { "__ENV_NAME__" = ".", "__ENV_NAME__.server" = "server" }
|
src/openenv/cli/templates/openenv_env/server/Dockerfile
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=__ENV_NAME__
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|