from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from optimum.onnxruntime import ORTModelForCTC
from transformers import Wav2Vec2Processor
import torch
import librosa
import numpy as np
import io
from contextlib import asynccontextmanager
# --- CONFIGURATION ---
# Le dossier sera copié par le Dockerfile au même niveau que app.py
ONNX_MODEL_DIR = "model_fr_onnx"
ONNX_FILENAME = "model_quantized.onnx"
ai_context = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
print("🚀 Chargement du modèle Français ONNX...")
try:
ai_context["processor"] = Wav2Vec2Processor.from_pretrained(ONNX_MODEL_DIR)
ai_context["model"] = ORTModelForCTC.from_pretrained(ONNX_MODEL_DIR, file_name=ONNX_FILENAME)
ai_context["vocab"] = ai_context["processor"].tokenizer.get_vocab()
print("✅ Modèle chargé.")
except Exception as e:
print(f"❌ Erreur critique : {e}")
yield
ai_context.clear()
app = FastAPI(lifespan=lifespan)
@app.get("/")
def home():
return {"status": "API is running", "help": "POST /transcribe to use"}
@app.post("/transcribe")
async def transcribe(
file: UploadFile = File(...),
allowed_phones: str = Form(...)
):
if "model" not in ai_context:
raise HTTPException(status_code=500, detail="Model not loaded")
# Lecture Audio
audio_content = await file.read()
try:
speech, _ = librosa.load(io.BytesIO(audio_content), sr=16000)
except:
# Fallback si librosa n'aime pas le format direct, on peut essayer soundfile
import soundfile as sf
speech, _ = sf.read(io.BytesIO(audio_content))
if len(speech.shape) > 1: speech = speech[:, 0] # Stereo to mono
if _ != 16000: speech = librosa.resample(speech, orig_sr=_, target_sr=16000)
# Inférence
processor = ai_context["processor"]
inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
logits = ai_context["model"](inputs.input_values).logits
# Masquage Dynamique
user_allowed = [p.strip() for p in allowed_phones.split(',')]
technical_tokens = ["|", "[PAD]", "", "", "", "", "[UNK]"]
full_allowed = set(user_allowed + technical_tokens)
vocab = ai_context["vocab"]
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
allowed_indices = [vocab[t] for t in full_allowed if t in vocab]
if allowed_indices:
mask[allowed_indices] = False
logits[:, :, mask] = -float('inf')
# Décodage
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return {"ipa": transcription}