Ahmed Wasfy commited on
Commit
b6577ee
·
1 Parent(s): c77a697

New model changes

Browse files
src/chatterbox/models/t3/t3.py CHANGED
@@ -10,7 +10,11 @@ import torch
10
  import torch.nn.functional as F
11
  from torch import nn, Tensor
12
  from transformers import LlamaModel, LlamaConfig
13
- from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor, MinPLogitsWarper
 
 
 
 
14
 
15
  from .modules.learned_pos_emb import LearnedPositionEmbeddings
16
 
@@ -27,8 +31,12 @@ logger = logging.getLogger(__name__)
27
 
28
  def _ensure_BOT_EOT(text_tokens: Tensor, hp):
29
  B = text_tokens.size(0)
30
- assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
31
- assert (text_tokens == hp.stop_text_token).int().sum() >= B, "missing stop_text_token"
 
 
 
 
32
 
33
 
34
  class T3(nn.Module):
@@ -43,7 +51,9 @@ class T3(nn.Module):
43
 
44
  def __init__(self, hp=None):
45
  if hp is None:
46
- hp = T3Config.english_only() # Default to English-only config for backward compatibility
 
 
47
  super().__init__()
48
  self.hp = hp
49
  self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
@@ -65,8 +75,12 @@ class T3(nn.Module):
65
  self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim)
66
 
67
  # logit projection
68
- self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
69
- self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False)
 
 
 
 
70
  self.compiled = False
71
 
72
  @property
@@ -77,9 +91,13 @@ class T3(nn.Module):
77
  """
78
  Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
79
  """
80
- if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None:
81
- t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \
82
- self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
 
 
 
 
83
  return self.cond_enc(t3_cond) # (B, len_cond, dim)
84
 
85
  def prepare_input_embeds(
@@ -103,13 +121,15 @@ class T3(nn.Module):
103
  len_cond = cond_emb.size(1)
104
 
105
  if cond_emb.size(0) != text_emb.size(0):
106
- cond_emb = cond_emb.expand(text_emb.size(0), -1, -1)
107
 
108
  # concat
109
- embeds = torch.stack([
110
- torch.cat((ce, te, se))
111
- for ce, te, se in zip(cond_emb, text_emb, speech_emb)
112
- ]) # (B, length, dim)
 
 
113
  return embeds, len_cond
114
 
115
  def forward(
@@ -140,7 +160,9 @@ class T3(nn.Module):
140
  return_dict=True,
141
  use_cache=(not training),
142
  )
143
- hidden_states = tfmr_out.hidden_states[-1] # final tfmr layer output, (B, seq, dim)
 
 
144
 
145
  # post-processing: splice out text and speech parts of hidden states
146
  len_text = text_tokens.size(1)
@@ -154,8 +176,8 @@ class T3(nn.Module):
154
  text_end = len_cond + ttl[i].item()
155
  speech_start = len_cond + text_tokens.size(1)
156
  speech_end = speech_start + stl[i].item()
157
- text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end]
158
- speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end]
159
 
160
  # logit projection
161
  text_logits = self.text_head(text_latents)
@@ -173,17 +195,21 @@ class T3(nn.Module):
173
  self,
174
  *,
175
  t3_cond: T3Cond,
176
- text_tokens: torch.LongTensor,
177
- text_token_lens: torch.LongTensor,
178
- speech_tokens: torch.LongTensor,
179
- speech_token_lens: torch.LongTensor,
 
 
180
  ):
181
- "training method"
182
- len_text = text_tokens.size(1)
183
- len_speech = speech_tokens.size(1)
184
- assert len_text == text_token_lens.max()
185
- assert len_speech == speech_token_lens.max()
 
186
 
 
187
  out = self.forward(
188
  t3_cond=t3_cond,
189
  text_tokens=text_tokens,
@@ -191,19 +217,42 @@ class T3(nn.Module):
191
  speech_tokens=speech_tokens,
192
  speech_token_lens=speech_token_lens,
193
  training=True,
194
- ) # (B, seq, vocab_size)
195
-
196
- # Calc CCE losses
197
- IGNORE_ID = -100
198
  device = out.text_logits.device
