import torch import torch.nn as nn import torch.nn.functional as F import torchaudio.transforms as T def compress_latents(z: torch.Tensor, factor: int = 6) -> torch.Tensor: B, C, T = z.shape if T % factor != 0: pad = factor - (T % factor) z = torch.nn.functional.pad(z, (0, pad)) T = T + pad return z.view(B, C, T // factor, factor).permute(0, 1, 3, 2).flatten(1, 2) def decompress_latents(z: torch.Tensor, factor: int = 6, target_channels: int = 24) -> torch.Tensor: B, _, T_low = z.shape return z.view(B, target_channels, factor, T_low).permute(0, 1, 3, 2).flatten(2, 3) def _resolve_vocab_size(char_dict_path, default=256): import json as _json import os as _os if char_dict_path and _os.path.exists(char_dict_path): try: with open(char_dict_path, "r") as f: cd = _json.load(f) if isinstance(cd, dict) and "vocab_size" in cd: return int(cd["vocab_size"]) if isinstance(cd, dict) and "char_to_id" in cd and isinstance(cd["char_to_id"], dict): return max(cd["char_to_id"].values()) + 1 if isinstance(cd, dict): return max(cd.values()) + 1 if cd else default return len(cd) except Exception: pass return default def load_ttl_config(config_path="configs/tts.json"): import json with open(config_path, "r") as f: full_config = json.load(f) ttl = full_config["ttl"] ae = full_config.get("ae", {}) dp = full_config.get("dp", {}) te = ttl["text_encoder"] se = ttl["style_encoder"] vf = ttl["vector_field"] um = ttl["uncond_masker"] char_dict_path = te.get("char_dict_path", te.get("text_embedder", {}).get("char_dict_path")) vocab_size = _resolve_vocab_size(char_dict_path, default=256) dp_char_dict_path = ( dp.get("sentence_encoder", {}).get("char_dict_path") or dp.get("sentence_encoder", {}).get("text_embedder", {}).get("char_dict_path") ) dp_vocab_size = _resolve_vocab_size(dp_char_dict_path, default=vocab_size) ae_dec = ae.get("decoder", {}) ae_dec_cfg = { "idim": ae_dec.get("idim", 24), "hdim": ae_dec.get("hdim", 512), "intermediate_dim": ae_dec.get("intermediate_dim", 2048), "ksz": ae_dec.get("ksz", 7), "dilation_lst": ae_dec.get("dilation_lst", [1, 2, 4, 1, 2, 4, 1, 1, 1, 1]), "chunk_compress_factor": ae.get("chunk_compress_factor", 1), "head": { "idim": ae_dec.get("head", {}).get("idim", ae_dec.get("hdim", 512)), "hdim": ae_dec.get("head", {}).get("hdim", 2048), "odim": ae_dec.get("head", {}).get("odim", 512), "ksz": ae_dec.get("head", {}).get("ksz", 3), }, } ae_enc = ae.get("encoder", {}) ae_enc_spec = ae_enc.get("spec_processor", {}) ae_enc_cfg = { "ksz": ae_enc.get("ksz", 7), "hdim": ae_enc.get("hdim", 512), "intermediate_dim": ae_enc.get("intermediate_dim", 2048), "dilation_lst": ae_enc.get("dilation_lst", [1] * 10), "odim": ae_enc.get("odim", 24), "idim": ae_enc.get("idim", 1253), } dp_se = dp.get("style_encoder", {}).get("style_token_layer", {}) return { "full_config": full_config, "ttl": ttl, "ae": ae, "dp": dp, "vocab_size": vocab_size, "char_dict_path": char_dict_path, "dp_vocab_size": dp_vocab_size, "latent_dim": ttl["latent_dim"], "chunk_compress_factor": ttl["chunk_compress_factor"], "compressed_channels": ttl["latent_dim"] * ttl["chunk_compress_factor"], "normalizer_scale": ttl["normalizer"]["scale"], "sigma_min": ttl["flow_matching"]["sig_min"], "Ke": ttl["batch_expander"]["n_batch_expand"], "te_d_model": te["text_embedder"]["char_emb_dim"], "te_convnext_layers": te["convnext"]["num_layers"], "te_expansion_factor": te["convnext"]["intermediate_dim"] // te["text_embedder"]["char_emb_dim"], "te_attn_n_layers": te["attn_encoder"]["n_layers"], "te_attn_p_dropout": te["attn_encoder"]["p_dropout"], "se_d_model": se["proj_in"]["odim"], "se_hidden_dim": se["convnext"]["intermediate_dim"], "se_num_blocks": se["convnext"]["num_layers"], "se_n_style": se["style_token_layer"]["n_style"], "se_n_heads": se["style_token_layer"]["n_heads"], "prob_both_uncond": um["prob_both_uncond"], "prob_text_uncond": um["prob_text_uncond"], "uncond_init_std": um["std"], "um_text_dim": um["text_dim"], "um_n_style": um["n_style"], "um_style_key_dim": um["style_key_dim"], "um_style_value_dim": um["style_value_dim"], "vf_hidden": vf["proj_in"]["odim"], "vf_time_dim": vf["time_encoder"]["time_dim"], "vf_n_blocks": vf["main_blocks"]["n_blocks"], "vf_text_dim": vf["main_blocks"]["text_cond_layer"]["text_dim"], "vf_text_n_heads": vf["main_blocks"]["text_cond_layer"]["n_heads"], "vf_style_dim": vf["main_blocks"]["style_cond_layer"]["style_dim"], "vf_rotary_scale": vf["main_blocks"]["text_cond_layer"]["rotary_scale"], "ae_dec_cfg": ae_dec_cfg, "ae_enc_cfg": ae_enc_cfg, "ae_sample_rate": ae.get("sample_rate", 44100), "ae_n_fft": ae_enc_spec.get("n_fft", 2048), "ae_hop_length": ae_enc_spec.get("hop_length", 512), "ae_n_mels": ae_enc_spec.get("n_mels", 1253), "dp_style_tokens": dp_se.get("n_style", 8), "dp_style_dim": dp_se.get("style_value_dim", 16), } class MelSpectrogram(nn.Module): def __init__(self, sample_rate=44100, n_fft=2048, win_length=2048, hop_length=512, n_mels=1253, f_min=0, f_max=None): super().__init__() self.mel = T.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max, center=True, power=1.0, ) def forward(self, audio): mel = torch.log(torch.clamp(self.mel(audio), min=1e-5)) return mel.squeeze(1) if mel.dim() == 4 and mel.shape[1] == 1 else mel class MelSpectrogramNoLog(nn.Module): def __init__(self, sample_rate=44100, n_fft=2048, win_length=2048, hop_length=512, n_mels=1253, f_min=0, f_max=12000, power=1.0): super().__init__() self.mel = T.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max, center=True, power=power, ) def forward(self, audio): mel = self.mel(audio) return mel.squeeze(1) if mel.dim() == 4 and mel.shape[1] == 1 else mel class LinearMelSpectrogram(nn.Module): def __init__(self, sample_rate=44100, n_fft=2048, win_length=2048, hop_length=512, n_mels=1253, f_min=0, f_max=None): super().__init__() self.spectrogram = T.Spectrogram( n_fft=n_fft, win_length=win_length, hop_length=hop_length, center=True, power=1.0, ) self.mel_scale = T.MelScale( n_mels=n_mels, sample_rate=sample_rate, n_stft=n_fft // 2 + 1, f_min=f_min, f_max=f_max, ) def forward(self, audio): spec = self.spectrogram(audio) mel = self.mel_scale(spec) log_spec = torch.log(torch.clamp(spec, min=1e-5)) log_mel = torch.log(torch.clamp(mel, min=1e-5)) return torch.cat([log_spec, log_mel], dim=1)