thivy commited on
Commit
f6ff4be
·
1 Parent(s): 024e027

feat: :sparkles: create search supervisor

Browse files
Files changed (3) hide show
  1. agents.py +48 -1
  2. qa_graph.py +43 -73
  3. tools.py +22 -2
agents.py CHANGED
@@ -1,4 +1,14 @@
1
- from tools import general_tools, file_agent_tools, data_agent_tools, math_agent_tools, analyze_video_tools, youtube_transcript_tools
 
 
 
 
 
 
 
 
 
 
2
  from langgraph.prebuilt import create_react_agent
3
  from langgraph.checkpoint.memory import MemorySaver
4
  from langchain_openai import ChatOpenAI
@@ -59,6 +69,27 @@ transcript_agent = create_react_agent(
59
  prompt="You analyze audio/speech content in videos. Use tools to get transcripts."
60
  )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  excel_prompt = """You are a supervisor. You coordinate file_reader, calculator, and data_processor to solve problems step by step.
63
  Do not do calculations or file reading yourself, use the tools.
64
  Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
@@ -76,6 +107,15 @@ If you are asked for a number, don't use comma to write your number neither use
76
  If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
77
  If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
78
  """
 
 
 
 
 
 
 
 
 
79
  # Supervisor
80
  excel_supervisor = create_supervisor(
81
  [file_agent, math_agent, data_agent],
@@ -88,4 +128,11 @@ video_supervisor = create_supervisor(
88
  [video_agent, transcript_agent],
89
  model=llm,
90
  prompt=video_analyzer_prompt
 
 
 
 
 
 
 
91
  ).compile()
 
1
+ from tools import (
2
+ general_tools,
3
+ file_agent_tools,
4
+ data_agent_tools,
5
+ math_agent_tools,
6
+ analyze_video_tools,
7
+ youtube_transcript_tools,
8
+ google_search,
9
+ wiki_search,
10
+ arxiv_search
11
+ )
12
  from langgraph.prebuilt import create_react_agent
13
  from langgraph.checkpoint.memory import MemorySaver
14
  from langchain_openai import ChatOpenAI
 
69
  prompt="You analyze audio/speech content in videos. Use tools to get transcripts."
70
  )
71
 
72
+ wiki_agent = create_react_agent(
73
+ model=llm,
74
+ tools=[wiki_search],
75
+ name="wiki_analyst",
76
+ prompt="You search information from wikipedia."
77
+ )
78
+
79
+ google_agent = create_react_agent(
80
+ model=llm,
81
+ tools=[google_search],
82
+ name="google_search_analyst",
83
+ prompt="You search information from google search."
84
+ )
85
+
86
+ arxiv_agent = create_react_agent(
87
+ model=llm,
88
+ tools=[arxiv_search],
89
+ name="arxiv_analyst",
90
+ prompt="You search information from arxiv."
91
+ )
92
+
93
  excel_prompt = """You are a supervisor. You coordinate file_reader, calculator, and data_processor to solve problems step by step.
94
  Do not do calculations or file reading yourself, use the tools.
95
  Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
 
107
  If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
108
  If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
109
  """
110
+
111
+ search_analyzer_prompt = """You coordinate different search agents to answer questions.
112
+ Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
113
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
114
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
115
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
116
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
117
+ """
118
+
119
  # Supervisor
120
  excel_supervisor = create_supervisor(
121
  [file_agent, math_agent, data_agent],
 
128
  [video_agent, transcript_agent],
129
  model=llm,
130
  prompt=video_analyzer_prompt
131
+ ).compile()
132
+
133
+ # search supervisor
134
+ search_supervisor = create_supervisor(
135
+ [wiki_agent, google_agent, arxiv_agent],
136
+ model=llm,
137
+ prompt=search_analyzer_prompt
138
  ).compile()
qa_graph.py CHANGED
@@ -32,6 +32,21 @@ def get_file_type(file_path: str) -> str:
32
  else:
33
  return "unknown"
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def ask_question(question: str, thread_id: str = "default") -> str:
36
  """Ask the agent a question."""
