binaryMao commited on
Commit
84226a5
·
verified ·
1 Parent(s): 5484bb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -140
app.py CHANGED
@@ -5,8 +5,9 @@ from huggingface_hub import snapshot_download
5
  from nemo.collections import asr as nemo_asr
6
  import gradio as gr
7
 
8
- # 1. CONFIGURATION ET MODÈLES
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
10
 
11
  MODELS = {
12
  "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
@@ -20,27 +21,9 @@ MODELS = {
20
  "Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"),
21
  }
22
 
23
- def find_example_video():
24
- paths = ["examples/MARALINKE_FIXED.mp4", "examples/MARALINKE.mp4", "MARALINKE.mp4"]
25
- for p in paths:
26
- if os.path.exists(p): return p
27
-
28
- # Si aucun fichier local, on télécharge un exemple
29
- print("⬇️ Téléchargement de la vidéo d'exemple...")
30
- example_url = "https://huggingface.co/spaces/RobotsMali/Soloni-Demo/resolve/main/examples/MARALINKE.mp4"
31
- target_path = "examples/MARALINKE.mp4"
32
- os.makedirs("examples", exist_ok=True)
33
- try:
34
- subprocess.run(f"wget {example_url} -O {target_path}", shell=True, check=True)
35
- return target_path
36
- except Exception as e:
37
- print(f"⚠️ Impossible de télécharger l'exemple : {e}")
38
- return None
39
-
40
- EXAMPLE_PATH = find_example_video()
41
  _cache = {}
42
 
43
- # 2. GESTION MÉMOIRE ET CHARGEMENT (AVEC CORRECTIF STATE_DICT)
44
  def clear_memory():
45
  _cache.clear()
46
  gc.collect()
@@ -54,105 +37,34 @@ def get_model(name):
54
 
55
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
56
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
57
-
58
  if not nemo_file: raise FileNotFoundError("Fichier .nemo introuvable.")
59
 
60
- # Méthode la plus simple et robuste pour charger le modèle
61
- print(f"📦 Chargement du modèle depuis : {nemo_file}")
62
 
63
- # On laisse NeMo gérer tout automatiquement
64
- model = nemo_asr.models.ASRModel.restore_from(nemo_file)
65
 
66
- # Déplacement vers le device approprié
67
- model = model.to(DEVICE)
68
- model.eval()
 
 
69
 
70
- # Optimisation mémoire pour GPU
71
  if DEVICE == "cuda":
72
- try:
73
- model = model.half()
74
- except Exception as e:
75
- print(f"⚠️ Impossible de convertir en half precision: {e}")
76
 
77
  _cache[name] = model
78
  return model
79
 
80
-
81
- # 3. UTILITAIRES
82
  def format_srt_time(sec):
83
- td = time.gmtime(sec)
84
  ms = int((sec - int(sec)) * 1000)
85
  return f"{time.strftime('%H:%M:%S', td)},{ms:03}"
86
 
