binaryMao commited on
Commit
599bc18
·
verified ·
1 Parent(s): 7b18ccd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -89
app.py CHANGED
@@ -1,134 +1,136 @@
1
  # -*- coding: utf-8 -*-
2
- import os
3
- import tempfile
4
- import glob
5
- import shutil
6
  import torch
7
- import gradio as gr
8
  from huggingface_hub import snapshot_download
9
  from nemo.collections import asr as nemo_asr
 
10
 
11
- # Configuration
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
- SEGMENT_DURATION = 10.0
14
 
15
- # Modèle unique pour CPU (CTC est plus léger)
16
  MODELS = {
17
- "Soloni V3 (Recommandé CPU)": "RobotsMali/soloni-114m-tdt-ctc-v3"
 
 
 
 
 
 
 
 
18
  }
19
 
20
  _cache = {}
21
 
22
- # ==========================
23
- # CHARGEMENT DU MODÈLE
24
- # ==========================
25
- def get_model(model_name):
26
- if model_name in _cache:
27
- return _cache[model_name]
28
-
29
- repo_id = MODELS[model_name]
30
- print(f"⏳ Chargement du modèle : {repo_id}")
31
 
32
- # Téléchargement local
33
- folder = snapshot_download(repo_id, local_dir_use_symlinks=False)
34
-
35
- # Trouve le fichier .nemo
36
- try:
37
- nemo_file = next(
38
- os.path.join(folder, f)
39
- for f in os.listdir(folder)
40
- if f.endswith(".nemo")
41
- )
42
- except StopIteration:
43
- raise FileNotFoundError("Fichier .nemo introuvable.")
44
-
45
- # Restauration générique (Auto-détecte CTC/RNNT)
46
- model = nemo_asr.models.ASRModel.restore_from(
47
- nemo_file,
48
- map_location=torch.device(DEVICE)
49
- )
50
 
 
 
51
  model.eval()
52
- _cache[model_name] = model
 
 
 
 
53
  return model
54
 
55
- # ==========================
56
- # LOGIQUE DE TRANSCRIPTION
57
- # ==========================
58
- def pipeline(audio_path, model_name):
59
- if not audio_path:
60
- yield "❌ Erreur", "Veuillez fournir un fichier audio."
61
- return
62
-
63
  tmp_dir = tempfile.mkdtemp()
64
-
65
  try:
66
- yield "⏳ Traitement audio...", ""
 
 
67
 
68
- # Segmentation FFmpeg (16kHz Mono requis)
69
- # On utilise une syntaxe simple pour éviter les erreurs de shell
70
- os.system(
71
- f"ffmpeg -y -i '{audio_path}' -f segment -segment_time {SEGMENT_DURATION} "
72
- f"-ac 1 -ar 16000 {tmp_dir}/seg_%03d.wav > /dev/null 2>&1"
73
- )
 
 
74
 
75
- segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
76
- segments = [s for s in segments if os.path.getsize(s) > 1000]
77
 
78
- if not segments:
79
- yield "❌ Erreur", "Impossible de segmenter l'audio."
80
  return
81
 
82
- yield f"🎙️ Transcription ({len(segments)} segments)...", ""
83
-
84
  model = get_model(model_name)
85
 
86
  with torch.inference_mode():
87
- # Transcription par batch pour le CPU
88
- results = model.transcribe(segments, batch_size=2, num_workers=0)
 
 
 
 
 
89
 
90
- # Gestion des différents formats de sortie de NeMo
91
- if isinstance(results, tuple):
92
- text_list = results[0]
93
- else:
94
- text_list = results
 
95
 
96
- final_text = " ".join(text_list)
97
- yield "✅ Terminé", final_text
 
98
 
99
  except Exception as e:
100
- yield "❌ Erreur Système", str(e)
101
-
102
  finally:
103
  if os.path.exists(tmp_dir):
104
  shutil.rmtree(tmp_dir)
105
 
106
- # ==========================
107
- # INTERFACE UTILISATEUR
108
- # ==========================
109
- with gr.Blocks() as demo:
110
- gr.Markdown("# 🤖 RobotsMali ASR")
111
- gr.Markdown("Outil de transcription automatique optimisé pour le CPU.")
112
-
113
  with gr.Row():
114
  with gr.Column():
115
- audio_input = gr.Audio(type="filepath", label="Audio")
116
  model_input = gr.Dropdown(
117
- choices=list(MODELS.keys()),
118
- value="Soloni V3 (Recommandé CPU)",
119
- label="Modèle"
120
  )
121
- btn = gr.Button("🚀 Transcrire", variant="primary")
122
-
123
  with gr.Column():
124
- status_output = gr.Textbox(label="Statut", interactive=False)
125
- text_output = gr.Textbox(label="Texte Transcrit", lines=10)
126
 
127
- btn.click(
128
- fn=pipeline,
129
- inputs=[audio_input, model_input],
130
- outputs=[status_output, text_output]
131
  )
132
 
 
 
133
  if __name__ == "__main__":
134
  demo.launch()
 
1
  # -*- coding: utf-8 -*-
2
+ import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil
 
 
 
