| import re
|
| from typing import Dict, Iterator, List
|
|
|
| from memory import get_relevant_context, save_interaction
|
| from model import get_model_manager
|
| from tools import calculator_tool, datetime_tool, text_stats_tool, web_search_tool
|
|
|
|
|
| class AgentRouter:
|
| def __init__(self) -> None:
|
| self.model = get_model_manager()
|
|
|
| @staticmethod
|
| def _detect_tool_intents(message: str) -> List[str]:
|
| lower = message.lower()
|
|
|
| web_topic_keywords = [
|
| "news",
|
| "update",
|
| "updates",
|
| "trending",
|
| "stock",
|
| "price",
|
| "weather",
|
| "headline",
|
| "headlines",
|
| "happening",
|
| "happened",
|
| ]
|
| freshness_keywords = ["latest", "current", "recent", "today", "now", "this week", "this month"]
|
| datetime_keywords = [
|
| "time",
|
| "date",
|
| "timezone",
|
| "clock",
|
| "what time",
|
| "current time",
|
| "today's date",
|
| "todays date",
|
| ]
|
| calc_keywords = [
|
| "calculate",
|
| "compute",
|
| "solve",
|
| "math",
|
| "equation",
|
| "sum",
|
| "multiply",
|
| "divide",
|
| "plus",
|
| "minus",
|
| ]
|
| text_stats_keywords = [
|
| "word count",
|
| "count words",
|
| "character count",
|
| "text stats",
|
| "text statistics",
|
| "count characters",
|
| ]
|
|
|
| tokens = set(re.findall(r"\b\w+\b", lower))
|
|
|
| def has_phrase_or_token(keyword: str) -> bool:
|
| if " " in keyword:
|
| return keyword in lower
|
| return keyword in tokens
|
|
|
| intents: List[str] = []
|
|
|
| has_web_topic = any(has_phrase_or_token(k) for k in web_topic_keywords)
|
| has_freshness = any(has_phrase_or_token(k) for k in freshness_keywords)
|
| has_datetime = any(has_phrase_or_token(k) for k in datetime_keywords)
|
| has_calc = any(has_phrase_or_token(k) for k in calc_keywords)
|
|
|
|
|
| if has_web_topic or (has_freshness and not has_datetime and not has_calc):
|
| intents.append("web_search")
|
|
|
| if has_datetime:
|
| intents.append("datetime")
|
|
|
| if has_calc:
|
| intents.append("calculator")
|
|
|
|
|
| if "calculator" not in intents and re.search(r"[0-9][0-9\s\+\-\*/\(\)\.\^%]+", lower):
|
| intents.append("calculator")
|
|
|
| if any(has_phrase_or_token(k) for k in text_stats_keywords):
|
| intents.append("text_stats")
|
|
|
| return intents if intents else ["llm"]
|
|
|
| @staticmethod
|
| def _extract_expression(message: str) -> str:
|
| normalized = message.strip().replace("^", "**")
|
| normalized = re.sub(r"^\s*(calculate|compute|solve|what is|what's)\s+", "", normalized, flags=re.IGNORECASE)
|
|
|
| allowed_words = {
|
| "sqrt",
|
| "sin",
|
| "cos",
|
| "tan",
|
| "log",
|
| "log10",
|
| "exp",
|
| "fabs",
|
| "ceil",
|
| "floor",
|
| "pow",
|
| "pi",
|
| "e",
|
| }
|
|
|
| token_pattern = r"[A-Za-z_]+|\d+\.\d+|\d+|\*\*|[+\-*/()%.,]"
|
| raw_tokens = re.findall(token_pattern, normalized)
|
|
|
| expression_tokens: List[str] = []
|
| for token in raw_tokens:
|
| if re.fullmatch(r"[A-Za-z_]+", token):
|
| lowered = token.lower()
|
| if lowered in allowed_words:
|
| expression_tokens.append(lowered)
|
| else:
|
| expression_tokens.append(token)
|
|
|
| expression = "".join(expression_tokens).strip(" ,")
|
|
|
| if not expression:
|
| return normalized
|
|
|
| return expression
|
|
|
| def _run_tools(self, intents: List[str], message: str) -> Dict[str, str]:
|
| outputs: Dict[str, str] = {}
|
|
|
| for intent in intents:
|
| if intent == "datetime":
|
| outputs["datetime"] = datetime_tool()
|
| elif intent == "web_search":
|
| outputs["web_search"] = web_search_tool(message, max_results=5)
|
| elif intent == "calculator":
|
| expression = self._extract_expression(message)
|
| result = calculator_tool(expression)
|
| outputs["calculator"] = f"Expression: {expression}\nResult: {result}"
|
| elif intent == "text_stats":
|
| outputs["text_stats"] = text_stats_tool(message)
|
|
|
| return outputs
|
|
|
| @staticmethod
|
| def _friendly_direct_response(tool_outputs: Dict[str, str]) -> str:
|
| lines: List[str] = ["Sure, here you go:"]
|
|
|
| if "datetime" in tool_outputs:
|
| date_line = ""
|
| time_line = ""
|
| for line in tool_outputs["datetime"].splitlines():
|
| if line.startswith("Current date:"):
|
| date_line = line.replace("Current date:", "").strip()
|
| if line.startswith("Current time:"):
|
| time_line = line.replace("Current time:", "").strip()
|
| if date_line or time_line:
|
| lines.append(f"- Date and time: {date_line} {time_line}".strip())
|
|
|
| if "calculator" in tool_outputs:
|
| result_line = next(
|
| (line for line in tool_outputs["calculator"].splitlines() if line.startswith("Result:")),
|
| "Result: N/A",
|
| )
|
| result = result_line.replace("Result:", "").strip()
|
| lines.append(f"- Calculation result: {result}")
|
|
|
| if "text_stats" in tool_outputs:
|
| stats = tool_outputs["text_stats"].replace("\n", " | ")
|
| lines.append(f"- Text stats: {stats}")
|
|
|
| return "\n".join(lines)
|
|
|
| @staticmethod
|
| def _is_unhelpful_web_response(text: str) -> bool:
|
| lower = text.lower()
|
| bad_patterns = [
|
| "i don't have access to real-time",
|
| "i do not have access to real-time",
|
| "i can't access real-time",
|
| "cannot access real-time",
|
| "as an ai language model",
|
| "you can use any reliable news",
|
| ]
|
| return any(pattern in lower for pattern in bad_patterns)
|
|
|
| @staticmethod
|
| def _summarize_web_tool_output(tool_output: str, message: str) -> str:
|
| if tool_output.startswith("Web search unavailable"):
|
| return "Web search is currently unavailable. Please try again in a moment."
|
|
|
| if tool_output.startswith("No web results found"):
|
| return f"I could not find recent web results for: {message}."
|
|
|
| lines = [line.strip() for line in tool_output.splitlines() if line.strip()]
|
| bullets = []
|
|
|
| for line in lines[:5]:
|
|
|
|
|
| match = re.match(r"^\d+\.\s+(.*?)\s+\|\s+(.*?)\s+\|\s+Source:\s+(.*)$", line)
|
| if match:
|
| title, snippet, source = match.groups()
|
| bullets.append(f"- {title}: {snippet} (Source: {source})")
|
| else:
|
| bullets.append(f"- {line}")
|
|
|
| if not bullets:
|
| return "I found web results, but could not format them cleanly. Please retry."
|
|
|
| return "Here are the latest web results:\n" + "\n".join(bullets)
|
|
|
| @staticmethod
|
| def _extra_tools_summary(tool_outputs: Dict[str, str]) -> str:
|
| extra: List[str] = []
|
| if "datetime" in tool_outputs:
|
| extra.append(tool_outputs["datetime"])
|
| if "calculator" in tool_outputs:
|
| extra.append(tool_outputs["calculator"])
|
| if "text_stats" in tool_outputs:
|
| extra.append(tool_outputs["text_stats"])
|
|
|
| if not extra:
|
| return ""
|
|
|
| return "\n\nAdditional tool outputs:\n" + "\n\n".join(extra)
|
|
|
| def respond(self, user_id: str, message: str) -> Dict[str, object]:
|
| memory_context = get_relevant_context(user_id, message)
|
| intents = self._detect_tool_intents(message)
|
|
|
| if intents == ["llm"]:
|
| response = self.model.generate(
|
| message=message,
|
| memory_context=memory_context,
|
| tool_context="",
|
| )
|
| save_interaction(user_id, message, response)
|
| return {
|
| "response": response,
|
| "route_used": "llm",
|
| "tools_used": [],
|
| }
|
|
|
| tool_outputs = self._run_tools(intents, message)
|
| tools_used = list(tool_outputs.keys())
|
|
|
| deterministic_only = set(tool_outputs.keys()).issubset({"datetime", "calculator", "text_stats"})
|
| if deterministic_only:
|
| response = self._friendly_direct_response(tool_outputs)
|
| save_interaction(user_id, message, response)
|
| route_used = "multi_tool_deterministic" if len(tools_used) > 1 else tools_used[0]
|
| return {
|
| "response": response,
|
| "route_used": route_used,
|
| "tools_used": tools_used,
|
| }
|
|
|
| tool_context_parts = []
|
| for tool_name, tool_output in tool_outputs.items():
|
| tool_context_parts.append(f"Tool used: {tool_name}\n{tool_output}")
|
| tool_context = "\n\n".join(tool_context_parts)
|
|
|
| if "web_search" in tool_outputs:
|
| web_instruction = (
|
| "Answer using only the provided web results. "
|
| "Do not say you lack real-time access. "
|
| "Provide a concise, friendly summary with sources."
|
| )
|
| response = self.model.generate(
|
| message=f"{web_instruction}\n\nUser request: {message}",
|
| memory_context=memory_context,
|
| tool_context=tool_context,
|
| )
|
|
|
| if self._is_unhelpful_web_response(response):
|
| response = self._summarize_web_tool_output(tool_outputs["web_search"], message)
|
|
|
| extra = self._extra_tools_summary(tool_outputs)
|
| if extra:
|
| response = f"{response}{extra}".strip()
|
|
|
| save_interaction(user_id, message, response)
|
| route_used = "multi_tool_web" if len(tools_used) > 1 else "web_search"
|
| return {
|
| "response": response,
|
| "route_used": route_used,
|
| "tools_used": tools_used,
|
| }
|
|
|
| response = self.model.generate(
|
| message=message,
|
| memory_context=memory_context,
|
| tool_context=tool_context,
|
| )
|
|
|
| save_interaction(user_id, message, response)
|
| return {
|
| "response": response,
|
| "route_used": "tool_augmented_llm",
|
| "tools_used": tools_used,
|
| }
|
|
|
| @staticmethod
|
| def _split_stream_chunks(text: str, chunk_size: int = 18) -> Iterator[str]:
|
| if not text:
|
| return
|
| words = text.split()
|
| if not words:
|
| return
|
|
|
| buf = []
|
| for word in words:
|
| buf.append(word)
|
| if len(buf) >= chunk_size:
|
| yield " ".join(buf) + " "
|
| buf = []
|
| if buf:
|
| yield " ".join(buf)
|
|
|
| def stream_respond(self, user_id: str, message: str) -> Iterator[Dict[str, object]]:
|
| memory_context = get_relevant_context(user_id, message)
|
| intents = self._detect_tool_intents(message)
|
|
|
| if intents == ["llm"]:
|
| accumulated = ""
|
| for delta in self.model.stream_generate(message=message, memory_context=memory_context, tool_context=""):
|
| accumulated += delta
|
| yield {
|
| "type": "chunk",
|
| "delta": delta,
|
| "route_used": "llm",
|
| "tools_used": [],
|
| }
|
|
|
| final_text = self.model.clean_response(accumulated)
|
| if not final_text:
|
| final_text = self.model.generate(message=message, memory_context=memory_context, tool_context="")
|
|
|
| save_interaction(user_id, message, final_text)
|
| yield {
|
| "type": "done",
|
| "response": final_text,
|
| "route_used": "llm",
|
| "tools_used": [],
|
| }
|
| return
|
|
|
| tool_outputs = self._run_tools(intents, message)
|
| tools_used = list(tool_outputs.keys())
|
| deterministic_only = set(tool_outputs.keys()).issubset({"datetime", "calculator", "text_stats"})
|
|
|
| if deterministic_only:
|
| final_text = self._friendly_direct_response(tool_outputs)
|
| route_used = "multi_tool_deterministic" if len(tools_used) > 1 else tools_used[0]
|
|
|
| for delta in self._split_stream_chunks(final_text):
|
| yield {
|
| "type": "chunk",
|
| "delta": delta,
|
| "route_used": route_used,
|
| "tools_used": tools_used,
|
| }
|
|
|
| save_interaction(user_id, message, final_text)
|
| yield {
|
| "type": "done",
|
| "response": final_text,
|
| "route_used": route_used,
|
| "tools_used": tools_used,
|
| }
|
| return
|
|
|
| tool_context_parts = []
|
| for tool_name, tool_output in tool_outputs.items():
|
| tool_context_parts.append(f"Tool used: {tool_name}\n{tool_output}")
|
| tool_context = "\n\n".join(tool_context_parts)
|
|
|
| if "web_search" in tool_outputs:
|
|
|
| base_text = self._summarize_web_tool_output(tool_outputs["web_search"], message)
|
| extra = self._extra_tools_summary(tool_outputs)
|
| final_text = f"{base_text}{extra}".strip() if extra else base_text
|
| route_used = "multi_tool_web" if len(tools_used) > 1 else "web_search"
|
|
|
| for delta in self._split_stream_chunks(final_text):
|
| yield {
|
| "type": "chunk",
|
| "delta": delta,
|
| "route_used": route_used,
|
| "tools_used": tools_used,
|
| }
|
|
|
| save_interaction(user_id, message, final_text)
|
| yield {
|
| "type": "done",
|
| "response": final_text,
|
| "route_used": route_used,
|
| "tools_used": tools_used,
|
| }
|
| return
|
|
|
| accumulated = ""
|
| for delta in self.model.stream_generate(message=message, memory_context=memory_context, tool_context=tool_context):
|
| accumulated += delta
|
| yield {
|
| "type": "chunk",
|
| "delta": delta,
|
| "route_used": "tool_augmented_llm",
|
| "tools_used": tools_used,
|
| }
|
|
|
| final_text = self.model.clean_response(accumulated)
|
| if not final_text:
|
| final_text = self.model.generate(message=message, memory_context=memory_context, tool_context=tool_context)
|
|
|
| save_interaction(user_id, message, final_text)
|
| yield {
|
| "type": "done",
|
| "response": final_text,
|
| "route_used": "tool_augmented_llm",
|
| "tools_used": tools_used,
|
| }
|
|
|
|
|
| agent_router = AgentRouter()
|
|
|