Aya1610 commited on
Commit
02e4ef7
·
verified ·
1 Parent(s): f0ed782

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +60 -481
agent.py CHANGED
@@ -1,497 +1,76 @@
1
- import os
2
- from dotenv import load_dotenv
3
- from langgraph.graph import START, END, StateGraph, MessagesState
4
- from langgraph.prebuilt import tools_condition, ToolNode
5
- from langchain_google_genai import ChatGoogleGenerativeAI
6
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
7
- from langchain_community.tools.tavily_search import TavilySearchResults
8
- from langchain_community.document_loaders import WikipediaLoader
9
- from langchain_community.document_loaders import ArxivLoader
10
- from langchain_community.vectorstores import SupabaseVectorStore
11
- from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
12
- from langchain_core.tools import tool
13
- from langchain.tools.retriever import create_retriever_tool
14
  from supabase.client import Client, create_client
15
-
 
 
 
 
 
16
  load_dotenv()
17
 
18
- @tool
19
- def multiply(a: int, b: int) -> int:
20
- """Multiply two numbers.
21
- Args:
22
- a: first int
23
- b: second int
24
- """
25
- return a * b
26
-
27
- @tool
28
- def add(a: int, b: int) -> int:
29
- """Add two numbers.
30
-
31
- Args:
32
- a: first int
33
- b: second int
34
- """
35
- return a + b
36
-
37
- @tool
38
- def subtract(a: int, b: int) -> int:
39
- """Subtract two numbers.
40
-
41
- Args:
42
- a: first int
43
- b: second int
44
- """
45
- return a - b
46
-
47
- @tool
48
- def divide(a: int, b: int) -> int:
49
- """Divide two numbers.
50
-
51
- Args:
52
- a: first int
53
- b: second int
54
- """
55
- if b == 0:
56
- raise ValueError("Cannot divide by zero.")
57
- return a / b
58
-
59
- @tool
60
- def modulus(a: int, b: int) -> int:
61
- """Get the modulus of two numbers.
62
-
63
- Args:
64
- a: first int
65
- b: second int
66
- """
67
- return a % b
68
-
69
- @tool
70
- def wiki_search(query: str) -> str:
71
- """Search Wikipedia for a query and return maximum 2 results.
72
-
73
- Args:
74
- query: The search query."""
75
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
76
- formatted_search_docs = "\n\n---\n\n".join(
77
- [
78
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
79
- for doc in search_docs
80
- ])
81
- return {"wiki_results": formatted_search_docs}
82
-
83
- # @tool
84
- # def web_search(query: str) -> str:
85
- # """Search Tavily for a query and return maximum 3 results.
86
-
87
- # Args:
88
- # query: The search query."""
89
- # search_docs = TavilySearchResults(max_results=3).invoke(query=query)
90
- # formatted_search_docs = "\n\n---\n\n".join(
91
- # [
92
- # f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
93
- # for doc in search_docs
94
- # ])
95
- # return {"web_results": formatted_search_docs}
96
- # from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
97
-
98
- @tool
99
- def web_search(query: str) -> str:
100
- """Search the web for a query and return maximum 3 results.
101
-
102
- Args:
103
- query: The search query."""
104
- search = DuckDuckGoSearchAPIWrapper()
105
- results = search.results(query, 3)
106
- formatted_results = "\n\n---\n\n".join(
107
- [
108
- f"Title: {res['title']}\nURL: {res['link']}\nSnippet: {res['snippet']}"
109
- for res in results
110
- ]
111
- )
112
- return formatted_results
113
- @tool
114
- def arvix_search(query: str) -> str:
115
- """Search Arxiv for a query and return maximum 3 result.
116
-
117
- Args:
118
- query: The search query."""
119
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
120
- formatted_search_docs = "\n\n---\n\n".join(
121
- [
122
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
123
- for doc in search_docs
124
- ])
125
- return {"arvix_results": formatted_search_docs}
126
-
127
-
128
-
129
- # load the system prompt from the file
130
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
131
- system_prompt = f.read()
132
 
