# ============================================================= # File: backend/api/services/agent_orchestrator.py # ============================================================= """ Agent Orchestrator (integrated with enterprise RedFlagDetector) Place at: backend/api/services/agent_orchestrator.py """ from __future__ import annotations import asyncio import json import os from typing import List, Dict, Any, Optional from ..models.agent import AgentRequest, AgentDecision, AgentResponse from ..models.redflag import RedFlagMatch from .redflag_detector import RedFlagDetector from .intent_classifier import IntentClassifier from .tool_selector import ToolSelector from .llm_client import LLMClient from ..mcp_clients.mcp_client import MCPClient from .tool_scoring import ToolScoringService from ..storage.analytics_store import AnalyticsStore import time class AgentOrchestrator: def __init__(self, rag_mcp_url: str, web_mcp_url: str, admin_mcp_url: str, llm_backend: str = "ollama"): self.mcp = MCPClient(rag_mcp_url, web_mcp_url, admin_mcp_url) self.llm = LLMClient(backend=llm_backend, url=os.getenv("OLLAMA_URL"), api_key=os.getenv("GROQ_API_KEY"), model=os.getenv("OLLAMA_MODEL")) # pass admin_mcp_url so detector can call back self.redflag = RedFlagDetector( supabase_url=os.getenv("SUPABASE_URL"), supabase_key=os.getenv("SUPABASE_SERVICE_KEY"), admin_mcp_url=admin_mcp_url ) self.intent = IntentClassifier(llm_client=self.llm) self.selector = ToolSelector(llm_client=self.llm) self.tool_scorer = ToolScoringService() self.analytics = AnalyticsStore() async def handle(self, req: AgentRequest) -> AgentResponse: start_time = time.time() reasoning_trace: List[Dict[str, Any]] = [] reasoning_trace.append({ "step": "request_received", "tenant_id": req.tenant_id, "user_id": req.user_id, "message_preview": req.message[:120] }) # 1) Red-flag check (async) matches: List[RedFlagMatch] = await self.redflag.check(req.tenant_id, req.message) reasoning_trace.append({ "step": "redflag_check", "match_count": len(matches), "matches": [m.__dict__ for m in matches] }) # Log red-flag violations for match in matches: self.analytics.log_redflag_violation( tenant_id=req.tenant_id, rule_id=match.rule_id, rule_pattern=match.pattern, severity=match.severity, matched_text=match.matched_text, confidence=match.confidence, message_preview=req.message[:200], user_id=req.user_id ) if matches: # Notify admin asynchronously (do not await blocking the response path if you prefer) # we await here to ensure admin receives the alert before responding try: await self.redflag.notify_admin(req.tenant_id, matches, source_payload={"message": req.message, "user_id": req.user_id}) except Exception: pass decision = AgentDecision( action="block", tool="admin", tool_input={"violations": [m.__dict__ for m in matches]}, reason="redflag_triggered" ) summary = "; ".join( f"{m.description or m.pattern} [severity: {m.severity}]" for m in matches ) or "Policy violation detected" total_latency_ms = int((time.time() - start_time) * 1000) self.analytics.log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent="admin", tools_used=["admin"], total_tokens=0, total_latency_ms=total_latency_ms, success=False, user_id=req.user_id ) return AgentResponse( text=f"⚠️ Request blocked by Admin Plan: {summary}. Please review your governance rules or contact an administrator.", decision=decision, tool_traces=[{"redflags": [m.__dict__ for m in matches]}], reasoning_trace=reasoning_trace ) # 2) Intent classification intent = await self.intent.classify(req.message) reasoning_trace.append({ "step": "intent_detection", "intent": intent }) # 2.5) Pre-fetch RAG results if available (for tool selector context) rag_prefetch = None rag_results = [] try: # Try to pre-fetch RAG to help tool selector make better decisions rag_start = time.time() rag_prefetch = await self.mcp.call_rag(req.tenant_id, req.message) rag_latency_ms = int((time.time() - rag_start) * 1000) if isinstance(rag_prefetch, dict): rag_results = rag_prefetch.get("results") or rag_prefetch.get("hits") or [] # Log RAG search event hits_count = len(rag_results) avg_score = None top_score = None if rag_results: scores = [h.get("score", 0.0) for h in rag_results if isinstance(h, dict) and "score" in h] if scores: avg_score = sum(scores) / len(scores) top_score = max(scores) self.analytics.log_rag_search( tenant_id=req.tenant_id, query=req.message[:500], hits_count=hits_count, avg_score=avg_score, top_score=top_score, latency_ms=rag_latency_ms ) # Log tool usage self.analytics.log_tool_usage( tenant_id=req.tenant_id, tool_name="rag", latency_ms=rag_latency_ms, success=True, user_id=req.user_id ) reasoning_trace.append({ "step": "rag_prefetch", "status": "ok", "hit_count": len(rag_results), "latency_ms": rag_latency_ms }) except Exception as pref_err: # If RAG fails, continue without it rag_latency_ms = 0 # 0 for failed self.analytics.log_tool_usage( tenant_id=req.tenant_id, tool_name="rag", latency_ms=rag_latency_ms, success=False, error_message=str(pref_err)[:200], user_id=req.user_id ) reasoning_trace.append({ "step": "rag_prefetch", "status": "error", "error": str(pref_err) }) rag_prefetch = None tool_scores = self.tool_scorer.score(req.message, intent, rag_results) reasoning_trace.append({ "step": "tool_scoring", "scores": tool_scores }) # 3) Tool selection (hybrid) - pass RAG results in context ctx = { "tenant_id": req.tenant_id, "rag_results": rag_results, "tool_scores": tool_scores } decision = await self.selector.select(intent, req.message, ctx) reasoning_trace.append({ "step": "tool_selection", "decision": decision.dict(), "context_scores": tool_scores }) tool_traces: List[Dict[str, Any]] = [] # 4) Handle multi-step tool execution if decision.action == "multi_step" and decision.tool_input: steps = decision.tool_input.get("steps", []) if steps: return await self._execute_multi_step( req, steps, decision, tool_traces, reasoning_trace, rag_prefetch ) # 5) Execute single tool tools_used = [] total_tokens = 0 if decision.action == "call_tool" and decision.tool: try: if decision.tool == "rag": rag_start = time.time() rag_resp = await self.mcp.call_rag(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message) rag_latency_ms = int((time.time() - rag_start) * 1000) tools_used.append("rag") tool_traces.append({"tool": "rag", "response": rag_resp}) hits = self._extract_hits(rag_resp) # Log RAG search and tool usage hits_count = len(hits) avg_score = None top_score = None if hits: scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h] if scores: avg_score = sum(scores) / len(scores) top_score = max(scores) self.analytics.log_rag_search( tenant_id=req.tenant_id, query=req.message[:500], hits_count=hits_count, avg_score=avg_score, top_score=top_score, latency_ms=rag_latency_ms ) self.analytics.log_tool_usage( tenant_id=req.tenant_id, tool_name="rag", latency_ms=rag_latency_ms, success=True, user_id=req.user_id ) reasoning_trace.append({ "step": "tool_execution", "tool": "rag", "hit_count": hits_count, "summary": self._summarize_hits(rag_resp, limit=2), "latency_ms": rag_latency_ms }) prompt = self._build_prompt_with_rag(req, rag_resp) llm_start = time.time() llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) llm_latency_ms = int((time.time() - llm_start) * 1000) tools_used.append("llm") # Estimate tokens (rough: ~4 chars per token) estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 total_tokens += estimated_tokens self.analytics.log_tool_usage( tenant_id=req.tenant_id, tool_name="llm", latency_ms=llm_latency_ms, tokens_used=estimated_tokens, success=True, user_id=req.user_id ) reasoning_trace.append({ "step": "llm_response", "mode": "rag_synthesis", "latency_ms": llm_latency_ms, "estimated_tokens": estimated_tokens }) total_latency_ms = int((time.time() - start_time) * 1000) self.analytics.log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent=intent, tools_used=tools_used, total_tokens=total_tokens, total_latency_ms=total_latency_ms, success=True, user_id=req.user_id ) return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) if decision.tool == "web": web_start = time.time() web_resp = await self.mcp.call_web(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message) web_latency_ms = int((time.time() - web_start) * 1000) tools_used.append("web") tool_traces.append({"tool": "web", "response": web_resp}) hits_count = len(self._extract_hits(web_resp)) self.analytics.log_tool_usage( tenant_id=req.tenant_id, tool_name="web", latency_ms=web_latency_ms, success=True, user_id=req.user_id ) reasoning_trace.append({ "step": "tool_execution", "tool": "web", "hit_count": hits_count, "summary": self._summarize_hits(web_resp, limit=2), "latency_ms": web_latency_ms }) prompt = self._build_prompt_with_web(req, web_resp) llm_start = time.time() llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) llm_latency_ms = int((time.time() - llm_start) * 1000) tools_used.append("llm") estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 total_tokens += estimated_tokens self.analytics.log_tool_usage( tenant_id=req.tenant_id, tool_name="llm", latency_ms=llm_latency_ms, tokens_used=estimated_tokens, success=True, user_id=req.user_id ) reasoning_trace.append({ "step": "llm_response", "mode": "web_synthesis", "latency_ms": llm_latency_ms, "estimated_tokens": estimated_tokens }) total_latency_ms = int((time.time() - start_time) * 1000) self.analytics.log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent=intent, tools_used=tools_used, total_tokens=total_tokens, total_latency_ms=total_latency_ms, success=True, user_id=req.user_id ) return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) if decision.tool == "admin": admin_start = time.time() admin_resp = await self.mcp.call_admin(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message) admin_latency_ms = int((time.time() - admin_start) * 1000) tools_used.append("admin") self.analytics.log_tool_usage( tenant_id=req.tenant_id, tool_name="admin", latency_ms=admin_latency_ms, success=True, user_id=req.user_id ) tool_traces.append({"tool": "admin", "response": admin_resp}) reasoning_trace.append({ "step": "tool_execution", "tool": "admin", "status": "completed", "latency_ms": admin_latency_ms }) total_latency_ms = int((time.time() - start_time) * 1000) self.analytics.log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent=intent, tools_used=tools_used, total_tokens=0, total_latency_ms=total_latency_ms, success=True, user_id=req.user_id ) return AgentResponse(text=json.dumps(admin_resp), decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) if decision.tool == "llm": llm_start = time.time() llm_out = await self.llm.simple_call(req.message, temperature=req.temperature) llm_latency_ms = int((time.time() - llm_start) * 1000) tools_used.append("llm") estimated_tokens = len(llm_out) // 4 + len(req.message) // 4 total_tokens += estimated_tokens self.analytics.log_tool_usage( tenant_id=req.tenant_id, tool_name="llm", latency_ms=llm_latency_ms, tokens_used=estimated_tokens, success=True, user_id=req.user_id ) reasoning_trace.append({ "step": "llm_response", "mode": "direct", "latency_ms": llm_latency_ms, "estimated_tokens": estimated_tokens }) total_latency_ms = int((time.time() - start_time) * 1000) self.analytics.log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent=intent, tools_used=tools_used, total_tokens=total_tokens, total_latency_ms=total_latency_ms, success=True, user_id=req.user_id ) return AgentResponse(text=llm_out, decision=decision, reasoning_trace=reasoning_trace) except Exception as e: tool_traces.append({"tool": decision.tool, "error": str(e)}) try: fallback = await self.llm.simple_call(req.message, temperature=req.temperature) except Exception as llm_error: error_msg = str(llm_error) if "Cannot connect" in error_msg or "Ollama" in error_msg: fallback = ( f"I encountered an error while processing your request: {str(e)}\n\n" f"Additionally, the AI service (Ollama) is unavailable: {error_msg}\n\n" f"To fix:\n" f"1. Install Ollama from https://ollama.ai\n" f"2. Start: `ollama serve`\n" f"3. Pull model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`" ) else: fallback = f"I encountered an error while processing your request: {str(e)}. Additionally, the AI service is unavailable: {error_msg}" return AgentResponse( text=fallback, decision=AgentDecision(action="respond", tool=None, tool_input=None, reason=f"tool_error_fallback: {e}"), tool_traces=tool_traces, reasoning_trace=reasoning_trace + [{ "step": "error", "tool": decision.tool, "error": str(e) }] ) # Default: direct LLM response try: llm_start = time.time() llm_out = await self.llm.simple_call(req.message, temperature=req.temperature) llm_latency_ms = int((time.time() - llm_start) * 1000) tools_used = ["llm"] estimated_tokens = len(llm_out) // 4 + len(req.message) // 4 self.analytics.log_tool_usage( tenant_id=req.tenant_id, tool_name="llm", latency_ms=llm_latency_ms, tokens_used=estimated_tokens, success=True, user_id=req.user_id ) except Exception as e: # If LLM fails, return a helpful error message error_msg = str(e) if "Cannot connect" in error_msg or "Ollama" in error_msg: llm_out = ( f"I couldn't connect to the AI service (Ollama). " f"Error: {error_msg}\n\n" f"To fix this:\n" f"1. Install Ollama from https://ollama.ai\n" f"2. Start Ollama: `ollama serve`\n" f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n" f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file" ) else: llm_out = f"I apologize, but I'm unable to process your request right now. The AI service is unavailable: {error_msg}" self.analytics.log_tool_usage( tenant_id=req.tenant_id, tool_name="llm", success=False, error_message=error_msg[:200], user_id=req.user_id ) reasoning_trace.append({ "step": "error", "tool": "llm", "error": str(e) }) total_latency_ms = int((time.time() - start_time) * 1000) self.analytics.log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent=intent, tools_used=tools_used if 'tools_used' in locals() else [], total_tokens=estimated_tokens if 'estimated_tokens' in locals() else 0, total_latency_ms=total_latency_ms, success=True if 'llm_out' in locals() else False, user_id=req.user_id ) return AgentResponse( text=llm_out, decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="default_llm"), reasoning_trace=reasoning_trace ) def _build_prompt_with_rag(self, req: AgentRequest, rag_resp: Dict[str, Any]) -> str: snippets = [] if isinstance(rag_resp, dict): hits = rag_resp.get("results") or rag_resp.get("hits") or [] for h in hits[:5]: txt = h.get("text") or h.get("content") or str(h) snippets.append(txt) snippet_text = "\n---\n".join(snippets) or "" prompt = ( f"You are an assistant helping tenant {req.tenant_id}. Use the following retrieved documents to answer the user's question.\n" f"Documents:\n{snippet_text}\n\n" f"User question: {req.message}\nProvide a concise, accurate answer and cite the source snippets where appropriate." ) return prompt async def _execute_multi_step(self, req: AgentRequest, steps: List[Dict[str, Any]], decision: AgentDecision, tool_traces: List[Dict[str, Any]], reasoning_trace: List[Dict[str, Any]], pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse: """ Execute multiple tools in sequence and synthesize results with LLM. """ rag_data = None web_data = None admin_data = None collected_data = [] parallel_tasks = {} rag_parallel_query = self._first_query_for_tool(steps, "rag", req.message) web_parallel_query = self._first_query_for_tool(steps, "web", req.message) if rag_parallel_query and web_parallel_query and rag_parallel_query == web_parallel_query: if not pre_fetched_rag: parallel_tasks["rag"] = asyncio.create_task(self.mcp.call_rag(req.tenant_id, rag_parallel_query)) parallel_tasks["web"] = asyncio.create_task(self.mcp.call_web(req.tenant_id, web_parallel_query)) # Execute each step in sequence for step_info in steps: tool_name = step_info.get("tool") step_input = step_info.get("input") or {} query = step_input.get("query") or req.message try: if tool_name == "rag": # Reuse pre-fetched RAG if available, otherwise fetch if pre_fetched_rag and query == rag_parallel_query: rag_resp = pre_fetched_rag tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"}) elif parallel_tasks.get("rag") and query == rag_parallel_query: rag_resp = await parallel_tasks["rag"] tool_traces.append({"tool": "rag", "response": rag_resp, "note": "parallel"}) else: rag_resp = await self.mcp.call_rag(req.tenant_id, query) tool_traces.append({"tool": "rag", "response": rag_resp}) rag_data = rag_resp reasoning_trace.append({ "step": "tool_execution", "tool": "rag", "hit_count": len(self._extract_hits(rag_resp)), "summary": self._summarize_hits(rag_resp, limit=2) }) # Extract snippets for prompt if isinstance(rag_resp, dict): hits = rag_resp.get("results") or rag_resp.get("hits") or [] for h in hits[:5]: txt = h.get("text") or h.get("content") or str(h) collected_data.append(f"[RAG] {txt}") elif tool_name == "web": if parallel_tasks.get("web") and query == web_parallel_query: web_resp = await parallel_tasks["web"] tool_traces.append({"tool": "web", "response": web_resp, "note": "parallel"}) else: web_resp = await self.mcp.call_web(req.tenant_id, query) tool_traces.append({"tool": "web", "response": web_resp}) web_data = web_resp reasoning_trace.append({ "step": "tool_execution", "tool": "web", "hit_count": len(self._extract_hits(web_resp)), "summary": self._summarize_hits(web_resp, limit=2) }) # Extract snippets for prompt if isinstance(web_resp, dict): hits = web_resp.get("results") or web_resp.get("items") or [] for h in hits[:5]: title = h.get("title") or h.get("headline") or "" snippet = h.get("snippet") or h.get("summary") or h.get("text") or "" url = h.get("url") or h.get("link") or "" collected_data.append(f"[WEB] {title}\n{snippet}\nSource: {url}") elif tool_name == "admin": admin_resp = await self.mcp.call_admin(req.tenant_id, query) tool_traces.append({"tool": "admin", "response": admin_resp}) admin_data = admin_resp collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}") reasoning_trace.append({ "step": "tool_execution", "tool": "admin", "status": "completed" }) elif tool_name == "llm": # LLM is always last - synthesize all collected data break except Exception as e: tool_traces.append({"tool": tool_name, "error": str(e)}) # Continue with other tools even if one fails reasoning_trace.append({ "step": "error", "tool": tool_name, "error": str(e) }) # Build comprehensive prompt with all collected data data_section = "\n---\n".join(collected_data) if collected_data else "" if data_section: prompt = ( f"You are an assistant helping tenant {req.tenant_id}.\n\n" f"The following information has been gathered from multiple sources:\n\n" f"{data_section}\n\n" f"User question: {req.message}\n\n" f"Provide a comprehensive, accurate answer using the information above. " f"Cite sources where appropriate (RAG for internal docs, WEB for online sources)." ) else: # No data collected, just answer the question prompt = req.message # Final LLM synthesis try: llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) return AgentResponse( text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace + [{ "step": "llm_response", "mode": "multi_step" }] ) except Exception as e: tool_traces.append({"tool": "llm", "error": str(e)}) error_msg = str(e) # Provide helpful error message if "Cannot connect" in error_msg or "Ollama" in error_msg: fallback = ( f"I couldn't connect to the AI service (Ollama). " f"Error: {error_msg}\n\n" f"To fix this:\n" f"1. Install Ollama from https://ollama.ai\n" f"2. Start Ollama: `ollama serve`\n" f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n" f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file" ) else: fallback = f"I encountered an error while synthesizing the response: {error_msg}" return AgentResponse( text=fallback, decision=AgentDecision( action="respond", tool=None, tool_input=None, reason=f"multi_step_llm_error: {e}" ), tool_traces=tool_traces, reasoning_trace=reasoning_trace + [{ "step": "error", "tool": "llm", "error": str(e) }] ) def _build_prompt_with_web(self, req: AgentRequest, web_resp: Dict[str, Any]) -> str: snippets = [] if isinstance(web_resp, dict): hits = web_resp.get("results") or web_resp.get("items") or [] for h in hits[:5]: title = h.get("title") or h.get("headline") or "" snippet = h.get("snippet") or h.get("summary") or h.get("text") or "" url = h.get("url") or h.get("link") or "" snippets.append(f"{title}\n{snippet}\nSource: {url}") snippet_text = "\n---\n".join(snippets) or "" prompt = ( f"You are an assistant with access to recent web search results. Use the following results to answer.\n{snippet_text}\n\n" f"User question: {req.message}\nAnswer succinctly and indicate which results you used." ) return prompt @staticmethod def _extract_hits(resp: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]: if not isinstance(resp, dict): return [] return resp.get("results") or resp.get("hits") or resp.get("items") or [] def _summarize_hits(self, resp: Optional[Dict[str, Any]], limit: int = 3) -> List[str]: hits = self._extract_hits(resp) summaries = [] for hit in hits[:limit]: if isinstance(hit, dict): snippet = hit.get("text") or hit.get("content") or hit.get("snippet") or "" else: snippet = str(hit) summaries.append(snippet[:160]) return summaries @staticmethod def _first_query_for_tool(steps: List[Dict[str, Any]], tool_name: str, default_query: str) -> Optional[str]: for step in steps: if step.get("tool") == tool_name: input_data = step.get("input") or {} return input_data.get("query") or default_query return None