Commit ·
8b9204d
0
Parent(s):
Duplicate from m-a-p/MERT-v1-330M
Browse filesCo-authored-by: Yizhi Li <yizhilll@users.noreply.huggingface.co>
- .gitattributes +34 -0
- MERT-v1-330M_fairseq.pt +3 -0
- README.md +124 -0
- config.json +92 -0
- configuration_MERT.py +141 -0
- modeling_MERT.py +409 -0
- preprocessor_config.json +9 -0
- pytorch_model.bin +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
MERT-v1-330M_fairseq.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:13d9b88455f88f8608399aa0e921d23100298d867d04380402be171a01f50a89
|
| 3 |
+
size 3991038973
|
README.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-4.0
|
| 3 |
+
inference: false
|
| 4 |
+
tags:
|
| 5 |
+
- music
|
| 6 |
+
pipeline_tag: audio-classification
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# Introduction to our series work
|
| 10 |
+
|
| 11 |
+
The development log of our Music Audio Pre-training (m-a-p) model family:
|
| 12 |
+
- 02/06/2023: [arxiv pre-print](https://arxiv.org/abs/2306.00107) and training [codes](https://github.com/yizhilll/MERT) released.
|
| 13 |
+
- 17/03/2023: we release two advanced music understanding models, [MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M) and [MERT-v1-330M](https://huggingface.co/m-a-p/MERT-v1-330M) , trained with new paradigm and dataset. They outperform the previous models and can better generalize to more tasks.
|
| 14 |
+
- 14/03/2023: we retrained the MERT-v0 model with open-source-only music dataset [MERT-v0-public](https://huggingface.co/m-a-p/MERT-v0-public)
|
| 15 |
+
- 29/12/2022: a music understanding model [MERT-v0](https://huggingface.co/m-a-p/MERT-v0) trained with **MLM** paradigm, which performs better at downstream tasks.
|
| 16 |
+
- 29/10/2022: a pre-trained MIR model [music2vec](https://huggingface.co/m-a-p/music2vec-v1) trained with **BYOL** paradigm.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
Here is a table for quick model pick-up:
|
| 21 |
+
|
| 22 |
+
| Name | Pre-train Paradigm | Training Data (hour) | Pre-train Context (second) | Model Size | Transformer Layer-Dimension | Feature Rate | Sample Rate | Release Date |
|
| 23 |
+
| ------------------------------------------------------------ | ------------------ | -------------------- | ---------------------------- | ---------- | --------------------------- | ------------ | ----------- | ------------ |
|
| 24 |
+
| [MERT-v1-330M](https://huggingface.co/m-a-p/MERT-v1-330M) | MLM | 160K | 5 | 330M | 24-1024 | 75 Hz | 24K Hz | 17/03/2023 |
|
| 25 |
+
| [MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M) | MLM | 20K | 5 | 95M | 12-768 | 75 Hz | 24K Hz | 17/03/2023 |
|
| 26 |
+
| [MERT-v0-public](https://huggingface.co/m-a-p/MERT-v0-public) | MLM | 900 | 5 | 95M | 12-768 | 50 Hz | 16K Hz | 14/03/2023 |
|
| 27 |
+
| [MERT-v0](https://huggingface.co/m-a-p/MERT-v0) | MLM | 1000 | 5 | 95 M | 12-768 | 50 Hz | 16K Hz | 29/12/2022 |
|
| 28 |
+
| [music2vec-v1](https://huggingface.co/m-a-p/music2vec-v1) | BYOL | 1000 | 30 | 95 M | 12-768 | 50 Hz | 16K Hz | 30/10/2022 |
|
| 29 |
+
|
| 30 |
+
## Explanation
|
| 31 |
+
|
| 32 |
+
The m-a-p models share the similar model architecture and the most distinguished difference is the paradigm in used pre-training. Other than that, there are several nuance technical configuration needs to know before using:
|
| 33 |
+
|
| 34 |
+
- **Model Size**: the number of parameters that would be loaded to memory. Please select the appropriate size fitting your hardware.
|
| 35 |
+
- **Transformer Layer-Dimension**: The number of transformer layers and the corresponding feature dimensions can be outputted from our model. This is marked out because features extracted by **different layers could have various performance depending on tasks**.
|
| 36 |
+
- **Feature Rate**: Given a 1-second audio input, the number of features output by the model.
|
| 37 |
+
- **Sample Rate**: The frequency of audio that the model is trained with.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Introduction to MERT-v1
|
| 42 |
+
|
| 43 |
+
Compared to MERT-v0, we introduce multiple new things in the MERT-v1 pre-training:
|
| 44 |
+
|
| 45 |
+
- Change the pseudo labels to 8 codebooks from [encodec](https://github.com/facebookresearch/encodec), which potentially has higher quality and empower our model to support music generation.
|
| 46 |
+
- MLM prediction with in-batch noise mixture.
|
| 47 |
+
- Train with higher audio frequency (24K Hz).
|
| 48 |
+
- Train with more audio data (up to 160 thousands of hours).
|
| 49 |
+
- More available model sizes 95M and 330M.
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
More details will be written in our coming-soon paper.
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Model Usage
|
| 58 |
+
|
| 59 |
+
```python
|
| 60 |
+
# from transformers import Wav2Vec2Processor
|
| 61 |
+
from transformers import Wav2Vec2FeatureExtractor
|
| 62 |
+
from transformers import AutoModel
|
| 63 |
+
import torch
|
| 64 |
+
from torch import nn
|
| 65 |
+
import torchaudio.transforms as T
|
| 66 |
+
from datasets import load_dataset
|
| 67 |
+
|
| 68 |
+
# loading our model weights
|
| 69 |
+
model = AutoModel.from_pretrained("m-a-p/MERT-v1-330M", trust_remote_code=True)
|
| 70 |
+
# loading the corresponding preprocessor config
|
| 71 |
+
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
|
| 72 |
+
|
| 73 |
+
# load demo audio and set processor
|
| 74 |
+
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
| 75 |
+
dataset = dataset.sort("id")
|
| 76 |
+
sampling_rate = dataset.features["audio"].sampling_rate
|
| 77 |
+
|
| 78 |
+
resample_rate = processor.sampling_rate
|
| 79 |
+
# make sure the sample_rate aligned
|
| 80 |
+
if resample_rate != sampling_rate:
|
| 81 |
+
print(f'setting rate from {sampling_rate} to {resample_rate}')
|
| 82 |
+
resampler = T.Resample(sampling_rate, resample_rate)
|
| 83 |
+
else:
|
| 84 |
+
resampler = None
|
| 85 |
+
|
| 86 |
+
# audio file is decoded on the fly
|
| 87 |
+
if resampler is None:
|
| 88 |
+
input_audio = dataset[0]["audio"]["array"]
|
| 89 |
+
else:
|
| 90 |
+
input_audio = resampler(torch.from_numpy(dataset[0]["audio"]["array"]))
|
| 91 |
+
|
| 92 |
+
inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt")
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
outputs = model(**inputs, output_hidden_states=True)
|
| 95 |
+
|
| 96 |
+
# take a look at the output shape, there are 25 layers of representation
|
| 97 |
+
# each layer performs differently in different downstream tasks, you should choose empirically
|
| 98 |
+
all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
|
| 99 |
+
print(all_layer_hidden_states.shape) # [25 layer, Time steps, 1024 feature_dim]
|
| 100 |
+
|
| 101 |
+
# for utterance level classification tasks, you can simply reduce the representation in time
|
| 102 |
+
time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
|
| 103 |
+
print(time_reduced_hidden_states.shape) # [25, 1024]
|
| 104 |
+
|
| 105 |
+
# you can even use a learnable weighted average representation
|
| 106 |
+
aggregator = nn.Conv1d(in_channels=25, out_channels=1, kernel_size=1)
|
| 107 |
+
weighted_avg_hidden_states = aggregator(time_reduced_hidden_states.unsqueeze(0)).squeeze()
|
| 108 |
+
print(weighted_avg_hidden_states.shape) # [1024]
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# Citation
|
| 114 |
+
|
| 115 |
+
```shell
|
| 116 |
+
@misc{li2023mert,
|
| 117 |
+
title={MERT: Acoustic Music Understanding Model with Large-Scale Self-supervised Training},
|
| 118 |
+
author={Yizhi Li and Ruibin Yuan and Ge Zhang and Yinghao Ma and Xingran Chen and Hanzhi Yin and Chenghua Lin and Anton Ragni and Emmanouil Benetos and Norbert Gyenge and Roger Dannenberg and Ruibo Liu and Wenhu Chen and Gus Xia and Yemin Shi and Wenhao Huang and Yike Guo and Jie Fu},
|
| 119 |
+
year={2023},
|
| 120 |
+
eprint={2306.00107},
|
| 121 |
+
archivePrefix={arXiv},
|
| 122 |
+
primaryClass={cs.SD}
|
| 123 |
+
}
|
| 124 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "m-a-p/MERT-v1-330M",
|
| 3 |
+
"activation_dropout": 0.0,
|
| 4 |
+
"apply_spec_augment": true,
|
| 5 |
+
"architectures": [
|
| 6 |
+
"MERTModel"
|
| 7 |
+
],
|
| 8 |
+
"attention_dropout": 0.0,
|
| 9 |
+
"attention_relax": 32.0,
|
| 10 |
+
"auto_map": {
|
| 11 |
+
"AutoConfig": "configuration_MERT.MERTConfig",
|
| 12 |
+
"AutoModel": "modeling_MERT.MERTModel"
|
| 13 |
+
},
|
| 14 |
+
"bos_token_id": 1,
|
| 15 |
+
"classifier_proj_size": 256,
|
| 16 |
+
"conv_bias": false,
|
| 17 |
+
"conv_dim": [
|
| 18 |
+
512,
|
| 19 |
+
512,
|
| 20 |
+
512,
|
| 21 |
+
512,
|
| 22 |
+
512,
|
| 23 |
+
512,
|
| 24 |
+
512
|
| 25 |
+
],
|
| 26 |
+
"conv_kernel": [
|
| 27 |
+
10,
|
| 28 |
+
3,
|
| 29 |
+
3,
|
| 30 |
+
3,
|
| 31 |
+
3,
|
| 32 |
+
2,
|
| 33 |
+
2
|
| 34 |
+
],
|
| 35 |
+
"conv_stride": [
|
| 36 |
+
5,
|
| 37 |
+
2,
|
| 38 |
+
2,
|
| 39 |
+
2,
|
| 40 |
+
2,
|
| 41 |
+
2,
|
| 42 |
+
2
|
| 43 |
+
],
|
| 44 |
+
"ctc_loss_reduction": "sum",
|
| 45 |
+
"ctc_zero_infinity": false,
|
| 46 |
+
"deepnorm": false,
|
| 47 |
+
"do_stable_layer_norm": true,
|
| 48 |
+
"eos_token_id": 2,
|
| 49 |
+
"feat_extract_activation": "gelu",
|
| 50 |
+
"feat_extract_dropout": 0.0,
|
| 51 |
+
"feat_extract_norm": "group",
|
| 52 |
+
"feat_proj_dropout": 0.0,
|
| 53 |
+
"feat_proj_layer_norm": true,
|
| 54 |
+
"feature_extractor_cqt": false,
|
| 55 |
+
"feature_extractor_cqt_bins": 336,
|
| 56 |
+
"final_dropout": 0.0,
|
| 57 |
+
"gradient_checkpointing": false,
|
| 58 |
+
"hidden_act": "gelu",
|
| 59 |
+
"hidden_dropout": 0.0,
|
| 60 |
+
"hidden_size": 1024,
|
| 61 |
+
"initializer_range": 0.02,
|
| 62 |
+
"intermediate_size": 4096,
|
| 63 |
+
"layer_norm_eps": 1e-05,
|
| 64 |
+
"layerdrop": 0.0,
|
| 65 |
+
"mask_channel_length": 10,
|
| 66 |
+
"mask_channel_min_space": 1,
|
| 67 |
+
"mask_channel_other": 0.0,
|
| 68 |
+
"mask_channel_prob": 0.0,
|
| 69 |
+
"mask_channel_selection": "static",
|
| 70 |
+
"mask_feature_length": 10,
|
| 71 |
+
"mask_feature_min_masks": 0,
|
| 72 |
+
"mask_feature_prob": 0.0,
|
| 73 |
+
"mask_time_length": 10,
|
| 74 |
+
"mask_time_min_masks": 2,
|
| 75 |
+
"mask_time_min_space": 1,
|
| 76 |
+
"mask_time_other": 0.0,
|
| 77 |
+
"mask_time_prob": 0.075,
|
| 78 |
+
"mask_time_selection": "static",
|
| 79 |
+
"model_type": "mert_model",
|
| 80 |
+
"num_attention_heads": 16,
|
| 81 |
+
"num_conv_pos_embedding_groups": 16,
|
| 82 |
+
"num_conv_pos_embeddings": 128,
|
| 83 |
+
"num_feat_extract_layers": 7,
|
| 84 |
+
"num_hidden_layers": 24,
|
| 85 |
+
"pad_token_id": 0,
|
| 86 |
+
"sample_rate": 24000,
|
| 87 |
+
"tokenizer_class": "Wav2Vec2CTCTokenizer",
|
| 88 |
+
"torch_dtype": "float32",
|
| 89 |
+
"transformers_version": "4.27.1",
|
| 90 |
+
"use_weighted_layer_sum": false,
|
| 91 |
+
"vocab_size": 32
|
| 92 |
+
}
|
configuration_MERT.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MERT model configuration
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import functools
|
| 6 |
+
import operator
|
| 7 |
+
|
| 8 |
+
# from ...configuration_utils import PretrainedConfig
|
| 9 |
+
# from ...utils import logging
|
| 10 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 11 |
+
from transformers.utils import logging
|
| 12 |
+
|
| 13 |
+
logger = logging.get_logger(__name__)
|
| 14 |
+
|
| 15 |
+
# TODO: use this MAP while uploading to Huggingface
|
| 16 |
+
# HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
| 17 |
+
# "facebook/hubert-base-ls960": "https://huggingface.co/facebook/hubert-base-ls960/resolve/main/config.json",
|
| 18 |
+
# # See all Hubert models at https://huggingface.co/models?filter=hubert
|
| 19 |
+
# }
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MERTConfig(PretrainedConfig):
|
| 23 |
+
r"""
|
| 24 |
+
"""
|
| 25 |
+
model_type = "mert_model"
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
vocab_size=32,
|
| 30 |
+
hidden_size=768,
|
| 31 |
+
num_hidden_layers=12,
|
| 32 |
+
num_attention_heads=12,
|
| 33 |
+
intermediate_size=3072,
|
| 34 |
+
hidden_act="gelu",
|
| 35 |
+
hidden_dropout=0.1,
|
| 36 |
+
activation_dropout=0.1,
|
| 37 |
+
attention_dropout=0.1,
|
| 38 |
+
feat_proj_layer_norm=True,
|
| 39 |
+
feat_proj_dropout=0.0,
|
| 40 |
+
final_dropout=0.1,
|
| 41 |
+
layerdrop=0.1,
|
| 42 |
+
initializer_range=0.02,
|
| 43 |
+
layer_norm_eps=1e-5,
|
| 44 |
+
feat_extract_norm="group",
|
| 45 |
+
feat_extract_activation="gelu",
|
| 46 |
+
conv_dim=(512, 512, 512, 512, 512, 512, 512),
|
| 47 |
+
conv_stride=(5, 2, 2, 2, 2, 2, 2),
|
| 48 |
+
conv_kernel=(10, 3, 3, 3, 3, 2, 2),
|
| 49 |
+
conv_bias=False,
|
| 50 |
+
num_conv_pos_embeddings=128,
|
| 51 |
+
num_conv_pos_embedding_groups=16,
|
| 52 |
+
do_stable_layer_norm=False,
|
| 53 |
+
apply_spec_augment=True,
|
| 54 |
+
mask_time_prob=0.05,
|
| 55 |
+
mask_time_length=10,
|
| 56 |
+
mask_time_min_masks=2,
|
| 57 |
+
mask_feature_prob=0.0,
|
| 58 |
+
mask_feature_length=10,
|
| 59 |
+
mask_feature_min_masks=0,
|
| 60 |
+
ctc_loss_reduction="sum",
|
| 61 |
+
ctc_zero_infinity=False,
|
| 62 |
+
use_weighted_layer_sum=False,
|
| 63 |
+
classifier_proj_size=256,
|
| 64 |
+
pad_token_id=0,
|
| 65 |
+
bos_token_id=1,
|
| 66 |
+
eos_token_id=2,
|
| 67 |
+
feature_extractor_cqt=False,
|
| 68 |
+
feature_extractor_cqt_bins=336,
|
| 69 |
+
deepnorm=False,
|
| 70 |
+
attention_relax=-1.0,
|
| 71 |
+
**kwargs
|
| 72 |
+
):
|
| 73 |
+
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
| 74 |
+
self.hidden_size = hidden_size
|
| 75 |
+
self.feat_extract_norm = feat_extract_norm
|
| 76 |
+
self.feat_extract_activation = feat_extract_activation
|
| 77 |
+
self.conv_dim = list(conv_dim)
|
| 78 |
+
self.conv_stride = list(conv_stride)
|
| 79 |
+
self.conv_kernel = list(conv_kernel)
|
| 80 |
+
self.conv_bias = conv_bias
|
| 81 |
+
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
| 82 |
+
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
| 83 |
+
self.num_feat_extract_layers = len(self.conv_dim)
|
| 84 |
+
self.num_hidden_layers = num_hidden_layers
|
| 85 |
+
self.intermediate_size = intermediate_size
|
| 86 |
+
self.hidden_act = hidden_act
|
| 87 |
+
self.num_attention_heads = num_attention_heads
|
| 88 |
+
self.hidden_dropout = hidden_dropout
|
| 89 |
+
self.attention_dropout = attention_dropout
|
| 90 |
+
self.activation_dropout = activation_dropout
|
| 91 |
+
self.feat_proj_layer_norm = feat_proj_layer_norm
|
| 92 |
+
self.feat_proj_dropout = feat_proj_dropout
|
| 93 |
+
self.final_dropout = final_dropout
|
| 94 |
+
self.layerdrop = layerdrop
|
| 95 |
+
self.layer_norm_eps = layer_norm_eps
|
| 96 |
+
self.initializer_range = initializer_range
|
| 97 |
+
self.vocab_size = vocab_size
|
| 98 |
+
self.do_stable_layer_norm = do_stable_layer_norm
|
| 99 |
+
self.use_weighted_layer_sum = use_weighted_layer_sum
|
| 100 |
+
self.classifier_proj_size = classifier_proj_size
|
| 101 |
+
|
| 102 |
+
if (
|
| 103 |
+
(len(self.conv_stride) != self.num_feat_extract_layers)
|
| 104 |
+
or (len(self.conv_kernel) != self.num_feat_extract_layers)
|
| 105 |
+
or (len(self.conv_dim) != self.num_feat_extract_layers)
|
| 106 |
+
):
|
| 107 |
+
raise ValueError(
|
| 108 |
+
"Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
|
| 109 |
+
" `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
|
| 110 |
+
f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
|
| 111 |
+
f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
|
| 115 |
+
self.apply_spec_augment = apply_spec_augment
|
| 116 |
+
self.mask_time_prob = mask_time_prob
|
| 117 |
+
self.mask_time_length = mask_time_length
|
| 118 |
+
self.mask_time_min_masks = mask_time_min_masks
|
| 119 |
+
self.mask_feature_prob = mask_feature_prob
|
| 120 |
+
self.mask_feature_length = mask_feature_length
|
| 121 |
+
self.mask_feature_min_masks = mask_feature_min_masks
|
| 122 |
+
|
| 123 |
+
# ctc loss
|
| 124 |
+
self.ctc_loss_reduction = ctc_loss_reduction
|
| 125 |
+
self.ctc_zero_infinity = ctc_zero_infinity
|
| 126 |
+
|
| 127 |
+
# cqt feature extractor
|
| 128 |
+
self.feature_extractor_cqt = feature_extractor_cqt
|
| 129 |
+
self.feature_extractor_cqt_bins = feature_extractor_cqt_bins
|
| 130 |
+
|
| 131 |
+
# deepnorm: up-scale weighted residual conection + down-scale initial value transformer encoder
|
| 132 |
+
self.deepnorm = deepnorm
|
| 133 |
+
|
| 134 |
+
self.attention_relax = attention_relax
|
| 135 |
+
|
| 136 |
+
# fix bug with hf > 4.42
|
| 137 |
+
self.conv_pos_batch_norm = False
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def inputs_to_logits_ratio(self):
|
| 141 |
+
return functools.reduce(operator.mul, self.conv_stride, 1)
|
modeling_MERT.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MERT model definition.
|
| 3 |
+
We largely adapt codes from:
|
| 4 |
+
1. https://github.com/huggingface/transformers/blob/main/src/transformers/models/hubert/modeling_hubert.py
|
| 5 |
+
2. https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
from transformers.models.hubert.modeling_hubert import (
|
| 14 |
+
HubertFeatureEncoder,
|
| 15 |
+
HubertModel,
|
| 16 |
+
HubertEncoderStableLayerNorm,
|
| 17 |
+
HubertEncoder,
|
| 18 |
+
HubertEncoderLayer,
|
| 19 |
+
HubertPositionalConvEmbedding,
|
| 20 |
+
HubertAttention,
|
| 21 |
+
HubertFeedForward,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from nnAudio import features as nnAudioFeatures
|
| 26 |
+
NNAUDIO_INSTALLED=True
|
| 27 |
+
except:
|
| 28 |
+
print("WARNING: feature_extractor_cqt requires the libray 'nnAudio'")
|
| 29 |
+
NNAUDIO_INSTALLED=False
|
| 30 |
+
|
| 31 |
+
from .configuration_MERT import MERTConfig
|
| 32 |
+
|
| 33 |
+
class MERTFeatureProjection(nn.Module):
|
| 34 |
+
def __init__(self, config):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.feat_proj_layer_norm = config.feat_proj_layer_norm
|
| 37 |
+
self.feature_extractor_cqt = config.feature_extractor_cqt
|
| 38 |
+
|
| 39 |
+
if self.feature_extractor_cqt:
|
| 40 |
+
# v3 concat features
|
| 41 |
+
self.feature_dimension = config.conv_dim[-1] + config.feature_extractor_cqt_bins
|
| 42 |
+
print(f"feature dimention: {self.feature_dimension}")
|
| 43 |
+
else:
|
| 44 |
+
self.feature_dimension = config.conv_dim[-1]
|
| 45 |
+
if self.feat_proj_layer_norm:
|
| 46 |
+
self.layer_norm = nn.LayerNorm(self.feature_dimension, eps=config.layer_norm_eps)
|
| 47 |
+
self.projection = nn.Linear(self.feature_dimension, config.hidden_size)
|
| 48 |
+
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
| 49 |
+
|
| 50 |
+
def forward(self, hidden_states):
|
| 51 |
+
# non-projected hidden states are needed for quantization
|
| 52 |
+
if self.feat_proj_layer_norm:
|
| 53 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 54 |
+
hidden_states = self.projection(hidden_states)
|
| 55 |
+
hidden_states = self.dropout(hidden_states)
|
| 56 |
+
return hidden_states
|
| 57 |
+
|
| 58 |
+
class MERTModel(HubertModel):
|
| 59 |
+
# overwrite config class
|
| 60 |
+
config_class = MERTConfig
|
| 61 |
+
base_model_prefix = "mert_model"
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
config: MERTConfig,
|
| 65 |
+
) -> None:
|
| 66 |
+
"""
|
| 67 |
+
initialize the with the grandparent method HubertPreTrainedModel.__init__()
|
| 68 |
+
and modify the HuBERTModel.__init__()
|
| 69 |
+
"""
|
| 70 |
+
super(HubertModel, self).__init__(config)
|
| 71 |
+
|
| 72 |
+
self.config = config
|
| 73 |
+
|
| 74 |
+
self.feature_extractor = HubertFeatureEncoder(config)
|
| 75 |
+
self.feature_projection = MERTFeatureProjection(config) # replace Feature Projection for introcuing new feature
|
| 76 |
+
|
| 77 |
+
if self.config.feature_extractor_cqt:
|
| 78 |
+
assert NNAUDIO_INSTALLED, "ERROR: feature_extractor_cqt requires the libray 'nnAudio', try after `pip install nnAudio` "
|
| 79 |
+
print('initializing cqt extractor for MERT')
|
| 80 |
+
self.feature_extractor_cqt = nnAudioFeatures.cqt.CQT(sr=self.config.sample_rate, hop_length=self.config.sample_rate//50, fmin=32.7,
|
| 81 |
+
fmax=None, n_bins=self.config.feature_extractor_cqt_bins, bins_per_octave=self.config.feature_extractor_cqt_bins//7,
|
| 82 |
+
filter_scale=1, norm=1, window='hann', center=True,
|
| 83 |
+
pad_mode='constant', trainable=False,
|
| 84 |
+
output_format='Magnitude', verbose=True)
|
| 85 |
+
|
| 86 |
+
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
| 87 |
+
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if config.do_stable_layer_norm:
|
| 91 |
+
assert not config.deepnorm, "must use post-layer_norm with deepnorm"
|
| 92 |
+
self.encoder = HubertEncoderStableLayerNorm(config)
|
| 93 |
+
else:
|
| 94 |
+
if config.deepnorm:
|
| 95 |
+
self.encoder = HubertEncoder_extend(config)
|
| 96 |
+
else:
|
| 97 |
+
self.encoder = HubertEncoder(config)
|
| 98 |
+
|
| 99 |
+
# Initialize weights and apply final processing
|
| 100 |
+
self.post_init()
|
| 101 |
+
|
| 102 |
+
def forward(self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) -> Union[Tuple, BaseModelOutput]:
|
| 103 |
+
|
| 104 |
+
# return super().forward(input_values, attention_mask, mask_time_indices, output_attentions, output_hidden_states, return_dict)
|
| 105 |
+
|
| 106 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 107 |
+
output_hidden_states = (
|
| 108 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 109 |
+
)
|
| 110 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 111 |
+
|
| 112 |
+
extract_features = self.feature_extractor(input_values)
|
| 113 |
+
extract_features = extract_features.transpose(1, 2)
|
| 114 |
+
|
| 115 |
+
# add additional cqt features for transformer input
|
| 116 |
+
if self.config.feature_extractor_cqt:
|
| 117 |
+
features_cqt = self.feature_extractor_cqt(input_values).transpose(1, 2)
|
| 118 |
+
features_cqt = features_cqt[:,:extract_features.shape[1],:] # align shape
|
| 119 |
+
# # v2
|
| 120 |
+
# features_cqt = self.post_cqt_feature_proj(features_cqt)
|
| 121 |
+
# extract_features = self.feature_projection.layer_norm(extract_features) + self.feature_projection.layer_norm(features_cqt) #v2
|
| 122 |
+
# v3
|
| 123 |
+
extract_features = torch.cat([extract_features,features_cqt], 2)
|
| 124 |
+
|
| 125 |
+
if attention_mask is not None:
|
| 126 |
+
# compute reduced attention_mask corresponding to feature vectors
|
| 127 |
+
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
| 128 |
+
|
| 129 |
+
hidden_states = self.feature_projection(extract_features)
|
| 130 |
+
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
| 131 |
+
|
| 132 |
+
encoder_outputs = self.encoder(
|
| 133 |
+
hidden_states,
|
| 134 |
+
attention_mask=attention_mask,
|
| 135 |
+
output_attentions=output_attentions,
|
| 136 |
+
output_hidden_states=output_hidden_states,
|
| 137 |
+
return_dict=return_dict,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
hidden_states = encoder_outputs[0] # take last_hidden from encoder output
|
| 141 |
+
|
| 142 |
+
if not return_dict:
|
| 143 |
+
return (hidden_states,) + encoder_outputs[1:]
|
| 144 |
+
|
| 145 |
+
return BaseModelOutput(
|
| 146 |
+
last_hidden_state=hidden_states,
|
| 147 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 148 |
+
attentions=encoder_outputs.attentions,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class HubertEncoder_extend(HubertEncoder):
|
| 153 |
+
def __init__(self, config):
|
| 154 |
+
# super().__init__()
|
| 155 |
+
# call nn module initialization
|
| 156 |
+
nn.Module.__init__(self)
|
| 157 |
+
# super(HubertEncoder_extend, self).__init__()
|
| 158 |
+
|
| 159 |
+
self.config = config
|
| 160 |
+
self.pos_conv_embed = HubertPositionalConvEmbedding(config)
|
| 161 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 162 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
self.layers = nn.ModuleList([HubertEncoderLayerExtend(config) for _ in range(config.num_hidden_layers)])
|
| 166 |
+
|
| 167 |
+
self.gradient_checkpointing = False
|
| 168 |
+
|
| 169 |
+
if config.deepnorm:
|
| 170 |
+
import math
|
| 171 |
+
init_scale = math.pow(8.0 * config.num_hidden_layers, 0.25)
|
| 172 |
+
for name, p in self.named_parameters():
|
| 173 |
+
if (
|
| 174 |
+
"feed_forward.intermediate_dense" in name
|
| 175 |
+
or "feed_forward.output_dense" in name
|
| 176 |
+
or "out_proj" in name
|
| 177 |
+
or "v_proj" in name
|
| 178 |
+
):
|
| 179 |
+
p.data.div_(init_scale)
|
| 180 |
+
|
| 181 |
+
class HubertEncoderLayerExtend(HubertEncoderLayer):
|
| 182 |
+
def __init__(self, config):
|
| 183 |
+
nn.Module.__init__(self)
|
| 184 |
+
# super(HubertEncoderLayerExtend, self).__init__()
|
| 185 |
+
if config.attention_relax > 0 :
|
| 186 |
+
self.attention = HubertAttention_extend(
|
| 187 |
+
embed_dim=config.hidden_size,
|
| 188 |
+
num_heads=config.num_attention_heads,
|
| 189 |
+
dropout=config.attention_dropout,
|
| 190 |
+
is_decoder=False,
|
| 191 |
+
attention_relax=config.attention_relax,
|
| 192 |
+
)
|
| 193 |
+
else:
|
| 194 |
+
self.attention = HubertAttention(
|
| 195 |
+
embed_dim=config.hidden_size,
|
| 196 |
+
num_heads=config.num_attention_heads,
|
| 197 |
+
dropout=config.attention_dropout,
|
| 198 |
+
is_decoder=False,
|
| 199 |
+
)
|
| 200 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 201 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 202 |
+
self.feed_forward = HubertFeedForward(config)
|
| 203 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 204 |
+
|
| 205 |
+
if config.deepnorm:
|
| 206 |
+
import math
|
| 207 |
+
self.residual_alpha = math.pow(2.0 * config.num_hidden_layers, 0.25)
|
| 208 |
+
else:
|
| 209 |
+
self.residual_alpha = 1.0
|
| 210 |
+
|
| 211 |
+
def residual_connection(self, x, residual):
|
| 212 |
+
'''
|
| 213 |
+
residual: input before f()
|
| 214 |
+
x: output of f(residual)
|
| 215 |
+
'''
|
| 216 |
+
return residual * self.residual_alpha + x
|
| 217 |
+
|
| 218 |
+
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
| 219 |
+
attn_residual = hidden_states
|
| 220 |
+
hidden_states, attn_weights, _ = self.attention(
|
| 221 |
+
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
| 222 |
+
)
|
| 223 |
+
hidden_states = self.dropout(hidden_states)
|
| 224 |
+
|
| 225 |
+
# hidden_states = attn_residual + hidden_states
|
| 226 |
+
hidden_states = self.residual_connection(hidden_states, attn_residual)
|
| 227 |
+
|
| 228 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 229 |
+
|
| 230 |
+
# hidden_states = hidden_states + self.feed_forward(hidden_states)
|
| 231 |
+
ffn_residual = hidden_states
|
| 232 |
+
hidden_states = self.feed_forward(hidden_states)
|
| 233 |
+
hidden_states = self.residual_connection(hidden_states, ffn_residual)
|
| 234 |
+
|
| 235 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 236 |
+
|
| 237 |
+
outputs = (hidden_states,)
|
| 238 |
+
|
| 239 |
+
if output_attentions:
|
| 240 |
+
outputs += (attn_weights,)
|
| 241 |
+
|
| 242 |
+
return outputs
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class HubertAttention_extend(nn.Module):
|
| 246 |
+
def __init__(
|
| 247 |
+
self,
|
| 248 |
+
embed_dim: int,
|
| 249 |
+
num_heads: int,
|
| 250 |
+
dropout: float = 0.0,
|
| 251 |
+
is_decoder: bool = False,
|
| 252 |
+
bias: bool = True,
|
| 253 |
+
attention_relax: float = -1.0,
|
| 254 |
+
):
|
| 255 |
+
super().__init__()
|
| 256 |
+
# nn.Module.__init__(self)
|
| 257 |
+
self.embed_dim = embed_dim
|
| 258 |
+
self.num_heads = num_heads
|
| 259 |
+
self.dropout = dropout
|
| 260 |
+
self.head_dim = embed_dim // num_heads
|
| 261 |
+
|
| 262 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
| 263 |
+
raise ValueError(
|
| 264 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
| 265 |
+
f" and `num_heads`: {num_heads})."
|
| 266 |
+
)
|
| 267 |
+
self.scaling = self.head_dim**-0.5
|
| 268 |
+
self.is_decoder = is_decoder
|
| 269 |
+
|
| 270 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 271 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 272 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 273 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 274 |
+
|
| 275 |
+
if attention_relax > 0:
|
| 276 |
+
self.attention_relax = attention_relax
|
| 277 |
+
|
| 278 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 279 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 280 |
+
|
| 281 |
+
def forward(
|
| 282 |
+
self,
|
| 283 |
+
hidden_states: torch.Tensor,
|
| 284 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 285 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 286 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 287 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 288 |
+
output_attentions: bool = False,
|
| 289 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 290 |
+
"""Input shape: Batch x Time x Channel"""
|
| 291 |
+
|
| 292 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 293 |
+
# for the decoder
|
| 294 |
+
is_cross_attention = key_value_states is not None
|
| 295 |
+
|
| 296 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 297 |
+
|
| 298 |
+
# get query proj
|
| 299 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
| 300 |
+
# get key, value proj
|
| 301 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
| 302 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
| 303 |
+
# the provided `key_value_states` to support prefix tuning
|
| 304 |
+
if (
|
| 305 |
+
is_cross_attention
|
| 306 |
+
and past_key_value is not None
|
| 307 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 308 |
+
):
|
| 309 |
+
# reuse k,v, cross_attentions
|
| 310 |
+
key_states = past_key_value[0]
|
| 311 |
+
value_states = past_key_value[1]
|
| 312 |
+
elif is_cross_attention:
|
| 313 |
+
# cross_attentions
|
| 314 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
| 315 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
| 316 |
+
elif past_key_value is not None:
|
| 317 |
+
# reuse k, v, self_attention
|
| 318 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 319 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 320 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 321 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 322 |
+
else:
|
| 323 |
+
# self_attention
|
| 324 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 325 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 326 |
+
|
| 327 |
+
if self.is_decoder:
|
| 328 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 329 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 330 |
+
# key/value_states (first "if" case)
|
| 331 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 332 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 333 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 334 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 335 |
+
past_key_value = (key_states, value_states)
|
| 336 |
+
|
| 337 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
| 338 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
| 339 |
+
key_states = key_states.view(*proj_shape)
|
| 340 |
+
value_states = value_states.view(*proj_shape)
|
| 341 |
+
|
| 342 |
+
src_len = key_states.size(1)
|
| 343 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
| 344 |
+
|
| 345 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
| 346 |
+
raise ValueError(
|
| 347 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
| 348 |
+
f" {attn_weights.size()}"
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
if attention_mask is not None:
|
| 352 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
| 353 |
+
raise ValueError(
|
| 354 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
| 355 |
+
)
|
| 356 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
| 357 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 358 |
+
|
| 359 |
+
if self.attention_relax > 0:
|
| 360 |
+
# => (bsz, self.num_heads, tgt_len, src_len)
|
| 361 |
+
# attn_weights_relax = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)/self.attention_relax
|
| 362 |
+
# => (bsz*self.num_heads, tgt_len, src_len)
|
| 363 |
+
attn_weights_relax = attn_weights / self.attention_relax
|
| 364 |
+
|
| 365 |
+
# => (bsz* self.num_heads, tgt_len, 1)
|
| 366 |
+
attn_max_relax = torch.max(attn_weights_relax, dim=-1, keepdim=False).unsqueeze(2)
|
| 367 |
+
attn_weights = (attn_weights_relax - attn_max_relax) * self.attention_relax
|
| 368 |
+
|
| 369 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 370 |
+
|
| 371 |
+
if layer_head_mask is not None:
|
| 372 |
+
if layer_head_mask.size() != (self.num_heads,):
|
| 373 |
+
raise ValueError(
|
| 374 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
| 375 |
+
f" {layer_head_mask.size()}"
|
| 376 |
+
)
|
| 377 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 378 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 379 |
+
|
| 380 |
+
if output_attentions:
|
| 381 |
+
# this operation is a bit awkward, but it's required to
|
| 382 |
+
# make sure that attn_weights keeps its gradient.
|
| 383 |
+
# In order to do so, attn_weights have to be reshaped
|
| 384 |
+
# twice and have to be reused in the following
|
| 385 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 386 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
| 387 |
+
else:
|
| 388 |
+
attn_weights_reshaped = None
|
| 389 |
+
|
| 390 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 391 |
+
|
| 392 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
| 393 |
+
|
| 394 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
| 395 |
+
raise ValueError(
|
| 396 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
| 397 |
+
f" {attn_output.size()}"
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
| 401 |
+
attn_output = attn_output.transpose(1, 2)
|
| 402 |
+
|
| 403 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
| 404 |
+
# partitioned aross GPUs when using tensor-parallelism.
|
| 405 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
| 406 |
+
|
| 407 |
+
attn_output = self.out_proj(attn_output)
|
| 408 |
+
|
| 409 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_normalize": false,
|
| 3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
| 4 |
+
"feature_size": 1,
|
| 5 |
+
"padding_side": "right",
|
| 6 |
+
"padding_value": 0,
|
| 7 |
+
"return_attention_mask": true,
|
| 8 |
+
"sampling_rate": 24000
|
| 9 |
+
}
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0c03774d821a6ad972cb48220b128a10ebea2c790116bd0fd35e5a013f8017f6
|
| 3 |
+
size 1261846489
|