binaryMao commited on
Commit
cc8e3ad
·
verified ·
1 Parent(s): 326676b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -86
app.py CHANGED
@@ -5,119 +5,92 @@ from huggingface_hub import snapshot_download
5
  from nemo.collections import asr as nemo_asr
6
  import gradio as gr
7
 
8
- # 1. CONFIGURATION MATÉRIEL ET MODÈLES
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  MODELS = {
12
  "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
 
 
 
13
  "Soloba V1.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v1.5", "rnnt"),
 
14
  "Soloni V3 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v3", "rnnt"),
15
  "Soloni V2 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v2", "rnnt"),
 
 
16
  "Soloni MSE (Experimental)": ("RobotsMali/lau-soloni-114m-mse-k1", "ctc"),
17
- "Soloba V0.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v0.5", "rnnt"),
18
  }
19
 
20
- # 2. LOCALISATION DE LA VIDÉO D'EXEMPLE
21
- def get_absolute_example():
22
- names = ["MARALINKE.mp4", "maralinke.mp4", "example.mp4"]
23
- dirs = [".", "examples", "/home/user/app", "/home/user/app/examples"]
24
- for d in dirs:
25
- for n in names:
26
- p = os.path.join(d, n)
27
- if os.path.exists(p): return os.path.abspath(p)
28
- return None
29
-
30
- EXAMPLE_PATH = get_absolute_example()
31
  _cache = {}
32
 
33
- # 3. GESTION DE LA MÉMOIRE ET CHARGEMENT (CORRIGÉ)
34
  def clear_memory():
35
- """Libère proprement la RAM et la VRAM."""
36
  _cache.clear()
37
  gc.collect()
38
  if torch.cuda.is_available():
39
  torch.cuda.empty_cache()
40
 
41
  def get_model(name):
42
- """Charge le modèle en utilisant ASRModel pour éviter les erreurs de state_dict."""
43
  if name in _cache: return _cache[name]
44
-
45
  clear_memory()
46
  repo, _ = MODELS[name]
47
 
48
- print(f"📥 Téléchargement depuis {repo}...")
49
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
50
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
51
 
52
- if not nemo_file:
53
- raise FileNotFoundError(f"Aucun fichier .nemo trouvé dans {folder}")
54
-
55
- # SOLUTION : Utilisation de ASRModel.restore_from pour la détection automatique
56
- # Cela évite l'erreur 'Unexpected key(s) in state_dict'
57
- model = nemo_asr.models.ASRModel.restore_from(nemo_file)
 
58
 
59
  model.to(DEVICE).eval()
60
-
61
- # Optimisation FP16
62
  if DEVICE == "cuda":
63
  model.half()
64
 
65
  _cache[name] = model
66
  return model
67
 
68
- # 4. UTILITAIRE TEMPOREL
69
  def format_srt_time(sec):
70
  td = time.gmtime(sec)
71
  ms = int((sec - int(sec)) * 1000)
72
  return f"{time.strftime('%H:%M:%S', td)},{ms:03}"
73
 
74
- # 5. PIPELINE DE TRANSCRIPTION
75
  def pipeline(video_in, model_name):
76
  tmp_dir = tempfile.mkdtemp()
77
  try:
78
- if not video_in: return "❌ Source vide", None
79
 
80
- # A. Extraction Audio
81
  yield "⏳ Extraction audio...", None
82
  full_wav = os.path.join(tmp_dir, "full.wav")
83
  subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True)
84
 
85
- # B. Segmentation
86
  seg_time = 20
87
- segment_pattern = os.path.join(tmp_dir, "seg_%03d.wav")
88
- subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time {seg_time} -c copy {segment_pattern}", shell=True, check=True)
89
  audio_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
90
 
91
- # C. Initialisation Modèle
92
- yield f"⏳ IA : Chargement de {model_name}...", None
93
  model = get_model(model_name)
94
 
95
- stride = 0.02
96
- if hasattr(model, 'preprocessor') and hasattr(model.preprocessor, 'featurizer'):
97
- stride = model.preprocessor.featurizer.hop_length / model.preprocessor.featurizer.sample_rate
98
-
99
- # D. Transcription
100
  all_words_ts = []
101
  for idx, seg_path in enumerate(audio_segments):
102
  base_time = idx * seg_time
103
  yield f"⏳ IA : Transcription {idx+1}/{len(audio_segments)}...", None
