sebastianfrench commited on
Commit
efd8150
·
1 Parent(s): e9d4a8e

add system prompt and change llama for claude model

Browse files
Files changed (3) hide show
  1. agents/search_agent.py +21 -4
  2. graphs/search.py +8 -8
  3. tools/search.py +16 -0
agents/search_agent.py CHANGED
@@ -1,14 +1,31 @@
1
  from graphs.search import build_workflow
2
- from langchain_core.messages import HumanMessage
 
 
 
 
 
3
  class SearchAgent:
4
  def __init__(self):
5
  print("SearchAgent initialized.")
6
  def __call__(self, question: str) -> str:
7
  print(f"Agent received question (first 50 chars): {question[:50]}...")
8
  workflow = build_workflow()
9
- messages = [HumanMessage(question)]
 
 
 
 
 
10
  messages = workflow.invoke({
11
  "messages":messages
12
- })
13
 
14
- return messages["messages"][-1].content
 
 
 
 
 
 
 
 
1
  from graphs.search import build_workflow
2
+ from langchain_core.messages import HumanMessage, SystemMessage
3
+ from langfuse.callback import CallbackHandler
4
+ from dotenv import load_dotenv
5
+ load_dotenv()
6
+ langfuse_handler = CallbackHandler(host="https://cloud.langfuse.com")
7
+
8
  class SearchAgent:
9
  def __init__(self):
10
  print("SearchAgent initialized.")
11
  def __call__(self, question: str) -> str:
12
  print(f"Agent received question (first 50 chars): {question[:50]}...")
13
  workflow = build_workflow()
14
+ messages= [SystemMessage("""You are a general AI assistant. I will ask you a question. Report your thoughts, and finish with only the answer. \n
15
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
16
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
17
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
18
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.""")]
19
+ messages = messages + [HumanMessage(content=question)]
20
  messages = workflow.invoke({
21
  "messages":messages
22
+ }, config={"callbacks": [langfuse_handler]})
23
 
24
+ return messages["messages"][-1].content
25
+
26
+ """ if __name__ == "__main__":
27
+ question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
28
+ agent = SearchAgent()
29
+ submit_answer = agent(question)
30
+
31
+ print(submit_answer) """
graphs/search.py CHANGED
@@ -1,5 +1,5 @@
1
- from models.models import groq_model
2
- from tools.search import arxiv_search, web_search, wikipedia_search
3
  from langgraph.graph import StateGraph, START, END, MessagesState
4
  from langgraph.prebuilt import ToolNode
5
  from langchain_core.messages import HumanMessage
@@ -7,12 +7,12 @@ from langchain_core.messages import HumanMessage
7
  tools = [
8
  arxiv_search,
9
  web_search,
10
- wikipedia_search
11
  ]
12
 
13
  tool_node = ToolNode(tools)
14
- bound_model = groq_model.bind_tools(tools)
15
-
16
  # Define the function that calls the model
17
  def call_model(state: MessagesState):
18
  response = bound_model.invoke(state["messages"])
@@ -40,12 +40,12 @@ def build_workflow():
40
  workflow.add_edge("action", "agent")
41
  return workflow.compile()
42
 
43
- if __name__ == "__main__":
44
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
45
  # Build the graph
46
  graph = build_workflow()
47
  # Run the graph
48
  messages = [HumanMessage(content=question)]
49
  messages = graph.invoke({"messages": messages})
50
  for m in messages["messages"]:
51
- m.pretty_print()
 
1
+ from models.models import groq_model, anthropic_model
2
+ from tools.search import arxiv_search, web_search, google_search
3
  from langgraph.graph import StateGraph, START, END, MessagesState
4
  from langgraph.prebuilt import ToolNode
5
  from langchain_core.messages import HumanMessage
 
7
  tools = [
8
  arxiv_search,
9
  web_search,
10
+ google_search,
11
  ]
12
 
13
  tool_node = ToolNode(tools)
14
+ #bound_model = groq_model.bind_tools(tools)
15
+ bound_model = anthropic_model.bind_tools(tools)
16
  # Define the function that calls the model
17
  def call_model(state: MessagesState):
18
  response = bound_model.invoke(state["messages"])
 
40
  workflow.add_edge("action", "agent")
41
  return workflow.compile()
42
 
43
+ """ if __name__ == "__main__":
44
+ question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
45
  # Build the graph
46
  graph = build_workflow()
47
  # Run the graph
48
  messages = [HumanMessage(content=question)]
49
  messages = graph.invoke({"messages": messages})
50
  for m in messages["messages"]:
51
+ m.pretty_print() """
tools/search.py CHANGED
@@ -1,5 +1,6 @@
1
  from langchain_core.tools import tool
2
  from langchain_community.tools.tavily_search import TavilySearchResults
 
3
  from langchain_community.document_loaders import WikipediaLoader
4
  from langchain_community.document_loaders import ArxivLoader
5
  from dotenv import load_dotenv
@@ -11,6 +12,7 @@ def wikipedia_search(query: str) -> str:
11
  """Search Wikipedia for a query and return maximum 1 result.
12
  Args:
13
  query: The search query."""
 
14
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
15
 
16
  formatted_search_docs = "\n\n---\n\n".join(
@@ -54,3 +56,17 @@ def arxiv_search(query: str) -> str:
54
  ]
55
  )
56
  return {"arxiv_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langchain_core.tools import tool
2
  from langchain_community.tools.tavily_search import TavilySearchResults
3
+ from langchain_community.utilities import GoogleSerperAPIWrapper
4
  from langchain_community.document_loaders import WikipediaLoader
5
  from langchain_community.document_loaders import ArxivLoader
6
  from dotenv import load_dotenv
 
12
  """Search Wikipedia for a query and return maximum 1 result.
13
  Args:
14
  query: The search query."""
15
+ query = "Mercedes Sosa"
16
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
17
 
18
  formatted_search_docs = "\n\n---\n\n".join(
 
56
  ]
57
  )
58
  return {"arxiv_results": formatted_search_docs}
59
+
60
+ @tool
61
+ def google_search(query: str) -> str:
62
+ """
63
+ Search Google for a query and return maximum 2 result.
64
+ Args: query: The search query.
65
+ """
66
+ search_docs = GoogleSerperAPIWrapper()
67
+ result = search_docs.run(query)
68
+
69
+ return {"google_results": result}
70
+
71
+
72
+