199
- mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None] # (B, len_text)
200
- mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] # (B, len_speech)
201
- masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID)
202
- masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID)
203
- loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID)
204
- loss_speech = F.cross_entropy(out.speech_logits, masked_speech, ignore_index=IGNORE_ID)
205
 
206
- return loss_text, loss_speech
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  @torch.inference_mode()
209
  def inference(
@@ -211,11 +260,9 @@ class T3(nn.Module):
211
  *,
212
  t3_cond: T3Cond,
213
  text_tokens: Tensor,
214
- initial_speech_tokens: Optional[Tensor]=None,
215
-
216
  # misc conditioning
217
- prepend_prompt_speech_tokens: Optional[Tensor]=None,
218
-
219
  # HF generate args
220
  num_return_sequences=1,
221
  max_new_tokens=None,
@@ -235,11 +282,15 @@ class T3(nn.Module):
235
  # Validate / sanitize inputs
236
  assert prepend_prompt_speech_tokens is None, "not implemented"
237
  _ensure_BOT_EOT(text_tokens, self.hp)
238
- text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device)
 
 
239
 
240
  # Default initial speech to a single start-of-speech token
241
  if initial_speech_tokens is None:
242
- initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
 
 
243
 
244
  # Prepare custom input embeds
245
  embeds, len_cond = self.prepare_input_embeds(
@@ -264,7 +315,7 @@ class T3(nn.Module):
264
  self.tfmr,
265
  None,
266
  text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
267
- alignment_layer_idx=9, # TODO: hparam or something?
268
  eos_idx=self.hp.stop_speech_token,
269
  )
270
  assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
@@ -298,7 +349,9 @@ class T3(nn.Module):
298
 
299
  device = embeds.device
300
 
301
- bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device)
 
 
302
  bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim)
303
  bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)
304
 
@@ -316,7 +369,9 @@ class T3(nn.Module):
316
  top_p_warper = TopPLogitsWarper(top_p=top_p)
317
  min_p_warper = MinPLogitsWarper(min_p=min_p)
318
  top_p_warper = TopPLogitsWarper(top_p=top_p)
319
- repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
 
 
320
 
321
  # ---- Initial Forward Pass (no kv_cache yet) ----
322
  output = self.patched_model(
@@ -332,29 +387,33 @@ class T3(nn.Module):
332
 
333
  # ---- Generation Loop using kv_cache ----
334
  for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
335
- logits_step = output.logits[:, -1, :]
336
  # CFG combine → (1, V)
337
- cond = logits_step[0:1, :]
338
  uncond = logits_step[1:2, :]
339
  cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
340
  logits = cond + cfg * (cond - uncond)
341
-
342
  # Apply alignment stream analyzer integrity checks
343
  if self.patched_model.alignment_stream_analyzer is not None:
344
- if logits.dim() == 1: # guard in case something upstream squeezed
345
- logits = logits.unsqueeze(0) # (1, V)
346
  # Pass the last generated token for repetition tracking
347
- last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None
348
- logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V)
 
 
 
 
349
 
350
  # Apply repetition penalty
351
- ids_for_proc = generated_ids[:1, ...] # batch = 1
352
  logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
353
-
354
  # Apply temperature scaling.
355
  if temperature != 1.0:
356
  logits = logits / temperature
357
-
358
  # Apply min_p and top_p filtering
359
  logits = min_p_warper(ids_for_proc, logits)
360
  logits = top_p_warper(ids_for_proc, logits)
@@ -373,7 +432,9 @@ class T3(nn.Module):
373
 
374
  # Get embedding for the new token.
375
  next_token_embed = self.speech_emb(next_token)
376
- next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1)
 
 
377
 
378
  # For CFG
379
  next_token_embed = torch.cat([next_token_embed, next_token_embed])
 
10
  import torch.nn.functional as F
