Spaces:
Sleeping
Sleeping
| """Speech-to-text using Azure Speech SDK.""" | |
| from __future__ import annotations | |
| import os | |
| import tempfile | |
| import threading | |
| from pathlib import Path | |
| import azure.cognitiveservices.speech as speechsdk | |
| from dataclasses import dataclass | |
| from ..core.config import get_settings | |
| from ..core.errors import SpeechError | |
| class STTResult: | |
| transcript: str | |
| language: str | None | |
| class SpeechToTextService: | |
| """Transcribe audio bytes using Azure Speech.""" | |
| def __init__(self) -> None: | |
| settings = get_settings() | |
| self._speech_config = speechsdk.SpeechConfig( | |
| subscription=settings.azure_speech_key, | |
| region=settings.azure_speech_region, | |
| ) | |
| self._default_language = settings.azure_speech_stt_language | |
| self._auto_languages = [ | |
| lang.strip() | |
| for lang in settings.azure_speech_auto_languages.split(",") | |
| if lang.strip() | |
| ] | |
| if self._auto_languages: | |
| # Avoid conflicting config: do not set a fixed language when auto-detecting. | |
| if not self._default_language: | |
| self._default_language = self._auto_languages[0] | |
| elif self._default_language: | |
| self._speech_config.speech_recognition_language = self._default_language | |
| def transcribe( | |
| self, | |
| audio_bytes: bytes, | |
| filename: str | None = None, | |
| content_type: str | None = None, | |
| ) -> STTResult: | |
| """Recognize speech from audio bytes and return transcript.""" | |
| audio_config, temp_path = self._build_audio_config( | |
| audio_bytes, filename, content_type | |
| ) | |
| auto_config = self._auto_detect_config() | |
| try: | |
| recognizer = speechsdk.SpeechRecognizer( | |
| speech_config=self._speech_config, | |
| audio_config=audio_config, | |
| auto_detect_source_language_config=auto_config, | |
| ) | |
| result = recognizer.recognize_once() | |
| if result.reason == speechsdk.ResultReason.RecognizedSpeech: | |
| transcript = result.text.strip() | |
| if not transcript: | |
| raise SpeechError(code="stt_empty", message="Empty transcript.") | |
| detected = _detected_language_from_result(result) | |
| return STTResult( | |
| transcript=transcript, | |
| language=detected or self._default_language, | |
| ) | |
| if result.reason == speechsdk.ResultReason.NoMatch: | |
| raise SpeechError(code="stt_no_match", message="No speech recognized.") | |
| if result.reason == speechsdk.ResultReason.Canceled: | |
| cancellation = _cancellation_details(result) | |
| raise SpeechError( | |
| code="stt_canceled", | |
| message="Speech recognition canceled.", | |
| details={ | |
| "reason": str(cancellation.reason), | |
| "error_details": cancellation.error_details, | |
| }, | |
| ) | |
| raise SpeechError(code="stt_failed", message="Speech recognition failed.") | |
| except RuntimeError as exc: | |
| raise SpeechError( | |
| code="stt_invalid_audio", | |
| message="Invalid or unsupported audio format.", | |
| ) from exc | |
| finally: | |
| if temp_path and os.path.exists(temp_path): | |
| try: | |
| os.unlink(temp_path) | |
| except OSError: | |
| pass | |
| def start_streaming( | |
| self, end_silence_ms: int = 1200, initial_silence_ms: int = 5000 | |
| ) -> "StreamingSTTSession": | |
| """Start a streaming STT session for PCM audio.""" | |
| return StreamingSTTSession( | |
| self._speech_config, | |
| self._auto_languages, | |
| self._default_language, | |
| end_silence_ms=end_silence_ms, | |
| initial_silence_ms=initial_silence_ms, | |
| ) | |
| def _auto_detect_config(self) -> speechsdk.AutoDetectSourceLanguageConfig | None: | |
| if not self._auto_languages: | |
| return None | |
| return speechsdk.languageconfig.AutoDetectSourceLanguageConfig( | |
| languages=self._auto_languages | |
| ) | |
| def _build_audio_config( | |
| self, | |
| audio_bytes: bytes, | |
| filename: str | None, | |
| content_type: str | None, | |
| ) -> tuple[speechsdk.audio.AudioConfig, str]: | |
| pcm_format = self._pcm_format_from_content_type(content_type) | |
| if pcm_format is not None: | |
| stream_format = speechsdk.audio.AudioStreamFormat(**pcm_format) | |
| push_stream = speechsdk.audio.PushAudioInputStream(stream_format) | |
| push_stream.write(audio_bytes) | |
| push_stream.close() | |
| return speechsdk.audio.AudioConfig(stream=push_stream), "" | |
| suffix = Path(filename or "").suffix.lower() | |
| if suffix in {".wav", ".wave"} or self._is_wav_content_type(content_type): | |
| temp_path = "" | |
| with tempfile.NamedTemporaryFile( | |
| delete=False, suffix=suffix or ".wav" | |
| ) as temp: | |
| temp.write(audio_bytes) | |
| temp_path = temp.name | |
| return speechsdk.audio.AudioConfig(filename=temp_path), temp_path | |
| container = self._container_format_for_suffix(suffix) | |
| if container is None: | |
| container = self._container_format_for_content_type(content_type) | |
| if container is None: | |
| raise SpeechError( | |
| code="stt_unsupported_format", | |
| message=( | |
| "Unsupported audio format. Use WAV, MP3, OGG/OPUS, or AAC/M4A." | |
| ), | |
| ) | |
| stream_format = speechsdk.audio.AudioStreamFormat( | |
| compressed_stream_format=container | |
| ) | |
| push_stream = speechsdk.audio.PushAudioInputStream(stream_format) | |
| push_stream.write(audio_bytes) | |
| push_stream.close() | |
| return speechsdk.audio.AudioConfig(stream=push_stream), "" | |
| def _container_format_for_suffix( | |
| self, suffix: str | |
| ) -> speechsdk.AudioStreamContainerFormat | None: | |
| mapping = { | |
| ".mp3": "MP3", | |
| ".ogg": "OGG_OPUS", | |
| ".opus": "OGG_OPUS", | |
| ".aac": "AAC", | |
| ".m4a": "AAC", | |
| } | |
| name = mapping.get(suffix) | |
| if not name: | |
| return None | |
| return getattr(speechsdk.AudioStreamContainerFormat, name, None) | |
| def _container_format_for_content_type( | |
| self, content_type: str | None | |
| ) -> speechsdk.AudioStreamContainerFormat | None: | |
| if not content_type: | |
| return None | |
| normalized = content_type.split(";")[0].strip().lower() | |
| mapping = { | |
| "audio/mpeg": "MP3", | |
| "audio/mp3": "MP3", | |
| "audio/ogg": "OGG_OPUS", | |
| "audio/opus": "OGG_OPUS", | |
| "audio/aac": "AAC", | |
| "audio/mp4": "AAC", | |
| "audio/x-m4a": "AAC", | |
| "audio/m4a": "AAC", | |
| "audio/wav": None, | |
| "audio/wave": None, | |
| "audio/x-wav": None, | |
| } | |
| name = mapping.get(normalized) | |
| if not name: | |
| return None | |
| return getattr(speechsdk.AudioStreamContainerFormat, name, None) | |
| def _is_wav_content_type(self, content_type: str | None) -> bool: | |
| if not content_type: | |
| return False | |
| normalized = content_type.split(";")[0].strip().lower() | |
| return normalized in {"audio/wav", "audio/wave", "audio/x-wav"} | |
| def _pcm_format_from_content_type( | |
| self, content_type: str | None | |
| ) -> dict[str, int] | None: | |
| if not content_type: | |
| return None | |
| parts = [p.strip().lower() for p in content_type.split(";")] | |
| if not parts or parts[0] not in {"audio/pcm", "audio/raw"}: | |
| return None | |
| def _get_int_param(key: str, default: int) -> int: | |
| for part in parts[1:]: | |
| if part.startswith(f"{key}="): | |
| try: | |
| return int(part.split("=", 1)[1]) | |
| except ValueError: | |
| return default | |
| return default | |
| return { | |
| "samples_per_second": _get_int_param("rate", 16000), | |
| "bits_per_sample": _get_int_param("bits", 16), | |
| "channels": _get_int_param("channels", 1), | |
| } | |
| class StreamingSTTSession: | |
| """Streaming STT session using PushAudioInputStream and ContinuousRecognizer.""" | |
| def __init__( | |
| self, | |
| speech_config: speechsdk.SpeechConfig, | |
| auto_languages: list[str], | |
| default_language: str, | |
| end_silence_ms: int, | |
| initial_silence_ms: int, | |
| ) -> None: | |
| self._texts: list[str] = [] | |
| self._partial: str | None = None | |
| self._language: str | None = None | |
| self._default_language = default_language | |
| self._done = threading.Event() | |
| self._error: SpeechError | None = None | |
| self._auto_languages = auto_languages | |
| speech_config.set_property( | |
| speechsdk.PropertyId.SpeechServiceConnection_EndSilenceTimeoutMs, | |
| str(end_silence_ms), | |
| ) | |
| speech_config.set_property( | |
| speechsdk.PropertyId.SpeechServiceConnection_InitialSilenceTimeoutMs, | |
| str(initial_silence_ms), | |
| ) | |
| stream_format = speechsdk.audio.AudioStreamFormat( | |
| samples_per_second=16000, bits_per_sample=16, channels=1 | |
| ) | |
| self._push_stream = speechsdk.audio.PushAudioInputStream(stream_format) | |
| self._audio_config = speechsdk.audio.AudioConfig(stream=self._push_stream) | |
| auto_config = self._auto_detect_config() | |
| self._recognizer = speechsdk.SpeechRecognizer( | |
| speech_config=speech_config, | |
| audio_config=self._audio_config, | |
| auto_detect_source_language_config=auto_config, | |
| ) | |
| self._recognizer.recognized.connect(self._on_recognized) | |
| self._recognizer.recognizing.connect(self._on_recognizing) | |
| self._recognizer.canceled.connect(self._on_canceled) | |
| self._recognizer.session_stopped.connect(self._on_session_stopped) | |
| self._recognizer.start_continuous_recognition() | |
| def write(self, audio_bytes: bytes) -> None: | |
| self._push_stream.write(audio_bytes) | |
| def finish(self, timeout: float = 10.0) -> STTResult: | |
| self._push_stream.close() | |
| completed = self._done.wait(timeout) | |
| if not completed: | |
| self._recognizer.stop_continuous_recognition() | |
| raise SpeechError(code="stt_timeout", message="STT timed out.") | |
| if self._error: | |
| raise self._error | |
| transcript = " ".join(self._texts).strip() | |
| if not transcript and self._partial: | |
| transcript = self._partial.strip() | |
| if not transcript: | |
| raise SpeechError(code="stt_empty", message="Empty transcript.") | |
| return STTResult( | |
| transcript=transcript, | |
| language=self._language or self._default_language, | |
| ) | |
| def _on_recognized(self, evt: speechsdk.SpeechRecognitionEventArgs) -> None: | |
| if evt.result.reason == speechsdk.ResultReason.RecognizedSpeech: | |
| text = evt.result.text.strip() | |
| if text: | |
| self._texts.append(text) | |
| lang = _detected_language_from_result(evt.result) | |
| if lang: | |
| self._language = lang | |
| def _on_recognizing(self, evt: speechsdk.SpeechRecognitionEventArgs) -> None: | |
| if evt.result.reason == speechsdk.ResultReason.RecognizingSpeech: | |
| text = evt.result.text.strip() | |
| if text: | |
| self._partial = text | |
| lang = _detected_language_from_result(evt.result) | |
| if lang: | |
| self._language = lang | |
| def _on_canceled(self, evt: speechsdk.SpeechRecognitionCanceledEventArgs) -> None: | |
| cancellation = _cancellation_details(evt.result) | |
| if cancellation.reason == speechsdk.CancellationReason.EndOfStream: | |
| self._done.set() | |
| return | |
| self._error = SpeechError( | |
| code="stt_canceled", | |
| message="Speech recognition canceled.", | |
| details={ | |
| "reason": str(cancellation.reason), | |
| "error_details": cancellation.error_details, | |
| }, | |
| ) | |
| self._done.set() | |
| def _on_session_stopped(self, evt: speechsdk.SessionEventArgs) -> None: | |
| self._done.set() | |
| def _auto_detect_config(self) -> speechsdk.AutoDetectSourceLanguageConfig | None: | |
| if not self._auto_languages: | |
| return None | |
| return speechsdk.languageconfig.AutoDetectSourceLanguageConfig( | |
| languages=self._auto_languages | |
| ) | |
| def _detected_language_from_result( | |
| result: speechsdk.SpeechRecognitionResult, | |
| ) -> str | None: | |
| if hasattr(speechsdk, "AutoDetectSourceLanguageResult"): | |
| try: | |
| detector = speechsdk.AutoDetectSourceLanguageResult.from_result(result) | |
| return getattr(detector, "language", None) | |
| except Exception: | |
| pass | |
| try: | |
| value = result.properties.get( | |
| speechsdk.PropertyId.SpeechServiceConnection_AutoDetectSourceLanguageResult | |
| ) | |
| if value: | |
| return str(value) | |
| except Exception: | |
| pass | |
| return None | |
| def _cancellation_details(result: speechsdk.SpeechRecognitionResult) -> speechsdk.CancellationDetails: | |
| """Compatibility wrapper for SDKs without CancellationDetails.from_result.""" | |
| if hasattr(speechsdk.CancellationDetails, "from_result"): | |
| return speechsdk.CancellationDetails.from_result(result) # type: ignore[no-any-return] | |
| return speechsdk.CancellationDetails(result) | |