Spaces:
Sleeping
Sleeping
File size: 5,710 Bytes
fda93d9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """
SpeechBrain ASR wrapper (optional)
Provides a lightweight adapter around SpeechBrain's EncoderASR/EncoderDecoderASR to be used
as an optional backend in `meeting_transcriber`.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, List, Optional
import numpy as np
import torch
from src.diarization import SpeakerSegment
from src.transcriber import TranscriptSegment
@dataclass
class SpeechBrainASRConfig:
model_id: str = "speechbrain/asr-crdnn-rnnlm-librispeech"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
chunk_length_s: float = 30.0
class SpeechBrainTranscriber:
"""Adapter for SpeechBrain ASR models.
Usage:
t = SpeechBrainTranscriber(config)
t.transcribe_segments(waveform, segments, sample_rate)
"""
def __init__(self, config: Optional[SpeechBrainASRConfig] = None, models_dir: str = "./models"):
self.config = config or SpeechBrainASRConfig()
self.models_dir = Path(models_dir)
self.models_dir.mkdir(parents=True, exist_ok=True)
self._model = None
def _load_model(self):
if self._model is not None:
return
try:
# Prefer the new import path to avoid deprecation warnings in SpeechBrain >=1.0
try:
from speechbrain.inference import ( # type: ignore
EncoderASR,
EncoderDecoderASR,
)
except Exception:
from speechbrain.pretrained import ( # type: ignore
EncoderASR,
EncoderDecoderASR,
)
# Try EncoderDecoderASR first (seq2seq), fall back to EncoderASR
try:
self._model = EncoderDecoderASR.from_hparams(
source=self.config.model_id, savedir=str(self.models_dir)
)
except Exception:
self._model = EncoderASR.from_hparams(
source=self.config.model_id, savedir=str(self.models_dir)
)
except Exception as e:
print(f"[SpeechBrain] Could not load model: {e}")
self._model = None
def transcribe_full_audio(self, waveform: torch.Tensor, sample_rate: int = 16000) -> str:
"""Transcribe full audio waveform. Returns post-processed text (raw)."""
self._load_model()
if self._model is None:
return ""
# SpeechBrain typically expects a file path for convenience; some models accept numpy arrays
try:
audio_np = waveform.squeeze().cpu().numpy()
# Many SpeechBrain models accept numpy arrays for `transcribe_batch`/`transcribe_file`
# Use transcribe_batch for in-memory audio
try:
res = self._model.transcribe_batch([audio_np])
if isinstance(res, list):
return str(res[0])
return str(res)
except Exception:
# Fallback: write temporary file
import tempfile
import soundfile as sf
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
sf.write(tmp.name, audio_np.astype("float32"), sample_rate)
return str(self._model.transcribe_file(tmp.name))
except Exception as e:
print(f"[SpeechBrain] Full audio transcription failed: {e}")
return ""
def transcribe_segments(
self, waveform: torch.Tensor, segments: List[SpeakerSegment], sample_rate: int = 16000
) -> List[TranscriptSegment]:
"""Transcribe each segment and return list of TranscriptSegment objects."""
self._load_model()
transcripts: List[TranscriptSegment] = []
if self._model is None:
return transcripts
for seg in segments:
start = int(seg.start * sample_rate)
end = int(seg.end * sample_rate)
segment_np = waveform[:, start:end].squeeze().cpu().numpy()
if segment_np.size == 0:
continue
# Skip extremely short segments
if seg.end - seg.start < 0.2:
continue
try:
# prefer in-memory transcribe_batch
res = self._model.transcribe_batch([segment_np])
text = str(res[0]) if isinstance(res, list) else str(res)
except Exception:
# fallback to temporary file path
try:
import tempfile
import soundfile as sf
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
sf.write(tmp.name, segment_np.astype("float32"), sample_rate)
text = str(self._model.transcribe_file(tmp.name))
except Exception as e:
print(f"[SpeechBrain] Segment transcription failed: {e}")
text = ""
if not text or not text.strip():
continue
transcripts.append(
TranscriptSegment(
speaker_id=seg.speaker_id,
start=seg.start,
end=seg.end,
text=text.strip(),
confidence=getattr(seg, "confidence", 1.0),
is_overlap=getattr(seg, "is_overlap", False),
)
)
return transcripts
|