|
|
|
|
|
import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil |
|
|
import torch |
|
|
from huggingface_hub import snapshot_download |
|
|
from nemo.collections import asr as nemo_asr |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
MODELS = { |
|
|
"Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"), |
|
|
"Soloba V2 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"), |
|
|
"Soloba V1 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"), |
|
|
"Soloba V1.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v1.5", "rnnt"), |
|
|
"Soloba V0.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v0.5", "rnnt"), |
|
|
"Soloni V3 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v3", "rnnt"), |
|
|
"Soloni V2 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v2", "rnnt"), |
|
|
"Soloni V1 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"), |
|
|
"Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"), |
|
|
} |
|
|
|
|
|
def find_example_video(): |
|
|
paths = ["examples/MARALINKE_FIXED.mp4", "examples/MARALINKE.mp4", "MARALINKE.mp4"] |
|
|
for p in paths: |
|
|
if os.path.exists(p): return p |
|
|
return None |
|
|
|
|
|
EXAMPLE_PATH = find_example_video() |
|
|
_cache = {} |
|
|
|
|
|
|
|
|
def clear_memory(): |
|
|
_cache.clear() |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def get_model(name): |
|
|
if name in _cache: return _cache[name] |
|
|
clear_memory() |
|
|
repo, _ = MODELS[name] |
|
|
|
|
|
folder = snapshot_download(repo, local_dir_use_symlinks=False) |
|
|
nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None) |
|
|
|
|
|
if not nemo_file: raise FileNotFoundError("Fichier .nemo introuvable.") |
|
|
|
|
|
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector |
|
|
|
|
|
|
|
|
model = nemo_asr.models.ASRModel.restore_from( |
|
|
nemo_file, |
|
|
map_location=torch.device(DEVICE), |
|
|
save_restore_connector=SaveRestoreConnector(), |
|
|
strict=False |
|
|
) |
|
|
|
|
|
model.to(DEVICE).eval() |
|
|
if DEVICE == "cuda": |
|
|
model.half() |
|
|
|
|
|
_cache[name] = model |
|
|
return model |
|
|
|
|
|
|
|
|
def format_srt_time(sec): |
|
|
td = time.gmtime(sec) |
|
|
ms = int((sec - int(sec)) * 1000) |
|
|
return f"{time.strftime('%H:%M:%S', td)},{ms:03}" |
|
|
|
|
|
|
|
|
def pipeline(video_in, model_name): |
|
|
tmp_dir = tempfile.mkdtemp() |
|
|
try: |
|
|
if not video_in: |
|
|
yield "❌ Aucune vidéo sélectionnée.", None |
|
|
return |
|
|
|
|
|
yield "⏳ Phase 1/4 : Extraction audio...", None |
|
|
full_wav = os.path.join(tmp_dir, "full.wav") |
|
|
subprocess.run(f"ffmpeg -y -threads 0 -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True) |
|
|
|
|
|
yield "⏳ Phase 2/4 : Segmentation...", None |
|
|
subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time 20 -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True, check=True) |
|
|
audio_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav"))) |
|
|
|
|
|
yield f"⏳ Phase 3/4 : Chargement de {model_name}...", None |
|
|
model = get_model(model_name) |
|
|
|
|
|
yield f"🎙️ Transcription de {len(audio_segments)} segments...", None |
|
|
b_size = 2 if DEVICE == "cpu" else 4 |
|
|
batch_hypotheses = model.transcribe(audio_segments, batch_size=b_size, return_hypotheses=True) |
|
|
|
|
|
all_words_ts = [] |
|
|
for idx, hyp in enumerate(batch_hypotheses): |
|
|
yield f"📝 Traitement : {idx+1}/{len(audio_segments)}...", None |
|
|
base_time = idx * 20 |
|
|
if isinstance(hyp, list): hyp = hyp[0] |
|
|
text = hyp.text if hasattr(hyp, 'text') else str(hyp) |
|
|
words = text.split() |
|
|
gap = 20.0 / max(len(words), 1) |
|
|
for i, w in enumerate(words): |
|
|
all_words_ts.append({"word": w, "start": base_time + (i * gap), "end": base_time + ((i+1) * gap)}) |
|
|
|
|
|
yield "⏳ Phase 4/4 : Encodage vidéo...", None |
|
|
srt_path = os.path.join(tmp_dir, "final.srt") |
|
|
with open(srt_path, "w", encoding="utf-8") as f: |
|
|
for i in range(0, len(all_words_ts), 6): |
|
|
chunk = all_words_ts[i:i+6] |
|
|
f.write(f"{(i//6)+1}\n{format_srt_time(chunk[0]['start'])} --> {format_srt_time(chunk[-1]['end'])}\n") |
|
|
f.write(" ".join([c['word'] for c in chunk]) + "\n\n") |
|
|
|
|
|
out_path = os.path.abspath(f"resultat_{int(time.time())}.mp4") |
|
|
safe_srt = srt_path.replace("\\", "/").replace(":", "\\:") |
|
|
|
|
|
cmd = f"ffmpeg -y -threads 0 -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}':force_style='Alignment=2,FontSize=18,PrimaryColour=&H00FFFF'\" -c:v libx264 -preset ultrafast -c:a copy {out_path}" |
|
|
subprocess.run(cmd, shell=True, check=True) |
|
|
|
|
|
yield "✅ Terminé !", out_path |
|
|
|
|
|
except Exception as e: |
|
|
yield f"❌ Erreur : {str(e)}", None |
|
|
finally: |
|
|
if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir) |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.HTML("<div style='text-align:center;'><h1>🤖 RobotsMali Speech Lab</h1></div>") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
v_input = gr.Video(label="Vidéo Source") |
|
|
m_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle") |
|
|
run_btn = gr.Button("🚀 GÉNÉRER", variant="primary") |
|
|
|
|
|
if EXAMPLE_PATH: |
|
|
gr.Examples(examples=[[EXAMPLE_PATH, "Soloni V3 (TDT-CTC)"]], inputs=[v_input, m_input]) |
|
|
|
|
|
with gr.Column(): |
|
|
status = gr.Markdown("### État\nEn attente...") |
|
|
v_output = gr.Video(label="Vidéo finale") |
|
|
|
|
|
run_btn.click(pipeline, [v_input, m_input], [status, v_output]) |
|
|
|
|
|
demo.queue().launch() |