import os import sys import time import gc import torch import numpy as np import aiofiles import asyncio import subprocess import io from contextlib import asynccontextmanager from typing import Optional, Dict, Any, AsyncGenerator from uuid import uuid4 from pathlib import Path from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks, Request from fastapi.responses import JSONResponse, StreamingResponse, Response from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import psutil import logging import soundfile as sf # Add NeuTTS Air to path sys.path.insert(0, "/app/neutts-air") # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Configuration class Config: MAX_TEXT_LENGTH = 1000 MIN_AUDIO_DURATION = 2 MAX_AUDIO_DURATION = 30 SAMPLE_RATE = 24000 REFERENCE_SAMPLE_RATE = 16000 CHUNK_SIZE = 4096 # For streaming MAX_CONCURRENT_REQUESTS = 3 REQUEST_TIMEOUT = 120 config = Config() # Global model instance with async support tts_model = None model_loading = False active_requests = 0 request_semaphore = asyncio.Semaphore(config.MAX_CONCURRENT_REQUESTS) # In-memory audio cache to avoid disk usage audio_cache = {} CACHE_MAX_SIZE = 50 # Max cached audio files CACHE_CLEANUP_INTERVAL = 300 # 5 minutes class AudioCache: """In-memory audio cache to avoid disk usage""" def __init__(self, max_size: int = 50): self.cache = {} self.max_size = max_size self.access_order = [] async def store_audio(self, audio_id: str, audio_data: np.ndarray, sample_rate: int): """Store audio in memory""" if len(self.cache) >= self.max_size: await self._remove_oldest() self.cache[audio_id] = { 'audio': audio_data, 'sample_rate': sample_rate, 'created_at': time.time(), 'accessed_at': time.time() } self.access_order.append(audio_id) async def get_audio(self, audio_id: str) -> Optional[Dict]: """Retrieve audio from memory""" if audio_id in self.cache: self.cache[audio_id]['accessed_at'] = time.time() # Move to end of access order if audio_id in self.access_order: self.access_order.remove(audio_id) self.access_order.append(audio_id) return self.cache[audio_id] return None async def _remove_oldest(self): """Remove least recently used audio""" if self.access_order: oldest_id = self.access_order.pop(0) if oldest_id in self.cache: del self.cache[oldest_id] logger.debug(f"Removed cached audio: {oldest_id}") # Initialize cache audio_cache = AudioCache(max_size=CACHE_MAX_SIZE) class AudioStreamProcessor: """Process audio in memory without disk usage""" @staticmethod async def convert_audio_to_wav_memory(upload_file: UploadFile) -> tuple[bytes, float]: """Convert uploaded audio to WAV format in memory""" try: # Read uploaded file into memory file_content = await upload_file.read() # Create temporary in-memory files input_buffer = io.BytesIO(file_content) output_buffer = io.BytesIO() # Save input to temporary file (minimal disk usage for ffmpeg) temp_input_path = f"/tmp/input_{uuid4().hex}{Path(upload_file.filename).suffix}" temp_output_path = f"/tmp/output_{uuid4().hex}.wav" try: # Write input to temp file async with aiofiles.open(temp_input_path, 'wb') as f: await f.write(file_content) # Convert using ffmpeg cmd = [ 'ffmpeg', '-i', temp_input_path, '-ac', '1', '-ar', str(config.REFERENCE_SAMPLE_RATE), '-acodec', 'pcm_s16le', '-y', temp_output_path ] process = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await process.communicate() if process.returncode != 0: raise Exception(f"FFmpeg failed: {stderr.decode()}") # Read converted file into memory async with aiofiles.open(temp_output_path, 'rb') as f: wav_data = await f.read() # Get duration duration = await AudioStreamProcessor.get_audio_duration_memory(wav_data) return wav_data, duration finally: # Cleanup temp files for temp_file in [temp_input_path, temp_output_path]: if os.path.exists(temp_file): try: os.remove(temp_file) except: pass except Exception as e: logger.error(f"Audio conversion failed: {e}") raise @staticmethod async def get_audio_duration_memory(audio_data: bytes) -> float: """Get audio duration from in-memory WAV data""" try: # Use soundfile with BytesIO with sf.SoundFile(io.BytesIO(audio_data)) as audio_file: return len(audio_file) / audio_file.samplerate except Exception as e: logger.warning(f"SoundFile duration failed: {e}, using librosa") # Fallback to librosa import librosa audio_array, sr = librosa.load(io.BytesIO(audio_data), sr=None) return len(audio_array) / sr @staticmethod async def validate_audio_duration(duration: float): """Validate audio duration""" if duration < config.MIN_AUDIO_DURATION: raise HTTPException( status_code=400, detail=f"Audio too short: {duration:.1f}s (minimum {config.MIN_AUDIO_DURATION}s)" ) if duration > config.MAX_AUDIO_DURATION: raise HTTPException( status_code=400, detail=f"Audio too long: {duration:.1f}s (maximum {config.MAX_AUDIO_DURATION}s)" ) async def load_tts_model(): """Load TTS model asynchronously""" global tts_model, model_loading if tts_model is not None or model_loading: return model_loading = True try: logger.info("Loading NeuTTS Air model...") # Clear memory before loading gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Import model try: from neuttsair.neutts import NeuTTSAir except ImportError as e: logger.error(f"Failed to import NeuTTS Air: {e}") raise # Initialize model tts_model = NeuTTSAir( backbone_repo="neuphonic/neutts-air", backbone_device="cpu", codec_repo="neuphonic/neucodec", codec_device="cpu" ) logger.info("NeuTTS Air model loaded successfully!") except Exception as e: logger.error(f"Failed to load model: {str(e)}") raise e finally: model_loading = False @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan manager with async startup/shutdown""" # Startup logger.info("🚀 Starting NeuTTS Air Streaming API") # Load model in background asyncio.create_task(load_tts_model()) # Start cache cleanup task asyncio.create_task(cache_cleanup_task()) yield # Shutdown logger.info("🛑 Shutting down NeuTTS Air API") global tts_model if tts_model is not None: del tts_model tts_model = None gc.collect() app = FastAPI( title="NeuTTS Air Streaming API", description="High-quality on-device TTS with streaming and no disk usage", version="2.0.0", lifespan=lifespan ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Pydantic models class TTSRequest(BaseModel): text: str = Field(..., min_length=1, max_length=1000) reference_text: str = Field(..., min_length=1, max_length=500) reference_audio_path: Optional[str] = None class TTSResponse(BaseModel): success: bool audio_id: Optional[str] = None message: Optional[str] = None processing_time: Optional[float] = None audio_duration: Optional[float] = None stream_url: Optional[str] = None class HealthResponse(BaseModel): status: str model_loaded: bool active_requests: int cache_size: int memory_usage: Dict[str, float] # Async middleware for request limiting @app.middleware("http") async def limit_concurrent_requests(request: Request, call_next): global active_requests if active_requests >= config.MAX_CONCURRENT_REQUESTS: return JSONResponse( status_code=429, content={"detail": "Too many concurrent requests"} ) async with request_semaphore: active_requests += 1 try: start_time = time.time() response = await call_next(request) process_time = time.time() - start_time logger.info(f"{request.method} {request.url.path} completed in {process_time:.2f}s") return response finally: active_requests -= 1 @app.get("/") async def root(): return { "message": "NeuTTS Air Streaming API", "status": "healthy", "features": ["streaming", "no_disk_usage", "async", "in_memory_cache"], "model_loaded": tts_model is not None, "active_requests": active_requests } @app.get("/health") async def health_check(): """Health check with memory usage""" try: memory = psutil.virtual_memory() return HealthResponse( status="healthy", model_loaded=tts_model is not None, active_requests=active_requests, cache_size=len(audio_cache.cache), memory_usage={ "total_gb": round(memory.total / (1024**3), 2), "available_gb": round(memory.available / (1024**3), 2), "used_percent": round(memory.percent, 2) } ) except Exception as e: return HealthResponse( status="degraded", model_loaded=tts_model is not None, active_requests=active_requests, cache_size=len(audio_cache.cache), memory_usage={"error": str(e)} ) @app.post("/synthesize", response_model=TTSResponse) async def synthesize_speech( reference_text: str = Form(...), text: str = Form(...), reference_audio: UploadFile = File(...) ): """ Synthesize speech with streaming support and no disk usage """ start_time = time.time() request_id = str(uuid4())[:8] logger.info(f"[{request_id}] Starting streaming synthesis") if tts_model is None: raise HTTPException(status_code=503, detail="Model not loaded yet") # Validate inputs if not reference_text.strip() or not text.strip(): raise HTTPException(status_code=400, detail="Text fields cannot be empty") try: # Convert audio to WAV in memory wav_data, audio_duration = await AudioStreamProcessor.convert_audio_to_wav_memory(reference_audio) await AudioStreamProcessor.validate_audio_duration(audio_duration) logger.info(f"[{request_id}] Audio validated: {audio_duration:.2f}s") # Create temporary file for model processing (minimal disk usage) temp_ref_path = f"/tmp/ref_{request_id}.wav" try: async with aiofiles.open(temp_ref_path, 'wb') as f: await f.write(wav_data) # Perform TTS logger.info(f"[{request_id}] Synthesizing: '{text[:50]}...'") # Encode reference and generate speech ref_codes = tts_model.encode_reference(temp_ref_path) wav_output = tts_model.infer(text, ref_codes, reference_text) # Generate audio ID for caching audio_id = f"audio_{request_id}" # Store in memory cache await audio_cache.store_audio(audio_id, wav_output, config.SAMPLE_RATE) processing_time = time.time() - start_time output_duration = len(wav_output) / config.SAMPLE_RATE logger.info(f"[{request_id}] Synthesis completed in {processing_time:.2f}s") return TTSResponse( success=True, audio_id=audio_id, message="Speech synthesized successfully", processing_time=round(processing_time, 2), audio_duration=round(output_duration, 2), stream_url=f"/stream/{audio_id}" ) finally: # Cleanup temp file if os.path.exists(temp_ref_path): try: os.remove(temp_ref_path) except: pass except HTTPException: raise except Exception as e: logger.error(f"[{request_id}] Synthesis error: {str(e)}") raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") @app.get("/stream/{audio_id}") async def stream_audio(audio_id: str): """ Stream audio directly from memory cache """ # Get audio from cache cached_audio = await audio_cache.get_audio(audio_id) if not cached_audio: raise HTTPException(status_code=404, detail="Audio not found or expired") audio_data = cached_audio['audio'] sample_rate = cached_audio['sample_rate'] # Convert numpy array to WAV bytes in memory wav_buffer = io.BytesIO() sf.write(wav_buffer, audio_data, sample_rate, format='WAV') wav_bytes = wav_buffer.getvalue() # Create async generator for streaming async def generate_audio_stream(): chunk_size = config.CHUNK_SIZE for i in range(0, len(wav_bytes), chunk_size): yield wav_bytes[i:i + chunk_size] await asyncio.sleep(0.001) # Small delay for proper streaming return StreamingResponse( generate_audio_stream(), media_type="audio/wav", headers={ "Content-Disposition": f"attachment; filename=speech_{audio_id}.wav", "Cache-Control": "no-cache", "Content-Length": str(len(wav_bytes)) } ) @app.get("/download/{audio_id}") async def download_audio(audio_id: str): """ Download audio as complete file """ cached_audio = await audio_cache.get_audio(audio_id) if not cached_audio: raise HTTPException(status_code=404, detail="Audio not found or expired") audio_data = cached_audio['audio'] sample_rate = cached_audio['sample_rate'] # Convert to WAV in memory wav_buffer = io.BytesIO() sf.write(wav_buffer, audio_data, sample_rate, format='WAV') wav_bytes = wav_buffer.getvalue() return Response( content=wav_bytes, media_type="audio/wav", headers={ "Content-Disposition": f"attachment; filename=speech_{audio_id}.wav", "Content-Length": str(len(wav_bytes)) } ) @app.post("/synthesize-and-stream") async def synthesize_and_stream( reference_text: str = Form(...), text: str = Form(...), reference_audio: UploadFile = File(...) ): """ Real-time synthesis and streaming in one endpoint """ start_time = time.time() if tts_model is None: raise HTTPException(status_code=503, detail="Model not loaded yet") try: # Convert audio to WAV in memory wav_data, audio_duration = await AudioStreamProcessor.convert_audio_to_wav_memory(reference_audio) await AudioStreamProcessor.validate_audio_duration(audio_duration) # Create temporary file for model processing temp_ref_path = f"/tmp/ref_stream_{uuid4().hex}.wav" try: async with aiofiles.open(temp_ref_path, 'wb') as f: await f.write(wav_data) # Perform TTS ref_codes = tts_model.encode_reference(temp_ref_path) wav_output = tts_model.infer(text, ref_codes, reference_text) processing_time = time.time() - start_time logger.info(f"Real-time synthesis completed in {processing_time:.2f}s") # Convert to WAV bytes wav_buffer = io.BytesIO() sf.write(wav_buffer, wav_output, config.SAMPLE_RATE, format='WAV') wav_bytes = wav_buffer.getvalue() # Stream directly async def generate_stream(): chunk_size = config.CHUNK_SIZE for i in range(0, len(wav_bytes), chunk_size): yield wav_bytes[i:i + chunk_size] await asyncio.sleep(0.001) return StreamingResponse( generate_stream(), media_type="audio/wav", headers={ "Content-Disposition": "attachment; filename=speech_stream.wav", "Cache-Control": "no-cache", "X-Processing-Time": f"{processing_time:.2f}" } ) finally: if os.path.exists(temp_ref_path): try: os.remove(temp_ref_path) except: pass except Exception as e: logger.error(f"Stream synthesis error: {str(e)}") raise HTTPException(status_code=500, detail=f"Stream synthesis failed: {str(e)}") @app.delete("/cache/{audio_id}") async def clear_cached_audio(audio_id: str): """Clear specific audio from cache""" if audio_id in audio_cache.cache: del audio_cache.cache[audio_id] if audio_id in audio_cache.access_order: audio_cache.access_order.remove(audio_id) return {"message": f"Audio {audio_id} cleared from cache"} else: raise HTTPException(status_code=404, detail="Audio not found in cache") @app.delete("/cache") async def clear_all_cache(): """Clear all audio cache""" cache_size = len(audio_cache.cache) audio_cache.cache.clear() audio_cache.access_order.clear() return {"message": f"Cleared all {cache_size} cached audio files"} async def cache_cleanup_task(): """Background task to clean up old cache entries""" while True: await asyncio.sleep(CACHE_CLEANUP_INTERVAL) try: current_time = time.time() expired_ids = [] for audio_id, data in audio_cache.cache.items(): if current_time - data['accessed_at'] > 3600: # 1 hour expired_ids.append(audio_id) for audio_id in expired_ids: if audio_id in audio_cache.cache: del audio_cache.cache[audio_id] if audio_id in audio_cache.access_order: audio_cache.access_order.remove(audio_id) if expired_ids: logger.info(f"Cache cleanup removed {len(expired_ids)} expired entries") except Exception as e: logger.error(f"Cache cleanup error: {e}") if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=7860, workers=1, log_level="info" )