neu / app.py
Rajhuggingface4253's picture
Update app.py
a1eb108 verified
raw
history blame
9.64 kB
# app.py
import os
import io
import asyncio
import time
import psutil
import soundfile as sf
import subprocess
import numpy as np
import librosa # Needed for monkey-patching
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
import logging
from types import MethodType
import torch
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import Response, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
# This will now work because the Dockerfile clones the repo
# and we add it to the path
import sys
sys.path.append(os.path.join(os.getcwd(), 'neutts-air'))
from neuttsair.neutts import NeuTTSAir
# --- Configuration & Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("NeuTTS-GGUF-API")
# Production-ready configuration via Environment Variables
BACKBONE_MODEL_PATH = os.getenv("BACKBONE_MODEL_PATH", "/app/models/neutts-air.gguf")
CODEC_REPO = os.getenv("CODEC_REPO", "neuphonic/neucodec-onnx-decoder") # Using ONNX for performance
DEVICE = "cpu" # llama-cpp handles its own device (CPU/GPU) management
MAX_WORKERS = int(os.getenv("MAX_WORKERS", "2"))
tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
SAMPLE_RATE = 24000
# --- Core Utility Functions ---
async def convert_to_wav_in_memory(upload_file: UploadFile) -> io.BytesIO:
"""Converts uploaded audio to a 16kHz WAV for the encoder, in memory."""
ffmpeg_command = [
"ffmpeg", "-i", "pipe:0", "-f", "wav", "-ar", "16000",
"-ac", "1", "-c:a", "pcm_s16le", "pipe:1"
]
proc = await asyncio.create_subprocess_exec(
*ffmpeg_command, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
wav_data, stderr_data = await proc.communicate(input=await upload_file.read())
if proc.returncode != 0:
error_message = stderr_data.decode()
logger.error(f"In-memory conversion failed: {error_message}")
error_detail = error_message.strip().splitlines()[-1]
raise HTTPException(status_code=400, detail=f"Audio conversion failed: {error_detail}")
return io.BytesIO(wav_data)
async def run_blocking_task_async(func, *args, **kwargs):
"""Offloads a blocking function call to the ThreadPoolExecutor."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(tts_executor, lambda: func(*args, **kwargs))
# --- Model Wrapper and Professional Integration ---
def _encode_reference_from_memory(self, ref_audio: io.BytesIO):
"""
A replacement for the original encode_reference.
This version reads from an in-memory BytesIO object instead of a file path,
which is much faster for our API.
"""
wav, _ = librosa.load(ref_audio, sr=16000, mono=True)
wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0)
with torch.no_grad():
ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
return ref_codes
class NeuTTSWrapper:
def __init__(self):
self.tts_model: NeuTTSAir | None = None
self.load_model()
def load_model(self):
try:
logger.info(f"Loading NeuTTSAir GGUF model from: {BACKBONE_MODEL_PATH}")
self.tts_model = NeuTTSAir(
backbone_repo=BACKBONE_MODEL_PATH,
codec_repo=CODEC_REPO,
backbone_device=DEVICE,
codec_device=DEVICE
)
# ** MONKEY-PATCHING **: This is the professional way to adapt the library
# without changing its source code. We replace its file-based function
# with our memory-based one.
self.tts_model.encode_reference = MethodType(_encode_reference_from_memory, self.tts_model)
logger.info("✅ NeuTTSAir GGUF model loaded and patched successfully.")
except Exception as e:
logger.error(f"❌ Model loading failed: {e}", exc_info=True)
raise
def convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes:
"""Converts NumPy audio array to bytes in the specified format."""
with io.BytesIO() as audio_buffer:
sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format)
return audio_buffer.getvalue()
# --- FastAPI Application Setup ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initializes the model on startup and shuts down the executor."""
try:
app.state.tts_wrapper = NeuTTSWrapper()
except Exception as e:
logger.error(f"Fatal startup error: Model could not be loaded. {e}")
# Properly handle shutdown if model loading fails
tts_executor.shutdown(wait=False, cancel_futures=True)
raise RuntimeError("Model initialization failed. Application cannot start.") from e
yield
logger.info("Shutting down ThreadPoolExecutor.")
tts_executor.shutdown(wait=True)
app = FastAPI(
title="NeuTTS Air GGUF Cloning API",
version="3.0.0-PROD-GGUF",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"],
)
# --- API Endpoints ---
@app.get("/")
async def root():
return {"message": "NeuTTS Air GGUF API - Ready for High-Speed Voice Cloning"}
@app.get("/health")
async def health_check():
mem = psutil.virtual_memory()
return {
"status": "healthy",
"model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None,
"model_type": "GGUF",
"backbone_path": BACKBONE_MODEL_PATH,
"codec_repo": CODEC_REPO,
"memory_usage_percent": mem.percent
}
@app.post("/synthesize", response_class=Response)
async def text_to_speech(
text: str = Form(...),
reference_text: str = Form(...),
output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"),
reference_audio: UploadFile = File(...)
):
"""Standard blocking TTS endpoint optimized for GGUF."""
start_time = time.time()
try:
converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
ref_codes = await run_blocking_task_async(
app.state.tts_wrapper.tts_model.encode_reference,
converted_wav_buffer
)
audio_data = await run_blocking_task_async(
app.state.tts_wrapper.tts_model.infer,
text, ref_codes, reference_text
)
audio_bytes = await run_blocking_task_async(
app.state.tts_wrapper.convert_to_streamable_format,
audio_data, output_format
)
processing_time = time.time() - start_time
return Response(
content=audio_bytes,
media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}",
headers={"X-Processing-Time": f"{processing_time:.2f}s"}
)
except Exception as e:
logger.error(f"Synthesis error: {e}", exc_info=True)
detail = str(e) if isinstance(e, HTTPException) else "An internal error occurred during synthesis."
raise HTTPException(status_code=500, detail=detail)
@app.post("/synthesize/stream")
async def stream_text_to_speech_cloning(
text: str = Form(..., min_length=1),
reference_text: str = Form(...),
output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"),
reference_audio: UploadFile = File(...)
):
"""High-performance, sentence-by-sentence streaming using the GGUF backend."""
try:
converted_wav_buffer = await convert_to_wav_in_memory(reference_audio)
ref_codes = await run_blocking_task_async(
app.state.tts_wrapper.tts_model.encode_reference,
converted_wav_buffer
)
except Exception as e:
logger.error(f"Error during pre-processing for stream: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to prepare reference audio for streaming.")
async def stream_generator():
# The model's infer_stream is a blocking generator. We must run it in the executor.
loop = asyncio.get_event_loop()
queue = asyncio.Queue()
def producer():
try:
# This loop will block in the thread, but not the main event loop
for audio_chunk in app.state.tts_wrapper.tts_model.infer_stream(text, ref_codes, reference_text):
# Convert chunk to the desired output format in the same thread
chunk_bytes = app.state.tts_wrapper.convert_to_streamable_format(audio_chunk, output_format)
# Put the result into the thread-safe asyncio queue
loop.call_soon_threadsafe(queue.put_nowait, chunk_bytes)
except Exception as e:
logger.error(f"Error in streaming producer thread: {e}", exc_info=True)
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, None) # Signal end of stream
# Start the blocking producer in the thread pool
producer_task = loop.run_in_executor(tts_executor, producer)
# The consumer runs in the main async event loop
while True:
item = await queue.get()
if item is None:
break
if isinstance(item, Exception):
raise item
yield item
await producer_task # Ensure the producer finishes cleanly
return StreamingResponse(
stream_generator(),
media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}"
)