repl / client.py
sergiopaniego's picture
Upload folder using huggingface_hub
07fffa0 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
REPL Environment Client.
This module provides a unified client for the REPL Environment that works
with both remote servers (via WebSocket) and local execution (no server needed).
Examples:
# Connect to remote server with your HF token for sub-LLM calls
env = REPLEnv(base_url="https://my-server.hf.space")
result = env.reset(
context="...",
task_prompt="...",
hf_token=os.environ["HF_TOKEN"], # Server uses this for llm_query
)
# Run locally (no server)
env = REPLEnv()
# Local with LLM support
env = REPLEnv(llm_query_fn=my_llm, llm_batch_fn=my_batch)
# All use the same interface
result = env.execute("x = len(context)")
env.close()
"""
from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
# Support both in-repo and standalone imports
try:
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from .models import REPLAction, REPLObservation, REPLState, CodeBlockResult
except ImportError:
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from models import REPLAction, REPLObservation, REPLState, CodeBlockResult
if TYPE_CHECKING:
from .server.repl_environment import REPLEnvironment
class REPLEnv:
"""
Unified client for the REPL Environment.
Works with both remote servers and local execution, providing the same
interface regardless of where the code runs.
Examples:
>>> # Connect to a running server
>>> with REPLEnv(base_url="http://localhost:8000") as env:
... result = env.reset(context="Hello World", task_prompt="Count chars")
... result = env.execute("count = len(context)")
... result = env.execute("print(f'FINAL({count})')")
... print(result.done) # True
>>> # Run locally without a server
>>> with REPLEnv() as env:
... result = env.reset(context="Hello World", task_prompt="Count chars")
... result = env.execute("count = len(context)")
... print(result.observation.result.success) # True
>>> # Local with LLM support for recursive calls
>>> def my_llm(prompt: str) -> str:
... return "LLM response"
>>> with REPLEnv(llm_query_fn=my_llm) as env:
... result = env.reset(context="...")
... result = env.execute("response = llm_query('Summarize: ' + context)")
>>> # From Docker image
>>> env = REPLEnv.from_docker_image("repl-env:latest")
>>> # From HuggingFace Hub
>>> env = REPLEnv.from_hub("openenv/repl-env")
"""
def __init__(
self,
base_url: Optional[str] = None,
*,
# Local-only options (ignored when base_url is set)
llm_query_fn: Optional[Callable[[str], str]] = None,
llm_batch_fn: Optional[Callable[[List[str]], List[str]]] = None,
max_output_length: int = 8192,
context_preview_length: int = 500,
reward_on_success: float = 1.0,
reward_on_iteration: float = 0.0,
reward_on_failure: float = -0.1,
reward_on_error: float = -0.05,
# Connection options (ignored when running locally)
connect_timeout_s: float = 10.0,
message_timeout_s: float = 60.0,
):
"""
Initialize REPL environment.
Args:
base_url: Server URL. If None, runs locally without a server.
llm_query_fn: Function for llm_query() calls (local mode only).
llm_batch_fn: Function for llm_query_batched() calls (local mode only).
max_output_length: Max stdout/stderr chars per execution (local only).
context_preview_length: Chars to show in context preview (local only).
reward_on_success: Reward when final answer submitted (local only).
reward_on_iteration: Reward per iteration step (local only).
reward_on_failure: Reward when max iterations reached (local only).
reward_on_error: Reward when code execution fails (local only).
connect_timeout_s: WebSocket connection timeout (remote only).
message_timeout_s: Message response timeout (remote only).
"""
self._base_url = base_url
self._local_env: Optional[REPLEnvironment] = None
self._remote_client: Optional[_RemoteREPLClient] = None
# Store local-mode options
self._llm_query_fn = llm_query_fn
self._llm_batch_fn = llm_batch_fn
self._max_output_length = max_output_length
self._context_preview_length = context_preview_length
self._reward_on_success = reward_on_success
self._reward_on_iteration = reward_on_iteration
self._reward_on_failure = reward_on_failure
self._reward_on_error = reward_on_error
# Store remote-mode options
self._connect_timeout_s = connect_timeout_s
self._message_timeout_s = message_timeout_s
# Provider for container/runtime lifecycle (set by factory methods)
self._provider = None
def _ensure_initialized(self) -> None:
"""Initialize the appropriate backend (local or remote)."""
if self._local_env is not None or self._remote_client is not None:
return
if self._base_url is None:
# Local mode: create REPLEnvironment directly
from .server.repl_environment import REPLEnvironment
self._local_env = REPLEnvironment(
max_output_length=self._max_output_length,
context_preview_length=self._context_preview_length,
reward_on_success=self._reward_on_success,
reward_on_iteration=self._reward_on_iteration,
reward_on_failure=self._reward_on_failure,
reward_on_error=self._reward_on_error,
llm_query_fn=self._llm_query_fn,
llm_batch_fn=self._llm_batch_fn,
)
else:
# Remote mode: create WebSocket client
self._remote_client = _RemoteREPLClient(
base_url=self._base_url,
connect_timeout_s=self._connect_timeout_s,
message_timeout_s=self._message_timeout_s,
provider=self._provider,
)
self._remote_client.connect()
def reset(
self,
*,
context: str = "",
task_prompt: str = "",
max_iterations: int = 30,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
hf_token: Optional[str] = None,
llm_model: Optional[str] = None,
) -> StepResult[REPLObservation]:
"""
Reset the environment for a new episode.
Args:
context: Text content to analyze (accessible as `context` variable).
task_prompt: Description of the task to solve.
max_iterations: Maximum code execution steps before timeout.
seed: Optional random seed for reproducibility.
episode_id: Optional custom episode identifier.
hf_token: Optional HuggingFace token for llm_query/llm_query_batched.
When provided, the server uses this token for sub-LLM calls
instead of its own configured token.
llm_model: Optional model name for LLM functions (default: Qwen3-Coder-480B).
Returns:
StepResult with initial observation.
"""
self._ensure_initialized()
if self._local_env is not None:
# Local mode
self._local_env.max_iterations = max_iterations
obs = self._local_env.reset(
seed=seed,
episode_id=episode_id,
context=context,
task_prompt=task_prompt,
hf_token=hf_token,
llm_model=llm_model,
)
return self._wrap_observation(obs)
else:
# Remote mode
assert self._remote_client is not None
return self._remote_client.reset(
context=context,
task_prompt=task_prompt,
max_iterations=max_iterations,
seed=seed,
episode_id=episode_id,
hf_token=hf_token,
llm_model=llm_model,
)
def step(self, action: REPLAction) -> StepResult[REPLObservation]:
"""
Execute a REPL action.
Args:
action: REPLAction containing code to execute.
Returns:
StepResult with execution observation.
"""
self._ensure_initialized()
if self._local_env is not None:
obs = self._local_env.step(action)
return self._wrap_observation(obs)
else:
assert self._remote_client is not None
return self._remote_client.step(action)
def execute(self, code: str) -> StepResult[REPLObservation]:
"""
Execute Python code in the REPL.
Convenience method that wraps step() with a code-only action.
Args:
code: Python code to execute.
Returns:
StepResult with execution observation.
"""
return self.step(REPLAction(code=code))
def submit_final_answer(self, answer: str) -> StepResult[REPLObservation]:
"""
Submit a final answer and terminate the episode.
Args:
answer: The final answer string.
Returns:
StepResult with done=True.
"""
return self.step(REPLAction(code="", is_final=True, final_answer=answer))
def get_variable(self, name: str) -> StepResult[REPLObservation]:
"""
Retrieve and print a variable from the REPL namespace.
Args:
name: Variable name to retrieve.
Returns:
StepResult with variable value in stdout.
"""
return self.execute(f"print(repr({name}))")
def state(self) -> REPLState:
"""
Get current environment state.
Returns:
REPLState with current environment information.
"""
self._ensure_initialized()
if self._local_env is not None:
return self._local_env.state
else:
assert self._remote_client is not None
return self._remote_client.state()
def list_variables(self) -> List[str]:
"""
Get list of available variables in the current session.
Returns:
List of variable names.
"""
return self.state().namespace_keys
def close(self) -> None:
"""Clean up resources."""
if self._local_env is not None:
self._local_env.close()
self._local_env = None
if self._remote_client is not None:
self._remote_client.close()
self._remote_client = None
def _wrap_observation(self, obs: REPLObservation) -> StepResult[REPLObservation]:
"""Wrap a local REPLObservation in a StepResult."""
return StepResult(
observation=obs,
reward=obs.reward,
done=obs.done,
)
# Context manager support
def __enter__(self) -> "REPLEnv":
"""Enter context manager."""
self._ensure_initialized()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit context manager."""
self.close()
# Factory methods
@classmethod
def from_docker_image(
cls,
image: str,
**kwargs: Any,
) -> "REPLEnv":
"""
Create a REPL environment by spinning up a Docker container.
Args:
image: Docker image name to run (e.g., "repl-env:latest").
**kwargs: Additional arguments passed to container start.
Returns:
Connected REPLEnv instance.
"""
from openenv.core.containers.runtime import LocalDockerProvider
provider = LocalDockerProvider()
base_url = provider.start_container(image, **kwargs)
provider.wait_for_ready(base_url)
env = cls(base_url=base_url)
env._provider = provider
env._ensure_initialized()
return env
@classmethod
def from_hub(
cls,
repo_id: str,
*,
use_docker: bool = True,
**kwargs: Any,
) -> "REPLEnv":
"""
Create a REPL environment from a HuggingFace Space.
Args:
repo_id: HuggingFace space identifier (e.g., "openenv/repl-env").
use_docker: If True, pull from HF registry. If False, run with UV.
**kwargs: Additional arguments passed to provider.
Returns:
Connected REPLEnv instance.
"""
if use_docker:
from openenv.core.containers.runtime import LocalDockerProvider
provider = LocalDockerProvider()
tag = kwargs.pop("tag", "latest")
image = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}"
base_url = provider.start_container(image, **kwargs)
provider.wait_for_ready(base_url)
else:
from openenv.core.containers.runtime import UVProvider
project_path = kwargs.pop(
"project_path", f"git+https://huggingface.co/spaces/{repo_id}"
)
provider = UVProvider(project_path=project_path, **kwargs)
base_url = provider.start()
provider.wait_for_ready()
env = cls(base_url=base_url)
env._provider = provider
env._ensure_initialized()
return env
class _RemoteREPLClient(EnvClient[REPLAction, REPLObservation, REPLState]):
"""
Internal WebSocket client for remote REPL connections.
This is the original EnvClient-based implementation, now used internally
by REPLEnv for remote mode.
"""
def _step_payload(self, action: REPLAction) -> Dict:
"""Convert REPLAction to JSON payload for step request."""
return {
"code": action.code,
"is_final": action.is_final,
"final_answer": action.final_answer,
}
def _parse_result(self, payload: Dict) -> StepResult[REPLObservation]:
"""Parse server response into StepResult[REPLObservation]."""
obs_data = payload.get("observation", {})
result_data = obs_data.get("result", {})
observation = REPLObservation(
result=CodeBlockResult(
stdout=result_data.get("stdout", ""),
stderr=result_data.get("stderr", ""),
locals_snapshot=result_data.get("locals_snapshot", {}),
execution_time=result_data.get("execution_time", 0.0),
success=result_data.get("success", True),
exception=result_data.get("exception"),
),
context_preview=obs_data.get("context_preview"),
context_length=obs_data.get("context_length", 0),
available_variables=obs_data.get("available_variables", []),
iteration=obs_data.get("iteration", 0),
max_iterations=obs_data.get("max_iterations", 30),
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> REPLState:
"""Parse server response into REPLState object."""
return REPLState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
context=payload.get("context"),
task_prompt=payload.get("task_prompt"),
iteration=payload.get("iteration", 0),
max_iterations=payload.get("max_iterations", 30),
namespace_keys=payload.get("namespace_keys", []),
final_answer=payload.get("final_answer"),
total_execution_time=payload.get("total_execution_time", 0.0),
)