Moeeldouma's picture
Add all scripts with prosody improvements
97e3499 verified
"""
Fine-tune XTTS-v2 on curated Egyptian Arabic data.
Uses the cleaned dataset from Phase 5 (data/egyptian/) to fine-tune
the GPT component of XTTS-v2. The base model weights are used as a
starting point, and only the GPT layers are updated.
Training configuration:
- 4 epochs (conservative to avoid overfitting on 5h of data)
- Batch size 4, gradient accumulation 2 (effective batch = 8)
- Learning rate 5e-6 (AdamW)
- fp32 training (fp16/mixed precision causes NaN losses with XTTS GPT)
- Saves best checkpoint + every 1000 steps
Usage:
conda activate new-arabic-tts
python scripts/train.py
Output:
models/finetuned/run/training/ (checkpoints, logs, config)
"""
import os
import gc
import sys
import json
import time
from pathlib import Path
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.utils.manage import ModelManager
# --- Paths ---
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DATA_DIR = PROJECT_ROOT / "data" / "Egyption" / "clean"
BASE_MODEL_DIR = PROJECT_ROOT / "models" / "base"
OUTPUT_DIR = PROJECT_ROOT / "models" / "finetuned"
TRAIN_CSV = str(DATA_DIR / "metadata_train.csv")
EVAL_CSV = str(DATA_DIR / "metadata_eval.csv")
# --- Training Config ---
LANGUAGE = "ar"
NUM_EPOCHS = 4
BATCH_SIZE = 4
GRAD_ACCUM = 2
LEARNING_RATE = 5e-6
MAX_AUDIO_LENGTH = 255995 # ~11.6 seconds at 22050 Hz
SAVE_STEP = 1000
def main():
print("=" * 70)
print(" XTTS-v2 Fine-Tuning — Egyptian Arabic")
print("=" * 70)
t_start = time.time()
OUT_PATH = str(OUTPUT_DIR / "run" / "training")
os.makedirs(OUT_PATH, exist_ok=True)
# --- Download DVAE and mel norm files ---
CHECKPOINTS_OUT = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files")
os.makedirs(CHECKPOINTS_OUT, exist_ok=True)
DVAE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT, "dvae.pth")
MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT, "mel_stats.pth")
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
print("[1/4] Downloading DVAE files...")
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_LINK], CHECKPOINTS_OUT, progress_bar=True)
else:
print("[1/4] DVAE files already downloaded")
# --- Use local base model files ---
TOKENIZER_FILE = str(BASE_MODEL_DIR / "vocab.json")
XTTS_CHECKPOINT = str(BASE_MODEL_DIR / "model.pth")
XTTS_CONFIG_FILE = str(BASE_MODEL_DIR / "config.json")
print(f"[2/4] Base model: {BASE_MODEL_DIR}")
print(f" Train CSV: {TRAIN_CSV}")
print(f" Eval CSV: {EVAL_CSV}")
# --- Dataset config ---
config_dataset = BaseDatasetConfig(
formatter="coqui",
dataset_name="egyptian_arabic_v2",
path=str(DATA_DIR),
meta_file_train=TRAIN_CSV,
meta_file_val=EVAL_CSV,
language=LANGUAGE,
)
# --- Model args ---
model_args = GPTArgs(
max_conditioning_length=132300, # 6 seconds
min_conditioning_length=66150, # 3 seconds
debug_loading_failures=False,
max_wav_length=MAX_AUDIO_LENGTH,
max_text_length=200,
mel_norm_file=MEL_NORM_FILE,
dvae_checkpoint=DVAE_CHECKPOINT,
xtts_checkpoint=XTTS_CHECKPOINT,
tokenizer_file=TOKENIZER_FILE,
gpt_num_audio_tokens=1026,
gpt_start_audio_token=1024,
gpt_stop_audio_token=1025,
gpt_use_masking_gt_prompt_approach=True,
gpt_use_perceiver_resampler=True,
)
audio_config = XttsAudioConfig(
sample_rate=22050,
dvae_sample_rate=22050,
output_sample_rate=24000,
)
# --- Trainer config ---
config = GPTTrainerConfig(
epochs=NUM_EPOCHS,
output_path=OUT_PATH,
model_args=model_args,
run_name="GPT_XTTS_AR_FT",
project_name="Arabic_TTS",
run_description="Fine-tuning XTTS-v2 GPT on Egyptian Arabic v2 (cleaned, ~10.6k clips, single speaker)",
dashboard_logger="tensorboard",
audio=audio_config,
batch_size=BATCH_SIZE,
batch_group_size=48,
eval_batch_size=BATCH_SIZE,
num_loader_workers=8,
eval_split_max_size=256,
print_step=50,
plot_step=100,
log_model_step=100,
save_step=SAVE_STEP,
save_n_checkpoints=3,
save_checkpoints=True,
print_eval=False,
optimizer="AdamW",
optimizer_wd_only_on_weights=True,
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
lr=LEARNING_RATE,
lr_scheduler="MultiStepLR",
lr_scheduler_params={
"milestones": [50000 * 18, 150000 * 18, 300000 * 18],
"gamma": 0.5,
"last_epoch": -1,
},
test_sentences=[],
)
# --- Init model ---
print("[3/4] Initializing model...")
model = GPTTrainer.init_from_config(config)
# --- Load data ---
train_samples, eval_samples = load_tts_samples(
[config_dataset],
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
print(f" Train samples: {len(train_samples)}")
print(f" Eval samples: {len(eval_samples)}")
# --- Train ---
print(f"[4/4] Starting training...")
print(f" Epochs: {NUM_EPOCHS}")
print(f" Batch size: {BATCH_SIZE} (x{GRAD_ACCUM} accum = {BATCH_SIZE * GRAD_ACCUM} effective)")
print(f" LR: {LEARNING_RATE}")
print(f" Save every: {SAVE_STEP} steps")
print(f" Output: {OUT_PATH}")
print()
trainer = Trainer(
TrainerArgs(
restore_path=None,
skip_train_epoch=False,
start_with_eval=False,
grad_accum_steps=GRAD_ACCUM,
),
config,
output_path=OUT_PATH,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
)
trainer.fit()
elapsed = (time.time() - t_start) / 3600
print(f"\n{'='*70}")
print(f" Training Complete!")
print(f" Total time: {elapsed:.1f} hours")
print(f" Output: {trainer.output_path}")
print(f"{'='*70}")
# Save training summary
summary = {
"date": time.strftime("%Y-%m-%d %H:%M"),
"dataset": "egyptian_arabic_v2 (Egyption/clean)",
"train_clips": len(train_samples),
"eval_clips": len(eval_samples),
"epochs": NUM_EPOCHS,
"batch_size": BATCH_SIZE,
"grad_accum": GRAD_ACCUM,
"learning_rate": LEARNING_RATE,
"training_hours": round(elapsed, 2),
"output_path": trainer.output_path,
"base_model": str(BASE_MODEL_DIR),
}
summary_path = PROJECT_ROOT / "docs" / "benchmarks" / "training_summary.json"
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
del model, trainer, train_samples, eval_samples
gc.collect()
if __name__ == "__main__":
main()