Spaces:
Paused
Paused
| import os | |
| import io | |
| import asyncio | |
| import time | |
| import shutil | |
| import numpy as np | |
| import psutil | |
| import soundfile as sf | |
| import subprocess | |
| import tempfile | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import Optional, Generator | |
| from contextlib import asynccontextmanager | |
| import logging | |
| import aiofiles | |
| import torch | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query | |
| from fastapi.responses import Response, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| # Ensure the cloned neutts-air repository is in the path | |
| import sys | |
| sys.path.append(os.path.join(os.getcwd(), 'neutts-air')) | |
| from neuttsair.neutts import NeuTTSAir | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("NeuTTS-API") | |
| # --- Configuration & Utility Functions --- | |
| # Explicitly use CPU as per Dockerfile and Hugging Face free tier compatibility | |
| DEVICE = "cpu" | |
| # Configure Max Workers for concurrent synthesis threads (1-2 is safe for CPU-only) | |
| MAX_WORKERS = 2 | |
| tts_executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) | |
| SAMPLE_RATE = 24000 | |
| CLEANUP_THRESHOLD = 3600 # 1 hour in seconds | |
| TEMP_AUDIO_DIR = "temp_audio" | |
| GENERATED_AUDIO_DIR = "generated_audio" | |
| os.makedirs(TEMP_AUDIO_DIR, exist_ok=True) | |
| os.makedirs(GENERATED_AUDIO_DIR, exist_ok=True) | |
| class TTSRequestModel(BaseModel): | |
| """Model for non-file inputs to synthesis and streaming.""" | |
| text: str = Field(..., min_length=1, max_length=1000) | |
| speed: float = Field(default=1.0, ge=0.5, le=2.0) | |
| output_format: str = Field(default="wav", pattern="^(wav|mp3|flac)$") | |
| def convert_to_wav_blocking(input_path: str) -> str: | |
| """ | |
| NEW FUNCTION: Uses FFmpeg to convert any uploaded audio format (WebM, MP4, etc.) | |
| to a 24kHz, 16-bit PCM WAV file, which is required by soundfile/libsndfile. | |
| This function must run in the ThreadPoolExecutor. | |
| """ | |
| # Create a unique temporary filename for the converted WAV file | |
| # We use tempfile.NamedTemporaryFile to safely create a path | |
| # and then delete the file handle so ffmpeg can write to it. | |
| with tempfile.NamedTemporaryFile(suffix=".wav", dir=TEMP_AUDIO_DIR, delete=False) as tmp: | |
| output_path = tmp.name | |
| logger.info(f"Converting '{os.path.basename(input_path)}' to WAV (24kHz, mono) at {os.path.basename(output_path)}") | |
| # FFmpeg command details: | |
| # -y: overwrite output file if it exists | |
| # -i: input file path | |
| # -f wav: output format is WAV | |
| # -ar 24000: set sample rate to 24000 (required by NeuTTS) | |
| # -ac 1: set audio channels to 1 (mono) | |
| # -c:a pcm_s16le: set codec to uncompressed 16-bit PCM (standard WAV) | |
| command = [ | |
| "ffmpeg", | |
| "-y", | |
| "-i", input_path, | |
| "-f", "wav", | |
| "-ar", str(SAMPLE_RATE), | |
| "-ac", "1", | |
| "-c:a", "pcm_s16le", | |
| output_path | |
| ] | |
| try: | |
| # Run the FFmpeg command | |
| # Use a short timeout to prevent runaway processes | |
| result = subprocess.run(command, check=True, capture_output=True, text=True, timeout=30) | |
| logger.info(f"FFmpeg conversion successful.") | |
| return output_path | |
| except subprocess.CalledProcessError as e: | |
| logger.error(f"FFmpeg conversion failed: {e.stderr}") | |
| # Clean up the output path if FFmpeg failed to write it | |
| if os.path.exists(output_path): | |
| os.unlink(output_path) | |
| # Provide the last line of the FFmpeg error to the user | |
| error_detail = e.stderr.splitlines()[-1] if e.stderr else "Unknown FFmpeg error." | |
| raise HTTPException(status_code=400, detail=f"Audio format conversion failed: {error_detail}") | |
| except subprocess.TimeoutExpired: | |
| logger.error("FFmpeg conversion timed out.") | |
| if os.path.exists(output_path): | |
| os.unlink(output_path) | |
| raise HTTPException(status_code=504, detail="Audio conversion timed out after 30 seconds.") | |
| except Exception as e: | |
| logger.error(f"General conversion error: {e}") | |
| if os.path.exists(output_path): | |
| os.unlink(output_path) | |
| raise HTTPException(status_code=500, detail="An unexpected error occurred during audio conversion.") | |
| # --- Model Wrapper and Logic --- | |
| class NeuTTSWrapper: | |
| def __init__(self, device: str = "cpu"): | |
| self.tts_model = None | |
| self.device = device | |
| self.load_model() | |
| def load_model(self): | |
| try: | |
| logger.info(f"Loading NeuTTSAir model on device: {self.device}") | |
| # Ensure we respect the CPU configuration | |
| self.tts_model = NeuTTSAir(backbone_device=self.device, codec_device=self.device) | |
| logger.info("✅ NeuTTSAir model loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"❌ Model loading failed: {e}") | |
| raise | |
| def _convert_to_streamable_format(self, audio_data: np.ndarray, audio_format: str) -> bytes: | |
| """Converts NumPy audio array to streamable bytes in the specified format.""" | |
| audio_buffer = io.BytesIO() | |
| try: | |
| sf.write(audio_buffer, audio_data, SAMPLE_RATE, format=audio_format) | |
| except Exception as e: | |
| logger.error(f"Failed to write audio data to format {audio_format}: {e}") | |
| raise | |
| audio_buffer.seek(0) | |
| return audio_buffer.read() | |
| def _split_text_into_chunks(self, text: str) -> list[str]: | |
| """Simple sentence splitting for streaming (can be enhanced with regex).""" | |
| sentences = [s.strip() for s in text.split('.') if s.strip()] | |
| if not sentences: | |
| sentences = [text.strip()] | |
| return sentences | |
| def generate_speech_blocking(self, text: str, ref_audio_path: str) -> np.ndarray: | |
| """Blocking synthesis for standard endpoint.""" | |
| ref_s = self.tts_model.encode_reference(ref_audio_path) | |
| # 3. Infer full text | |
| with torch.no_grad(): | |
| audio = self.tts_model.infer(text, ref_s, reference_text) | |
| return audio.cpu().numpy() | |
| def stream_speech_blocking(self, text: str, ref_audio_path: str, speed: float, audio_format: str) -> Generator[bytes, None, None]: | |
| """Sentence-by-Sentence Streaming (Blocking).""" | |
| logger.info(f"Starting streaming synthesis for text length: {len(text)}") | |
| ref_s = self.tts_model.encode_reference(ref_audio_path) | |
| # 3. Split text | |
| sentences = self._split_text_into_chunks(text) | |
| # 4. Stream chunks | |
| for i, sentence in enumerate(sentences): | |
| if not sentence.strip(): | |
| continue | |
| logger.debug(f"Generating streaming chunk {i+1}: '{sentence[:30]}...'") | |
| # Infer sentence | |
| with torch.no_grad(): | |
| audio_chunk = self.tts_model.infer(sentence, ref_s, reference_text) | |
| # Convert and yield | |
| yield self._convert_to_streamable_format(audio_chunk.cpu().numpy(), audio_format) | |
| logger.info("Streaming synthesis complete.") | |
| # --- Asynchronous Offloading --- | |
| async def run_blocking_task_async(func, *args, **kwargs): | |
| """Offloads a blocking function call to the ThreadPoolExecutor.""" | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor( | |
| tts_executor, | |
| lambda: func(*args, **kwargs) | |
| ) | |
| async def save_upload_file_async(upload_file: UploadFile) -> str: | |
| """Asynchronously saves the UploadFile to disk.""" | |
| temp_filename = os.path.join(TEMP_AUDIO_DIR, f"{time.time()}_{upload_file.filename}") | |
| try: | |
| # Use asyncio to read the file chunks in a non-blocking manner | |
| async with aiofiles.open(temp_filename, 'wb') as out_file: | |
| while content := await upload_file.read(1024 * 1024): | |
| await out_file.write(content) | |
| return temp_filename | |
| except Exception as e: | |
| logger.error(f"Error saving file: {e}") | |
| raise HTTPException(status_code=500, detail="Could not save reference audio file") | |
| # --- FastAPI Lifespan Manager (Kokoro Feature) --- | |
| async def lifespan(app: FastAPI): | |
| """Modern lifespan management: initialize model on startup, shutdown executor.""" | |
| try: | |
| app.state.tts_wrapper = NeuTTSWrapper(device=DEVICE) | |
| except Exception as e: | |
| logger.error(f"Fatal startup error: {e}") | |
| # Terminate the application if the model can't load | |
| tts_executor.shutdown(wait=False) | |
| raise RuntimeError("Model initialization failed.") | |
| yield # Application serves requests | |
| # Shutdown | |
| logger.info("Shutting down ThreadPoolExecutor.") | |
| tts_executor.shutdown(wait=False) | |
| # --- FastAPI Application Setup --- | |
| app = FastAPI( | |
| title="NeuTTS Air Instant Cloning API", | |
| version="2.0.0-PROD-ENHANCED", | |
| docs_url="/docs", | |
| lifespan=lifespan | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- New Endpoints and Enhancements --- | |
| async def root(): | |
| return {"message": "NeuTTS Air API v2.0 - Ready for Instant Voice Cloning"} | |
| async def health_check(): | |
| """Enhanced health check (Kokoro Feature + Original Metrics)""" | |
| mem = psutil.virtual_memory() | |
| disk = psutil.disk_usage('/') | |
| return { | |
| "status": "healthy", | |
| "model_loaded": hasattr(app.state, 'tts_wrapper') and app.state.tts_wrapper.tts_model is not None, | |
| "device": DEVICE, | |
| "concurrency_limit": MAX_WORKERS, | |
| "memory_usage": { | |
| "total_gb": round(mem.total / (1024**3), 2), | |
| "used_percent": mem.percent | |
| }, | |
| "disk_usage": { | |
| "total_gb": round(disk.total / (1024**3), 2), | |
| "used_percent": disk.percent | |
| } | |
| } | |
| async def cleanup_files(): | |
| """Maintenance endpoint to remove old generated and temporary files.""" | |
| await run_blocking_task_async(cleanup_files_blocking) | |
| return {"message": "Cleanup initiated successfully."} | |
| def cleanup_files_blocking(): | |
| """Blocking file cleanup logic (original NeuTTS feature).""" | |
| now = time.time() | |
| deleted_count = 0 | |
| for directory in [GENERATED_AUDIO_DIR, TEMP_AUDIO_DIR]: | |
| for filename in os.listdir(directory): | |
| filepath = os.path.join(directory, filename) | |
| if os.path.isfile(filepath): | |
| try: | |
| # Original cleanup logic: delete if older than CLEANUP_THRESHOLD | |
| if now - os.path.getctime(filepath) > CLEANUP_THRESHOLD: | |
| os.remove(filepath) | |
| deleted_count += 1 | |
| except Exception as e: | |
| logger.warning(f"Failed to delete {filepath}: {e}") | |
| logger.info(f"Cleanup completed: {deleted_count} files removed.") | |
| return deleted_count | |
| # --- Core Synthesis Endpoints --- | |
| async def text_to_speech( | |
| text: str = Form(...), | |
| reference_text: str = Form(...), | |
| speed: float = Form(1.0, ge=0.5, le=2.0), | |
| output_format: str = Form("wav", pattern="^(wav|mp3|flac)$"), | |
| reference_audio: UploadFile = File(...)): | |
| """ | |
| Standard blocking TTS endpoint with Multi-Format Output (Kokoro Feature). | |
| Includes FFmpeg conversion for uploaded audio format compatibility. | |
| """ | |
| if not hasattr(app.state, 'tts_wrapper'): | |
| raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded") | |
| # 1. Asynchronously save reference audio (original upload) | |
| temp_ref_path = await save_upload_file_async(reference_audio) | |
| converted_wav_path = None # NEW: Initialize for cleanup | |
| start_time = time.time() | |
| try: | |
| # 2. **NEW STEP**: Convert the uploaded file (WebM, etc.) to a 24kHz WAV file using FFmpeg | |
| converted_wav_path = await run_blocking_task_async( | |
| convert_to_wav_blocking, | |
| temp_ref_path | |
| ) | |
| # 3. Offload the ENTIRE blocking process (encode + infer) to a thread | |
| audio_data = await run_blocking_task_async( | |
| app.state.tts_wrapper.generate_speech_blocking, | |
| text, | |
| converted_wav_path, # IMPORTANT: Pass the CONVERTED WAV path | |
| reference_text | |
| ) | |
| # 4. Convert to requested format (Blocking, but usually fast) | |
| audio_bytes = await run_blocking_task_async( | |
| app.state.tts_wrapper._convert_to_streamable_format, | |
| audio_data, | |
| output_format | |
| ) | |
| # 5. Save to disk (Original NeuTTS requirement) | |
| audio_filename = f"tts_{time.time()}.{output_format}" | |
| final_path = os.path.join(GENERATED_AUDIO_DIR, audio_filename) | |
| await run_blocking_task_async( | |
| lambda: open(final_path, 'wb').write(audio_bytes) | |
| ) | |
| processing_time = time.time() - start_time | |
| audio_duration = len(audio_data) / SAMPLE_RATE | |
| return Response( | |
| content=audio_bytes, | |
| media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}", | |
| headers={ | |
| "Content-Disposition": f"attachment; filename={audio_filename}", | |
| "X-Processing-Time": f"{processing_time:.2f}s", | |
| "X-Audio-Duration": f"{audio_duration:.2f}s" | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Synthesis error: {e}") | |
| # Reraise HTTPExceptions that may have come from the conversion step | |
| if isinstance(e, HTTPException): | |
| raise | |
| raise HTTPException(status_code=500, detail=f"Synthesis failed: {e}") | |
| finally: | |
| # 6. Clean up BOTH the original file AND the converted WAV file | |
| if os.path.exists(temp_ref_path): | |
| os.unlink(temp_ref_path) | |
| if converted_wav_path and os.path.exists(converted_wav_path): | |
| os.unlink(converted_wav_path) | |
| async def stream_text_to_speech_cloning( | |
| text: str = Form(..., min_length=1, max_length=5000), | |
| reference_text: str = Form(...), | |
| speed: float = Form(1.0, ge=0.5, le=2.0), | |
| output_format: str = Form("mp3", pattern="^(wav|mp3|flac)$"), | |
| reference_audio: UploadFile = File(...)): | |
| """ | |
| Sentence-by-Sentence Streaming Endpoint. | |
| Fixes race condition by moving cleanup into the streaming generator. | |
| """ | |
| if not hasattr(app.state, 'tts_wrapper'): | |
| raise HTTPException(status_code=503, detail="Service unavailable: Model not loaded") | |
| # 1. Asynchronously save reference audio (non-blocking) | |
| temp_ref_path = await save_upload_file_async(reference_audio) | |
| converted_wav_path = None # Initialize for cleanup | |
| try: | |
| # 2. Convert the uploaded file (WebM, etc.) to a 24kHz WAV file | |
| converted_wav_path = await run_blocking_task_async( | |
| convert_to_wav_blocking, | |
| temp_ref_path | |
| ) | |
| # 2.5. CLEANUP ORIGINAL FILE IMMEDIATELY: It is no longer needed after conversion | |
| if os.path.exists(temp_ref_path): | |
| os.unlink(temp_ref_path) | |
| # 3. Define the generator function, which will run in the thread pool | |
| def stream_generator(path_to_delete: str): | |
| try: | |
| # This logic uses the path_to_delete parameter, which is guaranteed to exist | |
| for chunk_bytes in app.state.tts_wrapper.stream_speech_blocking( | |
| text, | |
| path_to_delete, # Pass the CONVERTED WAV path | |
| reference_text, | |
| speed, | |
| output_format | |
| ): | |
| yield chunk_bytes | |
| except Exception as e: | |
| # Log the error and raise it to stop the stream | |
| logger.error(f"Streaming generator error: {e}") | |
| raise # Re-raise to ensure the stream terminates | |
| finally: | |
| # 4. **CRUCIAL FIX:** Clean up the converted file ONLY AFTER GENERATION IS DONE | |
| if os.path.exists(path_to_delete): | |
| os.unlink(path_to_delete) | |
| logger.info(f"Cleaned up converted file: {path_to_delete}") | |
| # Return StreamingResponse, passing the path to the generator | |
| return StreamingResponse( | |
| stream_generator(converted_wav_path), | |
| media_type=f"audio/{'mpeg' if output_format == 'mp3' else output_format}", | |
| headers={ | |
| "Content-Disposition": "attachment; filename=tts_live_stream.mp3", | |
| "Transfer-Encoding": "chunked", | |
| "Cache-Control": "no-cache" | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Streaming setup error: {e}") | |
| # Clean up files only if the setup failed *before* starting the generator | |
| if os.path.exists(temp_ref_path): | |
| os.unlink(temp_ref_path) | |
| if converted_wav_path and os.path.exists(converted_wav_path): | |
| os.unlink(converted_wav_path) | |
| # Reraise HTTPExceptions that may have come from the conversion step | |
| if isinstance(e, HTTPException): | |
| raise | |
| raise HTTPException(status_code=500, detail=f"Streaming synthesis failed: {e}") | |
| # Note: The outer 'finally' block is now removed as its logic is handled in 2.5 and 4. | |
| async def get_audio(filename: str): | |
| """Original NeuTTS feature to serve generated audio files.""" | |
| file_path = os.path.join(GENERATED_AUDIO_DIR, filename) | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=404, detail="Audio file not found") | |
| return Response( | |
| content=open(file_path, "rb").read(), | |
| media_type=f"audio/{filename.split('.')[-1]}", # Simple media type detection | |
| headers={"Content-Disposition": f"attachment; filename={filename}"} | |
| ) | |