proz commited on
Commit
284104a
·
verified ·
1 Parent(s): f76a6e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -36
app.py CHANGED
@@ -7,19 +7,19 @@ import numpy as np
7
  from contextlib import asynccontextmanager
8
 
9
  # --- CONFIGURATION ---
10
- # Utilisation de la V2 (Meilleure sur voix réelles/CommonVoice)
11
  MODEL_ID = "Cnam-LMSSC/wav2vec2-french-phonemizer-v2"
12
 
13
  ai_context = {}
14
 
15
  @asynccontextmanager
16
  async def lifespan(app: FastAPI):
17
- print(f"🚀 Chargement du modèle {MODEL_ID}...")
18
  try:
19
  # Chargement du processeur et du modèle
20
  processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
21
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
22
- model.eval() # Mode lecture seule (plus rapide)
23
 
24
  ai_context["processor"] = processor
25
  ai_context["model"] = model
@@ -27,9 +27,9 @@ async def lifespan(app: FastAPI):
27
  # On stocke le vocabulaire pour le masque
28
  ai_context["vocab"] = processor.tokenizer.get_vocab()
29
 
30
- print("✅ Modèle chargé, VAD actif et vocabulaire indexé.")
31
  except Exception as e:
32
- print(f"❌ Erreur critique au chargement : {e}")
33
  yield
34
  ai_context.clear()
35
 
@@ -37,7 +37,7 @@ app = FastAPI(lifespan=lifespan)
37
 
38
  @app.get("/")
39
  def home():
40
- return {"status": "API V2 running", "model": MODEL_ID}
41
 
42
  @app.post("/transcribe")
