Upload folder using huggingface_hub
Browse files- README.md +19 -0
- asr_diarization/__init__.py +1 -0
- asr_diarization/inference.py +24 -0
- asr_diarization/pipeline.py +318 -0
- requirements.txt +8 -0
- setup.py +17 -0
README.md
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
tags:
|
| 4 |
+
- asr
|
| 5 |
+
- diarization
|
| 6 |
+
pipeline_tag: automatic-speech-recognition
|
| 7 |
+
---
|
| 8 |
+
# ASR + Diarization Pipeline
|
| 9 |
+
|
| 10 |
+
This package provides an **Automatic Speech Recognition (ASR) + Speaker Diarization** pipeline using:
|
| 11 |
+
- [OpenAI Whisper](https://huggingface.co/openai/whisper-medium)
|
| 12 |
+
- [Pyannote diarization](https://huggingface.co/pyannote/speaker-diarization-3.1)
|
| 13 |
+
|
| 14 |
+
## Install
|
| 15 |
+
```bash
|
| 16 |
+
pip install git+https://huggingface.co/Capstone04/asr-diarization-pipeline
|
| 17 |
+
|
| 18 |
+
## Speaker Identification
|
| 19 |
+
You can now enroll known speakers by providing reference audio samples. The pipeline will match incoming speaker segments against stored embeddings and label them accordingly. Unknown speakers are dynamically tracked per session.
|
asr_diarization/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .pipeline import ASR_Diarization
|
asr_diarization/inference.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from .pipeline import ASR_Diarization
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
def load_known_embeddings(path="known_speakers.json"):
|
| 8 |
+
if not os.path.exists(path):
|
| 9 |
+
return {}
|
| 10 |
+
with open(path, "r") as f:
|
| 11 |
+
raw = json.load(f)
|
| 12 |
+
return {name: np.array(emb, dtype=np.float32) for name, emb in raw.items()}
|
| 13 |
+
|
| 14 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 15 |
+
known_embeddings = load_known_embeddings()
|
| 16 |
+
pipe = ASR_Diarization(HF_TOKEN, known_embeddings=known_embeddings)
|
| 17 |
+
|
| 18 |
+
def inference(inputs):
|
| 19 |
+
return pipe(inputs)
|
| 20 |
+
|
| 21 |
+
def inference_with_eval(inputs, output_dir, base_name, ref_rttm=None, ref_json=None):
|
| 22 |
+
result = pipe(inputs)
|
| 23 |
+
pipe.evaluate(output_dir, base_name, ref_rttm, ref_json)
|
| 24 |
+
return result
|
asr_diarization/pipeline.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
# Fix TF32 reproducibility warning and potential computation issues
|
| 6 |
+
if torch.cuda.is_available():
|
| 7 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 8 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 9 |
+
|
| 10 |
+
import tempfile
|
| 11 |
+
import torchaudio
|
| 12 |
+
import threading
|
| 13 |
+
import numpy as np
|
| 14 |
+
import soundfile as sf
|
| 15 |
+
import noisereduce as nr
|
| 16 |
+
from scipy import signal
|
| 17 |
+
from numpy.linalg import norm
|
| 18 |
+
from pyannote.audio import Pipeline
|
| 19 |
+
from speechbrain.pretrained import EncoderClassifier
|
| 20 |
+
from pyannote.core import Annotation, Segment
|
| 21 |
+
from transformers import pipeline as hf_pipeline
|
| 22 |
+
from pyannote.metrics.diarization import DiarizationErrorRate
|
| 23 |
+
from jiwer import wer, Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSpaces, Strip
|
| 24 |
+
|
| 25 |
+
class ASR_Diarization:
|
| 26 |
+
def __init__(self, HF_TOKEN,
|
| 27 |
+
diar_model="pyannote/speaker-diarization-3.1",
|
| 28 |
+
asr_model="openai/whisper-medium"):
|
| 29 |
+
self.HF_TOKEN = HF_TOKEN
|
| 30 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
self._unknown_lock = threading.Lock()
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
self.embedding_model = EncoderClassifier.from_hparams(
|
| 35 |
+
source="speechbrain/spkrec-ecapa-voxceleb",
|
| 36 |
+
run_opts={"device": str(self.device)}
|
| 37 |
+
)
|
| 38 |
+
print("[ECAPA] Model loaded successfully.")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
self.embedding_model = None
|
| 41 |
+
print(f"[ERROR] Failed to load ECAPA: {e}")
|
| 42 |
+
|
| 43 |
+
self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
|
| 44 |
+
device_index = 0 if torch.cuda.is_available() else -1
|
| 45 |
+
self.asr_pipeline = hf_pipeline(
|
| 46 |
+
"automatic-speech-recognition",
|
| 47 |
+
model=asr_model,
|
| 48 |
+
device=device_index,
|
| 49 |
+
return_timestamps=True
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def run_diarization(self, audio_path):
|
| 53 |
+
diarization = self.diar_pipeline(audio_path)
|
| 54 |
+
return [
|
| 55 |
+
{"start": t.start, "end": t.end, "speaker": spk}
|
| 56 |
+
for t, _, spk in diarization.itertracks(yield_label=True)
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
def load_unknown_speakers(self, unknown_speakers_path):
|
| 60 |
+
if os.path.exists(unknown_speakers_path):
|
| 61 |
+
try:
|
| 62 |
+
with open(unknown_speakers_path, "r") as f:
|
| 63 |
+
content = f.read().strip()
|
| 64 |
+
if content:
|
| 65 |
+
return json.loads(content)
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f"[WARN] Failed to load unknown speakers ({e}), starting fresh")
|
| 68 |
+
return {}
|
| 69 |
+
|
| 70 |
+
def save_unknown_speakers(self, unknown_speakers, unknown_speakers_path):
|
| 71 |
+
try:
|
| 72 |
+
os.makedirs(os.path.dirname(unknown_speakers_path), exist_ok=True)
|
| 73 |
+
tmp = unknown_speakers_path + ".tmp"
|
| 74 |
+
with open(tmp, "w", encoding="utf-8") as f:
|
| 75 |
+
json.dump(unknown_speakers, f, indent=2)
|
| 76 |
+
f.flush()
|
| 77 |
+
os.fsync(f.fileno())
|
| 78 |
+
os.replace(tmp, unknown_speakers_path)
|
| 79 |
+
return True
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"[ERROR] Failed to save unknown speakers: {e}")
|
| 82 |
+
return False
|
| 83 |
+
|
| 84 |
+
def get_next_unknown_id(self, unknown_speakers):
|
| 85 |
+
if not unknown_speakers:
|
| 86 |
+
return "unknown_1"
|
| 87 |
+
max_id = 0
|
| 88 |
+
for speaker_id in unknown_speakers.keys():
|
| 89 |
+
if speaker_id.startswith("unknown_"):
|
| 90 |
+
try:
|
| 91 |
+
num = int(speaker_id.split("_")[1])
|
| 92 |
+
max_id = max(max_id, num)
|
| 93 |
+
except (IndexError, ValueError):
|
| 94 |
+
continue
|
| 95 |
+
return f"unknown_{max_id + 1}"
|
| 96 |
+
|
| 97 |
+
def match_speaker_embedding(self, cluster_embedding, enrolled_speakers_np, unknown_speakers, threshold=0.5):
|
| 98 |
+
cluster_embedding = cluster_embedding / norm(cluster_embedding)
|
| 99 |
+
best_name, best_score, is_enrolled = None, -1.0, False
|
| 100 |
+
|
| 101 |
+
# Log all similarities
|
| 102 |
+
sim_log = []
|
| 103 |
+
|
| 104 |
+
# Check enrolled speakers
|
| 105 |
+
for name, e_emb in enrolled_speakers_np.items():
|
| 106 |
+
sim = float(np.dot(cluster_embedding, e_emb / norm(e_emb)))
|
| 107 |
+
sim_log.append((name, sim, True))
|
| 108 |
+
if sim > best_score:
|
| 109 |
+
best_name, best_score, is_enrolled = name, sim, True
|
| 110 |
+
|
| 111 |
+
# Check unknown speakers
|
| 112 |
+
for u_id, u_emb in unknown_speakers.items():
|
| 113 |
+
sim = float(np.dot(cluster_embedding, np.array(u_emb) / norm(u_emb)))
|
| 114 |
+
sim_log.append((u_id, sim, False))
|
| 115 |
+
if sim > best_score:
|
| 116 |
+
best_name, best_score, is_enrolled = u_id, sim, False
|
| 117 |
+
|
| 118 |
+
# Log before creating new unknown
|
| 119 |
+
print("[MATCH LOG] Cluster embedding compared:", sim_log)
|
| 120 |
+
print(f"[MATCH LOG] Best match: {best_name}, score: {best_score}, enrolled: {is_enrolled}")
|
| 121 |
+
|
| 122 |
+
return best_name, best_score, is_enrolled
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def run_transcription(self, audio_path, diar_json, enrolled_speakers=None, unknown_speakers_path=None):
|
| 126 |
+
unknown_speakers_path = unknown_speakers_path or os.path.join(os.path.dirname(audio_path), "unknown_speakers.json")
|
| 127 |
+
|
| 128 |
+
# Load unknown speakers safely
|
| 129 |
+
with self._unknown_lock:
|
| 130 |
+
unknown_speakers = self.load_unknown_speakers(unknown_speakers_path)
|
| 131 |
+
|
| 132 |
+
audio, sr = torchaudio.load(audio_path)
|
| 133 |
+
audio_np = audio[0].numpy() if audio.shape[0] == 1 else audio.mean(dim=0).numpy()
|
| 134 |
+
merged_segments, speaker_segments = [], {}
|
| 135 |
+
enrolled_speakers_np = {n: v/norm(v) for n,v in (enrolled_speakers or {}).items() if norm(v) > 0}
|
| 136 |
+
|
| 137 |
+
target_sr = 16000
|
| 138 |
+
clusters = {}
|
| 139 |
+
for seg in diar_json:
|
| 140 |
+
clusters.setdefault(seg["speaker"], []).append(seg)
|
| 141 |
+
|
| 142 |
+
# Compute cluster embeddings
|
| 143 |
+
cluster_embeddings = {}
|
| 144 |
+
for cluster_label, segs in clusters.items():
|
| 145 |
+
seg_embs = []
|
| 146 |
+
for seg in segs:
|
| 147 |
+
start, end = seg["start"], seg["end"]
|
| 148 |
+
start_sample, end_sample = int(start*sr), int(end*sr)
|
| 149 |
+
chunk = audio_np[start_sample:end_sample]
|
| 150 |
+
if chunk.size < 8000:
|
| 151 |
+
chunk = np.pad(chunk, (0, 8000 - chunk.size), mode='constant')
|
| 152 |
+
if sr != target_sr:
|
| 153 |
+
chunk = signal.resample(chunk, int(len(chunk)*target_sr/sr)).astype(np.float32)
|
| 154 |
+
if self.embedding_model:
|
| 155 |
+
tensor = torch.from_numpy(chunk).unsqueeze(0).to(self.device)
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
emb = np.ravel(self.embedding_model.encode_batch(tensor).squeeze().cpu().numpy())
|
| 158 |
+
if norm(emb) > 0:
|
| 159 |
+
seg_embs.append(emb / norm(emb))
|
| 160 |
+
if seg_embs:
|
| 161 |
+
cluster_emb = np.mean(np.stack(seg_embs), axis=0)
|
| 162 |
+
cluster_embeddings[cluster_label] = cluster_emb / norm(cluster_emb)
|
| 163 |
+
|
| 164 |
+
speaker_map, speakers_updated = {}, False
|
| 165 |
+
threshold = 0.5
|
| 166 |
+
|
| 167 |
+
# Thread-safe unknown speaker update
|
| 168 |
+
with self._unknown_lock:
|
| 169 |
+
for cluster_label, c_emb in cluster_embeddings.items():
|
| 170 |
+
matched_name, best_score, is_enrolled = self.match_speaker_embedding(
|
| 171 |
+
c_emb, enrolled_speakers_np, unknown_speakers, threshold
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if best_score >= threshold:
|
| 175 |
+
speaker_map[cluster_label] = matched_name
|
| 176 |
+
# Update unknown embedding if matched_name is an unknown
|
| 177 |
+
if not is_enrolled:
|
| 178 |
+
old_emb = np.array(unknown_speakers[matched_name])
|
| 179 |
+
new_emb = (old_emb + c_emb) / 2.0
|
| 180 |
+
unknown_speakers[matched_name] = (new_emb / norm(new_emb)).tolist()
|
| 181 |
+
speakers_updated = True
|
| 182 |
+
else:
|
| 183 |
+
# No sufficient match found, create new unknown
|
| 184 |
+
new_id = self.get_next_unknown_id(unknown_speakers)
|
| 185 |
+
unknown_speakers[new_id] = c_emb.tolist()
|
| 186 |
+
speaker_map[cluster_label] = new_id
|
| 187 |
+
speakers_updated = True
|
| 188 |
+
|
| 189 |
+
if speakers_updated:
|
| 190 |
+
self.save_unknown_speakers(unknown_speakers, unknown_speakers_path)
|
| 191 |
+
|
| 192 |
+
# ASR transcription (same as before)
|
| 193 |
+
for seg in diar_json:
|
| 194 |
+
start, end, spk = seg["start"], seg["end"], seg["speaker"]
|
| 195 |
+
start_sample, end_sample = int(start*sr), int(end*sr)
|
| 196 |
+
chunk = audio_np[start_sample:end_sample]
|
| 197 |
+
if chunk.size == 0: continue
|
| 198 |
+
if sr != target_sr:
|
| 199 |
+
chunk = signal.resample(chunk, int(len(chunk)*target_sr/sr)).astype(np.float32)
|
| 200 |
+
sr_chunk = target_sr
|
| 201 |
+
else:
|
| 202 |
+
sr_chunk = sr
|
| 203 |
+
try:
|
| 204 |
+
reduced = nr.reduce_noise(chunk, sr=sr_chunk)
|
| 205 |
+
except Exception:
|
| 206 |
+
reduced = chunk
|
| 207 |
+
try:
|
| 208 |
+
result = self.asr_pipeline({"array": reduced, "sampling_rate": sr_chunk})
|
| 209 |
+
except Exception:
|
| 210 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpf:
|
| 211 |
+
sf.write(tmpf.name, reduced, sr_chunk, subtype="PCM_16")
|
| 212 |
+
result = self.asr_pipeline(tmpf.name)
|
| 213 |
+
|
| 214 |
+
tokens, transcript_text = [], ""
|
| 215 |
+
if isinstance(result, dict) and "chunks" in result:
|
| 216 |
+
for w in result["chunks"]:
|
| 217 |
+
start_ts = w.get("start") or (w.get("timestamp") and w["timestamp"][0])
|
| 218 |
+
end_ts = w.get("end") or (w.get("timestamp") and w["timestamp"][1])
|
| 219 |
+
word_text = w.get("text","").strip()
|
| 220 |
+
tokens.append({"start":start_ts,"end":end_ts,"text":word_text,"tag":"w"})
|
| 221 |
+
transcript_text += word_text + " "
|
| 222 |
+
else:
|
| 223 |
+
text = result.get("text") if isinstance(result, dict) else str(result)
|
| 224 |
+
transcript_text = text or ""
|
| 225 |
+
tokens.append({"start":None,"end":None,"text":transcript_text,"tag":"w"})
|
| 226 |
+
|
| 227 |
+
final_speaker = speaker_map.get(spk,"unknown")
|
| 228 |
+
seg_dict = {"speaker":final_speaker,"start":start,"end":end,"text":transcript_text.strip(),"tokens":tokens}
|
| 229 |
+
merged_segments.append(seg_dict)
|
| 230 |
+
speaker_segments.setdefault(final_speaker,[]).append(seg_dict)
|
| 231 |
+
|
| 232 |
+
return merged_segments, list(speaker_segments.keys())
|
| 233 |
+
|
| 234 |
+
def run_pipeline(self, audio_path, output_dir=None, base_name=None,
|
| 235 |
+
ref_rttm=None, ref_json=None, enrolled_speakers=None, unknown_speakers_path=None):
|
| 236 |
+
diar_json = self.run_diarization(audio_path)
|
| 237 |
+
merged_segments, speakers = self.run_transcription(
|
| 238 |
+
audio_path, diar_json, enrolled_speakers=enrolled_speakers,
|
| 239 |
+
unknown_speakers_path=unknown_speakers_path
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
if output_dir and base_name:
|
| 243 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 244 |
+
|
| 245 |
+
# Save RTTM
|
| 246 |
+
rttm_path = os.path.join(output_dir, f"{base_name}.rttm")
|
| 247 |
+
with open(rttm_path, "w") as f:
|
| 248 |
+
for seg in diar_json:
|
| 249 |
+
f.write(
|
| 250 |
+
f"SPEAKER {base_name} 1 {seg['start']:.6f} "
|
| 251 |
+
f"{seg['end']-seg['start']:.6f} <NA> <NA> "
|
| 252 |
+
f"{seg['speaker']} <NA>\n"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Save transcription
|
| 256 |
+
merged_path = os.path.join(output_dir, f"{base_name}_merged_transcription.json")
|
| 257 |
+
with open(merged_path, "w") as f:
|
| 258 |
+
json.dump(merged_segments, f, indent=2)
|
| 259 |
+
|
| 260 |
+
# Evaluation
|
| 261 |
+
eval_results = None
|
| 262 |
+
if ref_rttm or ref_json:
|
| 263 |
+
eval_results = self.evaluate(output_dir, base_name,
|
| 264 |
+
ref_rttm=ref_rttm, ref_json=ref_json)
|
| 265 |
+
|
| 266 |
+
return {
|
| 267 |
+
"speakers": speakers,
|
| 268 |
+
"segments": merged_segments,
|
| 269 |
+
"evaluation": eval_results
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
def evaluate(self, output_dir, base_name, ref_rttm=None, ref_json=None):
|
| 273 |
+
results = {}
|
| 274 |
+
hyp_rttm = os.path.join(output_dir, f"{base_name}.rttm")
|
| 275 |
+
hyp_json = os.path.join(output_dir, f"{base_name}_merged_transcription.json")
|
| 276 |
+
|
| 277 |
+
if ref_rttm:
|
| 278 |
+
def load_rttm(path):
|
| 279 |
+
ann = Annotation()
|
| 280 |
+
for line in open(path):
|
| 281 |
+
if line.startswith("SPEAKER"):
|
| 282 |
+
p = line.split()
|
| 283 |
+
start, dur, spk = float(p[3]), float(p[4]), p[7]
|
| 284 |
+
ann[Segment(start, start+dur)] = spk
|
| 285 |
+
return ann
|
| 286 |
+
|
| 287 |
+
der_score = DiarizationErrorRate()(load_rttm(ref_rttm), load_rttm(hyp_rttm))
|
| 288 |
+
results["DER"] = round(der_score * 100, 2)
|
| 289 |
+
|
| 290 |
+
if ref_json:
|
| 291 |
+
def load_words(path):
|
| 292 |
+
data = json.load(open(path))
|
| 293 |
+
return " ".join([tok["text"] for seg in data for tok in seg["tokens"]])
|
| 294 |
+
|
| 295 |
+
ref_text, hyp_text = load_words(ref_json), load_words(hyp_json)
|
| 296 |
+
transform = Compose([ToLowerCase(), RemovePunctuation(),
|
| 297 |
+
RemoveMultipleSpaces(), Strip()])
|
| 298 |
+
results["WER_raw"] = round(wer(ref_text, hyp_text), 4)
|
| 299 |
+
results["WER_normalized"] = round(wer(transform(ref_text), transform(hyp_text)), 4)
|
| 300 |
+
|
| 301 |
+
return results if results else None
|
| 302 |
+
|
| 303 |
+
def __call__(self, inputs):
|
| 304 |
+
if isinstance(inputs, dict):
|
| 305 |
+
if "audio_bytes" in inputs:
|
| 306 |
+
audio_bytes = inputs["audio_bytes"]
|
| 307 |
+
elif "audio" in inputs:
|
| 308 |
+
audio_bytes = inputs["audio"]
|
| 309 |
+
else:
|
| 310 |
+
raise ValueError("No audio found in inputs")
|
| 311 |
+
else:
|
| 312 |
+
audio_bytes = inputs
|
| 313 |
+
|
| 314 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
| 315 |
+
tmp.write(audio_bytes)
|
| 316 |
+
tmp_path = tmp.name
|
| 317 |
+
|
| 318 |
+
return self.run_pipeline(tmp_path)
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchaudio
|
| 3 |
+
pyannote.audio
|
| 4 |
+
transformers
|
| 5 |
+
noisereduce
|
| 6 |
+
scikit-learn
|
| 7 |
+
jiwer
|
| 8 |
+
librosa
|
setup.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="asr_diarization",
|
| 5 |
+
version="0.1.0",
|
| 6 |
+
packages=find_packages(),
|
| 7 |
+
install_requires=[
|
| 8 |
+
"torch",
|
| 9 |
+
"torchaudio",
|
| 10 |
+
"pyannote.audio",
|
| 11 |
+
"transformers",
|
| 12 |
+
"noisereduce",
|
| 13 |
+
"scikit-learn",
|
| 14 |
+
"jiwer",
|
| 15 |
+
"librosa"
|
| 16 |
+
],
|
| 17 |
+
)
|