104
-
105
  hyp = model.transcribe([seg_path], return_hypotheses=True)[0]
106
- # Gestion des différents formats de retour NeMo
 
107
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
108
  words = text.split()
109
- offsets = getattr(hyp, 'word_offsets', None)
110
-
111
- if offsets and len(offsets) == len(words):
112
- for i, word in enumerate(words):
113
- t_start = base_time + (offsets[i] * stride)
114
- all_words_ts.append({"word": word, "start": t_start, "end": t_start + 0.45})
115
- else:
116
- gap = float(seg_time) / max(len(words), 1)
117
- for i, w in enumerate(words):
118
- all_words_ts.append({"word": w, "start": base_time + (i * gap), "end": base_time + ((i+1) * gap)})
119
 
120
- # E. Génération du SRT
121
  srt_path = os.path.join(tmp_dir, "final.srt")
122
  with open(srt_path, "w", encoding="utf-8") as f:
123
  for i in range(0, len(all_words_ts), 6):
@@ -125,46 +98,28 @@ def pipeline(video_in, model_name):
125
  f.write(f"{(i//6)+1}\n{format_srt_time(chunk[0]['start'])} --> {format_srt_time(chunk[-1]['end'])}\n")
126
  f.write(" ".join([c['word'] for c in chunk]) + "\n\n")
127
 
128
- # F. Encodage Vidéo Final
129
- yield "⏳ Rendu vidéo final...", None
130
- out_path = os.path.abspath(f"robotsmali_final_{int(time.time())}.mp4")
131
  safe_srt = srt_path.replace("\\", "/").replace(":", "\\:")
 
132
 
133
- cmd_ffmpeg = (
134
- f"ffmpeg -y -i {shlex.quote(video_in)} "
135
- f"-vf \"subtitles='{safe_srt}':force_style='Alignment=2,FontSize=18'\" "
136
- f"-c:v libx264 -preset ultrafast -pix_fmt yuv420p -c:a aac {out_path}"
137
- )
138
- subprocess.run(cmd_ffmpeg, shell=True, check=True)
139
-
140
- yield "✅ Terminé avec succès !", out_path
141
-
142
  except Exception as e:
143
- traceback.print_exc()
144
  yield f"❌ Erreur : {str(e)}", None
145
  finally:
146
- # Nettoyage des fichiers temporaires
147
- if os.path.exists(tmp_dir):
148
- shutil.rmtree(tmp_dir)
149
 
150
- # 6. INTERFACE GRADIO
151
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
152
- gr.HTML("<h1 style='text-align:center;'>🤖 ROBOTSMALI TRANSCRIPTION</h1>")
153
-
154
  with gr.Row():
155
  with gr.Column():
156
- v_in = gr.Video(label="Source Vidéo")
157
- m_sel = gr.Dropdown(choices=list(MODELS.keys()), value="Soloba V3 (CTC)", label="Modèle IA")
158
- btn_run = gr.Button("🚀 GÉNÉRER SOUS-TITRES", variant="primary")
159
-
160
- if EXAMPLE_PATH:
161
- gr.Examples(examples=[[EXAMPLE_PATH, "Soloba V3 (CTC)"]], inputs=[v_in, m_sel])
162
-
163
  with gr.Column():
164
- status = gr.Markdown("### État\nPrêt.")
165
- v_out = gr.Video(label="Résultat Final")
166
-
167
- btn_run.click(pipeline, [v_in, m_sel], [status, v_out])
168
 
169
- if __name__ == "__main__":
170
- demo.launch()
 
5
  from nemo.collections import asr as nemo_asr
6
  import gradio as gr
7
 
8
+ # CONFIGURATION
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  MODELS = {
12
  "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
13
+ "Soloba V2 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"),
14
+ "Soloba V1 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"),
15
+ "Soloba V0 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v0", "ctc"),
16
  "Soloba V1.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v1.5", "rnnt"),
17
+ "Soloba V0.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v0.5", "rnnt"),
18
  "Soloni V3 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v3", "rnnt"),
19
  "Soloni V2 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v2", "rnnt"),
20
+ "Soloni V1 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"),
21
+ "Soloni V0 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v0", "rnnt"),
22
  "Soloni MSE (Experimental)": ("RobotsMali/lau-soloni-114m-mse-k1", "ctc"),
23
+ "Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"),
24
  }
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  _cache = {}
27
 
 
28
  def clear_memory():
 
