binaryMao commited on
Commit
a364c86
·
verified ·
1 Parent(s): 84226a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -25
app.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
 
2
  import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil
3
  import torch
@@ -5,9 +9,9 @@ from huggingface_hub import snapshot_download
5
  from nemo.collections import asr as nemo_asr
6
  import gradio as gr
7
 
8
- # --- CONFIGURATION ---
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
- SEGMENT_DURATION = 5.0 # Ta préférence pour Soloni
11
 
12
  MODELS = {
13
  "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
@@ -21,9 +25,27 @@ MODELS = {
21
  "Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"),
22
  }
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  _cache = {}
25
 
26
- # --- GESTION MÉMOIRE ET CHARGEMENT ---
27
  def clear_memory():
28
  _cache.clear()
29
  gc.collect()
@@ -39,32 +61,30 @@ def get_model(name):
39
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
40
  if not nemo_file: raise FileNotFoundError("Fichier .nemo introuvable.")
41
 
42
- # CORRECTIF SÉCURISÉ POUR L'ERREUR D'INITIALISATION
43
  from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
44
-
45
- # On force l'utilisation d'un connecteur standard pour éviter le bug __init__()
46
  connector = SaveRestoreConnector()
47
-
48
  model = nemo_asr.models.ASRModel.restore_from(
49
  nemo_file,
50
  map_location=torch.device(DEVICE),
51
- save_restore_connector=connector # On passe l'instance déjà créée
52
  )
53
 
54
- model.eval()
55
  if DEVICE == "cuda":
56
  model = model.half()
57
 
58
  _cache[name] = model
59
  return model
60
 
61
- # --- UTILITAIRES ---
62
  def format_srt_time(sec):
63
  td = time.gmtime(max(0, sec))
64
  ms = int((sec - int(sec)) * 1000)
65
  return f"{time.strftime('%H:%M:%S', td)},{ms:03}"
66
 
67
- # --- PIPELINE ---
68
  def pipeline(video_in, model_name):
69
  tmp_dir = tempfile.mkdtemp()
70
  try:
@@ -72,24 +92,21 @@ def pipeline(video_in, model_name):
72
  yield "❌ Aucune vidéo sélectionnée.", None
73
  return
74
 
75
- # Phase 1 : Audio
76
  yield "⏳ Phase 1/4 : Extraction audio...", None
77
  full_wav = os.path.join(tmp_dir, "full.wav")
78
  subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True)
79
 
80
- # Phase 2 : Segmentation
81
  yield f"⏳ Phase 2/4 : Segmentation ({SEGMENT_DURATION}s)...", None
82
  subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True, check=True)
83
 
84
  files = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
85
- # Sécurité : ignorer les fichiers corrompus ou trop petits (<1.5 Ko)
86
  valid_segments = [f for f in files if os.path.getsize(f) > 1500]
87
 
88
  if not valid_segments:
89
- yield "❌ Erreur : Audio trop court ou invalide.", None
90
  return
91
 
92
- # Phase 3 : Transcription
93
  yield f"⏳ Phase 3/4 : Chargement de {model_name}...", None
94
  model = get_model(model_name)
95
 
@@ -106,7 +123,6 @@ def pipeline(video_in, model_name):
106
  words = text.split()
107
  if not words: continue
108
 
109
- # Distribution équitable des mots sur les 5 secondes
110
  gap = SEGMENT_DURATION / len(words)
111
  for i, w in enumerate(words):
112
  all_words_ts.append({
@@ -115,20 +131,17 @@ def pipeline(video_in, model_name):
115
  "end": base_time + ((i+1) * gap)
116
  })
117
 
118
- # Phase 4 : Encodage Vidéo
119
- yield "⏳ Phase 4/4 : Encodage final...", None
120
  srt_path = os.path.join(tmp_dir, "final.srt")
121
  with open(srt_path, "w", encoding="utf-8") as f:
122
- for i in range(0, len(all_words_ts), 6): # 6 mots max par ligne
123
  chunk = all_words_ts[i:i+6]
