proz commited on
Commit
e1da5ec
·
verified ·
1 Parent(s): cb4c63f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -23
app.py CHANGED
@@ -6,28 +6,24 @@ import io
6
  import numpy as np
7
  from contextlib import asynccontextmanager
8
 
9
- # --- CONFIGURATION ---
10
- # On garde uniquement la V2 car elle est meilleure sur les voix réelles
11
- MODEL_ID = "Cnam-LMSSC/wav2vec2-french-phonemizer-v2"
12
 
13
  ai_context = {}
14
 
15
  @asynccontextmanager
16
  async def lifespan(app: FastAPI):
17
- print(f"🚀 Chargement du modèle stable {MODEL_ID}...")
18
  try:
19
- # Chargement du processeur et du modèle
20
  processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
21
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
22
- model.eval() # Mode lecture seule
23
 
24
  ai_context["processor"] = processor
25
  ai_context["model"] = model
26
-
27
- # On stocke le vocabulaire pour le masque
28
  ai_context["vocab"] = processor.tokenizer.get_vocab()
29
 
30
- print("✅ Modèle V2 chargé et prêt.")
31
  except Exception as e:
32
  print(f"❌ Erreur critique : {e}")
33
  yield
@@ -37,7 +33,7 @@ app = FastAPI(lifespan=lifespan)
37
 
38
  @app.get("/")
39
  def home():
40
- return {"status": "API V2 Stable running", "model": MODEL_ID}
41
 
42
  @app.post("/transcribe")
43
  async def transcribe(
@@ -47,46 +43,39 @@ async def transcribe(
47
  if "model" not in ai_context:
48
  raise HTTPException(status_code=500, detail="Modèle non chargé")
49
 
50
- # 1. Lecture Audio (Simple et robuste)
51
  try:
52
  content = await file.read()
53
- # On lit juste le fichier, sans essayer de couper les silences (risque de bugs)
54
  audio_array, _ = librosa.load(io.BytesIO(content), sr=16000)
55
  except Exception as e:
56
  raise HTTPException(status_code=400, detail=f"Erreur audio: {str(e)}")
57
 
58
- # 2. Préparation Modèle
59
  processor = ai_context["processor"]
60
  model = ai_context["model"]
61
-
62
  inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
63
 
64
- # 3. Calcul des Logits
65
  with torch.no_grad():
66
  logits = model(inputs.input_values).logits
67
 
68
- # --- 4. APPLICATION DU MASQUE BINAIRE ---
69
-
70
  requested_phones = [p.strip() for p in allowed_phones.split(',') if p.strip()]
71
 
72
  if requested_phones:
73
  vocab = ai_context["vocab"]
74
- # Tokens techniques indispensables
75
  technical_tokens = ["<pad>", "<s>", "</s>", "<unk>", "|", "[PAD]", "[UNK]"]
76
  full_allowed_set = set(requested_phones + technical_tokens)
77
 
78
- # Mapping vers les ID du modèle
79
  allowed_indices = [vocab[t] for t in full_allowed_set if t in vocab]
80
 
81
  if allowed_indices:
82
- # On interdit tout (-Infini)
83
  mask = torch.full((logits.shape[-1],), float('-inf'))
84
- # On autorise la liste blanche (0.0)
85
  mask[allowed_indices] = 0.0
86
- # Application
87
  logits = logits + mask
88
 
89
- # 5. Décodage (Le gagnant prend tout)
90
  predicted_ids = torch.argmax(logits, dim=-1)
91
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
92
 
 
6
  import numpy as np
7
  from contextlib import asynccontextmanager
8
 
9
+ # --- CONFIGURATION V1 ---
10
+ MODEL_ID = "Cnam-LMSSC/wav2vec2-french-phonemizer"
 
11
 
12
  ai_context = {}
13
 
14
  @asynccontextmanager
15
  async def lifespan(app: FastAPI):
16
+ print(f"🚀 Chargement du modèle V1 {MODEL_ID}...")
17
  try:
 
18
  processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
19
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
20
+ model.eval()
21
 
22
  ai_context["processor"] = processor
23
  ai_context["model"] = model
 
 
24
  ai_context["vocab"] = processor.tokenizer.get_vocab()
25
 
26
+ print("✅ Modèle V1 prêt.")
27
  except Exception as e:
28
  print(f"❌ Erreur critique : {e}")
29
  yield
 
33
 
34
  @app.get("/")
35
  def home():
36
+ return {"status": "API V1 running", "model": MODEL_ID}
37
 
38
  @app.post("/transcribe")
39
  async def transcribe(
 
43
  if "model" not in ai_context:
44
  raise HTTPException(status_code=500, detail="Modèle non chargé")
45
 
46
+ # 1. Lecture Audio
47
  try:
48
  content = await file.read()
 
49
  audio_array, _ = librosa.load(io.BytesIO(content), sr=16000)
50
  except Exception as e:
51
  raise HTTPException(status_code=400, detail=f"Erreur audio: {str(e)}")
52
 
53
+ # 2. Préparation
54
  processor = ai_context["processor"]
55
  model = ai_context["model"]
 
56
  inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True)
57
 
58
+ # 3. Logits
59
  with torch.no_grad():
60
  logits = model(inputs.input_values).logits
61
 
62
+ # 4. Masque Binaire Strict
 
63
  requested_phones = [p.strip() for p in allowed_phones.split(',') if p.strip()]
64
 
65
  if requested_phones:
66
  vocab = ai_context["vocab"]
67
+ # Tokens techniques V1 (Le modèle V1 utilise beaucoup le pipe '|')
68
  technical_tokens = ["<pad>", "<s>", "</s>", "<unk>", "|", "[PAD]", "[UNK]"]
69
  full_allowed_set = set(requested_phones + technical_tokens)
70
 
 
71
  allowed_indices = [vocab[t] for t in full_allowed_set if t in vocab]
72
 
73
  if allowed_indices:
 
74
  mask = torch.full((logits.shape[-1],), float('-inf'))
 
75
  mask[allowed_indices] = 0.0
 
76
  logits = logits + mask
77
 
78
+ # 5. Décodage
79
  predicted_ids = torch.argmax(logits, dim=-1)
80
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
81