SergeyO7 commited on
Commit
1692baf
·
verified ·
1 Parent(s): 424924f

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +321 -0
agent.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import CodeAgent, ToolCallingAgent, LiteLLMModel, tool, Tool, load_tool, WebSearchTool, DuckDuckGoSearchTool
2
+ import asyncio
3
+ import os
4
+ import re
5
+ import pandas as pd
6
+ from typing import Optional
7
+ from token_bucket import Limiter, MemoryStorage
8
+ import yaml
9
+ from PIL import Image, ImageOps
10
+ import requests
11
+ from io import BytesIO
12
+ from markdownify import markdownify
13
+ import whisper
14
+ import time
15
+ import shutil
16
+ import traceback
17
+ from langchain_community.document_loaders import ArxivLoader
18
+ import logging
19
+ import io
20
+ import base64
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ @tool
25
+ def search_arxiv(query: str) -> str:
26
+ """Search Arxiv for a query and return maximum 3 result.
27
+
28
+ Args:
29
+ query: The search query.
30
+ Returns:
31
+ str: Formatted search results
32
+ """
33
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
34
+ formatted_search_docs = "\n\n---\n\n".join(
35
+ [
36
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
37
+ for doc in search_docs
38
+ ])
39
+ return {"arxiv_results": formatted_search_docs}
40
+
41
+ class ChessboardToFENOnlineTool(Tool):
42
+ name = "chessboard_to_fen_online"
43
+ description = "Converts a chessboard image to FEN using an online API (no local templates needed)."
44
+ inputs = {
45
+ 'image_path': {
46
+ 'type': 'string',
47
+ 'description': 'Path to the PNG/JPG image of the chessboard.'
48
+ }
49
+ }
50
+ output_type = "string"
51
+
52
+ def forward(self, image_path: str) -> str:
53
+ try:
54
+ with open(image_path, "rb") as image_file:
55
+ encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
56
+ except FileNotFoundError:
57
+ return "Error: Image file not found."
58
+
59
+ api_url = "https://api.chessvision.ai/v1/recognize"
60
+ headers = {
61
+ "Authorization": "Bearer YOUR_API_KEY", # Replace with actual key
62
+ "Content-Type": "application/json"
63
+ }
64
+ payload = {
65
+ "image": encoded_image,
66
+ "format": "fen"
67
+ }
68
+
69
+ try:
70
+ response = requests.post(api_url, headers=headers, json=payload)
71
+ if response.status_code == 200:
72
+ return response.json().get("fen", "Error: FEN not found in response.")
73
+ else:
74
+ return f"API Error: {response.status_code} - {response.text}"
75
+ except Exception as e:
76
+ return f"API Call Failed: {str(e)}"
77
+
78
+ class VisitWebpageTool(Tool):
79
+ name = "visit_webpage"
80
+ description = "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
81
+ inputs = {'url': {'type': 'string', 'description': 'The url of the webpage to visit.'}}
82
+ output_type = "string"
83
+
84
+ @retry(
85
+ stop=stop_after_attempt(3),
86
+ wait=wait_exponential(multiplier=1, min=4, max=10),
87
+ retry=retry_if_exception(is_429_error)
88
+ )
89
+ def forward(self, url: str) -> str:
90
+ try:
91
+ response = requests.get(url, timeout=50)
92
+ response.raise_for_status()
93
+ markdown_content = markdownify(response.text).strip()
94
+ markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
95
+ return markdown_content
96
+ except requests.exceptions.HTTPError as e:
97
+ if e.response.status_code == 429:
98
+ raise # Retry on 429
99
+ return f"Error fetching the webpage: {str(e)}"
100
+ except requests.exceptions.Timeout:
101
+ return "The request timed out. Please try again later or check the URL."
102
+ except requests.exceptions.RequestException as e:
103
+ return f"Error fetching the webpage: {str(e)}"
104
+ except Exception as e:
105
+ return f"An unexpected error occurred: {str(e)}"
106
+
107
+ def __init__(self, *args, **kwargs):
108
+ self.is_initialized = False
109
+
110
+ class SpeechToTextTool(Tool):
111
+ name = "speech_to_text"
112
+ description = "Converts an audio file to text using OpenAI Whisper."
113
+ inputs = {
114
+ "audio_path": {"type": "string", "description": "Path to audio file (.mp3, .wav)"},
115
+ }
116
+ output_type = "string"
117
+
118
+ def __init__(self):
119
+ super().__init__()
120
+ try:
121
+ self.model = whisper.load_model("base")
122
+ logger.info("Whisper model loaded successfully.")
123
+ except Exception as e:
124
+ logger.error(f"Failed to load Whisper model: {str(e)}")
125
+ raise RuntimeError(f"Failed to load Whisper model: {str(e)}")
126
+
127
+ def forward(self, audio_path: str) -> str:
128
+ if not os.path.exists(audio_path):
129
+ return f"Error: File not found at {audio_path}"
130
+ try:
131
+ print(f"Starting transcription for {audio_path}...")
132
+ result = self.model.transcribe(audio_path)
133
+ print(f"Transcription completed for {audio_path}.")
134
+ return result.get("text", "")
135
+ except Exception as e:
136
+ return f"Error processing audio file: {str(e)}"
137
+
138
+ class ExcelReaderTool(Tool):
139
+ name = "excel_reader"
140
+ description = "Reads and returns a pandas DataFrame from an Excel file (.xlsx, .xls)."
141
+ inputs = {
142
+ "excel_path": {
143
+ "type": "string",
144
+ "description": "The path to the Excel file to read",
145
+ },
146
+ "sheet_name": {
147
+ "type": "string",
148
+ "description": "The name of the sheet to read (optional, defaults to first sheet)",
149
+ "nullable": True
150
+ }
151
+ }
152
+ output_type = "pandas.DataFrame"
153
+
154
+ def forward(self, excel_path: str, sheet_name: str = None) -> pd.DataFrame:
155
+ try:
156
+ if not os.path.exists(excel_path):
157
+ return f"Error: Excel file not found at {excel_path}"
158
+ if sheet_name:
159
+ df = pd.read_excel(excel_path, sheet_name=sheet_name)
160
+ else:
161
+ df = pd.read_excel(excel_path)
162
+ return df
163
+ except Exception as e:
164
+ return f"Error reading Excel file: {str(e)}"
165
+
166
+ class PythonCodeReaderTool(Tool):
167
+ name = "read_python_code"
168
+ description = "Reads a Python (.py) file and returns its content as a string."
169
+ inputs = {
170
+ "file_path": {"type": "string", "description": "The path to the Python file to read"}
171
+ }
172
+ output_type = "string"
173
+
174
+ def forward(self, file_path: str) -> str:
175
+ try:
176
+ if not os.path.exists(file_path):
177
+ return f"Error: Python file not found at {file_path}"
178
+ with open(file_path, "r", encoding="utf-8") as file:
179
+ content = file.read()
180
+ return content
181
+ except Exception as e:
182
+ return f"Error reading Python file: {str(e)}"
183
+
184
+ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
185
+
186
+ class RetryDuckDuckGoSearchTool(DuckDuckGoSearchTool):
187
+ @retry(
188
+ stop=stop_after_attempt(3),
189
+ wait=wait_exponential(multiplier=1, min=4, max=10),
190
+ retry=retry_if_exception_type(Exception)
191
+ )
192
+ def forward(self, query: str) -> str:
193
+ return super().forward(query)
194
+
195
+ class MagAgent:
196
+ def __init__(self, rate_limiter: Optional[Limiter] = None):
197
+ """Initialize the MagAgent with search tools."""
198
+ logger.info("Initializing MagAgent")
199
+ self.rate_limiter = rate_limiter
200
+
201
+ print("Initializing MagAgent with search tools...")
202
+ try:
203
+ # Verify GEMINI_KEY
204
+ gemini_key = os.environ.get("GEMINI_KEY")
205
+ if not gemini_key:
206
+ raise ValueError("GEMINI_KEY environment variable is not set.")
207
+
208
+ model = LiteLLMModel(
209
+ model_id="gemini/gemini-1.5-flash",
210
+ api_key=gemini_key,
211
+ max_tokens=8192
212
+ )
213
+
214
+ self.imports = [
215
+ "pandas",
216
+ "numpy",
217
+ "os",
218
+ "requests",
219
+ "tempfile",
220
+ "datetime",
221
+ "json",
222
+ "time",
223
+ "re",
224
+ "openpyxl",
225
+ "pathlib",
226
+ "sys",
227
+ "bs4",
228
+ "arxiv",
229
+ "whisper",
230
+ "io",
231
+ "base64"
232
+ ]
233
+
234
+ self.tools = [
235
+ SpeechToTextTool(),
236
+ ExcelReaderTool(),
237
+ PythonCodeReaderTool(),
238
+ ChessboardToFENOnlineTool(),
239
+ search_arxiv,
240
+ ]
241
+
242
+ self.prompt_template = (
243
+ """
244
+ You are an advanced AI assistant specialized in solving complex, real-world tasks, requiring multi-step reasoning, factual accuracy, and use of external tools.
245
+ Follow these principles:
246
+ - Reason step-by-step. Think through the solution logically and plan your actions carefully before answering.
247
+ - Validate information. Always verify facts when possible instead of guessing.
248
+ - When processing external data (e.g., YouTube transcripts, web searches), expect potential issues like missing punctuation, inconsistent formatting, or conversational text.
249
+ - When asked to process Excel files, use the `excel_reader` tool, which returns a pandas DataFrame.
250
+ - When calculating sales, make sure you multiply volume on price per each product or category.
251
+ - When asked to transcript YouTube video, try searching it in www.youtubetotranscript.com.
252
+ - If the input is ambiguous, prioritize extracting key information relevant to the question.
253
+ - Use code if needed. For calculations, parsing, or transformations, generate Python code and execute it. Be cautious, as some questions contain time-consuming tasks, so analyze the question and choose the most efficient solution.
254
+ - Be precise and concise. The final answer must strictly match the required format with no extra commentary.
255
+ - Use tools intelligently. If a question involves external information, structured data, images, or audio, call the appropriate tool to retrieve or process it.
256
+ - If the question includes direct speech or quoted text (e.g., "Isn't that hot?"), treat it as a precise query and preserve the quoted structure in your response, including quotation marks for direct quotes (e.g., final_answer('"Extremely."')).
257
+ - If asked about the name of a place or city, use the full complete name without abbreviations (e.g., use Saint Petersburg instead of St.Petersburg).
258
+ - If asked to look up page numbers, make sure you don't mix them with problem or exercise numbers.
259
+ - If you cannot retrieve or process data (e.g., due to blocked requests), retry after 15 seconds delay, try another tool (try wikipedia_search, then web_search, then search_arxiv). Otherwise, return a clear error message: "Unable to retrieve data. Search has failed."
260
+ - Use `final_answer` to give the final answer.
261
+
262
+ QUESTION: {question}
263
+
264
+ {file_section}
265
+
266
+ ANSWER:
267
+ """
268
+ )
269
+
270
+ web_agent = ToolCallingAgent(
271
+ tools=[
272
+ WebSearchTool(),
273
+ VisitWebpageTool(),
274
+ search_arxiv,
275
+ ],
276
+ model=model,
277
+ max_steps=15,
278
+ name="web_search_agent",
279
+ description="Runs web searches for you.",
280
+ )
281
+
282
+ self.agent = CodeAgent(
283
+ model=model,
284
+ managed_agents=[web_agent],
285
+ tools=self.tools,
286
+ add_base_tools=True,
287
+ additional_authorized_imports=self.imports,
288
+ verbosity_level=2,
289
+ max_steps=10
290
+ )
291
+ print("MagAgent initialized.")
292
+ except Exception as e:
293
+ logger.error(f"Failed to initialize MagAgent: {str(e)}\n{traceback.format_exc()}")
294
+ raise
295
+
296
+ async def __call__(self, question: str, file_path: Optional[str] = None) -> str:
297
+ """Process a question asynchronously using the MagAgent."""
298
+ print(f"MagAgent received question (first 50 chars): {question[:50]}... File path: {file_path}")
299
+ try:
300
+ if self.rate_limiter:
301
+ while not self.rate_limiter.consume(1):
302
+ print(f"Rate limit reached. Waiting...")
303
+ await asyncio.sleep(4)
304
+ file_section = f"FILE: {file_path}" if file_path else ""
305
+ task = self.prompt_template.format(
306
+ question=question,
307
+ file_section=file_section
308
+ )
309
+ print(f"Calling agent.run...")
310
+ response = await asyncio.to_thread(self.agent.run, task=task)
311
+ print(f"Agent.run completed.")
312
+ response = str(response)
313
+ if not response:
314
+ print(f"No answer found.")
315
+ response = "No answer found."
316
+ print(f"MagAgent response: {response[:50]}...")
317
+ return response
318
+ except Exception as e:
319
+ error_msg = f"Error processing question: {str(e)}. Check API key or network connectivity."
320
+ print(error_msg)
321
+ return error_msg