124
  f.write(f"{(i//6)+1}\n{format_srt_time(chunk[0]['start'])} --> {format_srt_time(chunk[-1]['end'])}\n")
125
  f.write(" ".join([c['word'] for c in chunk]) + "\n\n")
126
 
127
  out_path = os.path.abspath(f"resultat_{int(time.time())}.mp4")
128
- # Fix pour le chemin SRT (Windows/Linux)
129
  safe_srt = srt_path.replace("\\", "/").replace(":", "\\:")
130
 
131
- # Style : Couleur Cyan pour la lisibilité
132
  cmd = f"ffmpeg -y -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}':force_style='Alignment=2,FontSize=18,PrimaryColour=&H00FFFF'\" -c:v libx264 -preset ultrafast -c:a copy {out_path}"
133
  subprocess.run(cmd, shell=True, check=True)
134
 
@@ -140,17 +153,23 @@ def pipeline(video_in, model_name):
140
  finally:
141
  if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)
142
 
143
- # --- INTERFACE ---
144
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
145
  gr.HTML("<div style='text-align:center;'><h1>🤖 RobotsMali Speech Lab</h1></div>")
 
146
  with gr.Row():
147
  with gr.Column():
148
  v_input = gr.Video(label="Vidéo Source")
149
  m_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle")
150
- run_btn = gr.Button("🚀 GÉNÉRER SOUS-TITRES", variant="primary")
 
 
 
 
 
151
  with gr.Column():
152
  status = gr.Markdown("### État\nPrêt.")
153
- v_output = gr.Video(label="Résultat")
154
 
155
  run_btn.click(pipeline, [v_input, m_input], [status, v_output])
156
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # !apt-get install -y ffmpeg
3
+ # !pip install gradio huggingface_hub torch
4
+ # !pip install git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[all]
5
 
6
  import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil
7
  import torch
 
9
  from nemo.collections import asr as nemo_asr
10
  import gradio as gr
11
 
12
+ # 1. CONFIGURATION ET MODÈLES (LISTE COMPLÈTE)
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+ SEGMENT_DURATION = 5.0 # Ta préférence stricte pour Soloni
15
 
16
  MODELS = {
17
  "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
 
25
  "Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"),
26
  }
27
 
28
+ # --- SECTION EXEMPLE (RÉTABLIE) ---
29
+ def find_example_video():
30
+ paths = ["examples/MARALINKE_FIXED.mp4", "examples/MARALINKE.mp4", "MARALINKE.mp4"]
31
+ for p in paths:
32
+ if os.path.exists(p): return p
33
+
34
+ print("⬇️ Téléchargement de la vidéo d'exemple...")
35
+ example_url = "https://huggingface.co/spaces/RobotsMali/Soloni-Demo/resolve/main/examples/MARALINKE.mp4"
36
+ target_path = "examples/MARALINKE.mp4"
37
+ os.makedirs("examples", exist_ok=True)
38
+ try:
39
+ subprocess.run(f"wget {example_url} -O {target_path}", shell=True, check=True)
40
+ return target_path
41
+ except Exception as e:
42
+ print(f"⚠️ Impossible de télécharger l'exemple : {e}")
43
+ return None
44
+
45
+ EXAMPLE_PATH = find_example_video()
46
  _cache = {}
47
 
48
+ # 2. GESTION MÉMOIRE ET CHARGEMENT SÉCURISÉ (CORRECTIF BUG)
49
  def clear_memory():
50
  _cache.clear()
51
  gc.collect()
 
61
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
62
  if not nemo_file: raise FileNotFoundError("Fichier .nemo introuvable.")
63
 
64
+ # CORRECTIF : On instancie le connecteur explicitement pour éviter l'erreur object.init()
65
  from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
 
 
66
  connector = SaveRestoreConnector()
67
+
68
  model = nemo_asr.models.ASRModel.restore_from(
69
  nemo_file,
70
  map_location=torch.device(DEVICE),
71
+ save_restore_connector=connector
72
  )
