binaryMao commited on
Commit
0c8d6fd
·
verified ·
1 Parent(s): 599bc18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -43
app.py CHANGED
@@ -6,9 +6,8 @@ from nemo.collections import asr as nemo_asr
6
  import gradio as gr
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
- SEGMENT_DURATION = 10.0 # Augmenté à 10s pour l'audio pur (plus efficace)
10
 
11
- # --- CONFIGURATION DES MODÈLES (Identique au script vidéo) ---
12
  MODELS = {
13
  "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
14
  "Soloba V2 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"),
@@ -24,23 +23,21 @@ MODELS = {
24
  _cache = {}
25
 
26
  def get_model(name):
27
- """Charge le modèle avec gestion de la mémoire."""
28
  if name in _cache:
29
  return _cache[name]
30
 
31
- # Libérer la mémoire des anciens modèles si nécessaire
32
- if len(_cache) > 0:
33
  _cache.clear()
34
  gc.collect()
35
  if torch.cuda.is_available(): torch.cuda.empty_cache()
36
 
37
  repo, _ = MODELS[name]
38
- print(f"⏳ Chargement de {name} depuis {repo}...")
39
 
40
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
41
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
42
 
43
- # Chargement via ASRModel (Factory)
44
  model = nemo_asr.models.ASRModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
45
  model.eval()
46
 
@@ -50,36 +47,34 @@ def get_model(name):
50
  _cache[name] = model
51
  return model
52
 
53
- # --- PIPELINE ASR ---
54
  def pipeline(audio_in, model_name):
 
 
 
 
55
  tmp_dir = tempfile.mkdtemp()
56
  try:
57
- if not audio_in:
58
- yield "❌ Fichier audio manquant", ""
59
- return
60
-
61
- yield "⏳ Traitement audio & Segmentation...", ""
62
 
63
- # Normalisation : Mono, 16kHz
64
  wav_path = os.path.join(tmp_dir, "clean.wav")
 
65
  subprocess.run(f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {wav_path}", shell=True, check=True)
66
 
67
- # Découpage en segments
68
  subprocess.run(f"ffmpeg -i {wav_path} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True)
69
 
70
  valid_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
71
  valid_segments = [f for f in valid_segments if os.path.getsize(f) > 1000]
72
 
73
  if not valid_segments:
74
- yield "❌ Erreur de segmentation", ""
75
  return
76
 
77
- yield f"🎙️ Transcription de {len(valid_segments)} segments...", ""
78
 
79
  model = get_model(model_name)
80
 
81
  with torch.inference_mode():
82
- # Utilisation de la méthode stable sans Lhotse
83
  batch_hyp = model.transcribe(
84
  valid_segments,
85
  batch_size=4,
@@ -87,50 +82,35 @@ def pipeline(audio_in, model_name):
87
  num_workers=0
88
  )
89
 
90
- # Extraction du texte
91
  results = []
92
  for hyp in batch_hyp:
93
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
94
- if text:
95
- results.append(text)
96
 
97
- final_text = " ".join(results)
98
-
99
- yield "✅ Transcription terminée", final_text
100
 
101
  except Exception as e:
102
- print(traceback.format_exc())
103
- yield f"❌ Erreur : {str(e)}", ""
104
  finally:
105
  if os.path.exists(tmp_dir):
106
  shutil.rmtree(tmp_dir)
107
 
108
- # --- INTERFACE GRADIO ---
109
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
110
  gr.Markdown("# 🤖 RobotsMali Speech-to-Text")
111
- gr.Markdown("Transcription multi-modèles pour les langues du Mali.")
112
 
113
  with gr.Row():
114
  with gr.Column():
115
- audio_input = gr.Audio(label="Audio (Fichier ou Micro)", type="filepath")
116
- model_input = gr.Dropdown(
117
- choices=list(MODELS.keys()),
118
- value="Soloni V3 (TDT-CTC)",
119
- label="Choisir un modèle"
120
- )
121
  run_btn = gr.Button("🚀 TRANSCRIRE", variant="primary")
122
 
123
  with gr.Column():
124
  status = gr.Markdown("### État : Prêt")
125
- text_output = gr.Textbox(label="Résultat", lines=15, show_copy_button=True)
126
-
127
- run_btn.click(
128
- fn=pipeline,
129
- inputs=[audio_input, model_input],
130
- outputs=[status, text_output]
131
- )
132
 
133
- gr.HTML("<center><p style='color: gray;'>Optimisé pour CPU & GPU Hugging Face</p></center>")
134
 
135
  if __name__ == "__main__":
136
  demo.launch()
 
6
  import gradio as gr
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+ SEGMENT_DURATION = 10.0
10
 
 
11
  MODELS = {
12
  "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
13
  "Soloba V2 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"),
 
23
  _cache = {}
24
 
25
  def get_model(name):
 
26
  if name in _cache:
27
  return _cache[name]
28
 
29
+ # NETTOYAGE MÉMOIRE AGRESSIF (Crucial pour les modèles 0.6b sur CPU Free)
30
+ if len(_cache) >= 1:
31
  _cache.clear()
32
  gc.collect()
33
  if torch.cuda.is_available(): torch.cuda.empty_cache()
34
 
35
  repo, _ = MODELS[name]
36
+ print(f"⏳ Chargement de {name}...")
37
 
38
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
39
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
40
 
 
41
  model = nemo_asr.models.ASRModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
42
  model.eval()
43
 
 
47
  _cache[name] = model
48
  return model
49
 
 
50
  def pipeline(audio_in, model_name):
51
+ if not audio_in:
52
+ yield "❌ Erreur", "Veuillez fournir un fichier audio."
53
+ return
54
+
55
  tmp_dir = tempfile.mkdtemp()
56
  try:
57
+ yield "⏳ Préparation de l'audio...", ""
 
 
 
 
58
 
 
59
  wav_path = os.path.join(tmp_dir, "clean.wav")
60
+ # Conversion simple et robuste
61
  subprocess.run(f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {wav_path}", shell=True, check=True)
62
 
63
+ # Segmentation
64
  subprocess.run(f"ffmpeg -i {wav_path} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True)
65
 
66
  valid_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
67
  valid_segments = [f for f in valid_segments if os.path.getsize(f) > 1000]
68
 
69
  if not valid_segments:
70
+ yield "❌ Erreur", "Audio trop court ou illisible."
71
  return
72
 
73
+ yield f"🎙️ Transcription ({len(valid_segments)} segments)...", ""
74
 
75
  model = get_model(model_name)
76
 
77
  with torch.inference_mode():
 
78
  batch_hyp = model.transcribe(
79
  valid_segments,
80
  batch_size=4,
 
82
  num_workers=0
83
  )
84
 
 
85
  results = []
86
  for hyp in batch_hyp:
87
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
88
+ if text: results.append(text)
 
89
 
90
+ yield "✅ Terminé", " ".join(results)
 
 
91
 
92
  except Exception as e:
93
+ yield "❌ Erreur", str(e)
 
94
  finally:
95
  if os.path.exists(tmp_dir):
96
  shutil.rmtree(tmp_dir)
97
 
98
+ # --- INTERFACE (Argument show_copy_button supprimé) ---
99
+ with gr.Blocks() as demo:
100
  gr.Markdown("# 🤖 RobotsMali Speech-to-Text")
 
101
 
102
  with gr.Row():
103
  with gr.Column():
104
+ audio_input = gr.Audio(label="Audio", type="filepath")
105
+ model_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle")
 
 
 
 
106
  run_btn = gr.Button("🚀 TRANSCRIRE", variant="primary")
107
 
108
  with gr.Column():
109
  status = gr.Markdown("### État : Prêt")
110
+ # Correction ici : show_copy_button est retiré pour compatibilité
111
+ text_output = gr.Textbox(label="Résultat", lines=15)
 
 
 
 
 
112
 
113
+ run_btn.click(fn=pipeline, inputs=[audio_input, model_input], outputs=[status, text_output])
114
 
115
  if __name__ == "__main__":
116
  demo.launch()