ONNX Turbo Multilingual

#3
by rlimonta - opened

Is there any expectation regarding a multilingual model? This ONNX Turbo version is excellent!

Yeah, Bahasa Indonesia would be great, been waiting for years.

Anyway here's my a simple fastapi project to test/host chatter turbo onnx, a down-payment for Bahasa Indonesia language request 🫑

# main.py

import onnxruntime
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import numpy as np
from tqdm import trange
import librosa
import soundfile as sf
import time
import os
import sounddevice as sd
from dotenv import load_dotenv

load_dotenv()

MODEL_ID = os.getenv("MODEL_ID", "ResembleAI/chatterbox-turbo-ONNX")
SAMPLE_RATE = 24000
START_SPEECH_TOKEN = 6561
STOP_SPEECH_TOKEN = 6562
SILENCE_TOKEN = 4299
NUM_KV_HEADS = 16
HEAD_DIM = 64

class RepetitionPenaltyLogitsProcessor:
    def __init__(self, penalty: float):
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` must be a strictly positive float, but is {penalty}")
        self.penalty = penalty

    def __call__(self, input_ids: np.ndarray, scores: np.ndarray) -> np.ndarray:
        score = np.take_along_axis(scores, input_ids, axis=1)
        score = np.where(score < 0, score * self.penalty, score / self.penalty)
        scores_processed = scores.copy()
        np.put_along_axis(scores_processed, input_ids, score, axis=1)
        return scores_processed

def download_model(name: str, dtype: str = "fp32") -> str:
    filename = f"{name}{'' if dtype == 'fp32' else '_quantized' if dtype == 'q8' else f'_{dtype}'}.onnx"
    graph = hf_hub_download(MODEL_ID, subfolder="onnx", filename=filename)      # Download graph
    hf_hub_download(MODEL_ID, subfolder="onnx", filename=f"{filename}_data")    # Download weights
    return graph

# Download models
## dtype options: fp32, fp16, q8, q4, q4f16
conditional_decoder_path = download_model("conditional_decoder", dtype="fp32")
speech_encoder_path = download_model("speech_encoder", dtype="fp32")
embed_tokens_path = download_model("embed_tokens", dtype="fp32")
language_model_path = download_model("language_model", dtype="fp32")

# Create ONNX sessions
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if "CUDAExecutionProvider" in onnxruntime.get_available_providers() else ["CPUExecutionProvider"]
print(f"Using providers: {providers}")

speech_encoder_session = onnxruntime.InferenceSession(speech_encoder_path, providers=providers)
embed_tokens_session = onnxruntime.InferenceSession(embed_tokens_path, providers=providers)
language_model_session = onnxruntime.InferenceSession(language_model_path, providers=providers)
cond_decoder_session = onnxruntime.InferenceSession(conditional_decoder_path, providers=providers)

def generate_speech(text, target_voice_path="ref_speaker.wav", max_new_tokens=1024, repetition_penalty=1.2, apply_watermark=False):
    start_time = time.time()
    
    # Prepare audio input
    audio_values, _ = librosa.load(target_voice_path, sr=SAMPLE_RATE)
    audio_values = audio_values[np.newaxis, :].astype(np.float32)

    # Prepare text input
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    input_ids = tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64)

    # Generation loop
    repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
    generate_tokens = np.array([[START_SPEECH_TOKEN]], dtype=np.int64)
    
    # We use a simple loop instead of trange for the API to avoid progress bar spam in logs, 
    # but keep trange if running interactively. For simplicity here, standard range.
    # Actually, for API usage, silence tqdm might be better, but let's just stick to range for cleaner logs if desired.
    # Keeping trange for now as requested/existing style.
    for i in trange(max_new_tokens, desc="Sampling", dynamic_ncols=True):
        inputs_embeds = embed_tokens_session.run(None, {"input_ids": input_ids})[0]

        if i == 0:
            ort_speech_encoder_input = {"audio_values": audio_values}
            cond_emb, prompt_token, speaker_embeddings, speaker_features = speech_encoder_session.run(None, ort_speech_encoder_input)
            inputs_embeds = np.concatenate((cond_emb, inputs_embeds), axis=1)

            # Initialize cache and LLM inputs
            batch_size, seq_len, _ = inputs_embeds.shape
            past_key_values = {
                i.name: np.zeros([batch_size, NUM_KV_HEADS, 0, HEAD_DIM], dtype=np.float16 if i.type == 'tensor(float16)' else np.float32)
                for i in language_model_session.get_inputs()
                if "past_key_values" in i.name
            }
            attention_mask = np.ones((batch_size, seq_len), dtype=np.int64)
            position_ids = np.arange(seq_len, dtype=np.int64).reshape(1, -1).repeat(batch_size, axis=0)

        logits, *present_key_values = language_model_session.run(None, dict(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            **past_key_values,
        ))

        logits = logits[:, -1, :]
        next_token_logits = repetition_penalty_processor(generate_tokens, logits)

        input_ids = np.argmax(next_token_logits, axis=-1, keepdims=True).astype(np.int64)
        generate_tokens = np.concatenate((generate_tokens, input_ids), axis=-1)
        if (input_ids.flatten() == STOP_SPEECH_TOKEN).all():
            break

        # Update values for next generation loop
        attention_mask = np.concatenate([attention_mask, np.ones((batch_size, 1), dtype=np.int64)], axis=1)
        position_ids = position_ids[:, -1:] + 1
        for j, key in enumerate(past_key_values):
            past_key_values[key] = present_key_values[j]

    # Decode audio
    speech_tokens = generate_tokens[:, 1:-1]
    silence_tokens = np.full((speech_tokens.shape[0], 3), SILENCE_TOKEN, dtype=np.int64) # Add silence at the end
    speech_tokens = np.concatenate([prompt_token, speech_tokens, silence_tokens], axis=1)

    wav = cond_decoder_session.run(None, dict(
        speech_tokens=speech_tokens,
        speaker_embeddings=speaker_embeddings,
        speaker_features=speaker_features,
    ))[0].squeeze(axis=0)

    # Optional: Apply watermark
    if apply_watermark:
        import perth
        watermarker = perth.PerthImplicitWatermarker()
        wav = watermarker.apply_watermark(wav, sample_rate=SAMPLE_RATE)
    
    end_time = time.time()
    print(f"Total encoded/decoding time: {end_time - start_time:.2f}s")
    
    return wav

def play_audio(wav, sr=SAMPLE_RATE):
    try:
        sd.play(wav, sr)
        sd.wait()
    except Exception as e:
        print(f"Error playing audio: {e}")

if __name__ == "__main__":
    # Generation parameters
    text = """
    Oh, wow, really? [laugh] That is absolute madness! I can't believe you actually managed to pull that off. Okay, okay, hold on. [chuckle] Let me get this straight. You walked in there, looked him right in the eye, and said... that? Man, you are braver than I am. I think I would have just frozen on the spot. Um, anyway, what happened after you left the building? Did they follow you?
    """
    target_voice_path = "ref_speaker.wav" # <-- change this to the voice sample you want to clone, use the same ref_speaker.wav as fallback
    output_file_name = "output.wav"
    max_new_tokens = 1024
    repetition_penalty = 1.2
    apply_watermark = False

    wav = generate_speech(text, target_voice_path, max_new_tokens, repetition_penalty, apply_watermark)
    sf.write(output_file_name, wav, SAMPLE_RATE)
    
    # Optional: uncomment to play automatically in script mode
    # play_audio(wav)
# api.py

import os
import io
import uvicorn
from fastapi import FastAPI, UploadFile, Body
from fastapi.responses import Response
import soundfile as sf
import numpy as np
from main import generate_speech, play_audio, SAMPLE_RATE
from pydub import AudioSegment
from dotenv import load_dotenv

load_dotenv()

app = FastAPI()

PORT = int(os.getenv("PORT", 8000))
HOST = os.getenv("HOST", "0.0.0.0")



@app
	.post("/generate")
async def generate(
    text: str = Body(..., embed=True),
    voice: str = Body("ref_speaker.wav", embed=True),
    format: str = Body("wav", embed=True),
    autoplay: bool = Body(False, embed=True)
):
    # Determine voice path - assumption: voice files are in current dir or absolute path
    # For safety, basic check or just pass to generate_speech which uses librosa.load
    target_voice_path = voice
    if not os.path.exists(target_voice_path):
        # Fallback to ref_speaker.wav if available, purely for robustness example, 
        # or let it fail naturally. Let's assume user provides valid path.
        pass

    wav_data = generate_speech(text, target_voice_path=target_voice_path)

    # Autoplay if requested
    if autoplay:
        play_audio(wav_data)

    # Convert to requested format
    buffer = io.BytesIO()
    if format.lower() == "wav":
        sf.write(buffer, wav_data, SAMPLE_RATE, format="WAV")
        media_type = "audio/wav"
    elif format.lower() == "ogg":
        sf.write(buffer, wav_data, SAMPLE_RATE, format="OGG")
        media_type = "audio/ogg"
    elif format.lower() == "mp3":
        # soundfile doesn't natively write mp3 easily without extra libs, use pydub
        # First write to wav buffer to pass to pydub
        wav_buffer = io.BytesIO()
        sf.write(wav_buffer, wav_data, SAMPLE_RATE, format="WAV")
        wav_buffer.seek(0)
        audio = AudioSegment.from_wav(wav_buffer)
        audio.export(buffer, format="mp3")
        media_type = "audio/mpeg"
    else:
        # Default to wav
        sf.write(buffer, wav_data, SAMPLE_RATE, format="WAV")
        media_type = "audio/wav"

    buffer.seek(0)
    return Response(content=buffer.read(), media_type=media_type)

if __name__ == "__main__":
    uvicorn.run(app, host=HOST, port=PORT)

# .env

PORT=4455
HOST=0.0.0.0
MODEL_ID="ResembleAI/chatterbox-turbo-ONNX"
# requirements.txt
# uv pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128 -U
onnxruntime-gpu # pip uninstall onnxruntime first..
transformers
huggingface-hub
numpy
tqdm
librosa
soundfile
fastapi
uvicorn
python-dotenv
sounddevice
pydub

Chatterbox Turbo TTS

High-performance ONNX-based Text-to-Speech engine featuring CUDA acceleration and a FastAPI server.

Features

  • πŸš€ Accelerated Inference: Optimized for CUDA to provide fast generation speeds.
  • 🌐 API Server: Built with FastAPI for easy integration and remote generation.
  • 🎡 Multi-Format Support: Generate audio in WAV, OGG, or MP3.
  • πŸ”Š Server-Side Autoplay: Optional capability to play generated audio immediately on the host machine.
  • βš™οΈ Configurable: Easy configuration using .env files.

Installation

  1. Prerequisites:

    • Python 3.10 or higher.
    • CUDA-compatible GPU (recommended for best performance).
    • FFmpeg (required for MP3 encoding).
  2. Install Dependencies:

    pip install -r requirements.txt
    

Configuration

The application is configured via a .env file. A default one is created automatically, but you can customize it:

PORT=4455
HOST=0.0.0.0
MODEL_ID="ResembleAI/chatterbox-turbo-ONNX"

Usage

πŸ–₯️ API Server

Start the API server to handle HTTP requests:

uvicorn api:app --port 4455 --host 0.0.0.0

Endpoints

POST /generate

Generates audio from the provided text.

Request Body (JSON):

{
  "text": "Hello, this is a test.",
  "voice": "ref_speaker.wav",
  "format": "wav",
  "autoplay": false
}
Field Type Default Description
text string Required The text to be converted to speech.
voice string ref_speaker.wav Path to the reference audio file for voice cloning.
format string wav Output audio format (wav, ogg, mp3).
autoplay boolean false If true, plays the audio on the server's speakers.

Example Request:

curl -X POST "http://localhost:4455/generate" \
     -H "Content-Type: application/json" \
     -d '{"text": "Hello world", "format": "mp3"}' \
     --output response.mp3

⌨️ CLI Script

You can also run the generation logic directly as a script:

python main.py

Note: Edit the text variable in main.py to change the input for the CLI script.

Troubleshooting

  • Connection Refused: Ensure the server is running (uvicorn api:app ...) before making requests.
  • MP3 Errors: Verify that ffmpeg is installed and added to your system PATH.
  • CUDA Warnings: Warnings about Memcpy or ScatterND during startup are normal ONNX Runtime optimization logs and do not indicate failure.

Sign up or log in to comment