37
  config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
@@ -87,65 +102,6 @@ def ask_question_youtube(question: Question) -> str:
87
  print(result)
88
  return result["messages"][-1].content
89
 
90
- test = [
91
- {
92
- "task_id": "8e867cd7-cff9-4e6c-867a-ff5ddc2550be",
93
- "question": "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia.",
94
- "Level": "1",
95
- "file_name": ""
96
- },
97
- # {
98
- # "task_id": "cca530fc-4052-43b2-b130-b30968d8aa44",
99
- # "question": "Review the chess position provided in the image. It is black's turn. Provide the correct next move for black which guarantees a win. Please provide your response in algebraic notation.",
100
- # "Level": "1",
101
- # "file_name": "cca530fc-4052-43b2-b130-b30968d8aa44.png"
102
- # },
103
- # {
104
- # "task_id": "1f975693-876d-457b-a649-393859e79bf3",
105
- # "question": "Hi, I was out sick from my classes on Friday, so I'm trying to figure out what I need to study for my Calculus mid-term next week. My friend from class sent me an audio recording of Professor Willowbrook giving out the recommended reading for the test, but my headphones are broken :(\n\nCould you please listen to the recording for me and tell me the page numbers I'm supposed to go over? I've attached a file called Homework.mp3 that has the recording. Please provide just the page numbers as a comma-delimited list. And please provide the list in ascending order.",
106
- # "Level": "1",
107
- # "file_name": "1f975693-876d-457b-a649-393859e79bf3.mp3"
108
- # },
109
- # {
110
- # "task_id": "7bd855d8-463d-4ed5-93ca-5fe35145f733",
111
- # "question": "The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.",
112
- # "Level": "1",
113
- # "file_name": "7bd855d8-463d-4ed5-93ca-5fe35145f733.xlsx"
114
- # },
115
- # {
116
- # "task_id": "f918266a-b3e0-4914-865d-4faa564f1aef",
117
- # "question": "What is the final numeric output from the attached Python code?",
118
- # "Level": "1",
119
- # "file_name": "f918266a-b3e0-4914-865d-4faa564f1aef.py"
120
- # },
121
- # {
122
- # "task_id": "cabe07ed-9eca-40ea-8ead-410ef5e83f91",
123
- # "question": "What is the surname of the equine veterinarian mentioned in 1.E Exercises from the chemistry materials licensed by Marisa Alviar-Agnew & Henry Agnew under the CK-12 license in LibreText's Introductory Chemistry materials as compiled 08/21/2023?",
124
- # "Level": "1",
125
- # "file_name": ""
126
- # },
127
- # {
128
- # "task_id": "9d191bce-651d-4746-be2d-7ef8ecadb9c2",
129
- # "question": "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\"",
130
- # "Level": "1",
131
- # "file_name": ""
132
- # },
133
- # {
134
- # "task_id": "a1e91b78-d3d8-4675-bb8d-62741b4b68a6",
135
- # "question": "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species to be on camera simultaneously?",
136
- # "Level": "1",
137
- # "file_name": ""
138
- # },
139
- ]
140
-
141
- questions = [Question(**item) for item in test]
142
- for q in questions:
143
- print(q.question)
144
- print(q.file_name)
145
- print(q.local_file_path)
146
-
147
-
148
-
149
  # State
150
  class State(TypedDict):
151
  question: Question
@@ -159,7 +115,7 @@ def ask_question_node(state: State) -> dict:
159
  thread_id = f"test_{question_obj.task_id}"
160
 
161
  # Call your existing function
162
- answer = ask_question(question_obj.question, thread_id)
163
 
164
  # Return dict to update state
165
  return {"answer": answer}
@@ -229,17 +185,31 @@ def build_graph():
229
  return react_graph
230
 
231
  if __name__ == "__main__":
232
- for i, question in enumerate(questions):
233
- print(f"\n{i}. {question.question}")
234
-
235
- react_graph = build_graph()
236
- # Invoke the graph and capture the result
237
- result = react_graph.invoke({
238
- "question": question,
239
- "decision": "",
240
- "answer": ""
241
- })
242
- print(result)
 
 
 
 
 
