import streamlit as st import os import torch import numpy as np import librosa import soundfile as sf import noisereduce as nr import pandas as pd import matplotlib.pyplot as plt from pyannote.audio import Model, Inference from pyannote.audio.utils.signal import Binarize from pyannote.core import SlidingWindowFeature, Annotation from sklearn.cluster import AgglomerativeClustering from sklearn.metrics import silhouette_score # --- 1. PYTORCH 2.6+ SECURITY FIX --- import torch.serialization original_load = torch.load def forced_load(f, map_location=None, pickle_module=None, **kwargs): kwargs['weights_only'] = False return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs) torch.load = forced_load # ------------------------------- st.set_page_config(page_title="Hindi-Bhojpuri Diarization Tool", layout="wide") st.title("🎙️ Speaker Diarization with De-noising") st.markdown(""" This tool uses a fine-tuned model to detect speakers. The system automatically determines the number of speakers based on voice similarity. """) # --- SIDEBAR CONFIGURATION (UI CLEANUP) --- st.sidebar.header("Configuration") MODEL_PATH = st.sidebar.text_input("Model Checkpoint Path", "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt") use_denoise = st.sidebar.checkbox("Enable De-noising", value=True) st.sidebar.subheader("Advanced Settings") threshold = st.sidebar.slider("AI Sensitivity (VAD)", 0.5, 0.95, 0.80) @st.cache_resource def load_cached_model(path): if not os.path.exists(path): return None return Model.from_pretrained(path) def process_audio(audio_path, model_path, denoise, sensitivity): # 1. Load Audio y, sr = librosa.load(audio_path, sr=16000) # converting to .wav # CONVERSION STEP: Explicitly write out as a standard .wav # This ensures the AI receives a PCM_16 bit depth file at 16kHz audio_input = "converted_audio.wav" sf.write(audio_input, y, sr, subtype='PCM_16') # 2. AGGRESSIVE DE-NOISING if denoise: with st.spinner("Step 1: Deep cleaning audio..."): # Increased prop_decrease to 0.90 to kill heavy background noise y = nr.reduce_noise(y=y, sr=sr, prop_decrease=0.90, n_fft=2048) audio_input = "temp_denoised.wav" sf.write(audio_input, y, sr) else: audio_input = audio_path # 3. AI Inference with st.spinner("Step 2: AI Neural Analysis..."): model = load_cached_model(model_path) if model is None: return None, None inference = Inference(model, window="sliding", duration=2.0, step=0.5) seg_output = inference(audio_input) data = np.squeeze(seg_output.data) if len(data.shape) == 3: data = data[:, :, 0] clean_scores = SlidingWindowFeature(data, seg_output.sliding_window) # 4. BINARIZATION FIX: Increase 'min_duration_on' to 1.2 seconds # This ignores all short noises/coughs/background clicks that cause 100+ speakers. binarize = Binarize(onset=0.85, offset=0.75, min_duration_on=1.2, min_duration_off=0.5) raw_hyp = binarize(clean_scores) # 5. FEATURE EXTRACTION embeddings = [] segments = [] for segment, track, label in raw_hyp.itertracks(yield_label=True): # Focus on the middle of the segment to get a 'clean' voiceprint feature_vector = np.mean(seg_output.crop(segment).data, axis=0).flatten() embeddings.append(feature_vector) segments.append(segment) final_hyp = Annotation() if len(embeddings) > 1: X = np.array(embeddings) # --- AUTO-DETECTION LOGIC --- # If Silhouette fails, we fall back to a safe range (2 to 5 speakers) try: scores = [] range_n = range(2, min(len(embeddings), 6)) for n in range_n: clusterer = AgglomerativeClustering(n_clusters=n, metric='euclidean', linkage='ward') labels = clusterer.fit_predict(X) scores.append(silhouette_score(X, labels)) best_n = range_n[np.argmax(scores)] except: best_n = 2 # Safe default for OJT demo clusterer = AgglomerativeClustering(n_clusters=best_n, metric='euclidean', linkage='ward') final_labels = clusterer.fit_predict(X) for i, segment in enumerate(segments): final_hyp[segment] = f"Speaker {final_labels[i]}" elif len(embeddings) == 1: final_hyp[segments[0]] = "Speaker 0" # .support() is CRITICAL: it merges small gaps of the same speaker return final_hyp.support(), audio_input # --- MAIN UI --- uploaded_file = st.file_uploader("Upload .wav file", type=["wav"]) if uploaded_file is not None: with open("temp_upload.wav", "wb") as f: f.write(uploaded_file.getbuffer()) col1, col2 = st.columns(2) with col1: st.subheader("Original Audio") st.audio("temp_upload.wav") if st.button("Start AI Analysis"): hyp, final_audio = process_audio("temp_upload.wav", MODEL_PATH, use_denoise, threshold) if hyp is None: st.error("Model not found!") else: with col2: if use_denoise: st.subheader("Denoised Version") st.audio(final_audio) st.divider() unique_speakers = sorted(hyp.labels()) st.subheader(f"📊 Speaker Timeline ({len(unique_speakers)} Speakers Detected)") if len(unique_speakers) > 0: fig, ax = plt.subplots(figsize=(12, len(unique_speakers) * 0.8 + 1.5)) colors = plt.cm.get_cmap('tab10', len(unique_speakers)) for i, speaker in enumerate(unique_speakers): speaker_segments = hyp.label_timeline(speaker) intervals = [(s.start, s.duration) for s in speaker_segments] ax.broken_barh(intervals, (i*10 + 2, 6), facecolors=colors(i)) ax.set_yticks([i*10 + 5 for i in range(len(unique_speakers))]) ax.set_yticklabels(unique_speakers) ax.set_xlabel("Time (seconds)") ax.grid(axis='x', linestyle='--', alpha=0.5) st.pyplot(fig) timestamp_list = [] for segment, track, label in hyp.itertracks(yield_label=True): timestamp_list.append({ "Speaker ID": label, "Start (s)": round(segment.start, 2), "End (s)": round(segment.end, 2), "Duration (s)": round(segment.duration, 2) }) df = pd.DataFrame(timestamp_list) st.dataframe(df, use_container_width=True) st.download_button("📩 Download CSV", df.to_csv(index=False).encode('utf-8'), "diarization.csv", "text/csv") else: st.warning("No speech detected.")