binaryMao commited on
Commit
9788403
·
verified ·
1 Parent(s): c307877

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -51
app.py CHANGED
@@ -1,11 +1,12 @@
1
  # -*- coding: utf-8 -*-
2
- import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil
3
  import torch
 
4
  from huggingface_hub import snapshot_download
5
  from nemo.collections import asr as nemo_asr
6
  import gradio as gr
7
 
8
- # 1. CONFIGURATION ET MODÈLES
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
  SEGMENT_DURATION = 5.0
11
 
@@ -23,7 +24,7 @@ MODELS = {
23
 
24
  # --- SECTION EXEMPLE ---
25
  def find_example_video():
26
- paths = ["examples/MARALINKE_FIXED.mp4", "examples/MARALINKE.mp4", "MARALINKE.mp4"]
27
  for p in paths:
28
  if os.path.exists(p): return p
29
  return None
@@ -31,109 +32,109 @@ def find_example_video():
31
  EXAMPLE_PATH = find_example_video()
32
  _cache = {}
33
 
34
- # 2. GESTION MÉMOIRE ET CHARGEMENT (NOUVELLE MÉTHODE ANTI-BUG)
35
  def get_model(name):
36
  if name in _cache: return _cache[name]
37
  gc.collect()
38
  if torch.cuda.is_available(): torch.cuda.empty_cache()
39
 
40
- repo, model_type = MODELS[name]
41
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
42
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
43
-
44
- print(f"📦 Restauration forcée du modèle : {name}")
45
-
46
- # MÉTHODE ALTERNATIVE : On bypass le connecteur automatique
47
  try:
48
- # Tentative avec la méthode classique mais simplifiée au maximum
49
- model = nemo_asr.models.ASRModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
50
- except TypeError:
51
- # Si l'erreur object.init() survient, on utilise le chargement par classe spécifique
52
- print("⚠️ Détection du bug NeMo, basculement sur le chargement spécifique...")
53
- if "ctc" in name.lower() or model_type == "ctc":
 
 
 
 
54
  model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
55
  else:
56
  model = nemo_asr.models.EncDecHybridRNNTCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
57
-
58
  model.eval()
59
  if DEVICE == "cuda": model = model.half()
60
  _cache[name] = model
61
  return model
62
 
63
- # 3. UTILITAIRES
64
  def format_srt_time(sec):
65
  td = time.gmtime(max(0, sec))
66
  ms = int((sec - int(sec)) * 1000)
67
  return f"{time.strftime('%H:%M:%S', td)},{ms:03}"
68
 
69
- # 4. PIPELINE
70
  def pipeline(video_in, model_name):
71
  tmp_dir = tempfile.mkdtemp()
72
  try:
73
- if not video_in: return "❌ Vidéo manquante", None
74
 
75
- # Audio
 
76
  full_wav = os.path.join(tmp_dir, "full.wav")
77
  subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True)
78
-
79
- # Segmentation
80
- yield "⏳ Segmentation...", None
81
  subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True, check=True)
82
 
83
  files = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
84
- valid_segments = [f for f in files if os.path.getsize(f) > 1500]
85
 
86
  # Transcription
87
- yield f"🎙️ Chargement & Transcription ({model_name})...", None
88
  model = get_model(model_name)
89
 
90
  with torch.inference_mode():
91
- batch_hypotheses = model.transcribe(valid_segments, batch_size=16 if DEVICE=="cuda" else 2, return_hypotheses=True)
92
 
93
- # SRT Generation
94
- all_words_ts = []
95
- for idx, hyp in enumerate(batch_hypotheses):
96
- base_time = idx * SEGMENT_DURATION
97
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
98
  words = text.split()
99
  if not words: continue
100
- gap = SEGMENT_DURATION / len(words)
101
  for i, w in enumerate(words):
102
- all_words_ts.append({"word": w, "start": base_time + (i * gap), "end": base_time + ((i+1) * gap)})
103
 
104
- # Encodage
105
- srt_path = os.path.join(tmp_dir, "final.srt")
106
  with open(srt_path, "w", encoding="utf-8") as f:
107
- for i in range(0, len(all_words_ts), 6):
108
- chunk = all_words_ts[i:i+6]
109
- f.write(f"{(i//6)+1}\n{format_srt_time(chunk[0]['start'])} --> {format_srt_time(chunk[-1]['end'])}\n")
110
- f.write(" ".join([c['word'] for c in chunk]) + "\n\n")
111
 
112
- out_path = os.path.abspath(f"resultat_{int(time.time())}.mp4")
 
 
113
  safe_srt = srt_path.replace("\\", "/").replace(":", "\\:")
114
- cmd = f"ffmpeg -y -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}':force_style='Alignment=2,FontSize=18'\" -c:v libx264 -preset ultrafast -c:a copy {out_path}"
 
115
  subprocess.run(cmd, shell=True, check=True)
116
 
117
- yield "✅ Succès !", out_path
118
 
119
  except Exception as e:
120
  yield f"❌ Erreur : {str(e)}", None
121
  finally:
122
  if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)
123
 
124
- # 5. INTERFACE
125
- with gr.Blocks() as demo:
126
- gr.HTML("<h1 style='text-align:center;'>RobotsMali Speech Lab</h1>")
127
  with gr.Row():
128
  with gr.Column():
129
- v_input = gr.Video(label="Source")
130
- m_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle")
131
- run_btn = gr.Button("🚀 GÉNÉRER", variant="primary")
132
- if EXAMPLE_PATH: gr.Examples([[EXAMPLE_PATH, "Soloni V3 (TDT-CTC)"]], [v_input, m_input])
133
  with gr.Column():
