Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import re | |
| import tempfile | |
| # Removed heavy imports from top to speed up startup: | |
| # import torch | |
| # import numpy as np | |
| # import soundfile as sf | |
| from flask import Flask, request, jsonify, send_file, render_template | |
| from flask_cors import CORS | |
| from gtts import gTTS | |
| from gtts.tts import gTTSError | |
| # Removed top-level transformers import to lazy-load MMS: | |
| # from transformers import VitsModel, AutoTokenizer | |
| # Lazy MMS globals | |
| mms_model = None | |
| mms_tokenizer = None | |
| # Define a writable cache directory for Hugging Face models | |
| CACHE_DIR = os.environ.get("TRANSFORMERS_CACHE") | |
| def load_mms(): | |
| global mms_model, mms_tokenizer | |
| if mms_model and mms_tokenizer: | |
| return | |
| print("Loading Facebook MMS-TTS model for Amharic...") | |
| print(f"Using cache directory: {CACHE_DIR}") | |
| from transformers import VitsModel, AutoTokenizer | |
| mms_model_id = "facebook/mms-tts-amh" | |
| # Explicitly pass the cache_dir to from_pretrained | |
| mms_model = VitsModel.from_pretrained(mms_model_id, cache_dir=CACHE_DIR) | |
| mms_tokenizer = AutoTokenizer.from_pretrained(mms_model_id, cache_dir=CACHE_DIR) | |
| print("MMS-TTS model loaded successfully.") | |
| app = Flask(__name__, static_folder='static', template_folder='templates') | |
| CORS(app) | |
| def index(): | |
| return render_template('index.html') | |
| # Health check | |
| def health(): | |
| return jsonify({ | |
| "ok": True, | |
| "mms_loaded": bool(mms_model and mms_tokenizer) | |
| }) | |
| def text_to_speech(): | |
| data = request.get_json() | |
| if not data or 'text' not in data or not data['text'].strip(): | |
| return jsonify({"error": "Text is required."}), 400 | |
| text = data.get('text') | |
| model = data.get('model', 'gtts') | |
| speed = float(data.get('speed', 1.0)) | |
| print(f"--- Received TTS Request for model: {model} ---") | |
| try: | |
| if model == 'gtts': | |
| try: | |
| print("Attempting gTTS synthesis with default endpoint (tld='com')...") | |
| tts = gTTS(text=text, lang='am', slow=(speed < 1.0), lang_check=False) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as tmp: | |
| tmp_path = tmp.name | |
| try: | |
| tts.save(tmp_path) | |
| with open(tmp_path, 'rb') as f: | |
| data_bytes = f.read() | |
| finally: | |
| try: os.remove(tmp_path) | |
| except OSError: pass | |
| if not data_bytes: | |
| raise RuntimeError("gTTS produced empty audio stream") | |
| audio_fp = io.BytesIO(data_bytes) | |
| audio_fp.seek(0) | |
| print("Successfully generated audio with gTTS.") | |
| return send_file(audio_fp, mimetype='audio/mpeg') | |
| except gTTSError as ge: | |
| msg = ("gTTS failed using the default endpoint (Google TTS). " | |
| "Please try again later or use the MMS model.") | |
| print(f"gTTS gTTSError: {ge}") | |
| return jsonify({"error": msg, "details": str(ge)}), 502 | |
| except Exception as ge: | |
| msg = "gTTS failed unexpectedly on the default endpoint." | |
| print(f"gTTS unexpected error: {ge}") | |
| return jsonify({"error": msg, "details": str(ge)}), 502 | |
| elif model == 'mms': | |
| try: | |
| load_mms() | |
| except Exception as e: | |
| print(f"Failed to load MMS: {e}") | |
| return jsonify({"error": "MMS-TTS model is not available on the server.", "details": str(e)}), 500 | |
| print("Generating audio with MMS-TTS...") | |
| # Heavy imports only used here | |
| import torch | |
| import soundfile as sf | |
| # The transformers tokenizer will automatically use uroman if it's installed. | |
| # No explicit call is needed. | |
| if re.search(r"[^A-Za-z0-9\s\.,\?!;:'\"\-]", text): | |
| print("Text contains non-Roman characters. Relying on tokenizer's automatic romanization.") | |
| inputs = mms_tokenizer(text, return_tensors="pt") | |
| try: | |
| input_len = inputs["input_ids"].shape[-1] | |
| except Exception: | |
| input_len = 0 | |
| if input_len == 0: | |
| msg = ("MMS-TTS received text that tokenized to length 0. " | |
| "Install 'uroman' (Python >= 3.10) or provide romanized Latin text.") | |
| print(msg) | |
| return jsonify({"error": msg}), 400 | |
| with torch.no_grad(): | |
| output = mms_model(**inputs).waveform | |
| sampling_rate = mms_model.config.sampling_rate | |
| speech_waveform = output.cpu().numpy().squeeze() | |
| audio_fp = io.BytesIO() | |
| sf.write(audio_fp, speech_waveform, sampling_rate, format='WAV') | |
| audio_fp.seek(0) | |
| print("Successfully generated audio with MMS-TTS.") | |
| return send_file(audio_fp, mimetype='audio/wav') | |
| elif model in ['openai', 'azure']: | |
| return jsonify({"error": "The keys for this model have expired. Please use other models."}), 403 | |
| else: | |
| return jsonify({"error": f"The model '{model}' is not implemented yet."}), 501 | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| return jsonify({"error": f"An unexpected error occurred during TTS generation: {str(e)}"}), 500 | |
| if __name__ == '__main__': | |
| port = int(os.getenv('PORT', 7860)) | |
| app.run(debug=False, port=port, host='0.0.0.0') |