added wiki search
Browse files- agent.py +21 -5
- requirements.txt +4 -1
agent.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
| 1 |
'''
|
| 2 |
TODO tools
|
| 3 |
-
- web_search
|
| 4 |
|
| 5 |
Facultative
|
| 6 |
- wiki_search
|
| 7 |
- arxiv_search
|
| 8 |
'''
|
| 9 |
import os
|
| 10 |
-
from smolagents import CodeAgent, tool, DuckDuckGoSearchTool, OpenAIServerModel
|
| 11 |
-
|
| 12 |
|
| 13 |
@tool
|
| 14 |
def add(a:int, b:int) -> int:
|
|
@@ -88,13 +87,30 @@ def rounder(a:float, n:int) -> float:
|
|
| 88 |
|
| 89 |
return round(a,n)
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
def get_agent() -> CodeAgent:
|
| 93 |
search_tool = DuckDuckGoSearchTool()
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
model = OpenAIServerModel(
|
| 96 |
model_id='codestral-latest',
|
| 97 |
api_base="https://codestral.mistral.ai/v1/",
|
| 98 |
-
api_key=
|
| 99 |
|
| 100 |
-
return CodeAgent(tools=[add, subtract, multiply, divide, modulus, rounder, search_tool], model=model)
|
|
|
|
| 1 |
'''
|
| 2 |
TODO tools
|
|
|
|
| 3 |
|
| 4 |
Facultative
|
| 5 |
- wiki_search
|
| 6 |
- arxiv_search
|
| 7 |
'''
|
| 8 |
import os
|
| 9 |
+
from smolagents import CodeAgent, tool, DuckDuckGoSearchTool, OpenAIServerModel, VisitWebpageTool, Tool
|
| 10 |
+
from langchain_community import WikipediaLoader
|
| 11 |
|
| 12 |
@tool
|
| 13 |
def add(a:int, b:int) -> int:
|
|
|
|
| 87 |
|
| 88 |
return round(a,n)
|
| 89 |
|
| 90 |
+
@tool
|
| 91 |
+
def wiki_search(query: str) -> str:
|
| 92 |
+
"""Search Wikipedia for a query and return maximum 2 results.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
query: The search query."""
|
| 96 |
+
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
|
| 97 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
| 98 |
+
[
|
| 99 |
+
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
| 100 |
+
for doc in search_docs
|
| 101 |
+
])
|
| 102 |
+
return {"wiki_results": formatted_search_docs}
|
| 103 |
+
|
| 104 |
|
| 105 |
def get_agent() -> CodeAgent:
|
| 106 |
search_tool = DuckDuckGoSearchTool()
|
| 107 |
+
web_page_tool = VisitWebpageTool()
|
| 108 |
+
|
| 109 |
+
api_key = os.getenv('CODESTRAL_API_KEY')
|
| 110 |
|
| 111 |
model = OpenAIServerModel(
|
| 112 |
model_id='codestral-latest',
|
| 113 |
api_base="https://codestral.mistral.ai/v1/",
|
| 114 |
+
api_key=api_key)
|
| 115 |
|
| 116 |
+
return CodeAgent(tools=[add, subtract, multiply, divide, modulus, rounder, search_tool, web_page_tool, wiki_search], model=model)
|
requirements.txt
CHANGED
|
@@ -1,3 +1,6 @@
|
|
| 1 |
gradio
|
| 2 |
requests
|
| 3 |
-
smolagents
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
gradio
|
| 2 |
requests
|
| 3 |
+
smolagents
|
| 4 |
+
smolagents[openai]
|
| 5 |
+
langchain_community
|
| 6 |
+
wikipedia
|