""" browser_pipeline_sim.py Python simulation of the browser worker.ts logic. Mirrors the EXACT token assembly, CFG loop, sampling, and decoding that the browser worker does — making it trivial to diff against worker.ts. If this produces correct audio, worker.ts with the same logic will too. If there's a bug in the browser, you can find it by diffing this script against worker.ts. Usage: cd /workspaces/work conda run -n chatterbox-onnx python browser_pipeline_sim.py [text] [ref_audio] Outputs: _cmp/browser_sim_output.wav """ import os, sys, time, json, struct, math import numpy as np import onnxruntime as ort import soundfile as sf import librosa import requests from pathlib import Path from huggingface_hub import hf_hub_download sys.path.insert(0, "Chatterbox-Finnish") # ── Config — mirrors browser worker constants ───────────────────────────────── TEXT = sys.argv[1] if len(sys.argv) > 1 else \ "Tervetuloa kokeilemaan hienoviritettyä suomenkielistä Chatterbox-puhesynteesiä." REF_AUDIO = sys.argv[2] if len(sys.argv) > 2 else \ "Chatterbox-Finnish/samples/reference_finnish.wav" HF_BASE = "onnx-community/chatterbox-multilingual-ONNX" HF_FI = "RASMUS/Chatterbox-Finnish-ONNX" CACHE_DIR = "_onnx_cache" OUT_DIR = Path("_cmp"); OUT_DIR.mkdir(exist_ok=True) # Worker constants (from worker.ts) SOT = 255 # [START] token EOT = 0 # [STOP] token START_SPEECH = 6561 # BOS speech token STOP_SPEECH = 6562 # EOS speech token CFG_WEIGHT = 0.5 REP_PENALTY = 1.2 TEMPERATURE = 0.8 EXAGGERATION = 0.6 MIN_SPEECH_TOKENS = 40 MAX_STEPS = 1000 SAMPLE_RATE = 24000 GROQ_KEY = os.environ.get("GROQ_API_KEY", "") # ── puncNorm — mirrors worker.ts puncNorm() ─────────────────────────────────── def punc_norm(text: str) -> str: """Mirrors puncNorm() in worker.ts""" import re t = text.strip() t = re.sub(r'\s+', ' ', t) # normalize whitespace t = t[0].upper() + t[1:] if t else t # capitalize first letter t = t.replace(' .', '.').replace(' ,', ',') # remove space before punct t = t.replace(' ?', '?').replace(' !', '!') if t and t[-1] not in '.!?…': # ensure ending punctuation t += '.' return t # ── EnTokenizer (same as Chatterbox-Finnish/src/chatterbox_/models/tokenizers/tokenizer.py) def load_tokenizer(tokenizer_path: str): from src.chatterbox_.models.tokenizers.tokenizer import EnTokenizer return EnTokenizer(tokenizer_path) # ── puncNorm + tokenize → mirrors worker.ts tokenize() ─────────────────────── def tokenize(tok, text: str) -> list[int]: """ Mirrors browser worker tokenize(): 1. puncNorm 2. replace spaces with [SPACE] token 3. encode grapheme by grapheme 4. wrap with [SOT, ..., EOT] """ normed = punc_norm(text) ids = tok.encode(normed) return [SOT] + ids + [EOT] # ── Download helpers ─────────────────────────────────────────────────────────── def dl(repo_id, filename): return hf_hub_download(repo_id=repo_id, filename=filename, local_dir=CACHE_DIR, local_dir_use_symlinks=False) # ── KV cache helpers ─────────────────────────────────────────────────────────── def empty_kv(n_layers=30, kv_dtype=np.float32): return [(np.zeros((1, 16, 0, 64), dtype=kv_dtype), np.zeros((1, 16, 0, 64), dtype=kv_dtype)) for _ in range(n_layers)] def make_kv_feeds(kv_cache): feeds = {} for i, (k, v) in enumerate(kv_cache): feeds[f"past_key_values.{i}.key"] = k feeds[f"past_key_values.{i}.value"] = v return feeds # ── Sampling helpers ─────────────────────────────────────────────────────────── def apply_rep_penalty(logits, generated_set, penalty): logits = logits.copy() for tok in generated_set: logits[tok] = logits[tok] / penalty if logits[tok] > 0 else logits[tok] * penalty return logits def apply_min_p(logits, p=0.05): """Mirrors applyMinP() in worker.ts""" logits = logits.copy() probs = np.exp(logits - logits.max()) probs /= probs.sum() threshold = probs.max() * p logits[probs < threshold] = -1e9 return logits def sample_with_temperature(logits, temperature): """Mirrors sampleWithTemperature() in worker.ts""" logits = (logits / temperature).astype(np.float64) logits -= logits.max() probs = np.exp(logits) probs /= probs.sum() return int(np.random.choice(len(probs), p=probs)) # ── Main pipeline ───────────────────────────────────────────────────────────── def main(): providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] opts = ort.SessionOptions() opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL # ── Load sessions ── print("Loading ONNX sessions...") sess_se = ort.InferenceSession(dl(HF_BASE, "onnx/speech_encoder.onnx"), opts, providers) sess_et = ort.InferenceSession(dl(HF_BASE, "onnx/embed_tokens.onnx"), opts, providers) sess_lm = ort.InferenceSession(dl(HF_FI, "onnx/language_model.onnx"), opts, providers) sess_cd = ort.InferenceSession(dl(HF_BASE, "onnx/conditional_decoder.onnx"), opts, providers) # ── Precomputed Finnish conditioning ── cond_emb_raw = open(dl(HF_FI, "onnx/finnish_cond_emb.bin"), "rb").read() cond_emb = np.frombuffer(cond_emb_raw, dtype=np.float32).reshape(1, 34, 1024) print(f" cond_emb: {cond_emb.shape}") # ── KV dtype from model ── kv_input_name = "past_key_values.0.key" kv_dtype_str = next( (inp.type for inp in sess_lm.get_inputs() if inp.name == kv_input_name), "tensor(float)" ) kv_dtype = np.float16 if "float16" in kv_dtype_str else np.float32 print(f" KV cache dtype: {kv_dtype}") # ── Step 1: Reference audio → speaker embeddings ── print(f"\nStep 1: speech_encoder ({REF_AUDIO})") ref_audio, ref_sr = librosa.load(REF_AUDIO, sr=None) ref_16k = librosa.resample(ref_audio, orig_sr=ref_sr, target_sr=16000).astype(np.float32) ref_input = ref_16k[np.newaxis, :] # [1, T] se_outs = sess_se.run(None, {"audio": ref_input}) speaker_emb = se_outs[0] # [1, 256] prompt_tokens = se_outs[1] # [1, N] print(f" speaker_emb: {speaker_emb.shape}") print(f" prompt_tokens: {prompt_tokens.shape}") # Note: se_outs[2] is cond_emb from base model — we ignore it, use precomputed Finnish version # ── Step 2: Tokenize text (mirrors worker.ts tokenize()) ── print(f"\nStep 2: tokenize '{TEXT[:60]}...'") tok = load_tokenizer("Chatterbox-Finnish/pretrained_models/tokenizer.json") token_ids = tokenize(tok, TEXT) print(f" token_ids ({len(token_ids)}): {token_ids[:6]}...{token_ids[-3:]}") text_ids_np = np.array([token_ids], dtype=np.int64) # [1, T] # ── Step 3: Embed text tokens ── print(f"\nStep 3: embed_tokens") text_embeds = sess_et.run(None, {"input_ids": text_ids_np})[0] # [1, T, 1024] print(f" text_embeds: {text_embeds.shape}") # Embed BOS speech token bos_ids = np.array([[START_SPEECH]], dtype=np.int64) bos_emb = sess_et.run(None, {"input_ids": bos_ids})[0] # [1, 1, 1024] # ── Step 4: Build prefill ── # Matches PyTorch: inputs_embeds = cat([cond_emb, text_emb, bos_emb]) prefill_cond = np.concatenate([cond_emb, text_embeds, bos_emb], axis=1) zeros_text = np.zeros_like(text_embeds) prefill_uncond = np.concatenate([cond_emb, zeros_text, bos_emb], axis=1) mask_cond = np.ones((1, prefill_cond.shape[1]), dtype=np.int64) mask_uncond = np.ones((1, prefill_uncond.shape[1]), dtype=np.int64) kv_empty_layer = np.zeros((1, 16, 0, 64), dtype=kv_dtype) kv_cond = [(kv_empty_layer.copy(), kv_empty_layer.copy()) for _ in range(30)] kv_uncond = [(kv_empty_layer.copy(), kv_empty_layer.copy()) for _ in range(30)] def lm_step(embeds, mask, kv): feeds = {"inputs_embeds": embeds, "attention_mask": mask} feeds.update(make_kv_feeds(kv)) outs = sess_lm.run(None, feeds) logits = outs[0] # [1, seq, vocab] new_kv = [(outs[1 + i*2], outs[1 + i*2 + 1]) for i in range(30)] return logits, new_kv # ── Step 5: Prefill both streams ── print(f"\nStep 4: prefill ({prefill_cond.shape[1]} tokens)") t0 = time.time() logits_c, kv_cond = lm_step(prefill_cond, mask_cond, kv_cond) logits_uc, kv_uncond = lm_step(prefill_uncond, mask_uncond, kv_uncond) print(f" prefill done ({time.time()-t0:.1f}s)") # ── Step 6: Autoregressive generation ── print(f"\nStep 5: generate (max {MAX_STEPS} steps)") generated = [START_SPEECH] speech_tokens = [] t0 = time.time() for step in range(MAX_STEPS): last_ids = np.array([[generated[-1]]], dtype=np.int64) last_emb = sess_et.run(None, {"input_ids": last_ids})[0] # [1, 1, 1024] seq_len_c = kv_cond[0][0].shape[2] + 1 seq_len_uc = kv_uncond[0][0].shape[2] + 1 mask_c = np.ones((1, seq_len_c), dtype=np.int64) mask_uc = np.ones((1, seq_len_uc), dtype=np.int64) logits_c, kv_cond = lm_step(last_emb, mask_c, kv_cond) logits_uc, kv_uncond = lm_step(last_emb, mask_uc, kv_uncond) # CFG: final = cond + cfg_weight * (cond - uncond) lc = logits_c[0, -1].astype(np.float32) luc = logits_uc[0, -1].astype(np.float32) final_logits = lc + CFG_WEIGHT * (lc - luc) # Apply rep penalty + min_p + temperature sample final_logits = apply_rep_penalty(final_logits, set(generated), REP_PENALTY) final_logits = apply_min_p(final_logits, p=0.05) token = sample_with_temperature(final_logits, TEMPERATURE) if token == STOP_SPEECH and len(speech_tokens) >= MIN_SPEECH_TOKENS: print(f" EOS at step {step} ({len(speech_tokens)} speech tokens)") break generated.append(token) if token < START_SPEECH: speech_tokens.append(token) if (step + 1) % 100 == 0: elapsed = time.time() - t0 rate = (step + 1) / elapsed print(f" step {step+1:4d}: {len(speech_tokens):3d} speech tokens ({rate:.1f} tok/s)") elapsed = time.time() - t0 print(f" generation done: {len(speech_tokens)} speech tokens in {elapsed:.1f}s") # ── Step 7: Decode → waveform ── print(f"\nStep 6: conditional_decoder") speech_tok_arr = np.array([speech_tokens], dtype=np.int64) cd_out = sess_cd.run(None, { "speech_tokens": speech_tok_arr, "speaker_embeddings": speaker_emb, }) wav = cd_out[0].squeeze().astype(np.float32) # Normalize (mirrors browser worker) peak = np.abs(wav).max() if peak < 0.01: print(f" warning: very low amplitude (peak={peak:.4f}), auto-normalizing") wav = wav * (0.9 / peak) wav = np.clip(wav, -1.0, 1.0) out_path = str(OUT_DIR / "browser_sim_output.wav") sf.write(out_path, wav, SAMPLE_RATE) print(f"\nSaved: {out_path} ({len(wav)/SAMPLE_RATE:.2f}s, peak={np.abs(wav).max():.4f})") # ── Transcribe ── if GROQ_KEY: print("\nTranscribing with Groq Whisper...") with open(out_path, "rb") as f: r = requests.post( "https://api.groq.com/openai/v1/audio/transcriptions", headers={"Authorization": f"Bearer {GROQ_KEY}"}, files={"file": (os.path.basename(out_path), f, "audio/wav")}, data={"model": "whisper-large-v3", "language": "fi", "response_format": "text"}, ) if r.ok: print(f"\nTranscript: '{r.text.strip()}'") print(f"Target text: '{TEXT}'") else: print(f" Groq error: {r.status_code} {r.text}") if __name__ == "__main__": main()