Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import torch | |
| from flask import Flask, request, jsonify, render_template, send_file | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| import librosa | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| app = Flask(__name__) | |
| # ======================== | |
| # CONFIGURATION | |
| # ======================== | |
| STT_MODELS = { | |
| "yoruba": { | |
| "model_id": "ajibs75/whisper-small-yoruba", | |
| "language": "yo" | |
| }, | |
| "hausa": { | |
| "model_id": "NCAIR1/Hausa-ASR", | |
| "language": "ha" | |
| } | |
| } | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"🖥️ Device: {device}") | |
| models = {} | |
| processors = {} | |
| # ======================== | |
| # LOAD MODELS AT STARTUP | |
| # ======================== | |
| print("=" * 60) | |
| print("🎤 Chargement des modèles STT...") | |
| print("=" * 60) | |
| for lang, cfg in STT_MODELS.items(): | |
| print(f"📥 Chargement du modèle {lang}...") | |
| try: | |
| processor = WhisperProcessor.from_pretrained(cfg["model_id"]) | |
| model = WhisperForConditionalGeneration.from_pretrained(cfg["model_id"]) | |
| model.to(device) | |
| model.eval() | |
| processors[lang] = processor | |
| models[lang] = model | |
| print(f"✅ {lang.capitalize()} prêt") | |
| except Exception as e: | |
| print(f"❌ Erreur {lang}: {e}") | |
| print("=" * 60) | |
| print("✅ Tous les modèles sont chargés") | |
| print("=" * 60) | |
| # ======================== | |
| # UTILITY FUNCTIONS | |
| # ======================== | |
| def transcribe_audio(audio_path: str, language: str) -> str: | |
| """Transcrit un fichier audio avec le modèle approprié""" | |
| if language not in models: | |
| raise ValueError(f"Langue non supportée: {language}") | |
| processor = processors[language] | |
| model = models[language] | |
| # Charger l'audio avec librosa (resample à 16kHz) | |
| audio, sr = librosa.load(audio_path, sr=16000) | |
| # Prétraiter l'audio | |
| input_features = processor( | |
| audio, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ).input_features.to(device) | |
| # Générer la transcription | |
| with torch.no_grad(): | |
| predicted_ids = model.generate( | |
| input_features, | |
| max_length=448, | |
| num_beams=5, | |
| temperature=0.0 | |
| ) | |
| # Décoder | |
| transcription = processor.batch_decode( | |
| predicted_ids, | |
| skip_special_tokens=True | |
| )[0] | |
| return transcription.strip() | |
| # ======================== | |
| # ROUTES | |
| # ======================== | |
| def home(): | |
| """Page d'accueil""" | |
| return render_template("index.html") | |
| def stt(): | |
| """Endpoint de transcription""" | |
| try: | |
| # Vérifier le fichier audio | |
| if "audio" not in request.files: | |
| return jsonify({"error": "Aucun fichier audio fourni"}), 400 | |
| audio_file = request.files["audio"] | |
| if audio_file.filename == "": | |
| return jsonify({"error": "Nom de fichier vide"}), 400 | |
| # Récupérer la langue | |
| language = request.form.get("language", "").lower() | |
| if language not in models: | |
| return jsonify({"error": f"Langue non supportée: {language}"}), 400 | |
| # Sauvegarder temporairement le fichier | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: | |
| audio_file.save(temp_audio.name) | |
| temp_path = temp_audio.name | |
| print(f"📝 Transcription en cours ({language})...") | |
| # Transcrire | |
| transcription = transcribe_audio(temp_path, language) | |
| # Nettoyer | |
| os.remove(temp_path) | |
| print(f"✅ Transcription: {transcription}") | |
| return jsonify({ | |
| "success": True, | |
| "language": language, | |
| "transcription": transcription | |
| }) | |
| except Exception as e: | |
| print(f"❌ Erreur: {str(e)}") | |
| return jsonify({ | |
| "success": False, | |
| "error": str(e) | |
| }), 500 | |
| def health(): | |
| """Health check""" | |
| return jsonify({ | |
| "status": "healthy", | |
| "device": device, | |
| "models": list(models.keys()) | |
| }) | |
| def languages(): | |
| """Liste des langues disponibles""" | |
| return jsonify({ | |
| "languages": [ | |
| {"code": "yoruba", "name": "Yoruba", "flag": "🇳🇬"}, | |
| {"code": "hausa", "name": "Hausa", "flag": "🇳🇬"} | |
| ] | |
| }) | |
| # ======================== | |
| # RUN | |
| # ======================== | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port, debug=False) |