Spaces:
Sleeping
Sleeping
| # ============================================================= | |
| # File: backend/api/services/agent_orchestrator.py | |
| # ============================================================= | |
| """ | |
| Agent Orchestrator (integrated with enterprise RedFlagDetector) | |
| Place at: backend/api/services/agent_orchestrator.py | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| import re | |
| from typing import List, Dict, Any, Optional | |
| import logging | |
| from ..models.agent import AgentRequest, AgentDecision, AgentResponse | |
| from ..models.redflag import RedFlagMatch | |
| from .redflag_detector import RedFlagDetector | |
| from .intent_classifier import IntentClassifier | |
| from .tool_selector import ToolSelector | |
| from .llm_client import LLMClient | |
| from ..mcp_clients.mcp_client import MCPClient | |
| from .tool_scoring import ToolScoringService | |
| from ..storage.analytics_store import AnalyticsStore | |
| from .result_merger import merge_parallel_results, format_merged_context_for_prompt | |
| from .tool_metadata import validate_tool_output, get_tool_schema | |
| from .query_cache import get_cache | |
| from .query_expander import QueryExpander | |
| from .context_engineer import ContextEngineer | |
| import time | |
| logger = logging.getLogger(__name__) | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class AgentOrchestrator: | |
| def __init__(self, rag_mcp_url: str, web_mcp_url: str, admin_mcp_url: str): | |
| self.mcp = MCPClient(rag_mcp_url, web_mcp_url, admin_mcp_url) | |
| # Groq-only LLM client | |
| self.llm = LLMClient(api_key=os.getenv("GROQ_API_KEY"), model=os.getenv("GROQ_MODEL")) | |
| # pass admin_mcp_url so detector can call back | |
| self.redflag = RedFlagDetector( | |
| supabase_url=os.getenv("SUPABASE_URL"), | |
| supabase_key=os.getenv("SUPABASE_SERVICE_KEY"), | |
| admin_mcp_url=admin_mcp_url | |
| ) | |
| self.intent = IntentClassifier(llm_client=self.llm) | |
| self.selector = ToolSelector(llm_client=self.llm) | |
| self.tool_scorer = ToolScoringService() | |
| self.query_expander = QueryExpander(llm_client=self.llm) | |
| self.cache = get_cache() | |
| self.context_engineer = ContextEngineer(llm_client=self.llm) | |
| self._analytics: Optional[AnalyticsStore] = None | |
| self._analytics_disabled = os.getenv("ANALYTICS_DISABLED", "").lower() in {"1", "true", "yes"} | |
| self._analytics_failed = False | |
| self._log_analytics_backend_once() | |
| def _log_analytics_backend_once(self) -> None: | |
| if getattr(AgentOrchestrator, "_analytics_backend_logged", False): | |
| return | |
| if self._analytics_disabled: | |
| logger.info("Analytics: Disabled via ANALYTICS_DISABLED") | |
| else: | |
| store = self._get_analytics() | |
| if store is None: | |
| # Only log if credentials might be missing (not if package is missing) | |
| import os | |
| if os.getenv("SUPABASE_URL") and os.getenv("SUPABASE_SERVICE_KEY"): | |
| logger.warning("Analytics: Disabled (Supabase initialization failed)") | |
| else: | |
| logger.debug("Analytics: Disabled (Supabase not configured)") | |
| elif store.use_supabase: | |
| logger.info("Analytics: Using Supabase backend") | |
| else: | |
| logger.warning("Analytics: Using fallback backend") | |
| AgentOrchestrator._analytics_backend_logged = True | |
| def _get_analytics(self) -> Optional[AnalyticsStore]: | |
| if self._analytics_disabled or self._analytics_failed: | |
| return None | |
| if self._analytics is not None: | |
| return self._analytics | |
| try: | |
| self._analytics = AnalyticsStore() | |
| except RuntimeError as exc: | |
| # Only log at warning level if credentials are configured (actual error) | |
| # Otherwise log at debug level (expected when Supabase is not configured) | |
| import os | |
| if os.getenv("SUPABASE_URL") and os.getenv("SUPABASE_SERVICE_KEY"): | |
| logger.warning("Analytics disabled: %s", str(exc).split('\n')[0]) # Only first line | |
| else: | |
| logger.debug("Analytics disabled: %s", str(exc).split('\n')[0]) | |
| self._analytics_failed = True | |
| self._analytics = None | |
| except Exception as exc: # pragma: no cover - unexpected initialization failures | |
| logger.debug("Analytics unexpected init failure: %s", exc) | |
| self._analytics_failed = True | |
| self._analytics = None | |
| return self._analytics | |
| def _analytics_log_tool_usage(self, **kwargs: Any) -> None: | |
| analytics = self._get_analytics() | |
| if not analytics: | |
| return | |
| try: | |
| analytics.log_tool_usage(**kwargs) | |
| except Exception as exc: # pragma: no cover - analytics failures should not break flow | |
| logger.debug("AgentOrchestrator tool analytics failed: %s", exc) | |
| def _analytics_log_agent_query(self, **kwargs: Any) -> None: | |
| analytics = self._get_analytics() | |
| if not analytics: | |
| return | |
| try: | |
| analytics.log_agent_query(**kwargs) | |
| except Exception as exc: # pragma: no cover | |
| logger.debug("AgentOrchestrator agent query analytics failed: %s", exc) | |
| def _analytics_log_rag_search(self, **kwargs: Any) -> None: | |
| analytics = self._get_analytics() | |
| if not analytics: | |
| return | |
| try: | |
| analytics.log_rag_search(**kwargs) | |
| except Exception as exc: # pragma: no cover | |
| logger.debug("AgentOrchestrator RAG analytics failed: %s", exc) | |
| def _analytics_log_redflag_violation(self, **kwargs: Any) -> None: | |
| analytics = self._get_analytics() | |
| if not analytics: | |
| return | |
| try: | |
| analytics.log_redflag_violation(**kwargs) | |
| except Exception as exc: # pragma: no cover | |
| logger.debug("AgentOrchestrator redflag analytics failed: %s", exc) | |
| def _cache_response(self, req: AgentRequest, response: AgentResponse, skip_cache: bool = False): | |
| """Cache a response if appropriate.""" | |
| if skip_cache or req.message.startswith("admin:") or len(req.message) < 3: | |
| return | |
| try: | |
| self.cache.set(req.message, req.tenant_id, { | |
| "text": response.text, | |
| "decision": response.decision.dict() if response.decision else None, | |
| "tool_traces": response.tool_traces, | |
| "reasoning_trace": response.reasoning_trace | |
| }) | |
| except Exception as e: | |
| logger.debug(f"Failed to cache response: {e}") | |
| async def handle(self, req: AgentRequest) -> AgentResponse: | |
| start_time = time.time() | |
| reasoning_trace: List[Dict[str, Any]] = [] | |
| reasoning_trace.append({ | |
| "step": "request_received", | |
| "tenant_id": req.tenant_id, | |
| "user_id": req.user_id, | |
| "message_preview": req.message[:120] | |
| }) | |
| # Context Engineering: Write to scratchpad | |
| self.context_engineer.write_to_scratchpad( | |
| f"User query: {req.message[:200]}", | |
| category="user_query" | |
| ) | |
| # Check cache first (skip for admin queries and rule checks) | |
| cached_response = self.cache.get(req.message, req.tenant_id) | |
| if cached_response: | |
| reasoning_trace.append({ | |
| "step": "cache_hit", | |
| "cached": True | |
| }) | |
| return AgentResponse( | |
| text=cached_response.get("text", ""), | |
| decision=cached_response.get("decision"), | |
| tool_traces=cached_response.get("tool_traces", []), | |
| reasoning_trace=reasoning_trace + cached_response.get("reasoning_trace", []) | |
| ) | |
| # 1) FIRST: Check admin rules - if any rule matches, respond according to rule | |
| matches: List[RedFlagMatch] = await self.redflag.check(req.tenant_id, req.message) | |
| reasoning_trace.append({ | |
| "step": "admin_rules_check", | |
| "match_count": len(matches), | |
| "matches": [m.__dict__ for m in matches] | |
| }) | |
| if matches: | |
| # Log all rule matches | |
| for match in matches: | |
| self._analytics_log_redflag_violation( | |
| tenant_id=req.tenant_id, | |
| rule_id=match.rule_id, | |
| rule_pattern=match.pattern, | |
| severity=match.severity, | |
| matched_text=match.matched_text, | |
| confidence=match.confidence, | |
| message_preview=req.message[:200], | |
| user_id=req.user_id | |
| ) | |
| # Categorize rules: brief response rules vs blocking rules | |
| brief_response_rules = [] | |
| blocking_rules = [] | |
| for match in matches: | |
| rule_text = (match.description or match.pattern or "").lower() | |
| is_brief_rule = ( | |
| match.severity == "low" and ( | |
| "greeting" in rule_text or | |
| "brief" in rule_text or | |
| "simple response" in rule_text or | |
| "keep.*response.*brief" in rule_text or | |
| "do not.*verbose" in rule_text or | |
| "respond.*briefly" in rule_text | |
| ) | |
| ) | |
| if is_brief_rule: | |
| brief_response_rules.append(match) | |
| else: | |
| blocking_rules.append(match) | |
| # Handle brief response rules (greetings, etc.) - return immediately | |
| if brief_response_rules and not blocking_rules: | |
| # Return brief response without proceeding to normal flow | |
| brief_responses = [ | |
| "Hello! How can I help you today?", | |
| "Hi there! What can I assist you with?", | |
| "Hello! I'm here to help. What do you need?", | |
| "Hi! How can I assist you?" | |
| ] | |
| import random | |
| brief_response = random.choice(brief_responses) | |
| reasoning_trace.append({ | |
| "step": "brief_response_rule_matched", | |
| "action": "brief_response", | |
| "matched_rules": [m.rule_id for m in brief_response_rules], | |
| "message": "Brief response rule matched, returning brief response (skipping normal flow)" | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent="greeting", | |
| tools_used=[], | |
| total_tokens=len(brief_response) // 4, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse( | |
| text=brief_response, | |
| decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="brief_response_rule"), | |
| reasoning_trace=reasoning_trace | |
| ) | |
| # Handle blocking rules (security, compliance, etc.) - block and return immediately | |
| if blocking_rules: | |
| # Notify admin asynchronously | |
| try: | |
| await self.redflag.notify_admin(req.tenant_id, blocking_rules, source_payload={"message": req.message, "user_id": req.user_id}) | |
| except Exception: | |
| pass | |
| decision = AgentDecision( | |
| action="block", | |
| tool="admin", | |
| tool_input={"violations": [m.__dict__ for m in blocking_rules]}, | |
| reason="admin_rule_violation" | |
| ) | |
| # Build detailed prompt for LLM to generate natural red flag response | |
| violations_details = [] | |
| for i, m in enumerate(blocking_rules, 1): | |
| rule_name = m.description or m.pattern or "Policy violation" | |
| detail = f"{i}. **{rule_name}** (Severity: {m.severity})" | |
| if m.matched_text: | |
| detail += f"\n - Detected phrase: \"{m.matched_text}\"" | |
| violations_details.append(detail) | |
| llm_prompt = f"""A user made the following request: "{req.message}" | |
| However, this request violates company policies. The following policy violations were detected: | |
| {chr(10).join(violations_details)} | |
| Your task: Write a clear, professional, and empathetic response to inform the user that: | |
| 1. Their request cannot be processed due to policy violations | |
| 2. Which specific policy was violated (mention it naturally) | |
| 3. The incident has been logged for security review | |
| 4. They should contact an administrator if they need assistance or believe this is an error | |
| Write a natural, conversational response (2-4 sentences) that feels helpful rather than robotic. Be professional but understanding. | |
| Response:""" | |
| # Generate LLM response for red flag | |
| try: | |
| llm_response = await self.llm.simple_call(llm_prompt, temperature=min(req.temperature + 0.2, 0.7)) # Slightly higher temp for more natural response | |
| llm_response = llm_response.strip() | |
| # Add warning emoji if not present | |
| if not llm_response.startswith("⚠️") and not llm_response.startswith("🚨"): | |
| llm_response = f"⚠️ {llm_response}" | |
| except Exception as e: | |
| # Fallback to a simple message if LLM fails | |
| summary = "; ".join( | |
| f"{m.description or m.pattern}" | |
| for m in matches | |
| ) | |
| llm_response = f"⚠️ I'm unable to process your request because it violates our company policy: {summary}. This incident has been logged. Please contact your administrator if you need assistance." | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| # Log LLM usage for red flag response | |
| estimated_tokens = len(llm_response) // 4 + len(llm_prompt) // 4 | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=total_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent="admin", | |
| tools_used=["admin", "llm"], | |
| total_tokens=estimated_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=False, | |
| user_id=req.user_id | |
| ) | |
| response = AgentResponse( | |
| text=llm_response, | |
| decision=decision, | |
| tool_traces=[{"redflags": [m.__dict__ for m in blocking_rules]}], | |
| reasoning_trace=reasoning_trace | |
| ) | |
| # Don't cache admin rule violations | |
| return response | |
| # 2) ONLY IF NO RULES MATCHED: Proceed with normal flow (intent classification, RAG, etc.) | |
| # 2.1) Optional: Try to rewrite message if it might violate rules (preventive self-correction) | |
| # Note: This is a lighter check - we already blocked above if rules matched | |
| # This is for edge cases where we want to proactively improve the message | |
| safe_message = req.message # Default to original | |
| intent = await self.intent.classify(req.message) | |
| reasoning_trace.append({ | |
| "step": "intent_detection", | |
| "intent": intent | |
| }) | |
| # 2.5) Pre-fetch RAG results if available (for tool selector context) | |
| # BUT: Skip RAG pre-fetch for news/current events queries (they need web search, not RAG) | |
| rag_prefetch = None | |
| rag_results = [] | |
| # Detect news queries early to skip RAG pre-fetch | |
| # Make detection more aggressive - check for "news" keyword first | |
| msg_lower = req.message.lower().strip() | |
| # Primary detection: if "news" is in the message, it's almost certainly a news query | |
| has_news_keyword = "news" in msg_lower | |
| # Exclude common non-news phrases that contain "news" but aren't news queries | |
| non_news_phrases = [ | |
| "what is", "what's", "explain", "tell me about", "define", | |
| "how does", "how do", "what are", "what does", "what can" | |
| ] | |
| is_general_question = any(phrase in msg_lower for phrase in non_news_phrases) | |
| freshness_keywords = ["latest", "today", "current", "recent", | |
| "now", "updates", "breaking", "trending", "happening", | |
| "what's new", "what is new", "what happened"] | |
| news_patterns = [ | |
| r"latest news", r"current news", r"today's news", r"breaking news", | |
| r"news about", r"news on", r"news of", r"what's happening", | |
| r"what happened", r"recent news", r"news update" | |
| ] | |
| # If "news" keyword is present AND it's not a general question, it's a news query | |
| # Otherwise check for other freshness indicators | |
| is_news_query = (has_news_keyword and not is_general_question) or \ | |
| (any(k in msg_lower for k in freshness_keywords) and not is_general_question) or \ | |
| any(re.search(p, msg_lower) for p in news_patterns) | |
| # LLM-based detection for edge cases (if keyword-based detection is uncertain) | |
| # Only use LLM if it's a short query and we're uncertain | |
| if not is_news_query and len(msg_lower.split()) <= 5 and not is_general_question: | |
| # For short queries, use LLM to check if it's a news query | |
| try: | |
| llm_check_prompt = f"""Is the following query asking for current news or recent events? Answer only "yes" or "no". | |
| Query: "{req.message}" | |
| Answer:""" | |
| llm_response = await self.llm.simple_call(llm_check_prompt, temperature=0.0) | |
| if "yes" in llm_response.lower(): | |
| is_news_query = True | |
| reasoning_trace.append({ | |
| "step": "news_query_detection_llm", | |
| "detected": True, | |
| "llm_confirmed": True | |
| }) | |
| except Exception as e: | |
| logger.debug(f"LLM news detection failed: {e}") | |
| # Log detection for debugging | |
| if is_news_query: | |
| reasoning_trace.append({ | |
| "step": "news_query_detection", | |
| "detected": True, | |
| "message": req.message, | |
| "has_news_keyword": has_news_keyword, | |
| "matched_keywords": [k for k in freshness_keywords if k in msg_lower] | |
| }) | |
| # Only pre-fetch RAG if it's NOT a news query | |
| if not is_news_query: | |
| try: | |
| # Try to pre-fetch RAG to help tool selector make better decisions | |
| rag_start = time.time() | |
| rag_prefetch = await self.mcp.call_rag(req.tenant_id, req.message) | |
| rag_latency_ms = int((time.time() - rag_start) * 1000) | |
| if isinstance(rag_prefetch, dict): | |
| rag_results = rag_prefetch.get("results") or rag_prefetch.get("hits") or [] | |
| # Log RAG search event | |
| hits_count = len(rag_results) | |
| avg_score = None | |
| top_score = None | |
| if rag_results: | |
| scores = [h.get("score", 0.0) for h in rag_results if isinstance(h, dict) and "score" in h] | |
| if scores: | |
| avg_score = sum(scores) / len(scores) | |
| top_score = max(scores) | |
| self._analytics_log_rag_search( | |
| tenant_id=req.tenant_id, | |
| query=req.message[:500], | |
| hits_count=hits_count, | |
| avg_score=avg_score, | |
| top_score=top_score, | |
| latency_ms=rag_latency_ms | |
| ) | |
| # Log tool usage | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=rag_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "rag_prefetch", | |
| "status": "ok", | |
| "hit_count": len(rag_results), | |
| "latency_ms": rag_latency_ms | |
| }) | |
| except Exception as pref_err: | |
| # If RAG fails, continue without it | |
| rag_latency_ms = 0 # 0 for failed | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=rag_latency_ms, | |
| success=False, | |
| error_message=str(pref_err)[:200], | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "rag_prefetch", | |
| "status": "error", | |
| "error": str(pref_err) | |
| }) | |
| rag_prefetch = None | |
| else: | |
| # News query detected - skip RAG pre-fetch | |
| reasoning_trace.append({ | |
| "step": "rag_prefetch", | |
| "status": "skipped", | |
| "reason": "news_query_detected" | |
| }) | |
| tool_scores = self.tool_scorer.score(req.message, intent, rag_results) | |
| reasoning_trace.append({ | |
| "step": "tool_scoring", | |
| "scores": tool_scores | |
| }) | |
| # 3) Tool selection (hybrid) - pass RAG results, memory, and admin violations in context | |
| # Context Engineering: Compress conversation history if too long (Anthropic's compaction) | |
| # Use tool result clearing first (safest), then full compaction if needed | |
| if req.conversation_history and len(req.conversation_history) > 10: | |
| # Check token usage | |
| total_chars = sum(len(str(m.get("content", ""))) for m in req.conversation_history) | |
| estimated_tokens = total_chars // 4 | |
| # Compress if approaching context limit (80% threshold) | |
| if estimated_tokens > 8000: # ~80% of typical 10k context | |
| compressed_history = await self.context_engineer.compress_if_needed( | |
| req.conversation_history, | |
| max_tokens=6000, # Target 60% after compression | |
| use_compaction=True | |
| ) | |
| req.conversation_history = compressed_history | |
| reasoning_trace.append({ | |
| "step": "context_compaction", | |
| "original_length": len(req.conversation_history), | |
| "compressed_length": len(compressed_history), | |
| "compressed": len(compressed_history) < len(req.conversation_history), | |
| "strategy": "anthropic_compaction" | |
| }) | |
| # Get recent memory for context-aware routing | |
| from backend.mcp_server.common.memory import get_recent | |
| session_id = req.conversation_history[-1].get("session_id") if req.conversation_history else None | |
| recent_memory = [] | |
| if session_id: | |
| recent_memory = get_recent(session_id) | |
| # Context Engineering: Select relevant memories | |
| if recent_memory: | |
| selected_memories = await self.context_engineer.select_context( | |
| req.message, | |
| {"memories": recent_memory} | |
| ) | |
| recent_memory = selected_memories.get("memories", recent_memory) | |
| # Get admin violations if any | |
| admin_violations = [] | |
| if hasattr(self, 'redflag') and self.redflag: | |
| # Check if there were any violations detected | |
| # (This would be set during redflag checking earlier in the flow) | |
| pass # Admin violations are checked separately | |
| # FORCE web search for news queries - bypass tool selector entirely | |
| # Also ensure rag_results is empty for news queries (double-check) | |
| if is_news_query: | |
| rag_results = [] # Force empty - no RAG results for news queries | |
| from ..models.agent import AgentDecision | |
| # Enhance query for better web search results | |
| web_query = req.message | |
| # Handle ambiguous short queries like "latest news about Al" or "atest news about Al" | |
| # Try to expand with common interpretations | |
| query_words = web_query.lower().split() | |
| if len(query_words) <= 4: | |
| # Extract the topic (word after "about" or last word) | |
| topic = None | |
| if "about" in query_words: | |
| about_idx = query_words.index("about") | |
| if about_idx + 1 < len(query_words): | |
| topic = query_words[about_idx + 1] | |
| elif len(query_words) >= 2: | |
| # Last word might be the topic | |
| topic = query_words[-1] | |
| # If topic is very short (1-2 letters), it's likely ambiguous - expand it | |
| if topic and len(topic) <= 2: | |
| # Common expansions for "Al" | |
| if topic == "al": | |
| # Try multiple interpretations | |
| web_query = f"{' '.join(query_words[:-1])} artificial intelligence AI" | |
| elif topic == "ai": | |
| web_query = f"{' '.join(query_words[:-1])} artificial intelligence" | |
| # If still short, add "news" keyword if missing | |
| if "news" not in web_query.lower() and len(web_query.split()) <= 3: | |
| web_query = f"{web_query} news latest" | |
| decision = AgentDecision( | |
| action="call_tool", | |
| tool="web", | |
| tool_input={"query": web_query}, | |
| reason=f"news_query_forced_web_search (original: {req.message})" | |
| ) | |
| reasoning_trace.append({ | |
| "step": "tool_selection", | |
| "decision": decision.dict(), | |
| "note": "news_query_bypassed_selector_forced_web", | |
| "rag_results_forced_empty": True, | |
| "web_query": web_query | |
| }) | |
| else: | |
| ctx = { | |
| "tenant_id": req.tenant_id, | |
| "rag_results": rag_results, | |
| "tool_scores": tool_scores, | |
| "memory": recent_memory, # Context-aware routing: recent tool outputs | |
| "admin_violations": admin_violations # Context-aware routing: admin rule severity | |
| } | |
| decision = await self.selector.select(intent, req.message, ctx) | |
| reasoning_trace.append({ | |
| "step": "tool_selection", | |
| "decision": decision.dict(), | |
| "context_scores": tool_scores | |
| }) | |
| tool_traces: List[Dict[str, Any]] = [] | |
| # 4) Handle multi-step tool execution | |
| if decision.action == "multi_step" and decision.tool_input: | |
| steps = decision.tool_input.get("steps", []) | |
| if steps: | |
| return await self._execute_multi_step( | |
| req, | |
| steps, | |
| decision, | |
| tool_traces, | |
| reasoning_trace, | |
| rag_prefetch | |
| ) | |
| # 5) Execute single tool | |
| tools_used = [] | |
| total_tokens = 0 | |
| if decision.action == "call_tool" and decision.tool: | |
| try: | |
| if decision.tool == "rag": | |
| # Use autonomous retry with self-correction | |
| rag_query = decision.tool_input.get("query") if decision.tool_input else req.message | |
| rag_start = time.time() | |
| rag_resp = await self.rag_with_repair( | |
| query=rag_query, | |
| tenant_id=req.tenant_id, | |
| original_threshold=0.3, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| rag_latency_ms = int((time.time() - rag_start) * 1000) | |
| tools_used.append("rag") | |
| # Validate and format RAG output to conform to schema | |
| rag_formatted = self._format_tool_output("rag", rag_resp, rag_latency_ms) | |
| # Context Engineering: Compress tool output if needed | |
| rag_formatted = await self.context_engineer.compressor.compress_tool_output("rag", rag_formatted) | |
| tool_traces.append({"tool": "rag", "response": rag_formatted}) | |
| hits = self._extract_hits(rag_formatted) | |
| # Calculate scores for logging | |
| hits_count = len(hits) | |
| avg_score = rag_formatted.get("avg_score") | |
| top_score = rag_formatted.get("top_score") | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "hit_count": hits_count, | |
| "top_score": top_score, | |
| "avg_score": avg_score, | |
| "summary": self._summarize_hits(rag_formatted, limit=2) | |
| }) | |
| prompt = self._build_prompt_with_rag(req, rag_formatted) | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| # Estimate tokens (rough: ~4 chars per token) | |
| estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 | |
| total_tokens += estimated_tokens | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "llm_response", | |
| "mode": "rag_synthesis", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) | |
| if decision.tool == "web": | |
| # CRITICAL: For news queries, ensure RAG results are NEVER used | |
| msg_check_web = req.message.lower() | |
| is_news_web = "news" in msg_check_web or any(k in msg_check_web for k in ["latest", "breaking", "current", "recent", "today"]) | |
| if is_news_web: | |
| # Force clear any RAG context - news queries should NEVER use RAG | |
| rag_results = [] | |
| reasoning_trace.append({ | |
| "step": "web_tool_execution", | |
| "note": "news_query_confirmed_rag_results_cleared_before_web_search" | |
| }) | |
| # Use autonomous retry with query rewriting | |
| web_query = decision.tool_input.get("query") if decision.tool_input else req.message | |
| web_start = time.time() | |
| web_resp = await self.web_with_repair( | |
| query=web_query, | |
| tenant_id=req.tenant_id, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| web_latency_ms = int((time.time() - web_start) * 1000) | |
| tools_used.append("web") | |
| # Validate and format Web output to conform to schema | |
| web_formatted = self._format_tool_output("web", web_resp, web_latency_ms) | |
| # Context Engineering: Compress tool output if needed | |
| web_formatted = await self.context_engineer.compressor.compress_tool_output("web", web_formatted) | |
| tool_traces.append({"tool": "web", "response": web_formatted}) | |
| hits_count = len(self._extract_hits(web_formatted)) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "hit_count": hits_count, | |
| "summary": self._summarize_hits(web_formatted, limit=2), | |
| "is_news_query": is_news_web | |
| }) | |
| # ALWAYS use web prompt builder for web search results | |
| # Never use RAG prompt builder, even if web results are empty | |
| if hits_count == 0 and is_news_web: | |
| # Empty web results for news query - provide helpful guidance | |
| prompt = ( | |
| f"You are an assistant helping tenant {req.tenant_id}.\n\n" | |
| f"## User Question\n{req.message}\n\n" | |
| f"## Context\n" | |
| f"I searched for the latest news about this topic, but didn't find specific recent results in my web search.\n\n" | |
| f"## Your Task\n" | |
| f"Provide helpful information about what the user might be looking for. " | |
| f"If you have general knowledge about the topic, share it. " | |
| f"Be honest that I don't have access to the very latest breaking news right now, but provide what context you can. " | |
| f"Suggest that the user try:\n" | |
| f"- Checking major news websites directly (BBC, CNN, Reuters, etc.)\n" | |
| f"- Trying a more specific search query\n" | |
| f"- Using a news aggregator service\n\n" | |
| f"IMPORTANT: Do NOT say 'There is no mention of X in the provided context' - instead provide helpful general information or suggest where to find current news.\n\n" | |
| f"Provide a helpful response now:" | |
| ) | |
| else: | |
| # Use web prompt builder (never RAG) | |
| prompt = self._build_prompt_with_web(req, web_formatted) | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 | |
| total_tokens += estimated_tokens | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "llm_response", | |
| "mode": "web_synthesis", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) | |
| if decision.tool == "admin": | |
| admin_start = time.time() | |
| admin_resp = await self.mcp.call_admin(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message) | |
| admin_latency_ms = int((time.time() - admin_start) * 1000) | |
| tools_used.append("admin") | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="admin", | |
| latency_ms=admin_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| # Validate and format Admin output to conform to schema | |
| admin_formatted = self._format_tool_output("admin", admin_resp, admin_latency_ms) | |
| tool_traces.append({"tool": "admin", "response": admin_formatted}) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "admin", | |
| "status": "completed", | |
| "latency_ms": admin_latency_ms | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=0, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse(text=json.dumps(admin_resp), decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) | |
| if decision.tool == "llm": | |
| # Check if this is a news query - if so, force web search instead | |
| msg_lower_llm = req.message.lower() | |
| freshness_keywords_llm = ["latest", "today", "news", "current", "recent", | |
| "now", "updates", "breaking", "trending", "happening"] | |
| news_patterns_llm = [ | |
| r"latest news", r"current news", r"today's news", r"breaking news", | |
| r"news about", r"news on", r"news of" | |
| ] | |
| is_news_query_llm = any(k in msg_lower_llm for k in freshness_keywords_llm) or \ | |
| any(re.search(p, msg_lower_llm) for p in news_patterns_llm) | |
| # Force web search for news queries even if tool selector chose "llm" | |
| if is_news_query_llm: | |
| try: | |
| web_query = req.message | |
| if len(web_query.split()) <= 4: | |
| if "news" not in msg_lower_llm: | |
| web_query = f"{web_query} news latest" | |
| web_start = time.time() | |
| web_resp = await self.web_with_repair( | |
| query=web_query, | |
| tenant_id=req.tenant_id, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| web_latency_ms = int((time.time() - web_start) * 1000) | |
| tools_used.append("web") | |
| web_formatted = self._format_tool_output("web", web_resp, web_latency_ms) | |
| # Context Engineering: Compress tool output if needed | |
| web_formatted = await self.context_engineer.compressor.compress_tool_output("web", web_formatted) | |
| tool_traces.append({"tool": "web", "response": web_formatted}) | |
| hits_count = len(self._extract_hits(web_formatted)) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "hit_count": hits_count, | |
| "note": "forced_web_for_news_in_llm_path" | |
| }) | |
| if hits_count == 0: | |
| prompt_for_llm = ( | |
| f"You are an assistant helping tenant {req.tenant_id}.\n\n" | |
| f"## User Question\n{req.message}\n\n" | |
| f"## Context\n" | |
| f"I attempted to search for the latest news about this topic, but didn't find specific recent results.\n\n" | |
| f"## Your Task\n" | |
| f"Provide helpful information about what the user might be looking for. " | |
| f"If you have general knowledge about the topic, share it. " | |
| f"Be honest that you don't have access to the very latest breaking news, but provide what context you can. " | |
| f"Suggest that the user try checking major news websites directly or using a more specific search query.\n\n" | |
| f"Provide a helpful response now:" | |
| ) | |
| else: | |
| prompt_for_llm = self._build_prompt_with_web(req, web_formatted) | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(prompt_for_llm, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| estimated_tokens = len(llm_out) // 4 + len(prompt_for_llm) // 4 | |
| total_tokens += estimated_tokens | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace) | |
| except Exception as web_err: | |
| reasoning_trace.append({ | |
| "step": "web_search_forced_failed", | |
| "error": str(web_err)[:200] | |
| }) | |
| # Fall through to normal LLM path | |
| # If the user is asking who the admin / owner is, try to ground the | |
| # answer in tenant-specific RAG before falling back to a generic LLM reply. | |
| user_text = req.message.lower() | |
| # Normalize whitespace to make matching more robust | |
| user_text_normalized = " ".join(user_text.split()) | |
| admin_phrases = [ | |
| "who is the admin", | |
| "who's the admin", | |
| "who is admin", | |
| "who is the administrator", | |
| "who's the administrator", | |
| "who administers this platform", | |
| "who administers the platform", | |
| "who is the owner", | |
| "who's the owner", | |
| "who owns this platform", | |
| "who owns the platform", | |
| "who is the admin of integrachat", | |
| "who's the admin of integrachat", | |
| ] | |
| use_rag_for_admin = any(p in user_text_normalized for p in admin_phrases) or ( | |
| "admin" in user_text and "who" in user_text | |
| ) | |
| prompt_for_llm = req.message | |
| if use_rag_for_admin: | |
| try: | |
| rag_start = time.time() | |
| rag_resp = await self.rag_with_repair( | |
| query=req.message, | |
| tenant_id=req.tenant_id, | |
| original_threshold=0.2, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id, | |
| ) | |
| rag_latency_ms = int((time.time() - rag_start) * 1000) | |
| tools_used.append("rag") | |
| rag_formatted = self._format_tool_output("rag", rag_resp, rag_latency_ms) | |
| tool_traces.append({"tool": "rag", "response": rag_formatted}) | |
| hits = self._extract_hits(rag_formatted) | |
| hits_count = len(hits) | |
| avg_score = rag_formatted.get("avg_score") | |
| top_score = rag_formatted.get("top_score") | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=rag_latency_ms, | |
| success=True, | |
| user_id=req.user_id, | |
| ) | |
| reasoning_trace.append( | |
| { | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "hit_count": hits_count, | |
| "top_score": top_score, | |
| "avg_score": avg_score, | |
| "summary": self._summarize_hits(rag_formatted, limit=2), | |
| "note": "admin_identity_override", | |
| } | |
| ) | |
| # For admin questions, answer directly from RAG and avoid any | |
| # generic LLM behaviour. If there is at least one hit, return | |
| # that snippet; otherwise return an explicit "don't know". | |
| if hits: | |
| best = hits[0] | |
| admin_text = best.get("text") or best.get("content") or str(best) | |
| llm_out = f"According to the tenant knowledge base, {admin_text.strip()}" | |
| else: | |
| llm_out = "I don't know who administers this platform based on the tenant data." | |
| llm_latency_ms = 0 | |
| estimated_tokens = len(llm_out) // 4 + len(req.message) // 4 | |
| total_tokens += estimated_tokens | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id, | |
| ) | |
| reasoning_trace.append( | |
| { | |
| "step": "llm_response", | |
| "mode": "admin_from_rag_only", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens, | |
| } | |
| ) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id, | |
| ) | |
| return AgentResponse(text=llm_out, decision=decision, reasoning_trace=reasoning_trace) | |
| except Exception as rag_err: | |
| reasoning_trace.append( | |
| { | |
| "step": "rag_for_admin_fallback", | |
| "status": "error", | |
| "error": str(rag_err), | |
| } | |
| ) | |
| # For all other questions, if we already have RAG hits from pgvector | |
| # (rag_results from the prefetch step), reuse them to ground the | |
| # LLM response instead of answering purely from the model. | |
| # BUT: Skip RAG for news queries (they should use web search instead) | |
| is_news_query_here = any(k in req.message.lower() for k in ["latest", "today", "news", "current", "recent", "breaking", "trending", "happening", "updates"]) | |
| news_patterns_here = [ | |
| r"latest news", r"current news", r"today's news", r"breaking news", | |
| r"news about", r"news on", r"news of" | |
| ] | |
| is_news_query_here = is_news_query_here or any(re.search(p, req.message.lower()) for p in news_patterns_here) | |
| # NEVER use RAG for news queries - force web search or use general knowledge | |
| if not use_rag_for_admin and rag_results and not is_news_query_here: | |
| try: | |
| rag_prefetched_dict: Dict[str, Any] = {"results": rag_results} | |
| prompt_for_llm = self._build_prompt_with_rag(req, rag_prefetched_dict) | |
| reasoning_trace.append( | |
| { | |
| "step": "rag_context_for_llm", | |
| "hit_count": len(rag_results), | |
| "note": "used_prefetched_pgvector_hits", | |
| } | |
| ) | |
| except Exception as build_err: | |
| reasoning_trace.append( | |
| { | |
| "step": "rag_context_for_llm", | |
| "status": "error", | |
| "error": str(build_err), | |
| } | |
| ) | |
| elif not use_rag_for_admin: | |
| # No RAG results available - enhance the prompt to still provide best answer | |
| # BUT: For news queries, provide a helpful message about web search | |
| if is_news_query_here: | |
| prompt_for_llm = ( | |
| f"You are an assistant helping tenant {req.tenant_id}.\n\n" | |
| f"## User Question\n{req.message}\n\n" | |
| f"## Context\n" | |
| f"The user is asking for latest news. I attempted to search for current information but didn't find specific results.\n\n" | |
| f"## Your Task\n" | |
| f"Provide helpful information about what the user might be looking for. " | |
| f"If you have general knowledge about the topic, share it. " | |
| f"Be honest that you don't have access to the very latest breaking news, but provide what context you can. " | |
| f"Suggest that the user try checking major news websites directly or using a more specific search query.\n\n" | |
| f"IMPORTANT: Do NOT say 'There is no mention of X in the provided context' - instead provide helpful general information or suggest where to find current news." | |
| ) | |
| else: | |
| prompt_for_llm = ( | |
| f"You are an assistant helping tenant {req.tenant_id}.\n\n" | |
| f"## User Question\n{req.message}\n\n" | |
| f"## Your Task\n" | |
| f"Provide the best possible answer to the user's question. " | |
| f"Be clear, accurate, comprehensive, and helpful. " | |
| f"Focus on giving the user exactly what they need—clear guidance, accurate facts, " | |
| f"and practical steps whenever possible. " | |
| f"If you're uncertain about tenant-specific details, acknowledge that and provide general guidance." | |
| ) | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(prompt_for_llm, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| estimated_tokens = len(llm_out) // 4 + len(prompt_for_llm) // 4 | |
| total_tokens += estimated_tokens | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "llm_response", | |
| "mode": "direct", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse(text=llm_out, decision=decision, reasoning_trace=reasoning_trace) | |
| except Exception as e: | |
| tool_traces.append({"tool": decision.tool, "error": str(e)}) | |
| try: | |
| fallback = await self.llm.simple_call(req.message, temperature=req.temperature) | |
| except Exception as llm_error: | |
| error_msg = str(llm_error) | |
| if "Groq API key" in error_msg or "GROQ_API_KEY" in error_msg: | |
| fallback = ( | |
| f"I encountered an error while processing your request: {str(e)}\n\n" | |
| f"Additionally, the AI service (Groq) is unavailable: {error_msg}\n\n" | |
| f"To fix:\n" | |
| f"1. Get a free Groq API key from https://console.groq.com\n" | |
| f"2. Set GROQ_API_KEY in your .env file or environment variables" | |
| ) | |
| else: | |
| fallback = f"I encountered an error while processing your request: {str(e)}. Additionally, the AI service is unavailable: {error_msg}" | |
| return AgentResponse( | |
| text=fallback, | |
| decision=AgentDecision(action="respond", tool=None, tool_input=None, reason=f"tool_error_fallback: {e}"), | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace + [{ | |
| "step": "error", | |
| "tool": decision.tool, | |
| "error": str(e) | |
| }] | |
| ) | |
| # Default: direct LLM response | |
| # BUT: For news queries, try web search first even if tool selector didn't route to it | |
| msg_lower = req.message.lower() | |
| freshness_keywords = ["latest", "today", "news", "current", "recent", | |
| "now", "updates", "breaking", "trending", "happening"] | |
| news_patterns = [ | |
| r"latest news", r"current news", r"today's news", r"breaking news", | |
| r"news about", r"news on", r"news of" | |
| ] | |
| is_news_query_default = any(k in msg_lower for k in freshness_keywords) or \ | |
| any(re.search(p, msg_lower) for p in news_patterns) | |
| # If it's a news query and we're in the default path, force web search | |
| if is_news_query_default and decision.action != "call_tool" and decision.action != "multi_step": | |
| try: | |
| web_query = req.message | |
| if len(web_query.split()) <= 4: | |
| if "news" not in msg_lower: | |
| web_query = f"{web_query} news latest" | |
| web_start = time.time() | |
| web_resp = await self.web_with_repair( | |
| query=web_query, | |
| tenant_id=req.tenant_id, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| web_latency_ms = int((time.time() - web_start) * 1000) | |
| tools_used.append("web") | |
| web_formatted = self._format_tool_output("web", web_resp, web_latency_ms) | |
| # Context Engineering: Compress tool output if needed | |
| web_formatted = await self.context_engineer.compressor.compress_tool_output("web", web_formatted) | |
| tool_traces.append({"tool": "web", "response": web_formatted}) | |
| hits_count = len(self._extract_hits(web_formatted)) | |
| if hits_count > 0: | |
| prompt = self._build_prompt_with_web(req, web_formatted) | |
| else: | |
| # Web search returned no results - use a news-specific prompt | |
| prompt = ( | |
| f"You are an assistant helping tenant {req.tenant_id}.\n\n" | |
| f"## User Question\n{req.message}\n\n" | |
| f"## Context\n" | |
| f"The user is asking for latest news, but web search did not return specific results for this query.\n\n" | |
| f"## Your Task\n" | |
| f"Provide helpful information about what the user might be looking for. " | |
| f"If you know general information about the topic, share it. " | |
| f"Be honest that you don't have access to the very latest news, but provide what context you can. " | |
| f"Suggest that the user try rephrasing the query or checking news websites directly for the most current information." | |
| ) | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used, | |
| total_tokens=estimated_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse( | |
| text=llm_out, | |
| decision=AgentDecision(action="respond", tool="web", tool_input=None, reason="news_query_forced_web_search"), | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace | |
| ) | |
| except Exception as web_err: | |
| # If web search fails, fall through to default LLM | |
| reasoning_trace.append({ | |
| "step": "web_search_fallback", | |
| "error": str(web_err)[:200] | |
| }) | |
| try: | |
| llm_start = time.time() | |
| # For news queries in default path, use a better prompt | |
| if is_news_query_default: | |
| prompt_for_default = ( | |
| f"You are an assistant helping tenant {req.tenant_id}.\n\n" | |
| f"## User Question\n{req.message}\n\n" | |
| f"## Your Task\n" | |
| f"The user is asking for latest news. I don't have access to real-time web search results right now. " | |
| f"Please provide helpful information about what they might be looking for, or suggest they check news websites directly for the most current information." | |
| ) | |
| else: | |
| prompt_for_default = req.message | |
| llm_out = await self.llm.simple_call(prompt_for_default, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used = ["llm"] | |
| estimated_tokens = len(llm_out) // 4 + len(prompt_for_default) // 4 | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| except Exception as e: | |
| # If LLM fails, return a helpful error message | |
| error_msg = str(e) | |
| if "Groq API key" in error_msg or "GROQ_API_KEY" in error_msg: | |
| llm_out = ( | |
| f"I couldn't connect to the AI service (Groq). " | |
| f"Error: {error_msg}\n\n" | |
| f"To fix this:\n" | |
| f"1. Get a free Groq API key from https://console.groq.com\n" | |
| f"2. Set GROQ_API_KEY in your .env file or environment variables\n" | |
| f"3. Optionally set GROQ_MODEL (default: llama-3.1-8b-instant)" | |
| ) | |
| else: | |
| llm_out = f"I apologize, but I'm unable to process your request right now. The AI service is unavailable: {error_msg}" | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| success=False, | |
| error_message=error_msg[:200], | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "error", | |
| "tool": "llm", | |
| "error": str(e) | |
| }) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent=intent, | |
| tools_used=tools_used if 'tools_used' in locals() else [], | |
| total_tokens=estimated_tokens if 'estimated_tokens' in locals() else 0, | |
| total_latency_ms=total_latency_ms, | |
| success=True if 'llm_out' in locals() else False, | |
| user_id=req.user_id | |
| ) | |
| response = AgentResponse( | |
| text=llm_out, | |
| decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="default_llm"), | |
| reasoning_trace=reasoning_trace | |
| ) | |
| # Cache successful response | |
| self._cache_response(req, response) | |
| return response | |
| def _build_prompt_with_rag(self, req: AgentRequest, rag_resp: Dict[str, Any]) -> str: | |
| snippets = [] | |
| scores_info = [] | |
| if isinstance(rag_resp, dict): | |
| hits = rag_resp.get("results") or rag_resp.get("hits") or [] | |
| # Sort by score if available, take top 5 | |
| sorted_hits = sorted( | |
| hits, | |
| key=lambda h: float(h.get("score", h.get("similarity", 0.0))), | |
| reverse=True | |
| )[:5] | |
| for i, h in enumerate(sorted_hits, 1): | |
| txt = h.get("text") or h.get("content") or str(h) | |
| score = h.get("score") or h.get("similarity", 0.0) | |
| snippets.append(f"[Source {i}] {txt}") | |
| scores_info.append(f"Source {i}: relevance score {score:.3f}") | |
| snippet_text = "\n\n".join(snippets) if snippets else "" | |
| scores_text = "\n".join(scores_info) if scores_info else "" | |
| # Build optional relevance scores section separately to avoid complex f-string expressions | |
| if scores_text: | |
| relevance_section = "## Relevance Scores\n" + scores_text + "\n\n" | |
| else: | |
| relevance_section = "" | |
| if not snippet_text: | |
| prompt = ( | |
| f"You are an assistant helping tenant {req.tenant_id}.\n\n" | |
| f"## User Question\n{req.message}\n\n" | |
| f"## Context\n" | |
| f"No relevant documents were found in the knowledge base for this question.\n\n" | |
| f"## Important Rules\n" | |
| f"If the user asks a question that cannot be answered directly from the Knowledge Base, " | |
| f"then and ONLY then use the web-search tool to gather information. " | |
| f"When using web search, keep the response short, factual, and neutral. " | |
| f"Do NOT provide long legal, medical, or highly detailed professional explanations. " | |
| f"If the topic involves legal, medical, financial, or safety-critical advice, provide a brief general explanation " | |
| f"and tell the user to consult a qualified professional. " | |
| f"Never present external information as part of the official Knowledge Base.\n\n" | |
| f"## Your Task\n" | |
| f"Since no Knowledge Base documents were found, you may use web search as a fallback if needed. " | |
| f"Provide a brief, helpful answer. If you're uncertain about tenant-specific details, " | |
| f"acknowledge that and provide general guidance. " | |
| f"For legal, medical, financial, or safety-critical topics, keep responses brief and recommend consulting a professional." | |
| ) | |
| else: | |
| # Context Engineering: Get structured scratchpad context (Anthropic's note-taking) | |
| scratchpad_context = self.context_engineer.get_scratchpad_context(limit=5) | |
| scratchpad_section = f"\n## Structured Notes from Previous Steps\n{scratchpad_context}\n\n" if scratchpad_context else "" | |
| # Build prompt with Anthropic's recommended structure | |
| # Clear sections with XML/Markdown headers for better organization | |
| prompt = ( | |
| f"<system>\n" | |
| f"You are an assistant helping tenant {req.tenant_id}. " | |
| f"Your goal is to provide the most accurate, comprehensive, and helpful answer possible.\n" | |
| f"</system>\n\n" | |
| f"<background_information>\n" | |
| f"## KB-First Strategy\n" | |
| f"The Knowledge Base was checked first and relevant documents were found. " | |
| f"Use these documents as your PRIMARY and AUTHORITATIVE source. " | |
| f"Web search should ONLY be used as a fallback if the Knowledge Base cannot answer the question.\n" | |
| f"{scratchpad_section}" | |
| f"</background_information>\n\n" | |
| f"<knowledge_base_documents>\n" | |
| f"The following documents were retrieved from the tenant's knowledge base as relevant to the user's question:\n\n" | |
| f"{snippet_text}\n\n" | |
| f"{relevance_section}" | |
| f"</knowledge_base_documents>\n\n" | |
| f"<user_question>\n" | |
| f"{req.message}\n" | |
| f"</user_question>\n\n" | |
| f"<instructions>\n" | |
| f"## Your Task\n" | |
| f"1. **Primary Goal**: Answer the user's question using the information from the knowledge base documents above.\n" | |
| f"2. **KB Priority**: Base your answer PRIMARILY on the Knowledge Base. This is the authoritative source for tenant-specific information.\n" | |
| f"3. **Accuracy**: Base your answer primarily on the highest-scoring sources (most relevant documents).\n" | |
| f"4. **Comprehensiveness**: If multiple sources provide complementary information, synthesize them into a complete answer.\n" | |
| f"5. **Citation**: When referencing specific information, indicate which source(s) you used (e.g., 'According to Source 1...' or 'Sources 1 and 2 indicate...').\n" | |
| f"6. **Completeness**: If the documents don't fully answer the question, clearly state what information is available and what is missing.\n" | |
| f"7. **Clarity**: Write in a clear, professional, and easy-to-understand manner.\n" | |
| f"8. **Directness**: Get straight to the point - provide the answer the user needs without unnecessary preamble.\n" | |
| f"</instructions>\n\n" | |
| f"Provide your answer now:" | |
| ) | |
| return prompt | |
| async def _execute_multi_step(self, req: AgentRequest, steps: List[Dict[str, Any]], | |
| decision: AgentDecision, tool_traces: List[Dict[str, Any]], | |
| reasoning_trace: List[Dict[str, Any]], | |
| pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse: | |
| """ | |
| Execute multiple tools in sequence or parallel and synthesize results with LLM. | |
| Supports parallel execution when steps are marked with "parallel" flag. | |
| """ | |
| start_time = time.time() | |
| rag_data = None | |
| web_data = None | |
| admin_data = None | |
| collected_data = [] | |
| tools_used = [] | |
| total_tokens = 0 | |
| # Detect if this is a news query - if so, skip RAG steps entirely | |
| msg_lower = req.message.lower() | |
| freshness_keywords = ["latest", "today", "news", "current", "recent", | |
| "now", "updates", "breaking", "trending", "happening"] | |
| news_patterns = [ | |
| r"latest news", r"current news", r"today's news", r"breaking news", | |
| r"news about", r"news on", r"news of" | |
| ] | |
| is_news_query = any(k in msg_lower for k in freshness_keywords) or \ | |
| any(re.search(p, msg_lower) for p in news_patterns) | |
| # Filter out RAG steps for news queries | |
| if is_news_query: | |
| steps = [s for s in steps if s.get("tool") != "rag" and "rag" not in str(s.get("parallel", {}))] | |
| reasoning_trace.append({ | |
| "step": "multi_step_news_filter", | |
| "action": "removed_rag_steps", | |
| "remaining_steps": [s.get("tool") if isinstance(s, dict) and "tool" in s else "parallel" for s in steps] | |
| }) | |
| # Check if any step has parallel execution flag | |
| parallel_step = None | |
| for step_info in steps: | |
| if step_info.get("parallel"): | |
| parallel_step = step_info | |
| break | |
| # Handle parallel execution if detected | |
| if parallel_step and parallel_step.get("parallel"): | |
| parallel_config = parallel_step.get("parallel") | |
| parallel_tasks = {} | |
| start_time_parallel = time.time() | |
| # Prepare parallel tasks with retry logic | |
| # Skip RAG for news queries | |
| if "rag" in parallel_config and not is_news_query: | |
| rag_query = parallel_config["rag"] | |
| if pre_fetched_rag: | |
| # Use pre-fetched RAG if available - create a simple async function | |
| async def get_prefetched_rag(): | |
| return pre_fetched_rag | |
| parallel_tasks["rag"] = get_prefetched_rag() | |
| else: | |
| # Wrap with retry logic for parallel execution | |
| async def rag_with_retry_wrapper(): | |
| return await self.rag_with_repair( | |
| query=rag_query, | |
| tenant_id=req.tenant_id, | |
| original_threshold=0.3, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| parallel_tasks["rag"] = rag_with_retry_wrapper() | |
| elif "rag" in parallel_config and is_news_query: | |
| # Remove RAG from parallel config for news queries | |
| parallel_config = {k: v for k, v in parallel_config.items() if k != "rag"} | |
| reasoning_trace.append({ | |
| "step": "parallel_news_filter", | |
| "action": "removed_rag_from_parallel", | |
| "remaining_tools": list(parallel_config.keys()) | |
| }) | |
| if "web" in parallel_config: | |
| web_query = parallel_config["web"] | |
| # Wrap with retry logic for parallel execution | |
| async def web_with_retry_wrapper(): | |
| return await self.web_with_repair( | |
| query=web_query, | |
| tenant_id=req.tenant_id, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| parallel_tasks["web"] = web_with_retry_wrapper() | |
| # Execute tools in parallel | |
| if parallel_tasks: | |
| reasoning_trace.append({ | |
| "step": "parallel_execution", | |
| "tools": list(parallel_tasks.keys()), | |
| "mode": "parallel" | |
| }) | |
| parallel_results = await self.run_parallel_tools(parallel_tasks) | |
| parallel_latency_ms = int((time.time() - start_time_parallel) * 1000) | |
| # Process RAG results | |
| if "rag" in parallel_results: | |
| rag_result = parallel_results["rag"] | |
| if isinstance(rag_result, Exception): | |
| tool_traces.append({"tool": "rag", "error": str(rag_result), "note": "parallel"}) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "status": "error", | |
| "error": str(rag_result), | |
| "latency_ms": parallel_latency_ms | |
| }) | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=parallel_latency_ms, | |
| success=False, | |
| error_message=str(rag_result)[:200], | |
| user_id=req.user_id | |
| ) | |
| else: | |
| rag_data = rag_result | |
| tools_used.append("rag") | |
| tool_traces.append({"tool": "rag", "response": rag_result, "note": "parallel"}) | |
| hits_count = len(self._extract_hits(rag_result)) | |
| avg_score = None | |
| top_score = None | |
| if hits_count > 0: | |
| scores = [h.get("score", 0.0) for h in self._extract_hits(rag_result) if isinstance(h, dict) and "score" in h] | |
| if scores: | |
| avg_score = sum(scores) / len(scores) | |
| top_score = max(scores) | |
| self._analytics_log_rag_search( | |
| tenant_id=req.tenant_id, | |
| query=req.message[:500], | |
| hits_count=hits_count, | |
| avg_score=avg_score, | |
| top_score=top_score, | |
| latency_ms=parallel_latency_ms | |
| ) | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="rag", | |
| latency_ms=parallel_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "hit_count": hits_count, | |
| "summary": self._summarize_hits(rag_result, limit=2), | |
| "latency_ms": parallel_latency_ms, | |
| "mode": "parallel" | |
| }) | |
| # Process Web results | |
| if "web" in parallel_results: | |
| web_result = parallel_results["web"] | |
| if isinstance(web_result, Exception): | |
| tool_traces.append({"tool": "web", "error": str(web_result), "note": "parallel"}) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "status": "error", | |
| "error": str(web_result), | |
| "latency_ms": parallel_latency_ms | |
| }) | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="web", | |
| latency_ms=parallel_latency_ms, | |
| success=False, | |
| error_message=str(web_result)[:200], | |
| user_id=req.user_id | |
| ) | |
| else: | |
| web_data = web_result | |
| tools_used.append("web") | |
| tool_traces.append({"tool": "web", "response": web_result, "note": "parallel"}) | |
| hits_count = len(self._extract_hits(web_result)) | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="web", | |
| latency_ms=parallel_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "hit_count": hits_count, | |
| "summary": self._summarize_hits(web_result, limit=2), | |
| "latency_ms": parallel_latency_ms, | |
| "mode": "parallel" | |
| }) | |
| # Merge parallel results | |
| merged_context = merge_parallel_results(parallel_results) | |
| sources_list = list(set(e.get("source") for e in merged_context if e.get("source"))) if merged_context else [] | |
| reasoning_trace.append({ | |
| "step": "result_merger", | |
| "merged_items": len(merged_context), | |
| "sources": sources_list | |
| }) | |
| # Format merged context for prompt | |
| data_section = format_merged_context_for_prompt(merged_context, max_items=10) | |
| else: | |
| data_section = "" | |
| else: | |
| # Sequential execution (original logic) | |
| parallel_tasks = {} | |
| rag_parallel_query = self._first_query_for_tool(steps, "rag", req.message) | |
| web_parallel_query = self._first_query_for_tool(steps, "web", req.message) | |
| if rag_parallel_query and web_parallel_query and rag_parallel_query == web_parallel_query: | |
| if not pre_fetched_rag: | |
| parallel_tasks["rag"] = asyncio.create_task(self.mcp.call_rag(req.tenant_id, rag_parallel_query)) | |
| parallel_tasks["web"] = asyncio.create_task(self.mcp.call_web(req.tenant_id, web_parallel_query)) | |
| # Execute each step in sequence | |
| for step_info in steps: | |
| tool_name = step_info.get("tool") | |
| step_input = step_info.get("input") or {} | |
| query = step_input.get("query") or req.message | |
| try: | |
| if tool_name == "rag": | |
| # Skip RAG for news queries | |
| if is_news_query: | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "status": "skipped", | |
| "reason": "news_query_detected" | |
| }) | |
| continue # Skip this RAG step | |
| # Reuse pre-fetched RAG if available, otherwise fetch with retry | |
| if pre_fetched_rag and query == rag_parallel_query: | |
| rag_resp = pre_fetched_rag | |
| tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"}) | |
| elif parallel_tasks.get("rag") and query == rag_parallel_query: | |
| rag_resp = await parallel_tasks["rag"] | |
| tool_traces.append({"tool": "rag", "response": rag_resp, "note": "parallel"}) | |
| else: | |
| # Use autonomous retry with self-correction | |
| rag_resp = await self.rag_with_repair( | |
| query=query, | |
| tenant_id=req.tenant_id, | |
| original_threshold=0.3, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| tool_traces.append({"tool": "rag", "response": rag_resp, "note": "with_retry"}) | |
| rag_data = rag_resp | |
| tools_used.append("rag") | |
| hits = self._extract_hits(rag_resp) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "rag", | |
| "hit_count": len(hits), | |
| "summary": self._summarize_hits(rag_resp, limit=2) | |
| }) | |
| # Extract snippets for prompt | |
| if isinstance(rag_resp, dict): | |
| for h in hits[:5]: | |
| txt = h.get("text") or h.get("content") or str(h) | |
| collected_data.append(f"[RAG] {txt}") | |
| elif tool_name == "web": | |
| if parallel_tasks.get("web") and query == web_parallel_query: | |
| web_resp = await parallel_tasks["web"] | |
| tool_traces.append({"tool": "web", "response": web_resp, "note": "parallel"}) | |
| else: | |
| # Use autonomous retry with query rewriting | |
| web_resp = await self.web_with_repair( | |
| query=query, | |
| tenant_id=req.tenant_id, | |
| reasoning_trace=reasoning_trace, | |
| user_id=req.user_id | |
| ) | |
| tool_traces.append({"tool": "web", "response": web_resp, "note": "with_retry"}) | |
| web_data = web_resp | |
| tools_used.append("web") | |
| hits = self._extract_hits(web_resp) | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "web", | |
| "hit_count": len(hits), | |
| "summary": self._summarize_hits(web_resp, limit=2) | |
| }) | |
| # Extract snippets for prompt | |
| if isinstance(web_resp, dict): | |
| for h in hits[:5]: | |
| title = h.get("title") or h.get("headline") or "" | |
| snippet = h.get("snippet") or h.get("summary") or h.get("text") or "" | |
| url = h.get("url") or h.get("link") or "" | |
| collected_data.append(f"[WEB] {title}\n{snippet}\nSource: {url}") | |
| elif tool_name == "admin": | |
| admin_resp = await self.mcp.call_admin(req.tenant_id, query) | |
| tool_traces.append({"tool": "admin", "response": admin_resp}) | |
| admin_data = admin_resp | |
| tools_used.append("admin") | |
| collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}") | |
| reasoning_trace.append({ | |
| "step": "tool_execution", | |
| "tool": "admin", | |
| "status": "completed" | |
| }) | |
| elif tool_name == "llm": | |
| # LLM is always last - synthesize all collected data | |
| break | |
| except Exception as e: | |
| tool_traces.append({"tool": tool_name, "error": str(e)}) | |
| # Continue with other tools even if one fails | |
| reasoning_trace.append({ | |
| "step": "error", | |
| "tool": tool_name, | |
| "error": str(e) | |
| }) | |
| # Build comprehensive prompt with all collected data | |
| data_section = "\n---\n".join(collected_data) if collected_data else "" | |
| # Build final response. For admin-identity style questions, bypass generic | |
| # multi-step LLM behaviour and answer directly from RAG data if available. | |
| user_text = req.message.lower() | |
| user_text_normalized = " ".join(user_text.split()) | |
| admin_phrases = [ | |
| "who is the admin", | |
| "who's the admin", | |
| "who is admin", | |
| "who is the administrator", | |
| "who's the administrator", | |
| "who administers this platform", | |
| "who administers the platform", | |
| "who is the owner", | |
| "who's the owner", | |
| "who owns this platform", | |
| "who owns the platform", | |
| "who is the admin of integrachat", | |
| "who's the admin of integrachat", | |
| ] | |
| if any(p in user_text_normalized for p in admin_phrases) or ("admin" in user_text and "who" in user_text): | |
| hits = self._extract_hits(rag_data) if rag_data else [] | |
| if hits: | |
| best = hits[0] | |
| admin_text = best.get("text") or best.get("content") or str(best) | |
| llm_out = f"According to the tenant knowledge base, {admin_text.strip()}" | |
| else: | |
| llm_out = "I don't know who administers this platform based on the tenant data." | |
| llm_latency_ms = 0 | |
| estimated_tokens = len(llm_out) // 4 + len(req.message) // 4 | |
| total_tokens += estimated_tokens | |
| tools_used.append("llm") | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent="multi_step", | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse( | |
| text=llm_out, | |
| decision=decision, | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace + [{ | |
| "step": "llm_response", | |
| "mode": "multi_step_admin_from_rag_only", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens | |
| }] | |
| ) | |
| # Otherwise, build the normal multi-step synthesis prompt. | |
| if data_section: | |
| # Check if we have both RAG and web data | |
| has_rag = "[RAG]" in data_section | |
| has_web = "[WEB]" in data_section | |
| kb_first_note = "" | |
| web_fallback_note = "" | |
| if has_rag and has_web: | |
| kb_first_note = ( | |
| f"\n## KB-First Strategy\n" | |
| f"**Knowledge Base (RAG) was checked FIRST** and found relevant information. " | |
| f"This is the PRIMARY and AUTHORITATIVE source. " | |
| f"Web search results are provided as supplementary information only. " | |
| f"Prioritize Knowledge Base information over web search results.\n\n" | |
| ) | |
| web_fallback_note = ( | |
| f"\n## Web Search Rules\n" | |
| f"When using web search information as supplementary data:\n" | |
| f"- Keep web search details brief and factual\n" | |
| f"- For legal, medical, financial, or safety topics, add: 'For specific advice, consult a qualified professional.'\n" | |
| f"- Clearly distinguish between Knowledge Base (authoritative) and web search (supplementary) information\n\n" | |
| ) | |
| elif has_web and not has_rag: | |
| kb_first_note = ( | |
| f"\n## KB-First Strategy\n" | |
| f"The Knowledge Base was checked FIRST but no relevant information was found. " | |
| f"Web search results below are provided as a FALLBACK. " | |
| f"Keep the response short, factual, and neutral. " | |
| f"For legal, medical, financial, or safety topics, recommend consulting a qualified professional.\n\n" | |
| ) | |
| # Get structured scratchpad context (Anthropic's note-taking) | |
| scratchpad_context = self.context_engineer.get_scratchpad_context(limit=5) | |
| scratchpad_section = f"\n## Structured Notes\n{scratchpad_context}\n" if scratchpad_context else "" | |
| # Build prompt with Anthropic's structured format (XML-style sections) | |
| # Pre-build optional web guidance section to avoid complex nested f-strings | |
| if web_fallback_note: | |
| web_guidance_section = ( | |
| "<web_search_guidance>\n" | |
| f"{web_fallback_note.strip()}\n" | |
| "</web_search_guidance>\n\n" | |
| ) | |
| else: | |
| web_guidance_section = "" | |
| if has_web: | |
| brief_web_instruction = ( | |
| "11. **Brief Web Content**: If using web search, keep that portion of the response brief " | |
| "(2-4 sentences). Add professional disclaimers for legal/medical/financial topics.\n" | |
| ) | |
| else: | |
| brief_web_instruction = "" | |
| prompt = ( | |
| f"<system>\n" | |
| f"You are an assistant helping tenant {req.tenant_id}. " | |
| f"Your goal is to provide the most accurate, comprehensive, and helpful answer possible.\n" | |
| f"</system>\n\n" | |
| f"<background_information>\n" | |
| f"{kb_first_note.strip()}" | |
| f"{scratchpad_section}" | |
| f"</background_information>\n\n" | |
| f"<information_collected>\n" | |
| f"The following details have been gathered from reliable sources:\n\n" | |
| f"{data_section}\n" | |
| f"</information_collected>\n\n" | |
| f"{web_guidance_section}" | |
| f"<user_request>\n" | |
| f"{req.message}\n" | |
| f"</user_request>\n\n" | |
| f"<instructions>\n" | |
| f"## Your Task\n" | |
| f"1. **Primary Goal**: Use the information above to directly and completely address the user's request.\n" | |
| f"2. **Source Priority**: {'If both Knowledge Base (RAG) and web search results are present, prioritize Knowledge Base as the authoritative source. ' if has_rag and has_web else ''}Use web search information only to supplement or when KB has no relevant information.\n" | |
| f"3. **Synthesis**: Combine information from different sources when they provide complementary details.\n" | |
| f"4. **Prioritization**: If sources conflict, prioritize Knowledge Base information over web search results.\n" | |
| f"5. **Completeness**: Provide a comprehensive answer that covers all aspects of the user's question.\n" | |
| f"6. **Accuracy**: Base your answer on the provided information. If information is missing or uncertain, clearly state that.\n" | |
| f"7. **Clarity**: Write in a clear, professional, and easy-to-understand manner.\n" | |
| f"8. **Directness**: Get straight to the point - provide the answer the user needs without unnecessary preamble.\n" | |
| f"9. **Actionability**: If the question requires steps or actions, provide clear, actionable guidance.\n" | |
| f"10. **Citation**: When referencing specific sources, indicate which source(s) you used (e.g., '[RAG]', '[WEB]').\n" | |
| f"{brief_web_instruction}" | |
| f"</instructions>\n\n" | |
| f"If the information is incomplete, explain what can and cannot be concluded from the available data. " | |
| f"Focus on giving the user exactly what they need—clear guidance, accurate facts, and practical steps whenever possible.\n\n" | |
| f"Provide your comprehensive answer now:" | |
| ) | |
| else: | |
| # No data collected, provide best answer from general knowledge | |
| prompt = ( | |
| f"You are an assistant helping tenant {req.tenant_id}.\n\n" | |
| f"## User Question\n{req.message}\n\n" | |
| f"## Context\n" | |
| f"No specific information was found in the knowledge base or web search for this question.\n\n" | |
| f"## Your Task\n" | |
| f"Provide the best possible answer based on your general knowledge. " | |
| f"Be clear, accurate, comprehensive, and helpful. " | |
| f"If you're uncertain about tenant-specific details, acknowledge that and provide general guidance. " | |
| f"Focus on giving the user exactly what they need—clear guidance, accurate facts, and practical steps whenever possible." | |
| ) | |
| # Final LLM synthesis | |
| try: | |
| llm_start = time.time() | |
| llm_out = await self.llm.simple_call(prompt, temperature=req.temperature) | |
| llm_latency_ms = int((time.time() - llm_start) * 1000) | |
| tools_used.append("llm") | |
| estimated_tokens = len(llm_out) // 4 + len(prompt) // 4 | |
| total_tokens += estimated_tokens | |
| self._analytics_log_tool_usage( | |
| tenant_id=req.tenant_id, | |
| tool_name="llm", | |
| latency_ms=llm_latency_ms, | |
| tokens_used=estimated_tokens, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| total_latency_ms = int((time.time() - start_time) * 1000) | |
| self._analytics_log_agent_query( | |
| tenant_id=req.tenant_id, | |
| message_preview=req.message[:200], | |
| intent="multi_step", | |
| tools_used=tools_used, | |
| total_tokens=total_tokens, | |
| total_latency_ms=total_latency_ms, | |
| success=True, | |
| user_id=req.user_id | |
| ) | |
| return AgentResponse( | |
| text=llm_out, | |
| decision=decision, | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace + [{ | |
| "step": "llm_response", | |
| "mode": "multi_step_parallel" if parallel_step else "multi_step", | |
| "latency_ms": llm_latency_ms, | |
| "estimated_tokens": estimated_tokens | |
| }] | |
| ) | |
| except Exception as e: | |
| tool_traces.append({"tool": "llm", "error": str(e)}) | |
| error_msg = str(e) | |
| # Provide helpful error message | |
| if "Groq API key" in error_msg or "GROQ_API_KEY" in error_msg: | |
| fallback = ( | |
| f"I couldn't connect to the AI service (Groq). " | |
| f"Error: {error_msg}\n\n" | |
| f"To fix this:\n" | |
| f"1. Get a free Groq API key from https://console.groq.com\n" | |
| f"2. Set GROQ_API_KEY in your .env file or environment variables\n" | |
| f"3. Optionally set GROQ_MODEL (default: llama-3.1-8b-instant)" | |
| ) | |
| else: | |
| fallback = f"I encountered an error while synthesizing the response: {error_msg}" | |
| return AgentResponse( | |
| text=fallback, | |
| decision=AgentDecision( | |
| action="respond", | |
| tool=None, | |
| tool_input=None, | |
| reason=f"multi_step_llm_error: {e}" | |
| ), | |
| tool_traces=tool_traces, | |
| reasoning_trace=reasoning_trace + [{ | |
| "step": "error", | |
| "tool": "llm", | |
| "error": str(e) | |
| }] | |
| ) | |
| # ============================================================= | |
| # AUTONOMOUS RETRY + SELF-CORRECTION SYSTEM | |
| # ============================================================= | |
| """ | |
| This system provides autonomous retry and self-correction capabilities | |
| for the agent orchestrator. It enables the agent to: | |
| 1. **Self-healing**: Tools that break automatically retry with adjusted parameters | |
| 2. **Resilient operations**: Handles low RAG scores, empty web results, and misfired rules | |
| 3. **Smart optimization**: Automatically rewrites queries, adjusts thresholds, and optimizes parameters | |
| 4. **Enterprise-grade reliability**: Matches enterprise behavior with comprehensive retry strategies | |
| Key features: | |
| - safe_tool_call(): Generic retry wrapper for any tool call | |
| - rag_with_repair(): RAG search with automatic threshold adjustment and query expansion | |
| - web_with_repair(): Web search with automatic query rewriting for empty results | |
| - rule_safe_message(): Message rewriting to comply with admin rules | |
| All retry attempts are logged to analytics for monitoring and debugging. | |
| """ | |
| async def safe_tool_call( | |
| self, | |
| tool_fn, | |
| params: Dict[str, Any], | |
| max_retries: int = 2, | |
| fallback_params: Optional[Dict[str, Any]] = None, | |
| tool_name: str = "unknown", | |
| tenant_id: Optional[str] = None, | |
| user_id: Optional[str] = None, | |
| reasoning_trace: Optional[List[Dict[str, Any]]] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Wrapper for tool calls with automatic retry and self-correction. | |
| Args: | |
| tool_fn: Async function to call | |
| params: Parameters to pass to tool_fn | |
| max_retries: Maximum number of retry attempts | |
| fallback_params: Alternative parameters to try if initial attempt fails | |
| tool_name: Name of the tool (for logging) | |
| tenant_id: Tenant ID (for analytics) | |
| user_id: User ID (for analytics) | |
| reasoning_trace: Optional reasoning trace to append to | |
| Returns: | |
| Tool result dictionary, or {"error": "tool_failed_after_retries"} if all attempts fail | |
| """ | |
| for attempt in range(max_retries): | |
| try: | |
| result = await tool_fn(**params) | |
| if attempt > 0: | |
| # Log successful retry | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "retry_success", | |
| "tool": tool_name, | |
| "attempt": attempt + 1, | |
| "status": "recovered" | |
| }) | |
| if tenant_id: | |
| self._analytics_log_tool_usage( | |
| tenant_id=tenant_id, | |
| tool_name=f"{tool_name}_retry_{attempt+1}", | |
| latency_ms=0, | |
| success=True, | |
| user_id=user_id | |
| ) | |
| return result | |
| except Exception as e: | |
| error_msg = str(e) | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "retry_attempt", | |
| "tool": tool_name, | |
| "attempt": attempt + 1, | |
| "error": error_msg[:200] | |
| }) | |
| # Log failed attempt | |
| if tenant_id: | |
| self._analytics_log_tool_usage( | |
| tenant_id=tenant_id, | |
| tool_name=tool_name, | |
| latency_ms=0, | |
| success=False, | |
| error_message=error_msg[:200], | |
| user_id=user_id | |
| ) | |
| # Try alternate params if provided and not last attempt | |
| if fallback_params and attempt < max_retries - 1: | |
| params = {**params, **fallback_params} | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "retry_with_fallback_params", | |
| "tool": tool_name, | |
| "attempt": attempt + 2, | |
| "fallback_params": fallback_params | |
| }) | |
| # If last attempt, return error | |
| if attempt == max_retries - 1: | |
| return {"error": "tool_failed_after_retries", "error_message": error_msg} | |
| return {"error": "tool_failed_after_retries"} | |
| async def rag_with_repair( | |
| self, | |
| query: str, | |
| tenant_id: str, | |
| original_threshold: float = 0.3, | |
| reasoning_trace: Optional[List[Dict[str, Any]]] = None, | |
| user_id: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| RAG search with automatic self-correction for low scores. | |
| Strategy: | |
| 1. Try with original threshold | |
| 2. If top_score < 0.30, retry with lower threshold (0.15) | |
| 3. If still low (< 0.15), expand query and retry | |
| """ | |
| # Initial attempt | |
| rag_start = time.time() | |
| result = await self.mcp.call_rag(tenant_id, query, threshold=original_threshold) | |
| rag_latency_ms = int((time.time() - rag_start) * 1000) | |
| # Extract hits and calculate scores | |
| hits = self._extract_hits(result) | |
| top_score = None | |
| avg_score = None | |
| if hits: | |
| scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h] | |
| if scores: | |
| top_score = max(scores) | |
| avg_score = sum(scores) / len(scores) | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rag_initial_search", | |
| "query": query[:200], | |
| "hits_count": len(hits), | |
| "top_score": top_score, | |
| "avg_score": avg_score, | |
| "threshold": original_threshold | |
| }) | |
| # Retry logic: low score → lower threshold | |
| if top_score is not None and top_score < 0.30 and original_threshold >= 0.15: | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rag_retry_low_threshold", | |
| "reason": f"top_score {top_score:.3f} < 0.30, retrying with threshold=0.15" | |
| }) | |
| retry_start = time.time() | |
| result = await self.mcp.call_rag(tenant_id, query, threshold=0.15) | |
| retry_latency_ms = int((time.time() - retry_start) * 1000) | |
| rag_latency_ms += retry_latency_ms | |
| hits = self._extract_hits(result) | |
| if hits: | |
| scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h] | |
| if scores: | |
| top_score = max(scores) | |
| avg_score = sum(scores) / len(scores) | |
| # Log retry | |
| self._analytics_log_tool_usage( | |
| tenant_id=tenant_id, | |
| tool_name="rag_retry_low_threshold", | |
| latency_ms=retry_latency_ms, | |
| success=True, | |
| user_id=user_id | |
| ) | |
| # Final retry: expand query if score still too low | |
| if top_score is not None and top_score < 0.15: | |
| expanded_query = f"{query} (more details comprehensive explanation)" | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rag_retry_expanded_query", | |
| "reason": f"top_score {top_score:.3f} < 0.15, retrying with expanded query", | |
| "original_query": query[:200], | |
| "expanded_query": expanded_query[:200] | |
| }) | |
| retry_start = time.time() | |
| result = await self.mcp.call_rag(tenant_id, expanded_query, threshold=0.15) | |
| retry_latency_ms = int((time.time() - retry_start) * 1000) | |
| rag_latency_ms += retry_latency_ms | |
| hits = self._extract_hits(result) | |
| if hits: | |
| scores = [h.get("score", 0.0) for h in hits if isinstance(h, dict) and "score" in h] | |
| if scores: | |
| top_score = max(scores) | |
| avg_score = sum(scores) / len(scores) | |
| # Log retry | |
| self._analytics_log_tool_usage( | |
| tenant_id=tenant_id, | |
| tool_name="rag_retry_expanded_query", | |
| latency_ms=retry_latency_ms, | |
| success=True, | |
| user_id=user_id | |
| ) | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rag_expanded_query_result", | |
| "hits_count": len(hits), | |
| "top_score": top_score, | |
| "avg_score": avg_score | |
| }) | |
| # Log final RAG search | |
| if hits: | |
| self._analytics_log_rag_search( | |
| tenant_id=tenant_id, | |
| query=query[:500], | |
| hits_count=len(hits), | |
| avg_score=avg_score, | |
| top_score=top_score, | |
| latency_ms=rag_latency_ms | |
| ) | |
| return result | |
| async def web_with_repair( | |
| self, | |
| query: str, | |
| tenant_id: str, | |
| reasoning_trace: Optional[List[Dict[str, Any]]] = None, | |
| user_id: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Web search with multi-query strategy and automatic query rewriting. | |
| Strategy: | |
| 1. Try original query | |
| 2. If empty, generate multiple query variations using query expander | |
| 3. Execute queries in parallel for better results | |
| 4. Merge results from all successful queries | |
| """ | |
| # Detect if this is a news query | |
| query_lower = query.lower() | |
| is_news_query = any(kw in query_lower for kw in ["news", "latest", "breaking", "current", "today", "recent", "update"]) | |
| # Initial attempt | |
| web_start = time.time() | |
| result = await self.mcp.call_web(tenant_id, query) | |
| web_latency_ms = int((time.time() - web_start) * 1000) | |
| hits = self._extract_hits(result) | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "web_initial_search", | |
| "query": query[:200], | |
| "hits_count": len(hits), | |
| "is_news_query": is_news_query | |
| }) | |
| # Multi-query strategy: if initial results are poor, try multiple variations in parallel | |
| if not result or len(hits) < 3: | |
| # Generate query variations | |
| if is_news_query: | |
| # Use query expander for news queries | |
| try: | |
| query_variations = self.query_expander.expand_news_query(query) | |
| except Exception: | |
| query_variations = [ | |
| f"{query} news", | |
| f"latest {query}", | |
| f"{query} latest news", | |
| f"breaking news {query}" | |
| ] | |
| else: | |
| # For general queries, try explanation-focused rewrites | |
| query_variations = [ | |
| f"best explanation of {query}", | |
| f"{query} facts summary", | |
| f"information about {query}", | |
| f"what is {query}" | |
| ] | |
| # Execute multiple queries in parallel | |
| if len(query_variations) > 1: | |
| async def search_variation(q: str): | |
| try: | |
| return await self.mcp.call_web(tenant_id, q) | |
| except Exception as e: | |
| logger.debug(f"Web search failed for query '{q}': {e}") | |
| return None | |
| # Run all variations in parallel | |
| parallel_tasks = {q: search_variation(q) for q in query_variations[:3]} # Limit to 3 parallel | |
| parallel_results = await self.run_parallel_tools(parallel_tasks) | |
| # Merge results from all successful queries | |
| all_hits = [] | |
| seen_urls = set() | |
| # Add original hits | |
| for hit in hits: | |
| url = hit.get("url") or hit.get("link", "") | |
| if url and url not in seen_urls: | |
| all_hits.append(hit) | |
| seen_urls.add(url) | |
| # Add hits from parallel queries | |
| for q, res in parallel_results.items(): | |
| if res and not isinstance(res, Exception): | |
| var_hits = self._extract_hits(res) | |
| for hit in var_hits: | |
| url = hit.get("url") or hit.get("link", "") | |
| if url and url not in seen_urls: | |
| all_hits.append(hit) | |
| seen_urls.add(url) | |
| # Update result with merged hits | |
| if all_hits: | |
| result = {"results": all_hits[:10]} # Limit to top 10 | |
| hits = all_hits[:10] | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "web_multi_query_merge", | |
| "variations_tried": len(query_variations), | |
| "total_hits_merged": len(all_hits), | |
| "final_hits_count": len(hits) | |
| }) | |
| # If parallel didn't help, try one more sequential attempt with best variation | |
| if not all_hits and len(query_variations) > 0: | |
| best_variation = query_variations[0] | |
| retry_start = time.time() | |
| try: | |
| result = await self.mcp.call_web(tenant_id, best_variation) | |
| retry_latency_ms = int((time.time() - retry_start) * 1000) | |
| web_latency_ms += retry_latency_ms | |
| hits = self._extract_hits(result) | |
| if hits: | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "web_sequential_fallback_success", | |
| "query": best_variation[:200], | |
| "hits_count": len(hits) | |
| }) | |
| except Exception as e: | |
| logger.debug(f"Final web search retry failed: {e}") | |
| # Log final web search | |
| self._analytics_log_tool_usage( | |
| tenant_id=tenant_id, | |
| tool_name="web", | |
| latency_ms=web_latency_ms, | |
| success=len(hits) > 0, | |
| user_id=user_id | |
| ) | |
| return result | |
| async def rule_safe_message( | |
| self, | |
| user_message: str, | |
| tenant_id: str, | |
| reasoning_trace: Optional[List[Dict[str, Any]]] = None | |
| ) -> str: | |
| """ | |
| Check admin rules and rewrite message if it violates policies. | |
| Strategy: | |
| 1. Check rules | |
| 2. If blocked, ask LLM to rewrite message to comply | |
| 3. Return safe version | |
| """ | |
| matches: List[RedFlagMatch] = await self.redflag.check(tenant_id, user_message) | |
| if not matches: | |
| return user_message | |
| # Check if any are blocking rules (not just brief response rules) | |
| blocking_rules = [] | |
| for match in matches: | |
| rule_text = (match.description or match.pattern or "").lower() | |
| is_brief_rule = ( | |
| match.severity == "low" and ( | |
| "greeting" in rule_text or | |
| "brief" in rule_text or | |
| "simple response" in rule_text | |
| ) | |
| ) | |
| if not is_brief_rule: | |
| blocking_rules.append(match) | |
| # Only rewrite if there are blocking rules | |
| if not blocking_rules: | |
| return user_message | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rule_violation_detected", | |
| "blocking_rules_count": len(blocking_rules), | |
| "action": "attempting_rewrite" | |
| }) | |
| # Ask LLM to rewrite the message | |
| rewrite_prompt = f"""The following user message violates company policies. Rewrite it to be compliant while preserving the user's intent as much as possible. | |
| Original message: "{user_message}" | |
| Violated policies: | |
| {chr(10).join([f"- {m.description or m.pattern}" for m in blocking_rules[:3]])} | |
| Provide a rewritten version that: | |
| 1. Avoids the policy violations | |
| 2. Preserves the user's original intent | |
| 3. Remains professional and helpful | |
| Rewritten message:""" | |
| try: | |
| rewritten = await self.llm.simple_call(rewrite_prompt, temperature=0.3) | |
| rewritten = rewritten.strip().strip('"').strip("'") | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rule_rewrite_completed", | |
| "original_length": len(user_message), | |
| "rewritten_length": len(rewritten), | |
| "rewritten_preview": rewritten[:200] | |
| }) | |
| # Verify the rewritten message doesn't trigger rules | |
| verify_matches = await self.redflag.check(tenant_id, rewritten) | |
| if not verify_matches or all( | |
| (m.description or m.pattern or "").lower() in ["greeting", "brief", "simple response"] | |
| for m in verify_matches | |
| ): | |
| return rewritten | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rule_rewrite_still_violates", | |
| "action": "using_original_with_block" | |
| }) | |
| except Exception as e: | |
| if reasoning_trace is not None: | |
| reasoning_trace.append({ | |
| "step": "rule_rewrite_failed", | |
| "error": str(e)[:200] | |
| }) | |
| # Return original if rewrite failed or still violates | |
| return user_message | |
| def _build_prompt_with_web(self, req: AgentRequest, web_resp: Dict[str, Any]) -> str: | |
| snippets = [] | |
| if isinstance(web_resp, dict): | |
| hits = web_resp.get("results") or web_resp.get("items") or [] | |
| for i, h in enumerate(hits[:5], 1): | |
| title = h.get("title") or h.get("headline") or "" | |
| snippet = h.get("snippet") or h.get("summary") or h.get("text") or "" | |
| url = h.get("url") or h.get("link") or "" | |
| display_link = h.get("displayLink") or h.get("display_link") or "" | |
| source_info = display_link if display_link else url.split('/')[2] if url else "Unknown source" | |
| snippets.append(f"[Result {i}] {title}\n{snippet}\nSource: {source_info} ({url})") | |
| snippet_text = "\n\n".join(snippets) if snippets else "" | |
| if not snippet_text: | |
| prompt = ( | |
| f"You are an assistant helping tenant {req.tenant_id}.\n\n" | |
| f"## User Question\n{req.message}\n\n" | |
| f"## Context\n" | |
| f"No relevant web search results were found for this question.\n\n" | |
| f"## Your Task\n" | |
| f"Provide the best possible answer based on your general knowledge. " | |
| f"Be clear, accurate, and helpful. If you're uncertain about specific details, " | |
| f"acknowledge that and provide general guidance." | |
| ) | |
| else: | |
| # Build prompt with Anthropic's recommended structure (clear sections with XML tags) | |
| prompt = ( | |
| f"<system>\n" | |
| f"You are an assistant helping tenant {req.tenant_id}. " | |
| f"The Knowledge Base was checked first but no relevant information was found. " | |
| f"Web search results are provided below as a fallback.\n" | |
| f"</system>\n\n" | |
| f"<background_information>\n" | |
| f"## Important Rules for Web Search Responses\n" | |
| f"1. **KB-First Approach**: Always check Knowledge Base first. Web search is ONLY a fallback when KB has no relevant information.\n" | |
| f"2. **Keep it Short**: When using web search, keep responses short, factual, and neutral. Do NOT provide long explanations.\n" | |
| f"3. **No Professional Advice**: Do NOT provide long legal, medical, or highly detailed professional explanations. " | |
| f"If the topic involves legal, medical, financial, or safety-critical advice, provide a brief general explanation " | |
| f"and tell the user to consult a qualified professional.\n" | |
| f"4. **Clear Source Distinction**: Never present external web search information as part of the official Knowledge Base. " | |
| f"Always clarify that this information comes from external sources.\n" | |
| f"5. **Safety First**: For safety-critical topics, always recommend consulting qualified professionals.\n" | |
| f"</background_information>\n\n" | |
| f"<web_search_results>\n" | |
| f"The following search results were found for the user's question:\n\n" | |
| f"{snippet_text}\n" | |
| f"</web_search_results>\n\n" | |
| f"<user_question>\n" | |
| f"{req.message}\n" | |
| f"</user_question>\n\n" | |
| f"<instructions>\n" | |
| f"## Your Task\n" | |
| f"1. **Primary Goal**: Provide a short, factual answer using the web search results above.\n" | |
| f"2. **Keep it Brief**: Limit your response to 2-4 sentences. Do NOT provide lengthy explanations.\n" | |
| f"3. **Accuracy**: Prioritize information from authoritative sources (recognized websites, official sources, etc.).\n" | |
| f"4. **Professional Disclaimers**: For legal, medical, financial, or safety topics, include: " | |
| f"'For specific advice, please consult a qualified professional.'\n" | |
| f"5. **Source Clarity**: Start by mentioning this information comes from web search, not the Knowledge Base.\n" | |
| f"6. **Citation**: Briefly indicate which source(s) you used.\n" | |
| f"</instructions>\n\n" | |
| f"Provide a short, helpful answer now:" | |
| ) | |
| return prompt | |
| def _format_tool_output(self, tool_name: str, output: Any, latency_ms: int) -> Dict[str, Any]: | |
| """ | |
| Format tool output to conform to strict JSON schema. | |
| Args: | |
| tool_name: Name of the tool (rag, web, admin, llm) | |
| output: Raw tool output | |
| latency_ms: Actual latency in milliseconds | |
| Returns: | |
| Formatted output conforming to tool schema | |
| """ | |
| if tool_name == "rag": | |
| # Format RAG output | |
| if isinstance(output, dict): | |
| results = output.get("results") or output.get("hits") or [] | |
| # Ensure each result has required fields | |
| formatted_results = [] | |
| for r in results: | |
| if isinstance(r, dict): | |
| formatted_results.append({ | |
| "text": r.get("text") or r.get("content") or str(r), | |
| "similarity": float(r.get("similarity") or r.get("score") or 0.0), | |
| "metadata": r.get("metadata") or {}, | |
| "doc_id": r.get("doc_id") or r.get("id") | |
| }) | |
| else: | |
| formatted_results.append({ | |
| "text": str(r), | |
| "similarity": 0.5, | |
| "metadata": {}, | |
| "doc_id": None | |
| }) | |
| # Calculate aggregate scores | |
| scores = [r["similarity"] for r in formatted_results if r["similarity"] > 0] | |
| avg_score = sum(scores) / len(scores) if scores else 0.0 | |
| top_score = max(scores) if scores else 0.0 | |
| return { | |
| "results": formatted_results, | |
| "query": output.get("query", ""), | |
| "tenant_id": output.get("tenant_id", ""), | |
| "hits_count": len(formatted_results), | |
| "avg_score": round(avg_score, 3), | |
| "top_score": round(top_score, 3), | |
| "latency_ms": latency_ms | |
| } | |
| else: | |
| # Fallback for non-dict output | |
| return { | |
| "results": [{"text": str(output), "similarity": 0.5, "metadata": {}, "doc_id": None}], | |
| "query": "", | |
| "tenant_id": "", | |
| "hits_count": 1, | |
| "avg_score": 0.5, | |
| "top_score": 0.5, | |
| "latency_ms": latency_ms | |
| } | |
| elif tool_name == "web": | |
| # Format Web output | |
| if isinstance(output, dict): | |
| results = output.get("results") or output.get("items") or [] | |
| formatted_results = [] | |
| for r in results: | |
| if isinstance(r, dict): | |
| formatted_results.append({ | |
| "title": r.get("title") or r.get("headline") or "", | |
| "snippet": r.get("snippet") or r.get("summary") or r.get("text") or "", | |
| "link": r.get("url") or r.get("link") or "", | |
| "displayLink": r.get("displayLink") or r.get("display_link") or "" | |
| }) | |
| else: | |
| formatted_results.append({ | |
| "title": "", | |
| "snippet": str(r), | |
| "link": "", | |
| "displayLink": "" | |
| }) | |
| return { | |
| "results": formatted_results, | |
| "query": output.get("query", ""), | |
| "total_results": output.get("total_results") or output.get("totalResults") or len(formatted_results), | |
| "latency_ms": latency_ms | |
| } | |
| else: | |
| return { | |
| "results": [], | |
| "query": "", | |
| "total_results": 0, | |
| "latency_ms": latency_ms | |
| } | |
| elif tool_name == "admin": | |
| # Format Admin output | |
| if isinstance(output, dict): | |
| violations = output.get("violations") or output.get("matches") or [] | |
| formatted_violations = [] | |
| for v in violations: | |
| if isinstance(v, dict): | |
| formatted_violations.append({ | |
| "rule_id": v.get("rule_id") or v.get("id") or "", | |
| "rule_pattern": v.get("rule_pattern") or v.get("pattern") or "", | |
| "severity": v.get("severity", "medium"), | |
| "matched_text": v.get("matched_text") or v.get("text") or "", | |
| "confidence": float(v.get("confidence", 1.0)), | |
| "message_preview": v.get("message_preview") or v.get("preview") or "" | |
| }) | |
| return { | |
| "violations": formatted_violations, | |
| "checked": output.get("checked", True), | |
| "rules_count": output.get("rules_count") or output.get("rulesCount") or len(formatted_violations), | |
| "latency_ms": latency_ms | |
| } | |
| else: | |
| return { | |
| "violations": [], | |
| "checked": True, | |
| "rules_count": 0, | |
| "latency_ms": latency_ms | |
| } | |
| elif tool_name == "llm": | |
| # Format LLM output | |
| if isinstance(output, str): | |
| return { | |
| "text": output, | |
| "tokens_used": len(output) // 4, # Rough estimate | |
| "latency_ms": latency_ms, | |
| "model": getattr(self.llm, 'model', 'unknown'), | |
| "temperature": 0.0 | |
| } | |
| elif isinstance(output, dict): | |
| return { | |
| "text": output.get("text") or output.get("response") or str(output), | |
| "tokens_used": output.get("tokens_used") or output.get("tokens") or 0, | |
| "latency_ms": latency_ms, | |
| "model": output.get("model") or getattr(self.llm, 'model', 'unknown'), | |
| "temperature": output.get("temperature", 0.0) | |
| } | |
| else: | |
| return { | |
| "text": str(output), | |
| "tokens_used": 0, | |
| "latency_ms": latency_ms, | |
| "model": getattr(self.llm, 'model', 'unknown'), | |
| "temperature": 0.0 | |
| } | |
| # Unknown tool - return as-is | |
| return output if isinstance(output, dict) else {"output": str(output), "latency_ms": latency_ms} | |
| def _extract_hits(resp: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| if not isinstance(resp, dict): | |
| return [] | |
| return resp.get("results") or resp.get("hits") or resp.get("items") or [] | |
| def _summarize_hits(self, resp: Optional[Dict[str, Any]], limit: int = 3) -> List[str]: | |
| hits = self._extract_hits(resp) | |
| summaries = [] | |
| for hit in hits[:limit]: | |
| if isinstance(hit, dict): | |
| snippet = hit.get("text") or hit.get("content") or hit.get("snippet") or "" | |
| else: | |
| snippet = str(hit) | |
| summaries.append(snippet[:160]) | |
| return summaries | |
| async def run_parallel_tools(self, tasks: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Run multiple tools in parallel using asyncio.gather. | |
| Args: | |
| tasks: Dictionary mapping tool names to coroutines, e.g.: | |
| {"rag": rag_coro, "web": web_coro} | |
| Returns: | |
| Dictionary mapping tool names to results, e.g.: | |
| {"rag": rag_result, "web": web_result} | |
| Exceptions are returned as values if a tool fails. | |
| """ | |
| if not tasks: | |
| return {} | |
| names = list(tasks.keys()) | |
| coros = [tasks[name] for name in names] | |
| # Run all coroutines in parallel, return exceptions instead of raising | |
| results = await asyncio.gather(*coros, return_exceptions=True) | |
| return {names[i]: results[i] for i in range(len(names))} | |
| def _first_query_for_tool(steps: List[Dict[str, Any]], tool_name: str, default_query: str) -> Optional[str]: | |
| for step in steps: | |
| if step.get("tool") == tool_name: | |
| input_data = step.get("input") or {} | |
| return input_data.get("query") or default_query | |
| return None | |