tsrrus commited on
Commit
9a4d5ce
·
verified ·
1 Parent(s): 35d1d29

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +112 -45
agent.py CHANGED
@@ -1,22 +1,28 @@
1
  import os
 
 
2
  from dotenv import load_dotenv
3
  from langgraph.graph import START, StateGraph, MessagesState
4
  from langgraph.prebuilt import tools_condition
5
  from langgraph.prebuilt import ToolNode
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_groq import ChatGroq
8
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
 
 
 
 
9
  from langchain_community.tools.tavily_search import TavilySearchResults
10
  from langchain_community.document_loaders import WikipediaLoader
11
  from langchain_community.document_loaders import ArxivLoader
12
- from langchain_community.vectorstores import SupabaseVectorStore
13
- from langchain_core.messages import SystemMessage, HumanMessage
14
  from langchain_core.tools import tool
15
- from langchain.tools.retriever import create_retriever_tool
16
- from supabase.client import Client, create_client
17
 
18
  load_dotenv()
19
 
 
20
  @tool
21
  def multiply(a: int, b: int) -> int:
22
  """Multiply two numbers.
@@ -26,30 +32,33 @@ def multiply(a: int, b: int) -> int:
26
  """
27
  return a * b
28
 
 
29
  @tool
30
  def add(a: int, b: int) -> int:
31
  """Add two numbers.
32
-
33
  Args:
34
  a: first int
35
  b: second int
36
  """
37
  return a + b
38
 
 
39
  @tool
40
  def subtract(a: int, b: int) -> int:
41
  """Subtract two numbers.
42
-
43
  Args:
44
  a: first int
45
  b: second int
46
  """
47
  return a - b
48
 
 
49
  @tool
50
  def divide(a: int, b: int) -> int:
51
  """Divide two numbers.
52
-
53
  Args:
54
  a: first int
55
  b: second int
@@ -58,20 +67,22 @@ def divide(a: int, b: int) -> int:
58
  raise ValueError("Cannot divide by zero.")
59
  return a / b
60
 
 
61
  @tool
62
  def modulus(a: int, b: int) -> int:
63
  """Get the modulus of two numbers.
64
-
65
  Args:
66
  a: first int
67
  b: second int
68
  """
69
  return a % b
70
 
 
71
  @tool
72
  def wiki_search(query: str) -> str:
73
  """Search Wikipedia for a query and return maximum 2 results.
74
-
75
  Args:
76
  query: The search query."""
77
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
@@ -79,13 +90,15 @@ def wiki_search(query: str) -> str:
79
  [
80
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
81
  for doc in search_docs
82
- ])
 
83
  return {"wiki_results": formatted_search_docs}
84
 
 
85
  @tool
86
  def web_search(query: str) -> str:
87
  """Search Tavily for a query and return maximum 3 results.
88
-
89
  Args:
90
  query: The search query."""
91
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
@@ -93,13 +106,15 @@ def web_search(query: str) -> str:
93
  [
94
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
95
  for doc in search_docs
96
- ])
 
97
  return {"web_results": formatted_search_docs}
98
 
 
99
  @tool
100
  def arvix_search(query: str) -> str:
101
  """Search Arxiv for a query and return maximum 3 result.
102
-
103
  Args:
