aiqtech commited on
Commit
8122ef6
·
verified ·
1 Parent(s): 82087af

Deploy from GitHub repository

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/exp.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/logo.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,16 @@
1
  ---
2
  title: Heartlib
3
- emoji: 🦀
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.3.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
  title: Heartlib
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: "5.35.0"
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # Heartlib
13
+
14
+ <p align="center">
15
+
16
+ Deployed from: https://github.com/HeartMuLa/heartlib
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def main_function(input_data):
4
+ if not input_data:
5
+ return "Please provide input"
6
+
7
+ result = f"Processed successfully! Input received: {input_data}"
8
+ return result
9
+
10
+ with gr.Blocks(title="heartlib") as demo:
11
+ gr.Markdown(f"""
12
+ # Heartlib
13
+
14
+ <p align="center">
15
+
16
+ This space was created from: [https://github.com/HeartMuLa/heartlib](https://github.com/HeartMuLa/heartlib)
17
+ """)
18
+
19
+ with gr.Row():
20
+ with gr.Column():
21
+ input_data = gr.Textbox(
22
+ label="Input",
23
+ placeholder="Enter your input here...",
24
+ lines=3
25
+ )
26
+ process_btn = gr.Button("Process", variant="primary")
27
+
28
+ with gr.Column():
29
+ output_data = gr.Textbox(label="Output")
30
+
31
+ process_btn.click(
32
+ fn=main_function,
33
+ inputs=input_data,
34
+ outputs=output_data
35
+ )
36
+
37
+ if __name__ == "__main__":
38
+ demo.launch()
assets/badge.svg ADDED
assets/exp.png ADDED

Git LFS Details

  • SHA256: 715d2f0083971cdf62990de35b717a38eae83a25ba434487e48d0612f444a891
  • Pointer size: 131 Bytes
  • Size of remote file: 555 kB
assets/logo.png ADDED

