binaryMao commited on
Commit
49bfd29
·
verified ·
1 Parent(s): bfcc59d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +269 -140
app.py CHANGED
@@ -1,10 +1,20 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- ROBOTSMALI V38 FINAL SOUS-TITRAGE BAMBARA (STYLE NETFLIX)
4
- Correction V38.1 : FFmpeg fixe (pas de -c copy), durée exacte, QuartzNet fonctionnel
 
 
5
  """
6
 
7
- import os, tempfile, traceback, random, textwrap
 
 
 
 
 
 
 
 
8
  import numpy as np
9
  import torch
10
  import soundfile as sf
@@ -12,7 +22,6 @@ import librosa
12
  from huggingface_hub import snapshot_download
13
  from nemo.collections import asr as nemo_asr
14
  import gradio as gr
15
- from moviepy.editor import VideoFileClip
16
 
17
  # ----------------------------
18
  # CONFIG
@@ -25,211 +34,331 @@ torch.manual_seed(1234)
25
  MODELS = {
26
  "Soloni V1 (RNNT)": ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"),
27
  "Soloni V0 (RNNT)": ("RobotsMali/soloni-114m-tdt-ctc-v0", "rnnt"),
28
- "Soloba V1 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"),
29
- "Soloba V0 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v0", "ctc"),
30
- "QuartzNet V1 (CTC-char)": ("RobotsMali/stt-bm-quartznet15x5-v1", "ctc_char"),
31
- "QuartzNet V0 (CTC-char)": ("RobotsMali/stt-bm-quartznet15x5-v0", "ctc_char"),
32
  }
33
 
34
  _cache = {}
35
 
36
  # ----------------------------
37
- # MODEL LOADING
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # ----------------------------
39
  def load_model(name):
40
- if name in _cache: return _cache[name]
 
 
 
41
  repo, mode = MODELS[name]
 
42
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
43
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
44
  if not nemo_file:
45
- raise FileNotFoundError(f"Aucun .nemo trouvé pour {name}")
46
- model = (
47
- nemo_asr.models.EncDecHybridRNNTCTCBPEModel.restore_from(nemo_file)
48
- if mode == "rnnt"
49
- else nemo_asr.models.EncDecCTCModelBPE.restore_from(nemo_file)
50
- )
 
 
 
 
 
 
 
 
 
 
 
 
51
  model.to(DEVICE).eval()
52
  _cache[name] = model
 
53
  return model
54
 
55
  # ----------------------------
56
  # AUDIO EXTRACTION & CLEANING
57
  # ----------------------------
58
- def extract_audio(video, wav):
59
- os.system(f'ffmpeg -y -i "{video}" -ar 16000 -ac 1 -vn "{wav}"')
 
 
60
 
61
- def clean_audio(wav, top_db=35):
62
- audio, sr = sf.read(wav)
 
63
  if audio.ndim == 2:
64
- audio = audio.mean(1)
65
- max_val = np.max(np.abs(audio)) if audio.size > 0 else 0
 
 
 
66
  if max_val > 1e-6:
67
  audio = audio / max_val * 0.9
68
- clean = wav.replace(".wav", "_clean.wav")
69
- sf.write(clean, audio, sr)
70
- return clean, audio, sr
71
 
72
  # ----------------------------
73
  # TRANSCRIPTION
74
  # ----------------------------
75
- def transcribe(model, wav):
76
- out = model.transcribe([wav])
 
 
 
 
77
  if isinstance(out, list):
78
- if out and hasattr(out[0], "text"):
79
- return out[0].text.strip()
80
- if out and isinstance(out[0], str):
81
- return out[0].strip()
 
 
 
 
82
  if hasattr(out, "text"):
83
  return out.text.strip()
84
  return str(out).strip()
85
 
86
  # ----------------------------
87
- # UTILITAIRES
88
  # ----------------------------
89
  def keep_bambara(words):
90
- res=[]
91
  for w in words:
92
- wl=w.lower()
93
- if any(c in wl for c in ["ɛ","ɔ","ŋ"]) or sum(c in "aeiou" for c in wl)>=2:
94
  res.append(w)
95
  return res
96
 
97
- MAX_CHARS=45; MIN_DUR=0.3; MAX_DUR=3.2; MAX_WORDS=8
98
 
99
  def wrap2(txt):
100
- parts=textwrap.wrap(txt,MAX_CHARS)
101
- if len(parts)<=1: return txt
102
- mid=len(txt)//2
103
- left=txt.rfind(" ",0,mid)
104
- right=txt.find(" ",mid)
105
- cut=left if (mid-left)<=(right-mid if right!=-1 else 1e9) else right
106
- l1=txt[:cut].strip(); l2=txt[cut:].strip()
107
- return l1+"\n"+l2 if l2 else l1
108
-
109
- def pack(spans,total):
110
- tmp=[]
111
- for s,e,t in spans:
112
- s=max(0,min(s,total)); e=max(0,min(e,total))
113
- if e<=s or not t.strip(): continue
114
- tmp.append((s,e,t.strip()))
115
- merged=[]
 
116
  for seg in tmp:
117
- if not merged: merged.append(seg); continue
118
- ps,pe,pt=merged[-1]; s,e,t=seg
119
- if (e-s)<MIN_DUR or (s-pe)<0.1:
120
- merged[-1]=(ps,max(pe,e),(pt+" "+t).strip())
121
- else: merged.append(seg)
122
- out=[]; last_end=0
123
- for s,e,t in merged:
124
- dur=e-s; words=t.split()
125
- blocks=[" ".join(words[i:i+MAX_WORDS]) for i in range(0,len(words),MAX_WORDS)]
126
- step=dur/max(1,len(blocks)); base=s
 
 
 
127
  for b in blocks:
128
- st=base; en=min(base+step,e); base=en
129
- if en<=st: en=min(st+0.05,total)
130
- txt=wrap2(b)
131
- if st<last_end: st=last_end+1e-3; en=max(en,st+0.05)
132
- out.append((st,en,txt)); last_end=en
 
133
  return out
134
 
135
  # ----------------------------
136
- # ALIGNEMENT SIMPLE (VAD)
137
  # ----------------------------
138
- def align_vad(text,audio,sr,total_dur,top_db=28):
139
- words=keep_bambara(text.split())
140
- total=total_dur
141
- iv=librosa.effects.split(audio,top_db=top_db)
142
- if len(iv)==0 or not words:
143
- return pack([(0,total," ".join(words[:MAX_WORDS]))],total)
144
- spans=[]; L=sum(e-s for s,e in iv); idx=0
145
- for s,e in iv:
146
- seg=e-s; segt=seg/sr; k=max(1,int(round(len(words)*(seg/L))))
147
- chunk=words[idx:idx+k]; idx+=k
 
 
 
 
 
148
  if not chunk: continue
149
- lines=[chunk[i:i+MAX_WORDS] for i in range(0,len(chunk),MAX_WORDS)]
150
- step=max(MIN_DUR,min(MAX_DUR,segt/len(lines))); base=s/sr
151
- for j,ln in enumerate(lines):
152
- st=base+j*step; en=base+(j+1)*step
153
- spans.append((st,en," ".join(ln)))
154
- return pack(spans,total)
 
155
 
156
  # ----------------------------
157
- # SOUS-TITRES SRT + FFmpeg
158
  # ----------------------------
159
- def burn(video, subs):
160
- tmp_srt = tempfile.mktemp(suffix=".srt")
161
- out_file = "RobotsMali_Subtitled.mp4"
162
-
 
163
  def sec_to_srt(t):
164
- h = int(t // 3600)
165
- m = int((t % 3600) // 60)
166
- s = int(t % 60)
167
- ms = int((t - int(t)) * 1000)
168
  return f"{h:02}:{m:02}:{s:02},{ms:03}"
169
-
170
- # Écriture du SRT
171
  with open(tmp_srt, "w", encoding="utf-8") as f:
172
  for i, (start, end, text) in enumerate(subs, 1):
173
  f.write(f"{i}\n{sec_to_srt(start)} --> {sec_to_srt(end)}\n{text}\n\n")
174
 
175
- # Fusion avec réencodage (corrigé)
176
- cmd = (
177
- f'ffmpeg -y -i "{video}" '
178
- f'-vf "subtitles={tmp_srt}:force_style=\'Fontsize=24,PrimaryColour=&HFFFFFF&,OutlineColour=&H000000&\'" '
179
- f'-c:v libx264 -preset ultrafast -crf 23 -c:a copy "{out_file}"'
180
- )
181
- os.system(cmd)
182
-
183
- if os.path.exists(tmp_srt):
184
- os.remove(tmp_srt)
185
- return out_file
186
 
187
  # ----------------------------
188
- # PIPELINE PRINCIPAL
189
  # ----------------------------
190
- def pipeline(video, model_name):
 
 
 
 
191
  try:
192
- wav=tempfile.mktemp(suffix=".wav")
193
- extract_audio(video,wav)
194
- clean,audio,sr=clean_audio(wav)
195
- model=load_model(model_name)
196
- text=transcribe(model,clean)
197
- mode=MODELS[model_name][1]
198
- if mode=="rnnt":
199
- from ctc_segmentation import ctc_segmentation,CtcSegmentationParameters,prepare_text
200
- words=keep_bambara(text.split())
201
- if not words: return "⚠️ Aucun sous-titre utilisable",None
202
- x=torch.tensor(audio).float().unsqueeze(0).to(DEVICE)
203
- ln=torch.tensor([x.shape[1]]).to(DEVICE)
204
- with torch.no_grad(): logits=model(input_signal=x,input_signal_length=ln)[0]
205
- tps=VideoFileClip(video).duration/logits.shape[1]
206
- raw=model.tokenizer.vocab
207
- vocab=list(raw.keys()) if isinstance(raw,dict) else list(raw)
208
- cfg=CtcSegmentationParameters(); cfg.char_list=vocab
209
- gt=prepare_text(cfg,words)[0]
210
- timing,_,_=ctc_segmentation(cfg,logits.detach().cpu().numpy()[0],gt)
211
- spans=[(timing[i]*tps,timing[i+1]*tps,words[i]) for i in range(len(words))]
212
- subs=pack(spans,VideoFileClip(video).duration)
213
  else:
214
- subs=align_vad(text,audio,sr,VideoFileClip(video).duration)
215
- if not subs: return "⚠️ Aucun sous-titre utilisable",None
216
- out=burn(video,subs)
217
- return "✅ Terminé avec succès",out
218
- except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  traceback.print_exc()
220
- return "❌ Erreur — voir logs ci-dessus",None
221
 
222
  # ----------------------------
223
- # INTERFACE GRADIO
224
  # ----------------------------
225
- with gr.Blocks(title="RobotsMali V38.1 Final") as demo:
226
- gr.Markdown("## ⚡ RobotsMali V38.1 — Sous-titrage Style Netflix (QuartzNet & RNNT stable)")
227
  v = gr.Video(label="Vidéo à sous-titrer")
228
  m = gr.Dropdown(list(MODELS.keys()), value="Soloba V1 (CTC)", label="Modèle ASR")
229
  b = gr.Button("▶️ Générer")
230
  s = gr.Markdown()
231
  o = gr.Video(label="Vidéo sous-titrée")
232
-
233
  b.click(pipeline, [v, m], [s, o])
234
 
235
- demo.launch(share=True, debug=False)
 
 
 
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ ROBOTSMALI V41Sous-titrage Bambara (QuartzNet fix + RNNT/CTC robust)
4
+ - Load: RNNT, CTC-BPE (Soloba) et CTC-char (QuartzNet) correctement
5
+ - Segmentation : protected ctc_segmentation -> fallback VAD
6
+ - Burn subtitles : réencodage (libx264) quand on applique un filtre subtitles
7
  """
