sinhala-tts / scripts /train_f5.py
outlawmold's picture
Add server deployment instructions and cross-platform training fixes
937d0e8
#!/usr/bin/env python3
"""Cross-platform launcher for F5-TTS fine-tuning.
Auto-detects the compute device (CUDA / MPS / CPU) and adjusts settings:
- bnb_optimizer is only enabled on CUDA (bitsandbytes doesn't support MPS/CPU).
- num_workers is forced to 0 on Windows & macOS (soundfile/LibsndfileError).
- On Linux, num_workers defaults to 12 (or override via F5_NUM_WORKERS env var).
- Windows WMI cache is preloaded to avoid torch import hangs.
"""
from __future__ import annotations
import os
import platform
import socket
import sys
# ---------------------------------------------------------------------------
# Windows-only: preload WMI cache to avoid hang during torch import
# ---------------------------------------------------------------------------
if sys.platform == "win32":
platform._uname_cache = platform.uname_result( # type: ignore[attr-defined]
"Windows", socket.gethostname(), "10", "Windows", "AMD64"
)
# ---------------------------------------------------------------------------
# Device-safe bnb_optimizer patch
# bitsandbytes only supports CUDA. On MPS or CPU we silently fall back to
# standard AdamW so the same CLI flags work everywhere.
# ---------------------------------------------------------------------------
import f5_tts.model.trainer as _trainer_module
_original_trainer_init = _trainer_module.Trainer.__init__
def _patched_trainer_init(self, *args, **kwargs):
bnb = kwargs.get("bnb_optimizer", False)
if bnb:
import torch
if not torch.cuda.is_available():
kwargs["bnb_optimizer"] = False
dev = "mps" if torch.backends.mps.is_available() else "cpu"
print(
f"\n[Device] CUDA not available (detected {dev}). "
f"Disabling --bnb_optimizer; using standard AdamW.\n"
)
return _original_trainer_init(self, *args, **kwargs)
_trainer_module.Trainer.__init__ = _patched_trainer_init
# ---------------------------------------------------------------------------
# Platform-aware DataLoader patches
# ---------------------------------------------------------------------------
_original_train = _trainer_module.Trainer.train
def _patched_train(self, train_dataset, num_workers=12, resumable_with_seed=None):
env_workers = os.getenv("F5_NUM_WORKERS")
if env_workers is not None:
num_workers = int(env_workers)
print(f"\n[Workers] Using F5_NUM_WORKERS={num_workers}.\n")
elif sys.platform == "win32":
print(
"\n[Windows] Forcing num_workers=0 to avoid soundfile/LibsndfileError.\n"
)
num_workers = 0
elif sys.platform == "darwin":
print(
"\n[macOS] Forcing num_workers=0 for DataLoader stability on MPS.\n"
)
num_workers = 0
else:
# Linux — keep the default (12) unless overridden
print(f"\n[Linux] Using num_workers={num_workers}.\n")
return _original_train(
self, train_dataset, num_workers=num_workers, resumable_with_seed=resumable_with_seed
)
_trainer_module.Trainer.train = _patched_train
import torch.utils.data as _data_module
_original_dataloader_init = _data_module.DataLoader.__init__
def _patched_dataloader_init(self, *args, **kwargs):
# persistent_workers=True is only safe when num_workers > 0.
# On Windows/macOS we force num_workers=0, so disable it there.
if kwargs.get("num_workers", 0) == 0:
kwargs["persistent_workers"] = False
return _original_dataloader_init(self, *args, **kwargs)
_data_module.DataLoader.__init__ = _patched_dataloader_init
# ---------------------------------------------------------------------------
# Delegate to F5-TTS CLI
# ---------------------------------------------------------------------------
from f5_tts.train.finetune_cli import main
if __name__ == "__main__":
sys.exit(main())