rahul7star commited on
Commit
c5d087d
·
verified ·
1 Parent(s): 036ffff

Update src/chatterbox/mtl_tts.py

Browse files
Files changed (1) hide show
  1. src/chatterbox/mtl_tts.py +76 -135
src/chatterbox/mtl_tts.py CHANGED
@@ -1,15 +1,11 @@
1
  from dataclasses import dataclass
2
  from pathlib import Path
3
  import os
4
- from pathlib import Path
5
  import torch
6
- import os
7
- from huggingface_hub import snapshot_download
8
-
9
  import librosa
10
- import torch
11
  import perth
12
  import torch.nn.functional as F
 
13
  from safetensors.torch import load_file as load_safetensors
14
  from huggingface_hub import snapshot_download
15
 
@@ -24,51 +20,42 @@ from .models.t3.modules.cond_enc import T3Cond
24
 
25
  REPO_ID = "ResembleAI/chatterbox"
26
 
27
- # Supported languages for the multilingual model
28
  SUPPORTED_LANGUAGES = {
29
- "ar": "Arabic",
30
- "da": "Danish",
31
- "de": "German",
32
- "el": "Greek",
33
- "en": "English",
34
- "es": "Spanish",
35
- "fi": "Finnish",
36
- "fr": "French",
37
- "he": "Hebrew",
38
- "hi": "Hindi",
39
- "it": "Italian",
40
- "ja": "Japanese",
41
- "ko": "Korean",
42
- "ms": "Malay",
43
- "nl": "Dutch",
44
- "no": "Norwegian",
45
- "pl": "Polish",
46
- "pt": "Portuguese",
47
- "ru": "Russian",
48
- "sv": "Swedish",
49
- "sw": "Swahili",
50
- "tr": "Turkish",
51
- "zh": "Chinese",
52
  }
53
 
54
 
55
  def punc_norm(text: str) -> str:
56
- """
57
- Quick cleanup func for punctuation from LLMs or
58
- containing chars not seen often in the dataset
59
- """
60
- if len(text) == 0:
61
  return "You need to add some text for me to talk."
62
 
63
- # Capitalise first letter
64
- if text[0].islower():
65
- text = text[0].upper() + text[1:]
66
-
67
- # Remove multiple space chars
68
  text = " ".join(text.split())
69
 
70
- # Replace uncommon/llm punc
71
- punc_to_replace = [
72
  ("...", ", "),
73
  ("…", ", "),
74
  (":", ","),
@@ -82,13 +69,10 @@ def punc_norm(text: str) -> str:
82
  ("‘", "'"),
83
  ("’", "'"),
84
  ]
85
- for old_char_sequence, new_char in punc_to_replace:
86
- text = text.replace(old_char_sequence, new_char)
87
 
88
- # Add full stop if no ending punc
89
- text = text.rstrip(" ")
90
- sentence_enders = {".", "!", "?", "-", ",","、",",","。","?","!"}
91
- if not any(text.endswith(p) for p in sentence_enders):
92
  text += "."
93
 
94
  return text
@@ -96,58 +80,31 @@ def punc_norm(text: str) -> str:
96
 
97
  @dataclass
98
  class Conditionals:
99
- """
100
- Conditionals for T3 and S3Gen
101
- - T3 conditionals:
102
- - speaker_emb
103
- - clap_emb
104
- - cond_prompt_speech_tokens
105
- - cond_prompt_speech_emb
106
- - emotion_adv
107
- - S3Gen conditionals:
108
- - prompt_token
109
- - prompt_token_len
110
- - prompt_feat
111
- - prompt_feat_len
112
- - embedding
113
- """
114
  t3: T3Cond
115
  gen: dict
116
 
117
  def to(self, device):
118
- self.t3 = self.t3.to(device=device)
119
  for k, v in self.gen.items():
120
  if torch.is_tensor(v):
121
- self.gen[k] = v.to(device=device)
122
  return self
123
 
124
- def save(self, fpath: Path):
125
- arg_dict = dict(
126
- t3=self.t3.__dict__,
127
- gen=self.gen
128
- )
129
- torch.save(arg_dict, fpath)
130
 
131
  @classmethod
132
- def load(cls, fpath, map_location="cpu"):
133
- kwargs = torch.load(fpath, map_location="cpu", weights_only=True)
134
- return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
135
 
136
 
137
  class ChatterboxMultilingualTTS:
138
  ENC_COND_LEN = 6 * S3_SR
139
  DEC_COND_LEN = 10 * S3GEN_SR
140
 
