tsrrus commited on
Commit
497b4bc
·
verified ·
1 Parent(s): d708ed8

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +34 -225
agent.py CHANGED
@@ -1,77 +1,55 @@
1
  """LangGraph Agent"""
2
-
3
  import os
4
- import json
5
- from typing import Optional, Dict, Any, List
6
  from dotenv import load_dotenv
7
- from langgraph.graph import START, END, StateGraph, MessagesState
8
  from langgraph.prebuilt import tools_condition
9
  from langgraph.prebuilt import ToolNode
10
  from langchain_google_genai import ChatGoogleGenerativeAI
 
 
11
  from langchain_groq import ChatGroq
12
- from langchain_huggingface import (
13
- ChatHuggingFace,
14
- HuggingFaceEndpoint,
15
- HuggingFaceEmbeddings,
16
- )
17
- from langchain_community.utilities import GoogleSerperAPIWrapper
18
  from langchain_community.document_loaders import WikipediaLoader
19
  from langchain_community.document_loaders import ArxivLoader
20
  from langchain_community.vectorstores import SupabaseVectorStore
21
- from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
22
  from langchain_core.tools import tool
23
  from langchain.tools.retriever import create_retriever_tool
24
  from supabase.client import Client, create_client
25
- from langchain_core.prompts import ChatPromptTemplate
26
- from langchain_core.output_parsers import StrOutputParser
27
-
28
- import os
29
- from supabase import Client, create_client
30
-
31
- supabase: Client = create_client(
32
- os.environ.get("SUPABASE_URL"),
33
- os.environ.get("SUPABASE_SERVICE_KEY"))
34
 
35
  load_dotenv()
36
 
37
-
38
  @tool
39
  def multiply(a: int, b: int) -> int:
40
  """Multiply two numbers.
41
-
42
  Args:
43
  a: first int
44
  b: second int
45
  """
46
  return a * b
47
 
48
-
49
  @tool
50
  def add(a: int, b: int) -> int:
51
  """Add two numbers.
52
-
53
  Args:
54
  a: first int
55
  b: second int
56
  """
57
  return a + b
58
 
59
-
60
  @tool
61
  def subtract(a: int, b: int) -> int:
62
  """Subtract two numbers.
63
-
64
  Args:
65
  a: first int
66
  b: second int
67
  """
68
  return a - b
69
 
70
-
71
  @tool
72
  def divide(a: int, b: int) -> int:
73
  """Divide two numbers.
74
-
75
  Args:
76
  a: first int
77
  b: second int
@@ -80,22 +58,18 @@ def divide(a: int, b: int) -> int:
80
  raise ValueError("Cannot divide by zero.")
81
  return a / b
82
 
83
-
84
  @tool
85
  def modulus(a: int, b: int) -> int:
86
  """Get the modulus of two numbers.
87
-
88
  Args:
89
  a: first int
90
  b: second int
91
  """
92
  return a % b
93
 
94
-
95
  @tool
96
  def wiki_search(query: str) -> str:
97
  """Search Wikipedia for a query and return maximum 2 results.
98
-
99
  Args:
100
  query: The search query."""
101
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
@@ -103,26 +77,25 @@ def wiki_search(query: str) -> str:
103
  [
104
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
105
  for doc in search_docs
106
- ]
107
- )
108
  return {"wiki_results": formatted_search_docs}
109
 
110
-
111
  @tool
112
  def web_search(query: str) -> str:
113
  """Search Tavily for a query and return maximum 3 results.
114
-
115
  Args:
116
  query: The search query."""
117
- search = GoogleSerperAPIWrapper()
118
- result = search.run(query)
119
- return {"web_results": result}
120
-
 
 
 
121
 
122
  @tool
123
  def arvix_search(query: str) -> str:
124
  """Search Arxiv for a query and return maximum 3 result.
125
-
126
  Args:
127
  query: The search query."""
128
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
@@ -130,86 +103,12 @@ def arvix_search(query: str) -> str:
130
  [
131
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
132
  for doc in search_docs
133
- ]
134
- )
135
  return {"arvix_results": formatted_search_docs}
136
 
137
 
138
- def load_gaia_answers() -> List[Dict[str, Any]]:
139
- """Load the GAIA questions and answers from the JSON file."""
140
- try:
141
- with open("gaia.json", "r", encoding="utf-8") as f:
142
- return json.load(f)
143
- except Exception as e:
144
- print(f"Error loading GAIA answers: {e}")
145
- return []
146
-
147
- def find_gaia_answer(question: str) -> Optional[str]:
148
- """
149
- Find the most relevant answer in the GAIA dataset for the given question using LLM.
150
- Returns the answer if found, None otherwise.
151
- """
152
- try:
153
- # Load GAIA data
154
- gaia_data = load_gaia_answers()
155
- if not gaia_data:
156
- return None
157
-
158
- # First, try exact match for efficiency
159
- for entry in gaia_data:
160
- if entry.get("Question", "").strip() == question.strip():
161
- return entry.get("Final answer", "")
162
-
163
-
164
-
165
- # Initialize LLM (using the same provider as the main graph for consistency)
166
- llm = ChatHuggingFace(
167
- llm=HuggingFaceEndpoint(
168
- repo_id="meta-llama/Llama-3.1-8B-Instruct",
169
- temperature=0,
170
- ),
171
- )
172
-
173
- # Create a prompt template
174
- template = """You are an expert at matching questions to answers.
175
- Given the following question and a list of question-answer pairs from the GAIA dataset,
176
- find the most relevant answer. If no good match is found, return 'NO_MATCH'.
177
-
178
- Question: {question}
179
-
180
- Available question-answer pairs:
181
- {qa_pairs}
182
-
183
- Return ONLY the answer text if a match is found, or 'NO_MATCH' if no good match is found.
184
- """
185
-
186
- # Prepare the QA pairs string
187
- qa_pairs = "\n\n".join([
188
- f"Q: {entry.get('Question', '')}\nA: {entry.get('Final answer', '')}"
189
- for entry in gaia_data
190
- ])
191
-
192
- # Create and run the chain
193
- prompt = ChatPromptTemplate.from_template(template)
194
- chain = prompt | llm | StrOutputParser()
195
-
196
- # Get the response
197
- response = chain.invoke({
198
- "question": question,
199
- "qa_pairs": qa_pairs
200
- })
201
-
202
- # Parse the response
203
- response = response.strip()
204
- if response and response.upper() != "NO_MATCH":
205
- return response
206
-
207
- except Exception as e:
208
- print(f"Error in find_gaia_answer: {e}")
209
-
210
- return None
211
 
212
- # Load the system prompt from the file
213
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
214
  system_prompt = f.read()
215
 
@@ -217,15 +116,13 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
217
  sys_msg = SystemMessage(content=system_prompt)
218
 
219
  # build a retriever
220
- embeddings = HuggingFaceEmbeddings(
221
- model_name="sentence-transformers/all-mpnet-base-v2"
222
- ) # dim=768
223
  supabase: Client = create_client(
224
- os.environ.get("SUPABASE_URL"), os.environ.get("SUPABAS_SERVICE_KEY")
225
- )
226
  vector_store = SupabaseVectorStore(
227
  client=supabase,
228
- embedding=embeddings,
229
  table_name="documents",
230
  query_name="match_documents_langchain",
231
  )
@@ -236,22 +133,20 @@ create_retriever_tool = create_retriever_tool(
236
  )
237
 
238
 
 
239
  tools = [
240
  multiply,
241
  add,
242
  subtract,
243
  divide,
244
  modulus,
245
- # wiki_search,
246
  web_search,
247
  arvix_search,
248
  ]
249
 
250
- class AgentState(MessagesState):
251
- cheating_used: bool = False
252
-
253
  # Build graph function
254
- def build_graph(provider: str = "huggingface"):
255
  """Build the graph"""
256
  # Load environment variables from .env file
257
  if provider == "google":
@@ -259,29 +154,27 @@ def build_graph(provider: str = "huggingface"):
259
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
260
  elif provider == "groq":
261
  # Groq https://console.groq.com/docs/models
262
- llm = ChatGroq(
263
- model="qwen-qwq-32b", temperature=0
264
- ) # optional : qwen-qwq-32b gemma2-9b-it
 
265
  elif provider == "huggingface":
266
- # TODO: Add huggingface endpoint
267
  llm = ChatHuggingFace(
268
  llm=HuggingFaceEndpoint(
269
- repo_id="meta-llama/Llama-3.1-8B-Instruct",
270
  temperature=0,
271
  ),
272
  )
273
  else:
274
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
275
-
276
  # Bind tools to LLM
277
  llm_with_tools = llm.bind_tools(tools)
278
 
279
- # Node: Assistant
280
  def assistant(state: MessagesState):
281
  """Assistant node"""
282
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
283
 
284
- # Node: Retriever
285
  def retriever(state: MessagesState):
286
  """Retriever node"""
287
  similar_question = vector_store.similarity_search(state["messages"][0].content)
@@ -289,63 +182,13 @@ def build_graph(provider: str = "huggingface"):
289
  content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
290
  )
291
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
292
-
293
- # Node: Cheating - Check if question exists in GAIA dataset
294
- def cheating_node(state: MessagesState):
295
- """Cheating node that checks if question exists in GAIA dataset"""
296
- if not state["messages"] or not isinstance(state["messages"][-1], HumanMessage):
297
- return {"messages": state["messages"], "cheating_used": False}
298
-
299
- question = state["messages"][-1].content
300
- print("Checking if question exists in GAIA dataset...")
301
- answer = find_gaia_answer(question)
302
-
303
- if answer:
304
- # If answer found in GAIA, return it directly
305
- print("Answer found in GAIA dataset.")
306
- return {
307
- "messages": state["messages"] + [AIMessage(content=f"FINAL ANSWER: {answer}")],
308
- "cheating_used": True
309
- }
310
-
311
- # If not found, continue with normal flow
312
- return {
313
- "messages": state["messages"],
314
- "cheating_used": False
315
- }
316
 
317
- # Build the graph
318
- builder = StateGraph(AgentState)
319
-
320
- # Add nodes
321
- builder.add_node("cheating", cheating_node)
322
  builder.add_node("assistant", assistant)
323
  builder.add_node("tools", ToolNode(tools))
324
-
325
- # Define the workflow
326
- builder.add_edge(START, "cheating")
327
-
328
- # After cheating node, check if we found an answer
329
- def route_after_cheating(state: AgentState):
330
- """Route to end if cheating was used, otherwise to assistant"""
331
- cheating_used = state.get("cheating_used", False)
332
- print(f"Routing after cheating - cheating_used: {cheating_used}")
333
-
334
- # If we found an answer in GAIA, end the flow
335
- if cheating_used:
336
- print("Cheating was used, ending flow")
337
- return END
338
-
339
- # Otherwise, continue to assistant
340
- print("No cheating, continuing to assistant")
341
- return "assistant"
342
-
343
- builder.add_conditional_edges(
344
- "cheating",
345
- route_after_cheating
346
- )
347
-
348
- # Normal flow edges
349
  builder.add_conditional_edges(
350
  "assistant",
351
  tools_condition,
@@ -353,38 +196,4 @@ def build_graph(provider: str = "huggingface"):
353
  builder.add_edge("tools", "assistant")
354
 
355
  # Compile graph
356
- return builder.compile()
357
-
358
- class Agent():
359
- def __init__(self):
360
- self.graph = build_graph(provider="huggingface")
361
-
362
- def __call__(self, question: str) -> str:
363
- messages = [HumanMessage(content=question)]
364
- result = self.graph.invoke({"messages": messages})
365
-
366
- # Print all messages for debugging
367
- for m in result["messages"]:
368
- m.pretty_print()
369
-
370
- # Return the final answer if found
371
- if result["messages"] and result["messages"][-1].content.startswith("FINAL ANSWER: "):
372
- return result["messages"][-1].content.removeprefix("FINAL ANSWER: ")
373
-
374
- # If no final answer found but we have messages, return the last message
375
- if result["messages"]:
376
- return result["messages"][-1].content
377
-
378
- raise ValueError("No response generated.")
379
-
380
- # test
381
- if __name__ == "__main__":
382
- question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
383
- # Build the graph
384
- agent = Agent()
385
- print(agent.graph.get_graph().draw_ascii())
386
-
387
- # # Run the graph
388
- answer = agent(question)
389
- print("\n\nSubmitted answer:")
390
- print(answer)
 
1
  """LangGraph Agent"""
 
2
  import os
 
 
3
  from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain.agents import initialize_agent, Tool
10
  from langchain_groq import ChatGroq
11
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
12
+ from langchain_community.tools.tavily_search import TavilySearchResults
 
 
 
 
13
  from langchain_community.document_loaders import WikipediaLoader
14
  from langchain_community.document_loaders import ArxivLoader
15
  from langchain_community.vectorstores import SupabaseVectorStore
16
+ from langchain_core.messages import SystemMessage, HumanMessage
17
  from langchain_core.tools import tool
18
  from langchain.tools.retriever import create_retriever_tool
19
  from supabase.client import Client, create_client
 
 
 
 
 
 
 
 
 
20
 
21
  load_dotenv()
22
 
 
23
  @tool
24
  def multiply(a: int, b: int) -> int:
25
  """Multiply two numbers.
 
