| from smolagents import Tool |
| import torch |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, logging |
| import warnings |
|
|
|
|
| class SpeechRecognitionTool(Tool): |
| name = "speech_to_text" |
| description = """Transcribes speech from audio.""" |
|
|
| inputs = { |
| "audio": { |
| "type": "string", |
| "description": "Path to the audio file to transcribe.", |
| }, |
| "with_time_markers": { |
| "type": "boolean", |
| "description": "Whether to include timestamps in the transcription output. Each timestamp appears on its own line in the format [float, float], indicating the number of seconds elapsed from the start of the audio.", |
| "nullable": True, |
| "default": False, |
| }, |
| } |
| output_type = "string" |
|
|
| chunk_length_s = 30 |
|
|
| def __new__(cls, *args, **kwargs): |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
| model_id = "openai/whisper-large-v3-turbo" |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( |
| model_id, |
| torch_dtype=torch_dtype, |
| low_cpu_mem_usage=True, |
| use_safetensors=True, |
| ) |
| model.to(device) |
| processor = AutoProcessor.from_pretrained(model_id) |
|
|
| logging.set_verbosity_error() |
| warnings.filterwarnings( |
| "ignore", |
| category=FutureWarning, |
| message=r".*The input name `inputs` is deprecated.*", |
| ) |
| cls.pipe = pipeline( |
| "automatic-speech-recognition", |
| model=model, |
| tokenizer=processor.tokenizer, |
| feature_extractor=processor.feature_extractor, |
| torch_dtype=torch_dtype, |
| device=device, |
| chunk_length_s=cls.chunk_length_s, |
| return_timestamps=True, |
| ) |
|
|
| return super().__new__(cls, *args, **kwargs) |
|
|
| def forward(self, audio: str, with_time_markers: bool = False) -> str: |
| """ |
| Transcribes speech from audio. |
| |
| Args: |
| audio (str): Path to the audio file to transcribe. |
| with_time_markers (bool): Whether to include timestamps in the transcription output. Each timestamp appears on its own line in the format [float], indicating the number of seconds elapsed from the start of the audio. |
| |
| Returns: |
| str: The transcribed text. |
| """ |
| result = self.pipe(audio) |
| if not with_time_markers: |
| return result["text"].strip() |
|
|
| txt = "" |
| for chunk in self._normalize_chunks(result["chunks"]): |
| txt += f"[{chunk['start']:.2f}]\n{chunk['text']}\n[{chunk['end']:.2f}]\n" |
| return txt.strip() |
|
|
| def transcribe(self, audio, **kwargs): |
| result = self.pipe(audio, **kwargs) |
| return self._normalize_chunks(result["chunks"]) |
|
|
| def _normalize_chunks(self, chunks): |
| chunk_length_s = self.chunk_length_s |
| absolute_offset = 0.0 |
| chunk_offset = 0.0 |
| normalized = [] |
|
|
| for chunk in chunks: |
| timestamp_start = chunk["timestamp"][0] |
| timestamp_end = chunk["timestamp"][1] |
| if timestamp_start < chunk_offset: |
| absolute_offset += chunk_length_s |
| chunk_offset = timestamp_start |
| absolute_start = absolute_offset + timestamp_start |
|
|
| if timestamp_end < timestamp_start: |
| absolute_offset += chunk_length_s |
| absolute_end = absolute_offset + timestamp_end |
| chunk_offset = timestamp_end |
|
|
| chunk_text = chunk["text"].strip() |
| if chunk_text: |
| normalized.append( |
| { |
| "start": absolute_start, |
| "end": absolute_end, |
| "text": chunk_text, |
| } |
| ) |
|
|
| return normalized |
|
|