Tingusto commited on
Commit
faf691a
·
1 Parent(s): 57dc57d

Refactor agent.py to implement basic arithmetic operations (multiply, add, subtract, divide, modulus) and restore wiki and arxiv search functionalities. Remove SerpAPI, audio transcription, Excel analysis, and OCR tools. Update requirements.txt to include Gradio for OAuth support.

Browse files
Files changed (2) hide show
  1. agent.py +65 -126
  2. requirements.txt +1 -0
agent.py CHANGED
@@ -8,180 +8,119 @@ from langchain_community.document_loaders import WikipediaLoader
8
  from langchain_community.document_loaders import ArxivLoader
9
  from langchain_core.messages import SystemMessage, HumanMessage
10
  from langchain_core.tools import tool
11
- import requests
12
- from bs4 import BeautifulSoup
13
- import urllib.parse
14
- import re
15
- import pandas as pd
16
- import pytesseract
17
- from PIL import Image
18
- import whisper
19
- import yt_dlp
20
- import tempfile
21
- import subprocess
22
 
23
  load_dotenv()
24
 
25
- SERPAPI_API_KEY = os.getenv("SERPAPI_API_KEY")
26
-
27
  @tool
28
- def wiki_search(query: str) -> str:
29
- """Search Wikipedia for information.
30
-
31
  Args:
32
- query: The search query."""
33
- try:
34
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
35
- formatted_search_docs = "\n\n---\n\n".join(
36
- [
37
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
38
- for doc in search_docs
39
- ])
40
- return {"wiki_results": formatted_search_docs}
41
- except Exception as e:
42
- return f"Error searching Wikipedia: {str(e)}"
43
 
44
  @tool
45
- def serpapi_search(query: str) -> str:
46
- """Search the web using SerpAPI (Google Custom Search).
47
 
48
  Args:
49
- query: The search query."""
50
- try:
51
- if not SERPAPI_API_KEY:
52
- return "SerpAPI key not set"
53
- params = {
54
- "q": query,
55
- "api_key": SERPAPI_API_KEY,
56
- "engine": "google",
57
- "num": 3,
58
- "hl": "en"
59
- }
60
- response = requests.get("https://serpapi.com/search", params=params)
61
- response.raise_for_status()
62
- data = response.json()
63
- results = []
64
- for r in data.get("organic_results", [])[:3]:
65
- title = r.get("title", "")
66
- snippet = r.get("snippet", "")
67
- results.append(f"Title: {title}\nSnippet: {snippet}")
68
- return {"web_results": "\n\n".join(results) if results else "No results found"}
69
- except Exception as e:
70
- return f"Error searching web: {str(e)}"
71
 
72
  @tool
73
- def arxiv_search(query: str) -> str:
74
- """Search Arxiv for scientific papers.
75
 
76
  Args:
77
- query: The search query."""
78
- try:
79
- search_docs = ArxivLoader(query=query, load_max_docs=2).load()
80
- formatted_search_docs = "\n\n---\n\n".join(
81
- [
82
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
83
- for doc in search_docs
84
- ])
85
- return {"arxiv_results": formatted_search_docs}
86
- except Exception as e:
87
- return f"Error searching Arxiv: {str(e)}"
88
 
89
  @tool
90
- def reverse_text(text: str) -> str:
91
- """Reverse the given text.
92
 
93
  Args:
94
- text: The text to reverse."""
95
- return text[::-1]
 
 
 
 
96
 
97
  @tool
98
- def transcribe_audio(file_path: str) -> str:
99
- """Transcribe an audio file using Whisper.
100
 
101
  Args:
102
- file_path: Path to the audio file."""
103
- try:
104
- model = whisper.load_model("base")
105
- result = model.transcribe(file_path)
106
- return result["text"]
107
- except Exception as e:
108
- return f"Error transcribing audio: {str(e)}"
109
 
110
  @tool
111
- def analyze_excel(file_path: str, column: str = None) -> str:
112
- """Analyze an Excel file and return the sum of a column or all data.
113
-
114
- Args:
115
- file_path: Path to the Excel file.
116
- column: Optional column to sum."""
117
- try:
118
- df = pd.read_excel(file_path)
119
- if column and column in df.columns:
120
- return str(df[column].sum())
121
- return df.to_csv(index=False)
122
- except Exception as e:
123
- return f"Error analyzing Excel: {str(e)}"
124
-
125
- @tool
126
- def ocr_image(file_path: str) -> str:
127
- """Extract text from an image using OCR.
128
 
129
  Args:
130
- file_path: Path to the image file."""
131
- try:
132
- img = Image.open(file_path)
133
- text = pytesseract.image_to_string(img)
134
- return text
135
- except Exception as e:
136
- return f"Error extracting text from image: {str(e)}"
 
137
 
138
  @tool
139
- def analyze_youtube_video(video_url: str) -> str:
140
- """Download and transcribe a YouTube video using yt-dlp and Whisper.
141
 
142
  Args:
143
- video_url: The URL of the YouTube video."""
144
- try:
145
- with tempfile.TemporaryDirectory() as tmpdir:
146
- ydl_opts = {
147
- 'format': 'bestaudio/best',
148
- 'outtmpl': f'{tmpdir}/%(id)s.%(ext)s',
149
- 'quiet': True,
150
- }
151
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
152
- info = ydl.extract_info(video_url, download=True)
153
- audio_path = ydl.prepare_filename(info)
154
- model = whisper.load_model("base")
155
- result = model.transcribe(audio_path)
156
- return result["text"][:2000] # Limite la sortie
157
- except Exception as e:
158
- return f"Error analyzing YouTube video: {str(e)}"
159
 
