""" Speaker diarization service. Supports: - pyannote - sortformer Production / QA optimized for call center. """ import logging from abc import ABC, abstractmethod from pathlib import Path from typing import List, Optional, Dict from dataclasses import dataclass import librosa import numpy as np import torch from app.core.config import get_settings logger = logging.getLogger(__name__) settings = get_settings() torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # ========================================================= # DATA MODELS # ========================================================= @dataclass class SpeakerSegment: start: float end: float speaker: str confidence: float = 1.0 @property def duration(self) -> float: return self.end - self.start @dataclass class DiarizationResult: segments: List["SpeakerSegment"] speaker_count: int speakers: List[str] roles: Dict[str, str] # ========================================================= # BASE DIARIZER # ========================================================= class BaseDiarizer(ABC): @abstractmethod def diarize( self, audio_path: Path, num_speakers: Optional[int] = None, min_speakers: int = 1, max_speakers: int = 10 ) -> DiarizationResult: pass # ----------------------------------------------------- # ROLE INFERENCE # ----------------------------------------------------- @staticmethod def infer_roles( segments: List[SpeakerSegment] ) -> Dict[str, str]: duration_map: Dict[str, float] = {} for seg in segments: duration_map[seg.speaker] = ( duration_map.get(seg.speaker, 0.0) + seg.duration ) if not duration_map: return {} agent = max( duration_map, key=duration_map.get ) return { spk: ( "NV" if spk == agent else "KH" ) for spk in duration_map } # ========================================================= # PYANNOTE # ========================================================= class PyannoteDiarizer(BaseDiarizer): def __init__(self): from pyannote.audio import Pipeline logger.info( f"Loading pyannote model: " f"{settings.pyannote_model}" ) self.pipeline = Pipeline.from_pretrained( settings.pyannote_model, token=settings.hf_token ) self.pipeline.instantiate({ "clustering": { "threshold": 0.65 }, "segmentation": { "min_duration_off": 0.4 } }) device = torch.device(settings.resolved_device) if device.type == "cuda": self.pipeline = self.pipeline.to(device) logger.info("Pyannote READY") def diarize( self, audio_path: Path, num_speakers: Optional[int] = None, min_speakers: int = 1, max_speakers: int = 10 ) -> DiarizationResult: params = {} if num_speakers is not None: params["num_speakers"] = num_speakers else: params["min_speakers"] = min_speakers params["max_speakers"] = max_speakers diarization = self.pipeline( str(audio_path), **params ) annotation = ( diarization.speaker_diarization if hasattr(diarization, "speaker_diarization") else diarization ) segments: List[SpeakerSegment] = [] speaker_map = {} idx = 1 for turn, _, speaker in annotation.itertracks( yield_label=True ): if speaker not in speaker_map: speaker_map[speaker] = f"Speaker {idx}" idx += 1 segments.append( SpeakerSegment( start=float(turn.start), end=float(turn.end), speaker=speaker_map[speaker] ) ) segments.sort(key=lambda x: x.start) speakers = list({ s.speaker for s in segments }) roles = self.infer_roles(segments) return DiarizationResult( segments=segments, speaker_count=len(speakers), speakers=speakers, roles=roles ) # ========================================================= # SORTFORMER # ========================================================= class SortformerDiarizer(BaseDiarizer): def __init__(self): import nemo.collections.asr as nemo_asr logger.info( f"Loading sortformer model: " f"{settings.sortformer_model}" ) self.model = ( nemo_asr.models.SortformerEncLabelModel .from_pretrained( model_name=settings.sortformer_model ) .to(settings.resolved_device) ) logger.info("Sortformer READY") def diarize( self, audio_path: Path, num_speakers: Optional[int] = None, min_speakers: int = 1, max_speakers: int = 10 ) -> DiarizationResult: pred = self.model.diarize( audio=str(audio_path), batch_size=1 ) segments = self.normalize(pred) speakers = list({ s.speaker for s in segments }) roles = self.infer_roles(segments) return DiarizationResult( segments=segments, speaker_count=len(speakers), speakers=speakers, roles=roles ) # ----------------------------------------------------- # NORMALIZE OUTPUT # ----------------------------------------------------- def normalize( self, pred ) -> List[SpeakerSegment]: if isinstance(pred, list) and len(pred) == 1: pred = pred[0] segments: List[SpeakerSegment] = [] speaker_map = {} idx = 1 for s in pred: if not isinstance(s, str): continue parts = s.split() if len(parts) < 3: continue raw_speaker = parts[2] if raw_speaker not in speaker_map: speaker_map[raw_speaker] = ( f"Speaker {idx}" ) idx += 1 segments.append( SpeakerSegment( start=float(parts[0]), end=float(parts[1]), speaker=speaker_map[raw_speaker] ) ) return sorted( segments, key=lambda x: x.start ) # ========================================================= # MAIN SERVICE # ========================================================= class DiarizationService: _instance = None _diarizer = None def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance # ----------------------------------------------------- # LOAD MODEL # ----------------------------------------------------- @classmethod def get_diarizer(cls): if cls._diarizer is not None: return cls._diarizer model_type = ( settings.diarization_backend .lower() .strip() ) logger.info( f"Initializing diarization backend: " f"{model_type}" ) if model_type == "pyannote": cls._diarizer = PyannoteDiarizer() elif model_type == "sortformer": cls._diarizer = SortformerDiarizer() else: raise ValueError( f"Unsupported diarization backend: " f"{model_type}" ) return cls._diarizer # ----------------------------------------------------- # MAIN API # ----------------------------------------------------- @classmethod def diarize( cls, audio_path: Path, num_speakers: Optional[int] = None, min_speakers: int = 1, max_speakers: int = 10 ) -> DiarizationResult: diarizer = cls.get_diarizer() return diarizer.diarize( audio_path=audio_path, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers ) # ----------------------------------------------------- # ASYNC # ----------------------------------------------------- @classmethod async def diarize_async( cls, audio_path: Path, num_speakers: Optional[int] = None, min_speakers: int = 1, max_speakers: int = 10 ) -> DiarizationResult: import asyncio loop = asyncio.get_event_loop() return await loop.run_in_executor( None, lambda: cls.diarize( audio_path=audio_path, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers ) ) # ----------------------------------------------------- # PRELOAD # ----------------------------------------------------- @classmethod def preload_pipeline(cls): try: cls.get_diarizer() except Exception as e: logger.warning( f"Failed to preload diarization " f"pipeline: {e}" )