File size: 5,999 Bytes
5738fbf
326676b
60f6dea
f6e735c
6ec5f30
64e18a4
5738fbf
24a54d7
6ec5f30
24a54d7
60f6dea
3b5085e
cc8e3ad
 
3b5085e
cc8e3ad
3b5085e
 
cc8e3ad
 
60f6dea
6ec5f30
d436569
173bdc2
d436569
173bdc2
d436569
 
 
78a6ca1
3b5085e
24a54d7
 
74445b0
 
24a54d7
 
 
 
 
 
74445b0
24a54d7
224a0d9
900f511
95a2204
24a54d7
 
cc8e3ad
24a54d7
 
cc8e3ad
24a54d7
 
 
 
cc8e3ad
24a54d7
6ec5f30
24a54d7
 
 
900f511
6ec5f30
e685733
24a54d7
93438d8
 
 
bde1ae6
224a0d9
24a54d7
bde1ae6
93438d8
e685733
74445b0
24a54d7
74445b0
24a54d7
 
93438d8
173bdc2
24a54d7
 
78a6ca1
93438d8
88d36f5
24a54d7
 
74445b0
24a54d7
74445b0
 
173bdc2
93438d8
173bdc2
24a54d7
78a6ca1
cc8e3ad
326676b
 
78a6ca1
cc8e3ad
 
bde1ae6
24a54d7
93438d8
bde1ae6
88d36f5
 
 
93438d8
bde1ae6
74445b0
3b5085e
d436569
24a54d7
d436569
88d36f5
24a54d7
d436569
77e790d
24a54d7
326676b
cc8e3ad
e685733
24a54d7
78a6ca1
eb25311
78a6ca1
e685733
b01f955
24a54d7
6d5ada0
24a54d7
78a6ca1
d436569
eb25311
78a6ca1
b01f955
24a54d7
74445b0
78a6ca1
d436569
e685733
24a54d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# -*- coding: utf-8 -*-
import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil
import torch
from huggingface_hub import snapshot_download
from nemo.collections import asr as nemo_asr
import gradio as gr

# 1. CONFIGURATION ET MODÈLES
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

MODELS = {
    "Soloba V3 (CTC)":           ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
    "Soloba V2 (CTC)":           ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"),
    "Soloba V1 (CTC)":           ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"),
    "Soloba V1.5 (TDT)":         ("RobotsMali/soloba-tdt-0.6b-v1.5", "rnnt"),
    "Soloba V0.5 (TDT)":         ("RobotsMali/soloba-tdt-0.6b-v0.5", "rnnt"),
    "Soloni V3 (TDT-CTC)":       ("RobotsMali/soloni-114m-tdt-ctc-v3", "rnnt"),
    "Soloni V2 (TDT-CTC)":       ("RobotsMali/soloni-114m-tdt-ctc-v2", "rnnt"),
    "Soloni V1 (TDT-CTC)":       ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"),
    "Traduction Soloni (ST)":    ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"),
}

def find_example_video():
    paths = ["examples/MARALINKE_FIXED.mp4", "examples/MARALINKE.mp4", "MARALINKE.mp4"]
    for p in paths:
        if os.path.exists(p): return p
    return None

EXAMPLE_PATH = find_example_video()
_cache = {}

# 2. GESTION MÉMOIRE ET CHARGEMENT (AVEC CORRECTIF STATE_DICT)
def clear_memory():
    _cache.clear()
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def get_model(name):
    if name in _cache: return _cache[name]
    clear_memory()
    repo, _ = MODELS[name]
    
    folder = snapshot_download(repo, local_dir_use_symlinks=False)
    nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
    
    if not nemo_file: raise FileNotFoundError("Fichier .nemo introuvable.")

    from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
    
    # Correctif pour les clés "embedding_model" inattendues
    model = nemo_asr.models.ASRModel.restore_from(
        nemo_file, 
        map_location=torch.device(DEVICE),
        save_restore_connector=SaveRestoreConnector(),
        strict=False
    )
    
    model.to(DEVICE).eval()
    if DEVICE == "cuda":
        model.half()
        
    _cache[name] = model
    return model