160
  # System prompt
161
- system_prompt = """You are a highly accurate question-answering assistant. Your answers must be:\n- Direct, with no extra words or explanations.\n- Formatted exactly as requested (numbers only, comma-separated lists, etc.).\n- If the question involves a file, extract only the requested information.\n- If the question is about a video, audio, or image, use the appropriate tool to extract the answer.\n- If you are unsure, provide the most likely answer in the correct format.\n- Never add units, explanations, or formatting unless explicitly requested.\n"""
 
 
 
 
 
162
 
163
  # System message
164
  sys_msg = SystemMessage(content=system_prompt)
165
 
166
  # Tools list
167
  tools = [
 
 
 
 
 
168
  wiki_search,
169
- serpapi_search,
170
  arxiv_search,
171
- reverse_text,
172
- transcribe_audio,
173
- analyze_excel,
174
- ocr_image,
175
- analyze_youtube_video,
176
  ]
177
 
178
  def build_graph():
179
  """Build the graph"""
 
180
  llm = ChatGroq(
181
  model="llama3-70b-8192",
182
  temperature=0.1
183
  )
184
 
 
185
  llm_with_tools = llm.bind_tools(tools)
186
 
187
  # Node
 
8
  from langchain_community.document_loaders import ArxivLoader
9
  from langchain_core.messages import SystemMessage, HumanMessage
10
  from langchain_core.tools import tool
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  load_dotenv()
13
 
 
 
14
  @tool
15
+ def multiply(a: int, b: int) -> int:
16
+ """Multiply two numbers.
 
17
  Args:
18
+ a: first int
19
+ b: second int
20
+ """
21
+ return a * b
 
 
 
 
 
 
 
22
 
23
  @tool
24
+ def add(a: int, b: int) -> int:
25
+ """Add two numbers.
26
 
27
  Args:
28
+ a: first int
29
+ b: second int
30
+ """
31
+ return a + b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  @tool
34
+ def subtract(a: int, b: int) -> int:
35
+ """Subtract two numbers.
36
 
37
  Args:
38
+ a: first int
39
+ b: second int
40
+ """
41
+ return a - b
 
 
 
 
 
 
 
42
 
43
  @tool
44
+ def divide(a: int, b: int) -> int:
45
+ """Divide two numbers.
46
 
47
  Args:
48
+ a: first int
49
+ b: second int
50
+ """
51
+ if b == 0:
52
+ raise ValueError("Cannot divide by zero.")
53
+ return a / b
54
 
55
  @tool
56
+ def modulus(a: int, b: int) -> int:
57
+ """Get the modulus of two numbers.
58
 
59
  Args:
60
+ a: first int
61
+ b: second int
62
+ """
63
+ return a % b
 
 
 
64
 
65
  @tool
66
+ def wiki_search(query: str) -> str:
67
+ """Search Wikipedia for a query and return maximum 2 results.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  Args:
70
+ query: The search query."""
71
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
72
+ formatted_search_docs = "\n\n---\n\n".join(
73
+ [
74
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
75
+ for doc in search_docs
76
+ ])
77
+ return {"wiki_results": formatted_search_docs}
78
 
79
  @tool
80
+ def arxiv_search(query: str) -> str:
81
+ """Search Arxiv for a query and return maximum 3 result.
82
 
83
  Args:
84
+ query: The search query."""
85
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
86
+ formatted_search_docs = "\n\n---\n\n".join(
87
+ [
88
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
89
+ for doc in search_docs
90
+ ])
91
+ return {"arxiv_results": formatted_search_docs}
 
 
 
 
 
 
 
 
92
 
93
  # System prompt
94
+ system_prompt = """You are a highly accurate question-answering assistant. Your answers must be:
95
+ - Direct, with no extra words or explanations.
96
+ - Formatted exactly as requested (numbers only, comma-separated lists, etc.).
97
+ - If the question involves a file, extract only the requested information.
98
+ - If you are unsure, provide the most likely answer in the correct format.
99
+ - Never add units, explanations, or formatting unless explicitly requested."""
100
 
101
  # System message
102
  sys_msg = SystemMessage(content=system_prompt)
103
 
104
  # Tools list
105
  tools = [
106
+ multiply,
107
+ add,
108
+ subtract,
109
+ divide,
110
+ modulus,
111
  wiki_search,
 
112
  arxiv_search,
 
 
 
 
 
113
  ]
114
 
115
  def build_graph():
116
  """Build the graph"""
117
+ # Initialize Groq LLM
118
  llm = ChatGroq(
119
  model="llama3-70b-8192",
120
  temperature=0.1
121
  )
122
 
123
+ # Bind tools to LLM
124
  llm_with_tools = llm.bind_tools(tools)
125
 
126
  # Node
requirements.txt CHANGED
@@ -19,3 +19,4 @@ whisper
19
  pytesseract
20
  pillow
21
  yt-dlp
 
 
19
  pytesseract
20
  pillow
21
  yt-dlp
22
+ gradio[oauth]