from fastapi import FastAPI, File, UploadFile, Form, HTTPException from optimum.onnxruntime import ORTModelForCTC from transformers import Wav2Vec2Processor import torch import librosa import numpy as np import io from contextlib import asynccontextmanager # --- CONFIGURATION --- # Le dossier sera copié par le Dockerfile au même niveau que app.py ONNX_MODEL_DIR = "model_fr_onnx" ONNX_FILENAME = "model_quantized.onnx" ai_context = {} @asynccontextmanager async def lifespan(app: FastAPI): print("🚀 Chargement du modèle Français ONNX...") try: ai_context["processor"] = Wav2Vec2Processor.from_pretrained(ONNX_MODEL_DIR) ai_context["model"] = ORTModelForCTC.from_pretrained(ONNX_MODEL_DIR, file_name=ONNX_FILENAME) ai_context["vocab"] = ai_context["processor"].tokenizer.get_vocab() print("✅ Modèle chargé.") except Exception as e: print(f"❌ Erreur critique : {e}") yield ai_context.clear() app = FastAPI(lifespan=lifespan) @app.get("/") def home(): return {"status": "API is running", "help": "POST /transcribe to use"} @app.post("/transcribe") async def transcribe( file: UploadFile = File(...), allowed_phones: str = Form(...) ): if "model" not in ai_context: raise HTTPException(status_code=500, detail="Model not loaded") # Lecture Audio audio_content = await file.read() try: speech, _ = librosa.load(io.BytesIO(audio_content), sr=16000) except: # Fallback si librosa n'aime pas le format direct, on peut essayer soundfile import soundfile as sf speech, _ = sf.read(io.BytesIO(audio_content)) if len(speech.shape) > 1: speech = speech[:, 0] # Stereo to mono if _ != 16000: speech = librosa.resample(speech, orig_sr=_, target_sr=16000) # Inférence processor = ai_context["processor"] inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True) logits = ai_context["model"](inputs.input_values).logits # Masquage Dynamique user_allowed = [p.strip() for p in allowed_phones.split(',')] technical_tokens = ["|", "[PAD]", "", "", "", "", "[UNK]"] full_allowed = set(user_allowed + technical_tokens) vocab = ai_context["vocab"] mask = torch.ones(logits.shape[-1], dtype=torch.bool) allowed_indices = [vocab[t] for t in full_allowed if t in vocab] if allowed_indices: mask[allowed_indices] = False logits[:, :, mask] = -float('inf') # Décodage predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] return {"ipa": transcription}