AnamikaP commited on
Commit
60c77e8
·
verified ·
1 Parent(s): 695ff37

Update scripts/diarization_visualization.py

Browse files
Files changed (1) hide show
  1. scripts/diarization_visualization.py +65 -63
scripts/diarization_visualization.py CHANGED
@@ -1,64 +1,66 @@
1
- import os
2
- import torch
3
- import matplotlib.pyplot as plt
4
- from pyannote.metrics.diarization import DiarizationErrorRate
5
-
6
- # THE ULTIMATE BYPASS (Fixes PyTorch 2.6 security errors)
7
- import torch.serialization
8
- original_load = torch.load
9
- def patched_load(*args, **kwargs):
10
- kwargs['weights_only'] = False
11
- return original_load(*args, **kwargs)
12
- torch.load = patched_load
13
-
14
- # IMPORTS
15
- from pyannote.core import notebook
16
- from pyannote.audio import Pipeline
17
- from pyannote.database.util import load_rttm
18
-
19
- AUDIO_PATH = r"dataset/audio/clip_03.wav"
20
- RTTM_PATH = r"dataset/rttm/clip_03.rttm"
21
-
22
- # INITIALIZE PIPELINE
23
- print("Initializing AI Pipeline...")
24
- pipeline = Pipeline.from_pretrained(
25
- "pyannote/speaker-diarization-3.1",
26
- use_auth_token="hf_token_here" # Replace with your Hugging Face token
27
- )
28
-
29
- # --- RUN DIARIZATION ---
30
- print("AI is analyzing the audio...")
31
- prediction = pipeline(AUDIO_PATH)
32
-
33
- # --- LOAD GROUND TRUTH ---
34
- gt_dict = load_rttm(RTTM_PATH)
35
- uri = list(gt_dict.keys())[0]
36
- ground_truth = gt_dict[uri]
37
-
38
- # --- FIXED: CALCULATE DER USING REPORT ---
39
- metric = DiarizationErrorRate()
40
- # We process the specific file to get a clean report
41
- metric(ground_truth, prediction, notebook=True)
42
- report = metric.report(display=True)
43
-
44
- print("\n" + "="*50)
45
- print("FINAL EVALUATION REPORT")
46
- print(report)
47
- print("="*50 + "\n")
48
-
49
- ## --- VISUALIZATION (UNCHANGED) ---
50
- fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 8), sharex=True)
51
-
52
- plt.sca(ax1)
53
- notebook.plot_annotation(ground_truth, ax=ax1)
54
- ax1.set_title("REFERENCE (Ground Truth)", fontsize=14, fontweight='bold')
55
-
56
- plt.sca(ax2)
57
- notebook.plot_annotation(prediction, ax=ax2)
58
- ax2.set_title("HYPOTHESIS (Model Prediction)", fontsize=14, fontweight='bold')
59
-
60
- plt.xlabel("Time (seconds)", fontsize=12)
61
- plt.tight_layout()
62
-
63
- print("Diarization complete! Displaying plot...")
 
 
64
  plt.show()
 
1
+ import os
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ from pyannote.metrics.diarization import DiarizationErrorRate
5
+
6
+ # THE ULTIMATE BYPASS (Fixes PyTorch 2.6 security errors)
7
+ import torch.serialization
8
+ original_load = torch.load
9
+ def patched_load(*args, **kwargs):
10
+ kwargs['weights_only'] = False
11
+ return original_load(*args, **kwargs)
12
+ torch.load = patched_load
13
+
14
+ # IMPORTS
15
+ from pyannote.core import notebook
16
+ from pyannote.audio import Pipeline
17
+ from pyannote.database.util import load_rttm
18
+
19
+ AUDIO_PATH = r"dataset/audio/clip_03.wav"
20
+ RTTM_PATH = r"dataset/rttm/clip_03.rttm"
21
+
22
+ hf_token = os.getenv("HF_TOKEN")
23
+
24
+ # INITIALIZE PIPELINE
25
+ print("Initializing AI Pipeline...")
26
+ pipeline = Pipeline.from_pretrained(
27
+ "pyannote/speaker-diarization-3.1",
28
+ use_auth_token="hf_token" # Replace with your Hugging Face token
29
+ )
30
+
31
+ # --- RUN DIARIZATION ---
32
+ print("AI is analyzing the audio...")
33
+ prediction = pipeline(AUDIO_PATH)
34
+
35
+ # --- LOAD GROUND TRUTH ---
36
+ gt_dict = load_rttm(RTTM_PATH)
37
+ uri = list(gt_dict.keys())[0]
38
+ ground_truth = gt_dict[uri]
39
+
40
+ # --- FIXED: CALCULATE DER USING REPORT ---
41
+ metric = DiarizationErrorRate()
42
+ # We process the specific file to get a clean report
43
+ metric(ground_truth, prediction, notebook=True)
44
+ report = metric.report(display=True)
45
+
46
+ print("\n" + "="*50)
47
+ print("FINAL EVALUATION REPORT")
48
+ print(report)
49
+ print("="*50 + "\n")
50
+
51
+ ## --- VISUALIZATION (UNCHANGED) ---
52
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 8), sharex=True)
53
+
54
+ plt.sca(ax1)
55
+ notebook.plot_annotation(ground_truth, ax=ax1)
56
+ ax1.set_title("REFERENCE (Ground Truth)", fontsize=14, fontweight='bold')
57
+
58
+ plt.sca(ax2)
59
+ notebook.plot_annotation(prediction, ax=ax2)
60
+ ax2.set_title("HYPOTHESIS (Model Prediction)", fontsize=14, fontweight='bold')
61
+
62
+ plt.xlabel("Time (seconds)", fontsize=12)
63
+ plt.tight_layout()
64
+
65
+ print("Diarization complete! Displaying plot...")
66
  plt.show()