| import logging |
| import os |
| import sys |
| import traceback |
| from contextlib import contextmanager |
|
|
| import diart.operators as dops |
| import numpy as np |
| import rich |
| import rx.operators as ops |
| import whisper_timestamped as whisper |
| from diart import OnlineSpeakerDiarization, PipelineConfig |
| from diart.sources import MicrophoneAudioSource |
| from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment |
|
|
|
|
| def concat(chunks, collar=0.05): |
| """ |
| Concatenate predictions and audio |
| given a list of `(diarization, waveform)` pairs |
| and merge contiguous single-speaker regions |
| with pauses shorter than `collar` seconds. |
| """ |
| first_annotation = chunks[0][0] |
| first_waveform = chunks[0][1] |
| annotation = Annotation(uri=first_annotation.uri) |
| data = [] |
| for ann, wav in chunks: |
| annotation.update(ann) |
| data.append(wav.data) |
| annotation = annotation.support(collar) |
| window = SlidingWindow( |
| first_waveform.sliding_window.duration, |
| first_waveform.sliding_window.step, |
| first_waveform.sliding_window.start, |
| ) |
| data = np.concatenate(data, axis=0) |
| return annotation, SlidingWindowFeature(data, window) |
|
|
|
|
| def colorize_transcription(transcription): |
| colors = 2 * [ |
| "bright_red", |
| "bright_blue", |
| "bright_green", |
| "orange3", |
| "deep_pink1", |
| "yellow2", |
| "magenta", |
| "cyan", |
| "bright_magenta", |
| "dodger_blue2", |
| ] |
| result = [] |
| for speaker, text in transcription: |
| if speaker == -1: |
| |
| result.append(text) |
| else: |
| result.append(f"[{colors[speaker]}]{text}") |
| return "\n".join(result) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class WhisperTranscriber: |
| def __init__(self, model="small", device=None): |
| self.model = whisper.load_model(model, device=device) |
| self._buffer = "" |
|
|
| def transcribe(self, waveform): |
| """Transcribe audio using Whisper""" |
| |
| audio = waveform.data.astype("float32").reshape(-1) |
| audio = whisper.pad_or_trim(audio) |
|
|
| |
| transcription = whisper.transcribe( |
| self.model, |
| audio, |
| |
| initial_prompt=self._buffer, |
| verbose=True, |
| ) |
|
|
| return transcription |
|
|
| def identify_speakers(self, transcription, diarization, time_shift): |
| """Iterate over transcription segments to assign speakers""" |
| speaker_captions = [] |
| for segment in transcription["segments"]: |
| |
| start = time_shift + segment["words"][0]["start"] |
| end = time_shift + segment["words"][-1]["end"] |
| dia = diarization.crop(Segment(start, end)) |
|
|
| |
| speakers = dia.labels() |
| num_speakers = len(speakers) |
| if num_speakers == 0: |
| |
| caption = (-1, segment["text"]) |
| elif num_speakers == 1: |
| |
| spk_id = int(speakers[0].split("speaker")[1]) |
| caption = (spk_id, segment["text"]) |
| else: |
| |
| max_speaker = int( |
| np.argmax([dia.label_duration(spk) for spk in speakers]) |
| ) |
| caption = (max_speaker, segment["text"]) |
| speaker_captions.append(caption) |
|
|
| return speaker_captions |
|
|
| def __call__(self, diarization, waveform): |
| |
| transcription = self.transcribe(waveform) |
| |
| self._buffer += transcription["text"] |
| |
| time_shift = waveform.sliding_window.start |
| |
| speaker_transcriptions = self.identify_speakers( |
| transcription, diarization, time_shift |
| ) |
| return speaker_transcriptions |
|
|
|
|
| logging.getLogger("whisper_timestamped").setLevel(logging.ERROR) |
|
|
| config = PipelineConfig( |
| duration=5, step=0.5, latency="min", tau_active=0.5, rho_update=0.1, delta_new=0.57 |
| ) |
| dia = OnlineSpeakerDiarization(config) |
| source = MicrophoneAudioSource(config.sample_rate) |
|
|
| asr = WhisperTranscriber(model="small") |
|
|
| transcription_duration = 2 |
| batch_size = int(transcription_duration // config.step) |
|
|
| source.stream.pipe( |
| |
| dops.rearrange_audio_stream(config.duration, config.step, config.sample_rate), |
| |
| |
| ops.buffer_with_count(count=batch_size), |
| |
| |
| ops.map(dia), |
| |
| ops.map(concat), |
| |
| ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0), |
| |
| |
| ops.starmap(asr), |
| ops.map(colorize_transcription), |
| ).subscribe( |
| on_next=rich.print, |
| on_error=lambda _: traceback.print_exc(), |
| ) |
|
|
| print("Listening...") |
| source.read() |
|
|