104
  query: The search query."""
105
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
@@ -107,10 +122,74 @@ def arvix_search(query: str) -> str:
107
  [
108
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
109
  for doc in search_docs
110
- ])
 
111
  return {"arvix_results": formatted_search_docs}
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  # load the system prompt from the file
116
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
@@ -119,25 +198,6 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
  # System message
120
  sys_msg = SystemMessage(content=system_prompt)
121
 
122
- # build a retriever
123
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
124
- supabase: Client = create_client(
125
- os.environ.get("SUPABASE_URL"),
126
- os.environ.get("SUPABASE_SERVICE_KEY"))
127
- vector_store = SupabaseVectorStore(
128
- client=supabase,
129
- embedding= embeddings,
130
- table_name="documents",
131
- query_name="match_documents_langchain",
132
- )
133
- create_retriever_tool = create_retriever_tool(
134
- retriever=vector_store.as_retriever(),
135
- name="Question Search",
136
- description="A tool to retrieve similar questions from a vector store.",
137
- )
138
-
139
-
140
-
141
  tools = [
142
  multiply,
143
  add,
@@ -149,6 +209,7 @@ tools = [
149
  arvix_search,
150
  ]
151
 
 
152
  # Build graph function
153
  def build_graph(provider: str = "groq"):
154
  """Build the graph"""
@@ -158,7 +219,9 @@ def build_graph(provider: str = "groq"):
158
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
159
  elif provider == "groq":
160
  # Groq https://console.groq.com/docs/models
161
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
 
 
162
  elif provider == "huggingface":
163
  # TODO: Add huggingface endpoint
164
  llm = ChatHuggingFace(
@@ -176,31 +239,35 @@ def build_graph(provider: str = "groq"):
176
  def assistant(state: MessagesState):
177
  """Assistant node"""
178
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
179
-
180
- from langchain_core.messages import AIMessage
181
 
182
  def retriever(state: MessagesState):
 
183
  query = state["messages"][-1].content
184
- similar_docs = vector_store.similarity_search(query, k=1)
185
-
186
  # Handle empty results
187
  if not similar_docs:
188
- return {"messages": [AIMessage(content="I don't have information about this topic in my knowledge base. Please try a different question.")]}
189
-
 
 
 
 
 
 
190
  similar_doc = similar_docs[0]
191
  content = similar_doc.page_content
192
-
193
  if "Final answer :" in content:
194
  answer = content.split("Final answer :")[-1].strip()
195
  else:
196
  answer = content.strip()
197
-
198
  # Ensure answer is not empty
199
  if not answer:
200
  answer = "I found related information but couldn't extract a clear answer. Please rephrase your question."
201
-
202
- return {"messages": [AIMessage(content=answer)]}
203
 
 
204
 
205
  builder = StateGraph(MessagesState)
206
  builder.add_node("retriever", retriever)
 
1
  import os
2
+ import pandas as pd
3
+ import numpy as np
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
7
  from langgraph.prebuilt import ToolNode
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_groq import ChatGroq
10
+ from langchain_huggingface import (
11
+ ChatHuggingFace,
12
+ HuggingFaceEndpoint,
13
+ HuggingFaceEmbeddings,
14
+ )
15
  from langchain_community.tools.tavily_search import TavilySearchResults
16
  from langchain_community.document_loaders import WikipediaLoader
17
  from langchain_community.document_loaders import ArxivLoader
18
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
 
19
  from langchain_core.tools import tool
20
+ from sklearn.metrics.pairwise import cosine_similarity
21
+ import ast
22
 
23
  load_dotenv()
24
 
25
+
26
  @tool
27
  def multiply(a: int, b: int) -> int:
28
  """Multiply two numbers.
 
32
  """
33
  return a * b
34
 
35
+
36
  @tool
37
  def add(a: int, b: int) -> int:
38
  """Add two numbers.
39
+
40
  Args:
41
  a: first int
42
  b: second int
43
  """
44
  return a + b
45
 
46
+
47
  @tool
48
  def subtract(a: int, b: int) -> int:
49
  """Subtract two numbers.
50
+
51
  Args:
52
  a: first int
53
  b: second int
54
  """
55
  return a - b
56
 
57
+
58
  @tool
59
  def divide(a: int, b: int) -> int:
60
  """Divide two numbers.
61
+
62
  Args:
63
  a: first int
64
  b: second int
 
67
  raise ValueError("Cannot divide by zero.")
68
  return a / b
69
 
70
+
71
  @tool
72
  def modulus(a: int, b: int) -> int:
73
  """Get the modulus of two numbers.
74
+
75
  Args:
76
  a: first int
77
  b: second int
78
  """
79
  return a % b
80
 
81
+
82
  @tool
83
  def wiki_search(query: str) -> str:
84
  """Search Wikipedia for a query and return maximum 2 results.
85
+
86
  Args:
87
  query: The search query."""
88
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
 
90
  [
91
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
92
  for doc in search_docs
93
+ ]
94
+ )
95
  return {"wiki_results": formatted_search_docs}
96
 
97
+
98
  @tool
99
  def web_search(query: str) -> str:
100
  """Search Tavily for a query and return maximum 3 results.
101
+
102
  Args:
103
  query: The search query."""
104
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
 
106
  [
107
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
108
  for doc in search_docs
109
+ ]
110
+ )
111
  return {"web_results": formatted_search_docs}
112
 
113
+
114
  @tool
115
  def arvix_search(query: str) -> str:
116
  """Search Arxiv for a query and return maximum 3 result.
117
+
118
  Args:
119
  query: The search query."""
120
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
 
122
  [
123
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
124
  for doc in search_docs
125
+ ]
126
+ )
127
  return {"arvix_results": formatted_search_docs}
128
 
129
 
