File size: 15,779 Bytes
55e23fc
1e5b3af
7d87345
 
697b4cb
599bc18
6a368d1
 
 
32e487a
55e23fc
7d87345
5c26bdc
55e23fc
7b84d76
6a368d1
7d87345
f58a689
4ff231b
599bc18
 
 
 
 
 
 
 
 
4ff231b
 
8e3b228
4ff231b
599bc18
 
6a368d1
599bc18
0dbfd3b
55e23fc
 
f58a689
0c8d6fd
6a368d1
599bc18
 
55e23fc
 
6a368d1
4ff231b
697b4cb
55e23fc
 
 
6a368d1
3578cc2
 
 
 
 
6a368d1
 
55e23fc
 
 
 
 
 
 
6a368d1
55e23fc
fb26d45
6a368d1
 
fb26d45
6a368d1
fb26d45
 
6a368d1
 
fb26d45
 
 
 
6a368d1
fb26d45
6a368d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55e23fc
 
 
6a368d1
55e23fc
 
6a368d1
 
 
55e23fc
 
 
 
 
 
fb26d45
4ff231b
6a368d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb26d45
 
6a368d1
 
 
 
 
0c8d6fd
6a368d1
0c8d6fd
 
0dbfd3b
8e3b228
6a368d1
 
 
 
 
 
 
 
fb26d45
 
 
 
 
 
6a368d1
 
 
 
55e23fc
6a368d1
 
 
 
 
 
 
 
 
55e23fc
6a368d1
55e23fc
 
6a368d1
55e23fc
 
 
 
6a368d1
 
55e23fc
 
 
6a368d1
55e23fc
6a368d1
 
 
 
 
 
 
 
 
55e23fc
 
 
4ff231b
599bc18
 
6a368d1
0dbfd3b
55e23fc
6a368d1
fb26d45
6a368d1
 
 
fb26d45
55e23fc
fb26d45
 
 
 
 
55e23fc
6a368d1
 
 
 
 
 
7d87345
6a368d1
 
fb26d45
 
 
 
 
 
 
 
 
6a368d1
 
 
 
 
 
 
 
 
 
 
 
 
 
fb26d45
 
 
 
 
 
6a368d1
 
 
f746159
4ff231b
f58a689
6a368d1
 
8e3b228
55e23fc
 
6a368d1
3578cc2
 
 
4ff231b
55e23fc
 
6a368d1
 
 
 
 
 
55e23fc
8e3b228
6a368d1
55e23fc
6a368d1
55e23fc
6a368d1
 
 
 
 
55e23fc
6a368d1
55e23fc
 
fb26d45
6a368d1
 
55e23fc
6a368d1
 
55e23fc
6a368d1
 
 
 
 
8e3b228
6a368d1
 
 
 
 
 
 
 
3578cc2
 
6a368d1
 
 
 
 
 
 
 
 
 
 
55e23fc
 
 
6a368d1
 
 
 
 
 
 
 
 
 
fb26d45
 
 
 
 
 
6a368d1
 
 
 
 
 
 
 
 
55e23fc
599bc18
7b84d76
0dbfd3b
55e23fc
6a368d1
 
7b84d76
 
92d00a6
55e23fc
6a368d1
 
92d00a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
import os, shlex, subprocess, tempfile, traceback, glob, gc, shutil
import torch
from huggingface_hub import snapshot_download
from nemo.collections import asr as nemo_asr
from nemo.collections.asr.models import EncDecCTCModel, EncDecRNNTModel
import gradio as gr
import time
import psutil
import humanize

# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEGMENT_DURATION = 10.0
print(f"✅ Démarrage sur device: {DEVICE}")
print(f"✅ Gradio version: {gr.__version__}")
print(f"✅ Mémoire disponible: {humanize.naturalsize(psutil.virtual_memory().available)}")

# Dictionnaire des modèles RobotsMali
MODELS = {
    "Soloba V3 (CTC)":           ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
    "Soloba V2 (CTC)":           ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"),
    "Soloba V1 (CTC)":           ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"),
    "Soloba V1.5 (TDT)":         ("RobotsMali/soloba-tdt-0.6b-v1.5", "rnnt"),
    "Soloba V0.5 (TDT)":         ("RobotsMali/soloba-tdt-0.6b-v0.5", "rnnt"),
    "Soloni V3 (TDT-CTC)":       ("RobotsMali/soloni-114m-tdt-ctc-v3", "rnnt"),
    "Soloni V2 (TDT-CTC)":       ("RobotsMali/soloni-114m-tdt-ctc-v2", "rnnt"),
    "Soloni V1 (TDT-CTC)":       ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"),
    "Traduction Soloni (ST)":    ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"),
}

