DesgrabadorAI / diarization_speechbrain.py
JoaquinZ's picture
Upload 19 files
e641d41 verified
"""SpeechBrain-based diarization backend (modular, safe).
Public API:
is_available() -> bool
load_model(device: str) -> EncoderClassifier
diarize(audio_path: str, device: str = 'cpu', config: dict | None = None) -> list[dict]
Each diarization segment dict:
{ 'start': float, 'end': float, 'speaker': 'SPEAKER_00' }
Design goals:
- No side-effects (download only if needed)
- Graceful fallback if dependencies missing
- Clear separation: load_model, extract embeddings, cluster
- Minimal external assumptions (caller handles errors)
Configuration keys (with defaults):
window_sec: 3.0
overlap_sec: 1.5
min_segment_sec: 0.5
min_speakers: 2
max_speakers: 8
use_silhouette: True
silhouette_min_embeddings: 6
random_state: 42
use_vad: True
vad_frame_sec: 0.5
vad_frame_hop_sec: 0.25
vad_energy_threshold: 0.0 (auto)
vad_energy_std_factor: 0.5
vad_min_speech_sec: 0.6
vad_merge_gap_sec: 0.3
clustering_method: 'spectral' | 'agglomerative' | 'hdbscan'
agglomerative_linkage: 'average'
hdbscan_min_cluster_size: 2
hdbscan_min_samples: 1
post_smoothing: True
min_segment_post_sec: 0.4
merge_same_speaker_gap_sec: 0.3
robust_audio_load: True
force_single_split_sec: 30.0
force_single_split_window_sec: 4.0
force_single_split_overlap_sec: 1.0
force_min_speakers_if_split: 2
Requires packages: speechbrain, torch, torchaudio, numpy, scikit-learn.
"""
from __future__ import annotations
import os
import math
import time
import traceback
import warnings
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
DEFAULT_CONFIG = {
"window_sec": 2.5, # Ventanas más pequeñas para mejor detección
"overlap_sec": 1.0, # Menos overlap para más diversidad
"min_segment_sec": 0.3, # Segmentos más pequeños permitidos
"min_speakers": 2,
"max_speakers": 10, # Permitir más oradores
"use_silhouette": True,
"silhouette_min_embeddings": 4, # Menos restrictivo
"random_state": 42,
# VAD params (energía / RMS adaptativa)
"use_vad": True,
"vad_frame_sec": 0.5,
"vad_frame_hop_sec": 0.25,
"vad_energy_threshold": 0.0, # 0 => auto (media + factor * std)
"vad_energy_std_factor": 0.5,
"vad_min_speech_sec": 0.6,
"vad_merge_gap_sec": 0.3,
# Clustering
"clustering_method": "spectral", # spectral | agglomerative | hdbscan(auto)
"agglomerative_linkage": "average",
"hdbscan_min_cluster_size": 2,
"hdbscan_min_samples": 1,
# Post smoothing
"post_smoothing": True,
"min_segment_post_sec": 0.4,
"merge_same_speaker_gap_sec": 0.3,
"robust_audio_load": True,
"force_single_split_sec": 30.0,
"force_single_split_window_sec": 4.0,
"force_single_split_overlap_sec": 1.0,
"force_min_speakers_if_split": 2,
}
_MODEL_CACHE = {"model": None, "source": None}
def _inject_local_ffmpeg():
"""Add bundled ffmpeg bin folder to PATH so torchaudio/ffmpeg-based conversion works."""
try:
base_dir = Path(__file__).parent
candidates = [
base_dir / 'ffmpeg' / 'ffmpeg-8.0-essentials_build' / 'bin',
base_dir / 'ffmpeg' / 'bin',
]
for c in candidates:
if c.is_dir():
bin_path = str(c)
if bin_path not in os.environ.get('PATH',''):
os.environ['PATH'] = bin_path + os.pathsep + os.environ.get('PATH','')
print(f"[SpeechBrain] FFmpeg agregado al PATH: {bin_path}")
break
except Exception:
pass
_inject_local_ffmpeg()
def is_available() -> bool:
"""Return True if speechbrain is importable (core dependency)."""
try:
import speechbrain # noqa: F401
return True
except Exception:
return False
def load_model(device: str = "cpu", local_dir: str = "speechbrain_pretrained"):
"""Load (or reuse) SpeechBrain encoder model with optional LocalStrategy.
Avoids symlink issues on Windows and logs phases.
"""
if _MODEL_CACHE["model"] is not None:
return _MODEL_CACHE["model"]
t0 = time.time()
try:
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1")
os.environ.setdefault("HUGGINGFACE_HUB_DISABLE_SYMLINKS", "1")
os.environ.setdefault("SPEECHBRAIN_LOCALSTRATEGY", "1")
except Exception:
pass
from speechbrain.inference import EncoderClassifier
# Monkeypatch Pretrainer collect_in to avoid symlink warnings (best effort)
try:
from speechbrain.utils import parameter_transfer as _pt
if hasattr(_pt, "Pretrainer"):
_OrigPretrainer = _pt.Pretrainer
class _PatchedPretrainer(_OrigPretrainer): # type: ignore
def __init__(self, *args, **kwargs):
if "collect_in" in kwargs:
kwargs["collect_in"] = None
super().__init__(*args, **kwargs)
_pt.Pretrainer = _PatchedPretrainer # type: ignore
except Exception:
pass
# Warning filters
warnings.filterwarnings(
"ignore",
message=r"Requested Pretrainer collection using symlinks on Windows",
category=UserWarning,
module=r"speechbrain\.utils\.parameter_transfer"
)
warnings.filterwarnings(
"ignore",
message=r"`torch\.cuda\.amp\.custom_fwd\(args\.\.\.\)` is deprecated",
category=FutureWarning,
module=r"speechbrain\.utils\.autocast"
)
# Try LocalStrategy then fallback
model = None
try:
from speechbrain.utils.fetching import LocalStrategy
os.makedirs(local_dir, exist_ok=True)
fetch_strategy = LocalStrategy()
model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir=local_dir,
fetch_strategy=fetch_strategy,
run_opts={"device": device}
)
print("[SpeechBrain] Modelo cargado con LocalStrategy (%.2fs)" % (time.time()-t0))
except Exception as e:
# Silenciar TypeError específico de versiones antiguas donde LocalStrategy es un Enum interno
if isinstance(e, TypeError) and 'EnumType.__call__' in str(e):
print("[SpeechBrain] LocalStrategy no soportada en esta versión -> fallback (%.2fs)" % (time.time()-t0))
else:
print(f"[SpeechBrain] LocalStrategy no disponible ({type(e).__name__}: {e}); fallback (%.2fs)" % (time.time()-t0))
model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir=local_dir,
run_opts={"device": device}
)
print("[SpeechBrain] Modelo cargado fallback (%.2fs)" % (time.time()-t0))
_MODEL_CACHE["model"] = model
print("[SpeechBrain] load_model completo (%.2fs)" % (time.time()-t0))
return model
def _segment_audio(waveform, sample_rate: int, cfg: dict) -> List[Tuple[int, int]]:
win = int(cfg["window_sec"] * sample_rate)
hop = int((cfg["window_sec"] - cfg["overlap_sec"]) * sample_rate)
if hop <= 0:
hop = win # no overlap safety
total = waveform.shape[1]
segments = []
for start in range(0, total, hop):
end = min(start + win, total)
dur = (end - start) / sample_rate
if dur >= cfg["min_segment_sec"]:
segments.append((start, end))
if end == total:
break
return segments
def _apply_vad(waveform, sample_rate: int, cfg: dict):
"""Muy simple VAD basado en energía RMS por frames.
Devuelve lista de (start_sample, end_sample) de regiones detectadas como voz.
Se fusionan regiones cercanas (< vad_merge_gap_sec) y se descartan cortas (< vad_min_speech_sec).
"""
import torch
frame = int(cfg["vad_frame_sec"] * sample_rate)
hop = int(cfg["vad_frame_hop_sec"] * sample_rate)
if frame <= 0:
frame = int(0.5 * sample_rate)
if hop <= 0:
hop = frame
wav = waveform.squeeze(0)
total = wav.shape[0]
energies = []
bounds = []
for start in range(0, total, hop):
end = min(start + frame, total)
seg = wav[start:end]
if seg.numel() == 0:
continue
# RMS energy
e = torch.sqrt(torch.mean(seg ** 2) + 1e-12).item()
energies.append(e)
bounds.append((start, end))
if end == total:
break
import numpy as np
if not energies:
return []
arr = np.array(energies)
thr = cfg["vad_energy_threshold"]
if thr <= 0.0:
thr = float(arr.mean() + cfg["vad_energy_std_factor"] * arr.std())
speech_flags = arr >= thr
# Convert contiguous True frames to sample ranges
speech_ranges = []
cur_start = None
for i, flag in enumerate(speech_flags):
if flag and cur_start is None:
cur_start = bounds[i][0]
if not flag and cur_start is not None:
speech_ranges.append((cur_start, bounds[i][1]))
cur_start = None
if cur_start is not None:
speech_ranges.append((cur_start, bounds[-1][1]))
# Merge close ranges
merged = []
gap = int(cfg["vad_merge_gap_sec"] * sample_rate)
for (s, e) in speech_ranges:
if not merged:
merged.append([s, e])
else:
if s - merged[-1][1] <= gap:
merged[-1][1] = e
else:
merged.append([s, e])
# Drop short
min_len = int(cfg["vad_min_speech_sec"] * sample_rate)
final = [(s, e) for (s, e) in merged if (e - s) >= min_len]
return final
def _compute_embeddings(model, waveform, sample_rate: int, segments: List[Tuple[int, int]]):
import torch
import numpy as np
embs = []
times = []
with torch.no_grad():
for (s, e) in segments:
seg = waveform[:, s:e]
# Expected shape [batch, time]; ensure mono already.
batch = seg.squeeze(0).unsqueeze(0)
emb = model.encode_batch(batch)
vec = emb.squeeze().cpu().numpy()
# L2 normalize for cosine stability
norm = np.linalg.norm(vec) + 1e-9
vec = vec / norm
embs.append(vec)
times.append((s / sample_rate, e / sample_rate))
return embs, times
def _estimate_num_speakers(embeddings, cfg: dict, sklearn_available: bool) -> int:
import numpy as np
if sklearn_available:
try:
from sklearn.metrics import silhouette_score
from sklearn.cluster import SpectralClustering
except Exception:
sklearn_available = False
n = len(embeddings)
if n < 2:
return 1
arr = np.vstack(embeddings)
k_min = cfg["min_speakers"]
k_max = min(cfg["max_speakers"], n)
if k_max < k_min:
return k_min
if sklearn_available and cfg.get("use_silhouette", True) and n >= cfg.get("silhouette_min_embeddings", 4) and (k_max - k_min) >= 1:
best_k = k_min
best_score = -1.0
for k in range(k_min, k_max + 1):
try:
# Usar misma afinidad que el clustering real
clustering = SpectralClustering(n_clusters=k, affinity="cosine", random_state=cfg["random_state"])
labels = clustering.fit_predict(arr)
if len(set(labels)) < 2:
continue
score = silhouette_score(arr, labels, metric="cosine")
# Bias hacia más clusters si el score es similar
score_bonus = k * 0.02 # Pequeño bonus por más clusters
adjusted_score = score + score_bonus
if adjusted_score > best_score:
best_score = adjusted_score
best_k = k
except Exception:
continue
return best_k
# Fallback heuristic - bias hacia más speakers
estimated = max(k_min, min(k_max, int(round(math.sqrt(n * 1.5)))))
# Asegurar al menos 2 speakers para conversaciones
return max(2, estimated)
def _cluster_embeddings(embeddings, times, cfg: dict):
import numpy as np
sklearn_available = True
try:
from sklearn.cluster import SpectralClustering, AgglomerativeClustering
except Exception:
sklearn_available = False
if len(embeddings) == 0:
return []
if len(embeddings) == 1:
return [(times[0][0], times[0][1], 0)]
arr = np.vstack(embeddings)
method = cfg.get("clustering_method", "spectral")
k = _estimate_num_speakers(embeddings, cfg, sklearn_available) if method != "hdbscan" else None
debug = cfg.get("debug_log_path")
if debug:
try:
with open(debug, 'a', encoding='utf-8') as f:
f.write(f"cluster: n_emb={len(embeddings)} method={method} k_est={k} sklearn={sklearn_available}\n")
except Exception:
pass
labels = None
if method == "spectral" and sklearn_available:
try:
clustering = SpectralClustering(n_clusters=k, affinity="cosine", random_state=cfg["random_state"])
labels = clustering.fit_predict(arr)
except Exception:
labels = None
elif method == "agglomerative" and sklearn_available:
try:
linkage = cfg.get("agglomerative_linkage", "average")
clustering = AgglomerativeClustering(n_clusters=k, affinity="cosine", linkage=linkage)
# Newer sklearn may deprecate affinity in favor of metric
try:
labels = clustering.fit_predict(arr)
except TypeError:
clustering = AgglomerativeClustering(n_clusters=k, metric="cosine", linkage=linkage)
labels = clustering.fit_predict(arr)
except Exception:
labels = None
elif method == "hdbscan":
try:
import hdbscan # type: ignore
clusterer = hdbscan.HDBSCAN(min_cluster_size=cfg.get("hdbscan_min_cluster_size", 2),
min_samples=cfg.get("hdbscan_min_samples", 1),
metric='euclidean')
labels = clusterer.fit_predict(arr)
# HDBSCAN may label some points as -1 (noise). Reassign noise clusters to nearest non-noise centroid.
if (labels == -1).any():
valid = arr[labels != -1]
if valid.shape[0] >= 1:
import numpy as np
cent = []
for lab in sorted(set(labels) - {-1}):
cent.append(valid[labels[labels != -1] == lab].mean(axis=0))
cent = np.stack(cent) if cent else valid
for idx, lab in enumerate(labels):
if lab == -1:
# assign by cosine similarity
sims = np.dot(cent, arr[idx]) / (np.linalg.norm(cent, axis=1) * np.linalg.norm(arr[idx]) + 1e-9)
labels[idx] = int(np.argmax(sims))
# Remap labels to 0..K-1
uniq = {lab: i for i, lab in enumerate(sorted(set(labels)))}
labels = np.array([uniq[l] for l in labels])
except Exception:
labels = None
# Fallback manual k-means if labels still None
if labels is None:
if k is None:
k = _estimate_num_speakers(embeddings, cfg, sklearn_available)
rng = np.random.default_rng(cfg.get("random_state", 42))
centroids = arr[rng.choice(len(arr), size=k, replace=False)] if k <= len(arr) else arr
for _ in range(10):
sims = np.matmul(arr, centroids.T)
arr_norm = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-9
cent_norm = np.linalg.norm(centroids, axis=1, keepdims=True).T + 1e-9
cos = sims / (arr_norm * cent_norm)
labels = np.argmax(cos, axis=1)
new_centroids = []
for j in range(len(centroids)):
mask = labels == j
if np.any(mask):
new_centroids.append(arr[mask].mean(axis=0))
else:
new_centroids.append(centroids[j])
new_centroids = np.vstack(new_centroids)
if np.allclose(new_centroids, centroids):
break
centroids = new_centroids
# Debug distribución de labels
if debug:
try:
import collections
dist = collections.Counter(labels.tolist()) if 'labels' in locals() else {}
with open(debug, 'a', encoding='utf-8') as f:
f.write(f"labels_dist={dict(dist)}\n")
except Exception:
pass
return [(times[i][0], times[i][1], int(labels[i])) for i in range(len(labels))]
def _post_smooth(clustered: List[Tuple[float, float, int]], cfg: dict, debug: str | None):
if not clustered:
return clustered
if not cfg.get("post_smoothing", True):
return clustered
merged = []
# Primero fusionar segmentos consecutivos del mismo speaker si gap pequeño
gap = cfg.get("merge_same_speaker_gap_sec", 0.3)
for seg in clustered:
if not merged:
merged.append(list(seg))
else:
last = merged[-1]
if seg[2] == last[2] and seg[0] - last[1] <= gap:
last[1] = seg[1]
else:
merged.append(list(seg))
# Absorber segmentos demasiado cortos
min_len = cfg.get("min_segment_post_sec", 0.4)
if len(merged) > 1:
i = 0
while i < len(merged):
s, e, spk = merged[i]
dur = e - s
if dur < min_len:
# decidir a cuál vecino absorber (menor distancia temporal / preferible mismo speaker)
left_gap = (s - merged[i-1][1]) if i > 0 else float('inf')
right_gap = (merged[i+1][0] - e) if i < len(merged)-1 else float('inf')
target = None
if i > 0 and (i == len(merged)-1 or left_gap <= right_gap):
target = i-1
elif i < len(merged)-1:
target = i+1
if target is not None:
# extender target
if target < i: # merge into left
merged[target][1] = e
# mantener speaker del target
merged.pop(i)
i -= 1
else: # merge into right
merged[target][0] = s
merged.pop(i)
i -= 1
i += 1
out = [(a, b, int(c)) for (a, b, c) in merged]
if debug:
try:
with open(debug, 'a', encoding='utf-8') as f:
f.write(f"post_smooth: in={len(clustered)} out={len(out)}\n")
except Exception:
pass
return out
def diarize(audio_path: str, device: str = "cpu", config: Optional[dict] = None) -> List[Dict[str, Any]]:
"""High-level diarization returning WhisperX-compatible segments list.
Raises exceptions on fatal errors (caller may wrap)."""
cfg = {**DEFAULT_CONFIG, **(config or {})}
import torchaudio
import torch
model = load_model(device=device)
def _safe_load(path):
try:
return torchaudio.load(path)
except Exception as e_primary:
# Intentar conversión a wav si es m4a y robust_audio_load activo
if cfg.get("robust_audio_load", True) and path.lower().endswith('.m4a'):
import subprocess, tempfile
tmp_wav = os.path.join(tempfile.gettempdir(), f"_conv_{int(time.time()*1000)}.wav")
cmd = ['ffmpeg', '-y', '-i', path, '-ac', '1', '-ar', '16000', tmp_wav]
try:
# Calcular timeout dinámico basado en duración estimada
conversion_timeout = max(60, int(os.path.getsize(path) / (1024 * 1024) * 10)) # ~10 seg por MB
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=conversion_timeout)
return torchaudio.load(tmp_wav)
except Exception:
raise e_primary
raise e_primary
waveform, sr = _safe_load(audio_path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sr != 16000:
resample = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
waveform = resample(waveform)
sr = 16000
debug_path = os.path.join(os.path.dirname(__file__), 'speechbrain_debug.log')
cfg["debug_log_path"] = debug_path
t_start = time.time()
# Optional VAD to obtener regiones de voz más puras antes de segmentar ventanas largas.
vad_regions = []
if cfg.get("use_vad", True):
try:
vad_regions = _apply_vad(waveform, sr, cfg)
except Exception:
vad_regions = []
# Si hay regiones VAD, subdividir cada región con el esquema de ventanas (para robustez y solapado)
if vad_regions:
segments = []
base_win = int(cfg["window_sec"] * sr)
base_hop = int((cfg["window_sec"] - cfg["overlap_sec"]) * sr)
if base_hop <= 0:
base_hop = base_win
for (rs, re) in vad_regions:
cur = rs
while cur < re:
end = min(cur + base_win, re)
dur = (end - cur) / sr
if dur >= cfg["min_segment_sec"]:
segments.append((cur, end))
if end == re:
break
cur += base_hop
else:
segments = _segment_audio(waveform, sr, cfg)
try:
with open(debug_path, 'a', encoding='utf-8') as f:
total_vad_dur = 0.0
if vad_regions:
total_vad_dur = sum((re - rs)/sr for rs, re in vad_regions)
f.write(
f"diarize: sr={sr} samples={waveform.shape[1]} segs={len(segments)} "
f"vad_regions={len(vad_regions)} vad_dur={total_vad_dur:.1f}s audio_total={waveform.shape[1]/sr:.1f}s "
f"use_vad={cfg.get('use_vad', True)} window={cfg['window_sec']} overlap={cfg['overlap_sec']}\n"
)
except Exception:
pass
if len(segments) == 0:
return []
embeddings, times = _compute_embeddings(model, waveform, sr, segments)
# Heurística: si solo 1 embedding y audio bastante largo -> forzar re-segmentación más fina para evitar colapso falso a 1 speaker
audio_total_sec = waveform.shape[1] / sr
if len(embeddings) == 1 and audio_total_sec >= cfg.get("force_single_split_sec", 30.0):
try:
win_sec = cfg.get("force_single_split_window_sec", 4.0)
ov_sec = cfg.get("force_single_split_overlap_sec", 1.0)
tmp_cfg = dict(cfg)
tmp_cfg["window_sec"] = win_sec
tmp_cfg["overlap_sec"] = ov_sec
segments_refine = _segment_audio(waveform, sr, tmp_cfg)
if len(segments_refine) > 1:
embeddings, times = _compute_embeddings(model, waveform, sr, segments_refine)
segments = segments_refine
cfg["min_speakers"] = max(cfg.get("min_speakers", 2), cfg.get("force_min_speakers_if_split", 2))
debug_path = os.path.join(os.path.dirname(__file__), 'speechbrain_debug.log')
with open(debug_path, 'a', encoding='utf-8') as f:
f.write(f"force_single_split applied win={win_sec} overlap={ov_sec} new_embs={len(embeddings)}\n")
except Exception:
pass
try:
with open(debug_path, 'a', encoding='utf-8') as f:
f.write(f"embeddings: count={len(embeddings)} first_shape={getattr(embeddings[0],'shape',None)}\n")
except Exception:
pass
clustered = _cluster_embeddings(embeddings, times, cfg)
clustered = _post_smooth(clustered, cfg, debug_path)
try:
with open(debug_path, 'a', encoding='utf-8') as f:
uniq = sorted({c[2] for c in clustered})
f.write(f"clustered: segments={len(clustered)} speakers_est={len(uniq)} speakers={uniq} total_time={time.time()-t_start:.2f}s\n")
except Exception:
pass
# Convert to expected list of dicts
out = []
for (start, end, spk) in clustered:
out.append({"start": float(start), "end": float(end), "speaker": f"SPEAKER_{spk:02d}"})
return out
if __name__ == "__main__": # simple manual test guard
ap = os.environ.get("TEST_AUDIO")
if ap and os.path.isfile(ap):
print("Running quick diarization test on", ap)
try:
segs = diarize(ap)
print("Segments:", segs[:5], "... total", len(segs))
except Exception as e:
print("Error:", e)
print(traceback.format_exc())
else:
print("Set TEST_AUDIO env var to an audio file path for test.")