neu / app.py
Rajhuggingface4253's picture
Update app.py
afe82fe verified
raw
history blame
10.9 kB
import os
import sys
import time
import gc
import torch
import numpy as np
import aiofiles
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, Dict, Any
import psutil
import logging
# Add NeuTTS Air to path
sys.path.append("neutts-air")
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="NeuTTS Air API",
description="High-quality on-device Text-to-Speech with instant voice cloning",
version="1.0.0"
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model instance
tts_model = None
model_loading = False
# Pydantic models
class TTSRequest(BaseModel):
text: str
reference_text: str
reference_audio_path: Optional[str] = None
class TTSResponse(BaseModel):
success: bool
audio_url: Optional[str] = None
message: Optional[str] = None
processing_time: Optional[float] = None
audio_duration: Optional[float] = None
class HealthResponse(BaseModel):
status: str
model_loaded: bool
memory_usage: Dict[str, float]
disk_usage: Dict[str, float]
def load_tts_model():
global tts_model, model_loading
if tts_model is not None or model_loading:
return
model_loading = True
try:
logger.info("Loading NeuTTS Air model...")
# Try to import with fallbacks
try:
from neuttsair.neutts import NeuTTSAir
except ImportError as e:
logger.error(f"Failed to import NeuTTS Air: {e}")
# Try alternative import path
sys.path.insert(0, "/app/neutts-air")
from neuttsair.neutts import NeuTTSAir
# Use CPU for Hugging Face free tier with fallback models
tts_model = NeuTTSAir(
backbone_repo="neuphonic/neutts-air",
backbone_device="cpu",
codec_repo="neuphonic/neucodec",
codec_device="cpu"
)
logger.info("NeuTTS Air model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
model_loading = False
raise e
model_loading = False
@app.on_event("startup")
async def startup_event():
"""Load model on startup with error handling"""
try:
load_tts_model()
except Exception as e:
logger.error(f"Startup model loading failed: {e}")
@app.get("/")
async def root():
return {"message": "NeuTTS Air API is running!", "status": "healthy"}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
try:
memory = psutil.virtual_memory()
disk = psutil.disk_usage('/')
return HealthResponse(
status="healthy",
model_loaded=tts_model is not None,
memory_usage={
"total_gb": round(memory.total / (1024**3), 2),
"available_gb": round(memory.available / (1024**3), 2),
"used_percent": round(memory.percent, 2)
},
disk_usage={
"total_gb": round(disk.total / (1024**3), 2),
"free_gb": round(disk.free / (1024**3), 2),
"used_percent": round(disk.percent, 2)
}
)
except Exception as e:
return HealthResponse(
status="degraded",
model_loaded=tts_model is not None,
memory_usage={"error": str(e)},
disk_usage={"error": str(e)}
)
@app.post("/synthesize")
async def synthesize_speech(
reference_text: str = Form(...),
text: str = Form(...),
reference_audio: UploadFile = File(...)
):
"""
Synthesize speech using reference audio and text
"""
start_time = time.time()
if tts_model is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
# Validate inputs
if not reference_text.strip() or not text.strip():
raise HTTPException(status_code=400, detail="Text fields cannot be empty")
if len(text) > 1000:
raise HTTPException(status_code=400, detail="Text too long. Maximum 1000 characters allowed.")
temp_ref_path = None
try:
# Save uploaded file temporarily
temp_dir = "temp_audio"
os.makedirs(temp_dir, exist_ok=True)
file_extension = os.path.splitext(reference_audio.filename)[1] or ".wav"
temp_ref_path = os.path.join(temp_dir, f"ref_{int(time.time())}{file_extension}")
async with aiofiles.open(temp_ref_path, 'wb') as out_file:
content = await reference_audio.read()
await out_file.write(content)
# Validate audio file
try:
import librosa
audio_duration = librosa.get_duration(path=temp_ref_path)
if audio_duration < 2 or audio_duration > 30:
raise HTTPException(
status_code=400,
detail=f"Audio duration ({audio_duration:.1f}s) should be between 3-15 seconds"
)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
# Perform TTS
logger.info(f"Starting synthesis for text: {text[:50]}...")
# Encode reference
ref_codes = tts_model.encode_reference(temp_ref_path)
# Generate speech
wav = tts_model.infer(text, ref_codes, reference_text)
# Save output
output_dir = "generated_audio"
os.makedirs(output_dir, exist_ok=True)
output_filename = f"output_{int(time.time())}.wav"
output_path = os.path.join(output_dir, output_filename)
import soundfile as sf
sf.write(output_path, wav, 24000)
processing_time = time.time() - start_time
audio_duration = len(wav) / 24000
logger.info(f"Synthesis completed in {processing_time:.2f}s")
return TTSResponse(
success=True,
audio_url=f"/audio/{output_filename}",
message="Speech synthesized successfully",
processing_time=round(processing_time, 2),
audio_duration=round(audio_duration, 2)
)
except Exception as e:
logger.error(f"Synthesis error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
finally:
# Clean up temporary file
if temp_ref_path and os.path.exists(temp_ref_path):
try:
os.remove(temp_ref_path)
except:
pass
@app.get("/audio/{filename}")
async def get_audio_file(filename: str):
"""Serve generated audio files"""
file_path = os.path.join("generated_audio", filename)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="Audio file not found")
return FileResponse(
file_path,
media_type="audio/wav",
filename=f"generated_speech_{filename}"
)
@app.post("/synthesize-with-url")
async def synthesize_with_url(request: TTSRequest):
"""
Synthesize speech using a pre-uploaded reference audio file path
"""
start_time = time.time()
if tts_model is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
if not request.reference_audio_path or not os.path.exists(request.reference_audio_path):
raise HTTPException(status_code=400, detail="Reference audio path not found")
try:
# Validate audio file
import librosa
audio_duration = librosa.get_duration(path=request.reference_audio_path)
if audio_duration < 2 or audio_duration > 30:
raise HTTPException(
status_code=400,
detail=f"Audio duration ({audio_duration:.1f}s) should be between 3-15 seconds"
)
# Perform TTS
logger.info(f"Starting synthesis for text: {request.text[:50]}...")
# Encode reference
ref_codes = tts_model.encode_reference(request.reference_audio_path)
# Generate speech
wav = tts_model.infer(request.text, ref_codes, request.reference_text)
# Save output
output_dir = "generated_audio"
os.makedirs(output_dir, exist_ok=True)
output_filename = f"output_{int(time.time())}.wav"
output_path = os.path.join(output_dir, output_filename)
import soundfile as sf
sf.write(output_path, wav, 24000)
processing_time = time.time() - start_time
audio_duration = len(wav) / 24000
return TTSResponse(
success=True,
audio_url=f"/audio/{output_filename}",
message="Speech synthesized successfully",
processing_time=round(processing_time, 2),
audio_duration=round(audio_duration, 2)
)
except Exception as e:
logger.error(f"Synthesis error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
@app.delete("/cleanup")
async def cleanup_audio_files():
"""Clean up generated audio files older than 1 hour"""
try:
output_dir = "generated_audio"
temp_dir = "temp_audio"
deleted_count = 0
current_time = time.time()
# Clean generated audio
if os.path.exists(output_dir):
for filename in os.listdir(output_dir):
file_path = os.path.join(output_dir, filename)
if os.path.isfile(file_path):
file_age = current_time - os.path.getctime(file_path)
if file_age > 3600: # 1 hour
os.remove(file_path)
deleted_count += 1
# Clean temp audio
if os.path.exists(temp_dir):
for filename in os.listdir(temp_dir):
file_path = os.path.join(temp_dir, filename)
if os.path.isfile(file_path):
file_age = current_time - os.path.getctime(file_path)
if file_age > 3600: # 1 hour
os.remove(file_path)
deleted_count += 1
return {"message": f"Cleaned up {deleted_count} files"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)