yo-ha-asr / app.py
Ronaldodev's picture
Create app.py
5839ad6 verified
import os
import tempfile
import torch
from flask import Flask, request, jsonify, render_template, send_file
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import librosa
import warnings
warnings.filterwarnings('ignore')
app = Flask(__name__)
# ========================
# CONFIGURATION
# ========================
STT_MODELS = {
"yoruba": {
"model_id": "ajibs75/whisper-small-yoruba",
"language": "yo"
},
"hausa": {
"model_id": "NCAIR1/Hausa-ASR",
"language": "ha"
}
}
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ Device: {device}")
models = {}
processors = {}
# ========================
# LOAD MODELS AT STARTUP
# ========================
print("=" * 60)
print("🎤 Chargement des modèles STT...")
print("=" * 60)
for lang, cfg in STT_MODELS.items():
print(f"📥 Chargement du modèle {lang}...")
try:
processor = WhisperProcessor.from_pretrained(cfg["model_id"])
model = WhisperForConditionalGeneration.from_pretrained(cfg["model_id"])
model.to(device)
model.eval()
processors[lang] = processor
models[lang] = model
print(f"✅ {lang.capitalize()} prêt")
except Exception as e:
print(f"❌ Erreur {lang}: {e}")
print("=" * 60)
print("✅ Tous les modèles sont chargés")
print("=" * 60)
# ========================
# UTILITY FUNCTIONS
# ========================
def transcribe_audio(audio_path: str, language: str) -> str:
"""Transcrit un fichier audio avec le modèle approprié"""
if language not in models:
raise ValueError(f"Langue non supportée: {language}")
processor = processors[language]
model = models[language]
# Charger l'audio avec librosa (resample à 16kHz)
audio, sr = librosa.load(audio_path, sr=16000)
# Prétraiter l'audio
input_features = processor(
audio,
sampling_rate=16000,
return_tensors="pt"
).input_features.to(device)
# Générer la transcription
with torch.no_grad():
predicted_ids = model.generate(
input_features,
max_length=448,
num_beams=5,
temperature=0.0
)
# Décoder
transcription = processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
return transcription.strip()
# ========================
# ROUTES
# ========================
@app.route("/")
def home():
"""Page d'accueil"""
return render_template("index.html")
@app.route("/stt", methods=["POST"])
def stt():
"""Endpoint de transcription"""
try:
# Vérifier le fichier audio
if "audio" not in request.files:
return jsonify({"error": "Aucun fichier audio fourni"}), 400
audio_file = request.files["audio"]
if audio_file.filename == "":
return jsonify({"error": "Nom de fichier vide"}), 400
# Récupérer la langue
language = request.form.get("language", "").lower()
if language not in models:
return jsonify({"error": f"Langue non supportée: {language}"}), 400
# Sauvegarder temporairement le fichier
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
audio_file.save(temp_audio.name)
temp_path = temp_audio.name
print(f"📝 Transcription en cours ({language})...")
# Transcrire
transcription = transcribe_audio(temp_path, language)
# Nettoyer
os.remove(temp_path)
print(f"✅ Transcription: {transcription}")
return jsonify({
"success": True,
"language": language,
"transcription": transcription
})
except Exception as e:
print(f"❌ Erreur: {str(e)}")
return jsonify({
"success": False,
"error": str(e)
}), 500
@app.route("/health")
def health():
"""Health check"""
return jsonify({
"status": "healthy",
"device": device,
"models": list(models.keys())
})
@app.route("/languages")
def languages():
"""Liste des langues disponibles"""
return jsonify({
"languages": [
{"code": "yoruba", "name": "Yoruba", "flag": "🇳🇬"},
{"code": "hausa", "name": "Hausa", "flag": "🇳🇬"}
]
})
# ========================
# RUN
# ========================
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False)