Spaces:
Sleeping
Sleeping
File size: 1,847 Bytes
60c77e8 9f76952 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import os
import torch
import matplotlib.pyplot as plt
from pyannote.metrics.diarization import DiarizationErrorRate
# THE ULTIMATE BYPASS (Fixes PyTorch 2.6 security errors)
import torch.serialization
original_load = torch.load
def patched_load(*args, **kwargs):
kwargs['weights_only'] = False
return original_load(*args, **kwargs)
torch.load = patched_load
# IMPORTS
from pyannote.core import notebook
from pyannote.audio import Pipeline
from pyannote.database.util import load_rttm
AUDIO_PATH = r"dataset/audio/clip_03.wav"
RTTM_PATH = r"dataset/rttm/clip_03.rttm"
hf_token = os.getenv("HF_TOKEN")
# INITIALIZE PIPELINE
print("Initializing AI Pipeline...")
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token="hf_token" # Replace with your Hugging Face token
)
# --- RUN DIARIZATION ---
print("AI is analyzing the audio...")
prediction = pipeline(AUDIO_PATH)
# --- LOAD GROUND TRUTH ---
gt_dict = load_rttm(RTTM_PATH)
uri = list(gt_dict.keys())[0]
ground_truth = gt_dict[uri]
# --- FIXED: CALCULATE DER USING REPORT ---
metric = DiarizationErrorRate()
# We process the specific file to get a clean report
metric(ground_truth, prediction, notebook=True)
report = metric.report(display=True)
print("\n" + "="*50)
print("FINAL EVALUATION REPORT")
print(report)
print("="*50 + "\n")
## --- VISUALIZATION (UNCHANGED) ---
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 8), sharex=True)
plt.sca(ax1)
notebook.plot_annotation(ground_truth, ax=ax1)
ax1.set_title("REFERENCE (Ground Truth)", fontsize=14, fontweight='bold')
plt.sca(ax2)
notebook.plot_annotation(prediction, ax=ax2)
ax2.set_title("HYPOTHESIS (Model Prediction)", fontsize=14, fontweight='bold')
plt.xlabel("Time (seconds)", fontsize=12)
plt.tight_layout()
print("Diarization complete! Displaying plot...")
plt.show() |