Spaces:
Sleeping
Sleeping
| from smolagents import Tool | |
| from tools.speech_recognition_tool import SpeechRecognitionTool | |
| from transformers import HfInference | |
| from io import BytesIO | |
| import yt_dlp | |
| import av | |
| import subprocess | |
| import requests | |
| import base64 | |
| import tempfile | |
| import re | |
| import os | |
| class YouTubeVideoTool(Tool): | |
| name = 'youtube_video' | |
| description = 'Extract information from YouTube video content using vision, audio, and captions.' | |
| inputs = { | |
| 'url': {'type': 'string', 'description': 'YouTube video URL'}, | |
| 'query': {'type': 'string', 'description': 'Query about the video content'}, | |
| } | |
| output_type = 'string' | |
| def __init__( | |
| self, | |
| endpoint_url: str, | |
| video_quality: int = 360, | |
| frames_interval: float = 2, | |
| chunk_duration: float = 2, | |
| speech_tool: SpeechRecognitionTool | None = None, | |
| debug: bool = False, | |
| **kwargs | |
| ): | |
| self.video_quality = video_quality | |
| self.frames_interval = frames_interval | |
| self.chunk_duration = chunk_duration | |
| self.speech_tool = speech_tool | |
| self.debug = debug | |
| self.client = HfInference(endpoint_url=endpoint_url) | |
| super().__init__(**kwargs) | |
| def forward(self, url: str, query: str) -> str: | |
| full_answer = '' | |
| for chunk in self._split_video(url): | |
| prompt = self._compose_prompt(chunk, query, full_answer) | |
| resp = self.client.text_generation(prompt, model='mistralai/Mistral-7B-Instruct-v0.1', max_new_tokens=512) | |
| full_answer = resp.generated_text.strip() | |
| return full_answer if full_answer != 'I need to keep watching.' else '' | |
| def _split_video(self, url): | |
| video = self._process(url) | |
| dur = video['duration'] | |
| start = 0 | |
| while start < dur: | |
| end = min(start + self.chunk_duration, dur) | |
| yield self._chunk(video, start, end) | |
| start += self.chunk_duration | |
| def _chunk(self, video, start, end): | |
| caps = [c for c in video['captions'] if c['start'] <= end and c['end'] >= start] | |
| frames = [f for f in video['frames'] if start <= f['timestamp'] <= end] | |
| return { | |
| 'title': video['title'], | |
| 'description': video['description'], | |
| 'start': start, | |
| 'end': end, | |
| 'captions': '\n'.join(c['text'] for c in caps), | |
| 'frames': frames, | |
| } | |
| def _compose_prompt(self, chunk, query, previous): | |
| parts = [ | |
| f"VIDEO TITLE:\n{chunk['title']}", | |
| f"DESCRIPTION:\n{chunk['description']}", | |
| f"CAPTIONS:\n{chunk['captions']}", | |
| ] | |
| if previous: | |
| parts.append(f"PRIOR ANSWER:\n{previous}") | |
| parts.append(f"QUESTION:\n{query}") | |
| return "\n\n".join(parts) | |
| def _process(self, url): | |
| info = self._get_info(url) | |
| captions = self._get_captions(info) | |
| frames = self._get_frames(info) | |
| return { | |
| 'id': info['id'], | |
| 'title': info['title'], | |
| 'description': info['description'], | |
| 'duration': info['duration'], | |
| 'captions': captions, | |
| 'frames': frames, | |
| } | |
| def _get_info(self, url): | |
| ydl_opts = { | |
| 'quiet': True, | |
| 'skip_download': True, | |
| 'format': f"bestvideo[height<={self.video_quality}]+bestaudio/best", | |
| 'forceurl': True, | |
| } | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| return ydl.extract_info(url, download=False) | |
| def _get_captions(self, info): | |
| lang = 'en' | |
| caps = self._extract_captions(lang, info.get('subtitles', {}), info.get('automatic_captions', {})) | |
| if not caps and self.speech_tool: | |
| audio_url = self._select_audio_format(info['formats']) | |
| audio = self._capture_audio(audio_url) | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: | |
| f.write(audio.read()) | |
| f.flush() | |
| path = f.name | |
| try: | |
| txt = self.speech_tool(audio=path, with_time_markers=True) | |
| return self._parse_transcript(txt) | |
| finally: | |
| os.remove(path) | |
| return caps | |
| def _parse_transcript(self, raw): | |
| chunks = [] | |
| for match in re.finditer(r'\[(\d+\.\d+)\]\n(.+?)\n\[(\d+\.\d+)\]', raw, re.DOTALL): | |
| s, t, e = match.groups() | |
| chunks.append({'start': float(s), 'end': float(e), 'text': t.strip()}) | |
| return chunks | |
| def _extract_captions(self, lang, subs, auto): | |
| import pysrt, webvtt | |
| from io import StringIO | |
| def to_sec(t): return t.hours * 3600 + t.minutes * 60 + t.seconds + t.milliseconds / 1000 | |
| def from_srt(srt_url): | |
| resp = requests.get(srt_url) | |
| return [{ | |
| 'start': to_sec(sub.start), | |
| 'end': to_sec(sub.end), | |
| 'text': sub.text.strip(), | |
| } for sub in pysrt.from_string(resp.text)] | |
| def from_vtt(vtt_url): | |
| def vtt_to_sec(ts): | |
| h, m, s = ts.split(':') | |
| s, ms = s.split('.') | |
| return int(h)*3600 + int(m)*60 + int(s) + int(ms)/1000 | |
| resp = requests.get(vtt_url) | |
| out = [] | |
| for c in webvtt.read_buffer(StringIO(resp.text)): | |
| out.append({'start': vtt_to_sec(c.start), 'end': vtt_to_sec(c.end), 'text': c.text.strip()}) | |
| return out | |
| cap_track = subs.get(lang) or auto.get(lang) or [] | |
| for track in cap_track: | |
| if track['ext'] == 'srt': return from_srt(track['url']) | |
| if track['ext'] == 'vtt': return from_vtt(track['url']) | |
| return [] | |
| def _get_frames(self, info): | |
| video_url = self._select_video_format(info['formats'])['url'] | |
| return self._extract_frames(video_url) | |
| def _extract_frames(self, url): | |
| with tempfile.NamedTemporaryFile(suffix='.mkv', delete=False) as tmp: | |
| subprocess.run(['ffmpeg', '-y', '-i', url, '-f', 'matroska', tmp.name], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
| container = av.open(tmp.name) | |
| stream = container.streams.video[0] | |
| tb = stream.time_base | |
| frames = [] | |
| next_t = 0 | |
| for frame in container.decode(stream): | |
| if frame.pts is None: continue | |
| ts = float(frame.pts * tb) | |
| if ts >= next_t: | |
| frames.append({'timestamp': ts, 'image': frame.to_image()}) | |
| next_t += self.frames_interval | |
| container.close() | |
| os.remove(tmp.name) | |
| return frames | |
| def _select_video_format(self, formats): | |
| for f in formats: | |
| if f.get('vcodec') != 'none' and f.get('height') == self.video_quality: | |
| return f | |
| raise ValueError('No matching video format found') | |
| def _select_audio_format(self, formats): | |
| audio_formats = [f for f in formats if f.get('vcodec') == 'none' and f.get('acodec') != 'none'] | |
| audio_formats.sort(key=lambda f: (-f.get('abr', 0), f['ext'] != 'm4a')) | |
| return audio_formats[0]['url'] | |
| def _capture_audio(self, audio_url): | |
| cmd = [ | |
| 'ffmpeg', '-i', audio_url, | |
| '-f', 'wav', '-acodec', 'pcm_s16le', '-ac', '1', '-ar', '16000', '-' | |
| ] | |
| proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| if proc.returncode != 0: | |
| raise RuntimeError('Audio capture failed') | |
| buf = BytesIO(proc.stdout) | |
| buf.seek(0) | |
| return buf |