Chatterbox-Finnish-ONNX / scripts /browser_pipeline_sim.py
RASMUS's picture
Add scripts/browser_pipeline_sim.py
6264e00 verified
"""
browser_pipeline_sim.py
Python simulation of the browser worker.ts logic.
Mirrors the EXACT token assembly, CFG loop, sampling, and decoding
that the browser worker does β€” making it trivial to diff against worker.ts.
If this produces correct audio, worker.ts with the same logic will too.
If there's a bug in the browser, you can find it by diffing this script against worker.ts.
Usage:
cd /workspaces/work
conda run -n chatterbox-onnx python browser_pipeline_sim.py [text] [ref_audio]
Outputs:
_cmp/browser_sim_output.wav
"""
import os, sys, time, json, struct, math
import numpy as np
import onnxruntime as ort
import soundfile as sf
import librosa
import requests
from pathlib import Path
from huggingface_hub import hf_hub_download
sys.path.insert(0, "Chatterbox-Finnish")
# ── Config β€” mirrors browser worker constants ─────────────────────────────────
TEXT = sys.argv[1] if len(sys.argv) > 1 else \
"Tervetuloa kokeilemaan hienoviritettyΓ€ suomenkielistΓ€ Chatterbox-puhesynteesiΓ€."
REF_AUDIO = sys.argv[2] if len(sys.argv) > 2 else \
"Chatterbox-Finnish/samples/reference_finnish.wav"
HF_BASE = "onnx-community/chatterbox-multilingual-ONNX"
HF_FI = "RASMUS/Chatterbox-Finnish-ONNX"
CACHE_DIR = "_onnx_cache"
OUT_DIR = Path("_cmp"); OUT_DIR.mkdir(exist_ok=True)
# Worker constants (from worker.ts)
SOT = 255 # [START] token
EOT = 0 # [STOP] token
START_SPEECH = 6561 # BOS speech token
STOP_SPEECH = 6562 # EOS speech token
CFG_WEIGHT = 0.5
REP_PENALTY = 1.2
TEMPERATURE = 0.8
EXAGGERATION = 0.6
MIN_SPEECH_TOKENS = 40
MAX_STEPS = 1000
SAMPLE_RATE = 24000
GROQ_KEY = os.environ.get("GROQ_API_KEY", "")
# ── puncNorm β€” mirrors worker.ts puncNorm() ───────────────────────────────────
def punc_norm(text: str) -> str:
"""Mirrors puncNorm() in worker.ts"""
import re
t = text.strip()
t = re.sub(r'\s+', ' ', t) # normalize whitespace
t = t[0].upper() + t[1:] if t else t # capitalize first letter
t = t.replace(' .', '.').replace(' ,', ',') # remove space before punct
t = t.replace(' ?', '?').replace(' !', '!')
if t and t[-1] not in '.!?…': # ensure ending punctuation
t += '.'
return t
# ── EnTokenizer (same as Chatterbox-Finnish/src/chatterbox_/models/tokenizers/tokenizer.py)
def load_tokenizer(tokenizer_path: str):
from src.chatterbox_.models.tokenizers.tokenizer import EnTokenizer
return EnTokenizer(tokenizer_path)
# ── puncNorm + tokenize β†’ mirrors worker.ts tokenize() ───────────────────────
def tokenize(tok, text: str) -> list[int]:
"""
Mirrors browser worker tokenize():
1. puncNorm
2. replace spaces with [SPACE] token
3. encode grapheme by grapheme
4. wrap with [SOT, ..., EOT]
"""
normed = punc_norm(text)
ids = tok.encode(normed)
return [SOT] + ids + [EOT]
# ── Download helpers ───────────────────────────────────────────────────────────
def dl(repo_id, filename):
return hf_hub_download(repo_id=repo_id, filename=filename,
local_dir=CACHE_DIR, local_dir_use_symlinks=False)
# ── KV cache helpers ───────────────────────────────────────────────────────────
def empty_kv(n_layers=30, kv_dtype=np.float32):
return [(np.zeros((1, 16, 0, 64), dtype=kv_dtype),
np.zeros((1, 16, 0, 64), dtype=kv_dtype))
for _ in range(n_layers)]
def make_kv_feeds(kv_cache):
feeds = {}
for i, (k, v) in enumerate(kv_cache):
feeds[f"past_key_values.{i}.key"] = k
feeds[f"past_key_values.{i}.value"] = v
return feeds
# ── Sampling helpers ───────────────────────────────────────────────────────────
def apply_rep_penalty(logits, generated_set, penalty):
logits = logits.copy()
for tok in generated_set:
logits[tok] = logits[tok] / penalty if logits[tok] > 0 else logits[tok] * penalty
return logits
def apply_min_p(logits, p=0.05):
"""Mirrors applyMinP() in worker.ts"""
logits = logits.copy()
probs = np.exp(logits - logits.max())
probs /= probs.sum()
threshold = probs.max() * p
logits[probs < threshold] = -1e9
return logits
def sample_with_temperature(logits, temperature):
"""Mirrors sampleWithTemperature() in worker.ts"""
logits = (logits / temperature).astype(np.float64)
logits -= logits.max()
probs = np.exp(logits)
probs /= probs.sum()
return int(np.random.choice(len(probs), p=probs))
# ── Main pipeline ─────────────────────────────────────────────────────────────
def main():
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# ── Load sessions ──
print("Loading ONNX sessions...")
sess_se = ort.InferenceSession(dl(HF_BASE, "onnx/speech_encoder.onnx"), opts, providers)
sess_et = ort.InferenceSession(dl(HF_BASE, "onnx/embed_tokens.onnx"), opts, providers)
sess_lm = ort.InferenceSession(dl(HF_FI, "onnx/language_model.onnx"), opts, providers)
sess_cd = ort.InferenceSession(dl(HF_BASE, "onnx/conditional_decoder.onnx"), opts, providers)
# ── Precomputed Finnish conditioning ──
cond_emb_raw = open(dl(HF_FI, "onnx/finnish_cond_emb.bin"), "rb").read()
cond_emb = np.frombuffer(cond_emb_raw, dtype=np.float32).reshape(1, 34, 1024)
print(f" cond_emb: {cond_emb.shape}")
# ── KV dtype from model ──
kv_input_name = "past_key_values.0.key"
kv_dtype_str = next(
(inp.type for inp in sess_lm.get_inputs() if inp.name == kv_input_name), "tensor(float)"
)
kv_dtype = np.float16 if "float16" in kv_dtype_str else np.float32
print(f" KV cache dtype: {kv_dtype}")
# ── Step 1: Reference audio β†’ speaker embeddings ──
print(f"\nStep 1: speech_encoder ({REF_AUDIO})")
ref_audio, ref_sr = librosa.load(REF_AUDIO, sr=None)
ref_16k = librosa.resample(ref_audio, orig_sr=ref_sr, target_sr=16000).astype(np.float32)
ref_input = ref_16k[np.newaxis, :] # [1, T]
se_outs = sess_se.run(None, {"audio": ref_input})
speaker_emb = se_outs[0] # [1, 256]
prompt_tokens = se_outs[1] # [1, N]
print(f" speaker_emb: {speaker_emb.shape}")
print(f" prompt_tokens: {prompt_tokens.shape}")
# Note: se_outs[2] is cond_emb from base model β€” we ignore it, use precomputed Finnish version
# ── Step 2: Tokenize text (mirrors worker.ts tokenize()) ──
print(f"\nStep 2: tokenize '{TEXT[:60]}...'")
tok = load_tokenizer("Chatterbox-Finnish/pretrained_models/tokenizer.json")
token_ids = tokenize(tok, TEXT)
print(f" token_ids ({len(token_ids)}): {token_ids[:6]}...{token_ids[-3:]}")
text_ids_np = np.array([token_ids], dtype=np.int64) # [1, T]
# ── Step 3: Embed text tokens ──
print(f"\nStep 3: embed_tokens")
text_embeds = sess_et.run(None, {"input_ids": text_ids_np})[0] # [1, T, 1024]
print(f" text_embeds: {text_embeds.shape}")
# Embed BOS speech token
bos_ids = np.array([[START_SPEECH]], dtype=np.int64)
bos_emb = sess_et.run(None, {"input_ids": bos_ids})[0] # [1, 1, 1024]
# ── Step 4: Build prefill ──
# Matches PyTorch: inputs_embeds = cat([cond_emb, text_emb, bos_emb])
prefill_cond = np.concatenate([cond_emb, text_embeds, bos_emb], axis=1)
zeros_text = np.zeros_like(text_embeds)
prefill_uncond = np.concatenate([cond_emb, zeros_text, bos_emb], axis=1)
mask_cond = np.ones((1, prefill_cond.shape[1]), dtype=np.int64)
mask_uncond = np.ones((1, prefill_uncond.shape[1]), dtype=np.int64)
kv_empty_layer = np.zeros((1, 16, 0, 64), dtype=kv_dtype)
kv_cond = [(kv_empty_layer.copy(), kv_empty_layer.copy()) for _ in range(30)]
kv_uncond = [(kv_empty_layer.copy(), kv_empty_layer.copy()) for _ in range(30)]
def lm_step(embeds, mask, kv):
feeds = {"inputs_embeds": embeds, "attention_mask": mask}
feeds.update(make_kv_feeds(kv))
outs = sess_lm.run(None, feeds)
logits = outs[0] # [1, seq, vocab]
new_kv = [(outs[1 + i*2], outs[1 + i*2 + 1]) for i in range(30)]
return logits, new_kv
# ── Step 5: Prefill both streams ──
print(f"\nStep 4: prefill ({prefill_cond.shape[1]} tokens)")
t0 = time.time()
logits_c, kv_cond = lm_step(prefill_cond, mask_cond, kv_cond)
logits_uc, kv_uncond = lm_step(prefill_uncond, mask_uncond, kv_uncond)
print(f" prefill done ({time.time()-t0:.1f}s)")
# ── Step 6: Autoregressive generation ──
print(f"\nStep 5: generate (max {MAX_STEPS} steps)")
generated = [START_SPEECH]
speech_tokens = []
t0 = time.time()
for step in range(MAX_STEPS):
last_ids = np.array([[generated[-1]]], dtype=np.int64)
last_emb = sess_et.run(None, {"input_ids": last_ids})[0] # [1, 1, 1024]
seq_len_c = kv_cond[0][0].shape[2] + 1
seq_len_uc = kv_uncond[0][0].shape[2] + 1
mask_c = np.ones((1, seq_len_c), dtype=np.int64)
mask_uc = np.ones((1, seq_len_uc), dtype=np.int64)
logits_c, kv_cond = lm_step(last_emb, mask_c, kv_cond)
logits_uc, kv_uncond = lm_step(last_emb, mask_uc, kv_uncond)
# CFG: final = cond + cfg_weight * (cond - uncond)
lc = logits_c[0, -1].astype(np.float32)
luc = logits_uc[0, -1].astype(np.float32)
final_logits = lc + CFG_WEIGHT * (lc - luc)
# Apply rep penalty + min_p + temperature sample
final_logits = apply_rep_penalty(final_logits, set(generated), REP_PENALTY)
final_logits = apply_min_p(final_logits, p=0.05)
token = sample_with_temperature(final_logits, TEMPERATURE)
if token == STOP_SPEECH and len(speech_tokens) >= MIN_SPEECH_TOKENS:
print(f" EOS at step {step} ({len(speech_tokens)} speech tokens)")
break
generated.append(token)
if token < START_SPEECH:
speech_tokens.append(token)
if (step + 1) % 100 == 0:
elapsed = time.time() - t0
rate = (step + 1) / elapsed
print(f" step {step+1:4d}: {len(speech_tokens):3d} speech tokens ({rate:.1f} tok/s)")
elapsed = time.time() - t0
print(f" generation done: {len(speech_tokens)} speech tokens in {elapsed:.1f}s")
# ── Step 7: Decode β†’ waveform ──
print(f"\nStep 6: conditional_decoder")
speech_tok_arr = np.array([speech_tokens], dtype=np.int64)
cd_out = sess_cd.run(None, {
"speech_tokens": speech_tok_arr,
"speaker_embeddings": speaker_emb,
})
wav = cd_out[0].squeeze().astype(np.float32)
# Normalize (mirrors browser worker)
peak = np.abs(wav).max()
if peak < 0.01:
print(f" warning: very low amplitude (peak={peak:.4f}), auto-normalizing")
wav = wav * (0.9 / peak)
wav = np.clip(wav, -1.0, 1.0)
out_path = str(OUT_DIR / "browser_sim_output.wav")
sf.write(out_path, wav, SAMPLE_RATE)
print(f"\nSaved: {out_path} ({len(wav)/SAMPLE_RATE:.2f}s, peak={np.abs(wav).max():.4f})")
# ── Transcribe ──
if GROQ_KEY:
print("\nTranscribing with Groq Whisper...")
with open(out_path, "rb") as f:
r = requests.post(
"https://api.groq.com/openai/v1/audio/transcriptions",
headers={"Authorization": f"Bearer {GROQ_KEY}"},
files={"file": (os.path.basename(out_path), f, "audio/wav")},
data={"model": "whisper-large-v3", "language": "fi", "response_format": "text"},
)
if r.ok:
print(f"\nTranscript: '{r.text.strip()}'")
print(f"Target text: '{TEXT}'")
else:
print(f" Groq error: {r.status_code} {r.text}")
if __name__ == "__main__":
main()