Spaces:
Paused
Paused
| import os | |
| import io | |
| import base64 | |
| import json | |
| import asyncio | |
| import logging | |
| from concurrent.futures import ThreadPoolExecutor | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
| from fastapi.responses import Response, JSONResponse | |
| import soundfile as sf | |
| from neutts_wrapper import NeuTTSWrapper | |
| # --- Configuration & Global Objects --- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Read device from environment variable, defaulting to 'cpu' | |
| DEVICE = os.getenv("MODEL_DEVICE", "cpu") | |
| # Use a ThreadPoolExecutor to run blocking ML code in a separate thread | |
| tts_executor = ThreadPoolExecutor(max_workers=1) | |
| # --- Lifespan Management (Model Loading) --- | |
| async def lifespan(app: FastAPI): | |
| """ | |
| Manages the model's lifecycle. It's loaded at startup and resources are | |
| cleaned up at shutdown. | |
| """ | |
| logger.info("Application startup...") | |
| try: | |
| # Load the model wrapper into the application state | |
| app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE) | |
| except Exception as e: | |
| logger.error(f"FATAL: Model could not be loaded. {e}") | |
| app.state.tts_wrapper = None | |
| yield # The application is now running | |
| logger.info("Application shutdown...") | |
| tts_executor.shutdown(wait=True) | |
| # --- FastAPI App Initialization --- | |
| app = FastAPI( | |
| title="NeuTTS Air Production API", | |
| description="Production-ready Text-to-Speech with Voice Cloning", | |
| version="2.0.0", | |
| lifespan=lifespan | |
| ) | |
| # --- Helper function for running blocking code --- | |
| async def run_in_executor(func, *args): | |
| """Runs a blocking function in the thread pool to avoid blocking the server.""" | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(tts_executor, func, *args) | |
| # --- API Endpoints --- | |
| async def root(): | |
| return {"status": "online", "service": "NeuTTS Air API v2"} | |
| async def health_check(): | |
| model_status = "loaded" if app.state.tts_wrapper else "degraded" | |
| return {"status": "healthy", "model_status": model_status, "device": DEVICE} | |
| async def synthesize_speech( | |
| ref_text: str = Form(...), | |
| gen_text: str = Form(...), | |
| ref_audio: UploadFile = File(...) | |
| ): | |
| if not app.state.tts_wrapper: | |
| raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded") | |
| try: | |
| ref_audio_bytes = await ref_audio.read() | |
| # Run blocking ML code in the thread pool | |
| ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes) | |
| wav_data = await run_in_executor(app.state.tts_wrapper.infer, gen_text, ref_codes, ref_text) | |
| # Process audio in-memory | |
| buffer = io.BytesIO() | |
| sf.write(buffer, wav_data, 24000, format='WAV') | |
| buffer.seek(0) | |
| return Response(content=buffer.read(), media_type="audio/wav") | |
| except Exception as e: | |
| logger.error(f"Synthesis failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") | |
| async def synthesize_speech_base64( | |
| ref_text: str = Form(...), | |
| gen_text: str = Form(...), | |
| ref_audio: UploadFile = File(...) | |
| ): | |
| if not app.state.tts_wrapper: | |
| raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded") | |
| try: | |
| ref_audio_bytes = await ref_audio.read() | |
| # Run blocking ML code in the thread pool | |
| ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes) | |
| wav_data = await run_in_executor(app.state.tts_wrapper.infer, gen_text, ref_codes, ref_text) | |
| # Process audio in-memory | |
| buffer = io.BytesIO() | |
| sf.write(buffer, wav_data, 24000, format='WAV') | |
| buffer.seek(0) | |
| audio_b64 = base64.b64encode(buffer.read()).decode('utf-8') | |
| return JSONResponse({"audio_data": audio_b64, "format": "wav"}) | |
| except Exception as e: | |
| logger.error(f"Base64 synthesis failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Base64 synthesis failed: {str(e)}") | |
| async def batch_synthesize( | |
| ref_text: str = Form(...), | |
| ref_audio: UploadFile = File(...), | |
| texts: str = Form(...) | |
| ): | |
| if not app.state.tts_wrapper: | |
| raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded") | |
| try: | |
| text_list = json.loads(texts) | |
| if not isinstance(text_list, list): | |
| raise ValueError("Texts must be a JSON array of strings.") | |
| ref_audio_bytes = await ref_audio.read() | |
| # Encode reference once, in the thread pool | |
| ref_codes = await run_in_executor(app.state.tts_wrapper.encode_reference, ref_audio_bytes) | |
| results = [] | |
| for text in text_list: | |
| # Infer for each text | |
| wav_data = await run_in_executor(app.state.tts_wrapper.infer, text, ref_codes, ref_text) | |
| buffer = io.BytesIO() | |
| sf.write(buffer, wav_data, 24000, format='WAV') | |
| buffer.seek(0) | |
| audio_b64 = base64.b64encode(buffer.read()).decode('utf-8') | |
| results.append({"text": text, "audio_data": audio_b64}) | |
| return JSONResponse({"generated_clips": results}) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=400, detail="Invalid JSON in 'texts' field.") | |
| except Exception as e: | |
| logger.error(f"Batch synthesis failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Batch synthesis failed: {str(e)}") |