from fastapi import FastAPI, File, UploadFile, Form, HTTPException from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import torch import librosa import io import numpy as np from contextlib import asynccontextmanager # --- CONFIGURATION V1 --- MODEL_ID = "Cnam-LMSSC/wav2vec2-french-phonemizer" ai_context = {} @asynccontextmanager async def lifespan(app: FastAPI): print(f"🚀 Chargement V1 (Stable) : {MODEL_ID}...") try: processor = Wav2Vec2Processor.from_pretrained(MODEL_ID) model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) model.eval() ai_context["processor"] = processor ai_context["model"] = model ai_context["vocab"] = processor.tokenizer.get_vocab() # Inversion vocab pour le debug ai_context["id2vocab"] = {v: k for k, v in ai_context["vocab"].items()} print("✅ Modèle V1 prêt.") except Exception as e: print(f"❌ Erreur critique : {e}") yield ai_context.clear() app = FastAPI(lifespan=lifespan) @app.get("/") def home(): return {"status": "API V1 Stable running", "model": MODEL_ID} @app.post("/transcribe") async def transcribe( file: UploadFile = File(...), allowed_phones: str = Form(...) ): if "model" not in ai_context: raise HTTPException(status_code=500, detail="Modèle non chargé") # 1. Lecture Audio (Simple, sans boost ni loop) try: content = await file.read() audio_array, _ = librosa.load(io.BytesIO(content), sr=16000) except Exception as e: raise HTTPException(status_code=400, detail=f"Erreur audio: {str(e)}") # 2. Préparation Modèle processor = ai_context["processor"] model = ai_context["model"] vocab = ai_context["vocab"] id2vocab = ai_context["id2vocab"] inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True) # 3. Calcul des Logits with torch.no_grad(): logits = model(inputs.input_values).logits # --- DEBUG TRACE (Utile pour voir ce que la V1 entend vraiment) --- probs = torch.nn.functional.softmax(logits, dim=-1) top3_probs, top3_ids = torch.topk(probs, 3, dim=-1) debug_trace = [] time_steps = top3_ids.shape[1] for t in range(time_steps): best_id = top3_ids[0, t, 0].item() best_token = id2vocab.get(best_id, "") best_prob = top3_probs[0, t, 0].item() if best_token not in ["", "", "", "|"] and best_prob > 0.05: frame_info = [] for k in range(3): alt_id = top3_ids[0, t, k].item() alt_token = id2vocab.get(alt_id, "UNK") alt_prob = top3_probs[0, t, k].item() frame_info.append(f"{alt_token} ({int(alt_prob*100)}%)") debug_trace.append(f"T{t}: " + " | ".join(frame_info)) # Capture du Raw avant masque raw_ids = torch.argmax(logits, dim=-1) raw_ipa = processor.batch_decode(raw_ids, skip_special_tokens=True)[0] # --- 4. APPLICATION DU MASQUE BINAIRE (Logique V1) --- # La V1 gère les blocs entiers (ex: ɑ̃), pas besoin de décomposition complexe requested_phones = [p.strip() for p in allowed_phones.split(',') if p.strip()] if requested_phones: # Tokens techniques V1 (Le modèle V1 utilise '|' et d'autres) technical_tokens = ["", "", "", "", "|", "[PAD]", "[UNK]"] full_allowed_set = set(requested_phones + technical_tokens) # On trouve leurs positions numériques (ID) # Note : La V1 est "gentille", si on demande 'ɑ̃', elle a un token pour ça. allowed_indices = [vocab[t] for t in full_allowed_set if t in vocab] if allowed_indices: mask = torch.full((logits.shape[-1],), float('-inf')) mask[allowed_indices] = 0.0 logits = logits + mask # 5. Décodage final predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] return { "ipa": transcription, "raw_ipa_before_mask": raw_ipa, "debug_trace": debug_trace, "allowed_used": allowed_phones, "version": "V1_Classic" }