ml-agent / eval /claude_batch_solve.py
akseljoonas's picture
akseljoonas HF Staff
Initial commit: ML Agent with Xet storage for binaries
8cfacd3
import asyncio
import json
import os
import threading
from pathlib import Path
from typing import Any
from claude_agent_sdk import (
AssistantMessage,
ClaudeAgentOptions,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
query,
)
from dotenv import load_dotenv
load_dotenv()
# Thread-safe file writing
file_lock = threading.Lock()
def convert_message_to_chat_format(message: Any) -> dict | None:
"""Convert SDK message to standard chat format with role/content/tool_calls."""
if isinstance(message, SystemMessage):
# Extract tools list from init data for system message
if message.subtype == "init":
tools = message.data.get("tools", [])
tools_desc = "\n".join(f"- {tool}" for tool in tools)
return {
"role": "system",
"content": f"You are a helpful assistant with access to the following tools:\n{tools_desc}",
}
return None
elif isinstance(message, AssistantMessage):
text_content = ""
tool_calls = []
for block in message.content:
if isinstance(block, TextBlock):
text_content += block.text
elif isinstance(block, ToolUseBlock):
tool_calls.append(
{
"id": block.id,
"function": {
"name": block.name,
"arguments": block.input,
},
}
)
result = {"role": "assistant", "content": text_content}
if tool_calls:
result["tool_calls"] = tool_calls
return result
elif isinstance(message, UserMessage):
# UserMessage can contain tool results or text
if isinstance(message.content, str):
return {"role": "user", "content": message.content}
elif isinstance(message.content, list):
# Check for tool results
tool_results = []
text_content = ""
for block in message.content:
if isinstance(block, ToolResultBlock):
# Format tool result content
if isinstance(block.content, str):
content = block.content
elif isinstance(block.content, list):
content = json.dumps(block.content)
else:
content = str(block.content) if block.content else ""
tool_results.append(
{
"tool_use_id": block.tool_use_id,
"content": content,
"is_error": block.is_error,
}
)
elif isinstance(block, TextBlock):
text_content += block.text
if tool_results:
return {
"role": "user",
"content": f"<tool_response>\n{json.dumps(tool_results, indent=2)}\n</tool_response>",
}
else:
return {"role": "user", "content": text_content}
return None
elif isinstance(message, ResultMessage):
# ResultMessage is metadata, not a conversation message
return None
return None
async def solve_task(
question: str,
difficulty: str,
task_idx: int,
total: int,
semaphore: asyncio.Semaphore,
) -> dict:
"""Solve a single task using Claude Agent SDK."""
async with semaphore:
print(f"[{task_idx}/{total}] Starting: {question[:60]}...")
messages = []
solution = None
try:
async for message in query(
prompt=question,
options=ClaudeAgentOptions(
cwd=os.getcwd(),
permission_mode="bypassPermissions",
disallowed_tools=["Write", "Edit", "Bash", "Glob", "Grep"],
mcp_servers={
"huggingface": {
"type": "http",
"url": "https://huggingface.co/mcp",
"headers": {
"Authorization": f"Bearer {os.environ['HF_TOKEN']}"
},
}
},
),
):
# Convert to chat format and append if valid
chat_msg = convert_message_to_chat_format(message)
if chat_msg:
messages.append(chat_msg)
# Extract text from assistant messages
if isinstance(message, AssistantMessage):
for block in message.content:
if isinstance(block, TextBlock):
solution = block.text
# Check for result messages
elif isinstance(message, ResultMessage):
if message.is_error:
print(f"[{task_idx}/{total}] ✗ Agent error: {message.subtype}")
return {
"question": question,
"difficulty": difficulty,
"solution": None,
"messages": messages,
"error": f"Agent error: {message.subtype}",
}
elif message.result:
solution = message.result
print(f"[{task_idx}/{total}] ✓ Done: {question[:60]}...")
return {
"question": question,
"difficulty": difficulty,
"solution": solution,
"messages": messages,
"error": None,
}
except Exception as e:
print(f"[{task_idx}/{total}] ✗ Error: {e}")
return {
"question": question,
"difficulty": difficulty,
"solution": None,
"messages": messages,
"error": str(e),
}
def write_result(output_path: Path, result: dict):
"""Thread-safe write to output file."""
with file_lock:
with open(output_path, "a") as f:
f.write(json.dumps(result) + "\n")
async def main():
# Load tasks from filled_tasks.jsonl
tasks_path = Path(__file__).parent / "filled_tasks.jsonl"
tasks = []
with open(tasks_path) as f:
for line in f:
tasks.append(json.loads(line))
# Output file - clear it first
output_path = Path(__file__).parent / "solved_tasks.jsonl"
output_path.write_text("")
# Semaphore to limit concurrency
max_concurrent = 5
semaphore = asyncio.Semaphore(max_concurrent)
total = len(tasks)
print(f"Processing {total} tasks with {max_concurrent} concurrent agents...")
async def process_and_save(task: dict, idx: int):
result = await solve_task(
task["question"], task["difficulty"], idx, total, semaphore
)
write_result(output_path, result)
return result
# Create all tasks
coroutines = [process_and_save(task, i + 1) for i, task in enumerate(tasks)]
# Run all concurrently (semaphore limits actual parallelism)
results = await asyncio.gather(*coroutines, return_exceptions=True)
successful = sum(
1 for r in results if isinstance(r, dict) and r.get("error") is None
)
print(f"\nCompleted: {successful}/{total} successful")
print(f"Results saved to {output_path}")
if __name__ == "__main__":
asyncio.run(main())