tsrrus commited on
Commit
62309ba
·
verified ·
1 Parent(s): 7355acc

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +76 -24
agent.py CHANGED
@@ -121,21 +121,33 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
121
  sys_msg = SystemMessage(content=system_prompt)
122
 
123
  # build a retriever
124
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
- supabase: Client = create_client(
126
- os.environ.get("SUPABASE_URL"),
127
- os.environ.get("SUPABASE_SERVICE_KEY"))
128
- vector_store = SupabaseVectorStore(
129
- client=supabase,
130
- embedding= embeddings,
131
- table_name="documents",
132
- query_name="match_documents_langchain",
133
- )
 
 
 
 
 
 
 
 
 
 
 
 
134
  create_retriever_tool = create_retriever_tool(
135
- retriever=vector_store.as_retriever(),
136
  name="Question Search",
137
  description="A tool to retrieve similar questions from a vector store.",
138
- )
139
 
140
 
141
 
@@ -152,18 +164,19 @@ tools = [
152
 
153
  # Build graph function
154
  def build_graph(provider: str = "huggingface"):
155
- """Build the graph"""
156
 
157
  if provider == "groq":
158
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
 
 
159
  elif provider == "huggingface":
160
  llm = ChatHuggingFace(
161
- llm=HuggingFaceEndpoint(
162
- repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
163
- ),
164
  )
165
  else:
166
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
 
167
  # Bind tools to LLM
168
  llm_with_tools = llm.bind_tools(tools)
169
 
@@ -171,14 +184,32 @@ def build_graph(provider: str = "huggingface"):
171
  def assistant(state: MessagesState):
172
  """Assistant node"""
173
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
174
-
175
  def retriever(state: MessagesState):
176
- """Retriever node"""
177
- similar_question = vector_store.similarity_search(state["messages"][0].content)
178
- example_msg = HumanMessage(
179
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
180
- )
181
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  builder = StateGraph(MessagesState)
184
  builder.add_node("retriever", retriever)
@@ -194,6 +225,27 @@ def build_graph(provider: str = "huggingface"):
194
 
195
  # Compile graph
196
  return builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  # test
199
  if __name__ == "__main__":
 
121
  sys_msg = SystemMessage(content=system_prompt)
122
 
123
  # build a retriever
124
+ try:
125
+ embeddings = HuggingFaceEmbeddings(
126
+ model_name="sentence-transformers/all-mpnet-base-v2"
127
+ ) # dim=768
128
+ supabase: Client = create_client(
129
+ os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_KEY")
130
+ )
131
+ vector_store = SupabaseVectorStore(
132
+ client=supabase,
133
+ embedding=embeddings,
134
+ table_name="documents",
135
+ query_name="match_documents_langchain",
136
+ )
137
+
138
+ # Test the connection
139
+ test_results = vector_store.similarity_search("test query", k=1)
140
+ print(f"Vector store initialized successfully. Test returned {len(test_results)} results.")
141
+
142
+ except Exception as e:
143
+ print(f"Warning: Vector store initialization failed: {e}")
144
+ vector_store = None
145
+
146
  create_retriever_tool = create_retriever_tool(
147
+ retriever=vector_store.as_retriever() if vector_store else None,
148
  name="Question Search",
149
  description="A tool to retrieve similar questions from a vector store.",
150
+ ) if vector_store else None
151
 
152
 
153
 
 
164
 
165
  # Build graph function
166
  def build_graph(provider: str = "huggingface"):
167
+ """Build the graph with improved error handling"""
168
 
169
  if provider == "groq":
170
+ llm = ChatGroq(
171
+ model="qwen-qwq-32b", temperature=0
172
+ ) # optional : qwen-qwq-32b gemma2-9b-it
173
  elif provider == "huggingface":
174
  llm = ChatHuggingFace(
175
+ llm=HuggingFaceEndpoint(repo_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
 
 
176
  )
177
  else:
178
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
179
+
180
  # Bind tools to LLM
181
  llm_with_tools = llm.bind_tools(tools)
182
 
 
184
  def assistant(state: MessagesState):
185
  """Assistant node"""
186
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
187
+
188
  def retriever(state: MessagesState):
189
+ """Retriever node with error handling"""
190
+ try:
191
+ # Check if vector_store is available
192
+ if vector_store is None:
193
+ print("Vector store not available, proceeding without retrieval")
194
+ return {"messages": [sys_msg] + state["messages"]}
195
+
196
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
197
+
198
+ # Check if we have results before accessing them
199
+ if similar_question and len(similar_question) > 0:
200
+ example_msg = HumanMessage(
201
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
202
+ )
203
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
204
+ else:
205
+ # No similar questions found, proceed without reference
206
+ print("No similar questions found in vector store")
207
+ return {"messages": [sys_msg] + state["messages"]}
208
+
209
+ except Exception as e:
210
+ print(f"Error in retriever: {e}")
211
+ # Fallback: continue without retrieval
212
+ return {"messages": [sys_msg] + state["messages"]}
213
 
214
  builder = StateGraph(MessagesState)
215
  builder.add_node("retriever", retriever)
 
225
 
226
  # Compile graph
227
  return builder.compile()
228
+
229
+ def retriever(state: MessagesState):
230
+ """Retriever node with error handling"""
231
+ try:
232
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
233
+
234
+ # Check if we have results before accessing them
235
+ if similar_question and len(similar_question) > 0:
236
+ example_msg = HumanMessage(
237
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
238
+ )
239
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
240
+ else:
241
+ # No similar questions found, proceed without reference
242
+ print("No similar questions found in vector store")
243
+ return {"messages": [sys_msg] + state["messages"]}
244
+
245
+ except Exception as e:
246
+ print(f"Error in retriever: {e}")
247
+ # Fallback: continue without retrieval
248
+ return {"messages": [sys_msg] + state["messages"]}
249
 
250
  # test
251
  if __name__ == "__main__":