Spaces:
Sleeping
Sleeping
File size: 6,091 Bytes
2f235a0 20a1017 2f235a0 ef83e66 2f235a0 ef83e66 2f235a0 ef83e66 2f235a0 ef83e66 2f235a0 ef83e66 2f235a0 ef83e66 2f235a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from dataclasses import dataclass, field
import json
import re
@dataclass
class ToolSelector:
llm_client: any = None
async def select(self, intent: str, text: str, ctx):
msg = text.lower().strip()
tool_scores = ctx.get("tool_scores", {})
rag_score = tool_scores.get("rag_fitness", 0.0)
web_score = tool_scores.get("web_fitness", 0.0)
llm_score = tool_scores.get("llm_only", 0.0)
# ---------------------------------
# 1. Detect ADMIN RULES FIRST
# ---------------------------------
if intent == "admin":
return _multi_step([
step("admin", {"query": text}),
step("llm", {"query": text})
], "admin safety rule triggered β llm")
steps = []
needs_rag = False
needs_web = False
# ---------------------------------
# 2. Check RAG results (pre-fetch)
# ---------------------------------
rag_results = ctx.get("rag_results", [])
rag_has_data = len(rag_results) > 0
# RAG patterns: internal knowledge, company-specific, documentation
rag_patterns = [
r"company", r"internal", r"documentation", r"our ", r"your ",
r"knowledge base", r"private", r"internal docs", r"corporate"
]
if rag_has_data or rag_score >= 0.55 or any(re.search(p, msg) for p in rag_patterns):
needs_rag = True
if not any(s["tool"] == "rag" for s in steps):
steps.append(step("rag", {"query": text}))
# ---------------------------------
# 3. Fact lookup / definition β Web
# ---------------------------------
fact_patterns = [
r"what is ", r"who is ", r"where is ",
r"tell me about ", r"define ", r"explain ",
r"history of ", r"information about", r"details about"
]
if web_score >= 0.55 or any(re.search(p, msg) for p in fact_patterns):
needs_web = True
steps.append(step("web", {"query": text}))
# ---------------------------------
# 4. Freshness heuristic β Web
# ---------------------------------
freshness_keywords = ["latest", "today", "news", "current", "recent",
"now", "updates", "breaking", "trending"]
if any(k in msg for k in freshness_keywords):
needs_web = True
# Avoid duplicate web steps
if not any(s["tool"] == "web" for s in steps):
steps.append(step("web", {"query": text}))
# ---------------------------------
# 5. Complex queries that need multiple sources
# ---------------------------------
complex_patterns = [
r"compare", r"difference between", r"versus", r"vs",
r"both", r"and also", r"as well as", r"in addition"
]
needs_multiple = any(re.search(p, msg) for p in complex_patterns)
# ---------------------------------
# 6. Use LLM to enhance plan if we have partial steps or complex query
# ---------------------------------
if self.llm_client and (needs_multiple or (needs_rag and needs_web) or len(steps) == 0):
plan_prompt = f"""
You are an enterprise MCP agent.
You can select MULTIPLE tools in sequence to provide comprehensive answers.
TOOLS:
- rag β private knowledge retrieval (use for internal/company docs)
- web β online factual lookup (use for public facts, current info)
- llm β final reasoning and synthesis (always include at end)
Current context:
- RAG available: {rag_has_data}
- User message: "{text}"
- Tool scores: {json.dumps(tool_scores)}
Determine which tools are needed. You can select:
- Just LLM (simple questions)
- RAG + LLM (internal knowledge questions)
- Web + LLM (public fact questions)
- RAG + Web + LLM (comprehensive questions needing both sources)
Return a JSON list describing the steps, e.g.:
[
{{"tool": "rag", "reason": "Need internal documentation"}},
{{"tool": "web", "reason": "Need current public information"}},
{{"tool": "llm", "reason": "Synthesize all information"}}
]
Only return the JSON array. Do not include markdown formatting.
"""
try:
out = await self.llm_client.simple_call(plan_prompt)
# Clean the output in case LLM adds markdown
out = out.strip()
if out.startswith("```json"):
out = out[7:]
if out.startswith("```"):
out = out[3:]
if out.endswith("```"):
out = out[:-3]
out = out.strip()
steps_json = json.loads(out)
# Replace steps with LLM-planned steps (excluding LLM, we'll add it at end)
steps = [
step(s["tool"], {"query": text})
for s in steps_json if s.get("tool") != "llm"
]
except Exception as e:
# If LLM planning fails, keep existing steps or use fallback
if not steps:
steps = []
# ---------------------------------
# 7. Always end with LLM synthesis
# ---------------------------------
if not steps or steps[-1]["tool"] != "llm":
steps.append(step("llm", {
"rag_data": rag_results if rag_has_data else None,
"query": text
}))
# Build reason string showing the tool sequence
tool_names = [s["tool"] for s in steps]
reason = f"multi-tool plan: {' β '.join(tool_names)} | scores={tool_scores}"
return _multi_step(steps, reason)
def step(tool, input_data):
return {"tool": tool, "input": input_data}
def _multi_step(steps, reason):
from ..models.agent import AgentDecision
return AgentDecision(
action="multi_step",
tool=None,
tool_input={"steps": steps},
reason=reason
)
|