3
  import torch
 
4
  from huggingface_hub import snapshot_download
5
  from nemo.collections import asr as nemo_asr
6
+ import gradio as gr
7
 
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+ SEGMENT_DURATION = 10.0 # Augmenté à 10s pour l'audio pur (plus efficace)
10
 
11
+ # --- CONFIGURATION DES MODÈLES (Identique au script vidéo) ---
12
  MODELS = {
13
+ "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
14
+ "Soloba V2 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"),
15
+ "Soloba V1 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"),
16
+ "Soloba V1.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v1.5", "rnnt"),
17
+ "Soloba V0.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v0.5", "rnnt"),
18
+ "Soloni V3 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v3", "rnnt"),
19
+ "Soloni V2 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v2", "rnnt"),
20
+ "Soloni V1 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"),
21
+ "Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"),
22
  }
23
 
24
  _cache = {}
25
 
26
+ def get_model(name):
27
+ """Charge le modèle avec gestion de la mémoire."""
28
+ if name in _cache:
29
+ return _cache[name]
 
 
 
 
 
30
 
31
+ # Libérer la mémoire des anciens modèles si nécessaire
32
+ if len(_cache) > 0:
33
+ _cache.clear()
34
+ gc.collect()
35
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
36
+
37
+ repo, _ = MODELS[name]
38
+ print(f"⏳ Chargement de {name} depuis {repo}...")
39
+
40
+ folder = snapshot_download(repo, local_dir_use_symlinks=False)
41
+ nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
 
 
 
 
 
 
 
42
 
43
+ # Chargement via ASRModel (Factory)
44
+ model = nemo_asr.models.ASRModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
45
  model.eval()
46
+
47
+ if DEVICE == "cuda":
48
+ model = model.half()
49
+
50
+ _cache[name] = model
51
  return model
52
 
53
+ # --- PIPELINE ASR ---
54
+ def pipeline(audio_in, model_name):
 
 
 
 
 
 
55
  tmp_dir = tempfile.mkdtemp()
 
56
  try:
57
+ if not audio_in:
58
+ yield "❌ Fichier audio manquant", ""
59
+ return
60
 
61
+ yield "⏳ Traitement audio & Segmentation...", ""
62
+
63
+ # Normalisation : Mono, 16kHz
64
+ wav_path = os.path.join(tmp_dir, "clean.wav")
65
+ subprocess.run(f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {wav_path}", shell=True, check=True)
66
+
67
+ # Découpage en segments
68
+ subprocess.run(f"ffmpeg -i {wav_path} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True)
69
 
70
+ valid_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
71
+ valid_segments = [f for f in valid_segments if os.path.getsize(f) > 1000]
72
 
73
+ if not valid_segments:
74
+ yield "❌ Erreur de segmentation", ""
75
  return
76
 
77
+ yield f"🎙️ Transcription de {len(valid_segments)} segments...", ""
78
+
79
  model = get_model(model_name)
80
 
81
  with torch.inference_mode():
82
+ # Utilisation de la méthode stable sans Lhotse
83
+ batch_hyp = model.transcribe(
84
+ valid_segments,
85
+ batch_size=4,
86
+ return_hypotheses=True,
87
+ num_workers=0
88
+ )
89
 
90
+ # Extraction du texte
91
+ results = []
92
+ for hyp in batch_hyp:
93
+ text = hyp.text if hasattr(hyp, 'text') else str(hyp)
94
+ if text:
95
+ results.append(text)
96
 
97
+ final_text = " ".join(results)
98
+
99
+ yield "✅ Transcription terminée", final_text
100
 
101
  except Exception as e:
102
+ print(traceback.format_exc())
103
+ yield f"❌ Erreur : {str(e)}", ""
104
  finally:
105
  if os.path.exists(tmp_dir):
106
  shutil.rmtree(tmp_dir)
107
 
108
+ # --- INTERFACE GRADIO ---
109
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
110
+ gr.Markdown("# 🤖 RobotsMali Speech-to-Text")
111
+ gr.Markdown("Transcription multi-modèles pour les langues du Mali.")
112
+
 
 
113
  with gr.Row():
114
  with gr.Column():
115
+ audio_input = gr.Audio(label="Audio (Fichier ou Micro)", type="filepath")
116
  model_input = gr.Dropdown(
117
+ choices=list(MODELS.keys()),
118
+ value="Soloni V3 (TDT-CTC)",
119
+ label="Choisir un modèle"
120
  )
121
+ run_btn = gr.Button("🚀 TRANSCRIRE", variant="primary")
122
+
123
  with gr.Column():
124
+ status = gr.Markdown("### État : Prêt")
125
+ text_output = gr.Textbox(label="Résultat", lines=15, show_copy_button=True)
126
 
127
+ run_btn.click(
128
+ fn=pipeline,
129
+ inputs=[audio_input, model_input],
130
+ outputs=[status, text_output]
131
  )
132
 
133
+ gr.HTML("<center><p style='color: gray;'>Optimisé pour CPU & GPU Hugging Face</p></center>")
134
+
135
  if __name__ == "__main__":
136
  demo.launch()