Valtry commited on
Commit
cbfc437
·
verified ·
1 Parent(s): e4f9ea1

Upload 2 files

Browse files
Files changed (2) hide show
  1. agent.py +108 -48
  2. tools.py +12 -0
agent.py CHANGED
@@ -1,9 +1,9 @@
1
  import re
2
- from typing import Tuple
3
 
4
  from memory import get_relevant_context, save_interaction
5
  from model import get_model_manager
6
- from tools import calculator_tool, datetime_tool, web_search_tool
7
 
8
 
9
  class AgentRouter:
@@ -11,7 +11,7 @@ class AgentRouter:
11
  self.model = get_model_manager()
12
 
13
  @staticmethod
14
- def _classify(message: str) -> str:
15
  lower = message.lower()
16
 
17
  web_keywords = [
@@ -50,6 +50,14 @@ class AgentRouter:
50
  "plus",
51
  "minus",
52
  ]
 
 
 
 
 
 
 
 
53
 
54
  tokens = set(re.findall(r"\b\w+\b", lower))
55
 
@@ -58,24 +66,25 @@ class AgentRouter:
58
  return keyword in lower
59
  return keyword in tokens
60
 
61
- has_web_intent = any(has_phrase_or_token(k) for k in web_keywords)
62
- has_datetime_intent = any(has_phrase_or_token(k) for k in datetime_keywords)
63
 
64
- # Give web queries priority when both are present (e.g., "latest AI news today").
65
- if has_web_intent:
66
- return "web"
67
 
68
- if has_datetime_intent:
69
- return "datetime"
70
 
71
  if any(has_phrase_or_token(k) for k in calc_keywords):
72
- return "calculator"
73
 
74
  # Fallback detection for math-like expressions.
75
- if re.search(r"[0-9][0-9\s\+\-\*/\(\)\.\^%]+", lower):
76
- return "calculator"
 
 
 
77
 
78
- return "llm"
79
 
80
  @staticmethod
81
  def _extract_expression(message: str) -> str:
@@ -88,36 +97,58 @@ class AgentRouter:
88
  if cleaned.lower().startswith(prefix):
89
  cleaned = cleaned[len(prefix):].strip()
90
 
91
- return cleaned.strip() or message.strip()
 
 
 
92
 
93
- def _tool_context(self, route: str, message: str) -> Tuple[str, str]:
94
- if route == "datetime":
95
- tool_result = datetime_tool()
96
- return "datetime", tool_result
97
 
98
- if route == "web":
99
- tool_result = web_search_tool(message, max_results=5)
100
- return "web_search", tool_result
101
 
102
- if route == "calculator":
103
- expression = self._extract_expression(message)
104
- tool_result = calculator_tool(expression)
105
- return "calculator", f"Expression: {expression}\nResult: {tool_result}"
 
 
 
 
 
 
 
106
 
107
- return "llm", ""
108
 
109
  @staticmethod
110
- def _direct_tool_response(tool_name: str, tool_output: str) -> str:
111
- if tool_name == "datetime":
112
- return tool_output
113
-
114
- if tool_name == "calculator":
115
- lines = tool_output.splitlines()
116
- result_line = next((line for line in lines if line.startswith("Result:")), "Result: N/A")
 
 
 
 
 
 
 
 
 
 
 
 
117
  result = result_line.replace("Result:", "").strip()
118
- return f"The result is {result}."
 
 
 
 
119
 
120
- return tool_output
121
 
122
  @staticmethod
123
  def _is_unhelpful_web_response(text: str) -> bool:
@@ -158,27 +189,52 @@ class AgentRouter:
158
 
159
  return "Here are the latest web results:\n" + "\n".join(bullets)
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  def respond(self, user_id: str, message: str) -> str:
162
  memory_context = get_relevant_context(user_id, message)
163
- route = self._classify(message)
164
- tool_name, tool_output = self._tool_context(route, message)
165
 
166
- # For deterministic utility queries, respond directly from tools.
167
- if tool_name in {"datetime", "calculator"}:
168
- response = self._direct_tool_response(tool_name, tool_output)
 
 
 
169
  save_interaction(user_id, message, response)
170
  return response
171
 
172
- if tool_name == "llm":
173
- tool_context = ""
174
- else:
175
- tool_context = f"Tool used: {tool_name}\n{tool_output}"
176
 
177
- if tool_name == "web_search":
 
 
 
 
 
 
 
 
 
 
 
