mdicio commited on
Commit
35c8e46
·
1 Parent(s): e8cec3a

back to ggroq

Browse files
Files changed (2) hide show
  1. agent.py +342 -53
  2. app.py +9 -7
agent.py CHANGED
@@ -1,46 +1,265 @@
 
 
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
3
  from datasets import load_dataset
4
  from dotenv import load_dotenv
5
- from langchain_community.embeddings import HuggingFaceEmbeddings
6
  from langchain.schema import Document
7
  from langchain.tools.retriever import create_retriever_tool
8
  from langchain.vectorstores import Chroma
9
  from langchain_community.document_loaders import ArxivLoader, WikipediaLoader
 
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_core.messages import HumanMessage, SystemMessage
12
  from langchain_core.tools import tool
13
  from langchain_google_genai import ChatGoogleGenerativeAI
14
  from langchain_groq import ChatGroq
15
- from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings,
16
- HuggingFaceEndpoint)
17
  from langgraph.graph import START, MessagesState, StateGraph
18
  from langgraph.prebuilt import ToolNode, tools_condition
19
- from huggingface_hub import login
20
 
21
  login(token=os.environ["HUGGINGFACE_TOKEN"])
22
 
23
  load_dotenv()
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @tool
26
- def calculator(query: str) -> str:
27
- """Perform basic arithmetic operations based on the provided query.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
29
  Args:
30
- query: A mathematical query as a string, e.g., '2 + 2' or '5 * 6'."""
 
 
 
 
 
 
 
 
 
 
 
31
 
 
 
 
 
 
 
 
 
 
 
 
32
  try:
33
- # Evaluate the mathematical expression
34
- result = eval(query)
35
- return {"calculator_result": str(result)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  except Exception as e:
37
- return {"error": f"Error evaluating the expression: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  @tool
41
  def wiki_search(query: str) -> str:
42
  """Search Wikipedia for a query and return maximum 2 results.
43
-
44
  Args:
45
  query: The search query."""
46
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
@@ -56,7 +275,6 @@ def wiki_search(query: str) -> str:
56
  @tool
57
  def web_search(query: str) -> str:
58
  """Search Tavily for a query and return maximum 3 results.
59
-
60
  Args:
61
  query: The search query."""
62
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
@@ -70,9 +288,8 @@ def web_search(query: str) -> str:
70
 
71
 
72
  @tool
73
- def arvix_search(query: str) -> str:
74
  """Search Arxiv for a query and return maximum 3 result.
75
-
76
  Args:
77
  query: The search query."""
78
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
@@ -82,7 +299,74 @@ def arvix_search(query: str) -> str:
82
  for doc in search_docs
83
  ]
84
  )