Git LFS Details

  • SHA256: 4a70ac32f4997dc5396da8b24df054fb8453b769febf2a12ece7c24fdc0e1668
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
assets/lyrics.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [Intro]
2
+
3
+ [Verse]
4
+ The sun creeps in across the floor
5
+ I hear the traffic outside the door
6
+ The coffee pot begins to hiss
7
+ It is another morning just like this
8
+
9
+ [Prechorus]
10
+ The world keeps spinning round and round
11
+ Feet are planted on the ground
12
+ I find my rhythm in the sound
13
+
14
+ [Chorus]
15
+ Every day the light returns
16
+ Every day the fire burns
17
+ We keep on walking down this street
18
+ Moving to the same steady beat
19
+ It is the ordinary magic that we meet
20
+
21
+ [Verse]
22
+ The hours tick deeply into noon
23
+ Chasing shadows,chasing the moon
24
+ Work is done and the lights go low
25
+ Watching the city start to glow
26
+
27
+ [Bridge]
28
+ It is not always easy,not always bright
29
+ Sometimes we wrestle with the night
30
+ But we make it to the morning light
31
+
32
+ [Chorus]
33
+ Every day the light returns
34
+ Every day the fire burns
35
+ We keep on walking down this street
36
+ Moving to the same steady beat
37
+
38
+ [Outro]
39
+ Just another day
40
+ Every single day
assets/tags.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ piano,happy
examples/README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎤 Lyrics Transcription
2
+
3
+ Download checkpoint using any of the following command:
4
+ ```
5
+ hf download --local_dir './ckpt/HeartTranscriptor-oss' 'HeartMuLa/HeartTranscriptor-oss'
6
+ modelscope download --model 'HeartMuLa/HeartTranscriptor-oss' --local_dir './ckpt/HeartTranscriptor-oss'
7
+ ```
8
+
9
+ ```
10
+ python ./examples/run_lyrics_transcription.py --model_path=./ckpt
11
+ ```
12
+
13
+ By default this command will load the generated music file at `./assets/output.mp3` and print the transcribed lyrics. Use `--music_path` to specify the path to the music file.
14
+
15
+ Note that our HeartTranscriptor is trained on separated vocal tracks. In this example usage part, we directly demonstrate on unseparated music tracks, which is purely for simplicity of illustration. We recommend using source separation tools like demucs to separate the tracks before transcribing lyrics to achieve better results.
examples/run_lyrics_transcription.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from heartlib import HeartTranscriptorPipeline
2
+ import argparse
3
+ import torch
4
+
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument("--model_path", type=str, required=True)
9
+ parser.add_argument("--music_path", type=str, default="./assets/output.mp3")
10
+
11
+ return parser.parse_args()
12
+
13
+
14
+ if __name__ == "__main__":
15
+ args = parse_args()
16
+ pipe = HeartTranscriptorPipeline.from_pretrained(
17
+ args.model_path,
18
+ device=torch.device("cuda"),
19
+ dtype=torch.float16,
20
+ )
21
+ with torch.no_grad():
22
+ result = pipe(
23
+ args.music_path,
24
+ **{
25
+ "max_new_tokens": 256,
26
+ "num_beams": 2,
27
+ "task": "transcribe",
28
+ "condition_on_prev_tokens": False,
29
+ "compression_ratio_threshold": 1.8,
30
+ "temperature": (0.0, 0.1, 0.2, 0.4),
31
+ "logprob_threshold": -1.0,
32
+ "no_speech_threshold": 0.4,
33
+ },
34
+ )
35
+ print(result)
examples/run_music_generation.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from heartlib import HeartMuLaGenPipeline
2
+ import argparse
3
+ import torch
4
+
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument("--model_path", type=str, required=True)
9
+ parser.add_argument("--version", type=str, default="3B")
10
+ parser.add_argument("--lyrics", type=str, default="./assets/lyrics.txt")
11
+ parser.add_argument("--tags", type=str, default="./assets/tags.txt")
12
+ parser.add_argument("--save_path", type=str, default="./assets/output.mp3")
13
+
14
+ parser.add_argument("--max_audio_length_ms", type=int, default=240_000)
15
+ parser.add_argument("--topk", type=int, default=50)
16
+ parser.add_argument("--temperature", type=float, default=1.0)
17
+ parser.add_argument("--cfg_scale", type=float, default=1.5)
18
+ return parser.parse_args()
19
+
20
+
21
+ if __name__ == "__main__":
22
+ args = parse_args()
23
+ pipe = HeartMuLaGenPipeline.from_pretrained(
24
+ args.model_path,
25
+ device=torch.device("cuda"),
26
+ dtype=torch.bfloat16,
27
+ version=args.version,
28
+ )
29
+ with torch.no_grad():
30
+ pipe(
31
+ {
32
+ "lyrics": args.lyrics,
33
+ "tags": args.tags,
34
+ },
35
+ max_audio_length_ms=args.max_audio_length_ms,
36
+ save_path=args.save_path,
37
+ topk=args.topk,
38
+ temperature=args.temperature,
39
+ cfg_scale=args.cfg_scale,
40
+ )
41
+ print(f"Generated music saved to {args.save_path}")
pyproject.toml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "heartlib"
7
+ version = "0.1.0"
8
+ description = "A Python Library."
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ license = {text = "CC-BY-NC-4.0"}
12
+ authors = [
13
+ {name = "HeartMuLa Team", email = "heartmula.ai@gmail.com"}
14
+ ]
15
+ dependencies = [
16
+ "numpy==2.0.2",
17
+ "torch==2.4.1",
18
+ "torchaudio==2.4.1",
19
+ "torchtune==0.4.0",
20
+ "torchao==0.9.0",
21
+ "torchvision==0.19.1",
22
+ "tqdm==4.67.1",
23
+ "traitlets==5.7.1",
24
+ "traittypes==0.2.3",
25
+ "transformers==4.57.0",
26
+ "tokenizers==0.22.1",
27
+ "ipykernel==6.17.1",
28
+ "einops==0.8.1",
29
+ "accelerate==1.12.0",
30
+ "bitsandbytes==0.49.0",
31
+ "vector-quantize-pytorch==1.27.15",
32
+ "modelscope==1.33.0",
33
+ "soundfile"
34
+ ]
35
+ urls = { "homepage" = "https://heartmula.github.io/" }
36
+ classifiers = [
37
+ "Programming Language :: Python :: 3",
38
+ "License :: Other/Proprietary License",
39
+ "Operating System :: OS Independent"
40
+ ]
41
+
42
+ [tool.setuptools]
43
+ package-dir = {"" = "src"}
44
+
45
+ [tool.setuptools.packages.find]
46
+ where = ["src"]
47
+
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ gradio>=5.35.0
src/heartlib/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .pipelines.music_generation import HeartMuLaGenPipeline
2
+ from .pipelines.lyrics_transcription import HeartTranscriptorPipeline
3
+
4
+ __all__ = [
5
+ "HeartMuLaGenPipeline",
6
+ "HeartTranscriptorPipeline"
7
+ ]
src/heartlib/heartcodec/configuration_heartcodec.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class HeartCodecConfig(PretrainedConfig):
6
+ model_type = "heartcodec"
7
+
8
+ def __init__(
9
+ self,
10
+ # config for rvq
11
+ dim: int = 512,
12
+ codebook_size: int = 8192,
13
+ decay: float = 0.9,
14
+ commitment_weight: float = 1.0,
15
+ threshold_ema_dead_code: int = 2,
16
+ use_cosine_sim: bool = False,
17
+ codebook_dim: int = 32,
18
+ num_quantizers: int = 8,
19
+ # config for diffusion transformer
20
+ attention_head_dim: int = 64,
21
+ in_channels: int = 1024,
22
+ norm_type: str = "ada_norm_single",
23
+ num_attention_heads: int = 24,
24
+ num_layers: int = 24,
25
+ num_layers_2: int = 6,
26
+ out_channels: int = 256,
27
+ # config for sq codec
28
+ num_bands: int = 1,
29
+ sample_rate: int = 48000,
30
+ causal: bool = True,
31
+ num_samples: int = 2,
32
+ downsample_factors: List[int] = [3, 4, 4, 4, 5],
33
+ downsample_kernel_sizes: List[int] = [6, 8, 8, 8, 10],
34
+ upsample_factors: List[int] = [5, 4, 4, 4, 3],
35
+ upsample_kernel_sizes: List[int] = [10, 8, 8, 8, 6],
36
+ latent_hidden_dim: int = 128,
37
+ default_kernel_size: int = 7,
38
+ delay_kernel_size: int = 5,
39
+ init_channel: int = 64,
40
+ res_kernel_size: int = 7,
41
+ **kwargs
42
+ ):
43
+ super().__init__(**kwargs)
44
+ self.dim = dim
45
+ self.codebook_size = codebook_size
46
+ self.decay = decay
47
+ self.commitment_weight = commitment_weight
48
+ self.threshold_ema_dead_code = threshold_ema_dead_code
49
+ self.use_cosine_sim = use_cosine_sim
50
+ self.codebook_dim = codebook_dim
51
+ self.num_quantizers = num_quantizers
52
+
53
+ self.attention_head_dim = attention_head_dim
54
+ self.in_channels = in_channels
55
+ self.norm_type = norm_type
56
+ self.num_attention_heads = num_attention_heads
57
+ self.num_layers = num_layers
58
+ self.num_layers_2 = num_layers_2
59
+ self.out_channels = out_channels
60
+
61
+ self.num_bands = num_bands
62
+ self.sample_rate = sample_rate
63
+ self.causal = causal
64
+ self.num_samples = num_samples
65
+ self.downsample_factors = downsample_factors
66
+ self.downsample_kernel_sizes = downsample_kernel_sizes
67
+ self.upsample_factors = upsample_factors
68
+ self.upsample_kernel_sizes = upsample_kernel_sizes
69
+ self.latent_hidden_dim = latent_hidden_dim
70
+ self.default_kernel_size = default_kernel_size
71
+ self.delay_kernel_size = delay_kernel_size
72
+ self.init_channel = init_channel
73
+ self.res_kernel_size = res_kernel_size
src/heartlib/heartcodec/modeling_heartcodec.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .models.flow_matching import FlowMatching
3
+ from .models.sq_codec import ScalarModel
4
+ from .configuration_heartcodec import HeartCodecConfig
5
+ from transformers.modeling_utils import PreTrainedModel
6
+ import math
7
+ import numpy as np
8
+
9
+
10
+ class HeartCodec(PreTrainedModel):
11
+ config_class = HeartCodecConfig
12
+
13
+ def __init__(
14
+ self,
15
+ config: HeartCodecConfig,
16
+ ):
17
+ super(HeartCodec, self).__init__(config)
18
+
19
+ self.config = config
20
+
21
+ self.flow_matching = FlowMatching(
22
+ dim=config.dim,
23
+ codebook_size=config.codebook_size,
24
+ decay=config.decay,
25
+ commitment_weight=config.commitment_weight,
26
+ threshold_ema_dead_code=config.threshold_ema_dead_code,
27
+ use_cosine_sim=config.use_cosine_sim,
28
+ codebook_dim=config.codebook_dim,
29
+ num_quantizers=config.num_quantizers,
30
+ attention_head_dim=config.attention_head_dim,
31
+ in_channels=config.in_channels,
32
+ norm_type=config.norm_type,
33
+ num_attention_heads=config.num_attention_heads,
34
+ num_layers=config.num_layers,
35
+ num_layers_2=config.num_layers_2,
36
+ out_channels=config.out_channels,
37
+ )
38
+ self.scalar_model = ScalarModel(
39
+ num_bands=config.num_bands,
40
+ sample_rate=config.sample_rate,
41
+ causal=config.causal,
42
+ num_samples=config.num_samples,
43
+ downsample_factors=config.downsample_factors,
44
+ downsample_kernel_sizes=config.downsample_kernel_sizes,
45
+ upsample_factors=config.upsample_factors,
46
+ upsample_kernel_sizes=config.upsample_kernel_sizes,
47
+ latent_hidden_dim=config.latent_hidden_dim,
48
+ default_kernel_size=config.default_kernel_size,
49
+ delay_kernel_size=config.delay_kernel_size,
50
+ init_channel=config.init_channel,
51
+ res_kernel_size=config.res_kernel_size,
52
+ )
53
+ self.post_init()
54
+
55
+ self.sample_rate = config.sample_rate
56
+
57
+ @torch.inference_mode()
58
+ def detokenize(
59
+ self,
60
+ codes,
61
+ duration=29.76,
62
+ num_steps=10,
63
+ disable_progress=False,
64
+ guidance_scale=1.25,
65
+ device="cuda",
66
+ ):
67
+ codes = codes.unsqueeze(0).to(device)
68
+ first_latent = torch.randn(codes.shape[0], int(duration * 25), 256).to(
69
+ device
70
+ ) # B, T, 64
71
+ first_latent_length = 0
72
+ first_latent_codes_length = 0
73
+ min_samples = int(duration * 12.5)
74
+ hop_samples = min_samples // 93 * 80
75
+ ovlp_samples = min_samples - hop_samples
76
+ ovlp_frames = ovlp_samples * 2
77
+ codes_len = codes.shape[-1] #
78
+ target_len = int(
79
+ (codes_len - first_latent_codes_length) / 12.5 * self.sample_rate
80
+ )
81
+
82
+ # code repeat
83
+ if codes_len < min_samples:
84
+ while codes.shape[-1] < min_samples:
85
+ codes = torch.cat([codes, codes], -1)
86
+ codes = codes[:, :, 0:min_samples]
87
+ codes_len = codes.shape[-1]
88
+ if (codes_len - ovlp_frames) % hop_samples > 0:
89
+ len_codes = (
90
+ math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples
91
+ + ovlp_samples
92
+ )
93
+ while codes.shape[-1] < len_codes:
94
+ codes = torch.cat([codes, codes], -1)
95
+ codes = codes[:, :, 0:len_codes]
96
+ latent_length = int(duration * 25)
97
+ latent_list = []
98
+
99
+ for sinx in range(0, codes.shape[-1] - hop_samples + 1, hop_samples):
100
+ codes_input = []
101
+ codes_input.append(codes[:, :, sinx : sinx + min_samples])
102
+ if sinx == 0 or ovlp_frames == 0:
103
+ incontext_length = first_latent_length
104
+ latents = self.flow_matching.inference_codes(
105
+ codes_input,
106
+ first_latent,
107
+ latent_length,
108
+ incontext_length,
109
+ guidance_scale=guidance_scale,
110
+ num_steps=num_steps,
111
+ disable_progress=disable_progress,
112
+ scenario="other_seg",
113
+ )
114
+ latent_list.append(latents)
115
+ else:
116
+ true_latent = latent_list[-1][:, -ovlp_frames:, :]
117
+ len_add_to_latent = latent_length - true_latent.shape[1] #
118
+ incontext_length = true_latent.shape[1]
119
+ true_latent = torch.cat(
120
+ [
121
+ true_latent,
122
+ torch.randn(
123
+ true_latent.shape[0],
124
+ len_add_to_latent,
125
+ true_latent.shape[-1],
126
+ ).to(device),
127
+ ],
128
+ 1,
129
+ )
130
+ latents = self.flow_matching.inference_codes(
131
+ codes_input,
132
+ true_latent,
133
+ latent_length,
134
+ incontext_length,
135
+ guidance_scale=guidance_scale,
136
+ num_steps=num_steps,
137
+ disable_progress=disable_progress,
138
+ scenario="other_seg",
139
+ )
140
+ latent_list.append(latents)
141
+
142
+ latent_list = [l.float() for l in latent_list]
143
+ latent_list[0] = latent_list[0][:, first_latent_length:, :]
144
+ min_samples = int(duration * self.sample_rate)
145
+ hop_samples = min_samples // 93 * 80
146
+ ovlp_samples = min_samples - hop_samples
147
+
148
+ output = None
149
+ for i in range(len(latent_list)):
150
+ latent = latent_list[i]
151
+ bsz, t, f = latent.shape
152
+
153
+ latent = latent.reshape(
154
+ latent.shape[0], latent.shape[1], 2, latent.shape[2] // 2
155
+ ).permute(0, 2, 1, 3)
156
+ latent = latent.reshape(
157
+ latent.shape[0] * 2, latent.shape[2], latent.shape[3]
158
+ )
159
+ cur_output = (
160
+ self.scalar_model.decode(latent.transpose(1, 2)).squeeze(0).squeeze(1)
161
+ ) # 1 512 256
162
+
163
+ cur_output = cur_output[:, 0:min_samples].detach().cpu() # B, T
164
+ if cur_output.dim() == 3:
165
+ cur_output = cur_output[0]
166
+
167
+ if output is None:
168
+ output = cur_output
169
+ else:
170
+ if ovlp_samples == 0:
171
+ output = torch.cat([output, cur_output], -1)
172
+ else:
173
+ ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
174
+ ov_win = torch.cat([ov_win, 1 - ov_win], -1)
175
+ output[:, -ovlp_samples:] = (
176
+ output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:]
177
+ + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
178
+ )
179
+ output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
180
+ output = output[:, 0:target_len]
181
+ return output
src/heartlib/heartcodec/models/flow_matching.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from tqdm import tqdm
5
+ from vector_quantize_pytorch import ResidualVQ
6
+ from .transformer import LlamaTransformer
7
+
8
+
9
+ class FlowMatching(nn.Module):
10
+ def __init__(
11
+ self,
12
+ # rvq stuff
13
+ dim: int = 512,
14
+ codebook_size: int = 8192,
15
+ decay: float = 0.9,
16
+ commitment_weight: float = 1.0,
17
+ threshold_ema_dead_code: int = 2,
18
+ use_cosine_sim: bool = False,
19
+ codebook_dim: int = 32,
20
+ num_quantizers: int = 8,
21
+ # dit backbone stuff
22
+ attention_head_dim: int = 64,
23
+ in_channels: int = 1024,
24
+ norm_type: str = "ada_norm_single",
25
+ num_attention_heads: int = 24,
26
+ num_layers: int = 24,
27
+ num_layers_2: int = 6,
28
+ out_channels: int = 256,
29
+ ):
30
+ super().__init__()
31
+
32
+ self.vq_embed = ResidualVQ(
33
+ dim=dim,
34
+ codebook_size=codebook_size,
35
+ decay=decay,
36
+ commitment_weight=commitment_weight,
37
+ threshold_ema_dead_code=threshold_ema_dead_code,
38
+ use_cosine_sim=use_cosine_sim,
39
+ codebook_dim=codebook_dim,
40
+ num_quantizers=num_quantizers,
41
+ )
42
+ self.cond_feature_emb = nn.Linear(dim, dim)
43
+ self.zero_cond_embedding1 = nn.Parameter(torch.randn(dim))
44
+ self.estimator = LlamaTransformer(
45
+ attention_head_dim=attention_head_dim,
46
+ in_channels=in_channels,
47
+ norm_type=norm_type,
48
+ num_attention_heads=num_attention_heads,
49
+ num_layers=num_layers,
50
+ num_layers_2=num_layers_2,
51
+ out_channels=out_channels,
52
+ )
53
+
54
+ self.latent_dim = out_channels
55
+
56
+ @torch.no_grad()
57
+ def inference_codes(
58
+ self,
59
+ codes,
60
+ true_latents,
61
+ latent_length,
62
+ incontext_length,
63
+ guidance_scale=2.0,
64
+ num_steps=20,
65
+ disable_progress=True,
66
+ scenario="start_seg",
67
+ ):
68
+ device = true_latents.device
69
+ dtype = true_latents.dtype
70
+ # codes_bestrq_middle, codes_bestrq_last = codes
71
+ codes_bestrq_emb = codes[0]
72
+
73
+ batch_size = codes_bestrq_emb.shape[0]
74
+ self.vq_embed.eval()
75
+ quantized_feature_emb = self.vq_embed.get_output_from_indices(
76
+ codes_bestrq_emb.transpose(1, 2)
77
+ )
78
+ quantized_feature_emb = self.cond_feature_emb(quantized_feature_emb) # b t 512
79
+ # assert 1==2
80
+ quantized_feature_emb = F.interpolate(
81
+ quantized_feature_emb.permute(0, 2, 1), scale_factor=2, mode="nearest"
82
+ ).permute(0, 2, 1)
83
+
84
+ num_frames = quantized_feature_emb.shape[1] #
85
+ latents = torch.randn(
86
+ (batch_size, num_frames, self.latent_dim), device=device, dtype=dtype
87
+ )
88
+ latent_masks = torch.zeros(
89
+ latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device
90
+ )
91
+ latent_masks[:, 0:latent_length] = 2
92
+ if scenario == "other_seg":
93
+ latent_masks[:, 0:incontext_length] = 1
94
+
95
+ quantized_feature_emb = (latent_masks > 0.5).unsqueeze(
96
+ -1
97
+ ) * quantized_feature_emb + (latent_masks < 0.5).unsqueeze(
98
+ -1
99
+ ) * self.zero_cond_embedding1.unsqueeze(
100
+ 0
101
+ )
102
+
103
+ incontext_latents = (
104
+ true_latents
105
+ * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
106
+ )
107
+ incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
108
+
109
+ additional_model_input = torch.cat([quantized_feature_emb], 1)
110
+ temperature = 1.0
111
+ t_span = torch.linspace(
112
+ 0, 1, num_steps + 1, device=quantized_feature_emb.device
113
+ )
114
+ latents = self.solve_euler(
115
+ latents * temperature,
116
+ incontext_latents,
117
+ incontext_length,
118
+ t_span,
119
+ additional_model_input,
120
+ guidance_scale,
121
+ )
122
+
123
+ latents[:, 0:incontext_length, :] = incontext_latents[
124
+ :, 0:incontext_length, :
125
+ ] # B, T, dim
126
+ return latents
127
+
128
+ def solve_euler(self, x, incontext_x, incontext_length, t_span, mu, guidance_scale):
129
+ """
130
+ Fixed euler solver for ODEs.
131
+ Args:
132
+ x (torch.Tensor): random noise
133
+ t_span (torch.Tensor): n_timesteps interpolated
134
+ shape: (n_timesteps + 1,)
135
+ mu (torch.Tensor): output of encoder
136
+ shape: (batch_size, n_feats, mel_timesteps)
137
+ """
138
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
139
+ noise = x.clone()
140
+
141
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
142
+ # Or in future might add like a return_all_steps flag
143
+ sol = []
144
+ for step in tqdm(range(1, len(t_span))):
145
+ x[:, 0:incontext_length, :] = (1 - (1 - 1e-6) * t) * noise[
146
+ :, 0:incontext_length, :
147
+ ] + t * incontext_x[:, 0:incontext_length, :]
148
+ if guidance_scale > 1.0:
149
+ dphi_dt = self.estimator(
150
+ torch.cat(
151
+ [
152
+ torch.cat([x, x], 0),
153
+ torch.cat([incontext_x, incontext_x], 0),
154
+ torch.cat([torch.zeros_like(mu), mu], 0),
155
+ ],
156
+ 2,
157
+ ),
158
+ timestep=t.unsqueeze(-1).repeat(2),
159
+ )
160
+ dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2, 0)
161
+ dphi_dt = dphi_dt_uncond + guidance_scale * (
162
+ dhpi_dt_cond - dphi_dt_uncond
163
+ )
164
+ else:
165
+ dphi_dt = self.estimator(
166
+ torch.cat([x, incontext_x, mu], 2), timestep=t.unsqueeze(-1)
167
+ )
168
+
169
+ x = x + dt * dphi_dt
170
+ t = t + dt
171
+ sol.append(x)
172
+ if step < len(t_span) - 1:
173
+ dt = t_span[step + 1] - t
174
+
175
+ result = sol[-1]
176
+
177
+ return result
src/heartlib/heartcodec/models/sq_codec.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from torch.nn.utils.parametrizations import weight_norm
6
+ from torch.nn.utils import remove_weight_norm
7
+ from torch.autograd.function import InplaceFunction
8
+
9
+
10
+ def get_padding(kernel_size, dilation=1):
11
+ return int((kernel_size * dilation - dilation) / 2)
12
+
13
+
14
+ # Scripting this brings model speed up 1.4x
15
+ @torch.jit.script
16
+ def snake(x, alpha):
17
+ shape = x.shape
18
+ x = x.reshape(shape[0], shape[1], -1)
19
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
20
+ x = x.reshape(shape)
21
+ return x
22
+
23
+
24
+ class Snake1d(nn.Module):
25
+ def __init__(self, channels):
26
+ super().__init__()
27
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
28
+
29
+ def forward(self, x):
30
+ return snake(x, self.alpha)
31
+
32
+
33
+ class Conv1d(nn.Conv1d):
34
+ def __init__(
35
+ self,
36
+ in_channels: int,
37
+ out_channels: int,
38
+ kernel_size: int,
39
+ stride: int = 1,
40
+ dilation: int = 1,
41
+ groups: int = 1,
42
+ padding_mode: str = "zeros",
43
+ bias: bool = True,
44
+ padding=None,
45
+ causal: bool = False,
46
+ w_init_gain=None,
47
+ ):
48
+ self.causal = causal
49
+ if padding is None:
50
+ if causal:
51
+ padding = 0
52
+ self.left_padding = dilation * (kernel_size - 1)
53
+ else:
54
+ padding = get_padding(kernel_size, dilation)
55
+ super(Conv1d, self).__init__(
56
+ in_channels,
57
+ out_channels,
58
+ kernel_size,
59
+ stride=stride,
60
+ padding=padding,
61
+ dilation=dilation,
62
+ groups=groups,
63
+ padding_mode=padding_mode,
64
+ bias=bias,
65
+ )
66
+ if w_init_gain is not None:
67
+ torch.nn.init.xavier_uniform_(
68
+ self.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
69
+ )
70
+
71
+ def forward(self, x):
72
+ if self.causal:
73
+ x = F.pad(x.unsqueeze(2), (self.left_padding, 0, 0, 0)).squeeze(2)
74
+
75
+ return super(Conv1d, self).forward(x)
76
+
77
+
78
+ class ConvTranspose1d(nn.ConvTranspose1d):
79
+ def __init__(
80
+ self,
81
+ in_channels: int,
82
+ out_channels: int,
83
+ kernel_size: int,
84
+ stride: int = 1,
85
+ output_padding: int = 0,
86
+ groups: int = 1,
87
+ bias: bool = True,
88
+ dilation: int = 1,
89
+ padding=None,
90
+ padding_mode: str = "zeros",
91
+ causal: bool = False,
92
+ ):
93
+ if padding is None:
94
+ padding = 0 if causal else (kernel_size - stride) // 2
95
+ if causal:
96
+ assert padding == 0, "padding is not allowed in causal ConvTranspose1d."
97
+ assert (
98
+ kernel_size == 2 * stride
99
+ ), "kernel_size must be equal to 2*stride is not allowed in causal ConvTranspose1d."
100
+ super(ConvTranspose1d, self).__init__(
101
+ in_channels,
102
+ out_channels,
103
+ kernel_size,
104
+ stride=stride,
105
+ padding=padding,
106
+ output_padding=output_padding,
107
+ groups=groups,
108
+ bias=bias,
109
+ dilation=dilation,
110
+ padding_mode=padding_mode,
111
+ )
112
+ self.causal = causal
113
+ self.stride = stride
114
+
115
+ def forward(self, x):
116
+ x = super(ConvTranspose1d, self).forward(x)
117
+ if self.causal:
118
+ x = x[:, :, : -self.stride]
119
+ return x
120
+
121
+
122
+ class PreProcessor(nn.Module):
123
+ def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False):
124
+ super(PreProcessor, self).__init__()
125
+ self.pooling = torch.nn.AvgPool1d(kernel_size=num_samples)
126
+ self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal)
127
+ self.activation = nn.PReLU()
128
+
129
+ def forward(self, x):
130
+ output = self.activation(self.conv(x))
131
+ output = self.pooling(output)
132
+ return output
133
+
134
+
135
+ class PostProcessor(nn.Module):
136
+ def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False):
137
+ super(PostProcessor, self).__init__()
138
+ self.num_samples = num_samples
139
+ self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal)
140
+ self.activation = nn.PReLU()
141
+
142
+ def forward(self, x):
143
+ x = torch.transpose(x, 1, 2)
144
+ B, T, C = x.size()
145
+ x = x.repeat(1, 1, self.num_samples).view(B, -1, C)
146
+ x = torch.transpose(x, 1, 2)
147
+ output = self.activation(self.conv(x))
148
+ return output
149
+
150
+
151
+ class ResidualUnit(nn.Module):
152
+ def __init__(self, n_in, n_out, dilation, res_kernel_size=7, causal=False):
153
+ super(ResidualUnit, self).__init__()
154
+ self.conv1 = weight_norm(
155
+ Conv1d(
156
+ n_in,
157
+ n_out,
158
+ kernel_size=res_kernel_size,
159
+ dilation=dilation,
160
+ causal=causal,
161
+ )
162
+ )
163
+ self.conv2 = weight_norm(Conv1d(n_in, n_out, kernel_size=1, causal=causal))
164
+ self.activation1 = nn.PReLU()
165
+ self.activation2 = nn.PReLU()
166
+
167
+ def forward(self, x):
168
+ output = self.activation1(self.conv1(x))
169
+ output = self.activation2(self.conv2(output))
170
+ return output + x
171
+
172
+
173
+ class ResEncoderBlock(nn.Module):
174
+ def __init__(
175
+ self, n_in, n_out, stride, down_kernel_size, res_kernel_size=7, causal=False
176
+ ):
177
+ super(ResEncoderBlock, self).__init__()
178
+ self.convs = nn.ModuleList(
179
+ [
180
+ ResidualUnit(
181
+ n_in,
182
+ n_out // 2,
183
+ dilation=1,
184
+ res_kernel_size=res_kernel_size,
185
+ causal=causal,
186
+ ),
187
+ ResidualUnit(
188
+ n_out // 2,
189
+ n_out // 2,
190
+ dilation=3,
191
+ res_kernel_size=res_kernel_size,
192
+ causal=causal,
193
+ ),
194
+ ResidualUnit(
195
+ n_out // 2,
196
+ n_out // 2,
197
+ dilation=5,
198
+ res_kernel_size=res_kernel_size,
199
+ causal=causal,
200
+ ),
201
+ ResidualUnit(
202
+ n_out // 2,
203
+ n_out // 2,
204
+ dilation=7,
205
+ res_kernel_size=res_kernel_size,
206
+ causal=causal,
207
+ ),
208
+ ResidualUnit(
209
+ n_out // 2,
210
+ n_out // 2,
211
+ dilation=9,
212
+ res_kernel_size=res_kernel_size,
213
+ causal=causal,
214
+ ),
215
+ ]
216
+ )
217
+
218
+ self.down_conv = DownsampleLayer(
219
+ n_in, n_out, down_kernel_size, stride=stride, causal=causal
220
+ )
221
+
222
+ def forward(self, x):
223
+ for conv in self.convs:
224
+ x = conv(x)
225
+ x = self.down_conv(x)
226
+ return x
227
+
228
+
229
+ class ResDecoderBlock(nn.Module):
230
+ def __init__(
231
+ self, n_in, n_out, stride, up_kernel_size, res_kernel_size=7, causal=False
232
+ ):
233
+ super(ResDecoderBlock, self).__init__()
234
+ self.up_conv = UpsampleLayer(
235
+ n_in,
236
+ n_out,
237
+ kernel_size=up_kernel_size,
238
+ stride=stride,
239
+ causal=causal,
240
+ activation=None,
241
+ )
242
+
243
+ self.convs = nn.ModuleList(
244
+ [
245
+ ResidualUnit(
246
+ n_out,
247
+ n_out,
248
+ dilation=1,
249
+ res_kernel_size=res_kernel_size,
250
+ causal=causal,
251
+ ),
252
+ ResidualUnit(
253
+ n_out,
254
+ n_out,
255
+ dilation=3,
256
+ res_kernel_size=res_kernel_size,
257
+ causal=causal,
258
+ ),
259
+ ResidualUnit(
260
+ n_out,
261
+ n_out,
262
+ dilation=5,
263
+ res_kernel_size=res_kernel_size,
264
+ causal=causal,
265
+ ),
266
+ ResidualUnit(
267
+ n_out,
268
+ n_out,
269
+ dilation=7,
270
+ res_kernel_size=res_kernel_size,
271
+ causal=causal,
272
+ ),
273
+ ResidualUnit(
274
+ n_out,
275
+ n_out,
276
+ dilation=9,
277
+ res_kernel_size=res_kernel_size,
278
+ causal=causal,
279
+ ),
280
+ ]
281
+ )
282
+
283
+ def forward(self, x):
284
+ x = self.up_conv(x)
285
+ for conv in self.convs:
286
+ x = conv(x)
287
+ return x
288
+
289
+
290
+ class DownsampleLayer(nn.Module):
291
+ def __init__(
292
+ self,
293
+ in_channels: int,
294
+ out_channels: int,
295
+ kernel_size: int,
296
+ stride: int = 1,
297
+ causal: bool = False,
298
+ activation=nn.PReLU(),
299
+ use_weight_norm: bool = True,
300
+ pooling: bool = False,
301
+ ):
302
+ super(DownsampleLayer, self).__init__()
303
+ self.pooling = pooling
304
+ self.stride = stride
305
+ self.activation = nn.PReLU()
306
+ self.use_weight_norm = use_weight_norm
307
+ if pooling:
308
+ self.layer = Conv1d(in_channels, out_channels, kernel_size, causal=causal)
309
+ self.pooling = nn.AvgPool1d(kernel_size=stride)
310
+ else:
311
+ self.layer = Conv1d(
312
+ in_channels, out_channels, kernel_size, stride=stride, causal=causal
313
+ )
314
+ if use_weight_norm:
315
+ self.layer = weight_norm(self.layer)
316
+
317
+ def forward(self, x):
318
+ x = self.layer(x)
319
+ x = self.activation(x) if self.activation is not None else x
320
+ if self.pooling:
321
+ x = self.pooling(x)
322
+ return x
323
+
324
+ def remove_weight_norm(self):
325
+ if self.use_weight_norm:
326
+ remove_weight_norm(self.layer)
327
+
328
+
329
+ class UpsampleLayer(nn.Module):
330
+ def __init__(
331
+ self,
332
+ in_channels: int,
333
+ out_channels: int,
334
+ kernel_size: int,
335
+ stride: int = 1,
336
+ causal: bool = False,
337
+ activation=nn.PReLU(),
338
+ use_weight_norm: bool = True,
339
+ repeat: bool = False,
340
+ ):
341
+ super(UpsampleLayer, self).__init__()
342
+ self.repeat = repeat
343
+ self.stride = stride
344
+ self.activation = activation
345
+ self.use_weight_norm = use_weight_norm
346
+ if repeat:
347
+ self.layer = Conv1d(in_channels, out_channels, kernel_size, causal=causal)
348
+ else:
349
+ self.layer = ConvTranspose1d(
350
+ in_channels, out_channels, kernel_size, stride=stride, causal=causal
351
+ )
352
+ if use_weight_norm:
353
+ self.layer = weight_norm(self.layer)
354
+
355
+ def forward(self, x):
356
+ x = self.layer(x)
357
+ x = self.activation(x) if self.activation is not None else x
358
+ if self.repeat:
359
+ x = torch.transpose(x, 1, 2)
360
+ B, T, C = x.size()
361
+ x = x.repeat(1, 1, self.stride).view(B, -1, C)
362
+ x = torch.transpose(x, 1, 2)
363
+ return x
364
+
365
+ def remove_weight_norm(self):
366
+ if self.use_weight_norm:
367
+ remove_weight_norm(self.layer)
368
+
369
+
370
+ class round_func9(InplaceFunction):
371
+ @staticmethod
372
+ def forward(ctx, input):
373
+ ctx.input = input
374
+ return torch.round(9 * input) / 9
375
+
376
+ @staticmethod
377
+ def backward(ctx, grad_output):
378
+ grad_input = grad_output.clone()
379
+ return grad_input
380
+
381
+
382
+ class ScalarModel(nn.Module):
383
+ def __init__(
384
+ self,
385
+ num_bands,
386
+ sample_rate,
387
+ causal,
388
+ num_samples,
389
+ downsample_factors,
390
+ downsample_kernel_sizes,
391
+ upsample_factors,
392
+ upsample_kernel_sizes,
393
+ latent_hidden_dim,
394
+ default_kernel_size,
395
+ delay_kernel_size,
396
+ init_channel,
397
+ res_kernel_size,
398
+ mode="pre_proj",
399
+ ):
400
+ super(ScalarModel, self).__init__()
401
+ # self.args = args
402
+ self.encoder = []
403
+ self.decoder = []
404
+ self.vq = round_func9() # using 9
405
+ self.mode = mode
406
+ # Encoder parts
407
+ self.encoder.append(
408
+ weight_norm(
409
+ Conv1d(
410
+ num_bands,
411
+ init_channel,
412
+ kernel_size=default_kernel_size,
413
+ causal=causal,
414
+ )
415
+ )
416
+ )
417
+ if num_samples > 1:
418
+ # Downsampling
419
+ self.encoder.append(
420
+ PreProcessor(
421
+ init_channel,
422
+ init_channel,
423
+ num_samples,
424
+ kernel_size=default_kernel_size,
425
+ causal=causal,
426
+ )
427
+ )
428
+ for i, down_factor in enumerate(downsample_factors):
429
+ self.encoder.append(
430
+ ResEncoderBlock(
431
+ init_channel * np.power(2, i),
432
+ init_channel * np.power(2, i + 1),
433
+ down_factor,
434
+ downsample_kernel_sizes[i],
435
+ res_kernel_size,
436
+ causal=causal,
437
+ )
438
+ )
439
+ self.encoder.append(
440
+ weight_norm(
441
+ Conv1d(
442
+ init_channel * np.power(2, len(downsample_factors)),
443
+ latent_hidden_dim,
444
+ kernel_size=default_kernel_size,
445
+ causal=causal,
446
+ )
447
+ )
448
+ )
449
+ # Decoder
450
+ # look ahead
451
+ self.decoder.append(
452
+ weight_norm(
453
+ Conv1d(
454
+ latent_hidden_dim,
455
+ init_channel * np.power(2, len(upsample_factors)),
456
+ kernel_size=delay_kernel_size,
457
+ )
458
+ )
459
+ )
460
+ for i, upsample_factor in enumerate(upsample_factors):
461
+ self.decoder.append(
462
+ ResDecoderBlock(
463
+ init_channel * np.power(2, len(upsample_factors) - i),
464
+ init_channel * np.power(2, len(upsample_factors) - i - 1),
465
+ upsample_factor,
466
+ upsample_kernel_sizes[i],
467
+ res_kernel_size,
468
+ causal=causal,
469
+ )
470
+ )
471
+ if num_samples > 1:
472
+ self.decoder.append(
473
+ PostProcessor(
474
+ init_channel,
475
+ init_channel,
476
+ num_samples,
477
+ kernel_size=default_kernel_size,
478
+ causal=causal,
479
+ )
480
+ )
481
+ self.decoder.append(
482
+ weight_norm(
483
+ Conv1d(
484
+ init_channel,
485
+ num_bands,
486
+ kernel_size=default_kernel_size,
487
+ causal=causal,
488
+ )
489
+ )
490
+ )
491
+ self.encoder = nn.ModuleList(self.encoder)
492
+ self.decoder = nn.ModuleList(self.decoder)
493
+
494
+ def forward(self, x):
495
+ for i, layer in enumerate(self.encoder):
496
+ if i != len(self.encoder) - 1:
497
+ x = layer(x)
498
+ else:
499
+ x = F.tanh(layer(x))
500
+ # import pdb; pdb.set_trace()
501
+ x = self.vq.apply(x) # vq
502
+ for i, layer in enumerate(self.decoder):
503
+ x = layer(x)
504
+ return x
505
+
506
+ def inference(self, x):
507
+ for i, layer in enumerate(self.encoder):
508
+ if i != len(self.encoder) - 1:
509
+ x = layer(x)
510
+ else:
511
+ x = F.tanh(layer(x)) # reverse to tanh
512
+
513
+ emb = x
514
+ # import pdb; pdb.set_trace()
515
+ emb_quant = self.vq.apply(emb) # vq
516
+ x = emb_quant
517
+ for i, layer in enumerate(self.decoder):
518
+ x = layer(x)
519
+ return emb, emb_quant, x
520
+
521
+ def encode(self, x):
522
+ for i, layer in enumerate(self.encoder):
523
+ if i != len(self.encoder) - 1:
524
+ x = layer(x)
525
+ else:
526
+ x = F.tanh(layer(x)) # reverse to tanh
527
+
528
+ emb = x
529
+ # import pdb; pdb.set_trace()
530
+ emb_quant = self.vq.apply(emb) # vq
531
+ return emb
532
+
533
+ def decode(self, x):
534
+ x = self.vq.apply(
535
+ x
536
+ ) # make sure the prediction follow the similar disctribution
537
+ for i, layer in enumerate(self.decoder):
538
+ x = layer(x)
539
+ return x
src/heartlib/heartcodec/models/transformer.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class RMSNorm(nn.Module):
9
+ def __init__(self, dim: int, eps: float = 1e-6):
10
+ super().__init__()
11
+ self.eps = eps
12
+ self.weight = nn.Parameter(torch.ones(dim))
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ var = x.pow(2).mean(dim=-1, keepdim=True)
16
+ x = x * torch.rsqrt(var + self.eps)
17
+ return self.weight * x
18
+
19
+
20
+ class RotaryEmbedding(nn.Module):
21
+ def __init__(self, dim: int, base: int = 10000):
22
+ super().__init__()
23
+ self.dim = dim
24
+ self.base = base
25
+ self._cache = {}
26
+
27
+ def get_sin_cos(self, seq_len: int, device, dtype):
28
+ key = (seq_len, device, dtype)
29
+ cached = self._cache.get(key, None)
30
+ if cached is not None and cached[0].device == device:
31
+ return cached
32
+ inv_freq = 1.0 / (
33
+ self.base
34
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=dtype) / self.dim)
35
+ )
36
+ t = torch.arange(seq_len, device=device, dtype=dtype)
37
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
38
+ sin = freqs.sin()
39
+ cos = freqs.cos()
40
+ self._cache[key] = (sin, cos)
41
+ return sin, cos
42
+
43
+ def apply_rotary(
44
+ self, x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor
45
+ ) -> torch.Tensor:
46
+ x1, x2 = x[..., : self.dim // 2], x[..., self.dim // 2 : self.dim]
47
+ # Interleave sin/cos across pairs
48
+ x_rot = torch.stack((-x2, x1), dim=-1).reshape_as(x[..., : self.dim])
49
+ return (x[..., : self.dim] * cos.unsqueeze(-1)).reshape_as(
50
+ x[..., : self.dim]
51
+ ) + (x_rot * sin.unsqueeze(-1)).reshape_as(x[..., : self.dim])
52
+
53
+
54
+ class LlamaAttention(nn.Module):
55
+ def __init__(
56
+ self,
57
+ dim: int,
58
+ n_heads: int,
59
+ head_dim: int,
60
+ bias: bool = False,
61
+ dropout: float = 0.0,
62
+ rope_dim: Optional[int] = None,
63
+ cross_attention_dim: Optional[int] = None,
64
+ use_sdpa: bool = True,
65
+ ):
66
+ super().__init__()
67
+ self.dim = dim
68
+ self.n_heads = n_heads
69
+ self.head_dim = head_dim
70
+ self.inner_dim = n_heads * head_dim
71
+ self.cross_attention_dim = cross_attention_dim
72
+ self.q_proj = nn.Linear(dim, self.inner_dim, bias=bias)
73
+ k_in = dim if cross_attention_dim is None else cross_attention_dim
74
+ self.k_proj = nn.Linear(k_in, self.inner_dim, bias=bias)
75
+ self.v_proj = nn.Linear(k_in, self.inner_dim, bias=bias)
76
+ self.o_proj = nn.Linear(self.inner_dim, dim, bias=bias)
77
+ self.dropout = dropout
78
+ self.rope_dim = rope_dim if rope_dim is not None else head_dim
79
+ self.rope = RotaryEmbedding(self.rope_dim)
80
+ self.use_sdpa = use_sdpa
81
+ self._has_sdpa = hasattr(F, "scaled_dot_product_attention")
82
+
83
+ def _shape(self, x: torch.Tensor, b: int, t: int) -> torch.Tensor:
84
+ return x.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
85
+
86
+ def forward(
87
+ self,
88
+ x: torch.Tensor,
89
+ encoder_hidden_states: Optional[torch.Tensor] = None,
90
+ attention_mask: Optional[torch.Tensor] = None,
91
+ ) -> torch.Tensor:
92
+ b, t, c = x.shape
93
+ q = self._shape(self.q_proj(x), b, t)
94
+ if encoder_hidden_states is None:
95
+ k = self._shape(self.k_proj(x), b, t)
96
+ v = self._shape(self.v_proj(x), b, t)
97
+ else:
98
+ bt, tk, ck = encoder_hidden_states.shape
99
+ k = self._shape(self.k_proj(encoder_hidden_states), b, tk)
100
+ v = self._shape(self.v_proj(encoder_hidden_states), b, tk)
101
+
102
+ # RoPE on first rope_dim of head_dim
103
+ rope_dim = min(self.rope_dim, self.head_dim)
104
+ seq_len_for_rope = k.shape[-2]
105
+ sin, cos = self.rope.get_sin_cos(
106
+ seq_len_for_rope, device=x.device, dtype=x.dtype
107
+ )
108
+
109
+ def apply_rope_vec(tensor):
110
+ head = tensor[..., :rope_dim]
111
+ tail = tensor[..., rope_dim:]
112
+ b, h, tt, _ = head.shape
113
+ head = head.view(b, h, tt, rope_dim // 2, 2)
114
+ sin_ = sin.view(1, 1, tt, rope_dim // 2, 1)
115
+ cos_ = cos.view(1, 1, tt, rope_dim // 2, 1)
116
+ x1 = head[..., 0:1]
117
+ x2 = head[..., 1:2]
118
+ rot = torch.cat(
119
+ [x1 * cos_ - x2 * sin_, x1 * sin_ + x2 * cos_], dim=-1
120
+ ).view(b, h, tt, rope_dim)
121
+ return torch.cat([rot, tail], dim=-1)
122
+
123
+ q = apply_rope_vec(q)
124
+ k = apply_rope_vec(k)
125
+
126
+ # Prefer PyTorch SDPA (can enable FlashAttention kernel on supported GPUs)
127
+ if self.use_sdpa and self._has_sdpa:
128
+ s = k.shape[-2]
129
+ attn_mask_sdpa = None
130
+ if attention_mask is not None:
131
+ m = attention_mask
132
+
133
+ if m.dim() == 2 and m.shape == (b, s): # [b, s]
134
+ m = m[:, None, None, :] # [b,1,1,s]
135
+ elif m.dim() == 3 and m.shape[-2] == 1: # [b,1,s]
136
+ m = m[:, None, :, :] # [b,1,1,s]
137
+ elif m.dim() == 3 and m.shape[-2] == t: # [b,t,s]
138
+ m = m[:, None, :, :] # [b,1,t,s]
139
+ elif m.dim() == 4 and m.shape[1] == 1: # [b,1,t,s] or [b,1,1,s]
140
+ pass
141
+ attn_mask_sdpa = m
142
+
143
+ out = F.scaled_dot_product_attention(
144
+ q,
145
+ k,
146
+ v,
147
+ attn_mask=attn_mask_sdpa,
148
+ dropout_p=self.dropout if self.training else 0.0,
149
+ is_causal=False,
150
+ )
151
+ out = out.transpose(1, 2).contiguous().view(b, t, self.inner_dim)
152
+ return self.o_proj(out)
153
+ else:
154
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
155
+ self.head_dim
156
+ )
157
+ if attention_mask is not None:
158
+ attn_scores = attn_scores + attention_mask
159
+ attn = attn_scores.softmax(dim=-1)
160
+ attn = F.dropout(attn, p=self.dropout, training=self.training)
161
+ out = torch.matmul(attn, v)
162
+ out = out.transpose(1, 2).contiguous().view(b, t, self.inner_dim)
163
+ return self.o_proj(out)
164
+
165
+
166
+ class LlamaMLP(nn.Module):
167
+ def __init__(
168
+ self,
169
+ dim: int,
170
+ hidden_dim: Optional[int] = None,
171
+ multiple_of: int = 256,
172
+ dropout: float = 0.0,
173
+ ):
174
+ super().__init__()
175
+ hidden_dim = hidden_dim or 4 * dim
176
+ # align to multiple_of like Llama
177
+ hidden_dim = int(2 * hidden_dim / 3)
178
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
179
+ self.gate = nn.Linear(dim, hidden_dim, bias=False)
180
+ self.up = nn.Linear(dim, hidden_dim, bias=False)
181
+ self.down = nn.Linear(hidden_dim, dim, bias=False)
182
+ self.dropout = dropout
183
+
184
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
185
+ x = F.silu(self.gate(x)) * self.up(x)
186
+ x = F.dropout(x, p=self.dropout, training=self.training)
187
+ return self.down(x)
188
+
189
+
190
+ class LlamaTransformerBlock(nn.Module):
191
+ def __init__(
192
+ self,
193
+ dim: int,
194
+ n_heads: int,
195
+ head_dim: int,
196
+ mlp_multiple_of: int = 256,
197
+ dropout: float = 0.0,
198
+ attention_bias: bool = False,
199
+ cross_attention_dim: Optional[int] = None,
200
+ use_ada_layer_norm_single: bool = False,
201
+ ):
202
+ super().__init__()
203
+ self.attn_norm = RMSNorm(dim, 1e-6)
204
+ self.attn = LlamaAttention(
205
+ dim,
206
+ n_heads,
207
+ head_dim,
208
+ bias=attention_bias,
209
+ dropout=dropout,
210
+ rope_dim=head_dim,
211
+ cross_attention_dim=None,
212
+ )
213
+ self.cross_attn = None
214
+ if cross_attention_dim is not None:
215
+ self.cross_attn_norm = RMSNorm(dim, 1e-6)
216
+ self.cross_attn = LlamaAttention(
217
+ dim,
218
+ n_heads,
219
+ head_dim,
220
+ bias=attention_bias,
221
+ dropout=dropout,
222
+ rope_dim=head_dim,
223
+ cross_attention_dim=cross_attention_dim,
224
+ )
225
+ self.mlp_norm = RMSNorm(dim, 1e-6)
226
+ self.mlp = LlamaMLP(dim, multiple_of=mlp_multiple_of, dropout=dropout)
227
+ self.use_ada_layer_norm_single = use_ada_layer_norm_single
228
+ if self.use_ada_layer_norm_single:
229
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
230
+
231
+ def forward(
232
+ self,
233
+ x: torch.Tensor,
234
+ encoder_hidden_states: Optional[torch.Tensor] = None,
235
+ attention_mask: Optional[torch.Tensor] = None,
236
+ timestep: Optional[torch.Tensor] = None,
237
+ ) -> torch.Tensor:
238
+ if self.use_ada_layer_norm_single:
239
+ batch_size = x.shape[0]
240
+ # timestep: [B, 6*D]
241
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
242
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
243
+ ).chunk(6, dim=1)
244
+
245
+ # Self-Attention with modulation and gating
246
+ norm_hidden_states = self.attn_norm(x)
247
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
248
+ h = self.attn(norm_hidden_states, attention_mask=attention_mask)
249
+ h = gate_msa * h
250
+ x = x + h
251
+
252
+ # MLP with modulation and gating
253
+ norm_hidden_states = self.mlp_norm(x)
254
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
255
+ h = self.mlp(norm_hidden_states)
256
+ h = gate_mlp * h
257
+ x = x + h
258
+ return x
259
+ else:
260
+ h = self.attn(self.attn_norm(x), attention_mask=attention_mask)
261
+ x = x + h
262
+ h = self.mlp(self.mlp_norm(x))
263
+ x = x + h
264
+ return x
265
+
266
+
267
+ class ProjectLayer(nn.Module):
268
+ def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0.0):
269
+ super().__init__()
270
+ self.kernel_size = kernel_size
271
+ self.dropout = dropout
272
+ self.ffn_1 = nn.Conv1d(
273
+ hidden_size, filter_size, kernel_size, padding=kernel_size // 2
274
+ )
275
+ self.ffn_2 = nn.Linear(filter_size, filter_size)
276
+
277
+ def forward(self, x):
278
+ x = self.ffn_1(x.transpose(1, 2)).transpose(1, 2)
279
+ x = x * self.kernel_size**-0.5
280
+ x = self.ffn_2(x)
281
+ return x
282
+
283
+
284
+ class LlamaTransformer(nn.Module):
285
+ def __init__(
286
+ self,
287
+ num_attention_heads: int,
288
+ attention_head_dim: int,
289
+ in_channels: int,
290
+ out_channels: int,
291
+ num_layers: int = 12,
292
+ num_layers_2: int = 2,
293
+ dropout: float = 0.0,
294
+ cross_attention_dim: Optional[int] = None,
295
+ norm_type: str = "layer_norm",
296
+ ):
297
+ super().__init__()
298
+ inner_dim = num_attention_heads * attention_head_dim
299
+ inner_dim_2 = inner_dim * 2
300
+ self.in_channels = in_channels
301
+ self.out_channels = out_channels
302
+ self.inner_dim = inner_dim
303
+ self.inner_dim_2 = inner_dim_2
304
+ self.dropout = dropout
305
+
306
+ self.proj_in = ProjectLayer(in_channels, inner_dim, kernel_size=3)
307
+
308
+ use_ada_single = norm_type == "ada_norm_single"
309
+ self.transformer_blocks = nn.ModuleList(
310
+ [
311
+ LlamaTransformerBlock(
312
+ dim=inner_dim,
313
+ n_heads=num_attention_heads,
314
+ head_dim=attention_head_dim,
315
+ dropout=dropout,
316
+ attention_bias=False,
317
+ cross_attention_dim=cross_attention_dim,
318
+ use_ada_layer_norm_single=use_ada_single,
319
+ )
320
+ for _ in range(num_layers)
321
+ ]
322
+ )
323
+
324
+ self.transformer_blocks_2 = nn.ModuleList(
325
+ [
326
+ LlamaTransformerBlock(
327
+ dim=inner_dim_2,
328
+ n_heads=num_attention_heads,
329
+ head_dim=attention_head_dim * 2,
330
+ dropout=dropout,
331
+ attention_bias=False,
332
+ cross_attention_dim=cross_attention_dim,
333
+ use_ada_layer_norm_single=use_ada_single,
334
+ )
335
+ for _ in range(num_layers_2)
336
+ ]
337
+ )
338
+
339
+ self.connection_proj = ProjectLayer(
340
+ in_channels + inner_dim, inner_dim_2, kernel_size=3
341
+ )
342
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
343
+ self.norm_out_2 = nn.LayerNorm(inner_dim_2, elementwise_affine=False, eps=1e-6)
344
+ self.scale_shift_table = nn.Parameter(
345
+ torch.randn(2, inner_dim) / inner_dim**0.5
346
+ )
347
+ self.scale_shift_table_2 = nn.Parameter(
348
+ torch.randn(2, inner_dim_2) / inner_dim_2**0.5
349
+ )
350
+ self.proj_out = ProjectLayer(inner_dim_2, out_channels, kernel_size=3)
351
+ self.adaln_single = AdaLayerNormSingleFlow(inner_dim)
352
+ self.adaln_single_2 = AdaLayerNormSingleFlow(inner_dim_2)
353
+
354
+ def forward(
355
+ self,
356
+ hidden_states: torch.Tensor,
357
+ timestep: Optional[torch.LongTensor] = None,
358
+ ):
359
+ s = self.proj_in(hidden_states)
360
+
361
+ embedded_timestep = None
362
+ timestep_mod = None
363
+ if self.adaln_single is not None and timestep is not None:
364
+ batch_size = s.shape[0]
365
+ timestep_mod, embedded_timestep = self.adaln_single(
366
+ timestep, hidden_dtype=s.dtype
367
+ )
368
+ for blk in self.transformer_blocks:
369
+ s = blk(s, timestep=timestep_mod)
370
+
371
+ if embedded_timestep is None:
372
+ embedded_timestep = torch.zeros(
373
+ s.size(0), s.size(-1), device=s.device, dtype=s.dtype
374
+ )
375
+
376
+ shift, scale = (
377
+ self.scale_shift_table[None] + embedded_timestep[:, None]
378
+ ).chunk(2, dim=1)
379
+ s = self.norm_out(s)
380
+ s = s * (1 + scale) + shift
381
+
382
+ x = torch.cat([hidden_states, s], dim=-1)
383
+ x = self.connection_proj(x)
384
+
385
+ embedded_timestep_2 = None
386
+ timestep_mod_2 = None
387
+ if self.adaln_single_2 is not None and timestep is not None:
388
+ batch_size = x.shape[0]
389
+ timestep_mod_2, embedded_timestep_2 = self.adaln_single_2(
390
+ timestep, hidden_dtype=x.dtype
391
+ )
392
+ for blk in self.transformer_blocks_2:
393
+ x = blk(x, timestep=timestep_mod_2)
394
+
395
+ if embedded_timestep_2 is None:
396
+ embedded_timestep_2 = torch.zeros(
397
+ x.size(0), x.size(-1), device=x.device, dtype=x.dtype
398
+ )
399
+
400
+ shift_2, scale_2 = (
401
+ self.scale_shift_table_2[None] + embedded_timestep_2[:, None]
402
+ ).chunk(2, dim=1)
403
+ x = self.norm_out_2(x)
404
+ x = x * (1 + scale_2) + shift_2
405
+
406
+ out = self.proj_out(x)
407
+
408
+ return out
409
+
410
+
411
+ class PixArtAlphaCombinedFlowEmbeddings(nn.Module):
412
+ def __init__(self, embedding_dim: int, size_emb_dim: int):
413
+ super().__init__()
414
+ self.flow_t_size = 512
415
+ self.outdim = size_emb_dim
416
+ self.timestep_embedder = TimestepEmbedding(
417
+ in_channels=self.flow_t_size, time_embed_dim=embedding_dim
418
+ )
419
+
420
+ def timestep_embedding(self, timesteps, max_period=10000, scale=1000):
421
+ half = self.flow_t_size // 2
422
+ freqs = torch.exp(
423
+ -math.log(max_period)
424
+ * torch.arange(start=0, end=half, device=timesteps.device)
425
+ / half
426
+ ).type(timesteps.type())
427
+ args = timesteps[:, None] * freqs[None] * scale
428
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
429
+ if self.flow_t_size % 2:
430
+ embedding = torch.cat(
431
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
432
+ )
433
+ return embedding
434
+
435
+ def forward(self, timestep, hidden_dtype):
436
+ timesteps_proj = self.timestep_embedding(timestep)
437
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))
438
+ conditioning = timesteps_emb
439
+ return conditioning
440
+
441
+
442
+ class AdaLayerNormSingleFlow(nn.Module):
443
+ def __init__(self, embedding_dim: int):
444
+ super().__init__()
445
+ self.emb = PixArtAlphaCombinedFlowEmbeddings(
446
+ embedding_dim, size_emb_dim=embedding_dim // 3
447
+ )
448
+ self.silu = nn.SiLU()
449
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
450
+
451
+ def forward(
452
+ self,
453
+ timestep: torch.Tensor,
454
+ hidden_dtype: Optional[torch.dtype] = None,
455
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
456
+
457
+ embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype)
458
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
459
+
460
+
461
+ class TimestepEmbedding(nn.Module):
462
+ def __init__(self, in_channels: int, time_embed_dim: int):
463
+ super().__init__()
464
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
465
+ self.act = nn.SiLU()
466
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
467
+
468
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
469
+ x = self.linear_1(x)
470
+ x = self.act(x)
471
+ x = self.linear_2(x)
472
+ return x
473
+
474
+
475
+ class Timesteps(nn.Module):
476
+ def __init__(
477
+ self,
478
+ num_channels: int,
479
+ flip_sin_to_cos: bool = True,
480
+ downscale_freq_shift: float = 0,
481
+ ):
482
+ super().__init__()
483
+ self.num_channels = num_channels
484
+ self.flip_sin_to_cos = flip_sin_to_cos
485
+ self.downscale_freq_shift = downscale_freq_shift
486
+
487
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
488
+ half_dim = self.num_channels // 2
489
+ exponent = (
490
+ -math.log(10000)
491
+ * torch.arange(0, half_dim, device=timesteps.device)
492
+ / (half_dim - self.downscale_freq_shift)
493
+ )
494
+ emb = torch.exp(exponent)[None, :] * timesteps[:, None]
495
+ if self.flip_sin_to_cos:
496
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
497
+ else:
498
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
499
+ if self.num_channels % 2 == 1:
500
+ emb = torch.nn.functional.pad(emb, (0, 1))
501
+ return emb
src/heartlib/heartmula/configuration_heartmula.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class HeartMuLaConfig(PretrainedConfig):
5
+ model_type = "heartmula"
6
+
7
+ def __init__(
8
+ self,
9
+ backbone_flavor: str = "llama-3B",
10
+ decoder_flavor: str = "llama-300M",
11
+ text_vocab_size: int = 128256,
12
+ audio_vocab_size: int = 8197,
13
+ audio_num_codebooks: int = 8,
14
+ muq_dim: int = 512,
15
+ **kwargs
16
+ ):
17
+ super().__init__(**kwargs)
18
+ self.backbone_flavor = backbone_flavor
19
+ self.decoder_flavor = decoder_flavor
20
+ self.text_vocab_size = text_vocab_size
21
+ self.audio_vocab_size = audio_vocab_size
22
+ self.audio_num_codebooks = audio_num_codebooks
23
+ self.muq_dim = muq_dim
src/heartlib/heartmula/modeling_heartmula.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .configuration_heartmula import HeartMuLaConfig
4
+ from transformers.modeling_utils import PreTrainedModel
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchtune
8
+ from torchtune.models import llama3_2
9
+
10
+
11
+ def llama3_2_3B() -> torchtune.modules.transformer.TransformerDecoder:
12
+ return llama3_2.llama3_2(
13
+ vocab_size=128_256,
14
+ num_layers=28,
15
+ num_heads=24,
16
+ num_kv_heads=8,
17
+ embed_dim=3072,
18
+ max_seq_len=8192,
19
+ intermediate_dim=8192,
20
+ attn_dropout=0.0,
21
+ norm_eps=1e-5,
22
+ rope_base=500_000,
23
+ scale_factor=32,
24
+ )
25
+
26
+
27
+ def llama3_2_300M() -> torchtune.modules.transformer.TransformerDecoder:
28
+ return llama3_2.llama3_2(
29
+ vocab_size=128_256,
30
+ num_layers=3,
31
+ num_heads=8,
32
+ num_kv_heads=4,
33
+ embed_dim=3072,
34
+ max_seq_len=2048,
35
+ intermediate_dim=8192,
36
+ attn_dropout=0.0,
37
+ norm_eps=1e-5,
38
+ rope_base=500_000,
39
+ scale_factor=32,
40
+ )
41
+
42
+
43
+ def llama3_2_7B() -> torchtune.modules.transformer.TransformerDecoder:
44
+ return llama3_2.llama3_2(
45
+ vocab_size=128_256,
46
+ num_layers=32,
47
+ num_heads=32,
48
+ num_kv_heads=8,
49
+ embed_dim=4096,
50
+ max_seq_len=8192,
51
+ intermediate_dim=14336,
52
+ attn_dropout=0.0,
53
+ norm_eps=1e-5,
54
+ rope_base=500_000,
55
+ scale_factor=32,
56
+ )
57
+
58
+
59
+ def llama3_2_400M() -> torchtune.modules.transformer.TransformerDecoder:
60
+ return llama3_2.llama3_2(
61
+ vocab_size=128_256,
62
+ num_layers=4,
63
+ num_heads=8,
64
+ num_kv_heads=4,
65
+ embed_dim=3072,
66
+ max_seq_len=2048,
67
+ intermediate_dim=8192,
68
+ attn_dropout=0.0,
69
+ norm_eps=1e-5,
70
+ rope_base=500_000,
71
+ scale_factor=32,
72
+ ) # 减少了num_heads和num_kv_heads之间的倍速,提升了精确度,但降低了效率
73
+
74
+
75
+ FLAVORS = {
76
+ "llama-3B": llama3_2_3B,
77
+ "llama-300M": llama3_2_300M,
78
+ "llama-7B": llama3_2_7B,
79
+ "llama-400M": llama3_2_400M,
80
+ }
81
+
82
+
83
+ def _prepare_transformer(model):
84
+ embed_dim = model.tok_embeddings.embedding_dim
85
+ model.tok_embeddings = nn.Identity()
86
+ model.output = nn.Identity()
87
+ return model, embed_dim
88
+
89
+
90
+ def _create_causal_mask(seq_len: int, device: torch.device):
91
+ return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
92
+
93
+
94
+ def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
95
+ r = mask[input_pos, :]
96
+ return r
97
+
98
+
99
+ def _multinomial_sample_one_no_sync(
100
+ probs,
101
+ ): # Does multinomial sampling without a cuda synchronization
102
+ q = torch.empty_like(probs).exponential_(1)
103
+ return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
104
+
105
+
106
+ def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
107
+ logits = logits / temperature
108
+
109
+ filter_value: float = -float("Inf")
110
+ indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
111
+ scores_processed = logits.masked_fill(indices_to_remove, filter_value)
112
+ scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
113
+ probs = torch.nn.functional.softmax(scores_processed, dim=-1)
114
+
115
+ sample_token = _multinomial_sample_one_no_sync(probs)
116
+ return sample_token
117
+
118
+
119
+ class HeartMuLa(PreTrainedModel):
120
+ config_class = HeartMuLaConfig
121
+
122
+ def __init__(
123
+ self,
124
+ config: HeartMuLaConfig,
125
+ ):
126
+ super(HeartMuLa, self).__init__(config)
127
+
128
+ self.config = config
129
+
130
+ self.backbone, backbone_dim = _prepare_transformer(
131
+ FLAVORS[config.backbone_flavor]()
132
+ )
133
+ self.decoder, decoder_dim = _prepare_transformer(
134
+ FLAVORS[config.decoder_flavor]()
135
+ )
136
+
137
+ self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
138
+ self.audio_embeddings = nn.Embedding(
139
+ config.audio_vocab_size * config.audio_num_codebooks, backbone_dim
140
+ )
141
+ self.unconditional_text_embedding = nn.Embedding(1, backbone_dim)
142
+
143
+ self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
144
+ self.codebook0_head = nn.Linear(
145
+ backbone_dim, config.audio_vocab_size, bias=False
146
+ )
147
+ self.audio_head = nn.Parameter(
148
+ torch.empty(
149
+ config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size
150
+ )
151
+ )
152
+ self.muq_linear = nn.Linear(config.muq_dim, backbone_dim)
153
+ self.post_init()
154
+
155
+ def setup_caches(self, max_batch_size: int):
156
+ dtype = next(self.parameters()).dtype
157
+ device = next(self.parameters()).device
158
+
159
+ try:
160
+ self.reset_caches()
161
+ except RuntimeError:
162
+ pass
163
+
164
+ with device:
165
+ self.backbone.setup_caches(max_batch_size, dtype)
166
+ self.decoder.setup_caches(
167
+ max_batch_size,
168
+ dtype,
169
+ decoder_max_seq_len=self.config.audio_num_codebooks,
170
+ )
171
+
172
+ self.register_buffer(
173
+ "backbone_causal_mask",
174
+ _create_causal_mask(self.backbone.max_seq_len, device),
175
+ )
176
+ self.register_buffer(
177
+ "decoder_causal_mask",
178
+ _create_causal_mask(self.config.audio_num_codebooks, device),
179
+ )
180
+
181
+ def generate_frame(
182
+ self,
183
+ tokens: torch.Tensor,
184
+ tokens_mask: torch.Tensor,
185
+ input_pos: torch.Tensor,
186
+ temperature: float,
187
+ topk: int,
188
+ cfg_scale: float,
189
+ continuous_segments: torch.Tensor = None,
190
+ starts=None,
191
+ ) -> torch.Tensor:
192
+ b, s, _ = tokens.size()
193
+
194
+ assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
195
+ curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
196
+
197
+ uncond_mask = None
198
+ if cfg_scale > 1.0 and b > 1:
199
+ actual_B = b // 2
200
+ uncond_mask = torch.cat(
201
+ [
202
+ torch.zeros(actual_B, dtype=torch.bool, device=tokens.device),
203
+ torch.ones(actual_B, dtype=torch.bool, device=tokens.device),
204
+ ]
205
+ )
206
+
207
+ embeds = self._embed_tokens(tokens, uncond_mask=uncond_mask)
208
+ masked_embeds = embeds * tokens_mask.unsqueeze(-1)
209
+ h = masked_embeds.sum(dim=2, dtype=embeds.dtype) # merge
210
+ if continuous_segments is not None:
211
+ continuous_segments = self.muq_linear(continuous_segments)
212
+ if uncond_mask is not None:
213
+ uncond_embed = self.unconditional_text_embedding(
214
+ torch.zeros(1, device=tokens.device, dtype=torch.long)
215
+ )
216
+ mask_expanded = uncond_mask.view(b, 1).expand_as(continuous_segments)
217
+ continuous_segments = torch.where(
218
+ mask_expanded, uncond_embed, continuous_segments
219
+ )
220
+ batch_indices = torch.arange(h.shape[0], device=h.device)
221
+ h[batch_indices, starts] = continuous_segments
222
+ h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask)
223
+ last_h = h[:, -1, :] # the last frame
224
+ c0_logits = self.codebook0_head(last_h) # only predict the audio part
225
+
226
+ if cfg_scale > 1.0 and b > 1 and (b % 2 == 0):
227
+ actual_B = b // 2
228
+ cond_logits = c0_logits[:actual_B, :]
229
+ uncond_logits = c0_logits[actual_B:, :]
230
+ guided_logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
231
+ c0_sample = sample_topk(guided_logits, topk, temperature)
232
+ c0_sample = c0_sample.repeat(
233
+ 2, 1
234
+ ) # repeat to both branches to keep alignment
235
+ else:
236
+ c0_sample = sample_topk(c0_logits, topk, temperature)
237
+
238
+ c0_embed = self._embed_audio(0, c0_sample)
239
+
240
+ self.decoder.reset_caches()
241
+ curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
242
+ curr_sample = c0_sample.clone()
243
+ curr_pos = (
244
+ torch.arange(0, curr_h.size(1), device=curr_h.device)
245
+ .unsqueeze(0)
246
+ .repeat(curr_h.size(0), 1)
247
+ )
248
+ curr_h = curr_h.to(embeds.dtype)
249
+ for i in range(1, self.config.audio_num_codebooks):
250
+ curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
251
+ decoder_h = self.decoder(
252
+ self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask
253
+ )
254
+ ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
255
+ if cfg_scale > 1.0 and b > 1 and (b % 2 == 0):
256
+ actual_B = b // 2
257
+ cond_ci = ci_logits[:actual_B, :]
258
+ uncond_ci = ci_logits[actual_B:, :]
259
+ guided_ci = uncond_ci + (cond_ci - uncond_ci) * cfg_scale
260
+
261
+ ci_sample = sample_topk(guided_ci, topk, temperature)
262
+ ci_sample = ci_sample.repeat(2, 1)
263
+ else:
264
+ ci_sample = sample_topk(ci_logits, topk, temperature)
265
+ ci_embed = self._embed_audio(i, ci_sample)
266
+ curr_h = ci_embed
267
+ curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
268
+ curr_pos = curr_pos[:, -1:] + 1
269
+
270
+ return curr_sample
271
+
272
+ def reset_caches(self):
273
+ self.backbone.reset_caches()
274
+ self.decoder.reset_caches()
275
+
276
+ def _embed_local_audio(self, tokens):
277
+ """the token from 0-30"""
278
+ audio_tokens = tokens + (
279
+ self.config.audio_vocab_size
280
+ * torch.arange(self.config.audio_num_codebooks - 1, device=tokens.device)
281
+ )
282
+ audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
283
+ tokens.size(0), tokens.size(1), self.config.audio_num_codebooks - 1, -1
284
+ )
285
+ return audio_embeds
286
+
287
+ def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
288
+ return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
289
+
290
+ def _embed_tokens(
291
+ self, tokens: torch.Tensor, uncond_mask: torch.Tensor | None
292
+ ) -> torch.Tensor:
293
+ B, S, _ = tokens.size()
294
+ text_embeds = self.text_embeddings(tokens[:, :, -1])
295
+
296
+ if uncond_mask is not None:
297
+ uncond_text_embed = self.unconditional_text_embedding(
298
+ torch.zeros(1, device=tokens.device, dtype=torch.long)
299
+ )
300
+ mask_expanded = uncond_mask.view(B, 1, 1).expand_as(text_embeds)
301
+ text_embeds = torch.where(
302
+ mask_expanded,
303
+ uncond_text_embed,
304
+ text_embeds,
305
+ )
306
+
307
+ text_embeds = text_embeds.unsqueeze(-2)
308
+
309
+ audio_tokens = tokens[:, :, :-1] + (
310
+ self.config.audio_vocab_size
311
+ * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
312
+ )
313
+ audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
314
+ tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
315
+ )
316
+ return torch.cat([audio_embeds, text_embeds], dim=-2)
src/heartlib/pipelines/__init__.py ADDED
File without changes
src/heartlib/pipelines/lyrics_transcription.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.pipelines.automatic_speech_recognition import (
2
+ AutomaticSpeechRecognitionPipeline,
3
+ )
4
+ from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration
5
+ from transformers.models.whisper.processing_whisper import WhisperProcessor
6
+ import torch
7
+ import os
8
+
9
+
10
+ class HeartTranscriptorPipeline(AutomaticSpeechRecognitionPipeline):
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__(*args, **kwargs)
13
+
14
+ @classmethod
15
+ def from_pretrained(
16
+ cls, pretrained_path: str, device: torch.device, dtype: torch.dtype
17
+ ):
18
+ if os.path.exists(
19
+ hearttranscriptor_path := os.path.join(
20
+ pretrained_path, "HeartTranscriptor-oss"
21
+ )
22
+ ):
23
+ model = WhisperForConditionalGeneration.from_pretrained(
24
+ hearttranscriptor_path, torch_dtype=dtype, low_cpu_mem_usage=True
25
+ )
26
+ processor = WhisperProcessor.from_pretrained(hearttranscriptor_path)
27
+ else:
28
+ raise FileNotFoundError(
29
+ f"Expected to find checkpoint for HeartTranscriptor at {hearttranscriptor_path} but not found. Please check your folder {pretrained_path}."
30
+ )
31
+
32
+ return cls(
33
+ model=model,
34
+ tokenizer=processor.tokenizer,
35
+ feature_extractor=processor.feature_extractor,
36
+ device=device,
37
+ dtype=dtype,
38
+ chunk_length_s=30,
39
+ batch_size=16,
40
+ )
src/heartlib/pipelines/music_generation.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.pipelines.base import Pipeline
2
+ from tokenizers import Tokenizer
3
+ from ..heartmula.modeling_heartmula import HeartMuLa
4
+ from ..heartcodec.modeling_heartcodec import HeartCodec
5
+ import torch
6
+ from typing import Dict, Any, Optional
7
+ import os
8
+ from dataclasses import dataclass
9
+ from tqdm import tqdm
10
+ import torchaudio
11
+ import json
12
+ from transformers import BitsAndBytesConfig
13
+
14
+
15
+ @dataclass
16
+ class HeartMuLaGenConfig:
17
+ text_bos_id: int = 128000
18
+ text_eos_id: int = 128001
19
+ audio_eos_id: int = 8193
20
+ empty_id: int = 0
21
+
22
+ @classmethod
23
+ def from_file(cls, path: str):
24
+ with open(path, encoding="utf-8") as fp:
25
+ data = json.load(fp)
26
+ return cls(**data)
27
+
28
+
29
+ class HeartMuLaGenPipeline(Pipeline):
30
+ def __init__(
31
+ self,
32
+ model: HeartMuLa,
33
+ audio_codec: HeartCodec,
34
+ muq_mulan: Optional[Any],
35
+ text_tokenizer: Tokenizer,
36
+ config: HeartMuLaGenConfig,
37
+ device: torch.device,
38
+ dtype: torch.dtype,
39
+ ):
40
+ super().__init__(model, dtype=dtype)
41
+ self.model = model
42
+ self.audio_codec = audio_codec
43
+ self.muq_mulan = muq_mulan
44
+ self.text_tokenizer = text_tokenizer
45
+ self.config = config
46
+
47
+ self._parallel_number = audio_codec.config.num_quantizers + 1
48
+ self._muq_dim = model.config.muq_dim
49
+
50
+ def _sanitize_parameters(self, **kwargs):
51
+ preprocess_kwargs = {"cfg_scale": kwargs.get("cfg_scale", 1.5)}
52
+ forward_kwargs = {
53
+ "max_audio_length_ms": kwargs.get("max_audio_length_ms", 120_000),
54
+ "temperature": kwargs.get("temperature", 1.0),
55
+ "topk": kwargs.get("topk", 50),
56
+ "cfg_scale": kwargs.get("cfg_scale", 1.5),
57
+ }
58
+ postprocess_kwargs = {
59
+ "save_path": kwargs.get("save_path", "output.mp3"),
60
+ }
61
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
62
+
63
+ def preprocess(self, inputs: Dict[str, Any], cfg_scale: float):
64
+
65
+ # process tags
66
+ tags = inputs["tags"]
67
+ if os.path.isfile(tags):
68
+ with open(tags, encoding="utf-8") as fp:
69
+ tags = fp.read()
70
+ assert isinstance(tags, str), f"tags must be a string, but got {type(tags)}"
71
+
72
+ tags = tags.lower()
73
+ # encapsulate with special <tag> and </tag> tokens
74
+ if not tags.startswith("<tag>"):
75
+ tags = f"<tag>{tags}"
76
+ if not tags.endswith("</tag>"):
77
+ tags = f"{tags}</tag>"
78
+
79
+ tags_ids = self.text_tokenizer.encode(tags).ids
80
+ if tags_ids[0] != self.config.text_bos_id:
81
+ tags_ids = [self.config.text_bos_id] + tags_ids
82
+ if tags_ids[-1] != self.config.text_eos_id:
83
+ tags_ids = tags_ids + [self.config.text_eos_id]
84
+
85
+ # process reference audio
86
+ ref_audio = inputs.get("ref_audio", None)
87
+ if ref_audio is not None:
88
+ raise NotImplementedError("ref_audio is not supported yet.")
89
+ muq_embed = torch.zeros([self._muq_dim], dtype=self.dtype)
90
+ muq_idx = len(tags_ids)
91
+
92
+ # process lyrics
93
+ lyrics = inputs["lyrics"]
94
+ if os.path.isfile(lyrics):
95
+ with open(lyrics, encoding="utf-8") as fp:
96
+ lyrics = fp.read()
97
+ assert isinstance(
98
+ lyrics, str
99
+ ), f"lyrics must be a string, but got {type(lyrics)}"
100
+ lyrics = lyrics.lower()
101
+
102
+ lyrics_ids = self.text_tokenizer.encode(lyrics).ids
103
+ if lyrics_ids[0] != self.config.text_bos_id:
104
+ lyrics_ids = [self.config.text_bos_id] + lyrics_ids
105
+ if lyrics_ids[-1] != self.config.text_eos_id:
106
+ lyrics_ids = lyrics_ids + [self.config.text_eos_id]
107
+
108
+ # cat them together. tags, ref_audio, lyrics
109
+ prompt_len = len(tags_ids) + 1 + len(lyrics_ids)
110
+
111
+ tokens = torch.zeros([prompt_len, self._parallel_number], dtype=torch.long)
112
+ tokens[: len(tags_ids), -1] = torch.tensor(tags_ids)
113
+ tokens[len(tags_ids) + 1 :, -1] = torch.tensor(lyrics_ids)
114
+
115
+ tokens_mask = torch.zeros_like(tokens, dtype=torch.bool)
116
+ tokens_mask[:, -1] = True
117
+
118
+ bs_size = 2 if cfg_scale != 1.0 else 1
119
+
120
+ def _cfg_cat(tensor: torch.Tensor, cfg_scale: float):
121
+ tensor = tensor.unsqueeze(0)
122
+ if cfg_scale != 1.0:
123
+ tensor = torch.cat([tensor, tensor], dim=0)
124
+ return tensor
125
+
126
+ return {
127
+ "tokens": _cfg_cat(tokens, cfg_scale),
128
+ "tokens_mask": _cfg_cat(tokens_mask, cfg_scale),
129
+ "muq_embed": _cfg_cat(muq_embed, cfg_scale),
130
+ "muq_idx": [muq_idx] * bs_size,
131
+ "pos": _cfg_cat(torch.arange(prompt_len, dtype=torch.long), cfg_scale),
132
+ }
133
+
134
+ def _forward(
135
+ self,
136
+ model_inputs: Dict[str, Any],
137
+ max_audio_length_ms: int,
138
+ temperature: float,
139
+ topk: int,
140
+ cfg_scale: float,
141
+ ):
142
+ prompt_tokens = model_inputs["tokens"]
143
+ prompt_tokens_mask = model_inputs["tokens_mask"]
144
+ continuous_segment = model_inputs["muq_embed"]
145
+ starts = model_inputs["muq_idx"]
146
+ prompt_pos = model_inputs["pos"]
147
+
148
+ frames = []
149
+
150
+ bs_size = 2 if cfg_scale != 1.0 else 1
151
+ self.model.setup_caches(bs_size)
152
+ with torch.autocast(device_type=self.device.type, dtype=self.dtype):
153
+ curr_token = self.model.generate_frame(
154
+ tokens=prompt_tokens,
155
+ tokens_mask=prompt_tokens_mask,
156
+ input_pos=prompt_pos,
157
+ temperature=temperature,
158
+ topk=topk,
159
+ cfg_scale=cfg_scale,
160
+ continuous_segments=continuous_segment,
161
+ starts=starts,
162
+ )
163
+ frames.append(curr_token[0:1,])
164
+
165
+ def _pad_audio_token(token: torch.Tensor):
166
+ padded_token = (
167
+ torch.ones(
168
+ (token.shape[0], self._parallel_number),
169
+ device=token.device,
170
+ dtype=torch.long,
171
+ )
172
+ * self.config.empty_id
173
+ )
174
+ padded_token[:, :-1] = token
175
+ padded_token = padded_token.unsqueeze(1)
176
+ padded_token_mask = torch.ones_like(
177
+ padded_token, device=token.device, dtype=torch.bool
178
+ )
179
+ padded_token_mask[..., -1] = False
180
+ return padded_token, padded_token_mask
181
+
182
+ max_audio_frames = max_audio_length_ms // 80
183
+
184
+ for i in tqdm(range(max_audio_frames)):
185
+ curr_token, curr_token_mask = _pad_audio_token(curr_token)
186
+ with torch.autocast(device_type=self.device.type, dtype=self.dtype):
187
+ curr_token = self.model.generate_frame(
188
+ tokens=curr_token,
189
+ tokens_mask=curr_token_mask,
190
+ input_pos=prompt_pos[..., -1:] + i + 1,
191
+ temperature=temperature,
192
+ topk=topk,
193
+ cfg_scale=cfg_scale,
194
+ continuous_segments=None,
195
+ starts=None,
196
+ )
197
+ if torch.any(curr_token[0:1, :] >= self.config.audio_eos_id):
198
+ break
199
+ frames.append(curr_token[0:1,])
200
+ frames = torch.stack(frames).permute(1, 2, 0).squeeze(0)
201
+ wav = self.audio_codec.detokenize(frames)
202
+ return {"wav": wav}
203
+
204
+ def postprocess(self, model_outputs: Dict[str, Any], save_path: str):
205
+ wav = model_outputs["wav"]
206
+ torchaudio.save(save_path, wav, 48000)
207
+
208
+ @classmethod
209
+ def from_pretrained(
210
+ cls,
211
+ pretrained_path: str,
212
+ device: torch.device,
213
+ dtype: torch.dtype,
214
+ version: str,
215
+ bnb_config: Optional[BitsAndBytesConfig] = None,
216
+ ):
217
+
218
+ if os.path.exists(
219
+ heartcodec_path := os.path.join(pretrained_path, "HeartCodec-oss")
220
+ ):
221
+ heartcodec = HeartCodec.from_pretrained(heartcodec_path, device_map=device)
222
+ else:
223
+ raise FileNotFoundError(
224
+ f"Expected to find checkpoint for HeartCodec at {heartcodec_path} but not found. Please check your folder {pretrained_path}."
225
+ )
226
+
227
+ if os.path.exists(
228
+ heartmula_path := os.path.join(pretrained_path, f"HeartMuLa-oss-{version}")
229
+ ):
230
+ heartmula = HeartMuLa.from_pretrained(
231
+ heartmula_path, dtype=dtype, quantization_config=bnb_config
232
+ )
233
+ else:
234
+ raise FileNotFoundError(
235
+ f"Expected to find checkpoint for HeartMuLa at {heartmula_path} but not found. Please check your folder {pretrained_path}."
236
+ )
237
+
238
+ if os.path.isfile(
239
+ vocab_path := os.path.join(pretrained_path, "tokenizer.json")
240
+ ):
241
+ tokenizer = Tokenizer.from_file(vocab_path)
242
+ else:
243
+ raise FileNotFoundError(
244
+ f"Expected to find tokenizer.json for HeartMuLa at {vocab_path} but not found. Please check your folder {pretrained_path}."
245
+ )
246
+
247
+ if os.path.isfile(
248
+ gen_config_path := os.path.join(pretrained_path, "gen_config.json")
249
+ ):
250
+ gen_config = HeartMuLaGenConfig.from_file(gen_config_path)
251
+ else:
252
+ raise FileNotFoundError(
253
+ f"Expected to find gen_config.json for HeartMuLa at {gen_config_path} but not found. Please check your folder {pretrained_path}."
254
+ )
255
+
256
+ return cls(heartmula, heartcodec, None, tokenizer, gen_config, device, dtype)