Aya1610 commited on
Commit
c52d751
·
verified ·
1 Parent(s): 88542c2

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +170 -34
agent.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
  from dotenv import load_dotenv
3
  from langgraph.graph import START, END, StateGraph, MessagesState
4
- from langgraph.prebuilt import tools_condition
5
- from langgraph.prebuilt import ToolNode
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_groq import ChatGroq
8
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
@@ -10,7 +9,7 @@ from langchain_community.tools.tavily_search import TavilySearchResults
10
  from langchain_community.document_loaders import WikipediaLoader
11
  from langchain_community.document_loaders import ArxivLoader
12
  from langchain_community.vectorstores import SupabaseVectorStore
13
- from langchain_core.messages import SystemMessage, HumanMessage
14
  from langchain_core.tools import tool
15
  from langchain.tools.retriever import create_retriever_tool
16
  from supabase.client import Client, create_client
@@ -167,57 +166,47 @@ tools = [
167
  ]
168
 
169
 
 
 
170
  def build_graph(provider: str = "openai"):
171
  """Build the graph using OpenAI or Hugging Face"""
172
-
173
- if provider == "openai":
174
- # OpenAI ChatGPT (e.g., GPT-4 or GPT-3.5)
175
- from langchain.chat_models import ChatOpenAI
176
- llm = ChatOpenAI(model="gpt-4", temperature=0)
177
-
178
- elif provider == "huggingface":
179
- # Hugging Face endpoint
180
- from langchain.chat_models import ChatHuggingFace
181
- from langchain.llms import HuggingFaceEndpoint
182
 
 
 
 
 
 
 
183
  llm = ChatHuggingFace(
184
  llm=HuggingFaceEndpoint(
185
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
186
  temperature=0,
187
  )
188
  )
189
 
190
- else:
191
- raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
192
-
193
  # Bind tools to LLM
194
  llm_with_tools = llm.bind_tools(tools)
195
 
196
- # return llm_with_tools
197
-
198
- # Node
199
  def assistant(state: MessagesState):
200
  """Assistant node"""
201
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
202
 
203
- # def retriever(state: MessagesState):
204
- # """Retriever node"""
205
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
206
- #example_msg = HumanMessage(
207
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
208
- # )
209
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
210
-
211
- from langchain_core.messages import AIMessage
212
-
213
  def retriever(state: MessagesState):
 
214
  query = state["messages"][-1].content
215
- similar_doc = vector_store.similarity_search(query, k=1)[0]
 
216
  if not similar_docs:
217
- return {"messages": [AIMessage(content="No relevant information found")]}
 
218
  similar_doc = similar_docs[0]
219
-
220
  content = similar_doc.page_content
 
 
221
  if "Final answer :" in content:
222
  answer = content.split("Final answer :")[-1].strip()
223
  else:
@@ -225,6 +214,74 @@ def build_graph(provider: str = "openai"):
225
 
226
  return {"messages": [AIMessage(content=answer)]}
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  # builder = StateGraph(MessagesState)
229
  #builder.add_node("retriever", retriever)
230
  #builder.add_node("assistant", assistant)
@@ -240,10 +297,89 @@ def build_graph(provider: str = "openai"):
240
  builder = StateGraph(MessagesState)
241
  builder.add_node("retriever", retriever)
242
 
243
-
244
  builder.set_entry_point("retriever")
245
  builder.set_finish_point("retriever")
246
 
247
  # Compile graph
248
  return builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
 
1
  import os
2
  from dotenv import load_dotenv
3
  from langgraph.graph import START, END, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition, ToolNode
 
5
  from langchain_google_genai import ChatGoogleGenerativeAI
6
  from langchain_groq import ChatGroq
7
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
 
9
  from langchain_community.document_loaders import WikipediaLoader
10
  from langchain_community.document_loaders import ArxivLoader
11
  from langchain_community.vectorstores import SupabaseVectorStore
12
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
13
  from langchain_core.tools import tool
