Spaces:
Paused
Paused
| import os | |
| import sys | |
| import time | |
| import gc | |
| import torch | |
| import numpy as np | |
| import aiofiles | |
| import asyncio | |
| import subprocess | |
| import io | |
| from contextlib import asynccontextmanager | |
| from typing import Optional, Dict, Any, AsyncGenerator | |
| from uuid import uuid4 | |
| from pathlib import Path | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks, Request | |
| from fastapi.responses import JSONResponse, StreamingResponse, Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import psutil | |
| import logging | |
| import soundfile as sf | |
| # Add NeuTTS Air to path | |
| sys.path.insert(0, "/app/neutts-air") | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| class Config: | |
| MAX_TEXT_LENGTH = 1000 | |
| MIN_AUDIO_DURATION = 2 | |
| MAX_AUDIO_DURATION = 30 | |
| SAMPLE_RATE = 24000 | |
| REFERENCE_SAMPLE_RATE = 16000 | |
| CHUNK_SIZE = 4096 # For streaming | |
| MAX_CONCURRENT_REQUESTS = 3 | |
| REQUEST_TIMEOUT = 120 | |
| config = Config() | |
| # Global model instance with async support | |
| tts_model = None | |
| model_loading = False | |
| active_requests = 0 | |
| request_semaphore = asyncio.Semaphore(config.MAX_CONCURRENT_REQUESTS) | |
| # In-memory audio cache to avoid disk usage | |
| audio_cache = {} | |
| CACHE_MAX_SIZE = 50 # Max cached audio files | |
| CACHE_CLEANUP_INTERVAL = 300 # 5 minutes | |
| class AudioCache: | |
| """In-memory audio cache to avoid disk usage""" | |
| def __init__(self, max_size: int = 50): | |
| self.cache = {} | |
| self.max_size = max_size | |
| self.access_order = [] | |
| async def store_audio(self, audio_id: str, audio_data: np.ndarray, sample_rate: int): | |
| """Store audio in memory""" | |
| if len(self.cache) >= self.max_size: | |
| await self._remove_oldest() | |
| self.cache[audio_id] = { | |
| 'audio': audio_data, | |
| 'sample_rate': sample_rate, | |
| 'created_at': time.time(), | |
| 'accessed_at': time.time() | |
| } | |
| self.access_order.append(audio_id) | |
| async def get_audio(self, audio_id: str) -> Optional[Dict]: | |
| """Retrieve audio from memory""" | |
| if audio_id in self.cache: | |
| self.cache[audio_id]['accessed_at'] = time.time() | |
| # Move to end of access order | |
| if audio_id in self.access_order: | |
| self.access_order.remove(audio_id) | |
| self.access_order.append(audio_id) | |
| return self.cache[audio_id] | |
| return None | |
| async def _remove_oldest(self): | |
| """Remove least recently used audio""" | |
| if self.access_order: | |
| oldest_id = self.access_order.pop(0) | |
| if oldest_id in self.cache: | |
| del self.cache[oldest_id] | |
| logger.debug(f"Removed cached audio: {oldest_id}") | |
| # Initialize cache | |
| audio_cache = AudioCache(max_size=CACHE_MAX_SIZE) | |
| class AudioStreamProcessor: | |
| """Process audio in memory without disk usage""" | |
| async def convert_audio_to_wav_memory(upload_file: UploadFile) -> tuple[bytes, float]: | |
| """Convert uploaded audio to WAV format in memory""" | |
| try: | |
| # Read uploaded file into memory | |
| file_content = await upload_file.read() | |
| # Create temporary in-memory files | |
| input_buffer = io.BytesIO(file_content) | |
| output_buffer = io.BytesIO() | |
| # Save input to temporary file (minimal disk usage for ffmpeg) | |
| temp_input_path = f"/tmp/input_{uuid4().hex}{Path(upload_file.filename).suffix}" | |
| temp_output_path = f"/tmp/output_{uuid4().hex}.wav" | |
| try: | |
| # Write input to temp file | |
| async with aiofiles.open(temp_input_path, 'wb') as f: | |
| await f.write(file_content) | |
| # Convert using ffmpeg | |
| cmd = [ | |
| 'ffmpeg', '-i', temp_input_path, | |
| '-ac', '1', | |
| '-ar', str(config.REFERENCE_SAMPLE_RATE), | |
| '-acodec', 'pcm_s16le', | |
| '-y', temp_output_path | |
| ] | |
| process = await asyncio.create_subprocess_exec( | |
| *cmd, | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.PIPE | |
| ) | |
| stdout, stderr = await process.communicate() | |
| if process.returncode != 0: | |
| raise Exception(f"FFmpeg failed: {stderr.decode()}") | |
| # Read converted file into memory | |
| async with aiofiles.open(temp_output_path, 'rb') as f: | |
| wav_data = await f.read() | |
| # Get duration | |
| duration = await AudioStreamProcessor.get_audio_duration_memory(wav_data) | |
| return wav_data, duration | |
| finally: | |
| # Cleanup temp files | |
| for temp_file in [temp_input_path, temp_output_path]: | |
| if os.path.exists(temp_file): | |
| try: | |
| os.remove(temp_file) | |
| except: | |
| pass | |
| except Exception as e: | |
| logger.error(f"Audio conversion failed: {e}") | |
| raise | |
| async def get_audio_duration_memory(audio_data: bytes) -> float: | |
| """Get audio duration from in-memory WAV data""" | |
| try: | |
| # Use soundfile with BytesIO | |
| with sf.SoundFile(io.BytesIO(audio_data)) as audio_file: | |
| return len(audio_file) / audio_file.samplerate | |
| except Exception as e: | |
| logger.warning(f"SoundFile duration failed: {e}, using librosa") | |
| # Fallback to librosa | |
| import librosa | |
| audio_array, sr = librosa.load(io.BytesIO(audio_data), sr=None) | |
| return len(audio_array) / sr | |
| async def validate_audio_duration(duration: float): | |
| """Validate audio duration""" | |
| if duration < config.MIN_AUDIO_DURATION: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Audio too short: {duration:.1f}s (minimum {config.MIN_AUDIO_DURATION}s)" | |
| ) | |
| if duration > config.MAX_AUDIO_DURATION: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Audio too long: {duration:.1f}s (maximum {config.MAX_AUDIO_DURATION}s)" | |
| ) | |
| async def load_tts_model(): | |
| """Load TTS model asynchronously""" | |
| global tts_model, model_loading | |
| if tts_model is not None or model_loading: | |
| return | |
| model_loading = True | |
| try: | |
| logger.info("Loading NeuTTS Air model...") | |
| # Clear memory before loading | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Import model | |
| try: | |
| from neuttsair.neutts import NeuTTSAir | |
| except ImportError as e: | |
| logger.error(f"Failed to import NeuTTS Air: {e}") | |
| raise | |
| # Initialize model | |
| tts_model = NeuTTSAir( | |
| backbone_repo="neuphonic/neutts-air", | |
| backbone_device="cpu", | |
| codec_repo="neuphonic/neucodec", | |
| codec_device="cpu" | |
| ) | |
| logger.info("NeuTTS Air model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| raise e | |
| finally: | |
| model_loading = False | |
| async def lifespan(app: FastAPI): | |
| """Lifespan manager with async startup/shutdown""" | |
| # Startup | |
| logger.info("🚀 Starting NeuTTS Air Streaming API") | |
| # Load model in background | |
| asyncio.create_task(load_tts_model()) | |
| # Start cache cleanup task | |
| asyncio.create_task(cache_cleanup_task()) | |
| yield | |
| # Shutdown | |
| logger.info("🛑 Shutting down NeuTTS Air API") | |
| global tts_model | |
| if tts_model is not None: | |
| del tts_model | |
| tts_model = None | |
| gc.collect() | |
| app = FastAPI( | |
| title="NeuTTS Air Streaming API", | |
| description="High-quality on-device TTS with streaming and no disk usage", | |
| version="2.0.0", | |
| lifespan=lifespan | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Pydantic models | |
| class TTSRequest(BaseModel): | |
| text: str = Field(..., min_length=1, max_length=1000) | |
| reference_text: str = Field(..., min_length=1, max_length=500) | |
| reference_audio_path: Optional[str] = None | |
| class TTSResponse(BaseModel): | |
| success: bool | |
| audio_id: Optional[str] = None | |
| message: Optional[str] = None | |
| processing_time: Optional[float] = None | |
| audio_duration: Optional[float] = None | |
| stream_url: Optional[str] = None | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| active_requests: int | |
| cache_size: int | |
| memory_usage: Dict[str, float] | |
| # Async middleware for request limiting | |
| async def limit_concurrent_requests(request: Request, call_next): | |
| global active_requests | |
| if active_requests >= config.MAX_CONCURRENT_REQUESTS: | |
| return JSONResponse( | |
| status_code=429, | |
| content={"detail": "Too many concurrent requests"} | |
| ) | |
| async with request_semaphore: | |
| active_requests += 1 | |
| try: | |
| start_time = time.time() | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| logger.info(f"{request.method} {request.url.path} completed in {process_time:.2f}s") | |
| return response | |
| finally: | |
| active_requests -= 1 | |
| async def root(): | |
| return { | |
| "message": "NeuTTS Air Streaming API", | |
| "status": "healthy", | |
| "features": ["streaming", "no_disk_usage", "async", "in_memory_cache"], | |
| "model_loaded": tts_model is not None, | |
| "active_requests": active_requests | |
| } | |
| async def health_check(): | |
| """Health check with memory usage""" | |
| try: | |
| memory = psutil.virtual_memory() | |
| return HealthResponse( | |
| status="healthy", | |
| model_loaded=tts_model is not None, | |
| active_requests=active_requests, | |
| cache_size=len(audio_cache.cache), | |
| memory_usage={ | |
| "total_gb": round(memory.total / (1024**3), 2), | |
| "available_gb": round(memory.available / (1024**3), 2), | |
| "used_percent": round(memory.percent, 2) | |
| } | |
| ) | |
| except Exception as e: | |
| return HealthResponse( | |
| status="degraded", | |
| model_loaded=tts_model is not None, | |
| active_requests=active_requests, | |
| cache_size=len(audio_cache.cache), | |
| memory_usage={"error": str(e)} | |
| ) | |
| async def synthesize_speech( | |
| reference_text: str = Form(...), | |
| text: str = Form(...), | |
| reference_audio: UploadFile = File(...) | |
| ): | |
| """ | |
| Synthesize speech with streaming support and no disk usage | |
| """ | |
| start_time = time.time() | |
| request_id = str(uuid4())[:8] | |
| logger.info(f"[{request_id}] Starting streaming synthesis") | |
| if tts_model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded yet") | |
| # Validate inputs | |
| if not reference_text.strip() or not text.strip(): | |
| raise HTTPException(status_code=400, detail="Text fields cannot be empty") | |
| try: | |
| # Convert audio to WAV in memory | |
| wav_data, audio_duration = await AudioStreamProcessor.convert_audio_to_wav_memory(reference_audio) | |
| await AudioStreamProcessor.validate_audio_duration(audio_duration) | |
| logger.info(f"[{request_id}] Audio validated: {audio_duration:.2f}s") | |
| # Create temporary file for model processing (minimal disk usage) | |
| temp_ref_path = f"/tmp/ref_{request_id}.wav" | |
| try: | |
| async with aiofiles.open(temp_ref_path, 'wb') as f: | |
| await f.write(wav_data) | |
| # Perform TTS | |
| logger.info(f"[{request_id}] Synthesizing: '{text[:50]}...'") | |
| # Encode reference and generate speech | |
| ref_codes = tts_model.encode_reference(temp_ref_path) | |
| wav_output = tts_model.infer(text, ref_codes, reference_text) | |
| # Generate audio ID for caching | |
| audio_id = f"audio_{request_id}" | |
| # Store in memory cache | |
| await audio_cache.store_audio(audio_id, wav_output, config.SAMPLE_RATE) | |
| processing_time = time.time() - start_time | |
| output_duration = len(wav_output) / config.SAMPLE_RATE | |
| logger.info(f"[{request_id}] Synthesis completed in {processing_time:.2f}s") | |
| return TTSResponse( | |
| success=True, | |
| audio_id=audio_id, | |
| message="Speech synthesized successfully", | |
| processing_time=round(processing_time, 2), | |
| audio_duration=round(output_duration, 2), | |
| stream_url=f"/stream/{audio_id}" | |
| ) | |
| finally: | |
| # Cleanup temp file | |
| if os.path.exists(temp_ref_path): | |
| try: | |
| os.remove(temp_ref_path) | |
| except: | |
| pass | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"[{request_id}] Synthesis error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") | |
| async def stream_audio(audio_id: str): | |
| """ | |
| Stream audio directly from memory cache | |
| """ | |
| # Get audio from cache | |
| cached_audio = await audio_cache.get_audio(audio_id) | |
| if not cached_audio: | |
| raise HTTPException(status_code=404, detail="Audio not found or expired") | |
| audio_data = cached_audio['audio'] | |
| sample_rate = cached_audio['sample_rate'] | |
| # Convert numpy array to WAV bytes in memory | |
| wav_buffer = io.BytesIO() | |
| sf.write(wav_buffer, audio_data, sample_rate, format='WAV') | |
| wav_bytes = wav_buffer.getvalue() | |
| # Create async generator for streaming | |
| async def generate_audio_stream(): | |
| chunk_size = config.CHUNK_SIZE | |
| for i in range(0, len(wav_bytes), chunk_size): | |
| yield wav_bytes[i:i + chunk_size] | |
| await asyncio.sleep(0.001) # Small delay for proper streaming | |
| return StreamingResponse( | |
| generate_audio_stream(), | |
| media_type="audio/wav", | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=speech_{audio_id}.wav", | |
| "Cache-Control": "no-cache", | |
| "Content-Length": str(len(wav_bytes)) | |
| } | |
| ) | |
| async def download_audio(audio_id: str): | |
| """ | |
| Download audio as complete file | |
| """ | |
| cached_audio = await audio_cache.get_audio(audio_id) | |
| if not cached_audio: | |
| raise HTTPException(status_code=404, detail="Audio not found or expired") | |
| audio_data = cached_audio['audio'] | |
| sample_rate = cached_audio['sample_rate'] | |
| # Convert to WAV in memory | |
| wav_buffer = io.BytesIO() | |
| sf.write(wav_buffer, audio_data, sample_rate, format='WAV') | |
| wav_bytes = wav_buffer.getvalue() | |
| return Response( | |
| content=wav_bytes, | |
| media_type="audio/wav", | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=speech_{audio_id}.wav", | |
| "Content-Length": str(len(wav_bytes)) | |
| } | |
| ) | |
| async def synthesize_and_stream( | |
| reference_text: str = Form(...), | |
| text: str = Form(...), | |
| reference_audio: UploadFile = File(...) | |
| ): | |
| """ | |
| Real-time synthesis and streaming in one endpoint | |
| """ | |
| start_time = time.time() | |
| if tts_model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded yet") | |
| try: | |
| # Convert audio to WAV in memory | |
| wav_data, audio_duration = await AudioStreamProcessor.convert_audio_to_wav_memory(reference_audio) | |
| await AudioStreamProcessor.validate_audio_duration(audio_duration) | |
| # Create temporary file for model processing | |
| temp_ref_path = f"/tmp/ref_stream_{uuid4().hex}.wav" | |
| try: | |
| async with aiofiles.open(temp_ref_path, 'wb') as f: | |
| await f.write(wav_data) | |
| # Perform TTS | |
| ref_codes = tts_model.encode_reference(temp_ref_path) | |
| wav_output = tts_model.infer(text, ref_codes, reference_text) | |
| processing_time = time.time() - start_time | |
| logger.info(f"Real-time synthesis completed in {processing_time:.2f}s") | |
| # Convert to WAV bytes | |
| wav_buffer = io.BytesIO() | |
| sf.write(wav_buffer, wav_output, config.SAMPLE_RATE, format='WAV') | |
| wav_bytes = wav_buffer.getvalue() | |
| # Stream directly | |
| async def generate_stream(): | |
| chunk_size = config.CHUNK_SIZE | |
| for i in range(0, len(wav_bytes), chunk_size): | |
| yield wav_bytes[i:i + chunk_size] | |
| await asyncio.sleep(0.001) | |
| return StreamingResponse( | |
| generate_stream(), | |
| media_type="audio/wav", | |
| headers={ | |
| "Content-Disposition": "attachment; filename=speech_stream.wav", | |
| "Cache-Control": "no-cache", | |
| "X-Processing-Time": f"{processing_time:.2f}" | |
| } | |
| ) | |
| finally: | |
| if os.path.exists(temp_ref_path): | |
| try: | |
| os.remove(temp_ref_path) | |
| except: | |
| pass | |
| except Exception as e: | |
| logger.error(f"Stream synthesis error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Stream synthesis failed: {str(e)}") | |
| async def clear_cached_audio(audio_id: str): | |
| """Clear specific audio from cache""" | |
| if audio_id in audio_cache.cache: | |
| del audio_cache.cache[audio_id] | |
| if audio_id in audio_cache.access_order: | |
| audio_cache.access_order.remove(audio_id) | |
| return {"message": f"Audio {audio_id} cleared from cache"} | |
| else: | |
| raise HTTPException(status_code=404, detail="Audio not found in cache") | |
| async def clear_all_cache(): | |
| """Clear all audio cache""" | |
| cache_size = len(audio_cache.cache) | |
| audio_cache.cache.clear() | |
| audio_cache.access_order.clear() | |
| return {"message": f"Cleared all {cache_size} cached audio files"} | |
| async def cache_cleanup_task(): | |
| """Background task to clean up old cache entries""" | |
| while True: | |
| await asyncio.sleep(CACHE_CLEANUP_INTERVAL) | |
| try: | |
| current_time = time.time() | |
| expired_ids = [] | |
| for audio_id, data in audio_cache.cache.items(): | |
| if current_time - data['accessed_at'] > 3600: # 1 hour | |
| expired_ids.append(audio_id) | |
| for audio_id in expired_ids: | |
| if audio_id in audio_cache.cache: | |
| del audio_cache.cache[audio_id] | |
| if audio_id in audio_cache.access_order: | |
| audio_cache.access_order.remove(audio_id) | |
| if expired_ids: | |
| logger.info(f"Cache cleanup removed {len(expired_ids)} expired entries") | |
| except Exception as e: | |
| logger.error(f"Cache cleanup error: {e}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| workers=1, | |
| log_level="info" | |
| ) |