Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from TTS.api import TTS | |
| import os | |
| import tempfile | |
| import uuid | |
| import torch | |
| from typing import Optional | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="TTS API", description="Text-to-Speech API using XTTS-v2", version="1.0.0") | |
| class TTSRequest(BaseModel): | |
| text: str | |
| language: str = "en" | |
| class TTSService: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {self.device}") | |
| # Use absolute paths for the model | |
| model_path = "XTTS-v2_C3PO/" | |
| config_path = "XTTS-v2_C3PO/config.json" | |
| # Check if model files exist | |
| if not os.path.exists(config_path): | |
| logger.warning(f"Custom model config not found at {config_path}") | |
| # List contents of model directory for debugging | |
| model_dir = "/app/XTTS-v2_C3PO" | |
| if os.path.exists(model_dir): | |
| logger.info(f"Contents of {model_dir}: {os.listdir(model_dir)}") | |
| else: | |
| logger.warning(f"Model directory {model_dir} does not exist") | |
| # Fallback to default XTTS model | |
| logger.info("Falling back to default XTTS model") | |
| try: | |
| self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device) | |
| logger.info("Default TTS model loaded successfully") | |
| return | |
| except Exception as e: | |
| logger.error(f"Failed to load default TTS model: {e}") | |
| raise e | |
| try: | |
| self.tts = TTS( | |
| model_path=model_path, | |
| config_path=config_path, | |
| progress_bar=False, | |
| gpu=torch.cuda.is_available() | |
| ).to(self.device) | |
| logger.info("Custom TTS model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load custom TTS model: {e}") | |
| # Fallback to default model | |
| logger.info("Falling back to default XTTS model") | |
| try: | |
| self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device) | |
| logger.info("Default TTS model loaded successfully") | |
| except Exception as fallback_e: | |
| logger.error(f"Failed to load default TTS model: {fallback_e}") | |
| raise fallback_e | |
| def generate_speech(self, text: str, speaker_wav_path: str, language: str = "en") -> str: | |
| """Generate speech and return the path to the output file""" | |
| try: | |
| # Create a unique filename for the output | |
| output_filename = f"output_{uuid.uuid4().hex}.wav" | |
| output_path = os.path.join(tempfile.gettempdir(), output_filename) | |
| # Generate speech | |
| self.tts.tts_to_file( | |
| text=text, | |
| file_path=output_path, | |
| speaker_wav=speaker_wav_path, | |
| language=language | |
| ) | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"Error generating speech: {e}") | |
| raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}") | |
| # Initialize TTS service | |
| tts_service = TTSService() | |
| async def root(): | |
| return {"message": "TTS API is running", "status": "healthy"} | |
| async def health_check(): | |
| return {"status": "healthy", "device": tts_service.device} | |
| async def text_to_speech( | |
| text: str = Form(...), | |
| language: str = Form("en"), | |
| speaker_file: UploadFile = File(...) | |
| ): | |
| """ | |
| Convert text to speech using a reference speaker voice | |
| - **text**: The text to convert to speech | |
| - **language**: Language code (default: "en") | |
| - **speaker_file**: Audio file containing the reference speaker voice | |
| """ | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| # Validate file type | |
| if not speaker_file.content_type.startswith('audio/'): | |
| raise HTTPException(status_code=400, detail="Speaker file must be an audio file") | |
| try: | |
| # Save uploaded speaker file temporarily | |
| speaker_temp_path = "XTTS-v2_C3PO/reference.wav" | |
| with open(speaker_temp_path, "wb") as buffer: | |
| content = await speaker_file.read() | |
| buffer.write(content) | |
| # Generate speech | |
| output_path = tts_service.generate_speech(text, speaker_temp_path, language) | |
| # Return the generated audio file | |
| return FileResponse( | |
| output_path, | |
| media_type="audio/wav", | |
| filename=f"tts_output_{uuid.uuid4().hex}.wav", | |
| headers={"Content-Disposition": "attachment"} | |
| ) | |
| except Exception as e: | |
| # Clean up files in case of error | |
| if 'speaker_temp_path' in locals() and os.path.exists(speaker_temp_path): | |
| os.remove(speaker_temp_path) | |
| logger.error(f"Error in TTS endpoint: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def text_to_speech_with_url(request: TTSRequest, speaker_wav_url: str): | |
| """ | |
| Convert text to speech using a reference speaker voice from URL | |
| - **request**: TTSRequest containing text and language | |
| - **speaker_wav_url**: URL to the reference speaker audio file | |
| """ | |
| if not request.text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| try: | |
| # For this endpoint, you would need to download the file from URL | |
| # This is a simplified version - you might want to add URL validation and download logic | |
| raise HTTPException(status_code=501, detail="URL-based speaker input not implemented yet") | |
| except Exception as e: | |
| logger.error(f"Error in TTS URL endpoint: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) |