Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import torch | |
| import numpy as np | |
| from flask import Flask, request, send_file, jsonify | |
| from transformers import VitsModel, AutoTokenizer | |
| from pydub import AudioSegment | |
| import scipy.io.wavfile as wavfile | |
| app = Flask(__name__) | |
| # ---------------- Configuration ---------------- | |
| TTS_MODELS = { | |
| "yoruba": { | |
| "tokenizer": "FarmerlineML/yoruba_tts-2025", | |
| "checkpoint": "FarmerlineML/yoruba_tts-2025" | |
| }, | |
| "hausa": { | |
| "tokenizer": "FarmerlineML/main_hausa_TTS", | |
| "checkpoint": "FarmerlineML/main_hausa_TTS" | |
| } | |
| } | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| models = {} | |
| tokenizers = {} | |
| # ---------------- Load models at startup ---------------- | |
| print("Loading TTS models...") | |
| for lang, cfg in TTS_MODELS.items(): | |
| print(f"Loading {lang} model...") | |
| model = VitsModel.from_pretrained(cfg["checkpoint"]).to(device) | |
| model.eval() | |
| model.noise_scale = 0.5 | |
| model.noise_scale_duration = 0.5 | |
| model.speaking_rate = 0.9 | |
| models[lang] = model | |
| tokenizers[lang] = AutoTokenizer.from_pretrained(cfg["tokenizer"]) | |
| print("Models loaded.") | |
| # ---------------- Utils ---------------- | |
| def wav_to_mp3(wave: np.ndarray, sr: int) -> str: | |
| # 1️⃣ sécurité : float32 → normalisation | |
| if wave.dtype != np.int16: | |
| max_val = np.max(np.abs(wave)) | |
| if max_val > 0: | |
| wave = wave / max_val | |
| wave = (wave * 32767).astype(np.int16) | |
| # 2️⃣ écrire WAV | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| wavfile.write(f.name, sr, wave) | |
| wav_path = f.name | |
| # 3️⃣ charger avec pydub et forcer mono | |
| audio = AudioSegment.from_wav(wav_path) | |
| audio = audio.set_channels(1) | |
| audio = audio.normalize() # 🔥 très important | |
| mp3_path = wav_path.replace(".wav", ".mp3") | |
| audio.export( | |
| mp3_path, | |
| format="mp3", | |
| bitrate="128k" | |
| ) | |
| os.remove(wav_path) | |
| return mp3_path | |
| # ---------------- Routes ---------------- | |
| def tts(): | |
| data = request.get_json() | |
| if not data: | |
| return jsonify({"error": "JSON body required"}), 400 | |
| language = data.get("language", "").lower() | |
| text = data.get("text", "") | |
| if language not in models: | |
| return jsonify({"error": "Unsupported language"}), 400 | |
| if not text.strip(): | |
| return jsonify({"error": "Text is empty"}), 400 | |
| tokenizer = tokenizers[language] | |
| model = models[language] | |
| inputs = tokenizer(text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| wave = model(**inputs).waveform[0].cpu().numpy() | |
| audio_path = wav_to_mp3(wave, model.config.sampling_rate) | |
| return send_file( | |
| audio_path, | |
| mimetype="audio/mpeg", | |
| as_attachment=True, | |
| download_name=f"{language}.mp3" | |
| ) | |
| # ---------------- Health check ---------------- | |
| def health(): | |
| return {"status": "ok"} | |
| # ---------------- Run ---------------- | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860) | |