binaryMao commited on
Commit
60f6dea
·
verified ·
1 Parent(s): 9683e37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -131
app.py CHANGED
@@ -1,134 +1,112 @@
1
- import os, warnings, logging, tempfile
2
- warnings.filterwarnings("ignore")
3
- logging.getLogger("nemo_logger").setLevel(logging.ERROR)
4
-
5
- import torch
6
- torch.set_grad_enabled(False)
7
-
8
  import gradio as gr
 
9
  import numpy as np
 
10
  import soundfile as sf
11
-
12
  from moviepy.editor import VideoFileClip, CompositeVideoClip, ImageClip
13
  from PIL import Image, ImageDraw, ImageFont
14
 
15
  from nemo.collections import asr as nemo_asr
 
16
  from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text
17
 
18
 
19
- # ---------------- CONFIG ---------------- #
20
- os.environ["NEMO_FORCE_CPU"] = "1"
21
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
-
23
- SR = 16000
24
- MAX_VIDEO_BYTES = 200_000_000
25
 
26
- ASR_MODELS = {
27
- "Soloba CTC 0.6B V0": "RobotsMali/soloba-ctc-0.6b-v0",
28
- "Soloba CTC 0.6B V1": "RobotsMali/soloba-ctc-0.6b-v1",
29
- "Soloni 114M TDT CTC V0": "RobotsMali/soloni-114m-tdt-ctc-V0",
30
- "Soloni 114M TDT CTC V1": "RobotsMali/soloni-114m-tdt-ctc-v1",
31
- "QuartzNet BM V0": "RobotsMali/stt-bm-quartznet15x5-V0",
32
- "QuartzNet BM V1": "RobotsMali/stt-bm-quartznet15x5-V1"
33
- }
34
 
35
- _CACHE = {}
 
36
 
 
 
 
37
 
38
- # ---------------- LOAD MODEL ---------------- #
39
- def load_model(name):
40
- if name in _CACHE:
41
- return _CACHE[name]
42
- device = "cuda" if torch.cuda.is_available() else "cpu"
43
- model = nemo_asr.models.ASRModel.from_pretrained(
44
- model_name=ASR_MODELS[name]
45
- ).to(device).eval()
46
- _CACHE[name] = (model, device)
47
- return model, device
48
 
 
 
 
49
 
50
- # ---------------- EXTRACT AUDIO ---------------- #
51
  def extract_audio(video_path, wav_path):
52
- if os.path.getsize(video_path) > MAX_VIDEO_BYTES:
53
- raise RuntimeError("⚠️ Vidéo > 200MB. Compressez avant l’upload.")
54
- os.system(f"ffmpeg -y -i '{video_path}' -ac 1 -ar {SR} -vn '{wav_path}' >/dev/null 2>&1")
55
- audio, sr = sf.read(wav_path)
56
- return len(audio)/sr
 
 
57
 
 
 
 
58
 
59
- # ---------------- TRANSCRIBE (with forced alignment) ---------------- #
60
- def transcribe(model, device, wav_path, model_key):
61
- audio, sr = sf.read(wav_path)
62
  if audio.ndim == 2:
63
  audio = np.mean(audio, axis=1)
64
- total_s = len(audio) / sr
65
 
66
  x = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(device)
67
  ln = torch.tensor([x.shape[1]]).to(device)
68
 
69
- # --- Case 1 : Soloni → true word timestamps ---
70
- if "Soloni" in model_key and hasattr(model, "decode_and_align"):
71
- try:
72
- with torch.no_grad():
73
- proc, plen = model.preprocessor(input_signal=x, input_signal_length=ln)
74
- hyps = model.decode_and_align(encoder_output=proc, encoded_lengths=plen)
75
- hyp = hyps[0][0] if isinstance(hyps[0], list) else hyps[0]
76
- return [(w.start_offset_ms/1000, w.end_offset_ms/1000, w.word) for w in hyp.words]
77
- except:
78
- pass
79
-
80
- # --- Case 2 : Soloba / QuartzNet → forced alignment CTC ---
 
 
81
  with torch.no_grad():