87
- # 4. PIPELINE DE TRANSCRIPTION (OPTIMISÉ)
88
- def detect_silences(path, min_silence_len=0.3, silence_thresh=-35):
89
- """Detects silence intervals using ffmpeg"""
90
- cmd = (
91
- f"ffmpeg -i {shlex.quote(path)} -af "
92
- f"silencedetect=noise={silence_thresh}dB:d={min_silence_len} "
93
- f"-f null -"
94
- )
95
- result = subprocess.run(cmd, shell=True, stderr=subprocess.PIPE, text=True)
96
- silences = []
97
- for line in result.stderr.splitlines():
98
- if "silence_start" in line:
99
- start = float(line.split("silence_start: ")[1])
100
- silences.append({"start": start, "end": None})
101
- elif "silence_end" in line and silences:
102
- end = float(line.split("silence_end: ")[1].split(" ")[0])
103
- silences[-1]["end"] = end
104
- return [s for s in silences if s["end"] is not None]
105
-
106
- def smart_segment_audio(audio_path, target_duration=5.0):
107
- """Segments audio at silence points closest to target_duration"""
108
- silences = detect_silences(audio_path)
109
- segments_cuts = [0.0]
110
- last_cut = 0.0
111
-
112
- # Si aucun silence détecté, on fallback sur du découpage régulier
113
- if not silences:
114
- return None
115
-
116
- # On cherche le meilleur point de coupe
117
- duration = float(subprocess.check_output(
118
- f"ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {shlex.quote(audio_path)}",
119
- shell=True
120
- ).strip())
121
-
122
- current_pos = 0.0
123
- while current_pos < duration:
124
- target_pos = current_pos + target_duration
125
- if target_pos >= duration:
126
- break
127
-
128
- # Trouver le silence le plus proche du target_pos
129
- best_cut = None
130
- min_dist = float('inf')
131
-
132
- for s in silences:
133
- # On coupe au milieu du silence
134
- mid_silence = (s["start"] + s["end"]) / 2
135
- if mid_silence <= current_pos: continue
136
-
137
- dist = abs(mid_silence - target_pos)
138
- if dist < min_dist:
139
- min_dist = dist
140
- best_cut = mid_silence
141
-
142
- # Optimisation: inutile de chercher trop loin
143
- if mid_silence > target_pos + 10: break
144
-
145
- if best_cut and abs(best_cut - current_pos) > 1.0: # Éviter segments trop courts
146
- segments_cuts.append(best_cut)
147
- current_pos = best_cut
148
- else:
149
- # Pas de silence proche, on force la coupe (fallback)
150
- current_pos += target_duration
151
- segments_cuts.append(current_pos)
152
-
153
- segments_cuts.append(duration)
154
- return segments_cuts
155
-
156
  def pipeline(video_in, model_name):
157
  tmp_dir = tempfile.mkdtemp()
158
  try:
@@ -160,86 +72,86 @@ def pipeline(video_in, model_name):
160
  yield "❌ Aucune vidéo sélectionnée.", None
161
  return
162
 
 
163
  yield "⏳ Phase 1/4 : Extraction audio...", None
164
  full_wav = os.path.join(tmp_dir, "full.wav")
165
- subprocess.run(f"ffmpeg -y -threads 0 -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True)
166
 
167
- yield "⏳ Phase 2/4 : Segmentation (5s optimisé Soloni)...", None
 
 
168
 
169
- # Segmentation fixe 5s (optimal pour Soloni V2/V3)
170
- subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time 5 -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True, check=True)
171
  files = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
 
 
172
 
173
- segment_files = []
174
- for i, f in enumerate(files):
175
- segment_files.append({"file": f, "start_offset": i * 5.0})
176
 
 
177
  yield f"⏳ Phase 3/4 : Chargement de {model_name}...", None
178
  model = get_model(model_name)
179
 
180
- yield f"🎙️ Transcription de {len(segment_files)} segments...", None
181
- # Optimisation batch size pour Colab (souvent T4/V100)
182
  b_size = 16 if DEVICE == "cuda" else 2
183
 
184
- audio_paths = [s["file"] for s in segment_files]
185
-
186
- # Utilisation de torch.inference_mode pour gain perf
187
  with torch.inference_mode():
188
- batch_hypotheses = model.transcribe(audio_paths, batch_size=b_size, return_hypotheses=True)
189
 
190
  all_words_ts = []
191
  for idx, hyp in enumerate(batch_hypotheses):
192
- yield f"📝 Traitement : {idx+1}/{len(segment_files)}...", None
193
- base_time = segment_files[idx]["start_offset"]
194
-
195
- if isinstance(hyp, list): hyp = hyp[0]
196
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
197
  words = text.split()
198
- # Ajustement temporel plus précis
199
- segment_duration = segment_files[idx+1]["start_offset"] - base_time if idx < len(segment_files)-1 else 5.0
200
 
201
- gap = segment_duration / max(len(words), 1)
 
202
  for i, w in enumerate(words):
203
- all_words_ts.append({"word": w, "start": base_time + (i * gap), "end": base_time + ((i+1) * gap)})
204
-
205
- yield " Phase 4/4 : Encodage vidéo...", None
 
 
 
 
 
