Upload PackedTTS.py
Browse files- PackedTTS.py +497 -0
PackedTTS.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import copy
|
| 5 |
+
import random
|
| 6 |
+
import tempfile
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import librosa
|
| 12 |
+
import numpy as np
|
| 13 |
+
import soundfile as sf
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from chichat.chatterbox.models.s3tokenizer import S3_SR, drop_invalid_tokens
|
| 18 |
+
from chichat.chatterbox.models.s3gen import S3GEN_SR, S3Gen
|
| 19 |
+
from chichat.chatterbox.models.t3 import T3
|
| 20 |
+
from chichat.chatterbox.models.t3.modules.cond_enc import T3Cond
|
| 21 |
+
from chichat.chatterbox.models.tokenizers import EnTokenizer
|
| 22 |
+
from chichat.chatterbox.models.voice_encoder import VoiceEncoder
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ----------------------------------------------------------------------------
|
| 26 |
+
# CONFIG
|
| 27 |
+
# ----------------------------------------------------------------------------
|
| 28 |
+
DEFAULT_BUNDLE_PATH = Path("tts.pt")
|
| 29 |
+
DEFAULT_OUTPUT_PATH = Path("output.wav")
|
| 30 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
MAX_REF_SECONDS = 10.0
|
| 32 |
+
S3GEN_SR = 24000
|
| 33 |
+
S3_SR = 16000
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ----------------------------------------------------------------------------
|
| 37 |
+
# UTILITIES
|
| 38 |
+
# ----------------------------------------------------------------------------
|
| 39 |
+
def set_seed(seed: int):
|
| 40 |
+
if seed is None or int(seed) == 0:
|
| 41 |
+
return
|
| 42 |
+
seed = int(seed)
|
| 43 |
+
torch.manual_seed(seed)
|
| 44 |
+
if torch.cuda.is_available():
|
| 45 |
+
torch.cuda.manual_seed(seed)
|
| 46 |
+
torch.cuda.manual_seed_all(seed)
|
| 47 |
+
random.seed(seed)
|
| 48 |
+
np.random.seed(seed)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def clone_tensor(x: Optional[torch.Tensor], device=None) -> Optional[torch.Tensor]:
|
| 52 |
+
if x is None:
|
| 53 |
+
return None
|
| 54 |
+
if not torch.is_tensor(x):
|
| 55 |
+
return x
|
| 56 |
+
out = x.detach().clone()
|
| 57 |
+
if device is not None:
|
| 58 |
+
out = out.to(device)
|
| 59 |
+
return out
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def clone_ref_dict(ref_dict: Dict[str, Any], device=None) -> Dict[str, Any]:
|
| 63 |
+
out: Dict[str, Any] = {}
|
| 64 |
+
for k, v in ref_dict.items():
|
| 65 |
+
if torch.is_tensor(v):
|
| 66 |
+
t = v.detach().clone()
|
| 67 |
+
if device is not None:
|
| 68 |
+
t = t.to(device)
|
| 69 |
+
out[k] = t
|
| 70 |
+
else:
|
| 71 |
+
out[k] = copy.deepcopy(v)
|
| 72 |
+
return out
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def normalize_name(name: str) -> str:
|
| 76 |
+
import re
|
| 77 |
+
|
| 78 |
+
return re.sub(r"[^a-z0-9]+", "", name.strip().lower())
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ----------------------------------------------------------------------------
|
| 82 |
+
# CONDITIONALS
|
| 83 |
+
# ----------------------------------------------------------------------------
|
| 84 |
+
@dataclass
|
| 85 |
+
class Conditionals:
|
| 86 |
+
t3: T3Cond
|
| 87 |
+
gen: dict
|
| 88 |
+
|
| 89 |
+
def to(self, device):
|
| 90 |
+
self.t3 = self.t3.to(device)
|
| 91 |
+
self.t3.speaker_emb = clone_tensor(self.t3.speaker_emb, device)
|
| 92 |
+
|
| 93 |
+
if getattr(self.t3, "cond_prompt_speech_tokens", None) is not None:
|
| 94 |
+
self.t3.cond_prompt_speech_tokens = clone_tensor(self.t3.cond_prompt_speech_tokens, device)
|
| 95 |
+
|
| 96 |
+
if getattr(self.t3, "emotion_adv", None) is not None:
|
| 97 |
+
self.t3.emotion_adv = clone_tensor(self.t3.emotion_adv, device)
|
| 98 |
+
|
| 99 |
+
for k, v in self.gen.items():
|
| 100 |
+
if torch.is_tensor(v):
|
| 101 |
+
self.gen[k] = clone_tensor(v, device)
|
| 102 |
+
return self
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ----------------------------------------------------------------------------
|
| 106 |
+
# PACKED TTS
|
| 107 |
+
# ----------------------------------------------------------------------------
|
| 108 |
+
class PackedTTS:
|
| 109 |
+
def __init__(self, bundle: Dict[str, Any], device: str = DEVICE):
|
| 110 |
+
self.bundle = bundle
|
| 111 |
+
self.device = device
|
| 112 |
+
self.t3: Optional[T3] = None
|
| 113 |
+
self.s3gen: Optional[S3Gen] = None
|
| 114 |
+
self.ve: Optional[VoiceEncoder] = None
|
| 115 |
+
self.tokenizer: Optional[EnTokenizer] = None
|
| 116 |
+
self.conds: Optional[Conditionals] = None
|
| 117 |
+
|
| 118 |
+
self._tmpdir = tempfile.TemporaryDirectory(prefix="packed_tts_tokenizer_")
|
| 119 |
+
self._load_models_from_bundle()
|
| 120 |
+
|
| 121 |
+
@classmethod
|
| 122 |
+
def load(cls, bundle_path: Path, device: str = DEVICE) -> "PackedTTS":
|
| 123 |
+
bundle = torch.load(bundle_path, map_location="cpu")
|
| 124 |
+
if not isinstance(bundle, dict):
|
| 125 |
+
raise ValueError("Packed bundle did not contain a dictionary.")
|
| 126 |
+
bundle.setdefault("voices", {})
|
| 127 |
+
bundle.setdefault("emotions", {})
|
| 128 |
+
bundle.setdefault("models", {})
|
| 129 |
+
bundle.setdefault("defaults", {})
|
| 130 |
+
bundle.setdefault("indexes", {})
|
| 131 |
+
return cls(bundle=bundle, device=device)
|
| 132 |
+
|
| 133 |
+
def close(self):
|
| 134 |
+
try:
|
| 135 |
+
self._tmpdir.cleanup()
|
| 136 |
+
except Exception:
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
def __del__(self):
|
| 140 |
+
self.close()
|
| 141 |
+
|
| 142 |
+
# ------------------------------------------------------------------
|
| 143 |
+
# Model restore
|
| 144 |
+
# ------------------------------------------------------------------
|
| 145 |
+
def _load_models_from_bundle(self):
|
| 146 |
+
models = self.bundle.get("models", {})
|
| 147 |
+
if not models:
|
| 148 |
+
raise ValueError("Bundle is missing packed model weights.")
|
| 149 |
+
|
| 150 |
+
t3 = T3()
|
| 151 |
+
t3.load_state_dict(models["t3_state"])
|
| 152 |
+
t3.to(self.device).eval()
|
| 153 |
+
self.t3 = t3
|
| 154 |
+
|
| 155 |
+
s3gen = S3Gen()
|
| 156 |
+
s3gen.load_state_dict(models["s3gen_state"], strict=False)
|
| 157 |
+
s3gen.to(self.device).eval()
|
| 158 |
+
self.s3gen = s3gen
|
| 159 |
+
|
| 160 |
+
ve = VoiceEncoder()
|
| 161 |
+
ve.load_state_dict(models["ve_state"])
|
| 162 |
+
ve.to(self.device).eval()
|
| 163 |
+
self.ve = ve
|
| 164 |
+
|
| 165 |
+
tokenizer_json = models.get("tokenizer_json")
|
| 166 |
+
if not tokenizer_json:
|
| 167 |
+
raise ValueError("Bundle is missing tokenizer_json.")
|
| 168 |
+
tok_path = Path(self._tmpdir.name) / "tokenizer.json"
|
| 169 |
+
tok_path.write_text(tokenizer_json, encoding="utf-8")
|
| 170 |
+
self.tokenizer = EnTokenizer(str(tok_path))
|
| 171 |
+
|
| 172 |
+
# ------------------------------------------------------------------
|
| 173 |
+
# Audio extraction helpers
|
| 174 |
+
# ------------------------------------------------------------------
|
| 175 |
+
def _load_reference_audio(self, ref_audio_path: str):
|
| 176 |
+
wav, _ = librosa.load(
|
| 177 |
+
ref_audio_path,
|
| 178 |
+
sr=S3GEN_SR,
|
| 179 |
+
mono=True,
|
| 180 |
+
duration=MAX_REF_SECONDS,
|
| 181 |
+
)
|
| 182 |
+
max_len = int(MAX_REF_SECONDS * S3GEN_SR)
|
| 183 |
+
if len(wav) > max_len:
|
| 184 |
+
wav = wav[:max_len]
|
| 185 |
+
return wav
|
| 186 |
+
|
| 187 |
+
def extract_conditionals_from_audio(self, ref_audio_path: str, exaggeration: float = 0.5) -> Dict[str, Any]:
|
| 188 |
+
wav = self._load_reference_audio(ref_audio_path)
|
| 189 |
+
|
| 190 |
+
with torch.inference_mode():
|
| 191 |
+
ref_dict_raw = self.s3gen.embed_ref(wav, S3GEN_SR, device=self.device)
|
| 192 |
+
|
| 193 |
+
wav16k = librosa.resample(wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
|
| 194 |
+
wav16k = np.asarray(wav16k, dtype=np.float32)
|
| 195 |
+
|
| 196 |
+
embed = self.ve.embeds_from_wavs([wav16k], sample_rate=S3_SR)
|
| 197 |
+
if isinstance(embed, torch.Tensor):
|
| 198 |
+
speaker_emb = clone_tensor(embed.mean(dim=0, keepdim=True), self.device)
|
| 199 |
+
else:
|
| 200 |
+
speaker_emb = torch.from_numpy(np.asarray(embed)).mean(dim=0, keepdim=True).to(self.device)
|
| 201 |
+
|
| 202 |
+
plen = self.t3.hp.speech_cond_prompt_len
|
| 203 |
+
tok = None
|
| 204 |
+
if plen:
|
| 205 |
+
tokens, _ = self.s3gen.tokenizer.forward([wav16k], max_len=plen)
|
| 206 |
+
tok = torch.atleast_2d(tokens).clone().to(self.device)
|
| 207 |
+
|
| 208 |
+
ref_dict = clone_ref_dict(ref_dict_raw, device=self.device)
|
| 209 |
+
emotion_adv = torch.full((1, 1, 1), float(exaggeration), device=self.device)
|
| 210 |
+
|
| 211 |
+
return {
|
| 212 |
+
"speaker_emb": speaker_emb,
|
| 213 |
+
"cond_prompt_speech_tokens": tok,
|
| 214 |
+
"emotion_adv": emotion_adv,
|
| 215 |
+
"gen": ref_dict,
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
# ------------------------------------------------------------------
|
| 219 |
+
# Resolution helpers
|
| 220 |
+
# ------------------------------------------------------------------
|
| 221 |
+
def list_voices(self):
|
| 222 |
+
return list(self.bundle.get("voices", {}).keys())
|
| 223 |
+
|
| 224 |
+
def list_emotions(self):
|
| 225 |
+
return {k: len(v.get("variations", [])) for k, v in self.bundle.get("emotions", {}).items()}
|
| 226 |
+
|
| 227 |
+
def resolve_voice(self, requested: Optional[str]) -> Tuple[str, Dict[str, Any]]:
|
| 228 |
+
voices = self.bundle.get("voices", {})
|
| 229 |
+
if not voices:
|
| 230 |
+
raise ValueError("No voices are packed in this bundle.")
|
| 231 |
+
|
| 232 |
+
if not requested:
|
| 233 |
+
default_voice = self.bundle.get("defaults", {}).get("default_voice")
|
| 234 |
+
if default_voice and default_voice in voices:
|
| 235 |
+
return default_voice, voices[default_voice]
|
| 236 |
+
picked = random.choice(list(voices.keys()))
|
| 237 |
+
return picked, voices[picked]
|
| 238 |
+
|
| 239 |
+
norm = normalize_name(requested)
|
| 240 |
+
idx = self.bundle.get("indexes", {}).get("voice_norm", {})
|
| 241 |
+
if norm in idx and idx[norm] in voices:
|
| 242 |
+
name = idx[norm]
|
| 243 |
+
return name, voices[name]
|
| 244 |
+
|
| 245 |
+
from difflib import get_close_matches
|
| 246 |
+
|
| 247 |
+
matches = get_close_matches(requested, list(voices.keys()), n=1, cutoff=self.bundle.get("defaults", {}).get("fuzzy_cutoff", 0.72))
|
| 248 |
+
if matches:
|
| 249 |
+
name = matches[0]
|
| 250 |
+
return name, voices[name]
|
| 251 |
+
|
| 252 |
+
picked = random.choice(list(voices.keys()))
|
| 253 |
+
return picked, voices[picked]
|
| 254 |
+
|
| 255 |
+
def resolve_emotion(self, requested: Optional[str]) -> Tuple[str, Dict[str, Any]]:
|
| 256 |
+
emotions = self.bundle.get("emotions", {})
|
| 257 |
+
if not emotions:
|
| 258 |
+
raise ValueError("No emotions are packed in this bundle.")
|
| 259 |
+
|
| 260 |
+
if not requested:
|
| 261 |
+
default_emotion = self.bundle.get("defaults", {}).get("default_emotion")
|
| 262 |
+
if default_emotion and default_emotion in emotions:
|
| 263 |
+
emotion_name = default_emotion
|
| 264 |
+
else:
|
| 265 |
+
emotion_name = random.choice(list(emotions.keys()))
|
| 266 |
+
else:
|
| 267 |
+
norm = normalize_name(requested)
|
| 268 |
+
idx = self.bundle.get("indexes", {}).get("emotion_norm", {})
|
| 269 |
+
if norm in idx and idx[norm] in emotions:
|
| 270 |
+
emotion_name = idx[norm]
|
| 271 |
+
else:
|
| 272 |
+
from difflib import get_close_matches
|
| 273 |
+
|
| 274 |
+
matches = get_close_matches(requested, list(emotions.keys()), n=1, cutoff=self.bundle.get("defaults", {}).get("fuzzy_cutoff", 0.72))
|
| 275 |
+
emotion_name = matches[0] if matches else random.choice(list(emotions.keys()))
|
| 276 |
+
|
| 277 |
+
variations = emotions[emotion_name].get("variations", [])
|
| 278 |
+
if not variations:
|
| 279 |
+
raise ValueError(f"Emotion '{emotion_name}' has no variations.")
|
| 280 |
+
return emotion_name, random.choice(variations)
|
| 281 |
+
|
| 282 |
+
# ------------------------------------------------------------------
|
| 283 |
+
# Voice/emotion selection logic
|
| 284 |
+
# ------------------------------------------------------------------
|
| 285 |
+
def _resolve_voice_source(
|
| 286 |
+
self,
|
| 287 |
+
voice: Optional[str],
|
| 288 |
+
voice_ref: Optional[str],
|
| 289 |
+
exaggeration: float,
|
| 290 |
+
) -> Tuple[str, Dict[str, Any], Dict[str, Any]]:
|
| 291 |
+
"""Return (voice_name, voice_entry_or_extracted, extracted_conditionals_if_any)."""
|
| 292 |
+
if voice_ref:
|
| 293 |
+
extracted = self.extract_conditionals_from_audio(voice_ref, exaggeration=exaggeration)
|
| 294 |
+
return voice_ref, {"complete": True, **extracted}, extracted
|
| 295 |
+
|
| 296 |
+
voice_name, entry = self.resolve_voice(voice)
|
| 297 |
+
if entry.get("complete") and entry.get("speaker_emb") is not None:
|
| 298 |
+
return voice_name, entry, entry
|
| 299 |
+
|
| 300 |
+
raise ValueError(
|
| 301 |
+
f"Voice '{voice_name}' does not have packed generation conditionals. Provide voice_ref or repack the voice with a sample.wav."
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
def _resolve_emotion_source(
|
| 305 |
+
self,
|
| 306 |
+
emotion: Optional[str],
|
| 307 |
+
emo_ref: Optional[str],
|
| 308 |
+
voice_source_entry: Dict[str, Any],
|
| 309 |
+
voice_extracted: Dict[str, Any],
|
| 310 |
+
exaggeration: float,
|
| 311 |
+
) -> Tuple[str, Dict[str, Any]]:
|
| 312 |
+
if emo_ref:
|
| 313 |
+
extracted = self.extract_conditionals_from_audio(emo_ref, exaggeration=exaggeration)
|
| 314 |
+
return emo_ref, extracted
|
| 315 |
+
|
| 316 |
+
if emotion:
|
| 317 |
+
emotion_name, variation = self.resolve_emotion(emotion)
|
| 318 |
+
return emotion_name, variation
|
| 319 |
+
|
| 320 |
+
# No explicit emotion: prefer the voice's stored emotion if available.
|
| 321 |
+
if voice_source_entry.get("emotion_adv") is not None:
|
| 322 |
+
return "voice_default", {"emotion_adv": clone_tensor(voice_source_entry["emotion_adv"], self.device)}
|
| 323 |
+
|
| 324 |
+
# If the voice came from a ref audio, reuse its extracted emotion.
|
| 325 |
+
if voice_extracted.get("emotion_adv") is not None:
|
| 326 |
+
return "voice_ref", {"emotion_adv": clone_tensor(voice_extracted["emotion_adv"], self.device)}
|
| 327 |
+
|
| 328 |
+
# Final fallback.
|
| 329 |
+
return "fallback", {"emotion_adv": torch.full((1, 1, 1), float(exaggeration), device=self.device)}
|
| 330 |
+
|
| 331 |
+
# ------------------------------------------------------------------
|
| 332 |
+
# Inference helpers
|
| 333 |
+
# ------------------------------------------------------------------
|
| 334 |
+
def infer_t3(self, text: str, cfg_weight: float, temperature: float):
|
| 335 |
+
assert self.conds is not None, "Conditionals not prepared."
|
| 336 |
+
text = text.strip()
|
| 337 |
+
sot, eot = self.t3.hp.start_text_token, self.t3.hp.stop_text_token
|
| 338 |
+
tokens = self.tokenizer.text_to_tokens(text).to(self.device)
|
| 339 |
+
|
| 340 |
+
if cfg_weight > 0:
|
| 341 |
+
tokens = torch.cat([tokens, tokens], dim=0)
|
| 342 |
+
|
| 343 |
+
tokens = F.pad(tokens, (1, 0), value=sot)
|
| 344 |
+
tokens = F.pad(tokens, (0, 1), value=eot)
|
| 345 |
+
|
| 346 |
+
with torch.inference_mode():
|
| 347 |
+
out = self.t3.inference(
|
| 348 |
+
t3_cond=self.conds.t3,
|
| 349 |
+
text_tokens=tokens,
|
| 350 |
+
max_new_tokens=1000,
|
| 351 |
+
temperature=temperature,
|
| 352 |
+
cfg_weight=cfg_weight,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
return drop_invalid_tokens(out[0]).to(self.device)
|
| 356 |
+
|
| 357 |
+
def infer_s3gen(self, speech_tokens: torch.Tensor):
|
| 358 |
+
with torch.inference_mode():
|
| 359 |
+
wav, _ = self.s3gen.inference(
|
| 360 |
+
speech_tokens=speech_tokens,
|
| 361 |
+
ref_dict=self.conds.gen,
|
| 362 |
+
)
|
| 363 |
+
return wav.squeeze(0).detach().cpu().numpy()
|
| 364 |
+
|
| 365 |
+
# ------------------------------------------------------------------
|
| 366 |
+
# Public API
|
| 367 |
+
# ------------------------------------------------------------------
|
| 368 |
+
def generate(
|
| 369 |
+
self,
|
| 370 |
+
text: str,
|
| 371 |
+
voice: Optional[str] = None,
|
| 372 |
+
emotion: Optional[str] = None,
|
| 373 |
+
voice_ref: Optional[str] = None,
|
| 374 |
+
emo_ref: Optional[str] = None,
|
| 375 |
+
cfg_weight: float = 0.5,
|
| 376 |
+
temperature: float = 0.8,
|
| 377 |
+
exaggeration: float = 0.5,
|
| 378 |
+
seed: int = 0,
|
| 379 |
+
):
|
| 380 |
+
if seed:
|
| 381 |
+
set_seed(seed)
|
| 382 |
+
|
| 383 |
+
voice_name, voice_entry, voice_extracted = self._resolve_voice_source(voice, voice_ref, exaggeration)
|
| 384 |
+
emotion_name, emotion_source = self._resolve_emotion_source(
|
| 385 |
+
emotion=emotion,
|
| 386 |
+
emo_ref=emo_ref,
|
| 387 |
+
voice_source_entry=voice_entry,
|
| 388 |
+
voice_extracted=voice_extracted,
|
| 389 |
+
exaggeration=exaggeration,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
speaker_emb = voice_entry.get("speaker_emb")
|
| 393 |
+
if speaker_emb is None:
|
| 394 |
+
speaker_emb = voice_extracted.get("speaker_emb")
|
| 395 |
+
speaker_emb = clone_tensor(speaker_emb, self.device)
|
| 396 |
+
|
| 397 |
+
cond_prompt = voice_entry.get("cond_prompt_speech_tokens")
|
| 398 |
+
if cond_prompt is None:
|
| 399 |
+
cond_prompt = voice_extracted.get("cond_prompt_speech_tokens")
|
| 400 |
+
cond_prompt = clone_tensor(cond_prompt, self.device)
|
| 401 |
+
|
| 402 |
+
emotion_adv = emotion_source.get("emotion_adv")
|
| 403 |
+
emotion_adv = clone_tensor(emotion_adv, self.device)
|
| 404 |
+
|
| 405 |
+
gen = voice_entry.get("gen")
|
| 406 |
+
if gen is None:
|
| 407 |
+
gen = voice_extracted.get("gen")
|
| 408 |
+
if gen is None:
|
| 409 |
+
gen = {}
|
| 410 |
+
gen = clone_ref_dict(gen, device=self.device)
|
| 411 |
+
|
| 412 |
+
self.conds = Conditionals(
|
| 413 |
+
t3=T3Cond(
|
| 414 |
+
speaker_emb=speaker_emb,
|
| 415 |
+
cond_prompt_speech_tokens=cond_prompt,
|
| 416 |
+
emotion_adv=emotion_adv,
|
| 417 |
+
),
|
| 418 |
+
gen=gen,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
tokens = self.infer_t3(text, cfg_weight, temperature)
|
| 422 |
+
wav = self.infer_s3gen(tokens)
|
| 423 |
+
return S3GEN_SR, wav, {"voice": voice_name, "emotion": emotion_name}
|
| 424 |
+
|
| 425 |
+
forward = generate
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# ----------------------------------------------------------------------------
|
| 429 |
+
# CLI
|
| 430 |
+
# ----------------------------------------------------------------------------
|
| 431 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 432 |
+
p = argparse.ArgumentParser(description="Use a packed TTS bundle to generate speech.")
|
| 433 |
+
p.add_argument("--bundle", type=Path, default=DEFAULT_BUNDLE_PATH)
|
| 434 |
+
p.add_argument("--text", type=str, default="Hello world, this is a test.")
|
| 435 |
+
p.add_argument("--voice", type=str, default=None)
|
| 436 |
+
p.add_argument("--emotion", type=str, default=None)
|
| 437 |
+
p.add_argument("--voice-ref", type=Path, default=None)
|
| 438 |
+
p.add_argument("--emo-ref", type=Path, default=None)
|
| 439 |
+
p.add_argument("--cfg-weight", type=float, default=0.5)
|
| 440 |
+
p.add_argument("--temperature", type=float, default=0.8)
|
| 441 |
+
p.add_argument("--exaggeration", type=float, default=0.5)
|
| 442 |
+
p.add_argument("--seed", type=int, default=42)
|
| 443 |
+
p.add_argument("--output", type=Path, default=DEFAULT_OUTPUT_PATH)
|
| 444 |
+
p.add_argument("--list", action="store_true", help="List packed voices and emotions, then exit")
|
| 445 |
+
return p
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def main() -> None:
|
| 449 |
+
args = build_parser().parse_args()
|
| 450 |
+
tts = PackedTTS.load(args.bundle, device=DEVICE)
|
| 451 |
+
|
| 452 |
+
if args.list:
|
| 453 |
+
print("Voices:")
|
| 454 |
+
for name in tts.list_voices():
|
| 455 |
+
print(f" - {name}")
|
| 456 |
+
print("\nEmotions:")
|
| 457 |
+
for name, count in tts.list_emotions().items():
|
| 458 |
+
print(f" - {name} ({count} variations)")
|
| 459 |
+
return
|
| 460 |
+
|
| 461 |
+
voice_ref = str(args.voice_ref) if args.voice_ref else None
|
| 462 |
+
emo_ref = str(args.emo_ref) if args.emo_ref else None
|
| 463 |
+
sr, audio, meta = tts.generate(
|
| 464 |
+
text=args.text,
|
| 465 |
+
voice=args.voice,
|
| 466 |
+
emotion=args.emotion,
|
| 467 |
+
voice_ref=voice_ref,
|
| 468 |
+
emo_ref=emo_ref,
|
| 469 |
+
cfg_weight=args.cfg_weight,
|
| 470 |
+
temperature=args.temperature,
|
| 471 |
+
exaggeration=args.exaggeration,
|
| 472 |
+
seed=args.seed,
|
| 473 |
+
)
|
| 474 |
+
sf.write(str(args.output), audio, sr)
|
| 475 |
+
print(f"Saved {args.output}")
|
| 476 |
+
print(f"Resolved voice={meta['voice']} emotion={meta['emotion']}")
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
if __name__ == "__main__":
|
| 480 |
+
bundle_path = DEFAULT_BUNDLE_PATH
|
| 481 |
+
output_path = Path("sarah_happy_test.wav")
|
| 482 |
+
|
| 483 |
+
tts = PackedTTS.load(bundle_path, device=DEVICE)
|
| 484 |
+
|
| 485 |
+
sr, audio, meta = tts.generate(
|
| 486 |
+
text="Hi, this is Sarah speaking with a angry emotion.",
|
| 487 |
+
voice="Sarah",
|
| 488 |
+
emotion="Disgust",
|
| 489 |
+
cfg_weight=0.5,
|
| 490 |
+
temperature=0.8,
|
| 491 |
+
exaggeration=0.5,
|
| 492 |
+
seed=42,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
sf.write(str(output_path), audio, sr)
|
| 496 |
+
print(f"Saved {output_path}")
|
| 497 |
+
print(f"Resolved voice={meta['voice']} emotion={meta['emotion']}")
|