from dataclasses import dataclass, field import json import re from typing import Dict, Any, Optional, List from .tool_metadata import ( get_tool_latency_estimate, estimate_path_latency, get_fastest_path, validate_tool_output ) @dataclass class ToolSelector: llm_client: any = None async def select(self, intent: str, text: str, ctx): msg = text.lower().strip() tool_scores = ctx.get("tool_scores", {}) rag_score = tool_scores.get("rag_fitness", 0.0) web_score = tool_scores.get("web_fitness", 0.0) llm_score = tool_scores.get("llm_only", 0.0) # Context-aware routing: Check previous outputs rag_results = ctx.get("rag_results", []) memory = ctx.get("memory", []) # Recent tool outputs from conversation memory admin_violations = ctx.get("admin_violations", []) # Context-aware decisions context_hints = self._analyze_context(rag_results, memory, admin_violations, tool_scores) # --------------------------------- # 1. Detect ADMIN RULES FIRST # --------------------------------- if intent == "admin": # Context-aware: If severe violation, skip agent reasoning if context_hints.get("skip_agent_reasoning"): return _multi_step([ step("admin", {"query": text}) ], "admin critical violation → immediate block (latency: ~10ms)") # Estimate latency for admin path admin_latency = get_tool_latency_estimate("admin", {"query_length": len(text)}) llm_latency = get_tool_latency_estimate("llm", {"query_length": len(text)}) total_latency = admin_latency + llm_latency return _multi_step([ step("admin", {"query": text}), step("llm", {"query": text}) ], f"admin safety rule triggered → llm (est. latency: {total_latency}ms)") steps = [] needs_rag = False needs_web = False # --------------------------------- # 2. PRIORITY: Check for news/current events queries FIRST # --------------------------------- # This must happen BEFORE RAG check to prevent news queries from using RAG freshness_keywords = ["latest", "today", "news", "current", "recent", "now", "updates", "breaking", "trending", "happening", "what's new", "what is new", "what happened"] news_patterns = [ r"latest news", r"current news", r"today's news", r"breaking news", r"news about", r"news on", r"news of", r"what's happening", r"what happened", r"recent news", r"news update" ] is_news_query = any(k in msg for k in freshness_keywords) or any(re.search(p, msg) for p in news_patterns) # If it's a news query, skip RAG entirely and go straight to web if is_news_query: needs_web = True needs_rag = False # News queries should NEVER use RAG # For news queries, enhance the query to be more specific web_query = text if len(text.split()) <= 4: # Short queries like "latest news about Al" # Expand the query for better results if "news" not in msg: web_query = f"{text} news latest" elif "about" not in msg and "on" not in msg: # If query is just "latest news Al", expand to "latest news about Al" web_query = f"latest news about {text.replace('latest', '').replace('news', '').strip()}" # Estimate latency for web search web_latency = get_tool_latency_estimate("web", { "query_length": len(web_query), "query_complexity": "high" if len(web_query.split()) > 10 else "medium" }) steps.append(step("web", {"query": web_query, "_estimated_latency_ms": web_latency})) # --------------------------------- # 3. Check RAG results (pre-fetch) with context-aware routing # --------------------------------- # Only check RAG if it's NOT a news query if not is_news_query: rag_has_data = len(rag_results) > 0 # Context-aware: If RAG returned high score, skip web search rag_high_score = False if rag_results: top_score = max((r.get("similarity", 0) for r in rag_results), default=0) rag_high_score = top_score >= 0.8 if rag_high_score and context_hints.get("skip_web_if_rag_high"): # High confidence RAG result, skip web needs_web = False # Context-aware: If agent already has relevant memory, skip RAG has_relevant_memory = context_hints.get("has_relevant_memory", False) if has_relevant_memory and context_hints.get("skip_rag_if_memory"): needs_rag = False else: # RAG patterns: internal knowledge, company-specific, documentation rag_patterns = [ r"company", r"internal", r"documentation", r"our ", r"your ", r"knowledge base", r"private", r"internal docs", r"corporate", r"admin", r"administrator" ] # Exclude "who is" and "what is" from RAG patterns if they're part of news queries # But keep them for non-news queries if not is_news_query: rag_patterns.extend([r"who is", r"what is"]) if rag_has_data or rag_score >= 0.55 or any(re.search(p, msg) for p in rag_patterns): needs_rag = True if not any(s.get("tool") == "rag" for s in steps): # Estimate latency for RAG rag_latency = get_tool_latency_estimate("rag", {"query_length": len(text)}) steps.append(step("rag", {"query": text, "_estimated_latency_ms": rag_latency})) # --------------------------------- # 4. Fact lookup / definition → Web (with context-aware routing) # --------------------------------- # Only check fact patterns if it's NOT a news query (news already handled above) if not is_news_query: # Skip web if RAG already provided high-quality results rag_high_score = False if rag_results: top_score = max((r.get("similarity", 0) for r in rag_results), default=0) rag_high_score = top_score >= 0.8 if not (rag_high_score and context_hints.get("skip_web_if_rag_high")): fact_patterns = [ r"what is ", r"who is ", r"where is ", r"tell me about ", r"define ", r"explain ", r"history of ", r"information about", r"details about" ] if web_score >= 0.55 or any(re.search(p, msg) for p in fact_patterns): needs_web = True # Avoid duplicate web steps if not any(s.get("tool") == "web" for s in steps): # Estimate latency for web search web_latency = get_tool_latency_estimate("web", { "query_length": len(text), "query_complexity": "high" if len(text.split()) > 10 else "medium" }) steps.append(step("web", {"query": text, "_estimated_latency_ms": web_latency})) # --------------------------------- # 5. Complex queries that need multiple sources # --------------------------------- complex_patterns = [ r"compare", r"difference between", r"versus", r"vs", r"both", r"and also", r"as well as", r"in addition" ] needs_multiple = any(re.search(p, msg) for p in complex_patterns) # --------------------------------- # 6. Use LLM to enhance plan if we have partial steps or complex query # --------------------------------- # Check if we should use parallel execution (both RAG and Web needed) should_parallel = needs_rag and needs_web and (needs_multiple or rag_score >= 0.55 and web_score >= 0.55) if self.llm_client and (needs_multiple or (needs_rag and needs_web) or len(steps) == 0): plan_prompt = f""" You are an enterprise MCP agent. You can select MULTIPLE tools in sequence OR in parallel to provide comprehensive answers. TOOLS: - rag → private knowledge retrieval (use for internal/company docs) - web → online factual lookup (use for public facts, current info) - llm → final reasoning and synthesis (always include at end) Current context: - RAG available: {rag_has_data} - User message: "{text}" - Tool scores: {json.dumps(tool_scores)} Determine which tools are needed. You can select: - Just LLM (simple questions) - RAG + LLM (internal knowledge questions) - Web + LLM (public fact questions) - RAG + Web + LLM (comprehensive questions needing both sources) IMPORTANT: If the query needs BOTH internal docs (RAG) AND current/live info (Web), you can mark them for parallel execution by using a "parallel" step. Return a JSON list describing the steps. For parallel execution, use: [ {{ "parallel": {{ "rag": "query for internal docs", "web": "query for live info" }}, "reason": "Need both internal and live information simultaneously" }}, {{"tool": "llm", "reason": "Synthesize all information"}} ] For sequential execution, use: [ {{"tool": "rag", "reason": "Need internal documentation"}}, {{"tool": "web", "reason": "Need current public information"}}, {{"tool": "llm", "reason": "Synthesize all information"}} ] Only return the JSON array. Do not include markdown formatting. """ try: out = await self.llm_client.simple_call(plan_prompt) # Clean the output in case LLM adds markdown out = out.strip() if out.startswith("```json"): out = out[7:] if out.startswith("```"): out = out[3:] if out.endswith("```"): out = out[:-3] out = out.strip() steps_json = json.loads(out) # Check if LLM returned a parallel step has_parallel = any("parallel" in s for s in steps_json) if has_parallel: # Extract parallel step and convert to our format parallel_step = None other_steps = [] for s in steps_json: if "parallel" in s: parallel_step = {"parallel": s["parallel"]} elif s.get("tool") != "llm": other_steps.append(step(s["tool"], {"query": text})) if parallel_step: steps = [parallel_step] + other_steps else: # Fallback: convert to regular steps steps = [ step(s["tool"], {"query": text}) for s in steps_json if s.get("tool") != "llm" ] else: # Replace steps with LLM-planned steps (excluding LLM, we'll add it at end) steps = [ step(s["tool"], {"query": text}) for s in steps_json if s.get("tool") != "llm" ] except Exception as e: # If LLM planning fails, check if we should create parallel step manually if should_parallel and needs_rag and needs_web: # Create parallel step manually steps = [{ "parallel": { "rag": text, "web": text } }] elif not steps: steps = [] # --------------------------------- # 7. If we have both RAG and Web but no parallel step, consider creating one # --------------------------------- if should_parallel and needs_rag and needs_web: # Check if we already have a parallel step has_parallel = any("parallel" in s for s in steps) if not has_parallel: # Replace sequential RAG and Web steps with a parallel step new_steps = [] rag_query = text web_query = text # Extract queries from existing steps if available for s in steps: if s.get("tool") == "rag": rag_query = s.get("input", {}).get("query", text) elif s.get("tool") == "web": web_query = s.get("input", {}).get("query", text) # Create parallel step new_steps.append({ "parallel": { "rag": rag_query, "web": web_query } }) # Keep other non-RAG/Web steps for s in steps: if s.get("tool") not in ["rag", "web"]: new_steps.append(s) steps = new_steps # --------------------------------- # 8. Always end with LLM synthesis # --------------------------------- if not steps or (isinstance(steps[-1], dict) and steps[-1].get("tool") != "llm" and "parallel" not in steps[-1]): steps.append(step("llm", { "rag_data": rag_results if rag_has_data else None, "query": text })) # Optimize tool order for latency (fastest first when possible) if len(steps) > 1: # Reorder steps by estimated latency (except LLM which should be last) llm_step = None other_steps = [] for s in steps: if isinstance(s, dict) and s.get("tool") == "llm": llm_step = s else: other_steps.append(s) # Sort other steps by latency other_steps.sort(key=lambda s: s.get("input", {}).get("_estimated_latency_ms", 1000)) # Rebuild steps with LLM last steps = other_steps if llm_step: steps.append(llm_step) # Calculate total estimated latency tool_names = [] total_latency = 0 for s in steps: if "parallel" in s: tool_names.append("parallel(RAG+Web)") # Parallel execution: use max latency rag_lat = get_tool_latency_estimate("rag") web_lat = get_tool_latency_estimate("web") total_latency += max(rag_lat, web_lat) elif isinstance(s, dict) and "tool" in s: tool_name = s["tool"] tool_names.append(tool_name) est_latency = s.get("input", {}).get("_estimated_latency_ms") if est_latency: total_latency += est_latency else: total_latency += get_tool_latency_estimate(tool_name) # Build reason with latency and context hints context_info = [] if context_hints.get("skip_web_if_rag_high"): context_info.append("RAG high score → skip web") if context_hints.get("skip_rag_if_memory"): context_info.append("memory available → skip RAG") if context_hints.get("skip_agent_reasoning"): context_info.append("critical violation → skip reasoning") context_str = f" | context: {', '.join(context_info)}" if context_info else "" reason = f"multi-tool plan: {' → '.join(tool_names)} | est. latency: {total_latency}ms | scores={tool_scores}{context_str}" return _multi_step(steps, reason) def _analyze_context( self, rag_results: List[Dict], memory: List[Dict], admin_violations: List[Dict], tool_scores: Dict[str, float] ) -> Dict[str, Any]: """ Analyze context from previous outputs to make routing decisions. Returns context hints for intelligent tool selection. """ hints = {} # Check RAG results quality if rag_results: top_score = max((r.get("similarity", 0) for r in rag_results), default=0) if top_score >= 0.8: hints["skip_web_if_rag_high"] = True hints["rag_high_confidence"] = True # Check if relevant memory exists if memory: # Check if memory contains relevant RAG results has_rag_memory = any( m.get("tool") == "rag" and m.get("result", {}).get("results") for m in memory[-5:] # Check last 5 memory entries ) if has_rag_memory: hints["has_relevant_memory"] = True # Only skip RAG if memory is very recent and high quality recent_memory = memory[-1] if memory else {} if recent_memory.get("tool") == "rag": mem_results = recent_memory.get("result", {}).get("results", []) if mem_results: mem_top_score = max((r.get("similarity", 0) for r in mem_results), default=0) if mem_top_score >= 0.75: hints["skip_rag_if_memory"] = True # Check admin violations severity if admin_violations: max_severity = max( (v.get("severity", "low") for v in admin_violations), key=lambda s: ["low", "medium", "high", "critical"].index(s) if s in ["low", "medium", "high", "critical"] else 0 ) if max_severity in ["high", "critical"]: hints["skip_agent_reasoning"] = True hints["critical_violation"] = True return hints def step(tool, input_data): return {"tool": tool, "input": input_data} def _multi_step(steps, reason): from ..models.agent import AgentDecision return AgentDecision( action="multi_step", tool=None, tool_input={"steps": steps}, reason=reason )