IntegraChat / backend /api /services /agent_orchestrator.py
nothingworry's picture
feat: Add AI metadata extraction, latency prediction, context-aware routing, and tool output schemas
d1e5882
raw
history blame
79.1 kB
# =============================================================
# File: backend/api/services/agent_orchestrator.py
# =============================================================
"""
Agent Orchestrator (integrated with enterprise RedFlagDetector)
Place at: backend/api/services/agent_orchestrator.py
"""
from __future__ import annotations
import asyncio
import json
import os
from typing import List, Dict, Any, Optional
import logging
from ..models.agent import AgentRequest, AgentDecision, AgentResponse
from ..models.redflag import RedFlagMatch
from .redflag_detector import RedFlagDetector
from .intent_classifier import IntentClassifier
from .tool_selector import ToolSelector
from .llm_client import LLMClient
from ..mcp_clients.mcp_client import MCPClient
from .tool_scoring import ToolScoringService
from ..storage.analytics_store import AnalyticsStore
from .result_merger import merge_parallel_results, format_merged_context_for_prompt
from .tool_metadata import validate_tool_output, get_tool_schema
import time
logger = logging.getLogger(__name__)
from dotenv import load_dotenv
load_dotenv()
class AgentOrchestrator:
def __init__(self, rag_mcp_url: str, web_mcp_url: str, admin_mcp_url: str, llm_backend: str = "ollama"):
self.mcp = MCPClient(rag_mcp_url, web_mcp_url, admin_mcp_url)
self.llm = LLMClient(backend=llm_backend, url=os.getenv("OLLAMA_URL"), api_key=os.getenv("GROQ_API_KEY"), model=os.getenv("OLLAMA_MODEL"))
# pass admin_mcp_url so detector can call back
self.redflag = RedFlagDetector(
supabase_url=os.getenv("SUPABASE_URL"),
supabase_key=os.getenv("SUPABASE_SERVICE_KEY"),
admin_mcp_url=admin_mcp_url
)
self.intent = IntentClassifier(llm_client=self.llm)
self.selector = ToolSelector(llm_client=self.llm)
self.tool_scorer = ToolScoringService()
self._analytics: Optional[AnalyticsStore] = None
self._analytics_disabled = os.getenv("ANALYTICS_DISABLED", "").lower() in {"1", "true", "yes"}
self._analytics_failed = False
self._log_analytics_backend_once()
def _log_analytics_backend_once(self) -> None:
if getattr(AgentOrchestrator, "_analytics_backend_logged", False):
return
if self._analytics_disabled:
print("⚠️ AgentOrchestrator Analytics: Disabled via ANALYTICS_DISABLED")
else:
store = self._get_analytics()
if store is None:
print("⚠️ AgentOrchestrator Analytics: Disabled (Supabase not configured)")
elif store.use_supabase:
print("✅ AgentOrchestrator Analytics: Using Supabase backend")
else:
print("⚠️ AgentOrchestrator Analytics: Using fallback backend")
AgentOrchestrator._analytics_backend_logged = True
def _get_analytics(self) -> Optional[AnalyticsStore]:
if self._analytics_disabled or self._analytics_failed:
return None
if self._analytics is not None:
return self._analytics
try:
self._analytics = AnalyticsStore()
except RuntimeError as exc:
logger.warning("AgentOrchestrator analytics disabled: %s", exc)
self._analytics_failed = True
self._analytics = None
except Exception as exc: # pragma: no cover - unexpected initialization failures
logger.debug("AgentOrchestrator analytics unexpected init failure: %s", exc)
self._analytics_failed = True
self._analytics = None
return self._analytics
def _analytics_log_tool_usage(self, **kwargs: Any) -> None:
analytics = self._get_analytics()
if not analytics:
return
try:
analytics.log_tool_usage(**kwargs)
except Exception as exc: # pragma: no cover - analytics failures should not break flow
logger.debug("AgentOrchestrator tool analytics failed: %s", exc)
def _analytics_log_agent_query(self, **kwargs: Any) -> None:
analytics = self._get_analytics()
if not analytics:
return
try:
analytics.log_agent_query(**kwargs)
except Exception as exc: # pragma: no cover
logger.debug("AgentOrchestrator agent query analytics failed: %s", exc)
def _analytics_log_rag_search(self, **kwargs: Any) -> None:
analytics = self._get_analytics()
if not analytics:
return
try:
analytics.log_rag_search(**kwargs)
except Exception as exc: # pragma: no cover
logger.debug("AgentOrchestrator RAG analytics failed: %s", exc)
def _analytics_log_redflag_violation(self, **kwargs: Any) -> None:
analytics = self._get_analytics()
if not analytics:
return
try:
analytics.log_redflag_violation(**kwargs)
except Exception as exc: # pragma: no cover
logger.debug("AgentOrchestrator redflag analytics failed: %s", exc)
async def handle(self, req: AgentRequest) -> AgentResponse:
start_time = time.time()
reasoning_trace: List[Dict[str, Any]] = []
reasoning_trace.append({
"step": "request_received",
"tenant_id": req.tenant_id,
"user_id": req.user_id,
"message_preview": req.message[:120]
})
# 1) FIRST: Check admin rules - if any rule matches, respond according to rule
matches: List[RedFlagMatch] = await self.redflag.check(req.tenant_id, req.message)
reasoning_trace.append({
"step": "admin_rules_check",
"match_count": len(matches),
"matches": [m.__dict__ for m in matches]
})
if matches:
# Log all rule matches
for match in matches:
self._analytics_log_redflag_violation(
tenant_id=req.tenant_id,
rule_id=match.rule_id,
rule_pattern=match.pattern,
severity=match.severity,
matched_text=match.matched_text,
confidence=match.confidence,
message_preview=req.message[:200],
user_id=req.user_id
)
# Categorize rules: brief response rules vs blocking rules
brief_response_rules = []
blocking_rules = []
for match in matches:
rule_text = (match.description or match.pattern or "").lower()
is_brief_rule = (
match.severity == "low" and (
"greeting" in rule_text or
"brief" in rule_text or
"simple response" in rule_text or
"keep.*response.*brief" in rule_text or
"do not.*verbose" in rule_text or
"respond.*briefly" in rule_text
)
)
if is_brief_rule:
brief_response_rules.append(match)
else:
blocking_rules.append(match)
# Handle brief response rules (greetings, etc.) - return immediately
if brief_response_rules and not blocking_rules:
# Return brief response without proceeding to normal flow
brief_responses = [
"Hello! How can I help you today?",
"Hi there! What can I assist you with?",
"Hello! I'm here to help. What do you need?",
"Hi! How can I assist you?"
]
import random
brief_response = random.choice(brief_responses)
reasoning_trace.append({
"step": "brief_response_rule_matched",
"action": "brief_response",
"matched_rules": [m.rule_id for m in brief_response_rules],
"message": "Brief response rule matched, returning brief response (skipping normal flow)"
})
total_latency_ms = int((time.time() - start_time) * 1000)
self._analytics_log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent="greeting",
tools_used=[],
total_tokens=len(brief_response) // 4,
total_latency_ms=total_latency_ms,
success=True,
user_id=req.user_id
)
return AgentResponse(
text=brief_response,
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="brief_response_rule"),
reasoning_trace=reasoning_trace
)
# Handle blocking rules (security, compliance, etc.) - block and return immediately
if blocking_rules:
# Notify admin asynchronously
try:
await self.redflag.notify_admin(req.tenant_id, blocking_rules, source_payload={"message": req.message, "user_id": req.user_id})
except Exception:
pass
decision = AgentDecision(
action="block",
tool="admin",
tool_input={"violations": [m.__dict__ for m in blocking_rules]},
reason="admin_rule_violation"
)
# Build detailed prompt for LLM to generate natural red flag response
violations_details = []
for i, m in enumerate(blocking_rules, 1):
rule_name = m.description or m.pattern or "Policy violation"
detail = f"{i}. **{rule_name}** (Severity: {m.severity})"
if m.matched_text:
detail += f"\n - Detected phrase: \"{m.matched_text}\""
violations_details.append(detail)
llm_prompt = f"""A user made the following request: "{req.message}"
However, this request violates company policies. The following policy violations were detected:
{chr(10).join(violations_details)}
Your task: Write a clear, professional, and empathetic response to inform the user that:
1. Their request cannot be processed due to policy violations
2. Which specific policy was violated (mention it naturally)
3. The incident has been logged for security review
4. They should contact an administrator if they need assistance or believe this is an error
Write a natural, conversational response (2-4 sentences) that feels helpful rather than robotic. Be professional but understanding.
Response:"""
# Generate LLM response for red flag
try:
llm_response = await self.llm.simple_call(llm_prompt, temperature=min(req.temperature + 0.2, 0.7)) # Slightly higher temp for more natural response
llm_response = llm_response.strip()
# Add warning emoji if not present
if not llm_response.startswith("⚠️") and not llm_response.startswith("🚨"):
llm_response = f"⚠️ {llm_response}"
except Exception as e:
# Fallback to a simple message if LLM fails
summary = "; ".join(
f"{m.description or m.pattern}"
for m in matches
)
llm_response = f"⚠️ I'm unable to process your request because it violates our company policy: {summary}. This incident has been logged. Please contact your administrator if you need assistance."
total_latency_ms = int((time.time() - start_time) * 1000)
# Log LLM usage for red flag response
estimated_tokens = len(llm_response) // 4 + len(llm_prompt) // 4
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
latency_ms=total_latency_ms,
tokens_used=estimated_tokens,
success=True,
user_id=req.user_id
)
self._analytics_log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent="admin",
tools_used=["admin", "llm"],
total_tokens=estimated_tokens,
total_latency_ms=total_latency_ms,
success=False,
user_id=req.user_id
)
return AgentResponse(
text=llm_response,
decision=decision,
tool_traces=[{"redflags": [m.__dict__ for m in blocking_rules]}],
reasoning_trace=reasoning_trace
)
# 2) ONLY IF NO RULES MATCHED: Proceed with normal flow (intent classification, RAG, etc.)
# 2.1) Optional: Try to rewrite message if it might violate rules (preventive self-correction)
# Note: This is a lighter check - we already blocked above if rules matched
# This is for edge cases where we want to proactively improve the message
safe_message = req.message # Default to original
intent = await self.intent.classify(req.message)
reasoning_trace.append({
"step": "intent_detection",
"intent": intent
})
# 2.5) Pre-fetch RAG results if available (for tool selector context)
rag_prefetch = None
rag_results = []
try:
# Try to pre-fetch RAG to help tool selector make better decisions
rag_start = time.time()
rag_prefetch = await self.mcp.call_rag(req.tenant_id, req.message)
rag_latency_ms = int((time.time() - rag_start) * 1000)
if isinstance(rag_prefetch, dict):
rag_results = rag_prefetch.get("results") or rag_prefetch.get("hits") or []
# Log RAG search event
hits_count = len(rag_results)
avg_score = None
top_score = None
if rag_results:
scores = [h.get("score", 0.0) for h in rag_results if isinstance(h, dict) and "score" in h]
if scores:
avg_score = sum(scores) / len(scores)
top_score = max(scores)
self._analytics_log_rag_search(
tenant_id=req.tenant_id,
query=req.message[:500],
hits_count=hits_count,
avg_score=avg_score,
top_score=top_score,
latency_ms=rag_latency_ms
)
# Log tool usage
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="rag",
latency_ms=rag_latency_ms,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "rag_prefetch",
"status": "ok",
"hit_count": len(rag_results),
"latency_ms": rag_latency_ms
})
except Exception as pref_err:
# If RAG fails, continue without it
rag_latency_ms = 0 # 0 for failed
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="rag",
latency_ms=rag_latency_ms,
success=False,
error_message=str(pref_err)[:200],
user_id=req.user_id
)
reasoning_trace.append({
"step": "rag_prefetch",
"status": "error",
"error": str(pref_err)
})
rag_prefetch = None
tool_scores = self.tool_scorer.score(req.message, intent, rag_results)
reasoning_trace.append({
"step": "tool_scoring",
"scores": tool_scores
})
# 3) Tool selection (hybrid) - pass RAG results, memory, and admin violations in context
# Get recent memory for context-aware routing
from backend.mcp_server.common.memory import get_recent_memory
session_id = req.conversation_history[-1].get("session_id") if req.conversation_history else None
recent_memory = []
if session_id:
recent_memory = get_recent_memory(session_id)
# Get admin violations if any
admin_violations = []
if hasattr(self, 'redflag') and self.redflag:
# Check if there were any violations detected
# (This would be set during redflag checking earlier in the flow)
pass # Admin violations are checked separately
ctx = {
"tenant_id": req.tenant_id,
"rag_results": rag_results,
"tool_scores": tool_scores,
"memory": recent_memory, # Context-aware routing: recent tool outputs
"admin_violations": admin_violations # Context-aware routing: admin rule severity
}
decision = await self.selector.select(intent, req.message, ctx)
reasoning_trace.append({
"step": "tool_selection",
"decision": decision.dict(),
"context_scores": tool_scores
})
tool_traces: List[Dict[str, Any]] = []
# 4) Handle multi-step tool execution
if decision.action == "multi_step" and decision.tool_input:
steps = decision.tool_input.get("steps", [])
if steps:
return await self._execute_multi_step(
req,
steps,
decision,
tool_traces,
reasoning_trace,
rag_prefetch
)
# 5) Execute single tool
tools_used = []
total_tokens = 0
if decision.action == "call_tool" and decision.tool:
try:
if decision.tool == "rag":
# Use autonomous retry with self-correction
rag_query = decision.tool_input.get("query") if decision.tool_input else req.message
rag_start = time.time()
rag_resp = await self.rag_with_repair(
query=rag_query,
tenant_id=req.tenant_id,
original_threshold=0.3,
reasoning_trace=reasoning_trace,
user_id=req.user_id
)
rag_latency_ms = int((time.time() - rag_start) * 1000)
tools_used.append("rag")
# Validate and format RAG output to conform to schema
rag_formatted = self._format_tool_output("rag", rag_resp, rag_latency_ms)
tool_traces.append({"tool": "rag", "response": rag_formatted})
hits = self._extract_hits(rag_formatted)
# Calculate scores for logging
hits_count = len(hits)
avg_score = rag_formatted.get("avg_score")
top_score = rag_formatted.get("top_score")
reasoning_trace.append({
"step": "tool_execution",
"tool": "rag",
"hit_count": hits_count,
"top_score": top_score,
"avg_score": avg_score,
"summary": self._summarize_hits(rag_formatted, limit=2)
})
prompt = self._build_prompt_with_rag(req, rag_formatted)
llm_start = time.time()
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
llm_latency_ms = int((time.time() - llm_start) * 1000)
tools_used.append("llm")
# Estimate tokens (rough: ~4 chars per token)
estimated_tokens = len(llm_out) // 4 + len(prompt) // 4
total_tokens += estimated_tokens
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
latency_ms=llm_latency_ms,
tokens_used=estimated_tokens,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "llm_response",
"mode": "rag_synthesis",
"latency_ms": llm_latency_ms,
"estimated_tokens": estimated_tokens
})
total_latency_ms = int((time.time() - start_time) * 1000)
self._analytics_log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent=intent,
tools_used=tools_used,
total_tokens=total_tokens,
total_latency_ms=total_latency_ms,
success=True,
user_id=req.user_id
)
return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace)
if decision.tool == "web":
# Use autonomous retry with query rewriting
web_query = decision.tool_input.get("query") if decision.tool_input else req.message
web_start = time.time()
web_resp = await self.web_with_repair(
query=web_query,
tenant_id=req.tenant_id,
reasoning_trace=reasoning_trace,
user_id=req.user_id
)
web_latency_ms = int((time.time() - web_start) * 1000)
tools_used.append("web")
# Validate and format Web output to conform to schema
web_formatted = self._format_tool_output("web", web_resp, web_latency_ms)
tool_traces.append({"tool": "web", "response": web_formatted})
hits_count = len(self._extract_hits(web_formatted))
reasoning_trace.append({
"step": "tool_execution",
"tool": "web",
"hit_count": hits_count,
"summary": self._summarize_hits(web_formatted, limit=2)
})
prompt = self._build_prompt_with_web(req, web_formatted)
llm_start = time.time()
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
llm_latency_ms = int((time.time() - llm_start) * 1000)
tools_used.append("llm")
estimated_tokens = len(llm_out) // 4 + len(prompt) // 4
total_tokens += estimated_tokens
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
latency_ms=llm_latency_ms,
tokens_used=estimated_tokens,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "llm_response",
"mode": "web_synthesis",
"latency_ms": llm_latency_ms,
"estimated_tokens": estimated_tokens
})
total_latency_ms = int((time.time() - start_time) * 1000)
self._analytics_log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent=intent,
tools_used=tools_used,
total_tokens=total_tokens,
total_latency_ms=total_latency_ms,
success=True,
user_id=req.user_id
)
return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace)
if decision.tool == "admin":
admin_start = time.time()
admin_resp = await self.mcp.call_admin(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
admin_latency_ms = int((time.time() - admin_start) * 1000)
tools_used.append("admin")
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="admin",
latency_ms=admin_latency_ms,
success=True,
user_id=req.user_id
)
# Validate and format Admin output to conform to schema
admin_formatted = self._format_tool_output("admin", admin_resp, admin_latency_ms)
tool_traces.append({"tool": "admin", "response": admin_formatted})
reasoning_trace.append({
"step": "tool_execution",
"tool": "admin",
"status": "completed",
"latency_ms": admin_latency_ms
})
total_latency_ms = int((time.time() - start_time) * 1000)
self._analytics_log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent=intent,
tools_used=tools_used,
total_tokens=0,
total_latency_ms=total_latency_ms,
success=True,
user_id=req.user_id
)
return AgentResponse(text=json.dumps(admin_resp), decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace)
if decision.tool == "llm":
llm_start = time.time()
llm_out = await self.llm.simple_call(req.message, temperature=req.temperature)
llm_latency_ms = int((time.time() - llm_start) * 1000)
tools_used.append("llm")
estimated_tokens = len(llm_out) // 4 + len(req.message) // 4
total_tokens += estimated_tokens
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
latency_ms=llm_latency_ms,
tokens_used=estimated_tokens,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "llm_response",
"mode": "direct",
"latency_ms": llm_latency_ms,
"estimated_tokens": estimated_tokens
})
total_latency_ms = int((time.time() - start_time) * 1000)
self._analytics_log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent=intent,
tools_used=tools_used,
total_tokens=total_tokens,
total_latency_ms=total_latency_ms,
success=True,
user_id=req.user_id
)
return AgentResponse(text=llm_out, decision=decision, reasoning_trace=reasoning_trace)
except Exception as e:
tool_traces.append({"tool": decision.tool, "error": str(e)})
try:
fallback = await self.llm.simple_call(req.message, temperature=req.temperature)
except Exception as llm_error:
error_msg = str(llm_error)
if "Cannot connect" in error_msg or "Ollama" in error_msg:
fallback = (
f"I encountered an error while processing your request: {str(e)}\n\n"
f"Additionally, the AI service (Ollama) is unavailable: {error_msg}\n\n"
f"To fix:\n"
f"1. Install Ollama from https://ollama.ai\n"
f"2. Start: `ollama serve`\n"
f"3. Pull model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`"
)
else:
fallback = f"I encountered an error while processing your request: {str(e)}. Additionally, the AI service is unavailable: {error_msg}"
return AgentResponse(
text=fallback,
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason=f"tool_error_fallback: {e}"),
tool_traces=tool_traces,
reasoning_trace=reasoning_trace + [{
"step": "error",
"tool": decision.tool,
"error": str(e)
}]
)
# Default: direct LLM response
try:
llm_start = time.time()
llm_out = await self.llm.simple_call(req.message, temperature=req.temperature)
llm_latency_ms = int((time.time() - llm_start) * 1000)
tools_used = ["llm"]
estimated_tokens = len(llm_out) // 4 + len(req.message) // 4
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
latency_ms=llm_latency_ms,
tokens_used=estimated_tokens,
success=True,
user_id=req.user_id
)
except Exception as e:
# If LLM fails, return a helpful error message
error_msg = str(e)
if "Cannot connect" in error_msg or "Ollama" in error_msg:
llm_out = (
f"I couldn't connect to the AI service (Ollama). "
f"Error: {error_msg}\n\n"
f"To fix this:\n"
f"1. Install Ollama from https://ollama.ai\n"
f"2. Start Ollama: `ollama serve`\n"
f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n"
f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file"
)
else:
llm_out = f"I apologize, but I'm unable to process your request right now. The AI service is unavailable: {error_msg}"
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
success=False,
error_message=error_msg[:200],
user_id=req.user_id
)
reasoning_trace.append({
"step": "error",
"tool": "llm",
"error": str(e)
})
total_latency_ms = int((time.time() - start_time) * 1000)
self._analytics_log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent=intent,
tools_used=tools_used if 'tools_used' in locals() else [],
total_tokens=estimated_tokens if 'estimated_tokens' in locals() else 0,
total_latency_ms=total_latency_ms,
success=True if 'llm_out' in locals() else False,
user_id=req.user_id
)
return AgentResponse(
text=llm_out,
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="default_llm"),
reasoning_trace=reasoning_trace
)
def _build_prompt_with_rag(self, req: AgentRequest, rag_resp: Dict[str, Any]) -> str:
snippets = []
if isinstance(rag_resp, dict):
hits = rag_resp.get("results") or rag_resp.get("hits") or []
for h in hits[:5]:
txt = h.get("text") or h.get("content") or str(h)
snippets.append(txt)
snippet_text = "\n---\n".join(snippets) or ""
prompt = (
f"You are an assistant helping tenant {req.tenant_id}. Use the following retrieved documents to answer the user's question.\n"
f"Documents:\n{snippet_text}\n\n"
f"User question: {req.message}\nProvide a concise, accurate answer and cite the source snippets where appropriate."
)
return prompt
async def _execute_multi_step(self, req: AgentRequest, steps: List[Dict[str, Any]],
decision: AgentDecision, tool_traces: List[Dict[str, Any]],
reasoning_trace: List[Dict[str, Any]],
pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse:
"""
Execute multiple tools in sequence or parallel and synthesize results with LLM.
Supports parallel execution when steps are marked with "parallel" flag.
"""
start_time = time.time()
rag_data = None
web_data = None
admin_data = None
collected_data = []
tools_used = []
total_tokens = 0
# Check if any step has parallel execution flag
parallel_step = None
for step_info in steps:
if step_info.get("parallel"):
parallel_step = step_info
break
# Handle parallel execution if detected
if parallel_step and parallel_step.get("parallel"):
parallel_config = parallel_step.get("parallel")
parallel_tasks = {}
start_time_parallel = time.time()
# Prepare parallel tasks with retry logic
if "rag" in parallel_config:
rag_query = parallel_config["rag"]
if pre_fetched_rag:
# Use pre-fetched RAG if available - create a simple async function
async def get_prefetched_rag():
return pre_fetched_rag
parallel_tasks["rag"] = get_prefetched_rag()
else:
# Wrap with retry logic for parallel execution
async def rag_with_retry_wrapper():
return await self.rag_with_repair(
query=rag_query,
tenant_id=req.tenant_id,
original_threshold=0.3,
reasoning_trace=reasoning_trace,
user_id=req.user_id
)
parallel_tasks["rag"] = rag_with_retry_wrapper()
if "web" in parallel_config:
web_query = parallel_config["web"]
# Wrap with retry logic for parallel execution
async def web_with_retry_wrapper():
return await self.web_with_repair(
query=web_query,
tenant_id=req.tenant_id,
reasoning_trace=reasoning_trace,
user_id=req.user_id
)
parallel_tasks["web"] = web_with_retry_wrapper()
# Execute tools in parallel
if parallel_tasks:
reasoning_trace.append({
"step": "parallel_execution",
"tools": list(parallel_tasks.keys()),
"mode": "parallel"
})
parallel_results = await self.run_parallel_tools(parallel_tasks)
parallel_latency_ms = int((time.time() - start_time_parallel) * 1000)
# Process RAG results
if "rag" in parallel_results:
rag_result = parallel_results["rag"]
if isinstance(rag_result, Exception):
tool_traces.append({"tool": "rag", "error": str(rag_result), "note": "parallel"})
reasoning_trace.append({
"step": "tool_execution",
"tool": "rag",
"status": "error",
"error": str(rag_result),
"latency_ms": parallel_latency_ms
})
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="rag",
latency_ms=parallel_latency_ms,
success=False,
error_message=str(rag_result)[:200],
user_id=req.user_id
)
else:
rag_data = rag_result
tools_used.append("rag")
tool_traces.append({"tool": "rag", "response": rag_result, "note": "parallel"})
hits_count = len(self._extract_hits(rag_result))
avg_score = None
top_score = None
if hits_count > 0:
scores = [h.get("score", 0.0) for h in self._extract_hits(rag_result) if isinstance(h, dict) and "score" in h]
if scores:
avg_score = sum(scores) / len(scores)
top_score = max(scores)
self._analytics_log_rag_search(
tenant_id=req.tenant_id,
query=req.message[:500],
hits_count=hits_count,
avg_score=avg_score,
top_score=top_score,
latency_ms=parallel_latency_ms
)
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="rag",
latency_ms=parallel_latency_ms,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "tool_execution",
"tool": "rag",
"hit_count": hits_count,
"summary": self._summarize_hits(rag_result, limit=2),
"latency_ms": parallel_latency_ms,
"mode": "parallel"
})
# Process Web results
if "web" in parallel_results:
web_result = parallel_results["web"]
if isinstance(web_result, Exception):
tool_traces.append({"tool": "web", "error": str(web_result), "note": "parallel"})
reasoning_trace.append({
"step": "tool_execution",
"tool": "web",
"status": "error",
"error": str(web_result),
"latency_ms": parallel_latency_ms
})
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="web",
latency_ms=parallel_latency_ms,
success=False,
error_message=str(web_result)[:200],
user_id=req.user_id
)
else:
web_data = web_result
tools_used.append("web")
tool_traces.append({"tool": "web", "response": web_result, "note": "parallel"})
hits_count = len(self._extract_hits(web_result))
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="web",
latency_ms=parallel_latency_ms,
success=True,
user_id=req.user_id
)
reasoning_trace.append({
"step": "tool_execution",
"tool": "web",
"hit_count": hits_count,
"summary": self._summarize_hits(web_result, limit=2),
"latency_ms": parallel_latency_ms,
"mode": "parallel"
})
# Merge parallel results
merged_context = merge_parallel_results(parallel_results)
sources_list = list(set(e.get("source") for e in merged_context if e.get("source"))) if merged_context else []
reasoning_trace.append({
"step": "result_merger",
"merged_items": len(merged_context),
"sources": sources_list
})
# Format merged context for prompt
data_section = format_merged_context_for_prompt(merged_context, max_items=10)
else:
data_section = ""
else:
# Sequential execution (original logic)
parallel_tasks = {}
rag_parallel_query = self._first_query_for_tool(steps, "rag", req.message)
web_parallel_query = self._first_query_for_tool(steps, "web", req.message)
if rag_parallel_query and web_parallel_query and rag_parallel_query == web_parallel_query:
if not pre_fetched_rag:
parallel_tasks["rag"] = asyncio.create_task(self.mcp.call_rag(req.tenant_id, rag_parallel_query))
parallel_tasks["web"] = asyncio.create_task(self.mcp.call_web(req.tenant_id, web_parallel_query))
# Execute each step in sequence
for step_info in steps:
tool_name = step_info.get("tool")
step_input = step_info.get("input") or {}
query = step_input.get("query") or req.message
try:
if tool_name == "rag":
# Reuse pre-fetched RAG if available, otherwise fetch with retry
if pre_fetched_rag and query == rag_parallel_query:
rag_resp = pre_fetched_rag
tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"})
elif parallel_tasks.get("rag") and query == rag_parallel_query:
rag_resp = await parallel_tasks["rag"]
tool_traces.append({"tool": "rag", "response": rag_resp, "note": "parallel"})
else:
# Use autonomous retry with self-correction
rag_resp = await self.rag_with_repair(
query=query,
tenant_id=req.tenant_id,
original_threshold=0.3,
reasoning_trace=reasoning_trace,
user_id=req.user_id
)
tool_traces.append({"tool": "rag", "response": rag_resp, "note": "with_retry"})
rag_data = rag_resp
tools_used.append("rag")
hits = self._extract_hits(rag_resp)
reasoning_trace.append({
"step": "tool_execution",
"tool": "rag",
"hit_count": len(hits),
"summary": self._summarize_hits(rag_resp, limit=2)
})
# Extract snippets for prompt
if isinstance(rag_resp, dict):
for h in hits[:5]:
txt = h.get("text") or h.get("content") or str(h)
collected_data.append(f"[RAG] {txt}")
elif tool_name == "web":
if parallel_tasks.get("web") and query == web_parallel_query:
web_resp = await parallel_tasks["web"]
tool_traces.append({"tool": "web", "response": web_resp, "note": "parallel"})
else:
# Use autonomous retry with query rewriting
web_resp = await self.web_with_repair(
query=query,
tenant_id=req.tenant_id,
reasoning_trace=reasoning_trace,
user_id=req.user_id
)
tool_traces.append({"tool": "web", "response": web_resp, "note": "with_retry"})
web_data = web_resp
tools_used.append("web")
hits = self._extract_hits(web_resp)
reasoning_trace.append({
"step": "tool_execution",
"tool": "web",
"hit_count": len(hits),
"summary": self._summarize_hits(web_resp, limit=2)
})
# Extract snippets for prompt
if isinstance(web_resp, dict):
for h in hits[:5]:
title = h.get("title") or h.get("headline") or ""
snippet = h.get("snippet") or h.get("summary") or h.get("text") or ""
url = h.get("url") or h.get("link") or ""
collected_data.append(f"[WEB] {title}\n{snippet}\nSource: {url}")
elif tool_name == "admin":
admin_resp = await self.mcp.call_admin(req.tenant_id, query)
tool_traces.append({"tool": "admin", "response": admin_resp})
admin_data = admin_resp
tools_used.append("admin")
collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}")
reasoning_trace.append({
"step": "tool_execution",
"tool": "admin",
"status": "completed"
})
elif tool_name == "llm":
# LLM is always last - synthesize all collected data
break
except Exception as e:
tool_traces.append({"tool": tool_name, "error": str(e)})
# Continue with other tools even if one fails
reasoning_trace.append({
"step": "error",
"tool": tool_name,
"error": str(e)
})
# Build comprehensive prompt with all collected data
data_section = "\n---\n".join(collected_data) if collected_data else ""
# Build final prompt
if data_section:
prompt = (
f"You are an assistant helping tenant {req.tenant_id}.\n\n"
f"## Information Collected\n"
f"The following details have been gathered from multiple reliable sources:\n"
f"{data_section}\n\n"
f"## User Request\n"
f"{req.message}\n\n"
f"## Your Task\n"
f"Use the information above to directly address the user's request. "
f"Focus on giving the user exactly what they need—clear guidance, accurate facts, "
f"and practical steps whenever possible. If the information is incomplete, explain "
f"what can and cannot be concluded from the available data."
)
else:
# No data collected, just answer the question
prompt = req.message
# Final LLM synthesis
try:
llm_start = time.time()
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
llm_latency_ms = int((time.time() - llm_start) * 1000)
tools_used.append("llm")
estimated_tokens = len(llm_out) // 4 + len(prompt) // 4
total_tokens += estimated_tokens
self._analytics_log_tool_usage(
tenant_id=req.tenant_id,
tool_name="llm",
latency_ms=llm_latency_ms,
tokens_used=estimated_tokens,
success=True,
user_id=req.user_id
)
total_latency_ms = int((time.time() - start_time) * 1000)
self._analytics_log_agent_query(
tenant_id=req.tenant_id,
message_preview=req.message[:200],
intent="multi_step",
tools_used=tools_used,
total_tokens=total_tokens,
total_latency_ms=total_latency_ms,
success=True,
user_id=req.user_id
)
return AgentResponse(
text=llm_out,
decision=decision,
tool_traces=tool_traces,
reasoning_trace=reasoning_trace + [{
"step": "llm_response",
"mode": "multi_step_parallel" if parallel_step else "multi_step",
"latency_ms": llm_latency_ms,
"estimated_tokens": estimated_tokens
}]
)
except Exception as e:
tool_traces.append({"tool": "llm", "error": str(e)})
error_msg = str(e)
# Provide helpful error message
if "Cannot connect" in error_msg or "Ollama" in error_msg:
fallback = (
f"I couldn't connect to the AI service (Ollama). "
f"Error: {error_msg}\n\n"
f"To fix this:\n"
f"1. Install Ollama from https://ollama.ai\n"
f"2. Start Ollama: `ollama serve`\n"
f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n"
f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file"
)
else:
fallback = f"I encountered an error while synthesizing the response: {error_msg}"
return AgentResponse(
text=fallback,
decision=AgentDecision(
action="respond",
tool=None,
tool_input=None,
reason=f"multi_step_llm_error: {e}"
),
tool_traces=tool_traces,
reasoning_trace=reasoning_trace + [{
"step": "error",
"tool": "llm",
"error": str(e)
}]
)
# =============================================================
# AUTONOMOUS RETRY + SELF-CORRECTION SYSTEM
# =============================================================
"""
This system provides autonomous retry and self-correction capabilities
for the agent orchestrator. It enables the agent to:
1. **Self-healing**: Tools that break automatically retry with adjusted parameters
2. **Resilient operations**: Handles low RAG scores, empty web results, and misfired rules
3. **Smart optimization**: Automatically rewrites queries, adjusts thresholds, and optimizes parameters
4. **Enterprise-grade reliability**: Matches enterprise behavior with comprehensive retry strategies
Key features:
- safe_tool_call(): Generic retry wrapper for any tool call
- rag_with_repair(): RAG search with automatic threshold adjustment and query expansion
- web_with_repair(): Web search with automatic query rewriting for empty results
- rule_safe_message(): Message rewriting to comply with admin rules
All retry attempts are logged to analytics for monitoring and debugging.
"""
async def safe_tool_call(
self,
tool_fn,
params: Dict[str, Any],
max_retries: int = 2,
fallback_params: Optional[Dict[str, Any]] = None,
tool_name: str = "unknown",
tenant_id: Optional[str] = None,
user_id: Optional[str] = None,
reasoning_trace: Optional[List[Dict[str, Any]]] = None
) -> Dict[str, Any]:
"""
Wrapper for tool calls with automatic retry and self-correction.
Args:
tool_fn: Async function to call
params: Parameters to pass to tool_fn
max_retries: Maximum number of retry attempts
fallback_params: Alternative parameters to try if initial attempt fails
tool_name: Name of the tool (for logging)
tenant_id: Tenant ID (for analytics)
user_id: User ID (for analytics)
reasoning_trace: Optional reasoning trace to append to
Returns:
Tool result dictionary, or {"error": "tool_failed_after_retries"} if all attempts fail
"""
for attempt in range(max_retries):
try:
result = await tool_fn(**params)
if attempt > 0:
# Log successful retry
if reasoning_trace is not None:
reasoning_trace.append({
"step": "retry_success",
"tool": tool_name,
"attempt": attempt + 1,
"status": "recovered"
})
if tenant_id:
self._analytics_log_tool_usage(
tenant_id=tenant_id,
tool_name=f"{tool_name}_retry_{attempt+1}",
latency_ms=0,
success=True,
user_id=user_id
)
return result
except Exception as e:
error_msg = str(e)
if reasoning_trace is not None:
reasoning_trace.append({
"step": "retry_attempt",
"tool": tool_name,
"attempt": attempt + 1,
"error": error_msg[:200]
})
# Log failed attempt
if tenant_id:
self._analytics_log_tool_usage(
tenant_id=tenant_id,
tool_name=tool_name,
latency_ms=0,
success=False,
error_message=error_msg[:200],
user_id=user_id
)
# Try alternate params if provided and not last attempt
if fallback_params and attempt < max_retries - 1:
params = {**params, **fallback_params}
if reasoning_trace is not None:
reasoning_trace.append({
"step": "retry_with_fallback_params",
"tool": tool_name,
"attempt": attempt + 2,
"fallback_params": fallback_params
})
# If last attempt, return error
if attempt == max_retries - 1:
return {"error": "tool_failed_after_retries", "error_message": error_msg}
return {"error": "tool_failed_after_retries"}
async def rag_with_repair(
self,
query: str,
tenant_id: str,
original_threshold: float = 0.3,
reasoning_trace: Optional[List[Dict[str, Any]]] = None,
user_id: Optional[str] = None
) -> Dict[str, Any]:
"""
RAG search with automatic self-correction for low scores.
Strategy:
1. Try with original threshold
2. If top_score < 0.30, retry with lower threshold (0.15)
3. If still low (< 0.15), expand query and retry
"""
# Initial attempt
rag_start = time.time()
result = await self.mcp.call_rag(tenant_id, query, threshold=original_threshold)
rag_latency_ms = int((time.time() - rag_start) * 1000)
# Extract hits and calculate scores
hits = self._extract_hits(result)
top_score = None
avg_score = None
if hits:
scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h]
if scores:
top_score = max(scores)
avg_score = sum(scores) / len(scores)
if reasoning_trace is not None:
reasoning_trace.append({
"step": "rag_initial_search",
"query": query[:200],
"hits_count": len(hits),
"top_score": top_score,
"avg_score": avg_score,
"threshold": original_threshold
})
# Retry logic: low score → lower threshold
if top_score is not None and top_score < 0.30 and original_threshold >= 0.15:
if reasoning_trace is not None:
reasoning_trace.append({
"step": "rag_retry_low_threshold",
"reason": f"top_score {top_score:.3f} < 0.30, retrying with threshold=0.15"
})
retry_start = time.time()
result = await self.mcp.call_rag(tenant_id, query, threshold=0.15)
retry_latency_ms = int((time.time() - retry_start) * 1000)
rag_latency_ms += retry_latency_ms
hits = self._extract_hits(result)
if hits:
scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h]
if scores:
top_score = max(scores)
avg_score = sum(scores) / len(scores)
# Log retry
self._analytics_log_tool_usage(
tenant_id=tenant_id,
tool_name="rag_retry_low_threshold",
latency_ms=retry_latency_ms,
success=True,
user_id=user_id
)
# Final retry: expand query if score still too low
if top_score is not None and top_score < 0.15:
expanded_query = f"{query} (more details comprehensive explanation)"
if reasoning_trace is not None:
reasoning_trace.append({
"step": "rag_retry_expanded_query",
"reason": f"top_score {top_score:.3f} < 0.15, retrying with expanded query",
"original_query": query[:200],
"expanded_query": expanded_query[:200]
})
retry_start = time.time()
result = await self.mcp.call_rag(tenant_id, expanded_query, threshold=0.15)
retry_latency_ms = int((time.time() - retry_start) * 1000)
rag_latency_ms += retry_latency_ms
hits = self._extract_hits(result)
if hits:
scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h]
if scores:
top_score = max(scores)
avg_score = sum(scores) / len(scores)
# Log retry
self._analytics_log_tool_usage(
tenant_id=tenant_id,
tool_name="rag_retry_expanded_query",
latency_ms=retry_latency_ms,
success=True,
user_id=user_id
)
if reasoning_trace is not None:
reasoning_trace.append({
"step": "rag_expanded_query_result",
"hits_count": len(hits),
"top_score": top_score,
"avg_score": avg_score
})
# Log final RAG search
if hits:
self._analytics_log_rag_search(
tenant_id=tenant_id,
query=query[:500],
hits_count=len(hits),
avg_score=avg_score,
top_score=top_score,
latency_ms=rag_latency_ms
)
return result
async def web_with_repair(
self,
query: str,
tenant_id: str,
reasoning_trace: Optional[List[Dict[str, Any]]] = None,
user_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Web search with automatic query rewriting for empty results.
Strategy:
1. Try original query
2. If empty, try "best explanation of {query}"
3. If still empty, try "{query} facts summary"
"""
# Initial attempt
web_start = time.time()
result = await self.mcp.call_web(tenant_id, query)
web_latency_ms = int((time.time() - web_start) * 1000)
hits = self._extract_hits(result)
if reasoning_trace is not None:
reasoning_trace.append({
"step": "web_initial_search",
"query": query[:200],
"hits_count": len(hits)
})
# Retry logic: empty results → rewrite query
if not result or len(hits) == 0:
rewritten_queries = [
f"best explanation of {query}",
f"{query} facts summary"
]
for i, rewritten in enumerate(rewritten_queries):
if reasoning_trace is not None:
reasoning_trace.append({
"step": "web_retry_rewritten",
"attempt": i + 1,
"original_query": query[:200],
"rewritten_query": rewritten[:200]
})
retry_start = time.time()
result = await self.mcp.call_web(tenant_id, rewritten)
retry_latency_ms = int((time.time() - retry_start) * 1000)
web_latency_ms += retry_latency_ms
hits = self._extract_hits(result)
# Log retry
self._analytics_log_tool_usage(
tenant_id=tenant_id,
tool_name=f"web_retry_rewrite_{i+1}",
latency_ms=retry_latency_ms,
success=True,
user_id=user_id
)
if hits:
if reasoning_trace is not None:
reasoning_trace.append({
"step": "web_retry_success",
"rewritten_query": rewritten[:200],
"hits_count": len(hits)
})
break
# Log final web search
self._analytics_log_tool_usage(
tenant_id=tenant_id,
tool_name="web",
latency_ms=web_latency_ms,
success=len(hits) > 0,
user_id=user_id
)
return result
async def rule_safe_message(
self,
user_message: str,
tenant_id: str,
reasoning_trace: Optional[List[Dict[str, Any]]] = None
) -> str:
"""
Check admin rules and rewrite message if it violates policies.
Strategy:
1. Check rules
2. If blocked, ask LLM to rewrite message to comply
3. Return safe version
"""
matches: List[RedFlagMatch] = await self.redflag.check(tenant_id, user_message)
if not matches:
return user_message
# Check if any are blocking rules (not just brief response rules)
blocking_rules = []
for match in matches:
rule_text = (match.description or match.pattern or "").lower()
is_brief_rule = (
match.severity == "low" and (
"greeting" in rule_text or
"brief" in rule_text or
"simple response" in rule_text
)
)
if not is_brief_rule:
blocking_rules.append(match)
# Only rewrite if there are blocking rules
if not blocking_rules:
return user_message
if reasoning_trace is not None:
reasoning_trace.append({
"step": "rule_violation_detected",
"blocking_rules_count": len(blocking_rules),
"action": "attempting_rewrite"
})
# Ask LLM to rewrite the message
rewrite_prompt = f"""The following user message violates company policies. Rewrite it to be compliant while preserving the user's intent as much as possible.
Original message: "{user_message}"
Violated policies:
{chr(10).join([f"- {m.description or m.pattern}" for m in blocking_rules[:3]])}
Provide a rewritten version that:
1. Avoids the policy violations
2. Preserves the user's original intent
3. Remains professional and helpful
Rewritten message:"""
try:
rewritten = await self.llm.simple_call(rewrite_prompt, temperature=0.3)
rewritten = rewritten.strip().strip('"').strip("'")
if reasoning_trace is not None:
reasoning_trace.append({
"step": "rule_rewrite_completed",
"original_length": len(user_message),
"rewritten_length": len(rewritten),
"rewritten_preview": rewritten[:200]
})
# Verify the rewritten message doesn't trigger rules
verify_matches = await self.redflag.check(tenant_id, rewritten)
if not verify_matches or all(
(m.description or m.pattern or "").lower() in ["greeting", "brief", "simple response"]
for m in verify_matches
):
return rewritten
if reasoning_trace is not None:
reasoning_trace.append({
"step": "rule_rewrite_still_violates",
"action": "using_original_with_block"
})
except Exception as e:
if reasoning_trace is not None:
reasoning_trace.append({
"step": "rule_rewrite_failed",
"error": str(e)[:200]
})
# Return original if rewrite failed or still violates
return user_message
def _build_prompt_with_web(self, req: AgentRequest, web_resp: Dict[str, Any]) -> str:
snippets = []
if isinstance(web_resp, dict):
hits = web_resp.get("results") or web_resp.get("items") or []
for h in hits[:5]:
title = h.get("title") or h.get("headline") or ""
snippet = h.get("snippet") or h.get("summary") or h.get("text") or ""
url = h.get("url") or h.get("link") or ""
snippets.append(f"{title}\n{snippet}\nSource: {url}")
snippet_text = "\n---\n".join(snippets) or ""
# prompt = (
# f"You are an assistant with access to recent web search results. Use the following results to answer.\n{snippet_text}\n\n"
# f"User question: {req.message}\nAnswer succinctly and indicate which results you used."
# )
prompt = (
f"You are an assistant with access to recent web search results.\n\n"
f"## Search Results\n"
f"{snippet_text}\n\n"
f"## User Question\n"
f"{req.message}\n\n"
f"## Your Task\n"
f"Provide a clear, accurate, and succinct answer based on the search results above. "
f"Indicate which results you used in your reasoning."
)
return prompt
def _format_tool_output(self, tool_name: str, output: Any, latency_ms: int) -> Dict[str, Any]:
"""
Format tool output to conform to strict JSON schema.
Args:
tool_name: Name of the tool (rag, web, admin, llm)
output: Raw tool output
latency_ms: Actual latency in milliseconds
Returns:
Formatted output conforming to tool schema
"""
if tool_name == "rag":
# Format RAG output
if isinstance(output, dict):
results = output.get("results") or output.get("hits") or []
# Ensure each result has required fields
formatted_results = []
for r in results:
if isinstance(r, dict):
formatted_results.append({
"text": r.get("text") or r.get("content") or str(r),
"similarity": float(r.get("similarity") or r.get("score") or 0.0),
"metadata": r.get("metadata") or {},
"doc_id": r.get("doc_id") or r.get("id")
})
else:
formatted_results.append({
"text": str(r),
"similarity": 0.5,
"metadata": {},
"doc_id": None
})
# Calculate aggregate scores
scores = [r["similarity"] for r in formatted_results if r["similarity"] > 0]
avg_score = sum(scores) / len(scores) if scores else 0.0
top_score = max(scores) if scores else 0.0
return {
"results": formatted_results,
"query": output.get("query", ""),
"tenant_id": output.get("tenant_id", ""),
"hits_count": len(formatted_results),
"avg_score": round(avg_score, 3),
"top_score": round(top_score, 3),
"latency_ms": latency_ms
}
else:
# Fallback for non-dict output
return {
"results": [{"text": str(output), "similarity": 0.5, "metadata": {}, "doc_id": None}],
"query": "",
"tenant_id": "",
"hits_count": 1,
"avg_score": 0.5,
"top_score": 0.5,
"latency_ms": latency_ms
}
elif tool_name == "web":
# Format Web output
if isinstance(output, dict):
results = output.get("results") or output.get("items") or []
formatted_results = []
for r in results:
if isinstance(r, dict):
formatted_results.append({
"title": r.get("title") or r.get("headline") or "",
"snippet": r.get("snippet") or r.get("summary") or r.get("text") or "",
"link": r.get("url") or r.get("link") or "",
"displayLink": r.get("displayLink") or r.get("display_link") or ""
})
else:
formatted_results.append({
"title": "",
"snippet": str(r),
"link": "",
"displayLink": ""
})
return {
"results": formatted_results,
"query": output.get("query", ""),
"total_results": output.get("total_results") or output.get("totalResults") or len(formatted_results),
"latency_ms": latency_ms
}
else:
return {
"results": [],
"query": "",
"total_results": 0,
"latency_ms": latency_ms
}
elif tool_name == "admin":
# Format Admin output
if isinstance(output, dict):
violations = output.get("violations") or output.get("matches") or []
formatted_violations = []
for v in violations:
if isinstance(v, dict):
formatted_violations.append({
"rule_id": v.get("rule_id") or v.get("id") or "",
"rule_pattern": v.get("rule_pattern") or v.get("pattern") or "",
"severity": v.get("severity", "medium"),
"matched_text": v.get("matched_text") or v.get("text") or "",
"confidence": float(v.get("confidence", 1.0)),
"message_preview": v.get("message_preview") or v.get("preview") or ""
})
return {
"violations": formatted_violations,
"checked": output.get("checked", True),
"rules_count": output.get("rules_count") or output.get("rulesCount") or len(formatted_violations),
"latency_ms": latency_ms
}
else:
return {
"violations": [],
"checked": True,
"rules_count": 0,
"latency_ms": latency_ms
}
elif tool_name == "llm":
# Format LLM output
if isinstance(output, str):
return {
"text": output,
"tokens_used": len(output) // 4, # Rough estimate
"latency_ms": latency_ms,
"model": getattr(self.llm, 'model', 'unknown'),
"temperature": 0.0
}
elif isinstance(output, dict):
return {
"text": output.get("text") or output.get("response") or str(output),
"tokens_used": output.get("tokens_used") or output.get("tokens") or 0,
"latency_ms": latency_ms,
"model": output.get("model") or getattr(self.llm, 'model', 'unknown'),
"temperature": output.get("temperature", 0.0)
}
else:
return {
"text": str(output),
"tokens_used": 0,
"latency_ms": latency_ms,
"model": getattr(self.llm, 'model', 'unknown'),
"temperature": 0.0
}
# Unknown tool - return as-is
return output if isinstance(output, dict) else {"output": str(output), "latency_ms": latency_ms}
@staticmethod
def _extract_hits(resp: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]:
if not isinstance(resp, dict):
return []
return resp.get("results") or resp.get("hits") or resp.get("items") or []
def _summarize_hits(self, resp: Optional[Dict[str, Any]], limit: int = 3) -> List[str]:
hits = self._extract_hits(resp)
summaries = []
for hit in hits[:limit]:
if isinstance(hit, dict):
snippet = hit.get("text") or hit.get("content") or hit.get("snippet") or ""
else:
snippet = str(hit)
summaries.append(snippet[:160])
return summaries
async def run_parallel_tools(self, tasks: Dict[str, Any]) -> Dict[str, Any]:
"""
Run multiple tools in parallel using asyncio.gather.
Args:
tasks: Dictionary mapping tool names to coroutines, e.g.:
{"rag": rag_coro, "web": web_coro}
Returns:
Dictionary mapping tool names to results, e.g.:
{"rag": rag_result, "web": web_result}
Exceptions are returned as values if a tool fails.
"""
if not tasks:
return {}
names = list(tasks.keys())
coros = [tasks[name] for name in names]
# Run all coroutines in parallel, return exceptions instead of raising
results = await asyncio.gather(*coros, return_exceptions=True)
return {names[i]: results[i] for i in range(len(names))}
@staticmethod
def _first_query_for_tool(steps: List[Dict[str, Any]], tool_name: str, default_query: str) -> Optional[str]:
for step in steps:
if step.get("tool") == tool_name:
input_data = step.get("input") or {}
return input_data.get("query") or default_query
return None