yashu21 commited on
Commit
df41870
·
verified ·
1 Parent(s): 7e56021

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +18 -19
agent.py CHANGED
@@ -112,6 +112,8 @@ def arvix_search(query: str) -> str:
112
  ])
113
  return {"arvix_results": formatted_search_docs}
114
 
 
 
115
  # load the system prompt from the file
116
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
117
  system_prompt = f.read()
@@ -120,16 +122,23 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
120
  sys_msg = SystemMessage(content=system_prompt)
121
 
122
  # build a retriever
123
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
124
  supabase: Client = create_client(
125
  os.environ.get("SUPABASE_URL"),
126
  os.environ.get("SUPABASE_SERVICE_KEY"))
127
  vector_store = SupabaseVectorStore(
128
  client=supabase,
129
- embedding=embeddings,
130
  table_name="documents",
131
  query_name="match_documents_langchain",
132
  )
 
 
 
 
 
 
 
133
 
134
  tools = [
135
  multiply,
@@ -151,7 +160,7 @@ def build_graph(provider: str = "groq"):
151
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
152
  elif provider == "groq":
153
  # Groq https://console.groq.com/docs/models
154
- llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0) # Updated model name
155
  elif provider == "huggingface":
156
  # TODO: Add huggingface endpoint
157
  llm = ChatHuggingFace(
@@ -168,24 +177,14 @@ def build_graph(provider: str = "groq"):
168
  # Node
169
  def assistant(state: MessagesState):
170
  """Assistant node"""
171
- # For debugging: Force a wiki_search call for the question
172
- if "Mercedes Sosa" in state["messages"][-1].content:
173
- wiki_results = wiki_search.invoke({"query": "Mercedes Sosa discography"})
174
- state["messages"].append(HumanMessage(content=f"Wikipedia search results:\n{wiki_results['wiki_results']}"))
175
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
176
 
177
  def retriever(state: MessagesState):
178
  """Retriever node"""
179
- query = state["messages"][0].content
180
- similar_questions = vector_store.similarity_search(query)
181
- if similar_questions:
182
- example_msg = HumanMessage(
183
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_questions[0].page_content}",
184
- )
185
- else:
186
- example_msg = HumanMessage(
187
- content="No similar questions found in the vector store."
188
- )
189
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
190
 
191
  builder = StateGraph(MessagesState)
@@ -205,11 +204,11 @@ def build_graph(provider: str = "groq"):
205
 
206
  # test
207
  if __name__ == "__main__":
208
- question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?"
209
  # Build the graph
210
  graph = build_graph(provider="groq")
211
  # Run the graph
212
  messages = [HumanMessage(content=question)]
213
  messages = graph.invoke({"messages": messages})
214
  for m in messages["messages"]:
215
- m.pretty_print()
 
112
  ])
113
  return {"arvix_results": formatted_search_docs}
114
 
115
+
116
+
117
  # load the system prompt from the file
118
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
  system_prompt = f.read()
 
122
  sys_msg = SystemMessage(content=system_prompt)
123
 
124
  # build a retriever
125
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
126
  supabase: Client = create_client(
127
  os.environ.get("SUPABASE_URL"),
128
  os.environ.get("SUPABASE_SERVICE_KEY"))
129
  vector_store = SupabaseVectorStore(
130
  client=supabase,
131
+ embedding= embeddings,
132
  table_name="documents",
133
  query_name="match_documents_langchain",
134
  )
135
+ create_retriever_tool = create_retriever_tool(
136
+ retriever=vector_store.as_retriever(),
137
+ name="Question Search",
138
+ description="A tool to retrieve similar questions from a vector store.",
139
+ )
140
+
141
+
142
 
143
  tools = [
144
  multiply,
 
160
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
161
  elif provider == "groq":
162
  # Groq https://console.groq.com/docs/models
163
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
164
  elif provider == "huggingface":
165
  # TODO: Add huggingface endpoint
166
  llm = ChatHuggingFace(
 
177
  # Node
178
  def assistant(state: MessagesState):
179
  """Assistant node"""
 
 
 
 
180
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
181
 
182
  def retriever(state: MessagesState):
183
  """Retriever node"""
184
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
185
+ example_msg = HumanMessage(
186
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
187
+ )
 
 
 
 
 
 
188
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
189
 
190
  builder = StateGraph(MessagesState)
 
204
 
205
  # test
206
  if __name__ == "__main__":
207
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
208
  # Build the graph
209
  graph = build_graph(provider="groq")
210
  # Run the graph
211
  messages = [HumanMessage(content=question)]
212
  messages = graph.invoke({"messages": messages})
213
  for m in messages["messages"]:
214
+ m.pretty_print()