Deploy from GitHub repository
Browse files- .gitattributes +2 -0
- README.md +9 -5
- app.py +38 -0
- assets/badge.svg +12 -0
- assets/exp.png +3 -0
- assets/logo.png +3 -0
- assets/lyrics.txt +40 -0
- assets/tags.txt +1 -0
- examples/README.md +15 -0
- examples/run_lyrics_transcription.py +35 -0
- examples/run_music_generation.py +41 -0
- pyproject.toml +47 -0
- requirements.txt +1 -0
- src/heartlib/__init__.py +7 -0
- src/heartlib/heartcodec/configuration_heartcodec.py +73 -0
- src/heartlib/heartcodec/modeling_heartcodec.py +181 -0
- src/heartlib/heartcodec/models/flow_matching.py +177 -0
- src/heartlib/heartcodec/models/sq_codec.py +539 -0
- src/heartlib/heartcodec/models/transformer.py +501 -0
- src/heartlib/heartmula/configuration_heartmula.py +23 -0
- src/heartlib/heartmula/modeling_heartmula.py +316 -0
- src/heartlib/pipelines/__init__.py +0 -0
- src/heartlib/pipelines/lyrics_transcription.py +40 -0
- src/heartlib/pipelines/music_generation.py +256 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/exp.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/logo.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,16 @@
|
|
| 1 |
---
|
| 2 |
title: Heartlib
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Heartlib
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "5.35.0"
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Heartlib
|
| 13 |
+
|
| 14 |
+
<p align="center">
|
| 15 |
+
|
| 16 |
+
Deployed from: https://github.com/HeartMuLa/heartlib
|
app.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
def main_function(input_data):
|
| 4 |
+
if not input_data:
|
| 5 |
+
return "Please provide input"
|
| 6 |
+
|
| 7 |
+
result = f"Processed successfully! Input received: {input_data}"
|
| 8 |
+
return result
|
| 9 |
+
|
| 10 |
+
with gr.Blocks(title="heartlib") as demo:
|
| 11 |
+
gr.Markdown(f"""
|
| 12 |
+
# Heartlib
|
| 13 |
+
|
| 14 |
+
<p align="center">
|
| 15 |
+
|
| 16 |
+
This space was created from: [https://github.com/HeartMuLa/heartlib](https://github.com/HeartMuLa/heartlib)
|
| 17 |
+
""")
|
| 18 |
+
|
| 19 |
+
with gr.Row():
|
| 20 |
+
with gr.Column():
|
| 21 |
+
input_data = gr.Textbox(
|
| 22 |
+
label="Input",
|
| 23 |
+
placeholder="Enter your input here...",
|
| 24 |
+
lines=3
|
| 25 |
+
)
|
| 26 |
+
process_btn = gr.Button("Process", variant="primary")
|
| 27 |
+
|
| 28 |
+
with gr.Column():
|
| 29 |
+
output_data = gr.Textbox(label="Output")
|
| 30 |
+
|
| 31 |
+
process_btn.click(
|
| 32 |
+
fn=main_function,
|
| 33 |
+
inputs=input_data,
|
| 34 |
+
outputs=output_data
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
demo.launch()
|
assets/badge.svg
ADDED
|
|
assets/exp.png
ADDED
|
Git LFS Details
|
assets/logo.png
ADDED
|
Git LFS Details
|
assets/lyrics.txt
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[Intro]
|
| 2 |
+
|
| 3 |
+
[Verse]
|
| 4 |
+
The sun creeps in across the floor
|
| 5 |
+
I hear the traffic outside the door
|
| 6 |
+
The coffee pot begins to hiss
|
| 7 |
+
It is another morning just like this
|
| 8 |
+
|
| 9 |
+
[Prechorus]
|
| 10 |
+
The world keeps spinning round and round
|
| 11 |
+
Feet are planted on the ground
|
| 12 |
+
I find my rhythm in the sound
|
| 13 |
+
|
| 14 |
+
[Chorus]
|
| 15 |
+
Every day the light returns
|
| 16 |
+
Every day the fire burns
|
| 17 |
+
We keep on walking down this street
|
| 18 |
+
Moving to the same steady beat
|
| 19 |
+
It is the ordinary magic that we meet
|
| 20 |
+
|
| 21 |
+
[Verse]
|
| 22 |
+
The hours tick deeply into noon
|
| 23 |
+
Chasing shadows,chasing the moon
|
| 24 |
+
Work is done and the lights go low
|
| 25 |
+
Watching the city start to glow
|
| 26 |
+
|
| 27 |
+
[Bridge]
|
| 28 |
+
It is not always easy,not always bright
|
| 29 |
+
Sometimes we wrestle with the night
|
| 30 |
+
But we make it to the morning light
|
| 31 |
+
|
| 32 |
+
[Chorus]
|
| 33 |
+
Every day the light returns
|
| 34 |
+
Every day the fire burns
|
| 35 |
+
We keep on walking down this street
|
| 36 |
+
Moving to the same steady beat
|
| 37 |
+
|
| 38 |
+
[Outro]
|
| 39 |
+
Just another day
|
| 40 |
+
Every single day
|
assets/tags.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
piano,happy
|
examples/README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎤 Lyrics Transcription
|
| 2 |
+
|
| 3 |
+
Download checkpoint using any of the following command:
|
| 4 |
+
```
|
| 5 |
+
hf download --local_dir './ckpt/HeartTranscriptor-oss' 'HeartMuLa/HeartTranscriptor-oss'
|
| 6 |
+
modelscope download --model 'HeartMuLa/HeartTranscriptor-oss' --local_dir './ckpt/HeartTranscriptor-oss'
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
python ./examples/run_lyrics_transcription.py --model_path=./ckpt
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
By default this command will load the generated music file at `./assets/output.mp3` and print the transcribed lyrics. Use `--music_path` to specify the path to the music file.
|
| 14 |
+
|
| 15 |
+
Note that our HeartTranscriptor is trained on separated vocal tracks. In this example usage part, we directly demonstrate on unseparated music tracks, which is purely for simplicity of illustration. We recommend using source separation tools like demucs to separate the tracks before transcribing lyrics to achieve better results.
|
examples/run_lyrics_transcription.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from heartlib import HeartTranscriptorPipeline
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def parse_args():
|
| 7 |
+
parser = argparse.ArgumentParser()
|
| 8 |
+
parser.add_argument("--model_path", type=str, required=True)
|
| 9 |
+
parser.add_argument("--music_path", type=str, default="./assets/output.mp3")
|
| 10 |
+
|
| 11 |
+
return parser.parse_args()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
if __name__ == "__main__":
|
| 15 |
+
args = parse_args()
|
| 16 |
+
pipe = HeartTranscriptorPipeline.from_pretrained(
|
| 17 |
+
args.model_path,
|
| 18 |
+
device=torch.device("cuda"),
|
| 19 |
+
dtype=torch.float16,
|
| 20 |
+
)
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
result = pipe(
|
| 23 |
+
args.music_path,
|
| 24 |
+
**{
|
| 25 |
+
"max_new_tokens": 256,
|
| 26 |
+
"num_beams": 2,
|
| 27 |
+
"task": "transcribe",
|
| 28 |
+
"condition_on_prev_tokens": False,
|
| 29 |
+
"compression_ratio_threshold": 1.8,
|
| 30 |
+
"temperature": (0.0, 0.1, 0.2, 0.4),
|
| 31 |
+
"logprob_threshold": -1.0,
|
| 32 |
+
"no_speech_threshold": 0.4,
|
| 33 |
+
},
|
| 34 |
+
)
|
| 35 |
+
print(result)
|
examples/run_music_generation.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from heartlib import HeartMuLaGenPipeline
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def parse_args():
|
| 7 |
+
parser = argparse.ArgumentParser()
|
| 8 |
+
parser.add_argument("--model_path", type=str, required=True)
|
| 9 |
+
parser.add_argument("--version", type=str, default="3B")
|
| 10 |
+
parser.add_argument("--lyrics", type=str, default="./assets/lyrics.txt")
|
| 11 |
+
parser.add_argument("--tags", type=str, default="./assets/tags.txt")
|
| 12 |
+
parser.add_argument("--save_path", type=str, default="./assets/output.mp3")
|
| 13 |
+
|
| 14 |
+
parser.add_argument("--max_audio_length_ms", type=int, default=240_000)
|
| 15 |
+
parser.add_argument("--topk", type=int, default=50)
|
| 16 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
| 17 |
+
parser.add_argument("--cfg_scale", type=float, default=1.5)
|
| 18 |
+
return parser.parse_args()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if __name__ == "__main__":
|
| 22 |
+
args = parse_args()
|
| 23 |
+
pipe = HeartMuLaGenPipeline.from_pretrained(
|
| 24 |
+
args.model_path,
|
| 25 |
+
device=torch.device("cuda"),
|
| 26 |
+
dtype=torch.bfloat16,
|
| 27 |
+
version=args.version,
|
| 28 |
+
)
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
pipe(
|
| 31 |
+
{
|
| 32 |
+
"lyrics": args.lyrics,
|
| 33 |
+
"tags": args.tags,
|
| 34 |
+
},
|
| 35 |
+
max_audio_length_ms=args.max_audio_length_ms,
|
| 36 |
+
save_path=args.save_path,
|
| 37 |
+
topk=args.topk,
|
| 38 |
+
temperature=args.temperature,
|
| 39 |
+
cfg_scale=args.cfg_scale,
|
| 40 |
+
)
|
| 41 |
+
print(f"Generated music saved to {args.save_path}")
|
pyproject.toml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "heartlib"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "A Python Library."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.9"
|
| 11 |
+
license = {text = "CC-BY-NC-4.0"}
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "HeartMuLa Team", email = "heartmula.ai@gmail.com"}
|
| 14 |
+
]
|
| 15 |
+
dependencies = [
|
| 16 |
+
"numpy==2.0.2",
|
| 17 |
+
"torch==2.4.1",
|
| 18 |
+
"torchaudio==2.4.1",
|
| 19 |
+
"torchtune==0.4.0",
|
| 20 |
+
"torchao==0.9.0",
|
| 21 |
+
"torchvision==0.19.1",
|
| 22 |
+
"tqdm==4.67.1",
|
| 23 |
+
"traitlets==5.7.1",
|
| 24 |
+
"traittypes==0.2.3",
|
| 25 |
+
"transformers==4.57.0",
|
| 26 |
+
"tokenizers==0.22.1",
|
| 27 |
+
"ipykernel==6.17.1",
|
| 28 |
+
"einops==0.8.1",
|
| 29 |
+
"accelerate==1.12.0",
|
| 30 |
+
"bitsandbytes==0.49.0",
|
| 31 |
+
"vector-quantize-pytorch==1.27.15",
|
| 32 |
+
"modelscope==1.33.0",
|
| 33 |
+
"soundfile"
|
| 34 |
+
]
|
| 35 |
+
urls = { "homepage" = "https://heartmula.github.io/" }
|
| 36 |
+
classifiers = [
|
| 37 |
+
"Programming Language :: Python :: 3",
|
| 38 |
+
"License :: Other/Proprietary License",
|
| 39 |
+
"Operating System :: OS Independent"
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
[tool.setuptools]
|
| 43 |
+
package-dir = {"" = "src"}
|
| 44 |
+
|
| 45 |
+
[tool.setuptools.packages.find]
|
| 46 |
+
where = ["src"]
|
| 47 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
gradio>=5.35.0
|
src/heartlib/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pipelines.music_generation import HeartMuLaGenPipeline
|
| 2 |
+
from .pipelines.lyrics_transcription import HeartTranscriptorPipeline
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"HeartMuLaGenPipeline",
|
| 6 |
+
"HeartTranscriptorPipeline"
|
| 7 |
+
]
|
src/heartlib/heartcodec/configuration_heartcodec.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class HeartCodecConfig(PretrainedConfig):
|
| 6 |
+
model_type = "heartcodec"
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
# config for rvq
|
| 11 |
+
dim: int = 512,
|
| 12 |
+
codebook_size: int = 8192,
|
| 13 |
+
decay: float = 0.9,
|
| 14 |
+
commitment_weight: float = 1.0,
|
| 15 |
+
threshold_ema_dead_code: int = 2,
|
| 16 |
+
use_cosine_sim: bool = False,
|
| 17 |
+
codebook_dim: int = 32,
|
| 18 |
+
num_quantizers: int = 8,
|
| 19 |
+
# config for diffusion transformer
|
| 20 |
+
attention_head_dim: int = 64,
|
| 21 |
+
in_channels: int = 1024,
|
| 22 |
+
norm_type: str = "ada_norm_single",
|
| 23 |
+
num_attention_heads: int = 24,
|
| 24 |
+
num_layers: int = 24,
|
| 25 |
+
num_layers_2: int = 6,
|
| 26 |
+
out_channels: int = 256,
|
| 27 |
+
# config for sq codec
|
| 28 |
+
num_bands: int = 1,
|
| 29 |
+
sample_rate: int = 48000,
|
| 30 |
+
causal: bool = True,
|
| 31 |
+
num_samples: int = 2,
|
| 32 |
+
downsample_factors: List[int] = [3, 4, 4, 4, 5],
|
| 33 |
+
downsample_kernel_sizes: List[int] = [6, 8, 8, 8, 10],
|
| 34 |
+
upsample_factors: List[int] = [5, 4, 4, 4, 3],
|
| 35 |
+
upsample_kernel_sizes: List[int] = [10, 8, 8, 8, 6],
|
| 36 |
+
latent_hidden_dim: int = 128,
|
| 37 |
+
default_kernel_size: int = 7,
|
| 38 |
+
delay_kernel_size: int = 5,
|
| 39 |
+
init_channel: int = 64,
|
| 40 |
+
res_kernel_size: int = 7,
|
| 41 |
+
**kwargs
|
| 42 |
+
):
|
| 43 |
+
super().__init__(**kwargs)
|
| 44 |
+
self.dim = dim
|
| 45 |
+
self.codebook_size = codebook_size
|
| 46 |
+
self.decay = decay
|
| 47 |
+
self.commitment_weight = commitment_weight
|
| 48 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 49 |
+
self.use_cosine_sim = use_cosine_sim
|
| 50 |
+
self.codebook_dim = codebook_dim
|
| 51 |
+
self.num_quantizers = num_quantizers
|
| 52 |
+
|
| 53 |
+
self.attention_head_dim = attention_head_dim
|
| 54 |
+
self.in_channels = in_channels
|
| 55 |
+
self.norm_type = norm_type
|
| 56 |
+
self.num_attention_heads = num_attention_heads
|
| 57 |
+
self.num_layers = num_layers
|
| 58 |
+
self.num_layers_2 = num_layers_2
|
| 59 |
+
self.out_channels = out_channels
|
| 60 |
+
|
| 61 |
+
self.num_bands = num_bands
|
| 62 |
+
self.sample_rate = sample_rate
|
| 63 |
+
self.causal = causal
|
| 64 |
+
self.num_samples = num_samples
|
| 65 |
+
self.downsample_factors = downsample_factors
|
| 66 |
+
self.downsample_kernel_sizes = downsample_kernel_sizes
|
| 67 |
+
self.upsample_factors = upsample_factors
|
| 68 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 69 |
+
self.latent_hidden_dim = latent_hidden_dim
|
| 70 |
+
self.default_kernel_size = default_kernel_size
|
| 71 |
+
self.delay_kernel_size = delay_kernel_size
|
| 72 |
+
self.init_channel = init_channel
|
| 73 |
+
self.res_kernel_size = res_kernel_size
|
src/heartlib/heartcodec/modeling_heartcodec.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .models.flow_matching import FlowMatching
|
| 3 |
+
from .models.sq_codec import ScalarModel
|
| 4 |
+
from .configuration_heartcodec import HeartCodecConfig
|
| 5 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 6 |
+
import math
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HeartCodec(PreTrainedModel):
|
| 11 |
+
config_class = HeartCodecConfig
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
config: HeartCodecConfig,
|
| 16 |
+
):
|
| 17 |
+
super(HeartCodec, self).__init__(config)
|
| 18 |
+
|
| 19 |
+
self.config = config
|
| 20 |
+
|
| 21 |
+
self.flow_matching = FlowMatching(
|
| 22 |
+
dim=config.dim,
|
| 23 |
+
codebook_size=config.codebook_size,
|
| 24 |
+
decay=config.decay,
|
| 25 |
+
commitment_weight=config.commitment_weight,
|
| 26 |
+
threshold_ema_dead_code=config.threshold_ema_dead_code,
|
| 27 |
+
use_cosine_sim=config.use_cosine_sim,
|
| 28 |
+
codebook_dim=config.codebook_dim,
|
| 29 |
+
num_quantizers=config.num_quantizers,
|
| 30 |
+
attention_head_dim=config.attention_head_dim,
|
| 31 |
+
in_channels=config.in_channels,
|
| 32 |
+
norm_type=config.norm_type,
|
| 33 |
+
num_attention_heads=config.num_attention_heads,
|
| 34 |
+
num_layers=config.num_layers,
|
| 35 |
+
num_layers_2=config.num_layers_2,
|
| 36 |
+
out_channels=config.out_channels,
|
| 37 |
+
)
|
| 38 |
+
self.scalar_model = ScalarModel(
|
| 39 |
+
num_bands=config.num_bands,
|
| 40 |
+
sample_rate=config.sample_rate,
|
| 41 |
+
causal=config.causal,
|
| 42 |
+
num_samples=config.num_samples,
|
| 43 |
+
downsample_factors=config.downsample_factors,
|
| 44 |
+
downsample_kernel_sizes=config.downsample_kernel_sizes,
|
| 45 |
+
upsample_factors=config.upsample_factors,
|
| 46 |
+
upsample_kernel_sizes=config.upsample_kernel_sizes,
|
| 47 |
+
latent_hidden_dim=config.latent_hidden_dim,
|
| 48 |
+
default_kernel_size=config.default_kernel_size,
|
| 49 |
+
delay_kernel_size=config.delay_kernel_size,
|
| 50 |
+
init_channel=config.init_channel,
|
| 51 |
+
res_kernel_size=config.res_kernel_size,
|
| 52 |
+
)
|
| 53 |
+
self.post_init()
|
| 54 |
+
|
| 55 |
+
self.sample_rate = config.sample_rate
|
| 56 |
+
|
| 57 |
+
@torch.inference_mode()
|
| 58 |
+
def detokenize(
|
| 59 |
+
self,
|
| 60 |
+
codes,
|
| 61 |
+
duration=29.76,
|
| 62 |
+
num_steps=10,
|
| 63 |
+
disable_progress=False,
|
| 64 |
+
guidance_scale=1.25,
|
| 65 |
+
device="cuda",
|
| 66 |
+
):
|
| 67 |
+
codes = codes.unsqueeze(0).to(device)
|
| 68 |
+
first_latent = torch.randn(codes.shape[0], int(duration * 25), 256).to(
|
| 69 |
+
device
|
| 70 |
+
) # B, T, 64
|
| 71 |
+
first_latent_length = 0
|
| 72 |
+
first_latent_codes_length = 0
|
| 73 |
+
min_samples = int(duration * 12.5)
|
| 74 |
+
hop_samples = min_samples // 93 * 80
|
| 75 |
+
ovlp_samples = min_samples - hop_samples
|
| 76 |
+
ovlp_frames = ovlp_samples * 2
|
| 77 |
+
codes_len = codes.shape[-1] #
|
| 78 |
+
target_len = int(
|
| 79 |
+
(codes_len - first_latent_codes_length) / 12.5 * self.sample_rate
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# code repeat
|
| 83 |
+
if codes_len < min_samples:
|
| 84 |
+
while codes.shape[-1] < min_samples:
|
| 85 |
+
codes = torch.cat([codes, codes], -1)
|
| 86 |
+
codes = codes[:, :, 0:min_samples]
|
| 87 |
+
codes_len = codes.shape[-1]
|
| 88 |
+
if (codes_len - ovlp_frames) % hop_samples > 0:
|
| 89 |
+
len_codes = (
|
| 90 |
+
math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples
|
| 91 |
+
+ ovlp_samples
|
| 92 |
+
)
|
| 93 |
+
while codes.shape[-1] < len_codes:
|
| 94 |
+
codes = torch.cat([codes, codes], -1)
|
| 95 |
+
codes = codes[:, :, 0:len_codes]
|
| 96 |
+
latent_length = int(duration * 25)
|
| 97 |
+
latent_list = []
|
| 98 |
+
|
| 99 |
+
for sinx in range(0, codes.shape[-1] - hop_samples + 1, hop_samples):
|
| 100 |
+
codes_input = []
|
| 101 |
+
codes_input.append(codes[:, :, sinx : sinx + min_samples])
|
| 102 |
+
if sinx == 0 or ovlp_frames == 0:
|
| 103 |
+
incontext_length = first_latent_length
|
| 104 |
+
latents = self.flow_matching.inference_codes(
|
| 105 |
+
codes_input,
|
| 106 |
+
first_latent,
|
| 107 |
+
latent_length,
|
| 108 |
+
incontext_length,
|
| 109 |
+
guidance_scale=guidance_scale,
|
| 110 |
+
num_steps=num_steps,
|
| 111 |
+
disable_progress=disable_progress,
|
| 112 |
+
scenario="other_seg",
|
| 113 |
+
)
|
| 114 |
+
latent_list.append(latents)
|
| 115 |
+
else:
|
| 116 |
+
true_latent = latent_list[-1][:, -ovlp_frames:, :]
|
| 117 |
+
len_add_to_latent = latent_length - true_latent.shape[1] #
|
| 118 |
+
incontext_length = true_latent.shape[1]
|
| 119 |
+
true_latent = torch.cat(
|
| 120 |
+
[
|
| 121 |
+
true_latent,
|
| 122 |
+
torch.randn(
|
| 123 |
+
true_latent.shape[0],
|
| 124 |
+
len_add_to_latent,
|
| 125 |
+
true_latent.shape[-1],
|
| 126 |
+
).to(device),
|
| 127 |
+
],
|
| 128 |
+
1,
|
| 129 |
+
)
|
| 130 |
+
latents = self.flow_matching.inference_codes(
|
| 131 |
+
codes_input,
|
| 132 |
+
true_latent,
|
| 133 |
+
latent_length,
|
| 134 |
+
incontext_length,
|
| 135 |
+
guidance_scale=guidance_scale,
|
| 136 |
+
num_steps=num_steps,
|
| 137 |
+
disable_progress=disable_progress,
|
| 138 |
+
scenario="other_seg",
|
| 139 |
+
)
|
| 140 |
+
latent_list.append(latents)
|
| 141 |
+
|
| 142 |
+
latent_list = [l.float() for l in latent_list]
|
| 143 |
+
latent_list[0] = latent_list[0][:, first_latent_length:, :]
|
| 144 |
+
min_samples = int(duration * self.sample_rate)
|
| 145 |
+
hop_samples = min_samples // 93 * 80
|
| 146 |
+
ovlp_samples = min_samples - hop_samples
|
| 147 |
+
|
| 148 |
+
output = None
|
| 149 |
+
for i in range(len(latent_list)):
|
| 150 |
+
latent = latent_list[i]
|
| 151 |
+
bsz, t, f = latent.shape
|
| 152 |
+
|
| 153 |
+
latent = latent.reshape(
|
| 154 |
+
latent.shape[0], latent.shape[1], 2, latent.shape[2] // 2
|
| 155 |
+
).permute(0, 2, 1, 3)
|
| 156 |
+
latent = latent.reshape(
|
| 157 |
+
latent.shape[0] * 2, latent.shape[2], latent.shape[3]
|
| 158 |
+
)
|
| 159 |
+
cur_output = (
|
| 160 |
+
self.scalar_model.decode(latent.transpose(1, 2)).squeeze(0).squeeze(1)
|
| 161 |
+
) # 1 512 256
|
| 162 |
+
|
| 163 |
+
cur_output = cur_output[:, 0:min_samples].detach().cpu() # B, T
|
| 164 |
+
if cur_output.dim() == 3:
|
| 165 |
+
cur_output = cur_output[0]
|
| 166 |
+
|
| 167 |
+
if output is None:
|
| 168 |
+
output = cur_output
|
| 169 |
+
else:
|
| 170 |
+
if ovlp_samples == 0:
|
| 171 |
+
output = torch.cat([output, cur_output], -1)
|
| 172 |
+
else:
|
| 173 |
+
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
| 174 |
+
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
| 175 |
+
output[:, -ovlp_samples:] = (
|
| 176 |
+
output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:]
|
| 177 |
+
+ cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
| 178 |
+
)
|
| 179 |
+
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
| 180 |
+
output = output[:, 0:target_len]
|
| 181 |
+
return output
|
src/heartlib/heartcodec/models/flow_matching.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from vector_quantize_pytorch import ResidualVQ
|
| 6 |
+
from .transformer import LlamaTransformer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FlowMatching(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
# rvq stuff
|
| 13 |
+
dim: int = 512,
|
| 14 |
+
codebook_size: int = 8192,
|
| 15 |
+
decay: float = 0.9,
|
| 16 |
+
commitment_weight: float = 1.0,
|
| 17 |
+
threshold_ema_dead_code: int = 2,
|
| 18 |
+
use_cosine_sim: bool = False,
|
| 19 |
+
codebook_dim: int = 32,
|
| 20 |
+
num_quantizers: int = 8,
|
| 21 |
+
# dit backbone stuff
|
| 22 |
+
attention_head_dim: int = 64,
|
| 23 |
+
in_channels: int = 1024,
|
| 24 |
+
norm_type: str = "ada_norm_single",
|
| 25 |
+
num_attention_heads: int = 24,
|
| 26 |
+
num_layers: int = 24,
|
| 27 |
+
num_layers_2: int = 6,
|
| 28 |
+
out_channels: int = 256,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.vq_embed = ResidualVQ(
|
| 33 |
+
dim=dim,
|
| 34 |
+
codebook_size=codebook_size,
|
| 35 |
+
decay=decay,
|
| 36 |
+
commitment_weight=commitment_weight,
|
| 37 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
| 38 |
+
use_cosine_sim=use_cosine_sim,
|
| 39 |
+
codebook_dim=codebook_dim,
|
| 40 |
+
num_quantizers=num_quantizers,
|
| 41 |
+
)
|
| 42 |
+
self.cond_feature_emb = nn.Linear(dim, dim)
|
| 43 |
+
self.zero_cond_embedding1 = nn.Parameter(torch.randn(dim))
|
| 44 |
+
self.estimator = LlamaTransformer(
|
| 45 |
+
attention_head_dim=attention_head_dim,
|
| 46 |
+
in_channels=in_channels,
|
| 47 |
+
norm_type=norm_type,
|
| 48 |
+
num_attention_heads=num_attention_heads,
|
| 49 |
+
num_layers=num_layers,
|
| 50 |
+
num_layers_2=num_layers_2,
|
| 51 |
+
out_channels=out_channels,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.latent_dim = out_channels
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def inference_codes(
|
| 58 |
+
self,
|
| 59 |
+
codes,
|
| 60 |
+
true_latents,
|
| 61 |
+
latent_length,
|
| 62 |
+
incontext_length,
|
| 63 |
+
guidance_scale=2.0,
|
| 64 |
+
num_steps=20,
|
| 65 |
+
disable_progress=True,
|
| 66 |
+
scenario="start_seg",
|
| 67 |
+
):
|
| 68 |
+
device = true_latents.device
|
| 69 |
+
dtype = true_latents.dtype
|
| 70 |
+
# codes_bestrq_middle, codes_bestrq_last = codes
|
| 71 |
+
codes_bestrq_emb = codes[0]
|
| 72 |
+
|
| 73 |
+
batch_size = codes_bestrq_emb.shape[0]
|
| 74 |
+
self.vq_embed.eval()
|
| 75 |
+
quantized_feature_emb = self.vq_embed.get_output_from_indices(
|
| 76 |
+
codes_bestrq_emb.transpose(1, 2)
|
| 77 |
+
)
|
| 78 |
+
quantized_feature_emb = self.cond_feature_emb(quantized_feature_emb) # b t 512
|
| 79 |
+
# assert 1==2
|
| 80 |
+
quantized_feature_emb = F.interpolate(
|
| 81 |
+
quantized_feature_emb.permute(0, 2, 1), scale_factor=2, mode="nearest"
|
| 82 |
+
).permute(0, 2, 1)
|
| 83 |
+
|
| 84 |
+
num_frames = quantized_feature_emb.shape[1] #
|
| 85 |
+
latents = torch.randn(
|
| 86 |
+
(batch_size, num_frames, self.latent_dim), device=device, dtype=dtype
|
| 87 |
+
)
|
| 88 |
+
latent_masks = torch.zeros(
|
| 89 |
+
latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device
|
| 90 |
+
)
|
| 91 |
+
latent_masks[:, 0:latent_length] = 2
|
| 92 |
+
if scenario == "other_seg":
|
| 93 |
+
latent_masks[:, 0:incontext_length] = 1
|
| 94 |
+
|
| 95 |
+
quantized_feature_emb = (latent_masks > 0.5).unsqueeze(
|
| 96 |
+
-1
|
| 97 |
+
) * quantized_feature_emb + (latent_masks < 0.5).unsqueeze(
|
| 98 |
+
-1
|
| 99 |
+
) * self.zero_cond_embedding1.unsqueeze(
|
| 100 |
+
0
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
incontext_latents = (
|
| 104 |
+
true_latents
|
| 105 |
+
* ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
| 106 |
+
)
|
| 107 |
+
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
| 108 |
+
|
| 109 |
+
additional_model_input = torch.cat([quantized_feature_emb], 1)
|
| 110 |
+
temperature = 1.0
|
| 111 |
+
t_span = torch.linspace(
|
| 112 |
+
0, 1, num_steps + 1, device=quantized_feature_emb.device
|
| 113 |
+
)
|
| 114 |
+
latents = self.solve_euler(
|
| 115 |
+
latents * temperature,
|
| 116 |
+
incontext_latents,
|
| 117 |
+
incontext_length,
|
| 118 |
+
t_span,
|
| 119 |
+
additional_model_input,
|
| 120 |
+
guidance_scale,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
latents[:, 0:incontext_length, :] = incontext_latents[
|
| 124 |
+
:, 0:incontext_length, :
|
| 125 |
+
] # B, T, dim
|
| 126 |
+
return latents
|
| 127 |
+
|
| 128 |
+
def solve_euler(self, x, incontext_x, incontext_length, t_span, mu, guidance_scale):
|
| 129 |
+
"""
|
| 130 |
+
Fixed euler solver for ODEs.
|
| 131 |
+
Args:
|
| 132 |
+
x (torch.Tensor): random noise
|
| 133 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 134 |
+
shape: (n_timesteps + 1,)
|
| 135 |
+
mu (torch.Tensor): output of encoder
|
| 136 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 137 |
+
"""
|
| 138 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 139 |
+
noise = x.clone()
|
| 140 |
+
|
| 141 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 142 |
+
# Or in future might add like a return_all_steps flag
|
| 143 |
+
sol = []
|
| 144 |
+
for step in tqdm(range(1, len(t_span))):
|
| 145 |
+
x[:, 0:incontext_length, :] = (1 - (1 - 1e-6) * t) * noise[
|
| 146 |
+
:, 0:incontext_length, :
|
| 147 |
+
] + t * incontext_x[:, 0:incontext_length, :]
|
| 148 |
+
if guidance_scale > 1.0:
|
| 149 |
+
dphi_dt = self.estimator(
|
| 150 |
+
torch.cat(
|
| 151 |
+
[
|
| 152 |
+
torch.cat([x, x], 0),
|
| 153 |
+
torch.cat([incontext_x, incontext_x], 0),
|
| 154 |
+
torch.cat([torch.zeros_like(mu), mu], 0),
|
| 155 |
+
],
|
| 156 |
+
2,
|
| 157 |
+
),
|
| 158 |
+
timestep=t.unsqueeze(-1).repeat(2),
|
| 159 |
+
)
|
| 160 |
+
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2, 0)
|
| 161 |
+
dphi_dt = dphi_dt_uncond + guidance_scale * (
|
| 162 |
+
dhpi_dt_cond - dphi_dt_uncond
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
dphi_dt = self.estimator(
|
| 166 |
+
torch.cat([x, incontext_x, mu], 2), timestep=t.unsqueeze(-1)
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
x = x + dt * dphi_dt
|
| 170 |
+
t = t + dt
|
| 171 |
+
sol.append(x)
|
| 172 |
+
if step < len(t_span) - 1:
|
| 173 |
+
dt = t_span[step + 1] - t
|
| 174 |
+
|
| 175 |
+
result = sol[-1]
|
| 176 |
+
|
| 177 |
+
return result
|
src/heartlib/heartcodec/models/sq_codec.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 6 |
+
from torch.nn.utils import remove_weight_norm
|
| 7 |
+
from torch.autograd.function import InplaceFunction
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_padding(kernel_size, dilation=1):
|
| 11 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Scripting this brings model speed up 1.4x
|
| 15 |
+
@torch.jit.script
|
| 16 |
+
def snake(x, alpha):
|
| 17 |
+
shape = x.shape
|
| 18 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 19 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 20 |
+
x = x.reshape(shape)
|
| 21 |
+
return x
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Snake1d(nn.Module):
|
| 25 |
+
def __init__(self, channels):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
return snake(x, self.alpha)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Conv1d(nn.Conv1d):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
in_channels: int,
|
| 37 |
+
out_channels: int,
|
| 38 |
+
kernel_size: int,
|
| 39 |
+
stride: int = 1,
|
| 40 |
+
dilation: int = 1,
|
| 41 |
+
groups: int = 1,
|
| 42 |
+
padding_mode: str = "zeros",
|
| 43 |
+
bias: bool = True,
|
| 44 |
+
padding=None,
|
| 45 |
+
causal: bool = False,
|
| 46 |
+
w_init_gain=None,
|
| 47 |
+
):
|
| 48 |
+
self.causal = causal
|
| 49 |
+
if padding is None:
|
| 50 |
+
if causal:
|
| 51 |
+
padding = 0
|
| 52 |
+
self.left_padding = dilation * (kernel_size - 1)
|
| 53 |
+
else:
|
| 54 |
+
padding = get_padding(kernel_size, dilation)
|
| 55 |
+
super(Conv1d, self).__init__(
|
| 56 |
+
in_channels,
|
| 57 |
+
out_channels,
|
| 58 |
+
kernel_size,
|
| 59 |
+
stride=stride,
|
| 60 |
+
padding=padding,
|
| 61 |
+
dilation=dilation,
|
| 62 |
+
groups=groups,
|
| 63 |
+
padding_mode=padding_mode,
|
| 64 |
+
bias=bias,
|
| 65 |
+
)
|
| 66 |
+
if w_init_gain is not None:
|
| 67 |
+
torch.nn.init.xavier_uniform_(
|
| 68 |
+
self.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
if self.causal:
|
| 73 |
+
x = F.pad(x.unsqueeze(2), (self.left_padding, 0, 0, 0)).squeeze(2)
|
| 74 |
+
|
| 75 |
+
return super(Conv1d, self).forward(x)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ConvTranspose1d(nn.ConvTranspose1d):
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
in_channels: int,
|
| 82 |
+
out_channels: int,
|
| 83 |
+
kernel_size: int,
|
| 84 |
+
stride: int = 1,
|
| 85 |
+
output_padding: int = 0,
|
| 86 |
+
groups: int = 1,
|
| 87 |
+
bias: bool = True,
|
| 88 |
+
dilation: int = 1,
|
| 89 |
+
padding=None,
|
| 90 |
+
padding_mode: str = "zeros",
|
| 91 |
+
causal: bool = False,
|
| 92 |
+
):
|
| 93 |
+
if padding is None:
|
| 94 |
+
padding = 0 if causal else (kernel_size - stride) // 2
|
| 95 |
+
if causal:
|
| 96 |
+
assert padding == 0, "padding is not allowed in causal ConvTranspose1d."
|
| 97 |
+
assert (
|
| 98 |
+
kernel_size == 2 * stride
|
| 99 |
+
), "kernel_size must be equal to 2*stride is not allowed in causal ConvTranspose1d."
|
| 100 |
+
super(ConvTranspose1d, self).__init__(
|
| 101 |
+
in_channels,
|
| 102 |
+
out_channels,
|
| 103 |
+
kernel_size,
|
| 104 |
+
stride=stride,
|
| 105 |
+
padding=padding,
|
| 106 |
+
output_padding=output_padding,
|
| 107 |
+
groups=groups,
|
| 108 |
+
bias=bias,
|
| 109 |
+
dilation=dilation,
|
| 110 |
+
padding_mode=padding_mode,
|
| 111 |
+
)
|
| 112 |
+
self.causal = causal
|
| 113 |
+
self.stride = stride
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
x = super(ConvTranspose1d, self).forward(x)
|
| 117 |
+
if self.causal:
|
| 118 |
+
x = x[:, :, : -self.stride]
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class PreProcessor(nn.Module):
|
| 123 |
+
def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False):
|
| 124 |
+
super(PreProcessor, self).__init__()
|
| 125 |
+
self.pooling = torch.nn.AvgPool1d(kernel_size=num_samples)
|
| 126 |
+
self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal)
|
| 127 |
+
self.activation = nn.PReLU()
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
output = self.activation(self.conv(x))
|
| 131 |
+
output = self.pooling(output)
|
| 132 |
+
return output
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class PostProcessor(nn.Module):
|
| 136 |
+
def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False):
|
| 137 |
+
super(PostProcessor, self).__init__()
|
| 138 |
+
self.num_samples = num_samples
|
| 139 |
+
self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal)
|
| 140 |
+
self.activation = nn.PReLU()
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
x = torch.transpose(x, 1, 2)
|
| 144 |
+
B, T, C = x.size()
|
| 145 |
+
x = x.repeat(1, 1, self.num_samples).view(B, -1, C)
|
| 146 |
+
x = torch.transpose(x, 1, 2)
|
| 147 |
+
output = self.activation(self.conv(x))
|
| 148 |
+
return output
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class ResidualUnit(nn.Module):
|
| 152 |
+
def __init__(self, n_in, n_out, dilation, res_kernel_size=7, causal=False):
|
| 153 |
+
super(ResidualUnit, self).__init__()
|
| 154 |
+
self.conv1 = weight_norm(
|
| 155 |
+
Conv1d(
|
| 156 |
+
n_in,
|
| 157 |
+
n_out,
|
| 158 |
+
kernel_size=res_kernel_size,
|
| 159 |
+
dilation=dilation,
|
| 160 |
+
causal=causal,
|
| 161 |
+
)
|
| 162 |
+
)
|
| 163 |
+
self.conv2 = weight_norm(Conv1d(n_in, n_out, kernel_size=1, causal=causal))
|
| 164 |
+
self.activation1 = nn.PReLU()
|
| 165 |
+
self.activation2 = nn.PReLU()
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
output = self.activation1(self.conv1(x))
|
| 169 |
+
output = self.activation2(self.conv2(output))
|
| 170 |
+
return output + x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class ResEncoderBlock(nn.Module):
|
| 174 |
+
def __init__(
|
| 175 |
+
self, n_in, n_out, stride, down_kernel_size, res_kernel_size=7, causal=False
|
| 176 |
+
):
|
| 177 |
+
super(ResEncoderBlock, self).__init__()
|
| 178 |
+
self.convs = nn.ModuleList(
|
| 179 |
+
[
|
| 180 |
+
ResidualUnit(
|
| 181 |
+
n_in,
|
| 182 |
+
n_out // 2,
|
| 183 |
+
dilation=1,
|
| 184 |
+
res_kernel_size=res_kernel_size,
|
| 185 |
+
causal=causal,
|
| 186 |
+
),
|
| 187 |
+
ResidualUnit(
|
| 188 |
+
n_out // 2,
|
| 189 |
+
n_out // 2,
|
| 190 |
+
dilation=3,
|
| 191 |
+
res_kernel_size=res_kernel_size,
|
| 192 |
+
causal=causal,
|
| 193 |
+
),
|
| 194 |
+
ResidualUnit(
|
| 195 |
+
n_out // 2,
|
| 196 |
+
n_out // 2,
|
| 197 |
+
dilation=5,
|
| 198 |
+
res_kernel_size=res_kernel_size,
|
| 199 |
+
causal=causal,
|
| 200 |
+
),
|
| 201 |
+
ResidualUnit(
|
| 202 |
+
n_out // 2,
|
| 203 |
+
n_out // 2,
|
| 204 |
+
dilation=7,
|
| 205 |
+
res_kernel_size=res_kernel_size,
|
| 206 |
+
causal=causal,
|
| 207 |
+
),
|
| 208 |
+
ResidualUnit(
|
| 209 |
+
n_out // 2,
|
| 210 |
+
n_out // 2,
|
| 211 |
+
dilation=9,
|
| 212 |
+
res_kernel_size=res_kernel_size,
|
| 213 |
+
causal=causal,
|
| 214 |
+
),
|
| 215 |
+
]
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
self.down_conv = DownsampleLayer(
|
| 219 |
+
n_in, n_out, down_kernel_size, stride=stride, causal=causal
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
for conv in self.convs:
|
| 224 |
+
x = conv(x)
|
| 225 |
+
x = self.down_conv(x)
|
| 226 |
+
return x
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class ResDecoderBlock(nn.Module):
|
| 230 |
+
def __init__(
|
| 231 |
+
self, n_in, n_out, stride, up_kernel_size, res_kernel_size=7, causal=False
|
| 232 |
+
):
|
| 233 |
+
super(ResDecoderBlock, self).__init__()
|
| 234 |
+
self.up_conv = UpsampleLayer(
|
| 235 |
+
n_in,
|
| 236 |
+
n_out,
|
| 237 |
+
kernel_size=up_kernel_size,
|
| 238 |
+
stride=stride,
|
| 239 |
+
causal=causal,
|
| 240 |
+
activation=None,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
self.convs = nn.ModuleList(
|
| 244 |
+
[
|
| 245 |
+
ResidualUnit(
|
| 246 |
+
n_out,
|
| 247 |
+
n_out,
|
| 248 |
+
dilation=1,
|
| 249 |
+
res_kernel_size=res_kernel_size,
|
| 250 |
+
causal=causal,
|
| 251 |
+
),
|
| 252 |
+
ResidualUnit(
|
| 253 |
+
n_out,
|
| 254 |
+
n_out,
|
| 255 |
+
dilation=3,
|
| 256 |
+
res_kernel_size=res_kernel_size,
|
| 257 |
+
causal=causal,
|
| 258 |
+
),
|
| 259 |
+
ResidualUnit(
|
| 260 |
+
n_out,
|
| 261 |
+
n_out,
|
| 262 |
+
dilation=5,
|
| 263 |
+
res_kernel_size=res_kernel_size,
|
| 264 |
+
causal=causal,
|
| 265 |
+
),
|
| 266 |
+
ResidualUnit(
|
| 267 |
+
n_out,
|
| 268 |
+
n_out,
|
| 269 |
+
dilation=7,
|
| 270 |
+
res_kernel_size=res_kernel_size,
|
| 271 |
+
causal=causal,
|
| 272 |
+
),
|
| 273 |
+
ResidualUnit(
|
| 274 |
+
n_out,
|
| 275 |
+
n_out,
|
| 276 |
+
dilation=9,
|
| 277 |
+
res_kernel_size=res_kernel_size,
|
| 278 |
+
causal=causal,
|
| 279 |
+
),
|
| 280 |
+
]
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
def forward(self, x):
|
| 284 |
+
x = self.up_conv(x)
|
| 285 |
+
for conv in self.convs:
|
| 286 |
+
x = conv(x)
|
| 287 |
+
return x
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class DownsampleLayer(nn.Module):
|
| 291 |
+
def __init__(
|
| 292 |
+
self,
|
| 293 |
+
in_channels: int,
|
| 294 |
+
out_channels: int,
|
| 295 |
+
kernel_size: int,
|
| 296 |
+
stride: int = 1,
|
| 297 |
+
causal: bool = False,
|
| 298 |
+
activation=nn.PReLU(),
|
| 299 |
+
use_weight_norm: bool = True,
|
| 300 |
+
pooling: bool = False,
|
| 301 |
+
):
|
| 302 |
+
super(DownsampleLayer, self).__init__()
|
| 303 |
+
self.pooling = pooling
|
| 304 |
+
self.stride = stride
|
| 305 |
+
self.activation = nn.PReLU()
|
| 306 |
+
self.use_weight_norm = use_weight_norm
|
| 307 |
+
if pooling:
|
| 308 |
+
self.layer = Conv1d(in_channels, out_channels, kernel_size, causal=causal)
|
| 309 |
+
self.pooling = nn.AvgPool1d(kernel_size=stride)
|
| 310 |
+
else:
|
| 311 |
+
self.layer = Conv1d(
|
| 312 |
+
in_channels, out_channels, kernel_size, stride=stride, causal=causal
|
| 313 |
+
)
|
| 314 |
+
if use_weight_norm:
|
| 315 |
+
self.layer = weight_norm(self.layer)
|
| 316 |
+
|
| 317 |
+
def forward(self, x):
|
| 318 |
+
x = self.layer(x)
|
| 319 |
+
x = self.activation(x) if self.activation is not None else x
|
| 320 |
+
if self.pooling:
|
| 321 |
+
x = self.pooling(x)
|
| 322 |
+
return x
|
| 323 |
+
|
| 324 |
+
def remove_weight_norm(self):
|
| 325 |
+
if self.use_weight_norm:
|
| 326 |
+
remove_weight_norm(self.layer)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class UpsampleLayer(nn.Module):
|
| 330 |
+
def __init__(
|
| 331 |
+
self,
|
| 332 |
+
in_channels: int,
|
| 333 |
+
out_channels: int,
|
| 334 |
+
kernel_size: int,
|
| 335 |
+
stride: int = 1,
|
| 336 |
+
causal: bool = False,
|
| 337 |
+
activation=nn.PReLU(),
|
| 338 |
+
use_weight_norm: bool = True,
|
| 339 |
+
repeat: bool = False,
|
| 340 |
+
):
|
| 341 |
+
super(UpsampleLayer, self).__init__()
|
| 342 |
+
self.repeat = repeat
|
| 343 |
+
self.stride = stride
|
| 344 |
+
self.activation = activation
|
| 345 |
+
self.use_weight_norm = use_weight_norm
|
| 346 |
+
if repeat:
|
| 347 |
+
self.layer = Conv1d(in_channels, out_channels, kernel_size, causal=causal)
|
| 348 |
+
else:
|
| 349 |
+
self.layer = ConvTranspose1d(
|
| 350 |
+
in_channels, out_channels, kernel_size, stride=stride, causal=causal
|
| 351 |
+
)
|
| 352 |
+
if use_weight_norm:
|
| 353 |
+
self.layer = weight_norm(self.layer)
|
| 354 |
+
|
| 355 |
+
def forward(self, x):
|
| 356 |
+
x = self.layer(x)
|
| 357 |
+
x = self.activation(x) if self.activation is not None else x
|
| 358 |
+
if self.repeat:
|
| 359 |
+
x = torch.transpose(x, 1, 2)
|
| 360 |
+
B, T, C = x.size()
|
| 361 |
+
x = x.repeat(1, 1, self.stride).view(B, -1, C)
|
| 362 |
+
x = torch.transpose(x, 1, 2)
|
| 363 |
+
return x
|
| 364 |
+
|
| 365 |
+
def remove_weight_norm(self):
|
| 366 |
+
if self.use_weight_norm:
|
| 367 |
+
remove_weight_norm(self.layer)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class round_func9(InplaceFunction):
|
| 371 |
+
@staticmethod
|
| 372 |
+
def forward(ctx, input):
|
| 373 |
+
ctx.input = input
|
| 374 |
+
return torch.round(9 * input) / 9
|
| 375 |
+
|
| 376 |
+
@staticmethod
|
| 377 |
+
def backward(ctx, grad_output):
|
| 378 |
+
grad_input = grad_output.clone()
|
| 379 |
+
return grad_input
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
class ScalarModel(nn.Module):
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
num_bands,
|
| 386 |
+
sample_rate,
|
| 387 |
+
causal,
|
| 388 |
+
num_samples,
|
| 389 |
+
downsample_factors,
|
| 390 |
+
downsample_kernel_sizes,
|
| 391 |
+
upsample_factors,
|
| 392 |
+
upsample_kernel_sizes,
|
| 393 |
+
latent_hidden_dim,
|
| 394 |
+
default_kernel_size,
|
| 395 |
+
delay_kernel_size,
|
| 396 |
+
init_channel,
|
| 397 |
+
res_kernel_size,
|
| 398 |
+
mode="pre_proj",
|
| 399 |
+
):
|
| 400 |
+
super(ScalarModel, self).__init__()
|
| 401 |
+
# self.args = args
|
| 402 |
+
self.encoder = []
|
| 403 |
+
self.decoder = []
|
| 404 |
+
self.vq = round_func9() # using 9
|
| 405 |
+
self.mode = mode
|
| 406 |
+
# Encoder parts
|
| 407 |
+
self.encoder.append(
|
| 408 |
+
weight_norm(
|
| 409 |
+
Conv1d(
|
| 410 |
+
num_bands,
|
| 411 |
+
init_channel,
|
| 412 |
+
kernel_size=default_kernel_size,
|
| 413 |
+
causal=causal,
|
| 414 |
+
)
|
| 415 |
+
)
|
| 416 |
+
)
|
| 417 |
+
if num_samples > 1:
|
| 418 |
+
# Downsampling
|
| 419 |
+
self.encoder.append(
|
| 420 |
+
PreProcessor(
|
| 421 |
+
init_channel,
|
| 422 |
+
init_channel,
|
| 423 |
+
num_samples,
|
| 424 |
+
kernel_size=default_kernel_size,
|
| 425 |
+
causal=causal,
|
| 426 |
+
)
|
| 427 |
+
)
|
| 428 |
+
for i, down_factor in enumerate(downsample_factors):
|
| 429 |
+
self.encoder.append(
|
| 430 |
+
ResEncoderBlock(
|
| 431 |
+
init_channel * np.power(2, i),
|
| 432 |
+
init_channel * np.power(2, i + 1),
|
| 433 |
+
down_factor,
|
| 434 |
+
downsample_kernel_sizes[i],
|
| 435 |
+
res_kernel_size,
|
| 436 |
+
causal=causal,
|
| 437 |
+
)
|
| 438 |
+
)
|
| 439 |
+
self.encoder.append(
|
| 440 |
+
weight_norm(
|
| 441 |
+
Conv1d(
|
| 442 |
+
init_channel * np.power(2, len(downsample_factors)),
|
| 443 |
+
latent_hidden_dim,
|
| 444 |
+
kernel_size=default_kernel_size,
|
| 445 |
+
causal=causal,
|
| 446 |
+
)
|
| 447 |
+
)
|
| 448 |
+
)
|
| 449 |
+
# Decoder
|
| 450 |
+
# look ahead
|
| 451 |
+
self.decoder.append(
|
| 452 |
+
weight_norm(
|
| 453 |
+
Conv1d(
|
| 454 |
+
latent_hidden_dim,
|
| 455 |
+
init_channel * np.power(2, len(upsample_factors)),
|
| 456 |
+
kernel_size=delay_kernel_size,
|
| 457 |
+
)
|
| 458 |
+
)
|
| 459 |
+
)
|
| 460 |
+
for i, upsample_factor in enumerate(upsample_factors):
|
| 461 |
+
self.decoder.append(
|
| 462 |
+
ResDecoderBlock(
|
| 463 |
+
init_channel * np.power(2, len(upsample_factors) - i),
|
| 464 |
+
init_channel * np.power(2, len(upsample_factors) - i - 1),
|
| 465 |
+
upsample_factor,
|
| 466 |
+
upsample_kernel_sizes[i],
|
| 467 |
+
res_kernel_size,
|
| 468 |
+
causal=causal,
|
| 469 |
+
)
|
| 470 |
+
)
|
| 471 |
+
if num_samples > 1:
|
| 472 |
+
self.decoder.append(
|
| 473 |
+
PostProcessor(
|
| 474 |
+
init_channel,
|
| 475 |
+
init_channel,
|
| 476 |
+
num_samples,
|
| 477 |
+
kernel_size=default_kernel_size,
|
| 478 |
+
causal=causal,
|
| 479 |
+
)
|
| 480 |
+
)
|
| 481 |
+
self.decoder.append(
|
| 482 |
+
weight_norm(
|
| 483 |
+
Conv1d(
|
| 484 |
+
init_channel,
|
| 485 |
+
num_bands,
|
| 486 |
+
kernel_size=default_kernel_size,
|
| 487 |
+
causal=causal,
|
| 488 |
+
)
|
| 489 |
+
)
|
| 490 |
+
)
|
| 491 |
+
self.encoder = nn.ModuleList(self.encoder)
|
| 492 |
+
self.decoder = nn.ModuleList(self.decoder)
|
| 493 |
+
|
| 494 |
+
def forward(self, x):
|
| 495 |
+
for i, layer in enumerate(self.encoder):
|
| 496 |
+
if i != len(self.encoder) - 1:
|
| 497 |
+
x = layer(x)
|
| 498 |
+
else:
|
| 499 |
+
x = F.tanh(layer(x))
|
| 500 |
+
# import pdb; pdb.set_trace()
|
| 501 |
+
x = self.vq.apply(x) # vq
|
| 502 |
+
for i, layer in enumerate(self.decoder):
|
| 503 |
+
x = layer(x)
|
| 504 |
+
return x
|
| 505 |
+
|
| 506 |
+
def inference(self, x):
|
| 507 |
+
for i, layer in enumerate(self.encoder):
|
| 508 |
+
if i != len(self.encoder) - 1:
|
| 509 |
+
x = layer(x)
|
| 510 |
+
else:
|
| 511 |
+
x = F.tanh(layer(x)) # reverse to tanh
|
| 512 |
+
|
| 513 |
+
emb = x
|
| 514 |
+
# import pdb; pdb.set_trace()
|
| 515 |
+
emb_quant = self.vq.apply(emb) # vq
|
| 516 |
+
x = emb_quant
|
| 517 |
+
for i, layer in enumerate(self.decoder):
|
| 518 |
+
x = layer(x)
|
| 519 |
+
return emb, emb_quant, x
|
| 520 |
+
|
| 521 |
+
def encode(self, x):
|
| 522 |
+
for i, layer in enumerate(self.encoder):
|
| 523 |
+
if i != len(self.encoder) - 1:
|
| 524 |
+
x = layer(x)
|
| 525 |
+
else:
|
| 526 |
+
x = F.tanh(layer(x)) # reverse to tanh
|
| 527 |
+
|
| 528 |
+
emb = x
|
| 529 |
+
# import pdb; pdb.set_trace()
|
| 530 |
+
emb_quant = self.vq.apply(emb) # vq
|
| 531 |
+
return emb
|
| 532 |
+
|
| 533 |
+
def decode(self, x):
|
| 534 |
+
x = self.vq.apply(
|
| 535 |
+
x
|
| 536 |
+
) # make sure the prediction follow the similar disctribution
|
| 537 |
+
for i, layer in enumerate(self.decoder):
|
| 538 |
+
x = layer(x)
|
| 539 |
+
return x
|
src/heartlib/heartcodec/models/transformer.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RMSNorm(nn.Module):
|
| 9 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.eps = eps
|
| 12 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 13 |
+
|
| 14 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
var = x.pow(2).mean(dim=-1, keepdim=True)
|
| 16 |
+
x = x * torch.rsqrt(var + self.eps)
|
| 17 |
+
return self.weight * x
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RotaryEmbedding(nn.Module):
|
| 21 |
+
def __init__(self, dim: int, base: int = 10000):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.dim = dim
|
| 24 |
+
self.base = base
|
| 25 |
+
self._cache = {}
|
| 26 |
+
|
| 27 |
+
def get_sin_cos(self, seq_len: int, device, dtype):
|
| 28 |
+
key = (seq_len, device, dtype)
|
| 29 |
+
cached = self._cache.get(key, None)
|
| 30 |
+
if cached is not None and cached[0].device == device:
|
| 31 |
+
return cached
|
| 32 |
+
inv_freq = 1.0 / (
|
| 33 |
+
self.base
|
| 34 |
+
** (torch.arange(0, self.dim, 2, device=device, dtype=dtype) / self.dim)
|
| 35 |
+
)
|
| 36 |
+
t = torch.arange(seq_len, device=device, dtype=dtype)
|
| 37 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| 38 |
+
sin = freqs.sin()
|
| 39 |
+
cos = freqs.cos()
|
| 40 |
+
self._cache[key] = (sin, cos)
|
| 41 |
+
return sin, cos
|
| 42 |
+
|
| 43 |
+
def apply_rotary(
|
| 44 |
+
self, x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
x1, x2 = x[..., : self.dim // 2], x[..., self.dim // 2 : self.dim]
|
| 47 |
+
# Interleave sin/cos across pairs
|
| 48 |
+
x_rot = torch.stack((-x2, x1), dim=-1).reshape_as(x[..., : self.dim])
|
| 49 |
+
return (x[..., : self.dim] * cos.unsqueeze(-1)).reshape_as(
|
| 50 |
+
x[..., : self.dim]
|
| 51 |
+
) + (x_rot * sin.unsqueeze(-1)).reshape_as(x[..., : self.dim])
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class LlamaAttention(nn.Module):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
dim: int,
|
| 58 |
+
n_heads: int,
|
| 59 |
+
head_dim: int,
|
| 60 |
+
bias: bool = False,
|
| 61 |
+
dropout: float = 0.0,
|
| 62 |
+
rope_dim: Optional[int] = None,
|
| 63 |
+
cross_attention_dim: Optional[int] = None,
|
| 64 |
+
use_sdpa: bool = True,
|
| 65 |
+
):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.dim = dim
|
| 68 |
+
self.n_heads = n_heads
|
| 69 |
+
self.head_dim = head_dim
|
| 70 |
+
self.inner_dim = n_heads * head_dim
|
| 71 |
+
self.cross_attention_dim = cross_attention_dim
|
| 72 |
+
self.q_proj = nn.Linear(dim, self.inner_dim, bias=bias)
|
| 73 |
+
k_in = dim if cross_attention_dim is None else cross_attention_dim
|
| 74 |
+
self.k_proj = nn.Linear(k_in, self.inner_dim, bias=bias)
|
| 75 |
+
self.v_proj = nn.Linear(k_in, self.inner_dim, bias=bias)
|
| 76 |
+
self.o_proj = nn.Linear(self.inner_dim, dim, bias=bias)
|
| 77 |
+
self.dropout = dropout
|
| 78 |
+
self.rope_dim = rope_dim if rope_dim is not None else head_dim
|
| 79 |
+
self.rope = RotaryEmbedding(self.rope_dim)
|
| 80 |
+
self.use_sdpa = use_sdpa
|
| 81 |
+
self._has_sdpa = hasattr(F, "scaled_dot_product_attention")
|
| 82 |
+
|
| 83 |
+
def _shape(self, x: torch.Tensor, b: int, t: int) -> torch.Tensor:
|
| 84 |
+
return x.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 85 |
+
|
| 86 |
+
def forward(
|
| 87 |
+
self,
|
| 88 |
+
x: torch.Tensor,
|
| 89 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 90 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 91 |
+
) -> torch.Tensor:
|
| 92 |
+
b, t, c = x.shape
|
| 93 |
+
q = self._shape(self.q_proj(x), b, t)
|
| 94 |
+
if encoder_hidden_states is None:
|
| 95 |
+
k = self._shape(self.k_proj(x), b, t)
|
| 96 |
+
v = self._shape(self.v_proj(x), b, t)
|
| 97 |
+
else:
|
| 98 |
+
bt, tk, ck = encoder_hidden_states.shape
|
| 99 |
+
k = self._shape(self.k_proj(encoder_hidden_states), b, tk)
|
| 100 |
+
v = self._shape(self.v_proj(encoder_hidden_states), b, tk)
|
| 101 |
+
|
| 102 |
+
# RoPE on first rope_dim of head_dim
|
| 103 |
+
rope_dim = min(self.rope_dim, self.head_dim)
|
| 104 |
+
seq_len_for_rope = k.shape[-2]
|
| 105 |
+
sin, cos = self.rope.get_sin_cos(
|
| 106 |
+
seq_len_for_rope, device=x.device, dtype=x.dtype
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def apply_rope_vec(tensor):
|
| 110 |
+
head = tensor[..., :rope_dim]
|
| 111 |
+
tail = tensor[..., rope_dim:]
|
| 112 |
+
b, h, tt, _ = head.shape
|
| 113 |
+
head = head.view(b, h, tt, rope_dim // 2, 2)
|
| 114 |
+
sin_ = sin.view(1, 1, tt, rope_dim // 2, 1)
|
| 115 |
+
cos_ = cos.view(1, 1, tt, rope_dim // 2, 1)
|
| 116 |
+
x1 = head[..., 0:1]
|
| 117 |
+
x2 = head[..., 1:2]
|
| 118 |
+
rot = torch.cat(
|
| 119 |
+
[x1 * cos_ - x2 * sin_, x1 * sin_ + x2 * cos_], dim=-1
|
| 120 |
+
).view(b, h, tt, rope_dim)
|
| 121 |
+
return torch.cat([rot, tail], dim=-1)
|
| 122 |
+
|
| 123 |
+
q = apply_rope_vec(q)
|
| 124 |
+
k = apply_rope_vec(k)
|
| 125 |
+
|
| 126 |
+
# Prefer PyTorch SDPA (can enable FlashAttention kernel on supported GPUs)
|
| 127 |
+
if self.use_sdpa and self._has_sdpa:
|
| 128 |
+
s = k.shape[-2]
|
| 129 |
+
attn_mask_sdpa = None
|
| 130 |
+
if attention_mask is not None:
|
| 131 |
+
m = attention_mask
|
| 132 |
+
|
| 133 |
+
if m.dim() == 2 and m.shape == (b, s): # [b, s]
|
| 134 |
+
m = m[:, None, None, :] # [b,1,1,s]
|
| 135 |
+
elif m.dim() == 3 and m.shape[-2] == 1: # [b,1,s]
|
| 136 |
+
m = m[:, None, :, :] # [b,1,1,s]
|
| 137 |
+
elif m.dim() == 3 and m.shape[-2] == t: # [b,t,s]
|
| 138 |
+
m = m[:, None, :, :] # [b,1,t,s]
|
| 139 |
+
elif m.dim() == 4 and m.shape[1] == 1: # [b,1,t,s] or [b,1,1,s]
|
| 140 |
+
pass
|
| 141 |
+
attn_mask_sdpa = m
|
| 142 |
+
|
| 143 |
+
out = F.scaled_dot_product_attention(
|
| 144 |
+
q,
|
| 145 |
+
k,
|
| 146 |
+
v,
|
| 147 |
+
attn_mask=attn_mask_sdpa,
|
| 148 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 149 |
+
is_causal=False,
|
| 150 |
+
)
|
| 151 |
+
out = out.transpose(1, 2).contiguous().view(b, t, self.inner_dim)
|
| 152 |
+
return self.o_proj(out)
|
| 153 |
+
else:
|
| 154 |
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
|
| 155 |
+
self.head_dim
|
| 156 |
+
)
|
| 157 |
+
if attention_mask is not None:
|
| 158 |
+
attn_scores = attn_scores + attention_mask
|
| 159 |
+
attn = attn_scores.softmax(dim=-1)
|
| 160 |
+
attn = F.dropout(attn, p=self.dropout, training=self.training)
|
| 161 |
+
out = torch.matmul(attn, v)
|
| 162 |
+
out = out.transpose(1, 2).contiguous().view(b, t, self.inner_dim)
|
| 163 |
+
return self.o_proj(out)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class LlamaMLP(nn.Module):
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
dim: int,
|
| 170 |
+
hidden_dim: Optional[int] = None,
|
| 171 |
+
multiple_of: int = 256,
|
| 172 |
+
dropout: float = 0.0,
|
| 173 |
+
):
|
| 174 |
+
super().__init__()
|
| 175 |
+
hidden_dim = hidden_dim or 4 * dim
|
| 176 |
+
# align to multiple_of like Llama
|
| 177 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 178 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 179 |
+
self.gate = nn.Linear(dim, hidden_dim, bias=False)
|
| 180 |
+
self.up = nn.Linear(dim, hidden_dim, bias=False)
|
| 181 |
+
self.down = nn.Linear(hidden_dim, dim, bias=False)
|
| 182 |
+
self.dropout = dropout
|
| 183 |
+
|
| 184 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 185 |
+
x = F.silu(self.gate(x)) * self.up(x)
|
| 186 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 187 |
+
return self.down(x)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class LlamaTransformerBlock(nn.Module):
|
| 191 |
+
def __init__(
|
| 192 |
+
self,
|
| 193 |
+
dim: int,
|
| 194 |
+
n_heads: int,
|
| 195 |
+
head_dim: int,
|
| 196 |
+
mlp_multiple_of: int = 256,
|
| 197 |
+
dropout: float = 0.0,
|
| 198 |
+
attention_bias: bool = False,
|
| 199 |
+
cross_attention_dim: Optional[int] = None,
|
| 200 |
+
use_ada_layer_norm_single: bool = False,
|
| 201 |
+
):
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.attn_norm = RMSNorm(dim, 1e-6)
|
| 204 |
+
self.attn = LlamaAttention(
|
| 205 |
+
dim,
|
| 206 |
+
n_heads,
|
| 207 |
+
head_dim,
|
| 208 |
+
bias=attention_bias,
|
| 209 |
+
dropout=dropout,
|
| 210 |
+
rope_dim=head_dim,
|
| 211 |
+
cross_attention_dim=None,
|
| 212 |
+
)
|
| 213 |
+
self.cross_attn = None
|
| 214 |
+
if cross_attention_dim is not None:
|
| 215 |
+
self.cross_attn_norm = RMSNorm(dim, 1e-6)
|
| 216 |
+
self.cross_attn = LlamaAttention(
|
| 217 |
+
dim,
|
| 218 |
+
n_heads,
|
| 219 |
+
head_dim,
|
| 220 |
+
bias=attention_bias,
|
| 221 |
+
dropout=dropout,
|
| 222 |
+
rope_dim=head_dim,
|
| 223 |
+
cross_attention_dim=cross_attention_dim,
|
| 224 |
+
)
|
| 225 |
+
self.mlp_norm = RMSNorm(dim, 1e-6)
|
| 226 |
+
self.mlp = LlamaMLP(dim, multiple_of=mlp_multiple_of, dropout=dropout)
|
| 227 |
+
self.use_ada_layer_norm_single = use_ada_layer_norm_single
|
| 228 |
+
if self.use_ada_layer_norm_single:
|
| 229 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
| 230 |
+
|
| 231 |
+
def forward(
|
| 232 |
+
self,
|
| 233 |
+
x: torch.Tensor,
|
| 234 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 235 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 236 |
+
timestep: Optional[torch.Tensor] = None,
|
| 237 |
+
) -> torch.Tensor:
|
| 238 |
+
if self.use_ada_layer_norm_single:
|
| 239 |
+
batch_size = x.shape[0]
|
| 240 |
+
# timestep: [B, 6*D]
|
| 241 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 242 |
+
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
| 243 |
+
).chunk(6, dim=1)
|
| 244 |
+
|
| 245 |
+
# Self-Attention with modulation and gating
|
| 246 |
+
norm_hidden_states = self.attn_norm(x)
|
| 247 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
| 248 |
+
h = self.attn(norm_hidden_states, attention_mask=attention_mask)
|
| 249 |
+
h = gate_msa * h
|
| 250 |
+
x = x + h
|
| 251 |
+
|
| 252 |
+
# MLP with modulation and gating
|
| 253 |
+
norm_hidden_states = self.mlp_norm(x)
|
| 254 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
| 255 |
+
h = self.mlp(norm_hidden_states)
|
| 256 |
+
h = gate_mlp * h
|
| 257 |
+
x = x + h
|
| 258 |
+
return x
|
| 259 |
+
else:
|
| 260 |
+
h = self.attn(self.attn_norm(x), attention_mask=attention_mask)
|
| 261 |
+
x = x + h
|
| 262 |
+
h = self.mlp(self.mlp_norm(x))
|
| 263 |
+
x = x + h
|
| 264 |
+
return x
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class ProjectLayer(nn.Module):
|
| 268 |
+
def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0.0):
|
| 269 |
+
super().__init__()
|
| 270 |
+
self.kernel_size = kernel_size
|
| 271 |
+
self.dropout = dropout
|
| 272 |
+
self.ffn_1 = nn.Conv1d(
|
| 273 |
+
hidden_size, filter_size, kernel_size, padding=kernel_size // 2
|
| 274 |
+
)
|
| 275 |
+
self.ffn_2 = nn.Linear(filter_size, filter_size)
|
| 276 |
+
|
| 277 |
+
def forward(self, x):
|
| 278 |
+
x = self.ffn_1(x.transpose(1, 2)).transpose(1, 2)
|
| 279 |
+
x = x * self.kernel_size**-0.5
|
| 280 |
+
x = self.ffn_2(x)
|
| 281 |
+
return x
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class LlamaTransformer(nn.Module):
|
| 285 |
+
def __init__(
|
| 286 |
+
self,
|
| 287 |
+
num_attention_heads: int,
|
| 288 |
+
attention_head_dim: int,
|
| 289 |
+
in_channels: int,
|
| 290 |
+
out_channels: int,
|
| 291 |
+
num_layers: int = 12,
|
| 292 |
+
num_layers_2: int = 2,
|
| 293 |
+
dropout: float = 0.0,
|
| 294 |
+
cross_attention_dim: Optional[int] = None,
|
| 295 |
+
norm_type: str = "layer_norm",
|
| 296 |
+
):
|
| 297 |
+
super().__init__()
|
| 298 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 299 |
+
inner_dim_2 = inner_dim * 2
|
| 300 |
+
self.in_channels = in_channels
|
| 301 |
+
self.out_channels = out_channels
|
| 302 |
+
self.inner_dim = inner_dim
|
| 303 |
+
self.inner_dim_2 = inner_dim_2
|
| 304 |
+
self.dropout = dropout
|
| 305 |
+
|
| 306 |
+
self.proj_in = ProjectLayer(in_channels, inner_dim, kernel_size=3)
|
| 307 |
+
|
| 308 |
+
use_ada_single = norm_type == "ada_norm_single"
|
| 309 |
+
self.transformer_blocks = nn.ModuleList(
|
| 310 |
+
[
|
| 311 |
+
LlamaTransformerBlock(
|
| 312 |
+
dim=inner_dim,
|
| 313 |
+
n_heads=num_attention_heads,
|
| 314 |
+
head_dim=attention_head_dim,
|
| 315 |
+
dropout=dropout,
|
| 316 |
+
attention_bias=False,
|
| 317 |
+
cross_attention_dim=cross_attention_dim,
|
| 318 |
+
use_ada_layer_norm_single=use_ada_single,
|
| 319 |
+
)
|
| 320 |
+
for _ in range(num_layers)
|
| 321 |
+
]
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
self.transformer_blocks_2 = nn.ModuleList(
|
| 325 |
+
[
|
| 326 |
+
LlamaTransformerBlock(
|
| 327 |
+
dim=inner_dim_2,
|
| 328 |
+
n_heads=num_attention_heads,
|
| 329 |
+
head_dim=attention_head_dim * 2,
|
| 330 |
+
dropout=dropout,
|
| 331 |
+
attention_bias=False,
|
| 332 |
+
cross_attention_dim=cross_attention_dim,
|
| 333 |
+
use_ada_layer_norm_single=use_ada_single,
|
| 334 |
+
)
|
| 335 |
+
for _ in range(num_layers_2)
|
| 336 |
+
]
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
self.connection_proj = ProjectLayer(
|
| 340 |
+
in_channels + inner_dim, inner_dim_2, kernel_size=3
|
| 341 |
+
)
|
| 342 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
| 343 |
+
self.norm_out_2 = nn.LayerNorm(inner_dim_2, elementwise_affine=False, eps=1e-6)
|
| 344 |
+
self.scale_shift_table = nn.Parameter(
|
| 345 |
+
torch.randn(2, inner_dim) / inner_dim**0.5
|
| 346 |
+
)
|
| 347 |
+
self.scale_shift_table_2 = nn.Parameter(
|
| 348 |
+
torch.randn(2, inner_dim_2) / inner_dim_2**0.5
|
| 349 |
+
)
|
| 350 |
+
self.proj_out = ProjectLayer(inner_dim_2, out_channels, kernel_size=3)
|
| 351 |
+
self.adaln_single = AdaLayerNormSingleFlow(inner_dim)
|
| 352 |
+
self.adaln_single_2 = AdaLayerNormSingleFlow(inner_dim_2)
|
| 353 |
+
|
| 354 |
+
def forward(
|
| 355 |
+
self,
|
| 356 |
+
hidden_states: torch.Tensor,
|
| 357 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 358 |
+
):
|
| 359 |
+
s = self.proj_in(hidden_states)
|
| 360 |
+
|
| 361 |
+
embedded_timestep = None
|
| 362 |
+
timestep_mod = None
|
| 363 |
+
if self.adaln_single is not None and timestep is not None:
|
| 364 |
+
batch_size = s.shape[0]
|
| 365 |
+
timestep_mod, embedded_timestep = self.adaln_single(
|
| 366 |
+
timestep, hidden_dtype=s.dtype
|
| 367 |
+
)
|
| 368 |
+
for blk in self.transformer_blocks:
|
| 369 |
+
s = blk(s, timestep=timestep_mod)
|
| 370 |
+
|
| 371 |
+
if embedded_timestep is None:
|
| 372 |
+
embedded_timestep = torch.zeros(
|
| 373 |
+
s.size(0), s.size(-1), device=s.device, dtype=s.dtype
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
shift, scale = (
|
| 377 |
+
self.scale_shift_table[None] + embedded_timestep[:, None]
|
| 378 |
+
).chunk(2, dim=1)
|
| 379 |
+
s = self.norm_out(s)
|
| 380 |
+
s = s * (1 + scale) + shift
|
| 381 |
+
|
| 382 |
+
x = torch.cat([hidden_states, s], dim=-1)
|
| 383 |
+
x = self.connection_proj(x)
|
| 384 |
+
|
| 385 |
+
embedded_timestep_2 = None
|
| 386 |
+
timestep_mod_2 = None
|
| 387 |
+
if self.adaln_single_2 is not None and timestep is not None:
|
| 388 |
+
batch_size = x.shape[0]
|
| 389 |
+
timestep_mod_2, embedded_timestep_2 = self.adaln_single_2(
|
| 390 |
+
timestep, hidden_dtype=x.dtype
|
| 391 |
+
)
|
| 392 |
+
for blk in self.transformer_blocks_2:
|
| 393 |
+
x = blk(x, timestep=timestep_mod_2)
|
| 394 |
+
|
| 395 |
+
if embedded_timestep_2 is None:
|
| 396 |
+
embedded_timestep_2 = torch.zeros(
|
| 397 |
+
x.size(0), x.size(-1), device=x.device, dtype=x.dtype
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
shift_2, scale_2 = (
|
| 401 |
+
self.scale_shift_table_2[None] + embedded_timestep_2[:, None]
|
| 402 |
+
).chunk(2, dim=1)
|
| 403 |
+
x = self.norm_out_2(x)
|
| 404 |
+
x = x * (1 + scale_2) + shift_2
|
| 405 |
+
|
| 406 |
+
out = self.proj_out(x)
|
| 407 |
+
|
| 408 |
+
return out
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class PixArtAlphaCombinedFlowEmbeddings(nn.Module):
|
| 412 |
+
def __init__(self, embedding_dim: int, size_emb_dim: int):
|
| 413 |
+
super().__init__()
|
| 414 |
+
self.flow_t_size = 512
|
| 415 |
+
self.outdim = size_emb_dim
|
| 416 |
+
self.timestep_embedder = TimestepEmbedding(
|
| 417 |
+
in_channels=self.flow_t_size, time_embed_dim=embedding_dim
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
def timestep_embedding(self, timesteps, max_period=10000, scale=1000):
|
| 421 |
+
half = self.flow_t_size // 2
|
| 422 |
+
freqs = torch.exp(
|
| 423 |
+
-math.log(max_period)
|
| 424 |
+
* torch.arange(start=0, end=half, device=timesteps.device)
|
| 425 |
+
/ half
|
| 426 |
+
).type(timesteps.type())
|
| 427 |
+
args = timesteps[:, None] * freqs[None] * scale
|
| 428 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 429 |
+
if self.flow_t_size % 2:
|
| 430 |
+
embedding = torch.cat(
|
| 431 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
| 432 |
+
)
|
| 433 |
+
return embedding
|
| 434 |
+
|
| 435 |
+
def forward(self, timestep, hidden_dtype):
|
| 436 |
+
timesteps_proj = self.timestep_embedding(timestep)
|
| 437 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))
|
| 438 |
+
conditioning = timesteps_emb
|
| 439 |
+
return conditioning
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
class AdaLayerNormSingleFlow(nn.Module):
|
| 443 |
+
def __init__(self, embedding_dim: int):
|
| 444 |
+
super().__init__()
|
| 445 |
+
self.emb = PixArtAlphaCombinedFlowEmbeddings(
|
| 446 |
+
embedding_dim, size_emb_dim=embedding_dim // 3
|
| 447 |
+
)
|
| 448 |
+
self.silu = nn.SiLU()
|
| 449 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
| 450 |
+
|
| 451 |
+
def forward(
|
| 452 |
+
self,
|
| 453 |
+
timestep: torch.Tensor,
|
| 454 |
+
hidden_dtype: Optional[torch.dtype] = None,
|
| 455 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 456 |
+
|
| 457 |
+
embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype)
|
| 458 |
+
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
class TimestepEmbedding(nn.Module):
|
| 462 |
+
def __init__(self, in_channels: int, time_embed_dim: int):
|
| 463 |
+
super().__init__()
|
| 464 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
| 465 |
+
self.act = nn.SiLU()
|
| 466 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
| 467 |
+
|
| 468 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 469 |
+
x = self.linear_1(x)
|
| 470 |
+
x = self.act(x)
|
| 471 |
+
x = self.linear_2(x)
|
| 472 |
+
return x
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class Timesteps(nn.Module):
|
| 476 |
+
def __init__(
|
| 477 |
+
self,
|
| 478 |
+
num_channels: int,
|
| 479 |
+
flip_sin_to_cos: bool = True,
|
| 480 |
+
downscale_freq_shift: float = 0,
|
| 481 |
+
):
|
| 482 |
+
super().__init__()
|
| 483 |
+
self.num_channels = num_channels
|
| 484 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
| 485 |
+
self.downscale_freq_shift = downscale_freq_shift
|
| 486 |
+
|
| 487 |
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
| 488 |
+
half_dim = self.num_channels // 2
|
| 489 |
+
exponent = (
|
| 490 |
+
-math.log(10000)
|
| 491 |
+
* torch.arange(0, half_dim, device=timesteps.device)
|
| 492 |
+
/ (half_dim - self.downscale_freq_shift)
|
| 493 |
+
)
|
| 494 |
+
emb = torch.exp(exponent)[None, :] * timesteps[:, None]
|
| 495 |
+
if self.flip_sin_to_cos:
|
| 496 |
+
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
|
| 497 |
+
else:
|
| 498 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 499 |
+
if self.num_channels % 2 == 1:
|
| 500 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
| 501 |
+
return emb
|
src/heartlib/heartmula/configuration_heartmula.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class HeartMuLaConfig(PretrainedConfig):
|
| 5 |
+
model_type = "heartmula"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
backbone_flavor: str = "llama-3B",
|
| 10 |
+
decoder_flavor: str = "llama-300M",
|
| 11 |
+
text_vocab_size: int = 128256,
|
| 12 |
+
audio_vocab_size: int = 8197,
|
| 13 |
+
audio_num_codebooks: int = 8,
|
| 14 |
+
muq_dim: int = 512,
|
| 15 |
+
**kwargs
|
| 16 |
+
):
|
| 17 |
+
super().__init__(**kwargs)
|
| 18 |
+
self.backbone_flavor = backbone_flavor
|
| 19 |
+
self.decoder_flavor = decoder_flavor
|
| 20 |
+
self.text_vocab_size = text_vocab_size
|
| 21 |
+
self.audio_vocab_size = audio_vocab_size
|
| 22 |
+
self.audio_num_codebooks = audio_num_codebooks
|
| 23 |
+
self.muq_dim = muq_dim
|
src/heartlib/heartmula/modeling_heartmula.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .configuration_heartmula import HeartMuLaConfig
|
| 4 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torchtune
|
| 8 |
+
from torchtune.models import llama3_2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def llama3_2_3B() -> torchtune.modules.transformer.TransformerDecoder:
|
| 12 |
+
return llama3_2.llama3_2(
|
| 13 |
+
vocab_size=128_256,
|
| 14 |
+
num_layers=28,
|
| 15 |
+
num_heads=24,
|
| 16 |
+
num_kv_heads=8,
|
| 17 |
+
embed_dim=3072,
|
| 18 |
+
max_seq_len=8192,
|
| 19 |
+
intermediate_dim=8192,
|
| 20 |
+
attn_dropout=0.0,
|
| 21 |
+
norm_eps=1e-5,
|
| 22 |
+
rope_base=500_000,
|
| 23 |
+
scale_factor=32,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def llama3_2_300M() -> torchtune.modules.transformer.TransformerDecoder:
|
| 28 |
+
return llama3_2.llama3_2(
|
| 29 |
+
vocab_size=128_256,
|
| 30 |
+
num_layers=3,
|
| 31 |
+
num_heads=8,
|
| 32 |
+
num_kv_heads=4,
|
| 33 |
+
embed_dim=3072,
|
| 34 |
+
max_seq_len=2048,
|
| 35 |
+
intermediate_dim=8192,
|
| 36 |
+
attn_dropout=0.0,
|
| 37 |
+
norm_eps=1e-5,
|
| 38 |
+
rope_base=500_000,
|
| 39 |
+
scale_factor=32,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def llama3_2_7B() -> torchtune.modules.transformer.TransformerDecoder:
|
| 44 |
+
return llama3_2.llama3_2(
|
| 45 |
+
vocab_size=128_256,
|
| 46 |
+
num_layers=32,
|
| 47 |
+
num_heads=32,
|
| 48 |
+
num_kv_heads=8,
|
| 49 |
+
embed_dim=4096,
|
| 50 |
+
max_seq_len=8192,
|
| 51 |
+
intermediate_dim=14336,
|
| 52 |
+
attn_dropout=0.0,
|
| 53 |
+
norm_eps=1e-5,
|
| 54 |
+
rope_base=500_000,
|
| 55 |
+
scale_factor=32,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def llama3_2_400M() -> torchtune.modules.transformer.TransformerDecoder:
|
| 60 |
+
return llama3_2.llama3_2(
|
| 61 |
+
vocab_size=128_256,
|
| 62 |
+
num_layers=4,
|
| 63 |
+
num_heads=8,
|
| 64 |
+
num_kv_heads=4,
|
| 65 |
+
embed_dim=3072,
|
| 66 |
+
max_seq_len=2048,
|
| 67 |
+
intermediate_dim=8192,
|
| 68 |
+
attn_dropout=0.0,
|
| 69 |
+
norm_eps=1e-5,
|
| 70 |
+
rope_base=500_000,
|
| 71 |
+
scale_factor=32,
|
| 72 |
+
) # 减少了num_heads和num_kv_heads之间的倍速,提升了精确度,但降低了效率
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
FLAVORS = {
|
| 76 |
+
"llama-3B": llama3_2_3B,
|
| 77 |
+
"llama-300M": llama3_2_300M,
|
| 78 |
+
"llama-7B": llama3_2_7B,
|
| 79 |
+
"llama-400M": llama3_2_400M,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _prepare_transformer(model):
|
| 84 |
+
embed_dim = model.tok_embeddings.embedding_dim
|
| 85 |
+
model.tok_embeddings = nn.Identity()
|
| 86 |
+
model.output = nn.Identity()
|
| 87 |
+
return model, embed_dim
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _create_causal_mask(seq_len: int, device: torch.device):
|
| 91 |
+
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
|
| 95 |
+
r = mask[input_pos, :]
|
| 96 |
+
return r
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _multinomial_sample_one_no_sync(
|
| 100 |
+
probs,
|
| 101 |
+
): # Does multinomial sampling without a cuda synchronization
|
| 102 |
+
q = torch.empty_like(probs).exponential_(1)
|
| 103 |
+
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
|
| 107 |
+
logits = logits / temperature
|
| 108 |
+
|
| 109 |
+
filter_value: float = -float("Inf")
|
| 110 |
+
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
|
| 111 |
+
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
|
| 112 |
+
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
|
| 113 |
+
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
|
| 114 |
+
|
| 115 |
+
sample_token = _multinomial_sample_one_no_sync(probs)
|
| 116 |
+
return sample_token
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class HeartMuLa(PreTrainedModel):
|
| 120 |
+
config_class = HeartMuLaConfig
|
| 121 |
+
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
config: HeartMuLaConfig,
|
| 125 |
+
):
|
| 126 |
+
super(HeartMuLa, self).__init__(config)
|
| 127 |
+
|
| 128 |
+
self.config = config
|
| 129 |
+
|
| 130 |
+
self.backbone, backbone_dim = _prepare_transformer(
|
| 131 |
+
FLAVORS[config.backbone_flavor]()
|
| 132 |
+
)
|
| 133 |
+
self.decoder, decoder_dim = _prepare_transformer(
|
| 134 |
+
FLAVORS[config.decoder_flavor]()
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
|
| 138 |
+
self.audio_embeddings = nn.Embedding(
|
| 139 |
+
config.audio_vocab_size * config.audio_num_codebooks, backbone_dim
|
| 140 |
+
)
|
| 141 |
+
self.unconditional_text_embedding = nn.Embedding(1, backbone_dim)
|
| 142 |
+
|
| 143 |
+
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
|
| 144 |
+
self.codebook0_head = nn.Linear(
|
| 145 |
+
backbone_dim, config.audio_vocab_size, bias=False
|
| 146 |
+
)
|
| 147 |
+
self.audio_head = nn.Parameter(
|
| 148 |
+
torch.empty(
|
| 149 |
+
config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
self.muq_linear = nn.Linear(config.muq_dim, backbone_dim)
|
| 153 |
+
self.post_init()
|
| 154 |
+
|
| 155 |
+
def setup_caches(self, max_batch_size: int):
|
| 156 |
+
dtype = next(self.parameters()).dtype
|
| 157 |
+
device = next(self.parameters()).device
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
self.reset_caches()
|
| 161 |
+
except RuntimeError:
|
| 162 |
+
pass
|
| 163 |
+
|
| 164 |
+
with device:
|
| 165 |
+
self.backbone.setup_caches(max_batch_size, dtype)
|
| 166 |
+
self.decoder.setup_caches(
|
| 167 |
+
max_batch_size,
|
| 168 |
+
dtype,
|
| 169 |
+
decoder_max_seq_len=self.config.audio_num_codebooks,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
self.register_buffer(
|
| 173 |
+
"backbone_causal_mask",
|
| 174 |
+
_create_causal_mask(self.backbone.max_seq_len, device),
|
| 175 |
+
)
|
| 176 |
+
self.register_buffer(
|
| 177 |
+
"decoder_causal_mask",
|
| 178 |
+
_create_causal_mask(self.config.audio_num_codebooks, device),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def generate_frame(
|
| 182 |
+
self,
|
| 183 |
+
tokens: torch.Tensor,
|
| 184 |
+
tokens_mask: torch.Tensor,
|
| 185 |
+
input_pos: torch.Tensor,
|
| 186 |
+
temperature: float,
|
| 187 |
+
topk: int,
|
| 188 |
+
cfg_scale: float,
|
| 189 |
+
continuous_segments: torch.Tensor = None,
|
| 190 |
+
starts=None,
|
| 191 |
+
) -> torch.Tensor:
|
| 192 |
+
b, s, _ = tokens.size()
|
| 193 |
+
|
| 194 |
+
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
|
| 195 |
+
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
|
| 196 |
+
|
| 197 |
+
uncond_mask = None
|
| 198 |
+
if cfg_scale > 1.0 and b > 1:
|
| 199 |
+
actual_B = b // 2
|
| 200 |
+
uncond_mask = torch.cat(
|
| 201 |
+
[
|
| 202 |
+
torch.zeros(actual_B, dtype=torch.bool, device=tokens.device),
|
| 203 |
+
torch.ones(actual_B, dtype=torch.bool, device=tokens.device),
|
| 204 |
+
]
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
embeds = self._embed_tokens(tokens, uncond_mask=uncond_mask)
|
| 208 |
+
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
|
| 209 |
+
h = masked_embeds.sum(dim=2, dtype=embeds.dtype) # merge
|
| 210 |
+
if continuous_segments is not None:
|
| 211 |
+
continuous_segments = self.muq_linear(continuous_segments)
|
| 212 |
+
if uncond_mask is not None:
|
| 213 |
+
uncond_embed = self.unconditional_text_embedding(
|
| 214 |
+
torch.zeros(1, device=tokens.device, dtype=torch.long)
|
| 215 |
+
)
|
| 216 |
+
mask_expanded = uncond_mask.view(b, 1).expand_as(continuous_segments)
|
| 217 |
+
continuous_segments = torch.where(
|
| 218 |
+
mask_expanded, uncond_embed, continuous_segments
|
| 219 |
+
)
|
| 220 |
+
batch_indices = torch.arange(h.shape[0], device=h.device)
|
| 221 |
+
h[batch_indices, starts] = continuous_segments
|
| 222 |
+
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask)
|
| 223 |
+
last_h = h[:, -1, :] # the last frame
|
| 224 |
+
c0_logits = self.codebook0_head(last_h) # only predict the audio part
|
| 225 |
+
|
| 226 |
+
if cfg_scale > 1.0 and b > 1 and (b % 2 == 0):
|
| 227 |
+
actual_B = b // 2
|
| 228 |
+
cond_logits = c0_logits[:actual_B, :]
|
| 229 |
+
uncond_logits = c0_logits[actual_B:, :]
|
| 230 |
+
guided_logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
| 231 |
+
c0_sample = sample_topk(guided_logits, topk, temperature)
|
| 232 |
+
c0_sample = c0_sample.repeat(
|
| 233 |
+
2, 1
|
| 234 |
+
) # repeat to both branches to keep alignment
|
| 235 |
+
else:
|
| 236 |
+
c0_sample = sample_topk(c0_logits, topk, temperature)
|
| 237 |
+
|
| 238 |
+
c0_embed = self._embed_audio(0, c0_sample)
|
| 239 |
+
|
| 240 |
+
self.decoder.reset_caches()
|
| 241 |
+
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
|
| 242 |
+
curr_sample = c0_sample.clone()
|
| 243 |
+
curr_pos = (
|
| 244 |
+
torch.arange(0, curr_h.size(1), device=curr_h.device)
|
| 245 |
+
.unsqueeze(0)
|
| 246 |
+
.repeat(curr_h.size(0), 1)
|
| 247 |
+
)
|
| 248 |
+
curr_h = curr_h.to(embeds.dtype)
|
| 249 |
+
for i in range(1, self.config.audio_num_codebooks):
|
| 250 |
+
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
|
| 251 |
+
decoder_h = self.decoder(
|
| 252 |
+
self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask
|
| 253 |
+
)
|
| 254 |
+
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
|
| 255 |
+
if cfg_scale > 1.0 and b > 1 and (b % 2 == 0):
|
| 256 |
+
actual_B = b // 2
|
| 257 |
+
cond_ci = ci_logits[:actual_B, :]
|
| 258 |
+
uncond_ci = ci_logits[actual_B:, :]
|
| 259 |
+
guided_ci = uncond_ci + (cond_ci - uncond_ci) * cfg_scale
|
| 260 |
+
|
| 261 |
+
ci_sample = sample_topk(guided_ci, topk, temperature)
|
| 262 |
+
ci_sample = ci_sample.repeat(2, 1)
|
| 263 |
+
else:
|
| 264 |
+
ci_sample = sample_topk(ci_logits, topk, temperature)
|
| 265 |
+
ci_embed = self._embed_audio(i, ci_sample)
|
| 266 |
+
curr_h = ci_embed
|
| 267 |
+
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
|
| 268 |
+
curr_pos = curr_pos[:, -1:] + 1
|
| 269 |
+
|
| 270 |
+
return curr_sample
|
| 271 |
+
|
| 272 |
+
def reset_caches(self):
|
| 273 |
+
self.backbone.reset_caches()
|
| 274 |
+
self.decoder.reset_caches()
|
| 275 |
+
|
| 276 |
+
def _embed_local_audio(self, tokens):
|
| 277 |
+
"""the token from 0-30"""
|
| 278 |
+
audio_tokens = tokens + (
|
| 279 |
+
self.config.audio_vocab_size
|
| 280 |
+
* torch.arange(self.config.audio_num_codebooks - 1, device=tokens.device)
|
| 281 |
+
)
|
| 282 |
+
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
| 283 |
+
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks - 1, -1
|
| 284 |
+
)
|
| 285 |
+
return audio_embeds
|
| 286 |
+
|
| 287 |
+
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
|
| 288 |
+
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
|
| 289 |
+
|
| 290 |
+
def _embed_tokens(
|
| 291 |
+
self, tokens: torch.Tensor, uncond_mask: torch.Tensor | None
|
| 292 |
+
) -> torch.Tensor:
|
| 293 |
+
B, S, _ = tokens.size()
|
| 294 |
+
text_embeds = self.text_embeddings(tokens[:, :, -1])
|
| 295 |
+
|
| 296 |
+
if uncond_mask is not None:
|
| 297 |
+
uncond_text_embed = self.unconditional_text_embedding(
|
| 298 |
+
torch.zeros(1, device=tokens.device, dtype=torch.long)
|
| 299 |
+
)
|
| 300 |
+
mask_expanded = uncond_mask.view(B, 1, 1).expand_as(text_embeds)
|
| 301 |
+
text_embeds = torch.where(
|
| 302 |
+
mask_expanded,
|
| 303 |
+
uncond_text_embed,
|
| 304 |
+
text_embeds,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
text_embeds = text_embeds.unsqueeze(-2)
|
| 308 |
+
|
| 309 |
+
audio_tokens = tokens[:, :, :-1] + (
|
| 310 |
+
self.config.audio_vocab_size
|
| 311 |
+
* torch.arange(self.config.audio_num_codebooks, device=tokens.device)
|
| 312 |
+
)
|
| 313 |
+
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
| 314 |
+
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
|
| 315 |
+
)
|
| 316 |
+
return torch.cat([audio_embeds, text_embeds], dim=-2)
|
src/heartlib/pipelines/__init__.py
ADDED
|
File without changes
|
src/heartlib/pipelines/lyrics_transcription.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.pipelines.automatic_speech_recognition import (
|
| 2 |
+
AutomaticSpeechRecognitionPipeline,
|
| 3 |
+
)
|
| 4 |
+
from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration
|
| 5 |
+
from transformers.models.whisper.processing_whisper import WhisperProcessor
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HeartTranscriptorPipeline(AutomaticSpeechRecognitionPipeline):
|
| 11 |
+
def __init__(self, *args, **kwargs):
|
| 12 |
+
super().__init__(*args, **kwargs)
|
| 13 |
+
|
| 14 |
+
@classmethod
|
| 15 |
+
def from_pretrained(
|
| 16 |
+
cls, pretrained_path: str, device: torch.device, dtype: torch.dtype
|
| 17 |
+
):
|
| 18 |
+
if os.path.exists(
|
| 19 |
+
hearttranscriptor_path := os.path.join(
|
| 20 |
+
pretrained_path, "HeartTranscriptor-oss"
|
| 21 |
+
)
|
| 22 |
+
):
|
| 23 |
+
model = WhisperForConditionalGeneration.from_pretrained(
|
| 24 |
+
hearttranscriptor_path, torch_dtype=dtype, low_cpu_mem_usage=True
|
| 25 |
+
)
|
| 26 |
+
processor = WhisperProcessor.from_pretrained(hearttranscriptor_path)
|
| 27 |
+
else:
|
| 28 |
+
raise FileNotFoundError(
|
| 29 |
+
f"Expected to find checkpoint for HeartTranscriptor at {hearttranscriptor_path} but not found. Please check your folder {pretrained_path}."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return cls(
|
| 33 |
+
model=model,
|
| 34 |
+
tokenizer=processor.tokenizer,
|
| 35 |
+
feature_extractor=processor.feature_extractor,
|
| 36 |
+
device=device,
|
| 37 |
+
dtype=dtype,
|
| 38 |
+
chunk_length_s=30,
|
| 39 |
+
batch_size=16,
|
| 40 |
+
)
|
src/heartlib/pipelines/music_generation.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.pipelines.base import Pipeline
|
| 2 |
+
from tokenizers import Tokenizer
|
| 3 |
+
from ..heartmula.modeling_heartmula import HeartMuLa
|
| 4 |
+
from ..heartcodec.modeling_heartcodec import HeartCodec
|
| 5 |
+
import torch
|
| 6 |
+
from typing import Dict, Any, Optional
|
| 7 |
+
import os
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import torchaudio
|
| 11 |
+
import json
|
| 12 |
+
from transformers import BitsAndBytesConfig
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class HeartMuLaGenConfig:
|
| 17 |
+
text_bos_id: int = 128000
|
| 18 |
+
text_eos_id: int = 128001
|
| 19 |
+
audio_eos_id: int = 8193
|
| 20 |
+
empty_id: int = 0
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def from_file(cls, path: str):
|
| 24 |
+
with open(path, encoding="utf-8") as fp:
|
| 25 |
+
data = json.load(fp)
|
| 26 |
+
return cls(**data)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class HeartMuLaGenPipeline(Pipeline):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
model: HeartMuLa,
|
| 33 |
+
audio_codec: HeartCodec,
|
| 34 |
+
muq_mulan: Optional[Any],
|
| 35 |
+
text_tokenizer: Tokenizer,
|
| 36 |
+
config: HeartMuLaGenConfig,
|
| 37 |
+
device: torch.device,
|
| 38 |
+
dtype: torch.dtype,
|
| 39 |
+
):
|
| 40 |
+
super().__init__(model, dtype=dtype)
|
| 41 |
+
self.model = model
|
| 42 |
+
self.audio_codec = audio_codec
|
| 43 |
+
self.muq_mulan = muq_mulan
|
| 44 |
+
self.text_tokenizer = text_tokenizer
|
| 45 |
+
self.config = config
|
| 46 |
+
|
| 47 |
+
self._parallel_number = audio_codec.config.num_quantizers + 1
|
| 48 |
+
self._muq_dim = model.config.muq_dim
|
| 49 |
+
|
| 50 |
+
def _sanitize_parameters(self, **kwargs):
|
| 51 |
+
preprocess_kwargs = {"cfg_scale": kwargs.get("cfg_scale", 1.5)}
|
| 52 |
+
forward_kwargs = {
|
| 53 |
+
"max_audio_length_ms": kwargs.get("max_audio_length_ms", 120_000),
|
| 54 |
+
"temperature": kwargs.get("temperature", 1.0),
|
| 55 |
+
"topk": kwargs.get("topk", 50),
|
| 56 |
+
"cfg_scale": kwargs.get("cfg_scale", 1.5),
|
| 57 |
+
}
|
| 58 |
+
postprocess_kwargs = {
|
| 59 |
+
"save_path": kwargs.get("save_path", "output.mp3"),
|
| 60 |
+
}
|
| 61 |
+
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
| 62 |
+
|
| 63 |
+
def preprocess(self, inputs: Dict[str, Any], cfg_scale: float):
|
| 64 |
+
|
| 65 |
+
# process tags
|
| 66 |
+
tags = inputs["tags"]
|
| 67 |
+
if os.path.isfile(tags):
|
| 68 |
+
with open(tags, encoding="utf-8") as fp:
|
| 69 |
+
tags = fp.read()
|
| 70 |
+
assert isinstance(tags, str), f"tags must be a string, but got {type(tags)}"
|
| 71 |
+
|
| 72 |
+
tags = tags.lower()
|
| 73 |
+
# encapsulate with special <tag> and </tag> tokens
|
| 74 |
+
if not tags.startswith("<tag>"):
|
| 75 |
+
tags = f"<tag>{tags}"
|
| 76 |
+
if not tags.endswith("</tag>"):
|
| 77 |
+
tags = f"{tags}</tag>"
|
| 78 |
+
|
| 79 |
+
tags_ids = self.text_tokenizer.encode(tags).ids
|
| 80 |
+
if tags_ids[0] != self.config.text_bos_id:
|
| 81 |
+
tags_ids = [self.config.text_bos_id] + tags_ids
|
| 82 |
+
if tags_ids[-1] != self.config.text_eos_id:
|
| 83 |
+
tags_ids = tags_ids + [self.config.text_eos_id]
|
| 84 |
+
|
| 85 |
+
# process reference audio
|
| 86 |
+
ref_audio = inputs.get("ref_audio", None)
|
| 87 |
+
if ref_audio is not None:
|
| 88 |
+
raise NotImplementedError("ref_audio is not supported yet.")
|
| 89 |
+
muq_embed = torch.zeros([self._muq_dim], dtype=self.dtype)
|
| 90 |
+
muq_idx = len(tags_ids)
|
| 91 |
+
|
| 92 |
+
# process lyrics
|
| 93 |
+
lyrics = inputs["lyrics"]
|
| 94 |
+
if os.path.isfile(lyrics):
|
| 95 |
+
with open(lyrics, encoding="utf-8") as fp:
|
| 96 |
+
lyrics = fp.read()
|
| 97 |
+
assert isinstance(
|
| 98 |
+
lyrics, str
|
| 99 |
+
), f"lyrics must be a string, but got {type(lyrics)}"
|
| 100 |
+
lyrics = lyrics.lower()
|
| 101 |
+
|
| 102 |
+
lyrics_ids = self.text_tokenizer.encode(lyrics).ids
|
| 103 |
+
if lyrics_ids[0] != self.config.text_bos_id:
|
| 104 |
+
lyrics_ids = [self.config.text_bos_id] + lyrics_ids
|
| 105 |
+
if lyrics_ids[-1] != self.config.text_eos_id:
|
| 106 |
+
lyrics_ids = lyrics_ids + [self.config.text_eos_id]
|
| 107 |
+
|
| 108 |
+
# cat them together. tags, ref_audio, lyrics
|
| 109 |
+
prompt_len = len(tags_ids) + 1 + len(lyrics_ids)
|
| 110 |
+
|
| 111 |
+
tokens = torch.zeros([prompt_len, self._parallel_number], dtype=torch.long)
|
| 112 |
+
tokens[: len(tags_ids), -1] = torch.tensor(tags_ids)
|
| 113 |
+
tokens[len(tags_ids) + 1 :, -1] = torch.tensor(lyrics_ids)
|
| 114 |
+
|
| 115 |
+
tokens_mask = torch.zeros_like(tokens, dtype=torch.bool)
|
| 116 |
+
tokens_mask[:, -1] = True
|
| 117 |
+
|
| 118 |
+
bs_size = 2 if cfg_scale != 1.0 else 1
|
| 119 |
+
|
| 120 |
+
def _cfg_cat(tensor: torch.Tensor, cfg_scale: float):
|
| 121 |
+
tensor = tensor.unsqueeze(0)
|
| 122 |
+
if cfg_scale != 1.0:
|
| 123 |
+
tensor = torch.cat([tensor, tensor], dim=0)
|
| 124 |
+
return tensor
|
| 125 |
+
|
| 126 |
+
return {
|
| 127 |
+
"tokens": _cfg_cat(tokens, cfg_scale),
|
| 128 |
+
"tokens_mask": _cfg_cat(tokens_mask, cfg_scale),
|
| 129 |
+
"muq_embed": _cfg_cat(muq_embed, cfg_scale),
|
| 130 |
+
"muq_idx": [muq_idx] * bs_size,
|
| 131 |
+
"pos": _cfg_cat(torch.arange(prompt_len, dtype=torch.long), cfg_scale),
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
def _forward(
|
| 135 |
+
self,
|
| 136 |
+
model_inputs: Dict[str, Any],
|
| 137 |
+
max_audio_length_ms: int,
|
| 138 |
+
temperature: float,
|
| 139 |
+
topk: int,
|
| 140 |
+
cfg_scale: float,
|
| 141 |
+
):
|
| 142 |
+
prompt_tokens = model_inputs["tokens"]
|
| 143 |
+
prompt_tokens_mask = model_inputs["tokens_mask"]
|
| 144 |
+
continuous_segment = model_inputs["muq_embed"]
|
| 145 |
+
starts = model_inputs["muq_idx"]
|
| 146 |
+
prompt_pos = model_inputs["pos"]
|
| 147 |
+
|
| 148 |
+
frames = []
|
| 149 |
+
|
| 150 |
+
bs_size = 2 if cfg_scale != 1.0 else 1
|
| 151 |
+
self.model.setup_caches(bs_size)
|
| 152 |
+
with torch.autocast(device_type=self.device.type, dtype=self.dtype):
|
| 153 |
+
curr_token = self.model.generate_frame(
|
| 154 |
+
tokens=prompt_tokens,
|
| 155 |
+
tokens_mask=prompt_tokens_mask,
|
| 156 |
+
input_pos=prompt_pos,
|
| 157 |
+
temperature=temperature,
|
| 158 |
+
topk=topk,
|
| 159 |
+
cfg_scale=cfg_scale,
|
| 160 |
+
continuous_segments=continuous_segment,
|
| 161 |
+
starts=starts,
|
| 162 |
+
)
|
| 163 |
+
frames.append(curr_token[0:1,])
|
| 164 |
+
|
| 165 |
+
def _pad_audio_token(token: torch.Tensor):
|
| 166 |
+
padded_token = (
|
| 167 |
+
torch.ones(
|
| 168 |
+
(token.shape[0], self._parallel_number),
|
| 169 |
+
device=token.device,
|
| 170 |
+
dtype=torch.long,
|
| 171 |
+
)
|
| 172 |
+
* self.config.empty_id
|
| 173 |
+
)
|
| 174 |
+
padded_token[:, :-1] = token
|
| 175 |
+
padded_token = padded_token.unsqueeze(1)
|
| 176 |
+
padded_token_mask = torch.ones_like(
|
| 177 |
+
padded_token, device=token.device, dtype=torch.bool
|
| 178 |
+
)
|
| 179 |
+
padded_token_mask[..., -1] = False
|
| 180 |
+
return padded_token, padded_token_mask
|
| 181 |
+
|
| 182 |
+
max_audio_frames = max_audio_length_ms // 80
|
| 183 |
+
|
| 184 |
+
for i in tqdm(range(max_audio_frames)):
|
| 185 |
+
curr_token, curr_token_mask = _pad_audio_token(curr_token)
|
| 186 |
+
with torch.autocast(device_type=self.device.type, dtype=self.dtype):
|
| 187 |
+
curr_token = self.model.generate_frame(
|
| 188 |
+
tokens=curr_token,
|
| 189 |
+
tokens_mask=curr_token_mask,
|
| 190 |
+
input_pos=prompt_pos[..., -1:] + i + 1,
|
| 191 |
+
temperature=temperature,
|
| 192 |
+
topk=topk,
|
| 193 |
+
cfg_scale=cfg_scale,
|
| 194 |
+
continuous_segments=None,
|
| 195 |
+
starts=None,
|
| 196 |
+
)
|
| 197 |
+
if torch.any(curr_token[0:1, :] >= self.config.audio_eos_id):
|
| 198 |
+
break
|
| 199 |
+
frames.append(curr_token[0:1,])
|
| 200 |
+
frames = torch.stack(frames).permute(1, 2, 0).squeeze(0)
|
| 201 |
+
wav = self.audio_codec.detokenize(frames)
|
| 202 |
+
return {"wav": wav}
|
| 203 |
+
|
| 204 |
+
def postprocess(self, model_outputs: Dict[str, Any], save_path: str):
|
| 205 |
+
wav = model_outputs["wav"]
|
| 206 |
+
torchaudio.save(save_path, wav, 48000)
|
| 207 |
+
|
| 208 |
+
@classmethod
|
| 209 |
+
def from_pretrained(
|
| 210 |
+
cls,
|
| 211 |
+
pretrained_path: str,
|
| 212 |
+
device: torch.device,
|
| 213 |
+
dtype: torch.dtype,
|
| 214 |
+
version: str,
|
| 215 |
+
bnb_config: Optional[BitsAndBytesConfig] = None,
|
| 216 |
+
):
|
| 217 |
+
|
| 218 |
+
if os.path.exists(
|
| 219 |
+
heartcodec_path := os.path.join(pretrained_path, "HeartCodec-oss")
|
| 220 |
+
):
|
| 221 |
+
heartcodec = HeartCodec.from_pretrained(heartcodec_path, device_map=device)
|
| 222 |
+
else:
|
| 223 |
+
raise FileNotFoundError(
|
| 224 |
+
f"Expected to find checkpoint for HeartCodec at {heartcodec_path} but not found. Please check your folder {pretrained_path}."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if os.path.exists(
|
| 228 |
+
heartmula_path := os.path.join(pretrained_path, f"HeartMuLa-oss-{version}")
|
| 229 |
+
):
|
| 230 |
+
heartmula = HeartMuLa.from_pretrained(
|
| 231 |
+
heartmula_path, dtype=dtype, quantization_config=bnb_config
|
| 232 |
+
)
|
| 233 |
+
else:
|
| 234 |
+
raise FileNotFoundError(
|
| 235 |
+
f"Expected to find checkpoint for HeartMuLa at {heartmula_path} but not found. Please check your folder {pretrained_path}."
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if os.path.isfile(
|
| 239 |
+
vocab_path := os.path.join(pretrained_path, "tokenizer.json")
|
| 240 |
+
):
|
| 241 |
+
tokenizer = Tokenizer.from_file(vocab_path)
|
| 242 |
+
else:
|
| 243 |
+
raise FileNotFoundError(
|
| 244 |
+
f"Expected to find tokenizer.json for HeartMuLa at {vocab_path} but not found. Please check your folder {pretrained_path}."
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
if os.path.isfile(
|
| 248 |
+
gen_config_path := os.path.join(pretrained_path, "gen_config.json")
|
| 249 |
+
):
|
| 250 |
+
gen_config = HeartMuLaGenConfig.from_file(gen_config_path)
|
| 251 |
+
else:
|
| 252 |
+
raise FileNotFoundError(
|
| 253 |
+
f"Expected to find gen_config.json for HeartMuLa at {gen_config_path} but not found. Please check your folder {pretrained_path}."
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
return cls(heartmula, heartcodec, None, tokenizer, gen_config, device, dtype)
|