import io import os import pickle from typing import Optional, Dict, Callable import numpy as np import torch import torch.nn as nn import torchaudio from transformers import PreTrainedModel from .configuration_speech_encoder import SpeechEncoderConfig def wrap_bos_eos( units: torch.Tensor, durations: torch.Tensor, f0: torch.Tensor | None, dense_features: torch.Tensor, bos: torch.Tensor, eos: torch.Tensor, ): # bos/eos are 1-element tensors on the right device/dtype one = durations.new_ones(1) units = torch.cat([bos.to(units.device), units, eos.to(units.device)], dim=0) durations = torch.cat([one, durations, one], dim=0) if f0 is not None: # pad f0 with edge values f0 = torch.cat([f0[:1], f0, f0[-1:]], dim=0) return units, durations, f0, dense_features # ---------------------------- # Dense feature backends (HuBERT) # ---------------------------- class _FairseqHubertDense(nn.Module): """ Loads a fairseq HuBERT checkpoint (.pt) and exposes extract_features() at a given transformer layer. """ def __init__(self, ckpt_path: str, layer: int, expected_sr: int = 16000, hop: int = 320): super().__init__() try: from fairseq import checkpoint_utils except Exception as e: raise ImportError( "fairseq is required to load a .pt HuBERT checkpoint. " "Please `pip install fairseq` in your runtime." ) from e models, _, _ = checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) self.model = models[0] self.model.eval() for p in self.model.parameters(): p.requires_grad_(False) self.output_layer = int(layer) self.expected_sample_rate = int(expected_sr) self.code_hop_size = int(hop) @torch.no_grad() def forward(self, waveform: torch.Tensor) -> torch.Tensor: if waveform.ndim > 1: waveform = waveform.mean(0) wav = waveform.unsqueeze(0) # (1, T) # fairseq HuBERT exposes extract_features(...) feats, _ = self.model.extract_features(wav, output_layer=self.output_layer) # feats: (B, T, C) return feats[0] # (T, C) class _TransformersHubertDense(nn.Module): """ Uses transformers' facebook/hubert-* checkpoints. """ def __init__(self, hf_name: str, layer: int, expected_sr: int = 16000, hop: int = 320): super().__init__() from transformers import AutoModel self.backbone = AutoModel.from_pretrained(hf_name) self.backbone.eval() for p in self.backbone.parameters(): p.requires_grad_(False) self.layer = int(layer) self.expected_sample_rate = int(expected_sr) self.code_hop_size = int(hop) @torch.no_grad() def forward(self, waveform: torch.Tensor) -> torch.Tensor: if waveform.ndim > 1: waveform = waveform.mean(0) # transformers hubert expects (batch, time); we pass raw PCM; # hidden_states=True to get all layers out = self.backbone( inputs_embeds=None, input_values=waveform.unsqueeze(0), output_hidden_states=True, ) # hidden_states is a tuple: [emb, layer1, ..., layerN] hidden = out.hidden_states[self.layer] # (B, T, C) return hidden[0] # ---------------------------- # KMeans quantizer # ---------------------------- class KMeansQuantizer(nn.Module): """ Simple KMeans quantizer: nearest center assignment per frame. Loads centers from: * .pt (Tensor or dict with keys: cluster_centers, cluster_centers_, centroids, centers) * .npy * pickle/joblib of a scikit KMeans with .cluster_centers_ """ def __init__(self, centers: torch.Tensor): super().__init__() assert centers.ndim == 2, "centers must be (K, D)" self.register_buffer("centers", centers.float()) @property def vocab_size(self) -> int: return int(self.centers.size(0)) @staticmethod def _to_tensor(x): if torch.is_tensor(x): return x return torch.from_numpy(np.asarray(x)) @classmethod def from_file(cls, path: str, key: str = "") -> "KMeansQuantizer": path = os.fspath(path) if not os.path.exists(path): raise FileNotFoundError(f"KMeans file not found: {path}") centers = None if path.endswith(".pt") or path.endswith(".pth"): obj = torch.load(path, map_location="cpu") if torch.is_tensor(obj): centers = obj elif isinstance(obj, dict): for k in [key, "cluster_centers", "cluster_centers_", "centroids", "centers"]: if k and k in obj: centers = cls._to_tensor(obj[k]); break if centers is None: # Some dumps wrap centers deeper: {'state': {'centers': ...}} for v in obj.values(): if isinstance(v, dict): for k in ["cluster_centers", "cluster_centers_", "centroids", "centers"]: if k in v: centers = cls._to_tensor(v[k]); break if centers is not None: break if centers is None and path.endswith(".npy"): centers = torch.from_numpy(np.load(path)) if centers is None: # Try joblib/pickle try: import joblib obj = joblib.load(path) except Exception: with open(path, "rb") as f: obj = pickle.load(f) if hasattr(obj, "cluster_centers_"): centers = torch.from_numpy(np.asarray(obj.cluster_centers_)) if centers is None: raise ValueError( f"Could not load KMeans centers from {path}. " "Supported: .pt (tensor/dict), .npy, pickled sklearn KMeans." ) return cls(centers.float()) @torch.no_grad() def forward(self, dense_features: torch.Tensor) -> torch.Tensor: """ dense_features: (T, D) or (B, T, D) -> returns (T,) or (B,T,) int64 """ x = dense_features if x.ndim == 2: # (T, D) dist = torch.cdist(x.to(self.centers.dtype), self.centers) # (T, K) return torch.argmin(dist, dim=-1).to(torch.long) elif x.ndim == 3: # (B, T, D) B, T, D = x.shape x2 = x.reshape(B * T, D) dist = torch.cdist(x2.to(self.centers.dtype), self.centers) # (B*T, K) ids = torch.argmin(dist, dim=-1).to(torch.long).view(B, T) return ids else: raise ValueError("dense_features must be (T,D) or (B,T,D)") # ---------------------------- # SpeechEncoder (HF-ready) # ---------------------------- F0_FRAME_SPACE = 0.01 # seconds; kept for API completeness (we don't compute F0 here) class SpeechEncoder(PreTrainedModel): """ Hugging Face–ready port of the Textless 'SpeechEncoder'. * Has the same public methods as the original (by_name, maybe_resample, forward, properties). * Loads your uploaded HuBERT checkpoint and KMeans centers from the repo. * `need_f0` is supported as a flag, but F0 extraction is not implemented in this minimal port. """ config_class = SpeechEncoderConfig def __init__( self, dense_model: nn.Module, quantizer_model: nn.Module, deduplicate: bool, add_bos_eos: bool = False, need_f0: bool = False, f0_normalizer: Optional[Callable] = None, f0_quantizer: Optional[Callable] = None, config: Optional[SpeechEncoderConfig] = None, ): super().__init__(config if config is not None else SpeechEncoderConfig()) self.dense_model = dense_model self.quantizer_model = quantizer_model self.deduplicate = bool(deduplicate) self.add_bos_eos = bool(add_bos_eos) self.need_f0 = bool(need_f0) self.f0_normalizer = f0_normalizer self.f0_quantizer = f0_quantizer self.unit_vocab_size = int(self.quantizer_model.vocab_size) bos_id = self.config.bos_id if self.config and self.config.bos_id is not None else self.unit_vocab_size eos_id = self.config.eos_id if self.config and self.config.eos_id is not None else self.unit_vocab_size + 1 self.register_buffer("bos", torch.tensor([bos_id], dtype=torch.long)) self.register_buffer("eos", torch.tensor([eos_id], dtype=torch.long)) # used only for device tracking (mimics the original) self.register_buffer("_float_tensor", torch.tensor([0.0], dtype=torch.float)) # Optional feature normalization before K-Means self._feature_norm = getattr(self.config, "feature_norm", None) # ---------- HF convenience: override from_pretrained to pick up assets ---------- @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """ Loads config, constructs dense+quantizer from files inside the repo, and returns a ready-to-use SpeechEncoder (no weights to load into state_dict). """ config = SpeechEncoderConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) # Resolve local paths to uploaded assets repo_root = os.fspath(pretrained_model_name_or_path) hubert_path = os.path.join(repo_root, config.hubert_ckpt) quant_path = os.path.join(repo_root, config.quantizer_file) # Dense backend if config.hubert_backend == "fairseq": dense = _FairseqHubertDense( ckpt_path=hubert_path, layer=config.hubert_layer, expected_sr=config.expected_sample_rate, hop=config.code_hop_size, ) elif config.hubert_backend == "transformers": dense = _TransformersHubertDense( hf_name=config.hubert_hf_name, layer=config.hubert_layer, expected_sr=config.expected_sample_rate, hop=config.code_hop_size, ) else: raise ValueError("hubert_backend must be 'fairseq' or 'transformers'") # Quantizer quant = KMeansQuantizer.from_file(quant_path, key=config.quantizer_key) # Construct the encoder (HF PreTrainedModel base will still attach config) model = cls( dense_model=dense, quantizer_model=quant, deduplicate=config.deduplicate, add_bos_eos=config.add_bos_eos, need_f0=config.need_f0, f0_normalizer=None, f0_quantizer=None, config=config, ) return model # ---------- Original "by_name" API (kept for drop-in parity) ---------- @classmethod def by_name( cls, dense_model_name: str, quantizer_model_name: str, vocab_size: int, deduplicate: bool, add_bos_eos: bool = False, need_f0: bool = False, f0_normalizer: Optional[Callable] = None, f0_quantizer: Optional[Callable] = None, # HF-specific args to locate assets if you don't use .from_pretrained(...) hubert_backend: str = "fairseq", hubert_ckpt: Optional[str] = None, hubert_hf_name: str = "facebook/hubert-base-ls960", hubert_layer: int = 9, quantizer_file: Optional[str] = None, quantizer_key: str = "", expected_sample_rate: int = 16000, code_hop_size: int = 320, ) -> "SpeechEncoder": """ Mirrors textlesslib's SpeechEncoder.by_name. For HF usage prefer: AutoModel.from_pretrained(repo, trust_remote_code=True) """ # dense_model_name is kept for parity; we only support HuBERT here if hubert_backend == "fairseq": if not hubert_ckpt: raise ValueError("Provide hubert_ckpt (path to .pt) when hubert_backend='fairseq'.") dense = _FairseqHubertDense(hubert_ckpt, layer=hubert_layer, expected_sr=expected_sample_rate, hop=code_hop_size) elif hubert_backend == "transformers": dense = _TransformersHubertDense(hubert_hf_name, layer=hubert_layer, expected_sr=expected_sample_rate, hop=code_hop_size) else: raise ValueError("hubert_backend must be 'fairseq' or 'transformers'") if quantizer_model_name.lower() != "kmeans": raise ValueError("Only 'kmeans' quantizer is supported in this port.") if not quantizer_file: raise ValueError("Provide quantizer_file (path to centers).") quant = KMeansQuantizer.from_file(quantizer_file, key=quantizer_key) # Sanity check on vocab size if user passed it if vocab_size is not None and int(vocab_size) != quant.vocab_size: raise ValueError(f"vocab_size={vocab_size} does not match centers K={quant.vocab_size}") cfg = SpeechEncoderConfig( hubert_backend=hubert_backend, hubert_ckpt=hubert_ckpt or "", hubert_hf_name=hubert_hf_name, hubert_layer=hubert_layer, expected_sample_rate=expected_sample_rate, code_hop_size=code_hop_size, quantizer_file=os.path.basename(quantizer_file), deduplicate=deduplicate, add_bos_eos=add_bos_eos, need_f0=need_f0, ) return cls(dense, quant, deduplicate, add_bos_eos, need_f0, f0_normalizer, f0_quantizer, config=cfg) # ---------- Properties (parity) ---------- @property def device(self) -> torch.device: return self._float_tensor.device @property def vocab_size(self) -> int: return self.quantizer_model.vocab_size @property def code_hop_size(self) -> int: return getattr(self.dense_model, "code_hop_size", 320) @property def expected_sample_rate(self) -> int: return getattr(self.dense_model, "expected_sample_rate", 16000) @property def f0_code_ratio(self) -> float: # F0 frames per unit frame return self.code_hop_size / self.expected_sample_rate / F0_FRAME_SPACE # ---------- Resampling ---------- def maybe_resample(self, waveform: torch.Tensor, input_sample_rate: int) -> torch.Tensor: if int(input_sample_rate) == int(self.expected_sample_rate): return waveform return torchaudio.functional.resample( waveform, int(input_sample_rate), int(self.expected_sample_rate) ) # ---------- Forward (parity) ---------- @torch.no_grad() def forward(self, waveform: torch.Tensor, speaker: Optional[str] = None) -> Dict[str, torch.Tensor]: """ Returns: { "units": LongTensor [U], "durations": LongTensor [U], (frame counts) "dense": FloatTensor [T, D], (optional) "f0": FloatTensor [U] or [T_f0] if implemented } """ # 1) Dense features at HuBERT frame rate dense_features = self.dense_model(waveform) # (T, D) # optional feature normalization before KMeans (kept simple) if self._feature_norm == "unit": eps = 1e-6 dense_features = dense_features / (dense_features.norm(dim=-1, keepdim=True) + eps) elif self._feature_norm == "layernorm": mean = dense_features.mean(dim=-1, keepdim=True) std = dense_features.std(dim=-1, keepdim=True).clamp_min(1e-5) dense_features = (dense_features - mean) / std # 2) KMeans quantization → unit ids (per-frame) ids_per_frame = self.quantizer_model(dense_features) # (T,) # 3) Dedup → durations if self.deduplicate: units, durations = torch.unique_consecutive(ids_per_frame, return_counts=True) else: units = ids_per_frame durations = torch.ones_like(units, dtype=torch.long) # 4) (Optional) F0 path — not bundled here f0 = None if self.need_f0: raise NotImplementedError( "F0 extraction is not included in this minimal HF port. " "Set need_f0=False (as in the reference pipeline)." ) # 5) BOS/EOS wrap (if requested) if self.add_bos_eos: units, durations, f0, dense_features = wrap_bos_eos( units, durations, f0, dense_features, self.bos, self.eos ) item = { "units": units.to(self.device), "durations": durations.to(self.device), "dense": dense_features.to(self.device), } if f0 is not None: item["f0"] = f0.to(self.device) return item