Spaces:
Build error
Build error
Commit ·
b1175d1
1
Parent(s): bb739d5
Initial commit
Browse files- .gitattributes +8 -0
- .gitignore +10 -0
- .gradio/certificate.pem +31 -0
- README.md +7 -7
- app.py +773 -0
- higgs_audio/__init__.py +1 -0
- higgs_audio/audio_processing/LICENSE +51 -0
- higgs_audio/audio_processing/descriptaudiocodec/__init__.py +0 -0
- higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py +286 -0
- higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py +365 -0
- higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py +33 -0
- higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py +251 -0
- higgs_audio/audio_processing/higgs_audio_tokenizer.py +341 -0
- higgs_audio/audio_processing/quantization/__init__.py +8 -0
- higgs_audio/audio_processing/quantization/ac.py +301 -0
- higgs_audio/audio_processing/quantization/core_vq.py +360 -0
- higgs_audio/audio_processing/quantization/core_vq_lsx_version.py +431 -0
- higgs_audio/audio_processing/quantization/ddp_utils.py +197 -0
- higgs_audio/audio_processing/quantization/distrib.py +123 -0
- higgs_audio/audio_processing/quantization/vq.py +116 -0
- higgs_audio/audio_processing/semantic_module.py +310 -0
- higgs_audio/constants.py +3 -0
- higgs_audio/data_collator/__init__.py +0 -0
- higgs_audio/data_collator/higgs_audio_collator.py +583 -0
- higgs_audio/data_types.py +38 -0
- higgs_audio/dataset/__init__.py +0 -0
- higgs_audio/dataset/chatml_dataset.py +554 -0
- higgs_audio/model/__init__.py +9 -0
- higgs_audio/model/audio_head.py +139 -0
- higgs_audio/model/common.py +27 -0
- higgs_audio/model/configuration_higgs_audio.py +235 -0
- higgs_audio/model/cuda_graph_runner.py +129 -0
- higgs_audio/model/custom_modules.py +155 -0
- higgs_audio/model/modeling_higgs_audio.py +0 -0
- higgs_audio/model/utils.py +778 -0
- higgs_audio/serve/serve_engine.py +474 -0
- higgs_audio/serve/utils.py +254 -0
- pyproject.toml +100 -0
- requirements.txt +17 -0
- theme.json +285 -0
- voice_examples/config.json +30 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
voice_examples/en_woman.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
voice_examples/mabel.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
voice_examples/vex.wav filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
voice_examples/zh_man_sichuan.wav filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
voice_examples/belinda.wav filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
voice_examples/broom_salesman.wav filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
voice_examples/chadwick.wav filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
voice_examples/en_man.wav filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
*.pyw
|
| 6 |
+
*.pyz
|
| 7 |
+
*.pywz
|
| 8 |
+
*.pyzw
|
| 9 |
+
*.pyzwz
|
| 10 |
+
.ruff_cache/
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
README.md
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
|
| 11 |
---
|
| 12 |
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Banana Voice
|
| 3 |
+
emoji: 🎤
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.36.2
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
short_description: Banana Voice
|
| 11 |
---
|
| 12 |
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,773 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Xentrik Audio Text-to-Speech - TTS for Chatters
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import base64
|
| 7 |
+
import os
|
| 8 |
+
import uuid
|
| 9 |
+
import json
|
| 10 |
+
from typing import Optional
|
| 11 |
+
import gradio as gr
|
| 12 |
+
from loguru import logger
|
| 13 |
+
import numpy as np
|
| 14 |
+
import time
|
| 15 |
+
from functools import lru_cache
|
| 16 |
+
import re
|
| 17 |
+
import spaces
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
# Import HiggsAudio components
|
| 21 |
+
from higgs_audio.serve.serve_engine import HiggsAudioServeEngine
|
| 22 |
+
from higgs_audio.data_types import ChatMLSample, AudioContent, Message
|
| 23 |
+
|
| 24 |
+
# Global engine instance
|
| 25 |
+
engine = None
|
| 26 |
+
|
| 27 |
+
# Default model configuration
|
| 28 |
+
DEFAULT_MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
|
| 29 |
+
DEFAULT_AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
|
| 30 |
+
SAMPLE_RATE = 24000
|
| 31 |
+
|
| 32 |
+
DEFAULT_SYSTEM_PROMPT = (
|
| 33 |
+
"Generate audio following instruction.\n\n"
|
| 34 |
+
"<|scene_desc_start|>\n"
|
| 35 |
+
"Audio is recorded from a Quiet Bedroom.\n"
|
| 36 |
+
"<|scene_desc_end|>"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
DEFAULT_STOP_STRINGS = ["<|end_of_text|>", "<|eot_id|>"]
|
| 40 |
+
|
| 41 |
+
# Predefined examples for system and input messages - OnlyFans themed
|
| 42 |
+
PREDEFINED_EXAMPLES = {
|
| 43 |
+
"voice-clone": {
|
| 44 |
+
"system_prompt": "",
|
| 45 |
+
"input_text": "Hey there! I'm your hottest, sexiest and sweetest Voice Cloner. See below for Custom Reference, Click on it and upload a voice file. - let's clone some vocals and bring your voice to life! ",
|
| 46 |
+
"description": "Voice clone to clone the reference audio. Leave the system prompt empty.",
|
| 47 |
+
},
|
| 48 |
+
"smart-voice": {
|
| 49 |
+
"system_prompt": DEFAULT_SYSTEM_PROMPT,
|
| 50 |
+
"input_text": "The OnlyFans models are becoming richer than movie stars, I am doing nothing but OFM for sure. this is about to go super crazy!",
|
| 51 |
+
"description": "Smart voice to generate speech based on the context",
|
| 52 |
+
},
|
| 53 |
+
"multispeaker-voice-description": {
|
| 54 |
+
"system_prompt": "You are an AI assistant designed to convert text into speech.\n"
|
| 55 |
+
"If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.\n"
|
| 56 |
+
"If no speaker tag is present, select a suitable voice on your own.\n\n"
|
| 57 |
+
"<|scene_desc_start|>\n"
|
| 58 |
+
"SPEAKER0: feminine\n"
|
| 59 |
+
"SPEAKER1: masculine\n"
|
| 60 |
+
"<|scene_desc_end|>",
|
| 61 |
+
"input_text": "[SPEAKER0] I can't believe you did that without even asking me first!\n"
|
| 62 |
+
"[SPEAKER1] Oh, come on! It wasn't a big deal, and I knew you would overreact like this.\n"
|
| 63 |
+
"[SPEAKER0] Overreact? You made a decision that affects both of us without even considering my opinion!\n"
|
| 64 |
+
"[SPEAKER1] Because I didn't have time to sit around waiting for you to make up your mind! Someone had to act.",
|
| 65 |
+
"description": "Multispeaker with different voice descriptions in the system prompt",
|
| 66 |
+
},
|
| 67 |
+
"single-speaker-voice-description": {
|
| 68 |
+
"system_prompt": "Generate audio following instruction.\n\n"
|
| 69 |
+
"<|scene_desc_start|>\n"
|
| 70 |
+
"SPEAKER0: He speaks with a clear British accent and a conversational, inquisitive tone. His delivery is articulate and at a moderate pace, and very clear audio.\n"
|
| 71 |
+
"<|scene_desc_end|>",
|
| 72 |
+
"input_text": "Hey, everyone! Welcome back to Tech Talk Tuesdays.\n"
|
| 73 |
+
"It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n"
|
| 74 |
+
"And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.\n"
|
| 75 |
+
"\n"
|
| 76 |
+
"So here's the big question: Do you want to understand how deep learning works?\n",
|
| 77 |
+
"description": "Single speaker with voice description in the system prompt",
|
| 78 |
+
},
|
| 79 |
+
"single-speaker-bgm": {
|
| 80 |
+
"system_prompt": DEFAULT_SYSTEM_PROMPT,
|
| 81 |
+
"input_text": "[music start] I will remember this, thought Ender, when I am defeated. To keep dignity, and give honor where it's due, so that defeat is not disgrace. And I hope I don't have to do it often. [music end]",
|
| 82 |
+
"description": "Single speaker with BGM using music tag. This is an experimental feature and you may need to try multiple times to get the best result.",
|
| 83 |
+
},
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@lru_cache(maxsize=20)
|
| 88 |
+
def encode_audio_file(file_path):
|
| 89 |
+
"""Encode an audio file to base64."""
|
| 90 |
+
with open(file_path, "rb") as audio_file:
|
| 91 |
+
return base64.b64encode(audio_file.read()).decode("utf-8")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_current_device():
|
| 95 |
+
"""Get the current device."""
|
| 96 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def load_voice_presets():
|
| 100 |
+
"""Load the voice presets from both config.json and voice files directory."""
|
| 101 |
+
voice_presets = {}
|
| 102 |
+
voice_dir = os.path.join(os.path.dirname(__file__), "voice_examples")
|
| 103 |
+
config_path = os.path.join(voice_dir, "config.json")
|
| 104 |
+
|
| 105 |
+
# First try to load from config.json if it exists
|
| 106 |
+
if os.path.exists(config_path):
|
| 107 |
+
try:
|
| 108 |
+
with open(config_path, "r", encoding='utf-8') as f:
|
| 109 |
+
voice_dict = json.load(f)
|
| 110 |
+
|
| 111 |
+
for voice_id, voice_info in voice_dict.items():
|
| 112 |
+
if isinstance(voice_info, dict) and "transcript" in voice_info:
|
| 113 |
+
voice_presets[voice_id] = voice_info["transcript"]
|
| 114 |
+
else:
|
| 115 |
+
voice_presets[voice_id] = f"Voice sample {voice_id}"
|
| 116 |
+
|
| 117 |
+
logger.info(f"Loaded {len(voice_presets)} voice presets from config.json")
|
| 118 |
+
|
| 119 |
+
except json.JSONDecodeError as e:
|
| 120 |
+
logger.warning(f"Config.json has JSON syntax error: {e}. Will scan folder instead.")
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.warning(f"Error loading config.json: {e}. Will scan folder instead.")
|
| 123 |
+
|
| 124 |
+
# Then scan the voice_examples folder for any .wav files
|
| 125 |
+
if os.path.exists(voice_dir):
|
| 126 |
+
wav_files = [f for f in os.listdir(voice_dir) if f.endswith('.wav')]
|
| 127 |
+
|
| 128 |
+
for wav_file in wav_files:
|
| 129 |
+
voice_id = os.path.splitext(wav_file)[0] # Remove .wav extension
|
| 130 |
+
|
| 131 |
+
# Only add if not already in config or if we need to override
|
| 132 |
+
if voice_id not in voice_presets:
|
| 133 |
+
# Create a friendly name from the filename
|
| 134 |
+
friendly_name = voice_id.replace('_', ' ').title()
|
| 135 |
+
voice_presets[voice_id] = f"{friendly_name} Voice"
|
| 136 |
+
logger.info(f"Added voice preset from file: {voice_id}")
|
| 137 |
+
|
| 138 |
+
# Always include EMPTY option
|
| 139 |
+
voice_presets["EMPTY"] = "No reference voice"
|
| 140 |
+
|
| 141 |
+
logger.info(f"Total voice presets available: {list(voice_presets.keys())}")
|
| 142 |
+
return voice_presets
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_voice_preset(voice_preset):
|
| 146 |
+
"""Get the voice path and text for a given voice preset."""
|
| 147 |
+
if voice_preset == "EMPTY":
|
| 148 |
+
return None, ""
|
| 149 |
+
|
| 150 |
+
voice_path = os.path.join(os.path.dirname(__file__), "voice_examples", f"{voice_preset}.wav")
|
| 151 |
+
|
| 152 |
+
if not os.path.exists(voice_path):
|
| 153 |
+
logger.warning(f"Voice preset file not found: {voice_path}")
|
| 154 |
+
return None, "Voice preset not found"
|
| 155 |
+
|
| 156 |
+
# Get the transcript from loaded presets or create a default
|
| 157 |
+
voice_presets = load_voice_presets()
|
| 158 |
+
text = voice_presets.get(voice_preset, f"Voice sample: {voice_preset}")
|
| 159 |
+
|
| 160 |
+
return voice_path, text
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def normalize_chinese_punctuation(text):
|
| 164 |
+
"""
|
| 165 |
+
Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
|
| 166 |
+
"""
|
| 167 |
+
# Mapping of Chinese punctuation to English punctuation
|
| 168 |
+
chinese_to_english_punct = {
|
| 169 |
+
",": ", ", # comma
|
| 170 |
+
"。": ".", # period
|
| 171 |
+
":": ":", # colon
|
| 172 |
+
";": ";", # semicolon
|
| 173 |
+
"?": "?", # question mark
|
| 174 |
+
"!": "!", # exclamation mark
|
| 175 |
+
"极": "(", # left parenthesis
|
| 176 |
+
")": ")", # right parenthesis
|
| 177 |
+
"【": "[", # left square bracket
|
| 178 |
+
"】": "]", # right square bracket
|
| 179 |
+
"《": "<", # left angle quote
|
| 180 |
+
"》": ">", # right angle quote
|
| 181 |
+
"“": '"', # left double quotation
|
| 182 |
+
"”": '"', # right double quotation
|
| 183 |
+
"‘": "'", # left single quotation
|
| 184 |
+
"’": "'", # right single quotation
|
| 185 |
+
"、": ",", # enumeration comma
|
| 186 |
+
"—": "-", # em dash
|
| 187 |
+
"…": "...", # ellipsis
|
| 188 |
+
"·": ".", # middle dot
|
| 189 |
+
"「": '"', # left corner bracket
|
| 190 |
+
"」": '"', # right corner bracket
|
| 191 |
+
"『": '"', # left double corner bracket
|
| 192 |
+
"』": '"', # right double corner bracket
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
# Replace each Chinese punctuation with its English counterpart
|
| 196 |
+
for zh_punct, en_punct in chinese_to_english_punct.items():
|
| 197 |
+
text = text.replace(zh_punct, en_punct)
|
| 198 |
+
|
| 199 |
+
return text
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def normalize_text(transcript: str):
|
| 203 |
+
transcript = normalize_chinese_punctuation(transcript)
|
| 204 |
+
# Other normalizations (e.g., parentheses and other symbols. Will be improved in the future)
|
| 205 |
+
transcript = transcript.replace("(", " ")
|
| 206 |
+
transcript = transcript.replace(")", " ")
|
| 207 |
+
transcript = transcript.replace("°F", " degrees Fahrenheit")
|
| 208 |
+
transcript = transcript.replace("°C", " degrees Celsius")
|
| 209 |
+
|
| 210 |
+
for tag, replacement in [
|
| 211 |
+
("[laugh]", "<SE>[Laughter]</SE>"),
|
| 212 |
+
("[humming start]", "<SE>[Humming]</SE>"),
|
| 213 |
+
("[humming end]", "<SE_e>[Humming]</SE_e>"),
|
| 214 |
+
("[music start]", "<SE_s>[Music]</SE_s>"),
|
| 215 |
+
("[music end]", "<SE_e>[Music]</SE_e>"),
|
| 216 |
+
("[music]", "<SE>[Music]</SE>"),
|
| 217 |
+
("[sing start]", "<SE_s>[Singing]</SE_s>"),
|
| 218 |
+
("[sing end]", "<SE_e>[Singing]</SE_e>"),
|
| 219 |
+
("[applause]", "<SE>[Applause]</SE>"),
|
| 220 |
+
("[cheering]", "<SE>[Cheering]</SE>"),
|
| 221 |
+
("[cough]", "<SE>[Cough]</SE>"),
|
| 222 |
+
]:
|
| 223 |
+
transcript = transcript.replace(tag, replacement)
|
| 224 |
+
|
| 225 |
+
lines = transcript.split("\n")
|
| 226 |
+
transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()])
|
| 227 |
+
transcript = transcript.strip()
|
| 228 |
+
|
| 229 |
+
if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]):
|
| 230 |
+
transcript += "."
|
| 231 |
+
|
| 232 |
+
return transcript
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@spaces.GPU
|
| 236 |
+
def initialize_engine(model_path, audio_tokenizer_path) -> bool:
|
| 237 |
+
"""
|
| 238 |
+
Initialize the HiggsAudioServeEngine with the specified model and tokenizer.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
model_path: Path to the model to load
|
| 242 |
+
audio_tokenizer_path: Path to the audio tokenizer to load
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
True if initialization was successful, False otherwise
|
| 246 |
+
"""
|
| 247 |
+
global engine
|
| 248 |
+
try:
|
| 249 |
+
logger.info(f"Initializing engine with model: {model_path} and audio tokenizer: {audio_tokenizer_path}")
|
| 250 |
+
engine = HiggsAudioServeEngine(
|
| 251 |
+
model_name_or_path=model_path,
|
| 252 |
+
audio_tokenizer_name_or_path=audio_tokenizer_path,
|
| 253 |
+
device=get_current_device(),
|
| 254 |
+
)
|
| 255 |
+
logger.info(f"Successfully initialized HiggsAudioServeEngine with model: {model_path}")
|
| 256 |
+
return True
|
| 257 |
+
except Exception as e:
|
| 258 |
+
logger.error(f"Failed to initialize engine: {e}")
|
| 259 |
+
return False
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def check_return_audio(audio_wv: np.ndarray):
|
| 263 |
+
# check if the audio returned is all silent
|
| 264 |
+
if np.all(audio_wv == 0):
|
| 265 |
+
logger.warning("Audio is silent, returning None")
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def process_text_output(text_output: str):
|
| 269 |
+
# remove all the continuous <|AUDIO_OUT|> tokens with a single <|AUDIO_OUT|>
|
| 270 |
+
text_output = re.sub(r"(<\|AUDIO_OUT\|>)+", r"<|AUDIO_OUT|>", text_output)
|
| 271 |
+
return text_output
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def prepare_chatml_sample(
|
| 275 |
+
voice_preset: str,
|
| 276 |
+
text: str,
|
| 277 |
+
reference_audio: Optional[str] = None,
|
| 278 |
+
reference_text: Optional[str] = None,
|
| 279 |
+
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
| 280 |
+
):
|
| 281 |
+
"""Prepare a ChatMLSample for the HiggsAudioServeEngine."""
|
| 282 |
+
messages = []
|
| 283 |
+
|
| 284 |
+
# Add system message if provided
|
| 285 |
+
if len(system_prompt) > 0:
|
| 286 |
+
messages.append(Message(role="system", content=system_prompt))
|
| 287 |
+
|
| 288 |
+
# Add reference audio if provided
|
| 289 |
+
audio_base64 = None
|
| 290 |
+
ref_text = ""
|
| 291 |
+
|
| 292 |
+
if reference_audio:
|
| 293 |
+
# Custom reference audio
|
| 294 |
+
audio_base64 = encode_audio_file(reference_audio)
|
| 295 |
+
ref_text = reference_text or ""
|
| 296 |
+
elif voice_preset != "EMPTY":
|
| 297 |
+
# Voice preset
|
| 298 |
+
voice_path, ref_text = get_voice_preset(voice_preset)
|
| 299 |
+
if voice_path is None:
|
| 300 |
+
logger.warning(f"Voice preset {voice_preset} not found, skipping reference audio")
|
| 301 |
+
else:
|
| 302 |
+
audio_base64 = encode_audio_file(voice_path)
|
| 303 |
+
|
| 304 |
+
# Only add reference audio if we have it
|
| 305 |
+
if audio_base64 is not None:
|
| 306 |
+
# Add user message with reference text
|
| 307 |
+
messages.append(Message(role="user", content=ref_text))
|
| 308 |
+
|
| 309 |
+
# Add assistant message with audio content
|
| 310 |
+
audio_content = AudioContent(raw_audio=audio_base64, audio_url="")
|
| 311 |
+
messages.append(Message(role="assistant", content=[audio_content]))
|
| 312 |
+
|
| 313 |
+
# Add the main user message
|
| 314 |
+
text = normalize_text(text)
|
| 315 |
+
messages.append(Message(role="user", content=text))
|
| 316 |
+
|
| 317 |
+
return ChatMLSample(messages=messages)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@spaces.GPU(duration=120)
|
| 321 |
+
def text_to_speech(
|
| 322 |
+
text,
|
| 323 |
+
voice_preset,
|
| 324 |
+
reference_audio=None,
|
| 325 |
+
reference_text=None,
|
| 326 |
+
max_completion_tokens=1024,
|
| 327 |
+
temperature=1.0,
|
| 328 |
+
top_p=0.95,
|
| 329 |
+
top_k=50,
|
| 330 |
+
system_prompt=DEFAULT_SYSTEM_PROMPT,
|
| 331 |
+
stop_strings=None,
|
| 332 |
+
ras_win_len=7,
|
| 333 |
+
ras_win_max_num_repeat=2,
|
| 334 |
+
):
|
| 335 |
+
"""
|
| 336 |
+
Convert text to speech using HiggsAudioServeEngine.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
text: The text to convert to speech
|
| 340 |
+
voice_preset: The voice preset to use (or "EMPTY" for no preset)
|
| 341 |
+
reference_audio: Optional path to reference audio file
|
| 342 |
+
reference_text: Optional transcript of the reference audio
|
| 343 |
+
max_completion_tokens: Maximum number of tokens to generate
|
| 344 |
+
temperature: Sampling temperature for generation
|
| 345 |
+
top_p: Top-p sampling parameter
|
| 346 |
+
top_k: Top-k sampling parameter
|
| 347 |
+
system_prompt: System prompt to guide the model
|
| 348 |
+
stop_strings: Dataframe containing stop strings
|
| 349 |
+
ras_win_len: Window length for repetition avoidance sampling
|
| 350 |
+
ras_win_max_num_repeat: Maximum number of repetitions allowed in the window
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
Tuple of (generated_text, (sample_rate, audio_data)) where audio_data is int16 numpy array
|
| 354 |
+
"""
|
| 355 |
+
global engine
|
| 356 |
+
|
| 357 |
+
if engine is None:
|
| 358 |
+
initialize_engine(DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH)
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
# Prepare ChatML sample
|
| 362 |
+
chatml_sample = prepare_chatml_sample(voice_preset, text, reference_audio, reference_text, system_prompt)
|
| 363 |
+
|
| 364 |
+
# Convert stop strings format
|
| 365 |
+
if stop_strings is None:
|
| 366 |
+
stop_list = DEFAULT_STOP_STRINGS
|
| 367 |
+
else:
|
| 368 |
+
stop_list = [s for s in stop_strings["stops"] if s.strip()]
|
| 369 |
+
|
| 370 |
+
request_id = f"tts-playground-{str(uuid.uuid4())}"
|
| 371 |
+
logger.info(
|
| 372 |
+
f"{request_id}: Generating speech for text: {text[:100]}..., \n"
|
| 373 |
+
f"with parameters: temperature={temperature}, top_p={top_p}, top_k={top_k}, stop_list={stop_list}, "
|
| 374 |
+
f"ras_win_len={ras_win_len}, ras_win_max_num_repeat={ras_win_max_num_repeat}"
|
| 375 |
+
)
|
| 376 |
+
start_time = time.time()
|
| 377 |
+
|
| 378 |
+
# Generate using the engine
|
| 379 |
+
response = engine.generate(
|
| 380 |
+
chat_ml_sample=chatml_sample,
|
| 381 |
+
max_new_tokens=max_completion_tokens,
|
| 382 |
+
temperature=temperature,
|
| 383 |
+
top_k=top_k if top_k > 0 else None,
|
| 384 |
+
top_p=top_p,
|
| 385 |
+
stop_strings=stop_list,
|
| 386 |
+
ras_win_len=ras_win_len if ras_win_len > 0 else None,
|
| 387 |
+
ras_win_max_num_repeat=max(ras_win_len, ras_win_max_num_repeat),
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
generation_time = time.time() - start_time
|
| 391 |
+
logger.info(f"{request_id}: Generated audio in {generation_time:.3f} seconds")
|
| 392 |
+
gr.Info(f"Generated audio in {generation_time:.3f} seconds")
|
| 393 |
+
|
| 394 |
+
# Process the response
|
| 395 |
+
text_output = process_text_output(response.generated_text)
|
| 396 |
+
|
| 397 |
+
if response.audio is not None:
|
| 398 |
+
# Convert to int16 for Gradio
|
| 399 |
+
audio_data = (response.audio * 32767).astype(np.int16)
|
| 400 |
+
check_return_audio(audio_data)
|
| 401 |
+
return text_output, (response.sampling_rate, audio_data)
|
| 402 |
+
else:
|
| 403 |
+
logger.warning("No audio generated")
|
| 404 |
+
return text_output, None
|
| 405 |
+
|
| 406 |
+
except Exception as e:
|
| 407 |
+
error_msg = f"Error generating speech: {e}"
|
| 408 |
+
logger.error(error_msg)
|
| 409 |
+
gr.Error(error_msg)
|
| 410 |
+
return f"❌ {error_msg}", None
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def create_ui():
|
| 414 |
+
# Load voice presets
|
| 415 |
+
VOICE_PRESETS = load_voice_presets()
|
| 416 |
+
|
| 417 |
+
# Load theme with fallback
|
| 418 |
+
try:
|
| 419 |
+
my_theme = gr.Theme.load("theme.json")
|
| 420 |
+
except:
|
| 421 |
+
my_theme = gr.themes.Default()
|
| 422 |
+
logger.warning("Using default theme - theme.json not found")
|
| 423 |
+
|
| 424 |
+
# ... rest of your UI code ...
|
| 425 |
+
|
| 426 |
+
# Custom OnlyFans-inspired CSS
|
| 427 |
+
custom_css = """
|
| 428 |
+
.gradio-container {
|
| 429 |
+
max-width: 1200px;
|
| 430 |
+
margin: 0 auto;
|
| 431 |
+
border-radius: 20px;
|
| 432 |
+
background: rgba(255, 255, 255, 0.9);
|
| 433 |
+
backdrop-filter: blur(10px);
|
| 434 |
+
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1);
|
| 435 |
+
padding: 25px;
|
| 436 |
+
}
|
| 437 |
+
h1 {
|
| 438 |
+
background: linear-gradient(90deg, #ff4da6 0%, #9d4dff 100%);
|
| 439 |
+
-webkit-background-clip: text;
|
| 440 |
+
-webkit-text-fill-color: transparent;
|
| 441 |
+
text-align: center;
|
| 442 |
+
font-size: 2.5em;
|
| 443 |
+
font-weight: 800;
|
| 444 |
+
margin-bottom: 10px;
|
| 445 |
+
}
|
| 446 |
+
.gr-markdown p {
|
| 447 |
+
text-align: center;
|
| 448 |
+
color: #666;
|
| 449 |
+
font-size: 1.1em;
|
| 450 |
+
margin-bottom: 25px;
|
| 451 |
+
}
|
| 452 |
+
.gr-box {
|
| 453 |
+
border-radius: 15px;
|
| 454 |
+
border: 2px solid #ff4da6;
|
| 455 |
+
padding: 15px;
|
| 456 |
+
}
|
| 457 |
+
textarea, select, input {
|
| 458 |
+
border-radius: 15px;
|
| 459 |
+
border: 2px solid #ff9ec4;
|
| 460 |
+
padding: 12px;
|
| 461 |
+
font-size: 1em;
|
| 462 |
+
transition: all 0.3s ease;
|
| 463 |
+
}
|
| 464 |
+
textarea:focus, select:focus, input:focus {
|
| 465 |
+
border-color: #ff4da6;
|
| 466 |
+
box-shadow: 0 0 0 2px rgba(255, 77, 166, 0.2);
|
| 467 |
+
outline: none;
|
| 468 |
+
}
|
| 469 |
+
button {
|
| 470 |
+
background: linear-gradient(90deg, #ff4da6 0%, #9d4dff 100%);
|
| 471 |
+
color: white;
|
| 472 |
+
border: none;
|
| 473 |
+
border-radius: 15px;
|
| 474 |
+
padding: 12px 25px;
|
| 475 |
+
font-weight: 600;
|
| 476 |
+
font-size: 1.1em;
|
| 477 |
+
cursor: pointer;
|
| 478 |
+
transition: all 0.3s ease;
|
| 479 |
+
}
|
| 480 |
+
button:hover {
|
| 481 |
+
transform: translateY(-2px);
|
| 482 |
+
box-shadow: 0 5px 15px rgba(255, 77, 166, 0.4);
|
| 483 |
+
}
|
| 484 |
+
.gr-accordion {
|
| 485 |
+
border-radius: 15px;
|
| 486 |
+
border: 2px solid #ff9ec4;
|
| 487 |
+
margin-bottom: 15px;
|
| 488 |
+
}
|
| 489 |
+
.gr-accordion .gr-button {
|
| 490 |
+
background: transparent;
|
| 491 |
+
color: #ff4da6;
|
| 492 |
+
font-weight: 600;
|
| 493 |
+
}
|
| 494 |
+
.label-wrap {
|
| 495 |
+
font-weight: 600;
|
| 496 |
+
color: #ff4da6;
|
| 497 |
+
margin-bottom: 8px;
|
| 498 |
+
}
|
| 499 |
+
.tooltip {
|
| 500 |
+
background: #ff4da6;
|
| 501 |
+
color: white;
|
| 502 |
+
}
|
| 503 |
+
"""
|
| 504 |
+
|
| 505 |
+
default_template = "smart-voice"
|
| 506 |
+
|
| 507 |
+
"""Create the Gradio UI."""
|
| 508 |
+
with gr.Blocks(theme=my_theme, css=custom_css, title="OnlyAgencies Audio Text-to-Speech") as demo:
|
| 509 |
+
gr.Markdown("# OnlyAgencies Audio Text-to-Speech")
|
| 510 |
+
gr.Markdown("Create irresistible audio messages that keep your fans coming back for more 😘")
|
| 511 |
+
|
| 512 |
+
# Main UI section
|
| 513 |
+
with gr.Row():
|
| 514 |
+
with gr.Column(scale=2):
|
| 515 |
+
# Template selection dropdown
|
| 516 |
+
template_dropdown = gr.Dropdown(
|
| 517 |
+
label="TTS Template",
|
| 518 |
+
choices=list(PREDEFINED_EXAMPLES.keys()),
|
| 519 |
+
value=default_template,
|
| 520 |
+
info="Select a predefined example for system and input messages.",
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# Template description display
|
| 524 |
+
template_description = gr.HTML(
|
| 525 |
+
value=f'<p style="font-size: 0.85em; color: #ff4da6; margin: 0; padding: 0;">{PREDEFINED_EXAMPLES[default_template]["description"]}</p>',
|
| 526 |
+
visible=True,
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
system_prompt = gr.TextArea(
|
| 530 |
+
label="System Prompt",
|
| 531 |
+
placeholder="Enter system prompt to guide the model...",
|
| 532 |
+
value=PREDEFINED_EXAMPLES[default_template]["system_prompt"],
|
| 533 |
+
lines=2,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
input_text = gr.TextArea(
|
| 537 |
+
label="Input Text",
|
| 538 |
+
placeholder="Type the text you want to convert to speech...",
|
| 539 |
+
value=PREDEFINED_EXAMPLES[default_template]["input_text"],
|
| 540 |
+
lines=5,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
voice_preset = gr.Dropdown(
|
| 544 |
+
label="Voice Preset",
|
| 545 |
+
choices=list(VOICE_PRESETS.keys()),
|
| 546 |
+
value="EMPTY",
|
| 547 |
+
interactive=False, # Disabled by default since default template is not voice-clone
|
| 548 |
+
visible=False,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
with gr.Accordion(
|
| 552 |
+
"Custom Reference (Optional)", open=False, visible=False
|
| 553 |
+
) as custom_reference_accordion:
|
| 554 |
+
reference_audio = gr.Audio(label="Reference Audio", type="filepath")
|
| 555 |
+
reference_text = gr.TextArea(
|
| 556 |
+
label="Reference Text (transcript of the reference audio)",
|
| 557 |
+
placeholder="Enter the transcript of your reference audio...",
|
| 558 |
+
lines=3,
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
| 562 |
+
max_completion_tokens = gr.Slider(
|
| 563 |
+
minimum=128,
|
| 564 |
+
maximum=4096,
|
| 565 |
+
value=1024,
|
| 566 |
+
step=10,
|
| 567 |
+
label="Max Completion Tokens",
|
| 568 |
+
)
|
| 569 |
+
temperature = gr.Slider(
|
| 570 |
+
minimum=0.0,
|
| 571 |
+
maximum=1.5,
|
| 572 |
+
value=1.0,
|
| 573 |
+
step=0.1,
|
| 574 |
+
label="Temperature",
|
| 575 |
+
)
|
| 576 |
+
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top P")
|
| 577 |
+
top_k = gr.Slider(minimum=-1, maximum=100, value=50, step=1, label="Top K")
|
| 578 |
+
ras_win_len = gr.Slider(
|
| 579 |
+
minimum=0,
|
| 580 |
+
maximum=10,
|
| 581 |
+
value=7,
|
| 582 |
+
step=1,
|
| 583 |
+
label="RAS Window Length",
|
| 584 |
+
info="Window length for repetition avoidance sampling",
|
| 585 |
+
)
|
| 586 |
+
ras_win_max_num_repeat = gr.Slider(
|
| 587 |
+
minimum=1,
|
| 588 |
+
maximum=10,
|
| 589 |
+
value=2,
|
| 590 |
+
step=1,
|
| 591 |
+
label="RAS Max Num Repeat",
|
| 592 |
+
info="Maximum number of repetitions allowed in the window",
|
| 593 |
+
)
|
| 594 |
+
# Add stop strings component
|
| 595 |
+
stop_strings = gr.Dataframe(
|
| 596 |
+
label="Stop Strings",
|
| 597 |
+
headers=["stops"],
|
| 598 |
+
datatype=["str"],
|
| 599 |
+
value=[[s] for s in DEFAULT_STOP_STRINGS],
|
| 600 |
+
interactive=True,
|
| 601 |
+
col_count=(1, "fixed"),
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
submit_btn = gr.Button("Generate Speech", variant="primary", scale=1)
|
| 605 |
+
|
| 606 |
+
with gr.Column(scale=2):
|
| 607 |
+
output_text = gr.TextArea(label="Model Response", lines=2)
|
| 608 |
+
|
| 609 |
+
# Audio output
|
| 610 |
+
output_audio = gr.Audio(label="Generated Audio", interactive=False, autoplay=True)
|
| 611 |
+
|
| 612 |
+
stop_btn = gr.Button("Stop Playback", variant="primary")
|
| 613 |
+
|
| 614 |
+
# Example voice
|
| 615 |
+
with gr.Row(visible=False) as voice_samples_section:
|
| 616 |
+
voice_samples_table = gr.Dataframe(
|
| 617 |
+
headers=["Voice Preset", "Sample Text"],
|
| 618 |
+
datatype=["str", "str"],
|
| 619 |
+
value=[[preset, text] for preset, text in VOICE_PRESETS.items() if preset != "EMPTY"],
|
| 620 |
+
interactive=False,
|
| 621 |
+
)
|
| 622 |
+
sample_audio = gr.Audio(label="Voice Sample")
|
| 623 |
+
|
| 624 |
+
# Function to play voice sample when clicking on a row
|
| 625 |
+
def play_voice_sample(evt: gr.SelectData):
|
| 626 |
+
"""
|
| 627 |
+
Play a voice sample when a row is clicked in the voice samples table.
|
| 628 |
+
|
| 629 |
+
Args:
|
| 630 |
+
evt: The select event containing the clicked row index
|
| 631 |
+
|
| 632 |
+
Returns:
|
| 633 |
+
Path to the voice sample audio file, or None if not found
|
| 634 |
+
"""
|
| 635 |
+
try:
|
| 636 |
+
# Get the preset name from the clicked row
|
| 637 |
+
preset_names = [preset for preset in VOICE_PRESETS.keys() if preset != "EMPTY"]
|
| 638 |
+
if evt.index[0] < len(preset_names):
|
| 639 |
+
preset = preset_names[evt.index[0]]
|
| 640 |
+
voice_path, _ = get_voice_preset(preset)
|
| 641 |
+
if voice_path and os.path.exists(voice_path):
|
| 642 |
+
return voice_path
|
| 643 |
+
else:
|
| 644 |
+
gr.Warning(f"Voice sample file not found for preset: {preset}")
|
| 645 |
+
return None
|
| 646 |
+
else:
|
| 647 |
+
gr.Warning("Invalid voice preset selection")
|
| 648 |
+
return None
|
| 649 |
+
except Exception as e:
|
| 650 |
+
logger.error(f"Error playing voice sample: {e}")
|
| 651 |
+
gr.Error(f"Error playing voice sample: {e}")
|
| 652 |
+
return None
|
| 653 |
+
|
| 654 |
+
voice_samples_table.select(fn=play_voice_sample, outputs=[sample_audio])
|
| 655 |
+
|
| 656 |
+
# Function to handle template selection
|
| 657 |
+
def apply_template(template_name):
|
| 658 |
+
"""
|
| 659 |
+
Apply a predefined template to the UI components.
|
| 660 |
+
|
| 661 |
+
Args:
|
| 662 |
+
template_name: Name of the template to apply
|
| 663 |
+
|
| 664 |
+
Returns:
|
| 665 |
+
Tuple of updated values for system_prompt, input_text, template_description,
|
| 666 |
+
voice_preset, custom_reference_accordion, voice_samples_section, and ras_win_len
|
| 667 |
+
"""
|
| 668 |
+
if template_name in PREDEFINED_EXAMPLES:
|
| 669 |
+
template = PREDEFINED_EXAMPLES[template_name]
|
| 670 |
+
# Enable voice preset and custom reference only for voice-clone template
|
| 671 |
+
is_voice_clone = template_name == "voice-clone"
|
| 672 |
+
voice_preset_value = "belinda" if is_voice_clone else "EMPTY"
|
| 673 |
+
# Set ras_win_len to 0 for single-speaker-bgm, 7 for others
|
| 674 |
+
ras_win_len_value = 0 if template_name == "single-speaker-bgm" else 7
|
| 675 |
+
description_text = f'<p style="font-size: 0.85em; color: #ff4da6; margin: 0; padding: 0;">{template["description"]}</p>'
|
| 676 |
+
return (
|
| 677 |
+
template["system_prompt"], # system_prompt
|
| 678 |
+
template["input_text"], # input_text
|
| 679 |
+
description_text, # template_description
|
| 680 |
+
gr.update(
|
| 681 |
+
value=voice_preset_value, interactive=is_voice_clone, visible=is_voice_clone
|
| 682 |
+
), # voice_preset (value and interactivity)
|
| 683 |
+
gr.update(visible=is_voice_clone), # custom reference accordion visibility
|
| 684 |
+
gr.update(visible=is_voice_clone), # voice samples section visibility
|
| 685 |
+
ras_win_len_value, # ras_win_len
|
| 686 |
+
)
|
| 687 |
+
else:
|
| 688 |
+
return (
|
| 689 |
+
gr.update(),
|
| 690 |
+
gr.update(),
|
| 691 |
+
gr.update(),
|
| 692 |
+
gr.update(),
|
| 693 |
+
gr.update(),
|
| 694 |
+
gr.update(),
|
| 695 |
+
gr.update(),
|
| 696 |
+
) # No change if template not found
|
| 697 |
+
|
| 698 |
+
# Set up event handlers
|
| 699 |
+
|
| 700 |
+
# Connect template dropdown to handler
|
| 701 |
+
template_dropdown.change(
|
| 702 |
+
fn=apply_template,
|
| 703 |
+
inputs=[template_dropdown],
|
| 704 |
+
outputs=[
|
| 705 |
+
system_prompt,
|
| 706 |
+
input_text,
|
| 707 |
+
template_description,
|
| 708 |
+
voice_preset,
|
| 709 |
+
custom_reference_accordion,
|
| 710 |
+
voice_samples_section,
|
| 711 |
+
ras_win_len,
|
| 712 |
+
],
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
# Connect submit button to the TTS function
|
| 716 |
+
submit_btn.click(
|
| 717 |
+
fn=text_to_speech,
|
| 718 |
+
inputs=[
|
| 719 |
+
input_text,
|
| 720 |
+
voice_preset,
|
| 721 |
+
reference_audio,
|
| 722 |
+
reference_text,
|
| 723 |
+
max_completion_tokens,
|
| 724 |
+
temperature,
|
| 725 |
+
top_p,
|
| 726 |
+
top_k,
|
| 727 |
+
system_prompt,
|
| 728 |
+
stop_strings,
|
| 729 |
+
ras_win_len,
|
| 730 |
+
ras_win_max_num_repeat,
|
| 731 |
+
],
|
| 732 |
+
outputs=[output_text, output_audio],
|
| 733 |
+
api_name="generate_speech",
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# Stop button functionality
|
| 737 |
+
stop_btn.click(
|
| 738 |
+
fn=lambda: None,
|
| 739 |
+
inputs=[],
|
| 740 |
+
outputs=[output_audio],
|
| 741 |
+
js="() => {const audio = document.querySelector('audio'); if(audio) audio.pause(); return null;}",
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
return demo
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def main():
|
| 748 |
+
"""Main function to parse arguments and launch the UI."""
|
| 749 |
+
global DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH, VOICE_PRESETS
|
| 750 |
+
|
| 751 |
+
parser = argparse.ArgumentParser(description="Gradio UI for Text-to-Speech using HiggsAudioServeEngine")
|
| 752 |
+
parser.add_argument(
|
| 753 |
+
"--device",
|
| 754 |
+
type=str,
|
| 755 |
+
default="cuda",
|
| 756 |
+
choices=["cuda", "cpu"],
|
| 757 |
+
help="Device to run the model on.",
|
| 758 |
+
)
|
| 759 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host for the Gradio interface.")
|
| 760 |
+
parser.add_argument("--port", type=int, default=7860, help="Port for the Gradio interface.")
|
| 761 |
+
|
| 762 |
+
args = parser.parse_args()
|
| 763 |
+
|
| 764 |
+
# Update default values if provided via command line
|
| 765 |
+
VOICE_PRESETS = load_voice_presets()
|
| 766 |
+
|
| 767 |
+
# Create and launch the UI
|
| 768 |
+
demo = create_ui()
|
| 769 |
+
demo.launch(server_name=args.host, server_port=args.port, share=True)
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
if __name__ == "__main__":
|
| 773 |
+
main()
|
higgs_audio/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .model import HiggsAudioConfig, HiggsAudioModel
|
higgs_audio/audio_processing/LICENSE
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Third-Party License Attribution for Audio Processing Module
|
| 2 |
+
===========================================================
|
| 3 |
+
|
| 4 |
+
This directory contains code derived from multiple open-source projects.
|
| 5 |
+
The following sections detail the licenses and attributions for third-party code.
|
| 6 |
+
|
| 7 |
+
## XCodec Repository
|
| 8 |
+
The code in this directory is derived from:
|
| 9 |
+
https://github.com/zhenye234/xcodec
|
| 10 |
+
|
| 11 |
+
## Individual File Attributions
|
| 12 |
+
|
| 13 |
+
### Quantization Module (quantization/)
|
| 14 |
+
- Several files contain code derived from Meta Platforms, Inc. and the vector-quantize-pytorch repository
|
| 15 |
+
- Individual files contain their own license headers where applicable
|
| 16 |
+
- The vector-quantize-pytorch portions are licensed under the MIT License
|
| 17 |
+
|
| 18 |
+
## License Terms
|
| 19 |
+
|
| 20 |
+
### MIT License (for applicable portions)
|
| 21 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 22 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 23 |
+
in the Software without restriction, including without limitation the rights
|
| 24 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 25 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 26 |
+
furnished to do so, subject to the following conditions:
|
| 27 |
+
|
| 28 |
+
The above copyright notice and this permission notice shall be included in all
|
| 29 |
+
copies or substantial portions of the Software.
|
| 30 |
+
|
| 31 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 32 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 33 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 34 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 35 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 36 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 37 |
+
SOFTWARE.
|
| 38 |
+
|
| 39 |
+
## Attribution Requirements
|
| 40 |
+
When using this code, please ensure proper attribution to:
|
| 41 |
+
1. The original xcodec repository: https://github.com/zhenye234/xcodec
|
| 42 |
+
2. Any other repositories mentioned in individual file headers
|
| 43 |
+
3. This derivative work and its modifications
|
| 44 |
+
|
| 45 |
+
## Disclaimer
|
| 46 |
+
This directory contains modified versions of the original code. Please refer to
|
| 47 |
+
the original repositories for the canonical implementations and their specific
|
| 48 |
+
license terms.
|
| 49 |
+
|
| 50 |
+
For any questions about licensing or attribution, please check the individual
|
| 51 |
+
file headers and the original source repositories.
|
higgs_audio/audio_processing/descriptaudiocodec/__init__.py
ADDED
|
File without changes
|
higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import tqdm
|
| 9 |
+
from audiotools import AudioSignal
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
SUPPORTED_VERSIONS = ["1.0.0"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class DACFile:
|
| 17 |
+
codes: torch.Tensor
|
| 18 |
+
|
| 19 |
+
# Metadata
|
| 20 |
+
chunk_length: int
|
| 21 |
+
original_length: int
|
| 22 |
+
input_db: float
|
| 23 |
+
channels: int
|
| 24 |
+
sample_rate: int
|
| 25 |
+
padding: bool
|
| 26 |
+
dac_version: str
|
| 27 |
+
|
| 28 |
+
def save(self, path):
|
| 29 |
+
artifacts = {
|
| 30 |
+
"codes": self.codes.numpy().astype(np.uint16),
|
| 31 |
+
"metadata": {
|
| 32 |
+
"input_db": self.input_db.numpy().astype(np.float32),
|
| 33 |
+
"original_length": self.original_length,
|
| 34 |
+
"sample_rate": self.sample_rate,
|
| 35 |
+
"chunk_length": self.chunk_length,
|
| 36 |
+
"channels": self.channels,
|
| 37 |
+
"padding": self.padding,
|
| 38 |
+
"dac_version": SUPPORTED_VERSIONS[-1],
|
| 39 |
+
},
|
| 40 |
+
}
|
| 41 |
+
path = Path(path).with_suffix(".dac")
|
| 42 |
+
with open(path, "wb") as f:
|
| 43 |
+
np.save(f, artifacts)
|
| 44 |
+
return path
|
| 45 |
+
|
| 46 |
+
@classmethod
|
| 47 |
+
def load(cls, path):
|
| 48 |
+
artifacts = np.load(path, allow_pickle=True)[()]
|
| 49 |
+
codes = torch.from_numpy(artifacts["codes"].astype(int))
|
| 50 |
+
if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
|
| 51 |
+
raise RuntimeError(f"Given file {path} can't be loaded with this version of descript-audio-codec.")
|
| 52 |
+
return cls(codes=codes, **artifacts["metadata"])
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class CodecMixin:
|
| 56 |
+
@property
|
| 57 |
+
def padding(self):
|
| 58 |
+
if not hasattr(self, "_padding"):
|
| 59 |
+
self._padding = True
|
| 60 |
+
return self._padding
|
| 61 |
+
|
| 62 |
+
@padding.setter
|
| 63 |
+
def padding(self, value):
|
| 64 |
+
assert isinstance(value, bool)
|
| 65 |
+
|
| 66 |
+
layers = [l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))]
|
| 67 |
+
|
| 68 |
+
for layer in layers:
|
| 69 |
+
if value:
|
| 70 |
+
if hasattr(layer, "original_padding"):
|
| 71 |
+
layer.padding = layer.original_padding
|
| 72 |
+
else:
|
| 73 |
+
layer.original_padding = layer.padding
|
| 74 |
+
layer.padding = tuple(0 for _ in range(len(layer.padding)))
|
| 75 |
+
|
| 76 |
+
self._padding = value
|
| 77 |
+
|
| 78 |
+
def get_delay(self):
|
| 79 |
+
# Any number works here, delay is invariant to input length
|
| 80 |
+
l_out = self.get_output_length(0)
|
| 81 |
+
L = l_out
|
| 82 |
+
|
| 83 |
+
layers = []
|
| 84 |
+
for layer in self.modules():
|
| 85 |
+
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
| 86 |
+
layers.append(layer)
|
| 87 |
+
|
| 88 |
+
for layer in reversed(layers):
|
| 89 |
+
d = layer.dilation[0]
|
| 90 |
+
k = layer.kernel_size[0]
|
| 91 |
+
s = layer.stride[0]
|
| 92 |
+
|
| 93 |
+
if isinstance(layer, nn.ConvTranspose1d):
|
| 94 |
+
L = ((L - d * (k - 1) - 1) / s) + 1
|
| 95 |
+
elif isinstance(layer, nn.Conv1d):
|
| 96 |
+
L = (L - 1) * s + d * (k - 1) + 1
|
| 97 |
+
|
| 98 |
+
L = math.ceil(L)
|
| 99 |
+
|
| 100 |
+
l_in = L
|
| 101 |
+
|
| 102 |
+
return (l_in - l_out) // 2
|
| 103 |
+
|
| 104 |
+
def get_output_length(self, input_length):
|
| 105 |
+
L = input_length
|
| 106 |
+
# Calculate output length
|
| 107 |
+
for layer in self.modules():
|
| 108 |
+
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
| 109 |
+
d = layer.dilation[0]
|
| 110 |
+
k = layer.kernel_size[0]
|
| 111 |
+
s = layer.stride[0]
|
| 112 |
+
|
| 113 |
+
if isinstance(layer, nn.Conv1d):
|
| 114 |
+
L = ((L - d * (k - 1) - 1) / s) + 1
|
| 115 |
+
elif isinstance(layer, nn.ConvTranspose1d):
|
| 116 |
+
L = (L - 1) * s + d * (k - 1) + 1
|
| 117 |
+
|
| 118 |
+
L = math.floor(L)
|
| 119 |
+
return L
|
| 120 |
+
|
| 121 |
+
@torch.no_grad()
|
| 122 |
+
def compress(
|
| 123 |
+
self,
|
| 124 |
+
audio_path_or_signal: Union[str, Path, AudioSignal],
|
| 125 |
+
win_duration: float = 1.0,
|
| 126 |
+
verbose: bool = False,
|
| 127 |
+
normalize_db: float = -16,
|
| 128 |
+
n_quantizers: int = None,
|
| 129 |
+
) -> DACFile:
|
| 130 |
+
"""Processes an audio signal from a file or AudioSignal object into
|
| 131 |
+
discrete codes. This function processes the signal in short windows,
|
| 132 |
+
using constant GPU memory.
|
| 133 |
+
|
| 134 |
+
Parameters
|
| 135 |
+
----------
|
| 136 |
+
audio_path_or_signal : Union[str, Path, AudioSignal]
|
| 137 |
+
audio signal to reconstruct
|
| 138 |
+
win_duration : float, optional
|
| 139 |
+
window duration in seconds, by default 5.0
|
| 140 |
+
verbose : bool, optional
|
| 141 |
+
by default False
|
| 142 |
+
normalize_db : float, optional
|
| 143 |
+
normalize db, by default -16
|
| 144 |
+
|
| 145 |
+
Returns
|
| 146 |
+
-------
|
| 147 |
+
DACFile
|
| 148 |
+
Object containing compressed codes and metadata
|
| 149 |
+
required for decompression
|
| 150 |
+
"""
|
| 151 |
+
audio_signal = audio_path_or_signal
|
| 152 |
+
if isinstance(audio_signal, (str, Path)):
|
| 153 |
+
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
|
| 154 |
+
|
| 155 |
+
self.eval()
|
| 156 |
+
original_padding = self.padding
|
| 157 |
+
original_device = audio_signal.device
|
| 158 |
+
|
| 159 |
+
audio_signal = audio_signal.clone()
|
| 160 |
+
original_sr = audio_signal.sample_rate
|
| 161 |
+
|
| 162 |
+
resample_fn = audio_signal.resample
|
| 163 |
+
loudness_fn = audio_signal.loudness
|
| 164 |
+
|
| 165 |
+
# If audio is > 10 minutes long, use the ffmpeg versions
|
| 166 |
+
if audio_signal.signal_duration >= 10 * 60 * 60:
|
| 167 |
+
resample_fn = audio_signal.ffmpeg_resample
|
| 168 |
+
loudness_fn = audio_signal.ffmpeg_loudness
|
| 169 |
+
|
| 170 |
+
original_length = audio_signal.signal_length
|
| 171 |
+
resample_fn(self.sample_rate)
|
| 172 |
+
input_db = loudness_fn()
|
| 173 |
+
|
| 174 |
+
if normalize_db is not None:
|
| 175 |
+
audio_signal.normalize(normalize_db)
|
| 176 |
+
audio_signal.ensure_max_of_audio()
|
| 177 |
+
|
| 178 |
+
nb, nac, nt = audio_signal.audio_data.shape
|
| 179 |
+
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
|
| 180 |
+
win_duration = audio_signal.signal_duration if win_duration is None else win_duration
|
| 181 |
+
|
| 182 |
+
if audio_signal.signal_duration <= win_duration:
|
| 183 |
+
# Unchunked compression (used if signal length < win duration)
|
| 184 |
+
self.padding = True
|
| 185 |
+
n_samples = nt
|
| 186 |
+
hop = nt
|
| 187 |
+
else:
|
| 188 |
+
# Chunked inference
|
| 189 |
+
self.padding = False
|
| 190 |
+
# Zero-pad signal on either side by the delay
|
| 191 |
+
audio_signal.zero_pad(self.delay, self.delay)
|
| 192 |
+
n_samples = int(win_duration * self.sample_rate)
|
| 193 |
+
# Round n_samples to nearest hop length multiple
|
| 194 |
+
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
|
| 195 |
+
hop = self.get_output_length(n_samples)
|
| 196 |
+
|
| 197 |
+
codes = []
|
| 198 |
+
range_fn = range if not verbose else tqdm.trange
|
| 199 |
+
|
| 200 |
+
for i in range_fn(0, nt, hop):
|
| 201 |
+
x = audio_signal[..., i : i + n_samples]
|
| 202 |
+
x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
|
| 203 |
+
|
| 204 |
+
audio_data = x.audio_data.to(self.device)
|
| 205 |
+
audio_data = self.preprocess(audio_data, self.sample_rate)
|
| 206 |
+
_, c, _, _, _ = self.encode(audio_data, n_quantizers)
|
| 207 |
+
codes.append(c.to(original_device))
|
| 208 |
+
chunk_length = c.shape[-1]
|
| 209 |
+
|
| 210 |
+
codes = torch.cat(codes, dim=-1)
|
| 211 |
+
|
| 212 |
+
dac_file = DACFile(
|
| 213 |
+
codes=codes,
|
| 214 |
+
chunk_length=chunk_length,
|
| 215 |
+
original_length=original_length,
|
| 216 |
+
input_db=input_db,
|
| 217 |
+
channels=nac,
|
| 218 |
+
sample_rate=original_sr,
|
| 219 |
+
padding=self.padding,
|
| 220 |
+
dac_version=SUPPORTED_VERSIONS[-1],
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if n_quantizers is not None:
|
| 224 |
+
codes = codes[:, :n_quantizers, :]
|
| 225 |
+
|
| 226 |
+
self.padding = original_padding
|
| 227 |
+
return dac_file
|
| 228 |
+
|
| 229 |
+
@torch.no_grad()
|
| 230 |
+
def decompress(
|
| 231 |
+
self,
|
| 232 |
+
obj: Union[str, Path, DACFile],
|
| 233 |
+
verbose: bool = False,
|
| 234 |
+
) -> AudioSignal:
|
| 235 |
+
"""Reconstruct audio from a given .dac file
|
| 236 |
+
|
| 237 |
+
Parameters
|
| 238 |
+
----------
|
| 239 |
+
obj : Union[str, Path, DACFile]
|
| 240 |
+
.dac file location or corresponding DACFile object.
|
| 241 |
+
verbose : bool, optional
|
| 242 |
+
Prints progress if True, by default False
|
| 243 |
+
|
| 244 |
+
Returns
|
| 245 |
+
-------
|
| 246 |
+
AudioSignal
|
| 247 |
+
Object with the reconstructed audio
|
| 248 |
+
"""
|
| 249 |
+
self.eval()
|
| 250 |
+
if isinstance(obj, (str, Path)):
|
| 251 |
+
obj = DACFile.load(obj)
|
| 252 |
+
|
| 253 |
+
original_padding = self.padding
|
| 254 |
+
self.padding = obj.padding
|
| 255 |
+
|
| 256 |
+
range_fn = range if not verbose else tqdm.trange
|
| 257 |
+
codes = obj.codes
|
| 258 |
+
original_device = codes.device
|
| 259 |
+
chunk_length = obj.chunk_length
|
| 260 |
+
recons = []
|
| 261 |
+
|
| 262 |
+
for i in range_fn(0, codes.shape[-1], chunk_length):
|
| 263 |
+
c = codes[..., i : i + chunk_length].to(self.device)
|
| 264 |
+
z = self.quantizer.from_codes(c)[0]
|
| 265 |
+
r = self.decode(z)
|
| 266 |
+
recons.append(r.to(original_device))
|
| 267 |
+
|
| 268 |
+
recons = torch.cat(recons, dim=-1)
|
| 269 |
+
recons = AudioSignal(recons, self.sample_rate)
|
| 270 |
+
|
| 271 |
+
resample_fn = recons.resample
|
| 272 |
+
loudness_fn = recons.loudness
|
| 273 |
+
|
| 274 |
+
# If audio is > 10 minutes long, use the ffmpeg versions
|
| 275 |
+
if recons.signal_duration >= 10 * 60 * 60:
|
| 276 |
+
resample_fn = recons.ffmpeg_resample
|
| 277 |
+
loudness_fn = recons.ffmpeg_loudness
|
| 278 |
+
|
| 279 |
+
recons.normalize(obj.input_db)
|
| 280 |
+
resample_fn(obj.sample_rate)
|
| 281 |
+
recons = recons[..., : obj.original_length]
|
| 282 |
+
loudness_fn()
|
| 283 |
+
recons.audio_data = recons.audio_data.reshape(-1, obj.channels, obj.original_length)
|
| 284 |
+
|
| 285 |
+
self.padding = original_padding
|
| 286 |
+
return recons
|
higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from audiotools import AudioSignal
|
| 8 |
+
from audiotools.ml import BaseModel
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from .base import CodecMixin
|
| 12 |
+
from dac.nn.layers import Snake1d
|
| 13 |
+
from dac.nn.layers import WNConv1d
|
| 14 |
+
from dac.nn.layers import WNConvTranspose1d
|
| 15 |
+
from dac.nn.quantize import ResidualVectorQuantize
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def init_weights(m):
|
| 19 |
+
if isinstance(m, nn.Conv1d):
|
| 20 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 21 |
+
nn.init.constant_(m.bias, 0)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ResidualUnit(nn.Module):
|
| 25 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
| 26 |
+
super().__init__()
|
| 27 |
+
pad = ((7 - 1) * dilation) // 2
|
| 28 |
+
self.block = nn.Sequential(
|
| 29 |
+
Snake1d(dim),
|
| 30 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
| 31 |
+
Snake1d(dim),
|
| 32 |
+
WNConv1d(dim, dim, kernel_size=1),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
y = self.block(x)
|
| 37 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
| 38 |
+
if pad > 0:
|
| 39 |
+
x = x[..., pad:-pad]
|
| 40 |
+
return x + y
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class EncoderBlock(nn.Module):
|
| 44 |
+
def __init__(self, dim: int = 16, stride: int = 1):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.block = nn.Sequential(
|
| 47 |
+
ResidualUnit(dim // 2, dilation=1),
|
| 48 |
+
ResidualUnit(dim // 2, dilation=3),
|
| 49 |
+
ResidualUnit(dim // 2, dilation=9),
|
| 50 |
+
Snake1d(dim // 2),
|
| 51 |
+
WNConv1d(
|
| 52 |
+
dim // 2,
|
| 53 |
+
dim,
|
| 54 |
+
kernel_size=2 * stride,
|
| 55 |
+
stride=stride,
|
| 56 |
+
padding=math.ceil(stride / 2),
|
| 57 |
+
),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
return self.block(x)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Encoder(nn.Module):
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
d_model: int = 64,
|
| 68 |
+
strides: list = [2, 4, 8, 8],
|
| 69 |
+
d_latent: int = 256,
|
| 70 |
+
):
|
| 71 |
+
super().__init__()
|
| 72 |
+
# Create first convolution
|
| 73 |
+
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
| 74 |
+
|
| 75 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
| 76 |
+
for stride in strides:
|
| 77 |
+
d_model *= 2
|
| 78 |
+
self.block += [EncoderBlock(d_model, stride=stride)]
|
| 79 |
+
|
| 80 |
+
# Create last convolution
|
| 81 |
+
self.block += [
|
| 82 |
+
Snake1d(d_model),
|
| 83 |
+
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
# Wrap black into nn.Sequential
|
| 87 |
+
self.block = nn.Sequential(*self.block)
|
| 88 |
+
self.enc_dim = d_model
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
return self.block(x)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class DecoderBlock(nn.Module):
|
| 95 |
+
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, out_pad=0):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.block = nn.Sequential(
|
| 98 |
+
Snake1d(input_dim),
|
| 99 |
+
WNConvTranspose1d(
|
| 100 |
+
input_dim,
|
| 101 |
+
output_dim,
|
| 102 |
+
kernel_size=2 * stride,
|
| 103 |
+
stride=stride,
|
| 104 |
+
padding=math.ceil(stride / 2),
|
| 105 |
+
output_padding=stride % 2, # out_pad,
|
| 106 |
+
),
|
| 107 |
+
ResidualUnit(output_dim, dilation=1),
|
| 108 |
+
ResidualUnit(output_dim, dilation=3),
|
| 109 |
+
ResidualUnit(output_dim, dilation=9),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
return self.block(x)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class Decoder(nn.Module):
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
input_channel,
|
| 120 |
+
channels,
|
| 121 |
+
rates,
|
| 122 |
+
d_out: int = 1,
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
|
| 126 |
+
# Add first conv layer
|
| 127 |
+
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
| 128 |
+
|
| 129 |
+
# Add upsampling + MRF blocks
|
| 130 |
+
for i, stride in enumerate(rates):
|
| 131 |
+
input_dim = channels // 2**i
|
| 132 |
+
output_dim = channels // 2 ** (i + 1)
|
| 133 |
+
if i == 1:
|
| 134 |
+
out_pad = 1
|
| 135 |
+
else:
|
| 136 |
+
out_pad = 0
|
| 137 |
+
layers += [DecoderBlock(input_dim, output_dim, stride, out_pad)]
|
| 138 |
+
|
| 139 |
+
# Add final conv layer
|
| 140 |
+
layers += [
|
| 141 |
+
Snake1d(output_dim),
|
| 142 |
+
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
| 143 |
+
# nn.Tanh(),
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
self.model = nn.Sequential(*layers)
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
return self.model(x)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class DAC(BaseModel, CodecMixin):
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
encoder_dim: int = 64,
|
| 156 |
+
encoder_rates: List[int] = [2, 4, 8, 8],
|
| 157 |
+
latent_dim: int = None,
|
| 158 |
+
decoder_dim: int = 1536,
|
| 159 |
+
decoder_rates: List[int] = [8, 8, 4, 2],
|
| 160 |
+
n_codebooks: int = 9,
|
| 161 |
+
codebook_size: int = 1024,
|
| 162 |
+
codebook_dim: Union[int, list] = 8,
|
| 163 |
+
quantizer_dropout: bool = False,
|
| 164 |
+
sample_rate: int = 44100,
|
| 165 |
+
):
|
| 166 |
+
super().__init__()
|
| 167 |
+
|
| 168 |
+
self.encoder_dim = encoder_dim
|
| 169 |
+
self.encoder_rates = encoder_rates
|
| 170 |
+
self.decoder_dim = decoder_dim
|
| 171 |
+
self.decoder_rates = decoder_rates
|
| 172 |
+
self.sample_rate = sample_rate
|
| 173 |
+
|
| 174 |
+
if latent_dim is None:
|
| 175 |
+
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
| 176 |
+
|
| 177 |
+
self.latent_dim = latent_dim
|
| 178 |
+
|
| 179 |
+
self.hop_length = np.prod(encoder_rates)
|
| 180 |
+
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
|
| 181 |
+
|
| 182 |
+
self.n_codebooks = n_codebooks
|
| 183 |
+
self.codebook_size = codebook_size
|
| 184 |
+
self.codebook_dim = codebook_dim
|
| 185 |
+
self.quantizer = ResidualVectorQuantize(
|
| 186 |
+
input_dim=latent_dim,
|
| 187 |
+
n_codebooks=n_codebooks,
|
| 188 |
+
codebook_size=codebook_size,
|
| 189 |
+
codebook_dim=codebook_dim,
|
| 190 |
+
quantizer_dropout=quantizer_dropout,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
self.decoder = Decoder(
|
| 194 |
+
latent_dim,
|
| 195 |
+
decoder_dim,
|
| 196 |
+
decoder_rates,
|
| 197 |
+
)
|
| 198 |
+
self.sample_rate = sample_rate
|
| 199 |
+
self.apply(init_weights)
|
| 200 |
+
|
| 201 |
+
self.delay = self.get_delay()
|
| 202 |
+
|
| 203 |
+
def preprocess(self, audio_data, sample_rate):
|
| 204 |
+
if sample_rate is None:
|
| 205 |
+
sample_rate = self.sample_rate
|
| 206 |
+
assert sample_rate == self.sample_rate
|
| 207 |
+
|
| 208 |
+
length = audio_data.shape[-1]
|
| 209 |
+
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
|
| 210 |
+
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
| 211 |
+
|
| 212 |
+
return audio_data
|
| 213 |
+
|
| 214 |
+
def encode(
|
| 215 |
+
self,
|
| 216 |
+
audio_data: torch.Tensor,
|
| 217 |
+
n_quantizers: int = None,
|
| 218 |
+
):
|
| 219 |
+
"""Encode given audio data and return quantized latent codes
|
| 220 |
+
|
| 221 |
+
Parameters
|
| 222 |
+
----------
|
| 223 |
+
audio_data : Tensor[B x 1 x T]
|
| 224 |
+
Audio data to encode
|
| 225 |
+
n_quantizers : int, optional
|
| 226 |
+
Number of quantizers to use, by default None
|
| 227 |
+
If None, all quantizers are used.
|
| 228 |
+
|
| 229 |
+
Returns
|
| 230 |
+
-------
|
| 231 |
+
dict
|
| 232 |
+
A dictionary with the following keys:
|
| 233 |
+
"z" : Tensor[B x D x T]
|
| 234 |
+
Quantized continuous representation of input
|
| 235 |
+
"codes" : Tensor[B x N x T]
|
| 236 |
+
Codebook indices for each codebook
|
| 237 |
+
(quantized discrete representation of input)
|
| 238 |
+
"latents" : Tensor[B x N*D x T]
|
| 239 |
+
Projected latents (continuous representation of input before quantization)
|
| 240 |
+
"vq/commitment_loss" : Tensor[1]
|
| 241 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 242 |
+
entries
|
| 243 |
+
"vq/codebook_loss" : Tensor[1]
|
| 244 |
+
Codebook loss to update the codebook
|
| 245 |
+
"length" : int
|
| 246 |
+
Number of samples in input audio
|
| 247 |
+
"""
|
| 248 |
+
z = self.encoder(audio_data)
|
| 249 |
+
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
|
| 250 |
+
return z, codes, latents, commitment_loss, codebook_loss
|
| 251 |
+
|
| 252 |
+
def decode(self, z: torch.Tensor):
|
| 253 |
+
"""Decode given latent codes and return audio data
|
| 254 |
+
|
| 255 |
+
Parameters
|
| 256 |
+
----------
|
| 257 |
+
z : Tensor[B x D x T]
|
| 258 |
+
Quantized continuous representation of input
|
| 259 |
+
length : int, optional
|
| 260 |
+
Number of samples in output audio, by default None
|
| 261 |
+
|
| 262 |
+
Returns
|
| 263 |
+
-------
|
| 264 |
+
dict
|
| 265 |
+
A dictionary with the following keys:
|
| 266 |
+
"audio" : Tensor[B x 1 x length]
|
| 267 |
+
Decoded audio data.
|
| 268 |
+
"""
|
| 269 |
+
return self.decoder(z)
|
| 270 |
+
|
| 271 |
+
def forward(
|
| 272 |
+
self,
|
| 273 |
+
audio_data: torch.Tensor,
|
| 274 |
+
sample_rate: int = None,
|
| 275 |
+
n_quantizers: int = None,
|
| 276 |
+
):
|
| 277 |
+
"""Model forward pass
|
| 278 |
+
|
| 279 |
+
Parameters
|
| 280 |
+
----------
|
| 281 |
+
audio_data : Tensor[B x 1 x T]
|
| 282 |
+
Audio data to encode
|
| 283 |
+
sample_rate : int, optional
|
| 284 |
+
Sample rate of audio data in Hz, by default None
|
| 285 |
+
If None, defaults to `self.sample_rate`
|
| 286 |
+
n_quantizers : int, optional
|
| 287 |
+
Number of quantizers to use, by default None.
|
| 288 |
+
If None, all quantizers are used.
|
| 289 |
+
|
| 290 |
+
Returns
|
| 291 |
+
-------
|
| 292 |
+
dict
|
| 293 |
+
A dictionary with the following keys:
|
| 294 |
+
"z" : Tensor[B x D x T]
|
| 295 |
+
Quantized continuous representation of input
|
| 296 |
+
"codes" : Tensor[B x N x T]
|
| 297 |
+
Codebook indices for each codebook
|
| 298 |
+
(quantized discrete representation of input)
|
| 299 |
+
"latents" : Tensor[B x N*D x T]
|
| 300 |
+
Projected latents (continuous representation of input before quantization)
|
| 301 |
+
"vq/commitment_loss" : Tensor[1]
|
| 302 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 303 |
+
entries
|
| 304 |
+
"vq/codebook_loss" : Tensor[1]
|
| 305 |
+
Codebook loss to update the codebook
|
| 306 |
+
"length" : int
|
| 307 |
+
Number of samples in input audio
|
| 308 |
+
"audio" : Tensor[B x 1 x length]
|
| 309 |
+
Decoded audio data.
|
| 310 |
+
"""
|
| 311 |
+
length = audio_data.shape[-1]
|
| 312 |
+
audio_data = self.preprocess(audio_data, sample_rate)
|
| 313 |
+
z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
|
| 314 |
+
|
| 315 |
+
x = self.decode(z)
|
| 316 |
+
return {
|
| 317 |
+
"audio": x[..., :length],
|
| 318 |
+
"z": z,
|
| 319 |
+
"codes": codes,
|
| 320 |
+
"latents": latents,
|
| 321 |
+
"vq/commitment_loss": commitment_loss,
|
| 322 |
+
"vq/codebook_loss": codebook_loss,
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
if __name__ == "__main__":
|
| 327 |
+
import numpy as np
|
| 328 |
+
from functools import partial
|
| 329 |
+
|
| 330 |
+
model = DAC().to("cpu")
|
| 331 |
+
|
| 332 |
+
for n, m in model.named_modules():
|
| 333 |
+
o = m.extra_repr()
|
| 334 |
+
p = sum([np.prod(p.size()) for p in m.parameters()])
|
| 335 |
+
fn = lambda o, p: o + f" {p / 1e6:<.3f}M params."
|
| 336 |
+
setattr(m, "extra_repr", partial(fn, o=o, p=p))
|
| 337 |
+
print(model)
|
| 338 |
+
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
|
| 339 |
+
|
| 340 |
+
length = 88200 * 2
|
| 341 |
+
x = torch.randn(1, 1, length).to(model.device)
|
| 342 |
+
x.requires_grad_(True)
|
| 343 |
+
x.retain_grad()
|
| 344 |
+
|
| 345 |
+
# Make a forward pass
|
| 346 |
+
out = model(x)["audio"]
|
| 347 |
+
print("Input shape:", x.shape)
|
| 348 |
+
print("Output shape:", out.shape)
|
| 349 |
+
|
| 350 |
+
# Create gradient variable
|
| 351 |
+
grad = torch.zeros_like(out)
|
| 352 |
+
grad[:, :, grad.shape[-1] // 2] = 1
|
| 353 |
+
|
| 354 |
+
# Make a backward pass
|
| 355 |
+
out.backward(grad)
|
| 356 |
+
|
| 357 |
+
# Check non-zero values
|
| 358 |
+
gradmap = x.grad.squeeze(0)
|
| 359 |
+
gradmap = (gradmap != 0).sum(0) # sum across features
|
| 360 |
+
rf = (gradmap != 0).sum()
|
| 361 |
+
|
| 362 |
+
print(f"Receptive field: {rf.item()}")
|
| 363 |
+
|
| 364 |
+
x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
|
| 365 |
+
model.decompress(model.compress(x, verbose=True), verbose=True)
|
higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from torch.nn.utils import weight_norm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def WNConv1d(*args, **kwargs):
|
| 10 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 14 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Scripting this brings model speed up 1.4x
|
| 18 |
+
@torch.jit.script
|
| 19 |
+
def snake(x, alpha):
|
| 20 |
+
shape = x.shape
|
| 21 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 22 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 23 |
+
x = x.reshape(shape)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Snake1d(nn.Module):
|
| 28 |
+
def __init__(self, channels):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
return snake(x, self.alpha)
|
higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from torch.nn.utils import weight_norm
|
| 9 |
+
|
| 10 |
+
from dac.nn.layers import WNConv1d
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class VectorQuantize(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Implementation of VQ similar to Karpathy's repo:
|
| 16 |
+
https://github.com/karpathy/deep-vector-quantization
|
| 17 |
+
Additionally uses following tricks from Improved VQGAN
|
| 18 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
| 19 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
| 20 |
+
for improved codebook usage
|
| 21 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
| 22 |
+
improves training stability
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.codebook_size = codebook_size
|
| 28 |
+
self.codebook_dim = codebook_dim
|
| 29 |
+
|
| 30 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
| 31 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
| 32 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
| 33 |
+
|
| 34 |
+
def forward(self, z):
|
| 35 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
| 36 |
+
the corresponding codebook vectors
|
| 37 |
+
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
z : Tensor[B x D x T]
|
| 41 |
+
|
| 42 |
+
Returns
|
| 43 |
+
-------
|
| 44 |
+
Tensor[B x D x T]
|
| 45 |
+
Quantized continuous representation of input
|
| 46 |
+
Tensor[1]
|
| 47 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 48 |
+
entries
|
| 49 |
+
Tensor[1]
|
| 50 |
+
Codebook loss to update the codebook
|
| 51 |
+
Tensor[B x T]
|
| 52 |
+
Codebook indices (quantized discrete representation of input)
|
| 53 |
+
Tensor[B x D x T]
|
| 54 |
+
Projected latents (continuous representation of input before quantization)
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
| 58 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
| 59 |
+
z_q, indices = self.decode_latents(z_e)
|
| 60 |
+
|
| 61 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
| 62 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
| 63 |
+
|
| 64 |
+
z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass
|
| 65 |
+
|
| 66 |
+
z_q = self.out_proj(z_q)
|
| 67 |
+
|
| 68 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
| 69 |
+
|
| 70 |
+
def embed_code(self, embed_id):
|
| 71 |
+
return F.embedding(embed_id, self.codebook.weight)
|
| 72 |
+
|
| 73 |
+
def decode_code(self, embed_id):
|
| 74 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
| 75 |
+
|
| 76 |
+
def decode_latents(self, latents):
|
| 77 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
| 78 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
| 79 |
+
|
| 80 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
| 81 |
+
encodings = F.normalize(encodings)
|
| 82 |
+
codebook = F.normalize(codebook)
|
| 83 |
+
|
| 84 |
+
# Compute euclidean distance with codebook
|
| 85 |
+
dist = (
|
| 86 |
+
encodings.pow(2).sum(1, keepdim=True)
|
| 87 |
+
- 2 * encodings @ codebook.t()
|
| 88 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 89 |
+
)
|
| 90 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 91 |
+
z_q = self.decode_code(indices)
|
| 92 |
+
return z_q, indices
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class ResidualVectorQuantize(nn.Module):
|
| 96 |
+
"""
|
| 97 |
+
Introduced in SoundStream: An end2end neural audio codec
|
| 98 |
+
https://arxiv.org/abs/2107.03312
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
input_dim: int = 512,
|
| 104 |
+
n_codebooks: int = 9,
|
| 105 |
+
codebook_size: int = 1024,
|
| 106 |
+
codebook_dim: Union[int, list] = 8,
|
| 107 |
+
quantizer_dropout: float = 0.0,
|
| 108 |
+
):
|
| 109 |
+
super().__init__()
|
| 110 |
+
if isinstance(codebook_dim, int):
|
| 111 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
| 112 |
+
|
| 113 |
+
self.n_codebooks = n_codebooks
|
| 114 |
+
self.codebook_dim = codebook_dim
|
| 115 |
+
self.codebook_size = codebook_size
|
| 116 |
+
|
| 117 |
+
self.quantizers = nn.ModuleList(
|
| 118 |
+
[VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks)]
|
| 119 |
+
)
|
| 120 |
+
self.quantizer_dropout = quantizer_dropout
|
| 121 |
+
|
| 122 |
+
def forward(self, z, n_quantizers: int = None):
|
| 123 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
| 124 |
+
the corresponding codebook vectors
|
| 125 |
+
Parameters
|
| 126 |
+
----------
|
| 127 |
+
z : Tensor[B x D x T]
|
| 128 |
+
n_quantizers : int, optional
|
| 129 |
+
No. of quantizers to use
|
| 130 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
| 131 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
| 132 |
+
when in training mode, and a random number of quantizers is used.
|
| 133 |
+
Returns
|
| 134 |
+
-------
|
| 135 |
+
dict
|
| 136 |
+
A dictionary with the following keys:
|
| 137 |
+
|
| 138 |
+
"z" : Tensor[B x D x T]
|
| 139 |
+
Quantized continuous representation of input
|
| 140 |
+
"codes" : Tensor[B x N x T]
|
| 141 |
+
Codebook indices for each codebook
|
| 142 |
+
(quantized discrete representation of input)
|
| 143 |
+
"latents" : Tensor[B x N*D x T]
|
| 144 |
+
Projected latents (continuous representation of input before quantization)
|
| 145 |
+
"vq/commitment_loss" : Tensor[1]
|
| 146 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 147 |
+
entries
|
| 148 |
+
"vq/codebook_loss" : Tensor[1]
|
| 149 |
+
Codebook loss to update the codebook
|
| 150 |
+
"""
|
| 151 |
+
z_q = 0
|
| 152 |
+
residual = z
|
| 153 |
+
commitment_loss = 0
|
| 154 |
+
codebook_loss = 0
|
| 155 |
+
|
| 156 |
+
codebook_indices = []
|
| 157 |
+
latents = []
|
| 158 |
+
|
| 159 |
+
if n_quantizers is None:
|
| 160 |
+
n_quantizers = self.n_codebooks
|
| 161 |
+
if self.training:
|
| 162 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
| 163 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
| 164 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
| 165 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 166 |
+
n_quantizers = n_quantizers.to(z.device)
|
| 167 |
+
|
| 168 |
+
for i, quantizer in enumerate(self.quantizers):
|
| 169 |
+
if self.training is False and i >= n_quantizers:
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
|
| 173 |
+
|
| 174 |
+
# Create mask to apply quantizer dropout
|
| 175 |
+
mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
| 176 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
| 177 |
+
residual = residual - z_q_i
|
| 178 |
+
|
| 179 |
+
# Sum losses
|
| 180 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
| 181 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
| 182 |
+
|
| 183 |
+
codebook_indices.append(indices_i)
|
| 184 |
+
latents.append(z_e_i)
|
| 185 |
+
|
| 186 |
+
codes = torch.stack(codebook_indices, dim=1)
|
| 187 |
+
latents = torch.cat(latents, dim=1)
|
| 188 |
+
|
| 189 |
+
return z_q, codes, latents, commitment_loss, codebook_loss
|
| 190 |
+
|
| 191 |
+
def from_codes(self, codes: torch.Tensor):
|
| 192 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
| 193 |
+
Parameters
|
| 194 |
+
----------
|
| 195 |
+
codes : Tensor[B x N x T]
|
| 196 |
+
Quantized discrete representation of input
|
| 197 |
+
Returns
|
| 198 |
+
-------
|
| 199 |
+
Tensor[B x D x T]
|
| 200 |
+
Quantized continuous representation of input
|
| 201 |
+
"""
|
| 202 |
+
z_q = 0.0
|
| 203 |
+
z_p = []
|
| 204 |
+
n_codebooks = codes.shape[1]
|
| 205 |
+
for i in range(n_codebooks):
|
| 206 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
| 207 |
+
z_p.append(z_p_i)
|
| 208 |
+
|
| 209 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 210 |
+
z_q = z_q + z_q_i
|
| 211 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
| 212 |
+
|
| 213 |
+
def from_latents(self, latents: torch.Tensor):
|
| 214 |
+
"""Given the unquantized latents, reconstruct the
|
| 215 |
+
continuous representation after quantization.
|
| 216 |
+
|
| 217 |
+
Parameters
|
| 218 |
+
----------
|
| 219 |
+
latents : Tensor[B x N x T]
|
| 220 |
+
Continuous representation of input after projection
|
| 221 |
+
|
| 222 |
+
Returns
|
| 223 |
+
-------
|
| 224 |
+
Tensor[B x D x T]
|
| 225 |
+
Quantized representation of full-projected space
|
| 226 |
+
Tensor[B x D x T]
|
| 227 |
+
Quantized representation of latent space
|
| 228 |
+
"""
|
| 229 |
+
z_q = 0
|
| 230 |
+
z_p = []
|
| 231 |
+
codes = []
|
| 232 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
| 233 |
+
|
| 234 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
|
| 235 |
+
for i in range(n_codebooks):
|
| 236 |
+
j, k = dims[i], dims[i + 1]
|
| 237 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
| 238 |
+
z_p.append(z_p_i)
|
| 239 |
+
codes.append(codes_i)
|
| 240 |
+
|
| 241 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 242 |
+
z_q = z_q + z_q_i
|
| 243 |
+
|
| 244 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
rvq = ResidualVectorQuantize(quantizer_dropout=True)
|
| 249 |
+
x = torch.randn(16, 512, 80)
|
| 250 |
+
y = rvq(x)
|
| 251 |
+
print(y["latents"].shape)
|
higgs_audio/audio_processing/higgs_audio_tokenizer.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Based on code from: https://github.com/zhenye234/xcodec
|
| 2 |
+
# Licensed under MIT License
|
| 3 |
+
# Modifications by BosonAI
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from typing import Optional, Union, Sequence
|
| 11 |
+
import numpy as np
|
| 12 |
+
from transformers import AutoModel
|
| 13 |
+
import torchaudio
|
| 14 |
+
import json
|
| 15 |
+
import librosa
|
| 16 |
+
from huggingface_hub import snapshot_download
|
| 17 |
+
|
| 18 |
+
from vector_quantize_pytorch import ResidualFSQ
|
| 19 |
+
from .descriptaudiocodec.dac.model import dac as dac2
|
| 20 |
+
from .quantization.vq import ResidualVectorQuantizer
|
| 21 |
+
from .semantic_module import Encoder, Decoder
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class EncodedResult:
|
| 25 |
+
def __init__(self, audio_codes):
|
| 26 |
+
self.audio_codes = audio_codes
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class HiggsAudioFeatureExtractor(nn.Module):
|
| 30 |
+
def __init__(self, sampling_rate=16000):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.sampling_rate = sampling_rate
|
| 33 |
+
|
| 34 |
+
def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"):
|
| 35 |
+
# Convert from librosa to torch
|
| 36 |
+
audio_signal = torch.tensor(raw_audio)
|
| 37 |
+
audio_signal = audio_signal.unsqueeze(0)
|
| 38 |
+
if len(audio_signal.shape) < 3:
|
| 39 |
+
audio_signal = audio_signal.unsqueeze(0)
|
| 40 |
+
return {"input_values": audio_signal}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class HiggsAudioTokenizer(nn.Module):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
n_filters: int = 32,
|
| 47 |
+
D: int = 128,
|
| 48 |
+
target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6],
|
| 49 |
+
ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320
|
| 50 |
+
sample_rate: int = 16000,
|
| 51 |
+
bins: int = 1024,
|
| 52 |
+
n_q: int = 8,
|
| 53 |
+
codebook_dim: int = None,
|
| 54 |
+
normalize: bool = False,
|
| 55 |
+
causal: bool = False,
|
| 56 |
+
semantic_techer: str = "hubert_base_general",
|
| 57 |
+
last_layer_semantic: bool = True,
|
| 58 |
+
merge_mode: str = "concat",
|
| 59 |
+
downsample_mode: str = "step_down",
|
| 60 |
+
semantic_mode: str = "classic",
|
| 61 |
+
vq_scale: int = 1,
|
| 62 |
+
semantic_sample_rate: int = None,
|
| 63 |
+
device: str = "cuda",
|
| 64 |
+
):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.hop_length = np.prod(ratios)
|
| 67 |
+
self.semantic_techer = semantic_techer
|
| 68 |
+
|
| 69 |
+
self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz
|
| 70 |
+
|
| 71 |
+
self.target_bandwidths = target_bandwidths
|
| 72 |
+
self.n_q = n_q
|
| 73 |
+
self.sample_rate = sample_rate
|
| 74 |
+
self.encoder = dac2.Encoder(64, ratios, D)
|
| 75 |
+
|
| 76 |
+
self.decoder_2 = dac2.Decoder(D, 1024, ratios)
|
| 77 |
+
self.last_layer_semantic = last_layer_semantic
|
| 78 |
+
self.device = device
|
| 79 |
+
if semantic_techer == "hubert_base":
|
| 80 |
+
self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960")
|
| 81 |
+
self.semantic_sample_rate = 16000
|
| 82 |
+
self.semantic_dim = 768
|
| 83 |
+
self.encoder_semantic_dim = 768
|
| 84 |
+
|
| 85 |
+
elif semantic_techer == "wavlm_base_plus":
|
| 86 |
+
self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus")
|
| 87 |
+
self.semantic_sample_rate = 16000
|
| 88 |
+
self.semantic_dim = 768
|
| 89 |
+
self.encoder_semantic_dim = 768
|
| 90 |
+
|
| 91 |
+
elif semantic_techer == "hubert_base_general":
|
| 92 |
+
self.semantic_model = AutoModel.from_pretrained("ZhenYe234/hubert_base_general_audio")
|
| 93 |
+
self.semantic_sample_rate = 16000
|
| 94 |
+
self.semantic_dim = 768
|
| 95 |
+
self.encoder_semantic_dim = 768
|
| 96 |
+
|
| 97 |
+
# Overwrite semantic model sr to ensure semantic_downsample_factor is an integer
|
| 98 |
+
if semantic_sample_rate is not None:
|
| 99 |
+
self.semantic_sample_rate = semantic_sample_rate
|
| 100 |
+
|
| 101 |
+
self.semantic_model.eval()
|
| 102 |
+
|
| 103 |
+
# make the semantic model parameters do not need gradient
|
| 104 |
+
for param in self.semantic_model.parameters():
|
| 105 |
+
param.requires_grad = False
|
| 106 |
+
|
| 107 |
+
self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320)
|
| 108 |
+
|
| 109 |
+
self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale)
|
| 110 |
+
self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim)
|
| 111 |
+
self.decoder_semantic = Decoder(
|
| 112 |
+
code_dim=self.encoder_semantic_dim,
|
| 113 |
+
output_channels=self.semantic_dim,
|
| 114 |
+
decode_channels=self.semantic_dim,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# out_D=D+768
|
| 118 |
+
if isinstance(bins, int): # RVQ
|
| 119 |
+
self.quantizer = ResidualVectorQuantizer(
|
| 120 |
+
dimension=self.quantizer_dim,
|
| 121 |
+
codebook_dim=codebook_dim,
|
| 122 |
+
n_q=n_q,
|
| 123 |
+
bins=bins,
|
| 124 |
+
)
|
| 125 |
+
self.quantizer_type = "RVQ"
|
| 126 |
+
else: # RFSQ
|
| 127 |
+
self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q)
|
| 128 |
+
self.quantizer_type = "RFSQ"
|
| 129 |
+
|
| 130 |
+
self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim)
|
| 131 |
+
self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim)
|
| 132 |
+
self.fc_post2 = nn.Linear(self.quantizer_dim, D)
|
| 133 |
+
|
| 134 |
+
self.downsample_mode = downsample_mode
|
| 135 |
+
if downsample_mode == "avg":
|
| 136 |
+
self.semantic_pooling = nn.AvgPool1d(
|
| 137 |
+
kernel_size=self.semantic_downsample_factor,
|
| 138 |
+
stride=self.semantic_downsample_factor,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate)
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def tps(self):
|
| 145 |
+
return self.frame_rate
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def sampling_rate(self):
|
| 149 |
+
return self.sample_rate
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def num_codebooks(self):
|
| 153 |
+
return self.n_q
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def codebook_size(self):
|
| 157 |
+
return self.quantizer_dim
|
| 158 |
+
|
| 159 |
+
def get_last_layer(self):
|
| 160 |
+
return self.decoder.layers[-1].weight
|
| 161 |
+
|
| 162 |
+
def calculate_rec_loss(self, rec, target):
|
| 163 |
+
target = target / target.norm(dim=-1, keepdim=True)
|
| 164 |
+
rec = rec / rec.norm(dim=-1, keepdim=True)
|
| 165 |
+
rec_loss = (1 - (target * rec).sum(-1)).mean()
|
| 166 |
+
|
| 167 |
+
return rec_loss
|
| 168 |
+
|
| 169 |
+
@torch.no_grad()
|
| 170 |
+
def get_regress_target(self, x):
|
| 171 |
+
x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate)
|
| 172 |
+
|
| 173 |
+
if (
|
| 174 |
+
self.semantic_techer == "hubert_base"
|
| 175 |
+
or self.semantic_techer == "hubert_base_general"
|
| 176 |
+
or self.semantic_techer == "wavlm_base_plus"
|
| 177 |
+
):
|
| 178 |
+
x = x[:, 0, :]
|
| 179 |
+
x = F.pad(x, (160, 160))
|
| 180 |
+
target = self.semantic_model(x, output_hidden_states=True).hidden_states
|
| 181 |
+
target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2)
|
| 182 |
+
|
| 183 |
+
# average for all layers
|
| 184 |
+
target = target.mean(1)
|
| 185 |
+
# target = target[9]
|
| 186 |
+
# if self.hop_length > 320:
|
| 187 |
+
# target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
|
| 188 |
+
|
| 189 |
+
elif self.semantic_techer == "w2v_bert2":
|
| 190 |
+
target = self.semantic_model(x)
|
| 191 |
+
|
| 192 |
+
elif self.semantic_techer.startswith("whisper"):
|
| 193 |
+
if self.last_layer_semantic:
|
| 194 |
+
target = self.semantic_model(x, avg_layers=False)
|
| 195 |
+
else:
|
| 196 |
+
target = self.semantic_model(x, avg_layers=True)
|
| 197 |
+
|
| 198 |
+
elif self.semantic_techer.startswith("mert_music"):
|
| 199 |
+
if self.last_layer_semantic:
|
| 200 |
+
target = self.semantic_model(x, avg_layers=False)
|
| 201 |
+
else:
|
| 202 |
+
target = self.semantic_model(x, avg_layers=True)
|
| 203 |
+
|
| 204 |
+
elif self.semantic_techer.startswith("qwen_audio_omni"):
|
| 205 |
+
target = self.semantic_model(x)
|
| 206 |
+
|
| 207 |
+
if self.downsample_mode == "step_down":
|
| 208 |
+
if self.semantic_downsample_factor > 1:
|
| 209 |
+
target = target[:, :: self.semantic_downsample_factor, :]
|
| 210 |
+
|
| 211 |
+
elif self.downsample_mode == "avg":
|
| 212 |
+
target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
|
| 213 |
+
return target
|
| 214 |
+
|
| 215 |
+
def forward(self, x: torch.Tensor, bw: int):
|
| 216 |
+
e_semantic_input = self.get_regress_target(x).detach()
|
| 217 |
+
|
| 218 |
+
e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
|
| 219 |
+
e_acoustic = self.encoder(x)
|
| 220 |
+
|
| 221 |
+
e = torch.cat([e_acoustic, e_semantic], dim=1)
|
| 222 |
+
|
| 223 |
+
e = self.fc_prior(e.transpose(1, 2))
|
| 224 |
+
|
| 225 |
+
if self.quantizer_type == "RVQ":
|
| 226 |
+
e = e.transpose(1, 2)
|
| 227 |
+
quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
|
| 228 |
+
quantized = quantized.transpose(1, 2)
|
| 229 |
+
else:
|
| 230 |
+
quantized, codes = self.quantizer(e)
|
| 231 |
+
commit_loss = torch.tensor(0.0)
|
| 232 |
+
|
| 233 |
+
quantized_semantic = self.fc_post1(quantized).transpose(1, 2)
|
| 234 |
+
quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
|
| 235 |
+
|
| 236 |
+
o = self.decoder_2(quantized_acoustic)
|
| 237 |
+
|
| 238 |
+
o_semantic = self.decoder_semantic(quantized_semantic)
|
| 239 |
+
semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic)
|
| 240 |
+
|
| 241 |
+
return o, commit_loss, semantic_recon_loss, None
|
| 242 |
+
|
| 243 |
+
def encode(
|
| 244 |
+
self,
|
| 245 |
+
audio_path_or_wv,
|
| 246 |
+
sr=None,
|
| 247 |
+
loudness_normalize=False,
|
| 248 |
+
loudness_threshold=-23.0,
|
| 249 |
+
):
|
| 250 |
+
if isinstance(audio_path_or_wv, str):
|
| 251 |
+
wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None)
|
| 252 |
+
else:
|
| 253 |
+
wv = audio_path_or_wv
|
| 254 |
+
assert sr is not None
|
| 255 |
+
if loudness_normalize:
|
| 256 |
+
import pyloudnorm as pyln
|
| 257 |
+
|
| 258 |
+
meter = pyln.Meter(sr)
|
| 259 |
+
l = meter.integrated_loudness(wv)
|
| 260 |
+
wv = pyln.normalize.loudness(wv, l, loudness_threshold)
|
| 261 |
+
if sr != self.sampling_rate:
|
| 262 |
+
wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate)
|
| 263 |
+
if self.audio_tokenizer_feature_extractor is not None:
|
| 264 |
+
inputs = self.audio_tokenizer_feature_extractor(
|
| 265 |
+
raw_audio=wv,
|
| 266 |
+
sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate,
|
| 267 |
+
return_tensors="pt",
|
| 268 |
+
)
|
| 269 |
+
input_values = inputs["input_values"].to(self.device)
|
| 270 |
+
else:
|
| 271 |
+
input_values = torch.from_numpy(wv).float().unsqueeze(0)
|
| 272 |
+
with torch.no_grad():
|
| 273 |
+
encoder_outputs = self._xcodec_encode(input_values)
|
| 274 |
+
vq_code = encoder_outputs.audio_codes[0]
|
| 275 |
+
return vq_code
|
| 276 |
+
|
| 277 |
+
def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor:
|
| 278 |
+
bw = target_bw
|
| 279 |
+
|
| 280 |
+
e_semantic_input = self.get_regress_target(x).detach()
|
| 281 |
+
|
| 282 |
+
e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
|
| 283 |
+
e_acoustic = self.encoder(x)
|
| 284 |
+
|
| 285 |
+
if e_acoustic.shape[2] != e_semantic.shape[2]:
|
| 286 |
+
pad_size = 160 * self.semantic_downsample_factor
|
| 287 |
+
e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0))
|
| 288 |
+
|
| 289 |
+
if e_acoustic.shape[2] != e_semantic.shape[2]:
|
| 290 |
+
if e_acoustic.shape[2] > e_semantic.shape[2]:
|
| 291 |
+
e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]]
|
| 292 |
+
else:
|
| 293 |
+
e_semantic = e_semantic[:, :, : e_acoustic.shape[2]]
|
| 294 |
+
|
| 295 |
+
e = torch.cat([e_acoustic, e_semantic], dim=1)
|
| 296 |
+
|
| 297 |
+
e = self.fc_prior(e.transpose(1, 2))
|
| 298 |
+
|
| 299 |
+
if self.quantizer_type == "RVQ":
|
| 300 |
+
e = e.transpose(1, 2)
|
| 301 |
+
quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
|
| 302 |
+
codes = codes.permute(1, 0, 2)
|
| 303 |
+
else:
|
| 304 |
+
quantized, codes = self.quantizer(e)
|
| 305 |
+
codes = codes.permute(0, 2, 1)
|
| 306 |
+
|
| 307 |
+
# return codes
|
| 308 |
+
return EncodedResult(codes)
|
| 309 |
+
|
| 310 |
+
def decode(self, vq_code: torch.Tensor) -> torch.Tensor:
|
| 311 |
+
if self.quantizer_type == "RVQ":
|
| 312 |
+
vq_code = vq_code.permute(1, 0, 2)
|
| 313 |
+
quantized = self.quantizer.decode(vq_code)
|
| 314 |
+
quantized = quantized.transpose(1, 2)
|
| 315 |
+
else:
|
| 316 |
+
vq_code = vq_code.permute(0, 2, 1)
|
| 317 |
+
quantized = self.quantizer.get_output_from_indices(vq_code)
|
| 318 |
+
quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
|
| 319 |
+
|
| 320 |
+
o = self.decoder_2(quantized_acoustic)
|
| 321 |
+
return o.cpu().numpy()
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
|
| 325 |
+
is_local = os.path.exists(tokenizer_name_or_path)
|
| 326 |
+
if not is_local:
|
| 327 |
+
tokenizer_path = snapshot_download(tokenizer_name_or_path)
|
| 328 |
+
else:
|
| 329 |
+
tokenizer_path = tokenizer_name_or_path
|
| 330 |
+
config_path = os.path.join(tokenizer_path, "config.json")
|
| 331 |
+
model_path = os.path.join(tokenizer_path, "model.pth")
|
| 332 |
+
config = json.load(open(config_path))
|
| 333 |
+
model = HiggsAudioTokenizer(
|
| 334 |
+
**config,
|
| 335 |
+
device=device,
|
| 336 |
+
)
|
| 337 |
+
parameter_dict = torch.load(model_path, map_location=device)
|
| 338 |
+
model.load_state_dict(parameter_dict, strict=False)
|
| 339 |
+
model.to(device)
|
| 340 |
+
model.eval()
|
| 341 |
+
return model
|
higgs_audio/audio_processing/quantization/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# flake8: noqa
|
| 8 |
+
from .vq import QuantizedResult, ResidualVectorQuantizer
|
higgs_audio/audio_processing/quantization/ac.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Arithmetic coder."""
|
| 8 |
+
|
| 9 |
+
import io
|
| 10 |
+
import math
|
| 11 |
+
import random
|
| 12 |
+
import typing as tp
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from ..binary import BitPacker, BitUnpacker
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def build_stable_quantized_cdf(
|
| 19 |
+
pdf: torch.Tensor,
|
| 20 |
+
total_range_bits: int,
|
| 21 |
+
roundoff: float = 1e-8,
|
| 22 |
+
min_range: int = 2,
|
| 23 |
+
check: bool = True,
|
| 24 |
+
) -> torch.Tensor:
|
| 25 |
+
"""Turn the given PDF into a quantized CDF that splits
|
| 26 |
+
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
|
| 27 |
+
to the PDF.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
pdf (torch.Tensor): probability distribution, shape should be `[N]`.
|
| 31 |
+
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
|
| 32 |
+
during the coding process is `[0, 2 ** total_range_bits - 1]`.
|
| 33 |
+
roundoff (float): will round the pdf up to that level to remove difference coming
|
| 34 |
+
from e.g. evaluating the Language Model on different architectures.
|
| 35 |
+
min_range (int): minimum range width. Should always be at least 2 for numerical
|
| 36 |
+
stability. Use this to avoid pathological behavior is a value
|
| 37 |
+
that is expected to be rare actually happens in real life.
|
| 38 |
+
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
|
| 39 |
+
"""
|
| 40 |
+
pdf = pdf.detach()
|
| 41 |
+
if roundoff:
|
| 42 |
+
pdf = (pdf / roundoff).floor() * roundoff
|
| 43 |
+
# interpolate with uniform distribution to achieve desired minimum probability.
|
| 44 |
+
total_range = 2**total_range_bits
|
| 45 |
+
cardinality = len(pdf)
|
| 46 |
+
alpha = min_range * cardinality / total_range
|
| 47 |
+
assert alpha <= 1, "you must reduce min_range"
|
| 48 |
+
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
|
| 49 |
+
ranges += min_range
|
| 50 |
+
quantized_cdf = torch.cumsum(ranges, dim=-1)
|
| 51 |
+
if min_range < 2:
|
| 52 |
+
raise ValueError("min_range must be at least 2.")
|
| 53 |
+
if check:
|
| 54 |
+
assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
|
| 55 |
+
if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
|
| 56 |
+
raise ValueError("You must increase your total_range_bits.")
|
| 57 |
+
return quantized_cdf
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ArithmeticCoder:
|
| 61 |
+
"""ArithmeticCoder,
|
| 62 |
+
Let us take a distribution `p` over `N` symbols, and assume we have a stream
|
| 63 |
+
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
|
| 64 |
+
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
|
| 65 |
+
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
|
| 66 |
+
sequence `(s_t)` by doing the following:
|
| 67 |
+
|
| 68 |
+
1) Initialize the current range to` [0 ** 2 B - 1]`.
|
| 69 |
+
2) For each time step t, split the current range into contiguous chunks,
|
| 70 |
+
one for each possible outcome, with size roughly proportional to `p`.
|
| 71 |
+
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
|
| 72 |
+
would be `{[0, 2], [3, 3]}`.
|
| 73 |
+
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
|
| 74 |
+
4) When done encoding all the values, just select any value remaining in the range.
|
| 75 |
+
|
| 76 |
+
You will notice that this procedure can fail: for instance if at any point in time
|
| 77 |
+
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
|
| 78 |
+
possible outcome. Intuitively, the more likely a value is, the less the range width
|
| 79 |
+
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
|
| 80 |
+
coding scheme, likely outcomes would take less bits, and more of them can be coded
|
| 81 |
+
with a fixed budget.
|
| 82 |
+
|
| 83 |
+
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
|
| 84 |
+
when the current range decreases below a given limit (given by `total_range_bits`), without
|
| 85 |
+
having to redo all the computations. If we encode mostly likely values, we will seldom
|
| 86 |
+
need to inject new bits, but a single rare value can deplete our stock of entropy!
|
| 87 |
+
|
| 88 |
+
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
|
| 89 |
+
code works for any sequence `(p_t)` possibly different for each timestep.
|
| 90 |
+
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
|
| 91 |
+
the KL between the true distribution and `p_t`, the most efficient the coding will be.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
fo (IO[bytes]): file-like object to which the bytes will be written to.
|
| 95 |
+
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
|
| 96 |
+
Any time the current range width fall under this limit, new bits will
|
| 97 |
+
be injected to rescale the initial range.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
| 101 |
+
assert total_range_bits <= 30
|
| 102 |
+
self.total_range_bits = total_range_bits
|
| 103 |
+
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
|
| 104 |
+
self.low: int = 0
|
| 105 |
+
self.high: int = 0
|
| 106 |
+
self.max_bit: int = -1
|
| 107 |
+
self._dbg: tp.List[tp.Any] = []
|
| 108 |
+
self._dbg2: tp.List[tp.Any] = []
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def delta(self) -> int:
|
| 112 |
+
"""Return the current range width."""
|
| 113 |
+
return self.high - self.low + 1
|
| 114 |
+
|
| 115 |
+
def _flush_common_prefix(self):
|
| 116 |
+
# If self.low and self.high start with the sames bits,
|
| 117 |
+
# those won't change anymore as we always just increase the range
|
| 118 |
+
# by powers of 2, and we can flush them out to the bit stream.
|
| 119 |
+
assert self.high >= self.low, (self.low, self.high)
|
| 120 |
+
assert self.high < 2 ** (self.max_bit + 1)
|
| 121 |
+
while self.max_bit >= 0:
|
| 122 |
+
b1 = self.low >> self.max_bit
|
| 123 |
+
b2 = self.high >> self.max_bit
|
| 124 |
+
if b1 == b2:
|
| 125 |
+
self.low -= b1 << self.max_bit
|
| 126 |
+
self.high -= b1 << self.max_bit
|
| 127 |
+
assert self.high >= self.low, (self.high, self.low, self.max_bit)
|
| 128 |
+
assert self.low >= 0
|
| 129 |
+
self.max_bit -= 1
|
| 130 |
+
self.packer.push(b1)
|
| 131 |
+
else:
|
| 132 |
+
break
|
| 133 |
+
|
| 134 |
+
def push(self, symbol: int, quantized_cdf: torch.Tensor):
|
| 135 |
+
"""Push the given symbol on the stream, flushing out bits
|
| 136 |
+
if possible.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
symbol (int): symbol to encode with the AC.
|
| 140 |
+
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
| 141 |
+
to build this from your pdf estimate.
|
| 142 |
+
"""
|
| 143 |
+
while self.delta < 2**self.total_range_bits:
|
| 144 |
+
self.low *= 2
|
| 145 |
+
self.high = self.high * 2 + 1
|
| 146 |
+
self.max_bit += 1
|
| 147 |
+
|
| 148 |
+
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
|
| 149 |
+
range_high = quantized_cdf[symbol].item() - 1
|
| 150 |
+
effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
|
| 151 |
+
effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
|
| 152 |
+
assert self.low <= self.high
|
| 153 |
+
self.high = self.low + effective_high
|
| 154 |
+
self.low = self.low + effective_low
|
| 155 |
+
assert self.low <= self.high, (
|
| 156 |
+
effective_low,
|
| 157 |
+
effective_high,
|
| 158 |
+
range_low,
|
| 159 |
+
range_high,
|
| 160 |
+
)
|
| 161 |
+
self._dbg.append((self.low, self.high))
|
| 162 |
+
self._dbg2.append((self.low, self.high))
|
| 163 |
+
outs = self._flush_common_prefix()
|
| 164 |
+
assert self.low <= self.high
|
| 165 |
+
assert self.max_bit >= -1
|
| 166 |
+
assert self.max_bit <= 61, self.max_bit
|
| 167 |
+
return outs
|
| 168 |
+
|
| 169 |
+
def flush(self):
|
| 170 |
+
"""Flush the remaining information to the stream."""
|
| 171 |
+
while self.max_bit >= 0:
|
| 172 |
+
b1 = (self.low >> self.max_bit) & 1
|
| 173 |
+
self.packer.push(b1)
|
| 174 |
+
self.max_bit -= 1
|
| 175 |
+
self.packer.flush()
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class ArithmeticDecoder:
|
| 179 |
+
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
|
| 180 |
+
|
| 181 |
+
Note that this must be called with **exactly** the same parameters and sequence
|
| 182 |
+
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
|
| 183 |
+
|
| 184 |
+
If the AC encoder current range is [L, H], with `L` and `H` having the some common
|
| 185 |
+
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
|
| 186 |
+
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
|
| 187 |
+
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
|
| 188 |
+
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
|
| 189 |
+
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
|
| 190 |
+
and we will need to read new bits from the stream and repeat the process.
|
| 191 |
+
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
| 195 |
+
self.total_range_bits = total_range_bits
|
| 196 |
+
self.low: int = 0
|
| 197 |
+
self.high: int = 0
|
| 198 |
+
self.current: int = 0
|
| 199 |
+
self.max_bit: int = -1
|
| 200 |
+
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
|
| 201 |
+
# Following is for debugging
|
| 202 |
+
self._dbg: tp.List[tp.Any] = []
|
| 203 |
+
self._dbg2: tp.List[tp.Any] = []
|
| 204 |
+
self._last: tp.Any = None
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def delta(self) -> int:
|
| 208 |
+
return self.high - self.low + 1
|
| 209 |
+
|
| 210 |
+
def _flush_common_prefix(self):
|
| 211 |
+
# Given the current range [L, H], if both have a common prefix,
|
| 212 |
+
# we know we can remove it from our representation to avoid handling large numbers.
|
| 213 |
+
while self.max_bit >= 0:
|
| 214 |
+
b1 = self.low >> self.max_bit
|
| 215 |
+
b2 = self.high >> self.max_bit
|
| 216 |
+
if b1 == b2:
|
| 217 |
+
self.low -= b1 << self.max_bit
|
| 218 |
+
self.high -= b1 << self.max_bit
|
| 219 |
+
self.current -= b1 << self.max_bit
|
| 220 |
+
assert self.high >= self.low
|
| 221 |
+
assert self.low >= 0
|
| 222 |
+
self.max_bit -= 1
|
| 223 |
+
else:
|
| 224 |
+
break
|
| 225 |
+
|
| 226 |
+
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
|
| 227 |
+
"""Pull a symbol, reading as many bits from the stream as required.
|
| 228 |
+
This returns `None` when the stream has been exhausted.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
| 232 |
+
to build this from your pdf estimate. This must be **exatly**
|
| 233 |
+
the same cdf as the one used at encoding time.
|
| 234 |
+
"""
|
| 235 |
+
while self.delta < 2**self.total_range_bits:
|
| 236 |
+
bit = self.unpacker.pull()
|
| 237 |
+
if bit is None:
|
| 238 |
+
return None
|
| 239 |
+
self.low *= 2
|
| 240 |
+
self.high = self.high * 2 + 1
|
| 241 |
+
self.current = self.current * 2 + bit
|
| 242 |
+
self.max_bit += 1
|
| 243 |
+
|
| 244 |
+
def bin_search(low_idx: int, high_idx: int):
|
| 245 |
+
# Binary search is not just for coding interviews :)
|
| 246 |
+
if high_idx < low_idx:
|
| 247 |
+
raise RuntimeError("Binary search failed")
|
| 248 |
+
mid = (low_idx + high_idx) // 2
|
| 249 |
+
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
|
| 250 |
+
range_high = quantized_cdf[mid].item() - 1
|
| 251 |
+
effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
|
| 252 |
+
effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
|
| 253 |
+
low = effective_low + self.low
|
| 254 |
+
high = effective_high + self.low
|
| 255 |
+
if self.current >= low:
|
| 256 |
+
if self.current <= high:
|
| 257 |
+
return (mid, low, high, self.current)
|
| 258 |
+
else:
|
| 259 |
+
return bin_search(mid + 1, high_idx)
|
| 260 |
+
else:
|
| 261 |
+
return bin_search(low_idx, mid - 1)
|
| 262 |
+
|
| 263 |
+
self._last = (self.low, self.high, self.current, self.max_bit)
|
| 264 |
+
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
|
| 265 |
+
self._dbg.append((self.low, self.high, self.current))
|
| 266 |
+
self._flush_common_prefix()
|
| 267 |
+
self._dbg2.append((self.low, self.high, self.current))
|
| 268 |
+
|
| 269 |
+
return sym
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def test():
|
| 273 |
+
torch.manual_seed(1234)
|
| 274 |
+
random.seed(1234)
|
| 275 |
+
for _ in range(4):
|
| 276 |
+
pdfs = []
|
| 277 |
+
cardinality = random.randrange(4000)
|
| 278 |
+
steps = random.randrange(100, 500)
|
| 279 |
+
fo = io.BytesIO()
|
| 280 |
+
encoder = ArithmeticCoder(fo)
|
| 281 |
+
symbols = []
|
| 282 |
+
for step in range(steps):
|
| 283 |
+
pdf = torch.softmax(torch.randn(cardinality), dim=0)
|
| 284 |
+
pdfs.append(pdf)
|
| 285 |
+
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
| 286 |
+
symbol = torch.multinomial(pdf, 1).item()
|
| 287 |
+
symbols.append(symbol)
|
| 288 |
+
encoder.push(symbol, q_cdf)
|
| 289 |
+
encoder.flush()
|
| 290 |
+
|
| 291 |
+
fo.seek(0)
|
| 292 |
+
decoder = ArithmeticDecoder(fo)
|
| 293 |
+
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
|
| 294 |
+
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
| 295 |
+
decoded_symbol = decoder.pull(q_cdf)
|
| 296 |
+
assert decoded_symbol == symbol, idx
|
| 297 |
+
assert decoder.pull(torch.zeros(1)) is None
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
if __name__ == "__main__":
|
| 301 |
+
test()
|
higgs_audio/audio_processing/quantization/core_vq.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# This implementation is inspired from
|
| 8 |
+
# https://github.com/lucidrains/vector-quantize-pytorch
|
| 9 |
+
# which is released under MIT License. Hereafter, the original license:
|
| 10 |
+
# MIT License
|
| 11 |
+
#
|
| 12 |
+
# Copyright (c) 2020 Phil Wang
|
| 13 |
+
#
|
| 14 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 15 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 16 |
+
# in the Software without restriction, including without limitation the rights
|
| 17 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 18 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 19 |
+
# furnished to do so, subject to the following conditions:
|
| 20 |
+
#
|
| 21 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 22 |
+
# copies or substantial portions of the Software.
|
| 23 |
+
#
|
| 24 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 25 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 26 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 27 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 28 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 29 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 30 |
+
# SOFTWARE.
|
| 31 |
+
|
| 32 |
+
"""Core vector quantization implementation."""
|
| 33 |
+
|
| 34 |
+
import typing as tp
|
| 35 |
+
|
| 36 |
+
from einops import rearrange, repeat
|
| 37 |
+
import torch
|
| 38 |
+
from torch import nn
|
| 39 |
+
import torch.nn.functional as F
|
| 40 |
+
|
| 41 |
+
from xcodec.quantization.distrib import broadcast_tensors, rank
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
| 45 |
+
return val if val is not None else d
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def ema_inplace(moving_avg, new, decay: float):
|
| 49 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
| 53 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def uniform_init(*shape: int):
|
| 57 |
+
t = torch.empty(shape)
|
| 58 |
+
nn.init.kaiming_uniform_(t)
|
| 59 |
+
return t
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def sample_vectors(samples, num: int):
|
| 63 |
+
num_samples, device = samples.shape[0], samples.device
|
| 64 |
+
|
| 65 |
+
if num_samples >= num:
|
| 66 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
| 67 |
+
else:
|
| 68 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 69 |
+
|
| 70 |
+
return samples[indices]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
| 74 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
| 75 |
+
|
| 76 |
+
means = sample_vectors(samples, num_clusters)
|
| 77 |
+
|
| 78 |
+
for _ in range(num_iters):
|
| 79 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
| 80 |
+
dists = -(diffs**2).sum(dim=-1)
|
| 81 |
+
|
| 82 |
+
buckets = dists.max(dim=-1).indices
|
| 83 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
| 84 |
+
zero_mask = bins == 0
|
| 85 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
| 86 |
+
|
| 87 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
| 88 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
| 89 |
+
new_means = new_means / bins_min_clamped[..., None]
|
| 90 |
+
|
| 91 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
| 92 |
+
|
| 93 |
+
return means, bins
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class EuclideanCodebook(nn.Module):
|
| 97 |
+
"""Codebook with Euclidean distance.
|
| 98 |
+
Args:
|
| 99 |
+
dim (int): Dimension.
|
| 100 |
+
codebook_size (int): Codebook size.
|
| 101 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
| 102 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
| 103 |
+
the learned centroids as initialization.
|
| 104 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
| 105 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 106 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 107 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 108 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 109 |
+
randomly selected vector from the current batch.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
dim: int,
|
| 115 |
+
codebook_size: int,
|
| 116 |
+
kmeans_init: int = False,
|
| 117 |
+
kmeans_iters: int = 10,
|
| 118 |
+
decay: float = 0.99,
|
| 119 |
+
epsilon: float = 1e-5,
|
| 120 |
+
threshold_ema_dead_code: int = 2,
|
| 121 |
+
):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.decay = decay
|
| 124 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
| 125 |
+
embed = init_fn(codebook_size, dim)
|
| 126 |
+
|
| 127 |
+
self.codebook_size = codebook_size
|
| 128 |
+
|
| 129 |
+
self.kmeans_iters = kmeans_iters
|
| 130 |
+
self.epsilon = epsilon
|
| 131 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 132 |
+
|
| 133 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
| 134 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
| 135 |
+
self.register_buffer("embed", embed)
|
| 136 |
+
self.register_buffer("embed_avg", embed.clone())
|
| 137 |
+
|
| 138 |
+
@torch.jit.ignore
|
| 139 |
+
def init_embed_(self, data):
|
| 140 |
+
if self.inited:
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
| 144 |
+
self.embed.data.copy_(embed)
|
| 145 |
+
self.embed_avg.data.copy_(embed.clone())
|
| 146 |
+
self.cluster_size.data.copy_(cluster_size)
|
| 147 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
| 148 |
+
# Make sure all buffers across workers are in sync after initialization
|
| 149 |
+
broadcast_tensors(self.buffers())
|
| 150 |
+
|
| 151 |
+
def replace_(self, samples, mask):
|
| 152 |
+
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
| 153 |
+
self.embed.data.copy_(modified_codebook)
|
| 154 |
+
|
| 155 |
+
def expire_codes_(self, batch_samples):
|
| 156 |
+
if self.threshold_ema_dead_code == 0:
|
| 157 |
+
return
|
| 158 |
+
|
| 159 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
| 160 |
+
if not torch.any(expired_codes):
|
| 161 |
+
return
|
| 162 |
+
|
| 163 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
| 164 |
+
self.replace_(batch_samples, mask=expired_codes)
|
| 165 |
+
broadcast_tensors(self.buffers())
|
| 166 |
+
|
| 167 |
+
def preprocess(self, x):
|
| 168 |
+
x = rearrange(x, "... d -> (...) d")
|
| 169 |
+
return x
|
| 170 |
+
|
| 171 |
+
def quantize(self, x):
|
| 172 |
+
embed = self.embed.t()
|
| 173 |
+
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
|
| 174 |
+
embed_ind = dist.max(dim=-1).indices
|
| 175 |
+
return embed_ind
|
| 176 |
+
|
| 177 |
+
def postprocess_emb(self, embed_ind, shape):
|
| 178 |
+
return embed_ind.view(*shape[:-1])
|
| 179 |
+
|
| 180 |
+
def dequantize(self, embed_ind):
|
| 181 |
+
quantize = F.embedding(embed_ind, self.embed) # get embedding based on index
|
| 182 |
+
return quantize
|
| 183 |
+
|
| 184 |
+
def encode(self, x):
|
| 185 |
+
shape = x.shape
|
| 186 |
+
# pre-process
|
| 187 |
+
x = self.preprocess(x)
|
| 188 |
+
# quantize
|
| 189 |
+
embed_ind = self.quantize(x) # get index based on Euclidean distance
|
| 190 |
+
# post-process
|
| 191 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 192 |
+
return embed_ind
|
| 193 |
+
|
| 194 |
+
def decode(self, embed_ind):
|
| 195 |
+
quantize = self.dequantize(embed_ind)
|
| 196 |
+
return quantize
|
| 197 |
+
|
| 198 |
+
def forward(self, x):
|
| 199 |
+
shape, dtype = x.shape, x.dtype
|
| 200 |
+
x = self.preprocess(x)
|
| 201 |
+
|
| 202 |
+
self.init_embed_(x)
|
| 203 |
+
|
| 204 |
+
embed_ind = self.quantize(x)
|
| 205 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
| 206 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 207 |
+
quantize = self.dequantize(embed_ind)
|
| 208 |
+
|
| 209 |
+
if self.training:
|
| 210 |
+
# We do the expiry of code at that point as buffers are in sync
|
| 211 |
+
# and all the workers will take the same decision.
|
| 212 |
+
self.expire_codes_(x)
|
| 213 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
| 214 |
+
embed_sum = x.t() @ embed_onehot
|
| 215 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
| 216 |
+
cluster_size = (
|
| 217 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
|
| 218 |
+
)
|
| 219 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
| 220 |
+
self.embed.data.copy_(embed_normalized)
|
| 221 |
+
|
| 222 |
+
return quantize, embed_ind
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class VectorQuantization(nn.Module):
|
| 226 |
+
"""Vector quantization implementation.
|
| 227 |
+
Currently supports only euclidean distance.
|
| 228 |
+
Args:
|
| 229 |
+
dim (int): Dimension
|
| 230 |
+
codebook_size (int): Codebook size
|
| 231 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
| 232 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 233 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 234 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
| 235 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
| 236 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 237 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 238 |
+
randomly selected vector from the current batch.
|
| 239 |
+
commitment_weight (float): Weight for commitment loss.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def __init__(
|
| 243 |
+
self,
|
| 244 |
+
dim: int,
|
| 245 |
+
codebook_size: int,
|
| 246 |
+
codebook_dim: tp.Optional[int] = None,
|
| 247 |
+
decay: float = 0.99,
|
| 248 |
+
epsilon: float = 1e-5,
|
| 249 |
+
kmeans_init: bool = True,
|
| 250 |
+
kmeans_iters: int = 50,
|
| 251 |
+
threshold_ema_dead_code: int = 2,
|
| 252 |
+
commitment_weight: float = 1.0,
|
| 253 |
+
):
|
| 254 |
+
super().__init__()
|
| 255 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
| 256 |
+
|
| 257 |
+
requires_projection = _codebook_dim != dim
|
| 258 |
+
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
| 259 |
+
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
| 260 |
+
|
| 261 |
+
self.epsilon = epsilon
|
| 262 |
+
self.commitment_weight = commitment_weight
|
| 263 |
+
|
| 264 |
+
self._codebook = EuclideanCodebook(
|
| 265 |
+
dim=_codebook_dim,
|
| 266 |
+
codebook_size=codebook_size,
|
| 267 |
+
kmeans_init=kmeans_init,
|
| 268 |
+
kmeans_iters=kmeans_iters,
|
| 269 |
+
decay=decay,
|
| 270 |
+
epsilon=epsilon,
|
| 271 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
| 272 |
+
)
|
| 273 |
+
self.codebook_size = codebook_size
|
| 274 |
+
|
| 275 |
+
@property
|
| 276 |
+
def codebook(self):
|
| 277 |
+
return self._codebook.embed
|
| 278 |
+
|
| 279 |
+
def encode(self, x):
|
| 280 |
+
x = rearrange(x, "b d n -> b n d")
|
| 281 |
+
x = self.project_in(x)
|
| 282 |
+
embed_in = self._codebook.encode(x)
|
| 283 |
+
return embed_in
|
| 284 |
+
|
| 285 |
+
def decode(self, embed_ind):
|
| 286 |
+
quantize = self._codebook.decode(embed_ind)
|
| 287 |
+
quantize = self.project_out(quantize)
|
| 288 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
| 289 |
+
return quantize
|
| 290 |
+
|
| 291 |
+
def forward(self, x):
|
| 292 |
+
device = x.device
|
| 293 |
+
x = rearrange(x, "b d n -> b n d")
|
| 294 |
+
x = self.project_in(x)
|
| 295 |
+
|
| 296 |
+
quantize, embed_ind = self._codebook(x)
|
| 297 |
+
|
| 298 |
+
if self.training:
|
| 299 |
+
quantize = x + (quantize - x).detach()
|
| 300 |
+
|
| 301 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
| 302 |
+
|
| 303 |
+
if self.training:
|
| 304 |
+
if self.commitment_weight > 0:
|
| 305 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
| 306 |
+
loss = loss + commit_loss * self.commitment_weight
|
| 307 |
+
|
| 308 |
+
quantize = self.project_out(quantize)
|
| 309 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
| 310 |
+
return quantize, embed_ind, loss
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class ResidualVectorQuantization(nn.Module):
|
| 314 |
+
"""Residual vector quantization implementation.
|
| 315 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
def __init__(self, *, num_quantizers, **kwargs):
|
| 319 |
+
super().__init__()
|
| 320 |
+
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
|
| 321 |
+
|
| 322 |
+
def forward(self, x, n_q: tp.Optional[int] = None):
|
| 323 |
+
quantized_out = 0.0
|
| 324 |
+
residual = x
|
| 325 |
+
|
| 326 |
+
all_losses = []
|
| 327 |
+
all_indices = []
|
| 328 |
+
|
| 329 |
+
n_q = n_q or len(self.layers)
|
| 330 |
+
|
| 331 |
+
for layer in self.layers[:n_q]:
|
| 332 |
+
quantized, indices, loss = layer(residual)
|
| 333 |
+
residual = residual - quantized
|
| 334 |
+
quantized_out = quantized_out + quantized
|
| 335 |
+
|
| 336 |
+
all_indices.append(indices)
|
| 337 |
+
all_losses.append(loss)
|
| 338 |
+
|
| 339 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
| 340 |
+
return quantized_out, out_indices, out_losses
|
| 341 |
+
|
| 342 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
| 343 |
+
residual = x
|
| 344 |
+
all_indices = []
|
| 345 |
+
n_q = n_q or len(self.layers)
|
| 346 |
+
for layer in self.layers[:n_q]:
|
| 347 |
+
indices = layer.encode(residual)
|
| 348 |
+
quantized = layer.decode(indices)
|
| 349 |
+
residual = residual - quantized
|
| 350 |
+
all_indices.append(indices)
|
| 351 |
+
out_indices = torch.stack(all_indices)
|
| 352 |
+
return out_indices
|
| 353 |
+
|
| 354 |
+
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
| 355 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
| 356 |
+
for i, indices in enumerate(q_indices):
|
| 357 |
+
layer = self.layers[i]
|
| 358 |
+
quantized = layer.decode(indices)
|
| 359 |
+
quantized_out = quantized_out + quantized
|
| 360 |
+
return quantized_out
|
higgs_audio/audio_processing/quantization/core_vq_lsx_version.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c)
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
# This implementation is inspired from
|
| 6 |
+
# https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py and
|
| 7 |
+
# https://github.com/clementchadebec/benchmark_VAE/blob/dfa0dcf6c79172df5d27769c09c860c42008baaa/src/pythae/models/vq_vae/vq_vae_utils.py#L81
|
| 8 |
+
#
|
| 9 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 10 |
+
# All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This source code is licensed under the license found in the
|
| 13 |
+
# LICENSE file in the root directory of this source tree.
|
| 14 |
+
#
|
| 15 |
+
# This implementation is inspired from
|
| 16 |
+
# https://github.com/lucidrains/vector-quantize-pytorch
|
| 17 |
+
# which is released under MIT License. Hereafter, the original license:
|
| 18 |
+
# MIT License
|
| 19 |
+
#
|
| 20 |
+
# Copyright (c) 2020 Phil Wang
|
| 21 |
+
#
|
| 22 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 23 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 24 |
+
# in the Software without restriction, including without limitation the rights
|
| 25 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 26 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 27 |
+
# furnished to do so, subject to the following conditions:
|
| 28 |
+
#
|
| 29 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 30 |
+
# copies or substantial portions of the Software.
|
| 31 |
+
#
|
| 32 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 33 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 34 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 35 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 36 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 37 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 38 |
+
# SOFTWARE.
|
| 39 |
+
|
| 40 |
+
"""Core vector quantization implementation."""
|
| 41 |
+
|
| 42 |
+
import typing as tp
|
| 43 |
+
|
| 44 |
+
from einops import rearrange
|
| 45 |
+
import torch
|
| 46 |
+
from torch import nn
|
| 47 |
+
import torch.nn.functional as F
|
| 48 |
+
import torch.distributed as dist
|
| 49 |
+
|
| 50 |
+
from .distrib import broadcast_tensors, is_distributed
|
| 51 |
+
from .ddp_utils import SyncFunction
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
| 55 |
+
return val if val is not None else d
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def ema_inplace(moving_avg, new, decay: float):
|
| 59 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
| 63 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def uniform_init(*shape: int):
|
| 67 |
+
t = torch.empty(shape)
|
| 68 |
+
nn.init.kaiming_uniform_(t)
|
| 69 |
+
return t
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def sample_vectors(samples, num: int):
|
| 73 |
+
num_samples, device = samples.shape[0], samples.device
|
| 74 |
+
|
| 75 |
+
if num_samples >= num:
|
| 76 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
| 77 |
+
else:
|
| 78 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 79 |
+
|
| 80 |
+
return samples[indices]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def kmeans(
|
| 84 |
+
samples,
|
| 85 |
+
num_clusters: int,
|
| 86 |
+
num_iters: int = 10,
|
| 87 |
+
frames_to_use: int = 10_000,
|
| 88 |
+
batch_size: int = 64,
|
| 89 |
+
):
|
| 90 |
+
"""
|
| 91 |
+
Memory-efficient K-means clustering.
|
| 92 |
+
Args:
|
| 93 |
+
samples (tensor): shape [N, D]
|
| 94 |
+
num_clusters (int): number of centroids.
|
| 95 |
+
num_iters (int): number of iterations.
|
| 96 |
+
frames_to_use (int): subsample size from total samples.
|
| 97 |
+
batch_size (int): batch size used in distance computation.
|
| 98 |
+
Returns:
|
| 99 |
+
means: [num_clusters, D]
|
| 100 |
+
bins: [num_clusters] (number of points per cluster)
|
| 101 |
+
"""
|
| 102 |
+
N, D = samples.shape
|
| 103 |
+
dtype, device = samples.dtype, samples.device
|
| 104 |
+
|
| 105 |
+
if frames_to_use < N:
|
| 106 |
+
indices = torch.randperm(N, device=device)[:frames_to_use]
|
| 107 |
+
samples = samples[indices]
|
| 108 |
+
|
| 109 |
+
means = sample_vectors(samples, num_clusters)
|
| 110 |
+
|
| 111 |
+
for _ in range(num_iters):
|
| 112 |
+
# Store cluster assignments
|
| 113 |
+
all_assignments = []
|
| 114 |
+
|
| 115 |
+
for i in range(0, samples.shape[0], batch_size):
|
| 116 |
+
batch = samples[i : i + batch_size] # [B, D]
|
| 117 |
+
dists = torch.cdist(batch, means, p=2) # [B, C]
|
| 118 |
+
assignments = dists.argmin(dim=1) # [B]
|
| 119 |
+
all_assignments.append(assignments)
|
| 120 |
+
|
| 121 |
+
buckets = torch.cat(all_assignments, dim=0) # [N]
|
| 122 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
| 123 |
+
zero_mask = bins == 0
|
| 124 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
| 125 |
+
|
| 126 |
+
# Compute new means
|
| 127 |
+
new_means = torch.zeros_like(means)
|
| 128 |
+
for i in range(num_clusters):
|
| 129 |
+
mask = buckets == i
|
| 130 |
+
if mask.any():
|
| 131 |
+
new_means[i] = samples[mask].mean(dim=0)
|
| 132 |
+
|
| 133 |
+
means = torch.where(zero_mask[:, None], means, new_means)
|
| 134 |
+
|
| 135 |
+
return means, bins
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class EuclideanCodebook(nn.Module):
|
| 139 |
+
"""Codebook with Euclidean distance.
|
| 140 |
+
Args:
|
| 141 |
+
dim (int): Dimension.
|
| 142 |
+
codebook_size (int): Codebook size.
|
| 143 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
| 144 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
| 145 |
+
the learned centroids as initialization.
|
| 146 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
| 147 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 148 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 149 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 150 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 151 |
+
randomly selected vector from the current batch.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
dim: int,
|
| 157 |
+
codebook_size: int,
|
| 158 |
+
kmeans_init: int = False,
|
| 159 |
+
kmeans_iters: int = 10,
|
| 160 |
+
decay: float = 0.99,
|
| 161 |
+
epsilon: float = 1e-5,
|
| 162 |
+
threshold_ema_dead_code: int = 2,
|
| 163 |
+
):
|
| 164 |
+
super().__init__()
|
| 165 |
+
self.decay = decay
|
| 166 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
| 167 |
+
embed = init_fn(codebook_size, dim)
|
| 168 |
+
|
| 169 |
+
self.codebook_size = codebook_size
|
| 170 |
+
|
| 171 |
+
self.kmeans_iters = kmeans_iters
|
| 172 |
+
self.epsilon = epsilon
|
| 173 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 174 |
+
|
| 175 |
+
# Flag variable to indicate whether the codebook is initialized
|
| 176 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
| 177 |
+
# Runing EMA cluster size/count: N_i^t in eq. (6) in vqvae paper
|
| 178 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
| 179 |
+
# Codebook
|
| 180 |
+
self.register_buffer("embed", embed)
|
| 181 |
+
# EMA codebook: eq. (7) in vqvae paper
|
| 182 |
+
self.register_buffer("embed_avg", embed.clone())
|
| 183 |
+
|
| 184 |
+
@torch.jit.ignore
|
| 185 |
+
def init_embed_(self, data):
|
| 186 |
+
"""Initialize codebook.
|
| 187 |
+
Args:
|
| 188 |
+
data (tensor): [B * T, D].
|
| 189 |
+
"""
|
| 190 |
+
if self.inited:
|
| 191 |
+
return
|
| 192 |
+
|
| 193 |
+
## NOTE (snippet added by Songxiang Liu): gather data from all gpus
|
| 194 |
+
if dist.is_available() and dist.is_initialized():
|
| 195 |
+
# [B * T * world_size, D]
|
| 196 |
+
data = SyncFunction.apply(data)
|
| 197 |
+
|
| 198 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
| 199 |
+
self.embed.data.copy_(embed)
|
| 200 |
+
self.embed_avg.data.copy_(embed.clone())
|
| 201 |
+
self.cluster_size.data.copy_(cluster_size)
|
| 202 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
| 203 |
+
# Make sure all buffers across workers are in sync after initialization
|
| 204 |
+
broadcast_tensors(self.buffers())
|
| 205 |
+
|
| 206 |
+
def replace_(self, samples, mask):
|
| 207 |
+
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
| 208 |
+
self.embed.data.copy_(modified_codebook)
|
| 209 |
+
|
| 210 |
+
def expire_codes_(self, batch_samples):
|
| 211 |
+
if self.threshold_ema_dead_code == 0:
|
| 212 |
+
return
|
| 213 |
+
|
| 214 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
| 215 |
+
if not torch.any(expired_codes):
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
+
## NOTE (snippet added by Songxiang Liu): gather data from all gpus
|
| 219 |
+
if is_distributed():
|
| 220 |
+
# [B * T * world_size, D]
|
| 221 |
+
batch_samples = SyncFunction.apply(batch_samples)
|
| 222 |
+
|
| 223 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
| 224 |
+
self.replace_(batch_samples, mask=expired_codes)
|
| 225 |
+
broadcast_tensors(self.buffers())
|
| 226 |
+
|
| 227 |
+
def preprocess(self, x):
|
| 228 |
+
x = rearrange(x, "... d -> (...) d")
|
| 229 |
+
return x
|
| 230 |
+
|
| 231 |
+
def quantize(self, x):
|
| 232 |
+
embed = self.embed.t()
|
| 233 |
+
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
|
| 234 |
+
embed_ind = dist.max(dim=-1).indices
|
| 235 |
+
return embed_ind
|
| 236 |
+
|
| 237 |
+
def postprocess_emb(self, embed_ind, shape):
|
| 238 |
+
return embed_ind.view(*shape[:-1])
|
| 239 |
+
|
| 240 |
+
def dequantize(self, embed_ind):
|
| 241 |
+
quantize = F.embedding(embed_ind, self.embed)
|
| 242 |
+
return quantize
|
| 243 |
+
|
| 244 |
+
def encode(self, x):
|
| 245 |
+
shape = x.shape
|
| 246 |
+
# pre-process
|
| 247 |
+
x = self.preprocess(x) # [B, T, D] -> [B*T, D]
|
| 248 |
+
# quantize
|
| 249 |
+
embed_ind = self.quantize(x)
|
| 250 |
+
# post-process
|
| 251 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 252 |
+
return embed_ind
|
| 253 |
+
|
| 254 |
+
def decode(self, embed_ind):
|
| 255 |
+
quantize = self.dequantize(embed_ind)
|
| 256 |
+
return quantize
|
| 257 |
+
|
| 258 |
+
def forward(self, x):
|
| 259 |
+
# shape: [B, T, D]
|
| 260 |
+
shape, dtype = x.shape, x.dtype
|
| 261 |
+
x = self.preprocess(x) # [B, T, D] -> [B*T, D]
|
| 262 |
+
|
| 263 |
+
# Initialize codebook
|
| 264 |
+
self.init_embed_(x)
|
| 265 |
+
|
| 266 |
+
embed_ind = self.quantize(x) # [B*T,]
|
| 267 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) # [B*T, cb-size]
|
| 268 |
+
embed_ind = self.postprocess_emb(embed_ind, shape) # [B, T]
|
| 269 |
+
quantize = self.dequantize(embed_ind) # [B, T, D]
|
| 270 |
+
|
| 271 |
+
if self.training:
|
| 272 |
+
### Update codebook by EMA
|
| 273 |
+
embed_onehot_sum = embed_onehot.sum(0) # [cb-size,]
|
| 274 |
+
embed_sum = x.t() @ embed_onehot # [D, cb-size]
|
| 275 |
+
if is_distributed():
|
| 276 |
+
dist.all_reduce(embed_onehot_sum)
|
| 277 |
+
dist.all_reduce(embed_sum)
|
| 278 |
+
# Update ema cluster count N_i^t, eq. (6) in vqvae paper
|
| 279 |
+
self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
|
| 280 |
+
# Update ema embed: eq. (7) in vqvae paper
|
| 281 |
+
self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
|
| 282 |
+
# apply laplace smoothing
|
| 283 |
+
n = self.cluster_size.sum()
|
| 284 |
+
cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
|
| 285 |
+
# Update ema embed: eq. (8) in vqvae paper
|
| 286 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
| 287 |
+
self.embed.data.copy_(embed_normalized)
|
| 288 |
+
|
| 289 |
+
# We do the expiry of code at that point as buffers are in sync
|
| 290 |
+
# and all the workers will take the same decision.
|
| 291 |
+
self.expire_codes_(x)
|
| 292 |
+
|
| 293 |
+
return quantize, embed_ind
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class VectorQuantization(nn.Module):
|
| 297 |
+
"""Vector quantization implementation.
|
| 298 |
+
Currently supports only euclidean distance.
|
| 299 |
+
Args:
|
| 300 |
+
dim (int): Dimension
|
| 301 |
+
codebook_size (int): Codebook size
|
| 302 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
| 303 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 304 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 305 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
| 306 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
| 307 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 308 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 309 |
+
randomly selected vector from the current batch.
|
| 310 |
+
commitment_weight (float): Weight for commitment loss.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
def __init__(
|
| 314 |
+
self,
|
| 315 |
+
dim: int,
|
| 316 |
+
codebook_size: int,
|
| 317 |
+
codebook_dim: tp.Optional[int] = None,
|
| 318 |
+
decay: float = 0.99,
|
| 319 |
+
epsilon: float = 1e-5,
|
| 320 |
+
kmeans_init: bool = True,
|
| 321 |
+
kmeans_iters: int = 50,
|
| 322 |
+
threshold_ema_dead_code: int = 2,
|
| 323 |
+
commitment_weight: float = 1.0,
|
| 324 |
+
):
|
| 325 |
+
super().__init__()
|
| 326 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
| 327 |
+
|
| 328 |
+
requires_projection = _codebook_dim != dim
|
| 329 |
+
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
| 330 |
+
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
| 331 |
+
|
| 332 |
+
self.epsilon = epsilon
|
| 333 |
+
self.commitment_weight = commitment_weight
|
| 334 |
+
|
| 335 |
+
self._codebook = EuclideanCodebook(
|
| 336 |
+
dim=_codebook_dim,
|
| 337 |
+
codebook_size=codebook_size,
|
| 338 |
+
kmeans_init=kmeans_init,
|
| 339 |
+
kmeans_iters=kmeans_iters,
|
| 340 |
+
decay=decay,
|
| 341 |
+
epsilon=epsilon,
|
| 342 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
| 343 |
+
)
|
| 344 |
+
self.codebook_size = codebook_size
|
| 345 |
+
|
| 346 |
+
@property
|
| 347 |
+
def codebook(self):
|
| 348 |
+
return self._codebook.embed
|
| 349 |
+
|
| 350 |
+
def encode(self, x):
|
| 351 |
+
x = rearrange(x, "b d n -> b n d")
|
| 352 |
+
x = self.project_in(x)
|
| 353 |
+
embed_in = self._codebook.encode(x)
|
| 354 |
+
return embed_in
|
| 355 |
+
|
| 356 |
+
def decode(self, embed_ind):
|
| 357 |
+
quantize = self._codebook.decode(embed_ind)
|
| 358 |
+
quantize = self.project_out(quantize)
|
| 359 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
| 360 |
+
return quantize
|
| 361 |
+
|
| 362 |
+
def forward(self, x):
|
| 363 |
+
device = x.device
|
| 364 |
+
x = x.transpose(1, 2).contiguous() # [b d n] -> [b n d]
|
| 365 |
+
x = self.project_in(x)
|
| 366 |
+
|
| 367 |
+
quantize, embed_ind = self._codebook(x)
|
| 368 |
+
|
| 369 |
+
if self.training:
|
| 370 |
+
quantize = x + (quantize - x).detach()
|
| 371 |
+
|
| 372 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
| 373 |
+
|
| 374 |
+
if self.training:
|
| 375 |
+
if self.commitment_weight > 0:
|
| 376 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
| 377 |
+
loss = loss + commit_loss * self.commitment_weight
|
| 378 |
+
|
| 379 |
+
quantize = self.project_out(quantize)
|
| 380 |
+
quantize = quantize.transpose(1, 2).contiguous() # [b n d] -> [b d n]
|
| 381 |
+
return quantize, embed_ind, loss
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class ResidualVectorQuantization(nn.Module):
|
| 385 |
+
"""Residual vector quantization implementation.
|
| 386 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
def __init__(self, *, num_quantizers, **kwargs):
|
| 390 |
+
super().__init__()
|
| 391 |
+
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
|
| 392 |
+
|
| 393 |
+
def forward(self, x, n_q: tp.Optional[int] = None):
|
| 394 |
+
quantized_out = 0.0
|
| 395 |
+
residual = x
|
| 396 |
+
|
| 397 |
+
all_losses = []
|
| 398 |
+
all_indices = []
|
| 399 |
+
|
| 400 |
+
n_q = n_q or len(self.layers)
|
| 401 |
+
|
| 402 |
+
for layer in self.layers[:n_q]:
|
| 403 |
+
quantized, indices, loss = layer(residual)
|
| 404 |
+
residual = residual - quantized
|
| 405 |
+
quantized_out = quantized_out + quantized
|
| 406 |
+
|
| 407 |
+
all_indices.append(indices)
|
| 408 |
+
all_losses.append(loss)
|
| 409 |
+
|
| 410 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
| 411 |
+
return quantized_out, out_indices, out_losses
|
| 412 |
+
|
| 413 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
| 414 |
+
residual = x
|
| 415 |
+
all_indices = []
|
| 416 |
+
n_q = n_q or len(self.layers)
|
| 417 |
+
for layer in self.layers[:n_q]:
|
| 418 |
+
indices = layer.encode(residual)
|
| 419 |
+
quantized = layer.decode(indices)
|
| 420 |
+
residual = residual - quantized
|
| 421 |
+
all_indices.append(indices)
|
| 422 |
+
out_indices = torch.stack(all_indices)
|
| 423 |
+
return out_indices
|
| 424 |
+
|
| 425 |
+
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
| 426 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
| 427 |
+
for i, indices in enumerate(q_indices):
|
| 428 |
+
layer = self.layers[i]
|
| 429 |
+
quantized = layer.decode(indices)
|
| 430 |
+
quantized_out = quantized_out + quantized
|
| 431 |
+
return quantized_out
|
higgs_audio/audio_processing/quantization/ddp_utils.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import random
|
| 3 |
+
import subprocess
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 10 |
+
from torch.nn.parallel.distributed import _find_tensors
|
| 11 |
+
import torch.optim
|
| 12 |
+
import torch.utils.data
|
| 13 |
+
from packaging import version
|
| 14 |
+
from omegaconf import OmegaConf
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def set_random_seed(seed):
|
| 18 |
+
random.seed(seed)
|
| 19 |
+
np.random.seed(seed)
|
| 20 |
+
torch.manual_seed(seed)
|
| 21 |
+
torch.cuda.manual_seed_all(seed)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def is_logging_process():
|
| 25 |
+
return not dist.is_initialized() or dist.get_rank() == 0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_logger(cfg, name=None):
|
| 29 |
+
# log_file_path is used when unit testing
|
| 30 |
+
if is_logging_process():
|
| 31 |
+
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True))
|
| 32 |
+
return logging.getLogger(name)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
|
| 36 |
+
class SyncFunction(torch.autograd.Function):
|
| 37 |
+
@staticmethod
|
| 38 |
+
# @torch.no_grad()
|
| 39 |
+
def forward(ctx, tensor):
|
| 40 |
+
ctx.batch_size = tensor.shape[0]
|
| 41 |
+
|
| 42 |
+
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
| 43 |
+
|
| 44 |
+
torch.distributed.all_gather(gathered_tensor, tensor)
|
| 45 |
+
gathered_tensor = torch.cat(gathered_tensor, 0)
|
| 46 |
+
|
| 47 |
+
return gathered_tensor
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def backward(ctx, grad_output):
|
| 51 |
+
grad_input = grad_output.clone()
|
| 52 |
+
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
|
| 53 |
+
|
| 54 |
+
idx_from = torch.distributed.get_rank() * ctx.batch_size
|
| 55 |
+
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
|
| 56 |
+
return grad_input[idx_from:idx_to]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_timestamp():
|
| 60 |
+
return datetime.now().strftime("%y%m%d-%H%M%S")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_commit_hash():
|
| 64 |
+
message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
|
| 65 |
+
return message.strip().decode("utf-8")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class DDP(DistributedDataParallel):
|
| 69 |
+
"""
|
| 70 |
+
Override the forward call in lightning so it goes to training and validation step respectively
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def forward(self, *inputs, **kwargs): # pragma: no cover
|
| 74 |
+
if version.parse(torch.__version__[:6]) < version.parse("1.11"):
|
| 75 |
+
self._sync_params()
|
| 76 |
+
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
| 77 |
+
assert len(self.device_ids) == 1
|
| 78 |
+
if self.module.training:
|
| 79 |
+
output = self.module.training_step(*inputs[0], **kwargs[0])
|
| 80 |
+
elif self.module.testing:
|
| 81 |
+
output = self.module.test_step(*inputs[0], **kwargs[0])
|
| 82 |
+
else:
|
| 83 |
+
output = self.module.validation_step(*inputs[0], **kwargs[0])
|
| 84 |
+
if torch.is_grad_enabled():
|
| 85 |
+
# We'll return the output object verbatim since it is a freeform
|
| 86 |
+
# object. We need to find any tensors in this object, though,
|
| 87 |
+
# because we need to figure out which parameters were used during
|
| 88 |
+
# this forward pass, to ensure we short circuit reduction for any
|
| 89 |
+
# unused parameters. Only if `find_unused_parameters` is set.
|
| 90 |
+
if self.find_unused_parameters:
|
| 91 |
+
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
| 92 |
+
else:
|
| 93 |
+
self.reducer.prepare_for_backward([])
|
| 94 |
+
else:
|
| 95 |
+
from torch.nn.parallel.distributed import (
|
| 96 |
+
logging,
|
| 97 |
+
Join,
|
| 98 |
+
_DDPSink,
|
| 99 |
+
_tree_flatten_with_rref,
|
| 100 |
+
_tree_unflatten_with_rref,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
|
| 104 |
+
if torch.is_grad_enabled() and self.require_backward_grad_sync:
|
| 105 |
+
self.logger.set_runtime_stats_and_log()
|
| 106 |
+
self.num_iterations += 1
|
| 107 |
+
self.reducer.prepare_for_forward()
|
| 108 |
+
|
| 109 |
+
# Notify the join context that this process has not joined, if
|
| 110 |
+
# needed
|
| 111 |
+
work = Join.notify_join_context(self)
|
| 112 |
+
if work:
|
| 113 |
+
self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
|
| 114 |
+
|
| 115 |
+
# Calling _rebuild_buckets before forward compuation,
|
| 116 |
+
# It may allocate new buckets before deallocating old buckets
|
| 117 |
+
# inside _rebuild_buckets. To save peak memory usage,
|
| 118 |
+
# call _rebuild_buckets before the peak memory usage increases
|
| 119 |
+
# during forward computation.
|
| 120 |
+
# This should be called only once during whole training period.
|
| 121 |
+
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
|
| 122 |
+
logging.info("Reducer buckets have been rebuilt in this iteration.")
|
| 123 |
+
self._has_rebuilt_buckets = True
|
| 124 |
+
|
| 125 |
+
# sync params according to location (before/after forward) user
|
| 126 |
+
# specified as part of hook, if hook was specified.
|
| 127 |
+
buffer_hook_registered = hasattr(self, "buffer_hook")
|
| 128 |
+
if self._check_sync_bufs_pre_fwd():
|
| 129 |
+
self._sync_buffers()
|
| 130 |
+
|
| 131 |
+
if self._join_config.enable:
|
| 132 |
+
# Notify joined ranks whether they should sync in backwards pass or not.
|
| 133 |
+
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
|
| 134 |
+
|
| 135 |
+
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
| 136 |
+
if self.module.training:
|
| 137 |
+
output = self.module.training_step(*inputs[0], **kwargs[0])
|
| 138 |
+
elif self.module.testing:
|
| 139 |
+
output = self.module.test_step(*inputs[0], **kwargs[0])
|
| 140 |
+
else:
|
| 141 |
+
output = self.module.validation_step(*inputs[0], **kwargs[0])
|
| 142 |
+
|
| 143 |
+
# sync params according to location (before/after forward) user
|
| 144 |
+
# specified as part of hook, if hook was specified.
|
| 145 |
+
if self._check_sync_bufs_post_fwd():
|
| 146 |
+
self._sync_buffers()
|
| 147 |
+
|
| 148 |
+
if torch.is_grad_enabled() and self.require_backward_grad_sync:
|
| 149 |
+
self.require_forward_param_sync = True
|
| 150 |
+
# We'll return the output object verbatim since it is a freeform
|
| 151 |
+
# object. We need to find any tensors in this object, though,
|
| 152 |
+
# because we need to figure out which parameters were used during
|
| 153 |
+
# this forward pass, to ensure we short circuit reduction for any
|
| 154 |
+
# unused parameters. Only if `find_unused_parameters` is set.
|
| 155 |
+
if self.find_unused_parameters and not self.static_graph:
|
| 156 |
+
# Do not need to populate this for static graph.
|
| 157 |
+
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
| 158 |
+
else:
|
| 159 |
+
self.reducer.prepare_for_backward([])
|
| 160 |
+
else:
|
| 161 |
+
self.require_forward_param_sync = False
|
| 162 |
+
|
| 163 |
+
# TODO: DDPSink is currently enabled for unused parameter detection and
|
| 164 |
+
# static graph training for first iteration.
|
| 165 |
+
if (self.find_unused_parameters and not self.static_graph) or (
|
| 166 |
+
self.static_graph and self.num_iterations == 1
|
| 167 |
+
):
|
| 168 |
+
state_dict = {
|
| 169 |
+
"static_graph": self.static_graph,
|
| 170 |
+
"num_iterations": self.num_iterations,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
|
| 174 |
+
output_placeholders = [None for _ in range(len(output_tensor_list))]
|
| 175 |
+
# Do not touch tensors that have no grad_fn, which can cause issues
|
| 176 |
+
# such as https://github.com/pytorch/pytorch/issues/60733
|
| 177 |
+
for i, output in enumerate(output_tensor_list):
|
| 178 |
+
if torch.is_tensor(output) and output.grad_fn is None:
|
| 179 |
+
output_placeholders[i] = output
|
| 180 |
+
|
| 181 |
+
# When find_unused_parameters=True, makes tensors which require grad
|
| 182 |
+
# run through the DDPSink backward pass. When not all outputs are
|
| 183 |
+
# used in loss, this makes those corresponding tensors receive
|
| 184 |
+
# undefined gradient which the reducer then handles to ensure
|
| 185 |
+
# param.grad field is not touched and we don't error out.
|
| 186 |
+
passthrough_tensor_list = _DDPSink.apply(
|
| 187 |
+
self.reducer,
|
| 188 |
+
state_dict,
|
| 189 |
+
*output_tensor_list,
|
| 190 |
+
)
|
| 191 |
+
for i in range(len(output_placeholders)):
|
| 192 |
+
if output_placeholders[i] is None:
|
| 193 |
+
output_placeholders[i] = passthrough_tensor_list[i]
|
| 194 |
+
|
| 195 |
+
# Reconstruct output data structure.
|
| 196 |
+
output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
|
| 197 |
+
return output
|
higgs_audio/audio_processing/quantization/distrib.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Torch distributed utilities."""
|
| 8 |
+
|
| 9 |
+
import typing as tp
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def rank():
|
| 15 |
+
if torch.distributed.is_initialized():
|
| 16 |
+
return torch.distributed.get_rank()
|
| 17 |
+
else:
|
| 18 |
+
return 0
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def world_size():
|
| 22 |
+
if torch.distributed.is_initialized():
|
| 23 |
+
return torch.distributed.get_world_size()
|
| 24 |
+
else:
|
| 25 |
+
return 1
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def is_distributed():
|
| 29 |
+
return world_size() > 1
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
| 33 |
+
if is_distributed():
|
| 34 |
+
return torch.distributed.all_reduce(tensor, op)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _is_complex_or_float(tensor):
|
| 38 |
+
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
| 42 |
+
# utility function to check that the number of params in all workers is the same,
|
| 43 |
+
# and thus avoid a deadlock with distributed all reduce.
|
| 44 |
+
if not is_distributed() or not params:
|
| 45 |
+
return
|
| 46 |
+
# print('params[0].device ', params[0].device)
|
| 47 |
+
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
| 48 |
+
all_reduce(tensor)
|
| 49 |
+
if tensor.item() != len(params) * world_size():
|
| 50 |
+
# If not all the workers have the same number, for at least one of them,
|
| 51 |
+
# this inequality will be verified.
|
| 52 |
+
raise RuntimeError(
|
| 53 |
+
f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
| 58 |
+
"""Broadcast the tensors from the given parameters to all workers.
|
| 59 |
+
This can be used to ensure that all workers have the same model to start with.
|
| 60 |
+
"""
|
| 61 |
+
if not is_distributed():
|
| 62 |
+
return
|
| 63 |
+
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
| 64 |
+
_check_number_of_params(tensors)
|
| 65 |
+
handles = []
|
| 66 |
+
for tensor in tensors:
|
| 67 |
+
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
| 68 |
+
handles.append(handle)
|
| 69 |
+
for handle in handles:
|
| 70 |
+
handle.wait()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def sync_buffer(buffers, average=True):
|
| 74 |
+
"""
|
| 75 |
+
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
| 76 |
+
"""
|
| 77 |
+
if not is_distributed():
|
| 78 |
+
return
|
| 79 |
+
handles = []
|
| 80 |
+
for buffer in buffers:
|
| 81 |
+
if torch.is_floating_point(buffer.data):
|
| 82 |
+
if average:
|
| 83 |
+
handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
| 84 |
+
else:
|
| 85 |
+
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
|
| 86 |
+
handles.append((buffer, handle))
|
| 87 |
+
for buffer, handle in handles:
|
| 88 |
+
handle.wait()
|
| 89 |
+
if average:
|
| 90 |
+
buffer.data /= world_size
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def sync_grad(params):
|
| 94 |
+
"""
|
| 95 |
+
Simpler alternative to DistributedDataParallel, that doesn't rely
|
| 96 |
+
on any black magic. For simple models it can also be as fast.
|
| 97 |
+
Just call this on your model parameters after the call to backward!
|
| 98 |
+
"""
|
| 99 |
+
if not is_distributed():
|
| 100 |
+
return
|
| 101 |
+
handles = []
|
| 102 |
+
for p in params:
|
| 103 |
+
if p.grad is not None:
|
| 104 |
+
handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
| 105 |
+
handles.append((p, handle))
|
| 106 |
+
for p, handle in handles:
|
| 107 |
+
handle.wait()
|
| 108 |
+
p.grad.data /= world_size()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def average_metrics(metrics: tp.Dict[str, float], count=1.0):
|
| 112 |
+
"""Average a dictionary of metrics across all workers, using the optional
|
| 113 |
+
`count` as unormalized weight.
|
| 114 |
+
"""
|
| 115 |
+
if not is_distributed():
|
| 116 |
+
return metrics
|
| 117 |
+
keys, values = zip(*metrics.items())
|
| 118 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 119 |
+
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
| 120 |
+
tensor *= count
|
| 121 |
+
all_reduce(tensor)
|
| 122 |
+
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
| 123 |
+
return dict(zip(keys, averaged))
|
higgs_audio/audio_processing/quantization/vq.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Residual vector quantizer implementation."""
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
import math
|
| 11 |
+
import typing as tp
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
# from .core_vq import ResidualVectorQuantization
|
| 17 |
+
from .core_vq_lsx_version import ResidualVectorQuantization
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class QuantizedResult:
|
| 22 |
+
quantized: torch.Tensor
|
| 23 |
+
codes: torch.Tensor
|
| 24 |
+
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
| 25 |
+
penalty: tp.Optional[torch.Tensor] = None
|
| 26 |
+
metrics: dict = field(default_factory=dict)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ResidualVectorQuantizer(nn.Module):
|
| 30 |
+
"""Residual Vector Quantizer.
|
| 31 |
+
Args:
|
| 32 |
+
dimension (int): Dimension of the codebooks.
|
| 33 |
+
n_q (int): Number of residual vector quantizers used.
|
| 34 |
+
bins (int): Codebook size.
|
| 35 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 36 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
| 37 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
| 38 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 39 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 40 |
+
randomly selected vector from the current batch.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dimension: int = 256,
|
| 46 |
+
codebook_dim: int = None,
|
| 47 |
+
n_q: int = 8,
|
| 48 |
+
bins: int = 1024,
|
| 49 |
+
decay: float = 0.99,
|
| 50 |
+
kmeans_init: bool = True,
|
| 51 |
+
kmeans_iters: int = 50,
|
| 52 |
+
threshold_ema_dead_code: int = 2,
|
| 53 |
+
):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.n_q = n_q
|
| 56 |
+
self.dimension = dimension
|
| 57 |
+
self.codebook_dim = codebook_dim
|
| 58 |
+
self.bins = bins
|
| 59 |
+
self.decay = decay
|
| 60 |
+
self.kmeans_init = kmeans_init
|
| 61 |
+
self.kmeans_iters = kmeans_iters
|
| 62 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 63 |
+
self.vq = ResidualVectorQuantization(
|
| 64 |
+
dim=self.dimension,
|
| 65 |
+
codebook_dim=self.codebook_dim,
|
| 66 |
+
codebook_size=self.bins,
|
| 67 |
+
num_quantizers=self.n_q,
|
| 68 |
+
decay=self.decay,
|
| 69 |
+
kmeans_init=self.kmeans_init,
|
| 70 |
+
kmeans_iters=self.kmeans_iters,
|
| 71 |
+
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None): # -> QuantizedResult:
|
| 75 |
+
"""Residual vector quantization on the given input tensor.
|
| 76 |
+
Args:
|
| 77 |
+
x (torch.Tensor): Input tensor.
|
| 78 |
+
sample_rate (int): Sample rate of the input tensor.
|
| 79 |
+
bandwidth (float): Target bandwidth.
|
| 80 |
+
Returns:
|
| 81 |
+
QuantizedResult:
|
| 82 |
+
The quantized (or approximately quantized) representation with
|
| 83 |
+
the associated bandwidth and any penalty term for the loss.
|
| 84 |
+
"""
|
| 85 |
+
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
|
| 86 |
+
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
|
| 87 |
+
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
|
| 88 |
+
bw = torch.tensor(n_q * bw_per_q).to(x)
|
| 89 |
+
return quantized, codes, bw, torch.mean(commit_loss)
|
| 90 |
+
# return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
|
| 91 |
+
|
| 92 |
+
def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int:
|
| 93 |
+
"""Return n_q based on specified target bandwidth."""
|
| 94 |
+
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
|
| 95 |
+
n_q = self.n_q
|
| 96 |
+
if bandwidth and bandwidth > 0.0:
|
| 97 |
+
n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
|
| 98 |
+
return n_q
|
| 99 |
+
|
| 100 |
+
def get_bandwidth_per_quantizer(self, sample_rate: int):
|
| 101 |
+
"""Return bandwidth per quantizer for a given input sample rate."""
|
| 102 |
+
return math.log2(self.bins) * sample_rate / 1000
|
| 103 |
+
|
| 104 |
+
def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
|
| 105 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
| 106 |
+
The RVQ encode method sets the appropriate number of quantizer to use
|
| 107 |
+
and returns indices for each quantizer.
|
| 108 |
+
"""
|
| 109 |
+
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
|
| 110 |
+
codes = self.vq.encode(x, n_q=n_q)
|
| 111 |
+
return codes
|
| 112 |
+
|
| 113 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 114 |
+
"""Decode the given codes to the quantized representation."""
|
| 115 |
+
quantized = self.vq.decode(codes)
|
| 116 |
+
return quantized
|
higgs_audio/audio_processing/semantic_module.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Based on code from: https://github.com/zhenye234/xcodec
|
| 2 |
+
# Licensed under MIT License
|
| 3 |
+
# Modifications by BosonAI
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Conv1d1x1(nn.Conv1d):
|
| 10 |
+
"""1x1 Conv1d."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, in_channels, out_channels, bias=True):
|
| 13 |
+
super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Conv1d(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_channels: int,
|
| 20 |
+
out_channels: int,
|
| 21 |
+
kernel_size: int,
|
| 22 |
+
stride: int = 1,
|
| 23 |
+
padding: int = -1,
|
| 24 |
+
dilation: int = 1,
|
| 25 |
+
groups: int = 1,
|
| 26 |
+
bias: bool = True,
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.in_channels = in_channels
|
| 30 |
+
self.out_channels = out_channels
|
| 31 |
+
self.kernel_size = kernel_size
|
| 32 |
+
if padding < 0:
|
| 33 |
+
padding = (kernel_size - 1) // 2 * dilation
|
| 34 |
+
self.dilation = dilation
|
| 35 |
+
self.conv = nn.Conv1d(
|
| 36 |
+
in_channels=in_channels,
|
| 37 |
+
out_channels=out_channels,
|
| 38 |
+
kernel_size=kernel_size,
|
| 39 |
+
stride=stride,
|
| 40 |
+
padding=padding,
|
| 41 |
+
dilation=dilation,
|
| 42 |
+
groups=groups,
|
| 43 |
+
bias=bias,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
"""
|
| 48 |
+
Args:
|
| 49 |
+
x (Tensor): Float tensor variable with the shape (B, C, T).
|
| 50 |
+
Returns:
|
| 51 |
+
Tensor: Float tensor variable with the shape (B, C, T).
|
| 52 |
+
"""
|
| 53 |
+
x = self.conv(x)
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ResidualUnit(nn.Module):
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
in_channels: int,
|
| 61 |
+
out_channels: int,
|
| 62 |
+
kernel_size=3,
|
| 63 |
+
dilation=1,
|
| 64 |
+
bias=False,
|
| 65 |
+
nonlinear_activation="ELU",
|
| 66 |
+
nonlinear_activation_params={},
|
| 67 |
+
):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
|
| 70 |
+
self.conv1 = Conv1d(
|
| 71 |
+
in_channels=in_channels,
|
| 72 |
+
out_channels=out_channels,
|
| 73 |
+
kernel_size=kernel_size,
|
| 74 |
+
stride=1,
|
| 75 |
+
dilation=dilation,
|
| 76 |
+
bias=bias,
|
| 77 |
+
)
|
| 78 |
+
self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
y = self.conv1(self.activation(x))
|
| 82 |
+
y = self.conv2(self.activation(y))
|
| 83 |
+
return x + y
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class ConvTranspose1d(nn.Module):
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
in_channels: int,
|
| 90 |
+
out_channels: int,
|
| 91 |
+
kernel_size: int,
|
| 92 |
+
stride: int,
|
| 93 |
+
padding=-1,
|
| 94 |
+
output_padding=-1,
|
| 95 |
+
groups=1,
|
| 96 |
+
bias=True,
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
if padding < 0:
|
| 100 |
+
padding = (stride + 1) // 2
|
| 101 |
+
if output_padding < 0:
|
| 102 |
+
output_padding = 1 if stride % 2 else 0
|
| 103 |
+
self.deconv = nn.ConvTranspose1d(
|
| 104 |
+
in_channels=in_channels,
|
| 105 |
+
out_channels=out_channels,
|
| 106 |
+
kernel_size=kernel_size,
|
| 107 |
+
stride=stride,
|
| 108 |
+
padding=padding,
|
| 109 |
+
output_padding=output_padding,
|
| 110 |
+
groups=groups,
|
| 111 |
+
bias=bias,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def forward(self, x):
|
| 115 |
+
"""
|
| 116 |
+
Args:
|
| 117 |
+
x (Tensor): Float tensor variable with the shape (B, C, T).
|
| 118 |
+
Returns:
|
| 119 |
+
Tensor: Float tensor variable with the shape (B, C', T').
|
| 120 |
+
"""
|
| 121 |
+
x = self.deconv(x)
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class EncoderBlock(nn.Module):
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
in_channels: int,
|
| 129 |
+
out_channels: int,
|
| 130 |
+
stride: int,
|
| 131 |
+
dilations=(1, 1),
|
| 132 |
+
unit_kernel_size=3,
|
| 133 |
+
bias=True,
|
| 134 |
+
):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.res_units = torch.nn.ModuleList()
|
| 137 |
+
for dilation in dilations:
|
| 138 |
+
self.res_units += [
|
| 139 |
+
ResidualUnit(
|
| 140 |
+
in_channels,
|
| 141 |
+
in_channels,
|
| 142 |
+
kernel_size=unit_kernel_size,
|
| 143 |
+
dilation=dilation,
|
| 144 |
+
)
|
| 145 |
+
]
|
| 146 |
+
self.num_res = len(self.res_units)
|
| 147 |
+
|
| 148 |
+
self.conv = Conv1d(
|
| 149 |
+
in_channels=in_channels,
|
| 150 |
+
out_channels=out_channels,
|
| 151 |
+
kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2
|
| 152 |
+
stride=stride,
|
| 153 |
+
bias=bias,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
for idx in range(self.num_res):
|
| 158 |
+
x = self.res_units[idx](x)
|
| 159 |
+
x = self.conv(x)
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class Encoder(nn.Module):
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
input_channels: int,
|
| 167 |
+
encode_channels: int,
|
| 168 |
+
channel_ratios=(1, 1),
|
| 169 |
+
strides=(1, 1),
|
| 170 |
+
kernel_size=3,
|
| 171 |
+
bias=True,
|
| 172 |
+
block_dilations=(1, 1),
|
| 173 |
+
unit_kernel_size=3,
|
| 174 |
+
):
|
| 175 |
+
super().__init__()
|
| 176 |
+
assert len(channel_ratios) == len(strides)
|
| 177 |
+
|
| 178 |
+
self.conv = Conv1d(
|
| 179 |
+
in_channels=input_channels,
|
| 180 |
+
out_channels=encode_channels,
|
| 181 |
+
kernel_size=kernel_size,
|
| 182 |
+
stride=1,
|
| 183 |
+
bias=False,
|
| 184 |
+
)
|
| 185 |
+
self.conv_blocks = torch.nn.ModuleList()
|
| 186 |
+
in_channels = encode_channels
|
| 187 |
+
for idx, stride in enumerate(strides):
|
| 188 |
+
out_channels = int(encode_channels * channel_ratios[idx]) # could be float
|
| 189 |
+
self.conv_blocks += [
|
| 190 |
+
EncoderBlock(
|
| 191 |
+
in_channels,
|
| 192 |
+
out_channels,
|
| 193 |
+
stride,
|
| 194 |
+
dilations=block_dilations,
|
| 195 |
+
unit_kernel_size=unit_kernel_size,
|
| 196 |
+
bias=bias,
|
| 197 |
+
)
|
| 198 |
+
]
|
| 199 |
+
in_channels = out_channels
|
| 200 |
+
self.num_blocks = len(self.conv_blocks)
|
| 201 |
+
self.out_channels = out_channels
|
| 202 |
+
|
| 203 |
+
def forward(self, x):
|
| 204 |
+
x = self.conv(x)
|
| 205 |
+
for i in range(self.num_blocks):
|
| 206 |
+
x = self.conv_blocks[i](x)
|
| 207 |
+
return x
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class DecoderBlock(nn.Module):
|
| 211 |
+
"""Decoder block (no up-sampling)"""
|
| 212 |
+
|
| 213 |
+
def __init__(
|
| 214 |
+
self,
|
| 215 |
+
in_channels: int,
|
| 216 |
+
out_channels: int,
|
| 217 |
+
stride: int,
|
| 218 |
+
dilations=(1, 1),
|
| 219 |
+
unit_kernel_size=3,
|
| 220 |
+
bias=True,
|
| 221 |
+
):
|
| 222 |
+
super().__init__()
|
| 223 |
+
|
| 224 |
+
if stride == 1:
|
| 225 |
+
self.conv = Conv1d(
|
| 226 |
+
in_channels=in_channels,
|
| 227 |
+
out_channels=out_channels,
|
| 228 |
+
kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
|
| 229 |
+
stride=stride,
|
| 230 |
+
bias=bias,
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
self.conv = ConvTranspose1d(
|
| 234 |
+
in_channels=in_channels,
|
| 235 |
+
out_channels=out_channels,
|
| 236 |
+
kernel_size=(2 * stride),
|
| 237 |
+
stride=stride,
|
| 238 |
+
bias=bias,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
self.res_units = torch.nn.ModuleList()
|
| 242 |
+
for idx, dilation in enumerate(dilations):
|
| 243 |
+
self.res_units += [
|
| 244 |
+
ResidualUnit(
|
| 245 |
+
out_channels,
|
| 246 |
+
out_channels,
|
| 247 |
+
kernel_size=unit_kernel_size,
|
| 248 |
+
dilation=dilation,
|
| 249 |
+
)
|
| 250 |
+
]
|
| 251 |
+
self.num_res = len(self.res_units)
|
| 252 |
+
|
| 253 |
+
def forward(self, x):
|
| 254 |
+
x = self.conv(x)
|
| 255 |
+
for idx in range(self.num_res):
|
| 256 |
+
x = self.res_units[idx](x)
|
| 257 |
+
return x
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class Decoder(nn.Module):
|
| 261 |
+
def __init__(
|
| 262 |
+
self,
|
| 263 |
+
code_dim: int,
|
| 264 |
+
output_channels: int,
|
| 265 |
+
decode_channels: int,
|
| 266 |
+
channel_ratios=(1, 1),
|
| 267 |
+
strides=(1, 1),
|
| 268 |
+
kernel_size=3,
|
| 269 |
+
bias=True,
|
| 270 |
+
block_dilations=(1, 1),
|
| 271 |
+
unit_kernel_size=3,
|
| 272 |
+
):
|
| 273 |
+
super().__init__()
|
| 274 |
+
assert len(channel_ratios) == len(strides)
|
| 275 |
+
|
| 276 |
+
self.conv1 = Conv1d(
|
| 277 |
+
in_channels=code_dim,
|
| 278 |
+
out_channels=int(decode_channels * channel_ratios[0]),
|
| 279 |
+
kernel_size=kernel_size,
|
| 280 |
+
stride=1,
|
| 281 |
+
bias=False,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
self.conv_blocks = torch.nn.ModuleList()
|
| 285 |
+
for idx, stride in enumerate(strides):
|
| 286 |
+
in_channels = int(decode_channels * channel_ratios[idx])
|
| 287 |
+
if idx < (len(channel_ratios) - 1):
|
| 288 |
+
out_channels = int(decode_channels * channel_ratios[idx + 1])
|
| 289 |
+
else:
|
| 290 |
+
out_channels = decode_channels
|
| 291 |
+
self.conv_blocks += [
|
| 292 |
+
DecoderBlock(
|
| 293 |
+
in_channels,
|
| 294 |
+
out_channels,
|
| 295 |
+
stride,
|
| 296 |
+
dilations=block_dilations,
|
| 297 |
+
unit_kernel_size=unit_kernel_size,
|
| 298 |
+
bias=bias,
|
| 299 |
+
)
|
| 300 |
+
]
|
| 301 |
+
self.num_blocks = len(self.conv_blocks)
|
| 302 |
+
|
| 303 |
+
self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
|
| 304 |
+
|
| 305 |
+
def forward(self, z):
|
| 306 |
+
x = self.conv1(z)
|
| 307 |
+
for i in range(self.num_blocks):
|
| 308 |
+
x = self.conv_blocks[i](x)
|
| 309 |
+
x = self.conv2(x)
|
| 310 |
+
return x
|
higgs_audio/constants.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
AUDIO_IN_TOKEN = "<|AUDIO|>"
|
| 2 |
+
AUDIO_OUT_TOKEN = "<|AUDIO_OUT|>"
|
| 3 |
+
EOS_TOKEN = "<|end_of_text|>"
|
higgs_audio/data_collator/__init__.py
ADDED
|
File without changes
|
higgs_audio/data_collator/higgs_audio_collator.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import List, Tuple, Dict
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
from transformers.models.whisper.processing_whisper import WhisperProcessor
|
| 11 |
+
|
| 12 |
+
from ..dataset.chatml_dataset import ChatMLDatasetSample, RankedChatMLDatasetSampleTuple
|
| 13 |
+
from ..model.utils import build_delay_pattern_mask
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _ceil_to_nearest(n, round_to):
|
| 17 |
+
return (n + round_to - 1) // round_to * round_to
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class HiggsAudioBatchInput:
|
| 22 |
+
input_ids: torch.LongTensor # shape (bsz, seq_len).
|
| 23 |
+
attention_mask: torch.Tensor # shape (bsz, seq_len).
|
| 24 |
+
audio_features: Optional[torch.Tensor] # shape (num_audio_in, feature_dim, max_mel_seq_len).
|
| 25 |
+
audio_feature_attention_mask: Optional[torch.Tensor] # shape (num_audio_in, max_mel_seq_len).
|
| 26 |
+
audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
|
| 27 |
+
audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,)
|
| 28 |
+
# The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment
|
| 29 |
+
# Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information
|
| 30 |
+
# For example,
|
| 31 |
+
# audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples.
|
| 32 |
+
# This is a batch of 3 samples, then we will have the group location as:
|
| 33 |
+
# audio_out_ids_start_group_loc = [0, 0, 1, 2]
|
| 34 |
+
audio_out_ids_start_group_loc: Optional[
|
| 35 |
+
torch.LongTensor
|
| 36 |
+
] # shape (num_audio_out,), specify which a sample's group location in the batch
|
| 37 |
+
audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length)
|
| 38 |
+
audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,)
|
| 39 |
+
label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len)
|
| 40 |
+
label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
|
| 41 |
+
reward: Optional[float] = None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class HiggsAudioSampleCollator:
|
| 45 |
+
"""Sample collator for Higgs-Audio model.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
whisper_processor (WhisperProcessor): The whisper processor.
|
| 49 |
+
audio_in_token_id (int): The token id for audio-in.
|
| 50 |
+
audio_out_token_id (int): The token id for audio-out.
|
| 51 |
+
pad_token_id (int): The token id for padding.
|
| 52 |
+
audio_stream_bos_id (int): The token id for audio-stream beginning of sentence.
|
| 53 |
+
audio_stream_eos_id (int): The token id for audio-stream end of sentence.
|
| 54 |
+
round_to (int): The round-to value.
|
| 55 |
+
pad_left (bool): Whether to pad left.
|
| 56 |
+
return_audio_in_tokens (bool): Whether to return audio-in tokens.
|
| 57 |
+
use_delay_pattern (bool): Whether to use delay pattern.
|
| 58 |
+
disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes.
|
| 59 |
+
chunk_size_seconds (int): The chunk size in seconds.
|
| 60 |
+
add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks.
|
| 61 |
+
mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>.
|
| 62 |
+
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
whisper_processor: WhisperProcessor,
|
| 68 |
+
audio_in_token_id,
|
| 69 |
+
audio_out_token_id,
|
| 70 |
+
pad_token_id,
|
| 71 |
+
audio_stream_bos_id,
|
| 72 |
+
audio_stream_eos_id,
|
| 73 |
+
round_to=8,
|
| 74 |
+
pad_left=False,
|
| 75 |
+
encode_whisper_embed=True,
|
| 76 |
+
return_audio_in_tokens=True,
|
| 77 |
+
audio_num_codebooks=None,
|
| 78 |
+
use_delay_pattern=False,
|
| 79 |
+
disable_audio_codes_transform=False,
|
| 80 |
+
chunk_size_seconds=30, # Maximum duration for each chunk
|
| 81 |
+
add_new_bos_eos_for_long_chunk=True,
|
| 82 |
+
mask_audio_out_token_label=True,
|
| 83 |
+
):
|
| 84 |
+
self.whisper_processor = whisper_processor
|
| 85 |
+
self.round_to = round_to
|
| 86 |
+
self.pad_left = pad_left
|
| 87 |
+
self.audio_in_token_id = audio_in_token_id
|
| 88 |
+
self.audio_out_token_id = audio_out_token_id
|
| 89 |
+
self.audio_stream_bos_id = audio_stream_bos_id
|
| 90 |
+
self.audio_stream_eos_id = audio_stream_eos_id
|
| 91 |
+
self.pad_token_id = pad_token_id
|
| 92 |
+
self.encode_whisper_embed = encode_whisper_embed
|
| 93 |
+
self.return_audio_in_tokens = return_audio_in_tokens
|
| 94 |
+
self.audio_num_codebooks = audio_num_codebooks
|
| 95 |
+
self.use_delay_pattern = use_delay_pattern
|
| 96 |
+
if encode_whisper_embed:
|
| 97 |
+
self.chunk_size_seconds = chunk_size_seconds
|
| 98 |
+
self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate)
|
| 99 |
+
else:
|
| 100 |
+
self.chunk_size_seconds = None
|
| 101 |
+
self.chunk_size_samples = None
|
| 102 |
+
self.disable_audio_codes_transform = disable_audio_codes_transform
|
| 103 |
+
self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk
|
| 104 |
+
self.mask_audio_out_token_label = mask_audio_out_token_label
|
| 105 |
+
|
| 106 |
+
def _process_and_duplicate_audio_tokens(
|
| 107 |
+
self,
|
| 108 |
+
input_ids: torch.Tensor,
|
| 109 |
+
audio_idx: int,
|
| 110 |
+
wv: torch.Tensor,
|
| 111 |
+
sr: int,
|
| 112 |
+
labels: Optional[torch.Tensor] = None,
|
| 113 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 114 |
+
"""Process long audio and duplicate corresponding audio tokens.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
input_ids: Input token ids
|
| 118 |
+
audio_idx: Index of the audio token in the sequence
|
| 119 |
+
wv: Audio waveform
|
| 120 |
+
sr: Sample rate
|
| 121 |
+
labels: Optional label ids to be duplicated alongside input ids
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Tuple of:
|
| 125 |
+
- New input ids with duplicated audio tokens
|
| 126 |
+
- New label ids (if labels were provided) or None
|
| 127 |
+
- Number of chunks created
|
| 128 |
+
"""
|
| 129 |
+
# Calculate number of chunks needed
|
| 130 |
+
total_samples = len(wv)
|
| 131 |
+
num_chunks = math.ceil(total_samples / self.chunk_size_samples)
|
| 132 |
+
|
| 133 |
+
if num_chunks <= 1:
|
| 134 |
+
return input_ids, labels, 1
|
| 135 |
+
|
| 136 |
+
# Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|>
|
| 137 |
+
audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2]
|
| 138 |
+
# Duplicate sequence for each chunk
|
| 139 |
+
duplicated_sequence = audio_token_seq.repeat(num_chunks)
|
| 140 |
+
|
| 141 |
+
# Create new input_ids with duplicated tokens
|
| 142 |
+
new_input_ids = torch.cat(
|
| 143 |
+
[
|
| 144 |
+
input_ids[: audio_idx - 1],
|
| 145 |
+
duplicated_sequence,
|
| 146 |
+
input_ids[audio_idx + 2 :],
|
| 147 |
+
]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# If labels are provided, duplicate them as well
|
| 151 |
+
new_labels = None
|
| 152 |
+
if labels is not None:
|
| 153 |
+
label_seq = labels[audio_idx - 1 : audio_idx + 2]
|
| 154 |
+
duplicated_labels = label_seq.repeat(num_chunks)
|
| 155 |
+
new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]])
|
| 156 |
+
|
| 157 |
+
return new_input_ids, new_labels, num_chunks
|
| 158 |
+
|
| 159 |
+
def __call__(self, batch: List[ChatMLDatasetSample]):
|
| 160 |
+
"""Collate the input data with support for long audio processing."""
|
| 161 |
+
|
| 162 |
+
label_ids = None
|
| 163 |
+
label_audio_ids = None
|
| 164 |
+
if all([ele.label_ids is None for ele in batch]):
|
| 165 |
+
return_labels = False
|
| 166 |
+
else:
|
| 167 |
+
return_labels = True
|
| 168 |
+
|
| 169 |
+
if self.encode_whisper_embed:
|
| 170 |
+
# Process each sample in the batch to handle long audio
|
| 171 |
+
# TODO(?) The implementation here can be optimized.
|
| 172 |
+
processed_batch = []
|
| 173 |
+
for i in range(len(batch)):
|
| 174 |
+
sample = batch[i]
|
| 175 |
+
audio_in_mask = sample.input_ids == self.audio_in_token_id
|
| 176 |
+
audio_in_indices = torch.where(audio_in_mask)[0]
|
| 177 |
+
audio_out_mask = sample.input_ids == self.audio_out_token_id
|
| 178 |
+
|
| 179 |
+
# Process each audio token and duplicate if needed
|
| 180 |
+
modified_input_ids = sample.input_ids
|
| 181 |
+
modified_labels = sample.label_ids if return_labels else None
|
| 182 |
+
modified_waveforms_concat = []
|
| 183 |
+
modified_waveforms_start = []
|
| 184 |
+
modified_sample_rate = []
|
| 185 |
+
offset = 0 # Track position changes from duplicating tokens
|
| 186 |
+
curr_wv_offset = 0
|
| 187 |
+
|
| 188 |
+
# Process input audio tokens
|
| 189 |
+
for idx, audio_idx in enumerate(audio_in_indices):
|
| 190 |
+
# Get the audio for this token
|
| 191 |
+
wv, sr = sample.get_wv(idx) # Use idx since we want the original audio index
|
| 192 |
+
if sr != self.whisper_processor.feature_extractor.sampling_rate:
|
| 193 |
+
resampled_wv = librosa.resample(
|
| 194 |
+
wv.cpu().numpy(),
|
| 195 |
+
orig_sr=sr,
|
| 196 |
+
target_sr=self.whisper_processor.feature_extractor.sampling_rate,
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
resampled_wv = wv.cpu().numpy()
|
| 200 |
+
wv = torch.tensor(resampled_wv, device=wv.device)
|
| 201 |
+
sr = self.whisper_processor.feature_extractor.sampling_rate
|
| 202 |
+
|
| 203 |
+
# Process and duplicate tokens if necessary
|
| 204 |
+
token_pos = audio_idx + offset
|
| 205 |
+
modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens(
|
| 206 |
+
modified_input_ids, token_pos, wv, sr, modified_labels
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Update audio data
|
| 210 |
+
for chunk_idx in range(num_chunks):
|
| 211 |
+
chunk_start = chunk_idx * self.chunk_size_samples
|
| 212 |
+
chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv))
|
| 213 |
+
chunk_wv = wv[chunk_start:chunk_end]
|
| 214 |
+
modified_waveforms_concat.append(chunk_wv)
|
| 215 |
+
modified_waveforms_start.append(curr_wv_offset)
|
| 216 |
+
curr_wv_offset += len(chunk_wv)
|
| 217 |
+
modified_sample_rate.append(sr)
|
| 218 |
+
|
| 219 |
+
# Update offset for next iteration
|
| 220 |
+
offset += (num_chunks - 1) * 3 # Each new chunk adds 3 more tokens
|
| 221 |
+
|
| 222 |
+
# Create new sample with modified tokens and audio data
|
| 223 |
+
processed_sample = ChatMLDatasetSample(
|
| 224 |
+
input_ids=modified_input_ids,
|
| 225 |
+
label_ids=modified_labels if return_labels else sample.label_ids,
|
| 226 |
+
audio_ids_concat=sample.audio_ids_concat,
|
| 227 |
+
audio_ids_start=sample.audio_ids_start,
|
| 228 |
+
audio_waveforms_concat=torch.cat(modified_waveforms_concat)
|
| 229 |
+
if modified_waveforms_concat
|
| 230 |
+
else sample.audio_waveforms_concat,
|
| 231 |
+
audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long)
|
| 232 |
+
if modified_waveforms_start
|
| 233 |
+
else sample.audio_waveforms_start,
|
| 234 |
+
audio_sample_rate=torch.tensor(modified_sample_rate)
|
| 235 |
+
if modified_sample_rate
|
| 236 |
+
else sample.audio_sample_rate,
|
| 237 |
+
audio_speaker_indices=torch.tensor([]),
|
| 238 |
+
# FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat.
|
| 239 |
+
audio_label_ids_concat=sample.audio_label_ids_concat,
|
| 240 |
+
)
|
| 241 |
+
# audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0])
|
| 242 |
+
# assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}"
|
| 243 |
+
processed_batch.append(processed_sample)
|
| 244 |
+
else:
|
| 245 |
+
processed_batch = batch
|
| 246 |
+
|
| 247 |
+
# Get the max sequence length based on processed batch
|
| 248 |
+
max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to)
|
| 249 |
+
|
| 250 |
+
# Get the ids for audio-in and audio-out for each batch
|
| 251 |
+
audio_in_wv_l = []
|
| 252 |
+
audio_in_ids_l = []
|
| 253 |
+
audio_out_ids_l = []
|
| 254 |
+
audio_out_ids_group_loc_l = []
|
| 255 |
+
audio_in_label_ids_l = None
|
| 256 |
+
audio_out_label_ids_l = None
|
| 257 |
+
reward_l = []
|
| 258 |
+
|
| 259 |
+
if return_labels:
|
| 260 |
+
audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not.
|
| 261 |
+
|
| 262 |
+
# Process the audio inputs and outputs
|
| 263 |
+
for i in range(len(processed_batch)):
|
| 264 |
+
audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id
|
| 265 |
+
audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id
|
| 266 |
+
audio_ids = torch.ones_like(processed_batch[i].input_ids)
|
| 267 |
+
audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1
|
| 268 |
+
audio_in_ids = audio_ids[audio_in_mask]
|
| 269 |
+
audio_out_ids = audio_ids[audio_out_mask]
|
| 270 |
+
|
| 271 |
+
if return_labels:
|
| 272 |
+
audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0)
|
| 273 |
+
if self.mask_audio_out_token_label:
|
| 274 |
+
processed_batch[i].label_ids[audio_out_mask] = -100
|
| 275 |
+
|
| 276 |
+
# Process audio inputs
|
| 277 |
+
if self.return_audio_in_tokens:
|
| 278 |
+
audio_in_ids_l.extend(
|
| 279 |
+
[processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids]
|
| 280 |
+
)
|
| 281 |
+
if processed_batch[i].audio_label_ids_concat is not None:
|
| 282 |
+
if audio_in_label_ids_l is None:
|
| 283 |
+
audio_in_label_ids_l = []
|
| 284 |
+
audio_in_label_ids_l.extend(
|
| 285 |
+
[
|
| 286 |
+
processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
|
| 287 |
+
for idx in audio_in_ids
|
| 288 |
+
]
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
audio_out_ids_l.extend(
|
| 292 |
+
[processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids]
|
| 293 |
+
)
|
| 294 |
+
audio_out_ids_group_loc_l.append(i)
|
| 295 |
+
if processed_batch[i].reward is not None:
|
| 296 |
+
reward_l.append(processed_batch[i].reward)
|
| 297 |
+
|
| 298 |
+
if processed_batch[i].audio_label_ids_concat is not None:
|
| 299 |
+
if audio_out_label_ids_l is None:
|
| 300 |
+
audio_out_label_ids_l = []
|
| 301 |
+
audio_out_label_ids_l.extend(
|
| 302 |
+
[
|
| 303 |
+
processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
|
| 304 |
+
for idx in audio_out_ids
|
| 305 |
+
]
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if self.encode_whisper_embed:
|
| 309 |
+
for idx in audio_in_ids:
|
| 310 |
+
wv, sr = processed_batch[i].get_wv(idx)
|
| 311 |
+
resampled_wv = wv.cpu().numpy()
|
| 312 |
+
# Split long audio into chunks
|
| 313 |
+
total_samples = len(resampled_wv)
|
| 314 |
+
for chunk_start in range(0, total_samples, self.chunk_size_samples):
|
| 315 |
+
chunk_end = min(chunk_start + self.chunk_size_samples, total_samples)
|
| 316 |
+
chunk = resampled_wv[chunk_start:chunk_end]
|
| 317 |
+
audio_in_wv_l.append(chunk)
|
| 318 |
+
# assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \
|
| 319 |
+
# f"Assertion failed: Mismatch in number of audios. " \
|
| 320 |
+
# f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}."
|
| 321 |
+
|
| 322 |
+
if return_labels:
|
| 323 |
+
audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
|
| 324 |
+
|
| 325 |
+
# Process all audio features
|
| 326 |
+
if len(audio_in_wv_l) > 0:
|
| 327 |
+
feature_ret = self.whisper_processor.feature_extractor(
|
| 328 |
+
audio_in_wv_l,
|
| 329 |
+
sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
|
| 330 |
+
return_attention_mask=True,
|
| 331 |
+
padding="max_length",
|
| 332 |
+
)
|
| 333 |
+
audio_features = torch.from_numpy(feature_ret["input_features"])
|
| 334 |
+
audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"])
|
| 335 |
+
else:
|
| 336 |
+
if self.encode_whisper_embed:
|
| 337 |
+
audio_features = torch.zeros(
|
| 338 |
+
(
|
| 339 |
+
0,
|
| 340 |
+
self.whisper_processor.feature_extractor.feature_size,
|
| 341 |
+
self.whisper_processor.feature_extractor.nb_max_frames,
|
| 342 |
+
),
|
| 343 |
+
dtype=torch.float32,
|
| 344 |
+
)
|
| 345 |
+
audio_feature_attention_mask = torch.zeros(
|
| 346 |
+
(0, self.whisper_processor.feature_extractor.nb_max_frames),
|
| 347 |
+
dtype=torch.int32,
|
| 348 |
+
)
|
| 349 |
+
else:
|
| 350 |
+
audio_features = None
|
| 351 |
+
audio_feature_attention_mask = None
|
| 352 |
+
|
| 353 |
+
# Process audio input tokens
|
| 354 |
+
if len(audio_in_ids_l) > 0:
|
| 355 |
+
# Append audio-stream-bos and eos tokens
|
| 356 |
+
new_audio_in_ids_l = []
|
| 357 |
+
for ele in audio_in_ids_l:
|
| 358 |
+
if self.disable_audio_codes_transform:
|
| 359 |
+
# Do not add audio-stream-bos or eos tokens.
|
| 360 |
+
# This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
|
| 361 |
+
audio_codes = ele
|
| 362 |
+
else:
|
| 363 |
+
audio_codes = torch.cat(
|
| 364 |
+
[
|
| 365 |
+
torch.full(
|
| 366 |
+
(ele.shape[0], 1),
|
| 367 |
+
self.audio_stream_bos_id,
|
| 368 |
+
dtype=torch.long,
|
| 369 |
+
),
|
| 370 |
+
ele,
|
| 371 |
+
torch.full(
|
| 372 |
+
(ele.shape[0], 1),
|
| 373 |
+
self.audio_stream_eos_id,
|
| 374 |
+
dtype=torch.long,
|
| 375 |
+
),
|
| 376 |
+
],
|
| 377 |
+
dim=1,
|
| 378 |
+
)
|
| 379 |
+
if self.use_delay_pattern:
|
| 380 |
+
audio_codes = build_delay_pattern_mask(
|
| 381 |
+
audio_codes.unsqueeze(0),
|
| 382 |
+
bos_token_id=self.audio_stream_bos_id,
|
| 383 |
+
pad_token_id=self.audio_stream_eos_id,
|
| 384 |
+
)[0].squeeze(0)
|
| 385 |
+
new_audio_in_ids_l.append(audio_codes)
|
| 386 |
+
audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long()
|
| 387 |
+
audio_in_ids_start = torch.cumsum(
|
| 388 |
+
torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]),
|
| 389 |
+
dim=0,
|
| 390 |
+
)
|
| 391 |
+
else:
|
| 392 |
+
audio_in_ids = torch.zeros((0, 0), dtype=torch.long)
|
| 393 |
+
audio_in_ids_start = torch.zeros(0, dtype=torch.long)
|
| 394 |
+
|
| 395 |
+
# Process audio output tokens
|
| 396 |
+
audio_out_ids_start_group_loc = None
|
| 397 |
+
if len(audio_out_ids_l) > 0:
|
| 398 |
+
new_audio_out_ids_l = []
|
| 399 |
+
label_audio_ids_l = []
|
| 400 |
+
for idx, ele in enumerate(audio_out_ids_l):
|
| 401 |
+
if self.disable_audio_codes_transform:
|
| 402 |
+
# Do not add audio-stream-bos or eos tokens.
|
| 403 |
+
# This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
|
| 404 |
+
audio_codes = ele
|
| 405 |
+
if return_labels:
|
| 406 |
+
label_audio_ids = audio_out_label_ids_l[idx]
|
| 407 |
+
else:
|
| 408 |
+
audio_codes = torch.cat(
|
| 409 |
+
[
|
| 410 |
+
torch.full(
|
| 411 |
+
(ele.shape[0], 1),
|
| 412 |
+
self.audio_stream_bos_id,
|
| 413 |
+
dtype=torch.long,
|
| 414 |
+
),
|
| 415 |
+
ele,
|
| 416 |
+
torch.full(
|
| 417 |
+
(ele.shape[0], 1),
|
| 418 |
+
self.audio_stream_eos_id,
|
| 419 |
+
dtype=torch.long,
|
| 420 |
+
),
|
| 421 |
+
],
|
| 422 |
+
dim=1,
|
| 423 |
+
)
|
| 424 |
+
if return_labels:
|
| 425 |
+
label_audio_ids = torch.cat(
|
| 426 |
+
[
|
| 427 |
+
torch.full((ele.shape[0], 1), -100, dtype=torch.long),
|
| 428 |
+
ele,
|
| 429 |
+
torch.full(
|
| 430 |
+
(ele.shape[0], 1),
|
| 431 |
+
self.audio_stream_eos_id,
|
| 432 |
+
dtype=torch.long,
|
| 433 |
+
),
|
| 434 |
+
],
|
| 435 |
+
dim=1,
|
| 436 |
+
)
|
| 437 |
+
if self.use_delay_pattern:
|
| 438 |
+
audio_codes = build_delay_pattern_mask(
|
| 439 |
+
audio_codes.unsqueeze(0),
|
| 440 |
+
bos_token_id=self.audio_stream_bos_id,
|
| 441 |
+
pad_token_id=self.audio_stream_eos_id,
|
| 442 |
+
)[0].squeeze(0)
|
| 443 |
+
if return_labels:
|
| 444 |
+
label_audio_ids = build_delay_pattern_mask(
|
| 445 |
+
label_audio_ids.unsqueeze(0),
|
| 446 |
+
bos_token_id=-100,
|
| 447 |
+
pad_token_id=-100,
|
| 448 |
+
)[0].squeeze(0)
|
| 449 |
+
new_audio_out_ids_l.append(audio_codes)
|
| 450 |
+
|
| 451 |
+
if return_labels:
|
| 452 |
+
if audio_out_no_train_flag[idx]:
|
| 453 |
+
label_audio_ids[:] = -100
|
| 454 |
+
label_audio_ids_l.append(label_audio_ids)
|
| 455 |
+
|
| 456 |
+
audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long()
|
| 457 |
+
if return_labels:
|
| 458 |
+
label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long()
|
| 459 |
+
audio_out_ids_start = torch.cumsum(
|
| 460 |
+
torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]),
|
| 461 |
+
dim=0,
|
| 462 |
+
)
|
| 463 |
+
audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long)
|
| 464 |
+
else:
|
| 465 |
+
audio_out_ids = torch.zeros((0, 0), dtype=torch.long)
|
| 466 |
+
audio_out_ids_start = torch.zeros(0, dtype=torch.long)
|
| 467 |
+
if return_labels:
|
| 468 |
+
label_audio_ids = torch.zeros((0, 0), dtype=torch.long)
|
| 469 |
+
|
| 470 |
+
reward = torch.tensor(reward_l, dtype=torch.float32)
|
| 471 |
+
|
| 472 |
+
# Handle padding for input ids and attention mask
|
| 473 |
+
if self.pad_left:
|
| 474 |
+
input_ids = torch.stack(
|
| 475 |
+
[
|
| 476 |
+
F.pad(
|
| 477 |
+
ele.input_ids,
|
| 478 |
+
(max_seq_length - len(ele.input_ids), 0),
|
| 479 |
+
value=self.pad_token_id,
|
| 480 |
+
)
|
| 481 |
+
for ele in processed_batch
|
| 482 |
+
]
|
| 483 |
+
)
|
| 484 |
+
if return_labels:
|
| 485 |
+
label_ids = torch.stack(
|
| 486 |
+
[
|
| 487 |
+
F.pad(
|
| 488 |
+
ele.label_ids,
|
| 489 |
+
(max_seq_length - len(ele.label_ids), 0),
|
| 490 |
+
value=-100,
|
| 491 |
+
)
|
| 492 |
+
for ele in processed_batch
|
| 493 |
+
]
|
| 494 |
+
)
|
| 495 |
+
attention_mask = torch.stack(
|
| 496 |
+
[
|
| 497 |
+
F.pad(
|
| 498 |
+
torch.ones_like(ele.input_ids),
|
| 499 |
+
(max_seq_length - len(ele.input_ids), 0),
|
| 500 |
+
value=0,
|
| 501 |
+
)
|
| 502 |
+
for ele in processed_batch
|
| 503 |
+
]
|
| 504 |
+
)
|
| 505 |
+
else:
|
| 506 |
+
input_ids = torch.stack(
|
| 507 |
+
[
|
| 508 |
+
F.pad(
|
| 509 |
+
ele.input_ids,
|
| 510 |
+
(0, max_seq_length - len(ele.input_ids)),
|
| 511 |
+
value=self.pad_token_id,
|
| 512 |
+
)
|
| 513 |
+
for ele in processed_batch
|
| 514 |
+
]
|
| 515 |
+
)
|
| 516 |
+
if return_labels:
|
| 517 |
+
label_ids = torch.stack(
|
| 518 |
+
[
|
| 519 |
+
F.pad(
|
| 520 |
+
ele.label_ids,
|
| 521 |
+
(0, max_seq_length - len(ele.label_ids)),
|
| 522 |
+
value=-100,
|
| 523 |
+
)
|
| 524 |
+
for ele in processed_batch
|
| 525 |
+
]
|
| 526 |
+
)
|
| 527 |
+
attention_mask = torch.stack(
|
| 528 |
+
[
|
| 529 |
+
F.pad(
|
| 530 |
+
torch.ones_like(ele.input_ids),
|
| 531 |
+
(0, max_seq_length - len(ele.input_ids)),
|
| 532 |
+
value=0,
|
| 533 |
+
)
|
| 534 |
+
for ele in processed_batch
|
| 535 |
+
]
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
if not self.return_audio_in_tokens:
|
| 539 |
+
audio_in_ids = None
|
| 540 |
+
audio_in_ids_start = None
|
| 541 |
+
|
| 542 |
+
# Apply audio_num_codebooks limit if specified
|
| 543 |
+
if self.audio_num_codebooks is not None:
|
| 544 |
+
if audio_in_ids is not None:
|
| 545 |
+
audio_in_ids = audio_in_ids[: self.audio_num_codebooks]
|
| 546 |
+
if audio_out_ids is not None:
|
| 547 |
+
audio_out_ids = audio_out_ids[: self.audio_num_codebooks]
|
| 548 |
+
if label_audio_ids is not None:
|
| 549 |
+
label_audio_ids = label_audio_ids[: self.audio_num_codebooks]
|
| 550 |
+
|
| 551 |
+
return HiggsAudioBatchInput(
|
| 552 |
+
input_ids=input_ids,
|
| 553 |
+
attention_mask=attention_mask,
|
| 554 |
+
audio_features=audio_features,
|
| 555 |
+
audio_feature_attention_mask=audio_feature_attention_mask,
|
| 556 |
+
audio_out_ids=audio_out_ids,
|
| 557 |
+
audio_out_ids_start=audio_out_ids_start,
|
| 558 |
+
audio_out_ids_start_group_loc=audio_out_ids_start_group_loc,
|
| 559 |
+
audio_in_ids=audio_in_ids,
|
| 560 |
+
audio_in_ids_start=audio_in_ids_start,
|
| 561 |
+
label_ids=label_ids,
|
| 562 |
+
label_audio_ids=label_audio_ids,
|
| 563 |
+
reward=reward,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
class HiggsAudioDPOSamplesCollator(HiggsAudioSampleCollator):
|
| 568 |
+
def __init__(self, *args, **kwargs):
|
| 569 |
+
super().__init__(*args, **kwargs)
|
| 570 |
+
|
| 571 |
+
def __call__(self, batch: List[RankedChatMLDatasetSampleTuple]) -> HiggsAudioBatchInput:
|
| 572 |
+
# flatten ranked chatml samples
|
| 573 |
+
chosen = []
|
| 574 |
+
rejected = []
|
| 575 |
+
|
| 576 |
+
for sample in batch:
|
| 577 |
+
chosen.append(sample.max_score_sample())
|
| 578 |
+
rejected.append(sample.min_score_sample())
|
| 579 |
+
|
| 580 |
+
merged = chosen
|
| 581 |
+
merged.extend(rejected)
|
| 582 |
+
|
| 583 |
+
return super().__call__(batch=merged)
|
higgs_audio/data_types.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Basic data types for multimodal ChatML format."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Dict, List, Optional, Union
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class AudioContent:
|
| 9 |
+
audio_url: str
|
| 10 |
+
# Base64 encoded audio bytes
|
| 11 |
+
raw_audio: Optional[str] = None
|
| 12 |
+
offset: Optional[float] = None
|
| 13 |
+
duration: Optional[float] = None
|
| 14 |
+
row_id: Optional[int] = None
|
| 15 |
+
type: str = "audio"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class TextContent:
|
| 20 |
+
text: str
|
| 21 |
+
type: str = "text"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class Message:
|
| 26 |
+
role: str
|
| 27 |
+
content: Union[str, AudioContent, TextContent, List[Union[str, AudioContent, TextContent]]]
|
| 28 |
+
recipient: Optional[str] = None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class ChatMLSample:
|
| 33 |
+
"""Dataclass to hold multimodal ChatML data."""
|
| 34 |
+
|
| 35 |
+
messages: List[Message]
|
| 36 |
+
start_index: Optional[int] = None # We will mask the messages[:start_index] when finetuning the LLM.
|
| 37 |
+
misc: Optional[Dict] = None
|
| 38 |
+
speaker: Optional[str] = None
|
higgs_audio/dataset/__init__.py
ADDED
|
File without changes
|
higgs_audio/dataset/chatml_dataset.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dacite
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import torch
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import multiprocessing as mp
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, fields
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
from typing import Union, List, Dict, Optional
|
| 12 |
+
|
| 13 |
+
from ..data_types import ChatMLSample, TextContent, AudioContent
|
| 14 |
+
from ..constants import AUDIO_IN_TOKEN, AUDIO_OUT_TOKEN
|
| 15 |
+
|
| 16 |
+
from loguru import logger
|
| 17 |
+
|
| 18 |
+
# Whisper processor, 30 sec -> 3000 features
|
| 19 |
+
# Then we divide 4 in the audio towker, we decrease 3000 features to 750, which gives 25 Hz
|
| 20 |
+
WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC = 25
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ChatMLDatasetSample:
|
| 25 |
+
input_ids: torch.LongTensor # Shape (seq_len,): The input text tokens.
|
| 26 |
+
label_ids: torch.LongTensor # Shape (seq_len,): The label ids.
|
| 27 |
+
audio_ids_concat: torch.LongTensor # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
|
| 28 |
+
# Here `audio_seq_len` is the length of the concatenated audio tokens.`
|
| 29 |
+
audio_ids_start: (
|
| 30 |
+
torch.LongTensor
|
| 31 |
+
) # Shape (num_audios,): The start index of each audio token in the concatenated audio tokens.
|
| 32 |
+
audio_waveforms_concat: (
|
| 33 |
+
torch.Tensor
|
| 34 |
+
) # Shape (total_wv_length,): The concatenated audio waveforms for audio-in features.
|
| 35 |
+
audio_waveforms_start: (
|
| 36 |
+
torch.LongTensor
|
| 37 |
+
) # Shape (num_audios,): The start index of each audio waveform in the concatenated audio waveforms.
|
| 38 |
+
audio_sample_rate: torch.Tensor # Shape (num_audios,): The sampling rate of the audio waveforms.
|
| 39 |
+
audio_speaker_indices: (
|
| 40 |
+
torch.LongTensor
|
| 41 |
+
) # Shape (num_audios,) -1 means unknown speaker: The speaker indices for each audio.
|
| 42 |
+
audio_label_ids_concat: Optional[torch.LongTensor] = (
|
| 43 |
+
None # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
|
| 44 |
+
)
|
| 45 |
+
# Here `audio_seq_len` is the length of the concatenated audio tokens.`
|
| 46 |
+
reward: Optional[float] = None
|
| 47 |
+
|
| 48 |
+
def num_audios(self):
|
| 49 |
+
return max(len(self.audio_waveforms_start), len(self.audio_ids_start))
|
| 50 |
+
|
| 51 |
+
def get_audio_codes(self, idx):
|
| 52 |
+
code_start = self.audio_ids_start[idx]
|
| 53 |
+
if idx < len(self.audio_ids_start) - 1:
|
| 54 |
+
code_end = self.audio_ids_start[idx + 1]
|
| 55 |
+
else:
|
| 56 |
+
code_end = self.audio_ids_concat.shape[-1]
|
| 57 |
+
|
| 58 |
+
return self.audio_ids_concat[:, code_start:code_end]
|
| 59 |
+
|
| 60 |
+
def get_audio_codes_labels(self, idx):
|
| 61 |
+
if self.audio_label_ids_concat is None:
|
| 62 |
+
return None
|
| 63 |
+
code_start = self.audio_ids_start[idx]
|
| 64 |
+
if idx < len(self.audio_ids_start) - 1:
|
| 65 |
+
code_end = self.audio_ids_start[idx + 1]
|
| 66 |
+
else:
|
| 67 |
+
code_end = self.audio_ids_concat.shape[-1]
|
| 68 |
+
|
| 69 |
+
return self.audio_label_ids_concat[:, code_start:code_end]
|
| 70 |
+
|
| 71 |
+
def get_wv(self, idx):
|
| 72 |
+
wv_start = self.audio_waveforms_start[idx]
|
| 73 |
+
sr = self.audio_sample_rate[idx]
|
| 74 |
+
if idx < len(self.audio_waveforms_start) - 1:
|
| 75 |
+
wv_end = self.audio_waveforms_start[idx + 1]
|
| 76 |
+
else:
|
| 77 |
+
wv_end = self.audio_waveforms_concat.shape[-1]
|
| 78 |
+
return self.audio_waveforms_concat[wv_start:wv_end], sr
|
| 79 |
+
|
| 80 |
+
def cal_num_tokens(
|
| 81 |
+
self,
|
| 82 |
+
encode_whisper_embed: bool = True,
|
| 83 |
+
encode_audio_in_tokens: bool = False,
|
| 84 |
+
encode_audio_out_tokens: bool = True,
|
| 85 |
+
audio_in_token_id: int = 128015,
|
| 86 |
+
audio_out_token_id: int = 128016,
|
| 87 |
+
) -> int:
|
| 88 |
+
# we firstly exclude <|AUDIO|> and <|AUDIO_OUT|> because we do late merging and replace those position with actual audio features and audio token ids
|
| 89 |
+
# It's assumed that we always have audio_ids when audio_waveforms are there (but not vice-versa)
|
| 90 |
+
num_tokens = len(self.input_ids) - len(self.audio_ids_start)
|
| 91 |
+
|
| 92 |
+
if encode_whisper_embed and len(self.audio_waveforms_concat) > 0:
|
| 93 |
+
audio_lengths = torch.diff(self.audio_waveforms_start)
|
| 94 |
+
if len(audio_lengths):
|
| 95 |
+
# Sum before calling .item()
|
| 96 |
+
num_tokens += (
|
| 97 |
+
(
|
| 98 |
+
np.ceil(WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC * audio_lengths / self.audio_sample_rate[:-1])
|
| 99 |
+
).sum()
|
| 100 |
+
).item()
|
| 101 |
+
# add the last audio's token estimation
|
| 102 |
+
num_tokens += (
|
| 103 |
+
np.ceil(
|
| 104 |
+
WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC
|
| 105 |
+
* (self.audio_waveforms_concat.shape[0] - self.audio_waveforms_start[-1])
|
| 106 |
+
/ self.audio_sample_rate[-1]
|
| 107 |
+
)
|
| 108 |
+
).item()
|
| 109 |
+
|
| 110 |
+
if self.audio_ids_concat.size(1) > 0:
|
| 111 |
+
audio_io_ids = self.input_ids[
|
| 112 |
+
(self.input_ids == audio_in_token_id) | (self.input_ids == audio_out_token_id)
|
| 113 |
+
]
|
| 114 |
+
audio_io_id_lengths = torch.concat(
|
| 115 |
+
[
|
| 116 |
+
torch.diff(self.audio_ids_start),
|
| 117 |
+
torch.tensor([self.audio_ids_concat.shape[-1] - self.audio_ids_start[-1]]),
|
| 118 |
+
]
|
| 119 |
+
)
|
| 120 |
+
if encode_audio_in_tokens:
|
| 121 |
+
num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_in_token_id]).item()
|
| 122 |
+
|
| 123 |
+
if encode_audio_out_tokens:
|
| 124 |
+
num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_out_token_id]).item()
|
| 125 |
+
|
| 126 |
+
return int(num_tokens)
|
| 127 |
+
|
| 128 |
+
@classmethod
|
| 129 |
+
def merge(
|
| 130 |
+
cls,
|
| 131 |
+
samples: List["ChatMLDatasetSample"],
|
| 132 |
+
eos_token_id: int,
|
| 133 |
+
ignore_index: int,
|
| 134 |
+
padding_size: Optional[int] = None,
|
| 135 |
+
) -> "ChatMLDatasetSample":
|
| 136 |
+
"""Merges a list of ChatMLDatasetSample instances, inserting eos_token_id and ignore_index between them, and adjusting offsets for audio_ids_start and audio_waveforms_start.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
samples (List[ChatMLDatasetSample]): List of samples to merge.
|
| 140 |
+
eos_token_id (int): Tokens to be inserted into input_ids between samples.
|
| 141 |
+
ignore_index (int): Default label for padding.
|
| 142 |
+
padding_size (Optional[int]): If provided, pad the sequence to with this length.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
ChatMLDatasetSample: Merged and potentially padded sample.
|
| 146 |
+
"""
|
| 147 |
+
if not samples:
|
| 148 |
+
logger.fatal("The samples list is empty and cannot be merged.")
|
| 149 |
+
raise ValueError("The samples list is empty and cannot be merged.")
|
| 150 |
+
|
| 151 |
+
# Initialize empty lists for concatenation
|
| 152 |
+
input_ids_list = []
|
| 153 |
+
label_ids_list = []
|
| 154 |
+
audio_ids_concat_list = []
|
| 155 |
+
audio_ids_start_list = []
|
| 156 |
+
audio_waveforms_concat_list = []
|
| 157 |
+
audio_waveforms_start_list = []
|
| 158 |
+
audio_sample_rate_list = []
|
| 159 |
+
audio_speaker_indices_list = []
|
| 160 |
+
|
| 161 |
+
# Track offsets
|
| 162 |
+
audio_ids_offset = 0
|
| 163 |
+
audio_waveforms_offset = 0
|
| 164 |
+
|
| 165 |
+
for sample in samples:
|
| 166 |
+
# Add input_ids and label_ids with padding
|
| 167 |
+
if input_ids_list:
|
| 168 |
+
input_ids_list.append(torch.tensor([eos_token_id], dtype=torch.long))
|
| 169 |
+
label_ids_list.append(torch.tensor([ignore_index], dtype=torch.long))
|
| 170 |
+
input_ids_list.append(sample.input_ids)
|
| 171 |
+
label_ids_list.append(sample.label_ids)
|
| 172 |
+
|
| 173 |
+
# Add audio_ids_concat and handle empty audio ids
|
| 174 |
+
if sample.audio_ids_concat.size(1) > 0:
|
| 175 |
+
audio_ids_concat_list.append(sample.audio_ids_concat)
|
| 176 |
+
|
| 177 |
+
# Offset and add audio_ids_start
|
| 178 |
+
audio_ids_start_list.append(sample.audio_ids_start + audio_ids_offset)
|
| 179 |
+
audio_ids_offset += sample.audio_ids_concat.size(
|
| 180 |
+
1
|
| 181 |
+
) # (num_codebooks, seq_len): Update offset by audio_seq_len
|
| 182 |
+
|
| 183 |
+
# Add audio_waveforms_concat
|
| 184 |
+
if sample.audio_waveforms_concat.size(0) > 0:
|
| 185 |
+
# Check dimensions of the audio waveform to ensure consistency
|
| 186 |
+
if (
|
| 187 |
+
audio_waveforms_concat_list
|
| 188 |
+
and sample.audio_waveforms_concat.dim() != audio_waveforms_concat_list[0].dim()
|
| 189 |
+
):
|
| 190 |
+
logger.warning(
|
| 191 |
+
f"Skipping audio waveform with inconsistent dimensions: expected {audio_waveforms_concat_list[0].dim()}D, got {sample.audio_waveforms_concat.dim()}D"
|
| 192 |
+
)
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
audio_waveforms_concat_list.append(sample.audio_waveforms_concat)
|
| 196 |
+
audio_waveforms_start_list.append(sample.audio_waveforms_start + audio_waveforms_offset)
|
| 197 |
+
audio_waveforms_offset += sample.audio_waveforms_concat.size(0)
|
| 198 |
+
|
| 199 |
+
# Add audio_sample_rate and audio_speaker_indices
|
| 200 |
+
audio_sample_rate_list.append(sample.audio_sample_rate)
|
| 201 |
+
|
| 202 |
+
audio_speaker_indices_list.append(sample.audio_speaker_indices)
|
| 203 |
+
|
| 204 |
+
# Concatenate all tensors
|
| 205 |
+
input_ids = torch.cat(input_ids_list, dim=0)
|
| 206 |
+
label_ids = torch.cat(label_ids_list, dim=0)
|
| 207 |
+
|
| 208 |
+
# Apply padding if padding_size is specified
|
| 209 |
+
if padding_size is not None and padding_size > 0:
|
| 210 |
+
input_ids = torch.cat(
|
| 211 |
+
[
|
| 212 |
+
input_ids,
|
| 213 |
+
torch.full((padding_size,), eos_token_id, dtype=torch.long),
|
| 214 |
+
],
|
| 215 |
+
dim=0,
|
| 216 |
+
)
|
| 217 |
+
label_ids = torch.cat(
|
| 218 |
+
[
|
| 219 |
+
label_ids,
|
| 220 |
+
torch.full((padding_size,), ignore_index, dtype=torch.long),
|
| 221 |
+
],
|
| 222 |
+
dim=0,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Safely concatenate audio tensors with proper error handling
|
| 226 |
+
try:
|
| 227 |
+
audio_ids_concat = torch.cat(audio_ids_concat_list, dim=1) if audio_ids_concat_list else torch.tensor([[]])
|
| 228 |
+
audio_ids_start = torch.cat(audio_ids_start_list, dim=0) if audio_ids_start_list else torch.tensor([])
|
| 229 |
+
|
| 230 |
+
# Check for dimensional consistency in audio waveforms
|
| 231 |
+
if audio_waveforms_concat_list:
|
| 232 |
+
dims = [t.dim() for t in audio_waveforms_concat_list]
|
| 233 |
+
if not all(d == dims[0] for d in dims):
|
| 234 |
+
# If dimensions don't match, log warning and filter out the problematic tensors
|
| 235 |
+
logger.warning(
|
| 236 |
+
f"Inconsistent dimensions in audio waveforms: {dims}. Filtering to keep only consistent ones."
|
| 237 |
+
)
|
| 238 |
+
expected_dim = max(set(dims), key=dims.count) # Most common dimension
|
| 239 |
+
audio_waveforms_concat_list = [t for t in audio_waveforms_concat_list if t.dim() == expected_dim]
|
| 240 |
+
|
| 241 |
+
# Recalculate audio_waveforms_start with the filtered list
|
| 242 |
+
if audio_waveforms_concat_list:
|
| 243 |
+
audio_waveforms_offset = 0
|
| 244 |
+
audio_waveforms_start_list = []
|
| 245 |
+
for waveform in audio_waveforms_concat_list:
|
| 246 |
+
audio_waveforms_start_list.append(torch.tensor([audio_waveforms_offset]))
|
| 247 |
+
audio_waveforms_offset += waveform.size(0)
|
| 248 |
+
|
| 249 |
+
audio_waveforms_concat = (
|
| 250 |
+
torch.cat(audio_waveforms_concat_list, dim=0) if audio_waveforms_concat_list else torch.tensor([])
|
| 251 |
+
)
|
| 252 |
+
audio_waveforms_start = (
|
| 253 |
+
torch.cat(audio_waveforms_start_list, dim=0) if audio_waveforms_start_list else torch.tensor([])
|
| 254 |
+
)
|
| 255 |
+
audio_sample_rate = (
|
| 256 |
+
torch.cat(audio_sample_rate_list, dim=0) if audio_sample_rate_list else torch.tensor([])
|
| 257 |
+
)
|
| 258 |
+
audio_speaker_indices = (
|
| 259 |
+
torch.cat(audio_speaker_indices_list, dim=0) if audio_speaker_indices_list else torch.tensor([])
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
except RuntimeError as e:
|
| 263 |
+
logger.error(f"Error during tensor concatenation: {str(e)}")
|
| 264 |
+
logger.warning("Falling back to empty audio tensors")
|
| 265 |
+
# Fall back to empty tensors
|
| 266 |
+
audio_ids_concat = torch.tensor([[]])
|
| 267 |
+
audio_ids_start = torch.tensor([])
|
| 268 |
+
audio_waveforms_concat = torch.tensor([])
|
| 269 |
+
audio_waveforms_start = torch.tensor([])
|
| 270 |
+
audio_sample_rate = torch.tensor([])
|
| 271 |
+
audio_speaker_indices = torch.tensor([])
|
| 272 |
+
|
| 273 |
+
# Create the merged sample
|
| 274 |
+
merged_sample = cls(
|
| 275 |
+
input_ids=input_ids,
|
| 276 |
+
label_ids=label_ids,
|
| 277 |
+
audio_ids_concat=audio_ids_concat,
|
| 278 |
+
audio_ids_start=audio_ids_start,
|
| 279 |
+
audio_waveforms_concat=audio_waveforms_concat,
|
| 280 |
+
audio_waveforms_start=audio_waveforms_start,
|
| 281 |
+
audio_sample_rate=audio_sample_rate,
|
| 282 |
+
audio_speaker_indices=audio_speaker_indices,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
return merged_sample
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
@dataclass
|
| 289 |
+
class RankedChatMLDatasetSampleTuple:
|
| 290 |
+
samples: List[ChatMLDatasetSample]
|
| 291 |
+
scores: List[float]
|
| 292 |
+
|
| 293 |
+
def max_score_sample(self) -> ChatMLDatasetSample:
|
| 294 |
+
idx = self.scores.index(max(self.scores))
|
| 295 |
+
self.samples[idx].reward = self.scores[idx]
|
| 296 |
+
return self.samples[idx]
|
| 297 |
+
|
| 298 |
+
def min_score_sample(self) -> ChatMLDatasetSample:
|
| 299 |
+
idx = self.scores.index(min(self.scores))
|
| 300 |
+
self.samples[idx].reward = self.scores[idx]
|
| 301 |
+
return self.samples[idx]
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@dataclass
|
| 305 |
+
class ChatMLDatasetStorageSample:
|
| 306 |
+
input_tokens: torch.LongTensor
|
| 307 |
+
label_tokens: torch.LongTensor
|
| 308 |
+
audio_bytes_cache_dir_index: int
|
| 309 |
+
audio_codes_cache_dir_index: int
|
| 310 |
+
audio_bytes_indices: torch.LongTensor
|
| 311 |
+
audio_codes_indices: torch.LongTensor
|
| 312 |
+
speaker_indices: torch.LongTensor
|
| 313 |
+
file_index: int
|
| 314 |
+
original_sample_index: int
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# TODO(sxjscience): We need to revist the logic about parsing speaker ids.
|
| 318 |
+
# Currently, we assume that the speaker id is stored at the "misc" field in ChatMLSample.
|
| 319 |
+
def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
|
| 320 |
+
"""Preprocess the ChatML sample to get the tokens for the text part.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
sample (ChatMLSample): The ChatML sample to preprocess.
|
| 324 |
+
tokenizer: The tokenizer to use for encoding the text.
|
| 325 |
+
|
| 326 |
+
"""
|
| 327 |
+
|
| 328 |
+
try:
|
| 329 |
+
if not isinstance(sample, ChatMLSample):
|
| 330 |
+
# Handle all fields that could be NaN
|
| 331 |
+
if "speaker" in sample and pd.isna(sample["speaker"]):
|
| 332 |
+
sample["speaker"] = None
|
| 333 |
+
if "start_index" in sample and pd.isna(sample["start_index"]):
|
| 334 |
+
sample["start_index"] = None
|
| 335 |
+
if "content" in sample and pd.isna(sample["content"]):
|
| 336 |
+
sample["content"] = ""
|
| 337 |
+
|
| 338 |
+
# Convert any other potential NaN values in nested structures
|
| 339 |
+
def convert_nan_to_none(obj):
|
| 340 |
+
import numpy as np
|
| 341 |
+
|
| 342 |
+
if isinstance(obj, (pd.Series, np.ndarray)):
|
| 343 |
+
return obj.tolist()
|
| 344 |
+
elif pd.api.types.is_scalar(obj) and pd.isna(obj):
|
| 345 |
+
return None
|
| 346 |
+
elif isinstance(obj, dict):
|
| 347 |
+
return {k: convert_nan_to_none(v) for k, v in obj.items()}
|
| 348 |
+
elif isinstance(obj, (list, tuple)): # Fixed: Handle both list and tuple
|
| 349 |
+
return [convert_nan_to_none(item) for item in obj]
|
| 350 |
+
return obj
|
| 351 |
+
|
| 352 |
+
# Clean the sample data
|
| 353 |
+
clean_sample = convert_nan_to_none(sample)
|
| 354 |
+
|
| 355 |
+
val_keys = []
|
| 356 |
+
for field in fields(ChatMLSample):
|
| 357 |
+
if field.name in clean_sample:
|
| 358 |
+
val_keys.append(field.name)
|
| 359 |
+
clean_sample = {k: clean_sample[k] for k in val_keys}
|
| 360 |
+
|
| 361 |
+
try:
|
| 362 |
+
sample = dacite.from_dict(
|
| 363 |
+
data_class=ChatMLSample,
|
| 364 |
+
data=clean_sample,
|
| 365 |
+
config=dacite.Config(strict=True, check_types=True),
|
| 366 |
+
)
|
| 367 |
+
except Exception as e:
|
| 368 |
+
print(f"Failed to convert to ChatMLSample: {e}")
|
| 369 |
+
print(f"Clean sample: {json.dumps(clean_sample, indent=2)}")
|
| 370 |
+
return None, None, None, None
|
| 371 |
+
|
| 372 |
+
input_tokens = []
|
| 373 |
+
label_tokens = []
|
| 374 |
+
audio_contents = []
|
| 375 |
+
speaker_id = None
|
| 376 |
+
if sample.speaker is not None:
|
| 377 |
+
speaker_id = sample.speaker
|
| 378 |
+
elif sample.misc is not None:
|
| 379 |
+
if "speaker" in sample.misc:
|
| 380 |
+
speaker_id = sample.misc["speaker"]
|
| 381 |
+
|
| 382 |
+
total_m = len(sample.messages)
|
| 383 |
+
for turn_id, message in enumerate(sample.messages):
|
| 384 |
+
role = message.role
|
| 385 |
+
recipient = message.recipient
|
| 386 |
+
content = message.content
|
| 387 |
+
content_l = []
|
| 388 |
+
|
| 389 |
+
if isinstance(content, str):
|
| 390 |
+
content_l.append(TextContent(text=content))
|
| 391 |
+
elif isinstance(content, TextContent):
|
| 392 |
+
content_l.append(content)
|
| 393 |
+
elif isinstance(content, AudioContent):
|
| 394 |
+
content_l.append(content)
|
| 395 |
+
elif isinstance(content, list):
|
| 396 |
+
for ele in content:
|
| 397 |
+
if isinstance(ele, str):
|
| 398 |
+
content_l.append(TextContent(text=ele))
|
| 399 |
+
else:
|
| 400 |
+
content_l.append(ele)
|
| 401 |
+
if turn_id == 0:
|
| 402 |
+
prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n"
|
| 403 |
+
else:
|
| 404 |
+
prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
| 405 |
+
eot_postfix = "<|eot_id|>"
|
| 406 |
+
eom_postfix = "<|eom_id|>"
|
| 407 |
+
|
| 408 |
+
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
|
| 409 |
+
input_tokens.extend(prefix_tokens)
|
| 410 |
+
label_tokens.extend([-100 for _ in prefix_tokens])
|
| 411 |
+
|
| 412 |
+
if recipient:
|
| 413 |
+
assert role == "assistant", "Recipient is only available for assistant role."
|
| 414 |
+
recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False)
|
| 415 |
+
input_tokens.extend(recipient_tokens)
|
| 416 |
+
label_tokens.extend(recipient_tokens)
|
| 417 |
+
|
| 418 |
+
for content in content_l:
|
| 419 |
+
if content.type == "text":
|
| 420 |
+
text_tokens = tokenizer.encode(content.text, add_special_tokens=False)
|
| 421 |
+
input_tokens.extend(text_tokens)
|
| 422 |
+
if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
|
| 423 |
+
label_tokens.extend(text_tokens)
|
| 424 |
+
else:
|
| 425 |
+
label_tokens.extend([-100 for _ in text_tokens])
|
| 426 |
+
|
| 427 |
+
elif content.type == "audio":
|
| 428 |
+
# Generate the text-part of the audio tokens
|
| 429 |
+
audio_contents.append(content)
|
| 430 |
+
if role == "user" or role == "system":
|
| 431 |
+
# Add the text tokens
|
| 432 |
+
text_tokens = tokenizer.encode(
|
| 433 |
+
f"<|audio_bos|><|AUDIO|><|audio_eos|>",
|
| 434 |
+
add_special_tokens=False,
|
| 435 |
+
)
|
| 436 |
+
input_tokens.extend(text_tokens)
|
| 437 |
+
label_tokens.extend([-100 for _ in text_tokens])
|
| 438 |
+
elif role == "assistant":
|
| 439 |
+
# Add the text tokens for audio-out part.
|
| 440 |
+
text_tokens = tokenizer.encode(
|
| 441 |
+
f"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>",
|
| 442 |
+
add_special_tokens=False,
|
| 443 |
+
)
|
| 444 |
+
input_tokens.extend(text_tokens)
|
| 445 |
+
if sample.start_index is None or turn_id >= sample.start_index:
|
| 446 |
+
label_tokens.extend(text_tokens)
|
| 447 |
+
else:
|
| 448 |
+
label_tokens.extend([-100 for _ in text_tokens])
|
| 449 |
+
next_id = turn_id + 1
|
| 450 |
+
if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant":
|
| 451 |
+
postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False)
|
| 452 |
+
input_tokens.extend(postfix_tokens)
|
| 453 |
+
else:
|
| 454 |
+
postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False)
|
| 455 |
+
input_tokens.extend(postfix_tokens)
|
| 456 |
+
if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
|
| 457 |
+
label_tokens.extend(postfix_tokens)
|
| 458 |
+
else:
|
| 459 |
+
label_tokens.extend([-100 for _ in postfix_tokens])
|
| 460 |
+
|
| 461 |
+
return input_tokens, label_tokens, audio_contents, speaker_id
|
| 462 |
+
|
| 463 |
+
except Exception as e:
|
| 464 |
+
print(f"Error in prepare_chatml_sample: {str(e)}")
|
| 465 |
+
print(f"Sample data: {json.dumps(sample, indent=2)}")
|
| 466 |
+
return None, None, None, None
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def extract_generation_prompt_from_input_tokens(input_tokens, tokenizer):
|
| 470 |
+
"""Extract the generation prompt and reference answer from the input tokens.
|
| 471 |
+
|
| 472 |
+
For example:
|
| 473 |
+
|
| 474 |
+
Input Text = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
|
| 475 |
+
What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
|
| 476 |
+
<|start_header_id|>assistant<|end_header_id|>\n\nAt first they went by quick, too quick to even get.<|eot_id|>'
|
| 477 |
+
|
| 478 |
+
-->
|
| 479 |
+
|
| 480 |
+
Prompt = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
|
| 481 |
+
What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
|
| 482 |
+
<|start_header_id|>assistant<|end_header_id|>\n\n',
|
| 483 |
+
Reference = 'At first they went by quick, too quick to even get.'
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
input_tokens: The input tokens.
|
| 487 |
+
audio_contents: The audio contents.
|
| 488 |
+
tokenizer: The tokenizer to use for decoding the text.
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
prompt_tokens: The tokens for the prompt.
|
| 492 |
+
reference_answer: The reference answer.
|
| 493 |
+
num_audios_in_reference: The number of audios in the reference answer.
|
| 494 |
+
|
| 495 |
+
"""
|
| 496 |
+
input_text = tokenizer.decode(input_tokens)
|
| 497 |
+
generation_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 498 |
+
postfix = "<|eot_id|>"
|
| 499 |
+
assert generation_prefix in input_text
|
| 500 |
+
generation_prompt_end_loc = input_text.rfind(generation_prefix) + len(generation_prefix)
|
| 501 |
+
generation_prompt = input_text[:generation_prompt_end_loc]
|
| 502 |
+
reference_answer = input_text[generation_prompt_end_loc : input_text.find(postfix, generation_prompt_end_loc)]
|
| 503 |
+
num_audios_in_reference = reference_answer.count(AUDIO_IN_TOKEN) + reference_answer.count(AUDIO_OUT_TOKEN)
|
| 504 |
+
return (
|
| 505 |
+
tokenizer.encode(generation_prompt, add_special_tokens=False),
|
| 506 |
+
reference_answer,
|
| 507 |
+
num_audios_in_reference,
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def prepare_chatml_dataframe_single_process(df, tokenizer):
|
| 512 |
+
"""Prepare the ChatML DataFrame."""
|
| 513 |
+
ret = []
|
| 514 |
+
for _, row in df.iterrows():
|
| 515 |
+
input_tokens, label_tokens, audio_contents, speaker_id = prepare_chatml_sample(row.to_dict(), tokenizer)
|
| 516 |
+
ret.append((input_tokens, label_tokens, audio_contents, speaker_id))
|
| 517 |
+
return ret
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def prepare_chatml_dataframe(df, tokenizer, num_process=16):
|
| 521 |
+
if num_process is None:
|
| 522 |
+
return prepare_chatml_dataframe_single_process(df, tokenizer)
|
| 523 |
+
else:
|
| 524 |
+
num_process = max(min(len(df) // 1000, num_process), 1)
|
| 525 |
+
workloads = np.array_split(df, num_process)
|
| 526 |
+
with mp.Pool(num_process) as pool:
|
| 527 |
+
ret = pool.starmap(
|
| 528 |
+
prepare_chatml_dataframe_single_process,
|
| 529 |
+
[(workload, tokenizer) for workload in workloads],
|
| 530 |
+
)
|
| 531 |
+
return sum(ret, [])
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class DatasetInterface(ABC):
|
| 535 |
+
@abstractmethod
|
| 536 |
+
def __getitem__(self, idx) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
|
| 537 |
+
"""Retrieve a dataset sample by index."""
|
| 538 |
+
raise NotImplementedError
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
class IterableDatasetInterface(ABC):
|
| 542 |
+
@abstractmethod
|
| 543 |
+
def __iter__(
|
| 544 |
+
self,
|
| 545 |
+
) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
|
| 546 |
+
"""Retrieve a sample by iterating through the dataset."""
|
| 547 |
+
raise NotImplementedError
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
@dataclass
|
| 551 |
+
class DatasetInfo:
|
| 552 |
+
dataset_type: str
|
| 553 |
+
group_type: Optional[str] = None
|
| 554 |
+
mask_text: Optional[bool] = None # Whether to mask the text tokens for pretraining samples.
|
higgs_audio/model/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoConfig, AutoModel
|
| 2 |
+
|
| 3 |
+
from .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncoderConfig
|
| 4 |
+
from .modeling_higgs_audio import HiggsAudioModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
AutoConfig.register("higgs_audio_encoder", HiggsAudioEncoderConfig)
|
| 8 |
+
AutoConfig.register("higgs_audio", HiggsAudioConfig)
|
| 9 |
+
AutoModel.register(HiggsAudioConfig, HiggsAudioModel)
|
higgs_audio/model/audio_head.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Projector that maps hidden states from the LLM component to multimodal logits."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
|
| 9 |
+
from .common import HiggsAudioPreTrainedModel
|
| 10 |
+
from .configuration_higgs_audio import HiggsAudioConfig
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class HiggsAudioDecoderLayerOutput:
|
| 15 |
+
logits: torch.FloatTensor
|
| 16 |
+
audio_logits: torch.FloatTensor
|
| 17 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 18 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel):
|
| 22 |
+
"""Projection layers that map hidden states from the LLM component to audio / text logits.
|
| 23 |
+
|
| 24 |
+
We support two type of audio head:
|
| 25 |
+
- Basic Audio Head:
|
| 26 |
+
Directly map the hidden states to audio logits for all the codebooks.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None):
|
| 30 |
+
super().__init__(config)
|
| 31 |
+
self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
| 32 |
+
self.audio_lm_head = nn.Linear(
|
| 33 |
+
config.text_config.hidden_size,
|
| 34 |
+
config.audio_num_codebooks * (config.audio_codebook_size + 2),
|
| 35 |
+
bias=False,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Initialize weights and apply final processing
|
| 39 |
+
self.post_init()
|
| 40 |
+
|
| 41 |
+
def forward(
|
| 42 |
+
self,
|
| 43 |
+
hidden_states,
|
| 44 |
+
audio_out_mask,
|
| 45 |
+
label_audio_ids=None,
|
| 46 |
+
attention_mask=None,
|
| 47 |
+
position_ids=None,
|
| 48 |
+
past_key_values=None,
|
| 49 |
+
use_cache=None,
|
| 50 |
+
output_attentions=None,
|
| 51 |
+
output_hidden_states=None,
|
| 52 |
+
output_audio_hidden_states=False,
|
| 53 |
+
cache_position=None,
|
| 54 |
+
):
|
| 55 |
+
"""
|
| 56 |
+
Args:
|
| 57 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
|
| 58 |
+
Hidden states from the LLM component
|
| 59 |
+
audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
| 60 |
+
Mask for identifying the audio out tokens.
|
| 61 |
+
label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`):
|
| 62 |
+
Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used.
|
| 63 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
| 64 |
+
Mask to avoid performing attention on padding token indices
|
| 65 |
+
position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
| 66 |
+
Position ids for the input tokens
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`):
|
| 70 |
+
Logits for text tokens
|
| 71 |
+
audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`):
|
| 72 |
+
Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len`
|
| 73 |
+
"""
|
| 74 |
+
logits = self.text_lm_head(hidden_states)
|
| 75 |
+
|
| 76 |
+
all_hidden_states = () if output_hidden_states else None
|
| 77 |
+
all_self_attns = () if output_attentions else None
|
| 78 |
+
next_decoder_cache = None
|
| 79 |
+
|
| 80 |
+
# TODO(sxjscience) Need to check if DeepSpeed Zero3 supports zero-shape input.
|
| 81 |
+
if self.config.audio_decoder_proj_num_layers > 0:
|
| 82 |
+
# create position embeddings to be shared across the decoder layers
|
| 83 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 84 |
+
for decoder_layer in self.transformer_layers:
|
| 85 |
+
if output_hidden_states:
|
| 86 |
+
all_hidden_states += (hidden_states,)
|
| 87 |
+
|
| 88 |
+
if self.gradient_checkpointing and self.training:
|
| 89 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 90 |
+
decoder_layer.__call__,
|
| 91 |
+
hidden_states,
|
| 92 |
+
attention_mask,
|
| 93 |
+
position_ids,
|
| 94 |
+
past_key_values,
|
| 95 |
+
output_attentions,
|
| 96 |
+
use_cache,
|
| 97 |
+
cache_position,
|
| 98 |
+
position_embeddings,
|
| 99 |
+
)
|
| 100 |
+
else:
|
| 101 |
+
layer_outputs = decoder_layer(
|
| 102 |
+
hidden_states,
|
| 103 |
+
attention_mask=attention_mask,
|
| 104 |
+
position_ids=position_ids,
|
| 105 |
+
past_key_value=past_key_values,
|
| 106 |
+
output_attentions=output_attentions,
|
| 107 |
+
use_cache=use_cache,
|
| 108 |
+
cache_position=cache_position,
|
| 109 |
+
position_embeddings=position_embeddings,
|
| 110 |
+
)
|
| 111 |
+
hidden_states = layer_outputs[0]
|
| 112 |
+
hidden_states = self.norm(hidden_states)
|
| 113 |
+
|
| 114 |
+
if output_hidden_states:
|
| 115 |
+
all_hidden_states += (hidden_states,)
|
| 116 |
+
|
| 117 |
+
if output_attentions:
|
| 118 |
+
all_self_attns += (layer_outputs[1],)
|
| 119 |
+
|
| 120 |
+
if use_cache:
|
| 121 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 122 |
+
|
| 123 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 124 |
+
|
| 125 |
+
audio_logits = self.audio_lm_head(hidden_states[audio_out_mask])
|
| 126 |
+
|
| 127 |
+
if output_audio_hidden_states:
|
| 128 |
+
audio_hidden_states = hidden_states[audio_out_mask]
|
| 129 |
+
else:
|
| 130 |
+
audio_hidden_states = None
|
| 131 |
+
|
| 132 |
+
return (
|
| 133 |
+
logits,
|
| 134 |
+
audio_logits,
|
| 135 |
+
all_self_attns,
|
| 136 |
+
all_hidden_states,
|
| 137 |
+
audio_hidden_states,
|
| 138 |
+
next_cache,
|
| 139 |
+
)
|
higgs_audio/model/common.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 4 |
+
|
| 5 |
+
from .configuration_higgs_audio import HiggsAudioConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class HiggsAudioPreTrainedModel(PreTrainedModel):
|
| 9 |
+
config_class = HiggsAudioConfig
|
| 10 |
+
base_model_prefix = "model"
|
| 11 |
+
supports_gradient_checkpointing = True
|
| 12 |
+
_no_split_modules = []
|
| 13 |
+
_skip_keys_device_placement = "past_key_values"
|
| 14 |
+
_supports_flash_attn_2 = True
|
| 15 |
+
_supports_sdpa = True
|
| 16 |
+
|
| 17 |
+
def _init_weights(self, module):
|
| 18 |
+
std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_encoder_config.init_std
|
| 19 |
+
|
| 20 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
| 21 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 22 |
+
if module.bias is not None:
|
| 23 |
+
module.bias.data.zero_()
|
| 24 |
+
elif isinstance(module, nn.Embedding):
|
| 25 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 26 |
+
if module.padding_idx is not None:
|
| 27 |
+
module.weight.data[module.padding_idx].zero_()
|
higgs_audio/model/configuration_higgs_audio.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 2 |
+
from transformers.models.auto import CONFIG_MAPPING
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class HiggsAudioEncoderConfig(PretrainedConfig):
|
| 6 |
+
"""Configuration of the Audio encoder in Higgs-Audio."""
|
| 7 |
+
|
| 8 |
+
model_type = "higgs_audio_encoder"
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
num_mel_bins=128,
|
| 13 |
+
encoder_layers=32,
|
| 14 |
+
encoder_attention_heads=20,
|
| 15 |
+
encoder_ffn_dim=5120,
|
| 16 |
+
encoder_layerdrop=0.0,
|
| 17 |
+
d_model=1280,
|
| 18 |
+
dropout=0.0,
|
| 19 |
+
attention_dropout=0.0,
|
| 20 |
+
activation_function="gelu",
|
| 21 |
+
activation_dropout=0.0,
|
| 22 |
+
scale_embedding=False,
|
| 23 |
+
init_std=0.02,
|
| 24 |
+
max_source_positions=1500,
|
| 25 |
+
pad_token_id=128001,
|
| 26 |
+
**kwargs,
|
| 27 |
+
):
|
| 28 |
+
super().__init__(**kwargs)
|
| 29 |
+
|
| 30 |
+
self.num_mel_bins = num_mel_bins
|
| 31 |
+
self.d_model = d_model
|
| 32 |
+
self.encoder_layers = encoder_layers
|
| 33 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 34 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 35 |
+
self.dropout = dropout
|
| 36 |
+
self.attention_dropout = attention_dropout
|
| 37 |
+
self.activation_function = activation_function
|
| 38 |
+
self.activation_dropout = activation_dropout
|
| 39 |
+
self.encoder_layerdrop = encoder_layerdrop
|
| 40 |
+
self.num_hidden_layers = encoder_layers
|
| 41 |
+
self.init_std = init_std
|
| 42 |
+
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
| 43 |
+
self.max_source_positions = max_source_positions
|
| 44 |
+
self.pad_token_id = pad_token_id
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class HiggsAudioConfig(PretrainedConfig):
|
| 48 |
+
r"""
|
| 49 |
+
This is the configuration class for the HiggsAudioModel.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
text_config (`Union[AutoConfig, dict]`):
|
| 53 |
+
The config object or dictionary of the text backbone.
|
| 54 |
+
audio_encoder_config (`Union[AutoConfig, dict]`):
|
| 55 |
+
The config object or dictionary of the whisper encoder.
|
| 56 |
+
The audio encoder will be bidirectional and will be only available for audio understanding.
|
| 57 |
+
audio_tokenizer_config
|
| 58 |
+
The config object or dictionary of the audio tokenizer.
|
| 59 |
+
audio_adapter_type
|
| 60 |
+
The type of audio adapter to use. We support two types of adapter:
|
| 61 |
+
- stack:
|
| 62 |
+
We stack additional Transformer layers after the main LLM backbone for audio generation.
|
| 63 |
+
- dual_ffn:
|
| 64 |
+
For selected part of the LLM backbone, we replace the text FFN with a dual FFN architecture
|
| 65 |
+
that contains an additional audio FFN. The audio FFN will be triggered when the location is marked for audio tokens.
|
| 66 |
+
- dual_ffn_fast_forward:
|
| 67 |
+
We pick a few layers in the LLM backbone to plug-in the audio FFN. For the remaining layers,
|
| 68 |
+
the audio hidden states will be directly fast-forward to the next layer.
|
| 69 |
+
This reduces the computational cost for audio generation.
|
| 70 |
+
audio_embed_avg (`bool`, *optional*, defaults to False):
|
| 71 |
+
Whether to average the audio embeddings before sending them to the text attention layer.
|
| 72 |
+
audio_ffn_hidden_size
|
| 73 |
+
The hidden size of the audio feedforward network in dual-path FFN
|
| 74 |
+
audio_ffn_intermediate_size
|
| 75 |
+
The intermediate size of the audio feedforward network in dual-path FFN
|
| 76 |
+
audio_dual_ffn_layers
|
| 77 |
+
The layers in the LLM backbone to plug-in the dual FFN layer (mixture of audio FFN and text FFN).
|
| 78 |
+
audio_decoder_proj_num_attention (`int`, *optional*, defaults to 0):
|
| 79 |
+
The number of attention heads in the audio decoder projection layer.
|
| 80 |
+
use_delay_pattern (`bool`, *optional*, defaults to False):
|
| 81 |
+
Whether to use delay pattern in the audio decoder.
|
| 82 |
+
skip_audio_tower (`bool`, *optional*, defaults to False):
|
| 83 |
+
Whether to skip the audio tower in the audio encoder.
|
| 84 |
+
use_audio_out_embed_projector (`bool`, *optional*, defaults to False):
|
| 85 |
+
Whether to use an embedding projector to map audio out embeddings.
|
| 86 |
+
use_audio_out_self_attention (`bool`, *optional*, defaults to False):
|
| 87 |
+
Whether to use self-attention to aggregate information from audio-tokens before sending to the text attention layer.
|
| 88 |
+
audio_num_codebooks (`int`, *optional*, defaults to 12):
|
| 89 |
+
The number of codebooks in RVQGAN.
|
| 90 |
+
audio_codebook_size (`int`, *optional*, defaults to 1024):
|
| 91 |
+
The size of each codebook in RVQGAN.
|
| 92 |
+
audio_stream_bos_id
|
| 93 |
+
The id of the bos in the audio stream
|
| 94 |
+
audio_stream_eos_id
|
| 95 |
+
The id of the eos in the audio stream
|
| 96 |
+
audio_bos_token (`str`, *optional*, defaults to "<|audio_bos|>"):
|
| 97 |
+
The special `<|audio_bos|>` token. In Higgs-Audio, it is mapped to 128011,
|
| 98 |
+
which is the index of `<|reserved_special_token_3|>` in Llama-3.1-8B-Instruct's tokenizer.
|
| 99 |
+
audio_eos_token (`str`, *optional*, defaults to "<|audio_eos|>"):
|
| 100 |
+
The special `<|audio_eos|>` token. We use 128012 as the default value,
|
| 101 |
+
which is the index of `<|reserved_special_token_4|>` in Llama-3.1-8B-Instruct's tokenizer.
|
| 102 |
+
audio_out_bos_token (`str`, *optional*, defaults to "<|audio_out_bos|>"):
|
| 103 |
+
The special `<|audio_out_bos|>` token. We use 128013 as the default value,
|
| 104 |
+
which is the index of `<|reserved_special_token_5|>` in Llama-3.1-8B-Instruct's tokenizer.
|
| 105 |
+
audio_token (`str`, *optional*, defaults to "<|AUDIO|>"):
|
| 106 |
+
The special `<|AUDIO|>` token. We use 128015 as the default value,
|
| 107 |
+
which is the index of `<|reserved_special_token_7|>` in Llama-3.1-8B-Instruct's tokenizer.
|
| 108 |
+
This token indicates that the location should be filled in with whisper features.
|
| 109 |
+
audio_out_token (`str`, *optional*, defaults to "<|AUDIO_OUT|>"):
|
| 110 |
+
The special `<|AUDIO_OUT|>` token. We use 128016 as the default value,
|
| 111 |
+
which is the index of `<|reserved_special_token_8|>` in Llama-3.1-8B-Instruct's tokenizer.
|
| 112 |
+
This token indicates that the location should be filled in with audio tokens extracted via audio tokenizer.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
model_type = "higgs_audio"
|
| 116 |
+
is_composition = True
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
text_config=None,
|
| 121 |
+
audio_encoder_config=None,
|
| 122 |
+
audio_tokenizer_config=None,
|
| 123 |
+
audio_adapter_type="stack",
|
| 124 |
+
audio_embed_avg=False,
|
| 125 |
+
audio_ffn_hidden_size=4096,
|
| 126 |
+
audio_ffn_intermediate_size=14336,
|
| 127 |
+
audio_dual_ffn_layers=None,
|
| 128 |
+
audio_decoder_proj_num_layers=0,
|
| 129 |
+
encode_whisper_embed=True,
|
| 130 |
+
encode_audio_in_tokens=False,
|
| 131 |
+
use_delay_pattern=False,
|
| 132 |
+
skip_audio_tower=False,
|
| 133 |
+
use_audio_out_embed_projector=False,
|
| 134 |
+
use_audio_out_self_attention=False,
|
| 135 |
+
use_rq_transformer=False,
|
| 136 |
+
rq_transformer_hidden_size=None,
|
| 137 |
+
rq_transformer_intermediate_size=None,
|
| 138 |
+
rq_transformer_num_attention_heads=None,
|
| 139 |
+
rq_transformer_num_key_value_heads=None,
|
| 140 |
+
rq_transformer_num_hidden_layers=3,
|
| 141 |
+
audio_num_codebooks=12,
|
| 142 |
+
audio_codebook_size=1024,
|
| 143 |
+
audio_stream_bos_id=1024,
|
| 144 |
+
audio_stream_eos_id=1025,
|
| 145 |
+
audio_bos_token="<|audio_bos|>",
|
| 146 |
+
audio_eos_token="<|audio_eos|>",
|
| 147 |
+
audio_out_bos_token="<|audio_out_bos|>",
|
| 148 |
+
audio_in_token="<|AUDIO|>",
|
| 149 |
+
audio_out_token="<|AUDIO_OUT|>",
|
| 150 |
+
audio_in_token_idx=128015,
|
| 151 |
+
audio_out_token_idx=128016,
|
| 152 |
+
pad_token_id=128001,
|
| 153 |
+
audio_out_bos_token_id=128013,
|
| 154 |
+
audio_eos_token_id=128012,
|
| 155 |
+
**kwargs,
|
| 156 |
+
):
|
| 157 |
+
if isinstance(audio_encoder_config, dict):
|
| 158 |
+
audio_encoder_config["model_type"] = (
|
| 159 |
+
audio_encoder_config["model_type"] if "model_type" in audio_encoder_config else "higgs_audio_encoder"
|
| 160 |
+
)
|
| 161 |
+
audio_encoder_config = CONFIG_MAPPING[audio_encoder_config["model_type"]](**audio_encoder_config)
|
| 162 |
+
elif audio_encoder_config is None:
|
| 163 |
+
audio_encoder_config = HiggsAudioEncoderConfig()
|
| 164 |
+
|
| 165 |
+
if isinstance(text_config, dict):
|
| 166 |
+
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
| 167 |
+
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
| 168 |
+
elif text_config is None:
|
| 169 |
+
text_config = CONFIG_MAPPING["llama"]()
|
| 170 |
+
|
| 171 |
+
assert audio_adapter_type in [
|
| 172 |
+
"stack",
|
| 173 |
+
"dual_ffn",
|
| 174 |
+
"dual_ffn_fast_forward",
|
| 175 |
+
], f"Invalid audio adapter type: {audio_adapter_type}"
|
| 176 |
+
if audio_adapter_type.startswith("dual_ffn"):
|
| 177 |
+
assert audio_dual_ffn_layers is not None, (
|
| 178 |
+
"audio_dual_ffn_layers must be specified when using dual_ffn adapter."
|
| 179 |
+
)
|
| 180 |
+
self.text_config = text_config
|
| 181 |
+
self.audio_encoder_config = audio_encoder_config
|
| 182 |
+
self.audio_tokenizer_config = audio_tokenizer_config
|
| 183 |
+
self.audio_adapter_type = audio_adapter_type
|
| 184 |
+
self.audio_embed_avg = audio_embed_avg
|
| 185 |
+
self.audio_ffn_hidden_size = audio_ffn_hidden_size
|
| 186 |
+
self.audio_ffn_intermediate_size = audio_ffn_intermediate_size
|
| 187 |
+
self.audio_dual_ffn_layers = audio_dual_ffn_layers
|
| 188 |
+
self.audio_decoder_proj_num_layers = audio_decoder_proj_num_layers
|
| 189 |
+
self.encode_whisper_embed = encode_whisper_embed
|
| 190 |
+
self.encode_audio_in_tokens = encode_audio_in_tokens
|
| 191 |
+
self.use_delay_pattern = use_delay_pattern
|
| 192 |
+
self.skip_audio_tower = skip_audio_tower
|
| 193 |
+
self.use_audio_out_embed_projector = use_audio_out_embed_projector
|
| 194 |
+
self.use_audio_out_self_attention = use_audio_out_self_attention
|
| 195 |
+
|
| 196 |
+
self.use_rq_transformer = use_rq_transformer
|
| 197 |
+
|
| 198 |
+
if self.use_rq_transformer:
|
| 199 |
+
assert not self.use_delay_pattern, "Delay pattern is not supported if you turned on RQ-Transformer!"
|
| 200 |
+
self.rq_transformer_hidden_size = rq_transformer_hidden_size
|
| 201 |
+
self.rq_transformer_intermediate_size = rq_transformer_intermediate_size
|
| 202 |
+
self.rq_transformer_num_attention_heads = rq_transformer_num_attention_heads
|
| 203 |
+
self.rq_transformer_num_key_value_heads = rq_transformer_num_key_value_heads
|
| 204 |
+
self.rq_transformer_num_hidden_layers = rq_transformer_num_hidden_layers
|
| 205 |
+
|
| 206 |
+
if use_rq_transformer:
|
| 207 |
+
# For RQ-Transformer, we set the hidden_size to the same as the text model's hidden size if it is not specified.
|
| 208 |
+
if self.rq_transformer_hidden_size is None:
|
| 209 |
+
self.rq_transformer_hidden_size = text_config.hidden_size
|
| 210 |
+
assert self.rq_transformer_hidden_size % 128 == 0
|
| 211 |
+
if self.rq_transformer_intermediate_size is None:
|
| 212 |
+
self.rq_transformer_intermediate_size = text_config.intermediate_size
|
| 213 |
+
if self.rq_transformer_num_attention_heads is None:
|
| 214 |
+
self.rq_transformer_num_attention_heads = self.rq_transformer_hidden_size // 128
|
| 215 |
+
if self.rq_transformer_num_key_value_heads is None:
|
| 216 |
+
self.rq_transformer_num_key_value_heads = self.rq_transformer_hidden_size // 128 // 4
|
| 217 |
+
assert self.rq_transformer_hidden_size % self.rq_transformer_num_attention_heads == 0
|
| 218 |
+
assert self.rq_transformer_hidden_size % self.rq_transformer_num_key_value_heads == 0
|
| 219 |
+
|
| 220 |
+
self.audio_num_codebooks = audio_num_codebooks
|
| 221 |
+
self.audio_codebook_size = audio_codebook_size
|
| 222 |
+
self.audio_bos_token = audio_bos_token
|
| 223 |
+
self.audio_eos_token = audio_eos_token
|
| 224 |
+
self.audio_out_bos_token = audio_out_bos_token
|
| 225 |
+
self.audio_in_token = audio_in_token
|
| 226 |
+
self.audio_out_token = audio_out_token
|
| 227 |
+
self.audio_in_token_idx = audio_in_token_idx
|
| 228 |
+
self.audio_out_token_idx = audio_out_token_idx
|
| 229 |
+
self.audio_stream_bos_id = audio_stream_bos_id
|
| 230 |
+
self.audio_stream_eos_id = audio_stream_eos_id
|
| 231 |
+
self.audio_out_bos_token_id = audio_out_bos_token_id
|
| 232 |
+
self.audio_eos_token_id = audio_eos_token_id
|
| 233 |
+
|
| 234 |
+
super().__init__(**kwargs)
|
| 235 |
+
self.pad_token_id = pad_token_id
|
higgs_audio/model/cuda_graph_runner.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Optional, List, Dict, Tuple, Union
|
| 4 |
+
import gc
|
| 5 |
+
|
| 6 |
+
from transformers.cache_utils import Cache
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
_NUM_WARMUP_ITERS = 2
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CUDAGraphRunner(nn.Module):
|
| 13 |
+
def __init__(self, model):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.model = model
|
| 16 |
+
|
| 17 |
+
self.input_buffers: Dict[str, torch.Tensor] = {}
|
| 18 |
+
self.output_buffers: Dict[str, torch.Tensor] = {}
|
| 19 |
+
|
| 20 |
+
self._graph: Optional[torch.cuda.CUDAGraph] = None
|
| 21 |
+
|
| 22 |
+
@property
|
| 23 |
+
def graph(self):
|
| 24 |
+
assert self._graph is not None
|
| 25 |
+
return self._graph
|
| 26 |
+
|
| 27 |
+
def capture(
|
| 28 |
+
self,
|
| 29 |
+
hidden_states: torch.Tensor,
|
| 30 |
+
causal_mask: torch.Tensor,
|
| 31 |
+
position_ids: torch.Tensor,
|
| 32 |
+
audio_discrete_codes_mask: torch.Tensor,
|
| 33 |
+
cache_position: torch.Tensor,
|
| 34 |
+
past_key_values: Union[Cache, List[torch.FloatTensor]],
|
| 35 |
+
use_cache: bool,
|
| 36 |
+
audio_attention_mask: torch.Tensor,
|
| 37 |
+
fast_forward_attention_mask: torch.Tensor,
|
| 38 |
+
output_attentions: bool,
|
| 39 |
+
output_hidden_states: bool,
|
| 40 |
+
is_decoding_audio_token: Optional[bool] = None,
|
| 41 |
+
is_using_cuda_graph: Optional[bool] = False,
|
| 42 |
+
stream: torch.cuda.Stream = None,
|
| 43 |
+
memory_pool: Optional[Tuple[int, int]] = None,
|
| 44 |
+
):
|
| 45 |
+
assert self._graph is None
|
| 46 |
+
# Run warmup iterations
|
| 47 |
+
for _ in range(_NUM_WARMUP_ITERS):
|
| 48 |
+
self.model(
|
| 49 |
+
hidden_states=hidden_states,
|
| 50 |
+
causal_mask=causal_mask,
|
| 51 |
+
position_ids=position_ids,
|
| 52 |
+
audio_discrete_codes_mask=audio_discrete_codes_mask,
|
| 53 |
+
cache_position=cache_position,
|
| 54 |
+
past_key_values=past_key_values,
|
| 55 |
+
use_cache=use_cache,
|
| 56 |
+
audio_attention_mask=audio_attention_mask,
|
| 57 |
+
fast_forward_attention_mask=fast_forward_attention_mask,
|
| 58 |
+
output_attentions=output_attentions,
|
| 59 |
+
output_hidden_states=output_hidden_states,
|
| 60 |
+
is_decoding_audio_token=is_decoding_audio_token,
|
| 61 |
+
is_using_cuda_graph=is_using_cuda_graph,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
torch.cuda.synchronize()
|
| 65 |
+
|
| 66 |
+
# Capture the graph
|
| 67 |
+
self._graph = torch.cuda.CUDAGraph()
|
| 68 |
+
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
| 69 |
+
out_hidden_states, all_hidden_states, all_self_attns = self.model(
|
| 70 |
+
hidden_states=hidden_states,
|
| 71 |
+
causal_mask=causal_mask,
|
| 72 |
+
position_ids=position_ids,
|
| 73 |
+
audio_discrete_codes_mask=audio_discrete_codes_mask,
|
| 74 |
+
cache_position=cache_position,
|
| 75 |
+
past_key_values=past_key_values,
|
| 76 |
+
use_cache=use_cache,
|
| 77 |
+
audio_attention_mask=audio_attention_mask,
|
| 78 |
+
fast_forward_attention_mask=fast_forward_attention_mask,
|
| 79 |
+
output_attentions=output_attentions,
|
| 80 |
+
output_hidden_states=output_hidden_states,
|
| 81 |
+
is_decoding_audio_token=is_decoding_audio_token,
|
| 82 |
+
is_using_cuda_graph=is_using_cuda_graph,
|
| 83 |
+
)
|
| 84 |
+
# hidden_states_out = torch.ops._C.weak_ref_tensor(outputs[0])
|
| 85 |
+
# del outputs
|
| 86 |
+
gc.collect()
|
| 87 |
+
torch.cuda.synchronize()
|
| 88 |
+
|
| 89 |
+
# Save input and output buffers
|
| 90 |
+
self.input_buffers = {
|
| 91 |
+
"hidden_states": hidden_states,
|
| 92 |
+
"causal_mask": causal_mask,
|
| 93 |
+
"position_ids": position_ids,
|
| 94 |
+
"audio_discrete_codes_mask": audio_discrete_codes_mask,
|
| 95 |
+
"cache_position": cache_position,
|
| 96 |
+
"past_key_values": past_key_values,
|
| 97 |
+
"audio_attention_mask": audio_attention_mask,
|
| 98 |
+
"fast_forward_attention_mask": fast_forward_attention_mask,
|
| 99 |
+
}
|
| 100 |
+
self.output_buffers = {
|
| 101 |
+
"hidden_states": out_hidden_states,
|
| 102 |
+
"all_hidden_states": all_hidden_states,
|
| 103 |
+
"all_self_attns": all_self_attns,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
def forward(
|
| 107 |
+
self,
|
| 108 |
+
hidden_states: torch.Tensor,
|
| 109 |
+
causal_mask: torch.Tensor,
|
| 110 |
+
position_ids: torch.Tensor,
|
| 111 |
+
audio_discrete_codes_mask: torch.Tensor,
|
| 112 |
+
cache_position: torch.Tensor,
|
| 113 |
+
audio_attention_mask: torch.Tensor,
|
| 114 |
+
fast_forward_attention_mask: torch.Tensor,
|
| 115 |
+
**kwargs,
|
| 116 |
+
) -> torch.Tensor:
|
| 117 |
+
# Copy input tensors to buffers
|
| 118 |
+
self.input_buffers["hidden_states"].copy_(hidden_states, non_blocking=True)
|
| 119 |
+
self.input_buffers["causal_mask"].copy_(causal_mask, non_blocking=True)
|
| 120 |
+
self.input_buffers["position_ids"].copy_(position_ids, non_blocking=True)
|
| 121 |
+
self.input_buffers["audio_discrete_codes_mask"].copy_(audio_discrete_codes_mask, non_blocking=True)
|
| 122 |
+
self.input_buffers["cache_position"].copy_(cache_position, non_blocking=True)
|
| 123 |
+
self.input_buffers["audio_attention_mask"].copy_(audio_attention_mask, non_blocking=True)
|
| 124 |
+
self.input_buffers["fast_forward_attention_mask"].copy_(fast_forward_attention_mask, non_blocking=True)
|
| 125 |
+
|
| 126 |
+
# Run the captured graph
|
| 127 |
+
self.graph.replay()
|
| 128 |
+
|
| 129 |
+
return self.output_buffers["hidden_states"], None, None
|
higgs_audio/model/custom_modules.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PartiallyFrozenEmbedding(nn.Module):
|
| 6 |
+
"""Split an existing `nn.Embedding` module that splits the embedding into:
|
| 7 |
+
|
| 8 |
+
- A frozen embedding for indices [0..freeze_until_idx].
|
| 9 |
+
- A trainable embedding for indices [freeze_until_idx+1..vocab_size-1].
|
| 10 |
+
|
| 11 |
+
This should work with both Zero-2 and Zero-3 seamlessly
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int):
|
| 15 |
+
"""
|
| 16 |
+
:param original_embedding: An instance of nn.Embedding (the original embedding layer).
|
| 17 |
+
:param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen.
|
| 18 |
+
"""
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.freeze_until_idx = freeze_until_idx
|
| 21 |
+
self.original_vocab_size = original_embedding.num_embeddings
|
| 22 |
+
self.embedding_dim = original_embedding.embedding_dim
|
| 23 |
+
|
| 24 |
+
# Split the original embedding into frozen and trainable parts
|
| 25 |
+
self.embedding_frozen = nn.Embedding(
|
| 26 |
+
freeze_until_idx,
|
| 27 |
+
self.embedding_dim,
|
| 28 |
+
dtype=original_embedding.weight.dtype,
|
| 29 |
+
device=original_embedding.weight.device,
|
| 30 |
+
)
|
| 31 |
+
self.embedding_trainable = nn.Embedding(
|
| 32 |
+
self.original_vocab_size - freeze_until_idx,
|
| 33 |
+
self.embedding_dim,
|
| 34 |
+
dtype=original_embedding.weight.dtype,
|
| 35 |
+
device=original_embedding.weight.device,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Copy weights from the original embedding into the frozen and trainable parts
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx])
|
| 41 |
+
self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:])
|
| 42 |
+
|
| 43 |
+
# Freeze the frozen embedding
|
| 44 |
+
self.embedding_frozen.weight.requires_grad = False
|
| 45 |
+
|
| 46 |
+
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
"""
|
| 48 |
+
Forward pass for the split embedding wrapper.
|
| 49 |
+
:param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1].
|
| 50 |
+
"""
|
| 51 |
+
# Masks to separate frozen and trainable indices
|
| 52 |
+
# (bsz, seq_len)
|
| 53 |
+
mask_frozen = input_ids < self.freeze_until_idx
|
| 54 |
+
mask_trainable = ~mask_frozen
|
| 55 |
+
|
| 56 |
+
# Output tensor for embedding results
|
| 57 |
+
batch_size, seq_len = input_ids.shape
|
| 58 |
+
embeddings = torch.zeros(
|
| 59 |
+
batch_size,
|
| 60 |
+
seq_len,
|
| 61 |
+
self.embedding_dim,
|
| 62 |
+
device=input_ids.device,
|
| 63 |
+
dtype=self.embedding_frozen.weight.dtype,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Handle frozen embedding
|
| 67 |
+
if mask_frozen.any():
|
| 68 |
+
frozen_ids = input_ids[mask_frozen]
|
| 69 |
+
frozen_emb = self.embedding_frozen(frozen_ids)
|
| 70 |
+
embeddings[mask_frozen] = frozen_emb
|
| 71 |
+
|
| 72 |
+
# Handle trainable embedding
|
| 73 |
+
if mask_trainable.any():
|
| 74 |
+
# Adjust trainable IDs to the local index space of the trainable embedding
|
| 75 |
+
trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx)
|
| 76 |
+
trainable_emb = self.embedding_trainable(trainable_ids)
|
| 77 |
+
embeddings[mask_trainable] = trainable_emb
|
| 78 |
+
|
| 79 |
+
return embeddings
|
| 80 |
+
|
| 81 |
+
def to_unsplit(self) -> nn.Embedding:
|
| 82 |
+
unsplit_embedding = nn.Embedding(
|
| 83 |
+
self.original_vocab_size,
|
| 84 |
+
self.embedding_dim,
|
| 85 |
+
dtype=self.embedding_frozen.weight.dtype,
|
| 86 |
+
device=self.embedding_frozen.weight.device,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight)
|
| 91 |
+
unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight)
|
| 92 |
+
|
| 93 |
+
return unsplit_embedding
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class PartiallyFrozenLinear(nn.Module):
|
| 97 |
+
"""A wrapper around nn.Linear to partially freeze part of the weight matrix."""
|
| 98 |
+
|
| 99 |
+
def __init__(self, original_linear: nn.Linear, freeze_until_idx: int):
|
| 100 |
+
"""
|
| 101 |
+
:param original_linear: The original nn.Linear layer.
|
| 102 |
+
:param freeze_until_idx: The index up to which the rows of the weight matrix are frozen.
|
| 103 |
+
"""
|
| 104 |
+
super().__init__()
|
| 105 |
+
assert original_linear.bias is None, "Currently only support linear module without bias"
|
| 106 |
+
|
| 107 |
+
self.freeze_until_idx = freeze_until_idx
|
| 108 |
+
self.input_dim = original_linear.in_features
|
| 109 |
+
self.output_dim = original_linear.out_features
|
| 110 |
+
|
| 111 |
+
# Create frozen and trainable linear layers
|
| 112 |
+
self.linear_frozen = nn.Linear(
|
| 113 |
+
self.input_dim,
|
| 114 |
+
freeze_until_idx,
|
| 115 |
+
bias=False,
|
| 116 |
+
dtype=original_linear.weight.dtype,
|
| 117 |
+
device=original_linear.weight.device,
|
| 118 |
+
)
|
| 119 |
+
self.linear_trainable = nn.Linear(
|
| 120 |
+
self.input_dim,
|
| 121 |
+
self.output_dim - freeze_until_idx,
|
| 122 |
+
bias=False,
|
| 123 |
+
dtype=original_linear.weight.dtype,
|
| 124 |
+
device=original_linear.weight.device,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Copy weights from the original linear layer
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx])
|
| 130 |
+
self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:])
|
| 131 |
+
|
| 132 |
+
# Freeze the frozen linear layer
|
| 133 |
+
self.linear_frozen.weight.requires_grad = False
|
| 134 |
+
|
| 135 |
+
def forward(self, input_tensor):
|
| 136 |
+
# input_tensor: (bsz, seq_len, hidden_state_dim)
|
| 137 |
+
frozen_output = self.linear_frozen(input_tensor)
|
| 138 |
+
trainable_output = self.linear_trainable(input_tensor)
|
| 139 |
+
return torch.cat((frozen_output, trainable_output), dim=-1)
|
| 140 |
+
|
| 141 |
+
def to_unsplit(self) -> nn.Linear:
|
| 142 |
+
unsplit_linear = nn.Linear(
|
| 143 |
+
self.input_dim,
|
| 144 |
+
self.output_dim,
|
| 145 |
+
bias=False,
|
| 146 |
+
dtype=self.linear_frozen.weight.dtype,
|
| 147 |
+
device=self.linear_frozen.weight.device,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Copy weights from the frozen and trainable layers into the unsplit linear layer
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight)
|
| 153 |
+
unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight)
|
| 154 |
+
|
| 155 |
+
return unsplit_linear
|
higgs_audio/model/modeling_higgs_audio.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
higgs_audio/model/utils.py
ADDED
|
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
from functools import wraps
|
| 4 |
+
import torch
|
| 5 |
+
from transformers.integrations import is_deepspeed_available
|
| 6 |
+
|
| 7 |
+
if is_deepspeed_available():
|
| 8 |
+
from deepspeed.utils import groups as deepspeed_groups
|
| 9 |
+
from deepspeed.sequence.layer import _SeqAllToAll
|
| 10 |
+
else:
|
| 11 |
+
deepspeed_groups = None
|
| 12 |
+
_SeqAllToAll = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _ceil_to_nearest(n, round_to):
|
| 16 |
+
return (n + round_to - 1) // round_to * round_to
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def count_parameters(model, trainable_only=True):
|
| 20 |
+
if trainable_only:
|
| 21 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 22 |
+
else:
|
| 23 |
+
return sum(p.numel() for p in model.parameters())
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# TODO(sxjscience) Consider to move the function to audio_processing/utils.py
|
| 27 |
+
def build_delay_pattern_mask(
|
| 28 |
+
input_ids: torch.LongTensor,
|
| 29 |
+
bos_token_id: int,
|
| 30 |
+
pad_token_id: int,
|
| 31 |
+
):
|
| 32 |
+
"""Implement the delay pattern proposed in "Simple and Controllable Music Generation", https://arxiv.org/pdf/2306.05284
|
| 33 |
+
|
| 34 |
+
In the delay pattern, each codebook is offset by the previous codebook by
|
| 35 |
+
one. We insert a special delay token at the start of the sequence if its delayed, and append pad token once the sequence finishes.
|
| 36 |
+
|
| 37 |
+
Take the example where there are 4 codebooks and audio sequence length=5. After shifting, the output should have length seq_len + num_codebooks - 1
|
| 38 |
+
|
| 39 |
+
- [ *, *, *, *, *, P, P, P]
|
| 40 |
+
- [ B, *, *, *, *, *, P, P]
|
| 41 |
+
- [ B, B, *, *, *, *, *, P]
|
| 42 |
+
- [ B, B, B, *, *, *, *, *]
|
| 43 |
+
|
| 44 |
+
where B indicates the delay token id, P is the special padding token id and `*` indicates that the original audio token.
|
| 45 |
+
|
| 46 |
+
Now let's consider the case where we have a sequence of audio tokens to condition on.
|
| 47 |
+
The audio tokens were originally in the following non-delayed form:
|
| 48 |
+
|
| 49 |
+
- [a, b]
|
| 50 |
+
- [c, d]
|
| 51 |
+
- [e, f]
|
| 52 |
+
- [g, h]
|
| 53 |
+
|
| 54 |
+
After conversion, we get the following delayed form:
|
| 55 |
+
- [a, b, -1, -1, -1]
|
| 56 |
+
- [B, c, d, -1, -1]
|
| 57 |
+
- [B, B, e, f, -1]
|
| 58 |
+
- [B, B, B, g, h]
|
| 59 |
+
|
| 60 |
+
Note that we have a special token `-1` that indicates it should be replaced by a new token we see in the generation phase.
|
| 61 |
+
In that case, we should override the `-1` tokens in auto-regressive generation.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
input_ids (:obj:`torch.LongTensor`):
|
| 65 |
+
The input ids of the prompt. It will have shape (bsz, num_codebooks, seq_len).
|
| 66 |
+
bos_token_id (:obj:`int`):
|
| 67 |
+
The id of the special delay token
|
| 68 |
+
pad_token_id (:obj:`int`):
|
| 69 |
+
The id of the padding token. Should be the same as eos_token_id.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
input_ids (:obj:`torch.LongTensor`):
|
| 73 |
+
The transformed input ids with delay pattern applied. It will have shape (bsz, num_codebooks, seq_len + num_codebooks - 1).
|
| 74 |
+
input_ids_with_gen_mask (:obj:`torch.LongTensor`):
|
| 75 |
+
The transformed input ids with delay pattern applied. The -1 in the output indicates new tokens that should be generated.
|
| 76 |
+
|
| 77 |
+
"""
|
| 78 |
+
bsz, num_codebooks, seq_len = input_ids.shape
|
| 79 |
+
|
| 80 |
+
new_seq_len = seq_len + num_codebooks - 1
|
| 81 |
+
input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device)
|
| 82 |
+
bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0
|
| 83 |
+
eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0
|
| 84 |
+
input_ids_with_gen_mask[bos_mask] = bos_token_id
|
| 85 |
+
input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1)
|
| 86 |
+
input_ids = input_ids_with_gen_mask.clone()
|
| 87 |
+
input_ids[eos_mask] = pad_token_id
|
| 88 |
+
input_ids_with_gen_mask[eos_mask] = -1
|
| 89 |
+
return input_ids, input_ids_with_gen_mask
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def revert_delay_pattern(data):
|
| 93 |
+
"""Convert samples encoded with delay pattern back to the original form.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
data (:obj:`torch.Tensor`):
|
| 97 |
+
The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1).
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
ret (:obj:`torch.Tensor`):
|
| 101 |
+
Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len).
|
| 102 |
+
"""
|
| 103 |
+
assert len(data.shape) == 2
|
| 104 |
+
out_l = []
|
| 105 |
+
num_codebooks = data.shape[0]
|
| 106 |
+
for i in range(num_codebooks):
|
| 107 |
+
out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)])
|
| 108 |
+
return torch.cat(out_l, dim=0)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def merge_input_ids_with_audio_features(
|
| 112 |
+
audio_features_embed,
|
| 113 |
+
audio_features_length,
|
| 114 |
+
audio_in_embed,
|
| 115 |
+
audio_in_ids_start,
|
| 116 |
+
audio_out_embed,
|
| 117 |
+
audio_out_ids_start,
|
| 118 |
+
audio_in_token_idx,
|
| 119 |
+
audio_out_token_idx,
|
| 120 |
+
inputs_embeds,
|
| 121 |
+
input_ids,
|
| 122 |
+
attention_mask,
|
| 123 |
+
label_ids,
|
| 124 |
+
pad_token_id,
|
| 125 |
+
ignore_index=-100,
|
| 126 |
+
round_to=8,
|
| 127 |
+
left_padding=True,
|
| 128 |
+
):
|
| 129 |
+
"""
|
| 130 |
+
Merge input_ids with audio features into final embeddings.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
audio_features_embed (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`):
|
| 134 |
+
Encoded vectors of all audios in the batch (obtained from the semantic encoder)
|
| 135 |
+
audio_features_length (`torch.LongTensor` of shape `(num_audios,)`):
|
| 136 |
+
The length of audio embeddings of each audio as stacked in `audio_features_embed`
|
| 137 |
+
audio_in_embed (`torch.Tensor` of shape `(total_num_audio_in_tokens, embed_dim)`):
|
| 138 |
+
The embeddings of audio-in tokens
|
| 139 |
+
audio_in_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
|
| 140 |
+
The start index of the audio-in tokens for each audio
|
| 141 |
+
audio_out_embed (`torch.Tensor` of shape `(total_num_audio_out_tokens, embed_dim)`):
|
| 142 |
+
The embeddings of audio-out tokens
|
| 143 |
+
audio_out_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
|
| 144 |
+
The start index of the audio-out tokens for each audio
|
| 145 |
+
audio_in_token_idx
|
| 146 |
+
The index of the audio-in token in the vocabulary
|
| 147 |
+
audio_out_token_idx
|
| 148 |
+
The index of the audio-out token in the vocabulary
|
| 149 |
+
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
|
| 150 |
+
Token embeddings before merging with audio embeddings
|
| 151 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 152 |
+
Input_ids of tokens, possibly filled with audio token
|
| 153 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 154 |
+
Mask to avoid performing attention on padding token indices.
|
| 155 |
+
label_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
|
| 156 |
+
labels need to be recalculated to support training (if provided)
|
| 157 |
+
pad_token_id (`int`):
|
| 158 |
+
The index of the pad token in the vocabulary
|
| 159 |
+
ignore_index
|
| 160 |
+
The index to ignore in the loss calculation
|
| 161 |
+
round_to
|
| 162 |
+
The number to round to for padding
|
| 163 |
+
left_padding
|
| 164 |
+
Whether to apply left padding
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
final_embedding
|
| 168 |
+
The final embeddings after merging audio embeddings with text embeddings.
|
| 169 |
+
final_attention_mask
|
| 170 |
+
The final attention mask after merging audio embeddings with text embeddings.
|
| 171 |
+
final_labels
|
| 172 |
+
The labels for the text stream
|
| 173 |
+
position_ids
|
| 174 |
+
Positional ids for the merged data
|
| 175 |
+
final_input_ids
|
| 176 |
+
The final input_ids after merging audio embeddings with text embeddings.
|
| 177 |
+
final_audio_in_mask
|
| 178 |
+
Mask for audio-in embeddings
|
| 179 |
+
final_audio_in_discrete_codes_mask
|
| 180 |
+
Mask for audio-in discrete tokens
|
| 181 |
+
final_audio_out_mask
|
| 182 |
+
Mask for audio-out embeddings
|
| 183 |
+
|
| 184 |
+
Explanation:
|
| 185 |
+
each audio has variable length embeddings, with length specified by
|
| 186 |
+
- audio_features_length
|
| 187 |
+
- audio_in_ids_start
|
| 188 |
+
- audio_out_ids_start
|
| 189 |
+
|
| 190 |
+
Task:
|
| 191 |
+
- fill each <|AUDIO|> with audio embeddings (it can be the combination of embeddings extracted by WhisperEncoder and embeddings from audio codebooks)
|
| 192 |
+
- fill each <|AUDIO_OUT|> with the audio-out embeddings
|
| 193 |
+
|
| 194 |
+
Example:
|
| 195 |
+
<|AUDIO_OUT|>: X (5 tokens), Y (3 tokens)
|
| 196 |
+
<|AUDIO|>: Z (8 tokens)
|
| 197 |
+
|
| 198 |
+
X, Y are in the same sequence (in-context voice-clone). Z is in a different sequence (audio understanding).
|
| 199 |
+
if right padding
|
| 200 |
+
input_ids: [
|
| 201 |
+
a b c d e f X g h i j k Y l m
|
| 202 |
+
o p q r Z s t u v _ _ _ _ _ _
|
| 203 |
+
]
|
| 204 |
+
input_ids should be: [
|
| 205 |
+
a b c d e f X X X X X g h i j k Y Y Y l m
|
| 206 |
+
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
|
| 207 |
+
]
|
| 208 |
+
labels should be: [
|
| 209 |
+
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
|
| 210 |
+
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
|
| 211 |
+
]
|
| 212 |
+
elif left padding
|
| 213 |
+
input_ids: [
|
| 214 |
+
a b c d e f X g h i j k Y l m
|
| 215 |
+
_ _ _ _ _ _ o p q r Z s t u v
|
| 216 |
+
]
|
| 217 |
+
input_ids should be: [
|
| 218 |
+
a b c d e f X X X X X g h i j k Y Y Y l m
|
| 219 |
+
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
|
| 220 |
+
]
|
| 221 |
+
labels should be: [
|
| 222 |
+
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
|
| 223 |
+
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
|
| 224 |
+
]
|
| 225 |
+
|
| 226 |
+
"""
|
| 227 |
+
if label_ids is None:
|
| 228 |
+
skip_labels = True
|
| 229 |
+
else:
|
| 230 |
+
skip_labels = False
|
| 231 |
+
if audio_features_embed is not None and audio_features_embed.shape[0] == 0:
|
| 232 |
+
audio_features_embed = None
|
| 233 |
+
if audio_in_embed is not None and audio_in_embed.shape[0] == 0:
|
| 234 |
+
audio_in_embed = None
|
| 235 |
+
if audio_out_embed is not None and audio_out_embed.shape[0] == 0:
|
| 236 |
+
audio_out_embed = None
|
| 237 |
+
|
| 238 |
+
batch_size, sequence_length, embed_dim = inputs_embeds.shape
|
| 239 |
+
|
| 240 |
+
target_device = inputs_embeds.device
|
| 241 |
+
if left_padding is None:
|
| 242 |
+
left_padding = torch.any(attention_mask[:, 0] == 0)
|
| 243 |
+
|
| 244 |
+
audio_in_token_mask = input_ids == audio_in_token_idx
|
| 245 |
+
audio_out_token_mask = input_ids == audio_out_token_idx
|
| 246 |
+
text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx)
|
| 247 |
+
|
| 248 |
+
# 1. Calculate the number of tokens for each placeholder (like [<|AUDIO|>, <|AUDIO_OUT|>]).
|
| 249 |
+
token_placeholder_num = torch.ones_like(input_ids)
|
| 250 |
+
|
| 251 |
+
if audio_features_embed is not None:
|
| 252 |
+
num_audios, max_audio_tokens, _ = audio_features_embed.shape
|
| 253 |
+
audio_in_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to(
|
| 254 |
+
audio_features_length.device
|
| 255 |
+
) < audio_features_length.unsqueeze(1)
|
| 256 |
+
masked_audio_in_features = audio_features_embed[audio_in_features_mask].view(-1, embed_dim)
|
| 257 |
+
token_placeholder_num[audio_in_token_mask] = audio_features_length.long()
|
| 258 |
+
|
| 259 |
+
if audio_in_embed is not None:
|
| 260 |
+
audio_in_codes_length = torch.concat(
|
| 261 |
+
[
|
| 262 |
+
audio_in_ids_start[1:] - audio_in_ids_start[:-1],
|
| 263 |
+
torch.tensor(
|
| 264 |
+
[audio_in_embed.shape[0] - audio_in_ids_start[-1]],
|
| 265 |
+
device=audio_in_ids_start.device,
|
| 266 |
+
dtype=torch.long,
|
| 267 |
+
),
|
| 268 |
+
],
|
| 269 |
+
dim=0,
|
| 270 |
+
)
|
| 271 |
+
if audio_features_embed is not None:
|
| 272 |
+
token_placeholder_num[audio_in_token_mask] += audio_in_codes_length.long()
|
| 273 |
+
else:
|
| 274 |
+
token_placeholder_num[audio_in_token_mask] = audio_in_codes_length.long()
|
| 275 |
+
|
| 276 |
+
if audio_out_embed is not None:
|
| 277 |
+
audio_out_codes_length = torch.concat(
|
| 278 |
+
[
|
| 279 |
+
audio_out_ids_start[1:] - audio_out_ids_start[:-1],
|
| 280 |
+
torch.tensor(
|
| 281 |
+
[audio_out_embed.shape[0] - audio_out_ids_start[-1]],
|
| 282 |
+
device=audio_out_ids_start.device,
|
| 283 |
+
dtype=torch.long,
|
| 284 |
+
),
|
| 285 |
+
],
|
| 286 |
+
dim=0,
|
| 287 |
+
)
|
| 288 |
+
token_placeholder_num[audio_out_token_mask] = audio_out_codes_length.long()
|
| 289 |
+
|
| 290 |
+
new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
|
| 291 |
+
max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to)
|
| 292 |
+
nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1]
|
| 293 |
+
|
| 294 |
+
if left_padding:
|
| 295 |
+
new_token_positions += nb_audio_pad[:, None] # offset for left padding
|
| 296 |
+
|
| 297 |
+
# 2. Create the full embedding, already padded to the maximum position
|
| 298 |
+
final_embedding = torch.zeros(
|
| 299 |
+
(batch_size, max_token_num, embed_dim),
|
| 300 |
+
dtype=inputs_embeds.dtype,
|
| 301 |
+
device=inputs_embeds.device,
|
| 302 |
+
)
|
| 303 |
+
final_attention_mask = torch.zeros(
|
| 304 |
+
(batch_size, max_token_num),
|
| 305 |
+
dtype=attention_mask.dtype,
|
| 306 |
+
device=inputs_embeds.device,
|
| 307 |
+
)
|
| 308 |
+
final_input_ids = torch.full(
|
| 309 |
+
(batch_size, max_token_num),
|
| 310 |
+
pad_token_id,
|
| 311 |
+
dtype=input_ids.dtype,
|
| 312 |
+
device=inputs_embeds.device,
|
| 313 |
+
)
|
| 314 |
+
if skip_labels:
|
| 315 |
+
final_labels = None
|
| 316 |
+
else:
|
| 317 |
+
final_labels = torch.full(
|
| 318 |
+
(batch_size, max_token_num),
|
| 319 |
+
ignore_index,
|
| 320 |
+
dtype=label_ids.dtype,
|
| 321 |
+
device=inputs_embeds.device,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
final_audio_in_mask = torch.full(
|
| 325 |
+
(batch_size, max_token_num),
|
| 326 |
+
False,
|
| 327 |
+
dtype=torch.bool,
|
| 328 |
+
device=inputs_embeds.device,
|
| 329 |
+
)
|
| 330 |
+
final_audio_in_discrete_codes_mask = torch.full(
|
| 331 |
+
(batch_size, max_token_num),
|
| 332 |
+
False,
|
| 333 |
+
dtype=torch.bool,
|
| 334 |
+
device=inputs_embeds.device,
|
| 335 |
+
)
|
| 336 |
+
final_audio_out_mask = torch.full(
|
| 337 |
+
(batch_size, max_token_num),
|
| 338 |
+
False,
|
| 339 |
+
dtype=torch.bool,
|
| 340 |
+
device=inputs_embeds.device,
|
| 341 |
+
)
|
| 342 |
+
# 3. Get the audio-in token positions and audio-out token positions
|
| 343 |
+
batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length)
|
| 344 |
+
audio_in_batch_id = batch_id[audio_in_token_mask] # Shape (num_audio_in,)
|
| 345 |
+
audio_out_batch_id = batch_id[audio_out_token_mask] # Shape (num_audio_out,)
|
| 346 |
+
audio_features_token_ends = new_token_positions[audio_in_token_mask] # Shape (num_audio_in,)
|
| 347 |
+
audio_out_embed_ends = new_token_positions[audio_out_token_mask] # Shape (num_audio_out,)
|
| 348 |
+
|
| 349 |
+
if audio_in_embed is not None:
|
| 350 |
+
# Fill in the audio-in embeddings
|
| 351 |
+
seq_indices = (
|
| 352 |
+
torch.arange(max_token_num, device=target_device)
|
| 353 |
+
.unsqueeze(0)
|
| 354 |
+
.expand(audio_in_ids_start.shape[0], max_token_num)
|
| 355 |
+
)
|
| 356 |
+
audio_in_embed_token_starts = audio_features_token_ends - audio_in_codes_length + 1
|
| 357 |
+
batch_indices, col_indices = torch.where(
|
| 358 |
+
(seq_indices >= audio_in_embed_token_starts.unsqueeze(1))
|
| 359 |
+
& (seq_indices <= audio_features_token_ends.unsqueeze(1))
|
| 360 |
+
)
|
| 361 |
+
batch_indices = audio_in_batch_id[batch_indices]
|
| 362 |
+
final_embedding[batch_indices, col_indices] = audio_in_embed
|
| 363 |
+
final_input_ids[batch_indices, col_indices] = audio_in_token_idx
|
| 364 |
+
if not skip_labels:
|
| 365 |
+
final_labels[batch_indices, col_indices] = ignore_index
|
| 366 |
+
final_audio_in_mask[batch_indices, col_indices] = True
|
| 367 |
+
final_audio_in_discrete_codes_mask[batch_indices, col_indices] = True
|
| 368 |
+
audio_features_token_ends = audio_features_token_ends - audio_in_codes_length
|
| 369 |
+
|
| 370 |
+
if audio_features_embed is not None:
|
| 371 |
+
# Fill in the audio features
|
| 372 |
+
seq_indices = (
|
| 373 |
+
torch.arange(max_token_num, device=target_device)
|
| 374 |
+
.unsqueeze(0)
|
| 375 |
+
.expand(audio_features_embed.shape[0], max_token_num)
|
| 376 |
+
)
|
| 377 |
+
audio_features_token_starts = audio_features_token_ends - audio_features_length + 1
|
| 378 |
+
batch_indices, col_indices = torch.where(
|
| 379 |
+
(seq_indices >= audio_features_token_starts.unsqueeze(1))
|
| 380 |
+
& (seq_indices <= audio_features_token_ends.unsqueeze(1))
|
| 381 |
+
)
|
| 382 |
+
batch_indices = audio_in_batch_id[batch_indices]
|
| 383 |
+
final_embedding[batch_indices, col_indices] = masked_audio_in_features
|
| 384 |
+
final_input_ids[batch_indices, col_indices] = audio_in_token_idx
|
| 385 |
+
if not skip_labels:
|
| 386 |
+
final_labels[batch_indices, col_indices] = ignore_index
|
| 387 |
+
final_audio_in_mask[batch_indices, col_indices] = True
|
| 388 |
+
|
| 389 |
+
if audio_out_embed is not None:
|
| 390 |
+
# Fill in the audio-out embeddings
|
| 391 |
+
seq_indices = (
|
| 392 |
+
torch.arange(max_token_num, device=target_device)
|
| 393 |
+
.unsqueeze(0)
|
| 394 |
+
.expand(audio_out_ids_start.shape[0], max_token_num)
|
| 395 |
+
)
|
| 396 |
+
audio_out_embed_token_starts = audio_out_embed_ends - audio_out_codes_length + 1
|
| 397 |
+
batch_indices, col_indices = torch.where(
|
| 398 |
+
(seq_indices >= audio_out_embed_token_starts.unsqueeze(1))
|
| 399 |
+
& (seq_indices <= audio_out_embed_ends.unsqueeze(1))
|
| 400 |
+
)
|
| 401 |
+
batch_indices = audio_out_batch_id[batch_indices]
|
| 402 |
+
final_embedding[batch_indices, col_indices] = audio_out_embed
|
| 403 |
+
final_input_ids[batch_indices, col_indices] = audio_out_token_idx
|
| 404 |
+
if not skip_labels:
|
| 405 |
+
final_labels[batch_indices, col_indices] = ignore_index
|
| 406 |
+
final_audio_out_mask[batch_indices, col_indices] = True
|
| 407 |
+
|
| 408 |
+
# Fill in the original text embeddings and labels
|
| 409 |
+
batch_indices, non_audio_indices = torch.where(text_token_mask)
|
| 410 |
+
text_to_overwrite = new_token_positions[batch_indices, non_audio_indices]
|
| 411 |
+
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices]
|
| 412 |
+
if not skip_labels:
|
| 413 |
+
final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, non_audio_indices]
|
| 414 |
+
final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices]
|
| 415 |
+
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices]
|
| 416 |
+
final_attention_mask = final_attention_mask | final_audio_in_mask | final_audio_out_mask
|
| 417 |
+
|
| 418 |
+
# Trim the tensor if there are redundant padding tokens
|
| 419 |
+
if left_padding:
|
| 420 |
+
first_non_zero_loc = final_attention_mask.sum(0).nonzero()[0]
|
| 421 |
+
first_non_zero_loc = (first_non_zero_loc // round_to) * round_to
|
| 422 |
+
if first_non_zero_loc > 0:
|
| 423 |
+
final_attention_mask = final_attention_mask[:, first_non_zero_loc:]
|
| 424 |
+
final_embedding = final_embedding[:, first_non_zero_loc:]
|
| 425 |
+
if not skip_labels:
|
| 426 |
+
final_labels = final_labels[:, first_non_zero_loc:]
|
| 427 |
+
final_input_ids = final_input_ids[:, first_non_zero_loc:]
|
| 428 |
+
final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:]
|
| 429 |
+
final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:]
|
| 430 |
+
final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:]
|
| 431 |
+
else:
|
| 432 |
+
# We have done right padding, so we need to trim the mask
|
| 433 |
+
last_non_zero_loc = final_attention_mask.sum(0).nonzero()[-1] + 1
|
| 434 |
+
last_non_zero_loc = ((last_non_zero_loc + round_to - 1) // round_to) * round_to
|
| 435 |
+
if last_non_zero_loc < max_token_num:
|
| 436 |
+
final_attention_mask = final_attention_mask[:, :last_non_zero_loc]
|
| 437 |
+
final_embedding = final_embedding[:, :last_non_zero_loc]
|
| 438 |
+
if not skip_labels:
|
| 439 |
+
final_labels = final_labels[:, :last_non_zero_loc]
|
| 440 |
+
final_input_ids = final_input_ids[:, :last_non_zero_loc]
|
| 441 |
+
final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc]
|
| 442 |
+
final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc]
|
| 443 |
+
final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc]
|
| 444 |
+
|
| 445 |
+
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
| 446 |
+
return (
|
| 447 |
+
final_embedding,
|
| 448 |
+
final_attention_mask,
|
| 449 |
+
final_labels,
|
| 450 |
+
position_ids,
|
| 451 |
+
final_input_ids,
|
| 452 |
+
final_audio_in_mask,
|
| 453 |
+
final_audio_in_discrete_codes_mask,
|
| 454 |
+
final_audio_out_mask,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def is_deepspeed_ulysses_enabled():
|
| 459 |
+
if deepspeed_groups is None:
|
| 460 |
+
return False
|
| 461 |
+
|
| 462 |
+
"""Check if sequence parallelism is enabled."""
|
| 463 |
+
return deepspeed_groups._get_sequence_parallel_world_size() > 1
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def support_deepspeed_ulysses(module):
|
| 467 |
+
"""A decorator around Pytorch module. It is needed for the module that needs access to sequence parallel info."""
|
| 468 |
+
module._sp_size = None
|
| 469 |
+
module._sp_rank = None
|
| 470 |
+
module._sp_group = None
|
| 471 |
+
|
| 472 |
+
@property
|
| 473 |
+
def sp_size(self):
|
| 474 |
+
if self._sp_size is None:
|
| 475 |
+
self._sp_size = 1
|
| 476 |
+
if is_deepspeed_ulysses_enabled():
|
| 477 |
+
self._sp_size = deepspeed_groups._get_sequence_parallel_group().size()
|
| 478 |
+
return self._sp_size
|
| 479 |
+
|
| 480 |
+
@property
|
| 481 |
+
def sp_rank(self):
|
| 482 |
+
if self._sp_rank is None:
|
| 483 |
+
self._sp_rank = 0
|
| 484 |
+
if is_deepspeed_ulysses_enabled():
|
| 485 |
+
self._sp_rank = deepspeed_groups._get_sequence_parallel_rank()
|
| 486 |
+
return self._sp_rank
|
| 487 |
+
|
| 488 |
+
@property
|
| 489 |
+
def sp_group(self):
|
| 490 |
+
if self._sp_group is None and is_deepspeed_ulysses_enabled():
|
| 491 |
+
self._sp_group = deepspeed_groups._get_sequence_parallel_group()
|
| 492 |
+
return self._sp_group
|
| 493 |
+
|
| 494 |
+
module.sp_size = sp_size
|
| 495 |
+
module.sp_rank = sp_rank
|
| 496 |
+
module.sp_group = sp_group
|
| 497 |
+
|
| 498 |
+
return module
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def deepspeed_ulysses_attention(seq_dim=1, head_dim=2):
|
| 502 |
+
"""Perform all-to-all before and after the attention function."""
|
| 503 |
+
|
| 504 |
+
def attention_decorator(attn_func=None):
|
| 505 |
+
def wrapped(*args, **kwargs):
|
| 506 |
+
if is_deepspeed_ulysses_enabled():
|
| 507 |
+
sp_group = deepspeed_groups._get_sequence_parallel_group()
|
| 508 |
+
scatter_idx = head_dim # Scatter on num_heads dimension
|
| 509 |
+
gather_idx = seq_dim # Gather on seq_len dimension
|
| 510 |
+
batch_dim_idx = 0
|
| 511 |
+
args = list(args)
|
| 512 |
+
args[0] = _SeqAllToAll.apply(sp_group, args[0], scatter_idx, gather_idx, batch_dim_idx)
|
| 513 |
+
args[1] = _SeqAllToAll.apply(sp_group, args[1], scatter_idx, gather_idx, batch_dim_idx)
|
| 514 |
+
args[2] = _SeqAllToAll.apply(sp_group, args[2], scatter_idx, gather_idx, batch_dim_idx)
|
| 515 |
+
args = tuple(args)
|
| 516 |
+
|
| 517 |
+
attn_output = attn_func(*args, **kwargs)
|
| 518 |
+
|
| 519 |
+
if is_deepspeed_ulysses_enabled():
|
| 520 |
+
scatter_idx = seq_dim # Scatter back on seq_len dimension
|
| 521 |
+
gather_idx = head_dim # Gather on num_heads dimension
|
| 522 |
+
batch_dim_idx = 0
|
| 523 |
+
attn_output = _SeqAllToAll.apply(sp_group, attn_output, scatter_idx, gather_idx, batch_dim_idx)
|
| 524 |
+
|
| 525 |
+
return attn_output
|
| 526 |
+
|
| 527 |
+
return wrapped
|
| 528 |
+
|
| 529 |
+
return attention_decorator
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def deepspeed_ulysses_rope(state_seq_dim=2, trig_seq_dim=1):
|
| 533 |
+
"""Slice the corresponding cos and sin chunks for rope."""
|
| 534 |
+
|
| 535 |
+
def rope_decorator(rope_func=None):
|
| 536 |
+
def wrapped(*args, **kwargs):
|
| 537 |
+
if is_deepspeed_ulysses_enabled():
|
| 538 |
+
sp_rank = deepspeed_groups._get_sequence_parallel_rank()
|
| 539 |
+
args = list(args)
|
| 540 |
+
seq_chunk_size = args[0].size(state_seq_dim)
|
| 541 |
+
args[2] = torch.narrow(args[2], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
|
| 542 |
+
args[3] = torch.narrow(args[3], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
|
| 543 |
+
args = tuple(args)
|
| 544 |
+
|
| 545 |
+
return rope_func(*args, **kwargs)
|
| 546 |
+
|
| 547 |
+
return wrapped
|
| 548 |
+
|
| 549 |
+
return rope_decorator
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def _gather_tensors(input_, group=None):
|
| 553 |
+
"""Gather tensors and concatenate them along a dimension."""
|
| 554 |
+
input_ = input_.contiguous()
|
| 555 |
+
world_size = torch.distributed.get_world_size(group)
|
| 556 |
+
if world_size == 1:
|
| 557 |
+
return input_
|
| 558 |
+
tensor_shapes = [
|
| 559 |
+
torch.empty(len(input_.size()), dtype=torch.int64, device=input_.device) for _ in range(world_size)
|
| 560 |
+
]
|
| 561 |
+
input_size = torch.tensor(input_.size(), dtype=torch.int64, device=input_.device)
|
| 562 |
+
torch.distributed.all_gather(tensor_shapes, input_size, group=group)
|
| 563 |
+
gathered_buffers = [
|
| 564 |
+
torch.empty(tensor_shapes[i].tolist(), dtype=input_.dtype, device=input_.device) for i in range(world_size)
|
| 565 |
+
]
|
| 566 |
+
torch.distributed.all_gather(gathered_buffers, input_, group=group)
|
| 567 |
+
return gathered_buffers
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def _scatter_tensors(input_, group=None):
|
| 571 |
+
"""Scatter tensors."""
|
| 572 |
+
world_size = torch.distributed.get_world_size(group)
|
| 573 |
+
if world_size == 1:
|
| 574 |
+
return input_
|
| 575 |
+
rank = torch.distributed.get_rank(group)
|
| 576 |
+
return input_[rank]
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
class _GatherTensors(torch.autograd.Function):
|
| 580 |
+
"""All gather tensors among the ranks."""
|
| 581 |
+
|
| 582 |
+
@staticmethod
|
| 583 |
+
def symbolic(graph, input_, group):
|
| 584 |
+
return _gather_tensors(input_, group)
|
| 585 |
+
|
| 586 |
+
@staticmethod
|
| 587 |
+
def forward(ctx, input_, group):
|
| 588 |
+
ctx.group = group
|
| 589 |
+
return torch.nested.as_nested_tensor(_gather_tensors(input_, group), layout=torch.jagged)
|
| 590 |
+
|
| 591 |
+
@staticmethod
|
| 592 |
+
def backward(ctx, grad_output):
|
| 593 |
+
return _scatter_tensors(grad_output, ctx.group), None
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def all_gather_tensors(input_, size=None, dim=0, group=None):
|
| 597 |
+
if torch.distributed.get_world_size(group) == 1:
|
| 598 |
+
# no sequence parallelism
|
| 599 |
+
return input_
|
| 600 |
+
gathered_tensors = _GatherTensors.apply(input_, group)
|
| 601 |
+
|
| 602 |
+
if size:
|
| 603 |
+
split_gathered_tensors = []
|
| 604 |
+
for s, gathered_tensor in zip(size, gathered_tensors):
|
| 605 |
+
split_gathered_tensor = torch.split(gathered_tensor, s.tolist())
|
| 606 |
+
split_gathered_tensors.append(split_gathered_tensor)
|
| 607 |
+
|
| 608 |
+
gathered_tensors = [y for x in zip(*split_gathered_tensors) for y in x]
|
| 609 |
+
|
| 610 |
+
return torch.cat(gathered_tensors, dim).contiguous()
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def get_sequence_data_parallel_world_size():
|
| 614 |
+
return torch.distributed.get_world_size()
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def get_sequence_data_parallel_rank():
|
| 618 |
+
return torch.distributed.get_rank()
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
def get_sequence_data_parallel_group():
|
| 622 |
+
return torch.distributed.group.WORLD
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
if is_deepspeed_available():
|
| 626 |
+
deepspeed_groups._get_sequence_data_parallel_world_size = get_sequence_data_parallel_world_size
|
| 627 |
+
deepspeed_groups._get_sequence_data_parallel_rank = get_sequence_data_parallel_rank
|
| 628 |
+
deepspeed_groups._get_sequence_data_parallel_group = get_sequence_data_parallel_group
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def _gather_tokens(input_, dim=0, group=None):
|
| 632 |
+
"""Gather tensors and concatenate them along a dimension"""
|
| 633 |
+
input_ = input_.contiguous()
|
| 634 |
+
world_size = torch.distributed.get_world_size(group)
|
| 635 |
+
if world_size == 1:
|
| 636 |
+
return input_
|
| 637 |
+
|
| 638 |
+
gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device)
|
| 639 |
+
torch.distributed.all_gather_into_tensor(gather_buffer, input_, group=group)
|
| 640 |
+
if dim == 0:
|
| 641 |
+
shape = list(input_.size())
|
| 642 |
+
shape[0] = shape[0] * world_size
|
| 643 |
+
output = gather_buffer.view(shape)
|
| 644 |
+
else:
|
| 645 |
+
tensor_list = [
|
| 646 |
+
gather_buffer.narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size)
|
| 647 |
+
]
|
| 648 |
+
# Note: torch.cat already creates a contiguous tensor.
|
| 649 |
+
output = torch.cat(tensor_list, dim=dim).contiguous()
|
| 650 |
+
|
| 651 |
+
return output
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def _drop_tokens(input_, dim=0, group=None):
|
| 655 |
+
"""Divide a tensor among the sequence parallel ranks"""
|
| 656 |
+
world_size = torch.distributed.get_world_size(group)
|
| 657 |
+
if world_size == 1:
|
| 658 |
+
return input_
|
| 659 |
+
this_rank = torch.distributed.get_rank(group)
|
| 660 |
+
assert input_.shape[dim] % world_size == 0, (
|
| 661 |
+
f"input dimension {dim} ({input_.shape[dim]}) is not divisible by sequence parallel world size ({world_size})"
|
| 662 |
+
)
|
| 663 |
+
chunk_size = input_.shape[dim] // world_size
|
| 664 |
+
|
| 665 |
+
return torch.narrow(input_, dim, this_rank * chunk_size, chunk_size)
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
class _DropTokens(torch.autograd.Function):
|
| 669 |
+
"Divide tokens equally among the sequence parallel ranks"
|
| 670 |
+
|
| 671 |
+
@staticmethod
|
| 672 |
+
def symbolic(graph, input_, dim, group, grad_scale):
|
| 673 |
+
return _drop_tokens(input_, dim, group)
|
| 674 |
+
|
| 675 |
+
@staticmethod
|
| 676 |
+
def forward(ctx, input_, dim, group, grad_scale):
|
| 677 |
+
ctx.dim = dim
|
| 678 |
+
ctx.group = group
|
| 679 |
+
ctx.grad_scale = grad_scale
|
| 680 |
+
return _drop_tokens(input_, dim, group)
|
| 681 |
+
|
| 682 |
+
@staticmethod
|
| 683 |
+
def backward(ctx, grad_output):
|
| 684 |
+
grad_input = _gather_tokens(grad_output, ctx.dim, ctx.group)
|
| 685 |
+
if ctx.grad_scale != 1:
|
| 686 |
+
grad_input /= ctx.grad_scale
|
| 687 |
+
return grad_input, None, None, None
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
class _GatherTokens(torch.autograd.Function):
|
| 691 |
+
"Gather tokens among the sequence parallel ranks"
|
| 692 |
+
|
| 693 |
+
@staticmethod
|
| 694 |
+
def symbolic(graph, input_, dim, group, grad_scale):
|
| 695 |
+
return _gather_tokens(input_, dim, group)
|
| 696 |
+
|
| 697 |
+
@staticmethod
|
| 698 |
+
def forward(ctx, input_, dim, group, grad_scale):
|
| 699 |
+
ctx.dim = dim
|
| 700 |
+
ctx.group = group
|
| 701 |
+
ctx.grad_scale = grad_scale
|
| 702 |
+
return _gather_tokens(input_, dim, group)
|
| 703 |
+
|
| 704 |
+
@staticmethod
|
| 705 |
+
def backward(ctx, grad_output):
|
| 706 |
+
grad_input = _drop_tokens(grad_output, ctx.dim, ctx.group)
|
| 707 |
+
if ctx.grad_scale != 1:
|
| 708 |
+
grad_input *= ctx.grad_scale
|
| 709 |
+
return grad_input, None, None, None
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
def drop_tokens(input_, dim=0, group=None, grad_scale=1):
|
| 713 |
+
if torch.distributed.get_world_size(group) == 1:
|
| 714 |
+
# no sequence parallelism
|
| 715 |
+
return input_
|
| 716 |
+
return _DropTokens.apply(input_, dim, group, grad_scale)
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
def gather_tokens(input_, dim=0, group=None, grad_scale=1):
|
| 720 |
+
if torch.distributed.get_world_size(group) == 1:
|
| 721 |
+
# no sequence parallelism
|
| 722 |
+
return input_
|
| 723 |
+
return _GatherTokens.apply(input_, dim, group, grad_scale)
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
def sequence_chunking_per_rank(sp_size, sp_rank, *args, dim=1):
|
| 727 |
+
"""
|
| 728 |
+
Slice the inputs to create chuncks per the sequence parallel rank. This is used for the context parallel training.
|
| 729 |
+
|
| 730 |
+
Args:
|
| 731 |
+
sp_size (`int`):
|
| 732 |
+
Sequence parallel size.
|
| 733 |
+
sp_rank (`int`):
|
| 734 |
+
Sequence parallel rank for the current process.
|
| 735 |
+
dim (`int`):
|
| 736 |
+
The dimension to slice
|
| 737 |
+
"""
|
| 738 |
+
if sp_size == 1:
|
| 739 |
+
return args[0] if len(args) == 1 else args
|
| 740 |
+
|
| 741 |
+
seq_length = args[0].size(dim)
|
| 742 |
+
for arg in args[1:]:
|
| 743 |
+
assert arg.size(dim) == seq_length, (
|
| 744 |
+
f"arg={arg} ({arg.shape[dim]}) does not have the same size as args[0] ({seq_length}) in dimension {dim}"
|
| 745 |
+
)
|
| 746 |
+
assert seq_length % sp_size == 0, (
|
| 747 |
+
f"dimension {dim} ({args[0].shape[dim]}) is not divisible by sequence parallel world size ({sp_size})"
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
sub_seq_length = seq_length // sp_size
|
| 751 |
+
sub_seq_start = sp_rank * sub_seq_length
|
| 752 |
+
|
| 753 |
+
output = []
|
| 754 |
+
for ind in args:
|
| 755 |
+
ind = torch.narrow(ind, dim, sub_seq_start, sub_seq_length)
|
| 756 |
+
output.append(ind)
|
| 757 |
+
|
| 758 |
+
return tuple(output) if len(output) > 1 else output[0]
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
@contextmanager
|
| 762 |
+
def disable_deepspeed_ulysses():
|
| 763 |
+
"""Disable deepspeed ulysses (sequence parallelism) if it is enabled"""
|
| 764 |
+
if is_deepspeed_ulysses_enabled():
|
| 765 |
+
_old_get_sequence_parallel_world_size = deepspeed_groups._get_sequence_parallel_world_size
|
| 766 |
+
|
| 767 |
+
def _get_sequence_parallel_world_size():
|
| 768 |
+
return 1
|
| 769 |
+
|
| 770 |
+
deepspeed_groups._get_sequence_parallel_world_size = _get_sequence_parallel_world_size
|
| 771 |
+
try:
|
| 772 |
+
yield
|
| 773 |
+
finally:
|
| 774 |
+
deepspeed_groups._get_sequence_parallel_world_size = _old_get_sequence_parallel_world_size
|
| 775 |
+
else:
|
| 776 |
+
context = contextlib.nullcontext
|
| 777 |
+
with context():
|
| 778 |
+
yield
|
higgs_audio/serve/serve_engine.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import base64
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import List, Optional, Union
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from transformers import AutoTokenizer, AutoProcessor
|
| 10 |
+
from transformers.cache_utils import StaticCache
|
| 11 |
+
from transformers.generation.streamers import BaseStreamer
|
| 12 |
+
from transformers.generation.stopping_criteria import StoppingCriteria
|
| 13 |
+
from dataclasses import asdict
|
| 14 |
+
from loguru import logger
|
| 15 |
+
import threading
|
| 16 |
+
import librosa
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from ..dataset.chatml_dataset import (
|
| 20 |
+
ChatMLSample,
|
| 21 |
+
ChatMLDatasetSample,
|
| 22 |
+
prepare_chatml_sample,
|
| 23 |
+
)
|
| 24 |
+
from ..model import HiggsAudioModel
|
| 25 |
+
from ..model.utils import revert_delay_pattern
|
| 26 |
+
from ..data_collator.higgs_audio_collator import HiggsAudioSampleCollator
|
| 27 |
+
from ..audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def normalize_chinese_punctuation(text):
|
| 31 |
+
"""
|
| 32 |
+
Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
|
| 33 |
+
"""
|
| 34 |
+
# Mapping of Chinese punctuation to English punctuation
|
| 35 |
+
chinese_to_english_punct = {
|
| 36 |
+
",": ",", # comma
|
| 37 |
+
"。": ".", # period
|
| 38 |
+
":": ":", # colon
|
| 39 |
+
";": ";", # semicolon
|
| 40 |
+
"?": "?", # question mark
|
| 41 |
+
"!": "!", # exclamation mark
|
| 42 |
+
"(": "(", # left parenthesis
|
| 43 |
+
")": ")", # right parenthesis
|
| 44 |
+
"【": "[", # left square bracket
|
| 45 |
+
"】": "]", # right square bracket
|
| 46 |
+
"《": "<", # left angle quote
|
| 47 |
+
"》": ">", # right angle quote
|
| 48 |
+
"“": '"', # left double quotation
|
| 49 |
+
"”": '"', # right double quotation
|
| 50 |
+
"‘": "'", # left single quotation
|
| 51 |
+
"’": "'", # right single quotation
|
| 52 |
+
"、": ",", # enumeration comma
|
| 53 |
+
"—": "-", # em dash
|
| 54 |
+
"…": "...", # ellipsis
|
| 55 |
+
"·": ".", # middle dot
|
| 56 |
+
"「": '"', # left corner bracket
|
| 57 |
+
"」": '"', # right corner bracket
|
| 58 |
+
"『": '"', # left double corner bracket
|
| 59 |
+
"』": '"', # right double corner bracket
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# Replace each Chinese punctuation with its English counterpart
|
| 63 |
+
for zh_punct, en_punct in chinese_to_english_punct.items():
|
| 64 |
+
text = text.replace(zh_punct, en_punct)
|
| 65 |
+
|
| 66 |
+
return text
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@dataclass
|
| 70 |
+
class HiggsAudioStreamerDelta:
|
| 71 |
+
"""Represents a chunk of generated content, either text or audio tokens."""
|
| 72 |
+
|
| 73 |
+
text: Optional[str] = None
|
| 74 |
+
text_tokens: Optional[torch.Tensor] = None
|
| 75 |
+
audio_tokens: Optional[torch.Tensor] = None
|
| 76 |
+
finish_reason: Optional[str] = None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class AsyncHiggsAudioStreamer(BaseStreamer):
|
| 80 |
+
"""
|
| 81 |
+
Async streamer that handles both text and audio token generation from Higgs-Audio model.
|
| 82 |
+
Stores chunks in a queue to be consumed by downstream applications.
|
| 83 |
+
|
| 84 |
+
Parameters:
|
| 85 |
+
tokenizer (`AutoTokenizer`):
|
| 86 |
+
The tokenizer used to decode text tokens.
|
| 87 |
+
skip_prompt (`bool`, *optional*, defaults to `False`):
|
| 88 |
+
Whether to skip the prompt tokens in generation.
|
| 89 |
+
timeout (`float`, *optional*):
|
| 90 |
+
The timeout for the queue. If `None`, the queue will block indefinitely.
|
| 91 |
+
decode_kwargs (`dict`, *optional*):
|
| 92 |
+
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
| 93 |
+
|
| 94 |
+
Examples:
|
| 95 |
+
```python
|
| 96 |
+
>>> from transformers import AutoTokenizer
|
| 97 |
+
>>> from threading import Thread
|
| 98 |
+
>>> import asyncio
|
| 99 |
+
|
| 100 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("path/to/higgs/tokenizer")
|
| 101 |
+
>>> model = HiggsAudioModel.from_pretrained("path/to/higgs/model")
|
| 102 |
+
>>> inputs = tokenizer(["Generate some text and audio:"], return_tensors="pt")
|
| 103 |
+
|
| 104 |
+
>>> async def main():
|
| 105 |
+
... streamer = AsyncHiggsAudioStreamer(tokenizer)
|
| 106 |
+
... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
|
| 107 |
+
... thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 108 |
+
... thread.start()
|
| 109 |
+
...
|
| 110 |
+
... async for delta in streamer:
|
| 111 |
+
... if delta.text is not None:
|
| 112 |
+
... print("Text:", delta.text)
|
| 113 |
+
... if delta.audio_tokens is not None:
|
| 114 |
+
... print("Audio tokens shape:", delta.audio_tokens.shape)
|
| 115 |
+
>>> asyncio.run(main())
|
| 116 |
+
```
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
tokenizer: "AutoTokenizer",
|
| 122 |
+
skip_prompt: bool = False,
|
| 123 |
+
timeout: Optional[float] = None,
|
| 124 |
+
audio_num_codebooks: int = 1,
|
| 125 |
+
**decode_kwargs,
|
| 126 |
+
):
|
| 127 |
+
self.tokenizer = tokenizer
|
| 128 |
+
self.skip_prompt = skip_prompt
|
| 129 |
+
self.timeout = timeout
|
| 130 |
+
self.decode_kwargs = decode_kwargs
|
| 131 |
+
self.audio_num_codebooks = audio_num_codebooks
|
| 132 |
+
|
| 133 |
+
# Queue to store generated chunks
|
| 134 |
+
self.queue = asyncio.Queue()
|
| 135 |
+
self.stop_signal = None
|
| 136 |
+
|
| 137 |
+
# Get running event loop
|
| 138 |
+
self.loop = asyncio.get_running_loop()
|
| 139 |
+
self.has_asyncio_timeout = hasattr(asyncio, "timeout")
|
| 140 |
+
|
| 141 |
+
# State tracking
|
| 142 |
+
self.next_tokens_are_prompt = True
|
| 143 |
+
|
| 144 |
+
def put(self, value: torch.Tensor):
|
| 145 |
+
"""
|
| 146 |
+
Receives tokens and processes them as either text or audio tokens.
|
| 147 |
+
For text tokens, decodes and caches them until complete words are formed.
|
| 148 |
+
For audio tokens, directly queues them.
|
| 149 |
+
"""
|
| 150 |
+
if value.shape[0] > 1 and not self.next_tokens_are_prompt:
|
| 151 |
+
# This is likely audio tokens (shape: [audio_num_codebooks])
|
| 152 |
+
assert value.shape[0] == self.audio_num_codebooks, "Number of codebooks mismatch"
|
| 153 |
+
delta = HiggsAudioStreamerDelta(audio_tokens=value)
|
| 154 |
+
self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
# Skip prompt tokens if configured
|
| 158 |
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
| 159 |
+
self.next_tokens_are_prompt = False
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
# Process as text tokens
|
| 163 |
+
if len(value.shape) > 1:
|
| 164 |
+
value = value[0]
|
| 165 |
+
|
| 166 |
+
text = self.tokenizer.decode(value, **self.decode_kwargs)
|
| 167 |
+
delta = HiggsAudioStreamerDelta(text=text, text_tokens=value)
|
| 168 |
+
self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
|
| 169 |
+
|
| 170 |
+
def end(self):
|
| 171 |
+
"""Flushes any remaining text tokens and signals the end of generation."""
|
| 172 |
+
self.next_tokens_are_prompt = True
|
| 173 |
+
self.loop.call_soon_threadsafe(self.queue.put_nowait, self.stop_signal)
|
| 174 |
+
|
| 175 |
+
def __aiter__(self):
|
| 176 |
+
return self
|
| 177 |
+
|
| 178 |
+
async def __anext__(self):
|
| 179 |
+
try:
|
| 180 |
+
if self.has_asyncio_timeout:
|
| 181 |
+
async with asyncio.timeout(self.timeout):
|
| 182 |
+
value = await self.queue.get()
|
| 183 |
+
else:
|
| 184 |
+
value = await asyncio.wait_for(self.queue.get(), timeout=self.timeout)
|
| 185 |
+
except asyncio.TimeoutError:
|
| 186 |
+
raise TimeoutError()
|
| 187 |
+
else:
|
| 188 |
+
if value == self.stop_signal:
|
| 189 |
+
raise StopAsyncIteration()
|
| 190 |
+
else:
|
| 191 |
+
return value
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class AsyncStoppingCriteria(StoppingCriteria):
|
| 195 |
+
"""
|
| 196 |
+
Stopping criteria that checks for stop signal from a threading event.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
stop_signal (threading.Event): Event that will receive stop signals
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(self, stop_signal: threading.Event):
|
| 203 |
+
self.stop_signal = stop_signal
|
| 204 |
+
|
| 205 |
+
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
| 206 |
+
if self.stop_signal.is_set():
|
| 207 |
+
logger.info(f"Stop signal received. Can be caused by client disconnection.")
|
| 208 |
+
return True
|
| 209 |
+
return False
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@dataclass
|
| 213 |
+
class HiggsAudioResponse:
|
| 214 |
+
audio: Optional[np.ndarray] = None
|
| 215 |
+
generated_audio_tokens: Optional[np.ndarray] = None
|
| 216 |
+
sampling_rate: Optional[int] = None
|
| 217 |
+
generated_text: str = ""
|
| 218 |
+
generated_text_tokens: np.ndarray = field(default_factory=np.ndarray)
|
| 219 |
+
usage: Optional[dict] = None
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class HiggsAudioServeEngine:
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
model_name_or_path: str,
|
| 226 |
+
audio_tokenizer_name_or_path: str,
|
| 227 |
+
tokenizer_name_or_path: Optional[str] = None,
|
| 228 |
+
device: str = "cuda",
|
| 229 |
+
torch_dtype: Union[torch.dtype, str] = "auto",
|
| 230 |
+
kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes
|
| 231 |
+
):
|
| 232 |
+
"""
|
| 233 |
+
Initialize the HiggsAudioServeEngine, a serving wrapper for the HiggsAudioModel.
|
| 234 |
+
The model, tokenizer, and audio tokenizer will be downloaded from the Hugging Face Hub if they are not local.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
model_name_or_path (str):
|
| 238 |
+
The name or path of the model to load.
|
| 239 |
+
audio_tokenizer_name_or_path (str):
|
| 240 |
+
The name or path of the audio tokenizer to load.
|
| 241 |
+
tokenizer_name_or_path (str):
|
| 242 |
+
The name or path of the tokenizer to load.
|
| 243 |
+
device (str):
|
| 244 |
+
The device to use for the model.
|
| 245 |
+
kv_cache_lengths (List[int]):
|
| 246 |
+
The lengths of the KV caches to use for the model. Used for cuda graph capture when device is cuda.
|
| 247 |
+
torch_dtype (Union[torch.dtype, str]):
|
| 248 |
+
The dtype to use for the model.
|
| 249 |
+
"""
|
| 250 |
+
self.device = device
|
| 251 |
+
self.model_name_or_path = model_name_or_path
|
| 252 |
+
self.torch_dtype = torch_dtype
|
| 253 |
+
|
| 254 |
+
# Initialize model and tokenizer
|
| 255 |
+
self.model = HiggsAudioModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).to(device)
|
| 256 |
+
logger.info(f"Loaded model from {model_name_or_path}, dtype: {self.model.dtype}")
|
| 257 |
+
|
| 258 |
+
if tokenizer_name_or_path is None:
|
| 259 |
+
tokenizer_name_or_path = model_name_or_path
|
| 260 |
+
logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
|
| 261 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
|
| 262 |
+
|
| 263 |
+
logger.info(f"Initializing Higgs Audio Tokenizer")
|
| 264 |
+
self.audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer_name_or_path, device=device)
|
| 265 |
+
|
| 266 |
+
self.audio_num_codebooks = self.model.config.audio_num_codebooks
|
| 267 |
+
self.audio_codebook_size = self.model.config.audio_codebook_size
|
| 268 |
+
self.audio_tokenizer_tps = self.audio_tokenizer.tps
|
| 269 |
+
self.samples_per_token = int(self.audio_tokenizer.sampling_rate // self.audio_tokenizer_tps)
|
| 270 |
+
self.hamming_window_len = 2 * self.audio_num_codebooks * self.samples_per_token
|
| 271 |
+
# Set the audio special tokens
|
| 272 |
+
self.model.set_audio_special_tokens(self.tokenizer)
|
| 273 |
+
|
| 274 |
+
# Prepare KV caches for different lengths
|
| 275 |
+
cache_config = deepcopy(self.model.config.text_config)
|
| 276 |
+
cache_config.num_hidden_layers = self.model.config.text_config.num_hidden_layers
|
| 277 |
+
if self.model.config.audio_dual_ffn_layers:
|
| 278 |
+
cache_config.num_hidden_layers += len(self.model.config.audio_dual_ffn_layers)
|
| 279 |
+
# A list of KV caches for different lengths
|
| 280 |
+
self.kv_caches = {
|
| 281 |
+
length: StaticCache(
|
| 282 |
+
config=cache_config,
|
| 283 |
+
max_batch_size=1,
|
| 284 |
+
max_cache_len=length,
|
| 285 |
+
device=self.model.device,
|
| 286 |
+
dtype=self.model.dtype,
|
| 287 |
+
)
|
| 288 |
+
for length in sorted(kv_cache_lengths)
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
if self.model.config.encode_whisper_embed:
|
| 292 |
+
logger.info(f"Loading whisper processor")
|
| 293 |
+
whisper_processor = AutoProcessor.from_pretrained(
|
| 294 |
+
"openai/whisper-large-v3-turbo",
|
| 295 |
+
trust_remote=True,
|
| 296 |
+
device=self.device,
|
| 297 |
+
)
|
| 298 |
+
else:
|
| 299 |
+
whisper_processor = None
|
| 300 |
+
|
| 301 |
+
# Reuse collator to prepare inference samples
|
| 302 |
+
self.collator = HiggsAudioSampleCollator(
|
| 303 |
+
whisper_processor=whisper_processor,
|
| 304 |
+
encode_whisper_embed=self.model.config.encode_whisper_embed,
|
| 305 |
+
audio_in_token_id=self.model.config.audio_in_token_idx,
|
| 306 |
+
audio_out_token_id=self.model.config.audio_out_token_idx,
|
| 307 |
+
audio_stream_bos_id=self.model.config.audio_stream_bos_id,
|
| 308 |
+
audio_stream_eos_id=self.model.config.audio_stream_eos_id,
|
| 309 |
+
pad_token_id=self.model.config.pad_token_id,
|
| 310 |
+
return_audio_in_tokens=False,
|
| 311 |
+
use_delay_pattern=self.model.config.use_delay_pattern,
|
| 312 |
+
audio_num_codebooks=self.model.config.audio_num_codebooks,
|
| 313 |
+
round_to=1,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Lock to prevent multiple generations from happening at the same time
|
| 317 |
+
self.generate_lock = threading.Lock()
|
| 318 |
+
|
| 319 |
+
# Capture CUDA graphs for each KV cache length
|
| 320 |
+
if device == "cuda":
|
| 321 |
+
logger.info(f"Capturing CUDA graphs for each KV cache length")
|
| 322 |
+
self.model.capture_model(self.kv_caches.values())
|
| 323 |
+
|
| 324 |
+
def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_gen: bool = False):
|
| 325 |
+
input_tokens, _, audio_contents, _ = prepare_chatml_sample(
|
| 326 |
+
chat_ml_sample,
|
| 327 |
+
self.tokenizer,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 331 |
+
if force_audio_gen:
|
| 332 |
+
postfix += "<|audio_out_bos|>"
|
| 333 |
+
postfix = self.tokenizer.encode(postfix, add_special_tokens=False)
|
| 334 |
+
input_tokens.extend(postfix)
|
| 335 |
+
|
| 336 |
+
# Configure the audio inputs
|
| 337 |
+
audio_ids_l = []
|
| 338 |
+
for audio_content in audio_contents:
|
| 339 |
+
if audio_content.audio_url not in ["placeholder", ""]:
|
| 340 |
+
raw_audio, _ = librosa.load(audio_content.audio_url, sr=self.audio_tokenizer.sampling_rate)
|
| 341 |
+
elif audio_content.raw_audio is not None:
|
| 342 |
+
raw_audio, _ = librosa.load(
|
| 343 |
+
BytesIO(base64.b64decode(audio_content.raw_audio)),
|
| 344 |
+
sr=self.audio_tokenizer.sampling_rate,
|
| 345 |
+
)
|
| 346 |
+
else:
|
| 347 |
+
raw_audio = None
|
| 348 |
+
|
| 349 |
+
if raw_audio is not None:
|
| 350 |
+
audio_ids = self.audio_tokenizer.encode(raw_audio, self.audio_tokenizer.sampling_rate)
|
| 351 |
+
audio_ids_l.append(audio_ids.squeeze(0).cpu())
|
| 352 |
+
|
| 353 |
+
if len(audio_ids_l) > 0:
|
| 354 |
+
audio_ids_start = torch.tensor(
|
| 355 |
+
np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
|
| 356 |
+
dtype=torch.long,
|
| 357 |
+
device=self.device,
|
| 358 |
+
)[0:-1]
|
| 359 |
+
audio_ids_concat = torch.cat(audio_ids_l, dim=1)
|
| 360 |
+
else:
|
| 361 |
+
audio_ids_start = None
|
| 362 |
+
audio_ids_concat = None
|
| 363 |
+
|
| 364 |
+
sample = ChatMLDatasetSample(
|
| 365 |
+
input_ids=torch.LongTensor(input_tokens),
|
| 366 |
+
label_ids=None,
|
| 367 |
+
audio_ids_concat=audio_ids_concat,
|
| 368 |
+
audio_ids_start=audio_ids_start,
|
| 369 |
+
audio_waveforms_concat=None,
|
| 370 |
+
audio_waveforms_start=None,
|
| 371 |
+
audio_sample_rate=None,
|
| 372 |
+
audio_speaker_indices=None,
|
| 373 |
+
)
|
| 374 |
+
data = self.collator([sample])
|
| 375 |
+
inputs = asdict(data)
|
| 376 |
+
for k, v in inputs.items():
|
| 377 |
+
if isinstance(v, torch.Tensor):
|
| 378 |
+
inputs[k] = v.to(self.model.device)
|
| 379 |
+
|
| 380 |
+
return inputs
|
| 381 |
+
|
| 382 |
+
def _prepare_kv_caches(self):
|
| 383 |
+
for kv_cache in self.kv_caches.values():
|
| 384 |
+
kv_cache.reset()
|
| 385 |
+
|
| 386 |
+
def generate(
|
| 387 |
+
self,
|
| 388 |
+
chat_ml_sample: ChatMLSample,
|
| 389 |
+
max_new_tokens: int,
|
| 390 |
+
temperature: float = 0.7,
|
| 391 |
+
top_k: Optional[int] = None,
|
| 392 |
+
top_p: float = 0.95,
|
| 393 |
+
stop_strings: Optional[List[str]] = None,
|
| 394 |
+
force_audio_gen: bool = False,
|
| 395 |
+
ras_win_len: Optional[int] = None,
|
| 396 |
+
ras_win_max_num_repeat: int = 2,
|
| 397 |
+
):
|
| 398 |
+
"""
|
| 399 |
+
Generate audio from a chatml sample.
|
| 400 |
+
Args:
|
| 401 |
+
chat_ml_sample: A chatml sample.
|
| 402 |
+
max_new_tokens: The maximum number of new tokens to generate.
|
| 403 |
+
temperature: The temperature to use for the generation.
|
| 404 |
+
top_p: The top p to use for the generation.
|
| 405 |
+
Returns:
|
| 406 |
+
A dictionary with the following keys:
|
| 407 |
+
audio: The generated audio.
|
| 408 |
+
sampling_rate: The sampling rate of the generated audio.
|
| 409 |
+
"""
|
| 410 |
+
# Default stop strings
|
| 411 |
+
if stop_strings is None:
|
| 412 |
+
stop_strings = ["<|end_of_text|>", "<|eot_id|>"]
|
| 413 |
+
|
| 414 |
+
with torch.no_grad(), self.generate_lock:
|
| 415 |
+
inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen)
|
| 416 |
+
prompt_token_ids = inputs["input_ids"][0].cpu().numpy()
|
| 417 |
+
|
| 418 |
+
self._prepare_kv_caches()
|
| 419 |
+
|
| 420 |
+
outputs = self.model.generate(
|
| 421 |
+
**inputs,
|
| 422 |
+
max_new_tokens=max_new_tokens,
|
| 423 |
+
use_cache=True,
|
| 424 |
+
stop_strings=stop_strings,
|
| 425 |
+
tokenizer=self.tokenizer,
|
| 426 |
+
do_sample=False if temperature == 0.0 else True,
|
| 427 |
+
temperature=temperature,
|
| 428 |
+
top_k=top_k,
|
| 429 |
+
top_p=top_p,
|
| 430 |
+
past_key_values_buckets=self.kv_caches,
|
| 431 |
+
ras_win_len=ras_win_len,
|
| 432 |
+
ras_win_max_num_repeat=ras_win_max_num_repeat,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
if len(outputs[1]) > 0:
|
| 436 |
+
wv_list = []
|
| 437 |
+
for output_audio in outputs[1]:
|
| 438 |
+
vq_code = revert_delay_pattern(output_audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
|
| 439 |
+
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
|
| 440 |
+
wv_list.append(wv_numpy)
|
| 441 |
+
wv_numpy = np.concatenate(wv_list)
|
| 442 |
+
else:
|
| 443 |
+
wv_numpy = None
|
| 444 |
+
|
| 445 |
+
# We only support one request at a time now
|
| 446 |
+
generated_text_tokens = outputs[0][0].cpu().numpy()[len(prompt_token_ids) :]
|
| 447 |
+
generated_text = self.tokenizer.decode(generated_text_tokens)
|
| 448 |
+
generated_audio_tokens = outputs[1][0].cpu().numpy()
|
| 449 |
+
return HiggsAudioResponse(
|
| 450 |
+
audio=wv_numpy,
|
| 451 |
+
generated_audio_tokens=generated_audio_tokens,
|
| 452 |
+
sampling_rate=self.audio_tokenizer.sampling_rate,
|
| 453 |
+
generated_text=generated_text,
|
| 454 |
+
generated_text_tokens=generated_text_tokens,
|
| 455 |
+
usage={
|
| 456 |
+
"prompt_tokens": prompt_token_ids.shape[0],
|
| 457 |
+
"completion_tokens": generated_text_tokens.shape[0] + generated_audio_tokens.shape[1],
|
| 458 |
+
"total_tokens": (
|
| 459 |
+
prompt_token_ids.shape[0] + generated_text_tokens.shape[0] + generated_audio_tokens.shape[1]
|
| 460 |
+
),
|
| 461 |
+
"cached_tokens": 0,
|
| 462 |
+
},
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
def text_normalize(self, text: str) -> str:
|
| 466 |
+
"""
|
| 467 |
+
Normalize the text.
|
| 468 |
+
"""
|
| 469 |
+
# Perform some basic normalization
|
| 470 |
+
text = normalize_chinese_punctuation(text)
|
| 471 |
+
# Handle parentheses
|
| 472 |
+
text = text.replace("(", " ")
|
| 473 |
+
text = text.replace(")", " ")
|
| 474 |
+
return text
|
higgs_audio/serve/utils.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
import base64
|
| 3 |
+
import re
|
| 4 |
+
import regex
|
| 5 |
+
from typing import AsyncGenerator, Union
|
| 6 |
+
import io
|
| 7 |
+
from pydub import AudioSegment
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from functools import lru_cache
|
| 11 |
+
|
| 12 |
+
from ..audio_processing.higgs_audio_tokenizer import HiggsAudioTokenizer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def random_uuid() -> str:
|
| 16 |
+
return str(uuid.uuid4().hex)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
async def async_generator_wrap(first_element, gen: AsyncGenerator):
|
| 20 |
+
"""Wrap an async generator with the first element."""
|
| 21 |
+
yield first_element
|
| 22 |
+
async for item in gen:
|
| 23 |
+
yield item
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@lru_cache(maxsize=50)
|
| 27 |
+
def encode_base64_content_from_file(file_path: str) -> str:
|
| 28 |
+
"""Encode a content from a local file to base64 format."""
|
| 29 |
+
# Read the MP3 file as binary and encode it directly to Base64
|
| 30 |
+
with open(file_path, "rb") as audio_file:
|
| 31 |
+
audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8")
|
| 32 |
+
return audio_base64
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def pcm16_to_target_format(
|
| 36 |
+
np_audio: np.ndarray,
|
| 37 |
+
sample_rate: int,
|
| 38 |
+
bit_depth: int,
|
| 39 |
+
channels: int,
|
| 40 |
+
format: str,
|
| 41 |
+
target_rate: int,
|
| 42 |
+
):
|
| 43 |
+
wav_audio = AudioSegment(
|
| 44 |
+
np_audio.tobytes(),
|
| 45 |
+
frame_rate=sample_rate,
|
| 46 |
+
sample_width=bit_depth // 8,
|
| 47 |
+
channels=channels,
|
| 48 |
+
)
|
| 49 |
+
if target_rate is not None and target_rate != sample_rate:
|
| 50 |
+
wav_audio = wav_audio.set_frame_rate(target_rate)
|
| 51 |
+
|
| 52 |
+
# Convert WAV to MP3
|
| 53 |
+
target_io = io.BytesIO()
|
| 54 |
+
wav_audio.export(target_io, format=format)
|
| 55 |
+
target_io.seek(0)
|
| 56 |
+
|
| 57 |
+
return target_io
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def contains_chinese(text: str):
|
| 64 |
+
return bool(chinese_char_pattern.search(text))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# remove blank between chinese character
|
| 68 |
+
def replace_blank(text: str):
|
| 69 |
+
out_str = []
|
| 70 |
+
for i, c in enumerate(text):
|
| 71 |
+
if c == " ":
|
| 72 |
+
if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
|
| 73 |
+
out_str.append(c)
|
| 74 |
+
else:
|
| 75 |
+
out_str.append(c)
|
| 76 |
+
return "".join(out_str)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def replace_corner_mark(text: str):
|
| 80 |
+
text = text.replace("²", "平方")
|
| 81 |
+
text = text.replace("³", "立方")
|
| 82 |
+
return text
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# remove meaningless symbol
|
| 86 |
+
def remove_bracket(text: str):
|
| 87 |
+
text = text.replace("(", "").replace(")", "")
|
| 88 |
+
text = text.replace("【", "").replace("】", "")
|
| 89 |
+
text = text.replace("`", "").replace("`", "")
|
| 90 |
+
text = text.replace("——", " ")
|
| 91 |
+
return text
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# split paragrah logic:
|
| 95 |
+
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
|
| 96 |
+
# 2. cal sentence len according to lang
|
| 97 |
+
# 3. split sentence according to puncatation
|
| 98 |
+
def split_paragraph(
|
| 99 |
+
text: str,
|
| 100 |
+
tokenize,
|
| 101 |
+
lang="zh",
|
| 102 |
+
token_max_n=80,
|
| 103 |
+
token_min_n=60,
|
| 104 |
+
merge_len=20,
|
| 105 |
+
comma_split=False,
|
| 106 |
+
):
|
| 107 |
+
def calc_utt_length(_text: str):
|
| 108 |
+
if lang == "zh":
|
| 109 |
+
return len(_text)
|
| 110 |
+
else:
|
| 111 |
+
return len(tokenize(_text))
|
| 112 |
+
|
| 113 |
+
def should_merge(_text: str):
|
| 114 |
+
if lang == "zh":
|
| 115 |
+
return len(_text) < merge_len
|
| 116 |
+
else:
|
| 117 |
+
return len(tokenize(_text)) < merge_len
|
| 118 |
+
|
| 119 |
+
if lang == "zh":
|
| 120 |
+
pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
|
| 121 |
+
else:
|
| 122 |
+
pounc = [".", "?", "!", ";", ":"]
|
| 123 |
+
if comma_split:
|
| 124 |
+
pounc.extend([",", ","])
|
| 125 |
+
|
| 126 |
+
if text[-1] not in pounc:
|
| 127 |
+
if lang == "zh":
|
| 128 |
+
text += "。"
|
| 129 |
+
else:
|
| 130 |
+
text += "."
|
| 131 |
+
|
| 132 |
+
st = 0
|
| 133 |
+
utts = []
|
| 134 |
+
for i, c in enumerate(text):
|
| 135 |
+
if c in pounc:
|
| 136 |
+
if len(text[st:i]) > 0:
|
| 137 |
+
utts.append(text[st:i] + c)
|
| 138 |
+
if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
|
| 139 |
+
tmp = utts.pop(-1)
|
| 140 |
+
utts.append(tmp + text[i + 1])
|
| 141 |
+
st = i + 2
|
| 142 |
+
else:
|
| 143 |
+
st = i + 1
|
| 144 |
+
|
| 145 |
+
final_utts = []
|
| 146 |
+
cur_utt = ""
|
| 147 |
+
for utt in utts:
|
| 148 |
+
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
|
| 149 |
+
final_utts.append(cur_utt)
|
| 150 |
+
cur_utt = ""
|
| 151 |
+
cur_utt = cur_utt + utt
|
| 152 |
+
if len(cur_utt) > 0:
|
| 153 |
+
if should_merge(cur_utt) and len(final_utts) != 0:
|
| 154 |
+
final_utts[-1] = final_utts[-1] + cur_utt
|
| 155 |
+
else:
|
| 156 |
+
final_utts.append(cur_utt)
|
| 157 |
+
|
| 158 |
+
return final_utts
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def is_only_punctuation(text: str):
|
| 162 |
+
# Regular expression: Match strings that consist only of punctuation marks or are empty.
|
| 163 |
+
punctuation_pattern = r"^[\p{P}\p{S}]*$"
|
| 164 |
+
return bool(regex.fullmatch(punctuation_pattern, text))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# spell Arabic numerals
|
| 168 |
+
def spell_out_number(text: str, inflect_parser):
|
| 169 |
+
new_text = []
|
| 170 |
+
st = None
|
| 171 |
+
for i, c in enumerate(text):
|
| 172 |
+
if not c.isdigit():
|
| 173 |
+
if st is not None:
|
| 174 |
+
num_str = inflect_parser.number_to_words(text[st:i])
|
| 175 |
+
new_text.append(num_str)
|
| 176 |
+
st = None
|
| 177 |
+
new_text.append(c)
|
| 178 |
+
else:
|
| 179 |
+
if st is None:
|
| 180 |
+
st = i
|
| 181 |
+
if st is not None and st < len(text):
|
| 182 |
+
num_str = inflect_parser.number_to_words(text[st:])
|
| 183 |
+
new_text.append(num_str)
|
| 184 |
+
return "".join(new_text)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def remove_emoji(text: str):
|
| 188 |
+
# Pattern to match emojis and their modifiers
|
| 189 |
+
# - Standard emoji range
|
| 190 |
+
# - Zero-width joiners (U+200D)
|
| 191 |
+
# - Variation selectors (U+FE0F, U+FE0E)
|
| 192 |
+
# - Skin tone modifiers (U+1F3FB to U+1F3FF)
|
| 193 |
+
emoji_pattern = re.compile(
|
| 194 |
+
r"["
|
| 195 |
+
r"\U00010000-\U0010FFFF" # Standard emoji range
|
| 196 |
+
r"\u200D" # Zero-width joiner
|
| 197 |
+
r"\uFE0F\uFE0E" # Variation selectors
|
| 198 |
+
r"\U0001F3FB-\U0001F3FF" # Skin tone modifiers
|
| 199 |
+
r"]+",
|
| 200 |
+
flags=re.UNICODE,
|
| 201 |
+
)
|
| 202 |
+
return emoji_pattern.sub(r"", text)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def remove_repeated_punctuations(text, punctuations):
|
| 206 |
+
if len(punctuations) == 0:
|
| 207 |
+
return text
|
| 208 |
+
pattern = f"[{re.escape(''.join(punctuations))}]" # Create regex pattern for given punctuations
|
| 209 |
+
return re.sub(rf"({pattern})\1+", r"\1", text)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def full_to_half_width(text: str) -> str:
|
| 213 |
+
"""Convert full-width punctuation to half-width in a given string."""
|
| 214 |
+
full_width = "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
|
| 215 |
+
half_width = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
| 216 |
+
trans_table = str.maketrans(full_width, half_width)
|
| 217 |
+
return text.translate(trans_table)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def split_interleaved_delayed_audios(
|
| 221 |
+
audio_data: Union[list[list[int]], torch.Tensor],
|
| 222 |
+
audio_tokenizer: HiggsAudioTokenizer,
|
| 223 |
+
audio_stream_eos_id: int,
|
| 224 |
+
) -> list[tuple[list[list[int]], torch.Tensor]]:
|
| 225 |
+
separator = [audio_stream_eos_id] * audio_tokenizer.num_codebooks
|
| 226 |
+
|
| 227 |
+
# Convert separator to numpy array if audio_data is numpy array
|
| 228 |
+
if isinstance(audio_data, torch.Tensor):
|
| 229 |
+
audio_data = audio_data.transpose(1, 0)
|
| 230 |
+
separator = torch.tensor(separator)
|
| 231 |
+
# Find the indices where the rows equal the separator
|
| 232 |
+
split_indices = torch.where(torch.all(audio_data == separator, dim=1))[0]
|
| 233 |
+
start = 0
|
| 234 |
+
groups = []
|
| 235 |
+
for idx in split_indices:
|
| 236 |
+
groups.append(audio_data[start:idx].transpose(1, 0))
|
| 237 |
+
start = idx + 1
|
| 238 |
+
if start < len(audio_data):
|
| 239 |
+
groups.append(audio_data[start:].transpose(1, 0))
|
| 240 |
+
else:
|
| 241 |
+
groups = []
|
| 242 |
+
current = []
|
| 243 |
+
for row in audio_data:
|
| 244 |
+
current.append(row)
|
| 245 |
+
|
| 246 |
+
if row == separator:
|
| 247 |
+
groups.append(current)
|
| 248 |
+
current = []
|
| 249 |
+
|
| 250 |
+
# Don't forget the last group if there's no trailing separator
|
| 251 |
+
if current:
|
| 252 |
+
groups.append(current)
|
| 253 |
+
|
| 254 |
+
return groups
|
pyproject.toml
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[tool.ruff]
|
| 6 |
+
line-length = 119
|
| 7 |
+
target-version = "py310"
|
| 8 |
+
indent-width = 4
|
| 9 |
+
exclude = [
|
| 10 |
+
".bzr",
|
| 11 |
+
".direnv",
|
| 12 |
+
".eggs",
|
| 13 |
+
".git",
|
| 14 |
+
".git-rewrite",
|
| 15 |
+
".hg",
|
| 16 |
+
".ipynb_checkpoints",
|
| 17 |
+
".mypy_cache",
|
| 18 |
+
".nox",
|
| 19 |
+
".pants.d",
|
| 20 |
+
".pyenv",
|
| 21 |
+
".pytest_cache",
|
| 22 |
+
".pytype",
|
| 23 |
+
".ruff_cache",
|
| 24 |
+
".svn",
|
| 25 |
+
".tox",
|
| 26 |
+
".venv",
|
| 27 |
+
".vscode",
|
| 28 |
+
"__pypackages__",
|
| 29 |
+
"_build",
|
| 30 |
+
"buck-out",
|
| 31 |
+
"build",
|
| 32 |
+
"dist",
|
| 33 |
+
"node_modules",
|
| 34 |
+
"site-packages",
|
| 35 |
+
"venv",
|
| 36 |
+
"external",
|
| 37 |
+
"third_party",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
[tool.ruff.lint]
|
| 41 |
+
preview = true
|
| 42 |
+
ignore-init-module-imports = true
|
| 43 |
+
extend-select = [
|
| 44 |
+
"B009", # static getattr
|
| 45 |
+
"B010", # static setattr
|
| 46 |
+
"CPY", # Copyright
|
| 47 |
+
"E", # PEP8 errors
|
| 48 |
+
"F", # PEP8 formatting
|
| 49 |
+
"I", # Import sorting
|
| 50 |
+
"TID251", # Banned API
|
| 51 |
+
"UP", # Pyupgrade
|
| 52 |
+
"W", # PEP8 warnings
|
| 53 |
+
]
|
| 54 |
+
ignore = [
|
| 55 |
+
"E501", # Line length (handled by ruff-format)
|
| 56 |
+
"E741", # Ambiguous variable name
|
| 57 |
+
"W605", # Invalid escape sequence
|
| 58 |
+
"UP007", # X | Y type annotations
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
[tool.ruff.lint.per-file-ignores]
|
| 62 |
+
"__init__.py" = [
|
| 63 |
+
"F401", # Ignore seemingly unused imports (they're meant for re-export)
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
[tool.ruff.lint.isort]
|
| 67 |
+
lines-after-imports = 2
|
| 68 |
+
known-first-party = ["character_tuning"]
|
| 69 |
+
|
| 70 |
+
[tool.ruff.format]
|
| 71 |
+
# Like Black, use double quotes for strings.
|
| 72 |
+
quote-style = "double"
|
| 73 |
+
|
| 74 |
+
# Like Black, indent with spaces, rather than tabs.
|
| 75 |
+
indent-style = "space"
|
| 76 |
+
|
| 77 |
+
# Like Black, respect magic trailing commas.
|
| 78 |
+
skip-magic-trailing-comma = false
|
| 79 |
+
|
| 80 |
+
# Like Black, automatically detect the appropriate line ending.
|
| 81 |
+
line-ending = "auto"
|
| 82 |
+
|
| 83 |
+
# Enable auto-formatting of code examples in docstrings. Markdown,
|
| 84 |
+
# reStructuredText code/literal blocks and doctests are all supported.
|
| 85 |
+
#
|
| 86 |
+
# This is currently disabled by default, but it is planned for this
|
| 87 |
+
# to be opt-out in the future.
|
| 88 |
+
docstring-code-format = false
|
| 89 |
+
|
| 90 |
+
# Set the line length limit used when formatting code snippets in
|
| 91 |
+
# docstrings.
|
| 92 |
+
#
|
| 93 |
+
# This only has an effect when the `docstring-code-format` setting is
|
| 94 |
+
# enabled.
|
| 95 |
+
docstring-code-line-length = "dynamic"
|
| 96 |
+
|
| 97 |
+
[tool.ruff.lint.flake8-tidy-imports.banned-api]
|
| 98 |
+
"os.getenv".msg = "Use os.environ instead"
|
| 99 |
+
"os.putenv".msg = "Use os.environ instead"
|
| 100 |
+
"os.unsetenv".msg = "Use os.environ instead"
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
descript-audio-codec
|
| 2 |
+
torch==2.5.1
|
| 3 |
+
torchaudio==2.5.1
|
| 4 |
+
transformers>=4.45.1,<4.47.0
|
| 5 |
+
librosa
|
| 6 |
+
dacite
|
| 7 |
+
boto3==1.35.36
|
| 8 |
+
s3fs
|
| 9 |
+
json_repair
|
| 10 |
+
pandas
|
| 11 |
+
pydantic
|
| 12 |
+
vector_quantize_pytorch
|
| 13 |
+
loguru
|
| 14 |
+
pydub
|
| 15 |
+
ruff==0.12.2
|
| 16 |
+
omegaconf
|
| 17 |
+
click
|
theme.json
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"theme": {
|
| 3 |
+
"_font": [
|
| 4 |
+
{"__gradio_font__": true, "name": "Arial", "class": "font"},
|
| 5 |
+
{"__gradio_font__": true, "name": "sans-serif", "class": "font"}
|
| 6 |
+
],
|
| 7 |
+
"_font_mono": [
|
| 8 |
+
{"__gradio_font__": true, "name": "Courier New", "class": "font"},
|
| 9 |
+
{"__gradio_font__": true, "name": "monospace", "class": "font"}
|
| 10 |
+
],
|
| 11 |
+
"_stylesheets": [],
|
| 12 |
+
"background_fill_primary": "black",
|
| 13 |
+
"background_fill_primary_dark": "black",
|
| 14 |
+
"background_fill_secondary": "#1a1a1a",
|
| 15 |
+
"background_fill_secondary_dark": "#1a1a1a",
|
| 16 |
+
"block_background_fill": "#1a1a1a",
|
| 17 |
+
"block_background_fill_dark": "#1a1a1a",
|
| 18 |
+
"block_border_color": "#333333",
|
| 19 |
+
"block_border_color_dark": "#333333",
|
| 20 |
+
"block_border_width": "1px",
|
| 21 |
+
"block_info_text_color": "#cccccc",
|
| 22 |
+
"block_info_text_color_dark": "#cccccc",
|
| 23 |
+
"block_info_text_size": "*text_sm",
|
| 24 |
+
"block_info_text_weight": "600",
|
| 25 |
+
"block_label_background_fill": "#333333",
|
| 26 |
+
"block_label_background_fill_dark": "#333333",
|
| 27 |
+
"block_label_border_color": "#444444",
|
| 28 |
+
"block_label_border_color_dark": "#444444",
|
| 29 |
+
"block_label_border_width": "1px",
|
| 30 |
+
"block_label_margin": "*spacing_md",
|
| 31 |
+
"block_label_padding": "*spacing_sm *spacing_md",
|
| 32 |
+
"block_label_radius": "*radius_md",
|
| 33 |
+
"block_label_right_radius": "0 calc(*radius_lg - 1px) 0 calc(*radius_lg - 1px)",
|
| 34 |
+
"block_label_text_color": "yellow",
|
| 35 |
+
"block_label_text_color_dark": "yellow",
|
| 36 |
+
"block_label_text_size": "*text_md",
|
| 37 |
+
"block_label_text_weight": "700",
|
| 38 |
+
"block_padding": "*spacing_xl calc(*spacing_xl + 2px)",
|
| 39 |
+
"block_radius": "*radius_lg",
|
| 40 |
+
"block_shadow": "none",
|
| 41 |
+
"block_title_background_fill": "#333333",
|
| 42 |
+
"block_title_border_color": "none",
|
| 43 |
+
"block_title_border_width": "0px",
|
| 44 |
+
"block_title_padding": "*block_label_padding",
|
| 45 |
+
"block_title_radius": "*block_label_radius",
|
| 46 |
+
"block_title_text_color": "yellow",
|
| 47 |
+
"block_title_text_color_dark": "yellow",
|
| 48 |
+
"block_title_text_size": "*text_md",
|
| 49 |
+
"block_title_text_weight": "700",
|
| 50 |
+
"body_background_fill": "black",
|
| 51 |
+
"body_background_fill_dark": "black",
|
| 52 |
+
"body_text_color": "white",
|
| 53 |
+
"body_text_color_dark": "white",
|
| 54 |
+
"body_text_color_subdued": "#aaaaaa",
|
| 55 |
+
"body_text_color_subdued_dark": "#aaaaaa",
|
| 56 |
+
"body_text_size": "*text_md",
|
| 57 |
+
"body_text_weight": "600",
|
| 58 |
+
"border_color_accent": "yellow",
|
| 59 |
+
"border_color_accent_dark": "yellow",
|
| 60 |
+
"border_color_primary": "#333333",
|
| 61 |
+
"border_color_primary_dark": "#333333",
|
| 62 |
+
"button_border_width": "*input_border_width",
|
| 63 |
+
"button_border_width_dark": "*input_border_width",
|
| 64 |
+
"button_cancel_background_fill": "#333333",
|
| 65 |
+
"button_cancel_background_fill_dark": "#333333",
|
| 66 |
+
"button_cancel_background_fill_hover": "#444444",
|
| 67 |
+
"button_cancel_background_fill_hover_dark": "#444444",
|
| 68 |
+
"button_cancel_border_color": "#555555",
|
| 69 |
+
"button_cancel_border_color_dark": "#555555",
|
| 70 |
+
"button_cancel_border_color_hover": "#666666",
|
| 71 |
+
"button_cancel_border_color_hover_dark": "#666666",
|
| 72 |
+
"button_cancel_text_color": "white",
|
| 73 |
+
"button_cancel_text_color_dark": "white",
|
| 74 |
+
"button_cancel_text_color_hover": "white",
|
| 75 |
+
"button_cancel_text_color_hover_dark": "white",
|
| 76 |
+
"button_large_padding": "*spacing_lg calc(2 * *spacing_lg)",
|
| 77 |
+
"button_large_radius": "*radius_lg",
|
| 78 |
+
"button_large_text_size": "*text_lg",
|
| 79 |
+
"button_large_text_weight": "700",
|
| 80 |
+
"button_primary_background_fill": "yellow",
|
| 81 |
+
"button_primary_background_fill_dark": "yellow",
|
| 82 |
+
"button_primary_background_fill_hover": "#ffff80",
|
| 83 |
+
"button_primary_background_fill_hover_dark": "#ffff80",
|
| 84 |
+
"button_primary_border_color": "yellow",
|
| 85 |
+
"button_primary_border_color_dark": "yellow",
|
| 86 |
+
"button_primary_border_color_hover": "#ffff80",
|
| 87 |
+
"button_primary_border_color_hover_dark": "#ffff80",
|
| 88 |
+
"button_primary_text_color": "black",
|
| 89 |
+
"button_primary_text_color_dark": "black",
|
| 90 |
+
"button_primary_text_color_hover": "black",
|
| 91 |
+
"button_primary_text_color_hover_dark": "black",
|
| 92 |
+
"button_secondary_background_fill": "#333333",
|
| 93 |
+
"button_secondary_background_fill_dark": "#333333",
|
| 94 |
+
"button_secondary_background_fill_hover": "#444444",
|
| 95 |
+
"button_secondary_background_fill_hover_dark": "#444444",
|
| 96 |
+
"button_secondary_border_color": "#555555",
|
| 97 |
+
"button_secondary_border_color_dark": "#555555",
|
| 98 |
+
"button_secondary_border_color_hover": "#666666",
|
| 99 |
+
"button_secondary_border_color_hover_dark": "#666666",
|
| 100 |
+
"button_secondary_text_color": "white",
|
| 101 |
+
"button_secondary_text_color_dark": "white",
|
| 102 |
+
"button_secondary_text_color_hover": "white",
|
| 103 |
+
"button_secondary_text_color_hover_dark": "white",
|
| 104 |
+
"button_shadow": "*shadow_drop_lg",
|
| 105 |
+
"button_shadow_active": "*shadow_inset",
|
| 106 |
+
"button_shadow_hover": "*shadow_drop_lg",
|
| 107 |
+
"button_small_padding": "*spacing_sm calc(2 * *spacing_sm)",
|
| 108 |
+
"button_small_radius": "*radius_lg",
|
| 109 |
+
"button_small_text_size": "*text_md",
|
| 110 |
+
"button_small_text_weight": "700",
|
| 111 |
+
"button_transition": "background-color 0.2s ease",
|
| 112 |
+
"checkbox_background_color": "#1a1a1a",
|
| 113 |
+
"checkbox_background_color_dark": "#1a1a1a",
|
| 114 |
+
"checkbox_background_color_focus": "#1a1a1a",
|
| 115 |
+
"checkbox_background_color_focus_dark": "#1a1a1a",
|
| 116 |
+
"checkbox_background_color_hover": "#1a1a1a",
|
| 117 |
+
"checkbox_background_color_hover_dark": "#1a1a1a",
|
| 118 |
+
"checkbox_background_color_selected": "yellow",
|
| 119 |
+
"checkbox_background_color_selected_dark": "yellow",
|
| 120 |
+
"checkbox_border_color": "#333333",
|
| 121 |
+
"checkbox_border_color_dark": "#333333",
|
| 122 |
+
"checkbox_border_color_focus": "yellow",
|
| 123 |
+
"checkbox_border_color_focus_dark": "yellow",
|
| 124 |
+
"checkbox_border_color_hover": "#444444",
|
| 125 |
+
"checkbox_border_color_hover_dark": "#444444",
|
| 126 |
+
"checkbox_border_color_selected": "yellow",
|
| 127 |
+
"checkbox_border_color_selected_dark": "yellow",
|
| 128 |
+
"checkbox_border_radius": "*radius_sm",
|
| 129 |
+
"checkbox_border_width": "1px",
|
| 130 |
+
"checkbox_border_width_dark": "*input_border_width",
|
| 131 |
+
"checkbox_check": "url(\"data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='black' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e\")",
|
| 132 |
+
"checkbox_label_background_fill": "#333333",
|
| 133 |
+
"checkbox_label_background_fill_dark": "#333333",
|
| 134 |
+
"checkbox_label_background_fill_hover": "#444444",
|
| 135 |
+
"checkbox_label_background_fill_hover_dark": "#444444",
|
| 136 |
+
"checkbox_label_background_fill_selected": "yellow",
|
| 137 |
+
"checkbox_label_background_fill_selected_dark": "yellow",
|
| 138 |
+
"checkbox_label_border_color": "#555555",
|
| 139 |
+
"checkbox_label_border_color_dark": "#555555",
|
| 140 |
+
"checkbox_label_border_color_hover": "#666666",
|
| 141 |
+
"checkbox_label_border_color_hover_dark": "#666666",
|
| 142 |
+
"checkbox_label_border_width": "*input_border_width",
|
| 143 |
+
"checkbox_label_border_width_dark": "*input_border_width",
|
| 144 |
+
"checkbox_label_gap": "*spacing_lg",
|
| 145 |
+
"checkbox_label_padding": "*spacing_md calc(2 * *spacing_md)",
|
| 146 |
+
"checkbox_label_shadow": "*shadow_drop_lg",
|
| 147 |
+
"checkbox_label_text_color": "white",
|
| 148 |
+
"checkbox_label_text_color_dark": "white",
|
| 149 |
+
"checkbox_label_text_color_selected": "black",
|
| 150 |
+
"checkbox_label_text_color_selected_dark": "black",
|
| 151 |
+
"checkbox_label_text_size": "*text_md",
|
| 152 |
+
"checkbox_label_text_weight": "700",
|
| 153 |
+
"checkbox_shadow": "none",
|
| 154 |
+
"color_accent": "yellow",
|
| 155 |
+
"color_accent_soft": "#333300",
|
| 156 |
+
"color_accent_soft_dark": "#333300",
|
| 157 |
+
"container_radius": "*radius_lg",
|
| 158 |
+
"embed_radius": "*radius_lg",
|
| 159 |
+
"error_background_fill": "#330000",
|
| 160 |
+
"error_background_fill_dark": "#330000",
|
| 161 |
+
"error_border_color": "#660000",
|
| 162 |
+
"error_border_color_dark": "#660000",
|
| 163 |
+
"error_border_width": "1px",
|
| 164 |
+
"error_text_color": "#ff6666",
|
| 165 |
+
"error_text_color_dark": "#ff6666",
|
| 166 |
+
"font": "'Arial', 'sans-serif'",
|
| 167 |
+
"font_mono": "'Courier New', 'monospace'",
|
| 168 |
+
"form_gap_width": "0px",
|
| 169 |
+
"input_background_fill": "#1a1a1a",
|
| 170 |
+
"input_background_fill_dark": "#1a1a1a",
|
| 171 |
+
"input_background_fill_focus": "#333333",
|
| 172 |
+
"input_background_fill_focus_dark": "#333333",
|
| 173 |
+
"input_background_fill_hover": "#1a1a1a",
|
| 174 |
+
"input_background_fill_hover_dark": "#1a1a1a",
|
| 175 |
+
"input_border_color": "#333333",
|
| 176 |
+
"input_border_color_dark": "#333333",
|
| 177 |
+
"input_border_color_focus": "yellow",
|
| 178 |
+
"input_border_color_focus_dark": "yellow",
|
| 179 |
+
"input_border_color_hover": "#444444",
|
| 180 |
+
"input_border_color_hover_dark": "#444444",
|
| 181 |
+
"input_border_width": "1px",
|
| 182 |
+
"input_padding": "*spacing_xl",
|
| 183 |
+
"input_placeholder_color": "#666666",
|
| 184 |
+
"input_placeholder_color_dark": "#666666",
|
| 185 |
+
"input_radius": "*radius_lg",
|
| 186 |
+
"input_shadow": "none",
|
| 187 |
+
"input_shadow_focus": "0 0 0 2px rgba(255,255,0,0.2)",
|
| 188 |
+
"input_text_size": "*text_md",
|
| 189 |
+
"input_text_weight": "600",
|
| 190 |
+
"layout_gap": "*spacing_xxl",
|
| 191 |
+
"link_text_color": "yellow",
|
| 192 |
+
"link_text_color_active": "#ffff80",
|
| 193 |
+
"link_text_color_active_dark": "#ffff80",
|
| 194 |
+
"link_text_color_dark": "yellow",
|
| 195 |
+
"link_text_color_hover": "#ffff80",
|
| 196 |
+
"link_text_color_hover_dark": "#ffff80",
|
| 197 |
+
"link_text_color_visited": "yellow",
|
| 198 |
+
"link_text_color_visited_dark": "yellow",
|
| 199 |
+
"loader_color": "yellow",
|
| 200 |
+
"neutral_100": "#1a1a1a",
|
| 201 |
+
"neutral_200": "#333333",
|
| 202 |
+
"neutral_300": "#444444",
|
| 203 |
+
"neutral_400": "#666666",
|
| 204 |
+
"neutral_50": "#0d0d0d",
|
| 205 |
+
"neutral_500": "#808080",
|
| 206 |
+
"neutral_600": "#999999",
|
| 207 |
+
"neutral_700": "#b3b3b3",
|
| 208 |
+
"neutral_800": "#cccccc",
|
| 209 |
+
"neutral_900": "#e6e6e6",
|
| 210 |
+
"neutral_950": "#f2f2f2",
|
| 211 |
+
"panel_background_fill": "#1a1a1a",
|
| 212 |
+
"panel_background_fill_dark": "#1a1a1a",
|
| 213 |
+
"panel_border_color": "#333333",
|
| 214 |
+
"panel_border_color_dark": "#333333",
|
| 215 |
+
"panel_border_width": "1px",
|
| 216 |
+
"primary_100": "#333300",
|
| 217 |
+
"primary_200": "#666600",
|
| 218 |
+
"primary_300": "#999900",
|
| 219 |
+
"primary_400": "#cccc00",
|
| 220 |
+
"primary_50": "#1a1a00",
|
| 221 |
+
"primary_500": "yellow",
|
| 222 |
+
"primary_600": "#ffff33",
|
| 223 |
+
"primary_700": "#ffff66",
|
| 224 |
+
"primary_800": "#ffff99",
|
| 225 |
+
"primary_900": "#ffffcc",
|
| 226 |
+
"primary_950": "#ffffe6",
|
| 227 |
+
"prose_header_text_weight": "700",
|
| 228 |
+
"prose_text_size": "*text_md",
|
| 229 |
+
"prose_text_weight": "600",
|
| 230 |
+
"radio_circle": "url(\"data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='black' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e\")",
|
| 231 |
+
"radius_lg": "6px",
|
| 232 |
+
"radius_md": "4px",
|
| 233 |
+
"radius_sm": "2px",
|
| 234 |
+
"radius_xl": "8px",
|
| 235 |
+
"radius_xs": "1px",
|
| 236 |
+
"radius_xxl": "12px",
|
| 237 |
+
"radius_xxs": "1px",
|
| 238 |
+
"secondary_100": "#333333",
|
| 239 |
+
"secondary_200": "#444444",
|
| 240 |
+
"secondary_300": "#555555",
|
| 241 |
+
"secondary_400": "#666666",
|
| 242 |
+
"secondary_50": "#1a1a1a",
|
| 243 |
+
"secondary_500": "#777777",
|
| 244 |
+
"secondary_600": "#888888",
|
| 245 |
+
"secondary_700": "#999999",
|
| 246 |
+
"secondary_800": "#aaaaaa",
|
| 247 |
+
"secondary_900": "#bbbbbb",
|
| 248 |
+
"secondary_950": "#cccccc",
|
| 249 |
+
"section_header_text_size": "*text_md",
|
| 250 |
+
"section_header_text_weight": "700",
|
| 251 |
+
"shadow_drop": "0 1px 4px 0 rgba(255,255,0,0.1)",
|
| 252 |
+
"shadow_drop_lg": "0 2px 5px 0 rgba(255,255,0,0.2)",
|
| 253 |
+
"shadow_inset": "rgba(255,255,0,0.1) 0px 2px 4px 0px inset",
|
| 254 |
+
"shadow_spread": "6px",
|
| 255 |
+
"shadow_spread_dark": "1px",
|
| 256 |
+
"slider_color": "yellow",
|
| 257 |
+
"slider_color_dark": "yellow",
|
| 258 |
+
"spacing_lg": "8px",
|
| 259 |
+
"spacing_md": "6px",
|
| 260 |
+
"spacing_sm": "4px",
|
| 261 |
+
"spacing_xl": "10px",
|
| 262 |
+
"spacing_xs": "2px",
|
| 263 |
+
"spacing_xxl": "16px",
|
| 264 |
+
"spacing_xxs": "1px",
|
| 265 |
+
"stat_background_fill": "#333300",
|
| 266 |
+
"stat_background_fill_dark": "#333300",
|
| 267 |
+
"table_border_color": "#333333",
|
| 268 |
+
"table_border_color_dark": "#333333",
|
| 269 |
+
"table_even_background_fill": "#1a1a1a",
|
| 270 |
+
"table_even_background_fill_dark": "#1a1a1a",
|
| 271 |
+
"table_odd_background_fill": "#0d0d0d",
|
| 272 |
+
"table_odd_background_fill_dark": "#0d0d0d",
|
| 273 |
+
"table_radius": "*radius_lg",
|
| 274 |
+
"table_row_focus": "#333300",
|
| 275 |
+
"table_row_focus_dark": "#333300",
|
| 276 |
+
"text_lg": "16px",
|
| 277 |
+
"text_md": "14px",
|
| 278 |
+
"text_sm": "12px",
|
| 279 |
+
"text_xl": "22px",
|
| 280 |
+
"text_xs": "10px",
|
| 281 |
+
"text_xxl": "26px",
|
| 282 |
+
"text_xxs": "9px"
|
| 283 |
+
},
|
| 284 |
+
"version": "1.0.0"
|
| 285 |
+
}
|
voice_examples/config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"belinda": {
|
| 3 |
+
"transcript": "Twas the night before my birthday. Hooray! It's almost here! It may not be a holiday, but it's the best day of the year.",
|
| 4 |
+
"audio_file": "belinda.wav"
|
| 5 |
+
},
|
| 6 |
+
"broom_salesman": {
|
| 7 |
+
"transcript": "I would imagine so. A wand with a dragon heartstring core is capable of dazzling magic. And the bond between you and your wand should only grow stronger. Do not be surprised at your new wand's ability to perceive your intentions - particularly in a moment of need.",
|
| 8 |
+
"audio_file": "broom_salesman.wav"
|
| 9 |
+
},
|
| 10 |
+
"chadwick": {
|
| 11 |
+
"transcript": "Oh dear, who left all this junk lying around? Whoops, there it goes! Mind your pointed little pink head, starfish man.",
|
| 12 |
+
"audio_file": "chadwick.wav"
|
| 13 |
+
},
|
| 14 |
+
"en_man": {
|
| 15 |
+
"transcript": "Maintaining your ability to learn translates into increased marketability, improved career options and higher salaries.",
|
| 16 |
+
"audio_file": "en_man.wav"
|
| 17 |
+
},
|
| 18 |
+
"en_woman": {
|
| 19 |
+
"transcript": "The device would work during the day as well, if you took steps to either block direct sunlight or point it away from the sun.",
|
| 20 |
+
"audio_file": "en_woman.wav"
|
| 21 |
+
},
|
| 22 |
+
"mabel": {
|
| 23 |
+
"transcript": "You do talk an awful lot about weather, did you know that? Sometimes I wonder if you're actually content to be a wizard or if you're secretly harbouring a desire to become a seer of the clouds.",
|
| 24 |
+
"audio_file": "mabel.wav"
|
| 25 |
+
},
|
| 26 |
+
"vex": {
|
| 27 |
+
"transcript": "Uhh, this is going to take forever. Why is everything so far?",
|
| 28 |
+
"audio_file": "vex.wav"
|
| 29 |
+
},
|
| 30 |
+
}
|