| | """ |
| | Collection of Inspect AI solvers used by the rubric task. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import asyncio |
| | import json |
| | import os |
| | import tempfile |
| | from typing import Callable, Dict, List, Sequence |
| |
|
| | import litellm |
| | from inspect_ai.model import ChatMessageAssistant, ModelOutput |
| | from inspect_ai.solver import Solver, solver |
| | from inspect_ai.solver._task_state import TaskState |
| | from lmnr import Laminar, LaminarLiteLLMCallback |
| |
|
| | from eval.hf_agent_connector import AgentResponseGenerator |
| |
|
| |
|
| | async def _run_subprocess(command: Sequence[str]) -> str: |
| | process = await asyncio.create_subprocess_exec( |
| | *command, |
| | stdout=asyncio.subprocess.PIPE, |
| | stderr=asyncio.subprocess.PIPE, |
| | ) |
| | stdout, stderr = await process.communicate() |
| | if process.returncode != 0: |
| | raise RuntimeError( |
| | f"Command {' '.join(command)} failed with code {process.returncode}:\n" |
| | f"{stderr.decode().strip()}" |
| | ) |
| | return stdout.decode().strip() |
| |
|
| |
|
| | @solver(name="hf_agent") |
| | def hf_agent( |
| | config_path: str = "agent/config_mcp_example.json", |
| | max_iterations: int = 10, |
| | ) -> Solver: |
| | |
| | Laminar.initialize(project_api_key=os.environ.get("LMNR_API_KEY")) |
| | litellm.callbacks = [LaminarLiteLLMCallback()] |
| | print("✅ Laminar initialized") |
| |
|
| | runner = AgentResponseGenerator( |
| | config_path=config_path, |
| | max_iterations=max_iterations, |
| | ) |
| |
|
| | async def solve(state: TaskState, generate) -> TaskState: |
| | response = await runner.run(state.input_text) |
| | assistant_message = ChatMessageAssistant( |
| | content=response, |
| | model=runner.model_name, |
| | source="generate", |
| | ) |
| | state.messages.append(assistant_message) |
| | state.output = ModelOutput.from_message(assistant_message) |
| | state.completed = True |
| | return state |
| |
|
| | return solve |
| |
|
| |
|
| | @solver(name="claude_code") |
| | def claude_code( |
| | output_format: str = "json", |
| | mcp_config: str | None = None, |
| | ) -> Solver: |
| | if output_format not in {"text", "json", "stream-json"}: |
| | raise ValueError("output_format must be one of: text, json, stream-json") |
| |
|
| | async def solve(state: TaskState, generate) -> TaskState: |
| | prompt = state.input_text |
| |
|
| | cmd: List[str] = ["claude", "-p", prompt, "--output-format", output_format] |
| | if mcp_config: |
| | cmd += ["--mcp-config", mcp_config] |
| |
|
| | stdout = await _run_subprocess(cmd) |
| | response_text = stdout |
| | session_id = None |
| |
|
| | if output_format in {"json", "stream-json"}: |
| | |
| | candidate_line = stdout.strip().splitlines()[-1] |
| | try: |
| | payload = json.loads(candidate_line) |
| | response_text = ( |
| | payload.get("result") or payload.get("message", "") or stdout |
| | ) |
| | session_id = payload.get("session_id") |
| | except (json.JSONDecodeError, AttributeError): |
| | response_text = stdout |
| |
|
| | assistant_message = ChatMessageAssistant( |
| | content=response_text, |
| | model="claude-code", |
| | source="generate", |
| | metadata={"session_id": session_id} if session_id else None, |
| | ) |
| | state.messages.append(assistant_message) |
| | state.output = ModelOutput.from_message(assistant_message) |
| | state.completed = True |
| | return state |
| |
|
| | return solve |
| |
|
| |
|
| | @solver(name="claude_code+hf_mcp") |
| | def claude_code_hf_mcp( |
| | output_format: str = "json", |
| | hf_token: str | None = None, |
| | ) -> Solver: |
| | """ |
| | A solver that uses Claude Code with the Hugging Face MCP server. |
| | Requires HF_TOKEN in environment variables or passed as argument. |
| | """ |
| | token = hf_token or os.environ.get("HF_TOKEN") |
| | if not token: |
| | raise ValueError( |
| | "HF_TOKEN not found. Please set HF_TOKEN env var or pass it to the solver." |
| | ) |
| |
|
| | |
| | mcp_config = { |
| | "mcpServers": { |
| | "huggingface": { |
| | "type": "http", |
| | "url": "https://huggingface.co/mcp", |
| | "headers": {"Authorization": f"Bearer {token}"}, |
| | } |
| | } |
| | } |
| |
|
| | async def solve(state: TaskState, generate) -> TaskState: |
| | |
| | with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: |
| | json.dump(mcp_config, tmp, indent=2) |
| | tmp_path = tmp.name |
| |
|
| | try: |
| | |
| | delegate = claude_code(output_format=output_format, mcp_config=tmp_path) |
| | return await delegate(state, generate) |
| | finally: |
| | |
| | if os.path.exists(tmp_path): |
| | os.remove(tmp_path) |
| |
|
| | return solve |
| |
|
| |
|
| | SOLVER_REGISTRY: Dict[str, Callable[..., Solver]] = { |
| | "hf_agent": hf_agent, |
| | "claude_code": claude_code, |
| | "claude_code+hf_mcp": claude_code_hf_mcp, |
| | } |
| |
|
| |
|
| | def get_solver(name: str, **kwargs) -> Solver: |
| | try: |
| | factory = SOLVER_REGISTRY[name] |
| | except KeyError as exc: |
| | available = ", ".join(sorted(SOLVER_REGISTRY)) |
| | raise ValueError(f"Unknown solver '{name}'. Available: {available}") from exc |
| |
|
| | return factory(**kwargs) |
| |
|