Spaces:
Sleeping
Sleeping
| # ============================================================= | |
| # File: backend/api/services/agent_orchestrator.py | |
| # ============================================================= | |
| """ | |
| Agent Orchestrator (integrated with enterprise RedFlagDetector) | |
| Place at: backend/api/services/agent_orchestrator.py | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| from typing import List, Dict, Any, Optional | |
| from ..models.agent import AgentRequest, AgentDecision, AgentResponse | |
| from ..models.redflag import RedFlagMatch | |
| from .redflag_detector import RedFlagDetector | |
| from .intent_classifier import IntentClassifier | |
| from .tool_selector import ToolSelector | |
| from .llm_client import LLMClient | |
| from ..mcp_clients.mcp_client import MCPClient | |
| from .tool_scoring import ToolScoringService | |
| class AgentOrchestrator: | |
| def __init__(self, rag_mcp_url: str, web_mcp_url: str, admin_mcp_url: str, llm_backend: str = "ollama"): | |
| self.mcp = MCPClient(rag_mcp_url, web_mcp_url, admin_mcp_url) | |
| self.llm = LLMClient(backend=llm_backend, url=os.getenv("OLLAMA_URL"), api_key=os.getenv("GROQ_API_KEY"), model=os.getenv("OLLAMA_MODEL")) | |
| # pass admin_mcp_url so detector can call back | |
| self.redflag = RedFlagDetector( | |
| supabase_url=os.getenv("SUPABASE_URL"), | |
| supabase_key=os.getenv("SUPABASE_SERVICE_KEY"), | |
| admin_mcp_url=admin_mcp_url | |
| ) | |
| self.intent = IntentClassifier(llm_client=self.llm) | |
| self.selector = ToolSelector(llm_client=self.llm) | |
| self.tool_scorer = ToolScoringService() | |
| async def handle(self, req: AgentRequest) -> AgentResponse: | |
| reasoning_trace: List[Dict[str, Any]] = [] | |
| reasoning_trace.append({ | |
| "step": "request_received", | |
| "tenant_id": req.tenant_id, | |
| "user_id": req.user_id, | |
| "message_preview": req.message[:120] | |
| }) | |
| # 1) Red-flag check (async) | |
| matches: List[RedFlagMatch] = await self.redflag.check(req.tenant_id, req.message) | |
| reasoning_trace.append({ | |
| "step": "redflag_check", | |
| "match_count": len(matches), | |
| "matches": [m.__dict__ for m in matches] | |
| }) | |
| if matches: | |
| # Notify admin asynchronously (do not await blocking the response path if you prefer) | |
| # we await here to ensure admin receives the alert before responding | |
| try: | |
| await self.redflag.notify_admin(req.tenant_id, matches, source_payload={"message": req.message, "user_id": req.user_id}) | |
| except Exception: | |
| pass | |
| decision = AgentDecision( | |
| action="block", | |
| tool="admin", | |
| tool_input={"violations": [m.__dict__ for m in matches]}, | |
| reason="redflag_triggered" | |
| ) | |
| return AgentResponse( | |
| text="Your request has been blocked due to policy.", | |
| decision=decision, | |
| tool_traces=[{"redflags": [m.__dict__ for m in matches]}], | |
| reasoning_trace=reasoning_trace | |
| ) | |
| # 2) Intent classification | |
| intent = await self.intent.classify(req.message) | |
| reasoning_trace.append({ | |
| "step": "intent_detection", | |
| "intent": intent | |
| }) | |
| # 2.5) Pre-fetch RAG results if available (for tool selector context) | |
| rag_prefetch = None | |
| rag_results = [] | |
| try: | |
| # Try to pre-fetch RAG to help tool selector make better decisions | |
| rag_prefetch = await self.mcp.call_rag(req.tenant_id, req.message) | |
| if isinstance(rag_prefetch, dict): | |
| rag_results = rag_prefetch.get("results") or rag_prefetch.get("hits") or [] | |
| reasoning_trace.append({ | |
| "step": "rag_prefetch", | |
| "status": "ok", | |
| "hit_count": len(rag_results) | |
| }) | |
| except Exception as pref_err: | |
| # If RAG fails, continue without it | |
| reasoning_trace.append({ | |
| "step": "rag_prefetch", | |
| "status": "error", | |
| "error": str(pref_err) | |
| }) | |
| rag_prefetch = None | |
| tool_scores = self.tool_scorer.score(req.message, intent, rag_results) | |
| reasoning_trace.append({ | |
| "step": "tool_scoring", | |
| "scores": tool_scores | |
| }) | |
| # 3) Tool selection (hybrid) - pass RAG results in context | |
| ctx = { | |
| "tenant_id": req.tenant_id, | |
| "rag_results": rag_results, | |
| "tool_scores": tool_scores | |
| } | |
| decision = await self.selector.select(intent, req.message, ctx) | |
| reasoning_trace.append({ | |
| "step": "tool_selection", | |
| "decision": decision.dict(), | |
| "context_scores": tool_scores | |
| }) | |
| tool_traces: List[Dict[str, Any]] = [] | |
| # 4) Handle multi-step tool execution | |
| if decision.action == "multi_step" and decision.tool_input: | |
| steps = decision.tool_input.get("steps", []) | |
| if steps: | |
| return await self._execute_multi_step( | |
| req, | |
| steps, | |
| decision, | |
| tool_traces, | |
| reasoning_trace, | |
| rag_prefetch | |
| ) | |
| # 5) Execute single tool | |
| if decision.action == "call_tool" and decision.tool: | |
| try: | |
| if decision.tool == "rag": | |
| rag_resp = await self.mcp.call_rag(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message) | |
| tool_traces.append({"tool": "rag", "response": rag_resp}) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "hit_count": len(self._extract_hits(rag_resp)), | |
| "summary": self._summarize_hits(rag_resp, limit=2) | |
| }) | |
| prompt = self._build_prompt_with_rag(req, rag_resp) | |
| llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) | |
| reasoning_trace.append({ | |
| "step": "llm_response", | |
| "mode": "rag_synthesis" | |
| }) | |
| return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) | |
| if decision.tool == "web": | |
| web_resp = await self.mcp.call_web(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message) | |
| tool_traces.append({"tool": "web", "response": web_resp}) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "hit_count": len(self._extract_hits(web_resp)), | |
| "summary": self._summarize_hits(web_resp, limit=2) | |
| }) | |
| prompt = self._build_prompt_with_web(req, web_resp) | |
| llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) | |
| reasoning_trace.append({ | |
| "step": "llm_response", | |
| "mode": "web_synthesis" | |
| }) | |
| return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) | |
| if decision.tool == "admin": | |
| admin_resp = await self.mcp.call_admin(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message) | |
| tool_traces.append({"tool": "admin", "response": admin_resp}) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "admin", | |
| "status": "completed" | |
| }) | |
| return AgentResponse(text=json.dumps(admin_resp), decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) | |
| if decision.tool == "llm": | |
| llm_out = await self.llm.simple_call(req.message, temperature=req.temperature) | |
| reasoning_trace.append({ | |
| "step": "llm_response", | |
| "mode": "direct" | |
| }) | |
| return AgentResponse(text=llm_out, decision=decision, reasoning_trace=reasoning_trace) | |
| except Exception as e: | |
| tool_traces.append({"tool": decision.tool, "error": str(e)}) | |
| try: | |
| fallback = await self.llm.simple_call(req.message, temperature=req.temperature) | |
| except Exception as llm_error: | |
| error_msg = str(llm_error) | |
| if "Cannot connect" in error_msg or "Ollama" in error_msg: | |
| fallback = ( | |
| f"I encountered an error while processing your request: {str(e)}\n\n" | |
| f"Additionally, the AI service (Ollama) is unavailable: {error_msg}\n\n" | |
| f"To fix:\n" | |
| f"1. Install Ollama from https://ollama.ai\n" | |
| f"2. Start: `ollama serve`\n" | |
| f"3. Pull model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`" | |
| ) | |
| else: | |
| fallback = f"I encountered an error while processing your request: {str(e)}. Additionally, the AI service is unavailable: {error_msg}" | |
| return AgentResponse( | |
| text=fallback, | |
| decision=AgentDecision(action="respond", tool=None, tool_input=None, reason=f"tool_error_fallback: {e}"), | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace + [{ | |
| "step": "error", | |
| "tool": decision.tool, | |
| "error": str(e) | |
| }] | |
| ) | |
| # Default: direct LLM response | |
| try: | |
| llm_out = await self.llm.simple_call(req.message, temperature=req.temperature) | |
| except Exception as e: | |
| # If LLM fails, return a helpful error message | |
| error_msg = str(e) | |
| if "Cannot connect" in error_msg or "Ollama" in error_msg: | |
| llm_out = ( | |
| f"I couldn't connect to the AI service (Ollama). " | |
| f"Error: {error_msg}\n\n" | |
| f"To fix this:\n" | |
| f"1. Install Ollama from https://ollama.ai\n" | |
| f"2. Start Ollama: `ollama serve`\n" | |
| f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n" | |
| f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file" | |
| ) | |
| else: | |
| llm_out = f"I apologize, but I'm unable to process your request right now. The AI service is unavailable: {error_msg}" | |
| reasoning_trace.append({ | |
| "step": "error", | |
| "tool": "llm", | |
| "error": str(e) | |
| }) | |
| return AgentResponse( | |
| text=llm_out, | |
| decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="default_llm"), | |
| reasoning_trace=reasoning_trace | |
| ) | |
| def _build_prompt_with_rag(self, req: AgentRequest, rag_resp: Dict[str, Any]) -> str: | |
| snippets = [] | |
| if isinstance(rag_resp, dict): | |
| hits = rag_resp.get("results") or rag_resp.get("hits") or [] | |
| for h in hits[:5]: | |
| txt = h.get("text") or h.get("content") or str(h) | |
| snippets.append(txt) | |
| snippet_text = "\n---\n".join(snippets) or "" | |
| prompt = ( | |
| f"You are an assistant helping tenant {req.tenant_id}. Use the following retrieved documents to answer the user's question.\n" | |
| f"Documents:\n{snippet_text}\n\n" | |
| f"User question: {req.message}\nProvide a concise, accurate answer and cite the source snippets where appropriate." | |
| ) | |
| return prompt | |
| async def _execute_multi_step(self, req: AgentRequest, steps: List[Dict[str, Any]], | |
| decision: AgentDecision, tool_traces: List[Dict[str, Any]], | |
| reasoning_trace: List[Dict[str, Any]], | |
| pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse: | |
| """ | |
| Execute multiple tools in sequence and synthesize results with LLM. | |
| """ | |
| rag_data = None | |
| web_data = None | |
| admin_data = None | |
| collected_data = [] | |
| parallel_tasks = {} | |
| rag_parallel_query = self._first_query_for_tool(steps, "rag", req.message) | |
| web_parallel_query = self._first_query_for_tool(steps, "web", req.message) | |
| if rag_parallel_query and web_parallel_query and rag_parallel_query == web_parallel_query: | |
| if not pre_fetched_rag: | |
| parallel_tasks["rag"] = asyncio.create_task(self.mcp.call_rag(req.tenant_id, rag_parallel_query)) | |
| parallel_tasks["web"] = asyncio.create_task(self.mcp.call_web(req.tenant_id, web_parallel_query)) | |
| # Execute each step in sequence | |
| for step_info in steps: | |
| tool_name = step_info.get("tool") | |
| step_input = step_info.get("input") or {} | |
| query = step_input.get("query") or req.message | |
| try: | |
| if tool_name == "rag": | |
| # Reuse pre-fetched RAG if available, otherwise fetch | |
| if pre_fetched_rag and query == rag_parallel_query: | |
| rag_resp = pre_fetched_rag | |
| tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"}) | |
| elif parallel_tasks.get("rag") and query == rag_parallel_query: | |
| rag_resp = await parallel_tasks["rag"] | |
| tool_traces.append({"tool": "rag", "response": rag_resp, "note": "parallel"}) | |
| else: | |
| rag_resp = await self.mcp.call_rag(req.tenant_id, query) | |
| tool_traces.append({"tool": "rag", "response": rag_resp}) | |
| rag_data = rag_resp | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "hit_count": len(self._extract_hits(rag_resp)), | |
| "summary": self._summarize_hits(rag_resp, limit=2) | |
| }) | |
| # Extract snippets for prompt | |
| if isinstance(rag_resp, dict): | |
| hits = rag_resp.get("results") or rag_resp.get("hits") or [] | |
| for h in hits[:5]: | |
| txt = h.get("text") or h.get("content") or str(h) | |
| collected_data.append(f"[RAG] {txt}") | |
| elif tool_name == "web": | |
| if parallel_tasks.get("web") and query == web_parallel_query: | |
| web_resp = await parallel_tasks["web"] | |
| tool_traces.append({"tool": "web", "response": web_resp, "note": "parallel"}) | |
| else: | |
| web_resp = await self.mcp.call_web(req.tenant_id, query) | |
| tool_traces.append({"tool": "web", "response": web_resp}) | |
| web_data = web_resp | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "hit_count": len(self._extract_hits(web_resp)), | |
| "summary": self._summarize_hits(web_resp, limit=2) | |
| }) | |
| # Extract snippets for prompt | |
| if isinstance(web_resp, dict): | |
| hits = web_resp.get("results") or web_resp.get("items") or [] | |
| for h in hits[:5]: | |
| title = h.get("title") or h.get("headline") or "" | |
| snippet = h.get("snippet") or h.get("summary") or h.get("text") or "" | |
| url = h.get("url") or h.get("link") or "" | |
| collected_data.append(f"[WEB] {title}\n{snippet}\nSource: {url}") | |
| elif tool_name == "admin": | |
| admin_resp = await self.mcp.call_admin(req.tenant_id, query) | |
| tool_traces.append({"tool": "admin", "response": admin_resp}) | |
| admin_data = admin_resp | |
| collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}") | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "admin", | |
| "status": "completed" | |
| }) | |
| elif tool_name == "llm": | |
| # LLM is always last - synthesize all collected data | |
| break | |
| except Exception as e: | |
| tool_traces.append({"tool": tool_name, "error": str(e)}) | |
| # Continue with other tools even if one fails | |
| reasoning_trace.append({ | |
| "step": "error", | |
| "tool": tool_name, | |
| "error": str(e) | |
| }) | |
| # Build comprehensive prompt with all collected data | |
| data_section = "\n---\n".join(collected_data) if collected_data else "" | |
| if data_section: | |
| prompt = ( | |
| f"You are an assistant helping tenant {req.tenant_id}.\n\n" | |
| f"The following information has been gathered from multiple sources:\n\n" | |
| f"{data_section}\n\n" | |
| f"User question: {req.message}\n\n" | |
| f"Provide a comprehensive, accurate answer using the information above. " | |
| f"Cite sources where appropriate (RAG for internal docs, WEB for online sources)." | |
| ) | |
| else: | |
| # No data collected, just answer the question | |
| prompt = req.message | |
| # Final LLM synthesis | |
| try: | |
| llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) | |
| return AgentResponse( | |
| text=llm_out, | |
| decision=decision, | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace + [{ | |
| "step": "llm_response", | |
| "mode": "multi_step" | |
| }] | |
| ) | |
| except Exception as e: | |
| tool_traces.append({"tool": "llm", "error": str(e)}) | |
| error_msg = str(e) | |
| # Provide helpful error message | |
| if "Cannot connect" in error_msg or "Ollama" in error_msg: | |
| fallback = ( | |
| f"I couldn't connect to the AI service (Ollama). " | |
| f"Error: {error_msg}\n\n" | |
| f"To fix this:\n" | |
| f"1. Install Ollama from https://ollama.ai\n" | |
| f"2. Start Ollama: `ollama serve`\n" | |
| f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n" | |
| f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file" | |
| ) | |
| else: | |
| fallback = f"I encountered an error while synthesizing the response: {error_msg}" | |
| return AgentResponse( | |
| text=fallback, | |
| decision=AgentDecision( | |
| action="respond", | |
| tool=None, | |
| tool_input=None, | |
| reason=f"multi_step_llm_error: {e}" | |
| ), | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace + [{ | |
| "step": "error", | |
| "tool": "llm", | |
| "error": str(e) | |
| }] | |
| ) | |
| def _build_prompt_with_web(self, req: AgentRequest, web_resp: Dict[str, Any]) -> str: | |
| snippets = [] | |
| if isinstance(web_resp, dict): | |
| hits = web_resp.get("results") or web_resp.get("items") or [] | |
| for h in hits[:5]: | |
| title = h.get("title") or h.get("headline") or "" | |
| snippet = h.get("snippet") or h.get("summary") or h.get("text") or "" | |
| url = h.get("url") or h.get("link") or "" | |
| snippets.append(f"{title}\n{snippet}\nSource: {url}") | |
| snippet_text = "\n---\n".join(snippets) or "" | |
| prompt = ( | |
| f"You are an assistant with access to recent web search results. Use the following results to answer.\n{snippet_text}\n\n" | |
| f"User question: {req.message}\nAnswer succinctly and indicate which results you used." | |
| ) | |
| return prompt | |
| def _extract_hits(resp: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| if not isinstance(resp, dict): | |
| return [] | |
| return resp.get("results") or resp.get("hits") or resp.get("items") or [] | |
| def _summarize_hits(self, resp: Optional[Dict[str, Any]], limit: int = 3) -> List[str]: | |
| hits = self._extract_hits(resp) | |
| summaries = [] | |
| for hit in hits[:limit]: | |
| if isinstance(hit, dict): | |
| snippet = hit.get("text") or hit.get("content") or hit.get("snippet") or "" | |
| else: | |
| snippet = str(hit) | |
| summaries.append(snippet[:160]) | |
| return summaries | |
| def _first_query_for_tool(steps: List[Dict[str, Any]], tool_name: str, default_query: str) -> Optional[str]: | |
| for step in steps: | |
| if step.get("tool") == tool_name: | |
| input_data = step.get("input") or {} | |
| return input_data.get("query") or default_query | |
| return None | |