11
  from torch import nn, Tensor
12
  from transformers import LlamaModel, LlamaConfig
13
+ from transformers.generation.logits_process import (
14
+ TopPLogitsWarper,
15
+ RepetitionPenaltyLogitsProcessor,
16
+ MinPLogitsWarper,
17
+ )
18
 
19
  from .modules.learned_pos_emb import LearnedPositionEmbeddings
20
 
 
31
 
32
  def _ensure_BOT_EOT(text_tokens: Tensor, hp):
33
  B = text_tokens.size(0)
34
+ assert (
35
+ text_tokens == hp.start_text_token
36
+ ).int().sum() >= B, "missing start_text_token"
37
+ assert (
38
+ text_tokens == hp.stop_text_token
39
+ ).int().sum() >= B, "missing stop_text_token"
40
 
41
 
42
  class T3(nn.Module):
 
51
 
52
  def __init__(self, hp=None):
53
  if hp is None:
54
+ hp = (
55
+ T3Config.english_only()
56
+ ) # Default to English-only config for backward compatibility
57
  super().__init__()
58
  self.hp = hp
59
  self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
 
75
  self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim)
76
 
77
  # logit projection
78
+ self.text_head = nn.Linear(
79
+ self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False
80
+ )
81
+ self.speech_head = nn.Linear(
82
+ self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False
83
+ )
84
  self.compiled = False
85
 
86
  @property
 
91
  """
92
  Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
93
  """
94
+ if (
95
+ t3_cond.cond_prompt_speech_tokens is not None
96
+ and t3_cond.cond_prompt_speech_emb is None
97
+ ):
98
+ t3_cond.cond_prompt_speech_emb = self.speech_emb(
99
+ t3_cond.cond_prompt_speech_tokens
100
+ ) + self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
101
  return self.cond_enc(t3_cond) # (B, len_cond, dim)
102
 
103
  def prepare_input_embeds(
 
121
  len_cond = cond_emb.size(1)
122
 
123
  if cond_emb.size(0) != text_emb.size(0):
124
+ cond_emb = cond_emb.expand(text_emb.size(0), -1, -1)
125
 
126
  # concat
127
+ embeds = torch.stack(
128
+ [
129
+ torch.cat((ce, te, se))
130
+ for ce, te, se in zip(cond_emb, text_emb, speech_emb)
131
+ ]
132
+ ) # (B, length, dim)
133
  return embeds, len_cond
134
 
135
  def forward(
 
160
  return_dict=True,
161
  use_cache=(not training),
162
  )
163
+ hidden_states = tfmr_out.hidden_states[
164
+ -1
165
+ ] # final tfmr layer output, (B, seq, dim)
166
 
167
  # post-processing: splice out text and speech parts of hidden states
168
  len_text = text_tokens.size(1)
 
176
  text_end = len_cond + ttl[i].item()
177
  speech_start = len_cond + text_tokens.size(1)
178
  speech_end = speech_start + stl[i].item()
179
+ text_latents[i, : ttl[i]] = hidden_states[i, len_cond:text_end]
180
+ speech_latents[i, : stl[i]] = hidden_states[i, speech_start:speech_end]
181
 
182
  # logit projection
183
  text_logits = self.text_head(text_latents)
 
195
  self,
196
  *,
197
  t3_cond: T3Cond,
198
+ text_tokens: torch.LongTensor, # (B, S_text_padded), includes BOS & EOS
199
+ text_token_lens: torch.LongTensor, # (B,), actual lengths including BOS & EOS
200
+ speech_tokens: torch.LongTensor, # (B, S_speech_padded), includes BOS & EOS
201
+ speech_token_lens: torch.LongTensor, # (B,), actual lengths including BOS & EOS
202
+ labels_text: torch.LongTensor, # (B, S_text_padded-1), already masked with –100
203
+ labels_speech: torch.LongTensor, # (B, S_speech_padded-1), already masked with –100
204
  ):
