File size: 7,797 Bytes
67ea4ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import os
import sys
import torch
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback, TrainerCallback
from safetensors.torch import save_file

class ChatterboxTrainer(Trainer):
    """Custom Trainer to log sub-losses for both train and eval."""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._eval_loss_text = []
        self._eval_loss_speech = []

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs = model(**inputs)
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        
        if isinstance(outputs, dict):
            if model.training:
                if self.state.global_step % self.args.logging_steps == 0:
                    if "loss_text" in outputs:
                        self.log({"loss_text": outputs["loss_text"].item()})
                    if "loss_speech" in outputs:
                        self.log({"loss_speech": outputs["loss_speech"].item()})
            else:
                if "loss_text" in outputs:
                    self._eval_loss_text.append(outputs["loss_text"].item())
                if "loss_speech" in outputs:
                    self._eval_loss_speech.append(outputs["loss_speech"].item())
                
        return (loss, outputs) if return_outputs else loss

    def evaluation_loop(self, *args, **kwargs):
        self._eval_loss_text = []
        self._eval_loss_speech = []
        output = super().evaluation_loop(*args, **kwargs)
        if self._eval_loss_text:
            output.metrics["eval_loss_text"] = sum(self._eval_loss_text) / len(self._eval_loss_text)
        if self._eval_loss_speech:
            output.metrics["eval_loss_speech"] = sum(self._eval_loss_speech) / len(self._eval_loss_speech)
        return output

# Internal Modules
from src.config import TrainConfig
from src.dataset import ChatterboxDataset, data_collator
from src.model import resize_and_load_t3_weights, ChatterboxTrainerWrapper
from src.preprocess_ljspeech import preprocess_dataset_ljspeech
from src.preprocess_file_based import preprocess_dataset_file_based
from src.utils import setup_logger, check_pretrained_models

# Chatterbox Imports
from src.chatterbox_.tts import ChatterboxTTS
from src.chatterbox_.tts_turbo import ChatterboxTurboTTS
from src.chatterbox_.models.t3.t3 import T3

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_API_KEY"] = "INSERT_API_KEY_HERE"
os.environ["WANDB_PROJECT"] = "chatterbox-finetuning"

logger = setup_logger("ChatterboxFinetune")


