rahul7star commited on
Commit
d610107
·
verified ·
1 Parent(s): ef3441d

Update src/chatterbox/mtl_tts.py

Browse files
Files changed (1) hide show
  1. src/chatterbox/mtl_tts.py +58 -69
src/chatterbox/mtl_tts.py CHANGED
@@ -2,12 +2,11 @@ from dataclasses import dataclass
2
  from pathlib import Path
3
  import os
4
  import torch
5
- from huggingface_hub import snapshot_download
6
-
7
  import librosa
8
  import perth
9
  import torch.nn.functional as F
10
  from safetensors.torch import load_file as load_safetensors
 
11
 
12
  from .models.t3 import T3
13
  from .models.t3.modules.t3_config import T3Config
@@ -21,29 +20,28 @@ REPO_ID = "ResembleAI/chatterbox"
21
 
22
  # Supported languages for the multilingual model
23
  SUPPORTED_LANGUAGES = {
24
- "ar": "Arabic", "da": "Danish", "de": "German", "el": "Greek",
25
- "en": "English", "es": "Spanish", "fi": "Finnish", "fr": "French",
26
- "he": "Hebrew", "hi": "Hindi", "it": "Italian", "ja": "Japanese",
27
- "ko": "Korean", "ms": "Malay", "nl": "Dutch", "no": "Norwegian",
28
- "pl": "Polish", "pt": "Portuguese", "ru": "Russian", "sv": "Swedish",
29
  "sw": "Swahili", "tr": "Turkish", "zh": "Chinese",
30
  }
31
 
32
 
33
  def punc_norm(text: str) -> str:
34
- """Normalize punctuation for TTS text."""
35
- if not text:
36
  return "You need to add some text for me to talk."
37
- text = text[0].upper() + text[1:]
 
38
  text = " ".join(text.split())
39
  replacements = [
40
- ("...", ", "), ("…", ", "), (":", ","), (" - ", ","),
41
- (";", ","), ("", "-"), ("", "-"), (" ,", ","),
42
- ("“", "\""), ("”", "\""), ("‘", "'"), ("’", "'")
43
  ]
44
  for old, new in replacements:
45
  text = text.replace(old, new)
46
- if not any(text.endswith(p) for p in {".", "!", "?", "-", ",","、",",","。","?","!"}):
47
  text += "."
48
  return text
49
 
@@ -53,35 +51,22 @@ class Conditionals:
53
  t3: T3Cond
54
  gen: dict
55
 
56
- def to(self, device):
57
- """Move only tensors in `.gen` to device. T3Cond stays as-is."""
58
- for k, v in self.gen.items():
59
- if torch.is_tensor(v):
60
- self.gen[k] = v.to(device)
61
- return self
62
-
63
  def save(self, fpath: Path):
64
  torch.save({"t3": self.t3.__dict__, "gen": self.gen}, fpath)
65
 
66
  @classmethod
67
- def load(cls, fpath: Path, map_location="cpu"):
68
- data = torch.load(fpath, map_location=map_location, weights_only=True)
69
- return cls(T3Cond(**data["t3"]), data["gen"])
70
 
71
 
72
  class ChatterboxMultilingualTTS:
73
  ENC_COND_LEN = 6 * S3_SR
74
  DEC_COND_LEN = 10 * S3GEN_SR
75
 
76
- def __init__(
77
- self,
78
- t3: T3,
79
- s3gen: S3Gen,
80
- ve: VoiceEncoder,
81
- tokenizer: MTLTokenizer,
82
- device: str,
83
- conds: Conditionals = None,
84
- ):
85
  self.sr = S3GEN_SR
86
  self.t3 = t3
87
  self.s3gen = s3gen
@@ -91,55 +76,43 @@ class ChatterboxMultilingualTTS:
91
  self.conds = conds
92
  self.watermarker = perth.PerthImplicitWatermarker()
93
 
94
- # Disable gradients for safety
95
- for p in self.parameters():
96
- p.requires_grad = False
97
-
98
  @classmethod
99
  def get_supported_languages(cls):
100
  return SUPPORTED_LANGUAGES.copy()
101
 
102
  @classmethod
103
- def from_local(cls, ckpt_dir, device) -> "ChatterboxMultilingualTTS":
104
  ckpt_dir = Path(ckpt_dir)
