prozorov commited on
Commit
55db96e
·
verified ·
1 Parent(s): 010c895

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +12 -4
agent.py CHANGED
@@ -70,14 +70,21 @@ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-b
70
  supabase: Client = create_client(
71
  os.environ.get("SUPABASE_URL"),
72
  os.environ.get("SUPABASE_SERVICE_KEY"))
 
73
  vector_store = SupabaseVectorStore(
74
  client=supabase,
75
- embedding= embeddings,
76
  table_name="documents",
77
  query_name="match_documents_langchain",
 
 
78
  )
79
- create_retriever_tool = create_retriever_tool(
80
- retriever=vector_store.as_retriever(),
 
 
 
 
81
  name="Question Search",
82
  description="A tool to retrieve similar questions from a vector store.",
83
  )
@@ -86,6 +93,7 @@ tools = [
86
  wiki_search,
87
  web_search,
88
  arvix_search,
 
89
  ]
90
 
91
  def build_graph():
@@ -109,7 +117,7 @@ def build_graph():
109
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
110
 
111
  builder = StateGraph(MessagesState)
112
- #builder.add_node("retriever", retriever)
113
  builder.add_node("assistant", assistant)
114
  builder.add_node("tools", ToolNode(tools))
115
  builder.add_edge(START, "retriever")
 
70
  supabase: Client = create_client(
71
  os.environ.get("SUPABASE_URL"),
72
  os.environ.get("SUPABASE_SERVICE_KEY"))
73
+
74
  vector_store = SupabaseVectorStore(
75
  client=supabase,
76
+ embedding=embeddings,
77
  table_name="documents",
78
  query_name="match_documents_langchain",
79
+ text_key="content",
80
+ embedding_key="embedding"
81
  )
82
+
83
+ retriever_tool = create_retriever_tool(
84
+ retriever=vector_store.as_retriever(
85
+ search_type="similarity",
86
+ search_kwargs={"k": 5}
87
+ ),
88
  name="Question Search",
89
  description="A tool to retrieve similar questions from a vector store.",
90
  )
 
93
  wiki_search,
94
  web_search,
95
  arvix_search,
96
+ retriever_tool,
97
  ]
98
 
99
  def build_graph():
 
117
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
118
 
119
  builder = StateGraph(MessagesState)
120
+ builder.add_node("retriever", retriever)
121
  builder.add_node("assistant", assistant)
122
  builder.add_node("tools", ToolNode(tools))
123
  builder.add_edge(START, "retriever")