73
 
74
+ model.to(DEVICE).eval()
75
  if DEVICE == "cuda":
76
  model = model.half()
77
 
78
  _cache[name] = model
79
  return model
80
 
81
+ # 3. UTILITAIRES
82
  def format_srt_time(sec):
83
  td = time.gmtime(max(0, sec))
84
  ms = int((sec - int(sec)) * 1000)
85
  return f"{time.strftime('%H:%M:%S', td)},{ms:03}"
86
 
87
+ # 4. PIPELINE DE TRANSCRIPTION
88
  def pipeline(video_in, model_name):
89
  tmp_dir = tempfile.mkdtemp()
90
  try:
 
92
  yield "❌ Aucune vidéo sélectionnée.", None
93
  return
94
 
 
95
  yield "⏳ Phase 1/4 : Extraction audio...", None
96
  full_wav = os.path.join(tmp_dir, "full.wav")
97
  subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True)
98
 
 
99
  yield f"⏳ Phase 2/4 : Segmentation ({SEGMENT_DURATION}s)...", None
100
  subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True, check=True)
101
 
102
  files = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
103
+ # Sécurité : Ignorer les fichiers vides (<1.5 Ko)
104
  valid_segments = [f for f in files if os.path.getsize(f) > 1500]
105
 
106
  if not valid_segments:
107
+ yield "❌ Erreur : Audio invalide.", None
108
  return
109
 
 
110
  yield f"⏳ Phase 3/4 : Chargement de {model_name}...", None
111
  model = get_model(model_name)
112
 
 
123
  words = text.split()
124
  if not words: continue
125
 
 
126
  gap = SEGMENT_DURATION / len(words)
127
  for i, w in enumerate(words):
128
  all_words_ts.append({
 
131
  "end": base_time + ((i+1) * gap)
132
  })
133
 
134
+ yield "⏳ Phase 4/4 : Encodage vidéo...", None
 
135
  srt_path = os.path.join(tmp_dir, "final.srt")
136
  with open(srt_path, "w", encoding="utf-8") as f:
137
+ for i in range(0, len(all_words_ts), 6):
138
  chunk = all_words_ts[i:i+6]
139
  f.write(f"{(i//6)+1}\n{format_srt_time(chunk[0]['start'])} --> {format_srt_time(chunk[-1]['end'])}\n")
140
  f.write(" ".join([c['word'] for c in chunk]) + "\n\n")
141
 
142
  out_path = os.path.abspath(f"resultat_{int(time.time())}.mp4")
 
143
  safe_srt = srt_path.replace("\\", "/").replace(":", "\\:")
144
 
 
145
  cmd = f"ffmpeg -y -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}':force_style='Alignment=2,FontSize=18,PrimaryColour=&H00FFFF'\" -c:v libx264 -preset ultrafast -c:a copy {out_path}"
146
  subprocess.run(cmd, shell=True, check=True)
147
 
 
153
  finally:
154
  if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)
155
 
156
+ # 5. INTERFACE GRADIO
157
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
158
  gr.HTML("<div style='text-align:center;'><h1>🤖 RobotsMali Speech Lab</h1></div>")
159
+
160
  with gr.Row():
161
  with gr.Column():
162
  v_input = gr.Video(label="Vidéo Source")
163
  m_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle")
164
+ run_btn = gr.Button("🚀 GÉNÉRER", variant="primary")
165
+
166
+ # Rétabli : Affichage des exemples si trouvés
167
+ if EXAMPLE_PATH:
168
+ gr.Examples(examples=[[EXAMPLE_PATH, "Soloni V3 (TDT-CTC)"]], inputs=[v_input, m_input])
169
+
170
  with gr.Column():
171
  status = gr.Markdown("### État\nPrêt.")
172
+ v_output = gr.Video(label="Vidéo finale")
173
 
174
  run_btn.click(pipeline, [v_input, m_input], [status, v_output])
175