File size: 2,886 Bytes
5d897b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""TB2 Environment Client."""

from __future__ import annotations

from typing import Any


# Support both in-repo and standalone imports
try:
    # In-repo imports (when running from OpenEnv repository)
    from openenv.core.client_types import StepResult
    from openenv.core.env_client import EnvClient

    from .models import Tbench2Action, Tbench2Observation, Tbench2State
except ImportError:
    # Standalone imports (when environment is standalone with openenv from pip)
    from openenv.core.client_types import StepResult
    from openenv.core.env_client import EnvClient

    from models import Tbench2Action, Tbench2Observation, Tbench2State


class Tbench2Env(EnvClient[Tbench2Action, Tbench2Observation, Tbench2State]):
    """HTTP client for the TB2 environment."""

    def _step_payload(self, action: Tbench2Action) -> dict[str, Any]:
        return {
            "action_type": action.action_type,
            "command": action.command,
            "session_id": action.session_id,
            "block": action.block,
            "wait_seconds": action.wait_seconds,
            "file_path": action.file_path,
            "content": action.content,
        }

    def _parse_result(self, payload: dict[str, Any]) -> StepResult[Tbench2Observation]:
        obs_data = payload.get("observation", {})
        observation = Tbench2Observation(
            instruction=obs_data.get("instruction", ""),
            output=obs_data.get("output", ""),
            success=obs_data.get("success", True),
            error=obs_data.get("error", ""),
            task_id=obs_data.get("task_id", ""),
            task_path=obs_data.get("task_path", ""),
            session_id=obs_data.get("session_id"),
            action_type=obs_data.get("action_type", ""),
            info=obs_data.get("info", {}),
            reward=payload.get("reward"),
            done=payload.get("done", False),
            metadata=obs_data.get("metadata", {}),
        )
        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: dict[str, Any]) -> Tbench2State:
        return Tbench2State(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
            task_id=payload.get("task_id", ""),
            task_path=payload.get("task_path", ""),
            terminal_ready=payload.get("terminal_ready", False),
            last_action_type=payload.get("last_action_type", ""),
            last_command=payload.get("last_command", ""),
            last_output=payload.get("last_output", ""),
        )