AgenticAI-RAG / src /core /orchestrator.py
GreymanT's picture
Upload 80 files
8bf4d58 verified
"""Main orchestrator for coordinating all components."""
import logging
from typing import Dict, Any, Optional
from enum import Enum
from src.core.config import get_settings
from src.retrieval.vector_store import get_vector_store
from src.agents.local_data_agent import LocalDataAgent
from src.agents.search_agent import SearchAgent
from src.agents.cloud_agent import CloudAgent
from src.agents.aggregator_agent import AggregatorAgent
from src.agents.snowflake_agent import SnowflakeAgent
from src.tools.calculator import get_calculator
from src.tools.web_search import get_web_search
from src.tools.database_query import get_database_query
from openai import OpenAI
logger = logging.getLogger(__name__)
class Tier(Enum):
"""System tiers."""
BASIC_RAG = "basic"
AGENT_WITH_TOOLS = "agent"
ADVANCED_AGENTIC = "advanced"
class Orchestrator:
"""Main orchestrator for the RAG system."""
def __init__(self):
"""Initialize orchestrator."""
self.settings = get_settings()
self.client = OpenAI(**self.settings.get_openai_client_kwargs())
self.model = self.settings.openai_model
# Initialize components
self.vector_store = get_vector_store()
# Initialize agents (lazy loading)
self._local_agent: Optional[LocalDataAgent] = None
self._search_agent: Optional[SearchAgent] = None
self._cloud_agent: Optional[CloudAgent] = None
self._snowflake_agent: Optional[SnowflakeAgent] = None
self._aggregator_agent: Optional[AggregatorAgent] = None
# Initialize tools
self.calculator = get_calculator()
self.web_search = get_web_search()
self.database_query = get_database_query()
async def process_query(
self,
query: str,
tier: str = "basic",
session_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Process a query using the specified tier.
Args:
query: User query
tier: System tier ("basic", "agent", or "advanced")
session_id: Optional session ID for memory
Returns:
Response dictionary
"""
try:
tier_enum = Tier(tier.lower())
if tier_enum == Tier.BASIC_RAG:
return await self._process_basic_rag(query, session_id)
elif tier_enum == Tier.AGENT_WITH_TOOLS:
return await self._process_agent_with_tools(query, session_id)
elif tier_enum == Tier.ADVANCED_AGENTIC:
return await self._process_advanced_agentic(query, session_id)
else:
raise ValueError(f"Unknown tier: {tier}")
except ValueError as e:
logger.error(f"Invalid tier: {e}")
return {
"success": False,
"error": f"Invalid tier: {tier}",
}
except Exception as e:
logger.error(f"Error processing query: {e}")
return {
"success": False,
"error": str(e),
}
async def _process_basic_rag(
self,
query: str,
session_id: Optional[str],
) -> Dict[str, Any]:
"""Process query using basic RAG (retrieval + generation)."""
try:
# Check if OpenAI API key is configured
if not self.settings.openai_api_key:
return {
"success": False,
"error": "OpenAI API key not configured. Please set OPENAI_API_KEY in your .env file.",
"tier": "basic",
}
# Retrieve relevant documents
results = self.vector_store.search(query=query, n_results=5)
# Build context - use retrieved documents if available, otherwise use empty context
if results["documents"]:
context_parts = ["Retrieved documents:"]
for i, (doc, metadata) in enumerate(
zip(results["documents"], results["metadatas"]), 1
):
source = metadata.get("source", "Unknown")
context_parts.append(f"\n[{i}] Source: {source}")
# Ensure doc is a string
doc_str = str(doc) if doc else ""
context_parts.append(f"Content: {doc_str[:500]}...")
context = "\n".join(context_parts)
sources = [
{"id": id, "metadata": meta}
for id, meta in zip(results["ids"], results["metadatas"])
]
else:
context = "No relevant documents found in the knowledge base."
sources = []
# Generate response using LLM
messages = [
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on the provided context.",
},
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {query}",
},
]
try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.7,
)
answer = response.choices[0].message.content
except Exception as api_error:
error_msg = str(api_error)
if "quota" in error_msg.lower() or "429" in error_msg:
raise Exception("OpenAI API quota exceeded. Please check your billing and plan details.")
elif "api key" in error_msg.lower() or "401" in error_msg:
raise Exception("Invalid OpenAI API key. Please check your .env file.")
else:
raise Exception(f"OpenAI API error: {error_msg}")
return {
"success": True,
"answer": answer,
"tier": "basic",
"sources": sources,
"model": self.model,
}
except Exception as e:
logger.error(f"Error in basic RAG: {e}", exc_info=True)
return {
"success": False,
"error": f"Error processing query: {str(e)}",
"tier": "basic",
}
async def _process_agent_with_tools(
self,
query: str,
session_id: Optional[str],
) -> Dict[str, Any]:
"""Process query using agent with tools."""
try:
# Check if OpenAI API key is configured
if not self.settings.openai_api_key:
return {
"success": False,
"error": "OpenAI API key not configured. Please set OPENAI_API_KEY in your .env file.",
"tier": "agent",
}
# Use local agent with tools enabled
if not self._local_agent:
self._local_agent = LocalDataAgent(use_planning=True)
# Add tools to agent
self._local_agent.add_tool(
tool=self.calculator.get_tool_schema(),
tool_function=lambda expression: self.calculator.calculate(expression),
)
if self.settings.has_web_search():
async def web_search_tool(query: str, max_results: int = 5):
return await self.web_search.search(query, max_results)
self._local_agent.add_tool(
tool=self.web_search.get_tool_schema(),
tool_function=web_search_tool,
)
if self.settings.database_url:
def db_query_tool(sql: str, limit: int = 100):
return self.database_query.query(sql, limit)
self._local_agent.add_tool(
tool=self.database_query.get_tool_schema(),
tool_function=db_query_tool,
)
# Process query
response = await self._local_agent.process(query, session_id)
return {
**response,
"tier": "agent",
}
except Exception as e:
logger.error(f"Error in agent with tools: {e}", exc_info=True)
return {
"success": False,
"error": f"Error processing query: {str(e)}",
"tier": "agent",
}
async def _process_advanced_agentic(
self,
query: str,
session_id: Optional[str],
) -> Dict[str, Any]:
"""Process query using advanced agentic RAG with multiple agents."""
try:
# Check if OpenAI API key is configured
if not self.settings.openai_api_key:
return {
"success": False,
"error": "OpenAI API key not configured. Please set OPENAI_API_KEY in your .env file.",
"tier": "advanced",
}
# Use aggregator agent
if not self._aggregator_agent:
self._aggregator_agent = AggregatorAgent(use_planning=True)
# Add Snowflake agent if configured
if self.settings.has_snowflake() and not self._snowflake_agent:
snowflake_config = self.settings.get_snowflake_config()
self._snowflake_agent = SnowflakeAgent(
snowflake_config=snowflake_config,
use_planning=False
)
# Note: AggregatorAgent will automatically discover SnowflakeAgent
# through its agent selection logic
# Process query
response = await self._aggregator_agent.process(query, session_id)
return {
**response,
"tier": "advanced",
}
except Exception as e:
logger.error(f"Error in advanced agentic: {e}", exc_info=True)
return {
"success": False,
"error": f"Error processing query: {str(e)}",
"tier": "advanced",
}
def get_agent_status(self) -> Dict[str, Any]:
"""Get status of all agents."""
status = {
"tiers_available": ["basic", "agent", "advanced"],
"agents": {},
}
if self._local_agent:
status["agents"]["local"] = self._local_agent.get_status()
if self._search_agent:
status["agents"]["search"] = self._search_agent.get_status()
if self._cloud_agent:
status["agents"]["cloud"] = self._cloud_agent.get_status()
if self._snowflake_agent:
status["agents"]["snowflake"] = self._snowflake_agent.get_status()
if self._aggregator_agent:
status["agents"]["aggregator"] = self._aggregator_agent.get_status()
return status
def get_system_info(self) -> Dict[str, Any]:
"""Get system information."""
return {
"vector_store": {
"document_count": self.vector_store.count(),
"collection_name": self.settings.chroma_collection_name,
},
"tools": {
"calculator": True,
"web_search": self.settings.has_web_search(),
"database": bool(self.settings.database_url),
"snowflake": self.settings.has_snowflake(),
},
"memory": {
"short_term_enabled": True,
"long_term_enabled": self.settings.long_term_memory_enabled,
},
"model": self.model,
}
# Global instance
_orchestrator: Optional[Orchestrator] = None
def get_orchestrator() -> Orchestrator:
"""Get or create the global orchestrator instance."""
global _orchestrator
if _orchestrator is None:
_orchestrator = Orchestrator()
return _orchestrator