ConvxO2 commited on
Commit
4b8c370
·
1 Parent(s): 6aa584f

Fix stereo duration handling and robust ECAPA model loading

Browse files
Files changed (2) hide show
  1. app/pipeline.py +35 -14
  2. models/embedder.py +43 -26
app/pipeline.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  Speaker Diarization Pipeline
3
- Combines: Voice Activity Detection Segmentation ECAPA-TDNN Embeddings AHC Clustering
4
  """
5
 
6
  import torch
@@ -97,6 +97,16 @@ class DiarizationPipeline:
97
  return "cuda" if torch.cuda.is_available() else "cpu"
98
  return device
99
 
 
 
 
 
 
 
 
 
 
 
100
  def _load_vad(self):
101
  if self._vad_pipeline is not None:
102
  return
@@ -120,13 +130,13 @@ class DiarizationPipeline:
120
  frame_samples = int(frame_duration * self.SAMPLE_RATE)
121
  audio_np = audio.numpy()
122
  frames = [
123
- audio_np[i : i + frame_samples]
124
  for i in range(0, len(audio_np) - frame_samples, frame_samples)
125
  ]
126
 
127
  energies_db = []
128
- for f in frames:
129
- rms = np.sqrt(np.mean(f ** 2) + 1e-10)
130
  energies_db.append(20 * np.log10(rms))
131
 
132
  is_speech = np.array(energies_db) > threshold_db
@@ -198,25 +208,38 @@ class DiarizationPipeline:
198
  ) -> DiarizationResult:
199
  """Run full diarization pipeline on audio."""
200
  import time
 
201
  t_start = time.time()
202
 
203
  if isinstance(audio, (str, Path)):
204
  waveform, sample_rate = self.load_audio(audio)
205
- audio_tensor = waveform.squeeze(0)
206
  else:
207
  assert sample_rate is not None, "sample_rate required when passing tensor"
208
- audio_tensor = audio.squeeze(0) if audio.dim() > 1 else audio
209
 
210
- audio_duration = len(audio_tensor) / sample_rate
 
211
  logger.info(f"Processing {audio_duration:.1f}s audio at {sample_rate}Hz")
212
 
 
 
 
 
 
 
 
 
 
 
213
  processed = self.embedder.preprocess_audio(audio_tensor, sample_rate)
214
 
215
  speech_regions = self._get_speech_regions(processed)
216
  if not speech_regions:
217
  logger.warning("No speech detected in audio.")
218
  return DiarizationResult(
219
- segments=[], num_speakers=0,
 
220
  audio_duration=audio_duration,
221
  processing_time=time.time() - t_start,
222
  sample_rate=sample_rate,
@@ -232,7 +255,8 @@ class DiarizationPipeline:
232
  if len(embeddings) == 0:
233
  logger.warning("No valid embeddings extracted.")
234
  return DiarizationResult(
235
- segments=[], num_speakers=0,
 
236
  audio_duration=audio_duration,
237
  processing_time=time.time() - t_start,
238
  sample_rate=sample_rate,
@@ -245,11 +269,7 @@ class DiarizationPipeline:
245
 
246
  speaker_names = {i: f"SPEAKER_{i:02d}" for i in range(self.max_speakers)}
247
  segments = [
248
- DiarizationSegment(
249
- start=start,
250
- end=end,
251
- speaker=speaker_names[spk_id],
252
- )
253
  for start, end, spk_id in merged
254
  ]
255
 
@@ -268,3 +288,4 @@ class DiarizationPipeline:
268
  processing_time=processing_time,
269
  sample_rate=sample_rate,
270
  )
 
 
1
  """
2
  Speaker Diarization Pipeline