14
  from langchain.tools.retriever import create_retriever_tool
15
  from supabase.client import Client, create_client
 
166
  ]
167
 
168
 
169
+
170
+
171
  def build_graph(provider: str = "openai"):
172
  """Build the graph using OpenAI or Hugging Face"""
173
+ # Validate provider
174
+ if provider not in ["openai", "huggingface"]:
175
+ raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
 
 
 
 
 
 
 
176
 
177
+ # Initialize LLM based on provider
178
+ if provider == "openai":
179
+ from langchain_openai import ChatOpenAI
180
+ llm = ChatOpenAI(model="gpt-4o", temperature=0)
181
+ else: # huggingface
182
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
183
  llm = ChatHuggingFace(
184
  llm=HuggingFaceEndpoint(
185
+ endpoint_url="https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct",
186
  temperature=0,
187
  )
188
  )
189
 
 
 
 
190
  # Bind tools to LLM
191
  llm_with_tools = llm.bind_tools(tools)
192
 
193
+ # Define nodes
 
 
194
  def assistant(state: MessagesState):
195
  """Assistant node"""
196
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
197
 
 
 
 
 
 
 
 
 
 
 
198
  def retriever(state: MessagesState):
199
+ """Retriever node - provides context from vector store"""
200
  query = state["messages"][-1].content
201
+ similar_docs = vector_store.similarity_search(query, k=1)
202
+
203
  if not similar_docs:
204
+ return {"messages": [AIMessage(content="No relevant information found")]}
205
+
206
  similar_doc = similar_docs[0]
 
207
  content = similar_doc.page_content
208
+
209
+ # Extract answer if formatted, otherwise use full content
210
  if "Final answer :" in content:
211
  answer = content.split("Final answer :")[-1].strip()
212
  else:
 
214
 
215
  return {"messages": [AIMessage(content=answer)]}
216
 
217
+ # Build graph
218
+ builder = StateGraph(MessagesState)
219
+
220
+ # Add nodes
221
+ builder.add_node("retriever", retriever)
222
+ builder.add_node("assistant", assistant)
223
+ builder.add_node("tools", ToolNode(tools))
224
+
225
+ # Set up edges
226
+ builder.set_entry_point("retriever")
227
+ builder.add_edge("retriever", "assistant")
228
+ builder.add_conditional_edges(
229
+ "assistant",
230
+ tools_condition,
231
+ {"continue": "tools", "end": END}
232
+ )
233
+ builder.add_edge("tools", "assistant")
234
+
235
+ return builder.compile()
236
+ # def build_graph(provider: str = "google"):
237
+ # """Build the graph"""
238
+ # # Load environment variables from .env file
239
+ # if provider == "google":
240
+ # # Google Gemini
241
+ # llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
242
+ # elif provider == "groq":
243
+ # # Groq https://console.groq.com/docs/models
244
+ # llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
245
+ # elif provider == "huggingface":
246
+ # # TODO: Add huggingface endpoint
247
+ # llm = ChatHuggingFace(
248
+ # llm=HuggingFaceEndpoint(
249
+ # url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
250
+ # temperature=0,
251
+ # ),
252
+ # )
253
+ # else:
254
+ # raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
255
+ # # Bind tools to LLM
256
+ # llm_with_tools = llm.bind_tools(tools)
257
+
258
+ # # Node
259
+ # def assistant(state: MessagesState):
260
+ # """Assistant node"""
261
+ # return {"messages": [llm_with_tools.invoke(state["messages"])]}
262
+
263
+ # # def retriever(state: MessagesState):
264
+ # # """Retriever node"""
265
+ # # similar_question = vector_store.similarity_search(state["messages"][0].content)
266
+ # #example_msg = HumanMessage(
267
+ # # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
268
+ # # )
269
+ # # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
270
+
271
+ # from langchain_core.messages import AIMessage
272
+
273
+ # def retriever(state: MessagesState):
274
+ # query = state["messages"][-1].content
275
+ # similar_doc = vector_store.similarity_search(query, k=1)[0]
276
+
277
+ # content = similar_doc.page_content
278
+ # if "Final answer :" in content:
279
+ # answer = content.split("Final answer :")[-1].strip()
280
+ # else:
281
+ # answer = content.strip()
282
+
283
+ # return {"messages": [AIMessage(content=answer)]}
284
+
285
  # builder = StateGraph(MessagesState)
