tsrrus commited on
Commit
3e95386
·
verified ·
1 Parent(s): fd98589

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +62 -137
agent.py CHANGED
@@ -1,5 +1,4 @@
1
- """LangGraph Agent - Fixed Version"""
2
-
3
  import os
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
@@ -7,11 +6,7 @@ from langgraph.prebuilt import tools_condition
7
  from langgraph.prebuilt import ToolNode
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_groq import ChatGroq
10
- from langchain_huggingface import (
11
- ChatHuggingFace,
12
- HuggingFaceEndpoint,
13
- HuggingFaceEmbeddings,
14
- )
15
  from langchain_community.tools.tavily_search import TavilySearchResults
16
  from langchain_community.document_loaders import WikipediaLoader
17
  from langchain_community.document_loaders import ArxivLoader
@@ -23,7 +18,6 @@ from supabase.client import Client, create_client
23
 
24
  load_dotenv()
25
 
26
-
27
  @tool
28
  def multiply(a: int, b: int) -> int:
29
  """Multiply two numbers.
@@ -33,33 +27,30 @@ def multiply(a: int, b: int) -> int:
33
  """
34
  return a * b
35
 
36
-
37
  @tool
38
  def add(a: int, b: int) -> int:
39
  """Add two numbers.
40
-
41
  Args:
42
  a: first int
43
  b: second int
44
  """
45
  return a + b
46
 
47
-
48
  @tool
49
  def subtract(a: int, b: int) -> int:
50
  """Subtract two numbers.
51
-
52
  Args:
53
  a: first int
54
  b: second int
55
  """
56
  return a - b
57
 
58
-
59
  @tool
60
  def divide(a: int, b: int) -> int:
61
  """Divide two numbers.
62
-
63
  Args:
64
  a: first int
65
  b: second int
@@ -68,116 +59,85 @@ def divide(a: int, b: int) -> int:
68
  raise ValueError("Cannot divide by zero.")
69
  return a / b
70
 
71
-
72
  @tool
73
  def modulus(a: int, b: int) -> int:
74
  """Get the modulus of two numbers.
75
-
76
  Args:
77
  a: first int
78
  b: second int
79
  """
80
  return a % b
81
 
82
-
83
  @tool
84
  def wiki_search(query: str) -> str:
85
  """Search Wikipedia for a query and return maximum 2 results.
86
-
87
  Args:
88
  query: The search query."""
89
- try:
90
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
91
- formatted_search_docs = "\n\n---\n\n".join(
92
- [
93
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
94
- for doc in search_docs
95
- ]
96
- )
97
- return {"wiki_results": formatted_search_docs}
98
- except Exception as e:
99
- return {"wiki_results": f"Wikipedia search failed: {str(e)}"}
100
-
101
 
102
  @tool
103
  def web_search(query: str) -> str:
104
  """Search Tavily for a query and return maximum 3 results.
105
-
106
  Args:
107
  query: The search query."""
108
- try:
109
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
110
- formatted_search_docs = "\n\n---\n\n".join(
111
- [
112
- f'<Document source="{doc.get("url", "unknown")}" page="{doc.get("title", "")}"/>\n{doc.get("content", "")}\n</Document>'
113
- for doc in search_docs
114
- ]
115
- )
116
- return {"web_results": formatted_search_docs}
117
- except Exception as e:
118
- return {"web_results": f"Web search failed: {str(e)}"}
119
-
120
 
121
  @tool
122
  def arvix_search(query: str) -> str:
123
  """Search Arxiv for a query and return maximum 3 result.
124
-
125
  Args:
