proz commited on
Commit
fc1a510
·
verified ·
1 Parent(s): e1da5ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -13
app.py CHANGED
@@ -6,24 +6,26 @@ import io
6
  import numpy as np
7
  from contextlib import asynccontextmanager
8
 
9
- # --- CONFIGURATION V1 ---
10
  MODEL_ID = "Cnam-LMSSC/wav2vec2-french-phonemizer"
11
-
12
  ai_context = {}
13
 
14
  @asynccontextmanager
15
  async def lifespan(app: FastAPI):
16
  print(f"🚀 Chargement du modèle V1 {MODEL_ID}...")
17
  try:
 
18
  processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
19
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
20
- model.eval()
21
 
22
  ai_context["processor"] = processor
23
  ai_context["model"] = model
 
 
24
  ai_context["vocab"] = processor.tokenizer.get_vocab()
25
 
26
- print("✅ Modèle V1 prêt.")
27
  except Exception as e:
28
  print(f"❌ Erreur critique : {e}")
29
  yield
@@ -33,49 +35,80 @@ app = FastAPI(lifespan=lifespan)
33
 
34
  @app.get("/")
35
  def home():
36
- return {"status": "API V1 running", "model": MODEL_ID}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  @app.post("/transcribe")
39
  async def transcribe(
40
  file: UploadFile = File(...),
41
- allowed_phones: str = Form(...)
42
  ):
43
  if "model" not in ai_context:
44
  raise HTTPException(status_code=500, detail="Modèle non chargé")
45
 
46
- # 1. Lecture Audio
47
  try:
48
  content = await file.read()
 
49
  audio_array, _ = librosa.load(io.BytesIO(content), sr=16000)
50
  except Exception as e:
51
- raise HTTPException(status_code=400, detail=f"Erreur audio: {str(e)}")
52
 
53
- # 2. Préparation
54
  processor = ai_context["processor"]
55
  model = ai_context["model"]
 
56
  inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
57
 
58
- # 3. Logits
59
  with torch.no_grad():
60
  logits = model(inputs.input_values).logits
61
 
62
- # 4. Masque Binaire Strict
 
 
63
  requested_phones = [p.strip() for p in allowed_phones.split(',') if p.strip()]
64
 
65
  if requested_phones:
66
  vocab = ai_context["vocab"]
67
- # Tokens techniques V1 (Le modèle V1 utilise beaucoup le pipe '|')
 
 
68
  technical_tokens = ["<pad>", "<s>", "</s>", "<unk>", "|", "[PAD]", "[UNK]"]
 
 
69
  full_allowed_set = set(requested_phones + technical_tokens)
70
 
 
71
  allowed_indices = [vocab[t] for t in full_allowed_set if t in vocab]
72
 
73
  if allowed_indices:
 
74
  mask = torch.full((logits.shape[-1],), float('-inf'))
 
 
75
  mask[allowed_indices] = 0.0
 
 
76
  logits = logits + mask
77
 
78
- # 5. Décodage
79
  predicted_ids = torch.argmax(logits, dim=-1)
80
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
81
 
 
6
  import numpy as np
7
  from contextlib import asynccontextmanager
8
 
9
+ # Configuration V1
10
  MODEL_ID = "Cnam-LMSSC/wav2vec2-french-phonemizer"
 
11
  ai_context = {}
12
 
13
  @asynccontextmanager
14
  async def lifespan(app: FastAPI):
15
  print(f"🚀 Chargement du modèle V1 {MODEL_ID}...")
16
  try:
17
+ # Chargement du processeur et du modèle
18
  processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
19
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
20
+ model.eval() # Mode lecture seule (plus rapide)
21
 
22
  ai_context["processor"] = processor
23
  ai_context["model"] = model
24
+
25
+ # On stocke le vocabulaire pour le masque (ex: 'a': 12, 'b': 14)
26
  ai_context["vocab"] = processor.tokenizer.get_vocab()
27
 
28
+ print("✅ Modèle V1 prêt et vocabulaire indexé.")
29
  except Exception as e:
30
  print(f"❌ Erreur critique : {e}")
31
  yield
 
35
 
36
  @app.get("/")
37
  def home():
38
+ return {"status": "Model Cnam V1 Masked is running"}
39
+
40
+ # --- AJOUT POUR DIAGNOSTIC ---
41
+ @app.get("/vocab")
42
+ def get_vocab():
43
+ """Renvoie le dictionnaire complet de la V1 pour comparaison"""
44
+ if "vocab" not in ai_context:
45
+ return {"error": "Modèle non chargé"}
46
+
47
+ # On trie le dictionnaire par ordre alphabétique des clés pour faciliter la lecture
48
+ sorted_vocab = dict(sorted(ai_context["vocab"].items()))
49
+
50
+ return {
51
+ "model": MODEL_ID,
52
+ "total_tokens": len(sorted_vocab),
53
+ "tokens": sorted_vocab
54
+ }
55
+ # -----------------------------
56
 
57
  @app.post("/transcribe")
58
  async def transcribe(
59
  file: UploadFile = File(...),
60
+ allowed_phones: str = Form(...) # Ce champ est OBLIGATOIRE
61
  ):
62
  if "model" not in ai_context:
63
  raise HTTPException(status_code=500, detail="Modèle non chargé")
64
 
65
+ # 1. Lecture Audio avec Librosa (force 16kHz)
66
  try:
67
  content = await file.read()
68
+ # On utilise io.BytesIO pour lire depuis la mémoire sans fichier temporaire
69
  audio_array, _ = librosa.load(io.BytesIO(content), sr=16000)
70
  except Exception as e:
71
+ raise HTTPException(status_code=400, detail=f"Erreur fichier audio: {str(e)}")
72
 
73
+ # 2. Préparation Modèle
74
  processor = ai_context["processor"]
75
  model = ai_context["model"]
76
+
77
  inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
78
 
79
+ # 3. Calcul des Logits (Probabilités brutes avant décision)
80
  with torch.no_grad():
81
  logits = model(inputs.input_values).logits
82
 
83
+ # --- 4. APPLICATION DU MASQUE BINAIRE ---
84
+
85
+ # On récupère la liste demandée (ex: "a,i,o")
86
  requested_phones = [p.strip() for p in allowed_phones.split(',') if p.strip()]
87
 
88
  if requested_phones:
89
  vocab = ai_context["vocab"]
90
+
91
+ # Tokens techniques indispensables pour que le CTC fonctionne (silence, padding...)
92
+ # Le modèle Cnam utilise '|' comme séparateur de mot/silence
93
  technical_tokens = ["<pad>", "<s>", "</s>", "<unk>", "|", "[PAD]", "[UNK]"]
94
+
95
+ # On construit l'ensemble des tokens autorisés
96
  full_allowed_set = set(requested_phones + technical_tokens)
97
 
98
+ # On trouve leurs positions numériques (ID) dans le cerveau du modèle
99
  allowed_indices = [vocab[t] for t in full_allowed_set if t in vocab]
100
 
101
  if allowed_indices:
102
+ # Création du masque : Par défaut, tout est interdit (-Infini)
103
  mask = torch.full((logits.shape[-1],), float('-inf'))
104
+
105
+ # On ouvre les portes seulement pour les indices autorisés (0.0)
106
  mask[allowed_indices] = 0.0
107
+
108
+ # On applique le masque aux logits
109
  logits = logits + mask
110
 
111
+ # 5. Décodage final (Argmax)
112
  predicted_ids = torch.argmax(logits, dim=-1)
113
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
114