Spaces:
Build error
Build error
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import os | |
| import uuid | |
| import torch | |
| import torchaudio | |
| import base64 | |
| from transformers import AutoModelForCausalLM | |
| from yarngpt.audiotokenizer import AudioTokenizerV2 | |
| import uvicorn | |
| from datetime import datetime, timedelta | |
| import asyncio | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Nigerian TTS API", version="1.0.0") | |
| # Add CORS middleware to allow requests from anywhere | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for model components | |
| audio_tokenizer = None | |
| model = None | |
| model_loaded = False | |
| loading_error = None | |
| # Model configuration - Updated paths for Hugging Face Spaces | |
| tokenizer_path = "saheedniyi/YarnGPT2" | |
| # These files should be downloaded to /tmp during startup | |
| wav_tokenizer_config_path = "/tmp/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml" | |
| wav_tokenizer_model_path = "/tmp/wavtokenizer_large_speech_320_24k.ckpt" | |
| # Available voices and languages | |
| AVAILABLE_VOICES = { | |
| "female": ["zainab", "idera", "regina", "chinenye", "joke", "remi"], | |
| "male": ["jude", "tayo", "umar", "osagie", "onye", "emma"] | |
| } | |
| AVAILABLE_LANGUAGES = ["english", "yoruba", "igbo", "hausa"] | |
| # Input validation model | |
| class TTSRequest(BaseModel): | |
| text: str | |
| language: str = "english" | |
| voice: str = "idera" | |
| # Output model with base64-encoded audio | |
| class TTSResponse(BaseModel): | |
| audio_base64: str | |
| audio_url: str | |
| text: str | |
| voice: str | |
| language: str | |
| async def download_model_files(): | |
| """Download required model files""" | |
| global loading_error | |
| try: | |
| import requests | |
| from pathlib import Path | |
| logger.info("Starting model file downloads...") | |
| # URLs for the model files | |
| config_url = "https://huggingface.co/saheedniyi/YarnGPT2/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml" | |
| model_url = "https://huggingface.co/saheedniyi/YarnGPT2/resolve/main/wavtokenizer_large_speech_320_24k.ckpt" | |
| # Create tmp directory if it doesn't exist | |
| Path("/tmp").mkdir(exist_ok=True) | |
| # Download config file | |
| if not os.path.exists(wav_tokenizer_config_path): | |
| logger.info("Downloading tokenizer config...") | |
| response = requests.get(config_url, stream=True) | |
| response.raise_for_status() | |
| with open(wav_tokenizer_config_path, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| logger.info("Config file downloaded successfully") | |
| # Download model file | |
| if not os.path.exists(wav_tokenizer_model_path): | |
| logger.info("Downloading tokenizer model (this may take a while)...") | |
| response = requests.get(model_url, stream=True) | |
| response.raise_for_status() | |
| total_size = int(response.headers.get('content-length', 0)) | |
| downloaded = 0 | |
| with open(wav_tokenizer_model_path, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| downloaded += len(chunk) | |
| if total_size > 0: | |
| progress = (downloaded / total_size) * 100 | |
| if downloaded % (1024 * 1024 * 10) == 0: # Log every 10MB | |
| logger.info(f"Download progress: {progress:.1f}%") | |
| logger.info("Model file downloaded successfully") | |
| logger.info("All model files are ready") | |
| except Exception as e: | |
| error_msg = f"Error downloading model files: {str(e)}" | |
| logger.error(error_msg) | |
| loading_error = error_msg | |
| raise e | |
| async def load_models(): | |
| """Load the YarnGPT model and tokenizer""" | |
| global audio_tokenizer, model, model_loaded, loading_error | |
| try: | |
| logger.info("Loading YarnGPT model and tokenizer...") | |
| # First download the required files | |
| await download_model_files() | |
| # Initialize audio tokenizer | |
| logger.info("Initializing audio tokenizer...") | |
| audio_tokenizer = AudioTokenizerV2( | |
| tokenizer_path, | |
| wav_tokenizer_model_path, | |
| wav_tokenizer_config_path | |
| ) | |
| # Load the main model | |
| logger.info("Loading main model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| tokenizer_path, | |
| torch_dtype="auto" | |
| ).to(audio_tokenizer.device) | |
| model_loaded = True | |
| logger.info("Model loaded successfully!") | |
| except Exception as e: | |
| error_msg = f"Error loading models: {str(e)}" | |
| logger.error(error_msg) | |
| loading_error = error_msg | |
| model_loaded = False | |
| async def startup_event(): | |
| """Load models when the API starts""" | |
| asyncio.create_task(load_models()) | |
| async def root(): | |
| """API health check and info""" | |
| return { | |
| "status": "ok" if model_loaded else "loading", | |
| "message": "Nigerian TTS API is running" if model_loaded else "Models are loading...", | |
| "model_loaded": model_loaded, | |
| "loading_error": loading_error, | |
| "available_languages": AVAILABLE_LANGUAGES, | |
| "available_voices": AVAILABLE_VOICES | |
| } | |
| async def health_check(): | |
| """Detailed health check""" | |
| return { | |
| "status": "healthy" if model_loaded else "loading", | |
| "model_loaded": model_loaded, | |
| "loading_error": loading_error, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks): | |
| """Convert text to Nigerian-accented speech""" | |
| # Check if models are loaded | |
| if not model_loaded: | |
| if loading_error: | |
| raise HTTPException(status_code=503, detail=f"Model loading failed: {loading_error}") | |
| else: | |
| raise HTTPException(status_code=503, detail="Models are still loading. Please try again in a moment.") | |
| # Validate inputs | |
| if request.language not in AVAILABLE_LANGUAGES: | |
| raise HTTPException(status_code=400, detail=f"Language must be one of {AVAILABLE_LANGUAGES}") | |
| all_voices = AVAILABLE_VOICES["female"] + AVAILABLE_VOICES["male"] | |
| if request.voice not in all_voices: | |
| raise HTTPException(status_code=400, detail=f"Voice must be one of {all_voices}") | |
| # Generate unique filename | |
| audio_id = str(uuid.uuid4()) | |
| output_path = f"audio_files/{audio_id}.wav" | |
| os.makedirs("audio_files", exist_ok=True) | |
| try: | |
| logger.info(f"Generating TTS for text: '{request.text[:50]}...' with voice: {request.voice}") | |
| # Create prompt and generate audio | |
| prompt = audio_tokenizer.create_prompt( | |
| request.text, | |
| lang=request.language, | |
| speaker_name=request.voice | |
| ) | |
| input_ids = audio_tokenizer.tokenize_prompt(prompt) | |
| output = model.generate( | |
| input_ids=input_ids, | |
| temperature=0.1, | |
| repetition_penalty=1.1, | |
| max_length=4000, | |
| ) | |
| codes = audio_tokenizer.get_codes(output) | |
| audio = audio_tokenizer.get_audio(codes) | |
| # Save audio file | |
| torchaudio.save(output_path, audio, sample_rate=24000) | |
| logger.info(f"Audio saved to {output_path}") | |
| # Read the file and encode as base64 | |
| with open(output_path, "rb") as audio_file: | |
| audio_bytes = audio_file.read() | |
| audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') | |
| # Clean up old files after a while | |
| background_tasks.add_task(cleanup_old_files) | |
| return TTSResponse( | |
| audio_base64=audio_base64, | |
| audio_url=f"/audio/{audio_id}.wav", | |
| text=request.text, | |
| voice=request.voice, | |
| language=request.language | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error generating audio: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}") | |
| async def get_audio(filename: str): | |
| """Serve audio files""" | |
| file_path = f"audio_files/{filename}" | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=404, detail="Audio file not found") | |
| return FileResponse(file_path, media_type="audio/wav") | |
| def cleanup_old_files(): | |
| """Delete audio files older than 6 hours to manage disk space""" | |
| try: | |
| now = datetime.now() | |
| audio_dir = "audio_files" | |
| if not os.path.exists(audio_dir): | |
| return | |
| for filename in os.listdir(audio_dir): | |
| if not filename.endswith(".wav"): | |
| continue | |
| file_path = os.path.join(audio_dir, filename) | |
| file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path)) | |
| # Delete files older than 6 hours | |
| if now - file_mod_time > timedelta(hours=6): | |
| os.remove(file_path) | |
| logger.info(f"Deleted old audio file: {filename}") | |
| except Exception as e: | |
| logger.error(f"Error cleaning up old files: {e}") | |
| if __name__ == "__main__": | |
| logger.info("Starting Nigerian TTS API server...") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |