binaryMao commited on
Commit
9683e37
·
verified ·
1 Parent(s): 4d32742

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -33
app.py CHANGED
@@ -13,6 +13,7 @@ from moviepy.editor import VideoFileClip, CompositeVideoClip, ImageClip
13
  from PIL import Image, ImageDraw, ImageFont
14
 
15
  from nemo.collections import asr as nemo_asr
 
16
 
17
 
18
  # ---------------- CONFIG ---------------- #
@@ -55,45 +56,74 @@ def extract_audio(video_path, wav_path):
55
  return len(audio)/sr
56
 
57
 
58
- # ---------------- TRANSCRIBE ---------------- #
59
  def transcribe(model, device, wav_path, model_key):
60
  audio, sr = sf.read(wav_path)
61
  if audio.ndim == 2:
62
- audio = np.mean(audio, axis=1).astype(np.float32)
63
- if np.max(np.abs(audio)) > 1:
64
- audio = audio / np.max(np.abs(audio))
65
 
66
- total_s = len(audio)/sr
67
  x = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(device)
68
  ln = torch.tensor([x.shape[1]]).to(device)
69
 
70
- # Soloni real timestamps
71
  if "Soloni" in model_key and hasattr(model, "decode_and_align"):
72
  try:
73
  with torch.no_grad():
74
  proc, plen = model.preprocessor(input_signal=x, input_signal_length=ln)
75
  hyps = model.decode_and_align(encoder_output=proc, encoded_lengths=plen)
76
  hyp = hyps[0][0] if isinstance(hyps[0], list) else hyps[0]
77
- if hasattr(hyp, "words") and hyp.words:
78
- return [(w.start_offset_ms/1000, w.end_offset_ms/1000, w.word) for w in hyp.words]
79
  except:
80
  pass
81
 
82
- # Universal fallback for Soloba + QuartzNet
83
- out = model.transcribe([wav_path])[0]
84
- text = out.text.strip() if hasattr(out, "text") else str(out).strip()
 
 
85
  words = text.split()
86
  if not words:
87
  return []
88
 
89
- wps = max(2.0, len(words) / total_s)
90
- subs, t = [], 0
91
- for w in words:
92
- d = 1 / wps
93
- subs.append((t, min(total_s, t+d), w))
94
- t += d
95
- if t >= total_s: break
96
- return subs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
  # ---------------- BURN SUBTITLES (NO IMAGEMAGICK) ---------------- #
@@ -107,29 +137,24 @@ def burn(video_path, subs):
107
  font = ImageFont.load_default()
108
 
109
  layers = []
110
- for s, e, w in subs:
111
- if e <= s: continue
112
-
113
  img = Image.new("RGBA", (W, int(H*0.12)), (0,0,0,140))
114
  draw = ImageDraw.Draw(img)
115
 
116
- text = w.upper()
117
-
118
- # ✅ Pillow 10+ compatible text size
119
  try:
120
  bbox = draw.textbbox((0,0), text, font=font)
121
  tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1]
122
  except:
123
  tw, th = draw.textsize(text, font=font)
124
 
125
- x = (W - tw) // 2
126
- y = (int(H*0.12) - th) // 2
127
- draw.text((x, y), text, font=font, fill=(255,255,255))
128
 
129
  img_clip = ImageClip(np.array(img)).set_start(s).set_duration(e-s).set_position(("center", int(H*0.85)))
130
  layers.append(img_clip)
131
 
132
- final = CompositeVideoClip([clip] + layers)
133
  out = "RobotsMali_Subtitled.mp4"
134
  final.write_videofile(out, codec="libx264", audio_codec="aac", fps=clip.fps, verbose=False, logger=None)
135
  clip.close()
@@ -139,7 +164,7 @@ def burn(video_path, subs):
139
 
140
  # ---------------- PIPELINE ---------------- #
141
  def pipeline(video, model_name, progress=gr.Progress()):
142
- progress(0.2, "📦 Chargement du modèle…")
143
  model, device = load_model(model_name)
144
 
145
  with tempfile.TemporaryDirectory() as td:
@@ -147,10 +172,10 @@ def pipeline(video, model_name, progress=gr.Progress()):
147
  progress(0.5, "🔊 Extraction audio…")
148
  extract_audio(video, wav)
149
 
150
- progress(0.75, "🧠 Transcription…")
151
  subs = transcribe(model, device, wav, model_name)
152
 
153
- progress(0.95, "🎞️ Incrustation…")
154
  out = burn(video, subs)
155
  return f"✅ Sous-titrage généré avec **{model_name}**", out
156
 
