IntegraChat / backend /api /services /agent_orchestrator.py
nothingworry's picture
Migrate admin rules and analytics to Supabase
611e2c1
raw
history blame
50.5 kB
# =============================================================
# File: backend/api/services/agent_orchestrator.py
# =============================================================
"""
Agent Orchestrator (integrated with enterprise RedFlagDetector)
Place at: backend/api/services/agent_orchestrator.py
"""
from __future__ import annotations
import asyncio
import json
import os
from typing import List, Dict, Any, Optional
from ..models.agent import AgentRequest, AgentDecision, AgentResponse
from ..models.redflag import RedFlagMatch
from .redflag_detector import RedFlagDetector
from .intent_classifier import IntentClassifier
from .tool_selector import ToolSelector
from .llm_client import LLMClient
from ..mcp_clients.mcp_client import MCPClient
from .tool_scoring import ToolScoringService
from ..storage.analytics_store import AnalyticsStore
from .result_merger import merge_parallel_results, format_merged_context_for_prompt
import time
class AgentOrchestrator:
def __init__(self, rag_mcp_url: str, web_mcp_url: str, admin_mcp_url: str, llm_backend: str = "ollama"):
self.mcp = MCPClient(rag_mcp_url, web_mcp_url, admin_mcp_url)
self.llm = LLMClient(backend=llm_backend, url=os.getenv("OLLAMA_URL"), api_key=os.getenv("GROQ_API_KEY"), model=os.getenv("OLLAMA_MODEL"))
# pass admin_mcp_url so detector can call back
self.redflag = RedFlagDetector(
supabase_url=os.getenv("SUPABASE_URL"),
supabase_key=os.getenv("SUPABASE_SERVICE_KEY"),
admin_mcp_url=admin_mcp_url
)
self.intent = IntentClassifier(llm_client=self.llm)
self.selector = ToolSelector(llm_client=self.llm)
self.tool_scorer = ToolScoringService()
self.analytics = AnalyticsStore()
# Log backend being used (only once at startup)
if not hasattr(AgentOrchestrator, '_analytics_backend_logged'):
if self.analytics.use_supabase:
print("✅ AgentOrchestrator Analytics: Using Supabase backend")
else:
print("⚠️ AgentOrchestrator Analytics: Using SQLite backend")
AgentOrchestrator._analytics_backend_logged = True
async def handle(self, req: AgentRequest) -> AgentResponse:
start_time = time.time()
reasoning_trace: List[Dict[str, Any]] = []
reasoning_trace.append({
"step": "request_received",
"tenant_id": req.tenant_id,
"user_id": req.user_id,
"message_preview": req.message[:120]
})
# 1) FIRST: Check admin rules - if any rule matches, respond according to rule
matches: List[RedFlagMatch] = await self.redflag.check(req.tenant_id, req.message)
reasoning_trace.append({
"step": "admin_rules_check",
"match_count": len(matches),
"matches": [m.__dict__ for m in matches]
})
if matches:
# Log all rule matches
for match in matches:
self.analytics.log_redflag_violation(
tenant_id=req.tenant_id,
rule_id=match.rule_id,
rule_pattern=match.pattern,
severity=match.severity,
matched_text=match.matched_text,
confidence=match.confidence,
message_preview=req.message[:200],
user_id=req.user_id
)
# Categorize rules: brief response rules vs blocking rules
brief_response_rules = []
blocking_rules = []
for match in matches:
rule_text = (match.description or match.pattern or "").lower()
is_brief_rule = (
match.severity == "low" and (
"greeting" in rule_text or
"brief" in rule_text or
"simple response" in rule_text or
"keep.*response.*brief" in rule_text or
"do not.*verbose" in rule_text or
"respond.*briefly" in rule_text
)
)
if is_brief_rule:
brief_response_rules.append(match)
else:
blocking_rules.append(match)
# Handle brief response rules (greetings, etc.) - return immediately
if brief_response_rules and not blocking_rules:
# Return brief response without proceeding to normal flow
brief_responses = [
"Hello! How can I help you today?",
"Hi there! What can I assist you with?",
"Hello! I'm here to help. What do you need?",
"Hi! How can I assist you?"
]
import random
brief_response = random.choice(brief_responses)
reasoning_trace.append({
"step": "brief_response_rule_matched",
"action": "brief_response",
"matched_rules": [m.rule_id for m in brief_response_rules],
"message": "Brief response rule matched, returning brief response (skipping normal flow)"
})
total_latency_ms = int((time.time() - start_time) * 1000)
self.analytics.log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent="greeting",
tools_used=[],
total_tokens=len(brief_response) // 4,
total_latency_ms=total_latency_ms,
success=True,
user_id=req.user_id
)
return AgentResponse(
text=brief_response,
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="brief_response_rule"),
reasoning_trace=reasoning_trace
)
# Handle blocking rules (security, compliance, etc.) - block and return immediately
if blocking_rules:
# Notify admin asynchronously
try:
await self.redflag.notify_admin(req.tenant_id, blocking_rules, source_payload={"message": req.message, "user_id": req.user_id})
except Exception:
pass
decision = AgentDecision(
action="block",
tool="admin",
tool_input={"violations": [m.__dict__ for m in blocking_rules]},
reason="admin_rule_violation"
)
# Build detailed prompt for LLM to generate natural red flag response
violations_details = []
for i, m in enumerate(blocking_rules, 1):
rule_name = m.description or m.pattern or "Policy violation"
detail = f"{i}. **{rule_name}** (Severity: {m.severity})"
if m.matched_text:
detail += f"\n - Detected phrase: \"{m.matched_text}\""
violations_details.append(detail)
llm_prompt = f"""A user made the following request: "{req.message}"
However, this request violates company policies. The following policy violations were detected:
{chr(10).join(violations_details)}
Your task: Write a clear, professional, and empathetic response to inform the user that:
1. Their request cannot be processed due to policy violations
2. Which specific policy was violated (mention it naturally)
3. The incident has been logged for security review
4. They should contact an administrator if they need assistance or believe this is an error
Write a natural, conversational response (2-4 sentences) that feels helpful rather than robotic. Be professional but understanding.
Response:"""
# Generate LLM response for red flag
try:
llm_response = await self.llm.simple_call(llm_prompt, temperature=min(req.temperature + 0.2, 0.7)) # Slightly higher temp for more natural response
llm_response = llm_response.strip()
# Add warning emoji if not present
if not llm_response.startswith("⚠️") and not llm_response.startswith("🚨"):
llm_response = f"⚠️ {llm_response}"
except Exception as e:
# Fallback to a simple message if LLM fails
summary = "; ".join(
f"{m.description or m.pattern}"
for m in matches
)
llm_response = f"⚠️ I'm unable to process your request because it violates our company policy: {summary}. This incident has been logged. Please contact your administrator if you need assistance."
total_latency_ms = int((time.time() - start_time) * 1000)
# Log LLM usage for red flag response
estimated_tokens = len(llm_response) // 4 + len(llm_prompt) // 4
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
latency_ms=total_latency_ms,
tokens_used=estimated_tokens,
success=True,
user_id=req.user_id
)
self.analytics.log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent="admin",
tools_used=["admin", "llm"],
total_tokens=estimated_tokens,
total_latency_ms=total_latency_ms,
success=False,
user_id=req.user_id
)
return AgentResponse(
text=llm_response,
decision=decision,
tool_traces=[{"redflags": [m.__dict__ for m in blocking_rules]}],
reasoning_trace=reasoning_trace
)
# 2) ONLY IF NO RULES MATCHED: Proceed with normal flow (intent classification, RAG, etc.)
intent = await self.intent.classify(req.message)
reasoning_trace.append({
"step": "intent_detection",
"intent": intent
})
# 2.5) Pre-fetch RAG results if available (for tool selector context)
rag_prefetch = None
rag_results = []
try:
# Try to pre-fetch RAG to help tool selector make better decisions
rag_start = time.time()
rag_prefetch = await self.mcp.call_rag(req.tenant_id, req.message)
rag_latency_ms = int((time.time() - rag_start) * 1000)
if isinstance(rag_prefetch, dict):
rag_results = rag_prefetch.get("results") or rag_prefetch.get("hits") or []
# Log RAG search event
hits_count = len(rag_results)
avg_score = None
top_score = None
if rag_results:
scores = [h.get("score", 0.0) for h in rag_results if isinstance(h, dict) and "score" in h]
if scores:
avg_score = sum(scores) / len(scores)
top_score = max(scores)
self.analytics.log_rag_search(
tenant_id=req.tenant_id,
query=req.message[:500],
hits_count=hits_count,
avg_score=avg_score,
top_score=top_score,
latency_ms=rag_latency_ms
)
# Log tool usage
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="rag",
latency_ms=rag_latency_ms,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "rag_prefetch",
"status": "ok",
"hit_count": len(rag_results),
"latency_ms": rag_latency_ms
})
except Exception as pref_err:
# If RAG fails, continue without it
rag_latency_ms = 0 # 0 for failed
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="rag",
latency_ms=rag_latency_ms,
success=False,
error_message=str(pref_err)[:200],
user_id=req.user_id
)
reasoning_trace.append({
"step": "rag_prefetch",
"status": "error",
"error": str(pref_err)
})
rag_prefetch = None
tool_scores = self.tool_scorer.score(req.message, intent, rag_results)
reasoning_trace.append({
"step": "tool_scoring",
"scores": tool_scores
})
# 3) Tool selection (hybrid) - pass RAG results in context
ctx = {
"tenant_id": req.tenant_id,
"rag_results": rag_results,
"tool_scores": tool_scores
}
decision = await self.selector.select(intent, req.message, ctx)
reasoning_trace.append({
"step": "tool_selection",
"decision": decision.dict(),
"context_scores": tool_scores
})
tool_traces: List[Dict[str, Any]] = []
# 4) Handle multi-step tool execution
if decision.action == "multi_step" and decision.tool_input:
steps = decision.tool_input.get("steps", [])
if steps:
return await self._execute_multi_step(
req,
steps,
decision,
tool_traces,
reasoning_trace,
rag_prefetch
)
# 5) Execute single tool
tools_used = []
total_tokens = 0
if decision.action == "call_tool" and decision.tool:
try:
if decision.tool == "rag":
rag_start = time.time()
rag_resp = await self.mcp.call_rag(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
rag_latency_ms = int((time.time() - rag_start) * 1000)
tools_used.append("rag")
tool_traces.append({"tool": "rag", "response": rag_resp})
hits = self._extract_hits(rag_resp)
# Log RAG search and tool usage
hits_count = len(hits)
avg_score = None
top_score = None
if hits:
scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h]
if scores:
avg_score = sum(scores) / len(scores)
top_score = max(scores)
self.analytics.log_rag_search(
tenant_id=req.tenant_id,
query=req.message[:500],
hits_count=hits_count,
avg_score=avg_score,
top_score=top_score,
latency_ms=rag_latency_ms
)
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="rag",
latency_ms=rag_latency_ms,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "tool_execution",
"tool": "rag",
"hit_count": hits_count,
"summary": self._summarize_hits(rag_resp, limit=2),
"latency_ms": rag_latency_ms
})
prompt = self._build_prompt_with_rag(req, rag_resp)
llm_start = time.time()
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
llm_latency_ms = int((time.time() - llm_start) * 1000)
tools_used.append("llm")
# Estimate tokens (rough: ~4 chars per token)
estimated_tokens = len(llm_out) // 4 + len(prompt) // 4
total_tokens += estimated_tokens
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
latency_ms=llm_latency_ms,
tokens_used=estimated_tokens,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "llm_response",
"mode": "rag_synthesis",
"latency_ms": llm_latency_ms,
"estimated_tokens": estimated_tokens
})
total_latency_ms = int((time.time() - start_time) * 1000)
self.analytics.log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent=intent,
tools_used=tools_used,
total_tokens=total_tokens,
total_latency_ms=total_latency_ms,
success=True,
user_id=req.user_id
)
return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace)
if decision.tool == "web":
web_start = time.time()
web_resp = await self.mcp.call_web(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
web_latency_ms = int((time.time() - web_start) * 1000)
tools_used.append("web")
tool_traces.append({"tool": "web", "response": web_resp})
hits_count = len(self._extract_hits(web_resp))
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="web",
latency_ms=web_latency_ms,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "tool_execution",
"tool": "web",
"hit_count": hits_count,
"summary": self._summarize_hits(web_resp, limit=2),
"latency_ms": web_latency_ms
})
prompt = self._build_prompt_with_web(req, web_resp)
llm_start = time.time()
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
llm_latency_ms = int((time.time() - llm_start) * 1000)
tools_used.append("llm")
estimated_tokens = len(llm_out) // 4 + len(prompt) // 4
total_tokens += estimated_tokens
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
latency_ms=llm_latency_ms,
tokens_used=estimated_tokens,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "llm_response",
"mode": "web_synthesis",
"latency_ms": llm_latency_ms,
"estimated_tokens": estimated_tokens
})
total_latency_ms = int((time.time() - start_time) * 1000)
self.analytics.log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent=intent,
tools_used=tools_used,
total_tokens=total_tokens,
total_latency_ms=total_latency_ms,
success=True,
user_id=req.user_id
)
return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace)
if decision.tool == "admin":
admin_start = time.time()
admin_resp = await self.mcp.call_admin(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
admin_latency_ms = int((time.time() - admin_start) * 1000)
tools_used.append("admin")
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="admin",
latency_ms=admin_latency_ms,
success=True,
user_id=req.user_id
)
tool_traces.append({"tool": "admin", "response": admin_resp})
reasoning_trace.append({
"step": "tool_execution",
"tool": "admin",
"status": "completed",
"latency_ms": admin_latency_ms
})
total_latency_ms = int((time.time() - start_time) * 1000)
self.analytics.log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent=intent,
tools_used=tools_used,
total_tokens=0,
total_latency_ms=total_latency_ms,
success=True,
user_id=req.user_id
)
return AgentResponse(text=json.dumps(admin_resp), decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace)
if decision.tool == "llm":
llm_start = time.time()
llm_out = await self.llm.simple_call(req.message, temperature=req.temperature)
llm_latency_ms = int((time.time() - llm_start) * 1000)
tools_used.append("llm")
estimated_tokens = len(llm_out) // 4 + len(req.message) // 4
total_tokens += estimated_tokens
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
latency_ms=llm_latency_ms,
tokens_used=estimated_tokens,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "llm_response",
"mode": "direct",
"latency_ms": llm_latency_ms,
"estimated_tokens": estimated_tokens
})
total_latency_ms = int((time.time() - start_time) * 1000)
self.analytics.log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent=intent,
tools_used=tools_used,
total_tokens=total_tokens,
total_latency_ms=total_latency_ms,
success=True,
user_id=req.user_id
)
return AgentResponse(text=llm_out, decision=decision, reasoning_trace=reasoning_trace)
except Exception as e:
tool_traces.append({"tool": decision.tool, "error": str(e)})
try:
fallback = await self.llm.simple_call(req.message, temperature=req.temperature)
except Exception as llm_error:
error_msg = str(llm_error)
if "Cannot connect" in error_msg or "Ollama" in error_msg:
fallback = (
f"I encountered an error while processing your request: {str(e)}\n\n"
f"Additionally, the AI service (Ollama) is unavailable: {error_msg}\n\n"
f"To fix:\n"
f"1. Install Ollama from https://ollama.ai\n"
f"2. Start: `ollama serve`\n"
f"3. Pull model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`"
)
else:
fallback = f"I encountered an error while processing your request: {str(e)}. Additionally, the AI service is unavailable: {error_msg}"
return AgentResponse(
text=fallback,
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason=f"tool_error_fallback: {e}"),
tool_traces=tool_traces,
reasoning_trace=reasoning_trace + [{
"step": "error",
"tool": decision.tool,
"error": str(e)
}]
)
# Default: direct LLM response
try:
llm_start = time.time()
llm_out = await self.llm.simple_call(req.message, temperature=req.temperature)
llm_latency_ms = int((time.time() - llm_start) * 1000)
tools_used = ["llm"]
estimated_tokens = len(llm_out) // 4 + len(req.message) // 4
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
latency_ms=llm_latency_ms,
tokens_used=estimated_tokens,
success=True,
user_id=req.user_id
)
except Exception as e:
# If LLM fails, return a helpful error message
error_msg = str(e)
if "Cannot connect" in error_msg or "Ollama" in error_msg:
llm_out = (
f"I couldn't connect to the AI service (Ollama). "
f"Error: {error_msg}\n\n"
f"To fix this:\n"
f"1. Install Ollama from https://ollama.ai\n"
f"2. Start Ollama: `ollama serve`\n"
f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n"
f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file"
)
else:
llm_out = f"I apologize, but I'm unable to process your request right now. The AI service is unavailable: {error_msg}"
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
success=False,
error_message=error_msg[:200],
user_id=req.user_id
)
reasoning_trace.append({
"step": "error",
"tool": "llm",
"error": str(e)
})
total_latency_ms = int((time.time() - start_time) * 1000)
self.analytics.log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent=intent,
tools_used=tools_used if 'tools_used' in locals() else [],
total_tokens=estimated_tokens if 'estimated_tokens' in locals() else 0,
total_latency_ms=total_latency_ms,
success=True if 'llm_out' in locals() else False,
user_id=req.user_id
)
return AgentResponse(
text=llm_out,
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="default_llm"),
reasoning_trace=reasoning_trace
)
def _build_prompt_with_rag(self, req: AgentRequest, rag_resp: Dict[str, Any]) -> str:
snippets = []
if isinstance(rag_resp, dict):
hits = rag_resp.get("results") or rag_resp.get("hits") or []
for h in hits[:5]:
txt = h.get("text") or h.get("content") or str(h)
snippets.append(txt)
snippet_text = "\n---\n".join(snippets) or ""
prompt = (
f"You are an assistant helping tenant {req.tenant_id}. Use the following retrieved documents to answer the user's question.\n"
f"Documents:\n{snippet_text}\n\n"
f"User question: {req.message}\nProvide a concise, accurate answer and cite the source snippets where appropriate."
)
return prompt
async def _execute_multi_step(self, req: AgentRequest, steps: List[Dict[str, Any]],
decision: AgentDecision, tool_traces: List[Dict[str, Any]],
reasoning_trace: List[Dict[str, Any]],
pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse:
"""
Execute multiple tools in sequence or parallel and synthesize results with LLM.
Supports parallel execution when steps are marked with "parallel" flag.
"""
start_time = time.time()
rag_data = None
web_data = None
admin_data = None
collected_data = []
tools_used = []
total_tokens = 0
# Check if any step has parallel execution flag
parallel_step = None
for step_info in steps:
if step_info.get("parallel"):
parallel_step = step_info
break
# Handle parallel execution if detected
if parallel_step and parallel_step.get("parallel"):
parallel_config = parallel_step.get("parallel")
parallel_tasks = {}
start_time_parallel = time.time()
# Prepare parallel tasks
if "rag" in parallel_config:
rag_query = parallel_config["rag"]
if pre_fetched_rag:
# Use pre-fetched RAG if available - create a simple async function
async def get_prefetched_rag():
return pre_fetched_rag
parallel_tasks["rag"] = get_prefetched_rag()
else:
parallel_tasks["rag"] = self.mcp.call_rag(req.tenant_id, rag_query)
if "web" in parallel_config:
web_query = parallel_config["web"]
parallel_tasks["web"] = self.mcp.call_web(req.tenant_id, web_query)
# Execute tools in parallel
if parallel_tasks:
reasoning_trace.append({
"step": "parallel_execution",
"tools": list(parallel_tasks.keys()),
"mode": "parallel"
})
parallel_results = await self.run_parallel_tools(parallel_tasks)
parallel_latency_ms = int((time.time() - start_time_parallel) * 1000)
# Process RAG results
if "rag" in parallel_results:
rag_result = parallel_results["rag"]
if isinstance(rag_result, Exception):
tool_traces.append({"tool": "rag", "error": str(rag_result), "note": "parallel"})
reasoning_trace.append({
"step": "tool_execution",
"tool": "rag",
"status": "error",
"error": str(rag_result),
"latency_ms": parallel_latency_ms
})
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="rag",
latency_ms=parallel_latency_ms,
success=False,
error_message=str(rag_result)[:200],
user_id=req.user_id
)
else:
rag_data = rag_result
tools_used.append("rag")
tool_traces.append({"tool": "rag", "response": rag_result, "note": "parallel"})
hits_count = len(self._extract_hits(rag_result))
avg_score = None
top_score = None
if hits_count > 0:
scores = [h.get("score", 0.0) for h in self._extract_hits(rag_result) if isinstance(h, dict) and "score" in h]
if scores:
avg_score = sum(scores) / len(scores)
top_score = max(scores)
self.analytics.log_rag_search(
tenant_id=req.tenant_id,
query=req.message[:500],
hits_count=hits_count,
avg_score=avg_score,
top_score=top_score,
latency_ms=parallel_latency_ms
)
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="rag",
latency_ms=parallel_latency_ms,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "tool_execution",
"tool": "rag",
"hit_count": hits_count,
"summary": self._summarize_hits(rag_result, limit=2),
"latency_ms": parallel_latency_ms,
"mode": "parallel"
})
# Process Web results
if "web" in parallel_results:
web_result = parallel_results["web"]
if isinstance(web_result, Exception):
tool_traces.append({"tool": "web", "error": str(web_result), "note": "parallel"})
reasoning_trace.append({
"step": "tool_execution",
"tool": "web",
"status": "error",
"error": str(web_result),
"latency_ms": parallel_latency_ms
})
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="web",
latency_ms=parallel_latency_ms,
success=False,
error_message=str(web_result)[:200],
user_id=req.user_id
)
else:
web_data = web_result
tools_used.append("web")
tool_traces.append({"tool": "web", "response": web_result, "note": "parallel"})
hits_count = len(self._extract_hits(web_result))
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="web",
latency_ms=parallel_latency_ms,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "tool_execution",
"tool": "web",
"hit_count": hits_count,
"summary": self._summarize_hits(web_result, limit=2),
"latency_ms": parallel_latency_ms,
"mode": "parallel"
})
# Merge parallel results
merged_context = merge_parallel_results(parallel_results)
sources_list = list(set(e.get("source") for e in merged_context if e.get("source"))) if merged_context else []
reasoning_trace.append({
"step": "result_merger",
"merged_items": len(merged_context),
"sources": sources_list
})
# Format merged context for prompt
data_section = format_merged_context_for_prompt(merged_context, max_items=10)
else:
data_section = ""
else:
# Sequential execution (original logic)
parallel_tasks = {}
rag_parallel_query = self._first_query_for_tool(steps, "rag", req.message)
web_parallel_query = self._first_query_for_tool(steps, "web", req.message)
if rag_parallel_query and web_parallel_query and rag_parallel_query == web_parallel_query:
if not pre_fetched_rag:
parallel_tasks["rag"] = asyncio.create_task(self.mcp.call_rag(req.tenant_id, rag_parallel_query))
parallel_tasks["web"] = asyncio.create_task(self.mcp.call_web(req.tenant_id, web_parallel_query))
# Execute each step in sequence
for step_info in steps:
tool_name = step_info.get("tool")
step_input = step_info.get("input") or {}
query = step_input.get("query") or req.message
try:
if tool_name == "rag":
# Reuse pre-fetched RAG if available, otherwise fetch
if pre_fetched_rag and query == rag_parallel_query:
rag_resp = pre_fetched_rag
tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"})
elif parallel_tasks.get("rag") and query == rag_parallel_query:
rag_resp = await parallel_tasks["rag"]
tool_traces.append({"tool": "rag", "response": rag_resp, "note": "parallel"})
else:
rag_resp = await self.mcp.call_rag(req.tenant_id, query)
tool_traces.append({"tool": "rag", "response": rag_resp})
rag_data = rag_resp
tools_used.append("rag")
reasoning_trace.append({
"step": "tool_execution",
"tool": "rag",
"hit_count": len(self._extract_hits(rag_resp)),
"summary": self._summarize_hits(rag_resp, limit=2)
})
# Extract snippets for prompt
if isinstance(rag_resp, dict):
hits = rag_resp.get("results") or rag_resp.get("hits") or []
for h in hits[:5]:
txt = h.get("text") or h.get("content") or str(h)
collected_data.append(f"[RAG] {txt}")
elif tool_name == "web":
if parallel_tasks.get("web") and query == web_parallel_query:
web_resp = await parallel_tasks["web"]
tool_traces.append({"tool": "web", "response": web_resp, "note": "parallel"})
else:
web_resp = await self.mcp.call_web(req.tenant_id, query)
tool_traces.append({"tool": "web", "response": web_resp})
web_data = web_resp
tools_used.append("web")
reasoning_trace.append({
"step": "tool_execution",
"tool": "web",
"hit_count": len(self._extract_hits(web_resp)),
"summary": self._summarize_hits(web_resp, limit=2)
})
# Extract snippets for prompt
if isinstance(web_resp, dict):
hits = web_resp.get("results") or web_resp.get("items") or []
for h in hits[:5]:
title = h.get("title") or h.get("headline") or ""
snippet = h.get("snippet") or h.get("summary") or h.get("text") or ""
url = h.get("url") or h.get("link") or ""
collected_data.append(f"[WEB] {title}\n{snippet}\nSource: {url}")
elif tool_name == "admin":
admin_resp = await self.mcp.call_admin(req.tenant_id, query)
tool_traces.append({"tool": "admin", "response": admin_resp})
admin_data = admin_resp
tools_used.append("admin")
collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}")
reasoning_trace.append({
"step": "tool_execution",
"tool": "admin",
"status": "completed"
})
elif tool_name == "llm":
# LLM is always last - synthesize all collected data
break
except Exception as e:
tool_traces.append({"tool": tool_name, "error": str(e)})
# Continue with other tools even if one fails
reasoning_trace.append({
"step": "error",
"tool": tool_name,
"error": str(e)
})
# Build comprehensive prompt with all collected data
data_section = "\n---\n".join(collected_data) if collected_data else ""
# Build final prompt
if data_section:
prompt = (
f"You are an assistant helping tenant {req.tenant_id}.\n\n"
f"## Information Collected\n"
f"The following details have been gathered from multiple reliable sources:\n"
f"{data_section}\n\n"
f"## User Request\n"
f"{req.message}\n\n"
f"## Your Task\n"
f"Use the information above to directly address the user's request. "
f"Focus on giving the user exactly what they need—clear guidance, accurate facts, "
f"and practical steps whenever possible. If the information is incomplete, explain "
f"what can and cannot be concluded from the available data."
)
else:
# No data collected, just answer the question
prompt = req.message
# Final LLM synthesis
try:
llm_start = time.time()
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
llm_latency_ms = int((time.time() - llm_start) * 1000)
tools_used.append("llm")
estimated_tokens = len(llm_out) // 4 + len(prompt) // 4
total_tokens += estimated_tokens
self.analytics.log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
latency_ms=llm_latency_ms,
tokens_used=estimated_tokens,
success=True,
user_id=req.user_id
)
total_latency_ms = int((time.time() - start_time) * 1000)
self.analytics.log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent="multi_step",
tools_used=tools_used,
total_tokens=total_tokens,
total_latency_ms=total_latency_ms,
success=True,
user_id=req.user_id
)
return AgentResponse(
text=llm_out,
decision=decision,
tool_traces=tool_traces,
reasoning_trace=reasoning_trace + [{
"step": "llm_response",
"mode": "multi_step_parallel" if parallel_step else "multi_step",
"latency_ms": llm_latency_ms,
"estimated_tokens": estimated_tokens
}]
)
except Exception as e:
tool_traces.append({"tool": "llm", "error": str(e)})
error_msg = str(e)
# Provide helpful error message
if "Cannot connect" in error_msg or "Ollama" in error_msg:
fallback = (
f"I couldn't connect to the AI service (Ollama). "
f"Error: {error_msg}\n\n"
f"To fix this:\n"
f"1. Install Ollama from https://ollama.ai\n"
f"2. Start Ollama: `ollama serve`\n"
f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n"
f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file"
)
else:
fallback = f"I encountered an error while synthesizing the response: {error_msg}"
return AgentResponse(
text=fallback,
decision=AgentDecision(
action="respond",
tool=None,
tool_input=None,
reason=f"multi_step_llm_error: {e}"
),
tool_traces=tool_traces,
reasoning_trace=reasoning_trace + [{
"step": "error",
"tool": "llm",
"error": str(e)
}]
)
def _build_prompt_with_web(self, req: AgentRequest, web_resp: Dict[str, Any]) -> str:
snippets = []
if isinstance(web_resp, dict):
hits = web_resp.get("results") or web_resp.get("items") or []
for h in hits[:5]:
title = h.get("title") or h.get("headline") or ""
snippet = h.get("snippet") or h.get("summary") or h.get("text") or ""
url = h.get("url") or h.get("link") or ""
snippets.append(f"{title}\n{snippet}\nSource: {url}")
snippet_text = "\n---\n".join(snippets) or ""
# prompt = (
# f"You are an assistant with access to recent web search results. Use the following results to answer.\n{snippet_text}\n\n"
# f"User question: {req.message}\nAnswer succinctly and indicate which results you used."
# )
prompt = (
f"You are an assistant with access to recent web search results.\n\n"
f"## Search Results\n"
f"{snippet_text}\n\n"
f"## User Question\n"
f"{req.message}\n\n"
f"## Your Task\n"
f"Provide a clear, accurate, and succinct answer based on the search results above. "
f"Indicate which results you used in your reasoning."
)
return prompt
@staticmethod
def _extract_hits(resp: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]:
if not isinstance(resp, dict):
return []
return resp.get("results") or resp.get("hits") or resp.get("items") or []
def _summarize_hits(self, resp: Optional[Dict[str, Any]], limit: int = 3) -> List[str]:
hits = self._extract_hits(resp)
summaries = []
for hit in hits[:limit]:
if isinstance(hit, dict):
snippet = hit.get("text") or hit.get("content") or hit.get("snippet") or ""
else:
snippet = str(hit)
summaries.append(snippet[:160])
return summaries
async def run_parallel_tools(self, tasks: Dict[str, Any]) -> Dict[str, Any]:
"""
Run multiple tools in parallel using asyncio.gather.
Args:
tasks: Dictionary mapping tool names to coroutines, e.g.:
{"rag": rag_coro, "web": web_coro}
Returns:
Dictionary mapping tool names to results, e.g.:
{"rag": rag_result, "web": web_result}
Exceptions are returned as values if a tool fails.
"""
if not tasks:
return {}
names = list(tasks.keys())
coros = [tasks[name] for name in names]
# Run all coroutines in parallel, return exceptions instead of raising
results = await asyncio.gather(*coros, return_exceptions=True)
return {names[i]: results[i] for i in range(len(names))}
@staticmethod
def _first_query_for_tool(steps: List[Dict[str, Any]], tool_name: str, default_query: str) -> Optional[str]:
for step in steps:
if step.get("tool") == tool_name:
input_data = step.get("input") or {}
return input_data.get("query") or default_query
return None