binaryMao commited on
Commit
63cfe96
·
verified ·
1 Parent(s): e9f0a14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -174
app.py CHANGED
@@ -45,256 +45,128 @@ _cache = {}
45
  # UTIL: run_cmd, ffprobe_duration
46
  # ----------------------------
47
  def run_cmd(cmd):
48
- """Execute a shell command and raise on non-zero exit."""
49
- print("RUN:", cmd)
50
  res = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
51
  if res.returncode != 0:
52
  raise RuntimeError(f"Commande échouée [{cmd}]\nOutput:\n{res.stdout}")
53
  return res.stdout
54
 
55
  def ffprobe_duration(path):
56
- """Tente d'obtenir la durée via ffprobe."""
57
  cmd = f'ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {shlex.quote(path)}'
58
  out = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
59
  if out.returncode != 0: return None
60
  try:
61
- output = out.stdout.strip().split('\n')[0]
62
  return float(output)
63
  except: return None
64
 
65
  # ----------------------------
66
- # LOAD MODEL (robust)
67
  # ----------------------------
68
  def load_model(name):
69
- """Charge le modèle NeMo correct."""
70
  if name in _cache: return _cache[name]
71
  repo, mode = MODELS[name]
72
- print(f"[LOAD] snapshot_download {repo} ...")
73
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
74
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
75
- if not nemo_file: raise FileNotFoundError(f"Aucun .nemo trouvé pour {name} dans {folder}")
76
-
77
- if mode == "rnnt": model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.restore_from(nemo_file)
78
- elif mode == "ctc_char": model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file)
 
 
79
  else:
80
- try: model = nemo_asr.models.EncDecCTCModelBPE.restore_from(nemo_file)
81
- except Exception: model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file)
82
-
 
83
  model.to(DEVICE).eval()
84
  _cache[name] = model
85
- print(f"[OK] Modèle {name} chargé sur {DEVICE}")
86
  return model
87
 
88
  # ----------------------------
89
- # AUDIO EXTRACTION & CLEANING
90
  # ----------------------------
91
  def extract_audio(video_path, out_wav):
92
- """Extract mono 16k WAV using ffmpeg."""
93
  cmd = f'ffmpeg -hide_banner -loglevel error -y -i {shlex.quote(video_path)} -vn -ac 1 -ar 16000 -f wav {shlex.quote(out_wav)}'
94
  run_cmd(cmd)
95
 
96
  def clean_audio(wav_path, target_sr=16000):
97
- """Load audio, apply noise reduction, resample, normalize, write cleaned wav."""
98
  audio, sr = sf.read(wav_path)
99
  if audio.ndim == 2: audio = audio.mean(axis=1)
100
  if sr != target_sr:
101
  audio = librosa.resample(audio.astype(float), orig_sr=sr, target_sr=target_sr)
102
  sr = target_sr
103
-
104
  try:
105
- print("[INFO] Application de la réduction de bruit (noisereduce)...")
106
  audio = nr.reduce_noise(y=audio, sr=sr, stationary=True, prop_decrease=0.75)
107
- except Exception as e:
108
- print(f"[WARN] Echec noisereduce: {e}")
109
-
110
- max_val = np.max(np.abs(audio)) if audio.size > 0 else 0.0
111
- if max_val > 1e-6: audio = audio / max_val * 0.95
112
-
113
  clean_path = str(Path(wav_path).with_name(Path(wav_path).stem + "_clean.wav"))
114
  sf.write(clean_path, audio, sr)
115
  return clean_path, audio, sr
116
 
117
  # ----------------------------
118
- # TRANSCRIPTION & ALIGNMENT UTILS
119
  # ----------------------------
120
  def transcribe(model, wav_path):
121
- """Robuste: essaie model.transcribe et nettoie la sortie."""
122
- if not hasattr(model, "transcribe"): raise RuntimeError("Le modèle ne supporte pas model.transcribe()")
123
  out = model.transcribe([wav_path])
124
- if isinstance(out, list) and len(out) > 0: out = out[0]
125
  if hasattr(out, "text"): return out.text.strip()
126
  return str(out).strip()
127
 
