Bot
Fix: HF token auth and remove missing examples
3578cc2
raw
history blame
15.8 kB
import os, shlex, subprocess, tempfile, traceback, glob, gc, shutil
import torch
from huggingface_hub import snapshot_download
from nemo.collections import asr as nemo_asr
from nemo.collections.asr.models import EncDecCTCModel, EncDecRNNTModel
import gradio as gr
import time
import psutil
import humanize
# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEGMENT_DURATION = 10.0
print(f"✅ Démarrage sur device: {DEVICE}")
print(f"✅ Gradio version: {gr.__version__}")
print(f"✅ Mémoire disponible: {humanize.naturalsize(psutil.virtual_memory().available)}")
# Dictionnaire des modèles RobotsMali
MODELS = {
"Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"),
"Soloba V2 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"),
"Soloba V1 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"),
"Soloba V1.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v1.5", "rnnt"),
"Soloba V0.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v0.5", "rnnt"),
"Soloni V3 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v3", "rnnt"),
"Soloni V2 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v2", "rnnt"),
"Soloni V1 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"),
"Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"),
}
_cache = {}
def get_model(name):
if name in _cache:
print(f"✅ Modèle {name} déjà en cache")
return _cache[name]
print(f"📥 Chargement du modèle: {name}")
# Gestion agressive de la mémoire
if len(_cache) >= 1:
print("🧹 Nettoyage du cache...")
_cache.clear()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"🧹 Mémoire GPU libérée: {torch.cuda.memory_allocated()/1e9:.2f}GB")
try:
repo, arch_type = MODELS[name]
print(f"📦 Téléchargement depuis {repo}...")
start_time = time.time()
token = os.environ.get("HF_TOKEN")
if not token:
print("⚠️ Attention: HF_TOKEN non trouvé dans les variables d'environnement")
folder = snapshot_download(repo, local_dir_use_symlinks=False, token=token)
download_time = time.time() - start_time
print(f"📁 Dossier: {folder} (téléchargé en {download_time:.1f}s)")
nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
if nemo_file is None:
raise FileNotFoundError(f"Aucun fichier .nemo trouvé dans {folder}")
print(f"🔧 Restauration du modèle depuis {nemo_file}")
print(f"📊 Taille du fichier: {humanize.naturalsize(os.path.getsize(nemo_file))}")
# Chargement direct selon le type
try:
if arch_type == "ctc":
print("📥 Chargement en tant que modèle CTC...")
model = EncDecCTCModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
else: # rnnt
print("📥 Chargement en tant que modèle RNNT...")
model = EncDecRNNTModel.restore_from(nemo_file, map_location=torch.device(DEVICE))
except Exception as e:
print(f"❌ Échec du chargement spécifique: {e}")
print("⚠️ Les modèles RobotsMali doivent être chargés avec leur classe spécifique")
print("⚠️ Vérifiez que le type de modèle correspond à l'architecture")
raise RuntimeError(f"Impossible de charger le modèle {name}. Type attendu: {arch_type}") from e
# Patch de la configuration
try:
if hasattr(model, 'cfg'):
def remove_key_phrase_items_list(config):
if isinstance(config, dict):
if 'key_phrase_items_list' in config:
del config['key_phrase_items_list']
print("✅ Clé problématique key_phrase_items_list supprimée")
for key, value in config.items():
if isinstance(value, (dict, list)):
remove_key_phrase_items_list(value)
elif isinstance(config, list):
for item in config:
if isinstance(item, (dict, list)):
remove_key_phrase_items_list(item)
remove_key_phrase_items_list(model.cfg)
except Exception as e:
print(f"⚠️ Avertissement lors du patch de config: {e}")
model.eval()
if DEVICE == "cuda":
model = model.half()
print(f"🎯 Modèle converti en half precision")
print(f"✅ Modèle {name} chargé avec succès")
if DEVICE == "cuda":
print(f"📊 Mémoire GPU utilisée: {torch.cuda.memory_allocated()/1e9:.2f}GB")
_cache[name] = model
return model
except Exception as e:
print(f"❌ Erreur lors du chargement du modèle {name}:")
print(traceback.format_exc())
raise RuntimeError(f"Échec du chargement du modèle {name}: {str(e)}")
def get_audio_info(filepath):
"""Récupère les informations d'un fichier audio"""
try:
cmd = f"ffprobe -v quiet -print_format json -show_format -show_streams {shlex.quote(filepath)}"
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
if result.returncode == 0:
import json
info = json.loads(result.stdout)
streams = info.get('streams', [])
audio_stream = next((s for s in streams if s.get('codec_type') == 'audio'), None)
if audio_stream:
duration = float(info['format'].get('duration', 0))
return {
'duration': duration,
'sample_rate': audio_stream.get('sample_rate', '?'),
'channels': audio_stream.get('channels', '?'),
'codec': audio_stream.get('codec_name', '?'),
'size': humanize.naturalsize(int(info['format'].get('size', 0)))
}
except:
pass
return None
def format_time(seconds):
"""Formate le temps en MM:SS"""
if seconds <= 0:
return "00:00"
minutes = int(seconds // 60)
seconds = int(seconds % 60)
return f"{minutes:02d}:{seconds:02d}"
def pipeline(audio_in, model_name, progress=gr.Progress()):
if not audio_in:
yield "❌ Erreur", "Aucun audio détecté.", gr.update(visible=False)
return
tmp_dir = tempfile.mkdtemp()
try:
# === PHASE 1: Analyse de l'audio original ===
yield "🔍 Analyse du fichier audio...", "", gr.update(visible=False)
audio_info = get_audio_info(audio_in)
if audio_info:
duration = audio_info['duration']
duration_str = format_time(duration)
info_text = f"""
📊 **Informations audio :**
- Durée : {duration_str} ({duration:.1f} secondes)
- Taille : {audio_info['size']}
- Fréquence : {audio_info['sample_rate']} Hz
- Canaux : {audio_info['channels']}
- Codec : {audio_info['codec']}
"""
else:
duration = 0
info_text = "ℹ️ Impossible de lire les métadonnées audio"
yield f"⏳ Préparation... ({duration_str if duration > 0 else '??'})", info_text, gr.update(visible=False)
# === PHASE 2: Conversion audio ===
yield f"🔄 Conversion audio (étape 1/3)...", info_text, gr.update(visible=False)
progress(0.1, desc="Conversion audio...")
wav_path = os.path.join(tmp_dir, "input.wav")
# Vérification que le fichier source existe
if not os.path.exists(audio_in):
yield "❌ Erreur", f"Fichier audio introuvable: {audio_in}", gr.update(visible=False)
return
# Conversion avec FFmpeg
cmd = f"ffmpeg -y -i {shlex.quote(audio_in)} -ac 1 -ar 16000 {shlex.quote(wav_path)}"
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
if result.returncode != 0:
error_msg = f"Erreur FFmpeg: {result.stderr[:200]}..."
yield "❌ Erreur", error_msg, gr.update(visible=False)
return
if not os.path.exists(wav_path) or os.path.getsize(wav_path) == 0:
yield "❌ Erreur", "Fichier audio converti vide", gr.update(visible=False)
return
# Info sur l'audio converti
converted_size = humanize.naturalsize(os.path.getsize(wav_path))
info_text += f"\n- Après conversion : {converted_size}"
# === PHASE 3: Segmentation ===
yield f"✂️ Segmentation audio (étape 2/3)...", info_text, gr.update(visible=False)
progress(0.3, desc="Segmentation...")
seg_pattern = os.path.join(tmp_dir, 'seg_%03d.wav')
cmd = f"ffmpeg -i {shlex.quote(wav_path)} -f segment -segment_time {SEGMENT_DURATION} -c copy {shlex.quote(seg_pattern)}"
subprocess.run(cmd, shell=True, capture_output=True)
valid_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav")))
if not valid_segments:
yield "❌ Erreur", "Fichier audio vide ou incompatible après segmentation.", gr.update(visible=False)
return
nb_segments = len(valid_segments)
info_text += f"\n📦 **Segments :** {nb_segments}"
# === PHASE 4: Chargement du modèle ===
yield f"🤖 Chargement du modèle {model_name}...", info_text, gr.update(visible=False)
progress(0.5, desc="Chargement du modèle...")
try:
model = get_model(model_name)
except Exception as e:
yield "❌ Erreur modèle", f"Impossible de charger le modèle: {str(e)}", gr.update(visible=False)
return
# === PHASE 5: Transcription ===
yield f"📝 Transcription en cours... (0/{nb_segments})", info_text, gr.update(visible=False)
all_results = []
batch_size = 4
with torch.inference_mode():
for i in range(0, len(valid_segments), batch_size):
batch = valid_segments[i:i+batch_size]
try:
batch_hyp = model.transcribe(batch, batch_size=len(batch), return_hypotheses=True)
batch_results = [hyp.text if hasattr(hyp, 'text') else str(hyp) for hyp in batch_hyp]
all_results.extend(batch_results)
except Exception as e:
print(f"⚠️ Erreur sur le batch {i}: {e}")
# Continuer avec le batch suivant
continue
# Mise à jour de la progression
processed = min(i + batch_size, nb_segments)
progress_val = 0.5 + (0.5 * processed / nb_segments)
progress(progress_val, desc=f"Transcription {processed}/{nb_segments}")
# Afficher les résultats partiels
partial_text = " ".join(all_results)
yield f"📝 Transcription en cours... ({processed}/{nb_segments})", info_text, gr.update(value=partial_text, visible=True)
final_text = " ".join(all_results)
# === FIN ===
success_text = f"""
✅ **Transcription terminée !**
- Modèle : {model_name}
- Durée audio : {duration_str if duration > 0 else '?'}
- Segments : {nb_segments}
{info_text}
"""
yield "✅ Succès", success_text, gr.update(value=final_text, visible=True)
except Exception as e:
print(traceback.format_exc())
error_msg = f"❌ Erreur: {str(e)}"
yield error_msg, "", gr.update(visible=False)
finally:
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
print(f"🧹 Nettoyage du répertoire temporaire: {tmp_dir}")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Interface Gradio
with gr.Blocks(title="RobotsMali ASR", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🤖 RobotsMali - Reconnaissance Vocale
**Transcription automatique de l'audio en texte** avec les modèles de RobotsMali.
Choisissez un modèle, uploadez un fichier audio ou enregistrez-vous, et lancez la transcription !
""")
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(
label="🎤 Audio",
type="filepath",
sources=["upload", "microphone"],
waveform_options=gr.WaveformOptions(
waveform_color="#3498db",
waveform_progress_color="#2ecc71",
)
)
model_input = gr.Dropdown(
choices=list(MODELS.keys()),
value="Soloni V3 (TDT-CTC)",
label="🧠 Modèle",
info="Sélectionnez le modèle de transcription"
)
run_btn = gr.Button("🚀 DÉMARRER LA TRANSCRIPTION", variant="primary", size="lg")
with gr.Column(scale=1):
status = gr.Markdown("### 📊 État : En attente")
audio_info = gr.Markdown("ℹ️ Chargez un audio pour voir ses informations")
with gr.Row():
with gr.Column():
text_output = gr.Textbox(
label="📝 Transcription",
lines=8,
placeholder="La transcription apparaîtra ici...",
interactive=False,
visible=False
)
# Les exemples ont été retirés car les fichiers audio sont manquants dans le dépôt.
# Pour les rajouter, créez un dossier 'exemples/' et ajoutez-y les fichiers.
# Footer
gr.Markdown("""
---
### 📌 Notes
- Les fichiers audio sont traités localement et supprimés après transcription
- Durée maximale recommandée : 5 minutes
- Modèles entraînés par [RobotsMali](https://huggingface.co/RobotsMali)
""")
# Gestionnaire d'événements
run_btn.click(
fn=pipeline,
inputs=[audio_input, model_input],
outputs=[status, audio_info, text_output]
)
# Afficher les infos audio quand un fichier est chargé
def on_audio_upload(audio):
if audio:
info = get_audio_info(audio)
if info:
duration_str = format_time(info['duration'])
return f"""
### ✅ Audio chargé
- Durée : {duration_str} ({info['duration']:.1f}s)
- Taille : {info['size']}
- Format : {info['sample_rate']}Hz, {info['channels']} canaux
Cliquez sur **DÉMARRER** pour transcrire !
"""
else:
return "✅ Audio chargé. Prêt pour la transcription."
return "ℹ️ Chargez un audio pour voir ses informations"
audio_input.change(
fn=on_audio_upload,
inputs=audio_input,
outputs=audio_info
)
# Point d'entrée - Configuration pour Hugging Face Spaces
if __name__ == "__main__":
print("🚀 Lancement de l'application RobotsMali ASR...")
print(f"📊 Mémoire système: {humanize.naturalsize(psutil.virtual_memory().total)}")
print(f"🎯 Device: {DEVICE}")
# Configuration simple pour Spaces
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)