def main():
    
    cfg = TrainConfig()
    
    logger.info("--- Starting Chatterbox Finetuning ---")
    logger.info(f"Mode: {'CHATTERBOX-TURBO' if cfg.is_turbo else 'CHATTERBOX-TTS'}")

    # 0. CHECK MODEL FILES
    mode_check = "chatterbox_turbo" if cfg.is_turbo else "chatterbox"
    if not check_pretrained_models(mode=mode_check):
        sys.exit(1)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 1. SELECT THE CORRECT ENGINE CLASS
    if cfg.is_turbo:
        EngineClass = ChatterboxTurboTTS
    else:
        EngineClass = ChatterboxTTS
    
    logger.info(f"Device: {device}")
    logger.info(f"Model Directory: {cfg.model_dir}")

    # 2. LOAD ORIGINAL MODEL TEMPORARILY
    logger.info("Loading original model to extract weights...")
    # Loading on CPU first to save VRAM
    tts_engine_original = EngineClass.from_local(cfg.model_dir, device="cpu")

    pretrained_t3_state_dict = tts_engine_original.t3.state_dict()
    original_t3_config = tts_engine_original.t3.hp

    # 3. CREATE NEW T3 MODEL WITH NEW VOCAB SIZE
    logger.info(f"Creating new T3 model with vocab size: {cfg.new_vocab_size}")
    
    new_t3_config = original_t3_config
    new_t3_config.text_tokens_dict_size = cfg.new_vocab_size

    # We prevent caching during training.
    if hasattr(new_t3_config, "use_cache"):
        new_t3_config.use_cache = False
    else:
        setattr(new_t3_config, "use_cache", False)

    new_t3_model = T3(hp=new_t3_config)

    # 4. TRANSFER WEIGHTS
    logger.info("Transferring weights...")
    new_t3_model = resize_and_load_t3_weights(new_t3_model, pretrained_t3_state_dict)


    # --- SPECIAL SETTING FOR TURBO ---
    if cfg.is_turbo:
        logger.info("Turbo Mode: Removing backbone WTE layer...")
        if hasattr(new_t3_model.tfmr, "wte"):
            del new_t3_model.tfmr.wte


    # Clean up memory
    del tts_engine_original
    del pretrained_t3_state_dict

    # 5. PREPARE ENGINE FOR TRAINING
    # Reload engine components (VoiceEncoder, S3Gen) but inject our new T3
    tts_engine_new = EngineClass.from_local(cfg.model_dir, device="cpu")
    tts_engine_new.t3 = new_t3_model 

    # Freeze other components
    logger.info("Freezing S3Gen and VoiceEncoder...")
    for param in tts_engine_new.ve.parameters(): 
        param.requires_grad = False
        
    for param in tts_engine_new.s3gen.parameters(): 
        param.requires_grad = False

    # Enable Training for T3
    tts_engine_new.t3.train()
    for param in tts_engine_new.t3.parameters(): 
        param.requires_grad = True

    if cfg.preprocess:
        
        logger.info("Initializing Preprocess dataset...")
        
        if cfg.ljspeech:
            preprocess_dataset_ljspeech(cfg, tts_engine_new)
            
        else:
            preprocess_dataset_file_based(cfg, tts_engine_new)
            
    else:
        logger.info("Skipping the preprocessing dataset step...")
            
        
    # 6. DATASET & WRAPPER
    logger.info("Initializing Datasets...")
    train_ds = ChatterboxDataset(cfg, split="train")
    val_ds = ChatterboxDataset(cfg, split="val")
    
    model_wrapper = ChatterboxTrainerWrapper(tts_engine_new.t3)

    # 7. TRAINING ARGUMENTS
    training_args = TrainingArguments(
        output_dir=cfg.output_dir,
        per_device_train_batch_size=cfg.batch_size,
        gradient_accumulation_steps=cfg.grad_accum,
        learning_rate=cfg.learning_rate,
        weight_decay=cfg.weight_decay, # Added weight decay
        num_train_epochs=cfg.num_epochs,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="steps",
        logging_steps=10,
        remove_unused_columns=False, # Required for our custom wrapper
        dataloader_num_workers=16,    
        report_to=["wandb"],
        bf16=True if torch.cuda.is_available() else False, # Using bf16 for A100
        save_total_limit=5,          # Keep all epoch checkpoints
        gradient_checkpointing=False, # This setting theoretically reduces VRAM usage by 60%.
        label_names=["speech_tokens", "text_tokens"],
        load_best_model_at_end=True, 
        lr_scheduler_type="cosine",    # Research-optimized scheduler
        warmup_ratio=0.1,              # 10% warmup to handle English-to-Finnish transition
    )

    trainer = ChatterboxTrainer(
        model=model_wrapper,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        data_collator=data_collator,
        callbacks=[]                  # Removed EarlyStopping
    )

    logger.info("Running initial evaluation to establish baseline...")
    trainer.evaluate()

    logger.info("Starting Training Loop...")
    trainer.train()


    # 8. SAVE FINAL MODEL
    logger.info("Training complete. Saving model...")
    os.makedirs(cfg.output_dir, exist_ok=True)
    
    filename = "t3_turbo_finetuned.safetensors" if cfg.is_turbo else "t3_finetuned.safetensors"
    final_model_path = os.path.join(cfg.output_dir, filename)

    save_file(tts_engine_new.t3.state_dict(), final_model_path)
    logger.info(f"Model saved to: {final_model_path}")


if __name__ == "__main__": 
    main()