8
 
9
+ import os
10
+ import shlex
11
+ import subprocess
12
+ import tempfile
13
+ import traceback
14
+ import random
15
+ import textwrap
16
+ from pathlib import Path
17
+
18
  import numpy as np
19
  import torch
20
  import soundfile as sf
 
22
  from huggingface_hub import snapshot_download
23
  from nemo.collections import asr as nemo_asr
24
  import gradio as gr
 
25
 
26
  # ----------------------------
27
  # CONFIG
 
34
  MODELS = {
35
  "Soloni V1 (RNNT)": ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"),
36
  "Soloni V0 (RNNT)": ("RobotsMali/soloni-114m-tdt-ctc-v0", "rnnt"),
37
+ "Soloba V1 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"), # BPE
38
+ "Soloba V0 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v0", "ctc"), # BPE
39
+ "QuartzNet V1 (CTC-char)": ("RobotsMali/stt-bm-quartznet15x5-v1", "ctc_char"), # char
40
+ "QuartzNet V0 (CTC-char)": ("RobotsMali/stt-bm-quartznet15x5-v0", "ctc_char"), # char
41
  }
42
 
43
  _cache = {}
44
 
45
  # ----------------------------
46
+ # UTIL: run_cmd, ffprobe_duration
47
+ # ----------------------------
48
+ def run_cmd(cmd):
49
+ """Execute a shell command and raise on non-zero exit."""
50
+ print("RUN:", cmd)
51
+ res = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
52
+ if res.returncode != 0:
53
+ raise RuntimeError(f"Commande échouée [{cmd}]\nOutput:\n{res.stdout}")
54
+ return res.stdout
55
+
56
+ def ffprobe_duration(path):
57
+ cmd = f'ffprobe -v error -select_streams v:0 -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {shlex.quote(path)}'
58
+ out = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
59
+ if out.returncode != 0:
60
+ print("ffprobe erreur:", out.stderr)
61
+ return None
62
+ try:
63
+ return float(out.stdout.strip())
64
+ except:
65
+ return None
66
+
67
+ # ----------------------------
68
+ # LOAD MODEL (robust)
69
  # ----------------------------