141
- def __init__(
142
- self,
143
- t3: T3,
144
- s3gen: S3Gen,
145
- ve: VoiceEncoder,
146
- tokenizer: MTLTokenizer,
147
- device: str,
148
- conds: Conditionals = None,
149
- ):
150
- self.sr = S3GEN_SR # sample rate of synthesized audio
151
  self.t3 = t3
152
  self.s3gen = s3gen
153
  self.ve = ve
@@ -156,72 +113,61 @@ class ChatterboxMultilingualTTS:
156
  self.conds = conds
157
  self.watermarker = perth.PerthImplicitWatermarker()
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  @classmethod
160
  def get_supported_languages(cls):
161
- """Return dictionary of supported language codes and names."""
162
  return SUPPORTED_LANGUAGES.copy()
163
 
164
  @classmethod
165
- def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS':
166
  ckpt_dir = Path(ckpt_dir)
167
 
168
  ve = VoiceEncoder()
169
- ve.load_state_dict(
170
- torch.load(ckpt_dir / "ve.pt", weights_only=True)
171
- )
172
  ve.to(device).eval()
173
 
174
  t3 = T3(T3Config.multilingual())
175
  t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
176
- if "model" in t3_state.keys():
177
  t3_state = t3_state["model"][0]
178
  t3.load_state_dict(t3_state)
179
  t3.to(device).eval()
180
 
181
  s3gen = S3Gen()
182
- s3gen.load_state_dict(
183
- torch.load(ckpt_dir / "s3gen.pt", weights_only=True)
184
- )
185
  s3gen.to(device).eval()
186
 
187
- tokenizer = MTLTokenizer(
188
- str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
189
- )
190
 
191
  conds = None
192
- if (builtin_voice := ckpt_dir / "conds.pt").exists():
193
- conds = Conditionals.load(builtin_voice).to(device)
194
-
195
- return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
196
 
197
-
198
-
199
 
200
  @classmethod
201
- def from_pretrained(
202
- cls,
203
- device: str | torch.device | None = None,
204
- ) -> "ChatterboxMultilingualTTS":
205
- """
206
- Load ChatterboxMultilingualTTS safely.
207
- Defaults to CPU and never forces CUDA.
208
- """
209
-
210
- # 🔒 Normalize + force CPU
211
- if device is None:
212
- device = torch.device("cpu")
213
- elif isinstance(device, str):
214
- device = torch.device(device)
215
-
216
- # Absolute safety: never allow CUDA
217
- if device.type != "cpu":
218
- device = torch.device("cpu")
219
-
220
- ckpt_dir = Path(
221
- snapshot_download(
222
  repo_id=REPO_ID,
223
- repo_type="model",
224
- revision="main",
225
  allow_patterns=[
226
  "ve.pt",
227
  "t3_mtl23ls_v2.safetensors",
@@ -231,17 +177,12 @@ class ChatterboxMultilingualTTS:
231
  "Cangjie5_TC.json",
232
  ],
233
  token=os.getenv("HF_TOKEN"),
234
- )
235
- )
236
-
237
- model = cls.from_local(ckpt_dir, device)
238
 
239
- # Extra safety: force model tensors to CPU
240
- if hasattr(model, "to"):
241
- model = model.to("cpu")
242
 
243
- model.eval()
244
- for p in model.parameters():
245
- p.requires_grad = False
246
 
247
- return model
 
1
  from dataclasses import dataclass
2
  from pathlib import Path
3
  import os
 
4
  import torch
 
 
 
5
  import librosa
 
6
  import perth
7
  import torch.nn.functional as F
8
+
9
  from safetensors.torch import load_file as load_safetensors
10
  from huggingface_hub import snapshot_download
11
 
 
20
 
21
  REPO_ID = "ResembleAI/chatterbox"
22
 
 
23
  SUPPORTED_LANGUAGES = {
24
+ "ar": "Arabic",
25
+ "da": "Danish",
26
+ "de": "German",
27
+ "el": "Greek",
28
+ "en": "English",
29
+ "es": "Spanish",
30
+ "fi": "Finnish",
31
+ "fr": "French",
32
+ "he": "Hebrew",
33
+ "hi": "Hindi",
34
+ "it": "Italian",
35
+ "ja": "Japanese",
36
+ "ko": "Korean",
37
+ "ms": "Malay",
38
+ "nl": "Dutch",
39
+ "no": "Norwegian",
40
+ "pl": "Polish",
41
+ "pt": "Portuguese",
42
+ "ru": "Russian",
43
+ "sv": "Swedish",
44
+ "sw": "Swahili",
45
+ "tr": "Turkish",
46
+ "zh": "Chinese",
47
  }
48
 
49
 