243
 
244
- print(f"Answer: {result['answer']}")
245
- print("-" * 50)
 
 
 
 
 
 
 
 
 
 
32
  else:
33
  return "unknown"
34
 
35
+ def answer_qery(question: str, thread_id: str = "default") -> str:
36
+ """Ask the agent a question."""
37
+ config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 50}
38
+
39
+ try:
40
+ result = video_supervisor.invoke({
41
+ "messages": [
42
+ {"role": "user", "content": question}
43
+ ]
44
+ })
45
+ print(result)
46
+ return result["messages"][-1].content
47
+ except Exception as e:
48
+ return f"Error: {str(e)}"
49
+
50
  def ask_question(question: str, thread_id: str = "default") -> str:
51
  """Ask the agent a question."""
52
  config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
 
102
  print(result)
103
  return result["messages"][-1].content
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  # State
106
  class State(TypedDict):
107
  question: Question
 
115
  thread_id = f"test_{question_obj.task_id}"
116
 
117
  # Call your existing function
118
+ answer = answer_qery(question_obj.question, thread_id)
119
 
120
  # Return dict to update state
121
  return {"answer": answer}
 
185
  return react_graph
186
 
187
  if __name__ == "__main__":
188
+ test = [
189
+ {
190
+ "task_id": "8e867cd7-cff9-4e6c-867a-ff5ddc2550be",
191
+ "question": "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia.",
192
+ "Level": "1",
193
+ "file_name": ""
194
+ },
195
+ ]
196
+
197
+ questions = [Question(**item) for item in test]
198
+ for q in questions:
199
+ print(q.question)
200
+ print(q.file_name)
201
+ print(q.local_file_path)
202
+ for i, question in enumerate(questions):
203
+ print(f"\n{i}. {question.question}")
204
 
205
+ react_graph = build_graph()
206
+ # Invoke the graph and capture the result
207
+ result = react_graph.invoke({
208
+ "question": question,
209
+ "decision": "",
210
+ "answer": ""
211
+ })
212
+ print(result)
213
+
214
+ print(f"Answer: {result['answer']}")
215
+ print("-" * 50)
tools.py CHANGED
@@ -340,10 +340,22 @@ def google_search():
340
  )
341
  google_search = GoogleSearchRun(api_wrapper=api_wrapper)
342
  return google_search
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  def general_tools():
344
  tools = [
345
- ArxivQueryRun(api_wrapper=ArxivAPIWrapper()),
346
- google_search,
347
  analyze_image,
348
  read_python_file,
349
  transcribe_audio,
@@ -369,3 +381,11 @@ def math_agent_tools():
369
  def data_agent_tools():
370
  tools = [extract_values, filter_rows]
371
  return tools
 
 
 
 
 
 
 
 
 
340
  )
341
  google_search = GoogleSearchRun(api_wrapper=api_wrapper)
342
  return google_search
343
+
344
+ @tool
345
+ def wiki_search():
346
+ """Google search tool"""
347
+ api_wrapper = WikipediaAPIWrapper()
348
+ search = WikipediaQueryRun(api_wrapper=api_wrapper)
349
+ return search
350
+
351
+ @tool
352
+ def arxiv_search():
353
+ """Google search tool"""
354
+ api_wrapper = ArxivAPIWrapper()
355
+ search = ArxivQueryRun(api_wrapper=api_wrapper)
356
+ return search
357
  def general_tools():
358
  tools = [
 
 
359
  analyze_image,
360
  read_python_file,
361
  transcribe_audio,
 
381
  def data_agent_tools():
382
  tools = [extract_values, filter_rows]
383
  return tools
384
+
385
+ def search_agen_tools():
386
+ tools = [
387
+ google_search,
388
+ ArxivQueryRun(api_wrapper=ArxivAPIWrapper()),
389
+ WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
390
+ ]
391
+ return tools