206
  srt_path = os.path.join(tmp_dir, "final.srt")
207
  with open(srt_path, "w", encoding="utf-8") as f:
208
- for i in range(0, len(all_words_ts), 6):
209
  chunk = all_words_ts[i:i+6]
210
  f.write(f"{(i//6)+1}\n{format_srt_time(chunk[0]['start'])} --> {format_srt_time(chunk[-1]['end'])}\n")
211
  f.write(" ".join([c['word'] for c in chunk]) + "\n\n")
212
 
213
  out_path = os.path.abspath(f"resultat_{int(time.time())}.mp4")
 
214
  safe_srt = srt_path.replace("\\", "/").replace(":", "\\:")
215
 
216
- cmd = f"ffmpeg -y -threads 0 -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}':force_style='Alignment=2,FontSize=18,PrimaryColour=&H00FFFF'\" -c:v libx264 -preset ultrafast -c:a copy {out_path}"
 
217
  subprocess.run(cmd, shell=True, check=True)
218
 
219
  yield "✅ Terminé !", out_path
220
 
221
  except Exception as e:
 
222
  yield f"❌ Erreur : {str(e)}", None
223
  finally:
224
  if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)
225
 
226
- # 5. INTERFACE GRADIO
227
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
228
  gr.HTML("<div style='text-align:center;'><h1>🤖 RobotsMali Speech Lab</h1></div>")
229
-
230
  with gr.Row():
231
  with gr.Column():
232
  v_input = gr.Video(label="Vidéo Source")
233
  m_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle")
234
- run_btn = gr.Button("🚀 GÉNÉRER", variant="primary")
235
-
236
- if EXAMPLE_PATH:
237
- gr.Examples(examples=[[EXAMPLE_PATH, "Soloni V3 (TDT-CTC)"]], inputs=[v_input, m_input])
238
-
239
  with gr.Column():
240
- status = gr.Markdown("### État\nEn attente...")
241
- v_output = gr.Video(label="Vidéo finale")
242
 
243
  run_btn.click(pipeline, [v_input, m_input], [status, v_output])
244
 
245
- demo.queue().launch()
 
5
  from nemo.collections import asr as nemo_asr
6
  import gradio as gr
7
 
8
+ # --- CONFIGURATION ---
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+ SEGMENT_DURATION = 5.0 # Ta préférence pour Soloni
11
 
12
  MODELS = {
13
  "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
 
21
  "Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"),
22
  }
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  _cache = {}
25
 
26
+ # --- GESTION MÉMOIRE ET CHARGEMENT ---
27
  def clear_memory():
28
  _cache.clear()
29
  gc.collect()
 
37
 
38
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
39
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
 
40
  if not nemo_file: raise FileNotFoundError("Fichier .nemo introuvable.")
41
 
42
+ # CORRECTIF SÉCURISÉ POUR L'ERREUR D'INITIALISATION
43
+ from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
44
 
45
+ # On force l'utilisation d'un connecteur standard pour éviter le bug __init__()
46
+ connector = SaveRestoreConnector()
47
 
48
+ model = nemo_asr.models.ASRModel.restore_from(
49
+ nemo_file,
50
+ map_location=torch.device(DEVICE),
51
+ save_restore_connector=connector # On passe l'instance déjà créée
52
+ )
53
 
54
+ model.eval()
55
  if DEVICE == "cuda":
56
+ model = model.half()
 
 
 
57
 
58
  _cache[name] = model
59
  return model
60
 
61
+ # --- UTILITAIRES ---
 
62
  def format_srt_time(sec):
63
+ td = time.gmtime(max(0, sec))
64
  ms = int((sec - int(sec)) * 1000)
65
  return f"{time.strftime('%H:%M:%S', td)},{ms:03}"
66
 
67
+ # --- PIPELINE ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def pipeline(video_in, model_name):
69
  tmp_dir = tempfile.mkdtemp()
70
  try:
 
72
  yield "❌ Aucune vidéo sélectionnée.", None
73
  return
74
 