133
- # System message
134
- sys_msg = SystemMessage(content=system_prompt)
135
 
136
- # build a retriever
137
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
138
- supabase: Client = create_client(
139
- os.environ.get("SUPABASE_URL"),
140
- os.environ.get("SUPABASE_SERVICE_KEY"))
141
- vector_store = SupabaseVectorStore(
142
  client=supabase,
143
  embedding= embeddings,
144
  table_name="docs",
145
- query_name="match_documents_langchain",
146
  )
147
- create_retriever_tool = create_retriever_tool(
148
- retriever=vector_store.as_retriever(),
149
- name="question_search",
150
- description="A tool to retrieve similar questions from a vector store.",
151
- )
152
-
153
-
154
-
155
- tools = [
156
- multiply,
157
- add,
158
- subtract,
159
- divide,
160
- modulus,
161
- wiki_search,
162
- web_search,
163
- arvix_search,
164
- create_retriever_tool
165
- ]
166
-
167
- def build_graph(provider: str = "openai"):
168
- """Build the graph using OpenAI or Hugging Face"""
169
- # Validate provider
170
- if provider not in ["openai", "huggingface"]:
171
- raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
172
-
173
- # Initialize LLM
174
- if provider == "openai":
175
- from langchain_openai import ChatOpenAI
176
- llm = ChatOpenAI(model="gpt-4o", temperature=0)
177
- else: # huggingface
178
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
179
- llm = ChatHuggingFace(
180
- llm=HuggingFaceEndpoint(
181
- endpoint_url="https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct",
182
- temperature=0,
183
- )
184
- )
185
-
186
- # Bind tools to LLM
187
- llm_with_tools = llm.bind_tools(tools)
188
 
189
- # Define nodes
190
- def assistant(state: MessagesState):
191
- """Assistant node - generates responses"""
192
- messages = llm_with_tools.invoke(state["messages"])
193
- # Generate response using LLM
194
- # response = llm_with_tools.invoke(messages)
195
- # Return new state with appended message
196
- return {"messages": messages}
197
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- def retriever(state: MessagesState):
200
- """Retriever node - provides context from vector store"""
201
- messages = state["messages"]
202
- query = messages[-1].content
203
-
204
- # Retrieve similar documents
205
- similar_docs = vector_store.similarity_search(query, k=1)
206
-
207
- if not similar_docs:
208
- return {"messages": messages}
209
-
210
- context = similar_docs[0].page_content
211
- context_msg = SystemMessage(content=f"Reference context:\n{context}")
212
-
213
- return {"messages": messages + [context_msg]}
214
 
