File size: 4,719 Bytes
81b02bf
 
 
 
 
 
 
b74674a
81b02bf
b74674a
 
81b02bf
b74674a
81b02bf
 
 
 
b74674a
81b02bf
 
 
 
 
 
 
 
 
 
 
 
b74674a
81b02bf
b74674a
81b02bf
b74674a
 
81b02bf
b74674a
 
 
 
 
 
81b02bf
b74674a
81b02bf
 
b74674a
 
81b02bf
 
b74674a
81b02bf
 
 
 
 
 
b74674a
81b02bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b74674a
81b02bf
 
 
 
 
 
 
 
 
 
 
b74674a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
REPL Environment clients.

`REPLEnv` is the standard async OpenEnv client for remote/server-backed usage.
Use `async with` / `await` directly, or call `.sync()` for synchronous code.

This module intentionally contains only the remote OpenEnv client.
"""

from __future__ import annotations

from typing import Any

try:
    from openenv.core.client_types import StepResult
    from openenv.core.env_client import EnvClient

    from .models import CodeBlockResult, REPLAction, REPLObservation, REPLState
except ImportError:
    from models import CodeBlockResult, REPLAction, REPLObservation, REPLState
    from openenv.core.client_types import StepResult
    from openenv.core.env_client import EnvClient


class REPLEnv(EnvClient[REPLAction, REPLObservation, REPLState]):
    """
    Async client for the remote REPL environment.

    Use this client when connecting to a running OpenEnv server over WebSocket.
    For synchronous code, call `.sync()` on an instance.

    Example:
        >>> async with REPLEnv(base_url="http://localhost:8000") as env:
        ...     result = await env.reset(context="Hello World", task_prompt="Count chars")
        ...     result = await env.execute("count = len(context)")
        ...     result = await env.execute("print(f'FINAL({count})')")
        ...     print(result.done)

        >>> with REPLEnv(base_url="http://localhost:8000").sync() as env:
        ...     result = env.reset(context="Hello World", task_prompt="Count chars")
        ...     result = env.execute("count = len(context)")
        ...     result = env.execute("print(f'FINAL({count})')")
        ...     print(result.done)
    """

    def _step_payload(self, action: REPLAction) -> dict[str, Any]:
        return {
            "code": action.code,
            "is_final": action.is_final,
            "final_answer": action.final_answer,
        }

    def _parse_result(self, payload: dict[str, Any]) -> StepResult[REPLObservation]:
        obs_data = payload.get("observation", {})
        result_data = obs_data.get("result", {})

        observation = REPLObservation(
            result=CodeBlockResult(
                stdout=result_data.get("stdout", ""),
                stderr=result_data.get("stderr", ""),
                locals_snapshot=result_data.get("locals_snapshot", {}),
                execution_time=result_data.get("execution_time", 0.0),
                success=result_data.get("success", True),
                exception=result_data.get("exception"),
            ),
            context_preview=obs_data.get("context_preview"),
            context_length=obs_data.get("context_length", 0),
            available_variables=obs_data.get("available_variables", []),
            iteration=obs_data.get("iteration", 0),
            max_iterations=obs_data.get("max_iterations", 30),
            done=payload.get("done", False),
            reward=payload.get("reward"),
            metadata=obs_data.get("metadata", {}),
        )

        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: dict[str, Any]) -> REPLState:
        return REPLState(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
            context=payload.get("context"),
            task_prompt=payload.get("task_prompt"),
            iteration=payload.get("iteration", 0),
            max_iterations=payload.get("max_iterations", 30),
            namespace_keys=payload.get("namespace_keys", []),
            final_answer=payload.get("final_answer"),
            total_execution_time=payload.get("total_execution_time", 0.0),
        )

    async def execute(self, code: str) -> StepResult[REPLObservation]:
        """Execute Python code in the REPL."""
        return await self.step(REPLAction(code=code))

    async def submit_final_answer(self, answer: str) -> StepResult[REPLObservation]:
        """Submit a final answer and terminate the episode."""
        return await self.step(REPLAction(code="", is_final=True, final_answer=answer))

    async def get_variable(self, name: str) -> StepResult[REPLObservation]:
        """Retrieve and print a variable from the REPL namespace."""
        return await self.execute(f"print(repr({name}))")

    async def list_variables(self) -> list[str]:
        """Return the current REPL namespace keys."""
        return (await self.state()).namespace_keys