gilbert-stt-diarization / diarization_pyannote_demo.py
mathisescriva
Initial commit: pyannote diarization Space
704669a
#!/usr/bin/env python3
"""
Script de diarisation utilisant pyannote.audio (Community-1 ou 3.1).
Ce script prend en entrée un fichier audio et génère :
- Un fichier RTTM
- Un fichier JSON avec les segments de diarisation
Le modèle Community-1 est utilisé par défaut (meilleur que 3.1 selon les benchmarks).
Usage:
python diarization_pyannote_demo.py <input_audio.wav> [--output_dir OUTPUT_DIR]
python diarization_pyannote_demo.py audio.wav --num_speakers 3
python diarization_pyannote_demo.py audio.wav --model pyannote/speaker-diarization-precision-2
"""
import argparse
import json
import os
import sys
from pathlib import Path
from typing import List, Dict, Any
try:
# Importer pyannote en évitant les imports NeMo si possible
import os
# Désactiver temporairement l'import NeMo dans pyannote si nécessaire
os.environ['PYANNOTE_DISABLE_NEMO'] = '1'
from pyannote.audio import Pipeline
from pyannote.core import Annotation
try:
from pyannote.audio.pipelines.utils.hook import ProgressHook
HAS_PROGRESS_HOOK = True
except ImportError:
HAS_PROGRESS_HOOK = False
except ImportError as e:
print("ERREUR: pyannote.audio n'est pas installé. Voir INSTALL.md pour les instructions.")
print(f"Détails: {e}")
sys.exit(1)
except Exception as e:
# Si l'import échoue à cause de NeMo, donner des instructions
if 'nemo' in str(e).lower() or 'transformers' in str(e).lower():
print("ERREUR: Conflit de dépendances avec NeMo/transformers.")
print("Solution recommandée: Utiliser un environnement conda dédié.")
print("Exécuter: ./setup_nemo_env.sh")
print(f"Détails: {e}")
else:
print(f"ERREUR: {e}")
sys.exit(1)
import torch
# Corriger le problème PyTorch 2.6 avec weights_only
if hasattr(torch.serialization, 'add_safe_globals'):
try:
torch.serialization.add_safe_globals([torch.torch_version.TorchVersion])
except:
pass
def load_pyannote_pipeline(
model_name: str = "pyannote/speaker-diarization-community-1",
token: str = None
) -> Pipeline:
"""
Charge le pipeline de diarisation pyannote.
Args:
model_name: Nom du modèle Hugging Face
- "pyannote/speaker-diarization-community-1" (défaut, meilleur que 3.1)
- "pyannote/speaker-diarization-3.1" (legacy)
- "pyannote/speaker-diarization-precision-2" (nécessite API key pyannoteAI)
token: Token d'authentification (HF_TOKEN ou API key pyannoteAI)
Returns:
Pipeline pyannote configuré
"""
print(f"Chargement du pipeline pyannote: {model_name}")
# Déterminer le token à utiliser
if token is None:
# Pour precision-2, utiliser l'API key pyannoteAI si disponible
if "precision-2" in model_name:
token = os.environ.get("PYANNOTEAI_API_KEY") or os.environ.get("HF_TOKEN")
else:
token = os.environ.get("HF_TOKEN")
# Configurer le token dans huggingface_hub si disponible
if token:
try:
from huggingface_hub import login
login(token=token, add_to_git_credential=False)
except Exception:
# Si login échoue, on essaiera quand même avec use_auth_token
pass
if not token:
print("ATTENTION: Token d'authentification non défini.")
if "precision-2" in model_name:
print("Pour precision-2, définir: export PYANNOTEAI_API_KEY='votre_api_key'")
else:
print("Définir: export HF_TOKEN='votre_token'")
print("Note: Le script fonctionnera mais le téléchargement du modèle peut échouer.")
try:
# Ne pas passer use_auth_token car il cause des erreurs avec les nouvelles versions
# Le token est déjà configuré via huggingface_hub.login() si disponible
pipeline = Pipeline.from_pretrained(model_name)
# Déplacer sur GPU si disponible
if torch.cuda.is_available():
pipeline = pipeline.to(torch.device("cuda"))
print("Pipeline chargé sur GPU")
else:
print("Pipeline chargé sur CPU")
return pipeline
except Exception as e:
print(f"ERREUR lors du chargement du pipeline: {e}")
print("\nSolutions possibles:")
print("1. Vérifier que vous avez accepté les conditions d'utilisation sur Hugging Face")
print("2. Configurer un token: export HF_TOKEN='votre_token'")
if "precision-2" in model_name:
print("3. Pour precision-2, créer une API key sur pyannoteAI dashboard")
print("4. Vérifier votre connexion internet")
sys.exit(1)
def convert_audio_if_needed(audio_path: str) -> str:
"""
Convertit l'audio en WAV si nécessaire (pour les formats non supportés).
Args:
audio_path: Chemin vers le fichier audio
Returns:
Chemin vers le fichier audio (converti si nécessaire)
"""
ext = Path(audio_path).suffix.lower()
# Formats supportés directement par pyannote
supported_formats = {'.wav', '.flac', '.ogg'}
if ext in supported_formats:
return audio_path
# Convertir en WAV si nécessaire
if ext in {'.m4a', '.mp3', '.mp4', '.aac'}:
print(f"Conversion de {ext} en WAV...")
import librosa
import soundfile as sf
wav_path = str(Path(audio_path).with_suffix('.wav'))
# Vérifier si le fichier WAV existe déjà
if os.path.exists(wav_path):
print(f"Fichier WAV existant trouvé: {wav_path}")
return wav_path
try:
y, sr = librosa.load(audio_path, sr=16000, mono=True)
sf.write(wav_path, y, sr)
print(f"✅ Converti en WAV: {wav_path}")
return wav_path
except Exception as e:
print(f"ATTENTION: Erreur lors de la conversion, utilisation du fichier original: {e}")
return audio_path
return audio_path
def run_pyannote_diarization(
audio_path: str,
output_dir: str = "outputs/pyannote",
model_name: str = "pyannote/speaker-diarization-community-1",
num_speakers: int = None,
min_speakers: int = None,
max_speakers: int = None,
use_exclusive: bool = False,
show_progress: bool = True
) -> Dict[str, Any]:
"""
Exécute le pipeline de diarisation pyannote.
Args:
audio_path: Chemin vers le fichier audio
output_dir: Répertoire de sortie
model_name: Nom du modèle à utiliser
num_speakers: Nombre exact de locuteurs (si connu)
min_speakers: Nombre minimum de locuteurs
max_speakers: Nombre maximum de locuteurs
use_exclusive: Utiliser exclusive_speaker_diarization (Community-1+)
show_progress: Afficher la progression
Returns:
Dictionnaire contenant les résultats de diarisation
"""
# Convertir l'audio si nécessaire
audio_path = convert_audio_if_needed(audio_path)
print(f"Chargement de l'audio: {audio_path}")
# Créer le répertoire de sortie si nécessaire
os.makedirs(output_dir, exist_ok=True)
# Charger le pipeline
pipeline = load_pyannote_pipeline(model_name)
# Préparer les options de diarisation
diarization_options = {}
if num_speakers is not None:
diarization_options["num_speakers"] = num_speakers
print(f"Nombre de locuteurs fixé: {num_speakers}")
if min_speakers is not None:
diarization_options["min_speakers"] = min_speakers
print(f"Nombre minimum de locuteurs: {min_speakers}")
if max_speakers is not None:
diarization_options["max_speakers"] = max_speakers
print(f"Nombre maximum de locuteurs: {max_speakers}")
# Exécuter la diarisation
print("Exécution de la diarisation...")
try:
if show_progress and HAS_PROGRESS_HOOK:
with ProgressHook() as hook:
diarization = pipeline(audio_path, hook=hook, **diarization_options)
else:
diarization = pipeline(audio_path, **diarization_options)
except Exception as e:
print(f"ERREUR lors de la diarisation: {e}")
sys.exit(1)
# Utiliser exclusive_speaker_diarization si disponible et demandé
if use_exclusive and hasattr(diarization, 'exclusive_speaker_diarization'):
print("Utilisation de exclusive_speaker_diarization")
annotation = diarization.exclusive_speaker_diarization
else:
annotation = diarization
# Convertir l'annotation pyannote en format standard
segments = annotation_to_segments(annotation)
# Calculer les statistiques
num_speakers_detected = len(set(s["speaker"] for s in segments))
# Calculer la durée totale
if segments:
duration = max(s["end"] for s in segments)
else:
duration = 0.0
return {
"segments": segments,
"num_speakers": num_speakers_detected,
"duration": duration
}
def annotation_to_segments(annotation: Annotation) -> List[Dict[str, Any]]:
"""
Convertit une annotation pyannote en liste de segments.
Args:
annotation: Annotation pyannote
Returns:
Liste de segments au format [{"speaker": "...", "start": ..., "end": ...}]
"""
segments = []
# Obtenir tous les locuteurs uniques
speakers = sorted(annotation.labels())
# Créer un mapping pour normaliser les IDs
speaker_mapping = {}
for idx, speaker in enumerate(speakers):
speaker_mapping[speaker] = f"SPEAKER_{idx:02d}"
# Parcourir tous les segments
for segment, track, speaker in annotation.itertracks(yield_label=True):
normalized_speaker = speaker_mapping.get(speaker, speaker)
segments.append({
"speaker": normalized_speaker,
"start": round(segment.start, 2),
"end": round(segment.end, 2)
})
# Trier par temps de début
segments.sort(key=lambda x: x["start"])
return segments
def write_rttm(segments: List[Dict[str, Any]], output_path: str, audio_name: str):
"""
Écrit un fichier RTTM à partir des segments.
Args:
segments: Liste de segments
output_path: Chemin du fichier RTTM de sortie
audio_name: Nom du fichier audio (sans extension)
"""
with open(output_path, 'w') as f:
for seg in segments:
duration = seg["end"] - seg["start"]
# Format RTTM: SPEAKER <file> 1 <start> <duration> <NA> <NA> <speaker_id> <NA> <NA>
f.write(f"SPEAKER {audio_name} 1 {seg['start']:.3f} {duration:.3f} <NA> <NA> {seg['speaker']} <NA> <NA>\n")
def write_json(segments: List[Dict[str, Any]], output_path: str):
"""
Écrit un fichier JSON à partir des segments.
Args:
segments: Liste de segments
output_path: Chemin du fichier JSON de sortie
"""
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(segments, f, indent=2, ensure_ascii=False)
def main():
parser = argparse.ArgumentParser(
description="Diarisation avec pyannote.audio 3.x",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__
)
parser.add_argument(
"audio_path",
type=str,
help="Chemin vers le fichier audio"
)
parser.add_argument(
"--output_dir",
type=str,
default="outputs/pyannote",
help="Répertoire de sortie (défaut: outputs/pyannote)"
)
parser.add_argument(
"--model",
type=str,
default="pyannote/speaker-diarization-community-1",
help="Nom du modèle Hugging Face (défaut: pyannote/speaker-diarization-community-1). "
"Options: community-1, 3.1, precision-2 (nécessite API key pyannoteAI)"
)
parser.add_argument(
"--num_speakers",
type=int,
default=None,
help="Nombre exact de locuteurs (si connu à l'avance)"
)
parser.add_argument(
"--min_speakers",
type=int,
default=None,
help="Nombre minimum de locuteurs"
)
parser.add_argument(
"--max_speakers",
type=int,
default=None,
help="Nombre maximum de locuteurs"
)
parser.add_argument(
"--exclusive",
action="store_true",
help="Utiliser exclusive_speaker_diarization (Community-1+, simplifie la réconciliation avec transcription)"
)
parser.add_argument(
"--no-progress",
action="store_true",
help="Ne pas afficher la barre de progression"
)
args = parser.parse_args()
if not os.path.exists(args.audio_path):
print(f"ERREUR: Fichier audio introuvable: {args.audio_path}")
sys.exit(1)
# Normaliser le nom du modèle si version courte fournie
model_name = args.model
if model_name == "community-1":
model_name = "pyannote/speaker-diarization-community-1"
elif model_name == "3.1":
model_name = "pyannote/speaker-diarization-3.1"
elif model_name == "precision-2":
model_name = "pyannote/speaker-diarization-precision-2"
# Exécuter la diarisation
results = run_pyannote_diarization(
args.audio_path,
args.output_dir,
model_name,
num_speakers=args.num_speakers,
min_speakers=args.min_speakers,
max_speakers=args.max_speakers,
use_exclusive=args.exclusive,
show_progress=not args.no_progress
)
# Préparer les chemins de sortie
audio_name = Path(args.audio_path).stem
rttm_path = os.path.join(args.output_dir, f"{audio_name}.rttm")
json_path = os.path.join(args.output_dir, f"{audio_name}.json")
# Écrire les fichiers de sortie
write_rttm(results["segments"], rttm_path, audio_name)
write_json(results["segments"], json_path)
# Afficher les statistiques
print("\n" + "="*50)
print("RÉSULTATS DE LA DIARISATION")
print("="*50)
print(f"Nombre de locuteurs détectés: {results['num_speakers']}")
print(f"Durée totale: {results['duration']:.2f} secondes")
print(f"Nombre de segments: {len(results['segments'])}")
# Statistiques par locuteur
speaker_stats = {}
for seg in results["segments"]:
speaker = seg["speaker"]
duration = seg["end"] - seg["start"]
if speaker not in speaker_stats:
speaker_stats[speaker] = {"total_duration": 0.0, "num_segments": 0}
speaker_stats[speaker]["total_duration"] += duration
speaker_stats[speaker]["num_segments"] += 1
print("\nStatistiques par locuteur:")
for speaker, stats in sorted(speaker_stats.items()):
avg_duration = stats["total_duration"] / stats["num_segments"] if stats["num_segments"] > 0 else 0
print(f" {speaker}: {stats['num_segments']} segments, "
f"{stats['total_duration']:.2f}s total, "
f"{avg_duration:.2f}s moyenne/segment")
print(f"\nFichiers générés:")
print(f" RTTM: {rttm_path}")
print(f" JSON: {json_path}")
if __name__ == "__main__":
main()