binaryMao commited on
Commit
60b8ac2
·
verified ·
1 Parent(s): 0fb103f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -48
app.py CHANGED
@@ -2,32 +2,67 @@ import gradio as gr
2
  import numpy as np
3
  import torch
4
  import soundfile as sf
 
 
5
  from moviepy.editor import VideoFileClip, CompositeVideoClip, ImageClip
6
  from PIL import Image, ImageDraw, ImageFont
7
  from nemo.collections import asr as nemo_asr
8
- from huggingface_hub import hf_hub_download
9
  from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text
10
 
11
-
12
  MODELS = {
13
  "Soloni V0": ("RobotsMali/soloni-114m-tdt-ctc-V0", "soloni-114m-tdt-ctc-V0.nemo", "rnnt"),
14
  "Soloni V1": ("RobotsMali/soloni-114m-tdt-ctc-V1", "soloni-114m-tdt-ctc-V1.nemo", "rnnt"),
15
-
16
  "Soloba V0": ("RobotsMali/soloba-ctc-0.6b-V0", None, "ctc"),
17
  "Soloba V1": ("RobotsMali/soloba-ctc-0.6b-V1", None, "ctc"),
18
-
19
  "QuartzNet V0": ("RobotsMali/stt-bm-quartznet15x5-V0", None, "ctc"),
20
  "QuartzNet V1": ("RobotsMali/stt-bm-quartznet15x5-V1", None, "ctc"),
21
  }
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def extract_audio(video_path, wav_path):
25
- (VideoFileClip(video_path).audio.write_audiofile(
 
 
26
  wav_path, fps=16000, codec="pcm_s16le", verbose=False, logger=None
27
- ))
28
-
29
 
30
  def transcribe(model, device, wav, model_name):
 
31
  audio, sr = sf.read(wav)
32
  if audio.ndim == 2:
33
  audio = np.mean(audio, axis=1)
@@ -35,6 +70,7 @@ def transcribe(model, device, wav, model_name):
35
  ln = torch.tensor([x.shape[1]]).to(device)
36
  total_s = len(audio) / sr
37
 
 
38
  if "Soloni" in model_name:
39
  with torch.no_grad():
40
  proc, plen = model.preprocessor(input_signal=x, input_signal_length=ln)
@@ -42,6 +78,7 @@ def transcribe(model, device, wav, model_name):
42
  hyp = hyps[0][0] if isinstance(hyps[0], list) else hyps[0]
43
  return [(w.start_offset_ms/1000, w.end_offset_ms/1000, w.word) for w in hyp.words]
44
 
 
45
  text = model.transcribe([wav])[0].strip()
46
  if not text:
47
  return []
@@ -50,6 +87,9 @@ def transcribe(model, device, wav, model_name):
50
  logits, logit_len = model.forward(input_signal=x, input_signal_length=ln)
51
 
52
  words = text.split()
 
 
 
53
  config = CtcSegmentationParameters()
54
  config.char_list = list(model.tokenizer.vocab.keys())
55
  gt, _ = prepare_text(config, words)
@@ -61,69 +101,144 @@ def transcribe(model, device, wav, model_name):
61
  timings[i+1] * tps if i+1 < len(timings) else total_s,
62
  words[i]) for i in range(len(words))]
63
 
 
64
  grouped, temp = [], []
65
  for w in aligned:
66
  temp.append(w)
67
- if len(temp) >= 4:
68
- grouped.append(temp); temp = []
69
- if temp: grouped.append(temp)
 
 
70
 
71
  return [(g[0][0], g[-1][1], " ".join([w[2] for w in g])) for g in grouped]
72
 
73
-
74
  def burn(video, subs):
 
75
  clip = VideoFileClip(video)
76
  W, H = clip.size
 
 
77
  try:
78
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", int(H/20))
 
79
  except:
80
- font = ImageFont.load_default()
 
 
 
81
 
82
  layers = []
83
- for s,e,text in subs:
84
- img = Image.new("RGBA",(W,int(H*0.12)),(0,0,0,140))
 
 
85
  draw = ImageDraw.Draw(img)
86
- bbox = draw.textbbox((0,0), text, font=font)
87
- tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1]
88
- draw.text(((W-tw)//2,(int(H*0.12)-th)//2), text, font=font, fill="white")
89
- layers.append(ImageClip(np.array(img)).set_start(s).set_duration(e-s).set_position(("center",int(H*0.85))))
90
-
 
 
 
 
 
 
 
 
 
 
91
  final = CompositeVideoClip([clip] + layers)
92
- out = "RobotsMali_Subtitled.mp4"
93
- final.write_videofile(out, codec="libx264", audio_codec="aac", fps=clip.fps, verbose=False, logger=None)
94
- clip.close(); final.close()
95
- return out
96
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def pipeline(video_file, model_name):
 
99
  if video_file is None:
100
  return "Veuillez importer une vidéo.", None
101
 
102
  repo, nemo_file, mode = MODELS[model_name]
103
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
 
105
- if mode == "rnnt":
106
- nemo_path = hf_hub_download(repo, filename=nemo_file)
107
- model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.restore_from(nemo_path)
108
- else:
109
- model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name=repo)
110
-
111
- model = model.to(device); model.eval()
112
-
113
- wav = "audio.wav"
114
- extract_audio(video_file, wav)
115
- subs = transcribe(model, device, wav, model_name)
116
- final = burn(video_file, subs)
117
- return "✅ Sous-titres générés.", final
118
-
119
-
120
- with gr.Blocks() as demo:
121
- gr.Markdown("# 🎙️ **RobotsMali — Sous-titrage automatique Bambara**")
122
- video = gr.Video(label="Vidéo")
123
- model = gr.Dropdown(list(MODELS.keys()), value="Soloni V1", label="Modèle")
124
- btn = gr.Button("⚡ Générer les sous-titres")
125
- status = gr.Markdown()
126
- out = gr.Video(label="Résultat")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  btn.click(pipeline, inputs=[video, model], outputs=[status, out])
128
 
129
- demo.launch()
 
 
2
  import numpy as np
3
  import torch
4
  import soundfile as sf
5
+ import os
6
+ import tempfile
7
  from moviepy.editor import VideoFileClip, CompositeVideoClip, ImageClip
8
  from PIL import Image, ImageDraw, ImageFont
9
  from nemo.collections import asr as nemo_asr
10
+ from huggingface_hub import hf_hub_download, snapshot_download
11
  from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text
12
 
 
13
  MODELS = {
14
  "Soloni V0": ("RobotsMali/soloni-114m-tdt-ctc-V0", "soloni-114m-tdt-ctc-V0.nemo", "rnnt"),
15
  "Soloni V1": ("RobotsMali/soloni-114m-tdt-ctc-V1", "soloni-114m-tdt-ctc-V1.nemo", "rnnt"),
 
16
  "Soloba V0": ("RobotsMali/soloba-ctc-0.6b-V0", None, "ctc"),
17
  "Soloba V1": ("RobotsMali/soloba-ctc-0.6b-V1", None, "ctc"),
 
18
  "QuartzNet V0": ("RobotsMali/stt-bm-quartznet15x5-V0", None, "ctc"),
19
  "QuartzNet V1": ("RobotsMali/stt-bm-quartznet15x5-V1", None, "ctc"),
20
  }
21
 
22
+ def load_ctc_model_safe(repo_id):
23
+ """Charge les modèles CTC de manière robuste"""
24
+ try:
25
+ # Essai 1: Chargement standard
26
+ return nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name=repo_id)
27
+ except Exception as e:
28
+ print(f"Erreur lors du chargement standard: {e}")
29
+
30
+ # Essai 2: Téléchargement manuel via snapshot
31
+ try:
32
+ print("Tentative de téléchargement manuel...")
33
+ model_path = snapshot_download(
34
+ repo_id=repo_id,
35
+ cache_dir=tempfile.mkdtemp(),
36
+ local_dir_use_symlinks=False
37
+ )
38
+
39
+ # Chercher le fichier .nemo
40
+ nemo_file = None
41
+ for file in os.listdir(model_path):
42
+ if file.endswith('.nemo'):
43
+ nemo_file = os.path.join(model_path, file)
44
+ break
45
+
46
+ if nemo_file and os.path.exists(nemo_file):
47
+ print(f"Chargement depuis: {nemo_file}")
48
+ return nemo_asr.models.EncDecCTCModelBPE.restore_from(nemo_file)
49
+ else:
50
+ raise FileNotFoundError("Fichier .nemo non trouvé dans le repo")
51
+
52
+ except Exception as e2:
53
+ print(f"Échec du téléchargement manuel: {e2}")
54
+ raise
55
 
