Muksia commited on
Commit
fe9507a
·
verified ·
1 Parent(s): af5d4bc

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +51 -269
agent.py CHANGED
@@ -1,297 +1,79 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Initializes and configures a SmolAgents CodeAgent with custom tools
4
- for file handling and web interaction.
5
- """
6
-
7
- import importlib.resources
8
  import os
9
- import logging # Added for logging errors
10
- from typing import Type # Added for more specific type hints
11
 
12
  import requests
13
  import yaml
14
  import pandas as pd
15
 
16
- # Assuming 'config.py' exists in the same directory or Python path
17
- # and contains: DEFAULT_API_URL = "your_api_url_here"
18
- try:
19
- from config import DEFAULT_API_URL
20
- except ImportError:
21
- # Provide a default or raise a more specific error if config is crucial
22
- DEFAULT_API_URL = "http://localhost:8000" # Example default, adjust as needed
23
- logging.warning("config.py not found or DEFAULT_API_URL not set. Using default: %s", DEFAULT_API_URL)
24
-
25
- from smolagents import (
26
- CodeAgent,
27
- Tool,
28
- OpenAIServerModel,
29
- # Standard Tools
30
- DuckDuckGoSearchTool,
31
- VisitWebpageTool,
32
- WikipediaSearchTool,
33
- SpeechToTextTool,
34
- )
35
-
36
- # Configure logging
37
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
38
-
39
- # --- Custom Tools ---
40
 
41
  class GetTaskFileTool(Tool):
42
- """
43
- A tool to download a file associated with a specific task ID from a predefined API endpoint.
44
- """
45
  name = "get_task_file_tool"
46
- description = "Downloads the file content associated with the given task_id if it exists. Returns the absolute file path of the downloaded file."
47
  inputs = {
48
- "task_id": {"type": "string", "description": "The unique identifier for the task."},
49
- "file_name": {"type": "string", "description": "The desired local name for the downloaded file."},
50
  }
51
- output_type = "string" # Output is the file path or an error message
52
 
53
  def forward(self, task_id: str, file_name: str) -> str:
54
- """
55
- Executes the file download process.
56
-
57
- Args:
58
- task_id: The ID of the task whose file should be downloaded.
59
- file_name: The name to save the downloaded file as locally.
60
-
61
- Returns:
62
- The absolute path to the downloaded file if successful,
63
- otherwise an error message string.
64
- """
65
- url = f"{DEFAULT_API_URL}/files/{task_id}"
66
- logging.info("Attempting to download file from: %s", url)
67
- try:
68
- response = requests.get(url, timeout=30) # Increased timeout slightly
69
- response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx)
70
-
71
- # Ensure the directory exists if file_name includes a path
72
- # For simplicity here, we assume file_name is just a name,
73
- # and it's saved in the current working directory.
74
- # Consider adding directory creation logic if needed:
75
- # os.makedirs(os.path.dirname(file_path), exist_ok=True)
76
-
77
- file_path = os.path.abspath(file_name)
78
- with open(file_path, 'wb') as file:
79
- file.write(response.content)
80
- logging.info("File successfully downloaded and saved to: %s", file_path)
81
- return file_path
82
-
83
- except requests.exceptions.RequestException as e:
84
- error_msg = f"Error downloading file for task {task_id}: {e}"
85
- logging.error(error_msg)
86
- return error_msg # Return error message for the agent
87
- except IOError as e:
88
- error_msg = f"Error saving file {file_name}: {e}"
89
- logging.error(error_msg)
90
- return error_msg # Return error message
91
-
92
 
93
  class LoadXlsxFileTool(Tool):
94
- """
95
- A tool to load data from an XLSX (Excel) file into a pandas DataFrame.
96
- """
97
  name = "load_xlsx_file_tool"
98
- description = "Loads data from an XLSX file specified by its path into a pandas DataFrame."
99
  inputs = {
100
- "file_path": {"type": "string", "description": "The local path to the XLSX file."}
101
  }
102
- # Using object is acceptable here as DataFrames are complex types,
103
- # but adding pandas type hint for internal clarity.
104
  output_type = "object"
105
 
106
- def forward(self, file_path: str) -> pd.DataFrame | str:
107
- """
108
- Executes the XLSX file loading process.
109
-
110
- Args:
111
- file_path: The path to the XLSX file.
112
-
113
- Returns:
114
- A pandas DataFrame containing the data from the first sheet
115
- if successful, otherwise an error message string.
116
- """
117
- logging.info("Attempting to load XLSX file: %s", file_path)
118
- try:
119
- # Ensure the file exists before attempting to read
120
- if not os.path.exists(file_path):
121
- raise FileNotFoundError(f"No such file or directory: '{file_path}'")
122
-
123
- # Load the excel file. You might want to add options like sheet_name=None
124
- # to load all sheets into a dictionary of DataFrames if needed.
125
- df = pd.read_excel(file_path)
126
- logging.info("Successfully loaded XLSX file into DataFrame.")
127
- # Note: Returning the actual DataFrame object for the agent to use.
128
- # The agent's Python execution environment needs pandas installed.
129
- return df
130
-
131
- except FileNotFoundError as e:
132
- error_msg = f"Error loading XLSX: {e}"
133
- logging.error(error_msg)
134
- return error_msg # Return error message
135
- except Exception as e:
136
- # Catch other potential errors during pandas read_excel (e.g., bad format, permissions)
137
- # xlrd might be needed for .xls, openpyxl for .xlsx
138
- error_msg = f"Error reading Excel file {file_path}: {e}"
139
- logging.error(error_msg)
140
- return error_msg # Return error message
141
-
142
 
143
  class LoadTextFileTool(Tool):
144
- """
145
- A tool to load the content of a text file into a single string.
146
- """
147
  name = "load_text_file_tool"
148
- description = "Loads the entire content of any text file specified by its path."
149
  inputs = {
150
- "file_path": {"type": "string", "description": "The local path to the text file."}
151
  }
152
- output_type = "string" # Output is the file content or an error message
153
 
154
- def forward(self, file_path: str) -> str:
155
- """
156
- Executes the text file loading process.
157
 
158
- Args:
159
- file_path: The path to the text file.
160
 
161
- Returns:
162
- The content of the text file as a string if successful,
163
- otherwise an error message string.
164
- """
165
- logging.info("Attempting to load text file: %s", file_path)
166
- try:
167
- # Ensure the file exists before attempting to read
168
- if not os.path.exists(file_path):
169
- raise FileNotFoundError(f"No such file or directory: '{file_path}'")
170
-
171
- with open(file_path, 'r', encoding='utf-8') as file:
172
- content = file.read()
173
- logging.info("Successfully loaded text file.")
174
- return content
175
-
176
- except FileNotFoundError as e:
177
- error_msg = f"Error loading text file: {e}"
178
- logging.error(error_msg)
179
- return error_msg # Return error message
180
- except UnicodeDecodeError as e:
181
- error_msg = f"Encoding error reading file {file_path} as UTF-8: {e}"
182
- logging.error(error_msg)
183
- # Consider trying other encodings or returning raw bytes if appropriate
184
- return error_msg # Return error message
185
- except IOError as e:
186
- error_msg = f"Error reading file {file_path}: {e}"
187
- logging.error(error_msg)
188
- return error_msg # Return error message
189
-
190
-
191
- # --- Agent Configuration ---
192
-
193
- # Define the custom prefix for the system prompt clearly
194
- SYSTEM_PROMPT_PREFIX = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
195
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
196
- - If you are asked for a number, don't use comma separators (e.g., 1000 instead of 1,000) and avoid units like $ or % unless explicitly requested.
197
- - If you are asked for a string, use standard capitalization, avoid abbreviations (e.g., Los Angeles instead of LA), and write out digits as words (e.g., five instead of 5) unless numbers are specifically requested. Avoid leading/trailing articles (a, an, the) if possible.
198
- - If you are asked for a comma-separated list, apply the above rules to each element based on whether it's a number or a string.
199
- """
200
-
201
- def load_prompt_templates(yaml_path: str = "code_agent.yaml") -> dict:
202
- """Loads prompt templates from a YAML file packaged with the library."""
203
- try:
204
- # Assumes 'smolagents.prompts' is a valid package/directory containing yaml_path
205
- prompt_text = importlib.resources.files("smolagents.prompts").joinpath(yaml_path).read_text()
206
- return yaml.safe_load(prompt_text)
207
- except FileNotFoundError:
208
- logging.error("Prompt YAML file not found at expected location: smolagents/prompts/%s", yaml_path)
209
- # Return default empty dict or raise error, depending on desired behavior
210
- return {}
211
- except yaml.YAMLError as e:
212
- logging.error("Error parsing YAML file %s: %s", yaml_path, e)
213
- return {}
214
- except Exception as e: # Catch other potential errors like package not found
215
- logging.error("Failed to load prompts: %s", e)
216
- return {}
217
-
218
- def init_agent(api_key: str | None = None,
219
- model_id: str = "gemini-1.5-flash", # Updated model ID example
220
- api_base: str = "https://generativelanguage.googleapis.com/v1beta", # Updated base URL
221
- temperature: float = 0.7,
222
- max_steps: int = 15) -> CodeAgent | None:
223
- """
224
- Initializes and configures the CodeAgent.
225
-
226
- Args:
227
- api_key: The API key for the generative model service. Reads from
228
- "API_KEY" environment variable if not provided.
229
- model_id: The identifier of the model to use.
230
- api_base: The base URL for the API. Note: The original URL seemed incorrect for Gemini via OpenAI proxy format. Check documentation.
231
- The example here uses the direct Gemini API base URL format. Adjust if using an OpenAI proxy.
232
- temperature: The sampling temperature for the model.
233
- max_steps: The maximum number of steps the agent can take.
234
-
235
- Returns:
236
- An initialized CodeAgent instance, or None if initialization fails.
237
- """
238
- # Prefer passed API key, fallback to environment variable
239
- resolved_api_key = api_key or os.getenv("API_KEY")
240
- if not resolved_api_key:
241
- logging.error("API Key not provided and 'API_KEY' environment variable not set.")
242
- return None
243
-
244
- # Load base prompts
245
- prompts = load_prompt_templates()
246
- if not prompts or "system_prompt" not in prompts:
247
- logging.error("Failed to load or parse base prompts. Cannot initialize agent.")
248
- return None
249
-
250
- # Prepend the custom instructions to the loaded system prompt
251
- prompts["system_prompt"] = SYSTEM_PROMPT_PREFIX + prompts["system_prompt"]
252
-
253
- # Define the model connection
254
- # Note: Ensure OpenAIServerModel is compatible with the Gemini API structure
255
- # or use a specific Gemini client library if available/preferred.
256
- # The api_base URL format might need adjustment based on how OpenAIServerModel constructs the full URL.
257
- try:
258
- gemini_model = OpenAIServerModel(
259
- model_id=model_id,
260
- # Make sure api_base is correct for how OpenAIServerModel uses it.
261
- # If it expects an OpenAI-like structure, you might need a proxy or adjust this URL.
262
- # Example using direct Gemini API base:
263
- api_base=api_base,
264
- api_key=resolved_api_key,
265
- temperature=temperature
266
- )
267
- except Exception as e:
268
- logging.error("Failed to initialize the language model: %s", e)
269
- return None
270
-
271
-
272
- # Define the list of tools available to the agent
273
- tools = [
274
- DuckDuckGoSearchTool(),
275
- VisitWebpageTool(),
276
- WikipediaSearchTool(),
277
- GetTaskFileTool(), # Custom tool
278
- SpeechToTextTool(),
279
- LoadXlsxFileTool(), # Custom tool
280
- LoadTextFileTool() # Custom tool
281
- ]
282
-
283
- # Create the agent instance
284
- try:
285
- agent = CodeAgent(
286
- tools=tools,
287
- model=gemini_model,
288
- prompt_templates=prompts,
289
- max_steps=max_steps,
290
- # Explicitly list authorized imports for the code execution sandbox
291
- additional_authorized_imports = ["pandas", "os.path"] # Added os.path for potential use
292
- )
293
- logging.info("CodeAgent initialized successfully.")
294
- return agent
295
- except Exception as e:
296
- logging.error("Failed to initialize CodeAgent: %s", e)
297
- return None
 
1
+ import importlib
 
 
 
 
 
 
2
  import os
 
 
3
 
4
  import requests
5
  import yaml
6
  import pandas as pd
7
 
8
+ from config import DEFAULT_API_URL
9
+ from smolagents import CodeAgent, DuckDuckGoSearchTool, VisitWebpageTool, WikipediaSearchTool, Tool, OpenAIServerModel, SpeechToTextTool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class GetTaskFileTool(Tool):
 
 
 
12
  name = "get_task_file_tool"
13
+ description = """This tool downloads the file content associated with the given task_id if exists. Returns absolute file path"""
14
  inputs = {
15
+ "task_id": {"type": "string", "description": "Task id"},
16
+ "file_name": {"type": "string", "description": "File name"},
17
  }
18
+ output_type = "string"
19
 
20
  def forward(self, task_id: str, file_name: str) -> str:
21
+ response = requests.get(f"{DEFAULT_API_URL}/files/{task_id}", timeout=15)
22
+ response.raise_for_status()
23
+ with open(file_name, 'wb') as file:
24
+ file.write(response.content)
25
+ return os.path.abspath(file_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  class LoadXlsxFileTool(Tool):
 
 
 
28
  name = "load_xlsx_file_tool"
29
+ description = """This tool loads xlsx file into pandas and returns it"""
30
  inputs = {
31
+ "file_path": {"type": "string", "description": "File path"}
32
  }
 
 
33
  output_type = "object"
34
 
35
+ def forward(self, file_path: str) -> object:
36
+ return pd.read_excel(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  class LoadTextFileTool(Tool):
 
 
 
39
  name = "load_text_file_tool"
40
+ description = """This tool loads any text file"""
41
  inputs = {
42
+ "file_path": {"type": "string", "description": "File path"}
43
  }
44
+ output_type = "string"
45
 
46
+ def forward(self, file_path: str) -> object:
47
+ with open(file_path, 'r', encoding='utf-8') as file:
48
+ return file.read()
49
 
 
 
50
 
51
+ prompts = yaml.safe_load(
52
+ importlib.resources.files("smolagents.prompts").joinpath("code_agent.yaml").read_text()
53
+ )
54
+ prompts["system_prompt"] = ("You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. "
55
+ + prompts["system_prompt"])
56
+
57
+ def init_agent():
58
+ gemini_model = OpenAIServerModel(
59
+ model_id="gemini-2.0-flash",
60
+ api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
61
+ api_key=os.getenv("API_KEY"),
62
+ temperature=0.7
63
+ )
64
+ agent = CodeAgent(
65
+ tools=[
66
+ DuckDuckGoSearchTool(),
67
+ VisitWebpageTool(),
68
+ WikipediaSearchTool(),
69
+ GetTaskFileTool(),
70
+ SpeechToTextTool(),
71
+ LoadXlsxFileTool(),
72
+ LoadTextFileTool()
73
+ ],
74
+ model=gemini_model,
75
+ prompt_templates=prompts,
76
+ max_steps=15,
77
+ additional_authorized_imports = ["pandas"]
78
+ )
79
+ return agent