rahul7star commited on
Commit
8f0bff6
·
verified ·
1 Parent(s): b073708

Update src/chatterbox/mtl_tts.py

Browse files
Files changed (1) hide show
  1. src/chatterbox/mtl_tts.py +53 -107
src/chatterbox/mtl_tts.py CHANGED
@@ -1,6 +1,10 @@
1
  from dataclasses import dataclass
2
  from pathlib import Path
3
  import os
 
 
 
 
4
 
5
  import librosa
6
  import torch
@@ -190,112 +194,54 @@ class ChatterboxMultilingualTTS:
190
 
191
  return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
192
 
 
 
 
193
  @classmethod
194
- def from_pretrained(cls, device: torch.device) -> 'ChatterboxMultilingualTTS':
195
- ckpt_dir = Path(
196
- snapshot_download(
197
- repo_id=REPO_ID,
198
- repo_type="model",
199
- revision="main",
200
- allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"],
201
- token=os.getenv("HF_TOKEN"),
202
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  )
204
- return cls.from_local(ckpt_dir, device)
205
-
206
- def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
207
- ## Load reference wav
208
- s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
209
-
210
- ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
211
-
212
- s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
213
- s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
214
-
215
- # Speech cond prompt tokens
216
- t3_cond_prompt_tokens = None
217
- if plen := self.t3.hp.speech_cond_prompt_len:
218
- s3_tokzr = self.s3gen.tokenizer
219
- t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
220
- t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
221
-
222
- # Voice-encoder speaker embedding
223
- ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
224
- ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
225
-
226
- t3_cond = T3Cond(
227
- speaker_emb=ve_embed,
228
- cond_prompt_speech_tokens=t3_cond_prompt_tokens,
229
- emotion_adv=exaggeration * torch.ones(1, 1, 1),
230
- ).to(device=self.device)
231
- self.conds = Conditionals(t3_cond, s3gen_ref_dict)
232
-
233
- def generate(
234
- self,
235
- text,
236
- language_id,
237
- audio_prompt_path=None,
238
- exaggeration=0.5,
239
- cfg_weight=0.5,
240
- temperature=0.8,
241
- repetition_penalty=2.0,
242
- min_p=0.05,
243
- top_p=1.0,
244
- ):
245
- # Validate language_id
246
- if language_id and language_id.lower() not in SUPPORTED_LANGUAGES:
247
- supported_langs = ", ".join(SUPPORTED_LANGUAGES.keys())
248
- raise ValueError(
249
- f"Unsupported language_id '{language_id}'. "
250
- f"Supported languages: {supported_langs}"
251
- )
252
-
253
- if audio_prompt_path:
254
- self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
255
- else:
256
- assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
257
-
258
- # Update exaggeration if needed
259
- if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()):
260
- _cond: T3Cond = self.conds.t3
261
- self.conds.t3 = T3Cond(
262
- speaker_emb=_cond.speaker_emb,
263
- cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens,
264
- emotion_adv=exaggeration * torch.ones(1, 1, 1),
265
- ).to(device=self.device)
266
-
267
- # Norm and tokenize text
268
- text = punc_norm(text)
269
- text_tokens = self.tokenizer.text_to_tokens(text, language_id=language_id.lower() if language_id else None).to(self.device)
270
- text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
271
-
272
- sot = self.t3.hp.start_text_token
273
- eot = self.t3.hp.stop_text_token
274
- text_tokens = F.pad(text_tokens, (1, 0), value=sot)
275
- text_tokens = F.pad(text_tokens, (0, 1), value=eot)
276
-
277
- with torch.inference_mode():
278
- speech_tokens = self.t3.inference(
279
- t3_cond=self.conds.t3,
280
- text_tokens=text_tokens,
281
- max_new_tokens=1000, # TODO: use the value in config
282
- temperature=temperature,
283
- cfg_weight=cfg_weight,
284
- repetition_penalty=repetition_penalty,
285
- min_p=min_p,
286
- top_p=top_p,
287
- )
288
- # Extract only the conditional batch.
289
- speech_tokens = speech_tokens[0]
290
-
291
- # TODO: output becomes 1D
292
- speech_tokens = drop_invalid_tokens(speech_tokens)
293
- speech_tokens = speech_tokens.to(self.device)
294
-
295
- wav, _ = self.s3gen.inference(
296
- speech_tokens=speech_tokens,
297
- ref_dict=self.conds.gen,
298
- )
299
- wav = wav.squeeze(0).detach().cpu().numpy()
300
- watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
301
- return torch.from_numpy(watermarked_wav).unsqueeze(0)
 
1
  from dataclasses import dataclass
2
  from pathlib import Path
3
  import os
4
+ from pathlib import Path
5
+ import torch
6
+ import os
7
+ from huggingface_hub import snapshot_download
8
 
9
  import librosa
10
  import torch
 
194
 
195
  return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
196
 
197
+
198
+
199
+
200
  @classmethod
201
+ def from_pretrained(
202
+ cls,
203
+ device: str | torch.device | None = None,
204
+ ) -> "ChatterboxMultilingualTTS":
205
+ """
206
+ Load ChatterboxMultilingualTTS safely.
207
+ Defaults to CPU and never forces CUDA.
208
+ """
209
+
210
+ # 🔒 Normalize + force CPU
211
+ if device is None:
212
+ device = torch.device("cpu")
213
+ elif isinstance(device, str):
214
+ device = torch.device(device)
215
+
216
+ # Absolute safety: never allow CUDA
217
+ if device.type != "cpu":
218
+ device = torch.device("cpu")
219
+
220
+ ckpt_dir = Path(
221
+ snapshot_download(
222
+ repo_id=REPO_ID,
223
+ repo_type="model",
224
+ revision="main",
225
+ allow_patterns=[
226
+ "ve.pt",
227
+ "t3_mtl23ls_v2.safetensors",
228
+ "s3gen.pt",
229
+ "grapheme_mtl_merged_expanded_v1.json",
230
+ "conds.pt",
231
+ "Cangjie5_TC.json",
232
+ ],
233
+ token=os.getenv("HF_TOKEN"),
234
  )
235
+ )
236
+
237
+ model = cls.from_local(ckpt_dir, device)
238
+
239
+ # Extra safety: force model tensors to CPU
240
+ if hasattr(model, "to"):
241
+ model = model.to("cpu")
242
+
243
+ model.eval()
244
+ for p in model.parameters():
245
+ p.requires_grad = False
246
+
247
+ return model