Upload 8 files
Browse files- app.py +375 -0
- configuration.py +166 -0
- graph.py +574 -0
- prompts.py +180 -0
- state.py +72 -0
- supervisor_node.py +406 -0
- tools.py +58 -0
- utils.py +76 -0
app.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Web application for the Agent Supervisor with GAIA benchmark integration.
|
| 2 |
+
|
| 3 |
+
This module provides a Gradio web interface for interacting with the Agent Supervisor
|
| 4 |
+
and evaluating it against the GAIA benchmark.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import uuid
|
| 10 |
+
import asyncio
|
| 11 |
+
import requests
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import gradio as gr
|
| 14 |
+
|
| 15 |
+
from typing import Dict, List, Optional
|
| 16 |
+
from langchain_core.messages import HumanMessage
|
| 17 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 18 |
+
|
| 19 |
+
from react_agent.graph import create_agent_supervisor_graph, get_compiled_graph
|
| 20 |
+
|
| 21 |
+
# --- Constants ---
|
| 22 |
+
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 23 |
+
|
| 24 |
+
class GaiaAgent:
|
| 25 |
+
"""Agent implementation for the GAIA benchmark using the LangGraph supervisor."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, model_name=None, checkpointer=None):
|
| 28 |
+
"""Initialize the GAIA agent with LangGraph architecture.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
model_name: Optional model name to override the default
|
| 32 |
+
checkpointer: Optional checkpointer for persistence
|
| 33 |
+
"""
|
| 34 |
+
print("Initializing GaiaAgent...")
|
| 35 |
+
|
| 36 |
+
# Import Configuration class
|
| 37 |
+
from react_agent.configuration import Configuration
|
| 38 |
+
|
| 39 |
+
# Get configuration
|
| 40 |
+
config = Configuration.from_context()
|
| 41 |
+
default_model = config.model
|
| 42 |
+
|
| 43 |
+
# If no checkpointer provided, create a default one - using MemorySaver to avoid SQLite thread issues
|
| 44 |
+
if checkpointer is None:
|
| 45 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 46 |
+
checkpointer = MemorySaver()
|
| 47 |
+
print("Using in-memory checkpointer to avoid thread safety issues")
|
| 48 |
+
|
| 49 |
+
# Create and compile the graph
|
| 50 |
+
self.graph = get_compiled_graph(checkpointer=checkpointer)
|
| 51 |
+
|
| 52 |
+
# Configure the agent using values from Configuration
|
| 53 |
+
self.config = {
|
| 54 |
+
"configurable": {
|
| 55 |
+
# Use configuration model or override if provided
|
| 56 |
+
"model": model_name if model_name else default_model,
|
| 57 |
+
# Import specific models for each role from Configuration
|
| 58 |
+
"researcher_model": config.researcher_model,
|
| 59 |
+
"coder_model": config.coder_model,
|
| 60 |
+
"planner_model": config.planner_model,
|
| 61 |
+
"supervisor_model": config.supervisor_model,
|
| 62 |
+
"critic_model": config.critic_model,
|
| 63 |
+
"final_answer_model": config.final_answer_model,
|
| 64 |
+
# Other settings from Configuration
|
| 65 |
+
"max_search_results": config.max_search_results,
|
| 66 |
+
"recursion_limit": config.recursion_limit,
|
| 67 |
+
"max_iterations": config.max_iterations,
|
| 68 |
+
"allow_agent_to_extract_answers": config.allow_agent_to_extract_answers
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
print(f"GaiaAgent initialized successfully with model: {self.config['configurable']['model']}")
|
| 73 |
+
|
| 74 |
+
def __call__(self, question: str) -> str:
|
| 75 |
+
"""Process a question and return an answer formatted for GAIA benchmark.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
question: The GAIA benchmark question
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Answer formatted for GAIA benchmark evaluation
|
| 82 |
+
"""
|
| 83 |
+
print(f"Agent received question: {question[:100]}...")
|
| 84 |
+
|
| 85 |
+
# Create a thread_id for this interaction
|
| 86 |
+
thread_id = str(uuid.uuid4())
|
| 87 |
+
self.config["configurable"]["thread_id"] = thread_id
|
| 88 |
+
|
| 89 |
+
# Import configuration
|
| 90 |
+
from react_agent.configuration import Configuration
|
| 91 |
+
config = Configuration.from_context()
|
| 92 |
+
|
| 93 |
+
# Add a system prompt to ensure proper GAIA format
|
| 94 |
+
system_prompt = """You are a general AI assistant. Answer the question concisely.
|
| 95 |
+
YOUR ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
|
| 96 |
+
If asked for a number, don't use commas or units like $ or % unless specified.
|
| 97 |
+
If asked for a string, don't use articles or abbreviations (e.g. for cities), and write digits as plain text unless specified otherwise.
|
| 98 |
+
Focus on brevity and correctness."""
|
| 99 |
+
|
| 100 |
+
# Create input state with the human message and system prompt
|
| 101 |
+
input_state = {
|
| 102 |
+
"messages": [HumanMessage(content=question)],
|
| 103 |
+
"configurable": {
|
| 104 |
+
"thread_id": thread_id,
|
| 105 |
+
"system_prompt": system_prompt,
|
| 106 |
+
"model": config.model # Ensure model is also set in the state
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
# Process the question with our graph
|
| 111 |
+
try:
|
| 112 |
+
# Execute the graph and get the final state
|
| 113 |
+
# Use invoke instead of stream to limit operations
|
| 114 |
+
try:
|
| 115 |
+
final_state = self.graph.invoke(input_state, config=self.config)
|
| 116 |
+
except Exception as e:
|
| 117 |
+
# If we hit recursion error, try again with higher limit
|
| 118 |
+
print(f"Initial invocation failed: {str(e)}")
|
| 119 |
+
# Use double the recursion limit as fallback
|
| 120 |
+
self.config["configurable"]["recursion_limit"] = config.recursion_limit * 2
|
| 121 |
+
final_state = self.graph.invoke(input_state, config=self.config)
|
| 122 |
+
|
| 123 |
+
# Extract the answer - either from gaia_answer or from the last message
|
| 124 |
+
if "gaia_answer" in final_state:
|
| 125 |
+
answer = final_state["gaia_answer"]
|
| 126 |
+
else:
|
| 127 |
+
messages = final_state.get("messages", [])
|
| 128 |
+
answer = messages[-1].content if messages else "No answer generated."
|
| 129 |
+
|
| 130 |
+
# Clean the answer to ensure proper GAIA format (remove any FINAL ANSWER prefix)
|
| 131 |
+
if "FINAL ANSWER:" in answer:
|
| 132 |
+
answer = answer.split("FINAL ANSWER:")[1].strip()
|
| 133 |
+
|
| 134 |
+
print(f"Agent returning answer: {answer[:100]}...")
|
| 135 |
+
return answer
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
error_msg = f"Error processing question: {str(e)}"
|
| 139 |
+
print(error_msg)
|
| 140 |
+
return error_msg
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def run_and_submit_all(profile: gr.OAuthProfile | None):
|
| 144 |
+
"""Fetches all questions, runs the GaiaAgent on them, submits answers, and displays the results."""
|
| 145 |
+
|
| 146 |
+
# --- Determine HF Space Runtime URL and Repo URL ---
|
| 147 |
+
space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
|
| 148 |
+
|
| 149 |
+
if profile:
|
| 150 |
+
username = f"{profile.username}"
|
| 151 |
+
print(f"User logged in: {username}")
|
| 152 |
+
else:
|
| 153 |
+
print("User not logged in.")
|
| 154 |
+
return "Please Login to Hugging Face with the button.", None
|
| 155 |
+
|
| 156 |
+
api_url = DEFAULT_API_URL
|
| 157 |
+
questions_url = f"{api_url}/questions"
|
| 158 |
+
submit_url = f"{api_url}/submit"
|
| 159 |
+
|
| 160 |
+
# 1. Instantiate Agent
|
| 161 |
+
try:
|
| 162 |
+
agent = GaiaAgent()
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"Error instantiating agent: {e}")
|
| 165 |
+
return f"Error initializing agent: {e}", None
|
| 166 |
+
|
| 167 |
+
# In the case of an app running as a hugging Face space, this link points toward your codebase
|
| 168 |
+
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
|
| 169 |
+
print(agent_code)
|
| 170 |
+
|
| 171 |
+
# 2. Fetch Questions
|
| 172 |
+
print(f"Fetching questions from: {questions_url}")
|
| 173 |
+
try:
|
| 174 |
+
response = requests.get(questions_url, timeout=15)
|
| 175 |
+
response.raise_for_status()
|
| 176 |
+
questions_data = response.json()
|
| 177 |
+
if not questions_data:
|
| 178 |
+
print("Fetched questions list is empty.")
|
| 179 |
+
return "Fetched questions list is empty or invalid format.", None
|
| 180 |
+
print(f"Fetched {len(questions_data)} questions.")
|
| 181 |
+
except requests.exceptions.RequestException as e:
|
| 182 |
+
print(f"Error fetching questions: {e}")
|
| 183 |
+
return f"Error fetching questions: {e}", None
|
| 184 |
+
except requests.exceptions.JSONDecodeError as e:
|
| 185 |
+
print(f"Error decoding JSON response from questions endpoint: {e}")
|
| 186 |
+
print(f"Response text: {response.text[:500]}")
|
| 187 |
+
return f"Error decoding server response for questions: {e}", None
|
| 188 |
+
except Exception as e:
|
| 189 |
+
print(f"An unexpected error occurred fetching questions: {e}")
|
| 190 |
+
return f"An unexpected error occurred fetching questions: {e}", None
|
| 191 |
+
|
| 192 |
+
# 3. Run the Agent
|
| 193 |
+
results_log = []
|
| 194 |
+
answers_payload = []
|
| 195 |
+
print(f"Running agent on {len(questions_data)} questions...")
|
| 196 |
+
for item in questions_data:
|
| 197 |
+
task_id = item.get("task_id")
|
| 198 |
+
question_text = item.get("question")
|
| 199 |
+
if not task_id or question_text is None:
|
| 200 |
+
print(f"Skipping item with missing task_id or question: {item}")
|
| 201 |
+
continue
|
| 202 |
+
try:
|
| 203 |
+
answer = agent(question_text)
|
| 204 |
+
# Format answers according to API requirements - use submitted_answer as required
|
| 205 |
+
answers_payload.append({
|
| 206 |
+
"task_id": task_id,
|
| 207 |
+
"submitted_answer": answer
|
| 208 |
+
})
|
| 209 |
+
results_log.append({
|
| 210 |
+
"Task ID": task_id,
|
| 211 |
+
"Question": question_text,
|
| 212 |
+
"Answer": answer
|
| 213 |
+
})
|
| 214 |
+
except Exception as e:
|
| 215 |
+
print(f"Error running agent on task {task_id}: {e}")
|
| 216 |
+
results_log.append({
|
| 217 |
+
"Task ID": task_id,
|
| 218 |
+
"Question": question_text,
|
| 219 |
+
"Answer": f"AGENT ERROR: {e}"
|
| 220 |
+
})
|
| 221 |
+
|
| 222 |
+
if not answers_payload:
|
| 223 |
+
print("Agent did not produce any answers to submit.")
|
| 224 |
+
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
| 225 |
+
|
| 226 |
+
# 4. Prepare Submission
|
| 227 |
+
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
| 228 |
+
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
| 229 |
+
print(status_update)
|
| 230 |
+
|
| 231 |
+
# 5. Submit
|
| 232 |
+
print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
|
| 233 |
+
try:
|
| 234 |
+
response = requests.post(submit_url, json=submission_data, timeout=60)
|
| 235 |
+
response.raise_for_status()
|
| 236 |
+
result_data = response.json()
|
| 237 |
+
final_status = (
|
| 238 |
+
f"Submission Successful!\n"
|
| 239 |
+
f"User: {result_data.get('username')}\n"
|
| 240 |
+
f"Overall Score: {result_data.get('score', 'N/A')}% "
|
| 241 |
+
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
| 242 |
+
f"Message: {result_data.get('message', 'No message received.')}"
|
| 243 |
+
)
|
| 244 |
+
print("Submission successful.")
|
| 245 |
+
results_df = pd.DataFrame(results_log)
|
| 246 |
+
return final_status, results_df
|
| 247 |
+
except requests.exceptions.HTTPError as e:
|
| 248 |
+
error_detail = f"Server responded with status {e.response.status_code}."
|
| 249 |
+
try:
|
| 250 |
+
error_json = e.response.json()
|
| 251 |
+
error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
|
| 252 |
+
except requests.exceptions.JSONDecodeError:
|
| 253 |
+
error_detail += f" Response: {e.response.text[:500]}"
|
| 254 |
+
status_message = f"Submission Failed: {error_detail}"
|
| 255 |
+
print(status_message)
|
| 256 |
+
results_df = pd.DataFrame(results_log)
|
| 257 |
+
return status_message, results_df
|
| 258 |
+
except requests.exceptions.Timeout:
|
| 259 |
+
status_message = "Submission Failed: The request timed out."
|
| 260 |
+
print(status_message)
|
| 261 |
+
results_df = pd.DataFrame(results_log)
|
| 262 |
+
return status_message, results_df
|
| 263 |
+
except requests.exceptions.RequestException as e:
|
| 264 |
+
status_message = f"Submission Failed: Network error - {e}"
|
| 265 |
+
print(status_message)
|
| 266 |
+
results_df = pd.DataFrame(results_log)
|
| 267 |
+
return status_message, results_df
|
| 268 |
+
except Exception as e:
|
| 269 |
+
status_message = f"An unexpected error occurred during submission: {e}"
|
| 270 |
+
print(status_message)
|
| 271 |
+
results_df = pd.DataFrame(results_log)
|
| 272 |
+
return status_message, results_df
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# Function to test a single random question
|
| 276 |
+
def test_random_question():
|
| 277 |
+
"""Fetch a random question from the API and run the agent on it."""
|
| 278 |
+
api_url = DEFAULT_API_URL
|
| 279 |
+
random_question_url = f"{api_url}/random-question"
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
# Fetch a random question
|
| 283 |
+
response = requests.get(random_question_url, timeout=15)
|
| 284 |
+
response.raise_for_status()
|
| 285 |
+
question_data = response.json()
|
| 286 |
+
|
| 287 |
+
if not question_data:
|
| 288 |
+
return "Error: Received empty response from random question endpoint.", None
|
| 289 |
+
|
| 290 |
+
task_id = question_data.get("task_id")
|
| 291 |
+
question_text = question_data.get("question")
|
| 292 |
+
|
| 293 |
+
if not task_id or not question_text:
|
| 294 |
+
return "Error: Invalid question format received.", None
|
| 295 |
+
|
| 296 |
+
# Initialize agent and get answer
|
| 297 |
+
agent = GaiaAgent()
|
| 298 |
+
answer = agent(question_text)
|
| 299 |
+
|
| 300 |
+
# Return results
|
| 301 |
+
result = {
|
| 302 |
+
"Task ID": task_id,
|
| 303 |
+
"Question": question_text,
|
| 304 |
+
"Answer": answer
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
return "Test completed successfully.", result
|
| 308 |
+
|
| 309 |
+
except Exception as e:
|
| 310 |
+
return f"Error testing random question: {str(e)}", None
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# --- Build Gradio Interface using Blocks ---
|
| 314 |
+
with gr.Blocks() as demo:
|
| 315 |
+
gr.Markdown("# GAIA Benchmark Agent Evaluation")
|
| 316 |
+
gr.Markdown(
|
| 317 |
+
"""
|
| 318 |
+
**Instructions:**
|
| 319 |
+
|
| 320 |
+
1. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
|
| 321 |
+
2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run the agent, submit answers, and see the score.
|
| 322 |
+
3. Alternatively, click 'Test on Random Question' to test the agent on a single random question.
|
| 323 |
+
|
| 324 |
+
---
|
| 325 |
+
**Note:** Running the agent on all questions may take some time. Please be patient while the agent processes all the questions.
|
| 326 |
+
"""
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
gr.LoginButton()
|
| 330 |
+
|
| 331 |
+
with gr.Tabs():
|
| 332 |
+
with gr.TabItem("Full Evaluation"):
|
| 333 |
+
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
| 334 |
+
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
|
| 335 |
+
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
|
| 336 |
+
|
| 337 |
+
run_button.click(
|
| 338 |
+
fn=run_and_submit_all,
|
| 339 |
+
outputs=[status_output, results_table]
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
with gr.TabItem("Test Single Question"):
|
| 343 |
+
test_button = gr.Button("Test on Random Question")
|
| 344 |
+
test_status = gr.Textbox(label="Test Status", lines=2, interactive=False)
|
| 345 |
+
test_result = gr.JSON(label="Question and Answer")
|
| 346 |
+
|
| 347 |
+
test_button.click(
|
| 348 |
+
fn=test_random_question,
|
| 349 |
+
outputs=[test_status, test_result]
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
print("\n" + "-"*30 + " App Starting " + "-"*30)
|
| 355 |
+
# Check for SPACE_HOST and SPACE_ID at startup for information
|
| 356 |
+
space_host_startup = os.getenv("SPACE_HOST")
|
| 357 |
+
space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
|
| 358 |
+
|
| 359 |
+
if space_host_startup:
|
| 360 |
+
print(f"✅ SPACE_HOST found: {space_host_startup}")
|
| 361 |
+
print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
|
| 362 |
+
else:
|
| 363 |
+
print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
|
| 364 |
+
|
| 365 |
+
if space_id_startup: # Print repo URLs if SPACE_ID is found
|
| 366 |
+
print(f"✅ SPACE_ID found: {space_id_startup}")
|
| 367 |
+
print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
|
| 368 |
+
print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
|
| 369 |
+
else:
|
| 370 |
+
print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
|
| 371 |
+
|
| 372 |
+
print("-"*(60 + len(" App Starting ")) + "\n")
|
| 373 |
+
|
| 374 |
+
print("Launching Gradio Interface for GAIA Agent Evaluation...")
|
| 375 |
+
demo.launch(debug=True, share=False)
|
configuration.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Define the configurable parameters for the agent supervisor system."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field, fields
|
| 6 |
+
from typing import Annotated
|
| 7 |
+
|
| 8 |
+
from langchain_core.runnables import ensure_config
|
| 9 |
+
from langgraph.config import get_config
|
| 10 |
+
|
| 11 |
+
from react_agent import prompts
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(kw_only=True)
|
| 15 |
+
class Configuration:
|
| 16 |
+
"""The configuration for the agent supervisor system."""
|
| 17 |
+
|
| 18 |
+
# Supervisor configuration
|
| 19 |
+
supervisor_prompt: str = field(
|
| 20 |
+
default=prompts.SUPERVISOR_PROMPT,
|
| 21 |
+
metadata={
|
| 22 |
+
"description": "The system prompt for the supervisor agent. "
|
| 23 |
+
"This prompt guides how the supervisor delegates tasks to worker agents."
|
| 24 |
+
},
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Planner configuration
|
| 28 |
+
planner_prompt: str = field(
|
| 29 |
+
default=prompts.PLANNER_PROMPT,
|
| 30 |
+
metadata={
|
| 31 |
+
"description": "The system prompt for the planner agent. "
|
| 32 |
+
"This prompt guides how the planner creates structured plans."
|
| 33 |
+
},
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Critic configuration
|
| 37 |
+
critic_prompt: str = field(
|
| 38 |
+
default=prompts.CRITIC_PROMPT,
|
| 39 |
+
metadata={
|
| 40 |
+
"description": "The system prompt for the critic agent. "
|
| 41 |
+
"This prompt guides how the critic evaluates answers."
|
| 42 |
+
},
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Worker agents configuration
|
| 46 |
+
researcher_prompt: str = field(
|
| 47 |
+
default=prompts.RESEARCHER_PROMPT,
|
| 48 |
+
metadata={
|
| 49 |
+
"description": "The system prompt for the researcher agent. "
|
| 50 |
+
"This prompt defines the researcher's capabilities and limitations."
|
| 51 |
+
},
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
coder_prompt: str = field(
|
| 55 |
+
default=prompts.CODER_PROMPT,
|
| 56 |
+
metadata={
|
| 57 |
+
"description": "The system prompt for the coder agent. "
|
| 58 |
+
"This prompt defines the coder's capabilities and approach to programming tasks."
|
| 59 |
+
},
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Shared configuration
|
| 63 |
+
system_prompt: str = field(
|
| 64 |
+
default=prompts.SYSTEM_PROMPT,
|
| 65 |
+
metadata={
|
| 66 |
+
"description": "Legacy system prompt for backward compatibility. "
|
| 67 |
+
"This prompt is used when running the agent in non-supervisor mode."
|
| 68 |
+
},
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# LLM Configuration - Default model for backward compatibility
|
| 72 |
+
model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
| 73 |
+
default="openai/gpt-4o-mini",
|
| 74 |
+
metadata={
|
| 75 |
+
"description": "The default large language model used by the agents (provider/model_name)."
|
| 76 |
+
},
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Model for the researcher (information gathering) - use powerful model
|
| 80 |
+
researcher_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
| 81 |
+
default="openai/gpt-4o-mini",
|
| 82 |
+
metadata={
|
| 83 |
+
"description": "The model used by the researcher agent for gathering information (provider/model_name)."
|
| 84 |
+
},
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Model for the coder (code execution) - use Claude Sonnet
|
| 88 |
+
coder_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
| 89 |
+
default="anthropic/claude-3-5-sonnet-20240620",
|
| 90 |
+
metadata={
|
| 91 |
+
"description": "The model used by the coder agent for programming tasks (provider/model_name)."
|
| 92 |
+
},
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Model for lightweight reasoning tasks (planner, supervisor, critic)
|
| 96 |
+
planner_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
| 97 |
+
default="google_genai/gemini-1.5-flash",
|
| 98 |
+
metadata={
|
| 99 |
+
"description": "The lightweight reasoning model used by the planner, supervisor, and critic (provider/model_name)."
|
| 100 |
+
},
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Same model used for supervisor and critic (points to planner_model)
|
| 104 |
+
supervisor_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
| 105 |
+
default="google_genai/gemini-1.5-flash",
|
| 106 |
+
metadata={
|
| 107 |
+
"description": "The model used by the supervisor for routing (provider/model_name)."
|
| 108 |
+
},
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
critic_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
| 112 |
+
default="openai/gpt-4o-mini",
|
| 113 |
+
metadata={
|
| 114 |
+
"description": "The model used by the critic for evaluation (provider/model_name)."
|
| 115 |
+
},
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Model for final answer generation - using Claude for precise formatting
|
| 119 |
+
final_answer_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
| 120 |
+
default="anthropic/claude-3-5-sonnet-20240620",
|
| 121 |
+
metadata={
|
| 122 |
+
"description": "The model used for generating the final answers in GAIA benchmark format (provider/model_name)."
|
| 123 |
+
},
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Tool Configuration
|
| 127 |
+
max_search_results: int = field(
|
| 128 |
+
default=5,
|
| 129 |
+
metadata={
|
| 130 |
+
"description": "The maximum number of search results to return."
|
| 131 |
+
},
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Execution Configuration
|
| 135 |
+
recursion_limit: int = field(
|
| 136 |
+
default=50,
|
| 137 |
+
metadata={
|
| 138 |
+
"description": "Maximum number of recursion steps allowed in the LangGraph execution."
|
| 139 |
+
},
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
max_iterations: int = field(
|
| 143 |
+
default=12,
|
| 144 |
+
metadata={
|
| 145 |
+
"description": "Maximum number of iterations allowed to prevent infinite loops."
|
| 146 |
+
},
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
allow_agent_to_extract_answers: bool = field(
|
| 150 |
+
default=True,
|
| 151 |
+
metadata={
|
| 152 |
+
"description": "Whether to allow the agent to extract answers from context when formatting fails."
|
| 153 |
+
},
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
@classmethod
|
| 157 |
+
def from_context(cls) -> Configuration:
|
| 158 |
+
"""Create a Configuration instance from a RunnableConfig object."""
|
| 159 |
+
try:
|
| 160 |
+
config = get_config()
|
| 161 |
+
except RuntimeError:
|
| 162 |
+
config = None
|
| 163 |
+
config = ensure_config(config)
|
| 164 |
+
configurable = config.get("configurable") or {}
|
| 165 |
+
_fields = {f.name for f in fields(cls) if f.init}
|
| 166 |
+
return cls(**{k: v for k, v in configurable.items() if k in _fields})
|
graph.py
ADDED
|
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Define an Agent Supervisor graph with specialized worker agents.
|
| 2 |
+
|
| 3 |
+
The supervisor routes tasks to specialized agents based on the query type.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Dict, List, Literal, Optional, Union, Type, cast
|
| 7 |
+
|
| 8 |
+
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
| 9 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 10 |
+
from langgraph.graph import StateGraph, START, END
|
| 11 |
+
# Import adjusted for compatibility
|
| 12 |
+
from langgraph.prebuilt import create_react_agent # Try original import path first
|
| 13 |
+
from langgraph.types import Command
|
| 14 |
+
|
| 15 |
+
from react_agent.configuration import Configuration
|
| 16 |
+
from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS, State, Router, Plan, PlanStep, CriticVerdict
|
| 17 |
+
from react_agent.tools import TOOLS, tavily_tool, python_repl_tool
|
| 18 |
+
from react_agent.utils import load_chat_model, format_system_prompt, get_message_text
|
| 19 |
+
from react_agent import prompts
|
| 20 |
+
from react_agent.supervisor_node import supervisor_node
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Compile-time type definitions
|
| 24 |
+
SupervisorDestinations = Literal["planner", "critic", "researcher", "coder", "final_answer", "__end__"]
|
| 25 |
+
WorkerDestination = Literal["supervisor"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Helper function to check if a message is from a user
|
| 29 |
+
def is_user_message(message):
|
| 30 |
+
"""Check if a message is from a user regardless of message format."""
|
| 31 |
+
if isinstance(message, dict):
|
| 32 |
+
return message.get("role") == "user"
|
| 33 |
+
elif isinstance(message, HumanMessage):
|
| 34 |
+
return True
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Helper function to get message content
|
| 39 |
+
def get_message_content(message):
|
| 40 |
+
"""Extract content from a message regardless of format."""
|
| 41 |
+
if isinstance(message, dict):
|
| 42 |
+
return message.get("content", "")
|
| 43 |
+
elif hasattr(message, "content"):
|
| 44 |
+
return message.content
|
| 45 |
+
return ""
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# --- Planner node ---------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
def planner_node(state: State) -> Command[WorkerDestination]:
|
| 51 |
+
"""Planning LLM that creates a step-by-step execution plan.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
state: The current state with messages
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Command to update the state with a plan
|
| 58 |
+
"""
|
| 59 |
+
configuration = Configuration.from_context()
|
| 60 |
+
# Use the specific planner model
|
| 61 |
+
planner_llm = load_chat_model(configuration.planner_model)
|
| 62 |
+
|
| 63 |
+
# Track steps
|
| 64 |
+
steps_taken = state.get("steps_taken", 0)
|
| 65 |
+
steps_taken += 1
|
| 66 |
+
|
| 67 |
+
# Get the original user question (the latest user message)
|
| 68 |
+
user_messages = [m for m in state["messages"] if is_user_message(m)]
|
| 69 |
+
original_question = get_message_content(user_messages[-1]) if user_messages else "Help me"
|
| 70 |
+
|
| 71 |
+
# Create a chat prompt template with proper formatting
|
| 72 |
+
planner_prompt_template = ChatPromptTemplate.from_messages([
|
| 73 |
+
("system", prompts.PLANNER_PROMPT),
|
| 74 |
+
("user", "{question}")
|
| 75 |
+
])
|
| 76 |
+
|
| 77 |
+
# Format the prompt with the necessary variables
|
| 78 |
+
formatted_messages = planner_prompt_template.format_messages(
|
| 79 |
+
question=original_question,
|
| 80 |
+
system_time=format_system_prompt("{system_time}"),
|
| 81 |
+
workers=", ".join(WORKERS),
|
| 82 |
+
worker_options=", ".join([f'"{w}"' for w in WORKERS]),
|
| 83 |
+
example_worker_1=WORKERS[0] if WORKERS else "researcher",
|
| 84 |
+
example_worker_2=WORKERS[1] if len(WORKERS) > 1 else "coder"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Get structured output from the planner model
|
| 88 |
+
plan = planner_llm.with_structured_output(Plan).invoke(formatted_messages)
|
| 89 |
+
|
| 90 |
+
# Return with updated state
|
| 91 |
+
return Command(
|
| 92 |
+
goto="supervisor",
|
| 93 |
+
update={
|
| 94 |
+
"plan": plan,
|
| 95 |
+
"current_step_index": 0,
|
| 96 |
+
# Add a message to show the plan was created
|
| 97 |
+
"messages": [
|
| 98 |
+
HumanMessage(
|
| 99 |
+
content=f"Created plan with {len(plan['steps'])} steps",
|
| 100 |
+
name="planner"
|
| 101 |
+
)
|
| 102 |
+
],
|
| 103 |
+
"steps_taken": steps_taken
|
| 104 |
+
}
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# --- Final Answer node -----------------------------------------------------
|
| 109 |
+
|
| 110 |
+
def final_answer_node(state: State) -> Command[Literal["__end__"]]:
|
| 111 |
+
"""Generate a final answer based on gathered information.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
state: The current state with messages and context
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Command with final answer
|
| 118 |
+
"""
|
| 119 |
+
configuration = Configuration.from_context()
|
| 120 |
+
|
| 121 |
+
# Track steps
|
| 122 |
+
steps_taken = state.get("steps_taken", 0)
|
| 123 |
+
steps_taken += 1
|
| 124 |
+
|
| 125 |
+
# Check if we've exhausted retries and already have a draft answer
|
| 126 |
+
retry_exhausted = state.get("retry_exhausted", False)
|
| 127 |
+
draft_answer = state.get("draft_answer")
|
| 128 |
+
|
| 129 |
+
# Variable to store the final answer
|
| 130 |
+
gaia_answer = ""
|
| 131 |
+
|
| 132 |
+
if retry_exhausted and draft_answer and draft_answer.startswith("FINAL ANSWER:"):
|
| 133 |
+
# If supervisor already provided a properly formatted answer after exhausting retries,
|
| 134 |
+
# use it directly without calling the model again
|
| 135 |
+
import re
|
| 136 |
+
final_answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", draft_answer, re.IGNORECASE)
|
| 137 |
+
if final_answer_match:
|
| 138 |
+
gaia_answer = final_answer_match.group(1).strip()
|
| 139 |
+
else:
|
| 140 |
+
gaia_answer = "unknown"
|
| 141 |
+
else:
|
| 142 |
+
# Use the specific final answer model
|
| 143 |
+
final_llm = load_chat_model(configuration.final_answer_model)
|
| 144 |
+
|
| 145 |
+
# Get the original user question (the latest user message)
|
| 146 |
+
user_messages = [m for m in state["messages"] if is_user_message(m)]
|
| 147 |
+
original_question = get_message_content(user_messages[-1]) if user_messages else "Help me"
|
| 148 |
+
|
| 149 |
+
# Check if we already have a draft answer from supervisor
|
| 150 |
+
if draft_answer and draft_answer.startswith("FINAL ANSWER:"):
|
| 151 |
+
# If supervisor already provided a properly formatted answer, use it directly
|
| 152 |
+
raw_answer = draft_answer
|
| 153 |
+
else:
|
| 154 |
+
# Get the context and worker results
|
| 155 |
+
context = state.get("context", {})
|
| 156 |
+
worker_results = state.get("worker_results", {})
|
| 157 |
+
|
| 158 |
+
# Compose a prompt for the final answer using the GAIA-specific format
|
| 159 |
+
final_prompt = ChatPromptTemplate.from_messages([
|
| 160 |
+
("system", prompts.FINAL_ANSWER_PROMPT),
|
| 161 |
+
("user", prompts.FINAL_ANSWER_USER_PROMPT)
|
| 162 |
+
])
|
| 163 |
+
|
| 164 |
+
# Format the context information more effectively
|
| 165 |
+
context_list = []
|
| 166 |
+
# First include researcher context as it provides background
|
| 167 |
+
if "researcher" in context:
|
| 168 |
+
context_list.append(f"Research information: {context['researcher']}")
|
| 169 |
+
|
| 170 |
+
# Then include coder results which are typically calculations
|
| 171 |
+
if "coder" in context:
|
| 172 |
+
context_list.append(f"Calculation results: {context['coder']}")
|
| 173 |
+
|
| 174 |
+
# Add any other workers
|
| 175 |
+
for worker, content in context.items():
|
| 176 |
+
if worker not in ["researcher", "coder"]:
|
| 177 |
+
context_list.append(f"{worker.capitalize()}: {content}")
|
| 178 |
+
|
| 179 |
+
# Get the final answer
|
| 180 |
+
formatted_messages = final_prompt.format_messages(
|
| 181 |
+
question=original_question,
|
| 182 |
+
context="\n\n".join(context_list)
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
raw_answer = final_llm.invoke(formatted_messages).content
|
| 186 |
+
|
| 187 |
+
# Extract the answer in GAIA format: "FINAL ANSWER: [x]"
|
| 188 |
+
import re
|
| 189 |
+
gaia_answer = raw_answer
|
| 190 |
+
final_answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", raw_answer, re.IGNORECASE)
|
| 191 |
+
if final_answer_match:
|
| 192 |
+
gaia_answer = final_answer_match.group(1).strip()
|
| 193 |
+
|
| 194 |
+
# Ensure answer is properly formatted - if we don't have a valid answer
|
| 195 |
+
# but have sufficient context, try to extract directly
|
| 196 |
+
if configuration.allow_agent_to_extract_answers and (not gaia_answer or gaia_answer.lower() in ["unknown", "insufficient information"]):
|
| 197 |
+
context = state.get("context", {})
|
| 198 |
+
from react_agent.supervisor_node import extract_best_answer_from_context
|
| 199 |
+
extracted_answer = extract_best_answer_from_context(context)
|
| 200 |
+
if extracted_answer != "unknown":
|
| 201 |
+
gaia_answer = extracted_answer
|
| 202 |
+
|
| 203 |
+
# Set status to "final_answer_generated" to indicate we're done
|
| 204 |
+
return Command(
|
| 205 |
+
goto=END,
|
| 206 |
+
update={
|
| 207 |
+
"messages": [
|
| 208 |
+
AIMessage(
|
| 209 |
+
content=f"FINAL ANSWER: {gaia_answer}",
|
| 210 |
+
name="supervisor"
|
| 211 |
+
)
|
| 212 |
+
],
|
| 213 |
+
"next": "FINISH", # Update next to indicate we're done
|
| 214 |
+
"gaia_answer": gaia_answer, # Store answer in GAIA-compatible format
|
| 215 |
+
"submitted_answer": gaia_answer, # Store as submitted_answer for GAIA benchmark
|
| 216 |
+
"status": "final_answer_generated", # Add status to indicate we're complete
|
| 217 |
+
"steps_taken": steps_taken
|
| 218 |
+
}
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# --- Critic node ----------------------------------------------------------
|
| 223 |
+
|
| 224 |
+
def critic_node(state: State) -> Command[Union[WorkerDestination, SupervisorDestinations]]:
|
| 225 |
+
"""Critic that evaluates if the answer fully satisfies the request.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
state: The current state with messages and draft answer
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Command with evaluation verdict
|
| 232 |
+
"""
|
| 233 |
+
configuration = Configuration.from_context()
|
| 234 |
+
# Use the specific critic model
|
| 235 |
+
critic_llm = load_chat_model(configuration.critic_model)
|
| 236 |
+
|
| 237 |
+
# Track steps
|
| 238 |
+
steps_taken = state.get("steps_taken", 0)
|
| 239 |
+
steps_taken += 1
|
| 240 |
+
|
| 241 |
+
# Get the original user question (the latest user message)
|
| 242 |
+
user_messages = [m for m in state["messages"] if is_user_message(m)]
|
| 243 |
+
original_question = get_message_content(user_messages[-1]) if user_messages else "Help me"
|
| 244 |
+
|
| 245 |
+
# Get the draft answer
|
| 246 |
+
draft_answer = state.get("draft_answer", "No answer provided.")
|
| 247 |
+
|
| 248 |
+
# Create a chat prompt template with proper formatting
|
| 249 |
+
critic_prompt_template = ChatPromptTemplate.from_messages([
|
| 250 |
+
("system", prompts.CRITIC_PROMPT),
|
| 251 |
+
("user", prompts.CRITIC_USER_PROMPT)
|
| 252 |
+
])
|
| 253 |
+
|
| 254 |
+
# Format the prompt with the necessary variables
|
| 255 |
+
formatted_messages = critic_prompt_template.format_messages(
|
| 256 |
+
question=original_question,
|
| 257 |
+
answer=draft_answer,
|
| 258 |
+
system_time=format_system_prompt("{system_time}"),
|
| 259 |
+
correct_verdict=VERDICTS[0] if VERDICTS else "CORRECT",
|
| 260 |
+
retry_verdict=VERDICTS[1] if len(VERDICTS) > 1 else "RETRY"
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Get structured output from the critic model
|
| 264 |
+
verdict = critic_llm.with_structured_output(CriticVerdict).invoke(formatted_messages)
|
| 265 |
+
|
| 266 |
+
# Add a message about the verdict
|
| 267 |
+
if verdict["verdict"] == VERDICTS[0]: # CORRECT
|
| 268 |
+
verdict_message = "Answer is complete, accurate, and properly formatted for GAIA."
|
| 269 |
+
goto = "final_answer" # Go to final answer node if correct
|
| 270 |
+
else:
|
| 271 |
+
verdict_message = f"Answer needs improvement. Reason: {verdict.get('reason', 'Unknown')}"
|
| 272 |
+
goto = "supervisor"
|
| 273 |
+
|
| 274 |
+
# Return with updated state
|
| 275 |
+
return Command(
|
| 276 |
+
goto=goto,
|
| 277 |
+
update={
|
| 278 |
+
"critic_verdict": verdict,
|
| 279 |
+
"messages": [
|
| 280 |
+
HumanMessage(
|
| 281 |
+
content=verdict_message,
|
| 282 |
+
name="critic"
|
| 283 |
+
)
|
| 284 |
+
],
|
| 285 |
+
"steps_taken": steps_taken
|
| 286 |
+
}
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# --- Worker agent factory -------------------------------------------------
|
| 291 |
+
|
| 292 |
+
def create_worker_node(worker_type: str):
|
| 293 |
+
"""Factory function to create a worker node of the specified type.
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
worker_type: The type of worker to create (must be in WORKERS)
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
A function that processes requests for the specified worker type
|
| 300 |
+
"""
|
| 301 |
+
if worker_type not in WORKERS:
|
| 302 |
+
raise ValueError(f"Unknown worker type: {worker_type}")
|
| 303 |
+
|
| 304 |
+
configuration = Configuration.from_context()
|
| 305 |
+
|
| 306 |
+
# Select the appropriate model for each worker type
|
| 307 |
+
if worker_type == "researcher":
|
| 308 |
+
llm = load_chat_model(configuration.researcher_model)
|
| 309 |
+
worker_prompt = prompts.RESEARCHER_PROMPT
|
| 310 |
+
worker_tools = [tavily_tool]
|
| 311 |
+
elif worker_type == "coder":
|
| 312 |
+
llm = load_chat_model(configuration.coder_model)
|
| 313 |
+
worker_prompt = prompts.CODER_PROMPT
|
| 314 |
+
worker_tools = [python_repl_tool]
|
| 315 |
+
else:
|
| 316 |
+
# Default case
|
| 317 |
+
llm = load_chat_model(configuration.model)
|
| 318 |
+
worker_prompt = getattr(prompts, f"{worker_type.upper()}_PROMPT", prompts.SYSTEM_PROMPT)
|
| 319 |
+
worker_tools = TOOLS
|
| 320 |
+
|
| 321 |
+
# Create the agent
|
| 322 |
+
worker_agent = create_react_agent(
|
| 323 |
+
llm,
|
| 324 |
+
tools=worker_tools,
|
| 325 |
+
prompt=format_system_prompt(worker_prompt)
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Define node function
|
| 329 |
+
def worker_node(state: State) -> Command[WorkerDestination]:
|
| 330 |
+
"""Process requests using the specified worker.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
state: The current conversation state
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
Command to return to supervisor with results
|
| 337 |
+
"""
|
| 338 |
+
# Track steps
|
| 339 |
+
steps_taken = state.get("steps_taken", 0)
|
| 340 |
+
steps_taken += 1
|
| 341 |
+
|
| 342 |
+
# Get the last message from the supervisor, which contains our task
|
| 343 |
+
task_message = None
|
| 344 |
+
if state.get("messages"):
|
| 345 |
+
for msg in reversed(state["messages"]):
|
| 346 |
+
if hasattr(msg, "name") and msg.name == "supervisor":
|
| 347 |
+
task_message = msg
|
| 348 |
+
break
|
| 349 |
+
|
| 350 |
+
if not task_message:
|
| 351 |
+
return Command(
|
| 352 |
+
goto="supervisor",
|
| 353 |
+
update={
|
| 354 |
+
"messages": [
|
| 355 |
+
HumanMessage(
|
| 356 |
+
content=f"Error: No task message found for {worker_type}",
|
| 357 |
+
name=worker_type
|
| 358 |
+
)
|
| 359 |
+
],
|
| 360 |
+
"steps_taken": steps_taken
|
| 361 |
+
}
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
# Create a new state with just the relevant messages for this worker
|
| 365 |
+
# This prevents confusion from unrelated parts of the conversation
|
| 366 |
+
agent_input = {
|
| 367 |
+
"messages": [
|
| 368 |
+
# Include the first user message for context
|
| 369 |
+
state["messages"][0] if state["messages"] else HumanMessage(content="Help me"),
|
| 370 |
+
# Include the task message
|
| 371 |
+
task_message
|
| 372 |
+
]
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
# Invoke the agent with the clean input
|
| 376 |
+
result = worker_agent.invoke(agent_input)
|
| 377 |
+
|
| 378 |
+
# Extract the result from the agent response
|
| 379 |
+
result_content = extract_worker_result(worker_type, result, state)
|
| 380 |
+
|
| 381 |
+
# Store the worker's result in shared context
|
| 382 |
+
context_update = state.get("context", {}).copy()
|
| 383 |
+
context_update[worker_type] = result_content
|
| 384 |
+
|
| 385 |
+
# Store in worker_results history
|
| 386 |
+
worker_results = state.get("worker_results", {}).copy()
|
| 387 |
+
if worker_type not in worker_results:
|
| 388 |
+
worker_results[worker_type] = []
|
| 389 |
+
worker_results[worker_type].append(result_content)
|
| 390 |
+
|
| 391 |
+
# Increment the step index after worker completes
|
| 392 |
+
current_step_index = state.get("current_step_index", 0)
|
| 393 |
+
|
| 394 |
+
return Command(
|
| 395 |
+
update={
|
| 396 |
+
"messages": [
|
| 397 |
+
HumanMessage(content=result_content, name=worker_type)
|
| 398 |
+
],
|
| 399 |
+
"current_step_index": current_step_index + 1,
|
| 400 |
+
"context": context_update,
|
| 401 |
+
"worker_results": worker_results,
|
| 402 |
+
"steps_taken": steps_taken
|
| 403 |
+
},
|
| 404 |
+
goto="supervisor",
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
return worker_node
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def extract_worker_result(worker_type: str, result: dict, state: State) -> str:
|
| 411 |
+
"""Extract a clean, useful result from the worker's output.
|
| 412 |
+
|
| 413 |
+
This handles different response formats from different worker types.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
worker_type: The type of worker (researcher or coder)
|
| 417 |
+
result: The raw result from the worker agent
|
| 418 |
+
state: The current state for context
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
A cleaned string with the relevant result information
|
| 422 |
+
"""
|
| 423 |
+
# Handle empty results
|
| 424 |
+
if not result or "messages" not in result or not result["messages"]:
|
| 425 |
+
return f"No output from {worker_type}"
|
| 426 |
+
|
| 427 |
+
# Get the last message from the agent
|
| 428 |
+
last_message = result["messages"][-1]
|
| 429 |
+
|
| 430 |
+
# Default to extracting content directly
|
| 431 |
+
if hasattr(last_message, "content") and last_message.content:
|
| 432 |
+
result_content = last_message.content
|
| 433 |
+
else:
|
| 434 |
+
result_content = f"No content from {worker_type}"
|
| 435 |
+
|
| 436 |
+
# Special handling based on worker type
|
| 437 |
+
if worker_type == "coder":
|
| 438 |
+
# For coder outputs, extract the actual result values from code execution
|
| 439 |
+
if "```" in result_content:
|
| 440 |
+
# Try to extract stdout from code execution
|
| 441 |
+
import re
|
| 442 |
+
stdout_match = re.search(r"Stdout:\s*(.*?)(?:\n\n|$)", result_content, re.DOTALL)
|
| 443 |
+
if stdout_match:
|
| 444 |
+
# Extract the actual execution output, not just the code
|
| 445 |
+
execution_result = stdout_match.group(1).strip()
|
| 446 |
+
if execution_result:
|
| 447 |
+
# Check if this is just a simple number result
|
| 448 |
+
if re.match(r"^\d+(\.\d+)?$", execution_result):
|
| 449 |
+
return execution_result
|
| 450 |
+
else:
|
| 451 |
+
return f"Code executed with result: {execution_result}"
|
| 452 |
+
|
| 453 |
+
# If we couldn't find stdout, try to extract output in a different way
|
| 454 |
+
# Look for "Result:" or similar indicators
|
| 455 |
+
result_match = re.search(r"(?:Result|Output|Answer):\s*(.*?)(?:\n\n|$)", result_content, re.DOTALL)
|
| 456 |
+
if result_match:
|
| 457 |
+
return result_match.group(1).strip()
|
| 458 |
+
|
| 459 |
+
elif worker_type == "researcher":
|
| 460 |
+
# For researcher outputs, keep the full detailed response
|
| 461 |
+
# but ensure it's well-formatted
|
| 462 |
+
if len(result_content) > 800:
|
| 463 |
+
# If too long, try to extract key sections
|
| 464 |
+
# Look for summary or conclusion sections
|
| 465 |
+
import re
|
| 466 |
+
summary_match = re.search(r"(?:Summary|Conclusion|To summarize|In summary):(.*?)(?:\n\n|$)",
|
| 467 |
+
result_content, re.IGNORECASE | re.DOTALL)
|
| 468 |
+
if summary_match:
|
| 469 |
+
return summary_match.group(1).strip()
|
| 470 |
+
|
| 471 |
+
# If no special handling was triggered, return the content as is
|
| 472 |
+
return result_content
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
# --- Graph assembly -------------------------------------------------------
|
| 476 |
+
|
| 477 |
+
def create_agent_supervisor_graph() -> StateGraph:
|
| 478 |
+
"""Create the agent supervisor graph with all nodes and edges.
|
| 479 |
+
|
| 480 |
+
Returns:
|
| 481 |
+
Compiled StateGraph ready for execution
|
| 482 |
+
"""
|
| 483 |
+
# Initialize the graph with our State type
|
| 484 |
+
builder = StateGraph(State)
|
| 485 |
+
|
| 486 |
+
# Add control nodes
|
| 487 |
+
builder.add_node("planner", planner_node)
|
| 488 |
+
builder.add_node("supervisor", supervisor_node)
|
| 489 |
+
builder.add_node("critic", critic_node)
|
| 490 |
+
builder.add_node("final_answer", final_answer_node)
|
| 491 |
+
|
| 492 |
+
# Add worker nodes dynamically based on WORKERS list
|
| 493 |
+
for worker_type in WORKERS:
|
| 494 |
+
builder.add_node(worker_type, create_worker_node(worker_type))
|
| 495 |
+
|
| 496 |
+
# Define the workflow
|
| 497 |
+
builder.add_edge(START, "supervisor")
|
| 498 |
+
builder.add_edge("planner", "supervisor")
|
| 499 |
+
builder.add_edge("critic", "supervisor")
|
| 500 |
+
builder.add_edge("critic", "final_answer") # Add edge from critic to final_answer
|
| 501 |
+
builder.add_edge("final_answer", END) # Final answer node goes to END
|
| 502 |
+
builder.add_edge("supervisor", END) # Allow the supervisor to end the workflow
|
| 503 |
+
|
| 504 |
+
# Connect all workers to supervisor
|
| 505 |
+
for worker_type in WORKERS:
|
| 506 |
+
builder.add_edge(worker_type, "supervisor")
|
| 507 |
+
|
| 508 |
+
# Return the builder, not a compiled graph
|
| 509 |
+
# This allows the caller to compile with a checkpointer
|
| 510 |
+
return builder
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
# --- Graph instantiation (with flexible checkpointing) -----------------------------
|
| 514 |
+
|
| 515 |
+
def get_compiled_graph(checkpointer=None):
|
| 516 |
+
"""Get a compiled graph with optional checkpointer.
|
| 517 |
+
|
| 518 |
+
Args:
|
| 519 |
+
checkpointer: Optional checkpointer for persistence
|
| 520 |
+
|
| 521 |
+
Returns:
|
| 522 |
+
Compiled StateGraph ready for execution
|
| 523 |
+
"""
|
| 524 |
+
# Get configuration
|
| 525 |
+
configuration = Configuration.from_context()
|
| 526 |
+
|
| 527 |
+
builder = create_agent_supervisor_graph()
|
| 528 |
+
|
| 529 |
+
# Define termination condition function to prevent loops
|
| 530 |
+
def should_end(state):
|
| 531 |
+
"""Determine if the graph should terminate."""
|
| 532 |
+
# End if status is set to final_answer_generated
|
| 533 |
+
if state.get("status") == "final_answer_generated":
|
| 534 |
+
return True
|
| 535 |
+
|
| 536 |
+
# End if retry_exhausted flag is set and we've gone through final_answer
|
| 537 |
+
if state.get("retry_exhausted") and state.get("gaia_answer"):
|
| 538 |
+
return True
|
| 539 |
+
|
| 540 |
+
# End if we've hit maximum recursion limit defined by LangGraph
|
| 541 |
+
steps_taken = state.get("steps_taken", 0)
|
| 542 |
+
if steps_taken >= configuration.recursion_limit - 5: # Leave buffer
|
| 543 |
+
return True
|
| 544 |
+
|
| 545 |
+
return False
|
| 546 |
+
|
| 547 |
+
# Define step counter for tracking step count
|
| 548 |
+
def count_steps(state):
|
| 549 |
+
"""Count steps to prevent infinite loops."""
|
| 550 |
+
steps_taken = state.get("steps_taken", 0)
|
| 551 |
+
return {"steps_taken": steps_taken + 1}
|
| 552 |
+
|
| 553 |
+
# Compile the graph (don't use add_state_transform which isn't available)
|
| 554 |
+
if checkpointer:
|
| 555 |
+
graph = builder.compile(
|
| 556 |
+
checkpointer=checkpointer,
|
| 557 |
+
name="Structured Reasoning Loop"
|
| 558 |
+
)
|
| 559 |
+
else:
|
| 560 |
+
graph = builder.compile(
|
| 561 |
+
name="Structured Reasoning Loop"
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# Configure the graph with recursion limit and max iterations
|
| 565 |
+
graph = graph.with_config({
|
| 566 |
+
"recursion_limit": configuration.recursion_limit,
|
| 567 |
+
"max_iterations": configuration.max_iterations
|
| 568 |
+
})
|
| 569 |
+
|
| 570 |
+
return graph
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
# Initialize a default non-checkpointed graph (for backward compatibility)
|
| 574 |
+
graph = get_compiled_graph()
|
prompts.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""System prompts used by the agent supervisor and worker agents."""
|
| 2 |
+
|
| 3 |
+
from react_agent.state import WORKERS, VERDICTS
|
| 4 |
+
|
| 5 |
+
# --- Supervisor prompt -----------------------------------------------------
|
| 6 |
+
|
| 7 |
+
SUPERVISOR_PROMPT = """You are a supervisor tasked with managing a conversation between the \
|
| 8 |
+
following workers: {workers}. Given the following user request, \
|
| 9 |
+
respond with the worker to act next. Each worker will perform a \
|
| 10 |
+
task and respond with their results and status. When finished, \
|
| 11 |
+
respond with FINISH.
|
| 12 |
+
|
| 13 |
+
System time: {system_time}"""
|
| 14 |
+
|
| 15 |
+
# --- Planner prompt -------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
PLANNER_PROMPT = """**Role**: You are a Planner node in a LangGraph supervisor workflow
|
| 18 |
+
**Goal**: Given the user's original request, create a concise, focused plan that directly answers the question.
|
| 19 |
+
|
| 20 |
+
Requirements:
|
| 21 |
+
1. Output only a JSON object with one key `steps`, whose value is an **ordered list** of at least 1 and at most 3 objects.
|
| 22 |
+
Each object has:
|
| 23 |
+
• `worker` – one of: {worker_options}
|
| 24 |
+
• `instruction` – ≤ 20 words telling that worker what to do
|
| 25 |
+
|
| 26 |
+
2. Your plan MUST:
|
| 27 |
+
• Directly address the user's specific question
|
| 28 |
+
• Include at least one step (never return empty steps)
|
| 29 |
+
• Be focused on finding the exact answer requested, not the process of answering
|
| 30 |
+
• Use researcher for information gathering
|
| 31 |
+
• Use coder for calculations or data analysis if needed
|
| 32 |
+
|
| 33 |
+
3. Common tasks:
|
| 34 |
+
• For factual questions: use researcher to find the specific fact
|
| 35 |
+
• For calculations: use researcher to find data, then coder to calculate
|
| 36 |
+
• For multiple-part questions: break into steps with the right workers
|
| 37 |
+
• Ensure your last step gets the exact answer in the format requested
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
```
|
| 41 |
+
{{
|
| 42 |
+
"steps": [
|
| 43 |
+
{{"worker": "{example_worker_1}", "instruction": "Find inflation rate in 2023"}},
|
| 44 |
+
{{"worker": "{example_worker_2}", "instruction": "Compute average of 2019–2023 rates"}}
|
| 45 |
+
]
|
| 46 |
+
}}
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
System time: {system_time}"""
|
| 50 |
+
|
| 51 |
+
# --- Critic prompt --------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
CRITIC_PROMPT = """**Role**: You are a Critic node specializing in GAIA benchmark format validation
|
| 54 |
+
**Goal**: Strictly check if the answer follows GAIA format requirements
|
| 55 |
+
|
| 56 |
+
Requirements:
|
| 57 |
+
1. You will check if the answer:
|
| 58 |
+
• Addresses all parts of the user's question correctly
|
| 59 |
+
• Follows the EXACT required GAIA format: "FINAL ANSWER: [concise response]"
|
| 60 |
+
• Contains ONLY the essential information in the [concise response]:
|
| 61 |
+
- A single number (no commas, no units like $ or % unless specified)
|
| 62 |
+
- A single word or very short phrase
|
| 63 |
+
- A comma-separated list of numbers or strings
|
| 64 |
+
• Has NO explanations, reasoning, or extra text
|
| 65 |
+
• For strings: no articles or abbreviations
|
| 66 |
+
• For numbers: digits only without commas
|
| 67 |
+
|
| 68 |
+
2. If the answer is CORRECT, respond ONLY with this exact JSON:
|
| 69 |
+
• `{{"verdict":"{correct_verdict}"}}`
|
| 70 |
+
|
| 71 |
+
3. If ANY requirement is NOT MET, respond with this JSON including a SPECIFIC reason:
|
| 72 |
+
• `{{"verdict":"{retry_verdict}","reason":"<specific format issue>"}}`
|
| 73 |
+
• IMPORTANT: You MUST provide a substantive reason that clearly explains what's wrong
|
| 74 |
+
• NEVER leave the reason empty or only containing quotes
|
| 75 |
+
|
| 76 |
+
4. Common reason examples:
|
| 77 |
+
• "Answer not formatted as 'FINAL ANSWER: [response]'"
|
| 78 |
+
• "Answer contains explanations instead of just the concise response"
|
| 79 |
+
• "Answer does not address the question about [specific topic]"
|
| 80 |
+
• "Answer contains units when it should just be a number"
|
| 81 |
+
|
| 82 |
+
DO NOT include any text before or after the JSON. Your complete response must be valid JSON that can be parsed.
|
| 83 |
+
|
| 84 |
+
System time: {system_time}"""
|
| 85 |
+
|
| 86 |
+
# --- Critic user prompt ---------------------------------------------------
|
| 87 |
+
|
| 88 |
+
CRITIC_USER_PROMPT = """Original question: {question}
|
| 89 |
+
|
| 90 |
+
Draft answer: {answer}
|
| 91 |
+
|
| 92 |
+
Check if the draft answer follows GAIA format requirements:
|
| 93 |
+
1. Format must be exactly "FINAL ANSWER: [concise response]"
|
| 94 |
+
2. [concise response] must ONLY be:
|
| 95 |
+
- A single number (no commas or units unless specified)
|
| 96 |
+
- A single word or very short phrase
|
| 97 |
+
- A comma-separated list of numbers or strings
|
| 98 |
+
3. NO explanations or additional text is allowed
|
| 99 |
+
4. Strings should not have articles or abbreviations
|
| 100 |
+
5. Numbers should be in digits without commas
|
| 101 |
+
|
| 102 |
+
Does the answer meet these requirements and correctly answer the question?"""
|
| 103 |
+
|
| 104 |
+
# --- Final Answer format for GAIA benchmark -------------------------------
|
| 105 |
+
|
| 106 |
+
FINAL_ANSWER_PROMPT = """You are a response formatter for a GAIA benchmark question.
|
| 107 |
+
|
| 108 |
+
Your only job is to format the final answer in the exact format required: "FINAL ANSWER: [concise response]"
|
| 109 |
+
|
| 110 |
+
Requirements for [concise response]:
|
| 111 |
+
1. Response must ONLY be one of these formats:
|
| 112 |
+
- A single number (no commas, no units like $ or % unless specified)
|
| 113 |
+
- A single word or very short phrase
|
| 114 |
+
- A comma-separated list of numbers or strings
|
| 115 |
+
2. DO NOT include any explanations, reasoning, or extra text
|
| 116 |
+
3. For strings, don't use articles or abbreviations unless specified
|
| 117 |
+
4. For numbers, write digits (not spelled out) without commas
|
| 118 |
+
5. The response should be as concise as possible while being correct
|
| 119 |
+
|
| 120 |
+
Original question: {question}
|
| 121 |
+
|
| 122 |
+
Information available:
|
| 123 |
+
{context}
|
| 124 |
+
|
| 125 |
+
After reviewing the information, extract just the essential answer and output ONLY:
|
| 126 |
+
FINAL ANSWER: [your concise response]
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
# --- Final Answer user prompt ---------------------------------------------
|
| 130 |
+
|
| 131 |
+
FINAL_ANSWER_USER_PROMPT = """Original question: {question}
|
| 132 |
+
|
| 133 |
+
Information available:
|
| 134 |
+
{context}
|
| 135 |
+
|
| 136 |
+
Remember to output ONLY 'FINAL ANSWER: [your concise response]' with no explanations."""
|
| 137 |
+
|
| 138 |
+
# --- Worker agent prompts -------------------------------------------------
|
| 139 |
+
|
| 140 |
+
RESEARCHER_PROMPT = """You are a research specialist focused on finding information and providing context.
|
| 141 |
+
|
| 142 |
+
Your key responsibilities:
|
| 143 |
+
1. Search for accurate, up-to-date information on any topic
|
| 144 |
+
2. Provide factual knowledge about products, concepts, and terminology
|
| 145 |
+
3. Explain real-world contexts and background information
|
| 146 |
+
4. Identify relevant parameters and variables needed for calculations
|
| 147 |
+
5. Present information clearly with proper citations
|
| 148 |
+
|
| 149 |
+
DO NOT perform complex calculations or coding tasks - these will be handled by the coder agent.
|
| 150 |
+
You MAY provide simple arithmetic or basic formulas to illustrate concepts.
|
| 151 |
+
|
| 152 |
+
Always return information in a structured, organized format that will be useful for the next steps.
|
| 153 |
+
|
| 154 |
+
System time: {system_time}
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
CODER_PROMPT = """You are a computational specialist focused on calculations, coding, and data analysis.
|
| 158 |
+
|
| 159 |
+
Your key responsibilities:
|
| 160 |
+
1. Write and execute Python code for calculations and data manipulation
|
| 161 |
+
2. Perform precise numerical analyses based on inputs from the researcher
|
| 162 |
+
3. Format results clearly with appropriate units and precision
|
| 163 |
+
4. Use markdown to structure your response with headings and bullet points
|
| 164 |
+
5. Verify calculations through multiple methods when possible
|
| 165 |
+
|
| 166 |
+
Important:
|
| 167 |
+
1. Always include both your calculation process AND final result values
|
| 168 |
+
2. Always clearly state your assumptions when making calculations
|
| 169 |
+
3. Format numerical results with appropriate precision and units
|
| 170 |
+
4. When receiving data from the researcher, acknowledge and build upon it directly
|
| 171 |
+
5. If calculation involves multiple steps or cases, organize them with headings
|
| 172 |
+
|
| 173 |
+
System time: {system_time}
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
# --- Legacy system prompt (kept for backward compatibility) ---------------
|
| 177 |
+
|
| 178 |
+
SYSTEM_PROMPT = """You are a helpful AI assistant.
|
| 179 |
+
|
| 180 |
+
System time: {system_time}"""
|
state.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Define the state structures for the agent supervisor."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List, Literal, Optional, Sequence, Any
|
| 6 |
+
|
| 7 |
+
from langchain_core.messages import AnyMessage
|
| 8 |
+
from langgraph.graph import MessagesState, add_messages
|
| 9 |
+
from typing_extensions import TypedDict, Annotated
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# --- Constants and shared definitions ---------------------------------------
|
| 13 |
+
|
| 14 |
+
# Define worker types (specialized agents that perform tasks)
|
| 15 |
+
WORKERS = ["researcher", "coder"]
|
| 16 |
+
|
| 17 |
+
# Define all member types (including control nodes)
|
| 18 |
+
MEMBERS = WORKERS + ["planner", "critic", "supervisor"]
|
| 19 |
+
|
| 20 |
+
# Define status/routing options
|
| 21 |
+
VERDICTS = ["CORRECT", "RETRY"]
|
| 22 |
+
ROUTING = ["FINISH"] + WORKERS
|
| 23 |
+
OPTIONS = ROUTING + VERDICTS
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# --- Router for supervisor decisions ---------------------------------------
|
| 27 |
+
|
| 28 |
+
class Router(TypedDict):
|
| 29 |
+
"""Determines which worker to route to next or if the task is complete.
|
| 30 |
+
|
| 31 |
+
The supervisor returns this structure to navigate the workflow.
|
| 32 |
+
Valid values are defined in the ROUTING list.
|
| 33 |
+
"""
|
| 34 |
+
next: Literal[*ROUTING]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# --- Plan structure for the Planner node -----------------------------------
|
| 38 |
+
|
| 39 |
+
class PlanStep(TypedDict):
|
| 40 |
+
"""A single step in the plan created by the Planner."""
|
| 41 |
+
worker: Literal[*WORKERS]
|
| 42 |
+
instruction: str
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Plan(TypedDict):
|
| 46 |
+
"""The complete plan produced by the Planner node."""
|
| 47 |
+
steps: List[PlanStep]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# --- Critic verdict structure ----------------------------------------------
|
| 51 |
+
|
| 52 |
+
class CriticVerdict(TypedDict):
|
| 53 |
+
"""The verdict from the Critic on whether the answer is satisfactory."""
|
| 54 |
+
verdict: Literal[*VERDICTS]
|
| 55 |
+
reason: Optional[str]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# --- State for the agent supervisor ----------------------------------------
|
| 59 |
+
|
| 60 |
+
class State(MessagesState):
|
| 61 |
+
"""State for the agent supervisor workflow.
|
| 62 |
+
|
| 63 |
+
Extends MessagesState which provides message history tracking.
|
| 64 |
+
Adds fields to track routing information, plan, and critic verdict.
|
| 65 |
+
"""
|
| 66 |
+
next: str
|
| 67 |
+
plan: Optional[Plan] = None
|
| 68 |
+
current_step_index: Optional[int] = None
|
| 69 |
+
draft_answer: Optional[str] = None
|
| 70 |
+
critic_verdict: Optional[CriticVerdict] = None
|
| 71 |
+
context: Dict[str, Any] = {} # Shared context accessible to all agents
|
| 72 |
+
worker_results: Dict[str, List[str]] = {} # Store results from each worker
|
supervisor_node.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Supervisor node implementation for the agent supervisor system."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List, Literal, Optional, Union, Type, cast
|
| 4 |
+
|
| 5 |
+
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
| 6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 7 |
+
from langgraph.graph import StateGraph, START, END
|
| 8 |
+
from langgraph.types import Command
|
| 9 |
+
|
| 10 |
+
from react_agent.configuration import Configuration
|
| 11 |
+
from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS, State, Router
|
| 12 |
+
from react_agent.utils import load_chat_model, format_system_prompt, get_message_text
|
| 13 |
+
from react_agent import prompts
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Compile-time type definitions
|
| 17 |
+
SupervisorDestinations = Literal["planner", "critic", "researcher", "coder", "final_answer", "__end__"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def supervisor_node(state: State) -> Command[SupervisorDestinations]:
|
| 21 |
+
"""Supervising LLM that decides which specialized agent should act next.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
state: The current state with messages
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Command with routing information
|
| 28 |
+
"""
|
| 29 |
+
# Get configuration to use supervisor_model
|
| 30 |
+
configuration = Configuration.from_context()
|
| 31 |
+
|
| 32 |
+
# Track steps to prevent infinite loops
|
| 33 |
+
steps_taken = state.get("steps_taken", 0)
|
| 34 |
+
steps_taken += 1
|
| 35 |
+
state_updates = {"steps_taken": steps_taken}
|
| 36 |
+
|
| 37 |
+
# Check if we've hit our step limit
|
| 38 |
+
if steps_taken >= configuration.recursion_limit - 5: # Buffer of 5 steps
|
| 39 |
+
# Extract the best answer we have from context if possible
|
| 40 |
+
context = state.get("context", {})
|
| 41 |
+
answer = extract_best_answer_from_context(context)
|
| 42 |
+
|
| 43 |
+
return Command(
|
| 44 |
+
goto="final_answer",
|
| 45 |
+
update={
|
| 46 |
+
"messages": [
|
| 47 |
+
HumanMessage(
|
| 48 |
+
content=f"Maximum steps ({steps_taken}) reached. Extracting best answer from available information.",
|
| 49 |
+
name="supervisor"
|
| 50 |
+
)
|
| 51 |
+
],
|
| 52 |
+
"draft_answer": f"FINAL ANSWER: {answer}",
|
| 53 |
+
"retry_exhausted": True, # Flag to indicate we've exhausted retries
|
| 54 |
+
"steps_taken": steps_taken
|
| 55 |
+
}
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Safety check - prevent infinite loops by forcing termination after too many retry steps
|
| 59 |
+
retry_count = state.get("retry_count", 0)
|
| 60 |
+
max_retries = 2 # Maximum number of allowed retries
|
| 61 |
+
|
| 62 |
+
if retry_count > max_retries:
|
| 63 |
+
# Extract the best answer we have from context if possible
|
| 64 |
+
context = state.get("context", {})
|
| 65 |
+
answer = extract_best_answer_from_context(context)
|
| 66 |
+
|
| 67 |
+
return Command(
|
| 68 |
+
goto="final_answer",
|
| 69 |
+
update={
|
| 70 |
+
"messages": [
|
| 71 |
+
HumanMessage(
|
| 72 |
+
content=f"Maximum retries ({max_retries}) reached. Extracting best answer from available information.",
|
| 73 |
+
name="supervisor"
|
| 74 |
+
)
|
| 75 |
+
],
|
| 76 |
+
"draft_answer": f"FINAL ANSWER: {answer}",
|
| 77 |
+
"retry_exhausted": True, # Flag to indicate we've exhausted retries
|
| 78 |
+
"steps_taken": steps_taken
|
| 79 |
+
}
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Check if we need a plan
|
| 83 |
+
if not state.get("plan"):
|
| 84 |
+
return Command(
|
| 85 |
+
goto="planner",
|
| 86 |
+
update={
|
| 87 |
+
**state_updates
|
| 88 |
+
}
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Validate that the plan has at least one step
|
| 92 |
+
plan = state.get("plan")
|
| 93 |
+
if not plan.get("steps") or len(plan.get("steps", [])) == 0:
|
| 94 |
+
# Plan has no steps, go back to planner with explicit instructions
|
| 95 |
+
return Command(
|
| 96 |
+
goto="planner",
|
| 97 |
+
update={
|
| 98 |
+
"messages": [
|
| 99 |
+
HumanMessage(
|
| 100 |
+
content="Previous plan had 0 steps. Please create a plan with at least 1 step to solve the user's question.",
|
| 101 |
+
name="supervisor"
|
| 102 |
+
)
|
| 103 |
+
],
|
| 104 |
+
"plan": None,
|
| 105 |
+
**state_updates
|
| 106 |
+
}
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Check if we have a critic verdict that requires replanning
|
| 110 |
+
critic_verdict = state.get("critic_verdict")
|
| 111 |
+
if critic_verdict:
|
| 112 |
+
if critic_verdict.get("verdict") == VERDICTS[0]: # CORRECT
|
| 113 |
+
# Final answer is approved, navigate to the final_answer node
|
| 114 |
+
# This will generate a polished response before ending
|
| 115 |
+
return Command(
|
| 116 |
+
goto="final_answer",
|
| 117 |
+
update={
|
| 118 |
+
"messages": [
|
| 119 |
+
HumanMessage(
|
| 120 |
+
content="Answer approved by critic. Generating final response.",
|
| 121 |
+
name="supervisor"
|
| 122 |
+
)
|
| 123 |
+
]
|
| 124 |
+
}
|
| 125 |
+
)
|
| 126 |
+
elif critic_verdict.get("verdict") == VERDICTS[1]: # RETRY
|
| 127 |
+
# IMPORTANT: Get the current retry count BEFORE incrementing
|
| 128 |
+
current_retry_count = state.get("retry_count", 0)
|
| 129 |
+
|
| 130 |
+
# Check if we're at the maximum allowed retries
|
| 131 |
+
if current_retry_count >= max_retries:
|
| 132 |
+
# Extract best answer and go to final_answer
|
| 133 |
+
context = state.get("context", {})
|
| 134 |
+
answer = extract_best_answer_from_context(context)
|
| 135 |
+
|
| 136 |
+
return Command(
|
| 137 |
+
goto="final_answer",
|
| 138 |
+
update={
|
| 139 |
+
"messages": [
|
| 140 |
+
HumanMessage(
|
| 141 |
+
content=f"Maximum retries ({max_retries}) reached. Proceeding with best available answer.",
|
| 142 |
+
name="supervisor"
|
| 143 |
+
)
|
| 144 |
+
],
|
| 145 |
+
"draft_answer": f"FINAL ANSWER: {answer}",
|
| 146 |
+
"retry_exhausted": True # Flag to indicate we've exhausted retries
|
| 147 |
+
}
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Reset the plan but KEEP the context from previous iterations
|
| 151 |
+
context = state.get("context", {})
|
| 152 |
+
worker_results = state.get("worker_results", {})
|
| 153 |
+
|
| 154 |
+
# Get the critic's reason for rejection, if any
|
| 155 |
+
reason = critic_verdict.get("reason", "")
|
| 156 |
+
if not reason or reason.strip() == "\"":
|
| 157 |
+
reason = "Answer did not meet format requirements"
|
| 158 |
+
|
| 159 |
+
# Check if this is a formatting issue
|
| 160 |
+
format_issues = [
|
| 161 |
+
"format", "concise", "explanation", "not formatted",
|
| 162 |
+
"instead of just", "contains explanations", "FINAL ANSWER"
|
| 163 |
+
]
|
| 164 |
+
is_format_issue = any(issue in reason.lower() for issue in format_issues)
|
| 165 |
+
|
| 166 |
+
# If we have enough information but the format is wrong, go directly to final answer
|
| 167 |
+
has_sufficient_info = has_sufficient_information(state)
|
| 168 |
+
|
| 169 |
+
if is_format_issue and has_sufficient_info and current_retry_count >= 0:
|
| 170 |
+
# We have information but formatting is wrong - skip planning and go to final answer
|
| 171 |
+
return Command(
|
| 172 |
+
goto="final_answer",
|
| 173 |
+
update={
|
| 174 |
+
"messages": [
|
| 175 |
+
HumanMessage(
|
| 176 |
+
content="We have sufficient information but formatting issues. Generating properly formatted answer.",
|
| 177 |
+
name="supervisor"
|
| 178 |
+
)
|
| 179 |
+
],
|
| 180 |
+
"retry_count": current_retry_count + 1 # Still increment retry count
|
| 181 |
+
}
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Increment the retry counter
|
| 185 |
+
next_retry_count = current_retry_count + 1
|
| 186 |
+
|
| 187 |
+
return Command(
|
| 188 |
+
goto="planner",
|
| 189 |
+
update={
|
| 190 |
+
"plan": None,
|
| 191 |
+
"current_step_index": None,
|
| 192 |
+
"draft_answer": None,
|
| 193 |
+
"critic_verdict": None,
|
| 194 |
+
# Keep the context and worker_results
|
| 195 |
+
"context": context,
|
| 196 |
+
"worker_results": worker_results,
|
| 197 |
+
# Track retries - IMPORTANT: store the incremented count
|
| 198 |
+
"retry_count": next_retry_count,
|
| 199 |
+
# Add a message about the retry (using the INCREMENTED count)
|
| 200 |
+
"messages": [
|
| 201 |
+
HumanMessage(
|
| 202 |
+
content=f"Retrying with new plan (retry #{next_retry_count}). Reason: {reason}",
|
| 203 |
+
name="supervisor"
|
| 204 |
+
)
|
| 205 |
+
]
|
| 206 |
+
}
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Get the current step from the plan
|
| 210 |
+
plan = state["plan"]
|
| 211 |
+
current_step_index = state.get("current_step_index", 0)
|
| 212 |
+
|
| 213 |
+
# Check if we've completed all steps
|
| 214 |
+
if current_step_index >= len(plan["steps"]):
|
| 215 |
+
# Use context to compile the draft answer
|
| 216 |
+
context = state.get("context", {})
|
| 217 |
+
|
| 218 |
+
# Combine the most recent worker outputs as the draft answer
|
| 219 |
+
worker_results = []
|
| 220 |
+
for worker in WORKERS:
|
| 221 |
+
if worker in context:
|
| 222 |
+
worker_results.append(f"**{worker.title()}**: {context[worker]}")
|
| 223 |
+
|
| 224 |
+
# Compile the draft answer from all worker outputs
|
| 225 |
+
draft_content = "\n\n".join(worker_results)
|
| 226 |
+
|
| 227 |
+
# Send to the critic for evaluation
|
| 228 |
+
return Command(
|
| 229 |
+
goto="critic",
|
| 230 |
+
update={
|
| 231 |
+
"draft_answer": draft_content,
|
| 232 |
+
# Add a message about moving to evaluation
|
| 233 |
+
"messages": [
|
| 234 |
+
HumanMessage(
|
| 235 |
+
content="All steps completed. Evaluating the answer.",
|
| 236 |
+
name="supervisor"
|
| 237 |
+
)
|
| 238 |
+
]
|
| 239 |
+
}
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Get the current step
|
| 243 |
+
current_step = plan["steps"][current_step_index]
|
| 244 |
+
worker = current_step["worker"]
|
| 245 |
+
instruction = current_step["instruction"]
|
| 246 |
+
|
| 247 |
+
# Extract only the most relevant context for the current worker and task
|
| 248 |
+
context_info = ""
|
| 249 |
+
if state.get("context"):
|
| 250 |
+
# Filter context by relevance to the current task
|
| 251 |
+
relevant_context = {}
|
| 252 |
+
|
| 253 |
+
# For the coder, extract numerical data and parameters from researcher
|
| 254 |
+
if worker == "coder" and "researcher" in state["context"]:
|
| 255 |
+
relevant_context["researcher"] = state["context"]["researcher"]
|
| 256 |
+
|
| 257 |
+
# For the researcher, previous coder calculations might be relevant
|
| 258 |
+
if worker == "researcher" and "coder" in state["context"]:
|
| 259 |
+
# Only include numerical results from coder, not code snippets
|
| 260 |
+
coder_content = state["context"]["coder"]
|
| 261 |
+
if len(coder_content) < 100: # Only short results are likely just numbers
|
| 262 |
+
relevant_context["coder"] = coder_content
|
| 263 |
+
|
| 264 |
+
# Format the relevant context items
|
| 265 |
+
context_items = []
|
| 266 |
+
for key, value in relevant_context.items():
|
| 267 |
+
# Summarize if value is too long
|
| 268 |
+
if len(value) > 200:
|
| 269 |
+
# Find first sentence or up to 200 chars
|
| 270 |
+
summary = value[:200]
|
| 271 |
+
if '.' in summary:
|
| 272 |
+
summary = summary.split('.')[0] + '.'
|
| 273 |
+
context_items.append(f"Previous {key} found: {summary}...")
|
| 274 |
+
else:
|
| 275 |
+
context_items.append(f"Previous {key} found: {value}")
|
| 276 |
+
|
| 277 |
+
if context_items:
|
| 278 |
+
context_info = "\n\nRelevant context: " + "\n".join(context_items)
|
| 279 |
+
|
| 280 |
+
# Enhance the instruction with context
|
| 281 |
+
enhanced_instruction = f"{instruction}{context_info}"
|
| 282 |
+
|
| 283 |
+
# Add guidance based on worker type
|
| 284 |
+
if worker == "coder":
|
| 285 |
+
enhanced_instruction += "\nProvide both your calculation method AND the final result value."
|
| 286 |
+
elif worker == "researcher":
|
| 287 |
+
enhanced_instruction += "\nFocus on gathering factual information related to the task."
|
| 288 |
+
|
| 289 |
+
# Add the instruction to the messages
|
| 290 |
+
messages_update = [
|
| 291 |
+
HumanMessage(
|
| 292 |
+
content=f"Step {current_step_index + 1}: {enhanced_instruction}",
|
| 293 |
+
name="supervisor"
|
| 294 |
+
)
|
| 295 |
+
]
|
| 296 |
+
|
| 297 |
+
# Cast worker to appropriate type to satisfy type checking
|
| 298 |
+
worker_destination = cast(SupervisorDestinations, worker)
|
| 299 |
+
|
| 300 |
+
# Move to the appropriate worker
|
| 301 |
+
return Command(
|
| 302 |
+
goto=worker_destination,
|
| 303 |
+
update={
|
| 304 |
+
"messages": messages_update,
|
| 305 |
+
"next": worker, # For backward compatibility
|
| 306 |
+
**state_updates
|
| 307 |
+
}
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
def extract_best_answer_from_context(context):
|
| 311 |
+
"""Extract the best available answer from context.
|
| 312 |
+
|
| 313 |
+
This is a generic function to extract answers from any type of question context.
|
| 314 |
+
It progressively tries different strategies to find a suitable answer.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
context: The state context containing worker outputs
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
Best answer found or "unknown" if nothing suitable is found
|
| 321 |
+
"""
|
| 322 |
+
answer = "unknown"
|
| 323 |
+
|
| 324 |
+
# First check if the coder already provided a properly formatted answer
|
| 325 |
+
if "coder" in context:
|
| 326 |
+
coder_content = context["coder"]
|
| 327 |
+
|
| 328 |
+
# Look for "FINAL ANSWER: X" pattern in the coder output
|
| 329 |
+
import re
|
| 330 |
+
answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", coder_content, re.IGNORECASE)
|
| 331 |
+
if answer_match:
|
| 332 |
+
return answer_match.group(1).strip()
|
| 333 |
+
|
| 334 |
+
# If no answer in coder output, check researcher content
|
| 335 |
+
if "researcher" in context:
|
| 336 |
+
researcher_content = context["researcher"]
|
| 337 |
+
|
| 338 |
+
# Look for lists in the researcher content (common pattern)
|
| 339 |
+
import re
|
| 340 |
+
|
| 341 |
+
# Look for bulleted list items
|
| 342 |
+
list_items = re.findall(r"[-•*]\s+([^:\n]+)", researcher_content)
|
| 343 |
+
if list_items:
|
| 344 |
+
# Format as comma-separated list
|
| 345 |
+
answer = ",".join(item.strip() for item in list_items)
|
| 346 |
+
return answer
|
| 347 |
+
|
| 348 |
+
# Look for emphasized/bold items which might be key information
|
| 349 |
+
bold_items = re.findall(r"\*\*([^*]+)\*\*", researcher_content)
|
| 350 |
+
if bold_items:
|
| 351 |
+
# Join the important items as a comma-separated list
|
| 352 |
+
processed_items = []
|
| 353 |
+
for item in bold_items:
|
| 354 |
+
# Remove common filler words and clean up the item
|
| 355 |
+
clean_item = re.sub(r'(^|\s)(a|an|the|is|are|was|were|be|been)(\s|$)', ' ', item)
|
| 356 |
+
clean_item = clean_item.strip()
|
| 357 |
+
if clean_item and len(clean_item) < 30: # Only include reasonably short items
|
| 358 |
+
processed_items.append(clean_item)
|
| 359 |
+
|
| 360 |
+
if processed_items:
|
| 361 |
+
answer = ",".join(processed_items)
|
| 362 |
+
return answer
|
| 363 |
+
|
| 364 |
+
# If we still don't have an answer, try to extract common entities
|
| 365 |
+
combined_content = ""
|
| 366 |
+
for worker_type, content in context.items():
|
| 367 |
+
combined_content += " " + content
|
| 368 |
+
|
| 369 |
+
# Look for numbers in the content
|
| 370 |
+
import re
|
| 371 |
+
numbers = re.findall(r'\b(\d+(?:\.\d+)?)\b', combined_content)
|
| 372 |
+
if numbers:
|
| 373 |
+
answer = numbers[0] # Use the first number found
|
| 374 |
+
|
| 375 |
+
return answer
|
| 376 |
+
|
| 377 |
+
def has_sufficient_information(state):
|
| 378 |
+
"""Determine if we have enough information to generate a final answer.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
state: The current conversation state
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
Boolean indicating if we have sufficient information
|
| 385 |
+
"""
|
| 386 |
+
context = state.get("context", {})
|
| 387 |
+
|
| 388 |
+
# If we have both researcher and coder outputs, we likely have enough info
|
| 389 |
+
if "researcher" in context and "coder" in context:
|
| 390 |
+
return True
|
| 391 |
+
|
| 392 |
+
# If we have a substantial researcher output, that might be enough
|
| 393 |
+
if "researcher" in context and len(context["researcher"]) > 150:
|
| 394 |
+
return True
|
| 395 |
+
|
| 396 |
+
# If we have any worker output that contains lists or formatted data
|
| 397 |
+
for worker, content in context.items():
|
| 398 |
+
if content and (
|
| 399 |
+
"- " in content or # Bullet point
|
| 400 |
+
"•" in content or # Bullet point
|
| 401 |
+
"*" in content or # Emphasis or bullet
|
| 402 |
+
":" in content # Definition or explanation
|
| 403 |
+
):
|
| 404 |
+
return True
|
| 405 |
+
|
| 406 |
+
return False
|
tools.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module provides tools for the agent supervisor.
|
| 2 |
+
|
| 3 |
+
It includes:
|
| 4 |
+
- Web Search: For general web results using Tavily.
|
| 5 |
+
- Python REPL: For executing Python code (Use with caution!).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Annotated, List, Any, Callable, Optional, cast
|
| 9 |
+
|
| 10 |
+
# Core Tools & Utilities
|
| 11 |
+
from langchain_core.tools import tool
|
| 12 |
+
|
| 13 |
+
# Experimental Tools (Use with caution)
|
| 14 |
+
from langchain_experimental.utilities import PythonREPL
|
| 15 |
+
|
| 16 |
+
# Use TavilySearchResults from langchain_community like in the notebook
|
| 17 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
| 18 |
+
from react_agent.configuration import Configuration
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Create Tavily tool using configuration from context (more consistent approach)
|
| 22 |
+
def create_tavily_tool():
|
| 23 |
+
"""Create the Tavily search tool with configuration from context.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Configured TavilySearchResults tool
|
| 27 |
+
"""
|
| 28 |
+
configuration = Configuration.from_context()
|
| 29 |
+
return TavilySearchResults(max_results=configuration.max_search_results)
|
| 30 |
+
|
| 31 |
+
# Initialize the tool
|
| 32 |
+
tavily_tool = create_tavily_tool()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# --- Python REPL Tool ---
|
| 36 |
+
# WARNING: Executes arbitrary Python code locally. Be extremely careful
|
| 37 |
+
# about exposing this tool, especially in production environments.
|
| 38 |
+
repl = PythonREPL()
|
| 39 |
+
|
| 40 |
+
@tool
|
| 41 |
+
def python_repl_tool(
|
| 42 |
+
code: Annotated[str, "The python code to execute. Use print(...) to see output."],
|
| 43 |
+
):
|
| 44 |
+
"""Use this to execute python code. If you want to see the output of a value,
|
| 45 |
+
you should print it out with `print(...)`. This is visible to the user."""
|
| 46 |
+
try:
|
| 47 |
+
result = repl.run(code)
|
| 48 |
+
except BaseException as e:
|
| 49 |
+
return f"Failed to execute. Error: {repr(e)}"
|
| 50 |
+
# Filter out potentially sensitive REPL implementation details
|
| 51 |
+
result_str = f"Successfully executed:\n\`\`\`python\n{code}\n\`\`\`\nStdout: {result}"
|
| 52 |
+
return result_str
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# --- Tool List ---
|
| 56 |
+
|
| 57 |
+
# The list of tools available to the agent supervisor.
|
| 58 |
+
TOOLS: List[Callable[..., Any]] = [tavily_tool, python_repl_tool]
|
utils.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility & helper functions."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from langchain.chat_models import init_chat_model
|
| 6 |
+
from langchain_core.language_models import BaseChatModel
|
| 7 |
+
from langchain_core.messages import BaseMessage
|
| 8 |
+
import asyncio
|
| 9 |
+
from datetime import UTC, datetime
|
| 10 |
+
from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Load environment variables from .env file
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_message_text(msg: BaseMessage) -> str:
|
| 18 |
+
"""Get the text content of a message."""
|
| 19 |
+
content = msg.content
|
| 20 |
+
if isinstance(content, str):
|
| 21 |
+
return content
|
| 22 |
+
elif isinstance(content, dict):
|
| 23 |
+
return content.get("text", "")
|
| 24 |
+
else:
|
| 25 |
+
txts = [c if isinstance(c, str) else (c.get("text") or "") for c in content]
|
| 26 |
+
return "".join(txts).strip()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def format_system_prompt(prompt_template: str) -> str:
|
| 30 |
+
"""Format a system prompt template with current system time and available agents.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
prompt_template: The prompt template to format
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
The formatted prompt with system time and agent information
|
| 37 |
+
"""
|
| 38 |
+
# Get example workers for templates
|
| 39 |
+
example_worker_1 = WORKERS[0] if WORKERS else "researcher"
|
| 40 |
+
example_worker_2 = WORKERS[1] if len(WORKERS) > 1 else "coder"
|
| 41 |
+
|
| 42 |
+
# Get verdicts for templates
|
| 43 |
+
correct_verdict = VERDICTS[0] if VERDICTS else "CORRECT"
|
| 44 |
+
retry_verdict = VERDICTS[1] if len(VERDICTS) > 1 else "RETRY"
|
| 45 |
+
|
| 46 |
+
return prompt_template.format(
|
| 47 |
+
system_time=datetime.now(tz=UTC).isoformat(),
|
| 48 |
+
workers=", ".join(WORKERS),
|
| 49 |
+
members=", ".join(MEMBERS),
|
| 50 |
+
worker_options=", ".join([f'"{w}"' for w in WORKERS]),
|
| 51 |
+
example_worker_1=example_worker_1,
|
| 52 |
+
example_worker_2=example_worker_2,
|
| 53 |
+
correct_verdict=correct_verdict,
|
| 54 |
+
retry_verdict=retry_verdict
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_chat_model(fully_specified_name: str) -> BaseChatModel:
|
| 59 |
+
"""Load a chat model from a fully specified name.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
fully_specified_name (str): String in the format 'provider/model'.
|
| 63 |
+
"""
|
| 64 |
+
provider, model = fully_specified_name.split("/", maxsplit=1)
|
| 65 |
+
|
| 66 |
+
# Special handling for Google Genai models to ensure they're configured for async
|
| 67 |
+
if provider == "google_genai":
|
| 68 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 69 |
+
|
| 70 |
+
# Make sure we have the API key
|
| 71 |
+
if not os.environ.get("GOOGLE_API_KEY"):
|
| 72 |
+
raise ValueError("GOOGLE_API_KEY environment variable is required for google_genai models")
|
| 73 |
+
|
| 74 |
+
return ChatGoogleGenerativeAI(model=model)
|
| 75 |
+
else:
|
| 76 |
+
return init_chat_model(model, model_provider=provider)
|