| import asyncio |
| import re |
| import threading |
| import numpy as np |
| import logging |
|
|
|
|
| from diart import SpeakerDiarization, SpeakerDiarizationConfig |
| from diart.inference import StreamingInference |
| from diart.sources import AudioSource |
| from whisperlivekit.timed_objects import SpeakerSegment |
| from diart.sources import MicrophoneAudioSource |
| from rx.core import Observer |
| from typing import Tuple, Any, List |
| from pyannote.core import Annotation |
|
|
| logger = logging.getLogger(__name__) |
|
|
| def extract_number(s: str) -> int: |
| m = re.search(r'\d+', s) |
| return int(m.group()) if m else None |
|
|
| class DiarizationObserver(Observer): |
| """Observer that logs all data emitted by the diarization pipeline and stores speaker segments.""" |
| |
| def __init__(self): |
| self.speaker_segments = [] |
| self.processed_time = 0 |
| self.segment_lock = threading.Lock() |
| |
| def on_next(self, value: Tuple[Annotation, Any]): |
| annotation, audio = value |
| |
| logger.debug("\n--- New Diarization Result ---") |
| |
| duration = audio.extent.end - audio.extent.start |
| logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)") |
| logger.debug(f"Audio shape: {audio.data.shape}") |
| |
| with self.segment_lock: |
| if audio.extent.end > self.processed_time: |
| self.processed_time = audio.extent.end |
| if annotation and len(annotation._labels) > 0: |
| logger.debug("\nSpeaker segments:") |
| for speaker, label in annotation._labels.items(): |
| for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]): |
| print(f" {speaker}: {start:.2f}s-{end:.2f}s") |
| self.speaker_segments.append(SpeakerSegment( |
| speaker=speaker, |
| start=start, |
| end=end |
| )) |
| else: |
| logger.debug("\nNo speakers detected in this segment") |
| |
| def get_segments(self) -> List[SpeakerSegment]: |
| """Get a copy of the current speaker segments.""" |
| with self.segment_lock: |
| return self.speaker_segments.copy() |
| |
| def clear_old_segments(self, older_than: float = 30.0): |
| """Clear segments older than the specified time.""" |
| with self.segment_lock: |
| current_time = self.processed_time |
| self.speaker_segments = [ |
| segment for segment in self.speaker_segments |
| if current_time - segment.end < older_than |
| ] |
| |
| def on_error(self, error): |
| """Handle an error in the stream.""" |
| logger.debug(f"Error in diarization stream: {error}") |
| |
| def on_completed(self): |
| """Handle the completion of the stream.""" |
| logger.debug("Diarization stream completed") |
|
|
|
|
| class WebSocketAudioSource(AudioSource): |
| """ |
| Custom AudioSource that blocks in read() until close() is called. |
| Use push_audio() to inject PCM chunks. |
| """ |
| def __init__(self, uri: str = "websocket", sample_rate: int = 16000): |
| super().__init__(uri, sample_rate) |
| self._closed = False |
| self._close_event = threading.Event() |
|
|
| def read(self): |
| self._close_event.wait() |
|
|
| def close(self): |
| if not self._closed: |
| self._closed = True |
| self.stream.on_completed() |
| self._close_event.set() |
|
|
| def push_audio(self, chunk: np.ndarray): |
| if not self._closed: |
| new_audio = np.expand_dims(chunk, axis=0) |
| logger.debug('Add new chunk with shape:', new_audio.shape) |
| self.stream.on_next(new_audio) |
|
|
|
|
| class DiartDiarization: |
| def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False): |
| self.pipeline = SpeakerDiarization(config=config) |
| self.observer = DiarizationObserver() |
| |
| if use_microphone: |
| self.source = MicrophoneAudioSource() |
| self.custom_source = None |
| else: |
| self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate) |
| self.source = self.custom_source |
| |
| self.inference = StreamingInference( |
| pipeline=self.pipeline, |
| source=self.source, |
| do_plot=False, |
| show_progress=False, |
| ) |
| self.inference.attach_observers(self.observer) |
| asyncio.get_event_loop().run_in_executor(None, self.inference) |
|
|
| async def diarize(self, pcm_array: np.ndarray): |
| """ |
| Process audio data for diarization. |
| Only used when working with WebSocketAudioSource. |
| """ |
| if self.custom_source: |
| self.custom_source.push_audio(pcm_array) |
| self.observer.clear_old_segments() |
| return self.observer.get_segments() |
|
|
| def close(self): |
| """Close the audio source.""" |
| if self.custom_source: |
| self.custom_source.close() |
|
|
| def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float: |
| """ |
| Assign speakers to tokens based on timing overlap with speaker segments. |
| Uses the segments collected by the observer. |
| """ |
| segments = self.observer.get_segments() |
| |
| for token in tokens: |
| for segment in segments: |
| if not (segment.end <= token.start or segment.start >= token.end): |
| token.speaker = extract_number(segment.speaker) + 1 |
| end_attributed_speaker = max(token.end, end_attributed_speaker) |
| return end_attributed_speaker |