178
  web_instruction = (
179
  "Answer using only the provided web results. "
180
  "Do not say you lack real-time access. "
181
- "Provide a concise summary with sources."
182
  )
183
  response = self.model.generate(
184
  message=f"{web_instruction}\n\nUser request: {message}",
@@ -187,7 +243,11 @@ class AgentRouter:
187
  )
188
 
189
  if self._is_unhelpful_web_response(response):
190
- response = self._summarize_web_tool_output(tool_output, message)
 
 
 
 
191
 
192
  save_interaction(user_id, message, response)
193
  return response
 
1
  import re
2
+ from typing import Dict, List
3
 
4
  from memory import get_relevant_context, save_interaction
5
  from model import get_model_manager
6
+ from tools import calculator_tool, datetime_tool, text_stats_tool, web_search_tool
7
 
8
 
9
  class AgentRouter:
 
11
  self.model = get_model_manager()
12
 
13
  @staticmethod
14
+ def _detect_tool_intents(message: str) -> List[str]:
15
  lower = message.lower()
16
 
17
  web_keywords = [
 
50
  "plus",
51
  "minus",
52
  ]
53
+ text_stats_keywords = [
54
+ "word count",
55
+ "count words",
56
+ "character count",
57
+ "text stats",
58
+ "text statistics",
59
+ "count characters",
60
+ ]
61
 
62
  tokens = set(re.findall(r"\b\w+\b", lower))
63
 
 
66
  return keyword in lower
67
  return keyword in tokens
68
 
69
+ intents: List[str] = []
 
70
 
71
+ if any(has_phrase_or_token(k) for k in web_keywords):
72
+ intents.append("web_search")
 
73
 
74
+ if any(has_phrase_or_token(k) for k in datetime_keywords):
75
+ intents.append("datetime")
76
 
77
  if any(has_phrase_or_token(k) for k in calc_keywords):
78
+ intents.append("calculator")
79
 
80
  # Fallback detection for math-like expressions.
81
+ if "calculator" not in intents and re.search(r"[0-9][0-9\s\+\-\*/\(\)\.\^%]+", lower):
82
+ intents.append("calculator")
83
+
84
+ if any(has_phrase_or_token(k) for k in text_stats_keywords):
85
+ intents.append("text_stats")
86
 
87
+ return intents if intents else ["llm"]
88
 
89
  @staticmethod
90
  def _extract_expression(message: str) -> str:
 
97
  if cleaned.lower().startswith(prefix):
98
  cleaned = cleaned[len(prefix):].strip()
99
 
100
+ matches = re.findall(r"[0-9\s\+\-\*/\(\)\.\^%]+", cleaned)
101
+ ranked = [m.strip() for m in matches if re.search(r"\d", m) and re.search(r"[\+\-\*/\^%]", m)]
102
+ if ranked:
103
+ return max(ranked, key=len).replace("^", "**")
104
 
105
+ return cleaned.strip() or message.strip()
 
 
 
106
 
107
+ def _run_tools(self, intents: List[str], message: str) -> Dict[str, str]:
108
+ outputs: Dict[str, str] = {}
 
109
 
110
+ for intent in intents:
111
+ if intent == "datetime":
112
+ outputs["datetime"] = datetime_tool()
113
+ elif intent == "web_search":
114
+ outputs["web_search"] = web_search_tool(message, max_results=5)
115
+ elif intent == "calculator":
116
+ expression = self._extract_expression(message)
117
+ result = calculator_tool(expression)
118
+ outputs["calculator"] = f"Expression: {expression}\nResult: {result}"
119
+ elif intent == "text_stats":
120
+ outputs["text_stats"] = text_stats_tool(message)
121
 
122
+ return outputs
123
 
124
  @staticmethod
125
+ def _friendly_direct_response(tool_outputs: Dict[str, str]) -> str:
126
+ lines: List[str] = ["Sure, here you go:"]
127
+
128
+ if "datetime" in tool_outputs:
129
+ date_line = ""
130
+ time_line = ""
131
+ for line in tool_outputs["datetime"].splitlines():
132
+ if line.startswith("Current date:"):
133
+ date_line = line.replace("Current date:", "").strip()
134
+ if line.startswith("Current time:"):
135
+ time_line = line.replace("Current time:", "").strip()
136
+ if date_line or time_line:
137
+ lines.append(f"- Date and time: {date_line} {time_line}".strip())
138
+
139
+ if "calculator" in tool_outputs:
140
+ result_line = next(
141
+ (line for line in tool_outputs["calculator"].splitlines() if line.startswith("Result:")),
142
+ "Result: N/A",
143
+ )
144
  result = result_line.replace("Result:", "").strip()
