import os import tempfile import torch import librosa import numpy as np from flask import Flask, request, jsonify, render_template from transformers import WhisperProcessor, WhisperForConditionalGeneration import warnings warnings.filterwarnings('ignore') app = Flask(__name__) # ======================== # CONFIGURATION # ======================== MODEL_NAME = "Ronaldodev/whisper-fon-v4" PROCESSOR_NAME = "openai/whisper-small" TOKEN = os.getenv("HF_TOKEN") device = "cuda" if torch.cuda.is_available() else "cpu" print("=" * 60) print(f"🎤 Chargement du modèle ASR Fon") print(f"📦 Modèle: {MODEL_NAME}") print(f"🖥️ Device: {device}") print("=" * 60) # ======================== # LOAD MODEL AT STARTUP # ======================== try: print("📥 Chargement du processor...") processor = WhisperProcessor.from_pretrained(PROCESSOR_NAME, token=TOKEN) print("📥 Chargement du modèle...") model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, token=TOKEN) model.to(device) model.eval() print("✅ Modèle chargé avec succès") print("=" * 60) model_loaded = True except Exception as e: print(f"❌ Erreur lors du chargement: {str(e)}") print("=" * 60) model_loaded = False # ======================== # TRANSCRIPTION FUNCTIONS # ======================== def transcribe_short(audio_path: str) -> str: """Transcrit un audio court (< 30 secondes)""" speech, sr = librosa.load(audio_path, sr=16000) inputs = processor( speech, sampling_rate=16000, return_tensors="pt" ).input_features.to(device) with torch.no_grad(): ids = model.generate( inputs, max_length=300, task="transcribe" )[0] return processor.decode(ids, skip_special_tokens=True) def transcribe_long(audio_path: str, chunk_seconds: int = 30, overlap_seconds: int = 5) -> str: """Transcrit un audio long avec découpage en chunks""" speech, sr = librosa.load(audio_path, sr=16000) chunk_size = chunk_seconds * sr overlap = overlap_seconds * sr start = 0 full_text = "" chunk_count = 0 while start < len(speech): end = min(start + chunk_size, len(speech)) chunk = speech[start:end] inputs = processor( chunk, sampling_rate=16000, return_tensors="pt" ).input_features.to(device) with torch.no_grad(): ids = model.generate(inputs, max_length=448)[0] text = processor.decode(ids, skip_special_tokens=True) full_text += text + " " chunk_count += 1 print(f"✅ Chunk {chunk_count} transcrit: {text[:50]}...") start += chunk_size - overlap return full_text.strip() def get_audio_duration(audio_path: str) -> float: """Retourne la durée de l'audio en secondes""" y, sr = librosa.load(audio_path, sr=16000) return librosa.get_duration(y=y, sr=sr) # ======================== # ROUTES # ======================== @app.route("/") def home(): """Page d'accueil""" if not model_loaded: return """
Le modèle n'a pas pu être chargé. Vérifiez les logs.
""", 500 return render_template("index.html") @app.route("/transcribe", methods=["POST"]) def transcribe(): """Endpoint de transcription""" print("=" * 60) print("🎤 Requête de transcription reçue") print("=" * 60) if not model_loaded: return jsonify({ "success": False, "error": "Le modèle n'est pas chargé" }), 500 try: # Vérifier le fichier audio if "audio" not in request.files: return jsonify({ "success": False, "error": "Aucun fichier audio fourni" }), 400 audio_file = request.files["audio"] if audio_file.filename == "": return jsonify({ "success": False, "error": "Nom de fichier vide" }), 400 # Sauvegarder temporairement with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: audio_file.save(temp_audio.name) temp_path = temp_audio.name print(f"📁 Fichier temporaire: {temp_path}") # Obtenir la durée duration = get_audio_duration(temp_path) print(f"⏱️ Durée: {duration:.2f}s") # Choisir la méthode de transcription if duration <= 30: print("📝 Transcription courte...") transcription = transcribe_short(temp_path) else: print("📝 Transcription longue (avec chunks)...") transcription = transcribe_long(temp_path) # Nettoyer os.remove(temp_path) print(f"✅ Transcription: {transcription}") print("=" * 60) return jsonify({ "success": True, "transcription": transcription, "duration": round(duration, 2), "language": "fon" }) except Exception as e: print(f"❌ Erreur: {str(e)}") print("=" * 60) # Nettoyer si erreur if 'temp_path' in locals(): try: os.remove(temp_path) except: pass return jsonify({ "success": False, "error": str(e) }), 500 @app.route("/health") def health(): """Health check""" return jsonify({ "status": "healthy" if model_loaded else "unhealthy", "model": MODEL_NAME, "device": device, "model_loaded": model_loaded, "language": "fon" }) # ======================== # RUN # ======================== if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port, debug=False)