82
- logits, logits_len = model.forward(input_signal=x, input_signal_length=ln)
83
 
84
- text = model.transcribe([wav_path])[0].text.strip()
85
  words = text.split()
86
- if not words:
87
- return []
88
-
89
  config = CtcSegmentationParameters()
90
  config.char_list = list(model.tokenizer.vocab.keys())
 
91
 
92
- ground_truth_mat, _ = prepare_text(config, words)
93
-
94
- timings, _, _ = ctc_segmentation(
95
- config,
96
- logits.cpu().numpy()[0],
97
- ground_truth_mat
98
- )
99
-
100
- time_per_step = total_s / logits_len.cpu().numpy()[0]
101
 
102
  word_times = []
103
  for i, w in enumerate(words):
104
- start = timings[i] * time_per_step
105
- end = timings[i+1] * time_per_step if i+1 < len(timings) else total_s
106
- word_times.append((start, end, w))
107
 
108
- # --- Segment mode B (2 to 5 words per subtitle line) ---
109
- grouped = []
110
- segment = []
111
  for w in word_times:
112
- segment.append(w)
113
- if len(segment) >= 4: # max words per line
114
- grouped.append(segment)
115
- segment = []
116
- if segment:
117
- grouped.append(segment)
118
 
119
- subtitles = []
120
- for seg in grouped:
121
- s = seg[0][0]
122
- e = seg[-1][1]
123
- text = " ".join([w[2] for w in seg])
124
- subtitles.append((s, e, text))
125
 
126
- return subtitles
127
 
128
 
129
- # ---------------- BURN SUBTITLES (NO IMAGEMAGICK) ---------------- #
130
- def burn(video_path, subs):
131
- clip = VideoFileClip(video_path)
 
 
 
132
  W, H = clip.size
133
 
134
  try:
@@ -140,21 +118,15 @@ def burn(video_path, subs):
140
  for s, e, text in subs:
141
  img = Image.new("RGBA", (W, int(H*0.12)), (0,0,0,140))
142
  draw = ImageDraw.Draw(img)
 
 
 
143
 
144
- try:
145
- bbox = draw.textbbox((0,0), text, font=font)
146
- tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1]
147
- except:
148
- tw, th = draw.textsize(text, font=font)
149
 
150
- x = (W-tw)//2
151
- y = (int(H*0.12)-th)//2
152
- draw.text((x,y), text, font=font, fill=(255,255,255))
153
-
154
- img_clip = ImageClip(np.array(img)).set_start(s).set_duration(e-s).set_position(("center", int(H*0.85)))
155
- layers.append(img_clip)
156
-
157
- final = CompositeVideoClip([clip] + layers]
158
  out = "RobotsMali_Subtitled.mp4"
159
  final.write_videofile(out, codec="libx264", audio_codec="aac", fps=clip.fps, verbose=False, logger=None)
160
  clip.close()
@@ -162,39 +134,37 @@ def burn(video_path, subs):
162
  return out
163
 
164
 
165
- # ---------------- PIPELINE ---------------- #
166
- def pipeline(video, model_name, progress=gr.Progress()):
167
- progress(0.3, "📦 Chargement du modèle…")
168
- model, device = load_model(model_name)
 
 
 
169
 
170
- with tempfile.TemporaryDirectory() as td:
171
- wav = f"{td}/audio.wav"
172
- progress(0.5, "🔊 Extraction audio…")
173
- extract_audio(video, wav)
174
 
175
- progress(0.75, "🧠 Alignement temporel…")
176
- subs = transcribe(model, device, wav, model_name)
 
 
 
177
 
178
- progress(0.95, "🎞️ Incrustation des sous-titres…")
179
- out = burn(video, subs)
180
- return f"✅ Sous-titrage généré avec **{model_name}**", out
181
 
 
 
 
182
 
