binaryMao commited on
Commit
db04ae4
·
verified ·
1 Parent(s): 9788403

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -64
app.py CHANGED
@@ -1,12 +1,15 @@
1
  # -*- coding: utf-8 -*-
2
- import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil, tarfile
3
  import torch
4
- import yaml
5
  from huggingface_hub import snapshot_download
6
  from nemo.collections import asr as nemo_asr
7
  import gradio as gr
8
 
9
- # 1. CONFIGURATION
 
 
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  SEGMENT_DURATION = 5.0
12
 
@@ -22,7 +25,6 @@ MODELS = {
22
  "Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"),
23
  }
24
 
25
- # --- SECTION EXEMPLE ---
26
  def find_example_video():
27
  paths = ["examples/MARALINKE.mp4", "MARALINKE.mp4"]
28
  for p in paths:
@@ -32,109 +34,130 @@ def find_example_video():
32
  EXAMPLE_PATH = find_example_video()
33
  _cache = {}
34
 
35
- # 2. CHARGEMENT "HARDCORE" (BYPASS DU BUG INIT)
36
  def get_model(name):
37
  if name in _cache: return _cache[name]
38
- gc.collect()
39
- if torch.cuda.is_available(): torch.cuda.empty_cache()
40
 
41
- repo, _ = MODELS[name]
 
 
 
42
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
43
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
 
 
 
 
44
 
45
- # MÉTHODE DE SECOURS : Chargement par restauration standard
46
- # Si ça échoue encore, on utilise une approche par classe explicite
47
  try:
 
48
  from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
49
- model = nemo_asr.models.ASRModel.restore_from(
50
- nemo_file,
51
- map_location=torch.device(DEVICE),
52
- save_restore_connector=SaveRestoreConnector()
53
- )
54
- except Exception:
55
- # Si le bug persiste, on force le type selon le nom
56
- print("🔧 Mode de secours activé pour le chargement...")
57
- if "ctc" in name.lower() or "Soloba" in name:
58
- model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
59
- else:
60
- model = nemo_asr.models.EncDecHybridRNNTCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
 
 
 
 
 
 
 
61
 
62
  model.eval()
63
  if DEVICE == "cuda": model = model.half()
64
  _cache[name] = model
65
  return model
66
 
67
- # 3. UTILS & PIPELINE
68
- def format_srt_time(sec):
69
- td = time.gmtime(max(0, sec))
70
- ms = int((sec - int(sec)) * 1000)
71
- return f"{time.strftime('%H:%M:%S', td)},{ms:03}"
72
-
73
  def pipeline(video_in, model_name):
74
  tmp_dir = tempfile.mkdtemp()
 
 
 
 
 
 
 
75
  try:
76
- if not video_in: return "❌ Vidéo absente", None
77
-
78
- # Audio & Segmentation
79
- yield "⏳ Extraction & Segmentation...", None
 
80
  full_wav = os.path.join(tmp_dir, "full.wav")
81
- subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True)
82
- subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True, check=True)
 
 
 
83
 
84
  files = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
85
  valid_segments = [f for f in files if os.path.getsize(f) > 1000]
86
-
87
- # Transcription
88
- yield f"🎙️ Transcription ({model_name})...", None
89
  model = get_model(model_name)
90
 
 
91
  with torch.inference_mode():
92
- results = model.transcribe(valid_segments, batch_size=16 if DEVICE=="cuda" else 2, return_hypotheses=True)
93
 
94
- # SRT
95
  all_words = []
96
- for idx, hyp in enumerate(results):
97
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
98
  words = text.split()
99
  if not words: continue
100
- step = SEGMENT_DURATION / len(words)
101
  for i, w in enumerate(words):
102
- all_words.append({"w": w, "s": (idx * SEGMENT_DURATION) + (i * step), "e": (idx * SEGMENT_DURATION) + ((i+1) * step)})
103
 
104
- srt_path = os.path.join(tmp_dir, "f.srt")
 
105
  with open(srt_path, "w", encoding="utf-8") as f:
106
  for i in range(0, len(all_words), 6):
107
- c = all_words[i:i+6]
108
- f.write(f"{(i//6)+1}\n{format_srt_time(c[0]['s'])} --> {format_srt_time(c[-1]['e'])}\n{' '.join([x['w'] for x in c])}\n\n")
 
 
109
 
