# ============================================================= # File: backend/api/services/agent_orchestrator.py # ============================================================= """ Agent Orchestrator (integrated with enterprise RedFlagDetector) Place at: backend/api/services/agent_orchestrator.py """ from __future__ import annotations import asyncio import json import os from typing import List, Dict, Any, Optional import logging from ..models.agent import AgentRequest, AgentDecision, AgentResponse from ..models.redflag import RedFlagMatch from .redflag_detector import RedFlagDetector from .intent_classifier import IntentClassifier from .tool_selector import ToolSelector from .llm_client import LLMClient from ..mcp_clients.mcp_client import MCPClient from .tool_scoring import ToolScoringService from ..storage.analytics_store import AnalyticsStore from .result_merger import merge_parallel_results, format_merged_context_for_prompt import time logger = logging.getLogger(__name__) from dotenv import load_dotenv load_dotenv() class AgentOrchestrator: def __init__(self, rag_mcp_url: str, web_mcp_url: str, admin_mcp_url: str, llm_backend: str = "ollama"): self.mcp = MCPClient(rag_mcp_url, web_mcp_url, admin_mcp_url) self.llm = LLMClient(backend=llm_backend, url=os.getenv("OLLAMA_URL"), api_key=os.getenv("GROQ_API_KEY"), model=os.getenv("OLLAMA_MODEL")) # pass admin_mcp_url so detector can call back self.redflag = RedFlagDetector( supabase_url=os.getenv("SUPABASE_URL"), supabase_key=os.getenv("SUPABASE_SERVICE_KEY"), admin_mcp_url=admin_mcp_url ) self.intent = IntentClassifier(llm_client=self.llm) self.selector = ToolSelector(llm_client=self.llm) self.tool_scorer = ToolScoringService() self._analytics: Optional[AnalyticsStore] = None self._analytics_disabled = os.getenv("ANALYTICS_DISABLED", "").lower() in {"1", "true", "yes"} self._analytics_failed = False self._log_analytics_backend_once() def _log_analytics_backend_once(self) -> None: if getattr(AgentOrchestrator, "_analytics_backend_logged", False): return if self._analytics_disabled: print("⚠️ AgentOrchestrator Analytics: Disabled via ANALYTICS_DISABLED") else: store = self._get_analytics() if store is None: print("⚠️ AgentOrchestrator Analytics: Disabled (Supabase not configured)") elif store.use_supabase: print("✅ AgentOrchestrator Analytics: Using Supabase backend") else: print("⚠️ AgentOrchestrator Analytics: Using fallback backend") AgentOrchestrator._analytics_backend_logged = True def _get_analytics(self) -> Optional[AnalyticsStore]: if self._analytics_disabled or self._analytics_failed: return None if self._analytics is not None: return self._analytics try: self._analytics = AnalyticsStore() except RuntimeError as exc: logger.warning("AgentOrchestrator analytics disabled: %s", exc) self._analytics_failed = True self._analytics = None except Exception as exc: # pragma: no cover - unexpected initialization failures logger.debug("AgentOrchestrator analytics unexpected init failure: %s", exc) self._analytics_failed = True self._analytics = None return self._analytics def _analytics_log_tool_usage(self, **kwargs: Any) -> None: analytics = self._get_analytics() if not analytics: return try: analytics.log_tool_usage(**kwargs) except Exception as exc: # pragma: no cover - analytics failures should not break flow logger.debug("AgentOrchestrator tool analytics failed: %s", exc) def _analytics_log_agent_query(self, **kwargs: Any) -> None: analytics = self._get_analytics() if not analytics: return try: analytics.log_agent_query(**kwargs) except Exception as exc: # pragma: no cover logger.debug("AgentOrchestrator agent query analytics failed: %s", exc) def _analytics_log_rag_search(self, **kwargs: Any) -> None: analytics = self._get_analytics() if not analytics: return try: analytics.log_rag_search(**kwargs) except Exception as exc: # pragma: no cover logger.debug("AgentOrchestrator RAG analytics failed: %s", exc) def _analytics_log_redflag_violation(self, **kwargs: Any) -> None: analytics = self._get_analytics() if not analytics: return try: analytics.log_redflag_violation(**kwargs) except Exception as exc: # pragma: no cover logger.debug("AgentOrchestrator redflag analytics failed: %s", exc) async def handle(self, req: AgentRequest) -> AgentResponse: start_time = time.time() reasoning_trace: List[Dict[str, Any]] = [] reasoning_trace.append({ "step": "request_received", "tenant_id": req.tenant_id, "user_id": req.user_id, "message_preview": req.message[:120] }) # 1) FIRST: Check admin rules - if any rule matches, respond according to rule matches: List[RedFlagMatch] = await self.redflag.check(req.tenant_id, req.message) reasoning_trace.append({ "step": "admin_rules_check", "match_count": len(matches), "matches": [m.__dict__ for m in matches] }) if matches: # Log all rule matches for match in matches: self._analytics_log_redflag_violation( tenant_id=req.tenant_id, rule_id=match.rule_id, rule_pattern=match.pattern, severity=match.severity, matched_text=match.matched_text, confidence=match.confidence, message_preview=req.message[:200], user_id=req.user_id ) # Categorize rules: brief response rules vs blocking rules brief_response_rules = [] blocking_rules = [] for match in matches: rule_text = (match.description or match.pattern or "").lower() is_brief_rule = ( match.severity == "low" and ( "greeting" in rule_text or "brief" in rule_text or "simple response" in rule_text or "keep.*response.*brief" in rule_text or "do not.*verbose" in rule_text or "respond.*briefly" in rule_text ) ) if is_brief_rule: brief_response_rules.append(match) else: blocking_rules.append(match) # Handle brief response rules (greetings, etc.) - return immediately if brief_response_rules and not blocking_rules: # Return brief response without proceeding to normal flow brief_responses = [ "Hello! How can I help you today?", "Hi there! What can I assist you with?", "Hello! I'm here to help. What do you need?", "Hi! How can I assist you?" ] import random brief_response = random.choice(brief_responses) reasoning_trace.append({ "step": "brief_response_rule_matched", "action": "brief_response", "matched_rules": [m.rule_id for m in brief_response_rules], "message": "Brief response rule matched, returning brief response (skipping normal flow)" }) total_latency_ms = int((time.time() - start_time) * 1000) self._analytics_log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent="greeting", tools_used=[], total_tokens=len(brief_response) // 4, total_latency_ms=total_latency_ms, success=True, user_id=req.user_id ) return AgentResponse( text=brief_response, decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="brief_response_rule"), reasoning_trace=reasoning_trace ) # Handle blocking rules (security, compliance, etc.) - block and return immediately if blocking_rules: # Notify admin asynchronously try: await self.redflag.notify_admin(req.tenant_id, blocking_rules, source_payload={"message": req.message, "user_id": req.user_id}) except Exception: pass decision = AgentDecision( action="block", tool="admin", tool_input={"violations": [m.__dict__ for m in blocking_rules]}, reason="admin_rule_violation" ) # Build detailed prompt for LLM to generate natural red flag response violations_details = [] for i, m in enumerate(blocking_rules, 1): rule_name = m.description or m.pattern or "Policy violation" detail = f"{i}. **{rule_name}** (Severity: {m.severity})" if m.matched_text: detail += f"\n - Detected phrase: \"{m.matched_text}\"" violations_details.append(detail) llm_prompt = f"""A user made the following request: "{req.message}" However, this request violates company policies. The following policy violations were detected: {chr(10).join(violations_details)} Your task: Write a clear, professional, and empathetic response to inform the user that: 1. Their request cannot be processed due to policy violations 2. Which specific policy was violated (mention it naturally) 3. The incident has been logged for security review 4. They should contact an administrator if they need assistance or believe this is an error Write a natural, conversational response (2-4 sentences) that feels helpful rather than robotic. Be professional but understanding. Response:""" # Generate LLM response for red flag try: llm_response = await self.llm.simple_call(llm_prompt, temperature=min(req.temperature + 0.2, 0.7)) # Slightly higher temp for more natural response llm_response = llm_response.strip() # Add warning emoji if not present if not llm_response.startswith("⚠️") and not llm_response.startswith("🚨"): llm_response = f"⚠️ {llm_response}" except Exception as e: # Fallback to a simple message if LLM fails summary = "; ".join( f"{m.description or m.pattern}" for m in matches ) llm_response = f"⚠️ I'm unable to process your request because it violates our company policy: {summary}. This incident has been logged. Please contact your administrator if you need assistance." total_latency_ms = int((time.time() - start_time) * 1000) # Log LLM usage for red flag response estimated_tokens = len(llm_response) // 4 + len(llm_prompt) // 4 self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="llm", latency_ms=total_latency_ms, tokens_used=estimated_tokens, success=True, user_id=req.user_id ) self._analytics_log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent="admin", tools_used=["admin", "llm"], total_tokens=estimated_tokens, total_latency_ms=total_latency_ms, success=False, user_id=req.user_id ) return AgentResponse( text=llm_response, decision=decision, tool_traces=[{"redflags": [m.__dict__ for m in blocking_rules]}], reasoning_trace=reasoning_trace ) # 2) ONLY IF NO RULES MATCHED: Proceed with normal flow (intent classification, RAG, etc.) # 2.1) Optional: Try to rewrite message if it might violate rules (preventive self-correction) # Note: This is a lighter check - we already blocked above if rules matched # This is for edge cases where we want to proactively improve the message safe_message = req.message # Default to original intent = await self.intent.classify(req.message) reasoning_trace.append({ "step": "intent_detection", "intent": intent }) # 2.5) Pre-fetch RAG results if available (for tool selector context) rag_prefetch = None rag_results = [] try: # Try to pre-fetch RAG to help tool selector make better decisions rag_start = time.time() rag_prefetch = await self.mcp.call_rag(req.tenant_id, req.message) rag_latency_ms = int((time.time() - rag_start) * 1000) if isinstance(rag_prefetch, dict): rag_results = rag_prefetch.get("results") or rag_prefetch.get("hits") or [] # Log RAG search event hits_count = len(rag_results) avg_score = None top_score = None if rag_results: scores = [h.get("score", 0.0) for h in rag_results if isinstance(h, dict) and "score" in h] if scores: avg_score = sum(scores) / len(scores) top_score = max(scores) self._analytics_log_rag_search( tenant_id=req.tenant_id, query=req.message[:500], hits_count=hits_count, avg_score=avg_score, top_score=top_score, latency_ms=rag_latency_ms ) # Log tool usage self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="rag", latency_ms=rag_latency_ms, success=True, user_id=req.user_id ) reasoning_trace.append({ "step": "rag_prefetch", "status": "ok", "hit_count": len(rag_results), "latency_ms": rag_latency_ms }) except Exception as pref_err: # If RAG fails, continue without it rag_latency_ms = 0 # 0 for failed self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="rag", latency_ms=rag_latency_ms, success=False, error_message=str(pref_err)[:200], user_id=req.user_id ) reasoning_trace.append({ "step": "rag_prefetch", "status": "error", "error": str(pref_err) }) rag_prefetch = None tool_scores = self.tool_scorer.score(req.message, intent, rag_results) reasoning_trace.append({ "step": "tool_scoring", "scores": tool_scores }) # 3) Tool selection (hybrid) - pass RAG results in context ctx = { "tenant_id": req.tenant_id, "rag_results": rag_results, "tool_scores": tool_scores } decision = await self.selector.select(intent, req.message, ctx) reasoning_trace.append({ "step": "tool_selection", "decision": decision.dict(), "context_scores": tool_scores }) tool_traces: List[Dict[str, Any]] = [] # 4) Handle multi-step tool execution if decision.action == "multi_step" and decision.tool_input: steps = decision.tool_input.get("steps", []) if steps: return await self._execute_multi_step( req, steps, decision, tool_traces, reasoning_trace, rag_prefetch ) # 5) Execute single tool tools_used = [] total_tokens = 0 if decision.action == "call_tool" and decision.tool: try: if decision.tool == "rag": # Use autonomous retry with self-correction rag_query = decision.tool_input.get("query") if decision.tool_input else req.message rag_resp = await self.rag_with_repair( query=rag_query, tenant_id=req.tenant_id, original_threshold=0.3, reasoning_trace=reasoning_trace, user_id=req.user_id ) tools_used.append("rag") tool_traces.append({"tool": "rag", "response": rag_resp}) hits = self._extract_hits(rag_resp) # Calculate scores for logging hits_count = len(hits) avg_score = None top_score = None if hits: scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h] if scores: avg_score = sum(scores) / len(scores) top_score = max(scores) reasoning_trace.append({ "step": "tool_execution", "tool": "rag", "hit_count": hits_count, "top_score": top_score, "avg_score": avg_score, "summary": self._summarize_hits(rag_resp, limit=2) }) prompt = self._build_prompt_with_rag(req, rag_resp) llm_start = time.time() llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) llm_latency_ms = int((time.time() - llm_start) * 1000) tools_used.append("llm") # Estimate tokens (rough: ~4 chars per token) estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 total_tokens += estimated_tokens self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="llm", latency_ms=llm_latency_ms, tokens_used=estimated_tokens, success=True, user_id=req.user_id ) reasoning_trace.append({ "step": "llm_response", "mode": "rag_synthesis", "latency_ms": llm_latency_ms, "estimated_tokens": estimated_tokens }) total_latency_ms = int((time.time() - start_time) * 1000) self._analytics_log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent=intent, tools_used=tools_used, total_tokens=total_tokens, total_latency_ms=total_latency_ms, success=True, user_id=req.user_id ) return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) if decision.tool == "web": # Use autonomous retry with query rewriting web_query = decision.tool_input.get("query") if decision.tool_input else req.message web_resp = await self.web_with_repair( query=web_query, tenant_id=req.tenant_id, reasoning_trace=reasoning_trace, user_id=req.user_id ) tools_used.append("web") tool_traces.append({"tool": "web", "response": web_resp}) hits_count = len(self._extract_hits(web_resp)) reasoning_trace.append({ "step": "tool_execution", "tool": "web", "hit_count": hits_count, "summary": self._summarize_hits(web_resp, limit=2) }) prompt = self._build_prompt_with_web(req, web_resp) llm_start = time.time() llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) llm_latency_ms = int((time.time() - llm_start) * 1000) tools_used.append("llm") estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 total_tokens += estimated_tokens self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="llm", latency_ms=llm_latency_ms, tokens_used=estimated_tokens, success=True, user_id=req.user_id ) reasoning_trace.append({ "step": "llm_response", "mode": "web_synthesis", "latency_ms": llm_latency_ms, "estimated_tokens": estimated_tokens }) total_latency_ms = int((time.time() - start_time) * 1000) self._analytics_log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent=intent, tools_used=tools_used, total_tokens=total_tokens, total_latency_ms=total_latency_ms, success=True, user_id=req.user_id ) return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) if decision.tool == "admin": admin_start = time.time() admin_resp = await self.mcp.call_admin(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message) admin_latency_ms = int((time.time() - admin_start) * 1000) tools_used.append("admin") self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="admin", latency_ms=admin_latency_ms, success=True, user_id=req.user_id ) tool_traces.append({"tool": "admin", "response": admin_resp}) reasoning_trace.append({ "step": "tool_execution", "tool": "admin", "status": "completed", "latency_ms": admin_latency_ms }) total_latency_ms = int((time.time() - start_time) * 1000) self._analytics_log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent=intent, tools_used=tools_used, total_tokens=0, total_latency_ms=total_latency_ms, success=True, user_id=req.user_id ) return AgentResponse(text=json.dumps(admin_resp), decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) if decision.tool == "llm": llm_start = time.time() llm_out = await self.llm.simple_call(req.message, temperature=req.temperature) llm_latency_ms = int((time.time() - llm_start) * 1000) tools_used.append("llm") estimated_tokens = len(llm_out) // 4 + len(req.message) // 4 total_tokens += estimated_tokens self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="llm", latency_ms=llm_latency_ms, tokens_used=estimated_tokens, success=True, user_id=req.user_id ) reasoning_trace.append({ "step": "llm_response", "mode": "direct", "latency_ms": llm_latency_ms, "estimated_tokens": estimated_tokens }) total_latency_ms = int((time.time() - start_time) * 1000) self._analytics_log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent=intent, tools_used=tools_used, total_tokens=total_tokens, total_latency_ms=total_latency_ms, success=True, user_id=req.user_id ) return AgentResponse(text=llm_out, decision=decision, reasoning_trace=reasoning_trace) except Exception as e: tool_traces.append({"tool": decision.tool, "error": str(e)}) try: fallback = await self.llm.simple_call(req.message, temperature=req.temperature) except Exception as llm_error: error_msg = str(llm_error) if "Cannot connect" in error_msg or "Ollama" in error_msg: fallback = ( f"I encountered an error while processing your request: {str(e)}\n\n" f"Additionally, the AI service (Ollama) is unavailable: {error_msg}\n\n" f"To fix:\n" f"1. Install Ollama from https://ollama.ai\n" f"2. Start: `ollama serve`\n" f"3. Pull model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`" ) else: fallback = f"I encountered an error while processing your request: {str(e)}. Additionally, the AI service is unavailable: {error_msg}" return AgentResponse( text=fallback, decision=AgentDecision(action="respond", tool=None, tool_input=None, reason=f"tool_error_fallback: {e}"), tool_traces=tool_traces, reasoning_trace=reasoning_trace + [{ "step": "error", "tool": decision.tool, "error": str(e) }] ) # Default: direct LLM response try: llm_start = time.time() llm_out = await self.llm.simple_call(req.message, temperature=req.temperature) llm_latency_ms = int((time.time() - llm_start) * 1000) tools_used = ["llm"] estimated_tokens = len(llm_out) // 4 + len(req.message) // 4 self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="llm", latency_ms=llm_latency_ms, tokens_used=estimated_tokens, success=True, user_id=req.user_id ) except Exception as e: # If LLM fails, return a helpful error message error_msg = str(e) if "Cannot connect" in error_msg or "Ollama" in error_msg: llm_out = ( f"I couldn't connect to the AI service (Ollama). " f"Error: {error_msg}\n\n" f"To fix this:\n" f"1. Install Ollama from https://ollama.ai\n" f"2. Start Ollama: `ollama serve`\n" f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n" f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file" ) else: llm_out = f"I apologize, but I'm unable to process your request right now. The AI service is unavailable: {error_msg}" self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="llm", success=False, error_message=error_msg[:200], user_id=req.user_id ) reasoning_trace.append({ "step": "error", "tool": "llm", "error": str(e) }) total_latency_ms = int((time.time() - start_time) * 1000) self._analytics_log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent=intent, tools_used=tools_used if 'tools_used' in locals() else [], total_tokens=estimated_tokens if 'estimated_tokens' in locals() else 0, total_latency_ms=total_latency_ms, success=True if 'llm_out' in locals() else False, user_id=req.user_id ) return AgentResponse( text=llm_out, decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="default_llm"), reasoning_trace=reasoning_trace ) def _build_prompt_with_rag(self, req: AgentRequest, rag_resp: Dict[str, Any]) -> str: snippets = [] if isinstance(rag_resp, dict): hits = rag_resp.get("results") or rag_resp.get("hits") or [] for h in hits[:5]: txt = h.get("text") or h.get("content") or str(h) snippets.append(txt) snippet_text = "\n---\n".join(snippets) or "" prompt = ( f"You are an assistant helping tenant {req.tenant_id}. Use the following retrieved documents to answer the user's question.\n" f"Documents:\n{snippet_text}\n\n" f"User question: {req.message}\nProvide a concise, accurate answer and cite the source snippets where appropriate." ) return prompt async def _execute_multi_step(self, req: AgentRequest, steps: List[Dict[str, Any]], decision: AgentDecision, tool_traces: List[Dict[str, Any]], reasoning_trace: List[Dict[str, Any]], pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse: """ Execute multiple tools in sequence or parallel and synthesize results with LLM. Supports parallel execution when steps are marked with "parallel" flag. """ start_time = time.time() rag_data = None web_data = None admin_data = None collected_data = [] tools_used = [] total_tokens = 0 # Check if any step has parallel execution flag parallel_step = None for step_info in steps: if step_info.get("parallel"): parallel_step = step_info break # Handle parallel execution if detected if parallel_step and parallel_step.get("parallel"): parallel_config = parallel_step.get("parallel") parallel_tasks = {} start_time_parallel = time.time() # Prepare parallel tasks with retry logic if "rag" in parallel_config: rag_query = parallel_config["rag"] if pre_fetched_rag: # Use pre-fetched RAG if available - create a simple async function async def get_prefetched_rag(): return pre_fetched_rag parallel_tasks["rag"] = get_prefetched_rag() else: # Wrap with retry logic for parallel execution async def rag_with_retry_wrapper(): return await self.rag_with_repair( query=rag_query, tenant_id=req.tenant_id, original_threshold=0.3, reasoning_trace=reasoning_trace, user_id=req.user_id ) parallel_tasks["rag"] = rag_with_retry_wrapper() if "web" in parallel_config: web_query = parallel_config["web"] # Wrap with retry logic for parallel execution async def web_with_retry_wrapper(): return await self.web_with_repair( query=web_query, tenant_id=req.tenant_id, reasoning_trace=reasoning_trace, user_id=req.user_id ) parallel_tasks["web"] = web_with_retry_wrapper() # Execute tools in parallel if parallel_tasks: reasoning_trace.append({ "step": "parallel_execution", "tools": list(parallel_tasks.keys()), "mode": "parallel" }) parallel_results = await self.run_parallel_tools(parallel_tasks) parallel_latency_ms = int((time.time() - start_time_parallel) * 1000) # Process RAG results if "rag" in parallel_results: rag_result = parallel_results["rag"] if isinstance(rag_result, Exception): tool_traces.append({"tool": "rag", "error": str(rag_result), "note": "parallel"}) reasoning_trace.append({ "step": "tool_execution", "tool": "rag", "status": "error", "error": str(rag_result), "latency_ms": parallel_latency_ms }) self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="rag", latency_ms=parallel_latency_ms, success=False, error_message=str(rag_result)[:200], user_id=req.user_id ) else: rag_data = rag_result tools_used.append("rag") tool_traces.append({"tool": "rag", "response": rag_result, "note": "parallel"}) hits_count = len(self._extract_hits(rag_result)) avg_score = None top_score = None if hits_count > 0: scores = [h.get("score", 0.0) for h in self._extract_hits(rag_result) if isinstance(h, dict) and "score" in h] if scores: avg_score = sum(scores) / len(scores) top_score = max(scores) self._analytics_log_rag_search( tenant_id=req.tenant_id, query=req.message[:500], hits_count=hits_count, avg_score=avg_score, top_score=top_score, latency_ms=parallel_latency_ms ) self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="rag", latency_ms=parallel_latency_ms, success=True, user_id=req.user_id ) reasoning_trace.append({ "step": "tool_execution", "tool": "rag", "hit_count": hits_count, "summary": self._summarize_hits(rag_result, limit=2), "latency_ms": parallel_latency_ms, "mode": "parallel" }) # Process Web results if "web" in parallel_results: web_result = parallel_results["web"] if isinstance(web_result, Exception): tool_traces.append({"tool": "web", "error": str(web_result), "note": "parallel"}) reasoning_trace.append({ "step": "tool_execution", "tool": "web", "status": "error", "error": str(web_result), "latency_ms": parallel_latency_ms }) self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="web", latency_ms=parallel_latency_ms, success=False, error_message=str(web_result)[:200], user_id=req.user_id ) else: web_data = web_result tools_used.append("web") tool_traces.append({"tool": "web", "response": web_result, "note": "parallel"}) hits_count = len(self._extract_hits(web_result)) self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="web", latency_ms=parallel_latency_ms, success=True, user_id=req.user_id ) reasoning_trace.append({ "step": "tool_execution", "tool": "web", "hit_count": hits_count, "summary": self._summarize_hits(web_result, limit=2), "latency_ms": parallel_latency_ms, "mode": "parallel" }) # Merge parallel results merged_context = merge_parallel_results(parallel_results) sources_list = list(set(e.get("source") for e in merged_context if e.get("source"))) if merged_context else [] reasoning_trace.append({ "step": "result_merger", "merged_items": len(merged_context), "sources": sources_list }) # Format merged context for prompt data_section = format_merged_context_for_prompt(merged_context, max_items=10) else: data_section = "" else: # Sequential execution (original logic) parallel_tasks = {} rag_parallel_query = self._first_query_for_tool(steps, "rag", req.message) web_parallel_query = self._first_query_for_tool(steps, "web", req.message) if rag_parallel_query and web_parallel_query and rag_parallel_query == web_parallel_query: if not pre_fetched_rag: parallel_tasks["rag"] = asyncio.create_task(self.mcp.call_rag(req.tenant_id, rag_parallel_query)) parallel_tasks["web"] = asyncio.create_task(self.mcp.call_web(req.tenant_id, web_parallel_query)) # Execute each step in sequence for step_info in steps: tool_name = step_info.get("tool") step_input = step_info.get("input") or {} query = step_input.get("query") or req.message try: if tool_name == "rag": # Reuse pre-fetched RAG if available, otherwise fetch with retry if pre_fetched_rag and query == rag_parallel_query: rag_resp = pre_fetched_rag tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"}) elif parallel_tasks.get("rag") and query == rag_parallel_query: rag_resp = await parallel_tasks["rag"] tool_traces.append({"tool": "rag", "response": rag_resp, "note": "parallel"}) else: # Use autonomous retry with self-correction rag_resp = await self.rag_with_repair( query=query, tenant_id=req.tenant_id, original_threshold=0.3, reasoning_trace=reasoning_trace, user_id=req.user_id ) tool_traces.append({"tool": "rag", "response": rag_resp, "note": "with_retry"}) rag_data = rag_resp tools_used.append("rag") hits = self._extract_hits(rag_resp) reasoning_trace.append({ "step": "tool_execution", "tool": "rag", "hit_count": len(hits), "summary": self._summarize_hits(rag_resp, limit=2) }) # Extract snippets for prompt if isinstance(rag_resp, dict): for h in hits[:5]: txt = h.get("text") or h.get("content") or str(h) collected_data.append(f"[RAG] {txt}") elif tool_name == "web": if parallel_tasks.get("web") and query == web_parallel_query: web_resp = await parallel_tasks["web"] tool_traces.append({"tool": "web", "response": web_resp, "note": "parallel"}) else: # Use autonomous retry with query rewriting web_resp = await self.web_with_repair( query=query, tenant_id=req.tenant_id, reasoning_trace=reasoning_trace, user_id=req.user_id ) tool_traces.append({"tool": "web", "response": web_resp, "note": "with_retry"}) web_data = web_resp tools_used.append("web") hits = self._extract_hits(web_resp) reasoning_trace.append({ "step": "tool_execution", "tool": "web", "hit_count": len(hits), "summary": self._summarize_hits(web_resp, limit=2) }) # Extract snippets for prompt if isinstance(web_resp, dict): for h in hits[:5]: title = h.get("title") or h.get("headline") or "" snippet = h.get("snippet") or h.get("summary") or h.get("text") or "" url = h.get("url") or h.get("link") or "" collected_data.append(f"[WEB] {title}\n{snippet}\nSource: {url}") elif tool_name == "admin": admin_resp = await self.mcp.call_admin(req.tenant_id, query) tool_traces.append({"tool": "admin", "response": admin_resp}) admin_data = admin_resp tools_used.append("admin") collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}") reasoning_trace.append({ "step": "tool_execution", "tool": "admin", "status": "completed" }) elif tool_name == "llm": # LLM is always last - synthesize all collected data break except Exception as e: tool_traces.append({"tool": tool_name, "error": str(e)}) # Continue with other tools even if one fails reasoning_trace.append({ "step": "error", "tool": tool_name, "error": str(e) }) # Build comprehensive prompt with all collected data data_section = "\n---\n".join(collected_data) if collected_data else "" # Build final prompt if data_section: prompt = ( f"You are an assistant helping tenant {req.tenant_id}.\n\n" f"## Information Collected\n" f"The following details have been gathered from multiple reliable sources:\n" f"{data_section}\n\n" f"## User Request\n" f"{req.message}\n\n" f"## Your Task\n" f"Use the information above to directly address the user's request. " f"Focus on giving the user exactly what they need—clear guidance, accurate facts, " f"and practical steps whenever possible. If the information is incomplete, explain " f"what can and cannot be concluded from the available data." ) else: # No data collected, just answer the question prompt = req.message # Final LLM synthesis try: llm_start = time.time() llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) llm_latency_ms = int((time.time() - llm_start) * 1000) tools_used.append("llm") estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 total_tokens += estimated_tokens self._analytics_log_tool_usage( tenant_id=req.tenant_id, tool_name="llm", latency_ms=llm_latency_ms, tokens_used=estimated_tokens, success=True, user_id=req.user_id ) total_latency_ms = int((time.time() - start_time) * 1000) self._analytics_log_agent_query( tenant_id=req.tenant_id, message_preview=req.message[:200], intent="multi_step", tools_used=tools_used, total_tokens=total_tokens, total_latency_ms=total_latency_ms, success=True, user_id=req.user_id ) return AgentResponse( text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace + [{ "step": "llm_response", "mode": "multi_step_parallel" if parallel_step else "multi_step", "latency_ms": llm_latency_ms, "estimated_tokens": estimated_tokens }] ) except Exception as e: tool_traces.append({"tool": "llm", "error": str(e)}) error_msg = str(e) # Provide helpful error message if "Cannot connect" in error_msg or "Ollama" in error_msg: fallback = ( f"I couldn't connect to the AI service (Ollama). " f"Error: {error_msg}\n\n" f"To fix this:\n" f"1. Install Ollama from https://ollama.ai\n" f"2. Start Ollama: `ollama serve`\n" f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n" f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file" ) else: fallback = f"I encountered an error while synthesizing the response: {error_msg}" return AgentResponse( text=fallback, decision=AgentDecision( action="respond", tool=None, tool_input=None, reason=f"multi_step_llm_error: {e}" ), tool_traces=tool_traces, reasoning_trace=reasoning_trace + [{ "step": "error", "tool": "llm", "error": str(e) }] ) # ============================================================= # AUTONOMOUS RETRY + SELF-CORRECTION SYSTEM # ============================================================= """ This system provides autonomous retry and self-correction capabilities for the agent orchestrator. It enables the agent to: 1. **Self-healing**: Tools that break automatically retry with adjusted parameters 2. **Resilient operations**: Handles low RAG scores, empty web results, and misfired rules 3. **Smart optimization**: Automatically rewrites queries, adjusts thresholds, and optimizes parameters 4. **Enterprise-grade reliability**: Matches enterprise behavior with comprehensive retry strategies Key features: - safe_tool_call(): Generic retry wrapper for any tool call - rag_with_repair(): RAG search with automatic threshold adjustment and query expansion - web_with_repair(): Web search with automatic query rewriting for empty results - rule_safe_message(): Message rewriting to comply with admin rules All retry attempts are logged to analytics for monitoring and debugging. """ async def safe_tool_call( self, tool_fn, params: Dict[str, Any], max_retries: int = 2, fallback_params: Optional[Dict[str, Any]] = None, tool_name: str = "unknown", tenant_id: Optional[str] = None, user_id: Optional[str] = None, reasoning_trace: Optional[List[Dict[str, Any]]] = None ) -> Dict[str, Any]: """ Wrapper for tool calls with automatic retry and self-correction. Args: tool_fn: Async function to call params: Parameters to pass to tool_fn max_retries: Maximum number of retry attempts fallback_params: Alternative parameters to try if initial attempt fails tool_name: Name of the tool (for logging) tenant_id: Tenant ID (for analytics) user_id: User ID (for analytics) reasoning_trace: Optional reasoning trace to append to Returns: Tool result dictionary, or {"error": "tool_failed_after_retries"} if all attempts fail """ for attempt in range(max_retries): try: result = await tool_fn(**params) if attempt > 0: # Log successful retry if reasoning_trace is not None: reasoning_trace.append({ "step": "retry_success", "tool": tool_name, "attempt": attempt + 1, "status": "recovered" }) if tenant_id: self._analytics_log_tool_usage( tenant_id=tenant_id, tool_name=f"{tool_name}_retry_{attempt+1}", latency_ms=0, success=True, user_id=user_id ) return result except Exception as e: error_msg = str(e) if reasoning_trace is not None: reasoning_trace.append({ "step": "retry_attempt", "tool": tool_name, "attempt": attempt + 1, "error": error_msg[:200] }) # Log failed attempt if tenant_id: self._analytics_log_tool_usage( tenant_id=tenant_id, tool_name=tool_name, latency_ms=0, success=False, error_message=error_msg[:200], user_id=user_id ) # Try alternate params if provided and not last attempt if fallback_params and attempt < max_retries - 1: params = {**params, **fallback_params} if reasoning_trace is not None: reasoning_trace.append({ "step": "retry_with_fallback_params", "tool": tool_name, "attempt": attempt + 2, "fallback_params": fallback_params }) # If last attempt, return error if attempt == max_retries - 1: return {"error": "tool_failed_after_retries", "error_message": error_msg} return {"error": "tool_failed_after_retries"} async def rag_with_repair( self, query: str, tenant_id: str, original_threshold: float = 0.3, reasoning_trace: Optional[List[Dict[str, Any]]] = None, user_id: Optional[str] = None ) -> Dict[str, Any]: """ RAG search with automatic self-correction for low scores. Strategy: 1. Try with original threshold 2. If top_score < 0.30, retry with lower threshold (0.15) 3. If still low (< 0.15), expand query and retry """ # Initial attempt rag_start = time.time() result = await self.mcp.call_rag(tenant_id, query, threshold=original_threshold) rag_latency_ms = int((time.time() - rag_start) * 1000) # Extract hits and calculate scores hits = self._extract_hits(result) top_score = None avg_score = None if hits: scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h] if scores: top_score = max(scores) avg_score = sum(scores) / len(scores) if reasoning_trace is not None: reasoning_trace.append({ "step": "rag_initial_search", "query": query[:200], "hits_count": len(hits), "top_score": top_score, "avg_score": avg_score, "threshold": original_threshold }) # Retry logic: low score → lower threshold if top_score is not None and top_score < 0.30 and original_threshold >= 0.15: if reasoning_trace is not None: reasoning_trace.append({ "step": "rag_retry_low_threshold", "reason": f"top_score {top_score:.3f} < 0.30, retrying with threshold=0.15" }) retry_start = time.time() result = await self.mcp.call_rag(tenant_id, query, threshold=0.15) retry_latency_ms = int((time.time() - retry_start) * 1000) rag_latency_ms += retry_latency_ms hits = self._extract_hits(result) if hits: scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h] if scores: top_score = max(scores) avg_score = sum(scores) / len(scores) # Log retry self._analytics_log_tool_usage( tenant_id=tenant_id, tool_name="rag_retry_low_threshold", latency_ms=retry_latency_ms, success=True, user_id=user_id ) # Final retry: expand query if score still too low if top_score is not None and top_score < 0.15: expanded_query = f"{query} (more details comprehensive explanation)" if reasoning_trace is not None: reasoning_trace.append({ "step": "rag_retry_expanded_query", "reason": f"top_score {top_score:.3f} < 0.15, retrying with expanded query", "original_query": query[:200], "expanded_query": expanded_query[:200] }) retry_start = time.time() result = await self.mcp.call_rag(tenant_id, expanded_query, threshold=0.15) retry_latency_ms = int((time.time() - retry_start) * 1000) rag_latency_ms += retry_latency_ms hits = self._extract_hits(result) if hits: scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h] if scores: top_score = max(scores) avg_score = sum(scores) / len(scores) # Log retry self._analytics_log_tool_usage( tenant_id=tenant_id, tool_name="rag_retry_expanded_query", latency_ms=retry_latency_ms, success=True, user_id=user_id ) if reasoning_trace is not None: reasoning_trace.append({ "step": "rag_expanded_query_result", "hits_count": len(hits), "top_score": top_score, "avg_score": avg_score }) # Log final RAG search if hits: self._analytics_log_rag_search( tenant_id=tenant_id, query=query[:500], hits_count=len(hits), avg_score=avg_score, top_score=top_score, latency_ms=rag_latency_ms ) return result async def web_with_repair( self, query: str, tenant_id: str, reasoning_trace: Optional[List[Dict[str, Any]]] = None, user_id: Optional[str] = None ) -> Dict[str, Any]: """ Web search with automatic query rewriting for empty results. Strategy: 1. Try original query 2. If empty, try "best explanation of {query}" 3. If still empty, try "{query} facts summary" """ # Initial attempt web_start = time.time() result = await self.mcp.call_web(tenant_id, query) web_latency_ms = int((time.time() - web_start) * 1000) hits = self._extract_hits(result) if reasoning_trace is not None: reasoning_trace.append({ "step": "web_initial_search", "query": query[:200], "hits_count": len(hits) }) # Retry logic: empty results → rewrite query if not result or len(hits) == 0: rewritten_queries = [ f"best explanation of {query}", f"{query} facts summary" ] for i, rewritten in enumerate(rewritten_queries): if reasoning_trace is not None: reasoning_trace.append({ "step": "web_retry_rewritten", "attempt": i + 1, "original_query": query[:200], "rewritten_query": rewritten[:200] }) retry_start = time.time() result = await self.mcp.call_web(tenant_id, rewritten) retry_latency_ms = int((time.time() - retry_start) * 1000) web_latency_ms += retry_latency_ms hits = self._extract_hits(result) # Log retry self._analytics_log_tool_usage( tenant_id=tenant_id, tool_name=f"web_retry_rewrite_{i+1}", latency_ms=retry_latency_ms, success=True, user_id=user_id ) if hits: if reasoning_trace is not None: reasoning_trace.append({ "step": "web_retry_success", "rewritten_query": rewritten[:200], "hits_count": len(hits) }) break # Log final web search self._analytics_log_tool_usage( tenant_id=tenant_id, tool_name="web", latency_ms=web_latency_ms, success=len(hits) > 0, user_id=user_id ) return result async def rule_safe_message( self, user_message: str, tenant_id: str, reasoning_trace: Optional[List[Dict[str, Any]]] = None ) -> str: """ Check admin rules and rewrite message if it violates policies. Strategy: 1. Check rules 2. If blocked, ask LLM to rewrite message to comply 3. Return safe version """ matches: List[RedFlagMatch] = await self.redflag.check(tenant_id, user_message) if not matches: return user_message # Check if any are blocking rules (not just brief response rules) blocking_rules = [] for match in matches: rule_text = (match.description or match.pattern or "").lower() is_brief_rule = ( match.severity == "low" and ( "greeting" in rule_text or "brief" in rule_text or "simple response" in rule_text ) ) if not is_brief_rule: blocking_rules.append(match) # Only rewrite if there are blocking rules if not blocking_rules: return user_message if reasoning_trace is not None: reasoning_trace.append({ "step": "rule_violation_detected", "blocking_rules_count": len(blocking_rules), "action": "attempting_rewrite" }) # Ask LLM to rewrite the message rewrite_prompt = f"""The following user message violates company policies. Rewrite it to be compliant while preserving the user's intent as much as possible. Original message: "{user_message}" Violated policies: {chr(10).join([f"- {m.description or m.pattern}" for m in blocking_rules[:3]])} Provide a rewritten version that: 1. Avoids the policy violations 2. Preserves the user's original intent 3. Remains professional and helpful Rewritten message:""" try: rewritten = await self.llm.simple_call(rewrite_prompt, temperature=0.3) rewritten = rewritten.strip().strip('"').strip("'") if reasoning_trace is not None: reasoning_trace.append({ "step": "rule_rewrite_completed", "original_length": len(user_message), "rewritten_length": len(rewritten), "rewritten_preview": rewritten[:200] }) # Verify the rewritten message doesn't trigger rules verify_matches = await self.redflag.check(tenant_id, rewritten) if not verify_matches or all( (m.description or m.pattern or "").lower() in ["greeting", "brief", "simple response"] for m in verify_matches ): return rewritten if reasoning_trace is not None: reasoning_trace.append({ "step": "rule_rewrite_still_violates", "action": "using_original_with_block" }) except Exception as e: if reasoning_trace is not None: reasoning_trace.append({ "step": "rule_rewrite_failed", "error": str(e)[:200] }) # Return original if rewrite failed or still violates return user_message def _build_prompt_with_web(self, req: AgentRequest, web_resp: Dict[str, Any]) -> str: snippets = [] if isinstance(web_resp, dict): hits = web_resp.get("results") or web_resp.get("items") or [] for h in hits[:5]: title = h.get("title") or h.get("headline") or "" snippet = h.get("snippet") or h.get("summary") or h.get("text") or "" url = h.get("url") or h.get("link") or "" snippets.append(f"{title}\n{snippet}\nSource: {url}") snippet_text = "\n---\n".join(snippets) or "" # prompt = ( # f"You are an assistant with access to recent web search results. Use the following results to answer.\n{snippet_text}\n\n" # f"User question: {req.message}\nAnswer succinctly and indicate which results you used." # ) prompt = ( f"You are an assistant with access to recent web search results.\n\n" f"## Search Results\n" f"{snippet_text}\n\n" f"## User Question\n" f"{req.message}\n\n" f"## Your Task\n" f"Provide a clear, accurate, and succinct answer based on the search results above. " f"Indicate which results you used in your reasoning." ) return prompt @staticmethod def _extract_hits(resp: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]: if not isinstance(resp, dict): return [] return resp.get("results") or resp.get("hits") or resp.get("items") or [] def _summarize_hits(self, resp: Optional[Dict[str, Any]], limit: int = 3) -> List[str]: hits = self._extract_hits(resp) summaries = [] for hit in hits[:limit]: if isinstance(hit, dict): snippet = hit.get("text") or hit.get("content") or hit.get("snippet") or "" else: snippet = str(hit) summaries.append(snippet[:160]) return summaries async def run_parallel_tools(self, tasks: Dict[str, Any]) -> Dict[str, Any]: """ Run multiple tools in parallel using asyncio.gather. Args: tasks: Dictionary mapping tool names to coroutines, e.g.: {"rag": rag_coro, "web": web_coro} Returns: Dictionary mapping tool names to results, e.g.: {"rag": rag_result, "web": web_result} Exceptions are returned as values if a tool fails. """ if not tasks: return {} names = list(tasks.keys()) coros = [tasks[name] for name in names] # Run all coroutines in parallel, return exceptions instead of raising results = await asyncio.gather(*coros, return_exceptions=True) return {names[i]: results[i] for i in range(len(names))} @staticmethod def _first_query_for_tool(steps: List[Dict[str, Any]], tool_name: str, default_query: str) -> Optional[str]: for step in steps: if step.get("tool") == tool_name: input_data = step.get("input") or {} return input_data.get("query") or default_query return None