binaryMao commited on
Commit
605a27b
·
verified ·
1 Parent(s): 8cd7de6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -55
app.py CHANGED
@@ -5,43 +5,46 @@ import torch
5
  import soundfile as sf
6
  from moviepy.editor import VideoFileClip, CompositeVideoClip, ImageClip
7
  from PIL import Image, ImageDraw, ImageFont
8
-
9
  from nemo.collections import asr as nemo_asr
10
- from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
11
  from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text
12
 
13
 
14
  # =============================
15
- # LISTE DES MODELES ROBOTSMALI
16
  # =============================
17
 
18
  MODELS = {
19
- "Soloni 114M TDT CTC v1": "RobotsMali/soloni-114m-tdt-ctc-v1",
20
- "Soloni 350M TDT CTC v1": "RobotsMali/soloni-350m-tdt-ctc-v1",
21
 
22
- "Soloba CTC 0.6B v0": "RobotsMali/soloba-ctc-0.6b-v0",
23
- "Soloba CTC 0.6B v1": "RobotsMali/soloba-ctc-0.6b-v1",
24
 
25
- "QuartzNet Bambara v1": "RobotsMali/stt-bm-quartznet15x5-v1",
26
- "QuartzNet Bambara v2": "RobotsMali/stt-bm-quartznet15x5-v2"
27
  }
28
 
29
 
30
  # =============================
31
- # FONCTION : EXTRAIRE AUDIO
32
  # =============================
33
 
34
  def extract_audio(video_path, wav_path):
35
- clip = VideoFileClip(video_path)
36
- audio = clip.audio.to_soundarray(fps=16000)
37
- if audio.ndim == 2:
38
- audio = np.mean(audio, axis=1)
39
- sf.write(wav_path, audio, 16000)
40
- clip.close()
 
 
 
 
 
41
 
42
 
43
  # =============================
44
- # FONCTION : TRANSCRIPTION + TIMESTAMP
45
  # =============================
46
 
47
  def transcribe(model, device, wav, model_name):
@@ -51,18 +54,19 @@ def transcribe(model, device, wav, model_name):
51
 
52
  x = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(device)
53
  ln = torch.tensor([x.shape[1]]).to(device)
 
54
 
55
- # === Cas 1 : Soloni → timestamps natifs ===
56
- if "Soloni" in model_name and hasattr(model, "decode_and_align"):
57
  with torch.no_grad():
58
  proc, plen = model.preprocessor(input_signal=x, input_signal_length=ln)
59
  hyps = model.decode_and_align(encoder_output=proc, encoded_lengths=plen)
 
60
  hyp = hyps[0][0] if isinstance(hyps[0], list) else hyps[0]
61
  return [(w.start_offset_ms/1000, w.end_offset_ms/1000, w.word) for w in hyp.words]
62
 
63
- # === Cas 2 : Soloba & QuartzNet → Forced Alignment CTC ===
64
- text = model.transcribe([wav])[0]
65
- text = text.strip()
66
  if not text:
67
  return []
68
 
@@ -72,37 +76,29 @@ def transcribe(model, device, wav, model_name):
72
  words = text.split()
73
  config = CtcSegmentationParameters()
74
  config.char_list = list(model.tokenizer.vocab.keys())
75
- gt, utt = prepare_text(config, words)
76
 
77
  timings, _, _ = ctc_segmentation(config, logits.cpu().numpy()[0], gt)
78
- total_s = len(audio) / sr
79
  tps = total_s / logit_len.cpu().numpy()[0]
80
 
81
- word_times = []
82
- for i, w in enumerate(words):
83
- s = timings[i] * tps
84
- e = timings[i+1] * tps if i+1 < len(timings) else total_s
85
- word_times.append((s, e, w))
86
-
87
- # Groupage lisible : 3-5 mots par ligne
88
- grouped, block = [], []
89
- for w in word_times:
90
- block.append(w)
91
- if len(block) >= 4:
92
- grouped.append(block)
93
- block = []
94
- if block:
95
- grouped.append(block)
96
 
97
- subs = []
98
- for g in grouped:
99
- subs.append((g[0][0], g[-1][1], " ".join([w[2] for w in g])))
 
 
 
 
 
100
 
101
- return subs
102
 
103
 
104
  # =============================
105
- # FONCTION : INCRUSTATION SOUS-TITRES
106
  # =============================
107
 
108
  def burn(video, subs):
@@ -121,10 +117,7 @@ def burn(video, subs):
121
  bbox = draw.textbbox((0,0), text, font=font)
122
  tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1]
