Bushra-KB's picture
Update backend/app.py
69ab4fe verified
import os
import io
import re
import tempfile
# Removed heavy imports from top to speed up startup:
# import torch
# import numpy as np
# import soundfile as sf
from flask import Flask, request, jsonify, send_file, render_template
from flask_cors import CORS
from gtts import gTTS
from gtts.tts import gTTSError
# Removed top-level transformers import to lazy-load MMS:
# from transformers import VitsModel, AutoTokenizer
# Lazy MMS globals
mms_model = None
mms_tokenizer = None
# Define a writable cache directory for Hugging Face models
CACHE_DIR = os.environ.get("TRANSFORMERS_CACHE")
def load_mms():
global mms_model, mms_tokenizer
if mms_model and mms_tokenizer:
return
print("Loading Facebook MMS-TTS model for Amharic...")
print(f"Using cache directory: {CACHE_DIR}")
from transformers import VitsModel, AutoTokenizer
mms_model_id = "facebook/mms-tts-amh"
# Explicitly pass the cache_dir to from_pretrained
mms_model = VitsModel.from_pretrained(mms_model_id, cache_dir=CACHE_DIR)
mms_tokenizer = AutoTokenizer.from_pretrained(mms_model_id, cache_dir=CACHE_DIR)
print("MMS-TTS model loaded successfully.")
app = Flask(__name__, static_folder='static', template_folder='templates')
CORS(app)
@app.route('/')
def index():
return render_template('index.html')
# Health check
@app.route('/health')
def health():
return jsonify({
"ok": True,
"mms_loaded": bool(mms_model and mms_tokenizer)
})
@app.route('/api/tts', methods=['POST'])
def text_to_speech():
data = request.get_json()
if not data or 'text' not in data or not data['text'].strip():
return jsonify({"error": "Text is required."}), 400
text = data.get('text')
model = data.get('model', 'gtts')
speed = float(data.get('speed', 1.0))
print(f"--- Received TTS Request for model: {model} ---")
try:
if model == 'gtts':
try:
print("Attempting gTTS synthesis with default endpoint (tld='com')...")
tts = gTTS(text=text, lang='am', slow=(speed < 1.0), lang_check=False)
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as tmp:
tmp_path = tmp.name
try:
tts.save(tmp_path)
with open(tmp_path, 'rb') as f:
data_bytes = f.read()
finally:
try: os.remove(tmp_path)
except OSError: pass
if not data_bytes:
raise RuntimeError("gTTS produced empty audio stream")
audio_fp = io.BytesIO(data_bytes)
audio_fp.seek(0)
print("Successfully generated audio with gTTS.")
return send_file(audio_fp, mimetype='audio/mpeg')
except gTTSError as ge:
msg = ("gTTS failed using the default endpoint (Google TTS). "
"Please try again later or use the MMS model.")
print(f"gTTS gTTSError: {ge}")
return jsonify({"error": msg, "details": str(ge)}), 502
except Exception as ge:
msg = "gTTS failed unexpectedly on the default endpoint."
print(f"gTTS unexpected error: {ge}")
return jsonify({"error": msg, "details": str(ge)}), 502
elif model == 'mms':
try:
load_mms()
except Exception as e:
print(f"Failed to load MMS: {e}")
return jsonify({"error": "MMS-TTS model is not available on the server.", "details": str(e)}), 500
print("Generating audio with MMS-TTS...")
# Heavy imports only used here
import torch
import soundfile as sf
# The transformers tokenizer will automatically use uroman if it's installed.
# No explicit call is needed.
if re.search(r"[^A-Za-z0-9\s\.,\?!;:'\"\-]", text):
print("Text contains non-Roman characters. Relying on tokenizer's automatic romanization.")
inputs = mms_tokenizer(text, return_tensors="pt")
try:
input_len = inputs["input_ids"].shape[-1]
except Exception:
input_len = 0
if input_len == 0:
msg = ("MMS-TTS received text that tokenized to length 0. "
"Install 'uroman' (Python >= 3.10) or provide romanized Latin text.")
print(msg)
return jsonify({"error": msg}), 400
with torch.no_grad():
output = mms_model(**inputs).waveform
sampling_rate = mms_model.config.sampling_rate
speech_waveform = output.cpu().numpy().squeeze()
audio_fp = io.BytesIO()
sf.write(audio_fp, speech_waveform, sampling_rate, format='WAV')
audio_fp.seek(0)
print("Successfully generated audio with MMS-TTS.")
return send_file(audio_fp, mimetype='audio/wav')
elif model in ['openai', 'azure']:
return jsonify({"error": "The keys for this model have expired. Please use other models."}), 403
else:
return jsonify({"error": f"The model '{model}' is not implemented yet."}), 501
except Exception as e:
print(f"An error occurred: {e}")
return jsonify({"error": f"An unexpected error occurred during TTS generation: {str(e)}"}), 500
if __name__ == '__main__':
port = int(os.getenv('PORT', 7860))
app.run(debug=False, port=port, host='0.0.0.0')