# 3. UTILITAIRES
def format_srt_time(sec):
    td = time.gmtime(sec)
    ms = int((sec - int(sec)) * 1000)
    return f"{time.strftime('%H:%M:%S', td)},{ms:03}"

# 4. PIPELINE DE TRANSCRIPTION
def pipeline(video_in, model_name):
    tmp_dir = tempfile.mkdtemp()
    try:
        if not video_in:
            yield "❌ Aucune vidéo sélectionnée.", None
            return
        
        yield "⏳ Phase 1/4 : Extraction audio...", None
        full_wav = os.path.join(tmp_dir, "full.wav")
        subprocess.run(f"ffmpeg -y -threads 0 -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True)
        
        yield "⏳ Phase 2/4 : Segmentation...", None
        subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time 20 -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True, check=True)
        audio_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
        
        yield f"⏳ Phase 3/4 : Chargement de {model_name}...", None
        model = get_model(model_name)
        
        yield f"🎙️ Transcription de {len(audio_segments)} segments...", None
        b_size = 2 if DEVICE == "cpu" else 4
        batch_hypotheses = model.transcribe(audio_segments, batch_size=b_size, return_hypotheses=True)
        
        all_words_ts = []
        for idx, hyp in enumerate(batch_hypotheses):
            yield f"📝 Traitement : {idx+1}/{len(audio_segments)}...", None
            base_time = idx * 20
            if isinstance(hyp, list): hyp = hyp[0]
            text = hyp.text if hasattr(hyp, 'text') else str(hyp)
            words = text.split()
            gap = 20.0 / max(len(words), 1)
            for i, w in enumerate(words):
                all_words_ts.append({"word": w, "start": base_time + (i * gap), "end": base_time + ((i+1) * gap)})

        yield "⏳ Phase 4/4 : Encodage vidéo...", None
        srt_path = os.path.join(tmp_dir, "final.srt")
        with open(srt_path, "w", encoding="utf-8") as f:
            for i in range(0, len(all_words_ts), 6):
                chunk = all_words_ts[i:i+6]
                f.write(f"{(i//6)+1}\n{format_srt_time(chunk[0]['start'])} --> {format_srt_time(chunk[-1]['end'])}\n")
                f.write(" ".join([c['word'] for c in chunk]) + "\n\n")

        out_path = os.path.abspath(f"resultat_{int(time.time())}.mp4")
        safe_srt = srt_path.replace("\\", "/").replace(":", "\\:")
        
        cmd = f"ffmpeg -y -threads 0 -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}':force_style='Alignment=2,FontSize=18,PrimaryColour=&H00FFFF'\" -c:v libx264 -preset ultrafast -c:a copy {out_path}"
        subprocess.run(cmd, shell=True, check=True)
        
        yield "✅ Terminé !", out_path

    except Exception as e:
        yield f"❌ Erreur : {str(e)}", None
    finally:
        if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)

# 5. INTERFACE GRADIO
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.HTML("<div style='text-align:center;'><h1>🤖 RobotsMali Speech Lab</h1></div>")
    
    with gr.Row():
        with gr.Column():
            v_input = gr.Video(label="Vidéo Source")
            m_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle")
            run_btn = gr.Button("🚀 GÉNÉRER", variant="primary")
            
            if EXAMPLE_PATH:
                gr.Examples(examples=[[EXAMPLE_PATH, "Soloni V3 (TDT-CTC)"]], inputs=[v_input, m_input])

        with gr.Column():
            status = gr.Markdown("### État\nEn attente...")
            v_output = gr.Video(label="Vidéo finale")

    run_btn.click(pipeline, [v_input, m_input], [status, v_output])

demo.queue().launch()