File size: 17,068 Bytes
b386992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
"\n",
"Instructions for setting up Colab are as follows:\n",
"1. Open a new Python 3 notebook.\n",
"2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
"3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
"4. Run this cell to set up dependencies.\n",
"\"\"\"\n",
"# If you're using Google Colab and not running locally, run this cell.\n",
"\n",
"## Install dependencies\n",
"!pip install wget\n",
"!apt-get install sox libsndfile1 ffmpeg\n",
"!pip install text-unidecode\n",
"!pip install ipython\n",
"\n",
"# ## Install NeMo\n",
"BRANCH = 'main'\n",
"!python -m pip install git+https://github.com/NVIDIA/NeMo.git@{BRANCH}#egg=nemo_toolkit[asr]\n",
"\n",
"## Install TorchAudio\n",
"!pip install torchaudio -f https://download.pytorch.org/whl/torch_stable.html"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# End-to-End Speaker Diarization in NeMo"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"images/cascaded_diar_diagram.png\" alt=\"Cascaded Diar System \" style=\"width: 800px;\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Traditional cascaded (also referred to as modular or pipelined) speaker diarization systems consist of multiple modules such as a speaker activity detection (SAD) module and a speaker embedding extractor module. Cascaded systems are often time-consuming to develop since each module should be separately trained and optimized."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"images/e2e_diar_diagram.png\" alt=\"End-to-end diarization model diagram\" style=\"width: 800px;\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"On the other hand, end-to-end speaker diarization systems pursue a much more simplified version of a system where a single neural network model accepts raw audio signals and outputs speaker activity for each audio frame. Therefore, end-to-end diarization models have an advantage in ease of optimization and depoloyments."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sortformer Diarization Inference\n",
"\n",
"As explained in the [Sortformer Diarization Training](https://github.com/NVIDIA/NeMo/blob/main/tutorials/speaker_tasks/Speaker_Diarization_Training.ipynb) tutorial, Sortformer is trained with Sort-Loss to generate speaker segments in arrival-time order. If a diarization model can generate speaker segments in a pre-defined rule or order, we do not need to match the permutations when we train diarization model with multi-speaker automatic speech recognition (ASR) models or we do not need to match permutations from each window when a diarization model is running in streaming mode where audio chunk sequences are processed, creating a problem of permutation matchin between inference windows. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"images/intro_comparison.png\" alt=\"Intro Comparison\" style=\"width: 800px;\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### A toy example for speaker diarization with a Sortformer model "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Download a toy example audio file (`an4_diarize_test.wav`) and its ground-truth label file (`an4_diarize_test.rttm`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import wget\n",
"ROOT = os.getcwd()\n",
"data_dir = os.path.join(ROOT,'data')\n",
"os.makedirs(data_dir, exist_ok=True)\n",
"an4_audio = os.path.join(data_dir,'an4_diarize_test.wav')\n",
"an4_rttm = os.path.join(data_dir,'an4_diarize_test.rttm')\n",
"if not os.path.exists(an4_audio):\n",
" an4_audio_url = \"https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.wav\"\n",
" an4_audio = wget.download(an4_audio_url, data_dir)\n",
"if not os.path.exists(an4_rttm):\n",
" an4_rttm_url = \"https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.rttm\"\n",
" an4_rttm = wget.download(an4_rttm_url, data_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's plot and listen to the audio. Pay attention to the two speakers in the audio clip."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import IPython\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import librosa\n",
"\n",
"sr = 16000\n",
"signal, sr = librosa.load(an4_audio,sr=sr) \n",
"\n",
"fig,ax = plt.subplots(1,1)\n",
"fig.set_figwidth(20)\n",
"fig.set_figheight(2)\n",
"plt.plot(np.arange(len(signal)),signal,'gray')\n",
"fig.suptitle('Reference merged an4 audio', fontsize=16)\n",
"plt.xlabel('time (secs)', fontsize=18)\n",
"ax.margins(x=0)\n",
"plt.ylabel('signal strength', fontsize=16)\n",
"a,_ = plt.xticks();plt.xticks(a,a/sr)\n",
"\n",
"IPython.display.Audio(an4_audio)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have downloaded the example audio file contains multiple speakers, we can feed the audio clip into Sortformer diarizer and get the speaker diarization results."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download Sortformer diarizer model\n",
"\n",
"To download Sortformer diarizer from [Sortformer HuggingFace model card](https://huggingface.co/nvidia/diar_sortformer_4spk-v1), you need to get a [HuggingFace Acces Token](https://huggingface.co/docs/hub/en/security-tokens) and feed your access token in your python environment using [HuggingFace CLI](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli).\n",
"\n",
"If you are having trouble getting a HuggingFace token, you can download Sortformer model from [Sortformer HuggingFace model card](https://huggingface.co/nvidia/diar_sortformer_4spk-v1) and specify the path to the downloaded model.\n",
"\n",
"### Timestamp output and tensor output\n",
"\n",
"When excuting `diarize()` function, if you specify `include_tensor_outputs=True`, the diarization model will return the predicted speaker-labeled segments and tensors containing T by N (N is number of max speakers) sigmoid values for each audio frame. \n",
"\n",
"Without `include_tensor_outputs` variable or `include_tensor_outputs=False`, only speaker labeled segments will be returned."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from nemo.collections.asr.models import SortformerEncLabelModel\n",
"from huggingface_hub import get_token as get_hf_token\n",
"import torch\n",
"\n",
"if get_hf_token() is not None and get_hf_token().startswith(\"hf_\"):\n",
" # If you have logged into HuggingFace hub and have access token \n",
" diar_model = SortformerEncLabelModel.from_pretrained(\"nvidia/diar_sortformer_4spk-v1\")\n",
"else:\n",
" # You can downloaded \".nemo\" file from https://huggingface.co/nvidia/diar_sortformer_4spk-v1 and specify the path.\n",
" diar_model = SortformerEncLabelModel.restore_from(restore_path=\"/path/to/diar_sortformer_4spk-v1.nemo\", map_location=torch.device('cuda'), strict=False)\n",
"\n",
"pred_list, pred_tensor_list = diar_model.diarize(audio=an4_audio, include_tensor_outputs=True)\n",
"\n",
"print(pred_list[0])\n",
"print(pred_tensor_list[0].shape)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"Now let's visualize the predicted speaker diarization results. The diarization model outputs a tensor where each row represents a speaker and each column represents a frame. The sigmoid values in the tensor represent the probability of the frame being spoken by that speaker.\n",
"\n",
"In the following code, we'll plot the predicted speaker diarization results for the sample audio file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"def plot_diarout(preds):\n",
" preds_mat = preds.cpu().numpy().transpose()\n",
" cmap_str, grid_color_p= 'viridis', 'gray'\n",
" LW, FS = 0.4, 36\n",
"\n",
" yticklabels = [\"spk0\", \"spk1\", \"spk2\", \"spk3\"]\n",
" yticks = np.arange(len(yticklabels))\n",
" fig, axs = plt.subplots(1, 1, figsize=(30, 3)) \n",
"\n",
" axs.imshow(preds_mat, cmap=cmap_str, interpolation='nearest') \n",
" axs.set_title('Predictions', fontsize=FS)\n",
" axs.set_xticks(np.arange(-.5, preds_mat.shape[1], 1), minor=True)\n",
" axs.set_yticks(yticks)\n",
" axs.set_yticklabels(yticklabels)\n",
" axs.set_xlabel(f\"80 ms Frames\", fontsize=FS)\n",
" axs.grid(which='minor', color=grid_color_p, linestyle='-', linewidth=LW)\n",
"\n",
" plt.savefig('plot.png', dpi=300) \n",
" plt.show()\n",
"\n",
"\n",
"plot_diarout(pred_tensor_list[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that the first speaker always gets the first generic speaker label `spk0`. Sortformer model is trained to generate speaker labels in arrival time order, thus permutations of speakers are always predictable if we know the arrival time order of speakers."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Post-Processing of Speaker Timestamps\n",
"\n",
"In the previous steps, we have obtained predictions of speaker activity in a form of Tensors. While the floating point probability values can be interpreted as speaker probabilities, these values are not designed to consumed as is and still requires to be processed to be segment information which has start and end time for each time stamp.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from omegaconf import OmegaConf\n",
"import os\n",
"import wget\n",
"import pandas as pd\n",
"from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object\n",
"from pyannote.metrics.diarization import DiarizationErrorRate\n",
"\n",
"ROOT = os.getcwd()\n",
"data_dir = os.path.join(ROOT,'data')\n",
"\n",
"yaml_name=\"sortformer_diar_4spk-v1_dihard3-dev.yaml\"\n",
"MODEL_CONFIG = os.path.join(data_dir, yaml_name)\n",
"if True:\n",
" config_url = f\"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/post_processing/{yaml_name}\"\n",
" MODEL_CONFIG = wget.download(config_url, data_dir)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The post-processing yaml file `sortformer_diar_4spk-v1_dihard3-dev.yaml` is containing parameters that are optimized to have the lowest Diarization Error Rate (DER) on the [DiHARD 3 development set](https://catalog.ldc.upenn.edu/LDC2022S12)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from nemo.collections.asr.parts.utils.vad_utils import load_postprocessing_from_yaml\n",
"import json\n",
"from omegaconf import OmegaConf \n",
"post_processing_params = load_postprocessing_from_yaml(MODEL_CONFIG)\n",
"print(json.dumps(OmegaConf.to_container(post_processing_params), indent=4))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Post-processing of Speaker Timestamps\n",
"\n",
"The parameters in post-processing yaml configurations perform the following tasks:\n",
"\n",
"| **Parameter** | **Description** |\n",
"|-------------------:|--------------------------------------------------------------|\n",
"| **onset** | Onset threshold for detecting the beginning of a speech segment. |\n",
"| **offset** | Offset threshold for detecting the end of a speech segment. |\n",
"| **pad_onset** | Adds the specified duration at the beginning of each speech segment. |\n",
"| **pad_offset** | Adds the specified duration at the end of each speech segment. |\n",
"| **min_duration_on**| Removes short speech segments if the duration is less than the specified minimum duration. |\n",
"| **min_duration_off**| Removes short silences if the duration is less than the specified minimum duration. |\n",
"\n",
"Now let's check the diarization output timestamps and compare how post-processing changes the diarization output."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def show_diar_df(pred_session_list):\n",
" data = [segment.split() for segment in pred_session_list]\n",
" df = pd.DataFrame(data, columns=['Start Time', 'End Time', 'Speaker'])\n",
" print(df)\n",
"\n",
"# Call `diarize` method without postprocessing params\n",
"pred_list_bn = diar_model.diarize(audio=an4_audio)\n",
"print(f\" [Default Binarized Diarization Output]: \")\n",
"show_diar_df(pred_list_bn[0])\n",
"\n",
"# Call `diarize` method with postprocessing params\n",
"pred_list_pp = diar_model.diarize(audio=an4_audio, postprocessing_yaml=MODEL_CONFIG)\n",
"print(f\" [Post-processed Diarization Output]: \")\n",
"show_diar_df(pred_list_pp[0])\n",
"\n",
"print(f\" [Ground-truth Diarization Output]: \")\n",
"ref_labels = rttm_to_labels(an4_rttm)\n",
"show_diar_df(ref_labels)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see that the post-processed segments have more on-set padding while the differences are not significant. \n",
"\n",
"Now let's calculate DER (Diarization Error Rate) between the predicted labels and the ground-truth labels for both raw binarized and post-processed diarization outputs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get the refernce labels from ground-truth RTTM file\n",
"ref_labels = rttm_to_labels(an4_rttm)\n",
"\n",
"reference = labels_to_pyannote_object(ref_labels, uniq_name=\"binarize\")\n",
"hypothesis1 = labels_to_pyannote_object(pred_list_bn[0], uniq_name=\"binarize\")\n",
"der_metric1 = DiarizationErrorRate(collar=0, skip_overlap=False)\n",
"der_metric1(reference, hypothesis1, detailed=True)\n",
"print(f\"Raw Binarization DER: {abs(der_metric1):.6f}\")\n",
"\n",
"reference = labels_to_pyannote_object(ref_labels, uniq_name=\"post_processing\")\n",
"hypothesis2 = labels_to_pyannote_object(pred_list_pp[0], uniq_name=\"post_processing\")\n",
"der_metric2 = DiarizationErrorRate(collar=0, skip_overlap=False)\n",
"der_metric2(reference, hypothesis2, detailed=True)\n",
"print(f\"Post-Processing DER: {abs(der_metric2):.6f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since the diarization output with post-processing is optimized to get the lowest DER for the given sigmoid tensor outputs, it generaly gives lower DER values when compared to the raw binarized diarization output. To achieve the lowest DER, it is recommended to use the post-processing parameters that are optimized for your dataset of interest and your diarization model. "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
},
"vscode": {
"interpreter": {
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
|