183
- # ---------------- UI ---------------- #
184
- CSS = """
185
- body { background:#F5F8FF; font-family:Inter, sans-serif; }
186
- h1 { text-align:center; font-weight:800; color:#005BFF; margin-bottom:6px; }
187
- .gr-button { background:#005BFF !important; color:white !important; border-radius:8px; font-weight:700; }
188
- """
189
 
190
- with gr.Blocks(css=CSS, title="RobotsMali Caption Studio") as demo:
191
- gr.Markdown("<h1>RobotsMali Caption Studio</h1><p>Sous-titrage automatique en Bambara (Alignement Professionnel)</p>")
192
- video = gr.File(label="🎥 Importer une vidéo")
193
- model = gr.Dropdown(list(ASR_MODELS.keys()), value="Soloni 114M TDT CTC V1", label="🧠 Modèle ASR")
194
- run = gr.Button("🚀 Générer les sous-titres")
195
  status = gr.Markdown()
196
- output = gr.Video()
197
 
198
- run.click(pipeline, inputs=[video, model], outputs=[status, output])
199
 
200
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False, ssr_mode=False)
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
  import numpy as np
4
+ import torch
5
  import soundfile as sf
 
6
  from moviepy.editor import VideoFileClip, CompositeVideoClip, ImageClip
7
  from PIL import Image, ImageDraw, ImageFont
8
 
9
  from nemo.collections import asr as nemo_asr
10
+ from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
11
  from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text
12
 
13
 
14
+ # =============================
15
+ # LISTE DES MODELES ROBOTSMALI
16
+ # =============================
 
 
 
17
 
18
+ MODELS = {
19
+ "Soloni 114M TDT CTC v1": "RobotsMali/soloni-114m-tdt-ctc-v1",
20
+ "Soloni 350M TDT CTC v1": "RobotsMali/soloni-350m-tdt-ctc-v1",
 
 
 
 
 
21
 
22
+ "Soloba CTC 0.6B v0": "RobotsMali/soloba-ctc-0.6b-v0",
23
+ "Soloba CTC 0.6B v1": "RobotsMali/soloba-ctc-0.6b-v1",
24
 
25
+ "QuartzNet Bambara v1": "RobotsMali/stt-bm-quartznet15x5-v1",
26
+ "QuartzNet Bambara v2": "RobotsMali/stt-bm-quartznet15x5-v2"
27
+ }
28
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # =============================
31
+ # FONCTION : EXTRAIRE AUDIO
32
+ # =============================
33
 
 
34
  def extract_audio(video_path, wav_path):
35
+ clip = VideoFileClip(video_path)
36
+ audio = clip.audio.to_soundarray(fps=16000)
37
+ if audio.ndim == 2:
38
+ audio = np.mean(audio, axis=1)
39
+ sf.write(wav_path, audio, 16000)
40
+ clip.close()
41
+
42
 
43
+ # =============================
44
+ # FONCTION : TRANSCRIPTION + TIMESTAMP
45
+ # =============================
46
 
47
+ def transcribe(model, device, wav, model_name):
48
+ audio, sr = sf.read(wav)
 
49
  if audio.ndim == 2:
50
  audio = np.mean(audio, axis=1)
 
51
 
52
  x = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(device)
53
  ln = torch.tensor([x.shape[1]]).to(device)
54
 
55
+ # === Cas 1 : Soloni → timestamps natifs ===
56
+ if "Soloni" in model_name and hasattr(model, "decode_and_align"):
57
+ with torch.no_grad():
58
+ proc, plen = model.preprocessor(input_signal=x, input_signal_length=ln)
59
+ hyps = model.decode_and_align(encoder_output=proc, encoded_lengths=plen)
60
+ hyp = hyps[0][0] if isinstance(hyps[0], list) else hyps[0]
61
+ return [(w.start_offset_ms/1000, w.end_offset_ms/1000, w.word) for w in hyp.words]
62
+
63
+ # === Cas 2 : Soloba & QuartzNet → Forced Alignment CTC ===
64
+ text = model.transcribe([wav])[0]
65
+ text = text.strip()
66
+ if not text:
67
+ return []
68
+
69
  with torch.no_grad():
