Aya1610 commited on
Commit
55af95a
·
verified ·
1 Parent(s): cc5b899

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +46 -404
agent.py CHANGED
@@ -82,20 +82,36 @@ def wiki_search(query: str) -> str:
82
  ])
83
  return {"wiki_results": formatted_search_docs}
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  @tool
86
  def web_search(query: str) -> str:
87
- """Search Tavily for a query and return maximum 3 results.
88
 
89
  Args:
90
  query: The search query."""
91
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
92
- formatted_search_docs = "\n\n---\n\n".join(
 
93
  [
94
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
95
- for doc in search_docs
96
- ])
97
- return {"web_results": formatted_search_docs}
98
-
99
  @tool
100
  def arvix_search(query: str) -> str:
101
  """Search Arxiv for a query and return maximum 3 result.
@@ -127,12 +143,12 @@ supabase: Client = create_client(
127
  vector_store = SupabaseVectorStore(
128
  client=supabase,
129
  embedding= embeddings,
130
- table_name="question",
131
  query_name="match_documents_langchain",
132
  )
133
  create_retriever_tool = create_retriever_tool(
134
  retriever=vector_store.as_retriever(),
135
- name="Question Search",
136
  description="A tool to retrieve similar questions from a vector store.",
137
  )
138
 
@@ -149,29 +165,35 @@ tools = [
149
  arvix_search,
150
  ]
151
 
152
- # Build graph function
153
- def build_graph(provider: str = "google"):
154
- """Build the graph"""
155
- # Load environment variables from .env file
156
- if provider == "google":
157
- # Google Gemini
158
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
159
- elif provider == "groq":
160
- # Groq https://console.groq.com/docs/models
161
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
162
  elif provider == "huggingface":
163
- # TODO: Add huggingface endpoint
 
 
 
164
  llm = ChatHuggingFace(
165
  llm=HuggingFaceEndpoint(
166
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
167
  temperature=0,
168
- ),
169
  )
 
170
  else:
171
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
 
172
  # Bind tools to LLM
173
  llm_with_tools = llm.bind_tools(tools)
174
 
 
 
175
  # Node
176
  def assistant(state: MessagesState):
177
  """Assistant node"""
@@ -214,390 +236,10 @@ def build_graph(provider: str = "google"):
214
  builder = StateGraph(MessagesState)
215
  builder.add_node("retriever", retriever)
216
 
217
- # Retriever ist Start und Endpunkt
218
  builder.set_entry_point("retriever")
219
  builder.set_finish_point("retriever")
220
 
221
  # Compile graph
222
  return builder.compile()
223
 
224
- # GAIA Agent Solution with LangGraph and OpenAI - Standalone Version
225
- # import os
226
- # from dotenv import load_dotenv
227
- # from langgraph.graph import START, StateGraph, MessagesState
228
- # from langgraph.prebuilt import tools_condition
229
- # from langgraph.prebuilt import ToolNode
230
- # from langchain_google_genai import ChatGoogleGenerativeAI
231
- # from langchain_groq import ChatGroq
232
- # from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
233
- # from langchain_community.tools.tavily_search import TavilySearchResults
234
- # from langchain_community.document_loaders import WikipediaLoader
235
- # from langchain_community.document_loaders import ArxivLoader
236
- # from langchain_community.vectorstores import SupabaseVectorStore
237
- # from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
238
- # from langchain_core.tools import tool
239
- # from langchain.tools.retriever import create_retriever_tool
240
- # from supabase.client import Client, create_client
241
- # load_dotenv()
242
-
243
- # # --- Supabase Setup (only if credentials are provided) ---
244
- # supabase_url = os.getenv("SUPABASE_URL")
245
- # supabase_key = os.getenv("SUPABASE_SERVICE_KEY") or os.getenv("SUPABASE_KEY")
246
-
247
- # if supabase_url and supabase_key:
248
- # from supabase.client import Client, create_client
249
- # from langchain_community.vectorstores import SupabaseVectorStore
250
- # from langchain.tools.retriever import create_retriever_tool
251
- # from langchain_openai import OpenAIEmbeddings
252
- # supabase: Client = create_client(supabase_url, supabase_key)
253
- # else:
254
- # supabase = None
255
-
256
- # # --- Standard Imports ---
257
-
258
-
259
- # # OpenAI LLM
260
- # from langchain_openai import ChatOpenAI
261
-
262
- # # Optional document loaders
263
- # from langchain_community.tools.tavily_search import TavilySearchResults
264
- # from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
265
-
266
- # # --- Simple Math Tools ---
267
- # @tool
268
- # def multiply(a: int, b: int) -> int:
269
- # """Multiply two integers and return the result"""
270
- # return a * b
271
-
272
- # @tool
273
- # def add(a: int, b: int) -> int:
274
- # """Add two integers and return the sum"""
275
- # return a + b
276
-
277
- # @tool
278
- # def subtract(a: int, b: int) -> int:
279
- # """Subtract the second integer from the first and return the difference"""
280
- # return a - b
281
-
282
- # @tool
283
- # def divide(a: int, b: int) -> float:
284
- # """Divide the first integer by the second and return the quotient"""
285
- # if b == 0:
286
- # raise ValueError("Cannot divide by zero.")
287
- # return a / b
288
-
289
- # @tool
290
- # def modulus(a: int, b: int) -> int:
291
- # """Return the modulus of dividing the first integer by the second"""
292
- # return a % b
293
-
294
- # # --- Search Tools ---
295
- # @tool
296
- # def wiki_search(query: str) -> str:
297
- # """Search Wikipedia for the query and return up to 2 documents"""
298
- # try:
299
- # docs = WikipediaLoader(query=query, load_max_docs=2).load()
300
- # return "\n\n---\n\n".join(
301
- # f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content}' for doc in docs
302
- # )
303
- # except Exception as e:
304
- # return f"Wikipedia search failed: {str(e)}"
305
-
306
- # @tool
307
- # def web_search(query: str) -> str:
308
- # """Search the web using Tavily and return up to 3 results"""
309
- # try:
310
- # tavily_api_key = os.getenv("search")
311
- # if not tavily_api_key:
312
- # return "Web search unavailable: TAVILY_API_KEY not configured"
313
-
314
- # search_tool = TavilySearchResults(max_results=3, api_key=tavily_api_key)
315
- # docs = search_tool.invoke({"query": query})
316
- # return "\n\n---\n\n".join(
317
- # f'<Document source="{doc.get("url", "Unknown")}"/>\n{doc.get("content", "")}' for doc in docs
318
- # )
319
- # except Exception as e:
320
- # return f"Web search failed: {str(e)}"
321
-
322
- # @tool
323
- # def arxiv_search(query: str) -> str:
324
- # """Search Arxiv for the query and return up to 3 documents"""
325
- # try:
326
- # docs = ArxivLoader(query=query, load_max_docs=3).load()
327
- # return "\n\n---\n\n".join(
328
- # f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content[:1000]}' for doc in docs
329
- # )
330
- # except Exception as e:
331
- # return f"Arxiv search failed: {str(e)}"
332
-
333
- # # --- Assemble Tools List ---
334
- # tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arxiv_search]
335
-
336
- # # If supabase is configured, add retriever tool
337
- # if supabase:
338
- # try:
339
- # embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
340
- # # embeddings = OpenAIEmbeddings()
341
- # vector_store = SupabaseVectorStore(
342
- # client=supabase,
343
- # embedding=embeddings,
344
- # table_name="question",
345
- # query_name="match_documents_langchain",
346
- # )
347
- # retriever_tool = create_retriever_tool(
348
- # retriever=vector_store.as_retriever(),
349
- # name="question_search",
350
- # description="Retrieve similar questions from the vector store",
351
- # )
352
- # tools.append(retriever_tool)
353
- # except Exception as e:
354
- # print(f"Could not initialize Supabase retriever: {e}")
355
-
356
- # # --- Load System Prompt ---
357
- # def load_system_prompt():
358
- # """Load system prompt with fallback"""
359
- # try:
360
- # with open("system_prompt.txt", "r", encoding="utf-8") as f:
361
- # return SystemMessage(content=f.read())
362
- # except FileNotFoundError:
363
- # # Fallback system prompt
364
- # default_prompt = """You are a helpful AI assistant with access to various tools including:
365
- # - Math operations (add, subtract, multiply, divide, modulus)
366
- # - Search capabilities (Wikipedia, Arxiv, web search via Tavily)
367
- # - Information retrieval
368
-
369
- # Use these tools when appropriate to answer questions accurately and helpfully. When performing calculations, always use the provided math tools. When users ask for information that might require current data or research, use the appropriate search tools.
370
-
371
- # Be concise but thorough in your responses. If you use a tool, explain what you found or calculated."""
372
- # return SystemMessage(content=default_prompt)
373
-
374
- # sys_msg = load_system_prompt()
375
-
376
- # # --- Graph Builder (OpenAI) ---
377
- # def build_graph(provider: str = "google"):
378
- # """Build the graph"""
379
- # # Load environment variables from .env file
380
- # if provider == "google":
381
- # # Google Gemini
382
- # llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
383
- # elif provider == "groq":
384
- # # Groq https://console.groq.com/docs/models
385
- # llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
386
- # elif provider == "huggingface":
387
- # # TODO: Add huggingface endpoint
388
- # llm = ChatHuggingFace(
389
- # llm=HuggingFaceEndpoint(
390
- # url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
391
- # temperature=0,
392
- # ),
393
- # )
394
- # else:
395
- # raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
396
- # # Bind tools to LLM
397
- # llm_with_tools = llm.bind_tools(tools)
398
-
399
- # # Node
400
- # def assistant(state: MessagesState):
401
- # """Assistant node"""
402
- # return {"messages": [llm_with_tools.invoke(state["messages"])]}
403
-
404
- # # def retriever(state: MessagesState):
405
- # # """Retriever node"""
406
- # # similar_question = vector_store.similarity_search(state["messages"][0].content)
407
- # #example_msg = HumanMessage(
408
- # # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
409
- # # )
410
- # # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
411
-
412
- # # from langchain_core.messages import AIMessage
413
-
414
- # def retriever(state: MessagesState):
415
- # query = state["messages"][-1].content
416
- # similar_doc = vector_store.similarity_search(query, k=1)[0]
417
-
418
- # content = similar_doc.page_content
419
- # if "Final answer :" in content:
420
- # answer = content.split("Final answer :")[-1].strip()
421
- # else:
422
- # answer = content.strip()
423
-
424
- # return {"messages": [AIMessage(content=answer)]}
425
-
426
- # # builder = StateGraph(MessagesState)
427
- # #builder.add_node("retriever", retriever)
428
- # #builder.add_node("assistant", assistant)
429
- # #builder.add_node("tools", ToolNode(tools))
430
- # #builder.add_edge(START, "retriever")
431
- # #builder.add_edge("retriever", "assistant")
432
- # #builder.add_conditional_edges(
433
- # # "assistant",
434
- # # tools_condition,
435
- # #)
436
- # #builder.add_edge("tools", "assistant")
437
-
438
- # builder = StateGraph(MessagesState)
439
- # builder.add_node("retriever", retriever)
440
-
441
- # # Retriever ist Start und Endpunkt
442
- # builder.set_entry_point("retriever")
443
- # builder.set_finish_point("retriever")
444
-
445
- # # Compile graph
446
- # return builder.compile()
447
- # # def build_graph():
448
- # # """
449
- # # Build and return a StateGraph using OpenAI ChatGPT with tools.
450
- # # """
451
- # # print("=== BUILDING OPENAI GRAPH ===")
452
-
453
- # # # Check for OpenAI API key
454
- # # openai_api_key = os.getenv("OPENAI_API_KEY")
455
- # # print(f"OpenAI API Key: {'Found' if openai_api_key else 'Not found'}")
456
-
457
- # # if openai_api_key:
458
- # # print(f"API Key starts with: {openai_api_key[:10]}...")
459
-
460
- # # try:
461
- # # if openai_api_key and len(openai_api_key.strip()) > 0:
462
- # # print("Attempting to initialize OpenAI ChatGPT...")
463
-
464
- # # # Initialize OpenAI LLM
465
- # # llm = ChatOpenAI(
466
- # # model="gpt-3.5-turbo", # You can change to "gpt-4" if you have access
467
- # # temperature=0.1,
468
- # # api_key=openai_api_key.strip(),
469
- # # max_tokens=512
470
- # # )
471
-
472
- # # # Test the connection
473
- # # test_response = llm.invoke([HumanMessage(content="Hello")])
474
- # # print("✓ Successfully connected to OpenAI")
475
- # # print(f"Test response: {test_response.content[:50]}...")
476
-
477
- # # else:
478
- # # raise Exception("No valid OPENAI_API_KEY found")
479
-
480
- # # except Exception as e:
481
- # # print(f"Error initializing OpenAI LLM: {e}")
482
- # # print("Creating functional mock LLM...")
483
-
484
- # # class FunctionalMockLLM:
485
- # # def bind_tools(self, tools):
486
- # # self.tools = tools
487
- # # return self
488
-
489
- # # def invoke(self, messages):
490
- # # from langchain_core.messages import AIMessage
491
- # # import json
492
- # # import re
493
-
494
- # # last_msg = messages[-1] if messages else None
495
- # # if not last_msg:
496
- # # return AIMessage(content="Please ask me a question!")
497
-
498
- # # content = getattr(last_msg, 'content', str(last_msg))
499
- # # content_lower = content.lower()
500
-
501
- # # # Handle math operations with tool calls
502
- # # math_patterns = [
503
- # # (r'(\d+)\s*\+\s*(\d+)', 'add'),
504
- # # (r'(\d+)\s*-\s*(\d+)', 'subtract'),
505
- # # (r'(\d+)\s*\*\s*(\d+)', 'multiply'),
506
- # # (r'(\d+)\s*/\s*(\d+)', 'divide'),
507
- # # (r'(\d+)\s*%\s*(\d+)', 'modulus'),
508
- # # ]
509
-
510
- # # for pattern, operation in math_patterns:
511
- # # match = re.search(pattern, content)
512
- # # if match:
513
- # # a, b = int(match.group(1)), int(match.group(2))
514
-
515
- # # tool_call = {
516
- # # "name": operation,
517
- # # "args": {"a": a, "b": b},
518
- # # "id": f"call_{operation}_{a}_{b}"
519
- # # }
520
-
521
- # # return AIMessage(
522
- # # content=f"I'll {operation} {a} and {b} for you.",
523
- # # tool_calls=[tool_call]
524
- # # )
525
-
526
- # # # Handle search requests
527
- # # if any(word in content_lower for word in ['search', 'find', 'look up', 'what is', 'who is', 'tell me about']):
528
- # # # Extract search query
529
- # # search_query = content
530
- # # for phrase in ['search for', 'find', 'look up', 'what is', 'who is', 'tell me about']:
531
- # # search_query = search_query.lower().replace(phrase, '').strip()
532
-
533
- # # if len(search_query) > 100:
534
- # # search_query = search_query[:100]
535
-
536
- # # if 'wikipedia' in content_lower:
537
- # # tool_name = "wiki_search"
538
- # # elif 'arxiv' in content_lower or 'research' in content_lower or 'paper' in content_lower:
539
- # # tool_name = "arxiv_search"
540
- # # else:
541
- # # tool_name = "web_search"
542
-
543
- # # tool_call = {
544
- # # "name": tool_name,
545
- # # "args": {"query": search_query},
546
- # # "id": f"call_{tool_name}_{hash(search_query) % 1000}"
547
- # # }
548
-
549
- # # return AIMessage(
550
- # # content=f"I'll search for information about: {search_query}",
551
- # # tool_calls=[tool_call]
552
- # # )
553
-
554
- # # # Default response for other questions
555
- # # return AIMessage(content=f"I understand you're asking: {content[:200]}... I can help with math calculations and information searches. Please configure OPENAI_API_KEY for full functionality, or try asking me to calculate something or search for information.")
556
-
557
- # # llm = FunctionalMockLLM()
558
- # # print("✓ Using functional mock LLM")
559
-
560
- # # # Bind tools to LLM
561
- # # llm_with_tools = llm.bind_tools(tools)
562
-
563
- # # def retriever(state: MessagesState):
564
- # # """Add system message and handle retrieval if Supabase is available"""
565
- # # messages = [sys_msg] + state["messages"]
566
-
567
- # # if supabase and len(tools) > 8: # Check if retriever tool was added
568
- # # try:
569
- # # query = state["messages"][-1].content
570
- # # docs = vector_store.similarity_search(query, k=1)
571
- # # if docs:
572
- # # doc = docs[0]
573
- # # content = doc.page_content
574
- # # answer = content.split("Final answer :")[-1].strip() if "Final answer :" in content else content.strip()
575
- # # return {"messages": messages + [AIMessage(content=f"Retrieved context: {answer}")]}
576
- # # except Exception as e:
577
- # # print(f"Retrieval error: {e}")
578
-
579
- # # return {"messages": messages}
580
-
581
- # # def assistant(state: MessagesState):
582
- # # """Main assistant function"""
583
- # # try:
584
- # # response = llm_with_tools.invoke(state["messages"])
585
- # # return {"messages": [response]}
586
- # # except Exception as e:
587
- # # print(f"Assistant error: {e}")
588
- # # return {"messages": [AIMessage(content=f"I encountered an error: {str(e)}. Please make sure your OPENAI_API_KEY is configured correctly.")]}
589
-
590
- # # # Build the graph
591
- # # g = StateGraph(MessagesState)
592
- # # g.add_node("retriever", retriever)
593
- # # g.add_node("assistant", assistant)
594
- # # g.add_node("tools", ToolNode(tools))
595
-
596
- # # # Define edges
597
- # # g.add_edge(START, "retriever")
598
- # # g.add_edge("retriever", "assistant")
599
- # # g.add_conditional_edges("assistant", tools_condition)
600
- # # g.add_edge("tools", "assistant")
601
-
602
- # # print("✓ Graph compiled successfully")
603
- # # return g.compile()
 
