Spaces:
Paused
Paused
| import os | |
| import io | |
| import asyncio | |
| import time | |
| import numpy as np | |
| import psutil | |
| import soundfile as sf | |
| import subprocess | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import Generator | |
| from contextlib import asynccontextmanager | |
| import logging | |
| import torch | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
| from fastapi.responses import Response, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import re | |
| import hashlib | |
| from functools import lru_cache | |
| # Ensure the cloned neutts-air repository is in the path | |
| import sys | |
| sys.path.append(os.path.join(os.getcwd(), 'neutts-air')) | |
| from neuttsair.neutts import NeuTTSAir | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("NeuTTS-API") | |
| # --- Configuration & Utility Functions --- | |
| # Explicitly use CPU as per Dockerfile and Hugging Face free tier compatibility | |
| DEVICE = "cpu" | |
| # Configure Max Workers for concurrent synthesis threads (1-2 is safe for CPU-only) | |
| MAX_WORKERS = 2 | |
| tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) | |
| SAMPLE_RATE = 24000 | |
| class TTSRequestModel(BaseModel): | |
| """Model for non-file inputs to synthesis and streaming.""" | |
| text: str = Field(..., min_length=1, max_length=1000) | |
| output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$") | |
| async def convert_to_wav_in_memory(upload_file: UploadFile) -> io.BytesIO: | |
| """ | |
| Converts uploaded audio to a 24kHz WAV in memory using FFmpeg pipes. | |
| This avoids all intermediate disk I/O for maximum speed. | |
| """ | |
| ffmpeg_command = [ | |
| "ffmpeg", | |
| "-i", "pipe:0", # Read from stdin | |
| "-f", "wav", | |
| "-ar", str(SAMPLE_RATE), | |
| "-ac", "1", | |
| "-c:a", "pcm_s16le", | |
| "pipe:1" # Write to stdout | |
| ] | |
| # Start the subprocess with pipes for stdin, stdout, and stderr | |
| proc = await asyncio.create_subprocess_exec( | |
| *ffmpeg_command, | |
| stdin=subprocess.PIPE, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE | |
| ) | |
| # Stream the uploaded file data into ffmpeg's stdin | |
| # and capture the resulting WAV data from its stdout | |
| wav_data, stderr_data = await proc.communicate(input=await upload_file.read()) | |
| if proc.returncode != 0: | |
| error_message = stderr_data.decode() | |
| logger.error(f"In-memory conversion failed: {error_message}") | |
| # Provide the last line of the FFmpeg error to the user | |
| error_detail = error_message.splitlines()[-1] if error_message else "Unknown FFmpeg error." | |
| raise HTTPException(status_code=400, detail=f"Audio format conversion failed: {error_detail}") | |
| logger.info("In-memory FFmpeg conversion successful.") | |
| # Return the raw WAV data in a BytesIO buffer, ready for the model | |
| return io.BytesIO(wav_data) | |
| # --- Model Wrapper and Logic --- | |
| class NeuTTSWrapper: | |
| def __init__(self, device: str = "cpu"): | |
| self.tts_model = None | |
| self.device = device | |
| self.load_model() | |
| def load_model(self): | |
| try: | |
| logger.info(f"Loading NeuTTSAir model on device: {self.device}") | |
| # Ensure we respect the CPU configuration | |
| self.tts_model = NeuTTSAir(backbone_device=self.device, codec_device=self.device) | |
| logger.info("✅ NeuTTSAir model loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"❌ Model loading failed: {e}") | |
| raise | |
| def _convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes: | |
| """Converts NumPy audio array to streamable bytes in the specified format.""" | |
| audio_buffer = io.BytesIO() | |
| try: | |
| sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format) | |
| except Exception as e: | |
| logger.error(f"Failed to write audio data to format {audio_format}: {e}") | |
| raise | |
| audio_buffer.seek(0) | |
| return audio_buffer.read() | |
| def _split_text_into_chunks(self, text: str) -> list[str]: | |
| """ | |
| Splits text into sentences OR clauses using a robust regex. | |
| This is fast, library-free, and now handles commas. | |
| """ | |
| # This regex now finds all sequences of characters that are not a sentence-ending | |
| # or clause-ending punctuation mark, followed by that punctuation. | |
| # The only change is adding ',' to the character sets. | |
| chunks = re.findall(r'[^.,!?]+[.,!?]*', text) | |
| return [c.strip() for c in chunks if c.strip()] | |
| def _get_or_create_reference_encoding(self, audio_content_hash: str, audio_bytes: bytes) -> torch.Tensor: | |
| """ | |
| Caches the expensive reference encoding operation using an in-memory LRU cache. | |
| The hash of the audio content is the key. | |
| """ | |
| logger.info(f"Cache miss for hash: {audio_content_hash[:10]}... Encoding new reference.") | |
| # The model's encode_reference can take a file-like object (BytesIO) | |
| return self.tts_model.encode_reference(io.BytesIO(audio_bytes)) | |
| def generate_speech_blocking(self, text: str, ref_audio_bytes: bytes, reference_text: str) -> np.ndarray: | |
| """Blocking synthesis using cached reference encoding.""" | |
| # 1. Hash the audio bytes to get a cache key | |
| audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest() | |
| # 2. Get the encoding from the cache (or create it if new) | |
| ref_s = self._get_or_create_reference_encoding(audio_hash, ref_audio_bytes) | |
| # 3. Infer full text | |
| with torch.no_grad(): | |
| audio = self.tts_model.infer(text, ref_s, reference_text) | |
| return audio | |
| # --- Asynchronous Offloading --- | |
| async def run_blocking_task_async(func, *args, **kwargs): | |
| """Offloads a blocking function call to the ThreadPoolExecutor.""" | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor( | |
| tts_executor, | |
| lambda: func(*args, **kwargs) | |
| ) | |
| # --- FastAPI Lifespan Manager (Kokoro Feature) --- | |
| async def lifespan(app: FastAPI): | |
| """Modern lifespan management: initialize model on startup, shutdown executor.""" | |
| try: | |
| app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE) | |
| except Exception as e: | |
| logger.error(f"Fatal startup error: {e}") | |
| # Terminate the application if the model can't load | |
| tts_executor.shutdown(wait=False) | |
| raise RuntimeError("Model initialization failed.") | |
| yield # Application serves requests | |
| # Shutdown | |
| logger.info("Shutting down ThreadPoolExecutor.") | |
| tts_executor.shutdown(wait=False) | |
| # --- FastAPI Application Setup --- | |
| app = FastAPI( | |
| title="NeuTTS Air Instant Cloning API", | |
| version="2.0.0-PROD-ENHANCED", | |
| docs_url="/docs", | |
| lifespan=lifespan | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- New Endpoints and Enhancements --- | |
| async def root(): | |
| return {"message": "NeuTTS Air API v2.0 - Ready for Instant Voice Cloning"} | |
| async def health_check(): | |
| """Enhanced health check (Kokoro Feature + Original Metrics)""" | |
| mem = psutil.virtual_memory() | |
| disk = psutil.disk_usage('/') | |
| return { | |
| "status": "healthy", | |
| "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None, | |
| "device": DEVICE, | |
| "concurrency_limit": MAX_WORKERS, | |
| "memory_usage": { | |
| "total_gb": round(mem.total / (1024**3), 2), | |
| "used_percent": mem.percent | |
| }, | |
| "disk_usage": { | |
| "total_gb": round(disk.total / (1024**3), 2), | |
| "used_percent": disk.percent | |
| } | |
| } | |
| # --- Core Synthesis Endpoints --- | |
| async def text_to_speech( | |
| text: str = Form(...), | |
| reference_text: str = Form(...), | |
| output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"), | |
| reference_audio: UploadFile = File(...)): | |
| """ | |
| Standard blocking TTS endpoint with in-memory processing and caching. | |
| """ | |
| if not hasattr(app.state, 'tts_wrapper'): | |
| raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded") | |
| start_time = time.time() | |
| try: | |
| # 1. Convert the uploaded file to WAV directly in memory | |
| converted_wav_buffer = await convert_to_wav_in_memory(reference_audio) | |
| ref_audio_bytes = converted_wav_buffer.getvalue() | |
| # 2. Offload the blocking AI process (now faster with caching) | |
| audio_data = await run_blocking_task_async( | |
| app.state.tts_wrapper.generate_speech_blocking, | |
| text, | |
| ref_audio_bytes, # Pass bytes, not a path | |
| reference_text | |
| ) | |
| # 3. Convert to requested output format | |
| audio_bytes = await run_blocking_task_async( | |
| app.state.tts_wrapper._convert_to_streamable_format, | |
| audio_data, | |
| output_format | |
| ) | |
| processing_time = time.time() - start_time | |
| audio_duration = len(audio_data) / SAMPLE_RATE | |
| return Response( | |
| content=audio_bytes, | |
| media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}", | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=tts_output.{output_format}", | |
| "X-Processing-Time": f"{processing_time:.2f}s", | |
| "X-Audio-Duration": f"{audio_duration:.2f}s" | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Synthesis error: {e}") | |
| if isinstance(e, HTTPException): | |
| raise | |
| raise HTTPException(status_code=500, detail=f"Synthesis failed: {e}") | |
| async def stream_text_to_speech_cloning( | |
| text: str = Form(..., min_length=1, max_length=5000), | |
| reference_text: str = Form(...), | |
| output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"), | |
| reference_audio: UploadFile = File(...)): | |
| """ | |
| Sentence-by-Sentence Streaming using a high-performance, asyncio-native | |
| look-ahead pipeline. This ensures true overlap of CPU work and network I/O. | |
| """ | |
| if not hasattr(app.state, 'tts_wrapper'): | |
| raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded") | |
| async def stream_generator(): | |
| loop = asyncio.get_event_loop() | |
| q = asyncio.Queue(maxsize=MAX_WORKERS + 1) # Queue size based on workers | |
| async def producer(): | |
| try: | |
| converted_wav_buffer = await convert_to_wav_in_memory(reference_audio) | |
| ref_audio_bytes = converted_wav_buffer.getvalue() | |
| # Perform the one-time voice encoding | |
| audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest() | |
| ref_s = await loop.run_in_executor( | |
| tts_executor, | |
| app.state.tts_wrapper._get_or_create_reference_encoding, | |
| audio_hash, | |
| ref_audio_bytes | |
| ) | |
| sentences = app.state.tts_wrapper._split_text_into_chunks(text) | |
| def process_chunk(sentence_text): | |
| with torch.no_grad(): | |
| audio_chunk = app.state.tts_wrapper.tts_model.infer(sentence_text, ref_s, reference_text) | |
| return app.state.tts_wrapper._convert_to_streamable_format(audio_chunk, output_format) | |
| # Schedule all chunks for background processing | |
| for sentence in sentences: | |
| task = loop.run_in_executor(tts_executor, process_chunk, sentence) | |
| await q.put(task) | |
| except Exception as e: | |
| logger.error(f"Error in producer task: {e}") | |
| await q.put(e) | |
| finally: | |
| await q.put(None) | |
| producer_task = asyncio.create_task(producer()) | |
| # --- High-Performance Consumer with Look-Ahead --- | |
| # Get the first task from the queue to start the process. | |
| current_task = await q.get() | |
| while current_task is not None: | |
| # Simultaneously, get the NEXT task from the queue. | |
| # This allows the next chunk to start processing while we wait for the current one. | |
| next_task = await q.get() | |
| # Now, wait for the CURRENT task to finish. | |
| if isinstance(current_task, Exception): | |
| raise current_task | |
| chunk_bytes = await current_task | |
| yield chunk_bytes | |
| # The next task becomes the current task for the next iteration. | |
| current_task = next_task | |
| await producer_task | |
| return StreamingResponse( | |
| stream_generator(), | |
| media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}" | |
| ) | |