binaryMao commited on
Commit
0fb103f
·
verified ·
1 Parent(s): 0456de7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -69
app.py CHANGED
@@ -5,66 +5,43 @@ import soundfile as sf
5
  from moviepy.editor import VideoFileClip, CompositeVideoClip, ImageClip
6
  from PIL import Image, ImageDraw, ImageFont
7
  from nemo.collections import asr as nemo_asr
 
8
  from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text
9
 
10
 
11
- # =============================
12
- # LISTE OFFICIELLE DES MODELES ROBOTSMALI
13
- # =============================
14
-
15
  MODELS = {
16
- "Soloni V0": "RobotsMali/soloni-114m-tdt-ctc-V0",
17
- "Soloni V1": "RobotsMali/soloni-114m-tdt-ctc-V1",
18
 
19
- "Soloba V0": "RobotsMali/soloba-ctc-0.6b-V0",
20
- "Soloba V1": "RobotsMali/soloba-ctc-0.6b-V1",
21
 
22
- "QuartzNet V0": "RobotsMali/stt-bm-quartznet15x5-V0",
23
- "QuartzNet V1": "RobotsMali/stt-bm-quartznet15x5-V1"
24
  }
25
 
26
 
27
- # =============================
28
- # EXTRACTION AUDIO (SOLIDE & COMPATIBLE HF)
29
- # =============================
30
-
31
  def extract_audio(video_path, wav_path):
32
- (
33
- VideoFileClip(video_path)
34
- .audio
35
- .write_audiofile(
36
- wav_path,
37
- fps=16000,
38
- codec="pcm_s16le",
39
- verbose=False,
40
- logger=None
41
- )
42
- )
43
-
44
-
45
- # =============================
46
- # TRANSCRIPTION + ALIGNEMENT
47
- # =============================
48
 
49
  def transcribe(model, device, wav, model_name):
50
  audio, sr = sf.read(wav)
51
  if audio.ndim == 2:
52
  audio = np.mean(audio, axis=1)
53
-
54
  x = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(device)
55
  ln = torch.tensor([x.shape[1]]).to(device)
56
  total_s = len(audio) / sr
57
 
58
- # === Soloni → timestamps natifs ===
59
  if "Soloni" in model_name:
60
  with torch.no_grad():
61
  proc, plen = model.preprocessor(input_signal=x, input_signal_length=ln)
62
  hyps = model.decode_and_align(encoder_output=proc, encoded_lengths=plen)
63
-
64
  hyp = hyps[0][0] if isinstance(hyps[0], list) else hyps[0]
65
  return [(w.start_offset_ms/1000, w.end_offset_ms/1000, w.word) for w in hyp.words]
66
 
67
- # === Soloba & QuartzNet → CTC Forced Alignment ===
68
  text = model.transcribe([wav])[0].strip()
69
  if not text:
70
  return []
@@ -84,90 +61,69 @@ def transcribe(model, device, wav, model_name):
84
  timings[i+1] * tps if i+1 < len(timings) else total_s,
85
  words[i]) for i in range(len(words))]
86
 
87
- # Groupage lisible (max 4 mots par sous-titre)
88
  grouped, temp = [], []
89
  for w in aligned:
90
  temp.append(w)
91
  if len(temp) >= 4:
92
- grouped.append(temp)
93
- temp = []
94
- if temp:
95
- grouped.append(temp)
96
 
97
  return [(g[0][0], g[-1][1], " ".join([w[2] for w in g])) for g in grouped]
98
 
99
 
100
- # =============================
101
- # INCRUSTATION SOUS-TITRES (SANS IMAGEMAGICK)
102
- # =============================
103
-
104
  def burn(video, subs):
105
  clip = VideoFileClip(video)
106
  W, H = clip.size
107
-
108
  try:
109
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", int(H/20))
110
  except:
111
  font = ImageFont.load_default()
112
 
113
  layers = []
114
- for s, e, text in subs:
115
- img = Image.new("RGBA", (W, int(H*0.12)), (0,0,0,140))
116
  draw = ImageDraw.Draw(img)
117
  bbox = draw.textbbox((0,0), text, font=font)
118
  tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1]
119
- draw.text(((W-tw)//2, (int(H*0.12)-th)//2), text, font=font, fill="white")
120
- layers.append(ImageClip(np.array(img)).set_start(s).set_duration(e-s).set_position(("center", int(H*0.85))))
121
 
122
  final = CompositeVideoClip([clip] + layers)
123
  out = "RobotsMali_Subtitled.mp4"
124
  final.write_videofile(out, codec="libx264", audio_codec="aac", fps=clip.fps, verbose=False, logger=None)
125
-
126
- clip.close()
127
- final.close()
128
  return out
129
 
130
 
131
- # =============================
132
- # PIPELINE PRINCIPAL
133
- # =============================
134
-
135
  def pipeline(video_file, model_name):
136
  if video_file is None:
137
  return "Veuillez importer une vidéo.", None
138
 
 
139
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
140
 
141
- # Chargement correct selon le modèle
142
- if "Soloni" in model_name:
143
- model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.from_pretrained(model_name=MODELS[model_name])
144
  else:
145
- model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name=MODELS[model_name])
146
 
147
- model = model.to(device)
148
- model.eval()
149
 
150
  wav = "audio.wav"
151
  extract_audio(video_file, wav)
152
  subs = transcribe(model, device, wav, model_name)
153
  final = burn(video_file, subs)
154
-
155
  return "✅ Sous-titres générés.", final
156
 
157
 
158
- # =============================
159
- # INTERFACE (DESIGN CONSERVÉ)
160
- # =============================
161
-
162
  with gr.Blocks() as demo:
163
  gr.Markdown("# 🎙️ **RobotsMali — Sous-titrage automatique Bambara**")
164
-
165
  video = gr.Video(label="Vidéo")
166
  model = gr.Dropdown(list(MODELS.keys()), value="Soloni V1", label="Modèle")
167
  btn = gr.Button("⚡ Générer les sous-titres")
168
  status = gr.Markdown()
169
- out = gr.Video(label="Résultat (avec sous-titres)")
170
-
171
  btn.click(pipeline, inputs=[video, model], outputs=[status, out])
172
 
173
  demo.launch()
 
5
  from moviepy.editor import VideoFileClip, CompositeVideoClip, ImageClip
6
  from PIL import Image, ImageDraw, ImageFont
7
  from nemo.collections import asr as nemo_asr
8
+ from huggingface_hub import hf_hub_download
9
  from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text
10
 
11
 
 
 
 
 
12
  MODELS = {
13
+ "Soloni V0": ("RobotsMali/soloni-114m-tdt-ctc-V0", "soloni-114m-tdt-ctc-V0.nemo", "rnnt"),
14
+ "Soloni V1": ("RobotsMali/soloni-114m-tdt-ctc-V1", "soloni-114m-tdt-ctc-V1.nemo", "rnnt"),
15
 
16
+ "Soloba V0": ("RobotsMali/soloba-ctc-0.6b-V0", None, "ctc"),
17
+ "Soloba V1": ("RobotsMali/soloba-ctc-0.6b-V1", None, "ctc"),
18
 
19
+ "QuartzNet V0": ("RobotsMali/stt-bm-quartznet15x5-V0", None, "ctc"),
20
+ "QuartzNet V1": ("RobotsMali/stt-bm-quartznet15x5-V1", None, "ctc"),
21
  }
22
 
23
 
 
 
 
 
24
  def extract_audio(video_path, wav_path):
25
+ (VideoFileClip(video_path).audio.write_audiofile(
26
+ wav_path, fps=16000, codec="pcm_s16le", verbose=False, logger=None
27
+ ))
28
+
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def transcribe(model, device, wav, model_name):
31
  audio, sr = sf.read(wav)
32
  if audio.ndim == 2:
33
  audio = np.mean(audio, axis=1)
 
34
  x = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(device)
35
  ln = torch.tensor([x.shape[1]]).to(device)
36
  total_s = len(audio) / sr
37
 
 
38
  if "Soloni" in model_name:
39
  with torch.no_grad():
40
  proc, plen = model.preprocessor(input_signal=x, input_signal_length=ln)
41
  hyps = model.decode_and_align(encoder_output=proc, encoded_lengths=plen)
 
42
  hyp = hyps[0][0] if isinstance(hyps[0], list) else hyps[0]
43
  return [(w.start_offset_ms/1000, w.end_offset_ms/1000, w.word) for w in hyp.words]
44
 
 
45
  text = model.transcribe([wav])[0].strip()
46
  if not text:
47
  return []
 
61
  timings[i+1] * tps if i+1 < len(timings) else total_s,
62
  words[i]) for i in range(len(words))]
