File size: 1,847 Bytes
60c77e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f76952
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import torch
import matplotlib.pyplot as plt
from pyannote.metrics.diarization import DiarizationErrorRate

# THE ULTIMATE BYPASS (Fixes PyTorch 2.6 security errors)
import torch.serialization
original_load = torch.load
def patched_load(*args, **kwargs):
    kwargs['weights_only'] = False
    return original_load(*args, **kwargs)
torch.load = patched_load

# IMPORTS
from pyannote.core import notebook
from pyannote.audio import Pipeline
from pyannote.database.util import load_rttm

AUDIO_PATH = r"dataset/audio/clip_03.wav"
RTTM_PATH = r"dataset/rttm/clip_03.rttm"

hf_token = os.getenv("HF_TOKEN")

# INITIALIZE PIPELINE
print("Initializing AI Pipeline...")
pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token="hf_token"  # Replace with your Hugging Face token
)

# --- RUN DIARIZATION ---
print("AI is analyzing the audio...")
prediction = pipeline(AUDIO_PATH)

# --- LOAD GROUND TRUTH ---
gt_dict = load_rttm(RTTM_PATH)
uri = list(gt_dict.keys())[0]
ground_truth = gt_dict[uri]

# --- FIXED: CALCULATE DER USING REPORT ---
metric = DiarizationErrorRate()
# We process the specific file to get a clean report
metric(ground_truth, prediction, notebook=True) 
report = metric.report(display=True)

print("\n" + "="*50)
print("FINAL EVALUATION REPORT")
print(report)
print("="*50 + "\n")

## --- VISUALIZATION (UNCHANGED) ---
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 8), sharex=True)

plt.sca(ax1)
notebook.plot_annotation(ground_truth, ax=ax1) 
ax1.set_title("REFERENCE (Ground Truth)", fontsize=14, fontweight='bold')

plt.sca(ax2)
notebook.plot_annotation(prediction, ax=ax2)
ax2.set_title("HYPOTHESIS (Model Prediction)", fontsize=14, fontweight='bold')

plt.xlabel("Time (seconds)", fontsize=12)
plt.tight_layout()

print("Diarization complete! Displaying plot...")
plt.show()