| | import asyncio |
| | import json |
| | import os |
| | import threading |
| | from pathlib import Path |
| | from typing import Any |
| |
|
| | from claude_agent_sdk import ( |
| | AssistantMessage, |
| | ClaudeAgentOptions, |
| | ResultMessage, |
| | SystemMessage, |
| | TextBlock, |
| | ToolResultBlock, |
| | ToolUseBlock, |
| | UserMessage, |
| | query, |
| | ) |
| | from dotenv import load_dotenv |
| |
|
| | load_dotenv() |
| |
|
| | |
| | file_lock = threading.Lock() |
| |
|
| |
|
| | def convert_message_to_chat_format(message: Any) -> dict | None: |
| | """Convert SDK message to standard chat format with role/content/tool_calls.""" |
| |
|
| | if isinstance(message, SystemMessage): |
| | |
| | if message.subtype == "init": |
| | tools = message.data.get("tools", []) |
| | tools_desc = "\n".join(f"- {tool}" for tool in tools) |
| | return { |
| | "role": "system", |
| | "content": f"You are a helpful assistant with access to the following tools:\n{tools_desc}", |
| | } |
| | return None |
| |
|
| | elif isinstance(message, AssistantMessage): |
| | text_content = "" |
| | tool_calls = [] |
| |
|
| | for block in message.content: |
| | if isinstance(block, TextBlock): |
| | text_content += block.text |
| | elif isinstance(block, ToolUseBlock): |
| | tool_calls.append( |
| | { |
| | "id": block.id, |
| | "function": { |
| | "name": block.name, |
| | "arguments": block.input, |
| | }, |
| | } |
| | ) |
| |
|
| | result = {"role": "assistant", "content": text_content} |
| | if tool_calls: |
| | result["tool_calls"] = tool_calls |
| | return result |
| |
|
| | elif isinstance(message, UserMessage): |
| | |
| | if isinstance(message.content, str): |
| | return {"role": "user", "content": message.content} |
| | elif isinstance(message.content, list): |
| | |
| | tool_results = [] |
| | text_content = "" |
| | for block in message.content: |
| | if isinstance(block, ToolResultBlock): |
| | |
| | if isinstance(block.content, str): |
| | content = block.content |
| | elif isinstance(block.content, list): |
| | content = json.dumps(block.content) |
| | else: |
| | content = str(block.content) if block.content else "" |
| |
|
| | tool_results.append( |
| | { |
| | "tool_use_id": block.tool_use_id, |
| | "content": content, |
| | "is_error": block.is_error, |
| | } |
| | ) |
| | elif isinstance(block, TextBlock): |
| | text_content += block.text |
| |
|
| | if tool_results: |
| | return { |
| | "role": "user", |
| | "content": f"<tool_response>\n{json.dumps(tool_results, indent=2)}\n</tool_response>", |
| | } |
| | else: |
| | return {"role": "user", "content": text_content} |
| | return None |
| |
|
| | elif isinstance(message, ResultMessage): |
| | |
| | return None |
| |
|
| | return None |
| |
|
| |
|
| | async def solve_task( |
| | question: str, |
| | difficulty: str, |
| | task_idx: int, |
| | total: int, |
| | semaphore: asyncio.Semaphore, |
| | ) -> dict: |
| | """Solve a single task using Claude Agent SDK.""" |
| | async with semaphore: |
| | print(f"[{task_idx}/{total}] Starting: {question[:60]}...") |
| |
|
| | messages = [] |
| | solution = None |
| |
|
| | try: |
| | async for message in query( |
| | prompt=question, |
| | options=ClaudeAgentOptions( |
| | cwd=os.getcwd(), |
| | permission_mode="bypassPermissions", |
| | disallowed_tools=["Write", "Edit", "Bash", "Glob", "Grep"], |
| | mcp_servers={ |
| | "huggingface": { |
| | "type": "http", |
| | "url": "https://huggingface.co/mcp", |
| | "headers": { |
| | "Authorization": f"Bearer {os.environ['HF_TOKEN']}" |
| | }, |
| | } |
| | }, |
| | ), |
| | ): |
| | |
| | chat_msg = convert_message_to_chat_format(message) |
| | if chat_msg: |
| | messages.append(chat_msg) |
| |
|
| | |
| | if isinstance(message, AssistantMessage): |
| | for block in message.content: |
| | if isinstance(block, TextBlock): |
| | solution = block.text |
| | |
| | elif isinstance(message, ResultMessage): |
| | if message.is_error: |
| | print(f"[{task_idx}/{total}] ✗ Agent error: {message.subtype}") |
| | return { |
| | "question": question, |
| | "difficulty": difficulty, |
| | "solution": None, |
| | "messages": messages, |
| | "error": f"Agent error: {message.subtype}", |
| | } |
| | elif message.result: |
| | solution = message.result |
| |
|
| | print(f"[{task_idx}/{total}] ✓ Done: {question[:60]}...") |
| | return { |
| | "question": question, |
| | "difficulty": difficulty, |
| | "solution": solution, |
| | "messages": messages, |
| | "error": None, |
| | } |
| | except Exception as e: |
| | print(f"[{task_idx}/{total}] ✗ Error: {e}") |
| | return { |
| | "question": question, |
| | "difficulty": difficulty, |
| | "solution": None, |
| | "messages": messages, |
| | "error": str(e), |
| | } |
| |
|
| |
|
| | def write_result(output_path: Path, result: dict): |
| | """Thread-safe write to output file.""" |
| | with file_lock: |
| | with open(output_path, "a") as f: |
| | f.write(json.dumps(result) + "\n") |
| |
|
| |
|
| | async def main(): |
| | |
| | tasks_path = Path(__file__).parent / "filled_tasks.jsonl" |
| | tasks = [] |
| | with open(tasks_path) as f: |
| | for line in f: |
| | tasks.append(json.loads(line)) |
| |
|
| | |
| | output_path = Path(__file__).parent / "solved_tasks.jsonl" |
| | output_path.write_text("") |
| |
|
| | |
| | max_concurrent = 5 |
| | semaphore = asyncio.Semaphore(max_concurrent) |
| |
|
| | total = len(tasks) |
| | print(f"Processing {total} tasks with {max_concurrent} concurrent agents...") |
| |
|
| | async def process_and_save(task: dict, idx: int): |
| | result = await solve_task( |
| | task["question"], task["difficulty"], idx, total, semaphore |
| | ) |
| | write_result(output_path, result) |
| | return result |
| |
|
| | |
| | coroutines = [process_and_save(task, i + 1) for i, task in enumerate(tasks)] |
| |
|
| | |
| | results = await asyncio.gather(*coroutines, return_exceptions=True) |
| |
|
| | successful = sum( |
| | 1 for r in results if isinstance(r, dict) and r.get("error") is None |
| | ) |
| | print(f"\nCompleted: {successful}/{total} successful") |
| | print(f"Results saved to {output_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | asyncio.run(main()) |
| |
|