IntegraChat / backend /api /services /tool_selector.py
nothingworry's picture
Multi-Tool Parallel Execution
6d531e9
raw
history blame
9.93 kB
from dataclasses import dataclass, field
import json
import re
@dataclass
class ToolSelector:
llm_client: any = None
async def select(self, intent: str, text: str, ctx):
msg = text.lower().strip()
tool_scores = ctx.get("tool_scores", {})
rag_score = tool_scores.get("rag_fitness", 0.0)
web_score = tool_scores.get("web_fitness", 0.0)
llm_score = tool_scores.get("llm_only", 0.0)
# ---------------------------------
# 1. Detect ADMIN RULES FIRST
# ---------------------------------
if intent == "admin":
return _multi_step([
step("admin", {"query": text}),
step("llm", {"query": text})
], "admin safety rule triggered β†’ llm")
steps = []
needs_rag = False
needs_web = False
# ---------------------------------
# 2. Check RAG results (pre-fetch)
# ---------------------------------
rag_results = ctx.get("rag_results", [])
rag_has_data = len(rag_results) > 0
# RAG patterns: internal knowledge, company-specific, documentation
rag_patterns = [
r"company", r"internal", r"documentation", r"our ", r"your ",
r"knowledge base", r"private", r"internal docs", r"corporate",
r"admin", r"administrator", r"who is", r"what is" # Add admin and fact lookup patterns
]
if rag_has_data or rag_score >= 0.55 or any(re.search(p, msg) for p in rag_patterns):
needs_rag = True
if not any(s["tool"] == "rag" for s in steps):
steps.append(step("rag", {"query": text}))
# ---------------------------------
# 3. Fact lookup / definition β†’ Web
# ---------------------------------
fact_patterns = [
r"what is ", r"who is ", r"where is ",
r"tell me about ", r"define ", r"explain ",
r"history of ", r"information about", r"details about"
]
if web_score >= 0.55 or any(re.search(p, msg) for p in fact_patterns):
needs_web = True
steps.append(step("web", {"query": text}))
# ---------------------------------
# 4. Freshness heuristic β†’ Web
# ---------------------------------
freshness_keywords = ["latest", "today", "news", "current", "recent",
"now", "updates", "breaking", "trending"]
if any(k in msg for k in freshness_keywords):
needs_web = True
# Avoid duplicate web steps
if not any(s["tool"] == "web" for s in steps):
steps.append(step("web", {"query": text}))
# ---------------------------------
# 5. Complex queries that need multiple sources
# ---------------------------------
complex_patterns = [
r"compare", r"difference between", r"versus", r"vs",
r"both", r"and also", r"as well as", r"in addition"
]
needs_multiple = any(re.search(p, msg) for p in complex_patterns)
# ---------------------------------
# 6. Use LLM to enhance plan if we have partial steps or complex query
# ---------------------------------
# Check if we should use parallel execution (both RAG and Web needed)
should_parallel = needs_rag and needs_web and (needs_multiple or rag_score >= 0.55 and web_score >= 0.55)
if self.llm_client and (needs_multiple or (needs_rag and needs_web) or len(steps) == 0):
plan_prompt = f"""
You are an enterprise MCP agent.
You can select MULTIPLE tools in sequence OR in parallel to provide comprehensive answers.
TOOLS:
- rag β†’ private knowledge retrieval (use for internal/company docs)
- web β†’ online factual lookup (use for public facts, current info)
- llm β†’ final reasoning and synthesis (always include at end)
Current context:
- RAG available: {rag_has_data}
- User message: "{text}"
- Tool scores: {json.dumps(tool_scores)}
Determine which tools are needed. You can select:
- Just LLM (simple questions)
- RAG + LLM (internal knowledge questions)
- Web + LLM (public fact questions)
- RAG + Web + LLM (comprehensive questions needing both sources)
IMPORTANT: If the query needs BOTH internal docs (RAG) AND current/live info (Web),
you can mark them for parallel execution by using a "parallel" step.
Return a JSON list describing the steps. For parallel execution, use:
[
{{
"parallel": {{
"rag": "query for internal docs",
"web": "query for live info"
}},
"reason": "Need both internal and live information simultaneously"
}},
{{"tool": "llm", "reason": "Synthesize all information"}}
]
For sequential execution, use:
[
{{"tool": "rag", "reason": "Need internal documentation"}},
{{"tool": "web", "reason": "Need current public information"}},
{{"tool": "llm", "reason": "Synthesize all information"}}
]
Only return the JSON array. Do not include markdown formatting.
"""
try:
out = await self.llm_client.simple_call(plan_prompt)
# Clean the output in case LLM adds markdown
out = out.strip()
if out.startswith("```json"):
out = out[7:]
if out.startswith("```"):
out = out[3:]
if out.endswith("```"):
out = out[:-3]
out = out.strip()
steps_json = json.loads(out)
# Check if LLM returned a parallel step
has_parallel = any("parallel" in s for s in steps_json)
if has_parallel:
# Extract parallel step and convert to our format
parallel_step = None
other_steps = []
for s in steps_json:
if "parallel" in s:
parallel_step = {"parallel": s["parallel"]}
elif s.get("tool") != "llm":
other_steps.append(step(s["tool"], {"query": text}))
if parallel_step:
steps = [parallel_step] + other_steps
else:
# Fallback: convert to regular steps
steps = [
step(s["tool"], {"query": text})
for s in steps_json if s.get("tool") != "llm"
]
else:
# Replace steps with LLM-planned steps (excluding LLM, we'll add it at end)
steps = [
step(s["tool"], {"query": text})
for s in steps_json if s.get("tool") != "llm"
]
except Exception as e:
# If LLM planning fails, check if we should create parallel step manually
if should_parallel and needs_rag and needs_web:
# Create parallel step manually
steps = [{
"parallel": {
"rag": text,
"web": text
}
}]
elif not steps:
steps = []
# ---------------------------------
# 7. If we have both RAG and Web but no parallel step, consider creating one
# ---------------------------------
if should_parallel and needs_rag and needs_web:
# Check if we already have a parallel step
has_parallel = any("parallel" in s for s in steps)
if not has_parallel:
# Replace sequential RAG and Web steps with a parallel step
new_steps = []
rag_query = text
web_query = text
# Extract queries from existing steps if available
for s in steps:
if s.get("tool") == "rag":
rag_query = s.get("input", {}).get("query", text)
elif s.get("tool") == "web":
web_query = s.get("input", {}).get("query", text)
# Create parallel step
new_steps.append({
"parallel": {
"rag": rag_query,
"web": web_query
}
})
# Keep other non-RAG/Web steps
for s in steps:
if s.get("tool") not in ["rag", "web"]:
new_steps.append(s)
steps = new_steps
# ---------------------------------
# 8. Always end with LLM synthesis
# ---------------------------------
if not steps or (isinstance(steps[-1], dict) and steps[-1].get("tool") != "llm" and "parallel" not in steps[-1]):
steps.append(step("llm", {
"rag_data": rag_results if rag_has_data else None,
"query": text
}))
# Build reason string showing the tool sequence
tool_names = []
for s in steps:
if "parallel" in s:
tool_names.append("parallel(RAG+Web)")
elif isinstance(s, dict) and "tool" in s:
tool_names.append(s["tool"])
reason = f"multi-tool plan: {' β†’ '.join(tool_names)} | scores={tool_scores}"
return _multi_step(steps, reason)
def step(tool, input_data):
return {"tool": tool, "input": input_data}
def _multi_step(steps, reason):
from ..models.agent import AgentDecision
return AgentDecision(
action="multi_step",
tool=None,
tool_input={"steps": steps},
reason=reason
)