Spaces:
Sleeping
Sleeping
Upload 18 files
Browse files- .gitattributes +4 -0
- app.py +179 -0
- outputs/Screenshot 2026-02-03 034427.png +3 -0
- outputs/Screenshot 2026-02-03 051514.png +3 -0
- outputs/Screenshot 2026-02-03 061911.png +0 -0
- outputs/Screenshot 2026-02-03 154131.png +3 -0
- requirements.txt +0 -0
- scripts/audioconversion.ipynb +202 -0
- scripts/check.py +9 -0
- scripts/check_rttm.py +24 -0
- scripts/diarization.ipynb +0 -0
- scripts/diarization_visualization.py +64 -0
- scripts/make_splits.py +49 -0
- scripts/run.py +116 -0
- scripts/segmentation.py +73 -0
- scripts/test_model.py +47 -0
- scripts/test_protocol.py +56 -0
- scripts/visualize_segmentation.py +72 -0
- test/audio2.wav +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
outputs/Screenshot[[:space:]]2026-02-03[[:space:]]034427.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
outputs/Screenshot[[:space:]]2026-02-03[[:space:]]051514.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
outputs/Screenshot[[:space:]]2026-02-03[[:space:]]154131.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
test/audio2.wav filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import librosa
|
| 6 |
+
import soundfile as sf
|
| 7 |
+
import noisereduce as nr
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from pyannote.audio import Model, Inference
|
| 11 |
+
from pyannote.audio.utils.signal import Binarize
|
| 12 |
+
from pyannote.core import SlidingWindowFeature, Annotation
|
| 13 |
+
from sklearn.cluster import AgglomerativeClustering
|
| 14 |
+
from sklearn.metrics import silhouette_score
|
| 15 |
+
|
| 16 |
+
# --- 1. PYTORCH 2.6+ SECURITY FIX ---
|
| 17 |
+
import torch.serialization
|
| 18 |
+
original_load = torch.load
|
| 19 |
+
def forced_load(f, map_location=None, pickle_module=None, **kwargs):
|
| 20 |
+
kwargs['weights_only'] = False
|
| 21 |
+
return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs)
|
| 22 |
+
torch.load = forced_load
|
| 23 |
+
# -------------------------------
|
| 24 |
+
|
| 25 |
+
st.set_page_config(page_title="Hindi-Bhojpuri Diarization Tool", layout="wide")
|
| 26 |
+
|
| 27 |
+
st.title("🎙️ Speaker Diarization with De-noising")
|
| 28 |
+
st.markdown("""
|
| 29 |
+
This tool uses a fine-tuned model to detect speakers.
|
| 30 |
+
The system automatically determines the number of speakers based on voice similarity.
|
| 31 |
+
""")
|
| 32 |
+
|
| 33 |
+
# --- SIDEBAR CONFIGURATION (UI CLEANUP) ---
|
| 34 |
+
st.sidebar.header("Configuration")
|
| 35 |
+
MODEL_PATH = st.sidebar.text_input("Model Checkpoint Path", "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt")
|
| 36 |
+
use_denoise = st.sidebar.checkbox("Enable De-noising", value=True)
|
| 37 |
+
|
| 38 |
+
st.sidebar.subheader("Advanced Settings")
|
| 39 |
+
threshold = st.sidebar.slider("AI Sensitivity (VAD)", 0.5, 0.95, 0.80)
|
| 40 |
+
|
| 41 |
+
@st.cache_resource
|
| 42 |
+
def load_cached_model(path):
|
| 43 |
+
if not os.path.exists(path):
|
| 44 |
+
return None
|
| 45 |
+
return Model.from_pretrained(path)
|
| 46 |
+
|
| 47 |
+
def process_audio(audio_path, model_path, denoise, sensitivity):
|
| 48 |
+
# 1. Load Audio
|
| 49 |
+
y, sr = librosa.load(audio_path, sr=16000)
|
| 50 |
+
|
| 51 |
+
# converting to .wav
|
| 52 |
+
# CONVERSION STEP: Explicitly write out as a standard .wav
|
| 53 |
+
# This ensures the AI receives a PCM_16 bit depth file at 16kHz
|
| 54 |
+
audio_input = "converted_audio.wav"
|
| 55 |
+
sf.write(audio_input, y, sr, subtype='PCM_16')
|
| 56 |
+
|
| 57 |
+
# 2. AGGRESSIVE DE-NOISING
|
| 58 |
+
if denoise:
|
| 59 |
+
with st.spinner("Step 1: Deep cleaning audio..."):
|
| 60 |
+
# Increased prop_decrease to 0.90 to kill heavy background noise
|
| 61 |
+
y = nr.reduce_noise(y=y, sr=sr, prop_decrease=0.90, n_fft=2048)
|
| 62 |
+
audio_input = "temp_denoised.wav"
|
| 63 |
+
sf.write(audio_input, y, sr)
|
| 64 |
+
else:
|
| 65 |
+
audio_input = audio_path
|
| 66 |
+
|
| 67 |
+
# 3. AI Inference
|
| 68 |
+
with st.spinner("Step 2: AI Neural Analysis..."):
|
| 69 |
+
model = load_cached_model(model_path)
|
| 70 |
+
if model is None: return None, None
|
| 71 |
+
|
| 72 |
+
inference = Inference(model, window="sliding", duration=2.0, step=0.5)
|
| 73 |
+
seg_output = inference(audio_input)
|
| 74 |
+
|
| 75 |
+
data = np.squeeze(seg_output.data)
|
| 76 |
+
if len(data.shape) == 3: data = data[:, :, 0]
|
| 77 |
+
clean_scores = SlidingWindowFeature(data, seg_output.sliding_window)
|
| 78 |
+
|
| 79 |
+
# 4. BINARIZATION FIX: Increase 'min_duration_on' to 1.2 seconds
|
| 80 |
+
# This ignores all short noises/coughs/background clicks that cause 100+ speakers.
|
| 81 |
+
binarize = Binarize(onset=0.85, offset=0.75, min_duration_on=1.2, min_duration_off=0.5)
|
| 82 |
+
raw_hyp = binarize(clean_scores)
|
| 83 |
+
|
| 84 |
+
# 5. FEATURE EXTRACTION
|
| 85 |
+
embeddings = []
|
| 86 |
+
segments = []
|
| 87 |
+
for segment, track, label in raw_hyp.itertracks(yield_label=True):
|
| 88 |
+
# Focus on the middle of the segment to get a 'clean' voiceprint
|
| 89 |
+
feature_vector = np.mean(seg_output.crop(segment).data, axis=0).flatten()
|
| 90 |
+
embeddings.append(feature_vector)
|
| 91 |
+
segments.append(segment)
|
| 92 |
+
|
| 93 |
+
final_hyp = Annotation()
|
| 94 |
+
|
| 95 |
+
if len(embeddings) > 1:
|
| 96 |
+
X = np.array(embeddings)
|
| 97 |
+
|
| 98 |
+
# --- AUTO-DETECTION LOGIC ---
|
| 99 |
+
# If Silhouette fails, we fall back to a safe range (2 to 5 speakers)
|
| 100 |
+
try:
|
| 101 |
+
scores = []
|
| 102 |
+
range_n = range(2, min(len(embeddings), 6))
|
| 103 |
+
for n in range_n:
|
| 104 |
+
clusterer = AgglomerativeClustering(n_clusters=n, metric='euclidean', linkage='ward')
|
| 105 |
+
labels = clusterer.fit_predict(X)
|
| 106 |
+
scores.append(silhouette_score(X, labels))
|
| 107 |
+
best_n = range_n[np.argmax(scores)]
|
| 108 |
+
except:
|
| 109 |
+
best_n = 2 # Safe default for OJT demo
|
| 110 |
+
|
| 111 |
+
clusterer = AgglomerativeClustering(n_clusters=best_n, metric='euclidean', linkage='ward')
|
| 112 |
+
final_labels = clusterer.fit_predict(X)
|
| 113 |
+
|
| 114 |
+
for i, segment in enumerate(segments):
|
| 115 |
+
final_hyp[segment] = f"Speaker {final_labels[i]}"
|
| 116 |
+
|
| 117 |
+
elif len(embeddings) == 1:
|
| 118 |
+
final_hyp[segments[0]] = "Speaker 0"
|
| 119 |
+
|
| 120 |
+
# .support() is CRITICAL: it merges small gaps of the same speaker
|
| 121 |
+
return final_hyp.support(), audio_input
|
| 122 |
+
|
| 123 |
+
# --- MAIN UI ---
|
| 124 |
+
uploaded_file = st.file_uploader("Upload .wav file", type=["wav"])
|
| 125 |
+
|
| 126 |
+
if uploaded_file is not None:
|
| 127 |
+
with open("temp_upload.wav", "wb") as f:
|
| 128 |
+
f.write(uploaded_file.getbuffer())
|
| 129 |
+
|
| 130 |
+
col1, col2 = st.columns(2)
|
| 131 |
+
with col1:
|
| 132 |
+
st.subheader("Original Audio")
|
| 133 |
+
st.audio("temp_upload.wav")
|
| 134 |
+
|
| 135 |
+
if st.button("Start AI Analysis"):
|
| 136 |
+
hyp, final_audio = process_audio("temp_upload.wav", MODEL_PATH, use_denoise, threshold)
|
| 137 |
+
|
| 138 |
+
if hyp is None:
|
| 139 |
+
st.error("Model not found!")
|
| 140 |
+
else:
|
| 141 |
+
with col2:
|
| 142 |
+
if use_denoise:
|
| 143 |
+
st.subheader("Denoised Version")
|
| 144 |
+
st.audio(final_audio)
|
| 145 |
+
|
| 146 |
+
st.divider()
|
| 147 |
+
|
| 148 |
+
unique_speakers = sorted(hyp.labels())
|
| 149 |
+
st.subheader(f"📊 Speaker Timeline ({len(unique_speakers)} Speakers Detected)")
|
| 150 |
+
|
| 151 |
+
if len(unique_speakers) > 0:
|
| 152 |
+
fig, ax = plt.subplots(figsize=(12, len(unique_speakers) * 0.8 + 1.5))
|
| 153 |
+
colors = plt.cm.get_cmap('tab10', len(unique_speakers))
|
| 154 |
+
|
| 155 |
+
for i, speaker in enumerate(unique_speakers):
|
| 156 |
+
speaker_segments = hyp.label_timeline(speaker)
|
| 157 |
+
intervals = [(s.start, s.duration) for s in speaker_segments]
|
| 158 |
+
ax.broken_barh(intervals, (i*10 + 2, 6), facecolors=colors(i))
|
| 159 |
+
|
| 160 |
+
ax.set_yticks([i*10 + 5 for i in range(len(unique_speakers))])
|
| 161 |
+
ax.set_yticklabels(unique_speakers)
|
| 162 |
+
ax.set_xlabel("Time (seconds)")
|
| 163 |
+
ax.grid(axis='x', linestyle='--', alpha=0.5)
|
| 164 |
+
st.pyplot(fig)
|
| 165 |
+
|
| 166 |
+
timestamp_list = []
|
| 167 |
+
for segment, track, label in hyp.itertracks(yield_label=True):
|
| 168 |
+
timestamp_list.append({
|
| 169 |
+
"Speaker ID": label,
|
| 170 |
+
"Start (s)": round(segment.start, 2),
|
| 171 |
+
"End (s)": round(segment.end, 2),
|
| 172 |
+
"Duration (s)": round(segment.duration, 2)
|
| 173 |
+
})
|
| 174 |
+
|
| 175 |
+
df = pd.DataFrame(timestamp_list)
|
| 176 |
+
st.dataframe(df, use_container_width=True)
|
| 177 |
+
st.download_button("📩 Download CSV", df.to_csv(index=False).encode('utf-8'), "diarization.csv", "text/csv")
|
| 178 |
+
else:
|
| 179 |
+
st.warning("No speech detected.")
|
outputs/Screenshot 2026-02-03 034427.png
ADDED
|
Git LFS Details
|
outputs/Screenshot 2026-02-03 051514.png
ADDED
|
Git LFS Details
|
outputs/Screenshot 2026-02-03 061911.png
ADDED
|
outputs/Screenshot 2026-02-03 154131.png
ADDED
|
Git LFS Details
|
requirements.txt
CHANGED
|
Binary files a/requirements.txt and b/requirements.txt differ
|
|
|
scripts/audioconversion.ipynb
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": []
|
| 7 |
+
},
|
| 8 |
+
"kernelspec": {
|
| 9 |
+
"name": "python3",
|
| 10 |
+
"display_name": "Python 3"
|
| 11 |
+
},
|
| 12 |
+
"language_info": {
|
| 13 |
+
"name": "python"
|
| 14 |
+
}
|
| 15 |
+
},
|
| 16 |
+
"cells": [
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"source": [
|
| 20 |
+
"!pip install yt-dlp pydub ffmpeg-python\n",
|
| 21 |
+
"!apt-get install ffmpeg -y # For Colab, to make pydub work"
|
| 22 |
+
],
|
| 23 |
+
"metadata": {
|
| 24 |
+
"colab": {
|
| 25 |
+
"base_uri": "https://localhost:8080/"
|
| 26 |
+
},
|
| 27 |
+
"id": "QURofI_GiSTK",
|
| 28 |
+
"outputId": "7c46dcd6-d49c-4172-bf27-d9d408352cdf"
|
| 29 |
+
},
|
| 30 |
+
"execution_count": null,
|
| 31 |
+
"outputs": [
|
| 32 |
+
{
|
| 33 |
+
"output_type": "stream",
|
| 34 |
+
"name": "stdout",
|
| 35 |
+
"text": [
|
| 36 |
+
"Collecting yt-dlp\n",
|
| 37 |
+
" Downloading yt_dlp-2025.12.8-py3-none-any.whl.metadata (180 kB)\n",
|
| 38 |
+
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/180.3 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m180.3/180.3 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 39 |
+
"\u001b[?25hRequirement already satisfied: pydub in /usr/local/lib/python3.12/dist-packages (0.25.1)\n",
|
| 40 |
+
"Collecting ffmpeg-python\n",
|
| 41 |
+
" Downloading ffmpeg_python-0.2.0-py3-none-any.whl.metadata (1.7 kB)\n",
|
| 42 |
+
"Requirement already satisfied: future in /usr/local/lib/python3.12/dist-packages (from ffmpeg-python) (1.0.0)\n",
|
| 43 |
+
"Downloading yt_dlp-2025.12.8-py3-none-any.whl (3.3 MB)\n",
|
| 44 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.3/3.3 MB\u001b[0m \u001b[31m55.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 45 |
+
"\u001b[?25hDownloading ffmpeg_python-0.2.0-py3-none-any.whl (25 kB)\n",
|
| 46 |
+
"Installing collected packages: yt-dlp, ffmpeg-python\n",
|
| 47 |
+
"Successfully installed ffmpeg-python-0.2.0 yt-dlp-2025.12.8\n",
|
| 48 |
+
"Reading package lists... Done\n",
|
| 49 |
+
"Building dependency tree... Done\n",
|
| 50 |
+
"Reading state information... Done\n",
|
| 51 |
+
"ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).\n",
|
| 52 |
+
"0 upgraded, 0 newly installed, 0 to remove and 2 not upgraded.\n"
|
| 53 |
+
]
|
| 54 |
+
}
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"source": [
|
| 60 |
+
"import os\n",
|
| 61 |
+
"from pydub import AudioSegment\n",
|
| 62 |
+
"import yt_dlp\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"# -------------------------------\n",
|
| 65 |
+
"# 1️⃣ YouTube video URL\n",
|
| 66 |
+
"# -------------------------------\n",
|
| 67 |
+
"url = \"https://youtu.be/uvOF0qn_r_0?si=-Zd4-p22-bjgEAWT\" # Replace VIDEO_ID with your YouTube link\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"# -------------------------------\n",
|
| 70 |
+
"# 2️⃣ Download audio using yt-dlp\n",
|
| 71 |
+
"# -------------------------------\n",
|
| 72 |
+
"ydl_opts = {\n",
|
| 73 |
+
" 'format': 'bestaudio/best',\n",
|
| 74 |
+
" 'outtmpl': 'video_audio.%(ext)s',\n",
|
| 75 |
+
" 'postprocessors': [{\n",
|
| 76 |
+
" 'key': 'FFmpegExtractAudio',\n",
|
| 77 |
+
" 'preferredcodec': 'mp3', # download as mp3 first\n",
|
| 78 |
+
" 'preferredquality': '192',\n",
|
| 79 |
+
" }],\n",
|
| 80 |
+
"}\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"print(\"Downloading audio from YouTube...\")\n",
|
| 83 |
+
"with yt_dlp.YoutubeDL(ydl_opts) as ydl:\n",
|
| 84 |
+
" ydl.download([url])\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"# Find downloaded file\n",
|
| 87 |
+
"for file in os.listdir():\n",
|
| 88 |
+
" if file.startswith(\"video_audio\") and file.endswith(\".mp3\"):\n",
|
| 89 |
+
" audio_file = file\n",
|
| 90 |
+
" break\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"print(\"Downloaded:\", audio_file)\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"# -------------------------------\n",
|
| 95 |
+
"# 3️⃣ Convert audio to WAV (16kHz, mono)\n",
|
| 96 |
+
"# -------------------------------\n",
|
| 97 |
+
"audio = AudioSegment.from_file(audio_file)\n",
|
| 98 |
+
"audio = audio.set_channels(1) # mono\n",
|
| 99 |
+
"audio = audio.set_frame_rate(16000) # 16 kHz\n",
|
| 100 |
+
"wav_filename = \"znmd.wav\"\n",
|
| 101 |
+
"audio.export(wav_filename, format=\"wav\")\n",
|
| 102 |
+
"print(\"Converted to WAV:\", wav_filename)\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"# -------------------------------\n",
|
| 105 |
+
"# 4️⃣ Split WAV into 1-minute chunks\n",
|
| 106 |
+
"# -------------------------------\n",
|
| 107 |
+
"chunk_length_ms = 60 * 1000 # 1 minute\n",
|
| 108 |
+
"for i, start in enumerate(range(0, len(audio), chunk_length_ms)):\n",
|
| 109 |
+
" chunk = audio[start:start+chunk_length_ms]\n",
|
| 110 |
+
" chunk_filename = f\"ZNMD_chunk_{i}.wav\"\n",
|
| 111 |
+
" chunk.export(chunk_filename, format=\"wav\")\n",
|
| 112 |
+
" print(\"Saved chunk:\", chunk_filename)\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"print(\"All steps completed! Your audio is ready for diarization.\")\n"
|
| 115 |
+
],
|
| 116 |
+
"metadata": {
|
| 117 |
+
"id": "oPEdKqgLLUEw",
|
| 118 |
+
"colab": {
|
| 119 |
+
"base_uri": "https://localhost:8080/"
|
| 120 |
+
},
|
| 121 |
+
"outputId": "0c940970-dd15-40f4-9bfb-9fb18120d879"
|
| 122 |
+
},
|
| 123 |
+
"execution_count": null,
|
| 124 |
+
"outputs": [
|
| 125 |
+
{
|
| 126 |
+
"output_type": "stream",
|
| 127 |
+
"name": "stdout",
|
| 128 |
+
"text": [
|
| 129 |
+
"Downloading audio from YouTube...\n",
|
| 130 |
+
"[youtube] Extracting URL: https://youtu.be/uvOF0qn_r_0?si=-Zd4-p22-bjgEAWT\n",
|
| 131 |
+
"[youtube] uvOF0qn_r_0: Downloading webpage\n"
|
| 132 |
+
]
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"output_type": "stream",
|
| 136 |
+
"name": "stderr",
|
| 137 |
+
"text": [
|
| 138 |
+
"WARNING: [youtube] No supported JavaScript runtime could be found. Only deno is enabled by default; to use another runtime add --js-runtimes RUNTIME[:PATH] to your command/config. YouTube extraction without a JS runtime has been deprecated, and some formats may be missing. See https://github.com/yt-dlp/yt-dlp/wiki/EJS for details on installing one\n"
|
| 139 |
+
]
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"output_type": "stream",
|
| 143 |
+
"name": "stdout",
|
| 144 |
+
"text": [
|
| 145 |
+
"[youtube] uvOF0qn_r_0: Downloading android sdkless player API JSON\n",
|
| 146 |
+
"[youtube] uvOF0qn_r_0: Downloading web safari player API JSON\n"
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"output_type": "stream",
|
| 151 |
+
"name": "stderr",
|
| 152 |
+
"text": [
|
| 153 |
+
"WARNING: [youtube] uvOF0qn_r_0: Some web_safari client https formats have been skipped as they are missing a url. YouTube is forcing SABR streaming for this client. See https://github.com/yt-dlp/yt-dlp/issues/12482 for more details\n"
|
| 154 |
+
]
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"output_type": "stream",
|
| 158 |
+
"name": "stdout",
|
| 159 |
+
"text": [
|
| 160 |
+
"[youtube] uvOF0qn_r_0: Downloading m3u8 information\n"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"output_type": "stream",
|
| 165 |
+
"name": "stderr",
|
| 166 |
+
"text": [
|
| 167 |
+
"WARNING: [youtube] uvOF0qn_r_0: Some web client https formats have been skipped as they are missing a url. YouTube is forcing SABR streaming for this client. See https://github.com/yt-dlp/yt-dlp/issues/12482 for more details\n"
|
| 168 |
+
]
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"output_type": "stream",
|
| 172 |
+
"name": "stdout",
|
| 173 |
+
"text": [
|
| 174 |
+
"[info] uvOF0qn_r_0: Downloading 1 format(s): 251\n",
|
| 175 |
+
"[download] Destination: video_audio.webm\n",
|
| 176 |
+
"[download] 100% of 3.73MiB in 00:00:00 at 13.24MiB/s \n",
|
| 177 |
+
"[ExtractAudio] Destination: video_audio.mp3\n",
|
| 178 |
+
"Deleting original file video_audio.webm (pass -k to keep)\n",
|
| 179 |
+
"Downloaded: video_audio.mp3\n",
|
| 180 |
+
"Converted to WAV: znmd.wav\n",
|
| 181 |
+
"Saved chunk: ZNMD_chunk_0.wav\n",
|
| 182 |
+
"Saved chunk: ZNMD_chunk_1.wav\n",
|
| 183 |
+
"Saved chunk: ZNMD_chunk_2.wav\n",
|
| 184 |
+
"Saved chunk: ZNMD_chunk_3.wav\n",
|
| 185 |
+
"Saved chunk: ZNMD_chunk_4.wav\n",
|
| 186 |
+
"Saved chunk: ZNMD_chunk_5.wav\n",
|
| 187 |
+
"All steps completed! Your audio is ready for diarization.\n"
|
| 188 |
+
]
|
| 189 |
+
}
|
| 190 |
+
]
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"cell_type": "code",
|
| 194 |
+
"source": [],
|
| 195 |
+
"metadata": {
|
| 196 |
+
"id": "UXUHD4sTIjhe"
|
| 197 |
+
},
|
| 198 |
+
"execution_count": null,
|
| 199 |
+
"outputs": []
|
| 200 |
+
}
|
| 201 |
+
]
|
| 202 |
+
}
|
scripts/check.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#It is only checking whether your computer has the “Segmentation” thing installed or not.
|
| 2 |
+
#Do you have something called Segmentation inside pyannote.audio?
|
| 3 |
+
#Segmentation is just a class name (a tool), not a model, not training.
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from pyannote.audio.tasks import Segmentation
|
| 7 |
+
print("Success! Segmentation task imported.")
|
| 8 |
+
except ImportError as e:
|
| 9 |
+
print(f"Still failing: {e}")
|
scripts/check_rttm.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
uri = "hindi_chunk_22"
|
| 4 |
+
rttm_path = f"dataset/rttm/{uri}.rttm"
|
| 5 |
+
|
| 6 |
+
print(f"Checking for RTTM at: {os.path.abspath(rttm_path)}")
|
| 7 |
+
|
| 8 |
+
if os.path.exists(rttm_path):
|
| 9 |
+
print("RTTM file exists!")
|
| 10 |
+
with open(rttm_path, 'r') as f:
|
| 11 |
+
first_line = f.readline()
|
| 12 |
+
print(f"First line of RTTM: {first_line.strip()}")
|
| 13 |
+
|
| 14 |
+
parts = first_line.split()
|
| 15 |
+
if len(parts) > 1:
|
| 16 |
+
rttm_uri = parts[1]
|
| 17 |
+
if rttm_uri == uri:
|
| 18 |
+
print(f"URI Match: '{rttm_uri}' matches '{uri}'")
|
| 19 |
+
else:
|
| 20 |
+
print(f"URI MISMATCH: RTTM says '{rttm_uri}' but protocol expects '{uri}'")
|
| 21 |
+
else:
|
| 22 |
+
print("RTTM file NOT found at that path!")
|
| 23 |
+
|
| 24 |
+
#It checks whether an RTTM file exists, opens it, and verifies that the filename and the RTTM’s internal URI match.
|
scripts/diarization.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/diarization_visualization.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from pyannote.metrics.diarization import DiarizationErrorRate
|
| 5 |
+
|
| 6 |
+
# THE ULTIMATE BYPASS (Fixes PyTorch 2.6 security errors)
|
| 7 |
+
import torch.serialization
|
| 8 |
+
original_load = torch.load
|
| 9 |
+
def patched_load(*args, **kwargs):
|
| 10 |
+
kwargs['weights_only'] = False
|
| 11 |
+
return original_load(*args, **kwargs)
|
| 12 |
+
torch.load = patched_load
|
| 13 |
+
|
| 14 |
+
# IMPORTS
|
| 15 |
+
from pyannote.core import notebook
|
| 16 |
+
from pyannote.audio import Pipeline
|
| 17 |
+
from pyannote.database.util import load_rttm
|
| 18 |
+
|
| 19 |
+
AUDIO_PATH = r"dataset/audio/clip_03.wav"
|
| 20 |
+
RTTM_PATH = r"dataset/rttm/clip_03.rttm"
|
| 21 |
+
|
| 22 |
+
# INITIALIZE PIPELINE
|
| 23 |
+
print("Initializing AI Pipeline...")
|
| 24 |
+
pipeline = Pipeline.from_pretrained(
|
| 25 |
+
"pyannote/speaker-diarization-3.1",
|
| 26 |
+
use_auth_token="hf_token_here" # Replace with your Hugging Face token
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# --- RUN DIARIZATION ---
|
| 30 |
+
print("AI is analyzing the audio...")
|
| 31 |
+
prediction = pipeline(AUDIO_PATH)
|
| 32 |
+
|
| 33 |
+
# --- LOAD GROUND TRUTH ---
|
| 34 |
+
gt_dict = load_rttm(RTTM_PATH)
|
| 35 |
+
uri = list(gt_dict.keys())[0]
|
| 36 |
+
ground_truth = gt_dict[uri]
|
| 37 |
+
|
| 38 |
+
# --- FIXED: CALCULATE DER USING REPORT ---
|
| 39 |
+
metric = DiarizationErrorRate()
|
| 40 |
+
# We process the specific file to get a clean report
|
| 41 |
+
metric(ground_truth, prediction, notebook=True)
|
| 42 |
+
report = metric.report(display=True)
|
| 43 |
+
|
| 44 |
+
print("\n" + "="*50)
|
| 45 |
+
print("FINAL EVALUATION REPORT")
|
| 46 |
+
print(report)
|
| 47 |
+
print("="*50 + "\n")
|
| 48 |
+
|
| 49 |
+
## --- VISUALIZATION (UNCHANGED) ---
|
| 50 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 8), sharex=True)
|
| 51 |
+
|
| 52 |
+
plt.sca(ax1)
|
| 53 |
+
notebook.plot_annotation(ground_truth, ax=ax1)
|
| 54 |
+
ax1.set_title("REFERENCE (Ground Truth)", fontsize=14, fontweight='bold')
|
| 55 |
+
|
| 56 |
+
plt.sca(ax2)
|
| 57 |
+
notebook.plot_annotation(prediction, ax=ax2)
|
| 58 |
+
ax2.set_title("HYPOTHESIS (Model Prediction)", fontsize=14, fontweight='bold')
|
| 59 |
+
|
| 60 |
+
plt.xlabel("Time (seconds)", fontsize=12)
|
| 61 |
+
plt.tight_layout()
|
| 62 |
+
|
| 63 |
+
print("Diarization complete! Displaying plot...")
|
| 64 |
+
plt.show()
|
scripts/make_splits.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#We are dividing your audio dataset into train / dev / test lists that pyannote will later use.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
# Path to audio folder
|
| 7 |
+
audio_dir = "dataset/audio"
|
| 8 |
+
|
| 9 |
+
# Collect all wav files
|
| 10 |
+
uris = [
|
| 11 |
+
f.replace(".wav", "")
|
| 12 |
+
for f in os.listdir(audio_dir)
|
| 13 |
+
if f.endswith(".wav")
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
# Safety check
|
| 17 |
+
if len(uris) != 89:
|
| 18 |
+
print(f"Warning: expected 89 files, found {len(uris)}")
|
| 19 |
+
|
| 20 |
+
# Shuffle for randomness
|
| 21 |
+
random.seed(42)
|
| 22 |
+
random.shuffle(uris)
|
| 23 |
+
|
| 24 |
+
# Split sizes for 89 files
|
| 25 |
+
train = uris[:71]
|
| 26 |
+
dev = uris[71:80]
|
| 27 |
+
test = uris[80:89]
|
| 28 |
+
|
| 29 |
+
# Create splits folder if not exists
|
| 30 |
+
os.makedirs("dataset/splits", exist_ok=True)
|
| 31 |
+
|
| 32 |
+
def write_split(name, data):
|
| 33 |
+
with open(f"dataset/splits/{name}.txt", "w", encoding="utf-8") as f:
|
| 34 |
+
for uri in data:
|
| 35 |
+
f.write(uri + "\n")
|
| 36 |
+
|
| 37 |
+
write_split("train", train)
|
| 38 |
+
write_split("dev", dev)
|
| 39 |
+
write_split("test", test)
|
| 40 |
+
|
| 41 |
+
# Print summary
|
| 42 |
+
print("Dataset split completed:")
|
| 43 |
+
print(f" Train: {len(train)} files")
|
| 44 |
+
print(f" Dev : {len(dev)} files")
|
| 45 |
+
print(f" Test : {len(test)} files")
|
| 46 |
+
|
| 47 |
+
# 71 for training
|
| 48 |
+
# 9 for validation (dev)
|
| 49 |
+
# 9 for testing
|
scripts/run.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# MUST BE AT THE VERY TOP
|
| 3 |
+
os.environ["SPEECHBRAIN_LOCAL_STRATEGY"] = "copy"
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from pyannote.audio import Model
|
| 9 |
+
from pyannote.audio.pipelines import SpeakerDiarization
|
| 10 |
+
from pyannote.database.util import load_rttm
|
| 11 |
+
from pyannote.metrics.diarization import DiarizationErrorRate, DiarizationPurity, DiarizationCoverage
|
| 12 |
+
|
| 13 |
+
# --- THE DEFINITIVE FIX FOR PYTORCH 2.6+ SECURITY ERRORS ---
|
| 14 |
+
import torch.serialization
|
| 15 |
+
original_load = torch.load
|
| 16 |
+
def forced_load(f, map_location=None, pickle_module=None, **kwargs):
|
| 17 |
+
kwargs['weights_only'] = False
|
| 18 |
+
return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs)
|
| 19 |
+
torch.load = forced_load
|
| 20 |
+
# -----------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
# Configuration - Update these paths to match your project structure
|
| 23 |
+
CHECKPOINT_PATH = "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt"
|
| 24 |
+
TEST_LIST_PATH = "dataset/splits/test.txt"
|
| 25 |
+
AUDIO_DIR = "dataset/audio"
|
| 26 |
+
RTTM_DIR = "dataset/rttm"
|
| 27 |
+
OUTPUT_CSV = "overall_model_performance.csv"
|
| 28 |
+
|
| 29 |
+
def run_global_evaluation():
|
| 30 |
+
# 1. Load the fine-tuned model
|
| 31 |
+
print(f"Loading fine-tuned model from: {CHECKPOINT_PATH}")
|
| 32 |
+
seg_model = Model.from_pretrained(CHECKPOINT_PATH)
|
| 33 |
+
|
| 34 |
+
# 2. Initialize the Diarization Pipeline
|
| 35 |
+
print("Initializing Pipeline...")
|
| 36 |
+
pipeline = SpeakerDiarization(
|
| 37 |
+
segmentation=seg_model,
|
| 38 |
+
embedding="speechbrain/spkrec-ecapa-voxceleb",
|
| 39 |
+
clustering="AgglomerativeClustering",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Balanced parameters for diverse speaker counts
|
| 43 |
+
params = {
|
| 44 |
+
"segmentation": {
|
| 45 |
+
"threshold": 0.58, # High threshold to kill False Alarms
|
| 46 |
+
"min_duration_off": 0.2, # Prevents fragmented "flickering" between speakers
|
| 47 |
+
},
|
| 48 |
+
"clustering": {
|
| 49 |
+
"method": "centroid",
|
| 50 |
+
"threshold": 0.62, # Lower threshold to encourage speaker separation
|
| 51 |
+
"min_cluster_size": 1,
|
| 52 |
+
},
|
| 53 |
+
}
|
| 54 |
+
pipeline.instantiate(params)
|
| 55 |
+
|
| 56 |
+
# 3. Initialize Metrics
|
| 57 |
+
# Using 'total' metrics to accumulate across all files
|
| 58 |
+
total_der_metric = DiarizationErrorRate()
|
| 59 |
+
|
| 60 |
+
# 4. Load filenames from test.txt
|
| 61 |
+
with open(TEST_LIST_PATH, 'r') as f:
|
| 62 |
+
# Extract the URI (filename without extension) from each line
|
| 63 |
+
# Adjust the split logic if your test.txt has a different format (e.g., space-separated)
|
| 64 |
+
test_files = [line.strip().split()[0] for line in f if line.strip()]
|
| 65 |
+
|
| 66 |
+
print(f"Found {len(test_files)} files in test set. Starting Batch Processing...")
|
| 67 |
+
print("-" * 50)
|
| 68 |
+
|
| 69 |
+
for uri in test_files:
|
| 70 |
+
audio_path = os.path.join(AUDIO_DIR, f"{uri}.wav")
|
| 71 |
+
rttm_path = os.path.join(RTTM_DIR, f"{uri}.rttm")
|
| 72 |
+
|
| 73 |
+
if not os.path.exists(audio_path):
|
| 74 |
+
print(f"Warning: Audio file not found for {uri}. Skipping.")
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
# Load Reference RTTM
|
| 78 |
+
try:
|
| 79 |
+
reference = load_rttm(rttm_path)[uri]
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"Warning: Could not load RTTM for {uri}. Error: {e}")
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
# Run Diarization
|
| 85 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 86 |
+
test_file = {"waveform": waveform, "sample_rate": sample_rate, "uri": uri}
|
| 87 |
+
|
| 88 |
+
# We allow the AI to determine speaker count dynamically (min 2, max 7)
|
| 89 |
+
hypothesis = pipeline(test_file, min_speakers=2, max_speakers=7)
|
| 90 |
+
|
| 91 |
+
# Accumulate the metric
|
| 92 |
+
total_der_metric(reference, hypothesis, detailed=True)
|
| 93 |
+
print(f"Done: {uri}")
|
| 94 |
+
|
| 95 |
+
# 5. Final Calculations
|
| 96 |
+
print("\n" + "="*50)
|
| 97 |
+
print(" FINAL GLOBAL REPORT")
|
| 98 |
+
print("="*50)
|
| 99 |
+
|
| 100 |
+
# This creates a detailed table per file
|
| 101 |
+
report_df = total_der_metric.report(display=True)
|
| 102 |
+
|
| 103 |
+
# Global DER is the value of the metric after processing all files
|
| 104 |
+
global_der = abs(total_der_metric)
|
| 105 |
+
global_accuracy = max(0, (1 - global_der) * 100)
|
| 106 |
+
|
| 107 |
+
print(f"\nOVERALL SYSTEM ACCURACY : {global_accuracy:.2f}%")
|
| 108 |
+
print(f"GLOBAL DIARIZATION ERROR: {global_der * 100:.2f}%")
|
| 109 |
+
print("="*50)
|
| 110 |
+
|
| 111 |
+
# Save detailed report to CSV for your documentation
|
| 112 |
+
report_df.to_csv(OUTPUT_CSV)
|
| 113 |
+
print(f"Detailed file-by-file breakdown saved to: {OUTPUT_CSV}")
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
run_global_evaluation()
|
scripts/segmentation.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
import torch.serialization
|
| 7 |
+
from pyannote.core import Segment, Timeline
|
| 8 |
+
|
| 9 |
+
# --- 1. MONKEY PATCH (Fixes PyTorch 2.6 Security Error) ---
|
| 10 |
+
original_load = torch.serialization.load
|
| 11 |
+
def forced_load(f, map_location=None, pickle_module=None, **kwargs):
|
| 12 |
+
kwargs['weights_only'] = False
|
| 13 |
+
return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs)
|
| 14 |
+
torch.load = forced_load
|
| 15 |
+
torch.serialization.load = forced_load
|
| 16 |
+
# ---------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
#Model : neural network, Segmentation : training logic, get_protocol : dataset loader, pl.Trainer : training engine
|
| 19 |
+
|
| 20 |
+
from pyannote.audio import Model
|
| 21 |
+
from pyannote.audio.tasks import Segmentation
|
| 22 |
+
from pyannote.database import get_protocol, FileFinder
|
| 23 |
+
import pytorch_lightning as pl
|
| 24 |
+
|
| 25 |
+
os.environ["PYANNOTE_DATABASE_CONFIG"] = "database.yml"
|
| 26 |
+
|
| 27 |
+
def train_segmentation():
|
| 28 |
+
# 2. PREPROCESSORS
|
| 29 |
+
def get_annotated(file):
|
| 30 |
+
info = torchaudio.info(file["audio"])
|
| 31 |
+
# Calculate duration: total frames / sample rate
|
| 32 |
+
duration = info.num_frames / info.sample_rate
|
| 33 |
+
# Return the 'Timeline' object the library is looking for
|
| 34 |
+
return Timeline([Segment(0, duration)])
|
| 35 |
+
|
| 36 |
+
preprocessors = {
|
| 37 |
+
"audio": FileFinder(),
|
| 38 |
+
"annotated": get_annotated,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# 3. LOAD PROTOCOL
|
| 42 |
+
print("Loading Hindi-Bhojpuri Protocol...")
|
| 43 |
+
protocol = get_protocol(
|
| 44 |
+
'HindiBhojpuri.SpeakerDiarization.Segmentation',
|
| 45 |
+
preprocessors=preprocessors
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# 4. SETUP TASK
|
| 49 |
+
seg_task = Segmentation(
|
| 50 |
+
protocol,
|
| 51 |
+
duration=2.0,
|
| 52 |
+
batch_size=4,
|
| 53 |
+
num_workers=0
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# 5. LOAD MODEL - Start from an English-trained segmentation model, and adapt it to Hindi/Bhojpuri.” This is transfer learning, not training from scratch.
|
| 57 |
+
print("Attempting to load model...")
|
| 58 |
+
model = Model.from_pretrained("pyannote/segmentation-3.0")
|
| 59 |
+
model.task = seg_task
|
| 60 |
+
|
| 61 |
+
# 6. TRAINER
|
| 62 |
+
trainer = pl.Trainer(
|
| 63 |
+
accelerator="cpu",
|
| 64 |
+
max_epochs=5,
|
| 65 |
+
default_root_dir="training_results"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# 7. START
|
| 69 |
+
print("--- Starting Fine-tuning ---")
|
| 70 |
+
trainer.fit(model)
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
train_segmentation()
|
scripts/test_model.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchaudio
|
| 3 |
+
from pyannote.audio import Model
|
| 4 |
+
from pyannote.core import Annotation, Segment
|
| 5 |
+
|
| 6 |
+
# 1. PATHS
|
| 7 |
+
CHECKPOINT_PATH = "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt"
|
| 8 |
+
TEST_AUDIO = "dataset/audio/clip_07.wav"
|
| 9 |
+
|
| 10 |
+
def run_test():
|
| 11 |
+
print(f"Loading model directly...")
|
| 12 |
+
model = Model.from_pretrained(CHECKPOINT_PATH)
|
| 13 |
+
model.eval() # Set to evaluation mode
|
| 14 |
+
|
| 15 |
+
# 2. Load Audio Manually
|
| 16 |
+
waveform, sample_rate = torchaudio.load(TEST_AUDIO)
|
| 17 |
+
|
| 18 |
+
# Model expects [batch, channels, samples] - adding a batch dimension
|
| 19 |
+
if waveform.ndim == 2:
|
| 20 |
+
waveform = waveform.unsqueeze(0)
|
| 21 |
+
|
| 22 |
+
print("Running raw inference...")
|
| 23 |
+
with torch.no_grad():
|
| 24 |
+
# Get raw scores [batch, frames, speakers]
|
| 25 |
+
# This returns probabilities for each speaker class
|
| 26 |
+
scores = model(waveform)
|
| 27 |
+
|
| 28 |
+
# 3. Simple thresholding to find speakers
|
| 29 |
+
# If score > 0.5, we consider that speaker "active"
|
| 30 |
+
print("\n--- Raw Model Detections ---")
|
| 31 |
+
|
| 32 |
+
# We'll use a very simple logic to show you what the model sees
|
| 33 |
+
# The output usually has several speaker 'slots' (e.g., 7 slots)
|
| 34 |
+
num_speakers = scores.shape[-1]
|
| 35 |
+
|
| 36 |
+
# Moving average/thresholding logic
|
| 37 |
+
# (Simplified for debugging)
|
| 38 |
+
for s in range(num_speakers):
|
| 39 |
+
active_frames = torch.where(scores[0, :, s] > 0.5)[0]
|
| 40 |
+
if len(active_frames) > 0:
|
| 41 |
+
# Just showing first and last detection for this slot to keep it clean
|
| 42 |
+
start_time = active_frames[0] * 0.016 # Approximate frame shift
|
| 43 |
+
end_time = active_frames[-1] * 0.016
|
| 44 |
+
print(f"Speaker Slot {s}: Detected activity between {start_time:.2f}s and {end_time:.2f}s")
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
run_test()
|
scripts/test_protocol.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#It checks whether pyannote can correctly read your dataset (audio + RTTM) using database.yml.
|
| 2 |
+
#get_protocol → asks pyannote:“Give me the dataset described in database.yml” FileFinder → helps pyannote find audio files
|
| 3 |
+
import os
|
| 4 |
+
from pyannote.database import get_protocol, FileFinder
|
| 5 |
+
|
| 6 |
+
# 1. Setup paths
|
| 7 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 8 |
+
config_path = os.path.join(current_dir, "database.yml")
|
| 9 |
+
os.environ["PYANNOTE_DATABASE_CONFIG"] = config_path
|
| 10 |
+
|
| 11 |
+
# 2. Initialize preprocessor
|
| 12 |
+
preprocessors = {'audio': FileFinder()}
|
| 13 |
+
|
| 14 |
+
# 3. Load the protocol
|
| 15 |
+
try:
|
| 16 |
+
protocol = get_protocol(
|
| 17 |
+
'HindiBhojpuri.SpeakerDiarization.Segmentation',
|
| 18 |
+
preprocessors=preprocessors
|
| 19 |
+
)
|
| 20 |
+
print("Protocol loaded successfully!")
|
| 21 |
+
except Exception as e:
|
| 22 |
+
print(f"Failed to load protocol: {e}")
|
| 23 |
+
exit()
|
| 24 |
+
|
| 25 |
+
# 4. Detailed Data Verification
|
| 26 |
+
# This replaces your previous testing loop
|
| 27 |
+
for file in protocol.test():
|
| 28 |
+
print("\n" + "="*30)
|
| 29 |
+
print(f"FILE URI: {file['uri']}")
|
| 30 |
+
print(f"AUDIO PATH: {file['audio']}")
|
| 31 |
+
|
| 32 |
+
# Load the annotation (the RTTM data)
|
| 33 |
+
annotation = file['annotation']
|
| 34 |
+
print(f"SEGMENTS FOUND: {len(annotation)}")
|
| 35 |
+
|
| 36 |
+
print("-" * 30)
|
| 37 |
+
print("START | END | SPEAKER")
|
| 38 |
+
print("-" * 30)
|
| 39 |
+
|
| 40 |
+
# Iterate through the first 5 segments to keep the output clean
|
| 41 |
+
for i, (segment, track, label) in enumerate(annotation.itertracks(yield_label=True)):
|
| 42 |
+
if i >= 5:
|
| 43 |
+
print("... (and more)")
|
| 44 |
+
break
|
| 45 |
+
print(f"{segment.start:9.2f}s | {segment.end:9.2f}s | {label}")
|
| 46 |
+
|
| 47 |
+
print("="*30)
|
| 48 |
+
|
| 49 |
+
# Only check the first file for now
|
| 50 |
+
break
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# database.yml is correct, audio paths are correct
|
| 54 |
+
# RTTM files load correctly
|
| 55 |
+
# speaker segments exist
|
| 56 |
+
# segmentation training CAN start
|
scripts/visualize_segmentation.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
from pyannote.audio import Model, Inference
|
| 7 |
+
from pyannote.audio.utils.signal import Binarize
|
| 8 |
+
from pyannote.database.util import load_rttm
|
| 9 |
+
from pyannote.core import notebook, SlidingWindowFeature, Annotation
|
| 10 |
+
from sklearn.cluster import AgglomerativeClustering
|
| 11 |
+
|
| 12 |
+
# --- 1. PYTORCH 2.6+ SECURITY FIX ---
|
| 13 |
+
import torch.serialization
|
| 14 |
+
original_load = torch.load
|
| 15 |
+
def forced_load(f, map_location=None, pickle_module=None, **kwargs):
|
| 16 |
+
kwargs['weights_only'] = False
|
| 17 |
+
return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs)
|
| 18 |
+
torch.load = forced_load
|
| 19 |
+
# ------------------------------------
|
| 20 |
+
|
| 21 |
+
def visualize_audio_file(audio_path, rttm_path, checkpoint_path):
|
| 22 |
+
file_id = os.path.basename(audio_path).replace('.wav', '')
|
| 23 |
+
print(f"--- Processing: {file_id} ---")
|
| 24 |
+
|
| 25 |
+
# 1. Load Model & Run Inference
|
| 26 |
+
model = Model.from_pretrained(checkpoint_path)
|
| 27 |
+
inference = Inference(model, window="sliding", duration=2.0, step=0.5)
|
| 28 |
+
seg_output = inference(audio_path)
|
| 29 |
+
|
| 30 |
+
# 2. Reshape and Binarize (Using a high threshold to remove background noise)
|
| 31 |
+
data = np.squeeze(seg_output.data)
|
| 32 |
+
if len(data.shape) == 3: data = data[:, :, 0]
|
| 33 |
+
|
| 34 |
+
# Higher onset (0.8) ignores the "messy" low-volume background noises
|
| 35 |
+
binarize = Binarize(onset=0.8, offset=0.6, min_duration_on=0.4, min_duration_off=0.2)
|
| 36 |
+
raw_hypothesis = binarize(SlidingWindowFeature(data, seg_output.sliding_window))
|
| 37 |
+
|
| 38 |
+
# 3. MANUAL CLUSTERING (The fix for the rainbow/messy graph)
|
| 39 |
+
print("Clustering segments to simplify speakers...")
|
| 40 |
+
final_hypothesis = Annotation(uri=file_id)
|
| 41 |
+
|
| 42 |
+
# We take all those tiny segments and group them by their "class" index
|
| 43 |
+
# In raw segmentation, the 'class' index acts as a temporary speaker ID
|
| 44 |
+
for segment, track, label in raw_hypothesis.itertracks(yield_label=True):
|
| 45 |
+
# We simplify the labels: "0", "1", "2" instead of "104", "112", etc.
|
| 46 |
+
final_hypothesis[segment, track] = f"Speaker_{label % 5}"
|
| 47 |
+
|
| 48 |
+
# 4. Load Ground Truth
|
| 49 |
+
reference = load_rttm(rttm_path)[file_id]
|
| 50 |
+
|
| 51 |
+
# 5. Plotting
|
| 52 |
+
print("Generating Clean Graph...")
|
| 53 |
+
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(15, 8))
|
| 54 |
+
|
| 55 |
+
# Ground Truth
|
| 56 |
+
notebook.plot_annotation(reference, ax=ax[0], time=True, legend=True)
|
| 57 |
+
ax[0].set_title(f"GROUND TRUTH: {file_id}")
|
| 58 |
+
|
| 59 |
+
# Simplified AI Result
|
| 60 |
+
notebook.plot_annotation(final_hypothesis, ax=ax[1], time=True, legend=True)
|
| 61 |
+
ax[1].set_title(f"CLEANED AI HYPOTHESIS (Clustered & Filtered)")
|
| 62 |
+
|
| 63 |
+
plt.tight_layout()
|
| 64 |
+
plt.show()
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
AUDIO_FILE = "dataset/audio/bhojpuri_chunk_20.wav"
|
| 68 |
+
RTTM_FILE = "dataset/rttm/bhojpuri_chunk_20.rttm"
|
| 69 |
+
MODEL_CHECKPOINT = "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt"
|
| 70 |
+
|
| 71 |
+
if os.path.exists(AUDIO_FILE):
|
| 72 |
+
visualize_audio_file(AUDIO_FILE, RTTM_FILE, MODEL_CHECKPOINT)
|
test/audio2.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:daf78623168634fb66384ae3341f9bf5ab1e57fc4694c4e34b68f022a1527478
|
| 3 |
+
size 842157
|