tts_gallery / src /models /tts /kokoro_model.py
Michael Hu
fix: update Kokoro TTS model to use new pipeline interface and 24 kHz sample rate
3ab544b
import tempfile
import os
from kokoro import KPipeline
from ..base import TTSModel
class KokoroTTSModel(TTSModel):
"""Kokoro TTS model implementation"""
def __init__(self):
self._model = None
self._initialized = False
self._lang_code = 'a' # Default to American English
@property
def name(self):
return "hexgrad/kokoro"
@property
def description(self):
return "Lightweight TTS model with 82M parameters, Apache-licensed for production and personal use"
def initialize(self):
"""Initialize the Kokoro model"""
if self._initialized:
return True
try:
self._model = KPipeline(lang_code=self._lang_code)
self._initialized = True
return True
except Exception as e:
print(f"Error initializing Kokoro model: {e}")
return False
def generate_speech(self, text, lang_code=None, voice_name=None, **kwargs):
"""
Generate speech from text using Kokoro TTS
Args:
text (str): Text to convert to speech
lang_code (str, optional): Language code ('a' for American English, 'b' for British English)
voice_name (str, optional): Voice name (currently not supported by Kokoro but kept for interface consistency)
**kwargs: Additional parameters for generation
Returns:
str: Path to the generated audio file
"""
# Update language code if provided
if lang_code and lang_code != self._lang_code:
self._lang_code = lang_code
self._initialized = False
if not self._initialized:
if not self.initialize():
raise RuntimeError("Failed to initialize Kokoro model")
# Generate speech
try:
# Use the correct method for Kokoro - it returns a generator
import numpy as np
import soundfile as sf
# Use voice_name parameter if provided
voice = voice_name if voice_name else kwargs.get('voice')
# KPipeline returns a generator of (gs, ps, audio) tuples
generator = self._model(text, voice=voice)
# Get the first audio output from the generator
for _, _, audio in generator:
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
sf.write(tmp_file.name, audio, 24000) # Kokoro uses 24000 sample rate
return tmp_file.name
# If no audio was generated
raise RuntimeError("No audio was generated by Kokoro")
except Exception as e:
print(f"Error generating speech with Kokoro: {e}")
raise RuntimeError(f"Failed to generate speech: {str(e)}")
def get_supported_languages(self):
return ["American English", "British English"]
def get_language_codes(self):
"""Get mapping of language names to language codes"""
return {
"American English": "a",
"British English": "b"
}