brandongraves08 commited on
Commit
0322512
·
0 Parent(s):

HeartMuLa Gradio Space deployment

Browse files
.dockerignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .github
3
+ .venv
4
+ .env
5
+ __pycache__
6
+ *.pyc
7
+ *.pyo
8
+ *.egg-info
9
+ .pytest_cache
10
+ .coverage
11
+ *.mp3
12
+ *.wav
13
+ *.ogg
14
+ assets/*.mp3
15
+ assets/*.wav
16
+ .gitignore
17
+ README.md
18
+ LICENSE
.github/copilot-instructions.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GitHub Copilot instructions for heartlib
2
+
3
+ ## What this repo is
4
+ - HeartMuLa music generation stack: converts lyrics + style tags → audio via two-stage pipeline (HeartMuLa LLM → audio tokens, HeartCodec flow-matching codec → waveform).
5
+ - Supports lyrics transcription via Whisper-based HeartTranscriptor.
6
+ - Python package with main entry points: `heartlib.HeartMuLaGenPipeline` and `heartlib.HeartTranscriptorPipeline` (see `src/heartlib/__init__.py`).
7
+ - Examples/reference CLIs in `examples/`; production use: install via `pip install -e .`
8
+
9
+ ## Core architecture & data flow
10
+ **Music generation:** `Inputs(lyrics, tags)` → Tokenizer → HeartMuLa (frame-by-frame token generation) → HeartCodec (flow-matching detokenization) → MP3
11
+ - **HeartMuLa** (LLaMA3.2 backbone, 3B/300M/7B flavors): generates 8 parallel audio codebook streams + 1 prompt guidance stream (9 total, `_parallel_number=9`)
12
+ - **HeartCodec**: VQ codec that reconstructs waveforms from codebook frames in overlapping windows, fixed 48 kHz output
13
+ - **HeartTranscriptor**: Whisper variant fine-tuned for vocal transcription; works on 30-second chunks, batch=16
14
+
15
+ ## Repo map
16
+ - `src/heartlib/pipelines/music_generation.py`: orchestrates tokenization → HeartMuLa inference → HeartCodec detokenize → `torchaudio.save()`
17
+ - `HeartMuLaGenPipeline.from_pretrained()`: factory with device/dtype/lazy_load config
18
+ - `_resolve_paths()`: validates checkpoint layout early (hard error if missing)
19
+ - `_resolve_devices()`: handles scalar/dict device specs; forces `lazy_load=False` for multi-device
20
+ - `src/heartlib/heartmula/modeling_heartmula.py`: backbone (llama3_2_3B/7B/300M factory functions), token generator with CFG support
21
+ - `src/heartlib/heartcodec/modeling_heartcodec.py`: VQ codec + flow-matching decoder, detokenizes `(codebooks, time)` frames
22
+ - `src/heartlib/pipelines/lyrics_transcription.py`: wraps transformers' Whisper; fixed chunk=30s, batch=16
23
+ - `src/heartlib/heartmula/configuration_heartmula.py`, `src/heartlib/heartcodec/configuration_heartcodec.py`: model configs
24
+
25
+ ## Checkpoints & required layout
26
+ Directory structure after downloads (see README or `hf download` commands):
27
+ ```
28
+ ./ckpt/
29
+ HeartMuLa-oss-3B/ (or -7B, -300M)
30
+ config.json
31
+ model-*.safetensors
32
+ model.safetensors.index.json
33
+ HeartCodec-oss/
34
+ config.json
35
+ model-*.safetensors
36
+ model.safetensors.index.json
37
+ HeartTranscriptor-oss/
38
+ config.json
39
+ pytorch_model.bin (or safetensors)
40
+ tokenizer.json
41
+ gen_config.json
42
+ ```
43
+ - `_resolve_paths(pretrained_path, version)` validates all required files; raises FileNotFoundError if missing
44
+ - Latest checkpoint: HeartMuLa-RL-oss-3B-20260123 (RL-tuned, recommended for style control)
45
+
46
+ ## Generation pipeline behaviors to know
47
+ - **Inputs:** dict with `lyrics`, `tags` (both strings or file paths); auto-lowercased, tags wrapped with `<tag>...</tag>` if missing
48
+ - **Tokenization:** uses `tokenizers.Tokenizer` (from `tokenizer.json`); token IDs from `HeartMuLaGenConfig` (text_bos_id=128000, text_eos_id=128001, audio_eos_id=8193)
49
+ - **CFG (classifier-free guidance):** if `cfg_scale != 1.0`, batch duplicated for unconditional pass (bs becomes 2×); enables style control tradeoff
50
+ - **Audio generation loop:** runs max `max_audio_length_ms // 80` frames (~12.5 Hz generation rate); stops early if any token ≥ audio_eos_id
51
+ - **Memory optimization:**
52
+ - `lazy_load=True` defers model loading, unloads after generation (saves CUDA between uses)
53
+ - Forced `lazy_load=False` if mula_device ≠ codec_device (different device types can't swap)
54
+ - Uses `torch.autocast` with specified dtype to reduce memory footprint
55
+ - **Output:** via `torchaudio.save(save_path, wav, 48000)` at fixed 48 kHz sample rate
56
+
57
+ ## Codec specifics
58
+ - `HeartCodec.detokenize(frames)` expects shape `(codebooks, time)` where codebooks ≤ 8; pads/repeats to uniform length internally
59
+ - Uses flow-matching inference in overlapping windows (reduces boundary artifacts), then scalar decoder → PCM waveform
60
+ - Fixed 48 kHz output; non-standard rates must be resampled post-generation
61
+ - Model config (`config.json`) defines number of codebooks and codec architecture
62
+
63
+ ## Transcription pipeline behaviors
64
+ - `HeartTranscriptorPipeline.from_pretrained(model_path, device, dtype)` wraps `WhisperForConditionalGeneration` from `HeartTranscriptor-oss`
65
+ - Fixed at 30-second chunks, batch size 16; no dynamic chunking
66
+ - Note: trained on separated vocals; best results with source-separated inputs (use demucs or similar pre-pipeline)
67
+ - Supports beam search and temperature kwargs via `__call__()` decoding_kwargs (see `examples/run_lyrics_transcription.py`)
68
+
69
+ ## Dev workflows & commands
70
+ - **Install:** `pip install -e .` (Python ≥3.9, 3.10 recommended; CUDA deps: torch 2.4.1, torchaudio 2.4.1, torchtune 0.4.0, bitsandbytes 0.49.0)
71
+ - **Generate:** `python examples/run_music_generation.py --model_path ./ckpt --version 3B --lyrics ./assets/lyrics.txt --tags ./assets/tags.txt`
72
+ - Key flags: `--mula_device cuda --codec_device cuda` (or separate); `--lazy_load true` (single-GPU VRAM relief); `--cfg_scale 1.5` (style strength)
73
+ - **Transcribe:** `python examples/run_lyrics_transcription.py --model_path ./ckpt --music_path ./assets/output.mp3`
74
+ - No test suite; validate changes via example scripts
75
+
76
+ ## Coding conventions & critical patterns
77
+ - **Device specs:** `from_pretrained(..., device=X)` accepts `torch.device` (both models→X) or dict `{"mula": dev1, "codec": dev2}` (forces `lazy_load=False`)
78
+ - **Dtype specs:** mirrors device—scalar dtype or dict with `"mula"`, `"codec"` keys
79
+ - **Token/text handling:** always lowercase inputs, auto-wrap tags with `<tag>...</tag>`, append BOS/EOS via tokenizer config (callers depend on this)
80
+ - **Unimplemented:** reference audio path exists but raises `NotImplementedError`; don't add stub without full end-to-end implementation
81
+ - **Generation loop internals:** `tqdm` progress, `torch.autocast` scope—avoid breaking these or model cache setup (`setup_caches()`)
82
+ - **Memory patterns:** properties `self.mula` / `self.codec` lazy-load on first access if `lazy_load=True`, then can unload via `_unload_models()`
83
+
84
+ ## Quick pointers for agents
85
+ - Extend pipelines, not models, unless changing core LLM/codec logic
86
+ - Validate paths early (mirror `_resolve_paths` style) for new entry points
87
+ - Preserve 48 kHz sample rate and codebook count (8) in outputs
88
+ - When modifying tokenization or BOS/EOS logic, verify examples still run end-to-end
89
+ - Device/dtype flexibility is intentional—test multi-GPU configs if changing device dispatch logic
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ **.pyc
3
+ checkpoint/
4
+ .DS_Store
5
+ **.wav
6
+ **.mp3
7
+ **.png
8
+ **.jpeg
9
+ **.jpg
10
+ .vscode/
11
+ **.egg-info/
12
+ build/
13
+ .idea/
14
+ ckpt/
15
+ .venv*/
16
+ .env
17
+ models/
18
+ assets/
DEPLOYMENT_GUIDE.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HeartMuLa HF Space Deployment Guide
2
+
3
+ ## What's Been Set Up
4
+
5
+ ✅ **Gradio Web UI** (`app.py`)
6
+ - Interactive music generation interface
7
+ - Real-time parameter adjustment
8
+ - Audio preview and download
9
+ - Example prompts included
10
+
11
+ ✅ **Docker Environment** (`Dockerfile` + `requirements.txt`)
12
+ - CUDA 12.1 with GPU support
13
+ - All dependencies pre-configured
14
+ - Automatic model downloading on startup
15
+
16
+ ✅ **Space Configuration**
17
+ - `README_SPACE.md` - Space documentation
18
+ - `.dockerignore` - Optimized Docker builds
19
+
20
+ ## Deployment Steps
21
+
22
+ ### 1. Push to Your HF Space Repository
23
+ ```bash
24
+ cd f:\Projects\heartlib
25
+ git remote add space https://huggingface.co/spaces/brandongraves08/test
26
+ git push space main
27
+ ```
28
+
29
+ ### 2. Configure Space Settings (in HF UI)
30
+ 1. Go to https://huggingface.co/spaces/brandongraves08/test/settings
31
+ 2. Set **Runtime** to **Docker**
32
+ 3. Select **GPU** hardware (A100, T4, or A10 recommended)
33
+ 4. Save settings
34
+
35
+ ### 3. Space Will Auto-Deploy
36
+ The Dockerfile will:
37
+ - Install all dependencies
38
+ - Download HeartMuLa and HeartCodec models
39
+ - Start the Gradio app on port 7860
40
+
41
+ ## Features
42
+
43
+ ### Generation Parameters
44
+ - **Lyrics**: Custom lyrics for the song
45
+ - **Tags**: Style descriptors (pop, rock, ambient, etc.)
46
+ - **Duration**: 5-60 seconds
47
+ - **Temperature**: 0.1-2.0 (creativity level)
48
+ - **CFG Scale**: 1.0-3.0 (style control strength)
49
+ - **Top-K**: 10-100 (sampling parameter)
50
+
51
+ ### Model Information
52
+ - **HeartMuLa-RL-oss-3B**: RL-tuned 3B model (recommended)
53
+ - **HeartCodec-oss**: High-fidelity codec (48 kHz)
54
+ - **Inference Speed**: ~RTF 1.0
55
+
56
+ ## Local Testing (Before Deploying)
57
+
58
+ Test locally with CUDA GPU:
59
+ ```bash
60
+ cd f:\Projects\heartlib
61
+ pip install gradio
62
+ python app.py
63
+ ```
64
+
65
+ Then open http://localhost:7860 in your browser.
66
+
67
+ ## Troubleshooting
68
+
69
+ ### Models Not Downloading
70
+ If the Space fails to download models:
71
+ 1. Check HF token is configured: `huggingface-cli login`
72
+ 2. Verify model access on HF
73
+ 3. Check Space logs for download errors
74
+
75
+ ### Out of Memory
76
+ - Reduce duration slider
77
+ - Use smaller model version (if available)
78
+ - Enable lazy_load in pipeline (already done)
79
+
80
+ ### Slow Generation
81
+ - Generation is ~RTF 1.0 (real-time speed)
82
+ - First run may be slower due to model loading
83
+ - CPU is ~10x slower than GPU
84
+
85
+ ## File Structure
86
+
87
+ ```
88
+ heartlib/
89
+ ├── app.py # Gradio app (main entry point)
90
+ ├── Dockerfile # Docker build config
91
+ ├── requirements.txt # Python dependencies
92
+ ├── setup.sh # Model download script
93
+ ├── README_SPACE.md # Space documentation
94
+ ├── .dockerignore # Docker build optimization
95
+ ├── src/heartlib/ # Core library
96
+ │ ├── pipelines/
97
+ │ │ ├── music_generation.py
98
+ │ │ └── lyrics_transcription.py
99
+ │ ├── heartmula/
100
+ │ └── heartcodec/
101
+ └── examples/ # Example scripts
102
+ ```
103
+
104
+ ## Next Steps
105
+
106
+ 1. **Commit & Push** to your Space repository
107
+ 2. **Monitor** the Space build logs
108
+ 3. **Test** once deployment completes
109
+ 4. **Share** the Space URL!
110
+
111
+ ## Support
112
+
113
+ - **Paper**: https://arxiv.org/pdf/2601.10547
114
+ - **GitHub**: https://github.com/HeartMuLa/heartlib
115
+ - **Discord**: https://discord.gg/BKXF5FgH
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ python3.10 \
8
+ python3-pip \
9
+ git \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Set Python 3.10 as default
13
+ RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 && \
14
+ update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1
15
+
16
+ # Copy requirements and install Python dependencies
17
+ COPY requirements.txt .
18
+ RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+ # Copy application code
21
+ COPY . .
22
+
23
+ # Install heartlib package
24
+ RUN pip install -e .
25
+
26
+ # Create models directory
27
+ RUN mkdir -p ./models
28
+
29
+ # Run setup (downloads models and starts app)
30
+ CMD ["bash", "-c", "bash setup.sh && python app.py"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
MANUAL_DEPLOYMENT.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Manual Deployment Steps for HF Space
2
+
3
+ Since HF requires token authentication, here's how to deploy manually:
4
+
5
+ ## Option 1: Use HF CLI (Recommended)
6
+ 1. Get your HF access token: https://huggingface.co/settings/tokens
7
+ 2. Create a write-enabled token
8
+ 3. Run this command with your token:
9
+
10
+ ```bash
11
+ huggingface-cli login --token hf_YOUR_TOKEN_HERE --add_to_git_credential
12
+ ```
13
+
14
+ Then push:
15
+ ```bash
16
+ cd F:\Projects\heartlib
17
+ git remote add space https://huggingface.co/spaces/brandongraves08/test
18
+ git push space main --force
19
+ ```
20
+
21
+ ## Option 2: Manual Upload via Web UI
22
+ 1. Go to https://huggingface.co/spaces/brandongraves08/test
23
+ 2. Click **Files** → **Upload files**
24
+ 3. Upload these files:
25
+ - app.py
26
+ - Dockerfile
27
+ - requirements.txt
28
+ - setup.sh
29
+ - .dockerignore
30
+ - README_SPACE.md
31
+
32
+ ## Option 3: Use Git with SSH Key
33
+ 1. Set up SSH key on HF: https://huggingface.co/settings/keys
34
+ 2. Update git remote:
35
+ ```bash
36
+ git remote set-url space git@huggingface.co:spaces/brandongraves08/test.git
37
+ git push space main
38
+ ```
39
+
40
+ ## Space Configuration (After Upload)
41
+ Once files are on the Space:
42
+ 1. Go to **Settings** (gear icon)
43
+ 2. Set **Runtime** → **Docker**
44
+ 3. Select **GPU Hardware** (A100/T4/A10G recommended)
45
+ 4. Click **Save**
46
+
47
+ The Space will auto-build and start the Gradio app!
48
+
49
+ ## Verify Deployment
50
+ Check status at: https://brandongraves08-test.hf.space
51
+
52
+ ## Troubleshooting
53
+ - **Models not downloading**: Check HF token has read access
54
+ - **Build fails**: Check Space logs for error details
55
+ - **Out of memory**: Reduce model size or use better GPU
README.md ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <picture>
3
+ <source srcset="./assets/logo.png" media="(prefers-color-scheme: dark)">
4
+ <img src="./assets/logo.png" width="30%">
5
+ </picture>
6
+
7
+ </p>
8
+
9
+ <p align="center">
10
+ <a href="https://heartmula.github.io/">Demo 🎶</a> &nbsp;|&nbsp; 📑 <a href="https://arxiv.org/pdf/2601.10547">Paper</a>
11
+ <br>
12
+ <a href="https://huggingface.co/HeartMuLa/HeartMuLa-oss-3B">HeartMuLa-oss-3B 🤗</a> &nbsp;|&nbsp; <a href="https://modelscope.cn/models/HeartMuLa/HeartMuLa-oss-3B">HeartMuLa-oss-3B <picture>
13
+ <source srcset="./assets/badge.svg" media="(prefers-color-scheme: dark)">
14
+ <img src="./assets/badge.svg" width="20px">
15
+ </picture></a>
16
+ <br>
17
+ <a href="https://huggingface.co/HeartMuLa/HeartMuLa-RL-oss-3B-20260123"> HeartMuLa-RL-oss-3B-20260123 🤗</a> &nbsp;|&nbsp; <a href="https://modelscope.cn/models/HeartMuLa/HeartMuLa-RL-oss-3B-20260123">HeartMuLa-RL-oss-3B-20260123 <picture>
18
+ <source srcset="./assets/badge.svg" media="(prefers-color-scheme: dark)">
19
+ <img src="./assets/badge.svg" width="20px">
20
+ </picture></a>
21
+
22
+ </p>
23
+
24
+ ---
25
+ # HeartMuLa: A Family of Open Sourced Music Foundation Models
26
+
27
+ HeartMuLa is a family of open sourced music foundation models including:
28
+ 1. HeartMuLa: a music language model that generates music conditioned on lyrics and tags with multilingual support including but not limited to English, Chinese, Japanese, Korean and Spanish.
29
+ 2. HeartCodec: a 12.5 hz music codec with high reconstruction fidelity;
30
+ 3. HeartTranscriptor: a whisper-based model specifically tuned for lyrics transcription; Check [this page](./examples/README.md) for its usage.
31
+ 4. HeartCLAP: an audio–text alignment model that establishes a unified embedding space for music descriptions and cross-modal retrieval.
32
+ ---
33
+
34
+
35
+ Below shows the experiment result of our oss-3B version compared with other baselines.
36
+ <p align="center">
37
+ <picture>
38
+ <source srcset="./assets/exp.png" media="(prefers-color-scheme: dark)">
39
+ <img src="./assets/exp.png" width="90%">
40
+ </picture>
41
+
42
+ </p>
43
+
44
+ ---
45
+
46
+ ## 🔥 Highlight
47
+
48
+ Our latest internal version of HeartMuLa-7B achieves **comparable performance with Suno** in terms of musicality, fidelity and controllability. If you are interested, welcome to reach us out via heartmula.ai@gmail.com
49
+
50
+ ## 📰 News
51
+ Join on Discord! [<img alt="join discord" src="https://img.shields.io/discord/842440537755353128?color=%237289da&logo=discord"/>](https://discord.gg/BKXF5FgH)
52
+
53
+ - 🚀 **23 Jan. 2026**
54
+
55
+ By leveraging Reinforcement Learning, we have continuously refined our model and are proud to officially release **HeartMuLa-RL-oss-3B-20260123**. This version is designed to achieve more precise control over styles and tags. Simultaneously, we are launching **HeartCodec-oss-20260123**, which optimizes audio decoding quality.
56
+
57
+ - 🫶 **20 Jan. 2026**
58
+
59
+ [Benji](https://github.com/benjiyaya) has created a wonderful [ComfyUI custom node](https://github.com/benjiyaya/HeartMuLa_ComfyUI) for HeartMuLa. Thanks Benji!
60
+ - ⚖️ **20 Jan. 2026**
61
+
62
+ License update: We update the license of this repo and all related model weights to **Apache 2.0**.
63
+ - 🚀 **14 Jan. 2026**
64
+ The official release of **HeartTranscriptor-oss** and the first **HeartMuLa-oss-3B** version along with our **HeartCodec-oss**.
65
+
66
+ ---
67
+ ## 🧭 TODOs
68
+
69
+ - ⏳ Release scripts for inference acceleration and streaming inference. The current inference speed is around RTF $\approx 1.0$.
70
+ - ⏳ Support **reference audio conditioning**, **fine-grained controllable music generation**, **hot song generation**.
71
+ - ⏳ Release the **HeartMuLa-oss-7B** version.
72
+ - ✅ Release inference code and pretrained checkpoints of
73
+ **HeartCodec-oss, HeartMuLa-oss-3B, and HeartTranscriptor-oss**.
74
+
75
+ ---
76
+
77
+ ## 🛠️ Local Deployment
78
+
79
+ ### ⚙️ Environment Setup
80
+
81
+ We recommend using `python=3.10` for local deployment.
82
+
83
+ Clone this repo and install locally.
84
+
85
+ ```
86
+ git clone https://github.com/HeartMuLa/heartlib.git
87
+ cd heartlib
88
+ pip install -e .
89
+ ```
90
+
91
+ Download our pretrained checkpoints from huggingface or modelscope using the following command:
92
+
93
+ ```
94
+ # if you are using huggingface
95
+ hf download --local-dir './ckpt' 'HeartMuLa/HeartMuLaGen'
96
+
97
+ ## To use version released on 20260123 (recommended)
98
+ hf download --local-dir './ckpt/HeartMuLa-oss-3B' 'HeartMuLa/HeartMuLa-RL-oss-3B-20260123'
99
+ hf download --local-dir './ckpt/HeartCodec-oss' HeartMuLa/HeartCodec-oss-20260123
100
+
101
+ ## To use oss-3B version
102
+ hf download --local-dir './ckpt/HeartMuLa-oss-3B' 'HeartMuLa/HeartMuLa-oss-3B'
103
+ hf download --local-dir './ckpt/HeartCodec-oss' 'HeartMuLa/HeartCodec-oss'
104
+
105
+ # if you are using modelscope
106
+ modelscope download --model 'HeartMuLa/HeartMuLaGen' --local_dir './ckpt'
107
+
108
+ ## To use version released on 20260123 (recommended)
109
+ modelscope download --model 'HeartMuLa/HeartMuLa-RL-oss-3B-20260123' --local_dir './ckpt/HeartMuLa-oss-3B'
110
+ modelscope download --model 'HeartMuLa/HeartCodec-oss-20260123' --local_dir './ckpt/HeartCodec-oss'
111
+
112
+ ## To use oss-3B version
113
+ modelscope download --model 'HeartMuLa/HeartMuLa-oss-3B' --local_dir './ckpt/HeartMuLa-oss-3B'
114
+ modelscope download --model 'HeartMuLa/HeartCodec-oss' --local_dir './ckpt/HeartCodec-oss'
115
+ ```
116
+
117
+ After downloading, the `./ckpt` subfolder should structure like this:
118
+ ```
119
+ ./ckpt/
120
+ ├── HeartCodec-oss/
121
+ ├── HeartMuLa-oss-3B/
122
+ ├── gen_config.json
123
+ └── tokenizer.json
124
+ ```
125
+
126
+
127
+ ### ▶️ Example Usage
128
+
129
+ To generate music, run:
130
+
131
+ ```
132
+ python ./examples/run_music_generation.py --model_path=./ckpt --version="3B"
133
+ ```
134
+
135
+ By default this command will generate a piece of music conditioned on lyrics and tags provided in `./assets` folder. The output music will be saved at `./assets/output.mp3`.
136
+
137
+ #### FAQs
138
+
139
+ 1. How to specify lyrics and tags?
140
+
141
+ The model will load lyrics from the txt file `--lyrics` link to (by default `./assets/lyrics.txt`). If you would like to use your own lyrics, just modify the content in `./assets/lyrics.txt`. If you would like to save your lyrics to another path, e.g. `my_awesome_lyrics.txt`, remember to input arguments `--lyrics my_awesome_lyrics.txt`.
142
+
143
+ For tags it's basically the same.
144
+
145
+ 2. CUDA out of memory?
146
+
147
+ If you have multi-GPUs (e.g. 2 4090s), we recommend placing the params of HeartMuLa and HeartCodec separately on different devices. You can do it by typing `--mula_device cuda:0 --codec_device cuda:1`
148
+
149
+ If you are running on a single GPU, use `--lazy_load true` so that modules will be loaded on demand and deleted once inference completed to save GPU memory.
150
+
151
+ All parameters:
152
+
153
+ - `--model_path` (required): Path to the pretrained model checkpoint
154
+ - `--lyrics`: Path to lyrics file (default: `./assets/lyrics.txt`)
155
+ - `--tags`: Path to tags file (default: `./assets/tags.txt`)
156
+ - `--save_path`: Output audio file path (default: `./assets/output.mp3`)
157
+ - `--max_audio_length_ms`: Maximum audio length in milliseconds (default: 240000)
158
+ - `--topk`: Top-k sampling parameter for generation (default: 50)
159
+ - `--temperature`: Sampling temperature for generation (default: 1.0)
160
+ - `--cfg_scale`: Classifier-free guidance scale (default: 1.5)
161
+ - `--version`: The version of HeartMuLa, choose between [`3B`, `7B`]. (default: `3B`) # `7B` version not released yet.
162
+ - `--mula_device/--codec_device`: The device where params will be placed. Both are set to `cuda` by default. You can use `--mula_device cuda:0 --codec_device cuda:1` to explicitly place different modules to different devices.
163
+ - `--mula_dtype/--codec_dtype`: Inference dtype. By default is `bf16` for HeartMuLa and `fp32` for HeartCodec. Setting `bf16` for HeartCodec may result in the degradation of audio quality.
164
+ - `--lazy_load`: Whether or not to use lazy loading (default: false). If turned on, modules will be loaded on demand to save GPU usage.
165
+ Recommended format of lyrics and tags:
166
+ ```txt
167
+ [Intro]
168
+
169
+ [Verse]
170
+ The sun creeps in across the floor
171
+ I hear the traffic outside the door
172
+ The coffee pot begins to hiss
173
+ It is another morning just like this
174
+
175
+ [Prechorus]
176
+ The world keeps spinning round and round
177
+ Feet are planted on the ground
178
+ I find my rhythm in the sound
179
+
180
+ [Chorus]
181
+ Every day the light returns
182
+ Every day the fire burns
183
+ We keep on walking down this street
184
+ Moving to the same steady beat
185
+ It is the ordinary magic that we meet
186
+
187
+ [Verse]
188
+ The hours tick deeply into noon
189
+ Chasing shadows,chasing the moon
190
+ Work is done and the lights go low
191
+ Watching the city start to glow
192
+
193
+ [Bridge]
194
+ It is not always easy,not always bright
195
+ Sometimes we wrestle with the night
196
+ But we make it to the morning light
197
+
198
+ [Chorus]
199
+ Every day the light returns
200
+ Every day the fire burns
201
+ We keep on walking down this street
202
+ Moving to the same steady beat
203
+
204
+ [Outro]
205
+ Just another day
206
+ Every single day
207
+ ```
208
+
209
+ Regarding tags, check this [issue](https://github.com/HeartMuLa/heartlib/issues/17) for reference.
210
+ Our different tags are comma-separated without spaces as illustrated below:
211
+ ```txt
212
+ piano,happy,wedding,synthesizer,romantic
213
+ ```
214
+
215
+ ---
216
+
217
+
218
+ ## ⚖️ License
219
+
220
+ This repository is licensed under the Apache 2.0 License.
221
+
222
+ ---
223
+
224
+ ## 📚 Citation
225
+
226
+ ```
227
+ @misc{yang2026heartmulafamilyopensourced,
228
+ title={HeartMuLa: A Family of Open Sourced Music Foundation Models},
229
+ author={Dongchao Yang and Yuxin Xie and Yuguo Yin and Zheyu Wang and Xiaoyu Yi and Gongxi Zhu and Xiaolong Weng and Zihan Xiong and Yingzhe Ma and Dading Cong and Jingliang Liu and Zihang Huang and Jinghan Ru and Rongjie Huang and Haoran Wan and Peixu Wang and Kuoxi Yu and Helin Wang and Liming Liang and Xianwei Zhuang and Yuanyuan Wang and Haohan Guo and Junjie Cao and Zeqian Ju and Songxiang Liu and Yuewen Cao and Heming Weng and Yuexian Zou},
230
+ year={2026},
231
+ eprint={2601.10547},
232
+ archivePrefix={arXiv},
233
+ primaryClass={cs.SD},
234
+ url={https://arxiv.org/abs/2601.10547},
235
+ }
236
+ ```
237
+
238
+ ## 📬 Contact
239
+ If you are interested in HeartMuLa, feel free to reach us at heartmula.ai@gmail.com
240
+
241
+ Welcome to join us through [discord](https://discord.gg/BKXF5FgH) or Wechat group.
242
+
243
+ Scan the QR code on the left to join our Wechat group. If it expires, feel free to raise an issue to remind us of updating.
244
+
245
+ If the number of group members exceeds 200, joining the group via directly scanning the QR code is restricted by WeChat. In this case, scan our team member's QR code on the right and send a request writing **HeartMuLa Group Invite**. We will invite you into the group manually.
246
+ <p align="center">
247
+ <picture>
248
+ <source srcset="./assets/group_wx.jpeg" media="(prefers-color-scheme: dark)">
249
+ <img src="./assets/group_wx.jpeg" width="40%">
250
+ </picture>
251
+ <picture>
252
+ <source srcset="./assets/lead_wx.jpeg" media="(prefers-color-scheme: dark)">
253
+ <img src="./assets/lead_wx.jpeg" width="40%">
254
+ </picture>
255
+ </p>
README_SPACE.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: HeartMuLa Music Generation
3
+ emoji: 🎵
4
+ colorFrom: purple
5
+ colorTo: pink
6
+ sdk: docker
7
+ sdk_version: latest
8
+ app_file: app.py
9
+ pinned: true
10
+ duplicated_from: HeartMuLa/HeartMuLa-oss
11
+ ---
12
+
13
+ # HeartMuLa Music Generation Space
14
+
15
+ Generate music from lyrics and style tags using the HeartMuLa family of open-source music foundation models.
16
+
17
+ ## Features
18
+
19
+ - **Music Generation**: Convert lyrics + style tags → audio via two-stage pipeline
20
+ - **HeartMuLa LLM**: Frame-by-frame audio token generation with style control
21
+ - **HeartCodec**: High-fidelity flow-matching codec (48 kHz output)
22
+ - **Multiple Model Sizes**: 3B, 7B, and 300M versions available
23
+
24
+ ## Setup
25
+
26
+ The Space will automatically download and set up the required models on first run.
27
+
28
+ ## Usage
29
+
30
+ 1. Enter your **lyrics** in the text field
31
+ 2. Add **style tags** (e.g., "pop, upbeat, energetic")
32
+ 3. Adjust generation parameters:
33
+ - **Duration**: Length of generated music (5-60 seconds)
34
+ - **Temperature**: Creativity level (0.1-2.0)
35
+ - **CFG Scale**: Style control strength (1.0-3.0)
36
+ - **Top-K**: Sampling parameter (10-100)
37
+ 4. Click **Generate Music** to create your track
38
+
39
+ ## Model Information
40
+
41
+ - **HeartMuLa-RL-oss-3B-20260123**: RL-tuned version with improved style control (recommended)
42
+ - **HeartCodec-oss-20260123**: Optimized audio decoding quality
43
+
44
+ ## Performance
45
+
46
+ - RTF ≈ 1.0 (real-time inference speed)
47
+ - 48 kHz sample rate output
48
+ - Supports multiple languages
49
+
50
+ ## References
51
+
52
+ - [Paper](https://arxiv.org/pdf/2601.10547)
53
+ - [GitHub](https://github.com/HeartMuLa/heartlib)
54
+ - [Discord](https://discord.gg/BKXF5FgH)
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ HeartMuLa Music Generation Gradio App for Hugging Face Spaces
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import os
9
+ from pathlib import Path
10
+ from heartlib import HeartMuLaGenPipeline
11
+ import logging
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Configuration
17
+ MODEL_PATH = "./models"
18
+ DEFAULT_VERSION = "3B"
19
+
20
+ # Global pipeline instance
21
+ pipeline = None
22
+
23
+
24
+ def load_pipeline():
25
+ """Load the HeartMuLa pipeline"""
26
+ global pipeline
27
+
28
+ if pipeline is not None:
29
+ return pipeline
30
+
31
+ logger.info("Loading HeartMuLa pipeline...")
32
+
33
+ # Determine device
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ dtype = torch.float16 if device == "cuda" else torch.float32
36
+
37
+ logger.info(f"Using device: {device}, dtype: {dtype}")
38
+
39
+ try:
40
+ pipeline = HeartMuLaGenPipeline.from_pretrained(
41
+ MODEL_PATH,
42
+ device={
43
+ "mula": torch.device(device),
44
+ "codec": torch.device(device),
45
+ },
46
+ dtype={
47
+ "mula": dtype,
48
+ "codec": dtype,
49
+ },
50
+ version=DEFAULT_VERSION,
51
+ lazy_load=True,
52
+ )
53
+ logger.info("Pipeline loaded successfully!")
54
+ return pipeline
55
+ except Exception as e:
56
+ logger.error(f"Failed to load pipeline: {e}")
57
+ raise gr.Error(f"Failed to load model: {e}")
58
+
59
+
60
+ def generate_music(
61
+ lyrics: str,
62
+ tags: str,
63
+ max_duration: int = 30,
64
+ temperature: float = 1.0,
65
+ top_k: int = 50,
66
+ cfg_scale: float = 1.5,
67
+ ):
68
+ """Generate music from lyrics and tags"""
69
+
70
+ if not lyrics.strip():
71
+ raise gr.Error("Please enter lyrics")
72
+ if not tags.strip():
73
+ raise gr.Error("Please enter tags")
74
+
75
+ try:
76
+ logger.info(f"Generating music with lyrics: {lyrics[:50]}... and tags: {tags}")
77
+
78
+ # Load pipeline
79
+ pipe = load_pipeline()
80
+
81
+ # Convert duration to milliseconds
82
+ max_audio_length_ms = max_duration * 1000
83
+
84
+ # Generate music
85
+ output_path = "/tmp/generated_music.mp3"
86
+ os.makedirs("/tmp", exist_ok=True)
87
+
88
+ with torch.no_grad():
89
+ pipe(
90
+ {
91
+ "lyrics": lyrics,
92
+ "tags": tags,
93
+ },
94
+ max_audio_length_ms=max_audio_length_ms,
95
+ save_path=output_path,
96
+ topk=top_k,
97
+ temperature=temperature,
98
+ cfg_scale=cfg_scale,
99
+ )
100
+
101
+ logger.info(f"Music generated successfully: {output_path}")
102
+ return output_path
103
+
104
+ except Exception as e:
105
+ logger.error(f"Error during generation: {e}")
106
+ raise gr.Error(f"Generation failed: {e}")
107
+
108
+
109
+ def main():
110
+ """Create Gradio interface"""
111
+
112
+ # Check if models exist
113
+ if not Path(MODEL_PATH).exists():
114
+ logger.warning(f"Models directory not found at {MODEL_PATH}")
115
+ logger.info("You need to download the models first:")
116
+ logger.info("hf download --local-dir './models/HeartMuLa-oss-3B' HeartMuLa/HeartMuLa-RL-oss-3B-20260123")
117
+ logger.info("hf download --local-dir './models/HeartCodec-oss' HeartMuLa/HeartCodec-oss-20260123")
118
+
119
+ with gr.Blocks(title="HeartMuLa Music Generation") as demo:
120
+ gr.Markdown("""
121
+ # 🎵 HeartMuLa Music Generation
122
+
123
+ Generate music from lyrics and style tags using HeartMuLa, a family of open-source music foundation models.
124
+ """)
125
+
126
+ with gr.Row():
127
+ with gr.Column():
128
+ gr.Markdown("### Input")
129
+
130
+ lyrics_input = gr.Textbox(
131
+ label="Lyrics",
132
+ placeholder="Enter your lyrics here...",
133
+ lines=4,
134
+ )
135
+
136
+ tags_input = gr.Textbox(
137
+ label="Style Tags",
138
+ placeholder="e.g., pop, upbeat, energetic",
139
+ lines=2,
140
+ )
141
+
142
+ with gr.Row():
143
+ duration_slider = gr.Slider(
144
+ minimum=5,
145
+ maximum=60,
146
+ value=30,
147
+ step=5,
148
+ label="Duration (seconds)",
149
+ )
150
+
151
+ with gr.Row():
152
+ temp_slider = gr.Slider(
153
+ minimum=0.1,
154
+ maximum=2.0,
155
+ value=1.0,
156
+ step=0.1,
157
+ label="Temperature",
158
+ )
159
+
160
+ cfg_slider = gr.Slider(
161
+ minimum=1.0,
162
+ maximum=3.0,
163
+ value=1.5,
164
+ step=0.1,
165
+ label="CFG Scale (style strength)",
166
+ )
167
+
168
+ with gr.Row():
169
+ topk_slider = gr.Slider(
170
+ minimum=10,
171
+ maximum=100,
172
+ value=50,
173
+ step=5,
174
+ label="Top-K",
175
+ )
176
+
177
+ generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg")
178
+
179
+ with gr.Column():
180
+ gr.Markdown("### Output")
181
+ audio_output = gr.Audio(label="Generated Music", type="filepath")
182
+
183
+ gr.Markdown("""
184
+ ### 📝 Tips
185
+ - **Lyrics**: Describe the vocals and melody
186
+ - **Tags**: Use style descriptors like "pop", "rock", "ambient", "upbeat", etc.
187
+ - **CFG Scale**: Higher values = stronger style control (1.5 is recommended)
188
+ - **Temperature**: Higher = more creative, lower = more consistent
189
+ """)
190
+
191
+ # Connect button to generation function
192
+ generate_btn.click(
193
+ fn=generate_music,
194
+ inputs=[
195
+ lyrics_input,
196
+ tags_input,
197
+ duration_slider,
198
+ temp_slider,
199
+ topk_slider,
200
+ cfg_slider,
201
+ ],
202
+ outputs=audio_output,
203
+ )
204
+
205
+ # Example inputs
206
+ gr.Examples(
207
+ examples=[
208
+ [
209
+ "Love is in the air, feel the magic",
210
+ "pop, upbeat, romantic",
211
+ ],
212
+ [
213
+ "Dark skies falling down, lonely tonight",
214
+ "rock, emotional, melancholic",
215
+ ],
216
+ ],
217
+ inputs=[lyrics_input, tags_input],
218
+ outputs=audio_output,
219
+ fn=generate_music,
220
+ cache_examples=False,
221
+ )
222
+
223
+ return demo
224
+
225
+
226
+ if __name__ == "__main__":
227
+ demo = main()
228
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
examples/README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎤 Lyrics Transcription
2
+
3
+ Download checkpoint using any of the following command:
4
+ ```
5
+ hf download --local_dir './ckpt/HeartTranscriptor-oss' 'HeartMuLa/HeartTranscriptor-oss'
6
+ modelscope download --model 'HeartMuLa/HeartTranscriptor-oss' --local_dir './ckpt/HeartTranscriptor-oss'
7
+ ```
8
+
9
+ ```
10
+ python ./examples/run_lyrics_transcription.py --model_path=./ckpt
11
+ ```
12
+
13
+ By default this command will load the generated music file at `./assets/output.mp3` and print the transcribed lyrics. Use `--music_path` to specify the path to the music file.
14
+
15
+ Note that our HeartTranscriptor is trained on separated vocal tracks. In this example usage part, we directly demonstrate on unseparated music tracks, which is purely for simplicity of illustration. We recommend using source separation tools like demucs to separate the tracks before transcribing lyrics to achieve better results.
examples/run_lyrics_transcription.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from heartlib import HeartTranscriptorPipeline
2
+ import argparse
3
+ import torch
4
+
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument("--model_path", type=str, required=True)
9
+ parser.add_argument("--music_path", type=str, default="./assets/output.mp3")
10
+
11
+ return parser.parse_args()
12
+
13
+
14
+ if __name__ == "__main__":
15
+ args = parse_args()
16
+ pipe = HeartTranscriptorPipeline.from_pretrained(
17
+ args.model_path,
18
+ device=torch.device("cuda"),
19
+ dtype=torch.float16,
20
+ )
21
+ with torch.no_grad():
22
+ result = pipe(
23
+ args.music_path,
24
+ **{
25
+ "max_new_tokens": 256,
26
+ "num_beams": 2,
27
+ "task": "transcribe",
28
+ "condition_on_prev_tokens": False,
29
+ "compression_ratio_threshold": 1.8,
30
+ "temperature": (0.0, 0.1, 0.2, 0.4),
31
+ "logprob_threshold": -1.0,
32
+ "no_speech_threshold": 0.4,
33
+ },
34
+ )
35
+ print(result)
examples/run_music_generation.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from heartlib import HeartMuLaGenPipeline
2
+ import argparse
3
+ import torch
4
+
5
+
6
+ def str2bool(value):
7
+ if isinstance(value, bool):
8
+ return value
9
+ if value.lower() in ("yes", "y", "true", "t", "1"):
10
+ return True
11
+ elif value.lower() in ("no", "n", "false", "f", "0"):
12
+ return False
13
+ else:
14
+ raise argparse.ArgumentTypeError(f"Boolean value expected. Got: {value}")
15
+
16
+
17
+ def str2dtype(value):
18
+ value = value.lower()
19
+ if value == "float32" or value == "fp32":
20
+ return torch.float32
21
+ elif value == "float16" or value == "fp16":
22
+ return torch.float16
23
+ elif value == "bfloat16" or value == "bf16":
24
+ return torch.bfloat16
25
+ else:
26
+ raise argparse.ArgumentTypeError(f"Dtype not recognized: {value}")
27
+
28
+
29
+ def str2device(value):
30
+ value = value.lower()
31
+ return torch.device(value)
32
+
33
+
34
+ def parse_args():
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument("--model_path", type=str, required=True)
37
+ parser.add_argument("--version", type=str, default="3B")
38
+ parser.add_argument("--lyrics", type=str, default="./assets/lyrics.txt")
39
+ parser.add_argument("--tags", type=str, default="./assets/tags.txt")
40
+ parser.add_argument("--save_path", type=str, default="./assets/output.mp3")
41
+
42
+ parser.add_argument("--max_audio_length_ms", type=int, default=240_000)
43
+ parser.add_argument("--topk", type=int, default=50)
44
+ parser.add_argument("--temperature", type=float, default=1.0)
45
+ parser.add_argument("--cfg_scale", type=float, default=1.5)
46
+ parser.add_argument("--mula_device", type=str2device, default="cuda")
47
+ parser.add_argument("--codec_device", type=str2device, default="cuda")
48
+ parser.add_argument("--mula_dtype", type=str2dtype, default="bfloat16")
49
+ parser.add_argument("--codec_dtype", type=str2dtype, default="float32")
50
+ parser.add_argument("--lazy_load", type=str2bool, default=False)
51
+ return parser.parse_args()
52
+
53
+
54
+ if __name__ == "__main__":
55
+ args = parse_args()
56
+ pipe = HeartMuLaGenPipeline.from_pretrained(
57
+ args.model_path,
58
+ device={
59
+ "mula": torch.device(args.mula_device),
60
+ "codec": torch.device(args.codec_device),
61
+ },
62
+ dtype={
63
+ "mula": args.mula_dtype,
64
+ "codec": args.codec_dtype,
65
+ },
66
+ version=args.version,
67
+ lazy_load=args.lazy_load,
68
+ )
69
+ with torch.no_grad():
70
+ pipe(
71
+ {
72
+ "lyrics": args.lyrics,
73
+ "tags": args.tags,
74
+ },
75
+ max_audio_length_ms=args.max_audio_length_ms,
76
+ save_path=args.save_path,
77
+ topk=args.topk,
78
+ temperature=args.temperature,
79
+ cfg_scale=args.cfg_scale,
80
+ )
81
+ print(f"Generated music saved to {args.save_path}")
pyproject.toml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "heartlib"
7
+ version = "0.1.0"
8
+ description = "A Python Library."
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ license = {text = "Apache-2.0"}
12
+ authors = [
13
+ {name = "HeartMuLa Team", email = "heartmula.ai@gmail.com"}
14
+ ]
15
+ dependencies = [
16
+ "numpy==2.0.2",
17
+ "torch==2.4.1",
18
+ "torchaudio==2.4.1",
19
+ "torchtune==0.4.0",
20
+ "torchao==0.9.0",
21
+ "torchvision==0.19.1",
22
+ "tqdm==4.67.1",
23
+ "traitlets==5.7.1",
24
+ "traittypes==0.2.3",
25
+ "transformers==4.57.0",
26
+ "tokenizers==0.22.1",
27
+ "ipykernel==6.17.1",
28
+ "einops==0.8.1",
29
+ "accelerate==1.12.0",
30
+ "bitsandbytes==0.49.0",
31
+ "vector-quantize-pytorch==1.27.15",
32
+ "modelscope==1.33.0",
33
+ "soundfile"
34
+ ]
35
+ urls = { "homepage" = "https://heartmula.github.io/" }
36
+ classifiers = [
37
+ "Programming Language :: Python :: 3",
38
+ "Operating System :: OS Independent"
39
+ ]
40
+
41
+ [tool.setuptools]
42
+ package-dir = {"" = "src"}
43
+
44
+ [tool.setuptools.packages.find]
45
+ where = ["src"]
46
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.1
2
+ torchaudio==2.4.1
3
+ transformers==4.40.0
4
+ safetensors==0.4.1
5
+ bitsandbytes==0.49.0
6
+ torchtune==0.4.0
7
+ tokenizers==0.15.0
8
+ tqdm==4.66.1
9
+ gradio==4.36.1
10
+ pydantic==2.5.0
11
+ numpy==1.24.3
setup.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Setup script for HF Space deployment
3
+
4
+ echo "Installing HeartMuLa package..."
5
+ pip install -e .
6
+
7
+ echo "Downloading HeartMuLa models..."
8
+ mkdir -p ./models
9
+
10
+ # Download HeartMuLa 3B model (RL-tuned version - recommended)
11
+ echo "Downloading HeartMuLa-RL-oss-3B-20260123..."
12
+ huggingface-cli download \
13
+ --local-dir "./models/HeartMuLa-oss-3B" \
14
+ HeartMuLa/HeartMuLa-RL-oss-3B-20260123
15
+
16
+ # Download HeartCodec
17
+ echo "Downloading HeartCodec-oss-20260123..."
18
+ huggingface-cli download \
19
+ --local-dir "./models/HeartCodec-oss" \
20
+ HeartMuLa/HeartCodec-oss-20260123
21
+
22
+ # Copy tokenizer and config
23
+ echo "Copying tokenizer and config..."
24
+ huggingface-cli download \
25
+ --local-dir "./models" \
26
+ HeartMuLa/HeartMuLaGen
27
+
28
+ echo "Setup complete! Starting Gradio app..."
src/heartlib/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .pipelines.music_generation import HeartMuLaGenPipeline
2
+ from .pipelines.lyrics_transcription import HeartTranscriptorPipeline
3
+
4
+ __all__ = [
5
+ "HeartMuLaGenPipeline",
6
+ "HeartTranscriptorPipeline"
7
+ ]
src/heartlib/heartcodec/configuration_heartcodec.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class HeartCodecConfig(PretrainedConfig):
6
+ model_type = "heartcodec"
7
+
8
+ def __init__(
9
+ self,
10
+ # config for rvq
11
+ dim: int = 512,
12
+ codebook_size: int = 8192,
13
+ decay: float = 0.9,
14
+ commitment_weight: float = 1.0,
15
+ threshold_ema_dead_code: int = 2,
16
+ use_cosine_sim: bool = False,
17
+ codebook_dim: int = 32,
18
+ num_quantizers: int = 8,
19
+ # config for diffusion transformer
20
+ attention_head_dim: int = 64,
21
+ in_channels: int = 1024,
22
+ norm_type: str = "ada_norm_single",
23
+ num_attention_heads: int = 24,
24
+ num_layers: int = 24,
25
+ num_layers_2: int = 6,
26
+ out_channels: int = 256,
27
+ # config for sq codec
28
+ num_bands: int = 1,
29
+ sample_rate: int = 48000,
30
+ causal: bool = True,
31
+ num_samples: int = 2,
32
+ downsample_factors: List[int] = [3, 4, 4, 4, 5],
33
+ downsample_kernel_sizes: List[int] = [6, 8, 8, 8, 10],
34
+ upsample_factors: List[int] = [5, 4, 4, 4, 3],
35
+ upsample_kernel_sizes: List[int] = [10, 8, 8, 8, 6],
36
+ latent_hidden_dim: int = 128,
37
+ default_kernel_size: int = 7,
38
+ delay_kernel_size: int = 5,
39
+ init_channel: int = 64,
40
+ res_kernel_size: int = 7,
41
+ **kwargs
42
+ ):
43
+ super().__init__(**kwargs)
44
+ self.dim = dim
45
+ self.codebook_size = codebook_size
46
+ self.decay = decay
47
+ self.commitment_weight = commitment_weight
48
+ self.threshold_ema_dead_code = threshold_ema_dead_code
49
+ self.use_cosine_sim = use_cosine_sim
50
+ self.codebook_dim = codebook_dim
51
+ self.num_quantizers = num_quantizers
52
+
53
+ self.attention_head_dim = attention_head_dim
54
+ self.in_channels = in_channels
55
+ self.norm_type = norm_type
56
+ self.num_attention_heads = num_attention_heads
57
+ self.num_layers = num_layers
58
+ self.num_layers_2 = num_layers_2
59
+ self.out_channels = out_channels
60
+
61
+ self.num_bands = num_bands
62
+ self.sample_rate = sample_rate
63
+ self.causal = causal
64
+ self.num_samples = num_samples
65
+ self.downsample_factors = downsample_factors
66
+ self.downsample_kernel_sizes = downsample_kernel_sizes
67
+ self.upsample_factors = upsample_factors
68
+ self.upsample_kernel_sizes = upsample_kernel_sizes
69
+ self.latent_hidden_dim = latent_hidden_dim
70
+ self.default_kernel_size = default_kernel_size
71
+ self.delay_kernel_size = delay_kernel_size
72
+ self.init_channel = init_channel
73
+ self.res_kernel_size = res_kernel_size
src/heartlib/heartcodec/modeling_heartcodec.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .models.flow_matching import FlowMatching
3
+ from .models.sq_codec import ScalarModel
4
+ from .configuration_heartcodec import HeartCodecConfig
5
+ from transformers.modeling_utils import PreTrainedModel
6
+ import math
7
+ import numpy as np
8
+
9
+
10
+ class HeartCodec(PreTrainedModel):
11
+ config_class = HeartCodecConfig
12
+
13
+ def __init__(
14
+ self,
15
+ config: HeartCodecConfig,
16
+ ):
17
+ super(HeartCodec, self).__init__(config)
18
+
19
+ self.config = config
20
+
21
+ self.flow_matching = FlowMatching(
22
+ dim=config.dim,
23
+ codebook_size=config.codebook_size,
24
+ decay=config.decay,
25
+ commitment_weight=config.commitment_weight,
26
+ threshold_ema_dead_code=config.threshold_ema_dead_code,
27
+ use_cosine_sim=config.use_cosine_sim,
28
+ codebook_dim=config.codebook_dim,
29
+ num_quantizers=config.num_quantizers,
30
+ attention_head_dim=config.attention_head_dim,
31
+ in_channels=config.in_channels,
32
+ norm_type=config.norm_type,
33
+ num_attention_heads=config.num_attention_heads,
34
+ num_layers=config.num_layers,
35
+ num_layers_2=config.num_layers_2,
36
+ out_channels=config.out_channels,
37
+ )
38
+ self.scalar_model = ScalarModel(
39
+ num_bands=config.num_bands,
40
+ sample_rate=config.sample_rate,
41
+ causal=config.causal,
42
+ num_samples=config.num_samples,
43
+ downsample_factors=config.downsample_factors,
44
+ downsample_kernel_sizes=config.downsample_kernel_sizes,
45
+ upsample_factors=config.upsample_factors,
46
+ upsample_kernel_sizes=config.upsample_kernel_sizes,
47
+ latent_hidden_dim=config.latent_hidden_dim,
48
+ default_kernel_size=config.default_kernel_size,
49
+ delay_kernel_size=config.delay_kernel_size,
50
+ init_channel=config.init_channel,
51
+ res_kernel_size=config.res_kernel_size,
52
+ )
53
+ self.post_init()
54
+
55
+ self.sample_rate = config.sample_rate
56
+
57
+ @torch.inference_mode()
58
+ def detokenize(
59
+ self,
60
+ codes,
61
+ duration=29.76,
62
+ num_steps=10,
63
+ disable_progress=False,
64
+ guidance_scale=1.25,
65
+ ):
66
+ codes = codes.unsqueeze(0).to(self.device)
67
+ first_latent = torch.randn(
68
+ codes.shape[0], int(duration * 25), 256, dtype=self.dtype
69
+ ).to(
70
+ self.device
71
+ ) # B, T, 64
72
+ first_latent_length = 0
73
+ first_latent_codes_length = 0
74
+ min_samples = int(duration * 12.5)
75
+ hop_samples = min_samples // 93 * 80
76
+ ovlp_samples = min_samples - hop_samples
77
+ ovlp_frames = ovlp_samples * 2
78
+ codes_len = codes.shape[-1] #
79
+ target_len = int(
80
+ (codes_len - first_latent_codes_length) / 12.5 * self.sample_rate
81
+ )
82
+
83
+ # code repeat
84
+ if codes_len < min_samples:
85
+ while codes.shape[-1] < min_samples:
86
+ codes = torch.cat([codes, codes], -1)
87
+ codes = codes[:, :, 0:min_samples]
88
+ codes_len = codes.shape[-1]
89
+ if (codes_len - ovlp_frames) % hop_samples > 0:
90
+ len_codes = (
91
+ math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples
92
+ + ovlp_samples
93
+ )
94
+ while codes.shape[-1] < len_codes:
95
+ codes = torch.cat([codes, codes], -1)
96
+ codes = codes[:, :, 0:len_codes]
97
+ latent_length = int(duration * 25)
98
+ latent_list = []
99
+
100
+ for sinx in range(0, codes.shape[-1] - hop_samples + 1, hop_samples):
101
+ codes_input = []
102
+ codes_input.append(codes[:, :, sinx : sinx + min_samples])
103
+ if sinx == 0 or ovlp_frames == 0:
104
+ incontext_length = first_latent_length
105
+ latents = self.flow_matching.inference_codes(
106
+ codes_input,
107
+ first_latent,
108
+ latent_length,
109
+ incontext_length,
110
+ guidance_scale=guidance_scale,
111
+ num_steps=num_steps,
112
+ disable_progress=disable_progress,
113
+ scenario="other_seg",
114
+ )
115
+ latent_list.append(latents)
116
+ else:
117
+ true_latent = latent_list[-1][:, -ovlp_frames:, :]
118
+ len_add_to_latent = latent_length - true_latent.shape[1] #
119
+ incontext_length = true_latent.shape[1]
120
+ true_latent = torch.cat(
121
+ [
122
+ true_latent,
123
+ torch.randn(
124
+ true_latent.shape[0],
125
+ len_add_to_latent,
126
+ true_latent.shape[-1],
127
+ dtype=self.dtype,
128
+ ).to(self.device),
129
+ ],
130
+ 1,
131
+ )
132
+ latents = self.flow_matching.inference_codes(
133
+ codes_input,
134
+ true_latent,
135
+ latent_length,
136
+ incontext_length,
137
+ guidance_scale=guidance_scale,
138
+ num_steps=num_steps,
139
+ disable_progress=disable_progress,
140
+ scenario="other_seg",
141
+ )
142
+ latent_list.append(latents)
143
+
144
+ # latent_list = [l.float() for l in latent_list]
145
+ latent_list[0] = latent_list[0][:, first_latent_length:, :]
146
+ min_samples = int(duration * self.sample_rate)
147
+ hop_samples = min_samples // 93 * 80
148
+ ovlp_samples = min_samples - hop_samples
149
+
150
+ output = None
151
+ for i in range(len(latent_list)):
152
+ latent = latent_list[i]
153
+ bsz, t, f = latent.shape
154
+
155
+ latent = latent.reshape(
156
+ latent.shape[0], latent.shape[1], 2, latent.shape[2] // 2
157
+ ).permute(0, 2, 1, 3)
158
+ latent = latent.reshape(
159
+ latent.shape[0] * 2, latent.shape[2], latent.shape[3]
160
+ )
161
+ cur_output = (
162
+ self.scalar_model.decode(latent.transpose(1, 2)).squeeze(0).squeeze(1)
163
+ ) # 1 512 256
164
+
165
+ cur_output = cur_output[:, 0:min_samples].detach().cpu() # B, T
166
+ if cur_output.dim() == 3:
167
+ cur_output = cur_output[0]
168
+
169
+ if output is None:
170
+ output = cur_output
171
+ else:
172
+ if ovlp_samples == 0:
173
+ output = torch.cat([output, cur_output], -1)
174
+ else:
175
+ ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
176
+ ov_win = torch.cat([ov_win, 1 - ov_win], -1)
177
+ output[:, -ovlp_samples:] = (
178
+ output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:]
179
+ + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
180
+ )
181
+ output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
182
+ output = output[:, 0:target_len]
183
+ return output
src/heartlib/heartcodec/models/flow_matching.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from tqdm import tqdm
5
+ from vector_quantize_pytorch import ResidualVQ
6
+ from .transformer import LlamaTransformer
7
+
8
+
9
+ class FlowMatching(nn.Module):
10
+ def __init__(
11
+ self,
12
+ # rvq stuff
13
+ dim: int = 512,
14
+ codebook_size: int = 8192,
15
+ decay: float = 0.9,
16
+ commitment_weight: float = 1.0,
17
+ threshold_ema_dead_code: int = 2,
18
+ use_cosine_sim: bool = False,
19
+ codebook_dim: int = 32,
20
+ num_quantizers: int = 8,
21
+ # dit backbone stuff
22
+ attention_head_dim: int = 64,
23
+ in_channels: int = 1024,
24
+ norm_type: str = "ada_norm_single",
25
+ num_attention_heads: int = 24,
26
+ num_layers: int = 24,
27
+ num_layers_2: int = 6,
28
+ out_channels: int = 256,
29
+ ):
30
+ super().__init__()
31
+
32
+ self.vq_embed = ResidualVQ(
33
+ dim=dim,
34
+ codebook_size=codebook_size,
35
+ decay=decay,
36
+ commitment_weight=commitment_weight,
37
+ threshold_ema_dead_code=threshold_ema_dead_code,
38
+ use_cosine_sim=use_cosine_sim,
39
+ codebook_dim=codebook_dim,
40
+ num_quantizers=num_quantizers,
41
+ )
42
+ self.cond_feature_emb = nn.Linear(dim, dim)
43
+ self.zero_cond_embedding1 = nn.Parameter(torch.randn(dim))
44
+ self.estimator = LlamaTransformer(
45
+ attention_head_dim=attention_head_dim,
46
+ in_channels=in_channels,
47
+ norm_type=norm_type,
48
+ num_attention_heads=num_attention_heads,
49
+ num_layers=num_layers,
50
+ num_layers_2=num_layers_2,
51
+ out_channels=out_channels,
52
+ )
53
+
54
+ self.latent_dim = out_channels
55
+
56
+ @torch.no_grad()
57
+ def inference_codes(
58
+ self,
59
+ codes,
60
+ true_latents,
61
+ latent_length,
62
+ incontext_length,
63
+ guidance_scale=2.0,
64
+ num_steps=20,
65
+ disable_progress=True,
66
+ scenario="start_seg",
67
+ ):
68
+ device = true_latents.device
69
+ dtype = true_latents.dtype
70
+ # codes_bestrq_middle, codes_bestrq_last = codes
71
+ codes_bestrq_emb = codes[0]
72
+
73
+ batch_size = codes_bestrq_emb.shape[0]
74
+ self.vq_embed.eval()
75
+ quantized_feature_emb = self.vq_embed.get_output_from_indices(
76
+ codes_bestrq_emb.transpose(1, 2)
77
+ )
78
+ quantized_feature_emb = self.cond_feature_emb(quantized_feature_emb) # b t 512
79
+ # assert 1==2
80
+ quantized_feature_emb = F.interpolate(
81
+ quantized_feature_emb.permute(0, 2, 1), scale_factor=2, mode="nearest"
82
+ ).permute(0, 2, 1)
83
+
84
+ num_frames = quantized_feature_emb.shape[1] #
85
+ latents = torch.randn(
86
+ (batch_size, num_frames, self.latent_dim), device=device, dtype=dtype
87
+ )
88
+ latent_masks = torch.zeros(
89
+ latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device
90
+ )
91
+ latent_masks[:, 0:latent_length] = 2
92
+ if scenario == "other_seg":
93
+ latent_masks[:, 0:incontext_length] = 1
94
+
95
+ quantized_feature_emb = (latent_masks > 0.5).unsqueeze(
96
+ -1
97
+ ) * quantized_feature_emb + (latent_masks < 0.5).unsqueeze(
98
+ -1
99
+ ) * self.zero_cond_embedding1.unsqueeze(
100
+ 0
101
+ )
102
+
103
+ incontext_latents = (
104
+ true_latents
105
+ * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
106
+ )
107
+ incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
108
+
109
+ additional_model_input = torch.cat([quantized_feature_emb], 1)
110
+ temperature = 1.0
111
+ t_span = torch.linspace(
112
+ 0, 1, num_steps + 1, device=quantized_feature_emb.device
113
+ )
114
+ latents = self.solve_euler(
115
+ latents * temperature,
116
+ incontext_latents.to(dtype),
117
+ incontext_length,
118
+ t_span,
119
+ additional_model_input,
120
+ guidance_scale,
121
+ )
122
+
123
+ latents[:, 0:incontext_length, :] = incontext_latents[
124
+ :, 0:incontext_length, :
125
+ ] # B, T, dim
126
+ return latents
127
+
128
+ def solve_euler(self, x, incontext_x, incontext_length, t_span, mu, guidance_scale):
129
+ """
130
+ Fixed euler solver for ODEs.
131
+ Args:
132
+ x (torch.Tensor): random noise
133
+ t_span (torch.Tensor): n_timesteps interpolated
134
+ shape: (n_timesteps + 1,)
135
+ mu (torch.Tensor): output of encoder
136
+ shape: (batch_size, n_feats, mel_timesteps)
137
+ """
138
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
139
+ noise = x.clone()
140
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
141
+ # Or in future might add like a return_all_steps flag
142
+ sol = []
143
+ for step in tqdm(range(1, len(t_span))):
144
+ x[:, 0:incontext_length, :] = (1 - (1 - 1e-6) * t) * noise[
145
+ :, 0:incontext_length, :
146
+ ] + t * incontext_x[:, 0:incontext_length, :]
147
+ if guidance_scale > 1.0:
148
+ dphi_dt = self.estimator(
149
+ torch.cat(
150
+ [
151
+ torch.cat([x, x], 0),
152
+ torch.cat([incontext_x, incontext_x], 0),
153
+ torch.cat([torch.zeros_like(mu), mu], 0),
154
+ ],
155
+ 2,
156
+ ),
157
+ timestep=t.unsqueeze(-1).repeat(2),
158
+ )
159
+ dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2, 0)
160
+ dphi_dt = dphi_dt_uncond + guidance_scale * (
161
+ dhpi_dt_cond - dphi_dt_uncond
162
+ )
163
+ else:
164
+ dphi_dt = self.estimator(
165
+ torch.cat([x, incontext_x, mu], 2), timestep=t.unsqueeze(-1)
166
+ )
167
+
168
+ x = x + dt * dphi_dt
169
+ t = t + dt
170
+ sol.append(x)
171
+ if step < len(t_span) - 1:
172
+ dt = t_span[step + 1] - t
173
+
174
+ result = sol[-1]
175
+
176
+ return result
src/heartlib/heartcodec/models/sq_codec.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from torch.nn.utils.parametrizations import weight_norm
6
+ from torch.nn.utils import remove_weight_norm
7
+ from torch.autograd.function import InplaceFunction
8
+
9
+
10
+ def get_padding(kernel_size, dilation=1):
11
+ return int((kernel_size * dilation - dilation) / 2)
12
+
13
+
14
+ # Scripting this brings model speed up 1.4x
15
+ @torch.jit.script
16
+ def snake(x, alpha):
17
+ shape = x.shape
18
+ x = x.reshape(shape[0], shape[1], -1)
19
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
20
+ x = x.reshape(shape)
21
+ return x
22
+
23
+
24
+ class Snake1d(nn.Module):
25
+ def __init__(self, channels):
26
+ super().__init__()
27
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
28
+
29
+ def forward(self, x):
30
+ return snake(x, self.alpha)
31
+
32
+
33
+ class Conv1d(nn.Conv1d):
34
+ def __init__(
35
+ self,
36
+ in_channels: int,
37
+ out_channels: int,
38
+ kernel_size: int,
39
+ stride: int = 1,
40
+ dilation: int = 1,
41
+ groups: int = 1,
42
+ padding_mode: str = "zeros",
43
+ bias: bool = True,
44
+ padding=None,
45
+ causal: bool = False,
46
+ w_init_gain=None,
47
+ ):
48
+ self.causal = causal
49
+ if padding is None:
50
+ if causal:
51
+ padding = 0
52
+ self.left_padding = dilation * (kernel_size - 1)
53
+ else:
54
+ padding = get_padding(kernel_size, dilation)
55
+ super(Conv1d, self).__init__(
56
+ in_channels,
57
+ out_channels,
58
+ kernel_size,
59
+ stride=stride,
60
+ padding=padding,
61
+ dilation=dilation,
62
+ groups=groups,
63
+ padding_mode=padding_mode,
64
+ bias=bias,
65
+ )
66
+ if w_init_gain is not None:
67
+ torch.nn.init.xavier_uniform_(
68
+ self.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
69
+ )
70
+
71
+ def forward(self, x):
72
+ if self.causal:
73
+ x = F.pad(x.unsqueeze(2), (self.left_padding, 0, 0, 0)).squeeze(2)
74
+
75
+ return super(Conv1d, self).forward(x)
76
+
77
+
78
+ class ConvTranspose1d(nn.ConvTranspose1d):
79
+ def __init__(
80
+ self,
81
+ in_channels: int,
82
+ out_channels: int,
83
+ kernel_size: int,
84
+ stride: int = 1,
85
+ output_padding: int = 0,
86
+ groups: int = 1,
87
+ bias: bool = True,
88
+ dilation: int = 1,
89
+ padding=None,
90
+ padding_mode: str = "zeros",
91
+ causal: bool = False,
92
+ ):
93
+ if padding is None:
94
+ padding = 0 if causal else (kernel_size - stride) // 2
95
+ if causal:
96
+ assert padding == 0, "padding is not allowed in causal ConvTranspose1d."
97
+ assert (
98
+ kernel_size == 2 * stride
99
+ ), "kernel_size must be equal to 2*stride is not allowed in causal ConvTranspose1d."
100
+ super(ConvTranspose1d, self).__init__(
101
+ in_channels,
102
+ out_channels,
103
+ kernel_size,
104
+ stride=stride,
105
+ padding=padding,
106
+ output_padding=output_padding,
107
+ groups=groups,
108
+ bias=bias,
109
+ dilation=dilation,
110
+ padding_mode=padding_mode,
111
+ )
112
+ self.causal = causal
113
+ self.stride = stride
114
+
115
+ def forward(self, x):
116
+ x = super(ConvTranspose1d, self).forward(x)
117
+ if self.causal:
118
+ x = x[:, :, : -self.stride]
119
+ return x
120
+
121
+
122
+ class PreProcessor(nn.Module):
123
+ def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False):
124
+ super(PreProcessor, self).__init__()
125
+ self.pooling = torch.nn.AvgPool1d(kernel_size=num_samples)
126
+ self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal)
127
+ self.activation = nn.PReLU()
128
+
129
+ def forward(self, x):
130
+ output = self.activation(self.conv(x))
131
+ output = self.pooling(output)
132
+ return output
133
+
134
+
135
+ class PostProcessor(nn.Module):
136
+ def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False):
137
+ super(PostProcessor, self).__init__()
138
+ self.num_samples = num_samples
139
+ self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal)
140
+ self.activation = nn.PReLU()
141
+
142
+ def forward(self, x):
143
+ x = torch.transpose(x, 1, 2)
144
+ B, T, C = x.size()
145
+ x = x.repeat(1, 1, self.num_samples).view(B, -1, C)
146
+ x = torch.transpose(x, 1, 2)
147
+ output = self.activation(self.conv(x))
148
+ return output
149
+
150
+
151
+ class ResidualUnit(nn.Module):
152
+ def __init__(self, n_in, n_out, dilation, res_kernel_size=7, causal=False):
153
+ super(ResidualUnit, self).__init__()
154
+ self.conv1 = weight_norm(
155
+ Conv1d(
156
+ n_in,
157
+ n_out,
158
+ kernel_size=res_kernel_size,
159
+ dilation=dilation,
160
+ causal=causal,
161
+ )
162
+ )
163
+ self.conv2 = weight_norm(Conv1d(n_in, n_out, kernel_size=1, causal=causal))
164
+ self.activation1 = nn.PReLU()
165
+ self.activation2 = nn.PReLU()
166
+
167
+ def forward(self, x):
168
+ output = self.activation1(self.conv1(x))
169
+ output = self.activation2(self.conv2(output))
170
+ return output + x
171
+
172
+
173
+ class ResEncoderBlock(nn.Module):
174
+ def __init__(
175
+ self, n_in, n_out, stride, down_kernel_size, res_kernel_size=7, causal=False
176
+ ):
177
+ super(ResEncoderBlock, self).__init__()
178
+ self.convs = nn.ModuleList(
179
+ [
180
+ ResidualUnit(
181
+ n_in,
182
+ n_out // 2,
183
+ dilation=1,
184
+ res_kernel_size=res_kernel_size,
185
+ causal=causal,
186
+ ),
187
+ ResidualUnit(
188
+ n_out // 2,
189
+ n_out // 2,
190
+ dilation=3,
191
+ res_kernel_size=res_kernel_size,
192
+ causal=causal,
193
+ ),
194
+ ResidualUnit(
195
+ n_out // 2,
196
+ n_out // 2,
197
+ dilation=5,
198
+ res_kernel_size=res_kernel_size,
199
+ causal=causal,
200
+ ),
201
+ ResidualUnit(
202
+ n_out // 2,
203
+ n_out // 2,
204
+ dilation=7,
205
+ res_kernel_size=res_kernel_size,
206
+ causal=causal,
207
+ ),
208
+ ResidualUnit(
209
+ n_out // 2,
210
+ n_out // 2,
211
+ dilation=9,
212
+ res_kernel_size=res_kernel_size,
213
+ causal=causal,
214
+ ),
215
+ ]
216
+ )
217
+
218
+ self.down_conv = DownsampleLayer(
219
+ n_in, n_out, down_kernel_size, stride=stride, causal=causal
220
+ )
221
+
222
+ def forward(self, x):
223
+ for conv in self.convs:
224
+ x = conv(x)
225
+ x = self.down_conv(x)
226
+ return x
227
+
228
+
229
+ class ResDecoderBlock(nn.Module):
230
+ def __init__(
231
+ self, n_in, n_out, stride, up_kernel_size, res_kernel_size=7, causal=False
232
+ ):
233
+ super(ResDecoderBlock, self).__init__()
234
+ self.up_conv = UpsampleLayer(
235
+ n_in,
236
+ n_out,
237
+ kernel_size=up_kernel_size,
238
+ stride=stride,
239
+ causal=causal,
240
+ activation=None,
241
+ )
242
+
243
+ self.convs = nn.ModuleList(
244
+ [
245
+ ResidualUnit(
246
+ n_out,
247
+ n_out,
248
+ dilation=1,
249
+ res_kernel_size=res_kernel_size,
250
+ causal=causal,
251
+ ),
252
+ ResidualUnit(
253
+ n_out,
254
+ n_out,
255
+ dilation=3,
256
+ res_kernel_size=res_kernel_size,
257
+ causal=causal,
258
+ ),
259
+ ResidualUnit(
260
+ n_out,
261
+ n_out,
262
+ dilation=5,
263
+ res_kernel_size=res_kernel_size,
264
+ causal=causal,
265
+ ),
266
+ ResidualUnit(
267
+ n_out,
268
+ n_out,
269
+ dilation=7,
270
+ res_kernel_size=res_kernel_size,
271
+ causal=causal,
272
+ ),
273
+ ResidualUnit(
274
+ n_out,
275
+ n_out,
276
+ dilation=9,
277
+ res_kernel_size=res_kernel_size,
278
+ causal=causal,
279
+ ),
280
+ ]
281
+ )
282
+
283
+ def forward(self, x):
284
+ x = self.up_conv(x)
285
+ for conv in self.convs:
286
+ x = conv(x)
287
+ return x
288
+
289
+
290
+ class DownsampleLayer(nn.Module):
291
+ def __init__(
292
+ self,
293
+ in_channels: int,
294
+ out_channels: int,
295
+ kernel_size: int,
296
+ stride: int = 1,
297
+ causal: bool = False,
298
+ activation=nn.PReLU(),
299
+ use_weight_norm: bool = True,
300
+ pooling: bool = False,
301
+ ):
302
+ super(DownsampleLayer, self).__init__()
303
+ self.pooling = pooling
304
+ self.stride = stride
305
+ self.activation = nn.PReLU()
306
+ self.use_weight_norm = use_weight_norm
307
+ if pooling:
308
+ self.layer = Conv1d(in_channels, out_channels, kernel_size, causal=causal)
309
+ self.pooling = nn.AvgPool1d(kernel_size=stride)
310
+ else:
311
+ self.layer = Conv1d(
312
+ in_channels, out_channels, kernel_size, stride=stride, causal=causal
313
+ )
314
+ if use_weight_norm:
315
+ self.layer = weight_norm(self.layer)
316
+
317
+ def forward(self, x):
318
+ x = self.layer(x)
319
+ x = self.activation(x) if self.activation is not None else x
320
+ if self.pooling:
321
+ x = self.pooling(x)
322
+ return x
323
+
324
+ def remove_weight_norm(self):
325
+ if self.use_weight_norm:
326
+ remove_weight_norm(self.layer)
327
+
328
+
329
+ class UpsampleLayer(nn.Module):
330
+ def __init__(
331
+ self,
332
+ in_channels: int,
333
+ out_channels: int,
334
+ kernel_size: int,
335
+ stride: int = 1,
336
+ causal: bool = False,
337
+ activation=nn.PReLU(),
338
+ use_weight_norm: bool = True,
339
+ repeat: bool = False,
340
+ ):
341
+ super(UpsampleLayer, self).__init__()
342
+ self.repeat = repeat
343
+ self.stride = stride
344
+ self.activation = activation
345
+ self.use_weight_norm = use_weight_norm
346
+ if repeat:
347
+ self.layer = Conv1d(in_channels, out_channels, kernel_size, causal=causal)
348
+ else:
349
+ self.layer = ConvTranspose1d(
350
+ in_channels, out_channels, kernel_size, stride=stride, causal=causal
351
+ )
352
+ if use_weight_norm:
353
+ self.layer = weight_norm(self.layer)
354
+
355
+ def forward(self, x):
356
+ x = self.layer(x)
357
+ x = self.activation(x) if self.activation is not None else x
358
+ if self.repeat:
359
+ x = torch.transpose(x, 1, 2)
360
+ B, T, C = x.size()
361
+ x = x.repeat(1, 1, self.stride).view(B, -1, C)
362
+ x = torch.transpose(x, 1, 2)
363
+ return x
364
+
365
+ def remove_weight_norm(self):
366
+ if self.use_weight_norm:
367
+ remove_weight_norm(self.layer)
368
+
369
+
370
+ class round_func9(InplaceFunction):
371
+ @staticmethod
372
+ def forward(ctx, input):
373
+ ctx.input = input
374
+ return torch.round(9 * input) / 9
375
+
376
+ @staticmethod
377
+ def backward(ctx, grad_output):
378
+ grad_input = grad_output.clone()
379
+ return grad_input
380
+
381
+
382
+ class ScalarModel(nn.Module):
383
+ def __init__(
384
+ self,
385
+ num_bands,
386
+ sample_rate,
387
+ causal,
388
+ num_samples,
389
+ downsample_factors,
390
+ downsample_kernel_sizes,
391
+ upsample_factors,
392
+ upsample_kernel_sizes,
393
+ latent_hidden_dim,
394
+ default_kernel_size,
395
+ delay_kernel_size,
396
+ init_channel,
397
+ res_kernel_size,
398
+ mode="pre_proj",
399
+ ):
400
+ super(ScalarModel, self).__init__()
401
+ # self.args = args
402
+ self.encoder = []
403
+ self.decoder = []
404
+ self.vq = round_func9() # using 9
405
+ self.mode = mode
406
+ # Encoder parts
407
+ self.encoder.append(
408
+ weight_norm(
409
+ Conv1d(
410
+ num_bands,
411
+ init_channel,
412
+ kernel_size=default_kernel_size,
413
+ causal=causal,
414
+ )
415
+ )
416
+ )
417
+ if num_samples > 1:
418
+ # Downsampling
419
+ self.encoder.append(
420
+ PreProcessor(
421
+ init_channel,
422
+ init_channel,
423
+ num_samples,
424
+ kernel_size=default_kernel_size,
425
+ causal=causal,
426
+ )
427
+ )
428
+ for i, down_factor in enumerate(downsample_factors):
429
+ self.encoder.append(
430
+ ResEncoderBlock(
431
+ init_channel * np.power(2, i),
432
+ init_channel * np.power(2, i + 1),
433
+ down_factor,
434
+ downsample_kernel_sizes[i],
435
+ res_kernel_size,
436
+ causal=causal,
437
+ )
438
+ )
439
+ self.encoder.append(
440
+ weight_norm(
441
+ Conv1d(
442
+ init_channel * np.power(2, len(downsample_factors)),
443
+ latent_hidden_dim,
444
+ kernel_size=default_kernel_size,
445
+ causal=causal,
446
+ )
447
+ )
448
+ )
449
+ # Decoder
450
+ # look ahead
451
+ self.decoder.append(
452
+ weight_norm(
453
+ Conv1d(
454
+ latent_hidden_dim,
455
+ init_channel * np.power(2, len(upsample_factors)),
456
+ kernel_size=delay_kernel_size,
457
+ )
458
+ )
459
+ )
460
+ for i, upsample_factor in enumerate(upsample_factors):
461
+ self.decoder.append(
462
+ ResDecoderBlock(
463
+ init_channel * np.power(2, len(upsample_factors) - i),
464
+ init_channel * np.power(2, len(upsample_factors) - i - 1),
465
+ upsample_factor,
466
+ upsample_kernel_sizes[i],
467
+ res_kernel_size,
468
+ causal=causal,
469
+ )
470
+ )
471
+ if num_samples > 1:
472
+ self.decoder.append(
473
+ PostProcessor(
474
+ init_channel,
475
+ init_channel,
476
+ num_samples,
477
+ kernel_size=default_kernel_size,
478
+ causal=causal,
479
+ )
480
+ )
481
+ self.decoder.append(
482
+ weight_norm(
483
+ Conv1d(
484
+ init_channel,
485
+ num_bands,
486
+ kernel_size=default_kernel_size,
487
+ causal=causal,
488
+ )
489
+ )
490
+ )
491
+ self.encoder = nn.ModuleList(self.encoder)
492
+ self.decoder = nn.ModuleList(self.decoder)
493
+
494
+ def forward(self, x):
495
+ for i, layer in enumerate(self.encoder):
496
+ if i != len(self.encoder) - 1:
497
+ x = layer(x)
498
+ else:
499
+ x = F.tanh(layer(x))
500
+ # import pdb; pdb.set_trace()
501
+ x = self.vq.apply(x) # vq
502
+ for i, layer in enumerate(self.decoder):
503
+ x = layer(x)
504
+ return x
505
+
506
+ def inference(self, x):
507
+ for i, layer in enumerate(self.encoder):
508
+ if i != len(self.encoder) - 1:
509
+ x = layer(x)
510
+ else:
511
+ x = F.tanh(layer(x)) # reverse to tanh
512
+
513
+ emb = x
514
+ # import pdb; pdb.set_trace()
515
+ emb_quant = self.vq.apply(emb) # vq
516
+ x = emb_quant
517
+ for i, layer in enumerate(self.decoder):
518
+ x = layer(x)
519
+ return emb, emb_quant, x
520
+
521
+ def encode(self, x):
522
+ for i, layer in enumerate(self.encoder):
523
+ if i != len(self.encoder) - 1:
524
+ x = layer(x)
525
+ else:
526
+ x = F.tanh(layer(x)) # reverse to tanh
527
+
528
+ emb = x
529
+ # import pdb; pdb.set_trace()
530
+ emb_quant = self.vq.apply(emb) # vq
531
+ return emb
532
+
533
+ def decode(self, x):
534
+ x = self.vq.apply(
535
+ x
536
+ ) # make sure the prediction follow the similar disctribution
537
+ for i, layer in enumerate(self.decoder):
538
+ x = layer(x)
539
+ return x
src/heartlib/heartcodec/models/transformer.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class RMSNorm(nn.Module):
9
+ def __init__(self, dim: int, eps: float = 1e-6):
10
+ super().__init__()
11
+ self.eps = eps
12
+ self.weight = nn.Parameter(torch.ones(dim))
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ var = x.pow(2).mean(dim=-1, keepdim=True)
16
+ x = x * torch.rsqrt(var + self.eps)
17
+ return self.weight * x
18
+
19
+
20
+ class RotaryEmbedding(nn.Module):
21
+ def __init__(self, dim: int, base: int = 10000):
22
+ super().__init__()
23
+ self.dim = dim
24
+ self.base = base
25
+ self._cache = {}
26
+
27
+ def get_sin_cos(self, seq_len: int, device, dtype):
28
+ key = (seq_len, device, dtype)
29
+ cached = self._cache.get(key, None)
30
+ if cached is not None and cached[0].device == device:
31
+ return cached
32
+ inv_freq = 1.0 / (
33
+ self.base
34
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=dtype) / self.dim)
35
+ )
36
+ t = torch.arange(seq_len, device=device, dtype=dtype)
37
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
38
+ sin = freqs.sin()
39
+ cos = freqs.cos()
40
+ self._cache[key] = (sin, cos)
41
+ return sin, cos
42
+
43
+ def apply_rotary(
44
+ self, x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor
45
+ ) -> torch.Tensor:
46
+ x1, x2 = x[..., : self.dim // 2], x[..., self.dim // 2 : self.dim]
47
+ # Interleave sin/cos across pairs
48
+ x_rot = torch.stack((-x2, x1), dim=-1).reshape_as(x[..., : self.dim])
49
+ return (x[..., : self.dim] * cos.unsqueeze(-1)).reshape_as(
50
+ x[..., : self.dim]
51
+ ) + (x_rot * sin.unsqueeze(-1)).reshape_as(x[..., : self.dim])
52
+
53
+
54
+ class LlamaAttention(nn.Module):
55
+ def __init__(
56
+ self,
57
+ dim: int,
58
+ n_heads: int,
59
+ head_dim: int,
60
+ bias: bool = False,
61
+ dropout: float = 0.0,
62
+ rope_dim: Optional[int] = None,
63
+ cross_attention_dim: Optional[int] = None,
64
+ use_sdpa: bool = True,
65
+ ):
66
+ super().__init__()
67
+ self.dim = dim
68
+ self.n_heads = n_heads
69
+ self.head_dim = head_dim
70
+ self.inner_dim = n_heads * head_dim
71
+ self.cross_attention_dim = cross_attention_dim
72
+ self.q_proj = nn.Linear(dim, self.inner_dim, bias=bias)
73
+ k_in = dim if cross_attention_dim is None else cross_attention_dim
74
+ self.k_proj = nn.Linear(k_in, self.inner_dim, bias=bias)
75
+ self.v_proj = nn.Linear(k_in, self.inner_dim, bias=bias)
76
+ self.o_proj = nn.Linear(self.inner_dim, dim, bias=bias)
77
+ self.dropout = dropout
78
+ self.rope_dim = rope_dim if rope_dim is not None else head_dim
79
+ self.rope = RotaryEmbedding(self.rope_dim)
80
+ self.use_sdpa = use_sdpa
81
+ self._has_sdpa = hasattr(F, "scaled_dot_product_attention")
82
+
83
+ def _shape(self, x: torch.Tensor, b: int, t: int) -> torch.Tensor:
84
+ return x.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
85
+
86
+ def forward(
87
+ self,
88
+ x: torch.Tensor,
89
+ encoder_hidden_states: Optional[torch.Tensor] = None,
90
+ attention_mask: Optional[torch.Tensor] = None,
91
+ ) -> torch.Tensor:
92
+ b, t, c = x.shape
93
+ q = self._shape(self.q_proj(x), b, t)
94
+ if encoder_hidden_states is None:
95
+ k = self._shape(self.k_proj(x), b, t)
96
+ v = self._shape(self.v_proj(x), b, t)
97
+ else:
98
+ bt, tk, ck = encoder_hidden_states.shape
99
+ k = self._shape(self.k_proj(encoder_hidden_states), b, tk)
100
+ v = self._shape(self.v_proj(encoder_hidden_states), b, tk)
101
+
102
+ # RoPE on first rope_dim of head_dim
103
+ rope_dim = min(self.rope_dim, self.head_dim)
104
+ seq_len_for_rope = k.shape[-2]
105
+ sin, cos = self.rope.get_sin_cos(
106
+ seq_len_for_rope, device=x.device, dtype=x.dtype
107
+ )
108
+
109
+ def apply_rope_vec(tensor):
110
+ head = tensor[..., :rope_dim]
111
+ tail = tensor[..., rope_dim:]
112
+ b, h, tt, _ = head.shape
113
+ head = head.view(b, h, tt, rope_dim // 2, 2)
114
+ sin_ = sin.view(1, 1, tt, rope_dim // 2, 1)
115
+ cos_ = cos.view(1, 1, tt, rope_dim // 2, 1)
116
+ x1 = head[..., 0:1]
117
+ x2 = head[..., 1:2]
118
+ rot = torch.cat(
119
+ [x1 * cos_ - x2 * sin_, x1 * sin_ + x2 * cos_], dim=-1
120
+ ).view(b, h, tt, rope_dim)
121
+ return torch.cat([rot, tail], dim=-1)
122
+
123
+ q = apply_rope_vec(q)
124
+ k = apply_rope_vec(k)
125
+
126
+ # Prefer PyTorch SDPA (can enable FlashAttention kernel on supported GPUs)
127
+ if self.use_sdpa and self._has_sdpa:
128
+ s = k.shape[-2]
129
+ attn_mask_sdpa = None
130
+ if attention_mask is not None:
131
+ m = attention_mask
132
+
133
+ if m.dim() == 2 and m.shape == (b, s): # [b, s]
134
+ m = m[:, None, None, :] # [b,1,1,s]
135
+ elif m.dim() == 3 and m.shape[-2] == 1: # [b,1,s]
136
+ m = m[:, None, :, :] # [b,1,1,s]
137
+ elif m.dim() == 3 and m.shape[-2] == t: # [b,t,s]
138
+ m = m[:, None, :, :] # [b,1,t,s]
139
+ elif m.dim() == 4 and m.shape[1] == 1: # [b,1,t,s] or [b,1,1,s]
140
+ pass
141
+ attn_mask_sdpa = m
142
+
143
+ out = F.scaled_dot_product_attention(
144
+ q,
145
+ k,
146
+ v,
147
+ attn_mask=attn_mask_sdpa,
148
+ dropout_p=self.dropout if self.training else 0.0,
149
+ is_causal=False,
150
+ )
151
+ out = out.transpose(1, 2).contiguous().view(b, t, self.inner_dim)
152
+ return self.o_proj(out)
153
+ else:
154
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
155
+ self.head_dim
156
+ )
157
+ if attention_mask is not None:
158
+ attn_scores = attn_scores + attention_mask
159
+ attn = attn_scores.softmax(dim=-1)
160
+ attn = F.dropout(attn, p=self.dropout, training=self.training)
161
+ out = torch.matmul(attn, v)
162
+ out = out.transpose(1, 2).contiguous().view(b, t, self.inner_dim)
163
+ return self.o_proj(out)
164
+
165
+
166
+ class LlamaMLP(nn.Module):
167
+ def __init__(
168
+ self,
169
+ dim: int,
170
+ hidden_dim: Optional[int] = None,
171
+ multiple_of: int = 256,
172
+ dropout: float = 0.0,
173
+ ):
174
+ super().__init__()
175
+ hidden_dim = hidden_dim or 4 * dim
176
+ # align to multiple_of like Llama
177
+ hidden_dim = int(2 * hidden_dim / 3)
178
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
179
+ self.gate = nn.Linear(dim, hidden_dim, bias=False)
180
+ self.up = nn.Linear(dim, hidden_dim, bias=False)
181
+ self.down = nn.Linear(hidden_dim, dim, bias=False)
182
+ self.dropout = dropout
183
+
184
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
185
+ x = F.silu(self.gate(x)) * self.up(x)
186
+ x = F.dropout(x, p=self.dropout, training=self.training)
187
+ return self.down(x)
188
+
189
+
190
+ class LlamaTransformerBlock(nn.Module):
191
+ def __init__(
192
+ self,
193
+ dim: int,
194
+ n_heads: int,
195
+ head_dim: int,
196
+ mlp_multiple_of: int = 256,
197
+ dropout: float = 0.0,
198
+ attention_bias: bool = False,
199
+ cross_attention_dim: Optional[int] = None,
200
+ use_ada_layer_norm_single: bool = False,
201
+ ):
202
+ super().__init__()
203
+ self.attn_norm = RMSNorm(dim, 1e-6)
204
+ self.attn = LlamaAttention(
205
+ dim,
206
+ n_heads,
207
+ head_dim,
208
+ bias=attention_bias,
209
+ dropout=dropout,
210
+ rope_dim=head_dim,
211
+ cross_attention_dim=None,
212
+ )
213
+ self.cross_attn = None
214
+ if cross_attention_dim is not None:
215
+ self.cross_attn_norm = RMSNorm(dim, 1e-6)
216
+ self.cross_attn = LlamaAttention(
217
+ dim,
218
+ n_heads,
219
+ head_dim,
220
+ bias=attention_bias,
221
+ dropout=dropout,
222
+ rope_dim=head_dim,
223
+ cross_attention_dim=cross_attention_dim,
224
+ )
225
+ self.mlp_norm = RMSNorm(dim, 1e-6)
226
+ self.mlp = LlamaMLP(dim, multiple_of=mlp_multiple_of, dropout=dropout)
227
+ self.use_ada_layer_norm_single = use_ada_layer_norm_single
228
+ if self.use_ada_layer_norm_single:
229
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
230
+
231
+ def forward(
232
+ self,
233
+ x: torch.Tensor,
234
+ encoder_hidden_states: Optional[torch.Tensor] = None,
235
+ attention_mask: Optional[torch.Tensor] = None,
236
+ timestep: Optional[torch.Tensor] = None,
237
+ ) -> torch.Tensor:
238
+ if self.use_ada_layer_norm_single:
239
+ batch_size = x.shape[0]
240
+ # timestep: [B, 6*D]
241
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
242
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
243
+ ).chunk(6, dim=1)
244
+
245
+ # Self-Attention with modulation and gating
246
+ norm_hidden_states = self.attn_norm(x)
247
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
248
+ h = self.attn(norm_hidden_states, attention_mask=attention_mask)
249
+ h = gate_msa * h
250
+ x = x + h
251
+
252
+ # MLP with modulation and gating
253
+ norm_hidden_states = self.mlp_norm(x)
254
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
255
+ h = self.mlp(norm_hidden_states)
256
+ h = gate_mlp * h
257
+ x = x + h
258
+ return x
259
+ else:
260
+ h = self.attn(self.attn_norm(x), attention_mask=attention_mask)
261
+ x = x + h
262
+ h = self.mlp(self.mlp_norm(x))
263
+ x = x + h
264
+ return x
265
+
266
+
267
+ class ProjectLayer(nn.Module):
268
+ def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0.0):
269
+ super().__init__()
270
+ self.kernel_size = kernel_size
271
+ self.dropout = dropout
272
+ self.ffn_1 = nn.Conv1d(
273
+ hidden_size, filter_size, kernel_size, padding=kernel_size // 2
274
+ )
275
+ self.ffn_2 = nn.Linear(filter_size, filter_size)
276
+
277
+ def forward(self, x):
278
+ x = self.ffn_1(x.transpose(1, 2)).transpose(1, 2)
279
+ x = x * self.kernel_size**-0.5
280
+ x = self.ffn_2(x)
281
+ return x
282
+
283
+
284
+ class LlamaTransformer(nn.Module):
285
+ def __init__(
286
+ self,
287
+ num_attention_heads: int,
288
+ attention_head_dim: int,
289
+ in_channels: int,
290
+ out_channels: int,
291
+ num_layers: int = 12,
292
+ num_layers_2: int = 2,
293
+ dropout: float = 0.0,
294
+ cross_attention_dim: Optional[int] = None,
295
+ norm_type: str = "layer_norm",
296
+ ):
297
+ super().__init__()
298
+ inner_dim = num_attention_heads * attention_head_dim
299
+ inner_dim_2 = inner_dim * 2
300
+ self.in_channels = in_channels
301
+ self.out_channels = out_channels
302
+ self.inner_dim = inner_dim
303
+ self.inner_dim_2 = inner_dim_2
304
+ self.dropout = dropout
305
+
306
+ self.proj_in = ProjectLayer(in_channels, inner_dim, kernel_size=3)
307
+
308
+ use_ada_single = norm_type == "ada_norm_single"
309
+ self.transformer_blocks = nn.ModuleList(
310
+ [
311
+ LlamaTransformerBlock(
312
+ dim=inner_dim,
313
+ n_heads=num_attention_heads,
314
+ head_dim=attention_head_dim,
315
+ dropout=dropout,
316
+ attention_bias=False,
317
+ cross_attention_dim=cross_attention_dim,
318
+ use_ada_layer_norm_single=use_ada_single,
319
+ )
320
+ for _ in range(num_layers)
321
+ ]
322
+ )
323
+
324
+ self.transformer_blocks_2 = nn.ModuleList(
325
+ [
326
+ LlamaTransformerBlock(
327
+ dim=inner_dim_2,
328
+ n_heads=num_attention_heads,
329
+ head_dim=attention_head_dim * 2,
330
+ dropout=dropout,
331
+ attention_bias=False,
332
+ cross_attention_dim=cross_attention_dim,
333
+ use_ada_layer_norm_single=use_ada_single,
334
+ )
335
+ for _ in range(num_layers_2)
336
+ ]
337
+ )
338
+
339
+ self.connection_proj = ProjectLayer(
340
+ in_channels + inner_dim, inner_dim_2, kernel_size=3
341
+ )
342
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
343
+ self.norm_out_2 = nn.LayerNorm(inner_dim_2, elementwise_affine=False, eps=1e-6)
344
+ self.scale_shift_table = nn.Parameter(
345
+ torch.randn(2, inner_dim) / inner_dim**0.5
346
+ )
347
+ self.scale_shift_table_2 = nn.Parameter(
348
+ torch.randn(2, inner_dim_2) / inner_dim_2**0.5
349
+ )
350
+ self.proj_out = ProjectLayer(inner_dim_2, out_channels, kernel_size=3)
351
+ self.adaln_single = AdaLayerNormSingleFlow(inner_dim)
352
+ self.adaln_single_2 = AdaLayerNormSingleFlow(inner_dim_2)
353
+
354
+ def forward(
355
+ self,
356
+ hidden_states: torch.Tensor,
357
+ timestep: Optional[torch.LongTensor] = None,
358
+ ):
359
+ s = self.proj_in(hidden_states)
360
+
361
+ embedded_timestep = None
362
+ timestep_mod = None
363
+ if self.adaln_single is not None and timestep is not None:
364
+ batch_size = s.shape[0]
365
+ timestep_mod, embedded_timestep = self.adaln_single(
366
+ timestep, hidden_dtype=s.dtype
367
+ )
368
+ for blk in self.transformer_blocks:
369
+ s = blk(s, timestep=timestep_mod)
370
+
371
+ if embedded_timestep is None:
372
+ embedded_timestep = torch.zeros(
373
+ s.size(0), s.size(-1), device=s.device, dtype=s.dtype
374
+ )
375
+
376
+ shift, scale = (
377
+ self.scale_shift_table[None] + embedded_timestep[:, None]
378
+ ).chunk(2, dim=1)
379
+ s = self.norm_out(s)
380
+ s = s * (1 + scale) + shift
381
+
382
+ x = torch.cat([hidden_states, s], dim=-1)
383
+ x = self.connection_proj(x)
384
+
385
+ embedded_timestep_2 = None
386
+ timestep_mod_2 = None
387
+ if self.adaln_single_2 is not None and timestep is not None:
388
+ batch_size = x.shape[0]
389
+ timestep_mod_2, embedded_timestep_2 = self.adaln_single_2(
390
+ timestep, hidden_dtype=x.dtype
391
+ )
392
+ for blk in self.transformer_blocks_2:
393
+ x = blk(x, timestep=timestep_mod_2)
394
+
395
+ if embedded_timestep_2 is None:
396
+ embedded_timestep_2 = torch.zeros(
397
+ x.size(0), x.size(-1), device=x.device, dtype=x.dtype
398
+ )
399
+
400
+ shift_2, scale_2 = (
401
+ self.scale_shift_table_2[None] + embedded_timestep_2[:, None]
402
+ ).chunk(2, dim=1)
403
+ x = self.norm_out_2(x)
404
+ x = x * (1 + scale_2) + shift_2
405
+
406
+ out = self.proj_out(x)
407
+
408
+ return out
409
+
410
+
411
+ class PixArtAlphaCombinedFlowEmbeddings(nn.Module):
412
+ def __init__(self, embedding_dim: int, size_emb_dim: int):
413
+ super().__init__()
414
+ self.flow_t_size = 512
415
+ self.outdim = size_emb_dim
416
+ self.timestep_embedder = TimestepEmbedding(
417
+ in_channels=self.flow_t_size, time_embed_dim=embedding_dim
418
+ )
419
+
420
+ def timestep_embedding(self, timesteps, max_period=10000, scale=1000):
421
+ half = self.flow_t_size // 2
422
+ freqs = torch.exp(
423
+ -math.log(max_period)
424
+ * torch.arange(start=0, end=half, device=timesteps.device)
425
+ / half
426
+ ).type(timesteps.type())
427
+ args = timesteps[:, None] * freqs[None] * scale
428
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
429
+ if self.flow_t_size % 2:
430
+ embedding = torch.cat(
431
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
432
+ )
433
+ return embedding
434
+
435
+ def forward(self, timestep, hidden_dtype):
436
+ timesteps_proj = self.timestep_embedding(timestep)
437
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))
438
+ conditioning = timesteps_emb
439
+ return conditioning
440
+
441
+
442
+ class AdaLayerNormSingleFlow(nn.Module):
443
+ def __init__(self, embedding_dim: int):
444
+ super().__init__()
445
+ self.emb = PixArtAlphaCombinedFlowEmbeddings(
446
+ embedding_dim, size_emb_dim=embedding_dim // 3
447
+ )
448
+ self.silu = nn.SiLU()
449
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
450
+
451
+ def forward(
452
+ self,
453
+ timestep: torch.Tensor,
454
+ hidden_dtype: Optional[torch.dtype] = None,
455
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
456
+
457
+ embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype)
458
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
459
+
460
+
461
+ class TimestepEmbedding(nn.Module):
462
+ def __init__(self, in_channels: int, time_embed_dim: int):
463
+ super().__init__()
464
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
465
+ self.act = nn.SiLU()
466
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
467
+
468
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
469
+ x = self.linear_1(x)
470
+ x = self.act(x)
471
+ x = self.linear_2(x)
472
+ return x
473
+
474
+
475
+ class Timesteps(nn.Module):
476
+ def __init__(
477
+ self,
478
+ num_channels: int,
479
+ flip_sin_to_cos: bool = True,
480
+ downscale_freq_shift: float = 0,
481
+ ):
482
+ super().__init__()
483
+ self.num_channels = num_channels
484
+ self.flip_sin_to_cos = flip_sin_to_cos
485
+ self.downscale_freq_shift = downscale_freq_shift
486
+
487
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
488
+ half_dim = self.num_channels // 2
489
+ exponent = (
490
+ -math.log(10000)
491
+ * torch.arange(0, half_dim, device=timesteps.device)
492
+ / (half_dim - self.downscale_freq_shift)
493
+ )
494
+ emb = torch.exp(exponent)[None, :] * timesteps[:, None]
495
+ if self.flip_sin_to_cos:
496
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
497
+ else:
498
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
499
+ if self.num_channels % 2 == 1:
500
+ emb = torch.nn.functional.pad(emb, (0, 1))
501
+ return emb
src/heartlib/heartmula/configuration_heartmula.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class HeartMuLaConfig(PretrainedConfig):
5
+ model_type = "heartmula"
6
+
7
+ def __init__(
8
+ self,
9
+ backbone_flavor: str = "llama-3B",
10
+ decoder_flavor: str = "llama-300M",
11
+ text_vocab_size: int = 128256,
12
+ audio_vocab_size: int = 8197,
13
+ audio_num_codebooks: int = 8,
14
+ muq_dim: int = 512,
15
+ **kwargs
16
+ ):
17
+ super().__init__(**kwargs)
18
+ self.backbone_flavor = backbone_flavor
19
+ self.decoder_flavor = decoder_flavor
20
+ self.text_vocab_size = text_vocab_size
21
+ self.audio_vocab_size = audio_vocab_size
22
+ self.audio_num_codebooks = audio_num_codebooks
23
+ self.muq_dim = muq_dim
src/heartlib/heartmula/modeling_heartmula.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .configuration_heartmula import HeartMuLaConfig
4
+ from transformers.modeling_utils import PreTrainedModel
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchtune
8
+ from torchtune.models import llama3_2
9
+
10
+
11
+ def llama3_2_3B() -> torchtune.modules.transformer.TransformerDecoder:
12
+ return llama3_2.llama3_2(
13
+ vocab_size=128_256,
14
+ num_layers=28,
15
+ num_heads=24,
16
+ num_kv_heads=8,
17
+ embed_dim=3072,
18
+ max_seq_len=8192,
19
+ intermediate_dim=8192,
20
+ attn_dropout=0.0,
21
+ norm_eps=1e-5,
22
+ rope_base=500_000,
23
+ scale_factor=32,
24
+ )
25
+
26
+
27
+ def llama3_2_300M() -> torchtune.modules.transformer.TransformerDecoder:
28
+ return llama3_2.llama3_2(
29
+ vocab_size=128_256,
30
+ num_layers=3,
31
+ num_heads=8,
32
+ num_kv_heads=4,
33
+ embed_dim=3072,
34
+ max_seq_len=2048,
35
+ intermediate_dim=8192,
36
+ attn_dropout=0.0,
37
+ norm_eps=1e-5,
38
+ rope_base=500_000,
39
+ scale_factor=32,
40
+ )
41
+
42
+
43
+ def llama3_2_7B() -> torchtune.modules.transformer.TransformerDecoder:
44
+ return llama3_2.llama3_2(
45
+ vocab_size=128_256,
46
+ num_layers=32,
47
+ num_heads=32,
48
+ num_kv_heads=8,
49
+ embed_dim=4096,
50
+ max_seq_len=8192,
51
+ intermediate_dim=14336,
52
+ attn_dropout=0.0,
53
+ norm_eps=1e-5,
54
+ rope_base=500_000,
55
+ scale_factor=32,
56
+ )
57
+
58
+
59
+ def llama3_2_400M() -> torchtune.modules.transformer.TransformerDecoder:
60
+ return llama3_2.llama3_2(
61
+ vocab_size=128_256,
62
+ num_layers=4,
63
+ num_heads=8,
64
+ num_kv_heads=4,
65
+ embed_dim=3072,
66
+ max_seq_len=2048,
67
+ intermediate_dim=8192,
68
+ attn_dropout=0.0,
69
+ norm_eps=1e-5,
70
+ rope_base=500_000,
71
+ scale_factor=32,
72
+ ) # 减少了num_heads和num_kv_heads之间的倍速,提升了精确度,但降低了效率
73
+
74
+
75
+ FLAVORS = {
76
+ "llama-3B": llama3_2_3B,
77
+ "llama-300M": llama3_2_300M,
78
+ "llama-7B": llama3_2_7B,
79
+ "llama-400M": llama3_2_400M,
80
+ }
81
+
82
+
83
+ def _prepare_transformer(model):
84
+ embed_dim = model.tok_embeddings.embedding_dim
85
+ model.tok_embeddings = nn.Identity()
86
+ model.output = nn.Identity()
87
+ return model, embed_dim
88
+
89
+
90
+ def _create_causal_mask(seq_len: int, device: torch.device):
91
+ return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
92
+
93
+
94
+ def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
95
+ r = mask[input_pos, :]
96
+ return r
97
+
98
+
99
+ def _multinomial_sample_one_no_sync(
100
+ probs,
101
+ ): # Does multinomial sampling without a cuda synchronization
102
+ q = torch.empty_like(probs).exponential_(1)
103
+ return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
104
+
105
+
106
+ def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
107
+ logits = logits / temperature
108
+
109
+ filter_value: float = -float("Inf")
110
+ indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
111
+ scores_processed = logits.masked_fill(indices_to_remove, filter_value)
112
+ scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
113
+ probs = torch.nn.functional.softmax(scores_processed, dim=-1)
114
+
115
+ sample_token = _multinomial_sample_one_no_sync(probs)
116
+ return sample_token
117
+
118
+
119
+ class HeartMuLa(PreTrainedModel):
120
+ config_class = HeartMuLaConfig
121
+
122
+ def __init__(
123
+ self,
124
+ config: HeartMuLaConfig,
125
+ ):
126
+ super(HeartMuLa, self).__init__(config)
127
+
128
+ self.config = config
129
+
130
+ self.backbone, backbone_dim = _prepare_transformer(
131
+ FLAVORS[config.backbone_flavor]()
132
+ )
133
+ self.decoder, decoder_dim = _prepare_transformer(
134
+ FLAVORS[config.decoder_flavor]()
135
+ )
136
+
137
+ self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
138
+ self.audio_embeddings = nn.Embedding(
139
+ config.audio_vocab_size * config.audio_num_codebooks, backbone_dim
140
+ )
141
+ self.unconditional_text_embedding = nn.Embedding(1, backbone_dim)
142
+
143
+ self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
144
+ self.codebook0_head = nn.Linear(
145
+ backbone_dim, config.audio_vocab_size, bias=False
146
+ )
147
+ self.audio_head = nn.Parameter(
148
+ torch.empty(
149
+ config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size
150
+ )
151
+ )
152
+ self.muq_linear = nn.Linear(config.muq_dim, backbone_dim)
153
+ self.post_init()
154
+
155
+ def setup_caches(self, max_batch_size: int):
156
+ dtype = next(self.parameters()).dtype
157
+ device = next(self.parameters()).device
158
+
159
+ try:
160
+ self.reset_caches()
161
+ except RuntimeError:
162
+ pass
163
+
164
+ with device:
165
+ self.backbone.setup_caches(max_batch_size, dtype)
166
+ self.decoder.setup_caches(
167
+ max_batch_size,
168
+ dtype,
169
+ decoder_max_seq_len=self.config.audio_num_codebooks,
170
+ )
171
+
172
+ self.register_buffer(
173
+ "backbone_causal_mask",
174
+ _create_causal_mask(self.backbone.max_seq_len, device),
175
+ )
176
+ self.register_buffer(
177
+ "decoder_causal_mask",
178
+ _create_causal_mask(self.config.audio_num_codebooks, device),
179
+ )
180
+
181
+ def generate_frame(
182
+ self,
183
+ tokens: torch.Tensor,
184
+ tokens_mask: torch.Tensor,
185
+ input_pos: torch.Tensor,
186
+ temperature: float,
187
+ topk: int,
188
+ cfg_scale: float,
189
+ continuous_segments: torch.Tensor = None,
190
+ starts=None,
191
+ ) -> torch.Tensor:
192
+ b, s, _ = tokens.size()
193
+
194
+ assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
195
+ curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
196
+
197
+ uncond_mask = None
198
+ if cfg_scale > 1.0 and b > 1:
199
+ actual_B = b // 2
200
+ uncond_mask = torch.cat(
201
+ [
202
+ torch.zeros(actual_B, dtype=torch.bool, device=tokens.device),
203
+ torch.ones(actual_B, dtype=torch.bool, device=tokens.device),
204
+ ]
205
+ )
206
+
207
+ embeds = self._embed_tokens(tokens, uncond_mask=uncond_mask)
208
+ masked_embeds = embeds * tokens_mask.unsqueeze(-1)
209
+ h = masked_embeds.sum(dim=2, dtype=embeds.dtype) # merge
210
+ if continuous_segments is not None:
211
+ continuous_segments = self.muq_linear(continuous_segments)
212
+ if uncond_mask is not None:
213
+ uncond_embed = self.unconditional_text_embedding(
214
+ torch.zeros(1, device=tokens.device, dtype=torch.long)
215
+ )
216
+ mask_expanded = uncond_mask.view(b, 1).expand_as(continuous_segments)
217
+ continuous_segments = torch.where(
218
+ mask_expanded, uncond_embed, continuous_segments
219
+ )
220
+ batch_indices = torch.arange(h.shape[0], device=h.device)
221
+ h[batch_indices, starts] = continuous_segments
222
+ h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask)
223
+ last_h = h[:, -1, :] # the last frame
224
+ c0_logits = self.codebook0_head(last_h) # only predict the audio part
225
+
226
+ if cfg_scale > 1.0 and b > 1 and (b % 2 == 0):
227
+ actual_B = b // 2
228
+ cond_logits = c0_logits[:actual_B, :]
229
+ uncond_logits = c0_logits[actual_B:, :]
230
+ guided_logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
231
+ c0_sample = sample_topk(guided_logits, topk, temperature)
232
+ c0_sample = c0_sample.repeat(
233
+ 2, 1
234
+ ) # repeat to both branches to keep alignment
235
+ else:
236
+ c0_sample = sample_topk(c0_logits, topk, temperature)
237
+
238
+ c0_embed = self._embed_audio(0, c0_sample)
239
+
240
+ self.decoder.reset_caches()
241
+ curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
242
+ curr_sample = c0_sample.clone()
243
+ curr_pos = (
244
+ torch.arange(0, curr_h.size(1), device=curr_h.device)
245
+ .unsqueeze(0)
246
+ .repeat(curr_h.size(0), 1)
247
+ )
248
+ curr_h = curr_h.to(embeds.dtype)
249
+ for i in range(1, self.config.audio_num_codebooks):
250
+ curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
251
+ decoder_h = self.decoder(
252
+ self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask
253
+ )
254
+ ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
255
+ if cfg_scale > 1.0 and b > 1 and (b % 2 == 0):
256
+ actual_B = b // 2
257
+ cond_ci = ci_logits[:actual_B, :]
258
+ uncond_ci = ci_logits[actual_B:, :]
259
+ guided_ci = uncond_ci + (cond_ci - uncond_ci) * cfg_scale
260
+
261
+ ci_sample = sample_topk(guided_ci, topk, temperature)
262
+ ci_sample = ci_sample.repeat(2, 1)
263
+ else:
264
+ ci_sample = sample_topk(ci_logits, topk, temperature)
265
+ ci_embed = self._embed_audio(i, ci_sample)
266
+ curr_h = ci_embed
267
+ curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
268
+ curr_pos = curr_pos[:, -1:] + 1
269
+
270
+ return curr_sample
271
+
272
+ def reset_caches(self):
273
+ self.backbone.reset_caches()
274
+ self.decoder.reset_caches()
275
+
276
+ def _embed_local_audio(self, tokens):
277
+ """the token from 0-30"""
278
+ audio_tokens = tokens + (
279
+ self.config.audio_vocab_size
280
+ * torch.arange(self.config.audio_num_codebooks - 1, device=tokens.device)
281
+ )
282
+ audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
283
+ tokens.size(0), tokens.size(1), self.config.audio_num_codebooks - 1, -1
284
+ )
285
+ return audio_embeds
286
+
287
+ def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
288
+ return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
289
+
290
+ def _embed_tokens(
291
+ self, tokens: torch.Tensor, uncond_mask: torch.Tensor | None
292
+ ) -> torch.Tensor:
293
+ B, S, _ = tokens.size()
294
+ text_embeds = self.text_embeddings(tokens[:, :, -1])
295
+
296
+ if uncond_mask is not None:
297
+ uncond_text_embed = self.unconditional_text_embedding(
298
+ torch.zeros(1, device=tokens.device, dtype=torch.long)
299
+ )
300
+ mask_expanded = uncond_mask.view(B, 1, 1).expand_as(text_embeds)
301
+ text_embeds = torch.where(
302
+ mask_expanded,
303
+ uncond_text_embed,
304
+ text_embeds,
305
+ )
306
+
307
+ text_embeds = text_embeds.unsqueeze(-2)
308
+
309
+ audio_tokens = tokens[:, :, :-1] + (
310
+ self.config.audio_vocab_size
311
+ * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
312
+ )
313
+ audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
314
+ tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
315
+ )
316
+ return torch.cat([audio_embeds, text_embeds], dim=-2)
src/heartlib/pipelines/__init__.py ADDED
File without changes
src/heartlib/pipelines/lyrics_transcription.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.pipelines.automatic_speech_recognition import (
2
+ AutomaticSpeechRecognitionPipeline,
3
+ )
4
+ from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration
5
+ from transformers.models.whisper.processing_whisper import WhisperProcessor
6
+ import torch
7
+ import os
8
+
9
+
10
+ class HeartTranscriptorPipeline(AutomaticSpeechRecognitionPipeline):
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__(*args, **kwargs)
13
+
14
+ @classmethod
15
+ def from_pretrained(
16
+ cls, pretrained_path: str, device: torch.device, dtype: torch.dtype
17
+ ):
18
+ if os.path.exists(
19
+ hearttranscriptor_path := os.path.join(
20
+ pretrained_path, "HeartTranscriptor-oss"
21
+ )
22
+ ):
23
+ model = WhisperForConditionalGeneration.from_pretrained(
24
+ hearttranscriptor_path, torch_dtype=dtype, low_cpu_mem_usage=True
25
+ )
26
+ processor = WhisperProcessor.from_pretrained(hearttranscriptor_path)
27
+ else:
28
+ raise FileNotFoundError(
29
+ f"Expected to find checkpoint for HeartTranscriptor at {hearttranscriptor_path} but not found. Please check your folder {pretrained_path}."
30
+ )
31
+
32
+ return cls(
33
+ model=model,
34
+ tokenizer=processor.tokenizer,
35
+ feature_extractor=processor.feature_extractor,
36
+ device=device,
37
+ dtype=dtype,
38
+ chunk_length_s=30,
39
+ batch_size=16,
40
+ )
src/heartlib/pipelines/music_generation.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer
2
+ from ..heartmula.modeling_heartmula import HeartMuLa
3
+ from ..heartcodec.modeling_heartcodec import HeartCodec
4
+ import torch
5
+ from typing import Dict, Any, Optional, Union
6
+ import os
7
+ from dataclasses import dataclass
8
+ from tqdm import tqdm
9
+ import torchaudio
10
+ import json
11
+ from contextlib import contextmanager
12
+ import gc
13
+
14
+
15
+ def _resolve_paths(pretrained_path: str, version: str):
16
+
17
+ heartmula_path = os.path.join(pretrained_path, f"HeartMuLa-oss-{version}")
18
+ heartcodec_path = os.path.join(pretrained_path, "HeartCodec-oss")
19
+ tokenizer_path = os.path.join(pretrained_path, "tokenizer.json")
20
+ gen_config_path = os.path.join(pretrained_path, "gen_config.json")
21
+
22
+ if not os.path.exists(heartmula_path):
23
+ raise FileNotFoundError(
24
+ f"Expected to find checkpoint for HeartMuLa at {heartmula_path} but not found. Please check your folder {pretrained_path}."
25
+ )
26
+ if not os.path.exists(heartcodec_path):
27
+ raise FileNotFoundError(
28
+ f"Expected to find checkpoint for HeartCodec at {heartcodec_path} but not found. Please check your folder {pretrained_path}."
29
+ )
30
+ if not os.path.isfile(tokenizer_path):
31
+ raise FileNotFoundError(
32
+ f"Expected to find tokenizer.json for HeartMuLa at {tokenizer_path} but not found. Please check your folder {pretrained_path}."
33
+ )
34
+ if not os.path.isfile(gen_config_path):
35
+ raise FileNotFoundError(
36
+ f"Expected to find gen_config.json for HeartMuLa at {gen_config_path} but not found. Please check your folder {pretrained_path}."
37
+ )
38
+
39
+ return heartmula_path, heartcodec_path, tokenizer_path, gen_config_path
40
+
41
+
42
+ def _resolve_devices(
43
+ device: Union[torch.device, Dict[str, torch.device]], lazy_load: bool
44
+ ):
45
+ if isinstance(device, torch.device):
46
+ print(f"All model components will be loaded to device: {device}.")
47
+ mula_device = device
48
+ codec_device = device
49
+ elif isinstance(device, dict):
50
+ print("Model components will be loaded to devices as specified:")
51
+ for k, v in device.items():
52
+ print(f" {k}: {v}")
53
+ mula_device = device["mula"]
54
+ codec_device = device["codec"]
55
+ else:
56
+ raise ValueError(
57
+ "device must be either torch.device or Dict[str, torch.device]"
58
+ )
59
+
60
+ single_device = mula_device == codec_device
61
+ if not single_device:
62
+ print(
63
+ f"HeartMuLa and HeartCodec will be loaded to different devices. In this case, lazy_load is turned off."
64
+ )
65
+ lazy_load = False
66
+
67
+ return mula_device, codec_device, lazy_load
68
+
69
+
70
+ @dataclass
71
+ class HeartMuLaGenConfig:
72
+ text_bos_id: int = 128000
73
+ text_eos_id: int = 128001
74
+ audio_eos_id: int = 8193
75
+ empty_id: int = 0
76
+
77
+ @classmethod
78
+ def from_file(cls, path: str):
79
+ with open(path, encoding="utf-8") as fp:
80
+ data = json.load(fp)
81
+ return cls(**data)
82
+
83
+
84
+ class HeartMuLaGenPipeline:
85
+ def __init__(
86
+ self,
87
+ heartmula_path: str,
88
+ heartcodec_path: str,
89
+ heartmula_device: torch.device,
90
+ heartcodec_device: torch.device,
91
+ heartmula_dtype: torch.dtype,
92
+ heartcodec_dtype: torch.dtype,
93
+ lazy_load: bool,
94
+ muq_mulan: Optional[Any],
95
+ text_tokenizer: Tokenizer,
96
+ config: HeartMuLaGenConfig,
97
+ ):
98
+
99
+ self.muq_mulan = muq_mulan
100
+ self.text_tokenizer = text_tokenizer
101
+ self.config = config
102
+
103
+ # Remain fixed here for simplicity.
104
+ self._parallel_number = 8 + 1
105
+ self._muq_dim = 512
106
+
107
+ self.mula_dtype = heartmula_dtype
108
+ self.mula_path = heartmula_path
109
+ self.mula_device = heartmula_device
110
+ self.codec_dtype = heartcodec_dtype
111
+ self.codec_path = heartcodec_path
112
+ self.codec_device = heartcodec_device
113
+
114
+ self._mula: Optional[HeartMuLa] = None
115
+ self._codec: Optional[HeartCodec] = None
116
+ if not lazy_load:
117
+ print(
118
+ f"You have set lazy_load = False. Loading HeartMuLa and HeartCodec onto device..."
119
+ )
120
+ self._mula = HeartMuLa.from_pretrained(
121
+ self.mula_path,
122
+ device_map=self.mula_device,
123
+ torch_dtype=self.mula_dtype,
124
+ )
125
+ self._codec = HeartCodec.from_pretrained(
126
+ self.codec_path,
127
+ device_map=self.codec_device,
128
+ torch_dtype=self.codec_dtype,
129
+ )
130
+ self.lazy_load = lazy_load
131
+
132
+ @property
133
+ def mula(self) -> HeartMuLa:
134
+ if isinstance(self._mula, HeartMuLa):
135
+ return self._mula
136
+ self._mula = HeartMuLa.from_pretrained(
137
+ self.mula_path,
138
+ device_map=self.mula_device,
139
+ torch_dtype=self.mula_dtype,
140
+ )
141
+ return self._mula
142
+
143
+ @property
144
+ def codec(self) -> HeartCodec:
145
+ if isinstance(self._codec, HeartCodec):
146
+ return self._codec
147
+ self._codec = HeartCodec.from_pretrained(
148
+ self.codec_path,
149
+ device_map=self.codec_device,
150
+ torch_dtype=self.codec_dtype,
151
+ )
152
+ return self._codec
153
+
154
+ def _unload(self):
155
+ if not self.lazy_load:
156
+ return
157
+ if isinstance(self._mula, HeartMuLa):
158
+ print(f"You have set lazy_load=True. Unloading HeartMuLa from device.")
159
+ print(
160
+ f"CUDA memory before unloading: {torch.cuda.memory_allocated(self.mula_device) / 1024**3:.2f} GB"
161
+ )
162
+ del self._mula
163
+ gc.collect()
164
+ torch.cuda.empty_cache()
165
+ print(
166
+ f"CUDA memory after unloading: {torch.cuda.memory_allocated(self.mula_device) / 1024**3:.2f} GB"
167
+ )
168
+ self._mula = None
169
+ if isinstance(self._codec, HeartCodec):
170
+ print(f"You have set lazy_load=True. Unloading HeartCodec from device.")
171
+ print(
172
+ f"CUDA memory before unloading: {torch.cuda.memory_allocated(self.codec_device) / 1024**3:.2f} GB"
173
+ )
174
+ del self._codec
175
+ gc.collect()
176
+ torch.cuda.empty_cache()
177
+ print(
178
+ f"CUDA memory after unloading: {torch.cuda.memory_allocated(self.codec_device) / 1024**3:.2f} GB"
179
+ )
180
+ self._codec = None
181
+ return
182
+
183
+ def _sanitize_parameters(self, **kwargs):
184
+ preprocess_kwargs = {"cfg_scale": kwargs.get("cfg_scale", 1.5)}
185
+ forward_kwargs = {
186
+ "max_audio_length_ms": kwargs.get("max_audio_length_ms", 120_000),
187
+ "temperature": kwargs.get("temperature", 1.0),
188
+ "topk": kwargs.get("topk", 50),
189
+ "cfg_scale": kwargs.get("cfg_scale", 1.5),
190
+ }
191
+ postprocess_kwargs = {
192
+ "save_path": kwargs.get("save_path", "output.mp3"),
193
+ }
194
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
195
+
196
+ def preprocess(self, inputs: Dict[str, Any], cfg_scale: float):
197
+
198
+ # process tags
199
+ tags = inputs["tags"]
200
+ if os.path.isfile(tags):
201
+ with open(tags, encoding="utf-8") as fp:
202
+ tags = fp.read()
203
+ assert isinstance(tags, str), f"tags must be a string, but got {type(tags)}"
204
+
205
+ tags = tags.lower()
206
+ # encapsulate with special <tag> and </tag> tokens
207
+ if not tags.startswith("<tag>"):
208
+ tags = f"<tag>{tags}"
209
+ if not tags.endswith("</tag>"):
210
+ tags = f"{tags}</tag>"
211
+
212
+ tags_ids = self.text_tokenizer.encode(tags).ids
213
+ if tags_ids[0] != self.config.text_bos_id:
214
+ tags_ids = [self.config.text_bos_id] + tags_ids
215
+ if tags_ids[-1] != self.config.text_eos_id:
216
+ tags_ids = tags_ids + [self.config.text_eos_id]
217
+
218
+ # process reference audio
219
+ ref_audio = inputs.get("ref_audio", None)
220
+ if ref_audio is not None:
221
+ raise NotImplementedError("ref_audio is not supported yet.")
222
+ muq_embed = torch.zeros([self._muq_dim], dtype=self.mula_dtype)
223
+ muq_idx = len(tags_ids)
224
+
225
+ # process lyrics
226
+ lyrics = inputs["lyrics"]
227
+ if os.path.isfile(lyrics):
228
+ with open(lyrics, encoding="utf-8") as fp:
229
+ lyrics = fp.read()
230
+ assert isinstance(
231
+ lyrics, str
232
+ ), f"lyrics must be a string, but got {type(lyrics)}"
233
+ lyrics = lyrics.lower()
234
+
235
+ lyrics_ids = self.text_tokenizer.encode(lyrics).ids
236
+ if lyrics_ids[0] != self.config.text_bos_id:
237
+ lyrics_ids = [self.config.text_bos_id] + lyrics_ids
238
+ if lyrics_ids[-1] != self.config.text_eos_id:
239
+ lyrics_ids = lyrics_ids + [self.config.text_eos_id]
240
+
241
+ # cat them together. tags, ref_audio, lyrics
242
+ prompt_len = len(tags_ids) + 1 + len(lyrics_ids)
243
+
244
+ tokens = torch.zeros([prompt_len, self._parallel_number], dtype=torch.long)
245
+ tokens[: len(tags_ids), -1] = torch.tensor(tags_ids)
246
+ tokens[len(tags_ids) + 1 :, -1] = torch.tensor(lyrics_ids)
247
+
248
+ tokens_mask = torch.zeros_like(tokens, dtype=torch.bool)
249
+ tokens_mask[:, -1] = True
250
+
251
+ bs_size = 2 if cfg_scale != 1.0 else 1
252
+
253
+ def _cfg_cat(tensor: torch.Tensor, cfg_scale: float):
254
+ tensor = tensor.unsqueeze(0)
255
+ if cfg_scale != 1.0:
256
+ tensor = torch.cat([tensor, tensor], dim=0)
257
+ return tensor
258
+
259
+ return {
260
+ "tokens": _cfg_cat(tokens, cfg_scale),
261
+ "tokens_mask": _cfg_cat(tokens_mask, cfg_scale),
262
+ "muq_embed": _cfg_cat(muq_embed, cfg_scale),
263
+ "muq_idx": [muq_idx] * bs_size,
264
+ "pos": _cfg_cat(torch.arange(prompt_len, dtype=torch.long), cfg_scale),
265
+ }
266
+
267
+ def _forward(
268
+ self,
269
+ model_inputs: Dict[str, Any],
270
+ max_audio_length_ms: int,
271
+ temperature: float,
272
+ topk: int,
273
+ cfg_scale: float,
274
+ ):
275
+ prompt_tokens = model_inputs["tokens"].to(self.mula_device)
276
+ prompt_tokens_mask = model_inputs["tokens_mask"].to(self.mula_device)
277
+ continuous_segment = model_inputs["muq_embed"].to(self.mula_device)
278
+ starts = model_inputs["muq_idx"]
279
+ prompt_pos = model_inputs["pos"].to(self.mula_device)
280
+ frames = []
281
+
282
+ bs_size = 2 if cfg_scale != 1.0 else 1
283
+ self.mula.setup_caches(bs_size)
284
+ with torch.autocast(device_type=self.mula_device.type, dtype=self.mula_dtype):
285
+ curr_token = self.mula.generate_frame(
286
+ tokens=prompt_tokens,
287
+ tokens_mask=prompt_tokens_mask,
288
+ input_pos=prompt_pos,
289
+ temperature=temperature,
290
+ topk=topk,
291
+ cfg_scale=cfg_scale,
292
+ continuous_segments=continuous_segment,
293
+ starts=starts,
294
+ )
295
+ frames.append(curr_token[0:1,])
296
+
297
+ def _pad_audio_token(token: torch.Tensor):
298
+ padded_token = (
299
+ torch.ones(
300
+ (token.shape[0], self._parallel_number),
301
+ device=token.device,
302
+ dtype=torch.long,
303
+ )
304
+ * self.config.empty_id
305
+ )
306
+ padded_token[:, :-1] = token
307
+ padded_token = padded_token.unsqueeze(1)
308
+ padded_token_mask = torch.ones_like(
309
+ padded_token, device=token.device, dtype=torch.bool
310
+ )
311
+ padded_token_mask[..., -1] = False
312
+ return padded_token, padded_token_mask
313
+
314
+ max_audio_frames = max_audio_length_ms // 80
315
+
316
+ for i in tqdm(range(max_audio_frames)):
317
+ curr_token, curr_token_mask = _pad_audio_token(curr_token)
318
+ with torch.autocast(
319
+ device_type=self.mula_device.type, dtype=self.mula_dtype
320
+ ):
321
+ curr_token = self.mula.generate_frame(
322
+ tokens=curr_token,
323
+ tokens_mask=curr_token_mask,
324
+ input_pos=prompt_pos[..., -1:] + i + 1,
325
+ temperature=temperature,
326
+ topk=topk,
327
+ cfg_scale=cfg_scale,
328
+ continuous_segments=None,
329
+ starts=None,
330
+ )
331
+ if torch.any(curr_token[0:1, :] >= self.config.audio_eos_id):
332
+ break
333
+ frames.append(curr_token[0:1,])
334
+ frames = torch.stack(frames).permute(1, 2, 0).squeeze(0)
335
+ self._unload()
336
+ return {"frames": frames}
337
+
338
+ def postprocess(self, model_outputs: Dict[str, Any], save_path: str):
339
+ frames = model_outputs["frames"].to(self.codec_device)
340
+ wav = self.codec.detokenize(frames)
341
+ self._unload()
342
+ torchaudio.save(save_path, wav.to(torch.float32).cpu(), 48000)
343
+
344
+ def __call__(self, inputs: Dict[str, Any], **kwargs):
345
+ preprocess_kwargs, forward_kwargs, postprocess_kwargs = (
346
+ self._sanitize_parameters(**kwargs)
347
+ )
348
+ model_inputs = self.preprocess(inputs, **preprocess_kwargs)
349
+ model_outputs = self._forward(model_inputs, **forward_kwargs)
350
+ self.postprocess(model_outputs, **postprocess_kwargs)
351
+
352
+ @classmethod
353
+ def from_pretrained(
354
+ cls,
355
+ pretrained_path: str,
356
+ device: Union[torch.device, Dict[str, torch.device]],
357
+ dtype: Union[torch.dtype, Dict[str, torch.dtype]],
358
+ version: str,
359
+ lazy_load: bool = False,
360
+ ):
361
+
362
+ mula_path, codec_path, tokenizer_path, gen_config_path = _resolve_paths(
363
+ pretrained_path, version
364
+ )
365
+ mula_device, codec_device, lazy_load = _resolve_devices(device, lazy_load)
366
+ tokenizer = Tokenizer.from_file(tokenizer_path)
367
+ gen_config = HeartMuLaGenConfig.from_file(gen_config_path)
368
+
369
+ mula_dtype = dtype["mula"] if isinstance(dtype, dict) else dtype
370
+ codec_dtype = dtype["codec"] if isinstance(dtype, dict) else dtype
371
+
372
+ return cls(
373
+ heartmula_path=mula_path,
374
+ heartcodec_path=codec_path,
375
+ heartmula_device=mula_device,
376
+ heartcodec_device=codec_device,
377
+ lazy_load=lazy_load,
378
+ muq_mulan=None,
379
+ text_tokenizer=tokenizer,
380
+ config=gen_config,
381
+ heartmula_dtype=mula_dtype,
382
+ heartcodec_dtype=codec_dtype,
383
+ )