PoC_ASR_v6_dev / app /services /diarization.py
vyluong's picture
Update app/services/diarization.py
4386f6c verified
"""
Speaker diarization service.
Supports:
- pyannote
- sortformer
Production / QA optimized for call center.
"""
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Dict
from dataclasses import dataclass
import librosa
import numpy as np
import torch
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# =========================================================
# DATA MODELS
# =========================================================
@dataclass
class SpeakerSegment:
start: float
end: float
speaker: str
confidence: float = 1.0
@property
def duration(self) -> float:
return self.end - self.start
@dataclass
class DiarizationResult:
segments: List["SpeakerSegment"]
speaker_count: int
speakers: List[str]
roles: Dict[str, str]
# =========================================================
# BASE DIARIZER
# =========================================================
class BaseDiarizer(ABC):
@abstractmethod
def diarize(
self,
audio_path: Path,
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 10
) -> DiarizationResult:
pass
# -----------------------------------------------------
# ROLE INFERENCE
# -----------------------------------------------------
@staticmethod
def infer_roles(
segments: List[SpeakerSegment]
) -> Dict[str, str]:
duration_map: Dict[str, float] = {}
for seg in segments:
duration_map[seg.speaker] = (
duration_map.get(seg.speaker, 0.0)
+ seg.duration
)
if not duration_map:
return {}
agent = max(
duration_map,
key=duration_map.get
)
return {
spk: (
"NV"
if spk == agent
else "KH"
)
for spk in duration_map
}
# =========================================================
# PYANNOTE
# =========================================================
class PyannoteDiarizer(BaseDiarizer):
def __init__(self):
from pyannote.audio import Pipeline
logger.info(
f"Loading pyannote model: "
f"{settings.pyannote_model}"
)
self.pipeline = Pipeline.from_pretrained(
settings.pyannote_model,
token=settings.hf_token
)
self.pipeline.instantiate({
"clustering": {
"threshold": 0.65
},
"segmentation": {
"min_duration_off": 0.4
}
})
device = torch.device(settings.resolved_device)
if device.type == "cuda":
self.pipeline = self.pipeline.to(device)
logger.info("Pyannote READY")
def diarize(
self,
audio_path: Path,
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 10
) -> DiarizationResult:
params = {}
if num_speakers is not None:
params["num_speakers"] = num_speakers
else:
params["min_speakers"] = min_speakers
params["max_speakers"] = max_speakers
diarization = self.pipeline(
str(audio_path),
**params
)
annotation = (
diarization.speaker_diarization
if hasattr(diarization, "speaker_diarization")
else diarization
)
segments: List[SpeakerSegment] = []
speaker_map = {}
idx = 1
for turn, _, speaker in annotation.itertracks(
yield_label=True
):
if speaker not in speaker_map:
speaker_map[speaker] = f"Speaker {idx}"
idx += 1
segments.append(
SpeakerSegment(
start=float(turn.start),
end=float(turn.end),
speaker=speaker_map[speaker]
)
)
segments.sort(key=lambda x: x.start)
speakers = list({
s.speaker
for s in segments
})
roles = self.infer_roles(segments)
return DiarizationResult(
segments=segments,
speaker_count=len(speakers),
speakers=speakers,
roles=roles
)
# =========================================================
# SORTFORMER
# =========================================================
class SortformerDiarizer(BaseDiarizer):
def __init__(self):
import nemo.collections.asr as nemo_asr
logger.info(
f"Loading sortformer model: "
f"{settings.sortformer_model}"
)
self.model = (
nemo_asr.models.SortformerEncLabelModel
.from_pretrained(
model_name=settings.sortformer_model
)
.to(settings.resolved_device)
)
logger.info("Sortformer READY")
def diarize(
self,
audio_path: Path,
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 10
) -> DiarizationResult:
pred = self.model.diarize(
audio=str(audio_path),
batch_size=1
)
segments = self.normalize(pred)
speakers = list({
s.speaker
for s in segments
})
roles = self.infer_roles(segments)
return DiarizationResult(
segments=segments,
speaker_count=len(speakers),
speakers=speakers,
roles=roles
)
# -----------------------------------------------------
# NORMALIZE OUTPUT
# -----------------------------------------------------
def normalize(
self,
pred
) -> List[SpeakerSegment]:
if isinstance(pred, list) and len(pred) == 1:
pred = pred[0]
segments: List[SpeakerSegment] = []
speaker_map = {}
idx = 1
for s in pred:
if not isinstance(s, str):
continue
parts = s.split()
if len(parts) < 3:
continue
raw_speaker = parts[2]
if raw_speaker not in speaker_map:
speaker_map[raw_speaker] = (
f"Speaker {idx}"
)
idx += 1
segments.append(
SpeakerSegment(
start=float(parts[0]),
end=float(parts[1]),
speaker=speaker_map[raw_speaker]
)
)
return sorted(
segments,
key=lambda x: x.start
)
# =========================================================
# MAIN SERVICE
# =========================================================
class DiarizationService:
_instance = None
_diarizer = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
# -----------------------------------------------------
# LOAD MODEL
# -----------------------------------------------------
@classmethod
def get_diarizer(cls):
if cls._diarizer is not None:
return cls._diarizer
model_type = (
settings.diarization_backend
.lower()
.strip()
)
logger.info(
f"Initializing diarization backend: "
f"{model_type}"
)
if model_type == "pyannote":
cls._diarizer = PyannoteDiarizer()
elif model_type == "sortformer":
cls._diarizer = SortformerDiarizer()
else:
raise ValueError(
f"Unsupported diarization backend: "
f"{model_type}"
)
return cls._diarizer
# -----------------------------------------------------
# MAIN API
# -----------------------------------------------------
@classmethod
def diarize(
cls,
audio_path: Path,
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 10
) -> DiarizationResult:
diarizer = cls.get_diarizer()
return diarizer.diarize(
audio_path=audio_path,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers
)
# -----------------------------------------------------
# ASYNC
# -----------------------------------------------------
@classmethod
async def diarize_async(
cls,
audio_path: Path,
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 10
) -> DiarizationResult:
import asyncio
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: cls.diarize(
audio_path=audio_path,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers
)
)
# -----------------------------------------------------
# PRELOAD
# -----------------------------------------------------
@classmethod
def preload_pipeline(cls):
try:
cls.get_diarizer()
except Exception as e:
logger.warning(
f"Failed to preload diarization "
f"pipeline: {e}"
)