| """ |
| SEMACS Speech Tokenizer β encode audio β VQ codes. |
| |
| Encodes raw waveforms into discrete speech tokens using: |
| 1. CodecEncoder_Transformer β acoustic feature extraction |
| 2. WhisperVQEncoder (frozen) β semantic feature extraction |
| 3. SemanticEncoder β semantic projection |
| 4. ResidualFSQ quantizer β vector quantization β integer codes |
| |
| Output codes are at 12.5 Hz (one code every 80 ms at 16 kHz). |
| |
| Usage |
| ----- |
| from speech_tokenizer import SpeechTokenizer, SemacsConfig |
| from safetensors.torch import load_file |
| |
| config = SemacsConfig.from_json_file("config/config.json") |
| model = SpeechTokenizer(config) |
| model.load_state_dict(load_file("pretrained/model.safetensors"), strict=False) |
| model.eval() |
| |
| # From file path |
| codes = model.encode("audio.wav") # (1, 8, T_codes) |
| |
| # From waveform tensor (B, T) at 16 kHz |
| codes = model.encode_wav(wavs, feats) # (B, 8, T_codes) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torchaudio |
| from transformers import PreTrainedModel, WhisperFeatureExtractor |
|
|
| from .config import SemacsConfig |
| from .vq.codec_encoder import CodecEncoder_Transformer |
| from .vq.codec_decoder_vocos import CodecDecoderVocos |
| from .vq.module import SemanticEncoder |
| from .vq.whisper_encoder.modeling_whisper import WhisperVQEncoder |
|
|
|
|
| class SpeechTokenizer(PreTrainedModel): |
| """ |
| Encode-only SEMACS model. Produces discrete VQ codes from audio. |
| |
| Removed vs. full SEMACS: |
| - SemanticDecoder_module (reconstruction only) |
| - fc_post_a (audio decoder input) |
| - vq_upsample (upsamples for decoder, not needed here) |
| - decode() method |
| """ |
|
|
| config_class = SemacsConfig |
|
|
| def __init__(self, config: SemacsConfig): |
| super().__init__(config) |
|
|
| |
| self.semantic_model = WhisperVQEncoder.from_pretrained( |
| "zai-org/glm-4-voice-tokenizer" |
| ) |
| self.semantic_model.eval() |
| self.semantic_model._freeze_parameters() |
| self.semantic_model.requires_grad_(False) |
|
|
| |
| self.semantic_processor = WhisperFeatureExtractor.from_pretrained( |
| "zai-org/glm-4-voice-tokenizer" |
| ) |
|
|
| hidden_size = config.semantic_hidden_size |
| self.SemanticEncoder_module = SemanticEncoder(hidden_size, hidden_size, hidden_size) |
|
|
| |
| enccfg = config.model["codec_encoder"] |
| self.CodecEnc = CodecEncoder_Transformer( |
| ngf=enccfg["ngf"], |
| up_ratios=enccfg["up_ratios"], |
| dilations=enccfg["dilations"], |
| hidden_dim=enccfg["hidden_dim"], |
| output_dim=enccfg["output_dim"], |
| depth=enccfg["depth"], |
| heads=enccfg["heads"], |
| pos_meb_dim=enccfg["pos_meb_dim"], |
| ) |
|
|
| self.conv_downsampler = nn.Sequential( |
| nn.LeakyReLU(0.2), |
| nn.Conv1d( |
| enccfg["output_dim"], |
| enccfg["output_dim"], |
| kernel_size=8, |
| stride=4, |
| padding=2, |
| ), |
| ) |
|
|
| |
| concat_dim = hidden_size + enccfg["output_dim"] |
| self.fc_prior = nn.Linear(concat_dim, 2048) |
|
|
| deccfg = config.model["codec_decoder"] |
| self.generator = CodecDecoderVocos( |
| hidden_dim=deccfg["hidden_dim"], |
| depth=deccfg["depth"], |
| heads=deccfg["heads"], |
| pos_meb_dim=deccfg["pos_meb_dim"], |
| hop_length=deccfg.get("hop_length", 320), |
| vq_num_quantizers=deccfg["vq_num_quantizers"], |
| vq_dim=deccfg["vq_dim"], |
| vq_commit_weight=deccfg["vq_commit_weight"], |
| vq_weight_init=deccfg["vq_weight_init"], |
| vq_full_commit_loss=deccfg["vq_full_commit_loss"], |
| codebook_size=deccfg["codebook_size"], |
| codebook_dim=deccfg["codebook_dim"], |
| ) |
|
|
| |
|
|
| def _build_whisper_feat(self, wav_np): |
| """wav_np: 1-D numpy array at 16 kHz β whisper feature dict on self.device.""" |
| import numpy as np |
| import torch.nn.functional as F |
|
|
| pooling_kernel_size = self.semantic_model.config.pooling_kernel_size or 1 |
| stride = ( |
| self.semantic_model.conv1.stride[0] |
| * self.semantic_model.conv2.stride[0] |
| * pooling_kernel_size |
| * self.semantic_processor.hop_length |
| ) |
| wav_pad = F.pad(torch.from_numpy(wav_np), (640, 640)).numpy() |
| feat = self.semantic_processor( |
| raw_speech=wav_pad, |
| sampling_rate=16000, |
| return_tensors="pt", |
| return_attention_mask=True, |
| padding="longest", |
| pad_to_multiple_of=stride, |
| ) |
| return { |
| k: v.to(self.device) if isinstance(v, torch.Tensor) else v |
| for k, v in feat.items() |
| } |
|
|
| def _encode_core(self, wav_tensor: torch.Tensor, feat: dict) -> torch.Tensor: |
| """ |
| Shared encode logic. |
| |
| Args: |
| wav_tensor: (B, T) float32 waveform at 16 kHz |
| feat: Whisper feature dict (input_features, attention_mask) |
| |
| Returns: |
| vq_code: (B, num_fsq_levels, T_codes) |
| """ |
| with torch.no_grad(): |
| vq_emb = self.CodecEnc(wav_tensor.unsqueeze(1)).transpose(1, 2) |
| vq_emb = self.conv_downsampler(vq_emb) |
|
|
| semantic_out = self.semantic_model(**feat) |
| semantic = semantic_out.last_hidden_state.detach().transpose(1, 2) |
| semantic = self.SemanticEncoder_module(semantic) |
|
|
| |
| if vq_emb.shape[-1] != semantic.shape[-1]: |
| T = min(vq_emb.shape[-1], semantic.shape[-1]) |
| vq_emb = vq_emb[:, :, :T] |
| semantic = semantic[:, :, :T] |
|
|
| vq_emb = torch.cat([semantic, vq_emb], dim=1) |
| vq_emb = self.fc_prior(vq_emb.transpose(1, 2)).transpose(1, 2) |
|
|
| _, vq_code, _ = self.generator(vq_emb, vq=True) |
|
|
| return vq_code |
|
|
| |
|
|
| @torch.no_grad() |
| def encode(self, file_path: str) -> torch.Tensor: |
| """ |
| Tokenize a single audio file. |
| |
| Args: |
| file_path: path to audio (any format supported by torchaudio) |
| |
| Returns: |
| vq_code: (1, num_fsq_levels, T_codes) |
| """ |
| wav, sr = torchaudio.load(file_path) |
| wav = wav.mean(0) |
| if sr != 16000: |
| wav = torchaudio.functional.resample(wav, sr, 16000) |
|
|
| feat = self._build_whisper_feat(wav.numpy()) |
| return self._encode_core(wav.unsqueeze(0).to(self.device), feat) |
|
|
| def forward(self, wav: torch.Tensor, feat: dict) -> dict: |
| """ |
| Batch tokenization forward pass. Used by batch_infer.py. |
| |
| Args: |
| wav: (B, T) float32 waveform at 16 kHz |
| feat: Whisper feature dict with batched input_features + attention_mask |
| |
| Returns: |
| {"vq_code": (B, num_fsq_levels, T_codes)} |
| """ |
| vq_code = self._encode_core(wav, feat) |
| return {"vq_code": vq_code} |
|
|