import os # MUST BE AT THE VERY TOP os.environ["SPEECHBRAIN_LOCAL_STRATEGY"] = "copy" import torch import torchaudio import pandas as pd from pyannote.audio import Model from pyannote.audio.pipelines import SpeakerDiarization from pyannote.database.util import load_rttm from pyannote.metrics.diarization import DiarizationErrorRate, DiarizationPurity, DiarizationCoverage # --- THE DEFINITIVE FIX FOR PYTORCH 2.6+ SECURITY ERRORS --- import torch.serialization original_load = torch.load def forced_load(f, map_location=None, pickle_module=None, **kwargs): kwargs['weights_only'] = False return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs) torch.load = forced_load # ----------------------------------------------------------- # Configuration - Update these paths to match your project structure CHECKPOINT_PATH = "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt" TEST_LIST_PATH = "dataset/splits/test.txt" AUDIO_DIR = "dataset/audio" RTTM_DIR = "dataset/rttm" OUTPUT_CSV = "overall_model_performance.csv" def run_global_evaluation(): # 1. Load the fine-tuned model print(f"Loading fine-tuned model from: {CHECKPOINT_PATH}") seg_model = Model.from_pretrained(CHECKPOINT_PATH) # 2. Initialize the Diarization Pipeline print("Initializing Pipeline...") pipeline = SpeakerDiarization( segmentation=seg_model, embedding="speechbrain/spkrec-ecapa-voxceleb", clustering="AgglomerativeClustering", ) # Balanced parameters for diverse speaker counts params = { "segmentation": { "threshold": 0.58, # High threshold to kill False Alarms "min_duration_off": 0.2, # Prevents fragmented "flickering" between speakers }, "clustering": { "method": "centroid", "threshold": 0.62, # Lower threshold to encourage speaker separation "min_cluster_size": 1, }, } pipeline.instantiate(params) # 3. Initialize Metrics # Using 'total' metrics to accumulate across all files total_der_metric = DiarizationErrorRate() # 4. Load filenames from test.txt with open(TEST_LIST_PATH, 'r') as f: # Extract the URI (filename without extension) from each line # Adjust the split logic if your test.txt has a different format (e.g., space-separated) test_files = [line.strip().split()[0] for line in f if line.strip()] print(f"Found {len(test_files)} files in test set. Starting Batch Processing...") print("-" * 50) for uri in test_files: audio_path = os.path.join(AUDIO_DIR, f"{uri}.wav") rttm_path = os.path.join(RTTM_DIR, f"{uri}.rttm") if not os.path.exists(audio_path): print(f"Warning: Audio file not found for {uri}. Skipping.") continue # Load Reference RTTM try: reference = load_rttm(rttm_path)[uri] except Exception as e: print(f"Warning: Could not load RTTM for {uri}. Error: {e}") continue # Run Diarization waveform, sample_rate = torchaudio.load(audio_path) test_file = {"waveform": waveform, "sample_rate": sample_rate, "uri": uri} # We allow the AI to determine speaker count dynamically (min 2, max 7) hypothesis = pipeline(test_file, min_speakers=2, max_speakers=7) # Accumulate the metric total_der_metric(reference, hypothesis, detailed=True) print(f"Done: {uri}") # 5. Final Calculations print("\n" + "="*50) print(" FINAL GLOBAL REPORT") print("="*50) # This creates a detailed table per file report_df = total_der_metric.report(display=True) # Global DER is the value of the metric after processing all files global_der = abs(total_der_metric) global_accuracy = max(0, (1 - global_der) * 100) print(f"\nOVERALL SYSTEM ACCURACY : {global_accuracy:.2f}%") print(f"GLOBAL DIARIZATION ERROR: {global_der * 100:.2f}%") print("="*50) # Save detailed report to CSV for your documentation report_df.to_csv(OUTPUT_CSV) print(f"Detailed file-by-file breakdown saved to: {OUTPUT_CSV}") if __name__ == "__main__": run_global_evaluation()