| |
| """Cross-platform launcher for F5-TTS fine-tuning. |
| |
| Auto-detects the compute device (CUDA / MPS / CPU) and adjusts settings: |
| - bnb_optimizer is only enabled on CUDA (bitsandbytes doesn't support MPS/CPU). |
| - num_workers is forced to 0 on Windows & macOS (soundfile/LibsndfileError). |
| - On Linux, num_workers defaults to 12 (or override via F5_NUM_WORKERS env var). |
| - Windows WMI cache is preloaded to avoid torch import hangs. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import platform |
| import socket |
| import sys |
|
|
| |
| |
| |
| if sys.platform == "win32": |
| platform._uname_cache = platform.uname_result( |
| "Windows", socket.gethostname(), "10", "Windows", "AMD64" |
| ) |
|
|
| |
| |
| |
| |
| |
| import f5_tts.model.trainer as _trainer_module |
|
|
| _original_trainer_init = _trainer_module.Trainer.__init__ |
|
|
|
|
| def _patched_trainer_init(self, *args, **kwargs): |
| bnb = kwargs.get("bnb_optimizer", False) |
| if bnb: |
| import torch |
|
|
| if not torch.cuda.is_available(): |
| kwargs["bnb_optimizer"] = False |
| dev = "mps" if torch.backends.mps.is_available() else "cpu" |
| print( |
| f"\n[Device] CUDA not available (detected {dev}). " |
| f"Disabling --bnb_optimizer; using standard AdamW.\n" |
| ) |
| return _original_trainer_init(self, *args, **kwargs) |
|
|
|
|
| _trainer_module.Trainer.__init__ = _patched_trainer_init |
|
|
| |
| |
| |
| _original_train = _trainer_module.Trainer.train |
|
|
|
|
| def _patched_train(self, train_dataset, num_workers=12, resumable_with_seed=None): |
| env_workers = os.getenv("F5_NUM_WORKERS") |
| if env_workers is not None: |
| num_workers = int(env_workers) |
| print(f"\n[Workers] Using F5_NUM_WORKERS={num_workers}.\n") |
| elif sys.platform == "win32": |
| print( |
| "\n[Windows] Forcing num_workers=0 to avoid soundfile/LibsndfileError.\n" |
| ) |
| num_workers = 0 |
| elif sys.platform == "darwin": |
| print( |
| "\n[macOS] Forcing num_workers=0 for DataLoader stability on MPS.\n" |
| ) |
| num_workers = 0 |
| else: |
| |
| print(f"\n[Linux] Using num_workers={num_workers}.\n") |
|
|
| return _original_train( |
| self, train_dataset, num_workers=num_workers, resumable_with_seed=resumable_with_seed |
| ) |
|
|
|
|
| _trainer_module.Trainer.train = _patched_train |
|
|
| import torch.utils.data as _data_module |
|
|
| _original_dataloader_init = _data_module.DataLoader.__init__ |
|
|
|
|
| def _patched_dataloader_init(self, *args, **kwargs): |
| |
| |
| if kwargs.get("num_workers", 0) == 0: |
| kwargs["persistent_workers"] = False |
| return _original_dataloader_init(self, *args, **kwargs) |
|
|
|
|
| _data_module.DataLoader.__init__ = _patched_dataloader_init |
|
|
| |
| |
| |
| from f5_tts.train.finetune_cli import main |
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|