Aya1610 commited on
Commit
cdb890f
·
verified ·
1 Parent(s): 58c4e06

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +91 -19
agent.py CHANGED
@@ -165,16 +165,13 @@ tools = [
165
  create_retriever_tool
166
  ]
167
 
168
-
169
-
170
-
171
  def build_graph(provider: str = "openai"):
172
  """Build the graph using OpenAI or Hugging Face"""
173
  # Validate provider
174
  if provider not in ["openai", "huggingface"]:
175
  raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
176
 
177
- # Initialize LLM based on provider
178
  if provider == "openai":
179
  from langchain_openai import ChatOpenAI
180
  llm = ChatOpenAI(model="gpt-4o", temperature=0)
@@ -192,27 +189,35 @@ def build_graph(provider: str = "openai"):
192
 
193
  # Define nodes
194
  def assistant(state: MessagesState):
195
- """Assistant node"""
196
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
 
197
 
198
  def retriever(state: MessagesState):
199
  """Retriever node - provides context from vector store"""
200
- query = state["messages"][-1].content
 
 
 
 
 
201
  similar_docs = vector_store.similarity_search(query, k=1)
202
 
203
  if not similar_docs:
204
- return {"messages": [AIMessage(content="No relevant information found")]}
205
-
206
- similar_doc = similar_docs[0]
207
- content = similar_doc.page_content
208
 
209
- # Extract answer if formatted, otherwise use full content
210
- if "Final answer :" in content:
211
- answer = content.split("Final answer :")[-1].strip()
212
- else:
213
- answer = content.strip()
214
-
215
- return {"messages": [AIMessage(content=answer)]}
216
 
217
  # Build graph
218
  builder = StateGraph(MessagesState)
@@ -230,9 +235,76 @@ def build_graph(provider: str = "openai"):
230
  tools_condition,
231
  {"continue": "tools", "end": END}
232
  )
233
- builder.add_edge("tools", "assistant")
234
 
235
  return builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  # def build_graph(provider: str = "google"):
237
  # """Build the graph"""
238
  # # Load environment variables from .env file
 
165
  create_retriever_tool
166
  ]
167
 
 
 
 
168
  def build_graph(provider: str = "openai"):
169
  """Build the graph using OpenAI or Hugging Face"""
170
  # Validate provider
171
  if provider not in ["openai", "huggingface"]:
172
  raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
173
 
174
+ # Initialize LLM
175
  if provider == "openai":
176
  from langchain_openai import ChatOpenAI
177
  llm = ChatOpenAI(model="gpt-4o", temperature=0)
 
189
 
190
  # Define nodes
191
  def assistant(state: MessagesState):
192
+ """Assistant node - generates responses"""
193
+ # Get current messages
194
+ messages = state["messages"]
195
+ # Generate response using LLM
196
+ response = llm_with_tools.invoke(messages)
197
+ # Append new message to state
198
+ return {"messages": messages + [response]}
199
 
200
  def retriever(state: MessagesState):
201
  """Retriever node - provides context from vector store"""
202
+ # Get current messages
203
+ messages = state["messages"]
204
+ # Last message is the user query
205
+ query = messages[-1].content
206
+
207
+ # Retrieve similar documents
208
  similar_docs = vector_store.similarity_search(query, k=1)
209
 
210
  if not similar_docs:
211
+ # Return original messages if no context found
212
+ return {"messages": messages}
 
 
213
 
214
+ # Get context from first document
215
+ context = similar_docs[0].page_content
216
+ # Create system message with context
217
+ context_msg = SystemMessage(content=f"Reference context:\n{context}")
218
+
219
+ # Append context to messages
220
+ return {"messages": messages + [context_msg]}
221
 
222
  # Build graph
223
  builder = StateGraph(MessagesState)
 
235
  tools_condition,
236
  {"continue": "tools", "end": END}
237
  )
238
+ builder.add_edge("tools", "retriever") # Go back to retriever after tools
239
 
240
  return builder.compile()
241
+
242
+
243
+ # def build_graph(provider: str = "openai"):
244
+ # """Build the graph using OpenAI or Hugging Face"""
245
+ # # Validate provider
246
+ # if provider not in ["openai", "huggingface"]:
247
+ # raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
248
+
249
+ # # Initialize LLM based on provider
250
+ # if provider == "openai":
251
+ # from langchain_openai import ChatOpenAI
252
+ # llm = ChatOpenAI(model="gpt-4o", temperature=0)
253
+ # else: # huggingface
254
+ # from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
255
+ # llm = ChatHuggingFace(
256
+ # llm=HuggingFaceEndpoint(
257
+ # endpoint_url="https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct",
258
+ # temperature=0,
259
+ # )
260
+ # )
261
+
262
+ # # Bind tools to LLM
263
+ # llm_with_tools = llm.bind_tools(tools)
264
+
265
+ # # Define nodes
266
+ # def assistant(state: MessagesState):
267
+ # """Assistant node"""
268
+ # return {"messages": [llm_with_tools.invoke(state["messages"])]}
269
+
270
+ # def retriever(state: MessagesState):
271
+ # """Retriever node - provides context from vector store"""
272
+ # query = state["messages"][-1].content
273
+ # similar_docs = vector_store.similarity_search(query, k=1)
274
+
275
+ # if not similar_docs:
276
+ # return {"messages": [AIMessage(content="No relevant information found")]}
277
+
278
+ # similar_doc = similar_docs[0]
279
+ # content = similar_doc.page_content
280
+
281
+ # # Extract answer if formatted, otherwise use full content
282
+ # if "Final answer :" in content:
283
+ # answer = content.split("Final answer :")[-1].strip()
284
+ # else:
285
+ # answer = content.strip()
286
+
287
+ # return {"messages": [AIMessage(content=answer)]}
288
+
289
+ # # Build graph
290
+ # builder = StateGraph(MessagesState)
291
+
292
+ # # Add nodes
293
+ # builder.add_node("retriever", retriever)
294
+ # builder.add_node("assistant", assistant)
295
+ # builder.add_node("tools", ToolNode(tools))
296
+
297
+ # # Set up edges
298
+ # builder.set_entry_point("retriever")
299
+ # builder.add_edge("retriever", "assistant")
300
+ # builder.add_conditional_edges(
301
+ # "assistant",
302
+ # tools_condition,
303
+ # {"continue": "tools", "end": END}
304
+ # )
305
+ # builder.add_edge("tools", "assistant")
306
+
307
+ # return builder.compile()
308
  # def build_graph(provider: str = "google"):
309
  # """Build the graph"""
310
  # # Load environment variables from .env file