Spaces:
Sleeping
Sleeping
| """Main orchestrator for coordinating all components.""" | |
| import logging | |
| from typing import Dict, Any, Optional | |
| from enum import Enum | |
| from src.core.config import get_settings | |
| from src.retrieval.vector_store import get_vector_store | |
| from src.agents.local_data_agent import LocalDataAgent | |
| from src.agents.search_agent import SearchAgent | |
| from src.agents.cloud_agent import CloudAgent | |
| from src.agents.aggregator_agent import AggregatorAgent | |
| from src.agents.snowflake_agent import SnowflakeAgent | |
| from src.tools.calculator import get_calculator | |
| from src.tools.web_search import get_web_search | |
| from src.tools.database_query import get_database_query | |
| from openai import OpenAI | |
| logger = logging.getLogger(__name__) | |
| class Tier(Enum): | |
| """System tiers.""" | |
| BASIC_RAG = "basic" | |
| AGENT_WITH_TOOLS = "agent" | |
| ADVANCED_AGENTIC = "advanced" | |
| class Orchestrator: | |
| """Main orchestrator for the RAG system.""" | |
| def __init__(self): | |
| """Initialize orchestrator.""" | |
| self.settings = get_settings() | |
| self.client = OpenAI(**self.settings.get_openai_client_kwargs()) | |
| self.model = self.settings.openai_model | |
| # Initialize components | |
| self.vector_store = get_vector_store() | |
| # Initialize agents (lazy loading) | |
| self._local_agent: Optional[LocalDataAgent] = None | |
| self._search_agent: Optional[SearchAgent] = None | |
| self._cloud_agent: Optional[CloudAgent] = None | |
| self._snowflake_agent: Optional[SnowflakeAgent] = None | |
| self._aggregator_agent: Optional[AggregatorAgent] = None | |
| # Initialize tools | |
| self.calculator = get_calculator() | |
| self.web_search = get_web_search() | |
| self.database_query = get_database_query() | |
| async def process_query( | |
| self, | |
| query: str, | |
| tier: str = "basic", | |
| session_id: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Process a query using the specified tier. | |
| Args: | |
| query: User query | |
| tier: System tier ("basic", "agent", or "advanced") | |
| session_id: Optional session ID for memory | |
| Returns: | |
| Response dictionary | |
| """ | |
| try: | |
| tier_enum = Tier(tier.lower()) | |
| if tier_enum == Tier.BASIC_RAG: | |
| return await self._process_basic_rag(query, session_id) | |
| elif tier_enum == Tier.AGENT_WITH_TOOLS: | |
| return await self._process_agent_with_tools(query, session_id) | |
| elif tier_enum == Tier.ADVANCED_AGENTIC: | |
| return await self._process_advanced_agentic(query, session_id) | |
| else: | |
| raise ValueError(f"Unknown tier: {tier}") | |
| except ValueError as e: | |
| logger.error(f"Invalid tier: {e}") | |
| return { | |
| "success": False, | |
| "error": f"Invalid tier: {tier}", | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing query: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| } | |
| async def _process_basic_rag( | |
| self, | |
| query: str, | |
| session_id: Optional[str], | |
| ) -> Dict[str, Any]: | |
| """Process query using basic RAG (retrieval + generation).""" | |
| try: | |
| # Check if OpenAI API key is configured | |
| if not self.settings.openai_api_key: | |
| return { | |
| "success": False, | |
| "error": "OpenAI API key not configured. Please set OPENAI_API_KEY in your .env file.", | |
| "tier": "basic", | |
| } | |
| # Retrieve relevant documents | |
| results = self.vector_store.search(query=query, n_results=5) | |
| # Build context - use retrieved documents if available, otherwise use empty context | |
| if results["documents"]: | |
| context_parts = ["Retrieved documents:"] | |
| for i, (doc, metadata) in enumerate( | |
| zip(results["documents"], results["metadatas"]), 1 | |
| ): | |
| source = metadata.get("source", "Unknown") | |
| context_parts.append(f"\n[{i}] Source: {source}") | |
| # Ensure doc is a string | |
| doc_str = str(doc) if doc else "" | |
| context_parts.append(f"Content: {doc_str[:500]}...") | |
| context = "\n".join(context_parts) | |
| sources = [ | |
| {"id": id, "metadata": meta} | |
| for id, meta in zip(results["ids"], results["metadatas"]) | |
| ] | |
| else: | |
| context = "No relevant documents found in the knowledge base." | |
| sources = [] | |
| # Generate response using LLM | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful assistant that answers questions based on the provided context.", | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Context:\n{context}\n\nQuestion: {query}", | |
| }, | |
| ] | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=0.7, | |
| ) | |
| answer = response.choices[0].message.content | |
| except Exception as api_error: | |
| error_msg = str(api_error) | |
| if "quota" in error_msg.lower() or "429" in error_msg: | |
| raise Exception("OpenAI API quota exceeded. Please check your billing and plan details.") | |
| elif "api key" in error_msg.lower() or "401" in error_msg: | |
| raise Exception("Invalid OpenAI API key. Please check your .env file.") | |
| else: | |
| raise Exception(f"OpenAI API error: {error_msg}") | |
| return { | |
| "success": True, | |
| "answer": answer, | |
| "tier": "basic", | |
| "sources": sources, | |
| "model": self.model, | |
| } | |
| except Exception as e: | |
| logger.error(f"Error in basic RAG: {e}", exc_info=True) | |
| return { | |
| "success": False, | |
| "error": f"Error processing query: {str(e)}", | |
| "tier": "basic", | |
| } | |
| async def _process_agent_with_tools( | |
| self, | |
| query: str, | |
| session_id: Optional[str], | |
| ) -> Dict[str, Any]: | |
| """Process query using agent with tools.""" | |
| try: | |
| # Check if OpenAI API key is configured | |
| if not self.settings.openai_api_key: | |
| return { | |
| "success": False, | |
| "error": "OpenAI API key not configured. Please set OPENAI_API_KEY in your .env file.", | |
| "tier": "agent", | |
| } | |
| # Use local agent with tools enabled | |
| if not self._local_agent: | |
| self._local_agent = LocalDataAgent(use_planning=True) | |
| # Add tools to agent | |
| self._local_agent.add_tool( | |
| tool=self.calculator.get_tool_schema(), | |
| tool_function=lambda expression: self.calculator.calculate(expression), | |
| ) | |
| if self.settings.has_web_search(): | |
| async def web_search_tool(query: str, max_results: int = 5): | |
| return await self.web_search.search(query, max_results) | |
| self._local_agent.add_tool( | |
| tool=self.web_search.get_tool_schema(), | |
| tool_function=web_search_tool, | |
| ) | |
| if self.settings.database_url: | |
| def db_query_tool(sql: str, limit: int = 100): | |
| return self.database_query.query(sql, limit) | |
| self._local_agent.add_tool( | |
| tool=self.database_query.get_tool_schema(), | |
| tool_function=db_query_tool, | |
| ) | |
| # Process query | |
| response = await self._local_agent.process(query, session_id) | |
| return { | |
| **response, | |
| "tier": "agent", | |
| } | |
| except Exception as e: | |
| logger.error(f"Error in agent with tools: {e}", exc_info=True) | |
| return { | |
| "success": False, | |
| "error": f"Error processing query: {str(e)}", | |
| "tier": "agent", | |
| } | |
| async def _process_advanced_agentic( | |
| self, | |
| query: str, | |
| session_id: Optional[str], | |
| ) -> Dict[str, Any]: | |
| """Process query using advanced agentic RAG with multiple agents.""" | |
| try: | |
| # Check if OpenAI API key is configured | |
| if not self.settings.openai_api_key: | |
| return { | |
| "success": False, | |
| "error": "OpenAI API key not configured. Please set OPENAI_API_KEY in your .env file.", | |
| "tier": "advanced", | |
| } | |
| # Use aggregator agent | |
| if not self._aggregator_agent: | |
| self._aggregator_agent = AggregatorAgent(use_planning=True) | |
| # Add Snowflake agent if configured | |
| if self.settings.has_snowflake() and not self._snowflake_agent: | |
| snowflake_config = self.settings.get_snowflake_config() | |
| self._snowflake_agent = SnowflakeAgent( | |
| snowflake_config=snowflake_config, | |
| use_planning=False | |
| ) | |
| # Note: AggregatorAgent will automatically discover SnowflakeAgent | |
| # through its agent selection logic | |
| # Process query | |
| response = await self._aggregator_agent.process(query, session_id) | |
| return { | |
| **response, | |
| "tier": "advanced", | |
| } | |
| except Exception as e: | |
| logger.error(f"Error in advanced agentic: {e}", exc_info=True) | |
| return { | |
| "success": False, | |
| "error": f"Error processing query: {str(e)}", | |
| "tier": "advanced", | |
| } | |
| def get_agent_status(self) -> Dict[str, Any]: | |
| """Get status of all agents.""" | |
| status = { | |
| "tiers_available": ["basic", "agent", "advanced"], | |
| "agents": {}, | |
| } | |
| if self._local_agent: | |
| status["agents"]["local"] = self._local_agent.get_status() | |
| if self._search_agent: | |
| status["agents"]["search"] = self._search_agent.get_status() | |
| if self._cloud_agent: | |
| status["agents"]["cloud"] = self._cloud_agent.get_status() | |
| if self._snowflake_agent: | |
| status["agents"]["snowflake"] = self._snowflake_agent.get_status() | |
| if self._aggregator_agent: | |
| status["agents"]["aggregator"] = self._aggregator_agent.get_status() | |
| return status | |
| def get_system_info(self) -> Dict[str, Any]: | |
| """Get system information.""" | |
| return { | |
| "vector_store": { | |
| "document_count": self.vector_store.count(), | |
| "collection_name": self.settings.chroma_collection_name, | |
| }, | |
| "tools": { | |
| "calculator": True, | |
| "web_search": self.settings.has_web_search(), | |
| "database": bool(self.settings.database_url), | |
| "snowflake": self.settings.has_snowflake(), | |
| }, | |
| "memory": { | |
| "short_term_enabled": True, | |
| "long_term_enabled": self.settings.long_term_memory_enabled, | |
| }, | |
| "model": self.model, | |
| } | |
| # Global instance | |
| _orchestrator: Optional[Orchestrator] = None | |
| def get_orchestrator() -> Orchestrator: | |
| """Get or create the global orchestrator instance.""" | |
| global _orchestrator | |
| if _orchestrator is None: | |
| _orchestrator = Orchestrator() | |
| return _orchestrator | |