205
+ """
206
+ Compute text and speech cross-entropy using pre-masked labels from the collator.
207
+ Assumes:
208
+ - labels_text[t] corresponds to predicting text_tokens[:, 1:] with –100 where ignored
209
+ - labels_speech[t] corresponds to predicting speech_tokens[:, 1:] with –100 where ignored
210
+ """
211
 
212
+ # 1) Run model to get logits
213
  out = self.forward(
214
  t3_cond=t3_cond,
215
  text_tokens=text_tokens,
 
217
  speech_tokens=speech_tokens,
218
  speech_token_lens=speech_token_lens,
219
  training=True,
220
+ )
221
+ # out.text_logits: (B, S_text_padded, V_text)
222
+ # out.speech_logits: (B, S_speech_padded, V_speech)
 
223
  device = out.text_logits.device
224
+ IGNORE_ID = -100
 
 
 
 
 
225
 
226
+ # --- Text Loss (use labels_text directly) ---
227
+ # Align logits: predict t₁..EOS from inputs [BOS, t₁..]
228
+ logits_for_text = out.text_logits[
229
+ :, :-1, :
230
+ ].contiguous() # (B, S_text_padded-1, V_text)
231
+ # labels_text already has shape (B, S_text_padded-1) with –100 where masked
232
+ if logits_for_text.size(1) == 0:
233
+ loss_text = torch.tensor(0.0, device=device, requires_grad=self.training)
234
+ else:
235
+ loss_text = F.cross_entropy(
236
+ logits_for_text.transpose(1, 2), # (B, V_text, S_text_padded-1)
237
+ labels_text, # (B, S_text_padded-1), ignore_index=–100
238
+ ignore_index=IGNORE_ID,
239
+ )
240
+
241
+ # --- Speech Loss (use labels_speech directly) ---
242
+ logits_for_speech = out.speech_logits[
243
+ :, :-1, :
244
+ ].contiguous() # (B, S_speech_padded-1, V_speech)
245
+ # labels_speech already has shape (B, S_speech_padded-1) with –100 where masked
246
+ if logits_for_speech.size(1) == 0:
247
+ loss_speech = torch.tensor(0.0, device=device, requires_grad=self.training)
248
+ else:
249
+ loss_speech = F.cross_entropy(
250
+ logits_for_speech.transpose(1, 2), # (B, V_speech, S_speech_padded-1)
251
+ labels_speech, # (B, S_speech_padded-1), ignore_index=–100
252
+ ignore_index=IGNORE_ID,
253
+ )
254
+
255
+ return loss_text, loss_speech, out.speech_logits
256
 
257
  @torch.inference_mode()
258
  def inference(
 
260
  *,
261
  t3_cond: T3Cond,
262
  text_tokens: Tensor,
263
+ initial_speech_tokens: Optional[Tensor] = None,
 
264
  # misc conditioning
265
+ prepend_prompt_speech_tokens: Optional[Tensor] = None,
 
266
  # HF generate args
267
  num_return_sequences=1,
268
  max_new_tokens=None,
 
282
  # Validate / sanitize inputs
283
  assert prepend_prompt_speech_tokens is None, "not implemented"
284
  _ensure_BOT_EOT(text_tokens, self.hp)
285
+ text_tokens = torch.atleast_2d(text_tokens).to(
286
+ dtype=torch.long, device=self.device
287
+ )
288
 
289
  # Default initial speech to a single start-of-speech token
290
  if initial_speech_tokens is None:
291
+ initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(
292
+ text_tokens[:, :1]
293
+ )
294
 
295
  # Prepare custom input embeds
296
  embeds, len_cond = self.prepare_input_embeds(
 
315
  self.tfmr,
316
  None,
317
  text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
318
+ alignment_layer_idx=9, # TODO: hparam or something?
319
  eos_idx=self.hp.stop_speech_token,
320
  )
321
  assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
 
349
 
350
  device = embeds.device
351
 
352
+ bos_token = torch.tensor(
353
+ [[self.hp.start_speech_token]], dtype=torch.long, device=device
354
+ )
355
  bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim)