26
  Args:
27
  a: first int
28
  b: second int
29
  """
30
  return a * b
31
 
 
32
  @tool
33
  def add(a: int, b: int) -> int:
34
  """Add two numbers.
 
35
  Args:
36
  a: first int
37
  b: second int
38
  """
39
  return a + b
40
 
 
41
  @tool
42
  def subtract(a: int, b: int) -> int:
43
  """Subtract two numbers.
 
44
  Args:
45
  a: first int
46
  b: second int
47
  """
48
  return a - b
49
 
 
50
  @tool
51
  def divide(a: int, b: int) -> int:
52
  """Divide two numbers.
 
53
  Args:
54
  a: first int
55
  b: second int
 
58
  raise ValueError("Cannot divide by zero.")
59
  return a / b
60
 
 
61
  @tool
62
  def modulus(a: int, b: int) -> int:
63
  """Get the modulus of two numbers.
 
64
  Args:
65
  a: first int
66
  b: second int
67
  """
68
  return a % b
69
 
 
70
  @tool
71
  def wiki_search(query: str) -> str:
72
  """Search Wikipedia for a query and return maximum 2 results.
 
73
  Args:
74
  query: The search query."""
75
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
 
77
  [
78
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
79
  for doc in search_docs
80
+ ])
 
81
  return {"wiki_results": formatted_search_docs}