82
  ])
83
  return {"wiki_results": formatted_search_docs}
84
 
85
+ # @tool
86
+ # def web_search(query: str) -> str:
87
+ # """Search Tavily for a query and return maximum 3 results.
88
+
89
+ # Args:
90
+ # query: The search query."""
91
+ # search_docs = TavilySearchResults(max_results=3).invoke(query=query)
92
+ # formatted_search_docs = "\n\n---\n\n".join(
93
+ # [
94
+ # f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
95
+ # for doc in search_docs
96
+ # ])
97
+ # return {"web_results": formatted_search_docs}
98
+ # from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
99
+
100
  @tool
101
  def web_search(query: str) -> str:
102
+ """Search the web for a query and return maximum 3 results.
103
 
104
  Args:
105
  query: The search query."""
106
+ search = DuckDuckGoSearchAPIWrapper()
107
+ results = search.results(query, 3)
108
+ formatted_results = "\n\n---\n\n".join(
109
  [
110
+ f"Title: {res['title']}\nURL: {res['link']}\nSnippet: {res['snippet']}"
111
+ for res in results
112
+ ]
113
+ )
114
+ return formatted_results
115
  @tool
116
  def arvix_search(query: str) -> str:
117
  """Search Arxiv for a query and return maximum 3 result.
 
143
  vector_store = SupabaseVectorStore(
144
  client=supabase,
145
  embedding= embeddings,
146
+ table_name="docs",
147
  query_name="match_documents_langchain",
148
  )
149
  create_retriever_tool = create_retriever_tool(
150
  retriever=vector_store.as_retriever(),
151
+ name="question_search",
152
  description="A tool to retrieve similar questions from a vector store.",
153
  )
154
 
 
165
  arvix_search,
166
  ]
167
 
168
+
169
+ def build_graph(provider: str = "openai"):
170
+ """Build the graph using OpenAI or Hugging Face"""
171
+
172
+ if provider == "openai":
173
+ # OpenAI ChatGPT (e.g., GPT-4 or GPT-3.5)
174
+ from langchain.chat_models import ChatOpenAI
175
+ llm = ChatOpenAI(model="gpt-4", temperature=0)
176
+
 
177
  elif provider == "huggingface":
178
+ # Hugging Face endpoint
179
+ from langchain.chat_models import ChatHuggingFace
180
+ from langchain.llms import HuggingFaceEndpoint
181
+
182
  llm = ChatHuggingFace(
183
  llm=HuggingFaceEndpoint(
184
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
185
  temperature=0,
186
+ )
187
  )
188
+
189
  else:
190
+ raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
191
+
192
  # Bind tools to LLM
193
  llm_with_tools = llm.bind_tools(tools)
194
 
195
+ return llm_with_tools
196
+
197
  # Node
198
  def assistant(state: MessagesState):
199
  """Assistant node"""
 
236
  builder = StateGraph(MessagesState)
237
  builder.add_node("retriever", retriever)
238
 
239
+
240
  builder.set_entry_point("retriever")
241
  builder.set_finish_point("retriever")
242
 
243
  # Compile graph
244
  return builder.compile()
245