105
 
106
- # Voice Encoder
107
  ve = VoiceEncoder()
108
- ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", map_location="cpu", weights_only=True))
109
- ve.to(device)
110
 
111
- # T3
112
  t3 = T3(T3Config.multilingual())
113
  t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
114
- if "model" in t3_state:
115
  t3_state = t3_state["model"][0]
116
  t3.load_state_dict(t3_state)
117
- t3.to(device)
118
 
119
- # S3Gen
120
  s3gen = S3Gen()
121
- s3gen.load_state_dict(torch.load(ckpt_dir / "s3gen.pt", map_location="cpu", weights_only=True))
122
- s3gen.to(device)
123
 
124
- # Tokenizer
125
  tokenizer = MTLTokenizer(str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json"))
126
 
127
- # Conditionals
128
  conds = None
129
- if (ckpt_dir / "conds.pt").exists():
130
- conds = Conditionals.load(ckpt_dir / "conds.pt").to(device)
131
 
132
  return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
133
 
134
  @classmethod
135
  def from_pretrained(cls, device: str | torch.device | None = None) -> "ChatterboxMultilingualTTS":
136
- """Load model fully on CPU, never use CUDA."""
137
  if device is None:
138
  device = torch.device("cpu")
139
  elif isinstance(device, str):
140
  device = torch.device(device)
141
-
142
- # Force CPU
143
  if device.type != "cpu":
144
  device = torch.device("cpu")
145
 
@@ -148,24 +121,40 @@ class ChatterboxMultilingualTTS:
148
  repo_type="model",
149
  revision="main",
150
  allow_patterns=[
151
- "ve.pt",
152
- "t3_mtl23ls_v2.safetensors",
153
- "s3gen.pt",
154
- "grapheme_mtl_merged_expanded_v1.json",
155
- "conds.pt",
156
- "Cangjie5_TC.json",
157
  ],
158
- token=os.getenv("HF_TOKEN")
159
  ))
160
 
161
  model = cls.from_local(ckpt_dir, device)
 
 
 
 
 
 
 
 
 
162
  return model
163
 
164
- def parameters(self):
165
- """Iterate over all parameters in T3, S3Gen, and VE for disabling gradients."""
166
- for p in self.t3.parameters():
167
- yield p
168
- for p in self.s3gen.parameters():
169
- yield p
170
- for p in self.ve.parameters():
171
- yield p
 
 
 
 
 
 
 
 
 
 
 
2
  from pathlib import Path
3
  import os
4
  import torch
 
 
5
  import librosa
6
  import perth
7
  import torch.nn.functional as F
8
  from safetensors.torch import load_file as load_safetensors
9
+ from huggingface_hub import snapshot_download
10
 
11
  from .models.t3 import T3
12
  from .models.t3.modules.t3_config import T3Config
 
20
 
21
  # Supported languages for the multilingual model
22
  SUPPORTED_LANGUAGES = {
23
+ "ar": "Arabic", "da": "Danish", "de": "German", "el": "Greek", "en": "English",
24
+ "es": "Spanish", "fi": "Finnish", "fr": "French", "he": "Hebrew", "hi": "Hindi",
25
+ "it": "Italian", "ja": "Japanese", "ko": "Korean", "ms": "Malay", "nl": "Dutch",
26
+ "no": "Norwegian", "pl": "Polish", "pt": "Portuguese", "ru": "Russian", "sv": "Swedish",
 
27
  "sw": "Swahili", "tr": "Turkish", "zh": "Chinese",
28
  }
29
 
30
 
31
  def punc_norm(text: str) -> str:
32
+ if len(text) == 0:
 
33
  return "You need to add some text for me to talk."
34
+ if text[0].islower():
35
+ text = text[0].upper() + text[1:]
36
  text = " ".join(text.split())
37
  replacements = [
38
+ ("...", ", "), ("…", ", "), (":", ","), (" - ", ","), (";", ","),
39
+ ("", "-"), ("", "-"), (" ,", ","), ("“", "\""), ("”", "\""),
40
+ ("‘", "'"), ("’", "'"),
41
  ]
42
  for old, new in replacements:
43
  text = text.replace(old, new)
44
+ if not text[-1] in {".", "!", "?", "-", ",","、",",","。","?","!"}:
45
  text += "."
