ONNX Turbo Multilingual
Is there any expectation regarding a multilingual model? This ONNX Turbo version is excellent!
Yeah, Bahasa Indonesia would be great, been waiting for years.
Anyway here's my a simple fastapi project to test/host chatter turbo onnx, a down-payment for Bahasa Indonesia language request π«‘
# main.py
import onnxruntime
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import numpy as np
from tqdm import trange
import librosa
import soundfile as sf
import time
import os
import sounddevice as sd
from dotenv import load_dotenv
load_dotenv()
MODEL_ID = os.getenv("MODEL_ID", "ResembleAI/chatterbox-turbo-ONNX")
SAMPLE_RATE = 24000
START_SPEECH_TOKEN = 6561
STOP_SPEECH_TOKEN = 6562
SILENCE_TOKEN = 4299
NUM_KV_HEADS = 16
HEAD_DIM = 64
class RepetitionPenaltyLogitsProcessor:
def __init__(self, penalty: float):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` must be a strictly positive float, but is {penalty}")
self.penalty = penalty
def __call__(self, input_ids: np.ndarray, scores: np.ndarray) -> np.ndarray:
score = np.take_along_axis(scores, input_ids, axis=1)
score = np.where(score < 0, score * self.penalty, score / self.penalty)
scores_processed = scores.copy()
np.put_along_axis(scores_processed, input_ids, score, axis=1)
return scores_processed
def download_model(name: str, dtype: str = "fp32") -> str:
filename = f"{name}{'' if dtype == 'fp32' else '_quantized' if dtype == 'q8' else f'_{dtype}'}.onnx"
graph = hf_hub_download(MODEL_ID, subfolder="onnx", filename=filename) # Download graph
hf_hub_download(MODEL_ID, subfolder="onnx", filename=f"{filename}_data") # Download weights
return graph
# Download models
## dtype options: fp32, fp16, q8, q4, q4f16
conditional_decoder_path = download_model("conditional_decoder", dtype="fp32")
speech_encoder_path = download_model("speech_encoder", dtype="fp32")
embed_tokens_path = download_model("embed_tokens", dtype="fp32")
language_model_path = download_model("language_model", dtype="fp32")
# Create ONNX sessions
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if "CUDAExecutionProvider" in onnxruntime.get_available_providers() else ["CPUExecutionProvider"]
print(f"Using providers: {providers}")
speech_encoder_session = onnxruntime.InferenceSession(speech_encoder_path, providers=providers)
embed_tokens_session = onnxruntime.InferenceSession(embed_tokens_path, providers=providers)
language_model_session = onnxruntime.InferenceSession(language_model_path, providers=providers)
cond_decoder_session = onnxruntime.InferenceSession(conditional_decoder_path, providers=providers)
def generate_speech(text, target_voice_path="ref_speaker.wav", max_new_tokens=1024, repetition_penalty=1.2, apply_watermark=False):
start_time = time.time()
# Prepare audio input
audio_values, _ = librosa.load(target_voice_path, sr=SAMPLE_RATE)
audio_values = audio_values[np.newaxis, :].astype(np.float32)
# Prepare text input
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
input_ids = tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64)
# Generation loop
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
generate_tokens = np.array([[START_SPEECH_TOKEN]], dtype=np.int64)
# We use a simple loop instead of trange for the API to avoid progress bar spam in logs,
# but keep trange if running interactively. For simplicity here, standard range.
# Actually, for API usage, silence tqdm might be better, but let's just stick to range for cleaner logs if desired.
# Keeping trange for now as requested/existing style.
for i in trange(max_new_tokens, desc="Sampling", dynamic_ncols=True):
inputs_embeds = embed_tokens_session.run(None, {"input_ids": input_ids})[0]
if i == 0:
ort_speech_encoder_input = {"audio_values": audio_values}
cond_emb, prompt_token, speaker_embeddings, speaker_features = speech_encoder_session.run(None, ort_speech_encoder_input)
inputs_embeds = np.concatenate((cond_emb, inputs_embeds), axis=1)
# Initialize cache and LLM inputs
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = {
i.name: np.zeros([batch_size, NUM_KV_HEADS, 0, HEAD_DIM], dtype=np.float16 if i.type == 'tensor(float16)' else np.float32)
for i in language_model_session.get_inputs()
if "past_key_values" in i.name
}
attention_mask = np.ones((batch_size, seq_len), dtype=np.int64)
position_ids = np.arange(seq_len, dtype=np.int64).reshape(1, -1).repeat(batch_size, axis=0)
logits, *present_key_values = language_model_session.run(None, dict(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
**past_key_values,
))
logits = logits[:, -1, :]
next_token_logits = repetition_penalty_processor(generate_tokens, logits)
input_ids = np.argmax(next_token_logits, axis=-1, keepdims=True).astype(np.int64)
generate_tokens = np.concatenate((generate_tokens, input_ids), axis=-1)
if (input_ids.flatten() == STOP_SPEECH_TOKEN).all():
break
# Update values for next generation loop
attention_mask = np.concatenate([attention_mask, np.ones((batch_size, 1), dtype=np.int64)], axis=1)
position_ids = position_ids[:, -1:] + 1
for j, key in enumerate(past_key_values):
past_key_values[key] = present_key_values[j]
# Decode audio
speech_tokens = generate_tokens[:, 1:-1]
silence_tokens = np.full((speech_tokens.shape[0], 3), SILENCE_TOKEN, dtype=np.int64) # Add silence at the end
speech_tokens = np.concatenate([prompt_token, speech_tokens, silence_tokens], axis=1)
wav = cond_decoder_session.run(None, dict(
speech_tokens=speech_tokens,
speaker_embeddings=speaker_embeddings,
speaker_features=speaker_features,
))[0].squeeze(axis=0)
# Optional: Apply watermark
if apply_watermark:
import perth
watermarker = perth.PerthImplicitWatermarker()
wav = watermarker.apply_watermark(wav, sample_rate=SAMPLE_RATE)
end_time = time.time()
print(f"Total encoded/decoding time: {end_time - start_time:.2f}s")
return wav
def play_audio(wav, sr=SAMPLE_RATE):
try:
sd.play(wav, sr)
sd.wait()
except Exception as e:
print(f"Error playing audio: {e}")
if __name__ == "__main__":
# Generation parameters
text = """
Oh, wow, really? [laugh] That is absolute madness! I can't believe you actually managed to pull that off. Okay, okay, hold on. [chuckle] Let me get this straight. You walked in there, looked him right in the eye, and said... that? Man, you are braver than I am. I think I would have just frozen on the spot. Um, anyway, what happened after you left the building? Did they follow you?
"""
target_voice_path = "ref_speaker.wav" # <-- change this to the voice sample you want to clone, use the same ref_speaker.wav as fallback
output_file_name = "output.wav"
max_new_tokens = 1024
repetition_penalty = 1.2
apply_watermark = False
wav = generate_speech(text, target_voice_path, max_new_tokens, repetition_penalty, apply_watermark)
sf.write(output_file_name, wav, SAMPLE_RATE)
# Optional: uncomment to play automatically in script mode
# play_audio(wav)
# api.py
import os
import io
import uvicorn
from fastapi import FastAPI, UploadFile, Body
from fastapi.responses import Response
import soundfile as sf
import numpy as np
from main import generate_speech, play_audio, SAMPLE_RATE
from pydub import AudioSegment
from dotenv import load_dotenv
load_dotenv()
app = FastAPI()
PORT = int(os.getenv("PORT", 8000))
HOST = os.getenv("HOST", "0.0.0.0")
@app
.post("/generate")
async def generate(
text: str = Body(..., embed=True),
voice: str = Body("ref_speaker.wav", embed=True),
format: str = Body("wav", embed=True),
autoplay: bool = Body(False, embed=True)
):
# Determine voice path - assumption: voice files are in current dir or absolute path
# For safety, basic check or just pass to generate_speech which uses librosa.load
target_voice_path = voice
if not os.path.exists(target_voice_path):
# Fallback to ref_speaker.wav if available, purely for robustness example,
# or let it fail naturally. Let's assume user provides valid path.
pass
wav_data = generate_speech(text, target_voice_path=target_voice_path)
# Autoplay if requested
if autoplay:
play_audio(wav_data)
# Convert to requested format
buffer = io.BytesIO()
if format.lower() == "wav":
sf.write(buffer, wav_data, SAMPLE_RATE, format="WAV")
media_type = "audio/wav"
elif format.lower() == "ogg":
sf.write(buffer, wav_data, SAMPLE_RATE, format="OGG")
media_type = "audio/ogg"
elif format.lower() == "mp3":
# soundfile doesn't natively write mp3 easily without extra libs, use pydub
# First write to wav buffer to pass to pydub
wav_buffer = io.BytesIO()
sf.write(wav_buffer, wav_data, SAMPLE_RATE, format="WAV")
wav_buffer.seek(0)
audio = AudioSegment.from_wav(wav_buffer)
audio.export(buffer, format="mp3")
media_type = "audio/mpeg"
else:
# Default to wav
sf.write(buffer, wav_data, SAMPLE_RATE, format="WAV")
media_type = "audio/wav"
buffer.seek(0)
return Response(content=buffer.read(), media_type=media_type)
if __name__ == "__main__":
uvicorn.run(app, host=HOST, port=PORT)
# .env
PORT=4455
HOST=0.0.0.0
MODEL_ID="ResembleAI/chatterbox-turbo-ONNX"
# requirements.txt
# uv pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128 -U
onnxruntime-gpu # pip uninstall onnxruntime first..
transformers
huggingface-hub
numpy
tqdm
librosa
soundfile
fastapi
uvicorn
python-dotenv
sounddevice
pydub
Chatterbox Turbo TTS
High-performance ONNX-based Text-to-Speech engine featuring CUDA acceleration and a FastAPI server.
Features
- π Accelerated Inference: Optimized for CUDA to provide fast generation speeds.
- π API Server: Built with FastAPI for easy integration and remote generation.
- π΅ Multi-Format Support: Generate audio in WAV, OGG, or MP3.
- π Server-Side Autoplay: Optional capability to play generated audio immediately on the host machine.
- βοΈ Configurable: Easy configuration using
.envfiles.
Installation
Prerequisites:
- Python 3.10 or higher.
- CUDA-compatible GPU (recommended for best performance).
- FFmpeg (required for MP3 encoding).
Install Dependencies:
pip install -r requirements.txt
Configuration
The application is configured via a .env file. A default one is created automatically, but you can customize it:
PORT=4455
HOST=0.0.0.0
MODEL_ID="ResembleAI/chatterbox-turbo-ONNX"
Usage
π₯οΈ API Server
Start the API server to handle HTTP requests:
uvicorn api:app --port 4455 --host 0.0.0.0
Endpoints
POST /generate
Generates audio from the provided text.
Request Body (JSON):
{
"text": "Hello, this is a test.",
"voice": "ref_speaker.wav",
"format": "wav",
"autoplay": false
}
| Field | Type | Default | Description |
|---|---|---|---|
text |
string | Required | The text to be converted to speech. |
voice |
string | ref_speaker.wav |
Path to the reference audio file for voice cloning. |
format |
string | wav |
Output audio format (wav, ogg, mp3). |
autoplay |
boolean | false |
If true, plays the audio on the server's speakers. |
Example Request:
curl -X POST "http://localhost:4455/generate" \
-H "Content-Type: application/json" \
-d '{"text": "Hello world", "format": "mp3"}' \
--output response.mp3
β¨οΈ CLI Script
You can also run the generation logic directly as a script:
python main.py
Note: Edit the text variable in main.py to change the input for the CLI script.
Troubleshooting
- Connection Refused: Ensure the server is running (
uvicorn api:app ...) before making requests. - MP3 Errors: Verify that
ffmpegis installed and added to your system PATH. - CUDA Warnings: Warnings about
MemcpyorScatterNDduring startup are normal ONNX Runtime optimization logs and do not indicate failure.