binaryMao commited on
Commit
5c26bdc
·
verified ·
1 Parent(s): 3fc0845

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -42
app.py CHANGED
@@ -1,16 +1,13 @@
1
- # -*- coding: utf-8 -*-
2
  import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil
3
  import torch
4
  from huggingface_hub import snapshot_download
5
  from nemo.collections import asr as nemo_asr
6
- # Imports spécifiques pour éviter l'erreur "Abstract Class"
7
  from nemo.collections.asr.models import EncDecCTCModel, EncDecRNNTModel
8
  import gradio as gr
9
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
- SEGMENT_DURATION = 10.0
12
 
13
- # Dictionnaire complet (Nom: (Repo, Type))
14
  MODELS = {
15
  "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
16
  "Soloba V2 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"),
@@ -26,36 +23,30 @@ MODELS = {
26
  _cache = {}
27
 
28
  def get_model(name):
29
- """Charge le modèle en forçant la classe concrète (CTC ou RNNT)."""
30
  if name in _cache:
31
  return _cache[name]
32
 
33
- # Libération agressive de la RAM avant chargement
34
  if len(_cache) >= 1:
35
  _cache.clear()
36
  gc.collect()
37
  if torch.cuda.is_available(): torch.cuda.empty_cache()
38
 
39
  repo, arch_type = MODELS[name]
40
- print(f"⏳ Préparation du modèle {name} ({arch_type})...")
41
-
42
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
43
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
44
 
45
- # Utilisation de la classe spécifique pour contourner l'erreur ASRModel
46
  try:
47
  if arch_type == "ctc":
48
  model = EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
49
  else:
50
  model = EncDecRNNTModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
51
- except Exception as e:
52
- print(f"⚠️ Erreur de chargement spécifique, tentative générique : {e}")
53
  model = nemo_asr.models.ASRModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
54
 
55
  model.eval()
56
  if DEVICE == "cuda":
57
  model = model.half()
58
-
59
  _cache[name] = model
60
  return model
61
 
@@ -67,59 +58,34 @@ def pipeline(audio_in, model_name):
67
  tmp_dir = tempfile.mkdtemp()
68
  try:
69
  yield "⏳ Traitement de l'audio...", ""
70
-
71
- # Normalisation FFmpeg
72
  wav_path = os.path.join(tmp_dir, "input.wav")
73
  subprocess.run(f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {wav_path}", shell=True, check=True)
74
-
75
- # Segmentation
76
  subprocess.run(f"ffmpeg -i {wav_path} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True)
77
 
78
  valid_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
79
- valid_segments = [f for f in valid_segments if os.path.getsize(f) > 1000]
80
-
81
  if not valid_segments:
82
  yield "❌ Erreur", "Fichier audio vide ou incompatible."
83
  return
84
 
85
- yield f"🎙️ Transcription ({len(valid_segments)} segments)...", ""
86
-
87
  model = get_model(model_name)
88
-
89
  with torch.inference_mode():
90
- # Mode stable sans Lhotse
91
- batch_hyp = model.transcribe(
92
- valid_segments,
93
- batch_size=4,
94
- return_hypotheses=True,
95
- num_workers=0
96
- )
97
-
98
- results = []
99
- for hyp in batch_hyp:
100
- # Gère les formats de sortie CTC et RNNT
101
- text = hyp.text if hasattr(hyp, 'text') else str(hyp)
102
- if text: results.append(text)
103
 
 
104
  yield "✅ Succès", " ".join(results)
105
 
106
  except Exception as e:
107
- print(traceback.format_exc())
108
  yield "❌ Erreur", str(e)
109
  finally:
110
- if os.path.exists(tmp_dir):
111
- shutil.rmtree(tmp_dir)
112
 
113
- # --- UI GRADIO ---
114
- with gr.Blocks(theme=gr.themes.Default()) as demo:
115
  gr.Markdown("# 🤖 RobotsMali - Reconnaissance Vocale")
116
-
117
  with gr.Row():
118
  with gr.Column():
119
  audio_input = gr.Audio(label="Audio", type="filepath", sources=["upload", "microphone"])
120
  model_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle")
121
  run_btn = gr.Button("🚀 DÉMARRER", variant="primary")
122
-
123
  with gr.Column():
124
  status = gr.Markdown("### État : En attente")
125
  text_output = gr.Textbox(label="Transcription", lines=12)
@@ -127,4 +93,5 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
127
  run_btn.click(fn=pipeline, inputs=[audio_input, model_input], outputs=[status, text_output])
128
 
129
  if __name__ == "__main__":
130
- demo.launch()
 
 
 
1
  import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil
2
  import torch
3
  from huggingface_hub import snapshot_download
4
  from nemo.collections import asr as nemo_asr
 
5
  from nemo.collections.asr.models import EncDecCTCModel, EncDecRNNTModel
6
  import gradio as gr
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+ SEGMENT_DURATION = 10.0
10
 
 
11
  MODELS = {
12
  "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
13
  "Soloba V2 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"),
 
23
  _cache = {}
24
 
25
  def get_model(name):
 
26
  if name in _cache:
27
  return _cache[name]
28
 
 
29
  if len(_cache) >= 1:
30
  _cache.clear()
31
  gc.collect()
32
  if torch.cuda.is_available(): torch.cuda.empty_cache()
33
 
34
  repo, arch_type = MODELS[name]
 
 
35
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
36
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
37
 
 
38
  try:
39
  if arch_type == "ctc":
40
  model = EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
41
  else:
42
  model = EncDecRNNTModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
43
+ except Exception:
 
44
  model = nemo_asr.models.ASRModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
45
 
46
  model.eval()
47
  if DEVICE == "cuda":
48
  model = model.half()
49
+
50
  _cache[name] = model
51
  return model
52
 
 
58
  tmp_dir = tempfile.mkdtemp()
59
  try:
60
  yield "⏳ Traitement de l'audio...", ""
 
 
61
  wav_path = os.path.join(tmp_dir, "input.wav")
62
  subprocess.run(f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {wav_path}", shell=True, check=True)
 
 
63
  subprocess.run(f"ffmpeg -i {wav_path} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True)
64
 
65
  valid_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
 
 
66
  if not valid_segments:
67
  yield "❌ Erreur", "Fichier audio vide ou incompatible."
68
  return
69
 
 
 
70
  model = get_model(model_name)
 
71
  with torch.inference_mode():
72
+ batch_hyp = model.transcribe(valid_segments, batch_size=4, return_hypotheses=True)
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ results = [hyp.text if hasattr(hyp, 'text') else str(hyp) for hyp in batch_hyp]
75
  yield "✅ Succès", " ".join(results)
76
 
77
  except Exception as e:
 
78
  yield "❌ Erreur", str(e)
79
  finally:
80
+ if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)
 
81
 
82
+ with gr.Blocks(title="RobotsMali ASR") as demo:
 
83
  gr.Markdown("# 🤖 RobotsMali - Reconnaissance Vocale")
 
84
  with gr.Row():
85
  with gr.Column():
86
  audio_input = gr.Audio(label="Audio", type="filepath", sources=["upload", "microphone"])
87
  model_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle")
88
  run_btn = gr.Button("🚀 DÉMARRER", variant="primary")
 
89
  with gr.Column():
90
  status = gr.Markdown("### État : En attente")
91
  text_output = gr.Textbox(label="Transcription", lines=12)
 
93
  run_btn.click(fn=pipeline, inputs=[audio_input, model_input], outputs=[status, text_output])
94
 
95
  if __name__ == "__main__":
96
+ # Paramètres CRITIQUES pour Docker sur Hugging Face
97
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)