50
  def punc_norm(text: str) -> str:
51
+ if not text:
 
 
 
 
52
  return "You need to add some text for me to talk."
53
 
54
+ text = text.strip()
55
+ text = text[0].upper() + text[1:] if text[0].islower() else text
 
 
 
56
  text = " ".join(text.split())
57
 
58
+ replacements = [
 
59
  ("...", ", "),
60
  ("…", ", "),
61
  (":", ","),
 
69
  ("‘", "'"),
70
  ("’", "'"),
71
  ]
72
+ for a, b in replacements:
73
+ text = text.replace(a, b)
74
 
75
+ if not text.endswith((".", "!", "?", ",", "-", "。", "?", "!")):
 
 
 
76
  text += "."
77
 
78
  return text
 
80
 
81
  @dataclass
82
  class Conditionals:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  t3: T3Cond
84
  gen: dict
85
 
86
  def to(self, device):
87
+ self.t3 = self.t3.to(device)
88
  for k, v in self.gen.items():
89
  if torch.is_tensor(v):
90
+ self.gen[k] = v.to(device)
91
  return self
92
 
93
+ def save(self, path: Path):
94
+ torch.save({"t3": self.t3.__dict__, "gen": self.gen}, path)
 
 
 
 
95
 
96
  @classmethod
97
+ def load(cls, path, map_location="cpu"):
98
+ data = torch.load(path, map_location=map_location, weights_only=True)
99
+ return cls(T3Cond(**data["t3"]), data["gen"])
100
 
101
 
102
  class ChatterboxMultilingualTTS:
103
  ENC_COND_LEN = 6 * S3_SR
104
  DEC_COND_LEN = 10 * S3GEN_SR
105
 
106
+ def __init__(self, t3, s3gen, ve, tokenizer, device, conds=None):
107
+ self.sr = S3GEN_SR
 
 
 
 
 
 
 
 
108
  self.t3 = t3
109
  self.s3gen = s3gen
110
  self.ve = ve
 
113
  self.conds = conds
114
  self.watermarker = perth.PerthImplicitWatermarker()
115
 
116
+ # ✅ Forward torch behavior
117
+ def eval(self):
118
+ for m in (self.t3, self.s3gen, self.ve):
119
+ m.eval()
120
+ return self
121
+
122
+ def to(self, device):
123
+ self.device = device
124
+ for m in (self.t3, self.s3gen, self.ve):
125
+ m.to(device)
126
+ if self.conds:
127
+ self.conds.to(device)
128
+ return self
129
+
130
+ def parameters(self):
131
+ for m in (self.t3, self.s3gen, self.ve):
132
+ yield from m.parameters()
133
+
134
  @classmethod
135
  def get_supported_languages(cls):
 
136
  return SUPPORTED_LANGUAGES.copy()
137
 
138
  @classmethod
139
+ def from_local(cls, ckpt_dir, device):
140
  ckpt_dir = Path(ckpt_dir)
141
 
142
  ve = VoiceEncoder()
143
+ ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", map_location="cpu", weights_only=True))
 
 
144
  ve.to(device).eval()
145
 
146
  t3 = T3(T3Config.multilingual())
147
  t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
148
+ if "model" in t3_state:
149
  t3_state = t3_state["model"][0]
150
  t3.load_state_dict(t3_state)
151
  t3.to(device).eval()
152
 
153
  s3gen = S3Gen()
154
+ s3gen.load_state_dict(torch.load(ckpt_dir / "s3gen.pt", map_location="cpu", weights_only=True))
 
 
155
  s3gen.to(device).eval()
156
 
157
+ tokenizer = MTLTokenizer(str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json"))
 
 
158
 
159
  conds = None
160
+ if (ckpt_dir / "conds.pt").exists():
161
+ conds = Conditionals.load(ckpt_dir / "conds.pt").to(device)
 
 
162
 
163
+ return cls(t3, s3gen, ve, tokenizer, device, conds)
 
164
 
165
  @classmethod
166
+ def from_pretrained(cls, device=None):
167
+ device = torch.device("cpu")
168
+
169
+ ckpt_dir = snapshot_download(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  repo_id=REPO_ID,
 
 
171
  allow_patterns=[
172
  "ve.pt",
173
  "t3_mtl23ls_v2.safetensors",
 
177
  "Cangjie5_TC.json",
178
  ],
179
  token=os.getenv("HF_TOKEN"),
180
+ )
 
 
 
181
 
182
+ model = cls.from_local(ckpt_dir, device)
183
+ model.eval()
 
184
 
185
+ for p in model.parameters():
186
+ p.requires_grad_(False)
 
187
 
188
+ return model