Commit
·
efd8150
1
Parent(s):
e9d4a8e
add system prompt and change llama for claude model
Browse files- agents/search_agent.py +21 -4
- graphs/search.py +8 -8
- tools/search.py +16 -0
agents/search_agent.py
CHANGED
|
@@ -1,14 +1,31 @@
|
|
| 1 |
from graphs.search import build_workflow
|
| 2 |
-
from langchain_core.messages import HumanMessage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
class SearchAgent:
|
| 4 |
def __init__(self):
|
| 5 |
print("SearchAgent initialized.")
|
| 6 |
def __call__(self, question: str) -> str:
|
| 7 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
| 8 |
workflow = build_workflow()
|
| 9 |
-
messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
messages = workflow.invoke({
|
| 11 |
"messages":messages
|
| 12 |
-
})
|
| 13 |
|
| 14 |
-
return messages["messages"][-1].content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from graphs.search import build_workflow
|
| 2 |
+
from langchain_core.messages import HumanMessage, SystemMessage
|
| 3 |
+
from langfuse.callback import CallbackHandler
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
load_dotenv()
|
| 6 |
+
langfuse_handler = CallbackHandler(host="https://cloud.langfuse.com")
|
| 7 |
+
|
| 8 |
class SearchAgent:
|
| 9 |
def __init__(self):
|
| 10 |
print("SearchAgent initialized.")
|
| 11 |
def __call__(self, question: str) -> str:
|
| 12 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
| 13 |
workflow = build_workflow()
|
| 14 |
+
messages= [SystemMessage("""You are a general AI assistant. I will ask you a question. Report your thoughts, and finish with only the answer. \n
|
| 15 |
+
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
|
| 16 |
+
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
|
| 17 |
+
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
|
| 18 |
+
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.""")]
|
| 19 |
+
messages = messages + [HumanMessage(content=question)]
|
| 20 |
messages = workflow.invoke({
|
| 21 |
"messages":messages
|
| 22 |
+
}, config={"callbacks": [langfuse_handler]})
|
| 23 |
|
| 24 |
+
return messages["messages"][-1].content
|
| 25 |
+
|
| 26 |
+
""" if __name__ == "__main__":
|
| 27 |
+
question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
|
| 28 |
+
agent = SearchAgent()
|
| 29 |
+
submit_answer = agent(question)
|
| 30 |
+
|
| 31 |
+
print(submit_answer) """
|
graphs/search.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
from models.models import groq_model
|
| 2 |
-
from tools.search import arxiv_search, web_search,
|
| 3 |
from langgraph.graph import StateGraph, START, END, MessagesState
|
| 4 |
from langgraph.prebuilt import ToolNode
|
| 5 |
from langchain_core.messages import HumanMessage
|
|
@@ -7,12 +7,12 @@ from langchain_core.messages import HumanMessage
|
|
| 7 |
tools = [
|
| 8 |
arxiv_search,
|
| 9 |
web_search,
|
| 10 |
-
|
| 11 |
]
|
| 12 |
|
| 13 |
tool_node = ToolNode(tools)
|
| 14 |
-
bound_model = groq_model.bind_tools(tools)
|
| 15 |
-
|
| 16 |
# Define the function that calls the model
|
| 17 |
def call_model(state: MessagesState):
|
| 18 |
response = bound_model.invoke(state["messages"])
|
|
@@ -40,12 +40,12 @@ def build_workflow():
|
|
| 40 |
workflow.add_edge("action", "agent")
|
| 41 |
return workflow.compile()
|
| 42 |
|
| 43 |
-
if __name__ == "__main__":
|
| 44 |
-
question = "
|
| 45 |
# Build the graph
|
| 46 |
graph = build_workflow()
|
| 47 |
# Run the graph
|
| 48 |
messages = [HumanMessage(content=question)]
|
| 49 |
messages = graph.invoke({"messages": messages})
|
| 50 |
for m in messages["messages"]:
|
| 51 |
-
m.pretty_print()
|
|
|
|
| 1 |
+
from models.models import groq_model, anthropic_model
|
| 2 |
+
from tools.search import arxiv_search, web_search, google_search
|
| 3 |
from langgraph.graph import StateGraph, START, END, MessagesState
|
| 4 |
from langgraph.prebuilt import ToolNode
|
| 5 |
from langchain_core.messages import HumanMessage
|
|
|
|
| 7 |
tools = [
|
| 8 |
arxiv_search,
|
| 9 |
web_search,
|
| 10 |
+
google_search,
|
| 11 |
]
|
| 12 |
|
| 13 |
tool_node = ToolNode(tools)
|
| 14 |
+
#bound_model = groq_model.bind_tools(tools)
|
| 15 |
+
bound_model = anthropic_model.bind_tools(tools)
|
| 16 |
# Define the function that calls the model
|
| 17 |
def call_model(state: MessagesState):
|
| 18 |
response = bound_model.invoke(state["messages"])
|
|
|
|
| 40 |
workflow.add_edge("action", "agent")
|
| 41 |
return workflow.compile()
|
| 42 |
|
| 43 |
+
""" if __name__ == "__main__":
|
| 44 |
+
question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
|
| 45 |
# Build the graph
|
| 46 |
graph = build_workflow()
|
| 47 |
# Run the graph
|
| 48 |
messages = [HumanMessage(content=question)]
|
| 49 |
messages = graph.invoke({"messages": messages})
|
| 50 |
for m in messages["messages"]:
|
| 51 |
+
m.pretty_print() """
|
tools/search.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from langchain_core.tools import tool
|
| 2 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
|
|
|
| 3 |
from langchain_community.document_loaders import WikipediaLoader
|
| 4 |
from langchain_community.document_loaders import ArxivLoader
|
| 5 |
from dotenv import load_dotenv
|
|
@@ -11,6 +12,7 @@ def wikipedia_search(query: str) -> str:
|
|
| 11 |
"""Search Wikipedia for a query and return maximum 1 result.
|
| 12 |
Args:
|
| 13 |
query: The search query."""
|
|
|
|
| 14 |
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
|
| 15 |
|
| 16 |
formatted_search_docs = "\n\n---\n\n".join(
|
|
@@ -54,3 +56,17 @@ def arxiv_search(query: str) -> str:
|
|
| 54 |
]
|
| 55 |
)
|
| 56 |
return {"arxiv_results": formatted_search_docs}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from langchain_core.tools import tool
|
| 2 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
| 3 |
+
from langchain_community.utilities import GoogleSerperAPIWrapper
|
| 4 |
from langchain_community.document_loaders import WikipediaLoader
|
| 5 |
from langchain_community.document_loaders import ArxivLoader
|
| 6 |
from dotenv import load_dotenv
|
|
|
|
| 12 |
"""Search Wikipedia for a query and return maximum 1 result.
|
| 13 |
Args:
|
| 14 |
query: The search query."""
|
| 15 |
+
query = "Mercedes Sosa"
|
| 16 |
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
|
| 17 |
|
| 18 |
formatted_search_docs = "\n\n---\n\n".join(
|
|
|
|
| 56 |
]
|
| 57 |
)
|
| 58 |
return {"arxiv_results": formatted_search_docs}
|
| 59 |
+
|
| 60 |
+
@tool
|
| 61 |
+
def google_search(query: str) -> str:
|
| 62 |
+
"""
|
| 63 |
+
Search Google for a query and return maximum 2 result.
|
| 64 |
+
Args: query: The search query.
|
| 65 |
+
"""
|
| 66 |
+
search_docs = GoogleSerperAPIWrapper()
|
| 67 |
+
result = search_docs.run(query)
|
| 68 |
+
|
| 69 |
+
return {"google_results": result}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|