AnamikaP commited on
Commit
9f76952
·
verified ·
1 Parent(s): e266438

Upload 18 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ outputs/Screenshot[[:space:]]2026-02-03[[:space:]]034427.png filter=lfs diff=lfs merge=lfs -text
37
+ outputs/Screenshot[[:space:]]2026-02-03[[:space:]]051514.png filter=lfs diff=lfs merge=lfs -text
38
+ outputs/Screenshot[[:space:]]2026-02-03[[:space:]]154131.png filter=lfs diff=lfs merge=lfs -text
39
+ test/audio2.wav filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import librosa
6
+ import soundfile as sf
7
+ import noisereduce as nr
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+ from pyannote.audio import Model, Inference
11
+ from pyannote.audio.utils.signal import Binarize
12
+ from pyannote.core import SlidingWindowFeature, Annotation
13
+ from sklearn.cluster import AgglomerativeClustering
14
+ from sklearn.metrics import silhouette_score
15
+
16
+ # --- 1. PYTORCH 2.6+ SECURITY FIX ---
17
+ import torch.serialization
18
+ original_load = torch.load
19
+ def forced_load(f, map_location=None, pickle_module=None, **kwargs):
20
+ kwargs['weights_only'] = False
21
+ return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs)
22
+ torch.load = forced_load
23
+ # -------------------------------
24
+
25
+ st.set_page_config(page_title="Hindi-Bhojpuri Diarization Tool", layout="wide")
26
+
27
+ st.title("🎙️ Speaker Diarization with De-noising")
28
+ st.markdown("""
29
+ This tool uses a fine-tuned model to detect speakers.
30
+ The system automatically determines the number of speakers based on voice similarity.
31
+ """)
32
+
33
+ # --- SIDEBAR CONFIGURATION (UI CLEANUP) ---
34
+ st.sidebar.header("Configuration")
35
+ MODEL_PATH = st.sidebar.text_input("Model Checkpoint Path", "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt")
36
+ use_denoise = st.sidebar.checkbox("Enable De-noising", value=True)
37
+
38
+ st.sidebar.subheader("Advanced Settings")
39
+ threshold = st.sidebar.slider("AI Sensitivity (VAD)", 0.5, 0.95, 0.80)
40
+
41
+ @st.cache_resource
42
+ def load_cached_model(path):
43
+ if not os.path.exists(path):
44
+ return None
45
+ return Model.from_pretrained(path)
46
+
47
+ def process_audio(audio_path, model_path, denoise, sensitivity):
48
+ # 1. Load Audio
49
+ y, sr = librosa.load(audio_path, sr=16000)
50
+
51
+ # converting to .wav
52
+ # CONVERSION STEP: Explicitly write out as a standard .wav
53
+ # This ensures the AI receives a PCM_16 bit depth file at 16kHz
54
+ audio_input = "converted_audio.wav"
55
+ sf.write(audio_input, y, sr, subtype='PCM_16')
56
+
57
+ # 2. AGGRESSIVE DE-NOISING
58
+ if denoise:
59
+ with st.spinner("Step 1: Deep cleaning audio..."):
60
+ # Increased prop_decrease to 0.90 to kill heavy background noise
61
+ y = nr.reduce_noise(y=y, sr=sr, prop_decrease=0.90, n_fft=2048)
62
+ audio_input = "temp_denoised.wav"
63
+ sf.write(audio_input, y, sr)
64
+ else:
65
+ audio_input = audio_path
66
+
67
+ # 3. AI Inference
68
+ with st.spinner("Step 2: AI Neural Analysis..."):
69
+ model = load_cached_model(model_path)
70
+ if model is None: return None, None
71
+
72
+ inference = Inference(model, window="sliding", duration=2.0, step=0.5)
73
+ seg_output = inference(audio_input)
74
+
75
+ data = np.squeeze(seg_output.data)
76
+ if len(data.shape) == 3: data = data[:, :, 0]
77
+ clean_scores = SlidingWindowFeature(data, seg_output.sliding_window)
78
+
79
+ # 4. BINARIZATION FIX: Increase 'min_duration_on' to 1.2 seconds
80
+ # This ignores all short noises/coughs/background clicks that cause 100+ speakers.
81
+ binarize = Binarize(onset=0.85, offset=0.75, min_duration_on=1.2, min_duration_off=0.5)
82
+ raw_hyp = binarize(clean_scores)
83
+
84
+ # 5. FEATURE EXTRACTION
85
+ embeddings = []
86
+ segments = []
87
+ for segment, track, label in raw_hyp.itertracks(yield_label=True):
88
+ # Focus on the middle of the segment to get a 'clean' voiceprint
89
+ feature_vector = np.mean(seg_output.crop(segment).data, axis=0).flatten()
90
+ embeddings.append(feature_vector)
91
+ segments.append(segment)
92
+
93
+ final_hyp = Annotation()
94
+
95
+ if len(embeddings) > 1:
96
+ X = np.array(embeddings)
97
+
98
+ # --- AUTO-DETECTION LOGIC ---
99
+ # If Silhouette fails, we fall back to a safe range (2 to 5 speakers)
100
+ try:
101
+ scores = []
102
+ range_n = range(2, min(len(embeddings), 6))
103
+ for n in range_n:
104
+ clusterer = AgglomerativeClustering(n_clusters=n, metric='euclidean', linkage='ward')
105
+ labels = clusterer.fit_predict(X)
106
+ scores.append(silhouette_score(X, labels))
107
+ best_n = range_n[np.argmax(scores)]
108
+ except:
109
+ best_n = 2 # Safe default for OJT demo
110
+
111
+ clusterer = AgglomerativeClustering(n_clusters=best_n, metric='euclidean', linkage='ward')
112
+ final_labels = clusterer.fit_predict(X)
113
+
114
+ for i, segment in enumerate(segments):
115
+ final_hyp[segment] = f"Speaker {final_labels[i]}"
116
+
117
+ elif len(embeddings) == 1:
118
+ final_hyp[segments[0]] = "Speaker 0"
119
+
120
+ # .support() is CRITICAL: it merges small gaps of the same speaker
121
+ return final_hyp.support(), audio_input
122
+
123
+ # --- MAIN UI ---
124
+ uploaded_file = st.file_uploader("Upload .wav file", type=["wav"])
125
+
126
+ if uploaded_file is not None:
127
+ with open("temp_upload.wav", "wb") as f:
128
+ f.write(uploaded_file.getbuffer())
129
+
130
+ col1, col2 = st.columns(2)
131
+ with col1:
132
+ st.subheader("Original Audio")
133
+ st.audio("temp_upload.wav")
134
+
135
+ if st.button("Start AI Analysis"):
136
+ hyp, final_audio = process_audio("temp_upload.wav", MODEL_PATH, use_denoise, threshold)
137
+
138
+ if hyp is None:
139
+ st.error("Model not found!")
140
+ else:
141
+ with col2:
142
+ if use_denoise:
143
+ st.subheader("Denoised Version")
144
+ st.audio(final_audio)
145
+
146
+ st.divider()
147
+
148
+ unique_speakers = sorted(hyp.labels())
149
+ st.subheader(f"📊 Speaker Timeline ({len(unique_speakers)} Speakers Detected)")
150
+
151
+ if len(unique_speakers) > 0:
152
+ fig, ax = plt.subplots(figsize=(12, len(unique_speakers) * 0.8 + 1.5))
153
+ colors = plt.cm.get_cmap('tab10', len(unique_speakers))
154
+
155
+ for i, speaker in enumerate(unique_speakers):
156
+ speaker_segments = hyp.label_timeline(speaker)
157
+ intervals = [(s.start, s.duration) for s in speaker_segments]
158
+ ax.broken_barh(intervals, (i*10 + 2, 6), facecolors=colors(i))
159
+
160
+ ax.set_yticks([i*10 + 5 for i in range(len(unique_speakers))])
161
+ ax.set_yticklabels(unique_speakers)
162
+ ax.set_xlabel("Time (seconds)")
163
+ ax.grid(axis='x', linestyle='--', alpha=0.5)
164
+ st.pyplot(fig)
165
+
166
+ timestamp_list = []
167
+ for segment, track, label in hyp.itertracks(yield_label=True):
168
+ timestamp_list.append({
169
+ "Speaker ID": label,
170
+ "Start (s)": round(segment.start, 2),
171
+ "End (s)": round(segment.end, 2),
172
+ "Duration (s)": round(segment.duration, 2)
173
+ })
174
+
175
+ df = pd.DataFrame(timestamp_list)
176
+ st.dataframe(df, use_container_width=True)
177
+ st.download_button("📩 Download CSV", df.to_csv(index=False).encode('utf-8'), "diarization.csv", "text/csv")
178
+ else:
179
+ st.warning("No speech detected.")
outputs/Screenshot 2026-02-03 034427.png ADDED

