Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -18,10 +18,11 @@ from langgraph.graph.message import add_messages
|
|
| 18 |
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
|
| 19 |
from langgraph.graph import StateGraph, START, END, MessagesState
|
| 20 |
from langgraph.prebuilt import ToolNode, tools_condition
|
| 21 |
-
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
| 22 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 23 |
from langchain_groq import ChatGroq
|
| 24 |
-
|
|
|
|
| 25 |
|
| 26 |
HUGGINGFACEHUB_API_TOKEN = getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 27 |
|
|
@@ -30,6 +31,25 @@ HUGGINGFACEHUB_API_TOKEN = getenv("HUGGINGFACEHUB_API_TOKEN")
|
|
| 30 |
with open("prompt.txt", "r", encoding="utf-8") as f:
|
| 31 |
system_prompt = f.read()
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
# System message
|
| 34 |
sys_msg = SystemMessage(content=system_prompt)
|
| 35 |
|
|
@@ -62,12 +82,21 @@ def simple_graph():
|
|
| 62 |
|
| 63 |
def retriever(state: MessagesState):
|
| 64 |
"""Retriever node"""
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# Build graph / nodes
|
| 69 |
builder = StateGraph(MessagesState)
|
| 70 |
-
builder.add_node("retriever", retriever) #
|
| 71 |
builder.add_node("assistant", assistant) # Assistant
|
| 72 |
builder.add_node("tools", ToolNode(tools)) # Tools
|
| 73 |
|
|
|
|
| 18 |
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
|
| 19 |
from langgraph.graph import StateGraph, START, END, MessagesState
|
| 20 |
from langgraph.prebuilt import ToolNode, tools_condition
|
| 21 |
+
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace, HuggingFaceEmbeddings
|
| 22 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 23 |
from langchain_groq import ChatGroq
|
| 24 |
+
from langchain_community.vectorstores import SupabaseVectorStore
|
| 25 |
+
from supabase.client import Client, create_client
|
| 26 |
|
| 27 |
HUGGINGFACEHUB_API_TOKEN = getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 28 |
|
|
|
|
| 31 |
with open("prompt.txt", "r", encoding="utf-8") as f:
|
| 32 |
system_prompt = f.read()
|
| 33 |
|
| 34 |
+
# build a retriever
|
| 35 |
+
embeddings = HuggingFaceEmbeddings(
|
| 36 |
+
model_name="sentence-transformers/all-mpnet-base-v2"
|
| 37 |
+
) # dim=768
|
| 38 |
+
supabase: Client = create_client(
|
| 39 |
+
os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_ROLE_KEY")
|
| 40 |
+
)
|
| 41 |
+
vector_store = SupabaseVectorStore(
|
| 42 |
+
client=supabase,
|
| 43 |
+
embedding=embeddings,
|
| 44 |
+
table_name="documents2",
|
| 45 |
+
query_name="match_documents_2",
|
| 46 |
+
)
|
| 47 |
+
create_retriever_tool = create_retriever_tool(
|
| 48 |
+
retriever=vector_store.as_retriever(),
|
| 49 |
+
name="Question Search",
|
| 50 |
+
description="A tool to retrieve similar questions from a vector store.",
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
# System message
|
| 54 |
sys_msg = SystemMessage(content=system_prompt)
|
| 55 |
|
|
|
|
| 82 |
|
| 83 |
def retriever(state: MessagesState):
|
| 84 |
"""Retriever node"""
|
| 85 |
+
similar_question = vector_store.similarity_search(state["messages"][0].content)
|
| 86 |
+
|
| 87 |
+
if similar_question: # Check if the list is not empty
|
| 88 |
+
example_msg = HumanMessage(
|
| 89 |
+
content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
|
| 90 |
+
)
|
| 91 |
+
return {"messages": [sys_msg] + state["messages"] + [example_msg]}
|
| 92 |
+
else:
|
| 93 |
+
# Handle the case when no similar questions are found
|
| 94 |
+
return {"messages": [sys_msg] + state["messages"]}
|
| 95 |
+
|
| 96 |
|
| 97 |
# Build graph / nodes
|
| 98 |
builder = StateGraph(MessagesState)
|
| 99 |
+
builder.add_node("retriever", retriever) # Retriever
|
| 100 |
builder.add_node("assistant", assistant) # Assistant
|
| 101 |
builder.add_node("tools", ToolNode(tools)) # Tools
|
| 102 |
|