356
  bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)
357
 
 
369
  top_p_warper = TopPLogitsWarper(top_p=top_p)
370
  min_p_warper = MinPLogitsWarper(min_p=min_p)
371
  top_p_warper = TopPLogitsWarper(top_p=top_p)
372
+ repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(
373
+ penalty=float(repetition_penalty)
374
+ )
375
 
376
  # ---- Initial Forward Pass (no kv_cache yet) ----
377
  output = self.patched_model(
 
387
 
388
  # ---- Generation Loop using kv_cache ----
389
  for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
390
+ logits_step = output.logits[:, -1, :]
391
  # CFG combine → (1, V)
392
+ cond = logits_step[0:1, :]
393
  uncond = logits_step[1:2, :]
394
  cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
395
  logits = cond + cfg * (cond - uncond)
396
+
397
  # Apply alignment stream analyzer integrity checks
398
  if self.patched_model.alignment_stream_analyzer is not None:
399
+ if logits.dim() == 1: # guard in case something upstream squeezed
400
+ logits = logits.unsqueeze(0) # (1, V)
401
  # Pass the last generated token for repetition tracking
402
+ last_token = (
403
+ generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None
404
+ )
405
+ logits = self.patched_model.alignment_stream_analyzer.step(
406
+ logits, next_token=last_token
407
+ ) # (1, V)
408
 
409
  # Apply repetition penalty
410
+ ids_for_proc = generated_ids[:1, ...] # batch = 1
411
  logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
412
+
413
  # Apply temperature scaling.
414
  if temperature != 1.0:
415
  logits = logits / temperature
416
+
417
  # Apply min_p and top_p filtering
418
  logits = min_p_warper(ids_for_proc, logits)
419
  logits = top_p_warper(ids_for_proc, logits)
 
432
 
433
  # Get embedding for the new token.
434
  next_token_embed = self.speech_emb(next_token)
435
+ next_token_embed = (
436
+ next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1)
437
+ )
438
 
439
  # For CFG
440
  next_token_embed = torch.cat([next_token_embed, next_token_embed])
src/chatterbox/mtl_tts.py CHANGED
@@ -22,36 +22,36 @@ REPO_ID = "ResembleAI/chatterbox"
22
 
23
  # Supported languages for the multilingual model
24
  SUPPORTED_LANGUAGES = {
25
- "ar": "Arabic",
26
- "da": "Danish",
27
- "de": "German",
28
- "el": "Greek",
29
- "en": "English",
30
- "es": "Spanish",
31
- "fi": "Finnish",
32
- "fr": "French",
33
- "he": "Hebrew",
34
- "hi": "Hindi",
35
- "it": "Italian",
36
- "ja": "Japanese",
37
- "ko": "Korean",
38
- "ms": "Malay",
39
- "nl": "Dutch",
40
- "no": "Norwegian",
41
- "pl": "Polish",
42
- "pt": "Portuguese",
43
- "ru": "Russian",
44
- "sv": "Swedish",
45
- "sw": "Swahili",
46
- "tr": "Turkish",
47
- "zh": "Chinese",
48
  }
49
 
50
 
51
  def punc_norm(text: str) -> str:
52
  """
53
- Quick cleanup func for punctuation from LLMs or
54
- containing chars not seen often in the dataset
55
  """
56
  if len(text) == 0:
57
  return "You need to add some text for me to talk."
@@ -73,8 +73,8 @@ def punc_norm(text: str) -> str:
73
  ("—", "-"),
74
  ("–", "-"),
75
  (" ,", ","),
76
- ("“", "\""),
77
- ("”", "\""),
78
  ("‘", "'"),
79
  ("’", "'"),
80
  ]
@@ -83,7 +83,7 @@ def punc_norm(text: str) -> str:
83
 
84
  # Add full stop if no ending punc
85
  text = text.rstrip(" ")
86
- sentence_enders = {".", "!", "?", "-", ",","、",",","。","?","!"}
87
  if not any(text.endswith(p) for p in sentence_enders):
