File size: 4,118 Bytes
9de0414 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from smolagents import Tool
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, logging
import warnings
class SpeechRecognitionTool(Tool):
name = 'speech_to_text'
description = '''Transcribes speech from audio'''
inputs = {
'audio': {
'type': 'string',
'description': 'Path to the audio file to transcribe.',
},
'with_time_markers': {
'type': 'boolean',
'description': 'Whether to include timestamps in the transcription output. Each timestamp appears on its own line in the format [float, float], indicating the number of seconds elapsed from the start of the audio.',
'nullable': True,
'default': False,
},
}
output_type = 'string'
chunk_length_s = 30
def __new__(cls, *args, **kwargs):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = 'openai/whisper-large-v3-turbo'
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
logging.set_verbosity_error()
warnings.filterwarnings(
'ignore',
category=FutureWarning,
message=r'.*The input name "inputs" is deprecated.*',
)
cls.pipe = pipeline(
'automatic-speech-recognition',
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
chunk_length_s=cls.chunk_length_s,
return_timestamps=True,
)
return super().__new__(cls, *args, **kwargs)
def forward(self, audio: str, with_time_markers: bool = False) -> str:
'''
Transcribes speech from audio.
Args:
audio (str): Path to the audio file to transcribe.
with_time_markers (bool): Whether to include timestamps in the transcription output. Each timestamp appears on its own line in the format [float], indicating the number of seconds elapsed from the start of the audio
Returns:
str: The transcribed text.
'''
result = self.pipe(audio)
if not with_time_markers:
return result['text'].strip()
txt = ""
for chunk in self._normalize_chunks(result["chunks"]):
txt += f"[{chunk['start']:0.2f}]\n{chunk['text']}\n[{chunk['end']:0.2f}]\n"
return txt.strip()
def _normalize_chunks(self, chunks):
chunk_length_s = self.chunk_length_s
absolute_offset = 0.0
chunk_offset = 0.0
normalized = []
for chunk in chunks:
timestamp_start = chunk['timestamp'][0]
timestamp_end = chunk['timestamp'][1]
if timestamp_start < chunk_offset:
absolute_offset += chunk_length_s
chunk_offset = timestamp_start
absolute_start = absolute_offset + timestamp_start
if timestamp_end < timestamp_start:
absolute_offset += chunk_length_s
absolute_end = absolute_offset + timestamp_end
chunk_offset = timestamp_end
chunk_text = chunk['text'].strip()
if chunk_text:
normalized.append(
{
'start': absolute_start,
'end': absolute_end,
'text': chunk_text,
}
)
return normalized
# TEST THE SCRIPT (UNCOMMENT LINES BELOW)
# speech_to_text = SpeechRecognitionTool()
# transcription = speech_to_text(
# audio="https://agents-course-unit4-scoring.hf.space/files/99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3",
# with_time_markers=True,
# )
# print(transcription) |