proz commited on
Commit
b0b4663
·
verified ·
1 Parent(s): fc1a510

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -46
app.py CHANGED
@@ -6,26 +6,26 @@ import io
6
  import numpy as np
7
  from contextlib import asynccontextmanager
8
 
9
- # Configuration V1
10
  MODEL_ID = "Cnam-LMSSC/wav2vec2-french-phonemizer"
 
11
  ai_context = {}
12
 
13
  @asynccontextmanager
14
  async def lifespan(app: FastAPI):
15
- print(f"🚀 Chargement du modèle V1 {MODEL_ID}...")
16
  try:
17
- # Chargement du processeur et du modèle
18
  processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
19
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
20
- model.eval() # Mode lecture seule (plus rapide)
21
 
22
  ai_context["processor"] = processor
23
  ai_context["model"] = model
24
-
25
- # On stocke le vocabulaire pour le masque (ex: 'a': 12, 'b': 14)
26
  ai_context["vocab"] = processor.tokenizer.get_vocab()
 
 
27
 
28
- print("✅ Modèle V1 prêt et vocabulaire indexé.")
29
  except Exception as e:
30
  print(f"❌ Erreur critique : {e}")
31
  yield
@@ -35,84 +35,88 @@ app = FastAPI(lifespan=lifespan)
35
 
36
  @app.get("/")
37
  def home():
38
- return {"status": "Model Cnam V1 Masked is running"}
39
-
40
- # --- AJOUT POUR DIAGNOSTIC ---
41
- @app.get("/vocab")
42
- def get_vocab():
43
- """Renvoie le dictionnaire complet de la V1 pour comparaison"""
44
- if "vocab" not in ai_context:
45
- return {"error": "Modèle non chargé"}
46
-
47
- # On trie le dictionnaire par ordre alphabétique des clés pour faciliter la lecture
48
- sorted_vocab = dict(sorted(ai_context["vocab"].items()))
49
-
50
- return {
51
- "model": MODEL_ID,
52
- "total_tokens": len(sorted_vocab),
53
- "tokens": sorted_vocab
54
- }
55
- # -----------------------------
56
 
57
  @app.post("/transcribe")
58
  async def transcribe(
59
  file: UploadFile = File(...),
60
- allowed_phones: str = Form(...) # Ce champ est OBLIGATOIRE
61
  ):
62
  if "model" not in ai_context:
63
  raise HTTPException(status_code=500, detail="Modèle non chargé")
64
 
65
- # 1. Lecture Audio avec Librosa (force 16kHz)
66
  try:
67
  content = await file.read()
68
- # On utilise io.BytesIO pour lire depuis la mémoire sans fichier temporaire
69
  audio_array, _ = librosa.load(io.BytesIO(content), sr=16000)
70
  except Exception as e:
71
- raise HTTPException(status_code=400, detail=f"Erreur fichier audio: {str(e)}")
72
 
73
  # 2. Préparation Modèle
74
  processor = ai_context["processor"]
75
  model = ai_context["model"]
 
 
76
 
77
  inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
78
 
79
- # 3. Calcul des Logits (Probabilités brutes avant décision)
80
  with torch.no_grad():
81
  logits = model(inputs.input_values).logits
82
 
83
- # --- 4. APPLICATION DU MASQUE BINAIRE ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # On récupère la liste demandée (ex: "a,i,o")
86
  requested_phones = [p.strip() for p in allowed_phones.split(',') if p.strip()]
87
 
88
  if requested_phones:
89
- vocab = ai_context["vocab"]
90
-
91
- # Tokens techniques indispensables pour que le CTC fonctionne (silence, padding...)
92
- # Le modèle Cnam utilise '|' comme séparateur de mot/silence
93
  technical_tokens = ["<pad>", "<s>", "</s>", "<unk>", "|", "[PAD]", "[UNK]"]
94
-
95
- # On construit l'ensemble des tokens autorisés
96
  full_allowed_set = set(requested_phones + technical_tokens)
97
 
98
- # On trouve leurs positions numériques (ID) dans le cerveau du modèle
 
99
  allowed_indices = [vocab[t] for t in full_allowed_set if t in vocab]
100
 
101
  if allowed_indices:
102
- # Création du masque : Par défaut, tout est interdit (-Infini)
103
  mask = torch.full((logits.shape[-1],), float('-inf'))
104
-
105
- # On ouvre les portes seulement pour les indices autorisés (0.0)
106
  mask[allowed_indices] = 0.0
107
-
108
- # On applique le masque aux logits
109
  logits = logits + mask
110
 
111
- # 5. Décodage final (Argmax)
112
  predicted_ids = torch.argmax(logits, dim=-1)
113
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
114
 
115
  return {
116
  "ipa": transcription,
117
- "allowed_used": allowed_phones
 
 
 
118
  }
 
6
  import numpy as np
7
  from contextlib import asynccontextmanager
8
 
