gslm-encoder / modeling_speech_encoder.py
klemenk's picture
Update modeling_speech_encoder.py
a68b041 verified
import io
import os
import pickle
from typing import Optional, Dict, Callable
import numpy as np
import torch
import torch.nn as nn
import torchaudio
from transformers import PreTrainedModel
from .configuration_speech_encoder import SpeechEncoderConfig
def wrap_bos_eos(
units: torch.Tensor,
durations: torch.Tensor,
f0: torch.Tensor | None,
dense_features: torch.Tensor,
bos: torch.Tensor,
eos: torch.Tensor,
):
# bos/eos are 1-element tensors on the right device/dtype
one = durations.new_ones(1)
units = torch.cat([bos.to(units.device), units, eos.to(units.device)], dim=0)
durations = torch.cat([one, durations, one], dim=0)
if f0 is not None:
# pad f0 with edge values
f0 = torch.cat([f0[:1], f0, f0[-1:]], dim=0)
return units, durations, f0, dense_features
# ----------------------------
# Dense feature backends (HuBERT)
# ----------------------------
class _FairseqHubertDense(nn.Module):
"""
Loads a fairseq HuBERT checkpoint (.pt) and exposes extract_features() at a
given transformer layer.
"""
def __init__(self, ckpt_path: str, layer: int, expected_sr: int = 16000, hop: int = 320):
super().__init__()
try:
from fairseq import checkpoint_utils
except Exception as e:
raise ImportError(
"fairseq is required to load a .pt HuBERT checkpoint. "
"Please `pip install fairseq` in your runtime."
) from e
models, _, _ = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
self.model = models[0]
self.model.eval()
for p in self.model.parameters():
p.requires_grad_(False)
self.output_layer = int(layer)
self.expected_sample_rate = int(expected_sr)
self.code_hop_size = int(hop)
@torch.no_grad()
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
if waveform.ndim > 1:
waveform = waveform.mean(0)
wav = waveform.unsqueeze(0) # (1, T)
# fairseq HuBERT exposes extract_features(...)
feats, _ = self.model.extract_features(wav, output_layer=self.output_layer)
# feats: (B, T, C)
return feats[0] # (T, C)
class _TransformersHubertDense(nn.Module):
"""
Uses transformers' facebook/hubert-* checkpoints.
"""
def __init__(self, hf_name: str, layer: int, expected_sr: int = 16000, hop: int = 320):
super().__init__()
from transformers import AutoModel
self.backbone = AutoModel.from_pretrained(hf_name)
self.backbone.eval()
for p in self.backbone.parameters():
p.requires_grad_(False)
self.layer = int(layer)
self.expected_sample_rate = int(expected_sr)
self.code_hop_size = int(hop)
@torch.no_grad()
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
if waveform.ndim > 1:
waveform = waveform.mean(0)
# transformers hubert expects (batch, time); we pass raw PCM;
# hidden_states=True to get all layers
out = self.backbone(
inputs_embeds=None,
input_values=waveform.unsqueeze(0),
output_hidden_states=True,
)
# hidden_states is a tuple: [emb, layer1, ..., layerN]
hidden = out.hidden_states[self.layer] # (B, T, C)
return hidden[0]
# ----------------------------
# KMeans quantizer
# ----------------------------
class KMeansQuantizer(nn.Module):
"""
Simple KMeans quantizer: nearest center assignment per frame.
Loads centers from:
* .pt (Tensor or dict with keys: cluster_centers, cluster_centers_, centroids, centers)
* .npy
* pickle/joblib of a scikit KMeans with .cluster_centers_
"""
def __init__(self, centers: torch.Tensor):
super().__init__()
assert centers.ndim == 2, "centers must be (K, D)"
self.register_buffer("centers", centers.float())
@property
def vocab_size(self) -> int:
return int(self.centers.size(0))
@staticmethod
def _to_tensor(x):
if torch.is_tensor(x):
return x
return torch.from_numpy(np.asarray(x))
@classmethod
def from_file(cls, path: str, key: str = "") -> "KMeansQuantizer":
path = os.fspath(path)
if not os.path.exists(path):
raise FileNotFoundError(f"KMeans file not found: {path}")
centers = None
if path.endswith(".pt") or path.endswith(".pth"):
obj = torch.load(path, map_location="cpu")
if torch.is_tensor(obj):
centers = obj
elif isinstance(obj, dict):
for k in [key, "cluster_centers", "cluster_centers_", "centroids", "centers"]:
if k and k in obj:
centers = cls._to_tensor(obj[k]); break
if centers is None:
# Some dumps wrap centers deeper: {'state': {'centers': ...}}
for v in obj.values():
if isinstance(v, dict):
for k in ["cluster_centers", "cluster_centers_", "centroids", "centers"]:
if k in v:
centers = cls._to_tensor(v[k]); break
if centers is not None:
break
if centers is None and path.endswith(".npy"):
centers = torch.from_numpy(np.load(path))
if centers is None:
# Try joblib/pickle
try:
import joblib
obj = joblib.load(path)
except Exception:
with open(path, "rb") as f:
obj = pickle.load(f)
if hasattr(obj, "cluster_centers_"):
centers = torch.from_numpy(np.asarray(obj.cluster_centers_))
if centers is None:
raise ValueError(
f"Could not load KMeans centers from {path}. "
"Supported: .pt (tensor/dict), .npy, pickled sklearn KMeans."
)
return cls(centers.float())
@torch.no_grad()
def forward(self, dense_features: torch.Tensor) -> torch.Tensor:
"""
dense_features: (T, D) or (B, T, D) -> returns (T,) or (B,T,) int64
"""
x = dense_features
if x.ndim == 2: # (T, D)
dist = torch.cdist(x.to(self.centers.dtype), self.centers) # (T, K)
return torch.argmin(dist, dim=-1).to(torch.long)
elif x.ndim == 3: # (B, T, D)
B, T, D = x.shape
x2 = x.reshape(B * T, D)
dist = torch.cdist(x2.to(self.centers.dtype), self.centers) # (B*T, K)
ids = torch.argmin(dist, dim=-1).to(torch.long).view(B, T)
return ids
else:
raise ValueError("dense_features must be (T,D) or (B,T,D)")
# ----------------------------
# SpeechEncoder (HF-ready)
# ----------------------------
F0_FRAME_SPACE = 0.01 # seconds; kept for API completeness (we don't compute F0 here)
class SpeechEncoder(PreTrainedModel):
"""
Hugging Face–ready port of the Textless 'SpeechEncoder'.
* Has the same public methods as the original (by_name, maybe_resample, forward, properties).
* Loads your uploaded HuBERT checkpoint and KMeans centers from the repo.
* `need_f0` is supported as a flag, but F0 extraction is not implemented in this minimal port.
"""
config_class = SpeechEncoderConfig
def __init__(
self,
dense_model: nn.Module,
quantizer_model: nn.Module,
deduplicate: bool,
add_bos_eos: bool = False,
need_f0: bool = False,
f0_normalizer: Optional[Callable] = None,
f0_quantizer: Optional[Callable] = None,
config: Optional[SpeechEncoderConfig] = None,
):
super().__init__(config if config is not None else SpeechEncoderConfig())
self.dense_model = dense_model
self.quantizer_model = quantizer_model
self.deduplicate = bool(deduplicate)
self.add_bos_eos = bool(add_bos_eos)
self.need_f0 = bool(need_f0)
self.f0_normalizer = f0_normalizer
self.f0_quantizer = f0_quantizer
self.unit_vocab_size = int(self.quantizer_model.vocab_size)
bos_id = self.config.bos_id if self.config and self.config.bos_id is not None else self.unit_vocab_size
eos_id = self.config.eos_id if self.config and self.config.eos_id is not None else self.unit_vocab_size + 1
self.register_buffer("bos", torch.tensor([bos_id], dtype=torch.long))
self.register_buffer("eos", torch.tensor([eos_id], dtype=torch.long))
# used only for device tracking (mimics the original)
self.register_buffer("_float_tensor", torch.tensor([0.0], dtype=torch.float))
# Optional feature normalization before K-Means
self._feature_norm = getattr(self.config, "feature_norm", None)
# ---------- HF convenience: override from_pretrained to pick up assets ----------
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Loads config, constructs dense+quantizer from files inside the repo,
and returns a ready-to-use SpeechEncoder (no weights to load into state_dict).
"""
config = SpeechEncoderConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
# Resolve local paths to uploaded assets
repo_root = os.fspath(pretrained_model_name_or_path)
hubert_path = os.path.join(repo_root, config.hubert_ckpt)
quant_path = os.path.join(repo_root, config.quantizer_file)
# Dense backend
if config.hubert_backend == "fairseq":
dense = _FairseqHubertDense(
ckpt_path=hubert_path,
layer=config.hubert_layer,
expected_sr=config.expected_sample_rate,
hop=config.code_hop_size,
)
elif config.hubert_backend == "transformers":
dense = _TransformersHubertDense(
hf_name=config.hubert_hf_name,
layer=config.hubert_layer,
expected_sr=config.expected_sample_rate,
hop=config.code_hop_size,
)
else:
raise ValueError("hubert_backend must be 'fairseq' or 'transformers'")
# Quantizer
quant = KMeansQuantizer.from_file(quant_path, key=config.quantizer_key)
# Construct the encoder (HF PreTrainedModel base will still attach config)
model = cls(
dense_model=dense,
quantizer_model=quant,
deduplicate=config.deduplicate,
add_bos_eos=config.add_bos_eos,
need_f0=config.need_f0,
f0_normalizer=None,
f0_quantizer=None,
config=config,
)
return model
# ---------- Original "by_name" API (kept for drop-in parity) ----------
@classmethod
def by_name(
cls,
dense_model_name: str,
quantizer_model_name: str,
vocab_size: int,
deduplicate: bool,
add_bos_eos: bool = False,
need_f0: bool = False,
f0_normalizer: Optional[Callable] = None,
f0_quantizer: Optional[Callable] = None,
# HF-specific args to locate assets if you don't use .from_pretrained(...)
hubert_backend: str = "fairseq",
hubert_ckpt: Optional[str] = None,
hubert_hf_name: str = "facebook/hubert-base-ls960",
hubert_layer: int = 9,
quantizer_file: Optional[str] = None,
quantizer_key: str = "",
expected_sample_rate: int = 16000,
code_hop_size: int = 320,
) -> "SpeechEncoder":
"""
Mirrors textlesslib's SpeechEncoder.by_name. For HF usage prefer:
AutoModel.from_pretrained(repo, trust_remote_code=True)
"""
# dense_model_name is kept for parity; we only support HuBERT here
if hubert_backend == "fairseq":
if not hubert_ckpt:
raise ValueError("Provide hubert_ckpt (path to .pt) when hubert_backend='fairseq'.")
dense = _FairseqHubertDense(hubert_ckpt, layer=hubert_layer,
expected_sr=expected_sample_rate, hop=code_hop_size)
elif hubert_backend == "transformers":
dense = _TransformersHubertDense(hubert_hf_name, layer=hubert_layer,
expected_sr=expected_sample_rate, hop=code_hop_size)
else:
raise ValueError("hubert_backend must be 'fairseq' or 'transformers'")
if quantizer_model_name.lower() != "kmeans":
raise ValueError("Only 'kmeans' quantizer is supported in this port.")
if not quantizer_file:
raise ValueError("Provide quantizer_file (path to centers).")
quant = KMeansQuantizer.from_file(quantizer_file, key=quantizer_key)
# Sanity check on vocab size if user passed it
if vocab_size is not None and int(vocab_size) != quant.vocab_size:
raise ValueError(f"vocab_size={vocab_size} does not match centers K={quant.vocab_size}")
cfg = SpeechEncoderConfig(
hubert_backend=hubert_backend,
hubert_ckpt=hubert_ckpt or "",
hubert_hf_name=hubert_hf_name,
hubert_layer=hubert_layer,
expected_sample_rate=expected_sample_rate,
code_hop_size=code_hop_size,
quantizer_file=os.path.basename(quantizer_file),
deduplicate=deduplicate,
add_bos_eos=add_bos_eos,
need_f0=need_f0,
)
return cls(dense, quant, deduplicate, add_bos_eos, need_f0, f0_normalizer, f0_quantizer, config=cfg)
# ---------- Properties (parity) ----------
@property
def device(self) -> torch.device:
return self._float_tensor.device
@property
def vocab_size(self) -> int:
return self.quantizer_model.vocab_size
@property
def code_hop_size(self) -> int:
return getattr(self.dense_model, "code_hop_size", 320)
@property
def expected_sample_rate(self) -> int:
return getattr(self.dense_model, "expected_sample_rate", 16000)
@property
def f0_code_ratio(self) -> float:
# F0 frames per unit frame
return self.code_hop_size / self.expected_sample_rate / F0_FRAME_SPACE
# ---------- Resampling ----------
def maybe_resample(self, waveform: torch.Tensor, input_sample_rate: int) -> torch.Tensor:
if int(input_sample_rate) == int(self.expected_sample_rate):
return waveform
return torchaudio.functional.resample(
waveform, int(input_sample_rate), int(self.expected_sample_rate)
)
# ---------- Forward (parity) ----------
@torch.no_grad()
def forward(self, waveform: torch.Tensor, speaker: Optional[str] = None) -> Dict[str, torch.Tensor]:
"""
Returns:
{
"units": LongTensor [U],
"durations": LongTensor [U], (frame counts)
"dense": FloatTensor [T, D],
(optional) "f0": FloatTensor [U] or [T_f0] if implemented
}
"""
# 1) Dense features at HuBERT frame rate
dense_features = self.dense_model(waveform) # (T, D)
# optional feature normalization before KMeans (kept simple)
if self._feature_norm == "unit":
eps = 1e-6
dense_features = dense_features / (dense_features.norm(dim=-1, keepdim=True) + eps)
elif self._feature_norm == "layernorm":
mean = dense_features.mean(dim=-1, keepdim=True)
std = dense_features.std(dim=-1, keepdim=True).clamp_min(1e-5)
dense_features = (dense_features - mean) / std
# 2) KMeans quantization → unit ids (per-frame)
ids_per_frame = self.quantizer_model(dense_features) # (T,)
# 3) Dedup → durations
if self.deduplicate:
units, durations = torch.unique_consecutive(ids_per_frame, return_counts=True)
else:
units = ids_per_frame
durations = torch.ones_like(units, dtype=torch.long)
# 4) (Optional) F0 path — not bundled here
f0 = None
if self.need_f0:
raise NotImplementedError(
"F0 extraction is not included in this minimal HF port. "
"Set need_f0=False (as in the reference pipeline)."
)
# 5) BOS/EOS wrap (if requested)
if self.add_bos_eos:
units, durations, f0, dense_features = wrap_bos_eos(
units, durations, f0, dense_features, self.bos, self.eos
)
item = {
"units": units.to(self.device),
"durations": durations.to(self.device),
"dense": dense_features.to(self.device),
}
if f0 is not None:
item["f0"] = f0.to(self.device)
return item