130
+ # Load CSV data and embeddings
131
+ class LocalCSVRetriever:
132
+ def __init__(self, csv_file_path="supabase_docs.csv"):
133
+ self.csv_file_path = csv_file_path
134
+ self.df = None
135
+ self.embeddings_model = HuggingFaceEmbeddings(
136
+ model_name="sentence-transformers/all-mpnet-base-v2"
137
+ )
138
+ self.load_data()
139
+
140
+ def load_data(self):
141
+ """Load data from CSV file"""
142
+ try:
143
+ self.df = pd.read_csv(self.csv_file_path)
144
+ print(f"Loaded {len(self.df)} documents from {self.csv_file_path}")
145
+
146
+ # Convert string representation of embeddings back to numpy arrays
147
+ if 'embedding' in self.df.columns:
148
+ self.df['embedding_array'] = self.df['embedding'].apply(
149
+ lambda x: np.array(ast.literal_eval(x)) if isinstance(x, str) else np.array(x)
150
+ )
151
+ except FileNotFoundError:
152
+ print(f"CSV file {self.csv_file_path} not found!")
153
+ self.df = pd.DataFrame()
154
+ except Exception as e:
155
+ print(f"Error loading CSV: {e}")
156
+ self.df = pd.DataFrame()
157
+
158
+ def similarity_search(self, query: str, k: int = 1):
159
+ """Perform similarity search on local data"""
160
+ if self.df.empty:
161
+ return []
162
+
163
+ # Get query embedding
164
+ query_embedding = self.embeddings_model.embed_query(query)
165
+ query_embedding = np.array(query_embedding).reshape(1, -1)
166
+
167
+ # Calculate similarities
168
+ similarities = []
169
+ for idx, row in self.df.iterrows():
170
+ doc_embedding = row['embedding_array'].reshape(1, -1)
171
+ similarity = cosine_similarity(query_embedding, doc_embedding)[0][0]
172
+ similarities.append((idx, similarity, row['content']))
173
+
174
+ # Sort by similarity and return top k
175
+ similarities.sort(key=lambda x: x[1], reverse=True)
176
+
177
+ # Create simple document-like objects
178
+ results = []
179
+ for i in range(min(k, len(similarities))):
180
+ idx, sim_score, content = similarities[i]
181
+ # Create a simple object with page_content attribute
182
+ doc = type('Document', (), {
183
+ 'page_content': content,
184
+ 'metadata': ast.literal_eval(self.df.iloc[idx]['metadata']) if isinstance(self.df.iloc[idx]['metadata'], str) else self.df.iloc[idx]['metadata']
185
+ })()
186
+ results.append(doc)
187
+
188
+ return results
189
+
190
+
191
+ # Initialize the local retriever
192
+ local_retriever = LocalCSVRetriever()
193
 
194
  # load the system prompt from the file
195
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
 
198
  # System message
199
  sys_msg = SystemMessage(content=system_prompt)
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  tools = [
202
  multiply,
203
  add,
 
209
  arvix_search,
210
  ]
211
 
212
+
213
  # Build graph function
214
  def build_graph(provider: str = "groq"):
215
  """Build the graph"""
 
219
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
220
  elif provider == "groq":
221
  # Groq https://console.groq.com/docs/models
222
+ llm = ChatGroq(
223
+ model="qwen-qwq-32b", temperature=0
224
+ ) # optional : qwen-qwq-32b gemma2-9b-it
225
  elif provider == "huggingface":
226
  # TODO: Add huggingface endpoint
227
  llm = ChatHuggingFace(
 
239
  def assistant(state: MessagesState):
240
  """Assistant node"""
241
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
242
 
243
  def retriever(state: MessagesState):
244
+ """Modified retriever to use local CSV data"""
245
  query = state["messages"][-1].content
246
+ similar_docs = local_retriever.similarity_search(query, k=1)
247
+
248
  # Handle empty results
249
  if not similar_docs:
250
+ return {
251
+ "messages": [
252
+ AIMessage(
253
+ content="I don't have information about this topic in my knowledge base. Please try a different question."
254
+ )
255
+ ]
256
+ }
257
+
258
  similar_doc = similar_docs[0]
259
  content = similar_doc.page_content
260
+
261
  if "Final answer :" in content:
262
  answer = content.split("Final answer :")[-1].strip()
263
  else:
264
  answer = content.strip()
265
+
266
  # Ensure answer is not empty
267
  if not answer:
268
  answer = "I found related information but couldn't extract a clear answer. Please rephrase your question."
 
 
269
 
270
+ return {"messages": [AIMessage(content=answer)]}
271
 
272
  builder = StateGraph(MessagesState)
273
  builder.add_node("retriever", retriever)