ash_voice / app.py
Somalitts's picture
Update app.py
fc0b1ae verified
import os
import re
import uuid
import torch
import torchaudio
import soundfile as sf
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse
from pydantic import BaseModel
import logging
import tempfile
# --- ISBEDDELKA UGU MUHIIMSAN ---
# Deji 'environment variable' si aad ugu qasabto Hugging Face inuu isticmaalo /tmp
# Tani waa inay ka horraysaa dhammaan 'import'-yada transformers
CACHE_DIR = "/tmp/huggingface_cache"
os.environ['HF_HOME'] = CACHE_DIR
os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
os.environ['HF_DATASETS_CACHE'] = CACHE_DIR
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from speechbrain.inference.speaker import EncoderClassifier
# --- Dejinta iyo Isku-habeynta ---
logging.basicConfig(level=logging.INFO)
app = FastAPI(title="Multi-Voice Somali Text-to-Speech API")
# Hubinta aaladda (GPU ama CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {device}")
# Faylasha codadka tixraaca
VOICE_SAMPLE_FILES = ["1.wav"]
EMBEDDING_DIR = "/tmp/speaker_embeddings"
os.makedirs(EMBEDDING_DIR, exist_ok=True)
# Global variables for models
processor = None
model = None
vocoder = None
speaker_model = None
speaker_embeddings_cache = {}
@app.on_event("startup")
async def startup_event():
"""
Shaqadan waxay shaqaynaysaa hal mar marka uu barnaamijku bilaabmo.
"""
global processor, model, vocoder, speaker_model
logging.info(f"Models will be cached in: {os.environ.get('HF_HOME')}")
try:
# Hadda looma baahna in la gudbiyo 'cache_dir' mar kasta, laakiin way fiican tahay in la daayo
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
model = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad").to(device)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
# savedir wuxuu weli muhiim u yahay speechbrain
speaker_model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-xvect-voxceleb",
run_opts={"device": device},
savedir=os.path.join(CACHE_DIR, "spkrec-xvect-voxceleb")
)
logging.info("Models loaded successfully.")
except Exception as e:
logging.error(f"Error loading models: {e}")
# Ku dar faahfaahin dheeri ah oo ku saabsan qaladka si loo fahmo
import traceback
logging.error(traceback.format_exc())
raise RuntimeError(f"Could not load models: {e}")
logging.info("Pre-caching speaker embeddings...")
for voice_file in VOICE_SAMPLE_FILES:
if not os.path.exists(voice_file):
raise FileNotFoundError(f"Reference audio file not found: {voice_file}. Make sure it's in your repository.")
get_speaker_embedding(voice_file)
logging.info("Embeddings cached. Application is ready.")
def get_speaker_embedding(wav_file_path):
if wav_file_path in speaker_embeddings_cache:
return speaker_embeddings_cache[wav_file_path]
embedding_path = os.path.join(EMBEDDING_DIR, f"{os.path.basename(wav_file_path)}.pt")
if os.path.exists(embedding_path):
embedding = torch.load(embedding_path, map_location=device)
speaker_embeddings_cache[wav_file_path] = embedding
logging.info(f"Loaded cached embedding for {wav_file_path}")
return embedding
try:
audio, sr = torchaudio.load(wav_file_path)
if sr != 16000:
audio = torchaudio.functional.resample(audio, sr, 16000)
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
with torch.no_grad():
embedding = speaker_model.encode_batch(audio.to(device))
embedding = torch.nn.functional.normalize(embedding, dim=2).squeeze()
torch.save(embedding.cpu(), embedding_path)
speaker_embeddings_cache[wav_file_path] = embedding.to(device)
logging.info(f"Generated and cached new embedding for {wav_file_path}")
return embedding.to(device)
except Exception as e:
logging.error(f"Could not process audio file {wav_file_path}. Error: {e}")
raise HTTPException(status_code=500, detail=f"Failed to process reference audio: {wav_file_path}")
# --- Inta kale ee koodhka isma beddelin ---
class TTSRequest(BaseModel):
text: str
voice_choice: str = "1.wav"
@app.get("/voices")
async def get_available_voices():
return {"available_voices": VOICE_SAMPLE_FILES}
# ... (Inta kale ee koodhka waa sidii hore)
def normalize_text(text):
# Shaqooyinkaaga normalize halkan geli
return text
@app.post("/speak")
async def text_to_speech_endpoint(payload: TTSRequest, background_tasks: BackgroundTasks):
if not payload.text or not payload.text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty.")
if payload.voice_choice not in VOICE_SAMPLE_FILES:
raise HTTPException(status_code=400, detail=f"Voice choice '{payload.voice_choice}' not available.")
speaker_embedding = get_speaker_embedding(payload.voice_choice)
normalized_text = normalize_text(payload.text)
inputs = processor(text=normalized_text, return_tensors="pt").to(device)
with torch.no_grad():
speech = model.generate_speech(
inputs["input_ids"],
speaker_embedding.unsqueeze(0),
vocoder=vocoder
)
fd, tmp_path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
sf.write(tmp_path, speech.cpu().numpy(), 16000)
background_tasks.add_task(os.remove, tmp_path)
return FileResponse(path=tmp_path, media_type="audio/wav", filename="generated_voice.wav")