70
  def load_model(name):
71
+ """Charge le modèle NeMo correct selon type (rnnt / ctc / ctc_char)."""
72
+ if name in _cache:
73
+ return _cache[name]
74
+
75
  repo, mode = MODELS[name]
76
+ print(f"[LOAD] snapshot_download {repo} ...")
77
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
78
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
79
  if not nemo_file:
80
+ raise FileNotFoundError(f"Aucun .nemo trouvé pour {name} dans {folder}")
81
+
82
+ print(f"[LOAD] .nemo trouvé: {nemo_file}; mode={mode}")
83
+
84
+ # Sélection de la classe NeMo selon le mode
85
+ if mode == "rnnt":
86
+ model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.restore_from(nemo_file)
87
+ elif mode == "ctc_char":
88
+ # QuartzNet (char) : pas de tokenizer BPE dans cfg -> utiliser EncDecCTCModel
89
+ model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file)
90
+ else: # mode == "ctc" (BPE)
91
+ try:
92
+ model = nemo_asr.models.EncDecCTCModelBPE.restore_from(nemo_file)
93
+ except Exception as e:
94
+ # fallback sur EncDecCTCModel si BPE absent (prudence)
95
+ print(f"[WARN] EncDecCTCModelBPE failed ({e}), fallback EncDecCTCModel")
96
+ model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file)
97
+
98
  model.to(DEVICE).eval()
