TTS_models / app.py
h-rand's picture
Update app.py
eb7d099 verified
from fastapi import FastAPI, Response, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pocket_tts import TTSModel
import scipy.io.wavfile
import io
import torch
app = FastAPI()
# Configuration CORS pour que votre site puisse appeler l'API
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# --- CHARGEMENT ---
print("⏳ Chargement de Pocket TTS...")
try:
# Chargement du modèle sans arguments (utilise le défaut de la lib)
tts_model = TTSModel.load_model()
print("✅ Modèle chargé !")
except Exception as e:
print(f"❌ ERREUR CRITIQUE CHARGEMENT MODÈLE: {e}")
# On ne quitte pas pour laisser le serveur démarrer et voir les logs,
# mais ça ne marchera pas sans modèle.
# Dictionnaire des voix officielles
# Cela permet d'envoyer juste "voice": "marius" depuis le Javascript
VOICES = {
"alba": "hf://kyutai/tts-voices/alba-mackenna/casual.wav",
"marius": "hf://kyutai/tts-voices/marius-kuntze/casual.wav",
"javert": "hf://kyutai/tts-voices/javert-mccall/casual.wav",
"jean": "hf://kyutai/tts-voices/jean-valjean/casual.wav",
"fantine": "hf://kyutai/tts-voices/fantine-becker/casual.wav",
"cosette": "hf://kyutai/tts-voices/cosette-moore/casual.wav",
"eponine": "hf://kyutai/tts-voices/eponine-frazier/casual.wav",
"azelma": "hf://kyutai/tts-voices/azelma-frazier/casual.wav"
}
# Cache pour garder les voix en mémoire (comme recommandé dans la doc)
loaded_voice_states = {}
def get_voice_state(voice_name="alba"):
# Si on a déjà chargé cette voix, on la renvoie direct (rapide)
if voice_name in loaded_voice_states:
return loaded_voice_states[voice_name]
# Sinon on la charge (lent la première fois)
print(f"📥 Chargement de la voix : {voice_name}")
full_path = VOICES.get(voice_name, VOICES["alba"])
try:
state = tts_model.get_state_for_audio_prompt(full_path)
loaded_voice_states[voice_name] = state
return state
except Exception as e:
print(f"Erreur chargement voix {voice_name}: {e}")
# Fallback sur Alba si erreur
if voice_name != "alba":
return get_voice_state("alba")
raise e
# Pré-chargement de la voix par défaut au démarrage
try:
get_voice_state("alba")
print("✅ Voix par défaut (Alba) prête !")
except:
pass
@app.post("/tts")
async def generate_speech(data: dict):
text = data.get("text", "")
voice_name = data.get("voice", "alba") # Par défaut 'alba'
if not text:
raise HTTPException(status_code=400, detail="Texte vide")
print(f"🗣️ Génération ({voice_name}): {text[:30]}...")
try:
# Récupération de l'état de la voix
voice_state = get_voice_state(voice_name)
# Génération
audio_tensor = tts_model.generate_audio(voice_state, text)
# Conversion WAV
buffer = io.BytesIO()
scipy.io.wavfile.write(buffer, tts_model.sample_rate, audio_tensor.numpy())
buffer.seek(0)
return Response(content=buffer.read(), media_type="audio/wav")
except Exception as e:
print(f"❌ Erreur génération : {e}")
return Response(content=str(e), status_code=500)
@app.get("/")
def home():
return {"status": "Pocket TTS API Ready", "available_voices": list(VOICES.keys())}