46
  return text
47
 
 
51
  t3: T3Cond
52
  gen: dict
53
 
 
 
 
 
 
 
 
54
  def save(self, fpath: Path):
55
  torch.save({"t3": self.t3.__dict__, "gen": self.gen}, fpath)
56
 
57
  @classmethod
58
+ def load(cls, fpath: Path):
59
+ kwargs = torch.load(fpath, map_location="cpu", weights_only=True)
60
+ return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
61
 
62
 
63
  class ChatterboxMultilingualTTS:
64
  ENC_COND_LEN = 6 * S3_SR
65
  DEC_COND_LEN = 10 * S3GEN_SR
66
 
67
+ def __init__(self, t3: T3, s3gen: S3Gen, ve: VoiceEncoder,
68
+ tokenizer: MTLTokenizer, device: torch.device,
69
+ conds: Conditionals = None):
 
 
 
 
 
 
70
  self.sr = S3GEN_SR
71
  self.t3 = t3
72
  self.s3gen = s3gen
 
76
  self.conds = conds
77
  self.watermarker = perth.PerthImplicitWatermarker()
78
 
 
 
 
 
79
  @classmethod
80
  def get_supported_languages(cls):
81
  return SUPPORTED_LANGUAGES.copy()
82
 
83
  @classmethod
84
+ def from_local(cls, ckpt_dir: Path, device: torch.device) -> "ChatterboxMultilingualTTS":
85
  ckpt_dir = Path(ckpt_dir)
86
 
 
87
  ve = VoiceEncoder()
88
+ ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", weights_only=True))
89
+ ve.to(device).eval()
90
 
 
91
  t3 = T3(T3Config.multilingual())
92
  t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
93
+ if "model" in t3_state.keys():
94
  t3_state = t3_state["model"][0]
95
  t3.load_state_dict(t3_state)
96
+ t3.to(device).eval()
97
 
 
98
  s3gen = S3Gen()
99
+ s3gen.load_state_dict(torch.load(ckpt_dir / "s3gen.pt", weights_only=True))
100
+ s3gen.to(device).eval()
101
 
 
102
  tokenizer = MTLTokenizer(str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json"))
103
 
 
104
  conds = None
105
+ if (builtin_voice := ckpt_dir / "conds.pt").exists():
106
+ conds = Conditionals.load(builtin_voice)
107
 
108
  return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
109
 
110
  @classmethod
111
  def from_pretrained(cls, device: str | torch.device | None = None) -> "ChatterboxMultilingualTTS":
 
112
  if device is None:
113
  device = torch.device("cpu")
114
  elif isinstance(device, str):
115
  device = torch.device(device)
 
 
116
  if device.type != "cpu":
117
  device = torch.device("cpu")
118
 
 
121
  repo_type="model",
122
  revision="main",
123
  allow_patterns=[
124
+ "ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt",
125
+ "grapheme_mtl_merged_expanded_v1.json", "conds.pt",
126
+ "Cangjie5_TC.json"
 
 
 
127
  ],
128
+ token=os.getenv("HF_TOKEN"),
129
  ))
130
 
131
  model = cls.from_local(ckpt_dir, device)
132
+
133
+ # Ensure all params on CPU and eval
134
+ model.t3.to(device).eval()
135
+ model.s3gen.to(device).eval()
136
+ model.ve.to(device).eval()
137
+ if model.conds:
138
+ for k, v in model.conds.gen.items():
139
+ if torch.is_tensor(v):
140
+ model.conds.gen[k] = v.to(device)
141
  return model
142
 
143
+ @torch.no_grad()
144
+ def generate(self, text: str, speaker_embedding: torch.Tensor = None) -> torch.Tensor:
145
+ """
146
+ Generate audio waveform (numpy array) from text.
147
+ CPU-compatible.
148
+ """
149
+ text = punc_norm(text)
150
+ token_ids = self.tokenizer.encode(text)
151
+ token_ids = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
152
+
153
+ conds = self.conds.gen if self.conds else {}
154
+ t3_out = self.t3(token_ids, **conds)
155
+ audio = self.s3gen(t3_out, **conds)
156
+
157
+ if isinstance(audio, torch.Tensor):
158
+ audio = audio.squeeze(0).cpu().numpy()
159
+
160
+ return audio