moss-tok / app.py
Humair332's picture
Update app.py
ca20f23 verified
import time
import torch
import soundfile as sf
import torchaudio
import gradio as gr
import numpy as np
from transformers import HiggsAudioV2TokenizerModel, AutoFeatureExtractor
# ── Model loading ──────────────────────────────────────────────────────────────
REPO_ID = "eustlb/higgs-audio-v2-tokenizer"
NUM_CODEBOOKS = 8 # OmniVoice uses exactly 8 RVQ codebooks
model = None
feature_extractor = None
def load_model():
global model, feature_extractor
if model is None:
feature_extractor = AutoFeatureExtractor.from_pretrained(REPO_ID)
model = HiggsAudioV2TokenizerModel.from_pretrained(REPO_ID, device_map="auto")
model.eval()
return model, feature_extractor
# ── Core inference ─────────────────────────────────────────────────────────────
def run_tokenizer(audio_path: str, num_rvq_layers: int):
if audio_path is None:
return None, "", ""
m, fx = load_model()
device = next(m.parameters()).device
# Load & resample to 24 kHz
data, sr = sf.read(audio_path, always_2d=True) # (T, C)
wav = torch.from_numpy(data.T).float() # (C, T)
target_sr = fx.sampling_rate # 24000
if sr != target_sr:
wav = torchaudio.functional.resample(wav, sr, target_sr)
# Mix to mono for feature extractor
wav_mono = wav.mean(0).numpy() # (T,)
duration_s = len(wav_mono) / target_sr
# Feature extraction
inputs = fx(
raw_audio=wav_mono,
sampling_rate=target_sr,
return_tensors="pt",
)
input_values = inputs["input_values"].to(device) # (1, T)
# ── Encode ────────────────────────────────────────────────────────────────
t0 = time.perf_counter()
with torch.no_grad():
enc = m.encode(input_values)
encode_time = time.perf_counter() - t0
# audio_codes: normalise to (N_q, T_tok) regardless of batch/dim ordering
codes = enc.audio_codes # raw output
c = codes
if c.dim() == 3:
c = c[0] # drop batch dim β†’ (T_tok, N_q) or (N_q, T_tok)
if c.shape[0] > c.shape[1]: # more frames than codebooks
c = c.T # β†’ (N_q, T_tok)
elif c.dim() == 2:
c = c[0].unsqueeze(0) # β†’ (1, T_tok)
Q, T_tok = c.shape
# Slice to requested number of RVQ layers (OmniVoice uses up to 8)
layers_used = min(num_rvq_layers, Q)
c_sliced = c[:layers_used] # (layers_used, T_tok)
# ── Rebuild codes tensor for decode in original batch shape ───────────────
# Repack sliced codes back into the shape model.decode() expects.
# If original was (B, T, N_q) β†’ (1, T_tok, layers_used)
# If original was (B, N_q, T) β†’ (1, layers_used, T_tok)
orig = enc.audio_codes
if orig.dim() == 3:
if orig.shape[1] > orig.shape[2]: # (B, T, N_q)
codes_for_decode = c_sliced.T.unsqueeze(0) # (1, T_tok, layers_used)
else: # (B, N_q, T)
codes_for_decode = c_sliced.unsqueeze(0) # (1, layers_used, T_tok)
else:
codes_for_decode = c_sliced.unsqueeze(0) # (1, layers_used, T_tok)
# ── Decode ────────────────────────────────────────────────────────────────
t1 = time.perf_counter()
with torch.no_grad():
dec = m.decode(codes_for_decode)
decode_time = time.perf_counter() - t1
out_wav = dec.audio_values.squeeze(0).cpu() # (C, T) or (T,)
if out_wav.dim() == 1:
out_wav = out_wav.unsqueeze(0) # β†’ (1, T)
# ── Token table ───────────────────────────────────────────────────────────
lines = []
for q_idx, row in enumerate(c_sliced.cpu().tolist()):
row_str = " ".join(f"{int(v):4d}" for v in row)
lines.append(f"RVQ-{q_idx+1:02d}: [{row_str}]")
token_text = "\n".join(lines)
# ── Stats ─────────────────────────────────────────────────────────────────
frame_rate = T_tok / duration_s
total_tokens = layers_used * T_tok
enc_speed = T_tok / encode_time
dec_speed = T_tok / decode_time
compression = len(wav_mono) / T_tok
bitrate_kbps = (layers_used * frame_rate * 10) / 1000 # 10 bits/token (1024-entry codebook)
stats_md = f"""| Metric | Value |
|---|---|
| Audio duration | {duration_s:.3f} s |
| Sample rate | {target_sr:,} Hz |
| RVQ codebooks used | {layers_used} / {Q} |
| Codebook size | 1024 (10 bit) |
| Frame rate | {frame_rate:.2f} tok/s |
| Token frames (T) | {T_tok:,} |
| Total tokens (N_q Γ— T) | {total_tokens:,} |
| Bitrate | {bitrate_kbps:.2f} kbps |
| Encode speed | {enc_speed:,.1f} frames/s ({duration_s/encode_time:.1f}Γ— real-time) |
| Decode speed | {dec_speed:,.1f} frames/s ({duration_s/decode_time:.1f}Γ— real-time) |
| Compression ratio | {compression:.1f}Γ— |
| Encode time | {encode_time*1000:.1f} ms |
| Decode time | {decode_time*1000:.1f} ms |
| Device | {str(device).upper()} |"""
# Return as numpy for Gradio
out_np = out_wav.numpy()
if out_np.shape[0] == 1:
out_np = out_np[0] # mono
else:
out_np = out_np.T # stereo β†’ (T, C)
return (target_sr, out_np), token_text, stats_md
# ── UI ─────────────────────────────────────────────────────────────────────────
CSS = """
#token-box textarea {
font-family: "JetBrains Mono", "Fira Code", monospace;
font-size: 12px;
white-space: pre;
}
.title { text-align: center; }
"""
with gr.Blocks(title="Higgs Audio V2 Tokenizer (OmniVoice)") as demo:
gr.Markdown(
"""# πŸŽ™οΈ Higgs Audio V2 Tokenizer β€” as used in OmniVoice
**[`eustlb/higgs-audio-v2-tokenizer`](https://huggingface.co/eustlb/higgs-audio-v2-tokenizer)**
Β· 8-codebook RVQ Β· 25 Hz frame rate Β· 24 kHz Β· Speech + Music + Sound Events
Β· Used by [`k2-fsa/OmniVoice`](https://github.com/k2-fsa/OmniVoice) for 600+ language TTS""",
elem_classes="title",
)
with gr.Row():
with gr.Column():
audio_in = gr.Audio(
label="πŸ“‚ Upload Audio",
type="filepath",
)
rvq_slider = gr.Slider(
minimum=1,
maximum=NUM_CODEBOOKS,
value=NUM_CODEBOOKS,
step=1,
label=f"RVQ codebooks used for decode (fewer = lower quality / lower bitrate, OmniVoice uses all {NUM_CODEBOOKS})",
)
run_btn = gr.Button("β–Ά Tokenize & Reconstruct", variant="primary")
with gr.Column():
audio_out = gr.Audio(
label="πŸ”Š Reconstructed Audio",
type="numpy",
)
gr.Markdown("### πŸ”’ Tokens")
token_box = gr.Textbox(
label="Per-RVQ-codebook token indices (codebook size: 1024)",
lines=10,
interactive=False,
elem_id="token-box",
)
gr.Markdown("### πŸ“Š Stats")
stats_box = gr.Markdown()
run_btn.click(
fn=run_tokenizer,
inputs=[audio_in, rvq_slider],
outputs=[audio_out, token_box, stats_box],
)
if __name__ == "__main__":
demo.launch(css=CSS)