rahul7star commited on
Commit
d651e33
·
verified ·
1 Parent(s): 89ed58d

Update src/chatterbox/mtl_tts.py

Browse files
Files changed (1) hide show
  1. src/chatterbox/mtl_tts.py +215 -67
src/chatterbox/mtl_tts.py CHANGED
@@ -1,8 +1,9 @@
1
  from dataclasses import dataclass
2
  from pathlib import Path
3
  import os
4
- import torch
5
  import librosa
 
6
  import perth
7
  import torch.nn.functional as F
8
  from safetensors.torch import load_file as load_safetensors
@@ -16,47 +17,116 @@ from .models.tokenizers import MTLTokenizer
16
  from .models.voice_encoder import VoiceEncoder
17
  from .models.t3.modules.cond_enc import T3Cond
18
 
 
19
  REPO_ID = "ResembleAI/chatterbox"
20
 
21
  # Supported languages for the multilingual model
22
  SUPPORTED_LANGUAGES = {
23
- "ar": "Arabic", "da": "Danish", "de": "German", "el": "Greek", "en": "English",
24
- "es": "Spanish", "fi": "Finnish", "fr": "French", "he": "Hebrew", "hi": "Hindi",
25
- "it": "Italian", "ja": "Japanese", "ko": "Korean", "ms": "Malay", "nl": "Dutch",
26
- "no": "Norwegian", "pl": "Polish", "pt": "Portuguese", "ru": "Russian", "sv": "Swedish",
27
- "sw": "Swahili", "tr": "Turkish", "zh": "Chinese",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  }
29
 
30
 
31
  def punc_norm(text: str) -> str:
 
 
 
 
32
  if len(text) == 0:
33
  return "You need to add some text for me to talk."
 
 
34
  if text[0].islower():
35
  text = text[0].upper() + text[1:]
 
 
36
  text = " ".join(text.split())
37
- replacements = [
38
- ("...", ", "), ("…", ", "), (":", ","), (" - ", ","), (";", ","),
39
- ("—", "-"), ("–", "-"), (" ,", ","), ("“", "\""), ("”", "\""),
40
- ("", "'"), ("’", "'"),
 
 
 
 
 
 
 
 
 
 
 
41
  ]
42
- for old, new in replacements:
43
- text = text.replace(old, new)
44
- if not text[-1] in {".", "!", "?", "-", ",","、",",","。","?","!"}:
 
 
 
 
45
  text += "."
 
46
  return text
47
 
48
 
49
  @dataclass
50
  class Conditionals:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  t3: T3Cond
52
  gen: dict
53
 
 
 
 
 
 
 
 
54
  def save(self, fpath: Path):
55
- torch.save({"t3": self.t3.__dict__, "gen": self.gen}, fpath)
 
 
 
 
56
 
57
  @classmethod
58
- def load(cls, fpath: Path):
59
- kwargs = torch.load(fpath, map_location="cpu", weights_only=True)
60
  return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
61
 
62
 
@@ -64,8 +134,16 @@ class ChatterboxMultilingualTTS:
64
  ENC_COND_LEN = 6 * S3_SR
65
  DEC_COND_LEN = 10 * S3GEN_SR
66
 
67
- def __init__(self, t3, s3gen, ve, tokenizer, device, conds=None):
68
- self.sr = S3GEN_SR
 
 
 
 
 
 
 
 
69
  self.t3 = t3
70
  self.s3gen = s3gen
71
  self.ve = ve
@@ -75,9 +153,18 @@ class ChatterboxMultilingualTTS:
75
  self.watermarker = perth.PerthImplicitWatermarker()
76
 
77
  @classmethod
78
- def from_local(cls, ckpt_dir, device=torch.device("cpu")):
 
 
 
 
 
 
 
79
  ve = VoiceEncoder()
80
- ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", weights_only=True))
 
 
81
  ve.to(device).eval()
82
 
83
  t3 = T3(T3Config.multilingual())
@@ -88,66 +175,127 @@ class ChatterboxMultilingualTTS:
88
  t3.to(device).eval()
89
 
90
  s3gen = S3Gen()
91
- s3gen.load_state_dict(torch.load(ckpt_dir / "s3gen.pt", weights_only=True))
 
 
92
  s3gen.to(device).eval()
93
 
94
- tokenizer = MTLTokenizer(str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json"))
 
 
95
 
96
  conds = None
97
  if (builtin_voice := ckpt_dir / "conds.pt").exists():
98
- conds = Conditionals.load(builtin_voice)
99
 
100
- return cls(t3, s3gen, ve, tokenizer, device, conds)
101
 
102
  @classmethod
103
- def from_pretrained(cls, device=None):
104
- if device is None:
105
- device = torch.device("cpu")
106
- elif isinstance(device, str):
107
- device = torch.device(device)
108
- if device.type != "cpu":
109
- device = torch.device("cpu")
110
-
111
- ckpt_dir = Path(snapshot_download(
112
- repo_id=REPO_ID,
113
- repo_type="model",
114
- revision="main",
115
- allow_patterns=[
116
- "ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt",
117
- "grapheme_mtl_merged_expanded_v1.json", "conds.pt"
118
- ],
119
- token=os.getenv("HF_TOKEN"),
120
- ))
121
-
122
  return cls.from_local(ckpt_dir, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- @torch.no_grad()
125
- def generate(self, text: str, speaker_embedding=None, language_id=None, **kwargs):
126
- """
127
- CPU-safe text-to-speech.
128
- Accepts optional `language_id` and any other kwargs.
129
- """
130
- # Normalize punctuation
131
- text = text.strip()
132
- if not text.endswith("."):
133
- text += "."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- # Encode text
136
- token_ids = self.tokenizer.encode(text)
137
- token_ids = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
 
 
 
 
 
138
 
139
- conds = self.conds.gen if self.conds else {}
 
 
 
140
 
141
- # Include language_id in conds if provided
142
- if language_id is not None:
143
- conds = conds.copy()
144
- conds['language_id'] = language_id
145
 
146
- # Run through T3 and S3Gen
147
- t3_out = self.t3(token_ids, **conds)
148
- audio = self.s3gen(t3_out, **conds)
 
 
 
 
 
 
 
 
 
 
149
 
150
- if isinstance(audio, torch.Tensor):
151
- audio = audio.squeeze(0).cpu().numpy()
 
152
 
153
- return audio
 
 
 
 
 
 
 
1
  from dataclasses import dataclass
2
  from pathlib import Path
3
  import os
4
+
5
  import librosa
6
+ import torch
7
  import perth
8
  import torch.nn.functional as F
9
  from safetensors.torch import load_file as load_safetensors
 
17
  from .models.voice_encoder import VoiceEncoder
18
  from .models.t3.modules.cond_enc import T3Cond
19
 
20
+
21
  REPO_ID = "ResembleAI/chatterbox"
22
 
23
  # Supported languages for the multilingual model
24
  SUPPORTED_LANGUAGES = {
25
+ "ar": "Arabic",
26
+ "da": "Danish",
27
+ "de": "German",
28
+ "el": "Greek",
29
+ "en": "English",
30
+ "es": "Spanish",
31
+ "fi": "Finnish",
32
+ "fr": "French",
33
+ "he": "Hebrew",
34
+ "hi": "Hindi",
35
+ "it": "Italian",
36
+ "ja": "Japanese",
37
+ "ko": "Korean",
38
+ "ms": "Malay",
39
+ "nl": "Dutch",
40
+ "no": "Norwegian",
41
+ "pl": "Polish",
42
+ "pt": "Portuguese",
43
+ "ru": "Russian",
44
+ "sv": "Swedish",
45
+ "sw": "Swahili",
46
+ "tr": "Turkish",
47
+ "zh": "Chinese",
48
  }
49
 
50
 
51
  def punc_norm(text: str) -> str:
52
+ """
53
+ Quick cleanup func for punctuation from LLMs or
54
+ containing chars not seen often in the dataset
55
+ """
56
  if len(text) == 0:
57
  return "You need to add some text for me to talk."
58
+
59
+ # Capitalise first letter
60
  if text[0].islower():
61
  text = text[0].upper() + text[1:]
62
+
63
+ # Remove multiple space chars
64
  text = " ".join(text.split())
65
+
66
+ # Replace uncommon/llm punc
67
+ punc_to_replace = [
68
+ ("...", ", "),
69
+ ("…", ", "),
70
+ (":", ","),
71
+ (" - ", ", "),
72
+ (";", ", "),
73
+ ("—", "-"),
74
+ ("–", "-"),
75
+ (" ,", ","),
76
+ ("“", "\""),
77
+ ("”", "\""),
78
+ ("‘", "'"),
79
+ ("’", "'"),
80
  ]
81
+ for old_char_sequence, new_char in punc_to_replace:
82
+ text = text.replace(old_char_sequence, new_char)
83
+
84
+ # Add full stop if no ending punc
85
+ text = text.rstrip(" ")
86
+ sentence_enders = {".", "!", "?", "-", ",","、",",","。","?","!"}
87
+ if not any(text.endswith(p) for p in sentence_enders):
88
  text += "."
89
+
90
  return text
91
 
92
 
93
  @dataclass
94
  class Conditionals:
95
+ """
96
+ Conditionals for T3 and S3Gen
97
+ - T3 conditionals:
98
+ - speaker_emb
99
+ - clap_emb
100
+ - cond_prompt_speech_tokens
101
+ - cond_prompt_speech_emb
102
+ - emotion_adv
103
+ - S3Gen conditionals:
104
+ - prompt_token
105
+ - prompt_token_len
106
+ - prompt_feat
107
+ - prompt_feat_len
108
+ - embedding
109
+ """
110
  t3: T3Cond
111
  gen: dict
112
 
113
+ def to(self, device):
114
+ self.t3 = self.t3.to(device=device)
115
+ for k, v in self.gen.items():
116
+ if torch.is_tensor(v):
117
+ self.gen[k] = v.to(device=device)
118
+ return self
119
+
120
  def save(self, fpath: Path):
121
+ arg_dict = dict(
122
+ t3=self.t3.__dict__,
123
+ gen=self.gen
124
+ )
125
+ torch.save(arg_dict, fpath)
126
 
127
  @classmethod
128
+ def load(cls, fpath, map_location="cpu"):
129
+ kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
130
  return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
131
 
132
 
 
134
  ENC_COND_LEN = 6 * S3_SR
135
  DEC_COND_LEN = 10 * S3GEN_SR
136
 
137
+ def __init__(
138
+ self,
139
+ t3: T3,
140
+ s3gen: S3Gen,
141
+ ve: VoiceEncoder,
142
+ tokenizer: MTLTokenizer,
143
+ device: str,
144
+ conds: Conditionals = None,
145
+ ):
146
+ self.sr = S3GEN_SR # sample rate of synthesized audio
147
  self.t3 = t3
148
  self.s3gen = s3gen
149
  self.ve = ve
 
153
  self.watermarker = perth.PerthImplicitWatermarker()
154
 
155
  @classmethod
156
+ def get_supported_languages(cls):
157
+ """Return dictionary of supported language codes and names."""
158
+ return SUPPORTED_LANGUAGES.copy()
159
+
160
+ @classmethod
161
+ def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS':
162
+ ckpt_dir = Path(ckpt_dir)
163
+
164
  ve = VoiceEncoder()
165
+ ve.load_state_dict(
166
+ torch.load(ckpt_dir / "ve.pt", weights_only=True)
167
+ )
168
  ve.to(device).eval()
169
 
170
  t3 = T3(T3Config.multilingual())
 
175
  t3.to(device).eval()
176
 
177
  s3gen = S3Gen()
178
+ s3gen.load_state_dict(
179
+ torch.load(ckpt_dir / "s3gen.pt", weights_only=True)
180
+ )
181
  s3gen.to(device).eval()
182
 
183
+ tokenizer = MTLTokenizer(
184
+ str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
185
+ )
186
 
187
  conds = None
188
  if (builtin_voice := ckpt_dir / "conds.pt").exists():
189
+ conds = Conditionals.load(builtin_voice).to(device)
190
 
191
+ return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
192
 
193
  @classmethod
194
+ def from_pretrained(cls, device: torch.device) -> 'ChatterboxMultilingualTTS':
195
+ ckpt_dir = Path(
196
+ snapshot_download(
197
+ repo_id=REPO_ID,
198
+ repo_type="model",
199
+ revision="main",
200
+ allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"],
201
+ token=os.getenv("HF_TOKEN"),
202
+ )
203
+ )
 
 
 
 
 
 
 
 
 
204
  return cls.from_local(ckpt_dir, device)
205
+
206
+ def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
207
+ ## Load reference wav
208
+ s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
209
+
210
+ ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
211
+
212
+ s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
213
+ s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
214
+
215
+ # Speech cond prompt tokens
216
+ t3_cond_prompt_tokens = None
217
+ if plen := self.t3.hp.speech_cond_prompt_len:
218
+ s3_tokzr = self.s3gen.tokenizer
219
+ t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
220
+ t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
221
+
222
+ # Voice-encoder speaker embedding
223
+ ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
224
+ ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
225
+
226
+ t3_cond = T3Cond(
227
+ speaker_emb=ve_embed,
228
+ cond_prompt_speech_tokens=t3_cond_prompt_tokens,
229
+ emotion_adv=exaggeration * torch.ones(1, 1, 1),
230
+ ).to(device=self.device)
231
+ self.conds = Conditionals(t3_cond, s3gen_ref_dict)
232
 
233
+ def generate(
234
+ self,
235
+ text,
236
+ language_id,
237
+ audio_prompt_path=None,
238
+ exaggeration=0.5,
239
+ cfg_weight=0.5,
240
+ temperature=0.8,
241
+ repetition_penalty=2.0,
242
+ min_p=0.05,
243
+ top_p=1.0,
244
+ ):
245
+ # Validate language_id
246
+ if language_id and language_id.lower() not in SUPPORTED_LANGUAGES:
247
+ supported_langs = ", ".join(SUPPORTED_LANGUAGES.keys())
248
+ raise ValueError(
249
+ f"Unsupported language_id '{language_id}'. "
250
+ f"Supported languages: {supported_langs}"
251
+ )
252
+
253
+ if audio_prompt_path:
254
+ self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
255
+ else:
256
+ assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
257
 
258
+ # Update exaggeration if needed
259
+ if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()):
260
+ _cond: T3Cond = self.conds.t3
261
+ self.conds.t3 = T3Cond(
262
+ speaker_emb=_cond.speaker_emb,
263
+ cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens,
264
+ emotion_adv=exaggeration * torch.ones(1, 1, 1),
265
+ ).to(device=self.device)
266
 
267
+ # Norm and tokenize text
268
+ text = punc_norm(text)
269
+ text_tokens = self.tokenizer.text_to_tokens(text, language_id=language_id.lower() if language_id else None).to(self.device)
270
+ text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
271
 
272
+ sot = self.t3.hp.start_text_token
273
+ eot = self.t3.hp.stop_text_token
274
+ text_tokens = F.pad(text_tokens, (1, 0), value=sot)
275
+ text_tokens = F.pad(text_tokens, (0, 1), value=eot)
276
 
277
+ with torch.inference_mode():
278
+ speech_tokens = self.t3.inference(
279
+ t3_cond=self.conds.t3,
280
+ text_tokens=text_tokens,
281
+ max_new_tokens=1000, # TODO: use the value in config
282
+ temperature=temperature,
283
+ cfg_weight=cfg_weight,
284
+ repetition_penalty=repetition_penalty,
285
+ min_p=min_p,
286
+ top_p=top_p,
287
+ )
288
+ # Extract only the conditional batch.
289
+ speech_tokens = speech_tokens[0]
290
 
291
+ # TODO: output becomes 1D
292
+ speech_tokens = drop_invalid_tokens(speech_tokens)
293
+ speech_tokens = speech_tokens.to(self.device)
294
 
295
+ wav, _ = self.s3gen.inference(
296
+ speech_tokens=speech_tokens,
297
+ ref_dict=self.conds.gen,
298
+ )
299
+ wav = wav.squeeze(0).detach().cpu().numpy()
300
+ watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
301
+ return torch.from_numpy(watermarked_wav).unsqueeze(0)