Spaces:
Runtime error
Runtime error
| import os, shlex, subprocess, tempfile, traceback, glob, gc, shutil | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from nemo.collections import asr as nemo_asr | |
| from nemo.collections.asr.models import EncDecCTCModel, EncDecRNNTModel | |
| import gradio as gr | |
| import time | |
| import psutil | |
| import humanize | |
| # Configuration | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| SEGMENT_DURATION = 10.0 | |
| print(f"✅ Démarrage sur device: {DEVICE}") | |
| print(f"✅ Gradio version: {gr.__version__}") | |
| print(f"✅ Mémoire disponible: {humanize.naturalsize(psutil.virtual_memory().available)}") | |
| # Dictionnaire des modèles RobotsMali | |
| MODELS = { | |
| "Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"), | |
| "Soloba V2 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"), | |
| "Soloba V1 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"), | |
| "Soloba V1.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v1.5", "rnnt"), | |
| "Soloba V0.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v0.5", "rnnt"), | |
| "Soloni V3 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v3", "rnnt"), | |
| "Soloni V2 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v2", "rnnt"), | |
| "Soloni V1 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"), | |
| "Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"), | |
| } | |
| _cache = {} | |
| def get_model(name): | |
| if name in _cache: | |
| print(f"✅ Modèle {name} déjà en cache") | |
| return _cache[name] | |
| print(f"📥 Chargement du modèle: {name}") | |
| # Gestion agressive de la mémoire | |
| if len(_cache) >= 1: | |
| print("🧹 Nettoyage du cache...") | |
| _cache.clear() | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"🧹 Mémoire GPU libérée: {torch.cuda.memory_allocated()/1e9:.2f}GB") | |
| try: | |
| repo, arch_type = MODELS[name] | |
| print(f"📦 Téléchargement depuis {repo}...") | |
| start_time = time.time() | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| print("⚠️ Attention: HF_TOKEN non trouvé dans les variables d'environnement") | |
| folder = snapshot_download(repo, local_dir_use_symlinks=False, token=token) | |
| download_time = time.time() - start_time | |
| print(f"📁 Dossier: {folder} (téléchargé en {download_time:.1f}s)") | |
| nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None) | |
| if nemo_file is None: | |
| raise FileNotFoundError(f"Aucun fichier .nemo trouvé dans {folder}") | |
| print(f"🔧 Restauration du modèle depuis {nemo_file}") | |
| print(f"📊 Taille du fichier: {humanize.naturalsize(os.path.getsize(nemo_file))}") | |
| # Chargement direct selon le type | |
| try: | |
| if arch_type == "ctc": | |
| print("📥 Chargement en tant que modèle CTC...") | |
| model = EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE)) | |
| else: # rnnt | |
| print("📥 Chargement en tant que modèle RNNT...") | |
| model = EncDecRNNTModel.restore_from(nemo_file, map_location=torch.device(DEVICE)) | |
| except Exception as e: | |
| print(f"❌ Échec du chargement spécifique: {e}") | |
| print("⚠️ Les modèles RobotsMali doivent être chargés avec leur classe spécifique") | |
| print("⚠️ Vérifiez que le type de modèle correspond à l'architecture") | |
| raise RuntimeError(f"Impossible de charger le modèle {name}. Type attendu: {arch_type}") from e | |
| # Patch de la configuration | |
| try: | |
| if hasattr(model, 'cfg'): | |
| def remove_key_phrase_items_list(config): | |
| if isinstance(config, dict): | |
| if 'key_phrase_items_list' in config: | |
| del config['key_phrase_items_list'] | |
| print("✅ Clé problématique key_phrase_items_list supprimée") | |
| for key, value in config.items(): | |
| if isinstance(value, (dict, list)): | |
| remove_key_phrase_items_list(value) | |
| elif isinstance(config, list): | |
| for item in config: | |
| if isinstance(item, (dict, list)): | |
| remove_key_phrase_items_list(item) | |
| remove_key_phrase_items_list(model.cfg) | |
| except Exception as e: | |
| print(f"⚠️ Avertissement lors du patch de config: {e}") | |
| model.eval() | |
| if DEVICE == "cuda": | |
| model = model.half() | |
| print(f"🎯 Modèle converti en half precision") | |
| print(f"✅ Modèle {name} chargé avec succès") | |
| if DEVICE == "cuda": | |
| print(f"📊 Mémoire GPU utilisée: {torch.cuda.memory_allocated()/1e9:.2f}GB") | |
| _cache[name] = model | |
| return model | |
| except Exception as e: | |
| print(f"❌ Erreur lors du chargement du modèle {name}:") | |
| print(traceback.format_exc()) | |
| raise RuntimeError(f"Échec du chargement du modèle {name}: {str(e)}") | |
| def get_audio_info(filepath): | |
| """Récupère les informations d'un fichier audio""" | |
| try: | |
| cmd = f"ffprobe -v quiet -print_format json -show_format -show_streams {shlex.quote(filepath)}" | |
| result = subprocess.run(cmd, shell=True, capture_output=True, text=True) | |
| if result.returncode == 0: | |
| import json | |
| info = json.loads(result.stdout) | |
| streams = info.get('streams', []) | |
| audio_stream = next((s for s in streams if s.get('codec_type') == 'audio'), None) | |
| if audio_stream: | |
| duration = float(info['format'].get('duration', 0)) | |
| return { | |
| 'duration': duration, | |
| 'sample_rate': audio_stream.get('sample_rate', '?'), | |
| 'channels': audio_stream.get('channels', '?'), | |
| 'codec': audio_stream.get('codec_name', '?'), | |
| 'size': humanize.naturalsize(int(info['format'].get('size', 0))) | |
| } | |
| except: | |
| pass | |
| return None | |
| def format_time(seconds): | |
| """Formate le temps en MM:SS""" | |
| if seconds <= 0: | |
| return "00:00" | |
| minutes = int(seconds // 60) | |
| seconds = int(seconds % 60) | |
| return f"{minutes:02d}:{seconds:02d}" | |
| def pipeline(audio_in, model_name, progress=gr.Progress()): | |
| if not audio_in: | |
| yield "❌ Erreur", "Aucun audio détecté.", gr.update(visible=False) | |
| return | |
| tmp_dir = tempfile.mkdtemp() | |
| try: | |
| # === PHASE 1: Analyse de l'audio original === | |
| yield "🔍 Analyse du fichier audio...", "", gr.update(visible=False) | |
| audio_info = get_audio_info(audio_in) | |
| if audio_info: | |
| duration = audio_info['duration'] | |
| duration_str = format_time(duration) | |
| info_text = f""" | |
| 📊 **Informations audio :** | |
| - Durée : {duration_str} ({duration:.1f} secondes) | |
| - Taille : {audio_info['size']} | |
| - Fréquence : {audio_info['sample_rate']} Hz | |
| - Canaux : {audio_info['channels']} | |
| - Codec : {audio_info['codec']} | |
| """ | |
| else: | |
| duration = 0 | |
| info_text = "ℹ️ Impossible de lire les métadonnées audio" | |
| yield f"⏳ Préparation... ({duration_str if duration > 0 else '??'})", info_text, gr.update(visible=False) | |
| # === PHASE 2: Conversion audio === | |
| yield f"🔄 Conversion audio (étape 1/3)...", info_text, gr.update(visible=False) | |
| progress(0.1, desc="Conversion audio...") | |
| wav_path = os.path.join(tmp_dir, "input.wav") | |
| # Vérification que le fichier source existe | |
| if not os.path.exists(audio_in): | |
| yield "❌ Erreur", f"Fichier audio introuvable: {audio_in}", gr.update(visible=False) | |
| return | |
| # Conversion avec FFmpeg | |
| cmd = f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {shlex.quote(wav_path)}" | |
| result = subprocess.run(cmd, shell=True, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| error_msg = f"Erreur FFmpeg: {result.stderr[:200]}..." | |
| yield "❌ Erreur", error_msg, gr.update(visible=False) | |
| return | |
| if not os.path.exists(wav_path) or os.path.getsize(wav_path) == 0: | |
| yield "❌ Erreur", "Fichier audio converti vide", gr.update(visible=False) | |
| return | |
| # Info sur l'audio converti | |
| converted_size = humanize.naturalsize(os.path.getsize(wav_path)) | |
| info_text += f"\n- Après conversion : {converted_size}" | |
| # === PHASE 3: Segmentation === | |
| yield f"✂️ Segmentation audio (étape 2/3)...", info_text, gr.update(visible=False) | |
| progress(0.3, desc="Segmentation...") | |
| seg_pattern = os.path.join(tmp_dir, 'seg_%03d.wav') | |
| cmd = f"ffmpeg -i {shlex.quote(wav_path)} -f segment -segment_time {SEGMENT_DURATION} -c copy {shlex.quote(seg_pattern)}" | |
| subprocess.run(cmd, shell=True, capture_output=True) | |
| valid_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav"))) | |
| if not valid_segments: | |
| yield "❌ Erreur", "Fichier audio vide ou incompatible après segmentation.", gr.update(visible=False) | |
| return | |
| nb_segments = len(valid_segments) | |
| info_text += f"\n📦 **Segments :** {nb_segments}" | |
| # === PHASE 4: Chargement du modèle === | |
| yield f"🤖 Chargement du modèle {model_name}...", info_text, gr.update(visible=False) | |
| progress(0.5, desc="Chargement du modèle...") | |
| try: | |
| model = get_model(model_name) | |
| except Exception as e: | |
| yield "❌ Erreur modèle", f"Impossible de charger le modèle: {str(e)}", gr.update(visible=False) | |
| return | |
| # === PHASE 5: Transcription === | |
| yield f"📝 Transcription en cours... (0/{nb_segments})", info_text, gr.update(visible=False) | |
| all_results = [] | |
| batch_size = 4 | |
| with torch.inference_mode(): | |
| for i in range(0, len(valid_segments), batch_size): | |
| batch = valid_segments[i:i+batch_size] | |
| try: | |
| batch_hyp = model.transcribe(batch, batch_size=len(batch), return_hypotheses=True) | |
| batch_results = [hyp.text if hasattr(hyp, 'text') else str(hyp) for hyp in batch_hyp] | |
| all_results.extend(batch_results) | |
| except Exception as e: | |
| print(f"⚠️ Erreur sur le batch {i}: {e}") | |
| # Continuer avec le batch suivant | |
| continue | |
| # Mise à jour de la progression | |
| processed = min(i + batch_size, nb_segments) | |
| progress_val = 0.5 + (0.5 * processed / nb_segments) | |
| progress(progress_val, desc=f"Transcription {processed}/{nb_segments}") | |
| # Afficher les résultats partiels | |
| partial_text = " ".join(all_results) | |
| yield f"📝 Transcription en cours... ({processed}/{nb_segments})", info_text, gr.update(value=partial_text, visible=True) | |
| final_text = " ".join(all_results) | |
| # === FIN === | |
| success_text = f""" | |
| ✅ **Transcription terminée !** | |
| - Modèle : {model_name} | |
| - Durée audio : {duration_str if duration > 0 else '?'} | |
| - Segments : {nb_segments} | |
| {info_text} | |
| """ | |
| yield "✅ Succès", success_text, gr.update(value=final_text, visible=True) | |
| except Exception as e: | |
| print(traceback.format_exc()) | |
| error_msg = f"❌ Erreur: {str(e)}" | |
| yield error_msg, "", gr.update(visible=False) | |
| finally: | |
| if os.path.exists(tmp_dir): | |
| shutil.rmtree(tmp_dir) | |
| print(f"🧹 Nettoyage du répertoire temporaire: {tmp_dir}") | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Interface Gradio | |
| with gr.Blocks(title="RobotsMali ASR", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🤖 RobotsMali - Reconnaissance Vocale | |
| **Transcription automatique de l'audio en texte** avec les modèles de RobotsMali. | |
| Choisissez un modèle, uploadez un fichier audio ou enregistrez-vous, et lancez la transcription ! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio( | |
| label="🎤 Audio", | |
| type="filepath", | |
| sources=["upload", "microphone"], | |
| waveform_options=gr.WaveformOptions( | |
| waveform_color="#3498db", | |
| waveform_progress_color="#2ecc71", | |
| ) | |
| ) | |
| model_input = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value="Soloni V3 (TDT-CTC)", | |
| label="🧠 Modèle", | |
| info="Sélectionnez le modèle de transcription" | |
| ) | |
| run_btn = gr.Button("🚀 DÉMARRER LA TRANSCRIPTION", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| status = gr.Markdown("### 📊 État : En attente") | |
| audio_info = gr.Markdown("ℹ️ Chargez un audio pour voir ses informations") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_output = gr.Textbox( | |
| label="📝 Transcription", | |
| lines=8, | |
| placeholder="La transcription apparaîtra ici...", | |
| interactive=False, | |
| visible=False | |
| ) | |
| # Les exemples ont été retirés car les fichiers audio sont manquants dans le dépôt. | |
| # Pour les rajouter, créez un dossier 'exemples/' et ajoutez-y les fichiers. | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| ### 📌 Notes | |
| - Les fichiers audio sont traités localement et supprimés après transcription | |
| - Durée maximale recommandée : 5 minutes | |
| - Modèles entraînés par [RobotsMali](https://huggingface.co/RobotsMali) | |
| """) | |
| # Gestionnaire d'événements | |
| run_btn.click( | |
| fn=pipeline, | |
| inputs=[audio_input, model_input], | |
| outputs=[status, audio_info, text_output] | |
| ) | |
| # Afficher les infos audio quand un fichier est chargé | |
| def on_audio_upload(audio): | |
| if audio: | |
| info = get_audio_info(audio) | |
| if info: | |
| duration_str = format_time(info['duration']) | |
| return f""" | |
| ### ✅ Audio chargé | |
| - Durée : {duration_str} ({info['duration']:.1f}s) | |
| - Taille : {info['size']} | |
| - Format : {info['sample_rate']}Hz, {info['channels']} canaux | |
| Cliquez sur **DÉMARRER** pour transcrire ! | |
| """ | |
| else: | |
| return "✅ Audio chargé. Prêt pour la transcription." | |
| return "ℹ️ Chargez un audio pour voir ses informations" | |
| audio_input.change( | |
| fn=on_audio_upload, | |
| inputs=audio_input, | |
| outputs=audio_info | |
| ) | |
| # Point d'entrée - Configuration pour Hugging Face Spaces | |
| if __name__ == "__main__": | |
| print("🚀 Lancement de l'application RobotsMali ASR...") | |
| print(f"📊 Mémoire système: {humanize.naturalsize(psutil.virtual_memory().total)}") | |
| print(f"🎯 Device: {DEVICE}") | |
| # Configuration simple pour Spaces | |
| demo.queue().launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |