binaryMao commited on
Commit
6a368d1
·
verified ·
1 Parent(s): 7b84d76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -38
app.py CHANGED
@@ -4,12 +4,16 @@ from huggingface_hub import snapshot_download
4
  from nemo.collections import asr as nemo_asr
5
  from nemo.collections.asr.models import EncDecCTCModel, EncDecRNNTModel
6
  import gradio as gr
 
 
 
7
 
8
  # Configuration
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
  SEGMENT_DURATION = 10.0
11
  print(f"✅ Démarrage sur device: {DEVICE}")
12
  print(f"✅ Gradio version: {gr.__version__}")
 
13
 
14
  # Dictionnaire des modèles RobotsMali
15
  MODELS = {
@@ -28,23 +32,28 @@ _cache = {}
28
 
29
  def get_model(name):
30
  if name in _cache:
 
31
  return _cache[name]
32
 
33
  print(f"📥 Chargement du modèle: {name}")
34
 
35
  # Gestion agressive de la mémoire
36
  if len(_cache) >= 1:
 
37
  _cache.clear()
38
  gc.collect()
39
  if torch.cuda.is_available():
40
  torch.cuda.empty_cache()
 
41
 
42
  try:
43
  repo, arch_type = MODELS[name]
44
  print(f"📦 Téléchargement depuis {repo}...")
45
 
 
46
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
47
- print(f"📁 Dossier: {folder}")
 
48
 
49
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
50
 
@@ -52,17 +61,54 @@ def get_model(name):
52
  raise FileNotFoundError(f"Aucun fichier .nemo trouvé dans {folder}")
53
 
54
  print(f"🔧 Restauration du modèle depuis {nemo_file}")
 
55
 
56
- if arch_type == "ctc":
57
- model = EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
58
- else:
59
- model = EncDecRNNTModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
60
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  model.eval()
62
  if DEVICE == "cuda":
63
  model = model.half()
 
64
 
65
  print(f"✅ Modèle {name} chargé avec succès")
 
 
 
66
  _cache[name] = model
67
  return model
68
 
@@ -71,96 +117,267 @@ def get_model(name):
71
  print(traceback.format_exc())
72
  raise e
73
 
74
- def pipeline(audio_in, model_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  if not audio_in:
76
- yield "❌ Erreur", "Aucun audio détecté."
77
  return
78
 
79
  tmp_dir = tempfile.mkdtemp()
80
  try:
81
- yield "⏳ Traitement de l'audio...", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # Vérification que le fichier audio existe
 
 
 
 
 
 
 
 
84
  if not os.path.exists(audio_in):
85
- yield "❌ Erreur", f"Fichier audio introuvable: {audio_in}"
86
  return
87
-
88
- wav_path = os.path.join(tmp_dir, "input.wav")
89
 
90
- # Conversion audio
91
  cmd = f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {shlex.quote(wav_path)}"
92
  result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
93
 
94
  if result.returncode != 0:
95
- yield "❌ Erreur", f"Erreur FFmpeg: {result.stderr}"
 
96
  return
97
 
98
  if not os.path.exists(wav_path) or os.path.getsize(wav_path) == 0:
99
- yield "❌ Erreur", "Fichier audio converti vide"
100
  return
101
-
102
- # Segmentation
 
 
 
 
 
 
 
103
  seg_pattern = os.path.join(tmp_dir, 'seg_%03d.wav')
104
  cmd = f"ffmpeg -i {shlex.quote(wav_path)} -f segment -segment_time {SEGMENT_DURATION} -c copy {shlex.quote(seg_pattern)}"
105
  subprocess.run(cmd, shell=True, capture_output=True)
106
 
107
  valid_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
108
  if not valid_segments:
109
- yield "❌ Erreur", "Fichier audio vide ou incompatible."
110
  return
111
 
112
- print(f"🔊 {len(valid_segments)} segments à transcrire")
 
 
 
 
 
 
 
 
 
 
113
 
114
  model = get_model(model_name)
115
 
 
 
 
 
 
 
116
  with torch.inference_mode():
117
- batch_hyp = model.transcribe(valid_segments, batch_size=4, return_hypotheses=True)
118
-
119
- results = [hyp.text if hasattr(hyp, 'text') else str(hyp) for hyp in batch_hyp]
120
- yield "✅ Succès", " ".join(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  except Exception as e:
123
  print(traceback.format_exc())
124
- yield "❌ Erreur", str(e)
 
125
  finally:
126
  if os.path.exists(tmp_dir):
127
  shutil.rmtree(tmp_dir)
 
128
 
129
  # Interface Gradio
130
  with gr.Blocks(title="RobotsMali ASR", theme=gr.themes.Soft()) as demo:
131
- gr.Markdown("# 🤖 RobotsMali - Reconnaissance Vocale")
 
 
 
 
 
132
 
133
  with gr.Row():
134
- with gr.Column():
135
  audio_input = gr.Audio(
136
- label="Audio",
137
  type="filepath",
138
- sources=["upload", "microphone"]
 
 
 
 
139
  )
 
140
  model_input = gr.Dropdown(
141
  choices=list(MODELS.keys()),
142
- value="Soloni V3 (TDT-CTC)",
143
- label="Modèle"
 
144
  )
145
- run_btn = gr.Button("🚀 DÉMARRER", variant="primary")
 
146
 
 
 
 
 
 
147
  with gr.Column():
148
- status = gr.Markdown("### État : En attente")
149
- text_output = gr.Textbox(label="Transcription", lines=12)
150
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  run_btn.click(
152
  fn=pipeline,
153
  inputs=[audio_input, model_input],
154
- outputs=[status, text_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  )
156
 
157
  # Point d'entrée - Configuration pour Hugging Face Spaces
158
  if __name__ == "__main__":
159
  print("🚀 Lancement de l'application RobotsMali ASR...")
 
 
160
 
161
  # Configuration simple pour Spaces
162
  demo.queue().launch(
163
  server_name="0.0.0.0",
164
- server_port=7860
165
- # Pas de show_api, pas de share - juste l'essentiel
166
  )
 
4
  from nemo.collections import asr as nemo_asr
5
  from nemo.collections.asr.models import EncDecCTCModel, EncDecRNNTModel
6
  import gradio as gr
7
+ import time
8
+ import psutil
9
+ import humanize
10
 
11
  # Configuration
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
  SEGMENT_DURATION = 10.0
14
  print(f"✅ Démarrage sur device: {DEVICE}")
15
  print(f"✅ Gradio version: {gr.__version__}")
16
+ print(f"✅ Mémoire disponible: {humanize.naturalsize(psutil.virtual_memory().available)}")
17
 
18
  # Dictionnaire des modèles RobotsMali
19
  MODELS = {
 
32
 
33
  def get_model(name):
34
  if name in _cache:
35
+ print(f"✅ Modèle {name} déjà en cache")
36
  return _cache[name]
37
 
38
  print(f"📥 Chargement du modèle: {name}")
39
 
40
  # Gestion agressive de la mémoire
41
  if len(_cache) >= 1:
42
+ print("🧹 Nettoyage du cache...")
43
  _cache.clear()
44
  gc.collect()
45
  if torch.cuda.is_available():
46
  torch.cuda.empty_cache()
47
+ print(f"🧹 Mémoire GPU libérée: {torch.cuda.memory_allocated()/1e9:.2f}GB")
48
 
49
  try:
50
  repo, arch_type = MODELS[name]
51
  print(f"📦 Téléchargement depuis {repo}...")
52
 
53
+ start_time = time.time()
54
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
55
+ download_time = time.time() - start_time
56
+ print(f"📁 Dossier: {folder} (téléchargé en {download_time:.1f}s)")
57
 
58
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
59
 
 
61
  raise FileNotFoundError(f"Aucun fichier .nemo trouvé dans {folder}")
62
 
63
  print(f"🔧 Restauration du modèle depuis {nemo_file}")
64
+ print(f"📊 Taille du fichier: {humanize.naturalsize(os.path.getsize(nemo_file))}")
65
 
66
+ # Chargement avec gestion d'erreur
67
+ try:
68
+ if arch_type == "ctc":
69
+ model = EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
70
+ else:
71
+ model = EncDecRNNTModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
72
+ except Exception as e:
73
+ print(f"⚠️ Erreur de chargement spécifique, tentative avec ASRModel: {e}")
74
+ model = nemo_asr.models.ASRModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
75
+
76
+ # === CORRECTION : Patch de la configuration pour éviter l'erreur key_phrase_items_list ===
77
+ try:
78
+ if hasattr(model, 'cfg'):
79
+ def remove_key_phrase_items_list(config):
80
+ if isinstance(config, dict):
81
+ if 'key_phrase_items_list' in config:
82
+ del config['key_phrase_items_list']
83
+ print("✅ Clé problématique key_phrase_items_list supprimée")
84
+ for key, value in config.items():
85
+ if isinstance(value, (dict, list)):
86
+ remove_key_phrase_items_list(value)
87
+ elif isinstance(config, list):
88
+ for item in config:
89
+ if isinstance(item, (dict, list)):
90
+ remove_key_phrase_items_list(item)
91
+
92
+ remove_key_phrase_items_list(model.cfg)
93
+
94
+ # Désactiver boosting_tree si présent
95
+ if 'decoding' in model.cfg:
96
+ if 'greedy' in model.cfg.decoding:
97
+ if 'boosting_tree' in model.cfg.decoding.greedy:
98
+ model.cfg.decoding.greedy.boosting_tree = None
99
+ print("✅ Boosting_tree désactivé")
100
+ except Exception as e:
101
+ print(f"⚠️ Avertissement lors du patch de config: {e}")
102
+
103
  model.eval()
104
  if DEVICE == "cuda":
105
  model = model.half()
106
+ print(f"🎯 Modèle converti en half precision")
107
 
108
  print(f"✅ Modèle {name} chargé avec succès")
109
+ if DEVICE == "cuda":
110
+ print(f"📊 Mémoire GPU utilisée: {torch.cuda.memory_allocated()/1e9:.2f}GB")
111
+
112
  _cache[name] = model
113
  return model
114
 
 
117
  print(traceback.format_exc())
118
  raise e
119
 
120
+ def get_audio_info(filepath):
121
+ """Récupère les informations d'un fichier audio"""
122
+ try:
123
+ cmd = f"ffprobe -v quiet -print_format json -show_format -show_streams {shlex.quote(filepath)}"
124
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
125
+ if result.returncode == 0:
126
+ import json
127
+ info = json.loads(result.stdout)
128
+ streams = info.get('streams', [])
129
+ audio_stream = next((s for s in streams if s.get('codec_type') == 'audio'), None)
130
+ if audio_stream:
131
+ duration = float(info['format'].get('duration', 0))
132
+ return {
133
+ 'duration': duration,
134
+ 'sample_rate': audio_stream.get('sample_rate', '?'),
135
+ 'channels': audio_stream.get('channels', '?'),
136
+ 'codec': audio_stream.get('codec_name', '?'),
137
+ 'size': humanize.naturalsize(int(info['format'].get('size', 0)))
138
+ }
139
+ except:
140
+ pass
141
+ return None
142
+
143
+ def format_time(seconds):
144
+ """Formate le temps en MM:SS"""
145
+ minutes = int(seconds // 60)
146
+ seconds = int(seconds % 60)
147
+ return f"{minutes:02d}:{seconds:02d}"
148
+
149
+ def pipeline(audio_in, model_name, progress=gr.Progress()):
150
  if not audio_in:
151
+ yield "❌ Erreur", "Aucun audio détecté.", gr.update(visible=False)
152
  return
153
 
154
  tmp_dir = tempfile.mkdtemp()
155
  try:
156
+ # === PHASE 1: Analyse de l'audio original ===
157
+ yield "🔍 Analyse du fichier audio...", "", gr.update(visible=False)
158
+
159
+ audio_info = get_audio_info(audio_in)
160
+ if audio_info:
161
+ duration = audio_info['duration']
162
+ duration_str = format_time(duration)
163
+ info_text = f"""
164
+ 📊 **Informations audio :**
165
+ - Durée : {duration_str} ({duration:.1f} secondes)
166
+ - Taille : {audio_info['size']}
167
+ - Fréquence : {audio_info['sample_rate']} Hz
168
+ - Canaux : {audio_info['channels']}
169
+ - Codec : {audio_info['codec']}
170
+ """
171
+ else:
172
+ duration = 0
173
+ info_text = "ℹ️ Impossible de lire les métadonnées audio"
174
 
175
+ yield f"⏳ Préparation... ({duration_str if duration > 0 else '??'})", info_text, gr.update(visible=False)
176
+
177
+ # === PHASE 2: Conversion audio ===
178
+ yield f"🔄 Conversion audio (étape 1/3)...", info_text, gr.update(visible=False)
179
+ progress(0.1, desc="Conversion audio...")
180
+
181
+ wav_path = os.path.join(tmp_dir, "input.wav")
182
+
183
+ # Vérification que le fichier source existe
184
  if not os.path.exists(audio_in):
185
+ yield "❌ Erreur", f"Fichier audio introuvable: {audio_in}", gr.update(visible=False)
186
  return
 
 
187
 
188
+ # Conversion avec FFmpeg
189
  cmd = f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {shlex.quote(wav_path)}"
190
  result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
191
 
192
  if result.returncode != 0:
193
+ error_msg = f"Erreur FFmpeg: {result.stderr[:200]}..."
194
+ yield "❌ Erreur", error_msg, gr.update(visible=False)
195
  return
196
 
197
  if not os.path.exists(wav_path) or os.path.getsize(wav_path) == 0:
198
+ yield "❌ Erreur", "Fichier audio converti vide", gr.update(visible=False)
199
  return
200
+
201
+ # Info sur l'audio converti
202
+ converted_size = humanize.naturalsize(os.path.getsize(wav_path))
203
+ info_text += f"\n- Après conversion : {converted_size}"
204
+
205
+ # === PHASE 3: Segmentation ===
206
+ yield f"✂️ Segmentation audio (étape 2/3)...", info_text, gr.update(visible=False)
207
+ progress(0.3, desc="Segmentation...")
208
+
209
  seg_pattern = os.path.join(tmp_dir, 'seg_%03d.wav')
210
  cmd = f"ffmpeg -i {shlex.quote(wav_path)} -f segment -segment_time {SEGMENT_DURATION} -c copy {shlex.quote(seg_pattern)}"
211
  subprocess.run(cmd, shell=True, capture_output=True)
212
 
213
  valid_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
214
  if not valid_segments:
215
+ yield "❌ Erreur", "Fichier audio vide ou incompatible après segmentation.", gr.update(visible=False)
216
  return
217
 
218
+ nb_segments = len(valid_segments)
219
+ estimated_time = nb_segments * 2 # Estimation ~2 secondes par segment
220
+
221
+ info_text += f"\n📦 **Segments :** {nb_segments} (durée ~{format_time(duration)})"
222
+ yield f"🎯 Transcription en cours (étape 3/3)... {nb_segments} segments", info_text, gr.update(visible=False)
223
+ progress(0.5, desc=f"Transcription de {nb_segments} segments...")
224
+
225
+ print(f"🔊 {nb_segments} segments à transcrire")
226
+
227
+ # === PHASE 4: Chargement du modèle ===
228
+ yield f"🤖 Chargement du modèle {model_name}...", info_text, gr.update(visible=False)
229
 
230
  model = get_model(model_name)
231
 
232
+ # === PHASE 5: Transcription ===
233
+ yield f"📝 Transcription en cours... (0/{nb_segments})", info_text, gr.update(visible=False)
234
+
235
+ all_results = []
236
+ batch_size = 4
237
+
238
  with torch.inference_mode():
239
+ for i in range(0, len(valid_segments), batch_size):
240
+ batch = valid_segments[i:i+batch_size]
241
+ batch_hyp = model.transcribe(batch, batch_size=len(batch), return_hypotheses=True)
242
+
243
+ batch_results = [hyp.text if hasattr(hyp, 'text') else str(hyp) for hyp in batch_hyp]
244
+ all_results.extend(batch_results)
245
+
246
+ # Mise à jour de la progression
247
+ processed = min(i + batch_size, nb_segments)
248
+ progress_val = 0.5 + (0.5 * processed / nb_segments)
249
+ progress(progress_val, desc=f"Transcription {processed}/{nb_segments}")
250
+
251
+ # Afficher les résultats partiels
252
+ partial_text = " ".join(all_results)
253
+ yield f"📝 Transcription en cours... ({processed}/{nb_segments})", info_text, gr.update(value=partial_text, visible=True)
254
+
255
+ final_text = " ".join(all_results)
256
+
257
+ # === FIN ===
258
+ success_text = f"""
259
+ ✅ **Transcription terminée !**
260
+ - Modèle : {model_name}
261
+ - Durée audio : {format_time(duration) if duration > 0 else '?'}
262
+ - Segments : {nb_segments}
263
+
264
+ {info_text}
265
+ """
266
+
267
+ yield "✅ Succès", success_text, gr.update(value=final_text, visible=True)
268
 
269
  except Exception as e:
270
  print(traceback.format_exc())
271
+ error_msg = f"❌ Erreur: {str(e)}"
272
+ yield error_msg, "", gr.update(visible=False)
273
  finally:
274
  if os.path.exists(tmp_dir):
275
  shutil.rmtree(tmp_dir)
276
+ print(f"🧹 Nettoyage du répertoire temporaire: {tmp_dir}")
277
 
278
  # Interface Gradio
279
  with gr.Blocks(title="RobotsMali ASR", theme=gr.themes.Soft()) as demo:
280
+ gr.Markdown("""
281
+ # 🤖 RobotsMali - Reconnaissance Vocale
282
+ **Transcription automatique de l'audio en texte** avec les modèles de RobotsMali.
283
+
284
+ Choisissez un modèle, uploadez un fichier audio ou enregistrez-vous, et lancez la transcription !
285
+ """)
286
 
287
  with gr.Row():
288
+ with gr.Column(scale=1):
289
  audio_input = gr.Audio(
290
+ label="🎤 Audio",
291
  type="filepath",
292
+ sources=["upload", "microphone"],
293
+ waveform_options=gr.WaveformOptions(
294
+ waveform_color="#3498db",
295
+ waveform_progress_color="#2ecc71",
296
+ )
297
  )
298
+
299
  model_input = gr.Dropdown(
300
  choices=list(MODELS.keys()),
301
+ value=list(MODELS.keys())[4], # Soloni V3 par défaut
302
+ label="🧠 Modèle",
303
+ info="Sélectionnez le modèle de transcription"
304
  )
305
+
306
+ run_btn = gr.Button("🚀 DÉMARRER LA TRANSCRIPTION", variant="primary", size="lg")
307
 
308
+ with gr.Column(scale=1):
309
+ status = gr.Markdown("### 📊 État : En attente")
310
+ audio_info = gr.Markdown("ℹ️ Chargez un audio pour voir ses informations")
311
+
312
+ with gr.Row():
313
  with gr.Column():
314
+ text_output = gr.Textbox(
315
+ label="📝 Transcription",
316
+ lines=8,
317
+ placeholder="La transcription apparaîtra ici...",
318
+ interactive=False,
319
+ visible=False
320
+ )
321
+
322
+ with gr.Row():
323
+ gr.Examples(
324
+ examples=[
325
+ ["exemples/audio1.wav"],
326
+ ["exemples/audio2.mp3"],
327
+ ],
328
+ inputs=audio_input,
329
+ label="📋 Exemples (si disponibles)"
330
+ )
331
+
332
+ # Footer
333
+ gr.Markdown("""
334
+ ---
335
+ ### 📌 Notes
336
+ - Les fichiers audio sont traités localement et supprimés après transcription
337
+ - Durée maximale recommandée : 5 minutes
338
+ - Modèles entraînés par [RobotsMali](https://huggingface.co/RobotsMali)
339
+ """)
340
+
341
+ # Gestionnaire d'événements
342
  run_btn.click(
343
  fn=pipeline,
344
  inputs=[audio_input, model_input],
345
+ outputs=[status, audio_info, text_output]
346
+ )
347
+
348
+ # Afficher les infos audio quand un fichier est chargé
349
+ def on_audio_upload(audio):
350
+ if audio:
351
+ info = get_audio_info(audio)
352
+ if info:
353
+ duration_str = format_time(info['duration'])
354
+ return f"""
355
+ ### ✅ Audio chargé
356
+ - Durée : {duration_str} ({info['duration']:.1f}s)
357
+ - Taille : {info['size']}
358
+ - Format : {info['sample_rate']}Hz, {info['channels']} canaux
359
+
360
+ Cliquez sur **DÉMARRER** pour transcrire !
361
+ """
362
+ else:
363
+ return "✅ Audio chargé. Prêt pour la transcription."
364
+ return "ℹ️ Chargez un audio pour voir ses informations"
365
+
366
+ audio_input.change(
367
+ fn=on_audio_upload,
368
+ inputs=audio_input,
369
+ outputs=audio_info
370
  )
371
 
372
  # Point d'entrée - Configuration pour Hugging Face Spaces
373
  if __name__ == "__main__":
374
  print("🚀 Lancement de l'application RobotsMali ASR...")
375
+ print(f"📊 Mémoire système: {humanize.naturalsize(psutil.virtual_memory().total)}")
376
+ print(f"🎯 Device: {DEVICE}")
377
 
378
  # Configuration simple pour Spaces
379
  demo.queue().launch(
380
  server_name="0.0.0.0",
381
+ server_port=7860,
382
+ show_error=True
383
  )