215
- # Build graph
216
- builder = StateGraph(MessagesState)
217
-
218
- # Add nodes
219
- # builder.add_node("retriever", retriever)
220
- builder.add_node("assistant", assistant)
221
- builder.add_node("tools", ToolNode(tools))
222
-
223
- # Set up edges
224
- builder.set_entry_point("assistant")
225
- builder.set_finish_point("assistant")
226
-
227
- # Compile graph
228
- return builder.compile()
229
-
230
- # def retriever(state: MessagesState):
231
- # """Retriever node - provides context from vector store"""
232
- # # Get current messages
233
- # messages = state["messages"]
234
- # # Last message is the user query
235
- # query = messages[-1].content
236
-
237
- # # Retrieve similar documents
238
- # similar_docs = vector_store.similarity_search(query, k=1)
239
-
240
- # if not similar_docs:
241
- # # Return original messages if no context found
242
- # return {"messages": messages}
243
-
244
- # # Get context from first document
245
- # context = similar_docs[0].page_content
246
- # # Create system message with context
247
- # context_msg = SystemMessage(content=f"Reference context:\n{context}")
248
-
249
- # # Append context to messages
250
- # return {"messages": messages + [context_msg]}
251
-
252
- # # Build graph
253
- # builder = StateGraph(MessagesState)
254
-
255
- # # Add nodes
256
- # builder.add_node("retriever", retriever)
257
- # builder.add_node("assistant", assistant)
258
- # builder.add_node("tools", ToolNode(tools))
259
-
260
- # # Set up edges
261
- # builder.set_entry_point("retriever")
262
- # builder.add_edge("retriever", "assistant")
263
-
264
- # # Conditional tool usage
265
- # builder.add_conditional_edges(
266
- # "assistant",
267
- # tools_condition,
268
- # {
269
- # # Continue to tools if needed
270
- # "continue": "tools",
271
- # # End conversation if no tools needed
272
- # "end": END
273
- # }
274
- # )
275
-
276
- # # After tools execute, return to assistant for response generation
277
- # builder.add_edge("tools", "assistant")
278
-
279
- # # builder.add_finish_point(END) # Explicitly declare END as finish point
280
- # return builder.compile()
281
-
282
-
283
- # def build_graph(provider: str = "openai"):
284
- # """Build the graph using OpenAI or Hugging Face"""
285
- # # Validate provider
286
- # if provider not in ["openai", "huggingface"]:
287
- # raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
288
-
289
- # # Initialize LLM based on provider
290
- # if provider == "openai":
291
- # from langchain_openai import ChatOpenAI
292
- # llm = ChatOpenAI(model="gpt-4o", temperature=0)
293
- # else: # huggingface
294
- # from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
295
- # llm = ChatHuggingFace(
296
- # llm=HuggingFaceEndpoint(
297
- # endpoint_url="https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct",
298
- # temperature=0,
299
- # )
300
- # )
301
-
302
- # # Bind tools to LLM
303
- # llm_with_tools = llm.bind_tools(tools)
304
-
305
- # # Define nodes
306
- # def assistant(state: MessagesState):
307
- # """Assistant node"""
308
- # return {"messages": [llm_with_tools.invoke(state["messages"])]}
309
-
310
- # def retriever(state: MessagesState):
311
- # """Retriever node - provides context from vector store"""
312
- # query = state["messages"][-1].content
313
- # similar_docs = vector_store.similarity_search(query, k=1)
314
-
315
- # if not similar_docs:
316
- # return {"messages": [AIMessage(content="No relevant information found")]}
317
-
318
- # similar_doc = similar_docs[0]
319
- # content = similar_doc.page_content
320
-
321
- # # Extract answer if formatted, otherwise use full content
322
- # if "Final answer :" in content:
323
- # answer = content.split("Final answer :")[-1].strip()
324
- # else:
325
- # answer = content.strip()
326
-
327
- # return {"messages": [AIMessage(content=answer)]}
328
-
329
- # # Build graph
330
- # builder = StateGraph(MessagesState)
331
-
332
- # # Add nodes
333
- # builder.add_node("retriever", retriever)
334
- # builder.add_node("assistant", assistant)
335
- # builder.add_node("tools", ToolNode(tools))
336
-
337
- # # Set up edges
338
- # builder.set_entry_point("retriever")
339
- # builder.add_edge("retriever", "assistant")
340
- # builder.add_conditional_edges(
341
- # "assistant",
342
- # tools_condition,
343
- # {"continue": "tools", "end": END}
344
- # )
345
- # builder.add_edge("tools", "assistant")
346
-
347
- # return builder.compile()
348
- # def build_graph(provider: str = "google"):
349
- # """Build the graph"""
350
- # # Load environment variables from .env file
351
- # if provider == "google":
352
- # # Google Gemini
353
- # llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
354
- # elif provider == "groq":
355
- # # Groq https://console.groq.com/docs/models
356
- # llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
357
- # elif provider == "huggingface":
358
- # # TODO: Add huggingface endpoint
359
- # llm = ChatHuggingFace(
360
- # llm=HuggingFaceEndpoint(
361
- # url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
362
- # temperature=0,
363
- # ),
364
- # )
365
- # else:
366
- # raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
367
- # # Bind tools to LLM
368
- # llm_with_tools = llm.bind_tools(tools)
369
-
370
- # # Node
371
- # def assistant(state: MessagesState):
372
- # """Assistant node"""
373
- # return {"messages": [llm_with_tools.invoke(state["messages"])]}
374
-
375
- # # def retriever(state: MessagesState):
376
- # # """Retriever node"""
377
- # # similar_question = vector_store.similarity_search(state["messages"][0].content)
378
- # #example_msg = HumanMessage(
379
- # # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
380
- # # )
381
- # # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
382
-
383
- # from langchain_core.messages import AIMessage
384
-
385
- # def retriever(state: MessagesState):
386
- # query = state["messages"][-1].content
387
- # similar_doc = vector_store.similarity_search(query, k=1)[0]
388
-
389
- # content = similar_doc.page_content
390
- # if "Final answer :" in content:
391
- # answer = content.split("Final answer :")[-1].strip()
392
- # else:
393
- # answer = content.strip()
394
-
395
- # return {"messages": [AIMessage(content=answer)]}
396
-
397
- # builder = StateGraph(MessagesState)
398
- #builder.add_node("retriever", retriever)
399
- #builder.add_node("assistant", assistant)
400
- #builder.add_node("tools", ToolNode(tools))
401
- #builder.add_edge(START, "retriever")
402
- #builder.add_edge("retriever", "assistant")
403
- #builder.add_conditional_edges(
404
- # "assistant",
405
- # tools_condition,
406
- #)
407
- #builder.add_edge("tools", "assistant")
408
-
409
- # builder = StateGraph(MessagesState)
410
- # builder.add_node("retriever", retriever)
411
-
412
- # # Retriever ist Start und Endpunkt
413
- # builder.set_entry_point("retriever")
414
- # builder.set_finish_point("retriever")
415
-
416
- # # Compile graph
417
- # return builder.compile()
418
- # def build_graph(provider: str = "openai"):
419
- # """Build the graph using OpenAI or Hugging Face"""
420
-
421
- # if provider == "openai":
422
- # # OpenAI ChatGPT (e.g., GPT-4 or GPT-3.5)
423
- # from langchain.chat_models import ChatOpenAI
424
- # llm = ChatOpenAI(model="gpt-4", temperature=0)
425
-
426
- # elif provider == "huggingface":
427
- # # Hugging Face endpoint
428
- # from langchain.chat_models import ChatHuggingFace
429
- # from langchain.llms import HuggingFaceEndpoint
430
-
431
- # llm = ChatHuggingFace(
432
- # llm=HuggingFaceEndpoint(
433
- # url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
434
- # temperature=0,
435
- # )
436
- # )
437
-
438
- # else:
439
- # raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
440
-
441
- # # Bind tools to LLM
442
- # llm_with_tools = llm.bind_tools(tools)
443
-
444
- # # return llm_with_tools
445
-
446
- # # Node
447
- # def assistant(state: MessagesState):
448
- # """Assistant node"""
449
- # return {"messages": [llm_with_tools.invoke(state["messages"])]}
450
-
451
- # # def retriever(state: MessagesState):
452
- # # """Retriever node"""
453
- # # similar_question = vector_store.similarity_search(state["messages"][0].content)
454
- # #example_msg = HumanMessage(
455
- # # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
456
- # # )
457
- # # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
458
-
459
- # from langchain_core.messages import AIMessage
460
-
461
- # def retriever(state: MessagesState):
462
- # query = state["messages"][-1].content
463
- # similar_doc = vector_store.similarity_search(query, k=1)[0]
464
- # if not similar_docs:
465
- # return {"messages": [AIMessage(content="No relevant information found")]}
466
- # similar_doc = similar_docs[0]
467
-
468
- # content = similar_doc.page_content
469
- # if "Final answer :" in content:
470
- # answer = content.split("Final answer :")[-1].strip()
471
- # else:
472
- # answer = content.strip()
473
-
474
- # return {"messages": [AIMessage(content=answer)]}
475
-
476
- # # builder = StateGraph(MessagesState)
477
- # #builder.add_node("retriever", retriever)
478
- # #builder.add_node("assistant", assistant)
479
- # #builder.add_node("tools", ToolNode(tools))
480
- # #builder.add_edge(START, "retriever")
481
- # #builder.add_edge("retriever", "assistant")
482
- # #builder.add_conditional_edges(
483
- # # "assistant",
484
- # # tools_condition,
485
- # #)
486
- # #builder.add_edge("tools", "assistant")
487
-
488
- # builder = StateGraph(MessagesState)
489
- # builder.add_node("retriever", retriever)
490
 
