PradeepBodhi commited on
Commit
12d1066
·
verified ·
1 Parent(s): 4d768fd

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +108 -94
agent.py CHANGED
@@ -1,23 +1,70 @@
 
 
 
 
1
  from langchain_groq import ChatGroq
 
 
 
2
  from langchain_core.messages import SystemMessage, HumanMessage
3
  from langchain_core.tools import tool
4
- from langchain_community.tools.tavily_search import TavilySearchResults
5
- from langchain_community.document_loaders import WikipediaLoader
6
- from langgraph.graph import START, StateGraph, MessagesState
7
- from langgraph.prebuilt import tools_condition
8
- from langgraph.prebuilt import ToolNode
9
- from dotenv import load_dotenv
10
- import os
11
- from supabase.client import Client, create_client
12
- import json
13
- from langchain.schema import Document
14
  from langchain_huggingface import HuggingFaceEmbeddings
15
- from langchain_community.document_loaders import ArxivLoader
16
- from langchain_community.vectorstores import SupabaseVectorStore
17
-
18
 
19
  load_dotenv()
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  @tool
23
  def wiki_search(query: str) -> str:
@@ -62,19 +109,16 @@ def arvix_search(query: str) -> str:
62
  return {"arvix_results": formatted_search_docs}
63
 
64
 
65
- tools = [
66
- wiki_search,
67
- web_search,
68
- arvix_search,
69
- ]
70
 
71
  # load the system prompt from the file
72
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
73
  system_prompt = f.read()
74
 
 
 
75
 
76
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2"
77
-
78
  supabase: Client = create_client(
79
  os.environ.get("SUPABASE_URL"),
80
  os.environ.get("SUPABASE_SERVICE_KEY"))
@@ -84,14 +128,43 @@ vector_store = SupabaseVectorStore(
84
  table_name="documents",
85
  query_name="match_documents",
86
  )
 
 
 
 
 
87
 
88
- # System message
89
- sys_msg = SystemMessage(content=system_prompt)
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
92
  # Build graph function
93
- def build_graph():
94
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  llm_with_tools = llm.bind_tools(tools)
96
 
97
  # Node
@@ -106,7 +179,7 @@ def build_graph():
106
  content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
107
  )
108
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
109
-
110
  builder = StateGraph(MessagesState)
111
  builder.add_node("retriever", retriever)
112
  builder.add_node("assistant", assistant)
@@ -121,73 +194,14 @@ def build_graph():
121
 
122
  # Compile graph
123
  return builder.compile()
124
-
125
-
126
- def load_documents():
127
-
128
- embeddings = AzureOpenAIEmbeddings(model="text-embedding-3-small")
129
- with open('metadata.jsonl', 'r') as jsonl_file:
130
- json_list = list(jsonl_file)
131
-
132
- json_QA = []
133
- for json_str in json_list:
134
- json_data = json.loads(json_str)
135
- json_QA.append(json_data)
136
-
137
- supabase_url = os.environ.get("SUPABASE_URL")
138
- supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
139
- supabase: Client = create_client(supabase_url, supabase_key)
140
-
141
- docs: list[Document] = []
142
- for sample in json_QA:
143
- content = f"Question : {sample['Question']}\n\nFinal answer : {sample['Final answer']}"
144
- doc = {
145
- "content" : content,
146
- "metadata" : {
147
- "source" : sample['task_id']
148
- },
149
- "embedding" : embeddings.embed_query(content),
150
- }
151
- docs.append(doc)
152
-
153
- # upload the documents to the vector database
154
- try:
155
- response = (
156
- supabase.table("documents")
157
- .insert(docs)
158
- .execute()
159
- )
160
- except Exception as exception:
161
- print("Error inserting data into Supabase:", exception)
162
-
163
-
164
- def search_documents() -> list[Document]:
165
- query = "On June 6, 2023, an article by Carolyn Collins Petersen was published in Universe Today. This article mentions a team that produced a paper about their observations, linked at the bottom of the article. Find this paper. Under what NASA award number was the work performed by R. G. Arendt supported by?"
166
- vector_store = SupabaseVectorStore(
167
- client=supabase,
168
- embedding= embeddings,
169
- table_name="documents",
170
- query_name="match_documents",
171
- )
172
- retriever = vector_store.as_retriever()
173
- docs = retriever.get_relevant_documents(query)
174
- return docs[0]
175
-
176
-
177
-
178
- if __name__ == "__main__":
179
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
180
- # Build the graph
181
- graph = build_graph()
182
- # Run the graph
183
- messages = [HumanMessage(content=question)]
184
- messages = graph.invoke({"messages": messages})
185
- for m in messages["messages"]:
186
- m.pretty_print()
187
-
188
- # load_documents()
189
- # search_documents()
190
-
191
-
192
-
193
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition, ToolNode
5
  from langchain_groq import ChatGroq
