Golfn commited on
Commit
a8bb3a0
·
1 Parent(s): cc36db8

Add mathematical question and move vector_search into Alfred_agent.py

Browse files
Files changed (2) hide show
  1. Alfred_Agent.py +56 -33
  2. other_tools.py +51 -19
Alfred_Agent.py CHANGED
@@ -1,42 +1,41 @@
1
  import datasets
2
  from langchain.docstore.document import Document
3
-
4
  # Load the dataset
5
- guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
6
 
7
  # Convert dataset entries into Document objects
8
- docs = [
9
- Document(
10
- page_content="\n".join([
11
- f"Name: {guest['name']}",
12
- f"Relation: {guest['relation']}",
13
- f"Description: {guest['description']}",
14
- f"Email: {guest['email']}"
15
- ]),
16
- metadata={"name": guest["name"]}
17
- )
18
- for guest in guest_dataset
19
- ]
20
-
21
- from langchain_community.retrievers import BM25Retriever
22
- from langchain.tools import Tool
23
-
24
- bm25_retriever = BM25Retriever.from_documents(docs)
25
-
26
- def extract_text(query: str) -> str:
27
- """Retrieves detailed information about gala guests based on their name or relation."""
28
- results = bm25_retriever.invoke(query)
29
- if results:
30
- return "\n\n".join([doc.page_content for doc in results[:3]])
31
- else:
32
- return "No matching guest information found."
33
 
34
- guest_info_tool = Tool(
35
- name="guest_info_retriever",
36
- func=extract_text,
37
- description="Retrieves detailed information about gala guests based on their name or relation."
38
- )
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  from typing import TypedDict, Annotated
41
  from langgraph.graph.message import add_messages
42
  from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
@@ -45,7 +44,7 @@ from langgraph.graph import START, StateGraph
45
  from langgraph.prebuilt import tools_condition
46
  from langchain_openai import ChatOpenAI
47
  from Webserch_tool import weather_info_tool
48
- from other_tools import wiki_search, arvix_search, web_search
49
  import os
50
  from dotenv import load_dotenv
51
  load_dotenv()
@@ -56,7 +55,17 @@ llm = ChatOpenAI(temperature=0
56
  tools = [weather_info_tool,wiki_search,arvix_search,web_search]
57
  chat_with_tools = llm.bind_tools(tools)
58
 
 
 
 
 
 
 
 
 
 
59
  # Generate the AgentState and Agent graph
 
60
  class AgentState(TypedDict):
61
  messages: Annotated[list[AnyMessage], add_messages]
62
 
@@ -65,9 +74,23 @@ def assistant(state: AgentState):
65
  "messages": [chat_with_tools.invoke(state["messages"])],
66
  }
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  ## The graph
69
  builder = StateGraph(AgentState)
70
 
 
71
  # Define nodes: these do the work
72
  builder.add_node("assistant", assistant)
73
  builder.add_node("tools", ToolNode(tools))
 
1
  import datasets
2
  from langchain.docstore.document import Document
 
3
  # Load the dataset
4
+ # guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
5
 
6
  # Convert dataset entries into Document objects
7
+ # docs = [
8
+ # Document(
9
+ # page_content="\n".join([
10
+ # f"Name: {guest['name']}",
11
+ # f"Relation: {guest['relation']}",
12
+ # f"Description: {guest['description']}",
13
+ # f"Email: {guest['email']}"
14
+ # ]),
15
+ # metadata={"name": guest["name"]}
16
+ # )
17
+ # for guest in guest_dataset
18
+ # ]
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # from langchain_community.retrievers import BM25Retriever
21
+ # from langchain.tools import Tool
 
 
 
22
 
23
+ # bm25_retriever = BM25Retriever.from_documents(docs)
24
+
25
+ # def extract_text(query: str) -> str:
26
+ # """Retrieves detailed information about gala guests based on their name or relation."""
27
+ # results = bm25_retriever.invoke(query)
28
+ # if results:
29
+ # return "\n\n".join([doc.page_content for doc in results[:3]])
30
+ # else:
31
+ # return "No matching guest information found."
32
+
33
+ # guest_info_tool = Tool(
34
+ # name="guest_info_retriever",
35
+ # func=extract_text,
36
+ # description="Retrieves detailed information about gala guests based on their name or relation."
37
+ # )
38
+ #######################################################################################################################################################
39
  from typing import TypedDict, Annotated
40
  from langgraph.graph.message import add_messages
41
  from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
 
44
  from langgraph.prebuilt import tools_condition
45
  from langchain_openai import ChatOpenAI
46
  from Webserch_tool import weather_info_tool
47
+ from other_tools import wiki_search, arvix_search, web_search, vector_search
48
  import os
49
  from dotenv import load_dotenv
50
  load_dotenv()
 
55
  tools = [weather_info_tool,wiki_search,arvix_search,web_search]
56
  chat_with_tools = llm.bind_tools(tools)
57
 
58
+ #setting up prompt
59
+ ai_message = AIMessage(content="""You are a helpful assistant tasked with answering questions using a set of tools.
60
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
61
+ FINAL ANSWER: [YOUR FINAL ANSWER].
62
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, Apply the rules above for each element (number or string), ensure there is exactly one space after each comma.
63
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer. """)
64
+
65
+
66
+
67
  # Generate the AgentState and Agent graph
68
+ from langgraph.graph import MessagesState #the same as AgentState
69
  class AgentState(TypedDict):
