terminus_env / harness.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
3661dd4 verified
Raw
History Blame Contribute Delete
10.9 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Harness-oriented Terminus session adapter."""
from __future__ import annotations
import json
from typing import Any, Callable
from openenv.core.env_server.mcp_types import CallToolAction, Tool
from openenv.core.harness import (
ResourceSessionFactory,
StepEnvSessionAdapter,
ToolResult,
VerifyResult,
)
from .client import TerminusEnv
_TERMINUS_TOOLS: list[Tool] = [
Tool(
name="terminal",
description=(
"Run a shell command in the Terminus sandbox, or submit final_answer "
"to trigger verification."
),
input_schema={
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "Shell command to run in the sandbox.",
},
"final_answer": {
"type": "string",
"description": "Final answer to submit when the task is complete.",
},
},
"additionalProperties": False,
},
)
]
def _task_field(task: Any, *names: str, default: Any = None) -> Any:
if not isinstance(task, dict):
return default
for name in names:
value = task.get(name)
if value is not None:
return value
return default
def _coerce_commands(value: Any) -> list[str]:
if value is None:
return []
if isinstance(value, str):
return [value] if value.strip() else []
return [str(item) for item in value if str(item).strip()]
def _format_initial_prompt(result: Any, task: Any) -> str:
if isinstance(task, str):
instruction = task
setup_commands: list[str] = []
verify_commands: list[str] = []
elif isinstance(task, list):
user_messages = [
item.get("content")
for item in task
if isinstance(item, dict) and item.get("role") == "user"
]
instruction = str(user_messages[-1] if user_messages else task)
setup_commands = []
verify_commands = []
elif isinstance(task, dict):
instruction = str(
_task_field(task, "instruction", "prompt", "question", "task", default="")
)
setup_commands = _coerce_commands(_task_field(task, "setup", "setup_scripts"))
verify_commands = _coerce_commands(_task_field(task, "verify", "verify_scripts"))
else:
instruction = str(task or "")
setup_commands = []
verify_commands = []
metadata = getattr(result.observation, "metadata", {}) or {}
verify_commands = _coerce_commands(
metadata.get("verify_commands") or verify_commands
)
parts = []
if instruction:
parts.append(f"Task:\n{instruction}")
else:
parts.append("Task:\nUse the terminal tool to solve the current task.")
reset_message = metadata.get("message")
if reset_message:
parts.append(f"Environment:\n{reset_message}")
if setup_commands:
parts.append(
"Setup commands have already run:\n"
+ "\n".join(f"- {command}" for command in setup_commands)
)
if verify_commands:
parts.append(
"Verification commands will run after final_answer:\n"
+ "\n".join(f"- {command}" for command in verify_commands)
)
parts.append(
"Use terminal(command=...) to inspect and modify the sandbox. "
"When finished, call terminal(final_answer=...) exactly once so "
"verification runs and emits the environment reward."
)
return "\n\n".join(parts)
def _extract_tool_output(observation: Any) -> Any:
result = getattr(observation, "result", None)
if result is None:
return None
if hasattr(result, "data"):
return result.data
if isinstance(result, dict):
if "data" in result:
return result["data"]
content = result.get("content")
if isinstance(content, list):
texts = [
str(item.get("text"))
for item in content
if isinstance(item, dict) and item.get("text") is not None
]
if texts:
return "\n".join(texts)
return result
content = getattr(result, "content", None)
if isinstance(content, list):
texts = [
getattr(item, "text", None)
for item in content
if getattr(item, "text", None) is not None
]
if texts:
return "\n".join(texts)
return result
def _tool_error_message(observation: Any) -> str | None:
error = getattr(observation, "error", None)
if error is None:
return None
message = getattr(error, "message", None)
if message is not None:
return str(message)
if isinstance(error, dict):
return str(error.get("message") or error)
return str(error)
def _state_to_data(state: Any) -> Any:
if state is None:
return None
if hasattr(state, "model_dump"):
return state.model_dump()
return state
def _build_tool_result(
tool_name: str,
arguments: dict[str, Any],
result: Any,
state: Any,
) -> ToolResult:
output = _extract_tool_output(result.observation)
error = _tool_error_message(result.observation)
data = {
"tool_name": tool_name,
"arguments": dict(arguments),
"output": output,
"reward": result.reward,
"done": result.done,
}
if error:
data["error"] = error
return ToolResult(
data=data,
done=bool(result.done),
error=error,
metadata={
"reward": result.reward,
"state": _state_to_data(state),
},
)
def _build_verify(
transcript: list[dict[str, Any]],
final_state: Any | None,
last_result: Any | None,
state: Any,
) -> VerifyResult:
reward = None if last_result is None else last_result.reward
done = False if last_result is None else bool(last_result.done)
state_data = _state_to_data(state)
metrics = {
"done": done,
"step_count": getattr(state, "step_count", 0),
"commands": len(getattr(state, "commands", []) or []),
"verify_commands": len(getattr(state, "verify_commands", []) or []),
"setup_commands": len(getattr(state, "setup_results", []) or []),
"submitted_answer": getattr(state, "submitted_answer", None) is not None,
"sandbox_id": getattr(state, "sandbox_id", None),
}
if state is None and last_result is not None:
metrics["step_count"] = len(transcript)
return VerifyResult(
env_reward=reward,
done=done,
metrics=metrics,
artifacts={
"final_state": state_data,
"final_rollout": final_state,
"transcript_length": len(transcript),
},
)
def _build_reset_kwargs(
task: Any,
default_setup: list[str],
default_verify: list[str],
default_sandbox: dict[str, Any],
) -> dict[str, Any]:
reset_kwargs: dict[str, Any] = dict(default_sandbox)
setup = list(default_setup)
verify = list(default_verify)
if isinstance(task, dict):
setup = _coerce_commands(_task_field(task, "setup", "setup_scripts", default=setup))
verify = _coerce_commands(
_task_field(task, "verify", "verify_scripts", default=verify)
)
for key in (
"sandbox_image",
"sandbox_flavor",
"sandbox_timeout",
"hf_sandbox_image",
"hf_sandbox_flavor",
"hf_sandbox_timeout",
"forward_hf_token",
):
if key in task:
reset_kwargs[key] = task[key]
if setup:
reset_kwargs["setup"] = setup
if verify:
reset_kwargs["verify"] = verify
return reset_kwargs
class TerminusSessionFactory(ResourceSessionFactory):
"""Create Terminus-backed resource sessions for harness rollouts."""
def __init__(
self,
client_factory: Callable[[], TerminusEnv],
*,
default_setup: list[str] | None = None,
default_verify: list[str] | None = None,
sandbox: dict[str, Any] | None = None,
):
self._client_factory = client_factory
self._default_setup = list(default_setup or [])
self._default_verify = list(default_verify or [])
self._sandbox = dict(sandbox or {})
def create(
self,
task: Any = None,
seed: int | None = None,
episode_id: str | None = None,
) -> StepEnvSessionAdapter:
reset_kwargs = _build_reset_kwargs(
task,
self._default_setup,
self._default_verify,
self._sandbox,
)
return StepEnvSessionAdapter(
client=self._client_factory(),
task=task,
seed=seed,
episode_id=episode_id,
tool_specs=list(_TERMINUS_TOOLS),
action_builder=lambda name, arguments: CallToolAction(
tool_name=name,
arguments=dict(arguments),
),
initial_messages_builder=lambda result, current_task: [
{
"role": "user",
"content": _format_initial_prompt(result, current_task),
}
],
tool_result_builder=_build_tool_result,
verify_builder=_build_verify,
reset_kwargs=reset_kwargs,
)
def build_terminal_tool_call(response_text: str, *, call_id: str = "terminal-0"):
"""Parse a simple JSON terminal call from model text.
The preferred format is one JSON object containing ``command`` or
``final_answer``. Invalid text falls back to a no-op shell command so the
environment, not this parser, decides whether a rollout earns reward.
"""
from openenv.core.llm_client import ToolCall
text = response_text.strip()
if text.startswith("```"):
text = text.strip("`").strip()
if text.startswith("json"):
text = text[4:].strip()
try:
payload = json.loads(text)
except json.JSONDecodeError:
payload = {"command": response_text}
if not isinstance(payload, dict):
payload = {"command": response_text}
arguments = {
key: str(payload[key])
for key in ("command", "final_answer")
if payload.get(key) is not None
}
if not arguments:
arguments = {"command": ""}
return ToolCall(id=call_id, name="terminal", args=arguments)
__all__ = [
"TerminusSessionFactory",
"build_terminal_tool_call",
]