6
+ from langchain_community.tools.tavily_search import TavilySearchResults
7
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
8
+ from langchain_community.vectorstores import SupabaseVectorStore
9
  from langchain_core.messages import SystemMessage, HumanMessage
10
  from langchain_core.tools import tool
11
+ from langchain.tools.retriever import create_retriever_tool
 
 
 
 
 
 
 
 
 
12
  from langchain_huggingface import HuggingFaceEmbeddings
13
+ from supabase.client import Client, create_client
 
 
14
 
15
  load_dotenv()
16
 
17
+ @tool
18
+ def multiply(a: int, b: int) -> int:
19
+ """Multiply two numbers.
20
+
21
+ Args:
22
+ a: first int
23
+ b: second int
24
+ """
25
+ return a * b
26
+
27
+ @tool
28
+ def add(a: int, b: int) -> int:
29
+ """Add two numbers.
30
+
31
+ Args:
32
+ a: first int
33
+ b: second int
34
+ """
35
+ return a + b
36
+
37
+ @tool
38
+ def subtract(a: int, b: int) -> int:
39
+ """Subtract two numbers.
40
+
41
+ Args:
42
+ a: first int
43
+ b: second int
44
+ """
45
+ return a - b
46
+
47
+ @tool
48
+ def divide(a: int, b: int) -> int:
49
+ """Divide two numbers.
50
+
51
+ Args:
52
+ a: first int
53
+ b: second int
54
+ """
55
+ if b == 0:
56
+ raise ValueError("Cannot divide by zero.")
57
+ return a / b
58
+
59
+ @tool
60
+ def modulus(a: int, b: int) -> int:
61
+ """Get the modulus of two numbers.
62
+
63
+ Args:
64
+ a: first int
65
+ b: second int
66
+ """
67
+ return a % b
68
 
69
  @tool
70
  def wiki_search(query: str) -> str:
 
109
  return {"arvix_results": formatted_search_docs}
110
 
111
 
 
 
 
 
 
112
 
113
  # load the system prompt from the file
114
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
115
  system_prompt = f.read()
116
 
117
+ # System message
118
+ sys_msg = SystemMessage(content=system_prompt)
119
 
120
+ # build a retriever
121
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
122
  supabase: Client = create_client(
123
  os.environ.get("SUPABASE_URL"),
124
  os.environ.get("SUPABASE_SERVICE_KEY"))
 
128
  table_name="documents",
129
  query_name="match_documents",
130
  )
131
+ create_retriever_tool = create_retriever_tool(
132
+ retriever=vector_store.as_retriever(),
133
+ name="Question Search",
134
+ description="A tool to retrieve similar questions from a vector store.",
135
+ )
136
 
 
 
137
 
138
 
139
+ tools = [
140
+ multiply,
141
+ add,
142
+ subtract,
143
+ divide,
144
+ modulus,
145
+ wiki_search,
146
+ web_search,
147
+ arvix_search,
148
+ ]
149
+
150
  # Build graph function
151
+ def build_graph(provider: str = "groq"):
152
+ """Build the graph"""
153
+ # Load environment variables from .env file
154
+ if provider == "groq":
155
+ # Groq https://console.groq.com/docs/models
156
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
157
+ elif provider == "huggingface":
158
+ # TODO: Add huggingface endpoint
159
+ llm = ChatHuggingFace(
160
+ llm=HuggingFaceEndpoint(
161
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
162
+ temperature=0,
163
+ ),
164
+ )
165
+ else:
166
+ raise ValueError("Invalid provider. Choose, 'groq' or 'huggingface'.")
167
+ # Bind tools to LLM
168
  llm_with_tools = llm.bind_tools(tools)
169
 
170
  # Node
 
179
  content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
180
  )
181
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
182
+
183
  builder = StateGraph(MessagesState)
184
  builder.add_node("retriever", retriever)
185
  builder.add_node("assistant", assistant)
 
194
 
195
  # Compile graph
196
  return builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ # test
199
+ # if __name__ == "__main__":
200
+ # question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
201
+ # # Build the graph
202
+ # graph = build_graph(provider="groq")
203
+ # # Run the graph
204
+ # messages = [HumanMessage(content=question)]
205
+ # messages = graph.invoke({"messages": messages})
206
+ # for m in messages["messages"]:
207
+ # m.pretty_print()