rahul7star commited on
Commit
ef3441d
·
verified ·
1 Parent(s): c5d087d

Update src/chatterbox/mtl_tts.py

Browse files
Files changed (1) hide show
  1. src/chatterbox/mtl_tts.py +69 -86
src/chatterbox/mtl_tts.py CHANGED
@@ -2,12 +2,12 @@ from dataclasses import dataclass
2
  from pathlib import Path
3
  import os
4
  import torch
 
 
5
  import librosa
6
  import perth
7
  import torch.nn.functional as F
8
-
9
  from safetensors.torch import load_file as load_safetensors
10
- from huggingface_hub import snapshot_download
11
 
12
  from .models.t3 import T3
13
  from .models.t3.modules.t3_config import T3Config
@@ -17,64 +17,34 @@ from .models.tokenizers import MTLTokenizer
17
  from .models.voice_encoder import VoiceEncoder
18
  from .models.t3.modules.cond_enc import T3Cond
19
 
20
-
21
  REPO_ID = "ResembleAI/chatterbox"
22
 
 
23
  SUPPORTED_LANGUAGES = {
24
- "ar": "Arabic",
25
- "da": "Danish",
26
- "de": "German",
27
- "el": "Greek",
28
- "en": "English",
29
- "es": "Spanish",
30
- "fi": "Finnish",
31
- "fr": "French",
32
- "he": "Hebrew",
33
- "hi": "Hindi",
34
- "it": "Italian",
35
- "ja": "Japanese",
36
- "ko": "Korean",
37
- "ms": "Malay",
38
- "nl": "Dutch",
39
- "no": "Norwegian",
40
- "pl": "Polish",
41
- "pt": "Portuguese",
42
- "ru": "Russian",
43
- "sv": "Swedish",
44
- "sw": "Swahili",
45
- "tr": "Turkish",
46
- "zh": "Chinese",
47
  }
48
 
49
 
50
  def punc_norm(text: str) -> str:
 
51
  if not text:
52
  return "You need to add some text for me to talk."
53
-
54
- text = text.strip()
55
- text = text[0].upper() + text[1:] if text[0].islower() else text
56
  text = " ".join(text.split())
57
-
58
  replacements = [
59
- ("...", ", "),
60
- ("", ", "),
61
- (":", ","),
62
- (" - ", ", "),
63
- (";", ", "),
64
- ("—", "-"),
65
- ("–", "-"),
66
- (" ,", ","),
67
- ("“", "\""),
68
- ("”", "\""),
69
- ("‘", "'"),
70
- ("’", "'"),
71
  ]
72
- for a, b in replacements:
73
- text = text.replace(a, b)
74
-
75
- if not text.endswith((".", "!", "?", ",", "-", "。", "?", "!")):
76
  text += "."
77
-
78
  return text
79
 
80
 
@@ -84,18 +54,18 @@ class Conditionals:
84
  gen: dict
85
 
86
  def to(self, device):
87
- self.t3 = self.t3.to(device)
88
  for k, v in self.gen.items():
89
  if torch.is_tensor(v):
90
  self.gen[k] = v.to(device)
91
  return self
92
 
93
- def save(self, path: Path):
94
- torch.save({"t3": self.t3.__dict__, "gen": self.gen}, path)
95
 
96
  @classmethod
97
- def load(cls, path, map_location="cpu"):
98
- data = torch.load(path, map_location=map_location, weights_only=True)
99
  return cls(T3Cond(**data["t3"]), data["gen"])
100
 
101
 
@@ -103,7 +73,15 @@ class ChatterboxMultilingualTTS:
103
  ENC_COND_LEN = 6 * S3_SR
104
  DEC_COND_LEN = 10 * S3GEN_SR
105
 
106
- def __init__(self, t3, s3gen, ve, tokenizer, device, conds=None):
 
 
 
 
 
 
 
 
107
  self.sr = S3GEN_SR
108
  self.t3 = t3
109
  self.s3gen = s3gen
@@ -113,61 +91,62 @@ class ChatterboxMultilingualTTS:
113
  self.conds = conds
114
  self.watermarker = perth.PerthImplicitWatermarker()
115
 
