# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Harness-oriented Terminus session adapter.""" from __future__ import annotations import json from typing import Any, Callable from openenv.core.env_server.mcp_types import CallToolAction, Tool from openenv.core.harness import ( ResourceSessionFactory, StepEnvSessionAdapter, ToolResult, VerifyResult, ) from .client import TerminusEnv _TERMINUS_TOOLS: list[Tool] = [ Tool( name="terminal", description=( "Run a shell command in the Terminus sandbox, or submit final_answer " "to trigger verification." ), input_schema={ "type": "object", "properties": { "command": { "type": "string", "description": "Shell command to run in the sandbox.", }, "final_answer": { "type": "string", "description": "Final answer to submit when the task is complete.", }, }, "additionalProperties": False, }, ) ] def _task_field(task: Any, *names: str, default: Any = None) -> Any: if not isinstance(task, dict): return default for name in names: value = task.get(name) if value is not None: return value return default def _coerce_commands(value: Any) -> list[str]: if value is None: return [] if isinstance(value, str): return [value] if value.strip() else [] return [str(item) for item in value if str(item).strip()] def _format_initial_prompt(result: Any, task: Any) -> str: if isinstance(task, str): instruction = task setup_commands: list[str] = [] verify_commands: list[str] = [] elif isinstance(task, list): user_messages = [ item.get("content") for item in task if isinstance(item, dict) and item.get("role") == "user" ] instruction = str(user_messages[-1] if user_messages else task) setup_commands = [] verify_commands = [] elif isinstance(task, dict): instruction = str( _task_field(task, "instruction", "prompt", "question", "task", default="") ) setup_commands = _coerce_commands(_task_field(task, "setup", "setup_scripts")) verify_commands = _coerce_commands(_task_field(task, "verify", "verify_scripts")) else: instruction = str(task or "") setup_commands = [] verify_commands = [] metadata = getattr(result.observation, "metadata", {}) or {} verify_commands = _coerce_commands( metadata.get("verify_commands") or verify_commands ) parts = [] if instruction: parts.append(f"Task:\n{instruction}") else: parts.append("Task:\nUse the terminal tool to solve the current task.") reset_message = metadata.get("message") if reset_message: parts.append(f"Environment:\n{reset_message}") if setup_commands: parts.append( "Setup commands have already run:\n" + "\n".join(f"- {command}" for command in setup_commands) ) if verify_commands: parts.append( "Verification commands will run after final_answer:\n" + "\n".join(f"- {command}" for command in verify_commands) ) parts.append( "Use terminal(command=...) to inspect and modify the sandbox. " "When finished, call terminal(final_answer=...) exactly once so " "verification runs and emits the environment reward." ) return "\n\n".join(parts) def _extract_tool_output(observation: Any) -> Any: result = getattr(observation, "result", None) if result is None: return None if hasattr(result, "data"): return result.data if isinstance(result, dict): if "data" in result: return result["data"] content = result.get("content") if isinstance(content, list): texts = [ str(item.get("text")) for item in content if isinstance(item, dict) and item.get("text") is not None ] if texts: return "\n".join(texts) return result content = getattr(result, "content", None) if isinstance(content, list): texts = [ getattr(item, "text", None) for item in content if getattr(item, "text", None) is not None ] if texts: return "\n".join(texts) return result def _tool_error_message(observation: Any) -> str | None: error = getattr(observation, "error", None) if error is None: return None message = getattr(error, "message", None) if message is not None: return str(message) if isinstance(error, dict): return str(error.get("message") or error) return str(error) def _state_to_data(state: Any) -> Any: if state is None: return None if hasattr(state, "model_dump"): return state.model_dump() return state def _build_tool_result( tool_name: str, arguments: dict[str, Any], result: Any, state: Any, ) -> ToolResult: output = _extract_tool_output(result.observation) error = _tool_error_message(result.observation) data = { "tool_name": tool_name, "arguments": dict(arguments), "output": output, "reward": result.reward, "done": result.done, } if error: data["error"] = error return ToolResult( data=data, done=bool(result.done), error=error, metadata={ "reward": result.reward, "state": _state_to_data(state), }, ) def _build_verify( transcript: list[dict[str, Any]], final_state: Any | None, last_result: Any | None, state: Any, ) -> VerifyResult: reward = None if last_result is None else last_result.reward done = False if last_result is None else bool(last_result.done) state_data = _state_to_data(state) metrics = { "done": done, "step_count": getattr(state, "step_count", 0), "commands": len(getattr(state, "commands", []) or []), "verify_commands": len(getattr(state, "verify_commands", []) or []), "setup_commands": len(getattr(state, "setup_results", []) or []), "submitted_answer": getattr(state, "submitted_answer", None) is not None, "sandbox_id": getattr(state, "sandbox_id", None), } if state is None and last_result is not None: metrics["step_count"] = len(transcript) return VerifyResult( env_reward=reward, done=done, metrics=metrics, artifacts={ "final_state": state_data, "final_rollout": final_state, "transcript_length": len(transcript), }, ) def _build_reset_kwargs( task: Any, default_setup: list[str], default_verify: list[str], default_sandbox: dict[str, Any], ) -> dict[str, Any]: reset_kwargs: dict[str, Any] = dict(default_sandbox) setup = list(default_setup) verify = list(default_verify) if isinstance(task, dict): setup = _coerce_commands(_task_field(task, "setup", "setup_scripts", default=setup)) verify = _coerce_commands( _task_field(task, "verify", "verify_scripts", default=verify) ) for key in ( "sandbox_image", "sandbox_flavor", "sandbox_timeout", "hf_sandbox_image", "hf_sandbox_flavor", "hf_sandbox_timeout", "forward_hf_token", ): if key in task: reset_kwargs[key] = task[key] if setup: reset_kwargs["setup"] = setup if verify: reset_kwargs["verify"] = verify return reset_kwargs class TerminusSessionFactory(ResourceSessionFactory): """Create Terminus-backed resource sessions for harness rollouts.""" def __init__( self, client_factory: Callable[[], TerminusEnv], *, default_setup: list[str] | None = None, default_verify: list[str] | None = None, sandbox: dict[str, Any] | None = None, ): self._client_factory = client_factory self._default_setup = list(default_setup or []) self._default_verify = list(default_verify or []) self._sandbox = dict(sandbox or {}) def create( self, task: Any = None, seed: int | None = None, episode_id: str | None = None, ) -> StepEnvSessionAdapter: reset_kwargs = _build_reset_kwargs( task, self._default_setup, self._default_verify, self._sandbox, ) return StepEnvSessionAdapter( client=self._client_factory(), task=task, seed=seed, episode_id=episode_id, tool_specs=list(_TERMINUS_TOOLS), action_builder=lambda name, arguments: CallToolAction( tool_name=name, arguments=dict(arguments), ), initial_messages_builder=lambda result, current_task: [ { "role": "user", "content": _format_initial_prompt(result, current_task), } ], tool_result_builder=_build_tool_result, verify_builder=_build_verify, reset_kwargs=reset_kwargs, ) def build_terminal_tool_call(response_text: str, *, call_id: str = "terminal-0"): """Parse a simple JSON terminal call from model text. The preferred format is one JSON object containing ``command`` or ``final_answer``. Invalid text falls back to a no-op shell command so the environment, not this parser, decides whether a rollout earns reward. """ from openenv.core.llm_client import ToolCall text = response_text.strip() if text.startswith("```"): text = text.strip("`").strip() if text.startswith("json"): text = text[4:].strip() try: payload = json.loads(text) except json.JSONDecodeError: payload = {"command": response_text} if not isinstance(payload, dict): payload = {"command": response_text} arguments = { key: str(payload[key]) for key in ("command", "final_answer") if payload.get(key) is not None } if not arguments: arguments = {"command": ""} return ToolCall(id=call_id, name="terminal", args=arguments) __all__ = [ "TerminusSessionFactory", "build_terminal_tool_call", ]