29
  _cache.clear()
30
  gc.collect()
31
  if torch.cuda.is_available():
32
  torch.cuda.empty_cache()
33
 
34
  def get_model(name):
 
35
  if name in _cache: return _cache[name]
 
36
  clear_memory()
37
  repo, _ = MODELS[name]
38
 
39
+ print(f"📥 Téléchargement de {repo}...")
40
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
41
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
42
 
43
+ # Chargement robuste pour éviter l'erreur Unexpected Key
44
+ from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
45
+ model = nemo_asr.models.ASRModel.restore_from(
46
+ nemo_file,
47
+ map_location=torch.device(DEVICE),
48
+ save_restore_connector=SaveRestoreConnector()
49
+ )
50
 
51
  model.to(DEVICE).eval()
 
 
52
  if DEVICE == "cuda":
53
  model.half()
54
 
55
  _cache[name] = model
56
  return model
57
 
 
58
  def format_srt_time(sec):
59
  td = time.gmtime(sec)
60
  ms = int((sec - int(sec)) * 1000)
61
  return f"{time.strftime('%H:%M:%S', td)},{ms:03}"
62
 
 
63
  def pipeline(video_in, model_name):
64
  tmp_dir = tempfile.mkdtemp()
65
  try:
66
+ if not video_in: return "❌ Vidéo manquante", None
67
 
 
68
  yield "⏳ Extraction audio...", None
69
  full_wav = os.path.join(tmp_dir, "full.wav")
70
  subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True)
71
 
72
+ yield "⏳ Segmentation...", None
73
  seg_time = 20
74
+ subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time {seg_time} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True, check=True)
 
75
  audio_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
76
 
77
+ yield f"⏳ IA : Chargement {model_name}...", None
 
78
  model = get_model(model_name)
79
 
 
 
 
 
 
80
  all_words_ts = []
81
  for idx, seg_path in enumerate(audio_segments):
82
  base_time = idx * seg_time
83
  yield f"⏳ IA : Transcription {idx+1}/{len(audio_segments)}...", None
 
84
  hyp = model.transcribe([seg_path], return_hypotheses=True)[0]
85
+ if isinstance(hyp, list): hyp = hyp[0]
86
+
87
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
88
  words = text.split()
89
+ # On simule un timing si les offsets manquent
90
+ gap = float(seg_time) / max(len(words), 1)
91
+ for i, w in enumerate(words):
92
+ all_words_ts.append({"word": w, "start": base_time + (i * gap), "end": base_time + ((i+1) * gap)})
 
 
 
 
 
 
93
 
 
94
  srt_path = os.path.join(tmp_dir, "final.srt")
95
  with open(srt_path, "w", encoding="utf-8") as f:
96
  for i in range(0, len(all_words_ts), 6):
 
98
  f.write(f"{(i//6)+1}\n{format_srt_time(chunk[0]['start'])} --> {format_srt_time(chunk[-1]['end'])}\n")
99
  f.write(" ".join([c['word'] for c in chunk]) + "\n\n")
100
 
101
+ yield "⏳ Rendu vidéo...", None
102
+ out_path = os.path.abspath(f"resultat_{int(time.time())}.mp4")
 
103
  safe_srt = srt_path.replace("\\", "/").replace(":", "\\:")
104
+ subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}'\" -c:v libx264 -preset ultrafast -c:a aac {out_path}", shell=True, check=True)
105
 
106
+ yield "✅ Terminé !", out_path
 
 
 
 
 
 
 
 
107
  except Exception as e:
 
108
  yield f"❌ Erreur : {str(e)}", None
109
  finally:
110
+ if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)
 
 
111
 
112
+ # INTERFACE
113
+ with gr.Blocks() as demo:
114
+ gr.Markdown("# 🤖 RobotsMali ASR Test")
 
115
  with gr.Row():
116
  with gr.Column():
117
+ v_in = gr.Video()
118
+ m_sel = gr.Dropdown(choices=list(MODELS.keys()), value="Soloba V3 (CTC)")
119
+ btn = gr.Button("Lancer", variant="primary")
 
 
 
 
120
  with gr.Column():
121
+ status = gr.Markdown("Prêt.")
122
+ v_out = gr.Video()
123
+ btn.click(pipeline, [v_in, m_sel], [status, v_out])
 
124
 
125
+ demo.launch(debug=True; share=True)