88
  text += "."
89
 
@@ -107,6 +107,7 @@ class Conditionals:
107
  - prompt_feat_len
108
  - embedding
109
  """
 
110
  t3: T3Cond
111
  gen: dict
112
 
@@ -118,16 +119,13 @@ class Conditionals:
118
  return self
119
 
120
  def save(self, fpath: Path):
121
- arg_dict = dict(
122
- t3=self.t3.__dict__,
123
- gen=self.gen
124
- )
125
  torch.save(arg_dict, fpath)
126
 
127
  @classmethod
128
  def load(cls, fpath, map_location="cpu"):
129
  kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
130
- return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
131
 
132
 
133
  class ChatterboxMultilingualTTS:
@@ -158,13 +156,11 @@ class ChatterboxMultilingualTTS:
158
  return SUPPORTED_LANGUAGES.copy()
159
 
160
  @classmethod
161
- def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS':
162
  ckpt_dir = Path(ckpt_dir)
163
 
164
  ve = VoiceEncoder()
165
- ve.load_state_dict(
166
- torch.load(ckpt_dir / "ve.pt", weights_only=True)
167
- )
168
  ve.to(device).eval()
169
 
170
  t3 = T3(T3Config.multilingual())
@@ -175,14 +171,10 @@ class ChatterboxMultilingualTTS:
175
  t3.to(device).eval()
176
 
177
  s3gen = S3Gen()
178
- s3gen.load_state_dict(
179
- torch.load(ckpt_dir / "s3gen.pt", weights_only=True)
180
- )
181
  s3gen.to(device).eval()
182
 
183
- tokenizer = MTLTokenizer(
184
- str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
185
- )
186
 
187
  conds = None
188
  if (builtin_voice := ckpt_dir / "conds.pt").exists():
@@ -191,36 +183,94 @@ class ChatterboxMultilingualTTS:
191
  return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
192
 
193
  @classmethod
194
- def from_pretrained(cls, device: torch.device) -> 'ChatterboxMultilingualTTS':
195
  ckpt_dir = Path(
196
  snapshot_download(
197
  repo_id=REPO_ID,
198
  repo_type="model",
199
- revision="main",
200
- allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"],
 
 
 
 
 
 
 
201
  token=os.getenv("HF_TOKEN"),
202
  )
203
  )
204
  return cls.from_local(ckpt_dir, device)
205
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
207
  ## Load reference wav
208
  s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
209
 
210
  ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
211
 
212
- s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
213
- s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
 
 
214
 
215
  # Speech cond prompt tokens
216
  t3_cond_prompt_tokens = None
217
  if plen := self.t3.hp.speech_cond_prompt_len:
218
  s3_tokzr = self.s3gen.tokenizer
219
- t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
220
- t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
 
 
 
 
221
 
222
  # Voice-encoder speaker embedding
223
- ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
 
 
224
  ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
225
 
226
  t3_cond = T3Cond(
@@ -249,11 +299,13 @@ class ChatterboxMultilingualTTS:
249
  f"Unsupported language_id '{language_id}'. "
250
  f"Supported languages: {supported_langs}"
251
  )
252
-
253
  if audio_prompt_path:
254
  self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
255
  else:
256
- assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
 
 
257
 
258
  # Update exaggeration if needed
259
  if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()):
@@ -266,8 +318,12 @@ class ChatterboxMultilingualTTS:
266
 
267
  # Norm and tokenize text
268
  text = punc_norm(text)
269
- text_tokens = self.tokenizer.text_to_tokens(text, language_id=language_id.lower() if language_id else None).to(self.device)
270
- text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
 
 
 
 
271
 
272
  sot = self.t3.hp.start_text_token
273
  eot = self.t3.hp.stop_text_token
@@ -297,5 +353,5 @@ class ChatterboxMultilingualTTS:
297
  ref_dict=self.conds.gen,
298
  )
299
  wav = wav.squeeze(0).detach().cpu().numpy()
300
- watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
301
- return torch.from_numpy(watermarked_wav).unsqueeze(0)
 
22
 
23
  # Supported languages for the multilingual model
24
  SUPPORTED_LANGUAGES = {
25
+ "ar": "Arabic",
26
+ "da": "Danish",
27
+ "de": "German",
28
+ "el": "Greek",
29
+ "en": "English",
30
+ "es": "Spanish",
31
+ "fi": "Finnish",
32
+ "fr": "French",
33
+ "he": "Hebrew",
34
+ "hi": "Hindi",
35
+ "it": "Italian",
36
+ "ja": "Japanese",
37
+ "ko": "Korean",
38
+ "ms": "Malay",
39
+ "nl": "Dutch",
40
+ "no": "Norwegian",
41
+ "pl": "Polish",
42
+ "pt": "Portuguese",
43
+ "ru": "Russian",
44
+ "sv": "Swedish",
45
+ "sw": "Swahili",
46
+ "tr": "Turkish",
47
+ "zh": "Chinese",
48
  }
49
 
50
 
51
  def punc_norm(text: str) -> str:
52
  """