110
- # Vidéo
111
- out_path = os.path.abspath(f"output_{int(time.time())}.mp4")
112
- # Formatage du chemin SRT pour ffmpeg (Linux/Windows)
113
  safe_srt = srt_path.replace("\\", "/").replace(":", "\\:")
 
114
 
115
- cmd = f"ffmpeg -y -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}'\" -c:v libx264 -preset ultrafast -c:a copy {out_path}"
116
- subprocess.run(cmd, shell=True, check=True)
117
-
118
- yield "✅ Terminé !", out_path
119
 
120
  except Exception as e:
121
- yield f"❌ Erreur : {str(e)}", None
 
 
122
  finally:
123
  if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)
124
 
125
- # 4. INTERFACE
126
- with gr.Blocks(theme=gr.themes.Default()) as demo:
127
- gr.Markdown("# 🤖 RobotsMali Speech Lab")
128
  with gr.Row():
129
  with gr.Column():
130
- v_in = gr.Video()
131
- m_in = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)")
132
- btn = gr.Button("🚀 GÉNÉRER", variant="primary")
133
- if EXAMPLE_PATH: gr.Examples([[EXAMPLE_PATH, "Soloni V3 (TDT-CTC)"]], [v_in, m_in])
 
134
  with gr.Column():
135
- stat = gr.Markdown("Prêt.")
136
- v_out = gr.Video()
137
 
138
- btn.click(pipeline, [v_in, m_in], [stat, v_out])
139
 
140
- demo.launch()
 
1
  # -*- coding: utf-8 -*-
2
+ import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil
3
  import torch
4
+ import logging
5
  from huggingface_hub import snapshot_download
6
  from nemo.collections import asr as nemo_asr
7
  import gradio as gr
8
 
9
+ # Configuration des logs pour voir ce qui se passe sous le capot
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
  SEGMENT_DURATION = 5.0
15
 
 
25
  "Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"),
26
  }
27
 
 
28
  def find_example_video():
29
  paths = ["examples/MARALINKE.mp4", "MARALINKE.mp4"]
30
  for p in paths:
 
34
  EXAMPLE_PATH = find_example_video()
35
  _cache = {}
36
 
37
+ # --- CHARGEMENT AVEC LOGS ET BYPASS ---
38
  def get_model(name):
39
  if name in _cache: return _cache[name]
 
 
40
 
41
+ repo, m_type = MODELS[name]
42
+ print(f"🔍 LOG: Tentative de chargement du modèle: {name}")
43
+ print(f"🔍 LOG: Repo HF: {repo} | Device: {DEVICE}")
44
+
45
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
46
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
47
+
48
+ if not nemo_file:
49
+ print(f"❌ LOG: Erreur - Fichier .nemo introuvable dans {folder}")
50
+ raise FileNotFoundError("Fichier .nemo manquant.")
51
 
52
+ # Tentative 1: Standard avec connecteur explicite
 
53
  try:
54
+ print("🔍 LOG: Essai Méthode 1 (Standard Restore)...")
55
  from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
56
+ connector = SaveRestoreConnector()
57
+ model = nemo_asr.models.ASRModel.restore_from(nemo_file, map_location=torch.device(DEVICE), save_restore_connector=connector)
58
+ print("✅ LOG: Succès avec Méthode 1")
59
+
60
+ except TypeError as e:
61
+ print(f"⚠️ LOG: Échec Méthode 1 (Erreur init): {e}")
62
+ # Tentative 2: Forcer la classe selon le type
63
+ try:
64
+ print(f"🔍 LOG: Essai Méthode 2 (Forçage Classe {m_type})...")
65
+ if "ctc" in name.lower() or m_type == "ctc":
66
+ model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
67
+ else:
68
+ model = nemo_asr.models.EncDecHybridRNNTCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
69
+ print("✅ LOG: Succès avec Méthode 2")
70
+
71
+ except Exception as e2:
72
+ print(f"❌ LOG: Échec critique Méthode 2: {e2}")
73
+ traceback.print_exc()
74
+ raise RuntimeError(f"Impossible de charger le modèle après 2 tentatives. Erreur: {e2}")
75
 
76
  model.eval()
77
  if DEVICE == "cuda": model = model.half()
78
  _cache[name] = model
79
  return model
80
 
81
+ # --- PIPELINE ---
 
 
 
 
 
