AgenticAI-RAG / src /agents /aggregator_agent.py
GreymanT's picture
Upload 80 files
8bf4d58 verified
"""Aggregator agent that coordinates multiple specialized agents."""
import logging
from typing import List, Dict, Any, Optional
from openai import OpenAI
from src.agents.base_agent import BaseAgent
from src.agents.local_data_agent import LocalDataAgent
from src.agents.search_agent import SearchAgent
from src.agents.cloud_agent import CloudAgent
from src.core.config import get_settings
logger = logging.getLogger(__name__)
class AggregatorAgent(BaseAgent):
"""Agent that coordinates multiple specialized agents and aggregates responses."""
def __init__(self, use_planning: bool = True):
"""Initialize aggregator agent."""
super().__init__(
name="aggregator_agent",
description=(
"You are an aggregator agent that coordinates multiple specialized agents "
"to answer complex questions. You route queries to appropriate agents and "
"synthesize their responses into a comprehensive answer."
),
use_memory=True,
use_planning=use_planning,
planning_type="cot",
)
# Initialize specialized agents
self.local_agent = LocalDataAgent(use_planning=False)
self.search_agent = SearchAgent(use_planning=True)
self.cloud_agent = CloudAgent(use_planning=False)
# Initialize Snowflake agent if configured
self.snowflake_agent = None
from src.core.config import get_settings
settings = get_settings()
if settings.has_snowflake():
from src.agents.snowflake_agent import SnowflakeAgent
snowflake_config = settings.get_snowflake_config()
self.snowflake_agent = SnowflakeAgent(
snowflake_config=snowflake_config,
use_planning=False
)
self.agents = {
"local": self.local_agent,
"search": self.search_agent,
"cloud": self.cloud_agent,
}
if self.snowflake_agent:
self.agents["snowflake"] = self.snowflake_agent
async def retrieve_context(self, query: str) -> str:
"""
Retrieve context by querying relevant agents.
Args:
query: User query
Returns:
Aggregated context string
"""
# Determine which agents to query based on query content
agents_to_query = self._select_agents(query)
# Query selected agents in parallel
results = {}
for agent_name, agent in agents_to_query.items():
try:
context = await agent.retrieve_context(query)
results[agent_name] = context
except Exception as e:
logger.error(f"Error querying {agent_name} agent: {e}")
results[agent_name] = f"Error: {str(e)}"
# Combine results
context_parts = ["Context from specialized agents:"]
for agent_name, context in results.items():
context_parts.append(f"\n--- {agent_name.upper()} AGENT ---")
context_parts.append(context)
return "\n".join(context_parts)
def _select_agents(self, query: str) -> Dict[str, BaseAgent]:
"""
Select which agents to query based on the query content.
Args:
query: User query
Returns:
Dictionary of agent names to agents
"""
query_lower = query.lower()
selected = {}
# Always include local agent for document queries
if any(keyword in query_lower for keyword in ["document", "file", "local", "data"]):
selected["local"] = self.local_agent
# Include search agent for current information or web queries
if any(keyword in query_lower for keyword in [
"current", "latest", "recent", "news", "web", "internet", "online", "search"
]):
selected["search"] = self.search_agent
# Include cloud agent for cloud-related queries
if any(keyword in query_lower for keyword in ["cloud", "s3", "gcs", "storage", "remote"]):
selected["cloud"] = self.cloud_agent
# Include Snowflake agent for database/data warehouse queries
if self.snowflake_agent and any(keyword in query_lower for keyword in [
"snowflake", "data warehouse", "sql", "database", "query", "table", "schema"
]):
selected["snowflake"] = self.snowflake_agent
# If no specific match, use local and search by default
if not selected:
selected["local"] = self.local_agent
selected["search"] = self.search_agent
return selected
async def process(
self,
query: str,
session_id: Optional[str] = None,
context: Optional[str] = None,
) -> dict:
"""
Process query by coordinating multiple agents.
Args:
query: User query
session_id: Optional session ID
context: Optional additional context
Returns:
Aggregated response dictionary
"""
# Select agents to query
agents_to_query = self._select_agents(query)
# Get responses from selected agents
agent_responses = {}
for agent_name, agent in agents_to_query.items():
try:
response = await agent.process(query, session_id, context)
agent_responses[agent_name] = response
except Exception as e:
logger.error(f"Error processing with {agent_name} agent: {e}")
agent_responses[agent_name] = {
"success": False,
"error": str(e),
}
# Synthesize responses
synthesized_response = await self._synthesize_responses(
query=query,
agent_responses=agent_responses,
session_id=session_id,
)
return synthesized_response
async def _synthesize_responses(
self,
query: str,
agent_responses: Dict[str, dict],
session_id: Optional[str],
) -> dict:
"""
Synthesize responses from multiple agents.
Args:
query: Original query
agent_responses: Dictionary of agent responses
session_id: Optional session ID
Returns:
Synthesized response
"""
# Collect successful responses
successful_responses = {
name: resp for name, resp in agent_responses.items()
if resp.get("success", False)
}
if not successful_responses:
# If no successful responses, try to return the first response with error details
error_messages = []
for name, resp in agent_responses.items():
error_msg = resp.get("error", "Unknown error")
error_messages.append(f"{name}: {error_msg}")
return {
"success": False,
"error": f"No agents provided successful responses. Errors: {'; '.join(error_messages)}",
"agent_responses": agent_responses,
}
# If only one agent responded, return its response
if len(successful_responses) == 1:
response = list(successful_responses.values())[0]
response["aggregated_by"] = "single_agent"
return response
# Multiple responses - synthesize using LLM
try:
# Build synthesis prompt
synthesis_parts = [
"You are synthesizing responses from multiple specialized agents.",
f"Original question: {query}",
"",
"Agent responses:",
]
for agent_name, response in successful_responses.items():
answer = response.get("answer", "No answer provided")
synthesis_parts.append(f"\n{agent_name.upper()} Agent:")
synthesis_parts.append(answer)
synthesis_parts.extend([
"",
"Synthesize these responses into a comprehensive, coherent answer.",
"If there are conflicts, note them. If information is complementary, combine it.",
])
synthesis_prompt = "\n".join(synthesis_parts)
# Call LLM for synthesis
messages = [
{"role": "system", "content": self.description},
{"role": "user", "content": synthesis_prompt},
]
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.7,
)
synthesized_answer = response.choices[0].message.content
return {
"success": True,
"answer": synthesized_answer,
"agent": self.name,
"aggregated_by": "multiple_agents",
"source_agents": list(successful_responses.keys()),
"agent_responses": successful_responses,
"model": self.model,
}
except Exception as e:
logger.error(f"Error synthesizing responses: {e}")
# Fallback: return first successful response
first_response = list(successful_responses.values())[0]
first_response["aggregated_by"] = "fallback"
return first_response