| | from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
| | from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC |
| | import torch |
| | import librosa |
| | import io |
| | import numpy as np |
| | from contextlib import asynccontextmanager |
| |
|
| | |
| | MODEL_ID = "Cnam-LMSSC/wav2vec2-french-phonemizer" |
| |
|
| | ai_context = {} |
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | print(f"🚀 Chargement V1 (Stable) : {MODEL_ID}...") |
| | try: |
| | processor = Wav2Vec2Processor.from_pretrained(MODEL_ID) |
| | model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) |
| | model.eval() |
| | |
| | ai_context["processor"] = processor |
| | ai_context["model"] = model |
| | ai_context["vocab"] = processor.tokenizer.get_vocab() |
| | |
| | ai_context["id2vocab"] = {v: k for k, v in ai_context["vocab"].items()} |
| | |
| | print("✅ Modèle V1 prêt.") |
| | except Exception as e: |
| | print(f"❌ Erreur critique : {e}") |
| | yield |
| | ai_context.clear() |
| |
|
| | app = FastAPI(lifespan=lifespan) |
| |
|
| | @app.get("/") |
| | def home(): |
| | return {"status": "API V1 Stable running", "model": MODEL_ID} |
| |
|
| | @app.post("/transcribe") |
| | async def transcribe( |
| | file: UploadFile = File(...), |
| | allowed_phones: str = Form(...) |
| | ): |
| | if "model" not in ai_context: |
| | raise HTTPException(status_code=500, detail="Modèle non chargé") |
| |
|
| | |
| | try: |
| | content = await file.read() |
| | audio_array, _ = librosa.load(io.BytesIO(content), sr=16000) |
| | except Exception as e: |
| | raise HTTPException(status_code=400, detail=f"Erreur audio: {str(e)}") |
| |
|
| | |
| | processor = ai_context["processor"] |
| | model = ai_context["model"] |
| | vocab = ai_context["vocab"] |
| | id2vocab = ai_context["id2vocab"] |
| | |
| | inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True) |
| |
|
| | |
| | with torch.no_grad(): |
| | logits = model(inputs.input_values).logits |
| |
|
| | |
| | probs = torch.nn.functional.softmax(logits, dim=-1) |
| | top3_probs, top3_ids = torch.topk(probs, 3, dim=-1) |
| | |
| | debug_trace = [] |
| | time_steps = top3_ids.shape[1] |
| | |
| | for t in range(time_steps): |
| | best_id = top3_ids[0, t, 0].item() |
| | best_token = id2vocab.get(best_id, "") |
| | best_prob = top3_probs[0, t, 0].item() |
| |
|
| | if best_token not in ["<pad>", "<s>", "</s>", "|"] and best_prob > 0.05: |
| | frame_info = [] |
| | for k in range(3): |
| | alt_id = top3_ids[0, t, k].item() |
| | alt_token = id2vocab.get(alt_id, "UNK") |
| | alt_prob = top3_probs[0, t, k].item() |
| | frame_info.append(f"{alt_token} ({int(alt_prob*100)}%)") |
| | debug_trace.append(f"T{t}: " + " | ".join(frame_info)) |
| |
|
| | |
| | raw_ids = torch.argmax(logits, dim=-1) |
| | raw_ipa = processor.batch_decode(raw_ids, skip_special_tokens=True)[0] |
| |
|
| |
|
| | |
| | |
| | |
| | requested_phones = [p.strip() for p in allowed_phones.split(',') if p.strip()] |
| | |
| | if requested_phones: |
| | |
| | technical_tokens = ["<pad>", "<s>", "</s>", "<unk>", "|", "[PAD]", "[UNK]"] |
| | full_allowed_set = set(requested_phones + technical_tokens) |
| | |
| | |
| | |
| | allowed_indices = [vocab[t] for t in full_allowed_set if t in vocab] |
| | |
| | if allowed_indices: |
| | mask = torch.full((logits.shape[-1],), float('-inf')) |
| | mask[allowed_indices] = 0.0 |
| | logits = logits + mask |
| |
|
| | |
| | predicted_ids = torch.argmax(logits, dim=-1) |
| | transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
| | |
| | return { |
| | "ipa": transcription, |
| | "raw_ipa_before_mask": raw_ipa, |
| | "debug_trace": debug_trace, |
| | "allowed_used": allowed_phones, |
| | "version": "V1_Classic" |
| | } |