|
|
import io |
|
|
import os |
|
|
import pickle |
|
|
from typing import Optional, Dict, Callable |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchaudio |
|
|
|
|
|
from transformers import PreTrainedModel |
|
|
|
|
|
from .configuration_speech_encoder import SpeechEncoderConfig |
|
|
|
|
|
|
|
|
def wrap_bos_eos( |
|
|
units: torch.Tensor, |
|
|
durations: torch.Tensor, |
|
|
f0: torch.Tensor | None, |
|
|
dense_features: torch.Tensor, |
|
|
bos: torch.Tensor, |
|
|
eos: torch.Tensor, |
|
|
): |
|
|
|
|
|
one = durations.new_ones(1) |
|
|
units = torch.cat([bos.to(units.device), units, eos.to(units.device)], dim=0) |
|
|
durations = torch.cat([one, durations, one], dim=0) |
|
|
if f0 is not None: |
|
|
|
|
|
f0 = torch.cat([f0[:1], f0, f0[-1:]], dim=0) |
|
|
return units, durations, f0, dense_features |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _FairseqHubertDense(nn.Module): |
|
|
""" |
|
|
Loads a fairseq HuBERT checkpoint (.pt) and exposes extract_features() at a |
|
|
given transformer layer. |
|
|
""" |
|
|
def __init__(self, ckpt_path: str, layer: int, expected_sr: int = 16000, hop: int = 320): |
|
|
super().__init__() |
|
|
try: |
|
|
from fairseq import checkpoint_utils |
|
|
except Exception as e: |
|
|
raise ImportError( |
|
|
"fairseq is required to load a .pt HuBERT checkpoint. " |
|
|
"Please `pip install fairseq` in your runtime." |
|
|
) from e |
|
|
|
|
|
models, _, _ = checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) |
|
|
self.model = models[0] |
|
|
self.model.eval() |
|
|
for p in self.model.parameters(): |
|
|
p.requires_grad_(False) |
|
|
|
|
|
self.output_layer = int(layer) |
|
|
self.expected_sample_rate = int(expected_sr) |
|
|
self.code_hop_size = int(hop) |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, waveform: torch.Tensor) -> torch.Tensor: |
|
|
if waveform.ndim > 1: |
|
|
waveform = waveform.mean(0) |
|
|
wav = waveform.unsqueeze(0) |
|
|
|
|
|
feats, _ = self.model.extract_features(wav, output_layer=self.output_layer) |
|
|
|
|
|
return feats[0] |
|
|
|
|
|
|
|
|
class _TransformersHubertDense(nn.Module): |
|
|
""" |
|
|
Uses transformers' facebook/hubert-* checkpoints. |
|
|
""" |
|
|
def __init__(self, hf_name: str, layer: int, expected_sr: int = 16000, hop: int = 320): |
|
|
super().__init__() |
|
|
from transformers import AutoModel |
|
|
self.backbone = AutoModel.from_pretrained(hf_name) |
|
|
self.backbone.eval() |
|
|
for p in self.backbone.parameters(): |
|
|
p.requires_grad_(False) |
|
|
self.layer = int(layer) |
|
|
self.expected_sample_rate = int(expected_sr) |
|
|
self.code_hop_size = int(hop) |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, waveform: torch.Tensor) -> torch.Tensor: |
|
|
if waveform.ndim > 1: |
|
|
waveform = waveform.mean(0) |
|
|
|
|
|
|
|
|
out = self.backbone( |
|
|
inputs_embeds=None, |
|
|
input_values=waveform.unsqueeze(0), |
|
|
output_hidden_states=True, |
|
|
) |
|
|
|
|
|
hidden = out.hidden_states[self.layer] |
|
|
return hidden[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KMeansQuantizer(nn.Module): |
|
|
""" |
|
|
Simple KMeans quantizer: nearest center assignment per frame. |
|
|
Loads centers from: |
|
|
* .pt (Tensor or dict with keys: cluster_centers, cluster_centers_, centroids, centers) |
|
|
* .npy |
|
|
* pickle/joblib of a scikit KMeans with .cluster_centers_ |
|
|
""" |
|
|
def __init__(self, centers: torch.Tensor): |
|
|
super().__init__() |
|
|
assert centers.ndim == 2, "centers must be (K, D)" |
|
|
self.register_buffer("centers", centers.float()) |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
return int(self.centers.size(0)) |
|
|
|
|
|
@staticmethod |
|
|
def _to_tensor(x): |
|
|
if torch.is_tensor(x): |
|
|
return x |
|
|
return torch.from_numpy(np.asarray(x)) |
|
|
|
|
|
@classmethod |
|
|
def from_file(cls, path: str, key: str = "") -> "KMeansQuantizer": |
|
|
path = os.fspath(path) |
|
|
if not os.path.exists(path): |
|
|
raise FileNotFoundError(f"KMeans file not found: {path}") |
|
|
|
|
|
centers = None |
|
|
|
|
|
if path.endswith(".pt") or path.endswith(".pth"): |
|
|
obj = torch.load(path, map_location="cpu") |
|
|
if torch.is_tensor(obj): |
|
|
centers = obj |
|
|
elif isinstance(obj, dict): |
|
|
for k in [key, "cluster_centers", "cluster_centers_", "centroids", "centers"]: |
|
|
if k and k in obj: |
|
|
centers = cls._to_tensor(obj[k]); break |
|
|
if centers is None: |
|
|
|
|
|
for v in obj.values(): |
|
|
if isinstance(v, dict): |
|
|
for k in ["cluster_centers", "cluster_centers_", "centroids", "centers"]: |
|
|
if k in v: |
|
|
centers = cls._to_tensor(v[k]); break |
|
|
if centers is not None: |
|
|
break |
|
|
|
|
|
if centers is None and path.endswith(".npy"): |
|
|
centers = torch.from_numpy(np.load(path)) |
|
|
|
|
|
if centers is None: |
|
|
|
|
|
try: |
|
|
import joblib |
|
|
obj = joblib.load(path) |
|
|
except Exception: |
|
|
with open(path, "rb") as f: |
|
|
obj = pickle.load(f) |
|
|
if hasattr(obj, "cluster_centers_"): |
|
|
centers = torch.from_numpy(np.asarray(obj.cluster_centers_)) |
|
|
|
|
|
if centers is None: |
|
|
raise ValueError( |
|
|
f"Could not load KMeans centers from {path}. " |
|
|
"Supported: .pt (tensor/dict), .npy, pickled sklearn KMeans." |
|
|
) |
|
|
|
|
|
return cls(centers.float()) |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, dense_features: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
dense_features: (T, D) or (B, T, D) -> returns (T,) or (B,T,) int64 |
|
|
""" |
|
|
x = dense_features |
|
|
if x.ndim == 2: |
|
|
dist = torch.cdist(x.to(self.centers.dtype), self.centers) |
|
|
return torch.argmin(dist, dim=-1).to(torch.long) |
|
|
elif x.ndim == 3: |
|
|
B, T, D = x.shape |
|
|
x2 = x.reshape(B * T, D) |
|
|
dist = torch.cdist(x2.to(self.centers.dtype), self.centers) |
|
|
ids = torch.argmin(dist, dim=-1).to(torch.long).view(B, T) |
|
|
return ids |
|
|
else: |
|
|
raise ValueError("dense_features must be (T,D) or (B,T,D)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
F0_FRAME_SPACE = 0.01 |
|
|
|
|
|
|
|
|
class SpeechEncoder(PreTrainedModel): |
|
|
""" |
|
|
Hugging Face–ready port of the Textless 'SpeechEncoder'. |
|
|
|
|
|
* Has the same public methods as the original (by_name, maybe_resample, forward, properties). |
|
|
* Loads your uploaded HuBERT checkpoint and KMeans centers from the repo. |
|
|
* `need_f0` is supported as a flag, but F0 extraction is not implemented in this minimal port. |
|
|
""" |
|
|
config_class = SpeechEncoderConfig |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dense_model: nn.Module, |
|
|
quantizer_model: nn.Module, |
|
|
deduplicate: bool, |
|
|
add_bos_eos: bool = False, |
|
|
need_f0: bool = False, |
|
|
f0_normalizer: Optional[Callable] = None, |
|
|
f0_quantizer: Optional[Callable] = None, |
|
|
config: Optional[SpeechEncoderConfig] = None, |
|
|
): |
|
|
super().__init__(config if config is not None else SpeechEncoderConfig()) |
|
|
self.dense_model = dense_model |
|
|
self.quantizer_model = quantizer_model |
|
|
|
|
|
self.deduplicate = bool(deduplicate) |
|
|
self.add_bos_eos = bool(add_bos_eos) |
|
|
self.need_f0 = bool(need_f0) |
|
|
self.f0_normalizer = f0_normalizer |
|
|
self.f0_quantizer = f0_quantizer |
|
|
|
|
|
self.unit_vocab_size = int(self.quantizer_model.vocab_size) |
|
|
|
|
|
bos_id = self.config.bos_id if self.config and self.config.bos_id is not None else self.unit_vocab_size |
|
|
eos_id = self.config.eos_id if self.config and self.config.eos_id is not None else self.unit_vocab_size + 1 |
|
|
self.register_buffer("bos", torch.tensor([bos_id], dtype=torch.long)) |
|
|
self.register_buffer("eos", torch.tensor([eos_id], dtype=torch.long)) |
|
|
|
|
|
|
|
|
self.register_buffer("_float_tensor", torch.tensor([0.0], dtype=torch.float)) |
|
|
|
|
|
|
|
|
self._feature_norm = getattr(self.config, "feature_norm", None) |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
""" |
|
|
Loads config, constructs dense+quantizer from files inside the repo, |
|
|
and returns a ready-to-use SpeechEncoder (no weights to load into state_dict). |
|
|
""" |
|
|
config = SpeechEncoderConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
|
|
|
|
|
|
|
repo_root = os.fspath(pretrained_model_name_or_path) |
|
|
hubert_path = os.path.join(repo_root, config.hubert_ckpt) |
|
|
quant_path = os.path.join(repo_root, config.quantizer_file) |
|
|
|
|
|
|
|
|
if config.hubert_backend == "fairseq": |
|
|
dense = _FairseqHubertDense( |
|
|
ckpt_path=hubert_path, |
|
|
layer=config.hubert_layer, |
|
|
expected_sr=config.expected_sample_rate, |
|
|
hop=config.code_hop_size, |
|
|
) |
|
|
elif config.hubert_backend == "transformers": |
|
|
dense = _TransformersHubertDense( |
|
|
hf_name=config.hubert_hf_name, |
|
|
layer=config.hubert_layer, |
|
|
expected_sr=config.expected_sample_rate, |
|
|
hop=config.code_hop_size, |
|
|
) |
|
|
else: |
|
|
raise ValueError("hubert_backend must be 'fairseq' or 'transformers'") |
|
|
|
|
|
|
|
|
quant = KMeansQuantizer.from_file(quant_path, key=config.quantizer_key) |
|
|
|
|
|
|
|
|
model = cls( |
|
|
dense_model=dense, |
|
|
quantizer_model=quant, |
|
|
deduplicate=config.deduplicate, |
|
|
add_bos_eos=config.add_bos_eos, |
|
|
need_f0=config.need_f0, |
|
|
f0_normalizer=None, |
|
|
f0_quantizer=None, |
|
|
config=config, |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def by_name( |
|
|
cls, |
|
|
dense_model_name: str, |
|
|
quantizer_model_name: str, |
|
|
vocab_size: int, |
|
|
deduplicate: bool, |
|
|
add_bos_eos: bool = False, |
|
|
need_f0: bool = False, |
|
|
f0_normalizer: Optional[Callable] = None, |
|
|
f0_quantizer: Optional[Callable] = None, |
|
|
|
|
|
hubert_backend: str = "fairseq", |
|
|
hubert_ckpt: Optional[str] = None, |
|
|
hubert_hf_name: str = "facebook/hubert-base-ls960", |
|
|
hubert_layer: int = 9, |
|
|
quantizer_file: Optional[str] = None, |
|
|
quantizer_key: str = "", |
|
|
expected_sample_rate: int = 16000, |
|
|
code_hop_size: int = 320, |
|
|
) -> "SpeechEncoder": |
|
|
""" |
|
|
Mirrors textlesslib's SpeechEncoder.by_name. For HF usage prefer: |
|
|
AutoModel.from_pretrained(repo, trust_remote_code=True) |
|
|
""" |
|
|
|
|
|
if hubert_backend == "fairseq": |
|
|
if not hubert_ckpt: |
|
|
raise ValueError("Provide hubert_ckpt (path to .pt) when hubert_backend='fairseq'.") |
|
|
dense = _FairseqHubertDense(hubert_ckpt, layer=hubert_layer, |
|
|
expected_sr=expected_sample_rate, hop=code_hop_size) |
|
|
elif hubert_backend == "transformers": |
|
|
dense = _TransformersHubertDense(hubert_hf_name, layer=hubert_layer, |
|
|
expected_sr=expected_sample_rate, hop=code_hop_size) |
|
|
else: |
|
|
raise ValueError("hubert_backend must be 'fairseq' or 'transformers'") |
|
|
|
|
|
if quantizer_model_name.lower() != "kmeans": |
|
|
raise ValueError("Only 'kmeans' quantizer is supported in this port.") |
|
|
if not quantizer_file: |
|
|
raise ValueError("Provide quantizer_file (path to centers).") |
|
|
quant = KMeansQuantizer.from_file(quantizer_file, key=quantizer_key) |
|
|
|
|
|
|
|
|
if vocab_size is not None and int(vocab_size) != quant.vocab_size: |
|
|
raise ValueError(f"vocab_size={vocab_size} does not match centers K={quant.vocab_size}") |
|
|
|
|
|
cfg = SpeechEncoderConfig( |
|
|
hubert_backend=hubert_backend, |
|
|
hubert_ckpt=hubert_ckpt or "", |
|
|
hubert_hf_name=hubert_hf_name, |
|
|
hubert_layer=hubert_layer, |
|
|
expected_sample_rate=expected_sample_rate, |
|
|
code_hop_size=code_hop_size, |
|
|
quantizer_file=os.path.basename(quantizer_file), |
|
|
deduplicate=deduplicate, |
|
|
add_bos_eos=add_bos_eos, |
|
|
need_f0=need_f0, |
|
|
) |
|
|
return cls(dense, quant, deduplicate, add_bos_eos, need_f0, f0_normalizer, f0_quantizer, config=cfg) |
|
|
|
|
|
|
|
|
@property |
|
|
def device(self) -> torch.device: |
|
|
return self._float_tensor.device |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
return self.quantizer_model.vocab_size |
|
|
|
|
|
@property |
|
|
def code_hop_size(self) -> int: |
|
|
return getattr(self.dense_model, "code_hop_size", 320) |
|
|
|
|
|
@property |
|
|
def expected_sample_rate(self) -> int: |
|
|
return getattr(self.dense_model, "expected_sample_rate", 16000) |
|
|
|
|
|
@property |
|
|
def f0_code_ratio(self) -> float: |
|
|
|
|
|
return self.code_hop_size / self.expected_sample_rate / F0_FRAME_SPACE |
|
|
|
|
|
|
|
|
def maybe_resample(self, waveform: torch.Tensor, input_sample_rate: int) -> torch.Tensor: |
|
|
if int(input_sample_rate) == int(self.expected_sample_rate): |
|
|
return waveform |
|
|
return torchaudio.functional.resample( |
|
|
waveform, int(input_sample_rate), int(self.expected_sample_rate) |
|
|
) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, waveform: torch.Tensor, speaker: Optional[str] = None) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Returns: |
|
|
{ |
|
|
"units": LongTensor [U], |
|
|
"durations": LongTensor [U], (frame counts) |
|
|
"dense": FloatTensor [T, D], |
|
|
(optional) "f0": FloatTensor [U] or [T_f0] if implemented |
|
|
} |
|
|
""" |
|
|
|
|
|
dense_features = self.dense_model(waveform) |
|
|
|
|
|
|
|
|
if self._feature_norm == "unit": |
|
|
eps = 1e-6 |
|
|
dense_features = dense_features / (dense_features.norm(dim=-1, keepdim=True) + eps) |
|
|
elif self._feature_norm == "layernorm": |
|
|
mean = dense_features.mean(dim=-1, keepdim=True) |
|
|
std = dense_features.std(dim=-1, keepdim=True).clamp_min(1e-5) |
|
|
dense_features = (dense_features - mean) / std |
|
|
|
|
|
|
|
|
ids_per_frame = self.quantizer_model(dense_features) |
|
|
|
|
|
|
|
|
if self.deduplicate: |
|
|
units, durations = torch.unique_consecutive(ids_per_frame, return_counts=True) |
|
|
else: |
|
|
units = ids_per_frame |
|
|
durations = torch.ones_like(units, dtype=torch.long) |
|
|
|
|
|
|
|
|
f0 = None |
|
|
if self.need_f0: |
|
|
raise NotImplementedError( |
|
|
"F0 extraction is not included in this minimal HF port. " |
|
|
"Set need_f0=False (as in the reference pipeline)." |
|
|
) |
|
|
|
|
|
|
|
|
if self.add_bos_eos: |
|
|
units, durations, f0, dense_features = wrap_bos_eos( |
|
|
units, durations, f0, dense_features, self.bos, self.eos |
|
|
) |
|
|
|
|
|
item = { |
|
|
"units": units.to(self.device), |
|
|
"durations": durations.to(self.device), |
|
|
"dense": dense_features.to(self.device), |
|
|
} |
|
|
if f0 is not None: |
|
|
item["f0"] = f0.to(self.device) |
|
|
return item |
|
|
|