HiMind commited on
Commit
51166c6
·
verified ·
1 Parent(s): 314d289

Upload PackedTTS.py

Browse files
Files changed (1) hide show
  1. PackedTTS.py +497 -0
PackedTTS.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import copy
5
+ import random
6
+ import tempfile
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional, Tuple
10
+
11
+ import librosa
12
+ import numpy as np
13
+ import soundfile as sf
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ from chichat.chatterbox.models.s3tokenizer import S3_SR, drop_invalid_tokens
18
+ from chichat.chatterbox.models.s3gen import S3GEN_SR, S3Gen
19
+ from chichat.chatterbox.models.t3 import T3
20
+ from chichat.chatterbox.models.t3.modules.cond_enc import T3Cond
21
+ from chichat.chatterbox.models.tokenizers import EnTokenizer
22
+ from chichat.chatterbox.models.voice_encoder import VoiceEncoder
23
+
24
+
25
+ # ----------------------------------------------------------------------------
26
+ # CONFIG
27
+ # ----------------------------------------------------------------------------
28
+ DEFAULT_BUNDLE_PATH = Path("tts.pt")
29
+ DEFAULT_OUTPUT_PATH = Path("output.wav")
30
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
+ MAX_REF_SECONDS = 10.0
32
+ S3GEN_SR = 24000
33
+ S3_SR = 16000
34
+
35
+
36
+ # ----------------------------------------------------------------------------
37
+ # UTILITIES
38
+ # ----------------------------------------------------------------------------
39
+ def set_seed(seed: int):
40
+ if seed is None or int(seed) == 0:
41
+ return
42
+ seed = int(seed)
43
+ torch.manual_seed(seed)
44
+ if torch.cuda.is_available():
45
+ torch.cuda.manual_seed(seed)
46
+ torch.cuda.manual_seed_all(seed)
47
+ random.seed(seed)
48
+ np.random.seed(seed)
49
+
50
+
51
+ def clone_tensor(x: Optional[torch.Tensor], device=None) -> Optional[torch.Tensor]:
52
+ if x is None:
53
+ return None
54
+ if not torch.is_tensor(x):
55
+ return x
56
+ out = x.detach().clone()
57
+ if device is not None:
58
+ out = out.to(device)
59
+ return out
60
+
61
+
62
+ def clone_ref_dict(ref_dict: Dict[str, Any], device=None) -> Dict[str, Any]:
63
+ out: Dict[str, Any] = {}
64
+ for k, v in ref_dict.items():
65
+ if torch.is_tensor(v):
66
+ t = v.detach().clone()
67
+ if device is not None:
68
+ t = t.to(device)
69
+ out[k] = t
70
+ else:
71
+ out[k] = copy.deepcopy(v)
72
+ return out
73
+
74
+
75
+ def normalize_name(name: str) -> str:
76
+ import re
77
+
78
+ return re.sub(r"[^a-z0-9]+", "", name.strip().lower())
79
+
80
+
81
+ # ----------------------------------------------------------------------------
82
+ # CONDITIONALS
83
+ # ----------------------------------------------------------------------------
84
+ @dataclass
85
+ class Conditionals:
86
+ t3: T3Cond
87
+ gen: dict
88
+
89
+ def to(self, device):
90
+ self.t3 = self.t3.to(device)
91
+ self.t3.speaker_emb = clone_tensor(self.t3.speaker_emb, device)
92
+
93
+ if getattr(self.t3, "cond_prompt_speech_tokens", None) is not None:
94
+ self.t3.cond_prompt_speech_tokens = clone_tensor(self.t3.cond_prompt_speech_tokens, device)
95
+
96
+ if getattr(self.t3, "emotion_adv", None) is not None:
97
+ self.t3.emotion_adv = clone_tensor(self.t3.emotion_adv, device)
98
+
99
+ for k, v in self.gen.items():
100
+ if torch.is_tensor(v):
101
+ self.gen[k] = clone_tensor(v, device)
102
+ return self
103
+
104
+
105
+ # ----------------------------------------------------------------------------
106
+ # PACKED TTS
107
+ # ----------------------------------------------------------------------------
108
+ class PackedTTS:
109
+ def __init__(self, bundle: Dict[str, Any], device: str = DEVICE):
110
+ self.bundle = bundle
111
+ self.device = device
112
+ self.t3: Optional[T3] = None
113
+ self.s3gen: Optional[S3Gen] = None
114
+ self.ve: Optional[VoiceEncoder] = None
115
+ self.tokenizer: Optional[EnTokenizer] = None
116
+ self.conds: Optional[Conditionals] = None
117
+
118
+ self._tmpdir = tempfile.TemporaryDirectory(prefix="packed_tts_tokenizer_")
119
+ self._load_models_from_bundle()
120
+
121
+ @classmethod
122
+ def load(cls, bundle_path: Path, device: str = DEVICE) -> "PackedTTS":
123
+ bundle = torch.load(bundle_path, map_location="cpu")
124
+ if not isinstance(bundle, dict):
125
+ raise ValueError("Packed bundle did not contain a dictionary.")
126
+ bundle.setdefault("voices", {})
127
+ bundle.setdefault("emotions", {})
128
+ bundle.setdefault("models", {})
129
+ bundle.setdefault("defaults", {})
130
+ bundle.setdefault("indexes", {})
131
+ return cls(bundle=bundle, device=device)
132
+
133
+ def close(self):
134
+ try:
135
+ self._tmpdir.cleanup()
136
+ except Exception:
137
+ pass
138
+
139
+ def __del__(self):
140
+ self.close()
141
+
142
+ # ------------------------------------------------------------------
143
+ # Model restore
144
+ # ------------------------------------------------------------------
145
+ def _load_models_from_bundle(self):
146
+ models = self.bundle.get("models", {})
147
+ if not models:
148
+ raise ValueError("Bundle is missing packed model weights.")
149
+
150
+ t3 = T3()
151
+ t3.load_state_dict(models["t3_state"])
152
+ t3.to(self.device).eval()
153
+ self.t3 = t3
154
+
155
+ s3gen = S3Gen()
156
+ s3gen.load_state_dict(models["s3gen_state"], strict=False)
157
+ s3gen.to(self.device).eval()
158
+ self.s3gen = s3gen
159
+
160
+ ve = VoiceEncoder()
161
+ ve.load_state_dict(models["ve_state"])
162
+ ve.to(self.device).eval()
163
+ self.ve = ve
164
+
165
+ tokenizer_json = models.get("tokenizer_json")
166
+ if not tokenizer_json:
167
+ raise ValueError("Bundle is missing tokenizer_json.")
168
+ tok_path = Path(self._tmpdir.name) / "tokenizer.json"
169
+ tok_path.write_text(tokenizer_json, encoding="utf-8")
170
+ self.tokenizer = EnTokenizer(str(tok_path))
171
+
172
+ # ------------------------------------------------------------------
173
+ # Audio extraction helpers
174
+ # ------------------------------------------------------------------
175
+ def _load_reference_audio(self, ref_audio_path: str):
176
+ wav, _ = librosa.load(
177
+ ref_audio_path,
178
+ sr=S3GEN_SR,
179
+ mono=True,
180
+ duration=MAX_REF_SECONDS,
181
+ )
182
+ max_len = int(MAX_REF_SECONDS * S3GEN_SR)
183
+ if len(wav) > max_len:
184
+ wav = wav[:max_len]
185
+ return wav
186
+
187
+ def extract_conditionals_from_audio(self, ref_audio_path: str, exaggeration: float = 0.5) -> Dict[str, Any]:
188
+ wav = self._load_reference_audio(ref_audio_path)
189
+
190
+ with torch.inference_mode():
191
+ ref_dict_raw = self.s3gen.embed_ref(wav, S3GEN_SR, device=self.device)
192
+
193
+ wav16k = librosa.resample(wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
194
+ wav16k = np.asarray(wav16k, dtype=np.float32)
195
+
196
+ embed = self.ve.embeds_from_wavs([wav16k], sample_rate=S3_SR)
197
+ if isinstance(embed, torch.Tensor):
198
+ speaker_emb = clone_tensor(embed.mean(dim=0, keepdim=True), self.device)
199
+ else:
200
+ speaker_emb = torch.from_numpy(np.asarray(embed)).mean(dim=0, keepdim=True).to(self.device)
201
+
202
+ plen = self.t3.hp.speech_cond_prompt_len
203
+ tok = None
204
+ if plen:
205
+ tokens, _ = self.s3gen.tokenizer.forward([wav16k], max_len=plen)
206
+ tok = torch.atleast_2d(tokens).clone().to(self.device)
207
+
208
+ ref_dict = clone_ref_dict(ref_dict_raw, device=self.device)
209
+ emotion_adv = torch.full((1, 1, 1), float(exaggeration), device=self.device)
210
+
211
+ return {
212
+ "speaker_emb": speaker_emb,
213
+ "cond_prompt_speech_tokens": tok,
214
+ "emotion_adv": emotion_adv,
215
+ "gen": ref_dict,
216
+ }
217
+
218
+ # ------------------------------------------------------------------
219
+ # Resolution helpers
220
+ # ------------------------------------------------------------------
221
+ def list_voices(self):
222
+ return list(self.bundle.get("voices", {}).keys())
223
+
224
+ def list_emotions(self):
225
+ return {k: len(v.get("variations", [])) for k, v in self.bundle.get("emotions", {}).items()}
226
+
227
+ def resolve_voice(self, requested: Optional[str]) -> Tuple[str, Dict[str, Any]]:
228
+ voices = self.bundle.get("voices", {})
229
+ if not voices:
230
+ raise ValueError("No voices are packed in this bundle.")
231
+
232
+ if not requested:
233
+ default_voice = self.bundle.get("defaults", {}).get("default_voice")
234
+ if default_voice and default_voice in voices:
235
+ return default_voice, voices[default_voice]
236
+ picked = random.choice(list(voices.keys()))
237
+ return picked, voices[picked]
238
+
239
+ norm = normalize_name(requested)
240
+ idx = self.bundle.get("indexes", {}).get("voice_norm", {})
241
+ if norm in idx and idx[norm] in voices:
242
+ name = idx[norm]
243
+ return name, voices[name]
244
+
245
+ from difflib import get_close_matches
246
+
247
+ matches = get_close_matches(requested, list(voices.keys()), n=1, cutoff=self.bundle.get("defaults", {}).get("fuzzy_cutoff", 0.72))
248
+ if matches:
249
+ name = matches[0]
250
+ return name, voices[name]
251
+
252
+ picked = random.choice(list(voices.keys()))
253
+ return picked, voices[picked]
254
+
255
+ def resolve_emotion(self, requested: Optional[str]) -> Tuple[str, Dict[str, Any]]:
256
+ emotions = self.bundle.get("emotions", {})
257
+ if not emotions:
258
+ raise ValueError("No emotions are packed in this bundle.")
259
+
260
+ if not requested:
261
+ default_emotion = self.bundle.get("defaults", {}).get("default_emotion")
262
+ if default_emotion and default_emotion in emotions:
263
+ emotion_name = default_emotion
264
+ else:
265
+ emotion_name = random.choice(list(emotions.keys()))
266
+ else:
267
+ norm = normalize_name(requested)
268
+ idx = self.bundle.get("indexes", {}).get("emotion_norm", {})
269
+ if norm in idx and idx[norm] in emotions:
270
+ emotion_name = idx[norm]
271
+ else:
272
+ from difflib import get_close_matches
273
+
274
+ matches = get_close_matches(requested, list(emotions.keys()), n=1, cutoff=self.bundle.get("defaults", {}).get("fuzzy_cutoff", 0.72))
275
+ emotion_name = matches[0] if matches else random.choice(list(emotions.keys()))
276
+
277
+ variations = emotions[emotion_name].get("variations", [])
278
+ if not variations:
279
+ raise ValueError(f"Emotion '{emotion_name}' has no variations.")
280
+ return emotion_name, random.choice(variations)
281
+
282
+ # ------------------------------------------------------------------
283
+ # Voice/emotion selection logic
284
+ # ------------------------------------------------------------------
285
+ def _resolve_voice_source(
286
+ self,
287
+ voice: Optional[str],
288
+ voice_ref: Optional[str],
289
+ exaggeration: float,
290
+ ) -> Tuple[str, Dict[str, Any], Dict[str, Any]]:
291
+ """Return (voice_name, voice_entry_or_extracted, extracted_conditionals_if_any)."""
292
+ if voice_ref:
293
+ extracted = self.extract_conditionals_from_audio(voice_ref, exaggeration=exaggeration)
294
+ return voice_ref, {"complete": True, **extracted}, extracted
295
+
296
+ voice_name, entry = self.resolve_voice(voice)
297
+ if entry.get("complete") and entry.get("speaker_emb") is not None:
298
+ return voice_name, entry, entry
299
+
300
+ raise ValueError(
301
+ f"Voice '{voice_name}' does not have packed generation conditionals. Provide voice_ref or repack the voice with a sample.wav."
302
+ )
303
+
304
+ def _resolve_emotion_source(
305
+ self,
306
+ emotion: Optional[str],
307
+ emo_ref: Optional[str],
308
+ voice_source_entry: Dict[str, Any],
309
+ voice_extracted: Dict[str, Any],
310
+ exaggeration: float,
311
+ ) -> Tuple[str, Dict[str, Any]]:
312
+ if emo_ref:
313
+ extracted = self.extract_conditionals_from_audio(emo_ref, exaggeration=exaggeration)
314
+ return emo_ref, extracted
315
+
316
+ if emotion:
317
+ emotion_name, variation = self.resolve_emotion(emotion)
318
+ return emotion_name, variation
319
+
320
+ # No explicit emotion: prefer the voice's stored emotion if available.
321
+ if voice_source_entry.get("emotion_adv") is not None:
322
+ return "voice_default", {"emotion_adv": clone_tensor(voice_source_entry["emotion_adv"], self.device)}
323
+
324
+ # If the voice came from a ref audio, reuse its extracted emotion.
325
+ if voice_extracted.get("emotion_adv") is not None:
326
+ return "voice_ref", {"emotion_adv": clone_tensor(voice_extracted["emotion_adv"], self.device)}
327
+
328
+ # Final fallback.
329
+ return "fallback", {"emotion_adv": torch.full((1, 1, 1), float(exaggeration), device=self.device)}
330
+
331
+ # ------------------------------------------------------------------
332
+ # Inference helpers
333
+ # ------------------------------------------------------------------
334
+ def infer_t3(self, text: str, cfg_weight: float, temperature: float):
335
+ assert self.conds is not None, "Conditionals not prepared."
336
+ text = text.strip()
337
+ sot, eot = self.t3.hp.start_text_token, self.t3.hp.stop_text_token
338
+ tokens = self.tokenizer.text_to_tokens(text).to(self.device)
339
+
340
+ if cfg_weight > 0:
341
+ tokens = torch.cat([tokens, tokens], dim=0)
342
+
343
+ tokens = F.pad(tokens, (1, 0), value=sot)
344
+ tokens = F.pad(tokens, (0, 1), value=eot)
345
+
346
+ with torch.inference_mode():
347
+ out = self.t3.inference(
348
+ t3_cond=self.conds.t3,
349
+ text_tokens=tokens,
350
+ max_new_tokens=1000,
351
+ temperature=temperature,
352
+ cfg_weight=cfg_weight,
353
+ )
354
+
355
+ return drop_invalid_tokens(out[0]).to(self.device)
356
+
357
+ def infer_s3gen(self, speech_tokens: torch.Tensor):
358
+ with torch.inference_mode():
359
+ wav, _ = self.s3gen.inference(
360
+ speech_tokens=speech_tokens,
361
+ ref_dict=self.conds.gen,
362
+ )
363
+ return wav.squeeze(0).detach().cpu().numpy()
364
+
365
+ # ------------------------------------------------------------------
366
+ # Public API
367
+ # ------------------------------------------------------------------
368
+ def generate(
369
+ self,
370
+ text: str,
371
+ voice: Optional[str] = None,
372
+ emotion: Optional[str] = None,
373
+ voice_ref: Optional[str] = None,
374
+ emo_ref: Optional[str] = None,
375
+ cfg_weight: float = 0.5,
376
+ temperature: float = 0.8,
377
+ exaggeration: float = 0.5,
378
+ seed: int = 0,
379
+ ):
380
+ if seed:
381
+ set_seed(seed)
382
+
383
+ voice_name, voice_entry, voice_extracted = self._resolve_voice_source(voice, voice_ref, exaggeration)
384
+ emotion_name, emotion_source = self._resolve_emotion_source(
385
+ emotion=emotion,
386
+ emo_ref=emo_ref,
387
+ voice_source_entry=voice_entry,
388
+ voice_extracted=voice_extracted,
389
+ exaggeration=exaggeration,
390
+ )
391
+
392
+ speaker_emb = voice_entry.get("speaker_emb")
393
+ if speaker_emb is None:
394
+ speaker_emb = voice_extracted.get("speaker_emb")
395
+ speaker_emb = clone_tensor(speaker_emb, self.device)
396
+
397
+ cond_prompt = voice_entry.get("cond_prompt_speech_tokens")
398
+ if cond_prompt is None:
399
+ cond_prompt = voice_extracted.get("cond_prompt_speech_tokens")
400
+ cond_prompt = clone_tensor(cond_prompt, self.device)
401
+
402
+ emotion_adv = emotion_source.get("emotion_adv")
403
+ emotion_adv = clone_tensor(emotion_adv, self.device)
404
+
405
+ gen = voice_entry.get("gen")
406
+ if gen is None:
407
+ gen = voice_extracted.get("gen")
408
+ if gen is None:
409
+ gen = {}
410
+ gen = clone_ref_dict(gen, device=self.device)
411
+
412
+ self.conds = Conditionals(
413
+ t3=T3Cond(
414
+ speaker_emb=speaker_emb,
415
+ cond_prompt_speech_tokens=cond_prompt,
416
+ emotion_adv=emotion_adv,
417
+ ),
418
+ gen=gen,
419
+ )
420
+
421
+ tokens = self.infer_t3(text, cfg_weight, temperature)
422
+ wav = self.infer_s3gen(tokens)
423
+ return S3GEN_SR, wav, {"voice": voice_name, "emotion": emotion_name}
424
+
425
+ forward = generate
426
+
427
+
428
+ # ----------------------------------------------------------------------------
429
+ # CLI
430
+ # ----------------------------------------------------------------------------
431
+ def build_parser() -> argparse.ArgumentParser:
432
+ p = argparse.ArgumentParser(description="Use a packed TTS bundle to generate speech.")
433
+ p.add_argument("--bundle", type=Path, default=DEFAULT_BUNDLE_PATH)
434
+ p.add_argument("--text", type=str, default="Hello world, this is a test.")
435
+ p.add_argument("--voice", type=str, default=None)
436
+ p.add_argument("--emotion", type=str, default=None)
437
+ p.add_argument("--voice-ref", type=Path, default=None)
438
+ p.add_argument("--emo-ref", type=Path, default=None)
439
+ p.add_argument("--cfg-weight", type=float, default=0.5)
440
+ p.add_argument("--temperature", type=float, default=0.8)
441
+ p.add_argument("--exaggeration", type=float, default=0.5)
442
+ p.add_argument("--seed", type=int, default=42)
443
+ p.add_argument("--output", type=Path, default=DEFAULT_OUTPUT_PATH)
444
+ p.add_argument("--list", action="store_true", help="List packed voices and emotions, then exit")
445
+ return p
446
+
447
+
448
+ def main() -> None:
449
+ args = build_parser().parse_args()
450
+ tts = PackedTTS.load(args.bundle, device=DEVICE)
451
+
452
+ if args.list:
453
+ print("Voices:")
454
+ for name in tts.list_voices():
455
+ print(f" - {name}")
456
+ print("\nEmotions:")
457
+ for name, count in tts.list_emotions().items():
458
+ print(f" - {name} ({count} variations)")
459
+ return
460
+
461
+ voice_ref = str(args.voice_ref) if args.voice_ref else None
462
+ emo_ref = str(args.emo_ref) if args.emo_ref else None
463
+ sr, audio, meta = tts.generate(
464
+ text=args.text,
465
+ voice=args.voice,
466
+ emotion=args.emotion,
467
+ voice_ref=voice_ref,
468
+ emo_ref=emo_ref,
469
+ cfg_weight=args.cfg_weight,
470
+ temperature=args.temperature,
471
+ exaggeration=args.exaggeration,
472
+ seed=args.seed,
473
+ )
474
+ sf.write(str(args.output), audio, sr)
475
+ print(f"Saved {args.output}")
476
+ print(f"Resolved voice={meta['voice']} emotion={meta['emotion']}")
477
+
478
+
479
+ if __name__ == "__main__":
480
+ bundle_path = DEFAULT_BUNDLE_PATH
481
+ output_path = Path("sarah_happy_test.wav")
482
+
483
+ tts = PackedTTS.load(bundle_path, device=DEVICE)
484
+
485
+ sr, audio, meta = tts.generate(
486
+ text="Hi, this is Sarah speaking with a angry emotion.",
487
+ voice="Sarah",
488
+ emotion="Disgust",
489
+ cfg_weight=0.5,
490
+ temperature=0.8,
491
+ exaggeration=0.5,
492
+ seed=42,
493
+ )
494
+
495
+ sf.write(str(output_path), audio, sr)
496
+ print(f"Saved {output_path}")
497
+ print(f"Resolved voice={meta['voice']} emotion={meta['emotion']}")