binaryMao commited on
Commit
0dbfd3b
·
verified ·
1 Parent(s): 42c75c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -45
app.py CHANGED
@@ -1,110 +1,155 @@
1
  # -*- coding: utf-8 -*-
2
- import os, tempfile, glob, gc
 
 
 
3
  import torch
4
  import gradio as gr
5
  from huggingface_hub import snapshot_download
6
  from nemo.collections import asr as nemo_asr
7
 
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
  SEGMENT_DURATION = 10.0
10
 
11
- # ⚠️ CPU Basic → modèle léger uniquement
12
  MODELS = {
13
- "Soloni V3 (Recommandé CPU)": ("RobotsMali/soloni-114m-tdt-ctc-v3", "rnnt")
14
  }
15
 
 
16
  _cache = {}
17
 
18
  # ==========================
19
- # MODEL LOADER (identique logique vidéo)
20
  # ==========================
21
- def get_model(name):
22
- if name in _cache:
23
- return _cache[name]
24
 
25
- repo, _ = MODELS[name]
26
- folder = snapshot_download(repo, local_dir_use_symlinks=False)
 
 
 
27
 
28
- nemo_file = next(
29
- os.path.join(folder, f)
30
- for f in os.listdir(folder)
31
- if f.endswith(".nemo")
32
- )
 
 
 
 
33
 
 
 
 
34
  model = nemo_asr.models.ASRModel.restore_from(
35
  nemo_file,
36
  map_location=torch.device(DEVICE)
37
  )
38
 
39
  model.eval()
40
-
41
- _cache[name] = model
42
  return model
43
 
44
  # ==========================
45
- # PIPELINE (même logique que vidéo)
46
  # ==========================
47
- def pipeline(audio_file, model_name):
48
- tmp_dir = tempfile.mkdtemp()
 
 
49
 
 
 
 
50
  try:
51
- if not audio_file:
52
- yield "❌ Aucun fichier audio", None
53
- return
54
-
55
- yield "⏳ Segmentation...", None
56
 
57
- # Découpe en segments
58
- os.system(
59
- f"ffmpeg -y -i '{audio_file}' -f segment -segment_time {SEGMENT_DURATION} "
60
- f"-ac 1 -ar 16000 -c copy {tmp_dir}/seg_%03d.wav"
 
61
  )
 
62
 
63
- segments = sorted(glob.glob(f"{tmp_dir}/seg_*.wav"))
 
 
 
64
  segments = [s for s in segments if os.path.getsize(s) > 1000]
65
 
66
- yield f"🎙️ Transcription {len(segments)} segments...", None
 
 
 
 
67
 
 
68
  model = get_model(model_name)
69
 
 
70
  with torch.inference_mode():
71
  results = model.transcribe(
72
  segments,
73
  batch_size=4,
74
- num_workers=0
75
  )
76
 
77
- text = "\n".join(results)
 
 
 
 
 
 
 
78
 
79
- yield "✅ Transcription terminée", text
80
 
81
  except Exception as e:
82
- yield f"❌ Erreur : {str(e)}", None
83
 
84
  finally:
 
85
  if os.path.exists(tmp_dir):
86
- import shutil
87
  shutil.rmtree(tmp_dir)
88
 
89
  # ==========================
90
- # UI
91
  # ==========================
92
- with gr.Blocks() as demo:
93
- gr.Markdown("# 🤖 RobotsMali ASR – CPU Edition")
 
 
 
94
 
95
  with gr.Row():
96
  with gr.Column():
97
- audio_input = gr.Audio(type="filepath", label="Audio")
98
  model_input = gr.Dropdown(
99
  choices=list(MODELS.keys()),
100
- value="Soloni V3 (Recommandé CPU)"
 
101
  )
102
- btn = gr.Button("🚀 Transcrire")
103
 
104
  with gr.Column():
105
- status = gr.Markdown("Prêt.")
106
- output = gr.Textbox(lines=15)
 
 
 
 
 
 
 
107
 
108
- btn.click(pipeline, [audio_input, model_input], [status, output])
109
 
110
- demo.launch()
 
 
1
  # -*- coding: utf-8 -*-
2
+ import os
3
+ import tempfile
4
+ import glob
5
+ import shutil
6
  import torch
7
  import gradio as gr
8
  from huggingface_hub import snapshot_download
9
  from nemo.collections import asr as nemo_asr
10
 
11
+ # Configuration du matériel
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
  SEGMENT_DURATION = 10.0
14
 
15
+ # Dictionnaire des modèles
16
  MODELS = {
17
+ "Soloni V3 (Recommandé CPU)": "RobotsMali/soloni-114m-tdt-ctc-v3"
18
  }
19
 
20
+ # Cache pour éviter de recharger le modèle à chaque clic
21
  _cache = {}
22
 
23
  # ==========================
24
+ # CHARGEMENT DU MODÈLE
25
  # ==========================
26
+ def get_model(model_name):
27
+ if model_name in _cache:
28
+ return _cache[model_name]
29
 
