Spaces:
Running
Running
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from optimum.onnxruntime import ORTModelForCTC | |
| from transformers import Wav2Vec2Processor | |
| import torch | |
| import librosa | |
| import numpy as np | |
| import io | |
| from contextlib import asynccontextmanager | |
| # --- CONFIGURATION --- | |
| # Le dossier sera copié par le Dockerfile au même niveau que app.py | |
| ONNX_MODEL_DIR = "model_fr_onnx" | |
| ONNX_FILENAME = "model_quantized.onnx" | |
| ai_context = {} | |
| async def lifespan(app: FastAPI): | |
| print("🚀 Chargement du modèle Français ONNX...") | |
| try: | |
| ai_context["processor"] = Wav2Vec2Processor.from_pretrained(ONNX_MODEL_DIR) | |
| ai_context["model"] = ORTModelForCTC.from_pretrained(ONNX_MODEL_DIR, file_name=ONNX_FILENAME) | |
| ai_context["vocab"] = ai_context["processor"].tokenizer.get_vocab() | |
| print("✅ Modèle chargé.") | |
| except Exception as e: | |
| print(f"❌ Erreur critique : {e}") | |
| yield | |
| ai_context.clear() | |
| app = FastAPI(lifespan=lifespan) | |
| def home(): | |
| return {"status": "API is running", "help": "POST /transcribe to use"} | |
| async def transcribe( | |
| file: UploadFile = File(...), | |
| allowed_phones: str = Form(...) | |
| ): | |
| if "model" not in ai_context: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| # Lecture Audio | |
| audio_content = await file.read() | |
| try: | |
| speech, _ = librosa.load(io.BytesIO(audio_content), sr=16000) | |
| except: | |
| # Fallback si librosa n'aime pas le format direct, on peut essayer soundfile | |
| import soundfile as sf | |
| speech, _ = sf.read(io.BytesIO(audio_content)) | |
| if len(speech.shape) > 1: speech = speech[:, 0] # Stereo to mono | |
| if _ != 16000: speech = librosa.resample(speech, orig_sr=_, target_sr=16000) | |
| # Inférence | |
| processor = ai_context["processor"] | |
| inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True) | |
| logits = ai_context["model"](inputs.input_values).logits | |
| # Masquage Dynamique | |
| user_allowed = [p.strip() for p in allowed_phones.split(',')] | |
| technical_tokens = ["|", "[PAD]", "<s>", "</s>", "<pad>", "<unk>", "[UNK]"] | |
| full_allowed = set(user_allowed + technical_tokens) | |
| vocab = ai_context["vocab"] | |
| mask = torch.ones(logits.shape[-1], dtype=torch.bool) | |
| allowed_indices = [vocab[t] for t in full_allowed if t in vocab] | |
| if allowed_indices: | |
| mask[allowed_indices] = False | |
| logits[:, :, mask] = -float('inf') | |
| # Décodage | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| return {"ipa": transcription} |