43
  async def transcribe(
@@ -47,67 +47,50 @@ async def transcribe(
47
  if "model" not in ai_context:
48
  raise HTTPException(status_code=500, detail="Modèle non chargé")
49
 
50
- # 1. Lecture Audio
51
  try:
52
  content = await file.read()
 
53
  audio_array, _ = librosa.load(io.BytesIO(content), sr=16000)
54
  except Exception as e:
55
- raise HTTPException(status_code=400, detail=f"Erreur lecture audio: {str(e)}")
56
 
57
- # 2. NETTOYAGE AUDIO (VAD / Trimming)
58
- # On coupe le silence au début et à la fin.
59
- # top_db=20 signifie qu'on coupe tout ce qui est 20dB sous le pic sonore (bruit de fond).
60
- audio_trimmed, _ = librosa.effects.trim(audio_array, top_db=20)
61
-
62
- # Sécurité : Si l'audio est vide après nettoyage (que du silence), on renvoie vide
63
- if len(audio_trimmed) < 1000: # moins de 0.06 seconde
64
- return {"ipa": "", "confidence": 0.0, "status": "empty_audio"}
65
-
66
- # 3. Préparation Modèle
67
  processor = ai_context["processor"]
68
  model = ai_context["model"]
69
 
70
- inputs = processor(audio_trimmed, sampling_rate=16000, return_tensors="pt", padding=True)
71
 
72
- # 4. Calcul des Logits
73
  with torch.no_grad():
74
  logits = model(inputs.input_values).logits
75
 
76
- # --- 5. APPLICATION DU MASQUE BINAIRE ---
77
 
78
  requested_phones = [p.strip() for p in allowed_phones.split(',') if p.strip()]
79
 
80
  if requested_phones:
81
  vocab = ai_context["vocab"]
82
- # Tokens techniques pour le modèle Cnam
83
  technical_tokens = ["<pad>", "<s>", "</s>", "<unk>", "|", "[PAD]", "[UNK]"]
84
  full_allowed_set = set(requested_phones + technical_tokens)
85
 
 
86
  allowed_indices = [vocab[t] for t in full_allowed_set if t in vocab]
87
 
88
  if allowed_indices:
89
- # Masque d'interdiction totale (-Infini)
90
  mask = torch.full((logits.shape[-1],), float('-inf'))
91
- # Ouverture des phonèmes autorisés
92
  mask[allowed_indices] = 0.0
93
  # Application
94
  logits = logits + mask
95
 
96
- # 6. CALCUL DU SCORE DE CONFIANCE
97
- # On transforme les logits en probabilités (0% à 100%)
98
- probs = torch.nn.functional.softmax(logits, dim=-1)
99
-
100
- # On prend la probabilité la plus haute pour chaque instant (frame)
101
- max_probs, predicted_ids = torch.max(probs, dim=-1)
102
-
103
- # Le score de confiance est la moyenne de certitude sur toute la séquence
104
- confidence_score = float(torch.mean(max_probs))
105
-
106
- # 7. Décodage
107
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
108
 
109
  return {
110
  "ipa": transcription,
111
- "confidence": confidence_score, # ex: 0.95
112
  "allowed_used": allowed_phones
113
  }
 
7
  from contextlib import asynccontextmanager
8
 
9
  # --- CONFIGURATION ---
10
+ # On garde uniquement la V2 car elle est meilleure sur les voix réelles
11
  MODEL_ID = "Cnam-LMSSC/wav2vec2-french-phonemizer-v2"
12
 
13
  ai_context = {}
14
 
15
  @asynccontextmanager
16
  async def lifespan(app: FastAPI):
17
+ print(f"🚀 Chargement du modèle stable {MODEL_ID}...")
18
  try:
19
  # Chargement du processeur et du modèle
20
  processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
21
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
22
+ model.eval() # Mode lecture seule
23
 
24
  ai_context["processor"] = processor
25
  ai_context["model"] = model
 
27
  # On stocke le vocabulaire pour le masque
28
  ai_context["vocab"] = processor.tokenizer.get_vocab()
29
 
30
+ print("✅ Modèle V2 chargé et prêt.")
31
  except Exception as e:
32
+ print(f"❌ Erreur critique : {e}")
33
  yield
34
  ai_context.clear()
35
 
 
37
 
38
  @app.get("/")
39
  def home():
40
+ return {"status": "API V2 Stable running", "model": MODEL_ID}
41
 
42
  @app.post("/transcribe")
43
  async def transcribe(
 
47
  if "model" not in ai_context:
48
  raise HTTPException(status_code=500, detail="Modèle non chargé")
49
 
50
+ # 1. Lecture Audio (Simple et robuste)
51
  try:
52
  content = await file.read()
53
+ # On lit juste le fichier, sans essayer de couper les silences (risque de bugs)
54
  audio_array, _ = librosa.load(io.BytesIO(content), sr=16000)
55
  except Exception as e:
56
+ raise HTTPException(status_code=400, detail=f"Erreur audio: {str(e)}")
57
 
58
+ # 2. Préparation Modèle
 
 
 
 
 
 
 
 
 
59
  processor = ai_context["processor"]
60
  model = ai_context["model"]
61
 
62
+ inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
63
 
64
+ # 3. Calcul des Logits
65
  with torch.no_grad():
66
  logits = model(inputs.input_values).logits
67
 
68
+ # --- 4. APPLICATION DU MASQUE BINAIRE ---
69
 
70
  requested_phones = [p.strip() for p in allowed_phones.split(',') if p.strip()]
71
 
72
  if requested_phones:
73
  vocab = ai_context["vocab"]
74
+ # Tokens techniques indispensables
75
  technical_tokens = ["<pad>", "<s>", "</s>", "<unk>", "|", "[PAD]", "[UNK]"]
76
  full_allowed_set = set(requested_phones + technical_tokens)
77
 
78
+ # Mapping vers les ID du modèle
79
  allowed_indices = [vocab[t] for t in full_allowed_set if t in vocab]
80
 
81
  if allowed_indices:
82
+ # On interdit tout (-Infini)
83
  mask = torch.full((logits.shape[-1],), float('-inf'))
84
+ # On autorise la liste blanche (0.0)
85
  mask[allowed_indices] = 0.0
86
  # Application
87
  logits = logits + mask
88
 
89
+ # 5. Décodage (Le gagnant prend tout)
90
+ predicted_ids = torch.argmax(logits, dim=-1)
 
 
 
 
 
 
 
 
 
91
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
92
 
93
  return {
94
  "ipa": transcription,
 
95
  "allowed_used": allowed_phones
96
  }