binaryMao commited on
Commit
78a6ca1
·
verified ·
1 Parent(s): 716a2a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -36
app.py CHANGED
@@ -5,7 +5,7 @@ from huggingface_hub import snapshot_download
5
  from nemo.collections import asr as nemo_asr
6
  import gradio as gr
7
 
8
- # CONFIGURATION
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  MODELS = {
@@ -23,35 +23,28 @@ MODELS = {
23
  "Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"),
24
  }
25
 
26
- _cache = {}
 
27
 
28
- def clear_memory():
29
- _cache.clear()
30
- gc.collect()
31
- if torch.cuda.is_available():
32
- torch.cuda.empty_cache()
33
 
34
  def get_model(name):
35
  if name in _cache: return _cache[name]
36
- clear_memory()
37
- repo, _ = MODELS[name]
38
 
39
- print(f"📥 Téléchargement de {repo}...")
40
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
41
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
42
 
43
- # Chargement robuste pour éviter l'erreur Unexpected Key
44
  from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
45
  model = nemo_asr.models.ASRModel.restore_from(
46
  nemo_file,
47
  map_location=torch.device(DEVICE),
48
  save_restore_connector=SaveRestoreConnector()
49
  )
50
-
51
  model.to(DEVICE).eval()
52
- if DEVICE == "cuda":
53
- model.half()
54
-
55
  _cache[name] = model
56
  return model
57
 
@@ -63,34 +56,32 @@ def format_srt_time(sec):
63
  def pipeline(video_in, model_name):
64
  tmp_dir = tempfile.mkdtemp()
65
  try:
66
- if not video_in: return "❌ Vidéo manquante", None
67
 
68
- yield "⏳ Extraction audio...", None
69
  full_wav = os.path.join(tmp_dir, "full.wav")
70
  subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True)
71
 
72
- yield "⏳ Segmentation...", None
73
- seg_time = 20
74
- subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time {seg_time} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True, check=True)
75
  audio_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
76
 
77
- yield f"⏳ IA : Chargement {model_name}...", None
78
  model = get_model(model_name)
79
 
80
  all_words_ts = []
81
  for idx, seg_path in enumerate(audio_segments):
82
- base_time = idx * seg_time
83
- yield f" IA : Transcription {idx+1}/{len(audio_segments)}...", None
84
  hyp = model.transcribe([seg_path], return_hypotheses=True)[0]
85
  if isinstance(hyp, list): hyp = hyp[0]
86
-
87
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
88
  words = text.split()
89
- # On simule un timing si les offsets manquent
90
- gap = float(seg_time) / max(len(words), 1)
91
  for i, w in enumerate(words):
92
  all_words_ts.append({"word": w, "start": base_time + (i * gap), "end": base_time + ((i+1) * gap)})
93
 
 
94
  srt_path = os.path.join(tmp_dir, "final.srt")
95
  with open(srt_path, "w", encoding="utf-8") as f:
96
  for i in range(0, len(all_words_ts), 6):
@@ -98,10 +89,9 @@ def pipeline(video_in, model_name):
98
  f.write(f"{(i//6)+1}\n{format_srt_time(chunk[0]['start'])} --> {format_srt_time(chunk[-1]['end'])}\n")
99
  f.write(" ".join([c['word'] for c in chunk]) + "\n\n")
100
 
101
- yield "⏳ Rendu vidéo...", None
102
- out_path = os.path.abspath(f"resultat_{int(time.time())}.mp4")
103
  safe_srt = srt_path.replace("\\", "/").replace(":", "\\:")
104
- subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}'\" -c:v libx264 -preset ultrafast -c:a aac {out_path}", shell=True, check=True)
105
 
106
  yield "✅ Terminé !", out_path
107
  except Exception as e:
@@ -110,16 +100,27 @@ def pipeline(video_in, model_name):
110
  if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)
111
 
112
  # INTERFACE
113
- with gr.Blocks() as demo:
114
- gr.Markdown("# 🤖 RobotsMali ASR Test")
 
115
  with gr.Row():
116
  with gr.Column():
117
- v_in = gr.Video()
118
- m_sel = gr.Dropdown(choices=list(MODELS.keys()), value="Soloba V3 (CTC)")
119
- btn = gr.Button("Lancer", variant="primary")
 
 
 
 
 
 
 
 
 
120
  with gr.Column():
121
  status = gr.Markdown("Prêt.")
122
- v_out = gr.Video()
 
123
  btn.click(pipeline, [v_in, m_sel], [status, v_out])
124
 
125
- demo.launch(debug=True, share=True)
 
