ml-agent / eval /solvers.py
akseljoonas's picture
akseljoonas HF Staff
Initial commit: ML Agent with Xet storage for binaries
8cfacd3
"""
Collection of Inspect AI solvers used by the rubric task.
"""
from __future__ import annotations
import asyncio
import json
import os
import tempfile
from typing import Callable, Dict, List, Sequence
import litellm
from inspect_ai.model import ChatMessageAssistant, ModelOutput
from inspect_ai.solver import Solver, solver
from inspect_ai.solver._task_state import TaskState
from lmnr import Laminar, LaminarLiteLLMCallback
from eval.hf_agent_connector import AgentResponseGenerator
async def _run_subprocess(command: Sequence[str]) -> str:
process = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
raise RuntimeError(
f"Command {' '.join(command)} failed with code {process.returncode}:\n"
f"{stderr.decode().strip()}"
)
return stdout.decode().strip()
@solver(name="hf_agent")
def hf_agent(
config_path: str = "agent/config_mcp_example.json",
max_iterations: int = 10,
) -> Solver:
# init lmnr for observability
Laminar.initialize(project_api_key=os.environ.get("LMNR_API_KEY"))
litellm.callbacks = [LaminarLiteLLMCallback()]
print("✅ Laminar initialized")
runner = AgentResponseGenerator(
config_path=config_path,
max_iterations=max_iterations,
)
async def solve(state: TaskState, generate) -> TaskState:
response = await runner.run(state.input_text)
assistant_message = ChatMessageAssistant(
content=response,
model=runner.model_name,
source="generate",
)
state.messages.append(assistant_message)
state.output = ModelOutput.from_message(assistant_message)
state.completed = True
return state
return solve
@solver(name="claude_code")
def claude_code(
output_format: str = "json",
mcp_config: str | None = None,
) -> Solver:
if output_format not in {"text", "json", "stream-json"}:
raise ValueError("output_format must be one of: text, json, stream-json")
async def solve(state: TaskState, generate) -> TaskState:
prompt = state.input_text
cmd: List[str] = ["claude", "-p", prompt, "--output-format", output_format]
if mcp_config:
cmd += ["--mcp-config", mcp_config]
stdout = await _run_subprocess(cmd)
response_text = stdout
session_id = None
if output_format in {"json", "stream-json"}:
# stream-json may emit multiple JSON objects; take the last complete line
candidate_line = stdout.strip().splitlines()[-1]
try:
payload = json.loads(candidate_line)
response_text = (
payload.get("result") or payload.get("message", "") or stdout
)
session_id = payload.get("session_id")
except (json.JSONDecodeError, AttributeError):
response_text = stdout
assistant_message = ChatMessageAssistant(
content=response_text,
model="claude-code",
source="generate",
metadata={"session_id": session_id} if session_id else None,
)
state.messages.append(assistant_message)
state.output = ModelOutput.from_message(assistant_message)
state.completed = True
return state
return solve
@solver(name="claude_code+hf_mcp")
def claude_code_hf_mcp(
output_format: str = "json",
hf_token: str | None = None,
) -> Solver:
"""
A solver that uses Claude Code with the Hugging Face MCP server.
Requires HF_TOKEN in environment variables or passed as argument.
"""
token = hf_token or os.environ.get("HF_TOKEN")
if not token:
raise ValueError(
"HF_TOKEN not found. Please set HF_TOKEN env var or pass it to the solver."
)
# Construct the MCP configuration for Hugging Face
mcp_config = {
"mcpServers": {
"huggingface": {
"type": "http",
"url": "https://huggingface.co/mcp",
"headers": {"Authorization": f"Bearer {token}"},
}
}
}
async def solve(state: TaskState, generate) -> TaskState:
# Write config to a temporary file
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
json.dump(mcp_config, tmp, indent=2)
tmp_path = tmp.name
try:
# Delegate to the base claude_code solver
delegate = claude_code(output_format=output_format, mcp_config=tmp_path)
return await delegate(state, generate)
finally:
# Clean up the temporary file
if os.path.exists(tmp_path):
os.remove(tmp_path)
return solve
SOLVER_REGISTRY: Dict[str, Callable[..., Solver]] = {
"hf_agent": hf_agent,
"claude_code": claude_code,
"claude_code+hf_mcp": claude_code_hf_mcp,
}
def get_solver(name: str, **kwargs) -> Solver:
try:
factory = SOLVER_REGISTRY[name]
except KeyError as exc:
available = ", ".join(sorted(SOLVER_REGISTRY))
raise ValueError(f"Unknown solver '{name}'. Available: {available}") from exc
return factory(**kwargs)