Spaces:
Running
Running
File size: 5,388 Bytes
035d186 00d49da 035d186 9de209d 035d186 9de209d 035d186 7e21458 035d186 9de209d 035d186 00d49da 035d186 7e21458 035d186 00d49da 035d186 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
"""
Collection of Inspect AI solvers used by the rubric task.
"""
from __future__ import annotations
import asyncio
import json
import os
import tempfile
from typing import Callable, Dict, List, Sequence
import litellm
from inspect_ai.model import ChatMessageAssistant, ModelOutput
from inspect_ai.solver import Solver, solver
from inspect_ai.solver._task_state import TaskState
from lmnr import Laminar, LaminarLiteLLMCallback
from eval.hf_agent_connector import AgentResponseGenerator
async def _run_subprocess(command: Sequence[str]) -> str:
process = await asyncio.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
raise RuntimeError(
f"Command {' '.join(command)} failed with code {process.returncode}:\n"
f"{stderr.decode().strip()}"
)
return stdout.decode().strip()
@solver(name="hf_agent")
def hf_agent(
config_path: str = "agent/config_mcp_example.json",
max_iterations: int = 10,
) -> Solver:
# init lmnr for observability
Laminar.initialize(project_api_key=os.environ.get("LMNR_API_KEY"))
litellm.callbacks = [LaminarLiteLLMCallback()]
print("✅ Laminar initialized")
runner = AgentResponseGenerator(
config_path=config_path,
max_iterations=max_iterations,
)
async def solve(state: TaskState, generate) -> TaskState:
response = await runner.run(state.input_text)
assistant_message = ChatMessageAssistant(
content=response,
model=runner.model_name,
source="generate",
)
state.messages.append(assistant_message)
state.output = ModelOutput.from_message(assistant_message)
state.completed = True
return state
return solve
@solver(name="claude_code")
def claude_code(
output_format: str = "json",
mcp_config: str | None = None,
) -> Solver:
if output_format not in {"text", "json", "stream-json"}:
raise ValueError("output_format must be one of: text, json, stream-json")
async def solve(state: TaskState, generate) -> TaskState:
prompt = state.input_text
cmd: List[str] = ["claude", "-p", prompt, "--output-format", output_format]
if mcp_config:
cmd += ["--mcp-config", mcp_config]
stdout = await _run_subprocess(cmd)
response_text = stdout
session_id = None
if output_format in {"json", "stream-json"}:
# stream-json may emit multiple JSON objects; take the last complete line
candidate_line = stdout.strip().splitlines()[-1]
try:
payload = json.loads(candidate_line)
response_text = (
payload.get("result") or payload.get("message", "") or stdout
)
session_id = payload.get("session_id")
except (json.JSONDecodeError, AttributeError):
response_text = stdout
assistant_message = ChatMessageAssistant(
content=response_text,
model="claude-code",
source="generate",
metadata={"session_id": session_id} if session_id else None,
)
state.messages.append(assistant_message)
state.output = ModelOutput.from_message(assistant_message)
state.completed = True
return state
return solve
@solver(name="claude_code+hf_mcp")
def claude_code_hf_mcp(
output_format: str = "json",
hf_token: str | None = None,
) -> Solver:
"""
A solver that uses Claude Code with the Hugging Face MCP server.
Requires HF_TOKEN in environment variables or passed as argument.
"""
token = hf_token or os.environ.get("HF_TOKEN")
if not token:
raise ValueError(
"HF_TOKEN not found. Please set HF_TOKEN env var or pass it to the solver."
)
# Construct the MCP configuration for Hugging Face
mcp_config = {
"mcpServers": {
"huggingface": {
"type": "http",
"url": "https://huggingface.co/mcp",
"headers": {"Authorization": f"Bearer {token}"},
}
}
}
async def solve(state: TaskState, generate) -> TaskState:
# Write config to a temporary file
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
json.dump(mcp_config, tmp, indent=2)
tmp_path = tmp.name
try:
# Delegate to the base claude_code solver
delegate = claude_code(output_format=output_format, mcp_config=tmp_path)
return await delegate(state, generate)
finally:
# Clean up the temporary file
if os.path.exists(tmp_path):
os.remove(tmp_path)
return solve
SOLVER_REGISTRY: Dict[str, Callable[..., Solver]] = {
"hf_agent": hf_agent,
"claude_code": claude_code,
"claude_code+hf_mcp": claude_code_hf_mcp,
}
def get_solver(name: str, **kwargs) -> Solver:
try:
factory = SOLVER_REGISTRY[name]
except KeyError as exc:
available = ", ".join(sorted(SOLVER_REGISTRY))
raise ValueError(f"Unknown solver '{name}'. Available: {available}") from exc
return factory(**kwargs)
|