caash_api / app.py
Somalitts's picture
Update app.py
9d4c38e verified
import os
import re
import uuid
import torch
import torchaudio
import soundfile as sf
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse
from pydantic import BaseModel
import logging
import tempfile
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from speechbrain.pretrained import EncoderClassifier
# --- Dejinta iyo Isku-habeynta (Configuration) ---
logging.basicConfig(level=logging.INFO)
app = FastAPI(title="Multi-Voice Somali Text-to-Speech API")
# Hubinta aaladda (GPU ama CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {device}")
# Faylasha codadka tixraaca (ku dar halkan faylashaada .wav)
# Hubi in faylashan ay yaalliin isla galka uu ku jiro koodhkan
VOICE_SAMPLE_FILES = ["1.wav"]
EMBEDDING_DIR = "speaker_embeddings"
os.makedirs(EMBEDDING_DIR, exist_ok=True)
# --- Soo Dejinta Model-yada (Global variables) ---
processor = None
model = None
vocoder = None
speaker_model = None
speaker_embeddings_cache = {}
@app.on_event("startup")
async def startup_event():
"""
Shaqadan waxay shaqaynaysaa hal mar marka uu barnaamijku bilaabmo.
Waxay soo dejinaysaa model-yada waxayna diyaarisaa codadka.
"""
global processor, model, vocoder, speaker_model
logging.info("Loading models...")
try:
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
model = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad").to(device)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
speaker_model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-xvect-voxceleb",
run_opts={"device": device},
savedir=os.path.join("pretrained_models", "spkrec-xvect-voxceleb")
)
logging.info("Models loaded successfully.")
except Exception as e:
logging.error(f"Error loading models: {e}")
raise RuntimeError(f"Could not load models: {e}")
logging.info("Pre-caching speaker embeddings...")
for voice_file in VOICE_SAMPLE_FILES:
if not os.path.exists(voice_file):
raise FileNotFoundError(f"Reference audio file not found: {voice_file}. Make sure it's in the same directory.")
get_speaker_embedding(voice_file)
logging.info("Embeddings cached. Application is ready to serve requests.")
def get_speaker_embedding(wav_file_path):
"""
Waxay abuurtaa oo kaydisaa 'speaker embedding' ama way soo akhridaa haddii uu horay u kaydsanaa.
"""
if wav_file_path in speaker_embeddings_cache:
return speaker_embeddings_cache[wav_file_path]
embedding_path = os.path.join(EMBEDDING_DIR, f"{os.path.basename(wav_file_path)}.pt")
if os.path.exists(embedding_path):
embedding = torch.load(embedding_path, map_location=device)
speaker_embeddings_cache[wav_file_path] = embedding
logging.info(f"Loaded cached embedding for {wav_file_path}")
return embedding
try:
audio, sr = torchaudio.load(wav_file_path)
if sr != 16000:
audio = torchaudio.functional.resample(audio, sr, 16000)
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
with torch.no_grad():
embedding = speaker_model.encode_batch(audio.to(device))
embedding = torch.nn.functional.normalize(embedding, dim=2).squeeze()
torch.save(embedding.cpu(), embedding_path)
speaker_embeddings_cache[wav_file_path] = embedding.to(device)
logging.info(f"Generated and cached new embedding for {wav_file_path}")
return embedding.to(device)
except Exception as e:
logging.error(f"Could not process audio file {wav_file_path}. Error: {e}")
raise HTTPException(status_code=500, detail=f"Failed to process reference audio: {wav_file_path}")
# --- Shaqooyinka Hagaajinta Qoraalka (Text Processing) ---
# (Kuwani sidoodii hore ayay u fiican yihiin)
number_words = {
0: "eber", 1: "kow", 2: "labo", 3: "saddex", 4: "afar", 5: "shan",
6: "lix", 7: "toddobo", 8: "siddeed", 9: "sagaal", 10: "toban",
11: "kow iyo toban", 12: "labo iyo toban", 13: "saddex iyo toban",
14: "afar iyo toban", 15: "shan iyo toban", 16: "lix iyo toban",
17: "toddobo iyo toban", 18: "siddeed iyo toban", 19: "sagaal iyo toban",
20: "labaatan", 30: "soddon", 40: "afartan", 50: "konton",
60: "lixdan", 70: "toddobaatan", 80: "siddeetan", 90: "sagaashan",
100: "boqol", 1000: "kun",
}
def number_to_words_recursive(n):
if n in number_words: return number_words[n]
if n < 100: return number_words[n//10 * 10] + (" iyo " + number_words[n%10] if n%10 else "")
if n < 1000: return (number_to_words_recursive(n//100) + " boqol" if n//100 > 1 else "boqol") + (" iyo " + number_to_words_recursive(n%100) if n%100 else "")
if n < 1000000: return (number_to_words_recursive(n//1000) + " kun") + (" iyo " + number_to_words_recursive(n%1000) if n%1000 else "")
return str(n)
def replace_numbers_with_words(text):
return re.sub(r'\b\d+\b', lambda m: number_to_words_recursive(int(m.group())), text)
def normalize_text(text):
text = text.lower()
text = replace_numbers_with_words(text)
text = re.sub(r'[^\w\s\']', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
# --- Qaabka Codsiga API-ga (Pydantic Model) ---
class TTSRequest(BaseModel):
text: str
voice_choice: str = "1.wav" # Qiimaha asalka ah haddii aan la soo dirin
# --- Endpoints-ka API-ga ---
@app.get("/voices", summary="Soo Hel Codadka La Heli Karo")
async def get_available_voices():
"""
Wuxuu soo celinayaa liiska faylasha codadka ee diyaar ka ah.
"""
return {"available_voices": VOICE_SAMPLE_FILES}
@app.post("/speak", summary="Abuur Cod Qoraal ka timid")
async def text_to_speech_endpoint(payload: TTSRequest, background_tasks: BackgroundTasks):
"""
Wuxuu qoraal u beddelaa cod .wav ah.
- **text**: Qoraalka aad rabto inaad cod u beddesho.
- **voice_choice**: Faylka codka aad rabto inaad tixraacdo (tusaale, "1.wav").
"""
if not payload.text or not payload.text.strip():
raise HTTPException(status_code=400, detail="Qoraalku ma bannaanaan karo (Text cannot be empty).")
if payload.voice_choice not in VOICE_SAMPLE_FILES:
raise HTTPException(status_code=400, detail=f"Codka la doortay '{payload.voice_choice}' lama helin.")
try:
speaker_embedding = get_speaker_embedding(payload.voice_choice)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Faylka codka ee '{payload.voice_choice}' lama helin.")
normalized_text = normalize_text(payload.text)
logging.info(f"Generating speech for: '{normalized_text}' with voice '{payload.voice_choice}'")
inputs = processor(text=normalized_text, return_tensors="pt").to(device)
with torch.no_grad():
speech = model.generate_speech(
inputs["input_ids"],
speaker_embedding.unsqueeze(0),
vocoder=vocoder
)
# Ku kaydi fayl ku meel gaar ah
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
sf.write(tmp_file.name, speech.cpu().numpy(), 16000)
# Ku dar shaqo tirtiraysa faylka ka dib marka la soo celiyo
background_tasks.add_task(os.remove, tmp_file.name)
# Soo celi faylka codka
return FileResponse(
path=tmp_file.name,
media_type="audio/wav",
filename=f"{uuid.uuid4()}.wav"
)