@@ -163,7 +188,7 @@ h1 { text-align:center; font-weight:800; color:#005BFF; margin-bottom:6px; }
163
  """
164
 
165
  with gr.Blocks(css=CSS, title="RobotsMali Caption Studio") as demo:
166
- gr.Markdown("<h1>RobotsMali Caption Studio</h1><p>Sous-titrage automatique en Bambara</p>")
167
  video = gr.File(label="🎥 Importer une vidéo")
168
  model = gr.Dropdown(list(ASR_MODELS.keys()), value="Soloni 114M TDT CTC V1", label="🧠 Modèle ASR")
169
  run = gr.Button("🚀 Générer les sous-titres")
 
13
  from PIL import Image, ImageDraw, ImageFont
14
 
15
  from nemo.collections import asr as nemo_asr
16
+ from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text
17
 
18
 
19
  # ---------------- CONFIG ---------------- #
 
56
  return len(audio)/sr
57
 
58
 
59
+ # ---------------- TRANSCRIBE (with forced alignment) ---------------- #
60
  def transcribe(model, device, wav_path, model_key):
61
  audio, sr = sf.read(wav_path)
62
  if audio.ndim == 2:
63
+ audio = np.mean(audio, axis=1)
64
+ total_s = len(audio) / sr
 
65
 
 
66
  x = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(device)
67
  ln = torch.tensor([x.shape[1]]).to(device)
68
 
69
+ # --- Case 1 : Soloni true word timestamps ---
70
  if "Soloni" in model_key and hasattr(model, "decode_and_align"):
71
  try:
72
  with torch.no_grad():
73
  proc, plen = model.preprocessor(input_signal=x, input_signal_length=ln)
74
  hyps = model.decode_and_align(encoder_output=proc, encoded_lengths=plen)
75
  hyp = hyps[0][0] if isinstance(hyps[0], list) else hyps[0]
76
+ return [(w.start_offset_ms/1000, w.end_offset_ms/1000, w.word) for w in hyp.words]
 
77
  except:
78
  pass
79
 
80
+ # --- Case 2 : Soloba / QuartzNet → forced alignment CTC ---
81
+ with torch.no_grad():
82
+ logits, logits_len = model.forward(input_signal=x, input_signal_length=ln)
83
+
84
+ text = model.transcribe([wav_path])[0].text.strip()
85
  words = text.split()
86
  if not words:
87
  return []
88
 
89
+ config = CtcSegmentationParameters()
90
+ config.char_list = list(model.tokenizer.vocab.keys())
91
+
92
+ ground_truth_mat, _ = prepare_text(config, words)
93
+
94
+ timings, _, _ = ctc_segmentation(
95
+ config,
96
+ logits.cpu().numpy()[0],
97
+ ground_truth_mat
98
+ )
99
+
100
+ time_per_step = total_s / logits_len.cpu().numpy()[0]
101
+
102
+ word_times = []
103
+ for i, w in enumerate(words):
104
+ start = timings[i] * time_per_step
105
+ end = timings[i+1] * time_per_step if i+1 < len(timings) else total_s
106
+ word_times.append((start, end, w))
107
+
108
+ # --- Segment mode B (2 to 5 words per subtitle line) ---
109
+ grouped = []
110
+ segment = []
111
+ for w in word_times:
112
+ segment.append(w)
113
+ if len(segment) >= 4: # max words per line
114
+ grouped.append(segment)
115
+ segment = []
116
+ if segment:
117
+ grouped.append(segment)
118
+
119
+ subtitles = []
120
+ for seg in grouped:
121
+ s = seg[0][0]
122
+ e = seg[-1][1]
123
+ text = " ".join([w[2] for w in seg])
124
+ subtitles.append((s, e, text))
125
+
126
+ return subtitles
127
 
128
 
129
  # ---------------- BURN SUBTITLES (NO IMAGEMAGICK) ---------------- #
 
137
  font = ImageFont.load_default()
138
 
139
  layers = []
140
+ for s, e, text in subs:
 
 
141
  img = Image.new("RGBA", (W, int(H*0.12)), (0,0,0,140))
142
  draw = ImageDraw.Draw(img)
143
 
 
 
 
144
  try:
145
  bbox = draw.textbbox((0,0), text, font=font)
146
  tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1]
147
  except:
148
  tw, th = draw.textsize(text, font=font)
149
 
150
+ x = (W-tw)//2
151
+ y = (int(H*0.12)-th)//2
152
+ draw.text((x,y), text, font=font, fill=(255,255,255))
153
 
154
  img_clip = ImageClip(np.array(img)).set_start(s).set_duration(e-s).set_position(("center", int(H*0.85)))
155
  layers.append(img_clip)
156
 
157
+ final = CompositeVideoClip([clip] + layers]
158
  out = "RobotsMali_Subtitled.mp4"
159
  final.write_videofile(out, codec="libx264", audio_codec="aac", fps=clip.fps, verbose=False, logger=None)
160
  clip.close()
 
164
 
165
  # ---------------- PIPELINE ---------------- #
166
  def pipeline(video, model_name, progress=gr.Progress()):
167
+ progress(0.3, "📦 Chargement du modèle…")
168
  model, device = load_model(model_name)
169
 
170
  with tempfile.TemporaryDirectory() as td:
 
172
  progress(0.5, "🔊 Extraction audio…")
173
  extract_audio(video, wav)
174
 
175
+ progress(0.75, "🧠 Alignement temporel…")
176
  subs = transcribe(model, device, wav, model_name)
177
 
178
+ progress(0.95, "🎞️ Incrustation des sous-titres…")
179
  out = burn(video, subs)
180
  return f"✅ Sous-titrage généré avec **{model_name}**", out
181
 
 
188
  """
189
 
190
  with gr.Blocks(css=CSS, title="RobotsMali Caption Studio") as demo:
191
+ gr.Markdown("<h1>RobotsMali Caption Studio</h1><p>Sous-titrage automatique en Bambara (Alignement Professionnel)</p>")
192
  video = gr.File(label="🎥 Importer une vidéo")
193
  model = gr.Dropdown(list(ASR_MODELS.keys()), value="Soloni 114M TDT CTC V1", label="🧠 Modèle ASR")
194
  run = gr.Button("🚀 Générer les sous-titres")