File size: 7,482 Bytes
9f76952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import streamlit as st
import os
import torch
import numpy as np
import librosa
import soundfile as sf
import noisereduce as nr
import pandas as pd
import matplotlib.pyplot as plt
from pyannote.audio import Model, Inference
from pyannote.audio.utils.signal import Binarize
from pyannote.core import SlidingWindowFeature, Annotation
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import silhouette_score

# --- 1. PYTORCH 2.6+ SECURITY FIX ---
import torch.serialization
original_load = torch.load
def forced_load(f, map_location=None, pickle_module=None, **kwargs):
    kwargs['weights_only'] = False
    return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs)
torch.load = forced_load
# -------------------------------

st.set_page_config(page_title="Hindi-Bhojpuri Diarization Tool", layout="wide")

st.title("🎙️ Speaker Diarization with De-noising")
st.markdown("""

This tool uses a fine-tuned model to detect speakers. 

The system automatically determines the number of speakers based on voice similarity.

""")

# --- SIDEBAR CONFIGURATION (UI CLEANUP) ---
st.sidebar.header("Configuration")
MODEL_PATH = st.sidebar.text_input("Model Checkpoint Path", "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt")
use_denoise = st.sidebar.checkbox("Enable De-noising", value=True)

st.sidebar.subheader("Advanced Settings")
threshold = st.sidebar.slider("AI Sensitivity (VAD)", 0.5, 0.95, 0.80)

@st.cache_resource
def load_cached_model(path):
    if not os.path.exists(path):
        return None
    return Model.from_pretrained(path)

def process_audio(audio_path, model_path, denoise, sensitivity):
    # 1. Load Audio
    y, sr = librosa.load(audio_path, sr=16000)
    
    # converting to .wav
    # CONVERSION STEP: Explicitly write out as a standard .wav
    # This ensures the AI receives a PCM_16 bit depth file at 16kHz
    audio_input = "converted_audio.wav"
    sf.write(audio_input, y, sr, subtype='PCM_16')

    # 2. AGGRESSIVE DE-NOISING
    if denoise:
        with st.spinner("Step 1: Deep cleaning audio..."):
            # Increased prop_decrease to 0.90 to kill heavy background noise
            y = nr.reduce_noise(y=y, sr=sr, prop_decrease=0.90, n_fft=2048)
            audio_input = "temp_denoised.wav"
            sf.write(audio_input, y, sr)
    else:
        audio_input = audio_path

    # 3. AI Inference
    with st.spinner("Step 2: AI Neural Analysis..."):
        model = load_cached_model(model_path)
        if model is None: return None, None
            
        inference = Inference(model, window="sliding", duration=2.0, step=0.5)
        seg_output = inference(audio_input)
        
        data = np.squeeze(seg_output.data)
        if len(data.shape) == 3: data = data[:, :, 0]
        clean_scores = SlidingWindowFeature(data, seg_output.sliding_window)
        
        # 4. BINARIZATION FIX: Increase 'min_duration_on' to 1.2 seconds
        # This ignores all short noises/coughs/background clicks that cause 100+ speakers.
        binarize = Binarize(onset=0.85, offset=0.75, min_duration_on=1.2, min_duration_off=0.5)
        raw_hyp = binarize(clean_scores)
        
        # 5. FEATURE EXTRACTION
        embeddings = []
        segments = []
        for segment, track, label in raw_hyp.itertracks(yield_label=True):
            # Focus on the middle of the segment to get a 'clean' voiceprint
            feature_vector = np.mean(seg_output.crop(segment).data, axis=0).flatten()
            embeddings.append(feature_vector)
            segments.append(segment)

        final_hyp = Annotation()
        
        if len(embeddings) > 1:
            X = np.array(embeddings)
            
            # --- AUTO-DETECTION LOGIC ---
            # If Silhouette fails, we fall back to a safe range (2 to 5 speakers)
            try:
                scores = []
                range_n = range(2, min(len(embeddings), 6))
                for n in range_n:
                    clusterer = AgglomerativeClustering(n_clusters=n, metric='euclidean', linkage='ward')
                    labels = clusterer.fit_predict(X)
                    scores.append(silhouette_score(X, labels))
                best_n = range_n[np.argmax(scores)]
            except:
                best_n = 2 # Safe default for OJT demo

            clusterer = AgglomerativeClustering(n_clusters=best_n, metric='euclidean', linkage='ward')
            final_labels = clusterer.fit_predict(X)
            
            for i, segment in enumerate(segments):
                final_hyp[segment] = f"Speaker {final_labels[i]}"
        
        elif len(embeddings) == 1:
            final_hyp[segments[0]] = "Speaker 0"

    # .support() is CRITICAL: it merges small gaps of the same speaker
    return final_hyp.support(), audio_input

# --- MAIN UI ---
uploaded_file = st.file_uploader("Upload .wav file", type=["wav"])

if uploaded_file is not None:
    with open("temp_upload.wav", "wb") as f:
        f.write(uploaded_file.getbuffer())
    
    col1, col2 = st.columns(2)
    with col1:
        st.subheader("Original Audio")
        st.audio("temp_upload.wav")
    
    if st.button("Start AI Analysis"):
        hyp, final_audio = process_audio("temp_upload.wav", MODEL_PATH, use_denoise, threshold)
        
        if hyp is None:
            st.error("Model not found!")
        else:
            with col2:
                if use_denoise:
                    st.subheader("Denoised Version")
                    st.audio(final_audio)
            
            st.divider()
            
            unique_speakers = sorted(hyp.labels())
            st.subheader(f"📊 Speaker Timeline ({len(unique_speakers)} Speakers Detected)")
            
            if len(unique_speakers) > 0:
                fig, ax = plt.subplots(figsize=(12, len(unique_speakers) * 0.8 + 1.5))
                colors = plt.cm.get_cmap('tab10', len(unique_speakers))
                
                for i, speaker in enumerate(unique_speakers):
                    speaker_segments = hyp.label_timeline(speaker)
                    intervals = [(s.start, s.duration) for s in speaker_segments]
                    ax.broken_barh(intervals, (i*10 + 2, 6), facecolors=colors(i))
                
                ax.set_yticks([i*10 + 5 for i in range(len(unique_speakers))])
                ax.set_yticklabels(unique_speakers)
                ax.set_xlabel("Time (seconds)")
                ax.grid(axis='x', linestyle='--', alpha=0.5)
                st.pyplot(fig)
                
                timestamp_list = []
                for segment, track, label in hyp.itertracks(yield_label=True):
                    timestamp_list.append({
                        "Speaker ID": label,
                        "Start (s)": round(segment.start, 2),
                        "End (s)": round(segment.end, 2),
                        "Duration (s)": round(segment.duration, 2)
                    })
                
                df = pd.DataFrame(timestamp_list)
                st.dataframe(df, use_container_width=True)
                st.download_button("📩 Download CSV", df.to_csv(index=False).encode('utf-8'), "diarization.csv", "text/csv")
            else:
                st.warning("No speech detected.")