Spaces:
Running
on
Zero
Running
on
Zero
Upload inference related files
Browse files- .gitattributes +3 -0
- anyaccomp/fmt_model.py +367 -0
- anyaccomp/inference_utils.py +124 -0
- anyaccomp/llama_nar.py +667 -0
- config/flow_matching.json +74 -0
- config/vocoder.json +52 -0
- example/gradio/example1.mp3 +3 -0
- example/gradio/example2.wav +3 -0
- example/gradio/example3.wav +3 -0
- models/__init__.py +0 -0
- models/codec/__init__.py +0 -0
- models/codec/amphion_codec/.DS_Store +0 -0
- models/codec/amphion_codec/quantize/__init__.py +11 -0
- models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- models/codec/amphion_codec/vocos.py +881 -0
- models/codec/coco/rep_coco_model.py +441 -0
- models/codec/melvqgan/melspec.py +108 -0
- utils/__init__.py +0 -0
- utils/hparam.py +659 -0
- utils/util.py +690 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
example/gradio/example1.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
example/gradio/example2.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
example/gradio/example3.wav filter=lfs diff=lfs merge=lfs -text
|
anyaccomp/fmt_model.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import math
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from anyaccomp.llama_nar import DiffLlamaConcat
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
|
| 9 |
+
from typing import List, Optional, Tuple, Union
|
| 10 |
+
from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FlowMatchingTransformerConcat(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
vocab_size=1024,
|
| 17 |
+
mel_dim=100,
|
| 18 |
+
hidden_size=1024,
|
| 19 |
+
num_layers=12,
|
| 20 |
+
num_heads=16,
|
| 21 |
+
cfg_scale=0.2,
|
| 22 |
+
use_cond_code=False,
|
| 23 |
+
cond_codebook_size=1024,
|
| 24 |
+
cond_dim=1024,
|
| 25 |
+
cond_scale_factor=1,
|
| 26 |
+
sigma=1e-5,
|
| 27 |
+
time_scheduler="linear",
|
| 28 |
+
cfg=None,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.cfg = cfg
|
| 32 |
+
|
| 33 |
+
mel_dim = (
|
| 34 |
+
cfg.mel_dim if cfg is not None and hasattr(cfg, "mel_dim") else mel_dim
|
| 35 |
+
)
|
| 36 |
+
hidden_size = (
|
| 37 |
+
cfg.hidden_size
|
| 38 |
+
if cfg is not None and hasattr(cfg, "hidden_size")
|
| 39 |
+
else hidden_size
|
| 40 |
+
)
|
| 41 |
+
num_layers = (
|
| 42 |
+
cfg.num_layers
|
| 43 |
+
if cfg is not None and hasattr(cfg, "num_layers")
|
| 44 |
+
else num_layers
|
| 45 |
+
)
|
| 46 |
+
num_heads = (
|
| 47 |
+
cfg.num_heads
|
| 48 |
+
if cfg is not None and hasattr(cfg, "num_heads")
|
| 49 |
+
else num_heads
|
| 50 |
+
)
|
| 51 |
+
cfg_scale = (
|
| 52 |
+
cfg.cfg_scale
|
| 53 |
+
if cfg is not None and hasattr(cfg, "cfg_scale")
|
| 54 |
+
else cfg_scale
|
| 55 |
+
)
|
| 56 |
+
use_cond_code = (
|
| 57 |
+
cfg.use_cond_code
|
| 58 |
+
if cfg is not None and hasattr(cfg, "use_cond_code")
|
| 59 |
+
else use_cond_code
|
| 60 |
+
)
|
| 61 |
+
cond_codebook_size = (
|
| 62 |
+
cfg.cond_codebook_size
|
| 63 |
+
if cfg is not None and hasattr(cfg, "cond_codebook_size")
|
| 64 |
+
else cond_codebook_size
|
| 65 |
+
)
|
| 66 |
+
cond_dim = (
|
| 67 |
+
cfg.cond_dim if cfg is not None and hasattr(cfg, "cond_dim") else cond_dim
|
| 68 |
+
)
|
| 69 |
+
time_scheduler = (
|
| 70 |
+
cfg.time_scheduler
|
| 71 |
+
if cfg is not None and hasattr(cfg, "time_scheduler")
|
| 72 |
+
else time_scheduler
|
| 73 |
+
)
|
| 74 |
+
sigma = cfg.sigma if cfg is not None and hasattr(cfg, "sigma") else sigma
|
| 75 |
+
cond_scale_factor = (
|
| 76 |
+
cfg.cond_scale_factor
|
| 77 |
+
if cfg is not None and hasattr(cfg, "cond_scale_factor")
|
| 78 |
+
else cond_scale_factor
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
self.mel_dim = mel_dim
|
| 82 |
+
self.hidden_size = hidden_size
|
| 83 |
+
self.num_layers = num_layers
|
| 84 |
+
self.num_heads = num_heads
|
| 85 |
+
self.cfg_scale = cfg_scale
|
| 86 |
+
self.use_cond_code = use_cond_code
|
| 87 |
+
self.cond_codebook_size = cond_codebook_size
|
| 88 |
+
self.cond_dim = cond_dim
|
| 89 |
+
self.time_scheduler = time_scheduler
|
| 90 |
+
self.sigma = sigma
|
| 91 |
+
self.cond_scale_factor = cond_scale_factor
|
| 92 |
+
|
| 93 |
+
self.vocab_size = (
|
| 94 |
+
cfg.vocab_size
|
| 95 |
+
if cfg is not None and hasattr(cfg, "vocab_size")
|
| 96 |
+
else vocab_size
|
| 97 |
+
)
|
| 98 |
+
self.vocal_mel_proj = (
|
| 99 |
+
nn.Linear(self.cfg.cond_code_dim, self.hidden_size)
|
| 100 |
+
if not self.use_cond_code
|
| 101 |
+
else nn.Sequential(
|
| 102 |
+
nn.Embedding(
|
| 103 |
+
self.vocab_size, self.mel_dim
|
| 104 |
+
), # [batch] -> [batch, mel_dim]
|
| 105 |
+
nn.Linear(
|
| 106 |
+
self.mel_dim, self.hidden_size
|
| 107 |
+
), # [batch, mel_dim] -> [batch, hidden_size]
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.diff_estimator = DiffLlamaConcat(
|
| 112 |
+
mel_dim=self.mel_dim,
|
| 113 |
+
hidden_size=self.hidden_size,
|
| 114 |
+
num_heads=self.num_heads,
|
| 115 |
+
num_layers=self.num_layers,
|
| 116 |
+
flash_attention=hasattr(cfg, "flash_attention") and cfg.flash_attention,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if hasattr(cfg, "repa_loss") and cfg.repa_loss.enable:
|
| 120 |
+
repa_dim = (
|
| 121 |
+
cfg.repa_loss.repa_dim
|
| 122 |
+
if hasattr(cfg.repa_loss, "repa_dim")
|
| 123 |
+
else self.hidden_size
|
| 124 |
+
)
|
| 125 |
+
self.repa_proj = nn.Sequential(
|
| 126 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
| 127 |
+
nn.SiLU(),
|
| 128 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
| 129 |
+
nn.SiLU(),
|
| 130 |
+
nn.Linear(self.hidden_size, repa_dim),
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self.reset_parameters()
|
| 134 |
+
|
| 135 |
+
def reset_parameters(self):
|
| 136 |
+
def _reset_parameters(m):
|
| 137 |
+
if isinstance(m, nn.MultiheadAttention):
|
| 138 |
+
if m._qkv_same_embed_dim:
|
| 139 |
+
nn.init.normal_(m.in_proj_weight, std=0.02)
|
| 140 |
+
else:
|
| 141 |
+
nn.init.normal_(m.q_proj_weight, std=0.02)
|
| 142 |
+
nn.init.normal_(m.k_proj_weight, std=0.02)
|
| 143 |
+
nn.init.normal_(m.v_proj_weight, std=0.02)
|
| 144 |
+
|
| 145 |
+
if m.in_proj_bias is not None:
|
| 146 |
+
nn.init.constant_(m.in_proj_bias, 0.0)
|
| 147 |
+
nn.init.constant_(m.out_proj.bias, 0.0)
|
| 148 |
+
if m.bias_k is not None:
|
| 149 |
+
nn.init.xavier_normal_(m.bias_k)
|
| 150 |
+
if m.bias_v is not None:
|
| 151 |
+
nn.init.xavier_normal_(m.bias_v)
|
| 152 |
+
|
| 153 |
+
elif (
|
| 154 |
+
isinstance(m, nn.Conv1d)
|
| 155 |
+
or isinstance(m, nn.ConvTranspose1d)
|
| 156 |
+
or isinstance(m, nn.Conv2d)
|
| 157 |
+
or isinstance(m, nn.ConvTranspose2d)
|
| 158 |
+
):
|
| 159 |
+
m.weight.data.normal_(0.0, 0.02)
|
| 160 |
+
|
| 161 |
+
elif isinstance(m, nn.Linear):
|
| 162 |
+
m.weight.data.normal_(mean=0.0, std=0.02)
|
| 163 |
+
if m.bias is not None:
|
| 164 |
+
m.bias.data.zero_()
|
| 165 |
+
|
| 166 |
+
elif isinstance(m, nn.Embedding):
|
| 167 |
+
m.weight.data.normal_(mean=0.0, std=0.02)
|
| 168 |
+
if m.padding_idx is not None:
|
| 169 |
+
m.weight.data[m.padding_idx].zero_()
|
| 170 |
+
|
| 171 |
+
self.apply(_reset_parameters)
|
| 172 |
+
|
| 173 |
+
@torch.no_grad()
|
| 174 |
+
def forward_diffusion(self, x, t):
|
| 175 |
+
"""
|
| 176 |
+
x: (B, T, mel_dim)
|
| 177 |
+
t: (B,)
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
new_t = t
|
| 181 |
+
t = t.unsqueeze(-1).unsqueeze(-1)
|
| 182 |
+
z = torch.randn(
|
| 183 |
+
x.shape, dtype=x.dtype, device=x.device, requires_grad=False
|
| 184 |
+
) # (B, T, mel_dim)
|
| 185 |
+
|
| 186 |
+
cfg_scale = self.cfg_scale
|
| 187 |
+
|
| 188 |
+
# get prompt len
|
| 189 |
+
if torch.rand(1) > 0.7:
|
| 190 |
+
prompt_len = torch.randint(
|
| 191 |
+
min(x.shape[1] // 4, 5), int(x.shape[1] * 0.4), (x.shape[0],)
|
| 192 |
+
).to(
|
| 193 |
+
x.device
|
| 194 |
+
) # (B,)
|
| 195 |
+
else:
|
| 196 |
+
prompt_len = torch.zeros(x.shape[0]).to(x.device)
|
| 197 |
+
|
| 198 |
+
split_ratio = torch.rand(prompt_len.shape, device=prompt_len.device) # (B,)
|
| 199 |
+
|
| 200 |
+
left_len = (split_ratio * (prompt_len + 1).float()).long() # (B,)
|
| 201 |
+
right_len = prompt_len - left_len # (B,)
|
| 202 |
+
|
| 203 |
+
T = x.shape[1]
|
| 204 |
+
is_prompt = torch.zeros_like(x[:, :, 0]) # (B, T)
|
| 205 |
+
col_indices = torch.arange(T, device=x.device).repeat(x.shape[0], 1) # (B, T)
|
| 206 |
+
left_mask = col_indices < left_len.unsqueeze(1)
|
| 207 |
+
right_mask = col_indices >= (T - right_len.unsqueeze(1))
|
| 208 |
+
is_prompt[left_mask | right_mask] = 1
|
| 209 |
+
|
| 210 |
+
mask = torch.ones_like(x[:, :, 0]) # mask if 1, not mask if 0
|
| 211 |
+
mask[is_prompt.bool()] = 0
|
| 212 |
+
mask = mask[:, :, None]
|
| 213 |
+
|
| 214 |
+
# flow matching: xt = (1 - (1 - sigma) * t) * x0 + t * x; where x0 ~ N(0, 1), x is a sample
|
| 215 |
+
# flow gt: x - (1 - sigma) * x0 = x - (1 - sigma) * noise
|
| 216 |
+
xt = ((1 - (1 - self.sigma) * t) * z + t * x) * mask + x * (1 - mask)
|
| 217 |
+
|
| 218 |
+
return xt, z, new_t, prompt_len, mask
|
| 219 |
+
|
| 220 |
+
def loss_t(
|
| 221 |
+
self,
|
| 222 |
+
x,
|
| 223 |
+
x_mask,
|
| 224 |
+
t,
|
| 225 |
+
lyric=None,
|
| 226 |
+
output_hidden_states=False,
|
| 227 |
+
):
|
| 228 |
+
xt, z, new_t, prompt_len, mask = self.forward_diffusion(x, t)
|
| 229 |
+
|
| 230 |
+
noise = z
|
| 231 |
+
|
| 232 |
+
prompt_len = prompt_len.float()
|
| 233 |
+
|
| 234 |
+
# drop condition using cfg_scale
|
| 235 |
+
if lyric is not None:
|
| 236 |
+
cfg_mask = torch.where(
|
| 237 |
+
torch.rand_like(prompt_len) > self.cfg_scale,
|
| 238 |
+
torch.ones_like(prompt_len), # keep cond
|
| 239 |
+
torch.zeros_like(prompt_len), # drop cond
|
| 240 |
+
).to(lyric.device)
|
| 241 |
+
|
| 242 |
+
cond_mask = cfg_mask[:, None, None] # [b, 1, 1]
|
| 243 |
+
|
| 244 |
+
lyric = lyric * cond_mask
|
| 245 |
+
|
| 246 |
+
final_mask = mask * x_mask[..., None] # (B, T, 1)
|
| 247 |
+
|
| 248 |
+
output = self.diff_estimator(
|
| 249 |
+
xt, new_t, x_mask, lyric, output_hidden_states=output_hidden_states
|
| 250 |
+
)
|
| 251 |
+
if output_hidden_states:
|
| 252 |
+
return_list = [noise, x, output["hidden_states"], final_mask, prompt_len]
|
| 253 |
+
return_list.append(output["all_hidden_states"])
|
| 254 |
+
else:
|
| 255 |
+
return_list = [noise, x, output, final_mask, prompt_len]
|
| 256 |
+
|
| 257 |
+
return return_list
|
| 258 |
+
|
| 259 |
+
def compute_loss(self, x, x_mask, lyric=None, output_hidden_states=False):
|
| 260 |
+
# x0: (B, T, num_quantizer)
|
| 261 |
+
# x_mask: (B, T) mask is 0 for padding
|
| 262 |
+
t = torch.rand(x.shape[0], device=x.device, requires_grad=False)
|
| 263 |
+
t = torch.clamp(t, 1e-5, 1.0)
|
| 264 |
+
# from CosyVoice: considering the generation process at the beginning is harder than follows, we involve a cosine scheduler for the timestep t
|
| 265 |
+
if self.time_scheduler == "cos":
|
| 266 |
+
t = 1 - torch.cos(t * math.pi * 0.5)
|
| 267 |
+
else:
|
| 268 |
+
pass
|
| 269 |
+
return self.loss_t(
|
| 270 |
+
x, x_mask, t, lyric, output_hidden_states=output_hidden_states
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def forward(self, x, x_mask, vocal_mel, output_hidden_states=False):
|
| 274 |
+
cond = self.vocal_mel_proj(vocal_mel)
|
| 275 |
+
return self.compute_loss(x, x_mask, cond, output_hidden_states)
|
| 276 |
+
|
| 277 |
+
@torch.no_grad()
|
| 278 |
+
def reverse_diffusion(
|
| 279 |
+
self,
|
| 280 |
+
vocal_mel=None,
|
| 281 |
+
prompt=None,
|
| 282 |
+
right_prompt=None,
|
| 283 |
+
x_mask=None,
|
| 284 |
+
prompt_mask=None,
|
| 285 |
+
right_prompt_mask=None,
|
| 286 |
+
target_len=None,
|
| 287 |
+
n_timesteps=10,
|
| 288 |
+
cfg=1.0,
|
| 289 |
+
rescale_cfg=0.75,
|
| 290 |
+
):
|
| 291 |
+
h = 1.0 / n_timesteps
|
| 292 |
+
prompt_len = prompt.shape[1] if prompt is not None else 0
|
| 293 |
+
right_prompt_len = right_prompt.shape[1] if right_prompt is not None else 0
|
| 294 |
+
# print(prompt_len, right_prompt_len)
|
| 295 |
+
if vocal_mel is not None:
|
| 296 |
+
target_len = vocal_mel.shape[1]
|
| 297 |
+
elif target_len is None:
|
| 298 |
+
target_len = 1000 # hardcode 50Hz 20s
|
| 299 |
+
else:
|
| 300 |
+
raise ValueError
|
| 301 |
+
full_len = target_len
|
| 302 |
+
target_len = target_len - prompt_len - right_prompt_len
|
| 303 |
+
|
| 304 |
+
cond = self.vocal_mel_proj(vocal_mel)
|
| 305 |
+
|
| 306 |
+
if x_mask is None:
|
| 307 |
+
x_mask = torch.ones(cond.shape[0], target_len).to(cond.device)
|
| 308 |
+
if prompt_mask is None and prompt is not None:
|
| 309 |
+
prompt_mask = torch.ones(cond.shape[0], prompt_len).to(cond.device)
|
| 310 |
+
if right_prompt_mask is None and right_prompt is not None:
|
| 311 |
+
right_prompt_mask = torch.ones(cond.shape[0], right_prompt_len).to(
|
| 312 |
+
cond.device
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if prompt is not None and right_prompt is not None:
|
| 316 |
+
xt_mask = torch.cat([prompt_mask, x_mask, right_prompt_mask], dim=1)
|
| 317 |
+
elif prompt is not None and right_prompt is None:
|
| 318 |
+
xt_mask = torch.cat([prompt_mask, x_mask], dim=1)
|
| 319 |
+
elif prompt is None and right_prompt is not None:
|
| 320 |
+
xt_mask = torch.cat([x_mask, right_prompt_mask], dim=1)
|
| 321 |
+
else:
|
| 322 |
+
xt_mask = x_mask
|
| 323 |
+
|
| 324 |
+
z = torch.randn(
|
| 325 |
+
(cond.shape[0], target_len, self.mel_dim),
|
| 326 |
+
dtype=cond.dtype,
|
| 327 |
+
device=cond.device,
|
| 328 |
+
requires_grad=False,
|
| 329 |
+
)
|
| 330 |
+
xt = z
|
| 331 |
+
# t from 0 to 1: x0 = z ~ N(0, 1)
|
| 332 |
+
for i in range(n_timesteps):
|
| 333 |
+
if prompt is not None and right_prompt is not None:
|
| 334 |
+
xt_input = torch.cat([prompt, xt, right_prompt], dim=1)
|
| 335 |
+
elif prompt is not None and right_prompt is None:
|
| 336 |
+
xt_input = torch.cat([prompt, xt], dim=1)
|
| 337 |
+
elif prompt is None and right_prompt is not None:
|
| 338 |
+
xt_input = torch.cat([xt, right_prompt], dim=1)
|
| 339 |
+
else:
|
| 340 |
+
xt_input = xt
|
| 341 |
+
t = (0 + (i + 0.5) * h) * torch.ones(
|
| 342 |
+
z.shape[0], dtype=z.dtype, device=z.device
|
| 343 |
+
)
|
| 344 |
+
flow_pred = self.diff_estimator(xt_input, t, xt_mask, cond)
|
| 345 |
+
flow_pred = flow_pred[:, prompt_len : prompt_len + target_len, :]
|
| 346 |
+
# cfg
|
| 347 |
+
|
| 348 |
+
if cfg > 0:
|
| 349 |
+
uncond_flow_pred = self.diff_estimator(
|
| 350 |
+
xt_input, t, xt_mask, torch.zeros_like(cond)
|
| 351 |
+
)
|
| 352 |
+
uncond_flow_pred = uncond_flow_pred[
|
| 353 |
+
:, prompt_len : prompt_len + target_len, :
|
| 354 |
+
]
|
| 355 |
+
pos_flow_pred_std = flow_pred.std()
|
| 356 |
+
flow_pred_cfg = flow_pred + cfg * (flow_pred - uncond_flow_pred)
|
| 357 |
+
rescale_flow_pred = (
|
| 358 |
+
flow_pred_cfg * pos_flow_pred_std / flow_pred_cfg.std()
|
| 359 |
+
)
|
| 360 |
+
flow_pred = (
|
| 361 |
+
rescale_cfg * rescale_flow_pred + (1 - rescale_cfg) * flow_pred_cfg
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
dxt = flow_pred * h
|
| 365 |
+
xt = xt + dxt
|
| 366 |
+
|
| 367 |
+
return xt
|
anyaccomp/inference_utils.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import json
|
| 3 |
+
import librosa
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
import accelerate
|
| 7 |
+
import safetensors
|
| 8 |
+
import numpy as np
|
| 9 |
+
import os
|
| 10 |
+
import yaml
|
| 11 |
+
|
| 12 |
+
import torchvision
|
| 13 |
+
from librosa.feature import chroma_stft
|
| 14 |
+
|
| 15 |
+
import torchvision
|
| 16 |
+
import random
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
import sys
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
from anyaccomp.fmt_model import FlowMatchingTransformerConcat
|
| 23 |
+
from models.codec.amphion_codec.vocos import Vocos
|
| 24 |
+
from models.codec.melvqgan.melspec import MelSpectrogram
|
| 25 |
+
from models.codec.coco.rep_coco_model import CocoContentStyle, CocoContent, CocoStyle
|
| 26 |
+
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
|
| 29 |
+
from utils.util import load_config
|
| 30 |
+
|
| 31 |
+
import io
|
| 32 |
+
|
| 33 |
+
from transformers import T5Tokenizer, T5EncoderModel
|
| 34 |
+
|
| 35 |
+
import warnings
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Sing2SongInferencePipeline:
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
checkpoint_path,
|
| 42 |
+
cfg_path,
|
| 43 |
+
vocoder_checkpoint_path,
|
| 44 |
+
vocoder_cfg_path,
|
| 45 |
+
device="cuda",
|
| 46 |
+
):
|
| 47 |
+
self.cfg = load_config(cfg_path)
|
| 48 |
+
self.device = device
|
| 49 |
+
|
| 50 |
+
self.checkpoint_path = checkpoint_path
|
| 51 |
+
self._load_model(checkpoint_path)
|
| 52 |
+
|
| 53 |
+
self._build_input_model()
|
| 54 |
+
self.vocoder_checkpoint_path = vocoder_checkpoint_path
|
| 55 |
+
self.vocoder_cfg = load_config(vocoder_cfg_path)
|
| 56 |
+
self._build_output_model()
|
| 57 |
+
print("Output model built")
|
| 58 |
+
|
| 59 |
+
def _load_model(self, checkpoint_path):
|
| 60 |
+
self.model = FlowMatchingTransformerConcat(
|
| 61 |
+
cfg=self.cfg.model.flow_matching_transformer
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
accelerate.load_checkpoint_and_dispatch(self.model, checkpoint_path)
|
| 65 |
+
self.model.eval().to(self.device)
|
| 66 |
+
print(
|
| 67 |
+
f"model Params: {round(sum(p.numel() for p in self.model.parameters() if p.requires_grad)/1e6, 2)}M"
|
| 68 |
+
)
|
| 69 |
+
print(f"Loaded model from {checkpoint_path}")
|
| 70 |
+
|
| 71 |
+
def _build_input_model(self):
|
| 72 |
+
self.coco_model = CocoStyle(
|
| 73 |
+
cfg=self.cfg.model.coco, construct_only_for_quantizer=True
|
| 74 |
+
)
|
| 75 |
+
self.coco_model.eval()
|
| 76 |
+
self.coco_model.to(self.device)
|
| 77 |
+
accelerate.load_checkpoint_and_dispatch(
|
| 78 |
+
self.coco_model, self.cfg.model.coco.pretrained_path
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def _build_output_model(self):
|
| 82 |
+
# print(vocoder_checkpoint_path)
|
| 83 |
+
self.vocoder = Vocos(cfg=self.vocoder_cfg.model.vocos)
|
| 84 |
+
accelerate.load_checkpoint_and_dispatch(
|
| 85 |
+
self.vocoder, self.vocoder_checkpoint_path
|
| 86 |
+
)
|
| 87 |
+
self.vocoder = self.vocoder.eval().to(self.device)
|
| 88 |
+
|
| 89 |
+
@torch.no_grad()
|
| 90 |
+
@torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
| 91 |
+
def _extract_coco_codec(self, speech):
|
| 92 |
+
"""
|
| 93 |
+
Args:
|
| 94 |
+
speech: [B, T]
|
| 95 |
+
Returns:
|
| 96 |
+
codecs: [B, T]. Note that codecs might be not at 50Hz!
|
| 97 |
+
"""
|
| 98 |
+
target_chroma_dim = self.cfg.model.coco.chromagram_dim
|
| 99 |
+
|
| 100 |
+
speech = speech.cpu().numpy().squeeze()
|
| 101 |
+
|
| 102 |
+
chromagram = chroma_stft(
|
| 103 |
+
y=speech,
|
| 104 |
+
sr=self.cfg.preprocess.chromagram.sample_rate,
|
| 105 |
+
n_fft=self.cfg.preprocess.chromagram.n_fft,
|
| 106 |
+
hop_length=self.cfg.preprocess.chromagram.hop_size,
|
| 107 |
+
win_length=self.cfg.preprocess.chromagram.win_size,
|
| 108 |
+
n_chroma=target_chroma_dim,
|
| 109 |
+
).T # [D, T] -> [T, D]
|
| 110 |
+
chromagram_feats = torch.tensor(chromagram).unsqueeze(0).to(self.device)
|
| 111 |
+
codecs, _ = self.coco_model.quantize(chromagram_feats)
|
| 112 |
+
return codecs
|
| 113 |
+
|
| 114 |
+
@torch.no_grad()
|
| 115 |
+
def encode_vocal(self, speech): # (B, T)
|
| 116 |
+
speech = speech.to(self.device)
|
| 117 |
+
codecs = self._extract_coco_codec(speech)
|
| 118 |
+
return codecs
|
| 119 |
+
|
| 120 |
+
@torch.no_grad()
|
| 121 |
+
def _generate_audio(self, mel):
|
| 122 |
+
synthesized_audio = (self.vocoder(mel.transpose(1, 2)).detach().cpu())[0]
|
| 123 |
+
|
| 124 |
+
return synthesized_audio
|
anyaccomp/llama_nar.py
ADDED
|
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from typing import List, Optional, Tuple, Union
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
| 11 |
+
from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
|
| 12 |
+
from transformers import LlamaConfig
|
| 13 |
+
from transformers.models.llama.modeling_llama import (
|
| 14 |
+
LlamaAttention,
|
| 15 |
+
apply_rotary_pos_emb,
|
| 16 |
+
Cache,
|
| 17 |
+
repeat_kv,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SinusoidalPosEmb(nn.Module):
|
| 22 |
+
def __init__(self, dim):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.dim = dim
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
device = x.device
|
| 28 |
+
half_dim = self.dim // 2
|
| 29 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 30 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
| 31 |
+
emb = x[:, None] * emb[None, :] * 1.0
|
| 32 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 33 |
+
return emb
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class LlamaAdaptiveRMSNorm(nn.Module):
|
| 37 |
+
def __init__(self, hidden_size=1024, eps=1e-6, dim_cond=1024):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.to_weight = nn.Linear(dim_cond, hidden_size)
|
| 40 |
+
nn.init.zeros_(self.to_weight.weight)
|
| 41 |
+
nn.init.ones_(self.to_weight.bias)
|
| 42 |
+
self.variance_epsilon = eps
|
| 43 |
+
self._is_hf_initialized = True # disable automatic init
|
| 44 |
+
|
| 45 |
+
def forward(self, hidden_states, cond_embedding):
|
| 46 |
+
input_dtype = hidden_states.dtype
|
| 47 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 48 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 49 |
+
|
| 50 |
+
weight = self.to_weight(cond_embedding)
|
| 51 |
+
if len(weight.shape) == 2:
|
| 52 |
+
weight = weight.unsqueeze(1)
|
| 53 |
+
|
| 54 |
+
return (weight * hidden_states).to(input_dtype)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class LlamaNARDecoderLayer(LlamaDecoderLayer):
|
| 58 |
+
def __init__(self, config: LlamaConfig, layer_idx: int):
|
| 59 |
+
"""Override to adaptive layer norm"""
|
| 60 |
+
super().__init__(config, layer_idx) # init attention, mlp, etc.
|
| 61 |
+
# self.self_attn = LlamaXformersAttention(config=config, layer_idx=layer_idx)
|
| 62 |
+
|
| 63 |
+
self.self_attn.is_causal = False # for flash attn..
|
| 64 |
+
|
| 65 |
+
self.input_layernorm = LlamaAdaptiveRMSNorm(
|
| 66 |
+
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
|
| 67 |
+
)
|
| 68 |
+
self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
|
| 69 |
+
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# add `cond` in forward function
|
| 73 |
+
def forward(
|
| 74 |
+
self,
|
| 75 |
+
hidden_states: torch.Tensor,
|
| 76 |
+
cond_embedding: torch.Tensor,
|
| 77 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 78 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 79 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 80 |
+
output_attentions: Optional[bool] = False,
|
| 81 |
+
use_cache: Optional[bool] = False,
|
| 82 |
+
) -> Tuple[
|
| 83 |
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
| 84 |
+
]:
|
| 85 |
+
"""
|
| 86 |
+
Args:
|
| 87 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 88 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 89 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 90 |
+
output_attentions (`bool`, *optional*):
|
| 91 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 92 |
+
returned tensors for more detail.
|
| 93 |
+
use_cache (`bool`, *optional*):
|
| 94 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 95 |
+
(see `past_key_values`).
|
| 96 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
residual = hidden_states
|
| 100 |
+
|
| 101 |
+
hidden_states = self.input_layernorm(
|
| 102 |
+
hidden_states, cond_embedding=cond_embedding
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Self Attention
|
| 106 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 107 |
+
hidden_states=hidden_states,
|
| 108 |
+
attention_mask=attention_mask,
|
| 109 |
+
position_ids=position_ids,
|
| 110 |
+
past_key_value=past_key_value,
|
| 111 |
+
output_attentions=output_attentions,
|
| 112 |
+
use_cache=use_cache,
|
| 113 |
+
)
|
| 114 |
+
hidden_states = residual + hidden_states
|
| 115 |
+
|
| 116 |
+
# Fully Connected
|
| 117 |
+
residual = hidden_states
|
| 118 |
+
hidden_states = self.post_attention_layernorm(
|
| 119 |
+
hidden_states, cond_embedding=cond_embedding
|
| 120 |
+
)
|
| 121 |
+
hidden_states = self.mlp(hidden_states)
|
| 122 |
+
hidden_states = residual + hidden_states
|
| 123 |
+
|
| 124 |
+
outputs = (hidden_states,)
|
| 125 |
+
|
| 126 |
+
if output_attentions:
|
| 127 |
+
outputs += (self_attn_weights,)
|
| 128 |
+
|
| 129 |
+
if use_cache:
|
| 130 |
+
outputs += (present_key_value,)
|
| 131 |
+
|
| 132 |
+
return outputs
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class DiffLlamaConcat(LlamaModel):
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
mel_dim=100,
|
| 139 |
+
hidden_size=1024,
|
| 140 |
+
num_heads=16,
|
| 141 |
+
num_layers=16,
|
| 142 |
+
dropout=0.1,
|
| 143 |
+
ffn_dropout=0.1,
|
| 144 |
+
attention_dropout=0.0,
|
| 145 |
+
config=LlamaConfig(0, 256, 1024, 1, 1),
|
| 146 |
+
flash_attention=False,
|
| 147 |
+
):
|
| 148 |
+
super().__init__(config)
|
| 149 |
+
|
| 150 |
+
self.flash_attention = flash_attention
|
| 151 |
+
self.layers = nn.ModuleList(
|
| 152 |
+
[
|
| 153 |
+
LlamaNARDecoderLayer(
|
| 154 |
+
LlamaConfig(
|
| 155 |
+
hidden_size=hidden_size,
|
| 156 |
+
num_attention_heads=num_heads,
|
| 157 |
+
max_position_embeddings=4096,
|
| 158 |
+
intermediate_size=hidden_size * 4,
|
| 159 |
+
attn_implementation=(
|
| 160 |
+
"flash_attention_2" if self.flash_attention else "eager"
|
| 161 |
+
),
|
| 162 |
+
),
|
| 163 |
+
layer_idx=i,
|
| 164 |
+
)
|
| 165 |
+
for i in range(num_layers)
|
| 166 |
+
]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size)
|
| 170 |
+
|
| 171 |
+
self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
|
| 172 |
+
self.diff_step_mlp = nn.Sequential(
|
| 173 |
+
nn.Linear(hidden_size, hidden_size * 4),
|
| 174 |
+
nn.SiLU(),
|
| 175 |
+
nn.Linear(hidden_size * 4, hidden_size),
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
self.cond_mlp = nn.Sequential(
|
| 179 |
+
nn.Linear(hidden_size, hidden_size * 4),
|
| 180 |
+
nn.SiLU(),
|
| 181 |
+
nn.Linear(hidden_size * 4, hidden_size),
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
self.mel_mlp = nn.Sequential(
|
| 185 |
+
nn.Linear(mel_dim, hidden_size * 4),
|
| 186 |
+
nn.SiLU(),
|
| 187 |
+
nn.Linear(hidden_size * 4, hidden_size),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
self.mel_out_mlp = nn.Sequential(
|
| 191 |
+
nn.Linear(hidden_size, hidden_size * 4),
|
| 192 |
+
nn.SiLU(),
|
| 193 |
+
nn.Linear(hidden_size * 4, mel_dim),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
for layer in self.layers:
|
| 197 |
+
layer.input_layernorm = LlamaAdaptiveRMSNorm(
|
| 198 |
+
hidden_size, dim_cond=hidden_size
|
| 199 |
+
)
|
| 200 |
+
layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
|
| 201 |
+
hidden_size, dim_cond=hidden_size
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
self.embed_tokens = None
|
| 205 |
+
|
| 206 |
+
self.post_init()
|
| 207 |
+
|
| 208 |
+
# self.reset_parameters()
|
| 209 |
+
|
| 210 |
+
def _prepare_decoder_attention_mask(
|
| 211 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
| 212 |
+
):
|
| 213 |
+
# create noncausal mask
|
| 214 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 215 |
+
combined_attention_mask = None
|
| 216 |
+
|
| 217 |
+
def _expand_mask(
|
| 218 |
+
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
| 219 |
+
):
|
| 220 |
+
"""
|
| 221 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 222 |
+
"""
|
| 223 |
+
bsz, src_len = mask.size()
|
| 224 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 225 |
+
|
| 226 |
+
expanded_mask = (
|
| 227 |
+
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
inverted_mask = 1.0 - expanded_mask
|
| 231 |
+
|
| 232 |
+
return inverted_mask.masked_fill(
|
| 233 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if attention_mask is not None:
|
| 237 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 238 |
+
expanded_attn_mask = _expand_mask(
|
| 239 |
+
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
| 240 |
+
).to(inputs_embeds.device)
|
| 241 |
+
combined_attention_mask = (
|
| 242 |
+
expanded_attn_mask
|
| 243 |
+
if combined_attention_mask is None
|
| 244 |
+
else expanded_attn_mask + combined_attention_mask
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return combined_attention_mask
|
| 248 |
+
|
| 249 |
+
def forward(
|
| 250 |
+
self,
|
| 251 |
+
x,
|
| 252 |
+
diffusion_step,
|
| 253 |
+
x_mask,
|
| 254 |
+
cond,
|
| 255 |
+
input_ids: torch.LongTensor = None, # [num_quant, B, T]
|
| 256 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 257 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 258 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 259 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 260 |
+
use_cache: Optional[bool] = None,
|
| 261 |
+
output_attentions: Optional[bool] = None,
|
| 262 |
+
output_hidden_states: Optional[bool] = None,
|
| 263 |
+
return_dict: Optional[bool] = None,
|
| 264 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 265 |
+
|
| 266 |
+
# retrieve some shape info
|
| 267 |
+
batch_size, seq_length, _ = x.shape
|
| 268 |
+
|
| 269 |
+
# condtion mlp
|
| 270 |
+
cond_embedding = self.cond_mlp(cond) # (B, T, C)
|
| 271 |
+
|
| 272 |
+
# condition mel
|
| 273 |
+
x = self.mel_mlp(x)
|
| 274 |
+
|
| 275 |
+
# diffusion step embedding
|
| 276 |
+
diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device)
|
| 277 |
+
diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C)
|
| 278 |
+
x = x + cond_embedding
|
| 279 |
+
|
| 280 |
+
inputs_embeds = x
|
| 281 |
+
# if self.flash_attention:
|
| 282 |
+
# attention_mask = None
|
| 283 |
+
# else:
|
| 284 |
+
attention_mask = x_mask
|
| 285 |
+
|
| 286 |
+
# assert x_mask.shape == batch_size, seq_length
|
| 287 |
+
|
| 288 |
+
output_attentions = (
|
| 289 |
+
output_attentions
|
| 290 |
+
if output_attentions is not None
|
| 291 |
+
else self.config.output_attentions
|
| 292 |
+
)
|
| 293 |
+
output_hidden_states = (
|
| 294 |
+
output_hidden_states
|
| 295 |
+
if output_hidden_states is not None
|
| 296 |
+
else self.config.output_hidden_states
|
| 297 |
+
)
|
| 298 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 299 |
+
|
| 300 |
+
return_dict = (
|
| 301 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
seq_length_with_past = seq_length
|
| 305 |
+
past_key_values_length = 0
|
| 306 |
+
|
| 307 |
+
if past_key_values is not None:
|
| 308 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 309 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 310 |
+
|
| 311 |
+
if position_ids is None:
|
| 312 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 313 |
+
position_ids = torch.arange(
|
| 314 |
+
past_key_values_length,
|
| 315 |
+
seq_length + past_key_values_length,
|
| 316 |
+
dtype=torch.long,
|
| 317 |
+
device=device,
|
| 318 |
+
)
|
| 319 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 320 |
+
else:
|
| 321 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
| 322 |
+
|
| 323 |
+
if not self.flash_attention:
|
| 324 |
+
# embed positions
|
| 325 |
+
if attention_mask is None:
|
| 326 |
+
attention_mask = torch.ones(
|
| 327 |
+
(batch_size, seq_length_with_past),
|
| 328 |
+
dtype=torch.bool,
|
| 329 |
+
device=inputs_embeds.device,
|
| 330 |
+
)
|
| 331 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
| 332 |
+
attention_mask,
|
| 333 |
+
(batch_size, seq_length),
|
| 334 |
+
inputs_embeds,
|
| 335 |
+
past_key_values_length,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
hidden_states = inputs_embeds
|
| 339 |
+
|
| 340 |
+
if self.gradient_checkpointing and self.training:
|
| 341 |
+
if use_cache:
|
| 342 |
+
use_cache = False
|
| 343 |
+
|
| 344 |
+
# decoder layers
|
| 345 |
+
all_hidden_states = () if output_hidden_states else None
|
| 346 |
+
all_self_attns = () if output_attentions else None
|
| 347 |
+
next_decoder_cache = () if use_cache else None
|
| 348 |
+
|
| 349 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 350 |
+
if output_hidden_states:
|
| 351 |
+
all_hidden_states += (hidden_states,)
|
| 352 |
+
|
| 353 |
+
past_key_value = (
|
| 354 |
+
past_key_values[idx] if past_key_values is not None else None
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
if self.gradient_checkpointing and self.training:
|
| 358 |
+
raise NotImplementedError
|
| 359 |
+
|
| 360 |
+
def create_custom_forward(module):
|
| 361 |
+
def custom_forward(*inputs):
|
| 362 |
+
# None for past_key_value
|
| 363 |
+
return module(*inputs, output_attentions, None)
|
| 364 |
+
|
| 365 |
+
return custom_forward
|
| 366 |
+
|
| 367 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 368 |
+
create_custom_forward(decoder_layer),
|
| 369 |
+
hidden_states,
|
| 370 |
+
attention_mask,
|
| 371 |
+
position_ids,
|
| 372 |
+
None,
|
| 373 |
+
)
|
| 374 |
+
else:
|
| 375 |
+
layer_outputs = decoder_layer(
|
| 376 |
+
hidden_states,
|
| 377 |
+
# attention_mask=attention_mask if not self.flash_attention else None,
|
| 378 |
+
attention_mask=attention_mask,
|
| 379 |
+
position_ids=position_ids,
|
| 380 |
+
past_key_value=past_key_value,
|
| 381 |
+
output_attentions=output_attentions,
|
| 382 |
+
use_cache=use_cache,
|
| 383 |
+
cond_embedding=diffusion_step,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
hidden_states = layer_outputs[0]
|
| 387 |
+
|
| 388 |
+
if use_cache:
|
| 389 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
| 390 |
+
|
| 391 |
+
if output_attentions:
|
| 392 |
+
all_self_attns += (layer_outputs[1],)
|
| 393 |
+
|
| 394 |
+
hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step)
|
| 395 |
+
|
| 396 |
+
# add hidden states from the last decoder layer
|
| 397 |
+
if output_hidden_states:
|
| 398 |
+
all_hidden_states += (hidden_states,)
|
| 399 |
+
|
| 400 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 401 |
+
|
| 402 |
+
hidden_states = self.mel_out_mlp(hidden_states)
|
| 403 |
+
|
| 404 |
+
if not output_hidden_states:
|
| 405 |
+
return hidden_states
|
| 406 |
+
else:
|
| 407 |
+
return {
|
| 408 |
+
"hidden_states": hidden_states,
|
| 409 |
+
"all_hidden_states": all_hidden_states,
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class DiffLlama(LlamaModel):
|
| 414 |
+
def __init__(
|
| 415 |
+
self,
|
| 416 |
+
mel_dim=100,
|
| 417 |
+
hidden_size=1024,
|
| 418 |
+
num_heads=16,
|
| 419 |
+
num_layers=16,
|
| 420 |
+
dropout=0.1,
|
| 421 |
+
ffn_dropout=0.1,
|
| 422 |
+
attention_dropout=0.0,
|
| 423 |
+
config=LlamaConfig(0, 256, 1024, 1, 1),
|
| 424 |
+
flash_attention=False,
|
| 425 |
+
):
|
| 426 |
+
super().__init__(config)
|
| 427 |
+
|
| 428 |
+
self.flash_attention = flash_attention
|
| 429 |
+
self.layers = nn.ModuleList(
|
| 430 |
+
[
|
| 431 |
+
LlamaNARDecoderLayer(
|
| 432 |
+
LlamaConfig(
|
| 433 |
+
hidden_size=hidden_size,
|
| 434 |
+
num_attention_heads=num_heads,
|
| 435 |
+
max_position_embeddings=4096,
|
| 436 |
+
intermediate_size=hidden_size * 4,
|
| 437 |
+
attn_implementation=(
|
| 438 |
+
"flash_attention_2" if self.flash_attention else "eager"
|
| 439 |
+
),
|
| 440 |
+
is_causal=False,
|
| 441 |
+
),
|
| 442 |
+
layer_idx=i,
|
| 443 |
+
)
|
| 444 |
+
for i in range(num_layers)
|
| 445 |
+
]
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size)
|
| 449 |
+
|
| 450 |
+
self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
|
| 451 |
+
self.diff_step_mlp = nn.Sequential(
|
| 452 |
+
nn.Linear(hidden_size, hidden_size * 4),
|
| 453 |
+
nn.SiLU(),
|
| 454 |
+
nn.Linear(hidden_size * 4, hidden_size),
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# self.cond_mlp = nn.Sequential(
|
| 458 |
+
# nn.Linear(hidden_size, hidden_size * 4),
|
| 459 |
+
# nn.SiLU(),
|
| 460 |
+
# nn.Linear(hidden_size * 4, hidden_size),
|
| 461 |
+
# )
|
| 462 |
+
|
| 463 |
+
self.mel_mlp = nn.Sequential(
|
| 464 |
+
nn.Linear(mel_dim, hidden_size * 4),
|
| 465 |
+
nn.SiLU(),
|
| 466 |
+
nn.Linear(hidden_size * 4, hidden_size),
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
self.mel_out_mlp = nn.Sequential(
|
| 470 |
+
nn.Linear(hidden_size, hidden_size * 4),
|
| 471 |
+
nn.SiLU(),
|
| 472 |
+
nn.Linear(hidden_size * 4, mel_dim),
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
for layer in self.layers:
|
| 476 |
+
layer.input_layernorm = LlamaAdaptiveRMSNorm(
|
| 477 |
+
hidden_size, dim_cond=hidden_size
|
| 478 |
+
)
|
| 479 |
+
layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
|
| 480 |
+
hidden_size, dim_cond=hidden_size
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
self.embed_tokens = None
|
| 484 |
+
|
| 485 |
+
self.post_init()
|
| 486 |
+
|
| 487 |
+
# self.reset_parameters()
|
| 488 |
+
|
| 489 |
+
def _prepare_decoder_attention_mask(
|
| 490 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
| 491 |
+
):
|
| 492 |
+
# create noncausal mask
|
| 493 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 494 |
+
combined_attention_mask = None
|
| 495 |
+
|
| 496 |
+
def _expand_mask(
|
| 497 |
+
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
| 498 |
+
):
|
| 499 |
+
"""
|
| 500 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 501 |
+
"""
|
| 502 |
+
bsz, src_len = mask.size()
|
| 503 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 504 |
+
|
| 505 |
+
expanded_mask = (
|
| 506 |
+
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
inverted_mask = 1.0 - expanded_mask
|
| 510 |
+
|
| 511 |
+
return inverted_mask.masked_fill(
|
| 512 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if attention_mask is not None:
|
| 516 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 517 |
+
expanded_attn_mask = _expand_mask(
|
| 518 |
+
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
| 519 |
+
).to(inputs_embeds.device)
|
| 520 |
+
combined_attention_mask = (
|
| 521 |
+
expanded_attn_mask
|
| 522 |
+
if combined_attention_mask is None
|
| 523 |
+
else expanded_attn_mask + combined_attention_mask
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
return combined_attention_mask
|
| 527 |
+
|
| 528 |
+
def forward(
|
| 529 |
+
self,
|
| 530 |
+
x,
|
| 531 |
+
diffusion_step,
|
| 532 |
+
x_mask,
|
| 533 |
+
cond,
|
| 534 |
+
input_ids: torch.LongTensor = None, # [num_quant, B, T]
|
| 535 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 536 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 537 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 538 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 539 |
+
use_cache: Optional[bool] = None,
|
| 540 |
+
output_attentions: Optional[bool] = None,
|
| 541 |
+
output_hidden_states: Optional[bool] = None,
|
| 542 |
+
return_dict: Optional[bool] = None,
|
| 543 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 544 |
+
|
| 545 |
+
# retrieve some shape info
|
| 546 |
+
batch_size, seq_length, _ = x.shape
|
| 547 |
+
|
| 548 |
+
# condtion mlp
|
| 549 |
+
cond_embedding = self.cond_mlp(cond) # (B, T, C)
|
| 550 |
+
|
| 551 |
+
# condition mel
|
| 552 |
+
x = self.mel_mlp(x)
|
| 553 |
+
|
| 554 |
+
# diffusion step embedding
|
| 555 |
+
diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device)
|
| 556 |
+
diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C)
|
| 557 |
+
x = x + cond_embedding
|
| 558 |
+
|
| 559 |
+
inputs_embeds = x
|
| 560 |
+
attention_mask = x_mask
|
| 561 |
+
|
| 562 |
+
output_attentions = (
|
| 563 |
+
output_attentions
|
| 564 |
+
if output_attentions is not None
|
| 565 |
+
else self.config.output_attentions
|
| 566 |
+
)
|
| 567 |
+
output_hidden_states = (
|
| 568 |
+
output_hidden_states
|
| 569 |
+
if output_hidden_states is not None
|
| 570 |
+
else self.config.output_hidden_states
|
| 571 |
+
)
|
| 572 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 573 |
+
|
| 574 |
+
return_dict = (
|
| 575 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
seq_length_with_past = seq_length
|
| 579 |
+
past_key_values_length = 0
|
| 580 |
+
|
| 581 |
+
if past_key_values is not None:
|
| 582 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 583 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 584 |
+
|
| 585 |
+
if position_ids is None:
|
| 586 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 587 |
+
position_ids = torch.arange(
|
| 588 |
+
past_key_values_length,
|
| 589 |
+
seq_length + past_key_values_length,
|
| 590 |
+
dtype=torch.long,
|
| 591 |
+
device=device,
|
| 592 |
+
)
|
| 593 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 594 |
+
else:
|
| 595 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
| 596 |
+
|
| 597 |
+
hidden_states = inputs_embeds
|
| 598 |
+
|
| 599 |
+
if self.gradient_checkpointing and self.training:
|
| 600 |
+
if use_cache:
|
| 601 |
+
use_cache = False
|
| 602 |
+
|
| 603 |
+
# decoder layers
|
| 604 |
+
all_hidden_states = () if output_hidden_states else None
|
| 605 |
+
all_self_attns = () if output_attentions else None
|
| 606 |
+
next_decoder_cache = () if use_cache else None
|
| 607 |
+
|
| 608 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 609 |
+
if output_hidden_states:
|
| 610 |
+
all_hidden_states += (hidden_states,)
|
| 611 |
+
|
| 612 |
+
past_key_value = (
|
| 613 |
+
past_key_values[idx] if past_key_values is not None else None
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
if self.gradient_checkpointing and self.training:
|
| 617 |
+
raise NotImplementedError
|
| 618 |
+
|
| 619 |
+
def create_custom_forward(module):
|
| 620 |
+
def custom_forward(*inputs):
|
| 621 |
+
# None for past_key_value
|
| 622 |
+
return module(*inputs, output_attentions, None)
|
| 623 |
+
|
| 624 |
+
return custom_forward
|
| 625 |
+
|
| 626 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 627 |
+
create_custom_forward(decoder_layer),
|
| 628 |
+
hidden_states,
|
| 629 |
+
attention_mask,
|
| 630 |
+
position_ids,
|
| 631 |
+
None,
|
| 632 |
+
)
|
| 633 |
+
else:
|
| 634 |
+
layer_outputs = decoder_layer(
|
| 635 |
+
hidden_states,
|
| 636 |
+
attention_mask=attention_mask,
|
| 637 |
+
position_ids=position_ids,
|
| 638 |
+
past_key_value=past_key_value,
|
| 639 |
+
output_attentions=output_attentions,
|
| 640 |
+
use_cache=use_cache,
|
| 641 |
+
cond_embedding=diffusion_step,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
hidden_states = layer_outputs[0]
|
| 645 |
+
|
| 646 |
+
if use_cache:
|
| 647 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
| 648 |
+
|
| 649 |
+
if output_attentions:
|
| 650 |
+
all_self_attns += (layer_outputs[1],)
|
| 651 |
+
|
| 652 |
+
hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step)
|
| 653 |
+
|
| 654 |
+
# add hidden states from the last decoder layer
|
| 655 |
+
if output_hidden_states:
|
| 656 |
+
all_hidden_states += (hidden_states,)
|
| 657 |
+
|
| 658 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 659 |
+
|
| 660 |
+
hidden_states = self.mel_out_mlp(hidden_states)
|
| 661 |
+
if not output_hidden_states:
|
| 662 |
+
return hidden_states
|
| 663 |
+
else:
|
| 664 |
+
return {
|
| 665 |
+
"hidden_states": hidden_states,
|
| 666 |
+
"all_hidden_states": all_hidden_states,
|
| 667 |
+
}
|
config/flow_matching.json
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "Sing2SongNoText",
|
| 3 |
+
"preprocess": {
|
| 4 |
+
"use_mel": true,
|
| 5 |
+
"sample_rate": 24000,
|
| 6 |
+
"n_fft": 1920,
|
| 7 |
+
"num_mels": 128,
|
| 8 |
+
"sampling_rate": 24000,
|
| 9 |
+
"hop_size": 480,
|
| 10 |
+
"hop_size_vocal": 480,
|
| 11 |
+
"hop_size_accompaniment": 480,
|
| 12 |
+
"win_size": 1920,
|
| 13 |
+
"fmin": 0,
|
| 14 |
+
"fmax": 12000,
|
| 15 |
+
"mel_var": 8.14,
|
| 16 |
+
"mel_mean": -4.92,
|
| 17 |
+
|
| 18 |
+
"chromagram": {
|
| 19 |
+
"naive": true,
|
| 20 |
+
"hop_size": 480,
|
| 21 |
+
"sample_rate": 24000,
|
| 22 |
+
"n_fft": 1920,
|
| 23 |
+
"num_mels": 128,
|
| 24 |
+
"win_size": 1920,
|
| 25 |
+
"fmin": 0,
|
| 26 |
+
"fmax": 12000,
|
| 27 |
+
"mel_var": 8.14,
|
| 28 |
+
"mel_mean": -4.92,
|
| 29 |
+
"f0_fmin": 50.0,
|
| 30 |
+
"f0_fmax": 1100.0
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"model": {
|
| 34 |
+
"flow_matching_transformer": {
|
| 35 |
+
"vocab_size": 512,
|
| 36 |
+
"use_cond_code": true,
|
| 37 |
+
"mel_dim": 128,
|
| 38 |
+
"cond_dim": 768,
|
| 39 |
+
"hidden_size": 1024,
|
| 40 |
+
"num_layers": 10,
|
| 41 |
+
"num_heads": 16,
|
| 42 |
+
"cfg_scale": 0.2,
|
| 43 |
+
"prompt_prob": 0.,
|
| 44 |
+
"use_pretrained_model": false,
|
| 45 |
+
"sigma": 1e-5,
|
| 46 |
+
"time_scheduler": "cos",
|
| 47 |
+
"repa_loss": {
|
| 48 |
+
"enable": true,
|
| 49 |
+
"weight": 0.5,
|
| 50 |
+
"repa_layer": 4,
|
| 51 |
+
},
|
| 52 |
+
"flash_attention": false,
|
| 53 |
+
},
|
| 54 |
+
"coco": {
|
| 55 |
+
"coco_type": "style", // content, style, or content_style
|
| 56 |
+
"downsample_rate": 1, // The original frame rate is 50 Hz, downsample to 6.25 Hz
|
| 57 |
+
"codebook_size": 512,
|
| 58 |
+
"hidden_size": 1024, // Representations Dim
|
| 59 |
+
"codebook_dim": 8,
|
| 60 |
+
"encoder": {
|
| 61 |
+
"vocos_dim": 384,
|
| 62 |
+
"vocos_intermediate_dim": 2048,
|
| 63 |
+
"vocos_num_layers": 12,
|
| 64 |
+
},
|
| 65 |
+
"decoder": {
|
| 66 |
+
"vocos_dim": 384,
|
| 67 |
+
"vocos_intermediate_dim": 2048,
|
| 68 |
+
"vocos_num_layers": 12,
|
| 69 |
+
},
|
| 70 |
+
"chromagram_dim": 24,
|
| 71 |
+
"pretrained_path": "./pretrained/vq"
|
| 72 |
+
},
|
| 73 |
+
},
|
| 74 |
+
}
|
config/vocoder.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "Vocoder",
|
| 3 |
+
"preprocess": {
|
| 4 |
+
"hop_size": 480,
|
| 5 |
+
"sample_rate": 24000,
|
| 6 |
+
"max_length": 36000,
|
| 7 |
+
"n_fft": 1920,
|
| 8 |
+
"num_mels": 128,
|
| 9 |
+
"win_size": 1920,
|
| 10 |
+
"fmin": 0,
|
| 11 |
+
"fmax": 12000,
|
| 12 |
+
"mel_var": 8.14,
|
| 13 |
+
"mel_mean": -4.92,
|
| 14 |
+
"processed_dir": "",
|
| 15 |
+
"valid_file": "valid.json",
|
| 16 |
+
"train_file": "train.json",
|
| 17 |
+
"use_phone_cond": false,
|
| 18 |
+
"use_emilia_101k": false
|
| 19 |
+
},
|
| 20 |
+
"model": {
|
| 21 |
+
"vocos": {
|
| 22 |
+
"input_channels": 128,
|
| 23 |
+
"dim": 1024,
|
| 24 |
+
"intermediate_dim": 4096,
|
| 25 |
+
"num_layers": 30,
|
| 26 |
+
"n_fft": 1920,
|
| 27 |
+
"hop_size": 480,
|
| 28 |
+
"padding": "same"
|
| 29 |
+
},
|
| 30 |
+
"period_gan": {
|
| 31 |
+
"max_downsample_channels": 1024,
|
| 32 |
+
"channels": 64,
|
| 33 |
+
"channel_increasing_factor": 2
|
| 34 |
+
},
|
| 35 |
+
"spec_gan": {
|
| 36 |
+
"stft_params": {
|
| 37 |
+
"fft_sizes": [128, 256, 512, 1024, 2048],
|
| 38 |
+
"hop_sizes": [32, 64, 128, 256, 512],
|
| 39 |
+
"win_lengths": [128, 256, 512, 1024, 2048],
|
| 40 |
+
"window": "hann_window"
|
| 41 |
+
},
|
| 42 |
+
"in_channels": 1,
|
| 43 |
+
"out_channels": 1,
|
| 44 |
+
"channels": 64,
|
| 45 |
+
"kernel_sizes": [5, 3],
|
| 46 |
+
"max_downsample_channels": 1024,
|
| 47 |
+
"down_scales": [2, 2, 2],
|
| 48 |
+
"use_weight_norm": true,
|
| 49 |
+
"use_complex": false
|
| 50 |
+
}
|
| 51 |
+
},
|
| 52 |
+
}
|
example/gradio/example1.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2860cf5b49b861f0770805c8cdda5b61276abc3931bb11140a5d6fa451418130
|
| 3 |
+
size 384580
|
example/gradio/example2.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a16962f991ca69af95c79f39050017bd14e6dfd2c11b7547a10cd2b123b5ea6
|
| 3 |
+
size 2646044
|
example/gradio/example3.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a96d209d00b50cb0b2ea0d985d4d94a5ea29b20874e9f91f4ed15235d4018ec
|
| 3 |
+
size 2646044
|
models/__init__.py
ADDED
|
File without changes
|
models/codec/__init__.py
ADDED
|
File without changes
|
models/codec/amphion_codec/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
models/codec/amphion_codec/quantize/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
|
| 7 |
+
FactorizedVectorQuantize,
|
| 8 |
+
)
|
| 9 |
+
from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
|
| 10 |
+
from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
|
| 11 |
+
from models.codec.amphion_codec.quantize.residual_vq import ResidualVQ
|
models/codec/amphion_codec/quantize/factorized_vector_quantize.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from torch.nn.utils import weight_norm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def WNConv1d(*args, **kwargs):
|
| 15 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 19 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class FactorizedVectorQuantize(nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
input_dim,
|
| 26 |
+
codebook_size,
|
| 27 |
+
codebook_dim,
|
| 28 |
+
commitment=0.005,
|
| 29 |
+
codebook_loss_weight=1.0,
|
| 30 |
+
use_l2_normlize=True,
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.input_dim = input_dim
|
| 34 |
+
self.codebook_size = codebook_size
|
| 35 |
+
self.codebook_dim = codebook_dim
|
| 36 |
+
self.commitment = commitment
|
| 37 |
+
self.codebook_loss_weight = codebook_loss_weight
|
| 38 |
+
self.use_l2_normlize = use_l2_normlize
|
| 39 |
+
|
| 40 |
+
if self.input_dim != self.codebook_dim:
|
| 41 |
+
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
| 42 |
+
self.out_project = WNConv1d(
|
| 43 |
+
self.codebook_dim, self.input_dim, kernel_size=1
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
else:
|
| 47 |
+
self.in_project = nn.Identity()
|
| 48 |
+
self.out_project = nn.Identity()
|
| 49 |
+
|
| 50 |
+
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
|
| 51 |
+
|
| 52 |
+
def forward(self, z):
|
| 53 |
+
"""
|
| 54 |
+
Parameters
|
| 55 |
+
----------
|
| 56 |
+
z: torch.Tensor[B x D x T]
|
| 57 |
+
|
| 58 |
+
Returns
|
| 59 |
+
-------
|
| 60 |
+
z_q: torch.Tensor[B x D x T]
|
| 61 |
+
Quantized continuous representation of input
|
| 62 |
+
commit_loss: Tensor[B]
|
| 63 |
+
Commitment loss to train encoder to predict vectors closer to codebook entries
|
| 64 |
+
codebook_loss: Tensor[B]
|
| 65 |
+
Codebook loss to update the codebook
|
| 66 |
+
indices: torch.Tensor[B x T]
|
| 67 |
+
Codebook indices (quantized discrete representation of input)
|
| 68 |
+
z_e: torch.Tensor[B x D x T]
|
| 69 |
+
Projected latents (continuous representation of input before quantization)
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
| 73 |
+
z_e = self.in_project(z)
|
| 74 |
+
z_q, indices = self.decode_latents(z_e)
|
| 75 |
+
|
| 76 |
+
# Compute commitment loss and codebook loss
|
| 77 |
+
if self.training:
|
| 78 |
+
commit_loss = (
|
| 79 |
+
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
| 80 |
+
* self.commitment
|
| 81 |
+
)
|
| 82 |
+
codebook_loss = (
|
| 83 |
+
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
| 84 |
+
* self.codebook_loss_weight
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
| 88 |
+
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
| 89 |
+
|
| 90 |
+
z_q = z_e + (z_q - z_e).detach()
|
| 91 |
+
|
| 92 |
+
z_q = self.out_project(z_q)
|
| 93 |
+
|
| 94 |
+
return z_q, commit_loss, codebook_loss, indices, z_e
|
| 95 |
+
|
| 96 |
+
def embed_code(self, embed_id):
|
| 97 |
+
return F.embedding(embed_id, self.codebook.weight)
|
| 98 |
+
|
| 99 |
+
def decode_code(self, embed_id):
|
| 100 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
| 101 |
+
|
| 102 |
+
def decode_latents(self, latents):
|
| 103 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
| 104 |
+
codebook = self.codebook.weight
|
| 105 |
+
|
| 106 |
+
# L2 normalize encodings and codebook
|
| 107 |
+
if self.use_l2_normlize:
|
| 108 |
+
encodings = F.normalize(encodings)
|
| 109 |
+
codebook = F.normalize(codebook)
|
| 110 |
+
|
| 111 |
+
# Compute euclidean distance between encodings and codebook,
|
| 112 |
+
# if use_l2_normlize is True, the distance is equal to cosine distance
|
| 113 |
+
dist = (
|
| 114 |
+
encodings.pow(2).sum(1, keepdim=True)
|
| 115 |
+
- 2 * encodings @ codebook.t()
|
| 116 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 117 |
+
)
|
| 118 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 119 |
+
z_q = self.decode_code(indices)
|
| 120 |
+
|
| 121 |
+
return z_q, indices
|
| 122 |
+
|
| 123 |
+
def vq2emb(self, vq, out_proj=True):
|
| 124 |
+
emb = self.decode_code(vq)
|
| 125 |
+
if out_proj:
|
| 126 |
+
emb = self.out_project(emb)
|
| 127 |
+
return emb
|
| 128 |
+
|
| 129 |
+
def latent2dist(self, latents):
|
| 130 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
| 131 |
+
codebook = self.codebook.weight
|
| 132 |
+
|
| 133 |
+
# L2 normalize encodings and codebook
|
| 134 |
+
if self.use_l2_normlize:
|
| 135 |
+
encodings = F.normalize(encodings)
|
| 136 |
+
codebook = F.normalize(codebook)
|
| 137 |
+
|
| 138 |
+
# Compute euclidean distance between encodings and codebook,
|
| 139 |
+
# if use_l2_normlize is True, the distance is equal to cosine distance
|
| 140 |
+
dist = (
|
| 141 |
+
encodings.pow(2).sum(1, keepdim=True)
|
| 142 |
+
- 2 * encodings @ codebook.t()
|
| 143 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 144 |
+
) # (b*t, k)
|
| 145 |
+
|
| 146 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 147 |
+
dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
|
| 148 |
+
z_q = self.decode_code(indices)
|
| 149 |
+
|
| 150 |
+
return -dist, indices, z_q
|
models/codec/amphion_codec/quantize/lookup_free_quantize.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from torch.nn.utils import weight_norm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def WNConv1d(*args, **kwargs):
|
| 15 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 19 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LookupFreeQuantize(nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
input_dim,
|
| 26 |
+
codebook_size,
|
| 27 |
+
codebook_dim,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.input_dim = input_dim
|
| 31 |
+
self.codebook_size = codebook_size
|
| 32 |
+
self.codebook_dim = codebook_dim
|
| 33 |
+
|
| 34 |
+
assert 2**codebook_dim == codebook_size
|
| 35 |
+
|
| 36 |
+
if self.input_dim != self.codebook_dim:
|
| 37 |
+
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
| 38 |
+
self.out_project = WNConv1d(
|
| 39 |
+
self.codebook_dim, self.input_dim, kernel_size=1
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
else:
|
| 43 |
+
self.in_project = nn.Identity()
|
| 44 |
+
self.out_project = nn.Identity()
|
| 45 |
+
|
| 46 |
+
def forward(self, z):
|
| 47 |
+
z_e = self.in_project(z)
|
| 48 |
+
z_e = F.sigmoid(z_e)
|
| 49 |
+
|
| 50 |
+
z_q = z_e + (torch.round(z_e) - z_e).detach()
|
| 51 |
+
|
| 52 |
+
z_q = self.out_project(z_q)
|
| 53 |
+
|
| 54 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
| 55 |
+
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
| 56 |
+
|
| 57 |
+
bits = (
|
| 58 |
+
2
|
| 59 |
+
** torch.arange(self.codebook_dim, device=z.device)
|
| 60 |
+
.unsqueeze(0)
|
| 61 |
+
.unsqueeze(-1)
|
| 62 |
+
.long()
|
| 63 |
+
) # (1, d, 1)
|
| 64 |
+
indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long()
|
| 65 |
+
|
| 66 |
+
return z_q, commit_loss, codebook_loss, indices, z_e
|
| 67 |
+
|
| 68 |
+
def vq2emb(self, vq, out_proj=True):
|
| 69 |
+
emb = torch.zeros(
|
| 70 |
+
vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device
|
| 71 |
+
) # (B, d, T)
|
| 72 |
+
for i in range(self.codebook_dim):
|
| 73 |
+
emb[:, i, :] = (vq % 2).float()
|
| 74 |
+
vq = vq // 2
|
| 75 |
+
if out_proj:
|
| 76 |
+
emb = self.out_project(emb)
|
| 77 |
+
return emb
|
models/codec/amphion_codec/quantize/residual_vq.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from torch.nn.utils import weight_norm
|
| 14 |
+
|
| 15 |
+
from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
|
| 16 |
+
FactorizedVectorQuantize,
|
| 17 |
+
)
|
| 18 |
+
from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
|
| 19 |
+
from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ResidualVQ(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
Introduced in SoundStream: An end2end neural audio codec
|
| 25 |
+
https://arxiv.org/abs/2107.03312
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
input_dim: int = 256,
|
| 31 |
+
num_quantizers: int = 8,
|
| 32 |
+
codebook_size: int = 1024,
|
| 33 |
+
codebook_dim: int = 256,
|
| 34 |
+
quantizer_type: str = "vq", # "vq" or "fvq" or "lfq"
|
| 35 |
+
quantizer_dropout: float = 0.5,
|
| 36 |
+
**kwargs,
|
| 37 |
+
):
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
self.input_dim = input_dim
|
| 41 |
+
self.num_quantizers = num_quantizers
|
| 42 |
+
self.codebook_size = codebook_size
|
| 43 |
+
self.codebook_dim = codebook_dim
|
| 44 |
+
self.quantizer_type = quantizer_type
|
| 45 |
+
self.quantizer_dropout = quantizer_dropout
|
| 46 |
+
|
| 47 |
+
if quantizer_type == "vq":
|
| 48 |
+
VQ = VectorQuantize
|
| 49 |
+
elif quantizer_type == "fvq":
|
| 50 |
+
VQ = FactorizedVectorQuantize
|
| 51 |
+
elif quantizer_type == "lfq":
|
| 52 |
+
VQ = LookupFreeQuantize
|
| 53 |
+
else:
|
| 54 |
+
raise ValueError(f"Unknown quantizer type {quantizer_type}")
|
| 55 |
+
|
| 56 |
+
self.quantizers = nn.ModuleList(
|
| 57 |
+
[
|
| 58 |
+
VQ(
|
| 59 |
+
input_dim=input_dim,
|
| 60 |
+
codebook_size=codebook_size,
|
| 61 |
+
codebook_dim=codebook_dim,
|
| 62 |
+
**kwargs,
|
| 63 |
+
)
|
| 64 |
+
for _ in range(num_quantizers)
|
| 65 |
+
]
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def forward(self, z, n_quantizers: int = None):
|
| 69 |
+
"""
|
| 70 |
+
Parameters
|
| 71 |
+
----------
|
| 72 |
+
z : Tensor[B x D x T]
|
| 73 |
+
n_quantizers : int, optional
|
| 74 |
+
No. of quantizers to use
|
| 75 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
| 76 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
| 77 |
+
when in training mode, and a random number of quantizers is used.
|
| 78 |
+
Returns
|
| 79 |
+
-------
|
| 80 |
+
"quantized_out" : Tensor[B x D x T]
|
| 81 |
+
Quantized continuous representation of input
|
| 82 |
+
"all_indices" : Tensor[N x B x T]
|
| 83 |
+
Codebook indices for each codebook
|
| 84 |
+
(quantized discrete representation of input)
|
| 85 |
+
"all_commit_losses" : Tensor[N]
|
| 86 |
+
"all_codebook_losses" : Tensor[N]
|
| 87 |
+
"all_quantized" : Tensor[N x B x D x T]
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
quantized_out = 0.0
|
| 91 |
+
residual = z
|
| 92 |
+
|
| 93 |
+
all_commit_losses = []
|
| 94 |
+
all_codebook_losses = []
|
| 95 |
+
all_indices = []
|
| 96 |
+
all_quantized = []
|
| 97 |
+
|
| 98 |
+
if n_quantizers is None:
|
| 99 |
+
n_quantizers = self.num_quantizers
|
| 100 |
+
|
| 101 |
+
if self.training:
|
| 102 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
|
| 103 |
+
dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
|
| 104 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
| 105 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 106 |
+
n_quantizers = n_quantizers.to(z.device)
|
| 107 |
+
|
| 108 |
+
for i, quantizer in enumerate(self.quantizers):
|
| 109 |
+
if self.training is False and i >= n_quantizers:
|
| 110 |
+
break
|
| 111 |
+
|
| 112 |
+
z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
| 113 |
+
residual
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Create mask to apply quantizer dropout
|
| 117 |
+
mask = (
|
| 118 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
| 119 |
+
)
|
| 120 |
+
quantized_out = quantized_out + z_q_i * mask[:, None, None]
|
| 121 |
+
residual = residual - z_q_i
|
| 122 |
+
|
| 123 |
+
commit_loss_i = (commit_loss_i * mask).mean()
|
| 124 |
+
codebook_loss_i = (codebook_loss_i * mask).mean()
|
| 125 |
+
|
| 126 |
+
all_commit_losses.append(commit_loss_i)
|
| 127 |
+
all_codebook_losses.append(codebook_loss_i)
|
| 128 |
+
all_indices.append(indices_i)
|
| 129 |
+
all_quantized.append(z_q_i)
|
| 130 |
+
|
| 131 |
+
all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
|
| 132 |
+
torch.stack,
|
| 133 |
+
(all_commit_losses, all_codebook_losses, all_indices, all_quantized),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return (
|
| 137 |
+
quantized_out,
|
| 138 |
+
all_indices,
|
| 139 |
+
all_commit_losses,
|
| 140 |
+
all_codebook_losses,
|
| 141 |
+
all_quantized,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def vq2emb(self, vq, n_quantizers=None):
|
| 145 |
+
quantized_out = 0.0
|
| 146 |
+
if n_quantizers is None:
|
| 147 |
+
n_quantizers = self.num_quantizers
|
| 148 |
+
for idx, quantizer in enumerate(self.quantizers):
|
| 149 |
+
if idx >= n_quantizers:
|
| 150 |
+
break
|
| 151 |
+
quantized_out += quantizer.vq2emb(vq[idx])
|
| 152 |
+
return quantized_out
|
| 153 |
+
|
| 154 |
+
def latent2dist(self, z, n_quantizers=None):
|
| 155 |
+
quantized_out = 0.0
|
| 156 |
+
residual = z
|
| 157 |
+
|
| 158 |
+
all_dists = []
|
| 159 |
+
all_indices = []
|
| 160 |
+
|
| 161 |
+
if n_quantizers is None:
|
| 162 |
+
n_quantizers = self.num_quantizers
|
| 163 |
+
|
| 164 |
+
for i, quantizer in enumerate(self.quantizers):
|
| 165 |
+
if self.training is False and i >= n_quantizers:
|
| 166 |
+
break
|
| 167 |
+
dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
|
| 168 |
+
all_dists.append(dist_i)
|
| 169 |
+
all_indices.append(indices_i)
|
| 170 |
+
|
| 171 |
+
quantized_out = quantized_out + z_q_i
|
| 172 |
+
residual = residual - z_q_i
|
| 173 |
+
|
| 174 |
+
all_dists = torch.stack(all_dists)
|
| 175 |
+
all_indices = torch.stack(all_indices)
|
| 176 |
+
|
| 177 |
+
return all_dists, all_indices
|
models/codec/amphion_codec/quantize/vector_quantize.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from einops import rearrange, repeat
|
| 11 |
+
from torch.nn.utils import weight_norm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def WNConv1d(*args, **kwargs):
|
| 15 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 19 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def l2norm(t):
|
| 23 |
+
return F.normalize(t, p=2, dim=-1)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def ema_inplace(moving_avg, new, decay):
|
| 27 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def laplace_smoothing(x, n_categories, eps=1e-5):
|
| 31 |
+
return (x + eps) / (x.sum() + n_categories * eps)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def sample_vectors(samples, num):
|
| 35 |
+
num_samples, device = samples.shape[0], samples.device
|
| 36 |
+
|
| 37 |
+
if num_samples >= num:
|
| 38 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
| 39 |
+
else:
|
| 40 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 41 |
+
|
| 42 |
+
return samples[indices]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
|
| 46 |
+
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
|
| 47 |
+
|
| 48 |
+
means = sample_vectors(samples, num_clusters)
|
| 49 |
+
|
| 50 |
+
for _ in range(num_iters):
|
| 51 |
+
if use_cosine_sim:
|
| 52 |
+
dists = samples @ means.t()
|
| 53 |
+
else:
|
| 54 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
| 55 |
+
means, "c d -> () c d"
|
| 56 |
+
)
|
| 57 |
+
dists = -(diffs**2).sum(dim=-1)
|
| 58 |
+
|
| 59 |
+
buckets = dists.max(dim=-1).indices
|
| 60 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
| 61 |
+
zero_mask = bins == 0
|
| 62 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
| 63 |
+
|
| 64 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
| 65 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
| 66 |
+
new_means = new_means / bins_min_clamped[..., None]
|
| 67 |
+
|
| 68 |
+
if use_cosine_sim:
|
| 69 |
+
new_means = l2norm(new_means)
|
| 70 |
+
|
| 71 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
| 72 |
+
|
| 73 |
+
return means, bins
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class EuclideanCodebook(nn.Module):
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
dim,
|
| 80 |
+
codebook_size,
|
| 81 |
+
kmeans_init=False,
|
| 82 |
+
kmeans_iters=10,
|
| 83 |
+
decay=0.8,
|
| 84 |
+
eps=1e-5,
|
| 85 |
+
threshold_ema_dead_code=2,
|
| 86 |
+
weight_init=False,
|
| 87 |
+
):
|
| 88 |
+
super().__init__()
|
| 89 |
+
|
| 90 |
+
self.decay = decay
|
| 91 |
+
init_fn = torch.randn if not weight_init else torch.zeros
|
| 92 |
+
embed = init_fn(codebook_size, dim)
|
| 93 |
+
|
| 94 |
+
if weight_init:
|
| 95 |
+
nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
|
| 96 |
+
|
| 97 |
+
self.codebook_size = codebook_size
|
| 98 |
+
self.kmeans_iters = kmeans_iters
|
| 99 |
+
self.eps = eps
|
| 100 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 101 |
+
|
| 102 |
+
self.register_buffer(
|
| 103 |
+
"initted", torch.Tensor([not kmeans_init])
|
| 104 |
+
) # if kmeans_init is True, then initted is False; otherwise, initted is True
|
| 105 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
| 106 |
+
self.register_buffer("embed", embed)
|
| 107 |
+
self.register_buffer("embed_avg", embed.clone())
|
| 108 |
+
|
| 109 |
+
def init_embed_(self, data):
|
| 110 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
| 111 |
+
self.embed.data.copy_(embed)
|
| 112 |
+
self.embed_avg.data.copy_(embed)
|
| 113 |
+
self.cluster_size.data.copy_(cluster_size)
|
| 114 |
+
self.initted.data.copy_(torch.Tensor([True]))
|
| 115 |
+
|
| 116 |
+
def replace(self, samples, mask):
|
| 117 |
+
modified_codebook = torch.where(
|
| 118 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
| 119 |
+
)
|
| 120 |
+
self.embed.data.copy_(modified_codebook)
|
| 121 |
+
|
| 122 |
+
def expire_codes_(self, batch_samples):
|
| 123 |
+
if self.threshold_ema_dead_code == 0:
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
| 127 |
+
if not torch.any(expired_codes):
|
| 128 |
+
return
|
| 129 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
| 130 |
+
self.replace(batch_samples, mask=expired_codes)
|
| 131 |
+
|
| 132 |
+
def forward(self, x):
|
| 133 |
+
shape, dtype = x.shape, x.dtype
|
| 134 |
+
flatten = rearrange(x, "... d -> (...) d")
|
| 135 |
+
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
|
| 136 |
+
|
| 137 |
+
if not self.initted:
|
| 138 |
+
self.init_embed_(flatten)
|
| 139 |
+
|
| 140 |
+
dist = -(
|
| 141 |
+
flatten.pow(2).sum(1, keepdim=True)
|
| 142 |
+
- 2 * flatten @ embed
|
| 143 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
embed_ind = dist.max(dim=-1).indices
|
| 147 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
| 148 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
| 149 |
+
quantize = F.embedding(embed_ind, self.embed)
|
| 150 |
+
|
| 151 |
+
if self.training:
|
| 152 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
| 153 |
+
embed_sum = (
|
| 154 |
+
flatten.t() @ embed_onehot
|
| 155 |
+
) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
|
| 156 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
| 157 |
+
cluster_size = (
|
| 158 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
|
| 159 |
+
* self.cluster_size.sum()
|
| 160 |
+
)
|
| 161 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
| 162 |
+
self.embed.data.copy_(embed_normalized)
|
| 163 |
+
self.expire_codes_(x)
|
| 164 |
+
|
| 165 |
+
return quantize, embed_ind
|
| 166 |
+
|
| 167 |
+
def vq2emb(self, vq):
|
| 168 |
+
quantize = F.embedding(vq, self.embed)
|
| 169 |
+
return quantize
|
| 170 |
+
|
| 171 |
+
def latent2dist(self, x):
|
| 172 |
+
shape, dtype = x.shape, x.dtype
|
| 173 |
+
flatten = rearrange(x, "... d -> (...) d")
|
| 174 |
+
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
|
| 175 |
+
|
| 176 |
+
if not self.initted:
|
| 177 |
+
self.init_embed_(flatten)
|
| 178 |
+
|
| 179 |
+
dist = -(
|
| 180 |
+
flatten.pow(2).sum(1, keepdim=True)
|
| 181 |
+
- 2 * flatten @ embed
|
| 182 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
embed_ind = dist.max(dim=-1).indices
|
| 186 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
| 187 |
+
quantize = F.embedding(embed_ind, self.embed)
|
| 188 |
+
|
| 189 |
+
dist = dist.view(*shape[:-1], -1)
|
| 190 |
+
|
| 191 |
+
return dist, embed_ind, quantize
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class SimpleCodebook(nn.Module):
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
dim,
|
| 198 |
+
codebook_size,
|
| 199 |
+
use_l2_normlize=False,
|
| 200 |
+
):
|
| 201 |
+
super().__init__()
|
| 202 |
+
|
| 203 |
+
self.dim = dim
|
| 204 |
+
self.codebook_size = codebook_size
|
| 205 |
+
self.use_l2_normlize = use_l2_normlize
|
| 206 |
+
|
| 207 |
+
self.embed = nn.Embedding(self.codebook_size, self.dim)
|
| 208 |
+
|
| 209 |
+
def forward(self, x):
|
| 210 |
+
shape, dtype = x.shape, x.dtype
|
| 211 |
+
flatten = rearrange(x, "... d -> (...) d")
|
| 212 |
+
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
|
| 213 |
+
|
| 214 |
+
if self.use_l2_normlize:
|
| 215 |
+
flatten = F.normalize(flatten)
|
| 216 |
+
embed = F.normalize(embed)
|
| 217 |
+
|
| 218 |
+
dist = -(
|
| 219 |
+
flatten.pow(2).sum(1, keepdim=True)
|
| 220 |
+
- 2 * flatten @ embed
|
| 221 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
embed_ind = dist.max(dim=-1).indices
|
| 225 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
| 226 |
+
quantize = F.embedding(embed_ind, self.embed)
|
| 227 |
+
|
| 228 |
+
return quantize, embed_ind
|
| 229 |
+
|
| 230 |
+
def vq2emb(self, vq):
|
| 231 |
+
quantize = F.embedding(vq, self.embed.weight)
|
| 232 |
+
return quantize
|
| 233 |
+
|
| 234 |
+
def latent2dist(self, x):
|
| 235 |
+
shape, dtype = x.shape, x.dtype
|
| 236 |
+
flatten = rearrange(x, "... d -> (...) d")
|
| 237 |
+
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
|
| 238 |
+
|
| 239 |
+
if self.use_l2_normlize:
|
| 240 |
+
flatten = F.normalize(flatten)
|
| 241 |
+
embed = F.normalize(embed)
|
| 242 |
+
|
| 243 |
+
dist = -(
|
| 244 |
+
flatten.pow(2).sum(1, keepdim=True)
|
| 245 |
+
- 2 * flatten @ embed
|
| 246 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
embed_ind = dist.max(dim=-1).indices
|
| 250 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
| 251 |
+
quantize = F.embedding(embed_ind, self.embed)
|
| 252 |
+
|
| 253 |
+
dist = dist.view(*shape[:-1], -1)
|
| 254 |
+
|
| 255 |
+
return dist, embed_ind, quantize
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class VectorQuantize(nn.Module):
|
| 259 |
+
"""Vector quantization and factorized vecotor quantization implementation
|
| 260 |
+
Args:
|
| 261 |
+
input_dim (int): Dimension of input.
|
| 262 |
+
codebook_size (int): Codebook size.
|
| 263 |
+
codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
|
| 264 |
+
if use codebook_type == "euclidean", otherwise, if you want to use
|
| 265 |
+
factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
|
| 266 |
+
commitment (float): Weight for commitment loss.
|
| 267 |
+
use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
|
| 268 |
+
we suggest use it as True if you want to use factorized vector quantization
|
| 269 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
| 270 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
| 271 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 272 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 273 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 274 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 275 |
+
randomly selected vector from the current batch.
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
def __init__(
|
| 279 |
+
self,
|
| 280 |
+
input_dim,
|
| 281 |
+
codebook_size,
|
| 282 |
+
codebook_dim,
|
| 283 |
+
commitment=0.005,
|
| 284 |
+
codebook_loss_weight=1.0,
|
| 285 |
+
use_l2_normlize=False,
|
| 286 |
+
codebook_type="euclidean", # "euclidean" or "simple"
|
| 287 |
+
kmeans_init=False,
|
| 288 |
+
kmeans_iters=10,
|
| 289 |
+
decay=0.8,
|
| 290 |
+
eps=1e-5,
|
| 291 |
+
threshold_ema_dead_code=2,
|
| 292 |
+
weight_init=False,
|
| 293 |
+
):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.input_dim = input_dim
|
| 296 |
+
self.codebook_size = codebook_size
|
| 297 |
+
self.codebook_dim = codebook_dim
|
| 298 |
+
self.commitment = commitment
|
| 299 |
+
self.codebook_loss_weight = codebook_loss_weight
|
| 300 |
+
self.use_l2_normlize = use_l2_normlize
|
| 301 |
+
self.codebook_type = codebook_type
|
| 302 |
+
self.kmeans_init = kmeans_init
|
| 303 |
+
self.kmeans_iters = kmeans_iters
|
| 304 |
+
self.decay = decay
|
| 305 |
+
self.eps = eps
|
| 306 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 307 |
+
self.weight_init = weight_init
|
| 308 |
+
|
| 309 |
+
if self.input_dim != self.codebook_dim:
|
| 310 |
+
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
| 311 |
+
self.out_project = WNConv1d(
|
| 312 |
+
self.codebook_dim, self.input_dim, kernel_size=1
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
else:
|
| 316 |
+
self.in_project = nn.Identity()
|
| 317 |
+
self.out_project = nn.Identity()
|
| 318 |
+
|
| 319 |
+
if self.codebook_type == "euclidean":
|
| 320 |
+
self.codebook = EuclideanCodebook(
|
| 321 |
+
self.codebook_dim,
|
| 322 |
+
codebook_size=self.codebook_size,
|
| 323 |
+
kmeans_init=self.kmeans_init,
|
| 324 |
+
kmeans_iters=self.kmeans_iters,
|
| 325 |
+
decay=self.decay,
|
| 326 |
+
eps=self.eps,
|
| 327 |
+
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
| 328 |
+
weight_init=self.weight_init,
|
| 329 |
+
)
|
| 330 |
+
elif self.codebook_type == "simple":
|
| 331 |
+
self.codebook = SimpleCodebook(
|
| 332 |
+
self.codebook_dim,
|
| 333 |
+
codebook_size=self.codebook_size,
|
| 334 |
+
use_l2_normlize=self.use_l2_normlize,
|
| 335 |
+
)
|
| 336 |
+
else:
|
| 337 |
+
raise NotImplementedError(
|
| 338 |
+
f"codebook_type {self.codebook_type} is not implemented!"
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
def forward(self, z):
|
| 342 |
+
"""
|
| 343 |
+
Parameters
|
| 344 |
+
----------
|
| 345 |
+
z: torch.Tensor[B x D x T]
|
| 346 |
+
|
| 347 |
+
Returns
|
| 348 |
+
-------
|
| 349 |
+
z_q: torch.Tensor[B x D x T]
|
| 350 |
+
Quantized continuous representation of input
|
| 351 |
+
commit_loss: Tensor[B]
|
| 352 |
+
Commitment loss to train encoder to predict vectors closer to codebook entries
|
| 353 |
+
codebook_loss: Tensor[B]
|
| 354 |
+
Codebook loss to update the codebook
|
| 355 |
+
indices: torch.Tensor[B x T]
|
| 356 |
+
Codebook indices (quantized discrete representation of input)
|
| 357 |
+
z_e: torch.Tensor[B x D x T]
|
| 358 |
+
Projected latents (continuous representation of input before quantization)
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
| 362 |
+
z_e = self.in_project(z)
|
| 363 |
+
z_q, indices = self.decode_latents(z_e)
|
| 364 |
+
|
| 365 |
+
# Compute commitment loss and codebook loss
|
| 366 |
+
if self.training:
|
| 367 |
+
commit_loss = (
|
| 368 |
+
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
| 369 |
+
* self.commitment
|
| 370 |
+
)
|
| 371 |
+
codebook_loss = (
|
| 372 |
+
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
| 373 |
+
* self.codebook_loss_weight
|
| 374 |
+
)
|
| 375 |
+
else:
|
| 376 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
| 377 |
+
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
| 378 |
+
|
| 379 |
+
z_q = z_e + (z_q - z_e).detach()
|
| 380 |
+
|
| 381 |
+
z_q = self.out_project(z_q)
|
| 382 |
+
|
| 383 |
+
return z_q, commit_loss, codebook_loss, indices, z_e
|
| 384 |
+
|
| 385 |
+
def decode_latents(self, latents):
|
| 386 |
+
encodings = rearrange(latents, "b d t -> b t d")
|
| 387 |
+
z_q, indices = self.codebook(encodings)
|
| 388 |
+
z_q = z_q.transpose(1, 2)
|
| 389 |
+
return z_q, indices
|
| 390 |
+
|
| 391 |
+
def vq2emb(self, vq, out_proj=True):
|
| 392 |
+
emb = self.codebook.vq2emb(vq)
|
| 393 |
+
emb = emb.transpose(1, 2)
|
| 394 |
+
if out_proj:
|
| 395 |
+
emb = self.out_project(emb)
|
| 396 |
+
return emb
|
| 397 |
+
|
| 398 |
+
def latent2dist(self, latents):
|
| 399 |
+
latents = rearrange(latents, "b d t -> b t d")
|
| 400 |
+
dist, embed_ind, quantize = self.codebook.latent2dist(latents)
|
| 401 |
+
return dist, embed_ind, quantize.transpose(1, 2)
|
models/codec/amphion_codec/vocos.py
ADDED
|
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import scipy
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn, view_as_real, view_as_complex
|
| 12 |
+
from torch import nn
|
| 13 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
| 14 |
+
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
|
| 15 |
+
import librosa
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
| 19 |
+
"""
|
| 20 |
+
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
x (Tensor): Input tensor.
|
| 24 |
+
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
| 28 |
+
"""
|
| 29 |
+
return torch.log(torch.clip(x, min=clip_val))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
return torch.sign(x) * torch.log1p(x.abs())
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def symexp(x: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
return torch.sign(x) * (torch.exp(x.abs()) - 1)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class STFT(nn.Module):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
n_fft: int,
|
| 44 |
+
hop_length: int,
|
| 45 |
+
win_length: int,
|
| 46 |
+
center=True,
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.center = center
|
| 50 |
+
self.n_fft = n_fft
|
| 51 |
+
self.hop_length = hop_length
|
| 52 |
+
self.win_length = win_length
|
| 53 |
+
window = torch.hann_window(win_length)
|
| 54 |
+
self.register_buffer("window", window)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
# x: (B, T * hop_length)
|
| 58 |
+
|
| 59 |
+
if not self.center:
|
| 60 |
+
pad = self.win_length - self.hop_length
|
| 61 |
+
x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
|
| 62 |
+
|
| 63 |
+
stft_spec = torch.stft(
|
| 64 |
+
x,
|
| 65 |
+
self.n_fft,
|
| 66 |
+
hop_length=self.hop_length,
|
| 67 |
+
win_length=self.win_length,
|
| 68 |
+
window=self.window,
|
| 69 |
+
center=self.center,
|
| 70 |
+
return_complex=False,
|
| 71 |
+
) # (B, n_fft // 2 + 1, T, 2)
|
| 72 |
+
|
| 73 |
+
rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
|
| 74 |
+
imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
|
| 75 |
+
|
| 76 |
+
log_mag = torch.log(
|
| 77 |
+
torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
|
| 78 |
+
) # (B, n_fft // 2 + 1, T)
|
| 79 |
+
phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
|
| 80 |
+
|
| 81 |
+
return log_mag, phase
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class ISTFT(nn.Module):
|
| 85 |
+
"""
|
| 86 |
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
| 87 |
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
| 88 |
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
| 89 |
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
| 90 |
+
The NOLA constraint is met as we trim padded samples anyway.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
n_fft (int): Size of Fourier transform.
|
| 94 |
+
hop_length (int): The distance between neighboring sliding window frames.
|
| 95 |
+
win_length (int): The size of window frame and STFT filter.
|
| 96 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
|
| 101 |
+
):
|
| 102 |
+
super().__init__()
|
| 103 |
+
if padding not in ["center", "same"]:
|
| 104 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 105 |
+
self.padding = padding
|
| 106 |
+
self.n_fft = n_fft
|
| 107 |
+
self.hop_length = hop_length
|
| 108 |
+
self.win_length = win_length
|
| 109 |
+
window = torch.hann_window(win_length)
|
| 110 |
+
self.register_buffer("window", window)
|
| 111 |
+
|
| 112 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
| 113 |
+
"""
|
| 114 |
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
| 118 |
+
N is the number of frequency bins, and T is the number of time frames.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
| 122 |
+
"""
|
| 123 |
+
if self.padding == "center":
|
| 124 |
+
# Fallback to pytorch native implementation
|
| 125 |
+
return torch.istft(
|
| 126 |
+
spec,
|
| 127 |
+
self.n_fft,
|
| 128 |
+
self.hop_length,
|
| 129 |
+
self.win_length,
|
| 130 |
+
self.window,
|
| 131 |
+
center=True,
|
| 132 |
+
)
|
| 133 |
+
elif self.padding == "same":
|
| 134 |
+
pad = (self.win_length - self.hop_length) // 2
|
| 135 |
+
else:
|
| 136 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 137 |
+
|
| 138 |
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
| 139 |
+
B, N, T = spec.shape
|
| 140 |
+
|
| 141 |
+
# Inverse FFT
|
| 142 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
| 143 |
+
ifft = ifft * self.window[None, :, None]
|
| 144 |
+
|
| 145 |
+
# Overlap and Add
|
| 146 |
+
output_size = (T - 1) * self.hop_length + self.win_length
|
| 147 |
+
y = torch.nn.functional.fold(
|
| 148 |
+
ifft,
|
| 149 |
+
output_size=(1, output_size),
|
| 150 |
+
kernel_size=(1, self.win_length),
|
| 151 |
+
stride=(1, self.hop_length),
|
| 152 |
+
)[:, 0, 0, pad:-pad]
|
| 153 |
+
|
| 154 |
+
# Window envelope
|
| 155 |
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
| 156 |
+
window_envelope = torch.nn.functional.fold(
|
| 157 |
+
window_sq,
|
| 158 |
+
output_size=(1, output_size),
|
| 159 |
+
kernel_size=(1, self.win_length),
|
| 160 |
+
stride=(1, self.hop_length),
|
| 161 |
+
).squeeze()[pad:-pad]
|
| 162 |
+
|
| 163 |
+
# Normalize
|
| 164 |
+
assert (window_envelope > 1e-11).all()
|
| 165 |
+
y = y / window_envelope
|
| 166 |
+
|
| 167 |
+
return y
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class MDCT(nn.Module):
|
| 171 |
+
"""
|
| 172 |
+
Modified Discrete Cosine Transform (MDCT) module.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
frame_len (int): Length of the MDCT frame.
|
| 176 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
| 180 |
+
super().__init__()
|
| 181 |
+
if padding not in ["center", "same"]:
|
| 182 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 183 |
+
self.padding = padding
|
| 184 |
+
self.frame_len = frame_len
|
| 185 |
+
N = frame_len // 2
|
| 186 |
+
n0 = (N + 1) / 2
|
| 187 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
| 188 |
+
self.register_buffer("window", window)
|
| 189 |
+
|
| 190 |
+
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
|
| 191 |
+
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
|
| 192 |
+
# view_as_real: NCCL Backend does not support ComplexFloat data type
|
| 193 |
+
# https://github.com/pytorch/pytorch/issues/71613
|
| 194 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
| 195 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
| 196 |
+
|
| 197 |
+
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
| 198 |
+
"""
|
| 199 |
+
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
|
| 203 |
+
and T is the length of the audio.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
|
| 207 |
+
and N is the number of frequency bins.
|
| 208 |
+
"""
|
| 209 |
+
if self.padding == "center":
|
| 210 |
+
audio = torch.nn.functional.pad(
|
| 211 |
+
audio, (self.frame_len // 2, self.frame_len // 2)
|
| 212 |
+
)
|
| 213 |
+
elif self.padding == "same":
|
| 214 |
+
# hop_length is 1/2 frame_len
|
| 215 |
+
audio = torch.nn.functional.pad(
|
| 216 |
+
audio, (self.frame_len // 4, self.frame_len // 4)
|
| 217 |
+
)
|
| 218 |
+
else:
|
| 219 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 220 |
+
|
| 221 |
+
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
|
| 222 |
+
N = self.frame_len // 2
|
| 223 |
+
x = x * self.window.expand(x.shape)
|
| 224 |
+
X = torch.fft.fft(
|
| 225 |
+
x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
|
| 226 |
+
)[..., :N]
|
| 227 |
+
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
|
| 228 |
+
return torch.real(res) * np.sqrt(2)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class IMDCT(nn.Module):
|
| 232 |
+
"""
|
| 233 |
+
Inverse Modified Discrete Cosine Transform (IMDCT) module.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
frame_len (int): Length of the MDCT frame.
|
| 237 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
| 241 |
+
super().__init__()
|
| 242 |
+
if padding not in ["center", "same"]:
|
| 243 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 244 |
+
self.padding = padding
|
| 245 |
+
self.frame_len = frame_len
|
| 246 |
+
N = frame_len // 2
|
| 247 |
+
n0 = (N + 1) / 2
|
| 248 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
| 249 |
+
self.register_buffer("window", window)
|
| 250 |
+
|
| 251 |
+
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
|
| 252 |
+
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
|
| 253 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
| 254 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
| 255 |
+
|
| 256 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
| 257 |
+
"""
|
| 258 |
+
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
|
| 262 |
+
L is the number of frames, and N is the number of frequency bins.
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
|
| 266 |
+
"""
|
| 267 |
+
B, L, N = X.shape
|
| 268 |
+
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
|
| 269 |
+
Y[..., :N] = X
|
| 270 |
+
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
|
| 271 |
+
y = torch.fft.ifft(
|
| 272 |
+
Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
|
| 273 |
+
)
|
| 274 |
+
y = (
|
| 275 |
+
torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
|
| 276 |
+
* np.sqrt(N)
|
| 277 |
+
* np.sqrt(2)
|
| 278 |
+
)
|
| 279 |
+
result = y * self.window.expand(y.shape)
|
| 280 |
+
output_size = (1, (L + 1) * N)
|
| 281 |
+
audio = torch.nn.functional.fold(
|
| 282 |
+
result.transpose(1, 2),
|
| 283 |
+
output_size=output_size,
|
| 284 |
+
kernel_size=(1, self.frame_len),
|
| 285 |
+
stride=(1, self.frame_len // 2),
|
| 286 |
+
)[:, 0, 0, :]
|
| 287 |
+
|
| 288 |
+
if self.padding == "center":
|
| 289 |
+
pad = self.frame_len // 2
|
| 290 |
+
elif self.padding == "same":
|
| 291 |
+
pad = self.frame_len // 4
|
| 292 |
+
else:
|
| 293 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 294 |
+
|
| 295 |
+
audio = audio[:, pad:-pad]
|
| 296 |
+
return audio
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class FourierHead(nn.Module):
|
| 300 |
+
"""Base class for inverse fourier modules."""
|
| 301 |
+
|
| 302 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 303 |
+
"""
|
| 304 |
+
Args:
|
| 305 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 306 |
+
L is the sequence length, and H denotes the model dimension.
|
| 307 |
+
|
| 308 |
+
Returns:
|
| 309 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 310 |
+
"""
|
| 311 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class ISTFTHead(FourierHead):
|
| 315 |
+
"""
|
| 316 |
+
ISTFT Head module for predicting STFT complex coefficients.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
dim (int): Hidden dimension of the model.
|
| 320 |
+
n_fft (int): Size of Fourier transform.
|
| 321 |
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
| 322 |
+
the resolution of the input features.
|
| 323 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
| 327 |
+
super().__init__()
|
| 328 |
+
out_dim = n_fft + 2
|
| 329 |
+
self.out = torch.nn.Linear(dim, out_dim)
|
| 330 |
+
self.istft = ISTFT(
|
| 331 |
+
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 335 |
+
"""
|
| 336 |
+
Forward pass of the ISTFTHead module.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 340 |
+
L is the sequence length, and H denotes the model dimension.
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 344 |
+
"""
|
| 345 |
+
x = self.out(x).transpose(1, 2)
|
| 346 |
+
mag, p = x.chunk(2, dim=1)
|
| 347 |
+
mag = torch.exp(mag)
|
| 348 |
+
mag = torch.clip(
|
| 349 |
+
mag, max=1e2
|
| 350 |
+
) # safeguard to prevent excessively large magnitudes
|
| 351 |
+
# wrapping happens here. These two lines produce real and imaginary value
|
| 352 |
+
x = torch.cos(p)
|
| 353 |
+
y = torch.sin(p)
|
| 354 |
+
# recalculating phase here does not produce anything new
|
| 355 |
+
# only costs time
|
| 356 |
+
# phase = torch.atan2(y, x)
|
| 357 |
+
# S = mag * torch.exp(phase * 1j)
|
| 358 |
+
# better directly produce the complex value
|
| 359 |
+
S = mag * (x + 1j * y)
|
| 360 |
+
audio = self.istft(S)
|
| 361 |
+
return audio
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class IMDCTSymExpHead(FourierHead):
|
| 365 |
+
"""
|
| 366 |
+
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
dim (int): Hidden dimension of the model.
|
| 370 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
| 371 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 372 |
+
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
|
| 373 |
+
based on perceptual scaling. Defaults to None.
|
| 374 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
def __init__(
|
| 378 |
+
self,
|
| 379 |
+
dim: int,
|
| 380 |
+
mdct_frame_len: int,
|
| 381 |
+
padding: str = "same",
|
| 382 |
+
sample_rate: Optional[int] = None,
|
| 383 |
+
clip_audio: bool = False,
|
| 384 |
+
):
|
| 385 |
+
super().__init__()
|
| 386 |
+
out_dim = mdct_frame_len // 2
|
| 387 |
+
self.out = nn.Linear(dim, out_dim)
|
| 388 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
| 389 |
+
self.clip_audio = clip_audio
|
| 390 |
+
|
| 391 |
+
if sample_rate is not None:
|
| 392 |
+
# optionally init the last layer following mel-scale
|
| 393 |
+
m_max = _hz_to_mel(sample_rate // 2)
|
| 394 |
+
m_pts = torch.linspace(0, m_max, out_dim)
|
| 395 |
+
f_pts = _mel_to_hz(m_pts)
|
| 396 |
+
scale = 1 - (f_pts / f_pts.max())
|
| 397 |
+
|
| 398 |
+
with torch.no_grad():
|
| 399 |
+
self.out.weight.mul_(scale.view(-1, 1))
|
| 400 |
+
|
| 401 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 402 |
+
"""
|
| 403 |
+
Forward pass of the IMDCTSymExpHead module.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 407 |
+
L is the sequence length, and H denotes the model dimension.
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 411 |
+
"""
|
| 412 |
+
x = self.out(x)
|
| 413 |
+
x = symexp(x)
|
| 414 |
+
x = torch.clip(
|
| 415 |
+
x, min=-1e2, max=1e2
|
| 416 |
+
) # safeguard to prevent excessively large magnitudes
|
| 417 |
+
audio = self.imdct(x)
|
| 418 |
+
if self.clip_audio:
|
| 419 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
| 420 |
+
|
| 421 |
+
return audio
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class IMDCTCosHead(FourierHead):
|
| 425 |
+
"""
|
| 426 |
+
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
dim (int): Hidden dimension of the model.
|
| 430 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
| 431 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 432 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
| 433 |
+
"""
|
| 434 |
+
|
| 435 |
+
def __init__(
|
| 436 |
+
self,
|
| 437 |
+
dim: int,
|
| 438 |
+
mdct_frame_len: int,
|
| 439 |
+
padding: str = "same",
|
| 440 |
+
clip_audio: bool = False,
|
| 441 |
+
):
|
| 442 |
+
super().__init__()
|
| 443 |
+
self.clip_audio = clip_audio
|
| 444 |
+
self.out = nn.Linear(dim, mdct_frame_len)
|
| 445 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
| 446 |
+
|
| 447 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 448 |
+
"""
|
| 449 |
+
Forward pass of the IMDCTCosHead module.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 453 |
+
L is the sequence length, and H denotes the model dimension.
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 457 |
+
"""
|
| 458 |
+
x = self.out(x)
|
| 459 |
+
m, p = x.chunk(2, dim=2)
|
| 460 |
+
m = torch.exp(m).clip(
|
| 461 |
+
max=1e2
|
| 462 |
+
) # safeguard to prevent excessively large magnitudes
|
| 463 |
+
audio = self.imdct(m * torch.cos(p))
|
| 464 |
+
if self.clip_audio:
|
| 465 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
| 466 |
+
return audio
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class ConvNeXtBlock(nn.Module):
|
| 470 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
dim (int): Number of input channels.
|
| 474 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
| 475 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 476 |
+
Defaults to None.
|
| 477 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 478 |
+
None means non-conditional LayerNorm. Defaults to None.
|
| 479 |
+
"""
|
| 480 |
+
|
| 481 |
+
def __init__(
|
| 482 |
+
self,
|
| 483 |
+
dim: int,
|
| 484 |
+
intermediate_dim: int,
|
| 485 |
+
layer_scale_init_value: float,
|
| 486 |
+
adanorm_num_embeddings: Optional[int] = None,
|
| 487 |
+
):
|
| 488 |
+
super().__init__()
|
| 489 |
+
self.dwconv = nn.Conv1d(
|
| 490 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
| 491 |
+
) # depthwise conv
|
| 492 |
+
self.adanorm = adanorm_num_embeddings is not None
|
| 493 |
+
if adanorm_num_embeddings:
|
| 494 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
| 495 |
+
else:
|
| 496 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 497 |
+
self.pwconv1 = nn.Linear(
|
| 498 |
+
dim, intermediate_dim
|
| 499 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 500 |
+
self.act = nn.GELU()
|
| 501 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 502 |
+
self.gamma = (
|
| 503 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
| 504 |
+
if layer_scale_init_value > 0
|
| 505 |
+
else None
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
def forward(
|
| 509 |
+
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
| 510 |
+
) -> torch.Tensor:
|
| 511 |
+
residual = x
|
| 512 |
+
x = self.dwconv(x)
|
| 513 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
| 514 |
+
if self.adanorm:
|
| 515 |
+
assert cond_embedding_id is not None
|
| 516 |
+
x = self.norm(x, cond_embedding_id)
|
| 517 |
+
else:
|
| 518 |
+
x = self.norm(x)
|
| 519 |
+
x = self.pwconv1(x)
|
| 520 |
+
x = self.act(x)
|
| 521 |
+
x = self.pwconv2(x)
|
| 522 |
+
if self.gamma is not None:
|
| 523 |
+
x = self.gamma * x
|
| 524 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
| 525 |
+
|
| 526 |
+
x = residual + x
|
| 527 |
+
return x
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class AdaLayerNorm(nn.Module):
|
| 531 |
+
"""
|
| 532 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
num_embeddings (int): Number of embeddings.
|
| 536 |
+
embedding_dim (int): Dimension of the embeddings.
|
| 537 |
+
"""
|
| 538 |
+
|
| 539 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
| 540 |
+
super().__init__()
|
| 541 |
+
self.eps = eps
|
| 542 |
+
self.dim = embedding_dim
|
| 543 |
+
self.scale = nn.Embedding(
|
| 544 |
+
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
| 545 |
+
)
|
| 546 |
+
self.shift = nn.Embedding(
|
| 547 |
+
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
| 548 |
+
)
|
| 549 |
+
torch.nn.init.ones_(self.scale.weight)
|
| 550 |
+
torch.nn.init.zeros_(self.shift.weight)
|
| 551 |
+
|
| 552 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
| 553 |
+
scale = self.scale(cond_embedding_id)
|
| 554 |
+
shift = self.shift(cond_embedding_id)
|
| 555 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
| 556 |
+
x = x * scale + shift
|
| 557 |
+
return x
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class ResBlock1(nn.Module):
|
| 561 |
+
"""
|
| 562 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
| 563 |
+
but without upsampling layers.
|
| 564 |
+
|
| 565 |
+
Args:
|
| 566 |
+
dim (int): Number of input channels.
|
| 567 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
| 568 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
| 569 |
+
Defaults to (1, 3, 5).
|
| 570 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
| 571 |
+
Defaults to 0.1.
|
| 572 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 573 |
+
Defaults to None.
|
| 574 |
+
"""
|
| 575 |
+
|
| 576 |
+
def __init__(
|
| 577 |
+
self,
|
| 578 |
+
dim: int,
|
| 579 |
+
kernel_size: int = 3,
|
| 580 |
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
| 581 |
+
lrelu_slope: float = 0.1,
|
| 582 |
+
layer_scale_init_value: Optional[float] = None,
|
| 583 |
+
):
|
| 584 |
+
super().__init__()
|
| 585 |
+
self.lrelu_slope = lrelu_slope
|
| 586 |
+
self.convs1 = nn.ModuleList(
|
| 587 |
+
[
|
| 588 |
+
weight_norm(
|
| 589 |
+
nn.Conv1d(
|
| 590 |
+
dim,
|
| 591 |
+
dim,
|
| 592 |
+
kernel_size,
|
| 593 |
+
1,
|
| 594 |
+
dilation=dilation[0],
|
| 595 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
| 596 |
+
)
|
| 597 |
+
),
|
| 598 |
+
weight_norm(
|
| 599 |
+
nn.Conv1d(
|
| 600 |
+
dim,
|
| 601 |
+
dim,
|
| 602 |
+
kernel_size,
|
| 603 |
+
1,
|
| 604 |
+
dilation=dilation[1],
|
| 605 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
| 606 |
+
)
|
| 607 |
+
),
|
| 608 |
+
weight_norm(
|
| 609 |
+
nn.Conv1d(
|
| 610 |
+
dim,
|
| 611 |
+
dim,
|
| 612 |
+
kernel_size,
|
| 613 |
+
1,
|
| 614 |
+
dilation=dilation[2],
|
| 615 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
| 616 |
+
)
|
| 617 |
+
),
|
| 618 |
+
]
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
self.convs2 = nn.ModuleList(
|
| 622 |
+
[
|
| 623 |
+
weight_norm(
|
| 624 |
+
nn.Conv1d(
|
| 625 |
+
dim,
|
| 626 |
+
dim,
|
| 627 |
+
kernel_size,
|
| 628 |
+
1,
|
| 629 |
+
dilation=1,
|
| 630 |
+
padding=self.get_padding(kernel_size, 1),
|
| 631 |
+
)
|
| 632 |
+
),
|
| 633 |
+
weight_norm(
|
| 634 |
+
nn.Conv1d(
|
| 635 |
+
dim,
|
| 636 |
+
dim,
|
| 637 |
+
kernel_size,
|
| 638 |
+
1,
|
| 639 |
+
dilation=1,
|
| 640 |
+
padding=self.get_padding(kernel_size, 1),
|
| 641 |
+
)
|
| 642 |
+
),
|
| 643 |
+
weight_norm(
|
| 644 |
+
nn.Conv1d(
|
| 645 |
+
dim,
|
| 646 |
+
dim,
|
| 647 |
+
kernel_size,
|
| 648 |
+
1,
|
| 649 |
+
dilation=1,
|
| 650 |
+
padding=self.get_padding(kernel_size, 1),
|
| 651 |
+
)
|
| 652 |
+
),
|
| 653 |
+
]
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
self.gamma = nn.ParameterList(
|
| 657 |
+
[
|
| 658 |
+
(
|
| 659 |
+
nn.Parameter(
|
| 660 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
| 661 |
+
)
|
| 662 |
+
if layer_scale_init_value is not None
|
| 663 |
+
else None
|
| 664 |
+
),
|
| 665 |
+
(
|
| 666 |
+
nn.Parameter(
|
| 667 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
| 668 |
+
)
|
| 669 |
+
if layer_scale_init_value is not None
|
| 670 |
+
else None
|
| 671 |
+
),
|
| 672 |
+
(
|
| 673 |
+
nn.Parameter(
|
| 674 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
| 675 |
+
)
|
| 676 |
+
if layer_scale_init_value is not None
|
| 677 |
+
else None
|
| 678 |
+
),
|
| 679 |
+
]
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 683 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
| 684 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
| 685 |
+
xt = c1(xt)
|
| 686 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
| 687 |
+
xt = c2(xt)
|
| 688 |
+
if gamma is not None:
|
| 689 |
+
xt = gamma * xt
|
| 690 |
+
x = xt + x
|
| 691 |
+
return x
|
| 692 |
+
|
| 693 |
+
def remove_weight_norm(self):
|
| 694 |
+
for l in self.convs1:
|
| 695 |
+
remove_weight_norm(l)
|
| 696 |
+
for l in self.convs2:
|
| 697 |
+
remove_weight_norm(l)
|
| 698 |
+
|
| 699 |
+
@staticmethod
|
| 700 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
| 701 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
class Backbone(nn.Module):
|
| 705 |
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
| 706 |
+
|
| 707 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 708 |
+
"""
|
| 709 |
+
Args:
|
| 710 |
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
| 711 |
+
C denotes output features, and L is the sequence length.
|
| 712 |
+
|
| 713 |
+
Returns:
|
| 714 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
| 715 |
+
and H denotes the model dimension.
|
| 716 |
+
"""
|
| 717 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
class VocosBackbone(Backbone):
|
| 721 |
+
"""
|
| 722 |
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
| 723 |
+
|
| 724 |
+
Args:
|
| 725 |
+
input_channels (int): Number of input features channels.
|
| 726 |
+
dim (int): Hidden dimension of the model.
|
| 727 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
| 728 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
| 729 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
| 730 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 731 |
+
None means non-conditional model. Defaults to None.
|
| 732 |
+
"""
|
| 733 |
+
|
| 734 |
+
def __init__(
|
| 735 |
+
self,
|
| 736 |
+
input_channels: int,
|
| 737 |
+
dim: int,
|
| 738 |
+
intermediate_dim: int,
|
| 739 |
+
num_layers: int,
|
| 740 |
+
layer_scale_init_value: Optional[float] = None,
|
| 741 |
+
adanorm_num_embeddings: Optional[int] = None,
|
| 742 |
+
):
|
| 743 |
+
super().__init__()
|
| 744 |
+
self.input_channels = input_channels
|
| 745 |
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
| 746 |
+
self.adanorm = adanorm_num_embeddings is not None
|
| 747 |
+
if adanorm_num_embeddings:
|
| 748 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
| 749 |
+
else:
|
| 750 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 751 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
| 752 |
+
self.convnext = nn.ModuleList(
|
| 753 |
+
[
|
| 754 |
+
ConvNeXtBlock(
|
| 755 |
+
dim=dim,
|
| 756 |
+
intermediate_dim=intermediate_dim,
|
| 757 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 758 |
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
| 759 |
+
)
|
| 760 |
+
for _ in range(num_layers)
|
| 761 |
+
]
|
| 762 |
+
)
|
| 763 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
| 764 |
+
self.apply(self._init_weights)
|
| 765 |
+
|
| 766 |
+
def _init_weights(self, m):
|
| 767 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 768 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 769 |
+
nn.init.constant_(m.bias, 0)
|
| 770 |
+
|
| 771 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 772 |
+
bandwidth_id = kwargs.get("bandwidth_id", None)
|
| 773 |
+
x = self.embed(x)
|
| 774 |
+
if self.adanorm:
|
| 775 |
+
assert bandwidth_id is not None
|
| 776 |
+
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
| 777 |
+
else:
|
| 778 |
+
x = self.norm(x.transpose(1, 2))
|
| 779 |
+
x = x.transpose(1, 2)
|
| 780 |
+
for conv_block in self.convnext:
|
| 781 |
+
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
| 782 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
| 783 |
+
return x
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
class VocosResNetBackbone(Backbone):
|
| 787 |
+
"""
|
| 788 |
+
Vocos backbone module built with ResBlocks.
|
| 789 |
+
|
| 790 |
+
Args:
|
| 791 |
+
input_channels (int): Number of input features channels.
|
| 792 |
+
dim (int): Hidden dimension of the model.
|
| 793 |
+
num_blocks (int): Number of ResBlock1 blocks.
|
| 794 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
| 795 |
+
"""
|
| 796 |
+
|
| 797 |
+
def __init__(
|
| 798 |
+
self,
|
| 799 |
+
input_channels,
|
| 800 |
+
dim,
|
| 801 |
+
num_blocks,
|
| 802 |
+
layer_scale_init_value=None,
|
| 803 |
+
):
|
| 804 |
+
super().__init__()
|
| 805 |
+
self.input_channels = input_channels
|
| 806 |
+
self.embed = weight_norm(
|
| 807 |
+
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
|
| 808 |
+
)
|
| 809 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
| 810 |
+
self.resnet = nn.Sequential(
|
| 811 |
+
*[
|
| 812 |
+
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
|
| 813 |
+
for _ in range(num_blocks)
|
| 814 |
+
]
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 818 |
+
x = self.embed(x)
|
| 819 |
+
x = self.resnet(x)
|
| 820 |
+
x = x.transpose(1, 2)
|
| 821 |
+
return x
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
class Vocos(nn.Module):
|
| 825 |
+
def __init__(
|
| 826 |
+
self,
|
| 827 |
+
input_channels: int = 256,
|
| 828 |
+
dim: int = 384,
|
| 829 |
+
intermediate_dim: int = 1152,
|
| 830 |
+
num_layers: int = 8,
|
| 831 |
+
n_fft: int = 800,
|
| 832 |
+
hop_size: int = 200,
|
| 833 |
+
padding: str = "same",
|
| 834 |
+
adanorm_num_embeddings=None,
|
| 835 |
+
cfg=None,
|
| 836 |
+
):
|
| 837 |
+
super().__init__()
|
| 838 |
+
|
| 839 |
+
input_channels = (
|
| 840 |
+
cfg.input_channels
|
| 841 |
+
if cfg is not None and hasattr(cfg, "input_channels")
|
| 842 |
+
else input_channels
|
| 843 |
+
)
|
| 844 |
+
dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim
|
| 845 |
+
intermediate_dim = (
|
| 846 |
+
cfg.intermediate_dim
|
| 847 |
+
if cfg is not None and hasattr(cfg, "intermediate_dim")
|
| 848 |
+
else intermediate_dim
|
| 849 |
+
)
|
| 850 |
+
num_layers = (
|
| 851 |
+
cfg.num_layers
|
| 852 |
+
if cfg is not None and hasattr(cfg, "num_layers")
|
| 853 |
+
else num_layers
|
| 854 |
+
)
|
| 855 |
+
adanorm_num_embeddings = (
|
| 856 |
+
cfg.adanorm_num_embeddings
|
| 857 |
+
if cfg is not None and hasattr(cfg, "adanorm_num_embeddings")
|
| 858 |
+
else adanorm_num_embeddings
|
| 859 |
+
)
|
| 860 |
+
n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
|
| 861 |
+
hop_size = (
|
| 862 |
+
cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
|
| 863 |
+
)
|
| 864 |
+
padding = (
|
| 865 |
+
cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
self.backbone = VocosBackbone(
|
| 869 |
+
input_channels=input_channels,
|
| 870 |
+
dim=dim,
|
| 871 |
+
intermediate_dim=intermediate_dim,
|
| 872 |
+
num_layers=num_layers,
|
| 873 |
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
| 874 |
+
)
|
| 875 |
+
self.head = ISTFTHead(dim, n_fft, hop_size, padding)
|
| 876 |
+
|
| 877 |
+
def forward(self, x):
|
| 878 |
+
x = self.backbone(x)
|
| 879 |
+
x = self.head(x)
|
| 880 |
+
|
| 881 |
+
return x[:, None, :]
|
models/codec/coco/rep_coco_model.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from concurrent.futures import ALL_COMPLETED
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
|
| 13 |
+
from models.codec.amphion_codec.quantize import ResidualVQ
|
| 14 |
+
from models.codec.amphion_codec.vocos import VocosBackbone
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def init_weights(m):
|
| 18 |
+
if isinstance(m, nn.Conv1d):
|
| 19 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 20 |
+
nn.init.constant_(m.bias, 0)
|
| 21 |
+
if isinstance(m, nn.Linear):
|
| 22 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 23 |
+
nn.init.constant_(m.bias, 0)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def compute_codebook_perplexity(indices, codebook_size):
|
| 27 |
+
indices = indices.flatten()
|
| 28 |
+
prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0)
|
| 29 |
+
perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10)))
|
| 30 |
+
return perp
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class CocoContentStyle(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
codebook_size=8192,
|
| 37 |
+
hidden_size=1024,
|
| 38 |
+
codebook_dim=8,
|
| 39 |
+
num_quantizers=1,
|
| 40 |
+
quantizer_type="fvq",
|
| 41 |
+
use_whisper=True,
|
| 42 |
+
use_chromagram=True,
|
| 43 |
+
construct_only_for_quantizer=False,
|
| 44 |
+
cfg=None,
|
| 45 |
+
):
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
assert cfg is not None
|
| 49 |
+
self.cfg = cfg
|
| 50 |
+
|
| 51 |
+
codebook_size = getattr(cfg, "codebook_size", codebook_size)
|
| 52 |
+
hidden_size = getattr(cfg, "hidden_size", hidden_size)
|
| 53 |
+
codebook_dim = getattr(cfg, "codebook_dim", codebook_dim)
|
| 54 |
+
num_quantizers = getattr(cfg, "num_quantizers", num_quantizers)
|
| 55 |
+
quantizer_type = getattr(cfg, "quantizer_type", quantizer_type)
|
| 56 |
+
|
| 57 |
+
self.codebook_size = codebook_size
|
| 58 |
+
self.codebook_dim = codebook_dim
|
| 59 |
+
self.hidden_size = hidden_size
|
| 60 |
+
self.num_quantizers = num_quantizers
|
| 61 |
+
self.quantizer_type = quantizer_type
|
| 62 |
+
|
| 63 |
+
if use_whisper:
|
| 64 |
+
self.whisper_input_layer = nn.Linear(self.cfg.whisper_dim, hidden_size)
|
| 65 |
+
if use_chromagram:
|
| 66 |
+
self.chromagram_input_layer = nn.Linear(
|
| 67 |
+
self.cfg.chromagram_dim, hidden_size
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
downsample_rate = getattr(cfg, "downsample_rate", 1)
|
| 71 |
+
if downsample_rate > 1:
|
| 72 |
+
self.do_downsample = True
|
| 73 |
+
assert np.log2(downsample_rate).is_integer()
|
| 74 |
+
|
| 75 |
+
down_layers = []
|
| 76 |
+
up_layers = []
|
| 77 |
+
for _ in range(int(np.log2(downsample_rate))):
|
| 78 |
+
down_layers.extend(
|
| 79 |
+
[
|
| 80 |
+
nn.Conv1d(
|
| 81 |
+
hidden_size,
|
| 82 |
+
hidden_size,
|
| 83 |
+
kernel_size=3,
|
| 84 |
+
stride=2,
|
| 85 |
+
padding=1,
|
| 86 |
+
),
|
| 87 |
+
nn.GELU(),
|
| 88 |
+
]
|
| 89 |
+
)
|
| 90 |
+
up_layers.extend(
|
| 91 |
+
[
|
| 92 |
+
nn.ConvTranspose1d(
|
| 93 |
+
hidden_size, hidden_size, kernel_size=4, stride=2, padding=1
|
| 94 |
+
),
|
| 95 |
+
nn.GELU(),
|
| 96 |
+
]
|
| 97 |
+
)
|
| 98 |
+
self.downsample_layers = nn.Sequential(*down_layers)
|
| 99 |
+
self.upsample_layers = nn.Sequential(*up_layers)
|
| 100 |
+
|
| 101 |
+
else:
|
| 102 |
+
self.do_downsample = False
|
| 103 |
+
|
| 104 |
+
self.encoder = nn.Sequential(
|
| 105 |
+
VocosBackbone(
|
| 106 |
+
input_channels=self.hidden_size,
|
| 107 |
+
dim=self.cfg.encoder.vocos_dim,
|
| 108 |
+
intermediate_dim=self.cfg.encoder.vocos_intermediate_dim,
|
| 109 |
+
num_layers=self.cfg.encoder.vocos_num_layers,
|
| 110 |
+
adanorm_num_embeddings=None,
|
| 111 |
+
),
|
| 112 |
+
nn.Linear(self.cfg.encoder.vocos_dim, self.hidden_size),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.quantizer = ResidualVQ(
|
| 116 |
+
input_dim=hidden_size,
|
| 117 |
+
num_quantizers=num_quantizers,
|
| 118 |
+
codebook_size=codebook_size,
|
| 119 |
+
codebook_dim=codebook_dim,
|
| 120 |
+
quantizer_type=quantizer_type,
|
| 121 |
+
quantizer_dropout=0.0,
|
| 122 |
+
commitment=0.15,
|
| 123 |
+
codebook_loss_weight=1.0,
|
| 124 |
+
use_l2_normlize=True,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if not construct_only_for_quantizer:
|
| 128 |
+
self.decoder = nn.Sequential(
|
| 129 |
+
VocosBackbone(
|
| 130 |
+
input_channels=self.hidden_size,
|
| 131 |
+
dim=self.cfg.decoder.vocos_dim,
|
| 132 |
+
intermediate_dim=self.cfg.decoder.vocos_intermediate_dim,
|
| 133 |
+
num_layers=self.cfg.decoder.vocos_num_layers,
|
| 134 |
+
adanorm_num_embeddings=None,
|
| 135 |
+
),
|
| 136 |
+
nn.Linear(self.cfg.decoder.vocos_dim, self.hidden_size),
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
if use_whisper:
|
| 140 |
+
self.whisper_output_layer = nn.Linear(
|
| 141 |
+
self.hidden_size, self.cfg.whisper_dim
|
| 142 |
+
)
|
| 143 |
+
if use_chromagram:
|
| 144 |
+
self.chromagram_output_layer = nn.Linear(
|
| 145 |
+
self.hidden_size, self.cfg.chromagram_dim
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self.reset_parameters()
|
| 149 |
+
|
| 150 |
+
def forward(
|
| 151 |
+
self,
|
| 152 |
+
whisper_feats,
|
| 153 |
+
chromagram_feats,
|
| 154 |
+
return_for_quantizer=False,
|
| 155 |
+
):
|
| 156 |
+
"""
|
| 157 |
+
Args:
|
| 158 |
+
whisper_feats: [B, T, 1024]
|
| 159 |
+
chromagram_feats: [B, T, 24]
|
| 160 |
+
Returns:
|
| 161 |
+
whisper_rec: [B, T, 1024]
|
| 162 |
+
chromagram_rec: [B, T, 24]
|
| 163 |
+
codebook_loss: float
|
| 164 |
+
all_indices: [N, B, T] or [B, T] if num_of_quantizers == 1
|
| 165 |
+
"""
|
| 166 |
+
T = whisper_feats.shape[1]
|
| 167 |
+
|
| 168 |
+
# [B, T, D]
|
| 169 |
+
x = self.whisper_input_layer(whisper_feats) + self.chromagram_input_layer(
|
| 170 |
+
chromagram_feats
|
| 171 |
+
)
|
| 172 |
+
# print("Before downsample:", x.shape)
|
| 173 |
+
|
| 174 |
+
# ====== Downsample ======
|
| 175 |
+
if self.do_downsample:
|
| 176 |
+
x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2)
|
| 177 |
+
|
| 178 |
+
# print("After downsample:", x.shape)
|
| 179 |
+
|
| 180 |
+
# ====== Encoder ======
|
| 181 |
+
x = self.encoder(x.transpose(1, 2)).transpose(1, 2) # [B, T, D] -> [B, D, T]
|
| 182 |
+
|
| 183 |
+
# ====== Quantizer ======
|
| 184 |
+
(
|
| 185 |
+
quantized_out, # [B, D, T]
|
| 186 |
+
all_indices, # [num_of_quantizers, B, T]
|
| 187 |
+
all_commit_losses, # [num_of_quantizers]
|
| 188 |
+
all_codebook_losses, # [num_of_quantizers]
|
| 189 |
+
_,
|
| 190 |
+
) = self.quantizer(x)
|
| 191 |
+
|
| 192 |
+
if return_for_quantizer:
|
| 193 |
+
if all_indices.shape[0] == 1:
|
| 194 |
+
return all_indices.squeeze(0), quantized_out.transpose(1, 2)
|
| 195 |
+
return all_indices, quantized_out.transpose(1, 2)
|
| 196 |
+
|
| 197 |
+
# ====== Decoder ======
|
| 198 |
+
x_rec = self.decoder(quantized_out) # [B, T, D]
|
| 199 |
+
|
| 200 |
+
# ====== Upsample ======
|
| 201 |
+
if self.do_downsample:
|
| 202 |
+
x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2)
|
| 203 |
+
|
| 204 |
+
# print("After upsample:", x_rec.shape)
|
| 205 |
+
|
| 206 |
+
# Ensure output dimensions match input
|
| 207 |
+
if x_rec.shape[1] >= T: # Check time dimension
|
| 208 |
+
x_rec = x_rec[:, :T, :]
|
| 209 |
+
else:
|
| 210 |
+
padding_frames = T - x_rec.shape[1]
|
| 211 |
+
last_frame = x_rec[:, -1:, :]
|
| 212 |
+
padding = last_frame.repeat(1, padding_frames, 1)
|
| 213 |
+
x_rec = torch.cat([x_rec, padding], dim=1)
|
| 214 |
+
|
| 215 |
+
# ====== Loss ======
|
| 216 |
+
whisper_rec = self.whisper_output_layer(x_rec) # [B, T, 1024]
|
| 217 |
+
chromagram_rec = self.chromagram_output_layer(x_rec) # [B, T, 24]
|
| 218 |
+
|
| 219 |
+
codebook_loss = (all_codebook_losses + all_commit_losses).mean()
|
| 220 |
+
all_indices = all_indices
|
| 221 |
+
|
| 222 |
+
return whisper_rec, chromagram_rec, codebook_loss, all_indices
|
| 223 |
+
|
| 224 |
+
def quantize(self, whisper_feats, chromagram_feats):
|
| 225 |
+
"""
|
| 226 |
+
Args:
|
| 227 |
+
whisper_feats: [B, T, 1024]
|
| 228 |
+
chromagram_feats: [B, T, 24]
|
| 229 |
+
Returns:
|
| 230 |
+
all_indices: [N, B, T], or [B, T] if num_of_quantizers == 1
|
| 231 |
+
quantized_out: [B, D, T]
|
| 232 |
+
"""
|
| 233 |
+
all_indices, quantized_out = self.forward(
|
| 234 |
+
whisper_feats,
|
| 235 |
+
chromagram_feats,
|
| 236 |
+
return_for_quantizer=True,
|
| 237 |
+
)
|
| 238 |
+
return all_indices, quantized_out
|
| 239 |
+
|
| 240 |
+
def reset_parameters(self):
|
| 241 |
+
self.apply(init_weights)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class CocoContent(CocoContentStyle):
|
| 245 |
+
def __init__(
|
| 246 |
+
self,
|
| 247 |
+
cfg,
|
| 248 |
+
use_whisper=True,
|
| 249 |
+
use_chromagram=False,
|
| 250 |
+
construct_only_for_quantizer=False,
|
| 251 |
+
):
|
| 252 |
+
super().__init__(
|
| 253 |
+
cfg=cfg,
|
| 254 |
+
use_whisper=use_whisper,
|
| 255 |
+
use_chromagram=use_chromagram,
|
| 256 |
+
construct_only_for_quantizer=construct_only_for_quantizer,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
def forward(
|
| 260 |
+
self,
|
| 261 |
+
whisper_feats,
|
| 262 |
+
return_for_quantizer=False,
|
| 263 |
+
):
|
| 264 |
+
"""
|
| 265 |
+
Args:
|
| 266 |
+
whisper_feats: [B, T, 1024]
|
| 267 |
+
Returns:
|
| 268 |
+
whisper_rec: [B, T, 1024]
|
| 269 |
+
codebook_loss: float
|
| 270 |
+
all_indices: [N, B, T]
|
| 271 |
+
"""
|
| 272 |
+
T = whisper_feats.shape[1]
|
| 273 |
+
|
| 274 |
+
# [B, T, D]
|
| 275 |
+
x = self.whisper_input_layer(whisper_feats)
|
| 276 |
+
|
| 277 |
+
# ====== Downsample ======
|
| 278 |
+
if self.do_downsample:
|
| 279 |
+
x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2)
|
| 280 |
+
|
| 281 |
+
# ====== Encoder ======
|
| 282 |
+
x = self.encoder(x.transpose(1, 2)).transpose(1, 2) # [B, T, D] -> [B, D, T]
|
| 283 |
+
|
| 284 |
+
# ====== Quantizer ======
|
| 285 |
+
(
|
| 286 |
+
quantized_out, # [B, D, T]
|
| 287 |
+
all_indices, # [num_of_quantizers, B, T]
|
| 288 |
+
all_commit_losses, # [num_of_quantizers]
|
| 289 |
+
all_codebook_losses, # [num_of_quantizers]
|
| 290 |
+
_,
|
| 291 |
+
) = self.quantizer(x)
|
| 292 |
+
|
| 293 |
+
if return_for_quantizer:
|
| 294 |
+
if all_indices.shape[0] == 1:
|
| 295 |
+
return all_indices.squeeze(0), quantized_out.transpose(1, 2)
|
| 296 |
+
return all_indices, quantized_out.transpose(1, 2)
|
| 297 |
+
|
| 298 |
+
# ====== Decoder ======
|
| 299 |
+
x_rec = self.decoder(quantized_out) # [B, T, D]
|
| 300 |
+
|
| 301 |
+
# ====== Upsample ======
|
| 302 |
+
if self.do_downsample:
|
| 303 |
+
x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2)
|
| 304 |
+
|
| 305 |
+
# Ensure output dimensions match input
|
| 306 |
+
if x_rec.shape[1] >= T: # Check time dimension
|
| 307 |
+
x_rec = x_rec[:, :T, :]
|
| 308 |
+
else:
|
| 309 |
+
padding_frames = T - x_rec.shape[1]
|
| 310 |
+
last_frame = x_rec[:, -1:, :]
|
| 311 |
+
padding = last_frame.repeat(1, padding_frames, 1)
|
| 312 |
+
x_rec = torch.cat([x_rec, padding], dim=1)
|
| 313 |
+
|
| 314 |
+
# ====== Loss ======
|
| 315 |
+
whisper_rec = self.whisper_output_layer(x_rec) # [B, T, 1024]
|
| 316 |
+
|
| 317 |
+
codebook_loss = (all_codebook_losses + all_commit_losses).mean()
|
| 318 |
+
all_indices = all_indices
|
| 319 |
+
|
| 320 |
+
return whisper_rec, codebook_loss, all_indices
|
| 321 |
+
|
| 322 |
+
def quantize(self, whisper_feats):
|
| 323 |
+
all_indices, quantized_out = self.forward(
|
| 324 |
+
whisper_feats, return_for_quantizer=True
|
| 325 |
+
)
|
| 326 |
+
return all_indices, quantized_out
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class CocoStyle(CocoContentStyle):
|
| 330 |
+
def __init__(
|
| 331 |
+
self,
|
| 332 |
+
cfg,
|
| 333 |
+
use_whisper=False,
|
| 334 |
+
use_chromagram=True,
|
| 335 |
+
construct_only_for_quantizer=False,
|
| 336 |
+
):
|
| 337 |
+
super().__init__(
|
| 338 |
+
cfg=cfg,
|
| 339 |
+
use_whisper=use_whisper,
|
| 340 |
+
use_chromagram=use_chromagram,
|
| 341 |
+
construct_only_for_quantizer=construct_only_for_quantizer,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
def forward(
|
| 345 |
+
self,
|
| 346 |
+
chromagram_feats,
|
| 347 |
+
return_for_quantizer=False,
|
| 348 |
+
):
|
| 349 |
+
"""
|
| 350 |
+
Args:
|
| 351 |
+
chromagram_feats: [B, T, 24]
|
| 352 |
+
Returns:
|
| 353 |
+
chromagram_rec: [B, T, 24]
|
| 354 |
+
codebook_loss: float
|
| 355 |
+
all_indices: [N, B, T]
|
| 356 |
+
"""
|
| 357 |
+
T = chromagram_feats.shape[1]
|
| 358 |
+
|
| 359 |
+
# [B, T, D]
|
| 360 |
+
x = self.chromagram_input_layer(chromagram_feats)
|
| 361 |
+
|
| 362 |
+
# ====== Downsample ======
|
| 363 |
+
if self.do_downsample:
|
| 364 |
+
x = self.downsample_layers(x.transpose(1, 2)).transpose(1, 2)
|
| 365 |
+
|
| 366 |
+
# ====== Encoder ======
|
| 367 |
+
x = self.encoder(x.transpose(1, 2)).transpose(1, 2) # [B, T, D] -> [B, D, T]
|
| 368 |
+
|
| 369 |
+
# ====== Quantizer ======
|
| 370 |
+
(
|
| 371 |
+
quantized_out, # [B, D, T]
|
| 372 |
+
all_indices, # [num_of_quantizers, B, T]
|
| 373 |
+
all_commit_losses, # [num_of_quantizers]
|
| 374 |
+
all_codebook_losses, # [num_of_quantizers]
|
| 375 |
+
_,
|
| 376 |
+
) = self.quantizer(x)
|
| 377 |
+
|
| 378 |
+
if return_for_quantizer:
|
| 379 |
+
if all_indices.shape[0] == 1:
|
| 380 |
+
return all_indices.squeeze(0), quantized_out.transpose(1, 2)
|
| 381 |
+
return all_indices, quantized_out.transpose(1, 2)
|
| 382 |
+
|
| 383 |
+
# ====== Decoder ======
|
| 384 |
+
x_rec = self.decoder(quantized_out) # [B, T, D]
|
| 385 |
+
|
| 386 |
+
# ====== Upsample ======
|
| 387 |
+
if self.do_downsample:
|
| 388 |
+
x_rec = self.upsample_layers(x_rec.transpose(1, 2)).transpose(1, 2)
|
| 389 |
+
|
| 390 |
+
# Ensure output dimensions match input
|
| 391 |
+
if x_rec.shape[1] >= T: # Check time dimension
|
| 392 |
+
x_rec = x_rec[:, :T, :]
|
| 393 |
+
else:
|
| 394 |
+
padding_frames = T - x_rec.shape[1]
|
| 395 |
+
last_frame = x_rec[:, -1:, :]
|
| 396 |
+
padding = last_frame.repeat(1, padding_frames, 1)
|
| 397 |
+
x_rec = torch.cat([x_rec, padding], dim=1)
|
| 398 |
+
|
| 399 |
+
# ====== Loss ======
|
| 400 |
+
chromagram_rec = self.chromagram_output_layer(x_rec) # [B, T, 24]
|
| 401 |
+
|
| 402 |
+
codebook_loss = (all_codebook_losses + all_commit_losses).mean()
|
| 403 |
+
all_indices = all_indices
|
| 404 |
+
|
| 405 |
+
return chromagram_rec, codebook_loss, all_indices
|
| 406 |
+
|
| 407 |
+
def quantize(self, chromagram_feats):
|
| 408 |
+
all_indices, quantized_out = self.forward(
|
| 409 |
+
chromagram_feats, return_for_quantizer=True
|
| 410 |
+
)
|
| 411 |
+
return all_indices, quantized_out
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
# if __name__ == "__main__":
|
| 415 |
+
# from utils.util import JsonHParams
|
| 416 |
+
|
| 417 |
+
# cfg = JsonHParams(
|
| 418 |
+
# **{
|
| 419 |
+
# "whisper_dim": 1024,
|
| 420 |
+
# "chromagram_dim": 24,
|
| 421 |
+
# "global_speaker_encoder": {
|
| 422 |
+
# "input_dim": 128, # Eg: n_mels
|
| 423 |
+
# "hidden_size": 512, # 768 for emilia298k
|
| 424 |
+
# "num_hidden_layers": 4, # 6 for emilia298k
|
| 425 |
+
# "num_attention_heads": 8,
|
| 426 |
+
# },
|
| 427 |
+
# }
|
| 428 |
+
# )
|
| 429 |
+
# model = Coco(cfg=cfg)
|
| 430 |
+
|
| 431 |
+
# x = torch.randn(2, 150, 1024)
|
| 432 |
+
# tone_height = torch.randn(2)
|
| 433 |
+
# mels = torch.randn(2, 150, 128)
|
| 434 |
+
# mel_mask = torch.ones(2, 150)
|
| 435 |
+
|
| 436 |
+
# x_rec, codebook_loss, all_indices, auxillary_pred_outputs = model(
|
| 437 |
+
# x, tone_height, mels, mel_mask
|
| 438 |
+
# )
|
| 439 |
+
# print(x_rec.shape, codebook_loss, all_indices.shape)
|
| 440 |
+
# for k, v in auxillary_pred_outputs.items():
|
| 441 |
+
# print(k, v.shape)
|
models/codec/melvqgan/melspec.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import pyworld as pw
|
| 8 |
+
import numpy as np
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
import os
|
| 11 |
+
from torchaudio.functional import pitch_shift
|
| 12 |
+
import librosa
|
| 13 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import tqdm
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 20 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def dynamic_range_decompression(x, C=1):
|
| 24 |
+
return np.exp(x) / C
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 28 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def dynamic_range_decompression_torch(x, C=1):
|
| 32 |
+
return torch.exp(x) / C
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def spectral_normalize_torch(magnitudes):
|
| 36 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 37 |
+
return output
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def spectral_de_normalize_torch(magnitudes):
|
| 41 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
| 42 |
+
return output
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class MelSpectrogram(nn.Module):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
n_fft,
|
| 49 |
+
num_mels,
|
| 50 |
+
sampling_rate,
|
| 51 |
+
hop_size,
|
| 52 |
+
win_size,
|
| 53 |
+
fmin,
|
| 54 |
+
fmax,
|
| 55 |
+
center=False,
|
| 56 |
+
):
|
| 57 |
+
super(MelSpectrogram, self).__init__()
|
| 58 |
+
self.n_fft = n_fft
|
| 59 |
+
self.hop_size = hop_size
|
| 60 |
+
self.win_size = win_size
|
| 61 |
+
self.sampling_rate = sampling_rate
|
| 62 |
+
self.num_mels = num_mels
|
| 63 |
+
self.fmin = fmin
|
| 64 |
+
self.fmax = fmax
|
| 65 |
+
self.center = center
|
| 66 |
+
|
| 67 |
+
mel_basis = {}
|
| 68 |
+
hann_window = {}
|
| 69 |
+
|
| 70 |
+
mel = librosa_mel_fn(
|
| 71 |
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
| 72 |
+
)
|
| 73 |
+
mel_basis = torch.from_numpy(mel).float()
|
| 74 |
+
hann_window = torch.hann_window(win_size)
|
| 75 |
+
|
| 76 |
+
self.register_buffer("mel_basis", mel_basis)
|
| 77 |
+
self.register_buffer("hann_window", hann_window)
|
| 78 |
+
|
| 79 |
+
def forward(self, y):
|
| 80 |
+
y = torch.nn.functional.pad(
|
| 81 |
+
y.unsqueeze(1),
|
| 82 |
+
(
|
| 83 |
+
int((self.n_fft - self.hop_size) / 2),
|
| 84 |
+
int((self.n_fft - self.hop_size) / 2),
|
| 85 |
+
),
|
| 86 |
+
mode="reflect",
|
| 87 |
+
)
|
| 88 |
+
y = y.squeeze(1)
|
| 89 |
+
spec = torch.stft(
|
| 90 |
+
y,
|
| 91 |
+
self.n_fft,
|
| 92 |
+
hop_length=self.hop_size,
|
| 93 |
+
win_length=self.win_size,
|
| 94 |
+
window=self.hann_window,
|
| 95 |
+
center=self.center,
|
| 96 |
+
pad_mode="reflect",
|
| 97 |
+
normalized=False,
|
| 98 |
+
onesided=True,
|
| 99 |
+
return_complex=True,
|
| 100 |
+
)
|
| 101 |
+
spec = torch.view_as_real(spec)
|
| 102 |
+
|
| 103 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 104 |
+
|
| 105 |
+
spec = torch.matmul(self.mel_basis, spec)
|
| 106 |
+
spec = spectral_normalize_torch(spec)
|
| 107 |
+
|
| 108 |
+
return spec
|
utils/__init__.py
ADDED
|
File without changes
|
utils/hparam.py
ADDED
|
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# This code is modified from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py pylint: disable=line-too-long
|
| 7 |
+
"""Hyperparameter values."""
|
| 8 |
+
from __future__ import absolute_import
|
| 9 |
+
from __future__ import division
|
| 10 |
+
from __future__ import print_function
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import numbers
|
| 14 |
+
import re
|
| 15 |
+
import six
|
| 16 |
+
|
| 17 |
+
# Define the regular expression for parsing a single clause of the input
|
| 18 |
+
# (delimited by commas). A legal clause looks like:
|
| 19 |
+
# <variable name>[<index>]? = <rhs>
|
| 20 |
+
# where <rhs> is either a single token or [] enclosed list of tokens.
|
| 21 |
+
# For example: "var[1] = a" or "x = [1,2,3]"
|
| 22 |
+
PARAM_RE = re.compile(
|
| 23 |
+
r"""
|
| 24 |
+
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
|
| 25 |
+
(\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
|
| 26 |
+
\s*=\s*
|
| 27 |
+
((?P<val>[^,\[]*) # single value: "a" or None
|
| 28 |
+
|
|
| 29 |
+
\[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3"
|
| 30 |
+
($|,\s*)""",
|
| 31 |
+
re.VERBOSE,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _parse_fail(name, var_type, value, values):
|
| 36 |
+
"""Helper function for raising a value error for bad assignment."""
|
| 37 |
+
raise ValueError(
|
| 38 |
+
"Could not parse hparam '%s' of type '%s' with value '%s' in %s"
|
| 39 |
+
% (name, var_type.__name__, value, values)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _reuse_fail(name, values):
|
| 44 |
+
"""Helper function for raising a value error for reuse of name."""
|
| 45 |
+
raise ValueError("Multiple assignments to variable '%s' in %s" % (name, values))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _process_scalar_value(name, parse_fn, var_type, m_dict, values, results_dictionary):
|
| 49 |
+
"""Update results_dictionary with a scalar value.
|
| 50 |
+
|
| 51 |
+
Used to update the results_dictionary to be returned by parse_values when
|
| 52 |
+
encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".)
|
| 53 |
+
|
| 54 |
+
Mutates results_dictionary.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
name: Name of variable in assignment ("s" or "arr").
|
| 58 |
+
parse_fn: Function for parsing the actual value.
|
| 59 |
+
var_type: Type of named variable.
|
| 60 |
+
m_dict: Dictionary constructed from regex parsing.
|
| 61 |
+
m_dict['val']: RHS value (scalar)
|
| 62 |
+
m_dict['index']: List index value (or None)
|
| 63 |
+
values: Full expression being parsed
|
| 64 |
+
results_dictionary: The dictionary being updated for return by the parsing
|
| 65 |
+
function.
|
| 66 |
+
|
| 67 |
+
Raises:
|
| 68 |
+
ValueError: If the name has already been used.
|
| 69 |
+
"""
|
| 70 |
+
try:
|
| 71 |
+
parsed_value = parse_fn(m_dict["val"])
|
| 72 |
+
except ValueError:
|
| 73 |
+
_parse_fail(name, var_type, m_dict["val"], values)
|
| 74 |
+
|
| 75 |
+
# If no index is provided
|
| 76 |
+
if not m_dict["index"]:
|
| 77 |
+
if name in results_dictionary:
|
| 78 |
+
_reuse_fail(name, values)
|
| 79 |
+
results_dictionary[name] = parsed_value
|
| 80 |
+
else:
|
| 81 |
+
if name in results_dictionary:
|
| 82 |
+
# The name has already been used as a scalar, then it
|
| 83 |
+
# will be in this dictionary and map to a non-dictionary.
|
| 84 |
+
if not isinstance(results_dictionary.get(name), dict):
|
| 85 |
+
_reuse_fail(name, values)
|
| 86 |
+
else:
|
| 87 |
+
results_dictionary[name] = {}
|
| 88 |
+
|
| 89 |
+
index = int(m_dict["index"])
|
| 90 |
+
# Make sure the index position hasn't already been assigned a value.
|
| 91 |
+
if index in results_dictionary[name]:
|
| 92 |
+
_reuse_fail("{}[{}]".format(name, index), values)
|
| 93 |
+
results_dictionary[name][index] = parsed_value
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _process_list_value(name, parse_fn, var_type, m_dict, values, results_dictionary):
|
| 97 |
+
"""Update results_dictionary from a list of values.
|
| 98 |
+
|
| 99 |
+
Used to update results_dictionary to be returned by parse_values when
|
| 100 |
+
encountering a clause with a list RHS (e.g. "arr=[1,2,3]".)
|
| 101 |
+
|
| 102 |
+
Mutates results_dictionary.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
name: Name of variable in assignment ("arr").
|
| 106 |
+
parse_fn: Function for parsing individual values.
|
| 107 |
+
var_type: Type of named variable.
|
| 108 |
+
m_dict: Dictionary constructed from regex parsing.
|
| 109 |
+
m_dict['val']: RHS value (scalar)
|
| 110 |
+
values: Full expression being parsed
|
| 111 |
+
results_dictionary: The dictionary being updated for return by the parsing
|
| 112 |
+
function.
|
| 113 |
+
|
| 114 |
+
Raises:
|
| 115 |
+
ValueError: If the name has an index or the values cannot be parsed.
|
| 116 |
+
"""
|
| 117 |
+
if m_dict["index"] is not None:
|
| 118 |
+
raise ValueError("Assignment of a list to a list index.")
|
| 119 |
+
elements = filter(None, re.split("[ ,]", m_dict["vals"]))
|
| 120 |
+
# Make sure the name hasn't already been assigned a value
|
| 121 |
+
if name in results_dictionary:
|
| 122 |
+
raise _reuse_fail(name, values)
|
| 123 |
+
try:
|
| 124 |
+
results_dictionary[name] = [parse_fn(e) for e in elements]
|
| 125 |
+
except ValueError:
|
| 126 |
+
_parse_fail(name, var_type, m_dict["vals"], values)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _cast_to_type_if_compatible(name, param_type, value):
|
| 130 |
+
"""Cast hparam to the provided type, if compatible.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
name: Name of the hparam to be cast.
|
| 134 |
+
param_type: The type of the hparam.
|
| 135 |
+
value: The value to be cast, if compatible.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
The result of casting `value` to `param_type`.
|
| 139 |
+
|
| 140 |
+
Raises:
|
| 141 |
+
ValueError: If the type of `value` is not compatible with param_type.
|
| 142 |
+
* If `param_type` is a string type, but `value` is not.
|
| 143 |
+
* If `param_type` is a boolean, but `value` is not, or vice versa.
|
| 144 |
+
* If `param_type` is an integer type, but `value` is not.
|
| 145 |
+
* If `param_type` is a float type, but `value` is not a numeric type.
|
| 146 |
+
"""
|
| 147 |
+
fail_msg = "Could not cast hparam '%s' of type '%s' from value %r" % (
|
| 148 |
+
name,
|
| 149 |
+
param_type,
|
| 150 |
+
value,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Some callers use None, for which we can't do any casting/checking. :(
|
| 154 |
+
if issubclass(param_type, type(None)):
|
| 155 |
+
return value
|
| 156 |
+
|
| 157 |
+
# Avoid converting a non-string type to a string.
|
| 158 |
+
if issubclass(param_type, (six.string_types, six.binary_type)) and not isinstance(
|
| 159 |
+
value, (six.string_types, six.binary_type)
|
| 160 |
+
):
|
| 161 |
+
raise ValueError(fail_msg)
|
| 162 |
+
|
| 163 |
+
# Avoid converting a number or string type to a boolean or vice versa.
|
| 164 |
+
if issubclass(param_type, bool) != isinstance(value, bool):
|
| 165 |
+
raise ValueError(fail_msg)
|
| 166 |
+
|
| 167 |
+
# Avoid converting float to an integer (the reverse is fine).
|
| 168 |
+
if issubclass(param_type, numbers.Integral) and not isinstance(
|
| 169 |
+
value, numbers.Integral
|
| 170 |
+
):
|
| 171 |
+
raise ValueError(fail_msg)
|
| 172 |
+
|
| 173 |
+
# Avoid converting a non-numeric type to a numeric type.
|
| 174 |
+
if issubclass(param_type, numbers.Number) and not isinstance(value, numbers.Number):
|
| 175 |
+
raise ValueError(fail_msg)
|
| 176 |
+
|
| 177 |
+
return param_type(value)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def parse_values(values, type_map, ignore_unknown=False):
|
| 181 |
+
"""Parses hyperparameter values from a string into a python map.
|
| 182 |
+
|
| 183 |
+
`values` is a string containing comma-separated `name=value` pairs.
|
| 184 |
+
For each pair, the value of the hyperparameter named `name` is set to
|
| 185 |
+
`value`.
|
| 186 |
+
|
| 187 |
+
If a hyperparameter name appears multiple times in `values`, a ValueError
|
| 188 |
+
is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
|
| 189 |
+
|
| 190 |
+
If a hyperparameter name in both an index assignment and scalar assignment,
|
| 191 |
+
a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
|
| 192 |
+
|
| 193 |
+
The hyperparameter name may contain '.' symbols, which will result in an
|
| 194 |
+
attribute name that is only accessible through the getattr and setattr
|
| 195 |
+
functions. (And must be first explicit added through add_hparam.)
|
| 196 |
+
|
| 197 |
+
WARNING: Use of '.' in your variable names is allowed, but is not well
|
| 198 |
+
supported and not recommended.
|
| 199 |
+
|
| 200 |
+
The `value` in `name=value` must follows the syntax according to the
|
| 201 |
+
type of the parameter:
|
| 202 |
+
|
| 203 |
+
* Scalar integer: A Python-parsable integer point value. E.g.: 1,
|
| 204 |
+
100, -12.
|
| 205 |
+
* Scalar float: A Python-parsable floating point value. E.g.: 1.0,
|
| 206 |
+
-.54e89.
|
| 207 |
+
* Boolean: Either true or false.
|
| 208 |
+
* Scalar string: A non-empty sequence of characters, excluding comma,
|
| 209 |
+
spaces, and square brackets. E.g.: foo, bar_1.
|
| 210 |
+
* List: A comma separated list of scalar values of the parameter type
|
| 211 |
+
enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low].
|
| 212 |
+
|
| 213 |
+
When index assignment is used, the corresponding type_map key should be the
|
| 214 |
+
list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not
|
| 215 |
+
"arr[1]").
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
values: String. Comma separated list of `name=value` pairs where
|
| 219 |
+
'value' must follow the syntax described above.
|
| 220 |
+
type_map: A dictionary mapping hyperparameter names to types. Note every
|
| 221 |
+
parameter name in values must be a key in type_map. The values must
|
| 222 |
+
conform to the types indicated, where a value V is said to conform to a
|
| 223 |
+
type T if either V has type T, or V is a list of elements of type T.
|
| 224 |
+
Hence, for a multidimensional parameter 'x' taking float values,
|
| 225 |
+
'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
|
| 226 |
+
ignore_unknown: Bool. Whether values that are missing a type in type_map
|
| 227 |
+
should be ignored. If set to True, a ValueError will not be raised for
|
| 228 |
+
unknown hyperparameter type.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
A python map mapping each name to either:
|
| 232 |
+
* A scalar value.
|
| 233 |
+
* A list of scalar values.
|
| 234 |
+
* A dictionary mapping index numbers to scalar values.
|
| 235 |
+
(e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
|
| 236 |
+
|
| 237 |
+
Raises:
|
| 238 |
+
ValueError: If there is a problem with input.
|
| 239 |
+
* If `values` cannot be parsed.
|
| 240 |
+
* If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
|
| 241 |
+
* If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
|
| 242 |
+
'a[1]=1,a[1]=2', or 'a=1,a=[1]')
|
| 243 |
+
"""
|
| 244 |
+
results_dictionary = {}
|
| 245 |
+
pos = 0
|
| 246 |
+
while pos < len(values):
|
| 247 |
+
m = PARAM_RE.match(values, pos)
|
| 248 |
+
if not m:
|
| 249 |
+
raise ValueError("Malformed hyperparameter value: %s" % values[pos:])
|
| 250 |
+
# Check that there is a comma between parameters and move past it.
|
| 251 |
+
pos = m.end()
|
| 252 |
+
# Parse the values.
|
| 253 |
+
m_dict = m.groupdict()
|
| 254 |
+
name = m_dict["name"]
|
| 255 |
+
if name not in type_map:
|
| 256 |
+
if ignore_unknown:
|
| 257 |
+
continue
|
| 258 |
+
raise ValueError("Unknown hyperparameter type for %s" % name)
|
| 259 |
+
type_ = type_map[name]
|
| 260 |
+
|
| 261 |
+
# Set up correct parsing function (depending on whether type_ is a bool)
|
| 262 |
+
if type_ == bool:
|
| 263 |
+
|
| 264 |
+
def parse_bool(value):
|
| 265 |
+
if value in ["true", "True"]:
|
| 266 |
+
return True
|
| 267 |
+
elif value in ["false", "False"]:
|
| 268 |
+
return False
|
| 269 |
+
else:
|
| 270 |
+
try:
|
| 271 |
+
return bool(int(value))
|
| 272 |
+
except ValueError:
|
| 273 |
+
_parse_fail(name, type_, value, values)
|
| 274 |
+
|
| 275 |
+
parse = parse_bool
|
| 276 |
+
else:
|
| 277 |
+
parse = type_
|
| 278 |
+
|
| 279 |
+
# If a singe value is provided
|
| 280 |
+
if m_dict["val"] is not None:
|
| 281 |
+
_process_scalar_value(
|
| 282 |
+
name, parse, type_, m_dict, values, results_dictionary
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# If the assigned value is a list:
|
| 286 |
+
elif m_dict["vals"] is not None:
|
| 287 |
+
_process_list_value(name, parse, type_, m_dict, values, results_dictionary)
|
| 288 |
+
|
| 289 |
+
else: # Not assigned a list or value
|
| 290 |
+
_parse_fail(name, type_, "", values)
|
| 291 |
+
|
| 292 |
+
return results_dictionary
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class HParams(object):
|
| 296 |
+
"""Class to hold a set of hyperparameters as name-value pairs.
|
| 297 |
+
|
| 298 |
+
A `HParams` object holds hyperparameters used to build and train a model,
|
| 299 |
+
such as the number of hidden units in a neural net layer or the learning rate
|
| 300 |
+
to use when training.
|
| 301 |
+
|
| 302 |
+
You first create a `HParams` object by specifying the names and values of the
|
| 303 |
+
hyperparameters.
|
| 304 |
+
|
| 305 |
+
To make them easily accessible the parameter names are added as direct
|
| 306 |
+
attributes of the class. A typical usage is as follows:
|
| 307 |
+
|
| 308 |
+
```python
|
| 309 |
+
# Create a HParams object specifying names and values of the model
|
| 310 |
+
# hyperparameters:
|
| 311 |
+
hparams = HParams(learning_rate=0.1, num_hidden_units=100)
|
| 312 |
+
|
| 313 |
+
# The hyperparameter are available as attributes of the HParams object:
|
| 314 |
+
hparams.learning_rate ==> 0.1
|
| 315 |
+
hparams.num_hidden_units ==> 100
|
| 316 |
+
```
|
| 317 |
+
|
| 318 |
+
Hyperparameters have type, which is inferred from the type of their value
|
| 319 |
+
passed at construction type. The currently supported types are: integer,
|
| 320 |
+
float, boolean, string, and list of integer, float, boolean, or string.
|
| 321 |
+
|
| 322 |
+
You can override hyperparameter values by calling the
|
| 323 |
+
[`parse()`](#HParams.parse) method, passing a string of comma separated
|
| 324 |
+
`name=value` pairs. This is intended to make it possible to override
|
| 325 |
+
any hyperparameter values from a single command-line flag to which
|
| 326 |
+
the user passes 'hyper-param=value' pairs. It avoids having to define
|
| 327 |
+
one flag for each hyperparameter.
|
| 328 |
+
|
| 329 |
+
The syntax expected for each value depends on the type of the parameter.
|
| 330 |
+
See `parse()` for a description of the syntax.
|
| 331 |
+
|
| 332 |
+
Example:
|
| 333 |
+
|
| 334 |
+
```python
|
| 335 |
+
# Define a command line flag to pass name=value pairs.
|
| 336 |
+
# For example using argparse:
|
| 337 |
+
import argparse
|
| 338 |
+
parser = argparse.ArgumentParser(description='Train my model.')
|
| 339 |
+
parser.add_argument('--hparams', type=str,
|
| 340 |
+
help='Comma separated list of "name=value" pairs.')
|
| 341 |
+
args = parser.parse_args()
|
| 342 |
+
...
|
| 343 |
+
def my_program():
|
| 344 |
+
# Create a HParams object specifying the names and values of the
|
| 345 |
+
# model hyperparameters:
|
| 346 |
+
hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100,
|
| 347 |
+
activations=['relu', 'tanh'])
|
| 348 |
+
|
| 349 |
+
# Override hyperparameters values by parsing the command line
|
| 350 |
+
hparams.parse(args.hparams)
|
| 351 |
+
|
| 352 |
+
# If the user passed `--hparams=learning_rate=0.3` on the command line
|
| 353 |
+
# then 'hparams' has the following attributes:
|
| 354 |
+
hparams.learning_rate ==> 0.3
|
| 355 |
+
hparams.num_hidden_units ==> 100
|
| 356 |
+
hparams.activations ==> ['relu', 'tanh']
|
| 357 |
+
|
| 358 |
+
# If the hyperparameters are in json format use parse_json:
|
| 359 |
+
hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
|
| 360 |
+
```
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
_HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks.
|
| 364 |
+
|
| 365 |
+
def __init__(self, model_structure=None, **kwargs):
|
| 366 |
+
"""Create an instance of `HParams` from keyword arguments.
|
| 367 |
+
|
| 368 |
+
The keyword arguments specify name-values pairs for the hyperparameters.
|
| 369 |
+
The parameter types are inferred from the type of the values passed.
|
| 370 |
+
|
| 371 |
+
The parameter names are added as attributes of `HParams` object, so they
|
| 372 |
+
can be accessed directly with the dot notation `hparams._name_`.
|
| 373 |
+
|
| 374 |
+
Example:
|
| 375 |
+
|
| 376 |
+
```python
|
| 377 |
+
# Define 3 hyperparameters: 'learning_rate' is a float parameter,
|
| 378 |
+
# 'num_hidden_units' an integer parameter, and 'activation' a string
|
| 379 |
+
# parameter.
|
| 380 |
+
hparams = tf.HParams(
|
| 381 |
+
learning_rate=0.1, num_hidden_units=100, activation='relu')
|
| 382 |
+
|
| 383 |
+
hparams.activation ==> 'relu'
|
| 384 |
+
```
|
| 385 |
+
|
| 386 |
+
Note that a few names are reserved and cannot be used as hyperparameter
|
| 387 |
+
names. If you use one of the reserved name the constructor raises a
|
| 388 |
+
`ValueError`.
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
model_structure: An instance of ModelStructure, defining the feature
|
| 392 |
+
crosses to be used in the Trial.
|
| 393 |
+
**kwargs: Key-value pairs where the key is the hyperparameter name and
|
| 394 |
+
the value is the value for the parameter.
|
| 395 |
+
|
| 396 |
+
Raises:
|
| 397 |
+
ValueError: If both `hparam_def` and initialization values are provided,
|
| 398 |
+
or if one of the arguments is invalid.
|
| 399 |
+
|
| 400 |
+
"""
|
| 401 |
+
# Register the hyperparameters and their type in _hparam_types.
|
| 402 |
+
# This simplifies the implementation of parse().
|
| 403 |
+
# _hparam_types maps the parameter name to a tuple (type, bool).
|
| 404 |
+
# The type value is the type of the parameter for scalar hyperparameters,
|
| 405 |
+
# or the type of the list elements for multidimensional hyperparameters.
|
| 406 |
+
# The bool value is True if the value is a list, False otherwise.
|
| 407 |
+
self._hparam_types = {}
|
| 408 |
+
self._model_structure = model_structure
|
| 409 |
+
for name, value in six.iteritems(kwargs):
|
| 410 |
+
self.add_hparam(name, value)
|
| 411 |
+
|
| 412 |
+
def add_hparam(self, name, value):
|
| 413 |
+
"""Adds {name, value} pair to hyperparameters.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
name: Name of the hyperparameter.
|
| 417 |
+
value: Value of the hyperparameter. Can be one of the following types:
|
| 418 |
+
int, float, string, int list, float list, or string list.
|
| 419 |
+
|
| 420 |
+
Raises:
|
| 421 |
+
ValueError: if one of the arguments is invalid.
|
| 422 |
+
"""
|
| 423 |
+
# Keys in kwargs are unique, but 'name' could the name of a pre-existing
|
| 424 |
+
# attribute of this object. In that case we refuse to use it as a
|
| 425 |
+
# hyperparameter name.
|
| 426 |
+
if getattr(self, name, None) is not None:
|
| 427 |
+
raise ValueError("Hyperparameter name is reserved: %s" % name)
|
| 428 |
+
if isinstance(value, (list, tuple)):
|
| 429 |
+
if not value:
|
| 430 |
+
raise ValueError(
|
| 431 |
+
"Multi-valued hyperparameters cannot be empty: %s" % name
|
| 432 |
+
)
|
| 433 |
+
self._hparam_types[name] = (type(value[0]), True)
|
| 434 |
+
else:
|
| 435 |
+
self._hparam_types[name] = (type(value), False)
|
| 436 |
+
setattr(self, name, value)
|
| 437 |
+
|
| 438 |
+
def set_hparam(self, name, value):
|
| 439 |
+
"""Set the value of an existing hyperparameter.
|
| 440 |
+
|
| 441 |
+
This function verifies that the type of the value matches the type of the
|
| 442 |
+
existing hyperparameter.
|
| 443 |
+
|
| 444 |
+
Args:
|
| 445 |
+
name: Name of the hyperparameter.
|
| 446 |
+
value: New value of the hyperparameter.
|
| 447 |
+
|
| 448 |
+
Raises:
|
| 449 |
+
KeyError: If the hyperparameter doesn't exist.
|
| 450 |
+
ValueError: If there is a type mismatch.
|
| 451 |
+
"""
|
| 452 |
+
param_type, is_list = self._hparam_types[name]
|
| 453 |
+
if isinstance(value, list):
|
| 454 |
+
if not is_list:
|
| 455 |
+
raise ValueError(
|
| 456 |
+
"Must not pass a list for single-valued parameter: %s" % name
|
| 457 |
+
)
|
| 458 |
+
setattr(
|
| 459 |
+
self,
|
| 460 |
+
name,
|
| 461 |
+
[_cast_to_type_if_compatible(name, param_type, v) for v in value],
|
| 462 |
+
)
|
| 463 |
+
else:
|
| 464 |
+
if is_list:
|
| 465 |
+
raise ValueError(
|
| 466 |
+
"Must pass a list for multi-valued parameter: %s." % name
|
| 467 |
+
)
|
| 468 |
+
setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
|
| 469 |
+
|
| 470 |
+
def del_hparam(self, name):
|
| 471 |
+
"""Removes the hyperparameter with key 'name'.
|
| 472 |
+
|
| 473 |
+
Does nothing if it isn't present.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
name: Name of the hyperparameter.
|
| 477 |
+
"""
|
| 478 |
+
if hasattr(self, name):
|
| 479 |
+
delattr(self, name)
|
| 480 |
+
del self._hparam_types[name]
|
| 481 |
+
|
| 482 |
+
def parse(self, values):
|
| 483 |
+
"""Override existing hyperparameter values, parsing new values from a string.
|
| 484 |
+
|
| 485 |
+
See parse_values for more detail on the allowed format for values.
|
| 486 |
+
|
| 487 |
+
Args:
|
| 488 |
+
values: String. Comma separated list of `name=value` pairs where 'value'
|
| 489 |
+
must follow the syntax described above.
|
| 490 |
+
|
| 491 |
+
Returns:
|
| 492 |
+
The `HParams` instance.
|
| 493 |
+
|
| 494 |
+
Raises:
|
| 495 |
+
ValueError: If `values` cannot be parsed or a hyperparameter in `values`
|
| 496 |
+
doesn't exist.
|
| 497 |
+
"""
|
| 498 |
+
type_map = {}
|
| 499 |
+
for name, t in self._hparam_types.items():
|
| 500 |
+
param_type, _ = t
|
| 501 |
+
type_map[name] = param_type
|
| 502 |
+
|
| 503 |
+
values_map = parse_values(values, type_map)
|
| 504 |
+
return self.override_from_dict(values_map)
|
| 505 |
+
|
| 506 |
+
def override_from_dict(self, values_dict):
|
| 507 |
+
"""Override existing hyperparameter values, parsing new values from a dictionary.
|
| 508 |
+
|
| 509 |
+
Args:
|
| 510 |
+
values_dict: Dictionary of name:value pairs.
|
| 511 |
+
|
| 512 |
+
Returns:
|
| 513 |
+
The `HParams` instance.
|
| 514 |
+
|
| 515 |
+
Raises:
|
| 516 |
+
KeyError: If a hyperparameter in `values_dict` doesn't exist.
|
| 517 |
+
ValueError: If `values_dict` cannot be parsed.
|
| 518 |
+
"""
|
| 519 |
+
for name, value in values_dict.items():
|
| 520 |
+
self.set_hparam(name, value)
|
| 521 |
+
return self
|
| 522 |
+
|
| 523 |
+
def set_model_structure(self, model_structure):
|
| 524 |
+
self._model_structure = model_structure
|
| 525 |
+
|
| 526 |
+
def get_model_structure(self):
|
| 527 |
+
return self._model_structure
|
| 528 |
+
|
| 529 |
+
def to_json(self, indent=None, separators=None, sort_keys=False):
|
| 530 |
+
"""Serializes the hyperparameters into JSON.
|
| 531 |
+
|
| 532 |
+
Args:
|
| 533 |
+
indent: If a non-negative integer, JSON array elements and object members
|
| 534 |
+
will be pretty-printed with that indent level. An indent level of 0, or
|
| 535 |
+
negative, will only insert newlines. `None` (the default) selects the
|
| 536 |
+
most compact representation.
|
| 537 |
+
separators: Optional `(item_separator, key_separator)` tuple. Default is
|
| 538 |
+
`(', ', ': ')`.
|
| 539 |
+
sort_keys: If `True`, the output dictionaries will be sorted by key.
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
A JSON string.
|
| 543 |
+
"""
|
| 544 |
+
|
| 545 |
+
def remove_callables(x):
|
| 546 |
+
"""Omit callable elements from input with arbitrary nesting."""
|
| 547 |
+
if isinstance(x, dict):
|
| 548 |
+
return {
|
| 549 |
+
k: remove_callables(v)
|
| 550 |
+
for k, v in six.iteritems(x)
|
| 551 |
+
if not callable(v)
|
| 552 |
+
}
|
| 553 |
+
elif isinstance(x, list):
|
| 554 |
+
return [remove_callables(i) for i in x if not callable(i)]
|
| 555 |
+
return x
|
| 556 |
+
|
| 557 |
+
return json.dumps(
|
| 558 |
+
remove_callables(self.values()),
|
| 559 |
+
indent=indent,
|
| 560 |
+
separators=separators,
|
| 561 |
+
sort_keys=sort_keys,
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
def parse_json(self, values_json):
|
| 565 |
+
"""Override existing hyperparameter values, parsing new values from a json object.
|
| 566 |
+
|
| 567 |
+
Args:
|
| 568 |
+
values_json: String containing a json object of name:value pairs.
|
| 569 |
+
|
| 570 |
+
Returns:
|
| 571 |
+
The `HParams` instance.
|
| 572 |
+
|
| 573 |
+
Raises:
|
| 574 |
+
KeyError: If a hyperparameter in `values_json` doesn't exist.
|
| 575 |
+
ValueError: If `values_json` cannot be parsed.
|
| 576 |
+
"""
|
| 577 |
+
values_map = json.loads(values_json)
|
| 578 |
+
return self.override_from_dict(values_map)
|
| 579 |
+
|
| 580 |
+
def values(self):
|
| 581 |
+
"""Return the hyperparameter values as a Python dictionary.
|
| 582 |
+
|
| 583 |
+
Returns:
|
| 584 |
+
A dictionary with hyperparameter names as keys. The values are the
|
| 585 |
+
hyperparameter values.
|
| 586 |
+
"""
|
| 587 |
+
return {n: getattr(self, n) for n in self._hparam_types.keys()}
|
| 588 |
+
|
| 589 |
+
def get(self, key, default=None):
|
| 590 |
+
"""Returns the value of `key` if it exists, else `default`."""
|
| 591 |
+
if key in self._hparam_types:
|
| 592 |
+
# Ensure that default is compatible with the parameter type.
|
| 593 |
+
if default is not None:
|
| 594 |
+
param_type, is_param_list = self._hparam_types[key]
|
| 595 |
+
type_str = "list<%s>" % param_type if is_param_list else str(param_type)
|
| 596 |
+
fail_msg = (
|
| 597 |
+
"Hparam '%s' of type '%s' is incompatible with "
|
| 598 |
+
"default=%s" % (key, type_str, default)
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
is_default_list = isinstance(default, list)
|
| 602 |
+
if is_param_list != is_default_list:
|
| 603 |
+
raise ValueError(fail_msg)
|
| 604 |
+
|
| 605 |
+
try:
|
| 606 |
+
if is_default_list:
|
| 607 |
+
for value in default:
|
| 608 |
+
_cast_to_type_if_compatible(key, param_type, value)
|
| 609 |
+
else:
|
| 610 |
+
_cast_to_type_if_compatible(key, param_type, default)
|
| 611 |
+
except ValueError as e:
|
| 612 |
+
raise ValueError("%s. %s" % (fail_msg, e))
|
| 613 |
+
|
| 614 |
+
return getattr(self, key)
|
| 615 |
+
|
| 616 |
+
return default
|
| 617 |
+
|
| 618 |
+
def __contains__(self, key):
|
| 619 |
+
return key in self._hparam_types
|
| 620 |
+
|
| 621 |
+
def __str__(self):
|
| 622 |
+
return str(sorted(self.values().items()))
|
| 623 |
+
|
| 624 |
+
def __repr__(self):
|
| 625 |
+
return "%s(%s)" % (type(self).__name__, self.__str__())
|
| 626 |
+
|
| 627 |
+
@staticmethod
|
| 628 |
+
def _get_kind_name(param_type, is_list):
|
| 629 |
+
"""Returns the field name given parameter type and is_list.
|
| 630 |
+
|
| 631 |
+
Args:
|
| 632 |
+
param_type: Data type of the hparam.
|
| 633 |
+
is_list: Whether this is a list.
|
| 634 |
+
|
| 635 |
+
Returns:
|
| 636 |
+
A string representation of the field name.
|
| 637 |
+
|
| 638 |
+
Raises:
|
| 639 |
+
ValueError: If parameter type is not recognized.
|
| 640 |
+
"""
|
| 641 |
+
if issubclass(param_type, bool):
|
| 642 |
+
# This check must happen before issubclass(param_type, six.integer_types),
|
| 643 |
+
# since Python considers bool to be a subclass of int.
|
| 644 |
+
typename = "bool"
|
| 645 |
+
elif issubclass(param_type, six.integer_types):
|
| 646 |
+
# Setting 'int' and 'long' types to be 'int64' to ensure the type is
|
| 647 |
+
# compatible with both Python2 and Python3.
|
| 648 |
+
typename = "int64"
|
| 649 |
+
elif issubclass(param_type, (six.string_types, six.binary_type)):
|
| 650 |
+
# Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
|
| 651 |
+
# compatible with both Python2 and Python3.
|
| 652 |
+
typename = "bytes"
|
| 653 |
+
elif issubclass(param_type, float):
|
| 654 |
+
typename = "float"
|
| 655 |
+
else:
|
| 656 |
+
raise ValueError("Unsupported parameter type: %s" % str(param_type))
|
| 657 |
+
|
| 658 |
+
suffix = "list" if is_list else "value"
|
| 659 |
+
return "_".join([typename, suffix])
|
utils/util.py
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import collections
|
| 8 |
+
import glob
|
| 9 |
+
import os
|
| 10 |
+
import random
|
| 11 |
+
import time
|
| 12 |
+
import argparse
|
| 13 |
+
from collections import OrderedDict
|
| 14 |
+
|
| 15 |
+
import json5
|
| 16 |
+
import numpy as np
|
| 17 |
+
import glob
|
| 18 |
+
from torch.nn import functional as F
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from ruamel.yaml import YAML as yaml
|
| 23 |
+
except:
|
| 24 |
+
from ruamel_yaml import YAML as yaml
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
from utils.hparam import HParams
|
| 29 |
+
import logging
|
| 30 |
+
from logging import handlers
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def str2bool(v):
|
| 34 |
+
"""Used in argparse.ArgumentParser.add_argument to indicate
|
| 35 |
+
that a type is a bool type and user can enter
|
| 36 |
+
|
| 37 |
+
- yes, true, t, y, 1, to represent True
|
| 38 |
+
- no, false, f, n, 0, to represent False
|
| 39 |
+
|
| 40 |
+
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
| 41 |
+
"""
|
| 42 |
+
if isinstance(v, bool):
|
| 43 |
+
return v
|
| 44 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 45 |
+
return True
|
| 46 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 47 |
+
return False
|
| 48 |
+
else:
|
| 49 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def find_checkpoint_of_mapper(mapper_ckpt_dir):
|
| 53 |
+
mapper_ckpts = glob.glob(os.path.join(mapper_ckpt_dir, "ckpts/*.pt"))
|
| 54 |
+
|
| 55 |
+
# Select the max steps
|
| 56 |
+
mapper_ckpts.sort()
|
| 57 |
+
mapper_weights_file = mapper_ckpts[-1]
|
| 58 |
+
return mapper_weights_file
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def pad_f0_to_tensors(f0s, batched=None):
|
| 62 |
+
# Initialize
|
| 63 |
+
tensors = []
|
| 64 |
+
|
| 65 |
+
if batched == None:
|
| 66 |
+
# Get the max frame for padding
|
| 67 |
+
size = -1
|
| 68 |
+
for f0 in f0s:
|
| 69 |
+
size = max(size, f0.shape[-1])
|
| 70 |
+
|
| 71 |
+
tensor = torch.zeros(len(f0s), size)
|
| 72 |
+
|
| 73 |
+
for i, f0 in enumerate(f0s):
|
| 74 |
+
tensor[i, : f0.shape[-1]] = f0[:]
|
| 75 |
+
|
| 76 |
+
tensors.append(tensor)
|
| 77 |
+
else:
|
| 78 |
+
start = 0
|
| 79 |
+
while start + batched - 1 < len(f0s):
|
| 80 |
+
end = start + batched - 1
|
| 81 |
+
|
| 82 |
+
# Get the max frame for padding
|
| 83 |
+
size = -1
|
| 84 |
+
for i in range(start, end + 1):
|
| 85 |
+
size = max(size, f0s[i].shape[-1])
|
| 86 |
+
|
| 87 |
+
tensor = torch.zeros(batched, size)
|
| 88 |
+
|
| 89 |
+
for i in range(start, end + 1):
|
| 90 |
+
tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:]
|
| 91 |
+
|
| 92 |
+
tensors.append(tensor)
|
| 93 |
+
|
| 94 |
+
start = start + batched
|
| 95 |
+
|
| 96 |
+
if start != len(f0s):
|
| 97 |
+
end = len(f0s)
|
| 98 |
+
|
| 99 |
+
# Get the max frame for padding
|
| 100 |
+
size = -1
|
| 101 |
+
for i in range(start, end):
|
| 102 |
+
size = max(size, f0s[i].shape[-1])
|
| 103 |
+
|
| 104 |
+
tensor = torch.zeros(len(f0s) - start, size)
|
| 105 |
+
|
| 106 |
+
for i in range(start, end):
|
| 107 |
+
tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:]
|
| 108 |
+
|
| 109 |
+
tensors.append(tensor)
|
| 110 |
+
|
| 111 |
+
return tensors
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def pad_mels_to_tensors(mels, batched=None):
|
| 115 |
+
"""
|
| 116 |
+
Args:
|
| 117 |
+
mels: A list of mel-specs
|
| 118 |
+
Returns:
|
| 119 |
+
tensors: A list of tensors containing the batched mel-specs
|
| 120 |
+
mel_frames: A list of tensors containing the frames of the original mel-specs
|
| 121 |
+
"""
|
| 122 |
+
# Initialize
|
| 123 |
+
tensors = []
|
| 124 |
+
mel_frames = []
|
| 125 |
+
|
| 126 |
+
# Split mel-specs into batches to avoid cuda memory exceed
|
| 127 |
+
if batched == None:
|
| 128 |
+
# Get the max frame for padding
|
| 129 |
+
size = -1
|
| 130 |
+
for mel in mels:
|
| 131 |
+
size = max(size, mel.shape[-1])
|
| 132 |
+
|
| 133 |
+
tensor = torch.zeros(len(mels), mels[0].shape[0], size)
|
| 134 |
+
mel_frame = torch.zeros(len(mels), dtype=torch.int32)
|
| 135 |
+
|
| 136 |
+
for i, mel in enumerate(mels):
|
| 137 |
+
tensor[i, :, : mel.shape[-1]] = mel[:]
|
| 138 |
+
mel_frame[i] = mel.shape[-1]
|
| 139 |
+
|
| 140 |
+
tensors.append(tensor)
|
| 141 |
+
mel_frames.append(mel_frame)
|
| 142 |
+
else:
|
| 143 |
+
start = 0
|
| 144 |
+
while start + batched - 1 < len(mels):
|
| 145 |
+
end = start + batched - 1
|
| 146 |
+
|
| 147 |
+
# Get the max frame for padding
|
| 148 |
+
size = -1
|
| 149 |
+
for i in range(start, end + 1):
|
| 150 |
+
size = max(size, mels[i].shape[-1])
|
| 151 |
+
|
| 152 |
+
tensor = torch.zeros(batched, mels[0].shape[0], size)
|
| 153 |
+
mel_frame = torch.zeros(batched, dtype=torch.int32)
|
| 154 |
+
|
| 155 |
+
for i in range(start, end + 1):
|
| 156 |
+
tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:]
|
| 157 |
+
mel_frame[i - start] = mels[i].shape[-1]
|
| 158 |
+
|
| 159 |
+
tensors.append(tensor)
|
| 160 |
+
mel_frames.append(mel_frame)
|
| 161 |
+
|
| 162 |
+
start = start + batched
|
| 163 |
+
|
| 164 |
+
if start != len(mels):
|
| 165 |
+
end = len(mels)
|
| 166 |
+
|
| 167 |
+
# Get the max frame for padding
|
| 168 |
+
size = -1
|
| 169 |
+
for i in range(start, end):
|
| 170 |
+
size = max(size, mels[i].shape[-1])
|
| 171 |
+
|
| 172 |
+
tensor = torch.zeros(len(mels) - start, mels[0].shape[0], size)
|
| 173 |
+
mel_frame = torch.zeros(len(mels) - start, dtype=torch.int32)
|
| 174 |
+
|
| 175 |
+
for i in range(start, end):
|
| 176 |
+
tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:]
|
| 177 |
+
mel_frame[i - start] = mels[i].shape[-1]
|
| 178 |
+
|
| 179 |
+
tensors.append(tensor)
|
| 180 |
+
mel_frames.append(mel_frame)
|
| 181 |
+
|
| 182 |
+
return tensors, mel_frames
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def load_model_config(args):
|
| 186 |
+
"""Load model configurations (in args.json under checkpoint directory)
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
args (ArgumentParser): arguments to run bins/preprocess.py
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
dict: dictionary that stores model configurations
|
| 193 |
+
"""
|
| 194 |
+
if args.checkpoint_dir is None:
|
| 195 |
+
assert args.checkpoint_file is not None
|
| 196 |
+
checkpoint_dir = os.path.split(args.checkpoint_file)[0]
|
| 197 |
+
else:
|
| 198 |
+
checkpoint_dir = args.checkpoint_dir
|
| 199 |
+
config_path = os.path.join(checkpoint_dir, "args.json")
|
| 200 |
+
print("config_path: ", config_path)
|
| 201 |
+
|
| 202 |
+
config = load_config(config_path)
|
| 203 |
+
return config
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def remove_and_create(dir):
|
| 207 |
+
if os.path.exists(dir):
|
| 208 |
+
os.system("rm -r {}".format(dir))
|
| 209 |
+
os.makedirs(dir, exist_ok=True)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def has_existed(path, warning=False):
|
| 213 |
+
if not warning:
|
| 214 |
+
return os.path.exists(path)
|
| 215 |
+
|
| 216 |
+
if os.path.exists(path):
|
| 217 |
+
answer = input(
|
| 218 |
+
"The path {} has existed. \nInput 'y' (or hit Enter) to skip it, and input 'n' to re-write it [y/n]\n".format(
|
| 219 |
+
path
|
| 220 |
+
)
|
| 221 |
+
)
|
| 222 |
+
if not answer == "n":
|
| 223 |
+
return True
|
| 224 |
+
|
| 225 |
+
return False
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def remove_older_ckpt(saved_model_name, checkpoint_dir, max_to_keep=5):
|
| 229 |
+
if os.path.exists(os.path.join(checkpoint_dir, "checkpoint")):
|
| 230 |
+
with open(os.path.join(checkpoint_dir, "checkpoint"), "r") as f:
|
| 231 |
+
ckpts = [x.strip() for x in f.readlines()]
|
| 232 |
+
else:
|
| 233 |
+
ckpts = []
|
| 234 |
+
ckpts.append(saved_model_name)
|
| 235 |
+
for item in ckpts[:-max_to_keep]:
|
| 236 |
+
if os.path.exists(os.path.join(checkpoint_dir, item)):
|
| 237 |
+
os.remove(os.path.join(checkpoint_dir, item))
|
| 238 |
+
with open(os.path.join(checkpoint_dir, "checkpoint"), "w") as f:
|
| 239 |
+
for item in ckpts[-max_to_keep:]:
|
| 240 |
+
f.write("{}\n".format(item))
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def set_all_random_seed(seed: int):
|
| 244 |
+
random.seed(seed)
|
| 245 |
+
np.random.seed(seed)
|
| 246 |
+
torch.random.manual_seed(seed)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def save_checkpoint(
|
| 250 |
+
args,
|
| 251 |
+
generator,
|
| 252 |
+
g_optimizer,
|
| 253 |
+
step,
|
| 254 |
+
discriminator=None,
|
| 255 |
+
d_optimizer=None,
|
| 256 |
+
max_to_keep=5,
|
| 257 |
+
):
|
| 258 |
+
saved_model_name = "model.ckpt-{}.pt".format(step)
|
| 259 |
+
checkpoint_path = os.path.join(args.checkpoint_dir, saved_model_name)
|
| 260 |
+
|
| 261 |
+
if discriminator and d_optimizer:
|
| 262 |
+
torch.save(
|
| 263 |
+
{
|
| 264 |
+
"generator": generator.state_dict(),
|
| 265 |
+
"discriminator": discriminator.state_dict(),
|
| 266 |
+
"g_optimizer": g_optimizer.state_dict(),
|
| 267 |
+
"d_optimizer": d_optimizer.state_dict(),
|
| 268 |
+
"global_step": step,
|
| 269 |
+
},
|
| 270 |
+
checkpoint_path,
|
| 271 |
+
)
|
| 272 |
+
else:
|
| 273 |
+
torch.save(
|
| 274 |
+
{
|
| 275 |
+
"generator": generator.state_dict(),
|
| 276 |
+
"g_optimizer": g_optimizer.state_dict(),
|
| 277 |
+
"global_step": step,
|
| 278 |
+
},
|
| 279 |
+
checkpoint_path,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
print("Saved checkpoint: {}".format(checkpoint_path))
|
| 283 |
+
|
| 284 |
+
if os.path.exists(os.path.join(args.checkpoint_dir, "checkpoint")):
|
| 285 |
+
with open(os.path.join(args.checkpoint_dir, "checkpoint"), "r") as f:
|
| 286 |
+
ckpts = [x.strip() for x in f.readlines()]
|
| 287 |
+
else:
|
| 288 |
+
ckpts = []
|
| 289 |
+
ckpts.append(saved_model_name)
|
| 290 |
+
for item in ckpts[:-max_to_keep]:
|
| 291 |
+
if os.path.exists(os.path.join(args.checkpoint_dir, item)):
|
| 292 |
+
os.remove(os.path.join(args.checkpoint_dir, item))
|
| 293 |
+
with open(os.path.join(args.checkpoint_dir, "checkpoint"), "w") as f:
|
| 294 |
+
for item in ckpts[-max_to_keep:]:
|
| 295 |
+
f.write("{}\n".format(item))
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def attempt_to_restore(
|
| 299 |
+
generator, g_optimizer, checkpoint_dir, discriminator=None, d_optimizer=None
|
| 300 |
+
):
|
| 301 |
+
checkpoint_list = os.path.join(checkpoint_dir, "checkpoint")
|
| 302 |
+
if os.path.exists(checkpoint_list):
|
| 303 |
+
checkpoint_filename = open(checkpoint_list).readlines()[-1].strip()
|
| 304 |
+
checkpoint_path = os.path.join(checkpoint_dir, "{}".format(checkpoint_filename))
|
| 305 |
+
print("Restore from {}".format(checkpoint_path))
|
| 306 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 307 |
+
if generator:
|
| 308 |
+
if not list(generator.state_dict().keys())[0].startswith("module."):
|
| 309 |
+
raw_dict = checkpoint["generator"]
|
| 310 |
+
clean_dict = OrderedDict()
|
| 311 |
+
for k, v in raw_dict.items():
|
| 312 |
+
if k.startswith("module."):
|
| 313 |
+
clean_dict[k[7:]] = v
|
| 314 |
+
else:
|
| 315 |
+
clean_dict[k] = v
|
| 316 |
+
generator.load_state_dict(clean_dict)
|
| 317 |
+
else:
|
| 318 |
+
generator.load_state_dict(checkpoint["generator"])
|
| 319 |
+
if g_optimizer:
|
| 320 |
+
g_optimizer.load_state_dict(checkpoint["g_optimizer"])
|
| 321 |
+
global_step = 100000
|
| 322 |
+
if discriminator and "discriminator" in checkpoint.keys():
|
| 323 |
+
discriminator.load_state_dict(checkpoint["discriminator"])
|
| 324 |
+
global_step = checkpoint["global_step"]
|
| 325 |
+
print("restore discriminator")
|
| 326 |
+
if d_optimizer and "d_optimizer" in checkpoint.keys():
|
| 327 |
+
d_optimizer.load_state_dict(checkpoint["d_optimizer"])
|
| 328 |
+
print("restore d_optimizer...")
|
| 329 |
+
else:
|
| 330 |
+
global_step = 0
|
| 331 |
+
return global_step
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class ExponentialMovingAverage(object):
|
| 335 |
+
def __init__(self, decay):
|
| 336 |
+
self.decay = decay
|
| 337 |
+
self.shadow = {}
|
| 338 |
+
|
| 339 |
+
def register(self, name, val):
|
| 340 |
+
self.shadow[name] = val.clone()
|
| 341 |
+
|
| 342 |
+
def update(self, name, x):
|
| 343 |
+
assert name in self.shadow
|
| 344 |
+
update_delta = self.shadow[name] - x
|
| 345 |
+
self.shadow[name] -= (1.0 - self.decay) * update_delta
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def apply_moving_average(model, ema):
|
| 349 |
+
for name, param in model.named_parameters():
|
| 350 |
+
if name in ema.shadow:
|
| 351 |
+
ema.update(name, param.data)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def register_model_to_ema(model, ema):
|
| 355 |
+
for name, param in model.named_parameters():
|
| 356 |
+
if param.requires_grad:
|
| 357 |
+
ema.register(name, param.data)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class YParams(HParams):
|
| 361 |
+
def __init__(self, yaml_file):
|
| 362 |
+
if not os.path.exists(yaml_file):
|
| 363 |
+
raise IOError("yaml file: {} is not existed".format(yaml_file))
|
| 364 |
+
super().__init__()
|
| 365 |
+
self.d = collections.OrderedDict()
|
| 366 |
+
with open(yaml_file) as fp:
|
| 367 |
+
for _, v in yaml().load(fp).items():
|
| 368 |
+
for k1, v1 in v.items():
|
| 369 |
+
try:
|
| 370 |
+
if self.get(k1):
|
| 371 |
+
self.set_hparam(k1, v1)
|
| 372 |
+
else:
|
| 373 |
+
self.add_hparam(k1, v1)
|
| 374 |
+
self.d[k1] = v1
|
| 375 |
+
except Exception:
|
| 376 |
+
import traceback
|
| 377 |
+
|
| 378 |
+
print(traceback.format_exc())
|
| 379 |
+
|
| 380 |
+
# @property
|
| 381 |
+
def get_elements(self):
|
| 382 |
+
return self.d.items()
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def override_config(base_config, new_config):
|
| 386 |
+
"""Update new configurations in the original dict with the new dict
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
base_config (dict): original dict to be overridden
|
| 390 |
+
new_config (dict): dict with new configurations
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
dict: updated configuration dict
|
| 394 |
+
"""
|
| 395 |
+
for k, v in new_config.items():
|
| 396 |
+
if type(v) == dict:
|
| 397 |
+
if k not in base_config.keys():
|
| 398 |
+
base_config[k] = {}
|
| 399 |
+
base_config[k] = override_config(base_config[k], v)
|
| 400 |
+
else:
|
| 401 |
+
base_config[k] = v
|
| 402 |
+
return base_config
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def get_lowercase_keys_config(cfg):
|
| 406 |
+
"""Change all keys in cfg to lower case
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
cfg (dict): dictionary that stores configurations
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
dict: dictionary that stores configurations
|
| 413 |
+
"""
|
| 414 |
+
updated_cfg = dict()
|
| 415 |
+
for k, v in cfg.items():
|
| 416 |
+
if type(v) == dict:
|
| 417 |
+
v = get_lowercase_keys_config(v)
|
| 418 |
+
updated_cfg[k.lower()] = v
|
| 419 |
+
return updated_cfg
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def _load_config(config_fn, lowercase=False):
|
| 423 |
+
"""Load configurations into a dictionary
|
| 424 |
+
|
| 425 |
+
Args:
|
| 426 |
+
config_fn (str): path to configuration file
|
| 427 |
+
lowercase (bool, optional): whether changing keys to lower case. Defaults to False.
|
| 428 |
+
|
| 429 |
+
Returns:
|
| 430 |
+
dict: dictionary that stores configurations
|
| 431 |
+
"""
|
| 432 |
+
with open(config_fn, "r") as f:
|
| 433 |
+
data = f.read()
|
| 434 |
+
config_ = json5.loads(data)
|
| 435 |
+
if "base_config" in config_:
|
| 436 |
+
# load configurations from new path
|
| 437 |
+
try:
|
| 438 |
+
p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"])
|
| 439 |
+
except:
|
| 440 |
+
p_config_path = config_["base_config"]
|
| 441 |
+
p_config_ = _load_config(p_config_path)
|
| 442 |
+
config_ = override_config(p_config_, config_)
|
| 443 |
+
if lowercase:
|
| 444 |
+
# change keys in config_ to lower case
|
| 445 |
+
config_ = get_lowercase_keys_config(config_)
|
| 446 |
+
return config_
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def load_config(config_fn, lowercase=False):
|
| 450 |
+
"""Load configurations into a dictionary
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
config_fn (str): path to configuration file
|
| 454 |
+
lowercase (bool, optional): _description_. Defaults to False.
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
JsonHParams: an object that stores configurations
|
| 458 |
+
"""
|
| 459 |
+
config_ = _load_config(config_fn, lowercase=lowercase)
|
| 460 |
+
# create an JsonHParams object with configuration dict
|
| 461 |
+
cfg = JsonHParams(**config_)
|
| 462 |
+
return cfg
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def save_config(save_path, cfg):
|
| 466 |
+
"""Save configurations into a json file
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
save_path (str): path to save configurations
|
| 470 |
+
cfg (dict): dictionary that stores configurations
|
| 471 |
+
"""
|
| 472 |
+
with open(save_path, "w") as f:
|
| 473 |
+
json5.dump(
|
| 474 |
+
cfg, f, ensure_ascii=False, indent=4, quote_keys=True, sort_keys=True
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class JsonHParams:
|
| 479 |
+
def __init__(self, **kwargs):
|
| 480 |
+
for k, v in kwargs.items():
|
| 481 |
+
if type(v) == dict:
|
| 482 |
+
v = JsonHParams(**v)
|
| 483 |
+
self[k] = v
|
| 484 |
+
|
| 485 |
+
def keys(self):
|
| 486 |
+
return self.__dict__.keys()
|
| 487 |
+
|
| 488 |
+
def items(self):
|
| 489 |
+
return self.__dict__.items()
|
| 490 |
+
|
| 491 |
+
def values(self):
|
| 492 |
+
return self.__dict__.values()
|
| 493 |
+
|
| 494 |
+
def __len__(self):
|
| 495 |
+
return len(self.__dict__)
|
| 496 |
+
|
| 497 |
+
def __getitem__(self, key):
|
| 498 |
+
return getattr(self, key)
|
| 499 |
+
|
| 500 |
+
def __setitem__(self, key, value):
|
| 501 |
+
return setattr(self, key, value)
|
| 502 |
+
|
| 503 |
+
def __contains__(self, key):
|
| 504 |
+
return key in self.__dict__
|
| 505 |
+
|
| 506 |
+
def __repr__(self):
|
| 507 |
+
return self.__dict__.__repr__()
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class ValueWindow:
|
| 511 |
+
def __init__(self, window_size=100):
|
| 512 |
+
self._window_size = window_size
|
| 513 |
+
self._values = []
|
| 514 |
+
|
| 515 |
+
def append(self, x):
|
| 516 |
+
self._values = self._values[-(self._window_size - 1) :] + [x]
|
| 517 |
+
|
| 518 |
+
@property
|
| 519 |
+
def sum(self):
|
| 520 |
+
return sum(self._values)
|
| 521 |
+
|
| 522 |
+
@property
|
| 523 |
+
def count(self):
|
| 524 |
+
return len(self._values)
|
| 525 |
+
|
| 526 |
+
@property
|
| 527 |
+
def average(self):
|
| 528 |
+
return self.sum / max(1, self.count)
|
| 529 |
+
|
| 530 |
+
def reset(self):
|
| 531 |
+
self._values = []
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class Logger(object):
|
| 535 |
+
def __init__(
|
| 536 |
+
self,
|
| 537 |
+
filename,
|
| 538 |
+
level="info",
|
| 539 |
+
when="D",
|
| 540 |
+
backCount=10,
|
| 541 |
+
fmt="%(asctime)s : %(message)s",
|
| 542 |
+
):
|
| 543 |
+
self.level_relations = {
|
| 544 |
+
"debug": logging.DEBUG,
|
| 545 |
+
"info": logging.INFO,
|
| 546 |
+
"warning": logging.WARNING,
|
| 547 |
+
"error": logging.ERROR,
|
| 548 |
+
"crit": logging.CRITICAL,
|
| 549 |
+
}
|
| 550 |
+
if level == "debug":
|
| 551 |
+
fmt = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
| 552 |
+
self.logger = logging.getLogger(filename)
|
| 553 |
+
format_str = logging.Formatter(fmt)
|
| 554 |
+
self.logger.setLevel(self.level_relations.get(level))
|
| 555 |
+
sh = logging.StreamHandler()
|
| 556 |
+
sh.setFormatter(format_str)
|
| 557 |
+
th = handlers.TimedRotatingFileHandler(
|
| 558 |
+
filename=filename, when=when, backupCount=backCount, encoding="utf-8"
|
| 559 |
+
)
|
| 560 |
+
th.setFormatter(format_str)
|
| 561 |
+
self.logger.addHandler(sh)
|
| 562 |
+
self.logger.addHandler(th)
|
| 563 |
+
self.logger.info(
|
| 564 |
+
"==========================New Starting Here=============================="
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 569 |
+
classname = m.__class__.__name__
|
| 570 |
+
if classname.find("Conv") != -1:
|
| 571 |
+
m.weight.data.normal_(mean, std)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def get_padding(kernel_size, dilation=1):
|
| 575 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def slice_segments(x, ids_str, segment_size=4):
|
| 579 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
| 580 |
+
for i in range(x.size(0)):
|
| 581 |
+
idx_str = ids_str[i]
|
| 582 |
+
idx_end = idx_str + segment_size
|
| 583 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
| 584 |
+
return ret
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
| 588 |
+
b, d, t = x.size()
|
| 589 |
+
if x_lengths is None:
|
| 590 |
+
x_lengths = t
|
| 591 |
+
ids_str_max = x_lengths - segment_size + 1
|
| 592 |
+
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
| 593 |
+
ret = slice_segments(x, ids_str, segment_size)
|
| 594 |
+
return ret, ids_str
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def subsequent_mask(length):
|
| 598 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
| 599 |
+
return mask
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
@torch.jit.script
|
| 603 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 604 |
+
n_channels_int = n_channels[0]
|
| 605 |
+
in_act = input_a + input_b
|
| 606 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
| 607 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 608 |
+
acts = t_act * s_act
|
| 609 |
+
return acts
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def convert_pad_shape(pad_shape):
|
| 613 |
+
l = pad_shape[::-1]
|
| 614 |
+
pad_shape = [item for sublist in l for item in sublist]
|
| 615 |
+
return pad_shape
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def sequence_mask(length, max_length=None):
|
| 619 |
+
if max_length is None:
|
| 620 |
+
max_length = length.max()
|
| 621 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
| 622 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def generate_path(duration, mask):
|
| 626 |
+
"""
|
| 627 |
+
duration: [b, 1, t_x]
|
| 628 |
+
mask: [b, 1, t_y, t_x]
|
| 629 |
+
"""
|
| 630 |
+
device = duration.device
|
| 631 |
+
|
| 632 |
+
b, _, t_y, t_x = mask.shape
|
| 633 |
+
cum_duration = torch.cumsum(duration, -1)
|
| 634 |
+
|
| 635 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
| 636 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
| 637 |
+
path = path.view(b, t_x, t_y)
|
| 638 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
| 639 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
| 640 |
+
return path
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
| 644 |
+
if isinstance(parameters, torch.Tensor):
|
| 645 |
+
parameters = [parameters]
|
| 646 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
| 647 |
+
norm_type = float(norm_type)
|
| 648 |
+
if clip_value is not None:
|
| 649 |
+
clip_value = float(clip_value)
|
| 650 |
+
|
| 651 |
+
total_norm = 0
|
| 652 |
+
for p in parameters:
|
| 653 |
+
param_norm = p.grad.data.norm(norm_type)
|
| 654 |
+
total_norm += param_norm.item() ** norm_type
|
| 655 |
+
if clip_value is not None:
|
| 656 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
| 657 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
| 658 |
+
return total_norm
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def get_current_time():
|
| 662 |
+
pass
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| 666 |
+
"""
|
| 667 |
+
Args:
|
| 668 |
+
lengths:
|
| 669 |
+
A 1-D tensor containing sentence lengths.
|
| 670 |
+
max_len:
|
| 671 |
+
The length of masks.
|
| 672 |
+
Returns:
|
| 673 |
+
Return a 2-D bool tensor, where masked positions
|
| 674 |
+
are filled with `True` and non-masked positions are
|
| 675 |
+
filled with `False`.
|
| 676 |
+
|
| 677 |
+
>>> lengths = torch.tensor([1, 3, 2, 5])
|
| 678 |
+
>>> make_pad_mask(lengths)
|
| 679 |
+
tensor([[False, True, True, True, True],
|
| 680 |
+
[False, False, False, True, True],
|
| 681 |
+
[False, False, True, True, True],
|
| 682 |
+
[False, False, False, False, False]])
|
| 683 |
+
"""
|
| 684 |
+
assert lengths.ndim == 1, lengths.ndim
|
| 685 |
+
max_len = max(max_len, lengths.max())
|
| 686 |
+
n = lengths.size(0)
|
| 687 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
| 688 |
+
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
| 689 |
+
|
| 690 |
+
return expaned_lengths >= lengths.unsqueeze(-1)
|