Nguyen Nhu Trung commited on
Commit
e6f45d4
·
1 Parent(s): 69b065f

add tavily search

Browse files
Files changed (2) hide show
  1. agent.py +47 -9
  2. requirements.txt +2 -1
agent.py CHANGED
@@ -12,6 +12,7 @@ from langchain_core.messages import SystemMessage, HumanMessage
12
  from langchain_core.tools import tool
13
  from langchain_core.output_parsers import StrOutputParser
14
  from langchain_core.tools import Tool
 
15
  from langchain_experimental.utilities import PythonREPL
16
  import assemblyai as aai
17
 
@@ -77,14 +78,29 @@ def wiki_search(query: str) -> str:
77
  ])
78
  return {"wiki_results": formatted_search_docs}
79
 
 
 
 
 
 
 
 
 
80
  @tool
81
  def web_search(query: str) -> str:
82
- """Search DuckDuckGo for a query and return maximum 5 results.
83
 
84
  Args:
85
  query: The search query."""
86
- search_docs = DuckDuckGoSearchResults(max_results=5).invoke(query)
87
- return {"web_results": search_docs}
 
 
 
 
 
 
 
88
 
89
 
90
  system_prompt = "You are a helpful assistant"
@@ -98,13 +114,32 @@ tools = [
98
  transcribe_audio
99
  ]
100
 
101
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
102
  llm_with_tools = llm.bind_tools(tools)
103
 
104
  def assistant(state: MessagesState):
105
  """Assistant node"""
106
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  builder = StateGraph(MessagesState)
109
  builder.add_node("assistant", assistant)
110
  builder.add_node("tools", ToolNode(tools))
@@ -120,11 +155,14 @@ graph = builder.compile()
120
  def get_answer(query):
121
  messages = [HumanMessage(content=query)]
122
  results = graph.invoke({"messages": messages})
123
- return results["messages"][-1].content
124
 
125
  if __name__ == "__main__":
126
  question = "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species to be on camera simultaneously?"
127
- messages = [HumanMessage(content=question)]
128
- messages = graph.invoke({"messages": messages})
129
- for m in messages["messages"]:
130
- m.pretty_print()
 
 
 
 
12
  from langchain_core.tools import tool
13
  from langchain_core.output_parsers import StrOutputParser
14
  from langchain_core.tools import Tool
15
+ from langchain_community.tools.tavily_search import TavilySearchResults
16
  from langchain_experimental.utilities import PythonREPL
17
  import assemblyai as aai
18
 
 
78
  ])
79
  return {"wiki_results": formatted_search_docs}
80
 
81
+ # @tool
82
+ # def web_search(query: str) -> str:
83
+ # """Search DuckDuckGo for a query and return maximum 5 results.
84
+
85
+ # Args:
86
+ # query: The search query."""
87
+ # search_docs = DuckDuckGoSearchResults(max_results=5).invoke(query)
88
+ # return {"web_results": search_docs}
89
  @tool
90
  def web_search(query: str) -> str:
91
+ """Search Tavily for a query and return maximum 3 results.
92
 
93
  Args:
94
  query: The search query."""
95
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
96
+ formatted_search_docs = "\n\n---\n\n".join(
97
+ [
98
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
99
+ for doc in search_docs
100
+ ])
101
+ return {"web_results": formatted_search_docs}
102
+
103
+
104
 
105
 
106
  system_prompt = "You are a helpful assistant"
 
114
  transcribe_audio
115
  ]
116
 
117
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0.7)
118
  llm_with_tools = llm.bind_tools(tools)
119
 
120
  def assistant(state: MessagesState):
121
  """Assistant node"""
122
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
123
 
124
+ def final_answer(question, answer):
125
+ sys_prompt = """
126
+ You are an assistant. You help to extract the final answer based given an answer
127
+ of a question.
128
+
129
+ Example:
130
+ **QUESTION:** what is 1+1?
131
+ **ANSWER:** the result is 2
132
+ **OUTPUT:** 2
133
+ """
134
+ conversation = f"Extract final answer from\QUESTION:{question}\ANSWER:{answer}\OUTPUT:"
135
+ print(f"conversation: {conversation}")
136
+ messages = [
137
+ {"role": "system", "content": sys_prompt},
138
+ {"role": "user", "content": conversation},
139
+ ]
140
+ response = llm.invoke(messages)
141
+ return response.content.replace("**OUTPUT:**","").replace("**ANSWER:**","")
142
+
143
  builder = StateGraph(MessagesState)
144
  builder.add_node("assistant", assistant)
145
  builder.add_node("tools", ToolNode(tools))
 
155
  def get_answer(query):
156
  messages = [HumanMessage(content=query)]
157
  results = graph.invoke({"messages": messages})
158
+ return final_answer(query, results["messages"][-1].content)
159
 
160
  if __name__ == "__main__":
161
  question = "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species to be on camera simultaneously?"
162
+ question = "Hi, I was out sick from my classes on Friday, so I'm trying to figure out what I need to study for my Calculus mid-term next week. My friend from class sent me an audio recording of Professor Willowbrook giving out the recommended reading for the test, but my headphones are broken :(\n\nCould you please listen to the recording for me and tell me the page numbers I'm supposed to go over? I've attached a file called Homework.mp3 that has the recording. Please provide just the page numbers as a comma-delimited list. And please provide the list in ascending order."
163
+ question = "What is the first name of the only Malko Competition recipient from the 20th Century (after 1977) whose nationality on record is a country that no longer exists?"
164
+ # getmessages = [HumanMessage(content=question)]
165
+ # messages = graph.invoke({"messages": messages})
166
+ # for m in messages["messages"]:
167
+ # m.pretty_print()
168
+ print(f"FINAL ANSWER: {get_answer(question)}")
requirements.txt CHANGED
@@ -7,4 +7,5 @@ langchain-community
7
  wikipedia
8
  duckduckgo-search
9
  langchain-experimental
10
- assemblyai
 
 
7
  wikipedia
8
  duckduckgo-search
9
  langchain-experimental
10
+ assemblyai
11
+ langchain-tavily