82
  def pipeline(video_in, model_name):
83
  tmp_dir = tempfile.mkdtemp()
84
+ log_messages = []
85
+
86
+ def add_log(msg):
87
+ print(f"📋 PIPELINE: {msg}")
88
+ log_messages.append(msg)
89
+ return "\n".join(log_messages)
90
+
91
  try:
92
+ if not video_in:
93
+ yield "❌ Vidéo manquante", None
94
+ return
95
+
96
+ yield add_log("Phase 1: Extraction audio..."), None
97
  full_wav = os.path.join(tmp_dir, "full.wav")
98
+ res = subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, capture_output=True)
99
+ if res.returncode != 0: raise RuntimeError(f"FFmpeg Audio Error: {res.stderr.decode()}")
100
+
101
+ yield add_log(f"Phase 2: Découpage en blocs de {SEGMENT_DURATION}s..."), None
102
+ subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True)
103
 
104
  files = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
105
  valid_segments = [f for f in files if os.path.getsize(f) > 1000]
106
+ yield add_log(f"Segments valides trouvés: {len(valid_segments)}"), None
107
+
108
+ yield add_log(f"Phase 3: Initialisation de {model_name}..."), None
109
  model = get_model(model_name)
110
 
111
+ yield add_log("Phase 4: Transcription en cours..."), None
112
  with torch.inference_mode():
113
+ batch_hyp = model.transcribe(valid_segments, batch_size=8, return_hypotheses=True)
114
 
115
+ # Traitement SRT
116
  all_words = []
117
+ for idx, hyp in enumerate(batch_hyp):
118
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
119
  words = text.split()
120
  if not words: continue
121
+ gap = SEGMENT_DURATION / len(words)
122
  for i, w in enumerate(words):
123
+ all_words.append({"w": w, "s": (idx * SEGMENT_DURATION) + (i * gap), "e": (idx * SEGMENT_DURATION) + ((i+1) * gap)})
124
 
125
+ yield add_log("Phase 5: Création de la vidéo finale..."), None
126
+ srt_path = os.path.join(tmp_dir, "sub.srt")
127
  with open(srt_path, "w", encoding="utf-8") as f:
128
  for i in range(0, len(all_words), 6):
129
+ chunk = all_words[i:i+6]
130
+ start_f = time.strftime('%H:%M:%S', time.gmtime(chunk[0]['s'])) + f",{int((chunk[0]['s']%1)*1000):03d}"
131
+ end_f = time.strftime('%H:%M:%S', time.gmtime(chunk[-1]['e'])) + f",{int((chunk[-1]['e']%1)*1000):03d}"
132
+ f.write(f"{(i//6)+1}\n{start_f} --> {end_f}\n{' '.join([x['w'] for x in chunk])}\n\n")
133
 
134
+ out_path = os.path.abspath(f"result_{int(time.time())}.mp4")
 
 
135
  safe_srt = srt_path.replace("\\", "/").replace(":", "\\:")
136
+ subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}'\" -c:v libx264 -preset superfast -c:a copy {out_path}", shell=True, check=True)
137
 
138
+ yield add_log(" Terminé avec succès !"), out_path
 
 
 
139
 
140
  except Exception as e:
141
+ err_msg = f"❌ ERREUR: {str(e)}\n{traceback.format_exc()}"
142
+ print(err_msg)
143
+ yield add_log(err_msg), None
144
  finally:
145
  if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)
146
 
147
+ # --- INTERFACE ---
148
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
149
+ gr.Markdown("# 🚀 RobotsMali Speech Lab (Debug Mode)")
150
  with gr.Row():
151
  with gr.Column():
152
+ v_input = gr.Video(label="Vidéo")
153
+ m_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle")
154
+ run_btn = gr.Button("DÉMARRER LA TRANSCRIPTION", variant="primary")
155
+ if EXAMPLE_PATH:
156
+ gr.Examples([[EXAMPLE_PATH, "Soloni V3 (TDT-CTC)"]], [v_input, m_input])
157
  with gr.Column():
158
+ status_box = gr.Textbox(label="Logs d'exécution", lines=10, interactive=False)
159
+ v_output = gr.Video(label="Résultat")
160
 
161
+ run_btn.click(pipeline, [v_input, m_input], [status_box, v_output])
162
 
163
+ demo.launch(debug=True)