neu / app.py
Rajhuggingface4253's picture
Update app.py
63cf9bb verified
raw
history blame
5.77 kB
import os
import io
import base64
import json
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import Response, JSONResponse
import soundfile as sf
from neutts_wrapper import NeuTTSWrapper
# --- Configuration & Global Objects ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Read device from environment variable, defaulting to 'cpu'
DEVICE = os.getenv("MODEL_DEVICE", "cpu")
# Use a ThreadPoolExecutor to run blocking ML code in a separate thread
tts_executor = ThreadPoolExecutor(max_workers=1)
# --- Lifespan Management (Model Loading) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Manages the model's lifecycle. It's loaded at startup and resources are
cleaned up at shutdown.
"""
logger.info("Application startup...")
try:
# Load the model wrapper into the application state
app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE)
except Exception as e:
logger.error(f"FATAL: Model could not be loaded. {e}")
app.state.tts_wrapper = None
yield # The application is now running
logger.info("Application shutdown...")
tts_executor.shutdown(wait=True)
# --- FastAPI App Initialization ---
app = FastAPI(
title="NeuTTS Air Production API",
description="Production-ready Text-to-Speech with Voice Cloning",
version="2.0.0",
lifespan=lifespan
)
# --- Helper function for running blocking code ---
async def run_in_executor(func, *args):
"""Runs a blocking function in the thread pool to avoid blocking the server."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(tts_executor, func, *args)
# --- API Endpoints ---
@app.get("/")
async def root():
return {"status": "online", "service": "NeuTTS Air API v2"}
@app.get("/health")
async def health_check():
model_status = "loaded" if app.state.tts_wrapper else "degraded"
return {"status": "healthy", "model_status": model_status, "device": DEVICE}
@app.post("/api/v1/synthesize")
async def synthesize_speech(
ref_text: str = Form(...),
gen_text: str = Form(...),
ref_audio: UploadFile = File(...)
):
if not app.state.tts_wrapper:
raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
try:
ref_audio_bytes = await ref_audio.read()
# Run blocking ML code in the thread pool
ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes)
wav_data = await run_in_executor(app.state.tts_wrapper.infer, gen_text, ref_codes, ref_text)
# Process audio in-memory
buffer = io.BytesIO()
sf.write(buffer, wav_data, 24000, format='WAV')
buffer.seek(0)
return Response(content=buffer.read(), media_type="audio/wav")
except Exception as e:
logger.error(f"Synthesis failed: {e}")
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
@app.post("/api/v1/synthesize/b64")
async def synthesize_speech_base64(
ref_text: str = Form(...),
gen_text: str = Form(...),
ref_audio: UploadFile = File(...)
):
if not app.state.tts_wrapper:
raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
try:
ref_audio_bytes = await ref_audio.read()
# Run blocking ML code in the thread pool
ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes)
wav_data = await run_in_executor(app.state.tts_wrapper.infer, gen_text, ref_codes, ref_text)
# Process audio in-memory
buffer = io.BytesIO()
sf.write(buffer, wav_data, 24000, format='WAV')
buffer.seek(0)
audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
return JSONResponse({"audio_data": audio_b64, "format": "wav"})
except Exception as e:
logger.error(f"Base64 synthesis failed: {e}")
raise HTTPException(status_code=500, detail=f"Base64 synthesis failed: {str(e)}")
@app.post("/api/v1/batch-synthesize")
async def batch_synthesize(
ref_text: str = Form(...),
ref_audio: UploadFile = File(...),
texts: str = Form(...)
):
if not app.state.tts_wrapper:
raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
try:
text_list = json.loads(texts)
if not isinstance(text_list, list):
raise ValueError("Texts must be a JSON array of strings.")
ref_audio_bytes = await ref_audio.read()
# Encode reference once, in the thread pool
ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes)
results = []
for text in text_list:
# Infer for each text
wav_data = await run_in_executor(app.state.tts_wrapper.infer, text, ref_codes, ref_text)
buffer = io.BytesIO()
sf.write(buffer, wav_data, 24000, format='WAV')
buffer.seek(0)
audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
results.append({"text": text, "audio_data": audio_b64})
return JSONResponse({"generated_clips": results})
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON in 'texts' field.")
except Exception as e:
logger.error(f"Batch synthesis failed: {e}")
raise HTTPException(status_code=500, detail=f"Batch synthesis failed: {str(e)}")