Spaces:
Running
Running
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- Dockerfile +75 -0
- README.md +177 -4
- __init__.py +14 -0
- assets/cartpole.png +3 -0
- assets/quadruped.png +3 -0
- client.py +375 -0
- envs/dm_control_env/README.md +167 -0
- envs/dm_control_env/__init__.py +14 -0
- envs/dm_control_env/assets/cartpole.png +3 -0
- envs/dm_control_env/assets/quadruped.png +3 -0
- envs/dm_control_env/client.py +375 -0
- envs/dm_control_env/examples/__init__.py +1 -0
- envs/dm_control_env/examples/cartpole_control.py +333 -0
- envs/dm_control_env/examples/hopper_control.py +374 -0
- envs/dm_control_env/examples/list_environments.py +41 -0
- envs/dm_control_env/examples/quadruped_control.py +373 -0
- envs/dm_control_env/models.py +186 -0
- envs/dm_control_env/openenv.yaml +6 -0
- envs/dm_control_env/pyproject.toml +48 -0
- envs/dm_control_env/server/Dockerfile +73 -0
- envs/dm_control_env/server/__init__.py +5 -0
- envs/dm_control_env/server/app.py +78 -0
- envs/dm_control_env/server/dm_control_environment.py +428 -0
- examples/__init__.py +1 -0
- examples/cartpole_control.py +333 -0
- examples/hopper_control.py +374 -0
- examples/list_environments.py +41 -0
- examples/quadruped_control.py +373 -0
- models.py +186 -0
- openenv.yaml +6 -0
- pyproject.toml +48 -0
- server/Dockerfile +73 -0
- server/__init__.py +5 -0
- server/app.py +78 -0
- server/dm_control_environment.py +428 -0
- src/__init__.py +7 -0
- src/openenv.egg-info/PKG-INFO +337 -0
- src/openenv.egg-info/SOURCES.txt +142 -0
- src/openenv.egg-info/dependency_links.txt +1 -0
- src/openenv.egg-info/entry_points.txt +2 -0
- src/openenv.egg-info/requires.txt +32 -0
- src/openenv.egg-info/top_level.txt +2 -0
- src/openenv/__init__.py +23 -0
- src/openenv/auto/__init__.py +39 -0
- src/openenv/auto/_discovery.py +584 -0
- src/openenv/auto/auto_action.py +276 -0
- src/openenv/auto/auto_env.py +896 -0
- src/openenv/cli/__init__.py +9 -0
- src/openenv/cli/__main__.py +62 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/cartpole.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/quadruped.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
envs/dm_control_env/assets/cartpole.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
envs/dm_control_env/assets/quadruped.png filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build for dm_control environment
|
| 8 |
+
# Uses pip for package installation
|
| 9 |
+
|
| 10 |
+
FROM python:3.11-slim AS builder
|
| 11 |
+
|
| 12 |
+
WORKDIR /app
|
| 13 |
+
|
| 14 |
+
# Install build dependencies including OpenGL for MuJoCo
|
| 15 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 16 |
+
build-essential \
|
| 17 |
+
git \
|
| 18 |
+
libgl1 \
|
| 19 |
+
libglx-mesa0 \
|
| 20 |
+
libglew-dev \
|
| 21 |
+
libosmesa6-dev \
|
| 22 |
+
libgl1-mesa-dev \
|
| 23 |
+
libglfw3 \
|
| 24 |
+
patchelf \
|
| 25 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 26 |
+
|
| 27 |
+
# Copy environment code
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
WORKDIR /app/env
|
| 31 |
+
|
| 32 |
+
# Install dependencies using pip
|
| 33 |
+
RUN pip install --upgrade pip && \
|
| 34 |
+
pip install --no-cache-dir -e .
|
| 35 |
+
|
| 36 |
+
# Final runtime stage
|
| 37 |
+
FROM python:3.11-slim
|
| 38 |
+
|
| 39 |
+
WORKDIR /app
|
| 40 |
+
|
| 41 |
+
# Install runtime dependencies (OpenGL for MuJoCo rendering, curl for healthcheck)
|
| 42 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 43 |
+
curl \
|
| 44 |
+
libgl1 \
|
| 45 |
+
libglx-mesa0 \
|
| 46 |
+
libglew-dev \
|
| 47 |
+
libosmesa6-dev \
|
| 48 |
+
libglfw3 \
|
| 49 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 50 |
+
|
| 51 |
+
# Copy installed packages from builder
|
| 52 |
+
COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
|
| 53 |
+
COPY --from=builder /usr/local/bin /usr/local/bin
|
| 54 |
+
|
| 55 |
+
# Copy the environment code
|
| 56 |
+
COPY . /app/env
|
| 57 |
+
|
| 58 |
+
# Set PYTHONPATH so imports work correctly
|
| 59 |
+
ENV PYTHONPATH="/app/env"
|
| 60 |
+
|
| 61 |
+
# Set MuJoCo to use OSMesa for headless rendering
|
| 62 |
+
ENV MUJOCO_GL="osmesa"
|
| 63 |
+
|
| 64 |
+
# Expose port
|
| 65 |
+
EXPOSE 8000
|
| 66 |
+
|
| 67 |
+
# Health check
|
| 68 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
|
| 69 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 70 |
+
|
| 71 |
+
# Run the FastAPI server
|
| 72 |
+
# Use exec to replace the shell with uvicorn so it receives SIGINT/SIGTERM directly
|
| 73 |
+
CMD ["sh", "-c", "cd /app/env && exec uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
| 74 |
+
|
| 75 |
+
ENV ENABLE_WEB_INTERFACE=true
|
README.md
CHANGED
|
@@ -1,10 +1,183 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: dm_control Environment Server
|
| 3 |
+
emoji: 🤖
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
---
|
| 13 |
|
| 14 |
+
## Hugging Face Space Deployment
|
| 15 |
+
|
| 16 |
+
This Space is built from OpenEnv environment `dm_control_env`.
|
| 17 |
+
|
| 18 |
+
- Space URL: `https://huggingface.co/spaces/openenv/dm_control_env-v2-1-0`
|
| 19 |
+
- OpenEnv pinned ref: `v2.1.0`
|
| 20 |
+
- Hub tag: `openenv`
|
| 21 |
+
|
| 22 |
+
### Connecting from Code
|
| 23 |
+
|
| 24 |
+
```python
|
| 25 |
+
from envs.dm_control_env import Env
|
| 26 |
+
|
| 27 |
+
env = Env(base_url="https://huggingface.co/spaces/openenv/dm_control_env-v2-1-0")
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
# dm_control OpenEnv Environment
|
| 31 |
+
|
| 32 |
+
A generic OpenEnv environment for [dm_control.suite](https://github.com/google-deepmind/dm_control), providing access to all MuJoCo-based continuous control tasks.
|
| 33 |
+
|
| 34 |
+
<p align="center">
|
| 35 |
+
<img src="assets/cartpole.png" width="45%" alt="Cartpole Balance"/>
|
| 36 |
+
<img src="assets/quadruped.png" width="45%" alt="Quadruped Walk"/>
|
| 37 |
+
</p>
|
| 38 |
+
|
| 39 |
+
## Supported Environments
|
| 40 |
+
|
| 41 |
+
| Domain | Tasks |
|
| 42 |
+
|--------|-------|
|
| 43 |
+
| cartpole | balance, swingup, swingup_sparse |
|
| 44 |
+
| walker | stand, walk, run |
|
| 45 |
+
| humanoid | stand, walk, run |
|
| 46 |
+
| cheetah | run |
|
| 47 |
+
| hopper | stand, hop |
|
| 48 |
+
| reacher | easy, hard |
|
| 49 |
+
| pendulum | swingup |
|
| 50 |
+
| finger | spin, turn_easy, turn_hard |
|
| 51 |
+
| fish | upright, swim |
|
| 52 |
+
| ball_in_cup | catch |
|
| 53 |
+
| And more... | See `dm_control.suite.BENCHMARKING` |
|
| 54 |
+
|
| 55 |
+
## Quick Start
|
| 56 |
+
|
| 57 |
+
### Using the Client
|
| 58 |
+
|
| 59 |
+
```python
|
| 60 |
+
from envs.dm_control_env import DMControlEnv, DMControlAction
|
| 61 |
+
|
| 62 |
+
# Connect to a running server
|
| 63 |
+
with DMControlEnv(base_url="http://localhost:8000") as env:
|
| 64 |
+
# Reset with default (cartpole/balance)
|
| 65 |
+
result = env.reset()
|
| 66 |
+
print(f"Observations: {result.observation.observations.keys()}")
|
| 67 |
+
|
| 68 |
+
# Take actions
|
| 69 |
+
for _ in range(100):
|
| 70 |
+
action = DMControlAction(values=[0.5]) # Push cart right
|
| 71 |
+
result = env.step(action)
|
| 72 |
+
print(f"Reward: {result.reward}, Done: {result.done}")
|
| 73 |
+
|
| 74 |
+
if result.done:
|
| 75 |
+
result = env.reset()
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### Switching Environments
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
# Start with cartpole
|
| 82 |
+
result = env.reset(domain_name="cartpole", task_name="balance")
|
| 83 |
+
|
| 84 |
+
# Switch to walker (on next reset)
|
| 85 |
+
result = env.reset(domain_name="walker", task_name="walk")
|
| 86 |
+
# Note: walker has 6 action dimensions
|
| 87 |
+
action = DMControlAction(values=[0.0] * 6)
|
| 88 |
+
result = env.step(action)
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### Running the Server
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
# From OpenEnv root
|
| 95 |
+
cd envs/dm_control_env
|
| 96 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 97 |
+
|
| 98 |
+
# Or using uv
|
| 99 |
+
uv run --project . server
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
### Using Docker
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
# Build
|
| 106 |
+
docker build -t dm_control:latest -f server/Dockerfile .
|
| 107 |
+
|
| 108 |
+
# Run
|
| 109 |
+
docker run -p 8000:8000 dm_control:latest
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
## API
|
| 113 |
+
|
| 114 |
+
### Action
|
| 115 |
+
|
| 116 |
+
```python
|
| 117 |
+
class DMControlAction(Action):
|
| 118 |
+
values: List[float] # Continuous action values
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
Action dimensions vary by environment:
|
| 122 |
+
- cartpole: 1 (force on cart)
|
| 123 |
+
- walker: 6 (joint torques)
|
| 124 |
+
- humanoid: 21 (joint torques)
|
| 125 |
+
|
| 126 |
+
### Observation
|
| 127 |
+
|
| 128 |
+
```python
|
| 129 |
+
class DMControlObservation(Observation):
|
| 130 |
+
observations: Dict[str, List[float]] # Named observation arrays
|
| 131 |
+
pixels: Optional[str] # Base64 PNG (if render=True)
|
| 132 |
+
reward: float
|
| 133 |
+
done: bool
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
### State
|
| 137 |
+
|
| 138 |
+
```python
|
| 139 |
+
class DMControlState(State):
|
| 140 |
+
domain_name: str
|
| 141 |
+
task_name: str
|
| 142 |
+
action_spec: Dict[str, Any]
|
| 143 |
+
observation_spec: Dict[str, Any]
|
| 144 |
+
physics_timestep: float
|
| 145 |
+
control_timestep: float
|
| 146 |
+
episode_id: str
|
| 147 |
+
step_count: int
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
## Examples
|
| 151 |
+
|
| 152 |
+
See the `examples/` directory:
|
| 153 |
+
- `cartpole_control.py` - Interactive cartpole control with arrow keys
|
| 154 |
+
- `hopper_control.py` - Interactive hopper control with spacebar for random forces
|
| 155 |
+
- `quadruped_control.py` - Interactive quadruped control with spacebar for random forces
|
| 156 |
+
- `list_environments.py` - Print all available environments
|
| 157 |
+
|
| 158 |
+
All examples support consistent CLI arguments:
|
| 159 |
+
|
| 160 |
+
```bash
|
| 161 |
+
# Default: interactive mode with minimal pygame window
|
| 162 |
+
python examples/cartpole_control.py
|
| 163 |
+
|
| 164 |
+
# Visual mode with rendered MuJoCo frames
|
| 165 |
+
python examples/cartpole_control.py --visual
|
| 166 |
+
|
| 167 |
+
# Headless mode (no pygame, automated control)
|
| 168 |
+
python examples/cartpole_control.py --headless --max-steps 500
|
| 169 |
+
|
| 170 |
+
# Select a different task
|
| 171 |
+
python examples/cartpole_control.py --task swingup
|
| 172 |
+
python examples/hopper_control.py --task stand
|
| 173 |
+
python examples/quadruped_control.py --task run
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
## Environment Variables
|
| 177 |
+
|
| 178 |
+
| Variable | Default | Description |
|
| 179 |
+
|----------|---------|-------------|
|
| 180 |
+
| `DMCONTROL_DOMAIN` | cartpole | Default domain |
|
| 181 |
+
| `DMCONTROL_TASK` | balance | Default task |
|
| 182 |
+
| `DMCONTROL_RENDER_HEIGHT` | 480 | Render height |
|
| 183 |
+
| `DMCONTROL_RENDER_WIDTH` | 640 | Render width |
|
__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""dm_control OpenEnv Environment.
|
| 2 |
+
|
| 3 |
+
A generic OpenEnv environment for dm_control.suite supporting all domains/tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .models import DMControlAction, DMControlObservation, DMControlState
|
| 7 |
+
from .client import DMControlEnv
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"DMControlAction",
|
| 11 |
+
"DMControlObservation",
|
| 12 |
+
"DMControlState",
|
| 13 |
+
"DMControlEnv",
|
| 14 |
+
]
|
assets/cartpole.png
ADDED
|
Git LFS Details
|
assets/quadruped.png
ADDED
|
Git LFS Details
|
client.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
dm_control Environment Client.
|
| 9 |
+
|
| 10 |
+
This module provides the client for connecting to a dm_control
|
| 11 |
+
Environment server via WebSocket for persistent sessions.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from openenv.core.client_types import StepResult
|
| 18 |
+
from openenv.core.env_client import EnvClient
|
| 19 |
+
|
| 20 |
+
from .models import (
|
| 21 |
+
AVAILABLE_ENVIRONMENTS,
|
| 22 |
+
DMControlAction,
|
| 23 |
+
DMControlObservation,
|
| 24 |
+
DMControlState,
|
| 25 |
+
)
|
| 26 |
+
except ImportError:
|
| 27 |
+
from openenv.core.client_types import StepResult
|
| 28 |
+
from openenv.core.env_client import EnvClient
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from models import (
|
| 32 |
+
AVAILABLE_ENVIRONMENTS,
|
| 33 |
+
DMControlAction,
|
| 34 |
+
DMControlObservation,
|
| 35 |
+
DMControlState,
|
| 36 |
+
)
|
| 37 |
+
except ImportError:
|
| 38 |
+
try:
|
| 39 |
+
from dm_control_env.models import (
|
| 40 |
+
AVAILABLE_ENVIRONMENTS,
|
| 41 |
+
DMControlAction,
|
| 42 |
+
DMControlObservation,
|
| 43 |
+
DMControlState,
|
| 44 |
+
)
|
| 45 |
+
except ImportError:
|
| 46 |
+
from envs.dm_control_env.models import (
|
| 47 |
+
AVAILABLE_ENVIRONMENTS,
|
| 48 |
+
DMControlAction,
|
| 49 |
+
DMControlObservation,
|
| 50 |
+
DMControlState,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class DMControlEnv(EnvClient[DMControlAction, DMControlObservation, DMControlState]):
|
| 55 |
+
"""
|
| 56 |
+
Client for dm_control.suite environments.
|
| 57 |
+
|
| 58 |
+
This client maintains a persistent WebSocket connection to the environment
|
| 59 |
+
server, enabling efficient multi-step interactions with lower latency.
|
| 60 |
+
|
| 61 |
+
Supported Environments (via dm_control.suite):
|
| 62 |
+
- cartpole: balance, swingup, swingup_sparse
|
| 63 |
+
- walker: stand, walk, run
|
| 64 |
+
- humanoid: stand, walk, run
|
| 65 |
+
- cheetah: run
|
| 66 |
+
- hopper: stand, hop
|
| 67 |
+
- reacher: easy, hard
|
| 68 |
+
- And many more...
|
| 69 |
+
|
| 70 |
+
Example:
|
| 71 |
+
>>> # Connect to a running server
|
| 72 |
+
>>> with DMControlEnv(base_url="http://localhost:8000") as client:
|
| 73 |
+
... result = client.reset()
|
| 74 |
+
... print(f"Observations: {result.observation.observations.keys()}")
|
| 75 |
+
...
|
| 76 |
+
... # Take action (cartpole: push right)
|
| 77 |
+
... result = client.step(DMControlAction(values=[0.5]))
|
| 78 |
+
... print(f"Reward: {result.reward}")
|
| 79 |
+
|
| 80 |
+
Example switching environments:
|
| 81 |
+
>>> client = DMControlEnv(base_url="http://localhost:8000")
|
| 82 |
+
>>> # Start with cartpole balance
|
| 83 |
+
>>> result = client.reset(domain_name="cartpole", task_name="balance")
|
| 84 |
+
>>> # ... train on cartpole ...
|
| 85 |
+
>>> # Switch to walker walk
|
| 86 |
+
>>> result = client.reset(domain_name="walker", task_name="walk")
|
| 87 |
+
>>> # ... train on walker ...
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
base_url: str,
|
| 93 |
+
connect_timeout_s: float = 10.0,
|
| 94 |
+
message_timeout_s: float = 60.0,
|
| 95 |
+
provider: Optional[Any] = None,
|
| 96 |
+
):
|
| 97 |
+
"""
|
| 98 |
+
Initialize dm_control environment client.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
base_url: Base URL of the environment server (http:// or ws://).
|
| 102 |
+
connect_timeout_s: Timeout for establishing WebSocket connection.
|
| 103 |
+
message_timeout_s: Timeout for receiving responses.
|
| 104 |
+
provider: Optional container/runtime provider for lifecycle management.
|
| 105 |
+
"""
|
| 106 |
+
super().__init__(
|
| 107 |
+
base_url=base_url,
|
| 108 |
+
connect_timeout_s=connect_timeout_s,
|
| 109 |
+
message_timeout_s=message_timeout_s,
|
| 110 |
+
provider=provider,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def _step_payload(self, action: DMControlAction) -> Dict:
|
| 114 |
+
"""
|
| 115 |
+
Convert DMControlAction to JSON payload for step request.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
action: DMControlAction instance
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Dictionary representation suitable for JSON encoding
|
| 122 |
+
"""
|
| 123 |
+
payload: Dict[str, Any] = {"values": action.values}
|
| 124 |
+
|
| 125 |
+
if action.metadata:
|
| 126 |
+
payload["metadata"] = action.metadata
|
| 127 |
+
|
| 128 |
+
return payload
|
| 129 |
+
|
| 130 |
+
def _parse_result(self, payload: Dict) -> StepResult[DMControlObservation]:
|
| 131 |
+
"""
|
| 132 |
+
Parse server response into StepResult[DMControlObservation].
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
payload: JSON response from server
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
StepResult with DMControlObservation
|
| 139 |
+
"""
|
| 140 |
+
obs_data = payload.get("observation", {})
|
| 141 |
+
|
| 142 |
+
observation = DMControlObservation(
|
| 143 |
+
observations=obs_data.get("observations", {}),
|
| 144 |
+
pixels=obs_data.get("pixels"),
|
| 145 |
+
done=payload.get("done", False),
|
| 146 |
+
reward=payload.get("reward"),
|
| 147 |
+
metadata=obs_data.get("metadata", {}),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
return StepResult(
|
| 151 |
+
observation=observation,
|
| 152 |
+
reward=payload.get("reward"),
|
| 153 |
+
done=payload.get("done", False),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def _parse_state(self, payload: Dict) -> DMControlState:
|
| 157 |
+
"""
|
| 158 |
+
Parse server response into DMControlState object.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
payload: JSON response from /state endpoint
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
DMControlState object with environment information
|
| 165 |
+
"""
|
| 166 |
+
return DMControlState(
|
| 167 |
+
episode_id=payload.get("episode_id"),
|
| 168 |
+
step_count=payload.get("step_count", 0),
|
| 169 |
+
domain_name=payload.get("domain_name", ""),
|
| 170 |
+
task_name=payload.get("task_name", ""),
|
| 171 |
+
action_spec=payload.get("action_spec", {}),
|
| 172 |
+
observation_spec=payload.get("observation_spec", {}),
|
| 173 |
+
physics_timestep=payload.get("physics_timestep", 0.002),
|
| 174 |
+
control_timestep=payload.get("control_timestep", 0.02),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def reset(
|
| 178 |
+
self,
|
| 179 |
+
domain_name: Optional[str] = None,
|
| 180 |
+
task_name: Optional[str] = None,
|
| 181 |
+
seed: Optional[int] = None,
|
| 182 |
+
render: bool = False,
|
| 183 |
+
**kwargs,
|
| 184 |
+
) -> StepResult[DMControlObservation]:
|
| 185 |
+
"""
|
| 186 |
+
Reset the environment.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
domain_name: Optionally switch to a different domain.
|
| 190 |
+
task_name: Optionally switch to a different task.
|
| 191 |
+
seed: Random seed for reproducibility.
|
| 192 |
+
render: If True, include pixel observations in response.
|
| 193 |
+
**kwargs: Additional arguments passed to server.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
StepResult with initial observation.
|
| 197 |
+
"""
|
| 198 |
+
reset_kwargs = dict(kwargs)
|
| 199 |
+
if domain_name is not None:
|
| 200 |
+
reset_kwargs["domain_name"] = domain_name
|
| 201 |
+
if task_name is not None:
|
| 202 |
+
reset_kwargs["task_name"] = task_name
|
| 203 |
+
if seed is not None:
|
| 204 |
+
reset_kwargs["seed"] = seed
|
| 205 |
+
reset_kwargs["render"] = render
|
| 206 |
+
|
| 207 |
+
return super().reset(**reset_kwargs)
|
| 208 |
+
|
| 209 |
+
def step(
|
| 210 |
+
self,
|
| 211 |
+
action: DMControlAction,
|
| 212 |
+
render: bool = False,
|
| 213 |
+
**kwargs,
|
| 214 |
+
) -> StepResult[DMControlObservation]:
|
| 215 |
+
"""
|
| 216 |
+
Execute one step in the environment.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
action: DMControlAction with continuous action values.
|
| 220 |
+
render: If True, include pixel observations in response.
|
| 221 |
+
**kwargs: Additional arguments passed to server.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
StepResult with new observation, reward, and done flag.
|
| 225 |
+
"""
|
| 226 |
+
# Note: render flag needs to be passed differently
|
| 227 |
+
# For now, the server remembers the render setting from reset
|
| 228 |
+
return super().step(action, **kwargs)
|
| 229 |
+
|
| 230 |
+
@staticmethod
|
| 231 |
+
def available_environments() -> List[Tuple[str, str]]:
|
| 232 |
+
"""
|
| 233 |
+
List available dm_control environments.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
List of (domain_name, task_name) tuples.
|
| 237 |
+
"""
|
| 238 |
+
return AVAILABLE_ENVIRONMENTS
|
| 239 |
+
|
| 240 |
+
@classmethod
|
| 241 |
+
def from_direct(
|
| 242 |
+
cls,
|
| 243 |
+
domain_name: str = "cartpole",
|
| 244 |
+
task_name: str = "balance",
|
| 245 |
+
render_height: int = 480,
|
| 246 |
+
render_width: int = 640,
|
| 247 |
+
port: int = 8765,
|
| 248 |
+
) -> "DMControlEnv":
|
| 249 |
+
"""
|
| 250 |
+
Create a dm_control environment client with an embedded local server.
|
| 251 |
+
|
| 252 |
+
This method starts a local uvicorn server in a subprocess and returns
|
| 253 |
+
a client connected to it.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
domain_name: Default domain to use.
|
| 257 |
+
task_name: Default task to use.
|
| 258 |
+
render_height: Height of rendered images.
|
| 259 |
+
render_width: Width of rendered images.
|
| 260 |
+
port: Port for the local server.
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
DMControlEnv client connected to the local server.
|
| 264 |
+
|
| 265 |
+
Example:
|
| 266 |
+
>>> client = DMControlEnv.from_direct(domain_name="walker", task_name="walk")
|
| 267 |
+
>>> try:
|
| 268 |
+
... result = client.reset()
|
| 269 |
+
... for _ in range(100):
|
| 270 |
+
... result = client.step(DMControlAction(values=[0.0] * 6))
|
| 271 |
+
... finally:
|
| 272 |
+
... client.close()
|
| 273 |
+
"""
|
| 274 |
+
import os
|
| 275 |
+
import subprocess
|
| 276 |
+
import sys
|
| 277 |
+
import time
|
| 278 |
+
|
| 279 |
+
import requests
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
from pathlib import Path
|
| 283 |
+
|
| 284 |
+
client_dir = Path(__file__).parent
|
| 285 |
+
server_app = "envs.dm_control_env.server.app:app"
|
| 286 |
+
cwd = client_dir.parent.parent
|
| 287 |
+
|
| 288 |
+
if not (cwd / "envs" / "dm_control_env" / "server" / "app.py").exists():
|
| 289 |
+
if (client_dir / "server" / "app.py").exists():
|
| 290 |
+
server_app = "server.app:app"
|
| 291 |
+
cwd = client_dir
|
| 292 |
+
except Exception:
|
| 293 |
+
server_app = "envs.dm_control_env.server.app:app"
|
| 294 |
+
cwd = None
|
| 295 |
+
|
| 296 |
+
env = {
|
| 297 |
+
**os.environ,
|
| 298 |
+
"DMCONTROL_DOMAIN": domain_name,
|
| 299 |
+
"DMCONTROL_TASK": task_name,
|
| 300 |
+
"DMCONTROL_RENDER_HEIGHT": str(render_height),
|
| 301 |
+
"DMCONTROL_RENDER_WIDTH": str(render_width),
|
| 302 |
+
"NO_PROXY": "localhost,127.0.0.1",
|
| 303 |
+
"no_proxy": "localhost,127.0.0.1",
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
if cwd:
|
| 307 |
+
src_path = str(cwd / "src")
|
| 308 |
+
existing_path = env.get("PYTHONPATH", "")
|
| 309 |
+
env["PYTHONPATH"] = (
|
| 310 |
+
f"{src_path}:{cwd}:{existing_path}"
|
| 311 |
+
if existing_path
|
| 312 |
+
else f"{src_path}:{cwd}"
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
cmd = [
|
| 316 |
+
sys.executable,
|
| 317 |
+
"-m",
|
| 318 |
+
"uvicorn",
|
| 319 |
+
server_app,
|
| 320 |
+
"--host",
|
| 321 |
+
"127.0.0.1",
|
| 322 |
+
"--port",
|
| 323 |
+
str(port),
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
server_process = subprocess.Popen(
|
| 327 |
+
cmd,
|
| 328 |
+
env=env,
|
| 329 |
+
stdout=subprocess.PIPE,
|
| 330 |
+
stderr=subprocess.STDOUT,
|
| 331 |
+
cwd=str(cwd) if cwd else None,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
base_url = f"http://127.0.0.1:{port}"
|
| 335 |
+
healthy = False
|
| 336 |
+
for _ in range(30):
|
| 337 |
+
try:
|
| 338 |
+
response = requests.get(
|
| 339 |
+
f"{base_url}/health",
|
| 340 |
+
timeout=2,
|
| 341 |
+
proxies={"http": None, "https": None},
|
| 342 |
+
)
|
| 343 |
+
if response.status_code == 200:
|
| 344 |
+
healthy = True
|
| 345 |
+
break
|
| 346 |
+
except requests.exceptions.RequestException:
|
| 347 |
+
pass
|
| 348 |
+
time.sleep(1)
|
| 349 |
+
|
| 350 |
+
if not healthy:
|
| 351 |
+
server_process.kill()
|
| 352 |
+
raise RuntimeError(
|
| 353 |
+
f"Failed to start local dm_control server on port {port}. "
|
| 354 |
+
"Check that the port is available and dependencies are installed."
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
class DirectModeProvider:
|
| 358 |
+
"""Provider that manages the embedded server subprocess."""
|
| 359 |
+
|
| 360 |
+
def __init__(self, process: subprocess.Popen):
|
| 361 |
+
self._process = process
|
| 362 |
+
|
| 363 |
+
def stop(self):
|
| 364 |
+
"""Stop the embedded server."""
|
| 365 |
+
if self._process:
|
| 366 |
+
self._process.terminate()
|
| 367 |
+
try:
|
| 368 |
+
self._process.wait(timeout=10)
|
| 369 |
+
except subprocess.TimeoutExpired:
|
| 370 |
+
self._process.kill()
|
| 371 |
+
self._process = None
|
| 372 |
+
|
| 373 |
+
provider = DirectModeProvider(server_process)
|
| 374 |
+
client = cls(base_url=base_url, provider=provider)
|
| 375 |
+
return client
|
envs/dm_control_env/README.md
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: dm_control Environment Server
|
| 3 |
+
emoji: 🤖
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# dm_control OpenEnv Environment
|
| 15 |
+
|
| 16 |
+
A generic OpenEnv environment for [dm_control.suite](https://github.com/google-deepmind/dm_control), providing access to all MuJoCo-based continuous control tasks.
|
| 17 |
+
|
| 18 |
+
<p align="center">
|
| 19 |
+
<img src="assets/cartpole.png" width="45%" alt="Cartpole Balance"/>
|
| 20 |
+
<img src="assets/quadruped.png" width="45%" alt="Quadruped Walk"/>
|
| 21 |
+
</p>
|
| 22 |
+
|
| 23 |
+
## Supported Environments
|
| 24 |
+
|
| 25 |
+
| Domain | Tasks |
|
| 26 |
+
|--------|-------|
|
| 27 |
+
| cartpole | balance, swingup, swingup_sparse |
|
| 28 |
+
| walker | stand, walk, run |
|
| 29 |
+
| humanoid | stand, walk, run |
|
| 30 |
+
| cheetah | run |
|
| 31 |
+
| hopper | stand, hop |
|
| 32 |
+
| reacher | easy, hard |
|
| 33 |
+
| pendulum | swingup |
|
| 34 |
+
| finger | spin, turn_easy, turn_hard |
|
| 35 |
+
| fish | upright, swim |
|
| 36 |
+
| ball_in_cup | catch |
|
| 37 |
+
| And more... | See `dm_control.suite.BENCHMARKING` |
|
| 38 |
+
|
| 39 |
+
## Quick Start
|
| 40 |
+
|
| 41 |
+
### Using the Client
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
from envs.dm_control_env import DMControlEnv, DMControlAction
|
| 45 |
+
|
| 46 |
+
# Connect to a running server
|
| 47 |
+
with DMControlEnv(base_url="http://localhost:8000") as env:
|
| 48 |
+
# Reset with default (cartpole/balance)
|
| 49 |
+
result = env.reset()
|
| 50 |
+
print(f"Observations: {result.observation.observations.keys()}")
|
| 51 |
+
|
| 52 |
+
# Take actions
|
| 53 |
+
for _ in range(100):
|
| 54 |
+
action = DMControlAction(values=[0.5]) # Push cart right
|
| 55 |
+
result = env.step(action)
|
| 56 |
+
print(f"Reward: {result.reward}, Done: {result.done}")
|
| 57 |
+
|
| 58 |
+
if result.done:
|
| 59 |
+
result = env.reset()
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
### Switching Environments
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
# Start with cartpole
|
| 66 |
+
result = env.reset(domain_name="cartpole", task_name="balance")
|
| 67 |
+
|
| 68 |
+
# Switch to walker (on next reset)
|
| 69 |
+
result = env.reset(domain_name="walker", task_name="walk")
|
| 70 |
+
# Note: walker has 6 action dimensions
|
| 71 |
+
action = DMControlAction(values=[0.0] * 6)
|
| 72 |
+
result = env.step(action)
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### Running the Server
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
# From OpenEnv root
|
| 79 |
+
cd envs/dm_control_env
|
| 80 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 81 |
+
|
| 82 |
+
# Or using uv
|
| 83 |
+
uv run --project . server
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### Using Docker
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
# Build
|
| 90 |
+
docker build -t dm_control:latest -f server/Dockerfile .
|
| 91 |
+
|
| 92 |
+
# Run
|
| 93 |
+
docker run -p 8000:8000 dm_control:latest
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
## API
|
| 97 |
+
|
| 98 |
+
### Action
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
class DMControlAction(Action):
|
| 102 |
+
values: List[float] # Continuous action values
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
Action dimensions vary by environment:
|
| 106 |
+
- cartpole: 1 (force on cart)
|
| 107 |
+
- walker: 6 (joint torques)
|
| 108 |
+
- humanoid: 21 (joint torques)
|
| 109 |
+
|
| 110 |
+
### Observation
|
| 111 |
+
|
| 112 |
+
```python
|
| 113 |
+
class DMControlObservation(Observation):
|
| 114 |
+
observations: Dict[str, List[float]] # Named observation arrays
|
| 115 |
+
pixels: Optional[str] # Base64 PNG (if render=True)
|
| 116 |
+
reward: float
|
| 117 |
+
done: bool
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
### State
|
| 121 |
+
|
| 122 |
+
```python
|
| 123 |
+
class DMControlState(State):
|
| 124 |
+
domain_name: str
|
| 125 |
+
task_name: str
|
| 126 |
+
action_spec: Dict[str, Any]
|
| 127 |
+
observation_spec: Dict[str, Any]
|
| 128 |
+
physics_timestep: float
|
| 129 |
+
control_timestep: float
|
| 130 |
+
episode_id: str
|
| 131 |
+
step_count: int
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
## Examples
|
| 135 |
+
|
| 136 |
+
See the `examples/` directory:
|
| 137 |
+
- `cartpole_control.py` - Interactive cartpole control with arrow keys
|
| 138 |
+
- `hopper_control.py` - Interactive hopper control with spacebar for random forces
|
| 139 |
+
- `quadruped_control.py` - Interactive quadruped control with spacebar for random forces
|
| 140 |
+
- `list_environments.py` - Print all available environments
|
| 141 |
+
|
| 142 |
+
All examples support consistent CLI arguments:
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
# Default: interactive mode with minimal pygame window
|
| 146 |
+
python examples/cartpole_control.py
|
| 147 |
+
|
| 148 |
+
# Visual mode with rendered MuJoCo frames
|
| 149 |
+
python examples/cartpole_control.py --visual
|
| 150 |
+
|
| 151 |
+
# Headless mode (no pygame, automated control)
|
| 152 |
+
python examples/cartpole_control.py --headless --max-steps 500
|
| 153 |
+
|
| 154 |
+
# Select a different task
|
| 155 |
+
python examples/cartpole_control.py --task swingup
|
| 156 |
+
python examples/hopper_control.py --task stand
|
| 157 |
+
python examples/quadruped_control.py --task run
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
## Environment Variables
|
| 161 |
+
|
| 162 |
+
| Variable | Default | Description |
|
| 163 |
+
|----------|---------|-------------|
|
| 164 |
+
| `DMCONTROL_DOMAIN` | cartpole | Default domain |
|
| 165 |
+
| `DMCONTROL_TASK` | balance | Default task |
|
| 166 |
+
| `DMCONTROL_RENDER_HEIGHT` | 480 | Render height |
|
| 167 |
+
| `DMCONTROL_RENDER_WIDTH` | 640 | Render width |
|
envs/dm_control_env/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""dm_control OpenEnv Environment.
|
| 2 |
+
|
| 3 |
+
A generic OpenEnv environment for dm_control.suite supporting all domains/tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .models import DMControlAction, DMControlObservation, DMControlState
|
| 7 |
+
from .client import DMControlEnv
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"DMControlAction",
|
| 11 |
+
"DMControlObservation",
|
| 12 |
+
"DMControlState",
|
| 13 |
+
"DMControlEnv",
|
| 14 |
+
]
|
envs/dm_control_env/assets/cartpole.png
ADDED
|
Git LFS Details
|
envs/dm_control_env/assets/quadruped.png
ADDED
|
Git LFS Details
|
envs/dm_control_env/client.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
dm_control Environment Client.
|
| 9 |
+
|
| 10 |
+
This module provides the client for connecting to a dm_control
|
| 11 |
+
Environment server via WebSocket for persistent sessions.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from openenv.core.client_types import StepResult
|
| 18 |
+
from openenv.core.env_client import EnvClient
|
| 19 |
+
|
| 20 |
+
from .models import (
|
| 21 |
+
AVAILABLE_ENVIRONMENTS,
|
| 22 |
+
DMControlAction,
|
| 23 |
+
DMControlObservation,
|
| 24 |
+
DMControlState,
|
| 25 |
+
)
|
| 26 |
+
except ImportError:
|
| 27 |
+
from openenv.core.client_types import StepResult
|
| 28 |
+
from openenv.core.env_client import EnvClient
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from models import (
|
| 32 |
+
AVAILABLE_ENVIRONMENTS,
|
| 33 |
+
DMControlAction,
|
| 34 |
+
DMControlObservation,
|
| 35 |
+
DMControlState,
|
| 36 |
+
)
|
| 37 |
+
except ImportError:
|
| 38 |
+
try:
|
| 39 |
+
from dm_control_env.models import (
|
| 40 |
+
AVAILABLE_ENVIRONMENTS,
|
| 41 |
+
DMControlAction,
|
| 42 |
+
DMControlObservation,
|
| 43 |
+
DMControlState,
|
| 44 |
+
)
|
| 45 |
+
except ImportError:
|
| 46 |
+
from envs.dm_control_env.models import (
|
| 47 |
+
AVAILABLE_ENVIRONMENTS,
|
| 48 |
+
DMControlAction,
|
| 49 |
+
DMControlObservation,
|
| 50 |
+
DMControlState,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class DMControlEnv(EnvClient[DMControlAction, DMControlObservation, DMControlState]):
|
| 55 |
+
"""
|
| 56 |
+
Client for dm_control.suite environments.
|
| 57 |
+
|
| 58 |
+
This client maintains a persistent WebSocket connection to the environment
|
| 59 |
+
server, enabling efficient multi-step interactions with lower latency.
|
| 60 |
+
|
| 61 |
+
Supported Environments (via dm_control.suite):
|
| 62 |
+
- cartpole: balance, swingup, swingup_sparse
|
| 63 |
+
- walker: stand, walk, run
|
| 64 |
+
- humanoid: stand, walk, run
|
| 65 |
+
- cheetah: run
|
| 66 |
+
- hopper: stand, hop
|
| 67 |
+
- reacher: easy, hard
|
| 68 |
+
- And many more...
|
| 69 |
+
|
| 70 |
+
Example:
|
| 71 |
+
>>> # Connect to a running server
|
| 72 |
+
>>> with DMControlEnv(base_url="http://localhost:8000") as client:
|
| 73 |
+
... result = client.reset()
|
| 74 |
+
... print(f"Observations: {result.observation.observations.keys()}")
|
| 75 |
+
...
|
| 76 |
+
... # Take action (cartpole: push right)
|
| 77 |
+
... result = client.step(DMControlAction(values=[0.5]))
|
| 78 |
+
... print(f"Reward: {result.reward}")
|
| 79 |
+
|
| 80 |
+
Example switching environments:
|
| 81 |
+
>>> client = DMControlEnv(base_url="http://localhost:8000")
|
| 82 |
+
>>> # Start with cartpole balance
|
| 83 |
+
>>> result = client.reset(domain_name="cartpole", task_name="balance")
|
| 84 |
+
>>> # ... train on cartpole ...
|
| 85 |
+
>>> # Switch to walker walk
|
| 86 |
+
>>> result = client.reset(domain_name="walker", task_name="walk")
|
| 87 |
+
>>> # ... train on walker ...
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
base_url: str,
|
| 93 |
+
connect_timeout_s: float = 10.0,
|
| 94 |
+
message_timeout_s: float = 60.0,
|
| 95 |
+
provider: Optional[Any] = None,
|
| 96 |
+
):
|
| 97 |
+
"""
|
| 98 |
+
Initialize dm_control environment client.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
base_url: Base URL of the environment server (http:// or ws://).
|
| 102 |
+
connect_timeout_s: Timeout for establishing WebSocket connection.
|
| 103 |
+
message_timeout_s: Timeout for receiving responses.
|
| 104 |
+
provider: Optional container/runtime provider for lifecycle management.
|
| 105 |
+
"""
|
| 106 |
+
super().__init__(
|
| 107 |
+
base_url=base_url,
|
| 108 |
+
connect_timeout_s=connect_timeout_s,
|
| 109 |
+
message_timeout_s=message_timeout_s,
|
| 110 |
+
provider=provider,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def _step_payload(self, action: DMControlAction) -> Dict:
|
| 114 |
+
"""
|
| 115 |
+
Convert DMControlAction to JSON payload for step request.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
action: DMControlAction instance
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Dictionary representation suitable for JSON encoding
|
| 122 |
+
"""
|
| 123 |
+
payload: Dict[str, Any] = {"values": action.values}
|
| 124 |
+
|
| 125 |
+
if action.metadata:
|
| 126 |
+
payload["metadata"] = action.metadata
|
| 127 |
+
|
| 128 |
+
return payload
|
| 129 |
+
|
| 130 |
+
def _parse_result(self, payload: Dict) -> StepResult[DMControlObservation]:
|
| 131 |
+
"""
|
| 132 |
+
Parse server response into StepResult[DMControlObservation].
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
payload: JSON response from server
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
StepResult with DMControlObservation
|
| 139 |
+
"""
|
| 140 |
+
obs_data = payload.get("observation", {})
|
| 141 |
+
|
| 142 |
+
observation = DMControlObservation(
|
| 143 |
+
observations=obs_data.get("observations", {}),
|
| 144 |
+
pixels=obs_data.get("pixels"),
|
| 145 |
+
done=payload.get("done", False),
|
| 146 |
+
reward=payload.get("reward"),
|
| 147 |
+
metadata=obs_data.get("metadata", {}),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
return StepResult(
|
| 151 |
+
observation=observation,
|
| 152 |
+
reward=payload.get("reward"),
|
| 153 |
+
done=payload.get("done", False),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def _parse_state(self, payload: Dict) -> DMControlState:
|
| 157 |
+
"""
|
| 158 |
+
Parse server response into DMControlState object.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
payload: JSON response from /state endpoint
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
DMControlState object with environment information
|
| 165 |
+
"""
|
| 166 |
+
return DMControlState(
|
| 167 |
+
episode_id=payload.get("episode_id"),
|
| 168 |
+
step_count=payload.get("step_count", 0),
|
| 169 |
+
domain_name=payload.get("domain_name", ""),
|
| 170 |
+
task_name=payload.get("task_name", ""),
|
| 171 |
+
action_spec=payload.get("action_spec", {}),
|
| 172 |
+
observation_spec=payload.get("observation_spec", {}),
|
| 173 |
+
physics_timestep=payload.get("physics_timestep", 0.002),
|
| 174 |
+
control_timestep=payload.get("control_timestep", 0.02),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def reset(
|
| 178 |
+
self,
|
| 179 |
+
domain_name: Optional[str] = None,
|
| 180 |
+
task_name: Optional[str] = None,
|
| 181 |
+
seed: Optional[int] = None,
|
| 182 |
+
render: bool = False,
|
| 183 |
+
**kwargs,
|
| 184 |
+
) -> StepResult[DMControlObservation]:
|
| 185 |
+
"""
|
| 186 |
+
Reset the environment.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
domain_name: Optionally switch to a different domain.
|
| 190 |
+
task_name: Optionally switch to a different task.
|
| 191 |
+
seed: Random seed for reproducibility.
|
| 192 |
+
render: If True, include pixel observations in response.
|
| 193 |
+
**kwargs: Additional arguments passed to server.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
StepResult with initial observation.
|
| 197 |
+
"""
|
| 198 |
+
reset_kwargs = dict(kwargs)
|
| 199 |
+
if domain_name is not None:
|
| 200 |
+
reset_kwargs["domain_name"] = domain_name
|
| 201 |
+
if task_name is not None:
|
| 202 |
+
reset_kwargs["task_name"] = task_name
|
| 203 |
+
if seed is not None:
|
| 204 |
+
reset_kwargs["seed"] = seed
|
| 205 |
+
reset_kwargs["render"] = render
|
| 206 |
+
|
| 207 |
+
return super().reset(**reset_kwargs)
|
| 208 |
+
|
| 209 |
+
def step(
|
| 210 |
+
self,
|
| 211 |
+
action: DMControlAction,
|
| 212 |
+
render: bool = False,
|
| 213 |
+
**kwargs,
|
| 214 |
+
) -> StepResult[DMControlObservation]:
|
| 215 |
+
"""
|
| 216 |
+
Execute one step in the environment.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
action: DMControlAction with continuous action values.
|
| 220 |
+
render: If True, include pixel observations in response.
|
| 221 |
+
**kwargs: Additional arguments passed to server.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
StepResult with new observation, reward, and done flag.
|
| 225 |
+
"""
|
| 226 |
+
# Note: render flag needs to be passed differently
|
| 227 |
+
# For now, the server remembers the render setting from reset
|
| 228 |
+
return super().step(action, **kwargs)
|
| 229 |
+
|
| 230 |
+
@staticmethod
|
| 231 |
+
def available_environments() -> List[Tuple[str, str]]:
|
| 232 |
+
"""
|
| 233 |
+
List available dm_control environments.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
List of (domain_name, task_name) tuples.
|
| 237 |
+
"""
|
| 238 |
+
return AVAILABLE_ENVIRONMENTS
|
| 239 |
+
|
| 240 |
+
@classmethod
|
| 241 |
+
def from_direct(
|
| 242 |
+
cls,
|
| 243 |
+
domain_name: str = "cartpole",
|
| 244 |
+
task_name: str = "balance",
|
| 245 |
+
render_height: int = 480,
|
| 246 |
+
render_width: int = 640,
|
| 247 |
+
port: int = 8765,
|
| 248 |
+
) -> "DMControlEnv":
|
| 249 |
+
"""
|
| 250 |
+
Create a dm_control environment client with an embedded local server.
|
| 251 |
+
|
| 252 |
+
This method starts a local uvicorn server in a subprocess and returns
|
| 253 |
+
a client connected to it.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
domain_name: Default domain to use.
|
| 257 |
+
task_name: Default task to use.
|
| 258 |
+
render_height: Height of rendered images.
|
| 259 |
+
render_width: Width of rendered images.
|
| 260 |
+
port: Port for the local server.
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
DMControlEnv client connected to the local server.
|
| 264 |
+
|
| 265 |
+
Example:
|
| 266 |
+
>>> client = DMControlEnv.from_direct(domain_name="walker", task_name="walk")
|
| 267 |
+
>>> try:
|
| 268 |
+
... result = client.reset()
|
| 269 |
+
... for _ in range(100):
|
| 270 |
+
... result = client.step(DMControlAction(values=[0.0] * 6))
|
| 271 |
+
... finally:
|
| 272 |
+
... client.close()
|
| 273 |
+
"""
|
| 274 |
+
import os
|
| 275 |
+
import subprocess
|
| 276 |
+
import sys
|
| 277 |
+
import time
|
| 278 |
+
|
| 279 |
+
import requests
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
from pathlib import Path
|
| 283 |
+
|
| 284 |
+
client_dir = Path(__file__).parent
|
| 285 |
+
server_app = "envs.dm_control_env.server.app:app"
|
| 286 |
+
cwd = client_dir.parent.parent
|
| 287 |
+
|
| 288 |
+
if not (cwd / "envs" / "dm_control_env" / "server" / "app.py").exists():
|
| 289 |
+
if (client_dir / "server" / "app.py").exists():
|
| 290 |
+
server_app = "server.app:app"
|
| 291 |
+
cwd = client_dir
|
| 292 |
+
except Exception:
|
| 293 |
+
server_app = "envs.dm_control_env.server.app:app"
|
| 294 |
+
cwd = None
|
| 295 |
+
|
| 296 |
+
env = {
|
| 297 |
+
**os.environ,
|
| 298 |
+
"DMCONTROL_DOMAIN": domain_name,
|
| 299 |
+
"DMCONTROL_TASK": task_name,
|
| 300 |
+
"DMCONTROL_RENDER_HEIGHT": str(render_height),
|
| 301 |
+
"DMCONTROL_RENDER_WIDTH": str(render_width),
|
| 302 |
+
"NO_PROXY": "localhost,127.0.0.1",
|
| 303 |
+
"no_proxy": "localhost,127.0.0.1",
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
if cwd:
|
| 307 |
+
src_path = str(cwd / "src")
|
| 308 |
+
existing_path = env.get("PYTHONPATH", "")
|
| 309 |
+
env["PYTHONPATH"] = (
|
| 310 |
+
f"{src_path}:{cwd}:{existing_path}"
|
| 311 |
+
if existing_path
|
| 312 |
+
else f"{src_path}:{cwd}"
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
cmd = [
|
| 316 |
+
sys.executable,
|
| 317 |
+
"-m",
|
| 318 |
+
"uvicorn",
|
| 319 |
+
server_app,
|
| 320 |
+
"--host",
|
| 321 |
+
"127.0.0.1",
|
| 322 |
+
"--port",
|
| 323 |
+
str(port),
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
server_process = subprocess.Popen(
|
| 327 |
+
cmd,
|
| 328 |
+
env=env,
|
| 329 |
+
stdout=subprocess.PIPE,
|
| 330 |
+
stderr=subprocess.STDOUT,
|
| 331 |
+
cwd=str(cwd) if cwd else None,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
base_url = f"http://127.0.0.1:{port}"
|
| 335 |
+
healthy = False
|
| 336 |
+
for _ in range(30):
|
| 337 |
+
try:
|
| 338 |
+
response = requests.get(
|
| 339 |
+
f"{base_url}/health",
|
| 340 |
+
timeout=2,
|
| 341 |
+
proxies={"http": None, "https": None},
|
| 342 |
+
)
|
| 343 |
+
if response.status_code == 200:
|
| 344 |
+
healthy = True
|
| 345 |
+
break
|
| 346 |
+
except requests.exceptions.RequestException:
|
| 347 |
+
pass
|
| 348 |
+
time.sleep(1)
|
| 349 |
+
|
| 350 |
+
if not healthy:
|
| 351 |
+
server_process.kill()
|
| 352 |
+
raise RuntimeError(
|
| 353 |
+
f"Failed to start local dm_control server on port {port}. "
|
| 354 |
+
"Check that the port is available and dependencies are installed."
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
class DirectModeProvider:
|
| 358 |
+
"""Provider that manages the embedded server subprocess."""
|
| 359 |
+
|
| 360 |
+
def __init__(self, process: subprocess.Popen):
|
| 361 |
+
self._process = process
|
| 362 |
+
|
| 363 |
+
def stop(self):
|
| 364 |
+
"""Stop the embedded server."""
|
| 365 |
+
if self._process:
|
| 366 |
+
self._process.terminate()
|
| 367 |
+
try:
|
| 368 |
+
self._process.wait(timeout=10)
|
| 369 |
+
except subprocess.TimeoutExpired:
|
| 370 |
+
self._process.kill()
|
| 371 |
+
self._process = None
|
| 372 |
+
|
| 373 |
+
provider = DirectModeProvider(server_process)
|
| 374 |
+
client = cls(base_url=base_url, provider=provider)
|
| 375 |
+
return client
|
envs/dm_control_env/examples/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""dm_control examples."""
|
envs/dm_control_env/examples/cartpole_control.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Interactive cartpole control via OpenEnv.
|
| 3 |
+
|
| 4 |
+
This example demonstrates using the dm_control OpenEnv client with
|
| 5 |
+
the cartpole environment. Use arrow keys to control the cart.
|
| 6 |
+
|
| 7 |
+
Controls:
|
| 8 |
+
LEFT/RIGHT arrows: Apply force to move cart
|
| 9 |
+
R: Reset environment
|
| 10 |
+
ESC or Q: Quit
|
| 11 |
+
|
| 12 |
+
Requirements:
|
| 13 |
+
pip install pygame
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
1. Start the server: uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 17 |
+
2. Run this script: python examples/cartpole_control.py
|
| 18 |
+
|
| 19 |
+
For visual mode (requires working MuJoCo rendering):
|
| 20 |
+
python examples/cartpole_control.py --visual
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import random
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
# Add parent directory to path for imports
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 30 |
+
|
| 31 |
+
from client import DMControlEnv
|
| 32 |
+
from models import DMControlAction
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def run_headless(env: DMControlEnv, task: str = "balance", max_steps: int = 500):
|
| 36 |
+
"""Run cartpole control in headless mode."""
|
| 37 |
+
print("\n=== Headless Mode (OpenEnv Step/Observation Pattern) ===")
|
| 38 |
+
print("This mode demonstrates the OpenEnv API with the cartpole.\n")
|
| 39 |
+
|
| 40 |
+
# Reset environment using OpenEnv pattern
|
| 41 |
+
result = env.reset(domain_name="cartpole", task_name=task)
|
| 42 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 43 |
+
print(f" position: {result.observation.observations.get('position', [])}")
|
| 44 |
+
print(f" velocity: {result.observation.observations.get('velocity', [])}")
|
| 45 |
+
|
| 46 |
+
total_reward = 0.0
|
| 47 |
+
step_count = 0
|
| 48 |
+
|
| 49 |
+
print("\nRunning with random actions to demonstrate step/observation pattern...\n")
|
| 50 |
+
|
| 51 |
+
while not result.done and step_count < max_steps:
|
| 52 |
+
# Random action in [-1, 1]
|
| 53 |
+
action_value = random.uniform(-1.0, 1.0)
|
| 54 |
+
|
| 55 |
+
# Step the environment using OpenEnv pattern
|
| 56 |
+
action = DMControlAction(values=[action_value])
|
| 57 |
+
result = env.step(action)
|
| 58 |
+
|
| 59 |
+
# Access observation and reward from result
|
| 60 |
+
total_reward += result.reward or 0.0
|
| 61 |
+
step_count += 1
|
| 62 |
+
|
| 63 |
+
# Print progress periodically
|
| 64 |
+
if step_count % 50 == 0:
|
| 65 |
+
pos = result.observation.observations.get("position", [])
|
| 66 |
+
vel = result.observation.observations.get("velocity", [])
|
| 67 |
+
print(
|
| 68 |
+
f"Step {step_count}: reward={result.reward:.3f}, "
|
| 69 |
+
f"total={total_reward:.2f}, done={result.done}"
|
| 70 |
+
)
|
| 71 |
+
print(f" position={pos}, velocity={vel}")
|
| 72 |
+
|
| 73 |
+
print(f"\nEpisode finished: {step_count} steps, total reward: {total_reward:.2f}")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def run_interactive(env: DMControlEnv, task: str = "balance"):
|
| 77 |
+
"""Run interactive control with keyboard input via pygame."""
|
| 78 |
+
import pygame
|
| 79 |
+
|
| 80 |
+
print("\n=== Interactive Mode (OpenEnv Step/Observation Pattern) ===")
|
| 81 |
+
print("Use LEFT/RIGHT arrows to control cart, R to reset, ESC to quit.\n")
|
| 82 |
+
|
| 83 |
+
# Reset environment using OpenEnv pattern
|
| 84 |
+
result = env.reset(domain_name="cartpole", task_name=task)
|
| 85 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 86 |
+
|
| 87 |
+
# Initialize pygame for keyboard input (minimal window)
|
| 88 |
+
pygame.init()
|
| 89 |
+
screen = pygame.display.set_mode((400, 100))
|
| 90 |
+
pygame.display.set_caption("Cartpole Control - Arrow keys to move, R to reset")
|
| 91 |
+
clock = pygame.time.Clock()
|
| 92 |
+
|
| 93 |
+
# Font for display
|
| 94 |
+
font = pygame.font.Font(None, 24)
|
| 95 |
+
|
| 96 |
+
running = True
|
| 97 |
+
total_reward = 0.0
|
| 98 |
+
step_count = 0
|
| 99 |
+
|
| 100 |
+
print("\nControls:")
|
| 101 |
+
print(" LEFT/RIGHT arrows: Move cart")
|
| 102 |
+
print(" R: Reset environment")
|
| 103 |
+
print(" ESC or Q: Quit\n")
|
| 104 |
+
|
| 105 |
+
while running:
|
| 106 |
+
# Handle events
|
| 107 |
+
for event in pygame.event.get():
|
| 108 |
+
if event.type == pygame.QUIT:
|
| 109 |
+
running = False
|
| 110 |
+
elif event.type == pygame.KEYDOWN:
|
| 111 |
+
if event.key in (pygame.K_ESCAPE, pygame.K_q):
|
| 112 |
+
running = False
|
| 113 |
+
elif event.key == pygame.K_r:
|
| 114 |
+
result = env.reset(domain_name="cartpole", task_name=task)
|
| 115 |
+
total_reward = 0.0
|
| 116 |
+
step_count = 0
|
| 117 |
+
print("Environment reset")
|
| 118 |
+
|
| 119 |
+
# Check for held keys (for continuous control)
|
| 120 |
+
keys = pygame.key.get_pressed()
|
| 121 |
+
if keys[pygame.K_LEFT]:
|
| 122 |
+
action_value = -1.0
|
| 123 |
+
elif keys[pygame.K_RIGHT]:
|
| 124 |
+
action_value = 1.0
|
| 125 |
+
else:
|
| 126 |
+
action_value = 0.0
|
| 127 |
+
|
| 128 |
+
# Step the environment using OpenEnv pattern
|
| 129 |
+
action = DMControlAction(values=[action_value])
|
| 130 |
+
result = env.step(action)
|
| 131 |
+
|
| 132 |
+
# Track reward from result
|
| 133 |
+
total_reward += result.reward or 0.0
|
| 134 |
+
step_count += 1
|
| 135 |
+
|
| 136 |
+
# Check if episode is done
|
| 137 |
+
if result.done:
|
| 138 |
+
print(
|
| 139 |
+
f"Episode finished! Steps: {step_count}, "
|
| 140 |
+
f"Total reward: {total_reward:.2f}"
|
| 141 |
+
)
|
| 142 |
+
# Auto-reset on done
|
| 143 |
+
result = env.reset(domain_name="cartpole", task_name=task)
|
| 144 |
+
total_reward = 0.0
|
| 145 |
+
step_count = 0
|
| 146 |
+
|
| 147 |
+
# Update display
|
| 148 |
+
direction = (
|
| 149 |
+
"<--" if action_value < 0 else ("-->" if action_value > 0 else "---")
|
| 150 |
+
)
|
| 151 |
+
screen.fill((30, 30, 30))
|
| 152 |
+
text = font.render(
|
| 153 |
+
f"Step: {step_count} | Reward: {total_reward:.1f} | {direction}",
|
| 154 |
+
True,
|
| 155 |
+
(255, 255, 255),
|
| 156 |
+
)
|
| 157 |
+
screen.blit(text, (10, 40))
|
| 158 |
+
pygame.display.flip()
|
| 159 |
+
|
| 160 |
+
# Print progress periodically
|
| 161 |
+
if step_count % 200 == 0 and step_count > 0:
|
| 162 |
+
print(f"Step {step_count}: Total reward: {total_reward:.2f}")
|
| 163 |
+
|
| 164 |
+
# Cap at 30 FPS
|
| 165 |
+
clock.tick(30)
|
| 166 |
+
|
| 167 |
+
pygame.quit()
|
| 168 |
+
print(f"Session ended. Final reward: {total_reward:.2f}")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def run_visual(env: DMControlEnv, task: str = "balance"):
|
| 172 |
+
"""Run with pygame visualization showing rendered frames."""
|
| 173 |
+
import base64
|
| 174 |
+
import io
|
| 175 |
+
|
| 176 |
+
import pygame
|
| 177 |
+
|
| 178 |
+
print("\n=== Visual Mode (OpenEnv Step/Observation Pattern) ===")
|
| 179 |
+
|
| 180 |
+
# Reset environment with rendering enabled
|
| 181 |
+
result = env.reset(domain_name="cartpole", task_name=task, render=True)
|
| 182 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 183 |
+
|
| 184 |
+
# Get first frame to determine window size
|
| 185 |
+
if result.observation.pixels is None:
|
| 186 |
+
print("Error: Server did not return rendered pixels.")
|
| 187 |
+
print("Make sure the server supports render=True")
|
| 188 |
+
print("\nTry running in interactive mode (default) instead.")
|
| 189 |
+
sys.exit(1)
|
| 190 |
+
|
| 191 |
+
# Decode base64 PNG to pygame surface
|
| 192 |
+
png_data = base64.b64decode(result.observation.pixels)
|
| 193 |
+
frame = pygame.image.load(io.BytesIO(png_data))
|
| 194 |
+
frame_size = frame.get_size()
|
| 195 |
+
|
| 196 |
+
# Initialize pygame
|
| 197 |
+
pygame.init()
|
| 198 |
+
screen = pygame.display.set_mode(frame_size)
|
| 199 |
+
pygame.display.set_caption(
|
| 200 |
+
"Cartpole (OpenEnv) - Arrow Keys to Move, R to Reset, ESC to Quit"
|
| 201 |
+
)
|
| 202 |
+
clock = pygame.time.Clock()
|
| 203 |
+
|
| 204 |
+
print("Controls:")
|
| 205 |
+
print(" LEFT/RIGHT arrows: Move cart")
|
| 206 |
+
print(" R: Reset environment")
|
| 207 |
+
print(" ESC or Q: Quit")
|
| 208 |
+
|
| 209 |
+
running = True
|
| 210 |
+
total_reward = 0.0
|
| 211 |
+
step_count = 0
|
| 212 |
+
|
| 213 |
+
while running:
|
| 214 |
+
# Handle events
|
| 215 |
+
for event in pygame.event.get():
|
| 216 |
+
if event.type == pygame.QUIT:
|
| 217 |
+
running = False
|
| 218 |
+
elif event.type == pygame.KEYDOWN:
|
| 219 |
+
if event.key in (pygame.K_ESCAPE, pygame.K_q):
|
| 220 |
+
running = False
|
| 221 |
+
elif event.key == pygame.K_r:
|
| 222 |
+
result = env.reset(
|
| 223 |
+
domain_name="cartpole", task_name=task, render=True
|
| 224 |
+
)
|
| 225 |
+
total_reward = 0.0
|
| 226 |
+
step_count = 0
|
| 227 |
+
print("Environment reset")
|
| 228 |
+
|
| 229 |
+
# Check for held keys (for continuous control)
|
| 230 |
+
keys = pygame.key.get_pressed()
|
| 231 |
+
if keys[pygame.K_LEFT]:
|
| 232 |
+
action_value = -1.0
|
| 233 |
+
elif keys[pygame.K_RIGHT]:
|
| 234 |
+
action_value = 1.0
|
| 235 |
+
else:
|
| 236 |
+
action_value = 0.0
|
| 237 |
+
|
| 238 |
+
# Step the environment using OpenEnv pattern
|
| 239 |
+
action = DMControlAction(values=[action_value])
|
| 240 |
+
result = env.step(action, render=True)
|
| 241 |
+
|
| 242 |
+
# Track reward from result
|
| 243 |
+
total_reward += result.reward or 0.0
|
| 244 |
+
step_count += 1
|
| 245 |
+
|
| 246 |
+
# Check if episode is done
|
| 247 |
+
if result.done:
|
| 248 |
+
print(
|
| 249 |
+
f"Episode finished! Steps: {step_count}, "
|
| 250 |
+
f"Total reward: {total_reward:.2f}"
|
| 251 |
+
)
|
| 252 |
+
result = env.reset(domain_name="cartpole", task_name=task, render=True)
|
| 253 |
+
total_reward = 0.0
|
| 254 |
+
step_count = 0
|
| 255 |
+
|
| 256 |
+
# Render the frame from observation pixels
|
| 257 |
+
if result.observation.pixels:
|
| 258 |
+
png_data = base64.b64decode(result.observation.pixels)
|
| 259 |
+
frame = pygame.image.load(io.BytesIO(png_data))
|
| 260 |
+
screen.blit(frame, (0, 0))
|
| 261 |
+
pygame.display.flip()
|
| 262 |
+
|
| 263 |
+
# Print progress periodically
|
| 264 |
+
if step_count % 200 == 0 and step_count > 0:
|
| 265 |
+
print(f"Step {step_count}: Total reward: {total_reward:.2f}")
|
| 266 |
+
|
| 267 |
+
# Cap at 30 FPS
|
| 268 |
+
clock.tick(30)
|
| 269 |
+
|
| 270 |
+
pygame.quit()
|
| 271 |
+
print(f"Session ended. Final reward: {total_reward:.2f}")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def main():
|
| 275 |
+
parser = argparse.ArgumentParser(
|
| 276 |
+
description="Interactive cartpole control via OpenEnv"
|
| 277 |
+
)
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--visual",
|
| 280 |
+
action="store_true",
|
| 281 |
+
help="Enable pygame visualization with rendered frames",
|
| 282 |
+
)
|
| 283 |
+
parser.add_argument(
|
| 284 |
+
"--headless",
|
| 285 |
+
action="store_true",
|
| 286 |
+
help="Run in headless mode (no pygame, automated control)",
|
| 287 |
+
)
|
| 288 |
+
parser.add_argument(
|
| 289 |
+
"--max-steps",
|
| 290 |
+
type=int,
|
| 291 |
+
default=500,
|
| 292 |
+
help="Maximum steps for headless mode (default: 500)",
|
| 293 |
+
)
|
| 294 |
+
parser.add_argument(
|
| 295 |
+
"--task",
|
| 296 |
+
type=str,
|
| 297 |
+
default="balance",
|
| 298 |
+
choices=["balance", "balance_sparse", "swingup", "swingup_sparse"],
|
| 299 |
+
help="Cartpole task (default: balance)",
|
| 300 |
+
)
|
| 301 |
+
args = parser.parse_args()
|
| 302 |
+
|
| 303 |
+
server_url = "http://localhost:8000"
|
| 304 |
+
print(f"Connecting to {server_url}...")
|
| 305 |
+
|
| 306 |
+
try:
|
| 307 |
+
with DMControlEnv(base_url=server_url) as env:
|
| 308 |
+
print("Connected!")
|
| 309 |
+
|
| 310 |
+
# Get environment state
|
| 311 |
+
state = env.state()
|
| 312 |
+
print(f"Domain: {state.domain_name}, Task: {state.task_name}")
|
| 313 |
+
print(f"Action spec: {state.action_spec}")
|
| 314 |
+
|
| 315 |
+
if args.headless:
|
| 316 |
+
run_headless(env, task=args.task, max_steps=args.max_steps)
|
| 317 |
+
elif args.visual:
|
| 318 |
+
run_visual(env, task=args.task)
|
| 319 |
+
else:
|
| 320 |
+
run_interactive(env, task=args.task)
|
| 321 |
+
|
| 322 |
+
except ConnectionError as e:
|
| 323 |
+
print(f"Failed to connect: {e}")
|
| 324 |
+
print("\nMake sure the server is running:")
|
| 325 |
+
print(" cd OpenEnv")
|
| 326 |
+
print(
|
| 327 |
+
" PYTHONPATH=src:envs uvicorn envs.dm_control_env.server.app:app --port 8000"
|
| 328 |
+
)
|
| 329 |
+
sys.exit(1)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
if __name__ == "__main__":
|
| 333 |
+
main()
|
envs/dm_control_env/examples/hopper_control.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Interactive hopper control via OpenEnv.
|
| 3 |
+
|
| 4 |
+
This example demonstrates using the dm_control OpenEnv client with
|
| 5 |
+
the hopper environment. Press SPACE to apply random forces to the joints.
|
| 6 |
+
|
| 7 |
+
Controls:
|
| 8 |
+
SPACE: Apply random force to all joints
|
| 9 |
+
R: Reset environment
|
| 10 |
+
ESC or Q: Quit
|
| 11 |
+
|
| 12 |
+
Requirements:
|
| 13 |
+
pip install pygame
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
1. Start the server: uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 17 |
+
2. Run this script: python examples/hopper_control.py
|
| 18 |
+
|
| 19 |
+
For visual mode (requires working MuJoCo rendering):
|
| 20 |
+
python examples/hopper_control.py --visual
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import random
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
# Add parent directory to path for imports
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 30 |
+
|
| 31 |
+
from client import DMControlEnv
|
| 32 |
+
from models import DMControlAction
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_action_dim(env: DMControlEnv) -> int:
|
| 36 |
+
"""Get the action dimension from the environment state."""
|
| 37 |
+
state = env.state()
|
| 38 |
+
action_spec = state.action_spec
|
| 39 |
+
if action_spec and "shape" in action_spec:
|
| 40 |
+
shape = action_spec["shape"]
|
| 41 |
+
if isinstance(shape, list) and len(shape) > 0:
|
| 42 |
+
return shape[0]
|
| 43 |
+
# Hopper default: 4 actuators (hip, knee, ankle, toe)
|
| 44 |
+
return 4
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def generate_random_action(action_dim: int, magnitude: float = 1.0) -> DMControlAction:
|
| 48 |
+
"""Generate a random action with values in [-magnitude, magnitude]."""
|
| 49 |
+
values = [random.uniform(-magnitude, magnitude) for _ in range(action_dim)]
|
| 50 |
+
return DMControlAction(values=values)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def generate_zero_action(action_dim: int) -> DMControlAction:
|
| 54 |
+
"""Generate a zero action (no force applied)."""
|
| 55 |
+
return DMControlAction(values=[0.0] * action_dim)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def run_headless(env: DMControlEnv, task: str = "hop", max_steps: int = 1000):
|
| 59 |
+
"""Run hopper control in headless mode."""
|
| 60 |
+
print("\n=== Headless Mode (OpenEnv Step/Observation Pattern) ===")
|
| 61 |
+
print("This mode demonstrates the OpenEnv API with the hopper.\n")
|
| 62 |
+
|
| 63 |
+
# Reset environment using OpenEnv pattern
|
| 64 |
+
result = env.reset(domain_name="hopper", task_name=task)
|
| 65 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 66 |
+
|
| 67 |
+
# Get action dimension
|
| 68 |
+
action_dim = get_action_dim(env)
|
| 69 |
+
print(f"Action dimension: {action_dim}")
|
| 70 |
+
|
| 71 |
+
total_reward = 0.0
|
| 72 |
+
step_count = 0
|
| 73 |
+
|
| 74 |
+
print("\nRunning with periodic random forces...")
|
| 75 |
+
print("Every 30 steps, a random force burst is applied.\n")
|
| 76 |
+
|
| 77 |
+
while not result.done and step_count < max_steps:
|
| 78 |
+
# Apply random force every 30 steps, otherwise zero action
|
| 79 |
+
if step_count % 30 < 5:
|
| 80 |
+
# Random force burst for 5 steps
|
| 81 |
+
action = generate_random_action(action_dim, magnitude=0.8)
|
| 82 |
+
else:
|
| 83 |
+
# No force
|
| 84 |
+
action = generate_zero_action(action_dim)
|
| 85 |
+
|
| 86 |
+
# Step the environment using OpenEnv pattern
|
| 87 |
+
result = env.step(action)
|
| 88 |
+
|
| 89 |
+
# Access observation and reward from result
|
| 90 |
+
total_reward += result.reward or 0.0
|
| 91 |
+
step_count += 1
|
| 92 |
+
|
| 93 |
+
# Print progress periodically
|
| 94 |
+
if step_count % 100 == 0:
|
| 95 |
+
# Get some observation values
|
| 96 |
+
position = result.observation.observations.get("position", [])
|
| 97 |
+
velocity = result.observation.observations.get("velocity", [])
|
| 98 |
+
print(
|
| 99 |
+
f"Step {step_count}: reward={result.reward:.3f}, "
|
| 100 |
+
f"total={total_reward:.2f}, done={result.done}"
|
| 101 |
+
)
|
| 102 |
+
if position:
|
| 103 |
+
print(f" position: {position[:3]}")
|
| 104 |
+
if velocity:
|
| 105 |
+
print(f" velocity: {velocity[:3]}")
|
| 106 |
+
|
| 107 |
+
print(f"\nEpisode finished: {step_count} steps, total reward: {total_reward:.2f}")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def run_interactive(env: DMControlEnv, task: str = "hop"):
|
| 111 |
+
"""Run interactive control with keyboard input via pygame."""
|
| 112 |
+
import pygame
|
| 113 |
+
|
| 114 |
+
print("\n=== Interactive Mode (OpenEnv Step/Observation Pattern) ===")
|
| 115 |
+
print("Press SPACE to apply random force, R to reset, ESC to quit.\n")
|
| 116 |
+
|
| 117 |
+
# Reset environment using OpenEnv pattern
|
| 118 |
+
result = env.reset(domain_name="hopper", task_name=task)
|
| 119 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 120 |
+
|
| 121 |
+
# Get action dimension
|
| 122 |
+
action_dim = get_action_dim(env)
|
| 123 |
+
print(f"Action dimension: {action_dim}")
|
| 124 |
+
|
| 125 |
+
# Initialize pygame for keyboard input (minimal window)
|
| 126 |
+
pygame.init()
|
| 127 |
+
screen = pygame.display.set_mode((400, 100))
|
| 128 |
+
pygame.display.set_caption("Hopper Control - SPACE for random force, R to reset")
|
| 129 |
+
clock = pygame.time.Clock()
|
| 130 |
+
|
| 131 |
+
# Font for display
|
| 132 |
+
font = pygame.font.Font(None, 24)
|
| 133 |
+
|
| 134 |
+
running = True
|
| 135 |
+
total_reward = 0.0
|
| 136 |
+
step_count = 0
|
| 137 |
+
apply_random_force = False
|
| 138 |
+
|
| 139 |
+
print("\nControls:")
|
| 140 |
+
print(" SPACE: Apply random force to joints")
|
| 141 |
+
print(" R: Reset environment")
|
| 142 |
+
print(" ESC or Q: Quit\n")
|
| 143 |
+
|
| 144 |
+
while running:
|
| 145 |
+
# Handle events
|
| 146 |
+
for event in pygame.event.get():
|
| 147 |
+
if event.type == pygame.QUIT:
|
| 148 |
+
running = False
|
| 149 |
+
elif event.type == pygame.KEYDOWN:
|
| 150 |
+
if event.key in (pygame.K_ESCAPE, pygame.K_q):
|
| 151 |
+
running = False
|
| 152 |
+
elif event.key == pygame.K_r:
|
| 153 |
+
result = env.reset(domain_name="hopper", task_name=task)
|
| 154 |
+
total_reward = 0.0
|
| 155 |
+
step_count = 0
|
| 156 |
+
print("Environment reset")
|
| 157 |
+
|
| 158 |
+
# Check for held keys
|
| 159 |
+
keys = pygame.key.get_pressed()
|
| 160 |
+
apply_random_force = keys[pygame.K_SPACE]
|
| 161 |
+
|
| 162 |
+
# Generate action based on input
|
| 163 |
+
if apply_random_force:
|
| 164 |
+
action = generate_random_action(action_dim, magnitude=2.0)
|
| 165 |
+
else:
|
| 166 |
+
action = generate_zero_action(action_dim)
|
| 167 |
+
|
| 168 |
+
# Step the environment using OpenEnv pattern
|
| 169 |
+
result = env.step(action)
|
| 170 |
+
|
| 171 |
+
# Track reward from result
|
| 172 |
+
total_reward += result.reward or 0.0
|
| 173 |
+
step_count += 1
|
| 174 |
+
|
| 175 |
+
# Check if episode is done
|
| 176 |
+
if result.done:
|
| 177 |
+
print(
|
| 178 |
+
f"Episode finished! Steps: {step_count}, "
|
| 179 |
+
f"Total reward: {total_reward:.2f}"
|
| 180 |
+
)
|
| 181 |
+
# Auto-reset on done
|
| 182 |
+
result = env.reset(domain_name="hopper", task_name=task)
|
| 183 |
+
total_reward = 0.0
|
| 184 |
+
step_count = 0
|
| 185 |
+
|
| 186 |
+
# Update display
|
| 187 |
+
screen.fill((30, 30, 30))
|
| 188 |
+
status = "FORCE!" if apply_random_force else "idle"
|
| 189 |
+
text = font.render(
|
| 190 |
+
f"Step: {step_count} | Reward: {total_reward:.1f} | {status}",
|
| 191 |
+
True,
|
| 192 |
+
(255, 255, 255),
|
| 193 |
+
)
|
| 194 |
+
screen.blit(text, (10, 40))
|
| 195 |
+
pygame.display.flip()
|
| 196 |
+
|
| 197 |
+
# Print progress periodically
|
| 198 |
+
if step_count % 200 == 0 and step_count > 0:
|
| 199 |
+
print(f"Step {step_count}: Total reward: {total_reward:.2f}")
|
| 200 |
+
|
| 201 |
+
# Cap at 30 FPS
|
| 202 |
+
clock.tick(30)
|
| 203 |
+
|
| 204 |
+
pygame.quit()
|
| 205 |
+
print(f"Session ended. Final reward: {total_reward:.2f}")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def run_visual(env: DMControlEnv, task: str = "hop"):
|
| 209 |
+
"""Run with pygame visualization showing rendered frames."""
|
| 210 |
+
import base64
|
| 211 |
+
import io
|
| 212 |
+
|
| 213 |
+
import pygame
|
| 214 |
+
|
| 215 |
+
print("\n=== Visual Mode (OpenEnv Step/Observation Pattern) ===")
|
| 216 |
+
|
| 217 |
+
# Reset environment with rendering enabled
|
| 218 |
+
result = env.reset(domain_name="hopper", task_name=task, render=True)
|
| 219 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 220 |
+
|
| 221 |
+
# Get action dimension
|
| 222 |
+
action_dim = get_action_dim(env)
|
| 223 |
+
print(f"Action dimension: {action_dim}")
|
| 224 |
+
|
| 225 |
+
# Get first frame to determine window size
|
| 226 |
+
if result.observation.pixels is None:
|
| 227 |
+
print("Error: Server did not return rendered pixels.")
|
| 228 |
+
print("Make sure the server supports render=True")
|
| 229 |
+
print("\nTry running in interactive mode (default) instead.")
|
| 230 |
+
sys.exit(1)
|
| 231 |
+
|
| 232 |
+
# Decode base64 PNG to pygame surface
|
| 233 |
+
png_data = base64.b64decode(result.observation.pixels)
|
| 234 |
+
frame = pygame.image.load(io.BytesIO(png_data))
|
| 235 |
+
frame_size = frame.get_size()
|
| 236 |
+
|
| 237 |
+
# Initialize pygame
|
| 238 |
+
pygame.init()
|
| 239 |
+
screen = pygame.display.set_mode(frame_size)
|
| 240 |
+
pygame.display.set_caption(
|
| 241 |
+
"Hopper (OpenEnv) - SPACE for random force, R to Reset, ESC to Quit"
|
| 242 |
+
)
|
| 243 |
+
clock = pygame.time.Clock()
|
| 244 |
+
|
| 245 |
+
print("Controls:")
|
| 246 |
+
print(" SPACE: Apply random force to joints")
|
| 247 |
+
print(" R: Reset environment")
|
| 248 |
+
print(" ESC or Q: Quit")
|
| 249 |
+
|
| 250 |
+
running = True
|
| 251 |
+
total_reward = 0.0
|
| 252 |
+
step_count = 0
|
| 253 |
+
|
| 254 |
+
while running:
|
| 255 |
+
# Handle events
|
| 256 |
+
for event in pygame.event.get():
|
| 257 |
+
if event.type == pygame.QUIT:
|
| 258 |
+
running = False
|
| 259 |
+
elif event.type == pygame.KEYDOWN:
|
| 260 |
+
if event.key in (pygame.K_ESCAPE, pygame.K_q):
|
| 261 |
+
running = False
|
| 262 |
+
elif event.key == pygame.K_r:
|
| 263 |
+
result = env.reset(
|
| 264 |
+
domain_name="hopper", task_name=task, render=True
|
| 265 |
+
)
|
| 266 |
+
total_reward = 0.0
|
| 267 |
+
step_count = 0
|
| 268 |
+
print("Environment reset")
|
| 269 |
+
|
| 270 |
+
# Check for held keys
|
| 271 |
+
keys = pygame.key.get_pressed()
|
| 272 |
+
apply_random_force = keys[pygame.K_SPACE]
|
| 273 |
+
|
| 274 |
+
# Generate action based on input
|
| 275 |
+
if apply_random_force:
|
| 276 |
+
action = generate_random_action(action_dim, magnitude=2.0)
|
| 277 |
+
else:
|
| 278 |
+
action = generate_zero_action(action_dim)
|
| 279 |
+
|
| 280 |
+
# Step the environment using OpenEnv pattern
|
| 281 |
+
result = env.step(action, render=True)
|
| 282 |
+
|
| 283 |
+
# Track reward from result
|
| 284 |
+
total_reward += result.reward or 0.0
|
| 285 |
+
step_count += 1
|
| 286 |
+
|
| 287 |
+
# Check if episode is done
|
| 288 |
+
if result.done:
|
| 289 |
+
print(
|
| 290 |
+
f"Episode finished! Steps: {step_count}, "
|
| 291 |
+
f"Total reward: {total_reward:.2f}"
|
| 292 |
+
)
|
| 293 |
+
result = env.reset(domain_name="hopper", task_name=task, render=True)
|
| 294 |
+
total_reward = 0.0
|
| 295 |
+
step_count = 0
|
| 296 |
+
|
| 297 |
+
# Render the frame from observation pixels
|
| 298 |
+
if result.observation.pixels:
|
| 299 |
+
png_data = base64.b64decode(result.observation.pixels)
|
| 300 |
+
frame = pygame.image.load(io.BytesIO(png_data))
|
| 301 |
+
screen.blit(frame, (0, 0))
|
| 302 |
+
pygame.display.flip()
|
| 303 |
+
|
| 304 |
+
# Print progress periodically
|
| 305 |
+
if step_count % 200 == 0 and step_count > 0:
|
| 306 |
+
print(f"Step {step_count}: Total reward: {total_reward:.2f}")
|
| 307 |
+
|
| 308 |
+
# Cap at 30 FPS
|
| 309 |
+
clock.tick(30)
|
| 310 |
+
|
| 311 |
+
pygame.quit()
|
| 312 |
+
print(f"Session ended. Final reward: {total_reward:.2f}")
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def main():
|
| 316 |
+
parser = argparse.ArgumentParser(
|
| 317 |
+
description="Interactive hopper control via OpenEnv"
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--visual",
|
| 321 |
+
action="store_true",
|
| 322 |
+
help="Enable pygame visualization with rendered frames",
|
| 323 |
+
)
|
| 324 |
+
parser.add_argument(
|
| 325 |
+
"--headless",
|
| 326 |
+
action="store_true",
|
| 327 |
+
help="Run in headless mode (no pygame, automated control)",
|
| 328 |
+
)
|
| 329 |
+
parser.add_argument(
|
| 330 |
+
"--max-steps",
|
| 331 |
+
type=int,
|
| 332 |
+
default=1000,
|
| 333 |
+
help="Maximum steps for headless mode (default: 1000)",
|
| 334 |
+
)
|
| 335 |
+
parser.add_argument(
|
| 336 |
+
"--task",
|
| 337 |
+
type=str,
|
| 338 |
+
default="hop",
|
| 339 |
+
choices=["stand", "hop"],
|
| 340 |
+
help="Hopper task (default: hop)",
|
| 341 |
+
)
|
| 342 |
+
args = parser.parse_args()
|
| 343 |
+
|
| 344 |
+
server_url = "http://localhost:8000"
|
| 345 |
+
print(f"Connecting to {server_url}...")
|
| 346 |
+
|
| 347 |
+
try:
|
| 348 |
+
with DMControlEnv(base_url=server_url) as env:
|
| 349 |
+
print("Connected!")
|
| 350 |
+
|
| 351 |
+
# Get environment state
|
| 352 |
+
state = env.state()
|
| 353 |
+
print(f"Domain: {state.domain_name}, Task: {state.task_name}")
|
| 354 |
+
print(f"Action spec: {state.action_spec}")
|
| 355 |
+
|
| 356 |
+
if args.headless:
|
| 357 |
+
run_headless(env, task=args.task, max_steps=args.max_steps)
|
| 358 |
+
elif args.visual:
|
| 359 |
+
run_visual(env, task=args.task)
|
| 360 |
+
else:
|
| 361 |
+
run_interactive(env, task=args.task)
|
| 362 |
+
|
| 363 |
+
except ConnectionError as e:
|
| 364 |
+
print(f"Failed to connect: {e}")
|
| 365 |
+
print("\nMake sure the server is running:")
|
| 366 |
+
print(" cd OpenEnv")
|
| 367 |
+
print(
|
| 368 |
+
" PYTHONPATH=src:envs uvicorn envs.dm_control_env.server.app:app --port 8000"
|
| 369 |
+
)
|
| 370 |
+
sys.exit(1)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
if __name__ == "__main__":
|
| 374 |
+
main()
|
envs/dm_control_env/examples/list_environments.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""List all available dm_control.suite environments.
|
| 3 |
+
|
| 4 |
+
This utility prints all available domain/task combinations from dm_control.suite.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dm_control import suite
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
print("Available dm_control.suite environments:")
|
| 12 |
+
print("=" * 50)
|
| 13 |
+
|
| 14 |
+
# Group by domain
|
| 15 |
+
domains = {}
|
| 16 |
+
for domain, task in suite.BENCHMARKING:
|
| 17 |
+
if domain not in domains:
|
| 18 |
+
domains[domain] = []
|
| 19 |
+
domains[domain].append(task)
|
| 20 |
+
|
| 21 |
+
for domain in sorted(domains.keys()):
|
| 22 |
+
tasks = sorted(domains[domain])
|
| 23 |
+
print(f"\n{domain}:")
|
| 24 |
+
for task in tasks:
|
| 25 |
+
# Load env to get action spec
|
| 26 |
+
try:
|
| 27 |
+
env = suite.load(domain_name=domain, task_name=task)
|
| 28 |
+
action_spec = env.action_spec()
|
| 29 |
+
action_dim = action_spec.shape[0]
|
| 30 |
+
obs_keys = list(env.observation_spec().keys())
|
| 31 |
+
env.close()
|
| 32 |
+
print(f" - {task:20s} (action_dim={action_dim}, obs={obs_keys})")
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f" - {task:20s} (error: {e})")
|
| 35 |
+
|
| 36 |
+
print("\n" + "=" * 50)
|
| 37 |
+
print(f"Total: {len(suite.BENCHMARKING)} environments")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
main()
|
envs/dm_control_env/examples/quadruped_control.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Interactive quadruped control via OpenEnv.
|
| 3 |
+
|
| 4 |
+
This example demonstrates using the dm_control OpenEnv client with
|
| 5 |
+
the quadruped environment. Press SPACE to apply random forces to the joints.
|
| 6 |
+
|
| 7 |
+
Controls:
|
| 8 |
+
SPACE: Apply random force to all joints
|
| 9 |
+
R: Reset environment
|
| 10 |
+
ESC or Q: Quit
|
| 11 |
+
|
| 12 |
+
Requirements:
|
| 13 |
+
pip install pygame
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
1. Start the server: uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 17 |
+
2. Run this script: python examples/quadruped_control.py
|
| 18 |
+
|
| 19 |
+
For visual mode (requires working MuJoCo rendering):
|
| 20 |
+
python examples/quadruped_control.py --visual
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import random
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
# Add parent directory to path for imports
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 30 |
+
|
| 31 |
+
from client import DMControlEnv
|
| 32 |
+
from models import DMControlAction
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_action_dim(env: DMControlEnv) -> int:
|
| 36 |
+
"""Get the action dimension from the environment state."""
|
| 37 |
+
state = env.state()
|
| 38 |
+
action_spec = state.action_spec
|
| 39 |
+
if action_spec and "shape" in action_spec:
|
| 40 |
+
shape = action_spec["shape"]
|
| 41 |
+
if isinstance(shape, list) and len(shape) > 0:
|
| 42 |
+
return shape[0]
|
| 43 |
+
# Quadruped default: 12 actuators (3 per leg x 4 legs)
|
| 44 |
+
return 12
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def generate_random_action(action_dim: int, magnitude: float = 1.0) -> DMControlAction:
|
| 48 |
+
"""Generate a random action with values in [-magnitude, magnitude]."""
|
| 49 |
+
values = [random.uniform(-magnitude, magnitude) for _ in range(action_dim)]
|
| 50 |
+
return DMControlAction(values=values)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def generate_zero_action(action_dim: int) -> DMControlAction:
|
| 54 |
+
"""Generate a zero action (no force applied)."""
|
| 55 |
+
return DMControlAction(values=[0.0] * action_dim)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def run_headless(env: DMControlEnv, max_steps: int = 1000):
|
| 59 |
+
"""Run quadruped control in headless mode."""
|
| 60 |
+
print("\n=== Headless Mode (OpenEnv Step/Observation Pattern) ===")
|
| 61 |
+
print("This mode demonstrates the OpenEnv API with the quadruped.\n")
|
| 62 |
+
|
| 63 |
+
# Reset environment using OpenEnv pattern
|
| 64 |
+
result = env.reset(domain_name="quadruped", task_name="walk")
|
| 65 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 66 |
+
|
| 67 |
+
# Get action dimension
|
| 68 |
+
action_dim = get_action_dim(env)
|
| 69 |
+
print(f"Action dimension: {action_dim}")
|
| 70 |
+
|
| 71 |
+
total_reward = 0.0
|
| 72 |
+
step_count = 0
|
| 73 |
+
|
| 74 |
+
print("\nRunning with periodic random forces...")
|
| 75 |
+
print("Every 50 steps, a random force burst is applied.\n")
|
| 76 |
+
|
| 77 |
+
while not result.done and step_count < max_steps:
|
| 78 |
+
# Apply random force every 50 steps, otherwise zero action
|
| 79 |
+
if step_count % 50 < 10:
|
| 80 |
+
# Random force burst for 10 steps
|
| 81 |
+
action = generate_random_action(action_dim, magnitude=0.5)
|
| 82 |
+
else:
|
| 83 |
+
# No force
|
| 84 |
+
action = generate_zero_action(action_dim)
|
| 85 |
+
|
| 86 |
+
# Step the environment using OpenEnv pattern
|
| 87 |
+
result = env.step(action)
|
| 88 |
+
|
| 89 |
+
# Access observation and reward from result
|
| 90 |
+
total_reward += result.reward or 0.0
|
| 91 |
+
step_count += 1
|
| 92 |
+
|
| 93 |
+
# Print progress periodically
|
| 94 |
+
if step_count % 100 == 0:
|
| 95 |
+
# Get some observation values
|
| 96 |
+
egocentric_state = result.observation.observations.get(
|
| 97 |
+
"egocentric_state", []
|
| 98 |
+
)
|
| 99 |
+
print(
|
| 100 |
+
f"Step {step_count}: reward={result.reward:.3f}, "
|
| 101 |
+
f"total={total_reward:.2f}, done={result.done}"
|
| 102 |
+
)
|
| 103 |
+
if egocentric_state:
|
| 104 |
+
print(f" egocentric_state (first 5): {egocentric_state[:5]}")
|
| 105 |
+
|
| 106 |
+
print(f"\nEpisode finished: {step_count} steps, total reward: {total_reward:.2f}")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def run_interactive(env: DMControlEnv):
|
| 110 |
+
"""Run interactive control with keyboard input via pygame."""
|
| 111 |
+
import pygame
|
| 112 |
+
|
| 113 |
+
print("\n=== Interactive Mode (OpenEnv Step/Observation Pattern) ===")
|
| 114 |
+
print("Press SPACE to apply random force, R to reset, ESC to quit.\n")
|
| 115 |
+
|
| 116 |
+
# Reset environment using OpenEnv pattern
|
| 117 |
+
result = env.reset(domain_name="quadruped", task_name="walk")
|
| 118 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 119 |
+
|
| 120 |
+
# Get action dimension
|
| 121 |
+
action_dim = get_action_dim(env)
|
| 122 |
+
print(f"Action dimension: {action_dim}")
|
| 123 |
+
|
| 124 |
+
# Initialize pygame for keyboard input (minimal window)
|
| 125 |
+
pygame.init()
|
| 126 |
+
screen = pygame.display.set_mode((400, 100))
|
| 127 |
+
pygame.display.set_caption("Quadruped Control - SPACE for random force, R to reset")
|
| 128 |
+
clock = pygame.time.Clock()
|
| 129 |
+
|
| 130 |
+
# Draw instructions on the window
|
| 131 |
+
font = pygame.font.Font(None, 24)
|
| 132 |
+
|
| 133 |
+
running = True
|
| 134 |
+
total_reward = 0.0
|
| 135 |
+
step_count = 0
|
| 136 |
+
apply_random_force = False
|
| 137 |
+
|
| 138 |
+
print("\nControls:")
|
| 139 |
+
print(" SPACE: Apply random force to joints")
|
| 140 |
+
print(" R: Reset environment")
|
| 141 |
+
print(" ESC or Q: Quit\n")
|
| 142 |
+
|
| 143 |
+
while running:
|
| 144 |
+
# Handle events
|
| 145 |
+
for event in pygame.event.get():
|
| 146 |
+
if event.type == pygame.QUIT:
|
| 147 |
+
running = False
|
| 148 |
+
elif event.type == pygame.KEYDOWN:
|
| 149 |
+
if event.key in (pygame.K_ESCAPE, pygame.K_q):
|
| 150 |
+
running = False
|
| 151 |
+
elif event.key == pygame.K_r:
|
| 152 |
+
result = env.reset(domain_name="quadruped", task_name="walk")
|
| 153 |
+
total_reward = 0.0
|
| 154 |
+
step_count = 0
|
| 155 |
+
print("Environment reset")
|
| 156 |
+
|
| 157 |
+
# Check for held keys
|
| 158 |
+
keys = pygame.key.get_pressed()
|
| 159 |
+
apply_random_force = keys[pygame.K_SPACE]
|
| 160 |
+
|
| 161 |
+
# Generate action based on input
|
| 162 |
+
if apply_random_force:
|
| 163 |
+
action = generate_random_action(action_dim, magnitude=2.0)
|
| 164 |
+
else:
|
| 165 |
+
action = generate_zero_action(action_dim)
|
| 166 |
+
|
| 167 |
+
# Step the environment using OpenEnv pattern
|
| 168 |
+
result = env.step(action)
|
| 169 |
+
|
| 170 |
+
# Track reward from result
|
| 171 |
+
total_reward += result.reward or 0.0
|
| 172 |
+
step_count += 1
|
| 173 |
+
|
| 174 |
+
# Check if episode is done
|
| 175 |
+
if result.done:
|
| 176 |
+
print(
|
| 177 |
+
f"Episode finished! Steps: {step_count}, "
|
| 178 |
+
f"Total reward: {total_reward:.2f}"
|
| 179 |
+
)
|
| 180 |
+
# Auto-reset on done
|
| 181 |
+
result = env.reset(domain_name="quadruped", task_name="walk")
|
| 182 |
+
total_reward = 0.0
|
| 183 |
+
step_count = 0
|
| 184 |
+
|
| 185 |
+
# Update display
|
| 186 |
+
screen.fill((30, 30, 30))
|
| 187 |
+
status = "FORCE!" if apply_random_force else "idle"
|
| 188 |
+
text = font.render(
|
| 189 |
+
f"Step: {step_count} | Reward: {total_reward:.1f} | {status}",
|
| 190 |
+
True,
|
| 191 |
+
(255, 255, 255),
|
| 192 |
+
)
|
| 193 |
+
screen.blit(text, (10, 40))
|
| 194 |
+
pygame.display.flip()
|
| 195 |
+
|
| 196 |
+
# Print progress periodically
|
| 197 |
+
if step_count % 200 == 0 and step_count > 0:
|
| 198 |
+
print(f"Step {step_count}: Total reward: {total_reward:.2f}")
|
| 199 |
+
|
| 200 |
+
# Cap at 30 FPS
|
| 201 |
+
clock.tick(30)
|
| 202 |
+
|
| 203 |
+
pygame.quit()
|
| 204 |
+
print(f"Session ended. Final reward: {total_reward:.2f}")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def run_visual(env: DMControlEnv):
|
| 208 |
+
"""Run with pygame visualization showing rendered frames."""
|
| 209 |
+
import base64
|
| 210 |
+
import io
|
| 211 |
+
|
| 212 |
+
import pygame
|
| 213 |
+
|
| 214 |
+
print("\n=== Visual Mode (OpenEnv Step/Observation Pattern) ===")
|
| 215 |
+
|
| 216 |
+
# Reset environment with rendering enabled
|
| 217 |
+
result = env.reset(domain_name="quadruped", task_name="walk", render=True)
|
| 218 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 219 |
+
|
| 220 |
+
# Get action dimension
|
| 221 |
+
action_dim = get_action_dim(env)
|
| 222 |
+
print(f"Action dimension: {action_dim}")
|
| 223 |
+
|
| 224 |
+
# Get first frame to determine window size
|
| 225 |
+
if result.observation.pixels is None:
|
| 226 |
+
print("Error: Server did not return rendered pixels.")
|
| 227 |
+
print("Make sure the server supports render=True")
|
| 228 |
+
print("\nTry running in interactive mode (default) instead.")
|
| 229 |
+
sys.exit(1)
|
| 230 |
+
|
| 231 |
+
# Decode base64 PNG to pygame surface
|
| 232 |
+
png_data = base64.b64decode(result.observation.pixels)
|
| 233 |
+
frame = pygame.image.load(io.BytesIO(png_data))
|
| 234 |
+
frame_size = frame.get_size()
|
| 235 |
+
|
| 236 |
+
# Initialize pygame
|
| 237 |
+
pygame.init()
|
| 238 |
+
screen = pygame.display.set_mode(frame_size)
|
| 239 |
+
pygame.display.set_caption(
|
| 240 |
+
"Quadruped (OpenEnv) - SPACE for random force, R to Reset, ESC to Quit"
|
| 241 |
+
)
|
| 242 |
+
clock = pygame.time.Clock()
|
| 243 |
+
|
| 244 |
+
print("Controls:")
|
| 245 |
+
print(" SPACE: Apply random force to joints")
|
| 246 |
+
print(" R: Reset environment")
|
| 247 |
+
print(" ESC or Q: Quit")
|
| 248 |
+
|
| 249 |
+
running = True
|
| 250 |
+
total_reward = 0.0
|
| 251 |
+
step_count = 0
|
| 252 |
+
|
| 253 |
+
while running:
|
| 254 |
+
# Handle events
|
| 255 |
+
for event in pygame.event.get():
|
| 256 |
+
if event.type == pygame.QUIT:
|
| 257 |
+
running = False
|
| 258 |
+
elif event.type == pygame.KEYDOWN:
|
| 259 |
+
if event.key in (pygame.K_ESCAPE, pygame.K_q):
|
| 260 |
+
running = False
|
| 261 |
+
elif event.key == pygame.K_r:
|
| 262 |
+
result = env.reset(
|
| 263 |
+
domain_name="quadruped", task_name="walk", render=True
|
| 264 |
+
)
|
| 265 |
+
total_reward = 0.0
|
| 266 |
+
step_count = 0
|
| 267 |
+
print("Environment reset")
|
| 268 |
+
|
| 269 |
+
# Check for held keys
|
| 270 |
+
keys = pygame.key.get_pressed()
|
| 271 |
+
apply_random_force = keys[pygame.K_SPACE]
|
| 272 |
+
|
| 273 |
+
# Generate action based on input
|
| 274 |
+
if apply_random_force:
|
| 275 |
+
action = generate_random_action(action_dim, magnitude=2.0)
|
| 276 |
+
else:
|
| 277 |
+
action = generate_zero_action(action_dim)
|
| 278 |
+
|
| 279 |
+
# Step the environment using OpenEnv pattern
|
| 280 |
+
result = env.step(action, render=True)
|
| 281 |
+
|
| 282 |
+
# Track reward from result
|
| 283 |
+
total_reward += result.reward or 0.0
|
| 284 |
+
step_count += 1
|
| 285 |
+
|
| 286 |
+
# Check if episode is done
|
| 287 |
+
if result.done:
|
| 288 |
+
print(
|
| 289 |
+
f"Episode finished! Steps: {step_count}, "
|
| 290 |
+
f"Total reward: {total_reward:.2f}"
|
| 291 |
+
)
|
| 292 |
+
result = env.reset(domain_name="quadruped", task_name="walk", render=True)
|
| 293 |
+
total_reward = 0.0
|
| 294 |
+
step_count = 0
|
| 295 |
+
|
| 296 |
+
# Render the frame from observation pixels
|
| 297 |
+
if result.observation.pixels:
|
| 298 |
+
png_data = base64.b64decode(result.observation.pixels)
|
| 299 |
+
frame = pygame.image.load(io.BytesIO(png_data))
|
| 300 |
+
screen.blit(frame, (0, 0))
|
| 301 |
+
pygame.display.flip()
|
| 302 |
+
|
| 303 |
+
# Print progress periodically
|
| 304 |
+
if step_count % 200 == 0 and step_count > 0:
|
| 305 |
+
print(f"Step {step_count}: Total reward: {total_reward:.2f}")
|
| 306 |
+
|
| 307 |
+
# Cap at 30 FPS
|
| 308 |
+
clock.tick(30)
|
| 309 |
+
|
| 310 |
+
pygame.quit()
|
| 311 |
+
print(f"Session ended. Final reward: {total_reward:.2f}")
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def main():
|
| 315 |
+
parser = argparse.ArgumentParser(
|
| 316 |
+
description="Interactive quadruped control via OpenEnv"
|
| 317 |
+
)
|
| 318 |
+
parser.add_argument(
|
| 319 |
+
"--visual",
|
| 320 |
+
action="store_true",
|
| 321 |
+
help="Enable pygame visualization with rendered frames",
|
| 322 |
+
)
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--headless",
|
| 325 |
+
action="store_true",
|
| 326 |
+
help="Run in headless mode (no pygame, automated control)",
|
| 327 |
+
)
|
| 328 |
+
parser.add_argument(
|
| 329 |
+
"--max-steps",
|
| 330 |
+
type=int,
|
| 331 |
+
default=1000,
|
| 332 |
+
help="Maximum steps for headless mode (default: 1000)",
|
| 333 |
+
)
|
| 334 |
+
parser.add_argument(
|
| 335 |
+
"--task",
|
| 336 |
+
type=str,
|
| 337 |
+
default="walk",
|
| 338 |
+
choices=["walk", "run", "escape", "fetch"],
|
| 339 |
+
help="Quadruped task (default: walk)",
|
| 340 |
+
)
|
| 341 |
+
args = parser.parse_args()
|
| 342 |
+
|
| 343 |
+
server_url = "http://localhost:8000"
|
| 344 |
+
print(f"Connecting to {server_url}...")
|
| 345 |
+
|
| 346 |
+
try:
|
| 347 |
+
with DMControlEnv(base_url=server_url) as env:
|
| 348 |
+
print("Connected!")
|
| 349 |
+
|
| 350 |
+
# Get environment state
|
| 351 |
+
state = env.state()
|
| 352 |
+
print(f"Domain: {state.domain_name}, Task: {state.task_name}")
|
| 353 |
+
print(f"Action spec: {state.action_spec}")
|
| 354 |
+
|
| 355 |
+
if args.headless:
|
| 356 |
+
run_headless(env, max_steps=args.max_steps)
|
| 357 |
+
elif args.visual:
|
| 358 |
+
run_visual(env)
|
| 359 |
+
else:
|
| 360 |
+
run_interactive(env)
|
| 361 |
+
|
| 362 |
+
except ConnectionError as e:
|
| 363 |
+
print(f"Failed to connect: {e}")
|
| 364 |
+
print("\nMake sure the server is running:")
|
| 365 |
+
print(" cd OpenEnv")
|
| 366 |
+
print(
|
| 367 |
+
" PYTHONPATH=src:envs uvicorn envs.dm_control_env.server.app:app --port 8000"
|
| 368 |
+
)
|
| 369 |
+
sys.exit(1)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
if __name__ == "__main__":
|
| 373 |
+
main()
|
envs/dm_control_env/models.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Data models for the dm_control OpenEnv Environment.
|
| 9 |
+
|
| 10 |
+
This environment wraps dm_control.suite, providing access to all MuJoCo-based
|
| 11 |
+
continuous control tasks (cartpole, walker, humanoid, cheetah, etc.).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from typing import Any, Dict, List, Optional
|
| 15 |
+
|
| 16 |
+
from pydantic import Field
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 20 |
+
except ImportError:
|
| 21 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DMControlAction(Action):
|
| 25 |
+
"""
|
| 26 |
+
Action for dm_control environments.
|
| 27 |
+
|
| 28 |
+
All dm_control.suite environments use continuous actions represented as
|
| 29 |
+
a list of float values. The size and bounds depend on the specific
|
| 30 |
+
domain/task combination.
|
| 31 |
+
|
| 32 |
+
Example (cartpole - 1D action):
|
| 33 |
+
>>> action = DMControlAction(values=[0.5]) # Push cart right
|
| 34 |
+
|
| 35 |
+
Example (walker - 6D action):
|
| 36 |
+
>>> action = DMControlAction(values=[0.1, -0.2, 0.3, 0.0, -0.1, 0.2])
|
| 37 |
+
|
| 38 |
+
Attributes:
|
| 39 |
+
values: List of continuous action values. Shape and bounds depend on
|
| 40 |
+
the loaded environment's action_spec.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
values: List[float] = Field(
|
| 44 |
+
default_factory=list,
|
| 45 |
+
description="Continuous action values matching the environment's action_spec",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DMControlObservation(Observation):
|
| 50 |
+
"""
|
| 51 |
+
Observation from dm_control environments.
|
| 52 |
+
|
| 53 |
+
dm_control environments return observations as a dictionary of named arrays.
|
| 54 |
+
Common observation keys include 'position', 'velocity', 'orientations', etc.
|
| 55 |
+
The exact keys depend on the domain/task combination.
|
| 56 |
+
|
| 57 |
+
Example observation keys by domain:
|
| 58 |
+
- cartpole: 'position' (cos/sin of angle), 'velocity'
|
| 59 |
+
- walker: 'orientations', 'height', 'velocity'
|
| 60 |
+
- humanoid: 'joint_angles', 'head_height', 'extremities', 'torso_vertical', 'com_velocity'
|
| 61 |
+
|
| 62 |
+
Attributes:
|
| 63 |
+
observations: Dictionary mapping observation names to their values.
|
| 64 |
+
Each value is a flattened list of floats.
|
| 65 |
+
pixels: Optional base64-encoded PNG image of the rendered scene.
|
| 66 |
+
Only included when render=True is passed to reset/step.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
observations: Dict[str, List[float]] = Field(
|
| 70 |
+
default_factory=dict,
|
| 71 |
+
description="Named observation arrays from the environment",
|
| 72 |
+
)
|
| 73 |
+
pixels: Optional[str] = Field(
|
| 74 |
+
default=None,
|
| 75 |
+
description="Base64-encoded PNG image (when render=True)",
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class DMControlState(State):
|
| 80 |
+
"""
|
| 81 |
+
Extended state for dm_control environments.
|
| 82 |
+
|
| 83 |
+
Provides metadata about the currently loaded environment including
|
| 84 |
+
the domain/task names and action/observation specifications.
|
| 85 |
+
|
| 86 |
+
Attributes:
|
| 87 |
+
episode_id: Unique identifier for the current episode.
|
| 88 |
+
step_count: Number of steps taken in the current episode.
|
| 89 |
+
domain_name: The dm_control domain (e.g., 'cartpole', 'walker').
|
| 90 |
+
task_name: The specific task (e.g., 'balance', 'walk').
|
| 91 |
+
action_spec: Specification of the action space including shape and bounds.
|
| 92 |
+
observation_spec: Specification of the observation space.
|
| 93 |
+
physics_timestep: The physics simulation timestep in seconds.
|
| 94 |
+
control_timestep: The control timestep (time between actions) in seconds.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
domain_name: str = Field(
|
| 98 |
+
default="cartpole",
|
| 99 |
+
description="The dm_control domain name",
|
| 100 |
+
)
|
| 101 |
+
task_name: str = Field(
|
| 102 |
+
default="balance",
|
| 103 |
+
description="The task name within the domain",
|
| 104 |
+
)
|
| 105 |
+
action_spec: Dict[str, Any] = Field(
|
| 106 |
+
default_factory=dict,
|
| 107 |
+
description="Specification of the action space (shape, dtype, bounds)",
|
| 108 |
+
)
|
| 109 |
+
observation_spec: Dict[str, Any] = Field(
|
| 110 |
+
default_factory=dict,
|
| 111 |
+
description="Specification of the observation space",
|
| 112 |
+
)
|
| 113 |
+
physics_timestep: float = Field(
|
| 114 |
+
default=0.002,
|
| 115 |
+
description="Physics simulation timestep in seconds",
|
| 116 |
+
)
|
| 117 |
+
control_timestep: float = Field(
|
| 118 |
+
default=0.02,
|
| 119 |
+
description="Control timestep (time between actions) in seconds",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# Available dm_control.suite environments
|
| 124 |
+
# Format: (domain_name, task_name)
|
| 125 |
+
AVAILABLE_ENVIRONMENTS = [
|
| 126 |
+
# Cartpole
|
| 127 |
+
("cartpole", "balance"),
|
| 128 |
+
("cartpole", "balance_sparse"),
|
| 129 |
+
("cartpole", "swingup"),
|
| 130 |
+
("cartpole", "swingup_sparse"),
|
| 131 |
+
# Pendulum
|
| 132 |
+
("pendulum", "swingup"),
|
| 133 |
+
# Point mass
|
| 134 |
+
("point_mass", "easy"),
|
| 135 |
+
("point_mass", "hard"),
|
| 136 |
+
# Reacher
|
| 137 |
+
("reacher", "easy"),
|
| 138 |
+
("reacher", "hard"),
|
| 139 |
+
# Ball in cup
|
| 140 |
+
("ball_in_cup", "catch"),
|
| 141 |
+
# Finger
|
| 142 |
+
("finger", "spin"),
|
| 143 |
+
("finger", "turn_easy"),
|
| 144 |
+
("finger", "turn_hard"),
|
| 145 |
+
# Fish
|
| 146 |
+
("fish", "upright"),
|
| 147 |
+
("fish", "swim"),
|
| 148 |
+
# Cheetah
|
| 149 |
+
("cheetah", "run"),
|
| 150 |
+
# Walker
|
| 151 |
+
("walker", "stand"),
|
| 152 |
+
("walker", "walk"),
|
| 153 |
+
("walker", "run"),
|
| 154 |
+
# Hopper
|
| 155 |
+
("hopper", "stand"),
|
| 156 |
+
("hopper", "hop"),
|
| 157 |
+
# Swimmer
|
| 158 |
+
("swimmer", "swimmer6"),
|
| 159 |
+
("swimmer", "swimmer15"),
|
| 160 |
+
# Humanoid
|
| 161 |
+
("humanoid", "stand"),
|
| 162 |
+
("humanoid", "walk"),
|
| 163 |
+
("humanoid", "run"),
|
| 164 |
+
# Manipulator
|
| 165 |
+
("manipulator", "bring_ball"),
|
| 166 |
+
("manipulator", "bring_peg"),
|
| 167 |
+
("manipulator", "insert_ball"),
|
| 168 |
+
("manipulator", "insert_peg"),
|
| 169 |
+
# Acrobot
|
| 170 |
+
("acrobot", "swingup"),
|
| 171 |
+
("acrobot", "swingup_sparse"),
|
| 172 |
+
# Stacker
|
| 173 |
+
("stacker", "stack_2"),
|
| 174 |
+
("stacker", "stack_4"),
|
| 175 |
+
# Dog
|
| 176 |
+
("dog", "stand"),
|
| 177 |
+
("dog", "walk"),
|
| 178 |
+
("dog", "trot"),
|
| 179 |
+
("dog", "run"),
|
| 180 |
+
("dog", "fetch"),
|
| 181 |
+
# Quadruped
|
| 182 |
+
("quadruped", "walk"),
|
| 183 |
+
("quadruped", "run"),
|
| 184 |
+
("quadruped", "escape"),
|
| 185 |
+
("quadruped", "fetch"),
|
| 186 |
+
]
|
envs/dm_control_env/openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: dm_control_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
envs/dm_control_env/pyproject.toml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-dmcontrol-env"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "dm_control Environment for OpenEnv - wraps MuJoCo-based continuous control tasks (cartpole, walker, humanoid, etc.)"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv dependencies
|
| 18 |
+
"openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v2.1.0",
|
| 19 |
+
"fastapi>=0.115.0",
|
| 20 |
+
"pydantic>=2.0.0",
|
| 21 |
+
"uvicorn>=0.24.0",
|
| 22 |
+
"requests>=2.31.0",
|
| 23 |
+
# dm_control dependencies
|
| 24 |
+
"mujoco>=3.0.0",
|
| 25 |
+
"dm_control>=1.0.0",
|
| 26 |
+
"numpy>=1.20.0",
|
| 27 |
+
# Optional: for pixel observations
|
| 28 |
+
"pillow>=9.0.0",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
]
|
| 36 |
+
interactive = [
|
| 37 |
+
# For interactive examples with keyboard control
|
| 38 |
+
"pygame>=2.0.0",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
[project.scripts]
|
| 42 |
+
# Server entry point - enables running via: uv run --project . server
|
| 43 |
+
server = "dm_control_env.server.app:main"
|
| 44 |
+
|
| 45 |
+
[tool.setuptools]
|
| 46 |
+
include-package-data = true
|
| 47 |
+
packages = ["dm_control_env", "dm_control_env.server"]
|
| 48 |
+
package-dir = { "dm_control_env" = ".", "dm_control_env.server" = "server" }
|
envs/dm_control_env/server/Dockerfile
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build for dm_control environment
|
| 8 |
+
# Uses pip for package installation
|
| 9 |
+
|
| 10 |
+
FROM python:3.11-slim AS builder
|
| 11 |
+
|
| 12 |
+
WORKDIR /app
|
| 13 |
+
|
| 14 |
+
# Install build dependencies including OpenGL for MuJoCo
|
| 15 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 16 |
+
build-essential \
|
| 17 |
+
git \
|
| 18 |
+
libgl1 \
|
| 19 |
+
libglx-mesa0 \
|
| 20 |
+
libglew-dev \
|
| 21 |
+
libosmesa6-dev \
|
| 22 |
+
libgl1-mesa-dev \
|
| 23 |
+
libglfw3 \
|
| 24 |
+
patchelf \
|
| 25 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 26 |
+
|
| 27 |
+
# Copy environment code
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
WORKDIR /app/env
|
| 31 |
+
|
| 32 |
+
# Install dependencies using pip
|
| 33 |
+
RUN pip install --upgrade pip && \
|
| 34 |
+
pip install --no-cache-dir -e .
|
| 35 |
+
|
| 36 |
+
# Final runtime stage
|
| 37 |
+
FROM python:3.11-slim
|
| 38 |
+
|
| 39 |
+
WORKDIR /app
|
| 40 |
+
|
| 41 |
+
# Install runtime dependencies (OpenGL for MuJoCo rendering, curl for healthcheck)
|
| 42 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 43 |
+
curl \
|
| 44 |
+
libgl1 \
|
| 45 |
+
libglx-mesa0 \
|
| 46 |
+
libglew-dev \
|
| 47 |
+
libosmesa6-dev \
|
| 48 |
+
libglfw3 \
|
| 49 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 50 |
+
|
| 51 |
+
# Copy installed packages from builder
|
| 52 |
+
COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
|
| 53 |
+
COPY --from=builder /usr/local/bin /usr/local/bin
|
| 54 |
+
|
| 55 |
+
# Copy the environment code
|
| 56 |
+
COPY . /app/env
|
| 57 |
+
|
| 58 |
+
# Set PYTHONPATH so imports work correctly
|
| 59 |
+
ENV PYTHONPATH="/app/env"
|
| 60 |
+
|
| 61 |
+
# Set MuJoCo to use OSMesa for headless rendering
|
| 62 |
+
ENV MUJOCO_GL="osmesa"
|
| 63 |
+
|
| 64 |
+
# Expose port
|
| 65 |
+
EXPOSE 8000
|
| 66 |
+
|
| 67 |
+
# Health check
|
| 68 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
|
| 69 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 70 |
+
|
| 71 |
+
# Run the FastAPI server
|
| 72 |
+
# Use exec to replace the shell with uvicorn so it receives SIGINT/SIGTERM directly
|
| 73 |
+
CMD ["sh", "-c", "cd /app/env && exec uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
envs/dm_control_env/server/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""dm_control OpenEnv server module."""
|
| 2 |
+
|
| 3 |
+
from .dm_control_environment import DMControlEnvironment
|
| 4 |
+
|
| 5 |
+
__all__ = ["DMControlEnvironment"]
|
envs/dm_control_env/server/app.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
FastAPI application for the dm_control Environment.
|
| 9 |
+
|
| 10 |
+
This module creates an HTTP server that exposes dm_control.suite environments
|
| 11 |
+
over HTTP and WebSocket endpoints, compatible with EnvClient.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
# Development (with auto-reload):
|
| 15 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 16 |
+
|
| 17 |
+
# Production:
|
| 18 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 19 |
+
|
| 20 |
+
# Or run directly:
|
| 21 |
+
uv run --project . server
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from openenv.core.env_server.http_server import create_app
|
| 26 |
+
|
| 27 |
+
from ..models import DMControlAction, DMControlObservation
|
| 28 |
+
from .dm_control_environment import DMControlEnvironment
|
| 29 |
+
except ImportError:
|
| 30 |
+
from openenv.core.env_server.http_server import create_app
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
import sys
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
_parent = str(Path(__file__).parent.parent)
|
| 37 |
+
if _parent not in sys.path:
|
| 38 |
+
sys.path.insert(0, _parent)
|
| 39 |
+
from models import DMControlAction, DMControlObservation
|
| 40 |
+
from server.dm_control_environment import DMControlEnvironment
|
| 41 |
+
except ImportError:
|
| 42 |
+
try:
|
| 43 |
+
from dm_control_env.models import DMControlAction, DMControlObservation
|
| 44 |
+
from dm_control_env.server.dm_control_environment import (
|
| 45 |
+
DMControlEnvironment,
|
| 46 |
+
)
|
| 47 |
+
except ImportError:
|
| 48 |
+
from envs.dm_control_env.models import DMControlAction, DMControlObservation
|
| 49 |
+
from envs.dm_control_env.server.dm_control_environment import (
|
| 50 |
+
DMControlEnvironment,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Create the app with web interface
|
| 54 |
+
# Pass the class (factory) for concurrent session support
|
| 55 |
+
app = create_app(
|
| 56 |
+
DMControlEnvironment,
|
| 57 |
+
DMControlAction,
|
| 58 |
+
DMControlObservation,
|
| 59 |
+
env_name="dm_control_env",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main():
|
| 64 |
+
"""
|
| 65 |
+
Entry point for direct execution via uv run or python -m.
|
| 66 |
+
|
| 67 |
+
This function enables running the server without Docker:
|
| 68 |
+
uv run --project . server
|
| 69 |
+
python -m envs.dm_control_env.server.app
|
| 70 |
+
openenv serve dm_control_env
|
| 71 |
+
"""
|
| 72 |
+
import uvicorn
|
| 73 |
+
|
| 74 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
main()
|
envs/dm_control_env/server/dm_control_environment.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
dm_control Environment Implementation.
|
| 9 |
+
|
| 10 |
+
Wraps dm_control.suite environments (cartpole, walker, humanoid, etc.)
|
| 11 |
+
with the OpenEnv interface for standardized reinforcement learning.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import base64
|
| 15 |
+
import io
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
from typing import Any, Dict, Optional
|
| 19 |
+
from uuid import uuid4
|
| 20 |
+
|
| 21 |
+
# Configure MuJoCo rendering backend before importing dm_control
|
| 22 |
+
# On macOS, we don't set MUJOCO_GL - use default (glfw) which works
|
| 23 |
+
# when running synchronously in the main thread (see reset_async/step_async)
|
| 24 |
+
# On Linux, use egl for headless rendering
|
| 25 |
+
if "MUJOCO_GL" not in os.environ and sys.platform != "darwin":
|
| 26 |
+
os.environ.setdefault("MUJOCO_GL", "egl")
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from openenv.core.env_server.interfaces import Environment
|
| 32 |
+
|
| 33 |
+
from ..models import DMControlAction, DMControlObservation, DMControlState
|
| 34 |
+
except ImportError:
|
| 35 |
+
from openenv.core.env_server.interfaces import Environment
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
import sys
|
| 39 |
+
from pathlib import Path
|
| 40 |
+
|
| 41 |
+
_parent = str(Path(__file__).parent.parent)
|
| 42 |
+
if _parent not in sys.path:
|
| 43 |
+
sys.path.insert(0, _parent)
|
| 44 |
+
from models import DMControlAction, DMControlObservation, DMControlState
|
| 45 |
+
except ImportError:
|
| 46 |
+
try:
|
| 47 |
+
from dm_control_env.models import (
|
| 48 |
+
DMControlAction,
|
| 49 |
+
DMControlObservation,
|
| 50 |
+
DMControlState,
|
| 51 |
+
)
|
| 52 |
+
except ImportError:
|
| 53 |
+
from envs.dm_control_env.models import (
|
| 54 |
+
DMControlAction,
|
| 55 |
+
DMControlObservation,
|
| 56 |
+
DMControlState,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class DMControlEnvironment(Environment):
|
| 61 |
+
"""
|
| 62 |
+
Wraps dm_control.suite environments with the OpenEnv interface.
|
| 63 |
+
|
| 64 |
+
This environment supports all dm_control.suite domains and tasks including
|
| 65 |
+
cartpole, walker, humanoid, cheetah, and more.
|
| 66 |
+
|
| 67 |
+
Features:
|
| 68 |
+
- Dynamic environment switching via reset(domain_name="...", task_name="...")
|
| 69 |
+
- Support for all continuous control tasks
|
| 70 |
+
- Optional visual observations (base64-encoded images)
|
| 71 |
+
- Configurable via constructor or environment variables
|
| 72 |
+
|
| 73 |
+
Example:
|
| 74 |
+
>>> env = DMControlEnvironment()
|
| 75 |
+
>>> obs = env.reset() # Default: cartpole/balance
|
| 76 |
+
>>> print(obs.observations)
|
| 77 |
+
>>>
|
| 78 |
+
>>> # Take an action
|
| 79 |
+
>>> obs = env.step(DMControlAction(values=[0.5])) # Push cart right
|
| 80 |
+
>>> print(obs.reward)
|
| 81 |
+
|
| 82 |
+
Example with different environment:
|
| 83 |
+
>>> env = DMControlEnvironment(domain_name="walker", task_name="walk")
|
| 84 |
+
>>> obs = env.reset()
|
| 85 |
+
>>>
|
| 86 |
+
>>> # Or switch environment on reset
|
| 87 |
+
>>> obs = env.reset(domain_name="cheetah", task_name="run")
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
# dm_control environments are isolated and thread-safe
|
| 91 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
domain_name: Optional[str] = None,
|
| 96 |
+
task_name: Optional[str] = None,
|
| 97 |
+
render_height: Optional[int] = None,
|
| 98 |
+
render_width: Optional[int] = None,
|
| 99 |
+
):
|
| 100 |
+
"""
|
| 101 |
+
Initialize the dm_control environment.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
domain_name: The dm_control domain to load.
|
| 105 |
+
Env var: DMCONTROL_DOMAIN (default: cartpole)
|
| 106 |
+
task_name: The task within the domain.
|
| 107 |
+
Env var: DMCONTROL_TASK (default: balance)
|
| 108 |
+
render_height: Height of rendered images (when render=True).
|
| 109 |
+
Env var: DMCONTROL_RENDER_HEIGHT (default: 480)
|
| 110 |
+
render_width: Width of rendered images (when render=True).
|
| 111 |
+
Env var: DMCONTROL_RENDER_WIDTH (default: 640)
|
| 112 |
+
"""
|
| 113 |
+
self._env = None
|
| 114 |
+
|
| 115 |
+
self._domain_name = domain_name or os.environ.get(
|
| 116 |
+
"DMCONTROL_DOMAIN", "cartpole"
|
| 117 |
+
)
|
| 118 |
+
self._task_name = task_name or os.environ.get("DMCONTROL_TASK", "balance")
|
| 119 |
+
self._render_height = (
|
| 120 |
+
render_height
|
| 121 |
+
if render_height is not None
|
| 122 |
+
else int(os.environ.get("DMCONTROL_RENDER_HEIGHT", "480"))
|
| 123 |
+
)
|
| 124 |
+
self._render_width = (
|
| 125 |
+
render_width
|
| 126 |
+
if render_width is not None
|
| 127 |
+
else int(os.environ.get("DMCONTROL_RENDER_WIDTH", "640"))
|
| 128 |
+
)
|
| 129 |
+
self._include_pixels = False
|
| 130 |
+
|
| 131 |
+
self._state = DMControlState(
|
| 132 |
+
episode_id=str(uuid4()),
|
| 133 |
+
step_count=0,
|
| 134 |
+
domain_name=self._domain_name,
|
| 135 |
+
task_name=self._task_name,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def _load_environment(self, domain_name: str, task_name: str) -> None:
|
| 139 |
+
"""Load or switch to a dm_control environment."""
|
| 140 |
+
if self._env is not None:
|
| 141 |
+
try:
|
| 142 |
+
self._env.close()
|
| 143 |
+
except Exception:
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
from dm_control import suite
|
| 148 |
+
except ImportError as e:
|
| 149 |
+
raise ImportError(
|
| 150 |
+
"dm_control is required. Install with: pip install dm_control"
|
| 151 |
+
) from e
|
| 152 |
+
except Exception as e:
|
| 153 |
+
# MuJoCo/OpenGL initialization can fail on macOS
|
| 154 |
+
error_msg = str(e)
|
| 155 |
+
if sys.platform == "darwin":
|
| 156 |
+
raise RuntimeError(
|
| 157 |
+
f"Failed to import dm_control (MuJoCo error): {error_msg}\n\n"
|
| 158 |
+
"On macOS, try one of these solutions:\n"
|
| 159 |
+
"1. Install osmesa: brew install mesa\n"
|
| 160 |
+
"2. Run with MUJOCO_GL=glfw (requires display)\n"
|
| 161 |
+
"3. Run with MUJOCO_GL=egl (if EGL is available)"
|
| 162 |
+
) from e
|
| 163 |
+
raise
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
self._env = suite.load(domain_name=domain_name, task_name=task_name)
|
| 167 |
+
except Exception as e:
|
| 168 |
+
error_msg = str(e).lower()
|
| 169 |
+
# Check for MuJoCo/OpenGL errors
|
| 170 |
+
if "gl" in error_msg or "render" in error_msg or "display" in error_msg:
|
| 171 |
+
if sys.platform == "darwin":
|
| 172 |
+
raise RuntimeError(
|
| 173 |
+
f"MuJoCo initialization failed: {e}\n\n"
|
| 174 |
+
"On macOS, try one of these solutions:\n"
|
| 175 |
+
"1. Install osmesa: brew install mesa\n"
|
| 176 |
+
"2. Run with MUJOCO_GL=glfw (requires display)\n"
|
| 177 |
+
"3. Set PYOPENGL_PLATFORM=osmesa"
|
| 178 |
+
) from e
|
| 179 |
+
# Check if it's an invalid environment error
|
| 180 |
+
try:
|
| 181 |
+
available = [(d, t) for d, t in suite.BENCHMARKING]
|
| 182 |
+
raise ValueError(
|
| 183 |
+
f"Failed to load {domain_name}/{task_name}. "
|
| 184 |
+
f"Available environments: {available[:10]}... "
|
| 185 |
+
f"(use dm_control.suite.BENCHMARKING for full list)"
|
| 186 |
+
) from e
|
| 187 |
+
except Exception:
|
| 188 |
+
raise
|
| 189 |
+
|
| 190 |
+
self._domain_name = domain_name
|
| 191 |
+
self._task_name = task_name
|
| 192 |
+
|
| 193 |
+
self._state.domain_name = domain_name
|
| 194 |
+
self._state.task_name = task_name
|
| 195 |
+
self._state.action_spec = self._get_action_spec_info()
|
| 196 |
+
self._state.observation_spec = self._get_observation_spec_info()
|
| 197 |
+
self._state.physics_timestep = self._env.physics.timestep()
|
| 198 |
+
self._state.control_timestep = self._env.control_timestep()
|
| 199 |
+
|
| 200 |
+
def _get_action_spec_info(self) -> Dict[str, Any]:
|
| 201 |
+
"""Get information about the action space."""
|
| 202 |
+
spec = self._env.action_spec()
|
| 203 |
+
return {
|
| 204 |
+
"shape": list(spec.shape),
|
| 205 |
+
"dtype": str(spec.dtype),
|
| 206 |
+
"minimum": spec.minimum.tolist(),
|
| 207 |
+
"maximum": spec.maximum.tolist(),
|
| 208 |
+
"name": spec.name,
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
def _get_observation_spec_info(self) -> Dict[str, Any]:
|
| 212 |
+
"""Get information about the observation space."""
|
| 213 |
+
specs = self._env.observation_spec()
|
| 214 |
+
obs_info = {}
|
| 215 |
+
for name, spec in specs.items():
|
| 216 |
+
obs_info[name] = {
|
| 217 |
+
"shape": list(spec.shape),
|
| 218 |
+
"dtype": str(spec.dtype),
|
| 219 |
+
}
|
| 220 |
+
return obs_info
|
| 221 |
+
|
| 222 |
+
def _get_observation(
|
| 223 |
+
self,
|
| 224 |
+
time_step,
|
| 225 |
+
include_pixels: bool = False,
|
| 226 |
+
) -> DMControlObservation:
|
| 227 |
+
"""Convert dm_control TimeStep to DMControlObservation."""
|
| 228 |
+
import dm_env
|
| 229 |
+
|
| 230 |
+
observations = {}
|
| 231 |
+
for name, value in time_step.observation.items():
|
| 232 |
+
observations[name] = np.asarray(value).flatten().tolist()
|
| 233 |
+
|
| 234 |
+
pixels = None
|
| 235 |
+
if include_pixels:
|
| 236 |
+
try:
|
| 237 |
+
frame = self._env.physics.render(
|
| 238 |
+
height=self._render_height,
|
| 239 |
+
width=self._render_width,
|
| 240 |
+
camera_id=0,
|
| 241 |
+
)
|
| 242 |
+
from PIL import Image
|
| 243 |
+
|
| 244 |
+
img = Image.fromarray(frame)
|
| 245 |
+
buffer = io.BytesIO()
|
| 246 |
+
img.save(buffer, format="PNG")
|
| 247 |
+
pixels = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 248 |
+
except Exception:
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
done = time_step.step_type == dm_env.StepType.LAST
|
| 252 |
+
reward = float(time_step.reward) if time_step.reward is not None else 0.0
|
| 253 |
+
|
| 254 |
+
return DMControlObservation(
|
| 255 |
+
observations=observations,
|
| 256 |
+
pixels=pixels,
|
| 257 |
+
reward=reward,
|
| 258 |
+
done=done,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def reset(
|
| 262 |
+
self,
|
| 263 |
+
domain_name: Optional[str] = None,
|
| 264 |
+
task_name: Optional[str] = None,
|
| 265 |
+
seed: Optional[int] = None,
|
| 266 |
+
render: bool = False,
|
| 267 |
+
**kwargs,
|
| 268 |
+
) -> DMControlObservation:
|
| 269 |
+
"""
|
| 270 |
+
Reset the environment and return initial observation.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
domain_name: Optionally switch to a different domain.
|
| 274 |
+
task_name: Optionally switch to a different task.
|
| 275 |
+
seed: Random seed for reproducibility.
|
| 276 |
+
render: If True, include pixel observations.
|
| 277 |
+
**kwargs: Additional arguments (ignored).
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
DMControlObservation with initial state.
|
| 281 |
+
"""
|
| 282 |
+
self._include_pixels = render
|
| 283 |
+
|
| 284 |
+
target_domain = domain_name or self._domain_name
|
| 285 |
+
target_task = task_name or self._task_name
|
| 286 |
+
|
| 287 |
+
if (
|
| 288 |
+
self._env is None
|
| 289 |
+
or target_domain != self._domain_name
|
| 290 |
+
or target_task != self._task_name
|
| 291 |
+
):
|
| 292 |
+
self._load_environment(target_domain, target_task)
|
| 293 |
+
|
| 294 |
+
if seed is not None:
|
| 295 |
+
np.random.seed(seed)
|
| 296 |
+
|
| 297 |
+
time_step = self._env.reset()
|
| 298 |
+
|
| 299 |
+
self._state = DMControlState(
|
| 300 |
+
episode_id=str(uuid4()),
|
| 301 |
+
step_count=0,
|
| 302 |
+
domain_name=self._domain_name,
|
| 303 |
+
task_name=self._task_name,
|
| 304 |
+
action_spec=self._state.action_spec,
|
| 305 |
+
observation_spec=self._state.observation_spec,
|
| 306 |
+
physics_timestep=self._state.physics_timestep,
|
| 307 |
+
control_timestep=self._state.control_timestep,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
return self._get_observation(time_step, include_pixels=render)
|
| 311 |
+
|
| 312 |
+
def step(
|
| 313 |
+
self,
|
| 314 |
+
action: DMControlAction,
|
| 315 |
+
render: bool = False,
|
| 316 |
+
**kwargs,
|
| 317 |
+
) -> DMControlObservation:
|
| 318 |
+
"""
|
| 319 |
+
Execute one step in the environment.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
action: DMControlAction with continuous action values.
|
| 323 |
+
render: If True, include pixel observations.
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
DMControlObservation with new state, reward, and done flag.
|
| 327 |
+
"""
|
| 328 |
+
if self._env is None:
|
| 329 |
+
raise RuntimeError("Environment not initialized. Call reset() first.")
|
| 330 |
+
|
| 331 |
+
action_array = np.array(action.values, dtype=np.float64)
|
| 332 |
+
|
| 333 |
+
action_spec = self._env.action_spec()
|
| 334 |
+
expected_shape = action_spec.shape
|
| 335 |
+
if action_array.shape != expected_shape:
|
| 336 |
+
if action_array.size == np.prod(expected_shape):
|
| 337 |
+
action_array = action_array.reshape(expected_shape)
|
| 338 |
+
else:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
f"Action shape {action_array.shape} doesn't match "
|
| 341 |
+
f"expected shape {expected_shape}"
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
action_array = np.clip(action_array, action_spec.minimum, action_spec.maximum)
|
| 345 |
+
|
| 346 |
+
time_step = self._env.step(action_array)
|
| 347 |
+
self._state.step_count += 1
|
| 348 |
+
|
| 349 |
+
return self._get_observation(
|
| 350 |
+
time_step, include_pixels=render or self._include_pixels
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
async def reset_async(
|
| 354 |
+
self,
|
| 355 |
+
domain_name: Optional[str] = None,
|
| 356 |
+
task_name: Optional[str] = None,
|
| 357 |
+
seed: Optional[int] = None,
|
| 358 |
+
render: bool = False,
|
| 359 |
+
**kwargs,
|
| 360 |
+
) -> DMControlObservation:
|
| 361 |
+
"""Async version of reset.
|
| 362 |
+
|
| 363 |
+
On macOS, runs synchronously to avoid MuJoCo threading crashes.
|
| 364 |
+
On other platforms, runs in a thread pool.
|
| 365 |
+
"""
|
| 366 |
+
if sys.platform == "darwin":
|
| 367 |
+
# On macOS, MuJoCo crashes when run in a background thread
|
| 368 |
+
# Run synchronously (blocks event loop but avoids crash)
|
| 369 |
+
return self.reset(
|
| 370 |
+
domain_name=domain_name,
|
| 371 |
+
task_name=task_name,
|
| 372 |
+
seed=seed,
|
| 373 |
+
render=render,
|
| 374 |
+
**kwargs,
|
| 375 |
+
)
|
| 376 |
+
else:
|
| 377 |
+
import asyncio
|
| 378 |
+
|
| 379 |
+
return await asyncio.to_thread(
|
| 380 |
+
self.reset,
|
| 381 |
+
domain_name=domain_name,
|
| 382 |
+
task_name=task_name,
|
| 383 |
+
seed=seed,
|
| 384 |
+
render=render,
|
| 385 |
+
**kwargs,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
async def step_async(
|
| 389 |
+
self,
|
| 390 |
+
action: DMControlAction,
|
| 391 |
+
render: bool = False,
|
| 392 |
+
**kwargs,
|
| 393 |
+
) -> DMControlObservation:
|
| 394 |
+
"""Async version of step.
|
| 395 |
+
|
| 396 |
+
On macOS, runs synchronously to avoid MuJoCo threading crashes.
|
| 397 |
+
On other platforms, runs in a thread pool.
|
| 398 |
+
"""
|
| 399 |
+
if sys.platform == "darwin":
|
| 400 |
+
# On macOS, MuJoCo crashes when run in a background thread
|
| 401 |
+
# Run synchronously (blocks event loop but avoids crash)
|
| 402 |
+
return self.step(action, render=render, **kwargs)
|
| 403 |
+
else:
|
| 404 |
+
import asyncio
|
| 405 |
+
|
| 406 |
+
return await asyncio.to_thread(self.step, action, render=render, **kwargs)
|
| 407 |
+
|
| 408 |
+
@property
|
| 409 |
+
def state(self) -> DMControlState:
|
| 410 |
+
"""Get the current environment state."""
|
| 411 |
+
return self._state
|
| 412 |
+
|
| 413 |
+
def close(self) -> None:
|
| 414 |
+
"""Close the dm_control environment."""
|
| 415 |
+
env = getattr(self, "_env", None)
|
| 416 |
+
if env is not None:
|
| 417 |
+
try:
|
| 418 |
+
env.close()
|
| 419 |
+
except Exception:
|
| 420 |
+
pass
|
| 421 |
+
self._env = None
|
| 422 |
+
|
| 423 |
+
def __del__(self):
|
| 424 |
+
"""Cleanup on deletion."""
|
| 425 |
+
try:
|
| 426 |
+
self.close()
|
| 427 |
+
except Exception:
|
| 428 |
+
pass
|
examples/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""dm_control examples."""
|
examples/cartpole_control.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Interactive cartpole control via OpenEnv.
|
| 3 |
+
|
| 4 |
+
This example demonstrates using the dm_control OpenEnv client with
|
| 5 |
+
the cartpole environment. Use arrow keys to control the cart.
|
| 6 |
+
|
| 7 |
+
Controls:
|
| 8 |
+
LEFT/RIGHT arrows: Apply force to move cart
|
| 9 |
+
R: Reset environment
|
| 10 |
+
ESC or Q: Quit
|
| 11 |
+
|
| 12 |
+
Requirements:
|
| 13 |
+
pip install pygame
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
1. Start the server: uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 17 |
+
2. Run this script: python examples/cartpole_control.py
|
| 18 |
+
|
| 19 |
+
For visual mode (requires working MuJoCo rendering):
|
| 20 |
+
python examples/cartpole_control.py --visual
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import random
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
# Add parent directory to path for imports
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 30 |
+
|
| 31 |
+
from client import DMControlEnv
|
| 32 |
+
from models import DMControlAction
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def run_headless(env: DMControlEnv, task: str = "balance", max_steps: int = 500):
|
| 36 |
+
"""Run cartpole control in headless mode."""
|
| 37 |
+
print("\n=== Headless Mode (OpenEnv Step/Observation Pattern) ===")
|
| 38 |
+
print("This mode demonstrates the OpenEnv API with the cartpole.\n")
|
| 39 |
+
|
| 40 |
+
# Reset environment using OpenEnv pattern
|
| 41 |
+
result = env.reset(domain_name="cartpole", task_name=task)
|
| 42 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 43 |
+
print(f" position: {result.observation.observations.get('position', [])}")
|
| 44 |
+
print(f" velocity: {result.observation.observations.get('velocity', [])}")
|
| 45 |
+
|
| 46 |
+
total_reward = 0.0
|
| 47 |
+
step_count = 0
|
| 48 |
+
|
| 49 |
+
print("\nRunning with random actions to demonstrate step/observation pattern...\n")
|
| 50 |
+
|
| 51 |
+
while not result.done and step_count < max_steps:
|
| 52 |
+
# Random action in [-1, 1]
|
| 53 |
+
action_value = random.uniform(-1.0, 1.0)
|
| 54 |
+
|
| 55 |
+
# Step the environment using OpenEnv pattern
|
| 56 |
+
action = DMControlAction(values=[action_value])
|
| 57 |
+
result = env.step(action)
|
| 58 |
+
|
| 59 |
+
# Access observation and reward from result
|
| 60 |
+
total_reward += result.reward or 0.0
|
| 61 |
+
step_count += 1
|
| 62 |
+
|
| 63 |
+
# Print progress periodically
|
| 64 |
+
if step_count % 50 == 0:
|
| 65 |
+
pos = result.observation.observations.get("position", [])
|
| 66 |
+
vel = result.observation.observations.get("velocity", [])
|
| 67 |
+
print(
|
| 68 |
+
f"Step {step_count}: reward={result.reward:.3f}, "
|
| 69 |
+
f"total={total_reward:.2f}, done={result.done}"
|
| 70 |
+
)
|
| 71 |
+
print(f" position={pos}, velocity={vel}")
|
| 72 |
+
|
| 73 |
+
print(f"\nEpisode finished: {step_count} steps, total reward: {total_reward:.2f}")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def run_interactive(env: DMControlEnv, task: str = "balance"):
|
| 77 |
+
"""Run interactive control with keyboard input via pygame."""
|
| 78 |
+
import pygame
|
| 79 |
+
|
| 80 |
+
print("\n=== Interactive Mode (OpenEnv Step/Observation Pattern) ===")
|
| 81 |
+
print("Use LEFT/RIGHT arrows to control cart, R to reset, ESC to quit.\n")
|
| 82 |
+
|
| 83 |
+
# Reset environment using OpenEnv pattern
|
| 84 |
+
result = env.reset(domain_name="cartpole", task_name=task)
|
| 85 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 86 |
+
|
| 87 |
+
# Initialize pygame for keyboard input (minimal window)
|
| 88 |
+
pygame.init()
|
| 89 |
+
screen = pygame.display.set_mode((400, 100))
|
| 90 |
+
pygame.display.set_caption("Cartpole Control - Arrow keys to move, R to reset")
|
| 91 |
+
clock = pygame.time.Clock()
|
| 92 |
+
|
| 93 |
+
# Font for display
|
| 94 |
+
font = pygame.font.Font(None, 24)
|
| 95 |
+
|
| 96 |
+
running = True
|
| 97 |
+
total_reward = 0.0
|
| 98 |
+
step_count = 0
|
| 99 |
+
|
| 100 |
+
print("\nControls:")
|
| 101 |
+
print(" LEFT/RIGHT arrows: Move cart")
|
| 102 |
+
print(" R: Reset environment")
|
| 103 |
+
print(" ESC or Q: Quit\n")
|
| 104 |
+
|
| 105 |
+
while running:
|
| 106 |
+
# Handle events
|
| 107 |
+
for event in pygame.event.get():
|
| 108 |
+
if event.type == pygame.QUIT:
|
| 109 |
+
running = False
|
| 110 |
+
elif event.type == pygame.KEYDOWN:
|
| 111 |
+
if event.key in (pygame.K_ESCAPE, pygame.K_q):
|
| 112 |
+
running = False
|
| 113 |
+
elif event.key == pygame.K_r:
|
| 114 |
+
result = env.reset(domain_name="cartpole", task_name=task)
|
| 115 |
+
total_reward = 0.0
|
| 116 |
+
step_count = 0
|
| 117 |
+
print("Environment reset")
|
| 118 |
+
|
| 119 |
+
# Check for held keys (for continuous control)
|
| 120 |
+
keys = pygame.key.get_pressed()
|
| 121 |
+
if keys[pygame.K_LEFT]:
|
| 122 |
+
action_value = -1.0
|
| 123 |
+
elif keys[pygame.K_RIGHT]:
|
| 124 |
+
action_value = 1.0
|
| 125 |
+
else:
|
| 126 |
+
action_value = 0.0
|
| 127 |
+
|
| 128 |
+
# Step the environment using OpenEnv pattern
|
| 129 |
+
action = DMControlAction(values=[action_value])
|
| 130 |
+
result = env.step(action)
|
| 131 |
+
|
| 132 |
+
# Track reward from result
|
| 133 |
+
total_reward += result.reward or 0.0
|
| 134 |
+
step_count += 1
|
| 135 |
+
|
| 136 |
+
# Check if episode is done
|
| 137 |
+
if result.done:
|
| 138 |
+
print(
|
| 139 |
+
f"Episode finished! Steps: {step_count}, "
|
| 140 |
+
f"Total reward: {total_reward:.2f}"
|
| 141 |
+
)
|
| 142 |
+
# Auto-reset on done
|
| 143 |
+
result = env.reset(domain_name="cartpole", task_name=task)
|
| 144 |
+
total_reward = 0.0
|
| 145 |
+
step_count = 0
|
| 146 |
+
|
| 147 |
+
# Update display
|
| 148 |
+
direction = (
|
| 149 |
+
"<--" if action_value < 0 else ("-->" if action_value > 0 else "---")
|
| 150 |
+
)
|
| 151 |
+
screen.fill((30, 30, 30))
|
| 152 |
+
text = font.render(
|
| 153 |
+
f"Step: {step_count} | Reward: {total_reward:.1f} | {direction}",
|
| 154 |
+
True,
|
| 155 |
+
(255, 255, 255),
|
| 156 |
+
)
|
| 157 |
+
screen.blit(text, (10, 40))
|
| 158 |
+
pygame.display.flip()
|
| 159 |
+
|
| 160 |
+
# Print progress periodically
|
| 161 |
+
if step_count % 200 == 0 and step_count > 0:
|
| 162 |
+
print(f"Step {step_count}: Total reward: {total_reward:.2f}")
|
| 163 |
+
|
| 164 |
+
# Cap at 30 FPS
|
| 165 |
+
clock.tick(30)
|
| 166 |
+
|
| 167 |
+
pygame.quit()
|
| 168 |
+
print(f"Session ended. Final reward: {total_reward:.2f}")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def run_visual(env: DMControlEnv, task: str = "balance"):
|
| 172 |
+
"""Run with pygame visualization showing rendered frames."""
|
| 173 |
+
import base64
|
| 174 |
+
import io
|
| 175 |
+
|
| 176 |
+
import pygame
|
| 177 |
+
|
| 178 |
+
print("\n=== Visual Mode (OpenEnv Step/Observation Pattern) ===")
|
| 179 |
+
|
| 180 |
+
# Reset environment with rendering enabled
|
| 181 |
+
result = env.reset(domain_name="cartpole", task_name=task, render=True)
|
| 182 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 183 |
+
|
| 184 |
+
# Get first frame to determine window size
|
| 185 |
+
if result.observation.pixels is None:
|
| 186 |
+
print("Error: Server did not return rendered pixels.")
|
| 187 |
+
print("Make sure the server supports render=True")
|
| 188 |
+
print("\nTry running in interactive mode (default) instead.")
|
| 189 |
+
sys.exit(1)
|
| 190 |
+
|
| 191 |
+
# Decode base64 PNG to pygame surface
|
| 192 |
+
png_data = base64.b64decode(result.observation.pixels)
|
| 193 |
+
frame = pygame.image.load(io.BytesIO(png_data))
|
| 194 |
+
frame_size = frame.get_size()
|
| 195 |
+
|
| 196 |
+
# Initialize pygame
|
| 197 |
+
pygame.init()
|
| 198 |
+
screen = pygame.display.set_mode(frame_size)
|
| 199 |
+
pygame.display.set_caption(
|
| 200 |
+
"Cartpole (OpenEnv) - Arrow Keys to Move, R to Reset, ESC to Quit"
|
| 201 |
+
)
|
| 202 |
+
clock = pygame.time.Clock()
|
| 203 |
+
|
| 204 |
+
print("Controls:")
|
| 205 |
+
print(" LEFT/RIGHT arrows: Move cart")
|
| 206 |
+
print(" R: Reset environment")
|
| 207 |
+
print(" ESC or Q: Quit")
|
| 208 |
+
|
| 209 |
+
running = True
|
| 210 |
+
total_reward = 0.0
|
| 211 |
+
step_count = 0
|
| 212 |
+
|
| 213 |
+
while running:
|
| 214 |
+
# Handle events
|
| 215 |
+
for event in pygame.event.get():
|
| 216 |
+
if event.type == pygame.QUIT:
|
| 217 |
+
running = False
|
| 218 |
+
elif event.type == pygame.KEYDOWN:
|
| 219 |
+
if event.key in (pygame.K_ESCAPE, pygame.K_q):
|
| 220 |
+
running = False
|
| 221 |
+
elif event.key == pygame.K_r:
|
| 222 |
+
result = env.reset(
|
| 223 |
+
domain_name="cartpole", task_name=task, render=True
|
| 224 |
+
)
|
| 225 |
+
total_reward = 0.0
|
| 226 |
+
step_count = 0
|
| 227 |
+
print("Environment reset")
|
| 228 |
+
|
| 229 |
+
# Check for held keys (for continuous control)
|
| 230 |
+
keys = pygame.key.get_pressed()
|
| 231 |
+
if keys[pygame.K_LEFT]:
|
| 232 |
+
action_value = -1.0
|
| 233 |
+
elif keys[pygame.K_RIGHT]:
|
| 234 |
+
action_value = 1.0
|
| 235 |
+
else:
|
| 236 |
+
action_value = 0.0
|
| 237 |
+
|
| 238 |
+
# Step the environment using OpenEnv pattern
|
| 239 |
+
action = DMControlAction(values=[action_value])
|
| 240 |
+
result = env.step(action, render=True)
|
| 241 |
+
|
| 242 |
+
# Track reward from result
|
| 243 |
+
total_reward += result.reward or 0.0
|
| 244 |
+
step_count += 1
|
| 245 |
+
|
| 246 |
+
# Check if episode is done
|
| 247 |
+
if result.done:
|
| 248 |
+
print(
|
| 249 |
+
f"Episode finished! Steps: {step_count}, "
|
| 250 |
+
f"Total reward: {total_reward:.2f}"
|
| 251 |
+
)
|
| 252 |
+
result = env.reset(domain_name="cartpole", task_name=task, render=True)
|
| 253 |
+
total_reward = 0.0
|
| 254 |
+
step_count = 0
|
| 255 |
+
|
| 256 |
+
# Render the frame from observation pixels
|
| 257 |
+
if result.observation.pixels:
|
| 258 |
+
png_data = base64.b64decode(result.observation.pixels)
|
| 259 |
+
frame = pygame.image.load(io.BytesIO(png_data))
|
| 260 |
+
screen.blit(frame, (0, 0))
|
| 261 |
+
pygame.display.flip()
|
| 262 |
+
|
| 263 |
+
# Print progress periodically
|
| 264 |
+
if step_count % 200 == 0 and step_count > 0:
|
| 265 |
+
print(f"Step {step_count}: Total reward: {total_reward:.2f}")
|
| 266 |
+
|
| 267 |
+
# Cap at 30 FPS
|
| 268 |
+
clock.tick(30)
|
| 269 |
+
|
| 270 |
+
pygame.quit()
|
| 271 |
+
print(f"Session ended. Final reward: {total_reward:.2f}")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def main():
|
| 275 |
+
parser = argparse.ArgumentParser(
|
| 276 |
+
description="Interactive cartpole control via OpenEnv"
|
| 277 |
+
)
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--visual",
|
| 280 |
+
action="store_true",
|
| 281 |
+
help="Enable pygame visualization with rendered frames",
|
| 282 |
+
)
|
| 283 |
+
parser.add_argument(
|
| 284 |
+
"--headless",
|
| 285 |
+
action="store_true",
|
| 286 |
+
help="Run in headless mode (no pygame, automated control)",
|
| 287 |
+
)
|
| 288 |
+
parser.add_argument(
|
| 289 |
+
"--max-steps",
|
| 290 |
+
type=int,
|
| 291 |
+
default=500,
|
| 292 |
+
help="Maximum steps for headless mode (default: 500)",
|
| 293 |
+
)
|
| 294 |
+
parser.add_argument(
|
| 295 |
+
"--task",
|
| 296 |
+
type=str,
|
| 297 |
+
default="balance",
|
| 298 |
+
choices=["balance", "balance_sparse", "swingup", "swingup_sparse"],
|
| 299 |
+
help="Cartpole task (default: balance)",
|
| 300 |
+
)
|
| 301 |
+
args = parser.parse_args()
|
| 302 |
+
|
| 303 |
+
server_url = "http://localhost:8000"
|
| 304 |
+
print(f"Connecting to {server_url}...")
|
| 305 |
+
|
| 306 |
+
try:
|
| 307 |
+
with DMControlEnv(base_url=server_url) as env:
|
| 308 |
+
print("Connected!")
|
| 309 |
+
|
| 310 |
+
# Get environment state
|
| 311 |
+
state = env.state()
|
| 312 |
+
print(f"Domain: {state.domain_name}, Task: {state.task_name}")
|
| 313 |
+
print(f"Action spec: {state.action_spec}")
|
| 314 |
+
|
| 315 |
+
if args.headless:
|
| 316 |
+
run_headless(env, task=args.task, max_steps=args.max_steps)
|
| 317 |
+
elif args.visual:
|
| 318 |
+
run_visual(env, task=args.task)
|
| 319 |
+
else:
|
| 320 |
+
run_interactive(env, task=args.task)
|
| 321 |
+
|
| 322 |
+
except ConnectionError as e:
|
| 323 |
+
print(f"Failed to connect: {e}")
|
| 324 |
+
print("\nMake sure the server is running:")
|
| 325 |
+
print(" cd OpenEnv")
|
| 326 |
+
print(
|
| 327 |
+
" PYTHONPATH=src:envs uvicorn envs.dm_control_env.server.app:app --port 8000"
|
| 328 |
+
)
|
| 329 |
+
sys.exit(1)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
if __name__ == "__main__":
|
| 333 |
+
main()
|
examples/hopper_control.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Interactive hopper control via OpenEnv.
|
| 3 |
+
|
| 4 |
+
This example demonstrates using the dm_control OpenEnv client with
|
| 5 |
+
the hopper environment. Press SPACE to apply random forces to the joints.
|
| 6 |
+
|
| 7 |
+
Controls:
|
| 8 |
+
SPACE: Apply random force to all joints
|
| 9 |
+
R: Reset environment
|
| 10 |
+
ESC or Q: Quit
|
| 11 |
+
|
| 12 |
+
Requirements:
|
| 13 |
+
pip install pygame
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
1. Start the server: uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 17 |
+
2. Run this script: python examples/hopper_control.py
|
| 18 |
+
|
| 19 |
+
For visual mode (requires working MuJoCo rendering):
|
| 20 |
+
python examples/hopper_control.py --visual
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import random
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
# Add parent directory to path for imports
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 30 |
+
|
| 31 |
+
from client import DMControlEnv
|
| 32 |
+
from models import DMControlAction
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_action_dim(env: DMControlEnv) -> int:
|
| 36 |
+
"""Get the action dimension from the environment state."""
|
| 37 |
+
state = env.state()
|
| 38 |
+
action_spec = state.action_spec
|
| 39 |
+
if action_spec and "shape" in action_spec:
|
| 40 |
+
shape = action_spec["shape"]
|
| 41 |
+
if isinstance(shape, list) and len(shape) > 0:
|
| 42 |
+
return shape[0]
|
| 43 |
+
# Hopper default: 4 actuators (hip, knee, ankle, toe)
|
| 44 |
+
return 4
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def generate_random_action(action_dim: int, magnitude: float = 1.0) -> DMControlAction:
|
| 48 |
+
"""Generate a random action with values in [-magnitude, magnitude]."""
|
| 49 |
+
values = [random.uniform(-magnitude, magnitude) for _ in range(action_dim)]
|
| 50 |
+
return DMControlAction(values=values)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def generate_zero_action(action_dim: int) -> DMControlAction:
|
| 54 |
+
"""Generate a zero action (no force applied)."""
|
| 55 |
+
return DMControlAction(values=[0.0] * action_dim)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def run_headless(env: DMControlEnv, task: str = "hop", max_steps: int = 1000):
|
| 59 |
+
"""Run hopper control in headless mode."""
|
| 60 |
+
print("\n=== Headless Mode (OpenEnv Step/Observation Pattern) ===")
|
| 61 |
+
print("This mode demonstrates the OpenEnv API with the hopper.\n")
|
| 62 |
+
|
| 63 |
+
# Reset environment using OpenEnv pattern
|
| 64 |
+
result = env.reset(domain_name="hopper", task_name=task)
|
| 65 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 66 |
+
|
| 67 |
+
# Get action dimension
|
| 68 |
+
action_dim = get_action_dim(env)
|
| 69 |
+
print(f"Action dimension: {action_dim}")
|
| 70 |
+
|
| 71 |
+
total_reward = 0.0
|
| 72 |
+
step_count = 0
|
| 73 |
+
|
| 74 |
+
print("\nRunning with periodic random forces...")
|
| 75 |
+
print("Every 30 steps, a random force burst is applied.\n")
|
| 76 |
+
|
| 77 |
+
while not result.done and step_count < max_steps:
|
| 78 |
+
# Apply random force every 30 steps, otherwise zero action
|
| 79 |
+
if step_count % 30 < 5:
|
| 80 |
+
# Random force burst for 5 steps
|
| 81 |
+
action = generate_random_action(action_dim, magnitude=0.8)
|
| 82 |
+
else:
|
| 83 |
+
# No force
|
| 84 |
+
action = generate_zero_action(action_dim)
|
| 85 |
+
|
| 86 |
+
# Step the environment using OpenEnv pattern
|
| 87 |
+
result = env.step(action)
|
| 88 |
+
|
| 89 |
+
# Access observation and reward from result
|
| 90 |
+
total_reward += result.reward or 0.0
|
| 91 |
+
step_count += 1
|
| 92 |
+
|
| 93 |
+
# Print progress periodically
|
| 94 |
+
if step_count % 100 == 0:
|
| 95 |
+
# Get some observation values
|
| 96 |
+
position = result.observation.observations.get("position", [])
|
| 97 |
+
velocity = result.observation.observations.get("velocity", [])
|
| 98 |
+
print(
|
| 99 |
+
f"Step {step_count}: reward={result.reward:.3f}, "
|
| 100 |
+
f"total={total_reward:.2f}, done={result.done}"
|
| 101 |
+
)
|
| 102 |
+
if position:
|
| 103 |
+
print(f" position: {position[:3]}")
|
| 104 |
+
if velocity:
|
| 105 |
+
print(f" velocity: {velocity[:3]}")
|
| 106 |
+
|
| 107 |
+
print(f"\nEpisode finished: {step_count} steps, total reward: {total_reward:.2f}")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def run_interactive(env: DMControlEnv, task: str = "hop"):
|
| 111 |
+
"""Run interactive control with keyboard input via pygame."""
|
| 112 |
+
import pygame
|
| 113 |
+
|
| 114 |
+
print("\n=== Interactive Mode (OpenEnv Step/Observation Pattern) ===")
|
| 115 |
+
print("Press SPACE to apply random force, R to reset, ESC to quit.\n")
|
| 116 |
+
|
| 117 |
+
# Reset environment using OpenEnv pattern
|
| 118 |
+
result = env.reset(domain_name="hopper", task_name=task)
|
| 119 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 120 |
+
|
| 121 |
+
# Get action dimension
|
| 122 |
+
action_dim = get_action_dim(env)
|
| 123 |
+
print(f"Action dimension: {action_dim}")
|
| 124 |
+
|
| 125 |
+
# Initialize pygame for keyboard input (minimal window)
|
| 126 |
+
pygame.init()
|
| 127 |
+
screen = pygame.display.set_mode((400, 100))
|
| 128 |
+
pygame.display.set_caption("Hopper Control - SPACE for random force, R to reset")
|
| 129 |
+
clock = pygame.time.Clock()
|
| 130 |
+
|
| 131 |
+
# Font for display
|
| 132 |
+
font = pygame.font.Font(None, 24)
|
| 133 |
+
|
| 134 |
+
running = True
|
| 135 |
+
total_reward = 0.0
|
| 136 |
+
step_count = 0
|
| 137 |
+
apply_random_force = False
|
| 138 |
+
|
| 139 |
+
print("\nControls:")
|
| 140 |
+
print(" SPACE: Apply random force to joints")
|
| 141 |
+
print(" R: Reset environment")
|
| 142 |
+
print(" ESC or Q: Quit\n")
|
| 143 |
+
|
| 144 |
+
while running:
|
| 145 |
+
# Handle events
|
| 146 |
+
for event in pygame.event.get():
|
| 147 |
+
if event.type == pygame.QUIT:
|
| 148 |
+
running = False
|
| 149 |
+
elif event.type == pygame.KEYDOWN:
|
| 150 |
+
if event.key in (pygame.K_ESCAPE, pygame.K_q):
|
| 151 |
+
running = False
|
| 152 |
+
elif event.key == pygame.K_r:
|
| 153 |
+
result = env.reset(domain_name="hopper", task_name=task)
|
| 154 |
+
total_reward = 0.0
|
| 155 |
+
step_count = 0
|
| 156 |
+
print("Environment reset")
|
| 157 |
+
|
| 158 |
+
# Check for held keys
|
| 159 |
+
keys = pygame.key.get_pressed()
|
| 160 |
+
apply_random_force = keys[pygame.K_SPACE]
|
| 161 |
+
|
| 162 |
+
# Generate action based on input
|
| 163 |
+
if apply_random_force:
|
| 164 |
+
action = generate_random_action(action_dim, magnitude=2.0)
|
| 165 |
+
else:
|
| 166 |
+
action = generate_zero_action(action_dim)
|
| 167 |
+
|
| 168 |
+
# Step the environment using OpenEnv pattern
|
| 169 |
+
result = env.step(action)
|
| 170 |
+
|
| 171 |
+
# Track reward from result
|
| 172 |
+
total_reward += result.reward or 0.0
|
| 173 |
+
step_count += 1
|
| 174 |
+
|
| 175 |
+
# Check if episode is done
|
| 176 |
+
if result.done:
|
| 177 |
+
print(
|
| 178 |
+
f"Episode finished! Steps: {step_count}, "
|
| 179 |
+
f"Total reward: {total_reward:.2f}"
|
| 180 |
+
)
|
| 181 |
+
# Auto-reset on done
|
| 182 |
+
result = env.reset(domain_name="hopper", task_name=task)
|
| 183 |
+
total_reward = 0.0
|
| 184 |
+
step_count = 0
|
| 185 |
+
|
| 186 |
+
# Update display
|
| 187 |
+
screen.fill((30, 30, 30))
|
| 188 |
+
status = "FORCE!" if apply_random_force else "idle"
|
| 189 |
+
text = font.render(
|
| 190 |
+
f"Step: {step_count} | Reward: {total_reward:.1f} | {status}",
|
| 191 |
+
True,
|
| 192 |
+
(255, 255, 255),
|
| 193 |
+
)
|
| 194 |
+
screen.blit(text, (10, 40))
|
| 195 |
+
pygame.display.flip()
|
| 196 |
+
|
| 197 |
+
# Print progress periodically
|
| 198 |
+
if step_count % 200 == 0 and step_count > 0:
|
| 199 |
+
print(f"Step {step_count}: Total reward: {total_reward:.2f}")
|
| 200 |
+
|
| 201 |
+
# Cap at 30 FPS
|
| 202 |
+
clock.tick(30)
|
| 203 |
+
|
| 204 |
+
pygame.quit()
|
| 205 |
+
print(f"Session ended. Final reward: {total_reward:.2f}")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def run_visual(env: DMControlEnv, task: str = "hop"):
|
| 209 |
+
"""Run with pygame visualization showing rendered frames."""
|
| 210 |
+
import base64
|
| 211 |
+
import io
|
| 212 |
+
|
| 213 |
+
import pygame
|
| 214 |
+
|
| 215 |
+
print("\n=== Visual Mode (OpenEnv Step/Observation Pattern) ===")
|
| 216 |
+
|
| 217 |
+
# Reset environment with rendering enabled
|
| 218 |
+
result = env.reset(domain_name="hopper", task_name=task, render=True)
|
| 219 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 220 |
+
|
| 221 |
+
# Get action dimension
|
| 222 |
+
action_dim = get_action_dim(env)
|
| 223 |
+
print(f"Action dimension: {action_dim}")
|
| 224 |
+
|
| 225 |
+
# Get first frame to determine window size
|
| 226 |
+
if result.observation.pixels is None:
|
| 227 |
+
print("Error: Server did not return rendered pixels.")
|
| 228 |
+
print("Make sure the server supports render=True")
|
| 229 |
+
print("\nTry running in interactive mode (default) instead.")
|
| 230 |
+
sys.exit(1)
|
| 231 |
+
|
| 232 |
+
# Decode base64 PNG to pygame surface
|
| 233 |
+
png_data = base64.b64decode(result.observation.pixels)
|
| 234 |
+
frame = pygame.image.load(io.BytesIO(png_data))
|
| 235 |
+
frame_size = frame.get_size()
|
| 236 |
+
|
| 237 |
+
# Initialize pygame
|
| 238 |
+
pygame.init()
|
| 239 |
+
screen = pygame.display.set_mode(frame_size)
|
| 240 |
+
pygame.display.set_caption(
|
| 241 |
+
"Hopper (OpenEnv) - SPACE for random force, R to Reset, ESC to Quit"
|
| 242 |
+
)
|
| 243 |
+
clock = pygame.time.Clock()
|
| 244 |
+
|
| 245 |
+
print("Controls:")
|
| 246 |
+
print(" SPACE: Apply random force to joints")
|
| 247 |
+
print(" R: Reset environment")
|
| 248 |
+
print(" ESC or Q: Quit")
|
| 249 |
+
|
| 250 |
+
running = True
|
| 251 |
+
total_reward = 0.0
|
| 252 |
+
step_count = 0
|
| 253 |
+
|
| 254 |
+
while running:
|
| 255 |
+
# Handle events
|
| 256 |
+
for event in pygame.event.get():
|
| 257 |
+
if event.type == pygame.QUIT:
|
| 258 |
+
running = False
|
| 259 |
+
elif event.type == pygame.KEYDOWN:
|
| 260 |
+
if event.key in (pygame.K_ESCAPE, pygame.K_q):
|
| 261 |
+
running = False
|
| 262 |
+
elif event.key == pygame.K_r:
|
| 263 |
+
result = env.reset(
|
| 264 |
+
domain_name="hopper", task_name=task, render=True
|
| 265 |
+
)
|
| 266 |
+
total_reward = 0.0
|
| 267 |
+
step_count = 0
|
| 268 |
+
print("Environment reset")
|
| 269 |
+
|
| 270 |
+
# Check for held keys
|
| 271 |
+
keys = pygame.key.get_pressed()
|
| 272 |
+
apply_random_force = keys[pygame.K_SPACE]
|
| 273 |
+
|
| 274 |
+
# Generate action based on input
|
| 275 |
+
if apply_random_force:
|
| 276 |
+
action = generate_random_action(action_dim, magnitude=2.0)
|
| 277 |
+
else:
|
| 278 |
+
action = generate_zero_action(action_dim)
|
| 279 |
+
|
| 280 |
+
# Step the environment using OpenEnv pattern
|
| 281 |
+
result = env.step(action, render=True)
|
| 282 |
+
|
| 283 |
+
# Track reward from result
|
| 284 |
+
total_reward += result.reward or 0.0
|
| 285 |
+
step_count += 1
|
| 286 |
+
|
| 287 |
+
# Check if episode is done
|
| 288 |
+
if result.done:
|
| 289 |
+
print(
|
| 290 |
+
f"Episode finished! Steps: {step_count}, "
|
| 291 |
+
f"Total reward: {total_reward:.2f}"
|
| 292 |
+
)
|
| 293 |
+
result = env.reset(domain_name="hopper", task_name=task, render=True)
|
| 294 |
+
total_reward = 0.0
|
| 295 |
+
step_count = 0
|
| 296 |
+
|
| 297 |
+
# Render the frame from observation pixels
|
| 298 |
+
if result.observation.pixels:
|
| 299 |
+
png_data = base64.b64decode(result.observation.pixels)
|
| 300 |
+
frame = pygame.image.load(io.BytesIO(png_data))
|
| 301 |
+
screen.blit(frame, (0, 0))
|
| 302 |
+
pygame.display.flip()
|
| 303 |
+
|
| 304 |
+
# Print progress periodically
|
| 305 |
+
if step_count % 200 == 0 and step_count > 0:
|
| 306 |
+
print(f"Step {step_count}: Total reward: {total_reward:.2f}")
|
| 307 |
+
|
| 308 |
+
# Cap at 30 FPS
|
| 309 |
+
clock.tick(30)
|
| 310 |
+
|
| 311 |
+
pygame.quit()
|
| 312 |
+
print(f"Session ended. Final reward: {total_reward:.2f}")
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def main():
|
| 316 |
+
parser = argparse.ArgumentParser(
|
| 317 |
+
description="Interactive hopper control via OpenEnv"
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--visual",
|
| 321 |
+
action="store_true",
|
| 322 |
+
help="Enable pygame visualization with rendered frames",
|
| 323 |
+
)
|
| 324 |
+
parser.add_argument(
|
| 325 |
+
"--headless",
|
| 326 |
+
action="store_true",
|
| 327 |
+
help="Run in headless mode (no pygame, automated control)",
|
| 328 |
+
)
|
| 329 |
+
parser.add_argument(
|
| 330 |
+
"--max-steps",
|
| 331 |
+
type=int,
|
| 332 |
+
default=1000,
|
| 333 |
+
help="Maximum steps for headless mode (default: 1000)",
|
| 334 |
+
)
|
| 335 |
+
parser.add_argument(
|
| 336 |
+
"--task",
|
| 337 |
+
type=str,
|
| 338 |
+
default="hop",
|
| 339 |
+
choices=["stand", "hop"],
|
| 340 |
+
help="Hopper task (default: hop)",
|
| 341 |
+
)
|
| 342 |
+
args = parser.parse_args()
|
| 343 |
+
|
| 344 |
+
server_url = "http://localhost:8000"
|
| 345 |
+
print(f"Connecting to {server_url}...")
|
| 346 |
+
|
| 347 |
+
try:
|
| 348 |
+
with DMControlEnv(base_url=server_url) as env:
|
| 349 |
+
print("Connected!")
|
| 350 |
+
|
| 351 |
+
# Get environment state
|
| 352 |
+
state = env.state()
|
| 353 |
+
print(f"Domain: {state.domain_name}, Task: {state.task_name}")
|
| 354 |
+
print(f"Action spec: {state.action_spec}")
|
| 355 |
+
|
| 356 |
+
if args.headless:
|
| 357 |
+
run_headless(env, task=args.task, max_steps=args.max_steps)
|
| 358 |
+
elif args.visual:
|
| 359 |
+
run_visual(env, task=args.task)
|
| 360 |
+
else:
|
| 361 |
+
run_interactive(env, task=args.task)
|
| 362 |
+
|
| 363 |
+
except ConnectionError as e:
|
| 364 |
+
print(f"Failed to connect: {e}")
|
| 365 |
+
print("\nMake sure the server is running:")
|
| 366 |
+
print(" cd OpenEnv")
|
| 367 |
+
print(
|
| 368 |
+
" PYTHONPATH=src:envs uvicorn envs.dm_control_env.server.app:app --port 8000"
|
| 369 |
+
)
|
| 370 |
+
sys.exit(1)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
if __name__ == "__main__":
|
| 374 |
+
main()
|
examples/list_environments.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""List all available dm_control.suite environments.
|
| 3 |
+
|
| 4 |
+
This utility prints all available domain/task combinations from dm_control.suite.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dm_control import suite
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
print("Available dm_control.suite environments:")
|
| 12 |
+
print("=" * 50)
|
| 13 |
+
|
| 14 |
+
# Group by domain
|
| 15 |
+
domains = {}
|
| 16 |
+
for domain, task in suite.BENCHMARKING:
|
| 17 |
+
if domain not in domains:
|
| 18 |
+
domains[domain] = []
|
| 19 |
+
domains[domain].append(task)
|
| 20 |
+
|
| 21 |
+
for domain in sorted(domains.keys()):
|
| 22 |
+
tasks = sorted(domains[domain])
|
| 23 |
+
print(f"\n{domain}:")
|
| 24 |
+
for task in tasks:
|
| 25 |
+
# Load env to get action spec
|
| 26 |
+
try:
|
| 27 |
+
env = suite.load(domain_name=domain, task_name=task)
|
| 28 |
+
action_spec = env.action_spec()
|
| 29 |
+
action_dim = action_spec.shape[0]
|
| 30 |
+
obs_keys = list(env.observation_spec().keys())
|
| 31 |
+
env.close()
|
| 32 |
+
print(f" - {task:20s} (action_dim={action_dim}, obs={obs_keys})")
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f" - {task:20s} (error: {e})")
|
| 35 |
+
|
| 36 |
+
print("\n" + "=" * 50)
|
| 37 |
+
print(f"Total: {len(suite.BENCHMARKING)} environments")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
main()
|
examples/quadruped_control.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Interactive quadruped control via OpenEnv.
|
| 3 |
+
|
| 4 |
+
This example demonstrates using the dm_control OpenEnv client with
|
| 5 |
+
the quadruped environment. Press SPACE to apply random forces to the joints.
|
| 6 |
+
|
| 7 |
+
Controls:
|
| 8 |
+
SPACE: Apply random force to all joints
|
| 9 |
+
R: Reset environment
|
| 10 |
+
ESC or Q: Quit
|
| 11 |
+
|
| 12 |
+
Requirements:
|
| 13 |
+
pip install pygame
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
1. Start the server: uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 17 |
+
2. Run this script: python examples/quadruped_control.py
|
| 18 |
+
|
| 19 |
+
For visual mode (requires working MuJoCo rendering):
|
| 20 |
+
python examples/quadruped_control.py --visual
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import random
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
# Add parent directory to path for imports
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 30 |
+
|
| 31 |
+
from client import DMControlEnv
|
| 32 |
+
from models import DMControlAction
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_action_dim(env: DMControlEnv) -> int:
|
| 36 |
+
"""Get the action dimension from the environment state."""
|
| 37 |
+
state = env.state()
|
| 38 |
+
action_spec = state.action_spec
|
| 39 |
+
if action_spec and "shape" in action_spec:
|
| 40 |
+
shape = action_spec["shape"]
|
| 41 |
+
if isinstance(shape, list) and len(shape) > 0:
|
| 42 |
+
return shape[0]
|
| 43 |
+
# Quadruped default: 12 actuators (3 per leg x 4 legs)
|
| 44 |
+
return 12
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def generate_random_action(action_dim: int, magnitude: float = 1.0) -> DMControlAction:
|
| 48 |
+
"""Generate a random action with values in [-magnitude, magnitude]."""
|
| 49 |
+
values = [random.uniform(-magnitude, magnitude) for _ in range(action_dim)]
|
| 50 |
+
return DMControlAction(values=values)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def generate_zero_action(action_dim: int) -> DMControlAction:
|
| 54 |
+
"""Generate a zero action (no force applied)."""
|
| 55 |
+
return DMControlAction(values=[0.0] * action_dim)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def run_headless(env: DMControlEnv, max_steps: int = 1000):
|
| 59 |
+
"""Run quadruped control in headless mode."""
|
| 60 |
+
print("\n=== Headless Mode (OpenEnv Step/Observation Pattern) ===")
|
| 61 |
+
print("This mode demonstrates the OpenEnv API with the quadruped.\n")
|
| 62 |
+
|
| 63 |
+
# Reset environment using OpenEnv pattern
|
| 64 |
+
result = env.reset(domain_name="quadruped", task_name="walk")
|
| 65 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 66 |
+
|
| 67 |
+
# Get action dimension
|
| 68 |
+
action_dim = get_action_dim(env)
|
| 69 |
+
print(f"Action dimension: {action_dim}")
|
| 70 |
+
|
| 71 |
+
total_reward = 0.0
|
| 72 |
+
step_count = 0
|
| 73 |
+
|
| 74 |
+
print("\nRunning with periodic random forces...")
|
| 75 |
+
print("Every 50 steps, a random force burst is applied.\n")
|
| 76 |
+
|
| 77 |
+
while not result.done and step_count < max_steps:
|
| 78 |
+
# Apply random force every 50 steps, otherwise zero action
|
| 79 |
+
if step_count % 50 < 10:
|
| 80 |
+
# Random force burst for 10 steps
|
| 81 |
+
action = generate_random_action(action_dim, magnitude=0.5)
|
| 82 |
+
else:
|
| 83 |
+
# No force
|
| 84 |
+
action = generate_zero_action(action_dim)
|
| 85 |
+
|
| 86 |
+
# Step the environment using OpenEnv pattern
|
| 87 |
+
result = env.step(action)
|
| 88 |
+
|
| 89 |
+
# Access observation and reward from result
|
| 90 |
+
total_reward += result.reward or 0.0
|
| 91 |
+
step_count += 1
|
| 92 |
+
|
| 93 |
+
# Print progress periodically
|
| 94 |
+
if step_count % 100 == 0:
|
| 95 |
+
# Get some observation values
|
| 96 |
+
egocentric_state = result.observation.observations.get(
|
| 97 |
+
"egocentric_state", []
|
| 98 |
+
)
|
| 99 |
+
print(
|
| 100 |
+
f"Step {step_count}: reward={result.reward:.3f}, "
|
| 101 |
+
f"total={total_reward:.2f}, done={result.done}"
|
| 102 |
+
)
|
| 103 |
+
if egocentric_state:
|
| 104 |
+
print(f" egocentric_state (first 5): {egocentric_state[:5]}")
|
| 105 |
+
|
| 106 |
+
print(f"\nEpisode finished: {step_count} steps, total reward: {total_reward:.2f}")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def run_interactive(env: DMControlEnv):
|
| 110 |
+
"""Run interactive control with keyboard input via pygame."""
|
| 111 |
+
import pygame
|
| 112 |
+
|
| 113 |
+
print("\n=== Interactive Mode (OpenEnv Step/Observation Pattern) ===")
|
| 114 |
+
print("Press SPACE to apply random force, R to reset, ESC to quit.\n")
|
| 115 |
+
|
| 116 |
+
# Reset environment using OpenEnv pattern
|
| 117 |
+
result = env.reset(domain_name="quadruped", task_name="walk")
|
| 118 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 119 |
+
|
| 120 |
+
# Get action dimension
|
| 121 |
+
action_dim = get_action_dim(env)
|
| 122 |
+
print(f"Action dimension: {action_dim}")
|
| 123 |
+
|
| 124 |
+
# Initialize pygame for keyboard input (minimal window)
|
| 125 |
+
pygame.init()
|
| 126 |
+
screen = pygame.display.set_mode((400, 100))
|
| 127 |
+
pygame.display.set_caption("Quadruped Control - SPACE for random force, R to reset")
|
| 128 |
+
clock = pygame.time.Clock()
|
| 129 |
+
|
| 130 |
+
# Draw instructions on the window
|
| 131 |
+
font = pygame.font.Font(None, 24)
|
| 132 |
+
|
| 133 |
+
running = True
|
| 134 |
+
total_reward = 0.0
|
| 135 |
+
step_count = 0
|
| 136 |
+
apply_random_force = False
|
| 137 |
+
|
| 138 |
+
print("\nControls:")
|
| 139 |
+
print(" SPACE: Apply random force to joints")
|
| 140 |
+
print(" R: Reset environment")
|
| 141 |
+
print(" ESC or Q: Quit\n")
|
| 142 |
+
|
| 143 |
+
while running:
|
| 144 |
+
# Handle events
|
| 145 |
+
for event in pygame.event.get():
|
| 146 |
+
if event.type == pygame.QUIT:
|
| 147 |
+
running = False
|
| 148 |
+
elif event.type == pygame.KEYDOWN:
|
| 149 |
+
if event.key in (pygame.K_ESCAPE, pygame.K_q):
|
| 150 |
+
running = False
|
| 151 |
+
elif event.key == pygame.K_r:
|
| 152 |
+
result = env.reset(domain_name="quadruped", task_name="walk")
|
| 153 |
+
total_reward = 0.0
|
| 154 |
+
step_count = 0
|
| 155 |
+
print("Environment reset")
|
| 156 |
+
|
| 157 |
+
# Check for held keys
|
| 158 |
+
keys = pygame.key.get_pressed()
|
| 159 |
+
apply_random_force = keys[pygame.K_SPACE]
|
| 160 |
+
|
| 161 |
+
# Generate action based on input
|
| 162 |
+
if apply_random_force:
|
| 163 |
+
action = generate_random_action(action_dim, magnitude=2.0)
|
| 164 |
+
else:
|
| 165 |
+
action = generate_zero_action(action_dim)
|
| 166 |
+
|
| 167 |
+
# Step the environment using OpenEnv pattern
|
| 168 |
+
result = env.step(action)
|
| 169 |
+
|
| 170 |
+
# Track reward from result
|
| 171 |
+
total_reward += result.reward or 0.0
|
| 172 |
+
step_count += 1
|
| 173 |
+
|
| 174 |
+
# Check if episode is done
|
| 175 |
+
if result.done:
|
| 176 |
+
print(
|
| 177 |
+
f"Episode finished! Steps: {step_count}, "
|
| 178 |
+
f"Total reward: {total_reward:.2f}"
|
| 179 |
+
)
|
| 180 |
+
# Auto-reset on done
|
| 181 |
+
result = env.reset(domain_name="quadruped", task_name="walk")
|
| 182 |
+
total_reward = 0.0
|
| 183 |
+
step_count = 0
|
| 184 |
+
|
| 185 |
+
# Update display
|
| 186 |
+
screen.fill((30, 30, 30))
|
| 187 |
+
status = "FORCE!" if apply_random_force else "idle"
|
| 188 |
+
text = font.render(
|
| 189 |
+
f"Step: {step_count} | Reward: {total_reward:.1f} | {status}",
|
| 190 |
+
True,
|
| 191 |
+
(255, 255, 255),
|
| 192 |
+
)
|
| 193 |
+
screen.blit(text, (10, 40))
|
| 194 |
+
pygame.display.flip()
|
| 195 |
+
|
| 196 |
+
# Print progress periodically
|
| 197 |
+
if step_count % 200 == 0 and step_count > 0:
|
| 198 |
+
print(f"Step {step_count}: Total reward: {total_reward:.2f}")
|
| 199 |
+
|
| 200 |
+
# Cap at 30 FPS
|
| 201 |
+
clock.tick(30)
|
| 202 |
+
|
| 203 |
+
pygame.quit()
|
| 204 |
+
print(f"Session ended. Final reward: {total_reward:.2f}")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def run_visual(env: DMControlEnv):
|
| 208 |
+
"""Run with pygame visualization showing rendered frames."""
|
| 209 |
+
import base64
|
| 210 |
+
import io
|
| 211 |
+
|
| 212 |
+
import pygame
|
| 213 |
+
|
| 214 |
+
print("\n=== Visual Mode (OpenEnv Step/Observation Pattern) ===")
|
| 215 |
+
|
| 216 |
+
# Reset environment with rendering enabled
|
| 217 |
+
result = env.reset(domain_name="quadruped", task_name="walk", render=True)
|
| 218 |
+
print(f"Initial observations: {list(result.observation.observations.keys())}")
|
| 219 |
+
|
| 220 |
+
# Get action dimension
|
| 221 |
+
action_dim = get_action_dim(env)
|
| 222 |
+
print(f"Action dimension: {action_dim}")
|
| 223 |
+
|
| 224 |
+
# Get first frame to determine window size
|
| 225 |
+
if result.observation.pixels is None:
|
| 226 |
+
print("Error: Server did not return rendered pixels.")
|
| 227 |
+
print("Make sure the server supports render=True")
|
| 228 |
+
print("\nTry running in interactive mode (default) instead.")
|
| 229 |
+
sys.exit(1)
|
| 230 |
+
|
| 231 |
+
# Decode base64 PNG to pygame surface
|
| 232 |
+
png_data = base64.b64decode(result.observation.pixels)
|
| 233 |
+
frame = pygame.image.load(io.BytesIO(png_data))
|
| 234 |
+
frame_size = frame.get_size()
|
| 235 |
+
|
| 236 |
+
# Initialize pygame
|
| 237 |
+
pygame.init()
|
| 238 |
+
screen = pygame.display.set_mode(frame_size)
|
| 239 |
+
pygame.display.set_caption(
|
| 240 |
+
"Quadruped (OpenEnv) - SPACE for random force, R to Reset, ESC to Quit"
|
| 241 |
+
)
|
| 242 |
+
clock = pygame.time.Clock()
|
| 243 |
+
|
| 244 |
+
print("Controls:")
|
| 245 |
+
print(" SPACE: Apply random force to joints")
|
| 246 |
+
print(" R: Reset environment")
|
| 247 |
+
print(" ESC or Q: Quit")
|
| 248 |
+
|
| 249 |
+
running = True
|
| 250 |
+
total_reward = 0.0
|
| 251 |
+
step_count = 0
|
| 252 |
+
|
| 253 |
+
while running:
|
| 254 |
+
# Handle events
|
| 255 |
+
for event in pygame.event.get():
|
| 256 |
+
if event.type == pygame.QUIT:
|
| 257 |
+
running = False
|
| 258 |
+
elif event.type == pygame.KEYDOWN:
|
| 259 |
+
if event.key in (pygame.K_ESCAPE, pygame.K_q):
|
| 260 |
+
running = False
|
| 261 |
+
elif event.key == pygame.K_r:
|
| 262 |
+
result = env.reset(
|
| 263 |
+
domain_name="quadruped", task_name="walk", render=True
|
| 264 |
+
)
|
| 265 |
+
total_reward = 0.0
|
| 266 |
+
step_count = 0
|
| 267 |
+
print("Environment reset")
|
| 268 |
+
|
| 269 |
+
# Check for held keys
|
| 270 |
+
keys = pygame.key.get_pressed()
|
| 271 |
+
apply_random_force = keys[pygame.K_SPACE]
|
| 272 |
+
|
| 273 |
+
# Generate action based on input
|
| 274 |
+
if apply_random_force:
|
| 275 |
+
action = generate_random_action(action_dim, magnitude=2.0)
|
| 276 |
+
else:
|
| 277 |
+
action = generate_zero_action(action_dim)
|
| 278 |
+
|
| 279 |
+
# Step the environment using OpenEnv pattern
|
| 280 |
+
result = env.step(action, render=True)
|
| 281 |
+
|
| 282 |
+
# Track reward from result
|
| 283 |
+
total_reward += result.reward or 0.0
|
| 284 |
+
step_count += 1
|
| 285 |
+
|
| 286 |
+
# Check if episode is done
|
| 287 |
+
if result.done:
|
| 288 |
+
print(
|
| 289 |
+
f"Episode finished! Steps: {step_count}, "
|
| 290 |
+
f"Total reward: {total_reward:.2f}"
|
| 291 |
+
)
|
| 292 |
+
result = env.reset(domain_name="quadruped", task_name="walk", render=True)
|
| 293 |
+
total_reward = 0.0
|
| 294 |
+
step_count = 0
|
| 295 |
+
|
| 296 |
+
# Render the frame from observation pixels
|
| 297 |
+
if result.observation.pixels:
|
| 298 |
+
png_data = base64.b64decode(result.observation.pixels)
|
| 299 |
+
frame = pygame.image.load(io.BytesIO(png_data))
|
| 300 |
+
screen.blit(frame, (0, 0))
|
| 301 |
+
pygame.display.flip()
|
| 302 |
+
|
| 303 |
+
# Print progress periodically
|
| 304 |
+
if step_count % 200 == 0 and step_count > 0:
|
| 305 |
+
print(f"Step {step_count}: Total reward: {total_reward:.2f}")
|
| 306 |
+
|
| 307 |
+
# Cap at 30 FPS
|
| 308 |
+
clock.tick(30)
|
| 309 |
+
|
| 310 |
+
pygame.quit()
|
| 311 |
+
print(f"Session ended. Final reward: {total_reward:.2f}")
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def main():
|
| 315 |
+
parser = argparse.ArgumentParser(
|
| 316 |
+
description="Interactive quadruped control via OpenEnv"
|
| 317 |
+
)
|
| 318 |
+
parser.add_argument(
|
| 319 |
+
"--visual",
|
| 320 |
+
action="store_true",
|
| 321 |
+
help="Enable pygame visualization with rendered frames",
|
| 322 |
+
)
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--headless",
|
| 325 |
+
action="store_true",
|
| 326 |
+
help="Run in headless mode (no pygame, automated control)",
|
| 327 |
+
)
|
| 328 |
+
parser.add_argument(
|
| 329 |
+
"--max-steps",
|
| 330 |
+
type=int,
|
| 331 |
+
default=1000,
|
| 332 |
+
help="Maximum steps for headless mode (default: 1000)",
|
| 333 |
+
)
|
| 334 |
+
parser.add_argument(
|
| 335 |
+
"--task",
|
| 336 |
+
type=str,
|
| 337 |
+
default="walk",
|
| 338 |
+
choices=["walk", "run", "escape", "fetch"],
|
| 339 |
+
help="Quadruped task (default: walk)",
|
| 340 |
+
)
|
| 341 |
+
args = parser.parse_args()
|
| 342 |
+
|
| 343 |
+
server_url = "http://localhost:8000"
|
| 344 |
+
print(f"Connecting to {server_url}...")
|
| 345 |
+
|
| 346 |
+
try:
|
| 347 |
+
with DMControlEnv(base_url=server_url) as env:
|
| 348 |
+
print("Connected!")
|
| 349 |
+
|
| 350 |
+
# Get environment state
|
| 351 |
+
state = env.state()
|
| 352 |
+
print(f"Domain: {state.domain_name}, Task: {state.task_name}")
|
| 353 |
+
print(f"Action spec: {state.action_spec}")
|
| 354 |
+
|
| 355 |
+
if args.headless:
|
| 356 |
+
run_headless(env, max_steps=args.max_steps)
|
| 357 |
+
elif args.visual:
|
| 358 |
+
run_visual(env)
|
| 359 |
+
else:
|
| 360 |
+
run_interactive(env)
|
| 361 |
+
|
| 362 |
+
except ConnectionError as e:
|
| 363 |
+
print(f"Failed to connect: {e}")
|
| 364 |
+
print("\nMake sure the server is running:")
|
| 365 |
+
print(" cd OpenEnv")
|
| 366 |
+
print(
|
| 367 |
+
" PYTHONPATH=src:envs uvicorn envs.dm_control_env.server.app:app --port 8000"
|
| 368 |
+
)
|
| 369 |
+
sys.exit(1)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
if __name__ == "__main__":
|
| 373 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Data models for the dm_control OpenEnv Environment.
|
| 9 |
+
|
| 10 |
+
This environment wraps dm_control.suite, providing access to all MuJoCo-based
|
| 11 |
+
continuous control tasks (cartpole, walker, humanoid, cheetah, etc.).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from typing import Any, Dict, List, Optional
|
| 15 |
+
|
| 16 |
+
from pydantic import Field
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 20 |
+
except ImportError:
|
| 21 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DMControlAction(Action):
|
| 25 |
+
"""
|
| 26 |
+
Action for dm_control environments.
|
| 27 |
+
|
| 28 |
+
All dm_control.suite environments use continuous actions represented as
|
| 29 |
+
a list of float values. The size and bounds depend on the specific
|
| 30 |
+
domain/task combination.
|
| 31 |
+
|
| 32 |
+
Example (cartpole - 1D action):
|
| 33 |
+
>>> action = DMControlAction(values=[0.5]) # Push cart right
|
| 34 |
+
|
| 35 |
+
Example (walker - 6D action):
|
| 36 |
+
>>> action = DMControlAction(values=[0.1, -0.2, 0.3, 0.0, -0.1, 0.2])
|
| 37 |
+
|
| 38 |
+
Attributes:
|
| 39 |
+
values: List of continuous action values. Shape and bounds depend on
|
| 40 |
+
the loaded environment's action_spec.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
values: List[float] = Field(
|
| 44 |
+
default_factory=list,
|
| 45 |
+
description="Continuous action values matching the environment's action_spec",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DMControlObservation(Observation):
|
| 50 |
+
"""
|
| 51 |
+
Observation from dm_control environments.
|
| 52 |
+
|
| 53 |
+
dm_control environments return observations as a dictionary of named arrays.
|
| 54 |
+
Common observation keys include 'position', 'velocity', 'orientations', etc.
|
| 55 |
+
The exact keys depend on the domain/task combination.
|
| 56 |
+
|
| 57 |
+
Example observation keys by domain:
|
| 58 |
+
- cartpole: 'position' (cos/sin of angle), 'velocity'
|
| 59 |
+
- walker: 'orientations', 'height', 'velocity'
|
| 60 |
+
- humanoid: 'joint_angles', 'head_height', 'extremities', 'torso_vertical', 'com_velocity'
|
| 61 |
+
|
| 62 |
+
Attributes:
|
| 63 |
+
observations: Dictionary mapping observation names to their values.
|
| 64 |
+
Each value is a flattened list of floats.
|
| 65 |
+
pixels: Optional base64-encoded PNG image of the rendered scene.
|
| 66 |
+
Only included when render=True is passed to reset/step.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
observations: Dict[str, List[float]] = Field(
|
| 70 |
+
default_factory=dict,
|
| 71 |
+
description="Named observation arrays from the environment",
|
| 72 |
+
)
|
| 73 |
+
pixels: Optional[str] = Field(
|
| 74 |
+
default=None,
|
| 75 |
+
description="Base64-encoded PNG image (when render=True)",
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class DMControlState(State):
|
| 80 |
+
"""
|
| 81 |
+
Extended state for dm_control environments.
|
| 82 |
+
|
| 83 |
+
Provides metadata about the currently loaded environment including
|
| 84 |
+
the domain/task names and action/observation specifications.
|
| 85 |
+
|
| 86 |
+
Attributes:
|
| 87 |
+
episode_id: Unique identifier for the current episode.
|
| 88 |
+
step_count: Number of steps taken in the current episode.
|
| 89 |
+
domain_name: The dm_control domain (e.g., 'cartpole', 'walker').
|
| 90 |
+
task_name: The specific task (e.g., 'balance', 'walk').
|
| 91 |
+
action_spec: Specification of the action space including shape and bounds.
|
| 92 |
+
observation_spec: Specification of the observation space.
|
| 93 |
+
physics_timestep: The physics simulation timestep in seconds.
|
| 94 |
+
control_timestep: The control timestep (time between actions) in seconds.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
domain_name: str = Field(
|
| 98 |
+
default="cartpole",
|
| 99 |
+
description="The dm_control domain name",
|
| 100 |
+
)
|
| 101 |
+
task_name: str = Field(
|
| 102 |
+
default="balance",
|
| 103 |
+
description="The task name within the domain",
|
| 104 |
+
)
|
| 105 |
+
action_spec: Dict[str, Any] = Field(
|
| 106 |
+
default_factory=dict,
|
| 107 |
+
description="Specification of the action space (shape, dtype, bounds)",
|
| 108 |
+
)
|
| 109 |
+
observation_spec: Dict[str, Any] = Field(
|
| 110 |
+
default_factory=dict,
|
| 111 |
+
description="Specification of the observation space",
|
| 112 |
+
)
|
| 113 |
+
physics_timestep: float = Field(
|
| 114 |
+
default=0.002,
|
| 115 |
+
description="Physics simulation timestep in seconds",
|
| 116 |
+
)
|
| 117 |
+
control_timestep: float = Field(
|
| 118 |
+
default=0.02,
|
| 119 |
+
description="Control timestep (time between actions) in seconds",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# Available dm_control.suite environments
|
| 124 |
+
# Format: (domain_name, task_name)
|
| 125 |
+
AVAILABLE_ENVIRONMENTS = [
|
| 126 |
+
# Cartpole
|
| 127 |
+
("cartpole", "balance"),
|
| 128 |
+
("cartpole", "balance_sparse"),
|
| 129 |
+
("cartpole", "swingup"),
|
| 130 |
+
("cartpole", "swingup_sparse"),
|
| 131 |
+
# Pendulum
|
| 132 |
+
("pendulum", "swingup"),
|
| 133 |
+
# Point mass
|
| 134 |
+
("point_mass", "easy"),
|
| 135 |
+
("point_mass", "hard"),
|
| 136 |
+
# Reacher
|
| 137 |
+
("reacher", "easy"),
|
| 138 |
+
("reacher", "hard"),
|
| 139 |
+
# Ball in cup
|
| 140 |
+
("ball_in_cup", "catch"),
|
| 141 |
+
# Finger
|
| 142 |
+
("finger", "spin"),
|
| 143 |
+
("finger", "turn_easy"),
|
| 144 |
+
("finger", "turn_hard"),
|
| 145 |
+
# Fish
|
| 146 |
+
("fish", "upright"),
|
| 147 |
+
("fish", "swim"),
|
| 148 |
+
# Cheetah
|
| 149 |
+
("cheetah", "run"),
|
| 150 |
+
# Walker
|
| 151 |
+
("walker", "stand"),
|
| 152 |
+
("walker", "walk"),
|
| 153 |
+
("walker", "run"),
|
| 154 |
+
# Hopper
|
| 155 |
+
("hopper", "stand"),
|
| 156 |
+
("hopper", "hop"),
|
| 157 |
+
# Swimmer
|
| 158 |
+
("swimmer", "swimmer6"),
|
| 159 |
+
("swimmer", "swimmer15"),
|
| 160 |
+
# Humanoid
|
| 161 |
+
("humanoid", "stand"),
|
| 162 |
+
("humanoid", "walk"),
|
| 163 |
+
("humanoid", "run"),
|
| 164 |
+
# Manipulator
|
| 165 |
+
("manipulator", "bring_ball"),
|
| 166 |
+
("manipulator", "bring_peg"),
|
| 167 |
+
("manipulator", "insert_ball"),
|
| 168 |
+
("manipulator", "insert_peg"),
|
| 169 |
+
# Acrobot
|
| 170 |
+
("acrobot", "swingup"),
|
| 171 |
+
("acrobot", "swingup_sparse"),
|
| 172 |
+
# Stacker
|
| 173 |
+
("stacker", "stack_2"),
|
| 174 |
+
("stacker", "stack_4"),
|
| 175 |
+
# Dog
|
| 176 |
+
("dog", "stand"),
|
| 177 |
+
("dog", "walk"),
|
| 178 |
+
("dog", "trot"),
|
| 179 |
+
("dog", "run"),
|
| 180 |
+
("dog", "fetch"),
|
| 181 |
+
# Quadruped
|
| 182 |
+
("quadruped", "walk"),
|
| 183 |
+
("quadruped", "run"),
|
| 184 |
+
("quadruped", "escape"),
|
| 185 |
+
("quadruped", "fetch"),
|
| 186 |
+
]
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: dm_control_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
pyproject.toml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-dmcontrol-env"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "dm_control Environment for OpenEnv - wraps MuJoCo-based continuous control tasks (cartpole, walker, humanoid, etc.)"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv dependencies
|
| 18 |
+
"openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v2.1.0",
|
| 19 |
+
"fastapi>=0.115.0",
|
| 20 |
+
"pydantic>=2.0.0",
|
| 21 |
+
"uvicorn>=0.24.0",
|
| 22 |
+
"requests>=2.31.0",
|
| 23 |
+
# dm_control dependencies
|
| 24 |
+
"mujoco>=3.0.0",
|
| 25 |
+
"dm_control>=1.0.0",
|
| 26 |
+
"numpy>=1.20.0",
|
| 27 |
+
# Optional: for pixel observations
|
| 28 |
+
"pillow>=9.0.0",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
]
|
| 36 |
+
interactive = [
|
| 37 |
+
# For interactive examples with keyboard control
|
| 38 |
+
"pygame>=2.0.0",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
[project.scripts]
|
| 42 |
+
# Server entry point - enables running via: uv run --project . server
|
| 43 |
+
server = "dm_control_env.server.app:main"
|
| 44 |
+
|
| 45 |
+
[tool.setuptools]
|
| 46 |
+
include-package-data = true
|
| 47 |
+
packages = ["dm_control_env", "dm_control_env.server"]
|
| 48 |
+
package-dir = { "dm_control_env" = ".", "dm_control_env.server" = "server" }
|
server/Dockerfile
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build for dm_control environment
|
| 8 |
+
# Uses pip for package installation
|
| 9 |
+
|
| 10 |
+
FROM python:3.11-slim AS builder
|
| 11 |
+
|
| 12 |
+
WORKDIR /app
|
| 13 |
+
|
| 14 |
+
# Install build dependencies including OpenGL for MuJoCo
|
| 15 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 16 |
+
build-essential \
|
| 17 |
+
git \
|
| 18 |
+
libgl1 \
|
| 19 |
+
libglx-mesa0 \
|
| 20 |
+
libglew-dev \
|
| 21 |
+
libosmesa6-dev \
|
| 22 |
+
libgl1-mesa-dev \
|
| 23 |
+
libglfw3 \
|
| 24 |
+
patchelf \
|
| 25 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 26 |
+
|
| 27 |
+
# Copy environment code
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
WORKDIR /app/env
|
| 31 |
+
|
| 32 |
+
# Install dependencies using pip
|
| 33 |
+
RUN pip install --upgrade pip && \
|
| 34 |
+
pip install --no-cache-dir -e .
|
| 35 |
+
|
| 36 |
+
# Final runtime stage
|
| 37 |
+
FROM python:3.11-slim
|
| 38 |
+
|
| 39 |
+
WORKDIR /app
|
| 40 |
+
|
| 41 |
+
# Install runtime dependencies (OpenGL for MuJoCo rendering, curl for healthcheck)
|
| 42 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 43 |
+
curl \
|
| 44 |
+
libgl1 \
|
| 45 |
+
libglx-mesa0 \
|
| 46 |
+
libglew-dev \
|
| 47 |
+
libosmesa6-dev \
|
| 48 |
+
libglfw3 \
|
| 49 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 50 |
+
|
| 51 |
+
# Copy installed packages from builder
|
| 52 |
+
COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
|
| 53 |
+
COPY --from=builder /usr/local/bin /usr/local/bin
|
| 54 |
+
|
| 55 |
+
# Copy the environment code
|
| 56 |
+
COPY . /app/env
|
| 57 |
+
|
| 58 |
+
# Set PYTHONPATH so imports work correctly
|
| 59 |
+
ENV PYTHONPATH="/app/env"
|
| 60 |
+
|
| 61 |
+
# Set MuJoCo to use OSMesa for headless rendering
|
| 62 |
+
ENV MUJOCO_GL="osmesa"
|
| 63 |
+
|
| 64 |
+
# Expose port
|
| 65 |
+
EXPOSE 8000
|
| 66 |
+
|
| 67 |
+
# Health check
|
| 68 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
|
| 69 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 70 |
+
|
| 71 |
+
# Run the FastAPI server
|
| 72 |
+
# Use exec to replace the shell with uvicorn so it receives SIGINT/SIGTERM directly
|
| 73 |
+
CMD ["sh", "-c", "cd /app/env && exec uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
server/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""dm_control OpenEnv server module."""
|
| 2 |
+
|
| 3 |
+
from .dm_control_environment import DMControlEnvironment
|
| 4 |
+
|
| 5 |
+
__all__ = ["DMControlEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
FastAPI application for the dm_control Environment.
|
| 9 |
+
|
| 10 |
+
This module creates an HTTP server that exposes dm_control.suite environments
|
| 11 |
+
over HTTP and WebSocket endpoints, compatible with EnvClient.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
# Development (with auto-reload):
|
| 15 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 16 |
+
|
| 17 |
+
# Production:
|
| 18 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 19 |
+
|
| 20 |
+
# Or run directly:
|
| 21 |
+
uv run --project . server
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from openenv.core.env_server.http_server import create_app
|
| 26 |
+
|
| 27 |
+
from ..models import DMControlAction, DMControlObservation
|
| 28 |
+
from .dm_control_environment import DMControlEnvironment
|
| 29 |
+
except ImportError:
|
| 30 |
+
from openenv.core.env_server.http_server import create_app
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
import sys
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
_parent = str(Path(__file__).parent.parent)
|
| 37 |
+
if _parent not in sys.path:
|
| 38 |
+
sys.path.insert(0, _parent)
|
| 39 |
+
from models import DMControlAction, DMControlObservation
|
| 40 |
+
from server.dm_control_environment import DMControlEnvironment
|
| 41 |
+
except ImportError:
|
| 42 |
+
try:
|
| 43 |
+
from dm_control_env.models import DMControlAction, DMControlObservation
|
| 44 |
+
from dm_control_env.server.dm_control_environment import (
|
| 45 |
+
DMControlEnvironment,
|
| 46 |
+
)
|
| 47 |
+
except ImportError:
|
| 48 |
+
from envs.dm_control_env.models import DMControlAction, DMControlObservation
|
| 49 |
+
from envs.dm_control_env.server.dm_control_environment import (
|
| 50 |
+
DMControlEnvironment,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Create the app with web interface
|
| 54 |
+
# Pass the class (factory) for concurrent session support
|
| 55 |
+
app = create_app(
|
| 56 |
+
DMControlEnvironment,
|
| 57 |
+
DMControlAction,
|
| 58 |
+
DMControlObservation,
|
| 59 |
+
env_name="dm_control_env",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main():
|
| 64 |
+
"""
|
| 65 |
+
Entry point for direct execution via uv run or python -m.
|
| 66 |
+
|
| 67 |
+
This function enables running the server without Docker:
|
| 68 |
+
uv run --project . server
|
| 69 |
+
python -m envs.dm_control_env.server.app
|
| 70 |
+
openenv serve dm_control_env
|
| 71 |
+
"""
|
| 72 |
+
import uvicorn
|
| 73 |
+
|
| 74 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
main()
|
server/dm_control_environment.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
dm_control Environment Implementation.
|
| 9 |
+
|
| 10 |
+
Wraps dm_control.suite environments (cartpole, walker, humanoid, etc.)
|
| 11 |
+
with the OpenEnv interface for standardized reinforcement learning.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import base64
|
| 15 |
+
import io
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
from typing import Any, Dict, Optional
|
| 19 |
+
from uuid import uuid4
|
| 20 |
+
|
| 21 |
+
# Configure MuJoCo rendering backend before importing dm_control
|
| 22 |
+
# On macOS, we don't set MUJOCO_GL - use default (glfw) which works
|
| 23 |
+
# when running synchronously in the main thread (see reset_async/step_async)
|
| 24 |
+
# On Linux, use egl for headless rendering
|
| 25 |
+
if "MUJOCO_GL" not in os.environ and sys.platform != "darwin":
|
| 26 |
+
os.environ.setdefault("MUJOCO_GL", "egl")
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from openenv.core.env_server.interfaces import Environment
|
| 32 |
+
|
| 33 |
+
from ..models import DMControlAction, DMControlObservation, DMControlState
|
| 34 |
+
except ImportError:
|
| 35 |
+
from openenv.core.env_server.interfaces import Environment
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
import sys
|
| 39 |
+
from pathlib import Path
|
| 40 |
+
|
| 41 |
+
_parent = str(Path(__file__).parent.parent)
|
| 42 |
+
if _parent not in sys.path:
|
| 43 |
+
sys.path.insert(0, _parent)
|
| 44 |
+
from models import DMControlAction, DMControlObservation, DMControlState
|
| 45 |
+
except ImportError:
|
| 46 |
+
try:
|
| 47 |
+
from dm_control_env.models import (
|
| 48 |
+
DMControlAction,
|
| 49 |
+
DMControlObservation,
|
| 50 |
+
DMControlState,
|
| 51 |
+
)
|
| 52 |
+
except ImportError:
|
| 53 |
+
from envs.dm_control_env.models import (
|
| 54 |
+
DMControlAction,
|
| 55 |
+
DMControlObservation,
|
| 56 |
+
DMControlState,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class DMControlEnvironment(Environment):
|
| 61 |
+
"""
|
| 62 |
+
Wraps dm_control.suite environments with the OpenEnv interface.
|
| 63 |
+
|
| 64 |
+
This environment supports all dm_control.suite domains and tasks including
|
| 65 |
+
cartpole, walker, humanoid, cheetah, and more.
|
| 66 |
+
|
| 67 |
+
Features:
|
| 68 |
+
- Dynamic environment switching via reset(domain_name="...", task_name="...")
|
| 69 |
+
- Support for all continuous control tasks
|
| 70 |
+
- Optional visual observations (base64-encoded images)
|
| 71 |
+
- Configurable via constructor or environment variables
|
| 72 |
+
|
| 73 |
+
Example:
|
| 74 |
+
>>> env = DMControlEnvironment()
|
| 75 |
+
>>> obs = env.reset() # Default: cartpole/balance
|
| 76 |
+
>>> print(obs.observations)
|
| 77 |
+
>>>
|
| 78 |
+
>>> # Take an action
|
| 79 |
+
>>> obs = env.step(DMControlAction(values=[0.5])) # Push cart right
|
| 80 |
+
>>> print(obs.reward)
|
| 81 |
+
|
| 82 |
+
Example with different environment:
|
| 83 |
+
>>> env = DMControlEnvironment(domain_name="walker", task_name="walk")
|
| 84 |
+
>>> obs = env.reset()
|
| 85 |
+
>>>
|
| 86 |
+
>>> # Or switch environment on reset
|
| 87 |
+
>>> obs = env.reset(domain_name="cheetah", task_name="run")
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
# dm_control environments are isolated and thread-safe
|
| 91 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
domain_name: Optional[str] = None,
|
| 96 |
+
task_name: Optional[str] = None,
|
| 97 |
+
render_height: Optional[int] = None,
|
| 98 |
+
render_width: Optional[int] = None,
|
| 99 |
+
):
|
| 100 |
+
"""
|
| 101 |
+
Initialize the dm_control environment.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
domain_name: The dm_control domain to load.
|
| 105 |
+
Env var: DMCONTROL_DOMAIN (default: cartpole)
|
| 106 |
+
task_name: The task within the domain.
|
| 107 |
+
Env var: DMCONTROL_TASK (default: balance)
|
| 108 |
+
render_height: Height of rendered images (when render=True).
|
| 109 |
+
Env var: DMCONTROL_RENDER_HEIGHT (default: 480)
|
| 110 |
+
render_width: Width of rendered images (when render=True).
|
| 111 |
+
Env var: DMCONTROL_RENDER_WIDTH (default: 640)
|
| 112 |
+
"""
|
| 113 |
+
self._env = None
|
| 114 |
+
|
| 115 |
+
self._domain_name = domain_name or os.environ.get(
|
| 116 |
+
"DMCONTROL_DOMAIN", "cartpole"
|
| 117 |
+
)
|
| 118 |
+
self._task_name = task_name or os.environ.get("DMCONTROL_TASK", "balance")
|
| 119 |
+
self._render_height = (
|
| 120 |
+
render_height
|
| 121 |
+
if render_height is not None
|
| 122 |
+
else int(os.environ.get("DMCONTROL_RENDER_HEIGHT", "480"))
|
| 123 |
+
)
|
| 124 |
+
self._render_width = (
|
| 125 |
+
render_width
|
| 126 |
+
if render_width is not None
|
| 127 |
+
else int(os.environ.get("DMCONTROL_RENDER_WIDTH", "640"))
|
| 128 |
+
)
|
| 129 |
+
self._include_pixels = False
|
| 130 |
+
|
| 131 |
+
self._state = DMControlState(
|
| 132 |
+
episode_id=str(uuid4()),
|
| 133 |
+
step_count=0,
|
| 134 |
+
domain_name=self._domain_name,
|
| 135 |
+
task_name=self._task_name,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def _load_environment(self, domain_name: str, task_name: str) -> None:
|
| 139 |
+
"""Load or switch to a dm_control environment."""
|
| 140 |
+
if self._env is not None:
|
| 141 |
+
try:
|
| 142 |
+
self._env.close()
|
| 143 |
+
except Exception:
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
from dm_control import suite
|
| 148 |
+
except ImportError as e:
|
| 149 |
+
raise ImportError(
|
| 150 |
+
"dm_control is required. Install with: pip install dm_control"
|
| 151 |
+
) from e
|
| 152 |
+
except Exception as e:
|
| 153 |
+
# MuJoCo/OpenGL initialization can fail on macOS
|
| 154 |
+
error_msg = str(e)
|
| 155 |
+
if sys.platform == "darwin":
|
| 156 |
+
raise RuntimeError(
|
| 157 |
+
f"Failed to import dm_control (MuJoCo error): {error_msg}\n\n"
|
| 158 |
+
"On macOS, try one of these solutions:\n"
|
| 159 |
+
"1. Install osmesa: brew install mesa\n"
|
| 160 |
+
"2. Run with MUJOCO_GL=glfw (requires display)\n"
|
| 161 |
+
"3. Run with MUJOCO_GL=egl (if EGL is available)"
|
| 162 |
+
) from e
|
| 163 |
+
raise
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
self._env = suite.load(domain_name=domain_name, task_name=task_name)
|
| 167 |
+
except Exception as e:
|
| 168 |
+
error_msg = str(e).lower()
|
| 169 |
+
# Check for MuJoCo/OpenGL errors
|
| 170 |
+
if "gl" in error_msg or "render" in error_msg or "display" in error_msg:
|
| 171 |
+
if sys.platform == "darwin":
|
| 172 |
+
raise RuntimeError(
|
| 173 |
+
f"MuJoCo initialization failed: {e}\n\n"
|
| 174 |
+
"On macOS, try one of these solutions:\n"
|
| 175 |
+
"1. Install osmesa: brew install mesa\n"
|
| 176 |
+
"2. Run with MUJOCO_GL=glfw (requires display)\n"
|
| 177 |
+
"3. Set PYOPENGL_PLATFORM=osmesa"
|
| 178 |
+
) from e
|
| 179 |
+
# Check if it's an invalid environment error
|
| 180 |
+
try:
|
| 181 |
+
available = [(d, t) for d, t in suite.BENCHMARKING]
|
| 182 |
+
raise ValueError(
|
| 183 |
+
f"Failed to load {domain_name}/{task_name}. "
|
| 184 |
+
f"Available environments: {available[:10]}... "
|
| 185 |
+
f"(use dm_control.suite.BENCHMARKING for full list)"
|
| 186 |
+
) from e
|
| 187 |
+
except Exception:
|
| 188 |
+
raise
|
| 189 |
+
|
| 190 |
+
self._domain_name = domain_name
|
| 191 |
+
self._task_name = task_name
|
| 192 |
+
|
| 193 |
+
self._state.domain_name = domain_name
|
| 194 |
+
self._state.task_name = task_name
|
| 195 |
+
self._state.action_spec = self._get_action_spec_info()
|
| 196 |
+
self._state.observation_spec = self._get_observation_spec_info()
|
| 197 |
+
self._state.physics_timestep = self._env.physics.timestep()
|
| 198 |
+
self._state.control_timestep = self._env.control_timestep()
|
| 199 |
+
|
| 200 |
+
def _get_action_spec_info(self) -> Dict[str, Any]:
|
| 201 |
+
"""Get information about the action space."""
|
| 202 |
+
spec = self._env.action_spec()
|
| 203 |
+
return {
|
| 204 |
+
"shape": list(spec.shape),
|
| 205 |
+
"dtype": str(spec.dtype),
|
| 206 |
+
"minimum": spec.minimum.tolist(),
|
| 207 |
+
"maximum": spec.maximum.tolist(),
|
| 208 |
+
"name": spec.name,
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
def _get_observation_spec_info(self) -> Dict[str, Any]:
|
| 212 |
+
"""Get information about the observation space."""
|
| 213 |
+
specs = self._env.observation_spec()
|
| 214 |
+
obs_info = {}
|
| 215 |
+
for name, spec in specs.items():
|
| 216 |
+
obs_info[name] = {
|
| 217 |
+
"shape": list(spec.shape),
|
| 218 |
+
"dtype": str(spec.dtype),
|
| 219 |
+
}
|
| 220 |
+
return obs_info
|
| 221 |
+
|
| 222 |
+
def _get_observation(
|
| 223 |
+
self,
|
| 224 |
+
time_step,
|
| 225 |
+
include_pixels: bool = False,
|
| 226 |
+
) -> DMControlObservation:
|
| 227 |
+
"""Convert dm_control TimeStep to DMControlObservation."""
|
| 228 |
+
import dm_env
|
| 229 |
+
|
| 230 |
+
observations = {}
|
| 231 |
+
for name, value in time_step.observation.items():
|
| 232 |
+
observations[name] = np.asarray(value).flatten().tolist()
|
| 233 |
+
|
| 234 |
+
pixels = None
|
| 235 |
+
if include_pixels:
|
| 236 |
+
try:
|
| 237 |
+
frame = self._env.physics.render(
|
| 238 |
+
height=self._render_height,
|
| 239 |
+
width=self._render_width,
|
| 240 |
+
camera_id=0,
|
| 241 |
+
)
|
| 242 |
+
from PIL import Image
|
| 243 |
+
|
| 244 |
+
img = Image.fromarray(frame)
|
| 245 |
+
buffer = io.BytesIO()
|
| 246 |
+
img.save(buffer, format="PNG")
|
| 247 |
+
pixels = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 248 |
+
except Exception:
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
done = time_step.step_type == dm_env.StepType.LAST
|
| 252 |
+
reward = float(time_step.reward) if time_step.reward is not None else 0.0
|
| 253 |
+
|
| 254 |
+
return DMControlObservation(
|
| 255 |
+
observations=observations,
|
| 256 |
+
pixels=pixels,
|
| 257 |
+
reward=reward,
|
| 258 |
+
done=done,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def reset(
|
| 262 |
+
self,
|
| 263 |
+
domain_name: Optional[str] = None,
|
| 264 |
+
task_name: Optional[str] = None,
|
| 265 |
+
seed: Optional[int] = None,
|
| 266 |
+
render: bool = False,
|
| 267 |
+
**kwargs,
|
| 268 |
+
) -> DMControlObservation:
|
| 269 |
+
"""
|
| 270 |
+
Reset the environment and return initial observation.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
domain_name: Optionally switch to a different domain.
|
| 274 |
+
task_name: Optionally switch to a different task.
|
| 275 |
+
seed: Random seed for reproducibility.
|
| 276 |
+
render: If True, include pixel observations.
|
| 277 |
+
**kwargs: Additional arguments (ignored).
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
DMControlObservation with initial state.
|
| 281 |
+
"""
|
| 282 |
+
self._include_pixels = render
|
| 283 |
+
|
| 284 |
+
target_domain = domain_name or self._domain_name
|
| 285 |
+
target_task = task_name or self._task_name
|
| 286 |
+
|
| 287 |
+
if (
|
| 288 |
+
self._env is None
|
| 289 |
+
or target_domain != self._domain_name
|
| 290 |
+
or target_task != self._task_name
|
| 291 |
+
):
|
| 292 |
+
self._load_environment(target_domain, target_task)
|
| 293 |
+
|
| 294 |
+
if seed is not None:
|
| 295 |
+
np.random.seed(seed)
|
| 296 |
+
|
| 297 |
+
time_step = self._env.reset()
|
| 298 |
+
|
| 299 |
+
self._state = DMControlState(
|
| 300 |
+
episode_id=str(uuid4()),
|
| 301 |
+
step_count=0,
|
| 302 |
+
domain_name=self._domain_name,
|
| 303 |
+
task_name=self._task_name,
|
| 304 |
+
action_spec=self._state.action_spec,
|
| 305 |
+
observation_spec=self._state.observation_spec,
|
| 306 |
+
physics_timestep=self._state.physics_timestep,
|
| 307 |
+
control_timestep=self._state.control_timestep,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
return self._get_observation(time_step, include_pixels=render)
|
| 311 |
+
|
| 312 |
+
def step(
|
| 313 |
+
self,
|
| 314 |
+
action: DMControlAction,
|
| 315 |
+
render: bool = False,
|
| 316 |
+
**kwargs,
|
| 317 |
+
) -> DMControlObservation:
|
| 318 |
+
"""
|
| 319 |
+
Execute one step in the environment.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
action: DMControlAction with continuous action values.
|
| 323 |
+
render: If True, include pixel observations.
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
DMControlObservation with new state, reward, and done flag.
|
| 327 |
+
"""
|
| 328 |
+
if self._env is None:
|
| 329 |
+
raise RuntimeError("Environment not initialized. Call reset() first.")
|
| 330 |
+
|
| 331 |
+
action_array = np.array(action.values, dtype=np.float64)
|
| 332 |
+
|
| 333 |
+
action_spec = self._env.action_spec()
|
| 334 |
+
expected_shape = action_spec.shape
|
| 335 |
+
if action_array.shape != expected_shape:
|
| 336 |
+
if action_array.size == np.prod(expected_shape):
|
| 337 |
+
action_array = action_array.reshape(expected_shape)
|
| 338 |
+
else:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
f"Action shape {action_array.shape} doesn't match "
|
| 341 |
+
f"expected shape {expected_shape}"
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
action_array = np.clip(action_array, action_spec.minimum, action_spec.maximum)
|
| 345 |
+
|
| 346 |
+
time_step = self._env.step(action_array)
|
| 347 |
+
self._state.step_count += 1
|
| 348 |
+
|
| 349 |
+
return self._get_observation(
|
| 350 |
+
time_step, include_pixels=render or self._include_pixels
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
async def reset_async(
|
| 354 |
+
self,
|
| 355 |
+
domain_name: Optional[str] = None,
|
| 356 |
+
task_name: Optional[str] = None,
|
| 357 |
+
seed: Optional[int] = None,
|
| 358 |
+
render: bool = False,
|
| 359 |
+
**kwargs,
|
| 360 |
+
) -> DMControlObservation:
|
| 361 |
+
"""Async version of reset.
|
| 362 |
+
|
| 363 |
+
On macOS, runs synchronously to avoid MuJoCo threading crashes.
|
| 364 |
+
On other platforms, runs in a thread pool.
|
| 365 |
+
"""
|
| 366 |
+
if sys.platform == "darwin":
|
| 367 |
+
# On macOS, MuJoCo crashes when run in a background thread
|
| 368 |
+
# Run synchronously (blocks event loop but avoids crash)
|
| 369 |
+
return self.reset(
|
| 370 |
+
domain_name=domain_name,
|
| 371 |
+
task_name=task_name,
|
| 372 |
+
seed=seed,
|
| 373 |
+
render=render,
|
| 374 |
+
**kwargs,
|
| 375 |
+
)
|
| 376 |
+
else:
|
| 377 |
+
import asyncio
|
| 378 |
+
|
| 379 |
+
return await asyncio.to_thread(
|
| 380 |
+
self.reset,
|
| 381 |
+
domain_name=domain_name,
|
| 382 |
+
task_name=task_name,
|
| 383 |
+
seed=seed,
|
| 384 |
+
render=render,
|
| 385 |
+
**kwargs,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
async def step_async(
|
| 389 |
+
self,
|
| 390 |
+
action: DMControlAction,
|
| 391 |
+
render: bool = False,
|
| 392 |
+
**kwargs,
|
| 393 |
+
) -> DMControlObservation:
|
| 394 |
+
"""Async version of step.
|
| 395 |
+
|
| 396 |
+
On macOS, runs synchronously to avoid MuJoCo threading crashes.
|
| 397 |
+
On other platforms, runs in a thread pool.
|
| 398 |
+
"""
|
| 399 |
+
if sys.platform == "darwin":
|
| 400 |
+
# On macOS, MuJoCo crashes when run in a background thread
|
| 401 |
+
# Run synchronously (blocks event loop but avoids crash)
|
| 402 |
+
return self.step(action, render=render, **kwargs)
|
| 403 |
+
else:
|
| 404 |
+
import asyncio
|
| 405 |
+
|
| 406 |
+
return await asyncio.to_thread(self.step, action, render=render, **kwargs)
|
| 407 |
+
|
| 408 |
+
@property
|
| 409 |
+
def state(self) -> DMControlState:
|
| 410 |
+
"""Get the current environment state."""
|
| 411 |
+
return self._state
|
| 412 |
+
|
| 413 |
+
def close(self) -> None:
|
| 414 |
+
"""Close the dm_control environment."""
|
| 415 |
+
env = getattr(self, "_env", None)
|
| 416 |
+
if env is not None:
|
| 417 |
+
try:
|
| 418 |
+
env.close()
|
| 419 |
+
except Exception:
|
| 420 |
+
pass
|
| 421 |
+
self._env = None
|
| 422 |
+
|
| 423 |
+
def __del__(self):
|
| 424 |
+
"""Cleanup on deletion."""
|
| 425 |
+
try:
|
| 426 |
+
self.close()
|
| 427 |
+
except Exception:
|
| 428 |
+
pass
|
src/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""EnvTorch: Standardized agentic execution environments."""
|
src/openenv.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv
|
| 3 |
+
Version: 0.2.0
|
| 4 |
+
Summary: A unified framework for reinforcement learning environments
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Description-Content-Type: text/markdown
|
| 7 |
+
License-File: LICENSE
|
| 8 |
+
Requires-Dist: fastapi>=0.104.0
|
| 9 |
+
Requires-Dist: pydantic>=2.0.0
|
| 10 |
+
Requires-Dist: uvicorn>=0.24.0
|
| 11 |
+
Requires-Dist: requests>=2.25.0
|
| 12 |
+
Requires-Dist: typer>=0.9.0
|
| 13 |
+
Requires-Dist: rich>=13.0.0
|
| 14 |
+
Requires-Dist: pyyaml>=6.0
|
| 15 |
+
Requires-Dist: huggingface_hub>=0.20.0
|
| 16 |
+
Requires-Dist: openai>=2.7.2
|
| 17 |
+
Requires-Dist: tomli>=2.3.0
|
| 18 |
+
Requires-Dist: tomli-w>=1.2.0
|
| 19 |
+
Requires-Dist: websockets>=15.0.1
|
| 20 |
+
Provides-Extra: core
|
| 21 |
+
Requires-Dist: fastapi>=0.104.0; extra == "core"
|
| 22 |
+
Requires-Dist: pydantic>=2.0.0; extra == "core"
|
| 23 |
+
Requires-Dist: uvicorn>=0.24.0; extra == "core"
|
| 24 |
+
Requires-Dist: requests>=2.25.0; extra == "core"
|
| 25 |
+
Requires-Dist: websockets>=15.0.1; extra == "core"
|
| 26 |
+
Provides-Extra: cli
|
| 27 |
+
Requires-Dist: typer>=0.9.0; extra == "cli"
|
| 28 |
+
Requires-Dist: rich>=13.0.0; extra == "cli"
|
| 29 |
+
Requires-Dist: pyyaml>=6.0; extra == "cli"
|
| 30 |
+
Requires-Dist: huggingface_hub>=0.20.0; extra == "cli"
|
| 31 |
+
Requires-Dist: openai>=2.7.2; extra == "cli"
|
| 32 |
+
Requires-Dist: tomli>=2.3.0; extra == "cli"
|
| 33 |
+
Requires-Dist: tomli-w>=1.2.0; extra == "cli"
|
| 34 |
+
Provides-Extra: all
|
| 35 |
+
Requires-Dist: openenv[core]; extra == "all"
|
| 36 |
+
Requires-Dist: openenv[cli]; extra == "all"
|
| 37 |
+
Dynamic: license-file
|
| 38 |
+
|
| 39 |
+
# <img width="35" height="35" alt="image" src="https://github.com/user-attachments/assets/2700a971-e5d6-4036-b03f-2f89c9791609" /> OpenEnv: Agentic Execution Environments
|
| 40 |
+
|
| 41 |
+
An e2e framework for creating, deploying and using isolated execution environments for agentic RL training, built using Gymnasium style simple APIs.
|
| 42 |
+
|
| 43 |
+
[](https://pypi.org/project/openenv/)
|
| 44 |
+
[](https://discord.gg/YsTYBh6PD9)
|
| 45 |
+
[](https://colab.research.google.com/github/meta-pytorch/OpenEnv/blob/main/examples/OpenEnv_Tutorial.ipynb)
|
| 46 |
+
[](https://meta-pytorch.org/OpenEnv/)
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
**🚀 Featured Example:** Train LLMs to play BlackJack using [torchforge](https://github.com/meta-pytorch/torchforge) (PyTorch's agentic RL framework): [`examples/grpo_blackjack/`](examples/grpo_blackjack/)
|
| 51 |
+
|
| 52 |
+
## OpenEnv on partner platforms:
|
| 53 |
+
|
| 54 |
+
- [Lightning AI Studio](https://lightning.ai/environments?section=featured)
|
| 55 |
+
- [TRL example](https://huggingface.co/docs/trl/main/en/openenv)
|
| 56 |
+
- [Unsloth Google Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/OpenEnv_gpt_oss_(20B)_Reinforcement_Learning_2048_Game.ipynb)
|
| 57 |
+
- [ART example](https://art.openpipe.ai/integrations/openenv-integration)
|
| 58 |
+
- [Oumi example](https://github.com/oumi-ai/oumi/blob/main/notebooks/Oumi%20-%20OpenEnv%20GRPO%20with%20trl.ipynb)
|
| 59 |
+
|
| 60 |
+
## Overview
|
| 61 |
+
|
| 62 |
+
OpenEnv provides a standard for interacting with agentic execution environments via simple Gymnasium style APIs - `step()`, `reset()`, `state()`. Users of agentic execution environments can interact with the environment during RL training loops using these simple APIs.
|
| 63 |
+
|
| 64 |
+
In addition to making it easier for researchers and RL framework writers, we also provide tools for environment creators making it easier for them to create richer environments and make them available over familiar protocols like HTTP and packaged using canonical technologies like docker. Environment creators can use the OpenEnv framework to create environments that are isolated, secure, and easy to deploy and use.
|
| 65 |
+
|
| 66 |
+
The OpenEnv CLI (`openenv`) provides commands to initialize new environments and deploy them to Hugging Face Spaces.
|
| 67 |
+
|
| 68 |
+
> ⚠️ **Early Development Warning** OpenEnv is currently in an experimental
|
| 69 |
+
> stage. You should expect bugs, incomplete features, and APIs that may change
|
| 70 |
+
> in future versions. The project welcomes bugfixes, but to make sure things are
|
| 71 |
+
> well coordinated you should discuss any significant change before starting the
|
| 72 |
+
> work. It's recommended that you signal your intention to contribute in the
|
| 73 |
+
> issue tracker, either by filing a new issue or by claiming an existing one.
|
| 74 |
+
|
| 75 |
+
### RFCs
|
| 76 |
+
|
| 77 |
+
Below is a list of active and historical RFCs for OpenEnv. RFCs are proposals for major changes or features. Please review and contribute!
|
| 78 |
+
|
| 79 |
+
- [RFC 001: Baseline API and Interface Specifications](https://github.com/meta-pytorch/OpenEnv/pull/26)
|
| 80 |
+
|
| 81 |
+
## Architecture
|
| 82 |
+
|
| 83 |
+
### Component Overview
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
┌─────────────────────────────────────────────────────────┐
|
| 87 |
+
│ Client Application │
|
| 88 |
+
│ ┌────────────────┐ ┌──────────────────┐ │
|
| 89 |
+
│ │ EchoEnv │ │ CodingEnv │ │
|
| 90 |
+
│ │ (HTTPEnvClient)│ �� (HTTPEnvClient) │ │
|
| 91 |
+
│ └────────┬───────┘ └────────┬─────────┘ │
|
| 92 |
+
└───────────┼───────────────────────────────┼─────────────┘
|
| 93 |
+
│ HTTP │ HTTP
|
| 94 |
+
│ (reset, step, state) │
|
| 95 |
+
┌───────────▼───────────────────────────────▼─────────────┐
|
| 96 |
+
│ Docker Containers (Isolated) │
|
| 97 |
+
│ ┌──────────────────────┐ ┌──────────────────────┐ │
|
| 98 |
+
│ │ FastAPI Server │ │ FastAPI Server │ │
|
| 99 |
+
│ │ EchoEnvironment │ │ PythonCodeActEnv │ │
|
| 100 |
+
│ │ (Environment base) │ │ (Environment base) │ │
|
| 101 |
+
│ └──────────────────────┘ └──────────────────────┘ │
|
| 102 |
+
└─────────────────────────────────────────────────────────┘
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### Core Components
|
| 106 |
+
|
| 107 |
+
#### 1. Web Interface
|
| 108 |
+
|
| 109 |
+
OpenEnv includes a built-in web interface for interactive environment exploration and debugging. The web interface provides:
|
| 110 |
+
|
| 111 |
+
- **Two-Pane Layout**: HumanAgent interaction on the left, state observation on the right
|
| 112 |
+
- **Real-time Updates**: WebSocket-based live updates without page refresh
|
| 113 |
+
- **Dynamic Forms**: Automatically generated action forms based on environment Action types
|
| 114 |
+
- **Action History**: Complete log of all actions taken and their results
|
| 115 |
+
|
| 116 |
+
The web interface is **conditionally enabled** based on environment variables:
|
| 117 |
+
|
| 118 |
+
- **Local Development**: Disabled by default for lightweight development
|
| 119 |
+
- **Manual Override**: Enable with `ENABLE_WEB_INTERFACE=true`
|
| 120 |
+
|
| 121 |
+
To use the web interface:
|
| 122 |
+
|
| 123 |
+
```python
|
| 124 |
+
from openenv.core.env_server import create_web_interface_app
|
| 125 |
+
from your_env.models import YourAction, YourObservation
|
| 126 |
+
from your_env.server.your_environment import YourEnvironment
|
| 127 |
+
|
| 128 |
+
env = YourEnvironment()
|
| 129 |
+
app = create_web_interface_app(env, YourAction, YourObservation)
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
When enabled, open `http://localhost:8000/web` in your browser to interact with the environment.
|
| 133 |
+
|
| 134 |
+
#### 2. Environment (Server-Side)
|
| 135 |
+
Base class for implementing environment logic:
|
| 136 |
+
- **`reset()`**: Initialize a new episode, returns initial `Observation`
|
| 137 |
+
- **`step(action)`**: Execute an `Action`, returns resulting `Observation`
|
| 138 |
+
- **`state()`**: Access episode metadata (`State` with episode_id, step_count, etc.)
|
| 139 |
+
|
| 140 |
+
#### 3. HTTPEnvClient (Client-Side)
|
| 141 |
+
Base class for HTTP communication:
|
| 142 |
+
- Handles HTTP requests to environment server
|
| 143 |
+
- Contains a utility to spin up a docker container locally for the corresponding environment
|
| 144 |
+
- Type-safe action/observation parsing
|
| 145 |
+
|
| 146 |
+
#### 4. Container Providers
|
| 147 |
+
Manage container deployment:
|
| 148 |
+
- `LocalDockerProvider`: Run containers on local Docker daemon
|
| 149 |
+
- `KubernetesProvider`: Deploy to K8s clusters (future)
|
| 150 |
+
|
| 151 |
+
#### 5. Models
|
| 152 |
+
Type-safe data structures:
|
| 153 |
+
- `Action`: Base class for environment actions
|
| 154 |
+
- `Observation`: Base class for environment observations
|
| 155 |
+
- `State`: Episode state tracking
|
| 156 |
+
- `StepResult`: Combines observation, reward, done flag
|
| 157 |
+
|
| 158 |
+
## Project Structure
|
| 159 |
+
|
| 160 |
+
### For Environment Creators
|
| 161 |
+
|
| 162 |
+
Use the CLI to quickly scaffold a new environment:
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
openenv init my_env
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
This creates the following structure:
|
| 169 |
+
|
| 170 |
+
```
|
| 171 |
+
my_env/
|
| 172 |
+
├── .dockerignore # Docker build exclusions
|
| 173 |
+
├── __init__.py # Export YourAction, YourObservation, YourEnv
|
| 174 |
+
├── models.py # Define Action, Observation, State dataclasses
|
| 175 |
+
├── client.py # Implement YourEnv(HTTPEnvClient)
|
| 176 |
+
├── README.md # Document your environment
|
| 177 |
+
├── openenv.yaml # Environment manifest
|
| 178 |
+
├── pyproject.toml # Dependencies and package configuration
|
| 179 |
+
├── outputs/ # Runtime outputs (logs, evals) - gitignored
|
| 180 |
+
│ ├── logs/
|
| 181 |
+
│ └── evals/
|
| 182 |
+
└── server/
|
| 183 |
+
├── your_environment.py # Implement YourEnvironment(Environment)
|
| 184 |
+
├── app.py # Create FastAPI app
|
| 185 |
+
├── requirements.txt # Dependencies for Docker (can be generated)
|
| 186 |
+
└── Dockerfile # Define container image
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
#### Dependency Management
|
| 190 |
+
|
| 191 |
+
OpenEnv uses `pyproject.toml` as the primary dependency specification:
|
| 192 |
+
|
| 193 |
+
- **Environment-level `pyproject.toml`**: Each environment defines its own dependencies
|
| 194 |
+
- **Root-level `pyproject.toml`**: Contains shared core dependencies (fastapi, pydantic, uvicorn)
|
| 195 |
+
- **Server `requirements.txt`**: Can be auto-generated from `pyproject.toml` for Docker builds
|
| 196 |
+
|
| 197 |
+
**Development Workflow:**
|
| 198 |
+
|
| 199 |
+
```bash
|
| 200 |
+
# Install environment in editable mode
|
| 201 |
+
cd my_env
|
| 202 |
+
pip install -e .
|
| 203 |
+
|
| 204 |
+
# Or using uv (faster)
|
| 205 |
+
uv pip install -e .
|
| 206 |
+
|
| 207 |
+
# Run server locally without Docker
|
| 208 |
+
uv run server --host 0.0.0.0 --port 8000
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
**Benefits:**
|
| 212 |
+
- ✅ **Client-side extensions**: Modify client classes locally without repo changes
|
| 213 |
+
- ✅ **Better dependency management**: Clear separation between environments
|
| 214 |
+
- ✅ **Flexible workflows**: Use pip, uv, or Docker for different scenarios
|
| 215 |
+
- ✅ **CI/CD ready**: Automated dependency generation and validation
|
| 216 |
+
|
| 217 |
+
See [`envs/README.md`](envs/README.md) for a complete guide on building environments.
|
| 218 |
+
|
| 219 |
+
### For Environment Users
|
| 220 |
+
|
| 221 |
+
To use an environment:
|
| 222 |
+
1. Import from `envs.your_env`: `from envs.echo_env import EchoAction, EchoEnv`
|
| 223 |
+
2. Create client: `client = EchoEnv.from_docker_image("echo-env:latest")`
|
| 224 |
+
3. Interact: `client.reset()`, `client.step(action)`, `client.state()`
|
| 225 |
+
4. Cleanup: `client.close()`
|
| 226 |
+
|
| 227 |
+
See example scripts in `examples/` directory.
|
| 228 |
+
|
| 229 |
+
## CLI Commands
|
| 230 |
+
|
| 231 |
+
The OpenEnv CLI provides commands to manage environments:
|
| 232 |
+
|
| 233 |
+
- **`openenv init <env_name>`** - Initialize a new environment from template
|
| 234 |
+
- **`openenv push [--repo-id <repo>] [--private]`** - Deploy environment to Hugging Face Spaces
|
| 235 |
+
|
| 236 |
+
### Quick Start
|
| 237 |
+
|
| 238 |
+
```bash
|
| 239 |
+
# Create a new environment
|
| 240 |
+
openenv init my_game_env
|
| 241 |
+
|
| 242 |
+
# Deploy to Hugging Face (will prompt for login if needed)
|
| 243 |
+
cd my_game_env
|
| 244 |
+
openenv push
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
For detailed options: `openenv init --help` and `openenv push --help`.
|
| 248 |
+
|
| 249 |
+
## Design Principles
|
| 250 |
+
|
| 251 |
+
1. **Separation of Concerns**: Clear client-server boundaries
|
| 252 |
+
2. **Type Safety**: Strongly-typed actions, observations, and state
|
| 253 |
+
3. **Container Isolation**: Each environment runs in its own container
|
| 254 |
+
4. **Simple APIs**: Minimal, intuitive interfaces
|
| 255 |
+
|
| 256 |
+
## Quick Start
|
| 257 |
+
|
| 258 |
+
### Using the Echo Environment(Example)
|
| 259 |
+
|
| 260 |
+
```python
|
| 261 |
+
from envs.echo_env import EchoAction, EchoEnv
|
| 262 |
+
|
| 263 |
+
# Automatically start container and connect
|
| 264 |
+
client = EchoEnv.from_docker_image("echo-env:latest")
|
| 265 |
+
|
| 266 |
+
# Reset the environment
|
| 267 |
+
result = client.reset()
|
| 268 |
+
print(result.observation.echoed_message) # "Echo environment ready!"
|
| 269 |
+
|
| 270 |
+
# Send messages
|
| 271 |
+
result = client.step(EchoAction(message="Hello, World!"))
|
| 272 |
+
print(result.observation.echoed_message) # "Hello, World!"
|
| 273 |
+
print(result.reward) # 1.3 (based on message length)
|
| 274 |
+
|
| 275 |
+
# Cleanup
|
| 276 |
+
client.close() # Stops and removes container
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
## Requirements
|
| 280 |
+
|
| 281 |
+
- Python 3.11+
|
| 282 |
+
- Docker Desktop or Docker Engine
|
| 283 |
+
- FastAPI >= 0.104.0
|
| 284 |
+
- Uvicorn >= 0.24.0
|
| 285 |
+
- Requests >= 2.25.0
|
| 286 |
+
- smolagents (for coding environment)
|
| 287 |
+
|
| 288 |
+
## Supported RL Tools
|
| 289 |
+
The goal of this project is to support a broad set of open and closed tools to help standardize the agentic RL community. If you have a project that supports OpenEnv environments, please put up a PR to add your tool name along with a link to your documentation.
|
| 290 |
+
|
| 291 |
+
### torchforge
|
| 292 |
+
See GRPO BlackJack training example: [`examples/grpo_blackjack/`](examples/grpo_blackjack/)
|
| 293 |
+
|
| 294 |
+
### TRL
|
| 295 |
+
See the [TRL example](https://huggingface.co/docs/trl/main/en/openenv) on how to integrate OpenEnv environments with GRPO training.
|
| 296 |
+
|
| 297 |
+
### Unsloth
|
| 298 |
+
See the 2048 game example based on gpt-oss: [Colab notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/OpenEnv_gpt_oss_(20B)_Reinforcement_Learning_2048_Game.ipynb)
|
| 299 |
+
|
| 300 |
+
### SkyRL
|
| 301 |
+
See the [SkyRL example](https://skyrl.readthedocs.io/en/latest/examples/openenv.html) on how to train on OpenEnv environments with SkyRL.
|
| 302 |
+
|
| 303 |
+
### ART
|
| 304 |
+
See the [ART example](https://art.openpipe.ai/integrations/openenv-integration) on how OpenEnv environments can be used to train models with ART.
|
| 305 |
+
|
| 306 |
+
### Oumi
|
| 307 |
+
See the [Oumi example](https://github.com/oumi-ai/oumi/blob/main/notebooks/Oumi%20-%20OpenEnv%20GRPO%20with%20trl.ipynb) on how OpenEnv environments can be used to train models with Oumi.
|
| 308 |
+
|
| 309 |
+
## Example Environments
|
| 310 |
+
|
| 311 |
+
### Echo Environment
|
| 312 |
+
A simple environment that echoes back messages with metadata. Perfect for:
|
| 313 |
+
- Testing the HTTP server infrastructure
|
| 314 |
+
- Learning the framework basics
|
| 315 |
+
- Verifying container deployment
|
| 316 |
+
|
| 317 |
+
See: [`envs/echo_env/README.md`](envs/echo_env/README.md)
|
| 318 |
+
|
| 319 |
+
### Coding Environment
|
| 320 |
+
Executes arbitrary Python code in a sandboxed environment. Features:
|
| 321 |
+
- Safe code execution using smolagents
|
| 322 |
+
- Capture stdout, stderr, and exit codes
|
| 323 |
+
- Persistent execution context within episodes
|
| 324 |
+
- Error handling with detailed messages
|
| 325 |
+
|
| 326 |
+
See: [`envs/coding_env/README.md`](envs/coding_env/README.md)
|
| 327 |
+
|
| 328 |
+
## Community Support & Acknowledgments
|
| 329 |
+
This is an open and community-centric project. If you would like to add your name here, please put up a pull request and tag @jspisak for review. Ty!!
|
| 330 |
+
|
| 331 |
+
Supporters include: Meta-PyTorch, Hugging Face, [Patronus AI](https://patronus.ai), [Surge AI](https://surgehq.ai), [LastMile AI](https://www.lastmileai.dev), Unsloth AI, Reflection AI, vLLM, SkyRL (UC-Berkeley), LightningAI, Axolotl AI, Stanford Scaling Intelligence Lab, Mithril, [OpenMined](https://openmined.org/), [Fleet AI](https://fleetai.com), [Halluminate](https://halluminate.ai/), [Turing](https://www.turing.com/) ..
|
| 332 |
+
|
| 333 |
+
And we'd also like to acknowledge the team at Farama Foundation as the OpenEnv API was heavily inspired by the work you all have done on Gymnasium. Cheers!
|
| 334 |
+
|
| 335 |
+
## License
|
| 336 |
+
|
| 337 |
+
BSD 3-Clause License (see [LICENSE](./LICENSE) file)
|
src/openenv.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
README.md
|
| 3 |
+
pyproject.toml
|
| 4 |
+
envs/atari_env/__init__.py
|
| 5 |
+
envs/atari_env/client.py
|
| 6 |
+
envs/atari_env/models.py
|
| 7 |
+
envs/atari_env/server/__init__.py
|
| 8 |
+
envs/atari_env/server/app.py
|
| 9 |
+
envs/atari_env/server/atari_environment.py
|
| 10 |
+
envs/browsergym_env/__init__.py
|
| 11 |
+
envs/browsergym_env/client.py
|
| 12 |
+
envs/browsergym_env/models.py
|
| 13 |
+
envs/browsergym_env/server/__init__.py
|
| 14 |
+
envs/browsergym_env/server/app.py
|
| 15 |
+
envs/browsergym_env/server/browsergym_environment.py
|
| 16 |
+
envs/chat_env/__init__.py
|
| 17 |
+
envs/chat_env/client.py
|
| 18 |
+
envs/chat_env/models.py
|
| 19 |
+
envs/chat_env/server/__init__.py
|
| 20 |
+
envs/chat_env/server/app.py
|
| 21 |
+
envs/chat_env/server/chat_environment.py
|
| 22 |
+
envs/chat_env/server/test_chat_env.py
|
| 23 |
+
envs/coding_env/__init__.py
|
| 24 |
+
envs/coding_env/client.py
|
| 25 |
+
envs/coding_env/models.py
|
| 26 |
+
envs/coding_env/server/__init__.py
|
| 27 |
+
envs/coding_env/server/app.py
|
| 28 |
+
envs/coding_env/server/python_codeact_env.py
|
| 29 |
+
envs/coding_env/server/python_executor.py
|
| 30 |
+
envs/coding_env/server/transforms.py
|
| 31 |
+
envs/connect4_env/__init__.py
|
| 32 |
+
envs/connect4_env/client.py
|
| 33 |
+
envs/connect4_env/models.py
|
| 34 |
+
envs/connect4_env/server/__init__.py
|
| 35 |
+
envs/connect4_env/server/app.py
|
| 36 |
+
envs/connect4_env/server/connect4_environment.py
|
| 37 |
+
envs/dipg_safety_env/__init__.py
|
| 38 |
+
envs/dipg_safety_env/client.py
|
| 39 |
+
envs/dipg_safety_env/models.py
|
| 40 |
+
envs/dipg_safety_env/server/__init__.py
|
| 41 |
+
envs/dipg_safety_env/server/app.py
|
| 42 |
+
envs/dipg_safety_env/server/dipg_environment.py
|
| 43 |
+
envs/echo_env/__init__.py
|
| 44 |
+
envs/echo_env/client.py
|
| 45 |
+
envs/echo_env/models.py
|
| 46 |
+
envs/echo_env/build/lib/server/__init__.py
|
| 47 |
+
envs/echo_env/build/lib/server/app.py
|
| 48 |
+
envs/echo_env/build/lib/server/echo_environment.py
|
| 49 |
+
envs/echo_env/server/__init__.py
|
| 50 |
+
envs/echo_env/server/app.py
|
| 51 |
+
envs/echo_env/server/echo_environment.py
|
| 52 |
+
envs/finrl_env/__init__.py
|
| 53 |
+
envs/finrl_env/client.py
|
| 54 |
+
envs/finrl_env/models.py
|
| 55 |
+
envs/finrl_env/server/__init__.py
|
| 56 |
+
envs/finrl_env/server/app.py
|
| 57 |
+
envs/finrl_env/server/finrl_environment.py
|
| 58 |
+
envs/git_env/__init__.py
|
| 59 |
+
envs/git_env/client.py
|
| 60 |
+
envs/git_env/models.py
|
| 61 |
+
envs/git_env/server/__init__.py
|
| 62 |
+
envs/git_env/server/app.py
|
| 63 |
+
envs/git_env/server/git_task_environment.py
|
| 64 |
+
envs/openspiel_env/__init__.py
|
| 65 |
+
envs/openspiel_env/client.py
|
| 66 |
+
envs/openspiel_env/models.py
|
| 67 |
+
envs/openspiel_env/server/__init__.py
|
| 68 |
+
envs/openspiel_env/server/app.py
|
| 69 |
+
envs/openspiel_env/server/openspiel_environment.py
|
| 70 |
+
envs/openspiel_env/server/opponent_policies.py
|
| 71 |
+
envs/play/build/lib/server/__init__.py
|
| 72 |
+
envs/play/build/lib/server/app.py
|
| 73 |
+
envs/play/build/lib/server/play_environment.py
|
| 74 |
+
envs/sumo_rl_env/__init__.py
|
| 75 |
+
envs/sumo_rl_env/client.py
|
| 76 |
+
envs/sumo_rl_env/models.py
|
| 77 |
+
envs/sumo_rl_env/server/__init__.py
|
| 78 |
+
envs/sumo_rl_env/server/app.py
|
| 79 |
+
envs/sumo_rl_env/server/sumo_environment.py
|
| 80 |
+
envs/textarena_env/__init__.py
|
| 81 |
+
envs/textarena_env/client.py
|
| 82 |
+
envs/textarena_env/models.py
|
| 83 |
+
envs/textarena_env/rewards.py
|
| 84 |
+
envs/textarena_env/build/lib/server/__init__.py
|
| 85 |
+
envs/textarena_env/build/lib/server/app.py
|
| 86 |
+
envs/textarena_env/build/lib/server/environment.py
|
| 87 |
+
envs/textarena_env/server/__init__.py
|
| 88 |
+
envs/textarena_env/server/app.py
|
| 89 |
+
envs/textarena_env/server/environment.py
|
| 90 |
+
src/openenv/__init__.py
|
| 91 |
+
src/openenv.egg-info/PKG-INFO
|
| 92 |
+
src/openenv.egg-info/SOURCES.txt
|
| 93 |
+
src/openenv.egg-info/dependency_links.txt
|
| 94 |
+
src/openenv.egg-info/entry_points.txt
|
| 95 |
+
src/openenv.egg-info/requires.txt
|
| 96 |
+
src/openenv.egg-info/top_level.txt
|
| 97 |
+
src/openenv/cli/__init__.py
|
| 98 |
+
src/openenv/cli/__main__.py
|
| 99 |
+
src/openenv/cli/_cli_utils.py
|
| 100 |
+
src/openenv/cli/_validation.py
|
| 101 |
+
src/openenv/cli/commands/__init__.py
|
| 102 |
+
src/openenv/cli/commands/build.py
|
| 103 |
+
src/openenv/cli/commands/init.py
|
| 104 |
+
src/openenv/cli/commands/push.py
|
| 105 |
+
src/openenv/cli/commands/serve.py
|
| 106 |
+
src/openenv/cli/commands/validate.py
|
| 107 |
+
src/openenv/cli/templates/__init__.py
|
| 108 |
+
src/openenv/cli/templates/__pycache__/__init__.cpython-311.pyc
|
| 109 |
+
src/openenv/cli/templates/__pycache__/__init__.cpython-313.pyc
|
| 110 |
+
src/openenv/cli/templates/openenv_env/README.md
|
| 111 |
+
src/openenv/cli/templates/openenv_env/__init__.py
|
| 112 |
+
src/openenv/cli/templates/openenv_env/client.py
|
| 113 |
+
src/openenv/cli/templates/openenv_env/models.py
|
| 114 |
+
src/openenv/cli/templates/openenv_env/openenv.yaml
|
| 115 |
+
src/openenv/cli/templates/openenv_env/pyproject.toml
|
| 116 |
+
src/openenv/cli/templates/openenv_env/server/Dockerfile
|
| 117 |
+
src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py
|
| 118 |
+
src/openenv/cli/templates/openenv_env/server/__init__.py
|
| 119 |
+
src/openenv/cli/templates/openenv_env/server/app.py
|
| 120 |
+
src/openenv/cli/templates/openenv_env/server/requirements.txt
|
| 121 |
+
src/openenv/core/__init__.py
|
| 122 |
+
src/openenv/core/client_types.py
|
| 123 |
+
src/openenv/core/env_client.py
|
| 124 |
+
src/openenv/core/utils.py
|
| 125 |
+
src/openenv/core/containers/__init__.py
|
| 126 |
+
src/openenv/core/containers/test_local_docker_provider.py
|
| 127 |
+
src/openenv/core/containers/runtime/__init__.py
|
| 128 |
+
src/openenv/core/containers/runtime/providers.py
|
| 129 |
+
src/openenv/core/containers/runtime/uv_provider.py
|
| 130 |
+
src/openenv/core/env_server/__init__.py
|
| 131 |
+
src/openenv/core/env_server/base_transforms.py
|
| 132 |
+
src/openenv/core/env_server/exceptions.py
|
| 133 |
+
src/openenv/core/env_server/http_server.py
|
| 134 |
+
src/openenv/core/env_server/interfaces.py
|
| 135 |
+
src/openenv/core/env_server/route_config.py
|
| 136 |
+
src/openenv/core/env_server/serialization.py
|
| 137 |
+
src/openenv/core/env_server/types.py
|
| 138 |
+
src/openenv/core/env_server/web_interface.py
|
| 139 |
+
src/openenv/core/tools/__init__.py
|
| 140 |
+
src/openenv/core/tools/git_server_client.py
|
| 141 |
+
src/openenv/core/tools/local_python_executor.py
|
| 142 |
+
src/openenv_core/__init__.py
|
src/openenv.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
src/openenv.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
openenv = openenv.cli.__main__:main
|
src/openenv.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.104.0
|
| 2 |
+
pydantic>=2.0.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
requests>=2.25.0
|
| 5 |
+
typer>=0.9.0
|
| 6 |
+
rich>=13.0.0
|
| 7 |
+
pyyaml>=6.0
|
| 8 |
+
huggingface_hub>=0.20.0
|
| 9 |
+
openai>=2.7.2
|
| 10 |
+
tomli>=2.3.0
|
| 11 |
+
tomli-w>=1.2.0
|
| 12 |
+
websockets>=15.0.1
|
| 13 |
+
|
| 14 |
+
[all]
|
| 15 |
+
openenv[core]
|
| 16 |
+
openenv[cli]
|
| 17 |
+
|
| 18 |
+
[cli]
|
| 19 |
+
typer>=0.9.0
|
| 20 |
+
rich>=13.0.0
|
| 21 |
+
pyyaml>=6.0
|
| 22 |
+
huggingface_hub>=0.20.0
|
| 23 |
+
openai>=2.7.2
|
| 24 |
+
tomli>=2.3.0
|
| 25 |
+
tomli-w>=1.2.0
|
| 26 |
+
|
| 27 |
+
[core]
|
| 28 |
+
fastapi>=0.104.0
|
| 29 |
+
pydantic>=2.0.0
|
| 30 |
+
uvicorn>=0.24.0
|
| 31 |
+
requests>=2.25.0
|
| 32 |
+
websockets>=15.0.1
|
src/openenv.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv
|
| 2 |
+
openenv_core
|
src/openenv/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified OpenEnv package bundling the CLI and core runtime.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from importlib import metadata
|
| 6 |
+
|
| 7 |
+
from .auto import AutoAction, AutoEnv
|
| 8 |
+
from .core import GenericEnvClient, GenericAction, SyncEnvClient
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"core",
|
| 12 |
+
"cli",
|
| 13 |
+
"AutoEnv",
|
| 14 |
+
"AutoAction",
|
| 15 |
+
"GenericEnvClient",
|
| 16 |
+
"GenericAction",
|
| 17 |
+
"SyncEnvClient",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
__version__ = metadata.version("openenv") # type: ignore[arg-type]
|
| 22 |
+
except metadata.PackageNotFoundError: # pragma: no cover - local dev
|
| 23 |
+
__version__ = "0.0.0"
|
src/openenv/auto/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
OpenEnv Auto Module
|
| 9 |
+
===================
|
| 10 |
+
|
| 11 |
+
Provides HuggingFace-style auto-discovery API for OpenEnv environments.
|
| 12 |
+
|
| 13 |
+
This module enables automatic environment and action class loading without
|
| 14 |
+
manual imports:
|
| 15 |
+
|
| 16 |
+
>>> from openenv import AutoEnv, AutoAction
|
| 17 |
+
>>>
|
| 18 |
+
>>> # Load environment from installed package or HuggingFace Hub
|
| 19 |
+
>>> env = AutoEnv.from_name("coding-env")
|
| 20 |
+
>>>
|
| 21 |
+
>>> # Get action class
|
| 22 |
+
>>> CodeAction = AutoAction.from_name("coding")
|
| 23 |
+
>>> action = CodeAction(code="print('Hello!')")
|
| 24 |
+
|
| 25 |
+
Classes:
|
| 26 |
+
AutoEnv: Automatic environment client selection and instantiation
|
| 27 |
+
AutoAction: Automatic action class selection
|
| 28 |
+
|
| 29 |
+
The auto-discovery system works by:
|
| 30 |
+
1. Discovering installed openenv-* packages via importlib.metadata
|
| 31 |
+
2. Loading environment manifests (openenv.yaml) from package resources
|
| 32 |
+
3. Supporting HuggingFace Hub repositories for remote environments
|
| 33 |
+
4. Caching discovery results for performance
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
from .auto_action import AutoAction
|
| 37 |
+
from .auto_env import AutoEnv
|
| 38 |
+
|
| 39 |
+
__all__ = ["AutoEnv", "AutoAction"]
|
src/openenv/auto/_discovery.py
ADDED
|
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Environment Auto-Discovery System
|
| 9 |
+
==================================
|
| 10 |
+
|
| 11 |
+
This module provides automatic discovery of OpenEnv environments by:
|
| 12 |
+
1. Discovering installed openenv-* packages using importlib.metadata
|
| 13 |
+
2. Loading manifests (openenv.yaml) from package resources
|
| 14 |
+
3. Caching results for performance
|
| 15 |
+
4. Supporting HuggingFace Hub downloads
|
| 16 |
+
|
| 17 |
+
This enables AutoEnv to work without coupling to src/envs/ directory.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import importlib
|
| 21 |
+
import importlib.metadata
|
| 22 |
+
import importlib.resources
|
| 23 |
+
import json
|
| 24 |
+
import logging
|
| 25 |
+
import re
|
| 26 |
+
import tempfile
|
| 27 |
+
from dataclasses import dataclass, asdict
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
from typing import Dict, Optional, Type, Any
|
| 30 |
+
|
| 31 |
+
import yaml
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class EnvironmentInfo:
|
| 38 |
+
"""
|
| 39 |
+
Rich information about a discovered environment.
|
| 40 |
+
|
| 41 |
+
Attributes:
|
| 42 |
+
env_key: Environment key (e.g., "echo", "coding")
|
| 43 |
+
name: Full environment name (e.g., "echo_env")
|
| 44 |
+
package_name: Package name (e.g., "openenv-echo_env")
|
| 45 |
+
version: Version string
|
| 46 |
+
description: Human-readable description
|
| 47 |
+
client_module_path: Full module path to client (e.g., "echo_env.client")
|
| 48 |
+
client_class_name: Client class name (e.g., "EchoEnv")
|
| 49 |
+
action_class_name: Action class name (e.g., "EchoAction")
|
| 50 |
+
observation_class_name: Observation class name (e.g., "EchoObservation")
|
| 51 |
+
default_image: Default Docker image name (e.g., "echo-env:latest")
|
| 52 |
+
spec_version: OpenEnv spec version (from openenv.yaml)
|
| 53 |
+
manifest: Original manifest data
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
env_key: str
|
| 57 |
+
name: str
|
| 58 |
+
package_name: str
|
| 59 |
+
version: str
|
| 60 |
+
description: str
|
| 61 |
+
client_module_path: str
|
| 62 |
+
client_class_name: str
|
| 63 |
+
action_class_name: str
|
| 64 |
+
observation_class_name: str
|
| 65 |
+
default_image: str
|
| 66 |
+
spec_version: Optional[int] = None
|
| 67 |
+
manifest: Optional[Dict[str, Any]] = None
|
| 68 |
+
|
| 69 |
+
def get_client_class(self) -> Type:
|
| 70 |
+
"""
|
| 71 |
+
Dynamically import and return the client class.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Client class (e.g., EchoEnv)
|
| 75 |
+
|
| 76 |
+
Raises:
|
| 77 |
+
ImportError: If module or class cannot be imported
|
| 78 |
+
"""
|
| 79 |
+
try:
|
| 80 |
+
module = importlib.import_module(self.client_module_path)
|
| 81 |
+
return getattr(module, self.client_class_name)
|
| 82 |
+
except ImportError as e:
|
| 83 |
+
raise ImportError(
|
| 84 |
+
f"Failed to import {self.client_class_name} from {self.client_module_path}: {e}\n"
|
| 85 |
+
f"Make sure the package '{self.package_name}' is installed: "
|
| 86 |
+
f"pip install {self.package_name}"
|
| 87 |
+
) from e
|
| 88 |
+
except AttributeError as e:
|
| 89 |
+
raise ImportError(
|
| 90 |
+
f"Class {self.client_class_name} not found in {self.client_module_path}: {e}"
|
| 91 |
+
) from e
|
| 92 |
+
|
| 93 |
+
def get_action_class(self) -> Type:
|
| 94 |
+
"""
|
| 95 |
+
Dynamically import and return the action class.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Action class (e.g., EchoAction)
|
| 99 |
+
|
| 100 |
+
Raises:
|
| 101 |
+
ImportError: If module or class cannot be imported
|
| 102 |
+
"""
|
| 103 |
+
try:
|
| 104 |
+
module = importlib.import_module(self.client_module_path)
|
| 105 |
+
return getattr(module, self.action_class_name)
|
| 106 |
+
except ImportError as e:
|
| 107 |
+
raise ImportError(
|
| 108 |
+
f"Failed to import {self.action_class_name} from {self.client_module_path}: {e}\n"
|
| 109 |
+
f"Make sure the package '{self.package_name}' is installed: "
|
| 110 |
+
f"pip install {self.package_name}"
|
| 111 |
+
) from e
|
| 112 |
+
except AttributeError as e:
|
| 113 |
+
raise ImportError(
|
| 114 |
+
f"Class {self.action_class_name} not found in {self.client_module_path}: {e}"
|
| 115 |
+
) from e
|
| 116 |
+
|
| 117 |
+
def get_observation_class(self) -> Type:
|
| 118 |
+
"""
|
| 119 |
+
Dynamically import and return the observation class.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Observation class (e.g., EchoObservation)
|
| 123 |
+
|
| 124 |
+
Raises:
|
| 125 |
+
ImportError: If module or class cannot be imported
|
| 126 |
+
"""
|
| 127 |
+
try:
|
| 128 |
+
module = importlib.import_module(self.client_module_path)
|
| 129 |
+
return getattr(module, self.observation_class_name)
|
| 130 |
+
except ImportError as e:
|
| 131 |
+
raise ImportError(
|
| 132 |
+
f"Failed to import {self.observation_class_name} from {self.client_module_path}: {e}\n"
|
| 133 |
+
f"Make sure the package '{self.package_name}' is installed: "
|
| 134 |
+
f"pip install {self.package_name}"
|
| 135 |
+
) from e
|
| 136 |
+
except AttributeError as e:
|
| 137 |
+
raise ImportError(
|
| 138 |
+
f"Class {self.observation_class_name} not found in {self.client_module_path}: {e}"
|
| 139 |
+
) from e
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _normalize_env_name(name: str) -> str:
|
| 143 |
+
"""
|
| 144 |
+
Normalize environment name to standard format.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
name: Input name (e.g., "echo", "echo-env", "echo_env")
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Normalized name (e.g., "echo_env")
|
| 151 |
+
|
| 152 |
+
Examples:
|
| 153 |
+
>>> _normalize_env_name("echo")
|
| 154 |
+
'echo_env'
|
| 155 |
+
>>> _normalize_env_name("echo-env")
|
| 156 |
+
'echo_env'
|
| 157 |
+
>>> _normalize_env_name("echo_env")
|
| 158 |
+
'echo_env'
|
| 159 |
+
"""
|
| 160 |
+
# Remove common suffixes
|
| 161 |
+
name = re.sub(r"[-_]env$", "", name)
|
| 162 |
+
# Convert hyphens to underscores
|
| 163 |
+
name = name.replace("-", "_")
|
| 164 |
+
# Add _env suffix if not present
|
| 165 |
+
if not name.endswith("_env"):
|
| 166 |
+
name = f"{name}_env"
|
| 167 |
+
return name
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _is_hub_url(name: str) -> bool:
|
| 171 |
+
"""
|
| 172 |
+
Check if name is a HuggingFace Hub URL or repo ID.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
name: Input name
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
True if it looks like a Hub URL
|
| 179 |
+
|
| 180 |
+
Examples:
|
| 181 |
+
>>> _is_hub_url("meta-pytorch/echo_env")
|
| 182 |
+
True
|
| 183 |
+
>>> _is_hub_url("https://huggingface.co/meta-pytorch/echo_env")
|
| 184 |
+
True
|
| 185 |
+
>>> _is_hub_url("echo")
|
| 186 |
+
False
|
| 187 |
+
"""
|
| 188 |
+
# Contains org/repo pattern or huggingface.co domain
|
| 189 |
+
return "/" in name or "huggingface.co" in name
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _infer_class_name(env_name: str, class_type: str) -> str:
|
| 193 |
+
"""
|
| 194 |
+
Infer class name from environment name using simple conventions.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
env_name: Environment name (e.g., "echo_env")
|
| 198 |
+
class_type: Type of class ("client", "action", "observation")
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
Inferred class name
|
| 202 |
+
|
| 203 |
+
Examples:
|
| 204 |
+
>>> _infer_class_name("echo_env", "client")
|
| 205 |
+
'EchoEnv'
|
| 206 |
+
>>> _infer_class_name("echo_env", "action")
|
| 207 |
+
'EchoAction'
|
| 208 |
+
"""
|
| 209 |
+
# Remove _env suffix for base name
|
| 210 |
+
base_name = env_name.replace("_env", "")
|
| 211 |
+
|
| 212 |
+
# Convert to PascalCase
|
| 213 |
+
pascal_name = "".join(word.capitalize() for word in base_name.split("_"))
|
| 214 |
+
|
| 215 |
+
# Add suffix based on type
|
| 216 |
+
if class_type == "client":
|
| 217 |
+
return f"{pascal_name}Env"
|
| 218 |
+
elif class_type == "action":
|
| 219 |
+
return f"{pascal_name}Action"
|
| 220 |
+
elif class_type == "observation":
|
| 221 |
+
return f"{pascal_name}Observation"
|
| 222 |
+
else:
|
| 223 |
+
raise ValueError(f"Unknown class type: {class_type}")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _load_manifest_from_package(
|
| 227 |
+
package_name: str, module_name: str
|
| 228 |
+
) -> Optional[Dict[str, Any]]:
|
| 229 |
+
"""
|
| 230 |
+
Load openenv.yaml manifest from an installed package.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
package_name: Package name (e.g., "openenv-echo_env")
|
| 234 |
+
module_name: Module name (e.g., "echo_env")
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
Parsed manifest dictionary, or None if not found
|
| 238 |
+
|
| 239 |
+
"""
|
| 240 |
+
try:
|
| 241 |
+
# Try to read openenv.yaml from package
|
| 242 |
+
if hasattr(importlib.resources, "files"):
|
| 243 |
+
# Python 3.9+
|
| 244 |
+
package_files = importlib.resources.files(module_name)
|
| 245 |
+
if (package_files / "openenv.yaml").is_file():
|
| 246 |
+
manifest_text = (package_files / "openenv.yaml").read_text()
|
| 247 |
+
return yaml.safe_load(manifest_text)
|
| 248 |
+
else:
|
| 249 |
+
# Python 3.7-3.8 fallback
|
| 250 |
+
with importlib.resources.open_text(module_name, "openenv.yaml") as f:
|
| 251 |
+
return yaml.safe_load(f)
|
| 252 |
+
except (FileNotFoundError, ModuleNotFoundError, AttributeError):
|
| 253 |
+
logger.debug(f"No openenv.yaml found in {module_name}")
|
| 254 |
+
return None
|
| 255 |
+
except Exception as e:
|
| 256 |
+
logger.warning(f"Failed to load openenv.yaml from {module_name}: {e}")
|
| 257 |
+
return None
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def _create_env_info_from_package(
|
| 261 |
+
package_name: str, module_name: str, version: str
|
| 262 |
+
) -> Optional[EnvironmentInfo]:
|
| 263 |
+
"""
|
| 264 |
+
Create EnvironmentInfo from an installed package.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
package_name: Package name (e.g., "openenv-echo_env")
|
| 268 |
+
module_name: Module name (e.g., "echo_env")
|
| 269 |
+
version: Package version
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
EnvironmentInfo instance, or None if invalid
|
| 273 |
+
"""
|
| 274 |
+
# Load manifest
|
| 275 |
+
manifest = _load_manifest_from_package(package_name, module_name)
|
| 276 |
+
|
| 277 |
+
# Get environment name
|
| 278 |
+
if manifest and "name" in manifest:
|
| 279 |
+
env_name = manifest["name"]
|
| 280 |
+
else:
|
| 281 |
+
# Infer from module name
|
| 282 |
+
env_name = module_name
|
| 283 |
+
|
| 284 |
+
# Normalize to ensure _env suffix
|
| 285 |
+
if not env_name.endswith("_env"):
|
| 286 |
+
env_name = f"{env_name}_env"
|
| 287 |
+
|
| 288 |
+
# Determine env_key (e.g., "echo_env" → "echo")
|
| 289 |
+
env_key = env_name.replace("_env", "") if env_name.endswith("_env") else env_name
|
| 290 |
+
|
| 291 |
+
# Get description
|
| 292 |
+
description = (
|
| 293 |
+
manifest.get("description", f"{env_name} environment")
|
| 294 |
+
if manifest
|
| 295 |
+
else f"{env_name} environment"
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Get spec version
|
| 299 |
+
spec_version = manifest.get("spec_version") if manifest else None
|
| 300 |
+
|
| 301 |
+
# Determine class names
|
| 302 |
+
# Check if manifest has custom class names (custom format)
|
| 303 |
+
if manifest and "action" in manifest and "observation" in manifest:
|
| 304 |
+
# Custom format (like coding_env)
|
| 305 |
+
client_class_name = _infer_class_name(env_name, "client")
|
| 306 |
+
action_class_name = manifest.get(
|
| 307 |
+
"action", _infer_class_name(env_name, "action")
|
| 308 |
+
)
|
| 309 |
+
observation_class_name = manifest.get(
|
| 310 |
+
"observation", _infer_class_name(env_name, "observation")
|
| 311 |
+
)
|
| 312 |
+
else:
|
| 313 |
+
# Use conventions
|
| 314 |
+
client_class_name = _infer_class_name(env_name, "client")
|
| 315 |
+
action_class_name = _infer_class_name(env_name, "action")
|
| 316 |
+
observation_class_name = _infer_class_name(env_name, "observation")
|
| 317 |
+
|
| 318 |
+
# Module path is just module_name.client
|
| 319 |
+
client_module_path = f"{module_name}.client"
|
| 320 |
+
|
| 321 |
+
# Determine default Docker image name
|
| 322 |
+
image_name = env_name.replace("_", "-")
|
| 323 |
+
default_image = f"{image_name}:latest"
|
| 324 |
+
|
| 325 |
+
return EnvironmentInfo(
|
| 326 |
+
env_key=env_key,
|
| 327 |
+
name=env_name,
|
| 328 |
+
package_name=package_name,
|
| 329 |
+
version=version,
|
| 330 |
+
description=description,
|
| 331 |
+
client_module_path=client_module_path,
|
| 332 |
+
client_class_name=client_class_name,
|
| 333 |
+
action_class_name=action_class_name,
|
| 334 |
+
observation_class_name=observation_class_name,
|
| 335 |
+
default_image=default_image,
|
| 336 |
+
spec_version=spec_version,
|
| 337 |
+
manifest=manifest,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class EnvironmentDiscovery:
|
| 342 |
+
"""
|
| 343 |
+
Auto-discovery system for OpenEnv environments using installed packages.
|
| 344 |
+
|
| 345 |
+
This class discovers installed openenv-* packages and loads their metadata.
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
def __init__(self):
|
| 349 |
+
"""Initialize discovery system."""
|
| 350 |
+
self._cache: Optional[Dict[str, EnvironmentInfo]] = None
|
| 351 |
+
self._cache_file = Path(tempfile.gettempdir()) / "openenv_discovery_cache.json"
|
| 352 |
+
|
| 353 |
+
def _discover_installed_packages(self) -> Dict[str, EnvironmentInfo]:
|
| 354 |
+
"""
|
| 355 |
+
Discover all installed openenv-* packages.
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
Dictionary mapping env_key to EnvironmentInfo
|
| 359 |
+
"""
|
| 360 |
+
environments = {}
|
| 361 |
+
|
| 362 |
+
# Invalidate import caches to ensure we pick up newly installed packages
|
| 363 |
+
importlib.invalidate_caches()
|
| 364 |
+
|
| 365 |
+
# Get all installed packages
|
| 366 |
+
try:
|
| 367 |
+
distributions = importlib.metadata.distributions()
|
| 368 |
+
except Exception as e:
|
| 369 |
+
logger.warning(f"Failed to get installed packages: {e}")
|
| 370 |
+
return environments
|
| 371 |
+
|
| 372 |
+
# Filter for openenv-* packages (exclude openenv-core)
|
| 373 |
+
for dist in distributions:
|
| 374 |
+
package_name = dist.metadata["Name"]
|
| 375 |
+
|
| 376 |
+
if not package_name.startswith("openenv-"):
|
| 377 |
+
continue
|
| 378 |
+
|
| 379 |
+
if package_name == "openenv-core":
|
| 380 |
+
continue
|
| 381 |
+
|
| 382 |
+
# Get module name (e.g., "openenv-echo_env" → "echo_env")
|
| 383 |
+
module_name = package_name.replace("openenv-", "").replace("-", "_")
|
| 384 |
+
|
| 385 |
+
# Get version
|
| 386 |
+
version = dist.version
|
| 387 |
+
|
| 388 |
+
try:
|
| 389 |
+
# Create environment info
|
| 390 |
+
env_info = _create_env_info_from_package(
|
| 391 |
+
package_name, module_name, version
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
if env_info:
|
| 395 |
+
environments[env_info.env_key] = env_info
|
| 396 |
+
logger.debug(
|
| 397 |
+
f"Discovered environment: {env_info.env_key} ({package_name})"
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
except Exception as e:
|
| 401 |
+
logger.warning(f"Failed to load environment from {package_name}: {e}")
|
| 402 |
+
continue
|
| 403 |
+
|
| 404 |
+
return environments
|
| 405 |
+
|
| 406 |
+
def _load_cache(self) -> Optional[Dict[str, EnvironmentInfo]]:
|
| 407 |
+
"""
|
| 408 |
+
Load cached discovery results.
|
| 409 |
+
|
| 410 |
+
Returns:
|
| 411 |
+
Dictionary of env_key -> EnvironmentInfo, or None if cache invalid
|
| 412 |
+
"""
|
| 413 |
+
if not self._cache_file.exists():
|
| 414 |
+
return None
|
| 415 |
+
|
| 416 |
+
try:
|
| 417 |
+
with open(self._cache_file, "r") as f:
|
| 418 |
+
cache_data = json.load(f)
|
| 419 |
+
|
| 420 |
+
# Reconstruct EnvironmentInfo objects
|
| 421 |
+
cache = {}
|
| 422 |
+
for env_key, env_data in cache_data.items():
|
| 423 |
+
cache[env_key] = EnvironmentInfo(**env_data)
|
| 424 |
+
|
| 425 |
+
return cache
|
| 426 |
+
except Exception as e:
|
| 427 |
+
logger.warning(f"Failed to load discovery cache: {e}")
|
| 428 |
+
return None
|
| 429 |
+
|
| 430 |
+
def _save_cache(self, environments: Dict[str, EnvironmentInfo]) -> None:
|
| 431 |
+
"""
|
| 432 |
+
Save discovery results to cache.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
environments: Dictionary of env_key -> EnvironmentInfo
|
| 436 |
+
"""
|
| 437 |
+
try:
|
| 438 |
+
cache_data = {}
|
| 439 |
+
for env_key, env_info in environments.items():
|
| 440 |
+
cache_data[env_key] = asdict(env_info)
|
| 441 |
+
|
| 442 |
+
with open(self._cache_file, "w") as f:
|
| 443 |
+
json.dump(cache_data, f, indent=2)
|
| 444 |
+
|
| 445 |
+
except Exception as e:
|
| 446 |
+
logger.warning(f"Failed to save discovery cache: {e}")
|
| 447 |
+
|
| 448 |
+
def discover(self, use_cache: bool = True) -> Dict[str, EnvironmentInfo]:
|
| 449 |
+
"""
|
| 450 |
+
Discover all installed OpenEnv environments.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
use_cache: If True, try to load from cache first
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
Dictionary mapping env_key to EnvironmentInfo
|
| 457 |
+
|
| 458 |
+
Examples:
|
| 459 |
+
>>> discovery = EnvironmentDiscovery()
|
| 460 |
+
>>> envs = discovery.discover()
|
| 461 |
+
>>> print(envs.keys())
|
| 462 |
+
dict_keys(['echo', 'coding', ...])
|
| 463 |
+
"""
|
| 464 |
+
# Try to load from memory cache first
|
| 465 |
+
if use_cache and self._cache is not None:
|
| 466 |
+
return self._cache
|
| 467 |
+
|
| 468 |
+
# Try to load from file cache
|
| 469 |
+
if use_cache:
|
| 470 |
+
cached = self._load_cache()
|
| 471 |
+
if cached is not None:
|
| 472 |
+
self._cache = cached
|
| 473 |
+
return self._cache
|
| 474 |
+
|
| 475 |
+
# Discover from installed packages
|
| 476 |
+
environments = self._discover_installed_packages()
|
| 477 |
+
|
| 478 |
+
# Save to cache
|
| 479 |
+
self._save_cache(environments)
|
| 480 |
+
self._cache = environments
|
| 481 |
+
|
| 482 |
+
return environments
|
| 483 |
+
|
| 484 |
+
def get_environment(self, env_key: str) -> Optional[EnvironmentInfo]:
|
| 485 |
+
"""
|
| 486 |
+
Get information about a specific environment.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
env_key: Environment key (e.g., "echo", "coding")
|
| 490 |
+
|
| 491 |
+
Returns:
|
| 492 |
+
EnvironmentInfo if found, None otherwise
|
| 493 |
+
|
| 494 |
+
Examples:
|
| 495 |
+
>>> discovery = EnvironmentDiscovery()
|
| 496 |
+
>>> env = discovery.get_environment("echo")
|
| 497 |
+
>>> print(env.client_class_name)
|
| 498 |
+
'EchoEnv'
|
| 499 |
+
"""
|
| 500 |
+
environments = self.discover()
|
| 501 |
+
return environments.get(env_key)
|
| 502 |
+
|
| 503 |
+
def get_environment_by_name(self, name: str) -> Optional[EnvironmentInfo]:
|
| 504 |
+
"""
|
| 505 |
+
Get environment info by flexible name matching.
|
| 506 |
+
|
| 507 |
+
Args:
|
| 508 |
+
name: Environment name (e.g., "echo", "echo-env", "echo_env")
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
EnvironmentInfo if found, None otherwise
|
| 512 |
+
"""
|
| 513 |
+
# Normalize name to env_key
|
| 514 |
+
normalized = _normalize_env_name(name)
|
| 515 |
+
env_key = normalized.replace("_env", "")
|
| 516 |
+
|
| 517 |
+
return self.get_environment(env_key)
|
| 518 |
+
|
| 519 |
+
def list_environments(self) -> None:
|
| 520 |
+
"""
|
| 521 |
+
Print a formatted list of all discovered environments.
|
| 522 |
+
|
| 523 |
+
Examples:
|
| 524 |
+
>>> discovery = EnvironmentDiscovery()
|
| 525 |
+
>>> discovery.list_environments()
|
| 526 |
+
Available OpenEnv Environments:
|
| 527 |
+
----------------------------------------------------------------------
|
| 528 |
+
echo : Echo Environment (v0.1.0) - openenv-echo_env
|
| 529 |
+
coding : Coding Environment (v0.1.0) - openenv-coding_env
|
| 530 |
+
...
|
| 531 |
+
"""
|
| 532 |
+
environments = self.discover()
|
| 533 |
+
|
| 534 |
+
print("Available OpenEnv Environments:")
|
| 535 |
+
print("-" * 70)
|
| 536 |
+
|
| 537 |
+
if not environments:
|
| 538 |
+
print(" No OpenEnv environments found.")
|
| 539 |
+
print(" Install environments with: pip install openenv-<env-name>")
|
| 540 |
+
else:
|
| 541 |
+
for env_key in sorted(environments.keys()):
|
| 542 |
+
env = environments[env_key]
|
| 543 |
+
print(f" {env_key:<15}: {env.description} (v{env.version})")
|
| 544 |
+
print(f" Package: {env.package_name}")
|
| 545 |
+
|
| 546 |
+
print("-" * 70)
|
| 547 |
+
print(f"Total: {len(environments)} environments")
|
| 548 |
+
|
| 549 |
+
def clear_cache(self) -> None:
|
| 550 |
+
"""Clear the discovery cache."""
|
| 551 |
+
if self._cache_file.exists():
|
| 552 |
+
self._cache_file.unlink()
|
| 553 |
+
self._cache = None
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
# Global discovery instance
|
| 557 |
+
_global_discovery: Optional[EnvironmentDiscovery] = None
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def get_discovery() -> EnvironmentDiscovery:
|
| 561 |
+
"""
|
| 562 |
+
Get or create the global discovery instance.
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
Global EnvironmentDiscovery instance
|
| 566 |
+
|
| 567 |
+
Examples:
|
| 568 |
+
>>> discovery = get_discovery()
|
| 569 |
+
>>> envs = discovery.discover()
|
| 570 |
+
"""
|
| 571 |
+
global _global_discovery
|
| 572 |
+
|
| 573 |
+
if _global_discovery is None:
|
| 574 |
+
_global_discovery = EnvironmentDiscovery()
|
| 575 |
+
|
| 576 |
+
return _global_discovery
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def reset_discovery() -> None:
|
| 580 |
+
"""Reset the global discovery instance (useful for testing)."""
|
| 581 |
+
global _global_discovery
|
| 582 |
+
if _global_discovery is not None:
|
| 583 |
+
_global_discovery.clear_cache()
|
| 584 |
+
_global_discovery = None
|
src/openenv/auto/auto_action.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
AutoAction - Automatic Action Class Selection
|
| 9 |
+
==============================================
|
| 10 |
+
|
| 11 |
+
AutoAction provides a HuggingFace-style API for automatically retrieving the
|
| 12 |
+
correct Action class from installed packages or HuggingFace Hub.
|
| 13 |
+
|
| 14 |
+
This module simplifies working with environment actions by automatically
|
| 15 |
+
detecting and returning the appropriate Action class without requiring
|
| 16 |
+
manual imports.
|
| 17 |
+
|
| 18 |
+
Example:
|
| 19 |
+
>>> from openenv import AutoEnv, AutoAction
|
| 20 |
+
>>>
|
| 21 |
+
>>> # Get Action class from environment name
|
| 22 |
+
>>> CodeAction = AutoAction.from_env("coding")
|
| 23 |
+
>>> action = CodeAction(code="print('Hello!')")
|
| 24 |
+
>>>
|
| 25 |
+
>>> # From HuggingFace Hub
|
| 26 |
+
>>> CodeAction = AutoAction.from_env("meta-pytorch/coding-env")
|
| 27 |
+
>>>
|
| 28 |
+
>>> # Use with AutoEnv
|
| 29 |
+
>>> env = AutoEnv.from_env("coding-env")
|
| 30 |
+
>>> result = env.step(action)
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
from __future__ import annotations
|
| 34 |
+
|
| 35 |
+
import logging
|
| 36 |
+
from typing import Type, Dict, Any
|
| 37 |
+
|
| 38 |
+
from ._discovery import get_discovery, _is_hub_url
|
| 39 |
+
from .auto_env import AutoEnv
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class AutoAction:
|
| 45 |
+
"""
|
| 46 |
+
AutoAction automatically retrieves the correct Action class based on
|
| 47 |
+
environment names or HuggingFace Hub repositories.
|
| 48 |
+
|
| 49 |
+
This class follows the HuggingFace AutoModel pattern, making it easy to
|
| 50 |
+
get the right Action class without needing to know which module to import.
|
| 51 |
+
|
| 52 |
+
The class provides factory methods that look up the Action class and
|
| 53 |
+
return the class (not an instance) for you to instantiate.
|
| 54 |
+
|
| 55 |
+
Example:
|
| 56 |
+
>>> # From installed package
|
| 57 |
+
>>> CodeAction = AutoAction.from_env("coding")
|
| 58 |
+
>>> action = CodeAction(code="print('test')")
|
| 59 |
+
>>>
|
| 60 |
+
>>> # From HuggingFace Hub
|
| 61 |
+
>>> CodeAction = AutoAction.from_env("meta-pytorch/coding-env")
|
| 62 |
+
>>> action = CodeAction(code="print('test')")
|
| 63 |
+
>>>
|
| 64 |
+
>>> # Use with AutoEnv for a complete workflow
|
| 65 |
+
>>> env = AutoEnv.from_env("coding-env")
|
| 66 |
+
>>> ActionClass = AutoAction.from_env("coding-env")
|
| 67 |
+
>>> action = ActionClass(code="print('Hello, AutoAction!')")
|
| 68 |
+
>>> result = env.step(action)
|
| 69 |
+
|
| 70 |
+
Note:
|
| 71 |
+
AutoAction is not meant to be instantiated directly. Use the class
|
| 72 |
+
method from_env() instead.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self):
|
| 76 |
+
"""AutoAction should not be instantiated directly. Use class methods instead."""
|
| 77 |
+
raise TypeError(
|
| 78 |
+
"AutoAction is a factory class and should not be instantiated directly. "
|
| 79 |
+
"Use AutoAction.from_hub() or AutoAction.from_env() instead."
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
def from_env(cls, name: str, skip_install: bool = False) -> Type:
|
| 84 |
+
"""
|
| 85 |
+
Get the Action class from environment name or HuggingFace Hub repository.
|
| 86 |
+
|
| 87 |
+
This method automatically:
|
| 88 |
+
1. Checks if the name is a HuggingFace Hub URL/repo ID
|
| 89 |
+
2. If Hub: downloads and installs the environment package
|
| 90 |
+
3. If local: looks up the installed openenv-* package
|
| 91 |
+
4. Imports and returns the Action class
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
name: Environment name or HuggingFace Hub repo ID
|
| 95 |
+
Examples:
|
| 96 |
+
- "coding" / "coding-env" / "coding_env"
|
| 97 |
+
- "meta-pytorch/coding-env" (Hub repo ID)
|
| 98 |
+
- "https://huggingface.co/meta-pytorch/coding-env" (Hub URL)
|
| 99 |
+
skip_install: If True, skip package installation and return
|
| 100 |
+
GenericAction class instead. Use this when working with
|
| 101 |
+
GenericEnvClient to avoid installing remote packages.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Action class (not an instance!). Returns GenericAction when
|
| 105 |
+
skip_install=True.
|
| 106 |
+
|
| 107 |
+
Raises:
|
| 108 |
+
ValueError: If environment not found (only when skip_install=False)
|
| 109 |
+
ImportError: If environment package is not installed (only when skip_install=False)
|
| 110 |
+
|
| 111 |
+
Examples:
|
| 112 |
+
>>> # From installed package
|
| 113 |
+
>>> CodeAction = AutoAction.from_env("coding-env")
|
| 114 |
+
>>> action = CodeAction(code="print('Hello!')")
|
| 115 |
+
>>>
|
| 116 |
+
>>> # From HuggingFace Hub
|
| 117 |
+
>>> CodeAction = AutoAction.from_env("meta-pytorch/coding-env")
|
| 118 |
+
>>> action = CodeAction(code="print('Hello!')")
|
| 119 |
+
>>>
|
| 120 |
+
>>> # Skip installation, use GenericAction (for GenericEnvClient)
|
| 121 |
+
>>> ActionClass = AutoAction.from_env("user/repo", skip_install=True)
|
| 122 |
+
>>> action = ActionClass(code="print('Hello!')") # Returns GenericAction
|
| 123 |
+
>>>
|
| 124 |
+
>>> # Different name formats
|
| 125 |
+
>>> EchoAction = AutoAction.from_env("echo")
|
| 126 |
+
>>> EchoAction = AutoAction.from_env("echo-env")
|
| 127 |
+
>>> EchoAction = AutoAction.from_env("echo_env")
|
| 128 |
+
"""
|
| 129 |
+
# If skip_install is True, return GenericAction without any package lookup
|
| 130 |
+
if skip_install:
|
| 131 |
+
from openenv.core.generic_client import GenericAction
|
| 132 |
+
|
| 133 |
+
logger.info(
|
| 134 |
+
f"Returning GenericAction for '{name}' (skip_install=True). "
|
| 135 |
+
f"Use keyword arguments to create actions: GenericAction(code='...')"
|
| 136 |
+
)
|
| 137 |
+
return GenericAction
|
| 138 |
+
|
| 139 |
+
# Check if it's a HuggingFace Hub URL or repo ID
|
| 140 |
+
if _is_hub_url(name):
|
| 141 |
+
# Ensure package is installed (reuse AutoEnv logic, downloads only if needed)
|
| 142 |
+
env_name = AutoEnv._ensure_package_from_hub(name)
|
| 143 |
+
else:
|
| 144 |
+
env_name = name
|
| 145 |
+
|
| 146 |
+
# Get environment info from discovery
|
| 147 |
+
discovery = get_discovery()
|
| 148 |
+
env_info = discovery.get_environment_by_name(env_name)
|
| 149 |
+
|
| 150 |
+
if not env_info:
|
| 151 |
+
# Environment not found - provide helpful error message
|
| 152 |
+
available_envs = discovery.discover()
|
| 153 |
+
|
| 154 |
+
if not available_envs:
|
| 155 |
+
raise ValueError(
|
| 156 |
+
"No OpenEnv environments found.\n"
|
| 157 |
+
"Install an environment with: pip install openenv-<env-name>\n"
|
| 158 |
+
"Or specify a HuggingFace Hub repository: AutoAction.from_env('openenv/echo_env')"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Try to suggest similar environment names
|
| 162 |
+
from difflib import get_close_matches
|
| 163 |
+
|
| 164 |
+
env_keys = list(available_envs.keys())
|
| 165 |
+
suggestions = get_close_matches(env_name, env_keys, n=3, cutoff=0.6)
|
| 166 |
+
|
| 167 |
+
error_msg = f"Unknown environment '{env_name}'.\n"
|
| 168 |
+
if suggestions:
|
| 169 |
+
error_msg += f"Did you mean: {', '.join(suggestions)}?\n"
|
| 170 |
+
error_msg += f"Available environments: {', '.join(sorted(env_keys))}"
|
| 171 |
+
|
| 172 |
+
raise ValueError(error_msg)
|
| 173 |
+
|
| 174 |
+
# Get the action class
|
| 175 |
+
try:
|
| 176 |
+
action_class = env_info.get_action_class()
|
| 177 |
+
return action_class
|
| 178 |
+
except ImportError as e:
|
| 179 |
+
raise ImportError(
|
| 180 |
+
f"Failed to import action class for '{env_name}'.\n"
|
| 181 |
+
f"Package '{env_info.package_name}' appears to be installed but the module cannot be imported.\n"
|
| 182 |
+
f"Try reinstalling: pip install --force-reinstall {env_info.package_name}\n"
|
| 183 |
+
f"Original error: {e}"
|
| 184 |
+
) from e
|
| 185 |
+
|
| 186 |
+
@classmethod
|
| 187 |
+
def from_hub(cls, env_name: str, skip_install: bool = False) -> Type:
|
| 188 |
+
"""
|
| 189 |
+
Get the Action class from environment name.
|
| 190 |
+
|
| 191 |
+
This is an alias for from_env() for backward compatibility and clarity.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
env_name: Environment name (e.g., "coding", "echo")
|
| 195 |
+
skip_install: If True, skip package installation and return
|
| 196 |
+
GenericAction class instead.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Action class (not an instance!)
|
| 200 |
+
|
| 201 |
+
Examples:
|
| 202 |
+
>>> CodeAction = AutoAction.from_hub("coding")
|
| 203 |
+
>>> action = CodeAction(code="print('Hello!')")
|
| 204 |
+
"""
|
| 205 |
+
return cls.from_env(env_name, skip_install=skip_install)
|
| 206 |
+
|
| 207 |
+
@classmethod
|
| 208 |
+
def get_action_info(cls, name: str) -> Dict[str, Any]:
|
| 209 |
+
"""
|
| 210 |
+
Get detailed information about an action class.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
name: Environment name
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
Dictionary with action class metadata
|
| 217 |
+
|
| 218 |
+
Raises:
|
| 219 |
+
ValueError: If environment not found
|
| 220 |
+
|
| 221 |
+
Examples:
|
| 222 |
+
>>> info = AutoAction.get_action_info("coding")
|
| 223 |
+
>>> print(info['action_class'])
|
| 224 |
+
'CodingAction'
|
| 225 |
+
>>> print(info['module'])
|
| 226 |
+
'coding_env.client'
|
| 227 |
+
"""
|
| 228 |
+
discovery = get_discovery()
|
| 229 |
+
env_info = discovery.get_environment_by_name(name)
|
| 230 |
+
|
| 231 |
+
if not env_info:
|
| 232 |
+
raise ValueError(f"Unknown environment: {name}")
|
| 233 |
+
|
| 234 |
+
return {
|
| 235 |
+
"env_key": env_info.env_key,
|
| 236 |
+
"env_name": env_info.name,
|
| 237 |
+
"package": env_info.package_name,
|
| 238 |
+
"action_class": env_info.action_class_name,
|
| 239 |
+
"observation_class": env_info.observation_class_name,
|
| 240 |
+
"module": env_info.client_module_path,
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
@classmethod
|
| 244 |
+
def list_actions(cls) -> None:
|
| 245 |
+
"""
|
| 246 |
+
Print a formatted list of all available action classes.
|
| 247 |
+
|
| 248 |
+
This discovers all installed openenv-* packages and displays
|
| 249 |
+
their action class information in a user-friendly format.
|
| 250 |
+
|
| 251 |
+
Examples:
|
| 252 |
+
>>> AutoAction.list_actions()
|
| 253 |
+
Available Action Classes:
|
| 254 |
+
----------------------------------------------------------------------
|
| 255 |
+
echo : EchoAction (from openenv-echo-env)
|
| 256 |
+
coding : CodingAction (from openenv-coding_env)
|
| 257 |
+
----------------------------------------------------------------------
|
| 258 |
+
Total: 2 action classes
|
| 259 |
+
"""
|
| 260 |
+
discovery = get_discovery()
|
| 261 |
+
environments = discovery.discover()
|
| 262 |
+
|
| 263 |
+
print("Available Action Classes:")
|
| 264 |
+
print("-" * 70)
|
| 265 |
+
|
| 266 |
+
if not environments:
|
| 267 |
+
print(" No OpenEnv environments found.")
|
| 268 |
+
print(" Install environments with: pip install openenv-<env-name>")
|
| 269 |
+
else:
|
| 270 |
+
for env_key in sorted(environments.keys()):
|
| 271 |
+
env = environments[env_key]
|
| 272 |
+
print(f" {env_key:<15}: {env.action_class_name}")
|
| 273 |
+
print(f" Package: {env.package_name}")
|
| 274 |
+
|
| 275 |
+
print("-" * 70)
|
| 276 |
+
print(f"Total: {len(environments)} action classes")
|
src/openenv/auto/auto_env.py
ADDED
|
@@ -0,0 +1,896 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
AutoEnv - Automatic Environment Selection
|
| 9 |
+
==========================================
|
| 10 |
+
|
| 11 |
+
AutoEnv provides a HuggingFace-style API for automatically selecting and
|
| 12 |
+
instantiating the correct environment client from installed packages or
|
| 13 |
+
HuggingFace Hub.
|
| 14 |
+
|
| 15 |
+
This module simplifies environment creation by automatically detecting the
|
| 16 |
+
environment type from the name and instantiating the appropriate client class.
|
| 17 |
+
|
| 18 |
+
Example:
|
| 19 |
+
>>> from openenv import AutoEnv, AutoAction
|
| 20 |
+
>>>
|
| 21 |
+
>>> # From installed package
|
| 22 |
+
>>> env = AutoEnv.from_env("coding-env")
|
| 23 |
+
>>>
|
| 24 |
+
>>> # From HuggingFace Hub
|
| 25 |
+
>>> env = AutoEnv.from_env("meta-pytorch/coding-env")
|
| 26 |
+
>>>
|
| 27 |
+
>>> # With configuration
|
| 28 |
+
>>> env = AutoEnv.from_env("coding", env_vars={"DEBUG": "1"})
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
import importlib
|
| 34 |
+
import logging
|
| 35 |
+
import os
|
| 36 |
+
import shutil
|
| 37 |
+
import subprocess
|
| 38 |
+
import sys
|
| 39 |
+
import requests
|
| 40 |
+
from typing import Any, Optional, TYPE_CHECKING, Dict
|
| 41 |
+
|
| 42 |
+
from ._discovery import get_discovery, _is_hub_url
|
| 43 |
+
from openenv.core.utils import run_async_safely
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if TYPE_CHECKING:
|
| 47 |
+
from openenv.core.containers.runtime import ContainerProvider
|
| 48 |
+
from openenv.core.env_client import EnvClient
|
| 49 |
+
|
| 50 |
+
logger = logging.getLogger(__name__)
|
| 51 |
+
|
| 52 |
+
# Cache for repo ID → env_name mapping to avoid redundant downloads
|
| 53 |
+
_hub_env_name_cache: Dict[str, str] = {}
|
| 54 |
+
|
| 55 |
+
# Environment variable to skip user confirmation for remote installs
|
| 56 |
+
OPENENV_TRUST_REMOTE_CODE = "OPENENV_TRUST_REMOTE_CODE"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _has_uv() -> bool:
|
| 60 |
+
"""Check if uv is available in the system."""
|
| 61 |
+
return shutil.which("uv") is not None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _get_pip_command() -> list[str]:
|
| 65 |
+
"""
|
| 66 |
+
Get the appropriate pip command (uv pip or pip).
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
List of command parts for pip installation
|
| 70 |
+
"""
|
| 71 |
+
if _has_uv():
|
| 72 |
+
return ["uv", "pip"]
|
| 73 |
+
return [sys.executable, "-m", "pip"]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _confirm_remote_install(repo_id: str) -> bool:
|
| 77 |
+
"""
|
| 78 |
+
Ask user for confirmation before installing remote code.
|
| 79 |
+
|
| 80 |
+
This is a security measure since we're executing code from the internet.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
repo_id: The HuggingFace repo ID being installed
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
True if user confirms, False otherwise
|
| 87 |
+
"""
|
| 88 |
+
# Check environment variable for automated/CI environments
|
| 89 |
+
if os.environ.get(OPENENV_TRUST_REMOTE_CODE, "").lower() in ("1", "true", "yes"):
|
| 90 |
+
logger.info("Skipping confirmation (OPENENV_TRUST_REMOTE_CODE is set)")
|
| 91 |
+
return True
|
| 92 |
+
|
| 93 |
+
# Check if we're in an interactive terminal
|
| 94 |
+
if not sys.stdin.isatty():
|
| 95 |
+
logger.warning(
|
| 96 |
+
"Cannot prompt for confirmation in non-interactive mode. "
|
| 97 |
+
"Set OPENENV_TRUST_REMOTE_CODE=1 to allow remote installs."
|
| 98 |
+
)
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
print(f"\n{'=' * 60}")
|
| 102 |
+
print("⚠️ SECURITY WARNING: Remote Code Installation")
|
| 103 |
+
print(f"{'=' * 60}")
|
| 104 |
+
print("You are about to install code from a remote repository:")
|
| 105 |
+
print(f" Repository: {repo_id}")
|
| 106 |
+
print(f" Source: https://huggingface.co/spaces/{repo_id}")
|
| 107 |
+
print("\nThis will execute code from the internet on your machine.")
|
| 108 |
+
print("Only proceed if you trust the source.")
|
| 109 |
+
print(f"{'=' * 60}\n")
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
response = input("Do you want to proceed? [y/N]: ").strip().lower()
|
| 113 |
+
return response in ("y", "yes")
|
| 114 |
+
except (EOFError, KeyboardInterrupt):
|
| 115 |
+
print("\nInstallation cancelled.")
|
| 116 |
+
return False
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class AutoEnv:
|
| 120 |
+
"""
|
| 121 |
+
AutoEnv automatically selects and instantiates the correct environment client
|
| 122 |
+
based on environment names or HuggingFace Hub repositories.
|
| 123 |
+
|
| 124 |
+
This class follows the HuggingFace AutoModel pattern, making it easy to work
|
| 125 |
+
with different environments without needing to import specific client classes.
|
| 126 |
+
|
| 127 |
+
The class provides factory methods that:
|
| 128 |
+
1. Check if name is a HuggingFace Hub URL/repo ID
|
| 129 |
+
2. If Hub: download and install the environment package
|
| 130 |
+
3. If local: look up the installed openenv-* package
|
| 131 |
+
4. Import and instantiate the client class
|
| 132 |
+
|
| 133 |
+
Example:
|
| 134 |
+
>>> # From installed package
|
| 135 |
+
>>> env = AutoEnv.from_env("coding-env")
|
| 136 |
+
>>>
|
| 137 |
+
>>> # From HuggingFace Hub
|
| 138 |
+
>>> env = AutoEnv.from_env("meta-pytorch/coding-env")
|
| 139 |
+
>>>
|
| 140 |
+
>>> # List available environments
|
| 141 |
+
>>> AutoEnv.list_environments()
|
| 142 |
+
|
| 143 |
+
Note:
|
| 144 |
+
AutoEnv is not meant to be instantiated directly. Use the class method
|
| 145 |
+
from_env() instead.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
def __init__(self):
|
| 149 |
+
"""AutoEnv should not be instantiated directly. Use class methods instead."""
|
| 150 |
+
raise TypeError(
|
| 151 |
+
"AutoEnv is a factory class and should not be instantiated directly. "
|
| 152 |
+
"Use AutoEnv.from_hub() or AutoEnv.from_env() instead."
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
@classmethod
|
| 156 |
+
def _resolve_space_url(cls, repo_id: str) -> str:
|
| 157 |
+
"""
|
| 158 |
+
Resolve HuggingFace Space repo ID to Space URL.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
repo_id: HuggingFace repo ID (e.g., "wukaixingxp/coding-env-test")
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Space URL (e.g., "https://wukaixingxp-coding-env-test.hf.space")
|
| 165 |
+
|
| 166 |
+
Examples:
|
| 167 |
+
>>> AutoEnv._resolve_space_url("wukaixingxp/coding-env-test")
|
| 168 |
+
'https://wukaixingxp-coding-env-test.hf.space'
|
| 169 |
+
"""
|
| 170 |
+
# Clean up repo_id if it's a full URL
|
| 171 |
+
if "huggingface.co" in repo_id:
|
| 172 |
+
# Extract org/repo from URL
|
| 173 |
+
# https://huggingface.co/wukaixingxp/coding-env-test -> wukaixingxp/coding-env-test
|
| 174 |
+
parts = repo_id.split("/")
|
| 175 |
+
if len(parts) >= 2:
|
| 176 |
+
repo_id = f"{parts[-2]}/{parts[-1]}"
|
| 177 |
+
|
| 178 |
+
# Convert user/space-name to user-space-name.hf.space
|
| 179 |
+
space_slug = repo_id.replace("/", "-")
|
| 180 |
+
return f"https://{space_slug}.hf.space"
|
| 181 |
+
|
| 182 |
+
@classmethod
|
| 183 |
+
def _is_local_url(cls, url: str) -> bool:
|
| 184 |
+
"""
|
| 185 |
+
Check if a URL points to a local server.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
url: URL to check
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
True if URL is localhost or 127.0.0.1, False otherwise
|
| 192 |
+
|
| 193 |
+
Examples:
|
| 194 |
+
>>> AutoEnv._is_local_url("http://localhost:8000")
|
| 195 |
+
True
|
| 196 |
+
>>> AutoEnv._is_local_url("http://127.0.0.1:8000")
|
| 197 |
+
True
|
| 198 |
+
>>> AutoEnv._is_local_url("https://example.com")
|
| 199 |
+
False
|
| 200 |
+
"""
|
| 201 |
+
url_lower = url.lower()
|
| 202 |
+
return "localhost" in url_lower or "127.0.0.1" in url_lower
|
| 203 |
+
|
| 204 |
+
@classmethod
|
| 205 |
+
def _check_server_availability(cls, base_url: str, timeout: float = 2.0) -> bool:
|
| 206 |
+
"""
|
| 207 |
+
Check if a server at the given URL is running and accessible.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
base_url: Server base URL to check
|
| 211 |
+
timeout: Request timeout in seconds
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
True if server is accessible, False otherwise
|
| 215 |
+
|
| 216 |
+
Examples:
|
| 217 |
+
>>> AutoEnv._check_server_availability("http://localhost:8000")
|
| 218 |
+
True # if server is running
|
| 219 |
+
"""
|
| 220 |
+
try:
|
| 221 |
+
# Bypass proxy for localhost to avoid proxy issues
|
| 222 |
+
proxies = None
|
| 223 |
+
if cls._is_local_url(base_url):
|
| 224 |
+
proxies = {"http": None, "https": None}
|
| 225 |
+
|
| 226 |
+
# Try to access the health endpoint
|
| 227 |
+
response = requests.get(
|
| 228 |
+
f"{base_url}/health", timeout=timeout, proxies=proxies
|
| 229 |
+
)
|
| 230 |
+
if response.status_code == 200:
|
| 231 |
+
return True
|
| 232 |
+
|
| 233 |
+
# If health endpoint doesn't exist, try root endpoint
|
| 234 |
+
response = requests.get(base_url, timeout=timeout, proxies=proxies)
|
| 235 |
+
return response.status_code == 200
|
| 236 |
+
except (requests.RequestException, Exception) as e:
|
| 237 |
+
logger.debug(f"Server {base_url} not accessible: {e}")
|
| 238 |
+
return False
|
| 239 |
+
|
| 240 |
+
@classmethod
|
| 241 |
+
def _check_space_availability(cls, space_url: str, timeout: float = 5.0) -> bool:
|
| 242 |
+
"""
|
| 243 |
+
Check if HuggingFace Space is running and accessible.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
space_url: Space URL to check
|
| 247 |
+
timeout: Request timeout in seconds
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
True if Space is accessible, False otherwise
|
| 251 |
+
|
| 252 |
+
Examples:
|
| 253 |
+
>>> AutoEnv._check_space_availability("https://wukaixingxp-coding-env-test.hf.space")
|
| 254 |
+
True
|
| 255 |
+
"""
|
| 256 |
+
try:
|
| 257 |
+
# Try to access the health endpoint
|
| 258 |
+
response = requests.get(f"{space_url}/health", timeout=timeout)
|
| 259 |
+
if response.status_code == 200:
|
| 260 |
+
return True
|
| 261 |
+
|
| 262 |
+
# If health endpoint doesn't exist, try root endpoint
|
| 263 |
+
response = requests.get(space_url, timeout=timeout)
|
| 264 |
+
return response.status_code == 200
|
| 265 |
+
except (requests.RequestException, Exception) as e:
|
| 266 |
+
logger.debug(f"Space {space_url} not accessible: {e}")
|
| 267 |
+
return False
|
| 268 |
+
|
| 269 |
+
@classmethod
|
| 270 |
+
def _get_hub_git_url(cls, repo_id: str) -> str:
|
| 271 |
+
"""
|
| 272 |
+
Get the git URL for a HuggingFace Space.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
repo_id: HuggingFace repo ID (e.g., "wukaixingxp/coding-env-test")
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Git URL for pip installation (e.g., "git+https://huggingface.co/spaces/wukaixingxp/coding-env-test")
|
| 279 |
+
"""
|
| 280 |
+
# Clean up repo_id if it's a full URL
|
| 281 |
+
if "huggingface.co" in repo_id:
|
| 282 |
+
parts = repo_id.split("/")
|
| 283 |
+
if len(parts) >= 2:
|
| 284 |
+
repo_id = f"{parts[-2]}/{parts[-1]}"
|
| 285 |
+
|
| 286 |
+
return f"git+https://huggingface.co/spaces/{repo_id}"
|
| 287 |
+
|
| 288 |
+
@classmethod
|
| 289 |
+
def _install_from_hub(cls, repo_id: str, trust_remote_code: bool = False) -> str:
|
| 290 |
+
"""
|
| 291 |
+
Install environment package directly from HuggingFace Hub using git+.
|
| 292 |
+
|
| 293 |
+
This is the preferred method as it avoids downloading the entire repo
|
| 294 |
+
and uses pip/uv's native git support.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
repo_id: HuggingFace repo ID (e.g., "wukaixingxp/coding-env-test")
|
| 298 |
+
trust_remote_code: If True, skip user confirmation
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
Package name that was installed
|
| 302 |
+
|
| 303 |
+
Raises:
|
| 304 |
+
ValueError: If installation fails or user declines
|
| 305 |
+
"""
|
| 306 |
+
# Security check - confirm with user before installing remote code
|
| 307 |
+
if not trust_remote_code and not _confirm_remote_install(repo_id):
|
| 308 |
+
raise ValueError(
|
| 309 |
+
"Installation cancelled by user.\n"
|
| 310 |
+
"To allow remote installs without prompting, set OPENENV_TRUST_REMOTE_CODE=1"
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
git_url = cls._get_hub_git_url(repo_id)
|
| 314 |
+
pip_cmd = _get_pip_command()
|
| 315 |
+
pip_name = "uv pip" if pip_cmd[0] == "uv" else "pip"
|
| 316 |
+
|
| 317 |
+
logger.info(f"Installing from HuggingFace Space using {pip_name}: {repo_id}")
|
| 318 |
+
logger.info(f"Command: {' '.join(pip_cmd)} install {git_url}")
|
| 319 |
+
|
| 320 |
+
try:
|
| 321 |
+
result = subprocess.run(
|
| 322 |
+
[*pip_cmd, "install", git_url],
|
| 323 |
+
check=True,
|
| 324 |
+
capture_output=True,
|
| 325 |
+
text=True,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Try to extract package name from pip output
|
| 329 |
+
# Look for "Successfully installed <package-name>-<version>"
|
| 330 |
+
for line in result.stdout.split("\n"):
|
| 331 |
+
if "Successfully installed" in line:
|
| 332 |
+
# Parse package name from the line
|
| 333 |
+
parts = line.replace("Successfully installed", "").strip().split()
|
| 334 |
+
for part in parts:
|
| 335 |
+
if part.startswith("openenv-"):
|
| 336 |
+
# Remove version suffix (e.g., "openenv-coding_env-0.1.0" -> "openenv-coding_env")
|
| 337 |
+
# Check if last segment looks like a version number
|
| 338 |
+
last_segment = part.rsplit("-", 1)[-1]
|
| 339 |
+
if last_segment.replace(".", "").isdigit():
|
| 340 |
+
package_name = "-".join(part.rsplit("-", 1)[:-1])
|
| 341 |
+
else:
|
| 342 |
+
package_name = part
|
| 343 |
+
logger.info(f"Successfully installed: {package_name}")
|
| 344 |
+
return package_name
|
| 345 |
+
|
| 346 |
+
# Fallback: try to determine package name from repo_id
|
| 347 |
+
# Convention: repo name like "coding-env-test" -> package "openenv-coding_env"
|
| 348 |
+
env_name = repo_id.split("/")[-1] # Get repo name from "user/repo"
|
| 349 |
+
env_name = env_name.replace("-", "_")
|
| 350 |
+
if not env_name.endswith("_env"):
|
| 351 |
+
env_name = f"{env_name}_env"
|
| 352 |
+
package_name = f"openenv-{env_name}"
|
| 353 |
+
|
| 354 |
+
logger.info(f"Installed (inferred package name): {package_name}")
|
| 355 |
+
return package_name
|
| 356 |
+
|
| 357 |
+
except subprocess.CalledProcessError as e:
|
| 358 |
+
error_msg = e.stderr or e.stdout or str(e)
|
| 359 |
+
raise ValueError(
|
| 360 |
+
f"Failed to install environment from HuggingFace Space: {repo_id}\n"
|
| 361 |
+
f"Command: {' '.join(pip_cmd)} install {git_url}\n"
|
| 362 |
+
f"Error: {error_msg}\n"
|
| 363 |
+
f"Make sure the repository exists and contains a valid Python package."
|
| 364 |
+
) from e
|
| 365 |
+
|
| 366 |
+
@classmethod
|
| 367 |
+
def _is_package_installed(cls, package_name: str) -> bool:
|
| 368 |
+
"""
|
| 369 |
+
Check if a package is already installed.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
package_name: Package name (e.g., "openenv-coding_env")
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
True if installed, False otherwise
|
| 376 |
+
"""
|
| 377 |
+
try:
|
| 378 |
+
import importlib.metadata
|
| 379 |
+
|
| 380 |
+
importlib.metadata.distribution(package_name)
|
| 381 |
+
return True
|
| 382 |
+
except importlib.metadata.PackageNotFoundError:
|
| 383 |
+
return False
|
| 384 |
+
|
| 385 |
+
@classmethod
|
| 386 |
+
def _ensure_package_from_hub(
|
| 387 |
+
cls, name: str, trust_remote_code: bool = False
|
| 388 |
+
) -> str:
|
| 389 |
+
"""
|
| 390 |
+
Ensure package from HuggingFace Hub is installed.
|
| 391 |
+
|
| 392 |
+
Uses git+ URLs for direct installation without downloading the entire repo.
|
| 393 |
+
Prompts user for confirmation before installing remote code.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
name: HuggingFace repo ID (e.g., "wukaixingxp/coding-env-test")
|
| 397 |
+
trust_remote_code: If True, skip user confirmation
|
| 398 |
+
|
| 399 |
+
Returns:
|
| 400 |
+
Environment name (e.g., "coding_env")
|
| 401 |
+
"""
|
| 402 |
+
global _hub_env_name_cache
|
| 403 |
+
|
| 404 |
+
# Check if we already resolved this repo ID
|
| 405 |
+
if name in _hub_env_name_cache:
|
| 406 |
+
env_name = _hub_env_name_cache[name]
|
| 407 |
+
logger.debug(f"Using cached env name for {name}: {env_name}")
|
| 408 |
+
return env_name
|
| 409 |
+
|
| 410 |
+
# Try to infer expected package name from repo ID
|
| 411 |
+
# Convention: repo "user/coding-env" -> package "openenv-coding_env"
|
| 412 |
+
repo_name = name.split("/")[-1] if "/" in name else name
|
| 413 |
+
expected_env_name = repo_name.replace("-", "_")
|
| 414 |
+
if not expected_env_name.endswith("_env"):
|
| 415 |
+
expected_env_name = f"{expected_env_name}_env"
|
| 416 |
+
expected_package_name = f"openenv-{expected_env_name}"
|
| 417 |
+
|
| 418 |
+
# Check if already installed
|
| 419 |
+
if cls._is_package_installed(expected_package_name):
|
| 420 |
+
logger.info(f"Package already installed: {expected_package_name}")
|
| 421 |
+
# Clear and refresh discovery cache to make sure it's detected
|
| 422 |
+
get_discovery().clear_cache()
|
| 423 |
+
get_discovery().discover(use_cache=False)
|
| 424 |
+
# Cache the result
|
| 425 |
+
_hub_env_name_cache[name] = expected_env_name
|
| 426 |
+
return expected_env_name
|
| 427 |
+
|
| 428 |
+
# Not installed, install using git+ URL
|
| 429 |
+
logger.info(f"Package not found locally, installing from Hub: {name}")
|
| 430 |
+
|
| 431 |
+
# Track existing packages before installation
|
| 432 |
+
get_discovery().clear_cache()
|
| 433 |
+
existing_envs = set(get_discovery().discover(use_cache=False).keys())
|
| 434 |
+
|
| 435 |
+
# Install the package
|
| 436 |
+
cls._install_from_hub(name, trust_remote_code=trust_remote_code)
|
| 437 |
+
|
| 438 |
+
# Clear discovery cache to pick up the newly installed package
|
| 439 |
+
try:
|
| 440 |
+
importlib.invalidate_caches()
|
| 441 |
+
except Exception:
|
| 442 |
+
pass
|
| 443 |
+
get_discovery().clear_cache()
|
| 444 |
+
discovered_envs = get_discovery().discover(use_cache=False)
|
| 445 |
+
|
| 446 |
+
# Find the newly installed environment by comparing before/after
|
| 447 |
+
new_envs = set(discovered_envs.keys()) - existing_envs
|
| 448 |
+
|
| 449 |
+
if new_envs:
|
| 450 |
+
# Use the first newly discovered environment
|
| 451 |
+
env_name = next(iter(new_envs))
|
| 452 |
+
logger.info(f"Found newly installed environment: '{env_name}'")
|
| 453 |
+
else:
|
| 454 |
+
# Fallback: try to find by matching module patterns
|
| 455 |
+
# Look for any env that might match the repo name pattern
|
| 456 |
+
repo_name = name.split("/")[-1] if "/" in name else name
|
| 457 |
+
repo_base = (
|
| 458 |
+
repo_name.replace("-", "_").replace("_env", "").replace("_test", "")
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
env_name = None
|
| 462 |
+
for env_key, env_info in discovered_envs.items():
|
| 463 |
+
# Check if env_key is a prefix/substring match
|
| 464 |
+
if env_key in repo_base or repo_base.startswith(env_key):
|
| 465 |
+
env_name = env_key
|
| 466 |
+
logger.info(
|
| 467 |
+
f"Found matching environment '{env_name}' for repo '{name}'"
|
| 468 |
+
)
|
| 469 |
+
break
|
| 470 |
+
|
| 471 |
+
if env_name is None:
|
| 472 |
+
# Last resort: use inferred name from repo
|
| 473 |
+
env_name = repo_name.replace("-", "_")
|
| 474 |
+
if not env_name.endswith("_env"):
|
| 475 |
+
env_name = f"{env_name}_env"
|
| 476 |
+
# Strip to get env_key
|
| 477 |
+
env_name = env_name.replace("_env", "")
|
| 478 |
+
logger.warning(
|
| 479 |
+
f"Could not find newly installed environment for repo '{name}', "
|
| 480 |
+
f"using inferred name: {env_name}"
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# Cache the result to avoid redundant installs
|
| 484 |
+
_hub_env_name_cache[name] = env_name
|
| 485 |
+
|
| 486 |
+
return env_name
|
| 487 |
+
|
| 488 |
+
@classmethod
|
| 489 |
+
def from_env(
|
| 490 |
+
cls,
|
| 491 |
+
name: str,
|
| 492 |
+
base_url: Optional[str] = None,
|
| 493 |
+
docker_image: Optional[str] = None,
|
| 494 |
+
container_provider: Optional[ContainerProvider] = None,
|
| 495 |
+
wait_timeout: float = 30.0,
|
| 496 |
+
env_vars: Optional[Dict[str, str]] = None,
|
| 497 |
+
trust_remote_code: bool = False,
|
| 498 |
+
skip_install: bool = False,
|
| 499 |
+
**kwargs: Any,
|
| 500 |
+
) -> "EnvClient":
|
| 501 |
+
"""
|
| 502 |
+
Create an environment client from a name or HuggingFace Hub repository.
|
| 503 |
+
|
| 504 |
+
This method automatically:
|
| 505 |
+
1. Checks if the name is a HuggingFace Hub URL/repo ID
|
| 506 |
+
2. If Hub: installs the environment package using git+ URL
|
| 507 |
+
3. If local: looks up the installed openenv-* package
|
| 508 |
+
4. Imports the client class and instantiates it
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
name: Environment name or HuggingFace Hub repo ID
|
| 512 |
+
Examples:
|
| 513 |
+
- "coding" / "coding-env" / "coding_env"
|
| 514 |
+
- "meta-pytorch/coding-env" (Hub repo ID)
|
| 515 |
+
- "https://huggingface.co/meta-pytorch/coding-env" (Hub URL)
|
| 516 |
+
base_url: Optional base URL for HTTP connection
|
| 517 |
+
docker_image: Optional Docker image name (overrides default)
|
| 518 |
+
container_provider: Optional container provider
|
| 519 |
+
wait_timeout: Timeout for container startup (seconds)
|
| 520 |
+
env_vars: Optional environment variables for the container
|
| 521 |
+
trust_remote_code: If True, skip user confirmation when installing
|
| 522 |
+
from HuggingFace Hub. Can also be set via OPENENV_TRUST_REMOTE_CODE
|
| 523 |
+
environment variable.
|
| 524 |
+
skip_install: If True, skip package installation and return a
|
| 525 |
+
GenericEnvClient for remote environments. Useful when you only
|
| 526 |
+
want to connect to a running server without installing any
|
| 527 |
+
remote code. When True:
|
| 528 |
+
- If base_url is provided: connects directly using GenericEnvClient
|
| 529 |
+
- If HF Space is running: connects to Space using GenericEnvClient
|
| 530 |
+
- If HF Space is not running: uses Docker from HF registry
|
| 531 |
+
**kwargs: Additional arguments passed to the client class
|
| 532 |
+
|
| 533 |
+
Returns:
|
| 534 |
+
Instance of the environment client class
|
| 535 |
+
|
| 536 |
+
Raises:
|
| 537 |
+
ValueError: If environment not found or cannot be loaded
|
| 538 |
+
ImportError: If environment package is not installed
|
| 539 |
+
|
| 540 |
+
Examples:
|
| 541 |
+
>>> # From installed package
|
| 542 |
+
>>> env = AutoEnv.from_env("coding-env")
|
| 543 |
+
>>>
|
| 544 |
+
>>> # From HuggingFace Hub
|
| 545 |
+
>>> env = AutoEnv.from_env("meta-pytorch/coding-env")
|
| 546 |
+
>>>
|
| 547 |
+
>>> # With custom Docker image
|
| 548 |
+
>>> env = AutoEnv.from_env("coding", docker_image="my-coding-env:v2")
|
| 549 |
+
>>>
|
| 550 |
+
>>> # With environment variables
|
| 551 |
+
>>> env = AutoEnv.from_env(
|
| 552 |
+
... "dipg",
|
| 553 |
+
... env_vars={"DIPG_DATASET_PATH": "/data/dipg"}
|
| 554 |
+
... )
|
| 555 |
+
>>>
|
| 556 |
+
>>> # Skip package installation, use GenericEnvClient
|
| 557 |
+
>>> env = AutoEnv.from_env(
|
| 558 |
+
... "user/my-env",
|
| 559 |
+
... skip_install=True
|
| 560 |
+
... )
|
| 561 |
+
"""
|
| 562 |
+
from openenv.core import GenericEnvClient
|
| 563 |
+
|
| 564 |
+
# Handle skip_install mode - return GenericEnvClient without package installation
|
| 565 |
+
if skip_install:
|
| 566 |
+
# If base_url is provided, connect directly
|
| 567 |
+
if base_url:
|
| 568 |
+
if cls._check_server_availability(base_url):
|
| 569 |
+
logger.info(
|
| 570 |
+
f"Using GenericEnvClient for {base_url} (skip_install=True)"
|
| 571 |
+
)
|
| 572 |
+
return GenericEnvClient(base_url=base_url, **kwargs)
|
| 573 |
+
else:
|
| 574 |
+
raise ConnectionError(
|
| 575 |
+
f"Server not available at {base_url}. "
|
| 576 |
+
f"Please ensure the server is running."
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# If it's a Hub URL, try to connect to Space or use Docker
|
| 580 |
+
if _is_hub_url(name):
|
| 581 |
+
space_url = cls._resolve_space_url(name)
|
| 582 |
+
logger.info(f"Checking if HuggingFace Space is accessible: {space_url}")
|
| 583 |
+
|
| 584 |
+
if cls._check_space_availability(space_url):
|
| 585 |
+
logger.info(
|
| 586 |
+
f"Using GenericEnvClient for Space {space_url} (skip_install=True)"
|
| 587 |
+
)
|
| 588 |
+
return GenericEnvClient(base_url=space_url, **kwargs)
|
| 589 |
+
else:
|
| 590 |
+
# Space not running, use Docker from HF registry
|
| 591 |
+
logger.info(
|
| 592 |
+
f"Space not running at {space_url}, "
|
| 593 |
+
f"using GenericEnvClient with HF Docker registry"
|
| 594 |
+
)
|
| 595 |
+
return run_async_safely(
|
| 596 |
+
GenericEnvClient.from_env(
|
| 597 |
+
name,
|
| 598 |
+
use_docker=True,
|
| 599 |
+
provider=container_provider,
|
| 600 |
+
env_vars=env_vars or {},
|
| 601 |
+
**kwargs,
|
| 602 |
+
)
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
# For local environments with skip_install, we need docker_image
|
| 606 |
+
if docker_image:
|
| 607 |
+
logger.info(
|
| 608 |
+
f"Using GenericEnvClient with Docker image {docker_image} "
|
| 609 |
+
f"(skip_install=True)"
|
| 610 |
+
)
|
| 611 |
+
return run_async_safely(
|
| 612 |
+
GenericEnvClient.from_docker_image(
|
| 613 |
+
image=docker_image,
|
| 614 |
+
provider=container_provider,
|
| 615 |
+
wait_timeout=wait_timeout,
|
| 616 |
+
env_vars=env_vars or {},
|
| 617 |
+
**kwargs,
|
| 618 |
+
)
|
| 619 |
+
)
|
| 620 |
+
else:
|
| 621 |
+
raise ValueError(
|
| 622 |
+
f"Cannot use skip_install=True for local environment '{name}' "
|
| 623 |
+
f"without providing base_url or docker_image. "
|
| 624 |
+
f"For local environments, either:\n"
|
| 625 |
+
f" 1. Provide base_url to connect to a running server\n"
|
| 626 |
+
f" 2. Provide docker_image to start a container\n"
|
| 627 |
+
f" 3. Set skip_install=False to use the installed package"
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
# Check if it's a HuggingFace Hub URL or repo ID
|
| 631 |
+
if _is_hub_url(name):
|
| 632 |
+
# Try to connect to Space directly first
|
| 633 |
+
space_url = cls._resolve_space_url(name)
|
| 634 |
+
logger.info(f"Checking if HuggingFace Space is accessible: {space_url}")
|
| 635 |
+
|
| 636 |
+
space_is_available = cls._check_space_availability(space_url)
|
| 637 |
+
|
| 638 |
+
if space_is_available and base_url is None:
|
| 639 |
+
# Space is accessible! We'll connect directly without Docker
|
| 640 |
+
logger.info(f"Space is accessible at: {space_url}")
|
| 641 |
+
logger.info("Installing package for client code (no Docker needed)...")
|
| 642 |
+
|
| 643 |
+
# Ensure package is installed (uses git+ URL)
|
| 644 |
+
env_name = cls._ensure_package_from_hub(
|
| 645 |
+
name, trust_remote_code=trust_remote_code
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
# Set base_url to connect to remote Space
|
| 649 |
+
base_url = space_url
|
| 650 |
+
logger.info("Will connect to remote Space (no local Docker)")
|
| 651 |
+
else:
|
| 652 |
+
# Space not accessible or user provided explicit base_url
|
| 653 |
+
if not space_is_available:
|
| 654 |
+
logger.info(f"Space not accessible at {space_url}")
|
| 655 |
+
logger.info("Falling back to local Docker mode...")
|
| 656 |
+
|
| 657 |
+
# Ensure package is installed (uses git+ URL)
|
| 658 |
+
env_name = cls._ensure_package_from_hub(
|
| 659 |
+
name, trust_remote_code=trust_remote_code
|
| 660 |
+
)
|
| 661 |
+
else:
|
| 662 |
+
env_name = name
|
| 663 |
+
|
| 664 |
+
# Get environment info from discovery
|
| 665 |
+
discovery = get_discovery()
|
| 666 |
+
env_info = discovery.get_environment_by_name(env_name)
|
| 667 |
+
|
| 668 |
+
if not env_info:
|
| 669 |
+
# Environment not found - provide helpful error message
|
| 670 |
+
available_envs = discovery.discover()
|
| 671 |
+
|
| 672 |
+
if not available_envs:
|
| 673 |
+
raise ValueError(
|
| 674 |
+
"No OpenEnv environments found.\n"
|
| 675 |
+
"Install an environment with: pip install openenv-<env-name>\n"
|
| 676 |
+
"Or specify a HuggingFace Hub repository: AutoEnv.from_env('openenv/echo_env')"
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
# Try to suggest similar environment names
|
| 680 |
+
from difflib import get_close_matches
|
| 681 |
+
|
| 682 |
+
env_keys = list(available_envs.keys())
|
| 683 |
+
suggestions = get_close_matches(env_name, env_keys, n=3, cutoff=0.6)
|
| 684 |
+
|
| 685 |
+
error_msg = f"Unknown environment '{env_name}'.\n"
|
| 686 |
+
if suggestions:
|
| 687 |
+
error_msg += f"Did you mean: {', '.join(suggestions)}?\n"
|
| 688 |
+
error_msg += f"Available environments: {', '.join(sorted(env_keys))}"
|
| 689 |
+
|
| 690 |
+
raise ValueError(error_msg)
|
| 691 |
+
|
| 692 |
+
# Get the client class
|
| 693 |
+
try:
|
| 694 |
+
client_class = env_info.get_client_class()
|
| 695 |
+
except ImportError as e:
|
| 696 |
+
raise ImportError(
|
| 697 |
+
f"Failed to import environment client for '{env_name}'.\n"
|
| 698 |
+
f"Package '{env_info.package_name}' appears to be installed but the module cannot be imported.\n"
|
| 699 |
+
f"Try reinstalling: pip install --force-reinstall {env_info.package_name}\n"
|
| 700 |
+
f"Original error: {e}"
|
| 701 |
+
) from e
|
| 702 |
+
|
| 703 |
+
# Determine Docker image to use
|
| 704 |
+
if docker_image is None:
|
| 705 |
+
docker_image = env_info.default_image
|
| 706 |
+
|
| 707 |
+
# Create client instance
|
| 708 |
+
try:
|
| 709 |
+
if base_url:
|
| 710 |
+
# Check if the server at base_url is available
|
| 711 |
+
is_local = cls._is_local_url(base_url)
|
| 712 |
+
server_available = cls._check_server_availability(base_url)
|
| 713 |
+
|
| 714 |
+
if server_available:
|
| 715 |
+
# Server is running, connect directly
|
| 716 |
+
logger.info(
|
| 717 |
+
f"✅ Server available at {base_url}, connecting directly"
|
| 718 |
+
)
|
| 719 |
+
return client_class(base_url=base_url, provider=None, **kwargs)
|
| 720 |
+
elif is_local:
|
| 721 |
+
# Local server not running, auto-start Docker container
|
| 722 |
+
logger.info(f"❌ Server not available at {base_url}")
|
| 723 |
+
logger.info(f"🐳 Auto-starting Docker container: {docker_image}")
|
| 724 |
+
return run_async_safely(
|
| 725 |
+
client_class.from_docker_image(
|
| 726 |
+
image=docker_image,
|
| 727 |
+
provider=container_provider,
|
| 728 |
+
wait_timeout=wait_timeout,
|
| 729 |
+
env_vars=env_vars or {},
|
| 730 |
+
**kwargs,
|
| 731 |
+
)
|
| 732 |
+
)
|
| 733 |
+
else:
|
| 734 |
+
# Remote server not available, cannot auto-start
|
| 735 |
+
raise ConnectionError(
|
| 736 |
+
f"Remote server not available at {base_url}. "
|
| 737 |
+
f"Please ensure the server is running."
|
| 738 |
+
)
|
| 739 |
+
else:
|
| 740 |
+
# No base_url provided, start new Docker container
|
| 741 |
+
return run_async_safely(
|
| 742 |
+
client_class.from_docker_image(
|
| 743 |
+
image=docker_image,
|
| 744 |
+
provider=container_provider,
|
| 745 |
+
wait_timeout=wait_timeout,
|
| 746 |
+
env_vars=env_vars or {},
|
| 747 |
+
**kwargs,
|
| 748 |
+
)
|
| 749 |
+
)
|
| 750 |
+
except Exception as e:
|
| 751 |
+
raise ValueError(
|
| 752 |
+
f"Failed to create environment client for '{env_name}'.\n"
|
| 753 |
+
f"Client class: {client_class.__name__}\n"
|
| 754 |
+
f"Docker image: {docker_image}\n"
|
| 755 |
+
f"Error: {e}"
|
| 756 |
+
) from e
|
| 757 |
+
|
| 758 |
+
@classmethod
|
| 759 |
+
def from_hub(
|
| 760 |
+
cls,
|
| 761 |
+
name: str,
|
| 762 |
+
base_url: Optional[str] = None,
|
| 763 |
+
docker_image: Optional[str] = None,
|
| 764 |
+
container_provider: Optional["ContainerProvider"] = None,
|
| 765 |
+
wait_timeout: float = 30.0,
|
| 766 |
+
env_vars: Optional[Dict[str, str]] = None,
|
| 767 |
+
trust_remote_code: bool = False,
|
| 768 |
+
skip_install: bool = False,
|
| 769 |
+
**kwargs: Any,
|
| 770 |
+
) -> "EnvClient":
|
| 771 |
+
"""
|
| 772 |
+
Create an environment client from a name or HuggingFace Hub repository.
|
| 773 |
+
|
| 774 |
+
This is an alias for from_env() for backward compatibility.
|
| 775 |
+
|
| 776 |
+
Args:
|
| 777 |
+
name: Environment name or HuggingFace Hub repo ID
|
| 778 |
+
base_url: Optional base URL for HTTP connection
|
| 779 |
+
docker_image: Optional Docker image name (overrides default)
|
| 780 |
+
container_provider: Optional container provider
|
| 781 |
+
wait_timeout: Timeout for container startup (seconds)
|
| 782 |
+
env_vars: Optional environment variables for the container
|
| 783 |
+
trust_remote_code: If True, skip user confirmation when installing
|
| 784 |
+
from HuggingFace Hub
|
| 785 |
+
skip_install: If True, skip package installation and return a
|
| 786 |
+
GenericEnvClient for remote environments
|
| 787 |
+
**kwargs: Additional arguments passed to the client class
|
| 788 |
+
|
| 789 |
+
Returns:
|
| 790 |
+
Instance of the environment client class
|
| 791 |
+
|
| 792 |
+
Examples:
|
| 793 |
+
>>> env = AutoEnv.from_hub("coding-env")
|
| 794 |
+
>>> env = AutoEnv.from_hub("meta-pytorch/coding-env")
|
| 795 |
+
"""
|
| 796 |
+
return cls.from_env(
|
| 797 |
+
name=name,
|
| 798 |
+
base_url=base_url,
|
| 799 |
+
docker_image=docker_image,
|
| 800 |
+
container_provider=container_provider,
|
| 801 |
+
wait_timeout=wait_timeout,
|
| 802 |
+
env_vars=env_vars,
|
| 803 |
+
trust_remote_code=trust_remote_code,
|
| 804 |
+
skip_install=skip_install,
|
| 805 |
+
**kwargs,
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
@classmethod
|
| 809 |
+
def get_env_class(cls, name: str):
|
| 810 |
+
"""
|
| 811 |
+
Get the environment client class without instantiating it.
|
| 812 |
+
|
| 813 |
+
Args:
|
| 814 |
+
name: Environment name
|
| 815 |
+
|
| 816 |
+
Returns:
|
| 817 |
+
The environment client class
|
| 818 |
+
|
| 819 |
+
Raises:
|
| 820 |
+
ValueError: If environment not found
|
| 821 |
+
|
| 822 |
+
Examples:
|
| 823 |
+
>>> CodingEnv = AutoEnv.get_env_class("coding")
|
| 824 |
+
>>> # Now you can instantiate it yourself
|
| 825 |
+
>>> env = CodingEnv(base_url="http://localhost:8000")
|
| 826 |
+
"""
|
| 827 |
+
discovery = get_discovery()
|
| 828 |
+
env_info = discovery.get_environment_by_name(name)
|
| 829 |
+
|
| 830 |
+
if not env_info:
|
| 831 |
+
raise ValueError(f"Unknown environment: {name}")
|
| 832 |
+
|
| 833 |
+
return env_info.get_client_class()
|
| 834 |
+
|
| 835 |
+
@classmethod
|
| 836 |
+
def get_env_info(cls, name: str) -> Dict[str, Any]:
|
| 837 |
+
"""
|
| 838 |
+
Get detailed information about an environment.
|
| 839 |
+
|
| 840 |
+
Args:
|
| 841 |
+
name: Environment name
|
| 842 |
+
|
| 843 |
+
Returns:
|
| 844 |
+
Dictionary with environment metadata
|
| 845 |
+
|
| 846 |
+
Raises:
|
| 847 |
+
ValueError: If environment not found
|
| 848 |
+
|
| 849 |
+
Examples:
|
| 850 |
+
>>> info = AutoEnv.get_env_info("coding")
|
| 851 |
+
>>> print(info['description'])
|
| 852 |
+
'Coding environment for OpenEnv'
|
| 853 |
+
>>> print(info['default_image'])
|
| 854 |
+
'coding-env:latest'
|
| 855 |
+
"""
|
| 856 |
+
discovery = get_discovery()
|
| 857 |
+
env_info = discovery.get_environment_by_name(name)
|
| 858 |
+
|
| 859 |
+
if not env_info:
|
| 860 |
+
raise ValueError(f"Unknown environment: {name}")
|
| 861 |
+
|
| 862 |
+
return {
|
| 863 |
+
"env_key": env_info.env_key,
|
| 864 |
+
"name": env_info.name,
|
| 865 |
+
"package": env_info.package_name,
|
| 866 |
+
"version": env_info.version,
|
| 867 |
+
"description": env_info.description,
|
| 868 |
+
"env_class": env_info.client_class_name,
|
| 869 |
+
"action_class": env_info.action_class_name,
|
| 870 |
+
"observation_class": env_info.observation_class_name,
|
| 871 |
+
"module": env_info.client_module_path,
|
| 872 |
+
"default_image": env_info.default_image,
|
| 873 |
+
"spec_version": env_info.spec_version,
|
| 874 |
+
}
|
| 875 |
+
|
| 876 |
+
@classmethod
|
| 877 |
+
def list_environments(cls) -> None:
|
| 878 |
+
"""
|
| 879 |
+
Print a formatted list of all available environments.
|
| 880 |
+
|
| 881 |
+
This discovers all installed openenv-* packages and displays
|
| 882 |
+
their metadata in a user-friendly format.
|
| 883 |
+
|
| 884 |
+
Examples:
|
| 885 |
+
>>> AutoEnv.list_environments()
|
| 886 |
+
Available OpenEnv Environments:
|
| 887 |
+
----------------------------------------------------------------------
|
| 888 |
+
echo : Echo Environment (v0.1.0)
|
| 889 |
+
Package: openenv-echo-env
|
| 890 |
+
coding : Coding Environment (v0.1.0)
|
| 891 |
+
Package: openenv-coding_env
|
| 892 |
+
----------------------------------------------------------------------
|
| 893 |
+
Total: 2 environments
|
| 894 |
+
"""
|
| 895 |
+
discovery = get_discovery()
|
| 896 |
+
discovery.list_environments()
|
src/openenv/cli/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""OpenEnv CLI package."""
|
| 8 |
+
|
| 9 |
+
__version__ = "0.1.0"
|
src/openenv/cli/__main__.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
OpenEnv CLI entry point.
|
| 9 |
+
|
| 10 |
+
This module provides the main entry point for the OpenEnv command-line interface,
|
| 11 |
+
following the Hugging Face CLI pattern.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
import typer
|
| 17 |
+
|
| 18 |
+
from openenv.cli.commands import build, fork, init, push, serve, validate
|
| 19 |
+
|
| 20 |
+
# Create the main CLI app
|
| 21 |
+
app = typer.Typer(
|
| 22 |
+
name="openenv",
|
| 23 |
+
help="OpenEnv - An e2e framework for creating, deploying and using isolated execution environments for agentic RL training",
|
| 24 |
+
no_args_is_help=True,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Register commands
|
| 28 |
+
app.command(name="init", help="Initialize a new OpenEnv environment")(init.init)
|
| 29 |
+
app.command(name="build", help="Build Docker images for OpenEnv environments")(
|
| 30 |
+
build.build
|
| 31 |
+
)
|
| 32 |
+
app.command(
|
| 33 |
+
name="validate", help="Validate environment structure and deployment readiness"
|
| 34 |
+
)(validate.validate)
|
| 35 |
+
app.command(
|
| 36 |
+
name="push",
|
| 37 |
+
help="Push an OpenEnv environment to Hugging Face Spaces or custom registry",
|
| 38 |
+
)(push.push)
|
| 39 |
+
app.command(name="serve", help="Serve environments locally (TODO: Phase 4)")(
|
| 40 |
+
serve.serve
|
| 41 |
+
)
|
| 42 |
+
app.command(
|
| 43 |
+
name="fork",
|
| 44 |
+
help="Fork (duplicate) a Hugging Face Space to your account",
|
| 45 |
+
)(fork.fork)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Entry point for setuptools
|
| 49 |
+
def main() -> None:
|
| 50 |
+
"""Main entry point for the CLI."""
|
| 51 |
+
try:
|
| 52 |
+
app()
|
| 53 |
+
except KeyboardInterrupt:
|
| 54 |
+
print("\nOperation cancelled by user.")
|
| 55 |
+
sys.exit(130)
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"Error: {e}", file=sys.stderr)
|
| 58 |
+
sys.exit(1)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
main()
|