| """ |
| browser_pipeline_sim.py |
| |
| Python simulation of the browser worker.ts logic. |
| Mirrors the EXACT token assembly, CFG loop, sampling, and decoding |
| that the browser worker does β making it trivial to diff against worker.ts. |
| |
| If this produces correct audio, worker.ts with the same logic will too. |
| If there's a bug in the browser, you can find it by diffing this script against worker.ts. |
| |
| Usage: |
| cd /workspaces/work |
| conda run -n chatterbox-onnx python browser_pipeline_sim.py [text] [ref_audio] |
| |
| Outputs: |
| _cmp/browser_sim_output.wav |
| """ |
|
|
| import os, sys, time, json, struct, math |
| import numpy as np |
| import onnxruntime as ort |
| import soundfile as sf |
| import librosa |
| import requests |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download |
|
|
| sys.path.insert(0, "Chatterbox-Finnish") |
|
|
| |
| TEXT = sys.argv[1] if len(sys.argv) > 1 else \ |
| "Tervetuloa kokeilemaan hienoviritettyΓ€ suomenkielistΓ€ Chatterbox-puhesynteesiΓ€." |
| REF_AUDIO = sys.argv[2] if len(sys.argv) > 2 else \ |
| "Chatterbox-Finnish/samples/reference_finnish.wav" |
|
|
| HF_BASE = "onnx-community/chatterbox-multilingual-ONNX" |
| HF_FI = "RASMUS/Chatterbox-Finnish-ONNX" |
| CACHE_DIR = "_onnx_cache" |
| OUT_DIR = Path("_cmp"); OUT_DIR.mkdir(exist_ok=True) |
|
|
| |
| SOT = 255 |
| EOT = 0 |
| START_SPEECH = 6561 |
| STOP_SPEECH = 6562 |
| CFG_WEIGHT = 0.5 |
| REP_PENALTY = 1.2 |
| TEMPERATURE = 0.8 |
| EXAGGERATION = 0.6 |
| MIN_SPEECH_TOKENS = 40 |
| MAX_STEPS = 1000 |
| SAMPLE_RATE = 24000 |
|
|
| GROQ_KEY = os.environ.get("GROQ_API_KEY", "") |
|
|
| |
| def punc_norm(text: str) -> str: |
| """Mirrors puncNorm() in worker.ts""" |
| import re |
| t = text.strip() |
| t = re.sub(r'\s+', ' ', t) |
| t = t[0].upper() + t[1:] if t else t |
| t = t.replace(' .', '.').replace(' ,', ',') |
| t = t.replace(' ?', '?').replace(' !', '!') |
| if t and t[-1] not in '.!?β¦': |
| t += '.' |
| return t |
|
|
|
|
| |
| def load_tokenizer(tokenizer_path: str): |
| from src.chatterbox_.models.tokenizers.tokenizer import EnTokenizer |
| return EnTokenizer(tokenizer_path) |
|
|
|
|
| |
| def tokenize(tok, text: str) -> list[int]: |
| """ |
| Mirrors browser worker tokenize(): |
| 1. puncNorm |
| 2. replace spaces with [SPACE] token |
| 3. encode grapheme by grapheme |
| 4. wrap with [SOT, ..., EOT] |
| """ |
| normed = punc_norm(text) |
| ids = tok.encode(normed) |
| return [SOT] + ids + [EOT] |
|
|
|
|
| |
| def dl(repo_id, filename): |
| return hf_hub_download(repo_id=repo_id, filename=filename, |
| local_dir=CACHE_DIR, local_dir_use_symlinks=False) |
|
|
|
|
| |
| def empty_kv(n_layers=30, kv_dtype=np.float32): |
| return [(np.zeros((1, 16, 0, 64), dtype=kv_dtype), |
| np.zeros((1, 16, 0, 64), dtype=kv_dtype)) |
| for _ in range(n_layers)] |
|
|
|
|
| def make_kv_feeds(kv_cache): |
| feeds = {} |
| for i, (k, v) in enumerate(kv_cache): |
| feeds[f"past_key_values.{i}.key"] = k |
| feeds[f"past_key_values.{i}.value"] = v |
| return feeds |
|
|
|
|
| |
| def apply_rep_penalty(logits, generated_set, penalty): |
| logits = logits.copy() |
| for tok in generated_set: |
| logits[tok] = logits[tok] / penalty if logits[tok] > 0 else logits[tok] * penalty |
| return logits |
|
|
|
|
| def apply_min_p(logits, p=0.05): |
| """Mirrors applyMinP() in worker.ts""" |
| logits = logits.copy() |
| probs = np.exp(logits - logits.max()) |
| probs /= probs.sum() |
| threshold = probs.max() * p |
| logits[probs < threshold] = -1e9 |
| return logits |
|
|
|
|
| def sample_with_temperature(logits, temperature): |
| """Mirrors sampleWithTemperature() in worker.ts""" |
| logits = (logits / temperature).astype(np.float64) |
| logits -= logits.max() |
| probs = np.exp(logits) |
| probs /= probs.sum() |
| return int(np.random.choice(len(probs), p=probs)) |
|
|
|
|
| |
| def main(): |
| providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| opts = ort.SessionOptions() |
| opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
| |
| print("Loading ONNX sessions...") |
| sess_se = ort.InferenceSession(dl(HF_BASE, "onnx/speech_encoder.onnx"), opts, providers) |
| sess_et = ort.InferenceSession(dl(HF_BASE, "onnx/embed_tokens.onnx"), opts, providers) |
| sess_lm = ort.InferenceSession(dl(HF_FI, "onnx/language_model.onnx"), opts, providers) |
| sess_cd = ort.InferenceSession(dl(HF_BASE, "onnx/conditional_decoder.onnx"), opts, providers) |
|
|
| |
| cond_emb_raw = open(dl(HF_FI, "onnx/finnish_cond_emb.bin"), "rb").read() |
| cond_emb = np.frombuffer(cond_emb_raw, dtype=np.float32).reshape(1, 34, 1024) |
| print(f" cond_emb: {cond_emb.shape}") |
|
|
| |
| kv_input_name = "past_key_values.0.key" |
| kv_dtype_str = next( |
| (inp.type for inp in sess_lm.get_inputs() if inp.name == kv_input_name), "tensor(float)" |
| ) |
| kv_dtype = np.float16 if "float16" in kv_dtype_str else np.float32 |
| print(f" KV cache dtype: {kv_dtype}") |
|
|
| |
| print(f"\nStep 1: speech_encoder ({REF_AUDIO})") |
| ref_audio, ref_sr = librosa.load(REF_AUDIO, sr=None) |
| ref_16k = librosa.resample(ref_audio, orig_sr=ref_sr, target_sr=16000).astype(np.float32) |
| ref_input = ref_16k[np.newaxis, :] |
|
|
| se_outs = sess_se.run(None, {"audio": ref_input}) |
| speaker_emb = se_outs[0] |
| prompt_tokens = se_outs[1] |
| print(f" speaker_emb: {speaker_emb.shape}") |
| print(f" prompt_tokens: {prompt_tokens.shape}") |
| |
|
|
| |
| print(f"\nStep 2: tokenize '{TEXT[:60]}...'") |
| tok = load_tokenizer("Chatterbox-Finnish/pretrained_models/tokenizer.json") |
| token_ids = tokenize(tok, TEXT) |
| print(f" token_ids ({len(token_ids)}): {token_ids[:6]}...{token_ids[-3:]}") |
| text_ids_np = np.array([token_ids], dtype=np.int64) |
|
|
| |
| print(f"\nStep 3: embed_tokens") |
| text_embeds = sess_et.run(None, {"input_ids": text_ids_np})[0] |
| print(f" text_embeds: {text_embeds.shape}") |
|
|
| |
| bos_ids = np.array([[START_SPEECH]], dtype=np.int64) |
| bos_emb = sess_et.run(None, {"input_ids": bos_ids})[0] |
|
|
| |
| |
| prefill_cond = np.concatenate([cond_emb, text_embeds, bos_emb], axis=1) |
| zeros_text = np.zeros_like(text_embeds) |
| prefill_uncond = np.concatenate([cond_emb, zeros_text, bos_emb], axis=1) |
|
|
| mask_cond = np.ones((1, prefill_cond.shape[1]), dtype=np.int64) |
| mask_uncond = np.ones((1, prefill_uncond.shape[1]), dtype=np.int64) |
|
|
| kv_empty_layer = np.zeros((1, 16, 0, 64), dtype=kv_dtype) |
| kv_cond = [(kv_empty_layer.copy(), kv_empty_layer.copy()) for _ in range(30)] |
| kv_uncond = [(kv_empty_layer.copy(), kv_empty_layer.copy()) for _ in range(30)] |
|
|
| def lm_step(embeds, mask, kv): |
| feeds = {"inputs_embeds": embeds, "attention_mask": mask} |
| feeds.update(make_kv_feeds(kv)) |
| outs = sess_lm.run(None, feeds) |
| logits = outs[0] |
| new_kv = [(outs[1 + i*2], outs[1 + i*2 + 1]) for i in range(30)] |
| return logits, new_kv |
|
|
| |
| print(f"\nStep 4: prefill ({prefill_cond.shape[1]} tokens)") |
| t0 = time.time() |
| logits_c, kv_cond = lm_step(prefill_cond, mask_cond, kv_cond) |
| logits_uc, kv_uncond = lm_step(prefill_uncond, mask_uncond, kv_uncond) |
| print(f" prefill done ({time.time()-t0:.1f}s)") |
|
|
| |
| print(f"\nStep 5: generate (max {MAX_STEPS} steps)") |
| generated = [START_SPEECH] |
| speech_tokens = [] |
| t0 = time.time() |
|
|
| for step in range(MAX_STEPS): |
| last_ids = np.array([[generated[-1]]], dtype=np.int64) |
| last_emb = sess_et.run(None, {"input_ids": last_ids})[0] |
|
|
| seq_len_c = kv_cond[0][0].shape[2] + 1 |
| seq_len_uc = kv_uncond[0][0].shape[2] + 1 |
| mask_c = np.ones((1, seq_len_c), dtype=np.int64) |
| mask_uc = np.ones((1, seq_len_uc), dtype=np.int64) |
|
|
| logits_c, kv_cond = lm_step(last_emb, mask_c, kv_cond) |
| logits_uc, kv_uncond = lm_step(last_emb, mask_uc, kv_uncond) |
|
|
| |
| lc = logits_c[0, -1].astype(np.float32) |
| luc = logits_uc[0, -1].astype(np.float32) |
| final_logits = lc + CFG_WEIGHT * (lc - luc) |
|
|
| |
| final_logits = apply_rep_penalty(final_logits, set(generated), REP_PENALTY) |
| final_logits = apply_min_p(final_logits, p=0.05) |
| token = sample_with_temperature(final_logits, TEMPERATURE) |
|
|
| if token == STOP_SPEECH and len(speech_tokens) >= MIN_SPEECH_TOKENS: |
| print(f" EOS at step {step} ({len(speech_tokens)} speech tokens)") |
| break |
|
|
| generated.append(token) |
| if token < START_SPEECH: |
| speech_tokens.append(token) |
|
|
| if (step + 1) % 100 == 0: |
| elapsed = time.time() - t0 |
| rate = (step + 1) / elapsed |
| print(f" step {step+1:4d}: {len(speech_tokens):3d} speech tokens ({rate:.1f} tok/s)") |
|
|
| elapsed = time.time() - t0 |
| print(f" generation done: {len(speech_tokens)} speech tokens in {elapsed:.1f}s") |
|
|
| |
| print(f"\nStep 6: conditional_decoder") |
| speech_tok_arr = np.array([speech_tokens], dtype=np.int64) |
| cd_out = sess_cd.run(None, { |
| "speech_tokens": speech_tok_arr, |
| "speaker_embeddings": speaker_emb, |
| }) |
| wav = cd_out[0].squeeze().astype(np.float32) |
|
|
| |
| peak = np.abs(wav).max() |
| if peak < 0.01: |
| print(f" warning: very low amplitude (peak={peak:.4f}), auto-normalizing") |
| wav = wav * (0.9 / peak) |
| wav = np.clip(wav, -1.0, 1.0) |
|
|
| out_path = str(OUT_DIR / "browser_sim_output.wav") |
| sf.write(out_path, wav, SAMPLE_RATE) |
| print(f"\nSaved: {out_path} ({len(wav)/SAMPLE_RATE:.2f}s, peak={np.abs(wav).max():.4f})") |
|
|
| |
| if GROQ_KEY: |
| print("\nTranscribing with Groq Whisper...") |
| with open(out_path, "rb") as f: |
| r = requests.post( |
| "https://api.groq.com/openai/v1/audio/transcriptions", |
| headers={"Authorization": f"Bearer {GROQ_KEY}"}, |
| files={"file": (os.path.basename(out_path), f, "audio/wav")}, |
| data={"model": "whisper-large-v3", "language": "fi", "response_format": "text"}, |
| ) |
| if r.ok: |
| print(f"\nTranscript: '{r.text.strip()}'") |
| print(f"Target text: '{TEXT}'") |
| else: |
| print(f" Groq error: {r.status_code} {r.text}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|