116
- # Forward torch behavior
117
- def eval(self):
118
- for m in (self.t3, self.s3gen, self.ve):
119
- m.eval()
120
- return self
121
-
122
- def to(self, device):
123
- self.device = device
124
- for m in (self.t3, self.s3gen, self.ve):
125
- m.to(device)
126
- if self.conds:
127
- self.conds.to(device)
128
- return self
129
-
130
- def parameters(self):
131
- for m in (self.t3, self.s3gen, self.ve):
132
- yield from m.parameters()
133
 
134
  @classmethod
135
  def get_supported_languages(cls):
136
  return SUPPORTED_LANGUAGES.copy()
137
 
138
  @classmethod
139
- def from_local(cls, ckpt_dir, device):
140
  ckpt_dir = Path(ckpt_dir)
141
 
 
142
  ve = VoiceEncoder()
143
  ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", map_location="cpu", weights_only=True))
144
- ve.to(device).eval()
145
 
 
146
  t3 = T3(T3Config.multilingual())
147
  t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
148
  if "model" in t3_state:
149
  t3_state = t3_state["model"][0]
150
  t3.load_state_dict(t3_state)
151
- t3.to(device).eval()
152
 
 
153
  s3gen = S3Gen()
154
  s3gen.load_state_dict(torch.load(ckpt_dir / "s3gen.pt", map_location="cpu", weights_only=True))
