Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import tempfile | |
| import uuid | |
| import logging | |
| from typing import Optional | |
| from huggingface_hub import snapshot_download | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from TTS.api import TTS | |
| # Set environment variables for Coqui TTS | |
| os.environ["COQUI_TOS_AGREED"] = "1" | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="Coqui TTS C-3PO API", | |
| description="Text-to-Speech API using Coqui TTS with C-3PO fine-tuned voice model", | |
| version="1.0.0" | |
| ) | |
| class TTSRequest(BaseModel): | |
| text: str | |
| language: str = "en" | |
| class CoquiTTSService: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {self.device}") | |
| # Download and initialize the C-3PO fine-tuned model | |
| try: | |
| logger.info("Downloading C-3PO fine-tuned XTTS model from Hugging Face...") | |
| # Download the model files from Hugging Face | |
| model_path = snapshot_download( | |
| repo_id="Borcherding/XTTS-v2_C3PO", | |
| local_dir="./models/XTTS-v2_C3PO", | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info(f"Model downloaded to: {model_path}") | |
| # Initialize TTS with the downloaded C-3PO model | |
| config_path = os.path.join(model_path, "config.json") | |
| if os.path.exists(config_path): | |
| logger.info("Loading C-3PO fine-tuned model...") | |
| self.tts = TTS( | |
| model_path=model_path, | |
| config_path=config_path, | |
| progress_bar=False, | |
| gpu=torch.cuda.is_available() | |
| ).to(self.device) | |
| logger.info("C-3PO fine-tuned model loaded successfully!") | |
| else: | |
| # Fallback to using the model by name if config not found | |
| logger.info("Config not found, trying to load by repo ID...") | |
| self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device) | |
| logger.info("Fallback XTTS v2 model loaded!") | |
| # Store model path for reference audio | |
| self.model_path = model_path | |
| # Check for speakers | |
| if hasattr(self.tts, 'speakers') and self.tts.speakers: | |
| logger.info(f"Available speakers: {len(self.tts.speakers)}") | |
| self.default_speaker = self.tts.speakers[0] if self.tts.speakers else None | |
| else: | |
| logger.info("No preset speakers available - voice cloning mode") | |
| self.default_speaker = None | |
| except Exception as e: | |
| logger.error(f"Failed to load C-3PO model: {e}") | |
| logger.info("Falling back to standard XTTS v2 model...") | |
| try: | |
| self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device) | |
| self.model_path = None | |
| self.default_speaker = None | |
| logger.info("Fallback XTTS v2 model loaded!") | |
| except Exception as fallback_error: | |
| logger.error(f"Fallback model also failed: {fallback_error}") | |
| raise fallback_error | |
| def get_c3po_reference_audio(self): | |
| """Get reference audio file for C-3PO voice if available""" | |
| if self.model_path: | |
| # Look for reference audio files in the model directory | |
| possible_ref_files = [ | |
| "reference.wav", "speaker.wav", "c3po.wav", | |
| "sample.wav", "reference_audio.wav" | |
| ] | |
| for ref_file in possible_ref_files: | |
| ref_path = os.path.join(self.model_path, ref_file) | |
| if os.path.exists(ref_path): | |
| logger.info(f"Found C-3PO reference audio: {ref_path}") | |
| return ref_path | |
| return None | |
| def generate_speech(self, text: str, speaker_wav_path: Optional[str] = None, | |
| language: str = "en", use_c3po_voice: bool = True) -> str: | |
| """Generate speech using Coqui TTS with optional C-3PO voice""" | |
| try: | |
| # Validate text length | |
| if len(text) < 2: | |
| raise HTTPException(status_code=400, detail="Text too short") | |
| if len(text) > 500: | |
| raise HTTPException(status_code=400, detail="Text too long (max 500 characters)") | |
| # Generate unique output filename | |
| output_filename = f"c3po_tts_output_{uuid.uuid4().hex}.wav" | |
| output_path = os.path.join(tempfile.gettempdir(), output_filename) | |
| # Determine which speaker to use | |
| final_speaker_wav = speaker_wav_path | |
| # If no speaker provided and C-3PO voice requested, try to use reference audio | |
| if not final_speaker_wav and use_c3po_voice: | |
| c3po_ref = self.get_c3po_reference_audio() | |
| if c3po_ref: | |
| final_speaker_wav = c3po_ref | |
| logger.info("Using C-3PO reference audio for voice synthesis") | |
| if final_speaker_wav: | |
| # Voice cloning mode | |
| logger.info("Generating speech with voice cloning...") | |
| wav = self.tts.tts( | |
| text=text, | |
| speaker_wav=final_speaker_wav, | |
| language=language | |
| ) | |
| # Save the audio | |
| import torchaudio | |
| if isinstance(wav, list): | |
| wav = torch.tensor(wav) | |
| if wav.dim() == 1: | |
| wav = wav.unsqueeze(0) | |
| torchaudio.save(output_path, wav, 22050) | |
| elif self.default_speaker: | |
| # Use preset speaker | |
| logger.info(f"Generating speech with preset speaker: {self.default_speaker}") | |
| self.tts.tts_to_file( | |
| text=text, | |
| speaker=self.default_speaker, | |
| language=language, | |
| file_path=output_path | |
| ) | |
| else: | |
| # Try without speaker (some models support this) | |
| logger.info("Generating speech without specific speaker...") | |
| self.tts.tts_to_file( | |
| text=text, | |
| language=language, | |
| file_path=output_path | |
| ) | |
| if not os.path.exists(output_path): | |
| raise HTTPException(status_code=500, detail="Failed to generate audio file") | |
| logger.info(f"Speech generated successfully: {output_path}") | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"Error generating speech: {e}") | |
| if isinstance(e, HTTPException): | |
| raise e | |
| raise HTTPException(status_code=500, detail=f"Speech generation failed: {str(e)}") | |
| # Initialize TTS service | |
| logger.info("Initializing Coqui TTS service...") | |
| try: | |
| tts_service = CoquiTTSService() | |
| logger.info("TTS service initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize TTS service: {e}") | |
| tts_service = None | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "Coqui TTS C-3PO API", | |
| "status": "healthy" if tts_service else "error", | |
| "model": "XTTS v2", | |
| "voice_cloning": True | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| if not tts_service: | |
| raise HTTPException(status_code=503, detail="TTS service not available") | |
| c3po_ref_available = tts_service.get_c3po_reference_audio() is not None | |
| return { | |
| "status": "healthy", | |
| "device": tts_service.device, | |
| "model": "C-3PO Fine-tuned XTTS v2 (Coqui TTS)", | |
| "default_speaker": tts_service.default_speaker, | |
| "voice_cloning_available": True, | |
| "c3po_voice_available": c3po_ref_available, | |
| "model_path": getattr(tts_service, 'model_path', None) | |
| } | |
| async def text_to_speech( | |
| text: str = Form(...), | |
| language: str = Form("en"), | |
| speaker_file: UploadFile = File(None), | |
| use_c3po_voice: bool = Form(True) | |
| ): | |
| """ | |
| Convert text to speech using Coqui TTS | |
| - **text**: Text to convert to speech (2-500 characters) | |
| - **language**: Language code (default: "en") | |
| - **speaker_file**: Reference audio file for voice cloning (optional) | |
| - **use_c3po_voice**: Use C-3PO voice if no speaker file provided (default: True) | |
| """ | |
| if not tts_service: | |
| raise HTTPException(status_code=503, detail="TTS service not available") | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| speaker_temp_path = None | |
| try: | |
| # Handle speaker file if provided | |
| if speaker_file is not None: | |
| if not speaker_file.content_type or not speaker_file.content_type.startswith('audio/'): | |
| raise HTTPException(status_code=400, detail="Speaker file must be an audio file") | |
| # Save uploaded file temporarily | |
| speaker_temp_path = os.path.join( | |
| tempfile.gettempdir(), | |
| f"speaker_{uuid.uuid4().hex}.wav" | |
| ) | |
| with open(speaker_temp_path, "wb") as buffer: | |
| content = await speaker_file.read() | |
| buffer.write(content) | |
| logger.info(f"Speaker file saved: {speaker_temp_path}") | |
| # Generate speech | |
| output_path = tts_service.generate_speech(text, speaker_temp_path, language, use_c3po_voice) | |
| # Clean up temporary speaker file | |
| if speaker_temp_path and os.path.exists(speaker_temp_path): | |
| try: | |
| os.remove(speaker_temp_path) | |
| except: | |
| pass | |
| # Return the generated audio | |
| voice_type = "custom" if speaker_file else ("c3po" if use_c3po_voice else "default") | |
| return FileResponse( | |
| output_path, | |
| media_type="audio/wav", | |
| filename=f"c3po_tts_{voice_type}_{uuid.uuid4().hex}.wav", | |
| headers={"Content-Disposition": "attachment"} | |
| ) | |
| except Exception as e: | |
| # Clean up on error | |
| if speaker_temp_path and os.path.exists(speaker_temp_path): | |
| try: | |
| os.remove(speaker_temp_path) | |
| except: | |
| pass | |
| logger.error(f"Error in TTS endpoint: {e}") | |
| if isinstance(e, HTTPException): | |
| raise e | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def text_to_speech_c3po( | |
| text: str = Form(...), | |
| language: str = Form("en") | |
| ): | |
| """ | |
| Convert text to speech using C-3PO voice specifically | |
| - **text**: Text to convert to speech (2-500 characters) | |
| - **language**: Language code (default: "en") | |
| """ | |
| if not tts_service: | |
| raise HTTPException(status_code=503, detail="TTS service not available") | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| # Check if C-3PO voice is available | |
| c3po_ref = tts_service.get_c3po_reference_audio() | |
| if not c3po_ref: | |
| raise HTTPException(status_code=503, detail="C-3PO reference audio not available") | |
| try: | |
| # Generate speech with C-3PO voice | |
| output_path = tts_service.generate_speech(text, None, language, use_c3po_voice=True) | |
| return FileResponse( | |
| output_path, | |
| media_type="audio/wav", | |
| filename=f"c3po_voice_{uuid.uuid4().hex}.wav", | |
| headers={"Content-Disposition": "attachment"} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in C-3PO TTS endpoint: {e}") | |
| if isinstance(e, HTTPException): | |
| raise e | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def text_to_speech_json(request: TTSRequest): | |
| """ | |
| Convert text to speech using JSON request with C-3PO voice | |
| - **request**: TTSRequest containing text and language | |
| """ | |
| if not tts_service: | |
| raise HTTPException(status_code=503, detail="TTS service not available") | |
| if not request.text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| try: | |
| # Generate speech with C-3PO voice by default | |
| output_path = tts_service.generate_speech(request.text, None, request.language, use_c3po_voice=True) | |
| return FileResponse( | |
| output_path, | |
| media_type="audio/wav", | |
| filename=f"c3po_tts_{request.language}_{uuid.uuid4().hex}.wav", | |
| headers={"Content-Disposition": "attachment"} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in TTS JSON endpoint: {e}") | |
| if isinstance(e, HTTPException): | |
| raise e | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_models(): | |
| """List available TTS models""" | |
| try: | |
| # Create a temporary TTS instance to list models | |
| temp_tts = TTS() | |
| models = temp_tts.list_models() | |
| return {"models": models[:20]} # Return first 20 models | |
| except Exception as e: | |
| logger.error(f"Error listing models: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to list models") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |