Muksia commited on
Commit
27884a0
·
verified ·
1 Parent(s): b54694c

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +79 -0
agent.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+
4
+ import requests
5
+ import yaml
6
+ import pandas as pd
7
+
8
+ from config import DEFAULT_API_URL
9
+ from smolagents import CodeAgent, DuckDuckGoSearchTool, VisitWebpageTool, WikipediaSearchTool, Tool, OpenAIServerModel, SpeechToTextTool
10
+
11
+ class GetTaskFileTool(Tool):
12
+ name = "get_task_file_tool"
13
+ description = """This tool downloads the file content associated with the given task_id if exists. Returns absolute file path"""
14
+ inputs = {
15
+ "task_id": {"type": "string", "description": "Task id"},
16
+ "file_name": {"type": "string", "description": "File name"},
17
+ }
18
+ output_type = "string"
19
+
20
+ def forward(self, task_id: str, file_name: str) -> str:
21
+ response = requests.get(f"{DEFAULT_API_URL}/files/{task_id}", timeout=15)
22
+ response.raise_for_status()
23
+ with open(file_name, 'wb') as file:
24
+ file.write(response.content)
25
+ return os.path.abspath(file_name)
26
+
27
+ class LoadXlsxFileTool(Tool):
28
+ name = "load_xlsx_file_tool"
29
+ description = """This tool loads xlsx file into pandas and returns it"""
30
+ inputs = {
31
+ "file_path": {"type": "string", "description": "File path"}
32
+ }
33
+ output_type = "object"
34
+
35
+ def forward(self, file_path: str) -> object:
36
+ return pd.read_excel(file_path)
37
+
38
+ class LoadTextFileTool(Tool):
39
+ name = "load_text_file_tool"
40
+ description = """This tool loads any text file"""
41
+ inputs = {
42
+ "file_path": {"type": "string", "description": "File path"}
43
+ }
44
+ output_type = "string"
45
+
46
+ def forward(self, file_path: str) -> object:
47
+ with open(file_path, 'r', encoding='utf-8') as file:
48
+ return file.read()
49
+
50
+
51
+ prompts = yaml.safe_load(
52
+ importlib.resources.files("smolagents.prompts").joinpath("code_agent.yaml").read_text()
53
+ )
54
+ prompts["system_prompt"] = ("You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. "
55
+ + prompts["system_prompt"])
56
+
57
+ def init_agent():
58
+ gemini_model = OpenAIServerModel(
59
+ model_id="gemini-2.0-flash",
60
+ api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
61
+ api_key=os.getenv("API_KEY"),
62
+ temperature=0.7
63
+ )
64
+ agent = CodeAgent(
65
+ tools=[
66
+ DuckDuckGoSearchTool(),
67
+ VisitWebpageTool(),
68
+ WikipediaSearchTool(),
69
+ GetTaskFileTool(),
70
+ SpeechToTextTool(),
71
+ LoadXlsxFileTool(),
72
+ LoadTextFileTool()
73
+ ],
74
+ model=gemini_model,
75
+ prompt_templates=prompts,
76
+ max_steps=15,
77
+ additional_authorized_imports = ["pandas"]
78
+ )
79
+ return agent