70
+ logits, logit_len = model.forward(input_signal=x, input_signal_length=ln)
71
 
 
72
  words = text.split()
 
 
 
73
  config = CtcSegmentationParameters()
74
  config.char_list = list(model.tokenizer.vocab.keys())
75
+ gt, utt = prepare_text(config, words)
76
 
77
+ timings, _, _ = ctc_segmentation(config, logits.cpu().numpy()[0], gt)
78
+ total_s = len(audio) / sr
79
+ tps = total_s / logit_len.cpu().numpy()[0]
 
 
 
 
 
 
80
 
81
  word_times = []
82
  for i, w in enumerate(words):
83
+ s = timings[i] * tps
84
+ e = timings[i+1] * tps if i+1 < len(timings) else total_s
85
+ word_times.append((s, e, w))
86
 
87
+ # Groupage lisible : 3-5 mots par ligne
88
+ grouped, block = [], []
 
89
  for w in word_times:
90
+ block.append(w)
91
+ if len(block) >= 4:
92
+ grouped.append(block)
93
+ block = []
94
+ if block:
95
+ grouped.append(block)
96
 
97
+ subs = []
98
+ for g in grouped:
99
+ subs.append((g[0][0], g[-1][1], " ".join([w[2] for w in g])))
 
 
 
100
 
101
+ return subs
102
 
103
 
104
+ # =============================
105
+ # FONCTION : INCRUSTATION SOUS-TITRES
106
+ # =============================
107
+
108
+ def burn(video, subs):
109
+ clip = VideoFileClip(video)
110
  W, H = clip.size
111
 
112
  try:
 
118
  for s, e, text in subs:
119
  img = Image.new("RGBA", (W, int(H*0.12)), (0,0,0,140))
120
  draw = ImageDraw.Draw(img)
121
+ bbox = draw.textbbox((0,0), text, font=font)
122
+ tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1]
123
+ draw.text(((W-tw)//2, (int(H*0.12)-th)//2), text, font=font, fill="white")
124
 
125
+ layers.append(ImageClip(np.array(img))
126
+ .set_start(s).set_duration(e-s)
127
+ .set_position(("center", int(H*0.85))))
 
 
128
 
129
+ final = CompositeVideoClip([clip] + layers)
 
 
 
 
 
 
 
130
  out = "RobotsMali_Subtitled.mp4"
131
  final.write_videofile(out, codec="libx264", audio_codec="aac", fps=clip.fps, verbose=False, logger=None)
132
  clip.close()
 
134
  return out
135
 
136
 
137
+ # =============================
138
+ # PIPELINE
139
+ # =============================
140
+
141
+ def pipeline(video_file, model_name):
142
+ if video_file is None:
143
+ return "Veuillez importer une vidéo.", None
144
 
145
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146
+ model = nemo_asr.models.ASRModel.from_pretrained(MODELS[model_name]).to(device)
 
 
147
 
148
+ wav = "temp.wav"
149
+ extract_audio(video_file, wav)
150
+ subs = transcribe(model, device, wav, model_name)
151
+ out = burn(video_file, subs)
152
+ return "✅ Sous-titres générés avec succès.", out
153
 
 
 
 
154
 
155
+ # =============================
156
+ # INTERFACE GRADIO
157
+ # =============================
158
 
159
+ with gr.Blocks() as demo:
160
+ gr.Markdown("# 🎙️ RobotsMali Subtitle Generator")
 
 
 
 
161
 
162
+ video = gr.Video(label="Importer une vidéo")
163
+ model = gr.Dropdown(list(MODELS.keys()), value="Soloni 114M TDT CTC v1", label="Sélection du modèle")
164
+ btn = gr.Button(" Générer les sous-titres")
 
 
165
  status = gr.Markdown()
166
+ out = gr.Video(label="Résultat")
167
 
168
+ btn.click(pipeline, inputs=[video, model], outputs=[status, out])
169
 
170
+ demo.launch()