Zihan428 commited on
Commit
360cbde
·
1 Parent(s): 62b26ca

v2 update: multilingual improvements

Browse files
requirements.txt CHANGED
@@ -13,6 +13,7 @@ safetensors
13
 
14
  # Optional language-specific dependencies
15
  # Uncomment the ones you need for specific languages:
16
- pkuseg # For Chinese text segmentation (improves mixed text handling)
17
  pykakasi>=2.2.0 # For Japanese text processing (Kanji to Hiragana)
 
18
  # dicta-onnx>=0.1.0 # For Hebrew diacritization
 
13
 
14
  # Optional language-specific dependencies
15
  # Uncomment the ones you need for specific languages:
16
+ spacy_pkuseg # For Chinese text segmentation
17
  pykakasi>=2.2.0 # For Japanese text processing (Kanji to Hiragana)
18
+ russian-text-stresser # For Russian stress labeling
19
  # dicta-onnx>=0.1.0 # For Hebrew diacritization
src/chatterbox/models/t3/modules/t3_config.py CHANGED
@@ -28,7 +28,7 @@ class T3Config:
28
 
29
  @property
30
  def is_multilingual(self):
31
- return self.text_tokens_dict_size == 2352
32
 
33
  @classmethod
34
  def english_only(cls):
@@ -38,4 +38,4 @@ class T3Config:
38
  @classmethod
39
  def multilingual(cls):
40
  """Create configuration for multilingual TTS model."""
41
- return cls(text_tokens_dict_size=2352)
 
28
 
29
  @property
30
  def is_multilingual(self):
31
+ return self.text_tokens_dict_size == 2454
32
 
33
  @classmethod
34
  def english_only(cls):
 
38
  @classmethod
39
  def multilingual(cls):
40
  """Create configuration for multilingual TTS model."""
41
+ return cls(text_tokens_dict_size=2454)
src/chatterbox/models/tokenizers/tokenizer.py CHANGED
@@ -191,7 +191,7 @@ class ChineseCangjieConverter:
191
  def _init_segmenter(self):
192
  """Initialize pkuseg segmenter."""
193
  try:
194
- from pkuseg import pkuseg
195
  self.segmenter = pkuseg()
196
  except ImportError:
197
  logger.warning("pkuseg not available - Chinese segmentation will be skipped")
@@ -235,11 +235,53 @@ class ChineseCangjieConverter:
235
  return "".join(output)
236
 
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  class MTLTokenizer:
239
  def __init__(self, vocab_file_path):
240
  self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
241
  model_dir = Path(vocab_file_path).parent
242
  self.cangjie_converter = ChineseCangjieConverter(model_dir)
 
243
  self.check_vocabset_sot_eot()
244
 
245
  def check_vocabset_sot_eot(self):
@@ -262,6 +304,8 @@ class MTLTokenizer:
262
  txt = add_hebrew_diacritics(txt)
263
  elif language_id == 'ko':
264
  txt = korean_normalize(txt)
 
 
265
 
266
  # Prepend language token
267
  if language_id:
 
191
  def _init_segmenter(self):
192
  """Initialize pkuseg segmenter."""
193
  try:
194
+ from spacy_pkuseg import pkuseg
195
  self.segmenter = pkuseg()
196
  except ImportError:
197
  logger.warning("pkuseg not available - Chinese segmentation will be skipped")
 
235
  return "".join(output)
236
 
237
 
238
+ class RussianStressLabeler:
239
+ """Adds stress marks to Russian text when the optional dependency is available."""
240
+
241
+ def __init__(self):
242
+ self._stresser = None
243
+ self._available = False
244
+ self._error_logged = False
245
+ self._initialize()
246
+
247
+ def _initialize(self):
248
+ try:
249
+ from russian_text_stresser.text_stresser import RussianTextStresser
250
+ except ImportError:
251
+ logger.warning("russian_text_stresser not available - Russian stress labeling skipped")
252
+ self._error_logged = True
253
+ return
254
+ except Exception as exc:
255
+ logger.warning(f"Failed to import RussianTextStresser: {exc}")
256
+ self._error_logged = True
257
+ return
258
+
259
+ try:
260
+ self._stresser = RussianTextStresser()
261
+ self._available = True
262
+ except Exception as exc:
263
+ logger.warning(f"Failed to initialize RussianTextStresser: {exc}")
264
+ self._error_logged = True
265
+
266
+ def __call__(self, text: str) -> str:
267
+ if not text or not self._available:
268
+ return text
269
+
270
+ try:
271
+ return self._stresser.stress_text(text)
272
+ except Exception as exc:
273
+ if not self._error_logged:
274
+ logger.warning(f"Russian stress labeling failed: {exc}")
275
+ self._error_logged = True
276
+ return text
277
+
278
+
279
  class MTLTokenizer:
280
  def __init__(self, vocab_file_path):
281
  self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
282
  model_dir = Path(vocab_file_path).parent
283
  self.cangjie_converter = ChineseCangjieConverter(model_dir)
284
+ self.russian_stress_labeler = RussianStressLabeler()
285
  self.check_vocabset_sot_eot()
286
 
287
  def check_vocabset_sot_eot(self):
 
304
  txt = add_hebrew_diacritics(txt)
305
  elif language_id == 'ko':
306
  txt = korean_normalize(txt)
307
+ elif language_id == 'ru':
308
+ txt = self.russian_stress_labeler(txt)
309
 
310
  # Prepend language token
311
  if language_id:
src/chatterbox/mtl_tts.py CHANGED
@@ -168,7 +168,7 @@ class ChatterboxMultilingualTTS:
168
  ve.to(device).eval()
169
 
170
  t3 = T3(T3Config.multilingual())
171
- t3_state = load_safetensors(ckpt_dir / "t3_23lang.safetensors")
172
  if "model" in t3_state.keys():
173
  t3_state = t3_state["model"][0]
174
  t3.load_state_dict(t3_state)
@@ -181,7 +181,7 @@ class ChatterboxMultilingualTTS:
181
  s3gen.to(device).eval()
182
 
183
  tokenizer = MTLTokenizer(
184
- str(ckpt_dir / "mtl_tokenizer.json")
185
  )
186
 
187
  conds = None
@@ -197,7 +197,7 @@ class ChatterboxMultilingualTTS:
197
  repo_id=REPO_ID,
198
  repo_type="model",
199
  revision="main",
200
- allow_patterns=["ve.pt", "t3_23lang.safetensors", "s3gen.pt", "mtl_tokenizer.json", "conds.pt", "Cangjie5_TC.json"],
201
  token=os.getenv("HF_TOKEN"),
202
  )
203
  )
 
168
  ve.to(device).eval()
169
 
170
  t3 = T3(T3Config.multilingual())
171
+ t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
172
  if "model" in t3_state.keys():
173
  t3_state = t3_state["model"][0]
174
  t3.load_state_dict(t3_state)
 
181
  s3gen.to(device).eval()
182
 
183
  tokenizer = MTLTokenizer(
184
+ str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
185
  )
186
 
187
  conds = None
 
197
  repo_id=REPO_ID,
198
  repo_type="model",
199
  revision="main",
200
+ allow_patterns=["ve.pt", "t3_23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"],
201
  token=os.getenv("HF_TOKEN"),
202
  )
203
  )