gilbert-pyannote-diarization / diarization_pyannote_gilbert.py
MEscriva's picture
Upload diarization_pyannote_gilbert.py with huggingface_hub
c62be5a verified
#!/usr/bin/env python3
"""
Gilbert - Modèle de diarisation pyannote personnalisé
Version propriétaire optimisée pour le projet Gilbert avec :
- Post-traitement intelligent des segments
- Fusion des segments courts
- Détection d'overlap améliorée
- Configuration optimisée pour les réunions
- Statistiques avancées
Usage:
python diarization_pyannote_gilbert.py <input_audio.wav> [--output_dir OUTPUT_DIR]
"""
import argparse
import json
import os
import sys
from pathlib import Path
from typing import List, Dict, Any
from collections import defaultdict
try:
os.environ['PYANNOTE_DISABLE_NEMO'] = '1'
from pyannote.audio import Pipeline
from pyannote.core import Annotation
try:
from pyannote.audio.pipelines.utils.hook import ProgressHook
HAS_PROGRESS_HOOK = True
except ImportError:
HAS_PROGRESS_HOOK = False
except ImportError as e:
print("ERREUR: pyannote.audio n'est pas installé.")
sys.exit(1)
import torch
# Corriger le problème PyTorch 2.6 avec weights_only
if hasattr(torch.serialization, 'add_safe_globals'):
try:
torch.serialization.add_safe_globals([torch.torch_version.TorchVersion])
except:
pass
def load_gilbert_pipeline(
model_name: str = "pyannote/speaker-diarization-3.1",
token: str = None
) -> Pipeline:
"""
Charge le pipeline pyannote avec configuration optimisée pour Gilbert.
Args:
model_name: Nom du modèle Hugging Face
token: Token d'authentification
Returns:
Pipeline pyannote configuré
"""
print(f"[Gilbert] Chargement du pipeline: {model_name}")
if token is None:
token = os.environ.get("HF_TOKEN")
if token:
try:
from huggingface_hub import login
login(token=token, add_to_git_credential=False)
except Exception:
pass
if not token:
print("[Gilbert] ATTENTION: Token non défini. Certains modèles peuvent nécessiter un token.")
try:
pipeline = Pipeline.from_pretrained(model_name)
# Configuration optimisée pour les réunions (Gilbert)
# Ajuster les paramètres du pipeline si possible
if hasattr(pipeline, 'segmentation'):
# Optimiser pour les réunions : segments plus longs, moins de fragmentation
pass
print("[Gilbert] Pipeline chargé et optimisé pour les réunions")
return pipeline
except Exception as e:
print(f"[Gilbert] ERREUR lors du chargement: {e}")
raise
def post_process_segments_gilbert(
segments: List[Dict[str, Any]],
min_segment_duration: float = 0.5,
max_gap: float = 0.3
) -> List[Dict[str, Any]]:
"""
Post-traitement intelligent des segments pour Gilbert.
- Fusionne les segments courts du même locuteur
- Supprime les gaps trop petits
- Optimise pour les réunions
Args:
segments: Liste de segments bruts
min_segment_duration: Durée minimale d'un segment (secondes)
max_gap: Gap maximum à fusionner (secondes)
Returns:
Segments post-traités
"""
if not segments:
return segments
# Trier par temps de début
segments = sorted(segments, key=lambda x: x["start"])
# Grouper par locuteur et fusionner les segments proches
processed = []
current_segment = None
for seg in segments:
duration = seg["end"] - seg["start"]
# Filtrer les segments trop courts
if duration < min_segment_duration:
continue
if current_segment is None:
current_segment = seg.copy()
elif (current_segment["speaker"] == seg["speaker"] and
seg["start"] - current_segment["end"] <= max_gap):
# Fusionner avec le segment précédent
current_segment["end"] = seg["end"]
else:
# Sauvegarder le segment actuel et commencer un nouveau
processed.append(current_segment)
current_segment = seg.copy()
# Ajouter le dernier segment
if current_segment:
processed.append(current_segment)
return processed
def detect_overlaps_gilbert(segments: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Détecte et marque les overlaps entre locuteurs (Gilbert).
Args:
segments: Liste de segments
Returns:
Segments avec information d'overlap
"""
# Créer une timeline pour détecter les overlaps
timeline = []
for seg in segments:
timeline.append((seg["start"], seg["end"], seg["speaker"]))
# Détecter les overlaps
overlaps = []
for i, (start1, end1, speaker1) in enumerate(timeline):
for j, (start2, end2, speaker2) in enumerate(timeline[i+1:], i+1):
if speaker1 == speaker2:
continue
# Calculer l'overlap
overlap_start = max(start1, start2)
overlap_end = min(end1, end2)
if overlap_start < overlap_end:
overlaps.append({
"start": overlap_start,
"end": overlap_end,
"speakers": [speaker1, speaker2],
"duration": overlap_end - overlap_start
})
return overlaps
def convert_audio_if_needed(audio_path: str) -> str:
"""Convertit l'audio en WAV si nécessaire."""
ext = Path(audio_path).suffix.lower()
supported_formats = {'.wav', '.flac', '.ogg'}
if ext in supported_formats:
return audio_path
if ext in {'.m4a', '.mp3', '.mp4', '.aac'}:
print(f"[Gilbert] Conversion de {ext} en WAV...")
import librosa
import soundfile as sf
wav_path = str(Path(audio_path).with_suffix('.wav'))
if os.path.exists(wav_path):
print(f"[Gilbert] Fichier WAV existant: {wav_path}")
return wav_path
try:
y, sr = librosa.load(audio_path, sr=16000, mono=True)
sf.write(wav_path, y, sr)
print(f"[Gilbert] ✅ Converti en WAV: {wav_path}")
return wav_path
except Exception as e:
print(f"[Gilbert] ATTENTION: Erreur conversion: {e}")
return audio_path
return audio_path
def run_gilbert_diarization(
audio_path: str,
output_dir: str = "outputs/gilbert",
model_name: str = "pyannote/speaker-diarization-3.1",
num_speakers: int = None,
min_speakers: int = None,
max_speakers: int = None,
use_exclusive: bool = False,
show_progress: bool = True,
min_segment_duration: float = 0.5,
merge_gaps: float = 0.3
) -> Dict[str, Any]:
"""
Exécute la diarisation Gilbert (version personnalisée).
Args:
audio_path: Chemin vers le fichier audio
output_dir: Répertoire de sortie
model_name: Nom du modèle
num_speakers: Nombre exact de locuteurs
min_speakers: Nombre minimum de locuteurs
max_speakers: Nombre maximum de locuteurs
use_exclusive: Utiliser exclusive_speaker_diarization
show_progress: Afficher la progression
min_segment_duration: Durée minimale des segments (post-traitement)
merge_gaps: Gaps à fusionner (post-traitement)
Returns:
Dictionnaire contenant les résultats
"""
print("="*70)
print("GILBERT - Diarisation de locuteurs (version propriétaire)")
print("="*70)
# Convertir l'audio si nécessaire
audio_path = convert_audio_if_needed(audio_path)
print(f"[Gilbert] Audio: {audio_path}")
os.makedirs(output_dir, exist_ok=True)
# Charger le pipeline
pipeline = load_gilbert_pipeline(model_name)
# Configuration pour les réunions
diarization_options = {}
if num_speakers is not None:
diarization_options["num_speakers"] = num_speakers
print(f"[Gilbert] Nombre de locuteurs: {num_speakers}")
if min_speakers is not None:
diarization_options["min_speakers"] = min_speakers
if max_speakers is not None:
diarization_options["max_speakers"] = max_speakers
# Exécuter la diarisation
print("[Gilbert] Exécution de la diarisation...")
try:
if show_progress and HAS_PROGRESS_HOOK:
with ProgressHook() as hook:
diarization = pipeline(audio_path, hook=hook, **diarization_options)
else:
diarization = pipeline(audio_path, **diarization_options)
except Exception as e:
print(f"[Gilbert] ERREUR: {e}")
sys.exit(1)
# Utiliser exclusive_speaker_diarization si disponible
if use_exclusive and hasattr(diarization, 'exclusive_speaker_diarization'):
print("[Gilbert] Utilisation de exclusive_speaker_diarization")
annotation = diarization.exclusive_speaker_diarization
else:
annotation = diarization
# Convertir en segments
segments = annotation_to_segments(annotation)
# Post-traitement Gilbert (optionnel - peut être désactivé)
if min_segment_duration > 0 or merge_gaps > 0:
print("[Gilbert] Post-traitement intelligent des segments...")
segments_processed = post_process_segments_gilbert(
segments,
min_segment_duration=min_segment_duration,
max_gap=merge_gaps
)
else:
print("[Gilbert] Post-traitement désactivé (segments bruts conservés)")
segments_processed = segments
# Détecter les overlaps
overlaps = detect_overlaps_gilbert(segments_processed)
# Statistiques
num_speakers_detected = len(set(s["speaker"] for s in segments_processed))
duration = max(s["end"] for s in segments_processed) if segments_processed else 0.0
# Statistiques avancées par locuteur
speaker_stats = compute_gilbert_stats(segments_processed, overlaps)
print(f"[Gilbert] ✅ Diarisation terminée")
print(f"[Gilbert] Locuteurs: {num_speakers_detected}")
print(f"[Gilbert] Segments: {len(segments)}{len(segments_processed)} (après post-traitement)")
print(f"[Gilbert] Overlaps détectés: {len(overlaps)}")
return {
"segments": segments_processed,
"segments_raw": segments, # Segments bruts avant post-traitement
"overlaps": overlaps,
"num_speakers": num_speakers_detected,
"duration": duration,
"stats": speaker_stats,
"model": model_name,
"version": "Gilbert-v1.0"
}
def annotation_to_segments(annotation: Annotation) -> List[Dict[str, Any]]:
"""Convertit une annotation pyannote en segments."""
segments = []
speakers = sorted(annotation.labels())
speaker_mapping = {speaker: f"SPEAKER_{idx:02d}" for idx, speaker in enumerate(speakers)}
for segment, track, speaker in annotation.itertracks(yield_label=True):
normalized_speaker = speaker_mapping.get(speaker, speaker)
segments.append({
"speaker": normalized_speaker,
"start": round(segment.start, 2),
"end": round(segment.end, 2)
})
segments.sort(key=lambda x: x["start"])
return segments
def compute_gilbert_stats(
segments: List[Dict[str, Any]],
overlaps: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Calcule des statistiques avancées pour Gilbert."""
stats = defaultdict(lambda: {
"total_duration": 0.0,
"num_segments": 0,
"avg_segment_duration": 0.0,
"overlap_duration": 0.0
})
# Statistiques par locuteur
for seg in segments:
speaker = seg["speaker"]
duration = seg["end"] - seg["start"]
stats[speaker]["total_duration"] += duration
stats[speaker]["num_segments"] += 1
# Calculer les moyennes
for speaker in stats:
if stats[speaker]["num_segments"] > 0:
stats[speaker]["avg_segment_duration"] = (
stats[speaker]["total_duration"] / stats[speaker]["num_segments"]
)
# Calculer les overlaps par locuteur
for overlap in overlaps:
for speaker in overlap["speakers"]:
stats[speaker]["overlap_duration"] += overlap["duration"]
return dict(stats)
def write_rttm(segments: List[Dict[str, Any]], output_path: str, audio_name: str):
"""Écrit un fichier RTTM."""
with open(output_path, 'w') as f:
for seg in segments:
duration = seg["end"] - seg["start"]
f.write(f"SPEAKER {audio_name} 1 {seg['start']:.3f} {duration:.3f} <NA> <NA> {seg['speaker']} <NA> <NA>\n")
def write_json(segments: List[Dict[str, Any]], output_path: str):
"""Écrit un fichier JSON."""
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(segments, f, indent=2, ensure_ascii=False)
def main():
parser = argparse.ArgumentParser(
description="Gilbert - Diarisation pyannote personnalisée",
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("audio_path", type=str, help="Chemin vers le fichier audio")
parser.add_argument("--output_dir", type=str, default="outputs/gilbert", help="Répertoire de sortie")
parser.add_argument("--model", type=str, default="pyannote/speaker-diarization-3.1", help="Modèle pyannote")
parser.add_argument("--num_speakers", type=int, default=None, help="Nombre exact de locuteurs")
parser.add_argument("--min_speakers", type=int, default=None, help="Nombre minimum de locuteurs")
parser.add_argument("--max_speakers", type=int, default=None, help="Nombre maximum de locuteurs")
parser.add_argument("--exclusive", action="store_true", help="Utiliser exclusive_speaker_diarization")
parser.add_argument("--no-progress", action="store_true", help="Ne pas afficher la progression")
parser.add_argument("--min-segment", type=float, default=0.0, help="Durée minimale des segments (s). 0 = désactivé (recommandé pour meilleure précision)")
parser.add_argument("--merge-gaps", type=float, default=0.0, help="Gaps à fusionner (s). 0 = désactivé (recommandé pour meilleure précision)")
args = parser.parse_args()
if not os.path.exists(args.audio_path):
print(f"[Gilbert] ERREUR: Fichier introuvable: {args.audio_path}")
sys.exit(1)
# Exécuter la diarisation
results = run_gilbert_diarization(
args.audio_path,
args.output_dir,
args.model,
args.num_speakers,
args.min_speakers,
args.max_speakers,
args.exclusive,
not args.no_progress,
args.min_segment,
args.merge_gaps
)
# Préparer les chemins de sortie
audio_name = Path(args.audio_path).stem
rttm_path = os.path.join(args.output_dir, f"{audio_name}.rttm")
json_path = os.path.join(args.output_dir, f"{audio_name}.json")
stats_path = os.path.join(args.output_dir, f"{audio_name}_stats.json")
# Écrire les fichiers
write_rttm(results["segments"], rttm_path, audio_name)
write_json(results["segments"], json_path)
# Écrire les statistiques
with open(stats_path, 'w', encoding='utf-8') as f:
json.dump({
"version": results["version"],
"model": results["model"],
"num_speakers": results["num_speakers"],
"duration": results["duration"],
"num_segments": len(results["segments"]),
"num_segments_raw": len(results["segments_raw"]),
"num_overlaps": len(results["overlaps"]),
"speaker_stats": results["stats"]
}, f, indent=2, ensure_ascii=False)
# Afficher les statistiques
print("\n" + "="*70)
print("STATISTIQUES GILBERT")
print("="*70)
print(f"Version: {results['version']}")
print(f"Modèle: {results['model']}")
print(f"Locuteurs: {results['num_speakers']}")
print(f"Segments: {len(results['segments_raw'])}{len(results['segments'])} (post-traitement)")
print(f"Overlaps: {len(results['overlaps'])}")
print(f"\nStatistiques par locuteur:")
for speaker, stats in sorted(results["stats"].items()):
print(f" {speaker}: {stats['num_segments']} segments, "
f"{stats['total_duration']:.2f}s total, "
f"{stats['avg_segment_duration']:.2f}s moyenne, "
f"{stats['overlap_duration']:.2f}s overlap")
print(f"\nFichiers générés:")
print(f" RTTM: {rttm_path}")
print(f" JSON: {json_path}")
print(f" Stats: {stats_path}")
if __name__ == "__main__":
main()