99
  _cache[name] = model
100
+ print(f"[OK] Modèle {name} chargé sur {DEVICE}")
101
  return model
102
 
103
  # ----------------------------
104
  # AUDIO EXTRACTION & CLEANING
105
  # ----------------------------
106
+ def extract_audio(video_path, out_wav):
107
+ """Extract mono 16k WAV using ffmpeg."""
108
+ cmd = f'ffmpeg -hide_banner -loglevel error -y -i {shlex.quote(video_path)} -vn -ac 1 -ar 16000 -f wav {shlex.quote(out_wav)}'
109
+ run_cmd(cmd)
110
 
111
+ def clean_audio(wav_path, target_sr=16000):
112
+ """Load audio, ensure mono, resample to target_sr, normalize, write cleaned wav."""
113
+ audio, sr = sf.read(wav_path)
114
  if audio.ndim == 2:
115
+ audio = audio.mean(axis=1)
116
+ if sr != target_sr:
117
+ audio = librosa.resample(audio.astype(float), orig_sr=sr, target_sr=target_sr)
118
+ sr = target_sr
119
+ max_val = np.max(np.abs(audio)) if audio.size > 0 else 0.0
120
  if max_val > 1e-6:
121
  audio = audio / max_val * 0.9
122
+ clean_path = str(Path(wav_path).with_name(Path(wav_path).stem + "_clean.wav"))
123
+ sf.write(clean_path, audio, sr)
124
+ return clean_path, audio, sr
125
 
126
  # ----------------------------
127
  # TRANSCRIPTION
128
  # ----------------------------
129
+ def transcribe(model, wav_path):
130
+ """Robuste: essaie model.transcribe et nettoie la sortie."""
131
+ if not hasattr(model, "transcribe"):
132
+ raise RuntimeError("Le modèle ne supporte pas model.transcribe()")
133
+ out = model.transcribe([wav_path])
134
+ # Différentes formes de sortie possibles
135
  if isinstance(out, list):
136
+ if len(out) == 0:
137
+ return ""
138
+ first = out[0]
139
+ if isinstance(first, str):
140
+ return first.strip()
141
+ if hasattr(first, "text"):
142
+ return first.text.strip()
143
+ return str(first).strip()
144
  if hasattr(out, "text"):
145
  return out.text.strip()
146
  return str(out).strip()
147
 
148
  # ----------------------------
149
+ # UTILITAIRES sous-titres / packing
150
  # ----------------------------
151
  def keep_bambara(words):
152
+ res = []
153
  for w in words:
154
+ wl = w.lower()
155
+ if any(c in wl for c in ["ɛ","ɔ","ŋ"]) or sum(1 for c in wl if c in "aeiou") >= 2:
156
  res.append(w)
157
  return res
158
 
159
+ MAX_CHARS = 45; MIN_DUR = 0.3; MAX_DUR = 3.2; MAX_WORDS = 8
160
 
161
  def wrap2(txt):
162
+ parts = textwrap.wrap(txt, MAX_CHARS)
163
+ if len(parts) <= 1:
164
+ return txt
165
+ mid = len(txt) // 2
166
+ left = txt.rfind(" ", 0, mid)
167
+ right = txt.find(" ", mid)
168
+ cut = left if (mid - left) <= ((right - mid) if right != -1 else 1e9) else right
169
+ l1 = txt[:cut].strip(); l2 = txt[cut:].strip()
170
+ return l1 + "\n" + l2 if l2 else l1
171
+
172
+ def pack(spans, total):
173
+ tmp = []
174
+ for s, e, t in spans:
175
+ s = max(0, min(s, total)); e = max(0, min(e, total))
176
+ if e <= s or not t.strip(): continue
177
+ tmp.append((s, e, t.strip()))
178
+ merged = []
179
  for seg in tmp:
180
+ if not merged:
181
+ merged.append(seg); continue
182
+ ps, pe, pt = merged[-1]; s, e, t = seg
183
+ if (e - s) < MIN_DUR or (s - pe) < 0.1:
184
+ merged[-1] = (ps, max(pe, e), (pt + " " + t).strip())
185
+ else:
186
+ merged.append(seg)
187
+ out = []; last_end = 0
188
+ for s, e, t in merged:
189
+ dur = e - s; words = t.split()
190
+ blocks = [" ".join(words[i:i+MAX_WORDS]) for i in range(0, len(words), MAX_WORDS)]
191
+ step = dur / max(1, len(blocks))
192
+ base = s
193
  for b in blocks:
194
+ st = base; en = min(base + step, e); base = en
195
+ if en <= st: en = min(st + 0.05, total)
196
+ txt = wrap2(b)
197
+ if st < last_end:
198
+ st = last_end + 1e-3; en = max(en, st + 0.05)
199
+ out.append((st, en, txt)); last_end = en
200
  return out
201
 
202
  # ----------------------------
203
+ # VAD ALIGN (fallback alignment)
204
  # ----------------------------
205
+ def align_vad(text, audio, sr, total_dur, top_db=28):
206
+ words = keep_bambara(text.split())
207
+ total = total_dur
208
+ if audio is None or len(audio) == 0 or not words:
209
+ return pack([(0, total, " ".join(words[:MAX_WORDS]))], total)
210
+ iv = librosa.effects.split(audio, top_db=top_db)
211
+ if len(iv) == 0:
212
+ return pack([(0, total, " ".join(words[:MAX_WORDS]))], total)
213
+ spans = []
214
+ L = sum(e - s for s, e in iv)
215
+ idx = 0
216
+ for s, e in iv:
217
+ seg = e - s; segt = seg / sr
218
+ k = max(1, int(round(len(words) * (seg / L))))
219
+ chunk = words[idx:idx+k]; idx += k
220
  if not chunk: continue
221
+ lines = [chunk[i:i+MAX_WORDS] for i in range(0, len(chunk), MAX_WORDS)]
222
+ step = max(MIN_DUR, min(MAX_DUR, segt / max(1, len(lines))))
223
+ base = s / sr
224
+ for j, ln in enumerate(lines):
225
+ st = base + j * step; en = base + (j + 1) * step
226
+ spans.append((st, en, " ".join(ln)))
227
+ return pack(spans, total)
228
 
229
  # ----------------------------
230
+ # Écriture SRT + Burn (réencode)
231
  # ----------------------------
232
+ def burn(video_path, subs, output_path=None):
233
+ if output_path is None:
234
+ output_path = "RobotsMali_Subtitled.mp4"
235
+ tmp_fd, tmp_srt = tempfile.mkstemp(suffix=".srt")
236
+ os.close(tmp_fd)
237
  def sec_to_srt(t):
238
+ h = int(t // 3600); m = int((t % 3600) // 60); s = int(t % 60); ms = int((t - int(t)) * 1000)
 
 
 
239
  return f"{h:02}:{m:02}:{s:02},{ms:03}"
 
 
240
  with open(tmp_srt, "w", encoding="utf-8") as f:
241
  for i, (start, end, text) in enumerate(subs, 1):
242
  f.write(f"{i}\n{sec_to_srt(start)} --> {sec_to_srt(end)}\n{text}\n\n")
243
 
244
+ # On réencode (libx264) car on applique subtitles filter
245
+ vf = f"subtitles={shlex.quote(tmp_srt)}:force_style='Fontsize=22,PrimaryColour=&HFFFFFF&,OutlineColour=&H000000&'"
246
+ cmd = f'ffmpeg -hide_banner -loglevel error -y -i {shlex.quote(video_path)} -vf {shlex.quote(vf)} -c:v libx264 -preset fast -crf 23 -c:a aac -b:a 192k {shlex.quote(output_path)}'
247
+ try:
248
+ run_cmd(cmd)
249
+ finally:
250
+ if os.path.exists(tmp_srt):
251
+ os.remove(tmp_srt)
252
+ return output_path
 
 
253
 
254
  # ----------------------------
255
+ # PIPELINE PRINCIPAL (V41)
256
  # ----------------------------
257
+ def pipeline(video_input, model_name):
258
+ """
259
+ video_input : chemin ou dict Gradio (tmp_path)
260
+ model_name : clé dans MODELS
261
+ """
262
  try:
263
+ # support Gradio dict (tmp_path)
264
+ if isinstance(video_input, dict) and "tmp_path" in video_input:
265
+ video_path = video_input["tmp_path"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  else:
267
+ video_path = video_input
268
+
269
+ duration = ffprobe_duration(video_path)
270
+ if duration is None:
271
+ raise RuntimeError("Impossible d'obtenir la durée de la vidéo via ffprobe")
272
+
273
+ # fichiers temporaires
274
+ tmp_fd, tmp_wav = tempfile.mkstemp(suffix=".wav")
275
+ os.close(tmp_fd)
276
+
277
+ # extraction + nettoyage
278
+ extract_audio(video_path, tmp_wav)
279
+ clean_wav, audio, sr = clean_audio(tmp_wav)
280
+
281
+ # charger modèle
282
+ model = load_model(model_name)
283
+ text = transcribe(model, clean_wav)
284
+ mode = MODELS[model_name][1]
285
+
286
+ # segmentation / alignement
287
+ subs = None
288
+ if mode == "rnnt":
289
+ # RNNT : tentative de segmentation via logits + ctc_segmentation si dispo
290
+ try:
291
+ from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text
292
+ words = keep_bambara(text.split())
293
+ if not words:
294
+ return ("⚠️ Aucun sous-titre utilisable (texte vide après filtrage)", None)
295
+ x = torch.tensor(audio).float().unsqueeze(0).to(DEVICE)
296
+ ln = torch.tensor([x.shape[1]]).to(DEVICE)
297
+ with torch.no_grad():
298
+ logits = model(input_signal=x, input_signal_length=ln)[0]
299
+ # heuristique mapping frames -> seconds
300
+ time_per_frame = duration / max(1, logits.shape[1])
301
+ # build char list
302
+ try:
303
+ raw = model.tokenizer.vocab
304
+ vocab = list(raw.keys()) if isinstance(raw, dict) else list(raw)
305
+ except Exception:
306
+ vocab = None
307
+ cfg = CtcSegmentationParameters()
308
+ if vocab:
309
+ cfg.char_list = vocab
310
+ gt = prepare_text(cfg, words)[0]
311
+ try:
312
+ timing, _, _ = ctc_segmentation(cfg, logits.detach().cpu().numpy()[0], gt)
313
+ spans = [(timing[i] * time_per_frame, timing[i+1] * time_per_frame, words[i]) for i in range(len(words) - 1)]
314
+ subs = pack(spans, duration)
315
+ except AssertionError:
316
+ print("[WARN] Audio shorter than text -> fallback to VAD alignment")
317
+ subs = align_vad(text, audio, sr, duration)
318
+ except Exception as e:
319
+ print(f"[WARN] ctc_segmentation not available or failed ({e}) -> fallback VAD")
320
+ subs = align_vad(text, audio, sr, duration)
321
+
322
+ elif mode == "ctc_char":
323
+ # QuartzNet : pas de tokenizer BPE, on procède avec VAD (ou on peut essayer timestamps si model le permet)
324
+ # On essaie d'obtenir timestamps via model.transcribe() si disponible (mais souvent non)
325
+ try:
326
+ subs = align_vad(text, audio, sr, duration)
327
+ except Exception as e:
328
+ print(f"[WARN] QuartzNet alignment failed: {e}")
329
+ subs = align_vad(text, audio, sr, duration)
330
+
331
+ else: # ctc (BPE)
332
+ # Pour les modèles CTC-BPE, VAD reste une option raisonnable si segmentation manque
333
+ try:
334
+ subs = align_vad(text, audio, sr, duration)
335
+ except Exception as e:
336
+ print(f"[WARN] CTC alignment failed: {e}")
337
+ subs = align_vad(text, audio, sr, duration)
338
+
339
+ if not subs:
340
+ return ("⚠️ Aucun sous-titre utilisable (sub list vide)", None)
341
+
342
+ out_video = burn(video_path, subs)
343
+ return ("✅ Terminé avec succès", out_video)
344
+
345
+ except Exception as e:
346
  traceback.print_exc()
347
+ return (f"❌ Erreur — {str(e)}", None)
348
 
349
  # ----------------------------
350
+ # INTERFACE GRADIO (optionnel)
351
  # ----------------------------
352
+ with gr.Blocks(title="RobotsMali V41 - Sous-titrage") as demo:
353
+ gr.Markdown("## ⚡ RobotsMali V41 — Sous-titrage (QuartzNet fix + RNNT/CTC robust)")
354
  v = gr.Video(label="Vidéo à sous-titrer")
355
  m = gr.Dropdown(list(MODELS.keys()), value="Soloba V1 (CTC)", label="Modèle ASR")
356
  b = gr.Button("▶️ Générer")
357
  s = gr.Markdown()
358
  o = gr.Video(label="Vidéo sous-titrée")
 
359
  b.click(pipeline, [v, m], [s, o])
360
 
361
+ # Pour exécuter l'interface :
362
+ # demo.launch(share=True, debug=False)
363
+ demo.launch(share=True, debug=True)
364
+