70
  messages: Annotated[list[AnyMessage], add_messages]
71
 
 
74
  "messages": [chat_with_tools.invoke(state["messages"])],
75
  }
76
 
77
+ def retriever(state: AgentState):
78
+ """Retriever node"""
79
+ similar_question = vector_search(state["messages"][0].content)
80
+
81
+ if similar_question: # Check if the list is not empty
82
+ example_msg = HumanMessage(
83
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0]}",
84
+ )
85
+ return {"messages": [ai_message] + state["messages"] + [example_msg]}
86
+ else:
87
+ # Handle the case when no similar questions are found
88
+ return {"messages": [ai_message] + state["messages"]}
89
+
90
  ## The graph
91
  builder = StateGraph(AgentState)
92
 
93
+
94
  # Define nodes: these do the work
95
  builder.add_node("assistant", assistant)
96
  builder.add_node("tools", ToolNode(tools))
other_tools.py CHANGED
@@ -1,45 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  from langchain_community.document_loaders import WikipediaLoader
 
3
  def wiki_search(query: str) -> str:
4
- """Search Wikipedia for a query and return maximum 2 results.
5
-
6
- Args:
7
- query: The search query."""
8
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
9
  formatted_search_docs = "\n\n---\n\n".join(
10
  [
11
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
12
  for doc in search_docs
13
  ])
14
  return {"wiki_results": formatted_search_docs}
15
 
16
-
17
-
18
  from langchain_community.document_loaders import ArxivLoader
 
19
  def arvix_search(query: str) -> str:
20
- """Search Arxiv for a query and return maximum 3 result.
21
-
22
- Args:
23
- query: The search query."""
24
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
25
  formatted_search_docs = "\n\n---\n\n".join(
26
  [
27
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
28
  for doc in search_docs
29
  ])
30
  return {"arvix_results": formatted_search_docs}
31
 
32
-
33
  from langchain_community.tools.tavily_search import TavilySearchResults
 
34
  def web_search(query: str) -> str:
35
- """Search Tavily for a query and return maximum 3 results.
36
-
37
- Args:
38
- query: The search query."""
39
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
40
  formatted_search_docs = "\n\n---\n\n".join(
41
  [
42
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
 
 
 
 
 
 
 
 
 
 
43
  for doc in search_docs
44
  ])
45
- return {"web_results": formatted_search_docs}
 
1
+ from langchain.tools import tool
2
+ #mathematical operations
3
+ import cmath
4
+ @tool(description="Multiplies two numbers.")
5
+ def multiply(a: float, b: float) -> float:
6
+ return a * b
7
+
8
+ @tool(description="Adds two numbers.")
9
+ def add(a: float, b: float) -> float:
10
+ return a + b
11
+
12
+ @tool(description="Subtracts two numbers.")
13
+ def subtract(a: float, b: float) -> int:
14
+ return a - b
15
+
16
+ @tool(description="Divides two numbers.")
17
+ def divide(a: float, b: float) -> float:
18
+ if b == 0:
19
+ raise ValueError("Cannot divided by zero.")
20
+ return a / b
21
+
22
+ @tool(description="Get the modulus of two numbers.")
23
+ def modulus(a: int, b: int) -> int:
24
+ return a % b
25
+
26
+ @tool(description="Get the power of two numbers.")
27
+ def power(a: float, b: float) -> float:
28
+ return a**b
29
+
30
+ @tool(description="Get the square root of a number.")
31
+ def square_root(a: float) -> float | complex:
32
+ if a >= 0:
33
+ return a**0.5
34
+ return cmath.sqrt(a)
35
 
36
  from langchain_community.document_loaders import WikipediaLoader
37
+ @tool(description="Search Wikipedia for a query and return maximum 2 results.")
38
  def wiki_search(query: str) -> str:
 
 
 
 
39
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
40
  formatted_search_docs = "\n\n---\n\n".join(
41
  [
42
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>{doc.page_content}</Document>'
43
  for doc in search_docs
44
  ])
45
  return {"wiki_results": formatted_search_docs}
46
 
 
 
47
  from langchain_community.document_loaders import ArxivLoader
48
+ @tool(description="Search Arxiv for a query and return maximum 3 result.")
49
  def arvix_search(query: str) -> str:
 
 
 
 
50
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
51
  formatted_search_docs = "\n\n---\n\n".join(
52
  [
53
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>{doc.page_content[:1000]}</Document>'
54
  for doc in search_docs
55
  ])
56
  return {"arvix_results": formatted_search_docs}
57
 
 
58
  from langchain_community.tools.tavily_search import TavilySearchResults
59
+ @tool(description="Search Tavily for a query and return maximum 3 results.")
60
  def web_search(query: str) -> str:
 
 
 
 
61
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
62
  formatted_search_docs = "\n\n---\n\n".join(
63
  [
64
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>{doc.page_content}</Document>'
65
+ for doc in search_docs
66
+ ])
67
+ return {"web_results": formatted_search_docs}
68
+
69
+ from upload_metadata_n_setup_retrivers import retriever
70
+ def vector_search(query:str) -> str:
71
+ search_docs = retriever.invoke(query=query)
72
+ formatted_search_docs = "\n\n---\n\n".join(
73
+ [
74
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>{doc.page_content}</Document>'
75
  for doc in search_docs
76
  ])
77
+ return {"vector_results": formatted_search_docs}