viewfinder-annn commited on
Commit
85651ad
·
verified ·
1 Parent(s): 50ec95a

Upload inference related files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example/gradio/example1.mp3 filter=lfs diff=lfs merge=lfs -text
37
+ example/gradio/example2.wav filter=lfs diff=lfs merge=lfs -text
38
+ example/gradio/example3.wav filter=lfs diff=lfs merge=lfs -text
anyaccomp/fmt_model.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import math
5
+ from einops import rearrange
6
+ from anyaccomp.llama_nar import DiffLlamaConcat
7
+ import torch.nn.functional as F
8
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
9
+ from typing import List, Optional, Tuple, Union
10
+ from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
11
+
12
+
13
+ class FlowMatchingTransformerConcat(nn.Module):
14
+ def __init__(
15
+ self,
16
+ vocab_size=1024,
17
+ mel_dim=100,
18
+ hidden_size=1024,
19
+ num_layers=12,
20
+ num_heads=16,
21
+ cfg_scale=0.2,
22
+ use_cond_code=False,
23
+ cond_codebook_size=1024,
24
+ cond_dim=1024,
25
+ cond_scale_factor=1,
26
+ sigma=1e-5,
27
+ time_scheduler="linear",
28
+ cfg=None,
29
+ ):
30
+ super().__init__()
31
+ self.cfg = cfg
32
+
33
+ mel_dim = (
34
+ cfg.mel_dim if cfg is not None and hasattr(cfg, "mel_dim") else mel_dim
35
+ )
36
+ hidden_size = (
37
+ cfg.hidden_size
38
+ if cfg is not None and hasattr(cfg, "hidden_size")
39
+ else hidden_size
40
+ )
41
+ num_layers = (
42
+ cfg.num_layers
43
+ if cfg is not None and hasattr(cfg, "num_layers")
44
+ else num_layers
45
+ )
46
+ num_heads = (
47
+ cfg.num_heads
48
+ if cfg is not None and hasattr(cfg, "num_heads")
49
+ else num_heads
50
+ )
51
+ cfg_scale = (
52
+ cfg.cfg_scale
53
+ if cfg is not None and hasattr(cfg, "cfg_scale")
54
+ else cfg_scale
55
+ )
56
+ use_cond_code = (
57
+ cfg.use_cond_code
58
+ if cfg is not None and hasattr(cfg, "use_cond_code")
59
+ else use_cond_code
60
+ )
61
+ cond_codebook_size = (
62
+ cfg.cond_codebook_size
63
+ if cfg is not None and hasattr(cfg, "cond_codebook_size")
64
+ else cond_codebook_size
65
+ )
66
+ cond_dim = (
67
+ cfg.cond_dim if cfg is not None and hasattr(cfg, "cond_dim") else cond_dim
68
+ )
69
+ time_scheduler = (
70
+ cfg.time_scheduler
71
+ if cfg is not None and hasattr(cfg, "time_scheduler")
72
+ else time_scheduler
73
+ )
74
+ sigma = cfg.sigma if cfg is not None and hasattr(cfg, "sigma") else sigma
75
+ cond_scale_factor = (
76
+ cfg.cond_scale_factor
77
+ if cfg is not None and hasattr(cfg, "cond_scale_factor")
78
+ else cond_scale_factor
79
+ )
80
+
81
+ self.mel_dim = mel_dim
82
+ self.hidden_size = hidden_size
83
+ self.num_layers = num_layers
84
+ self.num_heads = num_heads
85
+ self.cfg_scale = cfg_scale
86
+ self.use_cond_code = use_cond_code
87
+ self.cond_codebook_size = cond_codebook_size
88
+ self.cond_dim = cond_dim
89
+ self.time_scheduler = time_scheduler
90
+ self.sigma = sigma
91
+ self.cond_scale_factor = cond_scale_factor
92
+
93
+ self.vocab_size = (
94
+ cfg.vocab_size
95
+ if cfg is not None and hasattr(cfg, "vocab_size")
96
+ else vocab_size
97
+ )
98
+ self.vocal_mel_proj = (
99
+ nn.Linear(self.cfg.cond_code_dim, self.hidden_size)
100
+ if not self.use_cond_code
101
+ else nn.Sequential(
102
+ nn.Embedding(
103
+ self.vocab_size, self.mel_dim
104
+ ), # [batch] -> [batch, mel_dim]
105
+ nn.Linear(
106
+ self.mel_dim, self.hidden_size
107
+ ), # [batch, mel_dim] -> [batch, hidden_size]
108
+ )
109
+ )
110
+
111
+ self.diff_estimator = DiffLlamaConcat(
112
+ mel_dim=self.mel_dim,
113
+ hidden_size=self.hidden_size,
114
+ num_heads=self.num_heads,
115
+ num_layers=self.num_layers,
116
+ flash_attention=hasattr(cfg, "flash_attention") and cfg.flash_attention,
117
+ )
118
+
119
+ if hasattr(cfg, "repa_loss") and cfg.repa_loss.enable:
120
+ repa_dim = (
121
+ cfg.repa_loss.repa_dim
122
+ if hasattr(cfg.repa_loss, "repa_dim")
123
+ else self.hidden_size
124
+ )
125
+ self.repa_proj = nn.Sequential(
126
+ nn.Linear(self.hidden_size, self.hidden_size),
127
+ nn.SiLU(),
128
+ nn.Linear(self.hidden_size, self.hidden_size),
129
+ nn.SiLU(),
130
+ nn.Linear(self.hidden_size, repa_dim),
131
+ )
132
+
133
+ self.reset_parameters()
134
+
135
+ def reset_parameters(self):
136
+ def _reset_parameters(m):
137
+ if isinstance(m, nn.MultiheadAttention):
138
+ if m._qkv_same_embed_dim:
139
+ nn.init.normal_(m.in_proj_weight, std=0.02)
140
+ else:
141
+ nn.init.normal_(m.q_proj_weight, std=0.02)
142
+ nn.init.normal_(m.k_proj_weight, std=0.02)
143
+ nn.init.normal_(m.v_proj_weight, std=0.02)
144
+
145
+ if m.in_proj_bias is not None:
146
+ nn.init.constant_(m.in_proj_bias, 0.0)
147
+ nn.init.constant_(m.out_proj.bias, 0.0)
148
+ if m.bias_k is not None:
149
+ nn.init.xavier_normal_(m.bias_k)
150
+ if m.bias_v is not None:
151
+ nn.init.xavier_normal_(m.bias_v)
152
+
153
+ elif (
154
+ isinstance(m, nn.Conv1d)
155
+ or isinstance(m, nn.ConvTranspose1d)
156
+ or isinstance(m, nn.Conv2d)
157
+ or isinstance(m, nn.ConvTranspose2d)
158
+ ):
159
+ m.weight.data.normal_(0.0, 0.02)
160
+
161
+ elif isinstance(m, nn.Linear):
162
+ m.weight.data.normal_(mean=0.0, std=0.02)
163
+ if m.bias is not None:
164
+ m.bias.data.zero_()
165
+
166
+ elif isinstance(m, nn.Embedding):
167
+ m.weight.data.normal_(mean=0.0, std=0.02)
168
+ if m.padding_idx is not None:
169
+ m.weight.data[m.padding_idx].zero_()
170
+
171
+ self.apply(_reset_parameters)
172
+
173
+ @torch.no_grad()
174
+ def forward_diffusion(self, x, t):
175
+ """
176
+ x: (B, T, mel_dim)
177
+ t: (B,)
178
+ """
179
+
180
+ new_t = t
181
+ t = t.unsqueeze(-1).unsqueeze(-1)
182
+ z = torch.randn(
183
+ x.shape, dtype=x.dtype, device=x.device, requires_grad=False
184
+ ) # (B, T, mel_dim)
185
+
186
+ cfg_scale = self.cfg_scale
187
+
188
+ # get prompt len
189
+ if torch.rand(1) > 0.7:
190
+ prompt_len = torch.randint(
191
+ min(x.shape[1] // 4, 5), int(x.shape[1] * 0.4), (x.shape[0],)
192
+ ).to(
193
+ x.device
194
+ ) # (B,)
195
+ else:
196
+ prompt_len = torch.zeros(x.shape[0]).to(x.device)
197
+
198
+ split_ratio = torch.rand(prompt_len.shape, device=prompt_len.device) # (B,)
199
+
200
+ left_len = (split_ratio * (prompt_len + 1).float()).long() # (B,)
201
+ right_len = prompt_len - left_len # (B,)
202
+
203
+ T = x.shape[1]
204
+ is_prompt = torch.zeros_like(x[:, :, 0]) # (B, T)
205
+ col_indices = torch.arange(T, device=x.device).repeat(x.shape[0], 1) # (B, T)
206
+ left_mask = col_indices < left_len.unsqueeze(1)
207
+ right_mask = col_indices >= (T - right_len.unsqueeze(1))
208
+ is_prompt[left_mask | right_mask] = 1
209
+
210
+ mask = torch.ones_like(x[:, :, 0]) # mask if 1, not mask if 0
211
+ mask[is_prompt.bool()] = 0
212
+ mask = mask[:, :, None]
213
+
214
+ # flow matching: xt = (1 - (1 - sigma) * t) * x0 + t * x; where x0 ~ N(0, 1), x is a sample
215
+ # flow gt: x - (1 - sigma) * x0 = x - (1 - sigma) * noise
216
+ xt = ((1 - (1 - self.sigma) * t) * z + t * x) * mask + x * (1 - mask)
217
+
218
+ return xt, z, new_t, prompt_len, mask
219
+
220
+ def loss_t(
221
+ self,
222
+ x,
223
+ x_mask,
224
+ t,
225
+ lyric=None,
226
+ output_hidden_states=False,
227
+ ):
228
+ xt, z, new_t, prompt_len, mask = self.forward_diffusion(x, t)
229
+
230
+ noise = z
231
+
232
+ prompt_len = prompt_len.float()
233
+
234
+ # drop condition using cfg_scale
235
+ if lyric is not None:
236
+ cfg_mask = torch.where(
237
+ torch.rand_like(prompt_len) > self.cfg_scale,
238
+ torch.ones_like(prompt_len), # keep cond
239
+ torch.zeros_like(prompt_len), # drop cond
240
+ ).to(lyric.device)
241
+
242
+ cond_mask = cfg_mask[:, None, None] # [b, 1, 1]
243
+
244
+ lyric = lyric * cond_mask
245
+
246
+ final_mask = mask * x_mask[..., None] # (B, T, 1)
247
+
248
+ output = self.diff_estimator(
249
+ xt, new_t, x_mask, lyric, output_hidden_states=output_hidden_states
250
+ )
251
+ if output_hidden_states:
252
+ return_list = [noise, x, output["hidden_states"], final_mask, prompt_len]
253
+ return_list.append(output["all_hidden_states"])
254
+ else:
255
+ return_list = [noise, x, output, final_mask, prompt_len]
256
+
257
+ return return_list
258
+
259
+ def compute_loss(self, x, x_mask, lyric=None, output_hidden_states=False):
260
+ # x0: (B, T, num_quantizer)
261
+ # x_mask: (B, T) mask is 0 for padding
262
+ t = torch.rand(x.shape[0], device=x.device, requires_grad=False)
263
+ t = torch.clamp(t, 1e-5, 1.0)
264
+ # from CosyVoice: considering the generation process at the beginning is harder than follows, we involve a cosine scheduler for the timestep t
265
+ if self.time_scheduler == "cos":
266
+ t = 1 - torch.cos(t * math.pi * 0.5)
267
+ else:
268
+ pass
269
+ return self.loss_t(
270
+ x, x_mask, t, lyric, output_hidden_states=output_hidden_states
271
+ )
272
+
273
+ def forward(self, x, x_mask, vocal_mel, output_hidden_states=False):
274
+ cond = self.vocal_mel_proj(vocal_mel)
275
+ return self.compute_loss(x, x_mask, cond, output_hidden_states)
276
+
277
+ @torch.no_grad()
278
+ def reverse_diffusion(
279
+ self,
280
+ vocal_mel=None,
281
+ prompt=None,
282
+ right_prompt=None,
283
+ x_mask=None,
284
+ prompt_mask=None,
285
+ right_prompt_mask=None,
286
+ target_len=None,
287
+ n_timesteps=10,
288
+ cfg=1.0,
289
+ rescale_cfg=0.75,
290
+ ):
291
+ h = 1.0 / n_timesteps
292
+ prompt_len = prompt.shape[1] if prompt is not None else 0
293
+ right_prompt_len = right_prompt.shape[1] if right_prompt is not None else 0
294
+ # print(prompt_len, right_prompt_len)
295
+ if vocal_mel is not None:
296
+ target_len = vocal_mel.shape[1]
297
+ elif target_len is None:
298
+ target_len = 1000 # hardcode 50Hz 20s
299
+ else:
300
+ raise ValueError
301
+ full_len = target_len
302
+ target_len = target_len - prompt_len - right_prompt_len
303
+
304
+ cond = self.vocal_mel_proj(vocal_mel)
305
+
306
+ if x_mask is None:
307
+ x_mask = torch.ones(cond.shape[0], target_len).to(cond.device)
308
+ if prompt_mask is None and prompt is not None:
309
+ prompt_mask = torch.ones(cond.shape[0], prompt_len).to(cond.device)
310
+ if right_prompt_mask is None and right_prompt is not None:
311
+ right_prompt_mask = torch.ones(cond.shape[0], right_prompt_len).to(
312
+ cond.device
313
+ )
314
+
315
+ if prompt is not None and right_prompt is not None:
316
+ xt_mask = torch.cat([prompt_mask, x_mask, right_prompt_mask], dim=1)
317
+ elif prompt is not None and right_prompt is None:
318
+ xt_mask = torch.cat([prompt_mask, x_mask], dim=1)
319
+ elif prompt is None and right_prompt is not None:
320
+ xt_mask = torch.cat([x_mask, right_prompt_mask], dim=1)
321
+ else:
322
+ xt_mask = x_mask
323
+
324
+ z = torch.randn(
325
+ (cond.shape[0], target_len, self.mel_dim),
326
+ dtype=cond.dtype,
327
+ device=cond.device,
328
+ requires_grad=False,
329
+ )
330
+ xt = z
331
+ # t from 0 to 1: x0 = z ~ N(0, 1)
332
+ for i in range(n_timesteps):
333
+ if prompt is not None and right_prompt is not None:
334
+ xt_input = torch.cat([prompt, xt, right_prompt], dim=1)
335
+ elif prompt is not None and right_prompt is None:
336
+ xt_input = torch.cat([prompt, xt], dim=1)
337
+ elif prompt is None and right_prompt is not None:
338
+ xt_input = torch.cat([xt, right_prompt], dim=1)
339
+ else:
340
+ xt_input = xt
341
+ t = (0 + (i + 0.5) * h) * torch.ones(
342
+ z.shape[0], dtype=z.dtype, device=z.device
343
+ )
344
+ flow_pred = self.diff_estimator(xt_input, t, xt_mask, cond)
345
+ flow_pred = flow_pred[:, prompt_len : prompt_len + target_len, :]
346
+ # cfg
347
+
348
+ if cfg > 0:
349
+ uncond_flow_pred = self.diff_estimator(
350
+ xt_input, t, xt_mask, torch.zeros_like(cond)
351
+ )
352
+ uncond_flow_pred = uncond_flow_pred[
353
+ :, prompt_len : prompt_len + target_len, :
354
+ ]
355
+ pos_flow_pred_std = flow_pred.std()
356
+ flow_pred_cfg = flow_pred + cfg * (flow_pred - uncond_flow_pred)
357
+ rescale_flow_pred = (
358
+ flow_pred_cfg * pos_flow_pred_std / flow_pred_cfg.std()
359
+ )
360
+ flow_pred = (
361
+ rescale_cfg * rescale_flow_pred + (1 - rescale_cfg) * flow_pred_cfg
362
+ )
363
+
364
+ dxt = flow_pred * h
365
+ xt = xt + dxt
366
+
367
+ return xt
anyaccomp/inference_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import json
3
+ import librosa
4
+ import torch
5
+ import torchaudio
6
+ import accelerate
7
+ import safetensors
8
+ import numpy as np
9
+ import os
10
+ import yaml
11
+
12
+ import torchvision
13
+ from librosa.feature import chroma_stft
14
+
15
+ import torchvision
16
+ import random
17
+ import numpy as np
18
+
19
+ import sys
20
+
21
+
22
+ from anyaccomp.fmt_model import FlowMatchingTransformerConcat
23
+ from models.codec.amphion_codec.vocos import Vocos
24
+ from models.codec.melvqgan.melspec import MelSpectrogram
25
+ from models.codec.coco.rep_coco_model import CocoContentStyle, CocoContent, CocoStyle
26
+
27
+ from tqdm import tqdm
28
+
29
+ from utils.util import load_config
30
+
31
+ import io
32
+
33
+ from transformers import T5Tokenizer, T5EncoderModel
34
+
35
+ import warnings
36
+
37
+
38
+ class Sing2SongInferencePipeline:
39
+ def __init__(
40
+ self,
41
+ checkpoint_path,
42
+ cfg_path,
43
+ vocoder_checkpoint_path,
44
+ vocoder_cfg_path,
45
+ device="cuda",
46
+ ):
47
+ self.cfg = load_config(cfg_path)
48
+ self.device = device
49
+
50
+ self.checkpoint_path = checkpoint_path
51
+ self._load_model(checkpoint_path)
52
+
53
+ self._build_input_model()
54
+ self.vocoder_checkpoint_path = vocoder_checkpoint_path
55
+ self.vocoder_cfg = load_config(vocoder_cfg_path)
56
+ self._build_output_model()
57
+ print("Output model built")
58
+
59
+ def _load_model(self, checkpoint_path):
60
+ self.model = FlowMatchingTransformerConcat(
61
+ cfg=self.cfg.model.flow_matching_transformer
62
+ )
63
+
64
+ accelerate.load_checkpoint_and_dispatch(self.model, checkpoint_path)
65
+ self.model.eval().to(self.device)
66
+ print(
67
+ f"model Params: {round(sum(p.numel() for p in self.model.parameters() if p.requires_grad)/1e6, 2)}M"
68
+ )
69
+ print(f"Loaded model from {checkpoint_path}")
70
+
71
+ def _build_input_model(self):
72
+ self.coco_model = CocoStyle(
73
+ cfg=self.cfg.model.coco, construct_only_for_quantizer=True
74
+ )
75
+ self.coco_model.eval()
76
+ self.coco_model.to(self.device)
77
+ accelerate.load_checkpoint_and_dispatch(
78
+ self.coco_model, self.cfg.model.coco.pretrained_path
79
+ )
80
+
81
+ def _build_output_model(self):
82
+ # print(vocoder_checkpoint_path)
83
+ self.vocoder = Vocos(cfg=self.vocoder_cfg.model.vocos)
84
+ accelerate.load_checkpoint_and_dispatch(
85
+ self.vocoder, self.vocoder_checkpoint_path
86
+ )
87
+ self.vocoder = self.vocoder.eval().to(self.device)
88
+
89
+ @torch.no_grad()
90
+ @torch.cuda.amp.autocast(dtype=torch.bfloat16)
91
+ def _extract_coco_codec(self, speech):
92
+ """
93
+ Args:
94
+ speech: [B, T]
95
+ Returns:
96
+ codecs: [B, T]. Note that codecs might be not at 50Hz!
97
+ """
98
+ target_chroma_dim = self.cfg.model.coco.chromagram_dim
99
+
100
+ speech = speech.cpu().numpy().squeeze()
101
+
102
+ chromagram = chroma_stft(
103
+ y=speech,
104
+ sr=self.cfg.preprocess.chromagram.sample_rate,
105
+ n_fft=self.cfg.preprocess.chromagram.n_fft,
106
+ hop_length=self.cfg.preprocess.chromagram.hop_size,
107
+ win_length=self.cfg.preprocess.chromagram.win_size,
108
+ n_chroma=target_chroma_dim,
109
+ ).T # [D, T] -> [T, D]
110
+ chromagram_feats = torch.tensor(chromagram).unsqueeze(0).to(self.device)
111
+ codecs, _ = self.coco_model.quantize(chromagram_feats)
112
+ return codecs
113
+
114
+ @torch.no_grad()
115
+ def encode_vocal(self, speech): # (B, T)
116
+ speech = speech.to(self.device)
117
+ codecs = self._extract_coco_codec(speech)
118
+ return codecs
119
+
120
+ @torch.no_grad()
121
+ def _generate_audio(self, mel):
122
+ synthesized_audio = (self.vocoder(mel.transpose(1, 2)).detach().cpu())[0]
123
+
124
+ return synthesized_audio
anyaccomp/llama_nar.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import os
6
+ import torch.nn as nn
7
+ from typing import List, Optional, Tuple, Union
8
+ import math
9
+
10
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
11
+ from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
12
+ from transformers import LlamaConfig
13
+ from transformers.models.llama.modeling_llama import (
14
+ LlamaAttention,
15
+ apply_rotary_pos_emb,
16
+ Cache,
17
+ repeat_kv,
18
+ )
19
+
20
+
21
+ class SinusoidalPosEmb(nn.Module):
22
+ def __init__(self, dim):
23
+ super().__init__()
24
+ self.dim = dim
25
+
26
+ def forward(self, x):
27
+ device = x.device
28
+ half_dim = self.dim // 2
29
+ emb = math.log(10000) / (half_dim - 1)
30
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
31
+ emb = x[:, None] * emb[None, :] * 1.0
32
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
33
+ return emb
34
+
35
+
36
+ class LlamaAdaptiveRMSNorm(nn.Module):
37
+ def __init__(self, hidden_size=1024, eps=1e-6, dim_cond=1024):
38
+ super().__init__()
39
+ self.to_weight = nn.Linear(dim_cond, hidden_size)
40
+ nn.init.zeros_(self.to_weight.weight)
41
+ nn.init.ones_(self.to_weight.bias)
42
+ self.variance_epsilon = eps
43
+ self._is_hf_initialized = True # disable automatic init
44
+
45
+ def forward(self, hidden_states, cond_embedding):
46
+ input_dtype = hidden_states.dtype
47
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
48
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
49
+
50
+ weight = self.to_weight(cond_embedding)
51
+ if len(weight.shape) == 2:
52
+ weight = weight.unsqueeze(1)
53
+
54
+ return (weight * hidden_states).to(input_dtype)
55
+
56
+
57
+ class LlamaNARDecoderLayer(LlamaDecoderLayer):
58
+ def __init__(self, config: LlamaConfig, layer_idx: int):
59
+ """Override to adaptive layer norm"""
60
+ super().__init__(config, layer_idx) # init attention, mlp, etc.
61
+ # self.self_attn = LlamaXformersAttention(config=config, layer_idx=layer_idx)
62
+
63
+ self.self_attn.is_causal = False # for flash attn..
64
+
65
+ self.input_layernorm = LlamaAdaptiveRMSNorm(
66
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
67
+ )
68
+ self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
69
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
70
+ )
71
+
72
+ # add `cond` in forward function
73
+ def forward(
74
+ self,
75
+ hidden_states: torch.Tensor,
76
+ cond_embedding: torch.Tensor,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ position_ids: Optional[torch.LongTensor] = None,
79
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
80
+ output_attentions: Optional[bool] = False,
81
+ use_cache: Optional[bool] = False,
82
+ ) -> Tuple[
83
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
84
+ ]:
85
+ """
86
+ Args:
87
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
88
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
89
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
90
+ output_attentions (`bool`, *optional*):
91
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
92
+ returned tensors for more detail.
93
+ use_cache (`bool`, *optional*):
94
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
95
+ (see `past_key_values`).
96
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
97
+ """
98
+
99
+ residual = hidden_states
100
+
101
+ hidden_states = self.input_layernorm(
102
+ hidden_states, cond_embedding=cond_embedding
103
+ )
104
+
105
+ # Self Attention
106
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
107
+ hidden_states=hidden_states,
108
+ attention_mask=attention_mask,
109
+ position_ids=position_ids,
110
+ past_key_value=past_key_value,
111
+ output_attentions=output_attentions,
112
+ use_cache=use_cache,
113
+ )
114
+ hidden_states = residual + hidden_states
115
+
116
+ # Fully Connected
117
+ residual = hidden_states
118
+ hidden_states = self.post_attention_layernorm(
119
+ hidden_states, cond_embedding=cond_embedding
120
+ )
121
+ hidden_states = self.mlp(hidden_states)
122
+ hidden_states = residual + hidden_states
123
+
124
+ outputs = (hidden_states,)
125
+
126
+ if output_attentions:
127
+ outputs += (self_attn_weights,)
128
+
129
+ if use_cache:
130
+ outputs += (present_key_value,)
131
+
132
+ return outputs
133
+
134
+
135
+ class DiffLlamaConcat(LlamaModel):
136
+ def __init__(
137
+ self,
138
+ mel_dim=100,
139
+ hidden_size=1024,
140
+ num_heads=16,
141
+ num_layers=16,
142
+ dropout=0.1,
143
+ ffn_dropout=0.1,
144
+ attention_dropout=0.0,
145
+ config=LlamaConfig(0, 256, 1024, 1, 1),
146
+ flash_attention=False,
147
+ ):
148
+ super().__init__(config)
149
+
150
+ self.flash_attention = flash_attention
151
+ self.layers = nn.ModuleList(
152
+ [
153
+ LlamaNARDecoderLayer(
154
+ LlamaConfig(
155
+ hidden_size=hidden_size,
156
+ num_attention_heads=num_heads,
157
+ max_position_embeddings=4096,
158
+ intermediate_size=hidden_size * 4,
159
+ attn_implementation=(
160
+ "flash_attention_2" if self.flash_attention else "eager"
161
+ ),
162
+ ),
163
+ layer_idx=i,
164
+ )
165
+ for i in range(num_layers)
166
+ ]
167
+ )
168
+
169
+ self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size)
170
+
171
+ self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
172
+ self.diff_step_mlp = nn.Sequential(
173
+ nn.Linear(hidden_size, hidden_size * 4),
174
+ nn.SiLU(),
175
+ nn.Linear(hidden_size * 4, hidden_size),
176
+ )
177
+
178
+ self.cond_mlp = nn.Sequential(
179
+ nn.Linear(hidden_size, hidden_size * 4),
180
+ nn.SiLU(),
181
+ nn.Linear(hidden_size * 4, hidden_size),
182
+ )
183
+
184
+ self.mel_mlp = nn.Sequential(
185
+ nn.Linear(mel_dim, hidden_size * 4),
186
+ nn.SiLU(),
187
+ nn.Linear(hidden_size * 4, hidden_size),
188
+ )
189
+
190
+ self.mel_out_mlp = nn.Sequential(
191
+ nn.Linear(hidden_size, hidden_size * 4),
192
+ nn.SiLU(),
193
+ nn.Linear(hidden_size * 4, mel_dim),
194
+ )
195
+
196
+ for layer in self.layers:
197
+ layer.input_layernorm = LlamaAdaptiveRMSNorm(
198
+ hidden_size, dim_cond=hidden_size
199
+ )
200
+ layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
201
+ hidden_size, dim_cond=hidden_size
202
+ )
203
+
204
+ self.embed_tokens = None
205
+
206
+ self.post_init()
207
+
208
+ # self.reset_parameters()
209
+
210
+ def _prepare_decoder_attention_mask(
211
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
212
+ ):
213
+ # create noncausal mask
214
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
215
+ combined_attention_mask = None
216
+
217
+ def _expand_mask(
218
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
219
+ ):
220
+ """
221
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
222
+ """
223
+ bsz, src_len = mask.size()
224
+ tgt_len = tgt_len if tgt_len is not None else src_len
225
+
226
+ expanded_mask = (
227
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
228
+ )
229
+
230
+ inverted_mask = 1.0 - expanded_mask
231
+
232
+ return inverted_mask.masked_fill(
233
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
234
+ )
235
+
236
+ if attention_mask is not None:
237
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
238
+ expanded_attn_mask = _expand_mask(
239
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
240
+ ).to(inputs_embeds.device)
241
+ combined_attention_mask = (
242
+ expanded_attn_mask
243
+ if combined_attention_mask is None
244
+ else expanded_attn_mask + combined_attention_mask
245
+ )
246
+
247
+ return combined_attention_mask
248
+
249
+ def forward(
250
+ self,
251
+ x,
252
+ diffusion_step,
253
+ x_mask,
254
+ cond,
255
+ input_ids: torch.LongTensor = None, # [num_quant, B, T]
256
+ attention_mask: Optional[torch.Tensor] = None,
257
+ position_ids: Optional[torch.LongTensor] = None,
258
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
259
+ inputs_embeds: Optional[torch.FloatTensor] = None,
260
+ use_cache: Optional[bool] = None,
261
+ output_attentions: Optional[bool] = None,
262
+ output_hidden_states: Optional[bool] = None,
263
+ return_dict: Optional[bool] = None,
264
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
265
+
266
+ # retrieve some shape info
267
+ batch_size, seq_length, _ = x.shape
268
+
269
+ # condtion mlp
270
+ cond_embedding = self.cond_mlp(cond) # (B, T, C)
271
+
272
+ # condition mel
273
+ x = self.mel_mlp(x)
274
+
275
+ # diffusion step embedding
276
+ diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device)
277
+ diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C)
278
+ x = x + cond_embedding
279
+
280
+ inputs_embeds = x
281
+ # if self.flash_attention:
282
+ # attention_mask = None
283
+ # else:
284
+ attention_mask = x_mask
285
+
286
+ # assert x_mask.shape == batch_size, seq_length
287
+
288
+ output_attentions = (
289
+ output_attentions
290
+ if output_attentions is not None
291
+ else self.config.output_attentions
292
+ )
293
+ output_hidden_states = (
294
+ output_hidden_states
295
+ if output_hidden_states is not None
296
+ else self.config.output_hidden_states
297
+ )
298
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
299
+
300
+ return_dict = (
301
+ return_dict if return_dict is not None else self.config.use_return_dict
302
+ )
303
+
304
+ seq_length_with_past = seq_length
305
+ past_key_values_length = 0
306
+
307
+ if past_key_values is not None:
308
+ past_key_values_length = past_key_values[0][0].shape[2]
309
+ seq_length_with_past = seq_length_with_past + past_key_values_length
310
+
311
+ if position_ids is None:
312
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
313
+ position_ids = torch.arange(
314
+ past_key_values_length,
315
+ seq_length + past_key_values_length,
316
+ dtype=torch.long,
317
+ device=device,
318
+ )
319
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
320
+ else:
321
+ position_ids = position_ids.view(-1, seq_length).long()
322
+
323
+ if not self.flash_attention:
324
+ # embed positions
325
+ if attention_mask is None:
326
+ attention_mask = torch.ones(
327
+ (batch_size, seq_length_with_past),
328
+ dtype=torch.bool,
329
+ device=inputs_embeds.device,
330
+ )
331
+ attention_mask = self._prepare_decoder_attention_mask(
332
+ attention_mask,
333
+ (batch_size, seq_length),
334
+ inputs_embeds,
335
+ past_key_values_length,
336
+ )
337
+
338
+ hidden_states = inputs_embeds
339
+
340
+ if self.gradient_checkpointing and self.training:
341
+ if use_cache:
342
+ use_cache = False
343
+
344
+ # decoder layers
345
+ all_hidden_states = () if output_hidden_states else None
346
+ all_self_attns = () if output_attentions else None
347
+ next_decoder_cache = () if use_cache else None
348
+
349
+ for idx, decoder_layer in enumerate(self.layers):
350
+ if output_hidden_states:
351
+ all_hidden_states += (hidden_states,)
352
+
353
+ past_key_value = (
354
+ past_key_values[idx] if past_key_values is not None else None
355
+ )
356
+
357
+ if self.gradient_checkpointing and self.training:
358
+ raise NotImplementedError
359
+
360
+ def create_custom_forward(module):
361
+ def custom_forward(*inputs):
362
+ # None for past_key_value
363
+ return module(*inputs, output_attentions, None)
364
+
365
+ return custom_forward
366
+
367
+ layer_outputs = torch.utils.checkpoint.checkpoint(
368
+ create_custom_forward(decoder_layer),
369
+ hidden_states,
370
+ attention_mask,
371
+ position_ids,
372
+ None,
373
+ )
374
+ else:
375
+ layer_outputs = decoder_layer(
376
+ hidden_states,
377
+ # attention_mask=attention_mask if not self.flash_attention else None,
378
+ attention_mask=attention_mask,
379
+ position_ids=position_ids,
380
+ past_key_value=past_key_value,
381
+ output_attentions=output_attentions,
382
+ use_cache=use_cache,
383
+ cond_embedding=diffusion_step,
384
+ )
385
+
386
+ hidden_states = layer_outputs[0]
387
+
388
+ if use_cache:
389
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
390
+
391
+ if output_attentions:
392
+ all_self_attns += (layer_outputs[1],)
393
+
394
+ hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step)
395
+
396
+ # add hidden states from the last decoder layer
397
+ if output_hidden_states:
398
+ all_hidden_states += (hidden_states,)
399
+
400
+ next_cache = next_decoder_cache if use_cache else None
401
+
402
+ hidden_states = self.mel_out_mlp(hidden_states)
403
+
404
+ if not output_hidden_states:
405
+ return hidden_states
406
+ else:
407
+ return {
408
+ "hidden_states": hidden_states,
409
+ "all_hidden_states": all_hidden_states,
410
+ }
411
+
412
+
413
+ class DiffLlama(LlamaModel):
414
+ def __init__(
415
+ self,
416
+ mel_dim=100,
417
+ hidden_size=1024,
418
+ num_heads=16,
419
+ num_layers=16,
420
+ dropout=0.1,
421
+ ffn_dropout=0.1,
422
+ attention_dropout=0.0,
423
+ config=LlamaConfig(0, 256, 1024, 1, 1),
424
+ flash_attention=False,
425
+ ):
426
+ super().__init__(config)
427
+
428
+ self.flash_attention = flash_attention
429
+ self.layers = nn.ModuleList(
430
+ [
431
+ LlamaNARDecoderLayer(
432
+ LlamaConfig(
433
+ hidden_size=hidden_size,
434
+ num_attention_heads=num_heads,
435
+ max_position_embeddings=4096,
436
+ intermediate_size=hidden_size * 4,
437
+ attn_implementation=(
438
+ "flash_attention_2" if self.flash_attention else "eager"
439
+ ),
440
+ is_causal=False,
441
+ ),
442
+ layer_idx=i,
443
+ )
444
+ for i in range(num_layers)
445
+ ]
446
+ )
447
+
448
+ self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size)
449
+
450
+ self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
451
+ self.diff_step_mlp = nn.Sequential(
452
+ nn.Linear(hidden_size, hidden_size * 4),
453
+ nn.SiLU(),
454
+ nn.Linear(hidden_size * 4, hidden_size),
455
+ )
456
+
457
+ # self.cond_mlp = nn.Sequential(
458
+ # nn.Linear(hidden_size, hidden_size * 4),
459
+ # nn.SiLU(),
460
+ # nn.Linear(hidden_size * 4, hidden_size),
461
+ # )
462
+
463
+ self.mel_mlp = nn.Sequential(
464
+ nn.Linear(mel_dim, hidden_size * 4),
465
+ nn.SiLU(),
466
+ nn.Linear(hidden_size * 4, hidden_size),
467
+ )
468
+
469
+ self.mel_out_mlp = nn.Sequential(
470
+ nn.Linear(hidden_size, hidden_size * 4),
471
+ nn.SiLU(),
472
+ nn.Linear(hidden_size * 4, mel_dim),
473
+ )
474
+
475
+ for layer in self.layers:
476
+ layer.input_layernorm = LlamaAdaptiveRMSNorm(
477
+ hidden_size, dim_cond=hidden_size
478
+ )
479
+ layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
480
+ hidden_size, dim_cond=hidden_size
481
+ )
482
+
483
+ self.embed_tokens = None
484
+
485
+ self.post_init()
486
+
487
+ # self.reset_parameters()
488
+
489
+ def _prepare_decoder_attention_mask(
490
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
491
+ ):
492
+ # create noncausal mask
493
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
494
+ combined_attention_mask = None
495
+
496
+ def _expand_mask(
497
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
498
+ ):
499
+ """
500
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
501
+ """
502
+ bsz, src_len = mask.size()
503
+ tgt_len = tgt_len if tgt_len is not None else src_len
504
+
505
+ expanded_mask = (
506
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
507
+ )
508
+
509
+ inverted_mask = 1.0 - expanded_mask
510
+
511
+ return inverted_mask.masked_fill(
512
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
513
+ )
514
+
515
+ if attention_mask is not None:
516
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
517
+ expanded_attn_mask = _expand_mask(
518
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
519
+ ).to(inputs_embeds.device)
520
+ combined_attention_mask = (
521
+ expanded_attn_mask
522
+ if combined_attention_mask is None
523
+ else expanded_attn_mask + combined_attention_mask
524
+ )
525
+
526
+ return combined_attention_mask
527
+
528
+ def forward(
529
+ self,
530
+ x,
531
+ diffusion_step,
532
+ x_mask,
533
+ cond,
534
+ input_ids: torch.LongTensor = None, # [num_quant, B, T]
535
+ attention_mask: Optional[torch.Tensor] = None,
536
+ position_ids: Optional[torch.LongTensor] = None,
537
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
538
+ inputs_embeds: Optional[torch.FloatTensor] = None,
539
+ use_cache: Optional[bool] = None,
540
+ output_attentions: Optional[bool] = None,
541
+ output_hidden_states: Optional[bool] = None,
542
+ return_dict: Optional[bool] = None,
543
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
544
+
545
+ # retrieve some shape info
546
+ batch_size, seq_length, _ = x.shape
547
+
548
+ # condtion mlp
549
+ cond_embedding = self.cond_mlp(cond) # (B, T, C)
550
+
551
+ # condition mel
552
+ x = self.mel_mlp(x)
553
+
554
+ # diffusion step embedding
555
+ diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device)
556
+ diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C)
557
+ x = x + cond_embedding
558
+
559
+ inputs_embeds = x
560
+ attention_mask = x_mask
561
+
562
+ output_attentions = (
563
+ output_attentions
564
+ if output_attentions is not None
565
+ else self.config.output_attentions
566
+ )
567
+ output_hidden_states = (
568
+ output_hidden_states
569
+ if output_hidden_states is not None
570
+ else self.config.output_hidden_states
571
+ )
572
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
573
+
574
+ return_dict = (
575
+ return_dict if return_dict is not None else self.config.use_return_dict
576
+ )
577
+
578
+ seq_length_with_past = seq_length
579
+ past_key_values_length = 0
580
+
581
+ if past_key_values is not None:
582
+ past_key_values_length = past_key_values[0][0].shape[2]
583
+ seq_length_with_past = seq_length_with_past + past_key_values_length
584
+
585
+ if position_ids is None:
586
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
587
+ position_ids = torch.arange(
588
+ past_key_values_length,
589
+ seq_length + past_key_values_length,
590
+ dtype=torch.long,
591
+ device=device,
592
+ )
593
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
594
+ else:
595
+ position_ids = position_ids.view(-1, seq_length).long()
596
+
597
+ hidden_states = inputs_embeds
598
+
599
+ if self.gradient_checkpointing and self.training:
600
+ if use_cache:
601
+ use_cache = False
602
+
603
+ # decoder layers
604
+ all_hidden_states = () if output_hidden_states else None
605
+ all_self_attns = () if output_attentions else None
606
+ next_decoder_cache = () if use_cache else None
607
+
608
+ for idx, decoder_layer in enumerate(self.layers):
609
+ if output_hidden_states:
610
+ all_hidden_states += (hidden_states,)
611
+
612
+ past_key_value = (
613
+ past_key_values[idx] if past_key_values is not None else None
614
+ )
615
+
616
+ if self.gradient_checkpointing and self.training:
617
+ raise NotImplementedError
618
+
619
+ def create_custom_forward(module):
620
+ def custom_forward(*inputs):
621
+ # None for past_key_value
622
+ return module(*inputs, output_attentions, None)
623
+
624
+ return custom_forward
625
+
626
+ layer_outputs = torch.utils.checkpoint.checkpoint(
627
+ create_custom_forward(decoder_layer),
628
+ hidden_states,
629
+ attention_mask,
630
+ position_ids,
631
+ None,
632
+ )
633
+ else:
634
+ layer_outputs = decoder_layer(
635
+ hidden_states,
636
+ attention_mask=attention_mask,
637
+ position_ids=position_ids,
638
+ past_key_value=past_key_value,
639
+ output_attentions=output_attentions,
640
+ use_cache=use_cache,
641
+ cond_embedding=diffusion_step,
642
+ )
643
+
644
+ hidden_states = layer_outputs[0]
645
+
646
+ if use_cache:
647
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
648
+
649
+ if output_attentions:
650
+ all_self_attns += (layer_outputs[1],)
651
+
652
+ hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step)
653
+
654
+ # add hidden states from the last decoder layer
655
+ if output_hidden_states:
656
+ all_hidden_states += (hidden_states,)
657
+
658
+ next_cache = next_decoder_cache if use_cache else None
659
+
660
+ hidden_states = self.mel_out_mlp(hidden_states)
661
+ if not output_hidden_states:
662
+ return hidden_states
663
+ else:
664
+ return {
665
+ "hidden_states": hidden_states,
666
+ "all_hidden_states": all_hidden_states,
667
+ }
config/flow_matching.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "Sing2SongNoText",
3
+ "preprocess": {
4
+ "use_mel": true,
5
+ "sample_rate": 24000,
6
+ "n_fft": 1920,
7
+ "num_mels": 128,
8
+ "sampling_rate": 24000,
9
+ "hop_size": 480,
10
+ "hop_size_vocal": 480,
11
+ "hop_size_accompaniment": 480,
12
+ "win_size": 1920,
13
+ "fmin": 0,
14
+ "fmax": 12000,
15
+ "mel_var": 8.14,
16
+ "mel_mean": -4.92,
17
+
18
+ "chromagram": {
19
+ "naive": true,
20
+ "hop_size": 480,
21
+ "sample_rate": 24000,
22
+ "n_fft": 1920,
23
+ "num_mels": 128,
24
+ "win_size": 1920,
25
+ "fmin": 0,
26
+ "fmax": 12000,
27
+ "mel_var": 8.14,
28
+ "mel_mean": -4.92,
29
+ "f0_fmin": 50.0,
30
+ "f0_fmax": 1100.0
31
+ }
32
+ },
33
+ "model": {
34
+ "flow_matching_transformer": {
35
+ "vocab_size": 512,
36
+ "use_cond_code": true,
37
+ "mel_dim": 128,
38
+ "cond_dim": 768,
39
+ "hidden_size": 1024,
40
+ "num_layers": 10,
41
+ "num_heads": 16,
42
+ "cfg_scale": 0.2,
43
+ "prompt_prob": 0.,
44
+ "use_pretrained_model": false,
45
+ "sigma": 1e-5,
46
+ "time_scheduler": "cos",
47
+ "repa_loss": {
48
+ "enable": true,
49
+ "weight": 0.5,
50
+ "repa_layer": 4,
51
+ },
52
+ "flash_attention": false,
53
+ },
54
+ "coco": {
55
+ "coco_type": "style", // content, style, or content_style
56
+ "downsample_rate": 1, // The original frame rate is 50 Hz, downsample to 6.25 Hz
57
+ "codebook_size": 512,
58
+ "hidden_size": 1024, // Representations Dim
59
+ "codebook_dim": 8,
60
+ "encoder": {
61
+ "vocos_dim": 384,
62
+ "vocos_intermediate_dim": 2048,
63
+ "vocos_num_layers": 12,
64
+ },
65
+ "decoder": {
66
+ "vocos_dim": 384,
67
+ "vocos_intermediate_dim": 2048,
68
+ "vocos_num_layers": 12,
69
+ },
70
+ "chromagram_dim": 24,
71
+ "pretrained_path": "./pretrained/vq"
72
+ },
73
+ },
74
+ }
config/vocoder.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "Vocoder",
3
+ "preprocess": {
4
+ "hop_size": 480,
5
+ "sample_rate": 24000,
6
+ "max_length": 36000,
7
+ "n_fft": 1920,
8
+ "num_mels": 128,
9
+ "win_size": 1920,
10
+ "fmin": 0,
11
+ "fmax": 12000,
12
+ "mel_var": 8.14,
13
+ "mel_mean": -4.92,
14
+ "processed_dir": "",
15
+ "valid_file": "valid.json",
16
+ "train_file": "train.json",
17
+ "use_phone_cond": false,
18
+ "use_emilia_101k": false
19
+ },
20
+ "model": {
21
+ "vocos": {
22
+ "input_channels": 128,
23
+ "dim": 1024,
24
+ "intermediate_dim": 4096,
25
+ "num_layers": 30,
26
+ "n_fft": 1920,
27
+ "hop_size": 480,
28
+ "padding": "same"
29
+ },
30
+ "period_gan": {
31
+ "max_downsample_channels": 1024,
32
+ "channels": 64,
33
+ "channel_increasing_factor": 2
34
+ },
35
+ "spec_gan": {
36
+ "stft_params": {
37
+ "fft_sizes": [128, 256, 512, 1024, 2048],
38
+ "hop_sizes": [32, 64, 128, 256, 512],
39
+ "win_lengths": [128, 256, 512, 1024, 2048],
40
+ "window": "hann_window"
41
+ },
42
+ "in_channels": 1,
43
+ "out_channels": 1,
44
+ "channels": 64,
45
+ "kernel_sizes": [5, 3],
46
+ "max_downsample_channels": 1024,
47
+ "down_scales": [2, 2, 2],
48
+ "use_weight_norm": true,
49
+ "use_complex": false
50
+ }
51
+ },
52
+ }
example/gradio/example1.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2860cf5b49b861f0770805c8cdda5b61276abc3931bb11140a5d6fa451418130
3
+ size 384580
example/gradio/example2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a16962f991ca69af95c79f39050017bd14e6dfd2c11b7547a10cd2b123b5ea6
3
+ size 2646044
example/gradio/example3.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a96d209d00b50cb0b2ea0d985d4d94a5ea29b20874e9f91f4ed15235d4018ec
3
+ size 2646044
models/__init__.py ADDED
File without changes
models/codec/__init__.py ADDED
File without changes
models/codec/amphion_codec/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/codec/amphion_codec/quantize/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
7
+ FactorizedVectorQuantize,
8
+ )
9
+ from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
10
+ from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
11
+ from models.codec.amphion_codec.quantize.residual_vq import ResidualVQ
models/codec/amphion_codec/quantize/factorized_vector_quantize.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+ from torch.nn.utils import weight_norm
12
+
13
+
14
+ def WNConv1d(*args, **kwargs):
15
+ return weight_norm(nn.Conv1d(*args, **kwargs))
16
+
17
+
18
+ def WNConvTranspose1d(*args, **kwargs):
19
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
20
+
21
+
22
+ class FactorizedVectorQuantize(nn.Module):
23
+ def __init__(
24
+ self,
25
+ input_dim,
26
+ codebook_size,
27
+ codebook_dim,
28
+ commitment=0.005,
29
+ codebook_loss_weight=1.0,
30
+ use_l2_normlize=True,
31
+ ):
32
+ super().__init__()
33
+ self.input_dim = input_dim
34
+ self.codebook_size = codebook_size
35
+ self.codebook_dim = codebook_dim
36
+ self.commitment = commitment
37
+ self.codebook_loss_weight = codebook_loss_weight
38
+ self.use_l2_normlize = use_l2_normlize
39
+
40
+ if self.input_dim != self.codebook_dim:
41
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
42
+ self.out_project = WNConv1d(
43
+ self.codebook_dim, self.input_dim, kernel_size=1
44
+ )
45
+
46
+ else:
47
+ self.in_project = nn.Identity()
48
+ self.out_project = nn.Identity()
49
+
50
+ self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
51
+
52
+ def forward(self, z):
53
+ """
54
+ Parameters
55
+ ----------
56
+ z: torch.Tensor[B x D x T]
57
+
58
+ Returns
59
+ -------
60
+ z_q: torch.Tensor[B x D x T]
61
+ Quantized continuous representation of input
62
+ commit_loss: Tensor[B]
63
+ Commitment loss to train encoder to predict vectors closer to codebook entries
64
+ codebook_loss: Tensor[B]
65
+ Codebook loss to update the codebook
66
+ indices: torch.Tensor[B x T]
67
+ Codebook indices (quantized discrete representation of input)
68
+ z_e: torch.Tensor[B x D x T]
69
+ Projected latents (continuous representation of input before quantization)
70
+ """
71
+
72
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
73
+ z_e = self.in_project(z)
74
+ z_q, indices = self.decode_latents(z_e)
75
+
76
+ # Compute commitment loss and codebook loss
77
+ if self.training:
78
+ commit_loss = (
79
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
80
+ * self.commitment
81
+ )
82
+ codebook_loss = (
83
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
84
+ * self.codebook_loss_weight
85
+ )
86
+ else:
87
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
88
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
89
+
90
+ z_q = z_e + (z_q - z_e).detach()
91
+
92
+ z_q = self.out_project(z_q)
93
+
94
+ return z_q, commit_loss, codebook_loss, indices, z_e
95
+
96
+ def embed_code(self, embed_id):
97
+ return F.embedding(embed_id, self.codebook.weight)
98
+
99
+ def decode_code(self, embed_id):
100
+ return self.embed_code(embed_id).transpose(1, 2)
101
+
102
+ def decode_latents(self, latents):
103
+ encodings = rearrange(latents, "b d t -> (b t) d")
104
+ codebook = self.codebook.weight
105
+
106
+ # L2 normalize encodings and codebook
107
+ if self.use_l2_normlize:
108
+ encodings = F.normalize(encodings)
109
+ codebook = F.normalize(codebook)
110
+
111
+ # Compute euclidean distance between encodings and codebook,
112
+ # if use_l2_normlize is True, the distance is equal to cosine distance
113
+ dist = (
114
+ encodings.pow(2).sum(1, keepdim=True)
115
+ - 2 * encodings @ codebook.t()
116
+ + codebook.pow(2).sum(1, keepdim=True).t()
117
+ )
118
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
119
+ z_q = self.decode_code(indices)
120
+
121
+ return z_q, indices
122
+
123
+ def vq2emb(self, vq, out_proj=True):
124
+ emb = self.decode_code(vq)
125
+ if out_proj:
126
+ emb = self.out_project(emb)
127
+ return emb
128
+
129
+ def latent2dist(self, latents):
130
+ encodings = rearrange(latents, "b d t -> (b t) d")
131
+ codebook = self.codebook.weight
132
+
133
+ # L2 normalize encodings and codebook
134
+ if self.use_l2_normlize:
135
+ encodings = F.normalize(encodings)
136
+ codebook = F.normalize(codebook)
137
+
138
+ # Compute euclidean distance between encodings and codebook,
139
+ # if use_l2_normlize is True, the distance is equal to cosine distance
140
+ dist = (
141
+ encodings.pow(2).sum(1, keepdim=True)
142
+ - 2 * encodings @ codebook.t()
143
+ + codebook.pow(2).sum(1, keepdim=True).t()
144
+ ) # (b*t, k)
145
+
146
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
147
+ dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
148
+ z_q = self.decode_code(indices)
149
+
150
+ return -dist, indices, z_q
models/codec/amphion_codec/quantize/lookup_free_quantize.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+ from torch.nn.utils import weight_norm
12
+
13
+
14
+ def WNConv1d(*args, **kwargs):
15
+ return weight_norm(nn.Conv1d(*args, **kwargs))
16
+
17
+
18
+ def WNConvTranspose1d(*args, **kwargs):
19
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
20
+
21
+
22
+ class LookupFreeQuantize(nn.Module):
23
+ def __init__(
24
+ self,
25
+ input_dim,
26
+ codebook_size,
27
+ codebook_dim,
28
+ ):
29
+ super().__init__()
30
+ self.input_dim = input_dim
31
+ self.codebook_size = codebook_size
32
+ self.codebook_dim = codebook_dim
33
+
34
+ assert 2**codebook_dim == codebook_size
35
+
36
+ if self.input_dim != self.codebook_dim:
37
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
38
+ self.out_project = WNConv1d(
39
+ self.codebook_dim, self.input_dim, kernel_size=1
40
+ )
41
+
42
+ else:
43
+ self.in_project = nn.Identity()
44
+ self.out_project = nn.Identity()
45
+
46
+ def forward(self, z):
47
+ z_e = self.in_project(z)
48
+ z_e = F.sigmoid(z_e)
49
+
50
+ z_q = z_e + (torch.round(z_e) - z_e).detach()
51
+
52
+ z_q = self.out_project(z_q)
53
+
54
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
55
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
56
+
57
+ bits = (
58
+ 2
59
+ ** torch.arange(self.codebook_dim, device=z.device)
60
+ .unsqueeze(0)
61
+ .unsqueeze(-1)
62
+ .long()
63
+ ) # (1, d, 1)
64
+ indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long()
65
+
66
+ return z_q, commit_loss, codebook_loss, indices, z_e
67
+
68
+ def vq2emb(self, vq, out_proj=True):
69
+ emb = torch.zeros(
70
+ vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device
71
+ ) # (B, d, T)
72
+ for i in range(self.codebook_dim):
73
+ emb[:, i, :] = (vq % 2).float()
74
+ vq = vq // 2
75
+ if out_proj:
76
+ emb = self.out_project(emb)
77
+ return emb
models/codec/amphion_codec/quantize/residual_vq.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+ from torch.nn.utils import weight_norm
14
+
15
+ from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
16
+ FactorizedVectorQuantize,
17
+ )
18
+ from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
19
+ from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
20
+
21
+
22
+ class ResidualVQ(nn.Module):
23
+ """
24
+ Introduced in SoundStream: An end2end neural audio codec
25
+ https://arxiv.org/abs/2107.03312
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ input_dim: int = 256,
31
+ num_quantizers: int = 8,
32
+ codebook_size: int = 1024,
33
+ codebook_dim: int = 256,
34
+ quantizer_type: str = "vq", # "vq" or "fvq" or "lfq"
35
+ quantizer_dropout: float = 0.5,
36
+ **kwargs,
37
+ ):
38
+ super().__init__()
39
+
40
+ self.input_dim = input_dim
41
+ self.num_quantizers = num_quantizers
42
+ self.codebook_size = codebook_size
43
+ self.codebook_dim = codebook_dim
44
+ self.quantizer_type = quantizer_type
45
+ self.quantizer_dropout = quantizer_dropout
46
+
47
+ if quantizer_type == "vq":
48
+ VQ = VectorQuantize
49
+ elif quantizer_type == "fvq":
50
+ VQ = FactorizedVectorQuantize
51
+ elif quantizer_type == "lfq":
52
+ VQ = LookupFreeQuantize
53
+ else:
54
+ raise ValueError(f"Unknown quantizer type {quantizer_type}")
55
+
56
+ self.quantizers = nn.ModuleList(
57
+ [
58
+ VQ(
59
+ input_dim=input_dim,
60
+ codebook_size=codebook_size,
61
+ codebook_dim=codebook_dim,
62
+ **kwargs,
63
+ )
64
+ for _ in range(num_quantizers)
65
+ ]
66
+ )
67
+
68
+ def forward(self, z, n_quantizers: int = None):
69
+ """
70
+ Parameters
71
+ ----------
72
+ z : Tensor[B x D x T]
73
+ n_quantizers : int, optional
74
+ No. of quantizers to use
75
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
76
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
77
+ when in training mode, and a random number of quantizers is used.
78
+ Returns
79
+ -------
80
+ "quantized_out" : Tensor[B x D x T]
81
+ Quantized continuous representation of input
82
+ "all_indices" : Tensor[N x B x T]
83
+ Codebook indices for each codebook
84
+ (quantized discrete representation of input)
85
+ "all_commit_losses" : Tensor[N]
86
+ "all_codebook_losses" : Tensor[N]
87
+ "all_quantized" : Tensor[N x B x D x T]
88
+ """
89
+
90
+ quantized_out = 0.0
91
+ residual = z
92
+
93
+ all_commit_losses = []
94
+ all_codebook_losses = []
95
+ all_indices = []
96
+ all_quantized = []
97
+
98
+ if n_quantizers is None:
99
+ n_quantizers = self.num_quantizers
100
+
101
+ if self.training:
102
+ n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
103
+ dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
104
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
105
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
106
+ n_quantizers = n_quantizers.to(z.device)
107
+
108
+ for i, quantizer in enumerate(self.quantizers):
109
+ if self.training is False and i >= n_quantizers:
110
+ break
111
+
112
+ z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
113
+ residual
114
+ )
115
+
116
+ # Create mask to apply quantizer dropout
117
+ mask = (
118
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
119
+ )
120
+ quantized_out = quantized_out + z_q_i * mask[:, None, None]
121
+ residual = residual - z_q_i
122
+
123
+ commit_loss_i = (commit_loss_i * mask).mean()
124
+ codebook_loss_i = (codebook_loss_i * mask).mean()
125
+
126
+ all_commit_losses.append(commit_loss_i)
127
+ all_codebook_losses.append(codebook_loss_i)
128
+ all_indices.append(indices_i)
129
+ all_quantized.append(z_q_i)
130
+
131
+ all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
132
+ torch.stack,
133
+ (all_commit_losses, all_codebook_losses, all_indices, all_quantized),
134
+ )
135
+
136
+ return (
137
+ quantized_out,
138
+ all_indices,
139
+ all_commit_losses,
140
+ all_codebook_losses,
141
+ all_quantized,
142
+ )
143
+
144
+ def vq2emb(self, vq, n_quantizers=None):
145
+ quantized_out = 0.0
146
+ if n_quantizers is None:
147
+ n_quantizers = self.num_quantizers
148
+ for idx, quantizer in enumerate(self.quantizers):
149
+ if idx >= n_quantizers:
150
+ break
151
+ quantized_out += quantizer.vq2emb(vq[idx])
152
+ return quantized_out
153
+
154
+ def latent2dist(self, z, n_quantizers=None):
155
+ quantized_out = 0.0
156
+ residual = z
157
+
158
+ all_dists = []
159
+ all_indices = []
160
+
161
+ if n_quantizers is None:
162
+ n_quantizers = self.num_quantizers
163
+
164
+ for i, quantizer in enumerate(self.quantizers):
165
+ if self.training is False and i >= n_quantizers:
166
+ break
167
+ dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
168
+ all_dists.append(dist_i)
169
+ all_indices.append(indices_i)
170
+
171
+ quantized_out = quantized_out + z_q_i
172
+ residual = residual - z_q_i
173
+
174
+ all_dists = torch.stack(all_dists)
175
+ all_indices = torch.stack(all_indices)
176
+
177
+ return all_dists, all_indices
models/codec/amphion_codec/quantize/vector_quantize.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange, repeat
11
+ from torch.nn.utils import weight_norm
12
+
13
+
14
+ def WNConv1d(*args, **kwargs):
15
+ return weight_norm(nn.Conv1d(*args, **kwargs))
16
+
17
+
18
+ def WNConvTranspose1d(*args, **kwargs):
19
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
20
+
21
+
22
+ def l2norm(t):
23
+ return F.normalize(t, p=2, dim=-1)
24
+
25
+
26
+ def ema_inplace(moving_avg, new, decay):
27
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
28
+
29
+
30
+ def laplace_smoothing(x, n_categories, eps=1e-5):
31
+ return (x + eps) / (x.sum() + n_categories * eps)
32
+
33
+
34
+ def sample_vectors(samples, num):
35
+ num_samples, device = samples.shape[0], samples.device
36
+
37
+ if num_samples >= num:
38
+ indices = torch.randperm(num_samples, device=device)[:num]
39
+ else:
40
+ indices = torch.randint(0, num_samples, (num,), device=device)
41
+
42
+ return samples[indices]
43
+
44
+
45
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
46
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
47
+
48
+ means = sample_vectors(samples, num_clusters)
49
+
50
+ for _ in range(num_iters):
51
+ if use_cosine_sim:
52
+ dists = samples @ means.t()
53
+ else:
54
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
55
+ means, "c d -> () c d"
56
+ )
57
+ dists = -(diffs**2).sum(dim=-1)
58
+
59
+ buckets = dists.max(dim=-1).indices
60
+ bins = torch.bincount(buckets, minlength=num_clusters)
61
+ zero_mask = bins == 0
62
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
63
+
64
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
65
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
66
+ new_means = new_means / bins_min_clamped[..., None]
67
+
68
+ if use_cosine_sim:
69
+ new_means = l2norm(new_means)
70
+
71
+ means = torch.where(zero_mask[..., None], means, new_means)
72
+
73
+ return means, bins
74
+
75
+
76
+ class EuclideanCodebook(nn.Module):
77
+ def __init__(
78
+ self,
79
+ dim,
80
+ codebook_size,
81
+ kmeans_init=False,
82
+ kmeans_iters=10,
83
+ decay=0.8,
84
+ eps=1e-5,
85
+ threshold_ema_dead_code=2,
86
+ weight_init=False,
87
+ ):
88
+ super().__init__()
89
+
90
+ self.decay = decay
91
+ init_fn = torch.randn if not weight_init else torch.zeros
92
+ embed = init_fn(codebook_size, dim)
93
+
94
+ if weight_init:
95
+ nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
96
+
97
+ self.codebook_size = codebook_size
98
+ self.kmeans_iters = kmeans_iters
99
+ self.eps = eps
100
+ self.threshold_ema_dead_code = threshold_ema_dead_code
101
+
102
+ self.register_buffer(
103
+ "initted", torch.Tensor([not kmeans_init])
104
+ ) # if kmeans_init is True, then initted is False; otherwise, initted is True
105
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
106
+ self.register_buffer("embed", embed)
107
+ self.register_buffer("embed_avg", embed.clone())
108
+
109
+ def init_embed_(self, data):
110
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
111
+ self.embed.data.copy_(embed)
112
+ self.embed_avg.data.copy_(embed)
113
+ self.cluster_size.data.copy_(cluster_size)
114
+ self.initted.data.copy_(torch.Tensor([True]))
115
+
116
+ def replace(self, samples, mask):
117
+ modified_codebook = torch.where(
118
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
119
+ )
120
+ self.embed.data.copy_(modified_codebook)
121
+
122
+ def expire_codes_(self, batch_samples):
123
+ if self.threshold_ema_dead_code == 0:
124
+ return
125
+
126
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
127
+ if not torch.any(expired_codes):
128
+ return
129
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
130
+ self.replace(batch_samples, mask=expired_codes)
131
+
132
+ def forward(self, x):
133
+ shape, dtype = x.shape, x.dtype
134
+ flatten = rearrange(x, "... d -> (...) d")
135
+ embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
136
+
137
+ if not self.initted:
138
+ self.init_embed_(flatten)
139
+
140
+ dist = -(
141
+ flatten.pow(2).sum(1, keepdim=True)
142
+ - 2 * flatten @ embed
143
+ + embed.pow(2).sum(0, keepdim=True)
144
+ )
145
+
146
+ embed_ind = dist.max(dim=-1).indices
147
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
148
+ embed_ind = embed_ind.view(*shape[:-1])
149
+ quantize = F.embedding(embed_ind, self.embed)
150
+
151
+ if self.training:
152
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
153
+ embed_sum = (
154
+ flatten.t() @ embed_onehot
155
+ ) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
156
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
157
+ cluster_size = (
158
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
159
+ * self.cluster_size.sum()
160
+ )
161
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
162
+ self.embed.data.copy_(embed_normalized)
163
+ self.expire_codes_(x)
164
+
165
+ return quantize, embed_ind
166
+
167
+ def vq2emb(self, vq):
168
+ quantize = F.embedding(vq, self.embed)
169
+ return quantize
170
+
171
+ def latent2dist(self, x):
172
+ shape, dtype = x.shape, x.dtype
173
+ flatten = rearrange(x, "... d -> (...) d")
174
+ embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
175
+
176
+ if not self.initted:
177
+ self.init_embed_(flatten)
178
+
179
+ dist = -(
180
+ flatten.pow(2).sum(1, keepdim=True)
181
+ - 2 * flatten @ embed
182
+ + embed.pow(2).sum(0, keepdim=True)
183
+ )
184
+
185
+ embed_ind = dist.max(dim=-1).indices
186
+ embed_ind = embed_ind.view(*shape[:-1])
187
+ quantize = F.embedding(embed_ind, self.embed)
188
+
189
+ dist = dist.view(*shape[:-1], -1)
190
+
191
+ return dist, embed_ind, quantize
192
+
193
+
194
+ class SimpleCodebook(nn.Module):
195
+ def __init__(
196
+ self,
197
+ dim,
198
+ codebook_size,
199
+ use_l2_normlize=False,
200
+ ):
201
+ super().__init__()
202
+
203
+ self.dim = dim
204
+ self.codebook_size = codebook_size
205
+ self.use_l2_normlize = use_l2_normlize
206
+
207
+ self.embed = nn.Embedding(self.codebook_size, self.dim)
208
+
209
+ def forward(self, x):
210
+ shape, dtype = x.shape, x.dtype
211
+ flatten = rearrange(x, "... d -> (...) d")
212
+ embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
213
+
214
+ if self.use_l2_normlize:
215
+ flatten = F.normalize(flatten)
216
+ embed = F.normalize(embed)
217
+
218
+ dist = -(
219
+ flatten.pow(2).sum(1, keepdim=True)
220
+ - 2 * flatten @ embed
221
+ + embed.pow(2).sum(0, keepdim=True)
222
+ )
223
+
224
+ embed_ind = dist.max(dim=-1).indices
225
+ embed_ind = embed_ind.view(*shape[:-1])
226
+ quantize = F.embedding(embed_ind, self.embed)
227
+
228
+ return quantize, embed_ind
229
+
230
+ def vq2emb(self, vq):
231
+ quantize = F.embedding(vq, self.embed.weight)
232
+ return quantize
233
+
234
+ def latent2dist(self, x):
235
+ shape, dtype = x.shape, x.dtype
236
+ flatten = rearrange(x, "... d -> (...) d")
237
+ embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
238
+
239
+ if self.use_l2_normlize:
240
+ flatten = F.normalize(flatten)
241
+ embed = F.normalize(embed)
242
+
243
+ dist = -(
244
+ flatten.pow(2).sum(1, keepdim=True)
245
+ - 2 * flatten @ embed
246
+ + embed.pow(2).sum(0, keepdim=True)
247
+ )
248
+
249
+ embed_ind = dist.max(dim=-1).indices
250
+ embed_ind = embed_ind.view(*shape[:-1])
251
+ quantize = F.embedding(embed_ind, self.embed)
252
+
253
+ dist = dist.view(*shape[:-1], -1)
254
+
255
+ return dist, embed_ind, quantize
256
+
257
+
258
+ class VectorQuantize(nn.Module):
259
+ """Vector quantization and factorized vecotor quantization implementation
260
+ Args:
261
+ input_dim (int): Dimension of input.
262
+ codebook_size (int): Codebook size.
263
+ codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
264
+ if use codebook_type == "euclidean", otherwise, if you want to use
265
+ factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
266
+ commitment (float): Weight for commitment loss.
267
+ use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
268
+ we suggest use it as True if you want to use factorized vector quantization
269
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
270
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
271
+ decay (float): Decay for exponential moving average over the codebooks.
272
+ epsilon (float): Epsilon value for numerical stability.
273
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
274
+ that have an exponential moving average cluster size less than the specified threshold with
275
+ randomly selected vector from the current batch.
276
+ """
277
+
278
+ def __init__(
279
+ self,
280
+ input_dim,
281
+ codebook_size,
282
+ codebook_dim,
283
+ commitment=0.005,
284
+ codebook_loss_weight=1.0,
285
+ use_l2_normlize=False,
286
+ codebook_type="euclidean", # "euclidean" or "simple"
287
+ kmeans_init=False,
288
+ kmeans_iters=10,
289
+ decay=0.8,
290
+ eps=1e-5,
291
+ threshold_ema_dead_code=2,
292
+ weight_init=False,
293
+ ):
294
+ super().__init__()
295
+ self.input_dim = input_dim
296
+ self.codebook_size = codebook_size
297
+ self.codebook_dim = codebook_dim
298
+ self.commitment = commitment
299
+ self.codebook_loss_weight = codebook_loss_weight
300
+ self.use_l2_normlize = use_l2_normlize
301
+ self.codebook_type = codebook_type
302
+ self.kmeans_init = kmeans_init
303
+ self.kmeans_iters = kmeans_iters
304
+ self.decay = decay
305
+ self.eps = eps
306
+ self.threshold_ema_dead_code = threshold_ema_dead_code
307
+ self.weight_init = weight_init
308
+
309
+ if self.input_dim != self.codebook_dim:
310
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
311
+ self.out_project = WNConv1d(
312
+ self.codebook_dim, self.input_dim, kernel_size=1
313
+ )
314
+
315
+ else:
316
+ self.in_project = nn.Identity()
317
+ self.out_project = nn.Identity()
318
+
319
+ if self.codebook_type == "euclidean":
320
+ self.codebook = EuclideanCodebook(
321
+ self.codebook_dim,
322
+ codebook_size=self.codebook_size,
323
+ kmeans_init=self.kmeans_init,
324
+ kmeans_iters=self.kmeans_iters,
325
+ decay=self.decay,
326
+ eps=self.eps,
327
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
328
+ weight_init=self.weight_init,
329
+ )
330
+ elif self.codebook_type == "simple":
331
+ self.codebook = SimpleCodebook(
332
+ self.codebook_dim,
333
+ codebook_size=self.codebook_size,
334
+ use_l2_normlize=self.use_l2_normlize,
335
+ )
336
+ else:
337
+ raise NotImplementedError(
338
+ f"codebook_type {self.codebook_type} is not implemented!"
339
+ )
340
+
341
+ def forward(self, z):
342
+ """
343
+ Parameters
344
+ ----------
345
+ z: torch.Tensor[B x D x T]
346
+
347
+ Returns
348
+ -------
349
+ z_q: torch.Tensor[B x D x T]
350
+ Quantized continuous representation of input
351
+ commit_loss: Tensor[B]
352
+ Commitment loss to train encoder to predict vectors closer to codebook entries
353
+ codebook_loss: Tensor[B]
354
+ Codebook loss to update the codebook
355
+ indices: torch.Tensor[B x T]
356
+ Codebook indices (quantized discrete representation of input)
357
+ z_e: torch.Tensor[B x D x T]
358
+ Projected latents (continuous representation of input before quantization)
359
+ """
360
+
361
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
362
+ z_e = self.in_project(z)
363
+ z_q, indices = self.decode_latents(z_e)
364
+
365
+ # Compute commitment loss and codebook loss
366
+ if self.training:
367
+ commit_loss = (
368
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
369
+ * self.commitment
370
+ )
371
+ codebook_loss = (
372
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
373
+ * self.codebook_loss_weight
374
+ )
375
+ else:
376
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
377
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
378
+
379
+ z_q = z_e + (z_q - z_e).detach()
380
+
381
+ z_q = self.out_project(z_q)
382
+
383
+ return z_q, commit_loss, codebook_loss, indices, z_e
384
+
385
+ def decode_latents(self, latents):
386
+ encodings = rearrange(latents, "b d t -> b t d")
387
+ z_q, indices = self.codebook(encodings)
388
+ z_q = z_q.transpose(1, 2)
389
+ return z_q, indices
390
+
391
+ def vq2emb(self, vq, out_proj=True):
392
+ emb = self.codebook.vq2emb(vq)
393
+ emb = emb.transpose(1, 2)
394
+ if out_proj:
395
+ emb = self.out_project(emb)
396
+ return emb
397
+
398
+ def latent2dist(self, latents):
399
+ latents = rearrange(latents, "b d t -> b t d")
400
+ dist, embed_ind, quantize = self.codebook.latent2dist(latents)
401
+ return dist, embed_ind, quantize.transpose(1, 2)
models/codec/amphion_codec/vocos.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import numpy as np
9
+ import scipy
10
+ import torch
11
+ from torch import nn, view_as_real, view_as_complex
12
+ from torch import nn
13
+ from torch.nn.utils import weight_norm, remove_weight_norm
14
+ from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
15
+ import librosa
16
+
17
+
18
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
19
+ """
20
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
21
+
22
+ Args:
23
+ x (Tensor): Input tensor.
24
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
25
+
26
+ Returns:
27
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
28
+ """
29
+ return torch.log(torch.clip(x, min=clip_val))
30
+
31
+
32
+ def symlog(x: torch.Tensor) -> torch.Tensor:
33
+ return torch.sign(x) * torch.log1p(x.abs())
34
+
35
+
36
+ def symexp(x: torch.Tensor) -> torch.Tensor:
37
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
38
+
39
+
40
+ class STFT(nn.Module):
41
+ def __init__(
42
+ self,
43
+ n_fft: int,
44
+ hop_length: int,
45
+ win_length: int,
46
+ center=True,
47
+ ):
48
+ super().__init__()
49
+ self.center = center
50
+ self.n_fft = n_fft
51
+ self.hop_length = hop_length
52
+ self.win_length = win_length
53
+ window = torch.hann_window(win_length)
54
+ self.register_buffer("window", window)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ # x: (B, T * hop_length)
58
+
59
+ if not self.center:
60
+ pad = self.win_length - self.hop_length
61
+ x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
62
+
63
+ stft_spec = torch.stft(
64
+ x,
65
+ self.n_fft,
66
+ hop_length=self.hop_length,
67
+ win_length=self.win_length,
68
+ window=self.window,
69
+ center=self.center,
70
+ return_complex=False,
71
+ ) # (B, n_fft // 2 + 1, T, 2)
72
+
73
+ rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
74
+ imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
75
+
76
+ log_mag = torch.log(
77
+ torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
78
+ ) # (B, n_fft // 2 + 1, T)
79
+ phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
80
+
81
+ return log_mag, phase
82
+
83
+
84
+ class ISTFT(nn.Module):
85
+ """
86
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
87
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
88
+ See issue: https://github.com/pytorch/pytorch/issues/62323
89
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
90
+ The NOLA constraint is met as we trim padded samples anyway.
91
+
92
+ Args:
93
+ n_fft (int): Size of Fourier transform.
94
+ hop_length (int): The distance between neighboring sliding window frames.
95
+ win_length (int): The size of window frame and STFT filter.
96
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
97
+ """
98
+
99
+ def __init__(
100
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
101
+ ):
102
+ super().__init__()
103
+ if padding not in ["center", "same"]:
104
+ raise ValueError("Padding must be 'center' or 'same'.")
105
+ self.padding = padding
106
+ self.n_fft = n_fft
107
+ self.hop_length = hop_length
108
+ self.win_length = win_length
109
+ window = torch.hann_window(win_length)
110
+ self.register_buffer("window", window)
111
+
112
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
113
+ """
114
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
115
+
116
+ Args:
117
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
118
+ N is the number of frequency bins, and T is the number of time frames.
119
+
120
+ Returns:
121
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
122
+ """
123
+ if self.padding == "center":
124
+ # Fallback to pytorch native implementation
125
+ return torch.istft(
126
+ spec,
127
+ self.n_fft,
128
+ self.hop_length,
129
+ self.win_length,
130
+ self.window,
131
+ center=True,
132
+ )
133
+ elif self.padding == "same":
134
+ pad = (self.win_length - self.hop_length) // 2
135
+ else:
136
+ raise ValueError("Padding must be 'center' or 'same'.")
137
+
138
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
139
+ B, N, T = spec.shape
140
+
141
+ # Inverse FFT
142
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
143
+ ifft = ifft * self.window[None, :, None]
144
+
145
+ # Overlap and Add
146
+ output_size = (T - 1) * self.hop_length + self.win_length
147
+ y = torch.nn.functional.fold(
148
+ ifft,
149
+ output_size=(1, output_size),
150
+ kernel_size=(1, self.win_length),
151
+ stride=(1, self.hop_length),
152
+ )[:, 0, 0, pad:-pad]
153
+
154
+ # Window envelope
155
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
156
+ window_envelope = torch.nn.functional.fold(
157
+ window_sq,
158
+ output_size=(1, output_size),
159
+ kernel_size=(1, self.win_length),
160
+ stride=(1, self.hop_length),
161
+ ).squeeze()[pad:-pad]
162
+
163
+ # Normalize
164
+ assert (window_envelope > 1e-11).all()
165
+ y = y / window_envelope
166
+
167
+ return y
168
+
169
+
170
+ class MDCT(nn.Module):
171
+ """
172
+ Modified Discrete Cosine Transform (MDCT) module.
173
+
174
+ Args:
175
+ frame_len (int): Length of the MDCT frame.
176
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
177
+ """
178
+
179
+ def __init__(self, frame_len: int, padding: str = "same"):
180
+ super().__init__()
181
+ if padding not in ["center", "same"]:
182
+ raise ValueError("Padding must be 'center' or 'same'.")
183
+ self.padding = padding
184
+ self.frame_len = frame_len
185
+ N = frame_len // 2
186
+ n0 = (N + 1) / 2
187
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
188
+ self.register_buffer("window", window)
189
+
190
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
191
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
192
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
193
+ # https://github.com/pytorch/pytorch/issues/71613
194
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
195
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
196
+
197
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
198
+ """
199
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
200
+
201
+ Args:
202
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
203
+ and T is the length of the audio.
204
+
205
+ Returns:
206
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
207
+ and N is the number of frequency bins.
208
+ """
209
+ if self.padding == "center":
210
+ audio = torch.nn.functional.pad(
211
+ audio, (self.frame_len // 2, self.frame_len // 2)
212
+ )
213
+ elif self.padding == "same":
214
+ # hop_length is 1/2 frame_len
215
+ audio = torch.nn.functional.pad(
216
+ audio, (self.frame_len // 4, self.frame_len // 4)
217
+ )
218
+ else:
219
+ raise ValueError("Padding must be 'center' or 'same'.")
220
+
221
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
222
+ N = self.frame_len // 2
223
+ x = x * self.window.expand(x.shape)
224
+ X = torch.fft.fft(
225
+ x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
226
+ )[..., :N]
227
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
228
+ return torch.real(res) * np.sqrt(2)
229
+
230
+
231
+ class IMDCT(nn.Module):
232
+ """
233
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
234
+
235
+ Args:
236
+ frame_len (int): Length of the MDCT frame.
237
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
238
+ """
239
+
240
+ def __init__(self, frame_len: int, padding: str = "same"):
241
+ super().__init__()
242
+ if padding not in ["center", "same"]:
243
+ raise ValueError("Padding must be 'center' or 'same'.")
244
+ self.padding = padding
245
+ self.frame_len = frame_len
246
+ N = frame_len // 2
247
+ n0 = (N + 1) / 2
248
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
249
+ self.register_buffer("window", window)
250
+
251
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
252
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
253
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
254
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
255
+
256
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
257
+ """
258
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
259
+
260
+ Args:
261
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
262
+ L is the number of frames, and N is the number of frequency bins.
263
+
264
+ Returns:
265
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
266
+ """
267
+ B, L, N = X.shape
268
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
269
+ Y[..., :N] = X
270
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
271
+ y = torch.fft.ifft(
272
+ Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
273
+ )
274
+ y = (
275
+ torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
276
+ * np.sqrt(N)
277
+ * np.sqrt(2)
278
+ )
279
+ result = y * self.window.expand(y.shape)
280
+ output_size = (1, (L + 1) * N)
281
+ audio = torch.nn.functional.fold(
282
+ result.transpose(1, 2),
283
+ output_size=output_size,
284
+ kernel_size=(1, self.frame_len),
285
+ stride=(1, self.frame_len // 2),
286
+ )[:, 0, 0, :]
287
+
288
+ if self.padding == "center":
289
+ pad = self.frame_len // 2
290
+ elif self.padding == "same":
291
+ pad = self.frame_len // 4
292
+ else:
293
+ raise ValueError("Padding must be 'center' or 'same'.")
294
+
295
+ audio = audio[:, pad:-pad]
296
+ return audio
297
+
298
+
299
+ class FourierHead(nn.Module):
300
+ """Base class for inverse fourier modules."""
301
+
302
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
303
+ """
304
+ Args:
305
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
306
+ L is the sequence length, and H denotes the model dimension.
307
+
308
+ Returns:
309
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
310
+ """
311
+ raise NotImplementedError("Subclasses must implement the forward method.")
312
+
313
+
314
+ class ISTFTHead(FourierHead):
315
+ """
316
+ ISTFT Head module for predicting STFT complex coefficients.
317
+
318
+ Args:
319
+ dim (int): Hidden dimension of the model.
320
+ n_fft (int): Size of Fourier transform.
321
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
322
+ the resolution of the input features.
323
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
324
+ """
325
+
326
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
327
+ super().__init__()
328
+ out_dim = n_fft + 2
329
+ self.out = torch.nn.Linear(dim, out_dim)
330
+ self.istft = ISTFT(
331
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
332
+ )
333
+
334
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
335
+ """
336
+ Forward pass of the ISTFTHead module.
337
+
338
+ Args:
339
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
340
+ L is the sequence length, and H denotes the model dimension.
341
+
342
+ Returns:
343
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
344
+ """
345
+ x = self.out(x).transpose(1, 2)
346
+ mag, p = x.chunk(2, dim=1)
347
+ mag = torch.exp(mag)
348
+ mag = torch.clip(
349
+ mag, max=1e2
350
+ ) # safeguard to prevent excessively large magnitudes
351
+ # wrapping happens here. These two lines produce real and imaginary value
352
+ x = torch.cos(p)
353
+ y = torch.sin(p)
354
+ # recalculating phase here does not produce anything new
355
+ # only costs time
356
+ # phase = torch.atan2(y, x)
357
+ # S = mag * torch.exp(phase * 1j)
358
+ # better directly produce the complex value
359
+ S = mag * (x + 1j * y)
360
+ audio = self.istft(S)
361
+ return audio
362
+
363
+
364
+ class IMDCTSymExpHead(FourierHead):
365
+ """
366
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
367
+
368
+ Args:
369
+ dim (int): Hidden dimension of the model.
370
+ mdct_frame_len (int): Length of the MDCT frame.
371
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
372
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
373
+ based on perceptual scaling. Defaults to None.
374
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
375
+ """
376
+
377
+ def __init__(
378
+ self,
379
+ dim: int,
380
+ mdct_frame_len: int,
381
+ padding: str = "same",
382
+ sample_rate: Optional[int] = None,
383
+ clip_audio: bool = False,
384
+ ):
385
+ super().__init__()
386
+ out_dim = mdct_frame_len // 2
387
+ self.out = nn.Linear(dim, out_dim)
388
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
389
+ self.clip_audio = clip_audio
390
+
391
+ if sample_rate is not None:
392
+ # optionally init the last layer following mel-scale
393
+ m_max = _hz_to_mel(sample_rate // 2)
394
+ m_pts = torch.linspace(0, m_max, out_dim)
395
+ f_pts = _mel_to_hz(m_pts)
396
+ scale = 1 - (f_pts / f_pts.max())
397
+
398
+ with torch.no_grad():
399
+ self.out.weight.mul_(scale.view(-1, 1))
400
+
401
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
402
+ """
403
+ Forward pass of the IMDCTSymExpHead module.
404
+
405
+ Args:
406
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
407
+ L is the sequence length, and H denotes the model dimension.
408
+
409
+ Returns:
410
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
411
+ """
412
+ x = self.out(x)
413
+ x = symexp(x)
414
+ x = torch.clip(
415
+ x, min=-1e2, max=1e2
416
+ ) # safeguard to prevent excessively large magnitudes
417
+ audio = self.imdct(x)
418
+ if self.clip_audio:
419
+ audio = torch.clip(x, min=-1.0, max=1.0)
420
+
421
+ return audio
422
+
423
+
424
+ class IMDCTCosHead(FourierHead):
425
+ """
426
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
427
+
428
+ Args:
429
+ dim (int): Hidden dimension of the model.
430
+ mdct_frame_len (int): Length of the MDCT frame.
431
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
432
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
433
+ """
434
+
435
+ def __init__(
436
+ self,
437
+ dim: int,
438
+ mdct_frame_len: int,
439
+ padding: str = "same",
440
+ clip_audio: bool = False,
441
+ ):
442
+ super().__init__()
443
+ self.clip_audio = clip_audio
444
+ self.out = nn.Linear(dim, mdct_frame_len)
445
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
446
+
447
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
448
+ """
449
+ Forward pass of the IMDCTCosHead module.
450
+
451
+ Args:
452
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
453
+ L is the sequence length, and H denotes the model dimension.
454
+
455
+ Returns:
456
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
457
+ """
458
+ x = self.out(x)
459
+ m, p = x.chunk(2, dim=2)
460
+ m = torch.exp(m).clip(
461
+ max=1e2
462
+ ) # safeguard to prevent excessively large magnitudes
463
+ audio = self.imdct(m * torch.cos(p))
464
+ if self.clip_audio:
465
+ audio = torch.clip(x, min=-1.0, max=1.0)
466
+ return audio
467
+
468
+
469
+ class ConvNeXtBlock(nn.Module):
470
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
471
+
472
+ Args:
473
+ dim (int): Number of input channels.
474
+ intermediate_dim (int): Dimensionality of the intermediate layer.
475
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
476
+ Defaults to None.
477
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
478
+ None means non-conditional LayerNorm. Defaults to None.
479
+ """
480
+
481
+ def __init__(
482
+ self,
483
+ dim: int,
484
+ intermediate_dim: int,
485
+ layer_scale_init_value: float,
486
+ adanorm_num_embeddings: Optional[int] = None,
487
+ ):
488
+ super().__init__()
489
+ self.dwconv = nn.Conv1d(
490
+ dim, dim, kernel_size=7, padding=3, groups=dim
491
+ ) # depthwise conv
492
+ self.adanorm = adanorm_num_embeddings is not None
493
+ if adanorm_num_embeddings:
494
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
495
+ else:
496
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
497
+ self.pwconv1 = nn.Linear(
498
+ dim, intermediate_dim
499
+ ) # pointwise/1x1 convs, implemented with linear layers
500
+ self.act = nn.GELU()
501
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
502
+ self.gamma = (
503
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
504
+ if layer_scale_init_value > 0
505
+ else None
506
+ )
507
+
508
+ def forward(
509
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
510
+ ) -> torch.Tensor:
511
+ residual = x
512
+ x = self.dwconv(x)
513
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
514
+ if self.adanorm:
515
+ assert cond_embedding_id is not None
516
+ x = self.norm(x, cond_embedding_id)
517
+ else:
518
+ x = self.norm(x)
519
+ x = self.pwconv1(x)
520
+ x = self.act(x)
521
+ x = self.pwconv2(x)
522
+ if self.gamma is not None:
523
+ x = self.gamma * x
524
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
525
+
526
+ x = residual + x
527
+ return x
528
+
529
+
530
+ class AdaLayerNorm(nn.Module):
531
+ """
532
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
533
+
534
+ Args:
535
+ num_embeddings (int): Number of embeddings.
536
+ embedding_dim (int): Dimension of the embeddings.
537
+ """
538
+
539
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
540
+ super().__init__()
541
+ self.eps = eps
542
+ self.dim = embedding_dim
543
+ self.scale = nn.Embedding(
544
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
545
+ )
546
+ self.shift = nn.Embedding(
547
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
548
+ )
549
+ torch.nn.init.ones_(self.scale.weight)
550
+ torch.nn.init.zeros_(self.shift.weight)
551
+
552
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
553
+ scale = self.scale(cond_embedding_id)
554
+ shift = self.shift(cond_embedding_id)
555
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
556
+ x = x * scale + shift
557
+ return x
558
+
559
+
560
+ class ResBlock1(nn.Module):
561
+ """
562
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
563
+ but without upsampling layers.
564
+
565
+ Args:
566
+ dim (int): Number of input channels.
567
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
568
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
569
+ Defaults to (1, 3, 5).
570
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
571
+ Defaults to 0.1.
572
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
573
+ Defaults to None.
574
+ """
575
+
576
+ def __init__(
577
+ self,
578
+ dim: int,
579
+ kernel_size: int = 3,
580
+ dilation: Tuple[int, int, int] = (1, 3, 5),
581
+ lrelu_slope: float = 0.1,
582
+ layer_scale_init_value: Optional[float] = None,
583
+ ):
584
+ super().__init__()
585
+ self.lrelu_slope = lrelu_slope
586
+ self.convs1 = nn.ModuleList(
587
+ [
588
+ weight_norm(
589
+ nn.Conv1d(
590
+ dim,
591
+ dim,
592
+ kernel_size,
593
+ 1,
594
+ dilation=dilation[0],
595
+ padding=self.get_padding(kernel_size, dilation[0]),
596
+ )
597
+ ),
598
+ weight_norm(
599
+ nn.Conv1d(
600
+ dim,
601
+ dim,
602
+ kernel_size,
603
+ 1,
604
+ dilation=dilation[1],
605
+ padding=self.get_padding(kernel_size, dilation[1]),
606
+ )
607
+ ),
608
+ weight_norm(
609
+ nn.Conv1d(
610
+ dim,
611
+ dim,
612
+ kernel_size,
613
+ 1,
614
+ dilation=dilation[2],
615
+ padding=self.get_padding(kernel_size, dilation[2]),
616
+ )
617
+ ),
618
+ ]
619
+ )
620
+
621
+ self.convs2 = nn.ModuleList(
622
+ [
623
+ weight_norm(
624
+ nn.Conv1d(
625
+ dim,
626
+ dim,
627
+ kernel_size,
628
+ 1,
629
+ dilation=1,
630
+ padding=self.get_padding(kernel_size, 1),
631
+ )
632
+ ),
633
+ weight_norm(
634
+ nn.Conv1d(
635
+ dim,
636
+ dim,
637
+ kernel_size,
638
+ 1,
639
+ dilation=1,
640
+ padding=self.get_padding(kernel_size, 1),
641
+ )
642
+ ),
643
+ weight_norm(
644
+ nn.Conv1d(
645
+ dim,
646
+ dim,
647
+ kernel_size,
648
+ 1,
649
+ dilation=1,
650
+ padding=self.get_padding(kernel_size, 1),
651
+ )
652
+ ),
653
+ ]
654
+ )
655
+
656
+ self.gamma = nn.ParameterList(
657
+ [
658
+ (
659
+ nn.Parameter(
660
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
661
+ )
662
+ if layer_scale_init_value is not None
663
+ else None
664
+ ),
665
+ (
666
+ nn.Parameter(
667
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
668
+ )
669
+ if layer_scale_init_value is not None
670
+ else None
671
+ ),
672
+ (
673
+ nn.Parameter(
674
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
675
+ )
676
+ if layer_scale_init_value is not None
677
+ else None
678
+ ),
679
+ ]
680
+ )
681
+
682
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
683
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
684
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
685
+ xt = c1(xt)
686
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
687
+ xt = c2(xt)
688
+ if gamma is not None:
689
+ xt = gamma * xt
690
+ x = xt + x
691
+ return x
692
+
693
+ def remove_weight_norm(self):
694
+ for l in self.convs1:
695
+ remove_weight_norm(l)
696
+ for l in self.convs2:
697
+ remove_weight_norm(l)
698
+
699
+ @staticmethod
700
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
701
+ return int((kernel_size * dilation - dilation) / 2)
702
+
703
+
704
+ class Backbone(nn.Module):
705
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
706
+
707
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
708
+ """
709
+ Args:
710
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
711
+ C denotes output features, and L is the sequence length.
712
+
713
+ Returns:
714
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
715
+ and H denotes the model dimension.
716
+ """
717
+ raise NotImplementedError("Subclasses must implement the forward method.")
718
+
719
+
720
+ class VocosBackbone(Backbone):
721
+ """
722
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
723
+
724
+ Args:
725
+ input_channels (int): Number of input features channels.
726
+ dim (int): Hidden dimension of the model.
727
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
728
+ num_layers (int): Number of ConvNeXtBlock layers.
729
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
730
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
731
+ None means non-conditional model. Defaults to None.
732
+ """
733
+
734
+ def __init__(
735
+ self,
736
+ input_channels: int,
737
+ dim: int,
738
+ intermediate_dim: int,
739
+ num_layers: int,
740
+ layer_scale_init_value: Optional[float] = None,
741
+ adanorm_num_embeddings: Optional[int] = None,
742
+ ):
743
+ super().__init__()
744
+ self.input_channels = input_channels
745
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
746
+ self.adanorm = adanorm_num_embeddings is not None
747
+ if adanorm_num_embeddings:
748
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
749
+ else:
750
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
751
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
752
+ self.convnext = nn.ModuleList(
753
+ [
754
+ ConvNeXtBlock(
755
+ dim=dim,
756
+ intermediate_dim=intermediate_dim,
757
+ layer_scale_init_value=layer_scale_init_value,
758
+ adanorm_num_embeddings=adanorm_num_embeddings,
759
+ )
760
+ for _ in range(num_layers)
761
+ ]
762
+ )
763
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
764
+ self.apply(self._init_weights)
765
+
766
+ def _init_weights(self, m):
767
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
768
+ nn.init.trunc_normal_(m.weight, std=0.02)
769
+ nn.init.constant_(m.bias, 0)
770
+
771
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
772
+ bandwidth_id = kwargs.get("bandwidth_id", None)
773
+ x = self.embed(x)
774
+ if self.adanorm:
775
+ assert bandwidth_id is not None
776
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
777
+ else:
778
+ x = self.norm(x.transpose(1, 2))
779
+ x = x.transpose(1, 2)
780
+ for conv_block in self.convnext:
781
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
782
+ x = self.final_layer_norm(x.transpose(1, 2))
783
+ return x
784
+
785
+
786
+ class VocosResNetBackbone(Backbone):
787
+ """
788
+ Vocos backbone module built with ResBlocks.
789
+
790
+ Args:
791
+ input_channels (int): Number of input features channels.
792
+ dim (int): Hidden dimension of the model.
793
+ num_blocks (int): Number of ResBlock1 blocks.
794
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
795
+ """
796
+
797
+ def __init__(
798
+ self,
799
+ input_channels,
800
+ dim,
801
+ num_blocks,
802
+ layer_scale_init_value=None,
803
+ ):
804
+ super().__init__()
805
+ self.input_channels = input_channels
806
+ self.embed = weight_norm(
807
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
808
+ )
809
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
810
+ self.resnet = nn.Sequential(
811
+ *[
812
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
813
+ for _ in range(num_blocks)
814
+ ]
815
+ )
816
+
817
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
818
+ x = self.embed(x)
819
+ x = self.resnet(x)
820
+ x = x.transpose(1, 2)
821
+ return x
822
+
823
+
824
+ class Vocos(nn.Module):
825
+ def __init__(
826
+ self,
827
+ input_channels: int = 256,
828
+ dim: int = 384,
829
+ intermediate_dim: int = 1152,
830
+ num_layers: int = 8,
831
+ n_fft: int = 800,
832
+ hop_size: int = 200,
833
+ padding: str = "same",
834
+ adanorm_num_embeddings=None,
835
+ cfg=None,
836
+ ):
837
+ super().__init__()
838
+
839
+ input_channels = (
840
+ cfg.input_channels
841
+ if cfg is not None and hasattr(cfg, "input_channels")
842
+ else input_channels
843
+ )
844
+ dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim
845
+ intermediate_dim = (
846
+ cfg.intermediate_dim
847
+ if cfg is not None and hasattr(cfg, "intermediate_dim")
848
+ else intermediate_dim
849
+ )
850
+ num_layers = (
851
+ cfg.num_layers
852
+ if cfg is not None and hasattr(cfg, "num_layers")
853
+ else num_layers
854
+ )
855
+ adanorm_num_embeddings = (
856
+ cfg.adanorm_num_embeddings
857
+ if cfg is not None and hasattr(cfg, "adanorm_num_embeddings")
858
+ else adanorm_num_embeddings
859
+ )
860
+ n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
861
+ hop_size = (
862
+ cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
863
+ )
864
+ padding = (
865
+ cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
866
+ )
867
+
868
+ self.backbone = VocosBackbone(
869
+ input_channels=input_channels,
870
+ dim=dim,
871
+ intermediate_dim=intermediate_dim,
872
+ num_layers=num_layers,
873
+ adanorm_num_embeddings=adanorm_num_embeddings,
874
+ )
875
+ self.head = ISTFTHead(dim, n_fft, hop_size, padding)
876
+
877
+ def forward(self, x):
878
+ x = self.backbone(x)
879
+ x = self.head(x)
880
+
881
+ return x[:, None, :]
models/codec/coco/rep_coco_model.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from concurrent.futures import ALL_COMPLETED
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from torch.nn import functional as F
12
+
13
+ from models.codec.amphion_codec.quantize import ResidualVQ
14
+ from models.codec.amphion_codec.vocos import VocosBackbone
15
+
16
+
17
+ def init_weights(m):
18
+ if isinstance(m, nn.Conv1d):
19
+ nn.init.trunc_normal_(m.weight, std=0.02)
20
+ nn.init.constant_(m.bias, 0)
21
+ if isinstance(m, nn.Linear):
22
+ nn.init.trunc_normal_(m.weight, std=0.02)
23
+ nn.init.constant_(m.bias, 0)
24
+
25
+
26
+ def compute_codebook_perplexity(indices, codebook_size):
27
+ indices = indices.flatten()
28
+ prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0)
29
+ perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10)))
30
+ return perp
31
+
32
+
33
+ class CocoContentStyle(nn.Module):
34
+ def __init__(
35
+ self,
36
+ codebook_size=8192,
37
+ hidden_size=1024,
38
+ codebook_dim=8,
39
+ num_quantizers=1,
40
+ quantizer_type="fvq",
41
+ use_whisper=True,
42
+ use_chromagram=True,
43
+ construct_only_for_quantizer=False,
44
+ cfg=None,
45
+ ):
46
+ super().__init__()
47
+
48
+ assert cfg is not None
49
+ self.cfg = cfg
50
+
51
+ codebook_size = getattr(cfg, "codebook_size", codebook_size)
52
+ hidden_size = getattr(cfg, "hidden_size", hidden_size)
53
+ codebook_dim = getattr(cfg, "codebook_dim", codebook_dim)
54
+ num_quantizers = getattr(cfg, "num_quantizers", num_quantizers)
55
+ quantizer_type = getattr(cfg, "quantizer_type", quantizer_type)
56
+
57
+ self.codebook_size = codebook_size
58
+ self.codebook_dim = codebook_dim
59
+ self.hidden_size = hidden_size
60
+ self.num_quantizers = num_quantizers
61
+ self.quantizer_type = quantizer_type
62
+
63
+ if use_whisper:
64
+ self.whisper_input_layer = nn.Linear(self.cfg.whisper_dim, hidden_size)
65
+ if use_chromagram:
66
+ self.chromagram_input_layer = nn.Linear(
67
+ self.cfg.chromagram_dim, hidden_size
68
+ )
69
+
70
+ downsample_rate = getattr(cfg, "downsample_rate", 1)
71
+ if downsample_rate > 1:
72
+ self.do_downsample = True
73
+ assert np.log2(downsample_rate).is_integer()
74
+
75
+ down_layers = []
76
+ up_layers = []
77
+ for _ in range(int(np.log2(downsample_rate))):
78
+ down_layers.extend(
79
+ [
80
+ nn.Conv1d(
81
+ hidden_size,
82
+ hidden_size,
83
+ kernel_size=3,
84
+ stride=2,
85
+ padding=1,
86
+ ),
87
+ nn.GELU(),
88
+ ]
89
+ )
90
+ up_layers.extend(
91
+ [
92
+ nn.ConvTranspose1d(
93
+ hidden_size, hidden_size, kernel_size=4, stride=2, padding=1
94
+ ),
95
+ nn.GELU(),
96
+ ]
97
+ )
98
+ self.downsample_layers = nn.Sequential(*down_layers)
99
+ self.upsample_layers = nn.Sequential(*up_layers)
100
+
101
+ else:
102
+ self.do_downsample = False
103
+
104
+ self.encoder = nn.Sequential(
105
+ VocosBackbone(
106
+ input_channels=self.hidden_size,
107
+ dim=self.cfg.encoder.vocos_dim,
108
+ intermediate_dim=self.cfg.encoder.vocos_intermediate_dim,
109
+ num_layers=self.cfg.encoder.vocos_num_layers,
110
+ adanorm_num_embeddings=None,
111
+ ),
112
+ nn.Linear(self.cfg.encoder.vocos_dim, self.hidden_size),
113
+ )
114
+
115
+ self.quantizer = ResidualVQ(
116
+ input_dim=hidden_size,
117
+ num_quantizers=num_quantizers,
118
+ codebook_size=codebook_size,
119
+ codebook_dim=codebook_dim,
120
+ quantizer_type=quantizer_type,
121
+ quantizer_dropout=0.0,
122
+ commitment=0.15,
123
+ codebook_loss_weight=1.0,
124
+ use_l2_normlize=True,
125
+ )
126
+
127
+ if not construct_only_for_quantizer:
128
+ self.decoder = nn.Sequential(
129
+ VocosBackbone(
130
+ input_channels=self.hidden_size,
131
+ dim=self.cfg.decoder.vocos_dim,
132
+ intermediate_dim=self.cfg.decoder.vocos_intermediate_dim,
133
+ num_layers=self.cfg.decoder.vocos_num_layers,
134
+ adanorm_num_embeddings=None,
135
+ ),
136
+ nn.Linear(self.cfg.decoder.vocos_dim, self.hidden_size),
137
+ )
138
+
139
+ if use_whisper:
140
+ self.whisper_output_layer = nn.Linear(
141
+ self.hidden_size, self.cfg.whisper_dim
142
+ )
143
+ if use_chromagram:
144
+ self.chromagram_output_layer = nn.Linear(
145
+ self.hidden_size, self.cfg.chromagram_dim
146
+ )
147
+
148
+ self.reset_parameters()
149
+
150
+ def forward(
151
+ self,
152
+ whisper_feats,
153
+ chromagram_feats,
154
+ return_for_quantizer=False,
155
+ ):
156
+ """
157
+ Args:
158
+ whisper_feats: [B, T, 1024]
159
+ chromagram_feats: [B, T, 24]
160
+ Returns:
161
+ whisper_rec: [B, T, 1024]
162
+ chromagram_rec: [B, T, 24]
163
+ codebook_loss: float
164
+ all_indices: [N, B, T] or [B, T] if num_of_quantizers == 1
165
+ """
166
+ T = whisper_feats.shape[1]
167
+
168
+ # [B, T, D]
169
+ x = self.whisper_input_layer(whisper_feats) + self.chromagram_input_layer(
170
+ chromagram_feats
171
+ )
172
+ # print("Before downsample:", x.shape)
173
+
174
+ # ====== Downsample ======
175
+ if self.do_downsample:
176
+ x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2)
177
+
178
+ # print("After downsample:", x.shape)
179
+
180
+ # ====== Encoder ======
181
+ x = self.encoder(x.transpose(1, 2)).transpose(1, 2) # [B, T, D] -> [B, D, T]
182
+
183
+ # ====== Quantizer ======
184
+ (
185
+ quantized_out, # [B, D, T]
186
+ all_indices, # [num_of_quantizers, B, T]
187
+ all_commit_losses, # [num_of_quantizers]
188
+ all_codebook_losses, # [num_of_quantizers]
189
+ _,
190
+ ) = self.quantizer(x)
191
+
192
+ if return_for_quantizer:
193
+ if all_indices.shape[0] == 1:
194
+ return all_indices.squeeze(0), quantized_out.transpose(1, 2)
195
+ return all_indices, quantized_out.transpose(1, 2)
196
+
197
+ # ====== Decoder ======
198
+ x_rec = self.decoder(quantized_out) # [B, T, D]
199
+
200
+ # ====== Upsample ======
201
+ if self.do_downsample:
202
+ x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2)
203
+
204
+ # print("After upsample:", x_rec.shape)
205
+
206
+ # Ensure output dimensions match input
207
+ if x_rec.shape[1] >= T: # Check time dimension
208
+ x_rec = x_rec[:, :T, :]
209
+ else:
210
+ padding_frames = T - x_rec.shape[1]
211
+ last_frame = x_rec[:, -1:, :]
212
+ padding = last_frame.repeat(1, padding_frames, 1)
213
+ x_rec = torch.cat([x_rec, padding], dim=1)
214
+
215
+ # ====== Loss ======
216
+ whisper_rec = self.whisper_output_layer(x_rec) # [B, T, 1024]
217
+ chromagram_rec = self.chromagram_output_layer(x_rec) # [B, T, 24]
218
+
219
+ codebook_loss = (all_codebook_losses + all_commit_losses).mean()
220
+ all_indices = all_indices
221
+
222
+ return whisper_rec, chromagram_rec, codebook_loss, all_indices
223
+
224
+ def quantize(self, whisper_feats, chromagram_feats):
225
+ """
226
+ Args:
227
+ whisper_feats: [B, T, 1024]
228
+ chromagram_feats: [B, T, 24]
229
+ Returns:
230
+ all_indices: [N, B, T], or [B, T] if num_of_quantizers == 1
231
+ quantized_out: [B, D, T]
232
+ """
233
+ all_indices, quantized_out = self.forward(
234
+ whisper_feats,
235
+ chromagram_feats,
236
+ return_for_quantizer=True,
237
+ )
238
+ return all_indices, quantized_out
239
+
240
+ def reset_parameters(self):
241
+ self.apply(init_weights)
242
+
243
+
244
+ class CocoContent(CocoContentStyle):
245
+ def __init__(
246
+ self,
247
+ cfg,
248
+ use_whisper=True,
249
+ use_chromagram=False,
250
+ construct_only_for_quantizer=False,
251
+ ):
252
+ super().__init__(
253
+ cfg=cfg,
254
+ use_whisper=use_whisper,
255
+ use_chromagram=use_chromagram,
256
+ construct_only_for_quantizer=construct_only_for_quantizer,
257
+ )
258
+
259
+ def forward(
260
+ self,
261
+ whisper_feats,
262
+ return_for_quantizer=False,
263
+ ):
264
+ """
265
+ Args:
266
+ whisper_feats: [B, T, 1024]
267
+ Returns:
268
+ whisper_rec: [B, T, 1024]
269
+ codebook_loss: float
270
+ all_indices: [N, B, T]
271
+ """
272
+ T = whisper_feats.shape[1]
273
+
274
+ # [B, T, D]
275
+ x = self.whisper_input_layer(whisper_feats)
276
+
277
+ # ====== Downsample ======
278
+ if self.do_downsample:
279
+ x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2)
280
+
281
+ # ====== Encoder ======
282
+ x = self.encoder(x.transpose(1, 2)).transpose(1, 2) # [B, T, D] -> [B, D, T]
283
+
284
+ # ====== Quantizer ======
285
+ (
286
+ quantized_out, # [B, D, T]
287
+ all_indices, # [num_of_quantizers, B, T]
288
+ all_commit_losses, # [num_of_quantizers]
289
+ all_codebook_losses, # [num_of_quantizers]
290
+ _,
291
+ ) = self.quantizer(x)
292
+
293
+ if return_for_quantizer:
294
+ if all_indices.shape[0] == 1:
295
+ return all_indices.squeeze(0), quantized_out.transpose(1, 2)
296
+ return all_indices, quantized_out.transpose(1, 2)
297
+
298
+ # ====== Decoder ======
299
+ x_rec = self.decoder(quantized_out) # [B, T, D]
300
+
301
+ # ====== Upsample ======
302
+ if self.do_downsample:
303
+ x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2)
304
+
305
+ # Ensure output dimensions match input
306
+ if x_rec.shape[1] >= T: # Check time dimension
307
+ x_rec = x_rec[:, :T, :]
308
+ else:
309
+ padding_frames = T - x_rec.shape[1]
310
+ last_frame = x_rec[:, -1:, :]
311
+ padding = last_frame.repeat(1, padding_frames, 1)
312
+ x_rec = torch.cat([x_rec, padding], dim=1)
313
+
314
+ # ====== Loss ======
315
+ whisper_rec = self.whisper_output_layer(x_rec) # [B, T, 1024]
316
+
317
+ codebook_loss = (all_codebook_losses + all_commit_losses).mean()
318
+ all_indices = all_indices
319
+
320
+ return whisper_rec, codebook_loss, all_indices
321
+
322
+ def quantize(self, whisper_feats):
323
+ all_indices, quantized_out = self.forward(
324
+ whisper_feats, return_for_quantizer=True
325
+ )
326
+ return all_indices, quantized_out
327
+
328
+
329
+ class CocoStyle(CocoContentStyle):
330
+ def __init__(
331
+ self,
332
+ cfg,
333
+ use_whisper=False,
334
+ use_chromagram=True,
335
+ construct_only_for_quantizer=False,
336
+ ):
337
+ super().__init__(
338
+ cfg=cfg,
339
+ use_whisper=use_whisper,
340
+ use_chromagram=use_chromagram,
341
+ construct_only_for_quantizer=construct_only_for_quantizer,
342
+ )
343
+
344
+ def forward(
345
+ self,
346
+ chromagram_feats,
347
+ return_for_quantizer=False,
348
+ ):
349
+ """
350
+ Args:
351
+ chromagram_feats: [B, T, 24]
352
+ Returns:
353
+ chromagram_rec: [B, T, 24]
354
+ codebook_loss: float
355
+ all_indices: [N, B, T]
356
+ """
357
+ T = chromagram_feats.shape[1]
358
+
359
+ # [B, T, D]
360
+ x = self.chromagram_input_layer(chromagram_feats)
361
+
362
+ # ====== Downsample ======
363
+ if self.do_downsample:
364
+ x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2)
365
+
366
+ # ====== Encoder ======
367
+ x = self.encoder(x.transpose(1, 2)).transpose(1, 2) # [B, T, D] -> [B, D, T]
368
+
369
+ # ====== Quantizer ======
370
+ (
371
+ quantized_out, # [B, D, T]
372
+ all_indices, # [num_of_quantizers, B, T]
373
+ all_commit_losses, # [num_of_quantizers]
374
+ all_codebook_losses, # [num_of_quantizers]
375
+ _,
376
+ ) = self.quantizer(x)
377
+
378
+ if return_for_quantizer:
379
+ if all_indices.shape[0] == 1:
380
+ return all_indices.squeeze(0), quantized_out.transpose(1, 2)
381
+ return all_indices, quantized_out.transpose(1, 2)
382
+
383
+ # ====== Decoder ======
384
+ x_rec = self.decoder(quantized_out) # [B, T, D]
385
+
386
+ # ====== Upsample ======
387
+ if self.do_downsample:
388
+ x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2)
389
+
390
+ # Ensure output dimensions match input
391
+ if x_rec.shape[1] >= T: # Check time dimension
392
+ x_rec = x_rec[:, :T, :]
393
+ else:
394
+ padding_frames = T - x_rec.shape[1]
395
+ last_frame = x_rec[:, -1:, :]
396
+ padding = last_frame.repeat(1, padding_frames, 1)
397
+ x_rec = torch.cat([x_rec, padding], dim=1)
398
+
399
+ # ====== Loss ======
400
+ chromagram_rec = self.chromagram_output_layer(x_rec) # [B, T, 24]
401
+
402
+ codebook_loss = (all_codebook_losses + all_commit_losses).mean()
403
+ all_indices = all_indices
404
+
405
+ return chromagram_rec, codebook_loss, all_indices
406
+
407
+ def quantize(self, chromagram_feats):
408
+ all_indices, quantized_out = self.forward(
409
+ chromagram_feats, return_for_quantizer=True
410
+ )
411
+ return all_indices, quantized_out
412
+
413
+
414
+ # if __name__ == "__main__":
415
+ # from utils.util import JsonHParams
416
+
417
+ # cfg = JsonHParams(
418
+ # **{
419
+ # "whisper_dim": 1024,
420
+ # "chromagram_dim": 24,
421
+ # "global_speaker_encoder": {
422
+ # "input_dim": 128, # Eg: n_mels
423
+ # "hidden_size": 512, # 768 for emilia298k
424
+ # "num_hidden_layers": 4, # 6 for emilia298k
425
+ # "num_attention_heads": 8,
426
+ # },
427
+ # }
428
+ # )
429
+ # model = Coco(cfg=cfg)
430
+
431
+ # x = torch.randn(2, 150, 1024)
432
+ # tone_height = torch.randn(2)
433
+ # mels = torch.randn(2, 150, 128)
434
+ # mel_mask = torch.ones(2, 150)
435
+
436
+ # x_rec, codebook_loss, all_indices, auxillary_pred_outputs = model(
437
+ # x, tone_height, mels, mel_mask
438
+ # )
439
+ # print(x_rec.shape, codebook_loss, all_indices.shape)
440
+ # for k, v in auxillary_pred_outputs.items():
441
+ # print(k, v.shape)
models/codec/melvqgan/melspec.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import pyworld as pw
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import os
11
+ from torchaudio.functional import pitch_shift
12
+ import librosa
13
+ from librosa.filters import mel as librosa_mel_fn
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import tqdm
17
+
18
+
19
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
20
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
21
+
22
+
23
+ def dynamic_range_decompression(x, C=1):
24
+ return np.exp(x) / C
25
+
26
+
27
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
28
+ return torch.log(torch.clamp(x, min=clip_val) * C)
29
+
30
+
31
+ def dynamic_range_decompression_torch(x, C=1):
32
+ return torch.exp(x) / C
33
+
34
+
35
+ def spectral_normalize_torch(magnitudes):
36
+ output = dynamic_range_compression_torch(magnitudes)
37
+ return output
38
+
39
+
40
+ def spectral_de_normalize_torch(magnitudes):
41
+ output = dynamic_range_decompression_torch(magnitudes)
42
+ return output
43
+
44
+
45
+ class MelSpectrogram(nn.Module):
46
+ def __init__(
47
+ self,
48
+ n_fft,
49
+ num_mels,
50
+ sampling_rate,
51
+ hop_size,
52
+ win_size,
53
+ fmin,
54
+ fmax,
55
+ center=False,
56
+ ):
57
+ super(MelSpectrogram, self).__init__()
58
+ self.n_fft = n_fft
59
+ self.hop_size = hop_size
60
+ self.win_size = win_size
61
+ self.sampling_rate = sampling_rate
62
+ self.num_mels = num_mels
63
+ self.fmin = fmin
64
+ self.fmax = fmax
65
+ self.center = center
66
+
67
+ mel_basis = {}
68
+ hann_window = {}
69
+
70
+ mel = librosa_mel_fn(
71
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
72
+ )
73
+ mel_basis = torch.from_numpy(mel).float()
74
+ hann_window = torch.hann_window(win_size)
75
+
76
+ self.register_buffer("mel_basis", mel_basis)
77
+ self.register_buffer("hann_window", hann_window)
78
+
79
+ def forward(self, y):
80
+ y = torch.nn.functional.pad(
81
+ y.unsqueeze(1),
82
+ (
83
+ int((self.n_fft - self.hop_size) / 2),
84
+ int((self.n_fft - self.hop_size) / 2),
85
+ ),
86
+ mode="reflect",
87
+ )
88
+ y = y.squeeze(1)
89
+ spec = torch.stft(
90
+ y,
91
+ self.n_fft,
92
+ hop_length=self.hop_size,
93
+ win_length=self.win_size,
94
+ window=self.hann_window,
95
+ center=self.center,
96
+ pad_mode="reflect",
97
+ normalized=False,
98
+ onesided=True,
99
+ return_complex=True,
100
+ )
101
+ spec = torch.view_as_real(spec)
102
+
103
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
104
+
105
+ spec = torch.matmul(self.mel_basis, spec)
106
+ spec = spectral_normalize_torch(spec)
107
+
108
+ return spec
utils/__init__.py ADDED
File without changes
utils/hparam.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py pylint: disable=line-too-long
7
+ """Hyperparameter values."""
8
+ from __future__ import absolute_import
9
+ from __future__ import division
10
+ from __future__ import print_function
11
+
12
+ import json
13
+ import numbers
14
+ import re
15
+ import six
16
+
17
+ # Define the regular expression for parsing a single clause of the input
18
+ # (delimited by commas). A legal clause looks like:
19
+ # <variable name>[<index>]? = <rhs>
20
+ # where <rhs> is either a single token or [] enclosed list of tokens.
21
+ # For example: "var[1] = a" or "x = [1,2,3]"
22
+ PARAM_RE = re.compile(
23
+ r"""
24
+ (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
25
+ (\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
26
+ \s*=\s*
27
+ ((?P<val>[^,\[]*) # single value: "a" or None
28
+ |
29
+ \[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3"
30
+ ($|,\s*)""",
31
+ re.VERBOSE,
32
+ )
33
+
34
+
35
+ def _parse_fail(name, var_type, value, values):
36
+ """Helper function for raising a value error for bad assignment."""
37
+ raise ValueError(
38
+ "Could not parse hparam '%s' of type '%s' with value '%s' in %s"
39
+ % (name, var_type.__name__, value, values)
40
+ )
41
+
42
+
43
+ def _reuse_fail(name, values):
44
+ """Helper function for raising a value error for reuse of name."""
45
+ raise ValueError("Multiple assignments to variable '%s' in %s" % (name, values))
46
+
47
+
48
+ def _process_scalar_value(name, parse_fn, var_type, m_dict, values, results_dictionary):
49
+ """Update results_dictionary with a scalar value.
50
+
51
+ Used to update the results_dictionary to be returned by parse_values when
52
+ encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".)
53
+
54
+ Mutates results_dictionary.
55
+
56
+ Args:
57
+ name: Name of variable in assignment ("s" or "arr").
58
+ parse_fn: Function for parsing the actual value.
59
+ var_type: Type of named variable.
60
+ m_dict: Dictionary constructed from regex parsing.
61
+ m_dict['val']: RHS value (scalar)
62
+ m_dict['index']: List index value (or None)
63
+ values: Full expression being parsed
64
+ results_dictionary: The dictionary being updated for return by the parsing
65
+ function.
66
+
67
+ Raises:
68
+ ValueError: If the name has already been used.
69
+ """
70
+ try:
71
+ parsed_value = parse_fn(m_dict["val"])
72
+ except ValueError:
73
+ _parse_fail(name, var_type, m_dict["val"], values)
74
+
75
+ # If no index is provided
76
+ if not m_dict["index"]:
77
+ if name in results_dictionary:
78
+ _reuse_fail(name, values)
79
+ results_dictionary[name] = parsed_value
80
+ else:
81
+ if name in results_dictionary:
82
+ # The name has already been used as a scalar, then it
83
+ # will be in this dictionary and map to a non-dictionary.
84
+ if not isinstance(results_dictionary.get(name), dict):
85
+ _reuse_fail(name, values)
86
+ else:
87
+ results_dictionary[name] = {}
88
+
89
+ index = int(m_dict["index"])
90
+ # Make sure the index position hasn't already been assigned a value.
91
+ if index in results_dictionary[name]:
92
+ _reuse_fail("{}[{}]".format(name, index), values)
93
+ results_dictionary[name][index] = parsed_value
94
+
95
+
96
+ def _process_list_value(name, parse_fn, var_type, m_dict, values, results_dictionary):
97
+ """Update results_dictionary from a list of values.
98
+
99
+ Used to update results_dictionary to be returned by parse_values when
100
+ encountering a clause with a list RHS (e.g. "arr=[1,2,3]".)
101
+
102
+ Mutates results_dictionary.
103
+
104
+ Args:
105
+ name: Name of variable in assignment ("arr").
106
+ parse_fn: Function for parsing individual values.
107
+ var_type: Type of named variable.
108
+ m_dict: Dictionary constructed from regex parsing.
109
+ m_dict['val']: RHS value (scalar)
110
+ values: Full expression being parsed
111
+ results_dictionary: The dictionary being updated for return by the parsing
112
+ function.
113
+
114
+ Raises:
115
+ ValueError: If the name has an index or the values cannot be parsed.
116
+ """
117
+ if m_dict["index"] is not None:
118
+ raise ValueError("Assignment of a list to a list index.")
119
+ elements = filter(None, re.split("[ ,]", m_dict["vals"]))
120
+ # Make sure the name hasn't already been assigned a value
121
+ if name in results_dictionary:
122
+ raise _reuse_fail(name, values)
123
+ try:
124
+ results_dictionary[name] = [parse_fn(e) for e in elements]
125
+ except ValueError:
126
+ _parse_fail(name, var_type, m_dict["vals"], values)
127
+
128
+
129
+ def _cast_to_type_if_compatible(name, param_type, value):
130
+ """Cast hparam to the provided type, if compatible.
131
+
132
+ Args:
133
+ name: Name of the hparam to be cast.
134
+ param_type: The type of the hparam.
135
+ value: The value to be cast, if compatible.
136
+
137
+ Returns:
138
+ The result of casting `value` to `param_type`.
139
+
140
+ Raises:
141
+ ValueError: If the type of `value` is not compatible with param_type.
142
+ * If `param_type` is a string type, but `value` is not.
143
+ * If `param_type` is a boolean, but `value` is not, or vice versa.
144
+ * If `param_type` is an integer type, but `value` is not.
145
+ * If `param_type` is a float type, but `value` is not a numeric type.
146
+ """
147
+ fail_msg = "Could not cast hparam '%s' of type '%s' from value %r" % (
148
+ name,
149
+ param_type,
150
+ value,
151
+ )
152
+
153
+ # Some callers use None, for which we can't do any casting/checking. :(
154
+ if issubclass(param_type, type(None)):
155
+ return value
156
+
157
+ # Avoid converting a non-string type to a string.
158
+ if issubclass(param_type, (six.string_types, six.binary_type)) and not isinstance(
159
+ value, (six.string_types, six.binary_type)
160
+ ):
161
+ raise ValueError(fail_msg)
162
+
163
+ # Avoid converting a number or string type to a boolean or vice versa.
164
+ if issubclass(param_type, bool) != isinstance(value, bool):
165
+ raise ValueError(fail_msg)
166
+
167
+ # Avoid converting float to an integer (the reverse is fine).
168
+ if issubclass(param_type, numbers.Integral) and not isinstance(
169
+ value, numbers.Integral
170
+ ):
171
+ raise ValueError(fail_msg)
172
+
173
+ # Avoid converting a non-numeric type to a numeric type.
174
+ if issubclass(param_type, numbers.Number) and not isinstance(value, numbers.Number):
175
+ raise ValueError(fail_msg)
176
+
177
+ return param_type(value)
178
+
179
+
180
+ def parse_values(values, type_map, ignore_unknown=False):
181
+ """Parses hyperparameter values from a string into a python map.
182
+
183
+ `values` is a string containing comma-separated `name=value` pairs.
184
+ For each pair, the value of the hyperparameter named `name` is set to
185
+ `value`.
186
+
187
+ If a hyperparameter name appears multiple times in `values`, a ValueError
188
+ is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
189
+
190
+ If a hyperparameter name in both an index assignment and scalar assignment,
191
+ a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
192
+
193
+ The hyperparameter name may contain '.' symbols, which will result in an
194
+ attribute name that is only accessible through the getattr and setattr
195
+ functions. (And must be first explicit added through add_hparam.)
196
+
197
+ WARNING: Use of '.' in your variable names is allowed, but is not well
198
+ supported and not recommended.
199
+
200
+ The `value` in `name=value` must follows the syntax according to the
201
+ type of the parameter:
202
+
203
+ * Scalar integer: A Python-parsable integer point value. E.g.: 1,
204
+ 100, -12.
205
+ * Scalar float: A Python-parsable floating point value. E.g.: 1.0,
206
+ -.54e89.
207
+ * Boolean: Either true or false.
208
+ * Scalar string: A non-empty sequence of characters, excluding comma,
209
+ spaces, and square brackets. E.g.: foo, bar_1.
210
+ * List: A comma separated list of scalar values of the parameter type
211
+ enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low].
212
+
213
+ When index assignment is used, the corresponding type_map key should be the
214
+ list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not
215
+ "arr[1]").
216
+
217
+ Args:
218
+ values: String. Comma separated list of `name=value` pairs where
219
+ 'value' must follow the syntax described above.
220
+ type_map: A dictionary mapping hyperparameter names to types. Note every
221
+ parameter name in values must be a key in type_map. The values must
222
+ conform to the types indicated, where a value V is said to conform to a
223
+ type T if either V has type T, or V is a list of elements of type T.
224
+ Hence, for a multidimensional parameter 'x' taking float values,
225
+ 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
226
+ ignore_unknown: Bool. Whether values that are missing a type in type_map
227
+ should be ignored. If set to True, a ValueError will not be raised for
228
+ unknown hyperparameter type.
229
+
230
+ Returns:
231
+ A python map mapping each name to either:
232
+ * A scalar value.
233
+ * A list of scalar values.
234
+ * A dictionary mapping index numbers to scalar values.
235
+ (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
236
+
237
+ Raises:
238
+ ValueError: If there is a problem with input.
239
+ * If `values` cannot be parsed.
240
+ * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
241
+ * If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
242
+ 'a[1]=1,a[1]=2', or 'a=1,a=[1]')
243
+ """
244
+ results_dictionary = {}
245
+ pos = 0
246
+ while pos < len(values):
247
+ m = PARAM_RE.match(values, pos)
248
+ if not m:
249
+ raise ValueError("Malformed hyperparameter value: %s" % values[pos:])
250
+ # Check that there is a comma between parameters and move past it.
251
+ pos = m.end()
252
+ # Parse the values.
253
+ m_dict = m.groupdict()
254
+ name = m_dict["name"]
255
+ if name not in type_map:
256
+ if ignore_unknown:
257
+ continue
258
+ raise ValueError("Unknown hyperparameter type for %s" % name)
259
+ type_ = type_map[name]
260
+
261
+ # Set up correct parsing function (depending on whether type_ is a bool)
262
+ if type_ == bool:
263
+
264
+ def parse_bool(value):
265
+ if value in ["true", "True"]:
266
+ return True
267
+ elif value in ["false", "False"]:
268
+ return False
269
+ else:
270
+ try:
271
+ return bool(int(value))
272
+ except ValueError:
273
+ _parse_fail(name, type_, value, values)
274
+
275
+ parse = parse_bool
276
+ else:
277
+ parse = type_
278
+
279
+ # If a singe value is provided
280
+ if m_dict["val"] is not None:
281
+ _process_scalar_value(
282
+ name, parse, type_, m_dict, values, results_dictionary
283
+ )
284
+
285
+ # If the assigned value is a list:
286
+ elif m_dict["vals"] is not None:
287
+ _process_list_value(name, parse, type_, m_dict, values, results_dictionary)
288
+
289
+ else: # Not assigned a list or value
290
+ _parse_fail(name, type_, "", values)
291
+
292
+ return results_dictionary
293
+
294
+
295
+ class HParams(object):
296
+ """Class to hold a set of hyperparameters as name-value pairs.
297
+
298
+ A `HParams` object holds hyperparameters used to build and train a model,
299
+ such as the number of hidden units in a neural net layer or the learning rate
300
+ to use when training.
301
+
302
+ You first create a `HParams` object by specifying the names and values of the
303
+ hyperparameters.
304
+
305
+ To make them easily accessible the parameter names are added as direct
306
+ attributes of the class. A typical usage is as follows:
307
+
308
+ ```python
309
+ # Create a HParams object specifying names and values of the model
310
+ # hyperparameters:
311
+ hparams = HParams(learning_rate=0.1, num_hidden_units=100)
312
+
313
+ # The hyperparameter are available as attributes of the HParams object:
314
+ hparams.learning_rate ==> 0.1
315
+ hparams.num_hidden_units ==> 100
316
+ ```
317
+
318
+ Hyperparameters have type, which is inferred from the type of their value
319
+ passed at construction type. The currently supported types are: integer,
320
+ float, boolean, string, and list of integer, float, boolean, or string.
321
+
322
+ You can override hyperparameter values by calling the
323
+ [`parse()`](#HParams.parse) method, passing a string of comma separated
324
+ `name=value` pairs. This is intended to make it possible to override
325
+ any hyperparameter values from a single command-line flag to which
326
+ the user passes 'hyper-param=value' pairs. It avoids having to define
327
+ one flag for each hyperparameter.
328
+
329
+ The syntax expected for each value depends on the type of the parameter.
330
+ See `parse()` for a description of the syntax.
331
+
332
+ Example:
333
+
334
+ ```python
335
+ # Define a command line flag to pass name=value pairs.
336
+ # For example using argparse:
337
+ import argparse
338
+ parser = argparse.ArgumentParser(description='Train my model.')
339
+ parser.add_argument('--hparams', type=str,
340
+ help='Comma separated list of "name=value" pairs.')
341
+ args = parser.parse_args()
342
+ ...
343
+ def my_program():
344
+ # Create a HParams object specifying the names and values of the
345
+ # model hyperparameters:
346
+ hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100,
347
+ activations=['relu', 'tanh'])
348
+
349
+ # Override hyperparameters values by parsing the command line
350
+ hparams.parse(args.hparams)
351
+
352
+ # If the user passed `--hparams=learning_rate=0.3` on the command line
353
+ # then 'hparams' has the following attributes:
354
+ hparams.learning_rate ==> 0.3
355
+ hparams.num_hidden_units ==> 100
356
+ hparams.activations ==> ['relu', 'tanh']
357
+
358
+ # If the hyperparameters are in json format use parse_json:
359
+ hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
360
+ ```
361
+ """
362
+
363
+ _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks.
364
+
365
+ def __init__(self, model_structure=None, **kwargs):
366
+ """Create an instance of `HParams` from keyword arguments.
367
+
368
+ The keyword arguments specify name-values pairs for the hyperparameters.
369
+ The parameter types are inferred from the type of the values passed.
370
+
371
+ The parameter names are added as attributes of `HParams` object, so they
372
+ can be accessed directly with the dot notation `hparams._name_`.
373
+
374
+ Example:
375
+
376
+ ```python
377
+ # Define 3 hyperparameters: 'learning_rate' is a float parameter,
378
+ # 'num_hidden_units' an integer parameter, and 'activation' a string
379
+ # parameter.
380
+ hparams = tf.HParams(
381
+ learning_rate=0.1, num_hidden_units=100, activation='relu')
382
+
383
+ hparams.activation ==> 'relu'
384
+ ```
385
+
386
+ Note that a few names are reserved and cannot be used as hyperparameter
387
+ names. If you use one of the reserved name the constructor raises a
388
+ `ValueError`.
389
+
390
+ Args:
391
+ model_structure: An instance of ModelStructure, defining the feature
392
+ crosses to be used in the Trial.
393
+ **kwargs: Key-value pairs where the key is the hyperparameter name and
394
+ the value is the value for the parameter.
395
+
396
+ Raises:
397
+ ValueError: If both `hparam_def` and initialization values are provided,
398
+ or if one of the arguments is invalid.
399
+
400
+ """
401
+ # Register the hyperparameters and their type in _hparam_types.
402
+ # This simplifies the implementation of parse().
403
+ # _hparam_types maps the parameter name to a tuple (type, bool).
404
+ # The type value is the type of the parameter for scalar hyperparameters,
405
+ # or the type of the list elements for multidimensional hyperparameters.
406
+ # The bool value is True if the value is a list, False otherwise.
407
+ self._hparam_types = {}
408
+ self._model_structure = model_structure
409
+ for name, value in six.iteritems(kwargs):
410
+ self.add_hparam(name, value)
411
+
412
+ def add_hparam(self, name, value):
413
+ """Adds {name, value} pair to hyperparameters.
414
+
415
+ Args:
416
+ name: Name of the hyperparameter.
417
+ value: Value of the hyperparameter. Can be one of the following types:
418
+ int, float, string, int list, float list, or string list.
419
+
420
+ Raises:
421
+ ValueError: if one of the arguments is invalid.
422
+ """
423
+ # Keys in kwargs are unique, but 'name' could the name of a pre-existing
424
+ # attribute of this object. In that case we refuse to use it as a
425
+ # hyperparameter name.
426
+ if getattr(self, name, None) is not None:
427
+ raise ValueError("Hyperparameter name is reserved: %s" % name)
428
+ if isinstance(value, (list, tuple)):
429
+ if not value:
430
+ raise ValueError(
431
+ "Multi-valued hyperparameters cannot be empty: %s" % name
432
+ )
433
+ self._hparam_types[name] = (type(value[0]), True)
434
+ else:
435
+ self._hparam_types[name] = (type(value), False)
436
+ setattr(self, name, value)
437
+
438
+ def set_hparam(self, name, value):
439
+ """Set the value of an existing hyperparameter.
440
+
441
+ This function verifies that the type of the value matches the type of the
442
+ existing hyperparameter.
443
+
444
+ Args:
445
+ name: Name of the hyperparameter.
446
+ value: New value of the hyperparameter.
447
+
448
+ Raises:
449
+ KeyError: If the hyperparameter doesn't exist.
450
+ ValueError: If there is a type mismatch.
451
+ """
452
+ param_type, is_list = self._hparam_types[name]
453
+ if isinstance(value, list):
454
+ if not is_list:
455
+ raise ValueError(
456
+ "Must not pass a list for single-valued parameter: %s" % name
457
+ )
458
+ setattr(
459
+ self,
460
+ name,
461
+ [_cast_to_type_if_compatible(name, param_type, v) for v in value],
462
+ )
463
+ else:
464
+ if is_list:
465
+ raise ValueError(
466
+ "Must pass a list for multi-valued parameter: %s." % name
467
+ )
468
+ setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
469
+
470
+ def del_hparam(self, name):
471
+ """Removes the hyperparameter with key 'name'.
472
+
473
+ Does nothing if it isn't present.
474
+
475
+ Args:
476
+ name: Name of the hyperparameter.
477
+ """
478
+ if hasattr(self, name):
479
+ delattr(self, name)
480
+ del self._hparam_types[name]
481
+
482
+ def parse(self, values):
483
+ """Override existing hyperparameter values, parsing new values from a string.
484
+
485
+ See parse_values for more detail on the allowed format for values.
486
+
487
+ Args:
488
+ values: String. Comma separated list of `name=value` pairs where 'value'
489
+ must follow the syntax described above.
490
+
491
+ Returns:
492
+ The `HParams` instance.
493
+
494
+ Raises:
495
+ ValueError: If `values` cannot be parsed or a hyperparameter in `values`
496
+ doesn't exist.
497
+ """
498
+ type_map = {}
499
+ for name, t in self._hparam_types.items():
500
+ param_type, _ = t
501
+ type_map[name] = param_type
502
+
503
+ values_map = parse_values(values, type_map)
504
+ return self.override_from_dict(values_map)
505
+
506
+ def override_from_dict(self, values_dict):
507
+ """Override existing hyperparameter values, parsing new values from a dictionary.
508
+
509
+ Args:
510
+ values_dict: Dictionary of name:value pairs.
511
+
512
+ Returns:
513
+ The `HParams` instance.
514
+
515
+ Raises:
516
+ KeyError: If a hyperparameter in `values_dict` doesn't exist.
517
+ ValueError: If `values_dict` cannot be parsed.
518
+ """
519
+ for name, value in values_dict.items():
520
+ self.set_hparam(name, value)
521
+ return self
522
+
523
+ def set_model_structure(self, model_structure):
524
+ self._model_structure = model_structure
525
+
526
+ def get_model_structure(self):
527
+ return self._model_structure
528
+
529
+ def to_json(self, indent=None, separators=None, sort_keys=False):
530
+ """Serializes the hyperparameters into JSON.
531
+
532
+ Args:
533
+ indent: If a non-negative integer, JSON array elements and object members
534
+ will be pretty-printed with that indent level. An indent level of 0, or
535
+ negative, will only insert newlines. `None` (the default) selects the
536
+ most compact representation.
537
+ separators: Optional `(item_separator, key_separator)` tuple. Default is
538
+ `(', ', ': ')`.
539
+ sort_keys: If `True`, the output dictionaries will be sorted by key.
540
+
541
+ Returns:
542
+ A JSON string.
543
+ """
544
+
545
+ def remove_callables(x):
546
+ """Omit callable elements from input with arbitrary nesting."""
547
+ if isinstance(x, dict):
548
+ return {
549
+ k: remove_callables(v)
550
+ for k, v in six.iteritems(x)
551
+ if not callable(v)
552
+ }
553
+ elif isinstance(x, list):
554
+ return [remove_callables(i) for i in x if not callable(i)]
555
+ return x
556
+
557
+ return json.dumps(
558
+ remove_callables(self.values()),
559
+ indent=indent,
560
+ separators=separators,
561
+ sort_keys=sort_keys,
562
+ )
563
+
564
+ def parse_json(self, values_json):
565
+ """Override existing hyperparameter values, parsing new values from a json object.
566
+
567
+ Args:
568
+ values_json: String containing a json object of name:value pairs.
569
+
570
+ Returns:
571
+ The `HParams` instance.
572
+
573
+ Raises:
574
+ KeyError: If a hyperparameter in `values_json` doesn't exist.
575
+ ValueError: If `values_json` cannot be parsed.
576
+ """
577
+ values_map = json.loads(values_json)
578
+ return self.override_from_dict(values_map)
579
+
580
+ def values(self):
581
+ """Return the hyperparameter values as a Python dictionary.
582
+
583
+ Returns:
584
+ A dictionary with hyperparameter names as keys. The values are the
585
+ hyperparameter values.
586
+ """
587
+ return {n: getattr(self, n) for n in self._hparam_types.keys()}
588
+
589
+ def get(self, key, default=None):
590
+ """Returns the value of `key` if it exists, else `default`."""
591
+ if key in self._hparam_types:
592
+ # Ensure that default is compatible with the parameter type.
593
+ if default is not None:
594
+ param_type, is_param_list = self._hparam_types[key]
595
+ type_str = "list<%s>" % param_type if is_param_list else str(param_type)
596
+ fail_msg = (
597
+ "Hparam '%s' of type '%s' is incompatible with "
598
+ "default=%s" % (key, type_str, default)
599
+ )
600
+
601
+ is_default_list = isinstance(default, list)
602
+ if is_param_list != is_default_list:
603
+ raise ValueError(fail_msg)
604
+
605
+ try:
606
+ if is_default_list:
607
+ for value in default:
608
+ _cast_to_type_if_compatible(key, param_type, value)
609
+ else:
610
+ _cast_to_type_if_compatible(key, param_type, default)
611
+ except ValueError as e:
612
+ raise ValueError("%s. %s" % (fail_msg, e))
613
+
614
+ return getattr(self, key)
615
+
616
+ return default
617
+
618
+ def __contains__(self, key):
619
+ return key in self._hparam_types
620
+
621
+ def __str__(self):
622
+ return str(sorted(self.values().items()))
623
+
624
+ def __repr__(self):
625
+ return "%s(%s)" % (type(self).__name__, self.__str__())
626
+
627
+ @staticmethod
628
+ def _get_kind_name(param_type, is_list):
629
+ """Returns the field name given parameter type and is_list.
630
+
631
+ Args:
632
+ param_type: Data type of the hparam.
633
+ is_list: Whether this is a list.
634
+
635
+ Returns:
636
+ A string representation of the field name.
637
+
638
+ Raises:
639
+ ValueError: If parameter type is not recognized.
640
+ """
641
+ if issubclass(param_type, bool):
642
+ # This check must happen before issubclass(param_type, six.integer_types),
643
+ # since Python considers bool to be a subclass of int.
644
+ typename = "bool"
645
+ elif issubclass(param_type, six.integer_types):
646
+ # Setting 'int' and 'long' types to be 'int64' to ensure the type is
647
+ # compatible with both Python2 and Python3.
648
+ typename = "int64"
649
+ elif issubclass(param_type, (six.string_types, six.binary_type)):
650
+ # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
651
+ # compatible with both Python2 and Python3.
652
+ typename = "bytes"
653
+ elif issubclass(param_type, float):
654
+ typename = "float"
655
+ else:
656
+ raise ValueError("Unsupported parameter type: %s" % str(param_type))
657
+
658
+ suffix = "list" if is_list else "value"
659
+ return "_".join([typename, suffix])
utils/util.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import collections
8
+ import glob
9
+ import os
10
+ import random
11
+ import time
12
+ import argparse
13
+ from collections import OrderedDict
14
+
15
+ import json5
16
+ import numpy as np
17
+ import glob
18
+ from torch.nn import functional as F
19
+
20
+
21
+ try:
22
+ from ruamel.yaml import YAML as yaml
23
+ except:
24
+ from ruamel_yaml import YAML as yaml
25
+
26
+ import torch
27
+
28
+ from utils.hparam import HParams
29
+ import logging
30
+ from logging import handlers
31
+
32
+
33
+ def str2bool(v):
34
+ """Used in argparse.ArgumentParser.add_argument to indicate
35
+ that a type is a bool type and user can enter
36
+
37
+ - yes, true, t, y, 1, to represent True
38
+ - no, false, f, n, 0, to represent False
39
+
40
+ See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
41
+ """
42
+ if isinstance(v, bool):
43
+ return v
44
+ if v.lower() in ("yes", "true", "t", "y", "1"):
45
+ return True
46
+ elif v.lower() in ("no", "false", "f", "n", "0"):
47
+ return False
48
+ else:
49
+ raise argparse.ArgumentTypeError("Boolean value expected.")
50
+
51
+
52
+ def find_checkpoint_of_mapper(mapper_ckpt_dir):
53
+ mapper_ckpts = glob.glob(os.path.join(mapper_ckpt_dir, "ckpts/*.pt"))
54
+
55
+ # Select the max steps
56
+ mapper_ckpts.sort()
57
+ mapper_weights_file = mapper_ckpts[-1]
58
+ return mapper_weights_file
59
+
60
+
61
+ def pad_f0_to_tensors(f0s, batched=None):
62
+ # Initialize
63
+ tensors = []
64
+
65
+ if batched == None:
66
+ # Get the max frame for padding
67
+ size = -1
68
+ for f0 in f0s:
69
+ size = max(size, f0.shape[-1])
70
+
71
+ tensor = torch.zeros(len(f0s), size)
72
+
73
+ for i, f0 in enumerate(f0s):
74
+ tensor[i, : f0.shape[-1]] = f0[:]
75
+
76
+ tensors.append(tensor)
77
+ else:
78
+ start = 0
79
+ while start + batched - 1 < len(f0s):
80
+ end = start + batched - 1
81
+
82
+ # Get the max frame for padding
83
+ size = -1
84
+ for i in range(start, end + 1):
85
+ size = max(size, f0s[i].shape[-1])
86
+
87
+ tensor = torch.zeros(batched, size)
88
+
89
+ for i in range(start, end + 1):
90
+ tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:]
91
+
92
+ tensors.append(tensor)
93
+
94
+ start = start + batched
95
+
96
+ if start != len(f0s):
97
+ end = len(f0s)
98
+
99
+ # Get the max frame for padding
100
+ size = -1
101
+ for i in range(start, end):
102
+ size = max(size, f0s[i].shape[-1])
103
+
104
+ tensor = torch.zeros(len(f0s) - start, size)
105
+
106
+ for i in range(start, end):
107
+ tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:]
108
+
109
+ tensors.append(tensor)
110
+
111
+ return tensors
112
+
113
+
114
+ def pad_mels_to_tensors(mels, batched=None):
115
+ """
116
+ Args:
117
+ mels: A list of mel-specs
118
+ Returns:
119
+ tensors: A list of tensors containing the batched mel-specs
120
+ mel_frames: A list of tensors containing the frames of the original mel-specs
121
+ """
122
+ # Initialize
123
+ tensors = []
124
+ mel_frames = []
125
+
126
+ # Split mel-specs into batches to avoid cuda memory exceed
127
+ if batched == None:
128
+ # Get the max frame for padding
129
+ size = -1
130
+ for mel in mels:
131
+ size = max(size, mel.shape[-1])
132
+
133
+ tensor = torch.zeros(len(mels), mels[0].shape[0], size)
134
+ mel_frame = torch.zeros(len(mels), dtype=torch.int32)
135
+
136
+ for i, mel in enumerate(mels):
137
+ tensor[i, :, : mel.shape[-1]] = mel[:]
138
+ mel_frame[i] = mel.shape[-1]
139
+
140
+ tensors.append(tensor)
141
+ mel_frames.append(mel_frame)
142
+ else:
143
+ start = 0
144
+ while start + batched - 1 < len(mels):
145
+ end = start + batched - 1
146
+
147
+ # Get the max frame for padding
148
+ size = -1
149
+ for i in range(start, end + 1):
150
+ size = max(size, mels[i].shape[-1])
151
+
152
+ tensor = torch.zeros(batched, mels[0].shape[0], size)
153
+ mel_frame = torch.zeros(batched, dtype=torch.int32)
154
+
155
+ for i in range(start, end + 1):
156
+ tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:]
157
+ mel_frame[i - start] = mels[i].shape[-1]
158
+
159
+ tensors.append(tensor)
160
+ mel_frames.append(mel_frame)
161
+
162
+ start = start + batched
163
+
164
+ if start != len(mels):
165
+ end = len(mels)
166
+
167
+ # Get the max frame for padding
168
+ size = -1
169
+ for i in range(start, end):
170
+ size = max(size, mels[i].shape[-1])
171
+
172
+ tensor = torch.zeros(len(mels) - start, mels[0].shape[0], size)
173
+ mel_frame = torch.zeros(len(mels) - start, dtype=torch.int32)
174
+
175
+ for i in range(start, end):
176
+ tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:]
177
+ mel_frame[i - start] = mels[i].shape[-1]
178
+
179
+ tensors.append(tensor)
180
+ mel_frames.append(mel_frame)
181
+
182
+ return tensors, mel_frames
183
+
184
+
185
+ def load_model_config(args):
186
+ """Load model configurations (in args.json under checkpoint directory)
187
+
188
+ Args:
189
+ args (ArgumentParser): arguments to run bins/preprocess.py
190
+
191
+ Returns:
192
+ dict: dictionary that stores model configurations
193
+ """
194
+ if args.checkpoint_dir is None:
195
+ assert args.checkpoint_file is not None
196
+ checkpoint_dir = os.path.split(args.checkpoint_file)[0]
197
+ else:
198
+ checkpoint_dir = args.checkpoint_dir
199
+ config_path = os.path.join(checkpoint_dir, "args.json")
200
+ print("config_path: ", config_path)
201
+
202
+ config = load_config(config_path)
203
+ return config
204
+
205
+
206
+ def remove_and_create(dir):
207
+ if os.path.exists(dir):
208
+ os.system("rm -r {}".format(dir))
209
+ os.makedirs(dir, exist_ok=True)
210
+
211
+
212
+ def has_existed(path, warning=False):
213
+ if not warning:
214
+ return os.path.exists(path)
215
+
216
+ if os.path.exists(path):
217
+ answer = input(
218
+ "The path {} has existed. \nInput 'y' (or hit Enter) to skip it, and input 'n' to re-write it [y/n]\n".format(
219
+ path
220
+ )
221
+ )
222
+ if not answer == "n":
223
+ return True
224
+
225
+ return False
226
+
227
+
228
+ def remove_older_ckpt(saved_model_name, checkpoint_dir, max_to_keep=5):
229
+ if os.path.exists(os.path.join(checkpoint_dir, "checkpoint")):
230
+ with open(os.path.join(checkpoint_dir, "checkpoint"), "r") as f:
231
+ ckpts = [x.strip() for x in f.readlines()]
232
+ else:
233
+ ckpts = []
234
+ ckpts.append(saved_model_name)
235
+ for item in ckpts[:-max_to_keep]:
236
+ if os.path.exists(os.path.join(checkpoint_dir, item)):
237
+ os.remove(os.path.join(checkpoint_dir, item))
238
+ with open(os.path.join(checkpoint_dir, "checkpoint"), "w") as f:
239
+ for item in ckpts[-max_to_keep:]:
240
+ f.write("{}\n".format(item))
241
+
242
+
243
+ def set_all_random_seed(seed: int):
244
+ random.seed(seed)
245
+ np.random.seed(seed)
246
+ torch.random.manual_seed(seed)
247
+
248
+
249
+ def save_checkpoint(
250
+ args,
251
+ generator,
252
+ g_optimizer,
253
+ step,
254
+ discriminator=None,
255
+ d_optimizer=None,
256
+ max_to_keep=5,
257
+ ):
258
+ saved_model_name = "model.ckpt-{}.pt".format(step)
259
+ checkpoint_path = os.path.join(args.checkpoint_dir, saved_model_name)
260
+
261
+ if discriminator and d_optimizer:
262
+ torch.save(
263
+ {
264
+ "generator": generator.state_dict(),
265
+ "discriminator": discriminator.state_dict(),
266
+ "g_optimizer": g_optimizer.state_dict(),
267
+ "d_optimizer": d_optimizer.state_dict(),
268
+ "global_step": step,
269
+ },
270
+ checkpoint_path,
271
+ )
272
+ else:
273
+ torch.save(
274
+ {
275
+ "generator": generator.state_dict(),
276
+ "g_optimizer": g_optimizer.state_dict(),
277
+ "global_step": step,
278
+ },
279
+ checkpoint_path,
280
+ )
281
+
282
+ print("Saved checkpoint: {}".format(checkpoint_path))
283
+
284
+ if os.path.exists(os.path.join(args.checkpoint_dir, "checkpoint")):
285
+ with open(os.path.join(args.checkpoint_dir, "checkpoint"), "r") as f:
286
+ ckpts = [x.strip() for x in f.readlines()]
287
+ else:
288
+ ckpts = []
289
+ ckpts.append(saved_model_name)
290
+ for item in ckpts[:-max_to_keep]:
291
+ if os.path.exists(os.path.join(args.checkpoint_dir, item)):
292
+ os.remove(os.path.join(args.checkpoint_dir, item))
293
+ with open(os.path.join(args.checkpoint_dir, "checkpoint"), "w") as f:
294
+ for item in ckpts[-max_to_keep:]:
295
+ f.write("{}\n".format(item))
296
+
297
+
298
+ def attempt_to_restore(
299
+ generator, g_optimizer, checkpoint_dir, discriminator=None, d_optimizer=None
300
+ ):
301
+ checkpoint_list = os.path.join(checkpoint_dir, "checkpoint")
302
+ if os.path.exists(checkpoint_list):
303
+ checkpoint_filename = open(checkpoint_list).readlines()[-1].strip()
304
+ checkpoint_path = os.path.join(checkpoint_dir, "{}".format(checkpoint_filename))
305
+ print("Restore from {}".format(checkpoint_path))
306
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
307
+ if generator:
308
+ if not list(generator.state_dict().keys())[0].startswith("module."):
309
+ raw_dict = checkpoint["generator"]
310
+ clean_dict = OrderedDict()
311
+ for k, v in raw_dict.items():
312
+ if k.startswith("module."):
313
+ clean_dict[k[7:]] = v
314
+ else:
315
+ clean_dict[k] = v
316
+ generator.load_state_dict(clean_dict)
317
+ else:
318
+ generator.load_state_dict(checkpoint["generator"])
319
+ if g_optimizer:
320
+ g_optimizer.load_state_dict(checkpoint["g_optimizer"])
321
+ global_step = 100000
322
+ if discriminator and "discriminator" in checkpoint.keys():
323
+ discriminator.load_state_dict(checkpoint["discriminator"])
324
+ global_step = checkpoint["global_step"]
325
+ print("restore discriminator")
326
+ if d_optimizer and "d_optimizer" in checkpoint.keys():
327
+ d_optimizer.load_state_dict(checkpoint["d_optimizer"])
328
+ print("restore d_optimizer...")
329
+ else:
330
+ global_step = 0
331
+ return global_step
332
+
333
+
334
+ class ExponentialMovingAverage(object):
335
+ def __init__(self, decay):
336
+ self.decay = decay
337
+ self.shadow = {}
338
+
339
+ def register(self, name, val):
340
+ self.shadow[name] = val.clone()
341
+
342
+ def update(self, name, x):
343
+ assert name in self.shadow
344
+ update_delta = self.shadow[name] - x
345
+ self.shadow[name] -= (1.0 - self.decay) * update_delta
346
+
347
+
348
+ def apply_moving_average(model, ema):
349
+ for name, param in model.named_parameters():
350
+ if name in ema.shadow:
351
+ ema.update(name, param.data)
352
+
353
+
354
+ def register_model_to_ema(model, ema):
355
+ for name, param in model.named_parameters():
356
+ if param.requires_grad:
357
+ ema.register(name, param.data)
358
+
359
+
360
+ class YParams(HParams):
361
+ def __init__(self, yaml_file):
362
+ if not os.path.exists(yaml_file):
363
+ raise IOError("yaml file: {} is not existed".format(yaml_file))
364
+ super().__init__()
365
+ self.d = collections.OrderedDict()
366
+ with open(yaml_file) as fp:
367
+ for _, v in yaml().load(fp).items():
368
+ for k1, v1 in v.items():
369
+ try:
370
+ if self.get(k1):
371
+ self.set_hparam(k1, v1)
372
+ else:
373
+ self.add_hparam(k1, v1)
374
+ self.d[k1] = v1
375
+ except Exception:
376
+ import traceback
377
+
378
+ print(traceback.format_exc())
379
+
380
+ # @property
381
+ def get_elements(self):
382
+ return self.d.items()
383
+
384
+
385
+ def override_config(base_config, new_config):
386
+ """Update new configurations in the original dict with the new dict
387
+
388
+ Args:
389
+ base_config (dict): original dict to be overridden
390
+ new_config (dict): dict with new configurations
391
+
392
+ Returns:
393
+ dict: updated configuration dict
394
+ """
395
+ for k, v in new_config.items():
396
+ if type(v) == dict:
397
+ if k not in base_config.keys():
398
+ base_config[k] = {}
399
+ base_config[k] = override_config(base_config[k], v)
400
+ else:
401
+ base_config[k] = v
402
+ return base_config
403
+
404
+
405
+ def get_lowercase_keys_config(cfg):
406
+ """Change all keys in cfg to lower case
407
+
408
+ Args:
409
+ cfg (dict): dictionary that stores configurations
410
+
411
+ Returns:
412
+ dict: dictionary that stores configurations
413
+ """
414
+ updated_cfg = dict()
415
+ for k, v in cfg.items():
416
+ if type(v) == dict:
417
+ v = get_lowercase_keys_config(v)
418
+ updated_cfg[k.lower()] = v
419
+ return updated_cfg
420
+
421
+
422
+ def _load_config(config_fn, lowercase=False):
423
+ """Load configurations into a dictionary
424
+
425
+ Args:
426
+ config_fn (str): path to configuration file
427
+ lowercase (bool, optional): whether changing keys to lower case. Defaults to False.
428
+
429
+ Returns:
430
+ dict: dictionary that stores configurations
431
+ """
432
+ with open(config_fn, "r") as f:
433
+ data = f.read()
434
+ config_ = json5.loads(data)
435
+ if "base_config" in config_:
436
+ # load configurations from new path
437
+ try:
438
+ p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"])
439
+ except:
440
+ p_config_path = config_["base_config"]
441
+ p_config_ = _load_config(p_config_path)
442
+ config_ = override_config(p_config_, config_)
443
+ if lowercase:
444
+ # change keys in config_ to lower case
445
+ config_ = get_lowercase_keys_config(config_)
446
+ return config_
447
+
448
+
449
+ def load_config(config_fn, lowercase=False):
450
+ """Load configurations into a dictionary
451
+
452
+ Args:
453
+ config_fn (str): path to configuration file
454
+ lowercase (bool, optional): _description_. Defaults to False.
455
+
456
+ Returns:
457
+ JsonHParams: an object that stores configurations
458
+ """
459
+ config_ = _load_config(config_fn, lowercase=lowercase)
460
+ # create an JsonHParams object with configuration dict
461
+ cfg = JsonHParams(**config_)
462
+ return cfg
463
+
464
+
465
+ def save_config(save_path, cfg):
466
+ """Save configurations into a json file
467
+
468
+ Args:
469
+ save_path (str): path to save configurations
470
+ cfg (dict): dictionary that stores configurations
471
+ """
472
+ with open(save_path, "w") as f:
473
+ json5.dump(
474
+ cfg, f, ensure_ascii=False, indent=4, quote_keys=True, sort_keys=True
475
+ )
476
+
477
+
478
+ class JsonHParams:
479
+ def __init__(self, **kwargs):
480
+ for k, v in kwargs.items():
481
+ if type(v) == dict:
482
+ v = JsonHParams(**v)
483
+ self[k] = v
484
+
485
+ def keys(self):
486
+ return self.__dict__.keys()
487
+
488
+ def items(self):
489
+ return self.__dict__.items()
490
+
491
+ def values(self):
492
+ return self.__dict__.values()
493
+
494
+ def __len__(self):
495
+ return len(self.__dict__)
496
+
497
+ def __getitem__(self, key):
498
+ return getattr(self, key)
499
+
500
+ def __setitem__(self, key, value):
501
+ return setattr(self, key, value)
502
+
503
+ def __contains__(self, key):
504
+ return key in self.__dict__
505
+
506
+ def __repr__(self):
507
+ return self.__dict__.__repr__()
508
+
509
+
510
+ class ValueWindow:
511
+ def __init__(self, window_size=100):
512
+ self._window_size = window_size
513
+ self._values = []
514
+
515
+ def append(self, x):
516
+ self._values = self._values[-(self._window_size - 1) :] + [x]
517
+
518
+ @property
519
+ def sum(self):
520
+ return sum(self._values)
521
+
522
+ @property
523
+ def count(self):
524
+ return len(self._values)
525
+
526
+ @property
527
+ def average(self):
528
+ return self.sum / max(1, self.count)
529
+
530
+ def reset(self):
531
+ self._values = []
532
+
533
+
534
+ class Logger(object):
535
+ def __init__(
536
+ self,
537
+ filename,
538
+ level="info",
539
+ when="D",
540
+ backCount=10,
541
+ fmt="%(asctime)s : %(message)s",
542
+ ):
543
+ self.level_relations = {
544
+ "debug": logging.DEBUG,
545
+ "info": logging.INFO,
546
+ "warning": logging.WARNING,
547
+ "error": logging.ERROR,
548
+ "crit": logging.CRITICAL,
549
+ }
550
+ if level == "debug":
551
+ fmt = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s"
552
+ self.logger = logging.getLogger(filename)
553
+ format_str = logging.Formatter(fmt)
554
+ self.logger.setLevel(self.level_relations.get(level))
555
+ sh = logging.StreamHandler()
556
+ sh.setFormatter(format_str)
557
+ th = handlers.TimedRotatingFileHandler(
558
+ filename=filename, when=when, backupCount=backCount, encoding="utf-8"
559
+ )
560
+ th.setFormatter(format_str)
561
+ self.logger.addHandler(sh)
562
+ self.logger.addHandler(th)
563
+ self.logger.info(
564
+ "==========================New Starting Here=============================="
565
+ )
566
+
567
+
568
+ def init_weights(m, mean=0.0, std=0.01):
569
+ classname = m.__class__.__name__
570
+ if classname.find("Conv") != -1:
571
+ m.weight.data.normal_(mean, std)
572
+
573
+
574
+ def get_padding(kernel_size, dilation=1):
575
+ return int((kernel_size * dilation - dilation) / 2)
576
+
577
+
578
+ def slice_segments(x, ids_str, segment_size=4):
579
+ ret = torch.zeros_like(x[:, :, :segment_size])
580
+ for i in range(x.size(0)):
581
+ idx_str = ids_str[i]
582
+ idx_end = idx_str + segment_size
583
+ ret[i] = x[i, :, idx_str:idx_end]
584
+ return ret
585
+
586
+
587
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
588
+ b, d, t = x.size()
589
+ if x_lengths is None:
590
+ x_lengths = t
591
+ ids_str_max = x_lengths - segment_size + 1
592
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
593
+ ret = slice_segments(x, ids_str, segment_size)
594
+ return ret, ids_str
595
+
596
+
597
+ def subsequent_mask(length):
598
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
599
+ return mask
600
+
601
+
602
+ @torch.jit.script
603
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
604
+ n_channels_int = n_channels[0]
605
+ in_act = input_a + input_b
606
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
607
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
608
+ acts = t_act * s_act
609
+ return acts
610
+
611
+
612
+ def convert_pad_shape(pad_shape):
613
+ l = pad_shape[::-1]
614
+ pad_shape = [item for sublist in l for item in sublist]
615
+ return pad_shape
616
+
617
+
618
+ def sequence_mask(length, max_length=None):
619
+ if max_length is None:
620
+ max_length = length.max()
621
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
622
+ return x.unsqueeze(0) < length.unsqueeze(1)
623
+
624
+
625
+ def generate_path(duration, mask):
626
+ """
627
+ duration: [b, 1, t_x]
628
+ mask: [b, 1, t_y, t_x]
629
+ """
630
+ device = duration.device
631
+
632
+ b, _, t_y, t_x = mask.shape
633
+ cum_duration = torch.cumsum(duration, -1)
634
+
635
+ cum_duration_flat = cum_duration.view(b * t_x)
636
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
637
+ path = path.view(b, t_x, t_y)
638
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
639
+ path = path.unsqueeze(1).transpose(2, 3) * mask
640
+ return path
641
+
642
+
643
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
644
+ if isinstance(parameters, torch.Tensor):
645
+ parameters = [parameters]
646
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
647
+ norm_type = float(norm_type)
648
+ if clip_value is not None:
649
+ clip_value = float(clip_value)
650
+
651
+ total_norm = 0
652
+ for p in parameters:
653
+ param_norm = p.grad.data.norm(norm_type)
654
+ total_norm += param_norm.item() ** norm_type
655
+ if clip_value is not None:
656
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
657
+ total_norm = total_norm ** (1.0 / norm_type)
658
+ return total_norm
659
+
660
+
661
+ def get_current_time():
662
+ pass
663
+
664
+
665
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
666
+ """
667
+ Args:
668
+ lengths:
669
+ A 1-D tensor containing sentence lengths.
670
+ max_len:
671
+ The length of masks.
672
+ Returns:
673
+ Return a 2-D bool tensor, where masked positions
674
+ are filled with `True` and non-masked positions are
675
+ filled with `False`.
676
+
677
+ >>> lengths = torch.tensor([1, 3, 2, 5])
678
+ >>> make_pad_mask(lengths)
679
+ tensor([[False, True, True, True, True],
680
+ [False, False, False, True, True],
681
+ [False, False, True, True, True],
682
+ [False, False, False, False, False]])
683
+ """
684
+ assert lengths.ndim == 1, lengths.ndim
685
+ max_len = max(max_len, lengths.max())
686
+ n = lengths.size(0)
687
+ seq_range = torch.arange(0, max_len, device=lengths.device)
688
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
689
+
690
+ return expaned_lengths >= lengths.unsqueeze(-1)