#!/usr/bin/env python3 """Cross-platform launcher for F5-TTS fine-tuning. Auto-detects the compute device (CUDA / MPS / CPU) and adjusts settings: - bnb_optimizer is only enabled on CUDA (bitsandbytes doesn't support MPS/CPU). - num_workers is forced to 0 on Windows & macOS (soundfile/LibsndfileError). - On Linux, num_workers defaults to 12 (or override via F5_NUM_WORKERS env var). - Windows WMI cache is preloaded to avoid torch import hangs. """ from __future__ import annotations import os import platform import socket import sys # --------------------------------------------------------------------------- # Windows-only: preload WMI cache to avoid hang during torch import # --------------------------------------------------------------------------- if sys.platform == "win32": platform._uname_cache = platform.uname_result( # type: ignore[attr-defined] "Windows", socket.gethostname(), "10", "Windows", "AMD64" ) # --------------------------------------------------------------------------- # Device-safe bnb_optimizer patch # bitsandbytes only supports CUDA. On MPS or CPU we silently fall back to # standard AdamW so the same CLI flags work everywhere. # --------------------------------------------------------------------------- import f5_tts.model.trainer as _trainer_module _original_trainer_init = _trainer_module.Trainer.__init__ def _patched_trainer_init(self, *args, **kwargs): bnb = kwargs.get("bnb_optimizer", False) if bnb: import torch if not torch.cuda.is_available(): kwargs["bnb_optimizer"] = False dev = "mps" if torch.backends.mps.is_available() else "cpu" print( f"\n[Device] CUDA not available (detected {dev}). " f"Disabling --bnb_optimizer; using standard AdamW.\n" ) return _original_trainer_init(self, *args, **kwargs) _trainer_module.Trainer.__init__ = _patched_trainer_init # --------------------------------------------------------------------------- # Platform-aware DataLoader patches # --------------------------------------------------------------------------- _original_train = _trainer_module.Trainer.train def _patched_train(self, train_dataset, num_workers=12, resumable_with_seed=None): env_workers = os.getenv("F5_NUM_WORKERS") if env_workers is not None: num_workers = int(env_workers) print(f"\n[Workers] Using F5_NUM_WORKERS={num_workers}.\n") elif sys.platform == "win32": print( "\n[Windows] Forcing num_workers=0 to avoid soundfile/LibsndfileError.\n" ) num_workers = 0 elif sys.platform == "darwin": print( "\n[macOS] Forcing num_workers=0 for DataLoader stability on MPS.\n" ) num_workers = 0 else: # Linux — keep the default (12) unless overridden print(f"\n[Linux] Using num_workers={num_workers}.\n") return _original_train( self, train_dataset, num_workers=num_workers, resumable_with_seed=resumable_with_seed ) _trainer_module.Trainer.train = _patched_train import torch.utils.data as _data_module _original_dataloader_init = _data_module.DataLoader.__init__ def _patched_dataloader_init(self, *args, **kwargs): # persistent_workers=True is only safe when num_workers > 0. # On Windows/macOS we force num_workers=0, so disable it there. if kwargs.get("num_workers", 0) == 0: kwargs["persistent_workers"] = False return _original_dataloader_init(self, *args, **kwargs) _data_module.DataLoader.__init__ = _patched_dataloader_init # --------------------------------------------------------------------------- # Delegate to F5-TTS CLI # --------------------------------------------------------------------------- from f5_tts.train.finetune_cli import main if __name__ == "__main__": sys.exit(main())