IntegraChat / backend /api /services /tool_selector.py
nothingworry's picture
Reasoning traces, smarter tools, deterministic backend tests.
ef83e66
raw
history blame
6.09 kB
from dataclasses import dataclass, field
import json
import re
@dataclass
class ToolSelector:
llm_client: any = None
async def select(self, intent: str, text: str, ctx):
msg = text.lower().strip()
tool_scores = ctx.get("tool_scores", {})
rag_score = tool_scores.get("rag_fitness", 0.0)
web_score = tool_scores.get("web_fitness", 0.0)
llm_score = tool_scores.get("llm_only", 0.0)
# ---------------------------------
# 1. Detect ADMIN RULES FIRST
# ---------------------------------
if intent == "admin":
return _multi_step([
step("admin", {"query": text}),
step("llm", {"query": text})
], "admin safety rule triggered β†’ llm")
steps = []
needs_rag = False
needs_web = False
# ---------------------------------
# 2. Check RAG results (pre-fetch)
# ---------------------------------
rag_results = ctx.get("rag_results", [])
rag_has_data = len(rag_results) > 0
# RAG patterns: internal knowledge, company-specific, documentation
rag_patterns = [
r"company", r"internal", r"documentation", r"our ", r"your ",
r"knowledge base", r"private", r"internal docs", r"corporate"
]
if rag_has_data or rag_score >= 0.55 or any(re.search(p, msg) for p in rag_patterns):
needs_rag = True
if not any(s["tool"] == "rag" for s in steps):
steps.append(step("rag", {"query": text}))
# ---------------------------------
# 3. Fact lookup / definition β†’ Web
# ---------------------------------
fact_patterns = [
r"what is ", r"who is ", r"where is ",
r"tell me about ", r"define ", r"explain ",
r"history of ", r"information about", r"details about"
]
if web_score >= 0.55 or any(re.search(p, msg) for p in fact_patterns):
needs_web = True
steps.append(step("web", {"query": text}))
# ---------------------------------
# 4. Freshness heuristic β†’ Web
# ---------------------------------
freshness_keywords = ["latest", "today", "news", "current", "recent",
"now", "updates", "breaking", "trending"]
if any(k in msg for k in freshness_keywords):
needs_web = True
# Avoid duplicate web steps
if not any(s["tool"] == "web" for s in steps):
steps.append(step("web", {"query": text}))
# ---------------------------------
# 5. Complex queries that need multiple sources
# ---------------------------------
complex_patterns = [
r"compare", r"difference between", r"versus", r"vs",
r"both", r"and also", r"as well as", r"in addition"
]
needs_multiple = any(re.search(p, msg) for p in complex_patterns)
# ---------------------------------
# 6. Use LLM to enhance plan if we have partial steps or complex query
# ---------------------------------
if self.llm_client and (needs_multiple or (needs_rag and needs_web) or len(steps) == 0):
plan_prompt = f"""
You are an enterprise MCP agent.
You can select MULTIPLE tools in sequence to provide comprehensive answers.
TOOLS:
- rag β†’ private knowledge retrieval (use for internal/company docs)
- web β†’ online factual lookup (use for public facts, current info)
- llm β†’ final reasoning and synthesis (always include at end)
Current context:
- RAG available: {rag_has_data}
- User message: "{text}"
- Tool scores: {json.dumps(tool_scores)}
Determine which tools are needed. You can select:
- Just LLM (simple questions)
- RAG + LLM (internal knowledge questions)
- Web + LLM (public fact questions)
- RAG + Web + LLM (comprehensive questions needing both sources)
Return a JSON list describing the steps, e.g.:
[
{{"tool": "rag", "reason": "Need internal documentation"}},
{{"tool": "web", "reason": "Need current public information"}},
{{"tool": "llm", "reason": "Synthesize all information"}}
]
Only return the JSON array. Do not include markdown formatting.
"""
try:
out = await self.llm_client.simple_call(plan_prompt)
# Clean the output in case LLM adds markdown
out = out.strip()
if out.startswith("```json"):
out = out[7:]
if out.startswith("```"):
out = out[3:]
if out.endswith("```"):
out = out[:-3]
out = out.strip()
steps_json = json.loads(out)
# Replace steps with LLM-planned steps (excluding LLM, we'll add it at end)
steps = [
step(s["tool"], {"query": text})
for s in steps_json if s.get("tool") != "llm"
]
except Exception as e:
# If LLM planning fails, keep existing steps or use fallback
if not steps:
steps = []
# ---------------------------------
# 7. Always end with LLM synthesis
# ---------------------------------
if not steps or steps[-1]["tool"] != "llm":
steps.append(step("llm", {
"rag_data": rag_results if rag_has_data else None,
"query": text
}))
# Build reason string showing the tool sequence
tool_names = [s["tool"] for s in steps]
reason = f"multi-tool plan: {' β†’ '.join(tool_names)} | scores={tool_scores}"
return _multi_step(steps, reason)
def step(tool, input_data):
return {"tool": tool, "input": input_data}
def _multi_step(steps, reason):
from ..models.agent import AgentDecision
return AgentDecision(
action="multi_step",
tool=None,
tool_input={"steps": steps},
reason=reason
)