TGPro1 commited on
Commit
43de3a2
·
verified ·
1 Parent(s): fd6462f

Upload chatterbox_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. chatterbox_utils.py +195 -0
chatterbox_utils.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import base64
5
+ import tempfile
6
+ import numpy as np
7
+ import onnxruntime
8
+ import soundfile as sf
9
+ import librosa
10
+ from tqdm import tqdm
11
+ from huggingface_hub import hf_hub_download
12
+ from transformers import AutoTokenizer
13
+ from unicodedata import category
14
+
15
+ # Constants from model card
16
+ S3GEN_SR = 24000
17
+ START_SPEECH_TOKEN = 6561
18
+ STOP_SPEECH_TOKEN = 6562
19
+ MODEL_ID = "onnx-community/chatterbox-multilingual-ONNX"
20
+
21
+ # Cache for sessions and helpers
22
+ SESSIONS = {
23
+ "speech_encoder": None,
24
+ "embed_tokens": None,
25
+ "language_model": None,
26
+ "conditional_decoder": None,
27
+ "tokenizer": None,
28
+ "cangjie": None,
29
+ "kakasi": None
30
+ }
31
+
32
+ class RepetitionPenaltyLogitsProcessor:
33
+ def __init__(self, penalty: float):
34
+ self.penalty = penalty
35
+ def __call__(self, input_ids: np.ndarray, scores: np.ndarray) -> np.ndarray:
36
+ score = np.take_along_axis(scores, input_ids, axis=1)
37
+ score = np.where(score < 0, score * self.penalty, score / self.penalty)
38
+ scores_processed = scores.copy()
39
+ np.put_along_axis(scores_processed, input_ids, score, axis=1)
40
+ return scores_processed
41
+
42
+ class ChineseCangjieConverter:
43
+ def __init__(self):
44
+ self.word2cj = {}
45
+ self.cj2word = {}
46
+ self.segmenter = None
47
+ self._load_cangjie_mapping()
48
+ self._init_segmenter()
49
+ def _load_cangjie_mapping(self):
50
+ try:
51
+ cangjie_file = hf_hub_download(repo_id=MODEL_ID, filename="Cangjie5_TC.json")
52
+ with open(cangjie_file, "r", encoding="utf-8") as fp:
53
+ data = json.load(fp)
54
+ for entry in data:
55
+ word, code = entry.split("\t")[:2]
56
+ self.word2cj[word] = code
57
+ if code not in self.cj2word: self.cj2word[code] = [word]
58
+ else: self.cj2word[code].append(word)
59
+ except Exception as e: print(f"Cangjie error: {e}")
60
+ def _init_segmenter(self):
61
+ try:
62
+ from pkuseg import pkuseg
63
+ self.segmenter = pkuseg()
64
+ except: self.segmenter = None
65
+
66
+ def _cangjie_encode(self, glyph: str):
67
+ code = self.word2cj.get(glyph)
68
+ if code is None: return None
69
+ index = self.cj2word[code].index(glyph)
70
+ return code + (str(index) if index > 0 else "")
71
+
72
+ def __call__(self, text):
73
+ if self.segmenter: text = " ".join(self.segmenter.cut(text))
74
+ output = []
75
+ for t in text:
76
+ if category(t) == "Lo":
77
+ cangjie = self._cangjie_encode(t)
78
+ if not cangjie: output.append(t); continue
79
+ output.append("".join([f"[cj_{c}]" for c in cangjie]) + "[cj_.]")
80
+ else: output.append(t)
81
+ return "".join(output)
82
+
83
+ def hiragana_normalize(text):
84
+ try:
85
+ import pykakasi
86
+ if not SESSIONS["kakasi"]: SESSIONS["kakasi"] = pykakasi.kakasi()
87
+ result = SESSIONS["kakasi"].convert(text)
88
+ out = []
89
+ for r in result:
90
+ inp, hira = r['orig'], r['hira']
91
+ if any([19968 <= ord(c) <= 40959 for c in inp]): out.append(hira)
92
+ else: out.append(inp)
93
+ import unicodedata
94
+ return unicodedata.normalize('NFKD', "".join(out))
95
+ except: return text
96
+
97
+ def korean_normalize(text):
98
+ def decomp(char):
99
+ if not ('\uac00' <= char <= '\ud7af'): return char
100
+ base = ord(char) - 0xAC00
101
+ i, m, f = chr(0x1100 + base // 588), chr(0x1161 + (base % 588) // 28), chr(0x11A7 + base % 28) if base % 28 > 0 else ''
102
+ return i + m + f
103
+ return "".join(decomp(c) for c in text).strip()
104
+
105
+ def prepare_language(txt, lang_id):
106
+ if lang_id == 'zh':
107
+ if not SESSIONS["cangjie"]: SESSIONS["cangjie"] = ChineseCangjieConverter()
108
+ txt = SESSIONS["cangjie"](txt)
109
+ elif lang_id == 'ja': txt = hiragana_normalize(txt)
110
+ elif lang_id == 'ko': txt = korean_normalize(txt)
111
+ return f"[{lang_id.lower()}]{txt}" if lang_id else txt
112
+
113
+ def load_chatterbox(device="cuda"):
114
+ """Pre-load ONNX sessions"""
115
+ if SESSIONS["speech_encoder"]: return
116
+ print("🚀 Loading Chatterbox ONNX...")
117
+ opts = onnxruntime.SessionOptions()
118
+ provs = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
119
+
120
+ for sess_name in ["speech_encoder", "embed_tokens", "conditional_decoder", "language_model"]:
121
+ fname = "onnx/" + (sess_name + ".onnx" if sess_name != "language_model" else "language_model.onnx")
122
+ path = hf_hub_download(repo_id=MODEL_ID, filename=fname)
123
+ hf_hub_download(repo_id=MODEL_ID, filename=fname + "_data") # Ensure sidecar data is present
124
+ SESSIONS[sess_name] = onnxruntime.InferenceSession(path, providers=provs)
125
+
126
+ SESSIONS["tokenizer"] = AutoTokenizer.from_pretrained(MODEL_ID)
127
+
128
+ def run_chatterbox_inference(text, lang_id, speaker_wav_path=None):
129
+ """Ported logic from model card with session reuse"""
130
+ load_chatterbox() # Ensure sessions ready
131
+
132
+ if not speaker_wav_path:
133
+ speaker_wav_path = hf_hub_download(repo_id=MODEL_ID, filename="default_voice.wav")
134
+
135
+ audio_values, _ = librosa.load(speaker_wav_path, sr=S3GEN_SR)
136
+ audio_values = audio_values[np.newaxis, :].astype(np.float32)
137
+
138
+ text = prepare_language(text, lang_id)
139
+ input_ids = SESSIONS["tokenizer"](text, return_tensors="np")["input_ids"].astype(np.int64)
140
+
141
+ position_ids = np.where(input_ids >= START_SPEECH_TOKEN, 0, np.arange(input_ids.shape[1])[np.newaxis, :] - 1)
142
+ ort_embed_tokens_inputs = {
143
+ "input_ids": input_ids,
144
+ "position_ids": position_ids.astype(np.int64),
145
+ "exaggeration": np.array([0.5], dtype=np.float32)
146
+ }
147
+
148
+ repartition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=1.2)
149
+ generate_tokens = np.array([[START_SPEECH_TOKEN]])
150
+
151
+ # Simple loop as per model card
152
+ batch_size = 1
153
+ num_hidden_layers = 30
154
+ num_key_value_heads = 16
155
+ head_dim = 64
156
+ max_tokens = 256
157
+
158
+ past_key_values = None
159
+ attention_mask = None
160
+
161
+ for i in range(max_tokens):
162
+ inputs_embeds = SESSIONS["embed_tokens"].run(None, ort_embed_tokens_inputs)[0]
163
+ if i == 0:
164
+ cond_emb, prompt_token, ref_x_vector, prompt_feat = SESSIONS["speech_encoder"].run(None, {"audio_values": audio_values})
165
+ inputs_embeds = np.concatenate((cond_emb, inputs_embeds), axis=1)
166
+ past_key_values = { f"past_key_values.{layer}.{kv}": np.zeros([batch_size, num_key_value_heads, 0, head_dim], dtype=np.float32)
167
+ for layer in range(num_hidden_layers) for kv in ("key", "value") }
168
+ attention_mask = np.ones((batch_size, inputs_embeds.shape[1]), dtype=np.int64)
169
+
170
+ logits, *present_key_values = SESSIONS["language_model"].run(None, {**{"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}, **past_key_values})
171
+ logits = logits[:, -1, :]
172
+ next_token_logits = repartition_penalty_processor(generate_tokens, logits)
173
+ next_token = np.argmax(next_token_logits, axis=-1, keepdims=True).astype(np.int64)
174
+ generate_tokens = np.concatenate((generate_tokens, next_token), axis=-1)
175
+
176
+ if (next_token.flatten() == STOP_SPEECH_TOKEN).all(): break
177
+
178
+ ort_embed_tokens_inputs["input_ids"] = next_token
179
+ ort_embed_tokens_inputs["position_ids"] = np.full((1, 1), i + 1, dtype=np.int64)
180
+ attention_mask = np.concatenate([attention_mask, np.ones((batch_size, 1), dtype=np.int64)], axis=1)
181
+ for j, key in enumerate(past_key_values): past_key_values[key] = present_key_values[j]
182
+
183
+ # Final Decode
184
+ speech_tokens = generate_tokens[:, 1:-1]
185
+ speech_tokens = np.concatenate([prompt_token, speech_tokens], axis=1)
186
+ wav = SESSIONS["conditional_decoder"].run(None, {"speech_tokens": speech_tokens, "speaker_embeddings": ref_x_vector, "speaker_features": prompt_feat})[0]
187
+ wav = np.squeeze(wav, axis=0)
188
+
189
+ # Return bytes directly
190
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
191
+ path = f.name
192
+ sf.write(path, wav, S3GEN_SR)
193
+ with open(path, "rb") as f: audio = f.read()
194
+ if os.path.exists(path): os.unlink(path)
195
+ return audio