75
+ # Phase 1 : Audio
76
  yield "⏳ Phase 1/4 : Extraction audio...", None
77
  full_wav = os.path.join(tmp_dir, "full.wav")
78
+ subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True)
79
 
80
+ # Phase 2 : Segmentation
81
+ yield f"⏳ Phase 2/4 : Segmentation ({SEGMENT_DURATION}s)...", None
82
+ subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True, check=True)
83
 
 
 
84
  files = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
85
+ # Sécurité : ignorer les fichiers corrompus ou trop petits (<1.5 Ko)
86
+ valid_segments = [f for f in files if os.path.getsize(f) > 1500]
87
 
88
+ if not valid_segments:
89
+ yield "❌ Erreur : Audio trop court ou invalide.", None
90
+ return
91
 
92
+ # Phase 3 : Transcription
93
  yield f"⏳ Phase 3/4 : Chargement de {model_name}...", None
94
  model = get_model(model_name)
95
 
96
+ yield f"🎙️ Transcription de {len(valid_segments)} segments...", None
 
97
  b_size = 16 if DEVICE == "cuda" else 2
98
 
 
 
 
99
  with torch.inference_mode():
100
+ batch_hypotheses = model.transcribe(valid_segments, batch_size=b_size, return_hypotheses=True)
101
 
102
  all_words_ts = []
103
  for idx, hyp in enumerate(batch_hypotheses):
104
+ base_time = idx * SEGMENT_DURATION
 
 
 
105
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
106
  words = text.split()
107
+ if not words: continue
 
108
 
109
+ # Distribution équitable des mots sur les 5 secondes
110
+ gap = SEGMENT_DURATION / len(words)
111
  for i, w in enumerate(words):
112
+ all_words_ts.append({
113
+ "word": w,
114
+ "start": base_time + (i * gap),
115
+ "end": base_time + ((i+1) * gap)
116
+ })
117
+
118
+ # Phase 4 : Encodage Vidéo
119
+ yield "⏳ Phase 4/4 : Encodage final...", None
120
  srt_path = os.path.join(tmp_dir, "final.srt")
121
  with open(srt_path, "w", encoding="utf-8") as f:
122
+ for i in range(0, len(all_words_ts), 6): # 6 mots max par ligne
123
  chunk = all_words_ts[i:i+6]
124
  f.write(f"{(i//6)+1}\n{format_srt_time(chunk[0]['start'])} --> {format_srt_time(chunk[-1]['end'])}\n")
125
  f.write(" ".join([c['word'] for c in chunk]) + "\n\n")
126
 
127
  out_path = os.path.abspath(f"resultat_{int(time.time())}.mp4")
128
+ # Fix pour le chemin SRT (Windows/Linux)
129
  safe_srt = srt_path.replace("\\", "/").replace(":", "\\:")
130
 
131
+ # Style : Couleur Cyan pour la lisibilité
132
+ cmd = f"ffmpeg -y -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}':force_style='Alignment=2,FontSize=18,PrimaryColour=&H00FFFF'\" -c:v libx264 -preset ultrafast -c:a copy {out_path}"
133
  subprocess.run(cmd, shell=True, check=True)
134
 
135
  yield "✅ Terminé !", out_path
136
 
137
  except Exception as e:
138
+ traceback.print_exc()
139
  yield f"❌ Erreur : {str(e)}", None
140
  finally:
141
  if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)
142
 
143
+ # --- INTERFACE ---
144
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
145
  gr.HTML("<div style='text-align:center;'><h1>🤖 RobotsMali Speech Lab</h1></div>")
 
146
  with gr.Row():
147
  with gr.Column():
148
  v_input = gr.Video(label="Vidéo Source")
149
  m_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle")
150
+ run_btn = gr.Button("🚀 GÉNÉRER SOUS-TITRES", variant="primary")
 
 
 
 
151
  with gr.Column():
152
+ status = gr.Markdown("### État\nPrêt.")
153
+ v_output = gr.Video(label="Résultat")
154
 
155
  run_btn.click(pipeline, [v_input, m_input], [status, v_output])
156
 
157
+ demo.queue().launch()