3
+ Combines: Voice Activity Detection -> Segmentation -> ECAPA-TDNN Embeddings -> AHC Clustering
4
  """
5
 
6
  import torch
 
97
  return "cuda" if torch.cuda.is_available() else "cpu"
98
  return device
99
 
100
+ def _to_mono_1d(self, audio: torch.Tensor) -> torch.Tensor:
101
+ """Convert waveform to a mono 1D tensor for duration and preprocessing."""
102
+ if audio.dim() == 1:
103
+ return audio
104
+ if audio.dim() >= 2:
105
+ if audio.shape[0] == 1:
106
+ return audio[0]
107
+ return audio.mean(dim=0)
108
+ return audio.reshape(-1)
109
+
110
  def _load_vad(self):
111
  if self._vad_pipeline is not None:
112
  return
 
130
  frame_samples = int(frame_duration * self.SAMPLE_RATE)
131
  audio_np = audio.numpy()
132
  frames = [
133
+ audio_np[i: i + frame_samples]
134
  for i in range(0, len(audio_np) - frame_samples, frame_samples)
135
  ]
136
 
137
  energies_db = []
138
+ for frame in frames:
139
+ rms = np.sqrt(np.mean(frame ** 2) + 1e-10)
140
  energies_db.append(20 * np.log10(rms))
141
 
142
  is_speech = np.array(energies_db) > threshold_db
 
208
  ) -> DiarizationResult:
209
  """Run full diarization pipeline on audio."""
210
  import time
211
+
212
  t_start = time.time()
213
 
214
  if isinstance(audio, (str, Path)):
215
  waveform, sample_rate = self.load_audio(audio)
216
+ audio_tensor = self._to_mono_1d(waveform)
217
  else:
218
  assert sample_rate is not None, "sample_rate required when passing tensor"
219
+ audio_tensor = self._to_mono_1d(audio)
220
 
221
+ num_samples = int(audio_tensor.numel())
222
+ audio_duration = num_samples / float(sample_rate)
223
  logger.info(f"Processing {audio_duration:.1f}s audio at {sample_rate}Hz")
224
 
225
+ if num_samples == 0:
226
+ logger.warning("Received empty audio input.")
227
+ return DiarizationResult(
228
+ segments=[],
229
+ num_speakers=0,
230
+ audio_duration=0.0,
231
+ processing_time=time.time() - t_start,
232
+ sample_rate=sample_rate,
233
+ )
234
+
235
  processed = self.embedder.preprocess_audio(audio_tensor, sample_rate)
236
 
237
  speech_regions = self._get_speech_regions(processed)
238
  if not speech_regions:
239
  logger.warning("No speech detected in audio.")
240
  return DiarizationResult(
241
+ segments=[],
242
+ num_speakers=0,
243
  audio_duration=audio_duration,
244
  processing_time=time.time() - t_start,
245
  sample_rate=sample_rate,
 
255
  if len(embeddings) == 0:
256
  logger.warning("No valid embeddings extracted.")
257
  return DiarizationResult(
258
+ segments=[],
259
+ num_speakers=0,
260
  audio_duration=audio_duration,
261
  processing_time=time.time() - t_start,
262
  sample_rate=sample_rate,
 
269
 
270
  speaker_names = {i: f"SPEAKER_{i:02d}" for i in range(self.max_speakers)}
271
  segments = [
272
+ DiarizationSegment(start=start, end=end, speaker=speaker_names[spk_id])
 
 
 
 
273
  for start, end, spk_id in merged
274
  ]
275
 
 
288
  processing_time=processing_time,
289
  sample_rate=sample_rate,
290
  )
291
+
models/embedder.py CHANGED
@@ -1,10 +1,9 @@
1
- """
2
  Speaker Embedding Extraction using ECAPA-TDNN architecture via SpeechBrain.
3
  Handles audio preprocessing, feature extraction, and L2-normalized embeddings.
4
  """
5
 
6
  import inspect
7
- import shutil
8
  from pathlib import Path
9
  from typing import Union, List, Tuple
10
 
@@ -36,24 +35,36 @@ class EcapaTDNNEmbedder:
36
  return "cuda" if torch.cuda.is_available() else "cpu"
37
  return device
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def _load_model(self):
40
  if self._model is not None:
41
  return
42
 
43
  try:
44
- import speechbrain.utils.fetching as _fetching
45
- from speechbrain.utils.fetching import LocalStrategy
46
- from speechbrain.inference.classifiers import EncoderClassifier
47
-
48
- def _patched_link(src, dst, local_strategy):
49
- dst_path = Path(dst)
50
- src_path = Path(src)
51
- dst_path.parent.mkdir(parents=True, exist_ok=True)
52
- if dst_path.exists() or dst_path.is_symlink():
53
- dst_path.unlink()
54
- shutil.copy2(str(src_path), str(dst_path))
55
-
56
- _fetching.link_with_strategy = _patched_link
57
 
58
  savedir = self.cache_dir / "ecapa_tdnn"
59
  hf_cache = self.cache_dir / "hf_cache"
@@ -63,23 +74,28 @@ class EcapaTDNNEmbedder:
63
  logger.info(f"Loading ECAPA-TDNN from {self.MODEL_SOURCE}...")
64
  logger.info(f"Savedir: {savedir}, exists: {savedir.exists()}")
65
 
66
- kwargs = {
67
- "source": self.MODEL_SOURCE,
68
- "savedir": str(savedir),
69
- "run_opts": {"device": self.device},
70
- }
71
 
72
- sig = inspect.signature(EncoderClassifier.from_hparams)
73
- if "huggingface_cache_dir" in sig.parameters:
74
- kwargs["huggingface_cache_dir"] = str(hf_cache)
75
- if "local_strategy" in sig.parameters:
76
- kwargs["local_strategy"] = LocalStrategy.COPY
 
 
 
77
 
78
- self._model = EncoderClassifier.from_hparams(**kwargs)
 
 
 
79
  self._model.eval()
80
  logger.success("ECAPA-TDNN model loaded successfully.")
81
  except ImportError as exc:
82
  raise ImportError("SpeechBrain not installed.") from exc
 
 
83
 
84
  def preprocess_audio(
85
  self, audio: Union[np.ndarray, torch.Tensor], sample_rate: int
@@ -157,3 +173,4 @@ class EcapaTDNNEmbedder:
157
  return np.empty((0, self.EMBEDDING_DIM)), []
158
 
159
  return np.stack(embeddings), valid_segments
 
 
1
+ """
2
  Speaker Embedding Extraction using ECAPA-TDNN architecture via SpeechBrain.
