File size: 4,287 Bytes
3789ef8 b0b4663 e1da5ec b0b4663 3789ef8 b0b4663 3789ef8 b0b4663 3789ef8 b0b4663 3789ef8 b0b4663 3789ef8 284104a 3789ef8 b0b4663 3789ef8 b0b4663 3789ef8 b0b4663 3789ef8 b0b4663 f76a6e9 fc1a510 3789ef8 b0b4663 fc1a510 284104a 3789ef8 b0b4663 3789ef8 b0b4663 fc1a510 3789ef8 b0b4663 3789ef8 b0b4663 3789ef8 b0b4663 284104a 3789ef8 b0b4663 3789ef8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import librosa
import io
import numpy as np
from contextlib import asynccontextmanager
# --- CONFIGURATION V1 ---
MODEL_ID = "Cnam-LMSSC/wav2vec2-french-phonemizer"
ai_context = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
print(f"🚀 Chargement V1 (Stable) : {MODEL_ID}...")
try:
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
model.eval()
ai_context["processor"] = processor
ai_context["model"] = model
ai_context["vocab"] = processor.tokenizer.get_vocab()
# Inversion vocab pour le debug
ai_context["id2vocab"] = {v: k for k, v in ai_context["vocab"].items()}
print("✅ Modèle V1 prêt.")
except Exception as e:
print(f"❌ Erreur critique : {e}")
yield
ai_context.clear()
app = FastAPI(lifespan=lifespan)
@app.get("/")
def home():
return {"status": "API V1 Stable running", "model": MODEL_ID}
@app.post("/transcribe")
async def transcribe(
file: UploadFile = File(...),
allowed_phones: str = Form(...)
):
if "model" not in ai_context:
raise HTTPException(status_code=500, detail="Modèle non chargé")
# 1. Lecture Audio (Simple, sans boost ni loop)
try:
content = await file.read()
audio_array, _ = librosa.load(io.BytesIO(content), sr=16000)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Erreur audio: {str(e)}")
# 2. Préparation Modèle
processor = ai_context["processor"]
model = ai_context["model"]
vocab = ai_context["vocab"]
id2vocab = ai_context["id2vocab"]
inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
# 3. Calcul des Logits
with torch.no_grad():
logits = model(inputs.input_values).logits
# --- DEBUG TRACE (Utile pour voir ce que la V1 entend vraiment) ---
probs = torch.nn.functional.softmax(logits, dim=-1)
top3_probs, top3_ids = torch.topk(probs, 3, dim=-1)
debug_trace = []
time_steps = top3_ids.shape[1]
for t in range(time_steps):
best_id = top3_ids[0, t, 0].item()
best_token = id2vocab.get(best_id, "")
best_prob = top3_probs[0, t, 0].item()
if best_token not in ["<pad>", "<s>", "</s>", "|"] and best_prob > 0.05:
frame_info = []
for k in range(3):
alt_id = top3_ids[0, t, k].item()
alt_token = id2vocab.get(alt_id, "UNK")
alt_prob = top3_probs[0, t, k].item()
frame_info.append(f"{alt_token} ({int(alt_prob*100)}%)")
debug_trace.append(f"T{t}: " + " | ".join(frame_info))
# Capture du Raw avant masque
raw_ids = torch.argmax(logits, dim=-1)
raw_ipa = processor.batch_decode(raw_ids, skip_special_tokens=True)[0]
# --- 4. APPLICATION DU MASQUE BINAIRE (Logique V1) ---
# La V1 gère les blocs entiers (ex: ɑ̃), pas besoin de décomposition complexe
requested_phones = [p.strip() for p in allowed_phones.split(',') if p.strip()]
if requested_phones:
# Tokens techniques V1 (Le modèle V1 utilise '|' et d'autres)
technical_tokens = ["<pad>", "<s>", "</s>", "<unk>", "|", "[PAD]", "[UNK]"]
full_allowed_set = set(requested_phones + technical_tokens)
# On trouve leurs positions numériques (ID)
# Note : La V1 est "gentille", si on demande 'ɑ̃', elle a un token pour ça.
allowed_indices = [vocab[t] for t in full_allowed_set if t in vocab]
if allowed_indices:
mask = torch.full((logits.shape[-1],), float('-inf'))
mask[allowed_indices] = 0.0
logits = logits + mask
# 5. Décodage final
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return {
"ipa": transcription,
"raw_ipa_before_mask": raw_ipa,
"debug_trace": debug_trace,
"allowed_used": allowed_phones,
"version": "V1_Classic"
} |