30
+ repo_id = MODELS[model_name]
31
+ print(f"⏳ Téléchargement du modèle depuis {repo_id}...")
32
+
33
+ # Téléchargement depuis Hugging Face
34
+ folder = snapshot_download(repo_id, local_dir_use_symlinks=False)
35
 
36
+ # Recherche automatique du fichier .nemo dans le dossier téléchargé
37
+ try:
38
+ nemo_file = next(
39
+ os.path.join(folder, f)
40
+ for f in os.listdir(folder)
41
+ if f.endswith(".nemo")
42
+ )
43
+ except StopIteration:
44
+ raise FileNotFoundError("Aucun fichier .nemo trouvé dans le dépôt.")
45
 
46
+ print(f"📦 Restauration du modèle NeMo sur {DEVICE}...")
47
+
48
+ # Utilisation de ASRModel pour l'auto-détection (CTC/RNNT)
49
  model = nemo_asr.models.ASRModel.restore_from(
50
  nemo_file,
51
  map_location=torch.device(DEVICE)
52
  )
53
 
54
  model.eval()
55
+ _cache[model_name] = model
 
56
  return model
57
 
58
  # ==========================
59
+ # PIPELINE DE TRANSCRIPTION
60
  # ==========================
61
+ def pipeline(audio_path, model_name):
62
+ if not audio_path:
63
+ yield "❌ Aucun fichier audio fourni", ""
64
+ return
65
 
66
+ # Création d'un dossier temporaire unique pour les segments
67
+ tmp_dir = tempfile.mkdtemp()
68
+
69
  try:
70
+ yield "⏳ Préparation de l'audio (FFmpeg)...", ""
 
 
 
 
71
 
72
+ # Normalisation et segmentation de l'audio
73
+ # -ac 1 (mono), -ar 16000 (16kHz requis par NeMo)
74
+ command = (
75
+ f"ffmpeg -y -i '{audio_path}' -f segment -segment_time {SEGMENT_DURATION} "
76
+ f"-ac 1 -ar 16000 {tmp_dir}/seg_%03d.wav > /dev/null 2>&1"
77
  )
78
+ os.system(command)
79
 
80
+ # Liste et tri des segments générés
81
+ segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
82
+
83
+ # Filtrer les fichiers trop petits (silences ou erreurs FFmpeg)
84
  segments = [s for s in segments if os.path.getsize(s) > 1000]
85
 
86
+ if not segments:
87
+ yield "❌ Erreur lors de la segmentation audio", ""
88
+ return
89
+
90
+ yield f"🎙️ Transcription de {len(segments)} segments en cours...", ""
91
 
92
+ # Récupération du modèle
93
  model = get_model(model_name)
94
 
95
+ # Inférence
96
  with torch.inference_mode():
97
  results = model.transcribe(
98
  segments,
99
  batch_size=4,
100
+ num_workers=0 # Important sur CPU pour éviter les fuites de mémoire
101
  )
102
 
103
+ # Reconstruction du texte
104
+ # Note: selon le modèle, results peut être une liste de chaînes ou un objet complexe
105
+ if isinstance(results, tuple):
106
+ text_results = results[0]
107
+ else:
108
+ text_results = results
109
+
110
+ final_text = " ".join(text_results)
111
 
112
+ yield "✅ Transcription terminée", final_text
113
 
114
  except Exception as e:
115
+ yield f"❌ Erreur système : {str(e)}", ""
116
 
117
  finally:
118
+ # Nettoyage automatique du dossier temporaire
119
  if os.path.exists(tmp_dir):
 
120
  shutil.rmtree(tmp_dir)
121
 
122
  # ==========================
123
+ # INTERFACE GRADIO (UI)
124
  # ==========================
125
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
126
+ gr.Markdown("""
127
+ # 🤖 RobotsMali ASR – Edition CPU
128
+ Déposez un fichier audio pour obtenir une transcription automatique.
129
+ """)
130
 
131
  with gr.Row():
132
  with gr.Column():
133
+ audio_input = gr.Audio(type="filepath", label="Fichier Audio (WAV, MP3, etc.)")
134
  model_input = gr.Dropdown(
135
  choices=list(MODELS.keys()),
136
+ value="Soloni V3 (Recommandé CPU)",
137
+ label="Sélection du Modèle"
138
  )
139
+ btn = gr.Button("🚀 Lancer la Transcription", variant="primary")
140
 
141
  with gr.Column():
142
+ status = gr.Markdown("### Statut : Prêt")
143
+ output = gr.Textbox(label="Résultat de la transcription", lines=12, show_copy_button=True)
144
+
145
+ # Interaction
146
+ btn.click(
147
+ fn=pipeline,
148
+ inputs=[audio_input, model_input],
149
+ outputs=[status, output]
150
+ )
151
 
152
+ gr.Markdown("--- \n *Optimisé pour les environnements à ressources limitées.*")
153
 
154
+ if __name__ == "__main__":
155
+ demo.launch()