surya07 commited on
Commit
c7724ac
·
verified ·
1 Parent(s): 38394f9

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +38 -29
agent.py CHANGED
@@ -21,7 +21,6 @@ load_dotenv()
21
  @tool
22
  def multiply(a: int, b: int) -> int:
23
  """Multiply two numbers.
24
-
25
  Args:
26
  a: first int
27
  b: second int
@@ -152,7 +151,7 @@ tools = [
152
  ]
153
 
154
  # Build graph function
155
- def build_graph(provider: str = "groq"):
156
  """Build the graph"""
157
  # Load environment variables from .env file
158
  if provider == "google":
@@ -179,36 +178,46 @@ def build_graph(provider: str = "groq"):
179
  """Assistant node"""
180
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
181
 
 
 
 
 
 
 
 
 
 
 
182
  def retriever(state: MessagesState):
183
- """Retriever node"""
184
- similar_question = vector_store.similarity_search(state["messages"][0].content)
185
- example_msg = HumanMessage(
186
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
187
- )
188
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  builder = StateGraph(MessagesState)
191
  builder.add_node("retriever", retriever)
192
- builder.add_node("assistant", assistant)
193
- builder.add_node("tools", ToolNode(tools))
194
- builder.add_edge(START, "retriever")
195
- builder.add_edge("retriever", "assistant")
196
- builder.add_conditional_edges(
197
- "assistant",
198
- tools_condition,
199
- )
200
- builder.add_edge("tools", "assistant")
201
 
202
  # Compile graph
203
- return builder.compile()
204
-
205
- # test
206
- if __name__ == "__main__":
207
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
208
- # Build the graph
209
- graph = build_graph(provider="groq")
210
- # Run the graph
211
- messages = [HumanMessage(content=question)]
212
- messages = graph.invoke({"messages": messages})
213
- for m in messages["messages"]:
214
- m.pretty_print()
 
21
  @tool
22
  def multiply(a: int, b: int) -> int:
23
  """Multiply two numbers.
 
24
  Args:
25
  a: first int
26
  b: second int
 
151
  ]
152
 
153
  # Build graph function
154
+ def build_graph(provider: str = "google"):
155
  """Build the graph"""
156
  # Load environment variables from .env file
157
  if provider == "google":
 
178
  """Assistant node"""
179
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
 
181
+ # def retriever(state: MessagesState):
182
+ # """Retriever node"""
183
+ # similar_question = vector_store.similarity_search(state["messages"][0].content)
184
+ #example_msg = HumanMessage(
185
+ # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
+ # )
187
+ # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
+
189
+ from langchain_core.messages import AIMessage
190
+
191
  def retriever(state: MessagesState):
192
+ query = state["messages"][-1].content
193
+ similar_doc = vector_store.similarity_search(query, k=1)[0]
194
+
195
+ content = similar_doc.page_content
196
+ if "Final answer :" in content:
197
+ answer = content.split("Final answer :")[-1].strip()
198
+ else:
199
+ answer = content.strip()
200
+
201
+ return {"messages": [AIMessage(content=answer)]}
202
+
203
+ # builder = StateGraph(MessagesState)
204
+ #builder.add_node("retriever", retriever)
205
+ #builder.add_node("assistant", assistant)
206
+ #builder.add_node("tools", ToolNode(tools))
207
+ #builder.add_edge(START, "retriever")
208
+ #builder.add_edge("retriever", "assistant")
209
+ #builder.add_conditional_edges(
210
+ # "assistant",
211
+ # tools_condition,
212
+ #)
213
+ #builder.add_edge("tools", "assistant")
214
 
215
  builder = StateGraph(MessagesState)
216
  builder.add_node("retriever", retriever)
217
+
218
+ # Retriever ist Start und Endpunkt
219
+ builder.set_entry_point("retriever")
220
+ builder.set_finish_point("retriever")
 
 
 
 
 
221
 
222
  # Compile graph
223
+ return builder.compile()