AnamikaP's picture
Upload 18 files
9f76952 verified
import os
# MUST BE AT THE VERY TOP
os.environ["SPEECHBRAIN_LOCAL_STRATEGY"] = "copy"
import torch
import torchaudio
import pandas as pd
from pyannote.audio import Model
from pyannote.audio.pipelines import SpeakerDiarization
from pyannote.database.util import load_rttm
from pyannote.metrics.diarization import DiarizationErrorRate, DiarizationPurity, DiarizationCoverage
# --- THE DEFINITIVE FIX FOR PYTORCH 2.6+ SECURITY ERRORS ---
import torch.serialization
original_load = torch.load
def forced_load(f, map_location=None, pickle_module=None, **kwargs):
kwargs['weights_only'] = False
return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs)
torch.load = forced_load
# -----------------------------------------------------------
# Configuration - Update these paths to match your project structure
CHECKPOINT_PATH = "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt"
TEST_LIST_PATH = "dataset/splits/test.txt"
AUDIO_DIR = "dataset/audio"
RTTM_DIR = "dataset/rttm"
OUTPUT_CSV = "overall_model_performance.csv"
def run_global_evaluation():
# 1. Load the fine-tuned model
print(f"Loading fine-tuned model from: {CHECKPOINT_PATH}")
seg_model = Model.from_pretrained(CHECKPOINT_PATH)
# 2. Initialize the Diarization Pipeline
print("Initializing Pipeline...")
pipeline = SpeakerDiarization(
segmentation=seg_model,
embedding="speechbrain/spkrec-ecapa-voxceleb",
clustering="AgglomerativeClustering",
)
# Balanced parameters for diverse speaker counts
params = {
"segmentation": {
"threshold": 0.58, # High threshold to kill False Alarms
"min_duration_off": 0.2, # Prevents fragmented "flickering" between speakers
},
"clustering": {
"method": "centroid",
"threshold": 0.62, # Lower threshold to encourage speaker separation
"min_cluster_size": 1,
},
}
pipeline.instantiate(params)
# 3. Initialize Metrics
# Using 'total' metrics to accumulate across all files
total_der_metric = DiarizationErrorRate()
# 4. Load filenames from test.txt
with open(TEST_LIST_PATH, 'r') as f:
# Extract the URI (filename without extension) from each line
# Adjust the split logic if your test.txt has a different format (e.g., space-separated)
test_files = [line.strip().split()[0] for line in f if line.strip()]
print(f"Found {len(test_files)} files in test set. Starting Batch Processing...")
print("-" * 50)
for uri in test_files:
audio_path = os.path.join(AUDIO_DIR, f"{uri}.wav")
rttm_path = os.path.join(RTTM_DIR, f"{uri}.rttm")
if not os.path.exists(audio_path):
print(f"Warning: Audio file not found for {uri}. Skipping.")
continue
# Load Reference RTTM
try:
reference = load_rttm(rttm_path)[uri]
except Exception as e:
print(f"Warning: Could not load RTTM for {uri}. Error: {e}")
continue
# Run Diarization
waveform, sample_rate = torchaudio.load(audio_path)
test_file = {"waveform": waveform, "sample_rate": sample_rate, "uri": uri}
# We allow the AI to determine speaker count dynamically (min 2, max 7)
hypothesis = pipeline(test_file, min_speakers=2, max_speakers=7)
# Accumulate the metric
total_der_metric(reference, hypothesis, detailed=True)
print(f"Done: {uri}")
# 5. Final Calculations
print("\n" + "="*50)
print(" FINAL GLOBAL REPORT")
print("="*50)
# This creates a detailed table per file
report_df = total_der_metric.report(display=True)
# Global DER is the value of the metric after processing all files
global_der = abs(total_der_metric)
global_accuracy = max(0, (1 - global_der) * 100)
print(f"\nOVERALL SYSTEM ACCURACY : {global_accuracy:.2f}%")
print(f"GLOBAL DIARIZATION ERROR: {global_der * 100:.2f}%")
print("="*50)
# Save detailed report to CSV for your documentation
report_df.to_csv(OUTPUT_CSV)
print(f"Detailed file-by-file breakdown saved to: {OUTPUT_CSV}")
if __name__ == "__main__":
run_global_evaluation()