yashu21 commited on
Commit
7e56021
·
verified ·
1 Parent(s): adb9345

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +19 -18
agent.py CHANGED
@@ -112,8 +112,6 @@ def arvix_search(query: str) -> str:
112
  ])
113
  return {"arvix_results": formatted_search_docs}
114
 
115
-
116
-
117
  # load the system prompt from the file
118
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
  system_prompt = f.read()
@@ -122,23 +120,16 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
122
  sys_msg = SystemMessage(content=system_prompt)
123
 
124
  # build a retriever
125
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
126
  supabase: Client = create_client(
127
  os.environ.get("SUPABASE_URL"),
128
  os.environ.get("SUPABASE_SERVICE_KEY"))
129
  vector_store = SupabaseVectorStore(
130
  client=supabase,
131
- embedding= embeddings,
132
  table_name="documents",
133
  query_name="match_documents_langchain",
134
  )
135
- create_retriever_tool = create_retriever_tool(
136
- retriever=vector_store.as_retriever(),
137
- name="Question Search",
138
- description="A tool to retrieve similar questions from a vector store.",
139
- )
140
-
141
-
142
 
143
  tools = [
144
  multiply,
@@ -160,7 +151,7 @@ def build_graph(provider: str = "groq"):
160
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
161
  elif provider == "groq":
162
  # Groq https://console.groq.com/docs/models
163
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
164
  elif provider == "huggingface":
165
  # TODO: Add huggingface endpoint
166
  llm = ChatHuggingFace(
@@ -177,14 +168,24 @@ def build_graph(provider: str = "groq"):
177
  # Node
178
  def assistant(state: MessagesState):
179
  """Assistant node"""
 
 
 
 
180
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
181
 
182
  def retriever(state: MessagesState):
183
  """Retriever node"""
184
- similar_question = vector_store.similarity_search(state["messages"][0].content)
185
- example_msg = HumanMessage(
186
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
187
- )
 
 
 
 
 
 
188
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
189
 
190
  builder = StateGraph(MessagesState)
@@ -204,11 +205,11 @@ def build_graph(provider: str = "groq"):
204
 
205
  # test
206
  if __name__ == "__main__":
207
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
208
  # Build the graph
209
  graph = build_graph(provider="groq")
210
  # Run the graph
211
  messages = [HumanMessage(content=question)]
212
  messages = graph.invoke({"messages": messages})
213
  for m in messages["messages"]:
214
- m.pretty_print()
 
112
  ])
113
  return {"arvix_results": formatted_search_docs}
114
 
 
 
115
  # load the system prompt from the file
116
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
117
  system_prompt = f.read()
 
120
  sys_msg = SystemMessage(content=system_prompt)
121
 
122
  # build a retriever
123
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
124
  supabase: Client = create_client(
125
  os.environ.get("SUPABASE_URL"),
126
  os.environ.get("SUPABASE_SERVICE_KEY"))
127
  vector_store = SupabaseVectorStore(
128
  client=supabase,
129
+ embedding=embeddings,
130
  table_name="documents",
131
  query_name="match_documents_langchain",
132
  )
 
 
 
 
 
 
 
133
 
134
  tools = [
135
  multiply,
 
151
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
152
  elif provider == "groq":
153
  # Groq https://console.groq.com/docs/models
154
+ llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0) # Updated model name
155
  elif provider == "huggingface":
156
  # TODO: Add huggingface endpoint
157
  llm = ChatHuggingFace(
 
168
  # Node
169
  def assistant(state: MessagesState):
170
  """Assistant node"""
171
+ # For debugging: Force a wiki_search call for the question
172
+ if "Mercedes Sosa" in state["messages"][-1].content:
173
+ wiki_results = wiki_search.invoke({"query": "Mercedes Sosa discography"})
174
+ state["messages"].append(HumanMessage(content=f"Wikipedia search results:\n{wiki_results['wiki_results']}"))
175
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
176
 
177
  def retriever(state: MessagesState):
178
  """Retriever node"""
179
+ query = state["messages"][0].content
180
+ similar_questions = vector_store.similarity_search(query)
181
+ if similar_questions:
182
+ example_msg = HumanMessage(
183
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_questions[0].page_content}",
184
+ )
185
+ else:
186
+ example_msg = HumanMessage(
187
+ content="No similar questions found in the vector store."
188
+ )
189
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
190
 
191
  builder = StateGraph(MessagesState)
 
205
 
206
  # test
207
  if __name__ == "__main__":
208
+ question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?"
209
  # Build the graph
210
  graph = build_graph(provider="groq")
211
  # Run the graph
212
  messages = [HumanMessage(content=question)]
213
  messages = graph.invoke({"messages": messages})
214
  for m in messages["messages"]:
215
+ m.pretty_print()