Spaces:
Sleeping
Sleeping
| """SpeechBrain-based diarization backend (modular, safe). | |
| Public API: | |
| is_available() -> bool | |
| load_model(device: str) -> EncoderClassifier | |
| diarize(audio_path: str, device: str = 'cpu', config: dict | None = None) -> list[dict] | |
| Each diarization segment dict: | |
| { 'start': float, 'end': float, 'speaker': 'SPEAKER_00' } | |
| Design goals: | |
| - No side-effects (download only if needed) | |
| - Graceful fallback if dependencies missing | |
| - Clear separation: load_model, extract embeddings, cluster | |
| - Minimal external assumptions (caller handles errors) | |
| Configuration keys (with defaults): | |
| window_sec: 3.0 | |
| overlap_sec: 1.5 | |
| min_segment_sec: 0.5 | |
| min_speakers: 2 | |
| max_speakers: 8 | |
| use_silhouette: True | |
| silhouette_min_embeddings: 6 | |
| random_state: 42 | |
| use_vad: True | |
| vad_frame_sec: 0.5 | |
| vad_frame_hop_sec: 0.25 | |
| vad_energy_threshold: 0.0 (auto) | |
| vad_energy_std_factor: 0.5 | |
| vad_min_speech_sec: 0.6 | |
| vad_merge_gap_sec: 0.3 | |
| clustering_method: 'spectral' | 'agglomerative' | 'hdbscan' | |
| agglomerative_linkage: 'average' | |
| hdbscan_min_cluster_size: 2 | |
| hdbscan_min_samples: 1 | |
| post_smoothing: True | |
| min_segment_post_sec: 0.4 | |
| merge_same_speaker_gap_sec: 0.3 | |
| robust_audio_load: True | |
| force_single_split_sec: 30.0 | |
| force_single_split_window_sec: 4.0 | |
| force_single_split_overlap_sec: 1.0 | |
| force_min_speakers_if_split: 2 | |
| Requires packages: speechbrain, torch, torchaudio, numpy, scikit-learn. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import math | |
| import time | |
| import traceback | |
| import warnings | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from pathlib import Path | |
| DEFAULT_CONFIG = { | |
| "window_sec": 2.5, # Ventanas más pequeñas para mejor detección | |
| "overlap_sec": 1.0, # Menos overlap para más diversidad | |
| "min_segment_sec": 0.3, # Segmentos más pequeños permitidos | |
| "min_speakers": 2, | |
| "max_speakers": 10, # Permitir más oradores | |
| "use_silhouette": True, | |
| "silhouette_min_embeddings": 4, # Menos restrictivo | |
| "random_state": 42, | |
| # VAD params (energía / RMS adaptativa) | |
| "use_vad": True, | |
| "vad_frame_sec": 0.5, | |
| "vad_frame_hop_sec": 0.25, | |
| "vad_energy_threshold": 0.0, # 0 => auto (media + factor * std) | |
| "vad_energy_std_factor": 0.5, | |
| "vad_min_speech_sec": 0.6, | |
| "vad_merge_gap_sec": 0.3, | |
| # Clustering | |
| "clustering_method": "spectral", # spectral | agglomerative | hdbscan(auto) | |
| "agglomerative_linkage": "average", | |
| "hdbscan_min_cluster_size": 2, | |
| "hdbscan_min_samples": 1, | |
| # Post smoothing | |
| "post_smoothing": True, | |
| "min_segment_post_sec": 0.4, | |
| "merge_same_speaker_gap_sec": 0.3, | |
| "robust_audio_load": True, | |
| "force_single_split_sec": 30.0, | |
| "force_single_split_window_sec": 4.0, | |
| "force_single_split_overlap_sec": 1.0, | |
| "force_min_speakers_if_split": 2, | |
| } | |
| _MODEL_CACHE = {"model": None, "source": None} | |
| def _inject_local_ffmpeg(): | |
| """Add bundled ffmpeg bin folder to PATH so torchaudio/ffmpeg-based conversion works.""" | |
| try: | |
| base_dir = Path(__file__).parent | |
| candidates = [ | |
| base_dir / 'ffmpeg' / 'ffmpeg-8.0-essentials_build' / 'bin', | |
| base_dir / 'ffmpeg' / 'bin', | |
| ] | |
| for c in candidates: | |
| if c.is_dir(): | |
| bin_path = str(c) | |
| if bin_path not in os.environ.get('PATH',''): | |
| os.environ['PATH'] = bin_path + os.pathsep + os.environ.get('PATH','') | |
| print(f"[SpeechBrain] FFmpeg agregado al PATH: {bin_path}") | |
| break | |
| except Exception: | |
| pass | |
| _inject_local_ffmpeg() | |
| def is_available() -> bool: | |
| """Return True if speechbrain is importable (core dependency).""" | |
| try: | |
| import speechbrain # noqa: F401 | |
| return True | |
| except Exception: | |
| return False | |
| def load_model(device: str = "cpu", local_dir: str = "speechbrain_pretrained"): | |
| """Load (or reuse) SpeechBrain encoder model with optional LocalStrategy. | |
| Avoids symlink issues on Windows and logs phases. | |
| """ | |
| if _MODEL_CACHE["model"] is not None: | |
| return _MODEL_CACHE["model"] | |
| t0 = time.time() | |
| try: | |
| os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1") | |
| os.environ.setdefault("HUGGINGFACE_HUB_DISABLE_SYMLINKS", "1") | |
| os.environ.setdefault("SPEECHBRAIN_LOCALSTRATEGY", "1") | |
| except Exception: | |
| pass | |
| from speechbrain.inference import EncoderClassifier | |
| # Monkeypatch Pretrainer collect_in to avoid symlink warnings (best effort) | |
| try: | |
| from speechbrain.utils import parameter_transfer as _pt | |
| if hasattr(_pt, "Pretrainer"): | |
| _OrigPretrainer = _pt.Pretrainer | |
| class _PatchedPretrainer(_OrigPretrainer): # type: ignore | |
| def __init__(self, *args, **kwargs): | |
| if "collect_in" in kwargs: | |
| kwargs["collect_in"] = None | |
| super().__init__(*args, **kwargs) | |
| _pt.Pretrainer = _PatchedPretrainer # type: ignore | |
| except Exception: | |
| pass | |
| # Warning filters | |
| warnings.filterwarnings( | |
| "ignore", | |
| message=r"Requested Pretrainer collection using symlinks on Windows", | |
| category=UserWarning, | |
| module=r"speechbrain\.utils\.parameter_transfer" | |
| ) | |
| warnings.filterwarnings( | |
| "ignore", | |
| message=r"`torch\.cuda\.amp\.custom_fwd\(args\.\.\.\)` is deprecated", | |
| category=FutureWarning, | |
| module=r"speechbrain\.utils\.autocast" | |
| ) | |
| # Try LocalStrategy then fallback | |
| model = None | |
| try: | |
| from speechbrain.utils.fetching import LocalStrategy | |
| os.makedirs(local_dir, exist_ok=True) | |
| fetch_strategy = LocalStrategy() | |
| model = EncoderClassifier.from_hparams( | |
| source="speechbrain/spkrec-ecapa-voxceleb", | |
| savedir=local_dir, | |
| fetch_strategy=fetch_strategy, | |
| run_opts={"device": device} | |
| ) | |
| print("[SpeechBrain] Modelo cargado con LocalStrategy (%.2fs)" % (time.time()-t0)) | |
| except Exception as e: | |
| # Silenciar TypeError específico de versiones antiguas donde LocalStrategy es un Enum interno | |
| if isinstance(e, TypeError) and 'EnumType.__call__' in str(e): | |
| print("[SpeechBrain] LocalStrategy no soportada en esta versión -> fallback (%.2fs)" % (time.time()-t0)) | |
| else: | |
| print(f"[SpeechBrain] LocalStrategy no disponible ({type(e).__name__}: {e}); fallback (%.2fs)" % (time.time()-t0)) | |
| model = EncoderClassifier.from_hparams( | |
| source="speechbrain/spkrec-ecapa-voxceleb", | |
| savedir=local_dir, | |
| run_opts={"device": device} | |
| ) | |
| print("[SpeechBrain] Modelo cargado fallback (%.2fs)" % (time.time()-t0)) | |
| _MODEL_CACHE["model"] = model | |
| print("[SpeechBrain] load_model completo (%.2fs)" % (time.time()-t0)) | |
| return model | |
| def _segment_audio(waveform, sample_rate: int, cfg: dict) -> List[Tuple[int, int]]: | |
| win = int(cfg["window_sec"] * sample_rate) | |
| hop = int((cfg["window_sec"] - cfg["overlap_sec"]) * sample_rate) | |
| if hop <= 0: | |
| hop = win # no overlap safety | |
| total = waveform.shape[1] | |
| segments = [] | |
| for start in range(0, total, hop): | |
| end = min(start + win, total) | |
| dur = (end - start) / sample_rate | |
| if dur >= cfg["min_segment_sec"]: | |
| segments.append((start, end)) | |
| if end == total: | |
| break | |
| return segments | |
| def _apply_vad(waveform, sample_rate: int, cfg: dict): | |
| """Muy simple VAD basado en energía RMS por frames. | |
| Devuelve lista de (start_sample, end_sample) de regiones detectadas como voz. | |
| Se fusionan regiones cercanas (< vad_merge_gap_sec) y se descartan cortas (< vad_min_speech_sec). | |
| """ | |
| import torch | |
| frame = int(cfg["vad_frame_sec"] * sample_rate) | |
| hop = int(cfg["vad_frame_hop_sec"] * sample_rate) | |
| if frame <= 0: | |
| frame = int(0.5 * sample_rate) | |
| if hop <= 0: | |
| hop = frame | |
| wav = waveform.squeeze(0) | |
| total = wav.shape[0] | |
| energies = [] | |
| bounds = [] | |
| for start in range(0, total, hop): | |
| end = min(start + frame, total) | |
| seg = wav[start:end] | |
| if seg.numel() == 0: | |
| continue | |
| # RMS energy | |
| e = torch.sqrt(torch.mean(seg ** 2) + 1e-12).item() | |
| energies.append(e) | |
| bounds.append((start, end)) | |
| if end == total: | |
| break | |
| import numpy as np | |
| if not energies: | |
| return [] | |
| arr = np.array(energies) | |
| thr = cfg["vad_energy_threshold"] | |
| if thr <= 0.0: | |
| thr = float(arr.mean() + cfg["vad_energy_std_factor"] * arr.std()) | |
| speech_flags = arr >= thr | |
| # Convert contiguous True frames to sample ranges | |
| speech_ranges = [] | |
| cur_start = None | |
| for i, flag in enumerate(speech_flags): | |
| if flag and cur_start is None: | |
| cur_start = bounds[i][0] | |
| if not flag and cur_start is not None: | |
| speech_ranges.append((cur_start, bounds[i][1])) | |
| cur_start = None | |
| if cur_start is not None: | |
| speech_ranges.append((cur_start, bounds[-1][1])) | |
| # Merge close ranges | |
| merged = [] | |
| gap = int(cfg["vad_merge_gap_sec"] * sample_rate) | |
| for (s, e) in speech_ranges: | |
| if not merged: | |
| merged.append([s, e]) | |
| else: | |
| if s - merged[-1][1] <= gap: | |
| merged[-1][1] = e | |
| else: | |
| merged.append([s, e]) | |
| # Drop short | |
| min_len = int(cfg["vad_min_speech_sec"] * sample_rate) | |
| final = [(s, e) for (s, e) in merged if (e - s) >= min_len] | |
| return final | |
| def _compute_embeddings(model, waveform, sample_rate: int, segments: List[Tuple[int, int]]): | |
| import torch | |
| import numpy as np | |
| embs = [] | |
| times = [] | |
| with torch.no_grad(): | |
| for (s, e) in segments: | |
| seg = waveform[:, s:e] | |
| # Expected shape [batch, time]; ensure mono already. | |
| batch = seg.squeeze(0).unsqueeze(0) | |
| emb = model.encode_batch(batch) | |
| vec = emb.squeeze().cpu().numpy() | |
| # L2 normalize for cosine stability | |
| norm = np.linalg.norm(vec) + 1e-9 | |
| vec = vec / norm | |
| embs.append(vec) | |
| times.append((s / sample_rate, e / sample_rate)) | |
| return embs, times | |
| def _estimate_num_speakers(embeddings, cfg: dict, sklearn_available: bool) -> int: | |
| import numpy as np | |
| if sklearn_available: | |
| try: | |
| from sklearn.metrics import silhouette_score | |
| from sklearn.cluster import SpectralClustering | |
| except Exception: | |
| sklearn_available = False | |
| n = len(embeddings) | |
| if n < 2: | |
| return 1 | |
| arr = np.vstack(embeddings) | |
| k_min = cfg["min_speakers"] | |
| k_max = min(cfg["max_speakers"], n) | |
| if k_max < k_min: | |
| return k_min | |
| if sklearn_available and cfg.get("use_silhouette", True) and n >= cfg.get("silhouette_min_embeddings", 4) and (k_max - k_min) >= 1: | |
| best_k = k_min | |
| best_score = -1.0 | |
| for k in range(k_min, k_max + 1): | |
| try: | |
| # Usar misma afinidad que el clustering real | |
| clustering = SpectralClustering(n_clusters=k, affinity="cosine", random_state=cfg["random_state"]) | |
| labels = clustering.fit_predict(arr) | |
| if len(set(labels)) < 2: | |
| continue | |
| score = silhouette_score(arr, labels, metric="cosine") | |
| # Bias hacia más clusters si el score es similar | |
| score_bonus = k * 0.02 # Pequeño bonus por más clusters | |
| adjusted_score = score + score_bonus | |
| if adjusted_score > best_score: | |
| best_score = adjusted_score | |
| best_k = k | |
| except Exception: | |
| continue | |
| return best_k | |
| # Fallback heuristic - bias hacia más speakers | |
| estimated = max(k_min, min(k_max, int(round(math.sqrt(n * 1.5))))) | |
| # Asegurar al menos 2 speakers para conversaciones | |
| return max(2, estimated) | |
| def _cluster_embeddings(embeddings, times, cfg: dict): | |
| import numpy as np | |
| sklearn_available = True | |
| try: | |
| from sklearn.cluster import SpectralClustering, AgglomerativeClustering | |
| except Exception: | |
| sklearn_available = False | |
| if len(embeddings) == 0: | |
| return [] | |
| if len(embeddings) == 1: | |
| return [(times[0][0], times[0][1], 0)] | |
| arr = np.vstack(embeddings) | |
| method = cfg.get("clustering_method", "spectral") | |
| k = _estimate_num_speakers(embeddings, cfg, sklearn_available) if method != "hdbscan" else None | |
| debug = cfg.get("debug_log_path") | |
| if debug: | |
| try: | |
| with open(debug, 'a', encoding='utf-8') as f: | |
| f.write(f"cluster: n_emb={len(embeddings)} method={method} k_est={k} sklearn={sklearn_available}\n") | |
| except Exception: | |
| pass | |
| labels = None | |
| if method == "spectral" and sklearn_available: | |
| try: | |
| clustering = SpectralClustering(n_clusters=k, affinity="cosine", random_state=cfg["random_state"]) | |
| labels = clustering.fit_predict(arr) | |
| except Exception: | |
| labels = None | |
| elif method == "agglomerative" and sklearn_available: | |
| try: | |
| linkage = cfg.get("agglomerative_linkage", "average") | |
| clustering = AgglomerativeClustering(n_clusters=k, affinity="cosine", linkage=linkage) | |
| # Newer sklearn may deprecate affinity in favor of metric | |
| try: | |
| labels = clustering.fit_predict(arr) | |
| except TypeError: | |
| clustering = AgglomerativeClustering(n_clusters=k, metric="cosine", linkage=linkage) | |
| labels = clustering.fit_predict(arr) | |
| except Exception: | |
| labels = None | |
| elif method == "hdbscan": | |
| try: | |
| import hdbscan # type: ignore | |
| clusterer = hdbscan.HDBSCAN(min_cluster_size=cfg.get("hdbscan_min_cluster_size", 2), | |
| min_samples=cfg.get("hdbscan_min_samples", 1), | |
| metric='euclidean') | |
| labels = clusterer.fit_predict(arr) | |
| # HDBSCAN may label some points as -1 (noise). Reassign noise clusters to nearest non-noise centroid. | |
| if (labels == -1).any(): | |
| valid = arr[labels != -1] | |
| if valid.shape[0] >= 1: | |
| import numpy as np | |
| cent = [] | |
| for lab in sorted(set(labels) - {-1}): | |
| cent.append(valid[labels[labels != -1] == lab].mean(axis=0)) | |
| cent = np.stack(cent) if cent else valid | |
| for idx, lab in enumerate(labels): | |
| if lab == -1: | |
| # assign by cosine similarity | |
| sims = np.dot(cent, arr[idx]) / (np.linalg.norm(cent, axis=1) * np.linalg.norm(arr[idx]) + 1e-9) | |
| labels[idx] = int(np.argmax(sims)) | |
| # Remap labels to 0..K-1 | |
| uniq = {lab: i for i, lab in enumerate(sorted(set(labels)))} | |
| labels = np.array([uniq[l] for l in labels]) | |
| except Exception: | |
| labels = None | |
| # Fallback manual k-means if labels still None | |
| if labels is None: | |
| if k is None: | |
| k = _estimate_num_speakers(embeddings, cfg, sklearn_available) | |
| rng = np.random.default_rng(cfg.get("random_state", 42)) | |
| centroids = arr[rng.choice(len(arr), size=k, replace=False)] if k <= len(arr) else arr | |
| for _ in range(10): | |
| sims = np.matmul(arr, centroids.T) | |
| arr_norm = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-9 | |
| cent_norm = np.linalg.norm(centroids, axis=1, keepdims=True).T + 1e-9 | |
| cos = sims / (arr_norm * cent_norm) | |
| labels = np.argmax(cos, axis=1) | |
| new_centroids = [] | |
| for j in range(len(centroids)): | |
| mask = labels == j | |
| if np.any(mask): | |
| new_centroids.append(arr[mask].mean(axis=0)) | |
| else: | |
| new_centroids.append(centroids[j]) | |
| new_centroids = np.vstack(new_centroids) | |
| if np.allclose(new_centroids, centroids): | |
| break | |
| centroids = new_centroids | |
| # Debug distribución de labels | |
| if debug: | |
| try: | |
| import collections | |
| dist = collections.Counter(labels.tolist()) if 'labels' in locals() else {} | |
| with open(debug, 'a', encoding='utf-8') as f: | |
| f.write(f"labels_dist={dict(dist)}\n") | |
| except Exception: | |
| pass | |
| return [(times[i][0], times[i][1], int(labels[i])) for i in range(len(labels))] | |
| def _post_smooth(clustered: List[Tuple[float, float, int]], cfg: dict, debug: str | None): | |
| if not clustered: | |
| return clustered | |
| if not cfg.get("post_smoothing", True): | |
| return clustered | |
| merged = [] | |
| # Primero fusionar segmentos consecutivos del mismo speaker si gap pequeño | |
| gap = cfg.get("merge_same_speaker_gap_sec", 0.3) | |
| for seg in clustered: | |
| if not merged: | |
| merged.append(list(seg)) | |
| else: | |
| last = merged[-1] | |
| if seg[2] == last[2] and seg[0] - last[1] <= gap: | |
| last[1] = seg[1] | |
| else: | |
| merged.append(list(seg)) | |
| # Absorber segmentos demasiado cortos | |
| min_len = cfg.get("min_segment_post_sec", 0.4) | |
| if len(merged) > 1: | |
| i = 0 | |
| while i < len(merged): | |
| s, e, spk = merged[i] | |
| dur = e - s | |
| if dur < min_len: | |
| # decidir a cuál vecino absorber (menor distancia temporal / preferible mismo speaker) | |
| left_gap = (s - merged[i-1][1]) if i > 0 else float('inf') | |
| right_gap = (merged[i+1][0] - e) if i < len(merged)-1 else float('inf') | |
| target = None | |
| if i > 0 and (i == len(merged)-1 or left_gap <= right_gap): | |
| target = i-1 | |
| elif i < len(merged)-1: | |
| target = i+1 | |
| if target is not None: | |
| # extender target | |
| if target < i: # merge into left | |
| merged[target][1] = e | |
| # mantener speaker del target | |
| merged.pop(i) | |
| i -= 1 | |
| else: # merge into right | |
| merged[target][0] = s | |
| merged.pop(i) | |
| i -= 1 | |
| i += 1 | |
| out = [(a, b, int(c)) for (a, b, c) in merged] | |
| if debug: | |
| try: | |
| with open(debug, 'a', encoding='utf-8') as f: | |
| f.write(f"post_smooth: in={len(clustered)} out={len(out)}\n") | |
| except Exception: | |
| pass | |
| return out | |
| def diarize(audio_path: str, device: str = "cpu", config: Optional[dict] = None) -> List[Dict[str, Any]]: | |
| """High-level diarization returning WhisperX-compatible segments list. | |
| Raises exceptions on fatal errors (caller may wrap).""" | |
| cfg = {**DEFAULT_CONFIG, **(config or {})} | |
| import torchaudio | |
| import torch | |
| model = load_model(device=device) | |
| def _safe_load(path): | |
| try: | |
| return torchaudio.load(path) | |
| except Exception as e_primary: | |
| # Intentar conversión a wav si es m4a y robust_audio_load activo | |
| if cfg.get("robust_audio_load", True) and path.lower().endswith('.m4a'): | |
| import subprocess, tempfile | |
| tmp_wav = os.path.join(tempfile.gettempdir(), f"_conv_{int(time.time()*1000)}.wav") | |
| cmd = ['ffmpeg', '-y', '-i', path, '-ac', '1', '-ar', '16000', tmp_wav] | |
| try: | |
| # Calcular timeout dinámico basado en duración estimada | |
| conversion_timeout = max(60, int(os.path.getsize(path) / (1024 * 1024) * 10)) # ~10 seg por MB | |
| subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=conversion_timeout) | |
| return torchaudio.load(tmp_wav) | |
| except Exception: | |
| raise e_primary | |
| raise e_primary | |
| waveform, sr = _safe_load(audio_path) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| if sr != 16000: | |
| resample = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) | |
| waveform = resample(waveform) | |
| sr = 16000 | |
| debug_path = os.path.join(os.path.dirname(__file__), 'speechbrain_debug.log') | |
| cfg["debug_log_path"] = debug_path | |
| t_start = time.time() | |
| # Optional VAD to obtener regiones de voz más puras antes de segmentar ventanas largas. | |
| vad_regions = [] | |
| if cfg.get("use_vad", True): | |
| try: | |
| vad_regions = _apply_vad(waveform, sr, cfg) | |
| except Exception: | |
| vad_regions = [] | |
| # Si hay regiones VAD, subdividir cada región con el esquema de ventanas (para robustez y solapado) | |
| if vad_regions: | |
| segments = [] | |
| base_win = int(cfg["window_sec"] * sr) | |
| base_hop = int((cfg["window_sec"] - cfg["overlap_sec"]) * sr) | |
| if base_hop <= 0: | |
| base_hop = base_win | |
| for (rs, re) in vad_regions: | |
| cur = rs | |
| while cur < re: | |
| end = min(cur + base_win, re) | |
| dur = (end - cur) / sr | |
| if dur >= cfg["min_segment_sec"]: | |
| segments.append((cur, end)) | |
| if end == re: | |
| break | |
| cur += base_hop | |
| else: | |
| segments = _segment_audio(waveform, sr, cfg) | |
| try: | |
| with open(debug_path, 'a', encoding='utf-8') as f: | |
| total_vad_dur = 0.0 | |
| if vad_regions: | |
| total_vad_dur = sum((re - rs)/sr for rs, re in vad_regions) | |
| f.write( | |
| f"diarize: sr={sr} samples={waveform.shape[1]} segs={len(segments)} " | |
| f"vad_regions={len(vad_regions)} vad_dur={total_vad_dur:.1f}s audio_total={waveform.shape[1]/sr:.1f}s " | |
| f"use_vad={cfg.get('use_vad', True)} window={cfg['window_sec']} overlap={cfg['overlap_sec']}\n" | |
| ) | |
| except Exception: | |
| pass | |
| if len(segments) == 0: | |
| return [] | |
| embeddings, times = _compute_embeddings(model, waveform, sr, segments) | |
| # Heurística: si solo 1 embedding y audio bastante largo -> forzar re-segmentación más fina para evitar colapso falso a 1 speaker | |
| audio_total_sec = waveform.shape[1] / sr | |
| if len(embeddings) == 1 and audio_total_sec >= cfg.get("force_single_split_sec", 30.0): | |
| try: | |
| win_sec = cfg.get("force_single_split_window_sec", 4.0) | |
| ov_sec = cfg.get("force_single_split_overlap_sec", 1.0) | |
| tmp_cfg = dict(cfg) | |
| tmp_cfg["window_sec"] = win_sec | |
| tmp_cfg["overlap_sec"] = ov_sec | |
| segments_refine = _segment_audio(waveform, sr, tmp_cfg) | |
| if len(segments_refine) > 1: | |
| embeddings, times = _compute_embeddings(model, waveform, sr, segments_refine) | |
| segments = segments_refine | |
| cfg["min_speakers"] = max(cfg.get("min_speakers", 2), cfg.get("force_min_speakers_if_split", 2)) | |
| debug_path = os.path.join(os.path.dirname(__file__), 'speechbrain_debug.log') | |
| with open(debug_path, 'a', encoding='utf-8') as f: | |
| f.write(f"force_single_split applied win={win_sec} overlap={ov_sec} new_embs={len(embeddings)}\n") | |
| except Exception: | |
| pass | |
| try: | |
| with open(debug_path, 'a', encoding='utf-8') as f: | |
| f.write(f"embeddings: count={len(embeddings)} first_shape={getattr(embeddings[0],'shape',None)}\n") | |
| except Exception: | |
| pass | |
| clustered = _cluster_embeddings(embeddings, times, cfg) | |
| clustered = _post_smooth(clustered, cfg, debug_path) | |
| try: | |
| with open(debug_path, 'a', encoding='utf-8') as f: | |
| uniq = sorted({c[2] for c in clustered}) | |
| f.write(f"clustered: segments={len(clustered)} speakers_est={len(uniq)} speakers={uniq} total_time={time.time()-t_start:.2f}s\n") | |
| except Exception: | |
| pass | |
| # Convert to expected list of dicts | |
| out = [] | |
| for (start, end, spk) in clustered: | |
| out.append({"start": float(start), "end": float(end), "speaker": f"SPEAKER_{spk:02d}"}) | |
| return out | |
| if __name__ == "__main__": # simple manual test guard | |
| ap = os.environ.get("TEST_AUDIO") | |
| if ap and os.path.isfile(ap): | |
| print("Running quick diarization test on", ap) | |
| try: | |
| segs = diarize(ap) | |
| print("Segments:", segs[:5], "... total", len(segs)) | |
| except Exception as e: | |
| print("Error:", e) | |
| print(traceback.format_exc()) | |
| else: | |
| print("Set TEST_AUDIO env var to an audio file path for test.") | |