|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from src.chatterbox_.models.t3.modules.cond_enc import T3Cond |
|
|
from src.config import TrainConfig |
|
|
from src.utils import setup_logger |
|
|
|
|
|
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
|
|
|
def resize_and_load_t3_weights(new_model: torch.nn.Module, pretrained_state_dict: dict): |
|
|
""" |
|
|
Loads pretrained weights into a new T3 model with a different vocabulary size. |
|
|
Features: Initialize new tokens with the AVERAGE of existing tokens. |
|
|
""" |
|
|
new_model_state_dict = new_model.state_dict() |
|
|
|
|
|
embedding_layer_name = "text_emb.weight" |
|
|
output_head_name = "text_head.weight" |
|
|
|
|
|
mean_init_applied = False |
|
|
|
|
|
|
|
|
for name, param in pretrained_state_dict.items(): |
|
|
|
|
|
if name not in [embedding_layer_name, output_head_name]: |
|
|
|
|
|
if name in new_model_state_dict and new_model_state_dict[name].shape == param.shape: |
|
|
new_model_state_dict[name].copy_(param) |
|
|
|
|
|
else: |
|
|
logger.warning(f"Layer skipped (mismatch): {name}") |
|
|
|
|
|
|
|
|
|
|
|
if embedding_layer_name in pretrained_state_dict: |
|
|
|
|
|
old_emb_weights = pretrained_state_dict[embedding_layer_name] |
|
|
old_vocab_size, _ = old_emb_weights.shape |
|
|
new_vocab_size = new_model_state_dict[embedding_layer_name].shape[0] |
|
|
|
|
|
|
|
|
new_model_state_dict[embedding_layer_name][:old_vocab_size, :].copy_(old_emb_weights) |
|
|
logger.info(f"Embedding layer: {old_vocab_size} tokens preserved.") |
|
|
|
|
|
|
|
|
if new_vocab_size > old_vocab_size: |
|
|
|
|
|
mean_emb = old_emb_weights.mean(dim=0) |
|
|
num_new_tokens = new_vocab_size - old_vocab_size |
|
|
|
|
|
new_model_state_dict[embedding_layer_name][old_vocab_size:, :].copy_(mean_emb.unsqueeze(0).expand(num_new_tokens, -1)) |
|
|
|
|
|
logger.info(f"Embedding layer: {num_new_tokens} new tokens initialized with mean.") |
|
|
mean_init_applied = True |
|
|
|
|
|
|
|
|
|
|
|
if output_head_name in pretrained_state_dict: |
|
|
|
|
|
old_head_weights = pretrained_state_dict[output_head_name] |
|
|
old_vocab_size, _ = old_head_weights.shape |
|
|
new_vocab_size = new_model_state_dict[output_head_name].shape[0] |
|
|
|
|
|
|
|
|
new_model_state_dict[output_head_name][:old_vocab_size, :].copy_(old_head_weights) |
|
|
logger.info(f"Output head: {old_vocab_size} tokens preserved.") |
|
|
|
|
|
|
|
|
if new_vocab_size > old_vocab_size: |
|
|
|
|
|
mean_head = old_head_weights.mean(dim=0) |
|
|
num_new_tokens = new_vocab_size - old_vocab_size |
|
|
new_model_state_dict[output_head_name][old_vocab_size:, :].copy_(mean_head.unsqueeze(0).expand(num_new_tokens, -1)) |
|
|
|
|
|
logger.info(f"Output head: {num_new_tokens} new neurons initialized with mean.") |
|
|
mean_init_applied = True |
|
|
|
|
|
|
|
|
new_model.load_state_dict(new_model_state_dict) |
|
|
|
|
|
if mean_init_applied: |
|
|
logger.info("All weights transferred successfully (with mean initialization for new tokens)!") |
|
|
else: |
|
|
logger.info("All weights transferred successfully (direct copy, no resizing needed)!") |
|
|
|
|
|
return new_model |
|
|
|
|
|
|
|
|
class ChatterboxTrainerWrapper(torch.nn.Module): |
|
|
""" |
|
|
Wrapper class to calculate Loss inside the Forward pass for HuggingFace Trainer. |
|
|
""" |
|
|
|
|
|
def __init__(self, t3_model): |
|
|
|
|
|
super().__init__() |
|
|
self.t3 = t3_model |
|
|
|
|
|
self.cfg = TrainConfig() |
|
|
|
|
|
if hasattr(t3_model.hp, 'speech_cond_prompt_len'): |
|
|
self.prompt_token_len = t3_model.hp.speech_cond_prompt_len |
|
|
else: |
|
|
self.prompt_token_len = 150 |
|
|
|
|
|
|
|
|
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
|
|
self.t3.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.t3.get_input_embeddings() |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
text_tokens, |
|
|
text_token_lens, |
|
|
speech_tokens, |
|
|
speech_token_lens, |
|
|
speaker_emb, |
|
|
prompt_tokens): |
|
|
|
|
|
device = text_tokens.device |
|
|
batch_size = text_tokens.size(0) |
|
|
|
|
|
emotion_adv = 0.5 * torch.ones(batch_size, 1, 1).to(device) |
|
|
|
|
|
t3_cond = T3Cond( |
|
|
speaker_emb=speaker_emb, |
|
|
cond_prompt_speech_tokens=prompt_tokens, |
|
|
emotion_adv=emotion_adv |
|
|
) |
|
|
|
|
|
|
|
|
out = self.t3.forward( |
|
|
t3_cond=t3_cond, |
|
|
text_tokens=text_tokens, |
|
|
text_token_lens=text_token_lens, |
|
|
speech_tokens=speech_tokens, |
|
|
speech_token_lens=speech_token_lens, |
|
|
training=True |
|
|
) |
|
|
|
|
|
IGNORE_ID = -100 |
|
|
|
|
|
speech_logits = out.speech_logits[:, :-1, :].transpose(1, 2) |
|
|
speech_labels = speech_tokens[:, 1:] |
|
|
|
|
|
curr_speech_len = speech_labels.size(1) |
|
|
mask_speech_pad = torch.arange(curr_speech_len, device=device)[None, :] >= (speech_token_lens[:, None] - 1) |
|
|
|
|
|
if self.cfg.is_turbo == True: |
|
|
speech_labels = speech_labels.masked_fill(mask_speech_pad, IGNORE_ID) |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
actual_prompt_len = prompt_tokens.size(1) |
|
|
mask_prompt = torch.arange(curr_speech_len, device=device)[None, :] < actual_prompt_len |
|
|
|
|
|
speech_labels = speech_labels.masked_fill(mask_speech_pad | mask_prompt, IGNORE_ID) |
|
|
|
|
|
|
|
|
loss_speech = F.cross_entropy(speech_logits, speech_labels, ignore_index=IGNORE_ID) |
|
|
|
|
|
|
|
|
text_logits = out.text_logits[:, :-1, :].transpose(1, 2) |
|
|
text_labels = text_tokens[:, 1:] |
|
|
|
|
|
curr_text_len = text_labels.size(1) |
|
|
mask_text_pad = torch.arange(curr_text_len, device=device)[None, :] >= (text_token_lens[:, None] - 1) |
|
|
|
|
|
text_labels = text_labels.masked_fill(mask_text_pad, IGNORE_ID) |
|
|
|
|
|
loss_text = F.cross_entropy(text_logits, text_labels, ignore_index=IGNORE_ID) |
|
|
|
|
|
total_loss = loss_text + loss_speech |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
"loss": total_loss, |
|
|
"loss_text": loss_text.detach(), |
|
|
"loss_speech": loss_speech.detach() |
|
|
} |