Paper2Agent / test.py
jcmiao's picture
Upload 18 files
51f80a0 verified
raw
history blame
5.66 kB
import anyio
import argparse
import shutil
import subprocess
from claude_agent_sdk import (
ClaudeSDKClient,
ClaudeAgentOptions,
AgentDefinition,
AssistantMessage,
TextBlock,
ToolUseBlock,
ToolResultBlock
)
import os
import sys
from pathlib import Path
from utils import copy_project_resources, clone_github_repo, prepare_folder_structure
from prompts.tasks import step1_environment_setup_and_tutorial_discovery, step2_tutorial_execution, step3_tool_extraction_and_testing, step4_mcp_integration
# ANTHROPIC_API_KEY should be set via environment variable by the caller (app.py)
async def fully_automatic(tasks: list, task_descriptions: list = None, log_file_path: str = None):
options = ClaudeAgentOptions(
allowed_tools=["Bash", "Edit", "Glob", "Grep", "NotebookEdit", "NotebookRead", "Read", "SlashCommand", "Task", "TodoWrite", "WebFetch", "WebSearch", "Write"],
permission_mode='acceptEdits',
cwd=str(Path.cwd()),
setting_sources=["project"],
)
async with ClaudeSDKClient(options=options) as client:
for i, task in enumerate(tasks, 1):
# Simple print for UI
print(f"\n{'='*70}")
if task_descriptions and i <= len(task_descriptions):
print(f"πŸš€ Starting {task_descriptions[i-1]}")
else:
print(f"πŸš€ Starting Task {i}")
print('='*70 + "\n")
try:
await client.query(task)
async for message in client.receive_response():
# Write detailed logs to file
if log_file_path:
with open(f"Task_{i}_{log_file_path}", 'a', encoding='utf-8') as log_file:
if isinstance(message, AssistantMessage):
for block in message.content:
if isinstance(block, TextBlock):
log_file.write(f"πŸ’­ Claude: {block.text}\n")
elif isinstance(block, ToolUseBlock):
if hasattr(block, 'input') and block.input:
if isinstance(block.input, dict):
for key, value in block.input.items():
val_str = str(value)
log_file.write(f"[ToolUseBlock] {key}: {val_str}\n")
elif isinstance(message, ToolResultBlock):
if hasattr(message, 'content'):
result = str(message.content)
log_file.write(f" βœ… Result: {result}\n")
# Only print brief progress to stdout for UI
if isinstance(message, AssistantMessage):
for block in message.content:
if isinstance(block, TextBlock):
# Only print short text blocks
text = block.text.strip()
if len(text) < 150:
print(f"πŸ’­ {text}")
print(f"\nβœ… Task {i} Completed\n")
except Exception as e:
print(f"❌ Task {i} Failed: {e}\n")
if log_file_path:
with open(log_file_path, 'a', encoding='utf-8') as log_file:
log_file.write(f"❌ Task {i} Failed: {e}\n")
def main():
parser = argparse.ArgumentParser(description="Script for running tasks with configurable options.")
parser.add_argument('--github_url', dest='github_repo_url', default="", help='GitHub repository URL')
parser.add_argument('--tutorials', dest='tutorial_filter', default="", help='Tutorial filter')
parser.add_argument('--api', dest='api_key', default="", help='API key')
args = parser.parse_args()
GITHUB_REPO_URL = args.github_repo_url
FOLDER_NAME = "Results"
TUTORIAL_FILTER = args.tutorial_filter
API_KEY = args.api_key
# Extract repo_name from the GITHUB_REPO_URL (strip .git suffix if present)
if GITHUB_REPO_URL:
repo_name = os.path.basename(GITHUB_REPO_URL)
if repo_name.endswith(".git"):
repo_name = repo_name[:-4]
else:
repo_name = ""
os.makedirs(FOLDER_NAME, exist_ok=True)
# step 1: copy .claude, templates, tools to the project directory
copy_project_resources(FOLDER_NAME)
# step 2: prepare the folder structure
prepare_folder_structure(FOLDER_NAME)
os.chdir(FOLDER_NAME)
# step 3: clone the github repository
clone_github_repo(GITHUB_REPO_URL, repo_name)
task_descriptions = [
"Task 1: Environment Setup and Tutorial Discovery",
"Task 2: Tutorial Execution",
"Task 3: Tool Extraction and Testing",
"Task 4: MCP Integration"
]
tasks = [
step1_environment_setup_and_tutorial_discovery(repo_name,TUTORIAL_FILTER),
step2_tutorial_execution(repo_name,API_KEY),
step3_tool_extraction_and_testing(repo_name,API_KEY),
step4_mcp_integration(repo_name),
]
print("\n" + "="*70)
print("πŸ“‹ Pipeline Tasks:")
for i, desc in enumerate(task_descriptions, 1):
print(f" {i}. {desc}")
print("="*70 + "\n")
# Define log file path
log_file_path = "log.log"
#print(tasks[0])
anyio.run(fully_automatic, tasks, task_descriptions, log_file_path)
if __name__ == "__main__":
main()