Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass, field | |
| import json | |
| import re | |
| from typing import Dict, Any, Optional, List | |
| from .tool_metadata import ( | |
| get_tool_latency_estimate, | |
| estimate_path_latency, | |
| get_fastest_path, | |
| validate_tool_output | |
| ) | |
| class ToolSelector: | |
| llm_client: any = None | |
| async def select(self, intent: str, text: str, ctx): | |
| msg = text.lower().strip() | |
| tool_scores = ctx.get("tool_scores", {}) | |
| rag_score = tool_scores.get("rag_fitness", 0.0) | |
| web_score = tool_scores.get("web_fitness", 0.0) | |
| llm_score = tool_scores.get("llm_only", 0.0) | |
| # Context-aware routing: Check previous outputs | |
| rag_results = ctx.get("rag_results", []) | |
| memory = ctx.get("memory", []) # Recent tool outputs from conversation memory | |
| admin_violations = ctx.get("admin_violations", []) | |
| # Context-aware decisions | |
| context_hints = self._analyze_context(rag_results, memory, admin_violations, tool_scores) | |
| # --------------------------------- | |
| # 1. Detect ADMIN RULES FIRST | |
| # --------------------------------- | |
| if intent == "admin": | |
| # Context-aware: If severe violation, skip agent reasoning | |
| if context_hints.get("skip_agent_reasoning"): | |
| return _multi_step([ | |
| step("admin", {"query": text}) | |
| ], "admin critical violation β immediate block (latency: ~10ms)") | |
| # Estimate latency for admin path | |
| admin_latency = get_tool_latency_estimate("admin", {"query_length": len(text)}) | |
| llm_latency = get_tool_latency_estimate("llm", {"query_length": len(text)}) | |
| total_latency = admin_latency + llm_latency | |
| return _multi_step([ | |
| step("admin", {"query": text}), | |
| step("llm", {"query": text}) | |
| ], f"admin safety rule triggered β llm (est. latency: {total_latency}ms)") | |
| steps = [] | |
| needs_rag = False | |
| needs_web = False | |
| # --------------------------------- | |
| # 2. PRIORITY: Check for news/current events queries FIRST | |
| # --------------------------------- | |
| # This must happen BEFORE RAG check to prevent news queries from using RAG | |
| freshness_keywords = ["latest", "today", "news", "current", "recent", | |
| "now", "updates", "breaking", "trending", "happening", | |
| "what's new", "what is new", "what happened"] | |
| news_patterns = [ | |
| r"latest news", r"current news", r"today's news", r"breaking news", | |
| r"news about", r"news on", r"news of", r"what's happening", | |
| r"what happened", r"recent news", r"news update" | |
| ] | |
| is_news_query = any(k in msg for k in freshness_keywords) or any(re.search(p, msg) for p in news_patterns) | |
| # If it's a news query, skip RAG entirely and go straight to web | |
| if is_news_query: | |
| needs_web = True | |
| needs_rag = False # News queries should NEVER use RAG | |
| # For news queries, enhance the query to be more specific | |
| web_query = text | |
| if len(text.split()) <= 4: # Short queries like "latest news about Al" | |
| # Expand the query for better results | |
| if "news" not in msg: | |
| web_query = f"{text} news latest" | |
| elif "about" not in msg and "on" not in msg: | |
| # If query is just "latest news Al", expand to "latest news about Al" | |
| web_query = f"latest news about {text.replace('latest', '').replace('news', '').strip()}" | |
| # Estimate latency for web search | |
| web_latency = get_tool_latency_estimate("web", { | |
| "query_length": len(web_query), | |
| "query_complexity": "high" if len(web_query.split()) > 10 else "medium" | |
| }) | |
| steps.append(step("web", {"query": web_query, "_estimated_latency_ms": web_latency})) | |
| # --------------------------------- | |
| # 3. Check RAG results (pre-fetch) with context-aware routing | |
| # --------------------------------- | |
| # Only check RAG if it's NOT a news query | |
| if not is_news_query: | |
| rag_has_data = len(rag_results) > 0 | |
| # Context-aware: If RAG returned high score, skip web search | |
| rag_high_score = False | |
| if rag_results: | |
| top_score = max((r.get("similarity", 0) for r in rag_results), default=0) | |
| rag_high_score = top_score >= 0.8 | |
| if rag_high_score and context_hints.get("skip_web_if_rag_high"): | |
| # High confidence RAG result, skip web | |
| needs_web = False | |
| # Context-aware: If agent already has relevant memory, skip RAG | |
| has_relevant_memory = context_hints.get("has_relevant_memory", False) | |
| if has_relevant_memory and context_hints.get("skip_rag_if_memory"): | |
| needs_rag = False | |
| else: | |
| # RAG patterns: internal knowledge, company-specific, documentation | |
| rag_patterns = [ | |
| r"company", r"internal", r"documentation", r"our ", r"your ", | |
| r"knowledge base", r"private", r"internal docs", r"corporate", | |
| r"admin", r"administrator" | |
| ] | |
| # Exclude "who is" and "what is" from RAG patterns if they're part of news queries | |
| # But keep them for non-news queries | |
| if not is_news_query: | |
| rag_patterns.extend([r"who is", r"what is"]) | |
| if rag_has_data or rag_score >= 0.55 or any(re.search(p, msg) for p in rag_patterns): | |
| needs_rag = True | |
| if not any(s.get("tool") == "rag" for s in steps): | |
| # Estimate latency for RAG | |
| rag_latency = get_tool_latency_estimate("rag", {"query_length": len(text)}) | |
| steps.append(step("rag", {"query": text, "_estimated_latency_ms": rag_latency})) | |
| # --------------------------------- | |
| # 4. Fact lookup / definition β Web (with context-aware routing) | |
| # --------------------------------- | |
| # Only check fact patterns if it's NOT a news query (news already handled above) | |
| if not is_news_query: | |
| # Skip web if RAG already provided high-quality results | |
| rag_high_score = False | |
| if rag_results: | |
| top_score = max((r.get("similarity", 0) for r in rag_results), default=0) | |
| rag_high_score = top_score >= 0.8 | |
| if not (rag_high_score and context_hints.get("skip_web_if_rag_high")): | |
| fact_patterns = [ | |
| r"what is ", r"who is ", r"where is ", | |
| r"tell me about ", r"define ", r"explain ", | |
| r"history of ", r"information about", r"details about" | |
| ] | |
| if web_score >= 0.55 or any(re.search(p, msg) for p in fact_patterns): | |
| needs_web = True | |
| # Avoid duplicate web steps | |
| if not any(s.get("tool") == "web" for s in steps): | |
| # Estimate latency for web search | |
| web_latency = get_tool_latency_estimate("web", { | |
| "query_length": len(text), | |
| "query_complexity": "high" if len(text.split()) > 10 else "medium" | |
| }) | |
| steps.append(step("web", {"query": text, "_estimated_latency_ms": web_latency})) | |
| # --------------------------------- | |
| # 5. Complex queries that need multiple sources | |
| # --------------------------------- | |
| complex_patterns = [ | |
| r"compare", r"difference between", r"versus", r"vs", | |
| r"both", r"and also", r"as well as", r"in addition" | |
| ] | |
| needs_multiple = any(re.search(p, msg) for p in complex_patterns) | |
| # --------------------------------- | |
| # 6. Use LLM to enhance plan if we have partial steps or complex query | |
| # --------------------------------- | |
| # Check if we should use parallel execution (both RAG and Web needed) | |
| should_parallel = needs_rag and needs_web and (needs_multiple or rag_score >= 0.55 and web_score >= 0.55) | |
| if self.llm_client and (needs_multiple or (needs_rag and needs_web) or len(steps) == 0): | |
| plan_prompt = f""" | |
| You are an enterprise MCP agent. | |
| You can select MULTIPLE tools in sequence OR in parallel to provide comprehensive answers. | |
| TOOLS: | |
| - rag β private knowledge retrieval (use for internal/company docs) | |
| - web β online factual lookup (use for public facts, current info) | |
| - llm β final reasoning and synthesis (always include at end) | |
| Current context: | |
| - RAG available: {rag_has_data} | |
| - User message: "{text}" | |
| - Tool scores: {json.dumps(tool_scores)} | |
| Determine which tools are needed. You can select: | |
| - Just LLM (simple questions) | |
| - RAG + LLM (internal knowledge questions) | |
| - Web + LLM (public fact questions) | |
| - RAG + Web + LLM (comprehensive questions needing both sources) | |
| IMPORTANT: If the query needs BOTH internal docs (RAG) AND current/live info (Web), | |
| you can mark them for parallel execution by using a "parallel" step. | |
| Return a JSON list describing the steps. For parallel execution, use: | |
| [ | |
| {{ | |
| "parallel": {{ | |
| "rag": "query for internal docs", | |
| "web": "query for live info" | |
| }}, | |
| "reason": "Need both internal and live information simultaneously" | |
| }}, | |
| {{"tool": "llm", "reason": "Synthesize all information"}} | |
| ] | |
| For sequential execution, use: | |
| [ | |
| {{"tool": "rag", "reason": "Need internal documentation"}}, | |
| {{"tool": "web", "reason": "Need current public information"}}, | |
| {{"tool": "llm", "reason": "Synthesize all information"}} | |
| ] | |
| Only return the JSON array. Do not include markdown formatting. | |
| """ | |
| try: | |
| out = await self.llm_client.simple_call(plan_prompt) | |
| # Clean the output in case LLM adds markdown | |
| out = out.strip() | |
| if out.startswith("```json"): | |
| out = out[7:] | |
| if out.startswith("```"): | |
| out = out[3:] | |
| if out.endswith("```"): | |
| out = out[:-3] | |
| out = out.strip() | |
| steps_json = json.loads(out) | |
| # Check if LLM returned a parallel step | |
| has_parallel = any("parallel" in s for s in steps_json) | |
| if has_parallel: | |
| # Extract parallel step and convert to our format | |
| parallel_step = None | |
| other_steps = [] | |
| for s in steps_json: | |
| if "parallel" in s: | |
| parallel_step = {"parallel": s["parallel"]} | |
| elif s.get("tool") != "llm": | |
| other_steps.append(step(s["tool"], {"query": text})) | |
| if parallel_step: | |
| steps = [parallel_step] + other_steps | |
| else: | |
| # Fallback: convert to regular steps | |
| steps = [ | |
| step(s["tool"], {"query": text}) | |
| for s in steps_json if s.get("tool") != "llm" | |
| ] | |
| else: | |
| # Replace steps with LLM-planned steps (excluding LLM, we'll add it at end) | |
| steps = [ | |
| step(s["tool"], {"query": text}) | |
| for s in steps_json if s.get("tool") != "llm" | |
| ] | |
| except Exception as e: | |
| # If LLM planning fails, check if we should create parallel step manually | |
| if should_parallel and needs_rag and needs_web: | |
| # Create parallel step manually | |
| steps = [{ | |
| "parallel": { | |
| "rag": text, | |
| "web": text | |
| } | |
| }] | |
| elif not steps: | |
| steps = [] | |
| # --------------------------------- | |
| # 7. If we have both RAG and Web but no parallel step, consider creating one | |
| # --------------------------------- | |
| if should_parallel and needs_rag and needs_web: | |
| # Check if we already have a parallel step | |
| has_parallel = any("parallel" in s for s in steps) | |
| if not has_parallel: | |
| # Replace sequential RAG and Web steps with a parallel step | |
| new_steps = [] | |
| rag_query = text | |
| web_query = text | |
| # Extract queries from existing steps if available | |
| for s in steps: | |
| if s.get("tool") == "rag": | |
| rag_query = s.get("input", {}).get("query", text) | |
| elif s.get("tool") == "web": | |
| web_query = s.get("input", {}).get("query", text) | |
| # Create parallel step | |
| new_steps.append({ | |
| "parallel": { | |
| "rag": rag_query, | |
| "web": web_query | |
| } | |
| }) | |
| # Keep other non-RAG/Web steps | |
| for s in steps: | |
| if s.get("tool") not in ["rag", "web"]: | |
| new_steps.append(s) | |
| steps = new_steps | |
| # --------------------------------- | |
| # 8. Always end with LLM synthesis | |
| # --------------------------------- | |
| if not steps or (isinstance(steps[-1], dict) and steps[-1].get("tool") != "llm" and "parallel" not in steps[-1]): | |
| steps.append(step("llm", { | |
| "rag_data": rag_results if rag_has_data else None, | |
| "query": text | |
| })) | |
| # Optimize tool order for latency (fastest first when possible) | |
| if len(steps) > 1: | |
| # Reorder steps by estimated latency (except LLM which should be last) | |
| llm_step = None | |
| other_steps = [] | |
| for s in steps: | |
| if isinstance(s, dict) and s.get("tool") == "llm": | |
| llm_step = s | |
| else: | |
| other_steps.append(s) | |
| # Sort other steps by latency | |
| other_steps.sort(key=lambda s: s.get("input", {}).get("_estimated_latency_ms", 1000)) | |
| # Rebuild steps with LLM last | |
| steps = other_steps | |
| if llm_step: | |
| steps.append(llm_step) | |
| # Calculate total estimated latency | |
| tool_names = [] | |
| total_latency = 0 | |
| for s in steps: | |
| if "parallel" in s: | |
| tool_names.append("parallel(RAG+Web)") | |
| # Parallel execution: use max latency | |
| rag_lat = get_tool_latency_estimate("rag") | |
| web_lat = get_tool_latency_estimate("web") | |
| total_latency += max(rag_lat, web_lat) | |
| elif isinstance(s, dict) and "tool" in s: | |
| tool_name = s["tool"] | |
| tool_names.append(tool_name) | |
| est_latency = s.get("input", {}).get("_estimated_latency_ms") | |
| if est_latency: | |
| total_latency += est_latency | |
| else: | |
| total_latency += get_tool_latency_estimate(tool_name) | |
| # Build reason with latency and context hints | |
| context_info = [] | |
| if context_hints.get("skip_web_if_rag_high"): | |
| context_info.append("RAG high score β skip web") | |
| if context_hints.get("skip_rag_if_memory"): | |
| context_info.append("memory available β skip RAG") | |
| if context_hints.get("skip_agent_reasoning"): | |
| context_info.append("critical violation β skip reasoning") | |
| context_str = f" | context: {', '.join(context_info)}" if context_info else "" | |
| reason = f"multi-tool plan: {' β '.join(tool_names)} | est. latency: {total_latency}ms | scores={tool_scores}{context_str}" | |
| return _multi_step(steps, reason) | |
| def _analyze_context( | |
| self, | |
| rag_results: List[Dict], | |
| memory: List[Dict], | |
| admin_violations: List[Dict], | |
| tool_scores: Dict[str, float] | |
| ) -> Dict[str, Any]: | |
| """ | |
| Analyze context from previous outputs to make routing decisions. | |
| Returns context hints for intelligent tool selection. | |
| """ | |
| hints = {} | |
| # Check RAG results quality | |
| if rag_results: | |
| top_score = max((r.get("similarity", 0) for r in rag_results), default=0) | |
| if top_score >= 0.8: | |
| hints["skip_web_if_rag_high"] = True | |
| hints["rag_high_confidence"] = True | |
| # Check if relevant memory exists | |
| if memory: | |
| # Check if memory contains relevant RAG results | |
| has_rag_memory = any( | |
| m.get("tool") == "rag" and m.get("result", {}).get("results") | |
| for m in memory[-5:] # Check last 5 memory entries | |
| ) | |
| if has_rag_memory: | |
| hints["has_relevant_memory"] = True | |
| # Only skip RAG if memory is very recent and high quality | |
| recent_memory = memory[-1] if memory else {} | |
| if recent_memory.get("tool") == "rag": | |
| mem_results = recent_memory.get("result", {}).get("results", []) | |
| if mem_results: | |
| mem_top_score = max((r.get("similarity", 0) for r in mem_results), default=0) | |
| if mem_top_score >= 0.75: | |
| hints["skip_rag_if_memory"] = True | |
| # Check admin violations severity | |
| if admin_violations: | |
| max_severity = max( | |
| (v.get("severity", "low") for v in admin_violations), | |
| key=lambda s: ["low", "medium", "high", "critical"].index(s) if s in ["low", "medium", "high", "critical"] else 0 | |
| ) | |
| if max_severity in ["high", "critical"]: | |
| hints["skip_agent_reasoning"] = True | |
| hints["critical_violation"] = True | |
| return hints | |
| def step(tool, input_data): | |
| return {"tool": tool, "input": input_data} | |
| def _multi_step(steps, reason): | |
| from ..models.agent import AgentDecision | |
| return AgentDecision( | |
| action="multi_step", | |
| tool=None, | |
| tool_input={"steps": steps}, | |
| reason=reason | |
| ) | |