rishabhsabnavis commited on
Commit
589b573
·
verified ·
1 Parent(s): 62be552

Upload 2 files

Browse files
Files changed (2) hide show
  1. infer.py +219 -0
  2. infer_utils.py +445 -0
infer.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ASLP-LAB
2
+ # 2025 Huakang Chen (huakang@mail.nwpu.edu.cn)
3
+ # 2025 Guobin Ma (guobin.ma@gmail.com)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import os
19
+ import time
20
+ import random
21
+
22
+ import torch
23
+ import torchaudio
24
+ from einops import rearrange
25
+
26
+ print("Current working directory:", os.getcwd())
27
+
28
+ from infer_utils import (
29
+ decode_audio,
30
+ get_lrc_token,
31
+ get_negative_style_prompt,
32
+ get_reference_latent,
33
+ get_style_prompt,
34
+ prepare_model,
35
+ )
36
+
37
+
38
+ def inference(
39
+ cfm_model,
40
+ vae_model,
41
+ cond,
42
+ text,
43
+ duration,
44
+ style_prompt,
45
+ negative_style_prompt,
46
+ start_time,
47
+ pred_frames,
48
+ batch_infer_num,
49
+ chunked=False,
50
+ ):
51
+ with torch.inference_mode():
52
+ latents, _ = cfm_model.sample(
53
+ cond=cond,
54
+ text=text,
55
+ duration=duration,
56
+ style_prompt=style_prompt,
57
+ negative_style_prompt=negative_style_prompt,
58
+ steps=32,
59
+ cfg_strength=4.0,
60
+ start_time=start_time,
61
+ latent_pred_segments=pred_frames,
62
+ batch_infer_num=batch_infer_num
63
+ )
64
+
65
+ outputs = []
66
+ for latent in latents:
67
+ latent = latent.to(torch.float32)
68
+ latent = latent.transpose(1, 2) # [b d t]
69
+
70
+ output = decode_audio(latent, vae_model, chunked=chunked)
71
+
72
+ # Rearrange audio batch to a single sequence
73
+ output = rearrange(output, "b d n -> d (b n)")
74
+ # Peak normalize, clip, convert to int16, and save to file
75
+ output = (
76
+ output.to(torch.float32)
77
+ .div(torch.max(torch.abs(output)))
78
+ .clamp(-1, 1)
79
+ .mul(32767)
80
+ .to(torch.int16)
81
+ .cpu()
82
+ )
83
+ outputs.append(output)
84
+
85
+ return outputs
86
+
87
+
88
+ if __name__ == "__main__":
89
+ parser = argparse.ArgumentParser()
90
+ parser.add_argument(
91
+ "--lrc-path",
92
+ type=str,
93
+ help="lyrics of target song",
94
+ ) # lyrics of target song
95
+ parser.add_argument(
96
+ "--ref-prompt",
97
+ type=str,
98
+ help="reference prompt as style prompt for target song",
99
+ required=False,
100
+ ) # reference prompt as style prompt for target song
101
+ parser.add_argument(
102
+ "--ref-audio-path",
103
+ type=str,
104
+ help="reference audio as style prompt for target song",
105
+ required=False,
106
+ ) # reference audio as style prompt for target song
107
+ parser.add_argument(
108
+ "--chunked",
109
+ action="store_true",
110
+ help="whether to use chunked decoding",
111
+ ) # whether to use chunked decoding
112
+ parser.add_argument(
113
+ "--audio-length",
114
+ type=int,
115
+ default=95,
116
+ choices=[95, 285],
117
+ help="length of generated song",
118
+ ) # length of target song
119
+ parser.add_argument(
120
+ "--repo-id", type=str, default="ASLP-lab/DiffRhythm-base", help="target model"
121
+ )
122
+ parser.add_argument(
123
+ "--output-dir",
124
+ type=str,
125
+ default="infer/example/output",
126
+ help="output directory fo generated song",
127
+ ) # output directory of target song
128
+ parser.add_argument(
129
+ "--edit",
130
+ action="store_true",
131
+ help="whether to open edit mode",
132
+ ) # edit flag
133
+ parser.add_argument(
134
+ "--ref-song",
135
+ type=str,
136
+ required=False,
137
+ help="reference prompt as latent prompt for editing",
138
+ ) # reference prompt as latent prompt for editing
139
+ parser.add_argument(
140
+ "--edit-segments",
141
+ type=str,
142
+ required=False,
143
+ help="Time segments to edit (in seconds). Format: `[[start1,end1],...]`. "
144
+ "Use `-1` for audio start/end (e.g., `[[-1,25], [50.0,-1]]`)."
145
+ ) # edit segments of target song
146
+ parser.add_argument(
147
+ "--batch-infer-num",
148
+ type=int,
149
+ default=1,
150
+ required=False,
151
+ help="number of songs per batch",
152
+ ) # number of songs per batch
153
+ args = parser.parse_args()
154
+
155
+ assert (
156
+ args.ref_prompt or args.ref_audio_path
157
+ ), "either ref_prompt or ref_audio_path should be provided"
158
+ assert not (
159
+ args.ref_prompt and args.ref_audio_path
160
+ ), "only one of them should be provided"
161
+ if args.edit:
162
+ assert (
163
+ args.ref_song and args.edit_segments
164
+ ), "reference song and edit segments should be provided for editing"
165
+
166
+ device = "cpu"
167
+ if torch.cuda.is_available():
168
+ device = "cuda"
169
+ elif torch.mps.is_available():
170
+ device = "mps"
171
+
172
+ audio_length = args.audio_length
173
+ if audio_length == 95:
174
+ max_frames = 2048
175
+ elif audio_length == 285: # current not available
176
+ max_frames = 6144
177
+
178
+ cfm, tokenizer, muq, vae = prepare_model(max_frames, device, repo_id=args.repo_id)
179
+
180
+ if args.lrc_path:
181
+ with open(args.lrc_path, "r", encoding='utf-8') as f:
182
+ lrc = f.read()
183
+ else:
184
+ lrc = ""
185
+ lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
186
+
187
+ if args.ref_audio_path:
188
+ style_prompt = get_style_prompt(muq, args.ref_audio_path)
189
+ else:
190
+ style_prompt = get_style_prompt(muq, prompt=args.ref_prompt)
191
+
192
+ negative_style_prompt = get_negative_style_prompt(device)
193
+
194
+ latent_prompt, pred_frames = get_reference_latent(device, max_frames, args.edit, args.edit_segments, args.ref_song, vae)
195
+
196
+ s_t = time.time()
197
+ generated_songs = inference(
198
+ cfm_model=cfm,
199
+ vae_model=vae,
200
+ cond=latent_prompt,
201
+ text=lrc_prompt,
202
+ duration=max_frames,
203
+ style_prompt=style_prompt,
204
+ negative_style_prompt=negative_style_prompt,
205
+ start_time=start_time,
206
+ pred_frames=pred_frames,
207
+ chunked=args.chunked,
208
+ batch_infer_num=args.batch_infer_num
209
+ )
210
+ e_t = time.time() - s_t
211
+ print(f"inference cost {e_t:.2f} seconds")
212
+
213
+ generated_song = random.sample(generated_songs, 1)[0]
214
+
215
+ output_dir = args.output_dir
216
+ os.makedirs(output_dir, exist_ok=True)
217
+
218
+ output_path = os.path.join(output_dir, "output.wav")
219
+ torchaudio.save(output_path, generated_song, sample_rate=44100)
infer_utils.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ASLP-LAB
2
+ # 2025 Huakang Chen (huakang@mail.nwpu.edu.cn)
3
+ # 2025 Guobin Ma (guobin.ma@gmail.com)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import torch
18
+ import librosa
19
+ import torchaudio
20
+ import random
21
+ import json
22
+ from muq import MuQMuLan
23
+ from mutagen.mp3 import MP3
24
+ import os
25
+ import numpy as np
26
+ from huggingface_hub import hf_hub_download
27
+
28
+ from sys import path
29
+ path.append(os.getcwd())
30
+
31
+ from model import DiT, CFM
32
+
33
+ def vae_sample(mean, scale):
34
+ stdev = torch.nn.functional.softplus(scale) + 1e-4
35
+ var = stdev * stdev
36
+ logvar = torch.log(var)
37
+ latents = torch.randn_like(mean) * stdev + mean
38
+
39
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
40
+
41
+ return latents, kl
42
+
43
+ def normalize_audio(y, target_dbfs=0):
44
+ max_amplitude = torch.max(torch.abs(y))
45
+
46
+ target_amplitude = 10.0**(target_dbfs / 20.0)
47
+ scale_factor = target_amplitude / max_amplitude
48
+
49
+ normalized_audio = y * scale_factor
50
+
51
+ return normalized_audio
52
+
53
+ def set_audio_channels(audio, target_channels):
54
+ if target_channels == 1:
55
+ # Convert to mono
56
+ audio = audio.mean(1, keepdim=True)
57
+ elif target_channels == 2:
58
+ # Convert to stereo
59
+ if audio.shape[1] == 1:
60
+ audio = audio.repeat(1, 2, 1)
61
+ elif audio.shape[1] > 2:
62
+ audio = audio[:, :2, :]
63
+ return audio
64
+
65
+ class PadCrop(torch.nn.Module):
66
+ def __init__(self, n_samples, randomize=True):
67
+ super().__init__()
68
+ self.n_samples = n_samples
69
+ self.randomize = randomize
70
+
71
+ def __call__(self, signal):
72
+ n, s = signal.shape
73
+ start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
74
+ end = start + self.n_samples
75
+ output = signal.new_zeros([n, self.n_samples])
76
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
77
+ return output
78
+
79
+ def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
80
+
81
+ audio = audio.to(device)
82
+
83
+ if in_sr != target_sr:
84
+ resample_tf = torchaudio.functional.Resample(in_sr, target_sr).to(device)
85
+ audio = resample_tf(audio)
86
+ if target_length is None:
87
+ target_length = audio.shape[-1]
88
+ audio = PadCrop(target_length, randomize=False)(audio)
89
+
90
+ # Add batch dimension
91
+ if audio.dim() == 1:
92
+ audio = audio.unsqueeze(0).unsqueeze(0)
93
+ elif audio.dim() == 2:
94
+ audio = audio.unsqueeze(0)
95
+
96
+ audio = set_audio_channels(audio, target_channels)
97
+
98
+ return audio
99
+
100
+ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
101
+ downsampling_ratio = 2048
102
+ io_channels = 2
103
+ if not chunked:
104
+ return vae_model.decode_export(latents)
105
+ else:
106
+ # chunked decoding
107
+ hop_size = chunk_size - overlap
108
+ total_size = latents.shape[2]
109
+ batch_size = latents.shape[0]
110
+ chunks = []
111
+ i = 0
112
+ for i in range(0, total_size - chunk_size + 1, hop_size):
113
+ chunk = latents[:, :, i : i + chunk_size]
114
+ chunks.append(chunk)
115
+ if i + chunk_size != total_size:
116
+ # Final chunk
117
+ chunk = latents[:, :, -chunk_size:]
118
+ chunks.append(chunk)
119
+ chunks = torch.stack(chunks)
120
+ num_chunks = chunks.shape[0]
121
+ # samples_per_latent is just the downsampling ratio
122
+ samples_per_latent = downsampling_ratio
123
+ # Create an empty waveform, we will populate it with chunks as decode them
124
+ y_size = total_size * samples_per_latent
125
+ y_final = torch.zeros((batch_size, io_channels, y_size)).to(latents.device)
126
+ for i in range(num_chunks):
127
+ x_chunk = chunks[i, :]
128
+ # decode the chunk
129
+ y_chunk = vae_model.decode_export(x_chunk)
130
+ # figure out where to put the audio along the time domain
131
+ if i == num_chunks - 1:
132
+ # final chunk always goes at the end
133
+ t_end = y_size
134
+ t_start = t_end - y_chunk.shape[2]
135
+ else:
136
+ t_start = i * hop_size * samples_per_latent
137
+ t_end = t_start + chunk_size * samples_per_latent
138
+ # remove the edges of the overlaps
139
+ ol = (overlap // 2) * samples_per_latent
140
+ chunk_start = 0
141
+ chunk_end = y_chunk.shape[2]
142
+ if i > 0:
143
+ # no overlap for the start of the first chunk
144
+ t_start += ol
145
+ chunk_start += ol
146
+ if i < num_chunks - 1:
147
+ # no overlap for the end of the last chunk
148
+ t_end -= ol
149
+ chunk_end -= ol
150
+ # paste the chunked audio into our y_final output audio
151
+ y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
152
+ return y_final
153
+
154
+ def encode_audio(audio, vae_model, chunked=False, overlap=32, chunk_size=128):
155
+ downsampling_ratio = 2048
156
+ latent_dim = 128
157
+ if not chunked:
158
+ # default behavior. Encode the entire audio in parallel
159
+ return vae_model.encode_export(audio)
160
+ else:
161
+ # CHUNKED ENCODING
162
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
163
+ samples_per_latent = downsampling_ratio
164
+ total_size = audio.shape[2] # in samples
165
+ batch_size = audio.shape[0]
166
+ chunk_size *= samples_per_latent # converting metric in latents to samples
167
+ overlap *= samples_per_latent # converting metric in latents to samples
168
+ hop_size = chunk_size - overlap
169
+ chunks = []
170
+ for i in range(0, total_size - chunk_size + 1, hop_size):
171
+ chunk = audio[:,:,i:i+chunk_size]
172
+ chunks.append(chunk)
173
+ if i+chunk_size != total_size:
174
+ # Final chunk
175
+ chunk = audio[:,:,-chunk_size:]
176
+ chunks.append(chunk)
177
+ chunks = torch.stack(chunks)
178
+ num_chunks = chunks.shape[0]
179
+ # Note: y_size might be a different value from the latent length used in diffusion training
180
+ # because we can encode audio of varying lengths
181
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
182
+ y_size = total_size // samples_per_latent
183
+ # Create an empty latent, we will populate it with chunks as we encode them
184
+ y_final = torch.zeros((batch_size,latent_dim,y_size)).to(audio.device)
185
+ for i in range(num_chunks):
186
+ x_chunk = chunks[i,:]
187
+ # encode the chunk
188
+ y_chunk = vae_model.encode_export(x_chunk)
189
+ # figure out where to put the audio along the time domain
190
+ if i == num_chunks-1:
191
+ # final chunk always goes at the end
192
+ t_end = y_size
193
+ t_start = t_end - y_chunk.shape[2]
194
+ else:
195
+ t_start = i * hop_size // samples_per_latent
196
+ t_end = t_start + chunk_size // samples_per_latent
197
+ # remove the edges of the overlaps
198
+ ol = overlap//samples_per_latent//2
199
+ chunk_start = 0
200
+ chunk_end = y_chunk.shape[2]
201
+ if i > 0:
202
+ # no overlap for the start of the first chunk
203
+ t_start += ol
204
+ chunk_start += ol
205
+ if i < num_chunks-1:
206
+ # no overlap for the end of the last chunk
207
+ t_end -= ol
208
+ chunk_end -= ol
209
+ # paste the chunked audio into our y_final output audio
210
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
211
+ return y_final
212
+
213
+ def prepare_model(max_frames, device, repo_id="ASLP-lab/DiffRhythm-1_2"):
214
+ # prepare cfm model
215
+ dit_ckpt_path = hf_hub_download(
216
+ repo_id=repo_id, filename="cfm_model.pt", cache_dir="./pretrained"
217
+ )
218
+ dit_config_path = "./config/diffrhythm-1b.json"
219
+ with open(dit_config_path) as f:
220
+ model_config = json.load(f)
221
+ dit_model_cls = DiT
222
+ cfm = CFM(
223
+ transformer=dit_model_cls(**model_config["model"], max_frames=max_frames),
224
+ num_channels=model_config["model"]["mel_dim"],
225
+ max_frames=max_frames
226
+ )
227
+ cfm = cfm.to(device)
228
+ cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False)
229
+
230
+ # prepare tokenizer
231
+ tokenizer = CNENTokenizer()
232
+
233
+ # prepare muq
234
+ muq = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./pretrained")
235
+ muq = muq.to(device).eval()
236
+
237
+ # prepare vae
238
+ vae_ckpt_path = hf_hub_download(
239
+ repo_id="ASLP-lab/DiffRhythm-vae",
240
+ filename="vae_model.pt",
241
+ cache_dir="./pretrained",
242
+ )
243
+ vae = torch.jit.load(vae_ckpt_path, map_location="cpu").to(device)
244
+
245
+ return cfm, tokenizer, muq, vae
246
+
247
+
248
+ # for song edit, will be added in the future
249
+ def get_reference_latent(device, max_frames, edit, pred_segments, ref_song, vae_model):
250
+ sampling_rate = 44100
251
+ downsample_rate = 2048
252
+ io_channels = 2
253
+ if edit:
254
+ input_audio, in_sr = torchaudio.load(ref_song)
255
+ input_audio = prepare_audio(input_audio, in_sr=in_sr, target_sr=sampling_rate, target_length=None, target_channels=io_channels, device=device)
256
+ input_audio = normalize_audio(input_audio, -6)
257
+
258
+ with torch.no_grad():
259
+ latent = encode_audio(input_audio, vae_model, chunked=True) # [b d t]
260
+ mean, scale = latent.chunk(2, dim=1)
261
+ prompt, _ = vae_sample(mean, scale)
262
+ prompt = prompt.transpose(1, 2) # [b t d]
263
+
264
+ pred_segments = json.loads(pred_segments)
265
+
266
+ pred_frames = []
267
+ for st, et in pred_segments:
268
+ sf = 0 if st == -1 else int(st * sampling_rate / downsample_rate)
269
+ ef = max_frames if et == -1 else int(et * sampling_rate / downsample_rate)
270
+ pred_frames.append((sf, ef))
271
+
272
+ return prompt, pred_frames
273
+ else:
274
+ prompt = torch.zeros(1, max_frames, 64).to(device)
275
+ pred_frames = [(0, max_frames)]
276
+ return prompt, pred_frames
277
+
278
+
279
+ def get_negative_style_prompt(device):
280
+ file_path = "infer/example/vocal.npy"
281
+ vocal_stlye = np.load(file_path)
282
+
283
+ vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512]
284
+ vocal_stlye = vocal_stlye.half()
285
+
286
+ return vocal_stlye
287
+
288
+
289
+ @torch.no_grad()
290
+ def get_style_prompt(model, wav_path=None, prompt=None):
291
+ mulan = model
292
+
293
+ if prompt is not None:
294
+ return mulan(texts=prompt).half()
295
+
296
+ ext = os.path.splitext(wav_path)[-1].lower()
297
+ if ext == ".mp3":
298
+ meta = MP3(wav_path)
299
+ audio_len = meta.info.length
300
+ elif ext in [".wav", ".flac"]:
301
+ audio_len = librosa.get_duration(path=wav_path)
302
+ else:
303
+ raise ValueError("Unsupported file format: {}".format(ext))
304
+
305
+ if audio_len < 10:
306
+ print(
307
+ f"Warning: The audio file {wav_path} is too short ({audio_len:.2f} seconds). Expected at least 10 seconds."
308
+ )
309
+
310
+ assert audio_len >= 10
311
+
312
+ mid_time = audio_len // 2
313
+ start_time = mid_time - 5
314
+ wav, _ = librosa.load(wav_path, sr=24000, offset=start_time, duration=10)
315
+
316
+ wav = torch.tensor(wav).unsqueeze(0).to(model.device)
317
+
318
+ with torch.no_grad():
319
+ audio_emb = mulan(wavs=wav) # [1, 512]
320
+
321
+ audio_emb = audio_emb
322
+ audio_emb = audio_emb.half()
323
+
324
+ return audio_emb
325
+
326
+
327
+ def parse_lyrics(lyrics: str):
328
+ lyrics_with_time = []
329
+ lyrics = lyrics.strip()
330
+ for line in lyrics.split("\n"):
331
+ try:
332
+ time, lyric = line[1:9], line[10:]
333
+ lyric = lyric.strip()
334
+ mins, secs = time.split(":")
335
+ secs = int(mins) * 60 + float(secs)
336
+ lyrics_with_time.append((secs, lyric))
337
+ except:
338
+ continue
339
+ return lyrics_with_time
340
+
341
+
342
+ class CNENTokenizer:
343
+ def __init__(self):
344
+ with open("./g2p/g2p/vocab.json", "r", encoding='utf-8') as file:
345
+ self.phone2id: dict = json.load(file)["vocab"]
346
+ self.id2phone = {v: k for (k, v) in self.phone2id.items()}
347
+ from g2p.g2p_generation import chn_eng_g2p
348
+
349
+ self.tokenizer = chn_eng_g2p
350
+
351
+ def encode(self, text):
352
+ phone, token = self.tokenizer(text)
353
+ token = [x + 1 for x in token]
354
+ return token
355
+
356
+ def decode(self, token):
357
+ return "|".join([self.id2phone[x - 1] for x in token])
358
+
359
+
360
+ def get_lrc_token(max_frames, text, tokenizer, device):
361
+
362
+ lyrics_shift = 0
363
+ sampling_rate = 44100
364
+ downsample_rate = 2048
365
+ max_secs = max_frames / (sampling_rate / downsample_rate)
366
+
367
+ comma_token_id = 1
368
+ period_token_id = 2
369
+
370
+ lrc_with_time = parse_lyrics(text)
371
+
372
+ modified_lrc_with_time = []
373
+ for i in range(len(lrc_with_time)):
374
+ time, line = lrc_with_time[i]
375
+ line_token = tokenizer.encode(line)
376
+ modified_lrc_with_time.append((time, line_token))
377
+ lrc_with_time = modified_lrc_with_time
378
+
379
+ lrc_with_time = [
380
+ (time_start, line)
381
+ for (time_start, line) in lrc_with_time
382
+ if time_start < max_secs
383
+ ]
384
+ if max_frames == 2048:
385
+ lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
386
+
387
+ normalized_start_time = 0.0
388
+
389
+ lrc = torch.zeros((max_frames,), dtype=torch.long)
390
+
391
+ tokens_count = 0
392
+ last_end_pos = 0
393
+ for time_start, line in lrc_with_time:
394
+ tokens = [
395
+ token if token != period_token_id else comma_token_id for token in line
396
+ ] + [period_token_id]
397
+ tokens = torch.tensor(tokens, dtype=torch.long)
398
+ num_tokens = tokens.shape[0]
399
+
400
+ gt_frame_start = int(time_start * sampling_rate / downsample_rate)
401
+
402
+ frame_shift = random.randint(int(-lyrics_shift), int(lyrics_shift))
403
+
404
+ frame_start = max(gt_frame_start - frame_shift, last_end_pos)
405
+ frame_len = min(num_tokens, max_frames - frame_start)
406
+
407
+ lrc[frame_start : frame_start + frame_len] = tokens[:frame_len]
408
+
409
+ tokens_count += num_tokens
410
+ last_end_pos = frame_start + frame_len
411
+
412
+ lrc_emb = lrc.unsqueeze(0).to(device)
413
+
414
+ normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device)
415
+ normalized_start_time = normalized_start_time.half()
416
+
417
+ return lrc_emb, normalized_start_time
418
+
419
+
420
+ def load_checkpoint(model, ckpt_path, device, use_ema=True):
421
+ model = model.half()
422
+
423
+ ckpt_type = ckpt_path.split(".")[-1]
424
+ if ckpt_type == "safetensors":
425
+ from safetensors.torch import load_file
426
+
427
+ checkpoint = load_file(ckpt_path)
428
+ else:
429
+ checkpoint = torch.load(ckpt_path, weights_only=True)
430
+
431
+ if use_ema:
432
+ if ckpt_type == "safetensors":
433
+ checkpoint = {"ema_model_state_dict": checkpoint}
434
+ checkpoint["model_state_dict"] = {
435
+ k.replace("ema_model.", ""): v
436
+ for k, v in checkpoint["ema_model_state_dict"].items()
437
+ if k not in ["initted", "step"]
438
+ }
439
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
440
+ else:
441
+ if ckpt_type == "safetensors":
442
+ checkpoint = {"model_state_dict": checkpoint}
443
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
444
+
445
+ return model.to(device)