9
+ # --- CONFIGURATION V1 ---
10
  MODEL_ID = "Cnam-LMSSC/wav2vec2-french-phonemizer"
11
+
12
  ai_context = {}
13
 
14
  @asynccontextmanager
15
  async def lifespan(app: FastAPI):
16
+ print(f"🚀 Chargement V1 (Stable) : {MODEL_ID}...")
17
  try:
 
18
  processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
19
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
20
+ model.eval()
21
 
22
  ai_context["processor"] = processor
23
  ai_context["model"] = model
 
 
24
  ai_context["vocab"] = processor.tokenizer.get_vocab()
25
+ # Inversion vocab pour le debug
26
+ ai_context["id2vocab"] = {v: k for k, v in ai_context["vocab"].items()}
27
 
28
+ print("✅ Modèle V1 prêt.")
29
  except Exception as e:
30
  print(f"❌ Erreur critique : {e}")
31
  yield
 
35
 
36
  @app.get("/")
37
  def home():
38
+ return {"status": "API V1 Stable running", "model": MODEL_ID}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  @app.post("/transcribe")
41
  async def transcribe(
42
  file: UploadFile = File(...),
43
+ allowed_phones: str = Form(...)
44
  ):
45
  if "model" not in ai_context:
46
  raise HTTPException(status_code=500, detail="Modèle non chargé")
47
 
48
+ # 1. Lecture Audio (Simple, sans boost ni loop)
49
  try:
50
  content = await file.read()
 
51
  audio_array, _ = librosa.load(io.BytesIO(content), sr=16000)
52
  except Exception as e:
53
+ raise HTTPException(status_code=400, detail=f"Erreur audio: {str(e)}")
54
 
55
  # 2. Préparation Modèle
56
  processor = ai_context["processor"]
57
  model = ai_context["model"]
58
+ vocab = ai_context["vocab"]
59
+ id2vocab = ai_context["id2vocab"]
60
 
61
  inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
62
 
63
+ # 3. Calcul des Logits
64
  with torch.no_grad():
65
  logits = model(inputs.input_values).logits
66
 
67
+ # --- DEBUG TRACE (Utile pour voir ce que la V1 entend vraiment) ---
68
+ probs = torch.nn.functional.softmax(logits, dim=-1)
69
+ top3_probs, top3_ids = torch.topk(probs, 3, dim=-1)
70
+
71
+ debug_trace = []
72
+ time_steps = top3_ids.shape[1]
73
+
74
+ for t in range(time_steps):
75
+ best_id = top3_ids[0, t, 0].item()
76
+ best_token = id2vocab.get(best_id, "")
77
+ best_prob = top3_probs[0, t, 0].item()
78
+
79
+ if best_token not in ["<pad>", "<s>", "</s>", "|"] and best_prob > 0.05:
80
+ frame_info = []
81
+ for k in range(3):
82
+ alt_id = top3_ids[0, t, k].item()
83
+ alt_token = id2vocab.get(alt_id, "UNK")
84
+ alt_prob = top3_probs[0, t, k].item()
85
+ frame_info.append(f"{alt_token} ({int(alt_prob*100)}%)")
86
+ debug_trace.append(f"T{t}: " + " | ".join(frame_info))
87
+
88
+ # Capture du Raw avant masque
89
+ raw_ids = torch.argmax(logits, dim=-1)
90
+ raw_ipa = processor.batch_decode(raw_ids, skip_special_tokens=True)[0]
91
+
92
+
93
+ # --- 4. APPLICATION DU MASQUE BINAIRE (Logique V1) ---
94
+ # La V1 gère les blocs entiers (ex: ɑ̃), pas besoin de décomposition complexe
95
 
 
96
  requested_phones = [p.strip() for p in allowed_phones.split(',') if p.strip()]
97
 
98
  if requested_phones:
99
+ # Tokens techniques V1 (Le modèle V1 utilise '|' et d'autres)
 
 
 
100
  technical_tokens = ["<pad>", "<s>", "</s>", "<unk>", "|", "[PAD]", "[UNK]"]
 
 
101
  full_allowed_set = set(requested_phones + technical_tokens)
102
 
103
+ # On trouve leurs positions numériques (ID)
104
+ # Note : La V1 est "gentille", si on demande 'ɑ̃', elle a un token pour ça.
105
  allowed_indices = [vocab[t] for t in full_allowed_set if t in vocab]
106
 
107
  if allowed_indices:
 
108
  mask = torch.full((logits.shape[-1],), float('-inf'))
 
 
109
  mask[allowed_indices] = 0.0
 
 
110
  logits = logits + mask
111
 
112
+ # 5. Décodage final
113
  predicted_ids = torch.argmax(logits, dim=-1)
114
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
115
 
116
  return {
117
  "ipa": transcription,
118
+ "raw_ipa_before_mask": raw_ipa,
119
+ "debug_trace": debug_trace,
120
+ "allowed_used": allowed_phones,
121
+ "version": "V1_Classic"
122
  }