import os import time import json import base64 import tempfile import numpy as np import onnxruntime import soundfile as sf import librosa from tqdm import tqdm from huggingface_hub import hf_hub_download from transformers import AutoTokenizer from unicodedata import category # Constants from model card S3GEN_SR = 24000 START_SPEECH_TOKEN = 6561 STOP_SPEECH_TOKEN = 6562 MODEL_ID = "onnx-community/chatterbox-multilingual-ONNX" # Cache for sessions and helpers SESSIONS = { "speech_encoder": None, "embed_tokens": None, "language_model": None, "conditional_decoder": None, "tokenizer": None, "cangjie": None, "kakasi": None } class RepetitionPenaltyLogitsProcessor: def __init__(self, penalty: float): self.penalty = penalty def __call__(self, input_ids: np.ndarray, scores: np.ndarray) -> np.ndarray: score = np.take_along_axis(scores, input_ids, axis=1) score = np.where(score < 0, score * self.penalty, score / self.penalty) scores_processed = scores.copy() np.put_along_axis(scores_processed, input_ids, score, axis=1) return scores_processed class ChineseCangjieConverter: def __init__(self): self.word2cj = {} self.cj2word = {} self.segmenter = None self._load_cangjie_mapping() self._init_segmenter() def _load_cangjie_mapping(self): try: cangjie_file = hf_hub_download(repo_id=MODEL_ID, filename="Cangjie5_TC.json") with open(cangjie_file, "r", encoding="utf-8") as fp: data = json.load(fp) for entry in data: word, code = entry.split("\t")[:2] self.word2cj[word] = code if code not in self.cj2word: self.cj2word[code] = [word] else: self.cj2word[code].append(word) except Exception as e: print(f"Cangjie error: {e}") def _init_segmenter(self): try: import jieba # Silence jieba logs import logging jieba.setLogLevel(logging.ERROR) self.segmenter = jieba except: self.segmenter = None def _cangjie_encode(self, glyph: str): code = self.word2cj.get(glyph) if code is None: return None index = self.cj2word[code].index(glyph) return code + (str(index) if index > 0 else "") def __call__(self, text): if self.segmenter: text = " ".join(self.segmenter.cut(text)) output = [] for t in text: if category(t) == "Lo": cangjie = self._cangjie_encode(t) if not cangjie: output.append(t); continue output.append("".join([f"[cj_{c}]" for c in cangjie]) + "[cj_.]") else: output.append(t) return "".join(output) def hiragana_normalize(text): try: import pykakasi if not SESSIONS["kakasi"]: SESSIONS["kakasi"] = pykakasi.kakasi() result = SESSIONS["kakasi"].convert(text) out = [] for r in result: inp, hira = r['orig'], r['hira'] if any([19968 <= ord(c) <= 40959 for c in inp]): out.append(hira) else: out.append(inp) import unicodedata return unicodedata.normalize('NFKD', "".join(out)) except: return text def korean_normalize(text): def decomp(char): if not ('\uac00' <= char <= '\ud7af'): return char base = ord(char) - 0xAC00 i, m, f = chr(0x1100 + base // 588), chr(0x1161 + (base % 588) // 28), chr(0x11A7 + base % 28) if base % 28 > 0 else '' return i + m + f return "".join(decomp(c) for c in text).strip() def prepare_language(txt, lang_id): if lang_id == 'zh': if not SESSIONS["cangjie"]: SESSIONS["cangjie"] = ChineseCangjieConverter() txt = SESSIONS["cangjie"](txt) elif lang_id == 'ja': txt = hiragana_normalize(txt) elif lang_id == 'ko': txt = korean_normalize(txt) return f"[{lang_id.lower()}]{txt}" if lang_id else txt def load_chatterbox(device="cpu"): """Pre-load ONNX sessions - v111: Forced CPU for stability""" if SESSIONS["speech_encoder"]: return print(f"🚀 Loading Chatterbox ONNX into CPU (ZeroGPU Safe Mode)...") opts = onnxruntime.SessionOptions() provs = ["CPUExecutionProvider"] for sess_name in ["speech_encoder", "embed_tokens", "conditional_decoder", "language_model"]: fname = "onnx/" + (sess_name + ".onnx" if sess_name != "language_model" else "language_model.onnx") path = hf_hub_download(repo_id=MODEL_ID, filename=fname) hf_hub_download(repo_id=MODEL_ID, filename=fname + "_data", local_files_only=False) # Ensure sidecar data is present SESSIONS[sess_name] = onnxruntime.InferenceSession(path, providers=provs) SESSIONS["tokenizer"] = AutoTokenizer.from_pretrained(MODEL_ID) def warmup_chatterbox(): """v92: Pre-download model files in background""" print("📥 Caching Chatterbox weights (ONNX)...") try: AutoTokenizer.from_pretrained(MODEL_ID) hf_hub_download(repo_id=MODEL_ID, filename="default_voice.wav") for sess_name in ["speech_encoder", "embed_tokens", "conditional_decoder", "language_model"]: fname = "onnx/" + (sess_name + ".onnx" if sess_name != "language_model" else "language_model.onnx") hf_hub_download(repo_id=MODEL_ID, filename=fname) hf_hub_download(repo_id=MODEL_ID, filename=fname + "_data") print("✅ Chatterbox cached.") except Exception as e: print(f"⚠️ Chatterbox cache warning: {e}") def run_chatterbox_inference(text, lang_id, speaker_wav_path=None): """Ported logic from model card with session reuse""" load_chatterbox() # Ensure sessions ready if not speaker_wav_path: speaker_wav_path = hf_hub_download(repo_id=MODEL_ID, filename="default_voice.wav") audio_values, _ = librosa.load(speaker_wav_path, sr=S3GEN_SR) audio_values = audio_values[np.newaxis, :].astype(np.float32) text = prepare_language(text, lang_id) input_ids = SESSIONS["tokenizer"](text, return_tensors="np")["input_ids"].astype(np.int64) position_ids = np.where(input_ids >= START_SPEECH_TOKEN, 0, np.arange(input_ids.shape[1])[np.newaxis, :] - 1) ort_embed_tokens_inputs = { "input_ids": input_ids, "position_ids": position_ids.astype(np.int64), "exaggeration": np.array([0.5], dtype=np.float32) } repartition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=1.2) generate_tokens = np.array([[START_SPEECH_TOKEN]]) # Simple loop as per model card batch_size = 1 num_hidden_layers = 30 num_key_value_heads = 16 head_dim = 64 max_tokens = 256 past_key_values = None attention_mask = None for i in range(max_tokens): inputs_embeds = SESSIONS["embed_tokens"].run(None, ort_embed_tokens_inputs)[0] if i == 0: cond_emb, prompt_token, ref_x_vector, prompt_feat = SESSIONS["speech_encoder"].run(None, {"audio_values": audio_values}) inputs_embeds = np.concatenate((cond_emb, inputs_embeds), axis=1) past_key_values = { f"past_key_values.{layer}.{kv}": np.zeros([batch_size, num_key_value_heads, 0, head_dim], dtype=np.float32) for layer in range(num_hidden_layers) for kv in ("key", "value") } attention_mask = np.ones((batch_size, inputs_embeds.shape[1]), dtype=np.int64) logits, *present_key_values = SESSIONS["language_model"].run(None, {**{"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}, **past_key_values}) logits = logits[:, -1, :] next_token_logits = repartition_penalty_processor(generate_tokens, logits) next_token = np.argmax(next_token_logits, axis=-1, keepdims=True).astype(np.int64) generate_tokens = np.concatenate((generate_tokens, next_token), axis=-1) if (next_token.flatten() == STOP_SPEECH_TOKEN).all(): break ort_embed_tokens_inputs["input_ids"] = next_token ort_embed_tokens_inputs["position_ids"] = np.full((1, 1), i + 1, dtype=np.int64) attention_mask = np.concatenate([attention_mask, np.ones((batch_size, 1), dtype=np.int64)], axis=1) for j, key in enumerate(past_key_values): past_key_values[key] = present_key_values[j] # Final Decode speech_tokens = generate_tokens[:, 1:-1] speech_tokens = np.concatenate([prompt_token, speech_tokens], axis=1) wav = SESSIONS["conditional_decoder"].run(None, {"speech_tokens": speech_tokens, "speaker_embeddings": ref_x_vector, "speaker_features": prompt_feat})[0] wav = np.squeeze(wav, axis=0) # Return bytes directly with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: path = f.name sf.write(path, wav, S3GEN_SR) with open(path, "rb") as f: audio = f.read() if os.path.exists(path): os.unlink(path) return audio