Spaces:
Paused
Paused
| import os | |
| import io | |
| import asyncio | |
| import time | |
| import shutil | |
| import numpy as np | |
| import psutil | |
| import soundfile as sf | |
| import subprocess | |
| import tempfile | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import Optional, Generator, AsyncGenerator | |
| from contextlib import asynccontextmanager | |
| import logging | |
| import aiofiles | |
| import torch | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query, BackgroundTasks | |
| from fastapi.responses import Response, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import uuid | |
| from dataclasses import dataclass | |
| from queue import Queue, Empty | |
| import threading | |
| # Ensure the cloned neutts-air repository is in the path | |
| import sys | |
| sys.path.append(os.path.join(os.getcwd(), 'neutts-air')) | |
| from neuttsair.neutts import NeuTTSAir | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger("NeuTTS-API") | |
| # --- Configuration & Constants --- | |
| DEVICE = "cpu" | |
| MAX_WORKERS = 2 | |
| tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) | |
| SAMPLE_RATE = 24000 | |
| CLEANUP_THRESHOLD = 300 | |
| TEMP_AUDIO_DIR = "temp_audio" | |
| GENERATED_AUDIO_DIR = "generated_audio" | |
| os.makedirs(TEMP_AUDIO_DIR, exist_ok=True) | |
| os.makedirs(GENERATED_AUDIO_DIR, exist_ok=True) | |
| # --- Data Models --- | |
| class TTSRequestModel(BaseModel): | |
| text: str = Field(..., min_length=1, max_length=1000) | |
| speed: float = Field(default=1.0, ge=0.5, le=2.0) | |
| output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$") | |
| class SynthesisTask: | |
| task_id: str | |
| text: str | |
| reference_audio_path: str | |
| reference_text: str | |
| output_format: str | |
| created_at: float | |
| # --- Enhanced Audio Conversion with Async Support --- | |
| async def convert_to_wav_async(input_path: str) -> str: | |
| """Asynchronous audio conversion using subprocess with async wrapper.""" | |
| with tempfile.NamedTemporaryFile(suffix=".wav", dir=TEMP_AUDIO_DIR, delete=False) as tmp: | |
| output_path = tmp.name | |
| logger.info(f"Converting '{os.path.basename(input_path)}' to WAV") | |
| command = [ | |
| "ffmpeg", "-y", "-i", input_path, | |
| "-f", "wav", "-ar", str(SAMPLE_RATE), | |
| "-ac", "1", "-c:a", "pcm_s16le", output_path | |
| ] | |
| try: | |
| # Run FFmpeg asynchronously | |
| process = await asyncio.create_subprocess_exec( | |
| *command, | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.PIPE | |
| ) | |
| stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=30) | |
| if process.returncode != 0: | |
| error_detail = stderr.decode().splitlines()[-1] if stderr else "Unknown FFmpeg error" | |
| logger.error(f"FFmpeg conversion failed: {error_detail}") | |
| if os.path.exists(output_path): | |
| os.unlink(output_path) | |
| raise HTTPException(status_code=400, detail=f"Audio conversion failed: {error_detail}") | |
| logger.info("FFmpeg conversion successful") | |
| return output_path | |
| except asyncio.TimeoutError: | |
| logger.error("FFmpeg conversion timed out") | |
| if os.path.exists(output_path): | |
| os.unlink(output_path) | |
| raise HTTPException(status_code=504, detail="Audio conversion timed out") | |
| except Exception as e: | |
| logger.error(f"Conversion error: {e}") | |
| if os.path.exists(output_path): | |
| os.unlink(output_path) | |
| raise HTTPException(status_code=500, detail="Unexpected conversion error") | |
| # --- Enhanced Model Wrapper with Async Streaming --- | |
| class NeuTTSWrapper: | |
| def __init__(self, device: str = "cpu"): | |
| self.tts_model = None | |
| self.device = device | |
| self._model_lock = asyncio.Lock() # For thread-safe model access | |
| self.load_model() | |
| def load_model(self): | |
| try: | |
| logger.info(f"Loading NeuTTSAir model on device: {self.device}") | |
| self.tts_model = NeuTTSAir(backbone_device=self.device, codec_device=self.device) | |
| logger.info("✅ NeuTTSAir model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"❌ Model loading failed: {e}") | |
| raise | |
| def _convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes: | |
| """Convert NumPy audio array to streamable bytes.""" | |
| audio_buffer = io.BytesIO() | |
| try: | |
| sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format) | |
| except Exception as e: | |
| logger.error(f"Failed to write audio data to format {audio_format}: {e}") | |
| raise | |
| audio_buffer.seek(0) | |
| return audio_buffer.read() | |
| def _split_text_into_chunks(self, text: str, max_chunk_length: int = 100) -> list[str]: | |
| """Enhanced text splitting for better streaming chunks.""" | |
| # Simple sentence-based splitting with length limits | |
| sentences = [] | |
| current_sentence = "" | |
| for word in text.split(): | |
| test_sentence = f"{current_sentence} {word}".strip() | |
| if len(test_sentence) <= max_chunk_length: | |
| current_sentence = test_sentence | |
| else: | |
| if current_sentence: | |
| sentences.append(current_sentence) | |
| current_sentence = word | |
| if current_sentence: | |
| sentences.append(current_sentence) | |
| return sentences or [text] | |
| async def generate_speech_async(self, text: str, ref_audio_path: str, reference_text: str) -> np.ndarray: | |
| """Asynchronous speech generation with proper locking.""" | |
| async with self._model_lock: | |
| return await asyncio.get_event_loop().run_in_executor( | |
| tts_executor, | |
| self._generate_speech_blocking, | |
| text, ref_audio_path, reference_text | |
| ) | |
| def _generate_speech_blocking(self, text: str, ref_audio_path: str, reference_text: str) -> np.ndarray: | |
| """Blocking speech generation (runs in thread pool).""" | |
| ref_s = self.tts_model.encode_reference(ref_audio_path) | |
| with torch.no_grad(): | |
| audio = self.tts_model.infer(text, ref_s, reference_text) | |
| return audio | |
| async def stream_speech_async( | |
| self, | |
| text: str, | |
| ref_audio_path: str, | |
| reference_text: str, | |
| audio_format: str | |
| ) -> AsyncGenerator[bytes, None]: | |
| """True asynchronous streaming with immediate chunk delivery.""" | |
| logger.info(f"Starting true streaming synthesis for text length: {len(text)}") | |
| # Encode reference once (this is the only blocking part we need to do first) | |
| async with self._model_lock: | |
| ref_s = await asyncio.get_event_loop().run_in_executor( | |
| tts_executor, | |
| self.tts_model.encode_reference, | |
| ref_audio_path | |
| ) | |
| # Split text into chunks for streaming | |
| sentences = self._split_text_into_chunks(text) | |
| logger.info(f"Split text into {len(sentences)} chunks for streaming") | |
| # Stream each chunk asynchronously | |
| for i, sentence in enumerate(sentences): | |
| if not sentence.strip(): | |
| continue | |
| logger.debug(f"Generating streaming chunk {i+1}: '{sentence[:30]}...'") | |
| # Generate this chunk asynchronously | |
| audio_chunk = await asyncio.get_event_loop().run_in_executor( | |
| tts_executor, | |
| self._infer_chunk, | |
| sentence, ref_s, reference_text | |
| ) | |
| # Convert and yield immediately | |
| chunk_bytes = await asyncio.get_event_loop().run_in_executor( | |
| tts_executor, | |
| self._convert_to_streamable_format, | |
| audio_chunk, audio_format | |
| ) | |
| yield chunk_bytes | |
| logger.debug(f"Yielded chunk {i+1} ({len(chunk_bytes)} bytes)") | |
| logger.info("Streaming synthesis complete") | |
| def _infer_chunk(self, sentence: str, ref_s, reference_text: str) -> np.ndarray: | |
| """Infer a single chunk (runs in thread pool).""" | |
| with torch.no_grad(): | |
| return self.tts_model.infer(sentence, ref_s, reference_text) | |
| # --- Async Utility Functions --- | |
| async def save_upload_file_async(upload_file: UploadFile) -> str: | |
| """Asynchronously saves the UploadFile to disk.""" | |
| temp_filename = os.path.join(TEMP_AUDIO_DIR, f"{time.time()}_{upload_file.filename}") | |
| try: | |
| async with aiofiles.open(temp_filename, 'wb') as out_file: | |
| while content := await upload_file.read(1024 * 1024): | |
| await out_file.write(content) | |
| return temp_filename | |
| except Exception as e: | |
| logger.error(f"Error saving file: {e}") | |
| raise HTTPException(status_code=500, detail="Could not save reference audio file") | |
| async def cleanup_file_async(file_path: str): | |
| """Asynchronously clean up a file.""" | |
| try: | |
| if os.path.exists(file_path): | |
| os.unlink(file_path) | |
| logger.debug(f"Cleaned up file: {file_path}") | |
| except Exception as e: | |
| logger.warning(f"Failed to cleanup file {file_path}: {e}") | |
| async def scheduled_cleanup_task(): | |
| """Runs the cleanup task periodically in the background.""" | |
| while True: | |
| await asyncio.sleep(CLEANUP_THRESHOLD) # Wait for the defined period (e.g., 1 hour) | |
| logger.info("Running scheduled cleanup of old audio files...") | |
| try: | |
| await cleanup_files_async() | |
| except Exception as e: | |
| logger.error(f"Scheduled cleanup task failed: {e}") | |
| # --- FastAPI Lifespan Manager --- | |
| async def lifespan(app: FastAPI): | |
| """Modern lifespan management.""" | |
| try: | |
| app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE) | |
| app.state.synthesis_tasks = {} # Track active tasks | |
| asyncio.create_task(scheduled_cleanup_task()) | |
| logger.info("✅ Application startup complete") | |
| except Exception as e: | |
| logger.error(f"Fatal startup error: {e}") | |
| tts_executor.shutdown(wait=False) | |
| raise RuntimeError("Model initialization failed") | |
| yield | |
| logger.info("Shutting down ThreadPoolExecutor") | |
| tts_executor.shutdown(wait=True) | |
| # --- FastAPI Application Setup --- | |
| app = FastAPI( | |
| title="NeuTTS Air Instant Cloning API - Enhanced", | |
| version="3.0.0-PROD-STREAMING", | |
| docs_url="/docs", | |
| lifespan=lifespan | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Enhanced Endpoints --- | |
| async def root(): | |
| return {"message": "NeuTTS Air API v3.0 - True Streaming Ready"} | |
| async def health_check(): | |
| """Enhanced health check with streaming metrics.""" | |
| mem = psutil.virtual_memory() | |
| disk = psutil.disk_usage('/') | |
| active_tasks = len(getattr(app.state, 'synthesis_tasks', {})) | |
| return { | |
| "status": "healthy", | |
| "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None, | |
| "device": DEVICE, | |
| "concurrency_limit": MAX_WORKERS, | |
| "active_synthesis_tasks": active_tasks, | |
| "memory_usage": { | |
| "total_gb": round(mem.total / (1024**3), 2), | |
| "used_percent": mem.percent | |
| }, | |
| "disk_usage": { | |
| "total_gb": round(disk.total / (1024**3), 2), | |
| "used_percent": disk.percent | |
| } | |
| } | |
| async def text_to_speech( | |
| text: str = Form(...), | |
| reference_text: str = Form(...), | |
| speed: float = Form(1.0, ge=0.5, le=2.0), | |
| output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"), | |
| reference_audio: UploadFile = File(...), | |
| background_tasks: BackgroundTasks = None | |
| ): | |
| """ | |
| Enhanced standard TTS endpoint with better async handling. | |
| """ | |
| if not hasattr(app.state, 'tts_wrapper'): | |
| raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded") | |
| start_time = time.time() | |
| temp_ref_path = None | |
| converted_wav_path = None | |
| try: | |
| # 1. Save uploaded file | |
| temp_ref_path = await save_upload_file_async(reference_audio) | |
| # 2. Convert to WAV | |
| converted_wav_path = await convert_to_wav_async(temp_ref_path) | |
| # 3. Generate speech asynchronously | |
| audio_data = await app.state.tts_wrapper.generate_speech_async( | |
| text, converted_wav_path, reference_text | |
| ) | |
| # 4. Convert to requested format | |
| audio_bytes = await asyncio.get_event_loop().run_in_executor( | |
| tts_executor, | |
| app.state.tts_wrapper._convert_to_streamable_format, | |
| audio_data, output_format | |
| ) | |
| # 5. Save to disk (optional - can be disabled in production) | |
| audio_filename = f"tts_{int(time.time())}.{output_format}" | |
| final_path = os.path.join(GENERATED_AUDIO_DIR, audio_filename) | |
| async with aiofiles.open(final_path, 'wb') as f: | |
| await f.write(audio_bytes) | |
| processing_time = time.time() - start_time | |
| audio_duration = len(audio_data) / SAMPLE_RATE | |
| return Response( | |
| content=audio_bytes, | |
| media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}", | |
| headers={ | |
| "Content-Disposition": f"attachment; filename={audio_filename}", | |
| "X-Processing-Time": f"{processing_time:.2f}s", | |
| "X-Audio-Duration": f"{audio_duration:.2f}s", | |
| "X-First-Chunk-Time": f"{processing_time:.2f}s" # For comparison | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Synthesis error: {e}") | |
| if isinstance(e, HTTPException): | |
| raise | |
| raise HTTPException(status_code=500, detail=f"Synthesis failed: {e}") | |
| finally: | |
| # Schedule cleanup in background | |
| if background_tasks: | |
| if temp_ref_path: | |
| background_tasks.add_task(cleanup_file_async, temp_ref_path) | |
| if converted_wav_path: | |
| background_tasks.add_task(cleanup_file_async, converted_wav_path) | |
| else: | |
| # Fallback synchronous cleanup | |
| if temp_ref_path and os.path.exists(temp_ref_path): | |
| os.unlink(temp_ref_path) | |
| if converted_wav_path and os.path.exists(converted_wav_path): | |
| os.unlink(converted_wav_path) | |
| async def stream_text_to_speech( | |
| text: str = Form(..., min_length=1, max_length=5000), | |
| reference_text: str = Form(...), | |
| speed: float = Form(1.0, ge=0.5, le=2.0), | |
| output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"), | |
| reference_audio: UploadFile = File(...) | |
| ): | |
| """ | |
| TRUE Streaming Endpoint - delivers audio chunks as they're generated. | |
| """ | |
| if not hasattr(app.state, 'tts_wrapper'): | |
| raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded") | |
| temp_ref_path = None | |
| converted_wav_path = None | |
| try: | |
| # 1. Save and convert reference audio | |
| temp_ref_path = await save_upload_file_async(reference_audio) | |
| converted_wav_path = await convert_to_wav_async(temp_ref_path) | |
| # 2. Clean up original file immediately | |
| if temp_ref_path and os.path.exists(temp_ref_path): | |
| await cleanup_file_async(temp_ref_path) | |
| temp_ref_path = None | |
| # 3. Create async generator for streaming | |
| async def generate_audio_stream(): | |
| """Async generator that yields audio chunks as they're produced.""" | |
| try: | |
| first_chunk_time = time.time() | |
| chunk_count = 0 | |
| async for chunk_bytes in app.state.tts_wrapper.stream_speech_async( | |
| text, converted_wav_path, reference_text, output_format | |
| ): | |
| chunk_count += 1 | |
| # Log timing for first chunk | |
| if chunk_count == 1: | |
| first_chunk_time = time.time() - first_chunk_time | |
| logger.info(f"First audio chunk delivered in {first_chunk_time:.2f}s") | |
| yield chunk_bytes | |
| except Exception as e: | |
| logger.error(f"Stream generation error: {e}") | |
| raise | |
| finally: | |
| # Clean up converted file when streaming is complete | |
| if converted_wav_path and os.path.exists(converted_wav_path): | |
| await cleanup_file_async(converted_wav_path) | |
| # 4. Return streaming response | |
| return StreamingResponse( | |
| generate_audio_stream(), | |
| media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}", | |
| headers={ | |
| "Content-Disposition": "attachment; filename=tts_live_stream.mp3", | |
| "Transfer-Encoding": "chunked", | |
| "Cache-Control": "no-cache", | |
| "X-Accel-Buffering": "no", | |
| "X-Streaming": "true" | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Streaming setup error: {e}") | |
| # Cleanup on error | |
| if temp_ref_path and os.path.exists(temp_ref_path): | |
| await cleanup_file_async(temp_ref_path) | |
| if converted_wav_path and os.path.exists(converted_wav_path): | |
| await cleanup_file_async(converted_wav_path) | |
| if isinstance(e, HTTPException): | |
| raise | |
| raise HTTPException(status_code=500, detail=f"Streaming setup failed: {e}") | |
| async def get_audio(filename: str): | |
| """Serve generated audio files.""" | |
| file_path = os.path.join(GENERATED_AUDIO_DIR, filename) | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=404, detail="Audio file not found") | |
| # Use async file reading for better performance | |
| async with aiofiles.open(file_path, "rb") as f: | |
| content = await f.read() | |
| return Response( | |
| content=content, | |
| media_type=f"audio/{filename.split('.')[-1]}", | |
| headers={"Content-Disposition": f"attachment; filename={filename}"} | |
| ) | |
| async def cleanup_files(): | |
| """Enhanced cleanup endpoint.""" | |
| deleted_count = await cleanup_files_async() | |
| return {"message": f"Cleanup completed: {deleted_count} files removed"} | |
| async def cleanup_files_async(): | |
| """Async file cleanup.""" | |
| now = time.time() | |
| deleted_count = 0 | |
| for directory in [GENERATED_AUDIO_DIR, TEMP_AUDIO_DIR]: | |
| if not os.path.exists(directory): | |
| continue | |
| for filename in os.listdir(directory): | |
| filepath = os.path.join(directory, filename) | |
| if os.path.isfile(filepath): | |
| try: | |
| if now - os.path.getctime(filepath) > CLEANUP_THRESHOLD: | |
| await cleanup_file_async(filepath) | |
| deleted_count += 1 | |
| except Exception as e: | |
| logger.warning(f"Failed to delete {filepath}: {e}") | |
| logger.info(f"Cleanup completed: {deleted_count} files removed") | |
| return deleted_count | |
| # Performance monitoring endpoint | |
| async def get_metrics(): | |
| """Performance metrics endpoint.""" | |
| return { | |
| "active_threads": threading.active_count(), | |
| "executor_queue_size": tts_executor._work_queue.qsize() if hasattr(tts_executor, '_work_queue') else 0, | |
| "memory_usage_mb": psutil.Process().memory_info().rss / 1024 / 1024 | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=7860, | |
| workers=1, # Multiple workers not supported with in-memory model | |
| loop="asyncio", | |
| access_log=True | |
| ) |