yo-ha-tts / app.py
Ronaldodev's picture
Update app.py
38a631c verified
import os
import tempfile
import torch
import numpy as np
from flask import Flask, request, send_file, jsonify
from transformers import VitsModel, AutoTokenizer
from pydub import AudioSegment
import scipy.io.wavfile as wavfile
app = Flask(__name__)
# ---------------- Configuration ----------------
TTS_MODELS = {
"yoruba": {
"tokenizer": "FarmerlineML/yoruba_tts-2025",
"checkpoint": "FarmerlineML/yoruba_tts-2025"
},
"hausa": {
"tokenizer": "FarmerlineML/main_hausa_TTS",
"checkpoint": "FarmerlineML/main_hausa_TTS"
}
}
device = "cuda" if torch.cuda.is_available() else "cpu"
models = {}
tokenizers = {}
# ---------------- Load models at startup ----------------
print("Loading TTS models...")
for lang, cfg in TTS_MODELS.items():
print(f"Loading {lang} model...")
model = VitsModel.from_pretrained(cfg["checkpoint"]).to(device)
model.eval()
model.noise_scale = 0.5
model.noise_scale_duration = 0.5
model.speaking_rate = 0.9
models[lang] = model
tokenizers[lang] = AutoTokenizer.from_pretrained(cfg["tokenizer"])
print("Models loaded.")
# ---------------- Utils ----------------
def wav_to_mp3(wave: np.ndarray, sr: int) -> str:
# 1️⃣ sécurité : float32 → normalisation
if wave.dtype != np.int16:
max_val = np.max(np.abs(wave))
if max_val > 0:
wave = wave / max_val
wave = (wave * 32767).astype(np.int16)
# 2️⃣ écrire WAV
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
wavfile.write(f.name, sr, wave)
wav_path = f.name
# 3️⃣ charger avec pydub et forcer mono
audio = AudioSegment.from_wav(wav_path)
audio = audio.set_channels(1)
audio = audio.normalize() # 🔥 très important
mp3_path = wav_path.replace(".wav", ".mp3")
audio.export(
mp3_path,
format="mp3",
bitrate="128k"
)
os.remove(wav_path)
return mp3_path
# ---------------- Routes ----------------
@app.route("/tts", methods=["POST"])
def tts():
data = request.get_json()
if not data:
return jsonify({"error": "JSON body required"}), 400
language = data.get("language", "").lower()
text = data.get("text", "")
if language not in models:
return jsonify({"error": "Unsupported language"}), 400
if not text.strip():
return jsonify({"error": "Text is empty"}), 400
tokenizer = tokenizers[language]
model = models[language]
inputs = tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
wave = model(**inputs).waveform[0].cpu().numpy()
audio_path = wav_to_mp3(wave, model.config.sampling_rate)
return send_file(
audio_path,
mimetype="audio/mpeg",
as_attachment=True,
download_name=f"{language}.mp3"
)
# ---------------- Health check ----------------
@app.route("/", methods=["GET"])
def health():
return {"status": "ok"}
# ---------------- Run ----------------
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)