Spaces:
Sleeping
Sleeping
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, logging | |
| import torch | |
| import warnings | |
| from tools.base_tool import BaseTool | |
| class SpeechRecognitionTool(BaseTool): | |
| name = 'speech_to_text' | |
| description = 'Transcribes speech from audio input.' | |
| def __init__(self): | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| dtype = torch.float16 if device == 'cuda' else torch.float32 | |
| model_id = 'openai/whisper-large-v3-turbo' | |
| self.model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True, | |
| ).to(device) | |
| self.processor = AutoProcessor.from_pretrained(model_id) | |
| logging.set_verbosity_error() | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| self.pipeline = pipeline( | |
| "automatic-speech-recognition", | |
| model=self.model, | |
| tokenizer=self.processor.tokenizer, | |
| feature_extractor=self.processor.feature_extractor, | |
| torch_dtype=dtype, | |
| device=device, | |
| chunk_length_s=30, | |
| return_timestamps=True, | |
| ) | |
| def transcribe(self, audio_path: str, with_timestamps: bool = False) -> str: | |
| result = self.pipeline(audio_path) | |
| if not with_timestamps: | |
| return result['text'].strip() | |
| formatted = "" | |
| for chunk in self._parse_timed_chunks(result['chunks']): | |
| formatted += f"[{chunk['start']:.2f}]\n{chunk['text']}\n[{chunk['end']:.2f}]\n" | |
| return formatted.strip() | |
| def _parse_timed_chunks(self, chunks): | |
| absolute_offset = 0.0 | |
| current_offset = 0.0 | |
| normalized = [] | |
| max_chunk = 30.0 | |
| for c in chunks: | |
| start, end = c['timestamp'] | |
| if start < current_offset: | |
| absolute_offset += max_chunk | |
| current_offset = start | |
| start_time = absolute_offset + start | |
| if end < start: | |
| absolute_offset += max_chunk | |
| end_time = absolute_offset + end | |
| current_offset = end | |
| text = c['text'].strip() | |
| if text: | |
| normalized.append({"start": start_time, "end": end_time, "text": text}) | |
| return normalized |