import os import sys import torch from transformers import Trainer, TrainingArguments, EarlyStoppingCallback, TrainerCallback from safetensors.torch import save_file class ChatterboxTrainer(Trainer): """Custom Trainer to log sub-losses for both train and eval.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._eval_loss_text = [] self._eval_loss_speech = [] def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): outputs = model(**inputs) loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] if isinstance(outputs, dict): if model.training: if self.state.global_step % self.args.logging_steps == 0: if "loss_text" in outputs: self.log({"loss_text": outputs["loss_text"].item()}) if "loss_speech" in outputs: self.log({"loss_speech": outputs["loss_speech"].item()}) else: if "loss_text" in outputs: self._eval_loss_text.append(outputs["loss_text"].item()) if "loss_speech" in outputs: self._eval_loss_speech.append(outputs["loss_speech"].item()) return (loss, outputs) if return_outputs else loss def evaluation_loop(self, *args, **kwargs): self._eval_loss_text = [] self._eval_loss_speech = [] output = super().evaluation_loop(*args, **kwargs) if self._eval_loss_text: output.metrics["eval_loss_text"] = sum(self._eval_loss_text) / len(self._eval_loss_text) if self._eval_loss_speech: output.metrics["eval_loss_speech"] = sum(self._eval_loss_speech) / len(self._eval_loss_speech) return output # Internal Modules from src.config import TrainConfig from src.dataset import ChatterboxDataset, data_collator from src.model import resize_and_load_t3_weights, ChatterboxTrainerWrapper from src.preprocess_ljspeech import preprocess_dataset_ljspeech from src.preprocess_file_based import preprocess_dataset_file_based from src.utils import setup_logger, check_pretrained_models # Chatterbox Imports from src.chatterbox_.tts import ChatterboxTTS from src.chatterbox_.tts_turbo import ChatterboxTurboTTS from src.chatterbox_.models.t3.t3 import T3 os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["WANDB_API_KEY"] = "INSERT_API_KEY_HERE" os.environ["WANDB_PROJECT"] = "chatterbox-finetuning" logger = setup_logger("ChatterboxFinetune") def main(): cfg = TrainConfig() logger.info("--- Starting Chatterbox Finetuning ---") logger.info(f"Mode: {'CHATTERBOX-TURBO' if cfg.is_turbo else 'CHATTERBOX-TTS'}") # 0. CHECK MODEL FILES mode_check = "chatterbox_turbo" if cfg.is_turbo else "chatterbox" if not check_pretrained_models(mode=mode_check): sys.exit(1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 1. SELECT THE CORRECT ENGINE CLASS if cfg.is_turbo: EngineClass = ChatterboxTurboTTS else: EngineClass = ChatterboxTTS logger.info(f"Device: {device}") logger.info(f"Model Directory: {cfg.model_dir}") # 2. LOAD ORIGINAL MODEL TEMPORARILY logger.info("Loading original model to extract weights...") # Loading on CPU first to save VRAM tts_engine_original = EngineClass.from_local(cfg.model_dir, device="cpu") pretrained_t3_state_dict = tts_engine_original.t3.state_dict() original_t3_config = tts_engine_original.t3.hp # 3. CREATE NEW T3 MODEL WITH NEW VOCAB SIZE logger.info(f"Creating new T3 model with vocab size: {cfg.new_vocab_size}") new_t3_config = original_t3_config new_t3_config.text_tokens_dict_size = cfg.new_vocab_size # We prevent caching during training. if hasattr(new_t3_config, "use_cache"): new_t3_config.use_cache = False else: setattr(new_t3_config, "use_cache", False) new_t3_model = T3(hp=new_t3_config) # 4. TRANSFER WEIGHTS logger.info("Transferring weights...") new_t3_model = resize_and_load_t3_weights(new_t3_model, pretrained_t3_state_dict) # --- SPECIAL SETTING FOR TURBO --- if cfg.is_turbo: logger.info("Turbo Mode: Removing backbone WTE layer...") if hasattr(new_t3_model.tfmr, "wte"): del new_t3_model.tfmr.wte # Clean up memory del tts_engine_original del pretrained_t3_state_dict # 5. PREPARE ENGINE FOR TRAINING # Reload engine components (VoiceEncoder, S3Gen) but inject our new T3 tts_engine_new = EngineClass.from_local(cfg.model_dir, device="cpu") tts_engine_new.t3 = new_t3_model # Freeze other components logger.info("Freezing S3Gen and VoiceEncoder...") for param in tts_engine_new.ve.parameters(): param.requires_grad = False for param in tts_engine_new.s3gen.parameters(): param.requires_grad = False # Enable Training for T3 tts_engine_new.t3.train() for param in tts_engine_new.t3.parameters(): param.requires_grad = True if cfg.preprocess: logger.info("Initializing Preprocess dataset...") if cfg.ljspeech: preprocess_dataset_ljspeech(cfg, tts_engine_new) else: preprocess_dataset_file_based(cfg, tts_engine_new) else: logger.info("Skipping the preprocessing dataset step...") # 6. DATASET & WRAPPER logger.info("Initializing Datasets...") train_ds = ChatterboxDataset(cfg, split="train") val_ds = ChatterboxDataset(cfg, split="val") model_wrapper = ChatterboxTrainerWrapper(tts_engine_new.t3) # 7. TRAINING ARGUMENTS training_args = TrainingArguments( output_dir=cfg.output_dir, per_device_train_batch_size=cfg.batch_size, gradient_accumulation_steps=cfg.grad_accum, learning_rate=cfg.learning_rate, weight_decay=cfg.weight_decay, # Added weight decay num_train_epochs=cfg.num_epochs, evaluation_strategy="epoch", save_strategy="epoch", logging_strategy="steps", logging_steps=10, remove_unused_columns=False, # Required for our custom wrapper dataloader_num_workers=16, report_to=["wandb"], bf16=True if torch.cuda.is_available() else False, # Using bf16 for A100 save_total_limit=5, # Keep all epoch checkpoints gradient_checkpointing=False, # This setting theoretically reduces VRAM usage by 60%. label_names=["speech_tokens", "text_tokens"], load_best_model_at_end=True, lr_scheduler_type="cosine", # Research-optimized scheduler warmup_ratio=0.1, # 10% warmup to handle English-to-Finnish transition ) trainer = ChatterboxTrainer( model=model_wrapper, args=training_args, train_dataset=train_ds, eval_dataset=val_ds, data_collator=data_collator, callbacks=[] # Removed EarlyStopping ) logger.info("Running initial evaluation to establish baseline...") trainer.evaluate() logger.info("Starting Training Loop...") trainer.train() # 8. SAVE FINAL MODEL logger.info("Training complete. Saving model...") os.makedirs(cfg.output_dir, exist_ok=True) filename = "t3_turbo_finetuned.safetensors" if cfg.is_turbo else "t3_finetuned.safetensors" final_model_path = os.path.join(cfg.output_dir, filename) save_file(tts_engine_new.t3.state_dict(), final_model_path) logger.info(f"Model saved to: {final_model_path}") if __name__ == "__main__": main()