avatar-v2 / Back-end /services /stt_service.py
DataSage12's picture
Initial commit: HOLOKIA-AVATAR for Hugging Face Spaces
69aa271
import tempfile
from fastapi import FastAPI, UploadFile, File, HTTPException
import logging
import os
import uvicorn
from faster_whisper import WhisperModel
from tempfile import NamedTemporaryFile
from fastapi.middleware.cors import CORSMiddleware
from pydub import AudioSegment
from langdetect import detect, DetectorFactory, LangDetectException
import glob
# Configuration du logger
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("STT_Service")
# Initialisation FastAPI
app = FastAPI()
# CORS pour frontend local
origins = ["http://localhost:5173", "http://127.0.0.1:5173"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Paramètres
MODEL_SIZE = "small" # Léger et rapide
MIN_AUDIO_SIZE = 1024 # Bytes (~0.03s à 16kHz)
MIN_AUDIO_DURATION = 0.75 # Sec min pour traitement
SAMPLE_RATE = 16000 # Hz, 16-bit mono
MAX_TEMP_FILES = 1000 # Limite de fichiers temporaires
# Initialisation du modèle Whisper
logger.info(f"Chargement du modèle Faster-Whisper ({MODEL_SIZE})")
try:
model = WhisperModel(MODEL_SIZE, device="cpu", compute_type="int8")
logger.info("Modèle Whisper chargé avec succès")
except Exception as e:
logger.error(f"Erreur lors du chargement du modèle Whisper: {e}")
model = None
# Rendre langdetect plus stable
DetectorFactory.seed = 0
# Correction langues fréquentes
EN_WORDS = {"hi", "hello", "hey", "ok", "thanks", "bye", "yes", "no"}
AR_WORDS = {"salam", "salaam", "marhaban", "مرحبا", "سلام", "شكرا", "أهلا"}
def validate_language(transcript: str, whisper_lang: str, probability: float) -> str:
"""Valide ou corrige la langue détectée par Whisper avec langdetect."""
if not transcript.strip():
return whisper_lang or "fr"
try:
detected_lang = detect(transcript)
except LangDetectException:
return whisper_lang or "fr"
# Normalisation et correction
text_lower = transcript.lower().strip()
if text_lower in EN_WORDS:
return "en"
if text_lower in AR_WORDS:
return "ar"
# Si probabilité faible (<0.9) et langdetect donne une langue différente, préférer langdetect
if probability < 0.9 and detected_lang in {"fr", "en", "ar"}:
logger.debug(f"Langue corrigée par langdetect: {whisper_lang} -> {detected_lang}")
return detected_lang
# Sinon, on garde la langue de Whisper
return whisper_lang or "fr"
def clean_temp_files():
"""Supprime les fichiers temporaires excédentaires."""
temp_files = glob.glob(f"{tempfile.gettempdir()}/*.wav")
if len(temp_files) > MAX_TEMP_FILES:
temp_files.sort(key=os.path.getmtime)
for file in temp_files[:-MAX_TEMP_FILES]:
try:
os.unlink(file)
logger.info(f"Fichier temporaire supprimé: {file}")
except Exception as e:
logger.warning(f"Erreur lors de la suppression de {file}: {e}")
@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...), language: str = None):
"""
Transcrit un fichier audio en texte avec Faster-Whisper.
- file: Audio (WAV, MP3, M4A, WebM, etc.)
- language: Langue cible (optionnel: en, ar, fr, etc.), sinon autodétection
"""
logger.info(f"Fichier reçu: filename={file.filename}, content_type={file.content_type}")
# Vérifier si le modèle est chargé
if model is None:
logger.error("Modèle Whisper non chargé")
raise HTTPException(status_code=503, detail="Service STT non disponible")
temp_audio_path = None
try:
# Nettoie les fichiers temporaires si nécessaire
clean_temp_files()
# Sauvegarde temporaire
suffix = os.path.splitext(file.filename)[1].lower()
content = await file.read()
if len(content) < MIN_AUDIO_SIZE:
logger.debug(f"Audio trop petit ({len(content)} bytes), skip")
return {"transcript": "", "lang": language or "unknown", "status": "skipped"}
with NamedTemporaryFile(delete=False, suffix=suffix) as temp_audio:
temp_audio.write(content)
temp_audio_path = temp_audio.name
# Vérification durée
duration = os.path.getsize(temp_audio_path) / (SAMPLE_RATE * 2) # 16kHz, 16-bit mono
logger.info(f"Processing audio, duration: {duration:.3f}s")
if duration < MIN_AUDIO_DURATION:
logger.debug(f"Audio trop court ({duration:.3f}s), skip")
return {"transcript": "", "lang": language or "unknown", "status": "skipped"}
# Conversion si non supporté
supported_formats = [".wav", ".mp3", ".m4a", ".webm"]
wav_path = temp_audio_path
if suffix != ".wav":
try:
audio = AudioSegment.from_file(temp_audio_path)
with NamedTemporaryFile(delete=False, suffix=".wav") as wav_fd:
audio.export(wav_fd.name, format="wav")
wav_path = wav_fd.name
logger.info(f"Converti en WAV: {wav_path}")
except Exception as e:
logger.error(f"Erreur conversion WAV: {e}")
raise HTTPException(status_code=400, detail="Erreur conversion audio")
# Transcription
segments, info = model.transcribe(wav_path, language=language if language else None)
transcript = " ".join([s.text.strip() for s in segments if s.text.strip()])
# Validation de la langue
detected_lang = language or validate_language(transcript, info.language, info.language_probability)
logger.info(f"Langue détectée: {info.language}, Probabilité: {info.language_probability:.2f}, Langue finale: {detected_lang}")
logger.info(f"Transcript: '{transcript}'")
return {
"transcript": transcript,
"lang": detected_lang,
"status": "success" if transcript else "empty"
}
except Exception as e:
logger.error(f"Échec transcription: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Erreur STT")
finally:
# Nettoyage fichiers temporaires
try:
if temp_audio_path and os.path.exists(temp_audio_path):
os.unlink(temp_audio_path)
if "wav_path" in locals() and wav_path != temp_audio_path and os.path.exists(wav_path):
os.unlink(wav_path)
except Exception as e:
logger.warning(f"Erreur nettoyage fichiers: {e}")
@app.get("/health")
async def health_check():
"""Vérifie l’état du service."""
return {"status": "ok", "service": "stt"}
def run_service():
"""Lance le service STT."""
uvicorn.run(app, host="0.0.0.0", port=5001)
if __name__ == "__main__":
run_service()