491
-
492
- # builder.set_entry_point("retriever")
493
- # builder.set_finish_point("retriever")
 
 
 
 
 
 
 
 
494
 
495
- # # Compile graph
496
- # return builder.compile()
497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
3
  from supabase.client import Client, create_client
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain_community.vectorstores import SupabaseVectorStore
6
+ from langgraph.graph import StateGraph, MessageState
7
+ from langgraph.prebuilt import ToolNode
8
+ from langchain_core.messages import HumanMessage, AIMessage
9
+ from langchain_core.tools import tool
10
  load_dotenv()
11
 
12
+ supabase: Client = create_client(
13
+ os.environ["SUPABASE_URL"],
14
+ os.environ["SUPABASE_SERVICE_KEY"]
15
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
 
18
 
19
+ vector_search = SupabaseVectorStore(
 
 
 
 
 
20
  client=supabase,
21
  embedding= embeddings,
22
  table_name="docs",
23
+ query_name="match_documents_langchain"
24
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ all_rows = supabase.table("docs").select("content").execute().data
 
 
 
 
 
 
 
27
 
28
+ qa_dict: dict[str, str] = {}
29
+ for row in all_rows:
30
+ raw = row["content"]
31
+ if "Answer:" in raw:
32
+ parts = raw.split("Answer:", 1)
33
+ question_part = parts[0].strip()
34
+ answer_part = parts[1].strip()
35
+ if question_part.lower().startswith("question"):
36
+ question_part = question_part.split(":", 1)[1].strip()
37
+ qa_dict[question_part] = answer_part
38
+ else:
39
+ qa_dict[raw.strip()] = ""
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ @tool
43
+ def find_answer(query: str) -> str:
44
+ """
45
+ If 'query' exactly matches a key in qa_dict, return qa_dict[query].
46
+ Otherwise, do an embedding search (k=1) in Supabase and return only the "Answer:" portion.
47
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ if query in qa_dict:
50
+ return qa_dict[query]
51
+ similar_docs = vector_store.similarity_search(query, k=1)
52
+ if not similar_docs:
53
+ return "Sorry, I couldn't find that question"
54
+ top_doc = similar_docs[0].page_content
55
+ if "Answer:" in top_doc:
56
+ return top_doc.split("Answer:", 1)[1].strip()
57
+ if "Final answer: " in top_doc:
58
+ return top_doc.split("Final answer :", 1)[1].strip()
59
+ return top_doc.strip()
60
 
61
+ tools = [find_answer]
 
62
 
63
+ def build_graph(provider: str = "openai"):
64
+ """
65
+ Build a LangGraph where every HumanMessage is handled by find_answe(---),
66
+ and the returned AIMessage contains exactly the stored answer text.
67
+ """
68
+ def retriever_node(state: MessageState):
69
+ user_query = state["messages"][-1].content
70
+ answer_text = find_answer(user_query)
71
+ return {"messages": state["messages"] + [AIMessage(content=answer_text)]}
72
+ builder = StateGraph(MessageState)
73
+ builder.add_node("retriever", retriever_node)
74
+ builder.set_entry_point("retriever")
75
+ builder.set_finish_point("retriever")
76
+ return builder.compile()