BlueV2 / models /utils.py
notmax123's picture
Simplify uploaded voice style export
35eb04b
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as T
def compress_latents(z: torch.Tensor, factor: int = 6) -> torch.Tensor:
B, C, T = z.shape
if T % factor != 0:
pad = factor - (T % factor)
z = torch.nn.functional.pad(z, (0, pad))
T = T + pad
return z.view(B, C, T // factor, factor).permute(0, 1, 3, 2).flatten(1, 2)
def decompress_latents(z: torch.Tensor, factor: int = 6, target_channels: int = 24) -> torch.Tensor:
B, _, T_low = z.shape
return z.view(B, target_channels, factor, T_low).permute(0, 1, 3, 2).flatten(2, 3)
def _resolve_vocab_size(char_dict_path, default=256):
import json as _json
import os as _os
if char_dict_path and _os.path.exists(char_dict_path):
try:
with open(char_dict_path, "r") as f:
cd = _json.load(f)
if isinstance(cd, dict) and "vocab_size" in cd:
return int(cd["vocab_size"])
if isinstance(cd, dict) and "char_to_id" in cd and isinstance(cd["char_to_id"], dict):
return max(cd["char_to_id"].values()) + 1
if isinstance(cd, dict):
return max(cd.values()) + 1 if cd else default
return len(cd)
except Exception:
pass
return default
def load_ttl_config(config_path="configs/tts.json"):
import json
with open(config_path, "r") as f:
full_config = json.load(f)
ttl = full_config["ttl"]
ae = full_config.get("ae", {})
dp = full_config.get("dp", {})
te = ttl["text_encoder"]
se = ttl["style_encoder"]
vf = ttl["vector_field"]
um = ttl["uncond_masker"]
char_dict_path = te.get("char_dict_path", te.get("text_embedder", {}).get("char_dict_path"))
vocab_size = _resolve_vocab_size(char_dict_path, default=256)
dp_char_dict_path = (
dp.get("sentence_encoder", {}).get("char_dict_path")
or dp.get("sentence_encoder", {}).get("text_embedder", {}).get("char_dict_path")
)
dp_vocab_size = _resolve_vocab_size(dp_char_dict_path, default=vocab_size)
ae_dec = ae.get("decoder", {})
ae_dec_cfg = {
"idim": ae_dec.get("idim", 24),
"hdim": ae_dec.get("hdim", 512),
"intermediate_dim": ae_dec.get("intermediate_dim", 2048),
"ksz": ae_dec.get("ksz", 7),
"dilation_lst": ae_dec.get("dilation_lst", [1, 2, 4, 1, 2, 4, 1, 1, 1, 1]),
"chunk_compress_factor": ae.get("chunk_compress_factor", 1),
"head": {
"idim": ae_dec.get("head", {}).get("idim", ae_dec.get("hdim", 512)),
"hdim": ae_dec.get("head", {}).get("hdim", 2048),
"odim": ae_dec.get("head", {}).get("odim", 512),
"ksz": ae_dec.get("head", {}).get("ksz", 3),
},
}
ae_enc = ae.get("encoder", {})
ae_enc_spec = ae_enc.get("spec_processor", {})
ae_enc_cfg = {
"ksz": ae_enc.get("ksz", 7),
"hdim": ae_enc.get("hdim", 512),
"intermediate_dim": ae_enc.get("intermediate_dim", 2048),
"dilation_lst": ae_enc.get("dilation_lst", [1] * 10),
"odim": ae_enc.get("odim", 24),
"idim": ae_enc.get("idim", 1253),
}
dp_se = dp.get("style_encoder", {}).get("style_token_layer", {})
return {
"full_config": full_config,
"ttl": ttl,
"ae": ae,
"dp": dp,
"vocab_size": vocab_size,
"char_dict_path": char_dict_path,
"dp_vocab_size": dp_vocab_size,
"latent_dim": ttl["latent_dim"],
"chunk_compress_factor": ttl["chunk_compress_factor"],
"compressed_channels": ttl["latent_dim"] * ttl["chunk_compress_factor"],
"normalizer_scale": ttl["normalizer"]["scale"],
"sigma_min": ttl["flow_matching"]["sig_min"],
"Ke": ttl["batch_expander"]["n_batch_expand"],
"te_d_model": te["text_embedder"]["char_emb_dim"],
"te_convnext_layers": te["convnext"]["num_layers"],
"te_expansion_factor": te["convnext"]["intermediate_dim"] // te["text_embedder"]["char_emb_dim"],
"te_attn_n_layers": te["attn_encoder"]["n_layers"],
"te_attn_p_dropout": te["attn_encoder"]["p_dropout"],
"se_d_model": se["proj_in"]["odim"],
"se_hidden_dim": se["convnext"]["intermediate_dim"],
"se_num_blocks": se["convnext"]["num_layers"],
"se_n_style": se["style_token_layer"]["n_style"],
"se_n_heads": se["style_token_layer"]["n_heads"],
"prob_both_uncond": um["prob_both_uncond"],
"prob_text_uncond": um["prob_text_uncond"],
"uncond_init_std": um["std"],
"um_text_dim": um["text_dim"],
"um_n_style": um["n_style"],
"um_style_key_dim": um["style_key_dim"],
"um_style_value_dim": um["style_value_dim"],
"vf_hidden": vf["proj_in"]["odim"],
"vf_time_dim": vf["time_encoder"]["time_dim"],
"vf_n_blocks": vf["main_blocks"]["n_blocks"],
"vf_text_dim": vf["main_blocks"]["text_cond_layer"]["text_dim"],
"vf_text_n_heads": vf["main_blocks"]["text_cond_layer"]["n_heads"],
"vf_style_dim": vf["main_blocks"]["style_cond_layer"]["style_dim"],
"vf_rotary_scale": vf["main_blocks"]["text_cond_layer"]["rotary_scale"],
"ae_dec_cfg": ae_dec_cfg,
"ae_enc_cfg": ae_enc_cfg,
"ae_sample_rate": ae.get("sample_rate", 44100),
"ae_n_fft": ae_enc_spec.get("n_fft", 2048),
"ae_hop_length": ae_enc_spec.get("hop_length", 512),
"ae_n_mels": ae_enc_spec.get("n_mels", 1253),
"dp_style_tokens": dp_se.get("n_style", 8),
"dp_style_dim": dp_se.get("style_value_dim", 16),
}
class MelSpectrogram(nn.Module):
def __init__(self, sample_rate=44100, n_fft=2048, win_length=2048,
hop_length=512, n_mels=1253, f_min=0, f_max=None):
super().__init__()
self.mel = T.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, win_length=win_length,
hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max,
center=True, power=1.0,
)
def forward(self, audio):
mel = torch.log(torch.clamp(self.mel(audio), min=1e-5))
return mel.squeeze(1) if mel.dim() == 4 and mel.shape[1] == 1 else mel
class MelSpectrogramNoLog(nn.Module):
def __init__(self, sample_rate=44100, n_fft=2048, win_length=2048,
hop_length=512, n_mels=1253, f_min=0, f_max=12000, power=1.0):
super().__init__()
self.mel = T.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, win_length=win_length,
hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max,
center=True, power=power,
)
def forward(self, audio):
mel = self.mel(audio)
return mel.squeeze(1) if mel.dim() == 4 and mel.shape[1] == 1 else mel
class LinearMelSpectrogram(nn.Module):
def __init__(self, sample_rate=44100, n_fft=2048, win_length=2048,
hop_length=512, n_mels=1253, f_min=0, f_max=None):
super().__init__()
self.spectrogram = T.Spectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length,
center=True, power=1.0,
)
self.mel_scale = T.MelScale(
n_mels=n_mels, sample_rate=sample_rate,
n_stft=n_fft // 2 + 1, f_min=f_min, f_max=f_max,
)
def forward(self, audio):
spec = self.spectrogram(audio)
mel = self.mel_scale(spec)
log_spec = torch.log(torch.clamp(spec, min=1e-5))
log_mel = torch.log(torch.clamp(mel, min=1e-5))
return torch.cat([log_spec, log_mel], dim=1)