Notulen_Otomatis / src /transcriber_speechbrain.py
Yermia's picture
Upload 13 files
fda93d9 verified
"""
SpeechBrain ASR wrapper (optional)
Provides a lightweight adapter around SpeechBrain's EncoderASR/EncoderDecoderASR to be used
as an optional backend in `meeting_transcriber`.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, List, Optional
import numpy as np
import torch
from src.diarization import SpeakerSegment
from src.transcriber import TranscriptSegment
@dataclass
class SpeechBrainASRConfig:
model_id: str = "speechbrain/asr-crdnn-rnnlm-librispeech"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
chunk_length_s: float = 30.0
class SpeechBrainTranscriber:
"""Adapter for SpeechBrain ASR models.
Usage:
t = SpeechBrainTranscriber(config)
t.transcribe_segments(waveform, segments, sample_rate)
"""
def __init__(self, config: Optional[SpeechBrainASRConfig] = None, models_dir: str = "./models"):
self.config = config or SpeechBrainASRConfig()
self.models_dir = Path(models_dir)
self.models_dir.mkdir(parents=True, exist_ok=True)
self._model = None
def _load_model(self):
if self._model is not None:
return
try:
# Prefer the new import path to avoid deprecation warnings in SpeechBrain >=1.0
try:
from speechbrain.inference import ( # type: ignore
EncoderASR,
EncoderDecoderASR,
)
except Exception:
from speechbrain.pretrained import ( # type: ignore
EncoderASR,
EncoderDecoderASR,
)
# Try EncoderDecoderASR first (seq2seq), fall back to EncoderASR
try:
self._model = EncoderDecoderASR.from_hparams(
source=self.config.model_id, savedir=str(self.models_dir)
)
except Exception:
self._model = EncoderASR.from_hparams(
source=self.config.model_id, savedir=str(self.models_dir)
)
except Exception as e:
print(f"[SpeechBrain] Could not load model: {e}")
self._model = None
def transcribe_full_audio(self, waveform: torch.Tensor, sample_rate: int = 16000) -> str:
"""Transcribe full audio waveform. Returns post-processed text (raw)."""
self._load_model()
if self._model is None:
return ""
# SpeechBrain typically expects a file path for convenience; some models accept numpy arrays
try:
audio_np = waveform.squeeze().cpu().numpy()
# Many SpeechBrain models accept numpy arrays for `transcribe_batch`/`transcribe_file`
# Use transcribe_batch for in-memory audio
try:
res = self._model.transcribe_batch([audio_np])
if isinstance(res, list):
return str(res[0])
return str(res)
except Exception:
# Fallback: write temporary file
import tempfile
import soundfile as sf
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
sf.write(tmp.name, audio_np.astype("float32"), sample_rate)
return str(self._model.transcribe_file(tmp.name))
except Exception as e:
print(f"[SpeechBrain] Full audio transcription failed: {e}")
return ""
def transcribe_segments(
self, waveform: torch.Tensor, segments: List[SpeakerSegment], sample_rate: int = 16000
) -> List[TranscriptSegment]:
"""Transcribe each segment and return list of TranscriptSegment objects."""
self._load_model()
transcripts: List[TranscriptSegment] = []
if self._model is None:
return transcripts
for seg in segments:
start = int(seg.start * sample_rate)
end = int(seg.end * sample_rate)
segment_np = waveform[:, start:end].squeeze().cpu().numpy()
if segment_np.size == 0:
continue
# Skip extremely short segments
if seg.end - seg.start < 0.2:
continue
try:
# prefer in-memory transcribe_batch
res = self._model.transcribe_batch([segment_np])
text = str(res[0]) if isinstance(res, list) else str(res)
except Exception:
# fallback to temporary file path
try:
import tempfile
import soundfile as sf
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
sf.write(tmp.name, segment_np.astype("float32"), sample_rate)
text = str(self._model.transcribe_file(tmp.name))
except Exception as e:
print(f"[SpeechBrain] Segment transcription failed: {e}")
text = ""
if not text or not text.strip():
continue
transcripts.append(
TranscriptSegment(
speaker_id=seg.speaker_id,
start=seg.start,
end=seg.end,
text=text.strip(),
confidence=getattr(seg, "confidence", 1.0),
is_overlap=getattr(seg, "is_overlap", False),
)
)
return transcripts