mreso commited on
Commit
9b47d4a
·
verified ·
1 Parent(s): c05c664

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cartpole.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/quadruped.png filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build for dm_control environment
8
+ # Uses pip for package installation
9
+
10
+ FROM python:3.11-slim AS builder
11
+
12
+ WORKDIR /app
13
+
14
+ # Install build dependencies including OpenGL for MuJoCo
15
+ RUN apt-get update && apt-get install -y --no-install-recommends \
16
+ build-essential \
17
+ git \
18
+ libgl1 \
19
+ libglx-mesa0 \
20
+ libglew-dev \
21
+ libosmesa6-dev \
22
+ libgl1-mesa-dev \
23
+ libglfw3 \
24
+ patchelf \
25
+ && rm -rf /var/lib/apt/lists/*
26
+
27
+ # Copy environment code
28
+ COPY . /app/env
29
+
30
+ WORKDIR /app/env
31
+
32
+ # Install dependencies using pip
33
+ RUN pip install --upgrade pip && \
34
+ pip install --no-cache-dir -e .
35
+
36
+ # Final runtime stage
37
+ FROM python:3.11-slim
38
+
39
+ WORKDIR /app
40
+
41
+ # Install runtime dependencies (OpenGL for MuJoCo rendering, curl for healthcheck)
42
+ RUN apt-get update && apt-get install -y --no-install-recommends \
43
+ curl \
44
+ libgl1 \
45
+ libglx-mesa0 \
46
+ libglew-dev \
47
+ libosmesa6-dev \
48
+ libglfw3 \
49
+ && rm -rf /var/lib/apt/lists/*
50
+
51
+ # Copy installed packages from builder
52
+ COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
53
+ COPY --from=builder /usr/local/bin /usr/local/bin
54
+
55
+ # Copy the environment code
56
+ COPY . /app/env
57
+
58
+ # Set PYTHONPATH so imports work correctly
59
+ ENV PYTHONPATH="/app/env"
60
+
61
+ # Set MuJoCo to use OSMesa for headless rendering
62
+ ENV MUJOCO_GL="osmesa"
63
+
64
+ # Expose port
65
+ EXPOSE 8000
66
+
67
+ # Health check
68
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
69
+ CMD curl -f http://localhost:8000/health || exit 1
70
+
71
+ # Run the FastAPI server
72
+ # Use exec to replace the shell with uvicorn so it receives SIGINT/SIGTERM directly
73
+ ENV ENABLE_WEB_INTERFACE=true
74
+ CMD ["sh", "-c", "cd /app/env && exec uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,10 +1,167 @@
1
  ---
2
- title: Dm Control Env
3
- emoji: 🐠
4
- colorFrom: blue
5
- colorTo: yellow
6
  sdk: docker
7
  pinned: false
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: dm_control Environment Server
3
+ emoji: 🤖
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
  ---
13
 
14
+ # dm_control OpenEnv Environment
15
+
16
+ A generic OpenEnv environment for [dm_control.suite](https://github.com/google-deepmind/dm_control), providing access to all MuJoCo-based continuous control tasks.
17
+
18
+ <p align="center">
19
+ <img src="assets/cartpole.png" width="45%" alt="Cartpole Balance"/>
20
+ <img src="assets/quadruped.png" width="45%" alt="Quadruped Walk"/>
21
+ </p>
22
+
23
+ ## Supported Environments
24
+
25
+ | Domain | Tasks |
26
+ |--------|-------|
27
+ | cartpole | balance, swingup, swingup_sparse |
28
+ | walker | stand, walk, run |
29
+ | humanoid | stand, walk, run |
30
+ | cheetah | run |
31
+ | hopper | stand, hop |
32
+ | reacher | easy, hard |
33
+ | pendulum | swingup |
34
+ | finger | spin, turn_easy, turn_hard |
35
+ | fish | upright, swim |
36
+ | ball_in_cup | catch |
37
+ | And more... | See `dm_control.suite.BENCHMARKING` |
38
+
39
+ ## Quick Start
40
+
41
+ ### Using the Client
42
+
43
+ ```python
44
+ from envs.dm_control_env import DMControlEnv, DMControlAction
45
+
46
+ # Connect to a running server
47
+ with DMControlEnv(base_url="http://localhost:8000") as env:
48
+ # Reset with default (cartpole/balance)
49
+ result = env.reset()
50
+ print(f"Observations: {result.observation.observations.keys()}")
51
+
52
+ # Take actions
53
+ for _ in range(100):
54
+ action = DMControlAction(values=[0.5]) # Push cart right
55
+ result = env.step(action)
56
+ print(f"Reward: {result.reward}, Done: {result.done}")
57
+
58
+ if result.done:
59
+ result = env.reset()
60
+ ```
61
+
62
+ ### Switching Environments
63
+
64
+ ```python
65
+ # Start with cartpole
66
+ result = env.reset(domain_name="cartpole", task_name="balance")
67
+
68
+ # Switch to walker (on next reset)
69
+ result = env.reset(domain_name="walker", task_name="walk")
70
+ # Note: walker has 6 action dimensions
71
+ action = DMControlAction(values=[0.0] * 6)
72
+ result = env.step(action)
73
+ ```
74
+
75
+ ### Running the Server
76
+
77
+ ```bash
78
+ # From OpenEnv root
79
+ cd envs/dm_control_env
80
+ uvicorn server.app:app --host 0.0.0.0 --port 8000
81
+
82
+ # Or using uv
83
+ uv run --project . server
84
+ ```
85
+
86
+ ### Using Docker
87
+
88
+ ```bash
89
+ # Build
90
+ docker build -t dm_control:latest -f server/Dockerfile .
91
+
92
+ # Run
93
+ docker run -p 8000:8000 dm_control:latest
94
+ ```
95
+
96
+ ## API
97
+
98
+ ### Action
99
+
100
+ ```python
101
+ class DMControlAction(Action):
102
+ values: List[float] # Continuous action values
103
+ ```
104
+
105
+ Action dimensions vary by environment:
106
+ - cartpole: 1 (force on cart)
107
+ - walker: 6 (joint torques)
108
+ - humanoid: 21 (joint torques)
109
+
110
+ ### Observation
111
+
112
+ ```python
113
+ class DMControlObservation(Observation):
114
+ observations: Dict[str, List[float]] # Named observation arrays
115
+ pixels: Optional[str] # Base64 PNG (if render=True)
116
+ reward: float
117
+ done: bool
118
+ ```
119
+
120
+ ### State
121
+
122
+ ```python
123
+ class DMControlState(State):
124
+ domain_name: str
125
+ task_name: str
126
+ action_spec: Dict[str, Any]
127
+ observation_spec: Dict[str, Any]
128
+ physics_timestep: float
129
+ control_timestep: float
130
+ episode_id: str
131
+ step_count: int
132
+ ```
133
+
134
+ ## Examples
135
+
136
+ See the `examples/` directory:
137
+ - `cartpole_control.py` - Interactive cartpole control with arrow keys
138
+ - `hopper_control.py` - Interactive hopper control with spacebar for random forces
139
+ - `quadruped_control.py` - Interactive quadruped control with spacebar for random forces
140
+ - `list_environments.py` - Print all available environments
141
+
142
+ All examples support consistent CLI arguments:
143
+
144
+ ```bash
145
+ # Default: interactive mode with minimal pygame window
146
+ python examples/cartpole_control.py
147
+
148
+ # Visual mode with rendered MuJoCo frames
149
+ python examples/cartpole_control.py --visual
150
+
151
+ # Headless mode (no pygame, automated control)
152
+ python examples/cartpole_control.py --headless --max-steps 500
153
+
154
+ # Select a different task
155
+ python examples/cartpole_control.py --task swingup
156
+ python examples/hopper_control.py --task stand
157
+ python examples/quadruped_control.py --task run
158
+ ```
159
+
160
+ ## Environment Variables
161
+
162
+ | Variable | Default | Description |
163
+ |----------|---------|-------------|
164
+ | `DMCONTROL_DOMAIN` | cartpole | Default domain |
165
+ | `DMCONTROL_TASK` | balance | Default task |
166
+ | `DMCONTROL_RENDER_HEIGHT` | 480 | Render height |
167
+ | `DMCONTROL_RENDER_WIDTH` | 640 | Render width |
__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """dm_control OpenEnv Environment.
2
+
3
+ A generic OpenEnv environment for dm_control.suite supporting all domains/tasks.
4
+ """
5
+
6
+ from .models import DMControlAction, DMControlObservation, DMControlState
7
+ from .client import DMControlEnv
8
+
9
+ __all__ = [
10
+ "DMControlAction",
11
+ "DMControlObservation",
12
+ "DMControlState",
13
+ "DMControlEnv",
14
+ ]
assets/cartpole.png ADDED

Git LFS Details

  • SHA256: 734d8a585b7538741767bd612db2068b511fda10bd685404ed190f9e9ff6b74d
  • Pointer size: 131 Bytes
  • Size of remote file: 329 kB
assets/quadruped.png ADDED

Git LFS Details

  • SHA256: 939ce03bc22ecc8ebcf1beea7418ade07b20ffe07c8c0515c5e3ff294287284d
  • Pointer size: 131 Bytes
  • Size of remote file: 553 kB
client.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ dm_control Environment Client.
9
+
10
+ This module provides the client for connecting to a dm_control
11
+ Environment server via WebSocket for persistent sessions.
12
+ """
13
+
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ try:
17
+ from openenv.core.client_types import StepResult
18
+ from openenv.core.env_client import EnvClient
19
+
20
+ from .models import (
21
+ AVAILABLE_ENVIRONMENTS,
22
+ DMControlAction,
23
+ DMControlObservation,
24
+ DMControlState,
25
+ )
26
+ except ImportError:
27
+ from openenv.core.client_types import StepResult
28
+ from openenv.core.env_client import EnvClient
29
+
30
+ try:
31
+ from models import (
32
+ AVAILABLE_ENVIRONMENTS,
33
+ DMControlAction,
34
+ DMControlObservation,
35
+ DMControlState,
36
+ )
37
+ except ImportError:
38
+ try:
39
+ from dm_control_env.models import (
40
+ AVAILABLE_ENVIRONMENTS,
41
+ DMControlAction,
42
+ DMControlObservation,
43
+ DMControlState,
44
+ )
45
+ except ImportError:
46
+ from envs.dm_control_env.models import (
47
+ AVAILABLE_ENVIRONMENTS,
48
+ DMControlAction,
49
+ DMControlObservation,
50
+ DMControlState,
51
+ )
52
+
53
+
54
+ class DMControlEnv(EnvClient[DMControlAction, DMControlObservation, DMControlState]):
55
+ """
56
+ Client for dm_control.suite environments.
57
+
58
+ This client maintains a persistent WebSocket connection to the environment
59
+ server, enabling efficient multi-step interactions with lower latency.
60
+
61
+ Supported Environments (via dm_control.suite):
62
+ - cartpole: balance, swingup, swingup_sparse
63
+ - walker: stand, walk, run
64
+ - humanoid: stand, walk, run
65
+ - cheetah: run
66
+ - hopper: stand, hop
67
+ - reacher: easy, hard
68
+ - And many more...
69
+
70
+ Example:
71
+ >>> # Connect to a running server
72
+ >>> with DMControlEnv(base_url="http://localhost:8000") as client:
73
+ ... result = client.reset()
74
+ ... print(f"Observations: {result.observation.observations.keys()}")
75
+ ...
76
+ ... # Take action (cartpole: push right)
77
+ ... result = client.step(DMControlAction(values=[0.5]))
78
+ ... print(f"Reward: {result.reward}")
79
+
80
+ Example switching environments:
81
+ >>> client = DMControlEnv(base_url="http://localhost:8000")
82
+ >>> # Start with cartpole balance
83
+ >>> result = client.reset(domain_name="cartpole", task_name="balance")
84
+ >>> # ... train on cartpole ...
85
+ >>> # Switch to walker walk
86
+ >>> result = client.reset(domain_name="walker", task_name="walk")
87
+ >>> # ... train on walker ...
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ base_url: str,
93
+ connect_timeout_s: float = 10.0,
94
+ message_timeout_s: float = 60.0,
95
+ provider: Optional[Any] = None,
96
+ ):
97
+ """
98
+ Initialize dm_control environment client.
99
+
100
+ Args:
101
+ base_url: Base URL of the environment server (http:// or ws://).
102
+ connect_timeout_s: Timeout for establishing WebSocket connection.
103
+ message_timeout_s: Timeout for receiving responses.
104
+ provider: Optional container/runtime provider for lifecycle management.
105
+ """
106
+ super().__init__(
107
+ base_url=base_url,
108
+ connect_timeout_s=connect_timeout_s,
109
+ message_timeout_s=message_timeout_s,
110
+ provider=provider,
111
+ )
112
+
113
+ def _step_payload(self, action: DMControlAction) -> Dict:
114
+ """
115
+ Convert DMControlAction to JSON payload for step request.
116
+
117
+ Args:
118
+ action: DMControlAction instance
119
+
120
+ Returns:
121
+ Dictionary representation suitable for JSON encoding
122
+ """
123
+ payload: Dict[str, Any] = {"values": action.values}
124
+
125
+ if action.metadata:
126
+ payload["metadata"] = action.metadata
127
+
128
+ return payload
129
+
130
+ def _parse_result(self, payload: Dict) -> StepResult[DMControlObservation]:
131
+ """
132
+ Parse server response into StepResult[DMControlObservation].
133
+
134
+ Args:
135
+ payload: JSON response from server
136
+
137
+ Returns:
138
+ StepResult with DMControlObservation
139
+ """
140
+ obs_data = payload.get("observation", {})
141
+
142
+ observation = DMControlObservation(
143
+ observations=obs_data.get("observations", {}),
144
+ pixels=obs_data.get("pixels"),
145
+ done=payload.get("done", False),
146
+ reward=payload.get("reward"),
147
+ metadata=obs_data.get("metadata", {}),
148
+ )
149
+
150
+ return StepResult(
151
+ observation=observation,
152
+ reward=payload.get("reward"),
153
+ done=payload.get("done", False),
154
+ )
155
+
156
+ def _parse_state(self, payload: Dict) -> DMControlState:
157
+ """
158
+ Parse server response into DMControlState object.
159
+
160
+ Args:
161
+ payload: JSON response from /state endpoint
162
+
163
+ Returns:
164
+ DMControlState object with environment information
165
+ """
166
+ return DMControlState(
167
+ episode_id=payload.get("episode_id"),
168
+ step_count=payload.get("step_count", 0),
169
+ domain_name=payload.get("domain_name", ""),
170
+ task_name=payload.get("task_name", ""),
171
+ action_spec=payload.get("action_spec", {}),
172
+ observation_spec=payload.get("observation_spec", {}),
173
+ physics_timestep=payload.get("physics_timestep", 0.002),
174
+ control_timestep=payload.get("control_timestep", 0.02),
175
+ )
176
+
177
+ def reset(
178
+ self,
179
+ domain_name: Optional[str] = None,
180
+ task_name: Optional[str] = None,
181
+ seed: Optional[int] = None,
182
+ render: bool = False,
183
+ **kwargs,
184
+ ) -> StepResult[DMControlObservation]:
185
+ """
186
+ Reset the environment.
187
+
188
+ Args:
189
+ domain_name: Optionally switch to a different domain.
190
+ task_name: Optionally switch to a different task.
191
+ seed: Random seed for reproducibility.
192
+ render: If True, include pixel observations in response.
193
+ **kwargs: Additional arguments passed to server.
194
+
195
+ Returns:
196
+ StepResult with initial observation.
197
+ """
198
+ reset_kwargs = dict(kwargs)
199
+ if domain_name is not None:
200
+ reset_kwargs["domain_name"] = domain_name
201
+ if task_name is not None:
202
+ reset_kwargs["task_name"] = task_name
203
+ if seed is not None:
204
+ reset_kwargs["seed"] = seed
205
+ reset_kwargs["render"] = render
206
+
207
+ return super().reset(**reset_kwargs)
208
+
209
+ def step(
210
+ self,
211
+ action: DMControlAction,
212
+ render: bool = False,
213
+ **kwargs,
214
+ ) -> StepResult[DMControlObservation]:
215
+ """
216
+ Execute one step in the environment.
217
+
218
+ Args:
219
+ action: DMControlAction with continuous action values.
220
+ render: If True, include pixel observations in response.
221
+ **kwargs: Additional arguments passed to server.
222
+
223
+ Returns:
224
+ StepResult with new observation, reward, and done flag.
225
+ """
226
+ # Note: render flag needs to be passed differently
227
+ # For now, the server remembers the render setting from reset
228
+ return super().step(action, **kwargs)
229
+
230
+ @staticmethod
231
+ def available_environments() -> List[Tuple[str, str]]:
232
+ """
233
+ List available dm_control environments.
234
+
235
+ Returns:
236
+ List of (domain_name, task_name) tuples.
237
+ """
238
+ return AVAILABLE_ENVIRONMENTS
239
+
240
+ @classmethod
241
+ def from_direct(
242
+ cls,
243
+ domain_name: str = "cartpole",
244
+ task_name: str = "balance",
245
+ render_height: int = 480,
246
+ render_width: int = 640,
247
+ port: int = 8765,
248
+ ) -> "DMControlEnv":
249
+ """
250
+ Create a dm_control environment client with an embedded local server.
251
+
252
+ This method starts a local uvicorn server in a subprocess and returns
253
+ a client connected to it.
254
+
255
+ Args:
256
+ domain_name: Default domain to use.
257
+ task_name: Default task to use.
258
+ render_height: Height of rendered images.
259
+ render_width: Width of rendered images.
260
+ port: Port for the local server.
261
+
262
+ Returns:
263
+ DMControlEnv client connected to the local server.
264
+
265
+ Example:
266
+ >>> client = DMControlEnv.from_direct(domain_name="walker", task_name="walk")
267
+ >>> try:
268
+ ... result = client.reset()
269
+ ... for _ in range(100):
270
+ ... result = client.step(DMControlAction(values=[0.0] * 6))
271
+ ... finally:
272
+ ... client.close()
273
+ """
274
+ import os
275
+ import subprocess
276
+ import sys
277
+ import time
278
+
279
+ import requests
280
+
281
+ try:
282
+ from pathlib import Path
283
+
284
+ client_dir = Path(__file__).parent
285
+ server_app = "envs.dm_control_env.server.app:app"
286
+ cwd = client_dir.parent.parent
287
+
288
+ if not (cwd / "envs" / "dm_control_env" / "server" / "app.py").exists():
289
+ if (client_dir / "server" / "app.py").exists():
290
+ server_app = "server.app:app"
291
+ cwd = client_dir
292
+ except Exception:
293
+ server_app = "envs.dm_control_env.server.app:app"
294
+ cwd = None
295
+
296
+ env = {
297
+ **os.environ,
298
+ "DMCONTROL_DOMAIN": domain_name,
299
+ "DMCONTROL_TASK": task_name,
300
+ "DMCONTROL_RENDER_HEIGHT": str(render_height),
301
+ "DMCONTROL_RENDER_WIDTH": str(render_width),
302
+ "NO_PROXY": "localhost,127.0.0.1",
303
+ "no_proxy": "localhost,127.0.0.1",
304
+ }
305
+
306
+ if cwd:
307
+ src_path = str(cwd / "src")
308
+ existing_path = env.get("PYTHONPATH", "")
309
+ env["PYTHONPATH"] = (
310
+ f"{src_path}:{cwd}:{existing_path}"
311
+ if existing_path
312
+ else f"{src_path}:{cwd}"
313
+ )
314
+
315
+ cmd = [
316
+ sys.executable,
317
+ "-m",
318
+ "uvicorn",
319
+ server_app,
320
+ "--host",
321
+ "127.0.0.1",
322
+ "--port",
323
+ str(port),
324
+ ]
325
+
326
+ server_process = subprocess.Popen(
327
+ cmd,
328
+ env=env,
329
+ stdout=subprocess.PIPE,
330
+ stderr=subprocess.STDOUT,
331
+ cwd=str(cwd) if cwd else None,
332
+ )
333
+
334
+ base_url = f"http://127.0.0.1:{port}"
335
+ healthy = False
336
+ for _ in range(30):
337
+ try:
338
+ response = requests.get(
339
+ f"{base_url}/health",
340
+ timeout=2,
341
+ proxies={"http": None, "https": None},
342
+ )
343
+ if response.status_code == 200:
344
+ healthy = True
345
+ break
346
+ except requests.exceptions.RequestException:
347
+ pass
348
+ time.sleep(1)
349
+
350
+ if not healthy:
351
+ server_process.kill()
352
+ raise RuntimeError(
353
+ f"Failed to start local dm_control server on port {port}. "
354
+ "Check that the port is available and dependencies are installed."
355
+ )
356
+
357
+ class DirectModeProvider:
358
+ """Provider that manages the embedded server subprocess."""
359
+
360
+ def __init__(self, process: subprocess.Popen):
361
+ self._process = process
362
+
363
+ def stop(self):
364
+ """Stop the embedded server."""
365
+ if self._process:
366
+ self._process.terminate()
367
+ try:
368
+ self._process.wait(timeout=10)
369
+ except subprocess.TimeoutExpired:
370
+ self._process.kill()
371
+ self._process = None
372
+
373
+ provider = DirectModeProvider(server_process)
374
+ client = cls(base_url=base_url, provider=provider)
375
+ return client
examples/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """dm_control examples."""
examples/cartpole_control.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Interactive cartpole control via OpenEnv.
3
+
4
+ This example demonstrates using the dm_control OpenEnv client with
5
+ the cartpole environment. Use arrow keys to control the cart.
6
+
7
+ Controls:
8
+ LEFT/RIGHT arrows: Apply force to move cart
9
+ R: Reset environment
10
+ ESC or Q: Quit
11
+
12
+ Requirements:
13
+ pip install pygame
14
+
15
+ Usage:
16
+ 1. Start the server: uvicorn server.app:app --host 0.0.0.0 --port 8000
17
+ 2. Run this script: python examples/cartpole_control.py
18
+
19
+ For visual mode (requires working MuJoCo rendering):
20
+ python examples/cartpole_control.py --visual
21
+ """
22
+
23
+ import argparse
24
+ import random
25
+ import sys
26
+ from pathlib import Path
27
+
28
+ # Add parent directory to path for imports
29
+ sys.path.insert(0, str(Path(__file__).parent.parent))
30
+
31
+ from client import DMControlEnv
32
+ from models import DMControlAction
33
+
34
+
35
+ def run_headless(env: DMControlEnv, task: str = "balance", max_steps: int = 500):
36
+ """Run cartpole control in headless mode."""
37
+ print("\n=== Headless Mode (OpenEnv Step/Observation Pattern) ===")
38
+ print("This mode demonstrates the OpenEnv API with the cartpole.\n")
39
+
40
+ # Reset environment using OpenEnv pattern
41
+ result = env.reset(domain_name="cartpole", task_name=task)
42
+ print(f"Initial observations: {list(result.observation.observations.keys())}")
43
+ print(f" position: {result.observation.observations.get('position', [])}")
44
+ print(f" velocity: {result.observation.observations.get('velocity', [])}")
45
+
46
+ total_reward = 0.0
47
+ step_count = 0
48
+
49
+ print("\nRunning with random actions to demonstrate step/observation pattern...\n")
50
+
51
+ while not result.done and step_count < max_steps:
52
+ # Random action in [-1, 1]
53
+ action_value = random.uniform(-1.0, 1.0)
54
+
55
+ # Step the environment using OpenEnv pattern
56
+ action = DMControlAction(values=[action_value])
57
+ result = env.step(action)
58
+
59
+ # Access observation and reward from result
60
+ total_reward += result.reward or 0.0
61
+ step_count += 1
62
+
63
+ # Print progress periodically
64
+ if step_count % 50 == 0:
65
+ pos = result.observation.observations.get("position", [])
66
+ vel = result.observation.observations.get("velocity", [])
67
+ print(
68
+ f"Step {step_count}: reward={result.reward:.3f}, "
69
+ f"total={total_reward:.2f}, done={result.done}"
70
+ )
71
+ print(f" position={pos}, velocity={vel}")
72
+
73
+ print(f"\nEpisode finished: {step_count} steps, total reward: {total_reward:.2f}")
74
+
75
+
76
+ def run_interactive(env: DMControlEnv, task: str = "balance"):
77
+ """Run interactive control with keyboard input via pygame."""
78
+ import pygame
79
+
80
+ print("\n=== Interactive Mode (OpenEnv Step/Observation Pattern) ===")
81
+ print("Use LEFT/RIGHT arrows to control cart, R to reset, ESC to quit.\n")
82
+
83
+ # Reset environment using OpenEnv pattern
84
+ result = env.reset(domain_name="cartpole", task_name=task)
85
+ print(f"Initial observations: {list(result.observation.observations.keys())}")
86
+
87
+ # Initialize pygame for keyboard input (minimal window)
88
+ pygame.init()
89
+ screen = pygame.display.set_mode((400, 100))
90
+ pygame.display.set_caption("Cartpole Control - Arrow keys to move, R to reset")
91
+ clock = pygame.time.Clock()
92
+
93
+ # Font for display
94
+ font = pygame.font.Font(None, 24)
95
+
96
+ running = True
97
+ total_reward = 0.0
98
+ step_count = 0
99
+
100
+ print("\nControls:")
101
+ print(" LEFT/RIGHT arrows: Move cart")
102
+ print(" R: Reset environment")
103
+ print(" ESC or Q: Quit\n")
104
+
105
+ while running:
106
+ # Handle events
107
+ for event in pygame.event.get():
108
+ if event.type == pygame.QUIT:
109
+ running = False
110
+ elif event.type == pygame.KEYDOWN:
111
+ if event.key in (pygame.K_ESCAPE, pygame.K_q):
112
+ running = False
113
+ elif event.key == pygame.K_r:
114
+ result = env.reset(domain_name="cartpole", task_name=task)
115
+ total_reward = 0.0
116
+ step_count = 0
117
+ print("Environment reset")
118
+
119
+ # Check for held keys (for continuous control)
120
+ keys = pygame.key.get_pressed()
121
+ if keys[pygame.K_LEFT]:
122
+ action_value = -1.0
123
+ elif keys[pygame.K_RIGHT]:
124
+ action_value = 1.0
125
+ else:
126
+ action_value = 0.0
127
+
128
+ # Step the environment using OpenEnv pattern
129
+ action = DMControlAction(values=[action_value])
130
+ result = env.step(action)
131
+
132
+ # Track reward from result
133
+ total_reward += result.reward or 0.0
134
+ step_count += 1
135
+
136
+ # Check if episode is done
137
+ if result.done:
138
+ print(
139
+ f"Episode finished! Steps: {step_count}, "
140
+ f"Total reward: {total_reward:.2f}"
141
+ )
142
+ # Auto-reset on done
143
+ result = env.reset(domain_name="cartpole", task_name=task)
144
+ total_reward = 0.0
145
+ step_count = 0
146
+
147
+ # Update display
148
+ direction = (
149
+ "<--" if action_value < 0 else ("-->" if action_value > 0 else "---")
150
+ )
151
+ screen.fill((30, 30, 30))
152
+ text = font.render(
153
+ f"Step: {step_count} | Reward: {total_reward:.1f} | {direction}",
154
+ True,
155
+ (255, 255, 255),
156
+ )
157
+ screen.blit(text, (10, 40))
158
+ pygame.display.flip()
159
+
160
+ # Print progress periodically
161
+ if step_count % 200 == 0 and step_count > 0:
162
+ print(f"Step {step_count}: Total reward: {total_reward:.2f}")
163
+
164
+ # Cap at 30 FPS
165
+ clock.tick(30)
166
+
167
+ pygame.quit()
168
+ print(f"Session ended. Final reward: {total_reward:.2f}")
169
+
170
+
171
+ def run_visual(env: DMControlEnv, task: str = "balance"):
172
+ """Run with pygame visualization showing rendered frames."""
173
+ import base64
174
+ import io
175
+
176
+ import pygame
177
+
178
+ print("\n=== Visual Mode (OpenEnv Step/Observation Pattern) ===")
179
+
180
+ # Reset environment with rendering enabled
181
+ result = env.reset(domain_name="cartpole", task_name=task, render=True)
182
+ print(f"Initial observations: {list(result.observation.observations.keys())}")
183
+
184
+ # Get first frame to determine window size
185
+ if result.observation.pixels is None:
186
+ print("Error: Server did not return rendered pixels.")
187
+ print("Make sure the server supports render=True")
188
+ print("\nTry running in interactive mode (default) instead.")
189
+ sys.exit(1)
190
+
191
+ # Decode base64 PNG to pygame surface
192
+ png_data = base64.b64decode(result.observation.pixels)
193
+ frame = pygame.image.load(io.BytesIO(png_data))
194
+ frame_size = frame.get_size()
195
+
196
+ # Initialize pygame
197
+ pygame.init()
198
+ screen = pygame.display.set_mode(frame_size)
199
+ pygame.display.set_caption(
200
+ "Cartpole (OpenEnv) - Arrow Keys to Move, R to Reset, ESC to Quit"
201
+ )
202
+ clock = pygame.time.Clock()
203
+
204
+ print("Controls:")
205
+ print(" LEFT/RIGHT arrows: Move cart")
206
+ print(" R: Reset environment")
207
+ print(" ESC or Q: Quit")
208
+
209
+ running = True
210
+ total_reward = 0.0
211
+ step_count = 0
212
+
213
+ while running:
214
+ # Handle events
215
+ for event in pygame.event.get():
216
+ if event.type == pygame.QUIT:
217
+ running = False
218
+ elif event.type == pygame.KEYDOWN:
219
+ if event.key in (pygame.K_ESCAPE, pygame.K_q):
220
+ running = False
221
+ elif event.key == pygame.K_r:
222
+ result = env.reset(
223
+ domain_name="cartpole", task_name=task, render=True
224
+ )
225
+ total_reward = 0.0
226
+ step_count = 0
227
+ print("Environment reset")
228
+
229
+ # Check for held keys (for continuous control)
230
+ keys = pygame.key.get_pressed()
231
+ if keys[pygame.K_LEFT]:
232
+ action_value = -1.0
233
+ elif keys[pygame.K_RIGHT]:
234
+ action_value = 1.0
235
+ else:
236
+ action_value = 0.0
237
+
238
+ # Step the environment using OpenEnv pattern
239
+ action = DMControlAction(values=[action_value])
240
+ result = env.step(action, render=True)
241
+
242
+ # Track reward from result
243
+ total_reward += result.reward or 0.0
244
+ step_count += 1
245
+
246
+ # Check if episode is done
247
+ if result.done:
248
+ print(
249
+ f"Episode finished! Steps: {step_count}, "
250
+ f"Total reward: {total_reward:.2f}"
251
+ )
252
+ result = env.reset(domain_name="cartpole", task_name=task, render=True)
253
+ total_reward = 0.0
254
+ step_count = 0
255
+
256
+ # Render the frame from observation pixels
257
+ if result.observation.pixels:
258
+ png_data = base64.b64decode(result.observation.pixels)
259
+ frame = pygame.image.load(io.BytesIO(png_data))
260
+ screen.blit(frame, (0, 0))
261
+ pygame.display.flip()
262
+
263
+ # Print progress periodically
264
+ if step_count % 200 == 0 and step_count > 0:
265
+ print(f"Step {step_count}: Total reward: {total_reward:.2f}")
266
+
267
+ # Cap at 30 FPS
268
+ clock.tick(30)
269
+
270
+ pygame.quit()
271
+ print(f"Session ended. Final reward: {total_reward:.2f}")
272
+
273
+
274
+ def main():
275
+ parser = argparse.ArgumentParser(
276
+ description="Interactive cartpole control via OpenEnv"
277
+ )
278
+ parser.add_argument(
279
+ "--visual",
280
+ action="store_true",
281
+ help="Enable pygame visualization with rendered frames",
282
+ )
283
+ parser.add_argument(
284
+ "--headless",
285
+ action="store_true",
286
+ help="Run in headless mode (no pygame, automated control)",
287
+ )
288
+ parser.add_argument(
289
+ "--max-steps",
290
+ type=int,
291
+ default=500,
292
+ help="Maximum steps for headless mode (default: 500)",
293
+ )
294
+ parser.add_argument(
295
+ "--task",
296
+ type=str,
297
+ default="balance",
298
+ choices=["balance", "balance_sparse", "swingup", "swingup_sparse"],
299
+ help="Cartpole task (default: balance)",
300
+ )
301
+ args = parser.parse_args()
302
+
303
+ server_url = "http://localhost:8000"
304
+ print(f"Connecting to {server_url}...")
305
+
306
+ try:
307
+ with DMControlEnv(base_url=server_url) as env:
308
+ print("Connected!")
309
+
310
+ # Get environment state
311
+ state = env.state()
312
+ print(f"Domain: {state.domain_name}, Task: {state.task_name}")
313
+ print(f"Action spec: {state.action_spec}")
314
+
315
+ if args.headless:
316
+ run_headless(env, task=args.task, max_steps=args.max_steps)
317
+ elif args.visual:
318
+ run_visual(env, task=args.task)
319
+ else:
320
+ run_interactive(env, task=args.task)
321
+
322
+ except ConnectionError as e:
323
+ print(f"Failed to connect: {e}")
324
+ print("\nMake sure the server is running:")
325
+ print(" cd OpenEnv")
326
+ print(
327
+ " PYTHONPATH=src:envs uvicorn envs.dm_control_env.server.app:app --port 8000"
328
+ )
329
+ sys.exit(1)
330
+
331
+
332
+ if __name__ == "__main__":
333
+ main()
examples/hopper_control.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Interactive hopper control via OpenEnv.
3
+
4
+ This example demonstrates using the dm_control OpenEnv client with
5
+ the hopper environment. Press SPACE to apply random forces to the joints.
6
+
7
+ Controls:
8
+ SPACE: Apply random force to all joints
9
+ R: Reset environment
10
+ ESC or Q: Quit
11
+
12
+ Requirements:
13
+ pip install pygame
14
+
15
+ Usage:
16
+ 1. Start the server: uvicorn server.app:app --host 0.0.0.0 --port 8000
17
+ 2. Run this script: python examples/hopper_control.py
18
+
19
+ For visual mode (requires working MuJoCo rendering):
20
+ python examples/hopper_control.py --visual
21
+ """
22
+
23
+ import argparse
24
+ import random
25
+ import sys
26
+ from pathlib import Path
27
+
28
+ # Add parent directory to path for imports
29
+ sys.path.insert(0, str(Path(__file__).parent.parent))
30
+
31
+ from client import DMControlEnv
32
+ from models import DMControlAction
33
+
34
+
35
+ def get_action_dim(env: DMControlEnv) -> int:
36
+ """Get the action dimension from the environment state."""
37
+ state = env.state()
38
+ action_spec = state.action_spec
39
+ if action_spec and "shape" in action_spec:
40
+ shape = action_spec["shape"]
41
+ if isinstance(shape, list) and len(shape) > 0:
42
+ return shape[0]
43
+ # Hopper default: 4 actuators (hip, knee, ankle, toe)
44
+ return 4
45
+
46
+
47
+ def generate_random_action(action_dim: int, magnitude: float = 1.0) -> DMControlAction:
48
+ """Generate a random action with values in [-magnitude, magnitude]."""
49
+ values = [random.uniform(-magnitude, magnitude) for _ in range(action_dim)]
50
+ return DMControlAction(values=values)
51
+
52
+
53
+ def generate_zero_action(action_dim: int) -> DMControlAction:
54
+ """Generate a zero action (no force applied)."""
55
+ return DMControlAction(values=[0.0] * action_dim)
56
+
57
+
58
+ def run_headless(env: DMControlEnv, task: str = "hop", max_steps: int = 1000):
59
+ """Run hopper control in headless mode."""
60
+ print("\n=== Headless Mode (OpenEnv Step/Observation Pattern) ===")
61
+ print("This mode demonstrates the OpenEnv API with the hopper.\n")
62
+
63
+ # Reset environment using OpenEnv pattern
64
+ result = env.reset(domain_name="hopper", task_name=task)
65
+ print(f"Initial observations: {list(result.observation.observations.keys())}")
66
+
67
+ # Get action dimension
68
+ action_dim = get_action_dim(env)
69
+ print(f"Action dimension: {action_dim}")
70
+
71
+ total_reward = 0.0
72
+ step_count = 0
73
+
74
+ print("\nRunning with periodic random forces...")
75
+ print("Every 30 steps, a random force burst is applied.\n")
76
+
77
+ while not result.done and step_count < max_steps:
78
+ # Apply random force every 30 steps, otherwise zero action
79
+ if step_count % 30 < 5:
80
+ # Random force burst for 5 steps
81
+ action = generate_random_action(action_dim, magnitude=0.8)
82
+ else:
83
+ # No force
84
+ action = generate_zero_action(action_dim)
85
+
86
+ # Step the environment using OpenEnv pattern
87
+ result = env.step(action)
88
+
89
+ # Access observation and reward from result
90
+ total_reward += result.reward or 0.0
91
+ step_count += 1
92
+
93
+ # Print progress periodically
94
+ if step_count % 100 == 0:
95
+ # Get some observation values
96
+ position = result.observation.observations.get("position", [])
97
+ velocity = result.observation.observations.get("velocity", [])
98
+ print(
99
+ f"Step {step_count}: reward={result.reward:.3f}, "
100
+ f"total={total_reward:.2f}, done={result.done}"
101
+ )
102
+ if position:
103
+ print(f" position: {position[:3]}")
104
+ if velocity:
105
+ print(f" velocity: {velocity[:3]}")
106
+
107
+ print(f"\nEpisode finished: {step_count} steps, total reward: {total_reward:.2f}")
108
+
109
+
110
+ def run_interactive(env: DMControlEnv, task: str = "hop"):
111
+ """Run interactive control with keyboard input via pygame."""
112
+ import pygame
113
+
114
+ print("\n=== Interactive Mode (OpenEnv Step/Observation Pattern) ===")
115
+ print("Press SPACE to apply random force, R to reset, ESC to quit.\n")
116
+
117
+ # Reset environment using OpenEnv pattern
118
+ result = env.reset(domain_name="hopper", task_name=task)
119
+ print(f"Initial observations: {list(result.observation.observations.keys())}")
120
+
121
+ # Get action dimension
122
+ action_dim = get_action_dim(env)
123
+ print(f"Action dimension: {action_dim}")
124
+
125
+ # Initialize pygame for keyboard input (minimal window)
126
+ pygame.init()
127
+ screen = pygame.display.set_mode((400, 100))
128
+ pygame.display.set_caption("Hopper Control - SPACE for random force, R to reset")
129
+ clock = pygame.time.Clock()
130
+
131
+ # Font for display
132
+ font = pygame.font.Font(None, 24)
133
+
134
+ running = True
135
+ total_reward = 0.0
136
+ step_count = 0
137
+ apply_random_force = False
138
+
139
+ print("\nControls:")
140
+ print(" SPACE: Apply random force to joints")
141
+ print(" R: Reset environment")
142
+ print(" ESC or Q: Quit\n")
143
+
144
+ while running:
145
+ # Handle events
146
+ for event in pygame.event.get():
147
+ if event.type == pygame.QUIT:
148
+ running = False
149
+ elif event.type == pygame.KEYDOWN:
150
+ if event.key in (pygame.K_ESCAPE, pygame.K_q):
151
+ running = False
152
+ elif event.key == pygame.K_r:
153
+ result = env.reset(domain_name="hopper", task_name=task)
154
+ total_reward = 0.0
155
+ step_count = 0
156
+ print("Environment reset")
157
+
158
+ # Check for held keys
159
+ keys = pygame.key.get_pressed()
160
+ apply_random_force = keys[pygame.K_SPACE]
161
+
162
+ # Generate action based on input
163
+ if apply_random_force:
164
+ action = generate_random_action(action_dim, magnitude=2.0)
165
+ else:
166
+ action = generate_zero_action(action_dim)
167
+
168
+ # Step the environment using OpenEnv pattern
169
+ result = env.step(action)
170
+
171
+ # Track reward from result
172
+ total_reward += result.reward or 0.0
173
+ step_count += 1
174
+
175
+ # Check if episode is done
176
+ if result.done:
177
+ print(
178
+ f"Episode finished! Steps: {step_count}, "
179
+ f"Total reward: {total_reward:.2f}"
180
+ )
181
+ # Auto-reset on done
182
+ result = env.reset(domain_name="hopper", task_name=task)
183
+ total_reward = 0.0
184
+ step_count = 0
185
+
186
+ # Update display
187
+ screen.fill((30, 30, 30))
188
+ status = "FORCE!" if apply_random_force else "idle"
189
+ text = font.render(
190
+ f"Step: {step_count} | Reward: {total_reward:.1f} | {status}",
191
+ True,
192
+ (255, 255, 255),
193
+ )
194
+ screen.blit(text, (10, 40))
195
+ pygame.display.flip()
196
+
197
+ # Print progress periodically
198
+ if step_count % 200 == 0 and step_count > 0:
199
+ print(f"Step {step_count}: Total reward: {total_reward:.2f}")
200
+
201
+ # Cap at 30 FPS
202
+ clock.tick(30)
203
+
204
+ pygame.quit()
205
+ print(f"Session ended. Final reward: {total_reward:.2f}")
206
+
207
+
208
+ def run_visual(env: DMControlEnv, task: str = "hop"):
209
+ """Run with pygame visualization showing rendered frames."""
210
+ import base64
211
+ import io
212
+
213
+ import pygame
214
+
215
+ print("\n=== Visual Mode (OpenEnv Step/Observation Pattern) ===")
216
+
217
+ # Reset environment with rendering enabled
218
+ result = env.reset(domain_name="hopper", task_name=task, render=True)
219
+ print(f"Initial observations: {list(result.observation.observations.keys())}")
220
+
221
+ # Get action dimension
222
+ action_dim = get_action_dim(env)
223
+ print(f"Action dimension: {action_dim}")
224
+
225
+ # Get first frame to determine window size
226
+ if result.observation.pixels is None:
227
+ print("Error: Server did not return rendered pixels.")
228
+ print("Make sure the server supports render=True")
229
+ print("\nTry running in interactive mode (default) instead.")
230
+ sys.exit(1)
231
+
232
+ # Decode base64 PNG to pygame surface
233
+ png_data = base64.b64decode(result.observation.pixels)
234
+ frame = pygame.image.load(io.BytesIO(png_data))
235
+ frame_size = frame.get_size()
236
+
237
+ # Initialize pygame
238
+ pygame.init()
239
+ screen = pygame.display.set_mode(frame_size)
240
+ pygame.display.set_caption(
241
+ "Hopper (OpenEnv) - SPACE for random force, R to Reset, ESC to Quit"
242
+ )
243
+ clock = pygame.time.Clock()
244
+
245
+ print("Controls:")
246
+ print(" SPACE: Apply random force to joints")
247
+ print(" R: Reset environment")
248
+ print(" ESC or Q: Quit")
249
+
250
+ running = True
251
+ total_reward = 0.0
252
+ step_count = 0
253
+
254
+ while running:
255
+ # Handle events
256
+ for event in pygame.event.get():
257
+ if event.type == pygame.QUIT:
258
+ running = False
259
+ elif event.type == pygame.KEYDOWN:
260
+ if event.key in (pygame.K_ESCAPE, pygame.K_q):
261
+ running = False
262
+ elif event.key == pygame.K_r:
263
+ result = env.reset(
264
+ domain_name="hopper", task_name=task, render=True
265
+ )
266
+ total_reward = 0.0
267
+ step_count = 0
268
+ print("Environment reset")
269
+
270
+ # Check for held keys
271
+ keys = pygame.key.get_pressed()
272
+ apply_random_force = keys[pygame.K_SPACE]
273
+
274
+ # Generate action based on input
275
+ if apply_random_force:
276
+ action = generate_random_action(action_dim, magnitude=2.0)
277
+ else:
278
+ action = generate_zero_action(action_dim)
279
+
280
+ # Step the environment using OpenEnv pattern
281
+ result = env.step(action, render=True)
282
+
283
+ # Track reward from result
284
+ total_reward += result.reward or 0.0
285
+ step_count += 1
286
+
287
+ # Check if episode is done
288
+ if result.done:
289
+ print(
290
+ f"Episode finished! Steps: {step_count}, "
291
+ f"Total reward: {total_reward:.2f}"
292
+ )
293
+ result = env.reset(domain_name="hopper", task_name=task, render=True)
294
+ total_reward = 0.0
295
+ step_count = 0
296
+
297
+ # Render the frame from observation pixels
298
+ if result.observation.pixels:
299
+ png_data = base64.b64decode(result.observation.pixels)
300
+ frame = pygame.image.load(io.BytesIO(png_data))
301
+ screen.blit(frame, (0, 0))
302
+ pygame.display.flip()
303
+
304
+ # Print progress periodically
305
+ if step_count % 200 == 0 and step_count > 0:
306
+ print(f"Step {step_count}: Total reward: {total_reward:.2f}")
307
+
308
+ # Cap at 30 FPS
309
+ clock.tick(30)
310
+
311
+ pygame.quit()
312
+ print(f"Session ended. Final reward: {total_reward:.2f}")
313
+
314
+
315
+ def main():
316
+ parser = argparse.ArgumentParser(
317
+ description="Interactive hopper control via OpenEnv"
318
+ )
319
+ parser.add_argument(
320
+ "--visual",
321
+ action="store_true",
322
+ help="Enable pygame visualization with rendered frames",
323
+ )
324
+ parser.add_argument(
325
+ "--headless",
326
+ action="store_true",
327
+ help="Run in headless mode (no pygame, automated control)",
328
+ )
329
+ parser.add_argument(
330
+ "--max-steps",
331
+ type=int,
332
+ default=1000,
333
+ help="Maximum steps for headless mode (default: 1000)",
334
+ )
335
+ parser.add_argument(
336
+ "--task",
337
+ type=str,
338
+ default="hop",
339
+ choices=["stand", "hop"],
340
+ help="Hopper task (default: hop)",
341
+ )
342
+ args = parser.parse_args()
343
+
344
+ server_url = "http://localhost:8000"
345
+ print(f"Connecting to {server_url}...")
346
+
347
+ try:
348
+ with DMControlEnv(base_url=server_url) as env:
349
+ print("Connected!")
350
+
351
+ # Get environment state
352
+ state = env.state()
353
+ print(f"Domain: {state.domain_name}, Task: {state.task_name}")
354
+ print(f"Action spec: {state.action_spec}")
355
+
356
+ if args.headless:
357
+ run_headless(env, task=args.task, max_steps=args.max_steps)
358
+ elif args.visual:
359
+ run_visual(env, task=args.task)
360
+ else:
361
+ run_interactive(env, task=args.task)
362
+
363
+ except ConnectionError as e:
364
+ print(f"Failed to connect: {e}")
365
+ print("\nMake sure the server is running:")
366
+ print(" cd OpenEnv")
367
+ print(
368
+ " PYTHONPATH=src:envs uvicorn envs.dm_control_env.server.app:app --port 8000"
369
+ )
370
+ sys.exit(1)
371
+
372
+
373
+ if __name__ == "__main__":
374
+ main()
examples/list_environments.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """List all available dm_control.suite environments.
3
+
4
+ This utility prints all available domain/task combinations from dm_control.suite.
5
+ """
6
+
7
+ from dm_control import suite
8
+
9
+
10
+ def main():
11
+ print("Available dm_control.suite environments:")
12
+ print("=" * 50)
13
+
14
+ # Group by domain
15
+ domains = {}
16
+ for domain, task in suite.BENCHMARKING:
17
+ if domain not in domains:
18
+ domains[domain] = []
19
+ domains[domain].append(task)
20
+
21
+ for domain in sorted(domains.keys()):
22
+ tasks = sorted(domains[domain])
23
+ print(f"\n{domain}:")
24
+ for task in tasks:
25
+ # Load env to get action spec
26
+ try:
27
+ env = suite.load(domain_name=domain, task_name=task)
28
+ action_spec = env.action_spec()
29
+ action_dim = action_spec.shape[0]
30
+ obs_keys = list(env.observation_spec().keys())
31
+ env.close()
32
+ print(f" - {task:20s} (action_dim={action_dim}, obs={obs_keys})")
33
+ except Exception as e:
34
+ print(f" - {task:20s} (error: {e})")
35
+
36
+ print("\n" + "=" * 50)
37
+ print(f"Total: {len(suite.BENCHMARKING)} environments")
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
examples/quadruped_control.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Interactive quadruped control via OpenEnv.
3
+
4
+ This example demonstrates using the dm_control OpenEnv client with
5
+ the quadruped environment. Press SPACE to apply random forces to the joints.
6
+
7
+ Controls:
8
+ SPACE: Apply random force to all joints
9
+ R: Reset environment
10
+ ESC or Q: Quit
11
+
12
+ Requirements:
13
+ pip install pygame
14
+
15
+ Usage:
16
+ 1. Start the server: uvicorn server.app:app --host 0.0.0.0 --port 8000
17
+ 2. Run this script: python examples/quadruped_control.py
18
+
19
+ For visual mode (requires working MuJoCo rendering):
20
+ python examples/quadruped_control.py --visual
21
+ """
22
+
23
+ import argparse
24
+ import random
25
+ import sys
26
+ from pathlib import Path
27
+
28
+ # Add parent directory to path for imports
29
+ sys.path.insert(0, str(Path(__file__).parent.parent))
30
+
31
+ from client import DMControlEnv
32
+ from models import DMControlAction
33
+
34
+
35
+ def get_action_dim(env: DMControlEnv) -> int:
36
+ """Get the action dimension from the environment state."""
37
+ state = env.state()
38
+ action_spec = state.action_spec
39
+ if action_spec and "shape" in action_spec:
40
+ shape = action_spec["shape"]
41
+ if isinstance(shape, list) and len(shape) > 0:
42
+ return shape[0]
43
+ # Quadruped default: 12 actuators (3 per leg x 4 legs)
44
+ return 12
45
+
46
+
47
+ def generate_random_action(action_dim: int, magnitude: float = 1.0) -> DMControlAction:
48
+ """Generate a random action with values in [-magnitude, magnitude]."""
49
+ values = [random.uniform(-magnitude, magnitude) for _ in range(action_dim)]
50
+ return DMControlAction(values=values)
51
+
52
+
53
+ def generate_zero_action(action_dim: int) -> DMControlAction:
54
+ """Generate a zero action (no force applied)."""
55
+ return DMControlAction(values=[0.0] * action_dim)
56
+
57
+
58
+ def run_headless(env: DMControlEnv, max_steps: int = 1000):
59
+ """Run quadruped control in headless mode."""
60
+ print("\n=== Headless Mode (OpenEnv Step/Observation Pattern) ===")
61
+ print("This mode demonstrates the OpenEnv API with the quadruped.\n")
62
+
63
+ # Reset environment using OpenEnv pattern
64
+ result = env.reset(domain_name="quadruped", task_name="walk")
65
+ print(f"Initial observations: {list(result.observation.observations.keys())}")
66
+
67
+ # Get action dimension
68
+ action_dim = get_action_dim(env)
69
+ print(f"Action dimension: {action_dim}")
70
+
71
+ total_reward = 0.0
72
+ step_count = 0
73
+
74
+ print("\nRunning with periodic random forces...")
75
+ print("Every 50 steps, a random force burst is applied.\n")
76
+
77
+ while not result.done and step_count < max_steps:
78
+ # Apply random force every 50 steps, otherwise zero action
79
+ if step_count % 50 < 10:
80
+ # Random force burst for 10 steps
81
+ action = generate_random_action(action_dim, magnitude=0.5)
82
+ else:
83
+ # No force
84
+ action = generate_zero_action(action_dim)
85
+
86
+ # Step the environment using OpenEnv pattern
87
+ result = env.step(action)
88
+
89
+ # Access observation and reward from result
90
+ total_reward += result.reward or 0.0
91
+ step_count += 1
92
+
93
+ # Print progress periodically
94
+ if step_count % 100 == 0:
95
+ # Get some observation values
96
+ egocentric_state = result.observation.observations.get(
97
+ "egocentric_state", []
98
+ )
99
+ print(
100
+ f"Step {step_count}: reward={result.reward:.3f}, "
101
+ f"total={total_reward:.2f}, done={result.done}"
102
+ )
103
+ if egocentric_state:
104
+ print(f" egocentric_state (first 5): {egocentric_state[:5]}")
105
+
106
+ print(f"\nEpisode finished: {step_count} steps, total reward: {total_reward:.2f}")
107
+
108
+
109
+ def run_interactive(env: DMControlEnv):
110
+ """Run interactive control with keyboard input via pygame."""
111
+ import pygame
112
+
113
+ print("\n=== Interactive Mode (OpenEnv Step/Observation Pattern) ===")
114
+ print("Press SPACE to apply random force, R to reset, ESC to quit.\n")
115
+
116
+ # Reset environment using OpenEnv pattern
117
+ result = env.reset(domain_name="quadruped", task_name="walk")
118
+ print(f"Initial observations: {list(result.observation.observations.keys())}")
119
+
120
+ # Get action dimension
121
+ action_dim = get_action_dim(env)
122
+ print(f"Action dimension: {action_dim}")
123
+
124
+ # Initialize pygame for keyboard input (minimal window)
125
+ pygame.init()
126
+ screen = pygame.display.set_mode((400, 100))
127
+ pygame.display.set_caption("Quadruped Control - SPACE for random force, R to reset")
128
+ clock = pygame.time.Clock()
129
+
130
+ # Draw instructions on the window
131
+ font = pygame.font.Font(None, 24)
132
+
133
+ running = True
134
+ total_reward = 0.0
135
+ step_count = 0
136
+ apply_random_force = False
137
+
138
+ print("\nControls:")
139
+ print(" SPACE: Apply random force to joints")
140
+ print(" R: Reset environment")
141
+ print(" ESC or Q: Quit\n")
142
+
143
+ while running:
144
+ # Handle events
145
+ for event in pygame.event.get():
146
+ if event.type == pygame.QUIT:
147
+ running = False
148
+ elif event.type == pygame.KEYDOWN:
149
+ if event.key in (pygame.K_ESCAPE, pygame.K_q):
150
+ running = False
151
+ elif event.key == pygame.K_r:
152
+ result = env.reset(domain_name="quadruped", task_name="walk")
153
+ total_reward = 0.0
154
+ step_count = 0
155
+ print("Environment reset")
156
+
157
+ # Check for held keys
158
+ keys = pygame.key.get_pressed()
159
+ apply_random_force = keys[pygame.K_SPACE]
160
+
161
+ # Generate action based on input
162
+ if apply_random_force:
163
+ action = generate_random_action(action_dim, magnitude=2.0)
164
+ else:
165
+ action = generate_zero_action(action_dim)
166
+
167
+ # Step the environment using OpenEnv pattern
168
+ result = env.step(action)
169
+
170
+ # Track reward from result
171
+ total_reward += result.reward or 0.0
172
+ step_count += 1
173
+
174
+ # Check if episode is done
175
+ if result.done:
176
+ print(
177
+ f"Episode finished! Steps: {step_count}, "
178
+ f"Total reward: {total_reward:.2f}"
179
+ )
180
+ # Auto-reset on done
181
+ result = env.reset(domain_name="quadruped", task_name="walk")
182
+ total_reward = 0.0
183
+ step_count = 0
184
+
185
+ # Update display
186
+ screen.fill((30, 30, 30))
187
+ status = "FORCE!" if apply_random_force else "idle"
188
+ text = font.render(
189
+ f"Step: {step_count} | Reward: {total_reward:.1f} | {status}",
190
+ True,
191
+ (255, 255, 255),
192
+ )
193
+ screen.blit(text, (10, 40))
194
+ pygame.display.flip()
195
+
196
+ # Print progress periodically
197
+ if step_count % 200 == 0 and step_count > 0:
198
+ print(f"Step {step_count}: Total reward: {total_reward:.2f}")
199
+
200
+ # Cap at 30 FPS
201
+ clock.tick(30)
202
+
203
+ pygame.quit()
204
+ print(f"Session ended. Final reward: {total_reward:.2f}")
205
+
206
+
207
+ def run_visual(env: DMControlEnv):
208
+ """Run with pygame visualization showing rendered frames."""
209
+ import base64
210
+ import io
211
+
212
+ import pygame
213
+
214
+ print("\n=== Visual Mode (OpenEnv Step/Observation Pattern) ===")
215
+
216
+ # Reset environment with rendering enabled
217
+ result = env.reset(domain_name="quadruped", task_name="walk", render=True)
218
+ print(f"Initial observations: {list(result.observation.observations.keys())}")
219
+
220
+ # Get action dimension
221
+ action_dim = get_action_dim(env)
222
+ print(f"Action dimension: {action_dim}")
223
+
224
+ # Get first frame to determine window size
225
+ if result.observation.pixels is None:
226
+ print("Error: Server did not return rendered pixels.")
227
+ print("Make sure the server supports render=True")
228
+ print("\nTry running in interactive mode (default) instead.")
229
+ sys.exit(1)
230
+
231
+ # Decode base64 PNG to pygame surface
232
+ png_data = base64.b64decode(result.observation.pixels)
233
+ frame = pygame.image.load(io.BytesIO(png_data))
234
+ frame_size = frame.get_size()
235
+
236
+ # Initialize pygame
237
+ pygame.init()
238
+ screen = pygame.display.set_mode(frame_size)
239
+ pygame.display.set_caption(
240
+ "Quadruped (OpenEnv) - SPACE for random force, R to Reset, ESC to Quit"
241
+ )
242
+ clock = pygame.time.Clock()
243
+
244
+ print("Controls:")
245
+ print(" SPACE: Apply random force to joints")
246
+ print(" R: Reset environment")
247
+ print(" ESC or Q: Quit")
248
+
249
+ running = True
250
+ total_reward = 0.0
251
+ step_count = 0
252
+
253
+ while running:
254
+ # Handle events
255
+ for event in pygame.event.get():
256
+ if event.type == pygame.QUIT:
257
+ running = False
258
+ elif event.type == pygame.KEYDOWN:
259
+ if event.key in (pygame.K_ESCAPE, pygame.K_q):
260
+ running = False
261
+ elif event.key == pygame.K_r:
262
+ result = env.reset(
263
+ domain_name="quadruped", task_name="walk", render=True
264
+ )
265
+ total_reward = 0.0
266
+ step_count = 0
267
+ print("Environment reset")
268
+
269
+ # Check for held keys
270
+ keys = pygame.key.get_pressed()
271
+ apply_random_force = keys[pygame.K_SPACE]
272
+
273
+ # Generate action based on input
274
+ if apply_random_force:
275
+ action = generate_random_action(action_dim, magnitude=2.0)
276
+ else:
277
+ action = generate_zero_action(action_dim)
278
+
279
+ # Step the environment using OpenEnv pattern
280
+ result = env.step(action, render=True)
281
+
282
+ # Track reward from result
283
+ total_reward += result.reward or 0.0
284
+ step_count += 1
285
+
286
+ # Check if episode is done
287
+ if result.done:
288
+ print(
289
+ f"Episode finished! Steps: {step_count}, "
290
+ f"Total reward: {total_reward:.2f}"
291
+ )
292
+ result = env.reset(domain_name="quadruped", task_name="walk", render=True)
293
+ total_reward = 0.0
294
+ step_count = 0
295
+
296
+ # Render the frame from observation pixels
297
+ if result.observation.pixels:
298
+ png_data = base64.b64decode(result.observation.pixels)
299
+ frame = pygame.image.load(io.BytesIO(png_data))
300
+ screen.blit(frame, (0, 0))
301
+ pygame.display.flip()
302
+
303
+ # Print progress periodically
304
+ if step_count % 200 == 0 and step_count > 0:
305
+ print(f"Step {step_count}: Total reward: {total_reward:.2f}")
306
+
307
+ # Cap at 30 FPS
308
+ clock.tick(30)
309
+
310
+ pygame.quit()
311
+ print(f"Session ended. Final reward: {total_reward:.2f}")
312
+
313
+
314
+ def main():
315
+ parser = argparse.ArgumentParser(
316
+ description="Interactive quadruped control via OpenEnv"
317
+ )
318
+ parser.add_argument(
319
+ "--visual",
320
+ action="store_true",
321
+ help="Enable pygame visualization with rendered frames",
322
+ )
323
+ parser.add_argument(
324
+ "--headless",
325
+ action="store_true",
326
+ help="Run in headless mode (no pygame, automated control)",
327
+ )
328
+ parser.add_argument(
329
+ "--max-steps",
330
+ type=int,
331
+ default=1000,
332
+ help="Maximum steps for headless mode (default: 1000)",
333
+ )
334
+ parser.add_argument(
335
+ "--task",
336
+ type=str,
337
+ default="walk",
338
+ choices=["walk", "run", "escape", "fetch"],
339
+ help="Quadruped task (default: walk)",
340
+ )
341
+ args = parser.parse_args()
342
+
343
+ server_url = "http://localhost:8000"
344
+ print(f"Connecting to {server_url}...")
345
+
346
+ try:
347
+ with DMControlEnv(base_url=server_url) as env:
348
+ print("Connected!")
349
+
350
+ # Get environment state
351
+ state = env.state()
352
+ print(f"Domain: {state.domain_name}, Task: {state.task_name}")
353
+ print(f"Action spec: {state.action_spec}")
354
+
355
+ if args.headless:
356
+ run_headless(env, max_steps=args.max_steps)
357
+ elif args.visual:
358
+ run_visual(env)
359
+ else:
360
+ run_interactive(env)
361
+
362
+ except ConnectionError as e:
363
+ print(f"Failed to connect: {e}")
364
+ print("\nMake sure the server is running:")
365
+ print(" cd OpenEnv")
366
+ print(
367
+ " PYTHONPATH=src:envs uvicorn envs.dm_control_env.server.app:app --port 8000"
368
+ )
369
+ sys.exit(1)
370
+
371
+
372
+ if __name__ == "__main__":
373
+ main()
models.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for the dm_control OpenEnv Environment.
9
+
10
+ This environment wraps dm_control.suite, providing access to all MuJoCo-based
11
+ continuous control tasks (cartpole, walker, humanoid, cheetah, etc.).
12
+ """
13
+
14
+ from typing import Any, Dict, List, Optional
15
+
16
+ from pydantic import Field
17
+
18
+ try:
19
+ from openenv.core.env_server.types import Action, Observation, State
20
+ except ImportError:
21
+ from openenv.core.env_server.types import Action, Observation, State
22
+
23
+
24
+ class DMControlAction(Action):
25
+ """
26
+ Action for dm_control environments.
27
+
28
+ All dm_control.suite environments use continuous actions represented as
29
+ a list of float values. The size and bounds depend on the specific
30
+ domain/task combination.
31
+
32
+ Example (cartpole - 1D action):
33
+ >>> action = DMControlAction(values=[0.5]) # Push cart right
34
+
35
+ Example (walker - 6D action):
36
+ >>> action = DMControlAction(values=[0.1, -0.2, 0.3, 0.0, -0.1, 0.2])
37
+
38
+ Attributes:
39
+ values: List of continuous action values. Shape and bounds depend on
40
+ the loaded environment's action_spec.
41
+ """
42
+
43
+ values: List[float] = Field(
44
+ default_factory=list,
45
+ description="Continuous action values matching the environment's action_spec",
46
+ )
47
+
48
+
49
+ class DMControlObservation(Observation):
50
+ """
51
+ Observation from dm_control environments.
52
+
53
+ dm_control environments return observations as a dictionary of named arrays.
54
+ Common observation keys include 'position', 'velocity', 'orientations', etc.
55
+ The exact keys depend on the domain/task combination.
56
+
57
+ Example observation keys by domain:
58
+ - cartpole: 'position' (cos/sin of angle), 'velocity'
59
+ - walker: 'orientations', 'height', 'velocity'
60
+ - humanoid: 'joint_angles', 'head_height', 'extremities', 'torso_vertical', 'com_velocity'
61
+
62
+ Attributes:
63
+ observations: Dictionary mapping observation names to their values.
64
+ Each value is a flattened list of floats.
65
+ pixels: Optional base64-encoded PNG image of the rendered scene.
66
+ Only included when render=True is passed to reset/step.
67
+ """
68
+
69
+ observations: Dict[str, List[float]] = Field(
70
+ default_factory=dict,
71
+ description="Named observation arrays from the environment",
72
+ )
73
+ pixels: Optional[str] = Field(
74
+ default=None,
75
+ description="Base64-encoded PNG image (when render=True)",
76
+ )
77
+
78
+
79
+ class DMControlState(State):
80
+ """
81
+ Extended state for dm_control environments.
82
+
83
+ Provides metadata about the currently loaded environment including
84
+ the domain/task names and action/observation specifications.
85
+
86
+ Attributes:
87
+ episode_id: Unique identifier for the current episode.
88
+ step_count: Number of steps taken in the current episode.
89
+ domain_name: The dm_control domain (e.g., 'cartpole', 'walker').
90
+ task_name: The specific task (e.g., 'balance', 'walk').
91
+ action_spec: Specification of the action space including shape and bounds.
92
+ observation_spec: Specification of the observation space.
93
+ physics_timestep: The physics simulation timestep in seconds.
94
+ control_timestep: The control timestep (time between actions) in seconds.
95
+ """
96
+
97
+ domain_name: str = Field(
98
+ default="cartpole",
99
+ description="The dm_control domain name",
100
+ )
101
+ task_name: str = Field(
102
+ default="balance",
103
+ description="The task name within the domain",
104
+ )
105
+ action_spec: Dict[str, Any] = Field(
106
+ default_factory=dict,
107
+ description="Specification of the action space (shape, dtype, bounds)",
108
+ )
109
+ observation_spec: Dict[str, Any] = Field(
110
+ default_factory=dict,
111
+ description="Specification of the observation space",
112
+ )
113
+ physics_timestep: float = Field(
114
+ default=0.002,
115
+ description="Physics simulation timestep in seconds",
116
+ )
117
+ control_timestep: float = Field(
118
+ default=0.02,
119
+ description="Control timestep (time between actions) in seconds",
120
+ )
121
+
122
+
123
+ # Available dm_control.suite environments
124
+ # Format: (domain_name, task_name)
125
+ AVAILABLE_ENVIRONMENTS = [
126
+ # Cartpole
127
+ ("cartpole", "balance"),
128
+ ("cartpole", "balance_sparse"),
129
+ ("cartpole", "swingup"),
130
+ ("cartpole", "swingup_sparse"),
131
+ # Pendulum
132
+ ("pendulum", "swingup"),
133
+ # Point mass
134
+ ("point_mass", "easy"),
135
+ ("point_mass", "hard"),
136
+ # Reacher
137
+ ("reacher", "easy"),
138
+ ("reacher", "hard"),
139
+ # Ball in cup
140
+ ("ball_in_cup", "catch"),
141
+ # Finger
142
+ ("finger", "spin"),
143
+ ("finger", "turn_easy"),
144
+ ("finger", "turn_hard"),
145
+ # Fish
146
+ ("fish", "upright"),
147
+ ("fish", "swim"),
148
+ # Cheetah
149
+ ("cheetah", "run"),
150
+ # Walker
151
+ ("walker", "stand"),
152
+ ("walker", "walk"),
153
+ ("walker", "run"),
154
+ # Hopper
155
+ ("hopper", "stand"),
156
+ ("hopper", "hop"),
157
+ # Swimmer
158
+ ("swimmer", "swimmer6"),
159
+ ("swimmer", "swimmer15"),
160
+ # Humanoid
161
+ ("humanoid", "stand"),
162
+ ("humanoid", "walk"),
163
+ ("humanoid", "run"),
164
+ # Manipulator
165
+ ("manipulator", "bring_ball"),
166
+ ("manipulator", "bring_peg"),
167
+ ("manipulator", "insert_ball"),
168
+ ("manipulator", "insert_peg"),
169
+ # Acrobot
170
+ ("acrobot", "swingup"),
171
+ ("acrobot", "swingup_sparse"),
172
+ # Stacker
173
+ ("stacker", "stack_2"),
174
+ ("stacker", "stack_4"),
175
+ # Dog
176
+ ("dog", "stand"),
177
+ ("dog", "walk"),
178
+ ("dog", "trot"),
179
+ ("dog", "run"),
180
+ ("dog", "fetch"),
181
+ # Quadruped
182
+ ("quadruped", "walk"),
183
+ ("quadruped", "run"),
184
+ ("quadruped", "escape"),
185
+ ("quadruped", "fetch"),
186
+ ]
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: dm_control_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
pyproject.toml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-dmcontrol-env"
13
+ version = "0.1.0"
14
+ description = "dm_control Environment for OpenEnv - wraps MuJoCo-based continuous control tasks (cartpole, walker, humanoid, etc.)"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv dependencies
18
+ "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
19
+ "fastapi>=0.115.0",
20
+ "pydantic>=2.0.0",
21
+ "uvicorn>=0.24.0",
22
+ "requests>=2.31.0",
23
+ # dm_control dependencies
24
+ "mujoco>=3.0.0",
25
+ "dm_control>=1.0.0",
26
+ "numpy>=1.20.0",
27
+ # Optional: for pixel observations
28
+ "pillow>=9.0.0",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=8.0.0",
34
+ "pytest-cov>=4.0.0",
35
+ ]
36
+ interactive = [
37
+ # For interactive examples with keyboard control
38
+ "pygame>=2.0.0",
39
+ ]
40
+
41
+ [project.scripts]
42
+ # Server entry point - enables running via: uv run --project . server
43
+ server = "dm_control_env.server.app:main"
44
+
45
+ [tool.setuptools]
46
+ include-package-data = true
47
+ packages = ["dm_control_env", "dm_control_env.server"]
48
+ package-dir = { "dm_control_env" = ".", "dm_control_env.server" = "server" }
server/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """dm_control OpenEnv server module."""
2
+
3
+ from .dm_control_environment import DMControlEnvironment
4
+
5
+ __all__ = ["DMControlEnvironment"]
server/app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the dm_control Environment.
9
+
10
+ This module creates an HTTP server that exposes dm_control.suite environments
11
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
12
+
13
+ Usage:
14
+ # Development (with auto-reload):
15
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
16
+
17
+ # Production:
18
+ uvicorn server.app:app --host 0.0.0.0 --port 8000
19
+
20
+ # Or run directly:
21
+ uv run --project . server
22
+ """
23
+
24
+ try:
25
+ from openenv.core.env_server.http_server import create_app
26
+
27
+ from ..models import DMControlAction, DMControlObservation
28
+ from .dm_control_environment import DMControlEnvironment
29
+ except ImportError:
30
+ from openenv.core.env_server.http_server import create_app
31
+
32
+ try:
33
+ import sys
34
+ from pathlib import Path
35
+
36
+ _parent = str(Path(__file__).parent.parent)
37
+ if _parent not in sys.path:
38
+ sys.path.insert(0, _parent)
39
+ from models import DMControlAction, DMControlObservation
40
+ from server.dm_control_environment import DMControlEnvironment
41
+ except ImportError:
42
+ try:
43
+ from dm_control_env.models import DMControlAction, DMControlObservation
44
+ from dm_control_env.server.dm_control_environment import (
45
+ DMControlEnvironment,
46
+ )
47
+ except ImportError:
48
+ from envs.dm_control_env.models import DMControlAction, DMControlObservation
49
+ from envs.dm_control_env.server.dm_control_environment import (
50
+ DMControlEnvironment,
51
+ )
52
+
53
+ # Create the app with web interface
54
+ # Pass the class (factory) for concurrent session support
55
+ app = create_app(
56
+ DMControlEnvironment,
57
+ DMControlAction,
58
+ DMControlObservation,
59
+ env_name="dm_control_env",
60
+ )
61
+
62
+
63
+ def main():
64
+ """
65
+ Entry point for direct execution via uv run or python -m.
66
+
67
+ This function enables running the server without Docker:
68
+ uv run --project . server
69
+ python -m envs.dm_control_env.server.app
70
+ openenv serve dm_control_env
71
+ """
72
+ import uvicorn
73
+
74
+ uvicorn.run(app, host="0.0.0.0", port=8000)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ main()
server/dm_control_environment.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ dm_control Environment Implementation.
9
+
10
+ Wraps dm_control.suite environments (cartpole, walker, humanoid, etc.)
11
+ with the OpenEnv interface for standardized reinforcement learning.
12
+ """
13
+
14
+ import base64
15
+ import io
16
+ import os
17
+ import sys
18
+ from typing import Any, Dict, Optional
19
+ from uuid import uuid4
20
+
21
+ # Configure MuJoCo rendering backend before importing dm_control
22
+ # On macOS, we don't set MUJOCO_GL - use default (glfw) which works
23
+ # when running synchronously in the main thread (see reset_async/step_async)
24
+ # On Linux, use egl for headless rendering
25
+ if "MUJOCO_GL" not in os.environ and sys.platform != "darwin":
26
+ os.environ.setdefault("MUJOCO_GL", "egl")
27
+
28
+ import numpy as np
29
+
30
+ try:
31
+ from openenv.core.env_server.interfaces import Environment
32
+
33
+ from ..models import DMControlAction, DMControlObservation, DMControlState
34
+ except ImportError:
35
+ from openenv.core.env_server.interfaces import Environment
36
+
37
+ try:
38
+ import sys
39
+ from pathlib import Path
40
+
41
+ _parent = str(Path(__file__).parent.parent)
42
+ if _parent not in sys.path:
43
+ sys.path.insert(0, _parent)
44
+ from models import DMControlAction, DMControlObservation, DMControlState
45
+ except ImportError:
46
+ try:
47
+ from dm_control_env.models import (
48
+ DMControlAction,
49
+ DMControlObservation,
50
+ DMControlState,
51
+ )
52
+ except ImportError:
53
+ from envs.dm_control_env.models import (
54
+ DMControlAction,
55
+ DMControlObservation,
56
+ DMControlState,
57
+ )
58
+
59
+
60
+ class DMControlEnvironment(Environment):
61
+ """
62
+ Wraps dm_control.suite environments with the OpenEnv interface.
63
+
64
+ This environment supports all dm_control.suite domains and tasks including
65
+ cartpole, walker, humanoid, cheetah, and more.
66
+
67
+ Features:
68
+ - Dynamic environment switching via reset(domain_name="...", task_name="...")
69
+ - Support for all continuous control tasks
70
+ - Optional visual observations (base64-encoded images)
71
+ - Configurable via constructor or environment variables
72
+
73
+ Example:
74
+ >>> env = DMControlEnvironment()
75
+ >>> obs = env.reset() # Default: cartpole/balance
76
+ >>> print(obs.observations)
77
+ >>>
78
+ >>> # Take an action
79
+ >>> obs = env.step(DMControlAction(values=[0.5])) # Push cart right
80
+ >>> print(obs.reward)
81
+
82
+ Example with different environment:
83
+ >>> env = DMControlEnvironment(domain_name="walker", task_name="walk")
84
+ >>> obs = env.reset()
85
+ >>>
86
+ >>> # Or switch environment on reset
87
+ >>> obs = env.reset(domain_name="cheetah", task_name="run")
88
+ """
89
+
90
+ # dm_control environments are isolated and thread-safe
91
+ SUPPORTS_CONCURRENT_SESSIONS = True
92
+
93
+ def __init__(
94
+ self,
95
+ domain_name: Optional[str] = None,
96
+ task_name: Optional[str] = None,
97
+ render_height: Optional[int] = None,
98
+ render_width: Optional[int] = None,
99
+ ):
100
+ """
101
+ Initialize the dm_control environment.
102
+
103
+ Args:
104
+ domain_name: The dm_control domain to load.
105
+ Env var: DMCONTROL_DOMAIN (default: cartpole)
106
+ task_name: The task within the domain.
107
+ Env var: DMCONTROL_TASK (default: balance)
108
+ render_height: Height of rendered images (when render=True).
109
+ Env var: DMCONTROL_RENDER_HEIGHT (default: 480)
110
+ render_width: Width of rendered images (when render=True).
111
+ Env var: DMCONTROL_RENDER_WIDTH (default: 640)
112
+ """
113
+ self._env = None
114
+
115
+ self._domain_name = domain_name or os.environ.get(
116
+ "DMCONTROL_DOMAIN", "cartpole"
117
+ )
118
+ self._task_name = task_name or os.environ.get("DMCONTROL_TASK", "balance")
119
+ self._render_height = (
120
+ render_height
121
+ if render_height is not None
122
+ else int(os.environ.get("DMCONTROL_RENDER_HEIGHT", "480"))
123
+ )
124
+ self._render_width = (
125
+ render_width
126
+ if render_width is not None
127
+ else int(os.environ.get("DMCONTROL_RENDER_WIDTH", "640"))
128
+ )
129
+ self._include_pixels = False
130
+
131
+ self._state = DMControlState(
132
+ episode_id=str(uuid4()),
133
+ step_count=0,
134
+ domain_name=self._domain_name,
135
+ task_name=self._task_name,
136
+ )
137
+
138
+ def _load_environment(self, domain_name: str, task_name: str) -> None:
139
+ """Load or switch to a dm_control environment."""
140
+ if self._env is not None:
141
+ try:
142
+ self._env.close()
143
+ except Exception:
144
+ pass
145
+
146
+ try:
147
+ from dm_control import suite
148
+ except ImportError as e:
149
+ raise ImportError(
150
+ "dm_control is required. Install with: pip install dm_control"
151
+ ) from e
152
+ except Exception as e:
153
+ # MuJoCo/OpenGL initialization can fail on macOS
154
+ error_msg = str(e)
155
+ if sys.platform == "darwin":
156
+ raise RuntimeError(
157
+ f"Failed to import dm_control (MuJoCo error): {error_msg}\n\n"
158
+ "On macOS, try one of these solutions:\n"
159
+ "1. Install osmesa: brew install mesa\n"
160
+ "2. Run with MUJOCO_GL=glfw (requires display)\n"
161
+ "3. Run with MUJOCO_GL=egl (if EGL is available)"
162
+ ) from e
163
+ raise
164
+
165
+ try:
166
+ self._env = suite.load(domain_name=domain_name, task_name=task_name)
167
+ except Exception as e:
168
+ error_msg = str(e).lower()
169
+ # Check for MuJoCo/OpenGL errors
170
+ if "gl" in error_msg or "render" in error_msg or "display" in error_msg:
171
+ if sys.platform == "darwin":
172
+ raise RuntimeError(
173
+ f"MuJoCo initialization failed: {e}\n\n"
174
+ "On macOS, try one of these solutions:\n"
175
+ "1. Install osmesa: brew install mesa\n"
176
+ "2. Run with MUJOCO_GL=glfw (requires display)\n"
177
+ "3. Set PYOPENGL_PLATFORM=osmesa"
178
+ ) from e
179
+ # Check if it's an invalid environment error
180
+ try:
181
+ available = [(d, t) for d, t in suite.BENCHMARKING]
182
+ raise ValueError(
183
+ f"Failed to load {domain_name}/{task_name}. "
184
+ f"Available environments: {available[:10]}... "
185
+ f"(use dm_control.suite.BENCHMARKING for full list)"
186
+ ) from e
187
+ except Exception:
188
+ raise
189
+
190
+ self._domain_name = domain_name
191
+ self._task_name = task_name
192
+
193
+ self._state.domain_name = domain_name
194
+ self._state.task_name = task_name
195
+ self._state.action_spec = self._get_action_spec_info()
196
+ self._state.observation_spec = self._get_observation_spec_info()
197
+ self._state.physics_timestep = self._env.physics.timestep()
198
+ self._state.control_timestep = self._env.control_timestep()
199
+
200
+ def _get_action_spec_info(self) -> Dict[str, Any]:
201
+ """Get information about the action space."""
202
+ spec = self._env.action_spec()
203
+ return {
204
+ "shape": list(spec.shape),
205
+ "dtype": str(spec.dtype),
206
+ "minimum": spec.minimum.tolist(),
207
+ "maximum": spec.maximum.tolist(),
208
+ "name": spec.name,
209
+ }
210
+
211
+ def _get_observation_spec_info(self) -> Dict[str, Any]:
212
+ """Get information about the observation space."""
213
+ specs = self._env.observation_spec()
214
+ obs_info = {}
215
+ for name, spec in specs.items():
216
+ obs_info[name] = {
217
+ "shape": list(spec.shape),
218
+ "dtype": str(spec.dtype),
219
+ }
220
+ return obs_info
221
+
222
+ def _get_observation(
223
+ self,
224
+ time_step,
225
+ include_pixels: bool = False,
226
+ ) -> DMControlObservation:
227
+ """Convert dm_control TimeStep to DMControlObservation."""
228
+ import dm_env
229
+
230
+ observations = {}
231
+ for name, value in time_step.observation.items():
232
+ observations[name] = np.asarray(value).flatten().tolist()
233
+
234
+ pixels = None
235
+ if include_pixels:
236
+ try:
237
+ frame = self._env.physics.render(
238
+ height=self._render_height,
239
+ width=self._render_width,
240
+ camera_id=0,
241
+ )
242
+ from PIL import Image
243
+
244
+ img = Image.fromarray(frame)
245
+ buffer = io.BytesIO()
246
+ img.save(buffer, format="PNG")
247
+ pixels = base64.b64encode(buffer.getvalue()).decode("utf-8")
248
+ except Exception:
249
+ pass
250
+
251
+ done = time_step.step_type == dm_env.StepType.LAST
252
+ reward = float(time_step.reward) if time_step.reward is not None else 0.0
253
+
254
+ return DMControlObservation(
255
+ observations=observations,
256
+ pixels=pixels,
257
+ reward=reward,
258
+ done=done,
259
+ )
260
+
261
+ def reset(
262
+ self,
263
+ domain_name: Optional[str] = None,
264
+ task_name: Optional[str] = None,
265
+ seed: Optional[int] = None,
266
+ render: bool = False,
267
+ **kwargs,
268
+ ) -> DMControlObservation:
269
+ """
270
+ Reset the environment and return initial observation.
271
+
272
+ Args:
273
+ domain_name: Optionally switch to a different domain.
274
+ task_name: Optionally switch to a different task.
275
+ seed: Random seed for reproducibility.
276
+ render: If True, include pixel observations.
277
+ **kwargs: Additional arguments (ignored).
278
+
279
+ Returns:
280
+ DMControlObservation with initial state.
281
+ """
282
+ self._include_pixels = render
283
+
284
+ target_domain = domain_name or self._domain_name
285
+ target_task = task_name or self._task_name
286
+
287
+ if (
288
+ self._env is None
289
+ or target_domain != self._domain_name
290
+ or target_task != self._task_name
291
+ ):
292
+ self._load_environment(target_domain, target_task)
293
+
294
+ if seed is not None:
295
+ np.random.seed(seed)
296
+
297
+ time_step = self._env.reset()
298
+
299
+ self._state = DMControlState(
300
+ episode_id=str(uuid4()),
301
+ step_count=0,
302
+ domain_name=self._domain_name,
303
+ task_name=self._task_name,
304
+ action_spec=self._state.action_spec,
305
+ observation_spec=self._state.observation_spec,
306
+ physics_timestep=self._state.physics_timestep,
307
+ control_timestep=self._state.control_timestep,
308
+ )
309
+
310
+ return self._get_observation(time_step, include_pixels=render)
311
+
312
+ def step(
313
+ self,
314
+ action: DMControlAction,
315
+ render: bool = False,
316
+ **kwargs,
317
+ ) -> DMControlObservation:
318
+ """
319
+ Execute one step in the environment.
320
+
321
+ Args:
322
+ action: DMControlAction with continuous action values.
323
+ render: If True, include pixel observations.
324
+
325
+ Returns:
326
+ DMControlObservation with new state, reward, and done flag.
327
+ """
328
+ if self._env is None:
329
+ raise RuntimeError("Environment not initialized. Call reset() first.")
330
+
331
+ action_array = np.array(action.values, dtype=np.float64)
332
+
333
+ action_spec = self._env.action_spec()
334
+ expected_shape = action_spec.shape
335
+ if action_array.shape != expected_shape:
336
+ if action_array.size == np.prod(expected_shape):
337
+ action_array = action_array.reshape(expected_shape)
338
+ else:
339
+ raise ValueError(
340
+ f"Action shape {action_array.shape} doesn't match "
341
+ f"expected shape {expected_shape}"
342
+ )
343
+
344
+ action_array = np.clip(action_array, action_spec.minimum, action_spec.maximum)
345
+
346
+ time_step = self._env.step(action_array)
347
+ self._state.step_count += 1
348
+
349
+ return self._get_observation(
350
+ time_step, include_pixels=render or self._include_pixels
351
+ )
352
+
353
+ async def reset_async(
354
+ self,
355
+ domain_name: Optional[str] = None,
356
+ task_name: Optional[str] = None,
357
+ seed: Optional[int] = None,
358
+ render: bool = False,
359
+ **kwargs,
360
+ ) -> DMControlObservation:
361
+ """Async version of reset.
362
+
363
+ On macOS, runs synchronously to avoid MuJoCo threading crashes.
364
+ On other platforms, runs in a thread pool.
365
+ """
366
+ if sys.platform == "darwin":
367
+ # On macOS, MuJoCo crashes when run in a background thread
368
+ # Run synchronously (blocks event loop but avoids crash)
369
+ return self.reset(
370
+ domain_name=domain_name,
371
+ task_name=task_name,
372
+ seed=seed,
373
+ render=render,
374
+ **kwargs,
375
+ )
376
+ else:
377
+ import asyncio
378
+
379
+ return await asyncio.to_thread(
380
+ self.reset,
381
+ domain_name=domain_name,
382
+ task_name=task_name,
383
+ seed=seed,
384
+ render=render,
385
+ **kwargs,
386
+ )
387
+
388
+ async def step_async(
389
+ self,
390
+ action: DMControlAction,
391
+ render: bool = False,
392
+ **kwargs,
393
+ ) -> DMControlObservation:
394
+ """Async version of step.
395
+
396
+ On macOS, runs synchronously to avoid MuJoCo threading crashes.
397
+ On other platforms, runs in a thread pool.
398
+ """
399
+ if sys.platform == "darwin":
400
+ # On macOS, MuJoCo crashes when run in a background thread
401
+ # Run synchronously (blocks event loop but avoids crash)
402
+ return self.step(action, render=render, **kwargs)
403
+ else:
404
+ import asyncio
405
+
406
+ return await asyncio.to_thread(self.step, action, render=render, **kwargs)
407
+
408
+ @property
409
+ def state(self) -> DMControlState:
410
+ """Get the current environment state."""
411
+ return self._state
412
+
413
+ def close(self) -> None:
414
+ """Close the dm_control environment."""
415
+ env = getattr(self, "_env", None)
416
+ if env is not None:
417
+ try:
418
+ env.close()
419
+ except Exception:
420
+ pass
421
+ self._env = None
422
+
423
+ def __del__(self):
424
+ """Cleanup on deletion."""
425
+ try:
426
+ self.close()
427
+ except Exception:
428
+ pass
uv.lock ADDED
The diff for this file is too large to render. See raw diff