binaryMao commited on
Commit
697b4cb
·
verified ·
1 Parent(s): 5824c95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -19
app.py CHANGED
@@ -3,11 +3,14 @@ import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil
3
  import torch
4
  from huggingface_hub import snapshot_download
5
  from nemo.collections import asr as nemo_asr
 
 
6
  import gradio as gr
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
  SEGMENT_DURATION = 10.0
10
 
 
11
  MODELS = {
12
  "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
13
  "Soloba V2 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"),
@@ -23,24 +26,33 @@ MODELS = {
23
  _cache = {}
24
 
25
  def get_model(name):
 
26
  if name in _cache:
27
  return _cache[name]
28
 
29
- # NETTOYAGE MÉMOIRE AGRESSIF (Crucial pour les modèles 0.6b sur CPU Free)
30
  if len(_cache) >= 1:
31
  _cache.clear()
32
  gc.collect()
33
  if torch.cuda.is_available(): torch.cuda.empty_cache()
34
 
35
- repo, _ = MODELS[name]
36
- print(f"⏳ Chargement de {name}...")
37
 
38
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
39
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
40
 
41
- model = nemo_asr.models.ASRModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
 
 
 
 
 
 
 
 
 
42
  model.eval()
43
-
44
  if DEVICE == "cuda":
45
  model = model.half()
46
 
@@ -49,15 +61,15 @@ def get_model(name):
49
 
50
  def pipeline(audio_in, model_name):
51
  if not audio_in:
52
- yield "❌ Erreur", "Veuillez fournir un fichier audio."
53
  return
54
 
55
  tmp_dir = tempfile.mkdtemp()
56
  try:
57
- yield "⏳ Préparation de l'audio...", ""
58
 
59
- wav_path = os.path.join(tmp_dir, "clean.wav")
60
- # Conversion simple et robuste
61
  subprocess.run(f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {wav_path}", shell=True, check=True)
62
 
63
  # Segmentation
@@ -67,7 +79,7 @@ def pipeline(audio_in, model_name):
67
  valid_segments = [f for f in valid_segments if os.path.getsize(f) > 1000]
68
 
69
  if not valid_segments:
70
- yield "❌ Erreur", "Audio trop court ou illisible."
71
  return
72
 
73
  yield f"🎙️ Transcription ({len(valid_segments)} segments)...", ""
@@ -75,6 +87,7 @@ def pipeline(audio_in, model_name):
75
  model = get_model(model_name)
76
 
77
  with torch.inference_mode():
 
78
  batch_hyp = model.transcribe(
79
  valid_segments,
80
  batch_size=4,
@@ -84,31 +97,32 @@ def pipeline(audio_in, model_name):
84
 
85
  results = []
86
  for hyp in batch_hyp:
 
87
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
88
  if text: results.append(text)
89
 
90
- yield "✅ Terminé", " ".join(results)
91
 
92
  except Exception as e:
 
93
  yield "❌ Erreur", str(e)
94
  finally:
95
  if os.path.exists(tmp_dir):
96
  shutil.rmtree(tmp_dir)
97
 
98
- # --- INTERFACE (Argument show_copy_button supprimé) ---
99
- with gr.Blocks() as demo:
100
- gr.Markdown("# 🤖 RobotsMali Speech-to-Text")
101
 
102
  with gr.Row():
103
  with gr.Column():
104
- audio_input = gr.Audio(label="Audio", type="filepath")
105
  model_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle")
106
- run_btn = gr.Button("🚀 TRANSCRIRE", variant="primary")
107
 
108
  with gr.Column():
109
- status = gr.Markdown("### État : Prêt")
110
- # Correction ici : show_copy_button est retiré pour compatibilité
111
- text_output = gr.Textbox(label="Résultat", lines=15)
112
 
113
  run_btn.click(fn=pipeline, inputs=[audio_input, model_input], outputs=[status, text_output])
114
 
 
3
  import torch
4
  from huggingface_hub import snapshot_download
5
  from nemo.collections import asr as nemo_asr
6
+ # Imports spécifiques pour éviter l'erreur "Abstract Class"
7
+ from nemo.collections.asr.models import EncDecCTCModel, EncDecRNNTModel
8
  import gradio as gr
9
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  SEGMENT_DURATION = 10.0
12
 
13
+ # Dictionnaire complet (Nom: (Repo, Type))
14
  MODELS = {
15
  "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
16
  "Soloba V2 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"),
 
26
  _cache = {}
27
 
28
  def get_model(name):
29
+ """Charge le modèle en forçant la classe concrète (CTC ou RNNT)."""
30
  if name in _cache:
31
  return _cache[name]
32
 
33
+ # Libération agressive de la RAM avant chargement
34
  if len(_cache) >= 1:
35
  _cache.clear()
36
  gc.collect()
37
  if torch.cuda.is_available(): torch.cuda.empty_cache()
38
 
39
+ repo, arch_type = MODELS[name]
40
+ print(f"⏳ Préparation du modèle {name} ({arch_type})...")
41
 
42
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
43
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
44
 
45
+ # Utilisation de la classe spécifique pour contourner l'erreur ASRModel
46
+ try:
47
+ if arch_type == "ctc":
48
+ model = EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
49
+ else:
50
+ model = EncDecRNNTModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
51
+ except Exception as e:
52
+ print(f"⚠️ Erreur de chargement spécifique, tentative générique : {e}")
53
+ model = nemo_asr.models.ASRModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
54
+
55
  model.eval()
 
56
  if DEVICE == "cuda":
57
  model = model.half()
58
 
 
61
 
62
  def pipeline(audio_in, model_name):
63
  if not audio_in:
64
+ yield "❌ Erreur", "Aucun audio détecté."
65
  return
66
 
67
  tmp_dir = tempfile.mkdtemp()
68
  try:
69
+ yield "⏳ Traitement de l'audio...", ""
70
 
71
+ # Normalisation FFmpeg
72
+ wav_path = os.path.join(tmp_dir, "input.wav")
73
  subprocess.run(f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {wav_path}", shell=True, check=True)
74
 
75
  # Segmentation
 
79
  valid_segments = [f for f in valid_segments if os.path.getsize(f) > 1000]
80
 
81
  if not valid_segments:
82
+ yield "❌ Erreur", "Fichier audio vide ou incompatible."
83
  return
84
 
85
  yield f"🎙️ Transcription ({len(valid_segments)} segments)...", ""
 
87
  model = get_model(model_name)
88
 
89
  with torch.inference_mode():
90
+ # Mode stable sans Lhotse
91
  batch_hyp = model.transcribe(
92
  valid_segments,
93
  batch_size=4,
 
97
 
98
  results = []
99
  for hyp in batch_hyp:
100
+ # Gère les formats de sortie CTC et RNNT
101
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
102
  if text: results.append(text)
103
 
104
+ yield "✅ Succès", " ".join(results)
105
 
106
  except Exception as e:
107
+ print(traceback.format_exc())
108
  yield "❌ Erreur", str(e)
109
  finally:
110
  if os.path.exists(tmp_dir):
111
  shutil.rmtree(tmp_dir)
112
 
113
+ # --- UI GRADIO ---
114
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
115
+ gr.Markdown("# 🤖 RobotsMali - Reconnaissance Vocale")
116
 
117
  with gr.Row():
118
  with gr.Column():
119
+ audio_input = gr.Audio(label="Audio", type="filepath", sources=["upload", "microphone"])
120
  model_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle")
121
+ run_btn = gr.Button("🚀 DÉMARRER", variant="primary")
122
 
123
  with gr.Column():
124
+ status = gr.Markdown("### État : En attente")
125
+ text_output = gr.Textbox(label="Transcription", lines=12)
 
126
 
127
  run_btn.click(fn=pipeline, inputs=[audio_input, model_input], outputs=[status, text_output])
128