""" SEMACS Speech Tokenizer — encode audio → VQ codes. Encodes raw waveforms into discrete speech tokens using: 1. CodecEncoder_Transformer — acoustic feature extraction 2. WhisperVQEncoder (frozen) — semantic feature extraction 3. SemanticEncoder — semantic projection 4. ResidualFSQ quantizer — vector quantization → integer codes Output codes are at 12.5 Hz (one code every 80 ms at 16 kHz). Usage ----- from speech_tokenizer import SpeechTokenizer, SemacsConfig from safetensors.torch import load_file config = SemacsConfig.from_json_file("config/config.json") model = SpeechTokenizer(config) model.load_state_dict(load_file("pretrained/model.safetensors"), strict=False) model.eval() # From file path codes = model.encode("audio.wav") # (1, 8, T_codes) # From waveform tensor (B, T) at 16 kHz codes = model.encode_wav(wavs, feats) # (B, 8, T_codes) """ import torch import torch.nn as nn import torchaudio from transformers import PreTrainedModel, WhisperFeatureExtractor from .config import SemacsConfig from .vq.codec_encoder import CodecEncoder_Transformer from .vq.codec_decoder_vocos import CodecDecoderVocos from .vq.module import SemanticEncoder from .vq.whisper_encoder.modeling_whisper import WhisperVQEncoder class SpeechTokenizer(PreTrainedModel): """ Encode-only SEMACS model. Produces discrete VQ codes from audio. Removed vs. full SEMACS: - SemanticDecoder_module (reconstruction only) - fc_post_a (audio decoder input) - vq_upsample (upsamples for decoder, not needed here) - decode() method """ config_class = SemacsConfig def __init__(self, config: SemacsConfig): super().__init__(config) # ── Frozen semantic model (GLM-4-Voice Whisper) ─────────────────── self.semantic_model = WhisperVQEncoder.from_pretrained( "zai-org/glm-4-voice-tokenizer" ) self.semantic_model.eval() self.semantic_model._freeze_parameters() self.semantic_model.requires_grad_(False) # Feature extractor used by encode(file_path) self.semantic_processor = WhisperFeatureExtractor.from_pretrained( "zai-org/glm-4-voice-tokenizer" ) hidden_size = config.semantic_hidden_size self.SemanticEncoder_module = SemanticEncoder(hidden_size, hidden_size, hidden_size) # ── Codec encoder ───────────────────────────────────────────────── enccfg = config.model["codec_encoder"] self.CodecEnc = CodecEncoder_Transformer( ngf=enccfg["ngf"], up_ratios=enccfg["up_ratios"], dilations=enccfg["dilations"], hidden_dim=enccfg["hidden_dim"], output_dim=enccfg["output_dim"], depth=enccfg["depth"], heads=enccfg["heads"], pos_meb_dim=enccfg["pos_meb_dim"], ) self.conv_downsampler = nn.Sequential( nn.LeakyReLU(0.2), nn.Conv1d( enccfg["output_dim"], enccfg["output_dim"], kernel_size=8, stride=4, padding=2, ), ) # ── Fusion + quantizer ──────────────────────────────────────────── concat_dim = hidden_size + enccfg["output_dim"] self.fc_prior = nn.Linear(concat_dim, 2048) deccfg = config.model["codec_decoder"] self.generator = CodecDecoderVocos( hidden_dim=deccfg["hidden_dim"], depth=deccfg["depth"], heads=deccfg["heads"], pos_meb_dim=deccfg["pos_meb_dim"], hop_length=deccfg.get("hop_length", 320), vq_num_quantizers=deccfg["vq_num_quantizers"], vq_dim=deccfg["vq_dim"], vq_commit_weight=deccfg["vq_commit_weight"], vq_weight_init=deccfg["vq_weight_init"], vq_full_commit_loss=deccfg["vq_full_commit_loss"], codebook_size=deccfg["codebook_size"], codebook_dim=deccfg["codebook_dim"], ) # ── Internal helpers ────────────────────────────────────────────────────── def _build_whisper_feat(self, wav_np): """wav_np: 1-D numpy array at 16 kHz → whisper feature dict on self.device.""" import numpy as np import torch.nn.functional as F pooling_kernel_size = self.semantic_model.config.pooling_kernel_size or 1 stride = ( self.semantic_model.conv1.stride[0] * self.semantic_model.conv2.stride[0] * pooling_kernel_size * self.semantic_processor.hop_length ) wav_pad = F.pad(torch.from_numpy(wav_np), (640, 640)).numpy() feat = self.semantic_processor( raw_speech=wav_pad, sampling_rate=16000, return_tensors="pt", return_attention_mask=True, padding="longest", pad_to_multiple_of=stride, ) return { k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in feat.items() } def _encode_core(self, wav_tensor: torch.Tensor, feat: dict) -> torch.Tensor: """ Shared encode logic. Args: wav_tensor: (B, T) float32 waveform at 16 kHz feat: Whisper feature dict (input_features, attention_mask) Returns: vq_code: (B, num_fsq_levels, T_codes) """ with torch.no_grad(): vq_emb = self.CodecEnc(wav_tensor.unsqueeze(1)).transpose(1, 2) vq_emb = self.conv_downsampler(vq_emb) semantic_out = self.semantic_model(**feat) semantic = semantic_out.last_hidden_state.detach().transpose(1, 2) semantic = self.SemanticEncoder_module(semantic) # Align time dimensions if vq_emb.shape[-1] != semantic.shape[-1]: T = min(vq_emb.shape[-1], semantic.shape[-1]) vq_emb = vq_emb[:, :, :T] semantic = semantic[:, :, :T] vq_emb = torch.cat([semantic, vq_emb], dim=1) vq_emb = self.fc_prior(vq_emb.transpose(1, 2)).transpose(1, 2) _, vq_code, _ = self.generator(vq_emb, vq=True) return vq_code # (B, num_fsq_levels, T_codes) # ── Public API ──────────────────────────────────────────────────────────── @torch.no_grad() def encode(self, file_path: str) -> torch.Tensor: """ Tokenize a single audio file. Args: file_path: path to audio (any format supported by torchaudio) Returns: vq_code: (1, num_fsq_levels, T_codes) """ wav, sr = torchaudio.load(file_path) wav = wav.mean(0) # mono if sr != 16000: wav = torchaudio.functional.resample(wav, sr, 16000) feat = self._build_whisper_feat(wav.numpy()) return self._encode_core(wav.unsqueeze(0).to(self.device), feat) def forward(self, wav: torch.Tensor, feat: dict) -> dict: """ Batch tokenization forward pass. Used by batch_infer.py. Args: wav: (B, T) float32 waveform at 16 kHz feat: Whisper feature dict with batched input_features + attention_mask Returns: {"vq_code": (B, num_fsq_levels, T_codes)} """ vq_code = self._encode_core(wav, feat) return {"vq_code": vq_code}