Spaces:
Paused
Paused
| import os | |
| import io | |
| import asyncio | |
| import time | |
| import shutil | |
| import numpy as np | |
| import psutil | |
| import soundfile as sf | |
| import subprocess | |
| import tempfile | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import Optional, Generator | |
| from contextlib import asynccontextmanager | |
| import logging | |
| import aiofiles | |
| import torch | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query | |
| from fastapi.responses import Response, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import re | |
| import hashlib | |
| from functools import lru_cache | |
| # Ensure the cloned neutts-air repository is in the path | |
| import sys | |
| sys.path.append(os.path.join(os.getcwd(), 'neutts-air')) | |
| from neuttsair.neutts import NeuTTSAir | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("NeuTTS-API") | |
| # --- Configuration & Utility Functions --- | |
| # Explicitly use CPU as per Dockerfile and Hugging Face free tier compatibility | |
| DEVICE = "cpu" | |
| # Configure Max Workers for concurrent synthesis threads (1-2 is safe for CPU-only) | |
| MAX_WORKERS = 2 | |
| tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) | |
| SAMPLE_RATE = 24000 | |
| CLEANUP_THRESHOLD = 300 # 1 hour in seconds | |
| TEMP_AUDIO_DIR = "temp_audio" | |
| GENERATED_AUDIO_DIR = "generated_audio" | |
| os.makedirs(TEMP_AUDIO_DIR, exist_ok=True) | |
| os.makedirs(GENERATED_AUDIO_DIR, exist_ok=True) | |
| class TTSRequestModel(BaseModel): | |
| """Model for non-file inputs to synthesis and streaming.""" | |
| text: str = Field(..., min_length=1, max_length=1000) | |
| speed: float = Field(default=1.0, ge=0.5, le=2.0) | |
| output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$") | |
| async def convert_to_wav_in_memory(upload_file: UploadFile) -> io.BytesIO: | |
| """ | |
| Converts uploaded audio to a 24kHz WAV in memory using FFmpeg pipes. | |
| This avoids all intermediate disk I/O for maximum speed. | |
| """ | |
| ffmpeg_command = [ | |
| "ffmpeg", | |
| "-i", "pipe:0", # Read from stdin | |
| "-f", "wav", | |
| "-ar", str(SAMPLE_RATE), | |
| "-ac", "1", | |
| "-c:a", "pcm_s16le", | |
| "pipe:1" # Write to stdout | |
| ] | |
| # Start the subprocess with pipes for stdin, stdout, and stderr | |
| proc = await asyncio.create_subprocess_exec( | |
| *ffmpeg_command, | |
| stdin=subprocess.PIPE, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE | |
| ) | |
| # Stream the uploaded file data into ffmpeg's stdin | |
| # and capture the resulting WAV data from its stdout | |
| wav_data, stderr_data = await proc.communicate(input=await upload_file.read()) | |
| if proc.returncode != 0: | |
| error_message = stderr_data.decode() | |
| logger.error(f"In-memory conversion failed: {error_message}") | |
| # Provide the last line of the FFmpeg error to the user | |
| error_detail = error_message.splitlines()[-1] if error_message else "Unknown FFmpeg error." | |
| raise HTTPException(status_code=400, detail=f"Audio format conversion failed: {error_detail}") | |
| logger.info("In-memory FFmpeg conversion successful.") | |
| # Return the raw WAV data in a BytesIO buffer, ready for the model | |
| return io.BytesIO(wav_data) | |
| # --- Model Wrapper and Logic --- | |
| class NeuTTSWrapper: | |
| def __init__(self, device: str = "cpu"): | |
| self.tts_model = None | |
| self.device = device | |
| self.load_model() | |
| def load_model(self): | |
| try: | |
| logger.info(f"Loading NeuTTSAir model on device: {self.device}") | |
| # Ensure we respect the CPU configuration | |
| self.tts_model = NeuTTSAir(backbone_device=self.device, codec_device=self.device) | |
| logger.info("✅ NeuTTSAir model loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"❌ Model loading failed: {e}") | |
| raise | |
| def _convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes: | |
| """Converts NumPy audio array to streamable bytes in the specified format.""" | |
| audio_buffer = io.BytesIO() | |
| try: | |
| sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format) | |
| except Exception as e: | |
| logger.error(f"Failed to write audio data to format {audio_format}: {e}") | |
| raise | |
| audio_buffer.seek(0) | |
| return audio_buffer.read() | |
| def _split_text_into_chunks(self, text: str) -> list[str]: | |
| """ | |
| Splits text into sentences OR clauses using a robust regex. | |
| This is fast, library-free, and now handles commas. | |
| """ | |
| # This regex now finds all sequences of characters that are not a sentence-ending | |
| # or clause-ending punctuation mark, followed by that punctuation. | |
| # The only change is adding ',' to the character sets. | |
| chunks = re.findall(r'[^.,!?]+[.,!?]*', text) | |
| return [c.strip() for c in chunks if c.strip()] | |
| def _get_or_create_reference_encoding(self, audio_content_hash: str, audio_bytes: bytes) -> torch.Tensor: | |
| """ | |
| Caches the expensive reference encoding operation using an in-memory LRU cache. | |
| The hash of the audio content is the key. | |
| """ | |
| logger.info(f"Cache miss for hash: {audio_content_hash[:10]}... Encoding new reference.") | |
| # The model's encode_reference can take a file-like object (BytesIO) | |
| return self.tts_model.encode_reference(io.BytesIO(audio_bytes)) | |
| def generate_speech_blocking(self, text: str, ref_audio_bytes: bytes, reference_text: str) -> np.ndarray: | |
| """Blocking synthesis using cached reference encoding.""" | |
| # 1. Hash the audio bytes to get a cache key | |
| audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest() | |
| # 2. Get the encoding from the cache (or create it if new) | |
| ref_s = self._get_or_create_reference_encoding(audio_hash, ref_audio_bytes) | |
| # 3. Infer full text | |
| with torch.no_grad(): | |
| audio = self.tts_model.infer(text, ref_s, reference_text) | |
| return audio | |
| def stream_speech_blocking(self, text: str, ref_audio_bytes: bytes, reference_text: str, speed: float, audio_format: str) -> Generator[bytes, None, None]: | |
| """Sentence-by-Sentence Streaming using cached reference encoding.""" | |
| logger.info(f"Starting streaming synthesis for text length: {len(text)}") | |
| # 1. Hash the audio bytes once | |
| audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest() | |
| # 2. Get the reference encoding from cache, once for the whole stream | |
| ref_s = self._get_or_create_reference_encoding(audio_hash, ref_audio_bytes) | |
| # 3. Split text using the new regex method | |
| sentences = self._split_text_into_chunks(text) | |
| # 4. Stream chunks | |
| for i, sentence in enumerate(sentences): | |
| if not sentence.strip(): | |
| continue | |
| logger.debug(f"Generating streaming chunk {i+1}: '{sentence[:30]}...'") | |
| with torch.no_grad(): | |
| audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text) | |
| yield self._convert_to_streamable_format(audio_chunk, audio_format) | |
| logger.info("Streaming synthesis complete.") | |
| # --- Asynchronous Offloading --- | |
| async def run_blocking_task_async(func, *args, **kwargs): | |
| """Offloads a blocking function call to the ThreadPoolExecutor.""" | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor( | |
| tts_executor, | |
| lambda: func(*args, **kwargs) | |
| ) | |
| async def save_upload_file_async(upload_file: UploadFile) -> str: | |
| """Asynchronously saves the UploadFile to disk.""" | |
| temp_filename = os.path.join(TEMP_AUDIO_DIR, f"{time.time()}_{upload_file.filename}") | |
| try: | |
| # Use asyncio to read the file chunks in a non-blocking manner | |
| async with aiofiles.open(temp_filename, 'wb') as out_file: | |
| while content := await upload_file.read(1024 * 1024): | |
| await out_file.write(content) | |
| return temp_filename | |
| except Exception as e: | |
| logger.error(f"Error saving file: {e}") | |
| raise HTTPException(status_code=500, detail="Could not save reference audio file") | |
| # --- FastAPI Lifespan Manager (Kokoro Feature) --- | |
| async def lifespan(app: FastAPI): | |
| """Modern lifespan management: initialize model on startup, shutdown executor.""" | |
| try: | |
| app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE) | |
| except Exception as e: | |
| logger.error(f"Fatal startup error: {e}") | |
| # Terminate the application if the model can't load | |
| tts_executor.shutdown(wait=False) | |
| raise RuntimeError("Model initialization failed.") | |
| yield # Application serves requests | |
| # Shutdown | |
| logger.info("Shutting down ThreadPoolExecutor.") | |
| tts_executor.shutdown(wait=False) | |
| # --- FastAPI Application Setup --- | |
| app = FastAPI( | |
| title="NeuTTS Air Instant Cloning API", | |
| version="2.0.0-PROD-ENHANCED", | |
| docs_url="/docs", | |
| lifespan=lifespan | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- New Endpoints and Enhancements --- | |
| async def root(): | |
| return {"message": "NeuTTS Air API v2.0 - Ready for Instant Voice Cloning"} | |
| async def health_check(): | |
| """Enhanced health check (Kokoro Feature + Original Metrics)""" | |
| mem = psutil.virtual_memory() | |
| disk = psutil.disk_usage('/') | |
| return { | |
| "status": "healthy", | |
| "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None, | |
| "device": DEVICE, | |
| "concurrency_limit": MAX_WORKERS, | |
| "memory_usage": { | |
| "total_gb": round(mem.total / (1024**3), 2), | |
| "used_percent": mem.percent | |
| }, | |
| "disk_usage": { | |
| "total_gb": round(disk.total / (1024**3), 2), | |
| "used_percent": disk.percent | |
| } | |
| } | |
| async def cleanup_files(): | |
| """Maintenance endpoint to remove old generated and temporary files.""" | |
| await run_blocking_task_async(cleanup_files_blocking) | |
| return {"message": "Cleanup initiated successfully."} | |
| def cleanup_files_blocking(): | |
| """Blocking file cleanup logic (original NeuTTS feature).""" | |
| now = time.time() | |
| deleted_count = 0 | |
| for directory in [GENERATED_AUDIO_DIR, TEMP_AUDIO_DIR]: | |
| for filename in os.listdir(directory): | |
| filepath = os.path.join(directory, filename) | |
| if os.path.isfile(filepath): | |
| try: | |
| # Original cleanup logic: delete if older than CLEANUP_THRESHOLD | |
| if now - os.path.getctime(filepath) > CLEANUP_THRESHOLD: | |
| os.remove(filepath) | |
| deleted_count += 1 | |
| except Exception as e: | |
| logger.warning(f"Failed to delete {filepath}: {e}") | |
| logger.info(f"Cleanup completed: {deleted_count} files removed.") | |
| return deleted_count | |
| # --- Core Synthesis Endpoints --- | |
| async def text_to_speech( | |
| text: str = Form(...), | |
| reference_text: str = Form(...), | |
| speed: float = Form(1.0, ge=0.5, le=2.0), | |
| output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"), | |
| reference_audio: UploadFile = File(...)): | |
| """ | |
| Standard blocking TTS endpoint with in-memory processing and caching. | |
| """ | |
| if not hasattr(app.state, 'tts_wrapper'): | |
| raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded") | |
| start_time = time.time() | |
| try: | |
| # 1. Convert the uploaded file to WAV directly in memory | |
| converted_wav_buffer = await convert_to_wav_in_memory(reference_audio) | |
| ref_audio_bytes = converted_wav_buffer.getvalue() | |
| # 2. Offload the blocking AI process (now faster with caching) | |
| audio_data = await run_blocking_task_async( | |
| app.state.tts_wrapper.generate_speech_blocking, | |
| text, | |
| ref_audio_bytes, # Pass bytes, not a path | |
| reference_text | |
| ) | |
| # 3. Convert to requested output format | |
| audio_bytes = await run_blocking_task_async( | |
| app.state.tts_wrapper._convert_to_streamable_format, | |
| audio_data, | |
| output_format | |
| ) | |
| processing_time = time.time() - start_time | |
| audio_duration = len(audio_data) / SAMPLE_RATE | |
| return Response( | |
| content=audio_bytes, | |
| media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}", | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=tts_output.{output_format}", | |
| "X-Processing-Time": f"{processing_time:.2f}s", | |
| "X-Audio-Duration": f"{audio_duration:.2f}s" | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Synthesis error: {e}") | |
| if isinstance(e, HTTPException): | |
| raise | |
| raise HTTPException(status_code=500, detail=f"Synthesis failed: {e}") | |
| async def stream_text_to_speech_cloning( | |
| text: str = Form(..., min_length=1, max_length=5000), | |
| reference_text: str = Form(...), | |
| speed: float = Form(1.0, ge=0.5, le=2.0), | |
| output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"), | |
| reference_audio: UploadFile = File(...)): | |
| """ | |
| Sentence-by-Sentence Streaming using a high-performance, asyncio-native | |
| producer-consumer pipeline. This overlaps CPU-bound AI work with network I/O. | |
| """ | |
| if not hasattr(app.state, 'tts_wrapper'): | |
| raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded") | |
| # This async generator is the final, correct implementation. | |
| async def stream_generator(): | |
| loop = asyncio.get_event_loop() | |
| q = asyncio.Queue(maxsize=2) | |
| # The PRODUCER is now an async task that runs in the background. | |
| async def producer(): | |
| try: | |
| # The one-time setup cost: convert and encode the reference voice. | |
| # This is done before the loop to ensure the voice is ready. | |
| converted_wav_buffer = await convert_to_wav_in_memory(reference_audio) | |
| ref_audio_bytes = converted_wav_buffer.getvalue() | |
| audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest() | |
| ref_s = await loop.run_in_executor( | |
| tts_executor, | |
| app.state.tts_wrapper._get_or_create_reference_encoding, | |
| audio_hash, | |
| ref_audio_bytes | |
| ) | |
| sentences = app.state.tts_wrapper._split_text_into_chunks(text) | |
| for sentence in sentences: | |
| # Define the blocking work for a single chunk | |
| def process_chunk(): | |
| with torch.no_grad(): | |
| audio_chunk = app.state.tts_wrapper.tts_model.infer(sentence, ref_s, reference_text) | |
| return app.state.tts_wrapper._convert_to_streamable_format(audio_chunk, output_format) | |
| # Offload the blocking work to the thread pool | |
| mp3_bytes = await loop.run_in_executor(tts_executor, process_chunk) | |
| # Put the finished MP3 chunk into the async queue | |
| await q.put(mp3_bytes) | |
| except Exception as e: | |
| logger.error(f"Error in producer task: {e}") | |
| await q.put(e) | |
| finally: | |
| # Signal that production is finished | |
| await q.put(None) | |
| # Start the producer as a background task. It starts working immediately. | |
| producer_task = asyncio.create_task(producer()) | |
| # The main loop now acts as the CONSUMER. | |
| while True: | |
| # Await the next finished MP3 chunk from the queue. | |
| result = await q.get() | |
| if result is None: | |
| break | |
| if isinstance(result, Exception): | |
| logger.error(f"Terminating stream due to producer error: {result}") | |
| raise result | |
| # Yield the chunk to the user. While the network sends this, | |
| # the producer is already working on the next chunk in the background. | |
| yield result | |
| # Ensure the producer task is cleaned up. | |
| await producer_task | |
| return StreamingResponse( | |
| stream_generator(), | |
| media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}" | |
| ) | |
| # Note: The outer 'finally' block is now removed as its logic is handled in 2.5 and 4. | |
| async def get_audio(filename: str): | |
| """Original NeuTTS feature to serve generated audio files.""" | |
| file_path = os.path.join(GENERATED_AUDIO_DIR, filename) | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=404, detail="Audio file not found") | |
| return Response( | |
| content=open(file_path, "rb").read(), | |
| media_type=f"audio/{filename.split('.')[-1]}", # Simple media type detection | |
| headers={"Content-Disposition": f"attachment; filename={filename}"} | |
| ) | |