api_cnam / app.py
proz's picture
Update app.py
b0b4663 verified
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import librosa
import io
import numpy as np
from contextlib import asynccontextmanager
# --- CONFIGURATION V1 ---
MODEL_ID = "Cnam-LMSSC/wav2vec2-french-phonemizer"
ai_context = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
print(f"🚀 Chargement V1 (Stable) : {MODEL_ID}...")
try:
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
model.eval()
ai_context["processor"] = processor
ai_context["model"] = model
ai_context["vocab"] = processor.tokenizer.get_vocab()
# Inversion vocab pour le debug
ai_context["id2vocab"] = {v: k for k, v in ai_context["vocab"].items()}
print("✅ Modèle V1 prêt.")
except Exception as e:
print(f"❌ Erreur critique : {e}")
yield
ai_context.clear()
app = FastAPI(lifespan=lifespan)
@app.get("/")
def home():
return {"status": "API V1 Stable running", "model": MODEL_ID}
@app.post("/transcribe")
async def transcribe(
file: UploadFile = File(...),
allowed_phones: str = Form(...)
):
if "model" not in ai_context:
raise HTTPException(status_code=500, detail="Modèle non chargé")
# 1. Lecture Audio (Simple, sans boost ni loop)
try:
content = await file.read()
audio_array, _ = librosa.load(io.BytesIO(content), sr=16000)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Erreur audio: {str(e)}")
# 2. Préparation Modèle
processor = ai_context["processor"]
model = ai_context["model"]
vocab = ai_context["vocab"]
id2vocab = ai_context["id2vocab"]
inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
# 3. Calcul des Logits
with torch.no_grad():
logits = model(inputs.input_values).logits
# --- DEBUG TRACE (Utile pour voir ce que la V1 entend vraiment) ---
probs = torch.nn.functional.softmax(logits, dim=-1)
top3_probs, top3_ids = torch.topk(probs, 3, dim=-1)
debug_trace = []
time_steps = top3_ids.shape[1]
for t in range(time_steps):
best_id = top3_ids[0, t, 0].item()
best_token = id2vocab.get(best_id, "")
best_prob = top3_probs[0, t, 0].item()
if best_token not in ["<pad>", "<s>", "</s>", "|"] and best_prob > 0.05:
frame_info = []
for k in range(3):
alt_id = top3_ids[0, t, k].item()
alt_token = id2vocab.get(alt_id, "UNK")
alt_prob = top3_probs[0, t, k].item()
frame_info.append(f"{alt_token} ({int(alt_prob*100)}%)")
debug_trace.append(f"T{t}: " + " | ".join(frame_info))
# Capture du Raw avant masque
raw_ids = torch.argmax(logits, dim=-1)
raw_ipa = processor.batch_decode(raw_ids, skip_special_tokens=True)[0]
# --- 4. APPLICATION DU MASQUE BINAIRE (Logique V1) ---
# La V1 gère les blocs entiers (ex: ɑ̃), pas besoin de décomposition complexe
requested_phones = [p.strip() for p in allowed_phones.split(',') if p.strip()]
if requested_phones:
# Tokens techniques V1 (Le modèle V1 utilise '|' et d'autres)
technical_tokens = ["<pad>", "<s>", "</s>", "<unk>", "|", "[PAD]", "[UNK]"]
full_allowed_set = set(requested_phones + technical_tokens)
# On trouve leurs positions numériques (ID)
# Note : La V1 est "gentille", si on demande 'ɑ̃', elle a un token pour ça.
allowed_indices = [vocab[t] for t in full_allowed_set if t in vocab]
if allowed_indices:
mask = torch.full((logits.shape[-1],), float('-inf'))
mask[allowed_indices] = 0.0
logits = logits + mask
# 5. Décodage final
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return {
"ipa": transcription,
"raw_ipa_before_mask": raw_ipa,
"debug_trace": debug_trace,
"allowed_used": allowed_phones,
"version": "V1_Classic"
}