85
- return {"arvix_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  system_prompt = """You are a helpful assistant tasked with answering questions using a set of tools.
@@ -100,7 +384,13 @@ embeddings = HuggingFaceEmbeddings(
100
  ) # dim=768
101
 
102
  # Load the GAIA validation dataset
103
- dataset = load_dataset("gaia-benchmark/GAIA", name="2023_level1", split="validation", trust_remote_code=True, cache_dir = "ragdata")
 
 
 
 
 
 
104
 
105
  # Extract questions and their answers
106
  documents = []
@@ -120,7 +410,7 @@ for entry in dataset:
120
  documents.append(Document(page_content=question, metadata=metadata))
121
 
122
  # Insert the documents into Chroma
123
- vectorstore = Chroma.from_documents(
124
  documents=documents,
125
  embedding=embeddings,
126
  collection_name="gaia_validation",
@@ -128,17 +418,29 @@ vectorstore = Chroma.from_documents(
128
  )
129
 
130
  create_retriever_tool = create_retriever_tool(
131
- retriever=vectorstore.as_retriever(),
132
  name="Question Search",
133
  description="A tool to retrieve similar questions from a vector store.",
134
  )
135
 
136
 
137
  tools = [
138
- calculator,
139
- wiki_search,
140
  web_search,
141
- arvix_search,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  ]
143
 
144
 
@@ -156,10 +458,10 @@ def build_graph(provider: str = "groq"):
156
  ) # optional : qwen-qwq-32b gemma2-9b-it
157
  elif provider == "huggingface":
158
  # TODO: Add huggingface endpoint
159
- llm=HuggingFaceEndpoint(
160
- repo_id="Meta-DeepLearning/llama-2-7b-chat-hf",
161
- temperature=0,
162
- )
163
  else:
164
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
165
  # Bind tools to LLM
@@ -172,35 +474,22 @@ def build_graph(provider: str = "groq"):
172
 
173
  def retriever(state: MessagesState):
174
  """Retriever node"""
175
- similar_question = vectorstore.similarity_search(state["messages"][0].content)
176
  example_msg = HumanMessage(
177
  content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
178
  )
179
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
180
 
181
  builder = StateGraph(MessagesState)
182
- builder.add_node("retriever", retriever)
183
- builder.add_node("assistant", assistant)
184
- builder.add_node("tools", ToolNode(tools))
185
- builder.add_edge(START, "retriever")
186
- builder.add_edge("retriever", "assistant")
187
- builder.add_conditional_edges(
188
- "assistant",
189
- tools_condition,
190
- )
191
- builder.add_edge("tools", "assistant")
192
-
193
- # Compile graph
194
- return builder.compile()
195
-
196
-
197
- # test
198
- if __name__ == "__main__":
199
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
200
- # Build the graph
201
- graph = build_graph(provider="groq")
202
- # Run the graph
203
- messages = [HumanMessage(content=question)]
204
- messages = graph.invoke({"messages": messages})
205
- for m in messages["messages"]:
206
- m.pretty_print()
 
1
+ import cmath
2
+ import json
3
  import os
4
+ import re
5
+ import tempfile
6
+ import uuid
7
+ from typing import Any, Dict, List, Optional
8
+ from urllib.parse import urlparse
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import pytesseract
13
+ import requests
14
+ from code_interpreter import CodeInterpreter
15
+ from dotenv import load_dotenv
16
+ from PIL import Image, ImageDraw, ImageEnhance, ImageFilter, ImageFont
17
+
18
+ interpreter_instance = CodeInterpreter()
19
+
20
+ from image_processing import *
21
 
22
+ """Langraph"""
23
  from datasets import load_dataset
24
  from dotenv import load_dotenv
25
+ from huggingface_hub import login
26
  from langchain.schema import Document
27
  from langchain.tools.retriever import create_retriever_tool
28
  from langchain.vectorstores import Chroma
29
  from langchain_community.document_loaders import ArxivLoader, WikipediaLoader
30
+ from langchain_community.embeddings import HuggingFaceEmbeddings
31
  from langchain_community.tools.tavily_search import TavilySearchResults
32
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
33
  from langchain_core.tools import tool
34
  from langchain_google_genai import ChatGoogleGenerativeAI
35
  from langchain_groq import ChatGroq
36
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
 
37
  from langgraph.graph import START, MessagesState, StateGraph
38
  from langgraph.prebuilt import ToolNode, tools_condition
39
+ from supabase.client import Client, create_client
40
 
41
  login(token=os.environ["HUGGINGFACE_TOKEN"])
42
 
43
  load_dotenv()
44
 
45
+
46
+ @tool
47
+ def multiply(a: float, b: float) -> float:
48
+ """
49
+ Multiplies two numbers.
50
+ Args:
51
+ a (float): the first number
52
+ b (float): the second number
53
+ """
54
+ return a * b
55
+
56
+
57
+ @tool
58
+ def add(a: float, b: float) -> float:
59
+ """
60
+ Adds two numbers.
61
+ Args:
62
+ a (float): the first number
63
+ b (float): the second number
64
+ """
65
+ return a + b
66
+
67
+
68
+ @tool
69
+ def subtract(a: float, b: float) -> int:
70
+ """
71
+ Subtracts two numbers.
72
+ Args:
73
+ a (float): the first number
74
+ b (float): the second number
75
+ """
76
+ return a - b
77
+
78
+
79
+ @tool
80
+ def divide(a: float, b: float) -> float:
81
+ """
82
+ Divides two numbers.
83
+ Args:
84
+ a (float): the first float number
85
+ b (float): the second float number
86
+ """
87
+ if b == 0:
88
+ raise ValueError("Cannot divided by zero.")
89
+ return a / b
90
+
91
+
92
  @tool
93
+ def modulus(a: int, b: int) -> int:
94
+ """
95
+ Get the modulus of two numbers.
96
+ Args:
97
+ a (int): the first number
98
+ b (int): the second number
99
+ """
100
+ return a % b
101
+
102
+
103
+ @tool
104
+ def power(a: float, b: float) -> float:
105
+ """
106
+ Get the power of two numbers.
107
+ Args:
108
+ a (float): the first number
109
+ b (float): the second number
110
+ """
111
+ return a**b
112
+
113
+
114
+ @tool
115
+ def square_root(a: float) -> float | complex:
116
+ """
117
+ Get the square root of a number.
118
+ Args:
119
+ a (float): the number to get the square root of
120
+ """
121
+ if a >= 0:
122
+ return a**0.5
123
+ return cmath.sqrt(a)
124
+
125
+
126
+ ### =============== DOCUMENT PROCESSING TOOLS =============== ###
127
+
128
 
129
+ @tool
130
+ def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
131
+ """
132
+ Save content to a file and return the path.
133
  Args:
134
+ content (str): the content to save to the file
135
+ filename (str, optional): the name of the file. If not provided, a random name file will be created.
136
+ """
137
+ temp_dir = tempfile.gettempdir()
138
+ if filename is None:
139
+ temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
140
+ filepath = temp_file.name
141
+ else:
142
+ filepath = os.path.join(temp_dir, filename)
143
+
144
+ with open(filepath, "w") as f:
145
+ f.write(content)
146
 
147
+ return f"File saved to {filepath}. You can read this file to process its contents."
148
+
149
+
150
+ @tool
151
+ def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
152
+ """
153
+ Download a file from a URL and save it to a temporary location.
154
+ Args:
155
+ url (str): the URL of the file to download.
156
+ filename (str, optional): the name of the file. If not provided, a random name file will be created.
157
+ """
158
  try:
159
+ # Parse URL to get filename if not provided
160
+ if not filename:
161
+ path = urlparse(url).path
162
+ filename = os.path.basename(path)
163
+ if not filename:
164
+ filename = f"downloaded_{uuid.uuid4().hex[:8]}"
165
+
166
+ # Create temporary file
167
+ temp_dir = tempfile.gettempdir()
168
+ filepath = os.path.join(temp_dir, filename)
169
+
170
+ # Download the file
171
+ response = requests.get(url, stream=True)
172
+ response.raise_for_status()
173
+
174
+ # Save the file
175
+ with open(filepath, "wb") as f:
176
+ for chunk in response.iter_content(chunk_size=8192):
177
+ f.write(chunk)
178
+
179
+ return f"File downloaded to {filepath}. You can read this file to process its contents."
180
  except Exception as e:
181
+ return f"Error downloading file: {str(e)}"
182
+
183
+
184
+ @tool
185
+ def extract_text_from_image(image_path: str) -> str:
186
+ """
187
+ Extract text from an image using OCR library pytesseract (if available).
188
+ Args:
189
+ image_path (str): the path to the image file.
190
+ """
191
+ try:
192
+ # Open the image
193
+ image = Image.open(image_path)
194
+
195
+ # Extract text from the image
196
+ text = pytesseract.image_to_string(image)
197
+
198
+ return f"Extracted text from image:\n\n{text}"
199
+ except Exception as e:
200
+ return f"Error extracting text from image: {str(e)}"
201
+
202
+
203
+ @tool
204
+ def analyze_csv_file(file_path: str, query: str) -> str:
205
+ """
206
+ Analyze a CSV file using pandas and answer a question about it.
207
+ Args:
208
+ file_path (str): the path to the CSV file.
209
+ query (str): Question about the data
210
+ """
211
+ try:
212
+ # Read the CSV file
213
+ df = pd.read_csv(file_path)
214
+
215
+ # Run various analyses based on the query
216
+ result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
217
+ result += f"Columns: {', '.join(df.columns)}\n\n"
218
+
219
+ # Add summary statistics
220
+ result += "Summary statistics:\n"
221
+ result += str(df.describe())
222
+
223
+ return result
224
+
225
+ except Exception as e:
226
+ return f"Error analyzing CSV file: {str(e)}"
227
+
228
+
229
+ @tool
230
+ def analyze_excel_file(file_path: str, query: str) -> str:
231
+ """
232
+ Analyze an Excel file using pandas and answer a question about it.
233
+ Args:
234
+ file_path (str): the path to the Excel file.
235
+ query (str): Question about the data
236
+ """
237
+ try:
238
+ # Read the Excel file
239
+ df = pd.read_excel(file_path)
240
+
241
+ # Run various analyses based on the query
242
+ result = (
243
+ f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
244
+ )
245
+ result += f"Columns: {', '.join(df.columns)}\n\n"
246
+
247
+ # Add summary statistics
248
+ result += "Summary statistics:\n"
249
+ result += str(df.describe())
250
+
251
+ return result
252
+
253
+ except Exception as e:
254
+ return f"Error analyzing Excel file: {str(e)}"
255
+
256
+
257
+ ### ============== IMAGE PROCESSING AND GENERATION TOOLS =============== ###
258
 
259
 
260
  @tool
261
  def wiki_search(query: str) -> str:
262
  """Search Wikipedia for a query and return maximum 2 results.
 
263
  Args:
264
  query: The search query."""
265
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
 
275
  @tool
276
  def web_search(query: str) -> str:
277
  """Search Tavily for a query and return maximum 3 results.
 
278
  Args:
279
  query: The search query."""
280
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
 
288
 
289
 
290
  @tool
291
+ def arxiv_search(query: str) -> str:
292
  """Search Arxiv for a query and return maximum 3 result.
 
293
  Args:
294
  query: The search query."""
295
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
 
299
  for doc in search_docs
300
  ]
