import os import tempfile import torch import numpy as np from flask import Flask, request, send_file, jsonify from transformers import VitsModel, AutoTokenizer from pydub import AudioSegment import scipy.io.wavfile as wavfile app = Flask(__name__) # ---------------- Configuration ---------------- TTS_MODELS = { "yoruba": { "tokenizer": "FarmerlineML/yoruba_tts-2025", "checkpoint": "FarmerlineML/yoruba_tts-2025" }, "hausa": { "tokenizer": "FarmerlineML/main_hausa_TTS", "checkpoint": "FarmerlineML/main_hausa_TTS" } } device = "cuda" if torch.cuda.is_available() else "cpu" models = {} tokenizers = {} # ---------------- Load models at startup ---------------- print("Loading TTS models...") for lang, cfg in TTS_MODELS.items(): print(f"Loading {lang} model...") model = VitsModel.from_pretrained(cfg["checkpoint"]).to(device) model.eval() model.noise_scale = 0.5 model.noise_scale_duration = 0.5 model.speaking_rate = 0.9 models[lang] = model tokenizers[lang] = AutoTokenizer.from_pretrained(cfg["tokenizer"]) print("Models loaded.") # ---------------- Utils ---------------- def wav_to_mp3(wave: np.ndarray, sr: int) -> str: # 1️⃣ sécurité : float32 → normalisation if wave.dtype != np.int16: max_val = np.max(np.abs(wave)) if max_val > 0: wave = wave / max_val wave = (wave * 32767).astype(np.int16) # 2️⃣ écrire WAV with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: wavfile.write(f.name, sr, wave) wav_path = f.name # 3️⃣ charger avec pydub et forcer mono audio = AudioSegment.from_wav(wav_path) audio = audio.set_channels(1) audio = audio.normalize() # 🔥 très important mp3_path = wav_path.replace(".wav", ".mp3") audio.export( mp3_path, format="mp3", bitrate="128k" ) os.remove(wav_path) return mp3_path # ---------------- Routes ---------------- @app.route("/tts", methods=["POST"]) def tts(): data = request.get_json() if not data: return jsonify({"error": "JSON body required"}), 400 language = data.get("language", "").lower() text = data.get("text", "") if language not in models: return jsonify({"error": "Unsupported language"}), 400 if not text.strip(): return jsonify({"error": "Text is empty"}), 400 tokenizer = tokenizers[language] model = models[language] inputs = tokenizer(text, return_tensors="pt").to(device) with torch.no_grad(): wave = model(**inputs).waveform[0].cpu().numpy() audio_path = wav_to_mp3(wave, model.config.sampling_rate) return send_file( audio_path, mimetype="audio/mpeg", as_attachment=True, download_name=f"{language}.mp3" ) # ---------------- Health check ---------------- @app.route("/", methods=["GET"]) def health(): return {"status": "ok"} # ---------------- Run ---------------- if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)