import os import logging import functools import torch logger = logging.getLogger(__name__) try: from fairseq import checkpoint_utils FAIRSEQ_AVAILABLE = True except Exception: FAIRSEQ_AVAILABLE = False def _patch_torch_load(): """Patch torch.load to default weights_only=False for fairseq compatibility (PyTorch 2.6+).""" _original = torch.load @functools.wraps(_original) def _patched(*args, **kwargs): kwargs.setdefault("weights_only", False) return _original(*args, **kwargs) return _original, _patched def get_index_path_from_model(sid): return next( ( f for f in [ os.path.join(root, name) for root, _, files in os.walk(os.getenv("index_root"), topdown=False) for name in files if name.endswith(".index") and "trained" not in name ] if sid.split(".")[0] in f ), "", ) def load_hubert(config): if FAIRSEQ_AVAILABLE: _original, _patched = _patch_torch_load() torch.load = _patched try: models, _, _ = checkpoint_utils.load_model_ensemble_and_task( ["assets/hubert/hubert_base.pt"], suffix="", ) finally: torch.load = _original hubert_model = models[0] hubert_model = hubert_model.to(config.device) if config.is_half: hubert_model = hubert_model.half() else: hubert_model = hubert_model.float() return hubert_model.eval() try: import torchaudio class HubertWrapper: def __init__(self, model): self.model = model self.final_proj = getattr(model, "final_proj", torch.nn.Identity()) def extract_features(self, source, padding_mask=None, output_layer=None): feats, _ = self.model.extract_features(source) if output_layer is None: idx = -1 else: idx = min(output_layer - 1, len(feats) - 1) return (feats[idx], None) def to(self, device): self.model = self.model.to(device) return self def half(self): self.model = self.model.half() return self def float(self): self.model = self.model.float() return self def eval(self): self.model.eval() return self model = torchaudio.pipelines.HUBERT_BASE.get_model() hubert_model = HubertWrapper(model).to(config.device) if config.is_half: hubert_model = hubert_model.half() else: hubert_model = hubert_model.float() return hubert_model.eval() except Exception as e: raise RuntimeError( "HuBERT 模型加载失败,请检查 fairseq 和 torchaudio 是否已安装" ) from e