kamorou commited on
Commit
c112284
·
verified ·
1 Parent(s): 733bfef

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +57 -57
agent.py CHANGED
@@ -8,13 +8,13 @@ from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
9
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
- from langchain.tools.retriever import create_retriever_tool
17
- from supabase.client import Client, create_client
18
 
19
  load_dotenv()
20
 
@@ -69,19 +69,19 @@ def modulus(a: int, b: int) -> int:
69
  """
70
  return a % b
71
 
72
- @tool
73
- def wiki_search(query: str) -> str:
74
- """Search Wikipedia for a query and return maximum 2 results.
75
 
76
- Args:
77
- query: The search query."""
78
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
- formatted_search_docs = "\n\n---\n\n".join(
80
- [
81
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
- for doc in search_docs
83
- ])
84
- return {"wiki_results": formatted_search_docs}
85
 
86
  @tool
87
  def web_search(query: str) -> str:
@@ -97,19 +97,19 @@ def web_search(query: str) -> str:
97
  ])
98
  return {"web_results": formatted_search_docs}
99
 
100
- @tool
101
- def arvix_search(query: str) -> str:
102
- """Search Arxiv for a query and return maximum 3 result.
103
 
104
- Args:
105
- query: The search query."""
106
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
- formatted_search_docs = "\n\n---\n\n".join(
108
- [
109
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
- for doc in search_docs
111
- ])
112
- return {"arvix_results": formatted_search_docs}
113
 
114
 
115
 
@@ -120,22 +120,22 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
120
  # System message
121
  sys_msg = SystemMessage(content=system_prompt)
122
 
123
- # build a retriever
124
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
- supabase: Client = create_client(
126
- os.environ.get("SUPABASE_URL"),
127
- os.environ.get("SUPABASE_SERVICE_KEY"))
128
- vector_store = SupabaseVectorStore(
129
- client=supabase,
130
- embedding= embeddings,
131
- table_name="documents",
132
- query_name="match_documents_langchain",
133
- )
134
- create_retriever_tool = create_retriever_tool(
135
- retriever=vector_store.as_retriever(),
136
- name="Question Search",
137
- description="A tool to retrieve similar questions from a vector store.",
138
- )
139
 
140
 
141
 
@@ -145,9 +145,9 @@ tools = [
145
  subtract,
146
  divide,
147
  modulus,
148
- wiki_search,
149
  web_search,
150
- arvix_search,
151
  ]
152
 
153
  # Build graph function
@@ -178,20 +178,20 @@ def build_graph(provider: str = "groq"):
178
  """Assistant node"""
179
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
 
181
- def retriever(state: MessagesState):
182
- """Retriever node"""
183
- similar_question = vector_store.similarity_search(state["messages"][0].content)
184
- example_msg = HumanMessage(
185
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
- )
187
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
 
189
  builder = StateGraph(MessagesState)
190
- builder.add_node("retriever", retriever)
191
  builder.add_node("assistant", assistant)
192
  builder.add_node("tools", ToolNode(tools))
193
- builder.add_edge(START, "retriever")
194
- builder.add_edge("retriever", "assistant")
195
  builder.add_conditional_edges(
196
  "assistant",
197
  tools_condition,
 
8
  from langchain_groq import ChatGroq
9
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
+ # from langchain_community.document_loaders import WikipediaLoader
12
+ # from langchain_community.document_loaders import ArxivLoader
13
+ # from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
+ # from langchain.tools.retriever import create_retriever_tool
17
+ # from supabase.client import Client, create_client
18
 
19
  load_dotenv()
20
 
 
69
  """
70
  return a % b
71
 
72
+ # @tool
73
+ # def wiki_search(query: str) -> str:
74
+ # """Search Wikipedia for a query and return maximum 2 results.
75
 
76
+ # Args:
77
+ # query: The search query."""
78
+ # search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
+ # formatted_search_docs = "\n\n---\n\n".join(
80
+ # [
81
+ # f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
+ # for doc in search_docs
83
+ # ])
84
+ # return {"wiki_results": formatted_search_docs}
85
 
86
  @tool
87
  def web_search(query: str) -> str:
 
97
  ])
98
  return {"web_results": formatted_search_docs}
99
 
100
+ # @tool
101
+ # def arvix_search(query: str) -> str:
102
+ # """Search Arxiv for a query and return maximum 3 result.
103
 
104
+ # Args:
105
+ # query: The search query."""
106
+ # search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
+ # formatted_search_docs = "\n\n---\n\n".join(
108
+ # [
109
+ # f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
+ # for doc in search_docs
111
+ # ])
112
+ # return {"arvix_results": formatted_search_docs}
113
 
114
 
115
 
 
120
  # System message
121
  sys_msg = SystemMessage(content=system_prompt)
122
 
123
+ # # build a retriever
124
+ # embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
+ # supabase: Client = create_client(
126
+ # os.environ.get("SUPABASE_URL"),
127
+ # os.environ.get("SUPABASE_SERVICE_KEY"))
128
+ # vector_store = SupabaseVectorStore(
129
+ # client=supabase,
130
+ # embedding= embeddings,
131
+ # table_name="documents",
132
+ # query_name="match_documents_langchain",
133
+ # )
134
+ # create_retriever_tool = create_retriever_tool(
135
+ # retriever=vector_store.as_retriever(),
136
+ # name="Question Search",
137
+ # description="A tool to retrieve similar questions from a vector store.",
138
+ # )
139
 
140
 
141
 
 
145
  subtract,
146
  divide,
147
  modulus,
148
+ # wiki_search,
149
  web_search,
150
+ # arvix_search,
151
  ]
152
 
153
  # Build graph function
 
178
  """Assistant node"""
179
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
 
181
+ # def retriever(state: MessagesState):
182
+ # """Retriever node"""
183
+ # similar_question = vector_store.similarity_search(state["messages"][0].content)
184
+ # example_msg = HumanMessage(
185
+ # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
+ # )
187
+ # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
 
189
  builder = StateGraph(MessagesState)
190
+ # builder.add_node("retriever", retriever)
191
  builder.add_node("assistant", assistant)
192
  builder.add_node("tools", ToolNode(tools))
193
+ builder.add_edge(START, "assistant")
194
+ # builder.add_edge("retriever", "assistant")
195
  builder.add_conditional_edges(
196
  "assistant",
197
  tools_condition,