286
  #builder.add_node("retriever", retriever)
287
  #builder.add_node("assistant", assistant)
 
297
  builder = StateGraph(MessagesState)
298
  builder.add_node("retriever", retriever)
299
 
300
+ # Retriever ist Start und Endpunkt
301
  builder.set_entry_point("retriever")
302
  builder.set_finish_point("retriever")
303
 
304
  # Compile graph
305
  return builder.compile()
306
+ # def build_graph(provider: str = "openai"):
307
+ # """Build the graph using OpenAI or Hugging Face"""
308
+
309
+ # if provider == "openai":
310
+ # # OpenAI ChatGPT (e.g., GPT-4 or GPT-3.5)
311
+ # from langchain.chat_models import ChatOpenAI
312
+ # llm = ChatOpenAI(model="gpt-4", temperature=0)
313
+
314
+ # elif provider == "huggingface":
315
+ # # Hugging Face endpoint
316
+ # from langchain.chat_models import ChatHuggingFace
317
+ # from langchain.llms import HuggingFaceEndpoint
318
+
319
+ # llm = ChatHuggingFace(
320
+ # llm=HuggingFaceEndpoint(
321
+ # url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
322
+ # temperature=0,
323
+ # )
324
+ # )
325
+
326
+ # else:
327
+ # raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
328
+
329
+ # # Bind tools to LLM
330
+ # llm_with_tools = llm.bind_tools(tools)
331
+
332
+ # # return llm_with_tools
333
+
334
+ # # Node
335
+ # def assistant(state: MessagesState):
336
+ # """Assistant node"""
337
+ # return {"messages": [llm_with_tools.invoke(state["messages"])]}
338
+
339
+ # # def retriever(state: MessagesState):
340
+ # # """Retriever node"""
341
+ # # similar_question = vector_store.similarity_search(state["messages"][0].content)
342
+ # #example_msg = HumanMessage(
343
+ # # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
344
+ # # )
345
+ # # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
346
+
347
+ # from langchain_core.messages import AIMessage
348
+
349
+ # def retriever(state: MessagesState):
350
+ # query = state["messages"][-1].content
351
+ # similar_doc = vector_store.similarity_search(query, k=1)[0]
352
+ # if not similar_docs:
353
+ # return {"messages": [AIMessage(content="No relevant information found")]}
354
+ # similar_doc = similar_docs[0]
355
+
356
+ # content = similar_doc.page_content
357
+ # if "Final answer :" in content:
358
+ # answer = content.split("Final answer :")[-1].strip()
359
+ # else:
360
+ # answer = content.strip()
361
+
362
+ # return {"messages": [AIMessage(content=answer)]}
363
+
364
+ # # builder = StateGraph(MessagesState)
365
+ # #builder.add_node("retriever", retriever)
366
+ # #builder.add_node("assistant", assistant)
367
+ # #builder.add_node("tools", ToolNode(tools))
368
+ # #builder.add_edge(START, "retriever")
369
+ # #builder.add_edge("retriever", "assistant")
370
+ # #builder.add_conditional_edges(
371
+ # # "assistant",
372
+ # # tools_condition,
373
+ # #)
374
+ # #builder.add_edge("tools", "assistant")
375
+
376
+ # builder = StateGraph(MessagesState)
377
+ # builder.add_node("retriever", retriever)
378
+
379
+
380
+ # builder.set_entry_point("retriever")
381
+ # builder.set_finish_point("retriever")
382
+
383
+ # # Compile graph
384
+ # return builder.compile()
385