File size: 5,388 Bytes
035d186
 
 
 
 
 
 
 
00d49da
 
035d186
 
9de209d
035d186
 
 
9de209d
035d186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e21458
 
035d186
 
 
9de209d
 
 
 
 
035d186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00d49da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
035d186
7e21458
035d186
00d49da
035d186
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
Collection of Inspect AI solvers used by the rubric task.
"""

from __future__ import annotations

import asyncio
import json
import os
import tempfile
from typing import Callable, Dict, List, Sequence

import litellm
from inspect_ai.model import ChatMessageAssistant, ModelOutput
from inspect_ai.solver import Solver, solver
from inspect_ai.solver._task_state import TaskState
from lmnr import Laminar, LaminarLiteLLMCallback

from eval.hf_agent_connector import AgentResponseGenerator


async def _run_subprocess(command: Sequence[str]) -> str:
    process = await asyncio.create_subprocess_exec(
        *command,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.PIPE,
    )
    stdout, stderr = await process.communicate()
    if process.returncode != 0:
        raise RuntimeError(
            f"Command {' '.join(command)} failed with code {process.returncode}:\n"
            f"{stderr.decode().strip()}"
        )
    return stdout.decode().strip()


@solver(name="hf_agent")
def hf_agent(
    config_path: str = "agent/config_mcp_example.json",
    max_iterations: int = 10,
) -> Solver:
    # init lmnr for observability
    Laminar.initialize(project_api_key=os.environ.get("LMNR_API_KEY"))
    litellm.callbacks = [LaminarLiteLLMCallback()]
    print("✅ Laminar initialized")

    runner = AgentResponseGenerator(
        config_path=config_path,
        max_iterations=max_iterations,
    )

    async def solve(state: TaskState, generate) -> TaskState:
        response = await runner.run(state.input_text)
        assistant_message = ChatMessageAssistant(
            content=response,
            model=runner.model_name,
            source="generate",
        )
        state.messages.append(assistant_message)
        state.output = ModelOutput.from_message(assistant_message)
        state.completed = True
        return state

    return solve


@solver(name="claude_code")
def claude_code(
    output_format: str = "json",
    mcp_config: str | None = None,
) -> Solver:
    if output_format not in {"text", "json", "stream-json"}:
        raise ValueError("output_format must be one of: text, json, stream-json")

    async def solve(state: TaskState, generate) -> TaskState:
        prompt = state.input_text

        cmd: List[str] = ["claude", "-p", prompt, "--output-format", output_format]
        if mcp_config:
            cmd += ["--mcp-config", mcp_config]

        stdout = await _run_subprocess(cmd)
        response_text = stdout
        session_id = None

        if output_format in {"json", "stream-json"}:
            # stream-json may emit multiple JSON objects; take the last complete line
            candidate_line = stdout.strip().splitlines()[-1]
            try:
                payload = json.loads(candidate_line)
                response_text = (
                    payload.get("result") or payload.get("message", "") or stdout
                )
                session_id = payload.get("session_id")
            except (json.JSONDecodeError, AttributeError):
                response_text = stdout

        assistant_message = ChatMessageAssistant(
            content=response_text,
            model="claude-code",
            source="generate",
            metadata={"session_id": session_id} if session_id else None,
        )
        state.messages.append(assistant_message)
        state.output = ModelOutput.from_message(assistant_message)
        state.completed = True
        return state

    return solve


@solver(name="claude_code+hf_mcp")
def claude_code_hf_mcp(
    output_format: str = "json",
    hf_token: str | None = None,
) -> Solver:
    """
    A solver that uses Claude Code with the Hugging Face MCP server.
    Requires HF_TOKEN in environment variables or passed as argument.
    """
    token = hf_token or os.environ.get("HF_TOKEN")
    if not token:
        raise ValueError(
            "HF_TOKEN not found. Please set HF_TOKEN env var or pass it to the solver."
        )

    # Construct the MCP configuration for Hugging Face
    mcp_config = {
        "mcpServers": {
            "huggingface": {
                "type": "http",
                "url": "https://huggingface.co/mcp",
                "headers": {"Authorization": f"Bearer {token}"},
            }
        }
    }

    async def solve(state: TaskState, generate) -> TaskState:
        # Write config to a temporary file
        with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
            json.dump(mcp_config, tmp, indent=2)
            tmp_path = tmp.name

        try:
            # Delegate to the base claude_code solver
            delegate = claude_code(output_format=output_format, mcp_config=tmp_path)
            return await delegate(state, generate)
        finally:
            # Clean up the temporary file
            if os.path.exists(tmp_path):
                os.remove(tmp_path)

    return solve


SOLVER_REGISTRY: Dict[str, Callable[..., Solver]] = {
    "hf_agent": hf_agent,
    "claude_code": claude_code,
    "claude_code+hf_mcp": claude_code_hf_mcp,
}


def get_solver(name: str, **kwargs) -> Solver:
    try:
        factory = SOLVER_REGISTRY[name]
    except KeyError as exc:
        available = ", ".join(sorted(SOLVER_REGISTRY))
        raise ValueError(f"Unknown solver '{name}'. Available: {available}") from exc

    return factory(**kwargs)