unit4_test / tools /yt_inspector_tool.py
Vladyslav Khaitov
Add new YouTube tools, change audio tool to audio transcriber, improve system prompt
4933f00
import base64
import platform
from smolagents import Tool
from smolagents.models import Model, ChatMessage
import yt_dlp
import tempfile
import os
import cv2
class YouTubeVisualInspectorTool(Tool):
name = "youtube_visual_inspector"
description = """A tool that downloads a YouTube video, extracts frames, and answers a question based on the video content. Use this tool to ask questions about the visual content of a YouTube video."""
inputs = {
"youtube_url": {
"description": "The URL of the YouTube video to analyze.",
"type": "string",
},
"question": {
"description": "The question to answer about the video.",
"type": "string",
},
}
output_type = "string"
def __init__(self, model: Model):
super().__init__()
self.model = model
def forward(self, youtube_url: str, question: str) -> str:
if not isinstance(youtube_url, str) or not isinstance(question, str):
raise Exception("You should provide both `youtube_url` and `question` string arguments to this tool!")
with tempfile.TemporaryDirectory() as tmpdir:
ydl_opts = {
'format': 'mp4',
'outtmpl': os.path.join(tmpdir, '%(id)s.%(ext)s'),
'quiet': True,
'noplaylist': True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info = ydl.extract_info(youtube_url, download=True)
video_path = ydl.prepare_filename(info)
if not video_path.endswith('.mp4'):
for f in os.listdir(tmpdir):
if f.endswith('.mp4'):
video_path = os.path.join(tmpdir, f)
break
# Extract every 25th frame using OpenCV
vidcap = cv2.VideoCapture(video_path)
frames = []
count = 0
success, image = vidcap.read()
while success:
if count % 25 == 0:
_, buffer = cv2.imencode('.jpg', image)
frame_b64 = base64.b64encode(buffer.tobytes()).decode('utf-8')
frames.append(frame_b64)
success, image = vidcap.read()
count += 1
vidcap.release()
# Compose the message as per the provided example
messages = [
ChatMessage(
role="user",
content=[
{"type": "text", "text": question},
*[
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}"}}
for frame in frames
]
]
)
]
try:
output = self.model(messages).content
if isinstance(output, list):
output = str(output)
except Exception as e:
raise Exception("Video QA failed: " + str(e))
return str(output)
class YouTubeAudioTranscriberTool(Tool):
name = "youtube_audio_transcriber"
description = """A tool that downloads audio from a YouTube video and transcribes it to text. Use this tool when you need to convert speech or audio content from YouTube videos into written text.
This tool handles various audio formats and provides accurate transcriptions of audio content from YouTube videos."""
inputs = {
"youtube_url": {
"description": "The URL of the YouTube video to download audio from and transcribe.",
"type": "string",
},
}
output_type = "string"
def __init__(self, model: Model):
super().__init__()
self.model = model
def forward(self, youtube_url: str) -> str:
if not isinstance(youtube_url, str):
raise Exception("You should provide the `youtube_url` string argument to this tool!")
with tempfile.TemporaryDirectory() as tmpdir:
# Download audio only
ydl_opts = {
'format': 'bestaudio/best',
'outtmpl': os.path.join(tmpdir, '%(id)s.%(ext)s'),
'quiet': True,
'noplaylist': True,
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3',
'preferredquality': '192',
}],
}
if platform.system() == "Darwin":
ydl_opts['ffmpeg_location'] = '/opt/homebrew/bin/ffmpeg'
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info = ydl.extract_info(youtube_url, download=True)
audio_path = ydl.prepare_filename(info)
# Convert to mp3 if not already
if not audio_path.endswith('.mp3'):
for f in os.listdir(tmpdir):
if f.endswith('.mp3'):
audio_path = os.path.join(tmpdir, f)
break
# Read and encode the audio file
with open(audio_path, "rb") as audio_file:
base64_audio = base64.b64encode(audio_file.read()).decode('utf-8')
format = audio_path.split(".")[-1]
messages = [
ChatMessage(
role="user",
content = [
{
"type": "text",
"text": "Please transcribe this audio file accurately. Provide only the transcribed text without any additional commentary or formatting.",
},
{
"type": "input_audio",
"input_audio": {
"data": base64_audio,
"format": format
}
}
]
)
]
try:
output = self.model(messages).content
if isinstance(output, list):
# Handle case where content is a list of dicts
output = str(output)
except Exception as e:
raise Exception("Transcription failed: " + str(e))
return str(output)