301
  )
302
+ return {"arxiv_results": formatted_search_docs}
303
+
304
+
305
+ ### =============== CODE INTERPRETER TOOLS =============== ###
306
+
307
+
308
+ @tool
309
+ def execute_code_multilang(code: str, language: str = "python") -> str:
310
+ """Execute code in multiple languages (Python, Bash, SQL, C, Java) and return results.
311
+ Args:
312
+ code (str): The source code to execute.
313
+ language (str): The language of the code. Supported: "python", "bash", "sql", "c", "java".
314
+ Returns:
315
+ A string summarizing the execution results (stdout, stderr, errors, plots, dataframes if any).
316
+ """
317
+ supported_languages = ["python", "bash", "sql", "c", "java"]
318
+ language = language.lower()
319
+
320
+ if language not in supported_languages:
321
+ return f"❌ Unsupported language: {language}. Supported languages are: {', '.join(supported_languages)}"
322
+
323
+ result = interpreter_instance.execute_code(code, language=language)
324
+
325
+ response = []
326
+
327
+ if result["status"] == "success":
328
+ response.append(f"✅ Code executed successfully in **{language.upper()}**")
329
+
330
+ if result.get("stdout"):
331
+ response.append(
332
+ "\n**Standard Output:**\n```\n" + result["stdout"].strip() + "\n```"
333
+ )
334
+
335
+ if result.get("stderr"):
336
+ response.append(
337
+ "\n**Standard Error (if any):**\n```\n"
338
+ + result["stderr"].strip()
339
+ + "\n```"
340
+ )
341
+
342
+ if result.get("result") is not None:
343
+ response.append(
344
+ "\n**Execution Result:**\n```\n"
345
+ + str(result["result"]).strip()
346
+ + "\n```"
347
+ )
348
+
349
+ if result.get("dataframes"):
350
+ for df_info in result["dataframes"]:
351
+ response.append(
352
+ f"\n**DataFrame `{df_info['name']}` (Shape: {df_info['shape']})**"
353
+ )
354
+ df_preview = pd.DataFrame(df_info["head"])
355
+ response.append("First 5 rows:\n```\n" + str(df_preview) + "\n```")
356
+
357
+ if result.get("plots"):
358
+ response.append(
359
+ f"\n**Generated {len(result['plots'])} plot(s)** (Image data returned separately)"
360
+ )
361
+
362
+ else:
363
+ response.append(f"❌ Code execution failed in **{language.upper()}**")
364
+ if result.get("stderr"):
365
+ response.append(
366
+ "\n**Error Log:**\n```\n" + result["stderr"].strip() + "\n```"
367
+ )
368
+
369
+ return "\n".join(response)
370
 
