diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..1cc9c2d7840f60642719ce2e6dedc6c49e1d5e65 Binary files /dev/null and b/.DS_Store differ diff --git a/__pycache__/gradio.cpython-310.pyc b/__pycache__/gradio.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bc72001b10fb6697d2a06c9dff4719128c7844c Binary files /dev/null and b/__pycache__/gradio.cpython-310.pyc differ diff --git a/__pycache__/model.cpython-310.pyc b/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04e455c8dae50f997f997f36b7ed5508dd45b57b Binary files /dev/null and b/__pycache__/model.cpython-310.pyc differ diff --git a/__pycache__/model.cpython-39.pyc b/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69db1ae478c22f2691081fe53998c01ff522d7a2 Binary files /dev/null and b/__pycache__/model.cpython-39.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..629e57ad76d4343410ca40ed0ed5488b47fb8c35 --- /dev/null +++ b/app.py @@ -0,0 +1,67 @@ +import torch +import torchaudio +import voicebox.src.attacks.offline.perturbation.voicebox.voicebox as vb #To access VoiceBox class +#import voicebox.src.attacks.online.voicebox_streamer as streamer #To access VoiceBoxStreamer class +import numpy as np +from voicebox.src.constants import PPG_PRETRAINED_PATH + +#Set voicebox default parameters +LOOKAHEAD = 5 +voicebox_kwargs={'win_length': 256, + 'ppg_encoder_hidden_size': 256, + 'use_phoneme_encoder': True, + 'use_pitch_encoder': True, + 'use_loudness_encoder': True, + 'spec_encoder_lookahead_frames': 0, + 'spec_encoder_type': 'mel', + 'spec_encoder_mlp_depth': 2, + 'bottleneck_lookahead_frames': LOOKAHEAD, + 'ppg_encoder_path': PPG_PRETRAINED_PATH, + 'n_bands': 128, + 'spec_encoder_hidden_size': 512, + 'bottleneck_skip': True, + 'bottleneck_hidden_size': 512, + 'bottleneck_feedforward_size': 512, + 'bottleneck_type': 'lstm', + 'bottleneck_depth': 2, + 'control_eps': 0.5, + 'projection_norm': float('inf'), + 'conditioning_dim': 512} + +#Load pretrained model: +model = vb.VoiceBox(**voicebox_kwargs) +model.load_state_dict(torch.load('voicebox/pretrained/voicebox/voicebox_final.pt', map_location=torch.device('cpu')), strict=True) +model.eval() + +#Define function to convert final audio format: +def float32_to_int16(waveform): + waveform = waveform / np.abs(waveform).max() + waveform = waveform * 32767 + waveform = waveform.astype(np.int16) + waveform = waveform.ravel() + return waveform + +#Define predict function: +def predict(inp): + #How to transform audio from string to tensor + waveform, sample_rate = torchaudio.load(inp) + + #Run model without changing weights + with torch.no_grad(): + waveform = model(waveform) + + #Transform output audio into gradio-readable format + waveform = waveform.numpy() + waveform = float32_to_int16(waveform) + return sample_rate, waveform + +#Set up gradio interface +import gradio as gr + +interface = gr.Interface( + fn=predict, + inputs=gr.Audio(type="filepath"), + outputs=gr.Audio() +) + +interface.launch() \ No newline at end of file diff --git a/example.wav b/example.wav new file mode 100644 index 0000000000000000000000000000000000000000..a49bc7ea69df30e9b40051fe97ce1841e8770844 Binary files /dev/null and b/example.wav differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4c481c8e1b55b98dacbd6cce531424c577d37c1d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +torch==1.10.0 +torchaudio==0.10.0 +torchvision +torchcrepe +tensorboard +textgrid +Pillow +numpy +tqdm +jiwer +librosa +pandas +protobuf==3.20.0 +git+https://github.com/ludlows/python-pesq#egg=pesq +psutil +pystoi +pytest +pyworld +pyyaml +matplotlib +seaborn +ipython +scipy +scikit-learn +ipywebrtc +argbind +sounddevice +keyboard \ No newline at end of file diff --git a/voicebox/.DS_Store b/voicebox/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/voicebox/.DS_Store differ diff --git a/voicebox/LICENSE b/voicebox/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/voicebox/README.md b/voicebox/README.md new file mode 100644 index 0000000000000000000000000000000000000000..abb43dc6b31fd0303ed4ffa03246789774930793 --- /dev/null +++ b/voicebox/README.md @@ -0,0 +1,136 @@ +

VoiceBlock

+

Privacy through Real-Time Adversarial Attacks with Audio-to-Audio Models

+
+ +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/???/???.ipynb) +[![Demo](https://img.shields.io/badge/Web-Demo-blue)](https://master.d3hvhbnf7qxjtf.amplifyapp.com/) +[![MIT license](https://img.shields.io/badge/License-MIT-blue.svg)](/LICENSE) + +
+

+ + +## Contents + +* Installation +* Reproducing Results +* Streaming Implementation +* Citation + +

Installation

+ +1. Clone the repository: + + git clone https://github.com/voiceboxneurips/voicebox.git + +2. We recommend working from a clean environment, e.g. using `conda`: + + conda create --name voicebox python=3.9 + source activate voicebox + +3. Install dependencies: + + cd voicebox + pip install -r requirements.txt + pip install -e . + +4. Grant permissions: + + chmod -R u+x scripts/ + +

Reproducing Results

+ +To reproduce our results, first download the corresponding data. Note that to download the [VoxCeleb1 dataset](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html), you must register and obtain a username and password. + +| Task | Dataset (Size) | Command | +|---|---|---| +| Objective evaluation | VoxCeleb1 (39G) | `python scripts/downloads/download_voxceleb.py --subset=1 --username= --password=` | +| WER / supplemental evaluations | LibriSpeech `train-clean-360` (23G) | `./scripts/downloads/download_librispeech_eval.sh` | +| Train attacks | LibriSpeech `train-clean-100` (11G) | `./scripts/downloads/download_librispeech_train.sh` | + + +We provide scripts to reproduce our experiments and save results, including generated audio, to named and time-stamped subdirectories within `runs/`. To reproduce our objective evaluation experiments using pre-trained attacks, run: + +``` +python scripts/experiments/evaluate.py +``` + +To reproduce our training, run: + +``` +python scripts/experiments/train.py +``` + +

Streaming Implementation

+ +As a proof of concept, we provide a streaming implementation of VoiceBox capable of modifying user audio in real-time. Here, we provide installation instructions for MacOS and Ubuntu 20.04. + +

MacOS

+ +See video below: + + +

+
+ +

Ubuntu 20.04

+ + +1. Open a terminal and follow the [installation instructions](#install) above. Change directory to the root of this repository. + +2. Run the following command: + + pacmd load-module module-null-sink sink_name=voicebox sink_properties=device.description=voicebox + + If you are using PipeWire instead of PulseAudio: + + pactl load-module module-null-sink media.class=Audio/Sink sink_name=voicebox sink_properties=device.description=voicebox + + PulseAudio is the default on Ubuntu. If you haven't changed your system defaults, you are probably using PulseAudio. This will add "voicebox" as an output device. Select it as the input to your chosen audio software. + +3. Find which audio device to read and write from. In your conda environment, run: + + python -m sounddevice + + You will get output similar to this: + + 0 HDA Intel HDMI: 0 (hw:0,3), ALSA (0 in, 8 out) + 1 HDA Intel HDMI: 1 (hw:0,7), ALSA (0 in, 8 out) + 2 HDA Intel HDMI: 2 (hw:0,8), ALSA (0 in, 8 out) + 3 HDA Intel HDMI: 3 (hw:0,9), ALSA (0 in, 8 out) + 4 HDA Intel HDMI: 4 (hw:0,10), ALSA (0 in, 8 out) + 5 hdmi, ALSA (0 in, 8 out) + 6 jack, ALSA (2 in, 2 out) + 7 pipewire, ALSA (64 in, 64 out) + 8 pulse, ALSA (32 in, 32 out) + * 9 default, ALSA (32 in, 32 out) + + In this example, we are going to route the audio through PipeWire (channel 7). This will be our INPUT_NUM and OUTPUT_NUM + +4. First, we need to create a conditioning embedding. To do this, run the enrollment script and follow its on-screen instructions: + + python scripts/streamer/enroll.py --input INPUT_NUM + +5. We can now use the streamer. Run: + + python scripts/stream.py --input INPUT_NUM --output OUTPUT_NUM + +6. Once the streamer is running, open `pavucontrol`. + + a. In `pavucontrol`, go to the "Playback" tab and find "ALSA pug-in [python3.9]: ALSA Playback on". Set the output to "voicebox". + + b. Then, go to "Recording" and find "ALSA pug-in [python3.9]: ALSA Playback from", and set the input to your desired microphone device. + +

Citation

+ +If you use this your academic research, please cite the following: + +``` +@inproceedings{authors2022voicelock, +title={VoiceBlock: Privacy through Real-Time Adversarial Attacks with Audio-to-Audio Models}, +author={Patrick O'Reilly, Andreas Bugler, Keshav Bhandari, Max Morrison, Bryan Pardo}, +booktitle={Neural Information Processing Systems}, +month={November}, +year={2022} +} +``` diff --git a/voicebox/cache/.gitkeep b/voicebox/cache/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/voicebox/data/.gitkeep b/voicebox/data/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/voicebox/figures/demo_thumbnail.png b/voicebox/figures/demo_thumbnail.png new file mode 100644 index 0000000000000000000000000000000000000000..2ffd514c3d5c016bee4bcb894a1ff0322c4febfe Binary files /dev/null and b/voicebox/figures/demo_thumbnail.png differ diff --git a/voicebox/figures/use_diagram_embeddings.png b/voicebox/figures/use_diagram_embeddings.png new file mode 100644 index 0000000000000000000000000000000000000000..64f91b8205aff9addd17badd788bff980c51516f Binary files /dev/null and b/voicebox/figures/use_diagram_embeddings.png differ diff --git a/voicebox/figures/vb_color_logo.png b/voicebox/figures/vb_color_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..f66afaf5897e18d5691d72fd4c23da2afab0aa8d Binary files /dev/null and b/voicebox/figures/vb_color_logo.png differ diff --git a/voicebox/figures/voicebox_untargeted_conditioning_draft.png b/voicebox/figures/voicebox_untargeted_conditioning_draft.png new file mode 100644 index 0000000000000000000000000000000000000000..ab4ea629b31828dcc8d70976dd8ed31b512bed11 Binary files /dev/null and b/voicebox/figures/voicebox_untargeted_conditioning_draft.png differ diff --git a/voicebox/pretrained/denoiser/demucs/dns_48.pt b/voicebox/pretrained/denoiser/demucs/dns_48.pt new file mode 100644 index 0000000000000000000000000000000000000000..25d833ca2ac2a1c1e388160b21acc875d44ebe6c --- /dev/null +++ b/voicebox/pretrained/denoiser/demucs/dns_48.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4cfd4151600ed611d4af05083f4633d4fc31b53761cff8a185293346df745988 +size 75486933 diff --git a/voicebox/pretrained/phoneme/causal_ppg_128_hidden_128_hop.pt b/voicebox/pretrained/phoneme/causal_ppg_128_hidden_128_hop.pt new file mode 100644 index 0000000000000000000000000000000000000000..a4e0dd1aa5194067b9db9cf7a9c1a2431972da39 --- /dev/null +++ b/voicebox/pretrained/phoneme/causal_ppg_128_hidden_128_hop.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be4c7a60c9af77e50af86924df8b73eb0c861a46f461e3bfe825c523a0a1a969 +size 1175695 diff --git a/voicebox/pretrained/phoneme/causal_ppg_256_hidden.pt b/voicebox/pretrained/phoneme/causal_ppg_256_hidden.pt new file mode 100644 index 0000000000000000000000000000000000000000..fcccc29730f2ed4f1de719363e88803b66052d59 --- /dev/null +++ b/voicebox/pretrained/phoneme/causal_ppg_256_hidden.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e8f20e4973a6b91002c97605f993cf6e16a24ca9d0d39e183438a8c16d85c87 +size 4556495 diff --git a/voicebox/pretrained/phoneme/causal_ppg_256_hidden_256_hop.pt b/voicebox/pretrained/phoneme/causal_ppg_256_hidden_256_hop.pt new file mode 100644 index 0000000000000000000000000000000000000000..a8164f93c0c22c4da83c6b9d228dc2bcc8c7865c --- /dev/null +++ b/voicebox/pretrained/phoneme/causal_ppg_256_hidden_256_hop.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0836df2f8465b53d4e0b5b14f1d1ef954b3570d6f95f1af22c3ac19b3e10099 +size 4573903 diff --git a/voicebox/pretrained/phoneme/causal_ppg_256_hidden_512_hop.pt b/voicebox/pretrained/phoneme/causal_ppg_256_hidden_512_hop.pt new file mode 100644 index 0000000000000000000000000000000000000000..16dcc6b9ce3c3c38dc89a54ed1f85883c6a5b55d --- /dev/null +++ b/voicebox/pretrained/phoneme/causal_ppg_256_hidden_512_hop.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a860d6f01058dc14b984845d27e681b5fe7c3bfffe41350e2e6e0f92e72778ad +size 4608719 diff --git a/voicebox/pretrained/phoneme/ppg_causal_small.pt b/voicebox/pretrained/phoneme/ppg_causal_small.pt new file mode 100644 index 0000000000000000000000000000000000000000..40d4c0527616a378a255a599e9b1041ba5a18fbe --- /dev/null +++ b/voicebox/pretrained/phoneme/ppg_causal_small.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4627bc2b63798df3391fe5c9ccbd72b929dc146b84f0fe61d1aa22848d107973 +size 18002639 diff --git a/voicebox/pretrained/speaker/resemblyzer/resemblyzer.pt b/voicebox/pretrained/speaker/resemblyzer/resemblyzer.pt new file mode 100644 index 0000000000000000000000000000000000000000..11791c885a2227550f0d3f75f3ea221fad6b413e --- /dev/null +++ b/voicebox/pretrained/speaker/resemblyzer/resemblyzer.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:afb2230a894f5a8f91263ff0b4811bde1ea5981bedda45a579c225e5a602ada3 +size 5697307 diff --git a/voicebox/pretrained/speaker/resnetse34v2/resnetse34v2.pt b/voicebox/pretrained/speaker/resnetse34v2/resnetse34v2.pt new file mode 100644 index 0000000000000000000000000000000000000000..948d7a7de32487b2173ebb5de6bcc1054436e28d --- /dev/null +++ b/voicebox/pretrained/speaker/resnetse34v2/resnetse34v2.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d96a4dad0118e9945bc7e676d8e5ff34d493ca2209fe188b3f982005132369bc +size 32311667 diff --git a/voicebox/pretrained/speaker/yvector/yvector.pt b/voicebox/pretrained/speaker/yvector/yvector.pt new file mode 100644 index 0000000000000000000000000000000000000000..d83afa1b9a3601498aa49ed5dd56fd46b490c465 --- /dev/null +++ b/voicebox/pretrained/speaker/yvector/yvector.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2b4228cc772e689f800f1f9dc91d4ef4ee289e7e62f2822805edfc5b7faf399 +size 57703939 diff --git a/voicebox/pretrained/universal/universal_final.pt b/voicebox/pretrained/universal/universal_final.pt new file mode 100644 index 0000000000000000000000000000000000000000..9043f6b9319959580cdac3caed62264381701365 --- /dev/null +++ b/voicebox/pretrained/universal/universal_final.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f435535934f6c8c24fda42c251e65f41627b0660d3420ba1c694e25a82be033e +size 128811 diff --git a/voicebox/pretrained/voicebox/voicebox_final.pt b/voicebox/pretrained/voicebox/voicebox_final.pt new file mode 100644 index 0000000000000000000000000000000000000000..3df47a6574632faec9464b90558d27b487428907 --- /dev/null +++ b/voicebox/pretrained/voicebox/voicebox_final.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb26234cc493182545dbfcc74501f6df7e90347ca3e2a94a7966978325a34ccd +size 30232012 diff --git a/voicebox/pretrained/voicebox/voicebox_final.yaml b/voicebox/pretrained/voicebox/voicebox_final.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ecd615967d6df61f4f9596a184f60d4839b9b59c --- /dev/null +++ b/voicebox/pretrained/voicebox/voicebox_final.yaml @@ -0,0 +1,20 @@ +win_length: 256 +ppg_encoder_hidden_size: 256 +use_phoneme_encoder: True +use_pitch_encoder: True +use_loudness_encoder: True +spec_encoder_lookahead_frames: 0 +spec_encoder_type: 'mel' +spec_encoder_mlp_depth: 2 +bottleneck_lookahead_frames: 5 +ppg_encoder_path: 'pretrained/phoneme/causal_ppg_256_hidden.pt' +n_bands: 128 +spec_encoder_hidden_size: 512 +bottleneck_skip: True +bottleneck_hidden_size: 512 +bottleneck_feedforward_size: 512 +bottleneck_type: 'lstm' +bottleneck_depth: 2 +control_eps: 0.5 +projection_norm: 'inf' +conditioning_dim: 512 \ No newline at end of file diff --git a/voicebox/requirements.txt b/voicebox/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4c481c8e1b55b98dacbd6cce531424c577d37c1d --- /dev/null +++ b/voicebox/requirements.txt @@ -0,0 +1,28 @@ +torch==1.10.0 +torchaudio==0.10.0 +torchvision +torchcrepe +tensorboard +textgrid +Pillow +numpy +tqdm +jiwer +librosa +pandas +protobuf==3.20.0 +git+https://github.com/ludlows/python-pesq#egg=pesq +psutil +pystoi +pytest +pyworld +pyyaml +matplotlib +seaborn +ipython +scipy +scikit-learn +ipywebrtc +argbind +sounddevice +keyboard \ No newline at end of file diff --git a/voicebox/scripts/downloads/download_librispeech_eval.sh b/voicebox/scripts/downloads/download_librispeech_eval.sh new file mode 100755 index 0000000000000000000000000000000000000000..83d7f3bc94878a99ddf3cc3a6b84bf1c3d5197c7 --- /dev/null +++ b/voicebox/scripts/downloads/download_librispeech_eval.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +set -e + +DOWNLOADS_SCRIPTS_DIR=$(eval dirname "$(readlink -f "$0")") +SCRIPTS_DIR="$(dirname "$DOWNLOADS_SCRIPTS_DIR")" +PROJECT_DIR="$(dirname "$SCRIPTS_DIR")" + +DATA_DIR="${PROJECT_DIR}/data/" +CACHE_DIR="${PROJECT_DIR}/cache/" + +mkdir -p "${DATA_DIR}" +mkdir -p "${CACHE_DIR}" + +# download train-clean-360 subset +echo "downloading LibriSpeech train-clean-360..." +wget http://www.openslr.org/resources/12/train-clean-360.tar.gz + +# extract train-clean-360 subset +echo "extracting LibriSpeech train-clean-360..." +tar -xf train-clean-360.tar.gz \ + -C "${DATA_DIR}" + +# delete archive +rm -f "train-clean-360.tar.gz" diff --git a/voicebox/scripts/downloads/download_librispeech_train.sh b/voicebox/scripts/downloads/download_librispeech_train.sh new file mode 100755 index 0000000000000000000000000000000000000000..6e6cc20cd79c57848d69ee916a8f64a086b44077 --- /dev/null +++ b/voicebox/scripts/downloads/download_librispeech_train.sh @@ -0,0 +1,54 @@ +#!/bin/bash + +set -e + +DOWNLOADS_SCRIPTS_DIR=$(eval dirname "$(readlink -f "$0")") +SCRIPTS_DIR="$(dirname "$DOWNLOADS_SCRIPTS_DIR")" +PROJECT_DIR="$(dirname "$SCRIPTS_DIR")" + +DATA_DIR="${PROJECT_DIR}/data/" +CACHE_DIR="${PROJECT_DIR}/cache/" + +mkdir -p "${DATA_DIR}" +mkdir -p "${CACHE_DIR}" + +# download test-clean subset +echo "downloading LibriSpeech test-clean..." +wget http://www.openslr.org/resources/12/test-clean.tar.gz + +# extract test-clean subset +echo "extracting LibriSpeech test-clean..." +tar -xf test-clean.tar.gz \ + -C "${DATA_DIR}" + +# delete archive +rm -f "test-clean.tar.gz" + +# download test-other subset +echo "downloading LibriSpeech test-other..." +wget http://www.openslr.org/resources/12/test-other.tar.gz + +# extract test-other subset +echo "extracting LibriSpeech test-other..." +tar -xf test-other.tar.gz \ + -C "${DATA_DIR}" + +# delete archive +rm -f "test-other.tar.gz" + +# download train-clean-100 subset +echo "downloading LibriSpeech train-clean-100..." +wget http://www.openslr.org/resources/12/train-clean-100.tar.gz + +# extract train-clean-100 subset +echo "extracting LibriSpeech train-clean-100..." +tar -xf train-clean-100.tar.gz \ + -C "${DATA_DIR}" + +# delete archive +rm -f "train-clean-100.tar.gz" + +# download LibriSpeech alignments dataset +wget -O alignments.zip https://zenodo.org/record/2619474/files/librispeech_alignments.zip?download=1 +unzip -d "${DATA_DIR}/LibriSpeech/" alignments.zip +rm -f alignments.zip diff --git a/voicebox/scripts/downloads/download_rir_noise.sh b/voicebox/scripts/downloads/download_rir_noise.sh new file mode 100755 index 0000000000000000000000000000000000000000..c7000148b565fbbd826d5c391dfa424db7e567f7 --- /dev/null +++ b/voicebox/scripts/downloads/download_rir_noise.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +set -e + +DOWNLOADS_SCRIPTS_DIR=$(eval dirname "$(readlink -f "$0")") +SCRIPTS_DIR="$(dirname "$DOWNLOADS_SCRIPTS_DIR")" +PROJECT_DIR="$(dirname "$SCRIPTS_DIR")" + +DATA_DIR="${PROJECT_DIR}/data/" +CACHE_DIR="${PROJECT_DIR}/cache/" + +REAL_RIR_DIR="${DATA_DIR}/rir/real/" +SYNTHETIC_RIR_DIR="${DATA_DIR}/rir/synthetic/" +ROOM_NOISE_DIR="${DATA_DIR}/noise/room/" +PS_NOISE_DIR="${DATA_DIR}/noise/pointsource/" + +mkdir -p "${REAL_RIR_DIR}" +mkdir -p "${SYNTHETIC_RIR_DIR}" +mkdir -p "${ROOM_NOISE_DIR}" +mkdir -p "${PS_NOISE_DIR}" + +# download RIR/noise composite dataset +echo "downloading RIR/noise dataset..." +wget -O "${DATA_DIR}/rirs_noises.zip" https://www.openslr.org/resources/28/rirs_noises.zip + +# extract RIR/noise composite dataset +echo "unzipping RIR/noise dataset..." +unzip "${DATA_DIR}/rirs_noises.zip" -d "${DATA_DIR}/" + +# delete archive +rm -f "${DATA_DIR}/rirs_noises.zip" + +# organize pointsource noise data +echo "extracting point-source noise data" +cp -a "${DATA_DIR}/RIRS_NOISES/pointsource_noises"/. "${PS_NOISE_DIR}" + +# organize room noise data +echo "extracting room noise data" +room_noises=($(find "${DATA_DIR}/RIRS_NOISES/real_rirs_isotropic_noises/" -maxdepth 1 -name '*noise*' -type f)) +cp -- "${room_noises[@]}" "${ROOM_NOISE_DIR}" + +# organize real RIR data +echo "extracting recorded RIR data" +rirs=($(find "${DATA_DIR}/RIRS_NOISES/real_rirs_isotropic_noises/" ! -name '*noise*' )) +cp -- "${rirs[@]}" "${REAL_RIR_DIR}" + +# organize synthetic RIR data +echo "extracting synthetic RIR data" +cp -a "${DATA_DIR}/RIRS_NOISES/simulated_rirs"/. "${SYNTHETIC_RIR_DIR}" + +# delete redundant data +rm -rf "${DATA_DIR}/RIRS_NOISES/" + +# separate near-field and far-field RIRs +NEARFIELD_RIR_DIR="${REAL_RIR_DIR}/nearfield/" +FARFIELD_RIR_DIR="${REAL_RIR_DIR}/farfield/" + +mkdir -p "${NEARFIELD_RIR_DIR}" +mkdir -p "${FARFIELD_RIR_DIR}" + +# read list of far-field RIRs +readarray -t FF_RIR_LIST < "${DOWNLOADS_SCRIPTS_DIR}/ff_rir.txt" + +# move far-field RIRs +for name in "${FF_RIR_LIST[@]}"; do + mv "$name" "${FARFIELD_RIR_DIR}/$(basename "$name")" +done + +# move remaining near-field RIRs +for name in "${REAL_RIR_DIR}"/*.wav; do + mv "$name" "${NEARFIELD_RIR_DIR}/$(basename "$name")" +done + diff --git a/voicebox/scripts/downloads/download_voxceleb.py b/voicebox/scripts/downloads/download_voxceleb.py new file mode 100755 index 0000000000000000000000000000000000000000..3ed7a1b9c5c404c9e9a2d1f8a58945160ee4ffd8 --- /dev/null +++ b/voicebox/scripts/downloads/download_voxceleb.py @@ -0,0 +1,189 @@ +import argparse +from pathlib import Path +import subprocess +import hashlib +import tarfile +from zipfile import ZipFile + +from src.constants import VOXCELEB1_DATA_DIR, VOXCELEB2_DATA_DIR +from src.utils import ensure_dir + +################################################################################ +# Download VoxCeleb1 dataset using valid credentials +################################################################################ + + +def parse_args(): + + """Parse command-line arguments""" + parser = argparse.ArgumentParser(add_help=False) + + parser.add_argument( + '--subset', + type=int, + default=1, + help='Specify which VoxCeleb subset to download: 1 or 2' + ) + + parser.add_argument( + '--username', + type=str, + default=None, + help='User name provided by VGG to access VoxCeleb dataset' + ) + + parser.add_argument( + '--password', + type=str, + default=None, + help='Password provided by VGG to access VoxCeleb dataset' + ) + + return parser.parse_args() + + +def md5(f: str): + """ + Return MD5 checksum for file. Code adapted from voxceleb_trainer repository: + https://github.com/clovaai/voxceleb_trainer/blob/master/dataprep.py + """ + + hash_md5 = hashlib.md5() + with open(f, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +def download(username: str, + password: str, + save_path: str, + lines: list): + """ + Given a list of dataset shards formatted as , download + each using `wget` and verify checksums. Code adapted from voxceleb_trainer + repository: + https://github.com/clovaai/voxceleb_trainer/blob/master/dataprep.py + """ + + for line in lines: + url = line.split()[0] + md5gt = line.split()[1] + outfile = url.split('/')[-1] + + # download files + out = subprocess.call( + f'wget {url} --user {username} --password {password} -O {save_path}' + f'/{outfile}', shell=True) + if out != 0: + raise ValueError(f'Download failed for {url}') + + # verify checksum + md5ck = md5(f'{save_path}/{outfile}') + if md5ck == md5gt: + print(f'Checksum successful for {outfile}') + else: + raise Warning(f'Checksum failed for {outfile}') + + +def concatenate(save_path: str, lines: list): + """ + Given a specification in the format , concatenate all + downloaded data shards matching FMT into the file FILENAME and verify + checksums. Code adapted from voxceleb_trainer repository: + https://github.com/clovaai/voxceleb_trainer/blob/master/dataprep.py + """ + + for line in lines: + infile = line.split()[0] + outfile = line.split()[1] + md5gt = line.split()[2] + + # concatenate shards + out = subprocess.call( + f'cat {save_path}/{infile} > {save_path}/{outfile}', shell=True) + + # verify checksum + md5ck = md5(f'{save_path}/{outfile}') + if md5ck == md5gt: + print(f'Checksum successful for {outfile}') + else: + raise Warning(f'Checksum failed for {outfile}') + + # delete shards + out = subprocess.call( + f'rm {save_path}/{infile}', shell=True) + + +def full_extract(save_path: str, f: str): + """ + Extract contents of compressed archive to data directory + """ + + save_path = str(save_path) + f = str(f) + + print(f'Extracting {f}') + + if f.endswith(".tar.gz"): + with tarfile.open(f, "r:gz") as tar: + tar.extractall(save_path) + + elif f.endswith(".zip"): + with ZipFile(f, 'r') as zf: + zf.extractall(save_path) + + +def main(): + + args = parse_args() + + # prepare to load dataset file paths + downloads_dir = Path(__file__).parent + + if args.subset == 1: + data_dir = VOXCELEB1_DATA_DIR + elif args.subset == 2: + data_dir = VOXCELEB2_DATA_DIR + else: + raise ValueError(f'Invalid VoxCeleb subset {args.subset}') + + ensure_dir(data_dir) + + # load dataset file paths + with open(downloads_dir / f'voxceleb{args.subset}_file_parts.txt', 'r') as f: + file_parts_list = f.readlines() + + # load output file paths + with open(downloads_dir / f'voxceleb{args.subset}_files.txt', 'r') as f: + files_list = f.readlines() + + # download subset + download( + username=args.username, + password=args.password, + save_path=data_dir, + lines=file_parts_list + ) + + # merge shards + concatenate(save_path=data_dir, lines=files_list) + + # account for test data + archives = [file.split()[1] for file in files_list] + test = f"vox{args.subset}_test_{'wav' if args.subset == 1 else 'aac'}.zip" + archives.append(test) + + # extract all compressed data + for file in archives: + full_extract(data_dir, data_dir / file) + + # organize extracted data + out = subprocess.call(f'mv {data_dir}/dev/aac/* {data_dir}/aac/ && rm -r ' + f'{data_dir}/dev', shell=True) + out = subprocess.call(f'mv -v {data_dir}/{"wav" if args.subset == 1 else "aac"}/*' + f' {data_dir}/voxceleb{args.subset}', shell=True) + + +if __name__ == "__main__": + main() diff --git a/voicebox/scripts/downloads/ff_rir.txt b/voicebox/scripts/downloads/ff_rir.txt new file mode 100755 index 0000000000000000000000000000000000000000..9cea87727bc027e23944becec05f33ddd8e0cd88 --- /dev/null +++ b/voicebox/scripts/downloads/ff_rir.txt @@ -0,0 +1,132 @@ +data/rir/real/air_type1_air_binaural_lecture_0_1.wav +data/rir/real/RWCP_type3_rir_cirline_ofc_imp_rev.wav +data/rir/real/RWCP_type1_rir_cirline_jr2_imp110.wav +data/rir/real/air_type1_air_binaural_aula_carolina_1_4_90_3.wav +data/rir/real/RVB2014_type1_rir_largeroom2_far_anglb.wav +data/rir/real/air_type1_air_binaural_stairway_1_1_60.wav +data/rir/real/air_type1_air_binaural_aula_carolina_1_3_90_3.wav +data/rir/real/air_type1_air_binaural_lecture_0_5.wav +data/rir/real/air_type1_air_binaural_stairway_1_2_30.wav +data/rir/real/air_type1_air_binaural_stairway_1_1_30.wav +data/rir/real/air_type1_air_binaural_stairway_1_2_15.wav +data/rir/real/air_type1_air_binaural_stairway_1_2_165.wav +data/rir/real/air_type1_air_binaural_stairway_1_2_75.wav +data/rir/real/air_type1_air_binaural_lecture_0_3.wav +data/rir/real/air_type1_air_binaural_stairway_1_2_0.wav +data/rir/real/air_type1_air_binaural_stairway_1_3_0.wav +data/rir/real/RWCP_type2_rir_cirline_jr1_imp110.wav +data/rir/real/air_type1_air_binaural_aula_carolina_1_5_90_3.wav +data/rir/real/RVB2014_type1_rir_largeroom1_far_anglb.wav +data/rir/real/air_type1_air_binaural_lecture_1_1.wav +data/rir/real/RVB2014_type1_rir_largeroom1_far_angla.wav +data/rir/real/air_type1_air_binaural_aula_carolina_1_7_90_3.wav +data/rir/real/RWCP_type2_rir_cirline_ofc_imp070.wav +data/rir/real/RWCP_type1_rir_cirline_jr1_imp070.wav +data/rir/real/air_type1_air_binaural_stairway_1_3_150.wav +data/rir/real/air_type1_air_binaural_lecture_1_5.wav +data/rir/real/RWCP_type1_rir_cirline_jr1_imp100.wav +data/rir/real/RWCP_type1_rir_cirline_jr2_imp100.wav +data/rir/real/RWCP_type1_rir_cirline_e2b_imp130.wav +data/rir/real/air_type1_air_phone_corridor_hfrp.wav +data/rir/real/RWCP_type1_rir_cirline_jr1_imp130.wav +data/rir/real/RVB2014_type1_rir_largeroom1_near_angla.wav +data/rir/real/air_type1_air_binaural_stairway_1_1_75.wav +data/rir/real/RWCP_type1_rir_cirline_e2b_imp150.wav +data/rir/real/air_type1_air_phone_lecture_hhp.wav +data/rir/real/air_type1_air_binaural_stairway_1_1_105.wav +data/rir/real/air_type1_air_phone_stairway_hfrp.wav +data/rir/real/air_type1_air_binaural_stairway_1_3_105.wav +data/rir/real/RWCP_type2_rir_cirline_jr1_imp090.wav +data/rir/real/RWCP_type1_rir_cirline_e2b_imp050.wav +data/rir/real/air_type1_air_phone_stairway2_hfrp.wav +data/rir/real/air_type1_air_phone_stairway2_hhp.wav +data/rir/real/RWCP_type1_rir_cirline_jr2_imp060.wav +data/rir/real/air_type1_air_binaural_stairway_1_3_90.wav +data/rir/real/RWCP_type2_rir_cirline_jr1_imp130.wav +data/rir/real/RWCP_type1_rir_cirline_e2b_imp030.wav +data/rir/real/RVB2014_type1_rir_largeroom2_near_angla.wav +data/rir/real/air_type1_air_binaural_lecture_0_6.wav +data/rir/real/RWCP_type1_rir_cirline_e2b_imp070.wav +data/rir/real/air_type1_air_phone_stairway1_hhp.wav +data/rir/real/air_type1_air_binaural_stairway_1_1_45.wav +data/rir/real/RWCP_type1_rir_cirline_ofc_imp090.wav +data/rir/real/air_type1_air_binaural_stairway_1_1_135.wav +data/rir/real/air_type1_air_binaural_stairway_1_2_180.wav +data/rir/real/RWCP_type1_rir_cirline_ofc_imp100.wav +data/rir/real/RWCP_type1_rir_cirline_ofc_imp080.wav +data/rir/real/RWCP_type2_rir_cirline_ofc_imp090.wav +data/rir/real/RWCP_type1_rir_cirline_jr2_imp080.wav +data/rir/real/air_type1_air_binaural_lecture_1_2.wav +data/rir/real/RWCP_type1_rir_cirline_ofc_imp070.wav +data/rir/real/air_type1_air_binaural_stairway_1_2_150.wav +data/rir/real/air_type1_air_binaural_lecture_1_4.wav +data/rir/real/air_type1_air_binaural_aula_carolina_1_3_0_3.wav +data/rir/real/RVB2014_type1_rir_largeroom1_near_anglb.wav +data/rir/real/air_type1_air_binaural_stairway_1_1_15.wav +data/rir/real/air_type1_air_binaural_stairway_1_1_120.wav +data/rir/real/RWCP_type1_rir_cirline_ofc_imp050.wav +data/rir/real/air_type1_air_binaural_aula_carolina_1_1_90_3.wav +data/rir/real/air_type1_air_phone_stairway_hhp.wav +data/rir/real/RWCP_type1_rir_cirline_jr2_imp120.wav +data/rir/real/RWCP_type2_rir_cirline_e2b_imp110.wav +data/rir/real/RWCP_type1_rir_cirline_e2b_imp010.wav +data/rir/real/air_type1_air_binaural_stairway_1_3_15.wav +data/rir/real/air_type1_air_binaural_stairway_1_2_135.wav +data/rir/real/air_type1_air_phone_bt_stairway_hhp.wav +data/rir/real/RWCP_type2_rir_cirline_e2b_imp070.wav +data/rir/real/RWCP_type1_rir_cirline_ofc_imp120.wav +data/rir/real/RWCP_type1_rir_cirline_ofc_imp110.wav +data/rir/real/air_type1_air_binaural_lecture_0_4.wav +data/rir/real/RWCP_type2_rir_cirline_ofc_imp050.wav +data/rir/real/air_type1_air_binaural_stairway_1_1_90.wav +data/rir/real/RWCP_type1_rir_cirline_jr2_imp090.wav +data/rir/real/air_type1_air_binaural_stairway_1_1_0.wav +data/rir/real/air_type1_air_phone_stairway1_hfrp.wav +data/rir/real/air_type1_air_binaural_lecture_1_3.wav +data/rir/real/RWCP_type1_rir_cirline_jr1_imp050.wav +data/rir/real/RWCP_type1_rir_cirline_jr1_imp080.wav +data/rir/real/air_type1_air_binaural_stairway_1_1_165.wav +data/rir/real/air_type1_air_binaural_stairway_1_2_45.wav +data/rir/real/air_type1_air_phone_bt_corridor_hhp.wav +data/rir/real/air_type1_air_binaural_aula_carolina_1_2_90_3.wav +data/rir/real/RWCP_type2_rir_cirline_ofc_imp110.wav +data/rir/real/air_type1_air_binaural_stairway_1_3_120.wav +data/rir/real/air_type1_air_binaural_aula_carolina_1_3_180_3.wav +data/rir/real/RWCP_type1_rir_cirline_e2b_imp110.wav +data/rir/real/RWCP_type1_rir_cirline_jr1_imp060.wav +data/rir/real/air_type1_air_binaural_stairway_1_3_45.wav +data/rir/real/RVB2014_type1_rir_largeroom2_far_angla.wav +data/rir/real/air_type1_air_binaural_stairway_1_2_60.wav +data/rir/real/RWCP_type2_rir_cirline_jr1_imp070.wav +data/rir/real/RWCP_type1_rir_cirline_ofc_imp130.wav +data/rir/real/air_type1_air_binaural_aula_carolina_1_3_135_3.wav +data/rir/real/air_type1_air_binaural_stairway_1_3_75.wav +data/rir/real/air_type1_air_binaural_stairway_1_1_180.wav +data/rir/real/RWCP_type1_rir_cirline_jr1_imp120.wav +data/rir/real/air_type1_air_binaural_stairway_1_3_60.wav +data/rir/real/air_type1_air_binaural_stairway_1_2_105.wav +data/rir/real/air_type1_air_binaural_stairway_1_3_135.wav +data/rir/real/air_type1_air_binaural_aula_carolina_1_3_45_3.wav +data/rir/real/air_type1_air_binaural_lecture_1_6.wav +data/rir/real/RWCP_type2_rir_cirline_e2b_imp090.wav +data/rir/real/RWCP_type1_rir_cirline_e2b_imp170.wav +data/rir/real/air_type1_air_binaural_stairway_1_2_90.wav +data/rir/real/RWCP_type1_rir_cirline_jr2_imp070.wav +data/rir/real/RWCP_type1_rir_cirline_jr1_imp110.wav +data/rir/real/air_type1_air_phone_lecture_hfrp.wav +data/rir/real/RVB2014_type1_rir_largeroom2_near_anglb.wav +data/rir/real/air_type1_air_binaural_stairway_1_3_165.wav +data/rir/real/RWCP_type2_rir_cirline_ofc_imp130.wav +data/rir/real/air_type1_air_binaural_stairway_1_1_150.wav +data/rir/real/RWCP_type1_rir_cirline_jr1_imp090.wav +data/rir/real/RWCP_type2_rir_cirline_e2b_imp130.wav +data/rir/real/RWCP_type1_rir_cirline_ofc_imp060.wav +data/rir/real/air_type1_air_binaural_stairway_1_3_180.wav +data/rir/real/RWCP_type2_rir_cirline_jr1_imp050.wav +data/rir/real/air_type1_air_binaural_stairway_1_3_30.wav +data/rir/real/air_type1_air_binaural_lecture_0_2.wav +data/rir/real/air_type1_air_binaural_aula_carolina_1_6_90_3.wav +data/rir/real/RWCP_type2_rir_cirline_e2b_imp050.wav +data/rir/real/RWCP_type1_rir_cirline_e2b_imp090.wav +data/rir/real/air_type1_air_phone_corridor_hhp.wav +data/rir/real/air_type1_air_binaural_stairway_1_2_120.wav diff --git a/voicebox/scripts/downloads/voxceleb1_file_parts.txt b/voicebox/scripts/downloads/voxceleb1_file_parts.txt new file mode 100755 index 0000000000000000000000000000000000000000..93b75ac0f09a3171593330913bbe114075ec174f --- /dev/null +++ b/voicebox/scripts/downloads/voxceleb1_file_parts.txt @@ -0,0 +1,5 @@ +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox1_dev_wav_partaa e395d020928bc15670b570a21695ed96 +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox1_dev_wav_partab bbfaaccefab65d82b21903e81a8a8020 +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox1_dev_wav_partac 017d579a2a96a077f40042ec33e51512 +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox1_dev_wav_partad 7bb1e9f70fddc7a678fa998ea8b3ba19 +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox1_test_wav.zip 185fdc63c3c739954633d50379a3d102 \ No newline at end of file diff --git a/voicebox/scripts/downloads/voxceleb1_files.txt b/voicebox/scripts/downloads/voxceleb1_files.txt new file mode 100755 index 0000000000000000000000000000000000000000..91520d13767d6b65033e52dfe53328d912b64b8d --- /dev/null +++ b/voicebox/scripts/downloads/voxceleb1_files.txt @@ -0,0 +1 @@ +vox1_dev_wav_parta* vox1_dev_wav.zip ae63e55b951748cc486645f532ba230b \ No newline at end of file diff --git a/voicebox/scripts/downloads/voxceleb2_file_parts.txt b/voicebox/scripts/downloads/voxceleb2_file_parts.txt new file mode 100755 index 0000000000000000000000000000000000000000..d1623ba880cac092e4e39dc0a779807507d889cc --- /dev/null +++ b/voicebox/scripts/downloads/voxceleb2_file_parts.txt @@ -0,0 +1,9 @@ +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partaa da070494c573e5c0564b1d11c3b20577 +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partab 17fe6dab2b32b48abaf1676429cdd06f +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partac 1de58e086c5edf63625af1cb6d831528 +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partad 5a043eb03e15c5a918ee6a52aad477f9 +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partae cea401b624983e2d0b2a87fb5d59aa60 +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partaf fc886d9ba90ab88e7880ee98effd6ae9 +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partag d160ecc3f6ee3eed54d55349531cb42e +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partah 6b84a81b9af72a9d9eecbb3b1f602e65 +http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_test_aac.zip 0d2b3ea430a821c33263b5ea37ede312 \ No newline at end of file diff --git a/voicebox/scripts/downloads/voxceleb2_files.txt b/voicebox/scripts/downloads/voxceleb2_files.txt new file mode 100755 index 0000000000000000000000000000000000000000..63cbea6b7581f9fa74ac226bf73f9254add4b71a --- /dev/null +++ b/voicebox/scripts/downloads/voxceleb2_files.txt @@ -0,0 +1 @@ +vox2_dev_aac_parta* vox2_dev_aac.zip bbc063c46078a602ca71605645c2a402 \ No newline at end of file diff --git a/voicebox/scripts/experiments/evaluate.py b/voicebox/scripts/experiments/evaluate.py new file mode 100755 index 0000000000000000000000000000000000000000..181a48aa6946e12be6035af930b551da87a4a9be --- /dev/null +++ b/voicebox/scripts/experiments/evaluate.py @@ -0,0 +1,915 @@ +import os.path + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +import psutil +import pickle + +import random +import argparse + +import librosa as li +from sklearn.utils import shuffle +from sklearn.neighbors import NearestNeighbors + +from pesq import pesq, NoUtterancesError + +from tqdm import tqdm +from sklearn.preprocessing import LabelEncoder +import numpy as np +from pathlib import Path +from tqdm import tqdm +import builtins +import math +import jiwer +from jiwer import wer, cer + +from typing import Iterable +from copy import deepcopy + +from distutils.util import strtobool + +from src.data import * +from src.constants import * +from src.models import * +from src.simulation import * +from src.preprocess import * +from src.attacks.offline import * +from src.loss import * +from src.pipelines import * +from src.utils import * + +################################################################################ +# Evaluate attacks on speaker recognition systems +################################################################################ + +EVAL_DATASET = "voxceleb" # "librispeech" +LOOKAHEAD = 5 +VOICEBOX_PATH = VOICEBOX_PRETRAINED_PATH +UNIVERSAL_PATH = UNIVERSAL_PRETRAINED_PATH +BATCH_SIZE = 20 # evaluation batch size +N_QUERY = 15 # number of query utterances per speaker +N_CONDITION = 10 # number of conditioning utterances per speaker +N_ENROLL = 20 # number of enrolled utterances per speaker +ADV_ENROLL = False # evaluate under assumption adversarial audio is enrolled +TARGETS_TRAIN = 'centroid' # 'random', 'same', 'single', 'median' +TARGETS_TEST = 'centroid' # 'random', 'same', 'single', 'median' +TRANSFER = True # evaluate attacks on unseen model +DENOISER = False # evaluate with unseen denoiser defense applied to queries +SIMULATION = False # apply noisy channel simulation to all queries in evaluation +COMPUTE_OBJECTIVE_METRICS = True # PESQ, STOI + + +def set_random_seed(seed: int = 123): + """Set random seed to allow for reproducibility""" + random.seed(seed) + torch.manual_seed(seed) + + if torch.backends.cudnn.is_available(): + # torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = True + + +def param_count(m: nn.Module, trainable: bool = False): + """Count the number of trainable parameters (weights) in a model""" + if trainable: + return builtins.sum( + [p.shape.numel() for p in m.parameters() if p.requires_grad]) + else: + return builtins.sum([p.shape.numel() for p in m.parameters()]) + + +def pad_sequence(sequences: list): + + max_len = max([s.shape[-1] for s in sequences]) + + padded = torch.zeros( + (len(sequences), 1, max_len), + dtype=sequences[0].dtype, + device=sequences[0].device) + + for i, s in enumerate(sequences): + padded[i, :, :s.shape[-1]] = s + + return padded + + +@torch.no_grad() +def compute_embeddings_batch(audio: list, + p: Pipeline, + defense: nn.Module = nn.Identity()): + """Compute batched speaker embeddings""" + + assert isinstance(p.model, SpeakerVerificationModel) + emb = [p(defense(audio[i].to(p.device))).to('cpu') for i in range(len(audio))] + emb = torch.cat(emb, dim=0) + return emb + + +@torch.no_grad() +def compute_transcripts_batch(audio: list, p: Pipeline): + """Compute batched transcripts""" + + assert isinstance(p.model, SpeechRecognitionModel) + transcripts = [] + for i in range(len(audio)): + t = p.model.transcribe(audio[i].to(p.device)) + if isinstance(t, str): + transcripts.append(t) + elif isinstance(t, list): + transcripts.extend(t) + + assert len(transcripts) == len(audio), f'Transcript format error' + + return transcripts + + +@torch.no_grad() +def compute_attack_batch(audio: list, + a: TrainableAttack, + c: torch.Tensor): + + if len(c) < len(audio): + c = c.repeat(len(audio), 1, 1) + adv = [a.perturbation(audio[i].to(a.pipeline.device), + y=c[i:i+1].to(a.pipeline.device)).to('cpu').reshape(1, 1, -1) + for i in range(len(audio))] + return adv + + +@torch.no_grad() +def compute_pesq(audio1: list, audio2: list, mode: str = 'wb'): + + assert len(audio1) == len(audio2) + scores = [] + + for i in range(len(audio1)): + try: + scores.append( + pesq(DataProperties.get('sample_rate'), + tensor_to_np(audio1[i]).flatten(), + tensor_to_np(audio2[i]).flatten(), + mode) + ) + except NoUtterancesError: + print("PESQ error, skipping audio file...") + return scores + + +@torch.no_grad() +def compute_stoi(audio1: list, audio2: list, extended: bool = False): + + assert len(audio1) == len(audio2) + scores = [] + for i in range(len(audio1)): + scores.append( + stoi(tensor_to_np(audio1[i]).flatten(), + tensor_to_np(audio2[i]).flatten(), + DataProperties.get('sample_rate'), + extended=extended) + ) + return scores + + +@torch.no_grad() +def build_ls_dataset(pipelines: dict): + """ + Build LibriSpeech evaluation dataset on disk holding: + * query audio + * query embeddings + * conditioning embeddings + * enrolled embeddings + * ground-truth query transcripts + """ + + # locate dataset + data_dir = LIBRISPEECH_DATA_DIR / 'train-clean-360' + cache_dir = CACHE_DIR / 'ls_wer_eval' + ensure_dir(cache_dir) + + assert os.path.isdir(data_dir), \ + f'LibriSpeech `train-clean-360` subset required for evaluation' + + spkr_dirs = list(data_dir.glob("*/")) + spkr_dirs = [s_d for s_d in spkr_dirs if os.path.isdir(s_d)] + + # catalog audio and load transcripts + for spkr_dir in tqdm(spkr_dirs, total=len(spkr_dirs), desc='Building dataset'): + + # identify speaker + spkr_id = spkr_dir.parts[-1] + + # check whether cached data exists for speaker + spkr_cache_dir = cache_dir / spkr_id + if os.path.isdir(spkr_cache_dir): + continue + + # each recording session has a separate subdirectory + rec_dirs = list(spkr_dir.glob("*/")) + rec_dirs = [r_d for r_d in rec_dirs if os.path.isdir(r_d)] + + # for each speaker, process & store necessary (non-adversarial) data + all_audio = [] + all_transcripts = [] + + # for each recording session, extract all audio files and transcripts + for rec_dir in rec_dirs: + + rec_id = rec_dir.parts[-1] + trans_fn = rec_dir / f"{spkr_id}-{rec_id}.trans.txt" + + # open transcript file + with open(trans_fn, "r") as f: + trans_idx = f.readlines() + + if len(trans_idx) == 0: + print(f"Error: empty transcript {trans_fn}") + continue + + for line in trans_idx: + + split_line = line.strip().split(" ") + audio_fn = rec_dir / f'{split_line[0]}.{LIBRISPEECH_EXT}' + transcript = " ".join(split_line[1:]).replace(" ", "|") + + x, _ = li.load(audio_fn, mono=True, sr=16000) + all_audio.append(torch.as_tensor(x).reshape(1, 1, -1).float()) + all_transcripts.append(transcript) + + # shuffle audio and transcripts in same random order + all_audio, all_transcripts = shuffle(all_audio, all_transcripts) + + # divide audio and transcripts + query_audio = all_audio[:N_QUERY] + query_transcripts = all_transcripts[:N_QUERY] + condition_audio = all_audio[N_QUERY:N_QUERY+N_CONDITION] + enroll_audio = all_audio[N_QUERY+N_CONDITION:][:N_ENROLL] + + # check for sufficient audio in each category + if len(query_audio) < N_QUERY: + print(f"Error: insufficient query audio for speaker {spkr_id}") + continue + elif len(condition_audio) < N_CONDITION: + print(f"Error: insufficient conditioning audio for speaker {spkr_id}") + continue + elif len(enroll_audio) < N_ENROLL: + print(f"Error: insufficient enrollment audio for speaker {spkr_id}") + continue + + # compute and save embeddings + for p_name, p in pipelines.items(): + + # compute and save query embeddings + query_emb = compute_embeddings_batch(query_audio, p) + f_query = spkr_cache_dir / p_name / 'query_emb.pt' + ensure_dir_for_filename(f_query) + + # compute and save conditioning embeddings + condition_emb = compute_embeddings_batch(condition_audio, p) + f_condition = spkr_cache_dir / p_name / 'condition_emb.pt' + ensure_dir_for_filename(f_condition) + + # compute and save enrolled embeddings + enroll_emb = compute_embeddings_batch(enroll_audio, p) + f_enroll = spkr_cache_dir / p_name / 'enroll_emb.pt' + ensure_dir_for_filename(f_enroll) + + torch.save(query_emb, f_query) + torch.save(condition_emb, f_condition) + torch.save(enroll_emb, f_enroll) + + # save query audio + f_audio = spkr_cache_dir / 'query_audio.pt' + torch.save(query_audio, f_audio) + + # save query transcripts + f_transcript = spkr_cache_dir / 'query_trans.pt' + torch.save(query_transcripts, f_transcript) + +@torch.no_grad() +def build_vc_dataset(pipelines: dict): + """ + Build VoxCeleb evaluation dataset on disk holding: + * query audio + * query embeddings + * conditioning embeddings + * enrolled embeddings + """ + + # locate dataset + data_dir = VOXCELEB1_DATA_DIR / 'voxceleb1' + cache_dir = CACHE_DIR / 'vc_wer_eval' + ensure_dir(cache_dir) + + assert os.path.isdir(data_dir), \ + f'VoxCeleb1 dataset required for evaluation' + + spkr_dirs = list(data_dir.glob("*/")) + spkr_dirs = [s_d for s_d in spkr_dirs if os.path.isdir(s_d)] + + # catalog audio + for spkr_dir in tqdm(spkr_dirs, total=len(spkr_dirs), desc='Building dataset'): + + # identify speaker + spkr_id = spkr_dir.parts[-1] + + # check whether cached data exists for speaker + spkr_cache_dir = cache_dir / spkr_id + if os.path.isdir(spkr_cache_dir): + continue + + # each recording session has a separate subdirectory + rec_dirs = list(spkr_dir.glob("*/")) + rec_dirs = [r_d for r_d in rec_dirs if os.path.isdir(r_d)] + + # for each speaker, process & store necessary (non-adversarial) data + all_audio = [] + + # for each recording session, extract all audio files and transcripts + for rec_dir in rec_dirs: + for audio_fn in rec_dir.glob(f"*.{VOXCELEB1_EXT}"): + x, _ = li.load(audio_fn, mono=True, sr=16000) + all_audio.append(torch.as_tensor(x).reshape(1, 1, -1).float()) + + # shuffle audio in random order + all_audio = shuffle(all_audio) + + # divide audio and transcripts + query_audio = all_audio[:N_QUERY] + condition_audio = all_audio[N_QUERY:N_QUERY+N_CONDITION] + enroll_audio = all_audio[N_QUERY+N_CONDITION:][:N_ENROLL] + + # check for sufficient audio in each category + if len(query_audio) < N_QUERY: + print(f"Error: insufficient query audio for speaker {spkr_id}") + continue + elif len(condition_audio) < N_CONDITION: + print(f"Error: insufficient conditioning audio for speaker {spkr_id}") + continue + elif len(enroll_audio) < N_ENROLL: + print(f"Error: insufficient enrollment audio for speaker {spkr_id}") + continue + + # compute and save embeddings + for p_name, p in pipelines.items(): + + # compute and save query embeddings + query_emb = compute_embeddings_batch(query_audio, p) + f_query = spkr_cache_dir / p_name / 'query_emb.pt' + ensure_dir_for_filename(f_query) + + # compute and save conditioning embeddings + condition_emb = compute_embeddings_batch(condition_audio, p) + f_condition = spkr_cache_dir / p_name / 'condition_emb.pt' + ensure_dir_for_filename(f_condition) + + # compute and save enrolled embeddings + enroll_emb = compute_embeddings_batch(enroll_audio, p) + f_enroll = spkr_cache_dir / p_name / 'enroll_emb.pt' + ensure_dir_for_filename(f_enroll) + + torch.save(query_emb, f_query) + torch.save(condition_emb, f_condition) + torch.save(enroll_emb, f_enroll) + + # save query audio + f_audio = spkr_cache_dir / 'query_audio.pt' + torch.save(query_audio, f_audio) + +@torch.no_grad() +def asr_metrics(true: list, hypothesis: list, batch_size: int = 5): + """ + Compute word and character error rates between two lists of corresponding + transcripts + """ + + assert len(true) == len(hypothesis) + + n_batches = math.ceil(len(true) / batch_size) + + transform_wer = jiwer.Compose([ + jiwer.ToLowerCase(), + jiwer.RemoveWhiteSpace(replace_by_space=True), + jiwer.RemoveMultipleSpaces(), + jiwer.ReduceToSingleSentence(word_delimiter="|"), + jiwer.ReduceToListOfListOfWords(word_delimiter="|"), + ]) + + wer_score = 0.0 + cer_score = 0.0 + + wer_n = 0 + cer_n = 0 + + for i in range(n_batches): + + batch_true = true[i*batch_size:(i+1)*batch_size] + batch_hypothesis = hypothesis[i*batch_size:(i+1)*batch_size] + + wer_n_batch = builtins.sum([len(s.split('|')) for s in batch_true]) + cer_n_batch = builtins.sum([len(s) for s in batch_true]) + + attack_cer = cer(batch_true, batch_hypothesis) + attack_wer = wer(batch_true, batch_hypothesis, + truth_transform=transform_wer, + hypothesis_transform=transform_wer) + + wer_score += wer_n_batch*attack_wer + cer_score += cer_n_batch*attack_cer + + wer_n += wer_n_batch + cer_n += cer_n_batch + + wer_score /= wer_n + cer_score /= cer_n + + return wer_score, cer_score + + +@torch.no_grad() +def top_k(query: dict, enrolled: dict, k: int): + """ + Compute portion of queries for which 'correct' ID appears in k-closest + enrolled entries + """ + + # concatenate query embeddings into single tensor + query_array = [] + query_ids = [] + + for s_l in query.keys(): + query_array.append(query[s_l]) + query_ids.extend([s_l] * len(query[s_l])) + + query_array = torch.cat(query_array, dim=0).squeeze().cpu().numpy() + query_ids = torch.as_tensor(query_ids).cpu().numpy() + + # concatenate enrolled embeddings into single tensor + enrolled_array = [] + enrolled_ids = [] + + for s_l in enrolled.keys(): + enrolled_array.append(enrolled[s_l]) + enrolled_ids.extend([s_l] * len(enrolled[s_l])) + + enrolled_array = torch.cat(enrolled_array, dim=0).squeeze().cpu().numpy() + enrolled_ids = torch.as_tensor(enrolled_ids).cpu().numpy() + + # embedding dimension + assert query_array.shape[-1] == enrolled_array.shape[-1] + d = query_array.shape[-1] + + # index enrolled embeddings + knn = NearestNeighbors(n_neighbors=k, metric="cosine").fit(enrolled_array) + + # `I` is a (n_queries, k) array holding the indices of the k-closest enrolled + # embeddings for each query; `D` is a (n_queries, k) array holding the corresponding + # embedding-space distances + D, I = knn.kneighbors(query_array, k, return_distance=True) + + # for each row, see if at least one of the k nearest enrolled indices maps + # to a speaker ID that matches the query index's speaker id + targets = np.tile(query_ids.reshape(-1, 1), (1, k)) + + predictions = enrolled_ids[I] + matches = (targets == predictions).sum(axis=-1) > 0 + + return np.mean(matches) + + +def init_attacks(): + """ + Initialize pre-trained speaker recognition pipelines and de-identification + attacks + """ + + # channel simulation + if SIMULATION: + sim = [ + Offset(length=[-.15, .15]), + Noise(type='gaussian', snr=[30.0, 50.0]), + Bandpass(low=[300, 500], high=[3400, 7400]), + Dropout(rate=0.001) + ] + else: + sim = None + + pipelines = {} + + model_resnet = SpeakerVerificationModel( + model=ResNetSE34V2(nOut=512, encoder_type='ASP'), + n_segments=1, + segment_select='lin', + distance_fn='cosine', + threshold=0.0 + ) + model_resnet.load_weights( + MODELS_DIR / 'speaker' / 'resnetse34v2' / 'resnetse34v2.pt') + + model_yvector = SpeakerVerificationModel( + model=YVector(), + n_segments=1, + segment_select='lin', + distance_fn='cosine', + threshold=0.0 + ) + model_yvector.load_weights( + MODELS_DIR / 'speaker' / 'yvector' / 'yvector.pt') + + pipelines['resnet'] = Pipeline( + simulation=sim, + preprocessor=Preprocessor(Normalize(method='peak')), + model=model_resnet, + device='cuda' if torch.cuda.is_available() else 'cpu' + ) + + if TRANSFER: + pipelines['yvector'] = Pipeline( + simulation=sim, + preprocessor=Preprocessor(Normalize(method='peak')), + model=model_yvector, + device='cuda' if torch.cuda.is_available() else 'cpu' + ) + else: + del model_yvector + + # prepare to log attack progress + writer = Writer( + root_dir=RUNS_DIR, + name='evaluate-attacks', + use_timestamp=True, + log_iter=300, + use_tb=True + ) + + attacks = {} + + # use consistent adversarial loss + adv_loss = SpeakerEmbeddingLoss( + targeted=False, + confidence=0.1, + threshold=0.0 + ) + + # use consistent auxiliary loss across attacks + aux_loss = SumLoss().add_loss_function( + DemucsMRSTFTLoss(), 1.0 + ).add_loss_function(L1Loss(), 1.0).to('cuda') + + attacks['voicebox'] = VoiceBoxAttack( + pipeline=pipelines['resnet'], + adv_loss=adv_loss, + aux_loss=aux_loss, + lr=1e-4, + epochs=1, + batch_size=BATCH_SIZE, + voicebox_kwargs={ + 'win_length': 256, + 'ppg_encoder_hidden_size': 256, + 'use_phoneme_encoder': True, + 'use_pitch_encoder': True, + 'use_loudness_encoder': True, + 'spec_encoder_lookahead_frames': 0, + 'spec_encoder_type': 'mel', + 'spec_encoder_mlp_depth': 2, + 'bottleneck_lookahead_frames': LOOKAHEAD, + 'ppg_encoder_path': PPG_PRETRAINED_PATH, + 'n_bands': 128, + 'spec_encoder_hidden_size': 512, + 'bottleneck_skip': True, + 'bottleneck_hidden_size': 512, + 'bottleneck_feedforward_size': 512, + 'bottleneck_type': 'lstm', + 'bottleneck_depth': 2, + 'control_eps': 0.5, + 'projection_norm': float('inf'), + 'conditioning_dim': 512 + }, + writer=writer, + checkpoint_name='voicebox-attack' + ) + attacks['voicebox'].load(VOICEBOX_PATH) + + attacks['universal'] = AdvPulseAttack( + pipeline=pipelines['resnet'], + adv_loss=adv_loss, + pgd_norm=float('inf'), + pgd_variant=None, + scale_grad=None, + eps=0.08, + length=2.0, + align='start', + lr=1e-4, + normalize=True, + loop=True, + aux_loss=aux_loss, + epochs=1, + batch_size=BATCH_SIZE, + writer=writer, + checkpoint_name='universal-attack' + ) + attacks['universal'].load(UNIVERSAL_PATH) + + attacks['kenansville'] = KenansvilleAttack( + pipeline=pipelines['resnet'], + batch_size=BATCH_SIZE, + adv_loss=adv_loss, + threshold_db_low=4.0, # fix threshold + threshold_db_high=4.0, + win_length=512, + writer=writer, + step_size=1.0, + search='bisection', + min_success_rate=0.2, + checkpoint_name='kenansville-attack' + ) + + attacks['noise'] = WhiteNoiseAttack( + pipeline=pipelines['resnet'], + adv_loss=adv_loss, + aux_loss=aux_loss, + snr_low=-10.0, # fix threshold + snr_high=-10.0, + writer=writer, + step_size=1, + search='bisection', + min_success_rate=0.2, + checkpoint_name='noise-perturbation' + ) + + return attacks, pipelines, writer + + +@torch.no_grad() +def evaluate_attack(attack: TrainableAttack, + speaker_pipeline: Pipeline, + asr_pipeline: Pipeline): + + if DENOISER: + from src.models.denoiser.demucs import load_demucs + defense = load_demucs('dns_48').to( + 'cuda' if torch.cuda.is_available() else 'cpu') + defense.eval() + else: + defense = nn.Identity() + + # prepare for GPU inference + if torch.cuda.is_available(): + + attack.pipeline.set_device('cuda') + speaker_pipeline.set_device('cuda') + asr_pipeline.set_device('cuda') + attack.perturbation.to('cuda') + + # locate dataset + if EVAL_DATASET == "librispeech": + cache_dir = CACHE_DIR / 'ls_wer_eval' + else: + cache_dir = CACHE_DIR / 'vc_wer_eval' + assert os.path.isdir(cache_dir), \ + f'Dataset must be built/cached before evaluation' + + # prepare for PESQ/STOI calculations + all_pesq_scores = [] + all_stoi_scores = [] + + # prepare for WER/CER computations + all_query_transcripts = [] + all_pred_query_transcripts = [] + all_adv_query_transcripts = [] + + # prepare for accuracy computations + all_query_emb = {} + all_adv_query_emb = {} + all_enroll_emb = {} + all_enroll_emb_centroid = {} + + spkr_dirs = list(cache_dir.glob("*/")) + spkr_dirs = [s_d for s_d in spkr_dirs if os.path.isdir(s_d)] + for spkr_dir in tqdm(spkr_dirs, total=len(spkr_dirs), desc='Running evaluation'): + + # identify speaker + spkr_id = spkr_dir.parts[-1] + + # use integer IDs + if EVAL_DATASET != "librispeech": + spkr_id = spkr_id.split("id")[-1] + + # identify speaker recognition model + if isinstance(speaker_pipeline.model.model, ResNetSE34V2): + model_name = 'resnet' + elif isinstance(speaker_pipeline.model.model, YVector): + model_name = 'yvector' + else: + raise ValueError(f'Invalid speaker recognition model') + + # load clean embeddings + query_emb = torch.load(spkr_dir / model_name / 'query_emb.pt') + condition_emb = torch.load(spkr_dir / 'resnet' / 'condition_emb.pt') + enroll_emb = torch.load(spkr_dir / model_name / 'enroll_emb.pt') + + # load clean audio + query_audio = torch.load(spkr_dir / 'query_audio.pt') + + # if defense in use, re-compute query audio + if DENOISER: + query_emb = compute_embeddings_batch( + query_audio, speaker_pipeline, defense=defense + ) + + # load clean transcript + if EVAL_DATASET == "librispeech": + query_transcripts = torch.load(spkr_dir / 'query_trans.pt') + else: + query_transcripts = None + + # compute conditioning embedding centroid + condition_centroid = condition_emb.mean(dim=(0, 1), keepdim=True) + + # compute enrolled embedding centroid + enroll_centroid = enroll_emb.mean(dim=(0, 1), keepdim=True) + + # compute adversarial query audio + adv_query_audio = compute_attack_batch( + query_audio, attack, condition_centroid) + + # compute adversarial query embeddings; optionally, pass through + # unseen denoiser defense + adv_query_emb = compute_embeddings_batch( + adv_query_audio, speaker_pipeline, defense=defense + ) + + if EVAL_DATASET == "librispeech": + + # compute clean predicted transcripts + pred_query_transcripts = compute_transcripts_batch( + query_audio, asr_pipeline + ) + + # compute adversarial transcripts + adv_query_transcripts = compute_transcripts_batch( + adv_query_audio, asr_pipeline + ) + + # compute objective quality metric scores + if COMPUTE_OBJECTIVE_METRICS: + pesq_scores = compute_pesq(query_audio, adv_query_audio) + stoi_scores = compute_stoi(query_audio, adv_query_audio) + else: + pesq_scores = np.zeros(len(query_audio)) + stoi_scores = np.zeros(len(query_audio)) + + # store all objective quality metric scores + all_pesq_scores.extend(pesq_scores) + all_stoi_scores.extend(stoi_scores) + + # store all unit-normalized clean, adversarial, and enrolled centroid + # embeddings + all_query_emb[int(spkr_id)] = F.normalize(query_emb.clone(), dim=-1) + all_adv_query_emb[int(spkr_id)] = F.normalize(adv_query_emb.clone(), dim=-1) + all_enroll_emb[int(spkr_id)] = F.normalize(enroll_emb.clone(), dim=-1) + all_enroll_emb_centroid[int(spkr_id)] = F.normalize(enroll_centroid.clone(), dim=-1) + + # store all transcripts + if EVAL_DATASET == "librispeech": + all_query_transcripts.extend(query_transcripts) + all_pred_query_transcripts.extend(pred_query_transcripts) + all_adv_query_transcripts.extend(adv_query_transcripts) + + # free GPU memory for similarity search + attack.pipeline.set_device('cpu') + speaker_pipeline.set_device('cpu') + asr_pipeline.set_device('cpu') + attack.perturbation.to('cpu') + torch.cuda.empty_cache() + + # compute and display final objective quality metrics + print(f"PESQ (mean/std): {np.mean(all_pesq_scores)}/{np.std(all_pesq_scores)}") + print(f"STOI (mean/std): {np.mean(all_stoi_scores)}/{np.std(all_stoi_scores)}") + + if EVAL_DATASET == "librispeech": + + # compute and display final WER/CER metrics + wer, cer = asr_metrics(all_query_transcripts, all_adv_query_transcripts) + print(f"Adversarial WER / CER: {wer} / {cer}") + + wer, cer = asr_metrics(all_query_transcripts, all_pred_query_transcripts) + print(f"Clean WER / CER: {wer} / {cer}") + + else: + wer, cer = None, None + + del (wer, cer, all_pesq_scores, all_stoi_scores, + all_query_transcripts, all_adv_query_transcripts, all_pred_query_transcripts) + + # embedding-space cosine distance calculations + cos_dist_fn = EmbeddingDistance(distance_fn='cosine') + + # mean clean-to-adversarial query embedding distance + total_query_dist = 0.0 + n = 0 + for spkr_id in all_query_emb.keys(): + dist = cos_dist_fn(all_query_emb[spkr_id], + all_adv_query_emb[spkr_id]).mean() + total_query_dist += len(all_query_emb[spkr_id]) * dist.item() + n += len(all_query_emb[spkr_id]) + mean_query_dist = total_query_dist / n + print(f"\n\t\tMean cosine distance between clean and adversarial query " + f"embeddings: {mean_query_dist :0.4f}") + + # mean adversarial-query-to-enrolled-centroid embedding distance + total_centroid_dist = 0.0 + n = 0 + for spkr_id in all_query_emb.keys(): + n_queries = len(all_adv_query_emb[spkr_id]) + dist = 0.0 + for i in range(n_queries): + dist += cos_dist_fn(all_enroll_emb_centroid[spkr_id], + all_adv_query_emb[spkr_id][i:i+1]).item() + total_centroid_dist += dist + n += n_queries + mean_centroid_dist = total_centroid_dist / n + print(f"\t\tMean cosine distance between clean enrolled centroids and " + f"adversarial query embeddings: {mean_centroid_dist :0.4f}") + + # top-1 accuracy for clean queries (closest embedding) + top_1_clean_single = top_k(all_query_emb, all_enroll_emb, k=1) + + # top-1 accuracy for clean queries (centroid embedding) + top_1_clean_centroid = top_k(all_query_emb, all_enroll_emb_centroid, k=1) + + # top-10 accuracy for clean queries (closest embedding) + top_10_clean_single = top_k(all_query_emb, all_enroll_emb, k=10) + + # top-10 accuracy for clean queries (centroid embedding) + top_10_clean_centroid = top_k(all_query_emb, all_enroll_emb_centroid, k=10) + + # top-1 accuracy for adversarial queries (closest embedding) + top_1_adv_single = top_k(all_adv_query_emb, all_enroll_emb, k=1) + + # top-1 accuracy for adversarial queries (centroid embedding) + top_1_adv_centroid = top_k(all_adv_query_emb, all_enroll_emb_centroid, k=1) + + # top-10 accuracy for adversarial queries (closest embedding) + top_10_adv_single = top_k(all_adv_query_emb, all_enroll_emb, k=10) + + # top-10 accuracy for adversarial queries (centroid embedding) + top_10_adv_centroid = top_k(all_adv_query_emb, all_enroll_emb_centroid, k=10) + + print(f"\n\t\tTop-1 accuracy (clean embedding / nearest enrolled embedding) {top_1_clean_single :0.4f}", + f"\n\t\tTop-1 accuracy (clean embedding / nearest enrolled centroid) {top_1_clean_centroid :0.4f}", + f"\n\t\tTop-10 accuracy (clean embedding / nearest enrolled embedding) {top_10_clean_single :0.4f}" + f"\n\t\tTop-10 accuracy (clean embedding / nearest enrolled centroid) {top_10_clean_centroid :0.4f}", + f"\n\t\tTop-1 accuracy (adversarial embedding / nearest enrolled embedding {top_1_adv_single :0.4f}", + f"\n\t\tTop-1 accuracy (adversarial embedding / nearest enrolled centroid) {top_1_adv_centroid :0.4f}", + f"\n\t\tTop-10 accuracy (adversarial embedding / nearest enrolled embedding {top_10_adv_single :0.4f}", + f"\n\t\tTop-10 accuracy (adversarial embedding / nearest enrolled centroid) {top_10_adv_centroid :0.4f}" + ) + + +@torch.no_grad() +def evaluate_attacks(attacks: dict, + speaker_pipelines: dict, + asr_pipeline: Pipeline): + + for attack_name, attack in attacks.items(): + for sp_name, sp in speaker_pipelines.items(): + print(f'Evaluating {attack_name} against model {sp_name} ' + f'{"with" if DENOISER else "without"} denoiser defense') + evaluate_attack(attack, sp, asr_pipeline) + + +def main(): + + # initial random seed (keep dataset order consistent) + set_random_seed(0) + + # initialize pipelines + attacks, pipelines, writer = init_attacks() + + # ensure that necessary data is cached + if EVAL_DATASET == "librispeech": + build_ls_dataset(pipelines) + else: + build_vc_dataset(pipelines) + + # initialize ASR model + asr_model = SpeechRecognitionModel( + model=Wav2Vec2(), + ) + asr_pipeline = Pipeline( + model=asr_model, + preprocessor=Preprocessor(Normalize(method='peak')), + device='cuda' if torch.cuda.is_available() else 'cpu' + ) + + writer.log_cuda_memory() + + evaluate_attacks(attacks, pipelines, asr_pipeline) + + +if __name__ == "__main__": + main() + diff --git a/voicebox/scripts/experiments/train.py b/voicebox/scripts/experiments/train.py new file mode 100755 index 0000000000000000000000000000000000000000..250c0d6c4ad1c2d02171a5e2db0aee49743c3577 --- /dev/null +++ b/voicebox/scripts/experiments/train.py @@ -0,0 +1,282 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +import psutil +import pickle +import librosa as li + +from torch.utils.data import TensorDataset + +import time +import random +import argparse +from datetime import datetime + +import numpy as np +import pandas as pd +from typing import Dict +from pathlib import Path +from tqdm import tqdm +import builtins + +from typing import Iterable +from copy import deepcopy + +from distutils.util import strtobool + +from src.data import * +from src.constants import * +from src.models import * +from src.simulation import * +from src.preprocess import * +from src.attacks.offline import * +from src.loss import * +from src.pipelines import * +from src.utils import * + +################################################################################ +# Train VoiceBox attack +################################################################################ + +BATCH_SIZE = 20 # training batch size +EPOCHS = 10 # training epochs +TARGET_PCTL = 25 # de-identification strength; in [1,5,10,15,20,25,50,90,100] +N_EMBEDDINGS_TRAIN = 15 +TARGETED = False +TARGETS_TRAIN = 'centroid' # 'random', 'same', 'single', 'median' +TARGETS_TEST = 'centroid' # 'random', 'same', 'single', 'median' + +# distributions of inter- ('targeted') and intra- ('untargeted') speaker +# distances in each pre-trained model's embedding spaces, as measured between +# individual utterances and their speaker centroid ('single-centroid') or +# between all pairs of individual utterances ('single-single') over the +# LibriSpeech test-clean dataset. This allows specification of attack strength +# during the training process +percentiles = { + 'resnet': { + 'targeted': { + 'single-centroid': {1:.495, 5:.572, 10:.617, 15:.648, 20:.673, 25:.695, 50:.773, 90:.892, 100:1.127}, + 'single-single': {1:.560, 5:.630, 10:.672, 15:.700, 20:.722, 25:.742, 50:.813, 90:.924, 100:1.194} + }, + 'untargeted': { + 'single-centroid': {1:.099, 5:.117, 10:.126, 15:.133, 20:.139, 25:.145, 50:.170, 90:.253, 100:.587}, + 'single-single': {1:.181, 5:.215, 10:.235, 15:.249, 20:.262, 25:.272, 50:.323, 90:.464, 100:.817} + }, + }, + 'yvector': { + 'targeted': { + 'single-centroid': {1:.665, 5:.757, 10:.801, 15:.830, 20:.851, 25:.868, 50:.936, 90:1.056, 100:1.312}, + 'single-single': {1:.695, 5:.779, 10:.821, 15:.847, 20:.868, 25:.885, 50:.952, 90:1.072, 100:1.428} + }, + 'untargeted': { + 'single-single': {1:.218, 5:.268, 10:.301, 15:.325, 20:.345, 25:.365, 50:.455, 90:.684, 100:1.156}, + 'single-centroid': {1:.114, 5:.143, 10:.159, 15:.170, 20:.180, 25:.190, 50:.242, 90:.413, 100:.874} + } + }, +} + + +def set_random_seed(seed: int = 123): + """Set random seed to allow for reproducibility""" + random.seed(seed) + torch.manual_seed(seed) + + if torch.backends.cudnn.is_available(): + # torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = True + + +def param_count(m: nn.Module, trainable: bool = False): + """Count the number of trainable parameters (weights) in a model""" + if trainable: + return builtins.sum( + [p.shape.numel() for p in m.parameters() if p.requires_grad]) + else: + return builtins.sum([p.shape.numel() for p in m.parameters()]) + + +def main(): + + set_random_seed(0) + + model = SpeakerVerificationModel( + model=ResNetSE34V2(nOut=512, encoder_type='ASP'), + n_segments=1, + segment_select='lin', + distance_fn='cosine', + threshold=percentiles['resnet']['targeted']['single-centroid' if + TARGETS_TRAIN == 'centroid' else 'single-single'][TARGET_PCTL] + ) + model.load_weights(MODELS_DIR / 'speaker' / 'resnetse34v2' / 'resnetse34v2.pt') + + # instantiate training pipeline + pipeline = Pipeline( + simulation=None, + preprocessor=Preprocessor(Normalize(method='peak')), + model=model, + device='cuda' if torch.cuda.is_available() else 'cpu' + ) + + attacks = {} + + # log training progress + writer = Writer( + root_dir=RUNS_DIR, + name='train-attacks', + use_timestamp=True, + log_iter=300, + use_tb=True + ) + + # adversarial training loss + adv_loss = SpeakerEmbeddingLoss( + targeted=TARGETED, + confidence=0.1, + threshold=pipeline.model.threshold + ) + + # auxiliary loss + aux_loss = SumLoss().add_loss_function( + DemucsMRSTFTLoss(), 1.0 + ).add_loss_function(L1Loss(), 1.0).to('cuda') + + # speech features loss actually seems to do better... + # aux_loss = SumLoss().add_loss_function(SpeechFeatureLoss(), 1e-6).to('cuda') + + attacks['voicebox'] = VoiceBoxAttack( + pipeline=pipeline, + adv_loss=adv_loss, + aux_loss=aux_loss, + lr=1e-4, + epochs=EPOCHS, + batch_size=BATCH_SIZE, + voicebox_kwargs={ + 'win_length': 256, + 'ppg_encoder_hidden_size': 256, + 'use_phoneme_encoder': True, + 'use_pitch_encoder': True, + 'use_loudness_encoder': True, + 'spec_encoder_lookahead_frames': 0, + 'spec_encoder_type': 'mel', + 'spec_encoder_mlp_depth': 2, + 'bottleneck_lookahead_frames': 5, + 'ppg_encoder_path': PPG_PRETRAINED_PATH, + 'n_bands': 128, + 'spec_encoder_hidden_size': 512, + 'bottleneck_skip': True, + 'bottleneck_hidden_size': 512, + 'bottleneck_feedforward_size': 512, + 'bottleneck_type': 'lstm', + 'bottleneck_depth': 2, + 'control_eps': 0.5, + 'projection_norm': float('inf'), + 'conditioning_dim': 512 + }, + writer=writer, + checkpoint_name='voicebox-attack' + ) + + attacks['universal'] = AdvPulseAttack( + pipeline=pipeline, + adv_loss=adv_loss, + pgd_norm=float('inf'), + pgd_variant=None, + scale_grad=None, + eps=0.08, + length=2.0, + align='random', # 'start', + lr=1e-4, + normalize=True, + loop=True, + aux_loss=aux_loss, + epochs=EPOCHS, + batch_size=BATCH_SIZE, + writer=writer, + checkpoint_name='universal-attack' + ) + + if torch.cuda.is_available(): + + # prepare for multi-GPU training + device_ids = get_cuda_device_ids() + + # wrap pipeline for multi-GPU training + pipeline = wrap_pipeline_multi_gpu(pipeline, device_ids) + + # load training and validation datasets. Features will be computed and + # cached to disk, which may take some time + data_train = LibriSpeechDataset( + split='train-clean-100', features=['pitch', 'periodicity', 'loudness']) + data_test = LibriSpeechDataset( + split='test-clean', features=['pitch', 'periodicity', 'loudness']) + + # reassign targets if necessary + compiled_train, compiled_test = create_embedding_dataset( + pipeline=pipeline, + select_train=TARGETS_TRAIN, + select_test=TARGETS_TEST, + data_train=data_train, + data_test=data_test, + targeted=TARGETED, + target_class=None, + num_embeddings_train=N_EMBEDDINGS_TRAIN, + batch_size=20 + ) + + # extract embedding datasets + data_train = compiled_train['dataset'] + data_test = compiled_test['dataset'] + + # log memory use prior to training + writer.log_info(f'Training data ready; memory use: ' + f'{psutil.virtual_memory().percent :0.3f}%') + writer.log_cuda_memory() + + for attack_name, attack in attacks.items(): + + writer.log_info(f'Preparing {attack_name}...') + + if torch.cuda.is_available(): + + attack.perturbation.to('cuda') + attack.pipeline.to('cuda') + + # wrap attack for multi-GPU training + attack = wrap_attack_multi_gpu(attack, device_ids) + + # evaluate performance + with torch.no_grad(): + x_example = next(iter(data_train))['x'].to(pipeline.device) + st = time.time() + outs = attack.perturbation(x_example) + dur = time.time() - st + + writer.log_info( + f'Processing time per input (device: ' + f'{pipeline.device}): {dur/x_example.shape[0] :0.4f} (s)' + ) + writer.log_info(f'Trainable parameters: ' + f'{param_count(attack.perturbation, trainable=True)}') + writer.log_info(f'Total parameters: {param_count(attack.perturbation, trainable=False)}') + + # train + writer.log_info('Training attack...') + attack.train(data_train=data_train, data_val=data_test) + + # evaluate + writer.log_info(f'Evaluating attack...') + x_adv, success, detection = attack.evaluate( + dataset=data_test + ) + + # log results summary: success rate in achieving target threshold + writer.log_info( + f'Success rate in meeting embedding distance threshold {pipeline.model.threshold}' + f' ({TARGET_PCTL}%): ' + f'{success.flatten().mean().item()}' + ) + + +if __name__ == "__main__": + main() diff --git a/voicebox/scripts/experiments/train_phoneme_predictor.py b/voicebox/scripts/experiments/train_phoneme_predictor.py new file mode 100755 index 0000000000000000000000000000000000000000..729d50106f1da326ed092cf15fcfc31c100590d4 --- /dev/null +++ b/voicebox/scripts/experiments/train_phoneme_predictor.py @@ -0,0 +1,205 @@ +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from src.models.phoneme import PPGEncoder +from src.constants import LIBRISPEECH_NUM_PHONEMES, LIBRISPEECH_PHONEME_DICT +from src.data import LibriSpeechDataset +from src.utils.writer import Writer + +import numpy as np +from sklearn.metrics import confusion_matrix, classification_report +import seaborn as sn +import pandas as pd +import matplotlib.pyplot as plt + +################################################################################ +# Train a simple model to produce phonetic posteriorgrams (PPGs) +################################################################################ + + +def main(): + + # training hyperparameters + lr = .001 + epochs = 60 + batch_size = 250 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # phoneme encoder hyperparameters + lstm_depth = 2 + hidden_size = 128 # 512 + win_length = 256 + hop_length = 128 + n_mels = 32 + n_mfcc = 19 + lookahead_frames = 0 # 1 + + # datasets and loaders + train_data = LibriSpeechDataset( + split='train-clean-100', + target='phoneme', + features=None, + hop_length=hop_length + ) + val_data = LibriSpeechDataset( + split='test-clean', + target='phoneme', + features=None, + hop_length=hop_length + ) + train_loader = DataLoader( + train_data, + batch_size=batch_size, + shuffle=True) + val_loader = DataLoader( + val_data, + batch_size=batch_size) + + # initialize phoneme encoder + encoder = PPGEncoder( + win_length=win_length, + hop_length=hop_length, + win_func=torch.hann_window, + n_mels=n_mels, + n_mfcc=n_mfcc, + lstm_depth=lstm_depth, + hidden_size=hidden_size, + ) + + # initialize classification layer and wrap as single module + classifier = nn.Sequential( + encoder, + nn.Linear(hidden_size, LIBRISPEECH_NUM_PHONEMES) + ).to(device) + + # log training progress + writer = Writer( + name=f"phoneme_lookahead_{lookahead_frames}", + use_tb=True, + log_iter=len(train_loader) + ) + + import builtins + parameter_count = builtins.sum([ + p.shape.numel() + for p in classifier[0].parameters() + if p.requires_grad + ]) + + writer.log_info(f'Training PPG model with lookahead {lookahead_frames}' + f' ({parameter_count} parameters)') + + # initialize optimizer and loss function + optimizer = torch.optim.Adam(classifier.parameters(), lr=lr) + loss_fn = nn.CrossEntropyLoss() + + iter_id = 0 + min_val_loss = float('inf') + + for epoch in range(epochs): + + print(f'beginning epoch {epoch}') + + classifier.train() + for batch in train_loader: + + optimizer.zero_grad(set_to_none=True) + + x, y = batch['x'].to(device), batch['y'].to(device) + + preds = classifier(x) + + # offset labels to incorporate lookahead + y = y[:, :-lookahead_frames if lookahead_frames else None] + + # offset predictions correspondingly + preds = preds[:, lookahead_frames:] + + # compute cross-entropy loss + loss = loss_fn( + preds.reshape(-1, LIBRISPEECH_NUM_PHONEMES), y.reshape(-1) + ) + + loss.backward() + optimizer.step() + + writer.log_scalar(loss, tag="CrossEntropyLoss-Train", global_step=iter_id) + iter_id += 1 + + val_loss, val_acc, n = 0.0, 0.0, 0 + classifier.eval() + with torch.no_grad(): + for batch in val_loader: + + x, y = batch['x'].to(device), batch['y'].to(device) + + preds = classifier(x) + + # offset labels to incorporate lookahead + y = y[:, :-lookahead_frames if lookahead_frames else None] + + # offset predictions correspondingly + preds = preds[:, lookahead_frames:] + + n += len(x) + val_loss += loss_fn( + preds.reshape(-1, LIBRISPEECH_NUM_PHONEMES), y.reshape(-1) + ) * len(x) + val_acc += len(x) * (torch.argmax(preds, dim=2) == y).flatten().float().mean() + + val_loss /= n + val_acc /= n + writer.log_scalar(val_loss, tag="CrossEntropyLoss-Val", global_step=iter_id) + writer.log_scalar(val_acc, tag="Accuracy-Val") + + # save weights + if val_loss < min_val_loss: + min_val_loss = val_loss + print(f'new best val loss {val_loss}; saving weights') + writer.checkpoint(classifier[0].state_dict(), 'phoneme_classifier') + + # generate confusion matrix + classifier.eval() + + # compute accuracy on validation data + all_preds = [] + all_true = [] + with torch.no_grad(): + for batch in val_loader: + + x, y = batch['x'].to(device), batch['y'].to(device) + + preds = classifier(x) + + # offset labels to incorporate lookahead + y = y[:, :-lookahead_frames if lookahead_frames else None] + + # offset predictions correspondingly + preds = preds[:, lookahead_frames:] + + all_preds.append(preds.argmax(dim=2).reshape(-1)) + all_true.append(y.reshape(-1)) + + # compile predictions and targets + all_preds = torch.cat(all_preds, dim=0).cpu().numpy() + all_true = torch.cat(all_true, dim=0).cpu().numpy() + + reverse_dict = {v: k for (k, v) in LIBRISPEECH_PHONEME_DICT.items() if v != 0} + reverse_dict[0] = 'sil' + + class_report = classification_report(all_true, all_preds) + writer.log_info(class_report) + + cm = confusion_matrix(all_true, all_preds, labels=list(range(len(reverse_dict)))) + df_cm = pd.DataFrame(cm, index=[i for i in sorted(list(reverse_dict.keys()))], + columns=[i for i in sorted(list(reverse_dict.keys()))]) + plt.figure(figsize=(40, 28)) + sn.set(font_scale=1.0) # for label size + sn.heatmap(df_cm, annot=True, annot_kws={"size": 35 / np.sqrt(len(cm))}, fmt='g') + + plt.savefig("phoneme_cm.png", dpi=200) + + +if __name__ == '__main__': + main() diff --git a/voicebox/scripts/streamer/benchmark_streamer.py b/voicebox/scripts/streamer/benchmark_streamer.py new file mode 100755 index 0000000000000000000000000000000000000000..4d6371a5ace4ac1dab6cae62a28fcc343154263a --- /dev/null +++ b/voicebox/scripts/streamer/benchmark_streamer.py @@ -0,0 +1,97 @@ +import torch +import librosa +import soundfile as sf + +from tqdm import tqdm +from src.attacks.offline.perturbation.voicebox import projection +from src.attacks.online import Streamer, VoiceBoxStreamer +from src.models import ResNetSE34V2, SpeakerVerificationModel +from src.constants import MODELS_DIR, TEST_DIR, PPG_PRETRAINED_PATH + +import warnings +warnings.filterwarnings("ignore") + +torch.set_num_threads(1) + +device = 'cpu' + +lookahead = 5 + +signal_length = 64_000 +chunk_size = 640 + +test_audio = torch.Tensor( + librosa.load(TEST_DIR / 'data' / 'test.wav', sr=16_000, mono=True)[0] +).unsqueeze(0).unsqueeze(0) + +tests = [ + (512, 512, 512) +] +resnet_model = SpeakerVerificationModel(model=ResNetSE34V2()) +condition_vector = resnet_model(test_audio) +for (bottleneck_hidden_size, + bottleneck_feedforward_size, + spec_encoder_hidden_size) in tests: + print( +f""" +==================================== +bottleneck_hidden_size: {bottleneck_hidden_size} +bottleneck_feedforward_size: {bottleneck_feedforward_size} +spec_encoder_hidden_size: {spec_encoder_hidden_size} +""" + ) + + streamer = Streamer( + VoiceBoxStreamer( + win_length=256, + bottleneck_type='lstm', + bottleneck_skip=True, + bottleneck_depth=2, + bottleneck_lookahead_frames=5, + bottleneck_hidden_size=bottleneck_hidden_size, + bottleneck_feedforward_size=bottleneck_feedforward_size, + + conditioning_dim=512, + + spec_encoder_mlp_depth=2, + spec_encoder_hidden_size=spec_encoder_hidden_size, + spec_encoder_lookahead_frames=0, + ppg_encoder_path=PPG_PRETRAINED_PATH, + + ppg_encoder_depth=2, + ppg_encoder_hidden_size=256, + projection_norm='inf', + control_eps=0.5, + n_bands=128 + ), + device, + hop_length=128, + window_length=256, + win_type='hann', + lookahead_frames=lookahead, + recurrent=True + ) + streamer.model.load_state_dict(torch.load(MODELS_DIR / 'voicebox' / 'voicebox_final.pt')) + streamer.condition_vector = condition_vector + + output_chunks = [] + for i in tqdm(range(0, signal_length, chunk_size)): + signal_chunk = test_audio[..., i:i+chunk_size] + out = streamer.feed(signal_chunk) + output_chunks.append(out) + output_chunks.append(streamer.flush()) + output_audio = torch.cat(output_chunks, dim=-1) + output_embedding = resnet_model(output_audio) + + print( +f""" +RTF: {streamer.real_time_factor} +Embedding Distance: {resnet_model.distance_fn(output_embedding, condition_vector)} +==================================== +""" + ) + sf.write( + 'output.wav', + output_audio.numpy().squeeze(), + 16_000, + ) diff --git a/voicebox/scripts/streamer/enroll.py b/voicebox/scripts/streamer/enroll.py new file mode 100755 index 0000000000000000000000000000000000000000..6e8714591e5f877f886920c33751ff47b20078d6 --- /dev/null +++ b/voicebox/scripts/streamer/enroll.py @@ -0,0 +1,105 @@ +""" +Pipeline for enrolling: +1. Provide Recording +2. Convert to 16 kHz +3. Divide into recordings +4. Get embeddings for each recording +5. Find centroid +6. Save conditioning as some value. +""" +import os +import argbind +import sounddevice as sd +import soundfile +import torch +import numpy as np + +import sys + +sys.path.append('.') + +from src.constants import CONDITIONING_FILENAME, CONDITIONING_FOLDER +from src.data import DataProperties +from src.models import ResNetSE34V2 + + +MIN_WINDOWS = 10 +WINDOW_SIZE = 64_000 +BLOCK_SIZE = 256 + +RECORDING_TEXT = """ +This script will record you speaking, and will create an embedding +to be used for conditioning Voicebox. This will overwrite any previous +embeddings. We recommend at least 10 seconds of non-stop voice recording. +Press enter to begin recording. To stop recording, press ctrl-C. +""" + + +def get_streams(input_name: str, block_size: int) -> sd.InputStream: + """ + Gets Input stream object + """ + try: + input_name = int(input_name) + except ValueError: + pass + return ( + sd.InputStream(device=input_name, + samplerate=DataProperties.get('sample_rate'), + channels=1, + blocksize=block_size) + ) + + +def record_from_user(input_name: str) -> torch.Tensor: + input_stream = get_streams(input_name, BLOCK_SIZE) + input(RECORDING_TEXT) + input_stream.start() + all_frames = [] + try: + print("Recording...") + while True: + frames, _ = input_stream.read(BLOCK_SIZE) + all_frames.append(frames) + except KeyboardInterrupt: + print("Stopped Recording.") + pass + all_frames = torch.Tensor(np.array(all_frames)) + recording = all_frames.reshape(-1) + return recording + + +def get_embedding(recording) -> torch.Tensor: + model = ResNetSE34V2(nOut=512, encoder_type='ASP') + recording = recording.view(1, -1) + embedding = model(recording) + return embedding + + +def save(embedding, audio) -> None: + os.makedirs(CONDITIONING_FOLDER, exist_ok=True) + torch.save(embedding, CONDITIONING_FILENAME) + soundfile.write( + CONDITIONING_FOLDER / 'conditioning_audio.wav', + audio.detach().cpu(), + DataProperties.get('sample_rate') + ) + + +@argbind.bind(positional=True, without_prefix=True) +def main(input: str = None): + """ + Creating a conditioning vector for VoiceBox from your voice + + :param input: Index or name of input audio interface. Defaults to current device + :type input: str, optional + """ + recording = record_from_user(input) + embedding = get_embedding(recording) + save(embedding, recording) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + main() diff --git a/voicebox/scripts/streamer/stream.py b/voicebox/scripts/streamer/stream.py new file mode 100755 index 0000000000000000000000000000000000000000..3e7196bf95259e9994ea73ca2a8b60252e055693 --- /dev/null +++ b/voicebox/scripts/streamer/stream.py @@ -0,0 +1,135 @@ +import argbind +import sounddevice as sd +import numpy as np +import yaml +import torch +import os +from typing import Union + +import sys +import warnings + +sys.path.append('.') +warnings.filterwarnings('ignore', category=UserWarning) + +from src.data.dataproperties import DataProperties +from src.attacks.online import Streamer, VoiceBoxStreamer +from src.constants import MODELS_DIR, CONDITIONING_FILENAME + + +def get_streams(input_name: str, output_name: str, block_size: int) -> tuple[sd.InputStream, sd.OutputStream]: + """ + Gets Input and Output stream objects + """ + try: + input_name = int(input_name) + except ValueError: + pass + try: + output_name = int(output_name) + except ValueError: + pass + return ( + sd.InputStream(device=input_name, + samplerate=DataProperties.get('sample_rate'), + channels=1, + blocksize=block_size), + sd.OutputStream(device=output_name, + samplerate=DataProperties.get('sample_rate'), + channels=1, + blocksize=block_size) + ) + + +def get_model_streamer(device: str, conditioning_path: str) -> Streamer: + # TODO: Make a good way to query an attack type. For now, I'm going to hard code this. + model_dir = os.path.join(MODELS_DIR, 'voicebox') + checkpoint_path = os.path.join(model_dir, 'voicebox_final.pt') + config_path = os.path.join(model_dir, 'voicebox_final.yaml') + + with open(config_path) as f: + config = yaml.safe_load(f) + + state_dict = torch.load(checkpoint_path, map_location=device) + condition_tensor = torch.load(conditioning_path, map_location=device) + model = VoiceBoxStreamer( + **config + ) + model.load_state_dict(state_dict) + model.condition_vector = condition_tensor.reshape(1, 1, -1) + + streamer = Streamer( + model=model, + device=device, + lookahead_frames=config['bottleneck_lookahead_frames'], + recurrent=True + ) + return streamer + + +def to_model(x: np.ndarray, device: str) -> torch.Tensor: + return torch.Tensor(x).view(1, 1, -1).to(device) + + +def from_model(x: torch.Tensor) -> np.ndarray: + return x.detach().cpu().view(-1, 1).numpy() + + +@argbind.bind(without_prefix=True) +def main( + input: str = None, + output: str = '', + device: str = 'cpu', + num_frames: int = 4, + pass_through: bool = False, + conditioning_path: str = CONDITIONING_FILENAME +): + f""" + Uses a streaming implementation of an attack to perturb incoming audio + + :param input: Index or name of input audio interface. Defaults to current device + :type input: str, optional + :param output: Index of name output audio interface. Defaults to 0 + :type output: str, optional + :param device: Device to processing attack. Should be either 'cpu' or 'cuda:X' + Defaults to 'cpu'. + :type device: str, optional + :param pass_through: If True, the voicebox perturbation is not applied and the input will be + identical to the output. This is for demo purposes. The input and output audio will + remain at 16 kHz. + :type pass_through: bool, optional + :type device: str, optional + :param num_frames: Number of overlapping model frames to process at one iteration. + Defaults to 1 + :type num_frames: int + :param conditioning_path: Path to conditioning tensor. Default: {CONDITIONING_FILENAME} + :type conditioning_path: str + """ + streamer = get_model_streamer(device, conditioning_path) + input_stream, output_stream = get_streams(input, output, streamer.hop_length) + if streamer.win_type in ['hann', 'triangular']: + input_samples = (num_frames - 1) * streamer.hop_length + streamer.window_length + else: + input_samples = streamer.hop_length + print("Ready to process audio") + input_stream.start() + output_stream.start() + try: + while True: + frames, overflow = input_stream.read(input_samples) + if pass_through: + output_stream.write(frames) + continue + out = streamer.feed(to_model(frames, device)) + out = from_model(out) + underflow = output_stream.write(out) + except KeyboardInterrupt: + print("Stopping") + input_stream.stop() + output_stream.stop() + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + main() diff --git a/voicebox/setup.py b/voicebox/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..2204312c7fce0359190eceff917f5e3e5273cfeb --- /dev/null +++ b/voicebox/setup.py @@ -0,0 +1,20 @@ +from setuptools import setup + +with open('README.md') as file: + long_description = file.read() + +setup( + name='src', + description='Code for VoiceBox', + version='0.0.1', + author='', + author_email='', + url='', + install_requires=[], + packages=['src'], + long_description=long_description, + long_description_content_type='text.markdown', + keywords=[], + classifiers=['License :: OSI Approved :: MIT License'], + license='MIT' +) diff --git a/voicebox/src.egg-info/PKG-INFO b/voicebox/src.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..3daeba61ee10f703a2099622f55839490770fa4c --- /dev/null +++ b/voicebox/src.egg-info/PKG-INFO @@ -0,0 +1,148 @@ +Metadata-Version: 2.1 +Name: src +Version: 0.0.1 +Summary: Code for VoiceBox +Home-page: +Author: +Author-email: +License: MIT +Classifier: License :: OSI Approved :: MIT License +Description-Content-Type: text.markdown +License-File: LICENSE + +

VoiceBlock

+

Privacy through Real-Time Adversarial Attacks with Audio-to-Audio Models

+
+ +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/???/???.ipynb) +[![Demo](https://img.shields.io/badge/Web-Demo-blue)](https://master.d3hvhbnf7qxjtf.amplifyapp.com/) +[![MIT license](https://img.shields.io/badge/License-MIT-blue.svg)](/LICENSE) + +
+

+ + +## Contents + +* Installation +* Reproducing Results +* Streaming Implementation +* Citation + +

Installation

+ +1. Clone the repository: + + git clone https://github.com/voiceboxneurips/voicebox.git + +2. We recommend working from a clean environment, e.g. using `conda`: + + conda create --name voicebox python=3.9 + source activate voicebox + +3. Install dependencies: + + cd voicebox + pip install -r requirements.txt + pip install -e . + +4. Grant permissions: + + chmod -R u+x scripts/ + +

Reproducing Results

+ +To reproduce our results, first download the corresponding data. Note that to download the [VoxCeleb1 dataset](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html), you must register and obtain a username and password. + +| Task | Dataset (Size) | Command | +|---|---|---| +| Objective evaluation | VoxCeleb1 (39G) | `python scripts/downloads/download_voxceleb.py --subset=1 --username= --password=` | +| WER / supplemental evaluations | LibriSpeech `train-clean-360` (23G) | `./scripts/downloads/download_librispeech_eval.sh` | +| Train attacks | LibriSpeech `train-clean-100` (11G) | `./scripts/downloads/download_librispeech_train.sh` | + + +We provide scripts to reproduce our experiments and save results, including generated audio, to named and time-stamped subdirectories within `runs/`. To reproduce our objective evaluation experiments using pre-trained attacks, run: + +``` +python scripts/experiments/evaluate.py +``` + +To reproduce our training, run: + +``` +python scripts/experiments/train.py +``` + +

Streaming Implementation

+ +As a proof of concept, we provide a streaming implementation of VoiceBox capable of modifying user audio in real-time. Here, we provide installation instructions for MacOS and Ubuntu 20.04. + +

MacOS

+ +See video below: + + +

+
+ +

Ubuntu 20.04

+ + +1. Open a terminal and follow the [installation instructions](#install) above. Change directory to the root of this repository. + +2. Run the following command: + + pacmd load-module module-null-sink sink_name=voicebox sink_properties=device.description=voicebox + + If you are using PipeWire instead of PulseAudio: + + pactl load-module module-null-sink media.class=Audio/Sink sink_name=voicebox sink_properties=device.description=voicebox + + PulseAudio is the default on Ubuntu. If you haven't changed your system defaults, you are probably using PulseAudio. This will add "voicebox" as an output device. Select it as the input to your chosen audio software. + +3. Find which audio device to read and write from. In your conda environment, run: + + python -m sounddevice + + You will get output similar to this: + + 0 HDA Intel HDMI: 0 (hw:0,3), ALSA (0 in, 8 out) + 1 HDA Intel HDMI: 1 (hw:0,7), ALSA (0 in, 8 out) + 2 HDA Intel HDMI: 2 (hw:0,8), ALSA (0 in, 8 out) + 3 HDA Intel HDMI: 3 (hw:0,9), ALSA (0 in, 8 out) + 4 HDA Intel HDMI: 4 (hw:0,10), ALSA (0 in, 8 out) + 5 hdmi, ALSA (0 in, 8 out) + 6 jack, ALSA (2 in, 2 out) + 7 pipewire, ALSA (64 in, 64 out) + 8 pulse, ALSA (32 in, 32 out) + * 9 default, ALSA (32 in, 32 out) + + In this example, we are going to route the audio through PipeWire (channel 7). This will be our INPUT_NUM and OUTPUT_NUM + +4. First, we need to create a conditioning embedding. To do this, run the enrollment script and follow its on-screen instructions: + + python scripts/streamer/enroll.py --input INPUT_NUM + +5. We can now use the streamer. Run: + + python scripts/stream.py --input INPUT_NUM --output OUTPUT_NUM + +6. Once the streamer is running, open `pavucontrol`. + + a. In `pavucontrol`, go to the "Playback" tab and find "ALSA pug-in [python3.9]: ALSA Playback on". Set the output to "voicebox". + + b. Then, go to "Recording" and find "ALSA pug-in [python3.9]: ALSA Playback from", and set the input to your desired microphone device. + +

Citation

+ +If you use this your academic research, please cite the following: + +``` +@inproceedings{authors2022voicelock, +title={VoiceBlock: Privacy through Real-Time Adversarial Attacks with Audio-to-Audio Models}, +author={Patrick O'Reilly, Andreas Bugler, Keshav Bhandari, Max Morrison, Bryan Pardo}, +booktitle={Neural Information Processing Systems}, +month={November}, +year={2022} +} +``` diff --git a/voicebox/src.egg-info/SOURCES.txt b/voicebox/src.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..26ec5d364917c7e236ffac4052b8d9e79c1ffb39 --- /dev/null +++ b/voicebox/src.egg-info/SOURCES.txt @@ -0,0 +1,9 @@ +LICENSE +README.md +setup.py +src/__init__.py +src/constants.py +src.egg-info/PKG-INFO +src.egg-info/SOURCES.txt +src.egg-info/dependency_links.txt +src.egg-info/top_level.txt \ No newline at end of file diff --git a/voicebox/src.egg-info/dependency_links.txt b/voicebox/src.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/voicebox/src.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/voicebox/src.egg-info/top_level.txt b/voicebox/src.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..85de9cf93344b897ee6b677d44c645d747f82b0c --- /dev/null +++ b/voicebox/src.egg-info/top_level.txt @@ -0,0 +1 @@ +src diff --git a/voicebox/src/__init__.py b/voicebox/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/voicebox/src/__pycache__/__init__.cpython-310.pyc b/voicebox/src/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8304f33ab992722427dcbae33308f31f44c448c Binary files /dev/null and b/voicebox/src/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/__pycache__/__init__.cpython-39.pyc b/voicebox/src/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d79706cbe17d171a9a3a22d46bfab41fbcd56d26 Binary files /dev/null and b/voicebox/src/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/__pycache__/constants.cpython-310.pyc b/voicebox/src/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c90f253cd8996bca0b7cd0187f2841922344f0f Binary files /dev/null and b/voicebox/src/__pycache__/constants.cpython-310.pyc differ diff --git a/voicebox/src/__pycache__/constants.cpython-39.pyc b/voicebox/src/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c94c3fc3df92b0722ea7b854c2b489ca156b2269 Binary files /dev/null and b/voicebox/src/__pycache__/constants.cpython-39.pyc differ diff --git a/voicebox/src/attacks/__init__.py b/voicebox/src/attacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f5a12faa99758192ecc4ed3fc22c9249232e86 --- /dev/null +++ b/voicebox/src/attacks/__init__.py @@ -0,0 +1 @@ + diff --git a/voicebox/src/attacks/__pycache__/__init__.cpython-310.pyc b/voicebox/src/attacks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..728befc6371f9392850db8dbed7ceca693f27f56 Binary files /dev/null and b/voicebox/src/attacks/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/attacks/__pycache__/__init__.cpython-39.pyc b/voicebox/src/attacks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ffda474ed0d9823e97ed37d21906961af0f1af3 Binary files /dev/null and b/voicebox/src/attacks/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/__init__.py b/voicebox/src/attacks/offline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..911b212ea58f56a8e3c5b2e5eb33fff6673d19d1 --- /dev/null +++ b/voicebox/src/attacks/offline/__init__.py @@ -0,0 +1,13 @@ +#Added by andy, as the code is being executed in the voiceblock folder instead of the voicebox folder +import sys +sys.path.append("voicebox") + +from src.attacks.offline.orthogonal_selective import SelectiveOrthogonalPGDMixin +from src.attacks.offline.null import NullAttack +from src.attacks.offline.advpulse import AdvPulseAttack +from src.attacks.offline.kenansville import KenansvilleAttack +from src.attacks.offline.demucs import DemucsAttack +from src.attacks.offline.voicebox import VoiceBoxAttack +from src.attacks.offline.white_noise import WhiteNoiseAttack +from src.attacks.offline.trainable import TrainableAttack +from src.attacks.offline.offline import OfflineAttack diff --git a/voicebox/src/attacks/offline/__pycache__/__init__.cpython-310.pyc b/voicebox/src/attacks/offline/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3a5d761aac1e824d4c1d83c5205e2d7c5eec317 Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/attacks/offline/__pycache__/__init__.cpython-39.pyc b/voicebox/src/attacks/offline/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c864b57d0a9f4fd8192ebeac47af29076fe2fc2 Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/__pycache__/advpulse.cpython-39.pyc b/voicebox/src/attacks/offline/__pycache__/advpulse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e195dfcce2578cf984559a5ae31cac5ba69f8ace Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/advpulse.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/__pycache__/demucs.cpython-39.pyc b/voicebox/src/attacks/offline/__pycache__/demucs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..166fc8941e837995db184d47edbf9091948b0ba6 Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/demucs.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/__pycache__/kenansville.cpython-39.pyc b/voicebox/src/attacks/offline/__pycache__/kenansville.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..668b7703bf2473db0c6cd091b921eb0b61a4e00f Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/kenansville.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/__pycache__/null.cpython-310.pyc b/voicebox/src/attacks/offline/__pycache__/null.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8501b2b5cddec202fe8d0877842120d70674011c Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/null.cpython-310.pyc differ diff --git a/voicebox/src/attacks/offline/__pycache__/null.cpython-39.pyc b/voicebox/src/attacks/offline/__pycache__/null.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcd091d50ba03e58d0b7cafdd3372925d506c783 Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/null.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/__pycache__/offline.cpython-310.pyc b/voicebox/src/attacks/offline/__pycache__/offline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..664c1a6a55aa998b075325e5b02027ae279ed5b3 Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/offline.cpython-310.pyc differ diff --git a/voicebox/src/attacks/offline/__pycache__/offline.cpython-39.pyc b/voicebox/src/attacks/offline/__pycache__/offline.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba1d6a234bc86bb66f05f1d29a2286db32814f8f Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/offline.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/__pycache__/orthogonal_selective.cpython-310.pyc b/voicebox/src/attacks/offline/__pycache__/orthogonal_selective.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3167ddfbd60b8787065d33fa74889a37dcfa18fd Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/orthogonal_selective.cpython-310.pyc differ diff --git a/voicebox/src/attacks/offline/__pycache__/orthogonal_selective.cpython-39.pyc b/voicebox/src/attacks/offline/__pycache__/orthogonal_selective.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62a9801600c01f839c26d4f5b62f6cbf671a59f4 Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/orthogonal_selective.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/__pycache__/trainable.cpython-39.pyc b/voicebox/src/attacks/offline/__pycache__/trainable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5edccaf2f90133c8192ec709a29da557f4bd51db Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/trainable.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/__pycache__/voicebox.cpython-39.pyc b/voicebox/src/attacks/offline/__pycache__/voicebox.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82e2f650210bd6ad9069bca4a68bfc2acdfc9b5c Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/voicebox.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/__pycache__/white_noise.cpython-39.pyc b/voicebox/src/attacks/offline/__pycache__/white_noise.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29ff0e7654451e5ddf1a2dfee8046591ebb4aad2 Binary files /dev/null and b/voicebox/src/attacks/offline/__pycache__/white_noise.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/advpulse.py b/voicebox/src/attacks/offline/advpulse.py new file mode 100644 index 0000000000000000000000000000000000000000..e60f69b5733de7ad6980579e247f00f3abce5429 --- /dev/null +++ b/voicebox/src/attacks/offline/advpulse.py @@ -0,0 +1,222 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math + +import librosa as li + +from typing import Union + +from src.attacks.offline.trainable import TrainableAttack +from src.pipelines import Pipeline +from src.loss.adversarial import AdversarialLoss +from src.attacks.offline.perturbation import AdditivePerturbation +from src.data import DataProperties + +################################################################################ +# Implementation of universal additive attack of Li et al. +################################################################################ + + +class AdvPulseAttack(TrainableAttack): + + def __init__(self, + pipeline: Pipeline, + adv_loss: AdversarialLoss, + mimic_sound: Union[torch.Tensor, str] = None, + init_mimic: bool = False, + tile_reference: bool = False, + normalize: bool = False, + eps: float = 0.05, + pgd_norm: Union[str, int, float] = float('inf'), + length: Union[int, float] = 0.5, + align: str = 'start', + loop: bool = False, + **kwargs + ): + + super().__init__( + pipeline=pipeline, + adv_loss=adv_loss, + perturbation=AdditivePerturbation( + eps=eps, + projection_norm=pgd_norm, + length=length, + align=align, + loop=loop, + normalize=normalize + ), + **kwargs + ) + + # determine whether to repeat template to match perturbation length + self.tile_reference = tile_reference + + if mimic_sound is None: + self.mimic_sound = None + + elif isinstance(mimic_sound, torch.Tensor): + + # require batch, channel dimensions + assert mimic_sound.ndim >= 2 + + # convert to mono audio + if mimic_sound.ndim == 2: + mimic_sound = mimic_sound.unsqueeze(1) + self.mimic_sound = mimic_sound.mean( + dim=1, keepdim=True + ).to(self.pipeline.device) + + # load from file path + elif isinstance(mimic_sound, str): + + # load from randomly-selected file + mimic_sound_np, _ = li.load( + mimic_sound, + sr=DataProperties.get('sample_rate'), + mono=True + ) + mimic_sound = torch.as_tensor(mimic_sound_np) + + # if length is specified, trim to match + max_len = DataProperties.get('signal_length') + self.mimic_sound = mimic_sound[..., :max_len].reshape( + 1, 1, -1 + ).to(self.pipeline.device) + + else: + raise ValueError(f'Invalid mimic sound type {type(mimic_sound)}') + + # if specified, initialize adversarial perturbation to match template + if self.mimic_sound is not None and init_mimic: + self.perturbation.delta = nn.Parameter( + self._match_signal_length( + self.mimic_sound, + torch.zeros(1, self.perturbation.length) + ) + ) + + @staticmethod + def _crossfade(sig, fade_len): + """Apply cross-fade to ends of signal""" + + sig = sig.clone() + fade_len = int(fade_len * sig.shape[-1]) + fade_in = torch.linspace(0, 1, fade_len).to(sig) + fade_out = torch.linspace(1, 0, fade_len).to(sig) + sig[..., :fade_len] *= fade_in + sig[..., -fade_len:] *= fade_out + return sig + + def _match_signal_length(self, sig: torch.Tensor, ref: torch.Tensor): + """ + Match length of signal to reference, either by trimming or repeating and + cross-fading + """ + + sig = sig.reshape(1, -1) + ref = ref.reshape(1, -1) + + signal_length = ref.shape[-1] + if sig.shape[-1] >= signal_length: + return sig[..., :signal_length].reshape(1, 1, -1).to(ref) + elif not self.tile_reference: + return F.pad( + sig, (0, signal_length - sig.shape[-1]) + ).reshape(1, 1, -1).to(ref) + + # cross-fade length + overlap = 0.05 + + step = math.ceil(sig.shape[-1] * (1 - overlap)) + n_repeat = math.ceil(signal_length / step) + + padded = torch.zeros( + 1, step * (n_repeat - 1) + sig.shape[-1] + 1 + ).reshape(1, -1).to(sig) + shape = padded.shape[:-1] + (n_repeat, sig.shape[-1]) + + strides = (padded.stride()[0],) + (step, padded.stride()[-1],) + frames = torch.as_strided( + padded, size=shape, stride=strides + )[::step] + + for j in range(n_repeat): + frames[:, j, :] += self._crossfade(sig, overlap) + + sig = padded[..., :signal_length].reshape( + 1, 1, -1 + ).to(ref) + + return sig + + def _set_loss_reference(self, x: torch.Tensor): + """ + Pass reference audio to auxiliary loss to avoid re-computing expensive + intermediate representations. For AdvPulse attack, optionally use + """ + + if self.aux_loss is not None: + + if self.mimic_sound is not None: + reference = self._match_signal_length( + self.mimic_sound, + self.perturbation.delta + ) + else: + reference = x + + self.aux_loss.set_reference(reference) + + def _compute_aux_loss(self, + x_adv: torch.Tensor, + x_ref: torch.Tensor = None): + """Compute auxiliary loss, optionally """ + if self.mimic_sound is not None: + return self.aux_loss(self.perturbation.delta, x_ref) + else: + return self.aux_loss(x_adv, x_ref) + + def _log_step(self, + x: torch.Tensor, + x_adv: torch.Tensor, + y: torch.Tensor, + adv_loss: Union[float, torch.Tensor] = None, + det_loss: Union[float, torch.Tensor] = None, + aux_loss: Union[float, torch.Tensor] = None, + success_rate: Union[float, torch.Tensor] = None, + detection_rate: Union[float, torch.Tensor] = None, + idx: int = 0, + tag: str = None, + *args, + **kwargs + ): + + if self.writer is None or self._iter_id % self.writer.log_iter: + return + + if tag is None: + tag = f'{self.__class__.__name__}-' \ + f'{self.aux_loss.__class__.__name__}' + + super()._log_step( + x, + x_adv, + y, + adv_loss=adv_loss, + det_loss=det_loss, + aux_loss=aux_loss, + success_rate=success_rate, + detection_rate=detection_rate, + idx=idx, + tag=tag + ) + + # add audio and spectrogram logging for mimic sound + if self.mimic_sound is not None: + self.writer.log_audio( + self.mimic_sound, + f'{tag}/sound-template', + global_step=self._iter_id + ) diff --git a/voicebox/src/attacks/offline/demucs.py b/voicebox/src/attacks/offline/demucs.py new file mode 100644 index 0000000000000000000000000000000000000000..b78cda12866c023358dc288cb4d85fefd2eba1b5 --- /dev/null +++ b/voicebox/src/attacks/offline/demucs.py @@ -0,0 +1,57 @@ +import torch + +from src.models.denoiser import load_demucs +from src.attacks.offline.trainable import TrainableAttack +from src.pipelines.pipeline import Pipeline +from src.loss.adversarial import AdversarialLoss + +################################################################################ +# Attack using Demucs waveform-convolutional U-net +################################################################################ + + +class DemucsAttack(TrainableAttack): + """ + Train a Demucs model to apply adversarial transformations to incoming audio + """ + def __init__(self, + pipeline: Pipeline, + adv_loss: AdversarialLoss, + model_name: str = 'dns_48', + pretrained: bool = True, + **kwargs): + """ + Train a pre-trained Demucs model to apply adversarial transformations to + incoming audio. Gradient descent variants adapted from `Evading + Adversarial Example Detection Defenses with Orthogonal Projected + Gradient Descent` by Bryniarski et al. + (https://github.com/v-wangg/OrthogonalPGD). + + :param pipeline: a Pipeline object wrapping a (defended) classifier + :param adv_loss: a Loss object encapsulating the adversarial objective; + should take model predictions and targets as arguments + :param aux_loss: an auxiliary Loss object; should take original and + adversarial inputs as arguments + :param model_name: name of pre-trained Demucs model to load + :param pretrained: if True, seek pretrained weights for given model + :param opt: optimizer; must be one of 'adam', 'sgd', or 'lbfgs' + :param lr: perturbation learning rate + :param mode: PGD variant; must be one of None, 'orthogonal', 'selective' + :param project_grad: p-norm for gradient regularization; must be one of + inf, 2, or None + :param k: if not None, perform gradient projection every kth step + :param max_iter: the maximum number of iterations per batch + :param epochs: optimization epochs over training data + :param eot_iter: resampling interval for Pipeline simulation parameters; + if 0 or None, do not resample parameters + :param batch_size: batch size for attack + :param rand_evals: randomly-resampled simulated evaluations per each + final generated attack + :param writer: a Writer object for logging attack progress + """ + super().__init__( + pipeline=pipeline, + adv_loss=adv_loss, + perturbation=load_demucs(model_name, pretrained), + **kwargs + ) diff --git a/voicebox/src/attacks/offline/kenansville.py b/voicebox/src/attacks/offline/kenansville.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc2985d9745aa8e476dd971a1a35a6a4c16a753 --- /dev/null +++ b/voicebox/src/attacks/offline/kenansville.py @@ -0,0 +1,368 @@ +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset + +from typing import Union + +import warnings + +from src.attacks.offline.trainable import TrainableAttack +from src.attacks.offline.perturbation.perturbation import Perturbation +from src.attacks.offline.perturbation.kenansville import KenansvillePerturbation +from src.pipelines.pipeline import Pipeline +from src.loss.adversarial import AdversarialLoss +from src.loss.auxiliary import AuxiliaryLoss +from src.utils.writer import Writer + +################################################################################ +# Untargeted, black-box signal-processing attack +################################################################################ + + +class KenansvilleAttack(TrainableAttack): + """ + Perturb inputs by removing frequency content. + """ + def __init__(self, + pipeline: Pipeline, + adv_loss: AdversarialLoss, + threshold_db_low: float = 1.0, + threshold_db_high: float = 100.0, + step_size: float = 10.0, + search: str = 'bisection', + min_success_rate: float = 0.9, + win_type: str = 'hann', + win_length: int = 2048, + **kwargs + ): + """ + Untargeted black-box spectral-bin-removal attack proposed by Abdullah + et al. (https://arxiv.org/abs/1910.05262). Code adapted from + https://bit.ly/31K4Efy. + + :param pipeline: a Pipeline object + :param adv_loss: an AdversarialLoss object; must be untargeted + :param aux_loss: an optional AuxiliaryLoss object + :param threshold_db: energy threshold relative to spectral peak energy; + frequency bins below threshold are removed + :param max_iter: iterations to search for optimal threshold. If nonzero, + search for highest (least perceptible) threshold value + such that attack achieves 100% untargeted success + against given pipeline. Otherwise, use given threshold + :param min_success_rate: minimum acceptable untargeted success rate when + optimizing threshold + :param win_type: window type; must be one of 'rectangular' or 'hann'. + For Hann window, audio is framed with 50% overlap + :param frame_len: frame length in samples + """ + + self.threshold_db_low = threshold_db_low + self.threshold_db_high = threshold_db_high + self.step_size = step_size + self.search = search + self.min_success_rate = min_success_rate + + super().__init__( + pipeline=pipeline, + adv_loss=adv_loss, + perturbation=KenansvillePerturbation( + threshold_db=threshold_db_low, + win_type=win_type, + win_length=win_length + ), + **kwargs + ) + + @torch.no_grad() + def train(self, + x_train: torch.Tensor = None, + y_train: torch.Tensor = None, + data_train: Dataset = None, + x_val: torch.Tensor = None, + y_val: torch.Tensor = None, + data_val: Dataset = None, + *args, + **kwargs + ): + + loader_train, loader_val = self._prepare_data( + x_train, + y_train, + data_train, + x_val, + y_val, + data_val) + + # match devices and set reference if necessary + ref_batch = next(iter(loader_train)) + + if isinstance(ref_batch, tuple): + x_ref = ref_batch[0] + warnings.warn('Warning: provided dataset yields batches in tuple ' + 'format; the first two tensors of each batch will be ' + 'interpreted as inputs and targets, respectively, ' + 'and any remaining tensors will be ignored. To pass ' + 'additional named tensor arguments, use a dictionary ' + 'batch format with keys `x` and `y` for inputs and ' + 'targets, respectively.') + elif isinstance(ref_batch, dict): + x_ref = ref_batch['x'] + else: + x_ref = ref_batch + + if hasattr(self.perturbation, "set_reference"): + try: + self.perturbation.set_reference( + x_ref.to(self.pipeline.device)) + except AttributeError: + pass + + # enumerate possible SNR values for search + threshold_values = torch.arange( + self.threshold_db_low, self.threshold_db_high, self.step_size) + + # track iterations + self._iter_id = 0 + self._batch_id = 0 + self._epoch_id = 0 + + # avoid unnecessary search + if self.threshold_db_low == self.threshold_db_high \ + or len(threshold_values) < 2 \ + or self.search in ['none', None]: + self.perturbation.set_threshold(self.threshold_db_low) + + else: + + # find best threshold via search + i_min = 0 + i_max = len(threshold_values) + threshold_best = self.threshold_db_low + + # perform bisection search for maximum SNR value which achieves + # minimum success threshold + if self.search == 'bisection': + + while i_min < i_max: + + # determine midpoint index + i_mid = (i_min + i_max) // 2 + threshold = threshold_values[i_mid] + + # set threshold + self.perturbation.set_threshold(threshold) + + # compute success rate over training data at each candidate + # threshold level + successes = 0 + n = 0 + + self._batch_id = 0 + for batch in loader_train: + + if isinstance(batch, dict): + x, y = batch['x'], batch['y'] + else: + x, y, *_ = batch + + x = x.to(self.pipeline.device) + y = y.to(self.pipeline.device) + + n += len(x) + x_adv = self.perturbation(x) + outputs = self.pipeline(x_adv) + adv_scores = self.adv_loss(outputs, y) + adv_loss = adv_scores.mean() + + batch_successes = (1.0 * self._compute_success_array( + x, y, x_adv)).sum().item() + successes += batch_successes + + self._log_step( + x, + x_adv, + y, + adv_loss, + success_rate=batch_successes/len(x) + ) + + self._batch_id += 1 + self._iter_id += 1 + + success_rate = successes / n + + if success_rate >= self.min_success_rate: + threshold_best = threshold + i_min = i_mid + 1 + else: + i_max = i_mid + + # perform linear search for SNR level + elif self.search == 'linear': + + for threshold in threshold_values: + + # set threshold + self.perturbation.set_threshold(threshold) + + # compute success rate over training data at each candidate + # threshold level + successes = 0 + n = 0 + + self._batch_id = 0 + for batch in loader_train: + + if isinstance(batch, dict): + x, y = batch['x'], batch['y'] + else: + x, y, *_ = batch + + x = x.to(self.pipeline.device) + y = y.to(self.pipeline.device) + + n += len(x) + x_adv = self.perturbation(x) + outputs = self.pipeline(x_adv) + adv_scores = self.adv_loss(outputs, y) + adv_loss = adv_scores.mean() + batch_successes = (1.0 * self._compute_success_array( + x, y, x_adv)).sum().item() + successes += batch_successes + + self._log_step( + x, + x_adv, + y, + adv_loss, + success_rate=batch_successes/len(x) + ) + + self._batch_id += 1 + self._iter_id += 1 + + success_rate = successes / n + + if success_rate >= self.min_success_rate: + threshold_best = threshold + else: + raise ValueError(f'Invalid search method {self.search}') + + # set final SNR value + self.perturbation.set_threshold(threshold_best) + + # perform validation + adv_scores = [] + aux_scores = [] + det_scores = [] + success_indicators = [] + detection_indicators = [] + + self.perturbation.eval() + + for batch_id, batch in enumerate(loader_val): + + # randomize simulation for each validation batch + self.pipeline.sample_params() + + if isinstance(batch, dict): + x_orig, targets = batch['x'], batch['y'] + else: + x_orig, targets, *_ = batch + + n_batch = x_orig.shape[0] + + x_orig = x_orig.to(self.pipeline.device) + targets = targets.to(self.pipeline.device) + + # set reference for auxiliary loss + self._set_loss_reference(x_orig) + + with torch.no_grad(): + + # compute adversarial loss + x_adv = self._evaluate_batch(x_orig, targets) + outputs = self.pipeline(x_adv) + adv_scores.append(self.adv_loss(outputs, targets).flatten()) + + # compute adversarial success rate + success_indicators.append( + 1.0 * self._compute_success_array( + x_orig, targets, x_adv + ).flatten()) + + # compute defense loss and detection indicators + def_results = self.pipeline.detect(x_adv) + detection_indicators.append(1.0 * def_results[0].flatten()) + det_scores.append(def_results[1].flatten()) + + # compute auxiliary loss + if self.aux_loss is not None: + aux_scores.append( + self._compute_aux_loss(x_adv).flatten()) + else: + aux_scores.append(torch.zeros(n_batch)) + + tag = f'{self.__class__.__name__}-' \ + f'{self.aux_loss.__class__.__name__}' + + if self.writer is not None: + + with self.writer.force_logging(): + + # adversarial loss value + self.writer.log_scalar( + torch.cat(adv_scores, dim=0).mean(), + f"{tag}/adversarial-loss-val", + global_step=self._iter_id + ) + + # detector loss value + self.writer.log_scalar( + torch.cat(det_scores, dim=0).mean(), + f"{tag}/detector-loss-val", + global_step=self._iter_id + ) + + # auxiliary loss value + self.writer.log_scalar( + torch.cat(aux_scores, dim=0).mean(), + f"{tag}/auxiliary-loss-val", + global_step=self._iter_id + ) + + # adversarial success rate + self.writer.log_scalar( + torch.cat(success_indicators, dim=0).mean(), + f"{tag}/success-rate-val", + global_step=self._iter_id + ) + + # adversarial detection rate + self.writer.log_scalar( + torch.cat(detection_indicators, dim=0).mean(), + f"{tag}/detection-rate-val", + global_step=self._iter_id + ) + + # freeze model parameters + self.perturbation.eval() + for p in self.perturbation.parameters(): + p.requires_grad = False + + # save model/perturbation + self._checkpoint() + + def _evaluate_batch(self, + x: torch.Tensor, + y: torch.Tensor, + **kwargs + ): + """ + Remove low-energy frequency content from inputs. + """ + + # require batch dimension + assert x.ndim >= 2 + + return self.perturbation(x) diff --git a/voicebox/src/attacks/offline/null.py b/voicebox/src/attacks/offline/null.py new file mode 100644 index 0000000000000000000000000000000000000000..a50c0a6579d3b8165a5bef81c04cea8da35939d7 --- /dev/null +++ b/voicebox/src/attacks/offline/null.py @@ -0,0 +1,54 @@ +import torch + +from src.attacks.offline.offline import OfflineAttack +from src.pipelines.pipeline import Pipeline +from src.loss.adversarial import AdversarialLoss +from src.utils.writer import Writer + +################################################################################ +# "Null" attack (apply no perturbations to inputs) +################################################################################ + + +class NullAttack(OfflineAttack): + """ + Simple baseline attack in which inputs are passed to model unaltered. + """ + + def __init__(self, + pipeline: Pipeline, + adv_loss: AdversarialLoss, + batch_size: int = 1, + rand_evals: int = 0, + writer: Writer = None, + **kwargs + ): + + super().__init__( + pipeline=pipeline, + adv_loss=adv_loss, + batch_size=batch_size, + rand_evals=rand_evals, + writer=writer + ) + + def _evaluate_batch(self, + x: torch.Tensor, + y: torch.Tensor, + **kwargs + ): + """ + Apply no perturbations to inputs. + """ + + # compute adversarial inputs + x_adv = x.clone().detach() + + # log attack results + self._log_step( + x=x, + x_adv=x_adv, + y=y + ) + + return x_adv diff --git a/voicebox/src/attacks/offline/offline.py b/voicebox/src/attacks/offline/offline.py new file mode 100644 index 0000000000000000000000000000000000000000..8e12568ea461553dd641074cbe169d90eb112a9b --- /dev/null +++ b/voicebox/src/attacks/offline/offline.py @@ -0,0 +1,429 @@ +import torch +import torch.nn as nn + +from typing import Tuple, Union +from torch.utils.data import Dataset, TensorDataset, DataLoader + +from src.utils.writer import Writer +from src.pipelines.pipeline import Pipeline +from src.loss.adversarial import AdversarialLoss +from src.loss.auxiliary import AuxiliaryLoss + +################################################################################ +# Base class for offline adversarial attacks +################################################################################ + + +class OfflineAttack: + + def __init__(self, + pipeline: Pipeline, + adv_loss: AdversarialLoss, + aux_loss: AuxiliaryLoss = None, + batch_size: int = 32, + rand_evals: int = 0, + writer: Writer = None, + **kwargs + ): + """ + Base class for offline attacks. Subclasses must override the + `evaluate_batch()` method. + + Offline attacks optimize perturbations of benign inputs without + real-time performance constraints. Optimization is performed using a + stored Pipeline object, encompassing a victim model, acoustic + simulation, and adversarial defenses. + + :param pipeline: a Pipeline object wrapping a (defended) classifier + :param adv_loss: AdversarialLoss object encapsulating attacker objective + :param aux_loss: optional AuxiliaryLoss object encapsulating + some perceptibility objective + :param batch_size: batch size for attack + :param rand_evals: randomly-resampled simulated evaluations per each + final generated attack + """ + self.pipeline = pipeline + + # ensure gradients flow through PyTorch RNN layers + self._pipeline_rnn_grad() + + self.adv_loss = adv_loss + self.aux_loss = aux_loss + + self.batch_size = batch_size + self.rand_evals = rand_evals + + # log attack progress + self.writer = writer + + # track batch inputs + self._batch_id = 0 + self._iter_id = 0 + + # optional data-loading arguments + self.pin_memory = kwargs.get('pin_memory', False) + self.num_workers = kwargs.get('num_workers', 0) + + self._check_loss() + + def _pipeline_rnn_grad(self): + """ + PyTorch requires any recurrent modules be placed in `train` mode to + enable backpropagation through the pipeline. + """ + for m in self.pipeline.modules(): + if isinstance(m, nn.RNNBase): + m.train() + + def _check_loss(self): + """ + Validate adversarial and auxiliary losses + """ + + assert self.adv_loss is not None, 'Must provide adversarial loss' + assert self.adv_loss.reduction in ['none', None], \ + 'All losses must provide unreduced scores' + + assert self.aux_loss is None or \ + self.aux_loss.reduction in ['none', None], \ + 'All losses must provide unreduced scores' + + @staticmethod + def _create_dataset(x: torch.Tensor, y: torch.Tensor): + """ + If attack inputs are given as tensors, create a simple dataset + """ + + # require batch dimension + assert x.ndim >= 2 + + dataset = TensorDataset( + x.type(torch.float32), + y.type(torch.float32), + ) + return dataset + + def _compute_detection_array(self, x_adv, *args, **kwargs): + """ + Pass attack audio through any detection defenses in stored Pipeline, and + return boolean detection flags for each input + """ + flags, scores = self.pipeline.detect(x_adv) + return flags + + @torch.no_grad() + def _compute_success_array(self, + x: torch.Tensor, + y: torch.Tensor, + x_adv: torch.Tensor, + *args, + **kwargs + ): + """ + Pass attack audio through stored Pipeline and determine adversarial + success for each input + """ + + # obtain 'clean' and adversarial predictions of stored Pipeline + preds = self.pipeline(x.detach()) + adv_preds = self.pipeline(x_adv.detach()) + + # for a targeted attack, attempt to match given targets + if self.adv_loss.targeted: + attack_success = self.pipeline.match_predict(adv_preds, y) + + # for an untargeted attack, attempt to evade clean predictions + else: + attack_success = ~self.pipeline.match_predict(adv_preds, preds) + + return attack_success + + def _log_step(self, + x: torch.Tensor, + x_adv: torch.Tensor, + y: torch.Tensor, + adv_loss: Union[float, torch.Tensor] = None, + det_loss: Union[float, torch.Tensor] = None, + aux_loss: Union[float, torch.Tensor] = None, + success_rate: Union[float, torch.Tensor] = None, + detection_rate: Union[float, torch.Tensor] = None, + idx: int = 0, + tag: str = None, + *args, + **kwargs + ): + """ + Log attack progress. + + :param x: batch of original inputs + :param x_adv: batch of adversarial inputs + :param y: batch of adversarial targets + :param adv_loss: adversarial loss value + :param det_loss: detection loss value + :param aux_loss: auxiliary loss value + :param success_rate: adversarial success rate + :param detection_rate: adversarial defense detection rate + :param idx: batch index for logging individual examples + """ + + if self.writer is None or self._iter_id % self.writer.log_iter: + return + + if tag is None: + tag = f'{self.__class__.__name__}-batch-{self._batch_id}' + + x = x.clone().detach() + x_adv = x_adv.clone().detach() + + # compute losses and simulated audio + with torch.no_grad(): + outputs_adv = self.pipeline(x_adv) + simulated = self.pipeline.simulate(x) + simulated_adv = self.pipeline.simulate(x_adv) + + # if adversarial loss is not provided, compute + if adv_loss is None: + adv_loss = self.adv_loss(outputs_adv, y).mean() + + # if detector loss or rate is not provided, compute + if det_loss is None or detection_rate is None: + flags, scores = self.pipeline.detect(x_adv) + det_loss = scores.mean() + detection_rate = torch.mean(1.0 * flags) + + # if auxiliary loss is not provided, compute + if aux_loss is None: + aux_loss = 0.0 if self.aux_loss is None else self.aux_loss( + x_adv, x + ).mean() + + # if adversarial success rate is not provided, compute + if success_rate is None: + success = self._compute_success_array( + x=x, + x_adv=x_adv, + y=y + ) + success_rate = torch.mean(1.0 * success) + + # unperturbed input + self.writer.log_audio( + x[idx], + f"{tag}/original", + global_step=self._iter_id + ) + + # simulated unperturbed input + self.writer.log_audio( + simulated[idx], + f"{tag}/simulated-original", + global_step=self._iter_id + ) + + # adversarial input + self.writer.log_audio( + x_adv[idx], + f"{tag}/adversarial", + global_step=self._iter_id + ) + + # simulated adversarial input + self.writer.log_audio( + simulated_adv[idx], + f"{tag}/simulated-adversarial", + global_step=self._iter_id + ) + + # adversarial loss value + self.writer.log_scalar( + adv_loss, + f"{tag}/adversarial-loss", + global_step=self._iter_id + ) + + # detector loss value + self.writer.log_scalar( + det_loss, + f"{tag}/detector-loss", + global_step=self._iter_id + ) + + # auxiliary loss value + self.writer.log_scalar( + aux_loss, + f"{tag}/auxiliary-loss", + global_step=self._iter_id + ) + + # adversarial success rate + self.writer.log_scalar( + success_rate, + f"{tag}/success-rate", + global_step=self._iter_id + ) + + # adversarial detection rate + self.writer.log_scalar( + detection_rate, + f"{tag}/detection-rate", + global_step=self._iter_id + ) + + def _evaluate_batch(self, + x: torch.Tensor, + y: torch.Tensor, + **kwargs + ): + """ + Perform attack on a batch of inputs. + + :param x: input tensor of shape (n_batch, ...) + :param y: targets tensor of shape (n_batch, ...) in case of targeted + attack; original labels tensor of shape (n_batch, ...) in + case of untargeted attack + """ + raise NotImplementedError() + + @torch.no_grad() + def evaluate(self, + x: torch.Tensor = None, + y: torch.Tensor = None, + dataset: Dataset = None, + **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Perform attack given input-target pairs, optionally in the form of a + Dataset object. Random evaluations will then be conducted on all + generated attacks. + + :param x: audio input, shape (n_batch, ..., signal_length) + :param y: targets, shape (n_batch, ...) + :param dataset: optionally, provide inputs and targets as dataset + :return: tuple holding + * adversarial audio (n_batch, ..., signal_length) + * boolean adversarial success indicators (n_batch,) + * boolean adversarial detection indicators (n_batch,) + """ + + assert (x is not None and y is not None) or dataset is not None + + # prepare batched data-loading, store original device + if dataset is None: + orig_device = x.device + dataset = self._create_dataset(x, y) + x_ref = x[0:1].clone().detach() + else: + ref_batch = next(iter(dataset)) + if isinstance(ref_batch, tuple): + x_ref = ref_batch[0] + elif isinstance(ref_batch, dict): + x_ref = ref_batch['x'] + else: + x_ref = ref_batch + orig_device = x_ref.device + + # prepare to compute attack success and detection rates + attack_success = torch.zeros( + len(dataset), dtype=torch.float).to(self.pipeline.device) + attack_detection = torch.zeros( + len(dataset), dtype=torch.float).to(self.pipeline.device) + + # prepare to store attack outputs + adv_x = torch.stack( + [torch.zeros(x_ref.shape)] * len(dataset), + dim=0 + ).to(self.pipeline.device) + + data_loader = DataLoader( + dataset=dataset, + batch_size=self.batch_size, + shuffle=False, + drop_last=False, + num_workers=self.num_workers, + pin_memory=self.pin_memory + ) + + # compute attacks with batching + for (batch_id, batch_all) in enumerate(data_loader): + + self._batch_id = batch_id + + # allow for different dataset formats + if isinstance(batch_all, tuple): + batch_all = { + 'x': batch_all[0], + 'y': batch_all[1] + } + + # match devices + for k in batch_all.keys(): + batch_all[k] = batch_all[k].to(self.pipeline.device) + + batch_index_1 = batch_id * self.batch_size + batch_index_2 = (batch_id + 1) * self.batch_size + + # compute attacks for given batch + adversarial_batch = self._evaluate_batch( + **batch_all, **kwargs + ) + + # if no random trials, evaluate once + if not self.rand_evals: + + # compute and store success rates for batch + attack_success_batch = self._compute_success_array( + **batch_all, + x_adv=adversarial_batch + ).reshape(-1).type(torch.float32) + + # compute and store detection rates for batch + attack_detection_batch = self._compute_detection_array( + x_adv=adversarial_batch + ).reshape(-1).type(torch.float32) + + # otherwise, perform multiple random evaluations per attack + else: + + # track batch success and detection over random evaluation + success_combined_batch = [] + detection_combined_batch = [] + + for i in range(self.rand_evals): + + # randomly sample simulation parameters + self.pipeline.sample_params() + + # compute and store success rates for batch + rand_success_batch = self._compute_success_array( + **batch_all, + x_adv=adversarial_batch + ).reshape(-1, 1) + success_combined_batch.append(rand_success_batch) + + # compute and store detection rates for batch + rand_detection_batch = self._compute_detection_array( + x_adv=adversarial_batch + ) + detection_combined_batch.append(rand_detection_batch) + + # average results over all trials + attack_success_batch = (1.0 * torch.cat( + success_combined_batch, dim=-1 + )).mean(dim=-1) + + attack_detection_batch = (1.0 * torch.cat( + detection_combined_batch, dim=-1 + )).mean(dim=-1) + + # store generated attack audio + adv_x[batch_index_1:batch_index_2] = adversarial_batch + + # store success rate per generated attack + attack_success[batch_index_1:batch_index_2] = attack_success_batch + + # store detection rate per generated attack + attack_detection[batch_index_1:batch_index_2] = attack_detection_batch + + return (adv_x.to(orig_device), + attack_success.to(orig_device), + attack_detection.to(orig_device)) diff --git a/voicebox/src/attacks/offline/orthogonal_selective.py b/voicebox/src/attacks/offline/orthogonal_selective.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a07342701c9b5787a4a65f3e128e4495e0c34c --- /dev/null +++ b/voicebox/src/attacks/offline/orthogonal_selective.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn + +################################################################################ +# Mixin class for handling selective/orthogonal PGD variants +################################################################################ + + +class SelectiveOrthogonalPGDMixin(object): + + def __init__(self, **kwargs): + pass + + @staticmethod + def _dot(x1: torch.Tensor, x2: torch.Tensor): + """ + Compute batch dot product along final dimension + """ + return (x1*x2).sum(-1, keepdim=True) + + def _project_orthogonal(self, x1: torch.Tensor, x2: torch.Tensor): + """ + Compute projection component of x1 along x2. For projection + onto zero vector, return zero vector + """ + return x2 * (self._dot(x1, x2) / self._dot(x2, x2).clamp_min(1e-12)) + + def _component_orthogonal(self, + x1: torch.Tensor, + x2: torch.Tensor, + x3: torch.Tensor): + """ + Compute component of x1 approximately orthogonal to x2 and x3 + """ + return x1 - self._project_orthogonal( + x1, x2 - self._project_orthogonal(x2, x3) + ) - self._project_orthogonal(x1, x3) + + @staticmethod + def _retrieve_parameter_gradients(m: nn.Module): + """ + Retrieve all trainable parameters of a nn.Module object + :return: tensor of shape (n_parameters,) + """ + + flattened_grad = [] + + for param in m.parameters(): + if param.requires_grad: + if param.grad is None: + flattened_grad.append( + torch.zeros_like(param).detach().flatten() + ) + else: + flattened_grad.append(param.grad.detach().flatten()) + + return torch.cat(flattened_grad, dim=-1) + + @staticmethod + def _set_parameter_gradients(flattened_grad: torch.Tensor, m: nn.Module): + """ + Set gradient attributes of trainable parameters of a nn.Module object + :param params: tensor of shape (n_parameters,) + """ + + # check that flattened gradients have valid shape + prod = sum( + [p.shape.numel() for p in m.parameters() if p.requires_grad] + ) + + assert flattened_grad.ndim <= 1 + assert flattened_grad.numel() == prod + + idx = 0 + for param in m.parameters(): + if param.requires_grad: + param_length = param.shape.numel() + grad = flattened_grad[idx:idx + param_length].reshape( + param.shape + ).detach() + param.grad = grad + idx += param_length diff --git a/voicebox/src/attacks/offline/perturbation/__init__.py b/voicebox/src/attacks/offline/perturbation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb13305e643a7a0d54d39fc91456921dca9594a2 --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/__init__.py @@ -0,0 +1,6 @@ +from src.attacks.offline.perturbation.perturbation import Perturbation +from src.attacks.offline.perturbation.additive import AdditivePerturbation +from src.attacks.offline.perturbation.white_noise import WhiteNoisePerturbation +from src.attacks.offline.perturbation.kenansville import KenansvillePerturbation +from src.attacks.offline.perturbation.voicebox import VoiceBox + diff --git a/voicebox/src/attacks/offline/perturbation/__pycache__/__init__.cpython-310.pyc b/voicebox/src/attacks/offline/perturbation/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93eefee684ee0b9dac72a85864ef03e414c35a38 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/__pycache__/__init__.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3db9b7554ea96dabc65292bb9d920584723b1169 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/__pycache__/additive.cpython-310.pyc b/voicebox/src/attacks/offline/perturbation/__pycache__/additive.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f5ddfc9dd8a588f021be803fa1d92136027280f Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/__pycache__/additive.cpython-310.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/__pycache__/additive.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/__pycache__/additive.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6579d598cc6e36b176eb1d69eb17cefffd036aa4 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/__pycache__/additive.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/__pycache__/kenansville.cpython-310.pyc b/voicebox/src/attacks/offline/perturbation/__pycache__/kenansville.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77555103848099447ebe03a9e356090baad54b4b Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/__pycache__/kenansville.cpython-310.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/__pycache__/kenansville.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/__pycache__/kenansville.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b33014fcbfb2ab04bdb914f4bc9d44addf864aee Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/__pycache__/kenansville.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/__pycache__/perturbation.cpython-310.pyc b/voicebox/src/attacks/offline/perturbation/__pycache__/perturbation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1679bc8a29934209efb7cbe6c3a2f42483fc4660 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/__pycache__/perturbation.cpython-310.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/__pycache__/perturbation.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/__pycache__/perturbation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f841a6c6841e642e663c537cc6e7073de722c46 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/__pycache__/perturbation.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/__pycache__/white_noise.cpython-310.pyc b/voicebox/src/attacks/offline/perturbation/__pycache__/white_noise.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ccbb85ca1dd7c388af9365904265e4f7da7b8a6 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/__pycache__/white_noise.cpython-310.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/__pycache__/white_noise.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/__pycache__/white_noise.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b62fa0644ecf79aa0c9c91a300db8580b8121c79 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/__pycache__/white_noise.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/additive.py b/voicebox/src/attacks/offline/perturbation/additive.py new file mode 100644 index 0000000000000000000000000000000000000000..dd34d9991a633f51919982601841135a1ba658b4 --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/additive.py @@ -0,0 +1,211 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math +import random + +from typing import Union, Dict + +from src.attacks.offline.perturbation.perturbation import Perturbation +from src.utils.plotting import plot_waveform +from src.data import DataProperties + +################################################################################ +# Apply additive perturbation to waveform audio +################################################################################ + + +class AdditivePerturbation(Perturbation): + + def __init__(self, + eps: float, + projection_norm: Union[str, int, float], + length: Union[int, float] = None, + align: str = 'start', + loop: bool = False, + normalize: bool = False + ): + + super().__init__() + + self.eps = eps + self.projection_norm = projection_norm + self.normalize = normalize + + assert align in ['start', 'random', 'none', None], \ + f'Invalid alignment; must be one of "start" or "random"' + self.align = align + + # if length is given as floating-point (time), convert to samples + if isinstance(length, float): + length = int(length * DataProperties.get('sample_rate')) + self.length = length + + # if True, loop perturbation to end of audio + self.loop = loop + + self.register_parameter( + "delta", nn.Parameter(torch.zeros(1, 1, self.length))) + + def set_reference(self, x: torch.Tensor): + """ + Given reference input, initialize perturbation parameters accordingly + and match input device. + + :param x: reference audio, shape (n_batch, n_channels, signal_length) + """ + + # require batch dimension + assert x.ndim >= 2, f"Invalid reference audio dimensions {x.shape}" + + # determine whether to match length + if self.length is None or self.length > x.shape[-1]: + length = x.shape[-1] + else: + length = self.length + + # initialize single-waveform additive perturbation and match reference + # device + self.delta = nn.Parameter( + torch.zeros(1, *x.shape[1:-1], length).to(x.device) + ) + + def forward(self, x: torch.Tensor, *args, **kwargs): + """ + Apply perturbation to inputs. + + :param x: input audio, shape (n_batch, n_channels, signal_length) + """ + + # do not overwrite incoming audio + x = x.clone().detach() + ndims = x.ndim + + # prepare to restore original volume + peak = torch.max(torch.abs(x), -1)[0].reshape(-1) + + # account for audio shorter than stored perturbation + orig_len = x.shape[-1] + x = F.pad(x, (0, max(self.delta.shape[-1] - x.shape[-1], 0))) + + # normalize to apply perturbation + if self.normalize: + x = (1.0 / torch.max(torch.abs(x) + 1e-8, dim=-1, keepdim=True)[0]) * x + + # check that input is broadcastable with additive perturbation + assert self._check_broadcastable( + x[..., :self.delta.shape[-1]], + self.delta + ), \ + f"Cannot broadcast inputs of shape {x.shape} " \ + f"with additive perturbation of shape {self.delta.shape}" + + # determine alignment + if self.align in ['start', 'none', None]: + st_idx = 0 + elif self.align == 'random': + st_idx = random.randrange( + 0, + x.shape[-1] - self.delta.shape[-1] + ) + else: + raise ValueError(f'Invalid alignment {self.align}') + + # if specified, loop to end of input audio + if not self.loop: + delta = self.delta + ed_idx = st_idx + self.delta.shape[-1] + else: + remaining = x.shape[-1] - self.delta.shape[-1] - st_idx + n_loops = math.ceil(remaining / self.delta.shape[-1]) + 1 + ed_idx = x.shape[-1] + delta = self.delta.repeat( + (1,) * (self.delta.ndim - 1) + (n_loops,) + )[..., :ed_idx - st_idx] + + # apply perturbation + x[..., st_idx:ed_idx] = x[..., st_idx:ed_idx] + delta + + # trim to original length + x = x[..., :orig_len] + + # peak-normalize to match original input + if self.normalize: + factor = peak / torch.max(torch.abs(x), -1)[0].reshape(-1) + factor = factor.reshape((-1,) + (1,)*(ndims - 1)) + x = x * factor + + return x + + def _visualize_top_level(self) -> Dict[str, torch.Tensor]: + """ + Visualize top-level (non-recursive) perturbation parameters. + + :return: tag (string) / image (tensor) pairs, stored in a dictionary + """ + + name = self.__class__.__name__ + + visualizations = {} + + # plot: additive perturbation + if self.delta.numel() > 0: + + visualizations = { + **visualizations, + f'{name}-parameters': plot_waveform(self.delta) + } + + # plot: parameter gradients + if self.delta.grad is not None: + + visualizations = { + **visualizations, + f'{name}-gradients': plot_waveform(self.delta.grad) + } + + return visualizations + + def _project_valid_top_level(self): + """ + Project top-level (non-recursive) parameters to valid range. + """ + + if self.eps is None: + return + + # obtained flattened parameters + flattened = [] + + for param in self.parameters(recurse=False): + if param.requires_grad: + flattened.append(param.data.detach().flatten()) + + flattened = torch.cat(flattened, dim=-1) # (n_parameters,) + + # project using given p-norm and radius + if self.projection_norm in [2, float(2), "2"]: + norm = torch.norm(flattened, p=2) + 1e-20 + factor = torch.min( + torch.tensor(1.0), + torch.tensor(self.eps) / norm + ).view(-1) + with torch.no_grad(): + flattened.mul_(factor) + elif self.projection_norm in [float("inf"), "inf"]: + with torch.no_grad(): + flattened.clamp_(min=-self.eps, max=self.eps) + else: + raise ValueError(f'Invalid projection norm {self.projection_norm}') + + # overwrite parameter data + idx = 0 + for param in self.parameters(recurse=False): + if param.requires_grad: + param_length = param.shape.numel() + data = flattened[idx:idx + param_length].reshape( + param.shape + ).detach() + param.data = data + idx += param_length diff --git a/voicebox/src/attacks/offline/perturbation/kenansville.py b/voicebox/src/attacks/offline/perturbation/kenansville.py new file mode 100644 index 0000000000000000000000000000000000000000..eb73a19a21b52e46e77873f03ab3882106f3aa47 --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/kenansville.py @@ -0,0 +1,266 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.fft as fft + +import math +import random + +import numpy as np + +from typing import Union, Dict + +from src.attacks.offline.perturbation.perturbation import Perturbation +from src.data import DataProperties + +################################################################################ +# Remove content of spectral bins below energy threshold +################################################################################ + + +class KenansvillePerturbation(Perturbation): + + def __init__(self, + threshold_db: float = 100.0, + win_length: int = 2048, + win_type: str = 'hann' + ): + + super().__init__() + + if win_type not in ['rectangular', 'triangular', 'hann']: + raise ValueError(f'Invalid window type {win_type}') + + self.threshold_db = nn.Parameter(torch.as_tensor([threshold_db])) + self.win_length = win_length + self.win_type = win_type + + # determine hop length from window function + if self.win_type == 'rectangular': # non-overlapping frames + self.hop_length = self.win_length + else: + self.hop_length = self.win_length // 2 + + @staticmethod + def _get_win_func(win_type: str): + if win_type == 'rectangular': + return lambda m: torch.ones(m) + elif win_type == 'triangular': + return lambda m: torch.as_tensor(np.bartlett(m)).float() + elif win_type == 'hann': + return lambda m: torch.hann_window(m) + + def _wav_to_frame(self, x: torch.Tensor): + """ + Given waveform audio, divide into (overlapping) frames. + + :param x: waveform audio of shape (n_batch, frame_len) + :return: framed audio of shape (n_batch, n_frames, signal_length) + """ + + assert x.ndim == 2 + n_batch, signal_len = x.shape + + # compute required number of frames given stored frame length + if self.win_type == 'rectangular': # non-overlapping frames + n_frames = signal_len // self.win_length + 1 + pad_len = n_frames * self.win_length + elif self.win_type == 'hann': # 50% overlap + n_frames = signal_len // self.hop_length + 1 + pad_len = (n_frames - 1) * (self.win_length // 2) + self.win_length + else: + raise ValueError(f'Invalid window type {self.win_type}') + + # pad input audio to integer number of frames + if signal_len < pad_len: + x = F.pad(x, (0, pad_len - signal_len)) + else: + x = x[..., :pad_len] + + # divide input audio into frames + if self.win_type == 'rectangular': + x = x.unfold(-1, self.win_length, self.win_length) + else: + x = x.unfold(-1, self.win_length, self.win_length // 2) + + return x + + def _frame_to_wav(self, x: torch.Tensor): + """ + Given framed audio, resynthesize waveform via overlap-add. + + :param x: framed audio of shape (n_batch, n_frames, frame_len) + :return: waveform audio of shape (n_batch, signal_length) + """ + + assert x.ndim == 3 + n_batch, n_frames, _ = x.shape + + # restore signal from frames using overlap-add + if self.win_type == 'rectangular': + + pad_len = n_frames * self.win_length + + x = F.fold( + x.permute(0, 2, 1), + (1, pad_len), + kernel_size=(1, self.win_length), + stride=(1, self.win_length) + ).reshape(n_batch, -1) + + else: + + pad_len = (n_frames - 1) * (self.win_length // 2) + self.win_length + + # obtain window function + win = self._get_win_func(self.win_type)( + self.win_length + ).to(x).reshape(1, 1, -1) + + x = x * win # apply windowing along final dimension + + # use `nn.functional.fold` to perform overlap-add synthesis; for + # reference, see https://tinyurl.com/pw7mv9hh + x = F.fold( + x.permute(0, 2, 1), + (1, pad_len), + kernel_size=(1, self.win_length), + stride=(1, self.win_length // 2) + ).reshape(n_batch, -1) + + return x + + @staticmethod + def _match_input(x: torch.Tensor, x_ref: torch.Tensor): + """ + Given adversarial output, match scale and dimensions to original input + + :param x: adversarial audio of shape (n_batch, ..., adv_signal_length) + :param x_ref: original audio of shape (n_batch, ..., signal_length) + """ + + assert x.ndim >= 2 and x_ref.ndim >= 2 + + # match original dimensions + if x.ndim < x_ref.ndim: + x = x.unsqueeze(1) + + # prepare to peak-normalize output audio + peak = torch.max(torch.abs(x_ref), -1)[0].reshape(-1) + + # peak-normalize to match input + factor = peak / torch.max(torch.abs(x), -1)[0].reshape(-1) + factor = factor.reshape(x.shape[0], *((1,)*(x.ndim - 1))) + x = (x * factor)[..., :x_ref.shape[-1]] + + return x + + def _remove_frequencies(self, x: torch.Tensor): + """ + Remove frequency content below relative energy threshold. + + :param x: framed audio of shape (n_batch, n_frames, frame_len) + :return: perturbed audio frames, shape (n_batch, n_frames, frame_len) + """ + + # convert threshold to energy ratio + threshold = 10 ** (-self.threshold_db / 10) + + assert x.ndim == 3 # (n_batch, n_frames, frame_len) + n_batch = x.shape[0] + + # if frames overlap, pad input to ensure each sample occurs in the same + # number of frames + if self.win_type in ["hann", "triangular"]: + x = F.pad(x, (self.win_length // 2, self.win_length // 2)) + + # compute power spectral density (PSD) of each frame, doubling paired + # frequencies (non-DC, non-nyquist) + x_rfft = fft.rfft(x) + x_psd = torch.abs(x_rfft) ** 2 # (n_batch, n_frames, n_fft) + + if x.shape[-1] % 2: # odd: DC frequency + x_psd[..., 1:] *= 2 + else: # even: DC and Nyquist frequencies + x_psd[..., 1:-1] *= 2 + + # sort frequency bins ascending by PSD + x_psd_index = torch.argsort(x_psd, dim=-1) + reordered = torch.gather(x_psd, -1, x_psd_index) + + # compute cumulative frequency energy across bins + cumulative = torch.cumsum(reordered, dim=-1) + + # set threshold relative to highest-energy bin + norm_threshold = (threshold * cumulative[..., -1]).unsqueeze(-1) + cutoff = torch.searchsorted(cumulative, norm_threshold, right=True) + + # zero bins below threshold, using sorted indices + n_frames = x_rfft.shape[1] + for i in range(n_batch): + for j in range(n_frames): + x_rfft[i, j, x_psd_index[i, j, :cutoff[i, j]]] = 0 + + # invert to waveform audio + x = fft.irfft( + x_rfft, + x.shape[-1] + ) + + # undo additional padding if necessary + if self.win_type == "hann": + x = x[..., self.win_length // 2: -self.win_length // 2] + + return x # (n_batch, n_frames, frame_len) + + def set_reference(self, x: torch.Tensor): + """ + Given reference input, initialize perturbation parameters accordingly + and match input device. + + :param x: reference audio, shape (n_batch, n_channels, signal_length) + """ + self.threshold_db = self.threshold_db.to(x.device) + + def set_threshold(self, threshold_db: float): + self.threshold_db.fill_(threshold_db) + + def forward(self, x: torch.Tensor, *args, **kwargs): + """ + Apply perturbation to inputs. + + :param x: input audio, shape (n_batch, n_channels, signal_length) + """ + + # require batch dimension + assert x.ndim >= 2 + n_batch, signal_len = x.shape[0], x.shape[-1] + + x_orig = x.clone().detach() + + # discard channel dimension + if x.ndim > 2: + x = x.mean(dim=1, keepdim=True).squeeze(1) + + x = self._wav_to_frame(x) + x = self._remove_frequencies(x) + x = self._frame_to_wav(x) + x = self._match_input(x, x_orig) + + return x + + def _visualize_top_level(self) -> Dict[str, torch.Tensor]: + """ + Visualize top-level (non-recursive) perturbation parameters. + + :return: tag (string) / image (tensor) pairs, stored in a dictionary + """ + + visualizations = {} + return visualizations + + def _project_valid_top_level(self): + """ + Project top-level (non-recursive) parameters to valid range. + """ + pass diff --git a/voicebox/src/attacks/offline/perturbation/perturbation.py b/voicebox/src/attacks/offline/perturbation/perturbation.py new file mode 100644 index 0000000000000000000000000000000000000000..278da679f0a9b0ee58659a08eb6f0144b466ada4 --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/perturbation.py @@ -0,0 +1,159 @@ +import torch +import torch.nn as nn + +from typing import Dict + +################################################################################ +# Base class for all adversarial perturbation operators +################################################################################ + + +class Perturbation(nn.Module): + """ + Base class for adversarial perturbation operators. Allows (recursive) + composition of Perturbation objects and facilitates projected gradient + descent optimization by controlling parameter and gradient access. + + Subclasses must override the methods `forward()`, `set_reference()`, and + `_project_valid_top_level()`. Optionally, sublasses can overwrite + `_visualize_top_level()` to produce parameter visualizations for logging + """ + def __init__(self): + super().__init__() + + @staticmethod + def _check_broadcastable(x: torch.Tensor, x_ref: torch.Tensor): + """ + Check whether input and reference tensors are broadcastable + """ + + broadcastable = all( + (m == n) or (m == 1) or (n == 1) for m, n in zip( + x.shape[::-1], x_ref.shape[::-1] + ) + ) + + # broadcast cannot expand input batch dimension + valid = x.shape[0] == x_ref.shape[0] or x_ref.shape[0] == 1 + + return broadcastable * valid + + @staticmethod + def _freeze_grad(m: nn.Module): + """ + Disable gradient computation for all parameters in given module + :param m: torch.nn.Module object + """ + for module in m.modules(): + for param in module.parameters(): + param.requires_grad = False + + def set_reference(self, x: torch.Tensor): + """ + Given reference input, initialize perturbation parameters accordingly + and match input device. + + :param x: reference audio, shape (n_batch, n_channels, signal_length) + """ + + raise NotImplementedError() + + def forward(self, x: torch.Tensor, *args, **kwargs): + """ + Apply perturbation to inputs. + + :param x: input audio, shape (n_batch, n_channels, signal_length) + """ + raise NotImplementedError() + + def retrieve_parameter_gradients(self): + """ + Retrieve 'flattened' perturbation parameter representation, including + those of all stored Perturbation objects. + + return: flattened parameter gradients, shape (n_parameters,) + """ + + flattened_grad = [] + + for param in self.parameters(): + if param.requires_grad: + + if param.grad is None: + flattened_grad.append( + torch.zeros_like(param).detach().flatten() + ) + else: + flattened_grad.append(param.grad.detach().flatten()) + + return torch.cat(flattened_grad, dim=-1) + + def set_parameter_gradients(self, flattened_grad: torch.tensor): + """ + Given flattened gradients, apply to stored parameters. + + :param flattened_grad: parameter gradients of shape (n_parameters,) + """ + + # check that flattened gradients have valid shape + prod = sum( + [p.shape.numel() for p in self.parameters() if p.requires_grad] + ) + + assert flattened_grad.ndim <= 1 + assert flattened_grad.numel() == prod + + idx = 0 + for param in self.parameters(): + if param.requires_grad: + param_length = param.shape.numel() + grad = flattened_grad[idx:idx + param_length].reshape( + param.shape + ).detach() + param.grad = grad + idx += param_length + + def _visualize_top_level(self) -> Dict[str, torch.Tensor]: + """ + Visualize top-level (non-recursive) perturbation parameters. + + :return: tag (string) / image (tensor) pairs, stored in a dictionary + """ + return {} + + def visualize(self) -> Dict[str, torch.Tensor]: + """ + Visualize perturbation parameters. + + :return: tag (string) / image (tensor) pairs, stored in a dictionary + """ + + # collect visualizations for top-level parameters + visualizations = self._visualize_top_level() + + # collect visualizations for stored Perturbation objects + for m in self.children(): + if isinstance(m, Perturbation): + visualizations = {**visualizations, **m.visualize()} + + return visualizations + + def _project_valid_top_level(self): + """ + Project top-level (non-recursive) parameters to valid range. + """ + raise NotImplementedError() + + def project_valid(self): + """ + Project perturbation parameters to valid range. Apply to all stored + Perturbation objects recursively, such that each Perturbation is + responsible for its own projection. + """ + + # project top-level parameters (non-recursive) + self._project_valid_top_level() + + for m in self.children(): + if isinstance(m, Perturbation): + m.project_valid() diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__init__.py b/voicebox/src/attacks/offline/perturbation/voicebox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd57d546aa969241c109333272b6b7d912435b6a --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/voicebox/__init__.py @@ -0,0 +1 @@ +from src.attacks.offline.perturbation.voicebox.voicebox import VoiceBox diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/__init__.cpython-310.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16e20be7bcef24ca4a6ae05eca37c322f2b56ab5 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/__init__.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53eb880b35a464d848e859d1f3bfc56161908967 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/batchnorm.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/batchnorm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..925c8f57550b8d18acea33108578540ddbf9fef3 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/batchnorm.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/bottleneck.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/bottleneck.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9266997a4b1aae01897eaa42d3071d3759501045 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/bottleneck.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/expnorm.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/expnorm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bf630fa617fafe0c78e4ec710a378959a5fb2e5 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/expnorm.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/film.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/film.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb66abacbb69f3cafb58f80d8310b0b1ce86030f Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/film.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/filter.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/filter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fae20cd47b2349fceec629aa5b0dd251f034022 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/filter.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/lookahead.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/lookahead.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7ecd3855c5c4eb95acd92b1e7453e27677f7a5d Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/lookahead.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/loudness.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/loudness.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a54824533d2acbfc1c457cbba3a37092e709081b Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/loudness.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/mlp.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/mlp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f632beefc4be1dd8f64f574535b54cb216951a74 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/mlp.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/pitch.cpython-310.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/pitch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a358d0667b6167640d3a4431ab9aa91185bd89e Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/pitch.cpython-310.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/pitch.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/pitch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9506adcd55def89129fef7dd9390fe0b66384458 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/pitch.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/projection.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/projection.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43862dbcde7890fa67c28fc8ccabee26228bb4d9 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/projection.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/spec.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/spec.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f37e12353a78c8868deb2037adedf8a744680df2 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/spec.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/voicebox.cpython-310.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/voicebox.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f9d05d0f7a5f3474e631736d340c7c14ce8b7b3 Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/voicebox.cpython-310.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/voicebox.cpython-39.pyc b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/voicebox.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49bf75278a246a675d078d87e9401497a65394da Binary files /dev/null and b/voicebox/src/attacks/offline/perturbation/voicebox/__pycache__/voicebox.cpython-39.pyc differ diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/batchnorm.py b/voicebox/src/attacks/offline/perturbation/voicebox/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..963de9ecb7b0d4defcbf6988f0d914d28257dafd --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/voicebox/batchnorm.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn + +################################################################################ +# Time-distributed batch normalization layer +################################################################################ + + +class BatchNorm(nn.Module): + """Apply batch normalization along feature dimension only""" + def __init__(self, + num_features, + feature_dim: int = -1, **kwargs): + + super().__init__() + + if feature_dim == 1: + self.permute = (0, 1, 2) + elif feature_dim in [2, -1]: + self.permute = (0, 2, 1) + else: + raise ValueError(f'Must provide batch-first inputs') + + self.num_features = num_features + self.feature_dim = feature_dim + + # pass any additional arguments to batch normalization module + self.bn = nn.BatchNorm1d(num_features=self.num_features, **kwargs) + + def forward(self, x: torch.Tensor): + + # check input dimensions + assert x.ndim == 3 + assert x.shape[self.feature_dim] == self.num_features + + # reshape to ensure batch normalization is time-distributed + x = x.permute(*self.permute) + + # apply normalization + x = self.bn(x) + + # restore original shape + x = x.permute(*self.permute) + + return x diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/bottleneck.py b/voicebox/src/attacks/offline/perturbation/voicebox/bottleneck.py new file mode 100644 index 0000000000000000000000000000000000000000..e6763050f13937519e75f5267a24595173b72d14 --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/voicebox/bottleneck.py @@ -0,0 +1,123 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +################################################################################ +# Latent bottleneck modules +################################################################################ + + +class RNNBottleneck(nn.Module): + + def __init__(self, + input_size: int = 512, + hidden_size: int = 2048, + proj_size: int = 512, + num_layers: int = 8, + downsample_index: int = 1, + downsample_factor: int = 2, + dropout_prob: float = 0 + ): + super().__init__() + + # downsampling can occur no earlier than first layer + assert downsample_index >= 0 + + self.num_layers = num_layers + self.downsample_index = downsample_index + self.downsample_factor = downsample_factor + self.dropout = nn.Dropout(dropout_prob) + + # optionally, apply projection + if proj_size >= hidden_size: + proj_size = 0 + + # build multi-layer recurrent network + rnn = [] + for i in range(num_layers): + + # first layer + if i == 0: + pass + + # pre-downsampling layers + elif i <= downsample_index: + input_size = proj_size or hidden_size + + # immediately after downsampling layer + elif i == downsample_index + 1: + input_size = (proj_size or hidden_size) * downsample_factor + + # subsequent layers + else: + input_size = proj_size or hidden_size + + rnn.append( + nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + proj_size=proj_size, + num_layers=1, + batch_first=True, + bidirectional=False, + bias=True + ) + ) + self.rnn = nn.ModuleList(rnn) + + def forward(self, x): + + for i in range(self.num_layers): + + x, _ = self.rnn[i](x, hx=None) + + x = self.dropout(x) + + if i == self.downsample_index: + + n_batch, n_frames, proj_size = x.shape + + # determine necessary padding to allow temporal downsampling + pad_len = self.downsample_factor * math.ceil(n_frames / self.downsample_factor) - n_frames + + # apply causal padding + x = F.pad(x, (0, 0, 0, pad_len)) + + # apply temporal downsampling + x = torch.reshape(x, (n_batch, x.shape[1] // self.downsample_factor, x.shape[2] * self.downsample_factor)) + + return x + + +class CausalTransformer(nn.Module): + + def __init__(self, + hidden_size: int, + dim_feedforward: int = 2048, + depth: int = 2, + heads: int = 8, + dropout_prob: float = 0.0): + + super().__init__() + + attention_block = nn.TransformerEncoderLayer( + d_model=hidden_size, + nhead=heads, + dim_feedforward=dim_feedforward, + dropout=dropout_prob, + batch_first=True + ) + self.transformer = nn.TransformerEncoder( + encoder_layer=attention_block, + num_layers=depth, + norm=None + ) + + def forward(self, x: torch.Tensor): + + _, n_frames, _ = x.shape + causal_mask = torch.triu(torch.ones((n_frames, n_frames), dtype=torch.bool, device=x.device), diagonal=1) + + return self.transformer(x, mask=causal_mask) diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/expnorm.py b/voicebox/src/attacks/offline/perturbation/voicebox/expnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..68db730798b8e291f03b2336bdaaec5555be8949 --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/voicebox/expnorm.py @@ -0,0 +1,100 @@ +import math + +import torch +import torch.nn as nn + +from src.data import DataProperties + +################################################################################ +# Exponential unit normalization module +################################################################################ + + +class ExponentialUnitNorm(nn.Module): + """Unit-normalize magnitude spectrogram""" + + def __init__(self, + decay: float, + hop_size: int, + n_freq: int, + eps: float = 1e-14): + """ + Perform exponential unit normalization on magnitude spectrogram + + Parameters + ---------- + decay (float): + + hop_size (int): + + n_freq (int): + + eps (float): + """ + + super().__init__() + + # compute exponential decay factor + self.alpha = self._get_norm_alpha( + DataProperties.get('sample_rate'), + hop_size, + decay + ) + self.eps = eps + self.n_freq = n_freq + self.init_state: torch.Tensor + + # initialize per-band states for unit norm calculation + self.reset() + + @staticmethod + def _get_norm_alpha(sr: int, hop_size: int, decay: float): + """ + Compute exponential decay factor `alpha` for a given decay window size + in seconds + """ + dt = hop_size / sr + a_ = math.exp(-dt / decay) + + precision = 3 + a = 1.0 + + while a >= 1.0: + a = round(a_, precision) + precision += 1 + + return a + + def reset(self): + """(Re)-initialize stored state""" + s = torch.linspace(0.001, 0.0001, self.n_freq).view( + 1, self.n_freq + ) # broadcast with (n_batch, 1, n_frames, n_freq, 2) + self.register_buffer("init_state", s) + + def forward(self, x: torch.Tensor): + """ + Perform exponential unit normalization on magnitude spectrogram + + Parameters + ---------- + x (Tensor): shape (n_batch, n_freq, n_frames) + + Returns + ------- + normalized (Tensor): shape (n_batch, n_freq, n_frames) + """ + x_abs = x.clamp_min(1e-10).sqrt() + + n_batch, n_freq, n_frames = x.shape + assert n_freq == self.n_freq + + state = self.init_state.clone().expand( + n_batch, n_freq) # (n_batch, n_freq) + + out_states = [] + for f in range(n_frames): + state = x_abs[:, :, f] * (1 - self.alpha) + state * self.alpha + out_states.append(state) + + return x / torch.stack(out_states, 2).sqrt() diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/film.py b/voicebox/src/attacks/offline/perturbation/voicebox/film.py new file mode 100644 index 0000000000000000000000000000000000000000..c7a2c8104504a74624d7cecc0dee69736b4a727b --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/voicebox/film.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn + +from src.attacks.offline.perturbation.voicebox.batchnorm import BatchNorm + +################################################################################ +# Affine conditioning layer +################################################################################ + + +class FiLM(nn.Module): + """ + Affine conditioning layer, as proposed in Perez et al. + (https://arxiv.org/pdf/1709.07871.pdf). Operates on each channel of a + selected feature representation, with one learned scaling parameter and one + learned bias parameter per channel. + + Code adapted from https://github.com/csteinmetz1/steerable-nafx + """ + def __init__( + self, + cond_dim: int, + num_features: int, + batch_norm: bool = True, + ): + """ + Apply linear projection and batch normalization to obtain affine + conditioning parameters. + + :param cond_dim: dimension of conditioning input + :param num_features: number of feature maps to which conditioning is + applied + :param batch_norm: if True, perform batch normalization + """ + super().__init__() + self.num_features = num_features + self.batch_norm = batch_norm + if batch_norm: + self.bn = BatchNorm(num_features, feature_dim=-1, affine=False) + self.adaptor = torch.nn.Linear(cond_dim, num_features * 2) + + def forward(self, x: torch.Tensor, cond: torch.Tensor): + """ + + FIGURE OUT SHAPES + + x (Tensor): shape + cond (Tensor): shape + """ + + # linear projection of conditioning input + cond = self.adaptor(cond) + + # learn scale and bias parameters per channel, thus 2X num_features + g, b = torch.chunk(cond, 2, dim=-1) + + if self.batch_norm: + x = self.bn(x) # apply BatchNorm without affine + x = (x * g) + b # then apply conditional affine + + return x diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/filter.py b/voicebox/src/attacks/offline/perturbation/voicebox/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6aad723d5e806f9ec03cb2184661d2f7a3a000 --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/voicebox/filter.py @@ -0,0 +1,322 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.fft as fft + +import numpy as np + +from typing import Union +import warnings + +################################################################################ +# Time-varying FIR filter module +################################################################################ + + +class FilterLayer(nn.Module): + """Encapsulate FIR filtering in network layer""" + def __init__(self, + win_length: int = 512, + win_type: str = 'hann', + n_bands: int = 128, + normalize_ir: Union[str, int, float] = None, + **kwargs): + """ + Given a set of frame-wise controls specifying the frequency amplitude + response of a time-varying FIR filter, apply filter to incoming audio. + + Parameters + ---------- + win_length (int): analysis window length in samples + + win_type (str): window function; must be one of 'rectangular', + 'triangular', or 'hann' + + n_bands (int): number of filter bands + + normalize_ir (int): type of normalization to apply to FIR impulse + responses; must be 1, 2, or None + """ + super().__init__() + + if win_type not in ['rectangular', 'triangular', 'hann']: + raise ValueError(f'Invalid window type {win_type}') + + if normalize_ir not in [None, 'none', 1, 2]: + raise ValueError( + f'Invalid IR normalization type {normalize_ir}') + + # round window size to next power of 2 + next_pow_2 = 2**(win_length - 1).bit_length() + if win_length != next_pow_2: + warnings.warn(f'Rounding block size {win_length} to nearest power' + f' of 2 ({next_pow_2})') + + # store attributes + self.win_length = next_pow_2 + self.n_bands = n_bands + self.win_type = win_type + self.normalize_ir = normalize_ir + + # determine hop length from window function + if self.win_type == 'rectangular': # non-overlapping frames + self.hop_length = self.win_length + else: # overlapping frames + self.hop_length = self.win_length // 2 + + @staticmethod + def _get_win_func(win_type: str): + """Obtain callable window function by name""" + if win_type == 'rectangular': + return lambda m: torch.ones(m) + elif win_type == 'hann': + return lambda m: torch.hann_window(m) + elif win_type == 'triangular': + return lambda m: torch.as_tensor(np.bartlett(m)).float() + + def _amp_to_ir(self, amp: torch.Tensor): + """ + Convert filter frequency amplitude response into a time-domain impulse + response. The filter response is given as a per-frame transfer function, + and a symmetric impulse response is returned. + + Parameters + ---------- + amp (torch.Tensor): shape (n_batch, n_frames, n_bands) or + (1, n_frames, n_bands); holds per-frame frequency + magnitude response of time-varying FIR filter + + Returns + ------- + impulse (torch.Tensor): shape (n_batch, n_frames, 2 * n_bands + 1) + """ + + # convert to complex zero-phase representation + amp = torch.stack([amp, torch.zeros_like(amp)], -1) + amp = torch.view_as_complex(amp) # (n_batch, n_frames, n_bands) + + # compute 1D inverse FFT along final dimension, treating bands as + # Fourier frequencies of analysis + impulse = fft.irfft(amp) + + # require filter size to match time-domain transform of filter bands + filter_size = impulse.shape[-1] + + # apply window to shifted zero-phase (symmetric) form of impulse + impulse = torch.roll(impulse, filter_size // 2, -1) + win = torch.hann_window( + filter_size, dtype=impulse.dtype, device=impulse.device + ) + + if self.normalize_ir is None: # or self.normalize_ir == 'none': disabled string option for jit.scripting. + pass + elif self.normalize_ir == 1: + impulse = impulse / (torch.sum( + impulse, dim=-1, keepdim=True + ) + 1e-20) + elif self.normalize_ir == 2: + impulse = impulse / torch.norm( + impulse, p=2, dim=-1, keepdim=True + ) + 1e-20 + + return impulse * win + + def _fft_convolve(self, + signal: torch.Tensor, + kernel: torch.Tensor, + n_fft: int): + """ + Given waveform representations of signal and FIR filter, convolve + via point-wise multiplication in Fourier domain + + Parameters + ---------- + signal (torch.Tensor): shape (n_batch, n_frames, win_length); holds + framed audio input + + kernel (torch.Tensor): shape (n_batch, n_frames, 2 * n_bands + 1); holds + time-domain impulse responses corresponding to + filter controls for each frame + + n_fft (int): number of FFT bins + + Returns + ------- + convolved (torch.Tensor): shape (n_batch, n_frames, n_fft); holds + filtered audio + """ + + # right-pad kernel and frames to n_fft samples + signal = nn.functional.pad(signal, (0, n_fft - signal.shape[-1])) + kernel = nn.functional.pad(kernel, (0, n_fft - kernel.shape[-1])) + + # apply convolution in Fourier domain and invert + convolved = fft.irfft(fft.rfft(signal) * fft.rfft(kernel)) + + # account for frame-by-frame delay + rolled = torch.roll(convolved, shifts=-(self.n_bands - 1), dims=-1) + + return rolled + + def _pad_and_frame(self, x: torch.Tensor, n_frames: int): + """ + Pad audio to given frame length and divide into frames + + Parameters + ---------- + x (torch.Tensor): shape (n_batch, n_channels, signal_len); input audio + + n_frames (int): target length in frames + + Returns + ------- + framed (torch.Tensor): shape (n_batch, 1, n_frames, win_length); holds + padded and framed input audio + """ + n_batch, *channel_dims, signal_len = x.shape + + if self.win_type == 'rectangular': + pad_len = n_frames * self.win_length + elif self.win_type in ['triangular', 'hann']: + pad_len = (n_frames - 1) * (self.win_length // 2) + self.win_length + else: + raise ValueError(f'Invalid window type {self.win_type}') + + # apply padding/trim + if signal_len < pad_len: + x = nn.functional.pad(x, (0, pad_len - x.shape[-1])) + else: + x = x[..., :pad_len] + + # divide audio into frames + x = x.unfold(-1, self.win_length, self.hop_length) + + return x + + def _ola(self, x: torch.Tensor): + """ + Re-synthesize waveform audio from filtered frames via overlap-add (OLA) + + Parameters + ---------- + x (torch.Tensor): shape (n_batch, n_frames, n_fft); holds framed and + filtered audio + + n_fft (int): length of FFT + + Returns + ------- + synthesized (torch.Tensor): shape (n_batch, n_channels, signal_len); + holds reconstructed audio + """ + + # check dimensions + assert x.ndim == 3 + n_batch, n_frames, n_fft = x.shape + + if self.win_type == 'rectangular': + + # compute target output length + pad_len = self.win_length * (n_frames - 1) + n_fft + + x = nn.functional.fold( + x.permute(0, 2, 1), + (1, pad_len), + kernel_size=(1, n_fft), + stride=(1, self.win_length) + ).reshape(n_batch, -1) + + elif self.win_type in ['triangular', 'hann']: + + # compute target output length + pad_len = (n_frames - 1) * (self.win_length // 2) + self.win_length + truncated_len = ((pad_len - self.win_length) + // (self.win_length // 2) + ) * (self.win_length // 2) + n_fft + + # obtain window functions and pad to match frame length + win = self._get_win_func( + self.win_type + )(self.win_length).to(x).reshape(1, 1, -1) + win_pad_len = x.shape[-1] - win.shape[-1] + win = nn.functional.pad(win, (0, win_pad_len)) + + # apply window frame-by-frame + x = x * win + + # use `nn.functional.fold` to perform overlap-add synthesis; for + # reference, see https://tinyurl.com/pw7mv9hh + x = nn.functional.fold( + x.permute(0, 2, 1), + (1, truncated_len), + kernel_size=(1, n_fft), + stride=(1, self.win_length // 2) + ).reshape(n_batch, -1) + + return x + + def forward(self, x: torch.Tensor, controls: torch.Tensor): + """ + Use given controls to parameterize a time-varying FIR filter and apply + to incoming audio + + Parameters + ---------- + x (torch.Tensor): shape (n_batch, n_channels, signal_len); holds + the input audio + controls (torch.Tensor): shape (n_batch, n_frames, n_bands); holds + frame-wise filter controls + """ + + # ensure the input contains at least a single frame of audio + assert x.shape[-1] >= self.win_length + + # check filter control dimensions + assert controls.shape[-1] == self.n_bands + + # avoid modifying input audio + x = x.clone().detach() + n_batch, *channel_dims, signal_len = x.shape + + # require batch, channel dimensions + assert x.ndim >= 2 + if x.ndim == 2: + x = x.unsqueeze(1) + + # convert to mono audio + x = x.mean(dim=1) + + # pad or trim signal to match number of frames in filter controls + n_frames = controls.shape[1] + x = self._pad_and_frame(x, n_frames) + + # convert filter controls (frequency amplitude responses) to frame-by- + # frame time-domain impulse responses + impulse = self._amp_to_ir(controls) + + # determine FFT size using inferred FIR waveform filter length + # (accounting for padding) + n_fft_min = self.win_length + 2 * (self.n_bands - 1) + n_fft = pow(2, math.ceil(math.log2(n_fft_min))) # use next power of 2 + + # convolve frame-by-frame in FFT domain; resulting padded frames will + # contain "ringing" overlapping segments which must be summed + x = self._fft_convolve( + x, + impulse, + n_fft + ).contiguous() # (n_batch, n_frames_overlap, n_fft) + + # restore signal from frames using overlap-add + x = self._ola(x) + + # match original dimensions + x = x[..., :signal_len].reshape( + n_batch, + *((1,) * len(channel_dims)), + signal_len + ) # trim signal to original length + + return x diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/lookahead.py b/voicebox/src/attacks/offline/perturbation/voicebox/lookahead.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce128d79b8855847b290c976b24fe0855c7ea9e --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/voicebox/lookahead.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +################################################################################ +# Post-bottleneck lookahead module from DeepSpeech 2 +################################################################################ + + +class Lookahead(nn.Module): + """ + Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks + from Wang et al 2016. + """ + def __init__(self, n_features: int, lookahead_frames: int): + """ + Parameters + ---------- + n_features (int): feature dimension + lookahead_frames (int): lookahead length in frames + """ + super(Lookahead, self).__init__() + + assert lookahead_frames >= 0, 'Must provide nonzero context length' + + self.lookahead_frames = lookahead_frames + self.n_features = n_features + + # pad to preserve sequence length in output + self.pad = (0, self.lookahead_frames) + + self.conv = nn.Conv1d( + self.n_features, + self.n_features, + kernel_size=self.lookahead_frames + 1, + stride=1, + groups=self.n_features, # independence between features + padding=0, + bias=False + ) + + def forward(self, x: torch.Tensor): + """ + Parameters + ---------- + x (Tensor): shape (n_batch, n_frames, n_features) + + Returns + ------- + out (Tensor): shape (n_batch, n_frames, n_features) + """ + x = x.transpose(1, 2) # (n_batch, n_features, n_frames) + x = F.pad(x, pad=self.pad, value=0) + x = self.conv(x) + x = x.transpose(1, 2).contiguous() # (n_batch, n_features, n_frames) + return x diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/loudness.py b/voicebox/src/attacks/offline/perturbation/voicebox/loudness.py new file mode 100644 index 0000000000000000000000000000000000000000..b842513192f33fa0b7b649a9a74146b64ab92a94 --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/voicebox/loudness.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn + +import librosa as li +import numpy as np + +from src.data import DataProperties + +################################################################################ +# Extract frame-wise A-weighted loudness +################################################################################ + + +class LoudnessEncoder(nn.Module): + """Extract frame-wise A-weighted loudness""" + def __init__(self, + hop_length: int = 128, + n_fft: int = 256 + ): + + super().__init__() + + self.hop_length = hop_length + self.n_fft = n_fft + + def forward(self, x: torch.Tensor): + + n_batch, *channel_dims, signal_len = x.shape + + # require batch, channel dimensions + assert x.ndim >= 2 + if x.ndim == 2: + x = x.unsqueeze(1) + + # convert to mono audio + x = x.mean(dim=1) + + spec = li.stft( + x.detach().cpu().numpy(), + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.n_fft, + center=True, + ) + spec = np.log(abs(spec) + 1e-7) + + # compute A-weighting curve for frequencies of analysis + f = li.fft_frequencies( + sr=DataProperties.get('sample_rate'), n_fft=self.n_fft) + a_weight = li.A_weighting(f) + + # apply multiplicative weighting via addition in log domain + spec = spec + a_weight.reshape(1, -1, 1) + + # take mean over each frame + loudness = torch.from_numpy(np.mean(spec, 1)).unsqueeze(-1).float().to(x.device) + + return loudness diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/mlp.py b/voicebox/src/attacks/offline/perturbation/voicebox/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..f5092090ececc342c3badbae3dbcd9ba55e01c08 --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/voicebox/mlp.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn + +################################################################################ +# Time-distributed MLP module +################################################################################ + + +class MLP(nn.Module): + """Time-distributed MLP network""" + def __init__(self, + in_channels: int, + hidden_size: int = 512, + depth: int = 2, + activation: nn.Module = nn.LeakyReLU() + ): + + super().__init__() + + channels = [in_channels] + depth * [hidden_size] + mlp = [] + for i in range(depth): + mlp.append(nn.Linear(channels[i], channels[i + 1])) + mlp.append(nn.LayerNorm(channels[i + 1])) + + # omit nonlinearity after final layer + if i < depth - 1: + mlp.append(activation) + + self.mlp = nn.Sequential(*mlp) + + def forward(self, x: torch.Tensor): + return self.mlp(x) diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/pitch.py b/voicebox/src/attacks/offline/perturbation/voicebox/pitch.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3636ab10b640f2e8ddf066cc4e7879d15612ab --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/voicebox/pitch.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torchcrepe +import pyworld + +import numpy as np + +from src.data import DataProperties + +################################################################################ +# Compute frame-wise pitch and periodicity estimates +################################################################################ + + +class PitchEncoder(nn.Module): + + def __init__(self, + algorithm: str = 'dio', + return_periodicity: bool = True, + hop_length: int = 128, + ): + + super().__init__() + + self.algorithm = algorithm + self.return_periodicity = return_periodicity + self.hop_length = hop_length + + @torch.no_grad() + def forward(self, x: torch.Tensor): + + # ensure the input contains at least a single frame of audio + assert x.shape[-1] >= self.hop_length + + # avoid modifying input audio + n_batch, *channel_dims, signal_len = x.shape + + # add channel dimension if necessary + if len(channel_dims) == 0: + x = x.unsqueeze(1) + + x = x.mean(dim=1) + + if self.algorithm == 'torchcrepe': + + if not self.return_periodicity: + pitch = torchcrepe.predict( + x, + DataProperties.get('sample_rate'), + hop_length=self.hop_length, + fmin=50, + fmax=550, + model='tiny', + batch_size=10_000, + return_periodicity=False, + device='cuda', + ) + + # (n_batch, n_frames, 1) + return pitch.unsqueeze(-1) + else: + pitch, periodicity = torchcrepe.predict( + x, + DataProperties.get('sample_rate'), + hop_length=self.hop_length, + fmin=50, + fmax=550, + model='tiny', + batch_size=10_000, + return_periodicity=True, + device='cuda', + ) + + # (n_batch, n_frames, 1), (n_batch, n_frames, 1) + return pitch.unsqueeze(-1), periodicity.unsqueeze(-1) + + elif self.algorithm == 'dio': + + pitch_out, periodicity_out, device = [], [], x.device + hop_ms = 1000*self.hop_length/DataProperties.get('sample_rate') + x_np = x.clone().double().cpu().numpy() + + for i in range(n_batch): + pitch, timeaxis = pyworld.dio( + x_np[i], + fs=DataProperties.get('sample_rate'), + f0_floor=50, + f0_ceil=550, + frame_period=hop_ms, + speed=4) # downsampling factor, for speedup + pitch = pyworld.stonemask( + x_np[i], + pitch, + timeaxis, + DataProperties.get('sample_rate')) + + pitch_out.append(pitch) + + if self.return_periodicity: + unvoiced = pyworld.d4c( + x_np[i], + pitch, + timeaxis, + DataProperties.get('sample_rate'), + ).mean(axis=1) + + periodicity_out.append(unvoiced) + + pitch_out = torch.as_tensor( + np.stack(pitch_out, axis=0), + dtype=torch.float32, + device=device).unsqueeze(-1) + + if not self.return_periodicity: + # (n_batch, n_frames, 1) + return pitch_out + else: + periodicity_out = torch.as_tensor( + np.stack(periodicity_out, axis=0), + dtype=torch.float32, + device=device).unsqueeze(-1) # remove unsqueeze if not averaging! + + # (n_batch, n_frames), (n_batch, n_frames, 1) + return pitch_out, periodicity_out + + else: + raise ValueError(f'Invalid algorithm {self.algorithm}') diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/projection.py b/voicebox/src/attacks/offline/perturbation/voicebox/projection.py new file mode 100644 index 0000000000000000000000000000000000000000..782397e622960e052684971f75b5fc98a8985a46 --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/voicebox/projection.py @@ -0,0 +1,215 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Union + +################################################################################ +# Apply projection to regularize controls based on past context +################################################################################ + + +class CausalControlProjection(nn.Module): + """ + Constrain each frame's perturbation 'budget' (projection bound) based on + perturbation magnitudes in previous frames. This may help to encourage + sparser high-magnitude perturbations. + + (n_batch, n_frames, n_controls) --> (n_batch, n_frames, n_controls) + """ + def __init__(self, + eps: float, + n_controls: int, + unity: float, + projection_norm: Union[str, float, int] = 2, + method: str = 'exponential', + decay: float = 2, + context: int = 10, *args, **kwargs): + """ + Parameters + ---------- + eps (float): perturbation bound + + n_controls (int): dimension of control vector for each frame + + unity (float): 'neutral' value for controls; will be used to center + during projection + + projection_norm: norm for projection (2 or infinity) + + method (str): strategy for regularizing controls based on past + context: + + 'none': apply projection with bound `eps` to all + frames independently + + 'exponential': controls bound at any frame is `eps` + minus an exponentially-decaying weighted average of + the control magnitudes of frames within the + preceding context window + + 'max': controls bound at any frame is `eps` minus + the maximum perturbation magnitude over the + preceding context window + + decay (float): decay rate of exponential moving average in frames, + i.e. after `decay` steps a frame's contribution to the + average is scaled by a factor of 1/e + + context (int): number of frames considered; frames beyond the context + window are truncated/removed from consideration + """ + super().__init__() + + if eps is None: + projection_norm = None + + self.eps = eps + self.n_controls = n_controls + self.unity = unity + + assert projection_norm in [ + "none", None, 2, 2.0, "2", float("inf"), "inf" + ] + self.projection_norm = projection_norm + + assert method in ['none', None, 'exp', 'exponential', 'max', 'maximum'], \ + f'Invalid causal regularization method {method}' + + self.method = method + self.decay = decay + self.context = context + + # compute exponential decay factor alpha + a_ = math.exp(-1 / self.decay) + precision = 3 + self.alpha = 1.0 + + while self.alpha >= 1.0: + self.alpha = round(a_, precision) + precision += 1 + + # compute "kernel" for exponentially-weighted average over context + # window, and reshape to broadcast with "unfolded" inputs of shape + # (n_batch, n_frames, self.context, 1) + self.exp_kernel = torch.as_tensor( + [self.alpha**i for i in range(1, self.context + 1)][::-1] + ).reshape(1, 1, -1, 1).float() + + def _project(self, x: torch.Tensor, eps: Union[float, torch.Tensor]): + """ + Apply frame-wise projection with given bound. + + Parameters + ---------- + x (Tensor): shape (n_batch, n_frames, n_controls) + + Returns + ------- + projected (Tensor): shape (n_batch, n_frames, n_controls) + """ + + if isinstance(eps, float): + eps = torch.tensor(eps, device=x.device) + + # L2 projection + if self.projection_norm in [2, '2', 2.0]: + + norm = torch.norm( + x, p=2, dim=-1, keepdim=True) + 1e-20 + factor = torch.min( + torch.tensor(1., device=x.device), + eps / norm + ) + x = x * factor + + # L-infinity projection + elif self.projection_norm in [float('inf'), 'inf']: + + x = torch.clamp( + x, + min=-eps.abs(), + max=eps.abs() + ) + + return x + + def forward(self, x: torch.Tensor): + """ + Regularize control signal via projection. + + Parameters + ---------- + x (Tensor): shape (n_batch, n_frames, n_controls) + + Returns + ------- + projected (Tensor): shape (n_batch, n_frames, n_controls) + """ + + # optionally, perform no projection + if self.projection_norm in ["none", None]: + return x + + n_batch, n_frames, n_controls = x.shape + assert n_controls == self.n_controls + + # center at unity + unity = torch.full(x.shape, self.unity, device=x.device) + x = x - unity + + # project controls for each frame independently + if self.method in ['none', None]: + budget = self.eps + + # project controls based on past context + else: + + # compute control magnitudes for each frame + magnitudes = x.abs().norm( + dim=-1, + keepdim=True, + p=2 if self.projection_norm in [2, 2.0, "2"] else float("inf") + ) + + # apply left-padding to generate one context window per input frame + padded = F.pad(magnitudes, (0, 0, self.context - 1, 0)) + + # get all (overlapping) context "windows" with stride 1 + windows = padded.unfold( + dimension=1, + size=self.context, + step=1 + ).permute(0, 1, 3, 2) # (n_batch, n_frames, self.context, 1) + + # determine frame-wise projection bounds from exponentially-decaying + # weighted average of control amplitudes within context window + if self.method in ['exp', 'exponential']: + + # compute weighted averages with kernel / sum + averages = (windows * self.exp_kernel).sum(dim=2) # (n_batch, n_frames, 1) + + # compute projection bound for each frame + budget = (self.eps - averages).clamp(min=0) # (n_batch, n_frames, 1) + + # determine frame-wise projection bounds from maximum control amplitudes + # within context window + elif self.method in ['max', 'maximum']: + + # take maximum over each window + maxima = torch.max(windows, dim=2)[0] # (n_batch, n_frames, 1) + + # compute projection bound for each frame + budget = (self.eps - maxima).clamp(min=0) # (n_batch, n_frames, 1) + + else: + raise ValueError(f'Invalid regularization method {self.method}') + + # apply frame-wise projection + x = self._project(x, budget) + + # undo centering + x = x + unity + + return x diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/spec.py b/voicebox/src/attacks/offline/perturbation/voicebox/spec.py new file mode 100644 index 0000000000000000000000000000000000000000..09e153d197b356616b35358299ed90b399248b39 --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/voicebox/spec.py @@ -0,0 +1,180 @@ +from src.attacks.offline.perturbation.voicebox import mlp +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +from torchaudio.transforms import Spectrogram, MelSpectrogram, MFCC + +from src.attacks.offline.perturbation.voicebox.expnorm import ExponentialUnitNorm +from src.attacks.offline.perturbation.voicebox.batchnorm import BatchNorm +from src.attacks.offline.perturbation.voicebox.mlp import MLP +from src.data import DataProperties + +################################################################################ +# Convolutional spectrogram encoder with optional lookahead +################################################################################ + + +class CausalPadding(nn.Module): + """Perform 'causal' padding at end of signal along final tensor dimension""" + + def __init__(self, pad: int = 0): + super().__init__() + self.pad = pad + + def forward(self, x: torch.Tensor): + return F.pad(x, (0, self.pad)) + + +class SpectrogramEncoder(nn.Module): + """Spectrogram encoder with optional lookahead""" + def __init__(self, + win_length: int = 512, + win_type: str = 'hann', + spec_type: str = 'linear', + lookahead: int = 5, + hidden_size: int = 512, + n_mels: int = 64, + mlp_depth: int = 2, + normalize: str = None + ): + super().__init__() + + # check validity of attributes + assert normalize in [None, 'none', 'instance', 'exponential'] + if win_type not in ['rectangular', 'triangular', 'hann']: + raise ValueError(f'Invalid window type {win_type}') + + # store attributes + self.win_length = win_length + self.win_type = win_type + self.lookahead = lookahead + self.hidden_size = hidden_size + self.n_mels = n_mels + self.spec_type = spec_type + self.mlp_depth = mlp_depth + self.normalize = normalize + + # determine hop length from window function + if self.win_type == 'rectangular': # non-overlapping frames + self.hop_length = self.win_length + else: + self.hop_length = self.win_length // 2 + + # determine spectrogram normalization method + n_freq = n_mels if spec_type in ['mel', 'mfcc'] else win_length // 2 + 1 + + if normalize in [None, 'none']: + self.norm = nn.Identity() + elif normalize == 'instance': + self.norm = nn.InstanceNorm1d( + num_features=n_freq, + track_running_stats=True + ) + elif normalize == 'exponential': + self.norm = ExponentialUnitNorm( + decay=1.0, + hop_size=self.hop_length, + n_freq=n_freq + ) + + # compute spectral representation + spec_kwargs = { + "n_fft": self.win_length, + "win_length": self.win_length, + "hop_length": self.hop_length, + "window_fn": self._get_win_func(self.win_type), + } + mel_kwargs = {**spec_kwargs, "n_mels": self.n_mels} + + if spec_type == 'linear': + self.spec = Spectrogram( + **spec_kwargs + ) + elif spec_type == 'mel': + self.spec = MelSpectrogram( + sample_rate=DataProperties.get("sample_rate"), + **mel_kwargs + ) + elif spec_type == 'mfcc': + self.spec = MFCC( + sample_rate=DataProperties.get("sample_rate"), + n_mfcc=self.n_mels, + log_mels=True, + melkwargs=mel_kwargs + ) + + # GLU - learn which channels of input to pass through most strongly + self.glu = nn.Sequential( + nn.Conv1d( + in_channels=n_freq, + out_channels=self.hidden_size * 2, + kernel_size=1, + stride=1), + nn.GLU(dim=1) + ) + + # Conv1D layers + conv = [] + for i in range(lookahead): + conv.extend([ + CausalPadding(1), + nn.Conv1d( + in_channels=self.hidden_size, + out_channels=self.hidden_size, + kernel_size=2, + stride=1 + ), + BatchNorm(num_features=self.hidden_size, feature_dim=1) if i < lookahead - 1 else nn.Identity(), + nn.ReLU() + ]) + self.conv = nn.Sequential(*conv) + + # pre-bottleneck MLP + self.mlp = MLP( + in_channels=self.hidden_size, + hidden_size=self.hidden_size, + depth=mlp_depth + ) + + @staticmethod + def _get_win_func(win_type: str): + if win_type == 'rectangular': + return lambda m: torch.ones(m) + elif win_type == 'hann': + return lambda m: torch.hann_window(m) + elif win_type == 'triangular': + return lambda m: torch.as_tensor(np.bartlett(m)).float() + + def forward(self, x: torch.Tensor): + + # ensure the input contains at least a single frame of audio + assert x.shape[-1] >= self.win_length + + # require batch, channel dimensions + assert x.ndim >= 2 + if x.ndim == 2: + x = x.unsqueeze(1) + + # convert to mono audio + x = x.mean(dim=1) + + # compute spectrogram + spec = self.spec(x) + 1e-6 # (n_batch, n_freq, n_frames) + + if self.spec_type in ['linear', 'mel']: + spec = 10 * torch.log10(spec + 1e-8) # (n_batch, n_freq, n_frames) + + # normalize spectrogram + spec = self.norm(spec) # (n_batch, n_freq, n_frames) + + # actual encoder network + encoded = self.glu(spec) # (n_batch, hidden_size, n_frames) + encoded = self.conv(encoded) # (n_batch, hidden_size, n_frames) + encoded = self.mlp( + encoded.permute(0, 2, 1) + ) # (n_batch, n_frames, hidden_size) + + return encoded diff --git a/voicebox/src/attacks/offline/perturbation/voicebox/voicebox.py b/voicebox/src/attacks/offline/perturbation/voicebox/voicebox.py new file mode 100644 index 0000000000000000000000000000000000000000..a30b3d77ae3ec068147b59a57895cf03ba915d12 --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/voicebox/voicebox.py @@ -0,0 +1,660 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math +import warnings + +from pathlib import Path + +import numpy as np + +from typing import Union, Dict + +from torchaudio.transforms import MelScale + +from src.constants import PPG_PRETRAINED_PATH +from src.data.dataproperties import DataProperties +from src.attacks.offline.perturbation import Perturbation +from src.models.phoneme import PPGEncoder +from src.attacks.offline.perturbation.voicebox.pitch import PitchEncoder +from src.attacks.offline.perturbation.voicebox.loudness import LoudnessEncoder +from src.attacks.offline.perturbation.voicebox.spec import SpectrogramEncoder +from src.attacks.offline.perturbation.voicebox.bottleneck import ( + RNNBottleneck, CausalTransformer +) +from src.attacks.offline.perturbation.voicebox.lookahead import Lookahead +from src.attacks.offline.perturbation.voicebox.mlp import MLP +from src.attacks.offline.perturbation.voicebox.film import FiLM +from src.attacks.offline.perturbation.voicebox.filter import FilterLayer +from src.attacks.offline.perturbation.voicebox.batchnorm import BatchNorm +from src.attacks.offline.perturbation.voicebox.projection import ( + CausalControlProjection +) + +################################################################################ +# VoiceBox model for applying adversarial FIR filtering in real-time +################################################################################ + + +class VoiceBox(Perturbation): + + def __init__(self, + + # encoder topology + use_loudness_encoder: bool = True, + use_pitch_encoder: bool = True, + use_phoneme_encoder: bool = True, + use_spec_encoder: bool = True, + + # SpectrogramEncoder parameters + spec_encoder_type: str = 'mel', + spec_encoder_n_mels: int = 64, + spec_encoder_mlp_depth: int = 2, + spec_encoder_hidden_size: int = 512, #Andy changed from 512 to get rid of size mismatch error message + spec_encoder_lookahead_frames: int = 5, + spec_encoder_normalize: str = 'none', + + # AC-VC encoder parameters + ppg_encoder_depth: int = 2, + ppg_encoder_hidden_size: int = 256, + ppg_encoder_path: Union[str, Path] = PPG_PRETRAINED_PATH, + + # bottleneck layer parameters + bottleneck_type: str = 'lstm', # lstm + bottleneck_skip: bool = True, + bottleneck_depth: int = 2, # 8 + bottleneck_hidden_size: int = 512, + bottleneck_feedforward_size: int = 512, #Andy changed from 2048 to get rid of size mismatch error message + + # optionally, concatenate conditioning information before bottleneck + conditioning_dim: int = 0, + + # post-bottleneck lookahead convolution + bottleneck_lookahead_frames: int = 0, + + # filter control constraint parameters + mel_scale_controls: bool = False, + neutral_below_hz: float = None, + neutral_above_hz: float = None, + control_scaling_fn: str = 'sigmoid', + control_eps: float = None, + projection_norm: Union[str, int, float] = None, + projection_context: int = 10, + projection_method: str = None, + projection_decay: float = 2.0, + + # FilterLayer parameters + n_bands: int = 128, + win_length: int = 256, + win_type: str = 'hann', + normalize_ir: Union[str, int] = None, + + # audio normalization + normalize_audio: str = 'peak', + ): + + super().__init__() + + # round window length to next power of 2 + next_pow_2 = 2**(win_length - 1).bit_length() + if win_length != next_pow_2: + warnings.warn(f'Rounding block size {win_length} to nearest power' + f' of 2 ({next_pow_2})') + + # store attributes + self.win_length = next_pow_2 + self.win_type = win_type + self.normalize_audio = normalize_audio + self.n_bands = n_bands + self.mel_scale_controls = mel_scale_controls + self.bottleneck_type = bottleneck_type + self.bottleneck_skip = bottleneck_skip + self.bottleneck_depth = bottleneck_depth + self.bottleneck_hidden_size = bottleneck_hidden_size + self.bottleneck_feedforward_size = bottleneck_feedforward_size + self.bottleneck_lookahead_frames = bottleneck_lookahead_frames + self.spec_encoder_mlp_depth = spec_encoder_mlp_depth + self.ppg_encoder_hidden_size = ppg_encoder_hidden_size + self.ppg_encoder_path = ppg_encoder_path + + # ensure at least one encoder network is present + assert any([ + use_loudness_encoder, + use_pitch_encoder, + use_phoneme_encoder, + use_spec_encoder]), \ + f'Must specify at least one encoder network' + + self.use_loudness_encoder = use_loudness_encoder + self.use_pitch_encoder = use_pitch_encoder + self.use_phoneme_encoder = use_phoneme_encoder + self.use_spec_encoder = use_spec_encoder + + ######################################################################## + # AC-VC ENCODER + ######################################################################## + + if use_phoneme_encoder: + # AC-VC PPG encoder network + self.ppg_encoder = PPGEncoder( + win_length=win_length, + hop_length=win_length//2, + win_func=torch.hann_window, + n_mels=32, + n_mfcc=19, + lstm_depth=ppg_encoder_depth, + hidden_size=ppg_encoder_hidden_size + ) + self.ppg_encoder.load_state_dict( + torch.load(ppg_encoder_path, map_location=torch.device('cpu')) + ) + else: + self.ppg_encoder = nn.Identity() + + # AC-VC pitch encoder network + self.pitch_encoder = PitchEncoder( + hop_length=win_length//2) if use_pitch_encoder else nn.Identity() + + # A-weighted loudness encoder + self.loudness_encoder = LoudnessEncoder( + hop_length=win_length//2) if use_loudness_encoder else nn.Identity() + + # freeze gradients + for p in self.ppg_encoder.parameters(): + p.requires_grad = False + for p in self.pitch_encoder.parameters(): + p.requires_grad = False + for p in self.loudness_encoder.parameters(): + p.requires_grad = False + + # merge AC-VC encoder & spectrogram encoder output + n_encoder_feats = int(use_phoneme_encoder) * ppg_encoder_hidden_size + n_encoder_feats += int(use_loudness_encoder) * 1 + n_encoder_feats += int(use_pitch_encoder) * 2 + n_encoder_feats += int(use_spec_encoder) * spec_encoder_hidden_size + self.encoder_proj = nn.Sequential( + BatchNorm(num_features=n_encoder_feats), + nn.Linear(n_encoder_feats, bottleneck_hidden_size), + nn.ReLU() + ) + + ######################################################################## + # "LOOKAHEAD" SPECTROGRAM ENCODER + ######################################################################## + + # spectrogram encoder, with optional lookahead + self.spec_encoder = SpectrogramEncoder( + spec_type=spec_encoder_type, + n_mels=spec_encoder_n_mels, + win_length=win_length, + win_type=win_type, + lookahead=spec_encoder_lookahead_frames, + hidden_size=spec_encoder_hidden_size, + mlp_depth=spec_encoder_mlp_depth, + normalize=spec_encoder_normalize + ) if use_spec_encoder else nn.Identity() + + ######################################################################## + # TARGET CONDITIONING + ######################################################################## + + if conditioning_dim > 0: + + self.conditioning_mlp = MLP( + in_channels=conditioning_dim, + hidden_size=conditioning_dim, + depth=2 + ) + self.conditioning_encoder = FiLM( + cond_dim=conditioning_dim, + num_features=bottleneck_hidden_size, + batch_norm=True + ) + + self.conditioning_dim = conditioning_dim + + ######################################################################## + # LATENT BOTTLENECK + ######################################################################## + + if bottleneck_type in ['lstm', 'rnn']: + self.bottleneck = RNNBottleneck( + input_size=bottleneck_hidden_size, + hidden_size=bottleneck_feedforward_size, + proj_size=bottleneck_hidden_size, + num_layers=bottleneck_depth, + downsample_index=1, + downsample_factor=1, + dropout_prob=0.0 + ) + elif bottleneck_type in ['attention', 'transformer']: + self.bottleneck = CausalTransformer( + hidden_size=bottleneck_hidden_size, + dim_feedforward=bottleneck_feedforward_size, + depth=bottleneck_depth, + heads=8, + dropout_prob=0.0 + ) + else: + self.bottleneck = nn.Identity() + + # post-bottleneck projection with optional lookahead + n_bottleneck_feats = 2 * bottleneck_hidden_size if bottleneck_skip \ + else bottleneck_hidden_size + self.bottleneck_proj = nn.Sequential( + Lookahead( + n_features=n_bottleneck_feats, + lookahead_frames=bottleneck_lookahead_frames + ) if bottleneck_lookahead_frames else nn.Identity(), + nn.LeakyReLU() if bottleneck_lookahead_frames else nn.Identity(), + BatchNorm(n_bottleneck_feats), + nn.Linear(n_bottleneck_feats, n_bands) + ) + + ######################################################################## + # DECODER + ######################################################################## + + assert control_scaling_fn.lower() in ['sigmoid', 'elu', 'relu', 'log'], \ + f'Invalid filter control scaling function {control_scaling_fn}' + self.control_scaling_fn = control_scaling_fn + + # (optionally-causal) filter-control projection module + self.projector = CausalControlProjection( + eps=control_eps, + n_controls=n_bands, + unity=0.0 if control_scaling_fn == 'log' else 1.0, + projection_norm=projection_norm, + method=projection_method, + decay=projection_decay, + context=projection_context + ) + + # constrain filter controls + self.cutoff_high = self.hz_to_band( + neutral_above_hz, + n_bands, + DataProperties.get('sample_rate') + ) if neutral_above_hz else n_bands + self.cutoff_low = self.hz_to_band( + neutral_below_hz, + n_bands, + DataProperties.get('sample_rate') + ) if neutral_below_hz else 0 + + # optionally, use mel-scaled filter controls + self.register_buffer("inv_mel_fb", MelScale( + n_mels=n_bands, + sample_rate=DataProperties.get('sample_rate'), + n_stft=n_bands + ).fb.transpose(0, 1).pinverse()) + self.inv_mel_scale = lambda x: torch.matmul( + self.inv_mel_fb, + x.permute(0, 2, 1) + ).clamp(min=0, max=None).permute(0, 2, 1) + + # filter + self.filter = FilterLayer( + win_length=win_length, + win_type=win_type, + n_bands=n_bands, + normalize_ir=normalize_ir + ) + + # references for visualization + self.ref_wav = torch.empty(0) + self.ref_controls = torch.empty(0) + + def set_reference(self, x: torch.Tensor): + """Store reference audio for visualization/logging""" + + # ensure the input contains at least a single frame of audio + assert x.shape[-1] >= self.win_length + + # avoid modifying input audio + x = x.clone().detach() + n_batch, *channel_dims, signal_len = x.shape + + # add channel dimension if necessary + if len(channel_dims) == 0: + x = x.unsqueeze(1) + + # store reference audio + self.ref_wav = x[0] + + # store reference controls + with torch.no_grad(): + self.ref_controls = self.get_controls(x)[0] + + def _project_valid_top_level(self): + pass + + def _visualize_top_level(self) -> Dict[str, torch.Tensor]: + """Visualize controls and reference audio""" + + name = self.__class__.__name__ + + visualizations = {} + + # compute controls for stored reference audio + with torch.no_grad(): + self.ref_controls = self.get_controls( + self.ref_wav + ) + + def band_to_hz(band: int): + nyquist = DataProperties.get('sample_rate') // 2 + hz = int(band / self.n_bands * nyquist) + return hz + + # plot controls + import io + import matplotlib.pyplot as plt + from PIL import Image + from torchvision.transforms import ToTensor + + _, t, f = self.ref_controls.shape + + # scale controls + if self.control_scaling_fn.lower() == 'sigmoid': + controls = 2 * torch.sigmoid( + self.ref_controls + )**(math.log(10)) + 1e-7 + elif self.control_scaling_fn.lower() == 'relu': + controls = F.relu(self.ref_controls) + elif self.control_scaling_fn.lower() == 'elu': + controls = F.elu(self.ref_controls) + 1 + elif self.control_scaling_fn.lower() == 'log': + controls = torch.tanh(self.ref_controls * 0.2) * 5 # [-5, 5] + controls = torch.exp(controls) + else: + raise ValueError(f'Invalid control scaling function ' + f'{self.control_scaling_fn}') + + # if controls are taken to be mel-scaled, linearly scale + if self.mel_scale_controls: + controls = self.inv_mel_scale(controls) + + # perform projection (PGD step) on filter controls + controls = self._project_valid(controls) + + # scale to [-4, 4] for plotting, with 0 at center (logarithmic) + # clips near zero, and + controls = torch.clamp(torch.log2(controls + 1e-8), min=-4, max=4) + controls = controls.clone().detach().squeeze().cpu().numpy() + + fig, axs = plt.subplots( + nrows=2, + ncols=1, + figsize=(30, 30), + gridspec_kw={'height_ratios': [4, 1]}) + + # draw frame boundaries + for frame in range(t): + axs[0].vlines( + frame, + ymin=-10, + ymax=f * 10, + color='k', + alpha=0.0, + linewidth=.5) + + for band in range(f): + axs[0].plot([band * 10]*t, 'k', alpha=0.6, linewidth=.5) + axs[0].plot(controls[:, band] + band * 10, alpha=0.9, linewidth=2) + axs[0].set_ylabel('Filter Band', fontsize=30) + + axs[0].set_yticks( + [10 * i for i in range(f)], + [f'{band_to_hz(i)}Hz' for i in range(f)]) + + axs[1].plot(self.ref_wav.cpu().numpy().flatten()[::25], color='k', linewidth=1) + axs[1].set_ylabel('Waveform Amplitude', fontsize=30) + axs[1].set_ylim([-1, 1]) + plt.tight_layout() + + # save plot to buffer + buf = io.BytesIO() + plt.savefig(buf, format="png") + plt.close(fig) + buf.seek(0) + img = Image.open(buf) + + # return plot image as tensor + output = ToTensor()(np.array(img)) + + visualizations = { + **visualizations, + f'{name}-parameters': output + } + + return visualizations + + @staticmethod + def hz_to_band(f: float, n_bands: int, sr: int): + nyquist = sr // 2 + band = max(0, + min( + int(f * n_bands / nyquist), n_bands + )) + return band + + def _project_valid(self, controls: torch.Tensor): + + # account for log-scaling + if self.control_scaling_fn.lower() == 'log': + controls = torch.log(controls + 1e-8) + + controls = self.projector(controls) + + # account for log-scaling + if self.control_scaling_fn == 'log': + controls = torch.exp(controls) + + # if specified, keep filter neutral in given ranges + if self.cutoff_low: + # avoid in-place operations in forward pass + b, t, c = controls.shape + controls = torch.cat( + [ + torch.ones((b, t, self.cutoff_low), device=controls.device), + controls[..., self.cutoff_low:], + ], dim=-1) + assert controls.shape == (b, t, c) + + if self.cutoff_high: + # avoid in-place operations in forward pass + b, t, c = controls.shape + controls = torch.cat( + [ + controls[..., :self.cutoff_high], + torch.ones((b, t, c - self.cutoff_high), device=controls.device), + ], dim=-1) + assert controls.shape == (b, t, c) + + return controls + + def get_controls(self, + x: torch.Tensor, + pitch: torch.Tensor = None, + periodicity: torch.Tensor = None, + loudness: torch.Tensor = None, + y: torch.Tensor = None, + *args, **kwargs + ): + """Map audio inputs to frame-wise filter controls""" + + # ensure the input contains at least a single frame of audio + assert x.shape[-1] >= self.win_length + + ######################################################################## + # AC-VC ENCODER + ######################################################################## + + with torch.no_grad(): + + features = [] + + # compute features if necessary + if self.use_pitch_encoder: + if pitch is None or periodicity is None: + pitch, periodicity = self.pitch_encoder(x) + features.extend([pitch, periodicity]) + + if self.use_loudness_encoder: + if loudness is None: + loudness = self.loudness_encoder(x) + features.append(loudness) + + # compute phonetic posteriorgrams (PPGs) + if self.use_phoneme_encoder: + ppg = self.ppg_encoder(x) + features.append(ppg) + + ######################################################################## + # SPECTROGRAM ENCODER + ######################################################################## + + if self.use_spec_encoder: + spec = self.spec_encoder(x) # (n_batch, n_frames, hidden_size) + features.append(spec) + + ######################################################################## + # MERGE ENCODINGS + ######################################################################## + + encoded = self.encoder_proj( + torch.cat(features, dim=-1)) # (n_batch, n_frames, hidden_size) + + ######################################################################## + # TARGET CONDITIONING + ######################################################################## + + if self.conditioning_dim: + + n_frames = encoded.shape[1] + + if y is not None: + assert y.shape[-1] == self.conditioning_dim + assert y.ndim == 3 + + # average over all segments if present + y = y.mean(dim=1, keepdim=True) + + # duplicate over all frames + y = y.repeat(1, n_frames, 1) + + else: + + y = torch.zeros( + (x.shape[0], n_frames, self.conditioning_dim), + device=x.device) + + # apply learned affine transformations to feature dimension + encoded = self.conditioning_encoder( + x=encoded, + cond=self.conditioning_mlp(y) + ) # (n_batch, n_frames, hidden_size) + + ######################################################################## + # BOTTLENECK + ######################################################################## + + bottleneck_out = self.bottleneck(encoded) + + # apply skip connection with pre-bottleneck encoding + if self.bottleneck_skip: + bottleneck_out = torch.cat([ + bottleneck_out, + encoded + ], dim=-1) # (n_batch, n_frames, 2 * hidden_size) + + controls = self.bottleneck_proj(bottleneck_out) # (n_batch, n_frames, n_bands) + + return controls + + def forward(self, + x: torch.Tensor, + pitch: torch.Tensor = None, + periodicity: torch.Tensor = None, + loudness: torch.Tensor = None, + y: torch.Tensor = None, + *args, **kwargs): + + # ensure the input contains at least a single frame of audio + assert x.shape[-1] >= self.win_length + + # require batch, channel dimensions + assert x.ndim >= 2 + n_batch, *channel_dims, signal_len = x.shape + + if x.ndim == 2: + x = x.unsqueeze(1) + + # convert to mono audio + x = x.mean(dim=1, keepdim=True) + + # if features are provided, check dimensions + assert pitch is None or pitch.shape[0] == x.shape[0] + assert periodicity is None or periodicity.shape[0] == x.shape[0] + assert loudness is None or loudness.shape[0] == x.shape[0] + + # prepare to normalize output volume + peak = torch.max(torch.abs(x), -1)[0].reshape(n_batch) + + controls = self.get_controls(x, + pitch, + periodicity, + loudness, + y) # (n_batch, n_frames, n_bands) + + ######################################################################## + # CONTROL SCALING + ######################################################################## + + # scale stored controls to fixed range + if self.control_scaling_fn.lower() == 'sigmoid': + controls = 2 * torch.sigmoid( + controls + )**(math.log(10)) + 1e-7 + elif self.control_scaling_fn.lower() == 'relu': + controls = F.relu(controls) + elif self.control_scaling_fn.lower() == 'elu': + controls = F.elu(controls) + 1 + elif self.control_scaling_fn.lower() == 'log': + controls = torch.tanh(controls * 0.2) * 5 # [-5, 5] + controls = torch.exp(controls) + else: + raise ValueError(f'Invalid control scaling function ' + f'{self.control_scaling_fn}') + + # if controls are taken to be mel-scaled, linearly scale + if self.mel_scale_controls: + controls = self.inv_mel_scale(controls) + + # perform projection (PGD step) on filter controls + controls = self._project_valid(controls) + + ######################################################################## + # FILTERING + ######################################################################## + + # apply filter + x = self.filter(x, controls) + + # apply normalization to match input volume + if self.normalize_audio in [None, 'none']: + factor = 1.0 + elif self.normalize_audio == 'peak': + factor = peak / (torch.max(torch.abs(x), -1)[0].reshape(n_batch) + 1e-6) + factor = factor.reshape(n_batch, 1, 1) + else: + raise ValueError(f'Invalid audio normalization type {self.normalize_audio}') + + x = x * factor + + # match original dimensions + x = x[..., :signal_len].reshape(n_batch, *((1,) * len(channel_dims)), signal_len) + + return x diff --git a/voicebox/src/attacks/offline/perturbation/white_noise.py b/voicebox/src/attacks/offline/perturbation/white_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..fa65ad0e422da19e5db0fc8d530ac31fa08746bc --- /dev/null +++ b/voicebox/src/attacks/offline/perturbation/white_noise.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn + +import math +import random + +from typing import Union, Dict + +from src.attacks.offline.perturbation.perturbation import Perturbation +from src.data import DataProperties + +################################################################################ +# Apply additive perturbation to waveform audio +################################################################################ + + +class WhiteNoisePerturbation(Perturbation): + + def __init__(self, snr: float = 0.0): + + super().__init__() + self.snr = nn.Parameter(torch.as_tensor([snr], dtype=torch.float32)) + + def set_reference(self, x: torch.Tensor): + """ + Given reference input, initialize perturbation parameters accordingly + and match input device. + + :param x: reference audio, shape (n_batch, n_channels, signal_length) + """ + self.snr = self.snr.to(x.device) + + def set_snr(self, snr: float): + self.snr.fill_(snr) + + def forward(self, x: torch.Tensor, *args, **kwargs): + """ + Apply perturbation to inputs. + + :param x: input audio, shape (n_batch, n_channels, signal_length) + """ + + # do not overwrite incoming audio + x = x.clone().detach() + + # numerical stability + eps = 1e-8 + + # white noise + noise = torch.randn_like(x) + + # scale noise level to stored SNR + noise_db = 10 * torch.log10( + torch.mean(torch.square(noise), dim=-1, keepdims=True) + eps + ) + + signal_db = 10 * torch.log10( + torch.mean(torch.square(x), dim=-1, keepdims=True) + eps + ) + + scale = torch.sqrt( + torch.pow(10, (signal_db - noise_db - self.snr) / 10) + ) + + return (scale * noise) + x + + def _visualize_top_level(self) -> Dict[str, torch.Tensor]: + """ + Visualize top-level (non-recursive) perturbation parameters. + + :return: tag (string) / image (tensor) pairs, stored in a dictionary + """ + + visualizations = {} + return visualizations + + def _project_valid_top_level(self): + """ + Project top-level (non-recursive) parameters to valid range. + """ + pass diff --git a/voicebox/src/attacks/offline/trainable.py b/voicebox/src/attacks/offline/trainable.py new file mode 100644 index 0000000000000000000000000000000000000000..3400a31f931c00d648bab8454025a617341467a9 --- /dev/null +++ b/voicebox/src/attacks/offline/trainable.py @@ -0,0 +1,691 @@ +import os +import warnings + +import torch +import torch.nn as nn + +from pathlib import Path +from typing import Tuple, Union + +from torch.utils.data import Dataset, DataLoader + +from src.attacks.offline.offline import OfflineAttack +from src.attacks.offline.orthogonal_selective import SelectiveOrthogonalPGDMixin +from src.attacks.offline.perturbation.perturbation import Perturbation +from src.pipelines import Pipeline +from src.loss.adversarial import AdversarialLoss +from src.loss.auxiliary import AuxiliaryLoss +from src.utils.writer import Writer + +################################################################################ +# Base class for trainable attacks +################################################################################ + + +class TrainableAttack(OfflineAttack, SelectiveOrthogonalPGDMixin): + + def __init__(self, + pipeline: Pipeline, + perturbation: torch.nn.Module, + adv_loss: AdversarialLoss, + aux_loss: AuxiliaryLoss = None, + adv_success_thresh: float = 0.0, + det_success_thresh: float = 0.0, + opt: str = 'adam', + lr: float = 1e-4, + pgd_variant: str = None, + pgd_norm: Union[str, int, float] = None, + scale_grad: Union[int, float, str] = None, + k: int = None, + epochs: int = 10, + max_iter: int = 1, + batch_size: int = 32, + rand_evals: int = 0, + eot_iter: int = 0, + checkpoint_name: str = None, + writer: Writer = None, + validate: bool = True, + **kwargs): + + super().__init__( + pipeline=pipeline, + adv_loss=adv_loss, + aux_loss=aux_loss, + batch_size=batch_size, + rand_evals=rand_evals, + writer=writer, + **kwargs + ) + + # underlying perturbation/model + self.perturbation = perturbation.to(self.pipeline.device) + + # optimizer + self.lr = lr + self.opt = opt + self.optimizer = None + self.epochs = epochs + self.max_iter = max_iter + self.eot_iter = eot_iter + + # PGD algorithm + self.pgd_variant = pgd_variant + self.pgd_norm = pgd_norm + self.scale_grad = scale_grad + self.k = k + self.adv_success_thresh = adv_success_thresh + self.det_success_thresh = det_success_thresh + + # determine whether to perform validation during training + self.validate = validate + + # checkpointing + self.checkpoint_name = checkpoint_name + + # track epoch count + self._epoch_id = 0 + + self._check_loss() + + def _tile_and_create_dataset(self, x: torch.Tensor, y: torch.Tensor): + """ + Given inputs and targets, create a dataset. If only a single target is + given, repeat to match length of inputs. + """ + # if only a single target is given, repeat to length of dataset + y = y.unsqueeze(0) if y.ndim < 1 else y + + if y.shape[0] == 1: + y = y.repeat_interleave(dim=0, repeats=x.shape[0]) + + return self._create_dataset(x, y) + + def _get_optimizer(self): + """Configure optimizer for stored model/perturbation""" + + if self.opt == 'adam': + optimizer = torch.optim.Adam( + self.perturbation.parameters(), + lr=self.lr, + betas=(.99, .999), + eps=1e-7, + amsgrad=False + ) + elif self.opt == 'lbfgs': + optimizer = torch.optim.LBFGS( + self.perturbation.parameters(), + lr=self.lr, + line_search_fn='strong_wolfe' + ) + elif self.opt == 'sgd': + optimizer = torch.optim.SGD( + self.perturbation.parameters(), + lr=self.lr + ) + else: + raise ValueError(f'Invalid optimizer {self.opt}') + + return optimizer + + def _set_loss_reference(self, x: torch.Tensor): + """ + Pass reference audio to auxiliary loss to avoid re-computing expensive + intermediate representations + """ + if self.aux_loss is not None: + self.aux_loss.set_reference(x) + + def _compute_aux_loss(self, + x_adv: torch.Tensor, + x_ref: torch.Tensor = None): + """Compute auxiliary loss given perturbed input""" + return self.aux_loss(x_adv, x_ref) + + def _prepare_data(self, + x_train: torch.Tensor = None, + y_train: torch.Tensor = None, + data_train: Dataset = None, + x_val: torch.Tensor = None, + y_val: torch.Tensor = None, + data_val: Dataset = None, + ): + + # require training dataset + assert (x_train is not None and y_train is not None) \ + or data_train is not None, 'Must provide training data' + + # require validation dataset + assert (x_val is not None and y_val is not None) \ + or data_val is not None, 'Must provide validation data' + + # package tensors as datasets + if data_train is None: + data_train = self._tile_and_create_dataset(x_train, y_train) + if data_val is None: + data_val = self._tile_and_create_dataset(x_val, y_val) + + loader_train = DataLoader( + dataset=data_train, + batch_size=self.batch_size, + shuffle=True, + drop_last=False, + pin_memory=self.pin_memory, + num_workers=self.num_workers + ) + + loader_val = DataLoader( + dataset=data_val, + batch_size=self.batch_size, + shuffle=False, + drop_last=False, + pin_memory=self.pin_memory, + num_workers=self.num_workers + ) + + return loader_train, loader_val + + def _train_batch(self, + x: torch.Tensor, + y: torch.Tensor, + *args, + **kwargs): + """Optimize stored model/perturbation over a batch of inputs""" + + # require batch dimension + assert x.ndim >= 2 + n_batch = x.shape[0] + + x = x.detach() + + # set reference for auxiliary loss to avoid re-computing + self._set_loss_reference(x) + + # randomly sample simulation parameters + if self.eot_iter and not self._iter_id % self.eot_iter: + self.pipeline.sample_params() + + def closure(): + + # placeholder for final model/perturbation gradients + model_gradients = \ + self._retrieve_parameter_gradients(self.perturbation) + grad_total = torch.zeros_like(model_gradients) + + # apply adversarial perturbation to batch and obtain predictions + perturbed = self.perturbation(x, *args, **kwargs) + outputs = self.pipeline(perturbed) + + # reset parameter gradients, using `None` for performance boost + self.perturbation.zero_grad(set_to_none=True) + + # compute flattened parameter gradients w.r.t. adversarial loss + adv_scores = self.adv_loss(outputs, y) + adv_loss = torch.mean(adv_scores) + adv_loss.backward(retain_graph=True) + adv_loss_grad = self._retrieve_parameter_gradients( + self.perturbation + ).detach() + + # reset parameter gradients, using `None` for performance boost + self.perturbation.zero_grad(set_to_none=True) + + # compute flattened parameter gradients w.r.t. detector loss + detector_flags, detector_scores = self.pipeline.detect(perturbed) + detector_loss = torch.mean(detector_scores) + detector_loss.backward(retain_graph=True) + detector_loss_grad = self._retrieve_parameter_gradients( + self.perturbation + ).detach() + + # reset parameter gradients, using `None` for performance boost + self.perturbation.zero_grad(set_to_none=True) + + # compute flattened parameter gradients w.r.t. auxiliary loss + if self.aux_loss is not None: + aux_scores = self._compute_aux_loss(perturbed) + aux_loss = torch.mean(aux_scores) + aux_loss.backward() + aux_loss_grad = self._retrieve_parameter_gradients( + self.perturbation + ).detach() + else: # if no auxiliary loss, do not penalize + aux_scores = torch.zeros(n_batch).to(x.device) + aux_loss = torch.mean(aux_scores) + aux_loss_grad = torch.zeros_like(adv_loss_grad).detach() + + # classifier evasion indicator, reshape for broadcasting + adv_success = (adv_loss <= self.adv_success_thresh) * 1.0 + + # detector evasion indicator, reshape for broadcasting + detector_success = (detector_loss <= self.det_success_thresh) * 1.0 + + # perform standard, orthogonal, or selective gradient + # accumulation + if self.pgd_variant is None or self.pgd_variant == 'none': + + # for standard PGD, sum loss gradients + grad_total += adv_loss_grad + \ + detector_loss_grad + \ + aux_loss_grad + + elif self.pgd_variant == 'orthogonal': + + # for orthogonal PGD, orthogonalize loss gradients and + # select one for update; optionally, orthogonalize only + # every kth step + if self.k and self._iter_id % self.k: + adv_loss_grad_proj = adv_loss_grad + detector_loss_grad_proj = detector_loss_grad + aux_loss_grad_proj = aux_loss_grad + else: + adv_loss_grad_proj = self._component_orthogonal( + adv_loss_grad, + detector_loss_grad, + aux_loss_grad + ) + detector_loss_grad_proj = self._component_orthogonal( + detector_loss_grad, + adv_loss_grad, + aux_loss_grad + ) + aux_loss_grad_proj = self._component_orthogonal( + aux_loss_grad, + detector_loss_grad, + adv_loss_grad + ) + + # update 'along' a single loss gradient per iteration + grad_total += adv_loss_grad_proj * (1 - adv_success) + grad_total += detector_loss_grad_proj * adv_success \ + * (1 - detector_success) + grad_total += aux_loss_grad_proj * adv_success * \ + detector_success + + elif self.pgd_variant == 'selective': + + # only consider a single loss per iteration, without + # ensuring orthogonality to remaining loss gradients + grad_total += adv_loss_grad * (1 - adv_success) + grad_total += detector_loss_grad * adv_success \ + * (1 - detector_success) + grad_total += aux_loss_grad * adv_success * detector_success + + else: + raise ValueError(f'Invalid attack mode {self.pgd_variant}') + + # regularize gradients via p-norm projection + if self.scale_grad in [2, float(2), "2"]: + grad_norms = torch.norm( + grad_total, p=2, dim=-1 + ) + 1e-20 + grad_total = grad_total / grad_norms + elif self.scale_grad in [float("inf"), "inf"]: + grad_total = torch.sign(grad_total) + elif self.scale_grad in ['none', None]: + pass + else: + raise ValueError(f'Invalid gradient regularization norm ' + f'{self.scale_grad}' + ) + + # set final parameter gradients + self._set_parameter_gradients( + grad_total.flatten(), + self.perturbation + ) + + # log results + if self.writer is not None: + self._log_step( + x=x, + x_adv=perturbed, + y=y, + adv_loss=adv_loss, + det_loss=detector_loss, + aux_loss=aux_loss, + detection_rate=torch.mean(1.0 * detector_flags) + ) + + # return placeholder loss + return adv_loss + detector_loss + aux_loss + + # optimizer step, using stored gradients + self.optimizer.step(closure) + + # project perturbation to feasible region + if hasattr(self.perturbation, "project_valid"): + try: + self.perturbation.project_valid() + except AttributeError: + pass + + # update total iteration count + self._iter_id += 1 + + def train(self, + x_train: torch.Tensor = None, + y_train: torch.Tensor = None, + data_train: Dataset = None, + x_val: torch.Tensor = None, + y_val: torch.Tensor = None, + data_val: Dataset = None, + *args, + **kwargs + ): + """ + Optimize trainable attack parameters over training data. + + Parameters + ---------- + + Returns + ------- + """ + + loader_train, loader_val = self._prepare_data( + x_train, + y_train, + data_train, + x_val, + y_val, + data_val) + + # match devices and set reference if necessary + ref_batch = next(iter(loader_train)) + + if isinstance(ref_batch, tuple): + x_ref = ref_batch[0] + warnings.warn('Warning: provided dataset yields batches in tuple ' + 'format; the first two tensors of each batch will be ' + 'interpreted as inputs and targets, respectively, ' + 'and any remaining tensors will be ignored. To pass ' + 'additional named tensor arguments, use a dictionary ' + 'batch format with keys `x` and `y` for inputs and ' + 'targets, respectively.') + elif isinstance(ref_batch, dict): + x_ref = ref_batch['x'] + else: + x_ref = ref_batch + + if hasattr(self.perturbation, "set_reference"): + try: + self.perturbation.set_reference( + x_ref.to(self.pipeline.device)) + except AttributeError: + pass + + # configure optimizer + self.optimizer = self._get_optimizer() + + # reset cumulative iteration count + self._iter_id = 0 + + # optimize perturbation over given number of epochs + for epoch_id in range(self.epochs): + + self._batch_id = 0 + self._epoch_id = epoch_id + + self.perturbation.train() + for batch_id, batch in enumerate(loader_train): + + self._batch_id = batch_id + + # allow for different dataset formats + if isinstance(batch, tuple): + batch = { + 'x': batch[0], + 'y': batch[1] + } + + # match devices + for k in batch.keys(): + batch[k] = batch[k].to(self.pipeline.device) + + self._train_batch(**batch) + + # perform validation once per epoch + if self.validate: + adv_scores = [] + aux_scores = [] + det_scores = [] + success_indicators = [] + detection_indicators = [] + + self.perturbation.eval() + for batch_id, batch in enumerate(loader_val): + + # randomize simulation for each validation batch + self.pipeline.sample_params() + + # allow for different dataset formats + if isinstance(batch, tuple): + batch = { + 'x': batch[0], + 'y': batch[1] + } + + n_batch = batch['x'].shape[0] + + # match devices + for k in batch.keys(): + batch[k] = batch[k].to(self.pipeline.device) + + # set reference for auxiliary loss + self._set_loss_reference(batch['x']) + + with torch.no_grad(): + + # compute adversarial loss + x_adv = self._evaluate_batch(**batch) + outputs = self.pipeline(x_adv) + adv_scores.append(self.adv_loss(outputs, batch['y']).flatten()) + + # compute adversarial success rate + success_indicators.append( + 1.0 * self._compute_success_array( + x=batch['x'], y=batch['y'], x_adv=x_adv + ).flatten()) + + # compute defense loss and detection indicators + def_results = self.pipeline.detect(x_adv) + detection_indicators.append(1.0 * def_results[0].flatten()) + det_scores.append(def_results[1].flatten()) + + # compute auxiliary loss + if self.aux_loss is not None: + aux_scores.append( + self._compute_aux_loss(x_adv).flatten()) + else: + aux_scores.append(torch.zeros(n_batch)) + + tag = f'{self.__class__.__name__}-' \ + f'{self.aux_loss.__class__.__name__}' + + if self.writer is not None: + + with self.writer.force_logging(): + + # adversarial loss value + self.writer.log_scalar( + torch.cat(adv_scores, dim=0).mean(), + f"{tag}/adversarial-loss-val", + global_step=self._iter_id + ) + + # detector loss value + self.writer.log_scalar( + torch.cat(det_scores, dim=0).mean(), + f"{tag}/detector-loss-val", + global_step=self._iter_id + ) + + # auxiliary loss value + self.writer.log_scalar( + torch.cat(aux_scores, dim=0).mean(), + f"{tag}/auxiliary-loss-val", + global_step=self._iter_id + ) + + # adversarial success rate + self.writer.log_scalar( + torch.cat(success_indicators, dim=0).mean(), + f"{tag}/success-rate-val", + global_step=self._iter_id + ) + + # adversarial detection rate + self.writer.log_scalar( + torch.cat(detection_indicators, dim=0).mean(), + f"{tag}/detection-rate-val", + global_step=self._iter_id + ) + + # clear optimizer + self.optimizer = None + + # freeze model parameters + self.perturbation.eval() + for p in self.perturbation.parameters(): + p.requires_grad = False + + # save model/perturbation + self._checkpoint() + + def _evaluate_batch(self, + x: torch.Tensor, + y: torch.Tensor, + *args, + **kwargs + ): + """Evaluate batch of inputs by passing through model/perturbation""" + + x_orig = x.clone().detach() + x_adv = self.perturbation(x_orig, *args, **kwargs) + return x_adv + + @torch.no_grad() + def evaluate(self, + x: torch.Tensor = None, + y: torch.Tensor = None, + dataset: Dataset = None, + *args, + **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + self.perturbation.eval() + return super().evaluate(x, y, dataset, *args, **kwargs) + + def _log_step(self, + x: torch.Tensor, + x_adv: torch.Tensor, + y: torch.Tensor, + adv_loss: Union[float, torch.Tensor] = None, + det_loss: Union[float, torch.Tensor] = None, + aux_loss: Union[float, torch.Tensor] = None, + success_rate: Union[float, torch.Tensor] = None, + detection_rate: Union[float, torch.Tensor] = None, + idx: int = 0, + tag: str = None, + *args, + **kwargs + ): + """ + Log attack progress. + + Parameters + ---------- + x (torch.Tensor): batch of original inputs + x_adv (torch.Tensor): batch of adversarial inputs + y (torch.Tensor): batch of targets + adv_loss (float): adversarial loss value + aux_loss (float): auxiliary loss value + det_loss (float): detector loss value + success_rate (float): attack success rate + detection_rate (float): attack detection rate + idx (int): batch index for logging individual examples + tag (str): label for logging output + """ + + if self.writer is None or self._iter_id % self.writer.log_iter: + return + + if tag is None: + tag = f'{self.__class__.__name__}-' \ + f'{self.aux_loss.__class__.__name__}' + + super()._log_step( + x, + x_adv, + y, + adv_loss=adv_loss, + det_loss=det_loss, + aux_loss=aux_loss, + success_rate=success_rate, + detection_rate=detection_rate, + idx=idx, + tag=tag + ) + + # log perturbation visualizations + if hasattr(self.perturbation, "visualize"): + try: + visualizations = self.perturbation.visualize() # Dict[str: tensor] + for name, image in visualizations.items(): + self.writer.log_image( + tag=f'{tag}/{name}', + image=image, + global_step=self._iter_id + ) + except AttributeError: + pass + + def load(self, path: Union[str, Path]): + """Load weights for stored perturbation/model""" + + checkpoint_path = Path(path) + + # for files, load directly + if checkpoint_path.is_file(): + final_path = checkpoint_path + + # for directory, check for most recent file + elif checkpoint_path.is_dir(): + + # search for files with matching identifier + if self.checkpoint_name is not None: + tag = f'{self.checkpoint_name}*.pt' + else: + tag = f'{self.__class__.__name__}-' \ + f'{self.aux_loss.__class__.__name__}*.pt' + valid_files = Path(checkpoint_path).rglob(tag) + + # select most recent checkpoint + final_path = max(valid_files, key=os.path.getctime) + else: + raise ValueError(f'Invalid checkpoint path {path}') + + self.perturbation.load_state_dict( + torch.load( + final_path, + map_location=self.pipeline.device) + ) + + def _checkpoint(self): + """Save model/perturbation checkpoint""" + if self.writer is not None: + if self.checkpoint_name is not None: + tag = f'{self.checkpoint_name}-epoch-{self._epoch_id}' + else: + tag = f'{self.__class__.__name__}-' \ + f'{self.aux_loss.__class__.__name__}-' \ + f'epoch-{self._epoch_id}' + self.writer.checkpoint( + self.perturbation.state_dict(), + tag=tag, + global_step=None + ) + + def __del__(self): + """Save model/perturbation checkpoint upon deletion""" + self._checkpoint() diff --git a/voicebox/src/attacks/offline/voicebox.py b/voicebox/src/attacks/offline/voicebox.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5729f43255e4945084861041e7fb9a892c05f3 --- /dev/null +++ b/voicebox/src/attacks/offline/voicebox.py @@ -0,0 +1,287 @@ +import torch + +from src.attacks.offline.perturbation import Perturbation +from src.attacks.offline.trainable import TrainableAttack +from src.attacks.offline.perturbation import VoiceBox +from src.loss.auxiliary import AuxiliaryLoss + +from typing import Union + +################################################################################ +# VoiceBox online filtering-based attack +################################################################################ + + +class VoiceBoxAttack(TrainableAttack): + + def __init__(self, + voicebox_kwargs: dict, + control_loss: AuxiliaryLoss = None, + **kwargs): + + # additional (optional) auxiliary loss on filter controls + self.control_loss = control_loss + + super().__init__( + perturbation=VoiceBox(**voicebox_kwargs), + **kwargs) + + def _log_step(self, + x: torch.Tensor, + x_adv: torch.Tensor, + y: torch.Tensor, + adv_loss: Union[float, torch.Tensor] = None, + det_loss: Union[float, torch.Tensor] = None, + aux_loss: Union[float, torch.Tensor] = None, + control_loss: Union[float, torch.Tensor] = None, + success_rate: Union[float, torch.Tensor] = None, + detection_rate: Union[float, torch.Tensor] = None, + idx: int = 0, + tag: str = None, + *args, + **kwargs + ): + + if self.writer is None or self._iter_id % self.writer.log_iter: + return + + if tag is None: + tag = f'{self.__class__.__name__}-' \ + f'{self.aux_loss.__class__.__name__}' + + super()._log_step( + x, + x_adv, + y, + adv_loss=adv_loss, + det_loss=det_loss, + aux_loss=aux_loss, + success_rate=success_rate, + detection_rate=detection_rate, + idx=idx, + tag=tag + ) + + # log control-signal loss + self.writer.log_scalar( + control_loss, + f"{tag}/control-signal-loss", + global_step=self._iter_id + ) + + def _train_batch(self, + x: torch.Tensor, + y: torch.Tensor, + *args, + **kwargs): + """Optimize stored model/perturbation over a batch of inputs""" + + # require batch dimension + assert x.ndim >= 2 + n_batch = x.shape[0] + + x = x.detach() + + # set reference for auxiliary loss to avoid re-computing + self._set_loss_reference(x) + + # randomly sample simulation parameters + if self.eot_iter and not self._iter_id % self.eot_iter: + self.pipeline.sample_params() + + def closure(): + + # placeholder for final model/perturbation gradients + model_gradients = \ + self._retrieve_parameter_gradients(self.perturbation) + grad_total = torch.zeros_like(model_gradients) + + # apply adversarial perturbation to batch and obtain predictions + perturbed = self.perturbation(x, y=y, *args, **kwargs) + outputs = self.pipeline(perturbed) + + # reset parameter gradients, using `None` for performance boost + self.perturbation.zero_grad(set_to_none=True) + + # compute flattened parameter gradients w.r.t. adversarial loss + adv_scores = self.adv_loss(outputs, y) + adv_loss = torch.mean(adv_scores) + adv_loss.backward(retain_graph=True) + adv_loss_grad = self._retrieve_parameter_gradients( + self.perturbation + ).detach() + + # reset parameter gradients, using `None` for performance boost + self.perturbation.zero_grad(set_to_none=True) + + # compute flattened parameter gradients w.r.t. detector loss + detector_flags, detector_scores = self.pipeline.detect(perturbed) + detector_loss = torch.mean(detector_scores) + detector_loss.backward(retain_graph=True) + detector_loss_grad = self._retrieve_parameter_gradients( + self.perturbation + ).detach() + + # reset parameter gradients, using `None` for performance boost + self.perturbation.zero_grad(set_to_none=True) + + # compute flattened parameter gradients w.r.t. auxiliary loss + if self.aux_loss is not None: + aux_scores = self._compute_aux_loss(perturbed) + aux_loss = torch.mean(aux_scores) + aux_loss.backward() + aux_loss_grad = self._retrieve_parameter_gradients( + self.perturbation + ).detach() + else: # if no auxiliary loss, do not penalize + aux_scores = torch.zeros(n_batch).to(x.device) + aux_loss = torch.mean(aux_scores) + aux_loss_grad = torch.zeros_like(adv_loss_grad).detach() + + # obtain filter controls for given inputs + get_controls = callable( + getattr(self.perturbation, "get_controls", None)) + if self.control_loss is not None and get_controls: + + # compute slowness / sparsity loss on control signal + controls = self.perturbation.get_controls( + x, *args, **kwargs) + control_scores = self.control_loss(controls) + control_loss = torch.mean(control_scores) * 0.01 + + # backpropagate + control_loss.backward() + + # retrieve parameter gradients + control_loss_grad = self._retrieve_parameter_gradients( + self.perturbation + ).detach() + + # add to aux loss + aux_loss_grad = aux_loss_grad + control_loss_grad + + else: + control_loss = 0.0 + ################################################################ + + # classifier evasion indicator, reshape for broadcasting + adv_success = (adv_loss <= self.adv_success_thresh) * 1.0 + + # detector evasion indicator, reshape for broadcasting + detector_success = (detector_loss <= self.det_success_thresh) * 1.0 + + # perform standard, orthogonal, or selective gradient + # accumulation + if self.pgd_variant is None or self.pgd_variant == 'none': + + # for standard PGD, sum loss gradients + grad_total += adv_loss_grad + \ + detector_loss_grad + \ + aux_loss_grad + + elif self.pgd_variant == 'orthogonal': + + # for orthogonal PGD, orthogonalize loss gradients and + # select one for update; optionally, orthogonalize only + # every kth step + if self.k and self._batch_id % self.k: + adv_loss_grad_proj = adv_loss_grad + detector_loss_grad_proj = detector_loss_grad + aux_loss_grad_proj = aux_loss_grad + else: + adv_loss_grad_proj = self._component_orthogonal( + adv_loss_grad, + detector_loss_grad, + aux_loss_grad + ) + detector_loss_grad_proj = self._component_orthogonal( + detector_loss_grad, + adv_loss_grad, + aux_loss_grad + ) + aux_loss_grad_proj = self._component_orthogonal( + aux_loss_grad, + detector_loss_grad, + adv_loss_grad + ) + + # update 'along' a single loss gradient per iteration + grad_total += adv_loss_grad_proj * (1 - adv_success) + grad_total += detector_loss_grad_proj * adv_success \ + * (1 - detector_success) + grad_total += aux_loss_grad_proj * adv_success * \ + detector_success + + elif self.pgd_variant == 'selective': + + # only consider a single loss per iteration, without + # ensuring orthogonality to remaining loss gradients + grad_total += adv_loss_grad * (1 - adv_success) + grad_total += detector_loss_grad * adv_success \ + * (1 - detector_success) + grad_total += aux_loss_grad * adv_success * detector_success + + else: + raise ValueError(f'Invalid attack mode {self.pgd_variant}') + + # regularize gradients via p-norm projection + if self.scale_grad in [2, float(2), "2"]: + grad_norms = torch.norm( + grad_total, p=2, dim=-1 + ) + 1e-20 + grad_total = grad_total / grad_norms + elif self.scale_grad in [float("inf"), "inf"]: + grad_total = torch.sign(grad_total) + elif self.scale_grad in ['none', None]: + pass + else: + raise ValueError(f'Invalid gradient regularization norm ' + f'{self.scale_grad}' + ) + + # set final parameter gradients + self._set_parameter_gradients( + grad_total.flatten(), + self.perturbation + ) + + # log results + if self.writer is not None: + self._log_step( + x=x, + x_adv=perturbed, + y=y, + adv_loss=adv_loss, + det_loss=detector_loss, + aux_loss=aux_loss, + control_loss=control_loss, + detection_rate=torch.mean(1.0 * detector_flags) + ) + + # return placeholder loss + return adv_loss + detector_loss + aux_loss + + # optimizer step, using stored gradients + self.optimizer.step(closure) + + # project perturbation to feasible region + if hasattr(self.perturbation, "project_valid"): + try: + self.perturbation.project_valid() + except AttributeError: + pass + + # update total iteration count + self._iter_id += 1 + + def _evaluate_batch(self, + x: torch.Tensor, + y: torch.Tensor, + *args, + **kwargs + ): + """Evaluate batch of inputs by passing through model/perturbation""" + + x_orig = x.clone().detach() + x_adv = self.perturbation(x_orig, y=y, *args, **kwargs) + return x_adv diff --git a/voicebox/src/attacks/offline/white_noise.py b/voicebox/src/attacks/offline/white_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..4b00a778863f947d0cc88846deb5e037a21ab4da --- /dev/null +++ b/voicebox/src/attacks/offline/white_noise.py @@ -0,0 +1,354 @@ +import torch +from torch.utils.data import Dataset + +import warnings + +from typing import Union, List, Tuple + +from src.attacks.offline.trainable import TrainableAttack +from src.attacks.offline.perturbation.perturbation import Perturbation +from src.attacks.offline.perturbation.white_noise import WhiteNoisePerturbation +from src.pipelines.pipeline import Pipeline +from src.loss.adversarial import AdversarialLoss + +################################################################################ +# White noise attack +################################################################################ + + +class WhiteNoiseAttack(TrainableAttack): + """ + Simple baseline attack in which white noise is added to inputs. + """ + def __init__(self, + pipeline: Pipeline, + adv_loss: AdversarialLoss, + snr_low: float = -10.0, + snr_high: float = 60.0, + step_size: float = 5.0, + min_success_rate: float = 0.9, + search: str = 'bisection', + **kwargs + ): + """ + Sweep a range of SNR values to find best signal-to-noise (SNR) ratio + at which to apply noise to inputs; adapted from https://bit.ly/3tcDF7u. + + :param pipeline: Pipeline object + :param adv_loss: AdversarialLoss object + :param snr: signal-to-noise ratio of attack. Can be a float, in which + case noise will be applied at the given SNR, or a pair of + floats representing the endpoints of the search space + :param step: step size for search over SNR values + :param search: search method if range of SNR values is given; must be + one of 'linear', 'bisection', or 'none' + """ + self.snr_low = snr_low + self.snr_high = snr_high + self.step_size = step_size + self.search = search + self.min_success_rate = min_success_rate + + super().__init__( + pipeline=pipeline, + adv_loss=adv_loss, + perturbation=WhiteNoisePerturbation(snr=snr_low), + **kwargs + ) + + @torch.no_grad() + def train(self, + x_train: torch.Tensor = None, + y_train: torch.Tensor = None, + data_train: Dataset = None, + x_val: torch.Tensor = None, + y_val: torch.Tensor = None, + data_val: Dataset = None, + *args, + **kwargs + ): + """ + Perform a single epoch of "training" by sweeping for an optimal SNR + value over the training data. + """ + + loader_train, loader_val = self._prepare_data( + x_train, + y_train, + data_train, + x_val, + y_val, + data_val) + + # match devices and set reference if necessary + ref_batch = next(iter(loader_train)) + + if isinstance(ref_batch, tuple): + x_ref = ref_batch[0] + warnings.warn('Warning: provided dataset yields batches in tuple ' + 'format; the first two tensors of each batch will be ' + 'interpreted as inputs and targets, respectively, ' + 'and any remaining tensors will be ignored. To pass ' + 'additional named tensor arguments, use a dictionary ' + 'batch format with keys `x` and `y` for inputs and ' + 'targets, respectively.') + elif isinstance(ref_batch, dict): + x_ref = ref_batch['x'] + else: + x_ref = ref_batch + + if hasattr(self.perturbation, "set_reference"): + try: + self.perturbation.set_reference( + x_ref.to(self.pipeline.device)) + except AttributeError: + pass + + # enumerate possible SNR values for search + snr_values = torch.arange(self.snr_low, self.snr_high, self.step_size) + + # track iterations + self._iter_id = 0 + self._batch_id = 0 + self._epoch_id = 0 + + # avoid unnecessary search + if self.snr_low == self.snr_high \ + or len(snr_values) < 2 \ + or self.search in ['none', None]: + self.perturbation.set_snr(self.snr_low) + + else: + + # find best SNR via search + i_min = 0 + i_max = len(snr_values) + snr_best = self.snr_low + + # perform bisection search for maximum SNR value which achieves + # minimum success threshold + if self.search == 'bisection': + + while i_min < i_max: + + # determine midpoint index + i_mid = (i_min + i_max) // 2 + snr = snr_values[i_mid] + + # set SNR + self.perturbation.set_snr(snr) + + # compute success rate over training data at each candidate + # SNR level + successes = 0 + n = 0 + + self._batch_id = 0 + for batch in loader_train: + + if isinstance(batch, dict): + x, y = batch['x'], batch['y'] + else: + x, y, *_ = batch + + x = x.to(self.pipeline.device) + y = y.to(self.pipeline.device) + + n += len(x) + x_adv = self.perturbation(x) + outputs = self.pipeline(x_adv) + adv_scores = self.adv_loss(outputs, y) + adv_loss = adv_scores.mean() + + batch_successes = (1.0 * self._compute_success_array( + x, y, x_adv)).sum().item() + successes += batch_successes + + self._log_step( + x, + x_adv, + y, + adv_loss, + success_rate=batch_successes/len(x) + ) + + self._batch_id += 1 + self._iter_id += 1 + + success_rate = successes / n + + if success_rate >= self.min_success_rate: + snr_best = snr + i_min = i_mid + 1 + else: + i_max = i_mid + + # perform linear search for SNR level + elif self.search == 'linear': + + for snr in snr_values: + + # set SNR + self.perturbation.set_snr(snr) + + # compute success rate over training data at each candidate + # SNR level + successes = 0 + n = 0 + + self._batch_id = 0 + for batch in loader_train: + + if isinstance(batch, dict): + x, y = batch['x'], batch['y'] + else: + x, y, *_ = batch + + x = x.to(self.pipeline.device) + y = y.to(self.pipeline.device) + + n += len(x) + x_adv = self.perturbation(x) + outputs = self.pipeline(x_adv) + adv_scores = self.adv_loss(outputs, y) + adv_loss = adv_scores.mean() + batch_successes = (1.0 * self._compute_success_array( + x, y, x_adv)).sum().item() + successes += batch_successes + + self._log_step( + x, + x_adv, + y, + adv_loss, + success_rate=batch_successes/len(x) + ) + + self._batch_id += 1 + self._iter_id += 1 + + success_rate = successes / n + + if success_rate >= self.min_success_rate: + snr_best = snr + else: + raise ValueError(f'Invalid search method {self.search}') + + # set final SNR value + self.perturbation.set_snr(snr_best) + + # perform validation + adv_scores = [] + aux_scores = [] + det_scores = [] + success_indicators = [] + detection_indicators = [] + + self.perturbation.eval() + + for batch_id, batch in enumerate(loader_val): + + # randomize simulation for each validation batch + self.pipeline.sample_params() + + if isinstance(batch, dict): + x_orig, targets = batch['x'], batch['y'] + else: + x_orig, targets, *_ = batch + + n_batch = x_orig.shape[0] + + x_orig = x_orig.to(self.pipeline.device) + targets = targets.to(self.pipeline.device) + + # set reference for auxiliary loss + self._set_loss_reference(x_orig) + + with torch.no_grad(): + + # compute adversarial loss + x_adv = self._evaluate_batch(x_orig, targets) + outputs = self.pipeline(x_adv) + adv_scores.append(self.adv_loss(outputs, targets).flatten()) + + # compute adversarial success rate + success_indicators.append( + 1.0 * self._compute_success_array( + x_orig, targets, x_adv + ).flatten()) + + # compute defense loss and detection indicators + def_results = self.pipeline.detect(x_adv) + detection_indicators.append(1.0 * def_results[0].flatten()) + det_scores.append(def_results[1].flatten()) + + # compute auxiliary loss + if self.aux_loss is not None: + aux_scores.append( + self._compute_aux_loss(x_adv).flatten()) + else: + aux_scores.append(torch.zeros(n_batch)) + + tag = f'{self.__class__.__name__}-' \ + f'{self.aux_loss.__class__.__name__}' + + if self.writer is not None: + + with self.writer.force_logging(): + + # adversarial loss value + self.writer.log_scalar( + torch.cat(adv_scores, dim=0).mean(), + f"{tag}/adversarial-loss-val", + global_step=self._iter_id + ) + + # detector loss value + self.writer.log_scalar( + torch.cat(det_scores, dim=0).mean(), + f"{tag}/detector-loss-val", + global_step=self._iter_id + ) + + # auxiliary loss value + self.writer.log_scalar( + torch.cat(aux_scores, dim=0).mean(), + f"{tag}/auxiliary-loss-val", + global_step=self._iter_id + ) + + # adversarial success rate + self.writer.log_scalar( + torch.cat(success_indicators, dim=0).mean(), + f"{tag}/success-rate-val", + global_step=self._iter_id + ) + + # adversarial detection rate + self.writer.log_scalar( + torch.cat(detection_indicators, dim=0).mean(), + f"{tag}/detection-rate-val", + global_step=self._iter_id + ) + + # freeze model parameters + self.perturbation.eval() + for p in self.perturbation.parameters(): + p.requires_grad = False + + # save model/perturbation + self._checkpoint() + + def _evaluate_batch(self, + x: torch.Tensor, + y: torch.Tensor, + **kwargs + ): + """ + Apply white noise perturbations to inputs. + """ + + # require batch dimension + assert x.ndim >= 2 + + return self.perturbation(x) diff --git a/voicebox/src/attacks/online/__init__.py b/voicebox/src/attacks/online/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6660019860730b8e2556b61dbffd505cf627289 --- /dev/null +++ b/voicebox/src/attacks/online/__init__.py @@ -0,0 +1,2 @@ +from .streamer import Streamer +from .voicebox_streamer import VoiceBoxStreamer \ No newline at end of file diff --git a/voicebox/src/attacks/online/__pycache__/__init__.cpython-39.pyc b/voicebox/src/attacks/online/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71a4b2ab5d367b8bc1c270d9367f749fa97d1c56 Binary files /dev/null and b/voicebox/src/attacks/online/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/attacks/online/__pycache__/streamer.cpython-39.pyc b/voicebox/src/attacks/online/__pycache__/streamer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48c811bdf94828b706e16f19355cd39b2a0c04c9 Binary files /dev/null and b/voicebox/src/attacks/online/__pycache__/streamer.cpython-39.pyc differ diff --git a/voicebox/src/attacks/online/__pycache__/voicebox_streamer.cpython-39.pyc b/voicebox/src/attacks/online/__pycache__/voicebox_streamer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20ef62effcb4cdab914a0fb1609878cfee98cf4f Binary files /dev/null and b/voicebox/src/attacks/online/__pycache__/voicebox_streamer.cpython-39.pyc differ diff --git a/voicebox/src/attacks/online/streamer.py b/voicebox/src/attacks/online/streamer.py new file mode 100644 index 0000000000000000000000000000000000000000..9791c71aa74b16fa07ce2f32a2ada90d6f401afd --- /dev/null +++ b/voicebox/src/attacks/online/streamer.py @@ -0,0 +1,217 @@ +""" +This file contains a code for a generalized streamer. + +The model's forward must take a windowed tensor of dimensions (n_batch, n_channels, window_size, num_windows) +and also output (n_batch, n_channels, window_size, num_windows). The first non-self positional argument +will be the windowed tensor, and the forward must have a keyword argument for `lookahead`. + +Overlap-adding to a buffer is handled by the streamer. All completed audio is outputted immediately. + +The streamer contains a lookahead window cache and queue, as well as an output buffer containing incomplete +audio. The last `lookahead_frames` frames of input audio should not have audio generated for them. + +`hidden_refresh` is the number of input samples until the hidden state refreshes. This is to prevent out of +distribution hidden states from no voice for awhile. + +For networks featuring a causal RNN, `recurrent` should be set true, and the following properties must hold +for `model.forward`: + + Has kwarg `hx` that takes a torch.Tensor or a tuple of torch.Tensors, representing the hidden state + + When kwarg `hx=None`, then the hidden state must be generated by the model. + + Returns a tuple for size 2: (windowed_audio, new_hx) +""" + +import time +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.data.dataproperties import DataProperties + + +class Streamer: + def __init__(self, + model: nn.Module, + device: str, + hop_length: int = None, + hidden_refresh: int = 32_000, + window_length: int = None, + win_type: str = None, + lookahead_frames: int = None, + recurrent: bool = False): + self.device = device + self.model: nn.Module = model + self.model.to(device) + self.model.eval() + # Assume hop_length is a multiple of window_length + self.hop_length: int = hop_length or self.model.hop_length + self.window_length: int = window_length or self.model.window_length + self.output_buffer_length: int = self.window_length - self.hop_length + self.hidden_refresh_frames: int = (hidden_refresh - self.window_length) // self.hop_length + 1 + self.lookahead_frames: int = self.model.lookahead_frames if lookahead_frames is None else lookahead_frames + self.win_type: str = win_type or self.model.win_type + self.recurrent: bool = recurrent + + # Timing variables + self._last_processed_samples: int = 0 + self._total_processed_samples: int = 0 + self._last_processed_time: float = 0 + self._total_processed_time: float = 0 + + self._reset_buffers() + + @property + def real_time_factor(self) -> float: + """Gets the real time factor of all current processed audio""" + return self._total_processed_time / (self._total_processed_samples / DataProperties.get('sample_rate')) + + def _refresh_hidden_state(self, input_frames: torch.Tensor) -> None: + num_frames = input_frames.shape[1] + self._frames_until_refresh -= num_frames + if self._frames_until_refresh <= 0: + self.hx = None + + def _reset_buffers(self): + """Clears out buffers""" + self.lookahead_buffer: torch.Tensor = torch.zeros(1, 0, self.window_length, device=self.device) + self.output_buffer: torch.Tensor = torch.zeros(1, 0, device=self.device) + # next_pad contains the first `self.hop_length` samples of a window, where half of it is in + # one `feed` chunk, and the other is in the next + self.next_pad: torch.Tensor = None + self.hx = None + self._frames_until_refresh = self.hidden_refresh_frames + + def _get_windows(self, x: torch.Tensor) -> torch.Tensor: + """ + Gets (possibly) overlapping windows of `x`. + If windows overlap, the last `self.hop_size` frames will be appended to the start of `x`. + + (1, 1, length) -> (1, num_windows, window_length) + """ + """ + If we are using overlapping windows, then take the last `self.hop_length` + of the last window in the `self.lookahead_buffer` + """ + if self.win_type in ['hann', 'triangular']: + if self.next_pad is not None: + x = torch.cat([self.next_pad, x], dim=-1) + self.next_pad = x[:, :, -self.hop_length:] + + windowed_audio = F.unfold( + x.unsqueeze(2), + (1, self.window_length), + stride=self.hop_length + ).permute(0, 2, 1) + + if self.lookahead_frames > 0: + windowed_audio = torch.cat( + [self.lookahead_buffer, windowed_audio], dim=1 + ) # (1, lookahead_buffer_windows + input_windows, window_length) + self.lookahead_buffer = windowed_audio[:, -self.lookahead_frames:, :] + return windowed_audio + + @staticmethod + def _get_win_func(win_type: str): + if win_type == 'rectangular': + return lambda m: torch.ones(m) + elif win_type == 'hann': + return lambda m: torch.hann_window(m) + elif win_type == 'triangular': + return lambda m: torch.as_tensor(np.bartlett(m)).float() + else: + raise ValueError("Invalid window type") + + def _overlap_add_buffer(self, frames: torch.Tensor) -> None: + """ + Takes `frames`, applies a windowed overlap-add to them. + Adds the previous overlap-add buffer to the start of the folded + signal, and assigns signal to `self.output_buffer` + + :param frames: (n_batch, n_frames, window_length) + """ + # Length of output frames may be longer than self.window_length. + # Extra frames will be truncated. + assert frames.shape[-1] >= self.window_length + frames = frames[:, :, :self.window_length] + + if self.win_type in ['triangular', 'hann']: + win = self._get_win_func(self.win_type)(self.window_length).to(frames).reshape(1, 1, -1) + # Instead of padding the window, we chop the frames + # apply window + frames = frames * win + elif self.win_type == 'rectangular': + pass # for consistency with non-streamer versions + else: + raise ValueError(f'Invalid windows type: {self.win_type}') + n_frames = frames.shape[1] + buffer = torch.zeros(1, 1, (n_frames - 1) * self.hop_length + self.window_length, device=self.device) + start_buffer = 0 + for frame_idx in range(n_frames): + buffer[..., start_buffer:start_buffer+self.window_length] += frames[:, frame_idx, :] + start_buffer += self.hop_length + buffer[..., :self.output_buffer.shape[-1]] += self.output_buffer + self.output_buffer = buffer + + def _get_output_audio(self): + """ + Returns the complete output audio from `self.output_buffer`, + and truncates `self.output_buffer` to only contain `output_buffer_length` + samples. + """ + buffer_length = self.output_buffer.shape[-1] + + return_audio, self.output_buffer = torch.split( + self.output_buffer, + split_size_or_sections=[buffer_length - self.output_buffer_length, + self.output_buffer_length], + dim=-1 + ) + return return_audio + + def _log_time(self, output_audio: torch.Tensor, time_begin: float) -> None: + elapsed = time.time() - time_begin + self._last_processed_samples = output_audio.shape[-1] + self._last_processed_time = elapsed + self._total_processed_samples += self._last_processed_samples + self._total_processed_time += self._last_processed_time + + @torch.no_grad() + def feed(self, audio: torch.Tensor, **kwargs) -> torch.Tensor: + """ + :param audio: FloatTensor of shape (1, 1, length). `length` should be + a multiple of `self.hop_length`. + """ + time_begin = time.time() + windowed_audio = self._get_windows(audio) # (1, input_windows, window_length) + + if self.recurrent: + output_frames, self.hx = self.model(windowed_audio, hx=self.hx, **kwargs) + else: + output_frames = self.model(windowed_audio, **kwargs) + # output_frames is size (1, output_windows, window_length) + + self._overlap_add_buffer(output_frames) + output_audio = self._get_output_audio() + self._log_time(output_audio, time_begin) + return output_audio + + @torch.no_grad() + def flush(self, **kwargs): + """ + This gets the remaining audio out of the buffers, processes it, and returns it. + + Does not log time. + """ + if self.lookahead_frames: + model_input = torch.cat( + [self.lookahead_buffer, torch.zeros(1, self.lookahead_frames, self.window_length, device=self.device)], + dim=1 + ) # (1, lookahead_buffer_windows + self.lookahead_frames, window_length) + if self.recurrent: + output_frames, self.hx = self.model(model_input, hx=self.hx, **kwargs) + else: + output_frames = self.model(model_input, **kwargs) # (1, output_windows, window_length) + self._overlap_add_buffer(output_frames) + output_audio = self.output_buffer + self._reset_buffers() + return output_audio diff --git a/voicebox/src/attacks/online/voicebox_streamer.py b/voicebox/src/attacks/online/voicebox_streamer.py new file mode 100644 index 0000000000000000000000000000000000000000..14a95847cddc8aa456dd9be8a5bebdc2ae0a2f61 --- /dev/null +++ b/voicebox/src/attacks/online/voicebox_streamer.py @@ -0,0 +1,626 @@ +from typing import Any, Callable, Union, Optional, Tuple +import math +import torchaudio +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchaudio.transforms import MFCC, MelSpectrogram +import pyworld + + +from src.attacks.offline.perturbation import VoiceBox +from src.attacks.offline.perturbation.voicebox.bottleneck import RNNBottleneck +from src.attacks.offline.perturbation.voicebox.filter import FilterLayer +from src.attacks.offline.perturbation.voicebox.spec import SpectrogramEncoder +from src.models.phoneme import Delta +from src.data import DataProperties + +""" +VoiceBoxStreamer: An implementation of VoiceBoxStreamer that works with `src.attacks.online.streamer.Streamer` + +Weights can be loaded from a `VoiceBox` object. +""" + +class VoiceBoxStreamer(VoiceBox): + def __init__(self, encoder_buffer_frames=50, bottleneck_type='lstm', **kwargs): + super().__init__(**kwargs) + self.hop_length = self.win_length // 2 + self._condition_vector: torch.Tensor = torch.zeros( + 1, 1, self.conditioning_dim + ) + # overwrite filter layer with streamer + self.filter = FilterLayerStreamer( + win_length=self.filter.win_length, + win_type=self.filter.win_type, + n_bands=self.filter.n_bands, + normalize_ir=self.filter.normalize_ir + ) + self._init_bottleneck(bottleneck_type) + self._init_encoder_streamer(encoder_buffer_frames) + + @property + def window_length(self): + return self.win_length + + @property + def condition_vector(self): + return self._condition_vector + + @condition_vector.setter + def condition_vector(self, v: torch.Tensor): + assert v.shape == self._condition_vector.shape + self._condition_vector = v + + def _init_bottleneck(self, bottleneck_type) -> None: + """ + Initializes Bottleneck. If 'rnn' or 'lstm', initialized with + a streaming version is used. Transformer is not supported with the streamer. + """ + if bottleneck_type in ['lstm', 'rnn']: + self.bottleneck = RNNBottleneckStreamer( + input_size=self.bottleneck_hidden_size, + hidden_size=self.bottleneck_feedforward_size, + proj_size=self.bottleneck_hidden_size, + num_layers=self.bottleneck_depth, + downsample_index=1, + downsample_factor=1, + dropout_prob=0.0 + ) + elif bottleneck_type in ['attention', 'transformer']: + raise NotImplementedError("Not Supported.") + else: + self.bottleneck = nn.Identity() + + def _init_encoder_streamer(self, encoder_buffer_frames: int) -> None: + """ + Initializes encoder streamer implementations + """ + if not isinstance(self.ppg_encoder, nn.Identity): + # NOTE: Assuming that he PPG is relatively constant. + # Don't touch the magic numbers + self.ppg_encoder = PPGEncoderStreamer( + win_length=self.win_length, + hop_length=self.win_length // 2, + win_func=torch.hann_window, + n_mels=32, + n_mfcc=19, + lstm_depth=2, + hidden_size=self.ppg_encoder_hidden_size, + lookahead=self.bottleneck_lookahead_frames + ) + + self.ppg_encoder.load_state_dict( + torch.load(self.ppg_encoder_path, map_location=torch.device('cpu')) + ) + if not isinstance(self.pitch_encoder, nn.Identity): + pitch_streamer = DioStreamer( + return_periodicity=True, + buffer_frames=encoder_buffer_frames, + lookahead_frames=self.bottleneck_lookahead_frames + ) + pitch_streamer.__dict__.update(self.pitch_encoder.__dict__) + self.pitch_encoder = pitch_streamer + if not isinstance(self.loudness_encoder, nn.Identity): + self.loudness_encoder = LoudnessEncoderStreamer( + hop_length=self.win_length // 2 + ) + if not isinstance(self.spec_encoder, nn.Identity): + spec_streamer_encoder = SpectrogramEncoderStreamer( + win_length=self.spec_encoder.win_length, + win_type=self.spec_encoder.win_type, + spec_type=self.spec_encoder.spec_type, + n_mels=self.spec_encoder.n_mels, + lookahead=self.spec_encoder.lookahead, + hidden_size=self.spec_encoder.hidden_size, + mlp_depth=self.spec_encoder.mlp_depth, + normalize=self.spec_encoder.normalize + ) + self.spec_encoder = spec_streamer_encoder + + def get_controls( + self, + x: torch.Tensor, + hx: Any = None) -> tuple[torch.Tensor, Any]: + """ + Gets controls and recurrent states + + :return controls: Controls tensor of (1, input_windows, n_fft) + :return hx: Recurrent state. Do not edit, except for feeding back to this function. + """ + ppg_hx, bottleneck_hx = (None, None) if hx is None else hx + n_frames = x.shape[1] + + features = [] + # I think theres some shared stft calls between all of these. + + if self.use_pitch_encoder: + pitch, periodicity = self.pitch_encoder(x) + # clip the first `self._buffer_frames` values + # We clip an extra frame off the DIO results + # because the F0 is assigned inclusively to the first and + # last sample. + features += [pitch, periodicity] + if self.use_loudness_encoder: + loudness = self.loudness_encoder(x) + # clip the first `self._buffer_frames` values + features.append(loudness) + if self.use_phoneme_encoder: + ppg, ppg_hx = self.ppg_encoder(x, hx=ppg_hx) + features.append(ppg) + if self.use_spec_encoder: + spec = self.spec_encoder(x) # (n_batch, n_frames, hidden_size) + features.append(spec) + + encoded = self.encoder_proj( + torch.cat(features, dim=-1) + ) + ######################################################################## + # TARGET CONDITIONING + ######################################################################## + if self.conditioning_dim: + y = self.condition_vector.repeat(1, n_frames, 1) + encoded = self.conditioning_encoder(x=encoded, cond=self.conditioning_mlp(y)) + # (1, n_frames, hidden_size) + + ######################################################################## + # BOTTLENECK + ######################################################################## + bottleneck_out, bottleneck_hx = self.bottleneck( + encoded, hx=bottleneck_hx, lookahead=self.bottleneck_lookahead_frames) + + # apply skip connection with pre-bottleneck encoding + if self.bottleneck_skip: + bottleneck_out = torch.cat([ + bottleneck_out, + encoded + ], dim=-1) # (n_batch, n_frames, 2 * hidden_size) + + controls = self.bottleneck_proj(bottleneck_out) # (n_batch, n_frames, n_bands) + return controls, (ppg_hx, bottleneck_hx) + + @torch.no_grad() + def forward(self, x: torch.Tensor, hx: Any=None) -> tuple[torch.Tensor, torch.Tensor]: + """ + Expecting: x: (1, input_windows, window_length) + + This should never get a batch size greater than 1 + + Hidden state should not be edited outside of this function + """ + assert not self.training, "Never use this streamer for training!" + assert x.shape[-1] == self.win_length + assert x.shape[0] == 1, "Batched audio not supported" + assert x.shape[1] > 1, "Due to what I believe to be a Pytorch bug, " \ + + "InstanceNorm1d does not allow single batch single frame inputs, " \ + + "even when `model.eval()` is called beforehand. Please use at least 2 frames at a time." + + controls, hx = self.get_controls(x, hx=hx) + x, controls = self.chop_lookahead(x, controls) + controls = self.controls_scaling(controls) + filtered_audio = self.filter(x, controls) + return filtered_audio, hx + + def chop_lookahead(self, x: torch.Tensor, controls: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + lookahead = self.spec_encoder.lookahead + self.bottleneck_lookahead_frames + return x[:, :-lookahead, :], controls[:, :-lookahead, :] + + def controls_scaling(self, controls: torch.Tensor) -> torch.Tensor: + + # scale stored controls to fixed range + if self.control_scaling_fn.lower() == 'sigmoid': + controls = 2 * torch.sigmoid( + controls + )**(math.log(10)) + 1e-7 + elif self.control_scaling_fn.lower() == 'relu': + controls = F.relu(controls) + elif self.control_scaling_fn.lower() == 'elu': + controls = F.elu(controls) + 1 + elif self.control_scaling_fn.lower() == 'log': + controls = torch.tanh(controls * 0.2) * 5 # [-5, 5] + controls = torch.exp(controls) + else: + raise ValueError(f'Invalid control scaling function ' + f'{self.control_scaling_fn}') + + # if controls are taken to be mel-scaled, linearly scale + if self.mel_scale_controls: + controls = self.inv_mel_scale(controls) + + # perform projection (PGD step) on filter controls + controls = self._project_valid(controls) + return controls + +############################################### +# STREAMING BUILDING BLOCKS # +############################################### + + +class RNNBottleneckStreamer(RNNBottleneck): + """ + Edited Bottleneck to return the hidden states. + """ + def forward(self, + x: torch.Tensor, + hx: Any=None, + lookahead: int=0 + ) -> tuple[torch.Tensor, Any]: + if hx is None: + hx = [None] * self.num_layers + else: + assert len(hx) == self.num_layers + for i, rnn_layer in enumerate(self.rnn): + x, hx[i] = self.run_rnn(rnn_layer, (x, hx[i]), lookahead) + x = self.dropout(x) + + # x = self.instancenorm(x.permute(0, 2, 1)).permute(0, 2, 1) + if i == self.downsample_index: + n_batch, n_frames, proj_size = x.shape + # determine necessary padding to allow temporal downsampling + pad_len = self.downsample_factor * math.ceil(n_frames / self.downsample_factor) - n_frames + # apply causal padding + x = F.pad(x, (0, 0, 0, pad_len)) + # apply temporal downsampling + x = torch.reshape(x, (n_batch, x.shape[1] // self.downsample_factor, x.shape[2] * self.downsample_factor)) + return x, hx + + @staticmethod + def run_rnn(rnn_layer: nn.Module, data: tuple[torch.Tensor, Any], lookahead: int) -> tuple[torch.Tensor, Any]: + x, hx = data + if lookahead == 0: + return rnn_layer(x, hx) + else: + num_frames = x.shape[1] + if num_frames <= lookahead: + x, _ = rnn_layer(x, hx) + return x, hx + else: + x, x_l = x[:, :-lookahead, :], x[:, -lookahead:, :] + x, hx = rnn_layer(x, hx) + x_l, _ = rnn_layer(x_l, hx) + x = torch.cat([x, x_l], dim=1) + return x, hx + + +class FilterLayerStreamer(FilterLayer): + """ + Streamer Implementation of FilterLayer without the OLA + """ + def __init__(self, win_length: int = 512, win_type: str = 'hann', n_bands: int = 128, normalize_ir: Union[str, int, float] = None, **kwargs): + super().__init__(win_length, win_type, n_bands, normalize_ir, **kwargs) + + # Why is computed every iteration + n_fft_min = self.win_length + 2 * (self.n_bands - 1) + self.n_fft = pow(2, math.ceil(math.log2(n_fft_min))) + + def forward(self, x: torch.Tensor, controls: torch.Tensor) -> torch.Tensor: + """ + + :param x: Windowed audio of size (1, input_windows, window_size) + :param controls: Controls of size (1, input_windows, n_fft) + """ + if x.shape[1] == 0: + return x + + impulse = self._amp_to_ir(controls) + x = self._fft_convolve(x, impulse, self.n_fft).contiguous() + # x: (1, n_frames, n_fft) + return x + +############################################### +# Streaming Encoding Blocks # +############################################### + +class PPGEncoderStreamer(nn.Module): + def __init__(self, + win_length: int = 256, + hop_length: int = 128, + win_func: Callable = torch.hann_window, + n_mels: int = 32, + n_mfcc: int = 13, + lstm_depth: int = 2, + hidden_size: int = 512, + lookahead: int = 5): + + super().__init__() + self.win_length = win_length + self.hop_length = hop_length + # get non-center MFCC + mel_kwargs = { + "n_fft": self.win_length, + "win_length": self.win_length, + "hop_length": self.hop_length, + "window_fn": win_func, + "n_mels": n_mels + } + spectrogram = NonCenterSpectrogram( + n_fft=self.win_length, + win_length=self.win_length, + hop_length=self.hop_length, + window_fn=win_func + ) + self.mfcc = MFCC( + sample_rate=DataProperties.get("sample_rate"), + n_mfcc=n_mfcc, + log_mels=True, + melkwargs=mel_kwargs + ) + self.mfcc.MelSpectrogram.spectrogram = spectrogram + # compute first- and second- order MFCC deltas + self.delta = Delta() + + # PPG network + self.mlp = nn.Sequential( + nn.Linear(n_mfcc * 3, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.ReLU() + ) + self.lstm = nn.LSTM( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=lstm_depth, + bias=True, + batch_first=True, + bidirectional=False + ) + self.lookahead = lookahead + + + @torch.no_grad() + def forward(self, x: torch.Tensor, hx: Optional[Tuple[torch.Tensor, torch.Tensor]]=None) -> Tuple[torch.Tensor, Any]: + """ + :param x: Windowed audio of (1, windows, win_length) + """ + # require batch, channel dimensions + mfcc = self.mfcc(x) # (1, n_mfcc, n_frames) + + delta1 = self.delta(mfcc) # (1, n_mfcc, n_frames) + delta2 = self.delta(delta1) # (1, n_mfcc, n_frames) + x = torch.cat([mfcc, delta1, delta2], dim=1) # (n_batch, n_frames, 3 * n_mfcc) + x = x.permute(0, 2, 1) + x = self.mlp(x) # (n_batch, n_frames, hidden_size) + if self.lookahead: + if x.shape[1] > self.lookahead: + x, x_l = x[:, :-self.lookahead, :], x[:, -self.lookahead:, :] + x, hx = self.lstm(x, hx) # (n_batch, n_frames, hidden_size) + x_l, _ = self.lstm(x_l, hx) + x = torch.cat([x, x_l], dim=1) + else: + x, _ = self.lstm(x, hx) + else: + x, hx = self.lstm(x, hx) + + return x, hx + +class DioStreamer(nn.Module): + """ + Pitch and Periodicity streamer. + + Only uses dio. + """ + def __init__( + self, + return_periodicity: bool=True, + hop_length: int=128, + buffer_frames: int = 20, + lookahead_frames: int = 5): + self.return_periodicity = return_periodicity + self.hop_length = hop_length + + self._buffer_frames = buffer_frames + self._lookahead = lookahead_frames + self._buffer = torch.zeros(1, buffer_frames, self.hop_length * 2) + + def _roll_buffer(self, last_input_frames: torch.Tensor) -> None: + if last_input_frames.shape[1] > self._lookahead: + self._buffer = ( + torch.cat( + [self._buffer, last_input_frames], + dim=1 + )[:, -self._buffer_frames-self._lookahead:-self._lookahead, :] + ) + + @torch.no_grad() + def forward(self, x: torch.Tensor): + """ + Takes windowed audio. Reconstructs the overlapping frames + by taking the first `hop_length` frames of each window, and stitching. + Them back together. Dio returns estimated pitch and periodicity + in hop_size distances + + """ + num_windows = x.shape[1] + self._buffer_frames + pitch_out, peridoicity_out, device = [], [], x.device + hop_ms = 1000 * self.hop_length / DataProperties.get('sample_rate') + x_padded = torch.cat([self._buffer, x], dim=1) + x_folded = F.fold( + x_padded[..., :self.hop_length].permute(0, 2, 1), + output_size=(1, self.hop_length * num_windows), + kernel_size=(1, self.hop_length), + stride=(1, self.hop_length) + ) + x_folded = x_folded.flatten() + x_folded = torch.cat([x_folded, x[0, -1, self.hop_length:]]) + # x_folded: (self.hop_length * num_windows) + x_np = x_folded.clone().double().cpu().numpy() + + + pitch, timeaxis = pyworld.dio( + x_np, + fs=DataProperties.get('sample_rate'), + f0_floor=50, + f0_ceil=550, + frame_period=hop_ms, + allowed_range=.1, + speed=4) # downsampling factor, for speedup + pitch = pyworld.stonemask( + x_np, + pitch, + timeaxis, + DataProperties.get('sample_rate')) + + pitch_out.append(pitch) + + pitch_out = torch.as_tensor( + pitch_out, + dtype=torch.float32, + device=device).unsqueeze(-1) + + pitch_out = pitch_out[:, self._buffer_frames+1:-1, :] + out = pitch_out + if self.return_periodicity: + unvoiced = pyworld.d4c( + x_np, + pitch, + timeaxis, + DataProperties.get('sample_rate'), + ).mean(axis=1) + + peridoicity_out.append(unvoiced) + periodicity_out = torch.as_tensor( + peridoicity_out, + dtype=torch.float32, + device=device).unsqueeze(-1) + + # (n_batch, n_frames, 1), (n_batch, n_frames, 1) + periodicity_out = periodicity_out[:, self._buffer_frames+1:-1, :] + out = pitch_out, periodicity_out + self._roll_buffer(x) + return out + + +class LoudnessEncoderStreamer(nn.Module): + """Streaming implementation for LoudnessEncoder""" + def __init__(self, + hop_length: int = 128, + n_fft: int = 256) -> None: + super().__init__() + self.hop_length = hop_length + self.n_fft = n_fft + + def A_weight(self): + # torch implementation of A_weight + # librosa.fft_frequencies + freqs = torch.zeros(1 + self.n_fft // 2) + for i in range(1 + self.n_fft // 2): + freqs[i] = i * DataProperties.get('sample_rate') / self.n_fft + + # librosa.A_weighting. Using magic numbers from librosa implementation + min_db = torch.Tensor([-80.0]).float() + f_sq = freqs ** 2.0 + const = torch.Tensor([12194.217, 20.598997, 107.65265, 737.86223]) ** 2.0 + weights = 2.0 + 20.0 * ( + torch.log10(const[0]) + + 2 * torch.log10(f_sq) + - torch.log10(f_sq + const[0]) + - torch.log10(f_sq + const[1]) + - 0.5 * torch.log10(f_sq + const[2]) + - 0.5 * torch.log10(f_sq + const[3]) + ) + return torch.where(weights > min_db, weights, min_db) + + @torch.no_grad() + def forward(self, x: torch.Tensor): + """ + :param x: Windowed audio of (1, num_windows, window_length) + """ + # torch.stft should be exactly the same as librosa stft + spec = torch.stft( + x.squeeze(0), + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.n_fft, + window=torch.hann_window(self.n_fft), + center=False, + return_complex=True + ).permute(2, 1, 0) # x: (1, num_windows, n_fft) + + spec = torch.log(abs(spec) + 1e-7) + a_weight = self.A_weight() + # apply multiplicative weighting via addition in log domain + spec = spec + a_weight.reshape(1, -1, 1) + + # take mean over each frame + loudness = torch.mean(spec, dim=1).unsqueeze(-1).float().to(x.device) + + return loudness + +class SpectrogramEncoderStreamer(SpectrogramEncoder): + """ + Streaming Implementation of SpectrogramEncoder + + Takes windowed audio instead of a single signal. + """ + def __init__( + self, + win_length: int = 512, + win_type: str = 'hann', + spec_type: str = 'linear', + lookahead: int = 5, + hidden_size: int = 512, + n_mels: int = 64, + mlp_depth: int = 2, + normalize: str = None + ): + super().__init__(win_length, win_type, spec_type, lookahead, hidden_size, n_mels, mlp_depth, normalize) + + # compute spectral representation + spec_kwargs = { + "n_fft": self.win_length, + "win_length": self.win_length, + "hop_length": self.hop_length, + "window_fn": self._get_win_func(self.win_type) + } + spectrogram = NonCenterSpectrogram(**spec_kwargs) + mel_kwargs = {**spec_kwargs, "n_mels": self.n_mels} + + if spec_type == 'linear': + self.spec = spectrogram + elif spec_type == 'mel': + self.spec = MelSpectrogram( + sample_rate=DataProperties.get("sample_rate"), + **mel_kwargs + ) + self.spec.spectrogram = spectrogram + elif spec_type == 'mfcc': + self.spec = MFCC( + sample_rate=DataProperties.get("sample_rate"), + n_mfcc=self.n_mels, + log_mels=True, + melkwargs=mel_kwargs + ) + self.spec.MelSpectrogram.spectrogram = spectrogram + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # spec = self.spec(x).squeeze(-1) + 1e-6 # (n_batch, n_frames, n_freq) + spec = self.spec(x) + 1e-6 # (n_batch, n_freqs, n_frames) + + if self.spec_type in ['linear', 'mel']: + spec = 10 * torch.log10(spec + 1e-8) # (n_batch, n_freq, n_frames) + # normalize spectrogram + spec = self.norm(spec) # (n_batch, n_freq, n_frames) + + # actual encoder network + encoded = self.glu(spec) # (n_batch, hidden_size, n_frames) + encoded = self.conv(encoded) # (n_batch, hidden_size, n_frames) + encoded = self.mlp( + encoded.permute(0, 2, 1) + ) # (n_batch, n_frames, hidden_size) + + return encoded + +class NonCenterSpectrogram(torchaudio.transforms.Spectrogram): + """ + A modified Spectrogram Module that processes overlapping frames + of audio in shape (batch, num_frames, frame_length) to + (batch, n_fft, num_frames). + + This should be used to patch the MelSpectrogram's `spectrogram`. + + Why don't I just use `center=False`? This causes a 1e-8 magnitude + error in the output spectrogram, and a 1e-5 magnitude error in the + MFCC, which is problematic for the PPG. + """ + def forward(self, x): + specgram = super().forward(x)[..., 1].transpose(-2, -1) + return specgram \ No newline at end of file diff --git a/voicebox/src/constants.py b/voicebox/src/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..11780ee743cd878f2c7595ab46f8419c1c425d7b --- /dev/null +++ b/voicebox/src/constants.py @@ -0,0 +1,67 @@ +from pathlib import Path + +################################################################################ +# Project-wide constants +################################################################################ + +# Static directories +CACHE_DIR = Path(__file__).parent.parent / 'cache' +DATA_DIR = Path(__file__).parent.parent / 'data' +RUNS_DIR = Path(__file__).parent.parent / 'runs' +TEST_DIR = Path(__file__).parent.parent / 'test' +CONFIGS_DIR = Path(__file__).parent.parent / 'configs' +MODELS_DIR = Path(__file__).parent.parent / 'pretrained' + +# Set constant properties for streaming operations +WIN_LENGTH = 256 +HOP_LENGTH = 128 +SAMPLE_RATE = 16000 + +# VoxCeleb1 dataset +VOXCELEB1_DATA_DIR = DATA_DIR / 'VoxCeleb1' +VOXCELEB1_EXT = 'wav' + +# VoxCeleb2 dataset +VOXCELEB2_DATA_DIR = DATA_DIR / 'VoxCeleb2' + +# Pretrained phoneme prediction model +PPG_PRETRAINED_PATH = MODELS_DIR / 'phoneme' / 'causal_ppg_256_hidden.pt' + +# Pretrained VoiceBox attack +VOICEBOX_PRETRAINED_PATH = MODELS_DIR / 'voicebox' / 'voicebox_final.pt' + +# Pretrained universal additive attack +UNIVERSAL_PRETRAINED_PATH = MODELS_DIR / 'universal' / 'universal_final.pt' + +# LibriSpeech dataset +LIBRISPEECH_DATA_DIR = DATA_DIR / 'LibriSpeech' +LIBRISPEECH_CACHE_DIR = CACHE_DIR / 'LibriSpeech' +LIBRISPEECH_SIG_LEN = 4.0 +LIBRISPEECH_EXT = 'flac' +LIBRISPEECH_PHONEME_EXT = 'TextGrid' +LIBRISPEECH_NUM_PHONEMES = 70 # first phoneme corresponds to silence +LIBRISPEECH_PHONEME_DICT = { + 'sil': 0, '': 0, 'sp': 0, 'spn': 0, + 'AE1': 1, 'P': 2, 'T': 3, 'ER0': 4, + 'W': 5, 'AH1': 6, 'N': 7, 'M': 8, + 'IH1': 9, 'S': 10, 'IH0': 11, 'Z': 12, + 'R': 13, 'EY1': 14, 'AH0': 15, 'L': 16, + 'D': 17, 'AY1': 18, 'V': 19, 'JH': 20, + 'EH1': 21, 'DH': 22, 'IY0': 23, 'IY2': 24, + 'OW1': 25, 'AW1': 26, 'UW1': 27, 'HH': 28, + 'AA1': 29, 'OW0': 30, 'F': 31, 'TH': 32, + 'AO1': 33, 'AA2': 34, 'ER1': 35, 'B': 36, + 'UH1': 37, 'K': 38, 'Y': 39, 'IY1': 40, + 'AO2': 41, 'NG': 42, 'AE0': 43, 'G': 44, + 'SH': 45, 'IH2': 46, 'EH2': 47, 'UW0': 48, + 'AY2': 49, 'EY2': 50, 'AA0': 51, 'OY1': 52, + 'AE2': 53, 'ZH': 54, 'EH0': 55, 'OW2': 56, + 'AH2': 57, 'UH2': 58, 'AO0': 59, 'UW2': 60, + 'EY0': 61, 'AW2': 62, 'AY0': 63, 'ER2': 64, + 'OY0': 65, 'OY2': 66, 'UH0': 67, 'AW0': 68, + 'CH': 69} +LIBRISPEECH_FILLER_PHONEMES = ['', 'sil', 'sp', 'spn'] + +# Streamer Conditioning +CONDITIONING_FOLDER = DATA_DIR / 'streamer' +CONDITIONING_FILENAME = CONDITIONING_FOLDER / 'conditioning.pt' diff --git a/voicebox/src/data/__init__.py b/voicebox/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49d26c70de50fe4800d7db3b9d1939e96d287874 --- /dev/null +++ b/voicebox/src/data/__init__.py @@ -0,0 +1,3 @@ +from src.data.dataproperties import DataProperties +from src.data.dataset import VoiceBoxDataset +from src.data.librispeech import LibriSpeechDataset diff --git a/voicebox/src/data/__pycache__/__init__.cpython-310.pyc b/voicebox/src/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4118df28e5542e4e86816e19e16a81a81d95b356 Binary files /dev/null and b/voicebox/src/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/data/__pycache__/__init__.cpython-39.pyc b/voicebox/src/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc3112d874d9fe13f60834b2d08f0359cd7b74f7 Binary files /dev/null and b/voicebox/src/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/data/__pycache__/dataproperties.cpython-310.pyc b/voicebox/src/data/__pycache__/dataproperties.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49cf137284c1a99d0df6b84d1486730be54845ec Binary files /dev/null and b/voicebox/src/data/__pycache__/dataproperties.cpython-310.pyc differ diff --git a/voicebox/src/data/__pycache__/dataproperties.cpython-39.pyc b/voicebox/src/data/__pycache__/dataproperties.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72aeb51887b1cdc2e234a077d0c296dbd54da6e4 Binary files /dev/null and b/voicebox/src/data/__pycache__/dataproperties.cpython-39.pyc differ diff --git a/voicebox/src/data/__pycache__/dataset.cpython-310.pyc b/voicebox/src/data/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f84eb2076795063a37c2183356d9f1482723868c Binary files /dev/null and b/voicebox/src/data/__pycache__/dataset.cpython-310.pyc differ diff --git a/voicebox/src/data/__pycache__/dataset.cpython-39.pyc b/voicebox/src/data/__pycache__/dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95c6c384f9f2b92a33577def85ed3390517da500 Binary files /dev/null and b/voicebox/src/data/__pycache__/dataset.cpython-39.pyc differ diff --git a/voicebox/src/data/__pycache__/librispeech.cpython-39.pyc b/voicebox/src/data/__pycache__/librispeech.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2484225ddc06a283f39bfdf59c7f301bfb87d30d Binary files /dev/null and b/voicebox/src/data/__pycache__/librispeech.cpython-39.pyc differ diff --git a/voicebox/src/data/dataproperties.py b/voicebox/src/data/dataproperties.py new file mode 100644 index 0000000000000000000000000000000000000000..47ed1e1d67d49fa8a8208f08d29c37727ce8ab10 --- /dev/null +++ b/voicebox/src/data/dataproperties.py @@ -0,0 +1,64 @@ +import torch + +################################################################################ +# Allow project-wide access to persistent data properties +################################################################################ + + +class DataProperties(object): + """ + Allow shared access to data properties (e.g. sample rate) across all audio + processing modules. Each dataset registers its properties with the + DataProperties class upon initialization, eliminating the need to repeatedly + pass properties as parameters + """ + + # Default data properties: 1-second 16kHz audio scaled to [-1, 1] + properties = { + "sample_rate": 16000, + "scale": 1.0, + "signal_length": 16000 + } + + @classmethod + def register_properties(cls, **kwargs): + """ + Register data properties by name + """ + cls.properties = kwargs + + @classmethod + def get(cls, *args): + """ + Access one or more data properties by name + """ + if len(args) > 1: + return tuple(cls.properties[a] for a in args) + else: + return cls.properties[args[0]] + + @classmethod + def format_input(cls, x: torch.Tensor): + """ + Ensure input is correctly formatted (batch/channels/samples). If input + cannot be reshaped to required dimensions, raise error + """ + + try: + signal_length = cls.properties["signal_length"] + except KeyError: + raise ValueError(f"Data property `signal_length` must be defined to" + f" format inputs") + + if x.ndim <= 1: + n_batch = 1 + + else: + n_batch = x.shape[0] + + try: + x = x.reshape(n_batch, 1, signal_length) + except RuntimeError: + raise ValueError(f"Invalid input dimensions {list(x.shape)}") + + return x diff --git a/voicebox/src/data/dataset.py b/voicebox/src/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..33181a56a9f06f666b0aca39a6cdffb5617e7e20 --- /dev/null +++ b/voicebox/src/data/dataset.py @@ -0,0 +1,412 @@ +import os +import math +from copy import deepcopy + +import librosa as li +import numpy as np + +import torch +import torch.nn.functional as F + +from torch.utils.data import Dataset + +from src.data.dataproperties import DataProperties +from src.constants import ( + SAMPLE_RATE, + HOP_LENGTH +) +from src.attacks.offline.perturbation.voicebox.pitch import PitchEncoder +from src.attacks.offline.perturbation.voicebox.loudness import LoudnessEncoder + +from os import path +from tqdm import tqdm +from pathlib import Path + +from typing import Union, Iterable + +################################################################################ +# Cache and load datasets +################################################################################ + + +def ensure_dir(directory: Union[str, Path]): + """ + Ensure all directories along given path exist, given directory name + """ + directory = str(directory) + if len(directory) > 0 and not os.path.exists(directory): + os.makedirs(directory) + + +class VoiceBoxDataset(Dataset): + + """ + A Dataset object for the LibriSpeech dataset subsets. The required data can + be downloaded by running the script `download_librispeech.sh`. This class + takes audio data from the specified directory and caches tensors to disk. + """ + def __init__(self, + split: str, + data_dir: str, + cache_dir: str, + audio_ext: str, + signal_length: Union[float, int], + scale: Union[float, int], + target: str, + features: Union[str, Iterable[str]] = None, + sample_rate: int = SAMPLE_RATE, + hop_length: int = HOP_LENGTH, + batch_format: str = 'dict', + *args, + **kwargs): + """ + Load, organize, and cache LibriSpeech dataset. + + Parameters + ---------- + split (str): data subset name + + data_dir (str): dataset root directory + + cache_dir (str): root directory to which tensors will be saved + + sample_rate (int): sample rate in Hz + + audio_ext (str): extension for audio files within dataset + + signal_length (int): length of audio files in samples (if `int` given) + or seconds (if `float` given) + + scale (float): range to which audio will be scaled + + hop_length (int): hop size for computing frame-wise features (e.g. + pitch, loudness) + + target (str): string specifying target type. + + features (Iterable): strings specifying features to compute for each + audio file in the dataset. Must be subset of + `pitch`, `periodicity`, `loudness` + + batch_format (str): format for returning batches. Must be either `dict` + or `tuple` + """ + + if batch_format not in ['dict', 'tuple']: + raise ValueError(f'Invalid batch format {batch_format}') + self.batch_format = batch_format + + self.data_dir = os.fspath(data_dir) + self.cache_dir = os.fspath(cache_dir) + + self.audio_ext = audio_ext + self.sample_rate = sample_rate + self.scale = scale + self.hop_length = hop_length + + # if signal length is given as floating-point value, assume time in + # seconds and convert to samples + if isinstance(signal_length, float): + self.signal_length = math.floor(signal_length * self.sample_rate) + else: + self.signal_length = signal_length + + # compute frame-equivalent signal length for targets/features, + # accounting for center-padding in spectrogram implementations + self.num_frames = math.ceil(self.signal_length / self.hop_length) + if not self.signal_length % self.hop_length: + self.num_frames += 1 + + # register data properties + DataProperties.register_properties( + sample_rate=self.sample_rate, + signal_length=self.signal_length, + scale=self.scale + ) + + # check for valid subset + self.split = self._check_split(split) + + # create directories if necessary + ensure_dir(path.join(self.cache_dir, self.split)) + ensure_dir(path.join(self.cache_dir, self.split)) + + # check for valid target types + self.target = self._check_target(target) + + # check for valid feature types + self.features = self._check_features(features) + + # scan all audio files in dataset + self.audio_list = self._get_audio_list() + + # check for cached audio, targets, and features by name. If missing, + # build required caches. Cache files are identified by sample rate and + # hop size where necessary (e.g. for pitch features, but not class + # targets) + self._build_audio_cache() + self._build_target_cache() + for feature in self.features: + self._build_feature_cache(feature) + + # load data and target tensors from caches + self.tx = torch.load( + Path(self.cache_dir) / + self.split / + f'{self._get_audio_id()}.pt') + self.ty = torch.load( + Path(self.cache_dir) / + self.split / + f'{self._get_target_id()}.pt') + + # load feature tensors from cache and store by name + self.tf = dict() + if self.features is not None and self.features: + for feature in self.features: + self.tf[feature] = torch.load( + Path(self.cache_dir) / + self.split / + f'{self._get_feature_id(feature)}.pt') + + @staticmethod + def _check_split(split: str): + if split not in ['train', 'test']: + raise ValueError(f'Invalid split {split}') + return split + + @staticmethod + def _check_target(target: str): + if target not in ['class', 'transcript']: + raise ValueError(f'Invalid target type {target}') + return target + + @staticmethod + def _check_features(features: Union[str, Iterable[str]]): + if features is None or not features: + features = [] + else: + if isinstance(features, str): + features = [features] + + for f in features: + if f not in ['pitch', 'periodicity', 'loudness']: + raise ValueError(f'Invalid feature type {f}') + return list(features) + + def _get_audio_list(self, *args, **kwargs): + """Scan for all audio files with given extension""" + return sorted( + list((Path(self.data_dir) / self.split).rglob( + f'*.{self.audio_ext}')) + ) + + def _get_audio_id(self): + """Identifier for cached audio""" + return f'{self.sample_rate}-audio' + + def _get_target_id(self): + """Identifier for cached targets""" + if self.target in ['class', 'transcript']: + return f'{self.target}' + else: + return f'{self.sample_rate}-{self.hop_length}-{self.target}' + + def _get_feature_id(self, feature: str): + """Identifier for cached features""" + return f'{self.sample_rate}-{self.hop_length}-{feature}' + + def _build_audio_cache(self): + """Load audio data and cache to disk""" + + audio_id = self._get_audio_id() + audio_cache = list( + (Path(self.cache_dir) / self.split).rglob( + f'{audio_id}.pt') + ) + if len(audio_cache) >= 1: + return + + # prepare to store audio waveforms and lengths + waveforms = torch.zeros(len(self.audio_list), 1, self.signal_length) + + pbar = tqdm(self.audio_list, total=len(self.audio_list)) + for i, audio_fn in enumerate(pbar): + pbar.set_description( + f'Loading {self.split}: {path.basename(audio_fn)}') + + # load audio and resample, but leave original length + waveform, _ = li.load(audio_fn, + mono=True, + sr=self.sample_rate) + waveforms[ + i, :, :min(self.signal_length, len(waveform)) + ] = torch.from_numpy(waveform)[..., :self.signal_length] + + # cache padded tensors and lengths to disk + torch.save(waveforms, + path.join( + self.cache_dir, + self.split, + f'{audio_id}.pt') + ) + + def _build_target_cache(self): + """Load targets and cache to disk""" + raise NotImplementedError() + + def _build_feature_cache(self, feature: str): + """Load features and cache to disk""" + + feature_id = self._get_feature_id(feature) + feature_cache = list( + (Path(self.cache_dir) / self.split).rglob( + f'{feature_id}.pt') + ) + if len(feature_cache) >= 1: + return + + # compute f0, periodicity using PyWorld 'dio' algorithm + pitch_extractor = PitchEncoder(hop_length=self.hop_length) + loudness_extractor = LoudnessEncoder(hop_length=self.hop_length) + + # determine 'zero' values for each feature + zero_pitch, zero_per = pitch_extractor( + torch.zeros(1, 1, self.signal_length)) + zero_loud = loudness_extractor(torch.zeros(1, 1, self.signal_length)) + pad_val_pitch = zero_pitch.mean().item() + pad_val_per = zero_per.mean().item() + pad_val_loud = zero_loud.mean().item() + + # store frame-wise features + if feature == 'loudness': + loudness = torch.full( + (len(self.audio_list), self.num_frames, 1), + pad_val_loud, + dtype=torch.float32 + ) + elif feature in ['pitch', 'periodicity']: + pitch = torch.full( + (len(self.audio_list), self.num_frames, 1), + pad_val_pitch, + dtype=torch.float32 + ) + periodicity = torch.full( + (len(self.audio_list), self.num_frames, 1), + pad_val_per, + dtype=torch.float32 + ) + + # iterate over audio + pbar = tqdm(self.audio_list, total=len(self.audio_list)) + for i, audio_fn in enumerate(pbar): + pbar.set_description( + f'Computing {feature} ({self.split}): ' + f'{path.basename(audio_fn)}') + + # load audio and resample, but leave original length + waveform, _ = li.load(audio_fn, + mono=True, + sr=self.sample_rate, + duration=self.signal_length / self.sample_rate) + + # convert to tensor, insert batch dimension + waveform = torch.from_numpy(waveform).unsqueeze(0) + + # trim or pad waveform if necessary + if waveform.shape[-1] >= self.signal_length: + waveform = waveform[..., :self.signal_length] + else: + pad_len = self.signal_length - waveform.shape[-1] + waveform = F.pad(waveform, (0, pad_len)) + + # compute and store pitch/periodicity in tandem + if feature in ['pitch', 'periodicity']: + + f0, p = pitch_extractor(waveform) + pitch[ + i, :min(f0.shape[1], self.num_frames), : + ] = f0[:, :self.num_frames, :] + periodicity[ + i, :min(p.shape[1], self.num_frames), : + ] = p[:, :self.num_frames, :] + + elif feature == 'loudness': + + l = loudness_extractor(waveform) + loudness[ + i, :min(l.shape[1], self.num_frames), : + ] = l[:, :self.num_frames, :] + + else: + raise ValueError(f'Invalid feature type {feature}') + + if feature in ['pitch', 'periodicity']: + + # save to disk + torch.save(pitch, + path.join( + self.cache_dir, + self.split, + f'{self._get_feature_id("pitch")}.pt' + )) + torch.save(periodicity, + path.join( + self.cache_dir, + self.split, + f'{self._get_feature_id("periodicity")}.pt' + )) + else: + # save to disk + torch.save(loudness, + path.join( + self.cache_dir, + self.split, + f'{feature_id}.pt' + )) + + def __len__(self): + return len(self.tx) + + def __getitem__(self, idx): + """Return batch of audio, targets, and optional feature values""" + + if self.batch_format == 'dict': + # return batch items by name + batch = { + 'x': self.tx[idx], + 'y': self.ty[idx], + **{k: self.tf[k][idx] for k in self.tf} + } + elif self.batch_format == 'tuple': + # return batch items in order + batch = (self.tx[idx], self.ty[idx]) + tuple( + self.tf[k][idx] for k in self.tf) + else: + raise ValueError(f'Invalid batch format {self.batch_format}') + + return batch + + def index_reduce(self, idx): + """Reduce to a subset by indexing into all stored tensors""" + + new_dataset = deepcopy(self) + new_dataset.tx = new_dataset.tx[idx] + new_dataset.ty = new_dataset.ty[idx] + for feature in new_dataset.features: + new_dataset.tf[feature] = new_dataset.tf[feature][idx] + + return new_dataset + + def overwrite_dataset(self, x: torch.Tensor, y: torch.Tensor, idx): + """Overwrite inputs and targets, and select features correspondingly""" + + # support boolean or integer indices + assert len(idx) <= self.__len__() + assert len(idx) == self.__len__() or \ + (len(idx) == len(x) and len(idx) == len(y)) + + new_dataset = self.index_reduce(idx) + new_dataset.tx = x + new_dataset.ty = y + + return new_dataset diff --git a/voicebox/src/data/librispeech.py b/voicebox/src/data/librispeech.py new file mode 100644 index 0000000000000000000000000000000000000000..ad6f9e68d25062edacdedbeeb79499bce11408e9 --- /dev/null +++ b/voicebox/src/data/librispeech.py @@ -0,0 +1,253 @@ +import os +import math + +import librosa as li +import numpy as np +import textgrid + +import torch + + +from src.data import DataProperties, VoiceBoxDataset +from src.utils import ensure_dir +from src.constants import ( + LIBRISPEECH_DATA_DIR, + LIBRISPEECH_CACHE_DIR, + SAMPLE_RATE, + LIBRISPEECH_EXT, + LIBRISPEECH_PHONEME_EXT, + LIBRISPEECH_PHONEME_DICT, + LIBRISPEECH_SIG_LEN, + HOP_LENGTH +) +from src.attacks.offline.perturbation.voicebox.voicebox import PitchEncoder + +from os import path +from tqdm import tqdm +from pathlib import Path + +from typing import Union, Iterable + + +################################################################################ +# Cache and load LibriSpeech dataset +################################################################################ + + +class LibriSpeechDataset(VoiceBoxDataset): + """ + A Dataset object for the LibriSpeech dataset subsets. The required data can + be downloaded by running the script `download_librispeech.sh`. This class + takes audio data from the specified directory and caches tensors to disk. + """ + def __init__(self, + split: str = 'test-clean', + data_dir: str = LIBRISPEECH_DATA_DIR, + cache_dir: str = LIBRISPEECH_CACHE_DIR, + sample_rate: int = SAMPLE_RATE, + audio_ext: str = LIBRISPEECH_EXT, + phoneme_ext: str = LIBRISPEECH_PHONEME_EXT, + signal_length: Union[float, int] = LIBRISPEECH_SIG_LEN, + scale: Union[float, int] = 1.0, + hop_length: int = HOP_LENGTH, + target: str = 'speaker', + features: Union[str, Iterable[str]] = None, + batch_format: str = 'dict', + *args, + **kwargs): + """ + Load, organize, and cache LibriSpeech dataset. + + Parameters + ---------- + split (str): + + data_dir (str): LibriSpeech root directory + + cache_dir (str): root directory to which tensors will be saved + + sample_rate (int): sample rate in Hz + + audio_ext (str): extension for audio files within dataset + + phoneme_ext (str): extension for phoneme alignment files within + dataset + + signal_length (int): length of audio files in samples (if `int` given) + or seconds (if `float` given) + + scale (float): range to which audio will be scaled + + hop_length (int): hop size for computing frame-wise features (e.g. + pitch, loudness) + + target (str): string specifying target type. Must be one of + `speaker` (speaker ID), `phoneme` (aligned phoneme + labels), or `transcript` + + features (Iterable): strings specifying features to compute for each + audio file in the dataset. Must be subset of + `pitch`, `periodicity`, `loudness` + + batch_format (str): format for returning batches. Must be either `dict` + or `tuple` + """ + + self.phoneme_ext = phoneme_ext + self.phoneme_list = [] + + super().__init__( + split=split, + data_dir=data_dir, + cache_dir=cache_dir, + audio_ext=audio_ext, + signal_length=signal_length, + scale=scale, + target=target, + features=features, + sample_rate=sample_rate, + hop_length=hop_length, + batch_format=batch_format, + *args, **kwargs + ) + + def __str__(self): + """Return string representation of dataset""" + return f'LibriSpeechDataset(split={self.split}, ' \ + f'target={self.target}, features={self.features})' + + @staticmethod + def _check_split(split: str): + """Check for valid dataset split""" + if split not in [ + 'test-clean', + 'test-other', + 'dev-clean', + 'dev-other', + 'train-clean-100', + 'train-clean-360', + 'train-other-500' + ]: + raise ValueError(f'Invalid split {split}') + return split + + @staticmethod + def _check_target(target: str): + if target not in ['speaker', 'phoneme', 'transcript']: + raise ValueError(f'Invalid target type {target}') + return target + + def _get_target_id(self): + """Identifier for cached targets""" + if self.target in ['speaker', 'transcript']: + return f'{self.target}' + else: + return f'{self.sample_rate}-{self.hop_length}-{self.target}' + + def _get_audio_list(self, *args, **kwargs): + """ + Scan for all audio files with given extension. Additionally, only select + audio files for which corresponding phoneme alignments exist. + """ + + audio_files = [os.path.splitext(f)[0] for f in + (Path(self.data_dir) / self.split).rglob( + f'*.{self.audio_ext}')] + phoneme_files = [os.path.splitext(f)[0] for f in + (Path(self.data_dir) / self.split).rglob( + f'*.{self.phoneme_ext}')] + matching_files = list(set(audio_files) & set(phoneme_files)) + + return sorted( + [f + "." + self.audio_ext for f in matching_files] + ) + + def _build_target_cache(self): + """Process and cache targets""" + + target_id = self._get_target_id() + target_cache = list( + (Path(self.cache_dir) / self.split).rglob( + f'{target_id}.pt') + ) + if len(target_cache) >= 1: + return + + # speaker ID targets + if self.target == 'speaker': + + targets = torch.zeros( + len(self.audio_list), dtype=torch.long + ) + + pbar = tqdm(self.audio_list, total=len(self.audio_list)) + for i, audio_fn in enumerate(pbar): + pbar.set_description( + f'Loading Speaker IDs ({self.split}): ' + f'{path.basename(audio_fn)}') + + # extract speaker ID + targets[i] = int(Path(audio_fn).parts[-3]) + + # frame-aligned phoneme label targets + elif self.target == 'phoneme': + + # retrieve phoneme alignment files + self.phoneme_list = [ + os.path.splitext(f)[0] + + "." + self.phoneme_ext for f in self.audio_list] + + targets = torch.zeros(len(self.phoneme_list), + self.num_frames, + dtype=torch.long) + + pbar = tqdm(self.phoneme_list, total=len(self.phoneme_list)) + for i, phoneme_fn in enumerate(pbar): + + pbar.set_description( + f'Loading phoneme alignments ({self.split}): ' + f'{path.basename(phoneme_fn)}') + + # load interval labels from TextGrid format + tg = textgrid.TextGrid.fromFile(phoneme_fn) + if tg[0].name == 'phones': + phoneme_intervals = tg[0] + elif tg[1].name == 'phones': + phoneme_intervals = tg[1] + else: + raise ValueError("Could not find phonemes") + + # compute number of frames in audio file given hop size, + # rounding up + num_frames = math.ceil( + tg.maxTime * self.sample_rate / self.hop_length) + ppg = torch.zeros(num_frames, dtype=torch.long) + + # for each labeled interval, break up into frames with given hop + # size and assign phoneme labels + for interval in phoneme_intervals: + interval.minTime = math.ceil( + interval.minTime * self.sample_rate / self.hop_length) + interval.maxTime = math.ceil( + interval.maxTime * self.sample_rate / self.hop_length) + phoneme_idx = LIBRISPEECH_PHONEME_DICT[interval.mark] + ppg[interval.minTime:interval.maxTime+1] = phoneme_idx + + targets[ + i, :min(ppg.shape[-1], self.num_frames) + ] = ppg[..., :self.num_frames] + + # string transcript targets + elif self.target == 'transcript': + raise NotImplementedError() + + else: + raise ValueError(f'Invalid target type {self.target}') + + # cache targets to disk + torch.save(targets, + path.join( + self.cache_dir, + self.split, + f'{target_id}.pt' + )) diff --git a/voicebox/src/defenses/__init__.py b/voicebox/src/defenses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c73e2e759048d5e9fc9a9fadba09d76120955fa --- /dev/null +++ b/voicebox/src/defenses/__init__.py @@ -0,0 +1 @@ +from src.defenses.defense import Defense diff --git a/voicebox/src/defenses/__pycache__/__init__.cpython-310.pyc b/voicebox/src/defenses/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a30dda438026294831fe422d6d38ebe3f2f6eb49 Binary files /dev/null and b/voicebox/src/defenses/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/defenses/__pycache__/__init__.cpython-39.pyc b/voicebox/src/defenses/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63a6344c85c39c0cac885c0b207b313cd3a1e4f3 Binary files /dev/null and b/voicebox/src/defenses/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/defenses/__pycache__/defense.cpython-310.pyc b/voicebox/src/defenses/__pycache__/defense.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88e621895d93cfc9710779689d59645433374fcb Binary files /dev/null and b/voicebox/src/defenses/__pycache__/defense.cpython-310.pyc differ diff --git a/voicebox/src/defenses/__pycache__/defense.cpython-39.pyc b/voicebox/src/defenses/__pycache__/defense.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e759ea7de4d8ba06e0dd4471428eb0fdd2eac38 Binary files /dev/null and b/voicebox/src/defenses/__pycache__/defense.cpython-39.pyc differ diff --git a/voicebox/src/defenses/__pycache__/detection.cpython-310.pyc b/voicebox/src/defenses/__pycache__/detection.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48b53208eafd88d0871ddb03d1d71cac2ee46dd7 Binary files /dev/null and b/voicebox/src/defenses/__pycache__/detection.cpython-310.pyc differ diff --git a/voicebox/src/defenses/__pycache__/detection.cpython-39.pyc b/voicebox/src/defenses/__pycache__/detection.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db8f83eaf7412a3c4552a94a8b22dd36f90022e9 Binary files /dev/null and b/voicebox/src/defenses/__pycache__/detection.cpython-39.pyc differ diff --git a/voicebox/src/defenses/__pycache__/purification.cpython-310.pyc b/voicebox/src/defenses/__pycache__/purification.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c8bec4b82f27c65007e190ca2ccf59e1cb191ff Binary files /dev/null and b/voicebox/src/defenses/__pycache__/purification.cpython-310.pyc differ diff --git a/voicebox/src/defenses/__pycache__/purification.cpython-39.pyc b/voicebox/src/defenses/__pycache__/purification.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..507745a0cbf346c24c60dfc43a2dbfe5b576b68d Binary files /dev/null and b/voicebox/src/defenses/__pycache__/purification.cpython-39.pyc differ diff --git a/voicebox/src/defenses/defense.py b/voicebox/src/defenses/defense.py new file mode 100644 index 0000000000000000000000000000000000000000..b3e5fe398b61910711d5928fec3376b5e45816a2 --- /dev/null +++ b/voicebox/src/defenses/defense.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn + +from typing import Iterable + +from src.defenses.purification import Purification +from src.defenses.detection import Detection +from src.models import Model + +################################################################################ +# Hold and apply both purification and detection defenses +################################################################################ + + +class Defense(nn.Module): + """ + Wrapper for sequential application of purification defenses and parallel + application of detection defenses. Allows for straight-through gradient + estimation. + """ + def __init__(self, + purification: Iterable[Purification], + detection: Iterable[Detection]): + super().__init__() + + if purification is None: + self.purification = nn.ModuleList([nn.Identity()]) + else: + self.purification = nn.ModuleList(purification) + + if detection is None: + self.detection = None + else: + self.detection = nn.ModuleList(detection) + + def purify(self, x: torch.Tensor): + """ + Apply purification defenses in sequence + """ + for p in self.purification: + + x = p(x) + + return x + + def detect(self, + x: torch.Tensor, + model: Model = None): + """ + Apply detection defenses in parallel. For each input, return maximum + score and detection flag obtained from all defenses. + """ + + # require batch dimension + assert x.ndim >= 2 + n_batch = x.shape[0] + + if isinstance( + self.detection, + nn.ModuleList + ) and len(self.detection) > 0: + + flags = [] + scores = [] + + for d in self.detection: + + flag, score = d(x, model) + + assert flag.shape[0] == score.shape[0] == n_batch + assert torch.prod(flag.shape).item() == n_batch + assert torch.prod(score.shape).item() == n_batch + + # ensure detector output shape of (n_batch, 1) + flags.append(flag.reshape(-1, 1)) + scores.append(score.reshape(-1, 1)) + + # concatenate outputs, size (n_batch, n_detectors) + scores = torch.cat(scores, dim=-1) + flags = torch.cat(flags, dim=-1) + + # final maximum scores/flags, size (n_batch, 1) + scores = torch.max(scores, dim=-1)[0] + flags = torch.max(flags, dim=-1)[0] + + else: + + n_batch = x.shape[0] + + # allow zero-gradients to propagate + flags = x.reshape(n_batch, -1).sum(dim=-1).reshape(n_batch, 1) * 0 + scores = x.reshape(n_batch, -1).sum(dim=-1).reshape(n_batch, 1) * 0 + + return flags, scores diff --git a/voicebox/src/defenses/detection.py b/voicebox/src/defenses/detection.py new file mode 100644 index 0000000000000000000000000000000000000000..97d256c6db7fe3fa0c8d251abfa0c4d55e80dc72 --- /dev/null +++ b/voicebox/src/defenses/detection.py @@ -0,0 +1,38 @@ +import torch +from torch import nn + +from typing import Tuple + +################################################################################ +# Base class for detection defense objects +################################################################################ + + +class Detection(nn.Module): + """ + Attempt to detect adversarial inputs, typically by observing a difference in + model response when a transformation is applied to inputs + """ + + def __init__(self, compute_grad: bool = True, threshold: float = 0.0): + super().__init__() + self.compute_grad = compute_grad + self.threshold = threshold + + def forward(self, + x: torch.Tensor, + model: nn.Module = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Determine whether input is adversarial, optionally using the response + of a trained model. + + :param x: audio input, shape (n_batch, n_channels, signal_length) + :param model: optionally, accept model; some defenses rely on observing + a model's response to transformed inputs + + :return: a tuple of tensors holding: + * boolean detection flags, shape (n_batch,) + * detector scores, shape (n_batch,) + """ + raise NotImplementedError() diff --git a/voicebox/src/defenses/purification.py b/voicebox/src/defenses/purification.py new file mode 100644 index 0000000000000000000000000000000000000000..f9bc646a8c555a45b90fc453de59c2bf064a35cc --- /dev/null +++ b/voicebox/src/defenses/purification.py @@ -0,0 +1,20 @@ +import torch +import torch.nn as nn + +################################################################################ +# Base class for purification defense objects +################################################################################ + + +class Purification(nn.Module): + """ + Attempt to 'purify' adversarial inputs by applying a distortion (e.g. + filter, compression) + """ + + def __init__(self, compute_grad: bool = True): + super().__init__() + self.compute_grad = compute_grad + + def forward(self, x: torch.Tensor): + raise NotImplementedError() diff --git a/voicebox/src/loss/__init__.py b/voicebox/src/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c73f9e2ba781a2ccb9c24b637ce7e9cf793c3de --- /dev/null +++ b/voicebox/src/loss/__init__.py @@ -0,0 +1,15 @@ +# Adversarial losses +from src.loss.cross_entropy import CELoss +from src.loss.cw import CWLoss +from src.loss.speaker_embedding import SpeakerEmbeddingLoss + +# Auxiliary losses +from src.loss.l1 import L1Loss +from src.loss.l2 import L2Loss +from src.loss.mrstft import MRSTFTLoss +from src.loss.demucs_mrstft import DemucsMRSTFTLoss +from src.loss.mfcc_cosine import MFCCCosineLoss +from src.loss.speech_features import SpeechFeatureLoss +from src.loss.frequency_masking import FrequencyMaskingLoss +from src.loss.sum import SumLoss +from src.loss.control import ControlSignalLoss diff --git a/voicebox/src/loss/__pycache__/__init__.cpython-39.pyc b/voicebox/src/loss/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4460f5855a203eb16f350c1bcd88c91ca7435a7 Binary files /dev/null and b/voicebox/src/loss/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/adversarial.cpython-39.pyc b/voicebox/src/loss/__pycache__/adversarial.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0675d94a7732bbc9335e077b69b75c42a5d8a41 Binary files /dev/null and b/voicebox/src/loss/__pycache__/adversarial.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/auxiliary.cpython-39.pyc b/voicebox/src/loss/__pycache__/auxiliary.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34b0bd88250186b7ea7e6bfec9da6117f60bc6d5 Binary files /dev/null and b/voicebox/src/loss/__pycache__/auxiliary.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/control.cpython-39.pyc b/voicebox/src/loss/__pycache__/control.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6af5271532d4eac023fdc2d607775e9598ae8b81 Binary files /dev/null and b/voicebox/src/loss/__pycache__/control.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/cross_entropy.cpython-39.pyc b/voicebox/src/loss/__pycache__/cross_entropy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b3eb270362af21e430c364c1cb087adc04de6d8 Binary files /dev/null and b/voicebox/src/loss/__pycache__/cross_entropy.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/cw.cpython-39.pyc b/voicebox/src/loss/__pycache__/cw.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf6467a924292f191af8a9a2102db477930d8663 Binary files /dev/null and b/voicebox/src/loss/__pycache__/cw.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/demucs_mrstft.cpython-39.pyc b/voicebox/src/loss/__pycache__/demucs_mrstft.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98dc071caf269cec5a2919aca79dbce1d8c4efb1 Binary files /dev/null and b/voicebox/src/loss/__pycache__/demucs_mrstft.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/frequency_masking.cpython-39.pyc b/voicebox/src/loss/__pycache__/frequency_masking.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92bd79f2e22a7c7008eb383aabea8b7fd74c90c5 Binary files /dev/null and b/voicebox/src/loss/__pycache__/frequency_masking.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/l1.cpython-39.pyc b/voicebox/src/loss/__pycache__/l1.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..360a7cab01884e4099276f26a65d240c0d6d4e2a Binary files /dev/null and b/voicebox/src/loss/__pycache__/l1.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/l2.cpython-39.pyc b/voicebox/src/loss/__pycache__/l2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..498bf7a879863d51a6065c8ae08550434e089c2d Binary files /dev/null and b/voicebox/src/loss/__pycache__/l2.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/loss.cpython-39.pyc b/voicebox/src/loss/__pycache__/loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04d5d55ae96692e9f7f74cfcddf939dbbcc865a4 Binary files /dev/null and b/voicebox/src/loss/__pycache__/loss.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/mfcc_cosine.cpython-39.pyc b/voicebox/src/loss/__pycache__/mfcc_cosine.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95a444bed0981342582d5840ebf2c8834c220e2c Binary files /dev/null and b/voicebox/src/loss/__pycache__/mfcc_cosine.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/mrstft.cpython-39.pyc b/voicebox/src/loss/__pycache__/mrstft.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc72142c9550f0c46e15c4ee2ffaaee73968935c Binary files /dev/null and b/voicebox/src/loss/__pycache__/mrstft.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/speaker_embedding.cpython-39.pyc b/voicebox/src/loss/__pycache__/speaker_embedding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31f1e5a87f4fc290d22e5f4f7ac9748069c69a82 Binary files /dev/null and b/voicebox/src/loss/__pycache__/speaker_embedding.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/speech_features.cpython-39.pyc b/voicebox/src/loss/__pycache__/speech_features.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daa1fc2824203dfd98fb200b85f9b26ea8f2ea98 Binary files /dev/null and b/voicebox/src/loss/__pycache__/speech_features.cpython-39.pyc differ diff --git a/voicebox/src/loss/__pycache__/sum.cpython-39.pyc b/voicebox/src/loss/__pycache__/sum.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..358e8dbea2ca4570edc048e2e163dda69d97199a Binary files /dev/null and b/voicebox/src/loss/__pycache__/sum.cpython-39.pyc differ diff --git a/voicebox/src/loss/adversarial.py b/voicebox/src/loss/adversarial.py new file mode 100644 index 0000000000000000000000000000000000000000..f99dbd13cbd41230bc8ca032b556ade8629bc29a --- /dev/null +++ b/voicebox/src/loss/adversarial.py @@ -0,0 +1,26 @@ +import torch + +from src.loss.loss import Loss + +################################################################################ +# Base class for adversarial classification objectives +################################################################################ + + +class AdversarialLoss(Loss): + """ + Wrapper for adversarial losses computed on paired targets. Subclasses must + override the method `_compute_loss()` to compute an unreduced batch loss, as + batch reduction is left to `forward()` + """ + + def __init__(self, + targeted: bool = True, + reduction: str = 'none' + ): + super().__init__(reduction) + self.targeted = targeted + + def _compute_loss(self, *args, **kwargs): + raise NotImplementedError() + diff --git a/voicebox/src/loss/auxiliary.py b/voicebox/src/loss/auxiliary.py new file mode 100644 index 0000000000000000000000000000000000000000..e9024d0371e25ebf2cba644878c1d495275fc1b0 --- /dev/null +++ b/voicebox/src/loss/auxiliary.py @@ -0,0 +1,59 @@ +import torch + +from src.loss.loss import Loss + +################################################################################ +# Base class for adversarial auxiliary objectives +################################################################################ + + +class AuxiliaryLoss(Loss): + """ + Wrapper for auxiliary (e.g. perceptual) losses, computed either on inputs + alone ("reference-free") or input-reference pairs ("full reference"). + Subclasses must override the method `_compute_loss()` to compute an + unreduced batch loss, as batch reduction is left to `forward()`. + + Subclasses must also implement the method `set_reference()`, which can be + used to compute and cache references. This may be useful in avoiding + re-computing expensive reference representations, such as the psychoacoustic + thresholds required by a frequency-masking loss. + """ + def __init__(self, + reduction: str = 'none' + ): + super().__init__(reduction) + + def _compute_loss(self, x: torch.Tensor, x_ref: torch.Tensor = None): + """ + Compute unreduced batch loss. + + :param x: input, shape (n_batch, ...) + :param x_ref: reference, shape (n_batch, ...) + :return: loss, shape (n_batch,) + """ + raise NotImplementedError() + + @staticmethod + def _check_broadcastable(x: torch.Tensor, x_ref: torch.Tensor): + """ + Check whether input and reference tensors are broadcastable + """ + + broadcastable = all( + (m == n) or (m == 1) or (n == 1) for m, n in zip( + x.shape[::-1], x_ref.shape[::-1] + ) + ) + + # broadcast cannot expand input batch dimension + valid = x.shape[0] == x_ref.shape[0] or x_ref.shape[0] == 1 + + return broadcastable * valid + + def set_reference(self, x_ref: torch.Tensor): + """ + Compute and cache reference representation(s). + """ + raise NotImplementedError() + diff --git a/voicebox/src/loss/control.py b/voicebox/src/loss/control.py new file mode 100644 index 0000000000000000000000000000000000000000..e929a9dd864d8897459eafc3b469fb2410c44738 --- /dev/null +++ b/voicebox/src/loss/control.py @@ -0,0 +1,110 @@ +import torch +import torch.nn.functional as F + +from src.loss.auxiliary import AuxiliaryLoss + +################################################################################ +# Control signal losses, for regularizing time-varying controls +################################################################################ + + +class ControlSignalLoss(AuxiliaryLoss): + """ + Compute losses to regularize time-varying control signals. + """ + def __init__(self, + reduction: str = 'none', + loss: str = 'group-sparse-slowness', + transpose: bool = False + ): + + super().__init__(reduction) + + # select loss variant + assert loss in ['l2-slowness', + 'l1-slowness', + 'group-sparse-slowness', + 'l1/2-group-sparsity', + 'l2', + 'l1' + ] + self.loss = loss + self.transpose = transpose + + def _compute_loss(self, x: torch.Tensor, x_ref: torch.Tensor = None): + """Compute specified loss on given control signal""" + + # require (n_batch, time, channels) representation + assert x.ndim == 3 + b, t, c = x.shape + + # if specified, flip time and channel dimensions + if self.transpose: + x = x.permute(0, 2, 1) + + if self.loss == 'l2-slowness': + loss = (1/((t - 1)*c))*torch.sum( + torch.sum( + torch.square( + torch.diff(x, dim=1) + ), + dim=2, + keepdim=True) + 1e-8, + dim=1, + keepdim=True + ).reshape(b) + + elif self.loss == 'l1-slowness': + loss = (1/((t - 1)*c))*torch.sum( + torch.sum( + torch.abs( + torch.diff(x, dim=1) + ), + dim=2, + keepdim=True) + 1e-8, + dim=1, + keepdim=True + ).reshape(b) + + elif self.loss == 'group-sparse-slowness': + loss = (1/((t - 1)*c))*torch.square( + torch.sum( + torch.sqrt( + torch.sum( + torch.square( + torch.diff(x, dim=1) + ), + dim=2, + keepdim=True) + 1e-8 + ), + dim=1, + keepdim=True + ) + ).reshape(b) + + elif self.loss == 'l1/2-group-sparsity': + loss = (1/((t - 1)*c))*torch.sum( + torch.sum( + torch.abs( + torch.diff(x, dim=1) + 1e-8 + )**0.5, + dim=2, + keepdim=True + )**2, + dim=1, + keepdim=True + ).reshape(b) + + elif self.loss == 'l2': + loss = x.norm(dim=(1, 2), p=2).reshape(b) + + elif self.loss == 'l1': + loss = x.norm(dim=(1, 2), p=1).reshape(b) + + else: + raise ValueError(f'Invalid control-signal loss {self.loss}') + + return loss + + def set_reference(self, x_ref: torch.Tensor): + pass diff --git a/voicebox/src/loss/cross_entropy.py b/voicebox/src/loss/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..5f233eceb784657bf6004ba0472fa7bf0263b546 --- /dev/null +++ b/voicebox/src/loss/cross_entropy.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + +from src.loss.adversarial import AdversarialLoss + +################################################################################ +# Cross-entropy loss +################################################################################ + + +class CELoss(AdversarialLoss): + """ + Measure cross-entropy between categorical (class) distributions + """ + def __init__(self, + targeted: bool = True, + reduction: str = 'none', + ): + super().__init__(targeted, reduction) + + self.loss = nn.CrossEntropyLoss(reduction='none') + + def _compute_loss(self, y_pred: torch.Tensor, y_true: torch.Tensor): + + assert y_pred.device == y_true.device + + assert y_pred.ndim >= 2 and y_pred.shape[-1] >= 2 + + if y_true.ndim >= 2: + y_true = y_true.argmax(dim=-1) + + loss = self.loss(y_pred, y_true) + + if not self.targeted: + loss *= -1 + return loss diff --git a/voicebox/src/loss/cw.py b/voicebox/src/loss/cw.py new file mode 100644 index 0000000000000000000000000000000000000000..29b0d4e0b601583eff5caa3fc5e098c2c8524543 --- /dev/null +++ b/voicebox/src/loss/cw.py @@ -0,0 +1,49 @@ +import torch +import torch.nn.functional as F + +from src.loss.adversarial import AdversarialLoss + +################################################################################ +# Carlini-Wagner loss; measures margin on logits (class scores) +################################################################################ + + +class CWLoss(AdversarialLoss): + """ + Penalize margin by which undesired class score(s) exceed desired class + score(s), with a "confidence" parameter determining the margin required to + incur loss + """ + def __init__(self, + targeted: bool = True, + reduction: str = 'none', + confidence: float = 0.0, + ): + super().__init__(targeted, reduction) + self.confidence = confidence + + @staticmethod + def _one_hot_encode(y: torch.tensor, n_classes: int): + return F.one_hot(y.type(torch.long), num_classes=n_classes) + + def _compute_loss(self, y_pred: torch.Tensor, y_true: torch.Tensor): + + assert y_pred.device == y_true.device + + if y_true.shape[-1] != y_pred.shape[-1]: + y_true = self._one_hot_encode(y_true, y_pred.shape[-1]) + + z_target = torch.sum(y_pred * y_true, dim=1) + + z_other = torch.max( + y_pred * (1 - y_true) + ( + torch.min(y_pred, dim=-1)[0] - 1 + ).reshape((-1, 1)) * y_true, dim=1, + )[0] + + if self.targeted: + loss = torch.clamp(z_other - z_target + self.confidence, min=0.) + else: + loss = torch.clamp(z_target - z_other + self.confidence, min=0.) + + return loss diff --git a/voicebox/src/loss/demucs_mrstft.py b/voicebox/src/loss/demucs_mrstft.py new file mode 100644 index 0000000000000000000000000000000000000000..65a76c57ddf9a20ac8a3526fba1328657b390f5f --- /dev/null +++ b/voicebox/src/loss/demucs_mrstft.py @@ -0,0 +1,241 @@ +import torch +import torch.nn.functional as F + +from typing import Collection + +from src.loss.auxiliary import AuxiliaryLoss + +################################################################################ +# Demucs-style multi-resolution STFT loss +################################################################################ + + +class DemucsMRSTFTLoss(AuxiliaryLoss): + """ + Compute multi-resolution spectrogram loss, as proposed by Yamamoto et al. + (https://arxiv.org/abs/1910.11480). Uses linear and log-scaled spectrograms + with spectral convergence loss, as in Defossez et al. + (https://arxiv.org/abs/2006.12847). Code adapted from + https://github.com/facebookresearch/denoiser. + """ + def __init__(self, + reduction: str = 'none', + fft_sizes: Collection = (1024, 2048, 512), + hop_sizes: Collection = (120, 240, 50), + win_lengths: Collection = (600, 1200, 240), + window: str = 'hann', + factor_sc: float = 0.1, + factor_mag: float = 0.1): + super().__init__(reduction) + + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) + + # store window functions as buffers + for i, win_length in enumerate(win_lengths): + self.register_buffer( + f'window_{i}', + self._get_win_func(window)(win_length) + ) + + # store STFT parameters at each resolution + self.stft_params = list( + zip( + fft_sizes, + hop_sizes, + win_lengths + ) + ) + + # scale losses + self.factor_sc = factor_sc + self.factor_mag = factor_mag + + # prepare to store reference spectrograms + self.ref_mag = None + + @staticmethod + def _get_win_func(win_type: str): + if win_type == 'rectangular': + return lambda m: torch.ones(m) + elif win_type == 'hann': + return lambda m: torch.hann_window(m) + elif win_type == 'hamming': + return lambda m: torch.hamming_window(m) + elif win_type == 'kaiser': + return lambda m: torch.kaiser_window(m) + else: + raise ValueError(f'Invalid window function {win_type}') + + @staticmethod + def _pad(x: torch.Tensor, win_length: int, hop_length: int): + """ + Avoid boundary artifacts by padding inputs before STFT such that all + samples are represented in the same number of spectrogram windows + """ + pad_frames = win_length // hop_length - 1 + pad_len = pad_frames * hop_length + return F.pad(x, (pad_len, pad_len)) + + def _stft(self, + x: torch.Tensor, + fft_size: int, + hop_size: int, + win_length: int, + window: torch.Tensor) -> torch.Tensor: + """ + Perform STFT and convert to magnitude spectrogram. + :param x: waveform audio; shape (n_batch, n_channels, signal_length) + :param fft_size: FFT size in samples + :param hop_size: hop size in samples + :param win_length: window length in samples + :param window: window function + :return: tensor holding magnitude spectrogram; shape + (n_batch, n_channels, n_frames, fft_size // 2 + 1) + """ + + # require batch dimension + assert x.ndim >= 2 + if x.ndim == 2: + x = x.unsqueeze(1) + + # pad to avoid boundary artifacts + x = self._pad(x, win_length, hop_size) + + # reshape to handle multi-channel audio + n_batch, n_channels, signal_length = x.shape + x = x.view(n_batch * n_channels, signal_length) + + # compute STFT + x_stft = torch.stft( + x, + fft_size, + hop_size, + win_length, + window + ) + _, n_freq, n_frames, _ = x_stft.shape + mag_stft = x_stft[..., 0] ** 2 + x_stft[..., 1] ** 2 + + return torch.sqrt( + torch.clamp( + mag_stft, min=1e-7) + ).transpose(-2, -1).view(n_batch, n_channels, n_frames, n_freq) + + @staticmethod + def _spectral_convergence(x_mag: torch.Tensor, + x_ref_mag: torch.Tensor) -> torch.Tensor: + """ + Compute spectral convergence loss between magnitude spectrograms. + + :param x_mag: magnitude spectrogram, shape + (n_batch, n_channels, n_frames, n_freq) + :param x_ref_mag: reference magnitude spectrogram, shape + :return: unreduced batch loss of shape (n_batch, n_channels) + """ + + # require batch dimension + assert x_mag.ndim >= 3 + + # flatten spectrogram dimensions + x_mag = x_mag.reshape(x_mag.shape[0], 1, -1) + x_ref_mag = x_ref_mag.reshape(x_ref_mag.shape[0], 1, -1) + + # numerical stability; otherwise, can end up with NaN gradients + # when x_mag == x_ref_mag + eps = 1e-12 + + return torch.linalg.matrix_norm( + x_ref_mag - x_mag + eps, ord="fro" + ) / torch.linalg.matrix_norm( + x_ref_mag, ord="fro" + ) + + @staticmethod + def _log_magnitude(x_mag: torch.Tensor, + x_ref_mag: torch.Tensor) -> torch.Tensor: + """ + Compute log loss between magnitude spectrograms. + + :param x_mag: magnitude spectrogram, shape + (n_batch, n_channels, n_frames, n_freq) + :param x_ref_mag: reference magnitude spectrogram, shape + :return: unreduced batch loss of shape (n_batch, n_channels) + """ + + assert x_mag.ndim >= 3 + n_batch = x_mag.shape[0] + + return torch.mean( + torch.abs( + torch.log( + x_mag + ).reshape(n_batch, -1) - torch.log( + x_ref_mag + ).reshape(n_batch, -1) + ), + dim=-1 + ) + + def _compute_loss(self, x: torch.Tensor, x_ref: torch.Tensor = None): + """ + Compute multi-resolution spectrogram loss between input and reference, + using both linear and log-scaled spectrograms of varying window lengths. + If no reference is provided, a stored reference will be used. + + :param x: input, shape (n_batch, n_channels, signal_length) + :param x_ref: reference, shape (n_batch, n_channels, signal_length) or + (1, n_channels, signal_length) + :return: loss, shape (n_batch,) + """ + + # require batch dimension + assert x.ndim >= 2 + + if x_ref is not None: + assert x_ref.ndim >= 2 + + # compute magnitude and spectral convergence losses + sc_loss = torch.zeros(x.shape[0]).to(x.device) + mag_loss = torch.zeros(x.shape[0]).to(x.device) + + for i, stft_params in enumerate(self.stft_params): + + window = list(self.buffers())[i] + + # compute input magnitude spectrogram + x_mag = self._stft(x, *stft_params, window) + + # compute or load reference magnitude spectrogram + if x_ref is not None: + x_ref_mag = self._stft(x_ref, *stft_params, window) + else: + x_ref_mag = self.ref_mag[i] + + # check compatibility of input and reference spectrograms + assert self._check_broadcastable( + x_mag, x_ref_mag + ), f"Cannot broadcast inputs of shape {x_mag.shape} " \ + f"with reference of shape {x_ref_mag.shape}" + + sc_loss += self._spectral_convergence(x_mag, x_ref_mag) + mag_loss += self._log_magnitude(x_mag, x_ref_mag) + + sc_loss /= len(self.stft_params) + mag_loss /= len(self.stft_params) + + return sc_loss * self.factor_sc + mag_loss * self.factor_mag + + def set_reference(self, x_ref: torch.Tensor): + + # require batch dimension, discard channel dimension + assert x_ref.ndim >= 2 + + # store reference spectrogram for each scale + self.ref_mag = [] + + for i, stft_params in enumerate(self.stft_params): + + window = list(self.buffers())[i] + + x_ref_mag = self._stft(x_ref, *stft_params, window) + self.ref_mag.append(x_ref_mag) diff --git a/voicebox/src/loss/frequency_masking.py b/voicebox/src/loss/frequency_masking.py new file mode 100644 index 0000000000000000000000000000000000000000..f3f330f7f93e3b97a0dfeb602d32b3d0c268a220 --- /dev/null +++ b/voicebox/src/loss/frequency_masking.py @@ -0,0 +1,520 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Tuple + +from src.loss.auxiliary import AuxiliaryLoss +from src.data import DataProperties + +################################################################################ +# Frequency masking loss +################################################################################ + + +class FrequencyMaskingLoss(AuxiliaryLoss): + """ + Compute perceptually-inspired frequency masking loss. Given a time-aligned + reference and additive perturbation, determine the degree to which the + perturbation exceeds psychoacoustic thresholds induced by the reference (and + thus the degree to which it is perceptible). Code adapted from + https://bit.ly/3JjfSbI. + """ + + def __init__(self, + reduction: str = 'none', + window_size: int = 2048, + hop_size: int = 512, + normalize: str = None): + super().__init__(reduction) + + self.sample_rate = DataProperties.get("sample_rate") + self.scale = DataProperties.get("scale") + + self.window_size = window_size + self.hop_size = hop_size + + # full-overlap: hop size must divide window size + if self.window_size % self.hop_size: + raise ValueError(f"Full-overlap: hop size {self.hop_size} must " + f"divide window size {self.window_size}") + + # encapsulate threshold computation + self.masker = PsychoacousticMasker( + window_size, + hop_size, + self.sample_rate + ) + + # normalize incoming audio to deal with loss scale-dependence + if normalize not in [None, 'none', 'peak']: + raise ValueError(f'Invalid normalization {normalize}') + self.normalize = normalize + + # store reference thresholds and PSD maxima to avoid recomputing + self.ref_wav = None + self.ref_thresh = None + self.ref_psd = None + + def _normalize(self, x: torch.Tensor): + if self.normalize == "peak": + (self.scale / torch.max( + torch.abs(x) + 1e-8, dim=-1, keepdim=True + )[0]) * x * 0.95 + elif self.normalize in [None, 'none']: + pass + else: + raise ValueError(f'Invalid normalization {self.normalize}') + return x + + def _pad(self, x: torch.Tensor): + """ + Avoid boundary artifacts by padding inputs before STFT such that all + samples are represented in the same number of spectrogram windows + """ + pad_frames = self.window_size // self.hop_size - 1 + pad_len = pad_frames * self.hop_size + return F.pad(x, (pad_len, pad_len)) + + def _stabilized_threshold_and_psd_maximum(self, x_ref: torch.Tensor): + """ + Return batch of stabilized masking thresholds and PSD maxima for + reference audio + + :param x_ref: waveform reference inputs of shape (n_batch, ...) + :return: tuple holding stabilized masking thresholds and PSD maxima + """ + + masking_threshold = [] + psd_maximum = [] + + assert x_ref.ndim >= 2 # inputs must have batch dimension + + # apply padding to avoid boundary artifacts + x_ref = self._pad(x_ref) + + for _, x_i in enumerate(x_ref): + mt, pm = self.masker.calculate_threshold_and_psd_maximum(x_i) + masking_threshold.append(mt) + psd_maximum.append(pm) + + # stabilize loss by canceling out the "10*log" term in power spectral + # density maximum and masking threshold + masking_threshold_stabilized = 10 ** ( + torch.cat(masking_threshold, dim=0) * 0.1 + ) + psd_maximum_stabilized = 10 ** (torch.cat(psd_maximum, dim=0) * 0.1) + + return masking_threshold_stabilized, psd_maximum_stabilized + + def _masking_hinge_loss( + self, + perturbation: torch.Tensor, + psd_maximum_stabilized: torch.Tensor, + masking_threshold_stabilized: torch.Tensor + ): + """ + Compute hinge loss between aligned perturbation and frequency-masking + thresholds induced by reference signal + """ + + n_batch = perturbation.shape[0] + + # calculate approximate power spectral density + psd_perturbation = self._approximate_power_spectral_density( + perturbation, psd_maximum_stabilized + ) + + # check that perturbation PSD is broadcastable with stored thresholds + assert self._check_broadcastable( + psd_perturbation, masking_threshold_stabilized + ), f"Cannot broadcast perturbation PSD of shape " \ + f"{psd_perturbation.shape} with reference thresholds of shape " \ + f"{masking_threshold_stabilized.shape}" + + # calculate hinge loss per input, averaged over frames + loss = nn.functional.relu( + psd_perturbation - masking_threshold_stabilized + ).view(n_batch, -1).mean(-1) + + return loss + + def _approximate_power_spectral_density( + self, + perturbation: torch.Tensor, + psd_maximum_stabilized: torch.Tensor + ): + """ + Approximate power spectral density of perturbation + """ + + # require batch dimension + assert perturbation.ndim >= 2 + + n_batch = perturbation.shape[0] + + # pad to avoid boundary artifacts + perturbation = self._pad(perturbation) + + # compute short-time Fourier transform (STFT) + stft_matrix = torch.stft( + perturbation.reshape(n_batch, -1), + n_fft=self.window_size, + hop_length=self.hop_size, + win_length=self.window_size, + center=False, + return_complex=False, + window=torch.hann_window(self.window_size).to(perturbation), + ).to(perturbation) + + # compute power spectral density (PSD); fixes implementation of Qin et + # al. by also considering the square root of gain_factor + gain_factor = torch.sqrt(torch.as_tensor(8.0 / 3.0)) + psd_matrix = torch.sum( + torch.square(gain_factor * stft_matrix / self.window_size), + dim=-1 + ) + + # approximate normalized psd using the following formula: + # psd_matrix_approximated = 10^((96.0 - psd_matrix_max + psd_matrix)/10) + psd_matrix_approximated = pow( + 10.0, 9.6 + ) / psd_maximum_stabilized.reshape(-1, 1, 1) * psd_matrix + + return psd_matrix_approximated # (n_batch, n_freq, self.window_size) + + def _compute_loss(self, x: torch.Tensor, x_ref: torch.Tensor = None): + """ + Compute frequency-masking loss. + + :param x: input, shape (n_batch, n_channels, signal_length) + :param x_ref: reference, shape (n_batch, n_channels, signal_length) or + (1, n_channels, signal_length) + :return: unreduced batch loss, shape (n_batch,) + """ + + # require batch dimension, discard channel dimension + assert x.ndim >= 2 + n_batch, signal_length = x.shape[0], x.shape[-1] + x = x.reshape(n_batch, signal_length) + x = self._normalize(x) + + if x_ref is not None: + assert x_ref.ndim >= 2 + x_ref = x_ref.reshape(x_ref.shape[0], x_ref.shape[-1]) + x_ref = self._normalize(x_ref) + masking_threshold, psd_maximum = \ + self._stabilized_threshold_and_psd_maximum(x_ref) + else: + x_ref = self.ref_wav + masking_threshold, psd_maximum = self.ref_thresh, self.ref_psd + + # check that waveform perturbation computation is broadcastable + assert self._check_broadcastable( + x, x_ref + ), f"Cannot broadcast input waveform of shape {x.shape} " \ + f"with reference waveform of shape {x_ref.shape}" + + perturbation = x - x_ref # assume additive waveform perturbation + + loss = self._masking_hinge_loss( + perturbation, psd_maximum, masking_threshold + ) + + return loss + + def set_reference(self, x_ref: torch.Tensor): + """ + Compute and store masking thresholds and PSD maxima for reference input + + :param x_ref: inputs, shape (n_batch, n_channels, signal_length) + """ + + # require batch dimension, discard channel dimension + assert x_ref.ndim >= 2 + n_batch, signal_length = x_ref.shape[0], x_ref.shape[-1] + x_ref = x_ref.clone().detach().reshape(n_batch, signal_length) + self.ref_wav = self._normalize(x_ref) + + self.ref_thresh, self.ref_psd = \ + self._stabilized_threshold_and_psd_maximum(self.ref_wav) + + # do not track gradients for stored references + self.ref_wav.requires_grad = False + self.ref_thresh.requires_grad = False + self.ref_psd.requires_grad = False + + +class PsychoacousticMasker: + """ + Adapted from Adversarial Robustness Toolbox Imperceptible ASR attack. + Implements psychoacoustic model of Lin and Abdulla (2015) following Qin et + al. (2019) simplifications. + + | Repo link: https://github.com/Trusted-AI/adversarial-robustness-toolbox/ + | Paper link: Lin and Abdulla (2015), https://www.springer.com/gp/book/9783319079738 + | Paper link: Qin et al. (2019), http://proceedings.mlr.press/v97/qin19a.html + """ + + def __init__(self, window_size: int = 2048, hop_size: int = 512, sample_rate: int = 16000) -> None: + """ + Initialization. + + :param window_size: Length of the window. The number of STFT rows is `(window_size // 2 + 1)`. + :param hop_size: Number of audio samples between adjacent STFT columns. + :param sample_rate: Sampling frequency of audio inputs. + """ + self._window_size = window_size + self._hop_size = hop_size + self._sample_rate = sample_rate + + # init some private properties for lazy loading + self._fft_frequencies = None + self._bark = None + self._absolute_threshold_hearing = None + + def calculate_threshold_and_psd_maximum(self, + audio: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the global masking threshold for an audio input and also return + its maximum power spectral density. This is the main method to call in + order to obtain global masking thresholds for an audio input. It also + returns the maximum power spectral density (PSD) for each frame. Given + an audio input, the following steps are performed: + + 1. STFT analysis and sound pressure level normalization + 2. Identification and filtering of maskers + 3. Calculation of individual masking thresholds + 4. Calculation of global masking thresholds + + :param audio: Audio samples of shape `(length,)`. + :return: Global masking thresholds of shape + `(window_size // 2 + 1, frame_length)` and the PSD maximum for + each frame of shape `(frame_length)`. + """ + + assert audio.ndim <= 1 or audio.shape[0] == 1 # process a single waveform + + # compute normalized PSD estimate frame-by-frame for each input, as well + # as maximum of each input's unnormalized PSD + psd_matrix, psd_max = self.power_spectral_density(audio) + threshold = torch.zeros_like(psd_matrix) + + # compute masking frequencies frame-by-frame for each input + for frame in range(psd_matrix.shape[-1]): + # apply methods for finding and filtering maskers + maskers, masker_idx = self.filter_maskers(*self.find_maskers(psd_matrix[..., frame])) + + # apply methods for calculating global threshold + threshold[..., frame] = self.calculate_global_threshold( + self.calculate_individual_threshold(maskers, masker_idx) + ) + + return threshold, psd_max + + def power_spectral_density(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the power spectral density matrix for an audio input. + + :param audio: audio inputs of shape `(signal_len,)`. + :return: PSD matrix of shape `(window_size // 2 + 1, frame_length)` and maximum vector of shape + `(n_batch, frame_length)`. + """ + + # compute short-time Fourier transform (STFT) + stft_matrix = torch.stft( + audio.reshape(1, -1), + n_fft=self.window_size, + hop_length=self.hop_size, + win_length=self.window_size, + center=False, + return_complex=True, + window=torch.hann_window(self.window_size).to(audio.device), + ).to(audio.device) + + # compute power spectral density (PSD) + # note: fixes implementation of Qin et al. by also considering the square root of gain_factor + gain_factor = torch.sqrt(torch.as_tensor(8.0 / 3.0)) + psd_matrix = 20 * torch.log10(torch.abs(gain_factor * stft_matrix / self.window_size)) + psd_matrix = psd_matrix.clamp(min=-200) + + # normalize PSD at 96dB + psd_matrix_max = torch.amax(psd_matrix, dim=[d for d in range(1, psd_matrix.ndim)], keepdim=True) + psd_matrix_normalized = 96.0 - psd_matrix_max + psd_matrix + + return psd_matrix_normalized, psd_matrix_max + + @staticmethod + def find_maskers(psd_vector: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Identify maskers. Possible maskers are local PSD maxima. Following Qin et al., + all maskers are treated as tonal. + + :param psd_vector: PSD vector of shape `(window_size // 2 + 1)`. + :return: Possible PSD maskers and indices. + """ + + # find all local maxima in single-frame PSD estimate + flat = psd_vector.reshape(-1) + left = flat[1:-1] - flat[:-2] + right = flat[1:-1] - flat[2:] + + ind = torch.where((left > 0) * (right > 0), + torch.ones_like(left), + torch.zeros_like(left)) + ind = torch.nn.functional.pad(ind, (1, 1), "constant", 0) + masker_idx = torch.nonzero(ind, out=None).cpu().reshape(-1) + + # smooth maskers with their direct neighbors + psd_maskers = 10 * torch.log10( + torch.sum( + torch.cat( + [10 ** (psd_vector[..., masker_idx + i] / 10) for i in range(-1, 2)] + ), + dim=0 + ) + ) + + return psd_maskers, masker_idx + + def filter_maskers(self, + maskers: torch.Tensor, + masker_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Filter maskers. First, discard all maskers that are below the absolute threshold + of hearing. Second, reduce pairs of maskers that are within 0.5 bark distance of + each other by keeping the larger masker. + + :param maskers: Masker PSD values. + :param masker_idx: Masker indices. + :return: Filtered PSD maskers and indices. + """ + # filter on the absolute threshold of hearing + # note: deviates from Qin et al. implementation by filtering first on ATH and only then on bark distance + ath_condition = maskers > self.absolute_threshold_hearing.to(maskers)[masker_idx] + masker_idx = masker_idx[ath_condition] + maskers = maskers[ath_condition] + + # filter on the bark distance + bark_condition = torch.ones(masker_idx.shape, dtype=torch.bool) + i_prev = 0 + for i in range(1, len(masker_idx)): + # find pairs of maskers that are within 0.5 bark distance of each other + if self.bark[i] - self.bark[i_prev] < 0.5: + # discard the smaller masker + i_todelete, i_prev = (i_prev, i_prev + 1) if maskers[i_prev] < maskers[i] else (i, i_prev) + bark_condition[i_todelete] = False + else: + i_prev = i + masker_idx = masker_idx[bark_condition] + maskers = maskers[bark_condition] + + return maskers, masker_idx + + @property + def window_size(self) -> int: + """ + :return: Window size of the masker. + """ + return self._window_size + + @property + def hop_size(self) -> int: + """ + :return: Hop size of the masker. + """ + return self._hop_size + + @property + def sample_rate(self) -> int: + """ + :return: Sample rate of the masker. + """ + return self._sample_rate + + @property + def fft_frequencies(self) -> torch.Tensor: + """ + :return: Discrete fourier transform sample frequencies. + """ + if self._fft_frequencies is None: + self._fft_frequencies = torch.linspace(0, self.sample_rate / 2, self.window_size // 2 + 1) + return self._fft_frequencies + + @property + def bark(self) -> torch.Tensor: + """ + :return: Bark scale for discrete fourier transform sample frequencies. + """ + if self._bark is None: + self._bark = 13 * torch.arctan(0.00076 * self.fft_frequencies) + 3.5 * torch.arctan( + torch.square(self.fft_frequencies / 7500.0) + ) + return self._bark + + @property + def absolute_threshold_hearing(self) -> torch.Tensor: + """ + :return: Absolute threshold of hearing (ATH) for discrete fourier transform sample frequencies. + """ + if self._absolute_threshold_hearing is None: + # ATH applies only to frequency range 20Hz<=f<=20kHz + # note: deviates from Qin et al. implementation by using the Hz range as valid domain + valid_domain = torch.logical_and(20 <= self.fft_frequencies, self.fft_frequencies <= 2e4) + freq = self.fft_frequencies[valid_domain] * 0.001 + + # outside valid ATH domain, set values to -infinity + # note: This ensures that every possible masker in the bins <=20Hz is valid. As a consequence, the global + # masking threshold formula will always return a value different to infinity + self._absolute_threshold_hearing = torch.ones(valid_domain.shape) * -float('inf') + + self._absolute_threshold_hearing[valid_domain] = ( + 3.64 * pow(freq, -0.8) - 6.5 * torch.exp(-0.6 * torch.square(freq - 3.3)) + 0.001 * pow(freq, 4) - 12 + ) + return self._absolute_threshold_hearing + + def calculate_individual_threshold(self, + maskers: torch.Tensor, + masker_idx: torch.Tensor) -> torch.Tensor: + """ + Calculate individual masking threshold with frequency denoted at bark scale. + + :param maskers: Masker PSD values. + :param masker_idx: Masker indices. + :return: Individual threshold vector of shape `(window_size // 2 + 1)`. + """ + delta_shift = -6.025 - 0.275 * self.bark + threshold = torch.zeros(masker_idx.shape + self.bark.shape).to(maskers) + + for k, (masker_j, masker) in enumerate(zip(masker_idx, maskers)): + + # critical band rate of the masker + z_j = self.bark[masker_j].to(maskers) + # distance maskees to masker in bark + delta_z = self.bark.to(maskers) - z_j + + # define two-slope spread function: + # if delta_z <= 0, spread_function = 27*delta_z + # if delta_z > 0, spread_function = [-27+0.37*max(PSD_masker-40,0]*delta_z + spread_function = 27 * delta_z + spread_function[delta_z > 0] = (-27 + 0.37 * max(masker - 40, 0)) * delta_z[delta_z > 0] + + # calculate threshold + threshold[k, :] = masker + delta_shift[masker_j] + spread_function + + return threshold + + def calculate_global_threshold(self, individual_threshold): + """ + Calculate global masking threshold. + + :param individual_threshold: Individual masking threshold vector. + :return: Global threshold vector of shape `(window_size // 2 + 1)`. + """ + # note: deviates from Qin et al. implementation by taking the log of the summation, which they do for numerical + # stability of the stage 2 optimization. We stabilize the optimization in the loss itself. + + return 10 * torch.log10( + torch.sum(10 ** (individual_threshold / 10), dim=0) + 10 ** (self.absolute_threshold_hearing.to(individual_threshold) / 10) + ) diff --git a/voicebox/src/loss/l1.py b/voicebox/src/loss/l1.py new file mode 100644 index 0000000000000000000000000000000000000000..d24d9debda71a58765bd514cba3707ef56e5452a --- /dev/null +++ b/voicebox/src/loss/l1.py @@ -0,0 +1,53 @@ +import torch + +from src.loss.auxiliary import AuxiliaryLoss + +################################################################################ +# L1 loss +################################################################################ + + +class L1Loss(AuxiliaryLoss): + + def __init__(self, + reduction: str = 'none', + ): + super().__init__(reduction) + self.ref_wav = None + + def set_reference(self, x_ref: torch.Tensor): + self.ref_wav = x_ref.clone().detach() + + def _compute_loss(self, x: torch.Tensor, x_ref: torch.Tensor = None): + """ + Compute L1 distance between input and reference. If no reference + is provided, a stored reference will be used. If no stored reference is + available, the L1 norm of the input will be returned. + + :param x: input tensor of shape (n_batch, ...) + :param x_ref: reference tensor of shape (n_batch, ...) or (1, ...) + :return: + """ + + # if no reference is stored or provided, apply L2 norm to input directly + if x_ref is None and self.ref_wav is None: + x_ref = torch.zeros_like(x) + + # use stored reference if none provided + elif x_ref is None: + x_ref = self.ref_wav + + # ensure broadcastable inputs + assert self._check_broadcastable( + x, x_ref + ), f"Cannot broadcast inputs of shape {x.shape} " \ + f"with reference of shape {x_ref.shape}" + + assert x.ndim >= 2 # require batch dimension + n_batch = x.shape[0] + + return torch.mean( + (x - x_ref).abs().reshape(n_batch, -1), + dim=-1 + ) + diff --git a/voicebox/src/loss/l2.py b/voicebox/src/loss/l2.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1612143be02c54a0250236e3c0446e0e60350b --- /dev/null +++ b/voicebox/src/loss/l2.py @@ -0,0 +1,50 @@ +import torch + +from src.loss.auxiliary import AuxiliaryLoss + +################################################################################ +# Squared L2 loss +################################################################################ + + +class L2Loss(AuxiliaryLoss): + + def __init__(self, + reduction: str = 'none', + ): + super().__init__(reduction) + self.ref_wav = None + + def set_reference(self, x_ref: torch.Tensor): + self.ref_wav = x_ref.clone().detach() + + def _compute_loss(self, x: torch.Tensor, x_ref: torch.Tensor = None): + """ + Compute squared L2 distance between input and reference. If no reference + is provided, a stored reference will be used. If no stored reference is + available, the L2 norm of the input will be returned. + + :param x: input tensor of shape (n_batch, ...) + :param x_ref: reference tensor of shape (n_batch, ...) or (1, ...) + :return: + """ + + # if no reference is stored or provided, apply L2 norm to input directly + if x_ref is None and self.ref_wav is None: + x_ref = torch.zeros_like(x) + + # use stored reference if none provided + elif x_ref is None: + x_ref = self.ref_wav + + # ensure broadcastable inputs + assert self._check_broadcastable( + x, x_ref + ), f"Cannot broadcast inputs of shape {x.shape} " \ + f"with reference of shape {x_ref.shape}" + + assert x.ndim >= 2 # require batch dimension + n_batch = x.shape[0] + + return torch.log(1. + torch.sum(torch.square(x_ref - x).reshape(n_batch, -1), dim=-1)) + diff --git a/voicebox/src/loss/loss.py b/voicebox/src/loss/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..355acb66925c40e908b8e047fabd6bcde7f0c10b --- /dev/null +++ b/voicebox/src/loss/loss.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn + + +################################################################################ +# Base class for all Loss objects +################################################################################ + + +class Loss(nn.Module): + """ + Base class for all losses (e.g. classification, auxiliary). Subclasses + must override the method `_compute_loss()` to compute an unreduced batch + loss, as batch reduction is left to `forward()` + """ + + def __init__(self, + reduction: str = 'none' + ): + super().__init__() + + self.reduction = reduction + + def _compute_loss(self, *args, **kwargs): + raise NotImplementedError() + + def forward(self, *args, **kwargs): + + batch_loss = self._compute_loss(*args, **kwargs) + + if self.reduction == 'mean': + return torch.mean(batch_loss) + elif self.reduction == 'sum': + return torch.sum(batch_loss) + elif self.reduction == 'none' or self.reduction is None: + return batch_loss + else: + raise ValueError(f'Invalid reduction {self.reduction}') + + diff --git a/voicebox/src/loss/mfcc_cosine.py b/voicebox/src/loss/mfcc_cosine.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb5a4f13028c9a498920c059531f20005549ad6 --- /dev/null +++ b/voicebox/src/loss/mfcc_cosine.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchaudio.transforms import MFCC + +from src.loss.auxiliary import AuxiliaryLoss +from src.data import DataProperties + +################################################################################ +# MFCC cosine loss; measures scale-independent spectral distance +################################################################################ + + +class MFCCCosineLoss(AuxiliaryLoss): + """ + Compute frame-wise cosine distance between MFCC representations of an input + and reference. This serves as a scale-independent spectral distance. + """ + def __init__(self, + reduction: str = 'none', + n_mfcc: int = 30, + log_mels: bool = True, + n_mels: int = 30, + win_length: float = 0.025, + hop_length: float = 0.010): + super().__init__(reduction) + + self.sample_rate = DataProperties.get("sample_rate") + + self.win_length = int(win_length * self.sample_rate) + self.hop_length = int(hop_length * self.sample_rate) + + self.mfcc = MFCC(n_mfcc=n_mfcc, + sample_rate=self.sample_rate, + norm='ortho', + log_mels=log_mels, + melkwargs={ + 'n_mels': n_mels, + 'n_fft': self.win_length, + 'win_length': self.win_length, + 'hop_length': self.hop_length, + 'f_min': 20.0, + 'f_max': self.sample_rate // 2, + 'window_fn': torch.hann_window} + ) + self.cos = nn.CosineSimilarity(dim=-2) # compute per frame + self.ref_mfcc = None + + def _pad(self, x: torch.Tensor): + """ + Avoid boundary artifacts by padding inputs before STFT such that all + samples are represented in the same number of spectrogram windows + """ + pad_frames = self.win_length // self.hop_length - 1 + pad_len = pad_frames * self.hop_length + return F.pad(x, (pad_len, pad_len)) + + def _compute_loss(self, x: torch.Tensor, x_ref: torch.Tensor = None): + """ + Compute frame-wise cosine similarity between MFCC representations of + input and reference. If no reference is provided, a stored reference + will be used. + + :param x: input, shape (n_batch, n_channels, signal_length) + :param x_ref: reference, shape (n_batch, n_channels, signal_length) or + (1, n_channels, signal_length) + :return: loss, shape (n_batch,) + """ + + # require batch dimension, discard channel dimension + assert x.ndim >= 2 + n_batch, signal_length = x.shape[0], x.shape[-1] + x = x.reshape(n_batch, signal_length) + + if x_ref is not None: + assert x_ref.ndim >= 2 + n_batch, signal_length = x_ref.shape[0], x_ref.shape[-1] + x_ref = x_ref.reshape(n_batch, signal_length) + + # compute input MFCC representation + x_pad = self._pad(x) + x_mfcc = self.mfcc(x_pad) # (n_batch, n_mfcc, n_frames) + + # compute reference MFCC representation + if x_ref is None: + x_ref_mfcc = self.ref_mfcc + else: + x_ref_pad = self._pad(x_ref) + x_ref_mfcc = self.mfcc(x_ref_pad) + + # ensure broadcastable inputs + assert self._check_broadcastable( + x_mfcc, x_ref_mfcc + ), f"Cannot broadcast inputs of shape {x_mfcc.shape} " \ + f"with reference of shape {x_ref_mfcc.shape}" + + cos_dist = 1-self.cos(x_mfcc, x_ref_mfcc) + + # taking mean along time dimension, rather than sum, prevents signal + # length from influencing loss magnitude + return cos_dist.mean(dim=-1) + + def set_reference(self, x_ref: torch.Tensor): + + # require batch dimension, discard channel dimension + assert x_ref.ndim >= 2 + n_batch, signal_length = x_ref.shape[0], x_ref.shape[-1] + x_ref = x_ref.reshape(n_batch, signal_length) + + self.mfcc.to(x_ref.device) # adopt reference device + + # pad to avoid boundary artifacts + x_ref_pad = self._pad(x_ref).clone().detach() + self.ref_mfcc = self.mfcc(x_ref_pad) diff --git a/voicebox/src/loss/mrstft.py b/voicebox/src/loss/mrstft.py new file mode 100644 index 0000000000000000000000000000000000000000..98baf696c7f60d4ec71e76459a6b4acec7546c90 --- /dev/null +++ b/voicebox/src/loss/mrstft.py @@ -0,0 +1,126 @@ +import torch +import torch.nn.functional as F + +from typing import Iterable + +from src.loss.auxiliary import AuxiliaryLoss + +################################################################################ +# DDSP-style multi-resolution STFT loss +################################################################################ + + +class MRSTFTLoss(AuxiliaryLoss): + """ + Compute multi-resolution spectrogram loss, as proposed by Yamamoto et al. + (https://arxiv.org/abs/1910.11480). Uses linear and log-scaled spectrograms, + as in Engel et al. (https://arxiv.org/abs/2001.04643). + """ + def __init__(self, + reduction: str = 'none', + scales: Iterable[int] = (4096, 2048, 1024, 512, 256, 128), + overlap: float = 0.75): + super().__init__(reduction) + + self.scales = scales + self.overlap = overlap + + self.ref_wav = None + self.ref_stft = None + + @staticmethod + def _safe_log(x: torch.Tensor): + return torch.log(x + 1e-7) + + @staticmethod + def _pad(x: torch.Tensor, win_length: int, hop_length: int): + """ + Avoid boundary artifacts by padding inputs before STFT such that all + samples are represented in the same number of spectrogram windows + """ + pad_frames = win_length // hop_length - 1 + pad_len = pad_frames * hop_length + return F.pad(x, (pad_len, pad_len)) + + def _stft(self, x: torch.Tensor, scale: int): + + # pad input to avoid boundary artifacts + x_pad = self._pad( + x, + win_length=scale, + hop_length=int(scale * (1 - self.overlap)) + ) + + # compute STFT at given window length + return torch.stft( + x_pad, + n_fft=scale, + hop_length=int(scale * (1 - self.overlap)), + win_length=scale, + window=torch.hann_window(scale).to(x_pad.device), + center=True, + normalized=True, + return_complex=True + ).abs() + + def set_reference(self, x_ref: torch.Tensor): + + # require batch dimension, discard channel dimension + assert x_ref.ndim >= 2 + n_batch, signal_length = x_ref.shape[0], x_ref.shape[-1] + x_ref = x_ref.reshape(n_batch, signal_length) + + # store reference spectrogram for each scale + self.ref_stft = [] + + for scale in self.scales: + x_ref_stft = self._stft(x_ref, scale) + self.ref_stft.append(x_ref_stft) + + def _compute_loss(self, x: torch.Tensor, x_ref: torch.Tensor = None): + """ + Compute multi-resolution spectrogram loss between input and reference, + using both linear and log-scaled spectrograms of varying window lengths. + If no reference is provided, a stored reference will be used. + + :param x: input, shape (n_batch, n_channels, signal_length) + :param x_ref: reference, shape (n_batch, n_channels, signal_length) or + (1, n_channels, signal_length) + :return: loss, shape (n_batch,) + """ + + # require batch dimension, discard channel dimension + assert x.ndim >= 2 + n_batch, signal_length = x.shape[0], x.shape[-1] + x = x.reshape(n_batch, signal_length) + + if x_ref is not None: + assert x_ref.ndim >= 2 + n_batch, signal_length = x_ref.shape[0], x_ref.shape[-1] + x_ref = x_ref.reshape(n_batch, signal_length) + + # compute loss on linear and log-scaled spectrograms + lin_loss = torch.zeros(x.shape[0]).to(x.device) + log_loss = torch.zeros(x.shape[0]).to(x.device) + + for i, scale in enumerate(self.scales): + + x_stft = self._stft(x, scale) + + if x_ref is not None: + x_ref_stft = self._stft(x_ref, scale) + else: + x_ref_stft = self.ref_stft[i] + + # check compatibility of input and reference spectrograms + assert self._check_broadcastable( + x_stft, x_ref_stft + ), f"Cannot broadcast inputs of shape {x_stft.shape} " \ + f"with reference of shape {x_ref_stft.shape}" + + lin_loss += (x_stft - x_ref_stft).abs().mean() + log_loss += ( + self._safe_log(x_stft) - self._safe_log(x_ref_stft) + ).abs().mean() + + return lin_loss + log_loss diff --git a/voicebox/src/loss/speaker_embedding.py b/voicebox/src/loss/speaker_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..b0bc48e3a512ebd88de57d4af87df4afcf87bf0e --- /dev/null +++ b/voicebox/src/loss/speaker_embedding.py @@ -0,0 +1,53 @@ +import torch + +from src.loss.adversarial import AdversarialLoss +from src.models.speaker.speaker import EmbeddingDistance + +################################################################################ +# Speaker embedding loss; measures distance in embedding space +################################################################################ + + +class SpeakerEmbeddingLoss(AdversarialLoss): + def __init__(self, + targeted: bool = True, + reduction: str = 'none', + confidence: float = 0.0, + distance_fn: str = 'cosine', + threshold: float = 0.0, + n_segments: int = 1 + ): + super().__init__(targeted, reduction) + + self.confidence = torch.tensor(confidence) + self.distance_fn = EmbeddingDistance(distance_fn) + self.threshold = threshold + self.n_segments = max(n_segments, 1) + + def _compute_loss( + self, + y_pred: torch.Tensor, + y_true: torch.Tensor + ) -> torch.Tensor: + """ + Given a batch of predicted and ground truth embeddings, compute + distance. It is assumed that `n_segments` embeddings have been produced + from each input audio file, and the distance is taken as the mean over + all predicted/ground-truth pairs in each tranche of `n_segments` + embeddings. + + :param y_pred: shape (n_batch, n_segments, embedding_dim) + :param y_true: shape (n_batch, n_segments, embedding_dim) + :return: loss, shape (n_batch,) + """ + + dist = self.distance_fn(y_pred, y_true) + + if self.targeted: + loss = torch.clamp(dist - self.threshold + self.confidence, min=0.) + else: + loss = torch.clamp(self.threshold - dist + self.confidence, min=0.) + + return loss + + diff --git a/voicebox/src/loss/speech_features.py b/voicebox/src/loss/speech_features.py new file mode 100644 index 0000000000000000000000000000000000000000..6da51150f7cae44bdc5bc89c64755b4bb75f95f3 --- /dev/null +++ b/voicebox/src/loss/speech_features.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.loss.auxiliary import AuxiliaryLoss +from src.models.speech import SpeechRecognitionModel, Wav2Vec2 + + +################################################################################ +# ASR feature-matching loss +################################################################################ + + +class SpeechFeatureLoss(AuxiliaryLoss): + """ + Compute distance at encoded feature representations or token emission + probabilities produced by a pretrained ASR model. For speech audio, these + representations should capture some notion of phonetic similarity. Adapted + from https://bit.ly/3z6EGyR. + """ + def __init__(self, + reduction: str = 'none', + model: SpeechRecognitionModel = SpeechRecognitionModel( + Wav2Vec2() + ), + use_tokens: bool = False + ): + super().__init__(reduction) + + self.ref_feats = None + self.ref_tokens = None + + self.model = model + + # disable gradient computation for ASR model parameters + self.model.eval() + for p in self.model.parameters(): + p.requires_grad = False + + self.use_tokens = use_tokens + + def _compute_loss(self, x: torch.Tensor, x_ref: torch.Tensor = None): + """ + Compute distance at encoded feature representations or token emission + probabilities produced by a pretrained ASR model. + + :param x: input, shape (n_batch, n_channels, signal_length) + :param x_ref: reference, shape (n_batch, n_channels, signal_length) or + (1, n_channels, signal_length) + :return: unreduced batch loss, shape (n_batch,) + """ + + # require batch dimension + assert x.ndim >= 2 + n_batch = x.shape[0] + + # prepare to store unreduced batch loss + loss = torch.zeros(x.shape[0]).to(x.device) + + # compare token emission probabilities + if self.use_tokens: + + x_tokens = self.model(x) + + if x_ref is not None: + x_ref_tokens = self.model(x_ref) + else: + x_ref_tokens = self.ref_tokens + + # check compatibility of input and reference emissions + assert self._check_broadcastable( + x_tokens, x_ref_tokens + ), f"Cannot broadcast inputs of shape {x_tokens.shape} " \ + f"with reference of shape {x_ref_tokens.shape}" + + loss += (x_tokens - x_ref_tokens).reshape( + n_batch, + -1 + ).norm(p=1, dim=-1) + + # compare deep features + else: + + x_feats = self.model.extract_features(x) + + if x_ref is not None: + x_ref_feats = self.model.extract_features(x_ref) + else: + x_ref_feats = self.ref_feats + + for i in range(len(x_feats)): + + # check compatibility of input and reference features + assert self._check_broadcastable( + x_feats[i], x_ref_feats[i] + ), f"Cannot broadcast inputs of shape {x_feats[i].shape} " \ + f"with reference of shape {x_ref_feats[i].shape}" + + loss += (x_feats[i] - x_ref_feats[i]).reshape( + n_batch, + -1 + ).norm(p=1, dim=-1) + + return loss + + def set_reference(self, x_ref: torch.Tensor): + """ + Compute and store deep features and token emission probabilities with + pretrained ASR model. + """ + + # store deep features + self.ref_feats = [ + r.detach() for r in self.model.extract_features(x_ref) + ] + + # store token emission probabilities + self.ref_tokens = self.model(x_ref).detach() diff --git a/voicebox/src/loss/sum.py b/voicebox/src/loss/sum.py new file mode 100644 index 0000000000000000000000000000000000000000..e475995fb65d23c4058b21a48a76c1163e2cacd7 --- /dev/null +++ b/voicebox/src/loss/sum.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +from src.loss.auxiliary import AuxiliaryLoss + +from typing import List + +################################################################################ +# Combine multiple auxiliary losses with a weighted sum +################################################################################ + + +class SumLoss(AuxiliaryLoss): + """ + Calculates a weighted sum of auxiliary loss functions + """ + def __init__(self, reduction: str = 'none'): + super().__init__(reduction=reduction) + + self._loss_functions: nn.ModuleList[nn.Module] = nn.ModuleList() + self._loss_weights: List[float] = [] + + def add_loss_function(self, + loss: AuxiliaryLoss, + weight: float) -> AuxiliaryLoss: + """ + Adds loss function to `_loss_functions` with `_loss_weights` + """ + + assert loss.reduction == 'none', \ + "Losses must provide unreduced batch values" + + self._loss_functions.append(loss) + self._loss_weights.append(weight) + + return self + + def _compute_loss(self, + x: torch.Tensor, + x_ref: torch.Tensor = None) -> torch.Tensor: + """ + Compute weighted sum over all losses + """ + + # require batch dimension + assert x.ndim >= 2 + n_batch = x.shape[0] + + # compute unreduced total batch loss + loss_total = torch.zeros(n_batch).to(x.device) + + for loss, weight in zip(self._loss_functions, self._loss_weights): + loss_total += weight * loss(x, x_ref) + return loss_total + + def set_reference(self, x_ref: torch.Tensor): + """ + Compute and cache reference representation(s) for all stored losses + """ + for loss in self._loss_functions: + loss.set_reference(x_ref) diff --git a/voicebox/src/models/__init__.py b/voicebox/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2029bc727cee937dcac375bdbc7faa067aa61d --- /dev/null +++ b/voicebox/src/models/__init__.py @@ -0,0 +1,4 @@ +from src.models.model import Model +from src.models.speaker import * +from src.models.speech import * +from src.models.denoiser import * diff --git a/voicebox/src/models/__pycache__/__init__.cpython-310.pyc b/voicebox/src/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b0a018d985beb3c838a72c19050e584c1cae0da Binary files /dev/null and b/voicebox/src/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/models/__pycache__/__init__.cpython-39.pyc b/voicebox/src/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3371185cc5839e9903e51b53a40fd8f244f4ce2e Binary files /dev/null and b/voicebox/src/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/models/__pycache__/model.cpython-310.pyc b/voicebox/src/models/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf80537e41caaf48acbb09495a3116e66b5ca67f Binary files /dev/null and b/voicebox/src/models/__pycache__/model.cpython-310.pyc differ diff --git a/voicebox/src/models/__pycache__/model.cpython-39.pyc b/voicebox/src/models/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f58f0da6ac933c947bffce9de801301e03a64e79 Binary files /dev/null and b/voicebox/src/models/__pycache__/model.cpython-39.pyc differ diff --git a/voicebox/src/models/denoiser/__init__.py b/voicebox/src/models/denoiser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad9c7b5ad84eceb52335209e9ba7e510a356aa4 --- /dev/null +++ b/voicebox/src/models/denoiser/__init__.py @@ -0,0 +1 @@ +from src.models.denoiser.demucs import Demucs, load_demucs diff --git a/voicebox/src/models/denoiser/__pycache__/__init__.cpython-310.pyc b/voicebox/src/models/denoiser/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9dbc0603c13f5443c630201671ab5f471e8b59a Binary files /dev/null and b/voicebox/src/models/denoiser/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/models/denoiser/__pycache__/__init__.cpython-39.pyc b/voicebox/src/models/denoiser/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..367378998b362483e4beabd237308c8d55dd7bf4 Binary files /dev/null and b/voicebox/src/models/denoiser/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/models/denoiser/demucs/__init__.py b/voicebox/src/models/denoiser/demucs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aaa92a765adf70070a6c251ab41523f63bf0fc68 --- /dev/null +++ b/voicebox/src/models/denoiser/demucs/__init__.py @@ -0,0 +1,2 @@ +from src.models.denoiser.demucs.demucs import Demucs +from src.models.denoiser.demucs.pretrained import load_demucs diff --git a/voicebox/src/models/denoiser/demucs/__pycache__/__init__.cpython-310.pyc b/voicebox/src/models/denoiser/demucs/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18ed8edc8ccbb39c80b47af1bda6f1183c5e0650 Binary files /dev/null and b/voicebox/src/models/denoiser/demucs/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/models/denoiser/demucs/__pycache__/__init__.cpython-39.pyc b/voicebox/src/models/denoiser/demucs/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed531a32f98936c0b37f71a3edd077ccbb4e7cb6 Binary files /dev/null and b/voicebox/src/models/denoiser/demucs/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/models/denoiser/demucs/__pycache__/demucs.cpython-310.pyc b/voicebox/src/models/denoiser/demucs/__pycache__/demucs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3156a0d06f0be5d55363712c7c516351d95f9756 Binary files /dev/null and b/voicebox/src/models/denoiser/demucs/__pycache__/demucs.cpython-310.pyc differ diff --git a/voicebox/src/models/denoiser/demucs/__pycache__/demucs.cpython-39.pyc b/voicebox/src/models/denoiser/demucs/__pycache__/demucs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caab0d95685771dcc8b5a2fdd06a6152680b6799 Binary files /dev/null and b/voicebox/src/models/denoiser/demucs/__pycache__/demucs.cpython-39.pyc differ diff --git a/voicebox/src/models/denoiser/demucs/__pycache__/pretrained.cpython-310.pyc b/voicebox/src/models/denoiser/demucs/__pycache__/pretrained.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5e8456501f3858b2cca8f1a44d56d90ed985477 Binary files /dev/null and b/voicebox/src/models/denoiser/demucs/__pycache__/pretrained.cpython-310.pyc differ diff --git a/voicebox/src/models/denoiser/demucs/__pycache__/pretrained.cpython-39.pyc b/voicebox/src/models/denoiser/demucs/__pycache__/pretrained.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bdf098430cb717984361b45609f4b4a11daf533 Binary files /dev/null and b/voicebox/src/models/denoiser/demucs/__pycache__/pretrained.cpython-39.pyc differ diff --git a/voicebox/src/models/denoiser/demucs/__pycache__/resample.cpython-310.pyc b/voicebox/src/models/denoiser/demucs/__pycache__/resample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..072a0c0f1d51290e4c946e77f58fb4c0d27366fa Binary files /dev/null and b/voicebox/src/models/denoiser/demucs/__pycache__/resample.cpython-310.pyc differ diff --git a/voicebox/src/models/denoiser/demucs/__pycache__/resample.cpython-39.pyc b/voicebox/src/models/denoiser/demucs/__pycache__/resample.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8667b7a4455d57e29f48677fcc37c53f212a9778 Binary files /dev/null and b/voicebox/src/models/denoiser/demucs/__pycache__/resample.cpython-39.pyc differ diff --git a/voicebox/src/models/denoiser/demucs/demucs.py b/voicebox/src/models/denoiser/demucs/demucs.py new file mode 100644 index 0000000000000000000000000000000000000000..f9e254aec89c0ddcd7369660dffe5ac3db258bd5 --- /dev/null +++ b/voicebox/src/models/denoiser/demucs/demucs.py @@ -0,0 +1,346 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math + +from src.models.denoiser.demucs.resample import upsample2, downsample2 + +################################################################################ +# DEMUCS U-Net denoiser architecture +################################################################################ + + +class Demucs(nn.Module): + + def __init__(self, + hidden_dim: int = 64, + growth: float = 1.0, + depth: int = 5, + causal: bool = True, + resample: int = 4, + rescale: float = 0.1, + stride_conv: int = 4, + kernel_conv: int = 8, + stride_glu: int = 1, + kernel_glu: int = 1, + original: bool = True, + use_bias: bool = True, + normalize: bool = True + ): + """ + Construct Demucs-like waveform convolutional denoiser architecture. + Adapted from https://github.com/facebookresearch/denoiser. + + :param growth: "growth" coefficient on the channel size of hidden representations + :param hidden_dim: base channel size of hidden representations + :param depth: number of convolutional blocks in the encoder and decoder + :param causal: if False, use bidirectional LSTM in bottleneck + :param resample: input resampling factor + :param rescale: rescaling factor to apply to all convolutional weights + :param stride_conv: stride of convolutional layers + :param kernel_conv: kernel size of convolutional layers + :param stride_glu: stride of channel-expanding pre-GLU convolutional layers + :param kernel_glu: kernel size of channel-expanding pre-GLU convolutional layers + :param original: if True, use ReLU activation on initial convolutional layer + :param use_bias: if True, use bias in all convolutional layers + :param normalize: if True, normalize input audio + """ + + super().__init__() + + # define forward-pass behaviors + self.original = original + self.causal = causal + self.normalize = normalize + self.resample = resample + + # store for receptive field & valid length computations + self.depth = depth + self.stride_conv = stride_conv + self.kernel_conv = kernel_conv + self.stride_glu = stride_glu + self.kernel_glu = kernel_glu + + assert resample in [1, 2, 4], "Resampling factor must be 1, 2 or 4." + + # construct waveform convolutional encoder and decoder + encoder_blocks = [] + decoder_blocks = [] + + for i in range(depth): + encoder_blocks.append( + self._build_encoder_block( + level=i, + hidden_dim=hidden_dim, + growth=growth, + stride_conv=stride_conv, + kernel_conv=kernel_conv, + stride_glu=stride_glu, + kernel_glu=kernel_glu, + use_relu=original or i, + use_bias=use_bias + ) + ) + decoder_blocks.append( + self._build_decoder_block( + level=depth - i - 1, + hidden_dim=hidden_dim, + growth=growth, + stride_conv=stride_conv, + kernel_conv=kernel_conv, + stride_glu=stride_glu, + kernel_glu=kernel_glu, + use_relu=depth - i - 1 > 0, # omit activation from final decoder layer + use_bias=use_bias + ) + ) + + self.encoder = nn.ModuleList(encoder_blocks) + self.decoder = nn.ModuleList(decoder_blocks) + + # rescale convolutional weights upon initialization + if rescale: + self._rescale_conv(rescale) + + # construct recurrent latent bottleneck + encoder_channels = int(growth * hidden_dim * (2 ** (depth - 1))) + self.rnn = nn.LSTM( + input_size=encoder_channels, + hidden_size=encoder_channels, + num_layers=2, + bidirectional=not causal, + bias=use_bias + ) + + # only apply linear projection for non-causal bidirectional LSTM + if not causal: + self.linear = nn.Linear( + 2*encoder_channels, + encoder_channels, + bias=use_bias + ) + else: + self.linear = nn.Identity() + + def _rescale_conv(self, reference: float): + """ + Rescale all convolutional and transpose-convolutional weights + and biases to reference scale. + """ + for module in self.modules(): + if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): + + std = module.weight.std().detach() + scale = (std / reference)**0.5 + module.weight.data /= scale + if module.bias is not None: + module.bias.data /= scale + + @staticmethod + def _build_encoder_block(level: int, + hidden_dim: int, + growth: float, + stride_conv: int, + kernel_conv: int, + stride_glu: int, + kernel_glu: int, + use_relu: bool, + use_bias: bool) -> nn.Module: + + in_channels = 1 if not level else int(hidden_dim * growth * (2**(level - 1))) + out_channels = int(hidden_dim * growth * (2 ** level)) + + conv = nn.Conv1d( + in_channels, + out_channels, + kernel_conv, + stride=stride_conv, + bias=use_bias + ) + relu = nn.ReLU() if use_relu else nn.Identity() + conv_glu = nn.Conv1d( + out_channels, + 2*out_channels, + kernel_glu, + stride=stride_glu, + padding=kernel_glu//2, + bias=use_bias + ) + glu = nn.GLU(dim=1) + + return nn.Sequential(conv, relu, conv_glu, glu) + + @staticmethod + def _build_decoder_block(level: int, + hidden_dim: int, + growth: float, + stride_conv: int, + kernel_conv: int, + stride_glu: int, + kernel_glu: int, + use_relu: bool, + use_bias: bool) -> nn.Module: + + in_channels = int(hidden_dim * growth * (2 ** level)) + out_channels = 1 if not level else int(hidden_dim * growth * (2**(level - 1))) + + deconv_glu = nn.Conv1d( + in_channels, + 2*in_channels, + kernel_glu, + stride=stride_glu, + padding=kernel_glu//2, + bias=use_bias + ) + glu = nn.GLU(dim=1) + deconv = nn.ConvTranspose1d( + in_channels, + out_channels, + kernel_conv, + stride=stride_conv, + bias=use_bias + ) + relu = nn.ReLU() if use_relu else nn.Identity() + + return nn.Sequential(deconv_glu, glu, deconv, relu) + + @property + def total_stride(self): + return (self.stride_conv * self.stride_glu) ** self.depth // self.resample + + def valid_length(self, length): + """ + Return the nearest valid input length to the model such that there are + no time steps "left over" in a convolution, i.e. for all layers + + input_length - kernel_length % stride_length = 0 + + If the input has a valid length, the corresponding decoded signal + will have exactly the same length. + """ + + # compute length through input resampling operation + length = math.ceil(length * self.resample) + + # compute output length through each encoder layer + for idx in range(self.depth): + length = math.ceil((length - self.kernel_conv) / self.stride_conv) + 1 + length = max(length, 1) + length = math.ceil((length - self.kernel_glu) / self.stride_glu) + 1 + length = max(length, 1) + + # compute output length through each decoder layer, assuming constant + # convolutional kernel + for idx in range(self.depth): + length = (length - 1) * self.stride_conv + self.kernel_conv + + # compute length through output downsampling operation + length = int(math.ceil(length / self.resample)) + return int(length) + + def encode(self, x: torch.Tensor): + """ + Given waveform input, obtain encoder output, discarding intermediate + (skip-connection) outputs + """ + + # require batch, channel dimensions + assert x.ndim >= 2 + + if x.ndim == 2: + x = x.unsqueeze(1) + + # convert to mono audio + x = x.mean(dim=1, keepdim=True) + + # normalize + if self.normalize: + std = x.std(dim=-1, keepdim=True) + x = x / (1e-3 + std) + + # zero-pad end of signal to ensure input and output have same length + length = x.shape[-1] + x = F.pad(x, (0, self.valid_length(length) - length)) + + # upsample input waveform + if self.resample == 2: + x = upsample2(x) + elif self.resample == 4: + x = upsample2(x) + x = upsample2(x) + + # pass through encoder layers + for encode in self.encoder: + x = encode(x) + + return x + + def bottleneck(self, encoded: torch.Tensor): + + encoded = encoded.permute(2, 0, 1) # time, batch, channels + self.rnn.flatten_parameters() + + # per-timestep output, plus final hidden and cell states + out, (hidden_state, cell_state) = self.rnn(encoded) + out = self.linear(out) + out = out.permute(1, 2, 0) + + return out + + def forward(self, x: torch.Tensor, *args, **kwargs): + + # require batch, channel dimensions + assert x.ndim >= 2 + + if x.ndim == 2: + x = x.unsqueeze(1) + + # convert to mono audio + x = x.mean(dim=1, keepdim=True) + + # normalize and store standard deviation for output scaling + if self.normalize: + std = x.std(dim=-1, keepdim=True) + x = x / (1e-3 + std) + else: + std = 1 + + # zero-pad end of signal to ensure input and output have same length + length = x.shape[-1] + x = F.pad(x, (0, self.valid_length(length) - length)) + + # upsample input waveform + if self.resample == 2: + x = upsample2(x) + elif self.resample == 4: + x = upsample2(x) + x = upsample2(x) + + # U-Net architecture: store skip connections from encoder outputs + skips = [] + for encode in self.encoder: + x = encode(x) + skips.append(x) + + # pass through recurrent bottleneck + x = self.bottleneck(x) + + # U-Net architecture: add skip connections to decoder inputs + for decode in self.decoder: + skip = skips.pop(-1) + x = x + skip[..., :x.shape[-1]] + x = decode(x) + + # downsample output waveform + if self.resample == 2: + x = downsample2(x) + elif self.resample == 4: + x = downsample2(x) + x = downsample2(x) + + # trim to original length + x = x[..., :length] + + # restore original scale + return std * x diff --git a/voicebox/src/models/denoiser/demucs/pretrained.py b/voicebox/src/models/denoiser/demucs/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..ab87708238c3b202b3d689c90788f41ed42e26bd --- /dev/null +++ b/voicebox/src/models/denoiser/demucs/pretrained.py @@ -0,0 +1,103 @@ +import torch +from pathlib import Path + +from src.constants import MODELS_DIR +from src.models.denoiser.demucs.demucs import Demucs + +################################################################################ +# Load pretrained DEMUCS U-Net denoiser +################################################################################ + + +def load_demucs(name: str, pretrained: bool = True): + """Load a pretrained Demucs denoiser model by name.""" + + # causal, hidden dimension 48, trained on DNS dataset + if name == 'dns_48': + config = { + 'hidden_dim': 48, + 'depth': 5, + 'resample': 4, + 'stride_conv': 4, + 'kernel_conv': 8, + 'growth': 1.0, + 'causal': True, + 'normalize': True + } + path = Path(MODELS_DIR) / 'denoiser' / 'demucs' / 'dns_48.pt' + + # causal, hidden dimension 64, trained on DNS dataset + elif name == 'dns_64': + config = { + 'hidden_dim': 64, + 'depth': 5, + 'resample': 4, + 'stride_conv': 4, + 'kernel_conv': 8, + 'growth': 1.0, + 'causal': True, + 'normalize': True + } + path = Path(MODELS_DIR) / 'denoiser' / 'demucs' / 'dns_64.pt' + raise NotImplementedError(f'Demucs model `dns_64` not currently ' + f'supported due to file size') + + # causal, hidden dimension 64 + elif name == 'master_64': + config = { + 'hidden_dim': 64, + 'depth': 5, + 'resample': 4, + 'stride_conv': 4, + 'kernel_conv': 8, + 'growth': 1.0, + 'causal': True, + 'normalize': True + } + path = Path(MODELS_DIR) / 'denoiser' / 'demucs' / 'master_64.pt' + raise NotImplementedError(f'Demucs model `master_64` not currently ' + f'supported due to file size') + + # non-causal, hidden dimension 64, trained on Valentini dataset + elif name == 'valentini_nc': + config = { + 'hidden_dim': 64, + 'depth': 5, + 'resample': 2, + 'stride_conv': 2, + 'kernel_conv': 8, + 'growth': 1.0, + 'causal': False, + 'normalize': True + } + path = Path(MODELS_DIR) / 'denoiser' / 'demucs' / 'valentini_nc.pt' + raise NotImplementedError(f'Demucs model `valentini_nc` not currently ' + f'supported due to file size') + + elif name == 'experimental_small': + + config = { + 'hidden_dim': 48, + 'depth': 3, + 'resample': 4, + 'stride_conv': 4, + 'kernel_conv': 8, + 'growth': 1.0, + 'causal': True, + 'normalize': True, + 'original': False, + 'use_bias': False + } + path = None + + else: + raise ValueError(f'Invalid model name {name}') + + # initialize model + model = Demucs(**config) + + # load pretrained weights from checkpoint file + if pretrained and path is not None: + model.load_state_dict(torch.load(path)) + + return model diff --git a/voicebox/src/models/denoiser/demucs/resample.py b/voicebox/src/models/denoiser/demucs/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..4e1b68f44962848d20a033a4f34e3bf4ff67a0e8 --- /dev/null +++ b/voicebox/src/models/denoiser/demucs/resample.py @@ -0,0 +1,76 @@ +import torch +import torch.nn.functional as F + +import math + +################################################################################ +# Resampling utilities for DEMUCS architecture +################################################################################ + + +def sinc(x: torch.Tensor): + """ + Sinc function. + """ + return torch.where( + x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(x) / x + ) + + +def kernel_upsample2(zeros=56): + """ + Compute windowed sinc kernel for upsampling by a factor of 2. + """ + win = torch.hann_window(4 * zeros + 1, periodic=False) + win_odd = win[1::2] + t = torch.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) + t *= math.pi + kernel = (sinc(t) * win_odd).view(1, 1, -1) + return kernel + + +def upsample2(x, zeros=56): + """ + Upsample input by a factor of 2 using sinc interpolation. + """ + *other, time = x.shape + kernel = kernel_upsample2(zeros).to(x) + out = F.conv1d( + x.view(-1, 1, time), + kernel, + padding=zeros + )[..., 1:].view(*other, time) + y = torch.stack([x, out], dim=-1) + return y.view(*other, -1) + + +def kernel_downsample2(zeros=56): + """ + Compute windowed sinc kernel for downsampling by a factor of 2. + """ + win = torch.hann_window(4 * zeros + 1, periodic=False) + win_odd = win[1::2] + t = torch.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) + t.mul_(math.pi) + kernel = (sinc(t) * win_odd).view(1, 1, -1) + return kernel + + +def downsample2(x, zeros=56): + """ + Downsample input by a factor of 2 using sinc interpolation. + """ + if x.shape[-1] % 2 != 0: + x = F.pad(x, (0, 1)) + x_even = x[..., ::2] + x_odd = x[..., 1::2] + *other, time = x_odd.shape + kernel = kernel_downsample2(zeros).to(x) + out = x_even + F.conv1d( + x_odd.view(-1, 1, time), + kernel, + padding=zeros + )[..., :-1].view(*other, time) + return out.view(*other, -1).mul(0.5) diff --git a/voicebox/src/models/model.py b/voicebox/src/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..bb1621bdf964a2e115141ac7a3e2fc350732236f --- /dev/null +++ b/voicebox/src/models/model.py @@ -0,0 +1,38 @@ +import torch +from torch import nn + +################################################################################ +# Wrapper for all PyTorch audio classifiers +################################################################################ + + +class Model(nn.Module): + """ + Wrapper class for PyTorch models; provides a consistent interface for + attack algorithms and prediction + """ + + def __init__(self): + """ + Initialize model + """ + super().__init__() + + def forward(self, x: torch.Tensor): + """ + Perform forward pass + """ + raise NotImplementedError() + + def load_weights(self, path: str): + """ + Load weights from checkpoint file + """ + raise NotImplementedError() + + @staticmethod + def match_predict(y_pred: torch.Tensor, y_true: torch.Tensor): + """ + Determine whether target pairs are equivalent + """ + raise NotImplementedError() diff --git a/voicebox/src/models/phoneme/__init__.py b/voicebox/src/models/phoneme/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6cda0147dc05b50d0ef1c5df91edf5049cfb7f3c --- /dev/null +++ b/voicebox/src/models/phoneme/__init__.py @@ -0,0 +1 @@ +from src.models.phoneme.phoneme import PPGEncoder, Delta diff --git a/voicebox/src/models/phoneme/__pycache__/__init__.cpython-310.pyc b/voicebox/src/models/phoneme/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e96efe73568631e8001720ab5edb63ab49f60ae Binary files /dev/null and b/voicebox/src/models/phoneme/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/models/phoneme/__pycache__/__init__.cpython-39.pyc b/voicebox/src/models/phoneme/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a8b8605eea715c5b96e0d2551ed272cb7f226b4 Binary files /dev/null and b/voicebox/src/models/phoneme/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/models/phoneme/__pycache__/phoneme.cpython-310.pyc b/voicebox/src/models/phoneme/__pycache__/phoneme.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f8e1c262faf808410a218d97968d319f6a86cea Binary files /dev/null and b/voicebox/src/models/phoneme/__pycache__/phoneme.cpython-310.pyc differ diff --git a/voicebox/src/models/phoneme/__pycache__/phoneme.cpython-39.pyc b/voicebox/src/models/phoneme/__pycache__/phoneme.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd550cd4db9d37a859ee0af2eb66f84c20619a3f Binary files /dev/null and b/voicebox/src/models/phoneme/__pycache__/phoneme.cpython-39.pyc differ diff --git a/voicebox/src/models/phoneme/phoneme.py b/voicebox/src/models/phoneme/phoneme.py new file mode 100644 index 0000000000000000000000000000000000000000..7f12b1eca019bf6238b72febf4aad00ae612b13d --- /dev/null +++ b/voicebox/src/models/phoneme/phoneme.py @@ -0,0 +1,117 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchaudio.transforms import MFCC + +from typing import Callable + +from src.data import DataProperties + +################################################################################ +# Phoneme predictor model from AC-VC (Ronssin & Cernak) +################################################################################ + + +class Delta(nn.Module): + """Causal delta computation""" + def forward(self, x: torch.Tensor): + + x = F.pad(x, (0, 1)) + x = torch.diff(x, n=1, dim=-1) + + return x + + +class PPGEncoder(nn.Module): + """ + Phonetic posteriorgram (PPG) predictor from Almost-Causal Voice Conversion + """ + def __init__(self, + win_length: int = 256, + hop_length: int = 128, + win_func: Callable = torch.hann_window, + n_mels: int = 32, + n_mfcc: int = 13, + lstm_depth: int = 2, + hidden_size: int = 512 + ): + """ + Parameters + ---------- + + win_length (int): spectrogram window length in samples + + hop_length (int): spectrogram hop length in samples + + win_func (Callable): spectrogram window function + + n_mels (int): number of mel-frequency bins + + n_mfcc (int): number of cepstral coefficients + + lstm_depth (int): number of LSTM layers + + hidden_size (int): hidden layer dimension for MLP and LSTM + """ + + super().__init__() + + self.win_length = win_length + self.hop_length = hop_length + + # compute spectral representation + mel_kwargs = { + "n_fft": self.win_length, + "win_length": self.win_length, + "hop_length": self.hop_length, + "window_fn": win_func, + "n_mels": n_mels + } + self.mfcc = MFCC( + sample_rate=DataProperties.get("sample_rate"), + n_mfcc=n_mfcc, + log_mels=True, + melkwargs=mel_kwargs + ) + + # compute first- and second-order MFCC deltas + self.delta = Delta() + + # PPG network + self.mlp = nn.Sequential( + nn.Linear(n_mfcc * 3, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.ReLU() + ) + self.lstm = nn.LSTM( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=lstm_depth, + bias=True, + batch_first=True, + bidirectional=False + ) + + def forward(self, x: torch.Tensor): + + # require batch, channel dimensions + assert x.ndim >= 2 + if x.ndim == 2: + x = x.unsqueeze(1) + + # convert to mono audio + x = x.mean(dim=1) + + mfcc = self.mfcc(x) # (n_batch, n_mfcc, n_frames) + delta1 = self.delta(mfcc) # (n_batch, n_mfcc, n_frames) + delta2 = self.delta(delta1) # (n_batch, n_mfcc, n_frames) + + x = torch.cat([mfcc, delta1, delta2], dim=1) # (n_batch, 3 * n_mfcc, n_frames) + x = x.permute(0, 2, 1) # (n_batch, n_frames, 3 * n_mfcc) + + x = self.mlp(x) # (n_batch, n_frames, hidden_size) + x, _ = self.lstm(x) # (n_batch, n_frames, hidden_size) + + return x \ No newline at end of file diff --git a/voicebox/src/models/speaker/__init__.py b/voicebox/src/models/speaker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2ca91d630eb3e84968e2e09aec146351a1e6ad --- /dev/null +++ b/voicebox/src/models/speaker/__init__.py @@ -0,0 +1,5 @@ +from src.models.speaker.resnetse34v2 import ResNetSE34V2 +from src.models.speaker.resemblyzer import Resemblyzer +from src.models.speaker.yvector import YVector +from src.models.speaker.speaker import EmbeddingDistance, SpeakerVerificationModel + diff --git a/voicebox/src/models/speaker/__pycache__/__init__.cpython-310.pyc b/voicebox/src/models/speaker/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6fdaa95a2e52de47b234afcd9dc7ed07b7efefc Binary files /dev/null and b/voicebox/src/models/speaker/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/models/speaker/__pycache__/__init__.cpython-39.pyc b/voicebox/src/models/speaker/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e4dcaee4035247ba24dd924d474cf6b2616d74d Binary files /dev/null and b/voicebox/src/models/speaker/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/models/speaker/__pycache__/speaker.cpython-310.pyc b/voicebox/src/models/speaker/__pycache__/speaker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40d1afaaab40d1be30ff3683df62ddf70a3b7fea Binary files /dev/null and b/voicebox/src/models/speaker/__pycache__/speaker.cpython-310.pyc differ diff --git a/voicebox/src/models/speaker/__pycache__/speaker.cpython-39.pyc b/voicebox/src/models/speaker/__pycache__/speaker.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cfffa7063f7f720bff791482fb451006e414ce2 Binary files /dev/null and b/voicebox/src/models/speaker/__pycache__/speaker.cpython-39.pyc differ diff --git a/voicebox/src/models/speaker/resemblyzer/__init__.py b/voicebox/src/models/speaker/resemblyzer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5155b2b61faa8495ac1834f1724c25bcc65526ae --- /dev/null +++ b/voicebox/src/models/speaker/resemblyzer/__init__.py @@ -0,0 +1 @@ +from src.models.speaker.resemblyzer.resemblyzer import Resemblyzer diff --git a/voicebox/src/models/speaker/resemblyzer/__pycache__/__init__.cpython-310.pyc b/voicebox/src/models/speaker/resemblyzer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad8f0e996ed045fe86237644650937357a142c5e Binary files /dev/null and b/voicebox/src/models/speaker/resemblyzer/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/models/speaker/resemblyzer/__pycache__/__init__.cpython-39.pyc b/voicebox/src/models/speaker/resemblyzer/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6f6ed402ac33206515fc0bf74f7190db8617d40 Binary files /dev/null and b/voicebox/src/models/speaker/resemblyzer/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/models/speaker/resemblyzer/__pycache__/resemblyzer.cpython-310.pyc b/voicebox/src/models/speaker/resemblyzer/__pycache__/resemblyzer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aa61532930b2142b35b5cf946125acd19a95580 Binary files /dev/null and b/voicebox/src/models/speaker/resemblyzer/__pycache__/resemblyzer.cpython-310.pyc differ diff --git a/voicebox/src/models/speaker/resemblyzer/__pycache__/resemblyzer.cpython-39.pyc b/voicebox/src/models/speaker/resemblyzer/__pycache__/resemblyzer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4e2d143b61579461f7af90620a4da7c5bafa03e Binary files /dev/null and b/voicebox/src/models/speaker/resemblyzer/__pycache__/resemblyzer.cpython-39.pyc differ diff --git a/voicebox/src/models/speaker/resemblyzer/resemblyzer.py b/voicebox/src/models/speaker/resemblyzer/resemblyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..eca26bcd57b0d5039613972836ec604ca0ebf395 --- /dev/null +++ b/voicebox/src/models/speaker/resemblyzer/resemblyzer.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + +from torchaudio.transforms import MelSpectrogram + +from src.data import DataProperties + +from typing import Union + +################################################################################ +# ResNetSE34V2 spectrogram convolutional model for speaker verification +################################################################################ + + +class Resemblyzer(nn.Module): + + def __init__(self, + hidden_size: int = 256, + embedding_size: int = 256, + layers: int = 3, + win_length: int = 400, + hop_length: int = 160, + n_mels: int = 40, + **kwargs): + """ + Resemblyzer speaker embedding model, based on the system proposed in Wan + et al. 2020 (https://arxiv.org/pdf/1710.10467.pdf). Code adapted from + https://github.com/resemble-ai/Resemblyzer. + + Parameters + ---------- + + Returns + ------- + + """ + super().__init__() + + # mel spectrogram + self.spec = MelSpectrogram( + sample_rate=DataProperties.get('sample_rate'), + n_fft=win_length, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mels, + center=True, + window_fn=torch.hann_window, + pad_mode='reflect', + norm='slaney', + mel_scale='slaney', + power=2.0 + ) + + # simple LSTM network for computing embeddings + self.lstm = nn.LSTM(n_mels, hidden_size, layers, batch_first=True) + self.linear = nn.Linear(hidden_size, embedding_size) + self.relu = nn.ReLU() + + def forward(self, x: torch.Tensor): + """ + Compute speaker embeddings for a batch of utterances. + + Parameters + ---------- + x (Tensor): + + Returns + ------- + emb (Tensor): + + """ + + # require batch dimension + assert x.ndim >= 2 + n_batch, *channel_dims, signal_len = x.shape + + # add channel dimension if necessary + if len(channel_dims) == 0: + x = x.unsqueeze(1) + + # discard channel dimensions + x = x.mean(1) + + # compute mel spectrogram + x = self.spec(x).permute(0, 2, 1) # (n_batch, n_frames, n_mels) + + # extract embeddings from final hidden layer of network + _, (hidden, _) = self.lstm(x) + emb = self.relu(self.linear(hidden[-1])) # (batch_size, embedding_size) + + return emb diff --git a/voicebox/src/models/speaker/resnetse34v2/__init__.py b/voicebox/src/models/speaker/resnetse34v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0abb337f39049941520aa3c2f0334513996eb933 --- /dev/null +++ b/voicebox/src/models/speaker/resnetse34v2/__init__.py @@ -0,0 +1 @@ +from src.models.speaker.resnetse34v2.resnetse34v2 import ResNetSE34V2 diff --git a/voicebox/src/models/speaker/resnetse34v2/__pycache__/__init__.cpython-310.pyc b/voicebox/src/models/speaker/resnetse34v2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31d0564c57d51fff5bad2217c510523a94f2ceed Binary files /dev/null and b/voicebox/src/models/speaker/resnetse34v2/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/models/speaker/resnetse34v2/__pycache__/__init__.cpython-39.pyc b/voicebox/src/models/speaker/resnetse34v2/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..148d4ca1abd7a9d55e884b767e4ef7c107bd0f73 Binary files /dev/null and b/voicebox/src/models/speaker/resnetse34v2/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/models/speaker/resnetse34v2/__pycache__/resnet_blocks.cpython-310.pyc b/voicebox/src/models/speaker/resnetse34v2/__pycache__/resnet_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..326a7eb4d9c85304faca4bf0b6dd13c49be9ecb5 Binary files /dev/null and b/voicebox/src/models/speaker/resnetse34v2/__pycache__/resnet_blocks.cpython-310.pyc differ diff --git a/voicebox/src/models/speaker/resnetse34v2/__pycache__/resnet_blocks.cpython-39.pyc b/voicebox/src/models/speaker/resnetse34v2/__pycache__/resnet_blocks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea2b14699cbab09e27112f5da772b45e874c456f Binary files /dev/null and b/voicebox/src/models/speaker/resnetse34v2/__pycache__/resnet_blocks.cpython-39.pyc differ diff --git a/voicebox/src/models/speaker/resnetse34v2/__pycache__/resnetse34v2.cpython-310.pyc b/voicebox/src/models/speaker/resnetse34v2/__pycache__/resnetse34v2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33ecff65f9ac9a8c6e59b1374b9bba4e84adf17c Binary files /dev/null and b/voicebox/src/models/speaker/resnetse34v2/__pycache__/resnetse34v2.cpython-310.pyc differ diff --git a/voicebox/src/models/speaker/resnetse34v2/__pycache__/resnetse34v2.cpython-39.pyc b/voicebox/src/models/speaker/resnetse34v2/__pycache__/resnetse34v2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac4d5c5dcf5e38b0de29f8e37082c101fa1d8d77 Binary files /dev/null and b/voicebox/src/models/speaker/resnetse34v2/__pycache__/resnetse34v2.cpython-39.pyc differ diff --git a/voicebox/src/models/speaker/resnetse34v2/resnet_blocks.py b/voicebox/src/models/speaker/resnetse34v2/resnet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..fdc68efe06c7c1e9977b614df7eef3b5a9c3b515 --- /dev/null +++ b/voicebox/src/models/speaker/resnetse34v2/resnet_blocks.py @@ -0,0 +1,129 @@ +import torch.nn as nn + +################################################################################ +# Residual blocks for ResNetSE34V2 architecture and variants +################################################################################ + + +class SEBasicBlock(nn.Module): + """ + Basic block for Squeeze-and-Excitation ResNet architecture: + input -> Conv2d -> ReLU -> BatchNorm -> Conv2D -> BatchNorm -> SE + Residual + V____________________________________________________________________^ + + Here SE refers to the squeeze-and-excitation layer, which aggregates + information across "spatial" axes to parameterize a channel-wise gate. + + Adapted from https://tinyurl.com/pdd3p8ew + """ + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + reduction=8): + + super(SEBasicBlock, self).__init__() + + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + return out + + +class SEBottleneck(nn.Module): + """ + Bottleneck block for Squeeze-and-Excitation ResNet architecture. Adapted + from https://tinyurl.com/pdd3p8ew. Not used in the default ResNetSE34V2 + architecture. + """ + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes * 4, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SELayer(nn.Module): + """ + Squeeze-and-Excitation ResNet layer. Adapted from + https://tinyurl.com/pdd3p8ew. Aggregates global spectro-temporal information + by average-pooling the "spatial" dimensions down to one value per channel. + These channel averages are passed through a two-layer feedforward network + and sigmoid activation to obtain a channel-wise gate that is applied to the + original input via multiplication. + """ + def __init__(self, channel, reduction=8): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + + return x * y diff --git a/voicebox/src/models/speaker/resnetse34v2/resnetse34v2.py b/voicebox/src/models/speaker/resnetse34v2/resnetse34v2.py new file mode 100644 index 0000000000000000000000000000000000000000..697ac187386beaaf10eb2787aa7f966958188469 --- /dev/null +++ b/voicebox/src/models/speaker/resnetse34v2/resnetse34v2.py @@ -0,0 +1,241 @@ +import torchaudio +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.data import DataProperties +from src.models.speaker.resnetse34v2.resnet_blocks import ( + SEBasicBlock, + SEBottleneck +) + +from typing import Union, Type, Iterable + +################################################################################ +# ResNetSE34V2 spectrogram convolutional model for speaker verification +################################################################################ + + +class PreEmphasis(torch.nn.Module): + """ + Original ResNet34SEV2 pre-emphasis filter implementation; see + https://github.com/clovaai/voxceleb_trainer. Requires two-dimensional + input (n_batch, signal_length) and produces two-dimensional output + """ + + def __init__(self, coef: float = 0.97): + super().__init__() + self.coef = coef + + self.register_buffer( + 'flipped_filter', + torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0) + ) + + def forward(self, x: torch.Tensor): + assert x.ndim == 2 + + x = x.unsqueeze(1) + x = F.pad(x, (1, 0), 'reflect') + return F.conv1d(x, self.flipped_filter).squeeze(1) + + +class ResNetSE34V2(nn.Module): + """ + ResNetSE34V2 model proposed in Heo et al. (arXiv: 2009.14153). Code adapted + from https://github.com/clovaai/voxceleb_trainer. + """ + def __init__(self, + block: Type[Union[SEBasicBlock, SEBottleneck]] = SEBasicBlock, + layers: Iterable[int] = (3, 4, 6, 3), + num_filters: Iterable[int] = (32, 64, 128, 256), + nOut: int = 512, + encoder_type: str = 'ASP', + n_mels: int = 64, + log_input: bool = True): + """ + Squeeze-and-Excitation ResNet architecture for speaker embedding. + Accepts waveform audio input of arbitrary length, converts to + spectrogram, and applies four residual/SE convolutional blocks followed + by a linear layer to produce discriminative embeddings. + + :param block: nn.Module subclass representing variant of SE block, as + defined in `src.modules.resnet_blocks.py`. + :param layers: an iterable containing the kernel dimension of each SE + layer; must have length 4 (one entry per SE layer) + :param num_filters: an iterable containing the number of filters / + output channels of each SE layer; must have length 4 + (one entry per SE layer) + :param nOut: final embedding dimension + :param encoder_type: method of aggregating frame-level features into + utterance-level features. Must be one of "SAP" + (self-attentive pooling) or "ASP" (attentive + statistics pooling) + :param n_mels: mel bins for spectrogram + :param log_input: if True, apply log to spectrogram + """ + super().__init__() + + # enforce sample rate requirement + if DataProperties.get('sample_rate') != 16000: + raise ValueError(f'Invalid sample rate ' + f'{DataProperties.get("sample_rate")}; ' + f'ResNetSE34V2 requires 16kHz audio') + + assert len(layers) == 4 + assert len(num_filters) == 4 + + self.inplanes = num_filters[0] + self.encoder_type = encoder_type + self.n_mels = n_mels + self.log_input = log_input + + # prior to SE layers, input spectrogram is passed through a "vanilla" + # convolutional layer + self.conv1 = nn.Conv2d( + 1, + num_filters[0], + kernel_size=3, + stride=1, + padding=1 + ) + self.relu = nn.ReLU(inplace=True) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + + self.layer1 = self._make_layer(block, num_filters[0], layers[0]) + self.layer2 = self._make_layer( + block, + num_filters[1], + layers[1], + stride=(2, 2) + ) + self.layer3 = self._make_layer( + block, + num_filters[2], + layers[2], + stride=(2, 2) + ) + self.layer4 = self._make_layer( + block, + num_filters[3], + layers[3], + stride=(2, 2) + ) + + self.instancenorm = nn.InstanceNorm1d(n_mels) + self.torchfb = torch.nn.Sequential( + PreEmphasis(), + torchaudio.transforms.MelSpectrogram( + sample_rate=16000, + n_fft=512, + win_length=400, + hop_length=160, + window_fn=torch.hamming_window, + n_mels=n_mels) + ) + + outmap_size = int(self.n_mels/8) + + # attention block: collapse and restore channel dimension of feature + # maps through 1x1 convolutions, then pass frame/time dimension through + # softmax to obtain a frame-wise weighting of each channel. In this + # case, can also be interpreted as using one attention head per channel? + # + # for more details, see Zhu et al. (https://bit.ly/3E10jBT) + self.attention = nn.Sequential( + nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + # for self-attentive pooling + if self.encoder_type == "SAP": + out_dim = num_filters[3] * outmap_size + # for attentive statistics pooling + elif self.encoder_type == "ASP": + out_dim = num_filters[3] * outmap_size * 2 + else: + raise ValueError(f'Undefined encoder {self.encoder_type}') + + self.fc = nn.Linear(out_dim, nOut) + + # initialize weights + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, + mode='fan_out', + nonlinearity='relu' + ) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def new_parameter(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + def forward(self, x: torch.Tensor): + + x = self.torchfb(x)+1e-6 + if self.log_input: + x = x.log() + x = self.instancenorm(x).unsqueeze(1) + + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + # consolidate channel and frequency dimensions, leaving time dimension + x = x.reshape(x.size()[0], -1, x.size()[-1]) # (n_batch, -1, time) + + w = self.attention(x) + + if self.encoder_type == "SAP": + + # apply attention weights and take weighted average over time + x = torch.sum(x * w, dim=2) # (n_batch, outmap_size) + + elif self.encoder_type == "ASP": + + # apply attention weights and take weighted average over time + mu = torch.sum(x * w, dim=2) # (n_batch, outmap_size) + + # compute standard deviation from weighted means + sg = torch.sqrt( + ( + torch.sum((x**2) * w, dim=2) - mu**2 + ).clamp(min=1e-5) + ) # (n_batch, outmap_size) + x = torch.cat((mu, sg), 1) # (n_batch, 2 * outmap_size) + + x = x.view(x.size()[0], -1) + x = self.fc(x) # (n_batch, nOut) + + return x + diff --git a/voicebox/src/models/speaker/speaker.py b/voicebox/src/models/speaker/speaker.py new file mode 100644 index 0000000000000000000000000000000000000000..082b1fb8520913afb1c379875d6ebfe99d0654c1 --- /dev/null +++ b/voicebox/src/models/speaker/speaker.py @@ -0,0 +1,286 @@ +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.models.model import Model + +################################################################################ +# Extends Model class for speaker verification +################################################################################ + + +class EmbeddingDistance(nn.Module): + """ + Compute average pair-wise embedding distances over all segments of + corresponding embeddings + """ + + def __init__(self, distance_fn: str = 'cosine'): + super().__init__() + + self.distance_fn = distance_fn + + def forward(self, x1: torch.Tensor, x2: torch.Tensor): + """ + Compute mean pairwise distance over all embedding pairs + + :param x1: embeddings of shape (n_batch, n_segments, embedding_dim) + :param x2: embeddings of shape (n_batch, n_segments, embedding_dim) + :return: mean pairwise distances of shape (n_batch,) + """ + + # match device + assert x1.device == x2.device, "Input device mismatch" + + # reshape to (n_batch, n_segments, embedding_dim) + try: + + assert x1.ndim >= 2 + assert x2.ndim >= 2 + + n_batch, embedding_dim = x1.shape[0], x1.shape[-1] + + x1 = x1.reshape(n_batch, -1, embedding_dim) + x2 = x2.reshape(n_batch, -1, embedding_dim) + + assert x1.shape == x2.shape + + except (AssertionError, RuntimeError): + + raise ValueError(f'Invalid input shapes {x1.shape}, ' + f'{x2.shape}; embeddings should have shape ' + f'(n_batch, n_segments, embedding_dim)') + + # compute distances over all segment pairs + if self.distance_fn == 'l2': + + # normalize embeddings for L2 distance + x1 = F.normalize(x1, p=2, dim=-1) + x2 = F.normalize(x2, p=2, dim=-1) + + return torch.cdist( + x1, + x2, + p=2.0 + ).reshape(-1, x1.shape[1] * x2.shape[1]).mean(-1) + + elif self.distance_fn == 'cosine': + + eps = 1e-8 # numerical stability + + dist = [] + + for i in range(n_batch): + + x1_n = x1[i].norm(dim=1)[:, None] + x2_n = x2[i].norm(dim=1)[:, None] + + x1_norm = x1[i] / torch.clamp(x1_n, min=eps) + x2_norm = x2[i] / torch.clamp(x2_n, min=eps) + + sim = torch.mm(x1_norm, x2_norm.transpose(0, 1)) + + dist_mtrx = 1 - sim + + dist.append( + dist_mtrx.reshape(-1, x1.shape[1] * x2.shape[1]).mean(-1) + ) + + return torch.cat(dist, dim=0) + + else: + raise ValueError(f'Invalid embedding distance {self.distance_fn}') + + +class SpeakerVerificationModel(Model): + """ + Perform speaker verification using a distance measured in the embedding + space of a given model. If the distance between utterance embeddings exceeds + a stored threshold, the utterances are assumed to originate from different + speakers. + """ + def __init__(self, + model: nn.Module, + n_segments: int = 1, + segment_select: str = 'lin', + distance_fn: str = 'cosine', + threshold: float = 0.0 + ): + """ + Wrap speaker verification model. + + :param model: a callable nn.Module object that produces speaker + embeddings. For inputs of shape (n_batch, signal_length) + or (n_batch, n_channels, signal_length), must produce + outputs of shape (n_batch, embedding_dim) + :param n_segments: number of segments per utterance from which to + compute speaker embeddings. One embedding is produced + per segment, resulting in outputs of shape + (n_batch, n_segments, embedding_dim) + :param segment_select: method for selecting utterance segments. Must be + `lin` (linearly-spaced) or `rand` (random) + :param distance_fn: + :param threshold: verification threshold in embedding space, according + to stored distance function + """ + + super().__init__() + + self.model = model + self.n_segments = n_segments + self.segment_select = segment_select + self.threshold = threshold + + frame_len, hop_len = 400, 160 # 25ms frame / 10ms hop at 16kHz + self.segment_frames = 400 # cap input segments at 400 frames + self.segment_len = self.segment_frames * hop_len + frame_len - hop_len + + # check input segmentation method + if segment_select not in ['lin', 'rand']: + raise ValueError(f'Invalid segment selection method' + f' {segment_select}') + self.segment_select = segment_select + + # prepare to compute pair-wise segment embedding distances + self.distance_fn = EmbeddingDistance(distance_fn) + + # store distance function and threshold to allow prediction-matching + self.threshold = threshold + + self.model.eval() + + def _pad_to_length(self, x: torch.Tensor): + """ + Pad audio to stored segment length + """ + if x.shape[-1] < self.segment_len: + return nn.functional.pad( + x, + (0, self.segment_len - x.shape[-1]) + ) + else: + return x + + def _extract_segments(self, x: torch.Tensor): + """ + Given number of segments to extract and segment length, either space + linearly or randomly. Should convert input of shape + (n_batch, signal_length) to (n_batch * n_segments, segment_length) where + segments from same utterance are consecutive along batch dimension + """ + + # if `n_segments` is nonzero, extract or trim audio to fixed-length + # segments before computing embeddings + if self.n_segments >= 1: + + # pad to allow a minimum of `n_segments` segments at set hop length + min_hop_length = 400 # fold is very memory-hungry + min_audio_length = self.segment_len + min_hop_length * (self.n_segments - 1) + + if x.shape[-1] < min_audio_length: + x = nn.functional.pad(x, (0, min_audio_length - x.shape[-1])) + + # compute segment indices + if self.segment_select == 'lin': # linear spacing + + hop = ( + x.shape[-1] - self.segment_len + ) // ( + self.n_segments - 1 + ) if self.n_segments > 1 else x.shape[-1] + x = x.unfold( + -1, self.segment_len, hop + )[:, :self.n_segments, :].contiguous() + x = x.reshape(-1, self.segment_len) + + elif self.segment_select == 'rand': + + # slice at fine resolution, and randomly select segments + x = x.unfold( + -1, self.segment_len, min_hop_length + ) + + x = x[:, torch.randperm(x.shape[1]), :] + x = x[:, :self.n_segments, :].contiguous() + x = x.reshape(-1, self.segment_len) + + else: + raise ValueError( + f'Invalid segment selection method {self.segment_select}' + ) + + return x + + def forward(self, x: torch.Tensor): + """ + Compute embeddings for input. If specified, divide inputs into segments + and compute embeddings for each. + """ + + # reshape audio + assert x.ndim >= 2 # require batch dimension + n_batch, signal_len = x.shape[0], x.shape[-1] + x = x.reshape(n_batch, signal_len) + + # pad to minimum length + x = self._pad_to_length(x) + + # divide into segments + x = self._extract_segments(x) # (n_batch * n_segments, segment_length) + + # compute embeddings + x = self.model(x) + + # group segments by corresponding utterance + x = x.reshape( + n_batch, + 1 if self.n_segments < 1 else self.n_segments, + -1 + ) # (n_batch, n_segments, segment_length) + + return x # return un-normalized embeddings by default + + def match_predict(self, y_pred: torch.Tensor, y_true: torch.Tensor): + """ + Determine whether target pairs are equivalent by checking if embedding + distance falls under stored threshold + """ + + dist = self.distance_fn(y_pred, y_true) + + return dist <= self.threshold + + def load_weights(self, path: str): + """ + Load weights from checkpoint file + """ + + # check if file exists + if not path or not os.path.isfile(path): + raise ValueError(f'Invalid path {path}') + + model_state = self.model.state_dict() + loaded_state = torch.load(path) + + for name, param in loaded_state.items(): + + origname = name + + if name not in model_state: + print("{} is not in the model.".format(origname)) + continue + + if model_state[name].size() != loaded_state[origname].size(): + print( + "Wrong parameter length: {}, model: {}, loaded: {}".format( + origname, + model_state[name].size(), + loaded_state[origname].size() + ) + ) + continue + + model_state[name].copy_(param) + diff --git a/voicebox/src/models/speaker/yvector/__init__.py b/voicebox/src/models/speaker/yvector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61326a174eebbb45725bae8702dcfff555f09aeb --- /dev/null +++ b/voicebox/src/models/speaker/yvector/__init__.py @@ -0,0 +1 @@ +from src.models.speaker.yvector.yvector import YVector diff --git a/voicebox/src/models/speaker/yvector/__pycache__/__init__.cpython-310.pyc b/voicebox/src/models/speaker/yvector/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21577e37090197d992da5c40d67adab93c8c53a8 Binary files /dev/null and b/voicebox/src/models/speaker/yvector/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/models/speaker/yvector/__pycache__/__init__.cpython-39.pyc b/voicebox/src/models/speaker/yvector/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa3e4a65ce781fb35436483bcbc07b95bd482842 Binary files /dev/null and b/voicebox/src/models/speaker/yvector/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/models/speaker/yvector/__pycache__/tdnn.cpython-310.pyc b/voicebox/src/models/speaker/yvector/__pycache__/tdnn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37aee8f3bf2f06dbfbc4985b77038ec4c9ebbf88 Binary files /dev/null and b/voicebox/src/models/speaker/yvector/__pycache__/tdnn.cpython-310.pyc differ diff --git a/voicebox/src/models/speaker/yvector/__pycache__/tdnn.cpython-39.pyc b/voicebox/src/models/speaker/yvector/__pycache__/tdnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f65dd343ae129f478319fa9dc0961f672948bf7 Binary files /dev/null and b/voicebox/src/models/speaker/yvector/__pycache__/tdnn.cpython-39.pyc differ diff --git a/voicebox/src/models/speaker/yvector/__pycache__/wav2spk.cpython-310.pyc b/voicebox/src/models/speaker/yvector/__pycache__/wav2spk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f95378bc628d09556b19ed708a18fccb40b73e75 Binary files /dev/null and b/voicebox/src/models/speaker/yvector/__pycache__/wav2spk.cpython-310.pyc differ diff --git a/voicebox/src/models/speaker/yvector/__pycache__/wav2spk.cpython-39.pyc b/voicebox/src/models/speaker/yvector/__pycache__/wav2spk.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce147413c46558842b6e27670e7fca0cbfde2d4b Binary files /dev/null and b/voicebox/src/models/speaker/yvector/__pycache__/wav2spk.cpython-39.pyc differ diff --git a/voicebox/src/models/speaker/yvector/__pycache__/yvector.cpython-310.pyc b/voicebox/src/models/speaker/yvector/__pycache__/yvector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d058552468036ff74a907b193118a51eea90dcd9 Binary files /dev/null and b/voicebox/src/models/speaker/yvector/__pycache__/yvector.cpython-310.pyc differ diff --git a/voicebox/src/models/speaker/yvector/__pycache__/yvector.cpython-39.pyc b/voicebox/src/models/speaker/yvector/__pycache__/yvector.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b541b3f164bd5fccffd643c21e1b01a3725a4aae Binary files /dev/null and b/voicebox/src/models/speaker/yvector/__pycache__/yvector.cpython-39.pyc differ diff --git a/voicebox/src/models/speaker/yvector/tdnn.py b/voicebox/src/models/speaker/yvector/tdnn.py new file mode 100644 index 0000000000000000000000000000000000000000..23ecf81126101157a39c786e427b8d6460d767c6 --- /dev/null +++ b/voicebox/src/models/speaker/yvector/tdnn.py @@ -0,0 +1,163 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm + +################################################################################ +# TDNN implementation of Zhu et al. (2021) +################################################################################ + + +class TDNNLayer(nn.Module): + + def __init__(self, input_dim, output_dim, + context_size, dilation=1): + ''' + TDNN as defined by https://www.danielpovey.com/files/2015_interspeech_multisplice.pdf + + Affine transformation not applied globally to all frames but smaller windows with local context + + batch_norm: True to include batch normalisation after the non linearity + + Context size and dilation determine the frames selected + (although context size is not really defined in the traditional sense) + For example: + context size 5 and dilation 1 is equivalent to [-2,-1,0,1,2] + context size 3 and dilation 2 is equivalent to [-2, 0, 2] + context size 1 and dilation 1 is equivalent to [0] + ''' + super(TDNNLayer, self).__init__() + self.context_size = context_size + self.input_dim = input_dim + self.output_dim = output_dim + self.dilation = dilation + self.kernel = nn.Linear(input_dim*context_size, output_dim) + + def forward(self, inputs): + ''' + input: size (batch, input_features, seq_len) + outpu: size (batch, new_seq_len, output_features) + ''' + + # ----------Convolution = unfold + matmul + fold + x = inputs + _, d, _ = x.shape + assert (d == self.input_dim), 'Input dimension was wrong. Expected ({}), got ({})'.format(self.input_dim, d) + x = x.unsqueeze(1) + + # Unfold input into smaller temporal contexts + x = F.unfold(x, (self.input_dim, self.context_size), + stride=(self.input_dim, 1), + dilation=(1, self.dilation)) + + # N, output_dim*context_size, new_t = x.shape + x = x.transpose(1, 2) + x = self.kernel(x) # matmul + + # transpose to channel first + x = x.transpose(1, 2) + + return x + + +class TDNNBlock(nn.Module): + + def __init__(self, input_dim, bn_dim, + skip, context_size, dilation=1, + bottleneck=False): + ''' + TDNNBlock + ''' + super(TDNNBlock, self).__init__() + + # bn conv + self.bottleneck = bottleneck + if bottleneck: + self.bnconv1d = nn.Conv1d(input_dim, bn_dim, 1) + self.nonlinear1 = nn.PReLU() + self.norm1 = nn.GroupNorm(1, bn_dim, eps=1e-08) + self.tdnnblock = TDNNLayer(bn_dim, input_dim, context_size, dilation) + else: + self.tdnnblock = TDNNLayer(input_dim, input_dim, context_size, dilation) + + # tdnn + self.nonlinear2 = nn.PReLU() + self.norm2 = nn.GroupNorm(1, input_dim, eps=1e-08) + + # skip connection + self.skip = skip + if self.skip: + self.skip_out = nn.MaxPool1d(kernel_size=context_size, + stride=1, dilation=dilation) + + def forward(self, x): + ''' + input: size (batch, seq_len, input_features) + outpu: size (batch, new_seq_len, output_features) + ''' + out = x + if self.bottleneck: + out = self.nonlinear1(self.bnconv1d(out)) + out = self.norm1(out) + + out = self.nonlinear2(self.tdnnblock(out)) + out = self.norm2(out) + + if self.skip: + skip = self.skip_out(x) + return out, skip + else: + return out + + +class TDNN(nn.Module): + + def __init__(self, filter_dim, input_dim, bn_dim, + skip, context_size=3, layer=9, stack=1, + bottleneck=False): + ''' + stacked TDNN Blocks + ''' + super(TDNN, self).__init__() + + # # BottleNeck Layer + # self.LN = nn.GroupNorm(1, filter_dim, eps=1e-8) + # self.BN_conv = nn.Conv1d(filter_dim, input_dim, 1) + + # Residual Connection + self.skip = skip + + # TDNN for feature extraction + self.receptive_field = 0 + + self.tdnn = nn.ModuleList([]) + for s in range(stack): + for i in range(layer): + self.tdnn.append(TDNNBlock(input_dim, bn_dim, self.skip, + context_size=3, dilation=2**i, + bottleneck=bottleneck)) + + if i == 0 and s == 0: + self.receptive_field += context_size + else: + self.receptive_field += (context_size - 1) * 2 ** i + + print("Receptive field: {:3d} frames.".format(self.receptive_field)) + + + def forward(self, x): + ''' + input: size (batch, seq_len, input_features) + outpu: size (batch, new_seq_len, output_features) + ''' + + # output = self.BN_conv(self.LN(x)) + + for i in range(len(self.tdnn)): + if self.skip: + output, skips = self.tdnn[i](x) + output = skips + output + else: + output = self.tdnn[i](output) + + return output diff --git a/voicebox/src/models/speaker/yvector/utils.py b/voicebox/src/models/speaker/yvector/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/voicebox/src/models/speaker/yvector/wav2spk.py b/voicebox/src/models/speaker/yvector/wav2spk.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc694846cef9991763dbdb25421080ac89916e7 --- /dev/null +++ b/voicebox/src/models/speaker/yvector/wav2spk.py @@ -0,0 +1,253 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +################################################################################ +# Wav2Spk implementation of Zhu et al. (2021) +################################################################################ + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, inputs): + output = F.group_norm( + inputs.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(inputs) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, inputs): + output = F.layer_norm( + inputs.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(inputs) + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +def norm_block(is_layer_norm, dim, affine=True, is_instance_norm=False): + if is_layer_norm: + mod = nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=affine), + TransposeLast(), + ) + else: + if is_instance_norm: + mod = Fp32GroupNorm(dim, dim, affine=False) # instance norm + else: + mod = Fp32GroupNorm(1, dim, affine=affine) # layer norm + + return mod + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers, + dropout=0.0, + log_compression=False, + skip_connections=False, + residual_scale=0.5, + non_affine_group_norm=False, + activation=nn.ReLU(), + is_instance_norm=True, + ): + super().__init__() + + def block(n_in, n_out, k, stride): + return nn.Sequential( + nn.Conv1d(n_in, n_out, k, stride=stride, bias=False), + nn.Dropout(p=dropout), + norm_block(is_layer_norm=False, + dim=n_out, + affine=not non_affine_group_norm, + is_instance_norm=is_instance_norm), + activation, + ) + + in_d = 1 + self.conv_layers = nn.ModuleList() + for dim, k, stride in conv_layers: + self.conv_layers.append(block(in_d, dim, k, stride)) + in_d = dim + + self.log_compression = log_compression + self.skip_connections = skip_connections + self.residual_scale = math.sqrt(residual_scale) + + def forward(self, x): + # BxT -> BxCxT + # x = x.unsqueeze(1) + + for conv in self.conv_layers: + residual = x + x = conv(x) + if self.skip_connections and x.size(1) == residual.size(1): + tsz = x.size(2) + r_tsz = residual.size(2) + residual = residual[..., :: r_tsz // tsz][..., :tsz] + x = (x + residual) * self.residual_scale + + if self.log_compression: + x = x.abs() + x = x + 1 + x = x.log() + + return x + + +class ZeroPad1d(nn.Module): + def __init__(self, pad_left, pad_right): + super().__init__() + self.pad_left = pad_left + self.pad_right = pad_right + + def forward(self, x): + return F.pad(x, (self.pad_left, self.pad_right)) + + +class ConvAggegator(nn.Module): + def __init__( + self, + conv_layers, + embed, + dropout=0.0, + skip_connections=False, + residual_scale=0.5, + non_affine_group_norm=False, + conv_bias=True, + zero_pad=False, + activation=nn.ReLU(), + ): + super().__init__() + + def block(n_in, n_out, k, stride): + # padding dims only really make sense for stride = 1 + ka = k // 2 + kb = ka - 1 if k % 2 == 0 else ka + + pad = ( + ZeroPad1d(ka + kb, 0) if zero_pad else + nn.ReplicationPad1d((ka + kb, 0)) + ) + + return nn.Sequential( + pad, + nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias), + nn.Dropout(p=dropout), + norm_block(False, n_out, affine=not non_affine_group_norm), + activation, + ) + + in_d = embed + self.conv_layers = nn.ModuleList() + self.residual_proj = nn.ModuleList() + for dim, k, stride in conv_layers: + if in_d != dim and skip_connections: + self.residual_proj.append(nn.Conv1d(in_d, dim, 1, bias=False)) + else: + self.residual_proj.append(None) + + self.conv_layers.append(block(in_d, dim, k, stride)) + in_d = dim + self.conv_layers = nn.Sequential(*self.conv_layers) + self.skip_connections = skip_connections + self.residual_scale = math.sqrt(residual_scale) + + def forward(self, x): + for rproj, conv in zip(self.residual_proj, self.conv_layers): + residual = x + x = conv(x) + if self.skip_connections: + if rproj is not None: + residual = rproj(residual) + x = (x + residual) * self.residual_scale + return x + + +class StatsPooling(nn.Module): + def __init__(self): + super(StatsPooling,self).__init__() + + def forward(self, varient_length_tensor): + mean = varient_length_tensor.mean(dim=-1) + std = varient_length_tensor.std(dim=-1) + return torch.cat((mean,std),dim=1) + + +class architecture(nn.Module): + def __init__(self, + feature_enc_layers=[(40, 10, 5), (200, 5, 4), (300, 5, 2)] + + [(512, 3, 2)]*2, + agg_layers=[(512, 3, 1)] * 4): + super(architecture, self).__init__() + # self.ln = nn.GroupNorm(1, 1, eps=1e-8) + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + is_instance_norm=True) + self.temporal_gating = nn.Sequential( + nn.Linear(feature_enc_layers[-1][0], 1), + nn.Sigmoid()) + self.feature_aggregator = ConvAggegator( + conv_layers=agg_layers, + embed=feature_enc_layers[-1][0]) + + self.statspool = StatsPooling() + self.fc1 = nn.Linear(agg_layers[-1][0] * 2, 512) + self.bn = nn.BatchNorm1d(512) + self.fc2 = nn.Linear(512, 128) + + self.lrelu = nn.LeakyReLU(0.2) + + def forward(self, x): + ''' + x: [B, L] + ''' + + if isinstance(self.feature_extractor, nn.modules.container.ModuleList): + z = [] + for encoder in self.feature_extractor: + z.append(encoder(x)) + z = torch.cat(z, dim=1) + else: + z = self.feature_extractor(x) + + # Temporal gating + g = z.permute(0, 2, 1).contiguous().view(-1, z.shape[1]) + g = self.temporal_gating(g).view(z.shape[0], z.shape[2]).unsqueeze(1) + z = z * g + + c = self.feature_aggregator(z) + c = self.statspool(c) + + # To use the regularization on the last two layers for fair comparison, may differ from original wav2spk + x = self.lrelu(self.bn(self.fc1(c))) + x = self.fc2(x) + + return x diff --git a/voicebox/src/models/speaker/yvector/yvector.py b/voicebox/src/models/speaker/yvector/yvector.py new file mode 100644 index 0000000000000000000000000000000000000000..0490a9cf481bf78572284d4fa579445142edff6b --- /dev/null +++ b/voicebox/src/models/speaker/yvector/yvector.py @@ -0,0 +1,184 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.models.speaker.yvector.tdnn import TDNNLayer +from src.models.speaker.yvector.wav2spk import ( + ConvFeatureExtractionModel, Fp32GroupNorm, norm_block +) +import numpy as np + +################################################################################ +# Y-Vector implementation of Zhu et al. (2021) +################################################################################ + + +class SEBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.fgate = nn.Sequential(nn.Linear(channels, channels), nn.Sigmoid()) + self.tgate = nn.Sequential(nn.Linear(channels, 1), nn.Sigmoid()) + + def forward(self, x): + + fg = self.fgate(x.mean(dim=-1)) + x = x * fg.unsqueeze(-1) + + tg = x.permute(0, 2, 1).contiguous().view(-1, x.shape[1]) + tg = self.tgate(tg).view(x.shape[0], x.shape[2]).unsqueeze(1) + out = x * tg + + return out + + +class MultiScaleConvFeatureExtractionModel(nn.Module): + def __init__( + self, + dropout=0.0, + non_affine_group_norm=False, + activation=nn.ReLU(),): + super().__init__() + + def block(n_in, n_out, k, stride, padding=0): + return nn.Sequential( + nn.Conv1d(n_in, n_out, k, stride=stride, bias=False, padding=padding), + nn.Dropout(p=dropout), + norm_block(is_layer_norm=False, dim=n_out, affine=not non_affine_group_norm, + is_instance_norm=True), + activation) + + self.conv_front = nn.ModuleList() + + # multi-3: s=18 + self.conv_front.append(nn.Sequential(block(1, 90, 36, 18, 0), block(90, 192, 5, 1, 2))) + self.conv_front.append(nn.Sequential(block(1, 90, 18, 9, 0), block(90, 160, 5, 2, 0))) + self.conv_front.append(nn.Sequential(block(1, 90, 12, 6, 0), block(90, 160, 5, 3, 0))) + + self.skip1 = nn.MaxPool1d(kernel_size=5, stride=8) + self.skip2 = nn.MaxPool1d(kernel_size=3, stride=4, padding=1) + # self.skip3 = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + self.conv1 = block(512, 512, 5, 2) + self.conv2 = block(512, 512, 3, 2) + self.conv3 = block(512, 512, 3, 2, padding=2) + + self.am1 = SEBlock(512) + self.am2 = SEBlock(512) + self.am3 = SEBlock(512) + self.am4 = SEBlock(512*3) + + def forward(self, x): + # BxT -> BxCxT + + # wave encoder + enc = [] + ft_shape = [] + for conv in self.conv_front: + enc.append(conv(x)) + ft_shape.append(conv(x).shape[-1]) + + ft_max = np.min(np.array(ft_shape)) + enc = torch.cat((enc[0][:, :, :ft_max], enc[1][:, :, :ft_max], enc[2][:, :, :ft_max]), dim=1) + + # skipping layers + skip1_out = self.skip1(enc) + out = self.conv1(enc) + out = self.am1(out) + skip2_out = self.skip2(out) + out = self.conv2(out) + out = self.am2(out) + # skip3_out = self.skip3(out) + out = self.conv3(out) + out = self.am3(out) + + t_max = np.min(np.array([skip1_out.shape[-1], skip2_out.shape[-1], out.shape[-1]])) + + out = torch.cat((skip1_out[:, :, :t_max], skip2_out[:, :, :t_max], out[:, :, :t_max]), dim=1) + out = self.am4(out) + + return out + + +class TDNN_Block(nn.Module): + def __init__(self, input_dim, output_dim=512, context_size=5, dilation=1, norm='bn', affine=True): + super(TDNN_Block, self).__init__() + if norm == 'bn': + norm_layer = nn.BatchNorm1d(output_dim, affine=affine) + elif norm == 'ln': + # norm_layer = nn.GroupNorm(1, output_dim, affine=affine) + norm_layer = Fp32GroupNorm(1, output_dim, affine=affine) + elif norm == 'in': + norm_layer = nn.GroupNorm(output_dim, output_dim, affine=False) + else: + raise ValueError('Norm should be {bn, ln, in}.') + self.tdnn_layer = nn.Sequential( + TDNNLayer(input_dim, output_dim, context_size, dilation), + norm_layer, + nn.ReLU() + ) + + def forward(self, x): + return self.tdnn_layer(x) + + +class xvecTDNN(nn.Module): + def __init__(self, feature_dim=512, embed_dim=512, norm='bn', p_dropout=0.0): + super(xvecTDNN, self).__init__() + self.tdnn = nn.Sequential( + TDNN_Block(feature_dim, 512, 5, 1, norm=norm), + TDNN_Block(512, 512, 3, 2, norm=norm), + TDNN_Block(512, 512, 3, 3, norm=norm), + TDNN_Block(512, 512, 1, 1, norm=norm), + TDNN_Block(512, 1500, 1, 1, norm=norm), + ) + + self.fc1 = nn.Linear(3000, 512) + self.bn = nn.BatchNorm1d(512) + self.dropout_fc1 = nn.Dropout(p=p_dropout) + self.lrelu = nn.LeakyReLU(0.2) + self.fc2 = nn.Linear(512, embed_dim) + + def forward(self, x): + # Note: x must be (batch_size, feat_dim, chunk_len) + x = self.tdnn(x) + + stats = torch.cat((x.mean(dim=2), x.std(dim=2)), dim=1) + + x = self.dropout_fc1(self.lrelu(self.bn(self.fc1(stats)))) + x = self.fc2(x) + + return x + + +class YVector(nn.Module): + def __init__(self, embed_dim=512): + super().__init__() + + self.feature_encoder = MultiScaleConvFeatureExtractionModel() + self.tdnn_aggregator = xvecTDNN(feature_dim=512*3, embed_dim=128, norm='ln') + + def forward(self, x): + + # require batch and channel dimensions + assert x.ndim >= 2 + + # avoid modifying input audio + n_batch, *channel_dims, signal_len = x.shape + + # add channel dimension if necessary + if len(channel_dims) == 0: + x = x.unsqueeze(1) + + out = self.feature_encoder(x) + out = self.tdnn_aggregator(out) + + return out + + +if __name__ == "__main__": + + model = YVector() + print(model) + wav_input_16khz = torch.randn(4, 1, 48000) + c = model(wav_input_16khz) + print(c.shape) diff --git a/voicebox/src/models/speech/__init__.py b/voicebox/src/models/speech/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb16ecc8c9c8798ef43083753a45aea8080e4e1 --- /dev/null +++ b/voicebox/src/models/speech/__init__.py @@ -0,0 +1,4 @@ +from src.models.speech.speech import SpeechRecognitionModel, GreedyCTCDecoder +from src.models.speech.wav2vec2 import Wav2Vec2 +from src.models.speech.hubert import HUBERT +from src.models.speech.deepspeech import DeepSpeech diff --git a/voicebox/src/models/speech/__pycache__/__init__.cpython-310.pyc b/voicebox/src/models/speech/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b671ebf83dbe5446dd06bf9d67b704ce60a2c58 Binary files /dev/null and b/voicebox/src/models/speech/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/models/speech/__pycache__/__init__.cpython-39.pyc b/voicebox/src/models/speech/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9afcab8a87261b9caa36932a422608ddae89821 Binary files /dev/null and b/voicebox/src/models/speech/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/models/speech/__pycache__/speech.cpython-310.pyc b/voicebox/src/models/speech/__pycache__/speech.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..878adbb2ec323dd31598bf742b8fd8cbfbec9e05 Binary files /dev/null and b/voicebox/src/models/speech/__pycache__/speech.cpython-310.pyc differ diff --git a/voicebox/src/models/speech/__pycache__/speech.cpython-39.pyc b/voicebox/src/models/speech/__pycache__/speech.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cc6170bcc996017559ee1ab593b7802d237a7e1 Binary files /dev/null and b/voicebox/src/models/speech/__pycache__/speech.cpython-39.pyc differ diff --git a/voicebox/src/models/speech/deepspeech/__init__.py b/voicebox/src/models/speech/deepspeech/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..996c877b7647635ad94004b7cbd808ea8857dea1 --- /dev/null +++ b/voicebox/src/models/speech/deepspeech/__init__.py @@ -0,0 +1 @@ +from src.models.speech.deepspeech.deepspeech import DeepSpeech diff --git a/voicebox/src/models/speech/deepspeech/__pycache__/__init__.cpython-310.pyc b/voicebox/src/models/speech/deepspeech/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56713247367368f565831445935d4dadc2a09b39 Binary files /dev/null and b/voicebox/src/models/speech/deepspeech/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/models/speech/deepspeech/__pycache__/__init__.cpython-39.pyc b/voicebox/src/models/speech/deepspeech/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ffc60ff9bf9a8c8fefe4de2b85bb6881b026abd Binary files /dev/null and b/voicebox/src/models/speech/deepspeech/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/models/speech/deepspeech/__pycache__/deepspeech.cpython-310.pyc b/voicebox/src/models/speech/deepspeech/__pycache__/deepspeech.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eba312b9427ab30cad1cf5f986203e0322913e6 Binary files /dev/null and b/voicebox/src/models/speech/deepspeech/__pycache__/deepspeech.cpython-310.pyc differ diff --git a/voicebox/src/models/speech/deepspeech/__pycache__/deepspeech.cpython-39.pyc b/voicebox/src/models/speech/deepspeech/__pycache__/deepspeech.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..006264c2b5924f830f916401db9a471fe24d01a3 Binary files /dev/null and b/voicebox/src/models/speech/deepspeech/__pycache__/deepspeech.cpython-39.pyc differ diff --git a/voicebox/src/models/speech/deepspeech/deepspeech.py b/voicebox/src/models/speech/deepspeech/deepspeech.py new file mode 100644 index 0000000000000000000000000000000000000000..d2d7fff04e35c8a93dc9c0de97afca01c8da4cbe --- /dev/null +++ b/voicebox/src/models/speech/deepspeech/deepspeech.py @@ -0,0 +1,364 @@ +import math +from typing import List, Union, Iterable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchaudio.transforms import Spectrogram +from torch.cuda.amp import autocast + +from src.data import DataProperties + +################################################################################ +# DeepSpeech2 model (Amodei et al.) as implemented by Sean Naren +################################################################################ + + +class SequenceWise(nn.Module): + + def __init__(self, module: nn.Module): + """ + Collapses input of shape (seq_len, n_batch, n_features) to + (seq_len * n_batch, n_features) and applies a nn.Module along the + feature dimension. Allows handling of variable sequence lengths and batch + sizes. + + Parameters + ---------- + module (nn.Module): module to apply to input + """ + super(SequenceWise, self).__init__() + self.module = module + + def forward(self, x: torch.Tensor): + + # assume input shape (seq_len, n_batch, n_features) + t, n = x.size(0), x.size(1) + x = x.view(t * n, -1) + x = self.module(x) + x = x.view(t, n, -1) + return x + + def __repr__(self): + tmpstr = self.__class__.__name__ + ' (\n' + tmpstr += self.module.__repr__() + tmpstr += ')' + return tmpstr + + +class MaskConv(nn.Module): + + def __init__(self, seq_module: nn.Sequential): + """ + Adds padding to the output of each layer in a given convolution stack + based on a set of given lengths. This ensures that the results of the + model do not change when batch sizes change during inference. Expects + input with shape (n_batch, n_channels, ???, seq_len) + + Parameters + ---------- + seq_module (nn.Sequential): the sequential module containing the + convolution stack + """ + super(MaskConv, self).__init__() + self.seq_module = seq_module + + def forward(self, x: torch.Tensor, lengths: Iterable): + """ + + Parameters + ---------- + x (Tensor): input with shape (n_batch, n_channels, ???, seq_len) + lengths (list): list of target lengths + + Returns + ------- + masked (Tensor): padded output of convolution stack + lengths (list): list of target lengths + """ + for module in self.seq_module: + x = module(x) + mask = torch.BoolTensor(x.size()).fill_(0) + if x.is_cuda: + mask = mask.cuda() + for i, length in enumerate(lengths): + length = length.item() + if (mask[i].size(2) - length) > 0: + mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1) + x = x.masked_fill(mask, 0) + return x, lengths + + +class InferenceBatchSoftmax(nn.Module): + """Apply softmax along final tensor dimension in inference mode only""" + def forward(self, input_: torch.Tensor): + if not self.training: + return F.softmax(input_, dim=-1) + else: + return input_ + + +class BatchRNN(nn.Module): + """RNN layer with optional batch normalization""" + def __init__(self, + input_size: int, + hidden_size: int, + rnn_type=nn.LSTM, + bidirectional: bool = False, + batch_norm: bool = True): + + super(BatchRNN, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.bidirectional = bidirectional + + # apply time-distributed batch normalization + self.batch_norm = SequenceWise( + nn.BatchNorm1d(input_size)) if batch_norm else None + + self.rnn = rnn_type(input_size=input_size, + hidden_size=hidden_size, + bidirectional=bidirectional, + bias=True) + self.num_directions = 2 if bidirectional else 1 + + def flatten_parameters(self): + self.rnn.flatten_parameters() + + def forward(self, x: torch.Tensor, output_lengths: torch.Tensor): + + if self.batch_norm is not None: + x = self.batch_norm(x) + + x = nn.utils.rnn.pack_padded_sequence(x, output_lengths) + x, h = self.rnn(x) + x, _ = nn.utils.rnn.pad_packed_sequence(x) + + # sum forward and backward contexts if bidirectional + if self.bidirectional: + x = x.view( + x.size(0), x.size(1), 2, -1 + ).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) + return x + + +class Lookahead(nn.Module): + """ + Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks + from Wang et al 2016. + """ + def __init__(self, n_features: int, context: int): + """ + Parameters + ---------- + n_features (int): feature dimension + context (int): context length in frames, corresponding to a lookahead + of (context - 1) frames + """ + super(Lookahead, self).__init__() + + assert context > 0, 'Must provide nonzero context length' + + self.context = context + self.n_features = n_features + + # pad to preserve sequence length in output + self.pad = (0, self.context - 1) + + self.conv = nn.Conv1d( + self.n_features, + self.n_features, + kernel_size=self.context, + stride=1, + groups=self.n_features, + padding=0, + bias=False + ) + + def forward(self, x: torch.Tensor): + """ + Parameters + ---------- + x (Tensor): shape (seq_len, n_batch, n_features) + + Returns + ------- + out (Tensor): shape (seq_len, n_batch, n_features) + """ + x = x.transpose(0, 1).transpose(1, 2) + x = F.pad(x, pad=self.pad, value=0) + x = self.conv(x) + x = x.transpose(1, 2).transpose(0, 1).contiguous() + return x + + def __repr__(self): + return self.__class__.__name__ + '(' \ + + 'n_features=' + str(self.n_features) \ + + ', context=' + str(self.context) + ')' + + +class DeepSpeech(nn.Module): + def __init__(self, + window_size: float = 0.02, + window_stride: float = 0.01, + normalize: bool = True): + """ + Parameters + ---------- + + """ + + super().__init__() + + # hard-code to match pre-trained implementation + self.sample_rate = 16000 + self.labels = [ + '_', "'", 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', + 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', + 'W', 'X', 'Y', 'Z', '|'] + self.sep_idx = len(self.labels) - 1 + self.blank_idx = 0 + self.hidden_size = 1024 + self.hidden_layers = 5 + self.lookahead_context = 0 + self.bidirectional: bool = True + self.normalize = normalize + num_classes = len(self.labels) + + # check sample rate + if DataProperties.get("sample_rate") != self.sample_rate: + raise ValueError(f"Incompatible data and model sample rates " + f"{DataProperties.get('sample_rate')}, " + f"{self.sample_rate}") + + # spectrogram processing - matches original Librosa implementation + # (MSE ~1e-11 for 4s audio) + self.spec = Spectrogram( + n_fft=int(self.sample_rate * window_size), + win_length=int(self.sample_rate * window_size), + hop_length=int(self.sample_rate * window_stride), + window_fn=torch.hamming_window, + center=True, + pad_mode='constant', + power=1 + ) + + # convolutional spectrogram encoder (acoustic model) + self.conv = MaskConv(nn.Sequential( + nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)), + nn.BatchNorm2d(32), + nn.Hardtanh(0, 20, inplace=True), + nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)), + nn.BatchNorm2d(32), + nn.Hardtanh(0, 20, inplace=True) + )) + + # compute RNN input size using conv formula (W - F + 2P)/ S+1 + rnn_input_size = int(math.floor((self.sample_rate * window_size) / 2) + 1) + rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1) + rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1) + rnn_input_size *= 32 + + # RNN stack + self.rnns = nn.Sequential( + BatchRNN( + input_size=rnn_input_size, + hidden_size=self.hidden_size, + rnn_type=nn.LSTM, + bidirectional=self.bidirectional, + batch_norm=False + ), + *( + BatchRNN( + input_size=self.hidden_size, + hidden_size=self.hidden_size, + rnn_type=nn.LSTM, + bidirectional=self.bidirectional + ) for x in range(self.hidden_layers - 1) + ) + ) + + # post-RNN lookahead (for unidirectional models) + self.lookahead = nn.Sequential( + Lookahead(self.hidden_size, context=self.lookahead_context), + nn.Hardtanh(0, 20, inplace=True) + ) if not self.bidirectional else None + + # final time-distributed linear layer for token prediction + fully_connected = nn.Sequential( + nn.BatchNorm1d(self.hidden_size), + nn.Linear(self.hidden_size, num_classes, bias=False) + ) + self.fc = nn.Sequential( + SequenceWise(fully_connected), + ) + self.inference_softmax = InferenceBatchSoftmax() + + def forward(self, x, lengths=None): + """ + Parameters + ---------- + x (Tensor): + + lengths (Tensor): + """ + + # ensure RNN blocks are in train mode to allow backpropagation for + # attack optimization + if not self.rnns.training: + self.rnns.train() + + # require batch, channel dimensions + assert x.ndim >= 2 + n_batch, *channel_dims, signal_len = x.shape + + if x.ndim == 2: + x = x.unsqueeze(1) + + # convert to mono audio + x = x.mean(dim=1, keepdim=True) + + # compute spectrogram + x = self.spec(x) # (n_batch, 1, n_freq, n_frames) + x = torch.log1p(x) + + if self.normalize: + mean = x.mean() + std = x.std() + x = x - mean + x = x / std + + lengths = lengths or torch.full((n_batch,), x.shape[-1], dtype=torch.long) + + lengths = lengths.cpu().int() + output_lengths = self.get_seq_lens(lengths) + x, _ = self.conv(x, output_lengths) + + sizes = x.size() + x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # Collapse feature dimension + x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH + + for rnn in self.rnns: + x = rnn(x, output_lengths) + + if not self.bidirectional: # no need for lookahead layer in bidirectional + x = self.lookahead(x) + + x = self.fc(x) + x = x.transpose(0, 1) + + return x + + def get_seq_lens(self, input_length): + """ + Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable + containing the size sequences that will be output by the network. + :param input_length: 1D Tensor + :return: 1D Tensor scaled by model + """ + seq_len = input_length + for m in self.conv.modules(): + if type(m) == nn.modules.conv.Conv2d: + seq_len = ((seq_len + 2 * m.padding[1] - m.dilation[1] * (m.kernel_size[1] - 1) - 1) // m.stride[1] + 1) + return seq_len.int() diff --git a/voicebox/src/models/speech/hubert/__init__.py b/voicebox/src/models/speech/hubert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..206e4192017f5f8ec9ff6de20324037a75cb27e7 --- /dev/null +++ b/voicebox/src/models/speech/hubert/__init__.py @@ -0,0 +1 @@ +from src.models.speech.hubert.hubert import HUBERT diff --git a/voicebox/src/models/speech/hubert/__pycache__/__init__.cpython-310.pyc b/voicebox/src/models/speech/hubert/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5cbc1c8e8c54fe24d797d8be5b4eb35d1063079 Binary files /dev/null and b/voicebox/src/models/speech/hubert/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/models/speech/hubert/__pycache__/__init__.cpython-39.pyc b/voicebox/src/models/speech/hubert/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62800da702a6e16c296251836e12ed785df6f030 Binary files /dev/null and b/voicebox/src/models/speech/hubert/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/models/speech/hubert/__pycache__/hubert.cpython-310.pyc b/voicebox/src/models/speech/hubert/__pycache__/hubert.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72baf6eb2eecab9ebe76ea789bfd991a7a364c91 Binary files /dev/null and b/voicebox/src/models/speech/hubert/__pycache__/hubert.cpython-310.pyc differ diff --git a/voicebox/src/models/speech/hubert/__pycache__/hubert.cpython-39.pyc b/voicebox/src/models/speech/hubert/__pycache__/hubert.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b04d0f3c5f9bd7305156b7e371fb4b0382696b9a Binary files /dev/null and b/voicebox/src/models/speech/hubert/__pycache__/hubert.cpython-39.pyc differ diff --git a/voicebox/src/models/speech/hubert/hubert.py b/voicebox/src/models/speech/hubert/hubert.py new file mode 100644 index 0000000000000000000000000000000000000000..4b053f0ab0df5557e48fa185e13e627a303cfca7 --- /dev/null +++ b/voicebox/src/models/speech/hubert/hubert.py @@ -0,0 +1,92 @@ +import torch +import torchaudio + +import torch.nn as nn + +from src.data import DataProperties + +torchaudio.set_audio_backend("sox_io") + + +class HUBERT(nn.Module): + + def __init__(self, variant: str = "large"): + + super().__init__() + + # identify model variant (distinguished by size, dataset, fine-tuning) + variants = { + + # `LARGE` variant, trained on Libri-Light 60,000h, fine-tuned on + # full LibriSpeech 960h + "large": "HUBERT_ASR_LARGE", + + # `XLARGE` variant, trained on Libri-Light 60,000h, fine-tuned on + # full LibriSpeech 960h + "xlarge": "HUBERT_ASR_XLARGE" + + } + + try: + variant_full_name = variants[variant] + except KeyError: + raise ValueError(f"Invalid variant {variant}; must be one of " + f"{list(variants.keys())}") + + # import HUBERT model variant as `bundle` object + bundle = eval(f'torchaudio.pipelines.{variant_full_name}') + + # unpack model, labels, and sample rate + self.model = bundle.get_model() + self.labels = bundle.get_labels() + self.sample_rate = bundle.sample_rate + + # hardcode for HUBERT + self.sep_idx = 4 + self.blank_idx = 0 + + # check sample rate + if DataProperties.get("sample_rate") != self.sample_rate: + raise ValueError(f"Incompatible data and model sample rates " + f"{DataProperties.get('sample_rate')}, " + f"{self.sample_rate}") + + # feature extractor: stacked 1D convolutional blocks + assert self.model.feature_extractor is not None + + # encoder: feature projection, transformer + assert self.model.encoder is not None + + # aux: fine-tuned linear layer(s) mapping to token probabilities + assert self.model.aux is not None + + def forward(self, x: torch.Tensor): + """ + Pass input audio through feature extractor, encoder, and fine-tuned + auxiliary layer to produce a sequence of token probability distributions + + :param x: waveform audio of shape (n_batch, ..., signal_length) + :return: + """ + + # reshape audio to (n_batch, signal_length) + if x.ndim != 2: + n_batch, signal_length = x.shape[0], x.shape[-1] + x = x.reshape(n_batch, signal_length) + + # emit sequence(s) of token probabilities + emission, _ = self.model(x, lengths=None) + + return emission + + def extract_features(self, x: torch.Tensor): + """ + Extract deep features. + """ + + # reshape audio to (n_batch, signal_length) + if x.ndim != 2: + n_batch, signal_length = x.shape[0], x.shape[-1] + x = x.reshape(n_batch, signal_length) + + return self.model.extract_features(x)[0] diff --git a/voicebox/src/models/speech/speech.py b/voicebox/src/models/speech/speech.py new file mode 100644 index 0000000000000000000000000000000000000000..dae8fd3038c31a02f9b4bc7fa67f036be668cd0d --- /dev/null +++ b/voicebox/src/models/speech/speech.py @@ -0,0 +1,376 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from typing import Union, List, Tuple + +from src.models.model import Model + +################################################################################ +# Extends Model class for speech recognition, with optional decoding +################################################################################ + + +class Decoder(object): + """ + Code adapted from DeepSpeech PyTorch (https://tinyurl.com/2p89d35e). Base + class for decoder objects, which convert emitted frame-by-frame token + probabilities into a string transcription. + """ + def __init__(self, + labels: Union[List[str], Tuple[str]], + sep_idx: int = None, + blank_idx: int = 0): + """ + Parameters + ---------- + labels (list): character corresponding to each token index + + sep_idx (int): index corresponding to space / separating character + + blank_idx (int): index corresponding to blank '_' character + """ + self.labels = labels + self.blank_idx = blank_idx + self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)]) + + if sep_idx is None: + + # use out-of-bounds index for separating character + sep_idx = len(labels) + if ' ' in labels: + sep_idx = labels.index(' ') + elif '|' in labels: + sep_idx = labels.index('|') + self.sep_idx = sep_idx + + else: + self.sep_idx = sep_idx + + def get_labels(self): + return self.labels + + def get_sep_idx(self): + return self.sep_idx + + def get_blank_idx(self): + return self.blank_idx + + def __call__(self, emission: torch.Tensor, sizes=None): + return self.decode(emission, sizes) + + def decode(self, emission: torch.Tensor, sizes=None): + """ + Decode emitted token probabilities to obtain a string transcription. + + Parameters + ---------- + emission (Tensor): shape (n_batch, n_frames, n_tokens) + + sizes (Tensor): length in frames of each emission in batch + """ + raise NotImplementedError + + +class GreedyCTCDecoder(Decoder): + """ + A simple decoder module to map token probability sequences to transcripts. + Decodes 'greedily' by selecting maximum-probability token at each time step. + Code adapted from DeepSpeech PyTorch (https://tinyurl.com/2p89d35e). + """ + def __init__(self, + labels: Union[List[str], Tuple[str]], + sep_idx: int = None, + blank_idx: int = 0): + super().__init__(labels, sep_idx, blank_idx) + + def convert_to_strings(self, + sequences, + sizes=None, + remove_repetitions=False, + return_offsets=False): + """ + Given a list of sequences holding token numbers, return the + corresponding strings. Optionally, collapse repeated token subsequences + and return final length of each processed sequence. + + Parameters + ---------- + + sequences (Tensor): shape (n_batch, n_frames); holds argmax token index + for each frame + + sizes + + remove_repetitions + + return_offsets + + Returns + ------- + + """ + + strings = [] + offsets = [] if return_offsets else None + + for i, sequence in enumerate(sequences): + + seq_len = sizes[i] if sizes is not None else len(sequence) + string, string_offsets = self.process_string(sequence, seq_len, remove_repetitions) + strings.append(string) + if return_offsets: + offsets.append(string_offsets) + + if return_offsets: + return strings, offsets + else: + return strings + + def process_string(self, + sequence, + size, + remove_repetitions=False): + string = '' + offsets = [] + + for i in range(size): + char = self.int_to_char[sequence[i].item()] + + if char != self.int_to_char[self.blank_idx]: + + # skip repeated characters if specified + if remove_repetitions and i != 0 and \ + char == self.int_to_char[sequence[i - 1].item()]: + pass + elif char == self.labels[self.sep_idx]: + string += self.labels[self.sep_idx] + offsets.append(i) + else: + string = string + char + offsets.append(i) + + return string, torch.tensor(offsets, dtype=torch.int) + + def decode(self, emission, sizes=None): + """ + Returns the argmax decoding given the emitted token probabilities. + According to connectionist temporal classification (CTC), removes + repeated elements in the decoded token sequence, as well as blanks. + + Parameters + ---------- + emission (Tensor): shape (n_batch, n_frames, n_tokens) + + sizes (Tensor): length in frames of each emission in batch + + Returns + ------- + transcription (list[str]): string transcription for each item in batch + + offsets (??? frame index per character predicted + + """ + + if emission.ndim == 2: # require shape (n_batch, n_frames, n_tokens) + emission = emission.unsqueeze(0) + + # compute max-probability label at each sequence index + max_probs = torch.argmax(emission, dim=-1) # (n_batch, sequence_len) + + strings, offsets = self.convert_to_strings(max_probs, + sizes, + remove_repetitions=True, + return_offsets=True) + return strings, offsets + + +class SpeechRecognitionModel(Model): + + def __init__(self, + model: nn.Module, + decoder: Decoder = None + ): + + super().__init__() + + self.model = model + self.model.eval() + + # ensure that list of viable tokens can be retrieved from wrapped model + labels_method = getattr(self.model, "get_labels", None) + labels_attr = getattr(self.model, "labels", None) + if callable(labels_method): + self._get_labels_fn = lambda: self.model.get_labels() + elif labels_attr is not None: + self._get_labels_fn = lambda: self.model.labels + else: + raise ValueError(f'Wrapped model must have method `.get_labels()`' + f' or attribute `.labels`') + + # ensure that blank and separator tokens can be retrieved from wrapped + # model + sep_method = getattr(self.model, "get_sep_idx", None) + sep_attr = getattr(self.model, "sep_idx", None) + if callable(sep_method): + self._get_sep_fn = lambda: self.model.get_sep_idx() + elif sep_attr is not None: + self._get_sep_fn = lambda: self.model.sep_idx + else: + raise ValueError(f'Wrapped model must have method `.get_sep_idx()`' + f' or attribute `.sep_idx`') + + blank_method = getattr(self.model, "get_blank_idx", None) + blank_attr = getattr(self.model, "blank_idx", None) + if callable(blank_method): + self._get_blank_fn = lambda: self.model.get_blank_idx() + elif blank_attr is not None: + self._get_blank_fn = lambda: self.model.blank_idx + else: + raise ValueError(f'Wrapped model must have method ' + f'`.get_blank_idx()` or attribute `.blank_idx`') + + # initialize decoder + if decoder is None: + decoder = GreedyCTCDecoder( + labels=self.get_labels(), + blank_idx=self.get_blank_idx(), + sep_idx=self.get_sep_idx() + ) + self.decoder = decoder + + # translate characters to token indices + self.char_to_idx = {l: i for i, l in enumerate(decoder.labels)} + + def get_labels(self): + """Retrieve a list of valid tokens""" + return self._get_labels_fn() + + def get_blank_idx(self): + """Return index of blank token""" + return self._get_blank_fn() + + def get_sep_idx(self): + """Return index of separator token""" + return self._get_sep_fn() + + def forward(self, x: torch.Tensor): + return self.model.forward(x) + + def transcribe(self, x: torch.Tensor, return_alignment: bool = False): + + if return_alignment: + return self.decoder(self.model(x)) + else: + return self.decoder(self.model(x))[0] + + def load_weights(self, path: str): + """ + Load weights from checkpoint file + """ + + # check if file exists + if not path or not os.path.isfile(path): + return + + model_state = self.model.state_dict() + loaded_state = torch.load(path) + + for name, param in loaded_state.items(): + + origname = name + + if name not in model_state: + print("{} is not in the model.".format(origname)) + continue + + if model_state[name].size() != loaded_state[origname].size(): + print( + "Wrong parameter length: {}, model: {}, loaded: {}".format( + origname, + model_state[name].size(), + loaded_state[origname].size() + ) + ) + continue + + model_state[name].copy_(param) + + def extract_features( + self, + x: torch.Tensor + ) -> List[torch.Tensor]: + + """ + Extract deep features. + + :param x: input + :return: a list of tensors holding intermediate activations / features + """ + + try: + return self.model.extract_features(x) + except AttributeError: + return [] + + def _str_to_tensor(self, seq: str): + token_indices = [self.char_to_idx[c] for c in seq] + return torch.as_tensor(token_indices, dtype=torch.long) + + def match_predict(self, + y_pred: Union[List[str], torch.Tensor], + y_true: Union[List[str], torch.Tensor]): + """ + Determine whether (batched) target pairs are equivalent. + """ + + n_batch = len(y_pred) + + y_true_lengths = None + + # convert ground-truth transcriptions to tensor form + if isinstance(y_true, list): + y_true = [self._str_to_tensor(t) for t in y_true] + y_true_lengths = [t.shape[-1] for t in y_true] + y_true = pad_sequence( + y_true, + batch_first=True + ) # (n_batch, max_seq_len) + + if y_true_lengths is None: + y_true_lengths = [y_true.shape[-1]] * n_batch + + # convert predicted transcriptions to tensor form + if isinstance(y_pred, list): + y_pred = [self._str_to_tensor(t) for t in y_pred] + y_pred = pad_sequence( + y_pred, + batch_first=True + ) # (n_batch, max_seq_len) + + length_diff = max(0, y_true.shape[-1] - y_pred.shape[-1]) + if length_diff: + y_pred = F.pad(y_pred, (0, length_diff)) + + matches = [] + for i in range(n_batch): + matches.append( + torch.all( + y_pred[i, ..., :y_true_lengths[i]] == y_true[i, ..., :y_true_lengths[i]] + ) + ) + + return torch.as_tensor(matches) + + + + """ + # masked comparison + use which one as dimension to select --- true or pred? + + pred lengths may be unnecessary! just select to true length + """ + + diff --git a/voicebox/src/models/speech/wav2vec2/__init__.py b/voicebox/src/models/speech/wav2vec2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bafbaa5d29f84883e6ddf0650ea5e5b7f3c9a7d5 --- /dev/null +++ b/voicebox/src/models/speech/wav2vec2/__init__.py @@ -0,0 +1 @@ +from src.models.speech.wav2vec2.wav2vec2 import Wav2Vec2 diff --git a/voicebox/src/models/speech/wav2vec2/__pycache__/__init__.cpython-310.pyc b/voicebox/src/models/speech/wav2vec2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..446ff1d35a90514dad722cae39ec7d676440e91f Binary files /dev/null and b/voicebox/src/models/speech/wav2vec2/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/models/speech/wav2vec2/__pycache__/__init__.cpython-39.pyc b/voicebox/src/models/speech/wav2vec2/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82dfdbb51036b1ec4ed99fea757be10f9fc12fae Binary files /dev/null and b/voicebox/src/models/speech/wav2vec2/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/models/speech/wav2vec2/__pycache__/wav2vec2.cpython-310.pyc b/voicebox/src/models/speech/wav2vec2/__pycache__/wav2vec2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..386241a1adaaf591ed9fb99083e04574d8a406e3 Binary files /dev/null and b/voicebox/src/models/speech/wav2vec2/__pycache__/wav2vec2.cpython-310.pyc differ diff --git a/voicebox/src/models/speech/wav2vec2/__pycache__/wav2vec2.cpython-39.pyc b/voicebox/src/models/speech/wav2vec2/__pycache__/wav2vec2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c267f50f7608912ee879778bce8072ec3d4c46d0 Binary files /dev/null and b/voicebox/src/models/speech/wav2vec2/__pycache__/wav2vec2.cpython-39.pyc differ diff --git a/voicebox/src/models/speech/wav2vec2/wav2vec2.py b/voicebox/src/models/speech/wav2vec2/wav2vec2.py new file mode 100644 index 0000000000000000000000000000000000000000..da917fa27ae5625678a1abd6c05dfdde755deb87 --- /dev/null +++ b/voicebox/src/models/speech/wav2vec2/wav2vec2.py @@ -0,0 +1,143 @@ +import torch +import torchaudio +import torch.nn as nn + +from src.data import DataProperties + +torchaudio.set_audio_backend("sox_io") + + +class Wav2Vec2(nn.Module): + """ + Wav2Vec2 ASR model, as proposed by Baevski et al. + (https://arxiv.org/abs/2006.11477). Takes arbitrary-length waveform audio + at 16kHz and produces string transcripts + """ + + def __init__(self, variant: str = "base_960h"): + + super().__init__() + + # identify model variant (distinguished by size, dataset, fine-tuning) + variants = { + + # `BASE` variant, trained on LibriSpeech 960h, fine-tuned on 10 + # minutes of Libri-Light + "base_10m": "WAV2VEC2_ASR_BASE_10M", + + # `BASE` variant, trained on LibriSpeech 960h, fine-tuned on 100 + # hours of LibriSpeech `train-clean-100` subset + "base_100h": "WAV2VEC2_ASR_BASE_100H", + + # `BASE` variant, trained on LibriSpeech 960h, fine-tuned on full + # LibriSpeech 960h + "base_960h": "WAV2VEC2_ASR_BASE_960H", + + # `LARGE` variant, trained on LibriSpeech 960h, fine-tuned on 10 + # minutes of Libri-Light + "large_10m": "WAV2VEC2_ASR_LARGE_10M", + + # `LARGE` variant, trained on LibriSpeech 960h, fine-tuned on 100 + # hours of LibriSpeech `train-clean-100` subset + "large_100h": "WAV2VEC2_ASR_LARGE_100H", + + # `LARGE` variant, trained on LibriSpeech 960h, fine-tuned on full + # LibriSpeech 960h + "large_960h": "WAV2VEC2_ASR_LARGE_960H", + + # `LARGE` variant, trained on Libri-Light 60,000h, fine-tuned on 10 + # minutes of Libri-Light + "large_lv60k_10m": "WAV2VEC2_ASR_LARGE_LV60K_10M", + + # `LARGE` variant, trained on Libri-Light 60,000h, fine-tuned on 100 + # hours of LibriSpeech `train-clean-100` subset + "large_lv60k_100h": "WAV2VEC2_ASR_LARGE_LV60K_100H", + + # `LARGE` variant, trained on Libri-Light 60,000h, fine-tuned on full + # LibriSpeech 960h + "large_lv60k_960h": "WAV2VEC2_ASR_LARGE_LV60K_960H", + + # `BASE` variant, trained on VoxPopuli 10,000h, fine-tuned on 282h + # German subset + "base_10k_de": "VOXPOPULI_ASR_BASE_10K_DE", + + # `BASE` variant, trained on VoxPopuli 10,000h, fine-tuned on 543h + # English subset + "base_10k_en": "VOXPOPULI_ASR_BASE_10K_EN", + + # `BASE` variant, trained on VoxPopuli 10,000h, fine-tuned on 166h + # Spanish subset + "base_10k_es": "VOXPOPULI_ASR_BASE_10K_ES", + + # `BASE` variant, trained on VoxPopuli 10,000h, fine-tuned on 211h + # French subset + "base_10k_fr": "VOXPOPULI_ASR_BASE_10K_FR", + + # `BASE` variant, trained on VoxPopuli 10,000h, fine-tuned on 91h + # Spanish subset + "base_10k_it": "VOXPOPULI_ASR_BASE_10K_IT" + } + + try: + variant_full_name = variants[variant] + except KeyError: + raise ValueError(f"Invalid variant {variant}; must be one of " + f"{list(variants.keys())}") + + # import Wav2Vec2 model variant as `bundle` object + bundle = eval(f'torchaudio.pipelines.{variant_full_name}') + + # unpack model, labels, and sample rate + self.model = bundle.get_model() + self.labels = bundle.get_labels() + self.sample_rate = bundle.sample_rate + + # hardcode for Wav2Vec2 + self.sep_idx = 4 + self.blank_idx = 0 + + # check sample rate + if DataProperties.get("sample_rate") != self.sample_rate: + raise ValueError(f"Incompatible data and model sample rates " + f"{DataProperties.get('sample_rate')}, " + f"{self.sample_rate}") + + # feature extractor: stacked 1D convolutional blocks + assert self.model.feature_extractor is not None + + # encoder: feature projection, transformer + assert self.model.encoder is not None + + # aux: fine-tuned linear layer(s) mapping to token probabilities + assert self.model.aux is not None + + def forward(self, x: torch.Tensor): + """ + Pass input audio through feature extractor, encoder, and fine-tuned + auxiliary layer to produce a sequence of token probability distributions + + :param x: waveform audio of shape (n_batch, ..., signal_length) + :return: + """ + + # reshape audio to (n_batch, signal_length) + if x.ndim != 2: + n_batch, signal_length = x.shape[0], x.shape[-1] + x = x.reshape(n_batch, signal_length) + + # emit sequence(s) of token probabilities + emission, _ = self.model(x, lengths=None) + + return emission + + def extract_features(self, x: torch.Tensor): + """ + Extract deep features. + """ + + # reshape audio to (n_batch, signal_length) + if x.ndim != 2: + n_batch, signal_length = x.shape[0], x.shape[-1] + x = x.reshape(n_batch, signal_length) + + return self.model.extract_features(x)[0] diff --git a/voicebox/src/pipelines/README.md b/voicebox/src/pipelines/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b1701878ff45a9e0bf6459d6650fc4fc715f0f19 --- /dev/null +++ b/voicebox/src/pipelines/README.md @@ -0,0 +1,103 @@ + +A `Pipeline` can be constructed around any speech model defined as a `torch.nn.Module` by wrapping the model with a task-specific `Model` subclass: + +```python +from src.pipelines import Pipeline +from src.models import SpeakerVerificationModel +from src.models.speaker import ResNetSE34V2 + +model = ResNetSE34V2() + +wrapped_model = SpeakerVerificationModel( + model +) + +pipeline = Pipeline( + model=wrapped_model +) +``` + +A `Pipeline` object may also hold any of the following components: +* A `Simulation` object defining an acoustic simulation +* A `Preprocessor` object consisting of one or more preprocessing stages +* A `Defense` object consisting of one or more adversarial defenses + +For additional documentation on the `Pipeline` class, see here. + + + +```python +# Load simulation +simulation = load_simulation(config) + +# Load preprocessing +preprocessor = load_preprocess(config) + +# Load model +model = load_model(config) +assert isinstance(model, SpeakerVerificationModel) + +# Load adversarial defenses +defense = load_defense(config) +``` + + +See additional documentation on models, acoustic simulation, preprocessing, and defenses. + + +For a quick start, load data and pipelines from a ready-made configuration: + +```python + +# CODE EXAMPLE OF PIPELINE/DATA BUILDER from config file + +``` + + + + +

Building A Pipeline

+ +We provide a simple interface for performing adversarial attacks on speech systems. Acoustic simulation, preprocessing, purification-based defenses, detection-based defenses, and models are implemented as differentiable modules and wrapped within a single `Pipeline` object. A `Pipeline` can be constructed around any speech model defined as a `torch.nn.Module` using a task-specific wrapper: + +```python +from src.pipelines import Pipeline +from src.models import SpeakerVerificationModel +from src.models.speaker import ResNetSE34V2 + +model = ResNetSE34V2() + +wrapped_model = SpeakerVerificationModel( + model +) + +pipeline = Pipeline( + model=wrapped_model +) +``` + +A `Pipeline` object may also hold any of the following components: +* A `Simulation` object defining an acoustic simulation +* A `Preprocessor` object consisting of one or more preprocessing stages +* A `Defense` object consisting of one or more adversarial defenses + +```python +from src.preprocess import Preprocessor, Normalize, KaldiStyleVAD +from src.simulation import Simulation, Offset, Bandpass + +simulation = Simulation( + Offset(length=[-.15, .15]), + Bandpass(low=200, high=6000) +) + +preprocessor = Preprocessor( + Normalize(), + KaldiStyleVAD() +) + +pipeline = Pipeline( + model=wrapped_model, + simulation=simulation, + preprocessor=preprocessor +) +``` diff --git a/voicebox/src/pipelines/__init__.py b/voicebox/src/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..658a58cb2b009dc8426819a1f696b929ff4db4d3 --- /dev/null +++ b/voicebox/src/pipelines/__init__.py @@ -0,0 +1 @@ +from src.pipelines.pipeline import Pipeline diff --git a/voicebox/src/pipelines/__pycache__/__init__.cpython-310.pyc b/voicebox/src/pipelines/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cf66e9427183cb1d6c8c37795793ea08d397d4c Binary files /dev/null and b/voicebox/src/pipelines/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/pipelines/__pycache__/__init__.cpython-39.pyc b/voicebox/src/pipelines/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a09ac1855cd27c13e890385e4a971c3d07f22159 Binary files /dev/null and b/voicebox/src/pipelines/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/pipelines/__pycache__/pipeline.cpython-310.pyc b/voicebox/src/pipelines/__pycache__/pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fe63b8a89d25a8d2f580d2551eeccf7a6619e59 Binary files /dev/null and b/voicebox/src/pipelines/__pycache__/pipeline.cpython-310.pyc differ diff --git a/voicebox/src/pipelines/__pycache__/pipeline.cpython-39.pyc b/voicebox/src/pipelines/__pycache__/pipeline.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a57238bc2ebc12dc2c352d558236fab7c5d7c5a2 Binary files /dev/null and b/voicebox/src/pipelines/__pycache__/pipeline.cpython-39.pyc differ diff --git a/voicebox/src/pipelines/pipeline.py b/voicebox/src/pipelines/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..b8420dd99d579430e8c7640613977459c792decd --- /dev/null +++ b/voicebox/src/pipelines/pipeline.py @@ -0,0 +1,189 @@ +import torch +import torch.nn as nn + +from typing import Union + +from src.models.model import Model +from src.simulation.simulation import Simulation +from src.preprocess.preprocessor import Preprocessor +from src.defenses.defense import Defense + +################################################################################ +# Encapsulate all stages of audio classification pipeline +################################################################################ + + +class Pipeline(nn.Module): + + def __init__(self, + model: Model, + simulation: Simulation = None, + preprocessor: Preprocessor = None, + defense: Defense = None, + device: Union[str, torch.device] = 'cpu', + **kwargs + ): + """ + Pipeline encompassing acoustic environment simulation, preprocessing, + model, and defenses (purification and detection). + + :param model: the victim classifier + :param simulation: an end-to-end differentiable acoustic simulation + :param preprocessor: differentiable preprocessing stages + :param defense: a set of purification and/or detection defenses. + Purification defenses are applied to incoming audio in + sequence, while detection defenses are applied in + parallel + :param device: store device to ensure all pipeline components are + correctly assigned + """ + super().__init__() + + self.model = model + self.simulation = simulation + self.preprocessor = preprocessor + self.defense = defense + self.device = device + + # flags to selectively enable pipeline stages + self._enable_simulation = True + self._enable_preprocessor = True + self._enable_defense = True + + # ensure model is in 'eval' mode + self.model.eval() + + # move all submodules to stored device + self.set_device(device) + + # freeze gradient computation for all stored parameters + self._freeze_grad() + + # randomly initialize simulation parameters + self.sample_params() + + @property + def enable_simulation(self): + return self._enable_simulation + + @enable_simulation.setter + def enable_simulation(self, flag: bool): + self._enable_simulation = flag + + @property + def enable_preprocessor(self): + return self._enable_preprocessor + + @enable_preprocessor.setter + def enable_preprocessor(self, flag: bool): + self._enable_preprocessor = flag + + @property + def enable_defense(self): + return self._enable_defense + + @enable_defense.setter + def enable_defense(self, flag: bool): + self._enable_defense = flag + + def set_device(self, device: Union[str, torch.device]): + """ + Move all submodules to stored device + """ + self.device = device + + for module in self.modules(): + module.to(self.device) + + def _freeze_grad(self): + """ + Disable gradient computations for all stored parameters + """ + for p in self.parameters(): + p.requires_grad = False + + def sample_params(self): + """ + Randomly re-sample the parameters of each stored effect + """ + if self.simulation is not None: + self.simulation.sample_params() + + def simulate(self, x: torch.Tensor): + """ + Pass inputs through simulation + """ + if self.enable_simulation and self.simulation is not None: + x = self.simulation(x) + + return x + + def preprocess(self, x: torch.Tensor): + """ + Pass inputs through preprocessing + """ + if self.enable_preprocessor and self.preprocessor is not None: + x = self.preprocessor(x) + + return x + + def purify(self, x: torch.Tensor): + """ + Pass inputs through purification defenses + """ + if self.enable_defense and self.defense is not None: + x = self.defense.purify(x) + + return x + + def forward(self, x: torch.Tensor): + """ + Pass inputs through simulation, preprocessor, purification defenses, and + model in sequence + """ + x = self.simulate(x) + x = self.preprocess(x) + x = self.purify(x) + + return self.model(x) + + def detect(self, x: torch.Tensor): + """ + Apply detection defenses to input in parallel. For every input, each + detection defense produces a score indicating confidence in its + adversarial nature and a boolean flag indicating whether this score + surpasses a (calibrated) internal threshold. + + :param x: input tensor (n_batch, ...) + :return: flags (n_batch, n_defenses), scores (n_batch, n_defenses) + """ + + # apply simulated distortions and preprocessing to input; omit + # purification defenses + x = self.simulate(x) + x = self.preprocess(x) + + if self._enable_defense and self.defense is not None: + flags, scores = self.defense.detect( + x, + self.model + ) + + else: + + n_batch = x.shape[0] + + # allow zero-gradients to propagate + flags = x.reshape(n_batch, -1).sum(dim=-1).reshape(n_batch, 1) * 0 + scores = x.reshape(n_batch, -1).sum(dim=-1).reshape(n_batch, 1) * 0 + + return flags, scores + + def match_predict(self, y_pred: torch.tensor, y_true: torch.Tensor): + """ + Determine whether target pairs are equivalent under stored model + """ + return self.model.match_predict(y_pred, y_true) + + + diff --git a/voicebox/src/preprocess/__init__.py b/voicebox/src/preprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fddbc082c3cdfe23c7e4c3f91253999d1d22d4ef --- /dev/null +++ b/voicebox/src/preprocess/__init__.py @@ -0,0 +1,5 @@ +from src.preprocess.preprocessor import Preprocessor +from src.preprocess.preemphasis import PreEmphasis +from src.preprocess.normalize import Normalize +from src.preprocess.vad import VAD, KaldiStyleVAD + diff --git a/voicebox/src/preprocess/__pycache__/__init__.cpython-310.pyc b/voicebox/src/preprocess/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d930327289ba4d9dfee3b124c3844e3136b4d574 Binary files /dev/null and b/voicebox/src/preprocess/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/preprocess/__pycache__/__init__.cpython-39.pyc b/voicebox/src/preprocess/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f6ef06ded8b63776bb1b22ae09e69da25369d6b Binary files /dev/null and b/voicebox/src/preprocess/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/preprocess/__pycache__/normalize.cpython-310.pyc b/voicebox/src/preprocess/__pycache__/normalize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec7429716a4096f24141af5e3084a2dfc245f50e Binary files /dev/null and b/voicebox/src/preprocess/__pycache__/normalize.cpython-310.pyc differ diff --git a/voicebox/src/preprocess/__pycache__/normalize.cpython-39.pyc b/voicebox/src/preprocess/__pycache__/normalize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f811a8881c4dd2c9d5e5d88393adac9c8701585e Binary files /dev/null and b/voicebox/src/preprocess/__pycache__/normalize.cpython-39.pyc differ diff --git a/voicebox/src/preprocess/__pycache__/preemphasis.cpython-310.pyc b/voicebox/src/preprocess/__pycache__/preemphasis.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8135af4ceead218f24c46effcc1b1042c991e1a8 Binary files /dev/null and b/voicebox/src/preprocess/__pycache__/preemphasis.cpython-310.pyc differ diff --git a/voicebox/src/preprocess/__pycache__/preemphasis.cpython-39.pyc b/voicebox/src/preprocess/__pycache__/preemphasis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3446c6850f1fe70bf4e4822ba6185683a92011bd Binary files /dev/null and b/voicebox/src/preprocess/__pycache__/preemphasis.cpython-39.pyc differ diff --git a/voicebox/src/preprocess/__pycache__/preprocessor.cpython-310.pyc b/voicebox/src/preprocess/__pycache__/preprocessor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f06a78c8511296f28b67a990093e484d8d70ab4f Binary files /dev/null and b/voicebox/src/preprocess/__pycache__/preprocessor.cpython-310.pyc differ diff --git a/voicebox/src/preprocess/__pycache__/preprocessor.cpython-39.pyc b/voicebox/src/preprocess/__pycache__/preprocessor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e7103944078b6c8db58bf06b2d52865a15089e8 Binary files /dev/null and b/voicebox/src/preprocess/__pycache__/preprocessor.cpython-39.pyc differ diff --git a/voicebox/src/preprocess/__pycache__/vad.cpython-310.pyc b/voicebox/src/preprocess/__pycache__/vad.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f193105383922ab953f2002defaadb68f9d0f683 Binary files /dev/null and b/voicebox/src/preprocess/__pycache__/vad.cpython-310.pyc differ diff --git a/voicebox/src/preprocess/__pycache__/vad.cpython-39.pyc b/voicebox/src/preprocess/__pycache__/vad.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04ec8ea78f18143a7226d25d913ecfa406bd83bf Binary files /dev/null and b/voicebox/src/preprocess/__pycache__/vad.cpython-39.pyc differ diff --git a/voicebox/src/preprocess/normalize.py b/voicebox/src/preprocess/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..f364b10bb29d1a17eec10939e6e0298c25268cbb --- /dev/null +++ b/voicebox/src/preprocess/normalize.py @@ -0,0 +1,67 @@ +import torch + +from src.simulation.component import Component + +################################################################################ +# Normalize audio +################################################################################ + + +class Normalize(Component): + + def __init__(self, + method: str = 'peak', + target_dbfs: float = -30.0, + increase_only: bool = False, + decrease_only: bool = False + ): + """ + Normalize incoming audio. + + Parameters + ---------- + + Returns + ------- + + """ + super().__init__() + + assert method in [None, 'none', 'peak', 'dbfs'], \ + f"Invalid normalization method {method}" + self.method = method + + # parameters for dBFS normalization + assert not (increase_only and decrease_only), \ + f"Cannot set both `increase_only` and `decrease_only`" + + self.target_dbfs = target_dbfs + self.increase_only = increase_only + self.decrease_only = decrease_only + + def forward(self, x: torch.Tensor): + + if self.method is None: + return x + elif self.method == 'peak': + return (self.scale / torch.max( + torch.abs(x) + 1e-8, dim=-1, keepdim=True)[0]) * x * 0.95 + elif self.method == 'dbfs': + + # compute volume in dBFS + rms = torch.sqrt(torch.mean(x ** 2)) + dbfs = 20 * torch.log10(rms) + + # determine whether to normalize + dbfs_change = self.target_dbfs - dbfs + if dbfs_change < 0 and self.increase_only or \ + dbfs_change > 0 and self.decrease_only: + return x + + normalized = x * (10 ** (dbfs_change / 20)) + + # clip to valid range + return torch.clamp(normalized, min=-self.scale, max=self.scale) + + else: + raise ValueError(f'Invalid normalization: {self.method}') diff --git a/voicebox/src/preprocess/preemphasis.py b/voicebox/src/preprocess/preemphasis.py new file mode 100644 index 0000000000000000000000000000000000000000..875ecad23f58bbc71f858467a91ddeeaa209ea1e --- /dev/null +++ b/voicebox/src/preprocess/preemphasis.py @@ -0,0 +1,69 @@ +import torch +import torch.nn.functional as F + +from src.simulation.component import Component + +################################################################################ +# Pre-emphasis filter +################################################################################ + + +class PreEmphasis(Component): + """ + Apply pre-emphasis filter via waveform convolution. Adapted from + https://github.com/clovaai/voxceleb_trainer/blob/master/utils.py + """ + + def __init__(self, coef: float = 0.97, method: str = 'shift'): + """ + Initialize filter + + :param coef: pre-emphasis coefficient + :param method: implementation; must be one of `conv` or `shift` + """ + super().__init__() + self.coef = coef + + if method not in ['conv', 'shift', None]: + raise ValueError(f'Invalid method {method}') + self.method = method + + # flip filter (cross-correlation --> convolution) + self.register_buffer( + 'flipped_filter', + torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0) + ) + + def forward(self, x: torch.Tensor): + """ + Apply pre-emphasis filter via waveform convolution + """ + + assert x.ndim >= 2 # require batch dimension + n_batch, signal_length = x.shape[0], x.shape[-1] + + # require channel dimension for convolution + x = x.reshape(n_batch, -1, signal_length) + in_channels = x.shape[1] + + if self.method == 'conv': + + # reflect padding to match lengths of in/out + x = F.pad(x, (1, 0), 'reflect') + return F.conv1d( + x, + self.flipped_filter.repeat(in_channels, 1, 1), + groups=in_channels + ) + + elif self.method == 'shift': + + return torch.cat( + [ + x[..., 0:1], + x[..., 1:] - self.coef*x[..., :-1] + ], dim=-1) + + else: + return x + diff --git a/voicebox/src/preprocess/preprocessor.py b/voicebox/src/preprocess/preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d7885d123f0edcda8eba1250ce280601e85e5c --- /dev/null +++ b/voicebox/src/preprocess/preprocessor.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn + +from typing import Iterable + +from src.simulation.component import Component + +################################################################################ +# Wrap preprocessing stages and apply sequentially +################################################################################ + + +class Preprocessor(nn.Module): + """ + Wrapper for sequential application of preprocessing stages. Allows for + straight-through gradient estimation. Because random parameter sampling is + not required, all modules are only required to be Component objects + """ + def __init__(self, *args): + super().__init__() + + stages = [] + + if len(args) == 1 and isinstance(args[0], Iterable): + for stage in args[0]: + assert isinstance(stage, Component), \ + "Arguments must be Component objects" + stages.append(stage) + else: + for stage in args: + assert isinstance(stage, Component), \ + "Arguments must be Component objects" + stages.append(stage) + + self.stages = nn.ModuleList(stages) + + def forward(self, x: torch.Tensor): + + # apply in sequence + for stage in self.stages: + + if stage.compute_grad: + x = stage(x) + + else: + # allow straight-through gradient estimation on backward pass + output = stage(x) + x = x + (output-x).detach() + + return x + diff --git a/voicebox/src/preprocess/vad.py b/voicebox/src/preprocess/vad.py new file mode 100644 index 0000000000000000000000000000000000000000..49899373cb3fd1f78836fd2efd3a06380a6cbaf2 --- /dev/null +++ b/voicebox/src/preprocess/vad.py @@ -0,0 +1,225 @@ +import torch +import math +import decimal + +from typing import List + +import torch.nn.functional as F + +from torchaudio.transforms import MFCC + +from src.simulation.component import Component + +################################################################################ +# Voice Activity Detection (VAD) +################################################################################ + + +class KaldiStyleVAD(Component): + """ + Kaldi-style Voice Activity Detection (VAD) module. Adapted from + https://github.com/fsepteixeira/FoolHD/blob/main/code/utils/vad_cmvn.py + """ + def __init__(self, + compute_grad: bool = True, + threshold: float = -15.0, + proportion_threshold: float = 0.12, + frame_len: float = 0.025, + hop_len: float = 0.010, + mean_scale: float = 0.5, + context: int = 2): + super().__init__(compute_grad) + + self.threshold = threshold + self.proportion_threshold = proportion_threshold + self.mean_scale = mean_scale + self.context = context + self.diff_zero = mean_scale != 0 + self.unfold_size = 2 * context + 1 + self.frame_len = int(frame_len * self.sample_rate) + self.hop_len = int(hop_len * self.sample_rate) + + # prepare to compute MFCC + self.mfcc = MFCC( + sample_rate=self.sample_rate, + n_mfcc=30, + dct_type=2, + norm='ortho', + log_mels=True, + melkwargs={ + 'n_fft': self.frame_len, + 'hop_length': self.hop_len, + 'n_mels': 30, + 'f_min': 20, + 'f_max': self.sample_rate // 2, + 'power': 2.0, + 'center': True + } + ) + + def forward(self, x: torch.Tensor): + + if x.shape[-1] < self.frame_len + self.hop_len: + return x + + # require batch dimension + assert x.ndim >= 2 + + # require mono audio, discard channel dimension + n_batch, slen = x.shape[0], x.shape[-1] + x = x.reshape(n_batch, slen) + + # compute MFCC + x_mfcc = self.mfcc(x).permute(0, 2, 1) # (n_batch, n_frames, n_mfcc) + + # set device for energy threshold + energy_threshold = torch.tensor([self.threshold]).to(x_mfcc.device) + + # first MFCC coefficient represents log energy + log_energy = x_mfcc[:, :, 0] + + if self.diff_zero: + energy_threshold = energy_threshold + self.mean_scale * log_energy.mean(dim=1) + + # prepare frame-wise mask + mask = torch.ones_like(log_energy) + + # pad borders with symmetric context before striding + mask = F.pad(mask, pad=(self.context, self.context), value=1.0) + + # get all (overlapping) context "windows" + mask = mask.unfold(dimension=1, size=self.unfold_size, step=1) + + # number of values included in each context window + den_count = mask.sum(dim=-1) + + # pad borders with symmetric context + log_energy = F.pad(log_energy, pad=(self.context, self.context)) + + # get all (overlapping) context "windows" + log_energy = log_energy.unfold( + dimension=1, + size=self.unfold_size, + step=1 + ) + + # number of values in each context window above threshold + num_count = log_energy.gt( + energy_threshold.unsqueeze(-1).unsqueeze(-1) + ).sum(dim=-1) + + # frame-by-frame mask + mask = num_count.ge(den_count*self.proportion_threshold) + + # "fold" to obtain waveform mask + mask_wav = mask.unsqueeze(-1).repeat_interleave( + repeats=self.frame_len, dim=-1 + ) + mask_wav = torch.cat( + [ + mask_wav[:, 0], + mask_wav[:, 1:][:, :, self.frame_len - self.hop_len:].reshape( + n_batch, -1 + ) + ], dim=-1 + ) + left_trim = self.frame_len // 2 + right_trim = mask_wav.shape[-1] - left_trim - x.shape[-1] + mask_wav = mask_wav[..., left_trim: -right_trim] + + # compute number of accepted samples per input waveform + samples_per_row: List[int] = [] + for e in torch.sum(mask_wav, dim=-1): + samples_per_row.append(e.item()) + + # split resulting tensor to keep trimmed inputs separate + split = torch.split(x[mask_wav], samples_per_row) + + # placeholder for outputs: (n_batch, 1, padded_length) + final = torch.zeros_like(x).unsqueeze(1) # pad to preserve length + + # concatenate and pad split views + for i, tensor in enumerate(split): + length = tensor.shape[-1] + final[i, :, :length] = tensor + + return final[..., :slen] + + +class VAD(Component): + """ + Apply Voice Activity Detection (VAD) while allowing for straight-through + gradient estimation. For now, only supports simple energy-based method, + and should be placed after normalization to avoid scale-dependence. + """ + def __init__(self, + compute_grad: bool = True, + frame_len: float = 0.05, + threshold: float = -72 + ): + + super().__init__(compute_grad) + + self.threshold = threshold + self.frame_len = int( + decimal.Decimal( + frame_len * self.sample_rate + ).quantize( + decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP + ) + ) # convert seconds to samples, round up + + def forward(self, x: torch.Tensor): + + # require batch dimension + assert x.ndim >= 2 + + # require mono audio, discard channel dimension + n_batch, slen = x.shape[0], x.shape[-1] + audio = x.reshape(n_batch, slen) + + eps = 1e-12 # numerical stability + + # determine number of frames + if slen <= self.frame_len: + n_frames = 1 + else: + n_frames = 1 + int( + math.ceil( + (1.0 * slen - self.frame_len) / self.frame_len) + ) + + # pad to integer frame length + padlen = int(n_frames * self.frame_len) + zeros = torch.zeros((x.shape[0], padlen - slen,)).to(x) + padded = torch.cat((audio, zeros), dim=-1) + + # obtain strided (frame-wise) view of audio + shape = (padded.shape[0], n_frames, self.frame_len) + frames = torch.as_strided( + padded, + size=shape, + stride=(padded.shape[-1], self.frame_len, 1) + ) + + # create frame-by-frame mask based on energy threshold + mask = 20 * torch.log10( + ((frames * self.scale).norm(dim=-1) / self.frame_len) + eps + ) > self.threshold + + # turn frame-by-frame mask into sample-by-sample mask + mask_wav = torch.repeat_interleave(mask, self.frame_len, dim=-1) + samples_per_row = torch.sum(mask, dim=-1) * self.frame_len + + split = torch.split(padded[mask_wav], tuple(samples_per_row)) + + # placeholder for outputs: (n_batch, 1, padded_length) + final = torch.zeros_like(padded).unsqueeze(1) # pad to preserve length + + # concatenate and pad split views + for i, tensor in enumerate(split): + length = tensor.shape[-1] + final[i, :, :length] = tensor + + return final[..., :slen] + diff --git a/voicebox/src/simulation/__init__.py b/voicebox/src/simulation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c82d6962f07a70be1fe0936c8c3ad687d8f9369b --- /dev/null +++ b/voicebox/src/simulation/__init__.py @@ -0,0 +1,9 @@ +from src.simulation.simulation import Simulation +from src.simulation.noise import Noise +from src.simulation.bandpass import Bandpass +from src.simulation.reverb import Reverb +from src.simulation.quantize import Quantize +from src.simulation.clip import Clip +from src.simulation.offset import Offset +from src.simulation.dropout import Dropout +from src.simulation.gain import Gain diff --git a/voicebox/src/simulation/__pycache__/__init__.cpython-310.pyc b/voicebox/src/simulation/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57b442302647ed1608adc725862f210a70697232 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/simulation/__pycache__/__init__.cpython-39.pyc b/voicebox/src/simulation/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5e3ff0eaece0b20ca087b17d72520853e82f3e7 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/simulation/__pycache__/bandpass.cpython-310.pyc b/voicebox/src/simulation/__pycache__/bandpass.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a7ecbc97d66c9a0e7fe96a8a9d7cc14218091d9 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/bandpass.cpython-310.pyc differ diff --git a/voicebox/src/simulation/__pycache__/bandpass.cpython-39.pyc b/voicebox/src/simulation/__pycache__/bandpass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..150a7d418ceada5dcb193ee95648dc12d2ddb65e Binary files /dev/null and b/voicebox/src/simulation/__pycache__/bandpass.cpython-39.pyc differ diff --git a/voicebox/src/simulation/__pycache__/clip.cpython-310.pyc b/voicebox/src/simulation/__pycache__/clip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d4e1163d61c7e6b43050466116371efa8218212 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/clip.cpython-310.pyc differ diff --git a/voicebox/src/simulation/__pycache__/clip.cpython-39.pyc b/voicebox/src/simulation/__pycache__/clip.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d82f24d506ae1af4e1f7f51c051191fa6c3d467 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/clip.cpython-39.pyc differ diff --git a/voicebox/src/simulation/__pycache__/component.cpython-310.pyc b/voicebox/src/simulation/__pycache__/component.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b419af399205f173a8fd6754ce2936d717eed90e Binary files /dev/null and b/voicebox/src/simulation/__pycache__/component.cpython-310.pyc differ diff --git a/voicebox/src/simulation/__pycache__/component.cpython-39.pyc b/voicebox/src/simulation/__pycache__/component.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3314730d52094dceadb5e5fcd3601de4ecae6b18 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/component.cpython-39.pyc differ diff --git a/voicebox/src/simulation/__pycache__/dropout.cpython-310.pyc b/voicebox/src/simulation/__pycache__/dropout.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d303946bc5019d6ba606ea0d7d3c2224aff9633e Binary files /dev/null and b/voicebox/src/simulation/__pycache__/dropout.cpython-310.pyc differ diff --git a/voicebox/src/simulation/__pycache__/dropout.cpython-39.pyc b/voicebox/src/simulation/__pycache__/dropout.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8514ee8cd291528c773404c93ee59e4f1aa562a3 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/dropout.cpython-39.pyc differ diff --git a/voicebox/src/simulation/__pycache__/effect.cpython-310.pyc b/voicebox/src/simulation/__pycache__/effect.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..865c662dbf6f22398f714de854b52beaba89d95a Binary files /dev/null and b/voicebox/src/simulation/__pycache__/effect.cpython-310.pyc differ diff --git a/voicebox/src/simulation/__pycache__/effect.cpython-39.pyc b/voicebox/src/simulation/__pycache__/effect.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4e6c4c1e8b461dd40b7ec39fb73dc7d307d0796 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/effect.cpython-39.pyc differ diff --git a/voicebox/src/simulation/__pycache__/gain.cpython-310.pyc b/voicebox/src/simulation/__pycache__/gain.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2af31651103893122fb48a49146f33936b2e872 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/gain.cpython-310.pyc differ diff --git a/voicebox/src/simulation/__pycache__/gain.cpython-39.pyc b/voicebox/src/simulation/__pycache__/gain.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6eb50f4a40305f2e00e3e2d049042f96fb4d16f9 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/gain.cpython-39.pyc differ diff --git a/voicebox/src/simulation/__pycache__/noise.cpython-310.pyc b/voicebox/src/simulation/__pycache__/noise.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d8a081e5a957d2a0d4a1cb4aa2e8496b61660b9 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/noise.cpython-310.pyc differ diff --git a/voicebox/src/simulation/__pycache__/noise.cpython-39.pyc b/voicebox/src/simulation/__pycache__/noise.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1826ac784949ac80c4fed283beb680425872384c Binary files /dev/null and b/voicebox/src/simulation/__pycache__/noise.cpython-39.pyc differ diff --git a/voicebox/src/simulation/__pycache__/offset.cpython-310.pyc b/voicebox/src/simulation/__pycache__/offset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dfef011835018f452816dc7207315144901763f Binary files /dev/null and b/voicebox/src/simulation/__pycache__/offset.cpython-310.pyc differ diff --git a/voicebox/src/simulation/__pycache__/offset.cpython-39.pyc b/voicebox/src/simulation/__pycache__/offset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b06e2e2b99338454660efe93affdb060dedbe44 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/offset.cpython-39.pyc differ diff --git a/voicebox/src/simulation/__pycache__/quantize.cpython-310.pyc b/voicebox/src/simulation/__pycache__/quantize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a21325b5d22e5e22648e6f9ba55909396dbd2ac9 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/quantize.cpython-310.pyc differ diff --git a/voicebox/src/simulation/__pycache__/quantize.cpython-39.pyc b/voicebox/src/simulation/__pycache__/quantize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..686e3185b5e4ed22835af559c65232b224771a59 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/quantize.cpython-39.pyc differ diff --git a/voicebox/src/simulation/__pycache__/reverb.cpython-310.pyc b/voicebox/src/simulation/__pycache__/reverb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68a0f7f2593e67be6f04627dd78c9046b3fa56b4 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/reverb.cpython-310.pyc differ diff --git a/voicebox/src/simulation/__pycache__/reverb.cpython-39.pyc b/voicebox/src/simulation/__pycache__/reverb.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b463c9b14ac8e72fc1fe48a16b862cd4f6cc21e Binary files /dev/null and b/voicebox/src/simulation/__pycache__/reverb.cpython-39.pyc differ diff --git a/voicebox/src/simulation/__pycache__/simulation.cpython-310.pyc b/voicebox/src/simulation/__pycache__/simulation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..354ad76f4fb8c3e2a7aa257084a899b9822695b6 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/simulation.cpython-310.pyc differ diff --git a/voicebox/src/simulation/__pycache__/simulation.cpython-39.pyc b/voicebox/src/simulation/__pycache__/simulation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..faeda546f3a544746a2ba88fccd8e043bf97b145 Binary files /dev/null and b/voicebox/src/simulation/__pycache__/simulation.cpython-39.pyc differ diff --git a/voicebox/src/simulation/bandpass.py b/voicebox/src/simulation/bandpass.py new file mode 100644 index 0000000000000000000000000000000000000000..784efa75b9ad93f35a38aea3cfe5cb3d63e0f271 --- /dev/null +++ b/voicebox/src/simulation/bandpass.py @@ -0,0 +1,112 @@ +import random +import torch +import torchaudio + +import torch.nn.functional as F + +from scipy.signal import firwin2 + +from src.simulation.effect import Effect + +torchaudio.set_audio_backend("sox_io") + +################################################################################ +# Bandpass filter +################################################################################ + + +class Bandpass(Effect): + + def __init__(self, compute_grad: bool = True, + low: any = None, + high: any = None): + super().__init__(compute_grad) + + self.min_low, self.max_low = self.parse_range( + low, + int, + f'Invalid cutoff frequency {low}' + ) + + self.min_high, self.max_high = self.parse_range( + high, + int, + f'Invalid cutoff frequency {high}' + ) + + if self.max_high > (self.sample_rate / 2) - 100: + raise ValueError( + f'Cutoff too close to Nyquist frequency' + f' {self.sample_rate/2}Hz; may produce ringing') + + # store impulse response as buffer to allow device movement + self.low, self.high = None, None + self.register_buffer("filter", torch.zeros(1, dtype=torch.float32)) + + # initialize filter + self.sample_params() + + def forward(self, x: torch.Tensor): + """ + Perform waveform convolution with FIR bandpass filter + """ + + # require batch and channel dimensions + n_batch, signal_length = x.shape[0], x.shape[-1] + x = x.reshape(n_batch, -1, signal_length) + + pad = F.pad(x, (self.filter.shape[-1]-1, 0)) + return F.conv1d(pad, self.filter.clone().to(x)) + + def sample_params(self): + """ + Sample cutoff frequencies, generate FIR lowpass and highpass filters, + convolve (with 'full' padding) to obtain a single FIR bandpass filter + """ + self.low = random.uniform(self.min_low, self.max_low) + self.high = random.uniform(self.min_high, self.max_high) + + n_taps = 257 # length of each FIR filter + width = 0.001 # width of filter transition band + + freq_hp = [ + 0.0, + self.low / (1 + width), + self.low * (1 + width), + self.sample_rate/2 + ] + freq_lp = [ + 0.0, + self.high / (1 + width), + self.high * (1 + width), + self.sample_rate/2 + ] + + gain_hp = [0.0, 0.0, 1.0, 1.0] + gain_lp = [1.0, 1.0, 0.0, 0.0] + + hp = torch.as_tensor( + firwin2( + numtaps=n_taps, + freq=freq_hp, + gain=gain_hp, + fs=self.sample_rate + ) + ) + lp = torch.as_tensor( + firwin2( + numtaps=n_taps, + freq=freq_lp, + gain=gain_lp, + fs=self.sample_rate + ) + ) + + self.filter = F.conv1d( + F.pad( + torch.as_tensor(lp).flip([-1]).reshape(1, 1, -1), + (hp.shape[-1] - 1, hp.shape[-1] - 1) # 'full' padding + ), + torch.as_tensor(hp).flip([-1]).reshape(1, 1, -1) + ).flip([-1]).reshape(1, 1, -1).to(self.filter) + diff --git a/voicebox/src/simulation/clip.py b/voicebox/src/simulation/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..1c062892fc2b643b80dc031dac66142e391464b7 --- /dev/null +++ b/voicebox/src/simulation/clip.py @@ -0,0 +1,36 @@ +import random +import torch + +from src.simulation.effect import Effect + +################################################################################ +# Perform simple clipping at waveform +################################################################################ + + +class Clip(Effect): + + def __init__(self, + compute_grad: bool = True, + scale: any = 1.0): + super().__init__(compute_grad) + + # parse valid range of clipping scale parameter + self.min_scale, self.max_scale = self.parse_range( + scale, + float, + f'Invalid clipping scale {scale}' + ) + + assert 0 <= scale <= self.scale + + self.clip_scale = None + self.sample_params() + + def forward(self, x: torch.Tensor): + return torch.clamp(x, min=-self.clip_scale, max=self.clip_scale) + + def sample_params(self): + self.clip_scale = random.uniform(self.min_scale, self.max_scale) + + diff --git a/voicebox/src/simulation/component.py b/voicebox/src/simulation/component.py new file mode 100644 index 0000000000000000000000000000000000000000..47ee3caaab817a59786c535ba02e0866abd1cd5e --- /dev/null +++ b/voicebox/src/simulation/component.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn + +from src.data import DataProperties + +################################################################################ +# Base class for differentiable audio-processing units +################################################################################ + + +class Component(nn.Module): + """ + Base class for differentiable audio-processing units + """ + def __init__(self, compute_grad: bool = True): + super().__init__() + self.compute_grad = compute_grad + + # fetch persistent data properties + self.sample_rate, self.scale, self.signal_length = DataProperties.get( + 'sample_rate', + 'scale', + 'signal_length' + ) + + def forward(self, x: torch.Tensor): + raise NotImplementedError() diff --git a/voicebox/src/simulation/dropout.py b/voicebox/src/simulation/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..954aac43a46e7a609e7f19e5e8323bef6b1bb7cd --- /dev/null +++ b/voicebox/src/simulation/dropout.py @@ -0,0 +1,38 @@ +import random +import torch + +from src.simulation.effect import Effect + +################################################################################ +# Random time-domain dropout +################################################################################ + + +class Dropout(Effect): + + def __init__(self, compute_grad: bool = True, rate: any = None): + + super().__init__(compute_grad) + + self.min_rate, self.max_rate = self.parse_range( + rate, + float, + f'Invalid signal dropout rate {rate}' + ) + + # store waveform mask as buffer to allow device movement + self.register_buffer("mask", torch.zeros(1, dtype=torch.float32)) + self.sample_params() + + def forward(self, x: torch.Tensor): + return self.mask.clone().to(x) * x + + def sample_params(self): + """ + Sample dropout rate uniformly and apply random dropout + """ + rate = random.uniform(self.min_rate, self.max_rate) + idx = torch.randperm(self.signal_length + )[:round(rate * self.signal_length)] + self.mask = torch.ones(self.signal_length).to(self.mask) + self.mask[..., idx] = 0 diff --git a/voicebox/src/simulation/effect.py b/voicebox/src/simulation/effect.py new file mode 100644 index 0000000000000000000000000000000000000000..8e0dbf9e310f5becae203b11ae28a0200496d321 --- /dev/null +++ b/voicebox/src/simulation/effect.py @@ -0,0 +1,55 @@ +import torch + +from typing import Any, Union, Sequence + +from src.simulation.component import Component + +################################################################################ +# Simulate environmental acoustic distortions in sequence +################################################################################ + + +class Effect(Component): + """ + Base class for all acoustic simulation effects units. Adds random parameter + sampling functionality to Component class. + """ + def __init__(self, compute_grad: bool = True): + super().__init__(compute_grad) + + def forward(self, x: torch.Tensor): + raise NotImplementedError() + + def sample_params(self): + """ + Sample effect parameters to allow for expectation-over-transformation + """ + raise NotImplementedError() + + @staticmethod + def parse_range(params: Any, dtype: Any, error_msg: str): + """ + For real-valued parameters, obtain acceptable range of values from which + to sample randomly + """ + + # for any sequence, assume endpoints mark range of values + if isinstance(params, Sequence): + min_val, max_val = params[0], params[1] + + # if a single value is given, use as both "endpoints" + elif isinstance(params, dtype): + min_val = max_val = params + + else: + raise ValueError(error_msg) + + try: + assert isinstance(min_val, dtype) + assert isinstance(max_val, dtype) + + except AssertionError: + raise ValueError(error_msg) + + return min_val, max_val + diff --git a/voicebox/src/simulation/gain.py b/voicebox/src/simulation/gain.py new file mode 100644 index 0000000000000000000000000000000000000000..0bac0e1ecf3545e704b06d6cc8092901050e3605 --- /dev/null +++ b/voicebox/src/simulation/gain.py @@ -0,0 +1,30 @@ +import random +import torch + +from src.simulation.effect import Effect + +################################################################################ +# Simple gain scaling +################################################################################ + + +class Gain(Effect): + + def __init__(self, compute_grad: bool = True, level: any = None): + + super().__init__(compute_grad) + + self.min_level, self.max_level = self.parse_range( + level, + float, + f'Invalid gain {level}' + ) + + self.level = None + self.sample_params() + + def forward(self, x: torch.Tensor): + return x * self.level + + def sample_params(self): + self.level = random.uniform(self.min_level, self.max_level) diff --git a/voicebox/src/simulation/noise.py b/voicebox/src/simulation/noise.py new file mode 100644 index 0000000000000000000000000000000000000000..1c7483525661774b165180d7bc2c039c5245fdd1 --- /dev/null +++ b/voicebox/src/simulation/noise.py @@ -0,0 +1,162 @@ +import random +import math +import torch +import torch.nn.functional as F +import torchaudio + +from pathlib import Path + +import librosa as li +from src.simulation.effect import Effect + +torchaudio.set_audio_backend("sox_io") + +################################################################################ +# Simulate environmental noise +################################################################################ + + +class Noise(Effect): + """ + Simple additive noise effect + """ + def __init__(self, + compute_grad: bool = True, + type: str = 'gaussian', + snr: any = None, + noise_dir: str = None, + ext: str = "wav"): + """ + Apply additive noise to audio signal. SNR calculations adapted from + VoxCeleb-Trainer (https://github.com/clovaai/voxceleb_trainer/) + + :param compute_grad: if False, perform straight-through gradient + estimation + :param type: type of noise to add; must be one of `gaussian`, + `uniform`, or `environmental` + :param snr: decibel Signal-to-Noise ratio (dB SNR) of added noise + :param noise_dir: directory from which to draw noise samples, if `type` + is `environmental` + :param ext: extension for audio files in `noise_dir` + """ + super().__init__(compute_grad) + + self.type = type + self.noise_list = None + self.ext = ext + + if type == 'environmental': + if not noise_dir: + raise ValueError( + 'Environmental noise requires sample directory' + ) + else: + self.noise_list = list(Path(noise_dir).rglob(f'*.{self.ext}')) + + # parse valid range of SNR parameter + self.min_snr, self.max_snr = self.parse_range( + snr, + float, + f'Invalid noise SNR {snr}' + ) + + # store noise as buffer to allow device movement + self.register_buffer("noise", torch.zeros(1, dtype=torch.float32)) + self.register_buffer("noise_db", torch.zeros(1, dtype=torch.float32)) + + # initialize parameters + self.snr = None + self.sample_params() + + def forward(self, x: torch.Tensor): + + # require batch, channel dimensions + assert x.ndim >= 2 + orig_shape = x.shape + + if x.ndim == 2: + x = x.unsqueeze(1) + + # scale noise level to stored SNR + signal_db = 10 * torch.log10( + torch.mean(torch.square(x), dim=-1, keepdims=True) + 1e-8 + ) + scale = torch.sqrt( + torch.pow(10, (signal_db - self.noise_db - self.snr) / 10) + ) + + # scale noise and trim to input length + noise = scale * self.noise.clone().to(x)[..., :x.shape[-1]] + + # repeat noise to match input length if necessary + pad_len = max(x.shape[-1] - noise.shape[-1], 0) + noise = F.pad(noise, (0, pad_len), mode='circular') + + # reshape to original dimensions + return (noise + x).reshape(orig_shape) + + @staticmethod + def _crossfade(sig, fade_len): + sig = sig.clone() + fade_len = int(fade_len * sig.shape[-1]) + fade_in = torch.linspace(0, 1, fade_len).to(sig) + fade_out = torch.linspace(1, 0, fade_len).to(sig) + sig[..., :fade_len] *= fade_in + sig[..., -fade_len:] *= fade_out + return sig + + def sample_params(self): + """ + Sample SNR uniformly from stored range + """ + self.snr = random.uniform(self.min_snr, self.max_snr) + + if self.type == "gaussian": + self.noise = torch.randn(self.signal_length).to(self.noise) + elif self.type == "uniform": + self.noise = torch.sign( + torch.randn(self.signal_length) + ).to(self.noise) + elif self.type == "environmental": + + # load from randomly-selected file + noise_np, _ = li.load( + random.choice(self.noise_list), + sr=self.sample_rate, mono=True + ) + noise = torch.as_tensor(noise_np) + + # trim or loop (with cross-fade) to match expected signal length + if noise.shape[-1] >= self.signal_length: + self.noise = noise[..., :self.signal_length].reshape( + 1, 1, -1 + ).to(self.noise) + else: + + overlap = 0.05 + step = math.ceil(noise.shape[-1] * (1 - overlap)) + n_repeat = math.ceil(self.signal_length / step) + + padded = torch.zeros( + 1, step * (n_repeat - 1) + noise.shape[-1] + 1 + ).reshape(1, -1).type(torch.float32) + shape = padded.shape[:-1] + (n_repeat, noise.shape[-1]) + + strides = (padded.stride()[0],) + (step, padded.stride()[-1],) + frames = torch.as_strided( + padded, size=shape, stride=strides + )[::step] + + for j in range(n_repeat): + frames[:, j, :] += self.crossfade(noise, overlap) + + self.noise = padded[..., :self.signal_length].reshape( + 1, 1, -1 + ).to(self.noise) + + else: + raise ValueError(f'Invalid noise type {self.type}') + + self.noise_db = 10 * torch.log10( + torch.mean(torch.square(self.noise), dim=-1, keepdims=True) + 1e-8 + ).to(self.noise_db) diff --git a/voicebox/src/simulation/offset.py b/voicebox/src/simulation/offset.py new file mode 100644 index 0000000000000000000000000000000000000000..b180a30d97e76c178d3b2d1c19755fc90069b27e --- /dev/null +++ b/voicebox/src/simulation/offset.py @@ -0,0 +1,42 @@ +import random +import torch + +from src.simulation.effect import Effect + +################################################################################ +# Random time-domain offset +################################################################################ + + +class Offset(Effect): + + def __init__(self, compute_grad: bool = True, length: any = None): + """ + Shift audio and trim/zero-pad to maintain length + + :param compute_grad: if False, use straight-through gradient estimator + :param length: offset length in seconds; sign indicates direction + """ + super().__init__(compute_grad) + + self.min_length, self.max_length = self.parse_range( + length, + float, + f'Invalid offset length {length}' + ) + + self.length = None + self.sample_params() + + def forward(self, x: torch.Tensor): + shifted = torch.roll(x, shifts=self.length, dims=-1) + if self.length >= 0: + shifted[..., :self.length] = 0 + else: + shifted[..., self.length:] = 0 + return shifted + + def sample_params(self): + self.length = round( + random.uniform(self.min_length, self.max_length) * self.sample_rate + ) diff --git a/voicebox/src/simulation/quantize.py b/voicebox/src/simulation/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..72cabedc3c69a1cbe98201e6f15395c8187f3770 --- /dev/null +++ b/voicebox/src/simulation/quantize.py @@ -0,0 +1,35 @@ +import random +import torch + +from src.simulation.effect import Effect + +################################################################################ +# Simulate simple quantization distortions +################################################################################ + + +class Quantize(Effect): + + def __init__(self, bits: any = 8): + super().__init__(compute_grad=False) + + self.min_bits, self.max_bits = self.parse_range( + bits, + int, + f'Invalid bit depth {bits}' + ) + self.bits = None + self.sample_params() + + def forward(self, x: torch.Tensor): + + # rescale full range to -2^(bits - 1), 2^(bits - 1) + scale_factor = 2 ** (self.bits - 1) + x_scaled = x * scale_factor / self.scale + x_quant = torch.round(x_scaled) + return x_quant * self.scale / scale_factor + + def sample_params(self): + self.bits = round( + random.uniform(self.min_bits, self.max_bits) + ) diff --git a/voicebox/src/simulation/reverb.py b/voicebox/src/simulation/reverb.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b39e4ffc6252540f9e27e73f2ecd175f2ac844 --- /dev/null +++ b/voicebox/src/simulation/reverb.py @@ -0,0 +1,97 @@ +import random +import torch + +import torch.fft as fft +import torch.nn.functional as F +import librosa as li + +from pathlib import Path + +from src.simulation.effect import Effect + +################################################################################ +# Convolutional reverb effect +################################################################################ + + +class Reverb(Effect): + """ + Apply impulse responses sampled from a given dataset + """ + def __init__(self, + compute_grad: bool = True, + rir_dir: str = None, + ext: str = "wav", + fft_convolve: bool = True + ): + super().__init__(compute_grad) + self.rir_dir = rir_dir + self.ext = ext + self.fft_convolve = fft_convolve + + self.rir_list = list(Path(self.rir_dir).rglob(f'*.{self.ext}')) + + # store room impulse response as buffer to allow device movement + self.register_buffer("rir", torch.zeros(1, dtype=torch.float32)) + + # initialize RIR + self.sample_params() + + @staticmethod + def _fft_convolve(signal: torch.Tensor, kernel: torch.Tensor): + + # ensure signal and kernel have channel dimension + signal = signal.reshape(signal.shape[0], -1) + kernel = kernel.reshape(kernel.shape[0], -1) + + signal_len = signal.shape[-1] + kernel_len = kernel.shape[-1] + kernel = F.pad( + kernel, (0, signal_len - kernel_len) + ) + + signal = F.pad(signal, (0, signal.shape[-1])) + kernel = F.pad(kernel, (kernel.shape[-1], 0)) + + output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel)) + output = output[..., output.shape[-1] // 2:] + + return output.unsqueeze(1) + + def forward(self, x: torch.Tensor): + + n_batch = x.shape[0] + if len(x.shape) < 3: + x = x.reshape(n_batch, 1, -1) + + if self.fft_convolve: + return self._fft_convolve(x, self.rir.clone().to(x)) + else: + pad = F.pad(x, (self.rir.shape[-1]-1, 0)) + return F.conv1d(pad, self.rir.clone().to(x)) + + def sample_params(self): + """ + Sample and preprocess room impulse response + """ + + rir_np, _ = li.load(random.choice(self.rir_list), + sr=self.sample_rate, mono=True) + rir = torch.as_tensor(rir_np) + + # trim leading silence + offsets = torch.where(torch.abs(rir) > (torch.abs(rir).max() / 100))[0] + rir = rir[..., offsets[0]:] + + # trim to signal length + rir = rir[..., :self.signal_length] + + # normalize + rir = rir / torch.norm(rir, p=2) + + # flip if waveform convolution + if not self.fft_convolve: + self.rir = torch.flip(rir, [-1]).reshape(1, 1, -1).to(self.rir) + else: + self.rir = rir.reshape(1, 1, -1).to(self.rir) + diff --git a/voicebox/src/simulation/simulation.py b/voicebox/src/simulation/simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..9b8e588cf4d194fee4647e4914beba1e319499b0 --- /dev/null +++ b/voicebox/src/simulation/simulation.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn + +from typing import Iterable + +from src.simulation.effect import Effect + +################################################################################ +# Wrap effects units to apply in sequence +################################################################################ + + +class Simulation(nn.Module): + """ + Wrapper for sequential application of effects units. Allows for straight- + through gradient estimation and random effect parameter sampling. + """ + def __init__(self, *args): + super().__init__() + + effects = [] + + if len(args) == 1 and isinstance(args[0], Iterable): + for effect in args[0]: + assert isinstance(effect, Effect), \ + "Arguments must be Effect objects" + effects.append(effect) + else: + for effect in args: + assert isinstance(effect, Effect), \ + "Arguments must be Effect objects" + effects.append(effect) + + self.effects = nn.ModuleList(effects) + + def forward(self, x: torch.Tensor): + + for effect in self.effects: + + if effect.compute_grad: + x = effect(x) + + else: + # allow straight-through gradient estimation on backward pass + output = effect(x) + x = x + (output-x).detach() + + return x + + def sample_params(self): + + for effect in self.effects: + effect.sample_params() diff --git a/voicebox/src/utils/__init__.py b/voicebox/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7709e09f2da1aab4f1fd018bf33256f767c45abe --- /dev/null +++ b/voicebox/src/utils/__init__.py @@ -0,0 +1,24 @@ +from src.utils.filesystem import ( + ensure_dir, + ensure_dir_for_filename +) +from src.utils.device import ( + get_cuda_device_ids, + wrap_module_multi_gpu, + wrap_attack_multi_gpu, + wrap_pipeline_multi_gpu, + unwrap_module_multi_gpu, + unwrap_attack_multi_gpu, + unwrap_pipeline_multi_gpu, + DataParallelWrapper, +) +from src.utils.analysis import * +from src.utils.data import ( + text_to_tensor, + padded_transcript_length, + dataset_to_device, + create_embedding_dataset, + create_transcription_dataset +) +from src.utils.plotting import * +from src.utils.writer import * diff --git a/voicebox/src/utils/__pycache__/__init__.cpython-310.pyc b/voicebox/src/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93564a3a6c4396a84346c1757e37bac8efa1edf4 Binary files /dev/null and b/voicebox/src/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/voicebox/src/utils/__pycache__/__init__.cpython-39.pyc b/voicebox/src/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76a9cc49039c579f05df94477da12ce74bba3a92 Binary files /dev/null and b/voicebox/src/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/voicebox/src/utils/__pycache__/analysis.cpython-310.pyc b/voicebox/src/utils/__pycache__/analysis.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8fb4aa06ef7603ceb95259e991862954b2101eb Binary files /dev/null and b/voicebox/src/utils/__pycache__/analysis.cpython-310.pyc differ diff --git a/voicebox/src/utils/__pycache__/analysis.cpython-39.pyc b/voicebox/src/utils/__pycache__/analysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b74d0bd4ee36d69f632ee631984c1da9491af61 Binary files /dev/null and b/voicebox/src/utils/__pycache__/analysis.cpython-39.pyc differ diff --git a/voicebox/src/utils/__pycache__/data.cpython-310.pyc b/voicebox/src/utils/__pycache__/data.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57c09858b51d05429fe454f84a547bdf3aaa6a14 Binary files /dev/null and b/voicebox/src/utils/__pycache__/data.cpython-310.pyc differ diff --git a/voicebox/src/utils/__pycache__/data.cpython-39.pyc b/voicebox/src/utils/__pycache__/data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14cb72b044f3341a95743a043b3e58f42e90ceb6 Binary files /dev/null and b/voicebox/src/utils/__pycache__/data.cpython-39.pyc differ diff --git a/voicebox/src/utils/__pycache__/device.cpython-310.pyc b/voicebox/src/utils/__pycache__/device.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca4b69983c8bfc39601960ec7e47a0e8a26a8081 Binary files /dev/null and b/voicebox/src/utils/__pycache__/device.cpython-310.pyc differ diff --git a/voicebox/src/utils/__pycache__/device.cpython-39.pyc b/voicebox/src/utils/__pycache__/device.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee6ddb52a84c2f9c2aad92f5b75c3833bc887e2c Binary files /dev/null and b/voicebox/src/utils/__pycache__/device.cpython-39.pyc differ diff --git a/voicebox/src/utils/__pycache__/filesystem.cpython-310.pyc b/voicebox/src/utils/__pycache__/filesystem.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd622a667bda0a81493ed5d9b7709e1328348901 Binary files /dev/null and b/voicebox/src/utils/__pycache__/filesystem.cpython-310.pyc differ diff --git a/voicebox/src/utils/__pycache__/filesystem.cpython-39.pyc b/voicebox/src/utils/__pycache__/filesystem.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ace69a88448f1cfc153cdbe9bddbd3b0cc6257c1 Binary files /dev/null and b/voicebox/src/utils/__pycache__/filesystem.cpython-39.pyc differ diff --git a/voicebox/src/utils/__pycache__/plotting.cpython-310.pyc b/voicebox/src/utils/__pycache__/plotting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..526145c9bec41f54d76ef05061385b83c995804b Binary files /dev/null and b/voicebox/src/utils/__pycache__/plotting.cpython-310.pyc differ diff --git a/voicebox/src/utils/__pycache__/plotting.cpython-39.pyc b/voicebox/src/utils/__pycache__/plotting.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..343d94969d669e49d1ed8e61cf7cba03e5e1a456 Binary files /dev/null and b/voicebox/src/utils/__pycache__/plotting.cpython-39.pyc differ diff --git a/voicebox/src/utils/__pycache__/writer.cpython-310.pyc b/voicebox/src/utils/__pycache__/writer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30dee757371a32f156db9eac452d4b8a0549fb93 Binary files /dev/null and b/voicebox/src/utils/__pycache__/writer.cpython-310.pyc differ diff --git a/voicebox/src/utils/__pycache__/writer.cpython-39.pyc b/voicebox/src/utils/__pycache__/writer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44456438e41f3f421e5e8fd72733cfc67e9ff73f Binary files /dev/null and b/voicebox/src/utils/__pycache__/writer.cpython-39.pyc differ diff --git a/voicebox/src/utils/analysis.py b/voicebox/src/utils/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..f663ba0c77abf536440a2c38a81b0b4168ff5659 --- /dev/null +++ b/voicebox/src/utils/analysis.py @@ -0,0 +1,417 @@ +import math +import torch +import pandas as pd +import numpy as np + +from typing import List, Union + +from pesq import pesq +from pystoi import stoi +from tqdm import tqdm + +from src.data import DataProperties +from src.utils.plotting import tensor_to_np +from src.models.speech import Wav2Vec2, GreedyCTCDecoder + +################################################################################ +# Utilities for analyzing attack results +################################################################################ + + +@torch.no_grad() +def run_perceptual_evaluation(x: torch.Tensor, + x_adv: torch.Tensor, + batch_size: int = 1, + device: Union[str, torch.cuda.device] = 'cpu', + tag: str = None, + **kwargs + ): + """ + Compute perceptual quality metrics on pairs of clean and adversarial audio + + Parameters + ---------- + x (Tensor): shape + + x_adv (Tensor): shape + + batch_size (int): + + device (str): + + Returns + ------- + + + """ + + # check for compatible audio dimensions + assert x.ndim == x_adv.ndim + + # require batch dimension + assert x.ndim >= 2 + n_batch = x.shape[0] + + # store results + results = {} + + # name results + tag = '' if tag is None else f'{tag}-' + + ############################################################################ + # WAVEFORM P-NORM DISTANCE + ############################################################################ + + # if dimensions match, measure L-2 and L-inf distance between waveforms + if x.shape == x_adv.shape: + + reduce_dims = tuple(range(1, x.ndim)) + + l2 = (x - x_adv).norm( + p=2, dim=reduce_dims).flatten().tolist() + linf = (x - x_adv).norm( + p=float('inf'), dim=reduce_dims).flatten().tolist() + + results = { + **results, + tag + 'L2-Waveform': l2, + tag + 'Linf-Waveform': linf + } + + ############################################################################ + # PESQ OBJECTIVE MEASURE (DEPRECATED) + ############################################################################ + + assert DataProperties.get('sample_rate') in [16000, 8000], \ + f"Cannot perform PESQ evaluation with sample rate " \ + f"{DataProperties.get('sample_rate')}Hz; must be 8000Hz or 16000Hz" + + wb_scores, nb_scores = [], [] + + for i in tqdm(range(n_batch), desc='computing PESQ scores'): + + wb_scores.append( + pesq(DataProperties.get('sample_rate'), + tensor_to_np(x[i]).flatten(), + tensor_to_np(x_adv[i]).flatten(), + 'wb') + ) + nb_scores.append( + pesq(DataProperties.get('sample_rate'), + tensor_to_np(x[i]).flatten(), + tensor_to_np(x_adv[i]).flatten(), + 'nb') + ) + + results = { + **results, + tag + 'PESQ-Wideband': wb_scores, + tag + 'PESQ-Narrowband': nb_scores, + } + + ############################################################################ + # STOI OBJECTIVE MEASURE (DEPRECATED) + ############################################################################ + + cl_scores, ex_scores = [], [] + + for i in tqdm(range(n_batch), desc='computing STOI scores'): + + cl_scores.append( + stoi(tensor_to_np(x[i]).flatten(), + tensor_to_np(x_adv[i]).flatten(), + DataProperties.get('sample_rate'), + extended=False) + ) + ex_scores.append( + stoi(tensor_to_np(x[i]).flatten(), + tensor_to_np(x_adv[i]).flatten(), + DataProperties.get('sample_rate'), + extended=True) + ) + + results = { + **results, + tag + 'STOI-Extended': cl_scores, + tag + 'STOI-Classical': ex_scores, + } + + ############################################################################ + # BSS-EVAL SIGNAL METRICS + ############################################################################ + + si_sdr, sd_sdr, snr, srr = [], [], [], [] + + for i in tqdm(range(n_batch), desc='computing BSS-EVAL metrics'): + + si_sdr_i, sd_sdr_i, snr_i, srr_i = _bss_eval( + tensor_to_np(x_adv[i]).flatten(), + tensor_to_np(x[i]).flatten()) + + si_sdr.append(si_sdr_i) + sd_sdr.append(sd_sdr_i) + snr.append(snr_i) + srr.append(srr_i) + + results = { + **results, + tag + 'SI-SDR': si_sdr, + tag + 'SD-SDR': sd_sdr, + tag + 'SNR': snr, + tag + 'SRR': srr + } + + ############################################################################ + # ASR TRANSCRIPTION METRICS + ############################################################################ + + # initialize ASR model / decoder + model = Wav2Vec2() + decoder = GreedyCTCDecoder(labels=model.labels) + + # obtain delimiter token + delimiter = decoder.get_labels()[decoder.get_sep_idx()] + + # move model to given device + model.to(device) + + # store original and adversarial transcriptions + transcriptions = [] + transcriptions_adv = [] + + n_batches = math.ceil(len(x) / batch_size) + for i in tqdm(range(n_batches), desc='computing WER/CER'): + + # move batches to device and pass to model + x_batch = x[batch_size*i:batch_size*(i+1)].to(device) + x_adv_batch = x_adv[batch_size*i:batch_size*(i+1)].to(device) + + emit_batch = model(x_batch) + emit_adv_batch = model(x_adv_batch) + + # decode sequence probability emissions to obtain string transcriptions + transcriptions.extend(decoder(emit_batch)[0]) + transcriptions_adv.extend(decoder(emit_adv_batch)[0]) + + # ASR WER + wer = compute_wer(transcriptions, transcriptions_adv, delimiter) + + # ASR CER + cer = compute_cer(transcriptions, transcriptions_adv, delimiter) + + results = { + **results, + tag + 'ASR-WER': wer, + tag + 'ASR-CER': cer, + } + + return results + + +def compute_wer( + reference: List[str], + transcription: List[str], + delimiter: str = ' '): + """ + Compute average word error rate (WER) between string transcriptions. + + WER = (Sw + Dw + Iw) / Nw + + where: + Sw is the number of words substituted, + Dw is the number of words deleted, + Iw is the number of words inserted, + Nw is the number of words in the reference + + Parameters + ---------- + + Returns + ------- + + """ + + assert len(reference) == len(transcription) + + # for each reference-transcription pair in batch, count errors of each of + # the four types as well as total word count + + total_edit_dist = 0 + total_ref_len = 0 + + for r, t in zip(reference, transcription): + + edit_dist, ref_len = _word_errors(r, t, delimiter=delimiter) + + if ref_len == 0: + raise ValueError("Reference sentences must nonzero word count") + + total_edit_dist += edit_dist + total_ref_len += ref_len + + wer = float(total_edit_dist) / total_ref_len + return wer + + +def compute_cer( + reference: List[str], + transcription: List[str], + delimiter: str = ' ', + remove_delimiter: bool = False): + """ + Compute average character error rate (CER) between string transcriptions. + + WER = (Sc + Dc + Ic) / Nc + + where: + Sc is the number of characters substituted, + Dc is the number of characters deleted, + Ic is the number of characters inserted, + Nc is the number of characters in the reference + + Parameters + ---------- + + Returns + ------- + + """ + + assert len(reference) == len(transcription) + + # for each reference-transcription pair in batch, count errors of each of + # the four types as well as total character count + + total_edit_dist = 0 + total_ref_len = 0 + + for r, t in zip(reference, transcription): + + edit_dist, ref_len = _char_errors(r, + t, + delimiter, + remove_delimiter) + + if ref_len == 0: + raise ValueError("Reference sentences must nonzero character count") + + total_edit_dist += edit_dist + total_ref_len += ref_len + + cer = float(total_edit_dist) / total_ref_len + return cer + + +def _word_errors(reference: str, transcription: str, delimiter: str = ' '): + """ + Compute the Levenshtein distance between reference and transcription + sequences at word level. + """ + + reference = reference.lower() + transcription = transcription.lower() + + ref_words = reference.split(delimiter) + tra_words = transcription.split(delimiter) + + edit_distance = _levenshtein_distance(ref_words, tra_words) + return float(edit_distance), len(ref_words) + + +def _char_errors(reference: str, + transcription: str, + delimiter: str = ' ', + remove_delimiter: bool = False + ): + """ + Compute the Levenshtein distance between reference and transcription + sequences at word level. + """ + + reference = reference.lower() + transcription = transcription.lower() + + join_char = delimiter + if remove_delimiter: + join_char = '' + + reference = join_char.join(filter(None, reference.split(delimiter))) + transcription = join_char.join(filter(None, transcription.split(delimiter))) + + edit_distance = _levenshtein_distance(reference, transcription) + return float(edit_distance), len(reference) + + +def _levenshtein_distance(reference: Union[List[str], str], + transcription: Union[List[str], str]): + """Levenshtein distance is a string metric for measuring the difference + between two sequences. Informally, the levenshtein disctance is defined as + the minimum number of single-character edits (substitutions, insertions or + deletions) required to change one word into the other. We can naturally + extend the edits to word level when calculate levenshtein disctance for + two sentences. + """ + m = len(reference) + n = len(transcription) + + # special cases + if reference == transcription: + return 0 + if m == 0: + return n + if n == 0: + return m + if m < n: + reference, transcription = transcription, reference + m, n = n, m + + # use O(min(m, n)) space + distance = np.zeros((2, n + 1), dtype=np.int32) + + # initialize distance matrix + for j in range(0, n + 1): + distance[0][j] = j + + # calculate Levenshtein distance + for i in range(1, m + 1): + prev_row_idx = (i - 1) % 2 + cur_row_idx = i % 2 + distance[cur_row_idx][0] = i + for j in range(1, n + 1): + if reference[i - 1] == transcription[j - 1]: + distance[cur_row_idx][j] = distance[prev_row_idx][j - 1] + else: + s_num = distance[prev_row_idx][j - 1] + 1 + i_num = distance[cur_row_idx][j - 1] + 1 + d_num = distance[prev_row_idx][j] + 1 + distance[cur_row_idx][j] = min(s_num, i_num, d_num) + + return distance[m % 2][n] + + +def _bss_eval(x, x_ref): + + x_ref_energy = (x_ref ** 2).sum() + + alpha = (x_ref @ x / x_ref_energy) + + e_true = x_ref + e_res = x - e_true + + signal = (e_true ** 2).sum() + noise = (e_res ** 2).sum() + + snr = 10 * np.log10(signal / noise) + + e_true = x_ref * alpha + e_res = x - e_true + + signal = (e_true ** 2).sum() + noise = (e_res ** 2).sum() + + si_sdr = 10 * np.log10(signal / noise) + + srr = -10 * np.log10((1 - (1/alpha)) ** 2) + sd_sdr = snr + 10 * np.log10(alpha ** 2) + + return si_sdr, sd_sdr, snr, srr diff --git a/voicebox/src/utils/data.py b/voicebox/src/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..adb8db961e87599e9c5908e1cdb8a70cf6ead6cb --- /dev/null +++ b/voicebox/src/utils/data.py @@ -0,0 +1,1059 @@ +import math +from copy import deepcopy +import torch +import torch.nn as nn +import torch.nn.functional as F +import random +import hashlib + +from torch.utils.data import Dataset, DataLoader, TensorDataset +from torch.nn.utils.rnn import pad_sequence + +from typing import Union, List, Sized, Iterable + +from tqdm import tqdm + +from src.utils.filesystem import ensure_dir +from src.utils.device import DataParallelWrapper +from src.pipelines.pipeline import Pipeline +from src.models.speaker.speaker import SpeakerVerificationModel +from src.models.speech.speech import SpeechRecognitionModel +from src.constants import * + +################################################################################ +# Data-loading utilities +################################################################################ + + +class DatasetWrapper(Dataset): + """ + Most data utilities here involve re-assigning or computing targets to + train or evaluate adversarial attacks. This class wraps an existing + dataset to overwrite its stored inputs and targets as necessary. + """ + def __init__(self, dataset, inputs, targets): + + super().__init__() + + self.dataset = dataset + self.inputs = inputs + self.targets = targets + + ref_batch = next(iter(dataset)) + + if isinstance(ref_batch, tuple): + self.format = 'tuple' + else: + self.format = 'dict' + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + + if self.format == 'tuple': + x, y, *features = self.dataset[idx] + return self.inputs[idx], self.targets[idx], *features + + else: + batch = self.dataset[idx] + batch['x'] = self.inputs[idx] + batch['y'] = self.targets[idx] + return batch + + +def pad_batch_power_2(batch): + """ + Given a batch of tensors, pad to nearest power of 2 to maximum length. Used + as a `collate_fn` argument to Pytorch `DataLoader` objects. + + Parameters: + ----------- + batch + + Returns: + -------- + """ + + # get tensors + (x, y) = zip(*batch) + + n_batch = len(x) + + if n_batch < 1: + return torch.Tensor([]), None + + if n_batch == 1: + return x[0:1], y + + if type(y[0]) != str: + y = torch.stack(y, dim=0) + + # compute maximum length + dtype, device = x[0].dtype, x[0].device + lengths = [x_i.shape[-1] for x_i in x] + max_len = max(lengths) + next_pow_2 = 2**(max_len - 1).bit_length() + + # pad inputs + shape = next(iter(x)).shape[1:-1] + batch_padded = torch.zeros( + (n_batch, *shape, next_pow_2), + dtype=dtype, + device=device + ) + + for i in range(n_batch): + batch_padded[i, ..., :lengths[i]] = x[i] + + return batch_padded, y + + +def text_to_tensor( + text: Union[str, List[str]], + labels: list, + return_lengths: bool = True, + max_length: int = None, + padding_value: int = -1): + """ + Convert one or more string transcripts to padded tensor form (character + indices), and optionally return sequence lengths. + + Parameters: + ----------- + text (str): a string or list of string transcripts + + labels (list): list of characters, ordered by index + + return_lengths (bool): if True, return sequence lengths + + max_length (int): if given, trim/pad all sequences to length + + padding_value (int): value with which to perform length padding + + Returns: + -------- + sequences (Tensor): tensor containing padded index sequences + + lengths (Tensor): tensor containing sequence lengths + """ + + if isinstance(text, str): + text = [text] + + # convert from characters to token indices + char_to_idx = {labels[i].upper(): i for i in range(len(labels))} + + lengths = [] + tensors = [] + + for t in text: + + lengths.append(len(t)) + token_indices = [char_to_idx[c] for c in t.upper()] + tensors.append( + torch.as_tensor(token_indices, dtype=torch.long) + ) + + # pad and return + tensors = pad_sequence( + tensors, + batch_first=True, + padding_value=padding_value + ) # (n_batch, max_seq_len) + + if max_length is not None: + if tensors.shape[-1] > max_length: + tensors = tensors[..., :max_length] + elif tensors.shape[-1] < max_length: + tensors = F.pad( + tensors, + (0, max_length - tensors.shape[-1]), + value=padding_value) + + lengths = torch.as_tensor(lengths, dtype=torch.long) + + if return_lengths: + return tensors, lengths + else: + return tensors + + +def padded_transcript_length( + transcript: torch.Tensor, + padding_value: int = -1): + """ + Given one or more transcripts in index sequence format, determine lengths + by looking for padding value. + + Parameters: + ----------- + transcript (Tensor): tensor containing one or more index sequences + + padding_value (int): value used to pad sequence tensors + + Returns: + -------- + lengths (Tensor): tensor containing un-padded length of each sequence + """ + + # find first occurence of padding value in each transcript tensor + mask = transcript == padding_value + mask_max_values, mask_max_indices = torch.max(mask, dim=-1) + + # if the max-mask is zero, there is no padding in the tensor + mask_max_indices[mask_max_values == 0] = transcript.shape[-1] + + return mask_max_indices.long() + + +def move_to_device_recursive(d: dict, device: Union[str, torch.device]): + """Move all tensors in a dictionary object to given device""" + for k, v in d.items(): + if isinstance(v, dict): + d[k] = move_to_device_recursive(v, device) + elif isinstance(v, torch.Tensor): + d[k] = v.to(device) + elif isinstance(v, tuple) and len(v) > 0 and isinstance( + next(iter(v)), torch.Tensor): + v_new = tuple() + for v_i in v: + v_new += (v_i.to(device),) + d[k] = v_new + elif isinstance(v, list) and len(v) > 0 and isinstance( + next(iter(v)), torch.Tensor): + v_new = [] + for v_i in v: + v_new += [v_i.to(device)] + d[k] = v_new + return d + + +def dataset_to_device(data: Dataset, device: Union[str, torch.device]): + """Move datasets directly to given device. May cause memory issues!""" + + data.__dict__ = move_to_device_recursive(data.__dict__, device) + + +@torch.no_grad() +def create_embedding_dataset(data_train: Dataset, + data_test: Dataset, + pipeline: Pipeline, + select_train: str = 'random', + select_test: str = 'random', + targeted: bool = True, + target_class: int = None, + num_per_class_train: int = None, + num_per_class_test: int = None, + num_embeddings_train: int = 5, + exclude_class: Union[int, List] = None, + exclude_success: bool = False, + use_cache: bool = False, + shuffle: bool = False, + **kwargs + ): + """ + Given training and test datasets holding audio and labels, compute + embeddings according to (potentially reassigned) labels. It is assumed + that the provided train and test targets both index the same set of classes, + i.e. that label 0 in the train set refers to the same class as label 0 in + the test set. This function considers three possible cases: + + 1. `targeted` == True, `target_class` != None + + In this case, a single target class from the training set is assigned + to all training and test instances. Embeddings corresponding to the + target class are computed and assigned based on the `select` parameter. + + 2. `targeted` == True, `target_class` == None + + In this case, targets are randomly reassigned within both the training + and test sets. Embeddings corresponding to the targets are computed and + assigned based on the `select` parameter. + + 3. `targeted` == False + + In this case, targets remain unchanged and the embedding of each + instance is computed directly. + + Parameters + ---------- + data_train (Dataset): a Dataset object holding audio and labels + + data_test (Dataset): a Dataset object holding audio and labels + + pipeline (Pipeline): a Pipeline object; must wrap a + SpeakerVerificationModel object + + select_train (str): method of selecting target embeddings for train + set; must be one of `random`, `single`, 'same', + `centroid`, or 'median' + + select_test (str): method of selecting target embeddings for test + set; must be one of `random`, `single`, same', + `centroid`, or 'median' + + targeted (bool): if False, target classes will not be reassigned + + target_class (int): if given, reassign all targets to this class + + num_per_class_train (int): if given, perform stratified sampling with this + number of instances drawn per class + + num_per_class_test (int): if given, perform stratified sampling with this + number of instances drawn per class + + num_embeddings_train (int): if given and attack is targeted, train on only + this many distinct embeddings of the target + speaker, and evaluate on other embeddings of the + target speaker + + exclude_class (int, list): if given, exclude all instances from this class + or in this list of classes + + exclude_success (bool): if True, drop or replace instances for which the + initial prediction achieves the desired + adversarial outcome (i.e. matches the target in + the case of a targeted attack, and evades the + target in the case of an untargeted attack) + + use_cache (bool): if True, save to and load from disk using hash- + based lookup + + shuffle (bool): if True, shuffle data and targets; may result in + mismatch if dataset contains other features (e.g. + pitch, periodicity) + + Returns + ------- + train (dict): dictionary containing Dataset with audio and embedding + targets, example audio of each target class, original targets + indices, and reassigned target indices + + test (dict): dictionary containing Dataset with audio and embedding + targets, example audio of each target class, original targets + indices, and reassigned target indices + """ + + if not num_per_class_train: + num_per_class_train = 0 + if not num_per_class_test: + num_per_class_test = 0 + + if targeted and target_class is not None and num_per_class_train: + assert num_per_class_train > num_embeddings_train, \ + f'For targets drawn from training set, number of embeddings ' \ + f'reserved for training ({num_embeddings_train}) must be less ' \ + f'than number of embeddings computed per class ' \ + f'({num_per_class_train})' + + if exclude_success: + raise NotImplementedError(f'Target correction not yet implemented; ' + f'use `NullAttack` to measure trivial ' + f'success rates for now') + + # ensure pipeline is capable of producing embeddings + assert isinstance(pipeline.model, SpeakerVerificationModel) + + # match devices + ref_batch_train = next(iter(data_train)) + ref_batch_test = next(iter(data_test)) + if isinstance(ref_batch_train, tuple): + example_input, *_ = ref_batch_train + elif isinstance(ref_batch_train, dict): + example_input = ref_batch_train['x'] + else: + raise ValueError(f'Dataset must provide batches in tuple or dictionary' + f' format') + orig_device = example_input.device + + # check that model produces embeddings with valid shape + try: + embedding_shape = list( + pipeline.model(example_input.to(pipeline.device)).shape + ) + assert len(embedding_shape) == 3 and embedding_shape[0] == 1 + except AssertionError: + raise RuntimeError(f'Speaker verification model must produce ' + f'embeddings of shape ' + f'(n_batch, n_segments, embedding_dim)') + + assert isinstance(data_train, Sized) and isinstance(data_test, Sized), \ + f"Datasets must have length attribute accessible via `len()`" + + # check embedding selection method + assert select_train in ['random', 'single', 'centroid', 'same', 'median'], \ + f"invalid value for `select_train` {select_train}" + assert select_test in ['random', 'single', 'centroid', 'same', 'median'], \ + f"invalid value for `select_test` {select_test}" + + assert not targeted or select_train != 'same', \ + f'`same` embedding selection only valid for untargeted mode' + + # check for optional `batch_size` argument; otherwise, use batch size of 1 + batch_size = kwargs.get('batch_size', 1) + + # creating embedding datasets is time-consuming; to avoid repeated + # computation, we can store the generated dataset under a hash + hash_str = str(pipeline.model) + hash_str += str(data_train.__class__.__name__) + hash_str += str(data_test.__class__.__name__) + hash_str += select_train + select_test + hash_str += str(targeted) + str(target_class) + hash_str += str(num_per_class_train) + str(num_per_class_test) + hash_str += str(exclude_class) + str(exclude_success) + + # obtain hash and convert to filename + dataset_hash = hashlib.md5(hash_str.encode()).digest() + dataset_file = str(dataset_hash).replace("\'", "")[1:].replace("\\", ".") + dataset_file += ".pt" + + # check whether a cached embedding dataset with matching hash exists + embeddings_cache_dir = Path(CACHE_DIR) / 'embeddings' + ensure_dir(embeddings_cache_dir) + cached_datasets = embeddings_cache_dir.glob('*.pt') + + # if dataset is already cached, load and return + if use_cache and dataset_file in [d.name for d in cached_datasets]: + + dataset = torch.load(embeddings_cache_dir / dataset_file) + + # check for valid dataset structure + try: + assert isinstance(dataset, dict) + assert 'train' in dataset and 'test' in dataset + + return move_to_device_recursive( + dataset['train'], orig_device + ), move_to_device_recursive(dataset['test'], orig_device) + + except AssertionError: + raise RuntimeWarning(f'Invalid dataset structure; will re-compute ' + f'and overwrite existing dataset ' + f'{dataset_file}') + + # shuffle data + rand_idx_train = torch.randperm( + len(data_train)) if shuffle else torch.arange(len(data_train)) + rand_idx_test = torch.randperm( + len(data_test)) if shuffle else torch.arange(len(data_test)) + + # separate data and labels + if isinstance(ref_batch_train, tuple): + inputs_train, labels_train, *_ = data_train[:] + else: + inputs_train, labels_train = data_train[:]['x'], data_train[:]['y'] + inputs_train_shuffled = inputs_train[rand_idx_train] + labels_train_shuffled = labels_train[rand_idx_train] + + if isinstance(ref_batch_test, tuple): + inputs_test, labels_test, *_ = data_test[:] + else: + inputs_test, labels_test = data_test[:]['x'], data_test[:]['y'] + inputs_test_shuffled = inputs_test[rand_idx_test] + labels_test_shuffled = labels_test[rand_idx_test] + + # if target is given, check that it is present in training data + if target_class is not None: + assert target_class in labels_train, \ + f'Target class {target_class} is not present in training data' + + # determine train and test labels + unique_labels_train = [l.item() for l in torch.unique(labels_train)] + unique_labels_test = [l.item() for l in torch.unique(labels_test)] + + # filter excluded classes (if given) from train and test sets + if isinstance(exclude_class, List): + unique_labels_train = [ + l for l in unique_labels_train if l not in exclude_class] + unique_labels_test = [ + l for l in unique_labels_test if l not in exclude_class] + + elif exclude_class is not None: + unique_labels_train = [ + l for l in unique_labels_train if not l == exclude_class] + unique_labels_test = [ + l for l in unique_labels_test if not l == exclude_class] + + # prepare to store one example audio input per label (speaker) + audio_train = {} + audio_test = {} + + # prepare to store training and test embeddings by label + embeddings_train = {} + embeddings_test = {} + + def compute_embeddings_by_label( + unique_labels: list, + inputs: torch.Tensor, + labels: torch.Tensor, + saved_audio: dict, + saved_embeddings: dict, + num_per_class): + """ + Compute an embedding for every instance in the given dataset and store + by label in a dictionary; store one audio example per label in a + dictionary. + """ + + # compute embeddings over training set and sort by label + for label in tqdm( + unique_labels, + total=len(unique_labels), + desc="Computing embeddings for dataset"): + + # select training instances of class, allowing for a limit on the + # number of embeddings stored per class + x_l = inputs[labels == label] + n_l = num_per_class if num_per_class else len(x_l) + + # store one audio example per training label + saved_audio[label] = x_l[0:1] + + # store embeddings per training label + n_batches = math.ceil(n_l / batch_size) + saved_embeddings[label] = [] + for i in range(n_batches): + saved_embeddings[label].append( + pipeline.model( + x_l[i*batch_size:(i+1)*batch_size].to(pipeline.device) + ).to('cpu') # store intermediate results on CPU + ) + + saved_embeddings[label] = torch.cat( + saved_embeddings[label], dim=0)[:n_l] + + # compute embeddings over training and test datasets and store by label + compute_embeddings_by_label( + unique_labels_train, + inputs_train_shuffled, + labels_train_shuffled, + audio_train, + embeddings_train, + num_per_class_train + ) + compute_embeddings_by_label( + unique_labels_test, + inputs_test_shuffled, + labels_test_shuffled, + audio_test, + embeddings_test, + num_per_class_test + ) + + # filter datasets to remove excluded and target labels + if targeted and target_class is not None: + unique_labels_train = [ + l for l in unique_labels_train if not l == target_class] + unique_labels_test = [ + l for l in unique_labels_test if not l == target_class] + + def reassign_labels( + unique_labels: list, + inputs: torch.Tensor, + labels_orig: torch.Tensor, + num_per_class): + """ + Reassign targets, as detailed in documentation above. + """ + + labels_new = torch.full(labels_orig.shape, -1, dtype=labels_orig.dtype) + + # reassign label-by-label + for i, label in enumerate( + tqdm( + unique_labels, + total=len(unique_labels), + desc="Reassigning labels for dataset")): + + # select all training instances with label + idx_l = labels_orig == label + x_l = inputs[idx_l] + + # store original targets + y_orig_l = torch.full((len(x_l), ), label) + + # use a placeholder to allow for deletion of rows; overwrite with + # valid labels and delete rows where -1 remains + y_new_l = torch.full((len(x_l), ), -1) + + # limit number of instances per class if specified + n_l = num_per_class if num_per_class else len(x_l) + + # targeted attacks require that the given targets be reassigned + if targeted: + + # if target class is provided, reassign targets to given class + if target_class is not None: + y_new_l[:n_l] = target_class + + # if no target class is given, randomly reassign targets; ensure + # that no target is unchanged and new targets are evenly + # distributed + else: + remaining_labels = [ + l for l in unique_labels if l != label] + + for j in range(n_l): + y_new_l[j] = random.choice(remaining_labels) + + # otherwise, classes remain unchanged + else: + y_new_l[:n_l] = y_orig_l[:n_l] + + # update data and labels, deleting rows corresponding to + # extraneous inputs (according to `num_per_class`) + labels_new[idx_l] = y_new_l + + keep_idx = labels_new != -1 + inputs = inputs[keep_idx] + labels_orig = labels_orig[keep_idx] + labels_new = labels_new[keep_idx] + + return inputs, labels_orig, labels_new, keep_idx + + # reassign training and test labels if necessary (see documentation + # above); remove instances of target class and those not required by + # `num_per_class`, if given + ( + inputs_train_shuffled, + labels_train_shuffled, + labels_train_reassigned, + select_idx_train + ) = reassign_labels( + unique_labels_train, + inputs_train_shuffled, + labels_train_shuffled, + num_per_class_train + ) + ( + inputs_test_shuffled, + labels_test_shuffled, + labels_test_reassigned, + select_idx_test + ) = reassign_labels( + unique_labels_test, + inputs_test_shuffled, + labels_test_shuffled, + num_per_class_test + ) + + # prepare to store target embeddings corresponding to reassigned labels + embedding_targets_train = torch.empty( + (len(labels_train_reassigned), *embedding_shape[1:])) + embedding_targets_test = torch.empty( + (len(labels_test_reassigned), *embedding_shape[1:])) + + def assign_embeddings( + labels_new: torch.Tensor, + embeddings_by_label: dict, + embedding_targets: torch.Tensor, + select: str, + is_train: bool = True + ): + + # iterate over dataset and associate embedding targets with + # reassigned labels + labels_to_assign = [l.item() for l in torch.unique(labels_new)] + + for label in labels_to_assign: + + # find indices for which embeddings of given label are to + # be assigned + idx_l = labels_new == label + n_l = int(torch.sum(idx_l * 1).item()) + + if n_l == 0: + continue + + # obtain all embeddings corresponding to given label + embeddings_l = embeddings_by_label[label] + + # if untargeted, assign ground-truth embeddings for each instance + if select == 'same': + y_emb_l = embeddings_l + + else: + + # separate train and test embeddings of given speaker + if num_embeddings_train: + + # for targeted attacks, allow training/testing on separate + # small subsets of a speaker's utterances + if targeted: + assert num_embeddings_train <= len(embeddings_l), \ + f"`num_embeddings_train` {num_embeddings_train} " \ + f"is greater than the number of utterances for " \ + f"speaker {label}" + + if is_train: + embeddings_l = embeddings_l[:num_embeddings_train] + else: + embeddings_l = embeddings_l[num_embeddings_train:] + + # using `select` parameter, assign embeddings + y_emb_l = [] + + for i in range(n_l): + + if select == 'single': # use single embedding + y_emb_l.append(embeddings_l[0:1]) + + elif select == 'random': # use random embeddings + emb_idx = random.randint(0, len(embeddings_l) - 1) + y_emb_l.append(embeddings_l[emb_idx:emb_idx+1]) + + elif select == 'centroid': # average over embeddings + + _, n_segments, embedding_dim = embedding_shape + + # duplicate over all segments + centroid = embeddings_l.mean(dim=(0, 1)).reshape( + (1, 1, embedding_dim) + ).repeat(1, n_segments, 1) + + y_emb_l.append(centroid) + + elif select == 'median': # median over embeddings + + _, n_segments, embedding_dim = embedding_shape + + # duplicate over all segments + median = embeddings_l.reshape( + n_l*n_segments, -1 + ).median(dim=0)[0].reshape( + (1, 1, embedding_dim) + ).repeat(1, n_segments, 1) + + y_emb_l.append(median) + + else: + raise ValueError(f'Invalid embedding selection method ' + f'{select}') + + y_emb_l = torch.cat(y_emb_l, dim=0) + + embedding_targets[idx_l] = y_emb_l + + # with labels finalized, assign embedding targets + assign_embeddings( + labels_train_reassigned, + embeddings_train, + embedding_targets_train, + select_train, + True + ) + assign_embeddings( + labels_test_reassigned, + embeddings_train if targeted and target_class is not None else embeddings_test, + embedding_targets_test, + select_test, + False + ) + + # account for shuffling + final_idx_train = rand_idx_train[select_idx_train] + final_idx_test = rand_idx_test[select_idx_test] + + from src.data.dataset import VoiceBoxDataset + + if isinstance(data_train, VoiceBoxDataset): + data_train_final = data_train.overwrite_dataset( + inputs_train_shuffled, + embedding_targets_train, + final_idx_train + ) + else: + data_train_final = DatasetWrapper( + data_train, + inputs_train_shuffled, + embedding_targets_train) + + if isinstance(data_test, VoiceBoxDataset): + data_test_final = data_test.overwrite_dataset( + inputs_test_shuffled, + embedding_targets_test, + final_idx_test + ) + else: + data_test_final = DatasetWrapper( + data_test, + inputs_test_shuffled, + embedding_targets_test) + + # store data and embeddings, audio examples, original targets, and + # reassigned targets + train = { + 'dataset': data_train_final, + 'id_to_audio': audio_train, + 'true_id': labels_train_shuffled, + 'target_id': labels_train_reassigned + } + test = { + 'dataset': data_test_final, + 'id_to_audio': audio_test, + 'true_id': labels_test_shuffled, + 'target_id': labels_test_reassigned + } + + if use_cache: + dataset = { + 'train': train, + 'test': test + } + torch.save(dataset, embeddings_cache_dir / dataset_file) + + # restore device and return + return move_to_device_recursive( + train, orig_device + ), move_to_device_recursive( + test, orig_device + ) + + +@torch.no_grad() +def create_transcription_dataset(data_train: Dataset, + data_test: Dataset, + pipeline: Pipeline, + targeted: bool = True, + target_transcription: str = None, + output_format: str = 'transcript', + shuffle: bool = False, + **kwargs): + """ + Given training and test datasets holding audio assign string transcriptions + for performing speech recognition attacks. + + 1. `targeted` == True, `target_transcription` != None + + In this case, a single transcription target is assigned to all instances. + + 2. `targeted` == True, `target_transcription` == None + + In this case, ground-truth transcriptions are randomly reassigned as + targets within both the training and test sets. + + 3. `targeted` == False + + In this case, ground-truth transcriptions are used as targets. + + Parameters + ---------- + data_train (Dataset): a Dataset object holding audio + + data_test (Dataset): a Dataset object holding audio + + targeted (bool): if False, target classes will not be reassigned + + target_transcription (str): if given, reassign all targets to the given + transcription string + + shuffle (bool): if True, shuffle data and targets; may result in + mismatch if dataset contains other features (e.g. + pitch, periodicity) + + Returns + ------- + train (dict): + + test (dict): + """ + + # check output format (string transcripts or frame-wise token probabilities) + assert output_format in ['transcript', 'emission'], \ + f'Invalid output format; must be one of `transcript` or `emission`' + + # check for valid model type + assert isinstance(pipeline.model, SpeechRecognitionModel) + + assert isinstance(data_train, Sized) and isinstance(data_test, Sized), \ + f"Datasets must have length attribute accessible via `len()`" + + # match devices + ref_batch_train = next(iter(data_train)) + ref_batch_test = next(iter(data_test)) + if isinstance(ref_batch_train, tuple): + example_input, *_ = ref_batch_train + elif isinstance(ref_batch_train, dict): + example_input = ref_batch_train['x'] + else: + raise ValueError(f'Dataset must provide batches in tuple or dictionary' + f' format') + + orig_device = example_input.device + + # check for optional `batch_size` argument; otherwise, use batch size of 1 + batch_size = kwargs.get('batch_size', 1) + + # shuffle data + + rand_idx_train = torch.randperm( + len(data_train)) if shuffle else torch.arange(len(data_train)) + rand_idx_test = torch.randperm( + len(data_test)) if shuffle else torch.arange(len(data_test)) + + if isinstance(ref_batch_train, tuple): + inputs_train, *_ = data_train[:] + else: + inputs_train = data_train[:]['x'] + inputs_train_shuffled = inputs_train[rand_idx_train] + + if isinstance(ref_batch_test, tuple): + inputs_test, *_ = data_test[:] + else: + inputs_test = data_test[:]['x'] + inputs_test_shuffled = inputs_test[rand_idx_test] + + # if targeted and target transcription provided, simply assign and return + if targeted and target_transcription is not None: + + assert output_format == 'transcript', \ + f"Target transcript provided; cannot use emission targets" + + # check that target transcript contains character set compatible with + # pipeline, and does not contain 'blank' character + valid_characters = deepcopy(pipeline.model.get_labels()) + try: + del valid_characters[pipeline.model.get_blank_idx()] + except (IndexError, TypeError): + pass + assert all([c in valid_characters for c in target_transcription]), \ + f'Target transcription contains invalid characters' + + single_target = text_to_tensor( + target_transcription, + pipeline.model.get_labels(), + return_lengths=False + ) + + targets_train = single_target.repeat(len(inputs_train_shuffled), 1) + targets_test = single_target.repeat(len(inputs_test_shuffled), 1) + + # otherwise, compute transcriptions using given pipeline + else: + + def transcribe(dataset: torch.Tensor): + + results = [] + + n_batches = math.ceil(len(dataset) / batch_size) + for batch_idx in tqdm( + range(n_batches), + total=n_batches, + desc="Computing transcriptions for dataset"): + + x = dataset[ + batch_idx*batch_size:(batch_idx+1)*batch_size + ].to(pipeline.device) + + if output_format == 'transcript': + results.extend(pipeline.model.transcribe(x)) + elif output_format == 'emission': + results.extend( + torch.split( + pipeline.model(x).to(orig_device), 1, dim=0)) + + if output_format == 'emission': + + # pad to max emission length + results = pad_sequence(results, batch_first=True).squeeze(1) + + elif output_format == 'transcript': + results = text_to_tensor( + results, + pipeline.model.get_labels(), + return_lengths=False) + + return results + + targets_train = transcribe(inputs_train_shuffled) + targets_test = transcribe(inputs_test_shuffled) + + # if targeted, permute transcriptions such that no input retains its + # original transcription + if targeted: + + # use derangements with a fixed iteration budget; expected number + # of iterations required to shuffle with no fixed points is e (~3) + def derange(x: torch.Tensor): + + max_iter = 10 + orig_shape = x.shape + x = x.reshape(x.shape[0], -1) + + for i in range(max_iter): + + rand_idx = torch.randperm(len(x)) + equal = torch.sum( + 1.0 * (x == x[rand_idx]), + dim=-1 + ) >= x.shape[-1] + + if not equal.sum().item(): + break + + return x[rand_idx].reshape(orig_shape) + + targets_train = derange(targets_train) + targets_test = derange(targets_test) + + # compute transcript lengths + if output_format == 'transcript': + lengths_train = padded_transcript_length(targets_train) + lengths_test = padded_transcript_length(targets_test) + elif output_format == 'emission': + lengths_train = torch.full( + size=(len(inputs_train_shuffled),), + fill_value=targets_train.shape[1], + dtype=torch.long + ) + lengths_test = torch.full( + size=(len(inputs_test_shuffled),), + fill_value=targets_test.shape[1], + dtype=torch.long + ) + else: + raise ValueError(f'Invalid value for `output_format`') + + from src.data.dataset import VoiceBoxDataset + + if isinstance(data_train, VoiceBoxDataset): + data_train_final = data_train.overwrite_dataset( + inputs_train_shuffled, + targets_train, + rand_idx_train) + else: + data_train_final = DatasetWrapper( + data_train, + inputs_train_shuffled, + targets_train) + + if isinstance(data_test, VoiceBoxDataset): + data_test_final = data_test.overwrite_dataset( + inputs_test_shuffled, + targets_test, + rand_idx_test + ) + else: + data_test_final = DatasetWrapper( + data_test, + inputs_test_shuffled, + targets_test) + + train = { + 'dataset': data_train_final, + 'targets': targets_train, + 'target_lengths': lengths_train + } + + test = { + 'dataset': data_test_final, + 'targets': targets_test, + 'target_lengths': lengths_test + } + + return move_to_device_recursive( + train, orig_device + ), move_to_device_recursive( + test, orig_device + ) diff --git a/voicebox/src/utils/device.py b/voicebox/src/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..a13604790090fd21a8e438b727cfd96d6e23ad33 --- /dev/null +++ b/voicebox/src/utils/device.py @@ -0,0 +1,122 @@ +import torch +import torch.nn as nn + +from typing import OrderedDict, Dict, Any, TypeVar, Union + +################################################################################ +# Utilities for single/multi-GPU training +################################################################################ + + +class DataParallelWrapper(nn.DataParallel): + """Extend DataParallel class to allow full method/attribute access""" + def __getattr__(self, name): + + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + + def state_dict(self, + *args, + destination=None, + prefix='', + keep_vars=False): + """Avoid `module` prefix in saved weights""" + return self.module.state_dict( + destination=destination, + prefix=prefix, + keep_vars=keep_vars + ) + + def load_state_dict(self, + state_dict: OrderedDict[str, torch.Tensor], + strict: bool = True): + """Avoid `module` prefix in saved weights""" + self.module.load_state_dict(state_dict, strict) + + +def get_cuda_device_ids(): + """Fetch all available CUDA devices""" + return list(range(torch.cuda.device_count())) + + +def wrap_module_multi_gpu(m: nn.Module, device_ids: list): + """Implement data parallelism for arbitrary Module objects.""" + + if len(device_ids) < 1: + return m + + elif isinstance(m, DataParallelWrapper): + return m + + else: + return DataParallelWrapper( + module=m, + device_ids=device_ids + ) + + +def unwrap_module_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]): + + if isinstance(m, DataParallelWrapper): + return m.module.to(device) + else: + return m.to(device) + + +def wrap_attack_multi_gpu(m: nn.Module, device_ids: list): + """ + Implement data parallelism for attack objects, including stored Pipeline + and Perturbation instances that may be accessed outside of `forward()` + """ + + if len(device_ids) < 1: + return m + + if hasattr(m, 'pipeline') and isinstance(m.pipeline, nn.Module): + m.pipeline = wrap_pipeline_multi_gpu(m.pipeline, device_ids) + + if hasattr(m, 'perturbation') and isinstance(m.perturbation, nn.Module): + m.perturbation = wrap_module_multi_gpu(m.perturbation, device_ids) + + # scale batch size to number of devices + if hasattr(m, 'batch_size'): + m.batch_size *= len(device_ids) + + return m + + +def unwrap_attack_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]): + """ + + """ + if hasattr(m, 'pipeline') and isinstance(m.pipeline, DataParallelWrapper): + m.pipeline = unwrap_module_multi_gpu(m.pipeline, device) + + if hasattr(m, 'perturbation') and isinstance(m.perturbation, DataParallelWrapper): + m.perturbation = unwrap_module_multi_gpu(m.perturbation, device) + + # scale batch size to number of devices + if hasattr(m, 'batch_size'): + m.batch_size = m.batch_size // len(get_cuda_device_ids()) + + return m + + +def wrap_pipeline_multi_gpu(m: nn.Module, device_ids: list): + """ + Implement data parallelism for Pipeline objects, including all intermediate + stages that may be accessed outside of `forward()` + """ + + if len(device_ids) < 1: + return m + + return wrap_module_multi_gpu(m, device_ids) + + +def unwrap_pipeline_multi_gpu(m: nn.Module, device: Union[str, int, torch.device]): + """ + """ + return unwrap_module_multi_gpu(m, device) diff --git a/voicebox/src/utils/filesystem.py b/voicebox/src/utils/filesystem.py new file mode 100644 index 0000000000000000000000000000000000000000..206a15884361a1eac190781be7b6e36c156b8eca --- /dev/null +++ b/voicebox/src/utils/filesystem.py @@ -0,0 +1,26 @@ +import os +from pathlib import Path + +from typing import Union + +################################################################################ +# Filesystem utilities +################################################################################ + + +def ensure_dir_for_filename(filename: str): + """ + Ensure all directories along given path exist, given filename + """ + ensure_dir(os.path.dirname(filename)) + + +def ensure_dir(directory: Union[str, Path]): + """ + Ensure all directories along given path exist, given directory name + """ + + directory = str(directory) + + if len(directory) > 0 and not os.path.exists(directory): + os.makedirs(directory) diff --git a/voicebox/src/utils/plotting.py b/voicebox/src/utils/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..d08e0bf057935356470bcb15a3aef7932056710a --- /dev/null +++ b/voicebox/src/utils/plotting.py @@ -0,0 +1,256 @@ +import io +import torch +import torch.nn.functional as F +import math +from datetime import datetime + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import IPython.display as ipd +from IPython.core.display import display + +from PIL import Image +from torchvision.transforms import PILToTensor, ToTensor + +from typing import Union + +#matplotlib.use('Agg') # switch backend to run on server + +################################################################################ +# Plotting utilities for logging and figures +################################################################################ + + +def tensor_to_np(x: torch.Tensor): + return x.clone().detach().cpu().numpy() + + +def play_audio(x: torch.Tensor, sample_rate: int = 16000): + display(ipd.Audio(tensor_to_np(x).flatten(), rate=sample_rate)) + + +def plot_filter(amplitudes: torch.Tensor): + """ + Given a single set of time-varying filter controls, return plot as image + """ + + amplitudes = amplitudes.clone().detach() + + if amplitudes.ndim == 2: + magnitudes = amplitudes.cpu().numpy().T + elif amplitudes.ndim == 3: + magnitudes = amplitudes[0].cpu().numpy().T + else: + raise ValueError("Can only plot single filter response") + + # plot filter controls over time as heatmap + fig, ax = plt.subplots(figsize=(8, 8)) + im = ax.imshow(magnitudes, aspect='auto') + fig.colorbar(im, ax=ax) + ax.invert_yaxis() + ax.set_title('filter amplitudes') + ax.set_xlabel('frames') + ax.set_ylabel('frequency bin') + plt.tight_layout() + + # save plot to buffer + buf = io.BytesIO() + plt.savefig(buf, format="png") + plt.close(fig) + buf.seek(0) + img = Image.open(buf) + + # return plot as image + return ToTensor()(np.array(img)) + + +def plot_waveform(x: torch.Tensor, scale: Union[int, float] = 1.0): + """ + Given single audio waveform, return plot as image + """ + try: + assert len(x.shape) == 1 or x.shape[0] == 1 + except AssertionError: + raise ValueError('Audio input must be single waveform') + + # waveform plot + fig, ax = plt.subplots(figsize=(8,8)) + fig.subplots_adjust(bottom=0.2) + plt.xticks( + #rotation=90 + ) + ax.plot(tensor_to_np(x).flatten(), color='k') + ax.set_xlabel("Sample Index") + ax.set_ylabel("Waveform Amplitude") + plt.axis((None, None, -scale, scale)) # set y-axis range + + # save plot to buffer + buf = io.BytesIO() + plt.savefig(buf, format="png") + plt.close(fig) + buf.seek(0) + img = Image.open(buf) + + # return plot as image + return ToTensor()(np.array(img)) + + +def plot_filter_codebook(x: torch.Tensor, use: torch.Tensor = None): + """ + Plot a codebook of learned frequency-domain filter controls. + """ + + # scale use rates to [0, 1] for background coloring but not text display + if use is not None: + use = use.clone().detach() + use_normalized = use.clone() + use_normalized -= use_normalized.min(0, keepdim=True)[0] + use_normalized /= use_normalized.max(0, keepdim=True)[0] + + n_filters, n_bands = x.shape[0], x.shape[-1] + + # create a square grid layout, which may be partially filled + grid_size = math.ceil(math.sqrt(n_filters)) + + fig, axs = plt.subplots(ncols=grid_size, nrows=grid_size, figsize=(8, 8)) + + for i in range(n_filters): + axis = axs[i//grid_size, i % grid_size] + + # color filter plot according to use rate of filter + if use is not None: + assert len(use) == n_filters # one usage rate per filter + axis.set_facecolor((1.0, 0.47, 0.42, use_normalized[i].item())) + + x_text = n_bands // 2 + y_text = x[i].max().item() / 2 + axis.text(x_text, y_text, f"{use[i].item() :0.3f}", ha="center", va="center", zorder=10) + + axis.plot(np.zeros(n_bands), 'k', alpha=0.5) # plot "neutral" line + axis.plot(tensor_to_np(x[i]).flatten()) + axis.set_xlabel("Frequency") + axis.set_ylabel("Amplitude") + + plt.tight_layout() + + # save plot to buffer + buf = io.BytesIO() + plt.savefig(buf, format="png") + plt.close(fig) + buf.seek(0) + img = Image.open(buf) + + # return plot as image + return ToTensor()(np.array(img)) + + +def plot_spectrogram(x: torch.Tensor): + """ + Given single audio waveform, return spectrogram plot as image + """ + try: + assert len(x.shape) == 1 or x.shape[0] == 1 + except AssertionError: + raise ValueError('Audio input must be single waveform') + + x = x.clone().detach() + + # spectrogram plot + spec = torch.stft(x.reshape(1, -1), + n_fft=512, + win_length=512, + hop_length=256, + window=torch.hann_window( + window_length=512 + ).to(x.device), + return_complex=True, + center=False + ) + spec = torch.squeeze( + torch.abs(spec) / (torch.max(torch.abs(spec))) + ) # normalize spectrogram by maximum absolute value + + # save plot to buffer + fig, ax = plt.subplots(figsize=(8, 8)) + ax.pcolormesh(tensor_to_np(torch.log(spec + 1)), vmin=0, vmax=.31) + plt.tight_layout() + buf = io.BytesIO() + plt.savefig(buf, format="png") + plt.close(fig) + buf.seek(0) + img = Image.open(buf) + + # return plot image as tensor + return ToTensor()(np.array(img)) + + +def plot_logits(class_scores: torch.Tensor, target: int = None): + """ + Given a vector of class scores, and optionally a target index, create a + simple bar plot of the scores and return as an image + """ + + # require single vector of class scores + try: + assert class_scores.ndim <= 1 or class_scores.shape[0] == 1 + except AssertionError: + raise ValueError('Must provide single vector of class scores') + + # convert to NumPy + scores = tensor_to_np(class_scores).flatten() + labels = np.arange(scores.shape[-1]) + + # bar plot + fig = plt.figure(figsize=(8, 8)) + bars = plt.bar(labels, scores, color='k') + + # if target label index is given, highlight corresponding bar + if target is not None: + + try: + assert 0 <= target < len(scores) + except AssertionError: + raise ValueError("Target must be valid index") + + bars[target].set_color('r') + + # save plot to buffer + buf = io.BytesIO() + plt.savefig(buf, format="png") + plt.close(fig) + buf.seek(0) + img = Image.open(buf) + + # return plot image as tensor + return ToTensor()(np.array(img)) + + +def get_duration(st: datetime, ed: datetime): + """Return duration as string""" + + total_seconds = int((ed - st).seconds) + hours = total_seconds // 3600 + + if hours: + minutes = total_seconds % (3600 * hours) // 60 + else: + minutes = total_seconds // 60 + + seconds = total_seconds + if minutes: + seconds = seconds % (60 * minutes) + if hours: + seconds = seconds % (3600 * hours) + + duration = "" + if hours > 0: + duration += f"{hours}h {minutes}m " + elif minutes > 0: + duration += f"{minutes}m " + duration += f"{seconds}s" + + return duration + + + diff --git a/voicebox/src/utils/writer.py b/voicebox/src/utils/writer.py new file mode 100644 index 0000000000000000000000000000000000000000..cf6976c53c983d4e33451acef368fd9fcd54ce35 --- /dev/null +++ b/voicebox/src/utils/writer.py @@ -0,0 +1,333 @@ +import torch +import torch.nn as nn +from copy import deepcopy +import contextlib +import math +import time +import logging +import sys +import json + +from typing import Union, Dict, Any + +from pathlib import Path + +from torch.utils.tensorboard import SummaryWriter + +from src.constants import RUNS_DIR +from src.utils.filesystem import ensure_dir +from src.utils.plotting import * + +################################################################################ +# Logging utility with optional TensorBoard support +################################################################################ + + +class Writer: + """ + Handles file, console, and TensorBoard logging + """ + def __init__(self, + root_dir: Union[str, Path] = RUNS_DIR, + name: str = None, + use_tb: bool = False, + log_iter: int = 100, + use_timestamp: bool = True, + log_images: bool = False, + **kwargs + ): + """ + Configure logging. + + :param root_dir: root logging directory + :param name: descriptive name for run + :param use_tb: if True, use TensorBoard + :param log_iter: iterations between logging + """ + + # generate run-specific name and create directory + run_name = f'{name}' + if use_timestamp: + run_name += f'_{time.strftime("%m-%d-%H_%M_%S")}' + self.run_dir = Path(root_dir) / run_name + ensure_dir(self.run_dir) + + # prepare checkpoint directory + self.checkpoint_dir = self.run_dir / 'checkpoints' + + # log to TensorBoard + self.use_tb = use_tb + self.log_iter = log_iter + self.writer = SummaryWriter( + log_dir=str(self.run_dir), + flush_secs=20, + ) if self.use_tb else None + + # log to console and file 'log.txt' + self.logger = logging.getLogger(run_name) + + self.logger.setLevel(logging.INFO) + self.logger.addHandler( + logging.StreamHandler(sys.stdout) + ) + self.logger.addHandler( + logging.FileHandler(self.run_dir / 'log.txt') + ) + + # to avoid segmentation faults, it may be necessary to skip image + # logging + self.log_images = log_images + + # self.logger.info(f'Logging to {self.run_dir}') + + # disable Matplotlib logging + logging.getLogger('matplotlib.font_manager').disabled = True + + def get_run_dir(self): + return str(self.run_dir) + + def log_info(self, info: str): + """ + Log generic statements + """ + self.logger.info(info) + + def _dict_to_str(self, d: dict): + """Recursively cast dictionary entries to strings""" + + d_out = {} + + for k, v in d.items(): + if isinstance(v, dict): + d_out[k] = self._dict_to_str(v) + elif not isinstance(v, (float, int, bool)): + d_out[k] = str(v) + else: + d_out[k] = v + return d_out + + def log_config(self, + config: Union[dict, str], + tag: str = "config", + path: Union[str, Path] = None): + """Save config file for run, given dictionary""" + + path = path if path is not None else self.run_dir / f'{tag}.conf' + + with open(path, "w") as out_config: + self.logger.info(f'Saving config to {path}') + + if isinstance(config, dict): + config = self._dict_to_str(config) + json.dump(config, out_config, indent=4) + else: + out_config.write(config) + + def log_scalar(self, x: torch.Tensor, tag: str, global_step: int = 0): + """ + Log scalar + """ + + # only log at specified iterations + if not self.log_iter or global_step % self.log_iter: + return + + # log scalar to file and console + self.logger.info(f'iter {global_step}\t{tag}: {x}') + + # if TensorBoard is enabled + if self.use_tb: + self.writer.add_scalar(f'{tag}', x, global_step=global_step) + self.writer.flush() + + def log_logits(self, + x: torch.Tensor, + target: int = None, + tag: str = None, + global_step: int = 0): + """ + Log class scores (logits) + """ + + # only log at specified iterations + if not self.log_iter or global_step % self.log_iter: + return + + # log plot to TensorBoard + if self.use_tb and self.log_images: + self.writer.add_image( + f"{tag}", plot_logits(x, target), + global_step=global_step + ) + + self.writer.flush() + + def log_audio(self, + x: torch.Tensor, + tag: str, + global_step: int = 0, + sample_rate: int = 16000, + scale: Union[int, float] = 1.0): + """ + Given a single audio waveform, log a normalized recording, waveform + plot, and spectrogram plot to TensorBoard + """ + + # only log at specified iterations and if TensorBoard is enabled + if not self.log_iter or global_step % self.log_iter or not self.use_tb: + return + + # normalize and log audio recording + normalized = (scale / torch.max( + torch.abs(x) + 1e-12, dim=-1, keepdim=True + )[0]) * x * 0.95 + + self.writer.add_audio(f"{tag}-audio", + normalized, + sample_rate=sample_rate, + global_step=global_step) + + if self.log_images: + + # log waveform + self.writer.add_image(f"{tag}-waveform", + plot_waveform(x, scale), + global_step=global_step) + + # log spectrogram + self.writer.add_image(f"{tag}-spectrogram", + plot_spectrogram(x), + global_step=global_step) + + # flush + self.writer.flush() + + def log_norm(self, + x: torch.Tensor, + tag: str, + global_step: int = 0): + """ + Plot norm of input tensor + """ + + # only log at specified iterations and if TensorBoard is enabled + if not self.log_iter or global_step % self.log_iter or not self.use_tb: + return + + # log norms + norm_2 = torch.norm(x, p=2) + norm_inf = torch.norm(x, p=float('inf')) + self.writer.add_scalar(f'{tag}/norm-2', norm_2, global_step=global_step) + self.writer.add_scalar(f'{tag}/norm-inf', norm_inf, global_step=global_step) + + self.writer.flush() + + def log_image(self, image: torch.Tensor, tag: str, global_step: int = 0): + """ + Log image plot + """ + + if not self.log_images: + return + + # only log at specified iterations and if TensorBoard is enabled + if not self.log_iter or global_step % self.log_iter or not self.use_tb: + return + + self.writer.add_image( + tag, + image, + global_step + ) + + self.writer.flush() + + def log_filter(self, + amplitudes: torch.Tensor, + tag: str, + global_step: int = 0): + """ + Plot filter controls + """ + + # only log at specified iterations and if TensorBoard is enabled + if not self.log_iter or global_step % self.log_iter or not self.use_tb: + return + + if self.log_images: + + self.writer.add_image( + f'filter_controls/{tag}', + plot_filter(amplitudes), + global_step + ) + + self.writer.flush() + + @staticmethod + def bytes_to_gb(size_bytes: int): + """ + Code from: https://stackoverflow.com/a/14822210 + """ + if size_bytes == 0: + return "0B" + size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") + i = int(math.floor(math.log(size_bytes, 1024))) + p = math.pow(1024, i) + s = round(size_bytes / p, 2) + return "%s %s" % (s, size_name[i]) + + def log_cuda_memory(self, device: int = 0): + + total_memory = self.bytes_to_gb( + torch.cuda.get_device_properties(device).total_memory + ) + reserved_memory = self.bytes_to_gb( + torch.cuda.memory_reserved(device) + ) + allocated_memory = self.bytes_to_gb( + torch.cuda.memory_allocated(device) + ) + free_memory = self.bytes_to_gb( + torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated( + device + ) + ) + + self.logger.info(f'\nMemory management:\n' + f'------------------\n' + f'Total: {total_memory}\n' + f'Reserved: {reserved_memory}\n' + f'Allocated: {allocated_memory}\n' + f'Free: {free_memory}\n') + + def checkpoint(self, + checkpoint: Union[nn.Module, Dict[str, Any]], + tag: str, + global_step: int = None + ): + """ + Given nn.Module object or state dictionary, save to disk + """ + ensure_dir(self.checkpoint_dir) + + if global_step is not None: + filename = f'{tag}_{global_step}.pt' + else: + filename = f'{tag}.pt' + + if isinstance(checkpoint, nn.Module): + checkpoint = checkpoint.state_dict() + + torch.save( + checkpoint, + self.checkpoint_dir / filename + ) + + @contextlib.contextmanager + def force_logging(self): + """Force Writer to log by temporarily overriding logging interval""" + log_iter = self.log_iter + self.log_iter = 1 + yield + self.log_iter = log_iter