File size: 7,797 Bytes
67ea4ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import os
import sys
import torch
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback, TrainerCallback
from safetensors.torch import save_file
class ChatterboxTrainer(Trainer):
"""Custom Trainer to log sub-losses for both train and eval."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._eval_loss_text = []
self._eval_loss_speech = []
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
outputs = model(**inputs)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
if isinstance(outputs, dict):
if model.training:
if self.state.global_step % self.args.logging_steps == 0:
if "loss_text" in outputs:
self.log({"loss_text": outputs["loss_text"].item()})
if "loss_speech" in outputs:
self.log({"loss_speech": outputs["loss_speech"].item()})
else:
if "loss_text" in outputs:
self._eval_loss_text.append(outputs["loss_text"].item())
if "loss_speech" in outputs:
self._eval_loss_speech.append(outputs["loss_speech"].item())
return (loss, outputs) if return_outputs else loss
def evaluation_loop(self, *args, **kwargs):
self._eval_loss_text = []
self._eval_loss_speech = []
output = super().evaluation_loop(*args, **kwargs)
if self._eval_loss_text:
output.metrics["eval_loss_text"] = sum(self._eval_loss_text) / len(self._eval_loss_text)
if self._eval_loss_speech:
output.metrics["eval_loss_speech"] = sum(self._eval_loss_speech) / len(self._eval_loss_speech)
return output
# Internal Modules
from src.config import TrainConfig
from src.dataset import ChatterboxDataset, data_collator
from src.model import resize_and_load_t3_weights, ChatterboxTrainerWrapper
from src.preprocess_ljspeech import preprocess_dataset_ljspeech
from src.preprocess_file_based import preprocess_dataset_file_based
from src.utils import setup_logger, check_pretrained_models
# Chatterbox Imports
from src.chatterbox_.tts import ChatterboxTTS
from src.chatterbox_.tts_turbo import ChatterboxTurboTTS
from src.chatterbox_.models.t3.t3 import T3
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_API_KEY"] = "INSERT_API_KEY_HERE"
os.environ["WANDB_PROJECT"] = "chatterbox-finetuning"
logger = setup_logger("ChatterboxFinetune")
def main():
cfg = TrainConfig()
logger.info("--- Starting Chatterbox Finetuning ---")
logger.info(f"Mode: {'CHATTERBOX-TURBO' if cfg.is_turbo else 'CHATTERBOX-TTS'}")
# 0. CHECK MODEL FILES
mode_check = "chatterbox_turbo" if cfg.is_turbo else "chatterbox"
if not check_pretrained_models(mode=mode_check):
sys.exit(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. SELECT THE CORRECT ENGINE CLASS
if cfg.is_turbo:
EngineClass = ChatterboxTurboTTS
else:
EngineClass = ChatterboxTTS
logger.info(f"Device: {device}")
logger.info(f"Model Directory: {cfg.model_dir}")
# 2. LOAD ORIGINAL MODEL TEMPORARILY
logger.info("Loading original model to extract weights...")
# Loading on CPU first to save VRAM
tts_engine_original = EngineClass.from_local(cfg.model_dir, device="cpu")
pretrained_t3_state_dict = tts_engine_original.t3.state_dict()
original_t3_config = tts_engine_original.t3.hp
# 3. CREATE NEW T3 MODEL WITH NEW VOCAB SIZE
logger.info(f"Creating new T3 model with vocab size: {cfg.new_vocab_size}")
new_t3_config = original_t3_config
new_t3_config.text_tokens_dict_size = cfg.new_vocab_size
# We prevent caching during training.
if hasattr(new_t3_config, "use_cache"):
new_t3_config.use_cache = False
else:
setattr(new_t3_config, "use_cache", False)
new_t3_model = T3(hp=new_t3_config)
# 4. TRANSFER WEIGHTS
logger.info("Transferring weights...")
new_t3_model = resize_and_load_t3_weights(new_t3_model, pretrained_t3_state_dict)
# --- SPECIAL SETTING FOR TURBO ---
if cfg.is_turbo:
logger.info("Turbo Mode: Removing backbone WTE layer...")
if hasattr(new_t3_model.tfmr, "wte"):
del new_t3_model.tfmr.wte
# Clean up memory
del tts_engine_original
del pretrained_t3_state_dict
# 5. PREPARE ENGINE FOR TRAINING
# Reload engine components (VoiceEncoder, S3Gen) but inject our new T3
tts_engine_new = EngineClass.from_local(cfg.model_dir, device="cpu")
tts_engine_new.t3 = new_t3_model
# Freeze other components
logger.info("Freezing S3Gen and VoiceEncoder...")
for param in tts_engine_new.ve.parameters():
param.requires_grad = False
for param in tts_engine_new.s3gen.parameters():
param.requires_grad = False
# Enable Training for T3
tts_engine_new.t3.train()
for param in tts_engine_new.t3.parameters():
param.requires_grad = True
if cfg.preprocess:
logger.info("Initializing Preprocess dataset...")
if cfg.ljspeech:
preprocess_dataset_ljspeech(cfg, tts_engine_new)
else:
preprocess_dataset_file_based(cfg, tts_engine_new)
else:
logger.info("Skipping the preprocessing dataset step...")
# 6. DATASET & WRAPPER
logger.info("Initializing Datasets...")
train_ds = ChatterboxDataset(cfg, split="train")
val_ds = ChatterboxDataset(cfg, split="val")
model_wrapper = ChatterboxTrainerWrapper(tts_engine_new.t3)
# 7. TRAINING ARGUMENTS
training_args = TrainingArguments(
output_dir=cfg.output_dir,
per_device_train_batch_size=cfg.batch_size,
gradient_accumulation_steps=cfg.grad_accum,
learning_rate=cfg.learning_rate,
weight_decay=cfg.weight_decay, # Added weight decay
num_train_epochs=cfg.num_epochs,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_strategy="steps",
logging_steps=10,
remove_unused_columns=False, # Required for our custom wrapper
dataloader_num_workers=16,
report_to=["wandb"],
bf16=True if torch.cuda.is_available() else False, # Using bf16 for A100
save_total_limit=5, # Keep all epoch checkpoints
gradient_checkpointing=False, # This setting theoretically reduces VRAM usage by 60%.
label_names=["speech_tokens", "text_tokens"],
load_best_model_at_end=True,
lr_scheduler_type="cosine", # Research-optimized scheduler
warmup_ratio=0.1, # 10% warmup to handle English-to-Finnish transition
)
trainer = ChatterboxTrainer(
model=model_wrapper,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=data_collator,
callbacks=[] # Removed EarlyStopping
)
logger.info("Running initial evaluation to establish baseline...")
trainer.evaluate()
logger.info("Starting Training Loop...")
trainer.train()
# 8. SAVE FINAL MODEL
logger.info("Training complete. Saving model...")
os.makedirs(cfg.output_dir, exist_ok=True)
filename = "t3_turbo_finetuned.safetensors" if cfg.is_turbo else "t3_finetuned.safetensors"
final_model_path = os.path.join(cfg.output_dir, filename)
save_file(tts_engine_new.t3.state_dict(), final_model_path)
logger.info(f"Model saved to: {final_model_path}")
if __name__ == "__main__":
main()
|