IntegraChat / backend /api /services /agent_orchestrator.py
nothingworry's picture
feat: Add knowledge base with document ingestion and file upload support
73fd1fc
raw
history blame
22.1 kB
# =============================================================
# File: backend/api/services/agent_orchestrator.py
# =============================================================
"""
Agent Orchestrator (integrated with enterprise RedFlagDetector)
Place at: backend/api/services/agent_orchestrator.py
"""
from __future__ import annotations
import asyncio
import json
import os
from typing import List, Dict, Any, Optional
from ..models.agent import AgentRequest, AgentDecision, AgentResponse
from ..models.redflag import RedFlagMatch
from .redflag_detector import RedFlagDetector
from .intent_classifier import IntentClassifier
from .tool_selector import ToolSelector
from .llm_client import LLMClient
from ..mcp_clients.mcp_client import MCPClient
from .tool_scoring import ToolScoringService
class AgentOrchestrator:
def __init__(self, rag_mcp_url: str, web_mcp_url: str, admin_mcp_url: str, llm_backend: str = "ollama"):
self.mcp = MCPClient(rag_mcp_url, web_mcp_url, admin_mcp_url)
self.llm = LLMClient(backend=llm_backend, url=os.getenv("OLLAMA_URL"), api_key=os.getenv("GROQ_API_KEY"), model=os.getenv("OLLAMA_MODEL"))
# pass admin_mcp_url so detector can call back
self.redflag = RedFlagDetector(
supabase_url=os.getenv("SUPABASE_URL"),
supabase_key=os.getenv("SUPABASE_SERVICE_KEY"),
admin_mcp_url=admin_mcp_url
)
self.intent = IntentClassifier(llm_client=self.llm)
self.selector = ToolSelector(llm_client=self.llm)
self.tool_scorer = ToolScoringService()
async def handle(self, req: AgentRequest) -> AgentResponse:
reasoning_trace: List[Dict[str, Any]] = []
reasoning_trace.append({
"step": "request_received",
"tenant_id": req.tenant_id,
"user_id": req.user_id,
"message_preview": req.message[:120]
})
# 1) Red-flag check (async)
matches: List[RedFlagMatch] = await self.redflag.check(req.tenant_id, req.message)
reasoning_trace.append({
"step": "redflag_check",
"match_count": len(matches),
"matches": [m.__dict__ for m in matches]
})
if matches:
# Notify admin asynchronously (do not await blocking the response path if you prefer)
# we await here to ensure admin receives the alert before responding
try:
await self.redflag.notify_admin(req.tenant_id, matches, source_payload={"message": req.message, "user_id": req.user_id})
except Exception:
pass
decision = AgentDecision(
action="block",
tool="admin",
tool_input={"violations": [m.__dict__ for m in matches]},
reason="redflag_triggered"
)
return AgentResponse(
text="Your request has been blocked due to policy.",
decision=decision,
tool_traces=[{"redflags": [m.__dict__ for m in matches]}],
reasoning_trace=reasoning_trace
)
# 2) Intent classification
intent = await self.intent.classify(req.message)
reasoning_trace.append({
"step": "intent_detection",
"intent": intent
})
# 2.5) Pre-fetch RAG results if available (for tool selector context)
rag_prefetch = None
rag_results = []
try:
# Try to pre-fetch RAG to help tool selector make better decisions
rag_prefetch = await self.mcp.call_rag(req.tenant_id, req.message)
if isinstance(rag_prefetch, dict):
rag_results = rag_prefetch.get("results") or rag_prefetch.get("hits") or []
reasoning_trace.append({
"step": "rag_prefetch",
"status": "ok",
"hit_count": len(rag_results)
})
except Exception as pref_err:
# If RAG fails, continue without it
reasoning_trace.append({
"step": "rag_prefetch",
"status": "error",
"error": str(pref_err)
})
rag_prefetch = None
tool_scores = self.tool_scorer.score(req.message, intent, rag_results)
reasoning_trace.append({
"step": "tool_scoring",
"scores": tool_scores
})
# 3) Tool selection (hybrid) - pass RAG results in context
ctx = {
"tenant_id": req.tenant_id,
"rag_results": rag_results,
"tool_scores": tool_scores
}
decision = await self.selector.select(intent, req.message, ctx)
reasoning_trace.append({
"step": "tool_selection",
"decision": decision.dict(),
"context_scores": tool_scores
})
tool_traces: List[Dict[str, Any]] = []
# 4) Handle multi-step tool execution
if decision.action == "multi_step" and decision.tool_input:
steps = decision.tool_input.get("steps", [])
if steps:
return await self._execute_multi_step(
req,
steps,
decision,
tool_traces,
reasoning_trace,
rag_prefetch
)
# 5) Execute single tool
if decision.action == "call_tool" and decision.tool:
try:
if decision.tool == "rag":
rag_resp = await self.mcp.call_rag(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
tool_traces.append({"tool": "rag", "response": rag_resp})
reasoning_trace.append({
"step": "tool_execution",
"tool": "rag",
"hit_count": len(self._extract_hits(rag_resp)),
"summary": self._summarize_hits(rag_resp, limit=2)
})
prompt = self._build_prompt_with_rag(req, rag_resp)
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
reasoning_trace.append({
"step": "llm_response",
"mode": "rag_synthesis"
})
return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace)
if decision.tool == "web":
web_resp = await self.mcp.call_web(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
tool_traces.append({"tool": "web", "response": web_resp})
reasoning_trace.append({
"step": "tool_execution",
"tool": "web",
"hit_count": len(self._extract_hits(web_resp)),
"summary": self._summarize_hits(web_resp, limit=2)
})
prompt = self._build_prompt_with_web(req, web_resp)
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
reasoning_trace.append({
"step": "llm_response",
"mode": "web_synthesis"
})
return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace)
if decision.tool == "admin":
admin_resp = await self.mcp.call_admin(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
tool_traces.append({"tool": "admin", "response": admin_resp})
reasoning_trace.append({
"step": "tool_execution",
"tool": "admin",
"status": "completed"
})
return AgentResponse(text=json.dumps(admin_resp), decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace)
if decision.tool == "llm":
llm_out = await self.llm.simple_call(req.message, temperature=req.temperature)
reasoning_trace.append({
"step": "llm_response",
"mode": "direct"
})
return AgentResponse(text=llm_out, decision=decision, reasoning_trace=reasoning_trace)
except Exception as e:
tool_traces.append({"tool": decision.tool, "error": str(e)})
try:
fallback = await self.llm.simple_call(req.message, temperature=req.temperature)
except Exception as llm_error:
error_msg = str(llm_error)
if "Cannot connect" in error_msg or "Ollama" in error_msg:
fallback = (
f"I encountered an error while processing your request: {str(e)}\n\n"
f"Additionally, the AI service (Ollama) is unavailable: {error_msg}\n\n"
f"To fix:\n"
f"1. Install Ollama from https://ollama.ai\n"
f"2. Start: `ollama serve`\n"
f"3. Pull model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`"
)
else:
fallback = f"I encountered an error while processing your request: {str(e)}. Additionally, the AI service is unavailable: {error_msg}"
return AgentResponse(
text=fallback,
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason=f"tool_error_fallback: {e}"),
tool_traces=tool_traces,
reasoning_trace=reasoning_trace + [{
"step": "error",
"tool": decision.tool,
"error": str(e)
}]
)
# Default: direct LLM response
try:
llm_out = await self.llm.simple_call(req.message, temperature=req.temperature)
except Exception as e:
# If LLM fails, return a helpful error message
error_msg = str(e)
if "Cannot connect" in error_msg or "Ollama" in error_msg:
llm_out = (
f"I couldn't connect to the AI service (Ollama). "
f"Error: {error_msg}\n\n"
f"To fix this:\n"
f"1. Install Ollama from https://ollama.ai\n"
f"2. Start Ollama: `ollama serve`\n"
f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n"
f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file"
)
else:
llm_out = f"I apologize, but I'm unable to process your request right now. The AI service is unavailable: {error_msg}"
reasoning_trace.append({
"step": "error",
"tool": "llm",
"error": str(e)
})
return AgentResponse(
text=llm_out,
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="default_llm"),
reasoning_trace=reasoning_trace
)
def _build_prompt_with_rag(self, req: AgentRequest, rag_resp: Dict[str, Any]) -> str:
snippets = []
if isinstance(rag_resp, dict):
hits = rag_resp.get("results") or rag_resp.get("hits") or []
for h in hits[:5]:
txt = h.get("text") or h.get("content") or str(h)
snippets.append(txt)
snippet_text = "\n---\n".join(snippets) or ""
prompt = (
f"You are an assistant helping tenant {req.tenant_id}. Use the following retrieved documents to answer the user's question.\n"
f"Documents:\n{snippet_text}\n\n"
f"User question: {req.message}\nProvide a concise, accurate answer and cite the source snippets where appropriate."
)
return prompt
async def _execute_multi_step(self, req: AgentRequest, steps: List[Dict[str, Any]],
decision: AgentDecision, tool_traces: List[Dict[str, Any]],
reasoning_trace: List[Dict[str, Any]],
pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse:
"""
Execute multiple tools in sequence and synthesize results with LLM.
"""
rag_data = None
web_data = None
admin_data = None
collected_data = []
parallel_tasks = {}
rag_parallel_query = self._first_query_for_tool(steps, "rag", req.message)
web_parallel_query = self._first_query_for_tool(steps, "web", req.message)
if rag_parallel_query and web_parallel_query and rag_parallel_query == web_parallel_query:
if not pre_fetched_rag:
parallel_tasks["rag"] = asyncio.create_task(self.mcp.call_rag(req.tenant_id, rag_parallel_query))
parallel_tasks["web"] = asyncio.create_task(self.mcp.call_web(req.tenant_id, web_parallel_query))
# Execute each step in sequence
for step_info in steps:
tool_name = step_info.get("tool")
step_input = step_info.get("input") or {}
query = step_input.get("query") or req.message
try:
if tool_name == "rag":
# Reuse pre-fetched RAG if available, otherwise fetch
if pre_fetched_rag and query == rag_parallel_query:
rag_resp = pre_fetched_rag
tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"})
elif parallel_tasks.get("rag") and query == rag_parallel_query:
rag_resp = await parallel_tasks["rag"]
tool_traces.append({"tool": "rag", "response": rag_resp, "note": "parallel"})
else:
rag_resp = await self.mcp.call_rag(req.tenant_id, query)
tool_traces.append({"tool": "rag", "response": rag_resp})
rag_data = rag_resp
reasoning_trace.append({
"step": "tool_execution",
"tool": "rag",
"hit_count": len(self._extract_hits(rag_resp)),
"summary": self._summarize_hits(rag_resp, limit=2)
})
# Extract snippets for prompt
if isinstance(rag_resp, dict):
hits = rag_resp.get("results") or rag_resp.get("hits") or []
for h in hits[:5]:
txt = h.get("text") or h.get("content") or str(h)
collected_data.append(f"[RAG] {txt}")
elif tool_name == "web":
if parallel_tasks.get("web") and query == web_parallel_query:
web_resp = await parallel_tasks["web"]
tool_traces.append({"tool": "web", "response": web_resp, "note": "parallel"})
else:
web_resp = await self.mcp.call_web(req.tenant_id, query)
tool_traces.append({"tool": "web", "response": web_resp})
web_data = web_resp
reasoning_trace.append({
"step": "tool_execution",
"tool": "web",
"hit_count": len(self._extract_hits(web_resp)),
"summary": self._summarize_hits(web_resp, limit=2)
})
# Extract snippets for prompt
if isinstance(web_resp, dict):
hits = web_resp.get("results") or web_resp.get("items") or []
for h in hits[:5]:
title = h.get("title") or h.get("headline") or ""
snippet = h.get("snippet") or h.get("summary") or h.get("text") or ""
url = h.get("url") or h.get("link") or ""
collected_data.append(f"[WEB] {title}\n{snippet}\nSource: {url}")
elif tool_name == "admin":
admin_resp = await self.mcp.call_admin(req.tenant_id, query)
tool_traces.append({"tool": "admin", "response": admin_resp})
admin_data = admin_resp
collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}")
reasoning_trace.append({
"step": "tool_execution",
"tool": "admin",
"status": "completed"
})
elif tool_name == "llm":
# LLM is always last - synthesize all collected data
break
except Exception as e:
tool_traces.append({"tool": tool_name, "error": str(e)})
# Continue with other tools even if one fails
reasoning_trace.append({
"step": "error",
"tool": tool_name,
"error": str(e)
})
# Build comprehensive prompt with all collected data
data_section = "\n---\n".join(collected_data) if collected_data else ""
if data_section:
prompt = (
f"You are an assistant helping tenant {req.tenant_id}.\n\n"
f"The following information has been gathered from multiple sources:\n\n"
f"{data_section}\n\n"
f"User question: {req.message}\n\n"
f"Provide a comprehensive, accurate answer using the information above. "
f"Cite sources where appropriate (RAG for internal docs, WEB for online sources)."
)
else:
# No data collected, just answer the question
prompt = req.message
# Final LLM synthesis
try:
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
return AgentResponse(
text=llm_out,
decision=decision,
tool_traces=tool_traces,
reasoning_trace=reasoning_trace + [{
"step": "llm_response",
"mode": "multi_step"
}]
)
except Exception as e:
tool_traces.append({"tool": "llm", "error": str(e)})
error_msg = str(e)
# Provide helpful error message
if "Cannot connect" in error_msg or "Ollama" in error_msg:
fallback = (
f"I couldn't connect to the AI service (Ollama). "
f"Error: {error_msg}\n\n"
f"To fix this:\n"
f"1. Install Ollama from https://ollama.ai\n"
f"2. Start Ollama: `ollama serve`\n"
f"3. Pull the model: `ollama pull {os.getenv('OLLAMA_MODEL', 'llama3.1:latest')}`\n"
f"4. Or set OLLAMA_URL and OLLAMA_MODEL in your .env file"
)
else:
fallback = f"I encountered an error while synthesizing the response: {error_msg}"
return AgentResponse(
text=fallback,
decision=AgentDecision(
action="respond",
tool=None,
tool_input=None,
reason=f"multi_step_llm_error: {e}"
),
tool_traces=tool_traces,
reasoning_trace=reasoning_trace + [{
"step": "error",
"tool": "llm",
"error": str(e)
}]
)
def _build_prompt_with_web(self, req: AgentRequest, web_resp: Dict[str, Any]) -> str:
snippets = []
if isinstance(web_resp, dict):
hits = web_resp.get("results") or web_resp.get("items") or []
for h in hits[:5]:
title = h.get("title") or h.get("headline") or ""
snippet = h.get("snippet") or h.get("summary") or h.get("text") or ""
url = h.get("url") or h.get("link") or ""
snippets.append(f"{title}\n{snippet}\nSource: {url}")
snippet_text = "\n---\n".join(snippets) or ""
prompt = (
f"You are an assistant with access to recent web search results. Use the following results to answer.\n{snippet_text}\n\n"
f"User question: {req.message}\nAnswer succinctly and indicate which results you used."
)
return prompt
@staticmethod
def _extract_hits(resp: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]:
if not isinstance(resp, dict):
return []
return resp.get("results") or resp.get("hits") or resp.get("items") or []
def _summarize_hits(self, resp: Optional[Dict[str, Any]], limit: int = 3) -> List[str]:
hits = self._extract_hits(resp)
summaries = []
for hit in hits[:limit]:
if isinstance(hit, dict):
snippet = hit.get("text") or hit.get("content") or hit.get("snippet") or ""
else:
snippet = str(hit)
summaries.append(snippet[:160])
return summaries
@staticmethod
def _first_query_for_tool(steps: List[Dict[str, Any]], tool_name: str, default_query: str) -> Optional[str]:
for step in steps:
if step.get("tool") == tool_name:
input_data = step.get("input") or {}
return input_data.get("query") or default_query
return None