Spaces:
Paused
Paused
| # app.py | |
| import os | |
| import io | |
| import asyncio | |
| import time | |
| import psutil | |
| import soundfile as sf | |
| import subprocess | |
| import numpy as np | |
| import librosa # Needed for monkey-patching | |
| from concurrent.futures import ThreadPoolExecutor | |
| from contextlib import asynccontextmanager | |
| import logging | |
| from types import MethodType | |
| import torch | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
| from fastapi.responses import Response, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| # This will now work because the Dockerfile clones the repo | |
| # and we add it to the path | |
| import sys | |
| sys.path.append(os.path.join(os.getcwd(), 'neutts-air')) | |
| from neuttsair.neutts import NeuTTSAir | |
| # --- Configuration & Logging --- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("NeuTTS-GGUF-API") | |
| # Production-ready configuration via Environment Variables | |
| BACKBONE_MODEL_PATH = os.getenv("BACKBONE_MODEL_PATH", "/app/models/neutts-air.gguf") | |
| CODEC_REPO = os.getenv("CODEC_REPO", "neuphonic/neucodec-onnx-decoder") # Using ONNX for performance | |
| DEVICE = "cpu" # llama-cpp handles its own device (CPU/GPU) management | |
| MAX_WORKERS = int(os.getenv("MAX_WORKERS", "2")) | |
| tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) | |
| SAMPLE_RATE = 24000 | |
| # --- Core Utility Functions --- | |
| async def convert_to_wav_in_memory(upload_file: UploadFile) -> io.BytesIO: | |
| """Converts uploaded audio to a 16kHz WAV for the encoder, in memory.""" | |
| ffmpeg_command = [ | |
| "ffmpeg", "-i", "pipe:0", "-f", "wav", "-ar", "16000", | |
| "-ac", "1", "-c:a", "pcm_s16le", "pipe:1" | |
| ] | |
| proc = await asyncio.create_subprocess_exec( | |
| *ffmpeg_command, stdin=subprocess.PIPE, | |
| stdout=subprocess.PIPE, stderr=subprocess.PIPE | |
| ) | |
| wav_data, stderr_data = await proc.communicate(input=await upload_file.read()) | |
| if proc.returncode != 0: | |
| error_message = stderr_data.decode() | |
| logger.error(f"In-memory conversion failed: {error_message}") | |
| error_detail = error_message.strip().splitlines()[-1] | |
| raise HTTPException(status_code=400, detail=f"Audio conversion failed: {error_detail}") | |
| return io.BytesIO(wav_data) | |
| async def run_blocking_task_async(func, *args, **kwargs): | |
| """Offloads a blocking function call to the ThreadPoolExecutor.""" | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(tts_executor, lambda: func(*args, **kwargs)) | |
| # --- Model Wrapper and Professional Integration --- | |
| def _encode_reference_from_memory(self, ref_audio: io.BytesIO): | |
| """ | |
| A replacement for the original encode_reference. | |
| This version reads from an in-memory BytesIO object instead of a file path, | |
| which is much faster for our API. | |
| """ | |
| wav, _ = librosa.load(ref_audio, sr=16000, mono=True) | |
| wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) | |
| with torch.no_grad(): | |
| ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0) | |
| return ref_codes | |
| class NeuTTSWrapper: | |
| def __init__(self): | |
| self.tts_model: NeuTTSAir | None = None | |
| self.load_model() | |
| def load_model(self): | |
| try: | |
| logger.info(f"Loading NeuTTSAir GGUF model from: {BACKBONE_MODEL_PATH}") | |
| self.tts_model = NeuTTSAir( | |
| backbone_repo=BACKBONE_MODEL_PATH, | |
| codec_repo=CODEC_REPO, | |
| backbone_device=DEVICE, | |
| codec_device=DEVICE | |
| ) | |
| # ** MONKEY-PATCHING **: This is the professional way to adapt the library | |
| # without changing its source code. We replace its file-based function | |
| # with our memory-based one. | |
| self.tts_model.encode_reference = MethodType(_encode_reference_from_memory, self.tts_model) | |
| logger.info("✅ NeuTTSAir GGUF model loaded and patched successfully.") | |
| except Exception as e: | |
| logger.error(f"❌ Model loading failed: {e}", exc_info=True) | |
| raise | |
| def convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes: | |
| """Converts NumPy audio array to bytes in the specified format.""" | |
| with io.BytesIO() as audio_buffer: | |
| sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format) | |
| return audio_buffer.getvalue() | |
| # --- FastAPI Application Setup --- | |
| async def lifespan(app: FastAPI): | |
| """Initializes the model on startup and shuts down the executor.""" | |
| try: | |
| app.state.tts_wrapper = NeuTTSWrapper() | |
| except Exception as e: | |
| logger.error(f"Fatal startup error: Model could not be loaded. {e}") | |
| # Properly handle shutdown if model loading fails | |
| tts_executor.shutdown(wait=False, cancel_futures=True) | |
| raise RuntimeError("Model initialization failed. Application cannot start.") from e | |
| yield | |
| logger.info("Shutting down ThreadPoolExecutor.") | |
| tts_executor.shutdown(wait=True) | |
| app = FastAPI( | |
| title="NeuTTS Air GGUF Cloning API", | |
| version="3.0.0-PROD-GGUF", | |
| lifespan=lifespan | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| # --- API Endpoints --- | |
| async def root(): | |
| return {"message": "NeuTTS Air GGUF API - Ready for High-Speed Voice Cloning"} | |
| async def health_check(): | |
| mem = psutil.virtual_memory() | |
| return { | |
| "status": "healthy", | |
| "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None, | |
| "model_type": "GGUF", | |
| "backbone_path": BACKBONE_MODEL_PATH, | |
| "codec_repo": CODEC_REPO, | |
| "memory_usage_percent": mem.percent | |
| } | |
| async def text_to_speech( | |
| text: str = Form(...), | |
| reference_text: str = Form(...), | |
| output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"), | |
| reference_audio: UploadFile = File(...) | |
| ): | |
| """Standard blocking TTS endpoint optimized for GGUF.""" | |
| start_time = time.time() | |
| try: | |
| converted_wav_buffer = await convert_to_wav_in_memory(reference_audio) | |
| ref_codes = await run_blocking_task_async( | |
| app.state.tts_wrapper.tts_model.encode_reference, | |
| converted_wav_buffer | |
| ) | |
| audio_data = await run_blocking_task_async( | |
| app.state.tts_wrapper.tts_model.infer, | |
| text, ref_codes, reference_text | |
| ) | |
| audio_bytes = await run_blocking_task_async( | |
| app.state.tts_wrapper.convert_to_streamable_format, | |
| audio_data, output_format | |
| ) | |
| processing_time = time.time() - start_time | |
| return Response( | |
| content=audio_bytes, | |
| media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}", | |
| headers={"X-Processing-Time": f"{processing_time:.2f}s"} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Synthesis error: {e}", exc_info=True) | |
| detail = str(e) if isinstance(e, HTTPException) else "An internal error occurred during synthesis." | |
| raise HTTPException(status_code=500, detail=detail) | |
| async def stream_text_to_speech_cloning( | |
| text: str = Form(..., min_length=1), | |
| reference_text: str = Form(...), | |
| output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"), | |
| reference_audio: UploadFile = File(...) | |
| ): | |
| """High-performance, sentence-by-sentence streaming using the GGUF backend.""" | |
| try: | |
| converted_wav_buffer = await convert_to_wav_in_memory(reference_audio) | |
| ref_codes = await run_blocking_task_async( | |
| app.state.tts_wrapper.tts_model.encode_reference, | |
| converted_wav_buffer | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error during pre-processing for stream: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail="Failed to prepare reference audio for streaming.") | |
| async def stream_generator(): | |
| # The model's infer_stream is a blocking generator. We must run it in the executor. | |
| loop = asyncio.get_event_loop() | |
| queue = asyncio.Queue() | |
| def producer(): | |
| try: | |
| # This loop will block in the thread, but not the main event loop | |
| for audio_chunk in app.state.tts_wrapper.tts_model.infer_stream(text, ref_codes, reference_text): | |
| # Convert chunk to the desired output format in the same thread | |
| chunk_bytes = app.state.tts_wrapper.convert_to_streamable_format(audio_chunk, output_format) | |
| # Put the result into the thread-safe asyncio queue | |
| loop.call_soon_threadsafe(queue.put_nowait, chunk_bytes) | |
| except Exception as e: | |
| logger.error(f"Error in streaming producer thread: {e}", exc_info=True) | |
| loop.call_soon_threadsafe(queue.put_nowait, e) | |
| finally: | |
| loop.call_soon_threadsafe(queue.put_nowait, None) # Signal end of stream | |
| # Start the blocking producer in the thread pool | |
| producer_task = loop.run_in_executor(tts_executor, producer) | |
| # The consumer runs in the main async event loop | |
| while True: | |
| item = await queue.get() | |
| if item is None: | |
| break | |
| if isinstance(item, Exception): | |
| raise item | |
| yield item | |
| await producer_task # Ensure the producer finishes cleanly | |
| return StreamingResponse( | |
| stream_generator(), | |
| media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}" | |
| ) |