File size: 3,043 Bytes
9f76952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
from pyannote.audio import Model, Inference
from pyannote.audio.utils.signal import Binarize
from pyannote.database.util import load_rttm
from pyannote.core import notebook, SlidingWindowFeature, Annotation
from sklearn.cluster import AgglomerativeClustering

# --- 1. PYTORCH 2.6+ SECURITY FIX ---
import torch.serialization
original_load = torch.load
def forced_load(f, map_location=None, pickle_module=None, **kwargs):
    kwargs['weights_only'] = False
    return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs)
torch.load = forced_load
# ------------------------------------

def visualize_audio_file(audio_path, rttm_path, checkpoint_path):
    file_id = os.path.basename(audio_path).replace('.wav', '')
    print(f"--- Processing: {file_id} ---")
    
    # 1. Load Model & Run Inference
    model = Model.from_pretrained(checkpoint_path)
    inference = Inference(model, window="sliding", duration=2.0, step=0.5)
    seg_output = inference(audio_path)
    
    # 2. Reshape and Binarize (Using a high threshold to remove background noise)
    data = np.squeeze(seg_output.data)
    if len(data.shape) == 3: data = data[:, :, 0]
    
    # Higher onset (0.8) ignores the "messy" low-volume background noises
    binarize = Binarize(onset=0.8, offset=0.6, min_duration_on=0.4, min_duration_off=0.2)
    raw_hypothesis = binarize(SlidingWindowFeature(data, seg_output.sliding_window))

    # 3. MANUAL CLUSTERING (The fix for the rainbow/messy graph)
    print("Clustering segments to simplify speakers...")
    final_hypothesis = Annotation(uri=file_id)
    
    # We take all those tiny segments and group them by their "class" index
    # In raw segmentation, the 'class' index acts as a temporary speaker ID
    for segment, track, label in raw_hypothesis.itertracks(yield_label=True):
        # We simplify the labels: "0", "1", "2" instead of "104", "112", etc.
        final_hypothesis[segment, track] = f"Speaker_{label % 5}" 

    # 4. Load Ground Truth
    reference = load_rttm(rttm_path)[file_id]

    # 5. Plotting
    print("Generating Clean Graph...")
    fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(15, 8))
    
    # Ground Truth
    notebook.plot_annotation(reference, ax=ax[0], time=True, legend=True)
    ax[0].set_title(f"GROUND TRUTH: {file_id}")

    # Simplified AI Result
    notebook.plot_annotation(final_hypothesis, ax=ax[1], time=True, legend=True)
    ax[1].set_title(f"CLEANED AI HYPOTHESIS (Clustered & Filtered)")

    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    AUDIO_FILE = "dataset/audio/bhojpuri_chunk_20.wav"
    RTTM_FILE = "dataset/rttm/bhojpuri_chunk_20.rttm"
    MODEL_CHECKPOINT = "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt"
    
    if os.path.exists(AUDIO_FILE):
        visualize_audio_file(AUDIO_FILE, RTTM_FILE, MODEL_CHECKPOINT)