126
  query: The search query."""
127
- try:
128
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
129
- formatted_search_docs = "\n\n---\n\n".join(
130
- [
131
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
132
- for doc in search_docs
133
- ]
134
- )
135
- return {"arvix_results": formatted_search_docs}
136
- except Exception as e:
137
- return {"arvix_results": f"Arxiv search failed: {str(e)}"}
138
 
139
 
140
  # load the system prompt from the file
141
- try:
142
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
143
- system_prompt = f.read()
144
- except FileNotFoundError:
145
- system_prompt = "You are a helpful AI assistant. Answer questions accurately and concisely."
146
 
147
  # System message
148
  sys_msg = SystemMessage(content=system_prompt)
149
 
150
- # build a retriever with error handling
151
- def initialize_vector_store():
152
- """Initialize vector store with proper error handling."""
153
- try:
154
- embeddings = HuggingFaceEmbeddings(
155
- model_name="sentence-transformers/all-mpnet-base-v2"
156
- )
157
- supabase: Client = create_client(
158
- os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_KEY")
159
- )
160
- vector_store = SupabaseVectorStore(
161
- client=supabase,
162
- embedding=embeddings,
163
- table_name="documents",
164
- query_name="match_documents_langchain",
165
- )
166
- return vector_store
167
- except Exception as e:
168
- print(f"Warning: Failed to initialize vector store: {e}")
169
- return None
170
 
171
- # Initialize vector store
172
- vector_store = initialize_vector_store()
173
 
174
- # Create retriever tool if vector store is available
175
- if vector_store:
176
- create_retriever_tool = create_retriever_tool(
177
- retriever=vector_store.as_retriever(),
178
- name="Question Search",
179
- description="A tool to retrieve similar questions from a vector store.",
180
- )
181
 
182
  tools = [
183
  multiply,
@@ -190,26 +150,20 @@ tools = [
190
  arvix_search,
191
  ]
192
 
193
-
194
  # Build graph function
195
  def build_graph(provider: str = "huggingface"):
196
  """Build the graph"""
197
- # Load environment variables from .env file
198
- if provider == "google":
199
- # Google Gemini
200
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
201
- elif provider == "groq":
202
- # Groq https://console.groq.com/docs/models
203
- llm = ChatGroq(
204
- model="qwen-qwq-32b", temperature=0
205
- ) # optional : qwen-qwq-32b gemma2-9b-it
206
  elif provider == "huggingface":
207
  llm = ChatHuggingFace(
208
- llm=HuggingFaceEndpoint(repo_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
 
 
209
  )
210
  else:
211
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
212
-
213
  # Bind tools to LLM
214
  llm_with_tools = llm.bind_tools(tools)
215
 
@@ -217,42 +171,14 @@ def build_graph(provider: str = "huggingface"):
217
  def assistant(state: MessagesState):
218
  """Assistant node"""
219
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
220
-
221
  def retriever(state: MessagesState):
222
- """Retriever node with proper error handling"""
223
- try:
224
- # Check if vector store is available
225
- if vector_store is None:
226
- print("Vector store not available, proceeding without retrieval")
227
- return {"messages": [sys_msg] + state["messages"]}
228
-
229
- # Get the user question
230
- user_question = state["messages"][-1].content if state["messages"] else ""
231
-
232
- if not user_question:
233
- print("No user question found, proceeding without retrieval")
234
- return {"messages": [sys_msg] + state["messages"]}
235
-
236
- # Perform similarity search
237
- similar_questions = vector_store.similarity_search(user_question, k=1)
238
-
239
- # Check if we found any similar questions
240
- if similar_questions and len(similar_questions) > 0:
241
- # Extract the first similar question
242
- similar_content = similar_questions[0].page_content
243
- example_msg = HumanMessage(
244
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_content}",
245
- )
246
- print("Found similar question for retrieval")
247
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
248
- else:
249
- print("No similar questions found, proceeding without retrieval example")
250
- return {"messages": [sys_msg] + state["messages"]}
251
-
252
- except Exception as e:
253
- print(f"Error in retriever node: {e}")
254
- # Fallback: proceed without retrieval
255
- return {"messages": [sys_msg] + state["messages"]}
256
 
257
  builder = StateGraph(MessagesState)
258
  builder.add_node("retriever", retriever)
@@ -269,7 +195,6 @@ def build_graph(provider: str = "huggingface"):
269
  # Compile graph
270
  return builder.compile()
271
 
272
-
273
  # test
274
  if __name__ == "__main__":
275
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
 
1
+ """LangGraph Agent"""
 
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
 
6
  from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
 
 
 
 
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
 
18
 
19
  load_dotenv()
20
 
 
21
  @tool
22
  def multiply(a: int, b: int) -> int:
23
  """Multiply two numbers.
 
