import time import torch import soundfile as sf import torchaudio import gradio as gr import numpy as np from transformers import HiggsAudioV2TokenizerModel, AutoFeatureExtractor # ── Model loading ────────────────────────────────────────────────────────────── REPO_ID = "eustlb/higgs-audio-v2-tokenizer" NUM_CODEBOOKS = 8 # OmniVoice uses exactly 8 RVQ codebooks model = None feature_extractor = None def load_model(): global model, feature_extractor if model is None: feature_extractor = AutoFeatureExtractor.from_pretrained(REPO_ID) model = HiggsAudioV2TokenizerModel.from_pretrained(REPO_ID, device_map="auto") model.eval() return model, feature_extractor # ── Core inference ───────────────────────────────────────────────────────────── def run_tokenizer(audio_path: str, num_rvq_layers: int): if audio_path is None: return None, "", "" m, fx = load_model() device = next(m.parameters()).device # Load & resample to 24 kHz data, sr = sf.read(audio_path, always_2d=True) # (T, C) wav = torch.from_numpy(data.T).float() # (C, T) target_sr = fx.sampling_rate # 24000 if sr != target_sr: wav = torchaudio.functional.resample(wav, sr, target_sr) # Mix to mono for feature extractor wav_mono = wav.mean(0).numpy() # (T,) duration_s = len(wav_mono) / target_sr # Feature extraction inputs = fx( raw_audio=wav_mono, sampling_rate=target_sr, return_tensors="pt", ) input_values = inputs["input_values"].to(device) # (1, T) # ── Encode ──────────────────────────────────────────────────────────────── t0 = time.perf_counter() with torch.no_grad(): enc = m.encode(input_values) encode_time = time.perf_counter() - t0 # audio_codes: normalise to (N_q, T_tok) regardless of batch/dim ordering codes = enc.audio_codes # raw output c = codes if c.dim() == 3: c = c[0] # drop batch dim → (T_tok, N_q) or (N_q, T_tok) if c.shape[0] > c.shape[1]: # more frames than codebooks c = c.T # → (N_q, T_tok) elif c.dim() == 2: c = c[0].unsqueeze(0) # → (1, T_tok) Q, T_tok = c.shape # Slice to requested number of RVQ layers (OmniVoice uses up to 8) layers_used = min(num_rvq_layers, Q) c_sliced = c[:layers_used] # (layers_used, T_tok) # ── Rebuild codes tensor for decode in original batch shape ─────────────── # Repack sliced codes back into the shape model.decode() expects. # If original was (B, T, N_q) → (1, T_tok, layers_used) # If original was (B, N_q, T) → (1, layers_used, T_tok) orig = enc.audio_codes if orig.dim() == 3: if orig.shape[1] > orig.shape[2]: # (B, T, N_q) codes_for_decode = c_sliced.T.unsqueeze(0) # (1, T_tok, layers_used) else: # (B, N_q, T) codes_for_decode = c_sliced.unsqueeze(0) # (1, layers_used, T_tok) else: codes_for_decode = c_sliced.unsqueeze(0) # (1, layers_used, T_tok) # ── Decode ──────────────────────────────────────────────────────────────── t1 = time.perf_counter() with torch.no_grad(): dec = m.decode(codes_for_decode) decode_time = time.perf_counter() - t1 out_wav = dec.audio_values.squeeze(0).cpu() # (C, T) or (T,) if out_wav.dim() == 1: out_wav = out_wav.unsqueeze(0) # → (1, T) # ── Token table ─────────────────────────────────────────────────────────── lines = [] for q_idx, row in enumerate(c_sliced.cpu().tolist()): row_str = " ".join(f"{int(v):4d}" for v in row) lines.append(f"RVQ-{q_idx+1:02d}: [{row_str}]") token_text = "\n".join(lines) # ── Stats ───────────────────────────────────────────────────────────────── frame_rate = T_tok / duration_s total_tokens = layers_used * T_tok enc_speed = T_tok / encode_time dec_speed = T_tok / decode_time compression = len(wav_mono) / T_tok bitrate_kbps = (layers_used * frame_rate * 10) / 1000 # 10 bits/token (1024-entry codebook) stats_md = f"""| Metric | Value | |---|---| | Audio duration | {duration_s:.3f} s | | Sample rate | {target_sr:,} Hz | | RVQ codebooks used | {layers_used} / {Q} | | Codebook size | 1024 (10 bit) | | Frame rate | {frame_rate:.2f} tok/s | | Token frames (T) | {T_tok:,} | | Total tokens (N_q × T) | {total_tokens:,} | | Bitrate | {bitrate_kbps:.2f} kbps | | Encode speed | {enc_speed:,.1f} frames/s ({duration_s/encode_time:.1f}× real-time) | | Decode speed | {dec_speed:,.1f} frames/s ({duration_s/decode_time:.1f}× real-time) | | Compression ratio | {compression:.1f}× | | Encode time | {encode_time*1000:.1f} ms | | Decode time | {decode_time*1000:.1f} ms | | Device | {str(device).upper()} |""" # Return as numpy for Gradio out_np = out_wav.numpy() if out_np.shape[0] == 1: out_np = out_np[0] # mono else: out_np = out_np.T # stereo → (T, C) return (target_sr, out_np), token_text, stats_md # ── UI ───────────────────────────────────────────────────────────────────────── CSS = """ #token-box textarea { font-family: "JetBrains Mono", "Fira Code", monospace; font-size: 12px; white-space: pre; } .title { text-align: center; } """ with gr.Blocks(title="Higgs Audio V2 Tokenizer (OmniVoice)") as demo: gr.Markdown( """# 🎙️ Higgs Audio V2 Tokenizer — as used in OmniVoice **[`eustlb/higgs-audio-v2-tokenizer`](https://huggingface.co/eustlb/higgs-audio-v2-tokenizer)** · 8-codebook RVQ · 25 Hz frame rate · 24 kHz · Speech + Music + Sound Events · Used by [`k2-fsa/OmniVoice`](https://github.com/k2-fsa/OmniVoice) for 600+ language TTS""", elem_classes="title", ) with gr.Row(): with gr.Column(): audio_in = gr.Audio( label="📂 Upload Audio", type="filepath", ) rvq_slider = gr.Slider( minimum=1, maximum=NUM_CODEBOOKS, value=NUM_CODEBOOKS, step=1, label=f"RVQ codebooks used for decode (fewer = lower quality / lower bitrate, OmniVoice uses all {NUM_CODEBOOKS})", ) run_btn = gr.Button("▶ Tokenize & Reconstruct", variant="primary") with gr.Column(): audio_out = gr.Audio( label="🔊 Reconstructed Audio", type="numpy", ) gr.Markdown("### 🔢 Tokens") token_box = gr.Textbox( label="Per-RVQ-codebook token indices (codebook size: 1024)", lines=10, interactive=False, elem_id="token-box", ) gr.Markdown("### 📊 Stats") stats_box = gr.Markdown() run_btn.click( fn=run_tokenizer, inputs=[audio_in, rvq_slider], outputs=[audio_out, token_box, stats_box], ) if __name__ == "__main__": demo.launch(css=CSS)