5
  from nemo.collections import asr as nemo_asr
6
  import gradio as gr
7
 
8
+ # 1. CONFIGURATION
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  MODELS = {
 
23
  "Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"),
24
  }
25
 
26
+ # Recherche de la vidéo d'exemple dans le dossier courant
27
+ EXAMPLE_VIDEO = "MARALINKE.mp4" if os.path.exists("MARALINKE.mp4") else None
28
 
29
+ _cache = {}
 
 
 
 
30
 
31
  def get_model(name):
32
  if name in _cache: return _cache[name]
33
+ gc.collect()
34
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
35
 
36
+ repo, _ = MODELS[name]
37
  folder = snapshot_download(repo, local_dir_use_symlinks=False)
38
  nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
39
 
 
40
  from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
41
  model = nemo_asr.models.ASRModel.restore_from(
42
  nemo_file,
43
  map_location=torch.device(DEVICE),
44
  save_restore_connector=SaveRestoreConnector()
45
  )
 
46
  model.to(DEVICE).eval()
47
+ if DEVICE == "cuda": model.half()
 
 
48
  _cache[name] = model
49
  return model
50
 
 
56
  def pipeline(video_in, model_name):
57
  tmp_dir = tempfile.mkdtemp()
58
  try:
59
+ if not video_in: return "❌ Source vide", None
60
 
61
+ yield "⏳ Phase 1 : Extraction Audio...", None
62
  full_wav = os.path.join(tmp_dir, "full.wav")
63
  subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True)
64
 
65
+ yield "⏳ Phase 2 : Segmentation...", None
66
+ subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time 20 -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True, check=True)
 
67
  audio_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
68
 
69
+ yield f"⏳ Phase 3 : IA ({model_name})...", None
70
  model = get_model(model_name)
71
 
72
  all_words_ts = []
73
  for idx, seg_path in enumerate(audio_segments):
74
+ base_time = idx * 20
75
+ yield f"🎙️ Transcription {idx+1}/{len(audio_segments)}...", None
76
  hyp = model.transcribe([seg_path], return_hypotheses=True)[0]
77
  if isinstance(hyp, list): hyp = hyp[0]
 
78
  text = hyp.text if hasattr(hyp, 'text') else str(hyp)
79
  words = text.split()
80
+ gap = 20.0 / max(len(words), 1)
 
81
  for i, w in enumerate(words):
82
  all_words_ts.append({"word": w, "start": base_time + (i * gap), "end": base_time + ((i+1) * gap)})
83
 
84
+ yield "⏳ Phase 4 : Rendu Vidéo...", None
85
  srt_path = os.path.join(tmp_dir, "final.srt")
86
  with open(srt_path, "w", encoding="utf-8") as f:
87
  for i in range(0, len(all_words_ts), 6):
 
89
  f.write(f"{(i//6)+1}\n{format_srt_time(chunk[0]['start'])} --> {format_srt_time(chunk[-1]['end'])}\n")
90
  f.write(" ".join([c['word'] for c in chunk]) + "\n\n")
91
 
92
+ out_path = os.path.abspath(f"output_{int(time.time())}.mp4")
 
93
  safe_srt = srt_path.replace("\\", "/").replace(":", "\\:")
94
+ subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}':force_style='Alignment=2,FontSize=18'\" -c:v libx264 -preset ultrafast -c:a aac {out_path}", shell=True, check=True)
95
 
96
  yield "✅ Terminé !", out_path
97
  except Exception as e:
 
100
  if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)
101
 
102
  # INTERFACE
103
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
104
+ gr.HTML("<h1 style='text-align:center;'>🤖 RobotsMali Transcription</h1>")
105
+
106
  with gr.Row():
107
  with gr.Column():
108
+ v_in = gr.Video(label="Vidéo Source")
109
+ m_sel = gr.Dropdown(choices=list(MODELS.keys()), value="Soloba V3 (CTC)", label="Modèle")
110
+ btn = gr.Button("GÉNÉRER", variant="primary")
111
+
112
+ # REVOICI L'EXEMPLE ICI
113
+ if EXAMPLE_VIDEO:
114
+ gr.Examples(
115
+ examples=[[EXAMPLE_VIDEO, "Soloba V3 (CTC)"]],
116
+ inputs=[v_in, m_sel],
117
+ label="Exemple disponible"
118
+ )
119
+
120
  with gr.Column():
121
  status = gr.Markdown("Prêt.")
122
+ v_out = gr.Video(label="Résultat")
123
+
124
  btn.click(pipeline, [v_in, m_sel], [status, v_out])
125
 
126
+ demo.launch()