53
+ Quick cleanup func for punctuation from LLMs or
54
+ containing chars not seen often in the dataset
55
  """
56
  if len(text) == 0:
57
  return "You need to add some text for me to talk."
 
73
  ("—", "-"),
74
  ("–", "-"),
75
  (" ,", ","),
76
+ ("“", '"'),
77
+ ("”", '"'),
78
  ("‘", "'"),
79
  ("’", "'"),
80
  ]
 
83
 
84
  # Add full stop if no ending punc
85
  text = text.rstrip(" ")
86
+ sentence_enders = {".", "!", "?", "-", ",", "、", ",", "。", "?", "!"}
87
  if not any(text.endswith(p) for p in sentence_enders):
88
  text += "."
89
 
 
107
  - prompt_feat_len
108
  - embedding
109
  """
110
+
111
  t3: T3Cond
112
  gen: dict
113
 
 
119
  return self
120
 
121
  def save(self, fpath: Path):
122
+ arg_dict = dict(t3=self.t3.__dict__, gen=self.gen)
 
 
 
123
  torch.save(arg_dict, fpath)
124
 
125
  @classmethod
126
  def load(cls, fpath, map_location="cpu"):
127
  kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
128
+ return cls(T3Cond(**kwargs["t3"]), kwargs["gen"])
129
 
130
 
131
  class ChatterboxMultilingualTTS:
 
156
  return SUPPORTED_LANGUAGES.copy()
157
 
158
  @classmethod
159
+ def from_local(cls, ckpt_dir, device) -> "ChatterboxMultilingualTTS":
160
  ckpt_dir = Path(ckpt_dir)
161
 
162
  ve = VoiceEncoder()
163
+ ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", weights_only=True))
 
 
164
  ve.to(device).eval()
165
 
166
  t3 = T3(T3Config.multilingual())
 
171
  t3.to(device).eval()
172
 
173
  s3gen = S3Gen()
174
+ s3gen.load_state_dict(torch.load(ckpt_dir / "s3gen.pt", weights_only=True))
 
 
175
  s3gen.to(device).eval()
176
 
