from flask import Flask, request, jsonify, send_file from flask_cors import CORS from transformers import VitsModel, AutoTokenizer import torch import scipy.io.wavfile app = Flask(__name__) CORS(app) # --- Load the MMS TTS models and tokenizers --- # Note: Models are loaded on startup. For a larger number of languages, lazy loading would be better. MODELS = { "ewe": "facebook/mms-tts-ewe", "mina": "facebook/mms-tts-gej" # Mina (Gen) } loaded_models = {} loaded_tokenizers = {} print("Loading models...") for lang, model_id in MODELS.items(): print(f"Loading {lang} model: {model_id}") try: loaded_models[lang] = VitsModel.from_pretrained(model_id) loaded_tokenizers[lang] = AutoTokenizer.from_pretrained(model_id) except Exception as e: print(f"Failed to load {lang} model: {e}") print("Models loaded successfully.") @app.route('/tts', methods=['POST']) def text_to_speech(): """ This endpoint receives text and a language code, returning a WAV audio file. """ data = request.get_json() text = data.get('text') lang = data.get('language', 'ewe') # Default to Ewe if not text: return jsonify({"error": "No text provided"}), 400 if lang not in loaded_models: return jsonify({"error": f"Language '{lang}' not supported"}), 400 try: model = loaded_models[lang] tokenizer = loaded_tokenizers[lang] # Tokenize the input text inputs = tokenizer(text, return_tensors="pt") # Generate the speech waveform with torch.no_grad(): output = model(**inputs).waveform # Save the audio file as a WAV file audio_file_path = "output.wav" sampling_rate = model.config.sampling_rate scipy.io.wavfile.write(audio_file_path, rate=sampling_rate, data=output.float().numpy().T) return send_file(audio_file_path, mimetype="audio/wav") except Exception as e: return jsonify({"error": f"An unexpected error occurred: {e}"}), 500 if __name__ == '__main__': # It's recommended to use a production-ready WSGI server instead of app.run in production app.run(debug=True)