| from smolagents import Tool
|
| from openai import OpenAI
|
| from tools.speech_recognition_tool import SpeechRecognitionTool
|
| from io import BytesIO
|
| import yt_dlp
|
| import av
|
| import torchaudio
|
| import subprocess
|
| import requests
|
| import base64
|
| import tempfile
|
| import re
|
| import os
|
|
|
| class YouTubeVideoTool(Tool):
|
| name = 'youtube_video'
|
| description = '''Process the video and return the requested information from it.'''
|
| inputs = {
|
| 'url': {
|
| 'type': 'string',
|
| 'description': 'The URL of the YouTube video.',
|
| },
|
| 'query': {
|
| 'type': 'string',
|
| 'description': 'The question to answer',
|
| },
|
| }
|
| output_type = 'string'
|
|
|
| def __init__(
|
| self,
|
| video_quality: int = 360,
|
| frames_interval: int | float | None = 2,
|
| chunk_duration: int | float | None = 2,
|
| speech_recognition_tool: SpeechRecognitionTool | None = None,
|
| client: OpenAI | None = None,
|
| model_id: str = 'gpt-4.1-mini',
|
| debug: bool = False,
|
| **kwargs,
|
| ):
|
| self.video_quality = video_quality
|
| self.speech_recognition_tool = speech_recognition_tool
|
| self.frames_interval = frames_interval
|
| self.chunk_duration = chunk_duration
|
|
|
| self.client = client or OpenAI()
|
| self.model_id = model_id
|
|
|
| self.debug = debug
|
|
|
| super().__init__(**kwargs)
|
|
|
| def forward(self, url: str, query: str):
|
| '''
|
| Process the video and return the requested information.
|
| Args:
|
| url(str): The URL of the YouTube video.
|
| query(str): The question to answer.
|
| Returns:
|
| str: Answer to the query
|
| '''
|
| answer = ''
|
| for chunk in self._split_video_into_chunks(url):
|
| prompt = self._prompt(
|
| chunk,
|
| query,
|
| answer,
|
| )
|
| response = self.client.responses.create(
|
| model='gpt-4.1-mini',
|
| input=[
|
| {
|
| 'role': 'user',
|
| 'content': [
|
| {
|
| 'type': 'input_text',
|
| 'text': prompt,
|
| },
|
| ],
|
| },
|
| ],
|
| )
|
| answer = response.output_text
|
| if self.debug:
|
| print(
|
| f"CHUNK {chunk['start']} - {chunk['end']}:\n\n{prompt}\n\nANSWER:\n{answer}"
|
| )
|
| if answer.strip() == 'I need to keep watching.':
|
| answer = ''
|
| return answer
|
| def _prompt(self, chunk, query, aggregated_answer):
|
| prompt = [
|
| f"""\
|
| These are some frames of a YouTube video.
|
| I will ask a question about the entire video, but you'll only see part of it at a time.
|
| Aggregate answer about the entire video, use information about previous parts but do not reference the previous parts in the answer directly.
|
|
|
| Ground your answer based on video title, description, captions, vide frames or answer from previous parts.
|
| If no evidences presented just say "I need to keep watching".
|
|
|
| VIDEO TITLE:
|
| {chunk["title"]}
|
|
|
| VIDEO DESCRIPTION:
|
| {chunk["description"]}
|
|
|
| FRAMES SUBTITLES:
|
| {chunk["captions"]}"""
|
| ]
|
|
|
| if aggregated_answer:
|
| prompt.append(f"""\
|
| Here is the answer to the same question based on the previous video parts:
|
|
|
| BASED ON PREVIOUS PARTS:
|
| {aggregated_answer}""")
|
| prompt.append(f"""\
|
| QUESTION:
|
| {query}""")
|
|
|
| return "\n\n".join(prompt)
|
|
|
| def _split_video_into_chunks(
|
| self, url: str, with_captions: bool = True, with_frames: bool = True
|
| ):
|
| video = self._process_video(
|
| url, with_captions=with_captions, with_frames=with_frames
|
| )
|
| video_duration = video['duration']
|
| chunk_duration = self.chunk_duration or video_duration
|
|
|
| chunk_start = 0.0
|
| while chunk_start < video_duration:
|
| chunk_end = min(chunk_start + chunk_duration, video_duration)
|
| chunk = self._get_video_chunk(video, chunk_start, chunk_end)
|
| yield chunk
|
| chunk_start += chunk_duration
|
|
|
| def _get_video_chunk(self, video, start, end):
|
| chunk_captions = [
|
| c for c in video['captions'] if c['start'] <= end and c['end'] >= start
|
| ]
|
| chunk_frames = [
|
| f
|
| for f in video['frames']
|
| if f['timestamp'] >= start and f['timestamp'] <= end
|
| ]
|
|
|
| return {
|
| 'title': video['title'],
|
| 'description': video['description'],
|
| 'start': start,
|
| 'end': end,
|
| 'captions': '\n'.join([c['text'] for c in chunk_captions]),
|
| 'frames': chunk_frames,
|
| }
|
|
|
| def _process_video(
|
| self, url: str, with_captions: bool = True, with_frames: bool = True
|
| ):
|
| lang = 'en'
|
| info = self._get_video_info(url, lang)
|
|
|
| if with_captions:
|
| captions = self._extract_captions(
|
| lang, info.get('subtitles', {}), info.get('automatic_captions', {})
|
| )
|
| if not captions and self.speech_recognition_tool:
|
| audio_url = self._select_audio_format(info['formats'])
|
| audio = self._capture_audio(audio_url)
|
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
|
| tmp.write(audio.read())
|
| tmp.flush()
|
| tmp_path = tmp.name
|
|
|
|
|
| try:
|
| transcription = self.speech_recognition_tool(audio=tmp_path, with_time_markers=True)
|
| captions = []
|
| pattern = re.compile(r'\[(\d+\.\d+)\]\n(.+?)\n\[(\d+\.\d+)\]', re.DOTALL)
|
| for match in pattern.finditer(transcription):
|
| start, text, end = match.groups()
|
| captions.append({
|
| 'start': float(start),
|
| 'end': float(end),
|
| 'text': text.strip(),
|
| })
|
| finally:
|
| os.remove(tmp_path)
|
| else:
|
| captions = []
|
|
|
| if with_frames:
|
| video_url = self._select_video_format(info['formats'], 360)['url']
|
| frames = self._capture_video_frames(video_url, self.frames_interval)
|
| else:
|
| frames = []
|
|
|
| return {
|
| 'id': info['id'],
|
| 'title': info['title'],
|
| 'description': info['description'],
|
| 'duration': info['duration'],
|
| 'captions': captions,
|
| 'frames': frames,
|
| }
|
|
|
| def _get_video_info(self, url: str, lang: str):
|
| ydl_opts = {
|
| 'quiet': True,
|
| 'skip_download': True,
|
| 'format': 'bestvideo[ext=mp4][height<=360]+bestaudio[ext=m4a]/best[height<=360]',
|
| "forceurl": True,
|
| "noplaylist": True,
|
| "writesubtitles": True,
|
| "writeautomaticsub": True,
|
| "subtitlesformat": "vtt",
|
| "subtitleslangs": [lang],
|
| }
|
|
|
| with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
| info = ydl.extract_info(url, download=False)
|
|
|
| return info
|
|
|
| def _extract_captions(self, lang, subtitles, auto_captions):
|
| caption_tracks = subtitles.get(lang) or auto_captions.get(lang) or []
|
|
|
| structured_captions = []
|
|
|
| srt_track = next(
|
| (track for track in caption_tracks if track['ext'] == 'srt'), None
|
| )
|
| vtt_track = next(
|
| (track for track in caption_tracks if track['ext'] == 'vtt'), None
|
| )
|
|
|
| if srt_track:
|
| import pysrt
|
| response = requests.get(srt_track['url'])
|
| response.raise_for_status()
|
| srt_data = response.content.decode('utf-8')
|
|
|
| def to_sec(t):
|
| return (
|
| t.hours * 3600 + t.minutes * 60 + t.seconds + t.milliseconds / 1000
|
| )
|
|
|
| structured_captions = [
|
| {
|
| "start": to_sec(sub.start),
|
| "end": to_sec(sub.end),
|
| "text": sub.text.strip(),
|
| }
|
| for sub in pysrt.from_string(srt_data)
|
| ]
|
| if vtt_track:
|
| import webvtt
|
| from io import StringIO
|
|
|
| response = requests.get(vtt_track['url'])
|
| response.raise_for_status()
|
| vtt_data = response.text
|
|
|
| vtt_file = StringIO(vtt_data)
|
|
|
| def to_sec(t):
|
| """Convert 'HH:MM:SS.mmm' to float seconds"""
|
| h, m, s = t.split(":")
|
| s, ms = s.split(".")
|
| return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000
|
|
|
| for caption in webvtt.read_buffer(vtt_file):
|
| structured_captions.append(
|
| {
|
| 'start': to_sec(caption.start),
|
| "end": to_sec(caption.end),
|
| "text": caption.text.strip(),
|
| }
|
| )
|
| return structured_captions
|
|
|
| def _select_video_format(self, formats, video_quality):
|
| video_format = next(
|
| f
|
| for f in formats
|
| if f.get('vcodec') != 'none' and f.get('height') == video_quality
|
| )
|
| return video_format
|
|
|
| def _capture_video_frames(self, video_url, capture_interval_sec=None):
|
| import tempfile
|
| import os
|
|
|
| with tempfile.NamedTemporaryFile(suffix='.mkv', delete=False) as tmp:
|
| tmp_path = tmp.name
|
|
|
| try:
|
| ffmpeg_cmd = [
|
| "ffmpeg",
|
| "-y",
|
| "-i", video_url,
|
| "-f", "matroska",
|
| tmp_path,
|
| ]
|
| result = subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
| if result.returncode != 0 or not os.path.exists(tmp_path) or os.path.getsize(tmp_path) == 0:
|
| raise RuntimeError(f"ffmpeg failed to create a valid video file for {video_url}")
|
|
|
| container = av.open(tmp_path)
|
| stream = container.streams.video[0]
|
| time_base = stream.time_base
|
|
|
| frames = []
|
| next_capture_time = 0
|
| for frame in container.decode(stream):
|
| if frame.pts is None:
|
| continue
|
| timestamp = float(frame.pts * time_base)
|
| if capture_interval_sec is None or timestamp >= next_capture_time:
|
| frames.append(
|
| {
|
| 'timestamp': timestamp,
|
| 'image': frame.to_image(),
|
| }
|
| )
|
| if capture_interval_sec is not None:
|
| next_capture_time += capture_interval_sec
|
| container.close()
|
| return frames
|
| finally:
|
| os.remove(tmp_path)
|
|
|
| def _base64_frames(self, frames):
|
| base64_frames = []
|
| for f in frames:
|
| buffered = BytesIO()
|
| f['image'].save(buffered, format='JPEG')
|
| encoded = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| base64_frames.append(encoded)
|
| return base64_frames
|
|
|
| def _select_audio_format(self, formats):
|
| audio_formats = [
|
| f
|
| for f in formats
|
| if f.get('vcodec') == 'none'
|
| and f.get('acodec')
|
| and f.get('acodec') != 'none'
|
| ]
|
|
|
| if not audio_formats:
|
| raise ValueError('No valid audio-only formats found')
|
|
|
|
|
| preferred_exts = ['m4a', 'webm']
|
|
|
| def sort_key(f):
|
| ext_score = (
|
| preferred_exts.index(f['ext']) if f['ext'] in preferred_exts else 99
|
| )
|
| abr = f.get('abr') or 0
|
| return (ext_score, -abr)
|
|
|
| audio_formats.sort(key=sort_key)
|
| return audio_formats[0]['url']
|
|
|
| def _capture_audio(self, audio_url) -> BytesIO:
|
| audio_buffer = BytesIO()
|
| ffmpeg_audio_cmd = [
|
| "ffmpeg",
|
| "-i",
|
| audio_url,
|
| "-f",
|
| "wav",
|
| "-acodec",
|
| "pcm_s16le",
|
| "-ac",
|
| "1",
|
| "-ar",
|
| "16000",
|
| "-",
|
| ]
|
|
|
| result = subprocess.run(
|
| ffmpeg_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
| )
|
| if result.returncode != 0:
|
| raise RuntimeError('ffmpeg failed:\n' + result.stderr.decode())
|
|
|
| audio_buffer = BytesIO(result.stdout)
|
| audio_buffer.seek(0)
|
| return audio_buffer |