Sbboss's picture
RAG, language updates
0b2d478
"""Speech-to-text using Azure Speech SDK."""
from __future__ import annotations
import os
import tempfile
import threading
from pathlib import Path
import azure.cognitiveservices.speech as speechsdk
from dataclasses import dataclass
from ..core.config import get_settings
from ..core.errors import SpeechError
@dataclass
class STTResult:
transcript: str
language: str | None
class SpeechToTextService:
"""Transcribe audio bytes using Azure Speech."""
def __init__(self) -> None:
settings = get_settings()
self._speech_config = speechsdk.SpeechConfig(
subscription=settings.azure_speech_key,
region=settings.azure_speech_region,
)
self._default_language = settings.azure_speech_stt_language
self._auto_languages = [
lang.strip()
for lang in settings.azure_speech_auto_languages.split(",")
if lang.strip()
]
if self._auto_languages:
# Avoid conflicting config: do not set a fixed language when auto-detecting.
if not self._default_language:
self._default_language = self._auto_languages[0]
elif self._default_language:
self._speech_config.speech_recognition_language = self._default_language
def transcribe(
self,
audio_bytes: bytes,
filename: str | None = None,
content_type: str | None = None,
) -> STTResult:
"""Recognize speech from audio bytes and return transcript."""
audio_config, temp_path = self._build_audio_config(
audio_bytes, filename, content_type
)
auto_config = self._auto_detect_config()
try:
recognizer = speechsdk.SpeechRecognizer(
speech_config=self._speech_config,
audio_config=audio_config,
auto_detect_source_language_config=auto_config,
)
result = recognizer.recognize_once()
if result.reason == speechsdk.ResultReason.RecognizedSpeech:
transcript = result.text.strip()
if not transcript:
raise SpeechError(code="stt_empty", message="Empty transcript.")
detected = _detected_language_from_result(result)
return STTResult(
transcript=transcript,
language=detected or self._default_language,
)
if result.reason == speechsdk.ResultReason.NoMatch:
raise SpeechError(code="stt_no_match", message="No speech recognized.")
if result.reason == speechsdk.ResultReason.Canceled:
cancellation = _cancellation_details(result)
raise SpeechError(
code="stt_canceled",
message="Speech recognition canceled.",
details={
"reason": str(cancellation.reason),
"error_details": cancellation.error_details,
},
)
raise SpeechError(code="stt_failed", message="Speech recognition failed.")
except RuntimeError as exc:
raise SpeechError(
code="stt_invalid_audio",
message="Invalid or unsupported audio format.",
) from exc
finally:
if temp_path and os.path.exists(temp_path):
try:
os.unlink(temp_path)
except OSError:
pass
def start_streaming(
self, end_silence_ms: int = 1200, initial_silence_ms: int = 5000
) -> "StreamingSTTSession":
"""Start a streaming STT session for PCM audio."""
return StreamingSTTSession(
self._speech_config,
self._auto_languages,
self._default_language,
end_silence_ms=end_silence_ms,
initial_silence_ms=initial_silence_ms,
)
def _auto_detect_config(self) -> speechsdk.AutoDetectSourceLanguageConfig | None:
if not self._auto_languages:
return None
return speechsdk.languageconfig.AutoDetectSourceLanguageConfig(
languages=self._auto_languages
)
def _build_audio_config(
self,
audio_bytes: bytes,
filename: str | None,
content_type: str | None,
) -> tuple[speechsdk.audio.AudioConfig, str]:
pcm_format = self._pcm_format_from_content_type(content_type)
if pcm_format is not None:
stream_format = speechsdk.audio.AudioStreamFormat(**pcm_format)
push_stream = speechsdk.audio.PushAudioInputStream(stream_format)
push_stream.write(audio_bytes)
push_stream.close()
return speechsdk.audio.AudioConfig(stream=push_stream), ""
suffix = Path(filename or "").suffix.lower()
if suffix in {".wav", ".wave"} or self._is_wav_content_type(content_type):
temp_path = ""
with tempfile.NamedTemporaryFile(
delete=False, suffix=suffix or ".wav"
) as temp:
temp.write(audio_bytes)
temp_path = temp.name
return speechsdk.audio.AudioConfig(filename=temp_path), temp_path
container = self._container_format_for_suffix(suffix)
if container is None:
container = self._container_format_for_content_type(content_type)
if container is None:
raise SpeechError(
code="stt_unsupported_format",
message=(
"Unsupported audio format. Use WAV, MP3, OGG/OPUS, or AAC/M4A."
),
)
stream_format = speechsdk.audio.AudioStreamFormat(
compressed_stream_format=container
)
push_stream = speechsdk.audio.PushAudioInputStream(stream_format)
push_stream.write(audio_bytes)
push_stream.close()
return speechsdk.audio.AudioConfig(stream=push_stream), ""
def _container_format_for_suffix(
self, suffix: str
) -> speechsdk.AudioStreamContainerFormat | None:
mapping = {
".mp3": "MP3",
".ogg": "OGG_OPUS",
".opus": "OGG_OPUS",
".aac": "AAC",
".m4a": "AAC",
}
name = mapping.get(suffix)
if not name:
return None
return getattr(speechsdk.AudioStreamContainerFormat, name, None)
def _container_format_for_content_type(
self, content_type: str | None
) -> speechsdk.AudioStreamContainerFormat | None:
if not content_type:
return None
normalized = content_type.split(";")[0].strip().lower()
mapping = {
"audio/mpeg": "MP3",
"audio/mp3": "MP3",
"audio/ogg": "OGG_OPUS",
"audio/opus": "OGG_OPUS",
"audio/aac": "AAC",
"audio/mp4": "AAC",
"audio/x-m4a": "AAC",
"audio/m4a": "AAC",
"audio/wav": None,
"audio/wave": None,
"audio/x-wav": None,
}
name = mapping.get(normalized)
if not name:
return None
return getattr(speechsdk.AudioStreamContainerFormat, name, None)
def _is_wav_content_type(self, content_type: str | None) -> bool:
if not content_type:
return False
normalized = content_type.split(";")[0].strip().lower()
return normalized in {"audio/wav", "audio/wave", "audio/x-wav"}
def _pcm_format_from_content_type(
self, content_type: str | None
) -> dict[str, int] | None:
if not content_type:
return None
parts = [p.strip().lower() for p in content_type.split(";")]
if not parts or parts[0] not in {"audio/pcm", "audio/raw"}:
return None
def _get_int_param(key: str, default: int) -> int:
for part in parts[1:]:
if part.startswith(f"{key}="):
try:
return int(part.split("=", 1)[1])
except ValueError:
return default
return default
return {
"samples_per_second": _get_int_param("rate", 16000),
"bits_per_sample": _get_int_param("bits", 16),
"channels": _get_int_param("channels", 1),
}
class StreamingSTTSession:
"""Streaming STT session using PushAudioInputStream and ContinuousRecognizer."""
def __init__(
self,
speech_config: speechsdk.SpeechConfig,
auto_languages: list[str],
default_language: str,
end_silence_ms: int,
initial_silence_ms: int,
) -> None:
self._texts: list[str] = []
self._partial: str | None = None
self._language: str | None = None
self._default_language = default_language
self._done = threading.Event()
self._error: SpeechError | None = None
self._auto_languages = auto_languages
speech_config.set_property(
speechsdk.PropertyId.SpeechServiceConnection_EndSilenceTimeoutMs,
str(end_silence_ms),
)
speech_config.set_property(
speechsdk.PropertyId.SpeechServiceConnection_InitialSilenceTimeoutMs,
str(initial_silence_ms),
)
stream_format = speechsdk.audio.AudioStreamFormat(
samples_per_second=16000, bits_per_sample=16, channels=1
)
self._push_stream = speechsdk.audio.PushAudioInputStream(stream_format)
self._audio_config = speechsdk.audio.AudioConfig(stream=self._push_stream)
auto_config = self._auto_detect_config()
self._recognizer = speechsdk.SpeechRecognizer(
speech_config=speech_config,
audio_config=self._audio_config,
auto_detect_source_language_config=auto_config,
)
self._recognizer.recognized.connect(self._on_recognized)
self._recognizer.recognizing.connect(self._on_recognizing)
self._recognizer.canceled.connect(self._on_canceled)
self._recognizer.session_stopped.connect(self._on_session_stopped)
self._recognizer.start_continuous_recognition()
def write(self, audio_bytes: bytes) -> None:
self._push_stream.write(audio_bytes)
def finish(self, timeout: float = 10.0) -> STTResult:
self._push_stream.close()
completed = self._done.wait(timeout)
if not completed:
self._recognizer.stop_continuous_recognition()
raise SpeechError(code="stt_timeout", message="STT timed out.")
if self._error:
raise self._error
transcript = " ".join(self._texts).strip()
if not transcript and self._partial:
transcript = self._partial.strip()
if not transcript:
raise SpeechError(code="stt_empty", message="Empty transcript.")
return STTResult(
transcript=transcript,
language=self._language or self._default_language,
)
def _on_recognized(self, evt: speechsdk.SpeechRecognitionEventArgs) -> None:
if evt.result.reason == speechsdk.ResultReason.RecognizedSpeech:
text = evt.result.text.strip()
if text:
self._texts.append(text)
lang = _detected_language_from_result(evt.result)
if lang:
self._language = lang
def _on_recognizing(self, evt: speechsdk.SpeechRecognitionEventArgs) -> None:
if evt.result.reason == speechsdk.ResultReason.RecognizingSpeech:
text = evt.result.text.strip()
if text:
self._partial = text
lang = _detected_language_from_result(evt.result)
if lang:
self._language = lang
def _on_canceled(self, evt: speechsdk.SpeechRecognitionCanceledEventArgs) -> None:
cancellation = _cancellation_details(evt.result)
if cancellation.reason == speechsdk.CancellationReason.EndOfStream:
self._done.set()
return
self._error = SpeechError(
code="stt_canceled",
message="Speech recognition canceled.",
details={
"reason": str(cancellation.reason),
"error_details": cancellation.error_details,
},
)
self._done.set()
def _on_session_stopped(self, evt: speechsdk.SessionEventArgs) -> None:
self._done.set()
def _auto_detect_config(self) -> speechsdk.AutoDetectSourceLanguageConfig | None:
if not self._auto_languages:
return None
return speechsdk.languageconfig.AutoDetectSourceLanguageConfig(
languages=self._auto_languages
)
def _detected_language_from_result(
result: speechsdk.SpeechRecognitionResult,
) -> str | None:
if hasattr(speechsdk, "AutoDetectSourceLanguageResult"):
try:
detector = speechsdk.AutoDetectSourceLanguageResult.from_result(result)
return getattr(detector, "language", None)
except Exception:
pass
try:
value = result.properties.get(
speechsdk.PropertyId.SpeechServiceConnection_AutoDetectSourceLanguageResult
)
if value:
return str(value)
except Exception:
pass
return None
def _cancellation_details(result: speechsdk.SpeechRecognitionResult) -> speechsdk.CancellationDetails:
"""Compatibility wrapper for SDKs without CancellationDetails.from_result."""
if hasattr(speechsdk.CancellationDetails, "from_result"):
return speechsdk.CancellationDetails.from_result(result) # type: ignore[no-any-return]
return speechsdk.CancellationDetails(result)