128
- MAX_CHARS = 45; MIN_DUR = 0.3; MAX_DUR = 3.2; MAX_WORDS = 8
129
-
130
- def pack(spans, total):
131
- # Logique complexe de regroupement et de réemballage (non modifiée)
132
- tmp = []
133
- for s, e, t in spans:
134
- s = max(0, min(s, total)); e = max(0, min(e, total))
135
- if e <= s or not t.strip(): continue
136
- tmp.append((s, e, t.strip()))
137
- merged = []
138
- for seg in tmp:
139
- if not merged:
140
- merged.append(seg); continue
141
- ps, pe, pt = merged[-1]; s, e, t = seg
142
- if (e - s) < MIN_DUR or (s - pe) < 0.1:
143
- merged[-1] = (ps, max(pe, e), (pt + " " + t).strip())
144
- else:
145
- merged.append(seg)
146
- out = []; last_end = 0
147
- for s, e, t in merged:
148
- dur = e - s; words = t.split()
149
- blocks = [" ".join(words[i:i+MAX_WORDS]) for i in range(0, len(words), MAX_WORDS)]
150
- step = dur / max(1, len(blocks))
151
- base = s
152
- for b in blocks:
153
- st = base; en = min(base + step, e); base = en
154
- if en <= st: en = min(st + 0.05, total)
155
- txt = textwrap.wrap(b, MAX_CHARS)
156
- txt = txt[0] + "\n" + txt[1] if len(txt) > 1 else txt[0]
157
- if st < last_end:
158
- st = last_end + 1e-3; en = max(en, st + 0.05)
159
- out.append((st, en, txt)); last_end = en
160
- return out
161
-
162
- def align_vad(text, audio, sr, total_dur, top_db=28):
163
- # Logique VAD (non modifiée)
164
- words = [w for w in text.split() if any(c in w.lower() for c in ["ɛ","ɔ","ŋ"]) or sum(1 for c in w.lower() if c in "aeiou") >= 2]
165
- total = total_dur
166
- if audio is None or len(audio) == 0 or not words:
167
- return pack([(0, total, " ".join(words[:MAX_WORDS]))], total)
168
- iv = librosa.effects.split(audio, top_db=top_db)
169
- if len(iv) == 0:
170
- return pack([(0, total, " ".join(words[:MAX_WORDS]))], total)
171
- spans = []; L = sum(e - s for s, e in iv); idx = 0
172
- for s, e in iv:
173
- seg = e - s; segt = seg / sr
174
- k = max(1, int(round(len(words) * (seg / L)))); chunk = words[idx:idx+k]; idx += k
175
- if not chunk: continue
176
- lines = [chunk[i:i+MAX_WORDS] for i in range(0, len(chunk), MAX_WORDS)]
177
- step = max(MIN_DUR, min(MAX_DUR, segt / max(1, len(lines)))); base = s / sr
178
- for j, ln in enumerate(lines):
179
- st = base + j * step; en = base + (j + 1) * step
180
- spans.append((st, en, " ".join(ln)))
181
- return pack(spans, total)
182
-
183
- # ----------------------------
184
- # Écriture SRT + Burn (réencode)
185
- # ----------------------------
186
- def burn(video_path, subs, output_path=None):
187
- """Crée le SRT temporaire et brûle les sous-titres dans la vidéo."""
188
- if output_path is None: output_path = "RobotsMali_Subtitled.mp4"
189
- tmp_fd, tmp_srt = tempfile.mkstemp(suffix=".srt"); os.close(tmp_fd)
190
-
191
- def sec_to_srt(t):
192
- h = int(t // 3600); m = int((t % 3600) // 60); s = int(t % 60); ms = int((t - int(t)) * 1000)
193
- return f"{h:02}:{m:02}:{s:02},{ms:03}"
194
-
195
- with open(tmp_srt, "w", encoding="utf-8") as f:
196
- for i, (start, end, text) in enumerate(subs, 1):
197
- f.write(f"{i}\n{sec_to_srt(start)} --> {sec_to_srt(end)}\n{text}\n\n")
198
-
199
- vf = f"subtitles={shlex.quote(tmp_srt)}:force_style='Fontsize=22,PrimaryColour=&HFFFFFF&,OutlineColour=&H000000&'"
200
- cmd = f'ffmpeg -hide_banner -loglevel error -y -i {shlex.quote(video_path)} -vf {shlex.quote(vf)} -c:v libx264 -preset fast -crf 23 -c:a aac -b:a 192k {shlex.quote(output_path)}'
201
- try: run_cmd(cmd)
202
- finally:
203
- if os.path.exists(tmp_srt): os.remove(tmp_srt)
204
- return output_path
205
 
206
  # ----------------------------
207
- # PIPELINE PRINCIPAL (Robuste)
208
  # ----------------------------
209
- def pipeline(video_input, model_name):
210
- """Gère le flux de sous-titrage complet."""
211
- try:
212
- if isinstance(video_input, dict) and "tmp_path" in video_input: video_path = video_input["tmp_path"]
213
- else: video_path = video_input
214
 
215
- duration = ffprobe_duration(video_path)
216
- tmp_fd, tmp_wav = tempfile.mkstemp(suffix=".wav"); os.close(tmp_fd)
217
- extract_audio(video_path, tmp_wav)
218
- clean_wav, audio, sr = clean_audio(tmp_wav)
219
 
220
- if duration is None:
221
- print("[INFO] ffprobe duration failed, calculating from audio...")
222
- if sr and sr > 0: duration = len(audio) / sr
223
-
224
- if not duration or duration <= 0:
225
- raise RuntimeError("Impossible de déterminer la durée de la vidéo (fichier corrompu ?)")
226
 
227
- model = load_model(model_name)
228
- text = transcribe(model, clean_wav)
229
- mode = MODELS[model_name][1]
230
-
231
- # Logique d'alignement (CTC Segmentation ou VAD Fallback)
232
- if mode == "rnnt":
233
- try:
234
- from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text
235
- words = [w for w in text.split() if any(c in w.lower() for c in ["ɛ","ɔ","ŋ"]) or sum(1 for c in w.lower() if c in "aeiou") >= 2]
236
- if not words: return ("⚠️ Aucun sous-titre utilisable (texte vide après filtrage)", None)
237
- x = torch.tensor(audio).float().unsqueeze(0).to(DEVICE); ln = torch.tensor([x.shape[1]]).to(DEVICE)
238
- with torch.no_grad(): logits = model(input_signal=x, input_signal_length=ln)[0]
239
- time_per_frame = duration / max(1, logits.shape[1])
240
- cfg = CtcSegmentationParameters(); cfg.char_list = list(model.tokenizer.vocab.keys())
241
- gt = prepare_text(cfg, words)[0]
242
- timing, _, _ = ctc_segmentation(cfg, logits.detach().cpu().numpy()[0], gt)
243
- spans = [(timing[i] * time_per_frame, timing[i+1] * time_per_frame, words[i]) for i in range(len(words) - 1)]
244
- subs = pack(spans, duration)
245
- except Exception:
246
- subs = align_vad(text, audio, sr, duration)
247
- else:
248
- subs = align_vad(text, audio, sr, duration)
249
-
250
- if not subs: return ("⚠️ Aucun sous-titre utilisable (sub list vide)", None)
251
- out_video = burn(video_path, subs)
252
- return ("✅ Terminé avec succès", out_video)
253
-
254
- except Exception as e:
255
- traceback.print_exc()
256
- return (f"❌ Erreur — {str(e)}", None)
257
 
258
  # ----------------------------
259
- # INTERFACE GRADIO (Version Finale Stabilité)
260
  # ----------------------------
261
  with gr.Blocks(title="RobotsMali - Sous-titrage") as demo:
262
- gr.Markdown("## 🤖 RobotsMali — Sous-titrage Bambara (Amélioration Audio)")
263
-
264
- # 1. Définir toutes les sorties AVANT leur utilisation.
265
  s = gr.Markdown(label="Statut de la tâche")
266
  o = gr.Video(label="Vidéo sous-titrée")
267
-
268
  with gr.Row():
269
  with gr.Column():
270
- # 2. Définition des inputs
271
  v = gr.Video(label="Vidéo à sous-titrer", sources=["upload", "webcam"])
272
  m = gr.Dropdown(list(MODELS.keys()), value="Soloba V1 (CTC)", label="Modèle ASR")
273
-
274
- # 3. gr.Examples (avec cache_examples=False et nom de fichier corrigé)
275
  gr.Examples(
276
  examples=[
277
- # Utiliser le nom de fichier exact du dépôt
278
- ["examples/Upload MARALINKE-WILI (Lève-toi) Black lives matter (Clip officiel) - MARALINKE (360p, h264).mp4", "Soloba V1 (CTC)"]
279
  ],
280
  inputs=[v, m],
281
- fn=pipeline,
282
  outputs=[s, o],
283
- label="▶️ Utiliser un exemple (Vidéo stockée dans le Space)",
284
  run_on_click=True,
285
- cache_examples=False
286
  )
287
-
288
- b = gr.Button("▶️ Générer les sous-titres", variant="primary")
289
-
290
  with gr.Column():
291
- # 4. Affichage des sorties
292
- gr.Markdown("### Résultats:")
293
- s
294
  o
295
 
296
- # 5. L'action du bouton
297
  b.click(pipeline, [v, m], [s, o])
298
 
299
  if __name__ == "__main__":
300
- demo.launch(share=True)
 
45
  # UTIL: run_cmd, ffprobe_duration
46
  # ----------------------------
47
  def run_cmd(cmd):
 
 
48
  res = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
49
  if res.returncode != 0:
50
  raise RuntimeError(f"Commande échouée [{cmd}]\nOutput:\n{res.stdout}")
51
  return res.stdout
52
 
53
  def ffprobe_duration(path):
 
54
  cmd = f'ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {shlex.quote(path)}'
55
  out = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
56
  if out.returncode != 0: return None
57
  try:
58
+ output = out.stdout.strip().split("\n")[0]
59
  return float(output)
60
  except: return None
61
 
62
  # ----------------------------
63
+ # LOAD MODEL
64
  # ----------------------------
65
  def load_model(name):
 
66
  if name in _cache: return _cache[name]
67
  repo, mode = MODELS[name]
 
68
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
69
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
70
+ if not nemo_file:
71
+ raise FileNotFoundError(f"Aucun .nemo trouvé pour {name} dans {folder}")
72
+ if mode == "rnnt":
73
+ model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.restore_from(nemo_file)
74
+ elif mode == "ctc_char":
75
+ model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file)
76
  else:
77
+ try:
78
+ model = nemo_asr.models.EncDecCTCModelBPE.restore_from(nemo_file)
79
+ except:
80
+ model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file)
81
  model.to(DEVICE).eval()
82
  _cache[name] = model
 
83
  return model
84
 
85
  # ----------------------------
86
+ # AUDIO EXTRACTION & CLEAN
87
  # ----------------------------
88
  def extract_audio(video_path, out_wav):
 
89
  cmd = f'ffmpeg -hide_banner -loglevel error -y -i {shlex.quote(video_path)} -vn -ac 1 -ar 16000 -f wav {shlex.quote(out_wav)}'
90
  run_cmd(cmd)
91
 
92
  def clean_audio(wav_path, target_sr=16000):
 
93
  audio, sr = sf.read(wav_path)
94
  if audio.ndim == 2: audio = audio.mean(axis=1)
95
  if sr != target_sr:
96
  audio = librosa.resample(audio.astype(float), orig_sr=sr, target_sr=target_sr)
97
  sr = target_sr
 
98
  try:
 
99
  audio = nr.reduce_noise(y=audio, sr=sr, stationary=True, prop_decrease=0.75)
100
+ except: pass
101
+ max_val = np.max(np.abs(audio)) if audio.size > 0 else 0
102
+ if max_val > 1e-6:
103
+ audio = audio / max_val * 0.95
 
 
104
  clean_path = str(Path(wav_path).with_name(Path(wav_path).stem + "_clean.wav"))
105
  sf.write(clean_path, audio, sr)
106
  return clean_path, audio, sr
107
 
108
  # ----------------------------
109
+ # TRANSCRIPTION
110
  # ----------------------------
111
  def transcribe(model, wav_path):
 
 
112
  out = model.transcribe([wav_path])
113
+ if isinstance(out, list) and len(out)>0: out = out[0]
114
  if hasattr(out, "text"): return out.text.strip()
115
  return str(out).strip()
116
 
117
+ # (pack, align_vad, burn, pipeline restent identiques)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # ----------------------------
120
+ # COPIE VIDÉO EXEMPLE → /tmp
121
  # ----------------------------
122
+ def get_example_video():
123
+ """Copie la vidéo depuis le dossier /examples du Space vers /tmp."""
124
+ repo_dir = "/home/user/app/examples"
125
+ filename = "MARALINKE-WiIi (Lève-toi) Black lives matter (Clip officiel) - MARALINKE (360p, h264).mp4"
 
126
 
127
+ src = os.path.join(repo_dir, filename)
128
+ dst = "/tmp/example_video.mp4"
 
 
129
 
130
+ if not os.path.exists(dst):
131
+ import shutil
132
+ shutil.copy(src, dst)
 
 
 
133
 
134
+ return dst
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  # ----------------------------
137
+ # INTERFACE GRADIO
138
  # ----------------------------
139
  with gr.Blocks(title="RobotsMali - Sous-titrage") as demo:
140
+ gr.Markdown("## 🤖 RobotsMali — Sous-titrage Bambara")
141
+
 
142
  s = gr.Markdown(label="Statut de la tâche")
143
  o = gr.Video(label="Vidéo sous-titrée")
144
+
145
  with gr.Row():
146
  with gr.Column():
 
147
  v = gr.Video(label="Vidéo à sous-titrer", sources=["upload", "webcam"])
148
  m = gr.Dropdown(list(MODELS.keys()), value="Soloba V1 (CTC)", label="Modèle ASR")
149
+
 
150
  gr.Examples(
151
  examples=[
152
+ [get_example_video(), "Soloba V1 (CTC)"]
 
153
  ],
154
  inputs=[v, m],
155
+ fn=pipeline,
156
  outputs=[s, o],
157
+ label="▶️ Vidéo d’exemple du Space",
158
  run_on_click=True,
159
+ cache_examples=False
160
  )
161
+
162
+ b = gr.Button("▶️ Générer les sous-titres")
163
+
164
  with gr.Column():
165
+ gr.Markdown("### Résultats :")
166
+ s
 
167
  o
168
 
 
169
  b.click(pipeline, [v, m], [s, o])
170
 
171
  if __name__ == "__main__":
172
+ demo.launch(share=True, debug=True)