RASMUS's picture
Upload Finnish Chatterbox model
308155b verified
import torch
import torch.nn.functional as F
from src.chatterbox_.models.t3.modules.cond_enc import T3Cond
from src.config import TrainConfig
from src.utils import setup_logger
logger = setup_logger(__name__)
def resize_and_load_t3_weights(new_model: torch.nn.Module, pretrained_state_dict: dict):
"""
Loads pretrained weights into a new T3 model with a different vocabulary size.
Features: Initialize new tokens with the AVERAGE of existing tokens.
"""
new_model_state_dict = new_model.state_dict()
embedding_layer_name = "text_emb.weight"
output_head_name = "text_head.weight"
mean_init_applied = False
# Step 1: Copy weights for ALL matching layers
for name, param in pretrained_state_dict.items():
if name not in [embedding_layer_name, output_head_name]:
if name in new_model_state_dict and new_model_state_dict[name].shape == param.shape:
new_model_state_dict[name].copy_(param)
else:
logger.warning(f"Layer skipped (mismatch): {name}")
# Step 2: Smart copy for Embedding Layer (Average Init)
if embedding_layer_name in pretrained_state_dict:
old_emb_weights = pretrained_state_dict[embedding_layer_name]
old_vocab_size, _ = old_emb_weights.shape
new_vocab_size = new_model_state_dict[embedding_layer_name].shape[0]
# A) Copy old weights
new_model_state_dict[embedding_layer_name][:old_vocab_size, :].copy_(old_emb_weights)
logger.info(f"Embedding layer: {old_vocab_size} tokens preserved.")
# B) Initialize new tokens with average
if new_vocab_size > old_vocab_size:
mean_emb = old_emb_weights.mean(dim=0)
num_new_tokens = new_vocab_size - old_vocab_size
new_model_state_dict[embedding_layer_name][old_vocab_size:, :].copy_(mean_emb.unsqueeze(0).expand(num_new_tokens, -1))
logger.info(f"Embedding layer: {num_new_tokens} new tokens initialized with mean.")
mean_init_applied = True
# Step 3: Smart copy for Output Head (Average Init)
if output_head_name in pretrained_state_dict:
old_head_weights = pretrained_state_dict[output_head_name]
old_vocab_size, _ = old_head_weights.shape
new_vocab_size = new_model_state_dict[output_head_name].shape[0]
# A) Copy old weights
new_model_state_dict[output_head_name][:old_vocab_size, :].copy_(old_head_weights)
logger.info(f"Output head: {old_vocab_size} tokens preserved.")
# B) Initialize new neurons with average
if new_vocab_size > old_vocab_size:
mean_head = old_head_weights.mean(dim=0)
num_new_tokens = new_vocab_size - old_vocab_size
new_model_state_dict[output_head_name][old_vocab_size:, :].copy_(mean_head.unsqueeze(0).expand(num_new_tokens, -1))
logger.info(f"Output head: {num_new_tokens} new neurons initialized with mean.")
mean_init_applied = True
# Step 4: Load the updated state dict into the new model
new_model.load_state_dict(new_model_state_dict)
if mean_init_applied:
logger.info("All weights transferred successfully (with mean initialization for new tokens)!")
else:
logger.info("All weights transferred successfully (direct copy, no resizing needed)!")
return new_model
class ChatterboxTrainerWrapper(torch.nn.Module):
"""
Wrapper class to calculate Loss inside the Forward pass for HuggingFace Trainer.
"""
def __init__(self, t3_model):
super().__init__()
self.t3 = t3_model
self.cfg = TrainConfig()
if hasattr(t3_model.hp, 'speech_cond_prompt_len'):
self.prompt_token_len = t3_model.hp.speech_cond_prompt_len
else:
self.prompt_token_len = 150
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
self.t3.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
def get_input_embeddings(self):
return self.t3.get_input_embeddings()
def forward(
self,
text_tokens,
text_token_lens,
speech_tokens,
speech_token_lens,
speaker_emb,
prompt_tokens):
device = text_tokens.device
batch_size = text_tokens.size(0)
emotion_adv = 0.5 * torch.ones(batch_size, 1, 1).to(device)
t3_cond = T3Cond(
speaker_emb=speaker_emb,
cond_prompt_speech_tokens=prompt_tokens,
emotion_adv=emotion_adv
)
# Forward Pass
out = self.t3.forward(
t3_cond=t3_cond,
text_tokens=text_tokens,
text_token_lens=text_token_lens,
speech_tokens=speech_tokens,
speech_token_lens=speech_token_lens,
training=True
)
IGNORE_ID = -100
speech_logits = out.speech_logits[:, :-1, :].transpose(1, 2)
speech_labels = speech_tokens[:, 1:]
curr_speech_len = speech_labels.size(1)
mask_speech_pad = torch.arange(curr_speech_len, device=device)[None, :] >= (speech_token_lens[:, None] - 1)
if self.cfg.is_turbo == True:
speech_labels = speech_labels.masked_fill(mask_speech_pad, IGNORE_ID)
else:
#mask_prompt = torch.arange(curr_speech_len, device=device)[None, :] < self.prompt_token_len
actual_prompt_len = prompt_tokens.size(1)
mask_prompt = torch.arange(curr_speech_len, device=device)[None, :] < actual_prompt_len
speech_labels = speech_labels.masked_fill(mask_speech_pad | mask_prompt, IGNORE_ID)
loss_speech = F.cross_entropy(speech_logits, speech_labels, ignore_index=IGNORE_ID)
text_logits = out.text_logits[:, :-1, :].transpose(1, 2)
text_labels = text_tokens[:, 1:]
curr_text_len = text_labels.size(1)
mask_text_pad = torch.arange(curr_text_len, device=device)[None, :] >= (text_token_lens[:, None] - 1)
text_labels = text_labels.masked_fill(mask_text_pad, IGNORE_ID)
loss_text = F.cross_entropy(text_logits, text_labels, ignore_index=IGNORE_ID)
total_loss = loss_text + loss_speech
# Return as dictionary - Trainer expects this format
# During training: uses "loss", during eval: uses "eval_loss"
return {
"loss": total_loss,
"loss_text": loss_text.detach(),
"loss_speech": loss_speech.detach()
}