Git LFS Details

  • SHA256: 9184b1e148610b6f07d12271613b37d100a752733bfcbcc7fa52bea5de888489
  • Pointer size: 131 Bytes
  • Size of remote file: 156 kB
outputs/Screenshot 2026-02-03 051514.png ADDED

Git LFS Details

  • SHA256: 5c9741ff4d11ded797fcc26c7925a4f586b973ab71a94d49390cfef3caeb1172
  • Pointer size: 131 Bytes
  • Size of remote file: 161 kB
outputs/Screenshot 2026-02-03 061911.png ADDED
outputs/Screenshot 2026-02-03 154131.png ADDED

Git LFS Details

  • SHA256: 9c17650e8e30f93225e9c51c16261c71c9489d2a89d22f9f8304770e2061497f
  • Pointer size: 131 Bytes
  • Size of remote file: 112 kB
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
scripts/audioconversion.ipynb ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "code",
19
+ "source": [
20
+ "!pip install yt-dlp pydub ffmpeg-python\n",
21
+ "!apt-get install ffmpeg -y # For Colab, to make pydub work"
22
+ ],
23
+ "metadata": {
24
+ "colab": {
25
+ "base_uri": "https://localhost:8080/"
26
+ },
27
+ "id": "QURofI_GiSTK",
28
+ "outputId": "7c46dcd6-d49c-4172-bf27-d9d408352cdf"
29
+ },
30
+ "execution_count": null,
31
+ "outputs": [
32
+ {
33
+ "output_type": "stream",
34
+ "name": "stdout",
35
+ "text": [
36
+ "Collecting yt-dlp\n",
37
+ " Downloading yt_dlp-2025.12.8-py3-none-any.whl.metadata (180 kB)\n",
38
+ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/180.3 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m180.3/180.3 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
39
+ "\u001b[?25hRequirement already satisfied: pydub in /usr/local/lib/python3.12/dist-packages (0.25.1)\n",
40
+ "Collecting ffmpeg-python\n",
41
+ " Downloading ffmpeg_python-0.2.0-py3-none-any.whl.metadata (1.7 kB)\n",
42
+ "Requirement already satisfied: future in /usr/local/lib/python3.12/dist-packages (from ffmpeg-python) (1.0.0)\n",
43
+ "Downloading yt_dlp-2025.12.8-py3-none-any.whl (3.3 MB)\n",
44
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.3/3.3 MB\u001b[0m \u001b[31m55.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
45
+ "\u001b[?25hDownloading ffmpeg_python-0.2.0-py3-none-any.whl (25 kB)\n",
46
+ "Installing collected packages: yt-dlp, ffmpeg-python\n",
47
+ "Successfully installed ffmpeg-python-0.2.0 yt-dlp-2025.12.8\n",
48
+ "Reading package lists... Done\n",
49
+ "Building dependency tree... Done\n",
50
+ "Reading state information... Done\n",
51
+ "ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).\n",
52
+ "0 upgraded, 0 newly installed, 0 to remove and 2 not upgraded.\n"
53
+ ]
54
+ }
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "source": [
60
+ "import os\n",
61
+ "from pydub import AudioSegment\n",
62
+ "import yt_dlp\n",
63
+ "\n",
64
+ "# -------------------------------\n",
65
+ "# 1️⃣ YouTube video URL\n",
66
+ "# -------------------------------\n",
67
+ "url = \"https://youtu.be/uvOF0qn_r_0?si=-Zd4-p22-bjgEAWT\" # Replace VIDEO_ID with your YouTube link\n",
68
+ "\n",
69
+ "# -------------------------------\n",
70
+ "# 2️⃣ Download audio using yt-dlp\n",
71
+ "# -------------------------------\n",
72
+ "ydl_opts = {\n",
73
+ " 'format': 'bestaudio/best',\n",
74
+ " 'outtmpl': 'video_audio.%(ext)s',\n",
75
+ " 'postprocessors': [{\n",
76
+ " 'key': 'FFmpegExtractAudio',\n",
77
+ " 'preferredcodec': 'mp3', # download as mp3 first\n",
78
+ " 'preferredquality': '192',\n",
79
+ " }],\n",
80
+ "}\n",
81
+ "\n",
82
+ "print(\"Downloading audio from YouTube...\")\n",
83
+ "with yt_dlp.YoutubeDL(ydl_opts) as ydl:\n",
84
+ " ydl.download([url])\n",
85
+ "\n",
86
+ "# Find downloaded file\n",
87
+ "for file in os.listdir():\n",
88
+ " if file.startswith(\"video_audio\") and file.endswith(\".mp3\"):\n",
89
+ " audio_file = file\n",
90
+ " break\n",
91
+ "\n",
92
+ "print(\"Downloaded:\", audio_file)\n",
93
+ "\n",
94
+ "# -------------------------------\n",
95
+ "# 3️⃣ Convert audio to WAV (16kHz, mono)\n",
96
+ "# -------------------------------\n",
97
+ "audio = AudioSegment.from_file(audio_file)\n",
98
+ "audio = audio.set_channels(1) # mono\n",
99
+ "audio = audio.set_frame_rate(16000) # 16 kHz\n",
100
+ "wav_filename = \"znmd.wav\"\n",
101
+ "audio.export(wav_filename, format=\"wav\")\n",
102
+ "print(\"Converted to WAV:\", wav_filename)\n",
103
+ "\n",
104
+ "# -------------------------------\n",
105
+ "# 4️⃣ Split WAV into 1-minute chunks\n",
106
+ "# -------------------------------\n",
107
+ "chunk_length_ms = 60 * 1000 # 1 minute\n",
108
+ "for i, start in enumerate(range(0, len(audio), chunk_length_ms)):\n",
109
+ " chunk = audio[start:start+chunk_length_ms]\n",
110
+ " chunk_filename = f\"ZNMD_chunk_{i}.wav\"\n",
111
+ " chunk.export(chunk_filename, format=\"wav\")\n",
112
+ " print(\"Saved chunk:\", chunk_filename)\n",
113
+ "\n",
114
+ "print(\"All steps completed! Your audio is ready for diarization.\")\n"
115
+ ],
116
+ "metadata": {
117
+ "id": "oPEdKqgLLUEw",
118
+ "colab": {
119
+ "base_uri": "https://localhost:8080/"
120
+ },
121
+ "outputId": "0c940970-dd15-40f4-9bfb-9fb18120d879"
122
+ },
123
+ "execution_count": null,
124
+ "outputs": [
125
+ {
126
+ "output_type": "stream",
127
+ "name": "stdout",
128
+ "text": [
129
+ "Downloading audio from YouTube...\n",
130
+ "[youtube] Extracting URL: https://youtu.be/uvOF0qn_r_0?si=-Zd4-p22-bjgEAWT\n",
131
+ "[youtube] uvOF0qn_r_0: Downloading webpage\n"
132
+ ]
133
+ },
134
+ {
135
+ "output_type": "stream",
136
+ "name": "stderr",
137
+ "text": [
138
+ "WARNING: [youtube] No supported JavaScript runtime could be found. Only deno is enabled by default; to use another runtime add --js-runtimes RUNTIME[:PATH] to your command/config. YouTube extraction without a JS runtime has been deprecated, and some formats may be missing. See https://github.com/yt-dlp/yt-dlp/wiki/EJS for details on installing one\n"
139
+ ]
140
+ },
141
+ {
142
+ "output_type": "stream",
143
+ "name": "stdout",
144
+ "text": [
145
+ "[youtube] uvOF0qn_r_0: Downloading android sdkless player API JSON\n",
146
+ "[youtube] uvOF0qn_r_0: Downloading web safari player API JSON\n"
147
+ ]
148
+ },
149
+ {
150
+ "output_type": "stream",
151
+ "name": "stderr",
152
+ "text": [
153
+ "WARNING: [youtube] uvOF0qn_r_0: Some web_safari client https formats have been skipped as they are missing a url. YouTube is forcing SABR streaming for this client. See https://github.com/yt-dlp/yt-dlp/issues/12482 for more details\n"
154
+ ]
155
+ },
156
+ {
157
+ "output_type": "stream",
158
+ "name": "stdout",
159
+ "text": [
160
+ "[youtube] uvOF0qn_r_0: Downloading m3u8 information\n"
161
+ ]
162
+ },
163
+ {
164
+ "output_type": "stream",
165
+ "name": "stderr",
166
+ "text": [
167
+ "WARNING: [youtube] uvOF0qn_r_0: Some web client https formats have been skipped as they are missing a url. YouTube is forcing SABR streaming for this client. See https://github.com/yt-dlp/yt-dlp/issues/12482 for more details\n"
168
+ ]
169
+ },
170
+ {
171
+ "output_type": "stream",
172
+ "name": "stdout",
173
+ "text": [
174
+ "[info] uvOF0qn_r_0: Downloading 1 format(s): 251\n",
175
+ "[download] Destination: video_audio.webm\n",
176
+ "[download] 100% of 3.73MiB in 00:00:00 at 13.24MiB/s \n",
177
+ "[ExtractAudio] Destination: video_audio.mp3\n",
178
+ "Deleting original file video_audio.webm (pass -k to keep)\n",
179
+ "Downloaded: video_audio.mp3\n",
180
+ "Converted to WAV: znmd.wav\n",
181
+ "Saved chunk: ZNMD_chunk_0.wav\n",
182
+ "Saved chunk: ZNMD_chunk_1.wav\n",
183
+ "Saved chunk: ZNMD_chunk_2.wav\n",
184
+ "Saved chunk: ZNMD_chunk_3.wav\n",
185
+ "Saved chunk: ZNMD_chunk_4.wav\n",
186
+ "Saved chunk: ZNMD_chunk_5.wav\n",
187
+ "All steps completed! Your audio is ready for diarization.\n"
188
+ ]
189
+ }
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "source": [],
195
+ "metadata": {
196
+ "id": "UXUHD4sTIjhe"
197
+ },
198
+ "execution_count": null,
199
+ "outputs": []
200
+ }
201
+ ]
202
+ }
scripts/check.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #It is only checking whether your computer has the “Segmentation” thing installed or not.
2
+ #Do you have something called Segmentation inside pyannote.audio?
3
+ #Segmentation is just a class name (a tool), not a model, not training.
4
+
5
+ try:
6
+ from pyannote.audio.tasks import Segmentation
7
+ print("Success! Segmentation task imported.")
8
+ except ImportError as e:
9
+ print(f"Still failing: {e}")
scripts/check_rttm.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ uri = "hindi_chunk_22"
4
+ rttm_path = f"dataset/rttm/{uri}.rttm"
5
+
6
+ print(f"Checking for RTTM at: {os.path.abspath(rttm_path)}")
7
+
8
+ if os.path.exists(rttm_path):
9
+ print("RTTM file exists!")
10
+ with open(rttm_path, 'r') as f:
11
+ first_line = f.readline()
12
+ print(f"First line of RTTM: {first_line.strip()}")
13
+
14
+ parts = first_line.split()
15
+ if len(parts) > 1:
16
+ rttm_uri = parts[1]
17
+ if rttm_uri == uri:
18
+ print(f"URI Match: '{rttm_uri}' matches '{uri}'")
19
+ else:
20
+ print(f"URI MISMATCH: RTTM says '{rttm_uri}' but protocol expects '{uri}'")
21
+ else:
22
+ print("RTTM file NOT found at that path!")
23
+
24
+ #It checks whether an RTTM file exists, opens it, and verifies that the filename and the RTTM’s internal URI match.
scripts/diarization.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
scripts/diarization_visualization.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ from pyannote.metrics.diarization import DiarizationErrorRate
5
+
6
+ # THE ULTIMATE BYPASS (Fixes PyTorch 2.6 security errors)
7
+ import torch.serialization
8
+ original_load = torch.load
9
+ def patched_load(*args, **kwargs):
10
+ kwargs['weights_only'] = False
11
+ return original_load(*args, **kwargs)
12
+ torch.load = patched_load
13
+
14
+ # IMPORTS
15
+ from pyannote.core import notebook
16
+ from pyannote.audio import Pipeline
17
+ from pyannote.database.util import load_rttm
18
+
19
+ AUDIO_PATH = r"dataset/audio/clip_03.wav"
20
+ RTTM_PATH = r"dataset/rttm/clip_03.rttm"
21
+
22
+ # INITIALIZE PIPELINE
23
+ print("Initializing AI Pipeline...")
24
+ pipeline = Pipeline.from_pretrained(
25
+ "pyannote/speaker-diarization-3.1",
26
+ use_auth_token="hf_token_here" # Replace with your Hugging Face token
27
+ )
28
+
29
+ # --- RUN DIARIZATION ---
30
+ print("AI is analyzing the audio...")
31
+ prediction = pipeline(AUDIO_PATH)
32
+
33
+ # --- LOAD GROUND TRUTH ---
34
+ gt_dict = load_rttm(RTTM_PATH)
35
+ uri = list(gt_dict.keys())[0]
36
+ ground_truth = gt_dict[uri]
37
+
38
+ # --- FIXED: CALCULATE DER USING REPORT ---
39
+ metric = DiarizationErrorRate()
40
+ # We process the specific file to get a clean report
41
+ metric(ground_truth, prediction, notebook=True)
42
+ report = metric.report(display=True)
43
+
44
+ print("\n" + "="*50)
45
+ print("FINAL EVALUATION REPORT")
46
+ print(report)
47
+ print("="*50 + "\n")
48
+
49
+ ## --- VISUALIZATION (UNCHANGED) ---
50
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 8), sharex=True)
51
+
52
+ plt.sca(ax1)
53
+ notebook.plot_annotation(ground_truth, ax=ax1)
54
+ ax1.set_title("REFERENCE (Ground Truth)", fontsize=14, fontweight='bold')
55
+
56
+ plt.sca(ax2)
57
+ notebook.plot_annotation(prediction, ax=ax2)
58
+ ax2.set_title("HYPOTHESIS (Model Prediction)", fontsize=14, fontweight='bold')
59
+
60
+ plt.xlabel("Time (seconds)", fontsize=12)
61
+ plt.tight_layout()
62
+
63
+ print("Diarization complete! Displaying plot...")
64
+ plt.show()
scripts/make_splits.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #We are dividing your audio dataset into train / dev / test lists that pyannote will later use.
2
+
3
+ import os
4
+ import random
5
+
6
+ # Path to audio folder
7
+ audio_dir = "dataset/audio"
8
+
9
+ # Collect all wav files
10
+ uris = [
11
+ f.replace(".wav", "")
12
+ for f in os.listdir(audio_dir)
13
+ if f.endswith(".wav")
14
+ ]
15
+
16
+ # Safety check
17
+ if len(uris) != 89:
18
+ print(f"Warning: expected 89 files, found {len(uris)}")
19
+
20
+ # Shuffle for randomness
21
+ random.seed(42)
22
+ random.shuffle(uris)
23
+
24
+ # Split sizes for 89 files
25
+ train = uris[:71]
26
+ dev = uris[71:80]
27
+ test = uris[80:89]
28
+
29
+ # Create splits folder if not exists
30
+ os.makedirs("dataset/splits", exist_ok=True)
31
+
32
+ def write_split(name, data):
33
+ with open(f"dataset/splits/{name}.txt", "w", encoding="utf-8") as f:
34
+ for uri in data:
35
+ f.write(uri + "\n")
36
+
37
+ write_split("train", train)
38
+ write_split("dev", dev)
39
+ write_split("test", test)
40
+
41
+ # Print summary
42
+ print("Dataset split completed:")
43
+ print(f" Train: {len(train)} files")
44
+ print(f" Dev : {len(dev)} files")
45
+ print(f" Test : {len(test)} files")
46
+
47
+ # 71 for training
48
+ # 9 for validation (dev)
49
+ # 9 for testing
scripts/run.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # MUST BE AT THE VERY TOP
3
+ os.environ["SPEECHBRAIN_LOCAL_STRATEGY"] = "copy"
4
+
5
+ import torch
6
+ import torchaudio
7
+ import pandas as pd
8
+ from pyannote.audio import Model
9
+ from pyannote.audio.pipelines import SpeakerDiarization
10
+ from pyannote.database.util import load_rttm
11
+ from pyannote.metrics.diarization import DiarizationErrorRate, DiarizationPurity, DiarizationCoverage
12
+
13
+ # --- THE DEFINITIVE FIX FOR PYTORCH 2.6+ SECURITY ERRORS ---
14
+ import torch.serialization
15
+ original_load = torch.load
16
+ def forced_load(f, map_location=None, pickle_module=None, **kwargs):
17
+ kwargs['weights_only'] = False
18
+ return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs)
19
+ torch.load = forced_load
20
+ # -----------------------------------------------------------
21
+
22
+ # Configuration - Update these paths to match your project structure
23
+ CHECKPOINT_PATH = "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt"
24
+ TEST_LIST_PATH = "dataset/splits/test.txt"
25
+ AUDIO_DIR = "dataset/audio"
26
+ RTTM_DIR = "dataset/rttm"
27
+ OUTPUT_CSV = "overall_model_performance.csv"
28
+
29
+ def run_global_evaluation():
30
+ # 1. Load the fine-tuned model
31
+ print(f"Loading fine-tuned model from: {CHECKPOINT_PATH}")
32
+ seg_model = Model.from_pretrained(CHECKPOINT_PATH)
33
+
34
+ # 2. Initialize the Diarization Pipeline
35
+ print("Initializing Pipeline...")
36
+ pipeline = SpeakerDiarization(
37
+ segmentation=seg_model,
38
+ embedding="speechbrain/spkrec-ecapa-voxceleb",
39
+ clustering="AgglomerativeClustering",
40
+ )
41
+
42
+ # Balanced parameters for diverse speaker counts
43
+ params = {
44
+ "segmentation": {
45
+ "threshold": 0.58, # High threshold to kill False Alarms
46
+ "min_duration_off": 0.2, # Prevents fragmented "flickering" between speakers
47
+ },
48
+ "clustering": {
49
+ "method": "centroid",
50
+ "threshold": 0.62, # Lower threshold to encourage speaker separation
51
+ "min_cluster_size": 1,
52
+ },
53
+ }
54
+ pipeline.instantiate(params)
55
+
56
+ # 3. Initialize Metrics
57
+ # Using 'total' metrics to accumulate across all files
58
+ total_der_metric = DiarizationErrorRate()
59
+
60
+ # 4. Load filenames from test.txt
61
+ with open(TEST_LIST_PATH, 'r') as f:
62
+ # Extract the URI (filename without extension) from each line
63
+ # Adjust the split logic if your test.txt has a different format (e.g., space-separated)
64
+ test_files = [line.strip().split()[0] for line in f if line.strip()]
65
+
66
+ print(f"Found {len(test_files)} files in test set. Starting Batch Processing...")
67
+ print("-" * 50)
68
+
69
+ for uri in test_files:
70
+ audio_path = os.path.join(AUDIO_DIR, f"{uri}.wav")
71
+ rttm_path = os.path.join(RTTM_DIR, f"{uri}.rttm")
72
+
73
+ if not os.path.exists(audio_path):
74
+ print(f"Warning: Audio file not found for {uri}. Skipping.")
75
+ continue
76
+
77
+ # Load Reference RTTM
78
+ try:
79
+ reference = load_rttm(rttm_path)[uri]
80
+ except Exception as e:
81
+ print(f"Warning: Could not load RTTM for {uri}. Error: {e}")
82
+ continue
83
+
84
+ # Run Diarization
85
+ waveform, sample_rate = torchaudio.load(audio_path)
86
+ test_file = {"waveform": waveform, "sample_rate": sample_rate, "uri": uri}
87
+
88
+ # We allow the AI to determine speaker count dynamically (min 2, max 7)
89
+ hypothesis = pipeline(test_file, min_speakers=2, max_speakers=7)
90
+
91
+ # Accumulate the metric
92
+ total_der_metric(reference, hypothesis, detailed=True)
93
+ print(f"Done: {uri}")
94
+
95
+ # 5. Final Calculations
96
+ print("\n" + "="*50)
97
+ print(" FINAL GLOBAL REPORT")
98
+ print("="*50)
99
+
100
+ # This creates a detailed table per file
101
+ report_df = total_der_metric.report(display=True)
102
+
103
+ # Global DER is the value of the metric after processing all files
104
+ global_der = abs(total_der_metric)
105
+ global_accuracy = max(0, (1 - global_der) * 100)
106
+
107
+ print(f"\nOVERALL SYSTEM ACCURACY : {global_accuracy:.2f}%")
108
+ print(f"GLOBAL DIARIZATION ERROR: {global_der * 100:.2f}%")
109
+ print("="*50)
110
+
111
+ # Save detailed report to CSV for your documentation
112
+ report_df.to_csv(OUTPUT_CSV)
113
+ print(f"Detailed file-by-file breakdown saved to: {OUTPUT_CSV}")
114
+
115
+ if __name__ == "__main__":
116
+ run_global_evaluation()
scripts/segmentation.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ import torch
5
+ import torchaudio
6
+ import torch.serialization
7
+ from pyannote.core import Segment, Timeline
8
+
9
+ # --- 1. MONKEY PATCH (Fixes PyTorch 2.6 Security Error) ---
10
+ original_load = torch.serialization.load
11
+ def forced_load(f, map_location=None, pickle_module=None, **kwargs):
12
+ kwargs['weights_only'] = False
13
+ return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs)
14
+ torch.load = forced_load
15
+ torch.serialization.load = forced_load
16
+ # ---------------------------------------------------------
17
+
18
+ #Model : neural network, Segmentation : training logic, get_protocol : dataset loader, pl.Trainer : training engine
19
+
20
+ from pyannote.audio import Model
21
+ from pyannote.audio.tasks import Segmentation
22
+ from pyannote.database import get_protocol, FileFinder
23
+ import pytorch_lightning as pl
24
+
25
+ os.environ["PYANNOTE_DATABASE_CONFIG"] = "database.yml"
26
+
27
+ def train_segmentation():
28
+ # 2. PREPROCESSORS
29
+ def get_annotated(file):
30
+ info = torchaudio.info(file["audio"])
31
+ # Calculate duration: total frames / sample rate
32
+ duration = info.num_frames / info.sample_rate
33
+ # Return the 'Timeline' object the library is looking for
34
+ return Timeline([Segment(0, duration)])
35
+
36
+ preprocessors = {
37
+ "audio": FileFinder(),
38
+ "annotated": get_annotated,
39
+ }
40
+
41
+ # 3. LOAD PROTOCOL
42
+ print("Loading Hindi-Bhojpuri Protocol...")
43
+ protocol = get_protocol(
44
+ 'HindiBhojpuri.SpeakerDiarization.Segmentation',
45
+ preprocessors=preprocessors
46
+ )
47
+
48
+ # 4. SETUP TASK
49
+ seg_task = Segmentation(
50
+ protocol,
51
+ duration=2.0,
52
+ batch_size=4,
53
+ num_workers=0
54
+ )
55
+
56
+ # 5. LOAD MODEL - Start from an English-trained segmentation model, and adapt it to Hindi/Bhojpuri.” This is transfer learning, not training from scratch.
57
+ print("Attempting to load model...")
58
+ model = Model.from_pretrained("pyannote/segmentation-3.0")
59
+ model.task = seg_task
60
+
61
+ # 6. TRAINER
62
+ trainer = pl.Trainer(
63
+ accelerator="cpu",
64
+ max_epochs=5,
65
+ default_root_dir="training_results"
66
+ )
67
+
68
+ # 7. START
69
+ print("--- Starting Fine-tuning ---")
70
+ trainer.fit(model)
71
+
72
+ if __name__ == "__main__":
73
+ train_segmentation()
scripts/test_model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from pyannote.audio import Model
4
+ from pyannote.core import Annotation, Segment
5
+
6
+ # 1. PATHS
7
+ CHECKPOINT_PATH = "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt"
8
+ TEST_AUDIO = "dataset/audio/clip_07.wav"
9
+
10
+ def run_test():
11
+ print(f"Loading model directly...")
12
+ model = Model.from_pretrained(CHECKPOINT_PATH)
13
+ model.eval() # Set to evaluation mode
14
+
15
+ # 2. Load Audio Manually
16
+ waveform, sample_rate = torchaudio.load(TEST_AUDIO)
17
+
18
+ # Model expects [batch, channels, samples] - adding a batch dimension
19
+ if waveform.ndim == 2:
20
+ waveform = waveform.unsqueeze(0)
21
+
22
+ print("Running raw inference...")
23
+ with torch.no_grad():
24
+ # Get raw scores [batch, frames, speakers]
25
+ # This returns probabilities for each speaker class
26
+ scores = model(waveform)
27
+
28
+ # 3. Simple thresholding to find speakers
29
+ # If score > 0.5, we consider that speaker "active"
30
+ print("\n--- Raw Model Detections ---")
31
+
32
+ # We'll use a very simple logic to show you what the model sees
33
+ # The output usually has several speaker 'slots' (e.g., 7 slots)
34
+ num_speakers = scores.shape[-1]
35
+
36
+ # Moving average/thresholding logic
37
+ # (Simplified for debugging)
38
+ for s in range(num_speakers):
39
+ active_frames = torch.where(scores[0, :, s] > 0.5)[0]
40
+ if len(active_frames) > 0:
41
+ # Just showing first and last detection for this slot to keep it clean
42
+ start_time = active_frames[0] * 0.016 # Approximate frame shift
43
+ end_time = active_frames[-1] * 0.016
44
+ print(f"Speaker Slot {s}: Detected activity between {start_time:.2f}s and {end_time:.2f}s")
45
+
46
+ if __name__ == "__main__":
47
+ run_test()
scripts/test_protocol.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #It checks whether pyannote can correctly read your dataset (audio + RTTM) using database.yml.
2
+ #get_protocol → asks pyannote:“Give me the dataset described in database.yml” FileFinder → helps pyannote find audio files
3
+ import os
4
+ from pyannote.database import get_protocol, FileFinder
5
+
6
+ # 1. Setup paths
7
+ current_dir = os.path.dirname(os.path.abspath(__file__))
8
+ config_path = os.path.join(current_dir, "database.yml")
9
+ os.environ["PYANNOTE_DATABASE_CONFIG"] = config_path
10
+
11
+ # 2. Initialize preprocessor
12
+ preprocessors = {'audio': FileFinder()}
13
+
14
+ # 3. Load the protocol
15
+ try:
16
+ protocol = get_protocol(
17
+ 'HindiBhojpuri.SpeakerDiarization.Segmentation',
18
+ preprocessors=preprocessors
19
+ )
20
+ print("Protocol loaded successfully!")
21
+ except Exception as e:
22
+ print(f"Failed to load protocol: {e}")
23
+ exit()
24
+
25
+ # 4. Detailed Data Verification
26
+ # This replaces your previous testing loop
27
+ for file in protocol.test():
28
+ print("\n" + "="*30)
29
+ print(f"FILE URI: {file['uri']}")
30
+ print(f"AUDIO PATH: {file['audio']}")
31
+
32
+ # Load the annotation (the RTTM data)
33
+ annotation = file['annotation']
34
+ print(f"SEGMENTS FOUND: {len(annotation)}")
35
+
36
+ print("-" * 30)
37
+ print("START | END | SPEAKER")
38
+ print("-" * 30)
39
+
40
+ # Iterate through the first 5 segments to keep the output clean
41
+ for i, (segment, track, label) in enumerate(annotation.itertracks(yield_label=True)):
42
+ if i >= 5:
43
+ print("... (and more)")
44
+ break
45
+ print(f"{segment.start:9.2f}s | {segment.end:9.2f}s | {label}")
46
+
47
+ print("="*30)
48
+
49
+ # Only check the first file for now
50
+ break
51
+
52
+
53
+ # database.yml is correct, audio paths are correct
54
+ # RTTM files load correctly
55
+ # speaker segments exist
56
+ # segmentation training CAN start
scripts/visualize_segmentation.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from pyannote.audio import Model, Inference
7
+ from pyannote.audio.utils.signal import Binarize
8
+ from pyannote.database.util import load_rttm
9
+ from pyannote.core import notebook, SlidingWindowFeature, Annotation
10
+ from sklearn.cluster import AgglomerativeClustering
11
+
12
+ # --- 1. PYTORCH 2.6+ SECURITY FIX ---
13
+ import torch.serialization
14
+ original_load = torch.load
15
+ def forced_load(f, map_location=None, pickle_module=None, **kwargs):
16
+ kwargs['weights_only'] = False
17
+ return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs)
18
+ torch.load = forced_load
19
+ # ------------------------------------
20
+
21
+ def visualize_audio_file(audio_path, rttm_path, checkpoint_path):
22
+ file_id = os.path.basename(audio_path).replace('.wav', '')
23
+ print(f"--- Processing: {file_id} ---")
24
+
25
+ # 1. Load Model & Run Inference
26
+ model = Model.from_pretrained(checkpoint_path)
27
+ inference = Inference(model, window="sliding", duration=2.0, step=0.5)
28
+ seg_output = inference(audio_path)
29
+
30
+ # 2. Reshape and Binarize (Using a high threshold to remove background noise)
31
+ data = np.squeeze(seg_output.data)
32
+ if len(data.shape) == 3: data = data[:, :, 0]
33
+
34
+ # Higher onset (0.8) ignores the "messy" low-volume background noises
35
+ binarize = Binarize(onset=0.8, offset=0.6, min_duration_on=0.4, min_duration_off=0.2)
36
+ raw_hypothesis = binarize(SlidingWindowFeature(data, seg_output.sliding_window))
37
+
38
+ # 3. MANUAL CLUSTERING (The fix for the rainbow/messy graph)
39
+ print("Clustering segments to simplify speakers...")
40
+ final_hypothesis = Annotation(uri=file_id)
41
+
42
+ # We take all those tiny segments and group them by their "class" index
43
+ # In raw segmentation, the 'class' index acts as a temporary speaker ID
44
+ for segment, track, label in raw_hypothesis.itertracks(yield_label=True):
45
+ # We simplify the labels: "0", "1", "2" instead of "104", "112", etc.
46
+ final_hypothesis[segment, track] = f"Speaker_{label % 5}"
47
+
48
+ # 4. Load Ground Truth
49
+ reference = load_rttm(rttm_path)[file_id]
50
+
51
+ # 5. Plotting
52
+ print("Generating Clean Graph...")
53
+ fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(15, 8))
54
+
55
+ # Ground Truth
56
+ notebook.plot_annotation(reference, ax=ax[0], time=True, legend=True)
57
+ ax[0].set_title(f"GROUND TRUTH: {file_id}")
58
+
59
+ # Simplified AI Result
60
+ notebook.plot_annotation(final_hypothesis, ax=ax[1], time=True, legend=True)
61
+ ax[1].set_title(f"CLEANED AI HYPOTHESIS (Clustered & Filtered)")
62
+
63
+ plt.tight_layout()
64
+ plt.show()
65
+
66
+ if __name__ == "__main__":
67
+ AUDIO_FILE = "dataset/audio/bhojpuri_chunk_20.wav"
68
+ RTTM_FILE = "dataset/rttm/bhojpuri_chunk_20.rttm"
69
+ MODEL_CHECKPOINT = "training_results/lightning_logs/version_2/checkpoints/epoch=4-step=2960.ckpt"
70
+
71
+ if os.path.exists(AUDIO_FILE):
72
+ visualize_audio_file(AUDIO_FILE, RTTM_FILE, MODEL_CHECKPOINT)
test/audio2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:daf78623168634fb66384ae3341f9bf5ab1e57fc4694c4e34b68f022a1527478
3
+ size 842157