binaryMao commited on
Commit
55e23fc
·
verified ·
1 Parent(s): 78be54b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -37
app.py CHANGED
@@ -1,18 +1,14 @@
1
- import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil
2
  import torch
3
  from huggingface_hub import snapshot_download
4
  from nemo.collections import asr as nemo_asr
5
  from nemo.collections.asr.models import EncDecCTCModel, EncDecRNNTModel
6
  import gradio as gr
7
 
8
- # Mise à jour de Gradio (optionnel, peut être commenté après la première exécution)
9
- try:
10
- subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "gradio", "gradio-client"])
11
- except:
12
- pass
13
-
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
  SEGMENT_DURATION = 10.0
 
16
 
17
  # Dictionnaire des modèles RobotsMali
18
  MODELS = {
@@ -33,30 +29,46 @@ def get_model(name):
33
  if name in _cache:
34
  return _cache[name]
35
 
 
 
36
  # Gestion agressive de la mémoire
37
  if len(_cache) >= 1:
38
  _cache.clear()
39
  gc.collect()
40
- if torch.cuda.is_available(): torch.cuda.empty_cache()
41
-
42
- repo, arch_type = MODELS[name]
43
- folder = snapshot_download(repo, local_dir_use_symlinks=False)
44
- nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
45
 
46
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  if arch_type == "ctc":
48
  model = EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
49
  else:
50
  model = EncDecRNNTModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
51
- except Exception:
52
- model = nemo_asr.models.ASRModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
53
-
54
- model.eval()
55
- if DEVICE == "cuda":
56
- model = model.half()
57
-
58
- _cache[name] = model
59
- return model
 
 
 
 
60
 
61
  def pipeline(audio_in, model_name):
62
  if not audio_in:
@@ -66,17 +78,40 @@ def pipeline(audio_in, model_name):
66
  tmp_dir = tempfile.mkdtemp()
67
  try:
68
  yield "⏳ Traitement de l'audio...", ""
 
 
 
 
 
 
69
  wav_path = os.path.join(tmp_dir, "input.wav")
70
- # Normalisation audio via FFmpeg
71
- subprocess.run(f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {wav_path}", shell=True, check=True)
72
- subprocess.run(f"ffmpeg -i {wav_path} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  valid_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
75
  if not valid_segments:
76
  yield "❌ Erreur", "Fichier audio vide ou incompatible."
77
  return
78
-
 
 
79
  model = get_model(model_name)
 
80
  with torch.inference_mode():
81
  batch_hyp = model.transcribe(valid_segments, batch_size=4, return_hypotheses=True)
82
 
@@ -87,28 +122,42 @@ def pipeline(audio_in, model_name):
87
  print(traceback.format_exc())
88
  yield "❌ Erreur", str(e)
89
  finally:
90
- if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)
 
91
 
92
- # Interface Gradio optimisée
93
- with gr.Blocks(title="RobotsMali ASR") as demo:
94
  gr.Markdown("# 🤖 RobotsMali - Reconnaissance Vocale")
 
95
  with gr.Row():
96
  with gr.Column():
97
- audio_input = gr.Audio(label="Audio", type="filepath", sources=["upload", "microphone"])
98
- model_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle")
 
 
 
 
 
 
 
 
99
  run_btn = gr.Button("🚀 DÉMARRER", variant="primary")
 
100
  with gr.Column():
101
  status = gr.Markdown("### État : En attente")
102
  text_output = gr.Textbox(label="Transcription", lines=12)
103
 
104
- run_btn.click(fn=pipeline, inputs=[audio_input, model_input], outputs=[status, text_output])
 
 
 
 
105
 
 
106
  if __name__ == "__main__":
107
- # Configuration pour les environnements cloud
 
108
  demo.queue().launch(
109
- server_name="0.0.0.0",
110
- server_port=7860,
111
- show_api=False,
112
- share=True, # Important : crée un lien public
113
- debug=True # Ajoutez ceci pour voir plus de détails en cas d'erreur
114
  )
 
1
+ import os, shlex, subprocess, tempfile, traceback, glob, gc, shutil
2
  import torch
3
  from huggingface_hub import snapshot_download
4
  from nemo.collections import asr as nemo_asr
5
  from nemo.collections.asr.models import EncDecCTCModel, EncDecRNNTModel
6
  import gradio as gr
7
 
8
+ # Configuration
 
 
 
 
 
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
  SEGMENT_DURATION = 10.0
11
+ print(f"✅ Démarrage sur device: {DEVICE}")
12
 
13
  # Dictionnaire des modèles RobotsMali
14
  MODELS = {
 
29
  if name in _cache:
30
  return _cache[name]
31
 
32
+ print(f"📥 Chargement du modèle: {name}")
33
+
34
  # Gestion agressive de la mémoire
35
  if len(_cache) >= 1:
36
  _cache.clear()
37
  gc.collect()
38
+ if torch.cuda.is_available():
39
+ torch.cuda.empty_cache()
 
 
 
40
 
41
  try:
42
+ repo, arch_type = MODELS[name]
43
+ print(f"📦 Téléchargement depuis {repo}...")
44
+
45
+ folder = snapshot_download(repo, local_dir_use_symlinks=False)
46
+ print(f"📁 Dossier: {folder}")
47
+
48
+ nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
49
+
50
+ if nemo_file is None:
51
+ raise FileNotFoundError(f"Aucun fichier .nemo trouvé dans {folder}")
52
+
53
+ print(f"🔧 Restauration du modèle depuis {nemo_file}")
54
+
55
  if arch_type == "ctc":
56
  model = EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
57
  else:
58
  model = EncDecRNNTModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
59
+
60
+ model.eval()
61
+ if DEVICE == "cuda":
62
+ model = model.half()
63
+
64
+ print(f"✅ Modèle {name} chargé avec succès")
65
+ _cache[name] = model
66
+ return model
67
+
68
+ except Exception as e:
69
+ print(f"❌ Erreur lors du chargement du modèle {name}:")
70
+ print(traceback.format_exc())
71
+ raise e
72
 
73
  def pipeline(audio_in, model_name):
74
  if not audio_in:
 
78
  tmp_dir = tempfile.mkdtemp()
79
  try:
80
  yield "⏳ Traitement de l'audio...", ""
81
+
82
+ # Vérification que le fichier audio existe
83
+ if not os.path.exists(audio_in):
84
+ yield "❌ Erreur", f"Fichier audio introuvable: {audio_in}"
85
+ return
86
+
87
  wav_path = os.path.join(tmp_dir, "input.wav")
88
+
89
+ # Conversion audio
90
+ cmd = f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {shlex.quote(wav_path)}"
91
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
92
+
93
+ if result.returncode != 0:
94
+ yield "❌ Erreur", f"Erreur FFmpeg: {result.stderr}"
95
+ return
96
+
97
+ if not os.path.exists(wav_path) or os.path.getsize(wav_path) == 0:
98
+ yield "❌ Erreur", "Fichier audio converti vide"
99
+ return
100
+
101
+ # Segmentation
102
+ seg_pattern = os.path.join(tmp_dir, 'seg_%03d.wav')
103
+ cmd = f"ffmpeg -i {shlex.quote(wav_path)} -f segment -segment_time {SEGMENT_DURATION} -c copy {shlex.quote(seg_pattern)}"
104
+ subprocess.run(cmd, shell=True, capture_output=True)
105
 
106
  valid_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
107
  if not valid_segments:
108
  yield "❌ Erreur", "Fichier audio vide ou incompatible."
109
  return
110
+
111
+ print(f"🔊 {len(valid_segments)} segments à transcrire")
112
+
113
  model = get_model(model_name)
114
+
115
  with torch.inference_mode():
116
  batch_hyp = model.transcribe(valid_segments, batch_size=4, return_hypotheses=True)
117
 
 
122
  print(traceback.format_exc())
123
  yield "❌ Erreur", str(e)
124
  finally:
125
+ if os.path.exists(tmp_dir):
126
+ shutil.rmtree(tmp_dir)
127
 
128
+ # Interface Gradio
129
+ with gr.Blocks(title="RobotsMali ASR", theme=gr.themes.Soft()) as demo:
130
  gr.Markdown("# 🤖 RobotsMali - Reconnaissance Vocale")
131
+
132
  with gr.Row():
133
  with gr.Column():
134
+ audio_input = gr.Audio(
135
+ label="Audio",
136
+ type="filepath",
137
+ sources=["upload", "microphone"]
138
+ )
139
+ model_input = gr.Dropdown(
140
+ choices=list(MODELS.keys()),
141
+ value="Soloni V3 (TDT-CTC)",
142
+ label="Modèle"
143
+ )
144
  run_btn = gr.Button("🚀 DÉMARRER", variant="primary")
145
+
146
  with gr.Column():
147
  status = gr.Markdown("### État : En attente")
148
  text_output = gr.Textbox(label="Transcription", lines=12)
149
 
150
+ run_btn.click(
151
+ fn=pipeline,
152
+ inputs=[audio_input, model_input],
153
+ outputs=[status, text_output]
154
+ )
155
 
156
+ # Point d'entrée - CORRECTION ICI
157
  if __name__ == "__main__":
158
+ print("🚀 Lancement de l'application RobotsMali ASR...")
159
+ # Désactiver l'API pour éviter le bug
160
  demo.queue().launch(
161
+ server_name="0.0.0.0",
162
+ show_api=False # ← Ceci corrige l'erreur
 
 
 
163
  )