Spaces:
Sleeping
Sleeping
Commit ·
0322512
0
Parent(s):
HeartMuLa Gradio Space deployment
Browse files- .dockerignore +18 -0
- .github/copilot-instructions.md +89 -0
- .gitignore +18 -0
- DEPLOYMENT_GUIDE.md +115 -0
- Dockerfile +30 -0
- LICENSE +201 -0
- MANUAL_DEPLOYMENT.md +55 -0
- README.md +255 -0
- README_SPACE.md +54 -0
- app.py +228 -0
- examples/README.md +15 -0
- examples/run_lyrics_transcription.py +35 -0
- examples/run_music_generation.py +81 -0
- pyproject.toml +46 -0
- requirements.txt +11 -0
- setup.sh +28 -0
- src/heartlib/__init__.py +7 -0
- src/heartlib/heartcodec/configuration_heartcodec.py +73 -0
- src/heartlib/heartcodec/modeling_heartcodec.py +183 -0
- src/heartlib/heartcodec/models/flow_matching.py +176 -0
- src/heartlib/heartcodec/models/sq_codec.py +539 -0
- src/heartlib/heartcodec/models/transformer.py +501 -0
- src/heartlib/heartmula/configuration_heartmula.py +23 -0
- src/heartlib/heartmula/modeling_heartmula.py +316 -0
- src/heartlib/pipelines/__init__.py +0 -0
- src/heartlib/pipelines/lyrics_transcription.py +40 -0
- src/heartlib/pipelines/music_generation.py +383 -0
.dockerignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.github
|
| 3 |
+
.venv
|
| 4 |
+
.env
|
| 5 |
+
__pycache__
|
| 6 |
+
*.pyc
|
| 7 |
+
*.pyo
|
| 8 |
+
*.egg-info
|
| 9 |
+
.pytest_cache
|
| 10 |
+
.coverage
|
| 11 |
+
*.mp3
|
| 12 |
+
*.wav
|
| 13 |
+
*.ogg
|
| 14 |
+
assets/*.mp3
|
| 15 |
+
assets/*.wav
|
| 16 |
+
.gitignore
|
| 17 |
+
README.md
|
| 18 |
+
LICENSE
|
.github/copilot-instructions.md
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GitHub Copilot instructions for heartlib
|
| 2 |
+
|
| 3 |
+
## What this repo is
|
| 4 |
+
- HeartMuLa music generation stack: converts lyrics + style tags → audio via two-stage pipeline (HeartMuLa LLM → audio tokens, HeartCodec flow-matching codec → waveform).
|
| 5 |
+
- Supports lyrics transcription via Whisper-based HeartTranscriptor.
|
| 6 |
+
- Python package with main entry points: `heartlib.HeartMuLaGenPipeline` and `heartlib.HeartTranscriptorPipeline` (see `src/heartlib/__init__.py`).
|
| 7 |
+
- Examples/reference CLIs in `examples/`; production use: install via `pip install -e .`
|
| 8 |
+
|
| 9 |
+
## Core architecture & data flow
|
| 10 |
+
**Music generation:** `Inputs(lyrics, tags)` → Tokenizer → HeartMuLa (frame-by-frame token generation) → HeartCodec (flow-matching detokenization) → MP3
|
| 11 |
+
- **HeartMuLa** (LLaMA3.2 backbone, 3B/300M/7B flavors): generates 8 parallel audio codebook streams + 1 prompt guidance stream (9 total, `_parallel_number=9`)
|
| 12 |
+
- **HeartCodec**: VQ codec that reconstructs waveforms from codebook frames in overlapping windows, fixed 48 kHz output
|
| 13 |
+
- **HeartTranscriptor**: Whisper variant fine-tuned for vocal transcription; works on 30-second chunks, batch=16
|
| 14 |
+
|
| 15 |
+
## Repo map
|
| 16 |
+
- `src/heartlib/pipelines/music_generation.py`: orchestrates tokenization → HeartMuLa inference → HeartCodec detokenize → `torchaudio.save()`
|
| 17 |
+
- `HeartMuLaGenPipeline.from_pretrained()`: factory with device/dtype/lazy_load config
|
| 18 |
+
- `_resolve_paths()`: validates checkpoint layout early (hard error if missing)
|
| 19 |
+
- `_resolve_devices()`: handles scalar/dict device specs; forces `lazy_load=False` for multi-device
|
| 20 |
+
- `src/heartlib/heartmula/modeling_heartmula.py`: backbone (llama3_2_3B/7B/300M factory functions), token generator with CFG support
|
| 21 |
+
- `src/heartlib/heartcodec/modeling_heartcodec.py`: VQ codec + flow-matching decoder, detokenizes `(codebooks, time)` frames
|
| 22 |
+
- `src/heartlib/pipelines/lyrics_transcription.py`: wraps transformers' Whisper; fixed chunk=30s, batch=16
|
| 23 |
+
- `src/heartlib/heartmula/configuration_heartmula.py`, `src/heartlib/heartcodec/configuration_heartcodec.py`: model configs
|
| 24 |
+
|
| 25 |
+
## Checkpoints & required layout
|
| 26 |
+
Directory structure after downloads (see README or `hf download` commands):
|
| 27 |
+
```
|
| 28 |
+
./ckpt/
|
| 29 |
+
HeartMuLa-oss-3B/ (or -7B, -300M)
|
| 30 |
+
config.json
|
| 31 |
+
model-*.safetensors
|
| 32 |
+
model.safetensors.index.json
|
| 33 |
+
HeartCodec-oss/
|
| 34 |
+
config.json
|
| 35 |
+
model-*.safetensors
|
| 36 |
+
model.safetensors.index.json
|
| 37 |
+
HeartTranscriptor-oss/
|
| 38 |
+
config.json
|
| 39 |
+
pytorch_model.bin (or safetensors)
|
| 40 |
+
tokenizer.json
|
| 41 |
+
gen_config.json
|
| 42 |
+
```
|
| 43 |
+
- `_resolve_paths(pretrained_path, version)` validates all required files; raises FileNotFoundError if missing
|
| 44 |
+
- Latest checkpoint: HeartMuLa-RL-oss-3B-20260123 (RL-tuned, recommended for style control)
|
| 45 |
+
|
| 46 |
+
## Generation pipeline behaviors to know
|
| 47 |
+
- **Inputs:** dict with `lyrics`, `tags` (both strings or file paths); auto-lowercased, tags wrapped with `<tag>...</tag>` if missing
|
| 48 |
+
- **Tokenization:** uses `tokenizers.Tokenizer` (from `tokenizer.json`); token IDs from `HeartMuLaGenConfig` (text_bos_id=128000, text_eos_id=128001, audio_eos_id=8193)
|
| 49 |
+
- **CFG (classifier-free guidance):** if `cfg_scale != 1.0`, batch duplicated for unconditional pass (bs becomes 2×); enables style control tradeoff
|
| 50 |
+
- **Audio generation loop:** runs max `max_audio_length_ms // 80` frames (~12.5 Hz generation rate); stops early if any token ≥ audio_eos_id
|
| 51 |
+
- **Memory optimization:**
|
| 52 |
+
- `lazy_load=True` defers model loading, unloads after generation (saves CUDA between uses)
|
| 53 |
+
- Forced `lazy_load=False` if mula_device ≠ codec_device (different device types can't swap)
|
| 54 |
+
- Uses `torch.autocast` with specified dtype to reduce memory footprint
|
| 55 |
+
- **Output:** via `torchaudio.save(save_path, wav, 48000)` at fixed 48 kHz sample rate
|
| 56 |
+
|
| 57 |
+
## Codec specifics
|
| 58 |
+
- `HeartCodec.detokenize(frames)` expects shape `(codebooks, time)` where codebooks ≤ 8; pads/repeats to uniform length internally
|
| 59 |
+
- Uses flow-matching inference in overlapping windows (reduces boundary artifacts), then scalar decoder → PCM waveform
|
| 60 |
+
- Fixed 48 kHz output; non-standard rates must be resampled post-generation
|
| 61 |
+
- Model config (`config.json`) defines number of codebooks and codec architecture
|
| 62 |
+
|
| 63 |
+
## Transcription pipeline behaviors
|
| 64 |
+
- `HeartTranscriptorPipeline.from_pretrained(model_path, device, dtype)` wraps `WhisperForConditionalGeneration` from `HeartTranscriptor-oss`
|
| 65 |
+
- Fixed at 30-second chunks, batch size 16; no dynamic chunking
|
| 66 |
+
- Note: trained on separated vocals; best results with source-separated inputs (use demucs or similar pre-pipeline)
|
| 67 |
+
- Supports beam search and temperature kwargs via `__call__()` decoding_kwargs (see `examples/run_lyrics_transcription.py`)
|
| 68 |
+
|
| 69 |
+
## Dev workflows & commands
|
| 70 |
+
- **Install:** `pip install -e .` (Python ≥3.9, 3.10 recommended; CUDA deps: torch 2.4.1, torchaudio 2.4.1, torchtune 0.4.0, bitsandbytes 0.49.0)
|
| 71 |
+
- **Generate:** `python examples/run_music_generation.py --model_path ./ckpt --version 3B --lyrics ./assets/lyrics.txt --tags ./assets/tags.txt`
|
| 72 |
+
- Key flags: `--mula_device cuda --codec_device cuda` (or separate); `--lazy_load true` (single-GPU VRAM relief); `--cfg_scale 1.5` (style strength)
|
| 73 |
+
- **Transcribe:** `python examples/run_lyrics_transcription.py --model_path ./ckpt --music_path ./assets/output.mp3`
|
| 74 |
+
- No test suite; validate changes via example scripts
|
| 75 |
+
|
| 76 |
+
## Coding conventions & critical patterns
|
| 77 |
+
- **Device specs:** `from_pretrained(..., device=X)` accepts `torch.device` (both models→X) or dict `{"mula": dev1, "codec": dev2}` (forces `lazy_load=False`)
|
| 78 |
+
- **Dtype specs:** mirrors device—scalar dtype or dict with `"mula"`, `"codec"` keys
|
| 79 |
+
- **Token/text handling:** always lowercase inputs, auto-wrap tags with `<tag>...</tag>`, append BOS/EOS via tokenizer config (callers depend on this)
|
| 80 |
+
- **Unimplemented:** reference audio path exists but raises `NotImplementedError`; don't add stub without full end-to-end implementation
|
| 81 |
+
- **Generation loop internals:** `tqdm` progress, `torch.autocast` scope—avoid breaking these or model cache setup (`setup_caches()`)
|
| 82 |
+
- **Memory patterns:** properties `self.mula` / `self.codec` lazy-load on first access if `lazy_load=True`, then can unload via `_unload_models()`
|
| 83 |
+
|
| 84 |
+
## Quick pointers for agents
|
| 85 |
+
- Extend pipelines, not models, unless changing core LLM/codec logic
|
| 86 |
+
- Validate paths early (mirror `_resolve_paths` style) for new entry points
|
| 87 |
+
- Preserve 48 kHz sample rate and codebook count (8) in outputs
|
| 88 |
+
- When modifying tokenization or BOS/EOS logic, verify examples still run end-to-end
|
| 89 |
+
- Device/dtype flexibility is intentional—test multi-GPU configs if changing device dispatch logic
|
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
**.pyc
|
| 3 |
+
checkpoint/
|
| 4 |
+
.DS_Store
|
| 5 |
+
**.wav
|
| 6 |
+
**.mp3
|
| 7 |
+
**.png
|
| 8 |
+
**.jpeg
|
| 9 |
+
**.jpg
|
| 10 |
+
.vscode/
|
| 11 |
+
**.egg-info/
|
| 12 |
+
build/
|
| 13 |
+
.idea/
|
| 14 |
+
ckpt/
|
| 15 |
+
.venv*/
|
| 16 |
+
.env
|
| 17 |
+
models/
|
| 18 |
+
assets/
|
DEPLOYMENT_GUIDE.md
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HeartMuLa HF Space Deployment Guide
|
| 2 |
+
|
| 3 |
+
## What's Been Set Up
|
| 4 |
+
|
| 5 |
+
✅ **Gradio Web UI** (`app.py`)
|
| 6 |
+
- Interactive music generation interface
|
| 7 |
+
- Real-time parameter adjustment
|
| 8 |
+
- Audio preview and download
|
| 9 |
+
- Example prompts included
|
| 10 |
+
|
| 11 |
+
✅ **Docker Environment** (`Dockerfile` + `requirements.txt`)
|
| 12 |
+
- CUDA 12.1 with GPU support
|
| 13 |
+
- All dependencies pre-configured
|
| 14 |
+
- Automatic model downloading on startup
|
| 15 |
+
|
| 16 |
+
✅ **Space Configuration**
|
| 17 |
+
- `README_SPACE.md` - Space documentation
|
| 18 |
+
- `.dockerignore` - Optimized Docker builds
|
| 19 |
+
|
| 20 |
+
## Deployment Steps
|
| 21 |
+
|
| 22 |
+
### 1. Push to Your HF Space Repository
|
| 23 |
+
```bash
|
| 24 |
+
cd f:\Projects\heartlib
|
| 25 |
+
git remote add space https://huggingface.co/spaces/brandongraves08/test
|
| 26 |
+
git push space main
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### 2. Configure Space Settings (in HF UI)
|
| 30 |
+
1. Go to https://huggingface.co/spaces/brandongraves08/test/settings
|
| 31 |
+
2. Set **Runtime** to **Docker**
|
| 32 |
+
3. Select **GPU** hardware (A100, T4, or A10 recommended)
|
| 33 |
+
4. Save settings
|
| 34 |
+
|
| 35 |
+
### 3. Space Will Auto-Deploy
|
| 36 |
+
The Dockerfile will:
|
| 37 |
+
- Install all dependencies
|
| 38 |
+
- Download HeartMuLa and HeartCodec models
|
| 39 |
+
- Start the Gradio app on port 7860
|
| 40 |
+
|
| 41 |
+
## Features
|
| 42 |
+
|
| 43 |
+
### Generation Parameters
|
| 44 |
+
- **Lyrics**: Custom lyrics for the song
|
| 45 |
+
- **Tags**: Style descriptors (pop, rock, ambient, etc.)
|
| 46 |
+
- **Duration**: 5-60 seconds
|
| 47 |
+
- **Temperature**: 0.1-2.0 (creativity level)
|
| 48 |
+
- **CFG Scale**: 1.0-3.0 (style control strength)
|
| 49 |
+
- **Top-K**: 10-100 (sampling parameter)
|
| 50 |
+
|
| 51 |
+
### Model Information
|
| 52 |
+
- **HeartMuLa-RL-oss-3B**: RL-tuned 3B model (recommended)
|
| 53 |
+
- **HeartCodec-oss**: High-fidelity codec (48 kHz)
|
| 54 |
+
- **Inference Speed**: ~RTF 1.0
|
| 55 |
+
|
| 56 |
+
## Local Testing (Before Deploying)
|
| 57 |
+
|
| 58 |
+
Test locally with CUDA GPU:
|
| 59 |
+
```bash
|
| 60 |
+
cd f:\Projects\heartlib
|
| 61 |
+
pip install gradio
|
| 62 |
+
python app.py
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Then open http://localhost:7860 in your browser.
|
| 66 |
+
|
| 67 |
+
## Troubleshooting
|
| 68 |
+
|
| 69 |
+
### Models Not Downloading
|
| 70 |
+
If the Space fails to download models:
|
| 71 |
+
1. Check HF token is configured: `huggingface-cli login`
|
| 72 |
+
2. Verify model access on HF
|
| 73 |
+
3. Check Space logs for download errors
|
| 74 |
+
|
| 75 |
+
### Out of Memory
|
| 76 |
+
- Reduce duration slider
|
| 77 |
+
- Use smaller model version (if available)
|
| 78 |
+
- Enable lazy_load in pipeline (already done)
|
| 79 |
+
|
| 80 |
+
### Slow Generation
|
| 81 |
+
- Generation is ~RTF 1.0 (real-time speed)
|
| 82 |
+
- First run may be slower due to model loading
|
| 83 |
+
- CPU is ~10x slower than GPU
|
| 84 |
+
|
| 85 |
+
## File Structure
|
| 86 |
+
|
| 87 |
+
```
|
| 88 |
+
heartlib/
|
| 89 |
+
├── app.py # Gradio app (main entry point)
|
| 90 |
+
├── Dockerfile # Docker build config
|
| 91 |
+
├── requirements.txt # Python dependencies
|
| 92 |
+
├── setup.sh # Model download script
|
| 93 |
+
├── README_SPACE.md # Space documentation
|
| 94 |
+
├── .dockerignore # Docker build optimization
|
| 95 |
+
├── src/heartlib/ # Core library
|
| 96 |
+
│ ├── pipelines/
|
| 97 |
+
│ │ ├── music_generation.py
|
| 98 |
+
│ │ └── lyrics_transcription.py
|
| 99 |
+
│ ├── heartmula/
|
| 100 |
+
│ └── heartcodec/
|
| 101 |
+
└── examples/ # Example scripts
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
## Next Steps
|
| 105 |
+
|
| 106 |
+
1. **Commit & Push** to your Space repository
|
| 107 |
+
2. **Monitor** the Space build logs
|
| 108 |
+
3. **Test** once deployment completes
|
| 109 |
+
4. **Share** the Space URL!
|
| 110 |
+
|
| 111 |
+
## Support
|
| 112 |
+
|
| 113 |
+
- **Paper**: https://arxiv.org/pdf/2601.10547
|
| 114 |
+
- **GitHub**: https://github.com/HeartMuLa/heartlib
|
| 115 |
+
- **Discord**: https://discord.gg/BKXF5FgH
|
Dockerfile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
python3.10 \
|
| 8 |
+
python3-pip \
|
| 9 |
+
git \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# Set Python 3.10 as default
|
| 13 |
+
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 && \
|
| 14 |
+
update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1
|
| 15 |
+
|
| 16 |
+
# Copy requirements and install Python dependencies
|
| 17 |
+
COPY requirements.txt .
|
| 18 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 19 |
+
|
| 20 |
+
# Copy application code
|
| 21 |
+
COPY . .
|
| 22 |
+
|
| 23 |
+
# Install heartlib package
|
| 24 |
+
RUN pip install -e .
|
| 25 |
+
|
| 26 |
+
# Create models directory
|
| 27 |
+
RUN mkdir -p ./models
|
| 28 |
+
|
| 29 |
+
# Run setup (downloads models and starts app)
|
| 30 |
+
CMD ["bash", "-c", "bash setup.sh && python app.py"]
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
MANUAL_DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 Manual Deployment Steps for HF Space
|
| 2 |
+
|
| 3 |
+
Since HF requires token authentication, here's how to deploy manually:
|
| 4 |
+
|
| 5 |
+
## Option 1: Use HF CLI (Recommended)
|
| 6 |
+
1. Get your HF access token: https://huggingface.co/settings/tokens
|
| 7 |
+
2. Create a write-enabled token
|
| 8 |
+
3. Run this command with your token:
|
| 9 |
+
|
| 10 |
+
```bash
|
| 11 |
+
huggingface-cli login --token hf_YOUR_TOKEN_HERE --add_to_git_credential
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
Then push:
|
| 15 |
+
```bash
|
| 16 |
+
cd F:\Projects\heartlib
|
| 17 |
+
git remote add space https://huggingface.co/spaces/brandongraves08/test
|
| 18 |
+
git push space main --force
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## Option 2: Manual Upload via Web UI
|
| 22 |
+
1. Go to https://huggingface.co/spaces/brandongraves08/test
|
| 23 |
+
2. Click **Files** → **Upload files**
|
| 24 |
+
3. Upload these files:
|
| 25 |
+
- app.py
|
| 26 |
+
- Dockerfile
|
| 27 |
+
- requirements.txt
|
| 28 |
+
- setup.sh
|
| 29 |
+
- .dockerignore
|
| 30 |
+
- README_SPACE.md
|
| 31 |
+
|
| 32 |
+
## Option 3: Use Git with SSH Key
|
| 33 |
+
1. Set up SSH key on HF: https://huggingface.co/settings/keys
|
| 34 |
+
2. Update git remote:
|
| 35 |
+
```bash
|
| 36 |
+
git remote set-url space git@huggingface.co:spaces/brandongraves08/test.git
|
| 37 |
+
git push space main
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Space Configuration (After Upload)
|
| 41 |
+
Once files are on the Space:
|
| 42 |
+
1. Go to **Settings** (gear icon)
|
| 43 |
+
2. Set **Runtime** → **Docker**
|
| 44 |
+
3. Select **GPU Hardware** (A100/T4/A10G recommended)
|
| 45 |
+
4. Click **Save**
|
| 46 |
+
|
| 47 |
+
The Space will auto-build and start the Gradio app!
|
| 48 |
+
|
| 49 |
+
## Verify Deployment
|
| 50 |
+
Check status at: https://brandongraves08-test.hf.space
|
| 51 |
+
|
| 52 |
+
## Troubleshooting
|
| 53 |
+
- **Models not downloading**: Check HF token has read access
|
| 54 |
+
- **Build fails**: Check Space logs for error details
|
| 55 |
+
- **Out of memory**: Reduce model size or use better GPU
|
README.md
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<picture>
|
| 3 |
+
<source srcset="./assets/logo.png" media="(prefers-color-scheme: dark)">
|
| 4 |
+
<img src="./assets/logo.png" width="30%">
|
| 5 |
+
</picture>
|
| 6 |
+
|
| 7 |
+
</p>
|
| 8 |
+
|
| 9 |
+
<p align="center">
|
| 10 |
+
<a href="https://heartmula.github.io/">Demo 🎶</a> | 📑 <a href="https://arxiv.org/pdf/2601.10547">Paper</a>
|
| 11 |
+
<br>
|
| 12 |
+
<a href="https://huggingface.co/HeartMuLa/HeartMuLa-oss-3B">HeartMuLa-oss-3B 🤗</a> | <a href="https://modelscope.cn/models/HeartMuLa/HeartMuLa-oss-3B">HeartMuLa-oss-3B <picture>
|
| 13 |
+
<source srcset="./assets/badge.svg" media="(prefers-color-scheme: dark)">
|
| 14 |
+
<img src="./assets/badge.svg" width="20px">
|
| 15 |
+
</picture></a>
|
| 16 |
+
<br>
|
| 17 |
+
<a href="https://huggingface.co/HeartMuLa/HeartMuLa-RL-oss-3B-20260123"> HeartMuLa-RL-oss-3B-20260123 🤗</a> | <a href="https://modelscope.cn/models/HeartMuLa/HeartMuLa-RL-oss-3B-20260123">HeartMuLa-RL-oss-3B-20260123 <picture>
|
| 18 |
+
<source srcset="./assets/badge.svg" media="(prefers-color-scheme: dark)">
|
| 19 |
+
<img src="./assets/badge.svg" width="20px">
|
| 20 |
+
</picture></a>
|
| 21 |
+
|
| 22 |
+
</p>
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
# HeartMuLa: A Family of Open Sourced Music Foundation Models
|
| 26 |
+
|
| 27 |
+
HeartMuLa is a family of open sourced music foundation models including:
|
| 28 |
+
1. HeartMuLa: a music language model that generates music conditioned on lyrics and tags with multilingual support including but not limited to English, Chinese, Japanese, Korean and Spanish.
|
| 29 |
+
2. HeartCodec: a 12.5 hz music codec with high reconstruction fidelity;
|
| 30 |
+
3. HeartTranscriptor: a whisper-based model specifically tuned for lyrics transcription; Check [this page](./examples/README.md) for its usage.
|
| 31 |
+
4. HeartCLAP: an audio–text alignment model that establishes a unified embedding space for music descriptions and cross-modal retrieval.
|
| 32 |
+
---
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Below shows the experiment result of our oss-3B version compared with other baselines.
|
| 36 |
+
<p align="center">
|
| 37 |
+
<picture>
|
| 38 |
+
<source srcset="./assets/exp.png" media="(prefers-color-scheme: dark)">
|
| 39 |
+
<img src="./assets/exp.png" width="90%">
|
| 40 |
+
</picture>
|
| 41 |
+
|
| 42 |
+
</p>
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## 🔥 Highlight
|
| 47 |
+
|
| 48 |
+
Our latest internal version of HeartMuLa-7B achieves **comparable performance with Suno** in terms of musicality, fidelity and controllability. If you are interested, welcome to reach us out via heartmula.ai@gmail.com
|
| 49 |
+
|
| 50 |
+
## 📰 News
|
| 51 |
+
Join on Discord! [<img alt="join discord" src="https://img.shields.io/discord/842440537755353128?color=%237289da&logo=discord"/>](https://discord.gg/BKXF5FgH)
|
| 52 |
+
|
| 53 |
+
- 🚀 **23 Jan. 2026**
|
| 54 |
+
|
| 55 |
+
By leveraging Reinforcement Learning, we have continuously refined our model and are proud to officially release **HeartMuLa-RL-oss-3B-20260123**. This version is designed to achieve more precise control over styles and tags. Simultaneously, we are launching **HeartCodec-oss-20260123**, which optimizes audio decoding quality.
|
| 56 |
+
|
| 57 |
+
- 🫶 **20 Jan. 2026**
|
| 58 |
+
|
| 59 |
+
[Benji](https://github.com/benjiyaya) has created a wonderful [ComfyUI custom node](https://github.com/benjiyaya/HeartMuLa_ComfyUI) for HeartMuLa. Thanks Benji!
|
| 60 |
+
- ⚖️ **20 Jan. 2026**
|
| 61 |
+
|
| 62 |
+
License update: We update the license of this repo and all related model weights to **Apache 2.0**.
|
| 63 |
+
- 🚀 **14 Jan. 2026**
|
| 64 |
+
The official release of **HeartTranscriptor-oss** and the first **HeartMuLa-oss-3B** version along with our **HeartCodec-oss**.
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
## 🧭 TODOs
|
| 68 |
+
|
| 69 |
+
- ⏳ Release scripts for inference acceleration and streaming inference. The current inference speed is around RTF $\approx 1.0$.
|
| 70 |
+
- ⏳ Support **reference audio conditioning**, **fine-grained controllable music generation**, **hot song generation**.
|
| 71 |
+
- ⏳ Release the **HeartMuLa-oss-7B** version.
|
| 72 |
+
- ✅ Release inference code and pretrained checkpoints of
|
| 73 |
+
**HeartCodec-oss, HeartMuLa-oss-3B, and HeartTranscriptor-oss**.
|
| 74 |
+
|
| 75 |
+
---
|
| 76 |
+
|
| 77 |
+
## 🛠️ Local Deployment
|
| 78 |
+
|
| 79 |
+
### ⚙️ Environment Setup
|
| 80 |
+
|
| 81 |
+
We recommend using `python=3.10` for local deployment.
|
| 82 |
+
|
| 83 |
+
Clone this repo and install locally.
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
git clone https://github.com/HeartMuLa/heartlib.git
|
| 87 |
+
cd heartlib
|
| 88 |
+
pip install -e .
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
Download our pretrained checkpoints from huggingface or modelscope using the following command:
|
| 92 |
+
|
| 93 |
+
```
|
| 94 |
+
# if you are using huggingface
|
| 95 |
+
hf download --local-dir './ckpt' 'HeartMuLa/HeartMuLaGen'
|
| 96 |
+
|
| 97 |
+
## To use version released on 20260123 (recommended)
|
| 98 |
+
hf download --local-dir './ckpt/HeartMuLa-oss-3B' 'HeartMuLa/HeartMuLa-RL-oss-3B-20260123'
|
| 99 |
+
hf download --local-dir './ckpt/HeartCodec-oss' HeartMuLa/HeartCodec-oss-20260123
|
| 100 |
+
|
| 101 |
+
## To use oss-3B version
|
| 102 |
+
hf download --local-dir './ckpt/HeartMuLa-oss-3B' 'HeartMuLa/HeartMuLa-oss-3B'
|
| 103 |
+
hf download --local-dir './ckpt/HeartCodec-oss' 'HeartMuLa/HeartCodec-oss'
|
| 104 |
+
|
| 105 |
+
# if you are using modelscope
|
| 106 |
+
modelscope download --model 'HeartMuLa/HeartMuLaGen' --local_dir './ckpt'
|
| 107 |
+
|
| 108 |
+
## To use version released on 20260123 (recommended)
|
| 109 |
+
modelscope download --model 'HeartMuLa/HeartMuLa-RL-oss-3B-20260123' --local_dir './ckpt/HeartMuLa-oss-3B'
|
| 110 |
+
modelscope download --model 'HeartMuLa/HeartCodec-oss-20260123' --local_dir './ckpt/HeartCodec-oss'
|
| 111 |
+
|
| 112 |
+
## To use oss-3B version
|
| 113 |
+
modelscope download --model 'HeartMuLa/HeartMuLa-oss-3B' --local_dir './ckpt/HeartMuLa-oss-3B'
|
| 114 |
+
modelscope download --model 'HeartMuLa/HeartCodec-oss' --local_dir './ckpt/HeartCodec-oss'
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
After downloading, the `./ckpt` subfolder should structure like this:
|
| 118 |
+
```
|
| 119 |
+
./ckpt/
|
| 120 |
+
├── HeartCodec-oss/
|
| 121 |
+
├── HeartMuLa-oss-3B/
|
| 122 |
+
├── gen_config.json
|
| 123 |
+
└── tokenizer.json
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
### ▶️ Example Usage
|
| 128 |
+
|
| 129 |
+
To generate music, run:
|
| 130 |
+
|
| 131 |
+
```
|
| 132 |
+
python ./examples/run_music_generation.py --model_path=./ckpt --version="3B"
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
By default this command will generate a piece of music conditioned on lyrics and tags provided in `./assets` folder. The output music will be saved at `./assets/output.mp3`.
|
| 136 |
+
|
| 137 |
+
#### FAQs
|
| 138 |
+
|
| 139 |
+
1. How to specify lyrics and tags?
|
| 140 |
+
|
| 141 |
+
The model will load lyrics from the txt file `--lyrics` link to (by default `./assets/lyrics.txt`). If you would like to use your own lyrics, just modify the content in `./assets/lyrics.txt`. If you would like to save your lyrics to another path, e.g. `my_awesome_lyrics.txt`, remember to input arguments `--lyrics my_awesome_lyrics.txt`.
|
| 142 |
+
|
| 143 |
+
For tags it's basically the same.
|
| 144 |
+
|
| 145 |
+
2. CUDA out of memory?
|
| 146 |
+
|
| 147 |
+
If you have multi-GPUs (e.g. 2 4090s), we recommend placing the params of HeartMuLa and HeartCodec separately on different devices. You can do it by typing `--mula_device cuda:0 --codec_device cuda:1`
|
| 148 |
+
|
| 149 |
+
If you are running on a single GPU, use `--lazy_load true` so that modules will be loaded on demand and deleted once inference completed to save GPU memory.
|
| 150 |
+
|
| 151 |
+
All parameters:
|
| 152 |
+
|
| 153 |
+
- `--model_path` (required): Path to the pretrained model checkpoint
|
| 154 |
+
- `--lyrics`: Path to lyrics file (default: `./assets/lyrics.txt`)
|
| 155 |
+
- `--tags`: Path to tags file (default: `./assets/tags.txt`)
|
| 156 |
+
- `--save_path`: Output audio file path (default: `./assets/output.mp3`)
|
| 157 |
+
- `--max_audio_length_ms`: Maximum audio length in milliseconds (default: 240000)
|
| 158 |
+
- `--topk`: Top-k sampling parameter for generation (default: 50)
|
| 159 |
+
- `--temperature`: Sampling temperature for generation (default: 1.0)
|
| 160 |
+
- `--cfg_scale`: Classifier-free guidance scale (default: 1.5)
|
| 161 |
+
- `--version`: The version of HeartMuLa, choose between [`3B`, `7B`]. (default: `3B`) # `7B` version not released yet.
|
| 162 |
+
- `--mula_device/--codec_device`: The device where params will be placed. Both are set to `cuda` by default. You can use `--mula_device cuda:0 --codec_device cuda:1` to explicitly place different modules to different devices.
|
| 163 |
+
- `--mula_dtype/--codec_dtype`: Inference dtype. By default is `bf16` for HeartMuLa and `fp32` for HeartCodec. Setting `bf16` for HeartCodec may result in the degradation of audio quality.
|
| 164 |
+
- `--lazy_load`: Whether or not to use lazy loading (default: false). If turned on, modules will be loaded on demand to save GPU usage.
|
| 165 |
+
Recommended format of lyrics and tags:
|
| 166 |
+
```txt
|
| 167 |
+
[Intro]
|
| 168 |
+
|
| 169 |
+
[Verse]
|
| 170 |
+
The sun creeps in across the floor
|
| 171 |
+
I hear the traffic outside the door
|
| 172 |
+
The coffee pot begins to hiss
|
| 173 |
+
It is another morning just like this
|
| 174 |
+
|
| 175 |
+
[Prechorus]
|
| 176 |
+
The world keeps spinning round and round
|
| 177 |
+
Feet are planted on the ground
|
| 178 |
+
I find my rhythm in the sound
|
| 179 |
+
|
| 180 |
+
[Chorus]
|
| 181 |
+
Every day the light returns
|
| 182 |
+
Every day the fire burns
|
| 183 |
+
We keep on walking down this street
|
| 184 |
+
Moving to the same steady beat
|
| 185 |
+
It is the ordinary magic that we meet
|
| 186 |
+
|
| 187 |
+
[Verse]
|
| 188 |
+
The hours tick deeply into noon
|
| 189 |
+
Chasing shadows,chasing the moon
|
| 190 |
+
Work is done and the lights go low
|
| 191 |
+
Watching the city start to glow
|
| 192 |
+
|
| 193 |
+
[Bridge]
|
| 194 |
+
It is not always easy,not always bright
|
| 195 |
+
Sometimes we wrestle with the night
|
| 196 |
+
But we make it to the morning light
|
| 197 |
+
|
| 198 |
+
[Chorus]
|
| 199 |
+
Every day the light returns
|
| 200 |
+
Every day the fire burns
|
| 201 |
+
We keep on walking down this street
|
| 202 |
+
Moving to the same steady beat
|
| 203 |
+
|
| 204 |
+
[Outro]
|
| 205 |
+
Just another day
|
| 206 |
+
Every single day
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
Regarding tags, check this [issue](https://github.com/HeartMuLa/heartlib/issues/17) for reference.
|
| 210 |
+
Our different tags are comma-separated without spaces as illustrated below:
|
| 211 |
+
```txt
|
| 212 |
+
piano,happy,wedding,synthesizer,romantic
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
---
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
## ⚖️ License
|
| 219 |
+
|
| 220 |
+
This repository is licensed under the Apache 2.0 License.
|
| 221 |
+
|
| 222 |
+
---
|
| 223 |
+
|
| 224 |
+
## 📚 Citation
|
| 225 |
+
|
| 226 |
+
```
|
| 227 |
+
@misc{yang2026heartmulafamilyopensourced,
|
| 228 |
+
title={HeartMuLa: A Family of Open Sourced Music Foundation Models},
|
| 229 |
+
author={Dongchao Yang and Yuxin Xie and Yuguo Yin and Zheyu Wang and Xiaoyu Yi and Gongxi Zhu and Xiaolong Weng and Zihan Xiong and Yingzhe Ma and Dading Cong and Jingliang Liu and Zihang Huang and Jinghan Ru and Rongjie Huang and Haoran Wan and Peixu Wang and Kuoxi Yu and Helin Wang and Liming Liang and Xianwei Zhuang and Yuanyuan Wang and Haohan Guo and Junjie Cao and Zeqian Ju and Songxiang Liu and Yuewen Cao and Heming Weng and Yuexian Zou},
|
| 230 |
+
year={2026},
|
| 231 |
+
eprint={2601.10547},
|
| 232 |
+
archivePrefix={arXiv},
|
| 233 |
+
primaryClass={cs.SD},
|
| 234 |
+
url={https://arxiv.org/abs/2601.10547},
|
| 235 |
+
}
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
## 📬 Contact
|
| 239 |
+
If you are interested in HeartMuLa, feel free to reach us at heartmula.ai@gmail.com
|
| 240 |
+
|
| 241 |
+
Welcome to join us through [discord](https://discord.gg/BKXF5FgH) or Wechat group.
|
| 242 |
+
|
| 243 |
+
Scan the QR code on the left to join our Wechat group. If it expires, feel free to raise an issue to remind us of updating.
|
| 244 |
+
|
| 245 |
+
If the number of group members exceeds 200, joining the group via directly scanning the QR code is restricted by WeChat. In this case, scan our team member's QR code on the right and send a request writing **HeartMuLa Group Invite**. We will invite you into the group manually.
|
| 246 |
+
<p align="center">
|
| 247 |
+
<picture>
|
| 248 |
+
<source srcset="./assets/group_wx.jpeg" media="(prefers-color-scheme: dark)">
|
| 249 |
+
<img src="./assets/group_wx.jpeg" width="40%">
|
| 250 |
+
</picture>
|
| 251 |
+
<picture>
|
| 252 |
+
<source srcset="./assets/lead_wx.jpeg" media="(prefers-color-scheme: dark)">
|
| 253 |
+
<img src="./assets/lead_wx.jpeg" width="40%">
|
| 254 |
+
</picture>
|
| 255 |
+
</p>
|
README_SPACE.md
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: HeartMuLa Music Generation
|
| 3 |
+
emoji: 🎵
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: docker
|
| 7 |
+
sdk_version: latest
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
duplicated_from: HeartMuLa/HeartMuLa-oss
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# HeartMuLa Music Generation Space
|
| 14 |
+
|
| 15 |
+
Generate music from lyrics and style tags using the HeartMuLa family of open-source music foundation models.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
|
| 19 |
+
- **Music Generation**: Convert lyrics + style tags → audio via two-stage pipeline
|
| 20 |
+
- **HeartMuLa LLM**: Frame-by-frame audio token generation with style control
|
| 21 |
+
- **HeartCodec**: High-fidelity flow-matching codec (48 kHz output)
|
| 22 |
+
- **Multiple Model Sizes**: 3B, 7B, and 300M versions available
|
| 23 |
+
|
| 24 |
+
## Setup
|
| 25 |
+
|
| 26 |
+
The Space will automatically download and set up the required models on first run.
|
| 27 |
+
|
| 28 |
+
## Usage
|
| 29 |
+
|
| 30 |
+
1. Enter your **lyrics** in the text field
|
| 31 |
+
2. Add **style tags** (e.g., "pop, upbeat, energetic")
|
| 32 |
+
3. Adjust generation parameters:
|
| 33 |
+
- **Duration**: Length of generated music (5-60 seconds)
|
| 34 |
+
- **Temperature**: Creativity level (0.1-2.0)
|
| 35 |
+
- **CFG Scale**: Style control strength (1.0-3.0)
|
| 36 |
+
- **Top-K**: Sampling parameter (10-100)
|
| 37 |
+
4. Click **Generate Music** to create your track
|
| 38 |
+
|
| 39 |
+
## Model Information
|
| 40 |
+
|
| 41 |
+
- **HeartMuLa-RL-oss-3B-20260123**: RL-tuned version with improved style control (recommended)
|
| 42 |
+
- **HeartCodec-oss-20260123**: Optimized audio decoding quality
|
| 43 |
+
|
| 44 |
+
## Performance
|
| 45 |
+
|
| 46 |
+
- RTF ≈ 1.0 (real-time inference speed)
|
| 47 |
+
- 48 kHz sample rate output
|
| 48 |
+
- Supports multiple languages
|
| 49 |
+
|
| 50 |
+
## References
|
| 51 |
+
|
| 52 |
+
- [Paper](https://arxiv.org/pdf/2601.10547)
|
| 53 |
+
- [GitHub](https://github.com/HeartMuLa/heartlib)
|
| 54 |
+
- [Discord](https://discord.gg/BKXF5FgH)
|
app.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
HeartMuLa Music Generation Gradio App for Hugging Face Spaces
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import torch
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from heartlib import HeartMuLaGenPipeline
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
# Configuration
|
| 17 |
+
MODEL_PATH = "./models"
|
| 18 |
+
DEFAULT_VERSION = "3B"
|
| 19 |
+
|
| 20 |
+
# Global pipeline instance
|
| 21 |
+
pipeline = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_pipeline():
|
| 25 |
+
"""Load the HeartMuLa pipeline"""
|
| 26 |
+
global pipeline
|
| 27 |
+
|
| 28 |
+
if pipeline is not None:
|
| 29 |
+
return pipeline
|
| 30 |
+
|
| 31 |
+
logger.info("Loading HeartMuLa pipeline...")
|
| 32 |
+
|
| 33 |
+
# Determine device
|
| 34 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 36 |
+
|
| 37 |
+
logger.info(f"Using device: {device}, dtype: {dtype}")
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
pipeline = HeartMuLaGenPipeline.from_pretrained(
|
| 41 |
+
MODEL_PATH,
|
| 42 |
+
device={
|
| 43 |
+
"mula": torch.device(device),
|
| 44 |
+
"codec": torch.device(device),
|
| 45 |
+
},
|
| 46 |
+
dtype={
|
| 47 |
+
"mula": dtype,
|
| 48 |
+
"codec": dtype,
|
| 49 |
+
},
|
| 50 |
+
version=DEFAULT_VERSION,
|
| 51 |
+
lazy_load=True,
|
| 52 |
+
)
|
| 53 |
+
logger.info("Pipeline loaded successfully!")
|
| 54 |
+
return pipeline
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.error(f"Failed to load pipeline: {e}")
|
| 57 |
+
raise gr.Error(f"Failed to load model: {e}")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def generate_music(
|
| 61 |
+
lyrics: str,
|
| 62 |
+
tags: str,
|
| 63 |
+
max_duration: int = 30,
|
| 64 |
+
temperature: float = 1.0,
|
| 65 |
+
top_k: int = 50,
|
| 66 |
+
cfg_scale: float = 1.5,
|
| 67 |
+
):
|
| 68 |
+
"""Generate music from lyrics and tags"""
|
| 69 |
+
|
| 70 |
+
if not lyrics.strip():
|
| 71 |
+
raise gr.Error("Please enter lyrics")
|
| 72 |
+
if not tags.strip():
|
| 73 |
+
raise gr.Error("Please enter tags")
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
logger.info(f"Generating music with lyrics: {lyrics[:50]}... and tags: {tags}")
|
| 77 |
+
|
| 78 |
+
# Load pipeline
|
| 79 |
+
pipe = load_pipeline()
|
| 80 |
+
|
| 81 |
+
# Convert duration to milliseconds
|
| 82 |
+
max_audio_length_ms = max_duration * 1000
|
| 83 |
+
|
| 84 |
+
# Generate music
|
| 85 |
+
output_path = "/tmp/generated_music.mp3"
|
| 86 |
+
os.makedirs("/tmp", exist_ok=True)
|
| 87 |
+
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
pipe(
|
| 90 |
+
{
|
| 91 |
+
"lyrics": lyrics,
|
| 92 |
+
"tags": tags,
|
| 93 |
+
},
|
| 94 |
+
max_audio_length_ms=max_audio_length_ms,
|
| 95 |
+
save_path=output_path,
|
| 96 |
+
topk=top_k,
|
| 97 |
+
temperature=temperature,
|
| 98 |
+
cfg_scale=cfg_scale,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
logger.info(f"Music generated successfully: {output_path}")
|
| 102 |
+
return output_path
|
| 103 |
+
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.error(f"Error during generation: {e}")
|
| 106 |
+
raise gr.Error(f"Generation failed: {e}")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def main():
|
| 110 |
+
"""Create Gradio interface"""
|
| 111 |
+
|
| 112 |
+
# Check if models exist
|
| 113 |
+
if not Path(MODEL_PATH).exists():
|
| 114 |
+
logger.warning(f"Models directory not found at {MODEL_PATH}")
|
| 115 |
+
logger.info("You need to download the models first:")
|
| 116 |
+
logger.info("hf download --local-dir './models/HeartMuLa-oss-3B' HeartMuLa/HeartMuLa-RL-oss-3B-20260123")
|
| 117 |
+
logger.info("hf download --local-dir './models/HeartCodec-oss' HeartMuLa/HeartCodec-oss-20260123")
|
| 118 |
+
|
| 119 |
+
with gr.Blocks(title="HeartMuLa Music Generation") as demo:
|
| 120 |
+
gr.Markdown("""
|
| 121 |
+
# 🎵 HeartMuLa Music Generation
|
| 122 |
+
|
| 123 |
+
Generate music from lyrics and style tags using HeartMuLa, a family of open-source music foundation models.
|
| 124 |
+
""")
|
| 125 |
+
|
| 126 |
+
with gr.Row():
|
| 127 |
+
with gr.Column():
|
| 128 |
+
gr.Markdown("### Input")
|
| 129 |
+
|
| 130 |
+
lyrics_input = gr.Textbox(
|
| 131 |
+
label="Lyrics",
|
| 132 |
+
placeholder="Enter your lyrics here...",
|
| 133 |
+
lines=4,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
tags_input = gr.Textbox(
|
| 137 |
+
label="Style Tags",
|
| 138 |
+
placeholder="e.g., pop, upbeat, energetic",
|
| 139 |
+
lines=2,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
with gr.Row():
|
| 143 |
+
duration_slider = gr.Slider(
|
| 144 |
+
minimum=5,
|
| 145 |
+
maximum=60,
|
| 146 |
+
value=30,
|
| 147 |
+
step=5,
|
| 148 |
+
label="Duration (seconds)",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
with gr.Row():
|
| 152 |
+
temp_slider = gr.Slider(
|
| 153 |
+
minimum=0.1,
|
| 154 |
+
maximum=2.0,
|
| 155 |
+
value=1.0,
|
| 156 |
+
step=0.1,
|
| 157 |
+
label="Temperature",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
cfg_slider = gr.Slider(
|
| 161 |
+
minimum=1.0,
|
| 162 |
+
maximum=3.0,
|
| 163 |
+
value=1.5,
|
| 164 |
+
step=0.1,
|
| 165 |
+
label="CFG Scale (style strength)",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
with gr.Row():
|
| 169 |
+
topk_slider = gr.Slider(
|
| 170 |
+
minimum=10,
|
| 171 |
+
maximum=100,
|
| 172 |
+
value=50,
|
| 173 |
+
step=5,
|
| 174 |
+
label="Top-K",
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg")
|
| 178 |
+
|
| 179 |
+
with gr.Column():
|
| 180 |
+
gr.Markdown("### Output")
|
| 181 |
+
audio_output = gr.Audio(label="Generated Music", type="filepath")
|
| 182 |
+
|
| 183 |
+
gr.Markdown("""
|
| 184 |
+
### 📝 Tips
|
| 185 |
+
- **Lyrics**: Describe the vocals and melody
|
| 186 |
+
- **Tags**: Use style descriptors like "pop", "rock", "ambient", "upbeat", etc.
|
| 187 |
+
- **CFG Scale**: Higher values = stronger style control (1.5 is recommended)
|
| 188 |
+
- **Temperature**: Higher = more creative, lower = more consistent
|
| 189 |
+
""")
|
| 190 |
+
|
| 191 |
+
# Connect button to generation function
|
| 192 |
+
generate_btn.click(
|
| 193 |
+
fn=generate_music,
|
| 194 |
+
inputs=[
|
| 195 |
+
lyrics_input,
|
| 196 |
+
tags_input,
|
| 197 |
+
duration_slider,
|
| 198 |
+
temp_slider,
|
| 199 |
+
topk_slider,
|
| 200 |
+
cfg_slider,
|
| 201 |
+
],
|
| 202 |
+
outputs=audio_output,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Example inputs
|
| 206 |
+
gr.Examples(
|
| 207 |
+
examples=[
|
| 208 |
+
[
|
| 209 |
+
"Love is in the air, feel the magic",
|
| 210 |
+
"pop, upbeat, romantic",
|
| 211 |
+
],
|
| 212 |
+
[
|
| 213 |
+
"Dark skies falling down, lonely tonight",
|
| 214 |
+
"rock, emotional, melancholic",
|
| 215 |
+
],
|
| 216 |
+
],
|
| 217 |
+
inputs=[lyrics_input, tags_input],
|
| 218 |
+
outputs=audio_output,
|
| 219 |
+
fn=generate_music,
|
| 220 |
+
cache_examples=False,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
return demo
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if __name__ == "__main__":
|
| 227 |
+
demo = main()
|
| 228 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
examples/README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎤 Lyrics Transcription
|
| 2 |
+
|
| 3 |
+
Download checkpoint using any of the following command:
|
| 4 |
+
```
|
| 5 |
+
hf download --local_dir './ckpt/HeartTranscriptor-oss' 'HeartMuLa/HeartTranscriptor-oss'
|
| 6 |
+
modelscope download --model 'HeartMuLa/HeartTranscriptor-oss' --local_dir './ckpt/HeartTranscriptor-oss'
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
python ./examples/run_lyrics_transcription.py --model_path=./ckpt
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
By default this command will load the generated music file at `./assets/output.mp3` and print the transcribed lyrics. Use `--music_path` to specify the path to the music file.
|
| 14 |
+
|
| 15 |
+
Note that our HeartTranscriptor is trained on separated vocal tracks. In this example usage part, we directly demonstrate on unseparated music tracks, which is purely for simplicity of illustration. We recommend using source separation tools like demucs to separate the tracks before transcribing lyrics to achieve better results.
|
examples/run_lyrics_transcription.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from heartlib import HeartTranscriptorPipeline
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def parse_args():
|
| 7 |
+
parser = argparse.ArgumentParser()
|
| 8 |
+
parser.add_argument("--model_path", type=str, required=True)
|
| 9 |
+
parser.add_argument("--music_path", type=str, default="./assets/output.mp3")
|
| 10 |
+
|
| 11 |
+
return parser.parse_args()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
if __name__ == "__main__":
|
| 15 |
+
args = parse_args()
|
| 16 |
+
pipe = HeartTranscriptorPipeline.from_pretrained(
|
| 17 |
+
args.model_path,
|
| 18 |
+
device=torch.device("cuda"),
|
| 19 |
+
dtype=torch.float16,
|
| 20 |
+
)
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
result = pipe(
|
| 23 |
+
args.music_path,
|
| 24 |
+
**{
|
| 25 |
+
"max_new_tokens": 256,
|
| 26 |
+
"num_beams": 2,
|
| 27 |
+
"task": "transcribe",
|
| 28 |
+
"condition_on_prev_tokens": False,
|
| 29 |
+
"compression_ratio_threshold": 1.8,
|
| 30 |
+
"temperature": (0.0, 0.1, 0.2, 0.4),
|
| 31 |
+
"logprob_threshold": -1.0,
|
| 32 |
+
"no_speech_threshold": 0.4,
|
| 33 |
+
},
|
| 34 |
+
)
|
| 35 |
+
print(result)
|
examples/run_music_generation.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from heartlib import HeartMuLaGenPipeline
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def str2bool(value):
|
| 7 |
+
if isinstance(value, bool):
|
| 8 |
+
return value
|
| 9 |
+
if value.lower() in ("yes", "y", "true", "t", "1"):
|
| 10 |
+
return True
|
| 11 |
+
elif value.lower() in ("no", "n", "false", "f", "0"):
|
| 12 |
+
return False
|
| 13 |
+
else:
|
| 14 |
+
raise argparse.ArgumentTypeError(f"Boolean value expected. Got: {value}")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def str2dtype(value):
|
| 18 |
+
value = value.lower()
|
| 19 |
+
if value == "float32" or value == "fp32":
|
| 20 |
+
return torch.float32
|
| 21 |
+
elif value == "float16" or value == "fp16":
|
| 22 |
+
return torch.float16
|
| 23 |
+
elif value == "bfloat16" or value == "bf16":
|
| 24 |
+
return torch.bfloat16
|
| 25 |
+
else:
|
| 26 |
+
raise argparse.ArgumentTypeError(f"Dtype not recognized: {value}")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def str2device(value):
|
| 30 |
+
value = value.lower()
|
| 31 |
+
return torch.device(value)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def parse_args():
|
| 35 |
+
parser = argparse.ArgumentParser()
|
| 36 |
+
parser.add_argument("--model_path", type=str, required=True)
|
| 37 |
+
parser.add_argument("--version", type=str, default="3B")
|
| 38 |
+
parser.add_argument("--lyrics", type=str, default="./assets/lyrics.txt")
|
| 39 |
+
parser.add_argument("--tags", type=str, default="./assets/tags.txt")
|
| 40 |
+
parser.add_argument("--save_path", type=str, default="./assets/output.mp3")
|
| 41 |
+
|
| 42 |
+
parser.add_argument("--max_audio_length_ms", type=int, default=240_000)
|
| 43 |
+
parser.add_argument("--topk", type=int, default=50)
|
| 44 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
| 45 |
+
parser.add_argument("--cfg_scale", type=float, default=1.5)
|
| 46 |
+
parser.add_argument("--mula_device", type=str2device, default="cuda")
|
| 47 |
+
parser.add_argument("--codec_device", type=str2device, default="cuda")
|
| 48 |
+
parser.add_argument("--mula_dtype", type=str2dtype, default="bfloat16")
|
| 49 |
+
parser.add_argument("--codec_dtype", type=str2dtype, default="float32")
|
| 50 |
+
parser.add_argument("--lazy_load", type=str2bool, default=False)
|
| 51 |
+
return parser.parse_args()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
args = parse_args()
|
| 56 |
+
pipe = HeartMuLaGenPipeline.from_pretrained(
|
| 57 |
+
args.model_path,
|
| 58 |
+
device={
|
| 59 |
+
"mula": torch.device(args.mula_device),
|
| 60 |
+
"codec": torch.device(args.codec_device),
|
| 61 |
+
},
|
| 62 |
+
dtype={
|
| 63 |
+
"mula": args.mula_dtype,
|
| 64 |
+
"codec": args.codec_dtype,
|
| 65 |
+
},
|
| 66 |
+
version=args.version,
|
| 67 |
+
lazy_load=args.lazy_load,
|
| 68 |
+
)
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
pipe(
|
| 71 |
+
{
|
| 72 |
+
"lyrics": args.lyrics,
|
| 73 |
+
"tags": args.tags,
|
| 74 |
+
},
|
| 75 |
+
max_audio_length_ms=args.max_audio_length_ms,
|
| 76 |
+
save_path=args.save_path,
|
| 77 |
+
topk=args.topk,
|
| 78 |
+
temperature=args.temperature,
|
| 79 |
+
cfg_scale=args.cfg_scale,
|
| 80 |
+
)
|
| 81 |
+
print(f"Generated music saved to {args.save_path}")
|
pyproject.toml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "heartlib"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "A Python Library."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.9"
|
| 11 |
+
license = {text = "Apache-2.0"}
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "HeartMuLa Team", email = "heartmula.ai@gmail.com"}
|
| 14 |
+
]
|
| 15 |
+
dependencies = [
|
| 16 |
+
"numpy==2.0.2",
|
| 17 |
+
"torch==2.4.1",
|
| 18 |
+
"torchaudio==2.4.1",
|
| 19 |
+
"torchtune==0.4.0",
|
| 20 |
+
"torchao==0.9.0",
|
| 21 |
+
"torchvision==0.19.1",
|
| 22 |
+
"tqdm==4.67.1",
|
| 23 |
+
"traitlets==5.7.1",
|
| 24 |
+
"traittypes==0.2.3",
|
| 25 |
+
"transformers==4.57.0",
|
| 26 |
+
"tokenizers==0.22.1",
|
| 27 |
+
"ipykernel==6.17.1",
|
| 28 |
+
"einops==0.8.1",
|
| 29 |
+
"accelerate==1.12.0",
|
| 30 |
+
"bitsandbytes==0.49.0",
|
| 31 |
+
"vector-quantize-pytorch==1.27.15",
|
| 32 |
+
"modelscope==1.33.0",
|
| 33 |
+
"soundfile"
|
| 34 |
+
]
|
| 35 |
+
urls = { "homepage" = "https://heartmula.github.io/" }
|
| 36 |
+
classifiers = [
|
| 37 |
+
"Programming Language :: Python :: 3",
|
| 38 |
+
"Operating System :: OS Independent"
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
[tool.setuptools]
|
| 42 |
+
package-dir = {"" = "src"}
|
| 43 |
+
|
| 44 |
+
[tool.setuptools.packages.find]
|
| 45 |
+
where = ["src"]
|
| 46 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.4.1
|
| 2 |
+
torchaudio==2.4.1
|
| 3 |
+
transformers==4.40.0
|
| 4 |
+
safetensors==0.4.1
|
| 5 |
+
bitsandbytes==0.49.0
|
| 6 |
+
torchtune==0.4.0
|
| 7 |
+
tokenizers==0.15.0
|
| 8 |
+
tqdm==4.66.1
|
| 9 |
+
gradio==4.36.1
|
| 10 |
+
pydantic==2.5.0
|
| 11 |
+
numpy==1.24.3
|
setup.sh
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Setup script for HF Space deployment
|
| 3 |
+
|
| 4 |
+
echo "Installing HeartMuLa package..."
|
| 5 |
+
pip install -e .
|
| 6 |
+
|
| 7 |
+
echo "Downloading HeartMuLa models..."
|
| 8 |
+
mkdir -p ./models
|
| 9 |
+
|
| 10 |
+
# Download HeartMuLa 3B model (RL-tuned version - recommended)
|
| 11 |
+
echo "Downloading HeartMuLa-RL-oss-3B-20260123..."
|
| 12 |
+
huggingface-cli download \
|
| 13 |
+
--local-dir "./models/HeartMuLa-oss-3B" \
|
| 14 |
+
HeartMuLa/HeartMuLa-RL-oss-3B-20260123
|
| 15 |
+
|
| 16 |
+
# Download HeartCodec
|
| 17 |
+
echo "Downloading HeartCodec-oss-20260123..."
|
| 18 |
+
huggingface-cli download \
|
| 19 |
+
--local-dir "./models/HeartCodec-oss" \
|
| 20 |
+
HeartMuLa/HeartCodec-oss-20260123
|
| 21 |
+
|
| 22 |
+
# Copy tokenizer and config
|
| 23 |
+
echo "Copying tokenizer and config..."
|
| 24 |
+
huggingface-cli download \
|
| 25 |
+
--local-dir "./models" \
|
| 26 |
+
HeartMuLa/HeartMuLaGen
|
| 27 |
+
|
| 28 |
+
echo "Setup complete! Starting Gradio app..."
|
src/heartlib/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pipelines.music_generation import HeartMuLaGenPipeline
|
| 2 |
+
from .pipelines.lyrics_transcription import HeartTranscriptorPipeline
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"HeartMuLaGenPipeline",
|
| 6 |
+
"HeartTranscriptorPipeline"
|
| 7 |
+
]
|
src/heartlib/heartcodec/configuration_heartcodec.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class HeartCodecConfig(PretrainedConfig):
|
| 6 |
+
model_type = "heartcodec"
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
# config for rvq
|
| 11 |
+
dim: int = 512,
|
| 12 |
+
codebook_size: int = 8192,
|
| 13 |
+
decay: float = 0.9,
|
| 14 |
+
commitment_weight: float = 1.0,
|
| 15 |
+
threshold_ema_dead_code: int = 2,
|
| 16 |
+
use_cosine_sim: bool = False,
|
| 17 |
+
codebook_dim: int = 32,
|
| 18 |
+
num_quantizers: int = 8,
|
| 19 |
+
# config for diffusion transformer
|
| 20 |
+
attention_head_dim: int = 64,
|
| 21 |
+
in_channels: int = 1024,
|
| 22 |
+
norm_type: str = "ada_norm_single",
|
| 23 |
+
num_attention_heads: int = 24,
|
| 24 |
+
num_layers: int = 24,
|
| 25 |
+
num_layers_2: int = 6,
|
| 26 |
+
out_channels: int = 256,
|
| 27 |
+
# config for sq codec
|
| 28 |
+
num_bands: int = 1,
|
| 29 |
+
sample_rate: int = 48000,
|
| 30 |
+
causal: bool = True,
|
| 31 |
+
num_samples: int = 2,
|
| 32 |
+
downsample_factors: List[int] = [3, 4, 4, 4, 5],
|
| 33 |
+
downsample_kernel_sizes: List[int] = [6, 8, 8, 8, 10],
|
| 34 |
+
upsample_factors: List[int] = [5, 4, 4, 4, 3],
|
| 35 |
+
upsample_kernel_sizes: List[int] = [10, 8, 8, 8, 6],
|
| 36 |
+
latent_hidden_dim: int = 128,
|
| 37 |
+
default_kernel_size: int = 7,
|
| 38 |
+
delay_kernel_size: int = 5,
|
| 39 |
+
init_channel: int = 64,
|
| 40 |
+
res_kernel_size: int = 7,
|
| 41 |
+
**kwargs
|
| 42 |
+
):
|
| 43 |
+
super().__init__(**kwargs)
|
| 44 |
+
self.dim = dim
|
| 45 |
+
self.codebook_size = codebook_size
|
| 46 |
+
self.decay = decay
|
| 47 |
+
self.commitment_weight = commitment_weight
|
| 48 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 49 |
+
self.use_cosine_sim = use_cosine_sim
|
| 50 |
+
self.codebook_dim = codebook_dim
|
| 51 |
+
self.num_quantizers = num_quantizers
|
| 52 |
+
|
| 53 |
+
self.attention_head_dim = attention_head_dim
|
| 54 |
+
self.in_channels = in_channels
|
| 55 |
+
self.norm_type = norm_type
|
| 56 |
+
self.num_attention_heads = num_attention_heads
|
| 57 |
+
self.num_layers = num_layers
|
| 58 |
+
self.num_layers_2 = num_layers_2
|
| 59 |
+
self.out_channels = out_channels
|
| 60 |
+
|
| 61 |
+
self.num_bands = num_bands
|
| 62 |
+
self.sample_rate = sample_rate
|
| 63 |
+
self.causal = causal
|
| 64 |
+
self.num_samples = num_samples
|
| 65 |
+
self.downsample_factors = downsample_factors
|
| 66 |
+
self.downsample_kernel_sizes = downsample_kernel_sizes
|
| 67 |
+
self.upsample_factors = upsample_factors
|
| 68 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 69 |
+
self.latent_hidden_dim = latent_hidden_dim
|
| 70 |
+
self.default_kernel_size = default_kernel_size
|
| 71 |
+
self.delay_kernel_size = delay_kernel_size
|
| 72 |
+
self.init_channel = init_channel
|
| 73 |
+
self.res_kernel_size = res_kernel_size
|
src/heartlib/heartcodec/modeling_heartcodec.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .models.flow_matching import FlowMatching
|
| 3 |
+
from .models.sq_codec import ScalarModel
|
| 4 |
+
from .configuration_heartcodec import HeartCodecConfig
|
| 5 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 6 |
+
import math
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HeartCodec(PreTrainedModel):
|
| 11 |
+
config_class = HeartCodecConfig
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
config: HeartCodecConfig,
|
| 16 |
+
):
|
| 17 |
+
super(HeartCodec, self).__init__(config)
|
| 18 |
+
|
| 19 |
+
self.config = config
|
| 20 |
+
|
| 21 |
+
self.flow_matching = FlowMatching(
|
| 22 |
+
dim=config.dim,
|
| 23 |
+
codebook_size=config.codebook_size,
|
| 24 |
+
decay=config.decay,
|
| 25 |
+
commitment_weight=config.commitment_weight,
|
| 26 |
+
threshold_ema_dead_code=config.threshold_ema_dead_code,
|
| 27 |
+
use_cosine_sim=config.use_cosine_sim,
|
| 28 |
+
codebook_dim=config.codebook_dim,
|
| 29 |
+
num_quantizers=config.num_quantizers,
|
| 30 |
+
attention_head_dim=config.attention_head_dim,
|
| 31 |
+
in_channels=config.in_channels,
|
| 32 |
+
norm_type=config.norm_type,
|
| 33 |
+
num_attention_heads=config.num_attention_heads,
|
| 34 |
+
num_layers=config.num_layers,
|
| 35 |
+
num_layers_2=config.num_layers_2,
|
| 36 |
+
out_channels=config.out_channels,
|
| 37 |
+
)
|
| 38 |
+
self.scalar_model = ScalarModel(
|
| 39 |
+
num_bands=config.num_bands,
|
| 40 |
+
sample_rate=config.sample_rate,
|
| 41 |
+
causal=config.causal,
|
| 42 |
+
num_samples=config.num_samples,
|
| 43 |
+
downsample_factors=config.downsample_factors,
|
| 44 |
+
downsample_kernel_sizes=config.downsample_kernel_sizes,
|
| 45 |
+
upsample_factors=config.upsample_factors,
|
| 46 |
+
upsample_kernel_sizes=config.upsample_kernel_sizes,
|
| 47 |
+
latent_hidden_dim=config.latent_hidden_dim,
|
| 48 |
+
default_kernel_size=config.default_kernel_size,
|
| 49 |
+
delay_kernel_size=config.delay_kernel_size,
|
| 50 |
+
init_channel=config.init_channel,
|
| 51 |
+
res_kernel_size=config.res_kernel_size,
|
| 52 |
+
)
|
| 53 |
+
self.post_init()
|
| 54 |
+
|
| 55 |
+
self.sample_rate = config.sample_rate
|
| 56 |
+
|
| 57 |
+
@torch.inference_mode()
|
| 58 |
+
def detokenize(
|
| 59 |
+
self,
|
| 60 |
+
codes,
|
| 61 |
+
duration=29.76,
|
| 62 |
+
num_steps=10,
|
| 63 |
+
disable_progress=False,
|
| 64 |
+
guidance_scale=1.25,
|
| 65 |
+
):
|
| 66 |
+
codes = codes.unsqueeze(0).to(self.device)
|
| 67 |
+
first_latent = torch.randn(
|
| 68 |
+
codes.shape[0], int(duration * 25), 256, dtype=self.dtype
|
| 69 |
+
).to(
|
| 70 |
+
self.device
|
| 71 |
+
) # B, T, 64
|
| 72 |
+
first_latent_length = 0
|
| 73 |
+
first_latent_codes_length = 0
|
| 74 |
+
min_samples = int(duration * 12.5)
|
| 75 |
+
hop_samples = min_samples // 93 * 80
|
| 76 |
+
ovlp_samples = min_samples - hop_samples
|
| 77 |
+
ovlp_frames = ovlp_samples * 2
|
| 78 |
+
codes_len = codes.shape[-1] #
|
| 79 |
+
target_len = int(
|
| 80 |
+
(codes_len - first_latent_codes_length) / 12.5 * self.sample_rate
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# code repeat
|
| 84 |
+
if codes_len < min_samples:
|
| 85 |
+
while codes.shape[-1] < min_samples:
|
| 86 |
+
codes = torch.cat([codes, codes], -1)
|
| 87 |
+
codes = codes[:, :, 0:min_samples]
|
| 88 |
+
codes_len = codes.shape[-1]
|
| 89 |
+
if (codes_len - ovlp_frames) % hop_samples > 0:
|
| 90 |
+
len_codes = (
|
| 91 |
+
math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples
|
| 92 |
+
+ ovlp_samples
|
| 93 |
+
)
|
| 94 |
+
while codes.shape[-1] < len_codes:
|
| 95 |
+
codes = torch.cat([codes, codes], -1)
|
| 96 |
+
codes = codes[:, :, 0:len_codes]
|
| 97 |
+
latent_length = int(duration * 25)
|
| 98 |
+
latent_list = []
|
| 99 |
+
|
| 100 |
+
for sinx in range(0, codes.shape[-1] - hop_samples + 1, hop_samples):
|
| 101 |
+
codes_input = []
|
| 102 |
+
codes_input.append(codes[:, :, sinx : sinx + min_samples])
|
| 103 |
+
if sinx == 0 or ovlp_frames == 0:
|
| 104 |
+
incontext_length = first_latent_length
|
| 105 |
+
latents = self.flow_matching.inference_codes(
|
| 106 |
+
codes_input,
|
| 107 |
+
first_latent,
|
| 108 |
+
latent_length,
|
| 109 |
+
incontext_length,
|
| 110 |
+
guidance_scale=guidance_scale,
|
| 111 |
+
num_steps=num_steps,
|
| 112 |
+
disable_progress=disable_progress,
|
| 113 |
+
scenario="other_seg",
|
| 114 |
+
)
|
| 115 |
+
latent_list.append(latents)
|
| 116 |
+
else:
|
| 117 |
+
true_latent = latent_list[-1][:, -ovlp_frames:, :]
|
| 118 |
+
len_add_to_latent = latent_length - true_latent.shape[1] #
|
| 119 |
+
incontext_length = true_latent.shape[1]
|
| 120 |
+
true_latent = torch.cat(
|
| 121 |
+
[
|
| 122 |
+
true_latent,
|
| 123 |
+
torch.randn(
|
| 124 |
+
true_latent.shape[0],
|
| 125 |
+
len_add_to_latent,
|
| 126 |
+
true_latent.shape[-1],
|
| 127 |
+
dtype=self.dtype,
|
| 128 |
+
).to(self.device),
|
| 129 |
+
],
|
| 130 |
+
1,
|
| 131 |
+
)
|
| 132 |
+
latents = self.flow_matching.inference_codes(
|
| 133 |
+
codes_input,
|
| 134 |
+
true_latent,
|
| 135 |
+
latent_length,
|
| 136 |
+
incontext_length,
|
| 137 |
+
guidance_scale=guidance_scale,
|
| 138 |
+
num_steps=num_steps,
|
| 139 |
+
disable_progress=disable_progress,
|
| 140 |
+
scenario="other_seg",
|
| 141 |
+
)
|
| 142 |
+
latent_list.append(latents)
|
| 143 |
+
|
| 144 |
+
# latent_list = [l.float() for l in latent_list]
|
| 145 |
+
latent_list[0] = latent_list[0][:, first_latent_length:, :]
|
| 146 |
+
min_samples = int(duration * self.sample_rate)
|
| 147 |
+
hop_samples = min_samples // 93 * 80
|
| 148 |
+
ovlp_samples = min_samples - hop_samples
|
| 149 |
+
|
| 150 |
+
output = None
|
| 151 |
+
for i in range(len(latent_list)):
|
| 152 |
+
latent = latent_list[i]
|
| 153 |
+
bsz, t, f = latent.shape
|
| 154 |
+
|
| 155 |
+
latent = latent.reshape(
|
| 156 |
+
latent.shape[0], latent.shape[1], 2, latent.shape[2] // 2
|
| 157 |
+
).permute(0, 2, 1, 3)
|
| 158 |
+
latent = latent.reshape(
|
| 159 |
+
latent.shape[0] * 2, latent.shape[2], latent.shape[3]
|
| 160 |
+
)
|
| 161 |
+
cur_output = (
|
| 162 |
+
self.scalar_model.decode(latent.transpose(1, 2)).squeeze(0).squeeze(1)
|
| 163 |
+
) # 1 512 256
|
| 164 |
+
|
| 165 |
+
cur_output = cur_output[:, 0:min_samples].detach().cpu() # B, T
|
| 166 |
+
if cur_output.dim() == 3:
|
| 167 |
+
cur_output = cur_output[0]
|
| 168 |
+
|
| 169 |
+
if output is None:
|
| 170 |
+
output = cur_output
|
| 171 |
+
else:
|
| 172 |
+
if ovlp_samples == 0:
|
| 173 |
+
output = torch.cat([output, cur_output], -1)
|
| 174 |
+
else:
|
| 175 |
+
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
| 176 |
+
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
| 177 |
+
output[:, -ovlp_samples:] = (
|
| 178 |
+
output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:]
|
| 179 |
+
+ cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
| 180 |
+
)
|
| 181 |
+
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
| 182 |
+
output = output[:, 0:target_len]
|
| 183 |
+
return output
|
src/heartlib/heartcodec/models/flow_matching.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from vector_quantize_pytorch import ResidualVQ
|
| 6 |
+
from .transformer import LlamaTransformer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FlowMatching(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
# rvq stuff
|
| 13 |
+
dim: int = 512,
|
| 14 |
+
codebook_size: int = 8192,
|
| 15 |
+
decay: float = 0.9,
|
| 16 |
+
commitment_weight: float = 1.0,
|
| 17 |
+
threshold_ema_dead_code: int = 2,
|
| 18 |
+
use_cosine_sim: bool = False,
|
| 19 |
+
codebook_dim: int = 32,
|
| 20 |
+
num_quantizers: int = 8,
|
| 21 |
+
# dit backbone stuff
|
| 22 |
+
attention_head_dim: int = 64,
|
| 23 |
+
in_channels: int = 1024,
|
| 24 |
+
norm_type: str = "ada_norm_single",
|
| 25 |
+
num_attention_heads: int = 24,
|
| 26 |
+
num_layers: int = 24,
|
| 27 |
+
num_layers_2: int = 6,
|
| 28 |
+
out_channels: int = 256,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.vq_embed = ResidualVQ(
|
| 33 |
+
dim=dim,
|
| 34 |
+
codebook_size=codebook_size,
|
| 35 |
+
decay=decay,
|
| 36 |
+
commitment_weight=commitment_weight,
|
| 37 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
| 38 |
+
use_cosine_sim=use_cosine_sim,
|
| 39 |
+
codebook_dim=codebook_dim,
|
| 40 |
+
num_quantizers=num_quantizers,
|
| 41 |
+
)
|
| 42 |
+
self.cond_feature_emb = nn.Linear(dim, dim)
|
| 43 |
+
self.zero_cond_embedding1 = nn.Parameter(torch.randn(dim))
|
| 44 |
+
self.estimator = LlamaTransformer(
|
| 45 |
+
attention_head_dim=attention_head_dim,
|
| 46 |
+
in_channels=in_channels,
|
| 47 |
+
norm_type=norm_type,
|
| 48 |
+
num_attention_heads=num_attention_heads,
|
| 49 |
+
num_layers=num_layers,
|
| 50 |
+
num_layers_2=num_layers_2,
|
| 51 |
+
out_channels=out_channels,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.latent_dim = out_channels
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def inference_codes(
|
| 58 |
+
self,
|
| 59 |
+
codes,
|
| 60 |
+
true_latents,
|
| 61 |
+
latent_length,
|
| 62 |
+
incontext_length,
|
| 63 |
+
guidance_scale=2.0,
|
| 64 |
+
num_steps=20,
|
| 65 |
+
disable_progress=True,
|
| 66 |
+
scenario="start_seg",
|
| 67 |
+
):
|
| 68 |
+
device = true_latents.device
|
| 69 |
+
dtype = true_latents.dtype
|
| 70 |
+
# codes_bestrq_middle, codes_bestrq_last = codes
|
| 71 |
+
codes_bestrq_emb = codes[0]
|
| 72 |
+
|
| 73 |
+
batch_size = codes_bestrq_emb.shape[0]
|
| 74 |
+
self.vq_embed.eval()
|
| 75 |
+
quantized_feature_emb = self.vq_embed.get_output_from_indices(
|
| 76 |
+
codes_bestrq_emb.transpose(1, 2)
|
| 77 |
+
)
|
| 78 |
+
quantized_feature_emb = self.cond_feature_emb(quantized_feature_emb) # b t 512
|
| 79 |
+
# assert 1==2
|
| 80 |
+
quantized_feature_emb = F.interpolate(
|
| 81 |
+
quantized_feature_emb.permute(0, 2, 1), scale_factor=2, mode="nearest"
|
| 82 |
+
).permute(0, 2, 1)
|
| 83 |
+
|
| 84 |
+
num_frames = quantized_feature_emb.shape[1] #
|
| 85 |
+
latents = torch.randn(
|
| 86 |
+
(batch_size, num_frames, self.latent_dim), device=device, dtype=dtype
|
| 87 |
+
)
|
| 88 |
+
latent_masks = torch.zeros(
|
| 89 |
+
latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device
|
| 90 |
+
)
|
| 91 |
+
latent_masks[:, 0:latent_length] = 2
|
| 92 |
+
if scenario == "other_seg":
|
| 93 |
+
latent_masks[:, 0:incontext_length] = 1
|
| 94 |
+
|
| 95 |
+
quantized_feature_emb = (latent_masks > 0.5).unsqueeze(
|
| 96 |
+
-1
|
| 97 |
+
) * quantized_feature_emb + (latent_masks < 0.5).unsqueeze(
|
| 98 |
+
-1
|
| 99 |
+
) * self.zero_cond_embedding1.unsqueeze(
|
| 100 |
+
0
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
incontext_latents = (
|
| 104 |
+
true_latents
|
| 105 |
+
* ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
| 106 |
+
)
|
| 107 |
+
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
| 108 |
+
|
| 109 |
+
additional_model_input = torch.cat([quantized_feature_emb], 1)
|
| 110 |
+
temperature = 1.0
|
| 111 |
+
t_span = torch.linspace(
|
| 112 |
+
0, 1, num_steps + 1, device=quantized_feature_emb.device
|
| 113 |
+
)
|
| 114 |
+
latents = self.solve_euler(
|
| 115 |
+
latents * temperature,
|
| 116 |
+
incontext_latents.to(dtype),
|
| 117 |
+
incontext_length,
|
| 118 |
+
t_span,
|
| 119 |
+
additional_model_input,
|
| 120 |
+
guidance_scale,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
latents[:, 0:incontext_length, :] = incontext_latents[
|
| 124 |
+
:, 0:incontext_length, :
|
| 125 |
+
] # B, T, dim
|
| 126 |
+
return latents
|
| 127 |
+
|
| 128 |
+
def solve_euler(self, x, incontext_x, incontext_length, t_span, mu, guidance_scale):
|
| 129 |
+
"""
|
| 130 |
+
Fixed euler solver for ODEs.
|
| 131 |
+
Args:
|
| 132 |
+
x (torch.Tensor): random noise
|
| 133 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 134 |
+
shape: (n_timesteps + 1,)
|
| 135 |
+
mu (torch.Tensor): output of encoder
|
| 136 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 137 |
+
"""
|
| 138 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 139 |
+
noise = x.clone()
|
| 140 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 141 |
+
# Or in future might add like a return_all_steps flag
|
| 142 |
+
sol = []
|
| 143 |
+
for step in tqdm(range(1, len(t_span))):
|
| 144 |
+
x[:, 0:incontext_length, :] = (1 - (1 - 1e-6) * t) * noise[
|
| 145 |
+
:, 0:incontext_length, :
|
| 146 |
+
] + t * incontext_x[:, 0:incontext_length, :]
|
| 147 |
+
if guidance_scale > 1.0:
|
| 148 |
+
dphi_dt = self.estimator(
|
| 149 |
+
torch.cat(
|
| 150 |
+
[
|
| 151 |
+
torch.cat([x, x], 0),
|
| 152 |
+
torch.cat([incontext_x, incontext_x], 0),
|
| 153 |
+
torch.cat([torch.zeros_like(mu), mu], 0),
|
| 154 |
+
],
|
| 155 |
+
2,
|
| 156 |
+
),
|
| 157 |
+
timestep=t.unsqueeze(-1).repeat(2),
|
| 158 |
+
)
|
| 159 |
+
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2, 0)
|
| 160 |
+
dphi_dt = dphi_dt_uncond + guidance_scale * (
|
| 161 |
+
dhpi_dt_cond - dphi_dt_uncond
|
| 162 |
+
)
|
| 163 |
+
else:
|
| 164 |
+
dphi_dt = self.estimator(
|
| 165 |
+
torch.cat([x, incontext_x, mu], 2), timestep=t.unsqueeze(-1)
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
x = x + dt * dphi_dt
|
| 169 |
+
t = t + dt
|
| 170 |
+
sol.append(x)
|
| 171 |
+
if step < len(t_span) - 1:
|
| 172 |
+
dt = t_span[step + 1] - t
|
| 173 |
+
|
| 174 |
+
result = sol[-1]
|
| 175 |
+
|
| 176 |
+
return result
|
src/heartlib/heartcodec/models/sq_codec.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 6 |
+
from torch.nn.utils import remove_weight_norm
|
| 7 |
+
from torch.autograd.function import InplaceFunction
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_padding(kernel_size, dilation=1):
|
| 11 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Scripting this brings model speed up 1.4x
|
| 15 |
+
@torch.jit.script
|
| 16 |
+
def snake(x, alpha):
|
| 17 |
+
shape = x.shape
|
| 18 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 19 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 20 |
+
x = x.reshape(shape)
|
| 21 |
+
return x
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Snake1d(nn.Module):
|
| 25 |
+
def __init__(self, channels):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
return snake(x, self.alpha)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Conv1d(nn.Conv1d):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
in_channels: int,
|
| 37 |
+
out_channels: int,
|
| 38 |
+
kernel_size: int,
|
| 39 |
+
stride: int = 1,
|
| 40 |
+
dilation: int = 1,
|
| 41 |
+
groups: int = 1,
|
| 42 |
+
padding_mode: str = "zeros",
|
| 43 |
+
bias: bool = True,
|
| 44 |
+
padding=None,
|
| 45 |
+
causal: bool = False,
|
| 46 |
+
w_init_gain=None,
|
| 47 |
+
):
|
| 48 |
+
self.causal = causal
|
| 49 |
+
if padding is None:
|
| 50 |
+
if causal:
|
| 51 |
+
padding = 0
|
| 52 |
+
self.left_padding = dilation * (kernel_size - 1)
|
| 53 |
+
else:
|
| 54 |
+
padding = get_padding(kernel_size, dilation)
|
| 55 |
+
super(Conv1d, self).__init__(
|
| 56 |
+
in_channels,
|
| 57 |
+
out_channels,
|
| 58 |
+
kernel_size,
|
| 59 |
+
stride=stride,
|
| 60 |
+
padding=padding,
|
| 61 |
+
dilation=dilation,
|
| 62 |
+
groups=groups,
|
| 63 |
+
padding_mode=padding_mode,
|
| 64 |
+
bias=bias,
|
| 65 |
+
)
|
| 66 |
+
if w_init_gain is not None:
|
| 67 |
+
torch.nn.init.xavier_uniform_(
|
| 68 |
+
self.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
if self.causal:
|
| 73 |
+
x = F.pad(x.unsqueeze(2), (self.left_padding, 0, 0, 0)).squeeze(2)
|
| 74 |
+
|
| 75 |
+
return super(Conv1d, self).forward(x)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ConvTranspose1d(nn.ConvTranspose1d):
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
in_channels: int,
|
| 82 |
+
out_channels: int,
|
| 83 |
+
kernel_size: int,
|
| 84 |
+
stride: int = 1,
|
| 85 |
+
output_padding: int = 0,
|
| 86 |
+
groups: int = 1,
|
| 87 |
+
bias: bool = True,
|
| 88 |
+
dilation: int = 1,
|
| 89 |
+
padding=None,
|
| 90 |
+
padding_mode: str = "zeros",
|
| 91 |
+
causal: bool = False,
|
| 92 |
+
):
|
| 93 |
+
if padding is None:
|
| 94 |
+
padding = 0 if causal else (kernel_size - stride) // 2
|
| 95 |
+
if causal:
|
| 96 |
+
assert padding == 0, "padding is not allowed in causal ConvTranspose1d."
|
| 97 |
+
assert (
|
| 98 |
+
kernel_size == 2 * stride
|
| 99 |
+
), "kernel_size must be equal to 2*stride is not allowed in causal ConvTranspose1d."
|
| 100 |
+
super(ConvTranspose1d, self).__init__(
|
| 101 |
+
in_channels,
|
| 102 |
+
out_channels,
|
| 103 |
+
kernel_size,
|
| 104 |
+
stride=stride,
|
| 105 |
+
padding=padding,
|
| 106 |
+
output_padding=output_padding,
|
| 107 |
+
groups=groups,
|
| 108 |
+
bias=bias,
|
| 109 |
+
dilation=dilation,
|
| 110 |
+
padding_mode=padding_mode,
|
| 111 |
+
)
|
| 112 |
+
self.causal = causal
|
| 113 |
+
self.stride = stride
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
x = super(ConvTranspose1d, self).forward(x)
|
| 117 |
+
if self.causal:
|
| 118 |
+
x = x[:, :, : -self.stride]
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class PreProcessor(nn.Module):
|
| 123 |
+
def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False):
|
| 124 |
+
super(PreProcessor, self).__init__()
|
| 125 |
+
self.pooling = torch.nn.AvgPool1d(kernel_size=num_samples)
|
| 126 |
+
self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal)
|
| 127 |
+
self.activation = nn.PReLU()
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
output = self.activation(self.conv(x))
|
| 131 |
+
output = self.pooling(output)
|
| 132 |
+
return output
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class PostProcessor(nn.Module):
|
| 136 |
+
def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False):
|
| 137 |
+
super(PostProcessor, self).__init__()
|
| 138 |
+
self.num_samples = num_samples
|
| 139 |
+
self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal)
|
| 140 |
+
self.activation = nn.PReLU()
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
x = torch.transpose(x, 1, 2)
|
| 144 |
+
B, T, C = x.size()
|
| 145 |
+
x = x.repeat(1, 1, self.num_samples).view(B, -1, C)
|
| 146 |
+
x = torch.transpose(x, 1, 2)
|
| 147 |
+
output = self.activation(self.conv(x))
|
| 148 |
+
return output
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class ResidualUnit(nn.Module):
|
| 152 |
+
def __init__(self, n_in, n_out, dilation, res_kernel_size=7, causal=False):
|
| 153 |
+
super(ResidualUnit, self).__init__()
|
| 154 |
+
self.conv1 = weight_norm(
|
| 155 |
+
Conv1d(
|
| 156 |
+
n_in,
|
| 157 |
+
n_out,
|
| 158 |
+
kernel_size=res_kernel_size,
|
| 159 |
+
dilation=dilation,
|
| 160 |
+
causal=causal,
|
| 161 |
+
)
|
| 162 |
+
)
|
| 163 |
+
self.conv2 = weight_norm(Conv1d(n_in, n_out, kernel_size=1, causal=causal))
|
| 164 |
+
self.activation1 = nn.PReLU()
|
| 165 |
+
self.activation2 = nn.PReLU()
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
output = self.activation1(self.conv1(x))
|
| 169 |
+
output = self.activation2(self.conv2(output))
|
| 170 |
+
return output + x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class ResEncoderBlock(nn.Module):
|
| 174 |
+
def __init__(
|
| 175 |
+
self, n_in, n_out, stride, down_kernel_size, res_kernel_size=7, causal=False
|
| 176 |
+
):
|
| 177 |
+
super(ResEncoderBlock, self).__init__()
|
| 178 |
+
self.convs = nn.ModuleList(
|
| 179 |
+
[
|
| 180 |
+
ResidualUnit(
|
| 181 |
+
n_in,
|
| 182 |
+
n_out // 2,
|
| 183 |
+
dilation=1,
|
| 184 |
+
res_kernel_size=res_kernel_size,
|
| 185 |
+
causal=causal,
|
| 186 |
+
),
|
| 187 |
+
ResidualUnit(
|
| 188 |
+
n_out // 2,
|
| 189 |
+
n_out // 2,
|
| 190 |
+
dilation=3,
|
| 191 |
+
res_kernel_size=res_kernel_size,
|
| 192 |
+
causal=causal,
|
| 193 |
+
),
|
| 194 |
+
ResidualUnit(
|
| 195 |
+
n_out // 2,
|
| 196 |
+
n_out // 2,
|
| 197 |
+
dilation=5,
|
| 198 |
+
res_kernel_size=res_kernel_size,
|
| 199 |
+
causal=causal,
|
| 200 |
+
),
|
| 201 |
+
ResidualUnit(
|
| 202 |
+
n_out // 2,
|
| 203 |
+
n_out // 2,
|
| 204 |
+
dilation=7,
|
| 205 |
+
res_kernel_size=res_kernel_size,
|
| 206 |
+
causal=causal,
|
| 207 |
+
),
|
| 208 |
+
ResidualUnit(
|
| 209 |
+
n_out // 2,
|
| 210 |
+
n_out // 2,
|
| 211 |
+
dilation=9,
|
| 212 |
+
res_kernel_size=res_kernel_size,
|
| 213 |
+
causal=causal,
|
| 214 |
+
),
|
| 215 |
+
]
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
self.down_conv = DownsampleLayer(
|
| 219 |
+
n_in, n_out, down_kernel_size, stride=stride, causal=causal
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
for conv in self.convs:
|
| 224 |
+
x = conv(x)
|
| 225 |
+
x = self.down_conv(x)
|
| 226 |
+
return x
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class ResDecoderBlock(nn.Module):
|
| 230 |
+
def __init__(
|
| 231 |
+
self, n_in, n_out, stride, up_kernel_size, res_kernel_size=7, causal=False
|
| 232 |
+
):
|
| 233 |
+
super(ResDecoderBlock, self).__init__()
|
| 234 |
+
self.up_conv = UpsampleLayer(
|
| 235 |
+
n_in,
|
| 236 |
+
n_out,
|
| 237 |
+
kernel_size=up_kernel_size,
|
| 238 |
+
stride=stride,
|
| 239 |
+
causal=causal,
|
| 240 |
+
activation=None,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
self.convs = nn.ModuleList(
|
| 244 |
+
[
|
| 245 |
+
ResidualUnit(
|
| 246 |
+
n_out,
|
| 247 |
+
n_out,
|
| 248 |
+
dilation=1,
|
| 249 |
+
res_kernel_size=res_kernel_size,
|
| 250 |
+
causal=causal,
|
| 251 |
+
),
|
| 252 |
+
ResidualUnit(
|
| 253 |
+
n_out,
|
| 254 |
+
n_out,
|
| 255 |
+
dilation=3,
|
| 256 |
+
res_kernel_size=res_kernel_size,
|
| 257 |
+
causal=causal,
|
| 258 |
+
),
|
| 259 |
+
ResidualUnit(
|
| 260 |
+
n_out,
|
| 261 |
+
n_out,
|
| 262 |
+
dilation=5,
|
| 263 |
+
res_kernel_size=res_kernel_size,
|
| 264 |
+
causal=causal,
|
| 265 |
+
),
|
| 266 |
+
ResidualUnit(
|
| 267 |
+
n_out,
|
| 268 |
+
n_out,
|
| 269 |
+
dilation=7,
|
| 270 |
+
res_kernel_size=res_kernel_size,
|
| 271 |
+
causal=causal,
|
| 272 |
+
),
|
| 273 |
+
ResidualUnit(
|
| 274 |
+
n_out,
|
| 275 |
+
n_out,
|
| 276 |
+
dilation=9,
|
| 277 |
+
res_kernel_size=res_kernel_size,
|
| 278 |
+
causal=causal,
|
| 279 |
+
),
|
| 280 |
+
]
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
def forward(self, x):
|
| 284 |
+
x = self.up_conv(x)
|
| 285 |
+
for conv in self.convs:
|
| 286 |
+
x = conv(x)
|
| 287 |
+
return x
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class DownsampleLayer(nn.Module):
|
| 291 |
+
def __init__(
|
| 292 |
+
self,
|
| 293 |
+
in_channels: int,
|
| 294 |
+
out_channels: int,
|
| 295 |
+
kernel_size: int,
|
| 296 |
+
stride: int = 1,
|
| 297 |
+
causal: bool = False,
|
| 298 |
+
activation=nn.PReLU(),
|
| 299 |
+
use_weight_norm: bool = True,
|
| 300 |
+
pooling: bool = False,
|
| 301 |
+
):
|
| 302 |
+
super(DownsampleLayer, self).__init__()
|
| 303 |
+
self.pooling = pooling
|
| 304 |
+
self.stride = stride
|
| 305 |
+
self.activation = nn.PReLU()
|
| 306 |
+
self.use_weight_norm = use_weight_norm
|
| 307 |
+
if pooling:
|
| 308 |
+
self.layer = Conv1d(in_channels, out_channels, kernel_size, causal=causal)
|
| 309 |
+
self.pooling = nn.AvgPool1d(kernel_size=stride)
|
| 310 |
+
else:
|
| 311 |
+
self.layer = Conv1d(
|
| 312 |
+
in_channels, out_channels, kernel_size, stride=stride, causal=causal
|
| 313 |
+
)
|
| 314 |
+
if use_weight_norm:
|
| 315 |
+
self.layer = weight_norm(self.layer)
|
| 316 |
+
|
| 317 |
+
def forward(self, x):
|
| 318 |
+
x = self.layer(x)
|
| 319 |
+
x = self.activation(x) if self.activation is not None else x
|
| 320 |
+
if self.pooling:
|
| 321 |
+
x = self.pooling(x)
|
| 322 |
+
return x
|
| 323 |
+
|
| 324 |
+
def remove_weight_norm(self):
|
| 325 |
+
if self.use_weight_norm:
|
| 326 |
+
remove_weight_norm(self.layer)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class UpsampleLayer(nn.Module):
|
| 330 |
+
def __init__(
|
| 331 |
+
self,
|
| 332 |
+
in_channels: int,
|
| 333 |
+
out_channels: int,
|
| 334 |
+
kernel_size: int,
|
| 335 |
+
stride: int = 1,
|
| 336 |
+
causal: bool = False,
|
| 337 |
+
activation=nn.PReLU(),
|
| 338 |
+
use_weight_norm: bool = True,
|
| 339 |
+
repeat: bool = False,
|
| 340 |
+
):
|
| 341 |
+
super(UpsampleLayer, self).__init__()
|
| 342 |
+
self.repeat = repeat
|
| 343 |
+
self.stride = stride
|
| 344 |
+
self.activation = activation
|
| 345 |
+
self.use_weight_norm = use_weight_norm
|
| 346 |
+
if repeat:
|
| 347 |
+
self.layer = Conv1d(in_channels, out_channels, kernel_size, causal=causal)
|
| 348 |
+
else:
|
| 349 |
+
self.layer = ConvTranspose1d(
|
| 350 |
+
in_channels, out_channels, kernel_size, stride=stride, causal=causal
|
| 351 |
+
)
|
| 352 |
+
if use_weight_norm:
|
| 353 |
+
self.layer = weight_norm(self.layer)
|
| 354 |
+
|
| 355 |
+
def forward(self, x):
|
| 356 |
+
x = self.layer(x)
|
| 357 |
+
x = self.activation(x) if self.activation is not None else x
|
| 358 |
+
if self.repeat:
|
| 359 |
+
x = torch.transpose(x, 1, 2)
|
| 360 |
+
B, T, C = x.size()
|
| 361 |
+
x = x.repeat(1, 1, self.stride).view(B, -1, C)
|
| 362 |
+
x = torch.transpose(x, 1, 2)
|
| 363 |
+
return x
|
| 364 |
+
|
| 365 |
+
def remove_weight_norm(self):
|
| 366 |
+
if self.use_weight_norm:
|
| 367 |
+
remove_weight_norm(self.layer)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class round_func9(InplaceFunction):
|
| 371 |
+
@staticmethod
|
| 372 |
+
def forward(ctx, input):
|
| 373 |
+
ctx.input = input
|
| 374 |
+
return torch.round(9 * input) / 9
|
| 375 |
+
|
| 376 |
+
@staticmethod
|
| 377 |
+
def backward(ctx, grad_output):
|
| 378 |
+
grad_input = grad_output.clone()
|
| 379 |
+
return grad_input
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
class ScalarModel(nn.Module):
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
num_bands,
|
| 386 |
+
sample_rate,
|
| 387 |
+
causal,
|
| 388 |
+
num_samples,
|
| 389 |
+
downsample_factors,
|
| 390 |
+
downsample_kernel_sizes,
|
| 391 |
+
upsample_factors,
|
| 392 |
+
upsample_kernel_sizes,
|
| 393 |
+
latent_hidden_dim,
|
| 394 |
+
default_kernel_size,
|
| 395 |
+
delay_kernel_size,
|
| 396 |
+
init_channel,
|
| 397 |
+
res_kernel_size,
|
| 398 |
+
mode="pre_proj",
|
| 399 |
+
):
|
| 400 |
+
super(ScalarModel, self).__init__()
|
| 401 |
+
# self.args = args
|
| 402 |
+
self.encoder = []
|
| 403 |
+
self.decoder = []
|
| 404 |
+
self.vq = round_func9() # using 9
|
| 405 |
+
self.mode = mode
|
| 406 |
+
# Encoder parts
|
| 407 |
+
self.encoder.append(
|
| 408 |
+
weight_norm(
|
| 409 |
+
Conv1d(
|
| 410 |
+
num_bands,
|
| 411 |
+
init_channel,
|
| 412 |
+
kernel_size=default_kernel_size,
|
| 413 |
+
causal=causal,
|
| 414 |
+
)
|
| 415 |
+
)
|
| 416 |
+
)
|
| 417 |
+
if num_samples > 1:
|
| 418 |
+
# Downsampling
|
| 419 |
+
self.encoder.append(
|
| 420 |
+
PreProcessor(
|
| 421 |
+
init_channel,
|
| 422 |
+
init_channel,
|
| 423 |
+
num_samples,
|
| 424 |
+
kernel_size=default_kernel_size,
|
| 425 |
+
causal=causal,
|
| 426 |
+
)
|
| 427 |
+
)
|
| 428 |
+
for i, down_factor in enumerate(downsample_factors):
|
| 429 |
+
self.encoder.append(
|
| 430 |
+
ResEncoderBlock(
|
| 431 |
+
init_channel * np.power(2, i),
|
| 432 |
+
init_channel * np.power(2, i + 1),
|
| 433 |
+
down_factor,
|
| 434 |
+
downsample_kernel_sizes[i],
|
| 435 |
+
res_kernel_size,
|
| 436 |
+
causal=causal,
|
| 437 |
+
)
|
| 438 |
+
)
|
| 439 |
+
self.encoder.append(
|
| 440 |
+
weight_norm(
|
| 441 |
+
Conv1d(
|
| 442 |
+
init_channel * np.power(2, len(downsample_factors)),
|
| 443 |
+
latent_hidden_dim,
|
| 444 |
+
kernel_size=default_kernel_size,
|
| 445 |
+
causal=causal,
|
| 446 |
+
)
|
| 447 |
+
)
|
| 448 |
+
)
|
| 449 |
+
# Decoder
|
| 450 |
+
# look ahead
|
| 451 |
+
self.decoder.append(
|
| 452 |
+
weight_norm(
|
| 453 |
+
Conv1d(
|
| 454 |
+
latent_hidden_dim,
|
| 455 |
+
init_channel * np.power(2, len(upsample_factors)),
|
| 456 |
+
kernel_size=delay_kernel_size,
|
| 457 |
+
)
|
| 458 |
+
)
|
| 459 |
+
)
|
| 460 |
+
for i, upsample_factor in enumerate(upsample_factors):
|
| 461 |
+
self.decoder.append(
|
| 462 |
+
ResDecoderBlock(
|
| 463 |
+
init_channel * np.power(2, len(upsample_factors) - i),
|
| 464 |
+
init_channel * np.power(2, len(upsample_factors) - i - 1),
|
| 465 |
+
upsample_factor,
|
| 466 |
+
upsample_kernel_sizes[i],
|
| 467 |
+
res_kernel_size,
|
| 468 |
+
causal=causal,
|
| 469 |
+
)
|
| 470 |
+
)
|
| 471 |
+
if num_samples > 1:
|
| 472 |
+
self.decoder.append(
|
| 473 |
+
PostProcessor(
|
| 474 |
+
init_channel,
|
| 475 |
+
init_channel,
|
| 476 |
+
num_samples,
|
| 477 |
+
kernel_size=default_kernel_size,
|
| 478 |
+
causal=causal,
|
| 479 |
+
)
|
| 480 |
+
)
|
| 481 |
+
self.decoder.append(
|
| 482 |
+
weight_norm(
|
| 483 |
+
Conv1d(
|
| 484 |
+
init_channel,
|
| 485 |
+
num_bands,
|
| 486 |
+
kernel_size=default_kernel_size,
|
| 487 |
+
causal=causal,
|
| 488 |
+
)
|
| 489 |
+
)
|
| 490 |
+
)
|
| 491 |
+
self.encoder = nn.ModuleList(self.encoder)
|
| 492 |
+
self.decoder = nn.ModuleList(self.decoder)
|
| 493 |
+
|
| 494 |
+
def forward(self, x):
|
| 495 |
+
for i, layer in enumerate(self.encoder):
|
| 496 |
+
if i != len(self.encoder) - 1:
|
| 497 |
+
x = layer(x)
|
| 498 |
+
else:
|
| 499 |
+
x = F.tanh(layer(x))
|
| 500 |
+
# import pdb; pdb.set_trace()
|
| 501 |
+
x = self.vq.apply(x) # vq
|
| 502 |
+
for i, layer in enumerate(self.decoder):
|
| 503 |
+
x = layer(x)
|
| 504 |
+
return x
|
| 505 |
+
|
| 506 |
+
def inference(self, x):
|
| 507 |
+
for i, layer in enumerate(self.encoder):
|
| 508 |
+
if i != len(self.encoder) - 1:
|
| 509 |
+
x = layer(x)
|
| 510 |
+
else:
|
| 511 |
+
x = F.tanh(layer(x)) # reverse to tanh
|
| 512 |
+
|
| 513 |
+
emb = x
|
| 514 |
+
# import pdb; pdb.set_trace()
|
| 515 |
+
emb_quant = self.vq.apply(emb) # vq
|
| 516 |
+
x = emb_quant
|
| 517 |
+
for i, layer in enumerate(self.decoder):
|
| 518 |
+
x = layer(x)
|
| 519 |
+
return emb, emb_quant, x
|
| 520 |
+
|
| 521 |
+
def encode(self, x):
|
| 522 |
+
for i, layer in enumerate(self.encoder):
|
| 523 |
+
if i != len(self.encoder) - 1:
|
| 524 |
+
x = layer(x)
|
| 525 |
+
else:
|
| 526 |
+
x = F.tanh(layer(x)) # reverse to tanh
|
| 527 |
+
|
| 528 |
+
emb = x
|
| 529 |
+
# import pdb; pdb.set_trace()
|
| 530 |
+
emb_quant = self.vq.apply(emb) # vq
|
| 531 |
+
return emb
|
| 532 |
+
|
| 533 |
+
def decode(self, x):
|
| 534 |
+
x = self.vq.apply(
|
| 535 |
+
x
|
| 536 |
+
) # make sure the prediction follow the similar disctribution
|
| 537 |
+
for i, layer in enumerate(self.decoder):
|
| 538 |
+
x = layer(x)
|
| 539 |
+
return x
|
src/heartlib/heartcodec/models/transformer.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RMSNorm(nn.Module):
|
| 9 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.eps = eps
|
| 12 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 13 |
+
|
| 14 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
var = x.pow(2).mean(dim=-1, keepdim=True)
|
| 16 |
+
x = x * torch.rsqrt(var + self.eps)
|
| 17 |
+
return self.weight * x
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RotaryEmbedding(nn.Module):
|
| 21 |
+
def __init__(self, dim: int, base: int = 10000):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.dim = dim
|
| 24 |
+
self.base = base
|
| 25 |
+
self._cache = {}
|
| 26 |
+
|
| 27 |
+
def get_sin_cos(self, seq_len: int, device, dtype):
|
| 28 |
+
key = (seq_len, device, dtype)
|
| 29 |
+
cached = self._cache.get(key, None)
|
| 30 |
+
if cached is not None and cached[0].device == device:
|
| 31 |
+
return cached
|
| 32 |
+
inv_freq = 1.0 / (
|
| 33 |
+
self.base
|
| 34 |
+
** (torch.arange(0, self.dim, 2, device=device, dtype=dtype) / self.dim)
|
| 35 |
+
)
|
| 36 |
+
t = torch.arange(seq_len, device=device, dtype=dtype)
|
| 37 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| 38 |
+
sin = freqs.sin()
|
| 39 |
+
cos = freqs.cos()
|
| 40 |
+
self._cache[key] = (sin, cos)
|
| 41 |
+
return sin, cos
|
| 42 |
+
|
| 43 |
+
def apply_rotary(
|
| 44 |
+
self, x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
x1, x2 = x[..., : self.dim // 2], x[..., self.dim // 2 : self.dim]
|
| 47 |
+
# Interleave sin/cos across pairs
|
| 48 |
+
x_rot = torch.stack((-x2, x1), dim=-1).reshape_as(x[..., : self.dim])
|
| 49 |
+
return (x[..., : self.dim] * cos.unsqueeze(-1)).reshape_as(
|
| 50 |
+
x[..., : self.dim]
|
| 51 |
+
) + (x_rot * sin.unsqueeze(-1)).reshape_as(x[..., : self.dim])
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class LlamaAttention(nn.Module):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
dim: int,
|
| 58 |
+
n_heads: int,
|
| 59 |
+
head_dim: int,
|
| 60 |
+
bias: bool = False,
|
| 61 |
+
dropout: float = 0.0,
|
| 62 |
+
rope_dim: Optional[int] = None,
|
| 63 |
+
cross_attention_dim: Optional[int] = None,
|
| 64 |
+
use_sdpa: bool = True,
|
| 65 |
+
):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.dim = dim
|
| 68 |
+
self.n_heads = n_heads
|
| 69 |
+
self.head_dim = head_dim
|
| 70 |
+
self.inner_dim = n_heads * head_dim
|
| 71 |
+
self.cross_attention_dim = cross_attention_dim
|
| 72 |
+
self.q_proj = nn.Linear(dim, self.inner_dim, bias=bias)
|
| 73 |
+
k_in = dim if cross_attention_dim is None else cross_attention_dim
|
| 74 |
+
self.k_proj = nn.Linear(k_in, self.inner_dim, bias=bias)
|
| 75 |
+
self.v_proj = nn.Linear(k_in, self.inner_dim, bias=bias)
|
| 76 |
+
self.o_proj = nn.Linear(self.inner_dim, dim, bias=bias)
|
| 77 |
+
self.dropout = dropout
|
| 78 |
+
self.rope_dim = rope_dim if rope_dim is not None else head_dim
|
| 79 |
+
self.rope = RotaryEmbedding(self.rope_dim)
|
| 80 |
+
self.use_sdpa = use_sdpa
|
| 81 |
+
self._has_sdpa = hasattr(F, "scaled_dot_product_attention")
|
| 82 |
+
|
| 83 |
+
def _shape(self, x: torch.Tensor, b: int, t: int) -> torch.Tensor:
|
| 84 |
+
return x.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 85 |
+
|
| 86 |
+
def forward(
|
| 87 |
+
self,
|
| 88 |
+
x: torch.Tensor,
|
| 89 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 90 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 91 |
+
) -> torch.Tensor:
|
| 92 |
+
b, t, c = x.shape
|
| 93 |
+
q = self._shape(self.q_proj(x), b, t)
|
| 94 |
+
if encoder_hidden_states is None:
|
| 95 |
+
k = self._shape(self.k_proj(x), b, t)
|
| 96 |
+
v = self._shape(self.v_proj(x), b, t)
|
| 97 |
+
else:
|
| 98 |
+
bt, tk, ck = encoder_hidden_states.shape
|
| 99 |
+
k = self._shape(self.k_proj(encoder_hidden_states), b, tk)
|
| 100 |
+
v = self._shape(self.v_proj(encoder_hidden_states), b, tk)
|
| 101 |
+
|
| 102 |
+
# RoPE on first rope_dim of head_dim
|
| 103 |
+
rope_dim = min(self.rope_dim, self.head_dim)
|
| 104 |
+
seq_len_for_rope = k.shape[-2]
|
| 105 |
+
sin, cos = self.rope.get_sin_cos(
|
| 106 |
+
seq_len_for_rope, device=x.device, dtype=x.dtype
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def apply_rope_vec(tensor):
|
| 110 |
+
head = tensor[..., :rope_dim]
|
| 111 |
+
tail = tensor[..., rope_dim:]
|
| 112 |
+
b, h, tt, _ = head.shape
|
| 113 |
+
head = head.view(b, h, tt, rope_dim // 2, 2)
|
| 114 |
+
sin_ = sin.view(1, 1, tt, rope_dim // 2, 1)
|
| 115 |
+
cos_ = cos.view(1, 1, tt, rope_dim // 2, 1)
|
| 116 |
+
x1 = head[..., 0:1]
|
| 117 |
+
x2 = head[..., 1:2]
|
| 118 |
+
rot = torch.cat(
|
| 119 |
+
[x1 * cos_ - x2 * sin_, x1 * sin_ + x2 * cos_], dim=-1
|
| 120 |
+
).view(b, h, tt, rope_dim)
|
| 121 |
+
return torch.cat([rot, tail], dim=-1)
|
| 122 |
+
|
| 123 |
+
q = apply_rope_vec(q)
|
| 124 |
+
k = apply_rope_vec(k)
|
| 125 |
+
|
| 126 |
+
# Prefer PyTorch SDPA (can enable FlashAttention kernel on supported GPUs)
|
| 127 |
+
if self.use_sdpa and self._has_sdpa:
|
| 128 |
+
s = k.shape[-2]
|
| 129 |
+
attn_mask_sdpa = None
|
| 130 |
+
if attention_mask is not None:
|
| 131 |
+
m = attention_mask
|
| 132 |
+
|
| 133 |
+
if m.dim() == 2 and m.shape == (b, s): # [b, s]
|
| 134 |
+
m = m[:, None, None, :] # [b,1,1,s]
|
| 135 |
+
elif m.dim() == 3 and m.shape[-2] == 1: # [b,1,s]
|
| 136 |
+
m = m[:, None, :, :] # [b,1,1,s]
|
| 137 |
+
elif m.dim() == 3 and m.shape[-2] == t: # [b,t,s]
|
| 138 |
+
m = m[:, None, :, :] # [b,1,t,s]
|
| 139 |
+
elif m.dim() == 4 and m.shape[1] == 1: # [b,1,t,s] or [b,1,1,s]
|
| 140 |
+
pass
|
| 141 |
+
attn_mask_sdpa = m
|
| 142 |
+
|
| 143 |
+
out = F.scaled_dot_product_attention(
|
| 144 |
+
q,
|
| 145 |
+
k,
|
| 146 |
+
v,
|
| 147 |
+
attn_mask=attn_mask_sdpa,
|
| 148 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 149 |
+
is_causal=False,
|
| 150 |
+
)
|
| 151 |
+
out = out.transpose(1, 2).contiguous().view(b, t, self.inner_dim)
|
| 152 |
+
return self.o_proj(out)
|
| 153 |
+
else:
|
| 154 |
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
|
| 155 |
+
self.head_dim
|
| 156 |
+
)
|
| 157 |
+
if attention_mask is not None:
|
| 158 |
+
attn_scores = attn_scores + attention_mask
|
| 159 |
+
attn = attn_scores.softmax(dim=-1)
|
| 160 |
+
attn = F.dropout(attn, p=self.dropout, training=self.training)
|
| 161 |
+
out = torch.matmul(attn, v)
|
| 162 |
+
out = out.transpose(1, 2).contiguous().view(b, t, self.inner_dim)
|
| 163 |
+
return self.o_proj(out)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class LlamaMLP(nn.Module):
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
dim: int,
|
| 170 |
+
hidden_dim: Optional[int] = None,
|
| 171 |
+
multiple_of: int = 256,
|
| 172 |
+
dropout: float = 0.0,
|
| 173 |
+
):
|
| 174 |
+
super().__init__()
|
| 175 |
+
hidden_dim = hidden_dim or 4 * dim
|
| 176 |
+
# align to multiple_of like Llama
|
| 177 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 178 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 179 |
+
self.gate = nn.Linear(dim, hidden_dim, bias=False)
|
| 180 |
+
self.up = nn.Linear(dim, hidden_dim, bias=False)
|
| 181 |
+
self.down = nn.Linear(hidden_dim, dim, bias=False)
|
| 182 |
+
self.dropout = dropout
|
| 183 |
+
|
| 184 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 185 |
+
x = F.silu(self.gate(x)) * self.up(x)
|
| 186 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 187 |
+
return self.down(x)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class LlamaTransformerBlock(nn.Module):
|
| 191 |
+
def __init__(
|
| 192 |
+
self,
|
| 193 |
+
dim: int,
|
| 194 |
+
n_heads: int,
|
| 195 |
+
head_dim: int,
|
| 196 |
+
mlp_multiple_of: int = 256,
|
| 197 |
+
dropout: float = 0.0,
|
| 198 |
+
attention_bias: bool = False,
|
| 199 |
+
cross_attention_dim: Optional[int] = None,
|
| 200 |
+
use_ada_layer_norm_single: bool = False,
|
| 201 |
+
):
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.attn_norm = RMSNorm(dim, 1e-6)
|
| 204 |
+
self.attn = LlamaAttention(
|
| 205 |
+
dim,
|
| 206 |
+
n_heads,
|
| 207 |
+
head_dim,
|
| 208 |
+
bias=attention_bias,
|
| 209 |
+
dropout=dropout,
|
| 210 |
+
rope_dim=head_dim,
|
| 211 |
+
cross_attention_dim=None,
|
| 212 |
+
)
|
| 213 |
+
self.cross_attn = None
|
| 214 |
+
if cross_attention_dim is not None:
|
| 215 |
+
self.cross_attn_norm = RMSNorm(dim, 1e-6)
|
| 216 |
+
self.cross_attn = LlamaAttention(
|
| 217 |
+
dim,
|
| 218 |
+
n_heads,
|
| 219 |
+
head_dim,
|
| 220 |
+
bias=attention_bias,
|
| 221 |
+
dropout=dropout,
|
| 222 |
+
rope_dim=head_dim,
|
| 223 |
+
cross_attention_dim=cross_attention_dim,
|
| 224 |
+
)
|
| 225 |
+
self.mlp_norm = RMSNorm(dim, 1e-6)
|
| 226 |
+
self.mlp = LlamaMLP(dim, multiple_of=mlp_multiple_of, dropout=dropout)
|
| 227 |
+
self.use_ada_layer_norm_single = use_ada_layer_norm_single
|
| 228 |
+
if self.use_ada_layer_norm_single:
|
| 229 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
| 230 |
+
|
| 231 |
+
def forward(
|
| 232 |
+
self,
|
| 233 |
+
x: torch.Tensor,
|
| 234 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 235 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 236 |
+
timestep: Optional[torch.Tensor] = None,
|
| 237 |
+
) -> torch.Tensor:
|
| 238 |
+
if self.use_ada_layer_norm_single:
|
| 239 |
+
batch_size = x.shape[0]
|
| 240 |
+
# timestep: [B, 6*D]
|
| 241 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 242 |
+
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
| 243 |
+
).chunk(6, dim=1)
|
| 244 |
+
|
| 245 |
+
# Self-Attention with modulation and gating
|
| 246 |
+
norm_hidden_states = self.attn_norm(x)
|
| 247 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
| 248 |
+
h = self.attn(norm_hidden_states, attention_mask=attention_mask)
|
| 249 |
+
h = gate_msa * h
|
| 250 |
+
x = x + h
|
| 251 |
+
|
| 252 |
+
# MLP with modulation and gating
|
| 253 |
+
norm_hidden_states = self.mlp_norm(x)
|
| 254 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
| 255 |
+
h = self.mlp(norm_hidden_states)
|
| 256 |
+
h = gate_mlp * h
|
| 257 |
+
x = x + h
|
| 258 |
+
return x
|
| 259 |
+
else:
|
| 260 |
+
h = self.attn(self.attn_norm(x), attention_mask=attention_mask)
|
| 261 |
+
x = x + h
|
| 262 |
+
h = self.mlp(self.mlp_norm(x))
|
| 263 |
+
x = x + h
|
| 264 |
+
return x
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class ProjectLayer(nn.Module):
|
| 268 |
+
def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0.0):
|
| 269 |
+
super().__init__()
|
| 270 |
+
self.kernel_size = kernel_size
|
| 271 |
+
self.dropout = dropout
|
| 272 |
+
self.ffn_1 = nn.Conv1d(
|
| 273 |
+
hidden_size, filter_size, kernel_size, padding=kernel_size // 2
|
| 274 |
+
)
|
| 275 |
+
self.ffn_2 = nn.Linear(filter_size, filter_size)
|
| 276 |
+
|
| 277 |
+
def forward(self, x):
|
| 278 |
+
x = self.ffn_1(x.transpose(1, 2)).transpose(1, 2)
|
| 279 |
+
x = x * self.kernel_size**-0.5
|
| 280 |
+
x = self.ffn_2(x)
|
| 281 |
+
return x
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class LlamaTransformer(nn.Module):
|
| 285 |
+
def __init__(
|
| 286 |
+
self,
|
| 287 |
+
num_attention_heads: int,
|
| 288 |
+
attention_head_dim: int,
|
| 289 |
+
in_channels: int,
|
| 290 |
+
out_channels: int,
|
| 291 |
+
num_layers: int = 12,
|
| 292 |
+
num_layers_2: int = 2,
|
| 293 |
+
dropout: float = 0.0,
|
| 294 |
+
cross_attention_dim: Optional[int] = None,
|
| 295 |
+
norm_type: str = "layer_norm",
|
| 296 |
+
):
|
| 297 |
+
super().__init__()
|
| 298 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 299 |
+
inner_dim_2 = inner_dim * 2
|
| 300 |
+
self.in_channels = in_channels
|
| 301 |
+
self.out_channels = out_channels
|
| 302 |
+
self.inner_dim = inner_dim
|
| 303 |
+
self.inner_dim_2 = inner_dim_2
|
| 304 |
+
self.dropout = dropout
|
| 305 |
+
|
| 306 |
+
self.proj_in = ProjectLayer(in_channels, inner_dim, kernel_size=3)
|
| 307 |
+
|
| 308 |
+
use_ada_single = norm_type == "ada_norm_single"
|
| 309 |
+
self.transformer_blocks = nn.ModuleList(
|
| 310 |
+
[
|
| 311 |
+
LlamaTransformerBlock(
|
| 312 |
+
dim=inner_dim,
|
| 313 |
+
n_heads=num_attention_heads,
|
| 314 |
+
head_dim=attention_head_dim,
|
| 315 |
+
dropout=dropout,
|
| 316 |
+
attention_bias=False,
|
| 317 |
+
cross_attention_dim=cross_attention_dim,
|
| 318 |
+
use_ada_layer_norm_single=use_ada_single,
|
| 319 |
+
)
|
| 320 |
+
for _ in range(num_layers)
|
| 321 |
+
]
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
self.transformer_blocks_2 = nn.ModuleList(
|
| 325 |
+
[
|
| 326 |
+
LlamaTransformerBlock(
|
| 327 |
+
dim=inner_dim_2,
|
| 328 |
+
n_heads=num_attention_heads,
|
| 329 |
+
head_dim=attention_head_dim * 2,
|
| 330 |
+
dropout=dropout,
|
| 331 |
+
attention_bias=False,
|
| 332 |
+
cross_attention_dim=cross_attention_dim,
|
| 333 |
+
use_ada_layer_norm_single=use_ada_single,
|
| 334 |
+
)
|
| 335 |
+
for _ in range(num_layers_2)
|
| 336 |
+
]
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
self.connection_proj = ProjectLayer(
|
| 340 |
+
in_channels + inner_dim, inner_dim_2, kernel_size=3
|
| 341 |
+
)
|
| 342 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
| 343 |
+
self.norm_out_2 = nn.LayerNorm(inner_dim_2, elementwise_affine=False, eps=1e-6)
|
| 344 |
+
self.scale_shift_table = nn.Parameter(
|
| 345 |
+
torch.randn(2, inner_dim) / inner_dim**0.5
|
| 346 |
+
)
|
| 347 |
+
self.scale_shift_table_2 = nn.Parameter(
|
| 348 |
+
torch.randn(2, inner_dim_2) / inner_dim_2**0.5
|
| 349 |
+
)
|
| 350 |
+
self.proj_out = ProjectLayer(inner_dim_2, out_channels, kernel_size=3)
|
| 351 |
+
self.adaln_single = AdaLayerNormSingleFlow(inner_dim)
|
| 352 |
+
self.adaln_single_2 = AdaLayerNormSingleFlow(inner_dim_2)
|
| 353 |
+
|
| 354 |
+
def forward(
|
| 355 |
+
self,
|
| 356 |
+
hidden_states: torch.Tensor,
|
| 357 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 358 |
+
):
|
| 359 |
+
s = self.proj_in(hidden_states)
|
| 360 |
+
|
| 361 |
+
embedded_timestep = None
|
| 362 |
+
timestep_mod = None
|
| 363 |
+
if self.adaln_single is not None and timestep is not None:
|
| 364 |
+
batch_size = s.shape[0]
|
| 365 |
+
timestep_mod, embedded_timestep = self.adaln_single(
|
| 366 |
+
timestep, hidden_dtype=s.dtype
|
| 367 |
+
)
|
| 368 |
+
for blk in self.transformer_blocks:
|
| 369 |
+
s = blk(s, timestep=timestep_mod)
|
| 370 |
+
|
| 371 |
+
if embedded_timestep is None:
|
| 372 |
+
embedded_timestep = torch.zeros(
|
| 373 |
+
s.size(0), s.size(-1), device=s.device, dtype=s.dtype
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
shift, scale = (
|
| 377 |
+
self.scale_shift_table[None] + embedded_timestep[:, None]
|
| 378 |
+
).chunk(2, dim=1)
|
| 379 |
+
s = self.norm_out(s)
|
| 380 |
+
s = s * (1 + scale) + shift
|
| 381 |
+
|
| 382 |
+
x = torch.cat([hidden_states, s], dim=-1)
|
| 383 |
+
x = self.connection_proj(x)
|
| 384 |
+
|
| 385 |
+
embedded_timestep_2 = None
|
| 386 |
+
timestep_mod_2 = None
|
| 387 |
+
if self.adaln_single_2 is not None and timestep is not None:
|
| 388 |
+
batch_size = x.shape[0]
|
| 389 |
+
timestep_mod_2, embedded_timestep_2 = self.adaln_single_2(
|
| 390 |
+
timestep, hidden_dtype=x.dtype
|
| 391 |
+
)
|
| 392 |
+
for blk in self.transformer_blocks_2:
|
| 393 |
+
x = blk(x, timestep=timestep_mod_2)
|
| 394 |
+
|
| 395 |
+
if embedded_timestep_2 is None:
|
| 396 |
+
embedded_timestep_2 = torch.zeros(
|
| 397 |
+
x.size(0), x.size(-1), device=x.device, dtype=x.dtype
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
shift_2, scale_2 = (
|
| 401 |
+
self.scale_shift_table_2[None] + embedded_timestep_2[:, None]
|
| 402 |
+
).chunk(2, dim=1)
|
| 403 |
+
x = self.norm_out_2(x)
|
| 404 |
+
x = x * (1 + scale_2) + shift_2
|
| 405 |
+
|
| 406 |
+
out = self.proj_out(x)
|
| 407 |
+
|
| 408 |
+
return out
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class PixArtAlphaCombinedFlowEmbeddings(nn.Module):
|
| 412 |
+
def __init__(self, embedding_dim: int, size_emb_dim: int):
|
| 413 |
+
super().__init__()
|
| 414 |
+
self.flow_t_size = 512
|
| 415 |
+
self.outdim = size_emb_dim
|
| 416 |
+
self.timestep_embedder = TimestepEmbedding(
|
| 417 |
+
in_channels=self.flow_t_size, time_embed_dim=embedding_dim
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
def timestep_embedding(self, timesteps, max_period=10000, scale=1000):
|
| 421 |
+
half = self.flow_t_size // 2
|
| 422 |
+
freqs = torch.exp(
|
| 423 |
+
-math.log(max_period)
|
| 424 |
+
* torch.arange(start=0, end=half, device=timesteps.device)
|
| 425 |
+
/ half
|
| 426 |
+
).type(timesteps.type())
|
| 427 |
+
args = timesteps[:, None] * freqs[None] * scale
|
| 428 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 429 |
+
if self.flow_t_size % 2:
|
| 430 |
+
embedding = torch.cat(
|
| 431 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
| 432 |
+
)
|
| 433 |
+
return embedding
|
| 434 |
+
|
| 435 |
+
def forward(self, timestep, hidden_dtype):
|
| 436 |
+
timesteps_proj = self.timestep_embedding(timestep)
|
| 437 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))
|
| 438 |
+
conditioning = timesteps_emb
|
| 439 |
+
return conditioning
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
class AdaLayerNormSingleFlow(nn.Module):
|
| 443 |
+
def __init__(self, embedding_dim: int):
|
| 444 |
+
super().__init__()
|
| 445 |
+
self.emb = PixArtAlphaCombinedFlowEmbeddings(
|
| 446 |
+
embedding_dim, size_emb_dim=embedding_dim // 3
|
| 447 |
+
)
|
| 448 |
+
self.silu = nn.SiLU()
|
| 449 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
| 450 |
+
|
| 451 |
+
def forward(
|
| 452 |
+
self,
|
| 453 |
+
timestep: torch.Tensor,
|
| 454 |
+
hidden_dtype: Optional[torch.dtype] = None,
|
| 455 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 456 |
+
|
| 457 |
+
embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype)
|
| 458 |
+
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
class TimestepEmbedding(nn.Module):
|
| 462 |
+
def __init__(self, in_channels: int, time_embed_dim: int):
|
| 463 |
+
super().__init__()
|
| 464 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
| 465 |
+
self.act = nn.SiLU()
|
| 466 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
| 467 |
+
|
| 468 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 469 |
+
x = self.linear_1(x)
|
| 470 |
+
x = self.act(x)
|
| 471 |
+
x = self.linear_2(x)
|
| 472 |
+
return x
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class Timesteps(nn.Module):
|
| 476 |
+
def __init__(
|
| 477 |
+
self,
|
| 478 |
+
num_channels: int,
|
| 479 |
+
flip_sin_to_cos: bool = True,
|
| 480 |
+
downscale_freq_shift: float = 0,
|
| 481 |
+
):
|
| 482 |
+
super().__init__()
|
| 483 |
+
self.num_channels = num_channels
|
| 484 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
| 485 |
+
self.downscale_freq_shift = downscale_freq_shift
|
| 486 |
+
|
| 487 |
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
| 488 |
+
half_dim = self.num_channels // 2
|
| 489 |
+
exponent = (
|
| 490 |
+
-math.log(10000)
|
| 491 |
+
* torch.arange(0, half_dim, device=timesteps.device)
|
| 492 |
+
/ (half_dim - self.downscale_freq_shift)
|
| 493 |
+
)
|
| 494 |
+
emb = torch.exp(exponent)[None, :] * timesteps[:, None]
|
| 495 |
+
if self.flip_sin_to_cos:
|
| 496 |
+
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
|
| 497 |
+
else:
|
| 498 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 499 |
+
if self.num_channels % 2 == 1:
|
| 500 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
| 501 |
+
return emb
|
src/heartlib/heartmula/configuration_heartmula.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class HeartMuLaConfig(PretrainedConfig):
|
| 5 |
+
model_type = "heartmula"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
backbone_flavor: str = "llama-3B",
|
| 10 |
+
decoder_flavor: str = "llama-300M",
|
| 11 |
+
text_vocab_size: int = 128256,
|
| 12 |
+
audio_vocab_size: int = 8197,
|
| 13 |
+
audio_num_codebooks: int = 8,
|
| 14 |
+
muq_dim: int = 512,
|
| 15 |
+
**kwargs
|
| 16 |
+
):
|
| 17 |
+
super().__init__(**kwargs)
|
| 18 |
+
self.backbone_flavor = backbone_flavor
|
| 19 |
+
self.decoder_flavor = decoder_flavor
|
| 20 |
+
self.text_vocab_size = text_vocab_size
|
| 21 |
+
self.audio_vocab_size = audio_vocab_size
|
| 22 |
+
self.audio_num_codebooks = audio_num_codebooks
|
| 23 |
+
self.muq_dim = muq_dim
|
src/heartlib/heartmula/modeling_heartmula.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .configuration_heartmula import HeartMuLaConfig
|
| 4 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torchtune
|
| 8 |
+
from torchtune.models import llama3_2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def llama3_2_3B() -> torchtune.modules.transformer.TransformerDecoder:
|
| 12 |
+
return llama3_2.llama3_2(
|
| 13 |
+
vocab_size=128_256,
|
| 14 |
+
num_layers=28,
|
| 15 |
+
num_heads=24,
|
| 16 |
+
num_kv_heads=8,
|
| 17 |
+
embed_dim=3072,
|
| 18 |
+
max_seq_len=8192,
|
| 19 |
+
intermediate_dim=8192,
|
| 20 |
+
attn_dropout=0.0,
|
| 21 |
+
norm_eps=1e-5,
|
| 22 |
+
rope_base=500_000,
|
| 23 |
+
scale_factor=32,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def llama3_2_300M() -> torchtune.modules.transformer.TransformerDecoder:
|
| 28 |
+
return llama3_2.llama3_2(
|
| 29 |
+
vocab_size=128_256,
|
| 30 |
+
num_layers=3,
|
| 31 |
+
num_heads=8,
|
| 32 |
+
num_kv_heads=4,
|
| 33 |
+
embed_dim=3072,
|
| 34 |
+
max_seq_len=2048,
|
| 35 |
+
intermediate_dim=8192,
|
| 36 |
+
attn_dropout=0.0,
|
| 37 |
+
norm_eps=1e-5,
|
| 38 |
+
rope_base=500_000,
|
| 39 |
+
scale_factor=32,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def llama3_2_7B() -> torchtune.modules.transformer.TransformerDecoder:
|
| 44 |
+
return llama3_2.llama3_2(
|
| 45 |
+
vocab_size=128_256,
|
| 46 |
+
num_layers=32,
|
| 47 |
+
num_heads=32,
|
| 48 |
+
num_kv_heads=8,
|
| 49 |
+
embed_dim=4096,
|
| 50 |
+
max_seq_len=8192,
|
| 51 |
+
intermediate_dim=14336,
|
| 52 |
+
attn_dropout=0.0,
|
| 53 |
+
norm_eps=1e-5,
|
| 54 |
+
rope_base=500_000,
|
| 55 |
+
scale_factor=32,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def llama3_2_400M() -> torchtune.modules.transformer.TransformerDecoder:
|
| 60 |
+
return llama3_2.llama3_2(
|
| 61 |
+
vocab_size=128_256,
|
| 62 |
+
num_layers=4,
|
| 63 |
+
num_heads=8,
|
| 64 |
+
num_kv_heads=4,
|
| 65 |
+
embed_dim=3072,
|
| 66 |
+
max_seq_len=2048,
|
| 67 |
+
intermediate_dim=8192,
|
| 68 |
+
attn_dropout=0.0,
|
| 69 |
+
norm_eps=1e-5,
|
| 70 |
+
rope_base=500_000,
|
| 71 |
+
scale_factor=32,
|
| 72 |
+
) # 减少了num_heads和num_kv_heads之间的倍速,提升了精确度,但降低了效率
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
FLAVORS = {
|
| 76 |
+
"llama-3B": llama3_2_3B,
|
| 77 |
+
"llama-300M": llama3_2_300M,
|
| 78 |
+
"llama-7B": llama3_2_7B,
|
| 79 |
+
"llama-400M": llama3_2_400M,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _prepare_transformer(model):
|
| 84 |
+
embed_dim = model.tok_embeddings.embedding_dim
|
| 85 |
+
model.tok_embeddings = nn.Identity()
|
| 86 |
+
model.output = nn.Identity()
|
| 87 |
+
return model, embed_dim
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _create_causal_mask(seq_len: int, device: torch.device):
|
| 91 |
+
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
|
| 95 |
+
r = mask[input_pos, :]
|
| 96 |
+
return r
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _multinomial_sample_one_no_sync(
|
| 100 |
+
probs,
|
| 101 |
+
): # Does multinomial sampling without a cuda synchronization
|
| 102 |
+
q = torch.empty_like(probs).exponential_(1)
|
| 103 |
+
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
|
| 107 |
+
logits = logits / temperature
|
| 108 |
+
|
| 109 |
+
filter_value: float = -float("Inf")
|
| 110 |
+
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
|
| 111 |
+
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
|
| 112 |
+
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
|
| 113 |
+
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
|
| 114 |
+
|
| 115 |
+
sample_token = _multinomial_sample_one_no_sync(probs)
|
| 116 |
+
return sample_token
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class HeartMuLa(PreTrainedModel):
|
| 120 |
+
config_class = HeartMuLaConfig
|
| 121 |
+
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
config: HeartMuLaConfig,
|
| 125 |
+
):
|
| 126 |
+
super(HeartMuLa, self).__init__(config)
|
| 127 |
+
|
| 128 |
+
self.config = config
|
| 129 |
+
|
| 130 |
+
self.backbone, backbone_dim = _prepare_transformer(
|
| 131 |
+
FLAVORS[config.backbone_flavor]()
|
| 132 |
+
)
|
| 133 |
+
self.decoder, decoder_dim = _prepare_transformer(
|
| 134 |
+
FLAVORS[config.decoder_flavor]()
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
|
| 138 |
+
self.audio_embeddings = nn.Embedding(
|
| 139 |
+
config.audio_vocab_size * config.audio_num_codebooks, backbone_dim
|
| 140 |
+
)
|
| 141 |
+
self.unconditional_text_embedding = nn.Embedding(1, backbone_dim)
|
| 142 |
+
|
| 143 |
+
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
|
| 144 |
+
self.codebook0_head = nn.Linear(
|
| 145 |
+
backbone_dim, config.audio_vocab_size, bias=False
|
| 146 |
+
)
|
| 147 |
+
self.audio_head = nn.Parameter(
|
| 148 |
+
torch.empty(
|
| 149 |
+
config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
self.muq_linear = nn.Linear(config.muq_dim, backbone_dim)
|
| 153 |
+
self.post_init()
|
| 154 |
+
|
| 155 |
+
def setup_caches(self, max_batch_size: int):
|
| 156 |
+
dtype = next(self.parameters()).dtype
|
| 157 |
+
device = next(self.parameters()).device
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
self.reset_caches()
|
| 161 |
+
except RuntimeError:
|
| 162 |
+
pass
|
| 163 |
+
|
| 164 |
+
with device:
|
| 165 |
+
self.backbone.setup_caches(max_batch_size, dtype)
|
| 166 |
+
self.decoder.setup_caches(
|
| 167 |
+
max_batch_size,
|
| 168 |
+
dtype,
|
| 169 |
+
decoder_max_seq_len=self.config.audio_num_codebooks,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
self.register_buffer(
|
| 173 |
+
"backbone_causal_mask",
|
| 174 |
+
_create_causal_mask(self.backbone.max_seq_len, device),
|
| 175 |
+
)
|
| 176 |
+
self.register_buffer(
|
| 177 |
+
"decoder_causal_mask",
|
| 178 |
+
_create_causal_mask(self.config.audio_num_codebooks, device),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def generate_frame(
|
| 182 |
+
self,
|
| 183 |
+
tokens: torch.Tensor,
|
| 184 |
+
tokens_mask: torch.Tensor,
|
| 185 |
+
input_pos: torch.Tensor,
|
| 186 |
+
temperature: float,
|
| 187 |
+
topk: int,
|
| 188 |
+
cfg_scale: float,
|
| 189 |
+
continuous_segments: torch.Tensor = None,
|
| 190 |
+
starts=None,
|
| 191 |
+
) -> torch.Tensor:
|
| 192 |
+
b, s, _ = tokens.size()
|
| 193 |
+
|
| 194 |
+
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
|
| 195 |
+
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
|
| 196 |
+
|
| 197 |
+
uncond_mask = None
|
| 198 |
+
if cfg_scale > 1.0 and b > 1:
|
| 199 |
+
actual_B = b // 2
|
| 200 |
+
uncond_mask = torch.cat(
|
| 201 |
+
[
|
| 202 |
+
torch.zeros(actual_B, dtype=torch.bool, device=tokens.device),
|
| 203 |
+
torch.ones(actual_B, dtype=torch.bool, device=tokens.device),
|
| 204 |
+
]
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
embeds = self._embed_tokens(tokens, uncond_mask=uncond_mask)
|
| 208 |
+
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
|
| 209 |
+
h = masked_embeds.sum(dim=2, dtype=embeds.dtype) # merge
|
| 210 |
+
if continuous_segments is not None:
|
| 211 |
+
continuous_segments = self.muq_linear(continuous_segments)
|
| 212 |
+
if uncond_mask is not None:
|
| 213 |
+
uncond_embed = self.unconditional_text_embedding(
|
| 214 |
+
torch.zeros(1, device=tokens.device, dtype=torch.long)
|
| 215 |
+
)
|
| 216 |
+
mask_expanded = uncond_mask.view(b, 1).expand_as(continuous_segments)
|
| 217 |
+
continuous_segments = torch.where(
|
| 218 |
+
mask_expanded, uncond_embed, continuous_segments
|
| 219 |
+
)
|
| 220 |
+
batch_indices = torch.arange(h.shape[0], device=h.device)
|
| 221 |
+
h[batch_indices, starts] = continuous_segments
|
| 222 |
+
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask)
|
| 223 |
+
last_h = h[:, -1, :] # the last frame
|
| 224 |
+
c0_logits = self.codebook0_head(last_h) # only predict the audio part
|
| 225 |
+
|
| 226 |
+
if cfg_scale > 1.0 and b > 1 and (b % 2 == 0):
|
| 227 |
+
actual_B = b // 2
|
| 228 |
+
cond_logits = c0_logits[:actual_B, :]
|
| 229 |
+
uncond_logits = c0_logits[actual_B:, :]
|
| 230 |
+
guided_logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
| 231 |
+
c0_sample = sample_topk(guided_logits, topk, temperature)
|
| 232 |
+
c0_sample = c0_sample.repeat(
|
| 233 |
+
2, 1
|
| 234 |
+
) # repeat to both branches to keep alignment
|
| 235 |
+
else:
|
| 236 |
+
c0_sample = sample_topk(c0_logits, topk, temperature)
|
| 237 |
+
|
| 238 |
+
c0_embed = self._embed_audio(0, c0_sample)
|
| 239 |
+
|
| 240 |
+
self.decoder.reset_caches()
|
| 241 |
+
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
|
| 242 |
+
curr_sample = c0_sample.clone()
|
| 243 |
+
curr_pos = (
|
| 244 |
+
torch.arange(0, curr_h.size(1), device=curr_h.device)
|
| 245 |
+
.unsqueeze(0)
|
| 246 |
+
.repeat(curr_h.size(0), 1)
|
| 247 |
+
)
|
| 248 |
+
curr_h = curr_h.to(embeds.dtype)
|
| 249 |
+
for i in range(1, self.config.audio_num_codebooks):
|
| 250 |
+
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
|
| 251 |
+
decoder_h = self.decoder(
|
| 252 |
+
self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask
|
| 253 |
+
)
|
| 254 |
+
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
|
| 255 |
+
if cfg_scale > 1.0 and b > 1 and (b % 2 == 0):
|
| 256 |
+
actual_B = b // 2
|
| 257 |
+
cond_ci = ci_logits[:actual_B, :]
|
| 258 |
+
uncond_ci = ci_logits[actual_B:, :]
|
| 259 |
+
guided_ci = uncond_ci + (cond_ci - uncond_ci) * cfg_scale
|
| 260 |
+
|
| 261 |
+
ci_sample = sample_topk(guided_ci, topk, temperature)
|
| 262 |
+
ci_sample = ci_sample.repeat(2, 1)
|
| 263 |
+
else:
|
| 264 |
+
ci_sample = sample_topk(ci_logits, topk, temperature)
|
| 265 |
+
ci_embed = self._embed_audio(i, ci_sample)
|
| 266 |
+
curr_h = ci_embed
|
| 267 |
+
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
|
| 268 |
+
curr_pos = curr_pos[:, -1:] + 1
|
| 269 |
+
|
| 270 |
+
return curr_sample
|
| 271 |
+
|
| 272 |
+
def reset_caches(self):
|
| 273 |
+
self.backbone.reset_caches()
|
| 274 |
+
self.decoder.reset_caches()
|
| 275 |
+
|
| 276 |
+
def _embed_local_audio(self, tokens):
|
| 277 |
+
"""the token from 0-30"""
|
| 278 |
+
audio_tokens = tokens + (
|
| 279 |
+
self.config.audio_vocab_size
|
| 280 |
+
* torch.arange(self.config.audio_num_codebooks - 1, device=tokens.device)
|
| 281 |
+
)
|
| 282 |
+
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
| 283 |
+
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks - 1, -1
|
| 284 |
+
)
|
| 285 |
+
return audio_embeds
|
| 286 |
+
|
| 287 |
+
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
|
| 288 |
+
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
|
| 289 |
+
|
| 290 |
+
def _embed_tokens(
|
| 291 |
+
self, tokens: torch.Tensor, uncond_mask: torch.Tensor | None
|
| 292 |
+
) -> torch.Tensor:
|
| 293 |
+
B, S, _ = tokens.size()
|
| 294 |
+
text_embeds = self.text_embeddings(tokens[:, :, -1])
|
| 295 |
+
|
| 296 |
+
if uncond_mask is not None:
|
| 297 |
+
uncond_text_embed = self.unconditional_text_embedding(
|
| 298 |
+
torch.zeros(1, device=tokens.device, dtype=torch.long)
|
| 299 |
+
)
|
| 300 |
+
mask_expanded = uncond_mask.view(B, 1, 1).expand_as(text_embeds)
|
| 301 |
+
text_embeds = torch.where(
|
| 302 |
+
mask_expanded,
|
| 303 |
+
uncond_text_embed,
|
| 304 |
+
text_embeds,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
text_embeds = text_embeds.unsqueeze(-2)
|
| 308 |
+
|
| 309 |
+
audio_tokens = tokens[:, :, :-1] + (
|
| 310 |
+
self.config.audio_vocab_size
|
| 311 |
+
* torch.arange(self.config.audio_num_codebooks, device=tokens.device)
|
| 312 |
+
)
|
| 313 |
+
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
| 314 |
+
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
|
| 315 |
+
)
|
| 316 |
+
return torch.cat([audio_embeds, text_embeds], dim=-2)
|
src/heartlib/pipelines/__init__.py
ADDED
|
File without changes
|
src/heartlib/pipelines/lyrics_transcription.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.pipelines.automatic_speech_recognition import (
|
| 2 |
+
AutomaticSpeechRecognitionPipeline,
|
| 3 |
+
)
|
| 4 |
+
from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration
|
| 5 |
+
from transformers.models.whisper.processing_whisper import WhisperProcessor
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HeartTranscriptorPipeline(AutomaticSpeechRecognitionPipeline):
|
| 11 |
+
def __init__(self, *args, **kwargs):
|
| 12 |
+
super().__init__(*args, **kwargs)
|
| 13 |
+
|
| 14 |
+
@classmethod
|
| 15 |
+
def from_pretrained(
|
| 16 |
+
cls, pretrained_path: str, device: torch.device, dtype: torch.dtype
|
| 17 |
+
):
|
| 18 |
+
if os.path.exists(
|
| 19 |
+
hearttranscriptor_path := os.path.join(
|
| 20 |
+
pretrained_path, "HeartTranscriptor-oss"
|
| 21 |
+
)
|
| 22 |
+
):
|
| 23 |
+
model = WhisperForConditionalGeneration.from_pretrained(
|
| 24 |
+
hearttranscriptor_path, torch_dtype=dtype, low_cpu_mem_usage=True
|
| 25 |
+
)
|
| 26 |
+
processor = WhisperProcessor.from_pretrained(hearttranscriptor_path)
|
| 27 |
+
else:
|
| 28 |
+
raise FileNotFoundError(
|
| 29 |
+
f"Expected to find checkpoint for HeartTranscriptor at {hearttranscriptor_path} but not found. Please check your folder {pretrained_path}."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return cls(
|
| 33 |
+
model=model,
|
| 34 |
+
tokenizer=processor.tokenizer,
|
| 35 |
+
feature_extractor=processor.feature_extractor,
|
| 36 |
+
device=device,
|
| 37 |
+
dtype=dtype,
|
| 38 |
+
chunk_length_s=30,
|
| 39 |
+
batch_size=16,
|
| 40 |
+
)
|
src/heartlib/pipelines/music_generation.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tokenizers import Tokenizer
|
| 2 |
+
from ..heartmula.modeling_heartmula import HeartMuLa
|
| 3 |
+
from ..heartcodec.modeling_heartcodec import HeartCodec
|
| 4 |
+
import torch
|
| 5 |
+
from typing import Dict, Any, Optional, Union
|
| 6 |
+
import os
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import torchaudio
|
| 10 |
+
import json
|
| 11 |
+
from contextlib import contextmanager
|
| 12 |
+
import gc
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _resolve_paths(pretrained_path: str, version: str):
|
| 16 |
+
|
| 17 |
+
heartmula_path = os.path.join(pretrained_path, f"HeartMuLa-oss-{version}")
|
| 18 |
+
heartcodec_path = os.path.join(pretrained_path, "HeartCodec-oss")
|
| 19 |
+
tokenizer_path = os.path.join(pretrained_path, "tokenizer.json")
|
| 20 |
+
gen_config_path = os.path.join(pretrained_path, "gen_config.json")
|
| 21 |
+
|
| 22 |
+
if not os.path.exists(heartmula_path):
|
| 23 |
+
raise FileNotFoundError(
|
| 24 |
+
f"Expected to find checkpoint for HeartMuLa at {heartmula_path} but not found. Please check your folder {pretrained_path}."
|
| 25 |
+
)
|
| 26 |
+
if not os.path.exists(heartcodec_path):
|
| 27 |
+
raise FileNotFoundError(
|
| 28 |
+
f"Expected to find checkpoint for HeartCodec at {heartcodec_path} but not found. Please check your folder {pretrained_path}."
|
| 29 |
+
)
|
| 30 |
+
if not os.path.isfile(tokenizer_path):
|
| 31 |
+
raise FileNotFoundError(
|
| 32 |
+
f"Expected to find tokenizer.json for HeartMuLa at {tokenizer_path} but not found. Please check your folder {pretrained_path}."
|
| 33 |
+
)
|
| 34 |
+
if not os.path.isfile(gen_config_path):
|
| 35 |
+
raise FileNotFoundError(
|
| 36 |
+
f"Expected to find gen_config.json for HeartMuLa at {gen_config_path} but not found. Please check your folder {pretrained_path}."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
return heartmula_path, heartcodec_path, tokenizer_path, gen_config_path
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _resolve_devices(
|
| 43 |
+
device: Union[torch.device, Dict[str, torch.device]], lazy_load: bool
|
| 44 |
+
):
|
| 45 |
+
if isinstance(device, torch.device):
|
| 46 |
+
print(f"All model components will be loaded to device: {device}.")
|
| 47 |
+
mula_device = device
|
| 48 |
+
codec_device = device
|
| 49 |
+
elif isinstance(device, dict):
|
| 50 |
+
print("Model components will be loaded to devices as specified:")
|
| 51 |
+
for k, v in device.items():
|
| 52 |
+
print(f" {k}: {v}")
|
| 53 |
+
mula_device = device["mula"]
|
| 54 |
+
codec_device = device["codec"]
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError(
|
| 57 |
+
"device must be either torch.device or Dict[str, torch.device]"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
single_device = mula_device == codec_device
|
| 61 |
+
if not single_device:
|
| 62 |
+
print(
|
| 63 |
+
f"HeartMuLa and HeartCodec will be loaded to different devices. In this case, lazy_load is turned off."
|
| 64 |
+
)
|
| 65 |
+
lazy_load = False
|
| 66 |
+
|
| 67 |
+
return mula_device, codec_device, lazy_load
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclass
|
| 71 |
+
class HeartMuLaGenConfig:
|
| 72 |
+
text_bos_id: int = 128000
|
| 73 |
+
text_eos_id: int = 128001
|
| 74 |
+
audio_eos_id: int = 8193
|
| 75 |
+
empty_id: int = 0
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def from_file(cls, path: str):
|
| 79 |
+
with open(path, encoding="utf-8") as fp:
|
| 80 |
+
data = json.load(fp)
|
| 81 |
+
return cls(**data)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class HeartMuLaGenPipeline:
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
heartmula_path: str,
|
| 88 |
+
heartcodec_path: str,
|
| 89 |
+
heartmula_device: torch.device,
|
| 90 |
+
heartcodec_device: torch.device,
|
| 91 |
+
heartmula_dtype: torch.dtype,
|
| 92 |
+
heartcodec_dtype: torch.dtype,
|
| 93 |
+
lazy_load: bool,
|
| 94 |
+
muq_mulan: Optional[Any],
|
| 95 |
+
text_tokenizer: Tokenizer,
|
| 96 |
+
config: HeartMuLaGenConfig,
|
| 97 |
+
):
|
| 98 |
+
|
| 99 |
+
self.muq_mulan = muq_mulan
|
| 100 |
+
self.text_tokenizer = text_tokenizer
|
| 101 |
+
self.config = config
|
| 102 |
+
|
| 103 |
+
# Remain fixed here for simplicity.
|
| 104 |
+
self._parallel_number = 8 + 1
|
| 105 |
+
self._muq_dim = 512
|
| 106 |
+
|
| 107 |
+
self.mula_dtype = heartmula_dtype
|
| 108 |
+
self.mula_path = heartmula_path
|
| 109 |
+
self.mula_device = heartmula_device
|
| 110 |
+
self.codec_dtype = heartcodec_dtype
|
| 111 |
+
self.codec_path = heartcodec_path
|
| 112 |
+
self.codec_device = heartcodec_device
|
| 113 |
+
|
| 114 |
+
self._mula: Optional[HeartMuLa] = None
|
| 115 |
+
self._codec: Optional[HeartCodec] = None
|
| 116 |
+
if not lazy_load:
|
| 117 |
+
print(
|
| 118 |
+
f"You have set lazy_load = False. Loading HeartMuLa and HeartCodec onto device..."
|
| 119 |
+
)
|
| 120 |
+
self._mula = HeartMuLa.from_pretrained(
|
| 121 |
+
self.mula_path,
|
| 122 |
+
device_map=self.mula_device,
|
| 123 |
+
torch_dtype=self.mula_dtype,
|
| 124 |
+
)
|
| 125 |
+
self._codec = HeartCodec.from_pretrained(
|
| 126 |
+
self.codec_path,
|
| 127 |
+
device_map=self.codec_device,
|
| 128 |
+
torch_dtype=self.codec_dtype,
|
| 129 |
+
)
|
| 130 |
+
self.lazy_load = lazy_load
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def mula(self) -> HeartMuLa:
|
| 134 |
+
if isinstance(self._mula, HeartMuLa):
|
| 135 |
+
return self._mula
|
| 136 |
+
self._mula = HeartMuLa.from_pretrained(
|
| 137 |
+
self.mula_path,
|
| 138 |
+
device_map=self.mula_device,
|
| 139 |
+
torch_dtype=self.mula_dtype,
|
| 140 |
+
)
|
| 141 |
+
return self._mula
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def codec(self) -> HeartCodec:
|
| 145 |
+
if isinstance(self._codec, HeartCodec):
|
| 146 |
+
return self._codec
|
| 147 |
+
self._codec = HeartCodec.from_pretrained(
|
| 148 |
+
self.codec_path,
|
| 149 |
+
device_map=self.codec_device,
|
| 150 |
+
torch_dtype=self.codec_dtype,
|
| 151 |
+
)
|
| 152 |
+
return self._codec
|
| 153 |
+
|
| 154 |
+
def _unload(self):
|
| 155 |
+
if not self.lazy_load:
|
| 156 |
+
return
|
| 157 |
+
if isinstance(self._mula, HeartMuLa):
|
| 158 |
+
print(f"You have set lazy_load=True. Unloading HeartMuLa from device.")
|
| 159 |
+
print(
|
| 160 |
+
f"CUDA memory before unloading: {torch.cuda.memory_allocated(self.mula_device) / 1024**3:.2f} GB"
|
| 161 |
+
)
|
| 162 |
+
del self._mula
|
| 163 |
+
gc.collect()
|
| 164 |
+
torch.cuda.empty_cache()
|
| 165 |
+
print(
|
| 166 |
+
f"CUDA memory after unloading: {torch.cuda.memory_allocated(self.mula_device) / 1024**3:.2f} GB"
|
| 167 |
+
)
|
| 168 |
+
self._mula = None
|
| 169 |
+
if isinstance(self._codec, HeartCodec):
|
| 170 |
+
print(f"You have set lazy_load=True. Unloading HeartCodec from device.")
|
| 171 |
+
print(
|
| 172 |
+
f"CUDA memory before unloading: {torch.cuda.memory_allocated(self.codec_device) / 1024**3:.2f} GB"
|
| 173 |
+
)
|
| 174 |
+
del self._codec
|
| 175 |
+
gc.collect()
|
| 176 |
+
torch.cuda.empty_cache()
|
| 177 |
+
print(
|
| 178 |
+
f"CUDA memory after unloading: {torch.cuda.memory_allocated(self.codec_device) / 1024**3:.2f} GB"
|
| 179 |
+
)
|
| 180 |
+
self._codec = None
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
def _sanitize_parameters(self, **kwargs):
|
| 184 |
+
preprocess_kwargs = {"cfg_scale": kwargs.get("cfg_scale", 1.5)}
|
| 185 |
+
forward_kwargs = {
|
| 186 |
+
"max_audio_length_ms": kwargs.get("max_audio_length_ms", 120_000),
|
| 187 |
+
"temperature": kwargs.get("temperature", 1.0),
|
| 188 |
+
"topk": kwargs.get("topk", 50),
|
| 189 |
+
"cfg_scale": kwargs.get("cfg_scale", 1.5),
|
| 190 |
+
}
|
| 191 |
+
postprocess_kwargs = {
|
| 192 |
+
"save_path": kwargs.get("save_path", "output.mp3"),
|
| 193 |
+
}
|
| 194 |
+
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
| 195 |
+
|
| 196 |
+
def preprocess(self, inputs: Dict[str, Any], cfg_scale: float):
|
| 197 |
+
|
| 198 |
+
# process tags
|
| 199 |
+
tags = inputs["tags"]
|
| 200 |
+
if os.path.isfile(tags):
|
| 201 |
+
with open(tags, encoding="utf-8") as fp:
|
| 202 |
+
tags = fp.read()
|
| 203 |
+
assert isinstance(tags, str), f"tags must be a string, but got {type(tags)}"
|
| 204 |
+
|
| 205 |
+
tags = tags.lower()
|
| 206 |
+
# encapsulate with special <tag> and </tag> tokens
|
| 207 |
+
if not tags.startswith("<tag>"):
|
| 208 |
+
tags = f"<tag>{tags}"
|
| 209 |
+
if not tags.endswith("</tag>"):
|
| 210 |
+
tags = f"{tags}</tag>"
|
| 211 |
+
|
| 212 |
+
tags_ids = self.text_tokenizer.encode(tags).ids
|
| 213 |
+
if tags_ids[0] != self.config.text_bos_id:
|
| 214 |
+
tags_ids = [self.config.text_bos_id] + tags_ids
|
| 215 |
+
if tags_ids[-1] != self.config.text_eos_id:
|
| 216 |
+
tags_ids = tags_ids + [self.config.text_eos_id]
|
| 217 |
+
|
| 218 |
+
# process reference audio
|
| 219 |
+
ref_audio = inputs.get("ref_audio", None)
|
| 220 |
+
if ref_audio is not None:
|
| 221 |
+
raise NotImplementedError("ref_audio is not supported yet.")
|
| 222 |
+
muq_embed = torch.zeros([self._muq_dim], dtype=self.mula_dtype)
|
| 223 |
+
muq_idx = len(tags_ids)
|
| 224 |
+
|
| 225 |
+
# process lyrics
|
| 226 |
+
lyrics = inputs["lyrics"]
|
| 227 |
+
if os.path.isfile(lyrics):
|
| 228 |
+
with open(lyrics, encoding="utf-8") as fp:
|
| 229 |
+
lyrics = fp.read()
|
| 230 |
+
assert isinstance(
|
| 231 |
+
lyrics, str
|
| 232 |
+
), f"lyrics must be a string, but got {type(lyrics)}"
|
| 233 |
+
lyrics = lyrics.lower()
|
| 234 |
+
|
| 235 |
+
lyrics_ids = self.text_tokenizer.encode(lyrics).ids
|
| 236 |
+
if lyrics_ids[0] != self.config.text_bos_id:
|
| 237 |
+
lyrics_ids = [self.config.text_bos_id] + lyrics_ids
|
| 238 |
+
if lyrics_ids[-1] != self.config.text_eos_id:
|
| 239 |
+
lyrics_ids = lyrics_ids + [self.config.text_eos_id]
|
| 240 |
+
|
| 241 |
+
# cat them together. tags, ref_audio, lyrics
|
| 242 |
+
prompt_len = len(tags_ids) + 1 + len(lyrics_ids)
|
| 243 |
+
|
| 244 |
+
tokens = torch.zeros([prompt_len, self._parallel_number], dtype=torch.long)
|
| 245 |
+
tokens[: len(tags_ids), -1] = torch.tensor(tags_ids)
|
| 246 |
+
tokens[len(tags_ids) + 1 :, -1] = torch.tensor(lyrics_ids)
|
| 247 |
+
|
| 248 |
+
tokens_mask = torch.zeros_like(tokens, dtype=torch.bool)
|
| 249 |
+
tokens_mask[:, -1] = True
|
| 250 |
+
|
| 251 |
+
bs_size = 2 if cfg_scale != 1.0 else 1
|
| 252 |
+
|
| 253 |
+
def _cfg_cat(tensor: torch.Tensor, cfg_scale: float):
|
| 254 |
+
tensor = tensor.unsqueeze(0)
|
| 255 |
+
if cfg_scale != 1.0:
|
| 256 |
+
tensor = torch.cat([tensor, tensor], dim=0)
|
| 257 |
+
return tensor
|
| 258 |
+
|
| 259 |
+
return {
|
| 260 |
+
"tokens": _cfg_cat(tokens, cfg_scale),
|
| 261 |
+
"tokens_mask": _cfg_cat(tokens_mask, cfg_scale),
|
| 262 |
+
"muq_embed": _cfg_cat(muq_embed, cfg_scale),
|
| 263 |
+
"muq_idx": [muq_idx] * bs_size,
|
| 264 |
+
"pos": _cfg_cat(torch.arange(prompt_len, dtype=torch.long), cfg_scale),
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
def _forward(
|
| 268 |
+
self,
|
| 269 |
+
model_inputs: Dict[str, Any],
|
| 270 |
+
max_audio_length_ms: int,
|
| 271 |
+
temperature: float,
|
| 272 |
+
topk: int,
|
| 273 |
+
cfg_scale: float,
|
| 274 |
+
):
|
| 275 |
+
prompt_tokens = model_inputs["tokens"].to(self.mula_device)
|
| 276 |
+
prompt_tokens_mask = model_inputs["tokens_mask"].to(self.mula_device)
|
| 277 |
+
continuous_segment = model_inputs["muq_embed"].to(self.mula_device)
|
| 278 |
+
starts = model_inputs["muq_idx"]
|
| 279 |
+
prompt_pos = model_inputs["pos"].to(self.mula_device)
|
| 280 |
+
frames = []
|
| 281 |
+
|
| 282 |
+
bs_size = 2 if cfg_scale != 1.0 else 1
|
| 283 |
+
self.mula.setup_caches(bs_size)
|
| 284 |
+
with torch.autocast(device_type=self.mula_device.type, dtype=self.mula_dtype):
|
| 285 |
+
curr_token = self.mula.generate_frame(
|
| 286 |
+
tokens=prompt_tokens,
|
| 287 |
+
tokens_mask=prompt_tokens_mask,
|
| 288 |
+
input_pos=prompt_pos,
|
| 289 |
+
temperature=temperature,
|
| 290 |
+
topk=topk,
|
| 291 |
+
cfg_scale=cfg_scale,
|
| 292 |
+
continuous_segments=continuous_segment,
|
| 293 |
+
starts=starts,
|
| 294 |
+
)
|
| 295 |
+
frames.append(curr_token[0:1,])
|
| 296 |
+
|
| 297 |
+
def _pad_audio_token(token: torch.Tensor):
|
| 298 |
+
padded_token = (
|
| 299 |
+
torch.ones(
|
| 300 |
+
(token.shape[0], self._parallel_number),
|
| 301 |
+
device=token.device,
|
| 302 |
+
dtype=torch.long,
|
| 303 |
+
)
|
| 304 |
+
* self.config.empty_id
|
| 305 |
+
)
|
| 306 |
+
padded_token[:, :-1] = token
|
| 307 |
+
padded_token = padded_token.unsqueeze(1)
|
| 308 |
+
padded_token_mask = torch.ones_like(
|
| 309 |
+
padded_token, device=token.device, dtype=torch.bool
|
| 310 |
+
)
|
| 311 |
+
padded_token_mask[..., -1] = False
|
| 312 |
+
return padded_token, padded_token_mask
|
| 313 |
+
|
| 314 |
+
max_audio_frames = max_audio_length_ms // 80
|
| 315 |
+
|
| 316 |
+
for i in tqdm(range(max_audio_frames)):
|
| 317 |
+
curr_token, curr_token_mask = _pad_audio_token(curr_token)
|
| 318 |
+
with torch.autocast(
|
| 319 |
+
device_type=self.mula_device.type, dtype=self.mula_dtype
|
| 320 |
+
):
|
| 321 |
+
curr_token = self.mula.generate_frame(
|
| 322 |
+
tokens=curr_token,
|
| 323 |
+
tokens_mask=curr_token_mask,
|
| 324 |
+
input_pos=prompt_pos[..., -1:] + i + 1,
|
| 325 |
+
temperature=temperature,
|
| 326 |
+
topk=topk,
|
| 327 |
+
cfg_scale=cfg_scale,
|
| 328 |
+
continuous_segments=None,
|
| 329 |
+
starts=None,
|
| 330 |
+
)
|
| 331 |
+
if torch.any(curr_token[0:1, :] >= self.config.audio_eos_id):
|
| 332 |
+
break
|
| 333 |
+
frames.append(curr_token[0:1,])
|
| 334 |
+
frames = torch.stack(frames).permute(1, 2, 0).squeeze(0)
|
| 335 |
+
self._unload()
|
| 336 |
+
return {"frames": frames}
|
| 337 |
+
|
| 338 |
+
def postprocess(self, model_outputs: Dict[str, Any], save_path: str):
|
| 339 |
+
frames = model_outputs["frames"].to(self.codec_device)
|
| 340 |
+
wav = self.codec.detokenize(frames)
|
| 341 |
+
self._unload()
|
| 342 |
+
torchaudio.save(save_path, wav.to(torch.float32).cpu(), 48000)
|
| 343 |
+
|
| 344 |
+
def __call__(self, inputs: Dict[str, Any], **kwargs):
|
| 345 |
+
preprocess_kwargs, forward_kwargs, postprocess_kwargs = (
|
| 346 |
+
self._sanitize_parameters(**kwargs)
|
| 347 |
+
)
|
| 348 |
+
model_inputs = self.preprocess(inputs, **preprocess_kwargs)
|
| 349 |
+
model_outputs = self._forward(model_inputs, **forward_kwargs)
|
| 350 |
+
self.postprocess(model_outputs, **postprocess_kwargs)
|
| 351 |
+
|
| 352 |
+
@classmethod
|
| 353 |
+
def from_pretrained(
|
| 354 |
+
cls,
|
| 355 |
+
pretrained_path: str,
|
| 356 |
+
device: Union[torch.device, Dict[str, torch.device]],
|
| 357 |
+
dtype: Union[torch.dtype, Dict[str, torch.dtype]],
|
| 358 |
+
version: str,
|
| 359 |
+
lazy_load: bool = False,
|
| 360 |
+
):
|
| 361 |
+
|
| 362 |
+
mula_path, codec_path, tokenizer_path, gen_config_path = _resolve_paths(
|
| 363 |
+
pretrained_path, version
|
| 364 |
+
)
|
| 365 |
+
mula_device, codec_device, lazy_load = _resolve_devices(device, lazy_load)
|
| 366 |
+
tokenizer = Tokenizer.from_file(tokenizer_path)
|
| 367 |
+
gen_config = HeartMuLaGenConfig.from_file(gen_config_path)
|
| 368 |
+
|
| 369 |
+
mula_dtype = dtype["mula"] if isinstance(dtype, dict) else dtype
|
| 370 |
+
codec_dtype = dtype["codec"] if isinstance(dtype, dict) else dtype
|
| 371 |
+
|
| 372 |
+
return cls(
|
| 373 |
+
heartmula_path=mula_path,
|
| 374 |
+
heartcodec_path=codec_path,
|
| 375 |
+
heartmula_device=mula_device,
|
| 376 |
+
heartcodec_device=codec_device,
|
| 377 |
+
lazy_load=lazy_load,
|
| 378 |
+
muq_mulan=None,
|
| 379 |
+
text_tokenizer=tokenizer,
|
| 380 |
+
config=gen_config,
|
| 381 |
+
heartmula_dtype=mula_dtype,
|
| 382 |
+
heartcodec_dtype=codec_dtype,
|
| 383 |
+
)
|