82
 
 
83
  @tool
84
  def web_search(query: str) -> str:
85
  """Search Tavily for a query and return maximum 3 results.
 
86
  Args:
87
  query: The search query."""
88
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
89
+ formatted_search_docs = "\n\n---\n\n".join(
90
+ [
91
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
92
+ for doc in search_docs
93
+ ])
94
+ return {"web_results": formatted_search_docs}
95
 
96
  @tool
97
  def arvix_search(query: str) -> str:
98
  """Search Arxiv for a query and return maximum 3 result.
 
99
  Args:
100
  query: The search query."""
101
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
 
103
  [
104
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
105
  for doc in search_docs
106
+ ])
 
107
  return {"arvix_results": formatted_search_docs}
108
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ # load the system prompt from the file
112
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
113
  system_prompt = f.read()
114
 
 
116
  sys_msg = SystemMessage(content=system_prompt)
117
 
118
  # build a retriever
119
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
 
 
120
  supabase: Client = create_client(
121
+ os.environ.get("SUPABASE_URL"),
122
+ os.environ.get("SUPABASE_SERVICE_KEY"))
123
  vector_store = SupabaseVectorStore(
124
  client=supabase,
125
+ embedding= embeddings,
126
  table_name="documents",
127
  query_name="match_documents_langchain",
128
  )
 
133
  )
134
 
135
 
136
+
137
  tools = [
138
  multiply,
139
  add,
140
  subtract,
141
  divide,
142
  modulus,
143
+ wiki_search,
144
  web_search,
145
  arvix_search,
146
  ]
147
 
 
 
 
148
  # Build graph function
149
+ def build_graph(provider: str = "groq"):
150
  """Build the graph"""
151
  # Load environment variables from .env file
152
  if provider == "google":
 
154
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
155
  elif provider == "groq":
156
  # Groq https://console.groq.com/docs/models
157
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
158
+ elif provider == "openai":
159
+ # OpenAI
160
+ llm = ChatOpenAI(model="gpt-4", temperature=0)
161
  elif provider == "huggingface":
 
162
  llm = ChatHuggingFace(
163
  llm=HuggingFaceEndpoint(
164
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
165
  temperature=0,
166
  ),
167
  )
168
  else:
169
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
 
170
  # Bind tools to LLM
171
  llm_with_tools = llm.bind_tools(tools)
172
 
173
+ # Node
174
  def assistant(state: MessagesState):
175
  """Assistant node"""
176
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
177
 
 
178
  def retriever(state: MessagesState):
179
  """Retriever node"""
180
  similar_question = vector_store.similarity_search(state["messages"][0].content)
 
182
  content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
183
  )
184
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ builder = StateGraph(MessagesState)
187
+ builder.add_node("retriever", retriever)
 
 
 
188
  builder.add_node("assistant", assistant)
189
  builder.add_node("tools", ToolNode(tools))
190
+ builder.add_edge(START, "retriever")
191
+ builder.add_edge("retriever", "assistant")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  builder.add_conditional_edges(
193
  "assistant",
194
  tools_condition,
 
196
  builder.add_edge("tools", "assistant")
197
 
198
  # Compile graph
199
+ return builder.compile()