File size: 6,841 Bytes
308155b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import torch
import torch.nn.functional as F
from src.chatterbox_.models.t3.modules.cond_enc import T3Cond
from src.config import TrainConfig
from src.utils import setup_logger


logger = setup_logger(__name__)


def resize_and_load_t3_weights(new_model: torch.nn.Module, pretrained_state_dict: dict):
    """
    Loads pretrained weights into a new T3 model with a different vocabulary size.
    Features: Initialize new tokens with the AVERAGE of existing tokens.
    """
    new_model_state_dict = new_model.state_dict()

    embedding_layer_name = "text_emb.weight"
    output_head_name = "text_head.weight"
    
    mean_init_applied = False

    # Step 1: Copy weights for ALL matching layers
    for name, param in pretrained_state_dict.items():
        
        if name not in [embedding_layer_name, output_head_name]:
            
            if name in new_model_state_dict and new_model_state_dict[name].shape == param.shape:
                new_model_state_dict[name].copy_(param)
                
            else:
                logger.warning(f"Layer skipped (mismatch): {name}")


    # Step 2: Smart copy for Embedding Layer (Average Init)
    if embedding_layer_name in pretrained_state_dict:
        
        old_emb_weights = pretrained_state_dict[embedding_layer_name]
        old_vocab_size, _ = old_emb_weights.shape
        new_vocab_size = new_model_state_dict[embedding_layer_name].shape[0]

        # A) Copy old weights
        new_model_state_dict[embedding_layer_name][:old_vocab_size, :].copy_(old_emb_weights)
        logger.info(f"Embedding layer: {old_vocab_size} tokens preserved.")

        # B) Initialize new tokens with average
        if new_vocab_size > old_vocab_size:
            
            mean_emb = old_emb_weights.mean(dim=0)
            num_new_tokens = new_vocab_size - old_vocab_size
            
            new_model_state_dict[embedding_layer_name][old_vocab_size:, :].copy_(mean_emb.unsqueeze(0).expand(num_new_tokens, -1))
            
            logger.info(f"Embedding layer: {num_new_tokens} new tokens initialized with mean.")
            mean_init_applied = True


    # Step 3: Smart copy for Output Head (Average Init)
    if output_head_name in pretrained_state_dict:
        
        old_head_weights = pretrained_state_dict[output_head_name]
        old_vocab_size, _ = old_head_weights.shape
        new_vocab_size = new_model_state_dict[output_head_name].shape[0]

        # A) Copy old weights
        new_model_state_dict[output_head_name][:old_vocab_size, :].copy_(old_head_weights)
        logger.info(f"Output head: {old_vocab_size} tokens preserved.")

        # B) Initialize new neurons with average
        if new_vocab_size > old_vocab_size:
            
            mean_head = old_head_weights.mean(dim=0)
            num_new_tokens = new_vocab_size - old_vocab_size
            new_model_state_dict[output_head_name][old_vocab_size:, :].copy_(mean_head.unsqueeze(0).expand(num_new_tokens, -1))
            
            logger.info(f"Output head: {num_new_tokens} new neurons initialized with mean.")
            mean_init_applied = True

    # Step 4: Load the updated state dict into the new model
    new_model.load_state_dict(new_model_state_dict)
    
    if mean_init_applied:
        logger.info("All weights transferred successfully (with mean initialization for new tokens)!")
    else:
        logger.info("All weights transferred successfully (direct copy, no resizing needed)!")

    return new_model


class ChatterboxTrainerWrapper(torch.nn.Module):
    """
    Wrapper class to calculate Loss inside the Forward pass for HuggingFace Trainer.
    """
    
    def __init__(self, t3_model):
        
        super().__init__()
        self.t3 = t3_model
        
        self.cfg = TrainConfig()
        
        if hasattr(t3_model.hp, 'speech_cond_prompt_len'):
            self.prompt_token_len = t3_model.hp.speech_cond_prompt_len
        else:
            self.prompt_token_len = 150 


    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        self.t3.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

    def get_input_embeddings(self):
        return self.t3.get_input_embeddings()


    def forward(
            self,
            text_tokens, 
            text_token_lens,
            speech_tokens, 
            speech_token_lens,
            speaker_emb, 
            prompt_tokens):

        device = text_tokens.device
        batch_size = text_tokens.size(0)
        
        emotion_adv = 0.5 * torch.ones(batch_size, 1, 1).to(device)
        
        t3_cond = T3Cond(
            speaker_emb=speaker_emb,
            cond_prompt_speech_tokens=prompt_tokens,
            emotion_adv=emotion_adv
        )

        # Forward Pass
        out = self.t3.forward(
            t3_cond=t3_cond,
            text_tokens=text_tokens,
            text_token_lens=text_token_lens,
            speech_tokens=speech_tokens,
            speech_token_lens=speech_token_lens,
            training=True
        )

        IGNORE_ID = -100

        speech_logits = out.speech_logits[:, :-1, :].transpose(1, 2)
        speech_labels = speech_tokens[:, 1:] 
        
        curr_speech_len = speech_labels.size(1)
        mask_speech_pad = torch.arange(curr_speech_len, device=device)[None, :] >= (speech_token_lens[:, None] - 1)

        if self.cfg.is_turbo == True:
            speech_labels = speech_labels.masked_fill(mask_speech_pad, IGNORE_ID)
            
        else:
            
            #mask_prompt = torch.arange(curr_speech_len, device=device)[None, :] < self.prompt_token_len
            
            actual_prompt_len = prompt_tokens.size(1)
            mask_prompt = torch.arange(curr_speech_len, device=device)[None, :] < actual_prompt_len
            
            speech_labels = speech_labels.masked_fill(mask_speech_pad | mask_prompt, IGNORE_ID)
            
            
        loss_speech = F.cross_entropy(speech_logits, speech_labels, ignore_index=IGNORE_ID)


        text_logits = out.text_logits[:, :-1, :].transpose(1, 2)
        text_labels = text_tokens[:, 1:]
            
        curr_text_len = text_labels.size(1)
        mask_text_pad = torch.arange(curr_text_len, device=device)[None, :] >= (text_token_lens[:, None] - 1)
        
        text_labels = text_labels.masked_fill(mask_text_pad, IGNORE_ID)
            
        loss_text = F.cross_entropy(text_logits, text_labels, ignore_index=IGNORE_ID)

        total_loss = loss_text + loss_speech

        # Return as dictionary - Trainer expects this format
        # During training: uses "loss", during eval: uses "eval_loss"
        return {
            "loss": total_loss,
            "loss_text": loss_text.detach(),
            "loss_speech": loss_speech.detach()
        }