Final_Assignment_Template / tools /speech_recognition_tool.py
FD900's picture
Update tools/speech_recognition_tool.py
a6908e4 verified
raw
history blame
3.23 kB
from smolagents import Tool
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, logging
import warnings
class SpeechRecognitionTool(Tool):
name = 'speech_to_text'
description = 'Transcribes spoken audio to text with optional time markers.'
inputs = {
'audio': {
'type': 'string',
'description': 'Local path to the audio file to transcribe.',
},
'with_time_markers': {
'type': 'boolean',
'description': 'Include timestamps in output.',
'nullable': True,
'default': False,
},
}
output_type = 'string'
chunk_length_s = 30 # chunk length for inference
def __new__(cls, *args, **kwargs):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = 'openai/whisper-large-v3-turbo'
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
).to(device)
processor = AutoProcessor.from_pretrained(model_id)
logging.set_verbosity_error()
warnings.filterwarnings("ignore", category=FutureWarning)
cls.pipe = pipeline(
task='automatic-speech-recognition',
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=dtype,
device=device,
chunk_length_s=cls.chunk_length_s,
return_timestamps=True,
)
return super().__new__(cls, *args, **kwargs)
def forward(self, audio: str, with_time_markers: bool = False) -> str:
"""
Run speech recognition on the input audio file.
Args:
audio (str): Path to a local .wav or .mp3 file
with_time_markers (bool): Whether to return chunked timestamps
Returns:
str: Transcript or chunked transcript with [start]\n[text]\n[end]
"""
result = self.pipe(audio)
if not with_time_markers:
return result['text'].strip()
chunks = self._normalize_chunks(result['chunks'])
lines = []
for ch in chunks:
lines.append(f"[{ch['start']:.2f}]\n{ch['text']}\n[{ch['end']:.2f}]")
return "\n".join(lines).strip()
def _normalize_chunks(self, chunks):
offset = 0.0
chunk_offset = 0.0
norm_chunks = []
for chunk in chunks:
ts_start, ts_end = chunk['timestamp']
if ts_start < chunk_offset:
offset += self.chunk_length_s
chunk_offset = ts_start
start = offset + ts_start
if ts_end < ts_start:
offset += self.chunk_length_s
end = offset + ts_end
chunk_offset = ts_end
if chunk['text'].strip():
norm_chunks.append({
'start': start,
'end': end,
'text': chunk['text'].strip(),
})
return norm_chunks