Spaces:
Sleeping
Sleeping
File size: 4,748 Bytes
f303e4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Synchronous wrapper for async EnvClient.
This module provides a SyncEnvClient that wraps an async EnvClient,
allowing synchronous usage while the underlying client uses async I/O.
Example:
>>> from openenv.core import GenericEnvClient
>>>
>>> # Create async client and get sync wrapper
>>> async_client = GenericEnvClient(base_url="http://localhost:8000")
>>> sync_client = async_client.sync()
>>>
>>> # Use synchronous API
>>> with sync_client:
... result = sync_client.reset()
... result = sync_client.step({"code": "print('hello')"})
"""
from __future__ import annotations
from typing import Any, Dict, Generic, TYPE_CHECKING, TypeVar
from .client_types import StepResult, StateT
from .utils import run_async_safely
if TYPE_CHECKING:
from .env_client import EnvClient
ActT = TypeVar("ActT")
ObsT = TypeVar("ObsT")
class SyncEnvClient(Generic[ActT, ObsT, StateT]):
"""
Synchronous wrapper around an async EnvClient.
This class provides a synchronous interface to an async EnvClient,
making it easier to use in synchronous code or to stop async from
"infecting" the entire call stack.
The wrapper uses `run_async_safely()` to execute async operations,
which handles both sync and async calling contexts correctly.
Example:
>>> # From an async client
>>> async_client = GenericEnvClient(base_url="http://localhost:8000")
>>> sync_client = async_client.sync()
>>>
>>> # Use synchronous context manager
>>> with sync_client:
... result = sync_client.reset()
... result = sync_client.step({"action": "test"})
Attributes:
_async: The wrapped async EnvClient instance
"""
def __init__(self, async_client: "EnvClient[ActT, ObsT, StateT]"):
"""
Initialize sync wrapper around an async client.
Args:
async_client: The async EnvClient to wrap
"""
self._async = async_client
@property
def async_client(self) -> "EnvClient[ActT, ObsT, StateT]":
"""Access the underlying async client."""
return self._async
def connect(self) -> "SyncEnvClient[ActT, ObsT, StateT]":
"""
Establish connection to the server.
Returns:
self for method chaining
"""
run_async_safely(self._async.connect())
return self
def disconnect(self) -> None:
"""Close the connection."""
run_async_safely(self._async.disconnect())
def reset(self, **kwargs: Any) -> StepResult[ObsT]:
"""
Reset the environment.
Args:
**kwargs: Optional parameters passed to the environment's reset method
Returns:
StepResult containing initial observation
"""
return run_async_safely(self._async.reset(**kwargs))
def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]:
"""
Execute an action in the environment.
Args:
action: The action to execute
**kwargs: Optional parameters
Returns:
StepResult containing observation, reward, and done status
"""
return run_async_safely(self._async.step(action, **kwargs))
def state(self) -> StateT:
"""
Get the current environment state.
Returns:
State object with environment state information
"""
return run_async_safely(self._async.state())
def close(self) -> None:
"""Close the connection and clean up resources."""
run_async_safely(self._async.close())
def __enter__(self) -> "SyncEnvClient[ActT, ObsT, StateT]":
"""Enter context manager, establishing connection."""
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit context manager, closing connection."""
self.close()
# Delegate abstract method implementations to the wrapped client
def _step_payload(self, action: ActT) -> Dict[str, Any]:
"""Delegate to async client's _step_payload."""
return self._async._step_payload(action)
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]:
"""Delegate to async client's _parse_result."""
return self._async._parse_result(payload)
def _parse_state(self, payload: Dict[str, Any]) -> StateT:
"""Delegate to async client's _parse_state."""
return self._async._parse_state(payload)
|