Muksia commited on
Commit
af5d4bc
·
verified ·
1 Parent(s): 27884a0

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +269 -51
agent.py CHANGED
@@ -1,79 +1,297 @@
1
- import importlib
 
 
 
 
 
 
2
  import os
 
 
3
 
4
  import requests
5
  import yaml
6
  import pandas as pd
7
 
8
- from config import DEFAULT_API_URL
9
- from smolagents import CodeAgent, DuckDuckGoSearchTool, VisitWebpageTool, WikipediaSearchTool, Tool, OpenAIServerModel, SpeechToTextTool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class GetTaskFileTool(Tool):
 
 
 
12
  name = "get_task_file_tool"
13
- description = """This tool downloads the file content associated with the given task_id if exists. Returns absolute file path"""
14
  inputs = {
15
- "task_id": {"type": "string", "description": "Task id"},
16
- "file_name": {"type": "string", "description": "File name"},
17
  }
18
- output_type = "string"
19
 
20
  def forward(self, task_id: str, file_name: str) -> str:
21
- response = requests.get(f"{DEFAULT_API_URL}/files/{task_id}", timeout=15)
22
- response.raise_for_status()
23
- with open(file_name, 'wb') as file:
24
- file.write(response.content)
25
- return os.path.abspath(file_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  class LoadXlsxFileTool(Tool):
 
 
 
28
  name = "load_xlsx_file_tool"
29
- description = """This tool loads xlsx file into pandas and returns it"""
30
  inputs = {
31
- "file_path": {"type": "string", "description": "File path"}
32
  }
 
 
33
  output_type = "object"
34
 
35
- def forward(self, file_path: str) -> object:
36
- return pd.read_excel(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  class LoadTextFileTool(Tool):
 
 
 
39
  name = "load_text_file_tool"
40
- description = """This tool loads any text file"""
41
  inputs = {
42
- "file_path": {"type": "string", "description": "File path"}
43
  }
44
- output_type = "string"
45
 
46
- def forward(self, file_path: str) -> object:
47
- with open(file_path, 'r', encoding='utf-8') as file:
48
- return file.read()
49
 
 
 
50
 
51
- prompts = yaml.safe_load(
52
- importlib.resources.files("smolagents.prompts").joinpath("code_agent.yaml").read_text()
53
- )
54
- prompts["system_prompt"] = ("You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. "
55
- + prompts["system_prompt"])
56
-
57
- def init_agent():
58
- gemini_model = OpenAIServerModel(
59
- model_id="gemini-2.0-flash",
60
- api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
61
- api_key=os.getenv("API_KEY"),
62
- temperature=0.7
63
- )
64
- agent = CodeAgent(
65
- tools=[
66
- DuckDuckGoSearchTool(),
67
- VisitWebpageTool(),
68
- WikipediaSearchTool(),
69
- GetTaskFileTool(),
70
- SpeechToTextTool(),
71
- LoadXlsxFileTool(),
72
- LoadTextFileTool()
73
- ],
74
- model=gemini_model,
75
- prompt_templates=prompts,
76
- max_steps=15,
77
- additional_authorized_imports = ["pandas"]
78
- )
79
- return agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Initializes and configures a SmolAgents CodeAgent with custom tools
4
+ for file handling and web interaction.
5
+ """
6
+
7
+ import importlib.resources
8
  import os
9
+ import logging # Added for logging errors
10
+ from typing import Type # Added for more specific type hints
11
 
12
  import requests
13
  import yaml
14
  import pandas as pd
15
 
16
+ # Assuming 'config.py' exists in the same directory or Python path
17
+ # and contains: DEFAULT_API_URL = "your_api_url_here"
18
+ try:
19
+ from config import DEFAULT_API_URL
20
+ except ImportError:
21
+ # Provide a default or raise a more specific error if config is crucial
22
+ DEFAULT_API_URL = "http://localhost:8000" # Example default, adjust as needed
23
+ logging.warning("config.py not found or DEFAULT_API_URL not set. Using default: %s", DEFAULT_API_URL)
24
+
25
+ from smolagents import (
26
+ CodeAgent,
27
+ Tool,
28
+ OpenAIServerModel,
29
+ # Standard Tools
30
+ DuckDuckGoSearchTool,
31
+ VisitWebpageTool,
32
+ WikipediaSearchTool,
33
+ SpeechToTextTool,
34
+ )
35
+
36
+ # Configure logging
37
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
38
+
39
+ # --- Custom Tools ---
40
 
41
  class GetTaskFileTool(Tool):
42
+ """
43
+ A tool to download a file associated with a specific task ID from a predefined API endpoint.
44
+ """
45
  name = "get_task_file_tool"
46
+ description = "Downloads the file content associated with the given task_id if it exists. Returns the absolute file path of the downloaded file."
47
  inputs = {
48
+ "task_id": {"type": "string", "description": "The unique identifier for the task."},
49
+ "file_name": {"type": "string", "description": "The desired local name for the downloaded file."},
50
  }
51
+ output_type = "string" # Output is the file path or an error message
52
 
53
  def forward(self, task_id: str, file_name: str) -> str:
54
+ """
55
+ Executes the file download process.
56
+
57
+ Args:
58
+ task_id: The ID of the task whose file should be downloaded.
59
+ file_name: The name to save the downloaded file as locally.
60
+
61
+ Returns:
62
+ The absolute path to the downloaded file if successful,
63
+ otherwise an error message string.
64
+ """
65
+ url = f"{DEFAULT_API_URL}/files/{task_id}"
66
+ logging.info("Attempting to download file from: %s", url)
67
+ try:
68
+ response = requests.get(url, timeout=30) # Increased timeout slightly
69
+ response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx)
70
+
71
+ # Ensure the directory exists if file_name includes a path
72
+ # For simplicity here, we assume file_name is just a name,
73
+ # and it's saved in the current working directory.
74
+ # Consider adding directory creation logic if needed:
75
+ # os.makedirs(os.path.dirname(file_path), exist_ok=True)
76
+
77
+ file_path = os.path.abspath(file_name)
78
+ with open(file_path, 'wb') as file:
79
+ file.write(response.content)
80
+ logging.info("File successfully downloaded and saved to: %s", file_path)
81
+ return file_path
82
+
83
+ except requests.exceptions.RequestException as e:
84
+ error_msg = f"Error downloading file for task {task_id}: {e}"
85
+ logging.error(error_msg)
86
+ return error_msg # Return error message for the agent
87
+ except IOError as e:
88
+ error_msg = f"Error saving file {file_name}: {e}"
89
+ logging.error(error_msg)
90
+ return error_msg # Return error message
91
+
92
 
93
  class LoadXlsxFileTool(Tool):
94
+ """
95
+ A tool to load data from an XLSX (Excel) file into a pandas DataFrame.
96
+ """
97
  name = "load_xlsx_file_tool"
98
+ description = "Loads data from an XLSX file specified by its path into a pandas DataFrame."
99
  inputs = {
100
+ "file_path": {"type": "string", "description": "The local path to the XLSX file."}
101
  }
102
+ # Using object is acceptable here as DataFrames are complex types,
103
+ # but adding pandas type hint for internal clarity.
104
  output_type = "object"
105
 
106
+ def forward(self, file_path: str) -> pd.DataFrame | str:
107
+ """
108
+ Executes the XLSX file loading process.
109
+
110
+ Args:
111
+ file_path: The path to the XLSX file.
112
+
113
+ Returns:
114
+ A pandas DataFrame containing the data from the first sheet
115
+ if successful, otherwise an error message string.
116
+ """
117
+ logging.info("Attempting to load XLSX file: %s", file_path)
118
+ try:
119
+ # Ensure the file exists before attempting to read
120
+ if not os.path.exists(file_path):
121
+ raise FileNotFoundError(f"No such file or directory: '{file_path}'")
122
+
123
+ # Load the excel file. You might want to add options like sheet_name=None
124
+ # to load all sheets into a dictionary of DataFrames if needed.
125
+ df = pd.read_excel(file_path)
126
+ logging.info("Successfully loaded XLSX file into DataFrame.")
127
+ # Note: Returning the actual DataFrame object for the agent to use.
128
+ # The agent's Python execution environment needs pandas installed.
129
+ return df
130
+
131
+ except FileNotFoundError as e:
132
+ error_msg = f"Error loading XLSX: {e}"
133
+ logging.error(error_msg)
134
+ return error_msg # Return error message
135
+ except Exception as e:
136
+ # Catch other potential errors during pandas read_excel (e.g., bad format, permissions)
137
+ # xlrd might be needed for .xls, openpyxl for .xlsx
138
+ error_msg = f"Error reading Excel file {file_path}: {e}"
139
+ logging.error(error_msg)
140
+ return error_msg # Return error message
141
+
142
 
143
  class LoadTextFileTool(Tool):
144
+ """
145
+ A tool to load the content of a text file into a single string.
146
+ """
147
  name = "load_text_file_tool"
148
+ description = "Loads the entire content of any text file specified by its path."
149
  inputs = {
150
+ "file_path": {"type": "string", "description": "The local path to the text file."}
151
  }
152
+ output_type = "string" # Output is the file content or an error message
153
 
154
+ def forward(self, file_path: str) -> str:
155
+ """
156
+ Executes the text file loading process.
157
 
158
+ Args:
159
+ file_path: The path to the text file.
160
 
161
+ Returns:
162
+ The content of the text file as a string if successful,
163
+ otherwise an error message string.
164
+ """
165
+ logging.info("Attempting to load text file: %s", file_path)
166
+ try:
167
+ # Ensure the file exists before attempting to read
168
+ if not os.path.exists(file_path):
169
+ raise FileNotFoundError(f"No such file or directory: '{file_path}'")
170
+
171
+ with open(file_path, 'r', encoding='utf-8') as file:
172
+ content = file.read()
173
+ logging.info("Successfully loaded text file.")
174
+ return content
175
+
176
+ except FileNotFoundError as e:
177
+ error_msg = f"Error loading text file: {e}"
178
+ logging.error(error_msg)
179
+ return error_msg # Return error message
180
+ except UnicodeDecodeError as e:
181
+ error_msg = f"Encoding error reading file {file_path} as UTF-8: {e}"
182
+ logging.error(error_msg)
183
+ # Consider trying other encodings or returning raw bytes if appropriate
184
+ return error_msg # Return error message
185
+ except IOError as e:
186
+ error_msg = f"Error reading file {file_path}: {e}"
187
+ logging.error(error_msg)
188
+ return error_msg # Return error message
189
+
190
+
191
+ # --- Agent Configuration ---
192
+
193
+ # Define the custom prefix for the system prompt clearly
194
+ SYSTEM_PROMPT_PREFIX = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
195
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
196
+ - If you are asked for a number, don't use comma separators (e.g., 1000 instead of 1,000) and avoid units like $ or % unless explicitly requested.
197
+ - If you are asked for a string, use standard capitalization, avoid abbreviations (e.g., Los Angeles instead of LA), and write out digits as words (e.g., five instead of 5) unless numbers are specifically requested. Avoid leading/trailing articles (a, an, the) if possible.
198
+ - If you are asked for a comma-separated list, apply the above rules to each element based on whether it's a number or a string.
199
+ """
200
+
201
+ def load_prompt_templates(yaml_path: str = "code_agent.yaml") -> dict:
202
+ """Loads prompt templates from a YAML file packaged with the library."""
203
+ try:
204
+ # Assumes 'smolagents.prompts' is a valid package/directory containing yaml_path
205
+ prompt_text = importlib.resources.files("smolagents.prompts").joinpath(yaml_path).read_text()
206
+ return yaml.safe_load(prompt_text)
207
+ except FileNotFoundError:
208
+ logging.error("Prompt YAML file not found at expected location: smolagents/prompts/%s", yaml_path)
209
+ # Return default empty dict or raise error, depending on desired behavior
210
+ return {}
211
+ except yaml.YAMLError as e:
212
+ logging.error("Error parsing YAML file %s: %s", yaml_path, e)
213
+ return {}
214
+ except Exception as e: # Catch other potential errors like package not found
215
+ logging.error("Failed to load prompts: %s", e)
216
+ return {}
217
+
218
+ def init_agent(api_key: str | None = None,
219
+ model_id: str = "gemini-1.5-flash", # Updated model ID example
220
+ api_base: str = "https://generativelanguage.googleapis.com/v1beta", # Updated base URL
221
+ temperature: float = 0.7,
222
+ max_steps: int = 15) -> CodeAgent | None:
223
+ """
224
+ Initializes and configures the CodeAgent.
225
+
226
+ Args:
227
+ api_key: The API key for the generative model service. Reads from
228
+ "API_KEY" environment variable if not provided.
229
+ model_id: The identifier of the model to use.
230
+ api_base: The base URL for the API. Note: The original URL seemed incorrect for Gemini via OpenAI proxy format. Check documentation.
231
+ The example here uses the direct Gemini API base URL format. Adjust if using an OpenAI proxy.
232
+ temperature: The sampling temperature for the model.
233
+ max_steps: The maximum number of steps the agent can take.
234
+
235
+ Returns:
236
+ An initialized CodeAgent instance, or None if initialization fails.
237
+ """
238
+ # Prefer passed API key, fallback to environment variable
239
+ resolved_api_key = api_key or os.getenv("API_KEY")
240
+ if not resolved_api_key:
241
+ logging.error("API Key not provided and 'API_KEY' environment variable not set.")
242
+ return None
243
+
244
+ # Load base prompts
245
+ prompts = load_prompt_templates()
246
+ if not prompts or "system_prompt" not in prompts:
247
+ logging.error("Failed to load or parse base prompts. Cannot initialize agent.")
248
+ return None
249
+
250
+ # Prepend the custom instructions to the loaded system prompt
251
+ prompts["system_prompt"] = SYSTEM_PROMPT_PREFIX + prompts["system_prompt"]
252
+
253
+ # Define the model connection
254
+ # Note: Ensure OpenAIServerModel is compatible with the Gemini API structure
255
+ # or use a specific Gemini client library if available/preferred.
256
+ # The api_base URL format might need adjustment based on how OpenAIServerModel constructs the full URL.
257
+ try:
258
+ gemini_model = OpenAIServerModel(
259
+ model_id=model_id,
260
+ # Make sure api_base is correct for how OpenAIServerModel uses it.
261
+ # If it expects an OpenAI-like structure, you might need a proxy or adjust this URL.
262
+ # Example using direct Gemini API base:
263
+ api_base=api_base,
264
+ api_key=resolved_api_key,
265
+ temperature=temperature
266
+ )
267
+ except Exception as e:
268
+ logging.error("Failed to initialize the language model: %s", e)
269
+ return None
270
+
271
+
272
+ # Define the list of tools available to the agent
273
+ tools = [
274
+ DuckDuckGoSearchTool(),
275
+ VisitWebpageTool(),
276
+ WikipediaSearchTool(),
277
+ GetTaskFileTool(), # Custom tool
278
+ SpeechToTextTool(),
279
+ LoadXlsxFileTool(), # Custom tool
280
+ LoadTextFileTool() # Custom tool
281
+ ]
282
+
283
+ # Create the agent instance
284
+ try:
285
+ agent = CodeAgent(
286
+ tools=tools,
287
+ model=gemini_model,
288
+ prompt_templates=prompts,
289
+ max_steps=max_steps,
290
+ # Explicitly list authorized imports for the code execution sandbox
291
+ additional_authorized_imports = ["pandas", "os.path"] # Added os.path for potential use
292
+ )
293
+ logging.info("CodeAgent initialized successfully.")
294
+ return agent
295
+ except Exception as e:
296
+ logging.error("Failed to initialize CodeAgent: %s", e)
297
+ return None