_cache = {}

def get_model(name):
    if name in _cache:
        print(f"✅ Modèle {name} déjà en cache")
        return _cache[name]
    
    print(f"📥 Chargement du modèle: {name}")
    
    # Gestion agressive de la mémoire
    if len(_cache) >= 1:
        print("🧹 Nettoyage du cache...")
        _cache.clear()
        gc.collect()
        if torch.cuda.is_available(): 
            torch.cuda.empty_cache()
            print(f"🧹 Mémoire GPU libérée: {torch.cuda.memory_allocated()/1e9:.2f}GB")

    try:
        repo, arch_type = MODELS[name]
        print(f"📦 Téléchargement depuis {repo}...")
        
        start_time = time.time()
        token = os.environ.get("HF_TOKEN")
        if not token:
            print("⚠️ Attention: HF_TOKEN non trouvé dans les variables d'environnement")
        
        folder = snapshot_download(repo, local_dir_use_symlinks=False, token=token)
        download_time = time.time() - start_time
        print(f"📁 Dossier: {folder} (téléchargé en {download_time:.1f}s)")
        
        nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
        
        if nemo_file is None:
            raise FileNotFoundError(f"Aucun fichier .nemo trouvé dans {folder}")
        
        print(f"🔧 Restauration du modèle depuis {nemo_file}")
        print(f"📊 Taille du fichier: {humanize.naturalsize(os.path.getsize(nemo_file))}")

        # Chargement direct selon le type
        try:
            if arch_type == "ctc":
                print("📥 Chargement en tant que modèle CTC...")
                model = EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
            else:  # rnnt
                print("📥 Chargement en tant que modèle RNNT...")
                model = EncDecRNNTModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
        except Exception as e:
            print(f"❌ Échec du chargement spécifique: {e}")
            print("⚠️ Les modèles RobotsMali doivent être chargés avec leur classe spécifique")
            print("⚠️ Vérifiez que le type de modèle correspond à l'architecture")
            raise RuntimeError(f"Impossible de charger le modèle {name}. Type attendu: {arch_type}") from e
        
        # Patch de la configuration
        try:
            if hasattr(model, 'cfg'):
                def remove_key_phrase_items_list(config):
                    if isinstance(config, dict):
                        if 'key_phrase_items_list' in config:
                            del config['key_phrase_items_list']
                            print("✅ Clé problématique key_phrase_items_list supprimée")
                        for key, value in config.items():
                            if isinstance(value, (dict, list)):
                                remove_key_phrase_items_list(value)
                    elif isinstance(config, list):
                        for item in config:
                            if isinstance(item, (dict, list)):
                                remove_key_phrase_items_list(item)
                
                remove_key_phrase_items_list(model.cfg)
        except Exception as e:
            print(f"⚠️ Avertissement lors du patch de config: {e}")
        
        model.eval()
        if DEVICE == "cuda":
            model = model.half()
            print(f"🎯 Modèle converti en half precision")
        
        print(f"✅ Modèle {name} chargé avec succès")
        if DEVICE == "cuda":
            print(f"📊 Mémoire GPU utilisée: {torch.cuda.memory_allocated()/1e9:.2f}GB")
        
        _cache[name] = model
        return model
        
    except Exception as e:
        print(f"❌ Erreur lors du chargement du modèle {name}:")
        print(traceback.format_exc())
        raise RuntimeError(f"Échec du chargement du modèle {name}: {str(e)}")

def get_audio_info(filepath):
    """Récupère les informations d'un fichier audio"""
    try:
        cmd = f"ffprobe -v quiet -print_format json -show_format -show_streams {shlex.quote(filepath)}"
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        if result.returncode == 0:
            import json
            info = json.loads(result.stdout)
            streams = info.get('streams', [])
            audio_stream = next((s for s in streams if s.get('codec_type') == 'audio'), None)
            if audio_stream:
                duration = float(info['format'].get('duration', 0))
                return {
                    'duration': duration,
                    'sample_rate': audio_stream.get('sample_rate', '?'),
                    'channels': audio_stream.get('channels', '?'),
                    'codec': audio_stream.get('codec_name', '?'),
                    'size': humanize.naturalsize(int(info['format'].get('size', 0)))
                }
    except:
        pass
    return None

