neu / app.py
Rajhuggingface4253's picture
Update app.py
24bb5f8 verified
raw
history blame
13 kB
import os
import io
import asyncio
import time
import numpy as np
import psutil
import soundfile as sf
import subprocess
from concurrent.futures import ThreadPoolExecutor
from typing import Generator
from contextlib import asynccontextmanager
import logging
import torch
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import Response, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import re
import hashlib
from functools import lru_cache
# Ensure the cloned neutts-air repository is in the path
import sys
sys.path.append(os.path.join(os.getcwd(), 'neutts-air'))
from neuttsair.neutts import NeuTTSAir
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("NeuTTS-API")
# --- Configuration & Utility Functions ---
# Explicitly use CPU as per Dockerfile and Hugging Face free tier compatibility
DEVICE = "cpu"
# Configure Max Workers for concurrent synthesis threads (1-2 is safe for CPU-only)
MAX_WORKERS = 2
tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
SAMPLE_RATE = 24000
class TTSRequestModel(BaseModel):
"""Model for non-file inputs to synthesis and streaming."""
text: str = Field(..., min_length=1, max_length=1000)
output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$")
async def convert_to_wav_in_memory(upload_file: UploadFile) -> io.BytesIO:
"""
Converts uploaded audio to a 24kHz WAV in memory using FFmpeg pipes.
This avoids all intermediate disk I/O for maximum speed.
"""
ffmpeg_command = [
"ffmpeg",
"-i", "pipe:0", # Read from stdin
"-f", "wav",
"-ar", str(SAMPLE_RATE),
"-ac", "1",
"-c:a", "pcm_s16le",
"pipe:1" # Write to stdout
]
# Start the subprocess with pipes for stdin, stdout, and stderr
proc = await asyncio.create_subprocess_exec(
*ffmpeg_command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# Stream the uploaded file data into ffmpeg's stdin
# and capture the resulting WAV data from its stdout
wav_data, stderr_data = await proc.communicate(input=await upload_file.read())
if proc.returncode != 0:
error_message = stderr_data.decode()
logger.error(f"In-memory conversion failed: {error_message}")
# Provide the last line of the FFmpeg error to the user
error_detail = error_message.splitlines()[-1] if error_message else "Unknown FFmpeg error."
raise HTTPException(status_code=400, detail=f"Audio format conversion failed: {error_detail}")
logger.info("In-memory FFmpeg conversion successful.")
# Return the raw WAV data in a BytesIO buffer, ready for the model
return io.BytesIO(wav_data)
# --- Model Wrapper and Logic ---
class NeuTTSWrapper:
def __init__(self, device: str = "cpu"):
self.tts_model = None
self.device = device
self.load_model()
def load_model(self):
try:
logger.info(f"Loading NeuTTSAir model on device: {self.device}")
# Ensure we respect the CPU configuration
self.tts_model = NeuTTSAir(backbone_device=self.device, codec_device=self.device)
logger.info("✅ NeuTTSAir model loaded successfully.")
except Exception as e:
logger.error(f"❌ Model loading failed: {e}")
raise
def _convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes:
"""Converts NumPy audio array to streamable bytes in the specified format."""
audio_buffer = io.BytesIO()
try:
sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format)
except Exception as e:
logger.error(f"Failed to write audio data to format {audio_format}: {e}")
raise
audio_buffer.seek(0)
return audio_buffer.read()
def _split_text_into_chunks(self, text: str) -> list[str]:
"""
Splits text into sentences OR clauses using a robust regex.
This is fast, library-free, and now handles commas.
"""
# This regex now finds all sequences of characters that are not a sentence-ending
# or clause-ending punctuation mark, followed by that punctuation.
# The only change is adding ',' to the character sets.
chunks = re.findall(r'[^.,!?]+[.,!?]*', text)
return [c.strip() for c in chunks if c.strip()]
@lru_cache(maxsize=32)
def _get_or_create_reference_encoding(self, audio_content_hash: str, audio_bytes: bytes) -> torch.Tensor:
"""
Caches the expensive reference encoding operation using an in-memory LRU cache.
The hash of the audio content is the key.
"""
logger.info(f"Cache miss for hash: {audio_content_hash[:10]}... Encoding new reference.")
# The model's encode_reference can take a file-like object (BytesIO)
return self.tts_model.encode_reference(io.BytesIO(audio_bytes))
def generate_speech_blocking(self, text: str, ref_audio_bytes: bytes, reference_text: str) -> np.ndarray:
"""Blocking synthesis using cached reference encoding."""
# 1. Hash the audio bytes to get a cache key
audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
# 2. Get the encoding from the cache (or create it if new)
ref_s = self._get_or_create_reference_encoding(audio_hash, ref_audio_bytes)
# 3. Infer full text
with torch.no_grad():
audio = self.tts_model.infer(text, ref_s, reference_text)
return audio
# --- Asynchronous Offloading ---
async def run_blocking_task_async(func, *args, **kwargs):
"""Offloads a blocking function call to the ThreadPoolExecutor."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
tts_executor,
lambda: func(*args, **kwargs)
)
# --- FastAPI Lifespan Manager (Kokoro Feature) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Modern lifespan management: initialize model on startup, shutdown executor."""
try:
app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE)
except Exception as e:
logger.error(f"Fatal startup error: {e}")
# Terminate the application if the model can't load
tts_executor.shutdown(wait=False)
raise RuntimeError("Model initialization failed.")
yield # Application serves requests
# Shutdown
logger.info("Shutting down ThreadPoolExecutor.")
tts_executor.shutdown(wait=False)
# --- FastAPI Application Setup ---
app = FastAPI(
title="NeuTTS Air Instant Cloning API",
version="2.0.0-PROD-ENHANCED",
docs_url="/docs",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# --- New Endpoints and Enhancements ---
@app.get("/")
async def root():
return {"message": "NeuTTS Air API v2.0 - Ready for Instant Voice Cloning"}
@app.get("/health")
async def health_check():
"""Enhanced health check (Kokoro Feature + Original Metrics)"""
mem = psutil.virtual_memory()
disk = psutil.disk_usage('/')
return {
"status": "healthy",
"model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None,
"device": DEVICE,
"concurrency_limit": MAX_WORKERS,
"memory_usage": {
"total_gb": round(mem.total / (1024**3), 2),
"used_percent": mem.percent
},
"disk_usage": {
"total_gb": round(disk.total / (1024**3), 2),
"used_percent": disk.percent
}
}
# --- Core Synthesis Endpoints ---
@app.post("/synthesize", response_class=Response)
async def text_to_speech(
text: str = Form(...),
reference_text: str = Form(...),
output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
reference_audio: UploadFile = File(...)):
"""
Standard blocking TTS endpoint with in-memory processing and caching.
"""
if not hasattr(app.state, 'tts_wrapper'):
raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
start_time = time.time()
try:
# 1. Convert the uploaded file to WAV directly in memory
converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
ref_audio_bytes = converted_wav_buffer.getvalue()
# 2. Offload the blocking AI process (now faster with caching)
audio_data = await run_blocking_task_async(
app.state.tts_wrapper.generate_speech_blocking,
text,
ref_audio_bytes, # Pass bytes, not a path
reference_text
)
# 3. Convert to requested output format
audio_bytes = await run_blocking_task_async(
app.state.tts_wrapper._convert_to_streamable_format,
audio_data,
output_format
)
processing_time = time.time() - start_time
audio_duration = len(audio_data) / SAMPLE_RATE
return Response(
content=audio_bytes,
media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
headers={
"Content-Disposition": f"attachment; filename=tts_output.{output_format}",
"X-Processing-Time": f"{processing_time:.2f}s",
"X-Audio-Duration": f"{audio_duration:.2f}s"
}
)
except Exception as e:
logger.error(f"Synthesis error: {e}")
if isinstance(e, HTTPException):
raise
raise HTTPException(status_code=500, detail=f"Synthesis failed: {e}")
@app.post("/synthesize/stream")
async def stream_text_to_speech_cloning(
text: str = Form(..., min_length=1, max_length=5000),
reference_text: str = Form(...),
output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
reference_audio: UploadFile = File(...)):
"""
Sentence-by-Sentence Streaming using a high-performance, asyncio-native
look-ahead pipeline. This ensures true overlap of CPU work and network I/O.
"""
if not hasattr(app.state, 'tts_wrapper'):
raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
async def stream_generator():
loop = asyncio.get_event_loop()
q = asyncio.Queue(maxsize=MAX_WORKERS + 1) # Queue size based on workers
async def producer():
try:
converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
ref_audio_bytes = converted_wav_buffer.getvalue()
# Perform the one-time voice encoding
audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
ref_s = await loop.run_in_executor(
tts_executor,
app.state.tts_wrapper._get_or_create_reference_encoding,
audio_hash,
ref_audio_bytes
)
sentences = app.state.tts_wrapper._split_text_into_chunks(text)
def process_chunk(sentence_text):
with torch.no_grad():
audio_chunk = app.state.tts_wrapper.tts_model.infer(sentence_text, ref_s, reference_text)
return app.state.tts_wrapper._convert_to_streamable_format(audio_chunk, output_format)
# Schedule all chunks for background processing
for sentence in sentences:
task = loop.run_in_executor(tts_executor, process_chunk, sentence)
await q.put(task)
except Exception as e:
logger.error(f"Error in producer task: {e}")
await q.put(e)
finally:
await q.put(None)
producer_task = asyncio.create_task(producer())
# --- High-Performance Consumer with Look-Ahead ---
# Get the first task from the queue to start the process.
current_task = await q.get()
while current_task is not None:
# Simultaneously, get the NEXT task from the queue.
# This allows the next chunk to start processing while we wait for the current one.
next_task = await q.get()
# Now, wait for the CURRENT task to finish.
if isinstance(current_task, Exception):
raise current_task
chunk_bytes = await current_task
yield chunk_bytes
# The next task becomes the current task for the next iteration.
current_task = next_task
await producer_task
return StreamingResponse(
stream_generator(),
media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}"
)