rahul7star commited on
Commit
89ed58d
·
verified ·
1 Parent(s): d610107

Update src/chatterbox/mtl_tts.py

Browse files
Files changed (1) hide show
  1. src/chatterbox/mtl_tts.py +22 -29
src/chatterbox/mtl_tts.py CHANGED
@@ -64,9 +64,7 @@ class ChatterboxMultilingualTTS:
64
  ENC_COND_LEN = 6 * S3_SR
65
  DEC_COND_LEN = 10 * S3GEN_SR
66
 
67
- def __init__(self, t3: T3, s3gen: S3Gen, ve: VoiceEncoder,
68
- tokenizer: MTLTokenizer, device: torch.device,
69
- conds: Conditionals = None):
70
  self.sr = S3GEN_SR
71
  self.t3 = t3
72
  self.s3gen = s3gen
@@ -77,13 +75,7 @@ class ChatterboxMultilingualTTS:
77
  self.watermarker = perth.PerthImplicitWatermarker()
78
 
79
  @classmethod
80
- def get_supported_languages(cls):
81
- return SUPPORTED_LANGUAGES.copy()
82
-
83
- @classmethod
84
- def from_local(cls, ckpt_dir: Path, device: torch.device) -> "ChatterboxMultilingualTTS":
85
- ckpt_dir = Path(ckpt_dir)
86
-
87
  ve = VoiceEncoder()
88
  ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", weights_only=True))
89
  ve.to(device).eval()
@@ -105,10 +97,10 @@ class ChatterboxMultilingualTTS:
105
  if (builtin_voice := ckpt_dir / "conds.pt").exists():
106
  conds = Conditionals.load(builtin_voice)
107
 
108
- return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
109
 
110
  @classmethod
111
- def from_pretrained(cls, device: str | torch.device | None = None) -> "ChatterboxMultilingualTTS":
112
  if device is None:
113
  device = torch.device("cpu")
114
  elif isinstance(device, str):
@@ -122,35 +114,36 @@ class ChatterboxMultilingualTTS:
122
  revision="main",
123
  allow_patterns=[
124
  "ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt",
125
- "grapheme_mtl_merged_expanded_v1.json", "conds.pt",
126
- "Cangjie5_TC.json"
127
  ],
128
  token=os.getenv("HF_TOKEN"),
129
  ))
130
 
131
- model = cls.from_local(ckpt_dir, device)
132
-
133
- # Ensure all params on CPU and eval
134
- model.t3.to(device).eval()
135
- model.s3gen.to(device).eval()
136
- model.ve.to(device).eval()
137
- if model.conds:
138
- for k, v in model.conds.gen.items():
139
- if torch.is_tensor(v):
140
- model.conds.gen[k] = v.to(device)
141
- return model
142
 
143
  @torch.no_grad()
144
- def generate(self, text: str, speaker_embedding: torch.Tensor = None) -> torch.Tensor:
145
  """
146
- Generate audio waveform (numpy array) from text.
147
- CPU-compatible.
148
  """
149
- text = punc_norm(text)
 
 
 
 
 
150
  token_ids = self.tokenizer.encode(text)
151
  token_ids = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
152
 
153
  conds = self.conds.gen if self.conds else {}
 
 
 
 
 
 
 
154
  t3_out = self.t3(token_ids, **conds)
155
  audio = self.s3gen(t3_out, **conds)
156
 
 
64
  ENC_COND_LEN = 6 * S3_SR
65
  DEC_COND_LEN = 10 * S3GEN_SR
66
 
67
+ def __init__(self, t3, s3gen, ve, tokenizer, device, conds=None):
 
 
68
  self.sr = S3GEN_SR
69
  self.t3 = t3
70
  self.s3gen = s3gen
 
75
  self.watermarker = perth.PerthImplicitWatermarker()
76
 
77
  @classmethod
78
+ def from_local(cls, ckpt_dir, device=torch.device("cpu")):
 
 
 
 
 
 
79
  ve = VoiceEncoder()
80
  ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", weights_only=True))
81
  ve.to(device).eval()
 
97
  if (builtin_voice := ckpt_dir / "conds.pt").exists():
98
  conds = Conditionals.load(builtin_voice)
99
 
100
+ return cls(t3, s3gen, ve, tokenizer, device, conds)
101
 
102
  @classmethod
103
+ def from_pretrained(cls, device=None):
104
  if device is None:
105
  device = torch.device("cpu")
106
  elif isinstance(device, str):
 
114
  revision="main",
115
  allow_patterns=[
116
  "ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt",
117
+ "grapheme_mtl_merged_expanded_v1.json", "conds.pt"
 
118
  ],
119
  token=os.getenv("HF_TOKEN"),
120
  ))
121
 
122
+ return cls.from_local(ckpt_dir, device)
 
 
 
 
 
 
 
 
 
 
123
 
124
  @torch.no_grad()
125
+ def generate(self, text: str, speaker_embedding=None, language_id=None, **kwargs):
126
  """
127
+ CPU-safe text-to-speech.
128
+ Accepts optional `language_id` and any other kwargs.
129
  """
130
+ # Normalize punctuation
131
+ text = text.strip()
132
+ if not text.endswith("."):
133
+ text += "."
134
+
135
+ # Encode text
136
  token_ids = self.tokenizer.encode(text)
137
  token_ids = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
138
 
139
  conds = self.conds.gen if self.conds else {}
140
+
141
+ # Include language_id in conds if provided
142
+ if language_id is not None:
143
+ conds = conds.copy()
144
+ conds['language_id'] = language_id
145
+
146
+ # Run through T3 and S3Gen
147
  t3_out = self.t3(token_ids, **conds)
148
  audio = self.s3gen(t3_out, **conds)
149