File size: 2,734 Bytes
7eac826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from optimum.onnxruntime import ORTModelForCTC
from transformers import Wav2Vec2Processor
import torch
import librosa
import numpy as np
import io
from contextlib import asynccontextmanager

# --- CONFIGURATION ---
# Le dossier sera copié par le Dockerfile au même niveau que app.py
ONNX_MODEL_DIR = "model_fr_onnx"
ONNX_FILENAME = "model_quantized.onnx"

ai_context = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    print("🚀 Chargement du modèle Français ONNX...")
    try:
        ai_context["processor"] = Wav2Vec2Processor.from_pretrained(ONNX_MODEL_DIR)
        ai_context["model"] = ORTModelForCTC.from_pretrained(ONNX_MODEL_DIR, file_name=ONNX_FILENAME)
        ai_context["vocab"] = ai_context["processor"].tokenizer.get_vocab()
        print("✅ Modèle chargé.")
    except Exception as e:
        print(f"❌ Erreur critique : {e}")
    yield
    ai_context.clear()

app = FastAPI(lifespan=lifespan)

@app.get("/")
def home():
    return {"status": "API is running", "help": "POST /transcribe to use"}

@app.post("/transcribe")
async def transcribe(
    file: UploadFile = File(...), 
    allowed_phones: str = Form(...) 
):
    if "model" not in ai_context:
        raise HTTPException(status_code=500, detail="Model not loaded")
    
    # Lecture Audio
    audio_content = await file.read()
    try:
        speech, _ = librosa.load(io.BytesIO(audio_content), sr=16000)
    except:
         # Fallback si librosa n'aime pas le format direct, on peut essayer soundfile
         import soundfile as sf
         speech, _ = sf.read(io.BytesIO(audio_content))
         if len(speech.shape) > 1: speech = speech[:, 0] # Stereo to mono
         if _ != 16000: speech = librosa.resample(speech, orig_sr=_, target_sr=16000)

    # Inférence
    processor = ai_context["processor"]
    inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
    
    logits = ai_context["model"](inputs.input_values).logits
    
    # Masquage Dynamique
    user_allowed = [p.strip() for p in allowed_phones.split(',')]
    technical_tokens = ["|", "[PAD]", "<s>", "</s>", "<pad>", "<unk>", "[UNK]"]
    full_allowed = set(user_allowed + technical_tokens)
    
    vocab = ai_context["vocab"]
    mask = torch.ones(logits.shape[-1], dtype=torch.bool)
    
    allowed_indices = [vocab[t] for t in full_allowed if t in vocab]
    if allowed_indices:
        mask[allowed_indices] = False
        
    logits[:, :, mask] = -float('inf')

    # Décodage
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    
    return {"ipa": transcription}