Spaces:
Sleeping
Sleeping
| # ============================================================= | |
| # File: backend/api/services/agent_orchestrator.py | |
| # ============================================================= | |
| """ | |
| Agent Orchestrator (integrated with enterprise RedFlagDetector) | |
| Place at: backend/api/services/agent_orchestrator.py | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| from typing import List, Dict, Any, Optional | |
| from ..models.agent import AgentRequest, AgentDecision, AgentResponse | |
| from ..models.redflag import RedFlagMatch | |
| from .redflag_detector import RedFlagDetector | |
| from .intent_classifier import IntentClassifier | |
| from .tool_selector import ToolSelector | |
| from .llm_client import LLMClient | |
| from ..mcp_clients.mcp_client import MCPClient | |
| from .tool_scoring import ToolScoringService | |
| from ..storage.analytics_store import AnalyticsStore | |
| from .result_merger import merge_parallel_results, format_merged_context_for_prompt | |
| import time | |
| class AgentOrchestrator: | |
| def __init__(self, rag_mcp_url: str, web_mcp_url: str, admin_mcp_url: str, llm_backend: str = "ollama"): | |
| self.mcp = MCPClient(rag_mcp_url, web_mcp_url, admin_mcp_url) | |
| self.llm = LLMClient(backend=llm_backend, url=os.getenv("OLLAMA_URL"), api_key=os.getenv("GROQ_API_KEY"), model=os.getenv("OLLAMA_MODEL")) | |
| # pass admin_mcp_url so detector can call back | |
| self.redflag = RedFlagDetector( | |
| supabase_url=os.getenv("SUPABASE_URL"), | |
| supabase_key=os.getenv("SUPABASE_SERVICE_KEY"), | |
| admin_mcp_url=admin_mcp_url | |
| ) | |
| self.intent = IntentClassifier(llm_client=self.llm) | |
| self.selector = ToolSelector(llm_client=self.llm) | |
| self.tool_scorer = ToolScoringService() | |
| self.analytics = AnalyticsStore() | |
| # Log backend being used (only once at startup) | |
| if not hasattr(AgentOrchestrator, '_analytics_backend_logged'): | |
| if self.analytics.use_supabase: | |
| print("✅ AgentOrchestrator Analytics: Using Supabase backend") | |
| else: | |
| print("⚠️ AgentOrchestrator Analytics: Using SQLite backend") | |
| AgentOrchestrator._analytics_backend_logged = True | |
| async def handle(self, req: AgentRequest) -> AgentResponse: | |
| start_time = time.time() | |
| reasoning_trace: List[Dict[str, Any]] = [] | |
| reasoning_trace.append({ | |
| "step": "request_received", | |
| "tenant_id": req.tenant_id, | |
| "user_id": req.user_id, | |
| "message_preview": req.message[:120] | |
| }) | |
| # 1) FIRST: Check admin rules - if any rule matches, respond according to rule | |
| matches: List[RedFlagMatch] = await self.redflag.check(req.tenant_id, req.message) | |
| reasoning_trace.append({ | |
| "step": "admin_rules_check", | |
| "match_count": len(matches), | |
| "matches": [m.__dict__ for m in matches] | |
| }) | |
| if matches: | |
| # Log all rule matches | |
| for match in matches: | |
| self.analytics.log_redflag_violation( | |
| tenant_id=req.tenant_id, | |
| rule_id=match.rule_id, | |
| rule_pattern=match.pattern, | |
| severity=match.severity, | |
| matched_text=match.matched_text, | |
| confidence=match.confidence, | |
| message_preview=req.message[:200], | |
| user_id=req.user_id | |
| ) | |
| # Categorize rules: brief response rules vs blocking rules | |
| brief_response_rules = [] | |
| blocking_rules = [] | |
| for match in matches: | |
| rule_text = (match.description or match.pattern or "").lower() | |
| is_brief_rule = ( | |
| match.severity == "low" and ( | |
| "greeting" in rule_text or | |
| "brief" in rule_text or | |
| "simple response" in rule_text or | |
| "keep.*response.*brief" in rule_text or | |
| "do not.*verbose" in rule_text or | |
| "respond.*briefly" in rule_text | |
| ) | |
| ) | |
| if is_brief_rule: | |
| brief_response_rules.append(match) | |
| else: | |
| blocking_rules.append(match) | |
| # Handle brief response rules (greetings, etc.) - return immediately | |
| if brief_response_rules and not blocking_rules: | |
| # Return brief response without proceeding to normal flow | |
| brief_responses = [ | |
| "Hello! How can I help you today?", | |
| "Hi there! What can I assist you with?", | |
| "Hello! I'm here to help. What do you need?", | |
| "Hi! How can I assist you?" | |
| ] | |
| import random | |
| brief_response = random.choice(brief_responses) | |
| reasoning_trace.append({ | |
| "step": "brief_response_rule_matched", | |
| "action": "brief_response", | |
| "matched_rules": [m.rule_id for m in brief_response_rules], | |
| "message": "Brief response rule matched, returning brief response (skipping normal flow)" | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self.analytics.log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent="greeting", | |
| tools_used=[], | |
| total_tokens=len(brief_response) // 4, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse( | |
| text=brief_response, | |
| decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="brief_response_rule"), | |
| reasoning_trace=reasoning_trace | |
| ) | |
| # Handle blocking rules (security, compliance, etc.) - block and return immediately | |
| if blocking_rules: | |
| # Notify admin asynchronously | |
| try: | |
| await self.redflag.notify_admin(req.tenant_id, blocking_rules, source_payload={"message": req.message, "user_id": req.user_id}) | |
| except Exception: | |
| pass | |
| decision = AgentDecision( | |
| action="block", | |
| tool="admin", | |
| tool_input={"violations": [m.__dict__ for m in blocking_rules]}, | |
| reason="admin_rule_violation" | |
| ) | |
| # Build detailed prompt for LLM to generate natural red flag response | |
| violations_details = [] | |
| for i, m in enumerate(blocking_rules, 1): | |
| rule_name = m.description or m.pattern or "Policy violation" | |
| detail = f"{i}. **{rule_name}** (Severity: {m.severity})" | |
| if m.matched_text: | |
| detail += f"\n - Detected phrase: \"{m.matched_text}\"" | |
| violations_details.append(detail) | |
| llm_prompt = f"""A user made the following request: "{req.message}" | |
| However, this request violates company policies. The following policy violations were detected: | |
| {chr(10).join(violations_details)} | |
| Your task: Write a clear, professional, and empathetic response to inform the user that: | |
| 1. Their request cannot be processed due to policy violations | |
| 2. Which specific policy was violated (mention it naturally) | |
| 3. The incident has been logged for security review | |
| 4. They should contact an administrator if they need assistance or believe this is an error | |
| Write a natural, conversational response (2-4 sentences) that feels helpful rather than robotic. Be professional but understanding. | |
| Response:""" | |
| # Generate LLM response for red flag | |
| try: | |
| llm_response = await self.llm.simple_call(llm_prompt, temperature=min(req.temperature + 0.2, 0.7)) # Slightly higher temp for more natural response | |
| llm_response = llm_response.strip() | |
| # Add warning emoji if not present | |
| if not llm_response.startswith("⚠️") and not llm_response.startswith("🚨"): | |
| llm_response = f"⚠️ {llm_response}" | |
| except Exception as e: | |
| # Fallback to a simple message if LLM fails | |
| summary = "; ".join( | |
| f"{m.description or m.pattern}" | |
| for m in matches | |
| ) | |
| llm_response = f"⚠️ I'm unable to process your request because it violates our company policy: {summary}. This incident has been logged. Please contact your administrator if you need assistance." | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| # Log LLM usage for red flag response | |
| estimated_tokens = len(llm_response) // 4 + len(llm_prompt) // 4 | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=total_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| self.analytics.log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent="admin", | |
| tools_used=["admin", "llm"], | |
| total_tokens=estimated_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=False, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse( | |
| text=llm_response, | |
| decision=decision, | |
| tool_traces=[{"redflags": [m.__dict__ for m in blocking_rules]}], | |
| reasoning_trace=reasoning_trace | |
| ) | |
| # 2) ONLY IF NO RULES MATCHED: Proceed with normal flow (intent classification, RAG, etc.) | |
| intent = await self.intent.classify(req.message) | |
| reasoning_trace.append({ | |
| "step": "intent_detection", | |
| "intent": intent | |
| }) | |
| # 2.5) Pre-fetch RAG results if available (for tool selector context) | |
| rag_prefetch = None | |
| rag_results = [] | |
| try: | |
| # Try to pre-fetch RAG to help tool selector make better decisions | |
| rag_start = time.time() | |
| rag_prefetch = await self.mcp.call_rag(req.tenant_id, req.message) | |
| rag_latency_ms = int((time.time() - rag_start) * 1000) | |
| if isinstance(rag_prefetch, dict): | |
| rag_results = rag_prefetch.get("results") or rag_prefetch.get("hits") or [] | |
| # Log RAG search event | |
| hits_count = len(rag_results) | |
| avg_score = None | |
| top_score = None | |
| if rag_results: | |
| scores = [h.get("score", 0.0) for h in rag_results if isinstance(h, dict) and "score" in h] | |
| if scores: | |
| avg_score = sum(scores) / len(scores) | |
| top_score = max(scores) | |
| self.analytics.log_rag_search( | |
| tenant_id=req.tenant_id, | |
| query=req.message[:500], | |
| hits_count=hits_count, | |
| avg_score=avg_score, | |
| top_score=top_score, | |
| latency_ms=rag_latency_ms | |
| ) | |
| # Log tool usage | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=rag_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "rag_prefetch", | |
| "status": "ok", | |
| "hit_count": len(rag_results), | |
| "latency_ms": rag_latency_ms | |
| }) | |
| except Exception as pref_err: | |
| # If RAG fails, continue without it | |
| rag_latency_ms = 0 # 0 for failed | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=rag_latency_ms, | |
| success=False, | |
| error_message=str(pref_err)[:200], | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "rag_prefetch", | |
| "status": "error", | |
| "error": str(pref_err) | |
| }) | |
| rag_prefetch = None | |
| tool_scores = self.tool_scorer.score(req.message, intent, rag_results) | |
| reasoning_trace.append({ | |
| "step": "tool_scoring", | |
| "scores": tool_scores | |
| }) | |
| # 3) Tool selection (hybrid) - pass RAG results in context | |
| ctx = { | |
| "tenant_id": req.tenant_id, | |
| "rag_results": rag_results, | |
| "tool_scores": tool_scores | |
| } | |
| decision = await self.selector.select(intent, req.message, ctx) | |
| reasoning_trace.append({ | |
| "step": "tool_selection", | |
| "decision": decision.dict(), | |
| "context_scores": tool_scores | |
| }) | |
| tool_traces: List[Dict[str, Any]] = [] | |
| # 4) Handle multi-step tool execution | |
| if decision.action == "multi_step" and decision.tool_input: | |
| steps = decision.tool_input.get("steps", []) | |
| if steps: | |
| return await self._execute_multi_step( | |
| req, | |
| steps, | |
| decision, | |
| tool_traces, | |
| reasoning_trace, | |
| rag_prefetch | |
| ) | |
| # 5) Execute single tool | |
| tools_used = [] | |
| total_tokens = 0 | |
| if decision.action == "call_tool" and decision.tool: | |
| try: | |
| if decision.tool == "rag": | |
| rag_start = time.time() | |
| rag_resp = await self.mcp.call_rag(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message) | |
| rag_latency_ms = int((time.time() - rag_start) * 1000) | |
| tools_used.append("rag") | |
| tool_traces.append({"tool": "rag", "response": rag_resp}) | |
| hits = self._extract_hits(rag_resp) | |
| # Log RAG search and tool usage | |
| hits_count = len(hits) | |
| avg_score = None | |
| top_score = None | |
| if hits: | |
| scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h] | |
| if scores: | |
| avg_score = sum(scores) / len(scores) | |
| top_score = max(scores) | |
| self.analytics.log_rag_search( | |
| tenant_id=req.tenant_id, | |
| query=req.message[:500], | |
| hits_count=hits_count, | |
| avg_score=avg_score, | |
| top_score=top_score, | |
| latency_ms=rag_latency_ms | |
| ) | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=rag_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "hit_count": hits_count, | |
| "summary": self._summarize_hits(rag_resp, limit=2), | |
| "latency_ms": rag_latency_ms | |
| }) | |
| prompt = self._build_prompt_with_rag(req, rag_resp) | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| # Estimate tokens (rough: ~4 chars per token) | |
| estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 | |
| total_tokens += estimated_tokens | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "llm_response", | |
| "mode": "rag_synthesis", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self.analytics.log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) | |
| if decision.tool == "web": | |
| web_start = time.time() | |
| web_resp = await self.mcp.call_web(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message) | |
| web_latency_ms = int((time.time() - web_start) * 1000) | |
| tools_used.append("web") | |
| tool_traces.append({"tool": "web", "response": web_resp}) | |
| hits_count = len(self._extract_hits(web_resp)) | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="web", | |
| latency_ms=web_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "hit_count": hits_count, | |
| "summary": self._summarize_hits(web_resp, limit=2), | |
| "latency_ms": web_latency_ms | |
| }) | |
| prompt = self._build_prompt_with_web(req, web_resp) | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 | |
| total_tokens += estimated_tokens | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "llm_response", | |
| "mode": "web_synthesis", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self.analytics.log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) | |
| if decision.tool == "admin": | |
| admin_start = time.time() | |
| admin_resp = await self.mcp.call_admin(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message) | |
| admin_latency_ms = int((time.time() - admin_start) * 1000) | |
| tools_used.append("admin") | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="admin", | |
| latency_ms=admin_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| tool_traces.append({"tool": "admin", "response": admin_resp}) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "admin", | |
| "status": "completed", | |
| "latency_ms": admin_latency_ms | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self.analytics.log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=0, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse(text=json.dumps(admin_resp), decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) | |
| if decision.tool == "llm": | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(req.message, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| estimated_tokens = len(llm_out) // 4 + len(req.message) // 4 | |
| total_tokens += estimated_tokens | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "llm_response", | |
| "mode": "direct", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self.analytics.log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse(text=llm_out, decision=decision, reasoning_trace=reasoning_trace) | |
| except Exception as e: | |
| tool_traces.append({"tool": decision.tool, "error": str(e)}) | |
| try: | |
| fallback = await self.llm.simple_call(req.message, temperature=req.temperature) | |
| except Exception as llm_error: | |
| error_msg = str(llm_error) | |
| if "Cannot connect" in error_msg or "Ollama" in error_msg: | |
| fallback = ( | |
| f"I encountered an error while processing your request: {str(e)}\n\n" | |
| f"Additionally, the AI service (Ollama) is unavailable: {error_msg}\n\n" | |
| f"To fix:\n" | |
| f"1. Install Ollama from https://ollama.ai\n" | |
| f"2. Start: `ollama serve`\n" | |
| f"3. Pull model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`" | |
| ) | |
| else: | |
| fallback = f"I encountered an error while processing your request: {str(e)}. Additionally, the AI service is unavailable: {error_msg}" | |
| return AgentResponse( | |
| text=fallback, | |
| decision=AgentDecision(action="respond", tool=None, tool_input=None, reason=f"tool_error_fallback: {e}"), | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace + [{ | |
| "step": "error", | |
| "tool": decision.tool, | |
| "error": str(e) | |
| }] | |
| ) | |
| # Default: direct LLM response | |
| try: | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(req.message, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used = ["llm"] | |
| estimated_tokens = len(llm_out) // 4 + len(req.message) // 4 | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| except Exception as e: | |
| # If LLM fails, return a helpful error message | |
| error_msg = str(e) | |
| if "Cannot connect" in error_msg or "Ollama" in error_msg: | |
| llm_out = ( | |
| f"I couldn't connect to the AI service (Ollama). " | |
| f"Error: {error_msg}\n\n" | |
| f"To fix this:\n" | |
| f"1. Install Ollama from https://ollama.ai\n" | |
| f"2. Start Ollama: `ollama serve`\n" | |
| f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n" | |
| f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file" | |
| ) | |
| else: | |
| llm_out = f"I apologize, but I'm unable to process your request right now. The AI service is unavailable: {error_msg}" | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| success=False, | |
| error_message=error_msg[:200], | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "error", | |
| "tool": "llm", | |
| "error": str(e) | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self.analytics.log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used if 'tools_used' in locals() else [], | |
| total_tokens=estimated_tokens if 'estimated_tokens' in locals() else 0, | |
| total_latency_ms=total_latency_ms, | |
| success=True if 'llm_out' in locals() else False, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse( | |
| text=llm_out, | |
| decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="default_llm"), | |
| reasoning_trace=reasoning_trace | |
| ) | |
| def _build_prompt_with_rag(self, req: AgentRequest, rag_resp: Dict[str, Any]) -> str: | |
| snippets = [] | |
| if isinstance(rag_resp, dict): | |
| hits = rag_resp.get("results") or rag_resp.get("hits") or [] | |
| for h in hits[:5]: | |
| txt = h.get("text") or h.get("content") or str(h) | |
| snippets.append(txt) | |
| snippet_text = "\n---\n".join(snippets) or "" | |
| prompt = ( | |
| f"You are an assistant helping tenant {req.tenant_id}. Use the following retrieved documents to answer the user's question.\n" | |
| f"Documents:\n{snippet_text}\n\n" | |
| f"User question: {req.message}\nProvide a concise, accurate answer and cite the source snippets where appropriate." | |
| ) | |
| return prompt | |
| async def _execute_multi_step(self, req: AgentRequest, steps: List[Dict[str, Any]], | |
| decision: AgentDecision, tool_traces: List[Dict[str, Any]], | |
| reasoning_trace: List[Dict[str, Any]], | |
| pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse: | |
| """ | |
| Execute multiple tools in sequence or parallel and synthesize results with LLM. | |
| Supports parallel execution when steps are marked with "parallel" flag. | |
| """ | |
| start_time = time.time() | |
| rag_data = None | |
| web_data = None | |
| admin_data = None | |
| collected_data = [] | |
| tools_used = [] | |
| total_tokens = 0 | |
| # Check if any step has parallel execution flag | |
| parallel_step = None | |
| for step_info in steps: | |
| if step_info.get("parallel"): | |
| parallel_step = step_info | |
| break | |
| # Handle parallel execution if detected | |
| if parallel_step and parallel_step.get("parallel"): | |
| parallel_config = parallel_step.get("parallel") | |
| parallel_tasks = {} | |
| start_time_parallel = time.time() | |
| # Prepare parallel tasks | |
| if "rag" in parallel_config: | |
| rag_query = parallel_config["rag"] | |
| if pre_fetched_rag: | |
| # Use pre-fetched RAG if available - create a simple async function | |
| async def get_prefetched_rag(): | |
| return pre_fetched_rag | |
| parallel_tasks["rag"] = get_prefetched_rag() | |
| else: | |
| parallel_tasks["rag"] = self.mcp.call_rag(req.tenant_id, rag_query) | |
| if "web" in parallel_config: | |
| web_query = parallel_config["web"] | |
| parallel_tasks["web"] = self.mcp.call_web(req.tenant_id, web_query) | |
| # Execute tools in parallel | |
| if parallel_tasks: | |
| reasoning_trace.append({ | |
| "step": "parallel_execution", | |
| "tools": list(parallel_tasks.keys()), | |
| "mode": "parallel" | |
| }) | |
| parallel_results = await self.run_parallel_tools(parallel_tasks) | |
| parallel_latency_ms = int((time.time() - start_time_parallel) * 1000) | |
| # Process RAG results | |
| if "rag" in parallel_results: | |
| rag_result = parallel_results["rag"] | |
| if isinstance(rag_result, Exception): | |
| tool_traces.append({"tool": "rag", "error": str(rag_result), "note": "parallel"}) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "status": "error", | |
| "error": str(rag_result), | |
| "latency_ms": parallel_latency_ms | |
| }) | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=parallel_latency_ms, | |
| success=False, | |
| error_message=str(rag_result)[:200], | |
| user_id=req.user_id | |
| ) | |
| else: | |
| rag_data = rag_result | |
| tools_used.append("rag") | |
| tool_traces.append({"tool": "rag", "response": rag_result, "note": "parallel"}) | |
| hits_count = len(self._extract_hits(rag_result)) | |
| avg_score = None | |
| top_score = None | |
| if hits_count > 0: | |
| scores = [h.get("score", 0.0) for h in self._extract_hits(rag_result) if isinstance(h, dict) and "score" in h] | |
| if scores: | |
| avg_score = sum(scores) / len(scores) | |
| top_score = max(scores) | |
| self.analytics.log_rag_search( | |
| tenant_id=req.tenant_id, | |
| query=req.message[:500], | |
| hits_count=hits_count, | |
| avg_score=avg_score, | |
| top_score=top_score, | |
| latency_ms=parallel_latency_ms | |
| ) | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=parallel_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "hit_count": hits_count, | |
| "summary": self._summarize_hits(rag_result, limit=2), | |
| "latency_ms": parallel_latency_ms, | |
| "mode": "parallel" | |
| }) | |
| # Process Web results | |
| if "web" in parallel_results: | |
| web_result = parallel_results["web"] | |
| if isinstance(web_result, Exception): | |
| tool_traces.append({"tool": "web", "error": str(web_result), "note": "parallel"}) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "status": "error", | |
| "error": str(web_result), | |
| "latency_ms": parallel_latency_ms | |
| }) | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="web", | |
| latency_ms=parallel_latency_ms, | |
| success=False, | |
| error_message=str(web_result)[:200], | |
| user_id=req.user_id | |
| ) | |
| else: | |
| web_data = web_result | |
| tools_used.append("web") | |
| tool_traces.append({"tool": "web", "response": web_result, "note": "parallel"}) | |
| hits_count = len(self._extract_hits(web_result)) | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="web", | |
| latency_ms=parallel_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "hit_count": hits_count, | |
| "summary": self._summarize_hits(web_result, limit=2), | |
| "latency_ms": parallel_latency_ms, | |
| "mode": "parallel" | |
| }) | |
| # Merge parallel results | |
| merged_context = merge_parallel_results(parallel_results) | |
| sources_list = list(set(e.get("source") for e in merged_context if e.get("source"))) if merged_context else [] | |
| reasoning_trace.append({ | |
| "step": "result_merger", | |
| "merged_items": len(merged_context), | |
| "sources": sources_list | |
| }) | |
| # Format merged context for prompt | |
| data_section = format_merged_context_for_prompt(merged_context, max_items=10) | |
| else: | |
| data_section = "" | |
| else: | |
| # Sequential execution (original logic) | |
| parallel_tasks = {} | |
| rag_parallel_query = self._first_query_for_tool(steps, "rag", req.message) | |
| web_parallel_query = self._first_query_for_tool(steps, "web", req.message) | |
| if rag_parallel_query and web_parallel_query and rag_parallel_query == web_parallel_query: | |
| if not pre_fetched_rag: | |
| parallel_tasks["rag"] = asyncio.create_task(self.mcp.call_rag(req.tenant_id, rag_parallel_query)) | |
| parallel_tasks["web"] = asyncio.create_task(self.mcp.call_web(req.tenant_id, web_parallel_query)) | |
| # Execute each step in sequence | |
| for step_info in steps: | |
| tool_name = step_info.get("tool") | |
| step_input = step_info.get("input") or {} | |
| query = step_input.get("query") or req.message | |
| try: | |
| if tool_name == "rag": | |
| # Reuse pre-fetched RAG if available, otherwise fetch | |
| if pre_fetched_rag and query == rag_parallel_query: | |
| rag_resp = pre_fetched_rag | |
| tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"}) | |
| elif parallel_tasks.get("rag") and query == rag_parallel_query: | |
| rag_resp = await parallel_tasks["rag"] | |
| tool_traces.append({"tool": "rag", "response": rag_resp, "note": "parallel"}) | |
| else: | |
| rag_resp = await self.mcp.call_rag(req.tenant_id, query) | |
| tool_traces.append({"tool": "rag", "response": rag_resp}) | |
| rag_data = rag_resp | |
| tools_used.append("rag") | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "hit_count": len(self._extract_hits(rag_resp)), | |
| "summary": self._summarize_hits(rag_resp, limit=2) | |
| }) | |
| # Extract snippets for prompt | |
| if isinstance(rag_resp, dict): | |
| hits = rag_resp.get("results") or rag_resp.get("hits") or [] | |
| for h in hits[:5]: | |
| txt = h.get("text") or h.get("content") or str(h) | |
| collected_data.append(f"[RAG] {txt}") | |
| elif tool_name == "web": | |
| if parallel_tasks.get("web") and query == web_parallel_query: | |
| web_resp = await parallel_tasks["web"] | |
| tool_traces.append({"tool": "web", "response": web_resp, "note": "parallel"}) | |
| else: | |
| web_resp = await self.mcp.call_web(req.tenant_id, query) | |
| tool_traces.append({"tool": "web", "response": web_resp}) | |
| web_data = web_resp | |
| tools_used.append("web") | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "hit_count": len(self._extract_hits(web_resp)), | |
| "summary": self._summarize_hits(web_resp, limit=2) | |
| }) | |
| # Extract snippets for prompt | |
| if isinstance(web_resp, dict): | |
| hits = web_resp.get("results") or web_resp.get("items") or [] | |
| for h in hits[:5]: | |
| title = h.get("title") or h.get("headline") or "" | |
| snippet = h.get("snippet") or h.get("summary") or h.get("text") or "" | |
| url = h.get("url") or h.get("link") or "" | |
| collected_data.append(f"[WEB] {title}\n{snippet}\nSource: {url}") | |
| elif tool_name == "admin": | |
| admin_resp = await self.mcp.call_admin(req.tenant_id, query) | |
| tool_traces.append({"tool": "admin", "response": admin_resp}) | |
| admin_data = admin_resp | |
| tools_used.append("admin") | |
| collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}") | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "admin", | |
| "status": "completed" | |
| }) | |
| elif tool_name == "llm": | |
| # LLM is always last - synthesize all collected data | |
| break | |
| except Exception as e: | |
| tool_traces.append({"tool": tool_name, "error": str(e)}) | |
| # Continue with other tools even if one fails | |
| reasoning_trace.append({ | |
| "step": "error", | |
| "tool": tool_name, | |
| "error": str(e) | |
| }) | |
| # Build comprehensive prompt with all collected data | |
| data_section = "\n---\n".join(collected_data) if collected_data else "" | |
| # Build final prompt | |
| if data_section: | |
| prompt = ( | |
| f"You are an assistant helping tenant {req.tenant_id}.\n\n" | |
| f"## Information Collected\n" | |
| f"The following details have been gathered from multiple reliable sources:\n" | |
| f"{data_section}\n\n" | |
| f"## User Request\n" | |
| f"{req.message}\n\n" | |
| f"## Your Task\n" | |
| f"Use the information above to directly address the user's request. " | |
| f"Focus on giving the user exactly what they need—clear guidance, accurate facts, " | |
| f"and practical steps whenever possible. If the information is incomplete, explain " | |
| f"what can and cannot be concluded from the available data." | |
| ) | |
| else: | |
| # No data collected, just answer the question | |
| prompt = req.message | |
| # Final LLM synthesis | |
| try: | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 | |
| total_tokens += estimated_tokens | |
| self.analytics.log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self.analytics.log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent="multi_step", | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse( | |
| text=llm_out, | |
| decision=decision, | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace + [{ | |
| "step": "llm_response", | |
| "mode": "multi_step_parallel" if parallel_step else "multi_step", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens | |
| }] | |
| ) | |
| except Exception as e: | |
| tool_traces.append({"tool": "llm", "error": str(e)}) | |
| error_msg = str(e) | |
| # Provide helpful error message | |
| if "Cannot connect" in error_msg or "Ollama" in error_msg: | |
| fallback = ( | |
| f"I couldn't connect to the AI service (Ollama). " | |
| f"Error: {error_msg}\n\n" | |
| f"To fix this:\n" | |
| f"1. Install Ollama from https://ollama.ai\n" | |
| f"2. Start Ollama: `ollama serve`\n" | |
| f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n" | |
| f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file" | |
| ) | |
| else: | |
| fallback = f"I encountered an error while synthesizing the response: {error_msg}" | |
| return AgentResponse( | |
| text=fallback, | |
| decision=AgentDecision( | |
| action="respond", | |
| tool=None, | |
| tool_input=None, | |
| reason=f"multi_step_llm_error: {e}" | |
| ), | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace + [{ | |
| "step": "error", | |
| "tool": "llm", | |
| "error": str(e) | |
| }] | |
| ) | |
| def _build_prompt_with_web(self, req: AgentRequest, web_resp: Dict[str, Any]) -> str: | |
| snippets = [] | |
| if isinstance(web_resp, dict): | |
| hits = web_resp.get("results") or web_resp.get("items") or [] | |
| for h in hits[:5]: | |
| title = h.get("title") or h.get("headline") or "" | |
| snippet = h.get("snippet") or h.get("summary") or h.get("text") or "" | |
| url = h.get("url") or h.get("link") or "" | |
| snippets.append(f"{title}\n{snippet}\nSource: {url}") | |
| snippet_text = "\n---\n".join(snippets) or "" | |
| # prompt = ( | |
| # f"You are an assistant with access to recent web search results. Use the following results to answer.\n{snippet_text}\n\n" | |
| # f"User question: {req.message}\nAnswer succinctly and indicate which results you used." | |
| # ) | |
| prompt = ( | |
| f"You are an assistant with access to recent web search results.\n\n" | |
| f"## Search Results\n" | |
| f"{snippet_text}\n\n" | |
| f"## User Question\n" | |
| f"{req.message}\n\n" | |
| f"## Your Task\n" | |
| f"Provide a clear, accurate, and succinct answer based on the search results above. " | |
| f"Indicate which results you used in your reasoning." | |
| ) | |
| return prompt | |
| def _extract_hits(resp: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| if not isinstance(resp, dict): | |
| return [] | |
| return resp.get("results") or resp.get("hits") or resp.get("items") or [] | |
| def _summarize_hits(self, resp: Optional[Dict[str, Any]], limit: int = 3) -> List[str]: | |
| hits = self._extract_hits(resp) | |
| summaries = [] | |
| for hit in hits[:limit]: | |
| if isinstance(hit, dict): | |
| snippet = hit.get("text") or hit.get("content") or hit.get("snippet") or "" | |
| else: | |
| snippet = str(hit) | |
| summaries.append(snippet[:160]) | |
| return summaries | |
| async def run_parallel_tools(self, tasks: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Run multiple tools in parallel using asyncio.gather. | |
| Args: | |
| tasks: Dictionary mapping tool names to coroutines, e.g.: | |
| {"rag": rag_coro, "web": web_coro} | |
| Returns: | |
| Dictionary mapping tool names to results, e.g.: | |
| {"rag": rag_result, "web": web_result} | |
| Exceptions are returned as values if a tool fails. | |
| """ | |
| if not tasks: | |
| return {} | |
| names = list(tasks.keys()) | |
| coros = [tasks[name] for name in names] | |
| # Run all coroutines in parallel, return exceptions instead of raising | |
| results = await asyncio.gather(*coros, return_exceptions=True) | |
| return {names[i]: results[i] for i in range(len(names))} | |
| def _first_query_for_tool(steps: List[Dict[str, Any]], tool_name: str, default_query: str) -> Optional[str]: | |
| for step in steps: | |
| if step.get("tool") == tool_name: | |
| input_data = step.get("input") or {} | |
| return input_data.get("query") or default_query | |
| return None | |