3
  Handles audio preprocessing, feature extraction, and L2-normalized embeddings.
4
  """
5
 
6
  import inspect
 
7
  from pathlib import Path
8
  from typing import Union, List, Tuple
9
 
 
35
  return "cuda" if torch.cuda.is_available() else "cpu"
36
  return device
37
 
38
+ def _build_hparams_kwargs(self, encoder_cls, savedir: Path, hf_cache: Path) -> dict:
39
+ kwargs = {
40
+ "source": self.MODEL_SOURCE,
41
+ "savedir": str(savedir),
42
+ "run_opts": {"device": self.device},
43
+ }
44
+
45
+ sig = inspect.signature(encoder_cls.from_hparams)
46
+ if "huggingface_cache_dir" in sig.parameters:
47
+ kwargs["huggingface_cache_dir"] = str(hf_cache)
48
+ if "local_strategy" in sig.parameters:
49
+ try:
50
+ from speechbrain.utils.fetching import LocalStrategy
51
+
52
+ kwargs["local_strategy"] = LocalStrategy.COPY
53
+ except Exception:
54
+ pass
55
+
56
+ return kwargs
57
+
58
  def _load_model(self):
59
  if self._model is not None:
60
  return
61
 
62
  try:
63
+ try:
64
+ from speechbrain.inference.classifiers import EncoderClassifier
65
+ except ImportError:
66
+ # Backward compatibility with older SpeechBrain versions.
67
+ from speechbrain.pretrained import EncoderClassifier
 
 
 
 
 
 
 
 
68
 
69
  savedir = self.cache_dir / "ecapa_tdnn"
70
  hf_cache = self.cache_dir / "hf_cache"
 
74
  logger.info(f"Loading ECAPA-TDNN from {self.MODEL_SOURCE}...")
75
  logger.info(f"Savedir: {savedir}, exists: {savedir.exists()}")
76
 
77
+ kwargs = self._build_hparams_kwargs(EncoderClassifier, savedir, hf_cache)
78
+ model = EncoderClassifier.from_hparams(**kwargs)
 
 
 
79
 
80
+ if model is None:
81
+ # Some SpeechBrain/HF hub combinations ignore optional kwargs.
82
+ logger.warning("ECAPA load returned None; retrying with minimal from_hparams kwargs.")
83
+ model = EncoderClassifier.from_hparams(
84
+ source=self.MODEL_SOURCE,
85
+ savedir=str(savedir),
86
+ run_opts={"device": self.device},
87
+ )
88
 
89
+ if model is None:
90
+ raise RuntimeError("EncoderClassifier.from_hparams returned None")
91
+
92
+ self._model = model
93
  self._model.eval()
94
  logger.success("ECAPA-TDNN model loaded successfully.")
95
  except ImportError as exc:
96
  raise ImportError("SpeechBrain not installed.") from exc
97
+ except Exception as exc:
98
+ raise RuntimeError(f"Failed to load ECAPA-TDNN model: {exc}") from exc
99
 
100
  def preprocess_audio(
101
  self, audio: Union[np.ndarray, torch.Tensor], sample_rate: int
 
173
  return np.empty((0, self.EMBEDDING_DIM)), []
174
 
175
  return np.stack(embeddings), valid_segments
176
+