155
- s3gen.to(device).eval()
156
 
 
157
  tokenizer = MTLTokenizer(str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json"))
158
 
 
159
  conds = None
160
  if (ckpt_dir / "conds.pt").exists():
161
  conds = Conditionals.load(ckpt_dir / "conds.pt").to(device)
162
 
163
- return cls(t3, s3gen, ve, tokenizer, device, conds)
164
 
165
  @classmethod
166
- def from_pretrained(cls, device=None):
167
- device = torch.device("cpu")
168
-
169
- ckpt_dir = snapshot_download(
 
 
 
 
 
 
 
 
170
  repo_id=REPO_ID,
 
 
171
  allow_patterns=[
172
  "ve.pt",
173
  "t3_mtl23ls_v2.safetensors",
@@ -176,13 +155,17 @@ class ChatterboxMultilingualTTS:
176
  "conds.pt",
177
  "Cangjie5_TC.json",
178
  ],
179
- token=os.getenv("HF_TOKEN"),
180
- )
181
 
182
  model = cls.from_local(ckpt_dir, device)
183
- model.eval()
184
-
185
- for p in model.parameters():
186
- p.requires_grad_(False)
187
-
188
  return model
 
 
 
 
 
 
 
 
 
 
2
  from pathlib import Path
3
  import os
4
  import torch
5
+ from huggingface_hub import snapshot_download
6
+
7
  import librosa
8
  import perth
9
  import torch.nn.functional as F
 
10
  from safetensors.torch import load_file as load_safetensors
 
11
 
12
  from .models.t3 import T3
13
  from .models.t3.modules.t3_config import T3Config
 
17
  from .models.voice_encoder import VoiceEncoder
18
  from .models.t3.modules.cond_enc import T3Cond
19
 
 
20
  REPO_ID = "ResembleAI/chatterbox"
21
 
22
+ # Supported languages for the multilingual model
23
  SUPPORTED_LANGUAGES = {
24
+ "ar": "Arabic", "da": "Danish", "de": "German", "el": "Greek",
25
+ "en": "English", "es": "Spanish", "fi": "Finnish", "fr": "French",
26
+ "he": "Hebrew", "hi": "Hindi", "it": "Italian", "ja": "Japanese",
27
+ "ko": "Korean", "ms": "Malay", "nl": "Dutch", "no": "Norwegian",
28
+ "pl": "Polish", "pt": "Portuguese", "ru": "Russian", "sv": "Swedish",
29
+ "sw": "Swahili", "tr": "Turkish", "zh": "Chinese",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  }
31
 
32
 
33
  def punc_norm(text: str) -> str:
34
+ """Normalize punctuation for TTS text."""
35
  if not text:
36
  return "You need to add some text for me to talk."
37
+ text = text[0].upper() + text[1:]
 
 
38
  text = " ".join(text.split())
 
39
  replacements = [
40
+ ("...", ", "), ("…", ", "), (":", ","), (" - ", ","),
41
+ (";", ","), ("—", "-"), ("–", "-"), (" ,", ","),
42
+ ("", "\""), ("”", "\""), ("‘", "'"), ("’", "'")
 
 
 
 
 
 
 
 
 
43
  ]
44
+ for old, new in replacements:
45
+ text = text.replace(old, new)
46
+ if not any(text.endswith(p) for p in {".", "!", "?", "-", ",","、",",","。","?","!"}):
 
47
  text += "."
 
48
  return text
49
 
50
 
 
54
  gen: dict
55
 
56
  def to(self, device):
57
+ """Move only tensors in `.gen` to device. T3Cond stays as-is."""
58
  for k, v in self.gen.items():
59
  if torch.is_tensor(v):
60
  self.gen[k] = v.to(device)
61
  return self
62
 
63
+ def save(self, fpath: Path):
64
+ torch.save({"t3": self.t3.__dict__, "gen": self.gen}, fpath)
65
 
66
  @classmethod
67
+ def load(cls, fpath: Path, map_location="cpu"):
68
+ data = torch.load(fpath, map_location=map_location, weights_only=True)
69
  return cls(T3Cond(**data["t3"]), data["gen"])
70
 
71
 
 
73
  ENC_COND_LEN = 6 * S3_SR
74
  DEC_COND_LEN = 10 * S3GEN_SR
75
 
76
+ def __init__(
77
+ self,
78
+ t3: T3,
79
+ s3gen: S3Gen,
80
+ ve: VoiceEncoder,
81
+ tokenizer: MTLTokenizer,
82
+ device: str,
83
+ conds: Conditionals = None,
84
+ ):
85
  self.sr = S3GEN_SR
86
  self.t3 = t3
87
  self.s3gen = s3gen
 
91
  self.conds = conds
92
  self.watermarker = perth.PerthImplicitWatermarker()
93
 
94
+ # Disable gradients for safety
95
+ for p in self.parameters():
96
+ p.requires_grad = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  @classmethod
99
  def get_supported_languages(cls):
100
  return SUPPORTED_LANGUAGES.copy()
101
 
102
  @classmethod
103
+ def from_local(cls, ckpt_dir, device) -> "ChatterboxMultilingualTTS":
104
  ckpt_dir = Path(ckpt_dir)
105
 
106
+ # Voice Encoder
107
  ve = VoiceEncoder()
108
  ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", map_location="cpu", weights_only=True))
109
+ ve.to(device)
110
 
111
+ # T3
112
  t3 = T3(T3Config.multilingual())
113
  t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
114
  if "model" in t3_state:
115
  t3_state = t3_state["model"][0]
116
  t3.load_state_dict(t3_state)
117
+ t3.to(device)
118
 
119
+ # S3Gen
120
  s3gen = S3Gen()
121
  s3gen.load_state_dict(torch.load(ckpt_dir / "s3gen.pt", map_location="cpu", weights_only=True))
122
+ s3gen.to(device)
123
 
124
+ # Tokenizer
125
  tokenizer = MTLTokenizer(str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json"))
126
 
127
+ # Conditionals
128
  conds = None
129
  if (ckpt_dir / "conds.pt").exists():
130
  conds = Conditionals.load(ckpt_dir / "conds.pt").to(device)
131
 
132
+ return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
133
 
134
  @classmethod
135
+ def from_pretrained(cls, device: str | torch.device | None = None) -> "ChatterboxMultilingualTTS":
136
+ """Load model fully on CPU, never use CUDA."""
137
+ if device is None:
138
+ device = torch.device("cpu")
139
+ elif isinstance(device, str):
140
+ device = torch.device(device)
141
+
142
+ # Force CPU
143
+ if device.type != "cpu":
144
+ device = torch.device("cpu")
145
+
146
+ ckpt_dir = Path(snapshot_download(
147
  repo_id=REPO_ID,
148
+ repo_type="model",
149
+ revision="main",
150
  allow_patterns=[
151
  "ve.pt",
152
  "t3_mtl23ls_v2.safetensors",
 
155
  "conds.pt",
156
  "Cangjie5_TC.json",
157
  ],
158
+ token=os.getenv("HF_TOKEN")
159
+ ))
160
 
161
  model = cls.from_local(ckpt_dir, device)
 
 
 
 
 
162
  return model
163
+
164
+ def parameters(self):
165
+ """Iterate over all parameters in T3, S3Gen, and VE for disabling gradients."""
166
+ for p in self.t3.parameters():
167
+ yield p
168
+ for p in self.s3gen.parameters():
169
+ yield p
170
+ for p in self.ve.parameters():
171
+ yield p