Spaces:
Paused
Paused
File size: 5,767 Bytes
3b32b80 3dd9c50 63cf9bb 26c5cf5 63cf9bb 3dd9c50 63cf9bb 3b32b80 63cf9bb cc02ee6 63cf9bb cc02ee6 63cf9bb cc02ee6 63cf9bb 26c5cf5 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 26c5cf5 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb cc02ee6 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 63cf9bb 3b32b80 3dd9c50 63cf9bb 3dd9c50 63cf9bb 3dd9c50 63cf9bb 3dd9c50 63cf9bb 3dd9c50 63cf9bb 3dd9c50 63cf9bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
import io
import base64
import json
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import Response, JSONResponse
import soundfile as sf
from neutts_wrapper import NeuTTSWrapper
# --- Configuration & Global Objects ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Read device from environment variable, defaulting to 'cpu'
DEVICE = os.getenv("MODEL_DEVICE", "cpu")
# Use a ThreadPoolExecutor to run blocking ML code in a separate thread
tts_executor = ThreadPoolExecutor(max_workers=1)
# --- Lifespan Management (Model Loading) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Manages the model's lifecycle. It's loaded at startup and resources are
cleaned up at shutdown.
"""
logger.info("Application startup...")
try:
# Load the model wrapper into the application state
app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE)
except Exception as e:
logger.error(f"FATAL: Model could not be loaded. {e}")
app.state.tts_wrapper = None
yield # The application is now running
logger.info("Application shutdown...")
tts_executor.shutdown(wait=True)
# --- FastAPI App Initialization ---
app = FastAPI(
title="NeuTTS Air Production API",
description="Production-ready Text-to-Speech with Voice Cloning",
version="2.0.0",
lifespan=lifespan
)
# --- Helper function for running blocking code ---
async def run_in_executor(func, *args):
"""Runs a blocking function in the thread pool to avoid blocking the server."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(tts_executor, func, *args)
# --- API Endpoints ---
@app.get("/")
async def root():
return {"status": "online", "service": "NeuTTS Air API v2"}
@app.get("/health")
async def health_check():
model_status = "loaded" if app.state.tts_wrapper else "degraded"
return {"status": "healthy", "model_status": model_status, "device": DEVICE}
@app.post("/api/v1/synthesize")
async def synthesize_speech(
ref_text: str = Form(...),
gen_text: str = Form(...),
ref_audio: UploadFile = File(...)
):
if not app.state.tts_wrapper:
raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
try:
ref_audio_bytes = await ref_audio.read()
# Run blocking ML code in the thread pool
ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes)
wav_data = await run_in_executor(app.state.tts_wrapper.infer, gen_text, ref_codes, ref_text)
# Process audio in-memory
buffer = io.BytesIO()
sf.write(buffer, wav_data, 24000, format='WAV')
buffer.seek(0)
return Response(content=buffer.read(), media_type="audio/wav")
except Exception as e:
logger.error(f"Synthesis failed: {e}")
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
@app.post("/api/v1/synthesize/b64")
async def synthesize_speech_base64(
ref_text: str = Form(...),
gen_text: str = Form(...),
ref_audio: UploadFile = File(...)
):
if not app.state.tts_wrapper:
raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
try:
ref_audio_bytes = await ref_audio.read()
# Run blocking ML code in the thread pool
ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes)
wav_data = await run_in_executor(app.state.tts_wrapper.infer, gen_text, ref_codes, ref_text)
# Process audio in-memory
buffer = io.BytesIO()
sf.write(buffer, wav_data, 24000, format='WAV')
buffer.seek(0)
audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
return JSONResponse({"audio_data": audio_b64, "format": "wav"})
except Exception as e:
logger.error(f"Base64 synthesis failed: {e}")
raise HTTPException(status_code=500, detail=f"Base64 synthesis failed: {str(e)}")
@app.post("/api/v1/batch-synthesize")
async def batch_synthesize(
ref_text: str = Form(...),
ref_audio: UploadFile = File(...),
texts: str = Form(...)
):
if not app.state.tts_wrapper:
raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded")
try:
text_list = json.loads(texts)
if not isinstance(text_list, list):
raise ValueError("Texts must be a JSON array of strings.")
ref_audio_bytes = await ref_audio.read()
# Encode reference once, in the thread pool
ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes)
results = []
for text in text_list:
# Infer for each text
wav_data = await run_in_executor(app.state.tts_wrapper.infer, text, ref_codes, ref_text)
buffer = io.BytesIO()
sf.write(buffer, wav_data, 24000, format='WAV')
buffer.seek(0)
audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
results.append({"text": text, "audio_data": audio_b64})
return JSONResponse({"generated_clips": results})
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON in 'texts' field.")
except Exception as e:
logger.error(f"Batch synthesis failed: {e}")
raise HTTPException(status_code=500, detail=f"Batch synthesis failed: {str(e)}") |