Capstone04 commited on
Commit
0b3a48c
·
verified ·
1 Parent(s): 273f384

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - asr
5
+ - diarization
6
+ pipeline_tag: automatic-speech-recognition
7
+ ---
8
+ # ASR + Diarization Pipeline
9
+
10
+ This package provides an **Automatic Speech Recognition (ASR) + Speaker Diarization** pipeline using:
11
+ - [OpenAI Whisper](https://huggingface.co/openai/whisper-medium)
12
+ - [Pyannote diarization](https://huggingface.co/pyannote/speaker-diarization-3.1)
13
+
14
+ ## Install
15
+ ```bash
16
+ pip install git+https://huggingface.co/Capstone04/asr-diarization-pipeline
17
+
18
+ ## Speaker Identification
19
+ You can now enroll known speakers by providing reference audio samples. The pipeline will match incoming speaker segments against stored embeddings and label them accordingly. Unknown speakers are dynamically tracked per session.
asr_diarization/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pipeline import ASR_Diarization
asr_diarization/inference.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .pipeline import ASR_Diarization
3
+
4
+ import json
5
+ import numpy as np
6
+
7
+ def load_known_embeddings(path="known_speakers.json"):
8
+ if not os.path.exists(path):
9
+ return {}
10
+ with open(path, "r") as f:
11
+ raw = json.load(f)
12
+ return {name: np.array(emb, dtype=np.float32) for name, emb in raw.items()}
13
+
14
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
15
+ known_embeddings = load_known_embeddings()
16
+ pipe = ASR_Diarization(HF_TOKEN, known_embeddings=known_embeddings)
17
+
18
+ def inference(inputs):
19
+ return pipe(inputs)
20
+
21
+ def inference_with_eval(inputs, output_dir, base_name, ref_rttm=None, ref_json=None):
22
+ result = pipe(inputs)
23
+ pipe.evaluate(output_dir, base_name, ref_rttm, ref_json)
24
+ return result
asr_diarization/pipeline.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+
5
+ # Fix TF32 reproducibility warning and potential computation issues
6
+ if torch.cuda.is_available():
7
+ torch.backends.cuda.matmul.allow_tf32 = True
8
+ torch.backends.cudnn.allow_tf32 = True
9
+
10
+ import tempfile
11
+ import torchaudio
12
+ import threading
13
+ import numpy as np
14
+ import soundfile as sf
15
+ import noisereduce as nr
16
+ from scipy import signal
17
+ from numpy.linalg import norm
18
+ from pyannote.audio import Pipeline
19
+ from speechbrain.pretrained import EncoderClassifier
20
+ from pyannote.core import Annotation, Segment
21
+ from transformers import pipeline as hf_pipeline
22
+ from pyannote.metrics.diarization import DiarizationErrorRate
23
+ from jiwer import wer, Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSpaces, Strip
24
+
25
+ class ASR_Diarization:
26
+ def __init__(self, HF_TOKEN,
27
+ diar_model="pyannote/speaker-diarization-3.1",
28
+ asr_model="openai/whisper-medium"):
29
+ self.HF_TOKEN = HF_TOKEN
30
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ self._unknown_lock = threading.Lock()
32
+
33
+ try:
34
+ self.embedding_model = EncoderClassifier.from_hparams(
35
+ source="speechbrain/spkrec-ecapa-voxceleb",
36
+ run_opts={"device": str(self.device)}
37
+ )
38
+ print("[ECAPA] Model loaded successfully.")
39
+ except Exception as e:
40
+ self.embedding_model = None
41
+ print(f"[ERROR] Failed to load ECAPA: {e}")
42
+
43
+ self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
44
+ device_index = 0 if torch.cuda.is_available() else -1
45
+ self.asr_pipeline = hf_pipeline(
46
+ "automatic-speech-recognition",
47
+ model=asr_model,
48
+ device=device_index,
49
+ return_timestamps=True
50
+ )
51
+
52
+ def run_diarization(self, audio_path):
53
+ diarization = self.diar_pipeline(audio_path)
54
+ return [
55
+ {"start": t.start, "end": t.end, "speaker": spk}
56
+ for t, _, spk in diarization.itertracks(yield_label=True)
57
+ ]
58
+
59
+ def load_unknown_speakers(self, unknown_speakers_path):
60
+ if os.path.exists(unknown_speakers_path):
61
+ try:
62
+ with open(unknown_speakers_path, "r") as f:
63
+ content = f.read().strip()
64
+ if content:
65
+ return json.loads(content)
66
+ except Exception as e:
67
+ print(f"[WARN] Failed to load unknown speakers ({e}), starting fresh")
68
+ return {}
69
+
70
+ def save_unknown_speakers(self, unknown_speakers, unknown_speakers_path):
71
+ try:
72
+ os.makedirs(os.path.dirname(unknown_speakers_path), exist_ok=True)
73
+ tmp = unknown_speakers_path + ".tmp"
74
+ with open(tmp, "w", encoding="utf-8") as f:
75
+ json.dump(unknown_speakers, f, indent=2)
76
+ f.flush()
77
+ os.fsync(f.fileno())
78
+ os.replace(tmp, unknown_speakers_path)
79
+ return True
80
+ except Exception as e:
81
+ print(f"[ERROR] Failed to save unknown speakers: {e}")
82
+ return False
83
+
84
+ def get_next_unknown_id(self, unknown_speakers):
85
+ if not unknown_speakers:
86
+ return "unknown_1"
87
+ max_id = 0
88
+ for speaker_id in unknown_speakers.keys():
89
+ if speaker_id.startswith("unknown_"):
90
+ try:
91
+ num = int(speaker_id.split("_")[1])
92
+ max_id = max(max_id, num)
93
+ except (IndexError, ValueError):
94
+ continue
95
+ return f"unknown_{max_id + 1}"
96
+
97
+ def match_speaker_embedding(self, cluster_embedding, enrolled_speakers_np, unknown_speakers, threshold=0.5):
98
+ cluster_embedding = cluster_embedding / norm(cluster_embedding)
99
+ best_name, best_score, is_enrolled = None, -1.0, False
100
+
101
+ # Log all similarities
102
+ sim_log = []
103
+
104
+ # Check enrolled speakers
105
+ for name, e_emb in enrolled_speakers_np.items():
106
+ sim = float(np.dot(cluster_embedding, e_emb / norm(e_emb)))
107
+ sim_log.append((name, sim, True))
108
+ if sim > best_score:
109
+ best_name, best_score, is_enrolled = name, sim, True
110
+
111
+ # Check unknown speakers
112
+ for u_id, u_emb in unknown_speakers.items():
113
+ sim = float(np.dot(cluster_embedding, np.array(u_emb) / norm(u_emb)))
114
+ sim_log.append((u_id, sim, False))
115
+ if sim > best_score:
116
+ best_name, best_score, is_enrolled = u_id, sim, False
117
+
118
+ # Log before creating new unknown
119
+ print("[MATCH LOG] Cluster embedding compared:", sim_log)
120
+ print(f"[MATCH LOG] Best match: {best_name}, score: {best_score}, enrolled: {is_enrolled}")
121
+
122
+ return best_name, best_score, is_enrolled
123
+
124
+
125
+ def run_transcription(self, audio_path, diar_json, enrolled_speakers=None, unknown_speakers_path=None):
126
+ unknown_speakers_path = unknown_speakers_path or os.path.join(os.path.dirname(audio_path), "unknown_speakers.json")
127
+
128
+ # Load unknown speakers safely
129
+ with self._unknown_lock:
130
+ unknown_speakers = self.load_unknown_speakers(unknown_speakers_path)
131
+
132
+ audio, sr = torchaudio.load(audio_path)
133
+ audio_np = audio[0].numpy() if audio.shape[0] == 1 else audio.mean(dim=0).numpy()
134
+ merged_segments, speaker_segments = [], {}
135
+ enrolled_speakers_np = {n: v/norm(v) for n,v in (enrolled_speakers or {}).items() if norm(v) > 0}
136
+
137
+ target_sr = 16000
138
+ clusters = {}
139
+ for seg in diar_json:
140
+ clusters.setdefault(seg["speaker"], []).append(seg)
141
+
142
+ # Compute cluster embeddings
143
+ cluster_embeddings = {}
144
+ for cluster_label, segs in clusters.items():
145
+ seg_embs = []
146
+ for seg in segs:
147
+ start, end = seg["start"], seg["end"]
148
+ start_sample, end_sample = int(start*sr), int(end*sr)
149
+ chunk = audio_np[start_sample:end_sample]
150
+ if chunk.size < 8000:
151
+ chunk = np.pad(chunk, (0, 8000 - chunk.size), mode='constant')
152
+ if sr != target_sr:
153
+ chunk = signal.resample(chunk, int(len(chunk)*target_sr/sr)).astype(np.float32)
154
+ if self.embedding_model:
155
+ tensor = torch.from_numpy(chunk).unsqueeze(0).to(self.device)
156
+ with torch.no_grad():
157
+ emb = np.ravel(self.embedding_model.encode_batch(tensor).squeeze().cpu().numpy())
158
+ if norm(emb) > 0:
159
+ seg_embs.append(emb / norm(emb))
160
+ if seg_embs:
161
+ cluster_emb = np.mean(np.stack(seg_embs), axis=0)
162
+ cluster_embeddings[cluster_label] = cluster_emb / norm(cluster_emb)
163
+
164
+ speaker_map, speakers_updated = {}, False
165
+ threshold = 0.5
166
+
167
+ # Thread-safe unknown speaker update
168
+ with self._unknown_lock:
169
+ for cluster_label, c_emb in cluster_embeddings.items():
170
+ matched_name, best_score, is_enrolled = self.match_speaker_embedding(
171
+ c_emb, enrolled_speakers_np, unknown_speakers, threshold
172
+ )
173
+
174
+ if best_score >= threshold:
175
+ speaker_map[cluster_label] = matched_name
176
+ # Update unknown embedding if matched_name is an unknown
177
+ if not is_enrolled:
178
+ old_emb = np.array(unknown_speakers[matched_name])
179
+ new_emb = (old_emb + c_emb) / 2.0
180
+ unknown_speakers[matched_name] = (new_emb / norm(new_emb)).tolist()
181
+ speakers_updated = True
182
+ else:
183
+ # No sufficient match found, create new unknown
184
+ new_id = self.get_next_unknown_id(unknown_speakers)
185
+ unknown_speakers[new_id] = c_emb.tolist()
186
+ speaker_map[cluster_label] = new_id
187
+ speakers_updated = True
188
+
189
+ if speakers_updated:
190
+ self.save_unknown_speakers(unknown_speakers, unknown_speakers_path)
191
+
192
+ # ASR transcription (same as before)
193
+ for seg in diar_json:
194
+ start, end, spk = seg["start"], seg["end"], seg["speaker"]
195
+ start_sample, end_sample = int(start*sr), int(end*sr)
196
+ chunk = audio_np[start_sample:end_sample]
197
+ if chunk.size == 0: continue
198
+ if sr != target_sr:
199
+ chunk = signal.resample(chunk, int(len(chunk)*target_sr/sr)).astype(np.float32)
200
+ sr_chunk = target_sr
201
+ else:
202
+ sr_chunk = sr
203
+ try:
204
+ reduced = nr.reduce_noise(chunk, sr=sr_chunk)
205
+ except Exception:
206
+ reduced = chunk
207
+ try:
208
+ result = self.asr_pipeline({"array": reduced, "sampling_rate": sr_chunk})
209
+ except Exception:
210
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpf:
211
+ sf.write(tmpf.name, reduced, sr_chunk, subtype="PCM_16")
212
+ result = self.asr_pipeline(tmpf.name)
213
+
214
+ tokens, transcript_text = [], ""
215
+ if isinstance(result, dict) and "chunks" in result:
216
+ for w in result["chunks"]:
217
+ start_ts = w.get("start") or (w.get("timestamp") and w["timestamp"][0])
218
+ end_ts = w.get("end") or (w.get("timestamp") and w["timestamp"][1])
219
+ word_text = w.get("text","").strip()
220
+ tokens.append({"start":start_ts,"end":end_ts,"text":word_text,"tag":"w"})
221
+ transcript_text += word_text + " "
222
+ else:
223
+ text = result.get("text") if isinstance(result, dict) else str(result)
224
+ transcript_text = text or ""
225
+ tokens.append({"start":None,"end":None,"text":transcript_text,"tag":"w"})
226
+
227
+ final_speaker = speaker_map.get(spk,"unknown")
228
+ seg_dict = {"speaker":final_speaker,"start":start,"end":end,"text":transcript_text.strip(),"tokens":tokens}
229
+ merged_segments.append(seg_dict)
230
+ speaker_segments.setdefault(final_speaker,[]).append(seg_dict)
231
+
232
+ return merged_segments, list(speaker_segments.keys())
233
+
234
+ def run_pipeline(self, audio_path, output_dir=None, base_name=None,
235
+ ref_rttm=None, ref_json=None, enrolled_speakers=None, unknown_speakers_path=None):
236
+ diar_json = self.run_diarization(audio_path)
237
+ merged_segments, speakers = self.run_transcription(
238
+ audio_path, diar_json, enrolled_speakers=enrolled_speakers,
239
+ unknown_speakers_path=unknown_speakers_path
240
+ )
241
+
242
+ if output_dir and base_name:
243
+ os.makedirs(output_dir, exist_ok=True)
244
+
245
+ # Save RTTM
246
+ rttm_path = os.path.join(output_dir, f"{base_name}.rttm")
247
+ with open(rttm_path, "w") as f:
248
+ for seg in diar_json:
249
+ f.write(
250
+ f"SPEAKER {base_name} 1 {seg['start']:.6f} "
251
+ f"{seg['end']-seg['start']:.6f} <NA> <NA> "
252
+ f"{seg['speaker']} <NA>\n"
253
+ )
254
+
255
+ # Save transcription
256
+ merged_path = os.path.join(output_dir, f"{base_name}_merged_transcription.json")
257
+ with open(merged_path, "w") as f:
258
+ json.dump(merged_segments, f, indent=2)
259
+
260
+ # Evaluation
261
+ eval_results = None
262
+ if ref_rttm or ref_json:
263
+ eval_results = self.evaluate(output_dir, base_name,
264
+ ref_rttm=ref_rttm, ref_json=ref_json)
265
+
266
+ return {
267
+ "speakers": speakers,
268
+ "segments": merged_segments,
269
+ "evaluation": eval_results
270
+ }
271
+
272
+ def evaluate(self, output_dir, base_name, ref_rttm=None, ref_json=None):
273
+ results = {}
274
+ hyp_rttm = os.path.join(output_dir, f"{base_name}.rttm")
275
+ hyp_json = os.path.join(output_dir, f"{base_name}_merged_transcription.json")
276
+
277
+ if ref_rttm:
278
+ def load_rttm(path):
279
+ ann = Annotation()
280
+ for line in open(path):
281
+ if line.startswith("SPEAKER"):
282
+ p = line.split()
283
+ start, dur, spk = float(p[3]), float(p[4]), p[7]
284
+ ann[Segment(start, start+dur)] = spk
285
+ return ann
286
+
287
+ der_score = DiarizationErrorRate()(load_rttm(ref_rttm), load_rttm(hyp_rttm))
288
+ results["DER"] = round(der_score * 100, 2)
289
+
290
+ if ref_json:
291
+ def load_words(path):
292
+ data = json.load(open(path))
293
+ return " ".join([tok["text"] for seg in data for tok in seg["tokens"]])
294
+
295
+ ref_text, hyp_text = load_words(ref_json), load_words(hyp_json)
296
+ transform = Compose([ToLowerCase(), RemovePunctuation(),
297
+ RemoveMultipleSpaces(), Strip()])
298
+ results["WER_raw"] = round(wer(ref_text, hyp_text), 4)
299
+ results["WER_normalized"] = round(wer(transform(ref_text), transform(hyp_text)), 4)
300
+
301
+ return results if results else None
302
+
303
+ def __call__(self, inputs):
304
+ if isinstance(inputs, dict):
305
+ if "audio_bytes" in inputs:
306
+ audio_bytes = inputs["audio_bytes"]
307
+ elif "audio" in inputs:
308
+ audio_bytes = inputs["audio"]
309
+ else:
310
+ raise ValueError("No audio found in inputs")
311
+ else:
312
+ audio_bytes = inputs
313
+
314
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
315
+ tmp.write(audio_bytes)
316
+ tmp_path = tmp.name
317
+
318
+ return self.run_pipeline(tmp_path)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ pyannote.audio
4
+ transformers
5
+ noisereduce
6
+ scikit-learn
7
+ jiwer
8
+ librosa
setup.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="asr_diarization",
5
+ version="0.1.0",
6
+ packages=find_packages(),
7
+ install_requires=[
8
+ "torch",
9
+ "torchaudio",
10
+ "pyannote.audio",
11
+ "transformers",
12
+ "noisereduce",
13
+ "scikit-learn",
14
+ "jiwer",
15
+ "librosa"
16
+ ],
17
+ )