SavirD commited on
Commit
fc6b5c1
·
verified ·
1 Parent(s): c6c487e

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=my_env
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD curl -f http://localhost:8000/health || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ ENV ENABLE_WEB_INTERFACE=true
81
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,10 +1,255 @@
1
  ---
2
- title: My Env
3
- emoji: 🦀
4
- colorFrom: gray
5
- colorTo: green
6
  sdk: docker
7
  pinned: false
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: My Env Environment Server
3
+ emoji: 🎴
4
+ colorFrom: blue
5
+ colorTo: pink
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
  ---
13
 
14
+ # My Env Environment
15
+
16
+ A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
17
+
18
+ ## Quick Start
19
+
20
+ The simplest way to use the My Env environment is through the `MyEnv` class:
21
+
22
+ ```python
23
+ from my_env import MyAction, MyEnv
24
+
25
+ try:
26
+ # Create environment from Docker image
27
+ my_envenv = MyEnv.from_docker_image("my_env-env:latest")
28
+
29
+ # Reset
30
+ result = my_envenv.reset()
31
+ print(f"Reset: {result.observation.echoed_message}")
32
+
33
+ # Send multiple messages
34
+ messages = ["Hello, World!", "Testing echo", "Final message"]
35
+
36
+ for msg in messages:
37
+ result = my_envenv.step(MyAction(message=msg))
38
+ print(f"Sent: '{msg}'")
39
+ print(f" → Echoed: '{result.observation.echoed_message}'")
40
+ print(f" → Length: {result.observation.message_length}")
41
+ print(f" → Reward: {result.reward}")
42
+
43
+ finally:
44
+ # Always clean up
45
+ my_envenv.close()
46
+ ```
47
+
48
+ That's it! The `MyEnv.from_docker_image()` method handles:
49
+ - Starting the Docker container
50
+ - Waiting for the server to be ready
51
+ - Connecting to the environment
52
+ - Container cleanup when you call `close()`
53
+
54
+ ## Building the Docker Image
55
+
56
+ Before using the environment, you need to build the Docker image:
57
+
58
+ ```bash
59
+ # From project root
60
+ docker build -t my_env-env:latest -f server/Dockerfile .
61
+ ```
62
+
63
+ ## Deploying to Hugging Face Spaces
64
+
65
+ You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
66
+
67
+ ```bash
68
+ # From the environment directory (where openenv.yaml is located)
69
+ openenv push
70
+
71
+ # Or specify options
72
+ openenv push --namespace my-org --private
73
+ ```
74
+
75
+ The `openenv push` command will:
76
+ 1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
77
+ 2. Prepare a custom build for Hugging Face Docker space (enables web interface)
78
+ 3. Upload to Hugging Face (ensuring you're logged in)
79
+
80
+ ### Prerequisites
81
+
82
+ - Authenticate with Hugging Face: The command will prompt for login if not already authenticated
83
+
84
+ ### Options
85
+
86
+ - `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
87
+ - `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
88
+ - `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
89
+ - `--private`: Deploy the space as private (default: public)
90
+
91
+ ### Examples
92
+
93
+ ```bash
94
+ # Push to your personal namespace (defaults to username/env-name from openenv.yaml)
95
+ openenv push
96
+
97
+ # Push to a specific repository
98
+ openenv push --repo-id my-org/my-env
99
+
100
+ # Push with a custom base image
101
+ openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
102
+
103
+ # Push as a private space
104
+ openenv push --private
105
+
106
+ # Combine options
107
+ openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
108
+ ```
109
+
110
+ After deployment, your space will be available at:
111
+ `https://huggingface.co/spaces/<repo-id>`
112
+
113
+ The deployed space includes:
114
+ - **Web Interface** at `/web` - Interactive UI for exploring the environment
115
+ - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
116
+ - **Health Check** at `/health` - Container health monitoring
117
+ - **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
118
+
119
+ ## Environment Details
120
+
121
+ ### Action
122
+ **MyAction**: Contains a single field
123
+ - `message` (str) - The message to echo back
124
+
125
+ ### Observation
126
+ **MyObservation**: Contains the echo response and metadata
127
+ - `echoed_message` (str) - The message echoed back
128
+ - `message_length` (int) - Length of the message
129
+ - `reward` (float) - Reward based on message length (length × 0.1)
130
+ - `done` (bool) - Always False for echo environment
131
+ - `metadata` (dict) - Additional info like step count
132
+
133
+ ### Reward
134
+ The reward is calculated as: `message_length × 0.1`
135
+ - "Hi" → reward: 0.2
136
+ - "Hello, World!" → reward: 1.3
137
+ - Empty message → reward: 0.0
138
+
139
+ ## Advanced Usage
140
+
141
+ ### Connecting to an Existing Server
142
+
143
+ If you already have a My Env environment server running, you can connect directly:
144
+
145
+ ```python
146
+ from my_env import MyEnv
147
+
148
+ # Connect to existing server
149
+ my_envenv = MyEnv(base_url="<ENV_HTTP_URL_HERE>")
150
+
151
+ # Use as normal
152
+ result = my_envenv.reset()
153
+ result = my_envenv.step(MyAction(message="Hello!"))
154
+ ```
155
+
156
+ Note: When connecting to an existing server, `my_envenv.close()` will NOT stop the server.
157
+
158
+ ### Using the Context Manager
159
+
160
+ The client supports context manager usage for automatic connection management:
161
+
162
+ ```python
163
+ from my_env import MyAction, MyEnv
164
+
165
+ # Connect with context manager (auto-connects and closes)
166
+ with MyEnv(base_url="http://localhost:8000") as env:
167
+ result = env.reset()
168
+ print(f"Reset: {result.observation.echoed_message}")
169
+ # Multiple steps with low latency
170
+ for msg in ["Hello", "World", "!"]:
171
+ result = env.step(MyAction(message=msg))
172
+ print(f"Echoed: {result.observation.echoed_message}")
173
+ ```
174
+
175
+ The client uses WebSocket connections for:
176
+ - **Lower latency**: No HTTP connection overhead per request
177
+ - **Persistent session**: Server maintains your environment state
178
+ - **Efficient for episodes**: Better for many sequential steps
179
+
180
+ ### Concurrent WebSocket Sessions
181
+
182
+ The server supports multiple concurrent WebSocket connections. To enable this,
183
+ modify `server/app.py` to use factory mode:
184
+
185
+ ```python
186
+ # In server/app.py - use factory mode for concurrent sessions
187
+ app = create_app(
188
+ MyEnvironment, # Pass class, not instance
189
+ MyAction,
190
+ MyObservation,
191
+ max_concurrent_envs=4, # Allow 4 concurrent sessions
192
+ )
193
+ ```
194
+
195
+ Then multiple clients can connect simultaneously:
196
+
197
+ ```python
198
+ from my_env import MyAction, MyEnv
199
+ from concurrent.futures import ThreadPoolExecutor
200
+
201
+ def run_episode(client_id: int):
202
+ with MyEnv(base_url="http://localhost:8000") as env:
203
+ result = env.reset()
204
+ for i in range(10):
205
+ result = env.step(MyAction(message=f"Client {client_id}, step {i}"))
206
+ return client_id, result.observation.message_length
207
+
208
+ # Run 4 episodes concurrently
209
+ with ThreadPoolExecutor(max_workers=4) as executor:
210
+ results = list(executor.map(run_episode, range(4)))
211
+ ```
212
+
213
+ ## Development & Testing
214
+
215
+ ### Direct Environment Testing
216
+
217
+ Test the environment logic directly without starting the HTTP server:
218
+
219
+ ```bash
220
+ # From the server directory
221
+ python3 server/my_env_environment.py
222
+ ```
223
+
224
+ This verifies that:
225
+ - Environment resets correctly
226
+ - Step executes actions properly
227
+ - State tracking works
228
+ - Rewards are calculated correctly
229
+
230
+ ### Running Locally
231
+
232
+ Run the server locally for development:
233
+
234
+ ```bash
235
+ uvicorn server.app:app --reload
236
+ ```
237
+
238
+ ## Project Structure
239
+
240
+ ```
241
+ my_env/
242
+ ├── .dockerignore # Docker build exclusions
243
+ ├── __init__.py # Module exports
244
+ ├── README.md # This file
245
+ ├── openenv.yaml # OpenEnv manifest
246
+ ├── pyproject.toml # Project metadata and dependencies
247
+ ├── uv.lock # Locked dependencies (generated)
248
+ ├── client.py # MyEnv client
249
+ ├── models.py # Action and Observation models
250
+ └── server/
251
+ ├── __init__.py # Server module exports
252
+ ├── my_env_environment.py # Core environment logic
253
+ ├── app.py # FastAPI application (HTTP + WebSocket endpoints)
254
+ └── Dockerfile # Container image definition
255
+ ```
__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Meta-Optimizer and My Env environments."""
8
+
9
+ from .client import MetaOptimizerEnv
10
+ from .models import (
11
+ MetaOptimizerAction,
12
+ MetaOptimizerObservation,
13
+ MyAction,
14
+ MyObservation,
15
+ )
16
+
17
+ __all__ = [
18
+ "MetaOptimizerEnv",
19
+ "MetaOptimizerAction",
20
+ "MetaOptimizerObservation",
21
+ "MyAction",
22
+ "MyObservation",
23
+ ]
client.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Meta-Optimizer Environment Client (OpenEnv WebSocket client)."""
8
+
9
+ from typing import Dict
10
+
11
+ from openenv.core.client_types import StepResult
12
+ from openenv.core.env_server.types import State
13
+ from openenv.core import EnvClient
14
+
15
+ from .models import MetaOptimizerAction, MetaOptimizerObservation
16
+
17
+
18
+ class MetaOptimizerEnv(
19
+ EnvClient[MetaOptimizerAction, MetaOptimizerObservation, State]
20
+ ):
21
+ """
22
+ Client for the Meta-Optimizer Environment.
23
+
24
+ Connects to the meta-optimizer server over WebSocket. Use reset(seed=..., task_id=...)
25
+ for training (task_id=None samples from 50 train tasks) or eval (task_id in EVAL_TASK_IDS).
26
+ """
27
+
28
+ def _step_payload(self, action: MetaOptimizerAction) -> Dict:
29
+ return {
30
+ "lr_scale": action.lr_scale,
31
+ "momentum_coef": action.momentum_coef,
32
+ "grad_clip_threshold": action.grad_clip_threshold,
33
+ "weight_decay_this_step": action.weight_decay_this_step,
34
+ }
35
+
36
+ def _parse_result(
37
+ self, payload: Dict
38
+ ) -> StepResult[MetaOptimizerObservation]:
39
+ obs_data = payload.get("observation", {})
40
+ observation = MetaOptimizerObservation(
41
+ loss=obs_data.get("loss", 0.0),
42
+ step_count=obs_data.get("step_count", 0),
43
+ grad_norm=obs_data.get("grad_norm"),
44
+ steps_to_threshold=obs_data.get("steps_to_threshold"),
45
+ done=payload.get("done", False),
46
+ reward=payload.get("reward"),
47
+ metadata=obs_data.get("metadata", {}),
48
+ )
49
+ return StepResult(
50
+ observation=observation,
51
+ reward=payload.get("reward"),
52
+ done=payload.get("done", False),
53
+ )
54
+
55
+ def _parse_state(self, payload: Dict) -> State:
56
+ return State(
57
+ episode_id=payload.get("episode_id"),
58
+ step_count=payload.get("step_count", 0),
59
+ )
env_gym.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Gymnasium wrapper for MetaOptimizerEnvironment for use with Stable-Baselines3 (e.g. SAC).
9
+ """
10
+
11
+ import math
12
+ from typing import Any, Dict, Optional, Tuple
13
+
14
+ import gymnasium as gym
15
+ import numpy as np
16
+
17
+ from my_env.models import MetaOptimizerAction
18
+ from my_env.server.meta_optimizer_environment import MetaOptimizerEnvironment
19
+ from my_env.server.tasks import get_task
20
+
21
+ # Bounds for normalization / clipping
22
+ LOSS_LOG_MAX = 2.0 # log10(loss+1e-8) capped for obs
23
+ GRAD_NORM_SCALE = 10.0
24
+
25
+
26
+ def obs_to_vector(obs: Any, max_steps: int) -> np.ndarray:
27
+ """Convert MetaOptimizerObservation to a fixed-size vector for SAC."""
28
+ loss = getattr(obs, "loss", 0.0) or 0.0
29
+ step_count = getattr(obs, "step_count", 0)
30
+ grad_norm = getattr(obs, "grad_norm", None)
31
+ # Normalize: log loss (bounded), step ratio, grad norm scale
32
+ loss_feat = min(math.log10(loss + 1e-8), LOSS_LOG_MAX) / LOSS_LOG_MAX
33
+ step_feat = step_count / max(1, max_steps)
34
+ grad_feat = (grad_norm / GRAD_NORM_SCALE) if grad_norm is not None else 0.0
35
+ grad_feat = min(max(grad_feat, 0.0), 1.0)
36
+ return np.array([loss_feat, step_feat, grad_feat], dtype=np.float32)
37
+
38
+
39
+ def vector_to_action(vec: np.ndarray) -> MetaOptimizerAction:
40
+ """Map [0,1]^4 to action bounds: lr [1e-4, 1], momentum [0,1], clip [0, 2], wd [0, 1e-3]."""
41
+ lr = 1e-4 + (1.0 - 1e-4) * float(np.clip(vec[0], 0, 1))
42
+ momentum = float(np.clip(vec[1], 0, 1))
43
+ clip = 2.0 * float(np.clip(vec[2], 0, 1))
44
+ wd = 1e-3 * float(np.clip(vec[3], 0, 1))
45
+ return MetaOptimizerAction(
46
+ lr_scale=lr,
47
+ momentum_coef=momentum,
48
+ grad_clip_threshold=clip,
49
+ weight_decay_this_step=wd,
50
+ )
51
+
52
+
53
+ class MetaOptimizerGymEnv(gym.Env):
54
+ """
55
+ Gymnasium env wrapping MetaOptimizerEnvironment for SAC.
56
+ Samples tasks from Distribution A (task_id 0..49) on each reset.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ max_steps: int = 100,
62
+ loss_threshold: float = 0.1,
63
+ task_ids: Optional[list] = None,
64
+ ):
65
+ super().__init__()
66
+ self._max_steps = max_steps
67
+ self._loss_threshold = loss_threshold
68
+ self._task_ids = task_ids or list(range(50))
69
+ self._env = MetaOptimizerEnvironment(
70
+ max_steps=max_steps,
71
+ loss_threshold=loss_threshold,
72
+ )
73
+ # Obs: loss (norm), step (norm), grad_norm (norm) = 3
74
+ self.observation_space = gym.spaces.Box(
75
+ low=0.0, high=1.0, shape=(3,), dtype=np.float32
76
+ )
77
+ # Action: lr, momentum, grad_clip, weight_decay (all [0,1] mapped to bounds in vector_to_action)
78
+ self.action_space = gym.spaces.Box(
79
+ low=0.0, high=1.0, shape=(4,), dtype=np.float32
80
+ )
81
+
82
+ def reset(
83
+ self, *, seed: Optional[int] = None, options: Optional[Dict] = None
84
+ ) -> Tuple[np.ndarray, Dict]:
85
+ import random
86
+ if seed is not None:
87
+ self._np_random = np.random.default_rng(seed)
88
+ idx = self._np_random.integers(0, len(self._task_ids))
89
+ task_id = self._task_ids[idx]
90
+ else:
91
+ task_id = random.choice(self._task_ids)
92
+ obs = self._env.reset(seed=seed, task_id=task_id)
93
+ vec = obs_to_vector(obs, self._max_steps)
94
+ return vec, {"task_id": task_id}
95
+
96
+ def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict]:
97
+ act = vector_to_action(action)
98
+ obs = self._env.step(act)
99
+ vec = obs_to_vector(obs, self._max_steps)
100
+ reward = float(obs.reward if obs.reward is not None else 0.0)
101
+ done = bool(obs.done)
102
+ truncated = False
103
+ info = {
104
+ "loss": obs.loss,
105
+ "step_count": obs.step_count,
106
+ "steps_to_threshold": obs.steps_to_threshold,
107
+ }
108
+ return vec, reward, done, truncated, info
models.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for the My Env Environment.
9
+
10
+ The my_env environment is a simple test environment that echoes back messages.
11
+ Meta-optimizer models support the meta-learning RL optimizer environment.
12
+ """
13
+
14
+ from pydantic import Field
15
+
16
+ from openenv.core.env_server.types import Action, Observation
17
+
18
+
19
+ class MyAction(Action):
20
+ """Action for the My Env environment - just a message to echo."""
21
+
22
+ message: str = Field(..., description="Message to echo back")
23
+
24
+
25
+ class MyObservation(Observation):
26
+ """Observation from the My Env environment - the echoed message."""
27
+
28
+ echoed_message: str = Field(default="", description="The echoed message")
29
+ message_length: int = Field(default=0, description="Length of the echoed message")
30
+
31
+
32
+ # --- Meta-optimizer environment (meta-learning RL optimizer) ---
33
+
34
+
35
+ class MetaOptimizerAction(Action):
36
+ """Action for the meta-optimizer environment: control optimizer hyperparameters per step."""
37
+
38
+ lr_scale: float = Field(
39
+ ...,
40
+ ge=1e-4,
41
+ le=1.0,
42
+ description="Learning rate scale for this step (e.g. 1e-4 to 1.0)",
43
+ )
44
+ momentum_coef: float = Field(
45
+ ...,
46
+ ge=0.0,
47
+ le=1.0,
48
+ description="Momentum coefficient (0 = no momentum, 1 = full carry)",
49
+ )
50
+ grad_clip_threshold: float = Field(
51
+ ...,
52
+ ge=0.0,
53
+ description="Gradient clipping threshold (0 = no clipping)",
54
+ )
55
+ weight_decay_this_step: float = Field(
56
+ ...,
57
+ ge=0.0,
58
+ description="Weight decay (L2) scale for this step (0 = no weight decay)",
59
+ )
60
+
61
+
62
+ class MetaOptimizerObservation(Observation):
63
+ """Observation from the meta-optimizer environment: loss, step, and optional grad norm."""
64
+
65
+ loss: float = Field(..., description="Current loss after last update")
66
+ step_count: int = Field(..., description="Current step in the episode")
67
+ grad_norm: float | None = Field(
68
+ default=None,
69
+ description="Global gradient norm before last update (if available)",
70
+ )
71
+ steps_to_threshold: int | None = Field(
72
+ default=None,
73
+ description="Step at which loss first reached threshold (None if not yet reached)",
74
+ )
75
+
openenv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: my_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
openenv_my_env.egg-info/PKG-INFO ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-my_env
3
+ Version: 0.1.0
4
+ Summary: My Env environment for OpenEnv
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.0
7
+ Requires-Dist: torch>=2.0.0
8
+ Requires-Dist: matplotlib>=3.5.0
9
+ Requires-Dist: seaborn>=0.12.0
10
+ Requires-Dist: gymnasium>=0.29.0
11
+ Requires-Dist: stable-baselines3>=2.0.0
12
+ Requires-Dist: numpy>=1.20.0
13
+ Provides-Extra: dev
14
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
15
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_my_env.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ __init__.py
3
+ client.py
4
+ models.py
5
+ pyproject.toml
6
+ ./__init__.py
7
+ ./client.py
8
+ ./env_gym.py
9
+ ./models.py
10
+ openenv_my_env.egg-info/PKG-INFO
11
+ openenv_my_env.egg-info/SOURCES.txt
12
+ openenv_my_env.egg-info/dependency_links.txt
13
+ openenv_my_env.egg-info/entry_points.txt
14
+ openenv_my_env.egg-info/requires.txt
15
+ openenv_my_env.egg-info/top_level.txt
16
+ server/__init__.py
17
+ server/app.py
18
+ server/meta_optimizer_environment.py
19
+ server/my_env_environment.py
20
+ server/tasks.py
openenv_my_env.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_my_env.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = my_env.server.app:main
openenv_my_env.egg-info/requires.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.0
2
+ torch>=2.0.0
3
+ matplotlib>=3.5.0
4
+ seaborn>=0.12.0
5
+ gymnasium>=0.29.0
6
+ stable-baselines3>=2.0.0
7
+ numpy>=1.20.0
8
+
9
+ [dev]
10
+ pytest>=8.0.0
11
+ pytest-cov>=4.0.0
openenv_my_env.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ my_env
pyproject.toml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-my_env"
13
+ version = "0.1.0"
14
+ description = "My Env environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.0",
21
+ # Environment-specific dependencies
22
+ # Add all dependencies needed for your environment here
23
+ # Examples:
24
+ # "numpy>=1.19.0",
25
+ "torch>=2.0.0",
26
+ "matplotlib>=3.5.0",
27
+ "seaborn>=0.12.0",
28
+ "gymnasium>=0.29.0",
29
+ "stable-baselines3>=2.0.0",
30
+ "numpy>=1.20.0",
31
+ # "openspiel>=1.0.0",
32
+ # "smolagents>=1.22.0,<2",
33
+ ]
34
+
35
+ [project.optional-dependencies]
36
+ dev = [
37
+ "pytest>=8.0.0",
38
+ "pytest-cov>=4.0.0",
39
+ ]
40
+
41
+ [project.scripts]
42
+ # Server entry point - enables running via: uv run --project . server
43
+ # or: python -m my_env.server.app
44
+ server = "my_env.server.app:main"
45
+
46
+ [tool.setuptools]
47
+ include-package-data = true
48
+ packages = ["my_env", "my_env.server"]
49
+ package-dir = { "my_env" = ".", "my_env.server" = "server" }
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """My Env environment server components."""
8
+
9
+ from .my_env_environment import MyEnvironment
10
+
11
+ __all__ = ["MyEnvironment"]
server/app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the Meta-Optimizer Environment.
9
+
10
+ This module creates an HTTP server that exposes the MetaOptimizerEnvironment
11
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
12
+
13
+ Endpoints:
14
+ - POST /reset: Reset the environment (optionally with task_id for eval)
15
+ - POST /step: Execute an action (lr_scale, momentum_coef, grad_clip_threshold, weight_decay_this_step)
16
+ - GET /state: Get current environment state
17
+ - GET /schema: Get action/observation schemas
18
+ - WS /ws: WebSocket endpoint for persistent sessions
19
+
20
+ Usage:
21
+ # Development (with auto-reload):
22
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
23
+
24
+ # Production:
25
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
26
+
27
+ # Or run directly:
28
+ python -m server.app
29
+ """
30
+
31
+ try:
32
+ from openenv.core.env_server.http_server import create_app
33
+ except Exception as e: # pragma: no cover
34
+ raise ImportError(
35
+ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
36
+ ) from e
37
+
38
+ # Import from package so server works when run via uv run server (my_env.server.app)
39
+ from my_env.models import MetaOptimizerAction, MetaOptimizerObservation
40
+ from .meta_optimizer_environment import MetaOptimizerEnvironment
41
+
42
+
43
+ # Create the app with web interface and README integration
44
+ app = create_app(
45
+ MetaOptimizerEnvironment,
46
+ MetaOptimizerAction,
47
+ MetaOptimizerObservation,
48
+ env_name="meta_optimizer",
49
+ max_concurrent_envs=4,
50
+ )
51
+
52
+
53
+ def main(host: str = "0.0.0.0", port: int = 8000):
54
+ """
55
+ Entry point for direct execution via uv run or python -m.
56
+
57
+ This function enables running the server without Docker:
58
+ uv run --project . server
59
+ uv run --project . server --port 8001
60
+ python -m my_env.server.app
61
+
62
+ Args:
63
+ host: Host address to bind to (default: "0.0.0.0")
64
+ port: Port number to listen on (default: 8000)
65
+
66
+ For production deployments, consider using uvicorn directly with
67
+ multiple workers:
68
+ uvicorn my_env.server.app:app --workers 4
69
+ """
70
+ import uvicorn
71
+
72
+ uvicorn.run(app, host=host, port=port)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ import argparse
77
+
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument("--port", type=int, default=8000)
80
+ args = parser.parse_args()
81
+ main(port=args.port)
server/meta_optimizer_environment.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Meta-optimizer environment: train an RL agent to act as an optimizer on random regression tasks.
9
+
10
+ Supports 50 training tasks, held-out eval, rich action space (LR, momentum, grad clip, weight decay),
11
+ and convergence-speed reward. Action log is exposed for emergent-behavior visualization.
12
+ """
13
+
14
+ import math
15
+ import random
16
+ from typing import Any, Dict, List, Optional
17
+ from uuid import uuid4
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from openenv.core.env_server.interfaces import Environment
23
+ from openenv.core.env_server.types import State
24
+
25
+ from my_env.models import MetaOptimizerAction, MetaOptimizerObservation
26
+ from .tasks import TRAIN_TASK_IDS, get_task, task_spec_from_dict, TaskSpec
27
+
28
+ # Defaults
29
+ LOSS_THRESHOLD = 0.1
30
+ MAX_STEPS = 100
31
+ BATCH_SIZE = 32
32
+
33
+
34
+ def _build_model(spec: TaskSpec) -> nn.Module:
35
+ """Build a 2-layer MLP for the given task spec."""
36
+ torch.manual_seed(spec.arch_seed)
37
+ return nn.Sequential(
38
+ nn.Linear(spec.input_dim, spec.hidden_dim),
39
+ nn.ReLU(),
40
+ nn.Linear(spec.hidden_dim, spec.output_dim),
41
+ )
42
+
43
+
44
+ def _get_batch(spec: TaskSpec, step: int, device: torch.device):
45
+ """Sinusoidal regression: X in [0,1], y = amplitude * sin(2*pi*freq*x + phase) + noise."""
46
+ g = torch.Generator(device=device)
47
+ g.manual_seed(spec.data_seed + step)
48
+ X = torch.rand(BATCH_SIZE, spec.input_dim, device=device, generator=g)
49
+ # y = amplitude * sin(2*pi*freq*x + phase); x is first column
50
+ x = X[:, 0:1]
51
+ y = spec.amplitude * torch.sin(2 * math.pi * spec.freq * x + spec.phase)
52
+ y = y + 0.05 * torch.randn_like(y, device=device, generator=g)
53
+ return X, y
54
+
55
+
56
+ def run_adam_baseline(
57
+ task_id: Optional[int] = None,
58
+ task_spec: Optional[Dict[str, Any]] = None,
59
+ max_steps: int = MAX_STEPS,
60
+ loss_threshold: float = LOSS_THRESHOLD,
61
+ lr: float = 1e-2,
62
+ seed: Optional[int] = None,
63
+ return_metrics: bool = False,
64
+ ):
65
+ """
66
+ Run Adam on one task. Returns steps to threshold, or full metrics dict if return_metrics=True.
67
+ """
68
+ if (task_id is None) == (task_spec is None):
69
+ raise ValueError("Provide exactly one of task_id or task_spec")
70
+ if seed is not None:
71
+ torch.manual_seed(seed)
72
+ device = torch.device("cpu")
73
+ spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
74
+ model = _build_model(spec).to(device)
75
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
76
+ loss_trajectory: List[float] = []
77
+ steps_to_threshold: Optional[int] = None
78
+ for step in range(max_steps):
79
+ X, y = _get_batch(spec, step, device)
80
+ model.train()
81
+ opt.zero_grad()
82
+ loss = nn.functional.mse_loss(model(X), y)
83
+ loss.backward()
84
+ opt.step()
85
+ with torch.no_grad():
86
+ L = nn.functional.mse_loss(model(X), y).item()
87
+ loss_trajectory.append(L)
88
+ if steps_to_threshold is None and L < loss_threshold:
89
+ steps_to_threshold = step + 1
90
+ final_loss = loss_trajectory[-1] if loss_trajectory else float("inf")
91
+ if not return_metrics:
92
+ return steps_to_threshold if steps_to_threshold is not None else max_steps
93
+ last_k = min(10, len(loss_trajectory))
94
+ mean_last_k = sum(loss_trajectory[-last_k:]) / last_k if loss_trajectory else final_loss
95
+ return {
96
+ "steps_to_threshold": steps_to_threshold if steps_to_threshold is not None else max_steps,
97
+ "success": steps_to_threshold is not None,
98
+ "final_loss": final_loss,
99
+ "mean_last_10_loss": mean_last_k,
100
+ "loss_auc": sum(loss_trajectory) / len(loss_trajectory) if loss_trajectory else final_loss,
101
+ "loss_trajectory": loss_trajectory,
102
+ }
103
+
104
+
105
+ def run_sgd_baseline(
106
+ task_id: Optional[int] = None,
107
+ task_spec: Optional[Dict[str, Any]] = None,
108
+ max_steps: int = MAX_STEPS,
109
+ loss_threshold: float = LOSS_THRESHOLD,
110
+ lr: float = 1e-2,
111
+ momentum: float = 0.9,
112
+ seed: Optional[int] = None,
113
+ return_metrics: bool = False,
114
+ ):
115
+ """
116
+ Run SGD (with optional momentum) on one task. Returns steps to threshold, or full metrics dict if return_metrics=True.
117
+ """
118
+ if (task_id is None) == (task_spec is None):
119
+ raise ValueError("Provide exactly one of task_id or task_spec")
120
+ if seed is not None:
121
+ torch.manual_seed(seed)
122
+ device = torch.device("cpu")
123
+ spec = task_spec_from_dict(task_spec) if task_spec is not None else get_task(task_id)
124
+ model = _build_model(spec).to(device)
125
+ opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
126
+ loss_trajectory = []
127
+ steps_to_threshold = None
128
+ for step in range(max_steps):
129
+ X, y = _get_batch(spec, step, device)
130
+ model.train()
131
+ opt.zero_grad()
132
+ loss = nn.functional.mse_loss(model(X), y)
133
+ loss.backward()
134
+ opt.step()
135
+ with torch.no_grad():
136
+ L = nn.functional.mse_loss(model(X), y).item()
137
+ loss_trajectory.append(L)
138
+ if steps_to_threshold is None and L < loss_threshold:
139
+ steps_to_threshold = step + 1
140
+ final_loss = loss_trajectory[-1] if loss_trajectory else float("inf")
141
+ if not return_metrics:
142
+ return steps_to_threshold if steps_to_threshold is not None else max_steps
143
+ last_k = min(10, len(loss_trajectory))
144
+ mean_last_k = sum(loss_trajectory[-last_k:]) / last_k if loss_trajectory else final_loss
145
+ return {
146
+ "steps_to_threshold": steps_to_threshold if steps_to_threshold is not None else max_steps,
147
+ "success": steps_to_threshold is not None,
148
+ "final_loss": final_loss,
149
+ "mean_last_10_loss": mean_last_k,
150
+ "loss_auc": sum(loss_trajectory) / len(loss_trajectory) if loss_trajectory else final_loss,
151
+ "loss_trajectory": loss_trajectory,
152
+ }
153
+
154
+
155
+ def run_meta_optimizer_trajectory(
156
+ task_id: Optional[int] = None,
157
+ task_spec: Optional[Dict[str, Any]] = None,
158
+ max_steps: int = MAX_STEPS,
159
+ loss_threshold: float = LOSS_THRESHOLD,
160
+ seed: Optional[int] = None,
161
+ policy_callable: Optional[Any] = None,
162
+ ) -> Dict[str, Any]:
163
+ """
164
+ Run the meta-optimizer env with a policy (obs -> MetaOptimizerAction) and return metrics dict.
165
+ If policy_callable is None, uses a fixed default policy.
166
+ """
167
+ if (task_id is None) == (task_spec is None):
168
+ raise ValueError("Provide exactly one of task_id or task_spec")
169
+ if seed is not None:
170
+ random.seed(seed)
171
+ torch.manual_seed(seed)
172
+ env = MetaOptimizerEnvironment(max_steps=max_steps, loss_threshold=loss_threshold)
173
+ obs = env.reset(seed=seed, task_id=task_id, task_spec=task_spec)
174
+ loss_trajectory: List[float] = [obs.loss]
175
+ if policy_callable is None:
176
+ def _default_policy(o): # type: ignore
177
+ return MetaOptimizerAction(
178
+ lr_scale=0.02, momentum_coef=0.9,
179
+ grad_clip_threshold=1.0, weight_decay_this_step=0.0,
180
+ )
181
+ policy_callable = _default_policy
182
+ while not obs.done:
183
+ action = policy_callable(obs)
184
+ obs = env.step(action)
185
+ loss_trajectory.append(obs.loss)
186
+ final_loss = obs.loss
187
+ steps_to_threshold = obs.steps_to_threshold if obs.steps_to_threshold is not None else max_steps
188
+ last_k = min(10, len(loss_trajectory))
189
+ mean_last_k = sum(loss_trajectory[-last_k:]) / last_k
190
+ return {
191
+ "steps_to_threshold": steps_to_threshold,
192
+ "success": obs.steps_to_threshold is not None,
193
+ "final_loss": final_loss,
194
+ "mean_last_10_loss": mean_last_k,
195
+ "loss_auc": sum(loss_trajectory) / len(loss_trajectory),
196
+ "loss_trajectory": loss_trajectory,
197
+ }
198
+
199
+
200
+ class MetaOptimizerEnvironment(Environment[MetaOptimizerAction, MetaOptimizerObservation, State]):
201
+ """
202
+ Meta-learning optimizer environment: agent chooses LR scale, momentum, grad clip, weight decay per step.
203
+ Reward = -steps_to_reach_threshold (convergence speed). Supports 50 train tasks and held-out eval.
204
+ """
205
+
206
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
207
+
208
+ def __init__(
209
+ self,
210
+ loss_threshold: float = LOSS_THRESHOLD,
211
+ max_steps: int = MAX_STEPS,
212
+ **kwargs: Any,
213
+ ):
214
+ super().__init__(**kwargs)
215
+ self.loss_threshold = loss_threshold
216
+ self.max_steps = max_steps
217
+ self._device = torch.device("cpu")
218
+
219
+ # Episode state (set in reset)
220
+ self._task_spec: Optional[TaskSpec] = None
221
+ self._model: Optional[nn.Module] = None
222
+ self._velocities: Optional[List[torch.Tensor]] = None
223
+ self._step_count: int = 0
224
+ self._current_loss: float = 0.0
225
+ self._steps_to_threshold: Optional[int] = None
226
+ self._action_log: List[Dict[str, Any]] = []
227
+ self._episode_id: Optional[str] = None
228
+
229
+ def reset(
230
+ self,
231
+ seed: Optional[int] = None,
232
+ episode_id: Optional[str] = None,
233
+ task_id: Optional[int] = None,
234
+ task_spec: Optional[Dict[str, Any]] = None,
235
+ **kwargs: Any,
236
+ ) -> MetaOptimizerObservation:
237
+ if seed is not None:
238
+ random.seed(seed)
239
+ torch.manual_seed(seed)
240
+ if task_spec is not None:
241
+ self._task_spec = task_spec_from_dict(task_spec)
242
+ else:
243
+ tid = task_id if task_id is not None else random.choice(TRAIN_TASK_IDS)
244
+ self._task_spec = get_task(tid)
245
+ self._model = _build_model(self._task_spec).to(self._device)
246
+ self._velocities = [torch.zeros_like(p) for p in self._model.parameters()]
247
+ self._step_count = 0
248
+ self._steps_to_threshold = None
249
+ self._action_log = []
250
+ self._episode_id = episode_id or str(uuid4())
251
+
252
+ # Initial loss (no update yet)
253
+ X, y = _get_batch(self._task_spec, 0, self._device)
254
+ with torch.no_grad():
255
+ out = self._model(X)
256
+ self._current_loss = nn.functional.mse_loss(out, y).item()
257
+
258
+ return self._observation(reward=None, grad_norm=None)
259
+
260
+ def step(
261
+ self,
262
+ action: MetaOptimizerAction,
263
+ timeout_s: Optional[float] = None,
264
+ **kwargs: Any,
265
+ ) -> MetaOptimizerObservation:
266
+ assert self._model is not None and self._task_spec is not None
267
+ lr = action.lr_scale
268
+ momentum = action.momentum_coef
269
+ clip = action.grad_clip_threshold
270
+ wd = action.weight_decay_this_step
271
+
272
+ self._action_log.append({
273
+ "step": self._step_count,
274
+ "lr_scale": lr,
275
+ "momentum_coef": momentum,
276
+ "grad_clip_threshold": clip,
277
+ "weight_decay_this_step": wd,
278
+ })
279
+
280
+ X, y = _get_batch(self._task_spec, self._step_count + 1, self._device)
281
+ self._model.train()
282
+ out = self._model(X)
283
+ loss = nn.functional.mse_loss(out, y)
284
+ self._model.zero_grad()
285
+ loss.backward()
286
+
287
+ grads = [p.grad.clone() for p in self._model.parameters()]
288
+ grad_norm = sum(g.pow(2).sum() for g in grads).sqrt().item()
289
+
290
+ if clip > 0:
291
+ total_norm = sum(g.pow(2).sum() for g in grads).sqrt()
292
+ if total_norm > clip:
293
+ scale = clip / (total_norm + 1e-8)
294
+ grads = [g * scale for g in grads]
295
+
296
+ with torch.no_grad():
297
+ for i, p in enumerate(self._model.parameters()):
298
+ g = grads[i]
299
+ v = self._velocities[i]
300
+ v.mul_(momentum).add_(g)
301
+ p.sub_(v, alpha=lr)
302
+ if wd > 0:
303
+ p.sub_(p, alpha=wd)
304
+
305
+ with torch.no_grad():
306
+ new_out = self._model(X)
307
+ self._current_loss = nn.functional.mse_loss(new_out, y).item()
308
+
309
+ self._step_count += 1
310
+ if self._steps_to_threshold is None and self._current_loss < self.loss_threshold:
311
+ self._steps_to_threshold = self._step_count
312
+
313
+ done = self._step_count >= self.max_steps
314
+ if done:
315
+ reward = -(self._steps_to_threshold if self._steps_to_threshold is not None else self.max_steps)
316
+ else:
317
+ reward = 0.0
318
+
319
+ return self._observation(reward=reward, grad_norm=grad_norm, done=done)
320
+
321
+ def _observation(
322
+ self,
323
+ reward: Optional[float] = None,
324
+ grad_norm: Optional[float] = None,
325
+ done: bool = False,
326
+ ) -> MetaOptimizerObservation:
327
+ meta: Dict[str, Any] = {}
328
+ if self._steps_to_threshold is not None:
329
+ meta["steps_to_threshold"] = self._steps_to_threshold
330
+ if done and self._action_log:
331
+ meta["action_log"] = self._action_log
332
+ return MetaOptimizerObservation(
333
+ loss=self._current_loss,
334
+ step_count=self._step_count,
335
+ grad_norm=grad_norm,
336
+ steps_to_threshold=self._steps_to_threshold,
337
+ done=done,
338
+ reward=reward,
339
+ metadata=meta,
340
+ )
341
+
342
+ @property
343
+ def state(self) -> State:
344
+ return State(
345
+ episode_id=self._episode_id,
346
+ step_count=self._step_count,
347
+ )
348
+
349
+ def get_episode_action_log(self) -> List[Dict[str, Any]]:
350
+ """Return the action log for the current episode (for in-process viz or eval)."""
351
+ return list(self._action_log)
server/my_env_environment.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ My Env Environment Implementation.
9
+
10
+ A simple test environment that echoes back messages sent to it.
11
+ Perfect for testing HTTP server infrastructure.
12
+ """
13
+
14
+ from uuid import uuid4
15
+
16
+ from openenv.core.env_server.interfaces import Environment
17
+ from openenv.core.env_server.types import State
18
+
19
+ from my_env.models import MyAction, MyObservation
20
+
21
+
22
+ class MyEnvironment(Environment):
23
+ """
24
+ A simple echo environment that echoes back messages.
25
+
26
+ This environment is designed for testing the HTTP server infrastructure.
27
+ It maintains minimal state and simply echoes back whatever message it receives.
28
+
29
+ Example:
30
+ >>> env = MyEnvironment()
31
+ >>> obs = env.reset()
32
+ >>> print(obs.echoed_message) # "My Env environment ready!"
33
+ >>>
34
+ >>> obs = env.step(MyAction(message="Hello"))
35
+ >>> print(obs.echoed_message) # "Hello"
36
+ >>> print(obs.message_length) # 5
37
+ """
38
+
39
+ # Enable concurrent WebSocket sessions.
40
+ # Set to True if your environment isolates state between instances.
41
+ # When True, multiple WebSocket clients can connect simultaneously, each
42
+ # getting their own environment instance (when using factory mode in app.py).
43
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
44
+
45
+ def __init__(self):
46
+ """Initialize the my_env environment."""
47
+ self._state = State(episode_id=str(uuid4()), step_count=0)
48
+ self._reset_count = 0
49
+
50
+ def reset(self) -> MyObservation:
51
+ """
52
+ Reset the environment.
53
+
54
+ Returns:
55
+ MyObservation with a ready message
56
+ """
57
+ self._state = State(episode_id=str(uuid4()), step_count=0)
58
+ self._reset_count += 1
59
+
60
+ return MyObservation(
61
+ echoed_message="My Env environment ready!",
62
+ message_length=0,
63
+ done=False,
64
+ reward=0.0,
65
+ )
66
+
67
+ def step(self, action: MyAction) -> MyObservation: # type: ignore[override]
68
+ """
69
+ Execute a step in the environment by echoing the message.
70
+
71
+ Args:
72
+ action: MyAction containing the message to echo
73
+
74
+ Returns:
75
+ MyObservation with the echoed message and its length
76
+ """
77
+ self._state.step_count += 1
78
+
79
+ message = action.message
80
+ length = len(message)
81
+
82
+ # Simple reward: longer messages get higher rewards
83
+ reward = length * 0.1
84
+
85
+ return MyObservation(
86
+ echoed_message=message,
87
+ message_length=length,
88
+ done=False,
89
+ reward=reward,
90
+ metadata={"original_message": message, "step": self._state.step_count},
91
+ )
92
+
93
+ @property
94
+ def state(self) -> State:
95
+ """
96
+ Get the current environment state.
97
+
98
+ Returns:
99
+ Current State with episode_id and step_count
100
+ """
101
+ return self._state
server/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+
5
+
6
+
server/tasks.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Task registry for meta-learning.
9
+
10
+ Tasks can be from the internal registry (get_task(task_id)) or provided from outside
11
+ via task_spec_from_dict() — the client sends the task definition to the environment.
12
+ """
13
+
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List
16
+
17
+ import math
18
+
19
+ # Distribution A: 50 training tasks (low-freq sinusoids)
20
+ TRAIN_TASK_IDS: List[int] = list(range(50))
21
+
22
+ # Distribution B: held-out eval tasks (high-freq sinusoids — different distribution)
23
+ EVAL_TASK_IDS: List[int] = [50, 51]
24
+
25
+ # Bounds for each distribution (freq, amplitude, phase)
26
+ DIST_A_FREQ = (1.0, 3.0)
27
+ DIST_A_AMP = (0.5, 2.0)
28
+ DIST_B_FREQ = (4.0, 6.0)
29
+ DIST_B_AMP = (0.3, 1.5)
30
+
31
+
32
+ @dataclass
33
+ class TaskSpec:
34
+ """Spec for one sinusoidal regression task."""
35
+
36
+ task_id: int
37
+ input_dim: int # 1 for scalar sinusoid input
38
+ hidden_dim: int
39
+ output_dim: int
40
+ data_seed: int
41
+ arch_seed: int
42
+ # Sinusoidal target: y = amplitude * sin(2*pi*freq*x + phase)
43
+ amplitude: float
44
+ freq: float
45
+ phase: float
46
+ distribution: str # "A" or "B"
47
+
48
+
49
+ def get_task(task_id: int) -> TaskSpec:
50
+ """
51
+ Return the task spec for the given task_id.
52
+ Task IDs 0..49 = Distribution A (train), 50+ = Distribution B (eval).
53
+ """
54
+ if task_id < 0:
55
+ raise ValueError(f"task_id must be >= 0, got {task_id}")
56
+ r = task_id * 7919 + 1
57
+ data_seed = task_id * 31337
58
+ arch_seed = task_id * 131 + 7
59
+ hidden_dim = 32 + (r % 33)
60
+
61
+ if task_id < 50:
62
+ # Distribution A
63
+ f_lo, f_hi = DIST_A_FREQ
64
+ a_lo, a_hi = DIST_A_AMP
65
+ distribution = "A"
66
+ else:
67
+ # Distribution B
68
+ f_lo, f_hi = DIST_B_FREQ
69
+ a_lo, a_hi = DIST_B_AMP
70
+ distribution = "B"
71
+
72
+ # Deterministic but varied per task
73
+ freq = f_lo + (r % 1000) / 1000.0 * (f_hi - f_lo)
74
+ amplitude = a_lo + ((r * 3) % 1000) / 1000.0 * (a_hi - a_lo)
75
+ phase = ((r * 7) % 1000) / 1000.0 * 2 * math.pi
76
+
77
+ return TaskSpec(
78
+ task_id=task_id,
79
+ input_dim=1,
80
+ hidden_dim=hidden_dim,
81
+ output_dim=1,
82
+ data_seed=data_seed,
83
+ arch_seed=arch_seed,
84
+ amplitude=amplitude,
85
+ freq=freq,
86
+ phase=phase,
87
+ distribution=distribution,
88
+ )
89
+
90
+
91
+ def task_spec_from_dict(d: Dict[str, Any]) -> TaskSpec:
92
+ """
93
+ Build a TaskSpec from an external dict (sent by the client).
94
+ The task is defined outside the env; we just parse it here.
95
+
96
+ Expected keys for type "sinusoid":
97
+ type="sinusoid", amplitude, freq, phase, data_seed (optional), arch_seed (optional),
98
+ input_dim (optional, default 1), hidden_dim (optional, default 32), task_id (optional).
99
+ """
100
+ task_type = d.get("type", "sinusoid")
101
+ if task_type != "sinusoid":
102
+ raise ValueError(f"Unknown task type: {task_type}")
103
+ task_id = d.get("task_id", 0)
104
+ return TaskSpec(
105
+ task_id=task_id,
106
+ input_dim=int(d.get("input_dim", 1)),
107
+ hidden_dim=int(d.get("hidden_dim", 32)),
108
+ output_dim=1,
109
+ data_seed=int(d.get("data_seed", task_id * 31337)),
110
+ arch_seed=int(d.get("arch_seed", task_id * 131 + 7)),
111
+ amplitude=float(d["amplitude"]),
112
+ freq=float(d["freq"]),
113
+ phase=float(d["phase"]),
114
+ distribution=d.get("distribution", "external"),
115
+ )
uv.lock ADDED
The diff for this file is too large to render. See raw diff