177
+ tokenizer = MTLTokenizer(str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json"))
 
 
178
 
179
  conds = None
180
  if (builtin_voice := ckpt_dir / "conds.pt").exists():
 
183
  return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
184
 
185
  @classmethod
186
+ def from_pretrained(cls, device: torch.device) -> "ChatterboxMultilingualTTS":
187
  ckpt_dir = Path(
188
  snapshot_download(
189
  repo_id=REPO_ID,
190
  repo_type="model",
191
+ revision="main",
192
+ allow_patterns=[
193
+ "ve.pt",
194
+ "t3_mtl23ls_v2.safetensors",
195
+ "s3gen.pt",
196
+ "grapheme_mtl_merged_expanded_v1.json",
197
+ "conds.pt",
198
+ "Cangjie5_TC.json",
199
+ ],
200
  token=os.getenv("HF_TOKEN"),
201
  )
202
  )
203
  return cls.from_local(ckpt_dir, device)
204
+
205
+ @classmethod
206
+ def from_checkpoint(
207
+ cls, save_dir, device: torch.device
208
+ ) -> "ChatterboxMultilingualTTS":
209
+ ckpt_dir = Path(
210
+ snapshot_download(
211
+ repo_id=REPO_ID,
212
+ repo_type="model",
213
+ revision="main",
214
+ allow_patterns=[
215
+ "ve.pt",
216
+ "t3_mtl23ls_v2.safetensors",
217
+ "s3gen.pt",
218
+ "grapheme_mtl_merged_expanded_v1.json",
219
+ "conds.pt",
220
+ "Cangjie5_TC.json",
221
+ ],
222
+ token=os.getenv("HF_TOKEN"),
223
+ )
224
+ )
225
+ ckpt_dir = Path(ckpt_dir)
226
+
227
+ ve = VoiceEncoder()
228
+ ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", weights_only=True))
229
+ ve.to(device).eval()
230
+
231
+ t3 = T3(T3Config.multilingual())
232
+ t3_state = load_safetensors(save_dir + "t3_mtl23ls_v2.safetensors")
233
+ if "model" in t3_state.keys():
234
+ t3_state = t3_state["model"][0]
235
+ t3.load_state_dict(t3_state)
236
+ t3.to(device).eval()
237
+
238
+ s3gen = S3Gen()
239
+ s3gen.load_state_dict(torch.load(ckpt_dir / "s3gen.pt", weights_only=True))
240
+ s3gen.to(device).eval()
241
+
242
+ tokenizer = MTLTokenizer(str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json"))
243
+
244
+ conds = Conditionals.load(save_dir + "conds.pt").to(device)
245
+
246
+ return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
247
+
248
  def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
249
  ## Load reference wav
250
  s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
251
 
252
  ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
253
 
254
+ s3gen_ref_wav = s3gen_ref_wav[: self.DEC_COND_LEN]
255
+ s3gen_ref_dict = self.s3gen.embed_ref(
256
+ s3gen_ref_wav, S3GEN_SR, device=self.device
257
+ )
258
 
259
  # Speech cond prompt tokens
260
  t3_cond_prompt_tokens = None
261
  if plen := self.t3.hp.speech_cond_prompt_len:
262
  s3_tokzr = self.s3gen.tokenizer
263
+ t3_cond_prompt_tokens, _ = s3_tokzr.forward(
264
+ [ref_16k_wav[: self.ENC_COND_LEN]], max_len=plen
265
+ )
266
+ t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(
267
+ self.device
268
+ )
269
 
270
  # Voice-encoder speaker embedding
271
+ ve_embed = torch.from_numpy(
272
+ self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR)
273
+ )
274
  ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
275
 
276
  t3_cond = T3Cond(
 
299
  f"Unsupported language_id '{language_id}'. "
300
  f"Supported languages: {supported_langs}"
301
  )
302
+
303
  if audio_prompt_path:
304
  self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
305
  else:
306
+ assert (
307
+ self.conds is not None
308
+ ), "Please `prepare_conditionals` first or specify `audio_prompt_path`"
309
 
310
  # Update exaggeration if needed
311
  if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()):
 
318
 
319
  # Norm and tokenize text
320
  text = punc_norm(text)
321
+ text_tokens = self.tokenizer.text_to_tokens(
322
+ text, language_id=language_id.lower() if language_id else None
323
+ ).to(self.device)
324
+ text_tokens = torch.cat(
325
+ [text_tokens, text_tokens], dim=0
326
+ ) # Need two seqs for CFG
327
 
328
  sot = self.t3.hp.start_text_token
329
  eot = self.t3.hp.stop_text_token
 
353
  ref_dict=self.conds.gen,
354
  )
355
  wav = wav.squeeze(0).detach().cpu().numpy()
356
+ # wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
357
+ return torch.from_numpy(wav).unsqueeze(0)