ORromu commited on
Commit
f91dabb
·
verified ·
1 Parent(s): b69765a

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +34 -5
agent.py CHANGED
@@ -18,10 +18,11 @@ from langgraph.graph.message import add_messages
18
  from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
19
  from langgraph.graph import StateGraph, START, END, MessagesState
20
  from langgraph.prebuilt import ToolNode, tools_condition
21
- from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
22
  from langchain_google_genai import ChatGoogleGenerativeAI
23
  from langchain_groq import ChatGroq
24
-
 
25
 
26
  HUGGINGFACEHUB_API_TOKEN = getenv("HUGGINGFACEHUB_API_TOKEN")
27
 
@@ -30,6 +31,25 @@ HUGGINGFACEHUB_API_TOKEN = getenv("HUGGINGFACEHUB_API_TOKEN")
30
  with open("prompt.txt", "r", encoding="utf-8") as f:
31
  system_prompt = f.read()
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # System message
34
  sys_msg = SystemMessage(content=system_prompt)
35
 
@@ -62,12 +82,21 @@ def simple_graph():
62
 
63
  def retriever(state: MessagesState):
64
  """Retriever node"""
65
- # I don't want to use a Retriever / Using similar questions.
66
- return {"messages": [sys_msg] + state["messages"]}
 
 
 
 
 
 
 
 
 
67
 
68
  # Build graph / nodes
69
  builder = StateGraph(MessagesState)
70
- builder.add_node("retriever", retriever) # Assistant
71
  builder.add_node("assistant", assistant) # Assistant
72
  builder.add_node("tools", ToolNode(tools)) # Tools
73
 
 
18
  from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
19
  from langgraph.graph import StateGraph, START, END, MessagesState
20
  from langgraph.prebuilt import ToolNode, tools_condition
21
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace, HuggingFaceEmbeddings
22
  from langchain_google_genai import ChatGoogleGenerativeAI
23
  from langchain_groq import ChatGroq
24
+ from langchain_community.vectorstores import SupabaseVectorStore
25
+ from supabase.client import Client, create_client
26
 
27
  HUGGINGFACEHUB_API_TOKEN = getenv("HUGGINGFACEHUB_API_TOKEN")
28
 
 
31
  with open("prompt.txt", "r", encoding="utf-8") as f:
32
  system_prompt = f.read()
33
 
34
+ # build a retriever
35
+ embeddings = HuggingFaceEmbeddings(
36
+ model_name="sentence-transformers/all-mpnet-base-v2"
37
+ ) # dim=768
38
+ supabase: Client = create_client(
39
+ os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_ROLE_KEY")
40
+ )
41
+ vector_store = SupabaseVectorStore(
42
+ client=supabase,
43
+ embedding=embeddings,
44
+ table_name="documents2",
45
+ query_name="match_documents_2",
46
+ )
47
+ create_retriever_tool = create_retriever_tool(
48
+ retriever=vector_store.as_retriever(),
49
+ name="Question Search",
50
+ description="A tool to retrieve similar questions from a vector store.",
51
+ )
52
+
53
  # System message
54
  sys_msg = SystemMessage(content=system_prompt)
55
 
 
82
 
83
  def retriever(state: MessagesState):
84
  """Retriever node"""
85
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
86
+
87
+ if similar_question: # Check if the list is not empty
88
+ example_msg = HumanMessage(
89
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
90
+ )
91
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
92
+ else:
93
+ # Handle the case when no similar questions are found
94
+ return {"messages": [sys_msg] + state["messages"]}
95
+
96
 
97
  # Build graph / nodes
98
  builder = StateGraph(MessagesState)
99
+ builder.add_node("retriever", retriever) # Retriever
100
  builder.add_node("assistant", assistant) # Assistant
101
  builder.add_node("tools", ToolNode(tools)) # Tools
102