sunf / speech_tokenizer /tokenizer.py
anhtunguyen98's picture
Upload folder using huggingface_hub
4698bfc verified
"""
SEMACS Speech Tokenizer β€” encode audio β†’ VQ codes.
Encodes raw waveforms into discrete speech tokens using:
1. CodecEncoder_Transformer β€” acoustic feature extraction
2. WhisperVQEncoder (frozen) β€” semantic feature extraction
3. SemanticEncoder β€” semantic projection
4. ResidualFSQ quantizer β€” vector quantization β†’ integer codes
Output codes are at 12.5 Hz (one code every 80 ms at 16 kHz).
Usage
-----
from speech_tokenizer import SpeechTokenizer, SemacsConfig
from safetensors.torch import load_file
config = SemacsConfig.from_json_file("config/config.json")
model = SpeechTokenizer(config)
model.load_state_dict(load_file("pretrained/model.safetensors"), strict=False)
model.eval()
# From file path
codes = model.encode("audio.wav") # (1, 8, T_codes)
# From waveform tensor (B, T) at 16 kHz
codes = model.encode_wav(wavs, feats) # (B, 8, T_codes)
"""
import torch
import torch.nn as nn
import torchaudio
from transformers import PreTrainedModel, WhisperFeatureExtractor
from .config import SemacsConfig
from .vq.codec_encoder import CodecEncoder_Transformer
from .vq.codec_decoder_vocos import CodecDecoderVocos
from .vq.module import SemanticEncoder
from .vq.whisper_encoder.modeling_whisper import WhisperVQEncoder
class SpeechTokenizer(PreTrainedModel):
"""
Encode-only SEMACS model. Produces discrete VQ codes from audio.
Removed vs. full SEMACS:
- SemanticDecoder_module (reconstruction only)
- fc_post_a (audio decoder input)
- vq_upsample (upsamples for decoder, not needed here)
- decode() method
"""
config_class = SemacsConfig
def __init__(self, config: SemacsConfig):
super().__init__(config)
# ── Frozen semantic model (GLM-4-Voice Whisper) ───────────────────
self.semantic_model = WhisperVQEncoder.from_pretrained(
"zai-org/glm-4-voice-tokenizer"
)
self.semantic_model.eval()
self.semantic_model._freeze_parameters()
self.semantic_model.requires_grad_(False)
# Feature extractor used by encode(file_path)
self.semantic_processor = WhisperFeatureExtractor.from_pretrained(
"zai-org/glm-4-voice-tokenizer"
)
hidden_size = config.semantic_hidden_size
self.SemanticEncoder_module = SemanticEncoder(hidden_size, hidden_size, hidden_size)
# ── Codec encoder ─────────────────────────────────────────────────
enccfg = config.model["codec_encoder"]
self.CodecEnc = CodecEncoder_Transformer(
ngf=enccfg["ngf"],
up_ratios=enccfg["up_ratios"],
dilations=enccfg["dilations"],
hidden_dim=enccfg["hidden_dim"],
output_dim=enccfg["output_dim"],
depth=enccfg["depth"],
heads=enccfg["heads"],
pos_meb_dim=enccfg["pos_meb_dim"],
)
self.conv_downsampler = nn.Sequential(
nn.LeakyReLU(0.2),
nn.Conv1d(
enccfg["output_dim"],
enccfg["output_dim"],
kernel_size=8,
stride=4,
padding=2,
),
)
# ── Fusion + quantizer ────────────────────────────────────────────
concat_dim = hidden_size + enccfg["output_dim"]
self.fc_prior = nn.Linear(concat_dim, 2048)
deccfg = config.model["codec_decoder"]
self.generator = CodecDecoderVocos(
hidden_dim=deccfg["hidden_dim"],
depth=deccfg["depth"],
heads=deccfg["heads"],
pos_meb_dim=deccfg["pos_meb_dim"],
hop_length=deccfg.get("hop_length", 320),
vq_num_quantizers=deccfg["vq_num_quantizers"],
vq_dim=deccfg["vq_dim"],
vq_commit_weight=deccfg["vq_commit_weight"],
vq_weight_init=deccfg["vq_weight_init"],
vq_full_commit_loss=deccfg["vq_full_commit_loss"],
codebook_size=deccfg["codebook_size"],
codebook_dim=deccfg["codebook_dim"],
)
# ── Internal helpers ──────────────────────────────────────────────────────
def _build_whisper_feat(self, wav_np):
"""wav_np: 1-D numpy array at 16 kHz β†’ whisper feature dict on self.device."""
import numpy as np
import torch.nn.functional as F
pooling_kernel_size = self.semantic_model.config.pooling_kernel_size or 1
stride = (
self.semantic_model.conv1.stride[0]
* self.semantic_model.conv2.stride[0]
* pooling_kernel_size
* self.semantic_processor.hop_length
)
wav_pad = F.pad(torch.from_numpy(wav_np), (640, 640)).numpy()
feat = self.semantic_processor(
raw_speech=wav_pad,
sampling_rate=16000,
return_tensors="pt",
return_attention_mask=True,
padding="longest",
pad_to_multiple_of=stride,
)
return {
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in feat.items()
}
def _encode_core(self, wav_tensor: torch.Tensor, feat: dict) -> torch.Tensor:
"""
Shared encode logic.
Args:
wav_tensor: (B, T) float32 waveform at 16 kHz
feat: Whisper feature dict (input_features, attention_mask)
Returns:
vq_code: (B, num_fsq_levels, T_codes)
"""
with torch.no_grad():
vq_emb = self.CodecEnc(wav_tensor.unsqueeze(1)).transpose(1, 2)
vq_emb = self.conv_downsampler(vq_emb)
semantic_out = self.semantic_model(**feat)
semantic = semantic_out.last_hidden_state.detach().transpose(1, 2)
semantic = self.SemanticEncoder_module(semantic)
# Align time dimensions
if vq_emb.shape[-1] != semantic.shape[-1]:
T = min(vq_emb.shape[-1], semantic.shape[-1])
vq_emb = vq_emb[:, :, :T]
semantic = semantic[:, :, :T]
vq_emb = torch.cat([semantic, vq_emb], dim=1)
vq_emb = self.fc_prior(vq_emb.transpose(1, 2)).transpose(1, 2)
_, vq_code, _ = self.generator(vq_emb, vq=True)
return vq_code # (B, num_fsq_levels, T_codes)
# ── Public API ────────────────────────────────────────────────────────────
@torch.no_grad()
def encode(self, file_path: str) -> torch.Tensor:
"""
Tokenize a single audio file.
Args:
file_path: path to audio (any format supported by torchaudio)
Returns:
vq_code: (1, num_fsq_levels, T_codes)
"""
wav, sr = torchaudio.load(file_path)
wav = wav.mean(0) # mono
if sr != 16000:
wav = torchaudio.functional.resample(wav, sr, 16000)
feat = self._build_whisper_feat(wav.numpy())
return self._encode_core(wav.unsqueeze(0).to(self.device), feat)
def forward(self, wav: torch.Tensor, feat: dict) -> dict:
"""
Batch tokenization forward pass. Used by batch_infer.py.
Args:
wav: (B, T) float32 waveform at 16 kHz
feat: Whisper feature dict with batched input_features + attention_mask
Returns:
{"vq_code": (B, num_fsq_levels, T_codes)}
"""
vq_code = self._encode_core(wav, feat)
return {"vq_code": vq_code}