Spaces:
Running
Running
| """Agent for finding and editing deadlines of a given conference using the Claude Agent SDK. | |
| Implements a 3-stage pipeline: | |
| 1. Retrieval: N agents independently search the web for conference information | |
| 2. Aggregation: A majority-vote agent synthesizes the retrieval results | |
| 3. Push: An agent writes the updated YAML and pushes directly to main | |
| Usage: | |
| ```bash | |
| uv run --env-file keys.env -m agents.agent --conference_name <name> --num-retrieval-agents 5 | |
| uv run --env-file keys.env -m agents.agent --conference_name <name> --dry-run | |
| ``` | |
| """ | |
| import argparse | |
| import asyncio | |
| import json | |
| from datetime import datetime | |
| import os | |
| from pathlib import Path | |
| import aiofiles | |
| from claude_agent_sdk import ( | |
| AssistantMessage, | |
| ClaudeAgentOptions, | |
| ResultMessage, | |
| TextBlock, | |
| ToolResultBlock, | |
| ToolUseBlock, | |
| UserMessage, | |
| query, | |
| ) | |
| from claude_agent_sdk.types import McpHttpServerConfig | |
| SCRIPT_DIR = Path(__file__).parent | |
| PROJECT_ROOT = ( | |
| Path(os.getcwd()) | |
| if os.environ.get("USE_CWD_AS_PROJECT_ROOT") | |
| else SCRIPT_DIR.parent | |
| ) | |
| # --- Structured output schemas --- | |
| RETRIEVAL_RESULT_SCHEMA = { | |
| "type": "object", | |
| "properties": { | |
| "requires_update": { | |
| "type": "boolean", | |
| "description": "Whether the conference data needs an update", | |
| }, | |
| "reasoning": { | |
| "type": "string", | |
| "description": "Explanation of why the data does or does not need an update", | |
| }, | |
| "updated_yaml": { | |
| "type": "string", | |
| "description": "The full updated YAML content", | |
| }, | |
| "source_urls": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": "URLs used as sources for the information", | |
| }, | |
| }, | |
| "required": ["requires_update", "reasoning", "updated_yaml", "source_urls"], | |
| } | |
| AGGREGATION_RESULT_SCHEMA = { | |
| "type": "object", | |
| "properties": { | |
| "reasoning": { | |
| "type": "string", | |
| "description": ( | |
| "Explanation of how the majority vote was performed, how the results " | |
| "were compared, where the colleagues agreed/disagreed, and how the " | |
| "synthesis was derived" | |
| ), | |
| }, | |
| "requires_update": { | |
| "type": "boolean", | |
| "description": "Whether the conference data needs updating (based on majority agreement)", | |
| }, | |
| "updated_yaml": { | |
| "type": "string", | |
| "description": "The synthesized updated YAML content", | |
| }, | |
| "source_urls": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "description": ( | |
| "Combined source URLs from the retrieval results that support " | |
| "the synthesized output" | |
| ), | |
| }, | |
| }, | |
| "required": ["reasoning", "requires_update", "updated_yaml", "source_urls"], | |
| } | |
| PUSH_RESULT_SCHEMA = { | |
| "type": "object", | |
| "properties": { | |
| "pushed": { | |
| "type": "boolean", | |
| "description": "Whether the update was committed and pushed to main", | |
| }, | |
| "commit_sha": { | |
| "type": "string", | |
| "description": "The SHA of the commit that was pushed (if any)", | |
| }, | |
| }, | |
| "required": ["pushed"], | |
| } | |
| # --- Utilities --- | |
| def format_date_verbose(dt: datetime) -> str: | |
| day = dt.day | |
| suffix = ( | |
| "th" | |
| if 11 <= day <= 13 | |
| else {1: "st", 2: "nd", 3: "rd"}.get(day % 10, "th") | |
| ) | |
| return f"{dt.strftime('%A')}, the {day}{suffix} of {dt.strftime('%B')}, {dt.year}" | |
| async def read_prompt(filename: str) -> str: | |
| """Read a prompt file from the script directory.""" | |
| filepath = SCRIPT_DIR / filename | |
| async with aiofiles.open(filepath, "r", encoding="utf-8") as f: | |
| return await f.read() | |
| async def read_app_readme() -> str: | |
| """Read the app README.md from the project root.""" | |
| readme_path = PROJECT_ROOT / "README.md" | |
| async with aiofiles.open(readme_path, "r", encoding="utf-8") as f: | |
| return await f.read() | |
| async def load_conference_data(conference_name: str) -> str: | |
| """Load conference data from YAML file.""" | |
| yaml_path = ( | |
| PROJECT_ROOT / "src" / "data" / "conferences" / f"{conference_name}.yml" | |
| ) | |
| if not yaml_path.exists(): | |
| print(f"Warning: Conference file not found at {yaml_path}") | |
| return "" | |
| async with aiofiles.open(yaml_path, "r", encoding="utf-8") as f: | |
| return await f.read() | |
| def _get_settings_path() -> str: | |
| """Resolve the settings.local.json path.""" | |
| settings_path = PROJECT_ROOT / ".claude" / "settings.local.json" | |
| if not settings_path.exists(): | |
| settings_path = Path.home() / ".claude" / "settings.local.json" | |
| return str(settings_path) | |
| def _get_exa_mcp_servers() -> dict[str, McpHttpServerConfig]: | |
| """Configure Exa MCP server if API key is available and not disabled.""" | |
| disable_mcp = os.environ.get("DISABLE_EXA_MCP", "").lower() in ( | |
| "1", | |
| "true", | |
| "yes", | |
| ) | |
| exa_api_key = os.environ.get("EXA_API_KEY", "") | |
| if disable_mcp: | |
| print("Exa MCP disabled via DISABLE_EXA_MCP environment variable") | |
| return {} | |
| elif exa_api_key: | |
| print(f"EXA_API_KEY found (length: {len(exa_api_key)})") | |
| return { | |
| "exa": McpHttpServerConfig( | |
| type="http", | |
| url=f"https://mcp.exa.ai/mcp?exaApiKey={exa_api_key}", | |
| ) | |
| } | |
| else: | |
| print("EXA_API_KEY not found, Exa MCP will not be available") | |
| return {} | |
| # --- Shared agent runner --- | |
| MAX_RETRIES = 3 | |
| SILENT_EXIT_THRESHOLD = 2 # message_count <= this with empty result triggers retry | |
| async def _run_agent_once( | |
| system_prompt: str, | |
| user_prompt: str, | |
| output_schema: dict, | |
| agent_label: str = "agent", | |
| mcp_servers: dict[str, McpHttpServerConfig] | None = None, | |
| ) -> tuple[dict, float, int]: | |
| """Run a single agent query attempt. | |
| Returns: | |
| A tuple of (structured output dict, cost in USD, message count). | |
| """ | |
| def on_stderr(data: str): | |
| print(f"[{agent_label}][stderr] {data.strip()}") | |
| options_kwargs: dict = { | |
| "system_prompt": system_prompt, | |
| "permission_mode": "bypassPermissions", | |
| "settings": _get_settings_path(), | |
| "stderr": on_stderr, | |
| "output_format": { | |
| "type": "json_schema", | |
| "schema": output_schema, | |
| }, | |
| } | |
| if mcp_servers: | |
| options_kwargs["mcp_servers"] = mcp_servers | |
| options = ClaudeAgentOptions(**options_kwargs) | |
| subagent_names: dict[str, str] = {} | |
| tool_names: dict[str, str] = {} | |
| message_count = 0 | |
| result: dict = {} | |
| cost_usd = 0.0 | |
| try: | |
| async for message in query(prompt=user_prompt, options=options): | |
| message_count += 1 | |
| msg_type = type(message).__name__ | |
| print(f"[{agent_label}] Message {message_count}: {msg_type}") | |
| if isinstance(message, AssistantMessage): | |
| if message.parent_tool_use_id is None: | |
| agent_prefix = f"[{agent_label}]" | |
| else: | |
| subagent_name = subagent_names.get( | |
| message.parent_tool_use_id, "subagent" | |
| ) | |
| agent_prefix = f"[{agent_label}/{subagent_name}]" | |
| for block in message.content: | |
| if isinstance(block, TextBlock): | |
| print(f"{agent_prefix} Claude: {block.text}") | |
| elif isinstance(block, ToolUseBlock): | |
| print(f"{agent_prefix} Tool: {block.name}({block.input})") | |
| tool_names[block.id] = block.name | |
| if block.name == "Task" and isinstance(block.input, dict): | |
| subagent_names[block.id] = block.input.get( | |
| "subagent_type", "subagent" | |
| ) | |
| elif isinstance(message, UserMessage): | |
| if isinstance(message.content, list): | |
| for block in message.content: | |
| if isinstance(block, ToolResultBlock): | |
| tool_name = tool_names.get( | |
| block.tool_use_id, "unknown" | |
| ) | |
| content_str = ( | |
| str(block.content) if block.content else "(empty)" | |
| ) | |
| if len(content_str) > 500: | |
| content_str = content_str[:500] + "... (truncated)" | |
| error_indicator = " [ERROR]" if block.is_error else "" | |
| print( | |
| f"[{agent_label}][result]{error_indicator} " | |
| f"{tool_name}: {content_str}" | |
| ) | |
| elif isinstance(message, ResultMessage): | |
| if hasattr(message, "error") and message.error: | |
| print(f"[{agent_label}][result] ERROR: {message.error}") | |
| if message.total_cost_usd and message.total_cost_usd > 0: | |
| cost_usd = message.total_cost_usd | |
| print(f"[{agent_label}] Cost: ${cost_usd:.4f}") | |
| if ( | |
| hasattr(message, "structured_output") | |
| and message.structured_output | |
| ): | |
| result = message.structured_output | |
| print(f"[{agent_label}][structured_output] {result}") | |
| except Exception as e: | |
| print(f"[{agent_label}] Error: {type(e).__name__}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| result["error"] = str(e) | |
| print(f"[{agent_label}] Completed. Total messages: {message_count}") | |
| return result, cost_usd, message_count | |
| async def _run_agent( | |
| system_prompt: str, | |
| user_prompt: str, | |
| output_schema: dict, | |
| agent_label: str = "agent", | |
| mcp_servers: dict[str, McpHttpServerConfig] | None = None, | |
| ) -> tuple[dict, float]: | |
| """Run a single agent query with structured output and automatic retries. | |
| The Claude Agent SDK on Modal can silently exit after only the SystemMessage, | |
| producing an empty result. This wrapper detects that (low message count with | |
| an empty result) and retries up to MAX_RETRIES times. | |
| Args: | |
| system_prompt: The system prompt for the agent. | |
| user_prompt: The user prompt for the agent. | |
| output_schema: JSON schema for structured output. | |
| agent_label: Label for log messages (e.g. "retrieval-1", "aggregation"). | |
| mcp_servers: Optional MCP server configuration. | |
| Returns: | |
| A tuple of (structured output dict, cost in USD). | |
| """ | |
| total_cost = 0.0 | |
| for attempt in range(1, MAX_RETRIES + 1): | |
| result, cost, message_count = await _run_agent_once( | |
| system_prompt=system_prompt, | |
| user_prompt=user_prompt, | |
| output_schema=output_schema, | |
| agent_label=agent_label, | |
| mcp_servers=mcp_servers, | |
| ) | |
| total_cost += cost | |
| is_silent_exit = ( | |
| message_count <= SILENT_EXIT_THRESHOLD | |
| and not result | |
| ) | |
| if not is_silent_exit or attempt == MAX_RETRIES: | |
| if is_silent_exit: | |
| print( | |
| f"[{agent_label}] WARNING: SDK silent exit persisted after " | |
| f"{MAX_RETRIES} attempts" | |
| ) | |
| return result, total_cost | |
| print( | |
| f"[{agent_label}] SDK silent exit detected " | |
| f"({message_count} messages, empty result). " | |
| f"Retrying ({attempt}/{MAX_RETRIES})..." | |
| ) | |
| await asyncio.sleep(2 * attempt) | |
| return result, total_cost | |
| # --- Stage 1: Information Retrieval --- | |
| async def run_retrieval_agent( | |
| conference_name: str, agent_index: int = 1 | |
| ) -> tuple[dict, float]: | |
| """Run a single retrieval agent to search for conference information. | |
| Args: | |
| conference_name: Name of the conference. | |
| agent_index: 1-based index of this agent (for logging). | |
| Returns: | |
| Tuple of (structured retrieval result dict, cost in USD). | |
| """ | |
| conference_data = await load_conference_data(conference_name) | |
| app_readme = await read_app_readme() | |
| system_template = await read_prompt("prompts/retrieval_system_prompt.md") | |
| system_prompt = system_template.format( | |
| conference_name=conference_name, | |
| date=format_date_verbose(datetime.now()), | |
| app_readme=app_readme, | |
| ) | |
| user_template = await read_prompt("prompts/retrieval_user_prompt.md") | |
| user_prompt = user_template.format( | |
| conference_name=conference_name, | |
| conference_data=conference_data if conference_data else "No existing data found.", | |
| ) | |
| mcp_servers = _get_exa_mcp_servers() | |
| return await _run_agent( | |
| system_prompt=system_prompt, | |
| user_prompt=user_prompt, | |
| output_schema=RETRIEVAL_RESULT_SCHEMA, | |
| agent_label=f"retrieval-{agent_index}", | |
| mcp_servers=mcp_servers or None, | |
| ) | |
| async def run_retrieval_agents( | |
| conference_name: str, n: int = 3 | |
| ) -> tuple[list[dict], float]: | |
| """Run N retrieval agents sequentially. | |
| Args: | |
| conference_name: Name of the conference. | |
| n: Number of retrieval agents to run. | |
| Returns: | |
| Tuple of (list of N retrieval result dicts, total cost in USD). | |
| """ | |
| results = [] | |
| total_cost = 0.0 | |
| for i in range(1, n + 1): | |
| print(f"\n--- Retrieval Agent {i}/{n} ---") | |
| result, cost = await run_retrieval_agent(conference_name, agent_index=i) | |
| results.append(result) | |
| total_cost += cost | |
| return results, total_cost | |
| # --- Stage 2: Aggregation (Majority Vote) --- | |
| async def run_aggregation_agent( | |
| conference_name: str, retrieval_results: list[dict] | |
| ) -> tuple[dict, float]: | |
| """Run the aggregation agent to perform majority vote over retrieval results. | |
| Uses a dedicated aggregation system prompt and has access to Exa MCP | |
| for independently verifying factual claims when agents disagree. | |
| Args: | |
| conference_name: Name of the conference. | |
| retrieval_results: List of retrieval result dicts from stage 1. | |
| Returns: | |
| Tuple of (aggregated result dict with consensus decision, cost in USD). | |
| """ | |
| conference_data = await load_conference_data(conference_name) | |
| system_template = await read_prompt("prompts/aggregation_system_prompt.md") | |
| system_prompt = system_template.format( | |
| conference_name=conference_name, | |
| date=format_date_verbose(datetime.now()), | |
| ) | |
| results_text = "" | |
| for i, result in enumerate(retrieval_results, 1): | |
| results_text += f"### Agent {i} result\n\n" | |
| results_text += f"```json\n{json.dumps(result, indent=2)}\n```\n\n" | |
| user_template = await read_prompt("prompts/aggregation_user_prompt.md") | |
| user_prompt = user_template.format( | |
| conference_name=conference_name, | |
| conference_data=conference_data if conference_data else "No existing data found.", | |
| num_agents=len(retrieval_results), | |
| retrieval_results=results_text, | |
| ) | |
| mcp_servers = _get_exa_mcp_servers() | |
| return await _run_agent( | |
| system_prompt=system_prompt, | |
| user_prompt=user_prompt, | |
| output_schema=AGGREGATION_RESULT_SCHEMA, | |
| agent_label="aggregation", | |
| mcp_servers=mcp_servers or None, | |
| ) | |
| # --- Stage 3: Push to Main --- | |
| async def run_push_agent( | |
| conference_name: str, | |
| verified_yaml: str, | |
| changes_summary: str, | |
| source_urls: list[str], | |
| ) -> tuple[dict, float]: | |
| """Run the push agent to write updated YAML and push directly to main. | |
| Args: | |
| conference_name: Name of the conference. | |
| verified_yaml: The verified updated YAML content to write. | |
| changes_summary: Summary of what changed. | |
| source_urls: Source URLs supporting the update. | |
| Returns: | |
| Tuple of (push result dict with pushed and commit_sha, cost in USD). | |
| """ | |
| current_yaml = await load_conference_data(conference_name) | |
| formatted_source_urls = ( | |
| "\n".join(f"- {url}" for url in source_urls) | |
| if source_urls | |
| else "- No source URLs were provided." | |
| ) | |
| system_template = await read_prompt("prompts/pr_system_prompt.md") | |
| system_prompt = system_template.format( | |
| conference_name=conference_name, | |
| ) | |
| user_template = await read_prompt("prompts/pr_user_prompt.md") | |
| user_prompt = user_template.format( | |
| conference_name=conference_name, | |
| updated_yaml=verified_yaml, | |
| changes_summary=changes_summary, | |
| source_urls=formatted_source_urls, | |
| current_yaml=current_yaml if current_yaml else "(file does not exist yet)", | |
| ) | |
| return await _run_agent( | |
| system_prompt=system_prompt, | |
| user_prompt=user_prompt, | |
| output_schema=PUSH_RESULT_SCHEMA, | |
| agent_label="push", | |
| mcp_servers=None, | |
| ) | |
| # --- Orchestrator --- | |
| async def find_conference_deadlines( | |
| conference_name: str, | |
| num_retrieval_agents: int = 3, | |
| dry_run: bool = False, | |
| ) -> dict: | |
| """Orchestrate the 3-stage pipeline for a conference. | |
| Stage 1: Run N retrieval agents sequentially to gather information. | |
| Stage 2: Run an aggregation agent to perform majority vote. | |
| Stage 3: If an update is needed, write the YAML and push directly to main. | |
| Args: | |
| conference_name: Name of the conference. | |
| num_retrieval_agents: Number of retrieval agents to run (default 3). | |
| dry_run: If True, skip pushing and just print the aggregated result. | |
| Returns: | |
| Final result dict with pushed, reasoning, and total_cost_usd. | |
| """ | |
| total_cost = 0.0 | |
| print(f"Processing conference: {conference_name}") | |
| if dry_run: | |
| print("DRY RUN: push will be skipped") | |
| pipeline_suffix = "" if dry_run else " -> push" | |
| print( | |
| f"Pipeline: {num_retrieval_agents} retrieval agents " | |
| f"-> aggregation{pipeline_suffix}" | |
| ) | |
| # === Stage 1: Information Retrieval === | |
| print(f"\n{'=' * 60}") | |
| print( | |
| f"=== Stage 1: Information Retrieval " | |
| f"({num_retrieval_agents} agents) ===" | |
| ) | |
| print(f"{'=' * 60}") | |
| retrieval_results, retrieval_cost = await run_retrieval_agents( | |
| conference_name, n=num_retrieval_agents | |
| ) | |
| total_cost += retrieval_cost | |
| for i, result in enumerate(retrieval_results, 1): | |
| requires_update = result.get("requires_update", "unknown") | |
| print(f" Agent {i}: requires_update={requires_update}") | |
| print(f" Retrieval stage cost: ${retrieval_cost:.4f}") | |
| # === Stage 2: Aggregation (Majority Vote) === | |
| print(f"\n{'=' * 60}") | |
| print("=== Stage 2: Aggregation (Majority Vote) ===") | |
| print(f"{'=' * 60}") | |
| aggregation_result, aggregation_cost = await run_aggregation_agent( | |
| conference_name, retrieval_results | |
| ) | |
| total_cost += aggregation_cost | |
| requires_update = aggregation_result.get("requires_update", False) | |
| # Fallback: if aggregation returned empty (SDK silent exit) but retrieval | |
| # agents unanimously agreed on an update, use the first retrieval result. | |
| valid_results = [r for r in retrieval_results if r.get("requires_update") is not None] | |
| all_agree_update = ( | |
| len(valid_results) >= 2 | |
| and all(r.get("requires_update") is True for r in valid_results) | |
| ) | |
| if not aggregation_result and all_agree_update: | |
| print( | |
| "\nWARNING: Aggregation agent returned empty (SDK silent exit). " | |
| "All retrieval agents unanimously agreed on update — using " | |
| "first retrieval result as fallback." | |
| ) | |
| aggregation_result = valid_results[0] | |
| requires_update = True | |
| print(f"\nAggregation result: requires_update={requires_update}") | |
| reasoning_preview = aggregation_result.get("reasoning", "N/A")[:200] | |
| print(f"Reasoning: {reasoning_preview}") | |
| print(f" Aggregation stage cost: ${aggregation_cost:.4f}") | |
| if not requires_update: | |
| print("\nNo update needed. Skipping push.") | |
| print(f"\nTotal pipeline cost: ${total_cost:.4f}") | |
| return { | |
| "pushed": False, | |
| "reasoning": aggregation_result.get("reasoning", ""), | |
| "updated_yaml": aggregation_result.get("updated_yaml", ""), | |
| "total_cost_usd": total_cost, | |
| } | |
| if dry_run: | |
| print("\nDRY RUN: Update needed but skipping push.") | |
| print(f"Updated YAML:\n{aggregation_result.get('updated_yaml', '')}") | |
| print(f"\nTotal pipeline cost: ${total_cost:.4f}") | |
| return { | |
| "pushed": False, | |
| "reasoning": aggregation_result.get("reasoning", ""), | |
| "updated_yaml": aggregation_result.get("updated_yaml", ""), | |
| "total_cost_usd": total_cost, | |
| } | |
| # === Stage 3: Push to Main === | |
| print(f"\n{'=' * 60}") | |
| print("=== Stage 3: Push to Main ===") | |
| print(f"{'=' * 60}") | |
| verified_yaml = aggregation_result.get("updated_yaml", "") | |
| changes_summary = aggregation_result.get("reasoning", "") | |
| source_urls = aggregation_result.get("source_urls", []) | |
| push_result, push_cost = await run_push_agent( | |
| conference_name, verified_yaml, changes_summary, source_urls | |
| ) | |
| total_cost += push_cost | |
| print(f" Push stage cost: ${push_cost:.4f}") | |
| print(f"\nTotal pipeline cost: ${total_cost:.4f}") | |
| return { | |
| "pushed": push_result.get("pushed", False), | |
| "commit_sha": push_result.get("commit_sha"), | |
| "reasoning": aggregation_result.get("reasoning", ""), | |
| "total_cost_usd": total_cost, | |
| } | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Find conference deadlines using Claude Agent SDK" | |
| ) | |
| parser.add_argument( | |
| "--conference_name", | |
| type=str, | |
| required=True, | |
| help="The name of the conference to find the deadlines of", | |
| ) | |
| parser.add_argument( | |
| "--num-retrieval-agents", | |
| type=int, | |
| default=3, | |
| help="Number of retrieval agents to run (default: 3)", | |
| ) | |
| parser.add_argument( | |
| "--dry-run", | |
| action="store_true", | |
| help="Run retrieval and aggregation only, skip pushing to main", | |
| ) | |
| args = parser.parse_args() | |
| result = asyncio.run( | |
| find_conference_deadlines( | |
| args.conference_name, | |
| num_retrieval_agents=args.num_retrieval_agents, | |
| dry_run=args.dry_run, | |
| ) | |
| ) | |
| print(f"\nResult: {result}") | |