finrl_env-v2-1-0 / src /core /env_client.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
b03d60d verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Environment client for persistent sessions.
This module provides a WebSocket-based client that maintains a persistent connection
to an environment server, enabling efficient multi-step interactions without
the overhead of HTTP request/response cycles.
The client is async by default. For synchronous usage, use the `.sync()` method
to get a `SyncEnvClient` wrapper.
Example (async):
>>> async with GenericEnvClient(base_url="ws://localhost:8000") as env:
... result = await env.reset()
... result = await env.step({"code": "print('hello')"})
Example (sync wrapper):
>>> env = GenericEnvClient(base_url="ws://localhost:8000").sync()
>>> with env:
... result = env.reset()
... result = env.step({"code": "print('hello')"})
"""
from __future__ import annotations
import asyncio
import json
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar
from .client_types import StepResult, StateT
from .containers.runtime import LocalDockerProvider, UVProvider
from .utils import convert_to_ws_url
if TYPE_CHECKING:
from .containers.runtime import ContainerProvider, RuntimeProvider
from .sync_client import SyncEnvClient
from websockets.asyncio.client import ClientConnection
from websockets.asyncio.client import connect as ws_connect
ActT = TypeVar("ActT")
ObsT = TypeVar("ObsT")
EnvClientT = TypeVar("EnvClientT", bound="EnvClient")
class EnvClient(ABC, Generic[ActT, ObsT, StateT]):
"""
Async environment client for persistent sessions.
This client maintains a persistent WebSocket connection to an environment
server, enabling efficient multi-step interactions. Each client instance
corresponds to a dedicated environment session on the server.
The client is async by default. For synchronous usage, use the `.sync()`
method to get a `SyncEnvClient` wrapper.
Features:
- Lower latency for sequential interactions
- Session state is maintained server-side
- Better suited for long-running episodes
- Async by default for modern Python async/await patterns
Example (async):
>>> from envs.coding_env.client import CodingEnv
>>>
>>> # Connect to a server using async context manager
>>> async with CodingEnv(base_url="ws://localhost:8000") as env:
... result = await env.reset(seed=42)
... while not result.done:
... action = agent.predict(result.observation)
... result = await env.step(action)
Example (sync wrapper):
>>> env = CodingEnv(base_url="ws://localhost:8000").sync()
>>> with env:
... result = env.reset(seed=42)
... result = env.step(action)
"""
def __init__(
self,
base_url: str,
connect_timeout_s: float = 10.0,
message_timeout_s: float = 60.0,
max_message_size_mb: float = 100.0,
provider: Optional["ContainerProvider | RuntimeProvider"] = None,
mode: Optional[str] = None,
):
"""
Initialize environment client.
Args:
base_url: Base URL of the environment server (http:// or ws://).
Will be converted to ws:// if http:// is provided.
connect_timeout_s: Timeout for establishing WebSocket connection
message_timeout_s: Timeout for receiving responses to messages
max_message_size_mb: Maximum WebSocket message size in megabytes.
Default 100MB to handle large observations (screenshots, DOM, etc.)
provider: Optional container/runtime provider for lifecycle management.
Can be a ContainerProvider (Docker) or RuntimeProvider (UV).
mode: Communication mode: 'simulation' for Gym-style API (default) or
'production' for MCP JSON-RPC protocol. Can also be set via the
OPENENV_CLIENT_MODE environment variable. Constructor parameter
takes precedence over environment variable. Case-insensitive.
"""
# Determine mode (constructor > env var > default)
if mode is None:
mode = os.environ.get("OPENENV_CLIENT_MODE", "simulation")
# Normalize and validate mode
mode = mode.lower()
if mode not in ("simulation", "production"):
raise ValueError(
f"Invalid mode: '{mode}'. Must be 'simulation' or 'production'. "
f"Set via constructor parameter or OPENENV_CLIENT_MODE environment variable."
)
# Store mode (use object.__setattr__ to bypass immutability)
object.__setattr__(self, "_mode", mode)
# Convert HTTP URL to WebSocket URL
ws_url = convert_to_ws_url(base_url)
self._ws_url = f"{ws_url}/ws"
self._connect_timeout = connect_timeout_s
self._message_timeout = message_timeout_s
self._max_message_size = int(
max_message_size_mb * 1024 * 1024
) # Convert MB to bytes
self._provider = provider
self._ws: Optional[ClientConnection] = None
def __setattr__(self, name: str, value: Any) -> None:
"""Prevent modification of _mode after initialization."""
if name == "_mode" and hasattr(self, "_mode"):
raise AttributeError("Cannot modify mode after initialization")
super().__setattr__(name, value)
async def connect(self) -> "EnvClient":
"""
Establish WebSocket connection to the server.
Returns:
self for method chaining
Raises:
ConnectionError: If connection cannot be established
"""
if self._ws is not None:
return self
# Bypass proxy for localhost connections
ws_url_lower = self._ws_url.lower()
is_localhost = "localhost" in ws_url_lower or "127.0.0.1" in ws_url_lower
old_no_proxy = os.environ.get("NO_PROXY")
if is_localhost:
# Set NO_PROXY to bypass proxy for localhost
current_no_proxy = old_no_proxy or ""
if "localhost" not in current_no_proxy.lower():
os.environ["NO_PROXY"] = (
f"{current_no_proxy},localhost,127.0.0.1"
if current_no_proxy
else "localhost,127.0.0.1"
)
try:
self._ws = await ws_connect(
self._ws_url,
open_timeout=self._connect_timeout,
max_size=self._max_message_size,
)
except Exception as e:
raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e
finally:
# Restore original NO_PROXY value
if is_localhost:
if old_no_proxy is None:
os.environ.pop("NO_PROXY", None)
else:
os.environ["NO_PROXY"] = old_no_proxy
return self
async def disconnect(self) -> None:
"""Close the WebSocket connection."""
if self._ws is not None:
try:
# Send close message
await self._send({"type": "close"})
except Exception:
pass # Best effort
try:
await self._ws.close()
except Exception:
pass
self._ws = None
async def _ensure_connected(self) -> None:
"""Ensure WebSocket connection is established."""
if self._ws is None:
await self.connect()
async def _send(self, message: Dict[str, Any]) -> None:
"""Send a message over the WebSocket."""
await self._ensure_connected()
assert self._ws is not None
await self._ws.send(json.dumps(message))
async def _receive(self) -> Dict[str, Any]:
"""Receive and parse a message from the WebSocket."""
assert self._ws is not None
raw = await asyncio.wait_for(self._ws.recv(), timeout=self._message_timeout)
return json.loads(raw)
async def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Send a message and wait for response."""
await self._send(message)
response = await self._receive()
# Check for error response
if response.get("type") == "error":
error_data = response.get("data", {})
raise RuntimeError(
f"Server error: {error_data.get('message', 'Unknown error')} "
f"(code: {error_data.get('code', 'UNKNOWN')})"
)
return response
@classmethod
async def from_docker_image(
cls: Type[EnvClientT],
image: str,
provider: Optional["ContainerProvider"] = None,
**kwargs: Any,
) -> EnvClientT:
"""
Create an environment client by spinning up a Docker container.
Args:
image: Docker image name to run (e.g., "coding-env:latest")
provider: Container provider to use (defaults to LocalDockerProvider)
**kwargs: Additional arguments to pass to provider.start_container()
Returns:
Connected client instance
"""
if provider is None:
provider = LocalDockerProvider()
# Start container
base_url = provider.start_container(image, **kwargs)
# Wait for server to be ready
provider.wait_for_ready(base_url)
# Create and connect client
client = cls(base_url=base_url, provider=provider)
await client.connect()
return client
@classmethod
async def from_env(
cls: Type[EnvClientT],
repo_id: str,
*,
use_docker: bool = True,
provider: Optional["ContainerProvider | RuntimeProvider"] = None,
**provider_kwargs: Any,
) -> EnvClientT:
"""
Create a client from a Hugging Face Space.
Args:
repo_id: Hugging Face space identifier ``{org}/{space}``.
use_docker: When ``True`` (default) pull from the HF registry and
launch via :class:`LocalDockerProvider`. When ``False`` run the
space locally with :class:`UVProvider`.
provider: Optional provider instance to reuse. Must be a
:class:`ContainerProvider` when ``use_docker=True`` and a
:class:`RuntimeProvider` otherwise.
provider_kwargs: Additional keyword arguments forwarded to
either the container provider's ``start_container`` (docker)
or to the ``UVProvider`` constructor/start (uv). When
``use_docker=False``, the ``project_path`` argument can be
used to override the default git URL
(``git+https://huggingface.co/spaces/{repo_id}``).
Returns:
Connected client instance
Examples:
>>> # Pull and run from HF Docker registry
>>> env = await MyEnv.from_env("openenv/echo-env")
>>>
>>> # Run locally with UV (clones the space)
>>> env = await MyEnv.from_env("openenv/echo-env", use_docker=False)
>>>
>>> # Run from a local checkout
>>> env = await MyEnv.from_env(
... "openenv/echo-env",
... use_docker=False,
... project_path="/path/to/local/checkout"
... )
"""
# Extract start args that apply to both providers
start_args = {}
for key in ("port", "env_vars", "workers"):
if key in provider_kwargs:
start_args[key] = provider_kwargs.pop(key)
if use_docker:
# Docker mode: pull from HF registry
docker_provider = provider or LocalDockerProvider()
tag = provider_kwargs.pop("tag", "latest")
image = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}"
base_url = docker_provider.start_container(
image, **start_args, **provider_kwargs
)
docker_provider.wait_for_ready(base_url)
client = cls(base_url=base_url, provider=docker_provider)
await client.connect()
return client
else:
# UV mode: clone and run with uv
if provider is None:
uv_kwargs = dict(provider_kwargs)
project_path = uv_kwargs.pop("project_path", None)
if project_path is None:
project_path = f"git+https://huggingface.co/spaces/{repo_id}"
provider = UVProvider(project_path=project_path, **uv_kwargs)
else:
if provider_kwargs:
raise ValueError(
"provider_kwargs cannot be used when supplying a provider instance"
)
base_url = provider.start(**start_args)
provider.wait_for_ready()
client = cls(base_url=base_url, provider=provider)
await client.connect()
return client
@abstractmethod
def _step_payload(self, action: ActT) -> Dict[str, Any]:
"""Convert an Action object to the JSON data expected by the env server."""
raise NotImplementedError
@abstractmethod
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]:
"""Convert a JSON response from the env server to StepResult[ObsT]."""
raise NotImplementedError
@abstractmethod
def _parse_state(self, payload: Dict[str, Any]) -> StateT:
"""Convert a JSON response from the state endpoint to a State object."""
raise NotImplementedError
async def reset(self, **kwargs: Any) -> StepResult[ObsT]:
"""
Reset the environment with optional parameters.
Args:
**kwargs: Optional parameters passed to the environment's reset method.
Common parameters include:
- seed: Random seed for reproducibility
- episode_id: Custom episode identifier
Returns:
StepResult containing initial observation
"""
message = {
"type": "reset",
"data": kwargs,
}
response = await self._send_and_receive(message)
return self._parse_result(response.get("data", {}))
async def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]:
"""
Execute an action in the environment.
Args:
action: The action to execute
**kwargs: Optional parameters (currently ignored)
Returns:
StepResult containing observation, reward, and done status
"""
message = {
"type": "step",
"data": self._step_payload(action),
}
response = await self._send_and_receive(message)
return self._parse_result(response.get("data", {}))
async def state(self) -> StateT:
"""
Get the current environment state from the server.
Returns:
State object with environment state information
"""
message = {"type": "state"}
response = await self._send_and_receive(message)
return self._parse_state(response.get("data", {}))
async def close(self) -> None:
"""
Close the WebSocket connection and clean up resources.
If this client was created via from_docker_image() or from_env(),
this will also stop and remove the associated container/process.
"""
await self.disconnect()
if self._provider is not None:
# Handle both ContainerProvider and RuntimeProvider
if hasattr(self._provider, "stop_container"):
self._provider.stop_container()
elif hasattr(self._provider, "stop"):
self._provider.stop()
async def __aenter__(self) -> "EnvClient":
"""Enter async context manager, ensuring connection is established."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit async context manager, closing connection."""
await self.close()
def __enter__(self) -> "EnvClient":
"""Sync context manager entry - raises error suggesting async usage."""
raise TypeError(
"EnvClient is async by default. Use 'async with' instead of 'with', "
"or call .sync() to get a synchronous wrapper:\n"
" async with client: # async usage\n"
" with client.sync(): # sync wrapper"
)
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Sync context manager exit - should not be reached."""
pass # pragma: no cover
def sync(self) -> "SyncEnvClient":
"""
Return a synchronous wrapper around this async client.
Use this method when you need synchronous access to the environment
without async/await syntax. This is useful for:
- Integration with synchronous codebases
- Interactive/REPL usage
- Stopping async from "infecting" the call stack
Returns:
SyncEnvClient wrapper that provides synchronous methods
Example:
>>> # Create async client and get sync wrapper
>>> async_client = GenericEnvClient(base_url="http://localhost:8000")
>>> sync_client = async_client.sync()
>>>
>>> # Use synchronous API
>>> with sync_client:
... result = sync_client.reset()
... result = sync_client.step({"code": "print('hello')"})
"""
from .sync_client import SyncEnvClient
return SyncEnvClient(self)