371
 
372
  system_prompt = """You are a helpful assistant tasked with answering questions using a set of tools.
 
384
  ) # dim=768
385
 
386
  # Load the GAIA validation dataset
387
+ dataset = load_dataset(
388
+ "gaia-benchmark/GAIA",
389
+ name="2023_level1",
390
+ split="validation",
391
+ trust_remote_code=True,
392
+ cache_dir="ragdata",
393
+ )
394
 
395
  # Extract questions and their answers
396
  documents = []
 
410
  documents.append(Document(page_content=question, metadata=metadata))
411
 
412
  # Insert the documents into Chroma
413
+ vector_store = Chroma.from_documents(
414
  documents=documents,
415
  embedding=embeddings,
416
  collection_name="gaia_validation",
 
418
  )
419
 
420
  create_retriever_tool = create_retriever_tool(
421
+ retriever=vector_store.as_retriever(),
422
  name="Question Search",
423
  description="A tool to retrieve similar questions from a vector store.",
424
  )
425
 
426
 
427
  tools = [
 
 
428
  web_search,
429
+ wiki_search,
430
+ arxiv_search,
431
+ multiply,
432
+ add,
433
+ subtract,
434
+ divide,
435
+ modulus,
436
+ power,
437
+ square_root,
438
+ save_and_read_file,
439
+ download_file_from_url,
440
+ extract_text_from_image,
441
+ analyze_csv_file,
442
+ analyze_excel_file,
443
+ execute_code_multilang,
444
  ]
445
 
446
 
 
458
  ) # optional : qwen-qwq-32b gemma2-9b-it
459
  elif provider == "huggingface":
460
  # TODO: Add huggingface endpoint
461
+ llm = HuggingFaceEndpoint(
462
+ repo_id="Meta-DeepLearning/llama-2-7b-chat-hf",
463
+ temperature=0,
464
+ )
465
  else:
466
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
467
  # Bind tools to LLM
 
474
 
475
  def retriever(state: MessagesState):
476
  """Retriever node"""
477
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
478
  example_msg = HumanMessage(
479
  content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
480
  )
481
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
482
 
483
  builder = StateGraph(MessagesState)
484
+ builder.add_node('retriever', retriever)
485
+ builder.add_node('assistant', assistant)
486
+ builder.add_node('tools', ToolNode(tools))
487
+
488
+ builder.add_edge(START, 'retriever')
489
+ builder.add_edge('retriever', 'assistant')
490
+ builder.add_conditional_edges('assistant', tools_condition)
491
+ builder.add_edge('tools', 'assistant')
492
+
493
+ graph = builder.compile()
494
+
495
+ return graph
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -17,15 +17,17 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
17
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
18
  class BasicAgent:
19
  def __init__(self):
20
- self.graph = build_graph(provider="groq")
21
- print("✅ Agent initialized.")
22
-
23
  def __call__(self, question: str) -> str:
24
- print(f"📨 Received question: {question[:60]}...")
 
 
 
25
  messages = [HumanMessage(content=question)]
26
- result = self.graph.invoke({"messages": messages})
27
- return result["messages"][-1].content # Simplify if needed
28
-
29
 
30
  def run_and_submit_all(profile: gr.OAuthProfile | None):
31
  """
 
17
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
18
  class BasicAgent:
19
  def __init__(self):
20
+ print("BasicAgent initialized.")
21
+ self.graph = build_graph(provider = "groq")
 
22
  def __call__(self, question: str) -> str:
23
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
24
+ # fixed_answer = "This is a default answer."
25
+ # print(f"Agent returning fixed answer: {fixed_answer}")
26
+ # return fixed_answer
27
  messages = [HumanMessage(content=question)]
28
+ messages = self.graph.invoke({'messages': messages})
29
+ ans = messages['messages'][-1].content
30
+ return ans[14:]
31
 
32
  def run_and_submit_all(profile: gr.OAuthProfile | None):
33
  """