def format_time(seconds):
    """Formate le temps en MM:SS"""
    if seconds <= 0:
        return "00:00"
    minutes = int(seconds // 60)
    seconds = int(seconds % 60)
    return f"{minutes:02d}:{seconds:02d}"

def pipeline(audio_in, model_name, progress=gr.Progress()):
    if not audio_in:
        yield "❌ Erreur", "Aucun audio détecté.", gr.update(visible=False)
        return

    tmp_dir = tempfile.mkdtemp()
    try:
        # === PHASE 1: Analyse de l'audio original ===
        yield "🔍 Analyse du fichier audio...", "", gr.update(visible=False)
        
        audio_info = get_audio_info(audio_in)
        if audio_info:
            duration = audio_info['duration']
            duration_str = format_time(duration)
            info_text = f"""
📊 **Informations audio :**
- Durée : {duration_str} ({duration:.1f} secondes)
- Taille : {audio_info['size']}
- Fréquence : {audio_info['sample_rate']} Hz
- Canaux : {audio_info['channels']}
- Codec : {audio_info['codec']}
            """
        else:
            duration = 0
            info_text = "ℹ️ Impossible de lire les métadonnées audio"
        
        yield f"⏳ Préparation... ({duration_str if duration > 0 else '??'})", info_text, gr.update(visible=False)
        
        # === PHASE 2: Conversion audio ===
        yield f"🔄 Conversion audio (étape 1/3)...", info_text, gr.update(visible=False)
        progress(0.1, desc="Conversion audio...")
        
        wav_path = os.path.join(tmp_dir, "input.wav")
        
        # Vérification que le fichier source existe
        if not os.path.exists(audio_in):
            yield "❌ Erreur", f"Fichier audio introuvable: {audio_in}", gr.update(visible=False)
            return
        
        # Conversion avec FFmpeg
        cmd = f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {shlex.quote(wav_path)}"
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        
        if result.returncode != 0:
            error_msg = f"Erreur FFmpeg: {result.stderr[:200]}..."
            yield "❌ Erreur", error_msg, gr.update(visible=False)
            return
        
        if not os.path.exists(wav_path) or os.path.getsize(wav_path) == 0:
            yield "❌ Erreur", "Fichier audio converti vide", gr.update(visible=False)
            return
        
        # Info sur l'audio converti
        converted_size = humanize.naturalsize(os.path.getsize(wav_path))
        info_text += f"\n- Après conversion : {converted_size}"
        
        # === PHASE 3: Segmentation ===
        yield f"✂️ Segmentation audio (étape 2/3)...", info_text, gr.update(visible=False)
        progress(0.3, desc="Segmentation...")
        
        seg_pattern = os.path.join(tmp_dir, 'seg_%03d.wav')
        cmd = f"ffmpeg -i {shlex.quote(wav_path)} -f segment -segment_time {SEGMENT_DURATION} -c copy {shlex.quote(seg_pattern)}"
        subprocess.run(cmd, shell=True, capture_output=True)

        valid_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
        if not valid_segments:
            yield "❌ Erreur", "Fichier audio vide ou incompatible après segmentation.", gr.update(visible=False)
            return
        
        nb_segments = len(valid_segments)
        info_text += f"\n📦 **Segments :** {nb_segments}"
        
        # === PHASE 4: Chargement du modèle ===
        yield f"🤖 Chargement du modèle {model_name}...", info_text, gr.update(visible=False)
        progress(0.5, desc="Chargement du modèle...")
        
        try:
            model = get_model(model_name)
        except Exception as e:
            yield "❌ Erreur modèle", f"Impossible de charger le modèle: {str(e)}", gr.update(visible=False)
            return
        
        # === PHASE 5: Transcription ===
        yield f"📝 Transcription en cours... (0/{nb_segments})", info_text, gr.update(visible=False)
        
        all_results = []
        batch_size = 4
        
        with torch.inference_mode():
            for i in range(0, len(valid_segments), batch_size):
                batch = valid_segments[i:i+batch_size]
                try:
                    batch_hyp = model.transcribe(batch, batch_size=len(batch), return_hypotheses=True)
                    
                    batch_results = [hyp.text if hasattr(hyp, 'text') else str(hyp) for hyp in batch_hyp]
                    all_results.extend(batch_results)
                except Exception as e:
                    print(f"⚠️ Erreur sur le batch {i}: {e}")
                    # Continuer avec le batch suivant
                    continue
                
                # Mise à jour de la progression
                processed = min(i + batch_size, nb_segments)
                progress_val = 0.5 + (0.5 * processed / nb_segments)
                progress(progress_val, desc=f"Transcription {processed}/{nb_segments}")
                
                # Afficher les résultats partiels
                partial_text = " ".join(all_results)
                yield f"📝 Transcription en cours... ({processed}/{nb_segments})", info_text, gr.update(value=partial_text, visible=True)
        
        final_text = " ".join(all_results)
        
        # === FIN ===
        success_text = f"""
✅ **Transcription terminée !**
- Modèle : {model_name}
- Durée audio : {duration_str if duration > 0 else '?'}
- Segments : {nb_segments}

{info_text}
        """
        
        yield "✅ Succès", success_text, gr.update(value=final_text, visible=True)

    except Exception as e:
        print(traceback.format_exc())
        error_msg = f"❌ Erreur: {str(e)}"
        yield error_msg, "", gr.update(visible=False)
    finally:
        if os.path.exists(tmp_dir): 
            shutil.rmtree(tmp_dir)
            print(f"🧹 Nettoyage du répertoire temporaire: {tmp_dir}")
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

# Interface Gradio
with gr.Blocks(title="RobotsMali ASR", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 🤖 RobotsMali - Reconnaissance Vocale
    **Transcription automatique de l'audio en texte** avec les modèles de RobotsMali.
    
    Choisissez un modèle, uploadez un fichier audio ou enregistrez-vous, et lancez la transcription !
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            audio_input = gr.Audio(
                label="🎤 Audio", 
                type="filepath", 
                sources=["upload", "microphone"],
                waveform_options=gr.WaveformOptions(
                    waveform_color="#3498db",
                    waveform_progress_color="#2ecc71",
                )
            )
            
            model_input = gr.Dropdown(
                choices=list(MODELS.keys()), 
                value="Soloni V3 (TDT-CTC)", 
                label="🧠 Modèle",
                info="Sélectionnez le modèle de transcription"
            )
            
            run_btn = gr.Button("🚀 DÉMARRER LA TRANSCRIPTION", variant="primary", size="lg")
        
        with gr.Column(scale=1):
            status = gr.Markdown("### 📊 État : En attente")
            audio_info = gr.Markdown("ℹ️ Chargez un audio pour voir ses informations")
    
    with gr.Row():
        with gr.Column():
            text_output = gr.Textbox(
                label="📝 Transcription", 
                lines=8,
                placeholder="La transcription apparaîtra ici...",
                interactive=False,
                visible=False
            )
    
    # Les exemples ont été retirés car les fichiers audio sont manquants dans le dépôt.
    # Pour les rajouter, créez un dossier 'exemples/' et ajoutez-y les fichiers.
    
    # Footer
    gr.Markdown("""
    ---
    ### 📌 Notes
    - Les fichiers audio sont traités localement et supprimés après transcription
    - Durée maximale recommandée : 5 minutes
    - Modèles entraînés par [RobotsMali](https://huggingface.co/RobotsMali)
    """)
    
    # Gestionnaire d'événements
    run_btn.click(
        fn=pipeline, 
        inputs=[audio_input, model_input], 
        outputs=[status, audio_info, text_output]
    )
    
    # Afficher les infos audio quand un fichier est chargé
    def on_audio_upload(audio):
        if audio:
            info = get_audio_info(audio)
            if info:
                duration_str = format_time(info['duration'])
                return f"""
### ✅ Audio chargé
- Durée : {duration_str} ({info['duration']:.1f}s)
- Taille : {info['size']}
- Format : {info['sample_rate']}Hz, {info['channels']} canaux

Cliquez sur **DÉMARRER** pour transcrire !
                """
            else:
                return "✅ Audio chargé. Prêt pour la transcription."
        return "ℹ️ Chargez un audio pour voir ses informations"
    
    audio_input.change(
        fn=on_audio_upload,
        inputs=audio_input,
        outputs=audio_info
    )

# Point d'entrée - Configuration pour Hugging Face Spaces
if __name__ == "__main__":
    print("🚀 Lancement de l'application RobotsMali ASR...")
    print(f"📊 Mémoire système: {humanize.naturalsize(psutil.virtual_memory().total)}")
    print(f"🎯 Device: {DEVICE}")
    
    # Configuration simple pour Spaces
    demo.queue().launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True
    )