27
  """
28
  return a * b
29
 
 
30
  @tool
31
  def add(a: int, b: int) -> int:
32
  """Add two numbers.
33
+
34
  Args:
35
  a: first int
36
  b: second int
37
  """
38
  return a + b
39
 
 
40
  @tool
41
  def subtract(a: int, b: int) -> int:
42
  """Subtract two numbers.
43
+
44
  Args:
45
  a: first int
46
  b: second int
47
  """
48
  return a - b
49
 
 
50
  @tool
51
  def divide(a: int, b: int) -> int:
52
  """Divide two numbers.
53
+
54
  Args:
55
  a: first int
56
  b: second int
 
59
  raise ValueError("Cannot divide by zero.")
60
  return a / b
61
 
 
62
  @tool
63
  def modulus(a: int, b: int) -> int:
64
  """Get the modulus of two numbers.
65
+
66
  Args:
67
  a: first int
68
  b: second int
69
  """
70
  return a % b
71
 
 
72
  @tool
73
  def wiki_search(query: str) -> str:
74
  """Search Wikipedia for a query and return maximum 2 results.
75
+
76
  Args:
77
  query: The search query."""
78
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
+ formatted_search_docs = "\n\n---\n\n".join(
80
+ [
81
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
+ for doc in search_docs
83
+ ])
84
+ return {"wiki_results": formatted_search_docs}
 
 
 
 
 
85
 
86
  @tool
87
  def web_search(query: str) -> str:
88
  """Search Tavily for a query and return maximum 3 results.
89
+
90
  Args:
91
  query: The search query."""
92
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
+ formatted_search_docs = "\n\n---\n\n".join(
94
+ [
95
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
+ for doc in search_docs
97
+ ])
98
+ return {"web_results": formatted_search_docs}
 
 
 
 
 
99
 
100
  @tool
101
  def arvix_search(query: str) -> str:
102
  """Search Arxiv for a query and return maximum 3 result.
103
+
104
  Args:
105
  query: The search query."""
106
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
+ formatted_search_docs = "\n\n---\n\n".join(
108
+ [
109
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
+ for doc in search_docs
111
+ ])
112
+ return {"arvix_results": formatted_search_docs}
113
+
 
 
 
114
 
115
 
116
  # load the system prompt from the file
117
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
+ system_prompt = f.read()
 
 
 
119
 
120
  # System message
121
  sys_msg = SystemMessage(content=system_prompt)
122
 
123
+ # build a retriever
124
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
+ supabase: Client = create_client(
126
+ os.environ.get("SUPABASE_URL"),
127
+ os.environ.get("SUPABASE_SERVICE_KEY"))
128
+ vector_store = SupabaseVectorStore(
129
+ client=supabase,
130
+ embedding= embeddings,
131
+ table_name="documents2",
132
+ query_name="match_documents_2",
133
+ )
134
+ create_retriever_tool = create_retriever_tool(
135
+ retriever=vector_store.as_retriever(),
136
+ name="Question Search",
137
+ description="A tool to retrieve similar questions from a vector store.",
138
+ )
 
 
 
 
139
 
 
 
140
 
 
 
 
 
 
 
 
141
 
142
  tools = [
143
  multiply,
 
150
  arvix_search,
151
  ]
152
 
 
153
  # Build graph function
154
  def build_graph(provider: str = "huggingface"):
155
  """Build the graph"""
156
+
157
+ if provider == "groq":
158
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
 
 
 
 
 
 
159
  elif provider == "huggingface":
160
  llm = ChatHuggingFace(
161
+ llm=HuggingFaceEndpoint(
162
+ repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
163
+ ),
164
  )
165
  else:
166
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
 
167
  # Bind tools to LLM
168
  llm_with_tools = llm.bind_tools(tools)
169
 
 
171
  def assistant(state: MessagesState):
172
  """Assistant node"""
173
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
174
+
175
  def retriever(state: MessagesState):
176
+ """Retriever node"""
177
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
178
+ example_msg = HumanMessage(
179
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
180
+ )
181
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  builder = StateGraph(MessagesState)
184
  builder.add_node("retriever", retriever)
 
195
  # Compile graph
196
  return builder.compile()
197
 
 
198
  # test
199
  if __name__ == "__main__":
200
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"