Luigi D'Addona commited on
Commit
5a54222
·
1 Parent(s): e47bd31

aggiunt tool arxiv_search

Browse files
Files changed (2) hide show
  1. agent.py +2 -2
  2. tools.py +14 -0
agent.py CHANGED
@@ -14,7 +14,7 @@ from langchain_google_genai import ChatGoogleGenerativeAI
14
 
15
  # Local imports
16
  from tools import get_search_tool, get_tavily_search_tool, get_wikipedia_tool, wikipedia_search, wikipedia_search_3,\
17
- execute_python_code_from_file, download_taskid_file, analyze_excel_file, get_analyze_mp3_tool
18
 
19
  # Nota: per i test in locale si usa il .env
20
  # su HuggingFace invece si usano le variabili definite in Settings/"Variables and secrets"
@@ -60,7 +60,7 @@ search_tool = get_tavily_search_tool()
60
  #wikipedia_tool = get_wikipedia_tool()
61
  analyze_mp3_tool = get_analyze_mp3_tool(chat)
62
 
63
- tools = [search_tool, wikipedia_search_3, execute_python_code_from_file, download_taskid_file, analyze_excel_file, analyze_mp3_tool]
64
 
65
  # Bind tools to the model
66
  chat_with_tools = chat.bind_tools(tools)
 
14
 
15
  # Local imports
16
  from tools import get_search_tool, get_tavily_search_tool, get_wikipedia_tool, wikipedia_search, wikipedia_search_3,\
17
+ execute_python_code_from_file, download_taskid_file, analyze_excel_file, get_analyze_mp3_tool, arxiv_search
18
 
19
  # Nota: per i test in locale si usa il .env
20
  # su HuggingFace invece si usano le variabili definite in Settings/"Variables and secrets"
 
60
  #wikipedia_tool = get_wikipedia_tool()
61
  analyze_mp3_tool = get_analyze_mp3_tool(chat)
62
 
63
+ tools = [search_tool, wikipedia_search_3, execute_python_code_from_file, download_taskid_file, analyze_excel_file, analyze_mp3_tool, arxiv_search]
64
 
65
  # Bind tools to the model
66
  chat_with_tools = chat.bind_tools(tools)
tools.py CHANGED
@@ -10,6 +10,7 @@ from langchain_community.tools import WikipediaQueryRun
10
  from langchain_community.document_loaders import WikipediaLoader
11
  import wikipedia
12
  from langchain_tavily import TavilySearch
 
13
  from langchain_core.tools import tool
14
  from langchain.tools import Tool
15
  from langchain_core.messages import HumanMessage
@@ -255,3 +256,16 @@ def get_analyze_mp3_tool(llm):
255
  return analyze_mp3_file
256
 
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from langchain_community.document_loaders import WikipediaLoader
11
  import wikipedia
12
  from langchain_tavily import TavilySearch
13
+ from langchain_community.document_loaders import ArxivLoader
14
  from langchain_core.tools import tool
15
  from langchain.tools import Tool
16
  from langchain_core.messages import HumanMessage
 
256
  return analyze_mp3_file
257
 
258
 
259
+ @tool
260
+ def arxiv_search(query: str) -> str:
261
+ """Search Arxiv for a query and return maximum 3 result.
262
+ Args:
263
+ query: The search query."""
264
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
265
+ formatted_search_docs = "\n\n---\n\n".join(
266
+ [
267
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
268
+ for doc in search_docs
269
+ ]
270
+ )
271
+ return {"arxiv_results": formatted_search_docs}