63
 
 
64
  grouped, temp = [], []
65
  for w in aligned:
66
  temp.append(w)
67
  if len(temp) >= 4:
68
+ grouped.append(temp); temp = []
69
+ if temp: grouped.append(temp)
 
 
70
 
71
  return [(g[0][0], g[-1][1], " ".join([w[2] for w in g])) for g in grouped]
72
 
73
 
 
 
 
 
74
  def burn(video, subs):
75
  clip = VideoFileClip(video)
76
  W, H = clip.size
 
77
  try:
78
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", int(H/20))
79
  except:
80
  font = ImageFont.load_default()
81
 
82
  layers = []
83
+ for s,e,text in subs:
84
+ img = Image.new("RGBA",(W,int(H*0.12)),(0,0,0,140))
85
  draw = ImageDraw.Draw(img)
86
  bbox = draw.textbbox((0,0), text, font=font)
87
  tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1]
88
+ draw.text(((W-tw)//2,(int(H*0.12)-th)//2), text, font=font, fill="white")
89
+ layers.append(ImageClip(np.array(img)).set_start(s).set_duration(e-s).set_position(("center",int(H*0.85))))
90
 
91
  final = CompositeVideoClip([clip] + layers)
92
  out = "RobotsMali_Subtitled.mp4"
93
  final.write_videofile(out, codec="libx264", audio_codec="aac", fps=clip.fps, verbose=False, logger=None)
94
+ clip.close(); final.close()
 
 
95
  return out
96
 
97
 
 
 
 
 
98
  def pipeline(video_file, model_name):
99
  if video_file is None:
100
  return "Veuillez importer une vidéo.", None
101
 
102
+ repo, nemo_file, mode = MODELS[model_name]
103
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
 
105
+ if mode == "rnnt":
106
+ nemo_path = hf_hub_download(repo, filename=nemo_file)
107
+ model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.restore_from(nemo_path)
108
  else:
109
+ model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name=repo)
110
 
111
+ model = model.to(device); model.eval()
 
112
 
113
  wav = "audio.wav"
114
  extract_audio(video_file, wav)
115
  subs = transcribe(model, device, wav, model_name)
116
  final = burn(video_file, subs)
 
117
  return "✅ Sous-titres générés.", final
118
 
119
 
 
 
 
 
120
  with gr.Blocks() as demo:
121
  gr.Markdown("# 🎙️ **RobotsMali — Sous-titrage automatique Bambara**")
 
122
  video = gr.Video(label="Vidéo")
123
  model = gr.Dropdown(list(MODELS.keys()), value="Soloni V1", label="Modèle")
124
  btn = gr.Button("⚡ Générer les sous-titres")
125
  status = gr.Markdown()
126
+ out = gr.Video(label="Résultat")
 
127
  btn.click(pipeline, inputs=[video, model], outputs=[status, out])
128
 
129
  demo.launch()