proz's picture
Upload 11 files
7eac826 verified
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from optimum.onnxruntime import ORTModelForCTC
from transformers import Wav2Vec2Processor
import torch
import librosa
import numpy as np
import io
from contextlib import asynccontextmanager
# --- CONFIGURATION ---
# Le dossier sera copié par le Dockerfile au même niveau que app.py
ONNX_MODEL_DIR = "model_fr_onnx"
ONNX_FILENAME = "model_quantized.onnx"
ai_context = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
print("🚀 Chargement du modèle Français ONNX...")
try:
ai_context["processor"] = Wav2Vec2Processor.from_pretrained(ONNX_MODEL_DIR)
ai_context["model"] = ORTModelForCTC.from_pretrained(ONNX_MODEL_DIR, file_name=ONNX_FILENAME)
ai_context["vocab"] = ai_context["processor"].tokenizer.get_vocab()
print("✅ Modèle chargé.")
except Exception as e:
print(f"❌ Erreur critique : {e}")
yield
ai_context.clear()
app = FastAPI(lifespan=lifespan)
@app.get("/")
def home():
return {"status": "API is running", "help": "POST /transcribe to use"}
@app.post("/transcribe")
async def transcribe(
file: UploadFile = File(...),
allowed_phones: str = Form(...)
):
if "model" not in ai_context:
raise HTTPException(status_code=500, detail="Model not loaded")
# Lecture Audio
audio_content = await file.read()
try:
speech, _ = librosa.load(io.BytesIO(audio_content), sr=16000)
except:
# Fallback si librosa n'aime pas le format direct, on peut essayer soundfile
import soundfile as sf
speech, _ = sf.read(io.BytesIO(audio_content))
if len(speech.shape) > 1: speech = speech[:, 0] # Stereo to mono
if _ != 16000: speech = librosa.resample(speech, orig_sr=_, target_sr=16000)
# Inférence
processor = ai_context["processor"]
inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
logits = ai_context["model"](inputs.input_values).logits
# Masquage Dynamique
user_allowed = [p.strip() for p in allowed_phones.split(',')]
technical_tokens = ["|", "[PAD]", "<s>", "</s>", "<pad>", "<unk>", "[UNK]"]
full_allowed = set(user_allowed + technical_tokens)
vocab = ai_context["vocab"]
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
allowed_indices = [vocab[t] for t in full_allowed if t in vocab]
if allowed_indices:
mask[allowed_indices] = False
logits[:, :, mask] = -float('inf')
# Décodage
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return {"ipa": transcription}