56
  def extract_audio(video_path, wav_path):
57
+ """Extrait l'audio de la vidéo"""
58
+ video = VideoFileClip(video_path)
59
+ video.audio.write_audiofile(
60
  wav_path, fps=16000, codec="pcm_s16le", verbose=False, logger=None
61
+ )
62
+ video.close()
63
 
64
  def transcribe(model, device, wav, model_name):
65
+ """Transcrit l'audio avec alignement temporel"""
66
  audio, sr = sf.read(wav)
67
  if audio.ndim == 2:
68
  audio = np.mean(audio, axis=1)
 
70
  ln = torch.tensor([x.shape[1]]).to(device)
71
  total_s = len(audio) / sr
72
 
73
+ # Modèles RNNT (Soloni)
74
  if "Soloni" in model_name:
75
  with torch.no_grad():
76
  proc, plen = model.preprocessor(input_signal=x, input_signal_length=ln)
 
78
  hyp = hyps[0][0] if isinstance(hyps[0], list) else hyps[0]
79
  return [(w.start_offset_ms/1000, w.end_offset_ms/1000, w.word) for w in hyp.words]
80
 
81
+ # Modèles CTC (Soloba, QuartzNet)
82
  text = model.transcribe([wav])[0].strip()
83
  if not text:
84
  return []
 
87
  logits, logit_len = model.forward(input_signal=x, input_signal_length=ln)
88
 
89
  words = text.split()
90
+ if not words:
91
+ return []
92
+
93
  config = CtcSegmentationParameters()
94
  config.char_list = list(model.tokenizer.vocab.keys())
95
  gt, _ = prepare_text(config, words)
 
101
  timings[i+1] * tps if i+1 < len(timings) else total_s,
102
  words[i]) for i in range(len(words))]
103
 
104
+ # Regroupement des mots
105
  grouped, temp = [], []
106
  for w in aligned:
107
  temp.append(w)
108
+ if len(temp) >= 4: # Groupe de 4 mots
109
+ grouped.append(temp)
110
+ temp = []
111
+ if temp:
112
+ grouped.append(temp)
113
 
114
  return [(g[0][0], g[-1][1], " ".join([w[2] for w in g])) for g in grouped]
115
 
 
116
  def burn(video, subs):
117
+ """Ajoute les sous-titres à la vidéo"""
118
  clip = VideoFileClip(video)
119
  W, H = clip.size
120
+
121
+ # Tentative de chargement de police
122
  try:
123
+ font_size = max(int(H/20), 20) # Taille minimale
124
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)
125
  except:
126
+ try:
127
+ font = ImageFont.load_default()
128
+ except:
129
+ font = None
130
 
131
  layers = []
132
+ for start, end, text in subs:
133
+ # Création de l'image de sous-titre
134
+ img_height = int(H * 0.12)
135
+ img = Image.new("RGBA", (W, img_height), (0, 0, 0, 140))
136
  draw = ImageDraw.Draw(img)
