File size: 6,091 Bytes
2f235a0
 
 
 
 
 
20a1017
2f235a0
 
 
 
 
ef83e66
 
 
 
2f235a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef83e66
2f235a0
ef83e66
2f235a0
 
 
 
 
 
 
 
 
 
ef83e66
2f235a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef83e66
2f235a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef83e66
2f235a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from dataclasses import dataclass, field
import json
import re


@dataclass
class ToolSelector:
    llm_client: any = None


    async def select(self, intent: str, text: str, ctx):
        msg = text.lower().strip()
        tool_scores = ctx.get("tool_scores", {})
        rag_score = tool_scores.get("rag_fitness", 0.0)
        web_score = tool_scores.get("web_fitness", 0.0)
        llm_score = tool_scores.get("llm_only", 0.0)

        # ---------------------------------
        # 1. Detect ADMIN RULES FIRST
        # ---------------------------------
        if intent == "admin":
            return _multi_step([
                step("admin", {"query": text}),
                step("llm", {"query": text})
            ], "admin safety rule triggered β†’ llm")

        steps = []
        needs_rag = False
        needs_web = False

        # ---------------------------------
        # 2. Check RAG results (pre-fetch)
        # ---------------------------------
        rag_results = ctx.get("rag_results", [])
        rag_has_data = len(rag_results) > 0

        # RAG patterns: internal knowledge, company-specific, documentation
        rag_patterns = [
            r"company", r"internal", r"documentation", r"our ", r"your ",
            r"knowledge base", r"private", r"internal docs", r"corporate"
        ]
        if rag_has_data or rag_score >= 0.55 or any(re.search(p, msg) for p in rag_patterns):
            needs_rag = True
            if not any(s["tool"] == "rag" for s in steps):
                steps.append(step("rag", {"query": text}))

        # ---------------------------------
        # 3. Fact lookup / definition β†’ Web
        # ---------------------------------
        fact_patterns = [
            r"what is ", r"who is ", r"where is ",
            r"tell me about ", r"define ", r"explain ",
            r"history of ", r"information about", r"details about"
        ]
        if web_score >= 0.55 or any(re.search(p, msg) for p in fact_patterns):
            needs_web = True
            steps.append(step("web", {"query": text}))

        # ---------------------------------
        # 4. Freshness heuristic β†’ Web
        # ---------------------------------
        freshness_keywords = ["latest", "today", "news", "current", "recent", 
                             "now", "updates", "breaking", "trending"]
        if any(k in msg for k in freshness_keywords):
            needs_web = True
            # Avoid duplicate web steps
            if not any(s["tool"] == "web" for s in steps):
                steps.append(step("web", {"query": text}))

        # ---------------------------------
        # 5. Complex queries that need multiple sources
        # ---------------------------------
        complex_patterns = [
            r"compare", r"difference between", r"versus", r"vs",
            r"both", r"and also", r"as well as", r"in addition"
        ]
        needs_multiple = any(re.search(p, msg) for p in complex_patterns)

        # ---------------------------------
        # 6. Use LLM to enhance plan if we have partial steps or complex query
        # ---------------------------------
        if self.llm_client and (needs_multiple or (needs_rag and needs_web) or len(steps) == 0):
            plan_prompt = f"""
You are an enterprise MCP agent. 
You can select MULTIPLE tools in sequence to provide comprehensive answers.

TOOLS:
- rag        β†’ private knowledge retrieval (use for internal/company docs)
- web        β†’ online factual lookup (use for public facts, current info)
- llm        β†’ final reasoning and synthesis (always include at end)

Current context:
- RAG available: {rag_has_data}
- User message: "{text}"
- Tool scores: {json.dumps(tool_scores)}

Determine which tools are needed. You can select:
- Just LLM (simple questions)
- RAG + LLM (internal knowledge questions)
- Web + LLM (public fact questions)
- RAG + Web + LLM (comprehensive questions needing both sources)

Return a JSON list describing the steps, e.g.:

[
  {{"tool": "rag", "reason": "Need internal documentation"}},
  {{"tool": "web", "reason": "Need current public information"}},
  {{"tool": "llm", "reason": "Synthesize all information"}}
]

Only return the JSON array. Do not include markdown formatting.
"""
            try:
                out = await self.llm_client.simple_call(plan_prompt)
                # Clean the output in case LLM adds markdown
                out = out.strip()
                if out.startswith("```json"):
                    out = out[7:]
                if out.startswith("```"):
                    out = out[3:]
                if out.endswith("```"):
                    out = out[:-3]
                out = out.strip()
                
                steps_json = json.loads(out)
                
                # Replace steps with LLM-planned steps (excluding LLM, we'll add it at end)
                steps = [
                    step(s["tool"], {"query": text})
                    for s in steps_json if s.get("tool") != "llm"
                ]
            except Exception as e:
                # If LLM planning fails, keep existing steps or use fallback
                if not steps:
                    steps = []

        # ---------------------------------
        # 7. Always end with LLM synthesis
        # ---------------------------------
        if not steps or steps[-1]["tool"] != "llm":
            steps.append(step("llm", {
                "rag_data": rag_results if rag_has_data else None,
                "query": text
            }))

        # Build reason string showing the tool sequence
        tool_names = [s["tool"] for s in steps]
        reason = f"multi-tool plan: {' β†’ '.join(tool_names)} | scores={tool_scores}"

        return _multi_step(steps, reason)



def step(tool, input_data):
    return {"tool": tool, "input": input_data}


def _multi_step(steps, reason):
    from ..models.agent import AgentDecision
    return AgentDecision(
        action="multi_step",
        tool=None,
        tool_input={"steps": steps},
        reason=reason
    )