|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""TB2 Environment Client.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
from openenv.core.client_types import StepResult |
|
|
from openenv.core.env_client import EnvClient |
|
|
|
|
|
from .models import Tbench2Action, Tbench2Observation, Tbench2State |
|
|
except ImportError: |
|
|
|
|
|
from openenv.core.client_types import StepResult |
|
|
from openenv.core.env_client import EnvClient |
|
|
|
|
|
from models import Tbench2Action, Tbench2Observation, Tbench2State |
|
|
|
|
|
|
|
|
class Tbench2Env(EnvClient[Tbench2Action, Tbench2Observation, Tbench2State]): |
|
|
"""HTTP client for the TB2 environment.""" |
|
|
|
|
|
def _step_payload(self, action: Tbench2Action) -> dict[str, Any]: |
|
|
return { |
|
|
"action_type": action.action_type, |
|
|
"command": action.command, |
|
|
"session_id": action.session_id, |
|
|
"block": action.block, |
|
|
"wait_seconds": action.wait_seconds, |
|
|
"file_path": action.file_path, |
|
|
"content": action.content, |
|
|
} |
|
|
|
|
|
def _parse_result(self, payload: dict[str, Any]) -> StepResult[Tbench2Observation]: |
|
|
obs_data = payload.get("observation", {}) |
|
|
observation = Tbench2Observation( |
|
|
instruction=obs_data.get("instruction", ""), |
|
|
output=obs_data.get("output", ""), |
|
|
success=obs_data.get("success", True), |
|
|
error=obs_data.get("error", ""), |
|
|
task_id=obs_data.get("task_id", ""), |
|
|
task_path=obs_data.get("task_path", ""), |
|
|
session_id=obs_data.get("session_id"), |
|
|
action_type=obs_data.get("action_type", ""), |
|
|
info=obs_data.get("info", {}), |
|
|
reward=payload.get("reward"), |
|
|
done=payload.get("done", False), |
|
|
metadata=obs_data.get("metadata", {}), |
|
|
) |
|
|
return StepResult( |
|
|
observation=observation, |
|
|
reward=payload.get("reward"), |
|
|
done=payload.get("done", False), |
|
|
) |
|
|
|
|
|
def _parse_state(self, payload: dict[str, Any]) -> Tbench2State: |
|
|
return Tbench2State( |
|
|
episode_id=payload.get("episode_id"), |
|
|
step_count=payload.get("step_count", 0), |
|
|
task_id=payload.get("task_id", ""), |
|
|
task_path=payload.get("task_path", ""), |
|
|
terminal_ready=payload.get("terminal_ready", False), |
|
|
last_action_type=payload.get("last_action_type", ""), |
|
|
last_command=payload.get("last_command", ""), |
|
|
last_output=payload.get("last_output", ""), |
|
|
) |
|
|
|