137
+
138
+ if font:
139
+ bbox = draw.textbbox((0, 0), text, font=font)
140
+ tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
141
+ draw.text(((W - tw) // 2, (img_height - th) // 2), text, font=font, fill="white")
142
+ else:
143
+ # Fallback si police non disponible
144
+ draw.text((W//2, img_height//2), text, fill="white", anchor="mm")
145
+
146
+ # Création du clip de sous-titre
147
+ subtitle_clip = ImageClip(np.array(img)).set_start(start).set_duration(end - start)
148
+ subtitle_clip = subtitle_clip.set_position(("center", int(H * 0.85)))
149
+ layers.append(subtitle_clip)
150
+
151
+ # Composition finale
152
  final = CompositeVideoClip([clip] + layers)
153
+ out_path = "RobotsMali_Subtitled.mp4"
154
+
155
+ # Écriture de la vidéo finale
156
+ final.write_videofile(
157
+ out_path,
158
+ codec="libx264",
159
+ audio_codec="aac",
160
+ fps=clip.fps,
161
+ verbose=False,
162
+ logger=None,
163
+ temp_audiofile="temp-audio.m4a",
164
+ remove_temp=True
165
+ )
166
+
167
+ # Nettoyage
168
+ clip.close()
169
+ final.close()
170
+ for layer in layers:
171
+ layer.close()
172
+
173
+ return out_path
174
 
175
  def pipeline(video_file, model_name):
176
+ """Pipeline principal de traitement"""
177
  if video_file is None:
178
  return "Veuillez importer une vidéo.", None
179
 
180
  repo, nemo_file, mode = MODELS[model_name]
181
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
182
 
183
+ try:
184
+ # Chargement du modèle
185
+ if mode == "rnnt":
186
+ nemo_path = hf_hub_download(repo, filename=nemo_file)
187
+ model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.restore_from(nemo_path)
188
+ else:
189
+ model = load_ctc_model_safe(repo) # Utilisation de la fonction sécurisée
190
+
191
+ model = model.to(device)
192
+ model.eval()
193
+
194
+ # Traitement
195
+ wav_path = "audio.wav"
196
+ extract_audio(video_file, wav_path)
197
+ subs = transcribe(model, device, wav_path, model_name)
198
+ final_video = burn(video_file, subs)
199
+
200
+ # Nettoyage des fichiers temporaires
201
+ if os.path.exists(wav_path):
202
+ os.remove(wav_path)
203
+
204
+ return "✅ Sous-titres générés avec succès!", final_video
205
+
206
+ except Exception as e:
207
+ print(f"Erreur dans le pipeline: {e}")
208
+ return f"❌ Erreur: {str(e)}", None
209
+
210
+ # Interface Gradio
211
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
212
+ gr.Markdown("""
213
+ # 🎙️ **RobotsMali — Sous-titrage automatique Bambara**
214
+ *Générez automatiquement des sous-titres en Bambara pour vos vidéos*
215
+ """)
216
+
217
+ with gr.Row():
218
+ with gr.Column():
219
+ video = gr.Video(label="Vidéo d'entrée", height=300)
220
+ model = gr.Dropdown(
221
+ list(MODELS.keys()),
222
+ value="Soloni V1",
223
+ label="Modèle de reconnaissance vocale",
224
+ info="Soloni: plus précis • Soloba/QuartzNet: plus rapide"
225
+ )
226
+ btn = gr.Button("⚡ Générer les sous-titres", variant="primary")
227
+
228
+ with gr.Column():
229
+ status = gr.Markdown("Prêt à traiter...")
230
+ out = gr.Video(label="Vidéo sous-titrée", height=300)
231
+
232
+ # Exemples
233
+ gr.Examples(
234
+ examples=[],
235
+ inputs=[video, model],
236
+ outputs=[status, out],
237
+ fn=pipeline,
238
+ cache_examples=False,
239
+ )
240
+
241
  btn.click(pipeline, inputs=[video, model], outputs=[status, out])
242
 
243
+ if __name__ == "__main__":
244
+ demo.launch(share=True, server_port=7860)