Spaces:
Paused
Paused
| import os | |
| import sys | |
| import time | |
| import gc | |
| import torch | |
| import numpy as np | |
| import aiofiles | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import JSONResponse, FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, Dict, Any | |
| import psutil | |
| import logging | |
| # Add NeuTTS Air to path | |
| sys.path.append("neutts-air") | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="NeuTTS Air API", | |
| description="High-quality on-device Text-to-Speech with instant voice cloning", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global model instance | |
| tts_model = None | |
| model_loading = False | |
| # Pydantic models | |
| class TTSRequest(BaseModel): | |
| text: str | |
| reference_text: str | |
| reference_audio_path: Optional[str] = None | |
| class TTSResponse(BaseModel): | |
| success: bool | |
| audio_url: Optional[str] = None | |
| message: Optional[str] = None | |
| processing_time: Optional[float] = None | |
| audio_duration: Optional[float] = None | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| memory_usage: Dict[str, float] | |
| disk_usage: Dict[str, float] | |
| def load_tts_model(): | |
| global tts_model, model_loading | |
| if tts_model is not None or model_loading: | |
| return | |
| model_loading = True | |
| try: | |
| logger.info("Loading NeuTTS Air model...") | |
| # Try to import with fallbacks | |
| try: | |
| from neuttsair.neutts import NeuTTSAir | |
| except ImportError as e: | |
| logger.error(f"Failed to import NeuTTS Air: {e}") | |
| # Try alternative import path | |
| sys.path.insert(0, "/app/neutts-air") | |
| from neuttsair.neutts import NeuTTSAir | |
| # Use CPU for Hugging Face free tier with fallback models | |
| tts_model = NeuTTSAir( | |
| backbone_repo="neuphonic/neutts-air", | |
| backbone_device="cpu", | |
| codec_repo="neuphonic/neucodec", | |
| codec_device="cpu" | |
| ) | |
| logger.info("NeuTTS Air model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| model_loading = False | |
| raise e | |
| model_loading = False | |
| async def startup_event(): | |
| """Load model on startup with error handling""" | |
| try: | |
| load_tts_model() | |
| except Exception as e: | |
| logger.error(f"Startup model loading failed: {e}") | |
| async def root(): | |
| return {"message": "NeuTTS Air API is running!", "status": "healthy"} | |
| async def health_check(): | |
| """Health check endpoint""" | |
| try: | |
| memory = psutil.virtual_memory() | |
| disk = psutil.disk_usage('/') | |
| return HealthResponse( | |
| status="healthy", | |
| model_loaded=tts_model is not None, | |
| memory_usage={ | |
| "total_gb": round(memory.total / (1024**3), 2), | |
| "available_gb": round(memory.available / (1024**3), 2), | |
| "used_percent": round(memory.percent, 2) | |
| }, | |
| disk_usage={ | |
| "total_gb": round(disk.total / (1024**3), 2), | |
| "free_gb": round(disk.free / (1024**3), 2), | |
| "used_percent": round(disk.percent, 2) | |
| } | |
| ) | |
| except Exception as e: | |
| return HealthResponse( | |
| status="degraded", | |
| model_loaded=tts_model is not None, | |
| memory_usage={"error": str(e)}, | |
| disk_usage={"error": str(e)} | |
| ) | |
| async def synthesize_speech( | |
| reference_text: str = Form(...), | |
| text: str = Form(...), | |
| reference_audio: UploadFile = File(...) | |
| ): | |
| """ | |
| Synthesize speech using reference audio and text | |
| """ | |
| start_time = time.time() | |
| if tts_model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded yet") | |
| # Validate inputs | |
| if not reference_text.strip() or not text.strip(): | |
| raise HTTPException(status_code=400, detail="Text fields cannot be empty") | |
| if len(text) > 1000: | |
| raise HTTPException(status_code=400, detail="Text too long. Maximum 1000 characters allowed.") | |
| temp_ref_path = None | |
| try: | |
| # Save uploaded file temporarily | |
| temp_dir = "temp_audio" | |
| os.makedirs(temp_dir, exist_ok=True) | |
| file_extension = os.path.splitext(reference_audio.filename)[1] or ".wav" | |
| temp_ref_path = os.path.join(temp_dir, f"ref_{int(time.time())}{file_extension}") | |
| async with aiofiles.open(temp_ref_path, 'wb') as out_file: | |
| content = await reference_audio.read() | |
| await out_file.write(content) | |
| # Validate audio file | |
| try: | |
| import librosa | |
| audio_duration = librosa.get_duration(path=temp_ref_path) | |
| if audio_duration < 2 or audio_duration > 30: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Audio duration ({audio_duration:.1f}s) should be between 3-15 seconds" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}") | |
| # Perform TTS | |
| logger.info(f"Starting synthesis for text: {text[:50]}...") | |
| # Encode reference | |
| ref_codes = tts_model.encode_reference(temp_ref_path) | |
| # Generate speech | |
| wav = tts_model.infer(text, ref_codes, reference_text) | |
| # Save output | |
| output_dir = "generated_audio" | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_filename = f"output_{int(time.time())}.wav" | |
| output_path = os.path.join(output_dir, output_filename) | |
| import soundfile as sf | |
| sf.write(output_path, wav, 24000) | |
| processing_time = time.time() - start_time | |
| audio_duration = len(wav) / 24000 | |
| logger.info(f"Synthesis completed in {processing_time:.2f}s") | |
| return TTSResponse( | |
| success=True, | |
| audio_url=f"/audio/{output_filename}", | |
| message="Speech synthesized successfully", | |
| processing_time=round(processing_time, 2), | |
| audio_duration=round(audio_duration, 2) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Synthesis error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") | |
| finally: | |
| # Clean up temporary file | |
| if temp_ref_path and os.path.exists(temp_ref_path): | |
| try: | |
| os.remove(temp_ref_path) | |
| except: | |
| pass | |
| async def get_audio_file(filename: str): | |
| """Serve generated audio files""" | |
| file_path = os.path.join("generated_audio", filename) | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=404, detail="Audio file not found") | |
| return FileResponse( | |
| file_path, | |
| media_type="audio/wav", | |
| filename=f"generated_speech_{filename}" | |
| ) | |
| async def synthesize_with_url(request: TTSRequest): | |
| """ | |
| Synthesize speech using a pre-uploaded reference audio file path | |
| """ | |
| start_time = time.time() | |
| if tts_model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded yet") | |
| if not request.reference_audio_path or not os.path.exists(request.reference_audio_path): | |
| raise HTTPException(status_code=400, detail="Reference audio path not found") | |
| try: | |
| # Validate audio file | |
| import librosa | |
| audio_duration = librosa.get_duration(path=request.reference_audio_path) | |
| if audio_duration < 2 or audio_duration > 30: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Audio duration ({audio_duration:.1f}s) should be between 3-15 seconds" | |
| ) | |
| # Perform TTS | |
| logger.info(f"Starting synthesis for text: {request.text[:50]}...") | |
| # Encode reference | |
| ref_codes = tts_model.encode_reference(request.reference_audio_path) | |
| # Generate speech | |
| wav = tts_model.infer(request.text, ref_codes, request.reference_text) | |
| # Save output | |
| output_dir = "generated_audio" | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_filename = f"output_{int(time.time())}.wav" | |
| output_path = os.path.join(output_dir, output_filename) | |
| import soundfile as sf | |
| sf.write(output_path, wav, 24000) | |
| processing_time = time.time() - start_time | |
| audio_duration = len(wav) / 24000 | |
| return TTSResponse( | |
| success=True, | |
| audio_url=f"/audio/{output_filename}", | |
| message="Speech synthesized successfully", | |
| processing_time=round(processing_time, 2), | |
| audio_duration=round(audio_duration, 2) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Synthesis error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") | |
| async def cleanup_audio_files(): | |
| """Clean up generated audio files older than 1 hour""" | |
| try: | |
| output_dir = "generated_audio" | |
| temp_dir = "temp_audio" | |
| deleted_count = 0 | |
| current_time = time.time() | |
| # Clean generated audio | |
| if os.path.exists(output_dir): | |
| for filename in os.listdir(output_dir): | |
| file_path = os.path.join(output_dir, filename) | |
| if os.path.isfile(file_path): | |
| file_age = current_time - os.path.getctime(file_path) | |
| if file_age > 3600: # 1 hour | |
| os.remove(file_path) | |
| deleted_count += 1 | |
| # Clean temp audio | |
| if os.path.exists(temp_dir): | |
| for filename in os.listdir(temp_dir): | |
| file_path = os.path.join(temp_dir, filename) | |
| if os.path.isfile(file_path): | |
| file_age = current_time - os.path.getctime(file_path) | |
| if file_age > 3600: # 1 hour | |
| os.remove(file_path) | |
| deleted_count += 1 | |
| return {"message": f"Cleaned up {deleted_count} files"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |