Update app.py
Browse files
app.py
CHANGED
|
@@ -13,6 +13,7 @@ from moviepy.editor import VideoFileClip, CompositeVideoClip, ImageClip
|
|
| 13 |
from PIL import Image, ImageDraw, ImageFont
|
| 14 |
|
| 15 |
from nemo.collections import asr as nemo_asr
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
# ---------------- CONFIG ---------------- #
|
|
@@ -55,45 +56,74 @@ def extract_audio(video_path, wav_path):
|
|
| 55 |
return len(audio)/sr
|
| 56 |
|
| 57 |
|
| 58 |
-
# ---------------- TRANSCRIBE ---------------- #
|
| 59 |
def transcribe(model, device, wav_path, model_key):
|
| 60 |
audio, sr = sf.read(wav_path)
|
| 61 |
if audio.ndim == 2:
|
| 62 |
-
audio = np.mean(audio, axis=1)
|
| 63 |
-
|
| 64 |
-
audio = audio / np.max(np.abs(audio))
|
| 65 |
|
| 66 |
-
total_s = len(audio)/sr
|
| 67 |
x = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(device)
|
| 68 |
ln = torch.tensor([x.shape[1]]).to(device)
|
| 69 |
|
| 70 |
-
#
|
| 71 |
if "Soloni" in model_key and hasattr(model, "decode_and_align"):
|
| 72 |
try:
|
| 73 |
with torch.no_grad():
|
| 74 |
proc, plen = model.preprocessor(input_signal=x, input_signal_length=ln)
|
| 75 |
hyps = model.decode_and_align(encoder_output=proc, encoded_lengths=plen)
|
| 76 |
hyp = hyps[0][0] if isinstance(hyps[0], list) else hyps[0]
|
| 77 |
-
|
| 78 |
-
return [(w.start_offset_ms/1000, w.end_offset_ms/1000, w.word) for w in hyp.words]
|
| 79 |
except:
|
| 80 |
pass
|
| 81 |
|
| 82 |
-
#
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
words = text.split()
|
| 86 |
if not words:
|
| 87 |
return []
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
# ---------------- BURN SUBTITLES (NO IMAGEMAGICK) ---------------- #
|
|
@@ -107,29 +137,24 @@ def burn(video_path, subs):
|
|
| 107 |
font = ImageFont.load_default()
|
| 108 |
|
| 109 |
layers = []
|
| 110 |
-
for s, e,
|
| 111 |
-
if e <= s: continue
|
| 112 |
-
|
| 113 |
img = Image.new("RGBA", (W, int(H*0.12)), (0,0,0,140))
|
| 114 |
draw = ImageDraw.Draw(img)
|
| 115 |
|
| 116 |
-
text = w.upper()
|
| 117 |
-
|
| 118 |
-
# ✅ Pillow 10+ compatible text size
|
| 119 |
try:
|
| 120 |
bbox = draw.textbbox((0,0), text, font=font)
|
| 121 |
tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1]
|
| 122 |
except:
|
| 123 |
tw, th = draw.textsize(text, font=font)
|
| 124 |
|
| 125 |
-
x = (W
|
| 126 |
-
y = (int(H*0.12)
|
| 127 |
-
draw.text((x,
|
| 128 |
|
| 129 |
img_clip = ImageClip(np.array(img)).set_start(s).set_duration(e-s).set_position(("center", int(H*0.85)))
|
| 130 |
layers.append(img_clip)
|
| 131 |
|
| 132 |
-
final = CompositeVideoClip([clip] + layers
|
| 133 |
out = "RobotsMali_Subtitled.mp4"
|
| 134 |
final.write_videofile(out, codec="libx264", audio_codec="aac", fps=clip.fps, verbose=False, logger=None)
|
| 135 |
clip.close()
|
|
@@ -139,7 +164,7 @@ def burn(video_path, subs):
|
|
| 139 |
|
| 140 |
# ---------------- PIPELINE ---------------- #
|
| 141 |
def pipeline(video, model_name, progress=gr.Progress()):
|
| 142 |
-
progress(0.
|
| 143 |
model, device = load_model(model_name)
|
| 144 |
|
| 145 |
with tempfile.TemporaryDirectory() as td:
|
|
@@ -147,10 +172,10 @@ def pipeline(video, model_name, progress=gr.Progress()):
|
|
| 147 |
progress(0.5, "🔊 Extraction audio…")
|
| 148 |
extract_audio(video, wav)
|
| 149 |
|
| 150 |
-
progress(0.75, "🧠
|
| 151 |
subs = transcribe(model, device, wav, model_name)
|
| 152 |
|
| 153 |
-
progress(0.95, "🎞️ Incrustation…")
|
| 154 |
out = burn(video, subs)
|
| 155 |
return f"✅ Sous-titrage généré avec **{model_name}**", out
|
| 156 |
|
|
@@ -163,7 +188,7 @@ h1 { text-align:center; font-weight:800; color:#005BFF; margin-bottom:6px; }
|
|
| 163 |
"""
|
| 164 |
|
| 165 |
with gr.Blocks(css=CSS, title="RobotsMali Caption Studio") as demo:
|
| 166 |
-
gr.Markdown("<h1>RobotsMali Caption Studio</h1><p>Sous-titrage automatique en Bambara</p>")
|
| 167 |
video = gr.File(label="🎥 Importer une vidéo")
|
| 168 |
model = gr.Dropdown(list(ASR_MODELS.keys()), value="Soloni 114M TDT CTC V1", label="🧠 Modèle ASR")
|
| 169 |
run = gr.Button("🚀 Générer les sous-titres")
|
|
|
|
| 13 |
from PIL import Image, ImageDraw, ImageFont
|
| 14 |
|
| 15 |
from nemo.collections import asr as nemo_asr
|
| 16 |
+
from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text
|
| 17 |
|
| 18 |
|
| 19 |
# ---------------- CONFIG ---------------- #
|
|
|
|
| 56 |
return len(audio)/sr
|
| 57 |
|
| 58 |
|
| 59 |
+
# ---------------- TRANSCRIBE (with forced alignment) ---------------- #
|
| 60 |
def transcribe(model, device, wav_path, model_key):
|
| 61 |
audio, sr = sf.read(wav_path)
|
| 62 |
if audio.ndim == 2:
|
| 63 |
+
audio = np.mean(audio, axis=1)
|
| 64 |
+
total_s = len(audio) / sr
|
|
|
|
| 65 |
|
|
|
|
| 66 |
x = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(device)
|
| 67 |
ln = torch.tensor([x.shape[1]]).to(device)
|
| 68 |
|
| 69 |
+
# --- Case 1 : Soloni → true word timestamps ---
|
| 70 |
if "Soloni" in model_key and hasattr(model, "decode_and_align"):
|
| 71 |
try:
|
| 72 |
with torch.no_grad():
|
| 73 |
proc, plen = model.preprocessor(input_signal=x, input_signal_length=ln)
|
| 74 |
hyps = model.decode_and_align(encoder_output=proc, encoded_lengths=plen)
|
| 75 |
hyp = hyps[0][0] if isinstance(hyps[0], list) else hyps[0]
|
| 76 |
+
return [(w.start_offset_ms/1000, w.end_offset_ms/1000, w.word) for w in hyp.words]
|
|
|
|
| 77 |
except:
|
| 78 |
pass
|
| 79 |
|
| 80 |
+
# --- Case 2 : Soloba / QuartzNet → forced alignment CTC ---
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
logits, logits_len = model.forward(input_signal=x, input_signal_length=ln)
|
| 83 |
+
|
| 84 |
+
text = model.transcribe([wav_path])[0].text.strip()
|
| 85 |
words = text.split()
|
| 86 |
if not words:
|
| 87 |
return []
|
| 88 |
|
| 89 |
+
config = CtcSegmentationParameters()
|
| 90 |
+
config.char_list = list(model.tokenizer.vocab.keys())
|
| 91 |
+
|
| 92 |
+
ground_truth_mat, _ = prepare_text(config, words)
|
| 93 |
+
|
| 94 |
+
timings, _, _ = ctc_segmentation(
|
| 95 |
+
config,
|
| 96 |
+
logits.cpu().numpy()[0],
|
| 97 |
+
ground_truth_mat
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
time_per_step = total_s / logits_len.cpu().numpy()[0]
|
| 101 |
+
|
| 102 |
+
word_times = []
|
| 103 |
+
for i, w in enumerate(words):
|
| 104 |
+
start = timings[i] * time_per_step
|
| 105 |
+
end = timings[i+1] * time_per_step if i+1 < len(timings) else total_s
|
| 106 |
+
word_times.append((start, end, w))
|
| 107 |
+
|
| 108 |
+
# --- Segment mode B (2 to 5 words per subtitle line) ---
|
| 109 |
+
grouped = []
|
| 110 |
+
segment = []
|
| 111 |
+
for w in word_times:
|
| 112 |
+
segment.append(w)
|
| 113 |
+
if len(segment) >= 4: # max words per line
|
| 114 |
+
grouped.append(segment)
|
| 115 |
+
segment = []
|
| 116 |
+
if segment:
|
| 117 |
+
grouped.append(segment)
|
| 118 |
+
|
| 119 |
+
subtitles = []
|
| 120 |
+
for seg in grouped:
|
| 121 |
+
s = seg[0][0]
|
| 122 |
+
e = seg[-1][1]
|
| 123 |
+
text = " ".join([w[2] for w in seg])
|
| 124 |
+
subtitles.append((s, e, text))
|
| 125 |
+
|
| 126 |
+
return subtitles
|
| 127 |
|
| 128 |
|
| 129 |
# ---------------- BURN SUBTITLES (NO IMAGEMAGICK) ---------------- #
|
|
|
|
| 137 |
font = ImageFont.load_default()
|
| 138 |
|
| 139 |
layers = []
|
| 140 |
+
for s, e, text in subs:
|
|
|
|
|
|
|
| 141 |
img = Image.new("RGBA", (W, int(H*0.12)), (0,0,0,140))
|
| 142 |
draw = ImageDraw.Draw(img)
|
| 143 |
|
|
|
|
|
|
|
|
|
|
| 144 |
try:
|
| 145 |
bbox = draw.textbbox((0,0), text, font=font)
|
| 146 |
tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1]
|
| 147 |
except:
|
| 148 |
tw, th = draw.textsize(text, font=font)
|
| 149 |
|
| 150 |
+
x = (W-tw)//2
|
| 151 |
+
y = (int(H*0.12)-th)//2
|
| 152 |
+
draw.text((x,y), text, font=font, fill=(255,255,255))
|
| 153 |
|
| 154 |
img_clip = ImageClip(np.array(img)).set_start(s).set_duration(e-s).set_position(("center", int(H*0.85)))
|
| 155 |
layers.append(img_clip)
|
| 156 |
|
| 157 |
+
final = CompositeVideoClip([clip] + layers]
|
| 158 |
out = "RobotsMali_Subtitled.mp4"
|
| 159 |
final.write_videofile(out, codec="libx264", audio_codec="aac", fps=clip.fps, verbose=False, logger=None)
|
| 160 |
clip.close()
|
|
|
|
| 164 |
|
| 165 |
# ---------------- PIPELINE ---------------- #
|
| 166 |
def pipeline(video, model_name, progress=gr.Progress()):
|
| 167 |
+
progress(0.3, "📦 Chargement du modèle…")
|
| 168 |
model, device = load_model(model_name)
|
| 169 |
|
| 170 |
with tempfile.TemporaryDirectory() as td:
|
|
|
|
| 172 |
progress(0.5, "🔊 Extraction audio…")
|
| 173 |
extract_audio(video, wav)
|
| 174 |
|
| 175 |
+
progress(0.75, "🧠 Alignement temporel…")
|
| 176 |
subs = transcribe(model, device, wav, model_name)
|
| 177 |
|
| 178 |
+
progress(0.95, "🎞️ Incrustation des sous-titres…")
|
| 179 |
out = burn(video, subs)
|
| 180 |
return f"✅ Sous-titrage généré avec **{model_name}**", out
|
| 181 |
|
|
|
|
| 188 |
"""
|
| 189 |
|
| 190 |
with gr.Blocks(css=CSS, title="RobotsMali Caption Studio") as demo:
|
| 191 |
+
gr.Markdown("<h1>RobotsMali Caption Studio</h1><p>Sous-titrage automatique en Bambara (Alignement Professionnel)</p>")
|
| 192 |
video = gr.File(label="🎥 Importer une vidéo")
|
| 193 |
model = gr.Dropdown(list(ASR_MODELS.keys()), value="Soloni 114M TDT CTC V1", label="🧠 Modèle ASR")
|
| 194 |
run = gr.Button("🚀 Générer les sous-titres")
|