Spaces:
Sleeping
Sleeping
| """ | |
| Version améliorée de app.py avec optimisations de performance | |
| """ | |
| from flask import Flask, request, jsonify, send_file | |
| from flask_cors import CORS | |
| import torch | |
| from transformers import AutoModelForCTC, AutoProcessor, VitsModel, AutoTokenizer | |
| import librosa | |
| import numpy as np | |
| import io | |
| import logging | |
| import threading | |
| import time | |
| from pathlib import Path | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Configuration des logs | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| SAMPLE_RATE = 16000 | |
| MAX_AUDIO_LENGTH = 30 | |
| MAX_TEXT_LENGTH = 1000 | |
| # Dictionnaire de mapping pour les langues TTS | |
| LANGUAGE_MAPPING = { | |
| "beh": "facebook/mms-tts-beh", | |
| "bba": "facebook/mms-tts-bba", | |
| "ddn": "facebook/mms-tts-ddn", | |
| "ewe": "facebook/mms-tts-ewe", | |
| "gej": "facebook/mms-tts-gej", | |
| "tbz": "facebook/mms-tts-tbz", | |
| "yor": "facebook/mms-tts-yor", | |
| "eng": "facebook/mms-tts-eng", | |
| "fra": "facebook/mms-tts-fra", | |
| } | |
| # Cache pour les modèles | |
| models_cache = {} | |
| cache_lock = threading.Lock() | |
| # Métadonnées de l'API | |
| API_METADATA = { | |
| "name": "Meta MMS ASR/TTS API", | |
| "version": "2.0", | |
| "description": "Reconnaissance vocale et synthèse vocale multilingue", | |
| "models": { | |
| "asr": "facebook/mms-1b-all (964M parameters)", | |
| "tts": f"{len(LANGUAGE_MAPPING)} langues supportées" | |
| } | |
| } | |
| def get_device(): | |
| """Retourne le device (GPU ou CPU)""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if torch.cuda.is_available(): | |
| logger.info(f"GPU disponible: {torch.cuda.get_device_name(0)}") | |
| return device | |
| def load_asr_model(): | |
| """Charge le modèle ASR avec cache""" | |
| with cache_lock: | |
| if "asr" not in models_cache: | |
| try: | |
| device = get_device() | |
| logger.info("⏳ Chargement du modèle ASR facebook/mms-1b-all...") | |
| processor = AutoProcessor.from_pretrained("facebook/mms-1b-all") | |
| model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all").to(device) | |
| model.eval() | |
| # Désactif les gradients | |
| with torch.no_grad(): | |
| pass | |
| models_cache["asr"] = {"model": model, "processor": processor} | |
| logger.info("✅ Modèle ASR chargé") | |
| except Exception as e: | |
| logger.error(f"❌ Erreur lors du chargement du modèle ASR: {e}") | |
| raise | |
| return models_cache["asr"]["model"], models_cache["asr"]["processor"] | |
| def load_tts_model(language_code): | |
| """Charge le modèle TTS pour une langue""" | |
| with cache_lock: | |
| if language_code not in models_cache: | |
| try: | |
| model_id = LANGUAGE_MAPPING.get(language_code) | |
| if not model_id: | |
| raise ValueError(f"Langue non supportée: {language_code}") | |
| device = get_device() | |
| logger.info(f"⏳ Chargement du modèle TTS {language_code} ({model_id})...") | |
| model = VitsModel.from_pretrained(model_id).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model.eval() | |
| models_cache[language_code] = {"model": model, "tokenizer": tokenizer} | |
| logger.info(f"✅ Modèle TTS {language_code} chargé") | |
| except Exception as e: | |
| logger.error(f"❌ Erreur lors du chargement du modèle TTS {language_code}: {e}") | |
| raise | |
| return models_cache[language_code]["model"], models_cache[language_code]["tokenizer"] | |
| def process_audio(audio_data, target_sr=SAMPLE_RATE): | |
| """Traite et normalise l'audio""" | |
| try: | |
| if isinstance(audio_data, bytes): | |
| audio, sr = librosa.load(io.BytesIO(audio_data), sr=None, mono=True) | |
| else: | |
| audio = audio_data | |
| sr = SAMPLE_RATE | |
| # Rééchantillonne si nécessaire | |
| if sr != target_sr: | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr) | |
| # Normalise | |
| if np.max(np.abs(audio)) > 0: | |
| audio = audio / np.max(np.abs(audio)) | |
| # Tronque si trop long | |
| max_samples = MAX_AUDIO_LENGTH * target_sr | |
| if len(audio) > max_samples: | |
| audio = audio[:max_samples] | |
| logger.warning(f"Audio tronqué à {MAX_AUDIO_LENGTH}s") | |
| return audio | |
| except Exception as e: | |
| logger.error(f"❌ Erreur lors du traitement audio: {e}") | |
| raise | |
| def index(): | |
| """Documentation de l'API""" | |
| return jsonify({ | |
| **API_METADATA, | |
| "device": get_device(), | |
| "endpoints": { | |
| "GET /health": "État du service", | |
| "GET /supported-languages": "Langues supportées", | |
| "POST /asr": "Audio → Texte", | |
| "POST /tts": "Texte → Audio", | |
| "GET /models-info": "Infos sur les modèles", | |
| }, | |
| "docs": "https://github.com/ronaldodev/mms-asr-tts" | |
| }) | |
| def health(): | |
| """Vérifier l'état du service""" | |
| try: | |
| device = get_device() | |
| return jsonify({ | |
| "status": "healthy", | |
| "device": device, | |
| "timestamp": time.time(), | |
| "cached_models": list(models_cache.keys()) | |
| }) | |
| except Exception as e: | |
| return jsonify({"status": "error", "error": str(e)}), 500 | |
| def models_info(): | |
| """Informations détaillées sur les modèles""" | |
| return jsonify({ | |
| "asr": { | |
| "model_id": "facebook/mms-1b-all", | |
| "parameters": "964.8M", | |
| "architecture": "wav2vec2", | |
| "languages": 100, | |
| "description": "Automatic Speech Recognition multilingue" | |
| }, | |
| "tts": { | |
| "model_family": "facebook/mms-tts-*", | |
| "architecture": "VITS", | |
| "sample_rate": 22050, | |
| "supported_languages": LANGUAGE_MAPPING, | |
| "description": "Text-to-Speech pour 8 langues" | |
| } | |
| }) | |
| def supported_languages(): | |
| """Langues supportées""" | |
| return jsonify({ | |
| "asr": { | |
| "model": "facebook/mms-1b-all", | |
| "languages": 100, | |
| "description": "Support de 100+ langues ISO 639-3" | |
| }, | |
| "tts": { | |
| "languages": LANGUAGE_MAPPING, | |
| "count": len(LANGUAGE_MAPPING), | |
| "sample_rate": 22050 | |
| } | |
| }) | |
| def asr(): | |
| """Convertir audio en texte (ASR)""" | |
| start_time = time.time() | |
| try: | |
| if "audio" not in request.files: | |
| return jsonify({"error": "Pas de fichier audio fourni"}), 400 | |
| audio_file = request.files["audio"] | |
| language = request.form.get("language", "eng") | |
| logger.info(f"📥 ASR demandé: language={language}, file={audio_file.filename}") | |
| # Valide le fichier | |
| if not audio_file.filename: | |
| return jsonify({"error": "Nom de fichier invalide"}), 400 | |
| # Charge et traite l'audio | |
| audio_data = audio_file.read() | |
| audio = process_audio(audio_data) | |
| logger.info(f" Audio chargé: {len(audio)/SAMPLE_RATE:.2f}s") | |
| # Charge le modèle | |
| model, processor = load_asr_model() | |
| processor.current_lang = language | |
| # Inférence | |
| device = get_device() | |
| with torch.no_grad(): | |
| inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt").to(device) | |
| outputs = model(**inputs) | |
| ids = torch.argmax(outputs.logits, dim=-1)[0] | |
| transcription = processor.decode(ids) | |
| elapsed = time.time() - start_time | |
| logger.info(f"✅ ASR complété en {elapsed:.2f}s: {transcription}") | |
| return jsonify({ | |
| "transcription": transcription, | |
| "language": language, | |
| "audio_length": len(audio) / SAMPLE_RATE, | |
| "processing_time": elapsed, | |
| "confidence": "not_available" | |
| }) | |
| except Exception as e: | |
| logger.error(f"❌ Erreur ASR: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def tts(): | |
| """Convertir texte en audio (TTS)""" | |
| start_time = time.time() | |
| try: | |
| data = request.get_json() | |
| if not data or "text" not in data: | |
| return jsonify({"error": "Paramètre 'text' requis"}), 400 | |
| text = data["text"].strip() | |
| language = data.get("language", "eng") | |
| if not text: | |
| return jsonify({"error": "Le texte ne peut pas être vide"}), 400 | |
| # Limite la longueur | |
| if len(text) > MAX_TEXT_LENGTH: | |
| text = text[:MAX_TEXT_LENGTH] | |
| logger.warning(f"Texte tronqué à {MAX_TEXT_LENGTH} caractères") | |
| logger.info(f"📥 TTS demandé: language={language}, text_len={len(text)}") | |
| # Charge le modèle | |
| model, tokenizer = load_tts_model(language) | |
| # Inférence | |
| device = get_device() | |
| with torch.no_grad(): | |
| inputs = tokenizer(text, return_tensors="pt").to(device) | |
| outputs = model(**inputs) | |
| waveform = outputs.waveform.cpu().numpy().flatten() | |
| # Encode en WAV | |
| import soundfile as sf | |
| audio_bytes = io.BytesIO() | |
| sf.write(audio_bytes, waveform, 22050, format="WAV") | |
| audio_bytes.seek(0) | |
| elapsed = time.time() - start_time | |
| logger.info(f"✅ TTS complété en {elapsed:.2f}s: {len(waveform)} samples") | |
| return send_file( | |
| audio_bytes, | |
| mimetype="audio/wav", | |
| as_attachment=True, | |
| download_name=f"tts_{language}.wav" | |
| ) | |
| except ValueError as e: | |
| logger.error(f"❌ Erreur TTS (valeur): {e}") | |
| return jsonify({"error": str(e)}), 400 | |
| except Exception as e: | |
| logger.error(f"❌ Erreur TTS: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def not_found(e): | |
| return jsonify({"error": "Endpoint non trouvé"}), 404 | |
| def server_error(e): | |
| return jsonify({"error": "Erreur serveur interne"}), 500 | |
| if __name__ == "__main__": | |
| logger.info(f"🚀 Démarrage de l'API MMS") | |
| logger.info(f"📊 Device: {get_device()}") | |
| logger.info(f"🌐 Démarrage sur 0.0.0.0:7860") | |
| app.run(host="0.0.0.0", port=7860, debug=False, threaded=True) | |