hua101 commited on
Commit
0756a62
·
verified ·
1 Parent(s): b541a3a

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +217 -36
agent.py CHANGED
@@ -17,7 +17,10 @@ from langchain.tools.retriever import create_retriever_tool
17
  from supabase.client import Client, create_client
18
 
19
  load_dotenv()
 
 
20
 
 
21
  @tool
22
  def multiply(a: int, b: int) -> int:
23
  """Multiply two numbers.
@@ -70,6 +73,7 @@ def modulus(a: int, b: int) -> int:
70
  """
71
  return a % b
72
 
 
73
  @tool
74
  def wiki_search(query: str) -> str:
75
  """Search Wikipedia for a query and return maximum 2 results.
@@ -112,34 +116,165 @@ def arvix_search(query: str) -> str:
112
  ])
113
  return {"arvix_results": formatted_search_docs}
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  # load the system prompt from the file
118
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
- system_prompt = f.read()
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # System message
122
  sys_msg = SystemMessage(content=system_prompt)
123
 
124
- # build a retriever
125
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
126
- supabase: Client = create_client(
127
- os.environ.get("SUPABASE_URL"),
128
- os.environ.get("SUPABASE_SERVICE_KEY"))
129
- vector_store = SupabaseVectorStore(
130
- client=supabase,
131
- embedding= embeddings,
132
- table_name="documents",
133
- query_name="match_documents_langchain",
134
- )
135
- create_retriever_tool = create_retriever_tool(
136
- retriever=vector_store.as_retriever(),
137
- name="Question Search",
138
- description="A tool to retrieve similar questions from a vector store.",
139
- )
140
-
 
 
 
 
 
 
 
 
 
 
 
141
 
 
 
142
 
 
143
  tools = [
144
  multiply,
145
  add,
@@ -149,20 +284,26 @@ tools = [
149
  wiki_search,
150
  web_search,
151
  arvix_search,
 
 
 
152
  ]
153
 
 
 
 
 
 
 
 
154
  # Build graph function
155
  def build_graph(provider: str = "groq"):
156
  """Build the graph"""
157
- # Load environment variables from .env file
158
  if provider == "google":
159
- # Google Gemini
160
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
161
  elif provider == "groq":
162
- # Groq https://console.groq.com/docs/models
163
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
164
  elif provider == "huggingface":
165
- # TODO: Add huggingface endpoint
166
  llm = ChatHuggingFace(
167
  llm=HuggingFaceEndpoint(
168
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
@@ -171,6 +312,7 @@ def build_graph(provider: str = "groq"):
171
  )
172
  else:
173
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
 
174
  # Bind tools to LLM
175
  llm_with_tools = llm.bind_tools(tools)
176
 
@@ -180,12 +322,28 @@ def build_graph(provider: str = "groq"):
180
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
181
 
182
  def retriever(state: MessagesState):
183
- """Retriever node"""
184
- similar_question = vector_store.similarity_search(state["messages"][0].content)
185
- example_msg = HumanMessage(
186
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
187
- )
188
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  builder = StateGraph(MessagesState)
191
  builder.add_node("retriever", retriever)
@@ -204,11 +362,34 @@ def build_graph(provider: str = "groq"):
204
 
205
  # test
206
  if __name__ == "__main__":
207
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
 
 
 
 
 
 
 
 
208
  # Build the graph
209
  graph = build_graph(provider="groq")
210
- # Run the graph
211
- messages = [HumanMessage(content=question)]
212
- messages = graph.invoke({"messages": messages})
213
- for m in messages["messages"]:
214
- m.pretty_print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  from supabase.client import Client, create_client
18
 
19
  load_dotenv()
20
+ print("GROQ_API_KEY:", os.getenv("GROQ_API_KEY"))
21
+ print("SUPABASE_URL:", os.getenv("SUPABASE_URL"))
22
 
23
+ # === 原有的数学工具 ===
24
  @tool
25
  def multiply(a: int, b: int) -> int:
26
  """Multiply two numbers.
 
73
  """
74
  return a % b
75
 
76
+ # === 原有的搜索工具 ===
77
  @tool
78
  def wiki_search(query: str) -> str:
79
  """Search Wikipedia for a query and return maximum 2 results.
 
116
  ])
117
  return {"arvix_results": formatted_search_docs}
118
 
