import importlib import importlib.abc import importlib.machinery import os import sys import time import types import warnings from pathlib import Path import torch # Stub speechbrain optional integrations (numba, k2_fsa) to prevent import errors on Windows # These are optional dependencies that may fail to import but are not required for inference class SpeechbrainIntegrationStubLoader(importlib.abc.Loader): def create_module(self, spec): return types.ModuleType(spec.name) def exec_module(self, module): spec = module.__spec__ module.__file__ = "" module.__package__ = spec.name if spec.submodule_search_locations is not None else spec.name.rpartition(".")[0] if spec.submodule_search_locations is not None: module.__path__ = [] module.__all__ = [] class SpeechbrainIntegrationStubFinder(importlib.abc.MetaPathFinder): NAMESPACE = "speechbrain.integrations" def find_spec(self, fullname, path, target=None): if not fullname.startswith(self.NAMESPACE): return None if fullname in sys.modules: return None real_spec = importlib.machinery.PathFinder.find_spec(fullname, path) if real_spec is not None: return None is_pkg = "." not in fullname spec = importlib.machinery.ModuleSpec(fullname, SpeechbrainIntegrationStubLoader(), is_package=is_pkg) if is_pkg: spec.submodule_search_locations = [] sys.modules[fullname] = types.ModuleType(fullname) sys.modules[fullname].__all__ = [] return spec def _install_speechbrain_optional_integration_stub_finder(): if not any(isinstance(finder, SpeechbrainIntegrationStubFinder) for finder in sys.meta_path): sys.meta_path.insert(0, SpeechbrainIntegrationStubFinder()) if "speechbrain.integrations" not in sys.modules: base_mod = types.ModuleType("speechbrain.integrations") base_mod.__path__ = [] sys.modules["speechbrain.integrations"] = base_mod for submodule in ["huggingface", "numba", "k2_fsa", "nlp"]: fullname = f"speechbrain.integrations.{submodule}" if fullname not in sys.modules: submod = types.ModuleType(fullname) submod.__path__ = [] submod.__package__ = "speechbrain.integrations" submod.__all__ = [] sys.modules[fullname] = submod _install_speechbrain_optional_integration_stub_finder() SEPFORNER_MODEL_SOURCE = os.environ.get("SEPFORNER_MODEL_SOURCE", "speechbrain/sepformer-libri3mix") SEPFORNER_MODEL_REVISION = os.environ.get("SEPFORNER_MODEL_REVISION", "main") SEPFORNER_REQUIRED_FILES = ("hyperparams.yaml", "encoder.ckpt", "decoder.ckpt", "masknet.ckpt") def _local_sepformer_dir() -> Path: return Path(os.path.abspath("./pretrained_sepformer")) def _missing_sepformer_files(local_dir: Path): return [ filename for filename in SEPFORNER_REQUIRED_FILES if not (local_dir / filename).is_file() or (local_dir / filename).stat().st_size == 0 ] def _download_missing_sepformer_files(local_dir: Path) -> None: local_dir.mkdir(parents=True, exist_ok=True) try: from huggingface_hub import hf_hub_download except ModuleNotFoundError as exc: raise ModuleNotFoundError( "huggingface_hub is required to download SepFormer assets. Install it with `pip install huggingface_hub`." ) from exc print(f"SepFormer source: {SEPFORNER_MODEL_SOURCE}@{SEPFORNER_MODEL_REVISION}") for filename in SEPFORNER_REQUIRED_FILES: local_path = local_dir / filename is_file = local_path.is_file() file_size = local_path.stat().st_size if is_file else 0 if is_file and file_size > 0: print(f"Using existing SepFormer asset: {local_path}") continue status_msg = "missing" if not is_file else "empty" print(f"Local asset '{filename}' is {status_msg}. Downloading from '{SEPFORNER_MODEL_SOURCE}' to '{local_dir}'...") max_retries = 3 last_error = None for attempt in range(max_retries): try: hf_hub_download( repo_id=SEPFORNER_MODEL_SOURCE, filename=filename, revision=SEPFORNER_MODEL_REVISION, local_dir=str(local_dir), local_dir_use_symlinks=False, ) break except Exception as exc: last_error = exc wait_time = 2 ** attempt print(f"Attempt {attempt + 1}/{max_retries} failed for '{filename}'. Retrying in {wait_time}s...") time.sleep(wait_time) else: raise RuntimeError( f"Failed to download '{filename}' from '{SEPFORNER_MODEL_SOURCE}' after {max_retries} attempts. " f"Check network connectivity or Hugging Face Hub status. Original error: {last_error}" ) from last_error def ensure_local_sepformer_assets() -> Path: local_dir = _local_sepformer_dir() missing = _missing_sepformer_files(local_dir) if missing: _download_missing_sepformer_files(local_dir) missing = _missing_sepformer_files(local_dir) if missing: raise FileNotFoundError( f"Local pretrained SepFormer directory '{local_dir}' is missing required files: {missing}. " f"Set SEPFORNER_MODEL_SOURCE to a valid SpeechBrain SepFormer model and rerun the application." ) return local_dir class UnifiedSepFormer(torch.nn.Module): def __init__(self, modules_dict): super().__init__() self.encoder = modules_dict['encoder'] self.masknet = modules_dict['masknet'] self.decoder = modules_dict['decoder'] def forward(self, mix): mix_w = self.encoder(mix) est_mask = self.masknet(mix_w) decoded_sources = [] for i in range(est_mask.shape[0]): sep_h_i = mix_w * est_mask[i] est_source_i = self.decoder(sep_h_i) decoded_sources.append(est_source_i.unsqueeze(-1)) est_source = torch.cat(decoded_sources, dim=-1) return est_source def load_model(checkpoint_path=None): try: speechbrain_inference = importlib.import_module("speechbrain.inference.separation") speechbrain_fetching = importlib.import_module("speechbrain.utils.fetching") except ModuleNotFoundError as exc: raise ModuleNotFoundError( "SpeechBrain is required for SepFormer model loading. Install it with `pip install speechbrain` and a compatible `k2` package, or use a separate environment where SpeechBrain is supported." ) from exc _install_speechbrain_optional_integration_stub_finder() SepformerSeparation = getattr(speechbrain_inference, "SepformerSeparation") LocalStrategy = getattr(speechbrain_fetching, "LocalStrategy") local_sepformer_dir = ensure_local_sepformer_assets() try: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) model_hub = SepformerSeparation.from_hparams( source=str(local_sepformer_dir), savedir=str(local_sepformer_dir), local_strategy=LocalStrategy.COPY_SKIP_CACHE, ) except ImportError as exc: msg = str(exc) if "speechbrain.integrations.k2_fsa" in msg or "Please install k2 to use k2" in msg or "No module named '_k2'" in msg: raise ImportError( "SpeechBrain attempted to load the optional k2 integration and failed. " "This often happens on Windows because k2 is not available or the installed wheel is incompatible. " "If you do not need k2 features, use a SpeechBrain install that does not require k2 or run this project on Linux. " "Original error: " + msg ) from exc raise model = UnifiedSepFormer(model_hub.mods) if checkpoint_path is None: model.eval() return model if not os.path.exists(checkpoint_path): print(f"WARNING: checkpoint '{checkpoint_path}' not found. Using local pretrained model instead.") model.eval() return model with warnings.catch_warnings(): warnings.filterwarnings("ignore", message=r"TypedStorage is deprecated.*") checkpoint = torch.load( checkpoint_path, map_location="cpu" ) if isinstance(checkpoint, dict): if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] elif "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint else: state_dict = checkpoint try: model.load_state_dict(state_dict) except RuntimeError as err: print("WARNING: checkpoint is incompatible with the local SepFormer architecture.") print("Attempting relaxed load with strict=False.") try: load_result = model.load_state_dict(state_dict, strict=False) missing = getattr(load_result, "missing_keys", None) unexpected = getattr(load_result, "unexpected_keys", None) if missing: print("Missing keys from checkpoint:", missing) if unexpected: print("Unexpected keys in checkpoint:", unexpected) print("Relaxed checkpoint load succeeded. Using loaded weights where possible.") model.eval() return model except RuntimeError as err2: print("Relaxed checkpoint load also failed. Using local pretrained SepFormer weights from './pretrained_sepformer' instead.") print(err2) model.eval() return model model.eval() return model