Spaces:
Sleeping
Sleeping
ernani
commited on
Commit
·
30c3969
1
Parent(s):
57fe70d
adding math tool - fixing prompts
Browse files- manage_agents.py +75 -10
- tools.py +39 -0
manage_agents.py
CHANGED
|
@@ -17,6 +17,7 @@ from tools import (
|
|
| 17 |
ExcelTool,
|
| 18 |
WebSearchTool,
|
| 19 |
ArvixSearchTool,
|
|
|
|
| 20 |
PythonTool,
|
| 21 |
ContentProcessingError
|
| 22 |
)
|
|
@@ -45,7 +46,8 @@ class ContentTypeAgent:
|
|
| 45 |
"excel": ExcelTool(),
|
| 46 |
"web": WebSearchTool(),
|
| 47 |
"arvix": ArvixSearchTool(),
|
| 48 |
-
"python": PythonTool()
|
|
|
|
| 49 |
}
|
| 50 |
|
| 51 |
self.type_identification_prompt = PromptTemplate(
|
|
@@ -61,6 +63,7 @@ class ContentTypeAgent:
|
|
| 61 |
- excel: If the question refers to an Excel file or contains a task ID for Excel
|
| 62 |
- web: If the question is about a research paper, academic work, or requires searching the broader web (including news, journals, or scholarly articles)
|
| 63 |
- python: If the question refers to a Python file or contains a task ID for Python
|
|
|
|
| 64 |
|
| 65 |
Consider these special cases:
|
| 66 |
1. If the question asks to search in Wikipedia, use "wiki"
|
|
@@ -143,6 +146,10 @@ class ContentTypeAgent:
|
|
| 143 |
if content_type == "wiki" and 'wikipedia' not in question_lower:
|
| 144 |
content_type = "web"
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
return content_type, question, task_id
|
| 147 |
|
| 148 |
class ProcessContentAgent:
|
|
@@ -302,6 +309,7 @@ class StateGraphAgent:
|
|
| 302 |
self.excel_tool = ExcelTool()
|
| 303 |
self.python_tool = PythonTool()
|
| 304 |
self.arvix_tool = ArvixSearchTool()
|
|
|
|
| 305 |
|
| 306 |
# Create a dictionary of tools for easy access
|
| 307 |
self.tools = {
|
|
@@ -312,7 +320,8 @@ class StateGraphAgent:
|
|
| 312 |
"audio": self.audio_tool,
|
| 313 |
"excel": self.excel_tool,
|
| 314 |
"python": self.python_tool,
|
| 315 |
-
"arvix": self.arvix_tool
|
|
|
|
| 316 |
}
|
| 317 |
|
| 318 |
# Tool usage tracking
|
|
@@ -479,7 +488,24 @@ class StateGraphAgent:
|
|
| 479 |
"required": ["task_id", "question"]
|
| 480 |
}
|
| 481 |
}
|
| 482 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
}
|
| 484 |
|
| 485 |
def _identify_content_type(self, question, file_name, task_id):
|
|
@@ -590,6 +616,10 @@ class StateGraphAgent:
|
|
| 590 |
question = args.get("question", "")
|
| 591 |
self.last_used_tool = "python"
|
| 592 |
result = self.python_tool._run(task_id, question=question)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
else:
|
| 594 |
result = f"Unknown tool: {tool_name}"
|
| 595 |
|
|
@@ -670,16 +700,31 @@ class StateGraphAgent:
|
|
| 670 |
try:
|
| 671 |
# Reset tool tracking
|
| 672 |
self.last_used_tool = None
|
| 673 |
-
|
| 674 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
direct_answer = self._direct_answer_attempt(question)
|
| 676 |
if direct_answer:
|
| 677 |
self.last_used_tool = "direct"
|
| 678 |
return direct_answer
|
| 679 |
|
| 680 |
-
# Identify content type
|
| 681 |
-
content_type, content_parameter = self._identify_content_type(question, file_name, task_id)
|
| 682 |
-
|
| 683 |
# For file-based questions, use the appropriate tool directly
|
| 684 |
if file_name:
|
| 685 |
if content_type in self.file_tools:
|
|
@@ -900,6 +945,7 @@ class MainAgent:
|
|
| 900 |
self.audio_tool = AudioTool()
|
| 901 |
self.excel_tool = ExcelTool()
|
| 902 |
self.python_tool = PythonTool()
|
|
|
|
| 903 |
|
| 904 |
# Create a dictionary of tools for easy access
|
| 905 |
self.tools = {
|
|
@@ -911,6 +957,7 @@ class MainAgent:
|
|
| 911 |
"audio": self.audio_tool,
|
| 912 |
"excel": self.excel_tool,
|
| 913 |
"python": self.python_tool,
|
|
|
|
| 914 |
}
|
| 915 |
|
| 916 |
# Tool usage tracking
|
|
@@ -1301,6 +1348,9 @@ class MainAgent:
|
|
| 1301 |
task_id = args.get("task_id", "")
|
| 1302 |
question = args.get("question", "")
|
| 1303 |
result = self.python_tool._run(task_id, question=question)
|
|
|
|
|
|
|
|
|
|
| 1304 |
else:
|
| 1305 |
result = f"Unknown tool: {tool_name}"
|
| 1306 |
|
|
@@ -1333,7 +1383,7 @@ class MainAgent:
|
|
| 1333 |
|
| 1334 |
question = state.messages[0].content
|
| 1335 |
query = f"""Analyze the question, understand the instructions, the context.
|
| 1336 |
-
If you can answer this directly
|
| 1337 |
Otherwise respond with 'TOOLS_REQUIRED'
|
| 1338 |
|
| 1339 |
Question: {question}
|
|
@@ -1397,11 +1447,26 @@ class MainAgent:
|
|
| 1397 |
- If the answer is a list, output only the list as requested (e.g., comma-separated, one per line, etc.).
|
| 1398 |
- If the answer is: how many wheels does the car have?, output only the number, not a sentence.
|
| 1399 |
|
|
|
|
|
|
|
| 1400 |
Context:
|
| 1401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1402 |
|
| 1403 |
Question:
|
| 1404 |
{question}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1405 |
"""
|
| 1406 |
response = self.llm.invoke(answer_prompt)
|
| 1407 |
answer = response.content if hasattr(response, 'content') else str(response)
|
|
|
|
| 17 |
ExcelTool,
|
| 18 |
WebSearchTool,
|
| 19 |
ArvixSearchTool,
|
| 20 |
+
MathTool,
|
| 21 |
PythonTool,
|
| 22 |
ContentProcessingError
|
| 23 |
)
|
|
|
|
| 46 |
"excel": ExcelTool(),
|
| 47 |
"web": WebSearchTool(),
|
| 48 |
"arvix": ArvixSearchTool(),
|
| 49 |
+
"python": PythonTool(),
|
| 50 |
+
"math": MathTool()
|
| 51 |
}
|
| 52 |
|
| 53 |
self.type_identification_prompt = PromptTemplate(
|
|
|
|
| 63 |
- excel: If the question refers to an Excel file or contains a task ID for Excel
|
| 64 |
- web: If the question is about a research paper, academic work, or requires searching the broader web (including news, journals, or scholarly articles)
|
| 65 |
- python: If the question refers to a Python file or contains a task ID for Python
|
| 66 |
+
- math: If the question involves an operation table, commutativity, associativity, or similar algebraic property
|
| 67 |
|
| 68 |
Consider these special cases:
|
| 69 |
1. If the question asks to search in Wikipedia, use "wiki"
|
|
|
|
| 146 |
if content_type == "wiki" and 'wikipedia' not in question_lower:
|
| 147 |
content_type = "web"
|
| 148 |
|
| 149 |
+
# Add detection for math/table/commutativity questions
|
| 150 |
+
if any(word in question_lower for word in ["commutative", "associative", "operation table", "table defining *", "counter-examples", "algebraic property"]):
|
| 151 |
+
content_type = "math"
|
| 152 |
+
|
| 153 |
return content_type, question, task_id
|
| 154 |
|
| 155 |
class ProcessContentAgent:
|
|
|
|
| 309 |
self.excel_tool = ExcelTool()
|
| 310 |
self.python_tool = PythonTool()
|
| 311 |
self.arvix_tool = ArvixSearchTool()
|
| 312 |
+
self.math_tool = MathTool()
|
| 313 |
|
| 314 |
# Create a dictionary of tools for easy access
|
| 315 |
self.tools = {
|
|
|
|
| 320 |
"audio": self.audio_tool,
|
| 321 |
"excel": self.excel_tool,
|
| 322 |
"python": self.python_tool,
|
| 323 |
+
"arvix": self.arvix_tool,
|
| 324 |
+
"math": self.math_tool
|
| 325 |
}
|
| 326 |
|
| 327 |
# Tool usage tracking
|
|
|
|
| 488 |
"required": ["task_id", "question"]
|
| 489 |
}
|
| 490 |
}
|
| 491 |
+
},
|
| 492 |
+
"math": {
|
| 493 |
+
"type": "function",
|
| 494 |
+
"function": {
|
| 495 |
+
"name": "analyze_math",
|
| 496 |
+
"description": "Analyze an operation table for algebraic properties (e.g., commutativity)",
|
| 497 |
+
"parameters": {
|
| 498 |
+
"type": "object",
|
| 499 |
+
"properties": {
|
| 500 |
+
"question": {
|
| 501 |
+
"type": "string",
|
| 502 |
+
"description": "The question containing the operation table and set S"
|
| 503 |
+
}
|
| 504 |
+
},
|
| 505 |
+
"required": ["question"]
|
| 506 |
+
}
|
| 507 |
+
}
|
| 508 |
+
},
|
| 509 |
}
|
| 510 |
|
| 511 |
def _identify_content_type(self, question, file_name, task_id):
|
|
|
|
| 616 |
question = args.get("question", "")
|
| 617 |
self.last_used_tool = "python"
|
| 618 |
result = self.python_tool._run(task_id, question=question)
|
| 619 |
+
elif tool_name == "analyze_math":
|
| 620 |
+
question = args.get("question", "")
|
| 621 |
+
self.last_used_tool = "math"
|
| 622 |
+
result = self.math_tool._run(question)
|
| 623 |
else:
|
| 624 |
result = f"Unknown tool: {tool_name}"
|
| 625 |
|
|
|
|
| 700 |
try:
|
| 701 |
# Reset tool tracking
|
| 702 |
self.last_used_tool = None
|
| 703 |
+
|
| 704 |
+
# Detect math/table/commutativity questions and always use MathTool
|
| 705 |
+
math_keywords = [
|
| 706 |
+
"commutative", "associative", "operation table", "table defining *", "counter-examples", "algebraic property"
|
| 707 |
+
]
|
| 708 |
+
question_lower = question.lower()
|
| 709 |
+
if any(word in question_lower for word in math_keywords):
|
| 710 |
+
self.last_used_tool = "math"
|
| 711 |
+
return self.math_tool._run(question)
|
| 712 |
+
|
| 713 |
+
# Identify content type
|
| 714 |
+
content_type, content_parameter = self._identify_content_type(question, file_name, task_id)
|
| 715 |
+
|
| 716 |
+
# If it's a math/table question, use MathTool directly (redundant, but safe)
|
| 717 |
+
if content_type == "math":
|
| 718 |
+
self.last_used_tool = "math"
|
| 719 |
+
result = self.math_tool._run(question)
|
| 720 |
+
return result
|
| 721 |
+
|
| 722 |
+
# Otherwise, try direct answer first
|
| 723 |
direct_answer = self._direct_answer_attempt(question)
|
| 724 |
if direct_answer:
|
| 725 |
self.last_used_tool = "direct"
|
| 726 |
return direct_answer
|
| 727 |
|
|
|
|
|
|
|
|
|
|
| 728 |
# For file-based questions, use the appropriate tool directly
|
| 729 |
if file_name:
|
| 730 |
if content_type in self.file_tools:
|
|
|
|
| 945 |
self.audio_tool = AudioTool()
|
| 946 |
self.excel_tool = ExcelTool()
|
| 947 |
self.python_tool = PythonTool()
|
| 948 |
+
self.math_tool = MathTool()
|
| 949 |
|
| 950 |
# Create a dictionary of tools for easy access
|
| 951 |
self.tools = {
|
|
|
|
| 957 |
"audio": self.audio_tool,
|
| 958 |
"excel": self.excel_tool,
|
| 959 |
"python": self.python_tool,
|
| 960 |
+
"math": self.math_tool
|
| 961 |
}
|
| 962 |
|
| 963 |
# Tool usage tracking
|
|
|
|
| 1348 |
task_id = args.get("task_id", "")
|
| 1349 |
question = args.get("question", "")
|
| 1350 |
result = self.python_tool._run(task_id, question=question)
|
| 1351 |
+
elif tool_name == "analyze_math":
|
| 1352 |
+
question = args.get("question", "")
|
| 1353 |
+
result = self.math_tool._run(question)
|
| 1354 |
else:
|
| 1355 |
result = f"Unknown tool: {tool_name}"
|
| 1356 |
|
|
|
|
| 1383 |
|
| 1384 |
question = state.messages[0].content
|
| 1385 |
query = f"""Analyze the question, understand the instructions, the context.
|
| 1386 |
+
If you can answer this directly without using any tools, provide the answer.
|
| 1387 |
Otherwise respond with 'TOOLS_REQUIRED'
|
| 1388 |
|
| 1389 |
Question: {question}
|
|
|
|
| 1447 |
- If the answer is a list, output only the list as requested (e.g., comma-separated, one per line, etc.).
|
| 1448 |
- If the answer is: how many wheels does the car have?, output only the number, not a sentence.
|
| 1449 |
|
| 1450 |
+
Example:
|
| 1451 |
+
What's the number of the racing car?
|
| 1452 |
Context:
|
| 1453 |
+
Racing car number: 1234567890
|
| 1454 |
+
Answer:
|
| 1455 |
+
1234567890
|
| 1456 |
+
|
| 1457 |
+
How many birds are in the picture?
|
| 1458 |
+
Context:
|
| 1459 |
+
There are 10 birds in the picture.
|
| 1460 |
+
Answer:
|
| 1461 |
+
10
|
| 1462 |
|
| 1463 |
Question:
|
| 1464 |
{question}
|
| 1465 |
+
Context:
|
| 1466 |
+
{tool_output}
|
| 1467 |
+
|
| 1468 |
+
If the question asks what was told in the context, answer it directly. You don't need the name of the person, just the answer.
|
| 1469 |
+
The answer should be exactly what was said in the context.
|
| 1470 |
"""
|
| 1471 |
response = self.llm.invoke(answer_prompt)
|
| 1472 |
answer = response.content if hasattr(response, 'content') else str(response)
|
tools.py
CHANGED
|
@@ -800,3 +800,42 @@ class WebSearchTool(BaseTool):
|
|
| 800 |
return search_result[:10000]
|
| 801 |
except Exception as e:
|
| 802 |
return f"Error searching the web: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 800 |
return search_result[:10000]
|
| 801 |
except Exception as e:
|
| 802 |
return f"Error searching the web: {str(e)}"
|
| 803 |
+
|
| 804 |
+
class MathTool(BaseTool):
|
| 805 |
+
"""Tool for analyzing operation tables for algebraic properties (e.g., commutativity)."""
|
| 806 |
+
name: str = "math_tool"
|
| 807 |
+
description: str = "Analyze operation tables for algebraic properties (e.g., commutativity)."
|
| 808 |
+
|
| 809 |
+
def _run(self, question: str) -> str:
|
| 810 |
+
import re
|
| 811 |
+
# Extract the set S
|
| 812 |
+
set_match = re.search(r'\{([a-zA-Z0-9_,\s]+)\}', question)
|
| 813 |
+
if not set_match:
|
| 814 |
+
return "Could not find set."
|
| 815 |
+
S = [x.strip() for x in set_match.group(1).split(',')]
|
| 816 |
+
# Extract the markdown table (find all lines that start with | and have at least 2 |'s)
|
| 817 |
+
table_lines = [line for line in question.splitlines() if line.strip().startswith('|') and line.count('|') > 2]
|
| 818 |
+
if not table_lines:
|
| 819 |
+
return "Could not find operation table."
|
| 820 |
+
# Remove separator row (contains only dashes and pipes)
|
| 821 |
+
table_lines = [line for line in table_lines if not set(line.replace('|', '').strip()) <= set('-')]
|
| 822 |
+
if not table_lines:
|
| 823 |
+
return "Could not find operation table after removing separator."
|
| 824 |
+
# Parse header
|
| 825 |
+
header = [cell.strip() for cell in table_lines[0].strip('|').split('|')][1:]
|
| 826 |
+
table = {}
|
| 827 |
+
for line in table_lines[1:]:
|
| 828 |
+
row = [cell.strip() for cell in line.strip('|').split('|')]
|
| 829 |
+
row_label = row[0]
|
| 830 |
+
table[row_label] = {col: val for col, val in zip(header, row[1:])}
|
| 831 |
+
# Check commutativity
|
| 832 |
+
involved = set()
|
| 833 |
+
for x in S:
|
| 834 |
+
for y in S:
|
| 835 |
+
if x != y:
|
| 836 |
+
xy = table[x][y]
|
| 837 |
+
yx = table[y][x]
|
| 838 |
+
if xy != yx:
|
| 839 |
+
involved.update([x, y, xy, yx])
|
| 840 |
+
involved = sorted([z for z in involved if z in S])
|
| 841 |
+
return ', '.join(involved)
|