119
+ # === 新增:Supabase 工具 ===
120
+ @tool
121
+ def supabase_vector_search(query: str, max_results: int = 3) -> str:
122
+ """Search the Supabase knowledge base using vector similarity.
123
+
124
+ Args:
125
+ query: The search query
126
+ max_results: Maximum number of results to return (default: 3)
127
+ """
128
+ try:
129
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
130
+ supabase: Client = create_client(
131
+ os.environ.get("SUPABASE_URL"),
132
+ os.environ.get("SUPABASE_SERVICE_KEY")
133
+ )
134
+
135
+ vector_store = SupabaseVectorStore(
136
+ client=supabase,
137
+ embedding=embeddings,
138
+ table_name="supabase_docs", # 使用您的实际表名
139
+ query_name="match_documents", # 使用我们创建的函数
140
+ )
141
+
142
+ results = vector_store.similarity_search(query, k=max_results)
143
+
144
+ if results:
145
+ formatted_results = "\n\n---\n\n".join([
146
+ f'<Document similarity="high"/>\n{doc.page_content[:800]}...\n</Document>'
147
+ for doc in results
148
+ ])
149
+ return {"supabase_vector_results": formatted_results}
150
+ else:
151
+ return {"message": "No relevant documents found in knowledge base"}
152
+
153
+ except Exception as e:
154
+ return {"error": f"Supabase vector search failed: {str(e)}"}
155
 
156
+ @tool
157
+ def supabase_text_search(query: str, max_results: int = 3) -> str:
158
+ """Search the Supabase knowledge base using text search.
159
+
160
+ Args:
161
+ query: The search query
162
+ max_results: Maximum number of results to return (default: 3)
163
+ """
164
+ try:
165
+ supabase: Client = create_client(
166
+ os.environ.get("SUPABASE_URL"),
167
+ os.environ.get("SUPABASE_SERVICE_KEY")
168
+ )
169
+
170
+ # 使用我们创建的混合搜索函数,只用文本搜索
171
+ result = supabase.rpc('hybrid_search', {
172
+ 'search_query': query,
173
+ 'search_type': 'text',
174
+ 'max_results': max_results
175
+ }).execute()
176
+
177
+ if result.data:
178
+ formatted_results = "\n\n---\n\n".join([
179
+ f'<Document similarity="{item.get("similarity", 0):.3f}"/>\n{item["content"][:800]}...\n</Document>'
180
+ for item in result.data
181
+ ])
182
+ return {"supabase_text_results": formatted_results}
183
+ else:
184
+ return {"message": "No relevant documents found in knowledge base"}
185
+
186
+ except Exception as e:
187
+ return {"error": f"Supabase text search failed: {str(e)}"}
188
+
189
+ @tool
190
+ def get_knowledge_context(query: str) -> str:
191
+ """Get contextual information from the knowledge base for better understanding.
192
+
193
+ Args:
194
+ query: The user's question
195
+ """
196
+ try:
197
+ supabase: Client = create_client(
198
+ os.environ.get("SUPABASE_URL"),
199
+ os.environ.get("SUPABASE_SERVICE_KEY")
200
+ )
201
+
202
+ result = supabase.rpc('get_agent_context', {
203
+ 'user_query': query,
204
+ 'context_limit': 2
205
+ }).execute()
206
+
207
+ if result.data and len(result.data) > 0:
208
+ context_data = result.data[0]
209
+ context_text = context_data.get("context_text", "")
210
+ confidence = context_data.get("confidence_score", 0)
211
+ source_count = context_data.get("source_count", 0)
212
+
213
+ if context_text and source_count > 0:
214
+ return {
215
+ "context": context_text[:1000], # 限制长度
216
+ "confidence": f"{confidence:.2f}",
217
+ "sources": source_count
218
+ }
219
+ else:
220
+ return {"message": "No relevant context found"}
221
+ else:
222
+ return {"message": "No context available"}
223
+
224
+ except Exception as e:
225
+ return {"error": f"Context retrieval failed: {str(e)}"}
226
 
227
  # load the system prompt from the file
228
+ try:
229
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
230
+ system_prompt = f.read()
231
+ except FileNotFoundError:
232
+ # 如果文件不存在,使用默认系统提示
233
+ system_prompt = """你是一个智能助手,可以使用多种���具来回答用户的问题。
234
+
235
+ 可用工具包括:
236
+ 1. 数学计算工具(加减乘除等)
237
+ 2. 网络搜索工具(Wikipedia, Arxiv, Web搜索)
238
+ 3. Supabase 知识库工具(向量搜索、文本搜索、上下文获取)
239
+
240
+ 请根据用户的问题选择最合适的工具,并提供准确、有用的答案。对于知识库中的信息,优先使用 Supabase 工具。"""
241
 
242
  # System message
