fon-essai / app.py
Ronaldodev's picture
Update app.py
8983ab0 verified
import os
import tempfile
import torch
import librosa
import numpy as np
from flask import Flask, request, jsonify, render_template
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import warnings
warnings.filterwarnings('ignore')
app = Flask(__name__)
# ========================
# CONFIGURATION
# ========================
MODEL_NAME = "Ronaldodev/whisper-fon-v4"
PROCESSOR_NAME = "openai/whisper-small"
TOKEN = os.getenv("HF_TOKEN")
device = "cuda" if torch.cuda.is_available() else "cpu"
print("=" * 60)
print(f"🎤 Chargement du modèle ASR Fon")
print(f"📦 Modèle: {MODEL_NAME}")
print(f"🖥️ Device: {device}")
print("=" * 60)
# ========================
# LOAD MODEL AT STARTUP
# ========================
try:
print("📥 Chargement du processor...")
processor = WhisperProcessor.from_pretrained(PROCESSOR_NAME, token=TOKEN)
print("📥 Chargement du modèle...")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, token=TOKEN)
model.to(device)
model.eval()
print("✅ Modèle chargé avec succès")
print("=" * 60)
model_loaded = True
except Exception as e:
print(f"❌ Erreur lors du chargement: {str(e)}")
print("=" * 60)
model_loaded = False
# ========================
# TRANSCRIPTION FUNCTIONS
# ========================
def transcribe_short(audio_path: str) -> str:
"""Transcrit un audio court (< 30 secondes)"""
speech, sr = librosa.load(audio_path, sr=16000)
inputs = processor(
speech,
sampling_rate=16000,
return_tensors="pt"
).input_features.to(device)
with torch.no_grad():
ids = model.generate(
inputs,
max_length=300,
task="transcribe"
)[0]
return processor.decode(ids, skip_special_tokens=True)
def transcribe_long(audio_path: str, chunk_seconds: int = 30, overlap_seconds: int = 5) -> str:
"""Transcrit un audio long avec découpage en chunks"""
speech, sr = librosa.load(audio_path, sr=16000)
chunk_size = chunk_seconds * sr
overlap = overlap_seconds * sr
start = 0
full_text = ""
chunk_count = 0
while start < len(speech):
end = min(start + chunk_size, len(speech))
chunk = speech[start:end]
inputs = processor(
chunk,
sampling_rate=16000,
return_tensors="pt"
).input_features.to(device)
with torch.no_grad():
ids = model.generate(inputs, max_length=448)[0]
text = processor.decode(ids, skip_special_tokens=True)
full_text += text + " "
chunk_count += 1
print(f"✅ Chunk {chunk_count} transcrit: {text[:50]}...")
start += chunk_size - overlap
return full_text.strip()
def get_audio_duration(audio_path: str) -> float:
"""Retourne la durée de l'audio en secondes"""
y, sr = librosa.load(audio_path, sr=16000)
return librosa.get_duration(y=y, sr=sr)
# ========================
# ROUTES
# ========================
@app.route("/")
def home():
"""Page d'accueil"""
if not model_loaded:
return """
<html>
<body style="font-family: Arial; text-align: center; padding: 50px;">
<h1>❌ Erreur</h1>
<p>Le modèle n'a pas pu être chargé. Vérifiez les logs.</p>
</body>
</html>
""", 500
return render_template("index.html")
@app.route("/transcribe", methods=["POST"])
def transcribe():
"""Endpoint de transcription"""
print("=" * 60)
print("🎤 Requête de transcription reçue")
print("=" * 60)
if not model_loaded:
return jsonify({
"success": False,
"error": "Le modèle n'est pas chargé"
}), 500
try:
# Vérifier le fichier audio
if "audio" not in request.files:
return jsonify({
"success": False,
"error": "Aucun fichier audio fourni"
}), 400
audio_file = request.files["audio"]
if audio_file.filename == "":
return jsonify({
"success": False,
"error": "Nom de fichier vide"
}), 400
# Sauvegarder temporairement
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
audio_file.save(temp_audio.name)
temp_path = temp_audio.name
print(f"📁 Fichier temporaire: {temp_path}")
# Obtenir la durée
duration = get_audio_duration(temp_path)
print(f"⏱️ Durée: {duration:.2f}s")
# Choisir la méthode de transcription
if duration <= 30:
print("📝 Transcription courte...")
transcription = transcribe_short(temp_path)
else:
print("📝 Transcription longue (avec chunks)...")
transcription = transcribe_long(temp_path)
# Nettoyer
os.remove(temp_path)
print(f"✅ Transcription: {transcription}")
print("=" * 60)
return jsonify({
"success": True,
"transcription": transcription,
"duration": round(duration, 2),
"language": "fon"
})
except Exception as e:
print(f"❌ Erreur: {str(e)}")
print("=" * 60)
# Nettoyer si erreur
if 'temp_path' in locals():
try:
os.remove(temp_path)
except:
pass
return jsonify({
"success": False,
"error": str(e)
}), 500
@app.route("/health")
def health():
"""Health check"""
return jsonify({
"status": "healthy" if model_loaded else "unhealthy",
"model": MODEL_NAME,
"device": device,
"model_loaded": model_loaded,
"language": "fon"
})
# ========================
# RUN
# ========================
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False)