Spaces:
Sleeping
Sleeping
| """ | |
| Speaker diarization service. | |
| Supports: | |
| - pyannote | |
| - sortformer | |
| Production / QA optimized for call center. | |
| """ | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from pathlib import Path | |
| from typing import List, Optional, Dict | |
| from dataclasses import dataclass | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from app.core.config import get_settings | |
| logger = logging.getLogger(__name__) | |
| settings = get_settings() | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # ========================================================= | |
| # DATA MODELS | |
| # ========================================================= | |
| class SpeakerSegment: | |
| start: float | |
| end: float | |
| speaker: str | |
| confidence: float = 1.0 | |
| def duration(self) -> float: | |
| return self.end - self.start | |
| class DiarizationResult: | |
| segments: List["SpeakerSegment"] | |
| speaker_count: int | |
| speakers: List[str] | |
| roles: Dict[str, str] | |
| # ========================================================= | |
| # BASE DIARIZER | |
| # ========================================================= | |
| class BaseDiarizer(ABC): | |
| def diarize( | |
| self, | |
| audio_path: Path, | |
| num_speakers: Optional[int] = None, | |
| min_speakers: int = 1, | |
| max_speakers: int = 10 | |
| ) -> DiarizationResult: | |
| pass | |
| # ----------------------------------------------------- | |
| # ROLE INFERENCE | |
| # ----------------------------------------------------- | |
| def infer_roles( | |
| segments: List[SpeakerSegment] | |
| ) -> Dict[str, str]: | |
| duration_map: Dict[str, float] = {} | |
| for seg in segments: | |
| duration_map[seg.speaker] = ( | |
| duration_map.get(seg.speaker, 0.0) | |
| + seg.duration | |
| ) | |
| if not duration_map: | |
| return {} | |
| agent = max( | |
| duration_map, | |
| key=duration_map.get | |
| ) | |
| return { | |
| spk: ( | |
| "NV" | |
| if spk == agent | |
| else "KH" | |
| ) | |
| for spk in duration_map | |
| } | |
| # ========================================================= | |
| # PYANNOTE | |
| # ========================================================= | |
| class PyannoteDiarizer(BaseDiarizer): | |
| def __init__(self): | |
| from pyannote.audio import Pipeline | |
| logger.info( | |
| f"Loading pyannote model: " | |
| f"{settings.pyannote_model}" | |
| ) | |
| self.pipeline = Pipeline.from_pretrained( | |
| settings.pyannote_model, | |
| token=settings.hf_token | |
| ) | |
| self.pipeline.instantiate({ | |
| "clustering": { | |
| "threshold": 0.65 | |
| }, | |
| "segmentation": { | |
| "min_duration_off": 0.4 | |
| } | |
| }) | |
| device = torch.device(settings.resolved_device) | |
| if device.type == "cuda": | |
| self.pipeline = self.pipeline.to(device) | |
| logger.info("Pyannote READY") | |
| def diarize( | |
| self, | |
| audio_path: Path, | |
| num_speakers: Optional[int] = None, | |
| min_speakers: int = 1, | |
| max_speakers: int = 10 | |
| ) -> DiarizationResult: | |
| params = {} | |
| if num_speakers is not None: | |
| params["num_speakers"] = num_speakers | |
| else: | |
| params["min_speakers"] = min_speakers | |
| params["max_speakers"] = max_speakers | |
| diarization = self.pipeline( | |
| str(audio_path), | |
| **params | |
| ) | |
| annotation = ( | |
| diarization.speaker_diarization | |
| if hasattr(diarization, "speaker_diarization") | |
| else diarization | |
| ) | |
| segments: List[SpeakerSegment] = [] | |
| speaker_map = {} | |
| idx = 1 | |
| for turn, _, speaker in annotation.itertracks( | |
| yield_label=True | |
| ): | |
| if speaker not in speaker_map: | |
| speaker_map[speaker] = f"Speaker {idx}" | |
| idx += 1 | |
| segments.append( | |
| SpeakerSegment( | |
| start=float(turn.start), | |
| end=float(turn.end), | |
| speaker=speaker_map[speaker] | |
| ) | |
| ) | |
| segments.sort(key=lambda x: x.start) | |
| speakers = list({ | |
| s.speaker | |
| for s in segments | |
| }) | |
| roles = self.infer_roles(segments) | |
| return DiarizationResult( | |
| segments=segments, | |
| speaker_count=len(speakers), | |
| speakers=speakers, | |
| roles=roles | |
| ) | |
| # ========================================================= | |
| # SORTFORMER | |
| # ========================================================= | |
| class SortformerDiarizer(BaseDiarizer): | |
| def __init__(self): | |
| import nemo.collections.asr as nemo_asr | |
| logger.info( | |
| f"Loading sortformer model: " | |
| f"{settings.sortformer_model}" | |
| ) | |
| self.model = ( | |
| nemo_asr.models.SortformerEncLabelModel | |
| .from_pretrained( | |
| model_name=settings.sortformer_model | |
| ) | |
| .to(settings.resolved_device) | |
| ) | |
| logger.info("Sortformer READY") | |
| def diarize( | |
| self, | |
| audio_path: Path, | |
| num_speakers: Optional[int] = None, | |
| min_speakers: int = 1, | |
| max_speakers: int = 10 | |
| ) -> DiarizationResult: | |
| pred = self.model.diarize( | |
| audio=str(audio_path), | |
| batch_size=1 | |
| ) | |
| segments = self.normalize(pred) | |
| speakers = list({ | |
| s.speaker | |
| for s in segments | |
| }) | |
| roles = self.infer_roles(segments) | |
| return DiarizationResult( | |
| segments=segments, | |
| speaker_count=len(speakers), | |
| speakers=speakers, | |
| roles=roles | |
| ) | |
| # ----------------------------------------------------- | |
| # NORMALIZE OUTPUT | |
| # ----------------------------------------------------- | |
| def normalize( | |
| self, | |
| pred | |
| ) -> List[SpeakerSegment]: | |
| if isinstance(pred, list) and len(pred) == 1: | |
| pred = pred[0] | |
| segments: List[SpeakerSegment] = [] | |
| speaker_map = {} | |
| idx = 1 | |
| for s in pred: | |
| if not isinstance(s, str): | |
| continue | |
| parts = s.split() | |
| if len(parts) < 3: | |
| continue | |
| raw_speaker = parts[2] | |
| if raw_speaker not in speaker_map: | |
| speaker_map[raw_speaker] = ( | |
| f"Speaker {idx}" | |
| ) | |
| idx += 1 | |
| segments.append( | |
| SpeakerSegment( | |
| start=float(parts[0]), | |
| end=float(parts[1]), | |
| speaker=speaker_map[raw_speaker] | |
| ) | |
| ) | |
| return sorted( | |
| segments, | |
| key=lambda x: x.start | |
| ) | |
| # ========================================================= | |
| # MAIN SERVICE | |
| # ========================================================= | |
| class DiarizationService: | |
| _instance = None | |
| _diarizer = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| return cls._instance | |
| # ----------------------------------------------------- | |
| # LOAD MODEL | |
| # ----------------------------------------------------- | |
| def get_diarizer(cls): | |
| if cls._diarizer is not None: | |
| return cls._diarizer | |
| model_type = ( | |
| settings.diarization_backend | |
| .lower() | |
| .strip() | |
| ) | |
| logger.info( | |
| f"Initializing diarization backend: " | |
| f"{model_type}" | |
| ) | |
| if model_type == "pyannote": | |
| cls._diarizer = PyannoteDiarizer() | |
| elif model_type == "sortformer": | |
| cls._diarizer = SortformerDiarizer() | |
| else: | |
| raise ValueError( | |
| f"Unsupported diarization backend: " | |
| f"{model_type}" | |
| ) | |
| return cls._diarizer | |
| # ----------------------------------------------------- | |
| # MAIN API | |
| # ----------------------------------------------------- | |
| def diarize( | |
| cls, | |
| audio_path: Path, | |
| num_speakers: Optional[int] = None, | |
| min_speakers: int = 1, | |
| max_speakers: int = 10 | |
| ) -> DiarizationResult: | |
| diarizer = cls.get_diarizer() | |
| return diarizer.diarize( | |
| audio_path=audio_path, | |
| num_speakers=num_speakers, | |
| min_speakers=min_speakers, | |
| max_speakers=max_speakers | |
| ) | |
| # ----------------------------------------------------- | |
| # ASYNC | |
| # ----------------------------------------------------- | |
| async def diarize_async( | |
| cls, | |
| audio_path: Path, | |
| num_speakers: Optional[int] = None, | |
| min_speakers: int = 1, | |
| max_speakers: int = 10 | |
| ) -> DiarizationResult: | |
| import asyncio | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor( | |
| None, | |
| lambda: cls.diarize( | |
| audio_path=audio_path, | |
| num_speakers=num_speakers, | |
| min_speakers=min_speakers, | |
| max_speakers=max_speakers | |
| ) | |
| ) | |
| # ----------------------------------------------------- | |
| # PRELOAD | |
| # ----------------------------------------------------- | |
| def preload_pipeline(cls): | |
| try: | |
| cls.get_diarizer() | |
| except Exception as e: | |
| logger.warning( | |
| f"Failed to preload diarization " | |
| f"pipeline: {e}" | |
| ) |