Spaces:
Sleeping
Sleeping
feat: Add AI metadata extraction, latency prediction, context-aware routing, and tool output schemas
d1e5882
| # ============================================================= | |
| # File: backend/api/services/agent_orchestrator.py | |
| # ============================================================= | |
| """ | |
| Agent Orchestrator (integrated with enterprise RedFlagDetector) | |
| Place at: backend/api/services/agent_orchestrator.py | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| from typing import List, Dict, Any, Optional | |
| import logging | |
| from ..models.agent import AgentRequest, AgentDecision, AgentResponse | |
| from ..models.redflag import RedFlagMatch | |
| from .redflag_detector import RedFlagDetector | |
| from .intent_classifier import IntentClassifier | |
| from .tool_selector import ToolSelector | |
| from .llm_client import LLMClient | |
| from ..mcp_clients.mcp_client import MCPClient | |
| from .tool_scoring import ToolScoringService | |
| from ..storage.analytics_store import AnalyticsStore | |
| from .result_merger import merge_parallel_results, format_merged_context_for_prompt | |
| from .tool_metadata import validate_tool_output, get_tool_schema | |
| import time | |
| logger = logging.getLogger(__name__) | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class AgentOrchestrator: | |
| def __init__(self, rag_mcp_url: str, web_mcp_url: str, admin_mcp_url: str, llm_backend: str = "ollama"): | |
| self.mcp = MCPClient(rag_mcp_url, web_mcp_url, admin_mcp_url) | |
| self.llm = LLMClient(backend=llm_backend, url=os.getenv("OLLAMA_URL"), api_key=os.getenv("GROQ_API_KEY"), model=os.getenv("OLLAMA_MODEL")) | |
| # pass admin_mcp_url so detector can call back | |
| self.redflag = RedFlagDetector( | |
| supabase_url=os.getenv("SUPABASE_URL"), | |
| supabase_key=os.getenv("SUPABASE_SERVICE_KEY"), | |
| admin_mcp_url=admin_mcp_url | |
| ) | |
| self.intent = IntentClassifier(llm_client=self.llm) | |
| self.selector = ToolSelector(llm_client=self.llm) | |
| self.tool_scorer = ToolScoringService() | |
| self._analytics: Optional[AnalyticsStore] = None | |
| self._analytics_disabled = os.getenv("ANALYTICS_DISABLED", "").lower() in {"1", "true", "yes"} | |
| self._analytics_failed = False | |
| self._log_analytics_backend_once() | |
| def _log_analytics_backend_once(self) -> None: | |
| if getattr(AgentOrchestrator, "_analytics_backend_logged", False): | |
| return | |
| if self._analytics_disabled: | |
| print("⚠️ AgentOrchestrator Analytics: Disabled via ANALYTICS_DISABLED") | |
| else: | |
| store = self._get_analytics() | |
| if store is None: | |
| print("⚠️ AgentOrchestrator Analytics: Disabled (Supabase not configured)") | |
| elif store.use_supabase: | |
| print("✅ AgentOrchestrator Analytics: Using Supabase backend") | |
| else: | |
| print("⚠️ AgentOrchestrator Analytics: Using fallback backend") | |
| AgentOrchestrator._analytics_backend_logged = True | |
| def _get_analytics(self) -> Optional[AnalyticsStore]: | |
| if self._analytics_disabled or self._analytics_failed: | |
| return None | |
| if self._analytics is not None: | |
| return self._analytics | |
| try: | |
| self._analytics = AnalyticsStore() | |
| except RuntimeError as exc: | |
| logger.warning("AgentOrchestrator analytics disabled: %s", exc) | |
| self._analytics_failed = True | |
| self._analytics = None | |
| except Exception as exc: # pragma: no cover - unexpected initialization failures | |
| logger.debug("AgentOrchestrator analytics unexpected init failure: %s", exc) | |
| self._analytics_failed = True | |
| self._analytics = None | |
| return self._analytics | |
| def _analytics_log_tool_usage(self, **kwargs: Any) -> None: | |
| analytics = self._get_analytics() | |
| if not analytics: | |
| return | |
| try: | |
| analytics.log_tool_usage(**kwargs) | |
| except Exception as exc: # pragma: no cover - analytics failures should not break flow | |
| logger.debug("AgentOrchestrator tool analytics failed: %s", exc) | |
| def _analytics_log_agent_query(self, **kwargs: Any) -> None: | |
| analytics = self._get_analytics() | |
| if not analytics: | |
| return | |
| try: | |
| analytics.log_agent_query(**kwargs) | |
| except Exception as exc: # pragma: no cover | |
| logger.debug("AgentOrchestrator agent query analytics failed: %s", exc) | |
| def _analytics_log_rag_search(self, **kwargs: Any) -> None: | |
| analytics = self._get_analytics() | |
| if not analytics: | |
| return | |
| try: | |
| analytics.log_rag_search(**kwargs) | |
| except Exception as exc: # pragma: no cover | |
| logger.debug("AgentOrchestrator RAG analytics failed: %s", exc) | |
| def _analytics_log_redflag_violation(self, **kwargs: Any) -> None: | |
| analytics = self._get_analytics() | |
| if not analytics: | |
| return | |
| try: | |
| analytics.log_redflag_violation(**kwargs) | |
| except Exception as exc: # pragma: no cover | |
| logger.debug("AgentOrchestrator redflag analytics failed: %s", exc) | |
| async def handle(self, req: AgentRequest) -> AgentResponse: | |
| start_time = time.time() | |
| reasoning_trace: List[Dict[str, Any]] = [] | |
| reasoning_trace.append({ | |
| "step": "request_received", | |
| "tenant_id": req.tenant_id, | |
| "user_id": req.user_id, | |
| "message_preview": req.message[:120] | |
| }) | |
| # 1) FIRST: Check admin rules - if any rule matches, respond according to rule | |
| matches: List[RedFlagMatch] = await self.redflag.check(req.tenant_id, req.message) | |
| reasoning_trace.append({ | |
| "step": "admin_rules_check", | |
| "match_count": len(matches), | |
| "matches": [m.__dict__ for m in matches] | |
| }) | |
| if matches: | |
| # Log all rule matches | |
| for match in matches: | |
| self._analytics_log_redflag_violation( | |
| tenant_id=req.tenant_id, | |
| rule_id=match.rule_id, | |
| rule_pattern=match.pattern, | |
| severity=match.severity, | |
| matched_text=match.matched_text, | |
| confidence=match.confidence, | |
| message_preview=req.message[:200], | |
| user_id=req.user_id | |
| ) | |
| # Categorize rules: brief response rules vs blocking rules | |
| brief_response_rules = [] | |
| blocking_rules = [] | |
| for match in matches: | |
| rule_text = (match.description or match.pattern or "").lower() | |
| is_brief_rule = ( | |
| match.severity == "low" and ( | |
| "greeting" in rule_text or | |
| "brief" in rule_text or | |
| "simple response" in rule_text or | |
| "keep.*response.*brief" in rule_text or | |
| "do not.*verbose" in rule_text or | |
| "respond.*briefly" in rule_text | |
| ) | |
| ) | |
| if is_brief_rule: | |
| brief_response_rules.append(match) | |
| else: | |
| blocking_rules.append(match) | |
| # Handle brief response rules (greetings, etc.) - return immediately | |
| if brief_response_rules and not blocking_rules: | |
| # Return brief response without proceeding to normal flow | |
| brief_responses = [ | |
| "Hello! How can I help you today?", | |
| "Hi there! What can I assist you with?", | |
| "Hello! I'm here to help. What do you need?", | |
| "Hi! How can I assist you?" | |
| ] | |
| import random | |
| brief_response = random.choice(brief_responses) | |
| reasoning_trace.append({ | |
| "step": "brief_response_rule_matched", | |
| "action": "brief_response", | |
| "matched_rules": [m.rule_id for m in brief_response_rules], | |
| "message": "Brief response rule matched, returning brief response (skipping normal flow)" | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent="greeting", | |
| tools_used=[], | |
| total_tokens=len(brief_response) // 4, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse( | |
| text=brief_response, | |
| decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="brief_response_rule"), | |
| reasoning_trace=reasoning_trace | |
| ) | |
| # Handle blocking rules (security, compliance, etc.) - block and return immediately | |
| if blocking_rules: | |
| # Notify admin asynchronously | |
| try: | |
| await self.redflag.notify_admin(req.tenant_id, blocking_rules, source_payload={"message": req.message, "user_id": req.user_id}) | |
| except Exception: | |
| pass | |
| decision = AgentDecision( | |
| action="block", | |
| tool="admin", | |
| tool_input={"violations": [m.__dict__ for m in blocking_rules]}, | |
| reason="admin_rule_violation" | |
| ) | |
| # Build detailed prompt for LLM to generate natural red flag response | |
| violations_details = [] | |
| for i, m in enumerate(blocking_rules, 1): | |
| rule_name = m.description or m.pattern or "Policy violation" | |
| detail = f"{i}. **{rule_name}** (Severity: {m.severity})" | |
| if m.matched_text: | |
| detail += f"\n - Detected phrase: \"{m.matched_text}\"" | |
| violations_details.append(detail) | |
| llm_prompt = f"""A user made the following request: "{req.message}" | |
| However, this request violates company policies. The following policy violations were detected: | |
| {chr(10).join(violations_details)} | |
| Your task: Write a clear, professional, and empathetic response to inform the user that: | |
| 1. Their request cannot be processed due to policy violations | |
| 2. Which specific policy was violated (mention it naturally) | |
| 3. The incident has been logged for security review | |
| 4. They should contact an administrator if they need assistance or believe this is an error | |
| Write a natural, conversational response (2-4 sentences) that feels helpful rather than robotic. Be professional but understanding. | |
| Response:""" | |
| # Generate LLM response for red flag | |
| try: | |
| llm_response = await self.llm.simple_call(llm_prompt, temperature=min(req.temperature + 0.2, 0.7)) # Slightly higher temp for more natural response | |
| llm_response = llm_response.strip() | |
| # Add warning emoji if not present | |
| if not llm_response.startswith("⚠️") and not llm_response.startswith("🚨"): | |
| llm_response = f"⚠️ {llm_response}" | |
| except Exception as e: | |
| # Fallback to a simple message if LLM fails | |
| summary = "; ".join( | |
| f"{m.description or m.pattern}" | |
| for m in matches | |
| ) | |
| llm_response = f"⚠️ I'm unable to process your request because it violates our company policy: {summary}. This incident has been logged. Please contact your administrator if you need assistance." | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| # Log LLM usage for red flag response | |
| estimated_tokens = len(llm_response) // 4 + len(llm_prompt) // 4 | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=total_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent="admin", | |
| tools_used=["admin", "llm"], | |
| total_tokens=estimated_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=False, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse( | |
| text=llm_response, | |
| decision=decision, | |
| tool_traces=[{"redflags": [m.__dict__ for m in blocking_rules]}], | |
| reasoning_trace=reasoning_trace | |
| ) | |
| # 2) ONLY IF NO RULES MATCHED: Proceed with normal flow (intent classification, RAG, etc.) | |
| # 2.1) Optional: Try to rewrite message if it might violate rules (preventive self-correction) | |
| # Note: This is a lighter check - we already blocked above if rules matched | |
| # This is for edge cases where we want to proactively improve the message | |
| safe_message = req.message # Default to original | |
| intent = await self.intent.classify(req.message) | |
| reasoning_trace.append({ | |
| "step": "intent_detection", | |
| "intent": intent | |
| }) | |
| # 2.5) Pre-fetch RAG results if available (for tool selector context) | |
| rag_prefetch = None | |
| rag_results = [] | |
| try: | |
| # Try to pre-fetch RAG to help tool selector make better decisions | |
| rag_start = time.time() | |
| rag_prefetch = await self.mcp.call_rag(req.tenant_id, req.message) | |
| rag_latency_ms = int((time.time() - rag_start) * 1000) | |
| if isinstance(rag_prefetch, dict): | |
| rag_results = rag_prefetch.get("results") or rag_prefetch.get("hits") or [] | |
| # Log RAG search event | |
| hits_count = len(rag_results) | |
| avg_score = None | |
| top_score = None | |
| if rag_results: | |
| scores = [h.get("score", 0.0) for h in rag_results if isinstance(h, dict) and "score" in h] | |
| if scores: | |
| avg_score = sum(scores) / len(scores) | |
| top_score = max(scores) | |
| self._analytics_log_rag_search( | |
| tenant_id=req.tenant_id, | |
| query=req.message[:500], | |
| hits_count=hits_count, | |
| avg_score=avg_score, | |
| top_score=top_score, | |
| latency_ms=rag_latency_ms | |
| ) | |
| # Log tool usage | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=rag_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "rag_prefetch", | |
| "status": "ok", | |
| "hit_count": len(rag_results), | |
| "latency_ms": rag_latency_ms | |
| }) | |
| except Exception as pref_err: | |
| # If RAG fails, continue without it | |
| rag_latency_ms = 0 # 0 for failed | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=rag_latency_ms, | |
| success=False, | |
| error_message=str(pref_err)[:200], | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "rag_prefetch", | |
| "status": "error", | |
| "error": str(pref_err) | |
| }) | |
| rag_prefetch = None | |
| tool_scores = self.tool_scorer.score(req.message, intent, rag_results) | |
| reasoning_trace.append({ | |
| "step": "tool_scoring", | |
| "scores": tool_scores | |
| }) | |
| # 3) Tool selection (hybrid) - pass RAG results, memory, and admin violations in context | |
| # Get recent memory for context-aware routing | |
| from backend.mcp_server.common.memory import get_recent_memory | |
| session_id = req.conversation_history[-1].get("session_id") if req.conversation_history else None | |
| recent_memory = [] | |
| if session_id: | |
| recent_memory = get_recent_memory(session_id) | |
| # Get admin violations if any | |
| admin_violations = [] | |
| if hasattr(self, 'redflag') and self.redflag: | |
| # Check if there were any violations detected | |
| # (This would be set during redflag checking earlier in the flow) | |
| pass # Admin violations are checked separately | |
| ctx = { | |
| "tenant_id": req.tenant_id, | |
| "rag_results": rag_results, | |
| "tool_scores": tool_scores, | |
| "memory": recent_memory, # Context-aware routing: recent tool outputs | |
| "admin_violations": admin_violations # Context-aware routing: admin rule severity | |
| } | |
| decision = await self.selector.select(intent, req.message, ctx) | |
| reasoning_trace.append({ | |
| "step": "tool_selection", | |
| "decision": decision.dict(), | |
| "context_scores": tool_scores | |
| }) | |
| tool_traces: List[Dict[str, Any]] = [] | |
| # 4) Handle multi-step tool execution | |
| if decision.action == "multi_step" and decision.tool_input: | |
| steps = decision.tool_input.get("steps", []) | |
| if steps: | |
| return await self._execute_multi_step( | |
| req, | |
| steps, | |
| decision, | |
| tool_traces, | |
| reasoning_trace, | |
| rag_prefetch | |
| ) | |
| # 5) Execute single tool | |
| tools_used = [] | |
| total_tokens = 0 | |
| if decision.action == "call_tool" and decision.tool: | |
| try: | |
| if decision.tool == "rag": | |
| # Use autonomous retry with self-correction | |
| rag_query = decision.tool_input.get("query") if decision.tool_input else req.message | |
| rag_start = time.time() | |
| rag_resp = await self.rag_with_repair( | |
| query=rag_query, | |
| tenant_id=req.tenant_id, | |
| original_threshold=0.3, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| rag_latency_ms = int((time.time() - rag_start) * 1000) | |
| tools_used.append("rag") | |
| # Validate and format RAG output to conform to schema | |
| rag_formatted = self._format_tool_output("rag", rag_resp, rag_latency_ms) | |
| tool_traces.append({"tool": "rag", "response": rag_formatted}) | |
| hits = self._extract_hits(rag_formatted) | |
| # Calculate scores for logging | |
| hits_count = len(hits) | |
| avg_score = rag_formatted.get("avg_score") | |
| top_score = rag_formatted.get("top_score") | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "hit_count": hits_count, | |
| "top_score": top_score, | |
| "avg_score": avg_score, | |
| "summary": self._summarize_hits(rag_formatted, limit=2) | |
| }) | |
| prompt = self._build_prompt_with_rag(req, rag_formatted) | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| # Estimate tokens (rough: ~4 chars per token) | |
| estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 | |
| total_tokens += estimated_tokens | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "llm_response", | |
| "mode": "rag_synthesis", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) | |
| if decision.tool == "web": | |
| # Use autonomous retry with query rewriting | |
| web_query = decision.tool_input.get("query") if decision.tool_input else req.message | |
| web_start = time.time() | |
| web_resp = await self.web_with_repair( | |
| query=web_query, | |
| tenant_id=req.tenant_id, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| web_latency_ms = int((time.time() - web_start) * 1000) | |
| tools_used.append("web") | |
| # Validate and format Web output to conform to schema | |
| web_formatted = self._format_tool_output("web", web_resp, web_latency_ms) | |
| tool_traces.append({"tool": "web", "response": web_formatted}) | |
| hits_count = len(self._extract_hits(web_formatted)) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "hit_count": hits_count, | |
| "summary": self._summarize_hits(web_formatted, limit=2) | |
| }) | |
| prompt = self._build_prompt_with_web(req, web_formatted) | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 | |
| total_tokens += estimated_tokens | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "llm_response", | |
| "mode": "web_synthesis", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) | |
| if decision.tool == "admin": | |
| admin_start = time.time() | |
| admin_resp = await self.mcp.call_admin(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message) | |
| admin_latency_ms = int((time.time() - admin_start) * 1000) | |
| tools_used.append("admin") | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="admin", | |
| latency_ms=admin_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| # Validate and format Admin output to conform to schema | |
| admin_formatted = self._format_tool_output("admin", admin_resp, admin_latency_ms) | |
| tool_traces.append({"tool": "admin", "response": admin_formatted}) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "admin", | |
| "status": "completed", | |
| "latency_ms": admin_latency_ms | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=0, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse(text=json.dumps(admin_resp), decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) | |
| if decision.tool == "llm": | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(req.message, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| estimated_tokens = len(llm_out) // 4 + len(req.message) // 4 | |
| total_tokens += estimated_tokens | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "llm_response", | |
| "mode": "direct", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse(text=llm_out, decision=decision, reasoning_trace=reasoning_trace) | |
| except Exception as e: | |
| tool_traces.append({"tool": decision.tool, "error": str(e)}) | |
| try: | |
| fallback = await self.llm.simple_call(req.message, temperature=req.temperature) | |
| except Exception as llm_error: | |
| error_msg = str(llm_error) | |
| if "Cannot connect" in error_msg or "Ollama" in error_msg: | |
| fallback = ( | |
| f"I encountered an error while processing your request: {str(e)}\n\n" | |
| f"Additionally, the AI service (Ollama) is unavailable: {error_msg}\n\n" | |
| f"To fix:\n" | |
| f"1. Install Ollama from https://ollama.ai\n" | |
| f"2. Start: `ollama serve`\n" | |
| f"3. Pull model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`" | |
| ) | |
| else: | |
| fallback = f"I encountered an error while processing your request: {str(e)}. Additionally, the AI service is unavailable: {error_msg}" | |
| return AgentResponse( | |
| text=fallback, | |
| decision=AgentDecision(action="respond", tool=None, tool_input=None, reason=f"tool_error_fallback: {e}"), | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace + [{ | |
| "step": "error", | |
| "tool": decision.tool, | |
| "error": str(e) | |
| }] | |
| ) | |
| # Default: direct LLM response | |
| try: | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(req.message, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used = ["llm"] | |
| estimated_tokens = len(llm_out) // 4 + len(req.message) // 4 | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| except Exception as e: | |
| # If LLM fails, return a helpful error message | |
| error_msg = str(e) | |
| if "Cannot connect" in error_msg or "Ollama" in error_msg: | |
| llm_out = ( | |
| f"I couldn't connect to the AI service (Ollama). " | |
| f"Error: {error_msg}\n\n" | |
| f"To fix this:\n" | |
| f"1. Install Ollama from https://ollama.ai\n" | |
| f"2. Start Ollama: `ollama serve`\n" | |
| f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n" | |
| f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file" | |
| ) | |
| else: | |
| llm_out = f"I apologize, but I'm unable to process your request right now. The AI service is unavailable: {error_msg}" | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| success=False, | |
| error_message=error_msg[:200], | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "error", | |
| "tool": "llm", | |
| "error": str(e) | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used if 'tools_used' in locals() else [], | |
| total_tokens=estimated_tokens if 'estimated_tokens' in locals() else 0, | |
| total_latency_ms=total_latency_ms, | |
| success=True if 'llm_out' in locals() else False, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse( | |
| text=llm_out, | |
| decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="default_llm"), | |
| reasoning_trace=reasoning_trace | |
| ) | |
| def _build_prompt_with_rag(self, req: AgentRequest, rag_resp: Dict[str, Any]) -> str: | |
| snippets = [] | |
| if isinstance(rag_resp, dict): | |
| hits = rag_resp.get("results") or rag_resp.get("hits") or [] | |
| for h in hits[:5]: | |
| txt = h.get("text") or h.get("content") or str(h) | |
| snippets.append(txt) | |
| snippet_text = "\n---\n".join(snippets) or "" | |
| prompt = ( | |
| f"You are an assistant helping tenant {req.tenant_id}. Use the following retrieved documents to answer the user's question.\n" | |
| f"Documents:\n{snippet_text}\n\n" | |
| f"User question: {req.message}\nProvide a concise, accurate answer and cite the source snippets where appropriate." | |
| ) | |
| return prompt | |
| async def _execute_multi_step(self, req: AgentRequest, steps: List[Dict[str, Any]], | |
| decision: AgentDecision, tool_traces: List[Dict[str, Any]], | |
| reasoning_trace: List[Dict[str, Any]], | |
| pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse: | |
| """ | |
| Execute multiple tools in sequence or parallel and synthesize results with LLM. | |
| Supports parallel execution when steps are marked with "parallel" flag. | |
| """ | |
| start_time = time.time() | |
| rag_data = None | |
| web_data = None | |
| admin_data = None | |
| collected_data = [] | |
| tools_used = [] | |
| total_tokens = 0 | |
| # Check if any step has parallel execution flag | |
| parallel_step = None | |
| for step_info in steps: | |
| if step_info.get("parallel"): | |
| parallel_step = step_info | |
| break | |
| # Handle parallel execution if detected | |
| if parallel_step and parallel_step.get("parallel"): | |
| parallel_config = parallel_step.get("parallel") | |
| parallel_tasks = {} | |
| start_time_parallel = time.time() | |
| # Prepare parallel tasks with retry logic | |
| if "rag" in parallel_config: | |
| rag_query = parallel_config["rag"] | |
| if pre_fetched_rag: | |
| # Use pre-fetched RAG if available - create a simple async function | |
| async def get_prefetched_rag(): | |
| return pre_fetched_rag | |
| parallel_tasks["rag"] = get_prefetched_rag() | |
| else: | |
| # Wrap with retry logic for parallel execution | |
| async def rag_with_retry_wrapper(): | |
| return await self.rag_with_repair( | |
| query=rag_query, | |
| tenant_id=req.tenant_id, | |
| original_threshold=0.3, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| parallel_tasks["rag"] = rag_with_retry_wrapper() | |
| if "web" in parallel_config: | |
| web_query = parallel_config["web"] | |
| # Wrap with retry logic for parallel execution | |
| async def web_with_retry_wrapper(): | |
| return await self.web_with_repair( | |
| query=web_query, | |
| tenant_id=req.tenant_id, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| parallel_tasks["web"] = web_with_retry_wrapper() | |
| # Execute tools in parallel | |
| if parallel_tasks: | |
| reasoning_trace.append({ | |
| "step": "parallel_execution", | |
| "tools": list(parallel_tasks.keys()), | |
| "mode": "parallel" | |
| }) | |
| parallel_results = await self.run_parallel_tools(parallel_tasks) | |
| parallel_latency_ms = int((time.time() - start_time_parallel) * 1000) | |
| # Process RAG results | |
| if "rag" in parallel_results: | |
| rag_result = parallel_results["rag"] | |
| if isinstance(rag_result, Exception): | |
| tool_traces.append({"tool": "rag", "error": str(rag_result), "note": "parallel"}) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "status": "error", | |
| "error": str(rag_result), | |
| "latency_ms": parallel_latency_ms | |
| }) | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=parallel_latency_ms, | |
| success=False, | |
| error_message=str(rag_result)[:200], | |
| user_id=req.user_id | |
| ) | |
| else: | |
| rag_data = rag_result | |
| tools_used.append("rag") | |
| tool_traces.append({"tool": "rag", "response": rag_result, "note": "parallel"}) | |
| hits_count = len(self._extract_hits(rag_result)) | |
| avg_score = None | |
| top_score = None | |
| if hits_count > 0: | |
| scores = [h.get("score", 0.0) for h in self._extract_hits(rag_result) if isinstance(h, dict) and "score" in h] | |
| if scores: | |
| avg_score = sum(scores) / len(scores) | |
| top_score = max(scores) | |
| self._analytics_log_rag_search( | |
| tenant_id=req.tenant_id, | |
| query=req.message[:500], | |
| hits_count=hits_count, | |
| avg_score=avg_score, | |
| top_score=top_score, | |
| latency_ms=parallel_latency_ms | |
| ) | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=parallel_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "hit_count": hits_count, | |
| "summary": self._summarize_hits(rag_result, limit=2), | |
| "latency_ms": parallel_latency_ms, | |
| "mode": "parallel" | |
| }) | |
| # Process Web results | |
| if "web" in parallel_results: | |
| web_result = parallel_results["web"] | |
| if isinstance(web_result, Exception): | |
| tool_traces.append({"tool": "web", "error": str(web_result), "note": "parallel"}) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "status": "error", | |
| "error": str(web_result), | |
| "latency_ms": parallel_latency_ms | |
| }) | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="web", | |
| latency_ms=parallel_latency_ms, | |
| success=False, | |
| error_message=str(web_result)[:200], | |
| user_id=req.user_id | |
| ) | |
| else: | |
| web_data = web_result | |
| tools_used.append("web") | |
| tool_traces.append({"tool": "web", "response": web_result, "note": "parallel"}) | |
| hits_count = len(self._extract_hits(web_result)) | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="web", | |
| latency_ms=parallel_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "hit_count": hits_count, | |
| "summary": self._summarize_hits(web_result, limit=2), | |
| "latency_ms": parallel_latency_ms, | |
| "mode": "parallel" | |
| }) | |
| # Merge parallel results | |
| merged_context = merge_parallel_results(parallel_results) | |
| sources_list = list(set(e.get("source") for e in merged_context if e.get("source"))) if merged_context else [] | |
| reasoning_trace.append({ | |
| "step": "result_merger", | |
| "merged_items": len(merged_context), | |
| "sources": sources_list | |
| }) | |
| # Format merged context for prompt | |
| data_section = format_merged_context_for_prompt(merged_context, max_items=10) | |
| else: | |
| data_section = "" | |
| else: | |
| # Sequential execution (original logic) | |
| parallel_tasks = {} | |
| rag_parallel_query = self._first_query_for_tool(steps, "rag", req.message) | |
| web_parallel_query = self._first_query_for_tool(steps, "web", req.message) | |
| if rag_parallel_query and web_parallel_query and rag_parallel_query == web_parallel_query: | |
| if not pre_fetched_rag: | |
| parallel_tasks["rag"] = asyncio.create_task(self.mcp.call_rag(req.tenant_id, rag_parallel_query)) | |
| parallel_tasks["web"] = asyncio.create_task(self.mcp.call_web(req.tenant_id, web_parallel_query)) | |
| # Execute each step in sequence | |
| for step_info in steps: | |
| tool_name = step_info.get("tool") | |
| step_input = step_info.get("input") or {} | |
| query = step_input.get("query") or req.message | |
| try: | |
| if tool_name == "rag": | |
| # Reuse pre-fetched RAG if available, otherwise fetch with retry | |
| if pre_fetched_rag and query == rag_parallel_query: | |
| rag_resp = pre_fetched_rag | |
| tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"}) | |
| elif parallel_tasks.get("rag") and query == rag_parallel_query: | |
| rag_resp = await parallel_tasks["rag"] | |
| tool_traces.append({"tool": "rag", "response": rag_resp, "note": "parallel"}) | |
| else: | |
| # Use autonomous retry with self-correction | |
| rag_resp = await self.rag_with_repair( | |
| query=query, | |
| tenant_id=req.tenant_id, | |
| original_threshold=0.3, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| tool_traces.append({"tool": "rag", "response": rag_resp, "note": "with_retry"}) | |
| rag_data = rag_resp | |
| tools_used.append("rag") | |
| hits = self._extract_hits(rag_resp) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "hit_count": len(hits), | |
| "summary": self._summarize_hits(rag_resp, limit=2) | |
| }) | |
| # Extract snippets for prompt | |
| if isinstance(rag_resp, dict): | |
| for h in hits[:5]: | |
| txt = h.get("text") or h.get("content") or str(h) | |
| collected_data.append(f"[RAG] {txt}") | |
| elif tool_name == "web": | |
| if parallel_tasks.get("web") and query == web_parallel_query: | |
| web_resp = await parallel_tasks["web"] | |
| tool_traces.append({"tool": "web", "response": web_resp, "note": "parallel"}) | |
| else: | |
| # Use autonomous retry with query rewriting | |
| web_resp = await self.web_with_repair( | |
| query=query, | |
| tenant_id=req.tenant_id, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| tool_traces.append({"tool": "web", "response": web_resp, "note": "with_retry"}) | |
| web_data = web_resp | |
| tools_used.append("web") | |
| hits = self._extract_hits(web_resp) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "hit_count": len(hits), | |
| "summary": self._summarize_hits(web_resp, limit=2) | |
| }) | |
| # Extract snippets for prompt | |
| if isinstance(web_resp, dict): | |
| for h in hits[:5]: | |
| title = h.get("title") or h.get("headline") or "" | |
| snippet = h.get("snippet") or h.get("summary") or h.get("text") or "" | |
| url = h.get("url") or h.get("link") or "" | |
| collected_data.append(f"[WEB] {title}\n{snippet}\nSource: {url}") | |
| elif tool_name == "admin": | |
| admin_resp = await self.mcp.call_admin(req.tenant_id, query) | |
| tool_traces.append({"tool": "admin", "response": admin_resp}) | |
| admin_data = admin_resp | |
| tools_used.append("admin") | |
| collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}") | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "admin", | |
| "status": "completed" | |
| }) | |
| elif tool_name == "llm": | |
| # LLM is always last - synthesize all collected data | |
| break | |
| except Exception as e: | |
| tool_traces.append({"tool": tool_name, "error": str(e)}) | |
| # Continue with other tools even if one fails | |
| reasoning_trace.append({ | |
| "step": "error", | |
| "tool": tool_name, | |
| "error": str(e) | |
| }) | |
| # Build comprehensive prompt with all collected data | |
| data_section = "\n---\n".join(collected_data) if collected_data else "" | |
| # Build final prompt | |
| if data_section: | |
| prompt = ( | |
| f"You are an assistant helping tenant {req.tenant_id}.\n\n" | |
| f"## Information Collected\n" | |
| f"The following details have been gathered from multiple reliable sources:\n" | |
| f"{data_section}\n\n" | |
| f"## User Request\n" | |
| f"{req.message}\n\n" | |
| f"## Your Task\n" | |
| f"Use the information above to directly address the user's request. " | |
| f"Focus on giving the user exactly what they need—clear guidance, accurate facts, " | |
| f"and practical steps whenever possible. If the information is incomplete, explain " | |
| f"what can and cannot be concluded from the available data." | |
| ) | |
| else: | |
| # No data collected, just answer the question | |
| prompt = req.message | |
| # Final LLM synthesis | |
| try: | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 | |
| total_tokens += estimated_tokens | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent="multi_step", | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse( | |
| text=llm_out, | |
| decision=decision, | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace + [{ | |
| "step": "llm_response", | |
| "mode": "multi_step_parallel" if parallel_step else "multi_step", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens | |
| }] | |
| ) | |
| except Exception as e: | |
| tool_traces.append({"tool": "llm", "error": str(e)}) | |
| error_msg = str(e) | |
| # Provide helpful error message | |
| if "Cannot connect" in error_msg or "Ollama" in error_msg: | |
| fallback = ( | |
| f"I couldn't connect to the AI service (Ollama). " | |
| f"Error: {error_msg}\n\n" | |
| f"To fix this:\n" | |
| f"1. Install Ollama from https://ollama.ai\n" | |
| f"2. Start Ollama: `ollama serve`\n" | |
| f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n" | |
| f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file" | |
| ) | |
| else: | |
| fallback = f"I encountered an error while synthesizing the response: {error_msg}" | |
| return AgentResponse( | |
| text=fallback, | |
| decision=AgentDecision( | |
| action="respond", | |
| tool=None, | |
| tool_input=None, | |
| reason=f"multi_step_llm_error: {e}" | |
| ), | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace + [{ | |
| "step": "error", | |
| "tool": "llm", | |
| "error": str(e) | |
| }] | |
| ) | |
| # ============================================================= | |
| # AUTONOMOUS RETRY + SELF-CORRECTION SYSTEM | |
| # ============================================================= | |
| """ | |
| This system provides autonomous retry and self-correction capabilities | |
| for the agent orchestrator. It enables the agent to: | |
| 1. **Self-healing**: Tools that break automatically retry with adjusted parameters | |
| 2. **Resilient operations**: Handles low RAG scores, empty web results, and misfired rules | |
| 3. **Smart optimization**: Automatically rewrites queries, adjusts thresholds, and optimizes parameters | |
| 4. **Enterprise-grade reliability**: Matches enterprise behavior with comprehensive retry strategies | |
| Key features: | |
| - safe_tool_call(): Generic retry wrapper for any tool call | |
| - rag_with_repair(): RAG search with automatic threshold adjustment and query expansion | |
| - web_with_repair(): Web search with automatic query rewriting for empty results | |
| - rule_safe_message(): Message rewriting to comply with admin rules | |
| All retry attempts are logged to analytics for monitoring and debugging. | |
| """ | |
| async def safe_tool_call( | |
| self, | |
| tool_fn, | |
| params: Dict[str, Any], | |
| max_retries: int = 2, | |
| fallback_params: Optional[Dict[str, Any]] = None, | |
| tool_name: str = "unknown", | |
| tenant_id: Optional[str] = None, | |
| user_id: Optional[str] = None, | |
| reasoning_trace: Optional[List[Dict[str, Any]]] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Wrapper for tool calls with automatic retry and self-correction. | |
| Args: | |
| tool_fn: Async function to call | |
| params: Parameters to pass to tool_fn | |
| max_retries: Maximum number of retry attempts | |
| fallback_params: Alternative parameters to try if initial attempt fails | |
| tool_name: Name of the tool (for logging) | |
| tenant_id: Tenant ID (for analytics) | |
| user_id: User ID (for analytics) | |
| reasoning_trace: Optional reasoning trace to append to | |
| Returns: | |
| Tool result dictionary, or {"error": "tool_failed_after_retries"} if all attempts fail | |
| """ | |
| for attempt in range(max_retries): | |
| try: | |
| result = await tool_fn(**params) | |
| if attempt > 0: | |
| # Log successful retry | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "retry_success", | |
| "tool": tool_name, | |
| "attempt": attempt + 1, | |
| "status": "recovered" | |
| }) | |
| if tenant_id: | |
| self._analytics_log_tool_usage( | |
| tenant_id=tenant_id, | |
| tool_name=f"{tool_name}_retry_{attempt+1}", | |
| latency_ms=0, | |
| success=True, | |
| user_id=user_id | |
| ) | |
| return result | |
| except Exception as e: | |
| error_msg = str(e) | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "retry_attempt", | |
| "tool": tool_name, | |
| "attempt": attempt + 1, | |
| "error": error_msg[:200] | |
| }) | |
| # Log failed attempt | |
| if tenant_id: | |
| self._analytics_log_tool_usage( | |
| tenant_id=tenant_id, | |
| tool_name=tool_name, | |
| latency_ms=0, | |
| success=False, | |
| error_message=error_msg[:200], | |
| user_id=user_id | |
| ) | |
| # Try alternate params if provided and not last attempt | |
| if fallback_params and attempt < max_retries - 1: | |
| params = {**params, **fallback_params} | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "retry_with_fallback_params", | |
| "tool": tool_name, | |
| "attempt": attempt + 2, | |
| "fallback_params": fallback_params | |
| }) | |
| # If last attempt, return error | |
| if attempt == max_retries - 1: | |
| return {"error": "tool_failed_after_retries", "error_message": error_msg} | |
| return {"error": "tool_failed_after_retries"} | |
| async def rag_with_repair( | |
| self, | |
| query: str, | |
| tenant_id: str, | |
| original_threshold: float = 0.3, | |
| reasoning_trace: Optional[List[Dict[str, Any]]] = None, | |
| user_id: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| RAG search with automatic self-correction for low scores. | |
| Strategy: | |
| 1. Try with original threshold | |
| 2. If top_score < 0.30, retry with lower threshold (0.15) | |
| 3. If still low (< 0.15), expand query and retry | |
| """ | |
| # Initial attempt | |
| rag_start = time.time() | |
| result = await self.mcp.call_rag(tenant_id, query, threshold=original_threshold) | |
| rag_latency_ms = int((time.time() - rag_start) * 1000) | |
| # Extract hits and calculate scores | |
| hits = self._extract_hits(result) | |
| top_score = None | |
| avg_score = None | |
| if hits: | |
| scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h] | |
| if scores: | |
| top_score = max(scores) | |
| avg_score = sum(scores) / len(scores) | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rag_initial_search", | |
| "query": query[:200], | |
| "hits_count": len(hits), | |
| "top_score": top_score, | |
| "avg_score": avg_score, | |
| "threshold": original_threshold | |
| }) | |
| # Retry logic: low score → lower threshold | |
| if top_score is not None and top_score < 0.30 and original_threshold >= 0.15: | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rag_retry_low_threshold", | |
| "reason": f"top_score {top_score:.3f} < 0.30, retrying with threshold=0.15" | |
| }) | |
| retry_start = time.time() | |
| result = await self.mcp.call_rag(tenant_id, query, threshold=0.15) | |
| retry_latency_ms = int((time.time() - retry_start) * 1000) | |
| rag_latency_ms += retry_latency_ms | |
| hits = self._extract_hits(result) | |
| if hits: | |
| scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h] | |
| if scores: | |
| top_score = max(scores) | |
| avg_score = sum(scores) / len(scores) | |
| # Log retry | |
| self._analytics_log_tool_usage( | |
| tenant_id=tenant_id, | |
| tool_name="rag_retry_low_threshold", | |
| latency_ms=retry_latency_ms, | |
| success=True, | |
| user_id=user_id | |
| ) | |
| # Final retry: expand query if score still too low | |
| if top_score is not None and top_score < 0.15: | |
| expanded_query = f"{query} (more details comprehensive explanation)" | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rag_retry_expanded_query", | |
| "reason": f"top_score {top_score:.3f} < 0.15, retrying with expanded query", | |
| "original_query": query[:200], | |
| "expanded_query": expanded_query[:200] | |
| }) | |
| retry_start = time.time() | |
| result = await self.mcp.call_rag(tenant_id, expanded_query, threshold=0.15) | |
| retry_latency_ms = int((time.time() - retry_start) * 1000) | |
| rag_latency_ms += retry_latency_ms | |
| hits = self._extract_hits(result) | |
| if hits: | |
| scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h] | |
| if scores: | |
| top_score = max(scores) | |
| avg_score = sum(scores) / len(scores) | |
| # Log retry | |
| self._analytics_log_tool_usage( | |
| tenant_id=tenant_id, | |
| tool_name="rag_retry_expanded_query", | |
| latency_ms=retry_latency_ms, | |
| success=True, | |
| user_id=user_id | |
| ) | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rag_expanded_query_result", | |
| "hits_count": len(hits), | |
| "top_score": top_score, | |
| "avg_score": avg_score | |
| }) | |
| # Log final RAG search | |
| if hits: | |
| self._analytics_log_rag_search( | |
| tenant_id=tenant_id, | |
| query=query[:500], | |
| hits_count=len(hits), | |
| avg_score=avg_score, | |
| top_score=top_score, | |
| latency_ms=rag_latency_ms | |
| ) | |
| return result | |
| async def web_with_repair( | |
| self, | |
| query: str, | |
| tenant_id: str, | |
| reasoning_trace: Optional[List[Dict[str, Any]]] = None, | |
| user_id: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Web search with automatic query rewriting for empty results. | |
| Strategy: | |
| 1. Try original query | |
| 2. If empty, try "best explanation of {query}" | |
| 3. If still empty, try "{query} facts summary" | |
| """ | |
| # Initial attempt | |
| web_start = time.time() | |
| result = await self.mcp.call_web(tenant_id, query) | |
| web_latency_ms = int((time.time() - web_start) * 1000) | |
| hits = self._extract_hits(result) | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "web_initial_search", | |
| "query": query[:200], | |
| "hits_count": len(hits) | |
| }) | |
| # Retry logic: empty results → rewrite query | |
| if not result or len(hits) == 0: | |
| rewritten_queries = [ | |
| f"best explanation of {query}", | |
| f"{query} facts summary" | |
| ] | |
| for i, rewritten in enumerate(rewritten_queries): | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "web_retry_rewritten", | |
| "attempt": i + 1, | |
| "original_query": query[:200], | |
| "rewritten_query": rewritten[:200] | |
| }) | |
| retry_start = time.time() | |
| result = await self.mcp.call_web(tenant_id, rewritten) | |
| retry_latency_ms = int((time.time() - retry_start) * 1000) | |
| web_latency_ms += retry_latency_ms | |
| hits = self._extract_hits(result) | |
| # Log retry | |
| self._analytics_log_tool_usage( | |
| tenant_id=tenant_id, | |
| tool_name=f"web_retry_rewrite_{i+1}", | |
| latency_ms=retry_latency_ms, | |
| success=True, | |
| user_id=user_id | |
| ) | |
| if hits: | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "web_retry_success", | |
| "rewritten_query": rewritten[:200], | |
| "hits_count": len(hits) | |
| }) | |
| break | |
| # Log final web search | |
| self._analytics_log_tool_usage( | |
| tenant_id=tenant_id, | |
| tool_name="web", | |
| latency_ms=web_latency_ms, | |
| success=len(hits) > 0, | |
| user_id=user_id | |
| ) | |
| return result | |
| async def rule_safe_message( | |
| self, | |
| user_message: str, | |
| tenant_id: str, | |
| reasoning_trace: Optional[List[Dict[str, Any]]] = None | |
| ) -> str: | |
| """ | |
| Check admin rules and rewrite message if it violates policies. | |
| Strategy: | |
| 1. Check rules | |
| 2. If blocked, ask LLM to rewrite message to comply | |
| 3. Return safe version | |
| """ | |
| matches: List[RedFlagMatch] = await self.redflag.check(tenant_id, user_message) | |
| if not matches: | |
| return user_message | |
| # Check if any are blocking rules (not just brief response rules) | |
| blocking_rules = [] | |
| for match in matches: | |
| rule_text = (match.description or match.pattern or "").lower() | |
| is_brief_rule = ( | |
| match.severity == "low" and ( | |
| "greeting" in rule_text or | |
| "brief" in rule_text or | |
| "simple response" in rule_text | |
| ) | |
| ) | |
| if not is_brief_rule: | |
| blocking_rules.append(match) | |
| # Only rewrite if there are blocking rules | |
| if not blocking_rules: | |
| return user_message | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rule_violation_detected", | |
| "blocking_rules_count": len(blocking_rules), | |
| "action": "attempting_rewrite" | |
| }) | |
| # Ask LLM to rewrite the message | |
| rewrite_prompt = f"""The following user message violates company policies. Rewrite it to be compliant while preserving the user's intent as much as possible. | |
| Original message: "{user_message}" | |
| Violated policies: | |
| {chr(10).join([f"- {m.description or m.pattern}" for m in blocking_rules[:3]])} | |
| Provide a rewritten version that: | |
| 1. Avoids the policy violations | |
| 2. Preserves the user's original intent | |
| 3. Remains professional and helpful | |
| Rewritten message:""" | |
| try: | |
| rewritten = await self.llm.simple_call(rewrite_prompt, temperature=0.3) | |
| rewritten = rewritten.strip().strip('"').strip("'") | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rule_rewrite_completed", | |
| "original_length": len(user_message), | |
| "rewritten_length": len(rewritten), | |
| "rewritten_preview": rewritten[:200] | |
| }) | |
| # Verify the rewritten message doesn't trigger rules | |
| verify_matches = await self.redflag.check(tenant_id, rewritten) | |
| if not verify_matches or all( | |
| (m.description or m.pattern or "").lower() in ["greeting", "brief", "simple response"] | |
| for m in verify_matches | |
| ): | |
| return rewritten | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rule_rewrite_still_violates", | |
| "action": "using_original_with_block" | |
| }) | |
| except Exception as e: | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rule_rewrite_failed", | |
| "error": str(e)[:200] | |
| }) | |
| # Return original if rewrite failed or still violates | |
| return user_message | |
| def _build_prompt_with_web(self, req: AgentRequest, web_resp: Dict[str, Any]) -> str: | |
| snippets = [] | |
| if isinstance(web_resp, dict): | |
| hits = web_resp.get("results") or web_resp.get("items") or [] | |
| for h in hits[:5]: | |
| title = h.get("title") or h.get("headline") or "" | |
| snippet = h.get("snippet") or h.get("summary") or h.get("text") or "" | |
| url = h.get("url") or h.get("link") or "" | |
| snippets.append(f"{title}\n{snippet}\nSource: {url}") | |
| snippet_text = "\n---\n".join(snippets) or "" | |
| # prompt = ( | |
| # f"You are an assistant with access to recent web search results. Use the following results to answer.\n{snippet_text}\n\n" | |
| # f"User question: {req.message}\nAnswer succinctly and indicate which results you used." | |
| # ) | |
| prompt = ( | |
| f"You are an assistant with access to recent web search results.\n\n" | |
| f"## Search Results\n" | |
| f"{snippet_text}\n\n" | |
| f"## User Question\n" | |
| f"{req.message}\n\n" | |
| f"## Your Task\n" | |
| f"Provide a clear, accurate, and succinct answer based on the search results above. " | |
| f"Indicate which results you used in your reasoning." | |
| ) | |
| return prompt | |
| def _format_tool_output(self, tool_name: str, output: Any, latency_ms: int) -> Dict[str, Any]: | |
| """ | |
| Format tool output to conform to strict JSON schema. | |
| Args: | |
| tool_name: Name of the tool (rag, web, admin, llm) | |
| output: Raw tool output | |
| latency_ms: Actual latency in milliseconds | |
| Returns: | |
| Formatted output conforming to tool schema | |
| """ | |
| if tool_name == "rag": | |
| # Format RAG output | |
| if isinstance(output, dict): | |
| results = output.get("results") or output.get("hits") or [] | |
| # Ensure each result has required fields | |
| formatted_results = [] | |
| for r in results: | |
| if isinstance(r, dict): | |
| formatted_results.append({ | |
| "text": r.get("text") or r.get("content") or str(r), | |
| "similarity": float(r.get("similarity") or r.get("score") or 0.0), | |
| "metadata": r.get("metadata") or {}, | |
| "doc_id": r.get("doc_id") or r.get("id") | |
| }) | |
| else: | |
| formatted_results.append({ | |
| "text": str(r), | |
| "similarity": 0.5, | |
| "metadata": {}, | |
| "doc_id": None | |
| }) | |
| # Calculate aggregate scores | |
| scores = [r["similarity"] for r in formatted_results if r["similarity"] > 0] | |
| avg_score = sum(scores) / len(scores) if scores else 0.0 | |
| top_score = max(scores) if scores else 0.0 | |
| return { | |
| "results": formatted_results, | |
| "query": output.get("query", ""), | |
| "tenant_id": output.get("tenant_id", ""), | |
| "hits_count": len(formatted_results), | |
| "avg_score": round(avg_score, 3), | |
| "top_score": round(top_score, 3), | |
| "latency_ms": latency_ms | |
| } | |
| else: | |
| # Fallback for non-dict output | |
| return { | |
| "results": [{"text": str(output), "similarity": 0.5, "metadata": {}, "doc_id": None}], | |
| "query": "", | |
| "tenant_id": "", | |
| "hits_count": 1, | |
| "avg_score": 0.5, | |
| "top_score": 0.5, | |
| "latency_ms": latency_ms | |
| } | |
| elif tool_name == "web": | |
| # Format Web output | |
| if isinstance(output, dict): | |
| results = output.get("results") or output.get("items") or [] | |
| formatted_results = [] | |
| for r in results: | |
| if isinstance(r, dict): | |
| formatted_results.append({ | |
| "title": r.get("title") or r.get("headline") or "", | |
| "snippet": r.get("snippet") or r.get("summary") or r.get("text") or "", | |
| "link": r.get("url") or r.get("link") or "", | |
| "displayLink": r.get("displayLink") or r.get("display_link") or "" | |
| }) | |
| else: | |
| formatted_results.append({ | |
| "title": "", | |
| "snippet": str(r), | |
| "link": "", | |
| "displayLink": "" | |
| }) | |
| return { | |
| "results": formatted_results, | |
| "query": output.get("query", ""), | |
| "total_results": output.get("total_results") or output.get("totalResults") or len(formatted_results), | |
| "latency_ms": latency_ms | |
| } | |
| else: | |
| return { | |
| "results": [], | |
| "query": "", | |
| "total_results": 0, | |
| "latency_ms": latency_ms | |
| } | |
| elif tool_name == "admin": | |
| # Format Admin output | |
| if isinstance(output, dict): | |
| violations = output.get("violations") or output.get("matches") or [] | |
| formatted_violations = [] | |
| for v in violations: | |
| if isinstance(v, dict): | |
| formatted_violations.append({ | |
| "rule_id": v.get("rule_id") or v.get("id") or "", | |
| "rule_pattern": v.get("rule_pattern") or v.get("pattern") or "", | |
| "severity": v.get("severity", "medium"), | |
| "matched_text": v.get("matched_text") or v.get("text") or "", | |
| "confidence": float(v.get("confidence", 1.0)), | |
| "message_preview": v.get("message_preview") or v.get("preview") or "" | |
| }) | |
| return { | |
| "violations": formatted_violations, | |
| "checked": output.get("checked", True), | |
| "rules_count": output.get("rules_count") or output.get("rulesCount") or len(formatted_violations), | |
| "latency_ms": latency_ms | |
| } | |
| else: | |
| return { | |
| "violations": [], | |
| "checked": True, | |
| "rules_count": 0, | |
| "latency_ms": latency_ms | |
| } | |
| elif tool_name == "llm": | |
| # Format LLM output | |
| if isinstance(output, str): | |
| return { | |
| "text": output, | |
| "tokens_used": len(output) // 4, # Rough estimate | |
| "latency_ms": latency_ms, | |
| "model": getattr(self.llm, 'model', 'unknown'), | |
| "temperature": 0.0 | |
| } | |
| elif isinstance(output, dict): | |
| return { | |
| "text": output.get("text") or output.get("response") or str(output), | |
| "tokens_used": output.get("tokens_used") or output.get("tokens") or 0, | |
| "latency_ms": latency_ms, | |
| "model": output.get("model") or getattr(self.llm, 'model', 'unknown'), | |
| "temperature": output.get("temperature", 0.0) | |
| } | |
| else: | |
| return { | |
| "text": str(output), | |
| "tokens_used": 0, | |
| "latency_ms": latency_ms, | |
| "model": getattr(self.llm, 'model', 'unknown'), | |
| "temperature": 0.0 | |
| } | |
| # Unknown tool - return as-is | |
| return output if isinstance(output, dict) else {"output": str(output), "latency_ms": latency_ms} | |
| def _extract_hits(resp: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| if not isinstance(resp, dict): | |
| return [] | |
| return resp.get("results") or resp.get("hits") or resp.get("items") or [] | |
| def _summarize_hits(self, resp: Optional[Dict[str, Any]], limit: int = 3) -> List[str]: | |
| hits = self._extract_hits(resp) | |
| summaries = [] | |
| for hit in hits[:limit]: | |
| if isinstance(hit, dict): | |
| snippet = hit.get("text") or hit.get("content") or hit.get("snippet") or "" | |
| else: | |
| snippet = str(hit) | |
| summaries.append(snippet[:160]) | |
| return summaries | |
| async def run_parallel_tools(self, tasks: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Run multiple tools in parallel using asyncio.gather. | |
| Args: | |
| tasks: Dictionary mapping tool names to coroutines, e.g.: | |
| {"rag": rag_coro, "web": web_coro} | |
| Returns: | |
| Dictionary mapping tool names to results, e.g.: | |
| {"rag": rag_result, "web": web_result} | |
| Exceptions are returned as values if a tool fails. | |
| """ | |
| if not tasks: | |
| return {} | |
| names = list(tasks.keys()) | |
| coros = [tasks[name] for name in names] | |
| # Run all coroutines in parallel, return exceptions instead of raising | |
| results = await asyncio.gather(*coros, return_exceptions=True) | |
| return {names[i]: results[i] for i in range(len(names))} | |
| def _first_query_for_tool(steps: List[Dict[str, Any]], tool_name: str, default_query: str) -> Optional[str]: | |
| for step in steps: | |
| if step.get("tool") == tool_name: | |
| input_data = step.get("input") or {} | |
| return input_data.get("query") or default_query | |
| return None | |