AI-Agent / agent.py
Valtry's picture
Upload 2 files
d70c8a7 verified
import re
from typing import Dict, Iterator, List
from memory import get_relevant_context, save_interaction
from model import get_model_manager
from tools import calculator_tool, datetime_tool, text_stats_tool, web_search_tool
class AgentRouter:
def __init__(self) -> None:
self.model = get_model_manager()
@staticmethod
def _detect_tool_intents(message: str) -> List[str]:
lower = message.lower()
web_topic_keywords = [
"news",
"update",
"updates",
"trending",
"stock",
"price",
"weather",
"headline",
"headlines",
"happening",
"happened",
]
freshness_keywords = ["latest", "current", "recent", "today", "now", "this week", "this month"]
datetime_keywords = [
"time",
"date",
"timezone",
"clock",
"what time",
"current time",
"today's date",
"todays date",
]
calc_keywords = [
"calculate",
"compute",
"solve",
"math",
"equation",
"sum",
"multiply",
"divide",
"plus",
"minus",
]
text_stats_keywords = [
"word count",
"count words",
"character count",
"text stats",
"text statistics",
"count characters",
]
tokens = set(re.findall(r"\b\w+\b", lower))
def has_phrase_or_token(keyword: str) -> bool:
if " " in keyword:
return keyword in lower
return keyword in tokens
intents: List[str] = []
has_web_topic = any(has_phrase_or_token(k) for k in web_topic_keywords)
has_freshness = any(has_phrase_or_token(k) for k in freshness_keywords)
has_datetime = any(has_phrase_or_token(k) for k in datetime_keywords)
has_calc = any(has_phrase_or_token(k) for k in calc_keywords)
# Avoid misrouting "current time" style prompts to web search.
if has_web_topic or (has_freshness and not has_datetime and not has_calc):
intents.append("web_search")
if has_datetime:
intents.append("datetime")
if has_calc:
intents.append("calculator")
# Fallback detection for math-like expressions.
if "calculator" not in intents and re.search(r"[0-9][0-9\s\+\-\*/\(\)\.\^%]+", lower):
intents.append("calculator")
if any(has_phrase_or_token(k) for k in text_stats_keywords):
intents.append("text_stats")
return intents if intents else ["llm"]
@staticmethod
def _extract_expression(message: str) -> str:
normalized = message.strip().replace("^", "**")
normalized = re.sub(r"^\s*(calculate|compute|solve|what is|what's)\s+", "", normalized, flags=re.IGNORECASE)
allowed_words = {
"sqrt",
"sin",
"cos",
"tan",
"log",
"log10",
"exp",
"fabs",
"ceil",
"floor",
"pow",
"pi",
"e",
}
token_pattern = r"[A-Za-z_]+|\d+\.\d+|\d+|\*\*|[+\-*/()%.,]"
raw_tokens = re.findall(token_pattern, normalized)
expression_tokens: List[str] = []
for token in raw_tokens:
if re.fullmatch(r"[A-Za-z_]+", token):
lowered = token.lower()
if lowered in allowed_words:
expression_tokens.append(lowered)
else:
expression_tokens.append(token)
expression = "".join(expression_tokens).strip(" ,")
if not expression:
return normalized
return expression
def _run_tools(self, intents: List[str], message: str) -> Dict[str, str]:
outputs: Dict[str, str] = {}
for intent in intents:
if intent == "datetime":
outputs["datetime"] = datetime_tool()
elif intent == "web_search":
outputs["web_search"] = web_search_tool(message, max_results=5)
elif intent == "calculator":
expression = self._extract_expression(message)
result = calculator_tool(expression)
outputs["calculator"] = f"Expression: {expression}\nResult: {result}"
elif intent == "text_stats":
outputs["text_stats"] = text_stats_tool(message)
return outputs
@staticmethod
def _friendly_direct_response(tool_outputs: Dict[str, str]) -> str:
lines: List[str] = ["Sure, here you go:"]
if "datetime" in tool_outputs:
date_line = ""
time_line = ""
for line in tool_outputs["datetime"].splitlines():
if line.startswith("Current date:"):
date_line = line.replace("Current date:", "").strip()
if line.startswith("Current time:"):
time_line = line.replace("Current time:", "").strip()
if date_line or time_line:
lines.append(f"- Date and time: {date_line} {time_line}".strip())
if "calculator" in tool_outputs:
result_line = next(
(line for line in tool_outputs["calculator"].splitlines() if line.startswith("Result:")),
"Result: N/A",
)
result = result_line.replace("Result:", "").strip()
lines.append(f"- Calculation result: {result}")
if "text_stats" in tool_outputs:
stats = tool_outputs["text_stats"].replace("\n", " | ")
lines.append(f"- Text stats: {stats}")
return "\n".join(lines)
@staticmethod
def _is_unhelpful_web_response(text: str) -> bool:
lower = text.lower()
bad_patterns = [
"i don't have access to real-time",
"i do not have access to real-time",
"i can't access real-time",
"cannot access real-time",
"as an ai language model",
"you can use any reliable news",
]
return any(pattern in lower for pattern in bad_patterns)
@staticmethod
def _summarize_web_tool_output(tool_output: str, message: str) -> str:
if tool_output.startswith("Web search unavailable"):
return "Web search is currently unavailable. Please try again in a moment."
if tool_output.startswith("No web results found"):
return f"I could not find recent web results for: {message}."
lines = [line.strip() for line in tool_output.splitlines() if line.strip()]
bullets = []
for line in lines[:5]:
# Expected line format:
# 1. Title | snippet text | Source: https://...
match = re.match(r"^\d+\.\s+(.*?)\s+\|\s+(.*?)\s+\|\s+Source:\s+(.*)$", line)
if match:
title, snippet, source = match.groups()
bullets.append(f"- {title}: {snippet} (Source: {source})")
else:
bullets.append(f"- {line}")
if not bullets:
return "I found web results, but could not format them cleanly. Please retry."
return "Here are the latest web results:\n" + "\n".join(bullets)
@staticmethod
def _extra_tools_summary(tool_outputs: Dict[str, str]) -> str:
extra: List[str] = []
if "datetime" in tool_outputs:
extra.append(tool_outputs["datetime"])
if "calculator" in tool_outputs:
extra.append(tool_outputs["calculator"])
if "text_stats" in tool_outputs:
extra.append(tool_outputs["text_stats"])
if not extra:
return ""
return "\n\nAdditional tool outputs:\n" + "\n\n".join(extra)
def respond(self, user_id: str, message: str) -> Dict[str, object]:
memory_context = get_relevant_context(user_id, message)
intents = self._detect_tool_intents(message)
if intents == ["llm"]:
response = self.model.generate(
message=message,
memory_context=memory_context,
tool_context="",
)
save_interaction(user_id, message, response)
return {
"response": response,
"route_used": "llm",
"tools_used": [],
}
tool_outputs = self._run_tools(intents, message)
tools_used = list(tool_outputs.keys())
deterministic_only = set(tool_outputs.keys()).issubset({"datetime", "calculator", "text_stats"})
if deterministic_only:
response = self._friendly_direct_response(tool_outputs)
save_interaction(user_id, message, response)
route_used = "multi_tool_deterministic" if len(tools_used) > 1 else tools_used[0]
return {
"response": response,
"route_used": route_used,
"tools_used": tools_used,
}
tool_context_parts = []
for tool_name, tool_output in tool_outputs.items():
tool_context_parts.append(f"Tool used: {tool_name}\n{tool_output}")
tool_context = "\n\n".join(tool_context_parts)
if "web_search" in tool_outputs:
web_instruction = (
"Answer using only the provided web results. "
"Do not say you lack real-time access. "
"Provide a concise, friendly summary with sources."
)
response = self.model.generate(
message=f"{web_instruction}\n\nUser request: {message}",
memory_context=memory_context,
tool_context=tool_context,
)
if self._is_unhelpful_web_response(response):
response = self._summarize_web_tool_output(tool_outputs["web_search"], message)
extra = self._extra_tools_summary(tool_outputs)
if extra:
response = f"{response}{extra}".strip()
save_interaction(user_id, message, response)
route_used = "multi_tool_web" if len(tools_used) > 1 else "web_search"
return {
"response": response,
"route_used": route_used,
"tools_used": tools_used,
}
response = self.model.generate(
message=message,
memory_context=memory_context,
tool_context=tool_context,
)
save_interaction(user_id, message, response)
return {
"response": response,
"route_used": "tool_augmented_llm",
"tools_used": tools_used,
}
@staticmethod
def _split_stream_chunks(text: str, chunk_size: int = 18) -> Iterator[str]:
if not text:
return
words = text.split()
if not words:
return
buf = []
for word in words:
buf.append(word)
if len(buf) >= chunk_size:
yield " ".join(buf) + " "
buf = []
if buf:
yield " ".join(buf)
def stream_respond(self, user_id: str, message: str) -> Iterator[Dict[str, object]]:
memory_context = get_relevant_context(user_id, message)
intents = self._detect_tool_intents(message)
if intents == ["llm"]:
accumulated = ""
for delta in self.model.stream_generate(message=message, memory_context=memory_context, tool_context=""):
accumulated += delta
yield {
"type": "chunk",
"delta": delta,
"route_used": "llm",
"tools_used": [],
}
final_text = self.model.clean_response(accumulated)
if not final_text:
final_text = self.model.generate(message=message, memory_context=memory_context, tool_context="")
save_interaction(user_id, message, final_text)
yield {
"type": "done",
"response": final_text,
"route_used": "llm",
"tools_used": [],
}
return
tool_outputs = self._run_tools(intents, message)
tools_used = list(tool_outputs.keys())
deterministic_only = set(tool_outputs.keys()).issubset({"datetime", "calculator", "text_stats"})
if deterministic_only:
final_text = self._friendly_direct_response(tool_outputs)
route_used = "multi_tool_deterministic" if len(tools_used) > 1 else tools_used[0]
for delta in self._split_stream_chunks(final_text):
yield {
"type": "chunk",
"delta": delta,
"route_used": route_used,
"tools_used": tools_used,
}
save_interaction(user_id, message, final_text)
yield {
"type": "done",
"response": final_text,
"route_used": route_used,
"tools_used": tools_used,
}
return
tool_context_parts = []
for tool_name, tool_output in tool_outputs.items():
tool_context_parts.append(f"Tool used: {tool_name}\n{tool_output}")
tool_context = "\n\n".join(tool_context_parts)
if "web_search" in tool_outputs:
# Stream deterministic web summaries for reliability.
base_text = self._summarize_web_tool_output(tool_outputs["web_search"], message)
extra = self._extra_tools_summary(tool_outputs)
final_text = f"{base_text}{extra}".strip() if extra else base_text
route_used = "multi_tool_web" if len(tools_used) > 1 else "web_search"
for delta in self._split_stream_chunks(final_text):
yield {
"type": "chunk",
"delta": delta,
"route_used": route_used,
"tools_used": tools_used,
}
save_interaction(user_id, message, final_text)
yield {
"type": "done",
"response": final_text,
"route_used": route_used,
"tools_used": tools_used,
}
return
accumulated = ""
for delta in self.model.stream_generate(message=message, memory_context=memory_context, tool_context=tool_context):
accumulated += delta
yield {
"type": "chunk",
"delta": delta,
"route_used": "tool_augmented_llm",
"tools_used": tools_used,
}
final_text = self.model.clean_response(accumulated)
if not final_text:
final_text = self.model.generate(message=message, memory_context=memory_context, tool_context=tool_context)
save_interaction(user_id, message, final_text)
yield {
"type": "done",
"response": final_text,
"route_used": "tool_augmented_llm",
"tools_used": tools_used,
}
agent_router = AgentRouter()