|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback, TrainerCallback |
|
|
from safetensors.torch import save_file |
|
|
|
|
|
class ChatterboxTrainer(Trainer): |
|
|
"""Custom Trainer to log sub-losses for both train and eval.""" |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self._eval_loss_text = [] |
|
|
self._eval_loss_speech = [] |
|
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
|
|
outputs = model(**inputs) |
|
|
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] |
|
|
|
|
|
if isinstance(outputs, dict): |
|
|
if model.training: |
|
|
if self.state.global_step % self.args.logging_steps == 0: |
|
|
if "loss_text" in outputs: |
|
|
self.log({"loss_text": outputs["loss_text"].item()}) |
|
|
if "loss_speech" in outputs: |
|
|
self.log({"loss_speech": outputs["loss_speech"].item()}) |
|
|
else: |
|
|
if "loss_text" in outputs: |
|
|
self._eval_loss_text.append(outputs["loss_text"].item()) |
|
|
if "loss_speech" in outputs: |
|
|
self._eval_loss_speech.append(outputs["loss_speech"].item()) |
|
|
|
|
|
return (loss, outputs) if return_outputs else loss |
|
|
|
|
|
def evaluation_loop(self, *args, **kwargs): |
|
|
self._eval_loss_text = [] |
|
|
self._eval_loss_speech = [] |
|
|
output = super().evaluation_loop(*args, **kwargs) |
|
|
if self._eval_loss_text: |
|
|
output.metrics["eval_loss_text"] = sum(self._eval_loss_text) / len(self._eval_loss_text) |
|
|
if self._eval_loss_speech: |
|
|
output.metrics["eval_loss_speech"] = sum(self._eval_loss_speech) / len(self._eval_loss_speech) |
|
|
return output |
|
|
|
|
|
|
|
|
from src.config import TrainConfig |
|
|
from src.dataset import ChatterboxDataset, data_collator |
|
|
from src.model import resize_and_load_t3_weights, ChatterboxTrainerWrapper |
|
|
from src.preprocess_ljspeech import preprocess_dataset_ljspeech |
|
|
from src.preprocess_file_based import preprocess_dataset_file_based |
|
|
from src.utils import setup_logger, check_pretrained_models |
|
|
|
|
|
|
|
|
from src.chatterbox_.tts import ChatterboxTTS |
|
|
from src.chatterbox_.tts_turbo import ChatterboxTurboTTS |
|
|
from src.chatterbox_.models.t3.t3 import T3 |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
os.environ["WANDB_API_KEY"] = "INSERT_API_KEY_HERE" |
|
|
os.environ["WANDB_PROJECT"] = "chatterbox-finetuning" |
|
|
|
|
|
logger = setup_logger("ChatterboxFinetune") |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
cfg = TrainConfig() |
|
|
|
|
|
logger.info("--- Starting Chatterbox Finetuning ---") |
|
|
logger.info(f"Mode: {'CHATTERBOX-TURBO' if cfg.is_turbo else 'CHATTERBOX-TTS'}") |
|
|
|
|
|
|
|
|
mode_check = "chatterbox_turbo" if cfg.is_turbo else "chatterbox" |
|
|
if not check_pretrained_models(mode=mode_check): |
|
|
sys.exit(1) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
if cfg.is_turbo: |
|
|
EngineClass = ChatterboxTurboTTS |
|
|
else: |
|
|
EngineClass = ChatterboxTTS |
|
|
|
|
|
logger.info(f"Device: {device}") |
|
|
logger.info(f"Model Directory: {cfg.model_dir}") |
|
|
|
|
|
|
|
|
logger.info("Loading original model to extract weights...") |
|
|
|
|
|
tts_engine_original = EngineClass.from_local(cfg.model_dir, device="cpu") |
|
|
|
|
|
pretrained_t3_state_dict = tts_engine_original.t3.state_dict() |
|
|
original_t3_config = tts_engine_original.t3.hp |
|
|
|
|
|
|
|
|
logger.info(f"Creating new T3 model with vocab size: {cfg.new_vocab_size}") |
|
|
|
|
|
new_t3_config = original_t3_config |
|
|
new_t3_config.text_tokens_dict_size = cfg.new_vocab_size |
|
|
|
|
|
|
|
|
if hasattr(new_t3_config, "use_cache"): |
|
|
new_t3_config.use_cache = False |
|
|
else: |
|
|
setattr(new_t3_config, "use_cache", False) |
|
|
|
|
|
new_t3_model = T3(hp=new_t3_config) |
|
|
|
|
|
|
|
|
logger.info("Transferring weights...") |
|
|
new_t3_model = resize_and_load_t3_weights(new_t3_model, pretrained_t3_state_dict) |
|
|
|
|
|
|
|
|
|
|
|
if cfg.is_turbo: |
|
|
logger.info("Turbo Mode: Removing backbone WTE layer...") |
|
|
if hasattr(new_t3_model.tfmr, "wte"): |
|
|
del new_t3_model.tfmr.wte |
|
|
|
|
|
|
|
|
|
|
|
del tts_engine_original |
|
|
del pretrained_t3_state_dict |
|
|
|
|
|
|
|
|
|
|
|
tts_engine_new = EngineClass.from_local(cfg.model_dir, device="cpu") |
|
|
tts_engine_new.t3 = new_t3_model |
|
|
|
|
|
|
|
|
logger.info("Freezing S3Gen and VoiceEncoder...") |
|
|
for param in tts_engine_new.ve.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
for param in tts_engine_new.s3gen.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
tts_engine_new.t3.train() |
|
|
for param in tts_engine_new.t3.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
if cfg.preprocess: |
|
|
|
|
|
logger.info("Initializing Preprocess dataset...") |
|
|
|
|
|
if cfg.ljspeech: |
|
|
preprocess_dataset_ljspeech(cfg, tts_engine_new) |
|
|
|
|
|
else: |
|
|
preprocess_dataset_file_based(cfg, tts_engine_new) |
|
|
|
|
|
else: |
|
|
logger.info("Skipping the preprocessing dataset step...") |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Initializing Datasets...") |
|
|
train_ds = ChatterboxDataset(cfg, split="train") |
|
|
val_ds = ChatterboxDataset(cfg, split="val") |
|
|
|
|
|
model_wrapper = ChatterboxTrainerWrapper(tts_engine_new.t3) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=cfg.output_dir, |
|
|
per_device_train_batch_size=cfg.batch_size, |
|
|
gradient_accumulation_steps=cfg.grad_accum, |
|
|
learning_rate=cfg.learning_rate, |
|
|
weight_decay=cfg.weight_decay, |
|
|
num_train_epochs=cfg.num_epochs, |
|
|
evaluation_strategy="epoch", |
|
|
save_strategy="epoch", |
|
|
logging_strategy="steps", |
|
|
logging_steps=10, |
|
|
remove_unused_columns=False, |
|
|
dataloader_num_workers=16, |
|
|
report_to=["wandb"], |
|
|
bf16=True if torch.cuda.is_available() else False, |
|
|
save_total_limit=5, |
|
|
gradient_checkpointing=False, |
|
|
label_names=["speech_tokens", "text_tokens"], |
|
|
load_best_model_at_end=True, |
|
|
lr_scheduler_type="cosine", |
|
|
warmup_ratio=0.1, |
|
|
) |
|
|
|
|
|
trainer = ChatterboxTrainer( |
|
|
model=model_wrapper, |
|
|
args=training_args, |
|
|
train_dataset=train_ds, |
|
|
eval_dataset=val_ds, |
|
|
data_collator=data_collator, |
|
|
callbacks=[] |
|
|
) |
|
|
|
|
|
logger.info("Running initial evaluation to establish baseline...") |
|
|
trainer.evaluate() |
|
|
|
|
|
logger.info("Starting Training Loop...") |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Training complete. Saving model...") |
|
|
os.makedirs(cfg.output_dir, exist_ok=True) |
|
|
|
|
|
filename = "t3_turbo_finetuned.safetensors" if cfg.is_turbo else "t3_finetuned.safetensors" |
|
|
final_model_path = os.path.join(cfg.output_dir, filename) |
|
|
|
|
|
save_file(tts_engine_new.t3.state_dict(), final_model_path) |
|
|
logger.info(f"Model saved to: {final_model_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|