243
  sys_msg = SystemMessage(content=system_prompt)
244
 
245
+ # === 更新 retriever 设置 ===
246
+ def setup_vector_store():
247
+ """设置向量存储"""
248
+ try:
249
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
250
+ supabase: Client = create_client(
251
+ os.environ.get("SUPABASE_URL"),
252
+ os.environ.get("SUPABASE_SERVICE_KEY")
253
+ )
254
+
255
+ vector_store = SupabaseVectorStore(
256
+ client=supabase,
257
+ embedding=embeddings,
258
+ table_name="supabase_docs", # 修改为正确的表名
259
+ query_name="match_documents", # 使用我们创建的函数
260
+ )
261
+
262
+ retriever_tool = create_retriever_tool(
263
+ retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
264
+ name="Knowledge Base Search",
265
+ description="Search the knowledge base for similar questions and answers.",
266
+ )
267
+
268
+ return vector_store, retriever_tool
269
+
270
+ except Exception as e:
271
+ print(f"❌ Vector store setup failed: {e}")
272
+ return None, None
273
 
274
+ # 设置向量存储
275
+ vector_store, retriever_tool = setup_vector_store()
276
 
277
+ # === 更新工具列表 ===
278
  tools = [
279
  multiply,
280
  add,
 
284
  wiki_search,
285
  web_search,
286
  arvix_search,
287
+ supabase_vector_search, # 新增
288
+ supabase_text_search, # 新增
289
+ get_knowledge_context, # 新增
290
  ]
291
 
292
+ # 如果 retriever 设置成功,添加到工具列表
293
+ if retriever_tool:
294
+ tools.append(retriever_tool)
295
+ print("✅ Knowledge base retriever tool added")
296
+ else:
297
+ print("⚠️ Knowledge base retriever tool not available")
298
+
299
  # Build graph function
300
  def build_graph(provider: str = "groq"):
301
  """Build the graph"""
 
302
  if provider == "google":
 
303
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
304
  elif provider == "groq":
305
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
306
  elif provider == "huggingface":
 
307
  llm = ChatHuggingFace(
308
  llm=HuggingFaceEndpoint(
309
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
 
312
  )
313
  else:
314
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
315
+
316
  # Bind tools to LLM
317
  llm_with_tools = llm.bind_tools(tools)
318
 
 
322
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
323
 
324
  def retriever(state: MessagesState):
325
+ """Enhanced retriever node with Supabase integration"""
326
+ try:
327
+ if vector_store and len(state["messages"]) > 0:
328
+ user_query = state["messages"][-1].content
329
+ similar_questions = vector_store.similarity_search(user_query, k=2)
330
+
331
+ if similar_questions:
332
+ example_content = "\n\n".join([
333
+ f"Similar Q&A {i+1}: {doc.page_content[:400]}..."
334
+ for i, doc in enumerate(similar_questions)
335
+ ])
336
+ example_msg = HumanMessage(
337
+ content=f"Here are similar questions and answers from the knowledge base for reference:\n\n{example_content}",
338
+ )
339
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
340
+
341
+ # 如果没有向量存储或搜索失败,返回原始消息
342
+ return {"messages": [sys_msg] + state["messages"]}
343
+
344
+ except Exception as e:
345
+ print(f"Retriever error: {e}")
346
+ return {"messages": [sys_msg] + state["messages"]}
347
 
348
  builder = StateGraph(MessagesState)
349
  builder.add_node("retriever", retriever)
 
362
 
363
  # test
364
  if __name__ == "__main__":
365
+ # 测试多种类型的问题
366
+ test_questions = [
367
+ "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?",
368
+ "What is the area of the green polygon?", # 测试知识库搜索
369
+ "Calculate 25 times 17", # 测试数学工具
370
+ ]
371
+
372
+ print("🚀 开始测试 Agent...")
373
+
374
  # Build the graph
375
  graph = build_graph(provider="groq")
376
+
377
+ for i, question in enumerate(test_questions, 1):
378
+ print(f"\n{'='*60}")
379
+ print(f"测试 {i}/3: {question}")
380
+ print(f"{'='*60}")
381
+
382
+ try:
383
+ messages = [HumanMessage(content=question)]
384
+ result = graph.invoke({"messages": messages})
385
+
386
+ print("\n📋 对话历史:")
387
+ for m in result["messages"]:
388
+ m.pretty_print()
389
+
390
+ except Exception as e:
391
+ print(f"❌ 处理问题时出错: {e}")
392
+
393
+ print(f"\n{'-'*60}")
394
+
395
+ print("\n🎉 测试完成!")