Spaces:
Sleeping
Sleeping
| """ | |
| Minimal Text-to-Speech API using Coqui TTS VITS model | |
| FastAPI application for Hugging Face Spaces | |
| """ | |
| import os | |
| import tempfile | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional | |
| from fastapi import FastAPI, HTTPException, Form | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| import uvicorn | |
| # Import TTS | |
| try: | |
| from TTS.api import TTS | |
| except ImportError: | |
| raise ImportError("TTS library not found. Install with: pip install TTS") | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Text-to-Speech API", | |
| description="Minimal TTS API using Coqui TTS VITS model", | |
| version="1.0.0" | |
| ) | |
| # Global TTS model variable | |
| tts_model = None | |
| # Request model | |
| class TTSRequest(BaseModel): | |
| text: str | |
| async def startup_event(): | |
| """ | |
| Load the TTS model once at startup. | |
| Using the VITS model for LJSpeech dataset. | |
| """ | |
| global tts_model | |
| try: | |
| logger.info("Loading TTS model...") | |
| # Use the specific VITS model requested | |
| model_name = "tts_models/en/ljspeech/vits" | |
| tts_model = TTS(model_name=model_name, progress_bar=False) | |
| logger.info("TTS model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to load TTS model: {str(e)}") | |
| raise e | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "message": "Text-to-Speech API is running", | |
| "model": "tts_models/en/ljspeech/vits", | |
| "engine": "Coqui TTS" | |
| } | |
| async def tts_get(text: str): | |
| """ | |
| Simple GET endpoint for TTS | |
| Usage: GET /tts?text=Hello%20world | |
| """ | |
| if not text or len(text.strip()) == 0: | |
| raise HTTPException(status_code=400, detail="Text parameter is required") | |
| return await generate_speech(text) | |
| async def tts_post( | |
| request: TTSRequest = None, | |
| text: str = Form(None) | |
| ): | |
| """ | |
| POST endpoint for TTS | |
| Accepts JSON body or form data | |
| """ | |
| # Handle different input formats | |
| if request: | |
| input_text = request.text | |
| elif text: | |
| input_text = text | |
| else: | |
| raise HTTPException(status_code=400, detail="Text is required") | |
| if not input_text or len(input_text.strip()) == 0: | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| return await generate_speech(input_text) | |
| async def generate_speech(text: str): | |
| """ | |
| Generate speech from text using the VITS model | |
| """ | |
| if not tts_model: | |
| raise HTTPException(status_code=503, detail="TTS model not loaded") | |
| try: | |
| # Create temporary file for output | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| output_path = tmp_file.name | |
| logger.info(f"Generating speech for text: '{text[:50]}...'") | |
| # Generate speech using VITS model | |
| tts_model.tts_to_file( | |
| text=text, | |
| file_path=output_path | |
| ) | |
| # Verify the file was created and has content | |
| if not os.path.exists(output_path) or os.path.getsize(output_path) == 0: | |
| raise Exception("Generated audio file is empty or was not created") | |
| logger.info(f"Speech generated successfully, file size: {os.path.getsize(output_path)} bytes") | |
| # Return the audio file | |
| return FileResponse( | |
| path=output_path, | |
| media_type="audio/wav", | |
| filename="speech.wav", | |
| headers={ | |
| "Content-Disposition": "attachment; filename=speech.wav", | |
| "Cache-Control": "no-cache" | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error generating speech: {str(e)}") | |
| # Clean up output file on error | |
| if 'output_path' in locals() and os.path.exists(output_path): | |
| try: | |
| os.unlink(output_path) | |
| except: | |
| pass | |
| raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": tts_model is not None, | |
| "model_name": "tts_models/en/ljspeech/vits" | |
| } | |
| if __name__ == "__main__": | |
| # For local development and HF Spaces | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |