zaldivards commited on
Commit
e47481b
·
1 Parent(s): 30e4f5b

Add agent tools

Browse files
Files changed (4) hide show
  1. definitions.py +15 -0
  2. requirements.txt +76 -2
  3. tools.py +139 -0
  4. utils.py +109 -0
definitions.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=C0115
2
+ from typing import TypedDict, Literal
3
+
4
+
5
+ class TranscriptionFile(TypedDict):
6
+ TranscriptFileUri: str
7
+
8
+
9
+ class TranscriptionResponse(TypedDict):
10
+ TranscriptionJobStatus: Literal["IN_PROGRESS", "COMPLETED", "FAILED"]
11
+ Transcript: TranscriptionFile
12
+
13
+
14
+ class TranscriptionJob(TypedDict):
15
+ TranscriptionJob: TranscriptionResponse
requirements.txt CHANGED
@@ -1,2 +1,76 @@
1
- gradio
2
- requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0 ; python_version >= '3.8'
2
+ annotated-types==0.7.0 ; python_version >= '3.8'
3
+ anyio==4.9.0 ; python_version >= '3.9'
4
+ boto3==1.38.23
5
+ botocore==1.38.23 ; python_version >= '3.9'
6
+ certifi==2025.4.26 ; python_version >= '3.6'
7
+ charset-normalizer==3.4.2 ; python_version >= '3.7'
8
+ click==8.2.1 ; python_version >= '3.10'
9
+ et-xmlfile==2.0.0 ; python_version >= '3.8'
10
+ fastapi==0.115.12 ; python_version >= '3.8'
11
+ ffmpy==0.5.0 ; python_version >= '3.8' and python_version < '4.0'
12
+ filelock==3.18.0 ; python_version >= '3.9'
13
+ fsspec==2025.5.1 ; python_version >= '3.9'
14
+ gradio==5.31.0
15
+ gradio-client==1.10.1 ; python_version >= '3.10'
16
+ groovy==0.1.2 ; python_version >= '3.10'
17
+ h11==0.16.0 ; python_version >= '3.8'
18
+ hf-xet==1.1.2 ; platform_machine == 'x86_64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'aarch64'
19
+ httpcore==1.0.9 ; python_version >= '3.8'
20
+ httpx==0.28.1 ; python_version >= '3.8'
21
+ huggingface-hub==0.32.1 ; python_full_version >= '3.8.0'
22
+ idna==3.10 ; python_version >= '3.6'
23
+ jinja2==3.1.6 ; python_version >= '3.7'
24
+ jmespath==1.0.1 ; python_version >= '3.7'
25
+ jsonpatch==1.33 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6'
26
+ jsonpointer==3.0.0 ; python_version >= '3.7'
27
+ langchain-core==0.3.61 ; python_version >= '3.9'
28
+ langgraph==0.4.7
29
+ langgraph-checkpoint==2.0.26 ; python_version >= '3.9'
30
+ langgraph-prebuilt==0.2.1 ; python_version >= '3.9'
31
+ langgraph-sdk==0.1.70 ; python_version >= '3.9'
32
+ langsmith==0.3.42 ; python_version >= '3.9'
33
+ markdown-it-py==3.0.0 ; python_version >= '3.8'
34
+ markupsafe==3.0.2 ; python_version >= '3.9'
35
+ mdurl==0.1.2 ; python_version >= '3.7'
36
+ numpy==2.2.6 ; python_version >= '3.10'
37
+ openpyxl==3.1.5
38
+ orjson==3.10.18 ; python_version >= '3.9'
39
+ ormsgpack==1.10.0 ; python_version >= '3.9'
40
+ packaging==24.2 ; python_version >= '3.8'
41
+ pandas==2.2.3 ; python_version >= '3.9'
42
+ pillow==11.2.1 ; python_version >= '3.9'
43
+ pydantic==2.11.5 ; python_version >= '3.9'
44
+ pydantic-core==2.33.2 ; python_version >= '3.9'
45
+ pydub==0.25.1
46
+ pygments==2.19.1 ; python_version >= '3.8'
47
+ pymupdf==1.26.0
48
+ python-dateutil==2.9.0.post0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
49
+ python-dotenv==1.1.0
50
+ python-multipart==0.0.20 ; python_version >= '3.8'
51
+ pytz==2025.2
52
+ pyyaml==6.0.2 ; python_version >= '3.8'
53
+ requests==2.32.3
54
+ requests-toolbelt==1.0.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
55
+ rich==14.0.0 ; python_full_version >= '3.8.0'
56
+ ruff==0.11.11 ; sys_platform != 'emscripten'
57
+ s3transfer==0.13.0 ; python_version >= '3.9'
58
+ safehttpx==0.1.6 ; python_version >= '3.10'
59
+ semantic-version==2.10.0 ; python_version >= '2.7'
60
+ shellingham==1.5.4 ; python_version >= '3.7'
61
+ six==1.17.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
62
+ smolagents==1.16.1
63
+ sniffio==1.3.1 ; python_version >= '3.7'
64
+ starlette==0.46.2 ; sys_platform != 'emscripten'
65
+ tenacity==9.1.2 ; python_version >= '3.9'
66
+ tomlkit==0.13.2 ; python_version >= '3.8'
67
+ tqdm==4.67.1 ; python_version >= '3.7'
68
+ typer==0.16.0 ; sys_platform != 'emscripten'
69
+ typing-extensions==4.13.2 ; python_version >= '3.8'
70
+ typing-inspection==0.4.1 ; python_version >= '3.9'
71
+ tzdata==2025.2 ; python_version >= '2'
72
+ urllib3==2.4.0 ; python_version >= '3.9'
73
+ uvicorn==0.34.2 ; sys_platform != 'emscripten'
74
+ websockets==15.0.1 ; python_version >= '3.9'
75
+ xxhash==3.5.0 ; python_version >= '3.7'
76
+ zstandard==0.23.0 ; python_version >= '3.8'
tools.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=W0718
2
+ import ast
3
+ import json
4
+ import os
5
+ from time import sleep
6
+ from uuid import uuid4
7
+
8
+ import boto3
9
+ import fitz
10
+ from pandas import read_excel
11
+ from smolagents import tool, Tool
12
+
13
+ from definitions import TranscriptionJob
14
+ from utils import get_file, s3_upload_file, s3_download_file
15
+
16
+
17
+ @tool
18
+ def math_calculator(query: str) -> str:
19
+ """A simple calculator tool that evaluates mathematical expressions.
20
+
21
+ Args:
22
+ query (str): A mathematical expression as a string, e.g., "2 + 2 * 3".
23
+ """
24
+ try:
25
+ result = ast.literal_eval(query)
26
+ return str(result)
27
+ except Exception as e:
28
+ return f"Error evaluating expression: {e}"
29
+
30
+
31
+ @tool
32
+ def excel_reader(task_id: str, file_name: str) -> str:
33
+ """Reads an Excel file and returns its content as a dataframe string.
34
+
35
+ Args:
36
+ task_id (str): The ID of the task associated with the file.
37
+ file_name (str): The name of the Excel file to read.
38
+ """
39
+ try:
40
+ file_content = get_file(task_id)
41
+ df = read_excel(file_content, engine="openpyxl")
42
+ return df.to_string(index=False)
43
+ except Exception as e:
44
+ return f"Error reading Excel file {file_name}: {e}"
45
+
46
+
47
+ def txt_reader(task_id: str, file_name: str) -> str:
48
+ """Reads a text file and returns its content as a string.
49
+
50
+ Args:
51
+ task_id (str): The ID of the task associated with the file.
52
+ file_name (str): The name of the file to read.
53
+ """
54
+ try:
55
+ file_content = get_file(task_id)
56
+ return file_content.read().decode("utf-8")
57
+ except Exception as e:
58
+ return f"Error reading file {file_name}: {e}"
59
+
60
+
61
+ def pdf_reader(task_id: str, file_name: str) -> str:
62
+ """Reads a PDF file and returns its content as a string.
63
+
64
+ Args:
65
+ task_id (str): The ID of the task associated with the file.
66
+ file_name (str): The name of the PDF file to read.
67
+ """
68
+ try:
69
+ file_content = get_file(task_id)
70
+ with fitz.open(stream=file_content.read(), filetype="pdf") as doc:
71
+ content = [page.get_text() for page in doc if page.get_text()]
72
+ text = "\n".join(content)
73
+ if not text:
74
+ return f"No text found in PDF file {file_name}."
75
+ return text.strip()
76
+ except Exception as e:
77
+ return f"Error reading PDF file {file_name}: {e}"
78
+
79
+
80
+ class AudioTranscriber(Tool): # pylint: disable=C0115
81
+ name = "AudioTranscriber"
82
+ description = "Extract text from audio files, such as MP3, MP4, WAV, etc."
83
+ inputs = {
84
+ "task_id": {
85
+ "type": "string",
86
+ "description": "The ID of the task associated with the audio file.",
87
+ },
88
+ "file_name": {
89
+ "type": "string",
90
+ "description": "The name of the audio file to transcribe.",
91
+ },
92
+ }
93
+
94
+ def __init__(self, *args, **kwargs):
95
+ super().__init__(*args, **kwargs)
96
+ region = os.getenv("AWS_REGION", "us-east-1")
97
+ self.client = boto3.client("transcribe", region_name=region)
98
+
99
+ def _transcribe_audio(self, job_name: str, media_uri: str) -> dict:
100
+ self.client.start_transcription_job(
101
+ TranscriptionJobName=job_name,
102
+ Media={"MediaFileUri": media_uri},
103
+ IdentifyLanguage=True,
104
+ OutputBucketName=os.getenv("TARGET_BUCKET"),
105
+ OutputKey=f"{job_name}.json",
106
+ )
107
+
108
+ def _get_transcription(self, job_name: str) -> str:
109
+ while True:
110
+ response: TranscriptionJob = self.client.get_transcription_job(TranscriptionJobName=job_name)
111
+ status = response["TranscriptionJob"]["TranscriptionJobStatus"]
112
+ if status in ["COMPLETED", "FAILED"]:
113
+ break
114
+ sleep(5)
115
+
116
+ transcript_url = response["TranscriptionJob"]["Transcript"]["TranscriptFileUri"]
117
+ try:
118
+ bytes_result = s3_download_file(os.getenv("TARGET_BUCKET"), transcript_url.split("/")[-1])
119
+ transcription_data = json.loads(bytes_result.read().decode("utf-8"))
120
+ return transcription_data["transcripts"][0]["transcript"]
121
+ except json.JSONDecodeError as e:
122
+ print(f"Error decoding transcription JSON: {e}")
123
+ raise
124
+ except Exception as e:
125
+ print(f"Error downloading or processing transcription file: {e}")
126
+ raise
127
+
128
+ def forward(self, task_id: str, file_name: str) -> str: # pylint: disable=W0221
129
+ try:
130
+ file_content = get_file(task_id)
131
+ s3_upload_file(file_content, os.getenv("SOURCE_BUCKET"), file_name)
132
+
133
+ media_uri = f"s3://{os.getenv('SOURCE_BUCKET')}/{file_name}"
134
+ job_name = f"{uuid4()}-{file_name.split('.')[0]}"
135
+ self._transcribe_audio(job_name, media_uri)
136
+ transcription = self._get_transcription(job_name)
137
+ return transcription
138
+ except Exception as e:
139
+ return f"Error starting transcription job for {file_name}: {e}"
utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+
4
+ import boto3
5
+ import requests
6
+ from dotenv import load_dotenv
7
+
8
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
9
+
10
+
11
+ load_dotenv()
12
+
13
+ bedrock_client = boto3.client(
14
+ "bedrock-runtime",
15
+ region_name=os.getenv("AWS_REGION"),
16
+ aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
17
+ aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
18
+ )
19
+
20
+
21
+ def get_file(task_id: str) -> BytesIO:
22
+ """Fetches a file associated with a task ID from the default API URL.
23
+
24
+ Parameters
25
+ ----------
26
+ task_id : str
27
+ The ID of the task for which the file is to be fetched.
28
+
29
+ Returns
30
+ -------
31
+ BytesIO
32
+ A BytesIO object containing the file content.
33
+ Raises
34
+ ------
35
+ requests.exceptions.RequestException
36
+ If there is an error during the HTTP request.
37
+ Exception
38
+ For any other unexpected errors that may occur.
39
+ """
40
+ url = f"{DEFAULT_API_URL}/files/{task_id}"
41
+ try:
42
+ response = requests.get(url, timeout=15)
43
+ response.raise_for_status()
44
+ return BytesIO(response.content)
45
+ except requests.exceptions.RequestException as e:
46
+ print(f"Error fetching file for task {task_id}: {e}")
47
+ raise
48
+ except Exception as e:
49
+ print(f"An unexpected error occurred fetching file for task {task_id}: {e}")
50
+ raise
51
+
52
+
53
+ def s3_upload_file(file_content: BytesIO, bucket_name: str, object_name: str) -> None:
54
+ """Uploads a file to an S3 bucket.
55
+
56
+ Parameters
57
+ ----------
58
+ file_content : BytesIO
59
+ The content of the file to upload.
60
+ bucket_name : str
61
+ The name of the S3 bucket.
62
+ object_name : str
63
+ The name of the object in the S3 bucket.
64
+
65
+ Raises
66
+ ------
67
+ Exception
68
+ If there is an error during the upload process.
69
+ """
70
+ try:
71
+ s3_client = boto3.client("s3")
72
+ s3_client.put_object(
73
+ Bucket=bucket_name,
74
+ Key=object_name,
75
+ Body=file_content.getvalue(),
76
+ ContentType="application/octet-stream",
77
+ )
78
+ except Exception as e:
79
+ print(f"Error uploading file to S3: {e}")
80
+ raise
81
+
82
+
83
+ def s3_download_file(bucket_name: str, object_name: str) -> BytesIO:
84
+ """Downloads a file from an S3 bucket.
85
+
86
+ Parameters
87
+ ----------
88
+ bucket_name : str
89
+ The name of the S3 bucket.
90
+ object_name : str
91
+ The name of the object in the S3 bucket.
92
+
93
+ Returns
94
+ -------
95
+ BytesIO
96
+ A BytesIO object containing the downloaded file content.
97
+
98
+ Raises
99
+ ------
100
+ Exception
101
+ If there is an error during the download process.
102
+ """
103
+ try:
104
+ s3_client = boto3.client("s3")
105
+ response = s3_client.get_object(Bucket=bucket_name, Key=object_name)
106
+ return BytesIO(response["Body"].read())
107
+ except Exception as e:
108
+ print(f"Error downloading file from S3: {e}")
109
+ raise