123
  draw.text(((W-tw)//2, (int(H*0.12)-th)//2), text, font=font, fill="white")
124
-
125
- layers.append(ImageClip(np.array(img))
126
- .set_start(s).set_duration(e-s)
127
- .set_position(("center", int(H*0.85))))
128
 
129
  final = CompositeVideoClip([clip] + layers)
130
  out = "RobotsMali_Subtitled.mp4"
@@ -145,22 +138,23 @@ def pipeline(video_file, model_name):
145
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146
  model = nemo_asr.models.ASRModel.from_pretrained(MODELS[model_name]).to(device)
147
 
148
- wav = "temp.wav"
149
  extract_audio(video_file, wav)
150
  subs = transcribe(model, device, wav, model_name)
151
- out = burn(video_file, subs)
152
- return "✅ Sous-titres générés avec succès.", out
 
153
 
154
 
155
  # =============================
156
- # INTERFACE GRADIO
157
  # =============================
158
 
159
  with gr.Blocks() as demo:
160
- gr.Markdown("# 🎙️ RobotsMali Subtitle Generator")
161
 
162
- video = gr.Video(label="Importer une vidéo")
163
- model = gr.Dropdown(list(MODELS.keys()), value="Soloni 114M TDT CTC v1", label="Sélection du modèle")
164
  btn = gr.Button("⚡ Générer les sous-titres")
165
  status = gr.Markdown()
166
  out = gr.Video(label="Résultat")
 
5
  import soundfile as sf
6
  from moviepy.editor import VideoFileClip, CompositeVideoClip, ImageClip
7
  from PIL import Image, ImageDraw, ImageFont
 
8
  from nemo.collections import asr as nemo_asr
 
9
  from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text
10
 
11
 
12
  # =============================
13
+ # LISTE OFFICIELLE DES MODELES ROBOTSMALI
14
  # =============================
15
 
16
  MODELS = {
17
+ "Soloni V0": "RobotsMali/soloni-114m-tdt-ctc-V0",
18
+ "Soloni V1": "RobotsMali/soloni-114m-tdt-ctc-V1",
19
 
20
+ "Soloba V0": "RobotsMali/soloba-ctc-0.6b-V0",
21
+ "Soloba V1": "RobotsMali/soloba-ctc-0.6b-V1",
22
 
23
+ "QuartzNet V0": "RobotsMali/stt-bm-quartznet15x5-V0",
24
+ "QuartzNet V1": "RobotsMali/stt-bm-quartznet15x5-V1"
25
  }
26
 
27
 
28
  # =============================
29
+ # EXTRACTION AUDIO (FIABLE + COMPATIBLE HF & COLAB)
30
  # =============================
31
 
32
  def extract_audio(video_path, wav_path):
33
+ (
34
+ VideoFileClip(video_path)
35
+ .audio
36
+ .write_audiofile(
37
+ wav_path,
38
+ fps=16000,
39
+ codec="pcm_s16le",
40
+ verbose=False,
41
+ logger=None
42
+ )
43
+ )
44
 
45
 
46
  # =============================
47
+ # TRANSCRIPTION + ALIGNEMENT
48
  # =============================
49
 
50
  def transcribe(model, device, wav, model_name):
 
54
 
55
  x = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(device)
56
  ln = torch.tensor([x.shape[1]]).to(device)
57
+ total_s = len(audio) / sr
58
 
59
+ # === Soloni → timestamps natifs ===
60
+ if "Soloni" in model_name:
61
  with torch.no_grad():
62
  proc, plen = model.preprocessor(input_signal=x, input_signal_length=ln)
63
  hyps = model.decode_and_align(encoder_output=proc, encoded_lengths=plen)
64
+
65
  hyp = hyps[0][0] if isinstance(hyps[0], list) else hyps[0]
66
  return [(w.start_offset_ms/1000, w.end_offset_ms/1000, w.word) for w in hyp.words]
67
 
68
+ # === Soloba / QuartzNet → Forced Alignment CTC ===
69
+ text = model.transcribe([wav])[0].strip()
 
70
  if not text:
71
  return []
72
 
 
76
  words = text.split()
77
  config = CtcSegmentationParameters()
78
  config.char_list = list(model.tokenizer.vocab.keys())
79
+ gt, _ = prepare_text(config, words)
80
 
81
  timings, _, _ = ctc_segmentation(config, logits.cpu().numpy()[0], gt)
 
82
  tps = total_s / logit_len.cpu().numpy()[0]
83
 
84
+ aligned = [(timings[i] * tps,
85
+ timings[i+1] * tps if i+1 < len(timings) else total_s,
86
+ words[i]) for i in range(len(words))]
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ grouped, temp = [], []
89
+ for w in aligned:
90
+ temp.append(w)
91
+ if len(temp) >= 4:
92
+ grouped.append(temp)
93
+ temp = []
94
+ if temp:
95
+ grouped.append(temp)
96
 
97
+ return [(g[0][0], g[-1][1], " ".join([w[2] for w in g])) for g in grouped]
98
 
99
 
100
  # =============================
101
+ # INCRUSTATION SOUS-TITRES (SANS IMAGEMAGICK)
102
  # =============================
103
 
104
  def burn(video, subs):
 
117
  bbox = draw.textbbox((0,0), text, font=font)
118
  tw, th = bbox[2]-bbox[0], bbox[3]-bbox[1]
119
  draw.text(((W-tw)//2, (int(H*0.12)-th)//2), text, font=font, fill="white")
120
+ layers.append(ImageClip(np.array(img)).set_start(s).set_duration(e-s).set_position(("center", int(H*0.85))))
 
 
 
121
 
122
  final = CompositeVideoClip([clip] + layers)
123
  out = "RobotsMali_Subtitled.mp4"
 
138
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
139
  model = nemo_asr.models.ASRModel.from_pretrained(MODELS[model_name]).to(device)
140
 
141
+ wav = "audio.wav"
142
  extract_audio(video_file, wav)
143
  subs = transcribe(model, device, wav, model_name)
144
+ final = burn(video_file, subs)
145
+
146
+ return "✅ Sous-titres générés.", final
147
 
148
 
149
  # =============================
150
+ # INTERFACE (inchangée)
151
  # =============================
152
 
153
  with gr.Blocks() as demo:
154
+ gr.Markdown("# 🎙️ **RobotsMali - Sous-titrage Bambara Automatique**")
155
 
156
+ video = gr.Video(label="Vidéo")
157
+ model = gr.Dropdown(list(MODELS.keys()), value="Soloni V1", label="Modèle")
158
  btn = gr.Button("⚡ Générer les sous-titres")
159
  status = gr.Markdown()
160
  out = gr.Video(label="Résultat")