| import os |
| import io |
| import torch |
| import numpy as np |
| import base64 |
| import tempfile |
| import scipy.io.wavfile |
| from fastapi import FastAPI, HTTPException, Depends, Security |
| from fastapi.security import APIKeyHeader |
| from fastapi.responses import Response |
| from pydantic import BaseModel |
| from typing import Optional, List |
| import soundfile as sf |
| from pydub import AudioSegment |
| from kokoro import KModel, KPipeline |
| from pocket_tts import TTSModel |
| import logging |
| import re |
| import asyncio |
| from concurrent.futures import ThreadPoolExecutor |
| import time |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| SECRET_KEY = os.getenv("API_SECRET_KEY", "your-default-secret-key") |
| CUDA_AVAILABLE = torch.cuda.is_available() |
|
|
| try: |
| char_limit_env = os.getenv("CHAR_LIMIT", "5000") |
| CHAR_LIMIT = int(char_limit_env) if char_limit_env.isdigit() else 5000 |
| except (ValueError, AttributeError): |
| CHAR_LIMIT = 5000 |
|
|
| |
| app = FastAPI(title="Kokoro TTS API", version="1.0.0") |
|
|
| |
| api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) |
|
|
| async def verify_api_key(api_key: str = Security(api_key_header)): |
| if api_key != SECRET_KEY: |
| raise HTTPException( |
| status_code=403, |
| detail="Invalid API Key" |
| ) |
| return api_key |
|
|
| |
| logger.info(f"Initializing models... CUDA Available: {CUDA_AVAILABLE}") |
| models = {} |
| pipelines = {} |
|
|
| LANGUAGES = { |
| 'a': '๐บ๐ธ American English', |
| 'b': '๐ฌ๐ง British English', |
| 'e': '๐ช๐ธ Spanish', |
| 'f': '๐ซ๐ท French', |
| 'h': '๐ฎ๐ณ Hindi', |
| 'i': '๐ฎ๐น Italian', |
| 'j': '๐ฏ๐ต Japanese', |
| 'p': '๐ง๐ท Brazilian Portuguese', |
| 'z': '๐จ๐ณ Mandarin Chinese' |
| } |
|
|
| VOICE_CHOICES = { |
| 'af_heart': '๐บ๐ธ ๐บ Heart โค๏ธ', |
| 'af_bella': '๐บ๐ธ ๐บ Bella ๐ฅ', |
| 'af_nicole': '๐บ๐ธ ๐บ Nicole ๐ง', |
| 'af_aoede': '๐บ๐ธ ๐บ Aoede', |
| 'af_kore': '๐บ๐ธ ๐บ Kore', |
| 'af_sarah': '๐บ๐ธ ๐บ Sarah', |
| 'af_nova': '๐บ๐ธ ๐บ Nova', |
| 'af_sky': '๐บ๐ธ ๐บ Sky', |
| 'af_alloy': '๐บ๐ธ ๐บ Alloy', |
| 'af_jessica': '๐บ๐ธ ๐บ Jessica', |
| 'af_river': '๐บ๐ธ ๐บ River', |
| 'am_michael': '๐บ๐ธ ๐น Michael', |
| 'am_fenrir': '๐บ๐ธ ๐น Fenrir', |
| 'am_puck': '๐บ๐ธ ๐น Puck', |
| 'am_echo': '๐บ๐ธ ๐น Echo', |
| 'am_eric': '๐บ๐ธ ๐น Eric', |
| 'am_liam': '๐บ๐ธ ๐น Liam', |
| 'am_onyx': '๐บ๐ธ ๐น Onyx', |
| 'am_santa': '๐บ๐ธ ๐น Santa', |
| 'am_adam': '๐บ๐ธ ๐น Adam', |
| 'bf_emma': '๐ฌ๐ง ๐บ Emma', |
| 'bf_isabella': '๐ฌ๐ง ๐บ Isabella', |
| 'bf_alice': '๐ฌ๐ง ๐บ Alice', |
| 'bf_lily': '๐ฌ๐ง ๐บ Lily', |
| 'bm_george': '๐ฌ๐ง ๐น George', |
| 'bm_fable': '๐ฌ๐ง ๐น Fable', |
| 'bm_lewis': '๐ฌ๐ง ๐น Lewis', |
| 'bm_daniel': '๐ฌ๐ง ๐น Daniel', |
| } |
|
|
| |
| class TTSRequest(BaseModel): |
| text: str |
| voice: str = "af_heart" |
| language: Optional[str] = None |
| use_gpu: Optional[bool] = None |
| speed: float = 1.0 |
|
|
|
|
| class VoiceCloningRequest(BaseModel): |
| """Request model for voice cloning TTS generation""" |
| text: str |
| voice_sample_base64: str |
| speed: float = 1.0 |
|
|
|
|
| |
| voice_cloning_model = None |
|
|
|
|
|
|
| |
| @app.on_event("startup") |
| async def startup_event(): |
| global models, pipelines, voice_cloning_model |
| |
| try: |
| |
| models = { |
| False: KModel().to('cpu').eval() |
| } |
| if CUDA_AVAILABLE: |
| models[True] = KModel().to('cuda').eval() |
| |
| |
| for lang_code in LANGUAGES.keys(): |
| try: |
| pipelines[lang_code] = KPipeline(lang_code=lang_code, model=False) |
| logger.info(f"Initialized pipeline for language: {lang_code} - {LANGUAGES[lang_code]}") |
| except Exception as e: |
| logger.warning(f"Could not initialize pipeline for {lang_code}: {e}") |
| |
| |
| if 'a' in pipelines: |
| pipelines['a'].g2p.lexicon.golds['kokoro'] = 'kหOkษษนO' |
| if 'b' in pipelines: |
| pipelines['b'].g2p.lexicon.golds['kokoro'] = 'kหQkษษนQ' |
| |
| |
| for voice_code in VOICE_CHOICES.keys(): |
| try: |
| pipelines[voice_code[0]].load_voice(voice_code) |
| except Exception as e: |
| logger.warning(f"Could not preload voice {voice_code}: {e}") |
| |
| |
| try: |
| voice_cloning_model = TTSModel.load_model() |
| logger.info("Voice cloning model initialized successfully") |
| except Exception as e: |
| logger.warning(f"Could not initialize voice cloning model: {e}") |
| voice_cloning_model = None |
| |
| logger.info("Models and pipelines initialized successfully") |
| except Exception as e: |
| logger.error(f"Failed to initialize models: {e}") |
| raise |
|
|
| def split_text_into_chunks(text: str, max_chars: int = 500) -> List[str]: |
| """Split text into chunks at sentence boundaries""" |
| sentences = re.split(r'(?<=[.!?])\s+', text) |
| chunks = [] |
| current_chunk = "" |
| |
| for sentence in sentences: |
| if len(current_chunk) + len(sentence) + 1 <= max_chars: |
| current_chunk += (" " if current_chunk else "") + sentence |
| else: |
| if current_chunk: |
| chunks.append(current_chunk) |
| if len(sentence) > max_chars: |
| words = sentence.split() |
| current_chunk = "" |
| for word in words: |
| if len(current_chunk) + len(word) + 1 <= max_chars: |
| current_chunk += (" " if current_chunk else "") + word |
| else: |
| if current_chunk: |
| chunks.append(current_chunk) |
| current_chunk = word |
| else: |
| current_chunk = sentence |
| |
| if current_chunk: |
| chunks.append(current_chunk) |
| |
| return chunks |
|
|
| def preprocess_text_for_phonemizer(text: str) -> str: |
| """Clean text to avoid phonemizer issues""" |
| |
| text = text.strip() |
| |
| |
| text = re.sub(r'\s+', ' ', text) |
| |
| |
| text = re.sub(r'\.{2,}', '.', text) |
| text = re.sub(r'\?{2,}', '?', text) |
| text = re.sub(r'!{2,}', '!', text) |
| |
| |
| text = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', text) |
| |
| |
| text = re.sub(r'[โโ]', '-', text) |
| |
| |
| text = re.sub(r'https?://\S+', '', text) |
| |
| |
| text = re.sub(r'\S+@\S+\.\S+', '', text) |
| |
| |
| text = text.strip(' \t\n\r.,;:') |
| |
| return text |
|
|
| def split_into_sentences(text: str) -> List[str]: |
| """Split text into individual sentences for fallback processing""" |
| |
| sentences = re.split(r'(?<=[.!?])\s+', text) |
| return [s.strip() for s in sentences if s.strip()] |
|
|
| def generate_audio_chunk(text: str, voice: str, speed: float, use_gpu: bool, lang_code: str): |
| """Generate audio for a single text chunk. |
| |
| IMPORTANT: For non-English languages, we use the English phonemizer because |
| the Spanish/French/etc phonemizers have known issues with the 'espeak-ng' backend. |
| The voice model still sounds correct - only phoneme conversion uses English rules. |
| """ |
| |
| |
| text = preprocess_text_for_phonemizer(text) |
| |
| if not text or len(text) < 2: |
| logger.warning("Text too short after preprocessing, skipping") |
| return None |
| |
| |
| |
| STABLE_LANGUAGES = {'a', 'b'} |
| |
| if lang_code in STABLE_LANGUAGES: |
| pipeline = pipelines.get(lang_code) |
| else: |
| |
| pipeline = pipelines.get('a') |
| logger.debug(f"Using English phonemizer for lang={lang_code} (stability)") |
| |
| if not pipeline: |
| pipeline = pipelines.get('b', list(pipelines.values())[0] if pipelines else None) |
| if not pipeline: |
| logger.error("No pipeline available") |
| return None |
| |
| try: |
| pack = pipeline.load_voice(voice) |
| |
| for _, ps, _ in pipeline(text, voice, speed): |
| ref_s = pack[len(ps)-1] |
| |
| try: |
| with torch.no_grad(): |
| if use_gpu and True in models: |
| audio = models[True](ps, ref_s, speed) |
| else: |
| audio = models[False](ps, ref_s, speed) |
| |
| return audio.numpy() |
| except Exception as e: |
| if use_gpu and False in models: |
| logger.warning(f"GPU processing failed, falling back to CPU: {e}") |
| with torch.no_grad(): |
| audio = models[False](ps, ref_s, speed) |
| return audio.numpy() |
| else: |
| raise e |
| |
| return None |
| |
| except Exception as e: |
| logger.error(f"Failed to generate audio chunk: {e}") |
| return None |
|
|
| async def generate_audio(text: str, voice: str = 'af_heart', speed: float = 1.0, use_gpu: bool = None, lang_code: str = 'a'): |
| """Generate audio from text using Kokoro TTS with parallel chunking for unlimited text length""" |
| |
| text = text.strip() |
| |
| if use_gpu is None: |
| use_gpu = CUDA_AVAILABLE |
| else: |
| use_gpu = use_gpu and CUDA_AVAILABLE |
| |
| if lang_code not in pipelines: |
| raise ValueError(f"Language '{lang_code}' not supported or not initialized") |
| |
| chunks = split_text_into_chunks(text, max_chars=500) |
| logger.info(f"Split text into {len(chunks)} chunks for parallel processing") |
| |
| start_time = time.time() |
| |
| loop = asyncio.get_event_loop() |
| max_parallel = min(len(chunks), 2) |
| with ThreadPoolExecutor(max_workers=max_parallel) as executor: |
| tasks = [] |
| for i, chunk in enumerate(chunks): |
| task = loop.run_in_executor( |
| executor, |
| generate_audio_chunk, |
| chunk, |
| voice, |
| speed, |
| use_gpu, |
| lang_code |
| ) |
| tasks.append(task) |
| |
| audio_results = await asyncio.gather(*tasks) |
| |
| process_time = time.time() - start_time |
| logger.info(f"Parallel processing completed in {process_time:.2f}s") |
| |
| sample_rate = 24000 |
| silence_gap = np.zeros(int(0.1 * sample_rate), dtype=np.float32) |
| |
| audio_chunks = [] |
| for i, audio_chunk in enumerate(audio_results): |
| if audio_chunk is not None: |
| audio_chunks.append(audio_chunk) |
| if i < len(audio_results) - 1: |
| audio_chunks.append(silence_gap) |
| |
| if not audio_chunks: |
| return None, 0 |
| |
| if len(audio_chunks) == 1: |
| return audio_chunks[0], process_time |
| |
| merged_audio = np.concatenate(audio_chunks) |
| logger.info(f"Successfully merged {len(chunks)} chunks into final audio of {len(merged_audio)} samples ({process_time:.2f}s total)") |
| |
| return merged_audio, process_time |
|
|
| def numpy_to_mp3(audio_array: np.ndarray, sample_rate: int = 24000) -> bytes: |
| """Convert numpy array to MP3 bytes""" |
| |
| |
| audio_int16 = (audio_array * 32767).astype(np.int16) |
| |
| |
| wav_buffer = io.BytesIO() |
| sf.write(wav_buffer, audio_int16, sample_rate, format='WAV', subtype='PCM_16') |
| wav_buffer.seek(0) |
| |
| |
| audio_segment = AudioSegment.from_wav(wav_buffer) |
| |
| |
| mp3_buffer = io.BytesIO() |
| audio_segment.export(mp3_buffer, format="mp3", bitrate="192k") |
| mp3_buffer.seek(0) |
| |
| return mp3_buffer.read() |
|
|
| |
| @app.get("/") |
| async def root(): |
| return {"message": "Kokoro TTS API is running", "cuda_available": CUDA_AVAILABLE} |
|
|
| @app.get("/health") |
| async def health_check(): |
| return {"status": "healthy", "cuda_available": CUDA_AVAILABLE} |
|
|
| @app.post("/generate") |
| async def generate_tts( |
| request: TTSRequest, |
| api_key: str = Depends(verify_api_key) |
| ): |
| """Generate TTS audio from text""" |
| |
| try: |
| |
| if request.voice not in VOICE_CHOICES: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Invalid voice. Available voices: {list(VOICE_CHOICES.keys())}" |
| ) |
| |
| |
| |
| |
| |
| requested_lang = request.language if request.language else request.voice[0] |
| lang_code = 'a' |
| |
| |
| if not request.text or len(request.text.strip()) == 0: |
| raise HTTPException( |
| status_code=400, |
| detail="Text cannot be empty" |
| ) |
| |
| |
| logger.info(f"Generating audio for voice: {request.voice}, requested_lang: {requested_lang}, using_lang: {lang_code}, text length: {len(request.text)}") |
| audio_array, generation_time = await generate_audio( |
| text=request.text, |
| voice=request.voice, |
| speed=request.speed, |
| use_gpu=request.use_gpu, |
| lang_code=lang_code |
| ) |
| |
| if audio_array is None: |
| raise HTTPException( |
| status_code=500, |
| detail="Failed to generate audio" |
| ) |
| |
| |
| sample_rate = 24000 |
| audio_duration = len(audio_array) / sample_rate |
| |
| |
| mp3_bytes = numpy_to_mp3(audio_array, sample_rate) |
| |
| |
| return Response( |
| content=mp3_bytes, |
| media_type="audio/mpeg", |
| headers={ |
| "Content-Disposition": "attachment; filename=tts_output.mp3", |
| "X-Audio-Duration": str(audio_duration), |
| "X-Generation-Time": str(generation_time), |
| "X-Sample-Rate": str(sample_rate) |
| } |
| ) |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"Error generating TTS: {e}") |
| raise HTTPException( |
| status_code=500, |
| detail=f"Internal server error: {str(e)}" |
| ) |
|
|
| @app.get("/voices") |
| async def get_voices(api_key: str = Depends(verify_api_key)): |
| """Get available voices""" |
| return {"voices": VOICE_CHOICES} |
|
|
|
|
| def generate_cloned_audio_chunk(text: str, voice_state, speed: float = 1.0): |
| """Generate audio for a single text chunk using voice cloning. |
| |
| Args: |
| text: Text to convert to speech |
| voice_state: Voice state from pocket_tts model |
| speed: Speech speed multiplier |
| |
| Returns: |
| numpy array of audio data or None if failed |
| """ |
| global voice_cloning_model |
| |
| if voice_cloning_model is None: |
| logger.error("Voice cloning model not initialized") |
| return None |
| |
| |
| text = preprocess_text_for_phonemizer(text) |
| |
| if not text or len(text) < 2: |
| logger.warning("Text too short after preprocessing, skipping") |
| return None |
| |
| try: |
| audio = voice_cloning_model.generate_audio(voice_state, text) |
| |
| return audio.numpy() |
| except Exception as e: |
| logger.error(f"Failed to generate cloned audio chunk: {e}") |
| return None |
|
|
|
|
| async def generate_cloned_audio(text: str, voice_sample_base64: str, speed: float = 1.0): |
| """Generate audio from text using voice cloning with parallel chunking. |
| |
| Args: |
| text: Text to convert to speech |
| voice_sample_base64: Base64 encoded voice sample audio (WAV/MP3/any format) |
| speed: Speech speed multiplier |
| |
| Returns: |
| Tuple of (audio_array, generation_time and sample_rate) |
| """ |
| global voice_cloning_model |
| |
| if voice_cloning_model is None: |
| raise ValueError("Voice cloning model not initialized") |
| |
| text = text.strip() |
| |
| |
| try: |
| voice_sample_bytes = base64.b64decode(voice_sample_base64) |
| except Exception as e: |
| raise ValueError(f"Invalid base64 voice sample: {e}") |
| |
| |
| try: |
| |
| audio_buffer = io.BytesIO(voice_sample_bytes) |
| |
| |
| audio_buffer.seek(0) |
| header = audio_buffer.read(12) |
| audio_buffer.seek(0) |
| |
| if header[:4] == b'RIFF': |
| |
| audio_segment = AudioSegment.from_wav(audio_buffer) |
| elif header[:3] == b'ID3' or header[:2] == b'\xff\xfb' or header[:2] == b'\xff\xfa': |
| |
| audio_segment = AudioSegment.from_mp3(audio_buffer) |
| elif header[:4] == b'fLaC': |
| |
| audio_segment = AudioSegment.from_file(audio_buffer, format="flac") |
| elif header[4:8] == b'ftyp': |
| |
| audio_segment = AudioSegment.from_file(audio_buffer, format="m4a") |
| elif header[:4] == b'OggS': |
| |
| audio_segment = AudioSegment.from_ogg(audio_buffer) |
| else: |
| |
| audio_segment = AudioSegment.from_file(audio_buffer) |
| |
| |
| max_duration_ms = 30 * 1000 |
| if len(audio_segment) > max_duration_ms: |
| audio_segment = audio_segment[:max_duration_ms] |
| logger.info(f"Trimmed voice sample to 30 seconds") |
| |
| |
| if audio_segment.channels > 1: |
| audio_segment = audio_segment.set_channels(1) |
| |
| |
| audio_segment = audio_segment.set_frame_rate(16000) |
| |
| |
| wav_buffer = io.BytesIO() |
| audio_segment.export(wav_buffer, format="wav") |
| wav_buffer.seek(0) |
| wav_bytes = wav_buffer.read() |
| |
| except Exception as e: |
| raise ValueError(f"Failed to convert audio format: {e}") |
| |
| |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: |
| tmp_file.write(wav_bytes) |
| voice_sample_path = tmp_file.name |
| |
| try: |
| start_time = time.time() |
| |
| |
| voice_state = voice_cloning_model.get_state_for_audio_prompt(voice_sample_path) |
| |
| |
| chunks = split_text_into_chunks(text, max_chars=500) |
| logger.info(f"Split text into {len(chunks)} chunks for voice cloning") |
| |
| loop = asyncio.get_event_loop() |
| max_parallel = min(len(chunks), 2) |
| |
| with ThreadPoolExecutor(max_workers=max_parallel) as executor: |
| tasks = [] |
| for chunk in chunks: |
| task = loop.run_in_executor( |
| executor, |
| generate_cloned_audio_chunk, |
| chunk, |
| voice_state, |
| speed |
| ) |
| tasks.append(task) |
| |
| audio_results = await asyncio.gather(*tasks) |
| |
| process_time = time.time() - start_time |
| logger.info(f"Voice cloning processing completed in {process_time:.2f}s") |
| |
| |
| sample_rate = voice_cloning_model.sample_rate |
| silence_gap = np.zeros(int(0.1 * sample_rate), dtype=np.float32) |
| |
| audio_chunks = [] |
| for i, audio_chunk in enumerate(audio_results): |
| if audio_chunk is not None: |
| audio_chunks.append(audio_chunk) |
| if i < len(audio_results) - 1: |
| audio_chunks.append(silence_gap) |
| |
| if not audio_chunks: |
| return None, 0, sample_rate |
| |
| if len(audio_chunks) == 1: |
| return audio_chunks[0], process_time, sample_rate |
| |
| merged_audio = np.concatenate(audio_chunks) |
| logger.info(f"Successfully merged {len(chunks)} chunks into final cloned audio of {len(merged_audio)} samples ({process_time:.2f}s total)") |
| |
| return merged_audio, process_time, sample_rate |
| |
| finally: |
| |
| try: |
| os.unlink(voice_sample_path) |
| except Exception: |
| pass |
|
|
|
|
| @app.post("/generate-cloned") |
| async def generate_cloned_tts( |
| request: VoiceCloningRequest, |
| api_key: str = Depends(verify_api_key) |
| ): |
| """Generate TTS audio from text using a cloned voice from provided sample""" |
| |
| try: |
| |
| if voice_cloning_model is None: |
| raise HTTPException( |
| status_code=503, |
| detail="Voice cloning model not available" |
| ) |
| |
| |
| if not request.text or len(request.text.strip()) == 0: |
| raise HTTPException( |
| status_code=400, |
| detail="Text cannot be empty" |
| ) |
| |
| |
| if not request.voice_sample_base64: |
| raise HTTPException( |
| status_code=400, |
| detail="Voice sample is required" |
| ) |
| |
| |
| logger.info(f"Generating cloned audio for text length: {len(request.text)}") |
| audio_array, generation_time, sample_rate = await generate_cloned_audio( |
| text=request.text, |
| voice_sample_base64=request.voice_sample_base64, |
| speed=request.speed |
| ) |
| |
| if audio_array is None: |
| raise HTTPException( |
| status_code=500, |
| detail="Failed to generate cloned audio" |
| ) |
| |
| |
| audio_duration = len(audio_array) / sample_rate |
| |
| |
| mp3_bytes = numpy_to_mp3(audio_array, sample_rate) |
| |
| |
| return Response( |
| content=mp3_bytes, |
| media_type="audio/mpeg", |
| headers={ |
| "Content-Disposition": "attachment; filename=cloned_tts_output.mp3", |
| "X-Audio-Duration": str(audio_duration), |
| "X-Generation-Time": str(generation_time), |
| "X-Sample-Rate": str(sample_rate) |
| } |
| ) |
| |
| except HTTPException: |
| raise |
| except ValueError as e: |
| raise HTTPException( |
| status_code=400, |
| detail=str(e) |
| ) |
| except Exception as e: |
| logger.error(f"Error generating cloned TTS: {e}") |
| raise HTTPException( |
| status_code=500, |
| detail=f"Internal server error: {str(e)}" |
| ) |
|
|
| @app.get("/languages") |
| async def get_languages(api_key: str = Depends(verify_api_key)): |
| """Get available languages""" |
| return {"languages": LANGUAGES} |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|