File size: 6,841 Bytes
308155b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import torch
import torch.nn.functional as F
from src.chatterbox_.models.t3.modules.cond_enc import T3Cond
from src.config import TrainConfig
from src.utils import setup_logger
logger = setup_logger(__name__)
def resize_and_load_t3_weights(new_model: torch.nn.Module, pretrained_state_dict: dict):
"""
Loads pretrained weights into a new T3 model with a different vocabulary size.
Features: Initialize new tokens with the AVERAGE of existing tokens.
"""
new_model_state_dict = new_model.state_dict()
embedding_layer_name = "text_emb.weight"
output_head_name = "text_head.weight"
mean_init_applied = False
# Step 1: Copy weights for ALL matching layers
for name, param in pretrained_state_dict.items():
if name not in [embedding_layer_name, output_head_name]:
if name in new_model_state_dict and new_model_state_dict[name].shape == param.shape:
new_model_state_dict[name].copy_(param)
else:
logger.warning(f"Layer skipped (mismatch): {name}")
# Step 2: Smart copy for Embedding Layer (Average Init)
if embedding_layer_name in pretrained_state_dict:
old_emb_weights = pretrained_state_dict[embedding_layer_name]
old_vocab_size, _ = old_emb_weights.shape
new_vocab_size = new_model_state_dict[embedding_layer_name].shape[0]
# A) Copy old weights
new_model_state_dict[embedding_layer_name][:old_vocab_size, :].copy_(old_emb_weights)
logger.info(f"Embedding layer: {old_vocab_size} tokens preserved.")
# B) Initialize new tokens with average
if new_vocab_size > old_vocab_size:
mean_emb = old_emb_weights.mean(dim=0)
num_new_tokens = new_vocab_size - old_vocab_size
new_model_state_dict[embedding_layer_name][old_vocab_size:, :].copy_(mean_emb.unsqueeze(0).expand(num_new_tokens, -1))
logger.info(f"Embedding layer: {num_new_tokens} new tokens initialized with mean.")
mean_init_applied = True
# Step 3: Smart copy for Output Head (Average Init)
if output_head_name in pretrained_state_dict:
old_head_weights = pretrained_state_dict[output_head_name]
old_vocab_size, _ = old_head_weights.shape
new_vocab_size = new_model_state_dict[output_head_name].shape[0]
# A) Copy old weights
new_model_state_dict[output_head_name][:old_vocab_size, :].copy_(old_head_weights)
logger.info(f"Output head: {old_vocab_size} tokens preserved.")
# B) Initialize new neurons with average
if new_vocab_size > old_vocab_size:
mean_head = old_head_weights.mean(dim=0)
num_new_tokens = new_vocab_size - old_vocab_size
new_model_state_dict[output_head_name][old_vocab_size:, :].copy_(mean_head.unsqueeze(0).expand(num_new_tokens, -1))
logger.info(f"Output head: {num_new_tokens} new neurons initialized with mean.")
mean_init_applied = True
# Step 4: Load the updated state dict into the new model
new_model.load_state_dict(new_model_state_dict)
if mean_init_applied:
logger.info("All weights transferred successfully (with mean initialization for new tokens)!")
else:
logger.info("All weights transferred successfully (direct copy, no resizing needed)!")
return new_model
class ChatterboxTrainerWrapper(torch.nn.Module):
"""
Wrapper class to calculate Loss inside the Forward pass for HuggingFace Trainer.
"""
def __init__(self, t3_model):
super().__init__()
self.t3 = t3_model
self.cfg = TrainConfig()
if hasattr(t3_model.hp, 'speech_cond_prompt_len'):
self.prompt_token_len = t3_model.hp.speech_cond_prompt_len
else:
self.prompt_token_len = 150
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
self.t3.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
def get_input_embeddings(self):
return self.t3.get_input_embeddings()
def forward(
self,
text_tokens,
text_token_lens,
speech_tokens,
speech_token_lens,
speaker_emb,
prompt_tokens):
device = text_tokens.device
batch_size = text_tokens.size(0)
emotion_adv = 0.5 * torch.ones(batch_size, 1, 1).to(device)
t3_cond = T3Cond(
speaker_emb=speaker_emb,
cond_prompt_speech_tokens=prompt_tokens,
emotion_adv=emotion_adv
)
# Forward Pass
out = self.t3.forward(
t3_cond=t3_cond,
text_tokens=text_tokens,
text_token_lens=text_token_lens,
speech_tokens=speech_tokens,
speech_token_lens=speech_token_lens,
training=True
)
IGNORE_ID = -100
speech_logits = out.speech_logits[:, :-1, :].transpose(1, 2)
speech_labels = speech_tokens[:, 1:]
curr_speech_len = speech_labels.size(1)
mask_speech_pad = torch.arange(curr_speech_len, device=device)[None, :] >= (speech_token_lens[:, None] - 1)
if self.cfg.is_turbo == True:
speech_labels = speech_labels.masked_fill(mask_speech_pad, IGNORE_ID)
else:
#mask_prompt = torch.arange(curr_speech_len, device=device)[None, :] < self.prompt_token_len
actual_prompt_len = prompt_tokens.size(1)
mask_prompt = torch.arange(curr_speech_len, device=device)[None, :] < actual_prompt_len
speech_labels = speech_labels.masked_fill(mask_speech_pad | mask_prompt, IGNORE_ID)
loss_speech = F.cross_entropy(speech_logits, speech_labels, ignore_index=IGNORE_ID)
text_logits = out.text_logits[:, :-1, :].transpose(1, 2)
text_labels = text_tokens[:, 1:]
curr_text_len = text_labels.size(1)
mask_text_pad = torch.arange(curr_text_len, device=device)[None, :] >= (text_token_lens[:, None] - 1)
text_labels = text_labels.masked_fill(mask_text_pad, IGNORE_ID)
loss_text = F.cross_entropy(text_logits, text_labels, ignore_index=IGNORE_ID)
total_loss = loss_text + loss_speech
# Return as dictionary - Trainer expects this format
# During training: uses "loss", during eval: uses "eval_loss"
return {
"loss": total_loss,
"loss_text": loss_text.detach(),
"loss_speech": loss_speech.detach()
} |