| |
| import os, shlex, subprocess, tempfile, traceback, time, glob, gc |
| import torch |
| from huggingface_hub import snapshot_download |
| from nemo.collections import asr as nemo_asr |
| import gradio as gr |
|
|
| |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| MODELS = { |
| "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"), |
| "Soloba V1.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v1.5", "rnnt"), |
| "Soloni V3 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v3", "rnnt"), |
| "Soloni V2 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v2", "rnnt"), |
| "Soloni MSE (Experimental)": ("RobotsMali/lau-soloni-114m-mse-k1", "ctc"), |
| "Soloba V0.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v0.5", "rnnt"), |
| } |
|
|
| _cache = {} |
|
|
| def clear_memory(): |
| """Nettoie la VRAM et la RAM pour éviter les débordements.""" |
| _cache.clear() |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| def load_model(name): |
| """Charge le modèle avec optimisation FP16 si possible.""" |
| if name in _cache: return _cache[name] |
| |
| yield f"⏳ Chargement du modèle {name}..." |
| clear_memory() |
| |
| repo, mode = MODELS[name] |
| folder = snapshot_download(repo, local_dir_use_symlinks=False) |
| nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None) |
| |
| if mode == "rnnt": |
| model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.restore_from(nemo_file) |
| else: |
| model = nemo_asr.models.EncDecCTCModelBPE.restore_from(nemo_file) |
| |
| model.to(DEVICE).eval() |
| if DEVICE == "cuda": |
| model.half() |
| |
| _cache[name] = model |
| return model |
|
|
| def format_srt_time(sec): |
| td = time.gmtime(sec) |
| ms = int((sec - int(sec)) * 1000) |
| return f"{time.strftime('%H:%M:%S', td)},{ms:03}" |
|
|
| def pipeline(video_in, model_name): |
| tmp_dir = tempfile.mkdtemp() |
| try: |
| if not video_in: return "❌ Source vide", None, None |
| |
| |
| yield "⏳ Extraction de l'audio...", None, None |
| full_wav = os.path.join(tmp_dir, "full.wav") |
| subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True) |
| |
| |
| segment_pattern = os.path.join(tmp_dir, "seg_%03d.wav") |
| subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time 10 -c copy {segment_pattern}", shell=True, check=True) |
| audio_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav"))) |
| |
| |
| model_gen = load_model(model_name) |
| model = None |
| for update in model_gen: |
| if isinstance(update, str): yield update, None, None |
| else: model = update |
| |
| |
| stride = 0.02 |
| if hasattr(model, 'preprocessor') and hasattr(model.preprocessor, 'featurizer'): |
| hop = model.preprocessor.featurizer.hop_length |
| sr = model.preprocessor.featurizer.sample_rate |
| stride = hop / sr |
|
|
| |
| all_words_ts = [] |
| for idx, seg_path in enumerate(audio_segments): |
| base_time = idx * 10.0 |
| yield f"⏳ IA : Transcription segment {idx+1}/{len(audio_segments)}...", None, None |
| |
| hyp = model.transcribe([seg_path], return_hypotheses=True)[0] |
| offsets = getattr(hyp, 'word_offsets', None) |
| words = hyp.text.split() if hasattr(hyp, 'text') else str(hyp).split() |
|
|
| if offsets and len(offsets) == len(words): |
| for i, word in enumerate(words): |
| start_t = base_time + (offsets[i] * stride) |
| all_words_ts.append({"word": word, "start": start_t, "end": start_t + 0.45}) |
| else: |
| |
| gap = 10.0 / max(len(words), 1) |
| for i, w in enumerate(words): |
| all_words_ts.append({"word": w, "start": base_time + (i * gap), "end": base_time + ((i+1) * gap)}) |
|
|
| |
| srt_path = os.path.join(tmp_dir, "final.srt") |
| words_per_line = 6 |
| with open(srt_path, "w", encoding="utf-8") as f: |
| for i in range(0, len(all_words_ts), words_per_line): |
| chunk = all_words_ts[i:i+words_per_line] |
| f.write(f"{(i//words_per_line)+1}\n") |
| f.write(f"{format_srt_time(chunk[0]['start'])} --> {format_srt_time(chunk[-1]['end'])}\n") |
| f.write(" ".join([c['word'] for c in chunk]) + "\n\n") |
|
|
| |
| yield "⏳ Rendu vidéo final ...", None, srt_path |
| out_path = os.path.abspath(f"robotsmali_final_{int(time.time())}.mp4") |
| |
| |
| safe_srt = srt_path.replace("\\", "/").replace(":", "\\:") |
| cmd_ffmpeg = ( |
| f"ffmpeg -y -i {shlex.quote(video_in)} " |
| f"-vf \"subtitles='{safe_srt}':force_style='Alignment=2,FontSize=18,OutlineColour=&H80000000,BorderStyle=4'\" " |
| f"-c:v libx264 -preset fast -pix_fmt yuv420p -movflags +faststart -c:a aac {out_path}" |
| ) |
| subprocess.run(cmd_ffmpeg, shell=True, check=True) |
| |
| yield "✅ Transcription et Incrustation Terminées !", out_path, srt_path |
|
|
| except Exception as e: |
| traceback.print_exc() |
| yield f"❌ Erreur : {str(e)}", None, None |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: |
| gr.HTML("<h1 style='text-align:center; color:#EAB308;'>🤖 ROBOTSMALI TRANSCRIPTION PRO</h1>") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| v_in = gr.Video(label="Vidéo Source (Upload ou Webcam)", sources=["upload", "webcam"]) |
| m_sel = gr.Dropdown(choices=list(MODELS.keys()), value="Soloba V3 (CTC)", label="Choisir le Modèle IA") |
| btn_run = gr.Button("🚀 GÉNÉRER LES SOUS-TITRES", variant="primary") |
| |
| with gr.Column(scale=1): |
| status = gr.Markdown("### État\nPrêt à l'emploi.") |
| v_out = gr.Video(label="Vidéo Finale Incrustée") |
| f_srt = gr.File(label="Fichier Sous-titres (.SRT)") |
|
|
| btn_run.click(pipeline, [v_in, m_sel], [status, v_out, f_srt]) |
|
|
| if __name__ == "__main__": |
| demo.launch(debug=True, show_error=True) |