File size: 4,287 Bytes
3789ef8
 
 
 
 
 
 
 
b0b4663
e1da5ec
b0b4663
3789ef8
 
 
 
b0b4663
3789ef8
 
 
b0b4663
3789ef8
 
 
 
b0b4663
 
3789ef8
b0b4663
3789ef8
284104a
3789ef8
 
 
 
 
 
 
b0b4663
3789ef8
 
 
 
b0b4663
3789ef8
 
 
 
b0b4663
3789ef8
 
 
 
b0b4663
f76a6e9
fc1a510
3789ef8
 
b0b4663
 
fc1a510
284104a
3789ef8
b0b4663
3789ef8
 
 
b0b4663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc1a510
3789ef8
 
 
b0b4663
3789ef8
 
 
b0b4663
 
3789ef8
 
 
 
 
 
 
b0b4663
284104a
3789ef8
 
 
 
b0b4663
 
 
 
3789ef8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import librosa
import io
import numpy as np
from contextlib import asynccontextmanager

# --- CONFIGURATION V1 ---
MODEL_ID = "Cnam-LMSSC/wav2vec2-french-phonemizer"

ai_context = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    print(f"🚀 Chargement V1 (Stable) : {MODEL_ID}...")
    try:
        processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
        model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
        model.eval()
        
        ai_context["processor"] = processor
        ai_context["model"] = model
        ai_context["vocab"] = processor.tokenizer.get_vocab()
        # Inversion vocab pour le debug
        ai_context["id2vocab"] = {v: k for k, v in ai_context["vocab"].items()}
        
        print("✅ Modèle V1 prêt.")
    except Exception as e:
        print(f"❌ Erreur critique : {e}")
    yield
    ai_context.clear()

app = FastAPI(lifespan=lifespan)

@app.get("/")
def home():
    return {"status": "API V1 Stable running", "model": MODEL_ID}

@app.post("/transcribe")
async def transcribe(
    file: UploadFile = File(...),
    allowed_phones: str = Form(...) 
):
    if "model" not in ai_context:
        raise HTTPException(status_code=500, detail="Modèle non chargé")

    # 1. Lecture Audio (Simple, sans boost ni loop)
    try:
        content = await file.read()
        audio_array, _ = librosa.load(io.BytesIO(content), sr=16000)
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Erreur audio: {str(e)}")

    # 2. Préparation Modèle
    processor = ai_context["processor"]
    model = ai_context["model"]
    vocab = ai_context["vocab"]
    id2vocab = ai_context["id2vocab"]
    
    inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)

    # 3. Calcul des Logits
    with torch.no_grad():
        logits = model(inputs.input_values).logits

    # --- DEBUG TRACE (Utile pour voir ce que la V1 entend vraiment) ---
    probs = torch.nn.functional.softmax(logits, dim=-1)
    top3_probs, top3_ids = torch.topk(probs, 3, dim=-1)
    
    debug_trace = []
    time_steps = top3_ids.shape[1]
    
    for t in range(time_steps):
        best_id = top3_ids[0, t, 0].item()
        best_token = id2vocab.get(best_id, "")
        best_prob = top3_probs[0, t, 0].item()

        if best_token not in ["<pad>", "<s>", "</s>", "|"] and best_prob > 0.05:
            frame_info = []
            for k in range(3):
                alt_id = top3_ids[0, t, k].item()
                alt_token = id2vocab.get(alt_id, "UNK")
                alt_prob = top3_probs[0, t, k].item()
                frame_info.append(f"{alt_token} ({int(alt_prob*100)}%)")
            debug_trace.append(f"T{t}: " + " | ".join(frame_info))

    # Capture du Raw avant masque
    raw_ids = torch.argmax(logits, dim=-1)
    raw_ipa = processor.batch_decode(raw_ids, skip_special_tokens=True)[0]


    # --- 4. APPLICATION DU MASQUE BINAIRE (Logique V1) ---
    # La V1 gère les blocs entiers (ex: ɑ̃), pas besoin de décomposition complexe
    
    requested_phones = [p.strip() for p in allowed_phones.split(',') if p.strip()]
    
    if requested_phones:
        # Tokens techniques V1 (Le modèle V1 utilise '|' et d'autres)
        technical_tokens = ["<pad>", "<s>", "</s>", "<unk>", "|", "[PAD]", "[UNK]"]
        full_allowed_set = set(requested_phones + technical_tokens)
        
        # On trouve leurs positions numériques (ID)
        # Note : La V1 est "gentille", si on demande 'ɑ̃', elle a un token pour ça.
        allowed_indices = [vocab[t] for t in full_allowed_set if t in vocab]
        
        if allowed_indices:
            mask = torch.full((logits.shape[-1],), float('-inf'))
            mask[allowed_indices] = 0.0
            logits = logits + mask

    # 5. Décodage final
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    
    return {
        "ipa": transcription,
        "raw_ipa_before_mask": raw_ipa,
        "debug_trace": debug_trace,
        "allowed_used": allowed_phones,
        "version": "V1_Classic"
    }