134
- status = gr.Markdown("Prêt.")
135
- v_output = gr.Video(label="Résultat")
136
 
137
- run_btn.click(pipeline, [v_input, m_input], [status, v_output])
138
 
139
  demo.launch()
 
1
  # -*- coding: utf-8 -*-
2
+ import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil, tarfile
3
  import torch
4
+ import yaml
5
  from huggingface_hub import snapshot_download
6
  from nemo.collections import asr as nemo_asr
7
  import gradio as gr
8
 
9
+ # 1. CONFIGURATION
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  SEGMENT_DURATION = 5.0
12
 
 
24
 
25
  # --- SECTION EXEMPLE ---
26
  def find_example_video():
27
+ paths = ["examples/MARALINKE.mp4", "MARALINKE.mp4"]
28
  for p in paths:
29
  if os.path.exists(p): return p
30
  return None
 
32
  EXAMPLE_PATH = find_example_video()
33
  _cache = {}
34
 
35
+ # 2. CHARGEMENT "HARDCORE" (BYPASS DU BUG INIT)
36
  def get_model(name):
37
  if name in _cache: return _cache[name]
38
  gc.collect()
39
  if torch.cuda.is_available(): torch.cuda.empty_cache()
40
 
41
+ repo, _ = MODELS[name]
42
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
43
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
44
+
45
+ # MÉTHODE DE SECOURS : Chargement par restauration standard
46
+ # Si ça échoue encore, on utilise une approche par classe explicite
 
47
  try:
48
+ from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
49
+ model = nemo_asr.models.ASRModel.restore_from(
50
+ nemo_file,
51
+ map_location=torch.device(DEVICE),
52
+ save_restore_connector=SaveRestoreConnector()
53
+ )
54
+ except Exception:
55
+ # Si le bug persiste, on force le type selon le nom
56
+ print("🔧 Mode de secours activé pour le chargement...")
57
+ if "ctc" in name.lower() or "Soloba" in name:
58
  model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
59
  else:
60
  model = nemo_asr.models.EncDecHybridRNNTCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
61
+
62
  model.eval()
63
  if DEVICE == "cuda": model = model.half()
64
  _cache[name] = model
65
  return model
66
 
67
+ # 3. UTILS & PIPELINE
68
  def format_srt_time(sec):
69
  td = time.gmtime(max(0, sec))
70
  ms = int((sec - int(sec)) * 1000)
71
  return f"{time.strftime('%H:%M:%S', td)},{ms:03}"
72
 
 
73
  def pipeline(video_in, model_name):
74
  tmp_dir = tempfile.mkdtemp()
75
  try:
76
+ if not video_in: return "❌ Vidéo absente", None
77
 
78
+ # Audio & Segmentation
79
+ yield "⏳ Extraction & Segmentation...", None
80
  full_wav = os.path.join(tmp_dir, "full.wav")
81
  subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True)
 
 
 
82
  subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True, check=True)
83
 
84
  files = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
85
+ valid_segments = [f for f in files if os.path.getsize(f) > 1000]
86
 
87
  # Transcription
88
+ yield f"🎙️ Transcription ({model_name})...", None
89
  model = get_model(model_name)
90
 
91
  with torch.inference_mode():
92
+ results = model.transcribe(valid_segments, batch_size=16 if DEVICE=="cuda" else 2, return_hypotheses=True)
93
 
94
+ # SRT
95
+ all_words = []
96
+ for idx, hyp in enumerate(results):
 
97
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
98
  words = text.split()
99
  if not words: continue
100
+ step = SEGMENT_DURATION / len(words)
101
  for i, w in enumerate(words):
102
+ all_words.append({"w": w, "s": (idx * SEGMENT_DURATION) + (i * step), "e": (idx * SEGMENT_DURATION) + ((i+1) * step)})
103
 
104
+ srt_path = os.path.join(tmp_dir, "f.srt")
 
105
  with open(srt_path, "w", encoding="utf-8") as f:
106
+ for i in range(0, len(all_words), 6):
107
+ c = all_words[i:i+6]
108
+ f.write(f"{(i//6)+1}\n{format_srt_time(c[0]['s'])} --> {format_srt_time(c[-1]['e'])}\n{' '.join([x['w'] for x in c])}\n\n")
 
109
 
110
+ # Vidéo
111
+ out_path = os.path.abspath(f"output_{int(time.time())}.mp4")
112
+ # Formatage du chemin SRT pour ffmpeg (Linux/Windows)
113
  safe_srt = srt_path.replace("\\", "/").replace(":", "\\:")
114
+
115
+ cmd = f"ffmpeg -y -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}'\" -c:v libx264 -preset ultrafast -c:a copy {out_path}"
116
  subprocess.run(cmd, shell=True, check=True)
117
 
118
+ yield "✅ Terminé !", out_path
119
 
120
  except Exception as e:
121
  yield f"❌ Erreur : {str(e)}", None
122
  finally:
123
  if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)
124
 
125
+ # 4. INTERFACE
126
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
127
+ gr.Markdown("# 🤖 RobotsMali Speech Lab")
128
  with gr.Row():
129
  with gr.Column():
130
+ v_in = gr.Video()
131
+ m_in = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)")
132
+ btn = gr.Button("🚀 GÉNÉRER", variant="primary")
133
+ if EXAMPLE_PATH: gr.Examples([[EXAMPLE_PATH, "Soloni V3 (TDT-CTC)"]], [v_in, m_in])
134
  with gr.Column():
135
+ stat = gr.Markdown("Prêt.")
136
+ v_out = gr.Video()
137
 
138
+ btn.click(pipeline, [v_in, m_in], [stat, v_out])
139
 
140
  demo.launch()