145
+ lines.append(f"- Calculation result: {result}")
146
+
147
+ if "text_stats" in tool_outputs:
148
+ stats = tool_outputs["text_stats"].replace("\n", " | ")
149
+ lines.append(f"- Text stats: {stats}")
150
 
151
+ return "\n".join(lines)
152
 
153
  @staticmethod
154
  def _is_unhelpful_web_response(text: str) -> bool:
 
189
 
190
  return "Here are the latest web results:\n" + "\n".join(bullets)
191
 
192
+ @staticmethod
193
+ def _extra_tools_summary(tool_outputs: Dict[str, str]) -> str:
194
+ extra: List[str] = []
195
+ if "datetime" in tool_outputs:
196
+ extra.append(tool_outputs["datetime"])
197
+ if "calculator" in tool_outputs:
198
+ extra.append(tool_outputs["calculator"])
199
+ if "text_stats" in tool_outputs:
200
+ extra.append(tool_outputs["text_stats"])
201
+
202
+ if not extra:
203
+ return ""
204
+
205
+ return "\n\nAdditional tool outputs:\n" + "\n\n".join(extra)
206
+
207
  def respond(self, user_id: str, message: str) -> str:
208
  memory_context = get_relevant_context(user_id, message)
209
+ intents = self._detect_tool_intents(message)
 
210
 
211
+ if intents == ["llm"]:
212
+ response = self.model.generate(
213
+ message=message,
214
+ memory_context=memory_context,
215
+ tool_context="",
216
+ )
217
  save_interaction(user_id, message, response)
218
  return response
219
 
220
+ tool_outputs = self._run_tools(intents, message)
 
 
 
221
 
222
+ deterministic_only = set(tool_outputs.keys()).issubset({"datetime", "calculator", "text_stats"})
223
+ if deterministic_only:
224
+ response = self._friendly_direct_response(tool_outputs)
225
+ save_interaction(user_id, message, response)
226
+ return response
227
+
228
+ tool_context_parts = []
229
+ for tool_name, tool_output in tool_outputs.items():
230
+ tool_context_parts.append(f"Tool used: {tool_name}\n{tool_output}")
231
+ tool_context = "\n\n".join(tool_context_parts)
232
+
233
+ if "web_search" in tool_outputs:
234
  web_instruction = (
235
  "Answer using only the provided web results. "
236
  "Do not say you lack real-time access. "
237
+ "Provide a concise, friendly summary with sources."
238
  )
239
  response = self.model.generate(
240
  message=f"{web_instruction}\n\nUser request: {message}",
 
243
  )
244
 
245
  if self._is_unhelpful_web_response(response):
246
+ response = self._summarize_web_tool_output(tool_outputs["web_search"], message)
247
+
248
+ extra = self._extra_tools_summary(tool_outputs)
249
+ if extra:
250
+ response = f"{response}{extra}".strip()
251
 
252
  save_interaction(user_id, message, response)
253
  return response
tools.py CHANGED
@@ -1,6 +1,7 @@
1
  import ast
2
  import datetime as dt
3
  import math
 
4
  from functools import lru_cache
5
  from typing import Dict, List
6
 
@@ -127,3 +128,14 @@ def datetime_tool() -> str:
127
  f"Current date: {now.strftime('%Y-%m-%d')}\n"
128
  f"Current time: {now.strftime('%H:%M:%S %Z')}"
129
  )
 
 
 
 
 
 
 
 
 
 
 
 
1
  import ast
2
  import datetime as dt
3
  import math
4
+ import re
5
  from functools import lru_cache
6
  from typing import Dict, List
7
 
 
128
  f"Current date: {now.strftime('%Y-%m-%d')}\n"
129
  f"Current time: {now.strftime('%H:%M:%S %Z')}"
130
  )
131
+
132
+
133
+ def text_stats_tool(text: str) -> str:
134
+ cleaned = text.strip()
135
+ if not cleaned:
136
+ return "Words: 0\nCharacters: 0\nSentences: 0"
137
+
138
+ words = len(re.findall(r"\b\w+\b", cleaned))
139
+ chars = len(cleaned)
140
+ sentences = len([s for s in re.split(r"[.!?]+", cleaned) if s.strip()])
141
+ return f"Words: {words}\nCharacters: {chars}\nSentences: {sentences}"