| """ |
| Fine-tune XTTS-v2 on curated Egyptian Arabic data. |
| |
| Uses the cleaned dataset from Phase 5 (data/egyptian/) to fine-tune |
| the GPT component of XTTS-v2. The base model weights are used as a |
| starting point, and only the GPT layers are updated. |
| |
| Training configuration: |
| - 4 epochs (conservative to avoid overfitting on 5h of data) |
| - Batch size 4, gradient accumulation 2 (effective batch = 8) |
| - Learning rate 5e-6 (AdamW) |
| - fp32 training (fp16/mixed precision causes NaN losses with XTTS GPT) |
| - Saves best checkpoint + every 1000 steps |
| |
| Usage: |
| conda activate new-arabic-tts |
| python scripts/train.py |
| |
| Output: |
| models/finetuned/run/training/ (checkpoints, logs, config) |
| """ |
|
|
| import os |
| import gc |
| import sys |
| import json |
| import time |
| from pathlib import Path |
|
|
| from trainer import Trainer, TrainerArgs |
| from TTS.config.shared_configs import BaseDatasetConfig |
| from TTS.tts.datasets import load_tts_samples |
| from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig |
| from TTS.utils.manage import ModelManager |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| DATA_DIR = PROJECT_ROOT / "data" / "Egyption" / "clean" |
| BASE_MODEL_DIR = PROJECT_ROOT / "models" / "base" |
| OUTPUT_DIR = PROJECT_ROOT / "models" / "finetuned" |
|
|
| TRAIN_CSV = str(DATA_DIR / "metadata_train.csv") |
| EVAL_CSV = str(DATA_DIR / "metadata_eval.csv") |
|
|
| |
| LANGUAGE = "ar" |
| NUM_EPOCHS = 4 |
| BATCH_SIZE = 4 |
| GRAD_ACCUM = 2 |
| LEARNING_RATE = 5e-6 |
| MAX_AUDIO_LENGTH = 255995 |
| SAVE_STEP = 1000 |
|
|
|
|
| def main(): |
| print("=" * 70) |
| print(" XTTS-v2 Fine-Tuning — Egyptian Arabic") |
| print("=" * 70) |
| t_start = time.time() |
|
|
| OUT_PATH = str(OUTPUT_DIR / "run" / "training") |
| os.makedirs(OUT_PATH, exist_ok=True) |
|
|
| |
| CHECKPOINTS_OUT = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files") |
| os.makedirs(CHECKPOINTS_OUT, exist_ok=True) |
|
|
| DVAE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth" |
| MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth" |
| DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT, "dvae.pth") |
| MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT, "mel_stats.pth") |
|
|
| if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): |
| print("[1/4] Downloading DVAE files...") |
| ModelManager._download_model_files([MEL_NORM_LINK, DVAE_LINK], CHECKPOINTS_OUT, progress_bar=True) |
| else: |
| print("[1/4] DVAE files already downloaded") |
|
|
| |
| TOKENIZER_FILE = str(BASE_MODEL_DIR / "vocab.json") |
| XTTS_CHECKPOINT = str(BASE_MODEL_DIR / "model.pth") |
| XTTS_CONFIG_FILE = str(BASE_MODEL_DIR / "config.json") |
|
|
| print(f"[2/4] Base model: {BASE_MODEL_DIR}") |
| print(f" Train CSV: {TRAIN_CSV}") |
| print(f" Eval CSV: {EVAL_CSV}") |
|
|
| |
| config_dataset = BaseDatasetConfig( |
| formatter="coqui", |
| dataset_name="egyptian_arabic_v2", |
| path=str(DATA_DIR), |
| meta_file_train=TRAIN_CSV, |
| meta_file_val=EVAL_CSV, |
| language=LANGUAGE, |
| ) |
|
|
| |
| model_args = GPTArgs( |
| max_conditioning_length=132300, |
| min_conditioning_length=66150, |
| debug_loading_failures=False, |
| max_wav_length=MAX_AUDIO_LENGTH, |
| max_text_length=200, |
| mel_norm_file=MEL_NORM_FILE, |
| dvae_checkpoint=DVAE_CHECKPOINT, |
| xtts_checkpoint=XTTS_CHECKPOINT, |
| tokenizer_file=TOKENIZER_FILE, |
| gpt_num_audio_tokens=1026, |
| gpt_start_audio_token=1024, |
| gpt_stop_audio_token=1025, |
| gpt_use_masking_gt_prompt_approach=True, |
| gpt_use_perceiver_resampler=True, |
| ) |
|
|
| audio_config = XttsAudioConfig( |
| sample_rate=22050, |
| dvae_sample_rate=22050, |
| output_sample_rate=24000, |
| ) |
|
|
| |
| config = GPTTrainerConfig( |
| epochs=NUM_EPOCHS, |
| output_path=OUT_PATH, |
| model_args=model_args, |
| run_name="GPT_XTTS_AR_FT", |
| project_name="Arabic_TTS", |
| run_description="Fine-tuning XTTS-v2 GPT on Egyptian Arabic v2 (cleaned, ~10.6k clips, single speaker)", |
| dashboard_logger="tensorboard", |
| audio=audio_config, |
| batch_size=BATCH_SIZE, |
| batch_group_size=48, |
| eval_batch_size=BATCH_SIZE, |
| num_loader_workers=8, |
| eval_split_max_size=256, |
| print_step=50, |
| plot_step=100, |
| log_model_step=100, |
| save_step=SAVE_STEP, |
| save_n_checkpoints=3, |
| save_checkpoints=True, |
| print_eval=False, |
| optimizer="AdamW", |
| optimizer_wd_only_on_weights=True, |
| optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2}, |
| lr=LEARNING_RATE, |
| lr_scheduler="MultiStepLR", |
| lr_scheduler_params={ |
| "milestones": [50000 * 18, 150000 * 18, 300000 * 18], |
| "gamma": 0.5, |
| "last_epoch": -1, |
| }, |
| test_sentences=[], |
| ) |
|
|
| |
| print("[3/4] Initializing model...") |
| model = GPTTrainer.init_from_config(config) |
|
|
| |
| train_samples, eval_samples = load_tts_samples( |
| [config_dataset], |
| eval_split=True, |
| eval_split_max_size=config.eval_split_max_size, |
| eval_split_size=config.eval_split_size, |
| ) |
| print(f" Train samples: {len(train_samples)}") |
| print(f" Eval samples: {len(eval_samples)}") |
|
|
| |
| print(f"[4/4] Starting training...") |
| print(f" Epochs: {NUM_EPOCHS}") |
| print(f" Batch size: {BATCH_SIZE} (x{GRAD_ACCUM} accum = {BATCH_SIZE * GRAD_ACCUM} effective)") |
| print(f" LR: {LEARNING_RATE}") |
| print(f" Save every: {SAVE_STEP} steps") |
| print(f" Output: {OUT_PATH}") |
| print() |
|
|
| trainer = Trainer( |
| TrainerArgs( |
| restore_path=None, |
| skip_train_epoch=False, |
| start_with_eval=False, |
| grad_accum_steps=GRAD_ACCUM, |
| ), |
| config, |
| output_path=OUT_PATH, |
| model=model, |
| train_samples=train_samples, |
| eval_samples=eval_samples, |
| ) |
| trainer.fit() |
|
|
| elapsed = (time.time() - t_start) / 3600 |
| print(f"\n{'='*70}") |
| print(f" Training Complete!") |
| print(f" Total time: {elapsed:.1f} hours") |
| print(f" Output: {trainer.output_path}") |
| print(f"{'='*70}") |
|
|
| |
| summary = { |
| "date": time.strftime("%Y-%m-%d %H:%M"), |
| "dataset": "egyptian_arabic_v2 (Egyption/clean)", |
| "train_clips": len(train_samples), |
| "eval_clips": len(eval_samples), |
| "epochs": NUM_EPOCHS, |
| "batch_size": BATCH_SIZE, |
| "grad_accum": GRAD_ACCUM, |
| "learning_rate": LEARNING_RATE, |
| "training_hours": round(elapsed, 2), |
| "output_path": trainer.output_path, |
| "base_model": str(BASE_MODEL_DIR), |
| } |
| summary_path = PROJECT_ROOT / "docs" / "benchmarks" / "training_summary.json" |
| with open(summary_path, "w", encoding="utf-8") as f: |
| json.dump(summary, f, ensure_ascii=False, indent=2) |
|
|
| del model, trainer, train_samples, eval_samples |
| gc.collect() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|