Spaces:
Running on Zero
Running on Zero
Ahmed Wasfy commited on
Commit ·
b6577ee
1
Parent(s): c77a697
New model changes
Browse files- src/chatterbox/models/t3/t3.py +118 -57
- src/chatterbox/mtl_tts.py +114 -58
src/chatterbox/models/t3/t3.py
CHANGED
|
@@ -10,7 +10,11 @@ import torch
|
|
| 10 |
import torch.nn.functional as F
|
| 11 |
from torch import nn, Tensor
|
| 12 |
from transformers import LlamaModel, LlamaConfig
|
| 13 |
-
from transformers.generation.logits_process import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
from .modules.learned_pos_emb import LearnedPositionEmbeddings
|
| 16 |
|
|
@@ -27,8 +31,12 @@ logger = logging.getLogger(__name__)
|
|
| 27 |
|
| 28 |
def _ensure_BOT_EOT(text_tokens: Tensor, hp):
|
| 29 |
B = text_tokens.size(0)
|
| 30 |
-
assert (
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
class T3(nn.Module):
|
|
@@ -43,7 +51,9 @@ class T3(nn.Module):
|
|
| 43 |
|
| 44 |
def __init__(self, hp=None):
|
| 45 |
if hp is None:
|
| 46 |
-
hp =
|
|
|
|
|
|
|
| 47 |
super().__init__()
|
| 48 |
self.hp = hp
|
| 49 |
self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
|
|
@@ -65,8 +75,12 @@ class T3(nn.Module):
|
|
| 65 |
self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim)
|
| 66 |
|
| 67 |
# logit projection
|
| 68 |
-
self.text_head = nn.Linear(
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
self.compiled = False
|
| 71 |
|
| 72 |
@property
|
|
@@ -77,9 +91,13 @@ class T3(nn.Module):
|
|
| 77 |
"""
|
| 78 |
Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
|
| 79 |
"""
|
| 80 |
-
if
|
| 81 |
-
t3_cond.
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
return self.cond_enc(t3_cond) # (B, len_cond, dim)
|
| 84 |
|
| 85 |
def prepare_input_embeds(
|
|
@@ -103,13 +121,15 @@ class T3(nn.Module):
|
|
| 103 |
len_cond = cond_emb.size(1)
|
| 104 |
|
| 105 |
if cond_emb.size(0) != text_emb.size(0):
|
| 106 |
-
|
| 107 |
|
| 108 |
# concat
|
| 109 |
-
embeds = torch.stack(
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
| 113 |
return embeds, len_cond
|
| 114 |
|
| 115 |
def forward(
|
|
@@ -140,7 +160,9 @@ class T3(nn.Module):
|
|
| 140 |
return_dict=True,
|
| 141 |
use_cache=(not training),
|
| 142 |
)
|
| 143 |
-
hidden_states = tfmr_out.hidden_states[
|
|
|
|
|
|
|
| 144 |
|
| 145 |
# post-processing: splice out text and speech parts of hidden states
|
| 146 |
len_text = text_tokens.size(1)
|
|
@@ -154,8 +176,8 @@ class T3(nn.Module):
|
|
| 154 |
text_end = len_cond + ttl[i].item()
|
| 155 |
speech_start = len_cond + text_tokens.size(1)
|
| 156 |
speech_end = speech_start + stl[i].item()
|
| 157 |
-
text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end]
|
| 158 |
-
speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end]
|
| 159 |
|
| 160 |
# logit projection
|
| 161 |
text_logits = self.text_head(text_latents)
|
|
@@ -173,17 +195,21 @@ class T3(nn.Module):
|
|
| 173 |
self,
|
| 174 |
*,
|
| 175 |
t3_cond: T3Cond,
|
| 176 |
-
text_tokens: torch.LongTensor,
|
| 177 |
-
text_token_lens: torch.LongTensor,
|
| 178 |
-
speech_tokens: torch.LongTensor,
|
| 179 |
-
speech_token_lens: torch.LongTensor,
|
|
|
|
|
|
|
| 180 |
):
|
| 181 |
-
"
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
| 186 |
|
|
|
|
| 187 |
out = self.forward(
|
| 188 |
t3_cond=t3_cond,
|
| 189 |
text_tokens=text_tokens,
|
|
@@ -191,19 +217,42 @@ class T3(nn.Module):
|
|
| 191 |
speech_tokens=speech_tokens,
|
| 192 |
speech_token_lens=speech_token_lens,
|
| 193 |
training=True,
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
#
|
| 197 |
-
IGNORE_ID = -100
|
| 198 |
device = out.text_logits.device
|
| 199 |
-
|
| 200 |
-
mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] # (B, len_speech)
|
| 201 |
-
masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID)
|
| 202 |
-
masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID)
|
| 203 |
-
loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID)
|
| 204 |
-
loss_speech = F.cross_entropy(out.speech_logits, masked_speech, ignore_index=IGNORE_ID)
|
| 205 |
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
@torch.inference_mode()
|
| 209 |
def inference(
|
|
@@ -211,11 +260,9 @@ class T3(nn.Module):
|
|
| 211 |
*,
|
| 212 |
t3_cond: T3Cond,
|
| 213 |
text_tokens: Tensor,
|
| 214 |
-
initial_speech_tokens: Optional[Tensor]=None,
|
| 215 |
-
|
| 216 |
# misc conditioning
|
| 217 |
-
prepend_prompt_speech_tokens: Optional[Tensor]=None,
|
| 218 |
-
|
| 219 |
# HF generate args
|
| 220 |
num_return_sequences=1,
|
| 221 |
max_new_tokens=None,
|
|
@@ -235,11 +282,15 @@ class T3(nn.Module):
|
|
| 235 |
# Validate / sanitize inputs
|
| 236 |
assert prepend_prompt_speech_tokens is None, "not implemented"
|
| 237 |
_ensure_BOT_EOT(text_tokens, self.hp)
|
| 238 |
-
text_tokens = torch.atleast_2d(text_tokens).to(
|
|
|
|
|
|
|
| 239 |
|
| 240 |
# Default initial speech to a single start-of-speech token
|
| 241 |
if initial_speech_tokens is None:
|
| 242 |
-
initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(
|
|
|
|
|
|
|
| 243 |
|
| 244 |
# Prepare custom input embeds
|
| 245 |
embeds, len_cond = self.prepare_input_embeds(
|
|
@@ -264,7 +315,7 @@ class T3(nn.Module):
|
|
| 264 |
self.tfmr,
|
| 265 |
None,
|
| 266 |
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
|
| 267 |
-
alignment_layer_idx=9,
|
| 268 |
eos_idx=self.hp.stop_speech_token,
|
| 269 |
)
|
| 270 |
assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
|
|
@@ -298,7 +349,9 @@ class T3(nn.Module):
|
|
| 298 |
|
| 299 |
device = embeds.device
|
| 300 |
|
| 301 |
-
bos_token = torch.tensor(
|
|
|
|
|
|
|
| 302 |
bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim)
|
| 303 |
bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)
|
| 304 |
|
|
@@ -316,7 +369,9 @@ class T3(nn.Module):
|
|
| 316 |
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
| 317 |
min_p_warper = MinPLogitsWarper(min_p=min_p)
|
| 318 |
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
| 319 |
-
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(
|
|
|
|
|
|
|
| 320 |
|
| 321 |
# ---- Initial Forward Pass (no kv_cache yet) ----
|
| 322 |
output = self.patched_model(
|
|
@@ -332,29 +387,33 @@ class T3(nn.Module):
|
|
| 332 |
|
| 333 |
# ---- Generation Loop using kv_cache ----
|
| 334 |
for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
|
| 335 |
-
logits_step = output.logits[:, -1, :]
|
| 336 |
# CFG combine → (1, V)
|
| 337 |
-
cond
|
| 338 |
uncond = logits_step[1:2, :]
|
| 339 |
cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
|
| 340 |
logits = cond + cfg * (cond - uncond)
|
| 341 |
-
|
| 342 |
# Apply alignment stream analyzer integrity checks
|
| 343 |
if self.patched_model.alignment_stream_analyzer is not None:
|
| 344 |
-
if logits.dim() == 1:
|
| 345 |
-
logits = logits.unsqueeze(0)
|
| 346 |
# Pass the last generated token for repetition tracking
|
| 347 |
-
last_token =
|
| 348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
# Apply repetition penalty
|
| 351 |
-
ids_for_proc = generated_ids[:1, ...]
|
| 352 |
logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
|
| 353 |
-
|
| 354 |
# Apply temperature scaling.
|
| 355 |
if temperature != 1.0:
|
| 356 |
logits = logits / temperature
|
| 357 |
-
|
| 358 |
# Apply min_p and top_p filtering
|
| 359 |
logits = min_p_warper(ids_for_proc, logits)
|
| 360 |
logits = top_p_warper(ids_for_proc, logits)
|
|
@@ -373,7 +432,9 @@ class T3(nn.Module):
|
|
| 373 |
|
| 374 |
# Get embedding for the new token.
|
| 375 |
next_token_embed = self.speech_emb(next_token)
|
| 376 |
-
next_token_embed =
|
|
|
|
|
|
|
| 377 |
|
| 378 |
# For CFG
|
| 379 |
next_token_embed = torch.cat([next_token_embed, next_token_embed])
|
|
|
|
| 10 |
import torch.nn.functional as F
|
| 11 |
from torch import nn, Tensor
|
| 12 |
from transformers import LlamaModel, LlamaConfig
|
| 13 |
+
from transformers.generation.logits_process import (
|
| 14 |
+
TopPLogitsWarper,
|
| 15 |
+
RepetitionPenaltyLogitsProcessor,
|
| 16 |
+
MinPLogitsWarper,
|
| 17 |
+
)
|
| 18 |
|
| 19 |
from .modules.learned_pos_emb import LearnedPositionEmbeddings
|
| 20 |
|
|
|
|
| 31 |
|
| 32 |
def _ensure_BOT_EOT(text_tokens: Tensor, hp):
|
| 33 |
B = text_tokens.size(0)
|
| 34 |
+
assert (
|
| 35 |
+
text_tokens == hp.start_text_token
|
| 36 |
+
).int().sum() >= B, "missing start_text_token"
|
| 37 |
+
assert (
|
| 38 |
+
text_tokens == hp.stop_text_token
|
| 39 |
+
).int().sum() >= B, "missing stop_text_token"
|
| 40 |
|
| 41 |
|
| 42 |
class T3(nn.Module):
|
|
|
|
| 51 |
|
| 52 |
def __init__(self, hp=None):
|
| 53 |
if hp is None:
|
| 54 |
+
hp = (
|
| 55 |
+
T3Config.english_only()
|
| 56 |
+
) # Default to English-only config for backward compatibility
|
| 57 |
super().__init__()
|
| 58 |
self.hp = hp
|
| 59 |
self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
|
|
|
|
| 75 |
self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim)
|
| 76 |
|
| 77 |
# logit projection
|
| 78 |
+
self.text_head = nn.Linear(
|
| 79 |
+
self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False
|
| 80 |
+
)
|
| 81 |
+
self.speech_head = nn.Linear(
|
| 82 |
+
self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False
|
| 83 |
+
)
|
| 84 |
self.compiled = False
|
| 85 |
|
| 86 |
@property
|
|
|
|
| 91 |
"""
|
| 92 |
Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
|
| 93 |
"""
|
| 94 |
+
if (
|
| 95 |
+
t3_cond.cond_prompt_speech_tokens is not None
|
| 96 |
+
and t3_cond.cond_prompt_speech_emb is None
|
| 97 |
+
):
|
| 98 |
+
t3_cond.cond_prompt_speech_emb = self.speech_emb(
|
| 99 |
+
t3_cond.cond_prompt_speech_tokens
|
| 100 |
+
) + self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
|
| 101 |
return self.cond_enc(t3_cond) # (B, len_cond, dim)
|
| 102 |
|
| 103 |
def prepare_input_embeds(
|
|
|
|
| 121 |
len_cond = cond_emb.size(1)
|
| 122 |
|
| 123 |
if cond_emb.size(0) != text_emb.size(0):
|
| 124 |
+
cond_emb = cond_emb.expand(text_emb.size(0), -1, -1)
|
| 125 |
|
| 126 |
# concat
|
| 127 |
+
embeds = torch.stack(
|
| 128 |
+
[
|
| 129 |
+
torch.cat((ce, te, se))
|
| 130 |
+
for ce, te, se in zip(cond_emb, text_emb, speech_emb)
|
| 131 |
+
]
|
| 132 |
+
) # (B, length, dim)
|
| 133 |
return embeds, len_cond
|
| 134 |
|
| 135 |
def forward(
|
|
|
|
| 160 |
return_dict=True,
|
| 161 |
use_cache=(not training),
|
| 162 |
)
|
| 163 |
+
hidden_states = tfmr_out.hidden_states[
|
| 164 |
+
-1
|
| 165 |
+
] # final tfmr layer output, (B, seq, dim)
|
| 166 |
|
| 167 |
# post-processing: splice out text and speech parts of hidden states
|
| 168 |
len_text = text_tokens.size(1)
|
|
|
|
| 176 |
text_end = len_cond + ttl[i].item()
|
| 177 |
speech_start = len_cond + text_tokens.size(1)
|
| 178 |
speech_end = speech_start + stl[i].item()
|
| 179 |
+
text_latents[i, : ttl[i]] = hidden_states[i, len_cond:text_end]
|
| 180 |
+
speech_latents[i, : stl[i]] = hidden_states[i, speech_start:speech_end]
|
| 181 |
|
| 182 |
# logit projection
|
| 183 |
text_logits = self.text_head(text_latents)
|
|
|
|
| 195 |
self,
|
| 196 |
*,
|
| 197 |
t3_cond: T3Cond,
|
| 198 |
+
text_tokens: torch.LongTensor, # (B, S_text_padded), includes BOS & EOS
|
| 199 |
+
text_token_lens: torch.LongTensor, # (B,), actual lengths including BOS & EOS
|
| 200 |
+
speech_tokens: torch.LongTensor, # (B, S_speech_padded), includes BOS & EOS
|
| 201 |
+
speech_token_lens: torch.LongTensor, # (B,), actual lengths including BOS & EOS
|
| 202 |
+
labels_text: torch.LongTensor, # (B, S_text_padded-1), already masked with –100
|
| 203 |
+
labels_speech: torch.LongTensor, # (B, S_speech_padded-1), already masked with –100
|
| 204 |
):
|
| 205 |
+
"""
|
| 206 |
+
Compute text and speech cross-entropy using pre-masked labels from the collator.
|
| 207 |
+
Assumes:
|
| 208 |
+
- labels_text[t] corresponds to predicting text_tokens[:, 1:] with –100 where ignored
|
| 209 |
+
- labels_speech[t] corresponds to predicting speech_tokens[:, 1:] with –100 where ignored
|
| 210 |
+
"""
|
| 211 |
|
| 212 |
+
# 1) Run model to get logits
|
| 213 |
out = self.forward(
|
| 214 |
t3_cond=t3_cond,
|
| 215 |
text_tokens=text_tokens,
|
|
|
|
| 217 |
speech_tokens=speech_tokens,
|
| 218 |
speech_token_lens=speech_token_lens,
|
| 219 |
training=True,
|
| 220 |
+
)
|
| 221 |
+
# out.text_logits: (B, S_text_padded, V_text)
|
| 222 |
+
# out.speech_logits: (B, S_speech_padded, V_speech)
|
|
|
|
| 223 |
device = out.text_logits.device
|
| 224 |
+
IGNORE_ID = -100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
# --- Text Loss (use labels_text directly) ---
|
| 227 |
+
# Align logits: predict t₁..EOS from inputs [BOS, t₁..]
|
| 228 |
+
logits_for_text = out.text_logits[
|
| 229 |
+
:, :-1, :
|
| 230 |
+
].contiguous() # (B, S_text_padded-1, V_text)
|
| 231 |
+
# labels_text already has shape (B, S_text_padded-1) with –100 where masked
|
| 232 |
+
if logits_for_text.size(1) == 0:
|
| 233 |
+
loss_text = torch.tensor(0.0, device=device, requires_grad=self.training)
|
| 234 |
+
else:
|
| 235 |
+
loss_text = F.cross_entropy(
|
| 236 |
+
logits_for_text.transpose(1, 2), # (B, V_text, S_text_padded-1)
|
| 237 |
+
labels_text, # (B, S_text_padded-1), ignore_index=–100
|
| 238 |
+
ignore_index=IGNORE_ID,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# --- Speech Loss (use labels_speech directly) ---
|
| 242 |
+
logits_for_speech = out.speech_logits[
|
| 243 |
+
:, :-1, :
|
| 244 |
+
].contiguous() # (B, S_speech_padded-1, V_speech)
|
| 245 |
+
# labels_speech already has shape (B, S_speech_padded-1) with –100 where masked
|
| 246 |
+
if logits_for_speech.size(1) == 0:
|
| 247 |
+
loss_speech = torch.tensor(0.0, device=device, requires_grad=self.training)
|
| 248 |
+
else:
|
| 249 |
+
loss_speech = F.cross_entropy(
|
| 250 |
+
logits_for_speech.transpose(1, 2), # (B, V_speech, S_speech_padded-1)
|
| 251 |
+
labels_speech, # (B, S_speech_padded-1), ignore_index=–100
|
| 252 |
+
ignore_index=IGNORE_ID,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
return loss_text, loss_speech, out.speech_logits
|
| 256 |
|
| 257 |
@torch.inference_mode()
|
| 258 |
def inference(
|
|
|
|
| 260 |
*,
|
| 261 |
t3_cond: T3Cond,
|
| 262 |
text_tokens: Tensor,
|
| 263 |
+
initial_speech_tokens: Optional[Tensor] = None,
|
|
|
|
| 264 |
# misc conditioning
|
| 265 |
+
prepend_prompt_speech_tokens: Optional[Tensor] = None,
|
|
|
|
| 266 |
# HF generate args
|
| 267 |
num_return_sequences=1,
|
| 268 |
max_new_tokens=None,
|
|
|
|
| 282 |
# Validate / sanitize inputs
|
| 283 |
assert prepend_prompt_speech_tokens is None, "not implemented"
|
| 284 |
_ensure_BOT_EOT(text_tokens, self.hp)
|
| 285 |
+
text_tokens = torch.atleast_2d(text_tokens).to(
|
| 286 |
+
dtype=torch.long, device=self.device
|
| 287 |
+
)
|
| 288 |
|
| 289 |
# Default initial speech to a single start-of-speech token
|
| 290 |
if initial_speech_tokens is None:
|
| 291 |
+
initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(
|
| 292 |
+
text_tokens[:, :1]
|
| 293 |
+
)
|
| 294 |
|
| 295 |
# Prepare custom input embeds
|
| 296 |
embeds, len_cond = self.prepare_input_embeds(
|
|
|
|
| 315 |
self.tfmr,
|
| 316 |
None,
|
| 317 |
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
|
| 318 |
+
alignment_layer_idx=9, # TODO: hparam or something?
|
| 319 |
eos_idx=self.hp.stop_speech_token,
|
| 320 |
)
|
| 321 |
assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
|
|
|
|
| 349 |
|
| 350 |
device = embeds.device
|
| 351 |
|
| 352 |
+
bos_token = torch.tensor(
|
| 353 |
+
[[self.hp.start_speech_token]], dtype=torch.long, device=device
|
| 354 |
+
)
|
| 355 |
bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim)
|
| 356 |
bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)
|
| 357 |
|
|
|
|
| 369 |
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
| 370 |
min_p_warper = MinPLogitsWarper(min_p=min_p)
|
| 371 |
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
| 372 |
+
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(
|
| 373 |
+
penalty=float(repetition_penalty)
|
| 374 |
+
)
|
| 375 |
|
| 376 |
# ---- Initial Forward Pass (no kv_cache yet) ----
|
| 377 |
output = self.patched_model(
|
|
|
|
| 387 |
|
| 388 |
# ---- Generation Loop using kv_cache ----
|
| 389 |
for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
|
| 390 |
+
logits_step = output.logits[:, -1, :]
|
| 391 |
# CFG combine → (1, V)
|
| 392 |
+
cond = logits_step[0:1, :]
|
| 393 |
uncond = logits_step[1:2, :]
|
| 394 |
cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
|
| 395 |
logits = cond + cfg * (cond - uncond)
|
| 396 |
+
|
| 397 |
# Apply alignment stream analyzer integrity checks
|
| 398 |
if self.patched_model.alignment_stream_analyzer is not None:
|
| 399 |
+
if logits.dim() == 1: # guard in case something upstream squeezed
|
| 400 |
+
logits = logits.unsqueeze(0) # (1, V)
|
| 401 |
# Pass the last generated token for repetition tracking
|
| 402 |
+
last_token = (
|
| 403 |
+
generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None
|
| 404 |
+
)
|
| 405 |
+
logits = self.patched_model.alignment_stream_analyzer.step(
|
| 406 |
+
logits, next_token=last_token
|
| 407 |
+
) # (1, V)
|
| 408 |
|
| 409 |
# Apply repetition penalty
|
| 410 |
+
ids_for_proc = generated_ids[:1, ...] # batch = 1
|
| 411 |
logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
|
| 412 |
+
|
| 413 |
# Apply temperature scaling.
|
| 414 |
if temperature != 1.0:
|
| 415 |
logits = logits / temperature
|
| 416 |
+
|
| 417 |
# Apply min_p and top_p filtering
|
| 418 |
logits = min_p_warper(ids_for_proc, logits)
|
| 419 |
logits = top_p_warper(ids_for_proc, logits)
|
|
|
|
| 432 |
|
| 433 |
# Get embedding for the new token.
|
| 434 |
next_token_embed = self.speech_emb(next_token)
|
| 435 |
+
next_token_embed = (
|
| 436 |
+
next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1)
|
| 437 |
+
)
|
| 438 |
|
| 439 |
# For CFG
|
| 440 |
next_token_embed = torch.cat([next_token_embed, next_token_embed])
|
src/chatterbox/mtl_tts.py
CHANGED
|
@@ -22,36 +22,36 @@ REPO_ID = "ResembleAI/chatterbox"
|
|
| 22 |
|
| 23 |
# Supported languages for the multilingual model
|
| 24 |
SUPPORTED_LANGUAGES = {
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
}
|
| 49 |
|
| 50 |
|
| 51 |
def punc_norm(text: str) -> str:
|
| 52 |
"""
|
| 53 |
-
|
| 54 |
-
|
| 55 |
"""
|
| 56 |
if len(text) == 0:
|
| 57 |
return "You need to add some text for me to talk."
|
|
@@ -73,8 +73,8 @@ def punc_norm(text: str) -> str:
|
|
| 73 |
("—", "-"),
|
| 74 |
("–", "-"),
|
| 75 |
(" ,", ","),
|
| 76 |
-
("“", "
|
| 77 |
-
("”", "
|
| 78 |
("‘", "'"),
|
| 79 |
("’", "'"),
|
| 80 |
]
|
|
@@ -83,7 +83,7 @@ def punc_norm(text: str) -> str:
|
|
| 83 |
|
| 84 |
# Add full stop if no ending punc
|
| 85 |
text = text.rstrip(" ")
|
| 86 |
-
sentence_enders = {".", "!", "?", "-", ",","、",",","。","?","!"}
|
| 87 |
if not any(text.endswith(p) for p in sentence_enders):
|
| 88 |
text += "."
|
| 89 |
|
|
@@ -107,6 +107,7 @@ class Conditionals:
|
|
| 107 |
- prompt_feat_len
|
| 108 |
- embedding
|
| 109 |
"""
|
|
|
|
| 110 |
t3: T3Cond
|
| 111 |
gen: dict
|
| 112 |
|
|
@@ -118,16 +119,13 @@ class Conditionals:
|
|
| 118 |
return self
|
| 119 |
|
| 120 |
def save(self, fpath: Path):
|
| 121 |
-
arg_dict = dict(
|
| 122 |
-
t3=self.t3.__dict__,
|
| 123 |
-
gen=self.gen
|
| 124 |
-
)
|
| 125 |
torch.save(arg_dict, fpath)
|
| 126 |
|
| 127 |
@classmethod
|
| 128 |
def load(cls, fpath, map_location="cpu"):
|
| 129 |
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
|
| 130 |
-
return cls(T3Cond(**kwargs[
|
| 131 |
|
| 132 |
|
| 133 |
class ChatterboxMultilingualTTS:
|
|
@@ -158,13 +156,11 @@ class ChatterboxMultilingualTTS:
|
|
| 158 |
return SUPPORTED_LANGUAGES.copy()
|
| 159 |
|
| 160 |
@classmethod
|
| 161 |
-
def from_local(cls, ckpt_dir, device) ->
|
| 162 |
ckpt_dir = Path(ckpt_dir)
|
| 163 |
|
| 164 |
ve = VoiceEncoder()
|
| 165 |
-
ve.load_state_dict(
|
| 166 |
-
torch.load(ckpt_dir / "ve.pt", weights_only=True)
|
| 167 |
-
)
|
| 168 |
ve.to(device).eval()
|
| 169 |
|
| 170 |
t3 = T3(T3Config.multilingual())
|
|
@@ -175,14 +171,10 @@ class ChatterboxMultilingualTTS:
|
|
| 175 |
t3.to(device).eval()
|
| 176 |
|
| 177 |
s3gen = S3Gen()
|
| 178 |
-
s3gen.load_state_dict(
|
| 179 |
-
torch.load(ckpt_dir / "s3gen.pt", weights_only=True)
|
| 180 |
-
)
|
| 181 |
s3gen.to(device).eval()
|
| 182 |
|
| 183 |
-
tokenizer = MTLTokenizer(
|
| 184 |
-
str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
|
| 185 |
-
)
|
| 186 |
|
| 187 |
conds = None
|
| 188 |
if (builtin_voice := ckpt_dir / "conds.pt").exists():
|
|
@@ -191,36 +183,94 @@ class ChatterboxMultilingualTTS:
|
|
| 191 |
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
|
| 192 |
|
| 193 |
@classmethod
|
| 194 |
-
def from_pretrained(cls, device: torch.device) ->
|
| 195 |
ckpt_dir = Path(
|
| 196 |
snapshot_download(
|
| 197 |
repo_id=REPO_ID,
|
| 198 |
repo_type="model",
|
| 199 |
-
revision="main",
|
| 200 |
-
allow_patterns=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
token=os.getenv("HF_TOKEN"),
|
| 202 |
)
|
| 203 |
)
|
| 204 |
return cls.from_local(ckpt_dir, device)
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
|
| 207 |
## Load reference wav
|
| 208 |
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
|
| 209 |
|
| 210 |
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
|
| 211 |
|
| 212 |
-
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
|
| 213 |
-
s3gen_ref_dict = self.s3gen.embed_ref(
|
|
|
|
|
|
|
| 214 |
|
| 215 |
# Speech cond prompt tokens
|
| 216 |
t3_cond_prompt_tokens = None
|
| 217 |
if plen := self.t3.hp.speech_cond_prompt_len:
|
| 218 |
s3_tokzr = self.s3gen.tokenizer
|
| 219 |
-
t3_cond_prompt_tokens, _ = s3_tokzr.forward(
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
# Voice-encoder speaker embedding
|
| 223 |
-
ve_embed = torch.from_numpy(
|
|
|
|
|
|
|
| 224 |
ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
|
| 225 |
|
| 226 |
t3_cond = T3Cond(
|
|
@@ -249,11 +299,13 @@ class ChatterboxMultilingualTTS:
|
|
| 249 |
f"Unsupported language_id '{language_id}'. "
|
| 250 |
f"Supported languages: {supported_langs}"
|
| 251 |
)
|
| 252 |
-
|
| 253 |
if audio_prompt_path:
|
| 254 |
self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
|
| 255 |
else:
|
| 256 |
-
assert
|
|
|
|
|
|
|
| 257 |
|
| 258 |
# Update exaggeration if needed
|
| 259 |
if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()):
|
|
@@ -266,8 +318,12 @@ class ChatterboxMultilingualTTS:
|
|
| 266 |
|
| 267 |
# Norm and tokenize text
|
| 268 |
text = punc_norm(text)
|
| 269 |
-
text_tokens = self.tokenizer.text_to_tokens(
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
sot = self.t3.hp.start_text_token
|
| 273 |
eot = self.t3.hp.stop_text_token
|
|
@@ -297,5 +353,5 @@ class ChatterboxMultilingualTTS:
|
|
| 297 |
ref_dict=self.conds.gen,
|
| 298 |
)
|
| 299 |
wav = wav.squeeze(0).detach().cpu().numpy()
|
| 300 |
-
|
| 301 |
-
return torch.from_numpy(
|
|
|
|
| 22 |
|
| 23 |
# Supported languages for the multilingual model
|
| 24 |
SUPPORTED_LANGUAGES = {
|
| 25 |
+
"ar": "Arabic",
|
| 26 |
+
"da": "Danish",
|
| 27 |
+
"de": "German",
|
| 28 |
+
"el": "Greek",
|
| 29 |
+
"en": "English",
|
| 30 |
+
"es": "Spanish",
|
| 31 |
+
"fi": "Finnish",
|
| 32 |
+
"fr": "French",
|
| 33 |
+
"he": "Hebrew",
|
| 34 |
+
"hi": "Hindi",
|
| 35 |
+
"it": "Italian",
|
| 36 |
+
"ja": "Japanese",
|
| 37 |
+
"ko": "Korean",
|
| 38 |
+
"ms": "Malay",
|
| 39 |
+
"nl": "Dutch",
|
| 40 |
+
"no": "Norwegian",
|
| 41 |
+
"pl": "Polish",
|
| 42 |
+
"pt": "Portuguese",
|
| 43 |
+
"ru": "Russian",
|
| 44 |
+
"sv": "Swedish",
|
| 45 |
+
"sw": "Swahili",
|
| 46 |
+
"tr": "Turkish",
|
| 47 |
+
"zh": "Chinese",
|
| 48 |
}
|
| 49 |
|
| 50 |
|
| 51 |
def punc_norm(text: str) -> str:
|
| 52 |
"""
|
| 53 |
+
Quick cleanup func for punctuation from LLMs or
|
| 54 |
+
containing chars not seen often in the dataset
|
| 55 |
"""
|
| 56 |
if len(text) == 0:
|
| 57 |
return "You need to add some text for me to talk."
|
|
|
|
| 73 |
("—", "-"),
|
| 74 |
("–", "-"),
|
| 75 |
(" ,", ","),
|
| 76 |
+
("“", '"'),
|
| 77 |
+
("”", '"'),
|
| 78 |
("‘", "'"),
|
| 79 |
("’", "'"),
|
| 80 |
]
|
|
|
|
| 83 |
|
| 84 |
# Add full stop if no ending punc
|
| 85 |
text = text.rstrip(" ")
|
| 86 |
+
sentence_enders = {".", "!", "?", "-", ",", "、", ",", "。", "?", "!"}
|
| 87 |
if not any(text.endswith(p) for p in sentence_enders):
|
| 88 |
text += "."
|
| 89 |
|
|
|
|
| 107 |
- prompt_feat_len
|
| 108 |
- embedding
|
| 109 |
"""
|
| 110 |
+
|
| 111 |
t3: T3Cond
|
| 112 |
gen: dict
|
| 113 |
|
|
|
|
| 119 |
return self
|
| 120 |
|
| 121 |
def save(self, fpath: Path):
|
| 122 |
+
arg_dict = dict(t3=self.t3.__dict__, gen=self.gen)
|
|
|
|
|
|
|
|
|
|
| 123 |
torch.save(arg_dict, fpath)
|
| 124 |
|
| 125 |
@classmethod
|
| 126 |
def load(cls, fpath, map_location="cpu"):
|
| 127 |
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
|
| 128 |
+
return cls(T3Cond(**kwargs["t3"]), kwargs["gen"])
|
| 129 |
|
| 130 |
|
| 131 |
class ChatterboxMultilingualTTS:
|
|
|
|
| 156 |
return SUPPORTED_LANGUAGES.copy()
|
| 157 |
|
| 158 |
@classmethod
|
| 159 |
+
def from_local(cls, ckpt_dir, device) -> "ChatterboxMultilingualTTS":
|
| 160 |
ckpt_dir = Path(ckpt_dir)
|
| 161 |
|
| 162 |
ve = VoiceEncoder()
|
| 163 |
+
ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", weights_only=True))
|
|
|
|
|
|
|
| 164 |
ve.to(device).eval()
|
| 165 |
|
| 166 |
t3 = T3(T3Config.multilingual())
|
|
|
|
| 171 |
t3.to(device).eval()
|
| 172 |
|
| 173 |
s3gen = S3Gen()
|
| 174 |
+
s3gen.load_state_dict(torch.load(ckpt_dir / "s3gen.pt", weights_only=True))
|
|
|
|
|
|
|
| 175 |
s3gen.to(device).eval()
|
| 176 |
|
| 177 |
+
tokenizer = MTLTokenizer(str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json"))
|
|
|
|
|
|
|
| 178 |
|
| 179 |
conds = None
|
| 180 |
if (builtin_voice := ckpt_dir / "conds.pt").exists():
|
|
|
|
| 183 |
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
|
| 184 |
|
| 185 |
@classmethod
|
| 186 |
+
def from_pretrained(cls, device: torch.device) -> "ChatterboxMultilingualTTS":
|
| 187 |
ckpt_dir = Path(
|
| 188 |
snapshot_download(
|
| 189 |
repo_id=REPO_ID,
|
| 190 |
repo_type="model",
|
| 191 |
+
revision="main",
|
| 192 |
+
allow_patterns=[
|
| 193 |
+
"ve.pt",
|
| 194 |
+
"t3_mtl23ls_v2.safetensors",
|
| 195 |
+
"s3gen.pt",
|
| 196 |
+
"grapheme_mtl_merged_expanded_v1.json",
|
| 197 |
+
"conds.pt",
|
| 198 |
+
"Cangjie5_TC.json",
|
| 199 |
+
],
|
| 200 |
token=os.getenv("HF_TOKEN"),
|
| 201 |
)
|
| 202 |
)
|
| 203 |
return cls.from_local(ckpt_dir, device)
|
| 204 |
+
|
| 205 |
+
@classmethod
|
| 206 |
+
def from_checkpoint(
|
| 207 |
+
cls, save_dir, device: torch.device
|
| 208 |
+
) -> "ChatterboxMultilingualTTS":
|
| 209 |
+
ckpt_dir = Path(
|
| 210 |
+
snapshot_download(
|
| 211 |
+
repo_id=REPO_ID,
|
| 212 |
+
repo_type="model",
|
| 213 |
+
revision="main",
|
| 214 |
+
allow_patterns=[
|
| 215 |
+
"ve.pt",
|
| 216 |
+
"t3_mtl23ls_v2.safetensors",
|
| 217 |
+
"s3gen.pt",
|
| 218 |
+
"grapheme_mtl_merged_expanded_v1.json",
|
| 219 |
+
"conds.pt",
|
| 220 |
+
"Cangjie5_TC.json",
|
| 221 |
+
],
|
| 222 |
+
token=os.getenv("HF_TOKEN"),
|
| 223 |
+
)
|
| 224 |
+
)
|
| 225 |
+
ckpt_dir = Path(ckpt_dir)
|
| 226 |
+
|
| 227 |
+
ve = VoiceEncoder()
|
| 228 |
+
ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", weights_only=True))
|
| 229 |
+
ve.to(device).eval()
|
| 230 |
+
|
| 231 |
+
t3 = T3(T3Config.multilingual())
|
| 232 |
+
t3_state = load_safetensors(save_dir + "t3_mtl23ls_v2.safetensors")
|
| 233 |
+
if "model" in t3_state.keys():
|
| 234 |
+
t3_state = t3_state["model"][0]
|
| 235 |
+
t3.load_state_dict(t3_state)
|
| 236 |
+
t3.to(device).eval()
|
| 237 |
+
|
| 238 |
+
s3gen = S3Gen()
|
| 239 |
+
s3gen.load_state_dict(torch.load(ckpt_dir / "s3gen.pt", weights_only=True))
|
| 240 |
+
s3gen.to(device).eval()
|
| 241 |
+
|
| 242 |
+
tokenizer = MTLTokenizer(str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json"))
|
| 243 |
+
|
| 244 |
+
conds = Conditionals.load(save_dir + "conds.pt").to(device)
|
| 245 |
+
|
| 246 |
+
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
|
| 247 |
+
|
| 248 |
def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
|
| 249 |
## Load reference wav
|
| 250 |
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
|
| 251 |
|
| 252 |
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
|
| 253 |
|
| 254 |
+
s3gen_ref_wav = s3gen_ref_wav[: self.DEC_COND_LEN]
|
| 255 |
+
s3gen_ref_dict = self.s3gen.embed_ref(
|
| 256 |
+
s3gen_ref_wav, S3GEN_SR, device=self.device
|
| 257 |
+
)
|
| 258 |
|
| 259 |
# Speech cond prompt tokens
|
| 260 |
t3_cond_prompt_tokens = None
|
| 261 |
if plen := self.t3.hp.speech_cond_prompt_len:
|
| 262 |
s3_tokzr = self.s3gen.tokenizer
|
| 263 |
+
t3_cond_prompt_tokens, _ = s3_tokzr.forward(
|
| 264 |
+
[ref_16k_wav[: self.ENC_COND_LEN]], max_len=plen
|
| 265 |
+
)
|
| 266 |
+
t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(
|
| 267 |
+
self.device
|
| 268 |
+
)
|
| 269 |
|
| 270 |
# Voice-encoder speaker embedding
|
| 271 |
+
ve_embed = torch.from_numpy(
|
| 272 |
+
self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR)
|
| 273 |
+
)
|
| 274 |
ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
|
| 275 |
|
| 276 |
t3_cond = T3Cond(
|
|
|
|
| 299 |
f"Unsupported language_id '{language_id}'. "
|
| 300 |
f"Supported languages: {supported_langs}"
|
| 301 |
)
|
| 302 |
+
|
| 303 |
if audio_prompt_path:
|
| 304 |
self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
|
| 305 |
else:
|
| 306 |
+
assert (
|
| 307 |
+
self.conds is not None
|
| 308 |
+
), "Please `prepare_conditionals` first or specify `audio_prompt_path`"
|
| 309 |
|
| 310 |
# Update exaggeration if needed
|
| 311 |
if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()):
|
|
|
|
| 318 |
|
| 319 |
# Norm and tokenize text
|
| 320 |
text = punc_norm(text)
|
| 321 |
+
text_tokens = self.tokenizer.text_to_tokens(
|
| 322 |
+
text, language_id=language_id.lower() if language_id else None
|
| 323 |
+
).to(self.device)
|
| 324 |
+
text_tokens = torch.cat(
|
| 325 |
+
[text_tokens, text_tokens], dim=0
|
| 326 |
+
) # Need two seqs for CFG
|
| 327 |
|
| 328 |
sot = self.t3.hp.start_text_token
|
| 329 |
eot = self.t3.hp.stop_text_token
|
|
|
|
| 353 |
ref_dict=self.conds.gen,
|
| 354 |
)
|
| 355 |
wav = wav.squeeze(0).detach().cpu().numpy()
|
| 356 |
+
# wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
|
| 357 |
+
return torch.from_numpy(wav).unsqueeze(0)
|