diff --git a/README.md b/README.md index f4bb5a5ad043ac09b6987704e3c4a54d5ab7febf..e24016b993ec87d4377212e338f1ee041f15e1dd 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ This repository provides an implementation of the MiniMax-Speech model, featurin ## Architecture ### Stage 1: Audio to Discrete Tokens -Converts raw audio into discrete representations using the DAC (Descript Audio Codec) framework. +Converts raw audio into discrete representations using the FSQ (S3Tokenizer) framework. ### Stage 2: Discrete Tokens to Continuous Latent Space Maps discrete tokens to a continuous latent space using a Variational Autoencoder (VAE). @@ -29,25 +29,25 @@ Maps discrete tokens to a continuous latent space using a Variational Autoencode ### 1. Model Training -#### BPE tokens to DAC codec tokens -- Based on the DAC codec -- Using Auto Regressive to predict the DAC codec tokens with learnable speaker extractor +#### BPE tokens to FSQ tokens +- Based on the FSQ +- Using Auto Regressive to predict the FSQ tokens with learnable speaker extractor -#### DAC codec tokens to DAC-VAE latent +#### FSQ tokens to DAC-VAE latent - Based on Cosyvoice2 flow matching decoder - Learns continuous latent representations from discrete tokens ### 2. Feature Extraction Before training the main model: -1. Extract discrete tokens using the trained DAC codec [Descript Audio Codec](https://github.com/descriptinc/descript-audio-codec) +1. Extract discrete tokens using the trained FSQ [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer) 2. Generate continuous latent representations using the trained DAC-VAE - the pretrained I provided here: [DAC-VAE](https://drive.google.com/file/d/1iwZhPlcdDwvPjeON3bFAeYarsV4ZtI2E/view?usp=sharing) ### 3. Two-Stage Training Train the models sequentially: -- **Stage 1**: BPE tokens → Discrete DAC codec -- **Stage 2**: Discrete DAC codec → DAC-VAE Continuous latent space +- **Stage 1**: BPE tokens → Discrete FSQ +- **Stage 2**: Discrete FSQ → DAC-VAE Continuous latent space ## Getting Started @@ -59,7 +59,7 @@ pip install -r requirements.txt ### Training Pipeline -1. **Extracting DAC Codec** (if not using pretrained) +1. **Extracting FSQ** (if not using pretrained) ```bash # Add training command ``` @@ -88,13 +88,12 @@ minimax-speech/ ├── configs/ │ └── dac_vae.yaml ├── models/ -│ ├── dac_codec/ +│ ├── fsq/ │ └── dac_vae/ ├── cosyvoice/ # Components from CosyVoice2 │ ├── flow/ │ ├── transformer/ │ └── utils/ -├── train_dac_vae.py └── README.md ``` @@ -130,13 +129,13 @@ If you use this code in your research, please cite: This project follows the licensing terms of its dependencies: - CosyVoice2 components: [Check CosyVoice2 License](https://github.com/FunAudioLLM/CosyVoice/blob/main/LICENSE) -- DAC components: [Apache 2.0 License](https://github.com/descriptinc/descript-audio-codec/blob/main/LICENSE) +- FSQ components: [Apache 2.0 License](https://github.com/xingchensong/S3Tokenizer/blob/main/LICENSE) - Original contributions: [Specify your license here] ## Acknowledgments - **[CosyVoice2](https://github.com/FunAudioLLM/CosyVoice)**: This implementation extensively uses code and architectures from CosyVoice2 -- **[Descript Audio Codec](https://github.com/descriptinc/descript-audio-codec)**: For the DAC implementation +- **[FSQ](https://github.com/xingchensong/S3Tokenizer)**: For the FSQ implementation - **MiniMax team**: For the technical report and methodology - **FunAudioLLM team**: For the excellent CosyVoice2 codebase diff --git a/dac-codec/assets/comparsion_stats.png b/dac-codec/assets/comparsion_stats.png deleted file mode 100644 index be1aea8ac5181c8cc26fbfbddb2316cb99be6c77..0000000000000000000000000000000000000000 --- a/dac-codec/assets/comparsion_stats.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:46dcd8f1b60cf44443354b21cece5b88f2a122aa4788dc4f899e5d28f34e2dac -size 184991 diff --git a/dac-codec/assets/objective_comparisons.png b/dac-codec/assets/objective_comparisons.png deleted file mode 100644 index c3ab1e94fbfb54f00e7c399c4917467a4a605a5e..0000000000000000000000000000000000000000 --- a/dac-codec/assets/objective_comparisons.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:919fa32c38d51aed15a7fb43eba2d28b687636ba3a289e7f9eeb10ab6489030d -size 530958 diff --git a/dac-codec/conf/1gpu.yml b/dac-codec/conf/1gpu.yml deleted file mode 100644 index ce878027b0bfaa6464081ee7aab203ffa300a128..0000000000000000000000000000000000000000 --- a/dac-codec/conf/1gpu.yml +++ /dev/null @@ -1,6 +0,0 @@ -$include: - - conf/base.yml - -batch_size: 12 -val_batch_size: 12 -num_workers: 4 diff --git a/dac-codec/conf/ablations/baseline.yml b/dac-codec/conf/ablations/baseline.yml deleted file mode 100644 index 1510ce237b7ec7655572f53ff74b195c957b64b5..0000000000000000000000000000000000000000 --- a/dac-codec/conf/ablations/baseline.yml +++ /dev/null @@ -1,3 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml diff --git a/dac-codec/conf/ablations/diff-mb.yml b/dac-codec/conf/ablations/diff-mb.yml deleted file mode 100644 index afa758d87c998c0379c97538a8eeee2f06d030bc..0000000000000000000000000000000000000000 --- a/dac-codec/conf/ablations/diff-mb.yml +++ /dev/null @@ -1,22 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -Discriminator.sample_rate: 44100 -Discriminator.fft_sizes: [2048, 1024, 512] -Discriminator.bands: - - [0.0, 0.05] - - [0.05, 0.1] - - [0.1, 0.25] - - [0.25, 0.5] - - [0.5, 1.0] - - -# re-weight lambdas to make up for -# lost discriminators vs baseline -lambdas: - mel/loss: 15.0 - adv/feat_loss: 5.0 - adv/gen_loss: 1.0 - vq/commitment_loss: 0.25 - vq/codebook_loss: 1.0 diff --git a/dac-codec/conf/ablations/equal-mb.yml b/dac-codec/conf/ablations/equal-mb.yml deleted file mode 100644 index 2c091ac6f5bcc0a2d519d4a849bed178d5a76050..0000000000000000000000000000000000000000 --- a/dac-codec/conf/ablations/equal-mb.yml +++ /dev/null @@ -1,22 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -Discriminator.sample_rate: 44100 -Discriminator.fft_sizes: [2048, 1024, 512] -Discriminator.bands: - - [0.0, 0.2] - - [0.2, 0.4] - - [0.4, 0.6] - - [0.6, 0.8] - - [0.8, 1.0] - - -# re-weight lambdas to make up for -# lost discriminators vs baseline -lambdas: - mel/loss: 15.0 - adv/feat_loss: 5.0 - adv/gen_loss: 1.0 - vq/commitment_loss: 0.25 - vq/codebook_loss: 1.0 diff --git a/dac-codec/conf/ablations/no-adv.yml b/dac-codec/conf/ablations/no-adv.yml deleted file mode 100644 index 75e271badf185b4a68edf1645e7d073a80968db1..0000000000000000000000000000000000000000 --- a/dac-codec/conf/ablations/no-adv.yml +++ /dev/null @@ -1,9 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -lambdas: - mel/loss: 1.0 - waveform/loss: 1.0 - vq/commitment_loss: 0.25 - vq/codebook_loss: 1.0 diff --git a/dac-codec/conf/ablations/no-data-balance.yml b/dac-codec/conf/ablations/no-data-balance.yml deleted file mode 100644 index a88f39254045144255fd8b46aaccf31444b0b254..0000000000000000000000000000000000000000 --- a/dac-codec/conf/ablations/no-data-balance.yml +++ /dev/null @@ -1,22 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -train/build_dataset.folders: - speech: - - /data/daps/train - - /data/vctk - - /data/vocalset - - /data/read_speech - - /data/french_speech - - /data/emotional_speech/ - - /data/common_voice/ - - /data/german_speech/ - - /data/russian_speech/ - - /data/spanish_speech/ - music: - - /data/musdb/train - - /data/jamendo - general: - - /data/audioset/data/unbalanced_train_segments/ - - /data/audioset/data/balanced_train_segments/ diff --git a/dac-codec/conf/ablations/no-low-hop.yml b/dac-codec/conf/ablations/no-low-hop.yml deleted file mode 100644 index abde9239186444fd045eccc8bbbc409b27e7ff9e..0000000000000000000000000000000000000000 --- a/dac-codec/conf/ablations/no-low-hop.yml +++ /dev/null @@ -1,18 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -MelSpectrogramLoss.n_mels: [80] -MelSpectrogramLoss.window_lengths: [512] -MelSpectrogramLoss.mel_fmin: [0] -MelSpectrogramLoss.mel_fmax: [null] -MelSpectrogramLoss.pow: 1.0 -MelSpectrogramLoss.clamp_eps: 1.0e-5 -MelSpectrogramLoss.mag_weight: 0.0 - -lambdas: - mel/loss: 100.0 - adv/feat_loss: 2.0 - adv/gen_loss: 1.0 - vq/commitment_loss: 0.25 - vq/codebook_loss: 1.0 diff --git a/dac-codec/conf/ablations/no-mb.yml b/dac-codec/conf/ablations/no-mb.yml deleted file mode 100644 index 3aa0015b26e61e6811e1f1f01914277390a459fe..0000000000000000000000000000000000000000 --- a/dac-codec/conf/ablations/no-mb.yml +++ /dev/null @@ -1,17 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -Discriminator.sample_rate: 44100 -Discriminator.fft_sizes: [2048, 1024, 512] -Discriminator.bands: - - [0.0, 1.0] - -# re-weight lambdas to make up for -# lost discriminators vs baseline -lambdas: - mel/loss: 15.0 - adv/feat_loss: 5.0 - adv/gen_loss: 1.0 - vq/commitment_loss: 0.25 - vq/codebook_loss: 1.0 diff --git a/dac-codec/conf/ablations/no-mpd-msd.yml b/dac-codec/conf/ablations/no-mpd-msd.yml deleted file mode 100644 index 9059b825cea1606ad9ebc0ba352bd5705eab8576..0000000000000000000000000000000000000000 --- a/dac-codec/conf/ablations/no-mpd-msd.yml +++ /dev/null @@ -1,21 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -Discriminator.sample_rate: 44100 -Discriminator.rates: [] -Discriminator.periods: [] -Discriminator.fft_sizes: [2048, 1024, 512] -Discriminator.bands: - - [0.0, 0.1] - - [0.1, 0.25] - - [0.25, 0.5] - - [0.5, 0.75] - - [0.75, 1.0] - -lambdas: - mel/loss: 15.0 - adv/feat_loss: 2.66 - adv/gen_loss: 1.0 - vq/commitment_loss: 0.25 - vq/codebook_loss: 1.0 diff --git a/dac-codec/conf/ablations/no-mpd.yml b/dac-codec/conf/ablations/no-mpd.yml deleted file mode 100644 index 0e4dc933bdaa8f31cbc28345da838b5ceccd37a5..0000000000000000000000000000000000000000 --- a/dac-codec/conf/ablations/no-mpd.yml +++ /dev/null @@ -1,21 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -Discriminator.sample_rate: 44100 -Discriminator.rates: [1] -Discriminator.periods: [] -Discriminator.fft_sizes: [2048, 1024, 512] -Discriminator.bands: - - [0.0, 0.1] - - [0.1, 0.25] - - [0.25, 0.5] - - [0.5, 0.75] - - [0.75, 1.0] - -lambdas: - mel/loss: 15.0 - adv/feat_loss: 2.5 - adv/gen_loss: 1.0 - vq/commitment_loss: 0.25 - vq/codebook_loss: 1.0 diff --git a/dac-codec/conf/ablations/only-speech.yml b/dac-codec/conf/ablations/only-speech.yml deleted file mode 100644 index c2bbc0d34ea8911effe31d4f0aa4ffcc56747b01..0000000000000000000000000000000000000000 --- a/dac-codec/conf/ablations/only-speech.yml +++ /dev/null @@ -1,22 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -train/build_dataset.folders: - speech_fb: - - /data/daps/train - speech_hq: - - /data/vctk - - /data/vocalset - - /data/read_speech - - /data/french_speech - speech_uq: - - /data/emotional_speech/ - - /data/common_voice/ - - /data/german_speech/ - - /data/russian_speech/ - - /data/spanish_speech/ - -val/build_dataset.folders: - speech_hq: - - /data/daps/val diff --git a/dac-codec/conf/base.yml b/dac-codec/conf/base.yml deleted file mode 100644 index 746e6f76e1274ba3b50688303192b0e67de82123..0000000000000000000000000000000000000000 --- a/dac-codec/conf/base.yml +++ /dev/null @@ -1,123 +0,0 @@ -# Model setup -DAC.sample_rate: 44100 -DAC.encoder_dim: 64 -DAC.encoder_rates: [2, 4, 8, 8] -DAC.decoder_dim: 1536 -DAC.decoder_rates: [8, 8, 4, 2] - -# Quantization -DAC.n_codebooks: 9 -DAC.codebook_size: 1024 -DAC.codebook_dim: 8 -DAC.quantizer_dropout: 1.0 - -# Discriminator -Discriminator.sample_rate: 44100 -Discriminator.rates: [] -Discriminator.periods: [2, 3, 5, 7, 11] -Discriminator.fft_sizes: [2048, 1024, 512] -Discriminator.bands: - - [0.0, 0.1] - - [0.1, 0.25] - - [0.25, 0.5] - - [0.5, 0.75] - - [0.75, 1.0] - -# Optimization -AdamW.betas: [0.8, 0.99] -AdamW.lr: 0.0001 -ExponentialLR.gamma: 0.999996 - -amp: false -val_batch_size: 100 -device: cuda -num_iters: 250000 -save_iters: [10000, 50000, 100000, 200000] -valid_freq: 1000 -sample_freq: 10000 -num_workers: 32 -val_idx: [0, 1, 2, 3, 4, 5, 6, 7] -seed: 0 -lambdas: - mel/loss: 15.0 - adv/feat_loss: 2.0 - adv/gen_loss: 1.0 - vq/commitment_loss: 0.25 - vq/codebook_loss: 1.0 - -VolumeNorm.db: [const, -16] - -# Transforms -build_transform.preprocess: - - Identity -build_transform.augment_prob: 0.0 -build_transform.augment: - - Identity -build_transform.postprocess: - - VolumeNorm - - RescaleAudio - - ShiftPhase - -# Loss setup -MultiScaleSTFTLoss.window_lengths: [2048, 512] -MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320] -MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048] -MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0] -MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null] -MelSpectrogramLoss.pow: 1.0 -MelSpectrogramLoss.clamp_eps: 1.0e-5 -MelSpectrogramLoss.mag_weight: 0.0 - -# Data -batch_size: 72 -train/AudioDataset.duration: 0.38 -train/AudioDataset.n_examples: 10000000 - -val/AudioDataset.duration: 5.0 -val/build_transform.augment_prob: 1.0 -val/AudioDataset.n_examples: 250 - -test/AudioDataset.duration: 10.0 -test/build_transform.augment_prob: 1.0 -test/AudioDataset.n_examples: 1000 - -AudioLoader.shuffle: true -AudioDataset.without_replacement: true - -train/build_dataset.folders: - speech_fb: - - /data/daps/train - speech_hq: - - /data/vctk - - /data/vocalset - - /data/read_speech - - /data/french_speech - speech_uq: - - /data/emotional_speech/ - - /data/common_voice/ - - /data/german_speech/ - - /data/russian_speech/ - - /data/spanish_speech/ - music_hq: - - /data/musdb/train - music_uq: - - /data/jamendo - general: - - /data/audioset/data/unbalanced_train_segments/ - - /data/audioset/data/balanced_train_segments/ - -val/build_dataset.folders: - speech_hq: - - /data/daps/val - music_hq: - - /data/musdb/test - general: - - /data/audioset/data/eval_segments/ - -test/build_dataset.folders: - speech_hq: - - /data/daps/test - music_hq: - - /data/musdb/test - general: - - /data/audioset/data/eval_segments/ diff --git a/dac-codec/conf/downsampling/1024x.yml b/dac-codec/conf/downsampling/1024x.yml deleted file mode 100644 index 2719f9b0378d1ac4b4b0cbd2774f7e51cb4d8ad2..0000000000000000000000000000000000000000 --- a/dac-codec/conf/downsampling/1024x.yml +++ /dev/null @@ -1,16 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -# Model setup -DAC.sample_rate: 44100 -DAC.encoder_dim: 64 -DAC.encoder_rates: [2, 8, 8, 8] -DAC.decoder_dim: 1536 -DAC.decoder_rates: [8, 4, 4, 2, 2, 2] - -# Quantization -DAC.n_codebooks: 19 -DAC.codebook_size: 1024 -DAC.codebook_dim: 8 -DAC.quantizer_dropout: 1.0 diff --git a/dac-codec/conf/downsampling/128x.yml b/dac-codec/conf/downsampling/128x.yml deleted file mode 100644 index cf7d5a417b50717b2ce24a7cc11bb5630a638c02..0000000000000000000000000000000000000000 --- a/dac-codec/conf/downsampling/128x.yml +++ /dev/null @@ -1,16 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -# Model setup -DAC.sample_rate: 44100 -DAC.encoder_dim: 64 -DAC.encoder_rates: [2, 4, 4, 4] -DAC.decoder_dim: 1536 -DAC.decoder_rates: [4, 4, 2, 2, 2, 1] - -# Quantization -DAC.n_codebooks: 2 -DAC.codebook_size: 1024 -DAC.codebook_dim: 8 -DAC.quantizer_dropout: 1.0 diff --git a/dac-codec/conf/downsampling/1536x.yml b/dac-codec/conf/downsampling/1536x.yml deleted file mode 100644 index fa695b1657a9f8b6902479d1da918a58333e6df7..0000000000000000000000000000000000000000 --- a/dac-codec/conf/downsampling/1536x.yml +++ /dev/null @@ -1,16 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -# Model setup -DAC.sample_rate: 44100 -DAC.encoder_dim: 96 -DAC.encoder_rates: [2, 8, 8, 12] -DAC.decoder_dim: 1536 -DAC.decoder_rates: [12, 4, 4, 2, 2, 2] - -# Quantization -DAC.n_codebooks: 28 -DAC.codebook_size: 1024 -DAC.codebook_dim: 8 -DAC.quantizer_dropout: 1.0 diff --git a/dac-codec/conf/downsampling/768x.yml b/dac-codec/conf/downsampling/768x.yml deleted file mode 100644 index 81005454c63370d94fe2751f45a25bd1577c656d..0000000000000000000000000000000000000000 --- a/dac-codec/conf/downsampling/768x.yml +++ /dev/null @@ -1,16 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -# Model setup -DAC.sample_rate: 44100 -DAC.encoder_dim: 64 -DAC.encoder_rates: [2, 6, 8, 8] -DAC.decoder_dim: 1536 -DAC.decoder_rates: [6, 4, 4, 2, 2, 2] - -# Quantization -DAC.n_codebooks: 14 -DAC.codebook_size: 1024 -DAC.codebook_dim: 8 -DAC.quantizer_dropout: 1.0 diff --git a/dac-codec/conf/final/16khz.yml b/dac-codec/conf/final/16khz.yml deleted file mode 100644 index a86e10773ae3e9beff90d43237ce13f9207f782e..0000000000000000000000000000000000000000 --- a/dac-codec/conf/final/16khz.yml +++ /dev/null @@ -1,123 +0,0 @@ -# Model setup -DAC.sample_rate: 16000 -DAC.encoder_dim: 64 -DAC.encoder_rates: [2, 4, 5, 8] -DAC.decoder_dim: 1536 -DAC.decoder_rates: [8, 5, 4, 2] - -# Quantization -DAC.n_codebooks: 12 -DAC.codebook_size: 1024 -DAC.codebook_dim: 8 -DAC.quantizer_dropout: 0.5 - -# Discriminator -Discriminator.sample_rate: 16000 -Discriminator.rates: [] -Discriminator.periods: [2, 3, 5, 7, 11] -Discriminator.fft_sizes: [2048, 1024, 512] -Discriminator.bands: - - [0.0, 0.1] - - [0.1, 0.25] - - [0.25, 0.5] - - [0.5, 0.75] - - [0.75, 1.0] - -# Optimization -AdamW.betas: [0.8, 0.99] -AdamW.lr: 0.0001 -ExponentialLR.gamma: 0.999996 - -amp: false -val_batch_size: 100 -device: cuda -num_iters: 400000 -save_iters: [10000, 50000, 100000, 200000] -valid_freq: 1000 -sample_freq: 10000 -num_workers: 32 -val_idx: [0, 1, 2, 3, 4, 5, 6, 7] -seed: 0 -lambdas: - mel/loss: 15.0 - adv/feat_loss: 2.0 - adv/gen_loss: 1.0 - vq/commitment_loss: 0.25 - vq/codebook_loss: 1.0 - -VolumeNorm.db: [const, -16] - -# Transforms -build_transform.preprocess: - - Identity -build_transform.augment_prob: 0.0 -build_transform.augment: - - Identity -build_transform.postprocess: - - VolumeNorm - - RescaleAudio - - ShiftPhase - -# Loss setup -MultiScaleSTFTLoss.window_lengths: [2048, 512] -MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320] -MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048] -MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0] -MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null] -MelSpectrogramLoss.pow: 1.0 -MelSpectrogramLoss.clamp_eps: 1.0e-5 -MelSpectrogramLoss.mag_weight: 0.0 - -# Data -batch_size: 72 -train/AudioDataset.duration: 0.38 -train/AudioDataset.n_examples: 10000000 - -val/AudioDataset.duration: 5.0 -val/build_transform.augment_prob: 1.0 -val/AudioDataset.n_examples: 250 - -test/AudioDataset.duration: 10.0 -test/build_transform.augment_prob: 1.0 -test/AudioDataset.n_examples: 1000 - -AudioLoader.shuffle: true -AudioDataset.without_replacement: true - -train/build_dataset.folders: - speech_fb: - - /data/daps/train - speech_hq: - - /data/vctk - - /data/vocalset - - /data/read_speech - - /data/french_speech - speech_uq: - - /data/emotional_speech/ - - /data/common_voice/ - - /data/german_speech/ - - /data/russian_speech/ - - /data/spanish_speech/ - music_hq: - - /data/musdb/train - music_uq: - - /data/jamendo - general: - - /data/audioset/data/unbalanced_train_segments/ - - /data/audioset/data/balanced_train_segments/ - -val/build_dataset.folders: - speech_hq: - - /data/daps/val - music_hq: - - /data/musdb/test - general: - - /data/audioset/data/eval_segments/ - -test/build_dataset.folders: - speech_hq: - - /data/daps/test - music_hq: - - /data/musdb/test - general: - - /data/audioset/data/eval_segments/ diff --git a/dac-codec/conf/final/24khz.yml b/dac-codec/conf/final/24khz.yml deleted file mode 100644 index b20298a302da392f1bcf414f03762bbb888593f7..0000000000000000000000000000000000000000 --- a/dac-codec/conf/final/24khz.yml +++ /dev/null @@ -1,123 +0,0 @@ -# Model setup -DAC.sample_rate: 24000 -DAC.encoder_dim: 64 -DAC.encoder_rates: [2, 4, 5, 8] -DAC.decoder_dim: 1536 -DAC.decoder_rates: [8, 5, 4, 2] - -# Quantization -DAC.n_codebooks: 32 -DAC.codebook_size: 1024 -DAC.codebook_dim: 8 -DAC.quantizer_dropout: 0.5 - -# Discriminator -Discriminator.sample_rate: 24000 -Discriminator.rates: [] -Discriminator.periods: [2, 3, 5, 7, 11] -Discriminator.fft_sizes: [2048, 1024, 512] -Discriminator.bands: - - [0.0, 0.1] - - [0.1, 0.25] - - [0.25, 0.5] - - [0.5, 0.75] - - [0.75, 1.0] - -# Optimization -AdamW.betas: [0.8, 0.99] -AdamW.lr: 0.0001 -ExponentialLR.gamma: 0.999996 - -amp: false -val_batch_size: 100 -device: cuda -num_iters: 400000 -save_iters: [10000, 50000, 100000, 200000] -valid_freq: 1000 -sample_freq: 10000 -num_workers: 32 -val_idx: [0, 1, 2, 3, 4, 5, 6, 7] -seed: 0 -lambdas: - mel/loss: 15.0 - adv/feat_loss: 2.0 - adv/gen_loss: 1.0 - vq/commitment_loss: 0.25 - vq/codebook_loss: 1.0 - -VolumeNorm.db: [const, -16] - -# Transforms -build_transform.preprocess: - - Identity -build_transform.augment_prob: 0.0 -build_transform.augment: - - Identity -build_transform.postprocess: - - VolumeNorm - - RescaleAudio - - ShiftPhase - -# Loss setup -MultiScaleSTFTLoss.window_lengths: [2048, 512] -MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320] -MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048] -MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0] -MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null] -MelSpectrogramLoss.pow: 1.0 -MelSpectrogramLoss.clamp_eps: 1.0e-5 -MelSpectrogramLoss.mag_weight: 0.0 - -# Data -batch_size: 72 -train/AudioDataset.duration: 0.38 -train/AudioDataset.n_examples: 10000000 - -val/AudioDataset.duration: 5.0 -val/build_transform.augment_prob: 1.0 -val/AudioDataset.n_examples: 250 - -test/AudioDataset.duration: 10.0 -test/build_transform.augment_prob: 1.0 -test/AudioDataset.n_examples: 1000 - -AudioLoader.shuffle: true -AudioDataset.without_replacement: true - -train/build_dataset.folders: - speech_fb: - - /data/daps/train - speech_hq: - - /data/vctk - - /data/vocalset - - /data/read_speech - - /data/french_speech - speech_uq: - - /data/emotional_speech/ - - /data/common_voice/ - - /data/german_speech/ - - /data/russian_speech/ - - /data/spanish_speech/ - music_hq: - - /data/musdb/train - music_uq: - - /data/jamendo - general: - - /data/audioset/data/unbalanced_train_segments/ - - /data/audioset/data/balanced_train_segments/ - -val/build_dataset.folders: - speech_hq: - - /data/daps/val - music_hq: - - /data/musdb/test - general: - - /data/audioset/data/eval_segments/ - -test/build_dataset.folders: - speech_hq: - - /data/daps/test - music_hq: - - /data/musdb/test - general: - - /data/audioset/data/eval_segments/ diff --git a/dac-codec/conf/final/44khz-16kbps.yml b/dac-codec/conf/final/44khz-16kbps.yml deleted file mode 100644 index 3ee405de7a23c1f6b1389d10dbaf162631bc1ac7..0000000000000000000000000000000000000000 --- a/dac-codec/conf/final/44khz-16kbps.yml +++ /dev/null @@ -1,124 +0,0 @@ -# Model setup -DAC.sample_rate: 44100 -DAC.encoder_dim: 64 -DAC.encoder_rates: [2, 4, 8, 8] -DAC.latent_dim: 128 -DAC.decoder_dim: 1536 -DAC.decoder_rates: [8, 8, 4, 2] - -# Quantization -DAC.n_codebooks: 18 # Max bitrate of 16kbps -DAC.codebook_size: 1024 -DAC.codebook_dim: 8 -DAC.quantizer_dropout: 0.5 - -# Discriminator -Discriminator.sample_rate: 44100 -Discriminator.rates: [] -Discriminator.periods: [2, 3, 5, 7, 11] -Discriminator.fft_sizes: [2048, 1024, 512] -Discriminator.bands: - - [0.0, 0.1] - - [0.1, 0.25] - - [0.25, 0.5] - - [0.5, 0.75] - - [0.75, 1.0] - -# Optimization -AdamW.betas: [0.8, 0.99] -AdamW.lr: 0.0001 -ExponentialLR.gamma: 0.999996 - -amp: false -val_batch_size: 100 -device: cuda -num_iters: 400000 -save_iters: [10000, 50000, 100000, 200000] -valid_freq: 1000 -sample_freq: 10000 -num_workers: 32 -val_idx: [0, 1, 2, 3, 4, 5, 6, 7] -seed: 0 -lambdas: - mel/loss: 15.0 - adv/feat_loss: 2.0 - adv/gen_loss: 1.0 - vq/commitment_loss: 0.25 - vq/codebook_loss: 1.0 - -VolumeNorm.db: [const, -16] - -# Transforms -build_transform.preprocess: - - Identity -build_transform.augment_prob: 0.0 -build_transform.augment: - - Identity -build_transform.postprocess: - - VolumeNorm - - RescaleAudio - - ShiftPhase - -# Loss setup -MultiScaleSTFTLoss.window_lengths: [2048, 512] -MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320] -MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048] -MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0] -MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null] -MelSpectrogramLoss.pow: 1.0 -MelSpectrogramLoss.clamp_eps: 1.0e-5 -MelSpectrogramLoss.mag_weight: 0.0 - -# Data -batch_size: 72 -train/AudioDataset.duration: 0.38 -train/AudioDataset.n_examples: 10000000 - -val/AudioDataset.duration: 5.0 -val/build_transform.augment_prob: 1.0 -val/AudioDataset.n_examples: 250 - -test/AudioDataset.duration: 10.0 -test/build_transform.augment_prob: 1.0 -test/AudioDataset.n_examples: 1000 - -AudioLoader.shuffle: true -AudioDataset.without_replacement: true - -train/build_dataset.folders: - speech_fb: - - /data/daps/train - speech_hq: - - /data/vctk - - /data/vocalset - - /data/read_speech - - /data/french_speech - speech_uq: - - /data/emotional_speech/ - - /data/common_voice/ - - /data/german_speech/ - - /data/russian_speech/ - - /data/spanish_speech/ - music_hq: - - /data/musdb/train - music_uq: - - /data/jamendo - general: - - /data/audioset/data/unbalanced_train_segments/ - - /data/audioset/data/balanced_train_segments/ - -val/build_dataset.folders: - speech_hq: - - /data/daps/val - music_hq: - - /data/musdb/test - general: - - /data/audioset/data/eval_segments/ - -test/build_dataset.folders: - speech_hq: - - /data/daps/test - music_hq: - - /data/musdb/test - general: - - /data/audioset/data/eval_segments/ diff --git a/dac-codec/conf/final/44khz.yml b/dac-codec/conf/final/44khz.yml deleted file mode 100644 index f3de25e37bcbe85212566b4ced30a6264b36f70a..0000000000000000000000000000000000000000 --- a/dac-codec/conf/final/44khz.yml +++ /dev/null @@ -1,123 +0,0 @@ -# Model setup -DAC.sample_rate: 44100 -DAC.encoder_dim: 64 -DAC.encoder_rates: [2, 4, 8, 8] -DAC.decoder_dim: 1536 -DAC.decoder_rates: [8, 8, 4, 2] - -# Quantization -DAC.n_codebooks: 9 -DAC.codebook_size: 1024 -DAC.codebook_dim: 8 -DAC.quantizer_dropout: 0.5 - -# Discriminator -Discriminator.sample_rate: 44100 -Discriminator.rates: [] -Discriminator.periods: [2, 3, 5, 7, 11] -Discriminator.fft_sizes: [2048, 1024, 512] -Discriminator.bands: - - [0.0, 0.1] - - [0.1, 0.25] - - [0.25, 0.5] - - [0.5, 0.75] - - [0.75, 1.0] - -# Optimization -AdamW.betas: [0.8, 0.99] -AdamW.lr: 0.0001 -ExponentialLR.gamma: 0.999996 - -amp: false -val_batch_size: 100 -device: cuda -num_iters: 400000 -save_iters: [10000, 50000, 100000, 200000] -valid_freq: 1000 -sample_freq: 10000 -num_workers: 32 -val_idx: [0, 1, 2, 3, 4, 5, 6, 7] -seed: 0 -lambdas: - mel/loss: 15.0 - adv/feat_loss: 2.0 - adv/gen_loss: 1.0 - vq/commitment_loss: 0.25 - vq/codebook_loss: 1.0 - -VolumeNorm.db: [const, -16] - -# Transforms -build_transform.preprocess: - - Identity -build_transform.augment_prob: 0.0 -build_transform.augment: - - Identity -build_transform.postprocess: - - VolumeNorm - - RescaleAudio - - ShiftPhase - -# Loss setup -MultiScaleSTFTLoss.window_lengths: [2048, 512] -MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320] -MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048] -MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0] -MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null] -MelSpectrogramLoss.pow: 1.0 -MelSpectrogramLoss.clamp_eps: 1.0e-5 -MelSpectrogramLoss.mag_weight: 0.0 - -# Data -batch_size: 72 -train/AudioDataset.duration: 0.38 -train/AudioDataset.n_examples: 10000000 - -val/AudioDataset.duration: 5.0 -val/build_transform.augment_prob: 1.0 -val/AudioDataset.n_examples: 250 - -test/AudioDataset.duration: 10.0 -test/build_transform.augment_prob: 1.0 -test/AudioDataset.n_examples: 1000 - -AudioLoader.shuffle: true -AudioDataset.without_replacement: true - -train/build_dataset.folders: - speech_fb: - - /data/daps/train - speech_hq: - - /data/vctk - - /data/vocalset - - /data/read_speech - - /data/french_speech - speech_uq: - - /data/emotional_speech/ - - /data/common_voice/ - - /data/german_speech/ - - /data/russian_speech/ - - /data/spanish_speech/ - music_hq: - - /data/musdb/train - music_uq: - - /data/jamendo - general: - - /data/audioset/data/unbalanced_train_segments/ - - /data/audioset/data/balanced_train_segments/ - -val/build_dataset.folders: - speech_hq: - - /data/daps/val - music_hq: - - /data/musdb/test - general: - - /data/audioset/data/eval_segments/ - -test/build_dataset.folders: - speech_hq: - - /data/daps/test - music_hq: - - /data/musdb/test - general: - - /data/audioset/data/eval_segments/ diff --git a/dac-codec/conf/quantizer/24kbps.yml b/dac-codec/conf/quantizer/24kbps.yml deleted file mode 100644 index 1b2f26ae13433489f098a8037a248699989b2bc4..0000000000000000000000000000000000000000 --- a/dac-codec/conf/quantizer/24kbps.yml +++ /dev/null @@ -1,5 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -DAC.n_codebooks: 28 diff --git a/dac-codec/conf/quantizer/256d.yml b/dac-codec/conf/quantizer/256d.yml deleted file mode 100644 index 2d958f85fd229074bc4a34176f792a7058378ee8..0000000000000000000000000000000000000000 --- a/dac-codec/conf/quantizer/256d.yml +++ /dev/null @@ -1,5 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -DAC.codebook_dim: 256 diff --git a/dac-codec/conf/quantizer/2d.yml b/dac-codec/conf/quantizer/2d.yml deleted file mode 100644 index aae678efc13d0c85a16d9b56893640ed97064b00..0000000000000000000000000000000000000000 --- a/dac-codec/conf/quantizer/2d.yml +++ /dev/null @@ -1,5 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -DAC.codebook_dim: 2 diff --git a/dac-codec/conf/quantizer/32d.yml b/dac-codec/conf/quantizer/32d.yml deleted file mode 100644 index 24ba180c7309131973d0a507266b6dfe1f1fdc7e..0000000000000000000000000000000000000000 --- a/dac-codec/conf/quantizer/32d.yml +++ /dev/null @@ -1,5 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -DAC.codebook_dim: 32 diff --git a/dac-codec/conf/quantizer/4d.yml b/dac-codec/conf/quantizer/4d.yml deleted file mode 100644 index 48d52872cdceb770ee6b920b56345c191a5e26cb..0000000000000000000000000000000000000000 --- a/dac-codec/conf/quantizer/4d.yml +++ /dev/null @@ -1,5 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -DAC.codebook_dim: 4 diff --git a/dac-codec/conf/quantizer/512d.yml b/dac-codec/conf/quantizer/512d.yml deleted file mode 100644 index 2a9d9ae4d49ac6b54837c3c6a27c17987cae90c8..0000000000000000000000000000000000000000 --- a/dac-codec/conf/quantizer/512d.yml +++ /dev/null @@ -1,5 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -DAC.codebook_dim: 512 diff --git a/dac-codec/conf/quantizer/dropout-0.0.yml b/dac-codec/conf/quantizer/dropout-0.0.yml deleted file mode 100644 index 93a657765caff2e5aa9dde8b7d56af310a4e2d93..0000000000000000000000000000000000000000 --- a/dac-codec/conf/quantizer/dropout-0.0.yml +++ /dev/null @@ -1,5 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -DAC.quantizer_dropout: 0.0 diff --git a/dac-codec/conf/quantizer/dropout-0.25.yml b/dac-codec/conf/quantizer/dropout-0.25.yml deleted file mode 100644 index d0c1ff43285e9968468bb4aed25668849466fb88..0000000000000000000000000000000000000000 --- a/dac-codec/conf/quantizer/dropout-0.25.yml +++ /dev/null @@ -1,5 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -DAC.quantizer_dropout: 0.25 diff --git a/dac-codec/conf/quantizer/dropout-0.5.yml b/dac-codec/conf/quantizer/dropout-0.5.yml deleted file mode 100644 index f6682b31fc3e20456da49bb51bcbe8bc6b07d7fa..0000000000000000000000000000000000000000 --- a/dac-codec/conf/quantizer/dropout-0.5.yml +++ /dev/null @@ -1,5 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -DAC.quantizer_dropout: 0.5 diff --git a/dac-codec/conf/size/medium.yml b/dac-codec/conf/size/medium.yml deleted file mode 100644 index 5751decf073e136848351f0576443870bf51f42f..0000000000000000000000000000000000000000 --- a/dac-codec/conf/size/medium.yml +++ /dev/null @@ -1,5 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -DAC.decoder_dim: 1024 diff --git a/dac-codec/conf/size/small.yml b/dac-codec/conf/size/small.yml deleted file mode 100644 index d67649bf050aceb9d737d9f5461cdba831ff9d19..0000000000000000000000000000000000000000 --- a/dac-codec/conf/size/small.yml +++ /dev/null @@ -1,5 +0,0 @@ -$include: - - conf/base.yml - - conf/1gpu.yml - -DAC.decoder_dim: 512 diff --git a/dac-codec/dac/__init__.py b/dac-codec/dac/__init__.py deleted file mode 100644 index 51205ef6ded9c6735a988b76008e0f6bdce8e215..0000000000000000000000000000000000000000 --- a/dac-codec/dac/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -__version__ = "1.0.0" - -# preserved here for legacy reasons -__model_version__ = "latest" - -import audiotools - -audiotools.ml.BaseModel.INTERN += ["dac.**"] -audiotools.ml.BaseModel.EXTERN += ["einops"] - - -from . import nn -from . import model -from . import utils -from .model import DAC -from .model import DACFile diff --git a/dac-codec/dac/__main__.py b/dac-codec/dac/__main__.py deleted file mode 100644 index 2fa8d15307997663f8143669c2bd56e0889cb021..0000000000000000000000000000000000000000 --- a/dac-codec/dac/__main__.py +++ /dev/null @@ -1,36 +0,0 @@ -import sys - -import argbind - -from dac.utils import download -from dac.utils.decode import decode -from dac.utils.encode import encode - -STAGES = ["encode", "decode", "download"] - - -def run(stage: str): - """Run stages. - - Parameters - ---------- - stage : str - Stage to run - """ - if stage not in STAGES: - raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") - stage_fn = globals()[stage] - - if stage == "download": - stage_fn() - return - - stage_fn() - - -if __name__ == "__main__": - group = sys.argv.pop(1) - args = argbind.parse_args(group=group) - - with argbind.scope(args): - run(group) diff --git a/dac-codec/dac/__pycache__/__init__.cpython-310.pyc b/dac-codec/dac/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 3926a25606b38aaf5bee89c51a8651e4e8abd591..0000000000000000000000000000000000000000 Binary files a/dac-codec/dac/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/dac-codec/dac/__pycache__/__main__.cpython-310.pyc b/dac-codec/dac/__pycache__/__main__.cpython-310.pyc deleted file mode 100644 index 2bb7debf53a3d8553e8e62c3fd5fff593b8e5bbf..0000000000000000000000000000000000000000 Binary files a/dac-codec/dac/__pycache__/__main__.cpython-310.pyc and /dev/null differ diff --git a/dac-codec/dac/compare/__init__.py b/dac-codec/dac/compare/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/dac-codec/dac/compare/encodec.py b/dac-codec/dac/compare/encodec.py deleted file mode 100644 index 42877de3cffa7d681b28266e4e1f537d48b749eb..0000000000000000000000000000000000000000 --- a/dac-codec/dac/compare/encodec.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -from audiotools import AudioSignal -from audiotools.ml import BaseModel -from encodec import EncodecModel - - -class Encodec(BaseModel): - def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0): - super().__init__() - - if sample_rate == 24000: - self.model = EncodecModel.encodec_model_24khz() - else: - self.model = EncodecModel.encodec_model_48khz() - self.model.set_target_bandwidth(bandwidth) - self.sample_rate = 44100 - - def forward( - self, - audio_data: torch.Tensor, - sample_rate: int = 44100, - n_quantizers: int = None, - ): - signal = AudioSignal(audio_data, sample_rate) - signal.resample(self.model.sample_rate) - recons = self.model(signal.audio_data) - recons = AudioSignal(recons, self.model.sample_rate) - recons.resample(sample_rate) - return {"audio": recons.audio_data} - - -if __name__ == "__main__": - import numpy as np - from functools import partial - - model = Encodec() - - for n, m in model.named_modules(): - o = m.extra_repr() - p = sum([np.prod(p.size()) for p in m.parameters()]) - fn = lambda o, p: o + f" {p/1e6:<.3f}M params." - setattr(m, "extra_repr", partial(fn, o=o, p=p)) - print(model) - print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) - - length = 88200 * 2 - x = torch.randn(1, 1, length).to(model.device) - x.requires_grad_(True) - x.retain_grad() - - # Make a forward pass - out = model(x)["audio"] - - print(x.shape, out.shape) diff --git a/dac-codec/dac/model/__init__.py b/dac-codec/dac/model/__init__.py deleted file mode 100644 index 02a75b7ad6028f5c41b6a8285b0257d4c23bdfcf..0000000000000000000000000000000000000000 --- a/dac-codec/dac/model/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .base import CodecMixin -from .base import DACFile -from .dac import DAC -from .discriminator import Discriminator diff --git a/dac-codec/dac/model/__pycache__/__init__.cpython-310.pyc b/dac-codec/dac/model/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 9cfb0cd3b1b5248a3274a59416e052d461aa6457..0000000000000000000000000000000000000000 Binary files a/dac-codec/dac/model/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/dac-codec/dac/model/__pycache__/base.cpython-310.pyc b/dac-codec/dac/model/__pycache__/base.cpython-310.pyc deleted file mode 100644 index 52a1ea921233d4b27a9446b0c1e4df56555de9b5..0000000000000000000000000000000000000000 Binary files a/dac-codec/dac/model/__pycache__/base.cpython-310.pyc and /dev/null differ diff --git a/dac-codec/dac/model/__pycache__/dac.cpython-310.pyc b/dac-codec/dac/model/__pycache__/dac.cpython-310.pyc deleted file mode 100644 index 5d0dd9fa678d29d89eabd59f7bf072089e128190..0000000000000000000000000000000000000000 Binary files a/dac-codec/dac/model/__pycache__/dac.cpython-310.pyc and /dev/null differ diff --git a/dac-codec/dac/model/__pycache__/discriminator.cpython-310.pyc b/dac-codec/dac/model/__pycache__/discriminator.cpython-310.pyc deleted file mode 100644 index 6976c240c9f308342fbc6597d0c219845428fe57..0000000000000000000000000000000000000000 Binary files a/dac-codec/dac/model/__pycache__/discriminator.cpython-310.pyc and /dev/null differ diff --git a/dac-codec/dac/model/base.py b/dac-codec/dac/model/base.py deleted file mode 100644 index 546b3cb7092d6bd1837ec780228d2a5b3e01fe8d..0000000000000000000000000000000000000000 --- a/dac-codec/dac/model/base.py +++ /dev/null @@ -1,294 +0,0 @@ -import math -from dataclasses import dataclass -from pathlib import Path -from typing import Union - -import numpy as np -import torch -import tqdm -from audiotools import AudioSignal -from torch import nn - -SUPPORTED_VERSIONS = ["1.0.0"] - - -@dataclass -class DACFile: - codes: torch.Tensor - - # Metadata - chunk_length: int - original_length: int - input_db: float - channels: int - sample_rate: int - padding: bool - dac_version: str - - def save(self, path): - artifacts = { - "codes": self.codes.numpy().astype(np.uint16), - "metadata": { - "input_db": self.input_db.numpy().astype(np.float32), - "original_length": self.original_length, - "sample_rate": self.sample_rate, - "chunk_length": self.chunk_length, - "channels": self.channels, - "padding": self.padding, - "dac_version": SUPPORTED_VERSIONS[-1], - }, - } - path = Path(path).with_suffix(".dac") - with open(path, "wb") as f: - np.save(f, artifacts) - return path - - @classmethod - def load(cls, path): - artifacts = np.load(path, allow_pickle=True)[()] - codes = torch.from_numpy(artifacts["codes"].astype(int)) - if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: - raise RuntimeError( - f"Given file {path} can't be loaded with this version of descript-audio-codec." - ) - return cls(codes=codes, **artifacts["metadata"]) - - -class CodecMixin: - @property - def padding(self): - if not hasattr(self, "_padding"): - self._padding = True - return self._padding - - @padding.setter - def padding(self, value): - assert isinstance(value, bool) - - layers = [ - l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) - ] - - for layer in layers: - if value: - if hasattr(layer, "original_padding"): - layer.padding = layer.original_padding - else: - layer.original_padding = layer.padding - layer.padding = tuple(0 for _ in range(len(layer.padding))) - - self._padding = value - - def get_delay(self): - # Any number works here, delay is invariant to input length - l_out = self.get_output_length(0) - L = l_out - - layers = [] - for layer in self.modules(): - if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): - layers.append(layer) - - for layer in reversed(layers): - d = layer.dilation[0] - k = layer.kernel_size[0] - s = layer.stride[0] - - if isinstance(layer, nn.ConvTranspose1d): - L = ((L - d * (k - 1) - 1) / s) + 1 - elif isinstance(layer, nn.Conv1d): - L = (L - 1) * s + d * (k - 1) + 1 - - L = math.ceil(L) - - l_in = L - - return (l_in - l_out) // 2 - - def get_output_length(self, input_length): - L = input_length - # Calculate output length - for layer in self.modules(): - if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): - d = layer.dilation[0] - k = layer.kernel_size[0] - s = layer.stride[0] - - if isinstance(layer, nn.Conv1d): - L = ((L - d * (k - 1) - 1) / s) + 1 - elif isinstance(layer, nn.ConvTranspose1d): - L = (L - 1) * s + d * (k - 1) + 1 - - L = math.floor(L) - return L - - @torch.no_grad() - def compress( - self, - audio_path_or_signal: Union[str, Path, AudioSignal], - win_duration: float = 1.0, - verbose: bool = False, - normalize_db: float = -16, - n_quantizers: int = None, - ) -> DACFile: - """Processes an audio signal from a file or AudioSignal object into - discrete codes. This function processes the signal in short windows, - using constant GPU memory. - - Parameters - ---------- - audio_path_or_signal : Union[str, Path, AudioSignal] - audio signal to reconstruct - win_duration : float, optional - window duration in seconds, by default 5.0 - verbose : bool, optional - by default False - normalize_db : float, optional - normalize db, by default -16 - - Returns - ------- - DACFile - Object containing compressed codes and metadata - required for decompression - """ - audio_signal = audio_path_or_signal - if isinstance(audio_signal, (str, Path)): - audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) - - self.eval() - original_padding = self.padding - original_device = audio_signal.device - - audio_signal = audio_signal.clone() - original_sr = audio_signal.sample_rate - - resample_fn = audio_signal.resample - loudness_fn = audio_signal.loudness - - # If audio is > 10 minutes long, use the ffmpeg versions - if audio_signal.signal_duration >= 10 * 60 * 60: - resample_fn = audio_signal.ffmpeg_resample - loudness_fn = audio_signal.ffmpeg_loudness - - original_length = audio_signal.signal_length - resample_fn(self.sample_rate) - input_db = loudness_fn() - - if normalize_db is not None: - audio_signal.normalize(normalize_db) - audio_signal.ensure_max_of_audio() - - nb, nac, nt = audio_signal.audio_data.shape - audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) - win_duration = ( - audio_signal.signal_duration if win_duration is None else win_duration - ) - - if audio_signal.signal_duration <= win_duration: - # Unchunked compression (used if signal length < win duration) - self.padding = True - n_samples = nt - hop = nt - else: - # Chunked inference - self.padding = False - # Zero-pad signal on either side by the delay - audio_signal.zero_pad(self.delay, self.delay) - n_samples = int(win_duration * self.sample_rate) - # Round n_samples to nearest hop length multiple - n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) - hop = self.get_output_length(n_samples) - - codes = [] - range_fn = range if not verbose else tqdm.trange - - for i in range_fn(0, nt, hop): - x = audio_signal[..., i : i + n_samples] - x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) - - audio_data = x.audio_data.to(self.device) - audio_data = self.preprocess(audio_data, self.sample_rate) - _, c, _, _, _ = self.encode(audio_data, n_quantizers) - codes.append(c.to(original_device)) - chunk_length = c.shape[-1] - - codes = torch.cat(codes, dim=-1) - - dac_file = DACFile( - codes=codes, - chunk_length=chunk_length, - original_length=original_length, - input_db=input_db, - channels=nac, - sample_rate=original_sr, - padding=self.padding, - dac_version=SUPPORTED_VERSIONS[-1], - ) - - if n_quantizers is not None: - codes = codes[:, :n_quantizers, :] - - self.padding = original_padding - return dac_file - - @torch.no_grad() - def decompress( - self, - obj: Union[str, Path, DACFile], - verbose: bool = False, - ) -> AudioSignal: - """Reconstruct audio from a given .dac file - - Parameters - ---------- - obj : Union[str, Path, DACFile] - .dac file location or corresponding DACFile object. - verbose : bool, optional - Prints progress if True, by default False - - Returns - ------- - AudioSignal - Object with the reconstructed audio - """ - self.eval() - if isinstance(obj, (str, Path)): - obj = DACFile.load(obj) - - original_padding = self.padding - self.padding = obj.padding - - range_fn = range if not verbose else tqdm.trange - codes = obj.codes - original_device = codes.device - chunk_length = obj.chunk_length - recons = [] - - for i in range_fn(0, codes.shape[-1], chunk_length): - c = codes[..., i : i + chunk_length].to(self.device) - z = self.quantizer.from_codes(c)[0] - r = self.decode(z) - recons.append(r.to(original_device)) - - recons = torch.cat(recons, dim=-1) - recons = AudioSignal(recons, self.sample_rate) - - resample_fn = recons.resample - loudness_fn = recons.loudness - - # If audio is > 10 minutes long, use the ffmpeg versions - if recons.signal_duration >= 10 * 60 * 60: - resample_fn = recons.ffmpeg_resample - loudness_fn = recons.ffmpeg_loudness - - recons.normalize(obj.input_db) - resample_fn(obj.sample_rate) - recons = recons[..., : obj.original_length] - loudness_fn() - recons.audio_data = recons.audio_data.reshape( - -1, obj.channels, obj.original_length - ) - - self.padding = original_padding - return recons diff --git a/dac-codec/dac/model/dac.py b/dac-codec/dac/model/dac.py deleted file mode 100644 index eb754b25a30a182b85db93783c9b2748a61e6dd2..0000000000000000000000000000000000000000 --- a/dac-codec/dac/model/dac.py +++ /dev/null @@ -1,364 +0,0 @@ -import math -from typing import List -from typing import Union - -import numpy as np -import torch -from audiotools import AudioSignal -from audiotools.ml import BaseModel -from torch import nn - -from .base import CodecMixin -from dac.nn.layers import Snake1d -from dac.nn.layers import WNConv1d -from dac.nn.layers import WNConvTranspose1d -from dac.nn.quantize import ResidualVectorQuantize - - -def init_weights(m): - if isinstance(m, nn.Conv1d): - nn.init.trunc_normal_(m.weight, std=0.02) - nn.init.constant_(m.bias, 0) - - -class ResidualUnit(nn.Module): - def __init__(self, dim: int = 16, dilation: int = 1): - super().__init__() - pad = ((7 - 1) * dilation) // 2 - self.block = nn.Sequential( - Snake1d(dim), - WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), - Snake1d(dim), - WNConv1d(dim, dim, kernel_size=1), - ) - - def forward(self, x): - y = self.block(x) - pad = (x.shape[-1] - y.shape[-1]) // 2 - if pad > 0: - x = x[..., pad:-pad] - return x + y - - -class EncoderBlock(nn.Module): - def __init__(self, dim: int = 16, stride: int = 1): - super().__init__() - self.block = nn.Sequential( - ResidualUnit(dim // 2, dilation=1), - ResidualUnit(dim // 2, dilation=3), - ResidualUnit(dim // 2, dilation=9), - Snake1d(dim // 2), - WNConv1d( - dim // 2, - dim, - kernel_size=2 * stride, - stride=stride, - padding=math.ceil(stride / 2), - ), - ) - - def forward(self, x): - return self.block(x) - - -class Encoder(nn.Module): - def __init__( - self, - d_model: int = 64, - strides: list = [2, 4, 8, 8], - d_latent: int = 64, - ): - super().__init__() - # Create first convolution - self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] - - # Create EncoderBlocks that double channels as they downsample by `stride` - for stride in strides: - d_model *= 2 - self.block += [EncoderBlock(d_model, stride=stride)] - - # Create last convolution - self.block += [ - Snake1d(d_model), - WNConv1d(d_model, d_latent, kernel_size=3, padding=1), - ] - - # Wrap black into nn.Sequential - self.block = nn.Sequential(*self.block) - self.enc_dim = d_model - - def forward(self, x): - return self.block(x) - - -class DecoderBlock(nn.Module): - def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): - super().__init__() - self.block = nn.Sequential( - Snake1d(input_dim), - WNConvTranspose1d( - input_dim, - output_dim, - kernel_size=2 * stride, - stride=stride, - padding=math.ceil(stride / 2), - ), - ResidualUnit(output_dim, dilation=1), - ResidualUnit(output_dim, dilation=3), - ResidualUnit(output_dim, dilation=9), - ) - - def forward(self, x): - return self.block(x) - - -class Decoder(nn.Module): - def __init__( - self, - input_channel, - channels, - rates, - d_out: int = 1, - ): - super().__init__() - - # Add first conv layer - layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] - - # Add upsampling + MRF blocks - for i, stride in enumerate(rates): - input_dim = channels // 2**i - output_dim = channels // 2 ** (i + 1) - layers += [DecoderBlock(input_dim, output_dim, stride)] - - # Add final conv layer - layers += [ - Snake1d(output_dim), - WNConv1d(output_dim, d_out, kernel_size=7, padding=3), - nn.Tanh(), - ] - - self.model = nn.Sequential(*layers) - - def forward(self, x): - return self.model(x) - - -class DAC(BaseModel, CodecMixin): - def __init__( - self, - encoder_dim: int = 64, - encoder_rates: List[int] = [2, 4, 8, 8], - latent_dim: int = None, - decoder_dim: int = 1536, - decoder_rates: List[int] = [8, 8, 4, 2], - n_codebooks: int = 9, - codebook_size: int = 1024, - codebook_dim: Union[int, list] = 8, - quantizer_dropout: bool = False, - sample_rate: int = 44100, - ): - super().__init__() - - self.encoder_dim = encoder_dim - self.encoder_rates = encoder_rates - self.decoder_dim = decoder_dim - self.decoder_rates = decoder_rates - self.sample_rate = sample_rate - - if latent_dim is None: - latent_dim = encoder_dim * (2 ** len(encoder_rates)) - - self.latent_dim = latent_dim - - self.hop_length = np.prod(encoder_rates) - self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) - - self.n_codebooks = n_codebooks - self.codebook_size = codebook_size - self.codebook_dim = codebook_dim - self.quantizer = ResidualVectorQuantize( - input_dim=latent_dim, - n_codebooks=n_codebooks, - codebook_size=codebook_size, - codebook_dim=codebook_dim, - quantizer_dropout=quantizer_dropout, - ) - - self.decoder = Decoder( - latent_dim, - decoder_dim, - decoder_rates, - ) - self.sample_rate = sample_rate - self.apply(init_weights) - - self.delay = self.get_delay() - - def preprocess(self, audio_data, sample_rate): - if sample_rate is None: - sample_rate = self.sample_rate - assert sample_rate == self.sample_rate - - length = audio_data.shape[-1] - right_pad = math.ceil(length / self.hop_length) * self.hop_length - length - audio_data = nn.functional.pad(audio_data, (0, right_pad)) - - return audio_data - - def encode( - self, - audio_data: torch.Tensor, - n_quantizers: int = None, - ): - """Encode given audio data and return quantized latent codes - - Parameters - ---------- - audio_data : Tensor[B x 1 x T] - Audio data to encode - n_quantizers : int, optional - Number of quantizers to use, by default None - If None, all quantizers are used. - - Returns - ------- - dict - A dictionary with the following keys: - "z" : Tensor[B x D x T] - Quantized continuous representation of input - "codes" : Tensor[B x N x T] - Codebook indices for each codebook - (quantized discrete representation of input) - "latents" : Tensor[B x N*D x T] - Projected latents (continuous representation of input before quantization) - "vq/commitment_loss" : Tensor[1] - Commitment loss to train encoder to predict vectors closer to codebook - entries - "vq/codebook_loss" : Tensor[1] - Codebook loss to update the codebook - "length" : int - Number of samples in input audio - """ - z = self.encoder(audio_data) - z, codes, latents, commitment_loss, codebook_loss = self.quantizer( - z, n_quantizers - ) - return z, codes, latents, commitment_loss, codebook_loss - - def decode(self, z: torch.Tensor): - """Decode given latent codes and return audio data - - Parameters - ---------- - z : Tensor[B x D x T] - Quantized continuous representation of input - length : int, optional - Number of samples in output audio, by default None - - Returns - ------- - dict - A dictionary with the following keys: - "audio" : Tensor[B x 1 x length] - Decoded audio data. - """ - return self.decoder(z) - - def forward( - self, - audio_data: torch.Tensor, - sample_rate: int = None, - n_quantizers: int = None, - ): - """Model forward pass - - Parameters - ---------- - audio_data : Tensor[B x 1 x T] - Audio data to encode - sample_rate : int, optional - Sample rate of audio data in Hz, by default None - If None, defaults to `self.sample_rate` - n_quantizers : int, optional - Number of quantizers to use, by default None. - If None, all quantizers are used. - - Returns - ------- - dict - A dictionary with the following keys: - "z" : Tensor[B x D x T] - Quantized continuous representation of input - "codes" : Tensor[B x N x T] - Codebook indices for each codebook - (quantized discrete representation of input) - "latents" : Tensor[B x N*D x T] - Projected latents (continuous representation of input before quantization) - "vq/commitment_loss" : Tensor[1] - Commitment loss to train encoder to predict vectors closer to codebook - entries - "vq/codebook_loss" : Tensor[1] - Codebook loss to update the codebook - "length" : int - Number of samples in input audio - "audio" : Tensor[B x 1 x length] - Decoded audio data. - """ - length = audio_data.shape[-1] - audio_data = self.preprocess(audio_data, sample_rate) - z, codes, latents, commitment_loss, codebook_loss = self.encode( - audio_data, n_quantizers - ) - - x = self.decode(z) - return { - "audio": x[..., :length], - "z": z, - "codes": codes, - "latents": latents, - "vq/commitment_loss": commitment_loss, - "vq/codebook_loss": codebook_loss, - } - - -if __name__ == "__main__": - import numpy as np - from functools import partial - - model = DAC().to("cpu") - - for n, m in model.named_modules(): - o = m.extra_repr() - p = sum([np.prod(p.size()) for p in m.parameters()]) - fn = lambda o, p: o + f" {p/1e6:<.3f}M params." - setattr(m, "extra_repr", partial(fn, o=o, p=p)) - print(model) - print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) - - length = 88200 * 2 - x = torch.randn(1, 1, length).to(model.device) - x.requires_grad_(True) - x.retain_grad() - - # Make a forward pass - out = model(x)["audio"] - print("Input shape:", x.shape) - print("Output shape:", out.shape) - - # Create gradient variable - grad = torch.zeros_like(out) - grad[:, :, grad.shape[-1] // 2] = 1 - - # Make a backward pass - out.backward(grad) - - # Check non-zero values - gradmap = x.grad.squeeze(0) - gradmap = (gradmap != 0).sum(0) # sum across features - rf = (gradmap != 0).sum() - - print(f"Receptive field: {rf.item()}") - - x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) - model.decompress(model.compress(x, verbose=True), verbose=True) diff --git a/dac-codec/dac/model/discriminator.py b/dac-codec/dac/model/discriminator.py deleted file mode 100644 index 09c79d1342ca46bef21daca64667577f05e61638..0000000000000000000000000000000000000000 --- a/dac-codec/dac/model/discriminator.py +++ /dev/null @@ -1,228 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from audiotools import AudioSignal -from audiotools import ml -from audiotools import STFTParams -from einops import rearrange -from torch.nn.utils import weight_norm - - -def WNConv1d(*args, **kwargs): - act = kwargs.pop("act", True) - conv = weight_norm(nn.Conv1d(*args, **kwargs)) - if not act: - return conv - return nn.Sequential(conv, nn.LeakyReLU(0.1)) - - -def WNConv2d(*args, **kwargs): - act = kwargs.pop("act", True) - conv = weight_norm(nn.Conv2d(*args, **kwargs)) - if not act: - return conv - return nn.Sequential(conv, nn.LeakyReLU(0.1)) - - -class MPD(nn.Module): - def __init__(self, period): - super().__init__() - self.period = period - self.convs = nn.ModuleList( - [ - WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), - WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), - WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), - WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), - WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), - ] - ) - self.conv_post = WNConv2d( - 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False - ) - - def pad_to_period(self, x): - t = x.shape[-1] - x = F.pad(x, (0, self.period - t % self.period), mode="reflect") - return x - - def forward(self, x): - fmap = [] - - x = self.pad_to_period(x) - x = rearrange(x, "b c (l p) -> b c l p", p=self.period) - - for layer in self.convs: - x = layer(x) - fmap.append(x) - - x = self.conv_post(x) - fmap.append(x) - - return fmap - - -class MSD(nn.Module): - def __init__(self, rate: int = 1, sample_rate: int = 44100): - super().__init__() - self.convs = nn.ModuleList( - [ - WNConv1d(1, 16, 15, 1, padding=7), - WNConv1d(16, 64, 41, 4, groups=4, padding=20), - WNConv1d(64, 256, 41, 4, groups=16, padding=20), - WNConv1d(256, 1024, 41, 4, groups=64, padding=20), - WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), - WNConv1d(1024, 1024, 5, 1, padding=2), - ] - ) - self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) - self.sample_rate = sample_rate - self.rate = rate - - def forward(self, x): - x = AudioSignal(x, self.sample_rate) - x.resample(self.sample_rate // self.rate) - x = x.audio_data - - fmap = [] - - for l in self.convs: - x = l(x) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - - return fmap - - -BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] - - -class MRD(nn.Module): - def __init__( - self, - window_length: int, - hop_factor: float = 0.25, - sample_rate: int = 44100, - bands: list = BANDS, - ): - """Complex multi-band spectrogram discriminator. - Parameters - ---------- - window_length : int - Window length of STFT. - hop_factor : float, optional - Hop factor of the STFT, defaults to ``0.25 * window_length``. - sample_rate : int, optional - Sampling rate of audio in Hz, by default 44100 - bands : list, optional - Bands to run discriminator over. - """ - super().__init__() - - self.window_length = window_length - self.hop_factor = hop_factor - self.sample_rate = sample_rate - self.stft_params = STFTParams( - window_length=window_length, - hop_length=int(window_length * hop_factor), - match_stride=True, - ) - - n_fft = window_length // 2 + 1 - bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] - self.bands = bands - - ch = 32 - convs = lambda: nn.ModuleList( - [ - WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), - WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), - WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), - WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), - WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), - ] - ) - self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) - self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) - - def spectrogram(self, x): - x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) - x = torch.view_as_real(x.stft()) - x = rearrange(x, "b 1 f t c -> (b 1) c t f") - # Split into bands - x_bands = [x[..., b[0] : b[1]] for b in self.bands] - return x_bands - - def forward(self, x): - x_bands = self.spectrogram(x) - fmap = [] - - x = [] - for band, stack in zip(x_bands, self.band_convs): - for layer in stack: - band = layer(band) - fmap.append(band) - x.append(band) - - x = torch.cat(x, dim=-1) - x = self.conv_post(x) - fmap.append(x) - - return fmap - - -class Discriminator(ml.BaseModel): - def __init__( - self, - rates: list = [], - periods: list = [2, 3, 5, 7, 11], - fft_sizes: list = [2048, 1024, 512], - sample_rate: int = 44100, - bands: list = BANDS, - ): - """Discriminator that combines multiple discriminators. - - Parameters - ---------- - rates : list, optional - sampling rates (in Hz) to run MSD at, by default [] - If empty, MSD is not used. - periods : list, optional - periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] - fft_sizes : list, optional - Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] - sample_rate : int, optional - Sampling rate of audio in Hz, by default 44100 - bands : list, optional - Bands to run MRD at, by default `BANDS` - """ - super().__init__() - discs = [] - discs += [MPD(p) for p in periods] - discs += [MSD(r, sample_rate=sample_rate) for r in rates] - discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] - self.discriminators = nn.ModuleList(discs) - - def preprocess(self, y): - # Remove DC offset - y = y - y.mean(dim=-1, keepdims=True) - # Peak normalize the volume of input audio - y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) - return y - - def forward(self, x): - x = self.preprocess(x) - fmaps = [d(x) for d in self.discriminators] - return fmaps - - -if __name__ == "__main__": - disc = Discriminator() - x = torch.zeros(1, 1, 44100) - results = disc(x) - for i, result in enumerate(results): - print(f"disc{i}") - for i, r in enumerate(result): - print(r.shape, r.mean(), r.min(), r.max()) - print() diff --git a/dac-codec/dac/nn/__init__.py b/dac-codec/dac/nn/__init__.py deleted file mode 100644 index 6718c8b1a3d36c31655b030f4c515a144cde4db7..0000000000000000000000000000000000000000 --- a/dac-codec/dac/nn/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . import layers -from . import loss -from . import quantize diff --git a/dac-codec/dac/nn/__pycache__/__init__.cpython-310.pyc b/dac-codec/dac/nn/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index df87b35cc2e709e7274cfd24c066d4fea0d3e8bc..0000000000000000000000000000000000000000 Binary files a/dac-codec/dac/nn/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/dac-codec/dac/nn/__pycache__/layers.cpython-310.pyc b/dac-codec/dac/nn/__pycache__/layers.cpython-310.pyc deleted file mode 100644 index 9ad43e2df2828248d4652da8914311ab6e5de6de..0000000000000000000000000000000000000000 Binary files a/dac-codec/dac/nn/__pycache__/layers.cpython-310.pyc and /dev/null differ diff --git a/dac-codec/dac/nn/__pycache__/loss.cpython-310.pyc b/dac-codec/dac/nn/__pycache__/loss.cpython-310.pyc deleted file mode 100644 index 2cf4e98d4940dee7694b175d7d6bfd68e5d9aed4..0000000000000000000000000000000000000000 Binary files a/dac-codec/dac/nn/__pycache__/loss.cpython-310.pyc and /dev/null differ diff --git a/dac-codec/dac/nn/__pycache__/quantize.cpython-310.pyc b/dac-codec/dac/nn/__pycache__/quantize.cpython-310.pyc deleted file mode 100644 index 60331ff77031774b6f13eea0e04e7244689eb593..0000000000000000000000000000000000000000 Binary files a/dac-codec/dac/nn/__pycache__/quantize.cpython-310.pyc and /dev/null differ diff --git a/dac-codec/dac/nn/layers.py b/dac-codec/dac/nn/layers.py deleted file mode 100644 index 44fbc2929715e11d843b24195d7042a528969a94..0000000000000000000000000000000000000000 --- a/dac-codec/dac/nn/layers.py +++ /dev/null @@ -1,33 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn.utils import weight_norm - - -def WNConv1d(*args, **kwargs): - return weight_norm(nn.Conv1d(*args, **kwargs)) - - -def WNConvTranspose1d(*args, **kwargs): - return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) - - -# Scripting this brings model speed up 1.4x -@torch.jit.script -def snake(x, alpha): - shape = x.shape - x = x.reshape(shape[0], shape[1], -1) - x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) - x = x.reshape(shape) - return x - - -class Snake1d(nn.Module): - def __init__(self, channels): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1, channels, 1)) - - def forward(self, x): - return snake(x, self.alpha) diff --git a/dac-codec/dac/nn/loss.py b/dac-codec/dac/nn/loss.py deleted file mode 100644 index 9bb3dd6ce08a7a24f18f941eeb5b68fe9461e86b..0000000000000000000000000000000000000000 --- a/dac-codec/dac/nn/loss.py +++ /dev/null @@ -1,368 +0,0 @@ -import typing -from typing import List - -import torch -import torch.nn.functional as F -from audiotools import AudioSignal -from audiotools import STFTParams -from torch import nn - - -class L1Loss(nn.L1Loss): - """L1 Loss between AudioSignals. Defaults - to comparing ``audio_data``, but any - attribute of an AudioSignal can be used. - - Parameters - ---------- - attribute : str, optional - Attribute of signal to compare, defaults to ``audio_data``. - weight : float, optional - Weight of this loss, defaults to 1.0. - - Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py - """ - - def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): - self.attribute = attribute - self.weight = weight - super().__init__(**kwargs) - - def forward(self, x: AudioSignal, y: AudioSignal): - """ - Parameters - ---------- - x : AudioSignal - Estimate AudioSignal - y : AudioSignal - Reference AudioSignal - - Returns - ------- - torch.Tensor - L1 loss between AudioSignal attributes. - """ - if isinstance(x, AudioSignal): - x = getattr(x, self.attribute) - y = getattr(y, self.attribute) - return super().forward(x, y) - - -class SISDRLoss(nn.Module): - """ - Computes the Scale-Invariant Source-to-Distortion Ratio between a batch - of estimated and reference audio signals or aligned features. - - Parameters - ---------- - scaling : int, optional - Whether to use scale-invariant (True) or - signal-to-noise ratio (False), by default True - reduction : str, optional - How to reduce across the batch (either 'mean', - 'sum', or none).], by default ' mean' - zero_mean : int, optional - Zero mean the references and estimates before - computing the loss, by default True - clip_min : int, optional - The minimum possible loss value. Helps network - to not focus on making already good examples better, by default None - weight : float, optional - Weight of this loss, defaults to 1.0. - - Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py - """ - - def __init__( - self, - scaling: int = True, - reduction: str = "mean", - zero_mean: int = True, - clip_min: int = None, - weight: float = 1.0, - ): - self.scaling = scaling - self.reduction = reduction - self.zero_mean = zero_mean - self.clip_min = clip_min - self.weight = weight - super().__init__() - - def forward(self, x: AudioSignal, y: AudioSignal): - eps = 1e-8 - # nb, nc, nt - if isinstance(x, AudioSignal): - references = x.audio_data - estimates = y.audio_data - else: - references = x - estimates = y - - nb = references.shape[0] - references = references.reshape(nb, 1, -1).permute(0, 2, 1) - estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) - - # samples now on axis 1 - if self.zero_mean: - mean_reference = references.mean(dim=1, keepdim=True) - mean_estimate = estimates.mean(dim=1, keepdim=True) - else: - mean_reference = 0 - mean_estimate = 0 - - _references = references - mean_reference - _estimates = estimates - mean_estimate - - references_projection = (_references**2).sum(dim=-2) + eps - references_on_estimates = (_estimates * _references).sum(dim=-2) + eps - - scale = ( - (references_on_estimates / references_projection).unsqueeze(1) - if self.scaling - else 1 - ) - - e_true = scale * _references - e_res = _estimates - e_true - - signal = (e_true**2).sum(dim=1) - noise = (e_res**2).sum(dim=1) - sdr = -10 * torch.log10(signal / noise + eps) - - if self.clip_min is not None: - sdr = torch.clamp(sdr, min=self.clip_min) - - if self.reduction == "mean": - sdr = sdr.mean() - elif self.reduction == "sum": - sdr = sdr.sum() - return sdr - - -class MultiScaleSTFTLoss(nn.Module): - """Computes the multi-scale STFT loss from [1]. - - Parameters - ---------- - window_lengths : List[int], optional - Length of each window of each STFT, by default [2048, 512] - loss_fn : typing.Callable, optional - How to compare each loss, by default nn.L1Loss() - clamp_eps : float, optional - Clamp on the log magnitude, below, by default 1e-5 - mag_weight : float, optional - Weight of raw magnitude portion of loss, by default 1.0 - log_weight : float, optional - Weight of log magnitude portion of loss, by default 1.0 - pow : float, optional - Power to raise magnitude to before taking log, by default 2.0 - weight : float, optional - Weight of this loss, by default 1.0 - match_stride : bool, optional - Whether to match the stride of convolutional layers, by default False - - References - ---------- - - 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. - "DDSP: Differentiable Digital Signal Processing." - International Conference on Learning Representations. 2019. - - Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py - """ - - def __init__( - self, - window_lengths: List[int] = [2048, 512], - loss_fn: typing.Callable = nn.L1Loss(), - clamp_eps: float = 1e-5, - mag_weight: float = 1.0, - log_weight: float = 1.0, - pow: float = 2.0, - weight: float = 1.0, - match_stride: bool = False, - window_type: str = None, - ): - super().__init__() - self.stft_params = [ - STFTParams( - window_length=w, - hop_length=w // 4, - match_stride=match_stride, - window_type=window_type, - ) - for w in window_lengths - ] - self.loss_fn = loss_fn - self.log_weight = log_weight - self.mag_weight = mag_weight - self.clamp_eps = clamp_eps - self.weight = weight - self.pow = pow - - def forward(self, x: AudioSignal, y: AudioSignal): - """Computes multi-scale STFT between an estimate and a reference - signal. - - Parameters - ---------- - x : AudioSignal - Estimate signal - y : AudioSignal - Reference signal - - Returns - ------- - torch.Tensor - Multi-scale STFT loss. - """ - loss = 0.0 - for s in self.stft_params: - x.stft(s.window_length, s.hop_length, s.window_type) - y.stft(s.window_length, s.hop_length, s.window_type) - loss += self.log_weight * self.loss_fn( - x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), - y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), - ) - loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) - return loss - - -class MelSpectrogramLoss(nn.Module): - """Compute distance between mel spectrograms. Can be used - in a multi-scale way. - - Parameters - ---------- - n_mels : List[int] - Number of mels per STFT, by default [150, 80], - window_lengths : List[int], optional - Length of each window of each STFT, by default [2048, 512] - loss_fn : typing.Callable, optional - How to compare each loss, by default nn.L1Loss() - clamp_eps : float, optional - Clamp on the log magnitude, below, by default 1e-5 - mag_weight : float, optional - Weight of raw magnitude portion of loss, by default 1.0 - log_weight : float, optional - Weight of log magnitude portion of loss, by default 1.0 - pow : float, optional - Power to raise magnitude to before taking log, by default 2.0 - weight : float, optional - Weight of this loss, by default 1.0 - match_stride : bool, optional - Whether to match the stride of convolutional layers, by default False - - Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py - """ - - def __init__( - self, - n_mels: List[int] = [150, 80], - window_lengths: List[int] = [2048, 512], - loss_fn: typing.Callable = nn.L1Loss(), - clamp_eps: float = 1e-5, - mag_weight: float = 1.0, - log_weight: float = 1.0, - pow: float = 2.0, - weight: float = 1.0, - match_stride: bool = False, - mel_fmin: List[float] = [0.0, 0.0], - mel_fmax: List[float] = [None, None], - window_type: str = None, - ): - super().__init__() - self.stft_params = [ - STFTParams( - window_length=w, - hop_length=w // 4, - match_stride=match_stride, - window_type=window_type, - ) - for w in window_lengths - ] - self.n_mels = n_mels - self.loss_fn = loss_fn - self.clamp_eps = clamp_eps - self.log_weight = log_weight - self.mag_weight = mag_weight - self.weight = weight - self.mel_fmin = mel_fmin - self.mel_fmax = mel_fmax - self.pow = pow - - def forward(self, x: AudioSignal, y: AudioSignal): - """Computes mel loss between an estimate and a reference - signal. - - Parameters - ---------- - x : AudioSignal - Estimate signal - y : AudioSignal - Reference signal - - Returns - ------- - torch.Tensor - Mel loss. - """ - loss = 0.0 - for n_mels, fmin, fmax, s in zip( - self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params - ): - kwargs = { - "window_length": s.window_length, - "hop_length": s.hop_length, - "window_type": s.window_type, - } - x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) - y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) - - loss += self.log_weight * self.loss_fn( - x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), - y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), - ) - loss += self.mag_weight * self.loss_fn(x_mels, y_mels) - return loss - - -class GANLoss(nn.Module): - """ - Computes a discriminator loss, given a discriminator on - generated waveforms/spectrograms compared to ground truth - waveforms/spectrograms. Computes the loss for both the - discriminator and the generator in separate functions. - """ - - def __init__(self, discriminator): - super().__init__() - self.discriminator = discriminator - - def forward(self, fake, real): - d_fake = self.discriminator(fake.audio_data) - d_real = self.discriminator(real.audio_data) - return d_fake, d_real - - def discriminator_loss(self, fake, real): - d_fake, d_real = self.forward(fake.clone().detach(), real) - - loss_d = 0 - for x_fake, x_real in zip(d_fake, d_real): - loss_d += torch.mean(x_fake[-1] ** 2) - loss_d += torch.mean((1 - x_real[-1]) ** 2) - return loss_d - - def generator_loss(self, fake, real): - d_fake, d_real = self.forward(fake, real) - - loss_g = 0 - for x_fake in d_fake: - loss_g += torch.mean((1 - x_fake[-1]) ** 2) - - loss_feature = 0 - - for i in range(len(d_fake)): - for j in range(len(d_fake[i]) - 1): - loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) - return loss_g, loss_feature diff --git a/dac-codec/dac/nn/quantize.py b/dac-codec/dac/nn/quantize.py deleted file mode 100644 index b17ff4a868e489c97fb977a12c112bbbaa5183f6..0000000000000000000000000000000000000000 --- a/dac-codec/dac/nn/quantize.py +++ /dev/null @@ -1,262 +0,0 @@ -from typing import Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn.utils import weight_norm - -from dac.nn.layers import WNConv1d - - -class VectorQuantize(nn.Module): - """ - Implementation of VQ similar to Karpathy's repo: - https://github.com/karpathy/deep-vector-quantization - Additionally uses following tricks from Improved VQGAN - (https://arxiv.org/pdf/2110.04627.pdf): - 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space - for improved codebook usage - 2. l2-normalized codes: Converts euclidean distance to cosine similarity which - improves training stability - """ - - def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): - super().__init__() - self.codebook_size = codebook_size - self.codebook_dim = codebook_dim - - self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) - self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) - self.codebook = nn.Embedding(codebook_size, codebook_dim) - - def forward(self, z): - """Quantized the input tensor using a fixed codebook and returns - the corresponding codebook vectors - - Parameters - ---------- - z : Tensor[B x D x T] - - Returns - ------- - Tensor[B x D x T] - Quantized continuous representation of input - Tensor[1] - Commitment loss to train encoder to predict vectors closer to codebook - entries - Tensor[1] - Codebook loss to update the codebook - Tensor[B x T] - Codebook indices (quantized discrete representation of input) - Tensor[B x D x T] - Projected latents (continuous representation of input before quantization) - """ - - # Factorized codes (ViT-VQGAN) Project input into low-dimensional space - z_e = self.in_proj(z) # z_e : (B x D x T) - z_q, indices = self.decode_latents(z_e) - - commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) - codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) - - z_q = ( - z_e + (z_q - z_e).detach() - ) # noop in forward pass, straight-through gradient estimator in backward pass - - z_q = self.out_proj(z_q) - - return z_q, commitment_loss, codebook_loss, indices, z_e - - def embed_code(self, embed_id): - return F.embedding(embed_id, self.codebook.weight) - - def decode_code(self, embed_id): - return self.embed_code(embed_id).transpose(1, 2) - - def decode_latents(self, latents): - encodings = rearrange(latents, "b d t -> (b t) d") - codebook = self.codebook.weight # codebook: (N x D) - - # L2 normalize encodings and codebook (ViT-VQGAN) - encodings = F.normalize(encodings) - codebook = F.normalize(codebook) - - # Compute euclidean distance with codebook - dist = ( - encodings.pow(2).sum(1, keepdim=True) - - 2 * encodings @ codebook.t() - + codebook.pow(2).sum(1, keepdim=True).t() - ) - indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) - z_q = self.decode_code(indices) - return z_q, indices - - -class ResidualVectorQuantize(nn.Module): - """ - Introduced in SoundStream: An end2end neural audio codec - https://arxiv.org/abs/2107.03312 - """ - - def __init__( - self, - input_dim: int = 512, - n_codebooks: int = 9, - codebook_size: int = 1024, - codebook_dim: Union[int, list] = 8, - quantizer_dropout: float = 0.0, - ): - super().__init__() - if isinstance(codebook_dim, int): - codebook_dim = [codebook_dim for _ in range(n_codebooks)] - - self.n_codebooks = n_codebooks - self.codebook_dim = codebook_dim - self.codebook_size = codebook_size - - self.quantizers = nn.ModuleList( - [ - VectorQuantize(input_dim, codebook_size, codebook_dim[i]) - for i in range(n_codebooks) - ] - ) - self.quantizer_dropout = quantizer_dropout - - def forward(self, z, n_quantizers: int = None): - """Quantized the input tensor using a fixed set of `n` codebooks and returns - the corresponding codebook vectors - Parameters - ---------- - z : Tensor[B x D x T] - n_quantizers : int, optional - No. of quantizers to use - (n_quantizers < self.n_codebooks ex: for quantizer dropout) - Note: if `self.quantizer_dropout` is True, this argument is ignored - when in training mode, and a random number of quantizers is used. - Returns - ------- - dict - A dictionary with the following keys: - - "z" : Tensor[B x D x T] - Quantized continuous representation of input - "codes" : Tensor[B x N x T] - Codebook indices for each codebook - (quantized discrete representation of input) - "latents" : Tensor[B x N*D x T] - Projected latents (continuous representation of input before quantization) - "vq/commitment_loss" : Tensor[1] - Commitment loss to train encoder to predict vectors closer to codebook - entries - "vq/codebook_loss" : Tensor[1] - Codebook loss to update the codebook - """ - z_q = 0 - residual = z - commitment_loss = 0 - codebook_loss = 0 - - codebook_indices = [] - latents = [] - - if n_quantizers is None: - n_quantizers = self.n_codebooks - if self.training: - n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 - dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) - n_dropout = int(z.shape[0] * self.quantizer_dropout) - n_quantizers[:n_dropout] = dropout[:n_dropout] - n_quantizers = n_quantizers.to(z.device) - - for i, quantizer in enumerate(self.quantizers): - if self.training is False and i >= n_quantizers: - break - - z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( - residual - ) - - # Create mask to apply quantizer dropout - mask = ( - torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers - ) - z_q = z_q + z_q_i * mask[:, None, None] - residual = residual - z_q_i - - # Sum losses - commitment_loss += (commitment_loss_i * mask).mean() - codebook_loss += (codebook_loss_i * mask).mean() - - codebook_indices.append(indices_i) - latents.append(z_e_i) - - codes = torch.stack(codebook_indices, dim=1) - latents = torch.cat(latents, dim=1) - - return z_q, codes, latents, commitment_loss, codebook_loss - - def from_codes(self, codes: torch.Tensor): - """Given the quantized codes, reconstruct the continuous representation - Parameters - ---------- - codes : Tensor[B x N x T] - Quantized discrete representation of input - Returns - ------- - Tensor[B x D x T] - Quantized continuous representation of input - """ - z_q = 0.0 - z_p = [] - n_codebooks = codes.shape[1] - for i in range(n_codebooks): - z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) - z_p.append(z_p_i) - - z_q_i = self.quantizers[i].out_proj(z_p_i) - z_q = z_q + z_q_i - return z_q, torch.cat(z_p, dim=1), codes - - def from_latents(self, latents: torch.Tensor): - """Given the unquantized latents, reconstruct the - continuous representation after quantization. - - Parameters - ---------- - latents : Tensor[B x N x T] - Continuous representation of input after projection - - Returns - ------- - Tensor[B x D x T] - Quantized representation of full-projected space - Tensor[B x D x T] - Quantized representation of latent space - """ - z_q = 0 - z_p = [] - codes = [] - dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) - - n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ - 0 - ] - for i in range(n_codebooks): - j, k = dims[i], dims[i + 1] - z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) - z_p.append(z_p_i) - codes.append(codes_i) - - z_q_i = self.quantizers[i].out_proj(z_p_i) - z_q = z_q + z_q_i - - return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) - - -if __name__ == "__main__": - rvq = ResidualVectorQuantize(quantizer_dropout=True) - x = torch.randn(16, 512, 80) - y = rvq(x) - print(y["latents"].shape) diff --git a/dac-codec/dac/utils/__init__.py b/dac-codec/dac/utils/__init__.py deleted file mode 100644 index 9e107bce41b8247c49eea68066fda53267693704..0000000000000000000000000000000000000000 --- a/dac-codec/dac/utils/__init__.py +++ /dev/null @@ -1,123 +0,0 @@ -from pathlib import Path - -import argbind -from audiotools import ml - -import dac - -DAC = dac.model.DAC -Accelerator = ml.Accelerator - -__MODEL_LATEST_TAGS__ = { - ("44khz", "8kbps"): "0.0.1", - ("24khz", "8kbps"): "0.0.4", - ("16khz", "8kbps"): "0.0.5", - ("44khz", "16kbps"): "1.0.0", -} - -__MODEL_URLS__ = { - ( - "44khz", - "0.0.1", - "8kbps", - ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", - ( - "24khz", - "0.0.4", - "8kbps", - ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", - ( - "16khz", - "0.0.5", - "8kbps", - ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", - ( - "44khz", - "1.0.0", - "16kbps", - ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", -} - - -@argbind.bind(group="download", positional=True, without_prefix=True) -def download( - model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" -): - """ - Function that downloads the weights file from URL if a local cache is not found. - - Parameters - ---------- - model_type : str - The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". - model_bitrate: str - Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". - Only 44khz model supports 16kbps. - tag : str - The tag of the model to download. Defaults to "latest". - - Returns - ------- - Path - Directory path required to load model via audiotools. - """ - model_type = model_type.lower() - tag = tag.lower() - - assert model_type in [ - "44khz", - "24khz", - "16khz", - ], "model_type must be one of '44khz', '24khz', or '16khz'" - - assert model_bitrate in [ - "8kbps", - "16kbps", - ], "model_bitrate must be one of '8kbps', or '16kbps'" - - if tag == "latest": - tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] - - download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) - - if download_link is None: - raise ValueError( - f"Could not find model with tag {tag} and model type {model_type}" - ) - - local_path = ( - Path.home() - / ".cache" - / "descript" - / "dac" - / f"weights_{model_type}_{model_bitrate}_{tag}.pth" - ) - if not local_path.exists(): - local_path.parent.mkdir(parents=True, exist_ok=True) - - # Download the model - import requests - - response = requests.get(download_link) - - if response.status_code != 200: - raise ValueError( - f"Could not download model. Received response code {response.status_code}" - ) - local_path.write_bytes(response.content) - - return local_path - - -def load_model( - model_type: str = "44khz", - model_bitrate: str = "8kbps", - tag: str = "latest", - load_path: str = None, -): - if not load_path: - load_path = download( - model_type=model_type, model_bitrate=model_bitrate, tag=tag - ) - generator = DAC.load(load_path) - return generator diff --git a/dac-codec/dac/utils/__pycache__/__init__.cpython-310.pyc b/dac-codec/dac/utils/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 633e291000456c29e1ff62fe19bea0b4d1d6df9a..0000000000000000000000000000000000000000 Binary files a/dac-codec/dac/utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/dac-codec/dac/utils/__pycache__/decode.cpython-310.pyc b/dac-codec/dac/utils/__pycache__/decode.cpython-310.pyc deleted file mode 100644 index 2d9922a2f7c54a451cc7758b9b4a3bd287dce21c..0000000000000000000000000000000000000000 Binary files a/dac-codec/dac/utils/__pycache__/decode.cpython-310.pyc and /dev/null differ diff --git a/dac-codec/dac/utils/__pycache__/encode.cpython-310.pyc b/dac-codec/dac/utils/__pycache__/encode.cpython-310.pyc deleted file mode 100644 index c3f5f5611ce664c4ee4b6fcd784c39d895ca1b38..0000000000000000000000000000000000000000 Binary files a/dac-codec/dac/utils/__pycache__/encode.cpython-310.pyc and /dev/null differ diff --git a/dac-codec/dac/utils/decode.py b/dac-codec/dac/utils/decode.py deleted file mode 100644 index 08d44e8453ec4fa3433c2a9952d1a4da15315939..0000000000000000000000000000000000000000 --- a/dac-codec/dac/utils/decode.py +++ /dev/null @@ -1,95 +0,0 @@ -import warnings -from pathlib import Path - -import argbind -import numpy as np -import torch -from audiotools import AudioSignal -from tqdm import tqdm - -from dac import DACFile -from dac.utils import load_model - -warnings.filterwarnings("ignore", category=UserWarning) - - -@argbind.bind(group="decode", positional=True, without_prefix=True) -@torch.inference_mode() -@torch.no_grad() -def decode( - input: str, - output: str = "", - weights_path: str = "", - model_tag: str = "latest", - model_bitrate: str = "8kbps", - device: str = "cuda", - model_type: str = "44khz", - verbose: bool = False, -): - """Decode audio from codes. - - Parameters - ---------- - input : str - Path to input directory or file - output : str, optional - Path to output directory, by default "". - If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. - weights_path : str, optional - Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the - model_tag and model_type. - model_tag : str, optional - Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. - model_bitrate: str - Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". - device : str, optional - Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. - model_type : str, optional - The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. - """ - generator = load_model( - model_type=model_type, - model_bitrate=model_bitrate, - tag=model_tag, - load_path=weights_path, - ) - generator.to(device) - generator.eval() - - # Find all .dac files in input directory - _input = Path(input) - input_files = list(_input.glob("**/*.dac")) - - # If input is a .dac file, add it to the list - if _input.suffix == ".dac": - input_files.append(_input) - - # Create output directory - output = Path(output) - output.mkdir(parents=True, exist_ok=True) - - for i in tqdm(range(len(input_files)), desc=f"Decoding files"): - # Load file - artifact = DACFile.load(input_files[i]) - - # Reconstruct audio from codes - recons = generator.decompress(artifact, verbose=verbose) - - # Compute output path - relative_path = input_files[i].relative_to(input) - output_dir = output / relative_path.parent - if not relative_path.name: - output_dir = output - relative_path = input_files[i] - output_name = relative_path.with_suffix(".wav").name - output_path = output_dir / output_name - output_path.parent.mkdir(parents=True, exist_ok=True) - - # Write to file - recons.write(output_path) - - -if __name__ == "__main__": - args = argbind.parse_args() - with argbind.scope(args): - decode() diff --git a/dac-codec/dac/utils/encode.py b/dac-codec/dac/utils/encode.py deleted file mode 100644 index aa3f6f44b3c210f485da1b1726b85494ff5e7804..0000000000000000000000000000000000000000 --- a/dac-codec/dac/utils/encode.py +++ /dev/null @@ -1,94 +0,0 @@ -import math -import warnings -from pathlib import Path - -import argbind -import numpy as np -import torch -from audiotools import AudioSignal -from audiotools.core import util -from tqdm import tqdm - -from dac.utils import load_model - -warnings.filterwarnings("ignore", category=UserWarning) - - -@argbind.bind(group="encode", positional=True, without_prefix=True) -@torch.inference_mode() -@torch.no_grad() -def encode( - input: str, - output: str = "", - weights_path: str = "", - model_tag: str = "latest", - model_bitrate: str = "8kbps", - n_quantizers: int = None, - device: str = "cuda", - model_type: str = "44khz", - win_duration: float = 5.0, - verbose: bool = False, -): - """Encode audio files in input path to .dac format. - - Parameters - ---------- - input : str - Path to input audio file or directory - output : str, optional - Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. - weights_path : str, optional - Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the - model_tag and model_type. - model_tag : str, optional - Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. - model_bitrate: str - Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". - n_quantizers : int, optional - Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. - device : str, optional - Device to use, by default "cuda" - model_type : str, optional - The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. - """ - generator = load_model( - model_type=model_type, - model_bitrate=model_bitrate, - tag=model_tag, - load_path=weights_path, - ) - generator.to(device) - generator.eval() - kwargs = {"n_quantizers": n_quantizers} - - # Find all audio files in input path - input = Path(input) - audio_files = util.find_audio(input) - - output = Path(output) - output.mkdir(parents=True, exist_ok=True) - - for i in tqdm(range(len(audio_files)), desc="Encoding files"): - # Load file - signal = AudioSignal(audio_files[i]) - - # Encode audio to .dac format - artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) - - # Compute output path - relative_path = audio_files[i].relative_to(input) - output_dir = output / relative_path.parent - if not relative_path.name: - output_dir = output - relative_path = audio_files[i] - output_name = relative_path.with_suffix(".dac").name - output_path = output_dir / output_name - output_path.parent.mkdir(parents=True, exist_ok=True) - - artifact.save(output_path) - - -if __name__ == "__main__": - args = argbind.parse_args() - with argbind.scope(args): - encode() diff --git a/dac-codec/scripts/compute_entropy.py b/dac-codec/scripts/compute_entropy.py deleted file mode 100644 index a065cfd05794dd461818a06cfc529d4656422a42..0000000000000000000000000000000000000000 --- a/dac-codec/scripts/compute_entropy.py +++ /dev/null @@ -1,50 +0,0 @@ -import argbind -import audiotools as at -import numpy as np -import torch -import tqdm - -import dac - - -@argbind.bind(without_prefix=True, positional=True) -def main( - folder: str, - model_path: str, - n_samples: int = 1024, - device: str = "cuda", -): - files = at.util.find_audio(folder)[:n_samples] - signals = [ - at.AudioSignal.salient_excerpt(f, loudness_cutoff=-20, duration=1.0) - for f in files - ] - - with torch.no_grad(): - model = dac.model.DAC.load(model_path).to(device) - model.eval() - - codes = [] - for x in tqdm.tqdm(signals): - x = x.to(model.device) - o = model.encode(x.audio_data, x.sample_rate) - codes.append(o["codes"].cpu()) - - codes = torch.cat(codes, dim=-1) - entropy = [] - - for i in range(codes.shape[1]): - codes_ = codes[0, i, :] - counts = torch.bincount(codes_) - counts = (counts / counts.sum()).clamp(1e-10) - entropy.append(-(counts * counts.log()).sum().item() * np.log2(np.e)) - - pct = sum(entropy) / (10 * len(entropy)) - print(f"Entropy for each codebook: {entropy}") - print(f"Effective percentage: {pct * 100}%") - - -if __name__ == "__main__": - args = argbind.parse_args() - with argbind.scope(args): - main() diff --git a/dac-codec/scripts/evaluate.py b/dac-codec/scripts/evaluate.py deleted file mode 100644 index 07cdc3c95e6b3e3f2a28127741b2a9c623890e56..0000000000000000000000000000000000000000 --- a/dac-codec/scripts/evaluate.py +++ /dev/null @@ -1,105 +0,0 @@ -import csv -import multiprocessing as mp -from concurrent.futures import ProcessPoolExecutor -from dataclasses import dataclass -from pathlib import Path - -import argbind -import torch -from audiotools import AudioSignal -from audiotools import metrics -from audiotools.core import util -from audiotools.ml.decorators import Tracker -from train import losses - - -@dataclass -class State: - stft_loss: losses.MultiScaleSTFTLoss - mel_loss: losses.MelSpectrogramLoss - waveform_loss: losses.L1Loss - sisdr_loss: losses.SISDRLoss - - -def get_metrics(signal_path, recons_path, state): - output = {} - signal = AudioSignal(signal_path) - recons = AudioSignal(recons_path) - for sr in [22050, 44100]: - x = signal.clone().resample(sr) - y = recons.clone().resample(sr) - k = "22k" if sr == 22050 else "44k" - output.update( - { - f"mel-{k}": state.mel_loss(x, y), - f"stft-{k}": state.stft_loss(x, y), - f"waveform-{k}": state.waveform_loss(x, y), - f"sisdr-{k}": state.sisdr_loss(x, y), - f"visqol-audio-{k}": metrics.quality.visqol(x, y), - f"visqol-speech-{k}": metrics.quality.visqol(x, y, "speech"), - } - ) - output["path"] = signal.path_to_file - output.update(signal.metadata) - return output - - -@argbind.bind(without_prefix=True) -@torch.no_grad() -def evaluate( - input: str = "samples/input", - output: str = "samples/output", - n_proc: int = 50, -): - tracker = Tracker() - - waveform_loss = losses.L1Loss() - stft_loss = losses.MultiScaleSTFTLoss() - mel_loss = losses.MelSpectrogramLoss() - sisdr_loss = losses.SISDRLoss() - - state = State( - waveform_loss=waveform_loss, - stft_loss=stft_loss, - mel_loss=mel_loss, - sisdr_loss=sisdr_loss, - ) - - audio_files = util.find_audio(input) - output = Path(output) - output.mkdir(parents=True, exist_ok=True) - - @tracker.track("metrics", len(audio_files)) - def record(future, writer): - o = future.result() - for k, v in o.items(): - if torch.is_tensor(v): - o[k] = v.item() - writer.writerow(o) - o.pop("path") - return o - - futures = [] - with tracker.live: - with open(output / "metrics.csv", "w") as csvfile: - with ProcessPoolExecutor(n_proc, mp.get_context("fork")) as pool: - for i in range(len(audio_files)): - future = pool.submit( - get_metrics, audio_files[i], output / audio_files[i].name, state - ) - futures.append(future) - - keys = list(futures[0].result().keys()) - writer = csv.DictWriter(csvfile, fieldnames=keys) - writer.writeheader() - - for future in futures: - record(future, writer) - - tracker.done("test", f"N={len(audio_files)}") - - -if __name__ == "__main__": - args = argbind.parse_args() - with argbind.scope(args): - evaluate() diff --git a/dac-codec/scripts/get_samples.py b/dac-codec/scripts/get_samples.py deleted file mode 100644 index b3adad0eac9dcd79bdfb79429eb8cea72947ef0c..0000000000000000000000000000000000000000 --- a/dac-codec/scripts/get_samples.py +++ /dev/null @@ -1,95 +0,0 @@ -from pathlib import Path - -import argbind -import torch -from audiotools import AudioSignal -from audiotools.core import util -from audiotools.ml.decorators import Tracker -from train import Accelerator -from train import DAC - -from dac.compare.encodec import Encodec - -Encodec = argbind.bind(Encodec) - - -def load_state( - accel: Accelerator, - tracker: Tracker, - save_path: str, - tag: str = "latest", - load_weights: bool = False, - model_type: str = "dac", - bandwidth: float = 24.0, -): - kwargs = { - "folder": f"{save_path}/{tag}", - "map_location": "cpu", - "package": not load_weights, - } - tracker.print(f"Resuming from {str(Path('.').absolute())}/{kwargs['folder']}") - - if model_type == "dac": - generator, _ = DAC.load_from_folder(**kwargs) - elif model_type == "encodec": - generator = Encodec(bandwidth=bandwidth) - - generator = accel.prepare_model(generator) - return generator - - -@torch.no_grad() -def process(signal, accel, generator, **kwargs): - signal = signal.to(accel.device) - recons = generator(signal.audio_data, signal.sample_rate, **kwargs)["audio"] - recons = AudioSignal(recons, signal.sample_rate) - recons = recons.normalize(signal.loudness()) - return recons.cpu() - - -@argbind.bind(without_prefix=True) -@torch.no_grad() -def get_samples( - accel, - path: str = "ckpt", - input: str = "samples/input", - output: str = "samples/output", - model_type: str = "dac", - model_tag: str = "latest", - bandwidth: float = 24.0, - n_quantizers: int = None, -): - tracker = Tracker(log_file=f"{path}/eval.txt", rank=accel.local_rank) - generator = load_state( - accel, - tracker, - save_path=path, - model_type=model_type, - bandwidth=bandwidth, - tag=model_tag, - ) - generator.eval() - kwargs = {"n_quantizers": n_quantizers} if model_type == "dac" else {} - - audio_files = util.find_audio(input) - - global process - process = tracker.track("process", len(audio_files))(process) - - output = Path(output) - output.mkdir(parents=True, exist_ok=True) - - with tracker.live: - for i in range(len(audio_files)): - signal = AudioSignal(audio_files[i]) - recons = process(signal, accel, generator, **kwargs) - recons.write(output / audio_files[i].name) - - tracker.done("test", f"N={len(audio_files)}") - - -if __name__ == "__main__": - args = argbind.parse_args() - with argbind.scope(args): - with Accelerator() as accel: - get_samples(accel) diff --git a/dac-codec/scripts/mushra.py b/dac-codec/scripts/mushra.py deleted file mode 100644 index 2bdf0060c1d73f0d8fcc05c8b9078a14af7dc94d..0000000000000000000000000000000000000000 --- a/dac-codec/scripts/mushra.py +++ /dev/null @@ -1,104 +0,0 @@ -import string -from dataclasses import dataclass -from pathlib import Path -from typing import List - -import argbind -import gradio as gr -from audiotools import preference as pr - - -@argbind.bind(without_prefix=True) -@dataclass -class Config: - folder: str = None - save_path: str = "results.csv" - conditions: List[str] = None - reference: str = None - seed: int = 0 - share: bool = False - n_samples: int = 10 - - -def get_text(wav_file: str): - txt_file = Path(wav_file).with_suffix(".txt") - if Path(txt_file).exists(): - with open(txt_file, "r") as f: - txt = f.read() - else: - txt = "" - return f"""
{txt}
""" - - -def main(config: Config): - with gr.Blocks() as app: - save_path = config.save_path - samples = gr.State(pr.Samples(config.folder, n_samples=config.n_samples)) - - reference = config.reference - conditions = config.conditions - - player = pr.Player(app) - player.create() - if reference is not None: - player.add("Play Reference") - - user = pr.create_tracker(app) - ratings = [] - - with gr.Row(): - txt = gr.HTML("") - - with gr.Row(): - gr.Button("Rate audio quality", interactive=False) - with gr.Column(scale=8): - gr.HTML(pr.slider_mushra) - - for i in range(len(conditions)): - with gr.Row().style(equal_height=True): - x = string.ascii_uppercase[i] - player.add(f"Play {x}") - with gr.Column(scale=9): - ratings.append(gr.Slider(value=50, interactive=True)) - - def build(user, samples, *ratings): - # Filter out samples user has done already, by looking in the CSV. - samples.filter_completed(user, save_path) - - # Write results to CSV - if samples.current > 0: - start_idx = 1 if reference is not None else 0 - name = samples.names[samples.current - 1] - result = {"sample": name, "user": user} - for k, r in zip(samples.order[start_idx:], ratings): - result[k] = r - pr.save_result(result, save_path) - - updates, done, pbar = samples.get_next_sample(reference, conditions) - wav_file = updates[0]["value"] - - txt_update = gr.update(value=get_text(wav_file)) - - return ( - updates - + [gr.update(value=50) for _ in ratings] - + [done, samples, pbar, txt_update] - ) - - progress = gr.HTML() - begin = gr.Button("Submit", elem_id="start-survey") - begin.click( - fn=build, - inputs=[user, samples] + ratings, - outputs=player.to_list() + ratings + [begin, samples, progress, txt], - ).then(None, _js=pr.reset_player) - - # Comment this back in to actually launch the script. - app.launch(share=config.share) - - -if __name__ == "__main__": - args = argbind.parse_args() - with argbind.scope(args): - config = Config() - main(config) diff --git a/dac-codec/scripts/organize_daps.py b/dac-codec/scripts/organize_daps.py deleted file mode 100644 index 13a43db17fa8136a2e67059ab375d8ce69ec008a..0000000000000000000000000000000000000000 --- a/dac-codec/scripts/organize_daps.py +++ /dev/null @@ -1,97 +0,0 @@ -import os -import pathlib -import shutil -from collections import defaultdict -from typing import Tuple - -import argbind -import numpy as np -import tqdm -from audiotools import util - - -@argbind.bind() -def split( - audio_files, ratio: Tuple[float, float, float] = (0.8, 0.1, 0.1), seed: int = 0 -): - assert sum(ratio) == 1.0 - util.seed(seed) - - idx = np.arange(len(audio_files)) - np.random.shuffle(idx) - - b = np.cumsum([0] + list(ratio)) * len(idx) - b = [int(_b) for _b in b] - train_idx = idx[b[0] : b[1]] - val_idx = idx[b[1] : b[2]] - test_idx = idx[b[2] :] - - audio_files = np.array(audio_files) - train_files = audio_files[train_idx] - val_files = audio_files[val_idx] - test_files = audio_files[test_idx] - - return train_files, val_files, test_files - - -def assign(val_split, test_split): - def _assign(value): - if value in val_split: - return "val" - if value in test_split: - return "test" - return "train" - - return _assign - - -DAPS_VAL = ["f2", "m2"] -DAPS_TEST = ["f10", "m10"] - - -@argbind.bind(without_prefix=True) -def process( - dataset: str = "daps", - daps_subset: str = "", -): - get_split = None - get_value = lambda path: path - - data_path = pathlib.Path("/data") - dataset_path = data_path / dataset - audio_files = util.find_audio(dataset_path) - - if dataset == "daps": - get_split = assign(DAPS_VAL, DAPS_TEST) - get_value = lambda path: (str(path).split("/")[-1].split("_", maxsplit=4)[0]) - audio_files = [ - x - for x in util.find_audio(dataset_path) - if daps_subset in str(x) and "breaths" not in str(x) - ] - - if get_split is None: - _, val, test = split(audio_files) - get_split = assign(val, test) - - splits = defaultdict(list) - for x in audio_files: - _split = get_split(get_value(x)) - splits[_split].append(x) - - with util.chdir(dataset_path): - for k, v in splits.items(): - v = sorted(v) - print(f"Processing {k} in {dataset_path} of length {len(v)}") - for _v in tqdm.tqdm(v): - tgt_path = pathlib.Path( - str(_v).replace(str(dataset_path), str(dataset_path / k)) - ) - tgt_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copyfile(_v, tgt_path) - - -if __name__ == "__main__": - args = argbind.parse_args() - with argbind.scope(args): - process() diff --git a/dac-codec/scripts/save_test_set.py b/dac-codec/scripts/save_test_set.py deleted file mode 100644 index 93fd1b168a492528b60b767a0b235b7d323f735e..0000000000000000000000000000000000000000 --- a/dac-codec/scripts/save_test_set.py +++ /dev/null @@ -1,55 +0,0 @@ -import csv -from pathlib import Path - -import argbind -import torch -from audiotools.core import util -from audiotools.ml.decorators import Tracker -from train import Accelerator - -import scripts.train as train - - -@torch.no_grad() -def process(batch, accel, test_data): - batch = util.prepare_batch(batch, accel.device) - signal = test_data.transform(batch["signal"].clone(), **batch["transform_args"]) - return signal.cpu() - - -@argbind.bind(without_prefix=True) -@torch.no_grad() -def save_test_set(args, accel, sample_rate: int = 44100, output: str = "samples/input"): - tracker = Tracker() - with argbind.scope(args, "test"): - test_data = train.build_dataset(sample_rate) - - global process - process = tracker.track("process", len(test_data))(process) - - output = Path(output) - output.mkdir(parents=True, exist_ok=True) - (output.parent / "input").mkdir(parents=True, exist_ok=True) - with open(output / "metadata.csv", "w") as csvfile: - keys = ["path", "original"] - writer = csv.DictWriter(csvfile, fieldnames=keys) - writer.writeheader() - - with tracker.live: - for i in range(len(test_data)): - signal = process(test_data[i], accel, test_data) - input_path = output.parent / "input" / f"sample_{i}.wav" - metadata = { - "path": str(input_path), - "original": str(signal.path_to_input_file), - } - writer.writerow(metadata) - signal.write(input_path) - tracker.done("test", f"N={len(test_data)}") - - -if __name__ == "__main__": - args = argbind.parse_args() - with argbind.scope(args): - with Accelerator() as accel: - save_test_set(args, accel) diff --git a/dac-codec/tests/__init__.py b/dac-codec/tests/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/dac-codec/tests/test_cli.py b/dac-codec/tests/test_cli.py deleted file mode 100644 index 9487b66d416f0906657b5ec363436c79d2338a3a..0000000000000000000000000000000000000000 --- a/dac-codec/tests/test_cli.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -Tests for CLI. -""" -import subprocess -from pathlib import Path - -import argbind -import numpy as np -import pytest -import torch -from audiotools import AudioSignal - -from dac.__main__ import run - - -def setup_module(module): - data_dir = Path(__file__).parent / "assets" - data_dir.mkdir(exist_ok=True, parents=True) - input_dir = data_dir / "input" - input_dir.mkdir(exist_ok=True, parents=True) - - for i in range(5): - signal = AudioSignal(np.random.randn(1000), 44_100) - signal.write(input_dir / f"sample_{i}.wav") - return input_dir - - -def teardown_module(module): - repo_root = Path(__file__).parent.parent - subprocess.check_output(["rm", "-rf", f"{repo_root}/tests/assets"]) - - -@pytest.mark.parametrize("model_type", ["44khz", "24khz", "16khz"]) -def test_reconstruction(model_type): - # Test encoding - input_dir = Path(__file__).parent / "assets" / "input" - output_dir = input_dir.parent / model_type / "encoded_output" - args = { - "input": str(input_dir), - "output": str(output_dir), - "device": "cuda" if torch.cuda.is_available() else "cpu", - "model_type": model_type, - } - with argbind.scope(args): - run("encode") - - # Test decoding - input_dir = output_dir - output_dir = input_dir.parent / model_type / "decoded_output" - args = { - "input": str(input_dir), - "output": str(output_dir), - "model_type": model_type, - } - with argbind.scope(args): - run("decode") - - -def test_compression(): - # Test encoding - input_dir = Path(__file__).parent / "assets" / "input" - output_dir = input_dir.parent / "encoded_output_quantizers" - args = { - "input": str(input_dir), - "output": str(output_dir), - "n_quantizers": 3, - "device": "cuda" if torch.cuda.is_available() else "cpu", - } - with argbind.scope(args): - run("encode") - - # Open .dac file - dac_file = output_dir / "sample_0.dac" - artifacts = np.load(dac_file, allow_pickle=True)[()] - codes = artifacts["codes"] - - # Ensure that the number of quantizers is correct - assert codes.shape[1] == 3 - - # Ensure that dtype of compression is uint16 - assert codes.dtype == np.uint16 - - -# CUDA_VISIBLE_DEVICES=0 python -m pytest tests/test_cli.py -s