IntegraChat / backend /api /services /tool_selector.py
nothingworry's picture
feat: add caching, query expansion, improved streaming, and enhanced error handling
ddc5c21
from dataclasses import dataclass, field
import json
import re
from typing import Dict, Any, Optional, List
from .tool_metadata import (
get_tool_latency_estimate,
estimate_path_latency,
get_fastest_path,
validate_tool_output
)
@dataclass
class ToolSelector:
llm_client: any = None
async def select(self, intent: str, text: str, ctx):
msg = text.lower().strip()
tool_scores = ctx.get("tool_scores", {})
rag_score = tool_scores.get("rag_fitness", 0.0)
web_score = tool_scores.get("web_fitness", 0.0)
llm_score = tool_scores.get("llm_only", 0.0)
# Context-aware routing: Check previous outputs
rag_results = ctx.get("rag_results", [])
memory = ctx.get("memory", []) # Recent tool outputs from conversation memory
admin_violations = ctx.get("admin_violations", [])
# Context-aware decisions
context_hints = self._analyze_context(rag_results, memory, admin_violations, tool_scores)
# ---------------------------------
# 1. Detect ADMIN RULES FIRST
# ---------------------------------
if intent == "admin":
# Context-aware: If severe violation, skip agent reasoning
if context_hints.get("skip_agent_reasoning"):
return _multi_step([
step("admin", {"query": text})
], "admin critical violation β†’ immediate block (latency: ~10ms)")
# Estimate latency for admin path
admin_latency = get_tool_latency_estimate("admin", {"query_length": len(text)})
llm_latency = get_tool_latency_estimate("llm", {"query_length": len(text)})
total_latency = admin_latency + llm_latency
return _multi_step([
step("admin", {"query": text}),
step("llm", {"query": text})
], f"admin safety rule triggered β†’ llm (est. latency: {total_latency}ms)")
steps = []
needs_rag = False
needs_web = False
# ---------------------------------
# 2. PRIORITY: Check for news/current events queries FIRST
# ---------------------------------
# This must happen BEFORE RAG check to prevent news queries from using RAG
freshness_keywords = ["latest", "today", "news", "current", "recent",
"now", "updates", "breaking", "trending", "happening",
"what's new", "what is new", "what happened"]
news_patterns = [
r"latest news", r"current news", r"today's news", r"breaking news",
r"news about", r"news on", r"news of", r"what's happening",
r"what happened", r"recent news", r"news update"
]
is_news_query = any(k in msg for k in freshness_keywords) or any(re.search(p, msg) for p in news_patterns)
# If it's a news query, skip RAG entirely and go straight to web
if is_news_query:
needs_web = True
needs_rag = False # News queries should NEVER use RAG
# For news queries, enhance the query to be more specific
web_query = text
if len(text.split()) <= 4: # Short queries like "latest news about Al"
# Expand the query for better results
if "news" not in msg:
web_query = f"{text} news latest"
elif "about" not in msg and "on" not in msg:
# If query is just "latest news Al", expand to "latest news about Al"
web_query = f"latest news about {text.replace('latest', '').replace('news', '').strip()}"
# Estimate latency for web search
web_latency = get_tool_latency_estimate("web", {
"query_length": len(web_query),
"query_complexity": "high" if len(web_query.split()) > 10 else "medium"
})
steps.append(step("web", {"query": web_query, "_estimated_latency_ms": web_latency}))
# ---------------------------------
# 3. Check RAG results (pre-fetch) with context-aware routing
# ---------------------------------
# Only check RAG if it's NOT a news query
if not is_news_query:
rag_has_data = len(rag_results) > 0
# Context-aware: If RAG returned high score, skip web search
rag_high_score = False
if rag_results:
top_score = max((r.get("similarity", 0) for r in rag_results), default=0)
rag_high_score = top_score >= 0.8
if rag_high_score and context_hints.get("skip_web_if_rag_high"):
# High confidence RAG result, skip web
needs_web = False
# Context-aware: If agent already has relevant memory, skip RAG
has_relevant_memory = context_hints.get("has_relevant_memory", False)
if has_relevant_memory and context_hints.get("skip_rag_if_memory"):
needs_rag = False
else:
# RAG patterns: internal knowledge, company-specific, documentation
rag_patterns = [
r"company", r"internal", r"documentation", r"our ", r"your ",
r"knowledge base", r"private", r"internal docs", r"corporate",
r"admin", r"administrator"
]
# Exclude "who is" and "what is" from RAG patterns if they're part of news queries
# But keep them for non-news queries
if not is_news_query:
rag_patterns.extend([r"who is", r"what is"])
if rag_has_data or rag_score >= 0.55 or any(re.search(p, msg) for p in rag_patterns):
needs_rag = True
if not any(s.get("tool") == "rag" for s in steps):
# Estimate latency for RAG
rag_latency = get_tool_latency_estimate("rag", {"query_length": len(text)})
steps.append(step("rag", {"query": text, "_estimated_latency_ms": rag_latency}))
# ---------------------------------
# 4. Fact lookup / definition β†’ Web (with context-aware routing)
# ---------------------------------
# Only check fact patterns if it's NOT a news query (news already handled above)
if not is_news_query:
# Skip web if RAG already provided high-quality results
rag_high_score = False
if rag_results:
top_score = max((r.get("similarity", 0) for r in rag_results), default=0)
rag_high_score = top_score >= 0.8
if not (rag_high_score and context_hints.get("skip_web_if_rag_high")):
fact_patterns = [
r"what is ", r"who is ", r"where is ",
r"tell me about ", r"define ", r"explain ",
r"history of ", r"information about", r"details about"
]
if web_score >= 0.55 or any(re.search(p, msg) for p in fact_patterns):
needs_web = True
# Avoid duplicate web steps
if not any(s.get("tool") == "web" for s in steps):
# Estimate latency for web search
web_latency = get_tool_latency_estimate("web", {
"query_length": len(text),
"query_complexity": "high" if len(text.split()) > 10 else "medium"
})
steps.append(step("web", {"query": text, "_estimated_latency_ms": web_latency}))
# ---------------------------------
# 5. Complex queries that need multiple sources
# ---------------------------------
complex_patterns = [
r"compare", r"difference between", r"versus", r"vs",
r"both", r"and also", r"as well as", r"in addition"
]
needs_multiple = any(re.search(p, msg) for p in complex_patterns)
# ---------------------------------
# 6. Use LLM to enhance plan if we have partial steps or complex query
# ---------------------------------
# Check if we should use parallel execution (both RAG and Web needed)
should_parallel = needs_rag and needs_web and (needs_multiple or rag_score >= 0.55 and web_score >= 0.55)
if self.llm_client and (needs_multiple or (needs_rag and needs_web) or len(steps) == 0):
plan_prompt = f"""
You are an enterprise MCP agent.
You can select MULTIPLE tools in sequence OR in parallel to provide comprehensive answers.
TOOLS:
- rag β†’ private knowledge retrieval (use for internal/company docs)
- web β†’ online factual lookup (use for public facts, current info)
- llm β†’ final reasoning and synthesis (always include at end)
Current context:
- RAG available: {rag_has_data}
- User message: "{text}"
- Tool scores: {json.dumps(tool_scores)}
Determine which tools are needed. You can select:
- Just LLM (simple questions)
- RAG + LLM (internal knowledge questions)
- Web + LLM (public fact questions)
- RAG + Web + LLM (comprehensive questions needing both sources)
IMPORTANT: If the query needs BOTH internal docs (RAG) AND current/live info (Web),
you can mark them for parallel execution by using a "parallel" step.
Return a JSON list describing the steps. For parallel execution, use:
[
{{
"parallel": {{
"rag": "query for internal docs",
"web": "query for live info"
}},
"reason": "Need both internal and live information simultaneously"
}},
{{"tool": "llm", "reason": "Synthesize all information"}}
]
For sequential execution, use:
[
{{"tool": "rag", "reason": "Need internal documentation"}},
{{"tool": "web", "reason": "Need current public information"}},
{{"tool": "llm", "reason": "Synthesize all information"}}
]
Only return the JSON array. Do not include markdown formatting.
"""
try:
out = await self.llm_client.simple_call(plan_prompt)
# Clean the output in case LLM adds markdown
out = out.strip()
if out.startswith("```json"):
out = out[7:]
if out.startswith("```"):
out = out[3:]
if out.endswith("```"):
out = out[:-3]
out = out.strip()
steps_json = json.loads(out)
# Check if LLM returned a parallel step
has_parallel = any("parallel" in s for s in steps_json)
if has_parallel:
# Extract parallel step and convert to our format
parallel_step = None
other_steps = []
for s in steps_json:
if "parallel" in s:
parallel_step = {"parallel": s["parallel"]}
elif s.get("tool") != "llm":
other_steps.append(step(s["tool"], {"query": text}))
if parallel_step:
steps = [parallel_step] + other_steps
else:
# Fallback: convert to regular steps
steps = [
step(s["tool"], {"query": text})
for s in steps_json if s.get("tool") != "llm"
]
else:
# Replace steps with LLM-planned steps (excluding LLM, we'll add it at end)
steps = [
step(s["tool"], {"query": text})
for s in steps_json if s.get("tool") != "llm"
]
except Exception as e:
# If LLM planning fails, check if we should create parallel step manually
if should_parallel and needs_rag and needs_web:
# Create parallel step manually
steps = [{
"parallel": {
"rag": text,
"web": text
}
}]
elif not steps:
steps = []
# ---------------------------------
# 7. If we have both RAG and Web but no parallel step, consider creating one
# ---------------------------------
if should_parallel and needs_rag and needs_web:
# Check if we already have a parallel step
has_parallel = any("parallel" in s for s in steps)
if not has_parallel:
# Replace sequential RAG and Web steps with a parallel step
new_steps = []
rag_query = text
web_query = text
# Extract queries from existing steps if available
for s in steps:
if s.get("tool") == "rag":
rag_query = s.get("input", {}).get("query", text)
elif s.get("tool") == "web":
web_query = s.get("input", {}).get("query", text)
# Create parallel step
new_steps.append({
"parallel": {
"rag": rag_query,
"web": web_query
}
})
# Keep other non-RAG/Web steps
for s in steps:
if s.get("tool") not in ["rag", "web"]:
new_steps.append(s)
steps = new_steps
# ---------------------------------
# 8. Always end with LLM synthesis
# ---------------------------------
if not steps or (isinstance(steps[-1], dict) and steps[-1].get("tool") != "llm" and "parallel" not in steps[-1]):
steps.append(step("llm", {
"rag_data": rag_results if rag_has_data else None,
"query": text
}))
# Optimize tool order for latency (fastest first when possible)
if len(steps) > 1:
# Reorder steps by estimated latency (except LLM which should be last)
llm_step = None
other_steps = []
for s in steps:
if isinstance(s, dict) and s.get("tool") == "llm":
llm_step = s
else:
other_steps.append(s)
# Sort other steps by latency
other_steps.sort(key=lambda s: s.get("input", {}).get("_estimated_latency_ms", 1000))
# Rebuild steps with LLM last
steps = other_steps
if llm_step:
steps.append(llm_step)
# Calculate total estimated latency
tool_names = []
total_latency = 0
for s in steps:
if "parallel" in s:
tool_names.append("parallel(RAG+Web)")
# Parallel execution: use max latency
rag_lat = get_tool_latency_estimate("rag")
web_lat = get_tool_latency_estimate("web")
total_latency += max(rag_lat, web_lat)
elif isinstance(s, dict) and "tool" in s:
tool_name = s["tool"]
tool_names.append(tool_name)
est_latency = s.get("input", {}).get("_estimated_latency_ms")
if est_latency:
total_latency += est_latency
else:
total_latency += get_tool_latency_estimate(tool_name)
# Build reason with latency and context hints
context_info = []
if context_hints.get("skip_web_if_rag_high"):
context_info.append("RAG high score β†’ skip web")
if context_hints.get("skip_rag_if_memory"):
context_info.append("memory available β†’ skip RAG")
if context_hints.get("skip_agent_reasoning"):
context_info.append("critical violation β†’ skip reasoning")
context_str = f" | context: {', '.join(context_info)}" if context_info else ""
reason = f"multi-tool plan: {' β†’ '.join(tool_names)} | est. latency: {total_latency}ms | scores={tool_scores}{context_str}"
return _multi_step(steps, reason)
def _analyze_context(
self,
rag_results: List[Dict],
memory: List[Dict],
admin_violations: List[Dict],
tool_scores: Dict[str, float]
) -> Dict[str, Any]:
"""
Analyze context from previous outputs to make routing decisions.
Returns context hints for intelligent tool selection.
"""
hints = {}
# Check RAG results quality
if rag_results:
top_score = max((r.get("similarity", 0) for r in rag_results), default=0)
if top_score >= 0.8:
hints["skip_web_if_rag_high"] = True
hints["rag_high_confidence"] = True
# Check if relevant memory exists
if memory:
# Check if memory contains relevant RAG results
has_rag_memory = any(
m.get("tool") == "rag" and m.get("result", {}).get("results")
for m in memory[-5:] # Check last 5 memory entries
)
if has_rag_memory:
hints["has_relevant_memory"] = True
# Only skip RAG if memory is very recent and high quality
recent_memory = memory[-1] if memory else {}
if recent_memory.get("tool") == "rag":
mem_results = recent_memory.get("result", {}).get("results", [])
if mem_results:
mem_top_score = max((r.get("similarity", 0) for r in mem_results), default=0)
if mem_top_score >= 0.75:
hints["skip_rag_if_memory"] = True
# Check admin violations severity
if admin_violations:
max_severity = max(
(v.get("severity", "low") for v in admin_violations),
key=lambda s: ["low", "medium", "high", "critical"].index(s) if s in ["low", "medium", "high", "critical"] else 0
)
if max_severity in ["high", "critical"]:
hints["skip_agent_reasoning"] = True
hints["critical_violation"] = True
return hints
def step(tool, input_data):
return {"tool": tool, "input": input_data}
def _multi_step(steps, reason):
from ..models.agent import AgentDecision
return AgentDecision(
action="multi_step",
tool=None,
tool_input={"steps": steps},
reason=reason
)