| from __future__ import annotations
|
|
|
| import argparse
|
| import copy
|
| import random
|
| import tempfile
|
| from dataclasses import dataclass
|
| from pathlib import Path
|
| from typing import Any, Dict, Optional, Tuple
|
|
|
| import librosa
|
| import numpy as np
|
| import soundfile as sf
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| from chichat.chatterbox.models.s3tokenizer import S3_SR, drop_invalid_tokens
|
| from chichat.chatterbox.models.s3gen import S3GEN_SR, S3Gen
|
| from chichat.chatterbox.models.t3 import T3
|
| from chichat.chatterbox.models.t3.modules.cond_enc import T3Cond
|
| from chichat.chatterbox.models.tokenizers import EnTokenizer
|
| from chichat.chatterbox.models.voice_encoder import VoiceEncoder
|
|
|
|
|
|
|
|
|
|
|
| DEFAULT_BUNDLE_PATH = Path("tts.pt")
|
| DEFAULT_OUTPUT_PATH = Path("output.wav")
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| MAX_REF_SECONDS = 10.0
|
| S3GEN_SR = 24000
|
| S3_SR = 16000
|
|
|
|
|
|
|
|
|
|
|
| def set_seed(seed: int):
|
| if seed is None or int(seed) == 0:
|
| return
|
| seed = int(seed)
|
| torch.manual_seed(seed)
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed(seed)
|
| torch.cuda.manual_seed_all(seed)
|
| random.seed(seed)
|
| np.random.seed(seed)
|
|
|
|
|
| def clone_tensor(x: Optional[torch.Tensor], device=None) -> Optional[torch.Tensor]:
|
| if x is None:
|
| return None
|
| if not torch.is_tensor(x):
|
| return x
|
| out = x.detach().clone()
|
| if device is not None:
|
| out = out.to(device)
|
| return out
|
|
|
|
|
| def clone_ref_dict(ref_dict: Dict[str, Any], device=None) -> Dict[str, Any]:
|
| out: Dict[str, Any] = {}
|
| for k, v in ref_dict.items():
|
| if torch.is_tensor(v):
|
| t = v.detach().clone()
|
| if device is not None:
|
| t = t.to(device)
|
| out[k] = t
|
| else:
|
| out[k] = copy.deepcopy(v)
|
| return out
|
|
|
|
|
| def normalize_name(name: str) -> str:
|
| import re
|
|
|
| return re.sub(r"[^a-z0-9]+", "", name.strip().lower())
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class Conditionals:
|
| t3: T3Cond
|
| gen: dict
|
|
|
| def to(self, device):
|
| self.t3 = self.t3.to(device)
|
| self.t3.speaker_emb = clone_tensor(self.t3.speaker_emb, device)
|
|
|
| if getattr(self.t3, "cond_prompt_speech_tokens", None) is not None:
|
| self.t3.cond_prompt_speech_tokens = clone_tensor(self.t3.cond_prompt_speech_tokens, device)
|
|
|
| if getattr(self.t3, "emotion_adv", None) is not None:
|
| self.t3.emotion_adv = clone_tensor(self.t3.emotion_adv, device)
|
|
|
| for k, v in self.gen.items():
|
| if torch.is_tensor(v):
|
| self.gen[k] = clone_tensor(v, device)
|
| return self
|
|
|
|
|
|
|
|
|
|
|
| class PackedTTS:
|
| def __init__(self, bundle: Dict[str, Any], device: str = DEVICE):
|
| self.bundle = bundle
|
| self.device = device
|
| self.t3: Optional[T3] = None
|
| self.s3gen: Optional[S3Gen] = None
|
| self.ve: Optional[VoiceEncoder] = None
|
| self.tokenizer: Optional[EnTokenizer] = None
|
| self.conds: Optional[Conditionals] = None
|
|
|
| self._tmpdir = tempfile.TemporaryDirectory(prefix="packed_tts_tokenizer_")
|
| self._load_models_from_bundle()
|
|
|
| @classmethod
|
| def load(cls, bundle_path: Path, device: str = DEVICE) -> "PackedTTS":
|
| bundle = torch.load(bundle_path, map_location="cpu")
|
| if not isinstance(bundle, dict):
|
| raise ValueError("Packed bundle did not contain a dictionary.")
|
| bundle.setdefault("voices", {})
|
| bundle.setdefault("emotions", {})
|
| bundle.setdefault("models", {})
|
| bundle.setdefault("defaults", {})
|
| bundle.setdefault("indexes", {})
|
| return cls(bundle=bundle, device=device)
|
|
|
| def close(self):
|
| try:
|
| self._tmpdir.cleanup()
|
| except Exception:
|
| pass
|
|
|
| def __del__(self):
|
| self.close()
|
|
|
|
|
|
|
|
|
| def _load_models_from_bundle(self):
|
| models = self.bundle.get("models", {})
|
| if not models:
|
| raise ValueError("Bundle is missing packed model weights.")
|
|
|
| t3 = T3()
|
| t3.load_state_dict(models["t3_state"])
|
| t3.to(self.device).eval()
|
| self.t3 = t3
|
|
|
| s3gen = S3Gen()
|
| s3gen.load_state_dict(models["s3gen_state"], strict=False)
|
| s3gen.to(self.device).eval()
|
| self.s3gen = s3gen
|
|
|
| ve = VoiceEncoder()
|
| ve.load_state_dict(models["ve_state"])
|
| ve.to(self.device).eval()
|
| self.ve = ve
|
|
|
| tokenizer_json = models.get("tokenizer_json")
|
| if not tokenizer_json:
|
| raise ValueError("Bundle is missing tokenizer_json.")
|
| tok_path = Path(self._tmpdir.name) / "tokenizer.json"
|
| tok_path.write_text(tokenizer_json, encoding="utf-8")
|
| self.tokenizer = EnTokenizer(str(tok_path))
|
|
|
|
|
|
|
|
|
| def _load_reference_audio(self, ref_audio_path: str):
|
| wav, _ = librosa.load(
|
| ref_audio_path,
|
| sr=S3GEN_SR,
|
| mono=True,
|
| duration=MAX_REF_SECONDS,
|
| )
|
| max_len = int(MAX_REF_SECONDS * S3GEN_SR)
|
| if len(wav) > max_len:
|
| wav = wav[:max_len]
|
| return wav
|
|
|
| def extract_conditionals_from_audio(self, ref_audio_path: str, exaggeration: float = 0.5) -> Dict[str, Any]:
|
| wav = self._load_reference_audio(ref_audio_path)
|
|
|
| with torch.inference_mode():
|
| ref_dict_raw = self.s3gen.embed_ref(wav, S3GEN_SR, device=self.device)
|
|
|
| wav16k = librosa.resample(wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
|
| wav16k = np.asarray(wav16k, dtype=np.float32)
|
|
|
| embed = self.ve.embeds_from_wavs([wav16k], sample_rate=S3_SR)
|
| if isinstance(embed, torch.Tensor):
|
| speaker_emb = clone_tensor(embed.mean(dim=0, keepdim=True), self.device)
|
| else:
|
| speaker_emb = torch.from_numpy(np.asarray(embed)).mean(dim=0, keepdim=True).to(self.device)
|
|
|
| plen = self.t3.hp.speech_cond_prompt_len
|
| tok = None
|
| if plen:
|
| tokens, _ = self.s3gen.tokenizer.forward([wav16k], max_len=plen)
|
| tok = torch.atleast_2d(tokens).clone().to(self.device)
|
|
|
| ref_dict = clone_ref_dict(ref_dict_raw, device=self.device)
|
| emotion_adv = torch.full((1, 1, 1), float(exaggeration), device=self.device)
|
|
|
| return {
|
| "speaker_emb": speaker_emb,
|
| "cond_prompt_speech_tokens": tok,
|
| "emotion_adv": emotion_adv,
|
| "gen": ref_dict,
|
| }
|
|
|
|
|
|
|
|
|
| def list_voices(self):
|
| return list(self.bundle.get("voices", {}).keys())
|
|
|
| def list_emotions(self):
|
| return {k: len(v.get("variations", [])) for k, v in self.bundle.get("emotions", {}).items()}
|
|
|
| def resolve_voice(self, requested: Optional[str]) -> Tuple[str, Dict[str, Any]]:
|
| voices = self.bundle.get("voices", {})
|
| if not voices:
|
| raise ValueError("No voices are packed in this bundle.")
|
|
|
| if not requested:
|
| default_voice = self.bundle.get("defaults", {}).get("default_voice")
|
| if default_voice and default_voice in voices:
|
| return default_voice, voices[default_voice]
|
| picked = random.choice(list(voices.keys()))
|
| return picked, voices[picked]
|
|
|
| norm = normalize_name(requested)
|
| idx = self.bundle.get("indexes", {}).get("voice_norm", {})
|
| if norm in idx and idx[norm] in voices:
|
| name = idx[norm]
|
| return name, voices[name]
|
|
|
| from difflib import get_close_matches
|
|
|
| matches = get_close_matches(requested, list(voices.keys()), n=1, cutoff=self.bundle.get("defaults", {}).get("fuzzy_cutoff", 0.72))
|
| if matches:
|
| name = matches[0]
|
| return name, voices[name]
|
|
|
| picked = random.choice(list(voices.keys()))
|
| return picked, voices[picked]
|
|
|
| def resolve_emotion(self, requested: Optional[str]) -> Tuple[str, Dict[str, Any]]:
|
| emotions = self.bundle.get("emotions", {})
|
| if not emotions:
|
| raise ValueError("No emotions are packed in this bundle.")
|
|
|
| if not requested:
|
| default_emotion = self.bundle.get("defaults", {}).get("default_emotion")
|
| if default_emotion and default_emotion in emotions:
|
| emotion_name = default_emotion
|
| else:
|
| emotion_name = random.choice(list(emotions.keys()))
|
| else:
|
| norm = normalize_name(requested)
|
| idx = self.bundle.get("indexes", {}).get("emotion_norm", {})
|
| if norm in idx and idx[norm] in emotions:
|
| emotion_name = idx[norm]
|
| else:
|
| from difflib import get_close_matches
|
|
|
| matches = get_close_matches(requested, list(emotions.keys()), n=1, cutoff=self.bundle.get("defaults", {}).get("fuzzy_cutoff", 0.72))
|
| emotion_name = matches[0] if matches else random.choice(list(emotions.keys()))
|
|
|
| variations = emotions[emotion_name].get("variations", [])
|
| if not variations:
|
| raise ValueError(f"Emotion '{emotion_name}' has no variations.")
|
| return emotion_name, random.choice(variations)
|
|
|
|
|
|
|
|
|
| def _resolve_voice_source(
|
| self,
|
| voice: Optional[str],
|
| voice_ref: Optional[str],
|
| exaggeration: float,
|
| ) -> Tuple[str, Dict[str, Any], Dict[str, Any]]:
|
| """Return (voice_name, voice_entry_or_extracted, extracted_conditionals_if_any)."""
|
| if voice_ref:
|
| extracted = self.extract_conditionals_from_audio(voice_ref, exaggeration=exaggeration)
|
| return voice_ref, {"complete": True, **extracted}, extracted
|
|
|
| voice_name, entry = self.resolve_voice(voice)
|
| if entry.get("complete") and entry.get("speaker_emb") is not None:
|
| return voice_name, entry, entry
|
|
|
| raise ValueError(
|
| f"Voice '{voice_name}' does not have packed generation conditionals. Provide voice_ref or repack the voice with a sample.wav."
|
| )
|
|
|
| def _resolve_emotion_source(
|
| self,
|
| emotion: Optional[str],
|
| emo_ref: Optional[str],
|
| voice_source_entry: Dict[str, Any],
|
| voice_extracted: Dict[str, Any],
|
| exaggeration: float,
|
| ) -> Tuple[str, Dict[str, Any]]:
|
| if emo_ref:
|
| extracted = self.extract_conditionals_from_audio(emo_ref, exaggeration=exaggeration)
|
| return emo_ref, extracted
|
|
|
| if emotion:
|
| emotion_name, variation = self.resolve_emotion(emotion)
|
| return emotion_name, variation
|
|
|
|
|
| if voice_source_entry.get("emotion_adv") is not None:
|
| return "voice_default", {"emotion_adv": clone_tensor(voice_source_entry["emotion_adv"], self.device)}
|
|
|
|
|
| if voice_extracted.get("emotion_adv") is not None:
|
| return "voice_ref", {"emotion_adv": clone_tensor(voice_extracted["emotion_adv"], self.device)}
|
|
|
|
|
| return "fallback", {"emotion_adv": torch.full((1, 1, 1), float(exaggeration), device=self.device)}
|
|
|
|
|
|
|
|
|
| def infer_t3(self, text: str, cfg_weight: float, temperature: float):
|
| assert self.conds is not None, "Conditionals not prepared."
|
| text = text.strip()
|
| sot, eot = self.t3.hp.start_text_token, self.t3.hp.stop_text_token
|
| tokens = self.tokenizer.text_to_tokens(text).to(self.device)
|
|
|
| if cfg_weight > 0:
|
| tokens = torch.cat([tokens, tokens], dim=0)
|
|
|
| tokens = F.pad(tokens, (1, 0), value=sot)
|
| tokens = F.pad(tokens, (0, 1), value=eot)
|
|
|
| with torch.inference_mode():
|
| out = self.t3.inference(
|
| t3_cond=self.conds.t3,
|
| text_tokens=tokens,
|
| max_new_tokens=1000,
|
| temperature=temperature,
|
| cfg_weight=cfg_weight,
|
| )
|
|
|
| return drop_invalid_tokens(out[0]).to(self.device)
|
|
|
| def infer_s3gen(self, speech_tokens: torch.Tensor):
|
| with torch.inference_mode():
|
| wav, _ = self.s3gen.inference(
|
| speech_tokens=speech_tokens,
|
| ref_dict=self.conds.gen,
|
| )
|
| return wav.squeeze(0).detach().cpu().numpy()
|
|
|
|
|
|
|
|
|
| def generate(
|
| self,
|
| text: str,
|
| voice: Optional[str] = None,
|
| emotion: Optional[str] = None,
|
| voice_ref: Optional[str] = None,
|
| emo_ref: Optional[str] = None,
|
| cfg_weight: float = 0.5,
|
| temperature: float = 0.8,
|
| exaggeration: float = 0.5,
|
| seed: int = 0,
|
| ):
|
| if seed:
|
| set_seed(seed)
|
|
|
| voice_name, voice_entry, voice_extracted = self._resolve_voice_source(voice, voice_ref, exaggeration)
|
| emotion_name, emotion_source = self._resolve_emotion_source(
|
| emotion=emotion,
|
| emo_ref=emo_ref,
|
| voice_source_entry=voice_entry,
|
| voice_extracted=voice_extracted,
|
| exaggeration=exaggeration,
|
| )
|
|
|
| speaker_emb = voice_entry.get("speaker_emb")
|
| if speaker_emb is None:
|
| speaker_emb = voice_extracted.get("speaker_emb")
|
| speaker_emb = clone_tensor(speaker_emb, self.device)
|
|
|
| cond_prompt = voice_entry.get("cond_prompt_speech_tokens")
|
| if cond_prompt is None:
|
| cond_prompt = voice_extracted.get("cond_prompt_speech_tokens")
|
| cond_prompt = clone_tensor(cond_prompt, self.device)
|
|
|
| emotion_adv = emotion_source.get("emotion_adv")
|
| emotion_adv = clone_tensor(emotion_adv, self.device)
|
|
|
| gen = voice_entry.get("gen")
|
| if gen is None:
|
| gen = voice_extracted.get("gen")
|
| if gen is None:
|
| gen = {}
|
| gen = clone_ref_dict(gen, device=self.device)
|
|
|
| self.conds = Conditionals(
|
| t3=T3Cond(
|
| speaker_emb=speaker_emb,
|
| cond_prompt_speech_tokens=cond_prompt,
|
| emotion_adv=emotion_adv,
|
| ),
|
| gen=gen,
|
| )
|
|
|
| tokens = self.infer_t3(text, cfg_weight, temperature)
|
| wav = self.infer_s3gen(tokens)
|
| return S3GEN_SR, wav, {"voice": voice_name, "emotion": emotion_name}
|
|
|
| forward = generate
|
|
|
|
|
|
|
|
|
|
|
| def build_parser() -> argparse.ArgumentParser:
|
| p = argparse.ArgumentParser(description="Use a packed TTS bundle to generate speech.")
|
| p.add_argument("--bundle", type=Path, default=DEFAULT_BUNDLE_PATH)
|
| p.add_argument("--text", type=str, default="Hello world, this is a test.")
|
| p.add_argument("--voice", type=str, default=None)
|
| p.add_argument("--emotion", type=str, default=None)
|
| p.add_argument("--voice-ref", type=Path, default=None)
|
| p.add_argument("--emo-ref", type=Path, default=None)
|
| p.add_argument("--cfg-weight", type=float, default=0.5)
|
| p.add_argument("--temperature", type=float, default=0.8)
|
| p.add_argument("--exaggeration", type=float, default=0.5)
|
| p.add_argument("--seed", type=int, default=42)
|
| p.add_argument("--output", type=Path, default=DEFAULT_OUTPUT_PATH)
|
| p.add_argument("--list", action="store_true", help="List packed voices and emotions, then exit")
|
| return p
|
|
|
|
|
| def main() -> None:
|
| args = build_parser().parse_args()
|
| tts = PackedTTS.load(args.bundle, device=DEVICE)
|
|
|
| if args.list:
|
| print("Voices:")
|
| for name in tts.list_voices():
|
| print(f" - {name}")
|
| print("\nEmotions:")
|
| for name, count in tts.list_emotions().items():
|
| print(f" - {name} ({count} variations)")
|
| return
|
|
|
| voice_ref = str(args.voice_ref) if args.voice_ref else None
|
| emo_ref = str(args.emo_ref) if args.emo_ref else None
|
| sr, audio, meta = tts.generate(
|
| text=args.text,
|
| voice=args.voice,
|
| emotion=args.emotion,
|
| voice_ref=voice_ref,
|
| emo_ref=emo_ref,
|
| cfg_weight=args.cfg_weight,
|
| temperature=args.temperature,
|
| exaggeration=args.exaggeration,
|
| seed=args.seed,
|
| )
|
| sf.write(str(args.output), audio, sr)
|
| print(f"Saved {args.output}")
|
| print(f"Resolved voice={meta['voice']} emotion={meta['emotion']}")
|
|
|
|
|
| if __name__ == "__main__":
|
| bundle_path = DEFAULT_BUNDLE_PATH
|
| output_path = Path("sarah_happy_test.wav")
|
|
|
| tts = PackedTTS.load(bundle_path, device=DEVICE)
|
|
|
| sr, audio, meta = tts.generate(
|
| text="Hi, this is Sarah speaking with a angry emotion.",
|
| voice="Sarah",
|
| emotion="Disgust",
|
| cfg_weight=0.5,
|
| temperature=0.8,
|
| exaggeration=0.5,
|
| seed=42,
|
| )
|
|
|
| sf.write(str(output_path), audio, sr)
|
| print(f"Saved {output_path}")
|
| print(f"Resolved voice={meta['voice']} emotion={meta['emotion']}") |