|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import tempfile |
|
|
import torchaudio |
|
|
import threading |
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
import noisereduce as nr |
|
|
from scipy import signal |
|
|
from numpy.linalg import norm |
|
|
from speechbrain.pretrained import SpeakerRecognition, EncoderClassifier |
|
|
from speechbrain.pretrained import SpectralMaskEnhancement |
|
|
from transformers import pipeline as hf_pipeline |
|
|
from jiwer import wer, Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSpaces, Strip |
|
|
|
|
|
class ASR_Diarization: |
|
|
def __init__(self, HF_TOKEN, asr_model="openai/whisper-medium"): |
|
|
self.HF_TOKEN = HF_TOKEN |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self._unknown_lock = threading.Lock() |
|
|
|
|
|
|
|
|
try: |
|
|
self.embedding_model = EncoderClassifier.from_hparams( |
|
|
source="speechbrain/spkrec-ecapa-voxceleb", |
|
|
run_opts={"device": str(self.device)} |
|
|
) |
|
|
print("[ECAPA] Model loaded successfully.") |
|
|
except Exception as e: |
|
|
self.embedding_model = None |
|
|
print(f"[ERROR] Failed to load ECAPA: {e}") |
|
|
|
|
|
try: |
|
|
self.speaker_diarization = SpeakerRecognition.from_hparams( |
|
|
source="speechbrain/spkrec-ecapa-voxceleb", |
|
|
savedir="pretrained_models/spkrec-ecapa-voxceleb" |
|
|
) |
|
|
print("[Speaker Recognition] Model loaded successfully.") |
|
|
except Exception as e: |
|
|
self.speaker_diarization = None |
|
|
print(f"[ERROR] Failed to load Speaker Recognition: {e}") |
|
|
|
|
|
|
|
|
device_index = 0 if torch.cuda.is_available() else -1 |
|
|
self.asr_pipeline = hf_pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model=asr_model, |
|
|
device=device_index, |
|
|
return_timestamps=True |
|
|
) |
|
|
|
|
|
def run_diarization(self, audio_path): |
|
|
"""Simple diarization using SpeechBrain embedding clustering""" |
|
|
audio, sr = torchaudio.load(audio_path) |
|
|
audio_np = audio[0].numpy() if audio.shape[0] == 1 else audio.mean(dim=0).numpy() |
|
|
|
|
|
|
|
|
chunk_duration = 2.0 |
|
|
chunk_size = int(chunk_duration * sr) |
|
|
segments = [] |
|
|
|
|
|
for i in range(0, len(audio_np), chunk_size): |
|
|
start_time = i / sr |
|
|
end_time = min((i + chunk_size) / sr, len(audio_np) / sr) |
|
|
chunk = audio_np[i:i+chunk_size] |
|
|
|
|
|
if len(chunk) < 8000: |
|
|
continue |
|
|
|
|
|
|
|
|
if self.embedding_model: |
|
|
try: |
|
|
chunk_tensor = torch.from_numpy(chunk).unsqueeze(0).to(self.device) |
|
|
with torch.no_grad(): |
|
|
embedding = self.embedding_model.encode_batch(chunk_tensor).squeeze().cpu().numpy() |
|
|
|
|
|
|
|
|
speaker_id = self._assign_speaker(embedding, segments) |
|
|
|
|
|
segments.append({ |
|
|
"start": start_time, |
|
|
"end": end_time, |
|
|
"speaker": speaker_id, |
|
|
"embedding": embedding |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"Error processing chunk: {e}") |
|
|
continue |
|
|
|
|
|
return segments |
|
|
|
|
|
def _assign_speaker(self, embedding, existing_segments, threshold=0.7): |
|
|
"""Assign speaker based on embedding similarity""" |
|
|
if not existing_segments: |
|
|
return "speaker_1" |
|
|
|
|
|
|
|
|
similarities = [] |
|
|
for seg in existing_segments[-10:]: |
|
|
if "embedding" in seg: |
|
|
sim = np.dot(embedding.flatten(), seg["embedding"].flatten()) / ( |
|
|
norm(embedding.flatten()) * norm(seg["embedding"].flatten()) |
|
|
) |
|
|
similarities.append((seg["speaker"], sim)) |
|
|
|
|
|
if similarities: |
|
|
best_speaker, best_sim = max(similarities, key=lambda x: x[1]) |
|
|
if best_sim > threshold: |
|
|
return best_speaker |
|
|
|
|
|
|
|
|
existing_speakers = set(seg["speaker"] for seg in existing_segments) |
|
|
speaker_num = 1 |
|
|
while f"speaker_{speaker_num}" in existing_speakers: |
|
|
speaker_num += 1 |
|
|
return f"speaker_{speaker_num}" |
|
|
|
|
|
def load_unknown_speakers(self, unknown_speakers_path): |
|
|
if os.path.exists(unknown_speakers_path): |
|
|
try: |
|
|
with open(unknown_speakers_path, "r") as f: |
|
|
content = f.read().strip() |
|
|
if content: |
|
|
return json.loads(content) |
|
|
except Exception as e: |
|
|
print(f"[WARN] Failed to load unknown speakers ({e}), starting fresh") |
|
|
return {} |
|
|
|
|
|
def save_unknown_speakers(self, unknown_speakers, unknown_speakers_path): |
|
|
try: |
|
|
os.makedirs(os.path.dirname(unknown_speakers_path), exist_ok=True) |
|
|
tmp = unknown_speakers_path + ".tmp" |
|
|
with open(tmp, "w", encoding="utf-8") as f: |
|
|
json.dump(unknown_speakers, f, indent=2) |
|
|
f.flush() |
|
|
os.fsync(f.fileno()) |
|
|
os.replace(tmp, unknown_speakers_path) |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"[ERROR] Failed to save unknown speakers: {e}") |
|
|
return False |
|
|
|
|
|
def get_next_unknown_id(self, unknown_speakers): |
|
|
if not unknown_speakers: |
|
|
return "unknown_1" |
|
|
max_id = 0 |
|
|
for speaker_id in unknown_speakers.keys(): |
|
|
if speaker_id.startswith("unknown_"): |
|
|
try: |
|
|
num = int(speaker_id.split("_")[1]) |
|
|
max_id = max(max_id, num) |
|
|
except (IndexError, ValueError): |
|
|
continue |
|
|
return f"unknown_{max_id + 1}" |
|
|
|
|
|
def match_speaker_embedding(self, cluster_embedding, enrolled_speakers_np, unknown_speakers, threshold=0.5): |
|
|
cluster_embedding = cluster_embedding / norm(cluster_embedding) |
|
|
best_name, best_score, is_enrolled = None, -1.0, False |
|
|
|
|
|
|
|
|
sim_log = [] |
|
|
|
|
|
|
|
|
for name, e_emb in enrolled_speakers_np.items(): |
|
|
sim = float(np.dot(cluster_embedding, e_emb / norm(e_emb))) |
|
|
sim_log.append((name, sim, True)) |
|
|
if sim > best_score: |
|
|
best_name, best_score, is_enrolled = name, sim, True |
|
|
|
|
|
|
|
|
for u_id, u_emb in unknown_speakers.items(): |
|
|
sim = float(np.dot(cluster_embedding, np.array(u_emb) / norm(u_emb))) |
|
|
sim_log.append((u_id, sim, False)) |
|
|
if sim > best_score: |
|
|
best_name, best_score, is_enrolled = u_id, sim, False |
|
|
|
|
|
|
|
|
print("[MATCH LOG] Cluster embedding compared:", sim_log) |
|
|
print(f"[MATCH LOG] Best match: {best_name}, score: {best_score}, enrolled: {is_enrolled}") |
|
|
|
|
|
return best_name, best_score, is_enrolled |
|
|
|
|
|
def run_transcription(self, audio_path, diar_json, enrolled_speakers=None, unknown_speakers_path=None): |
|
|
unknown_speakers_path = unknown_speakers_path or os.path.join(os.path.dirname(audio_path), "unknown_speakers.json") |
|
|
|
|
|
|
|
|
with self._unknown_lock: |
|
|
unknown_speakers = self.load_unknown_speakers(unknown_speakers_path) |
|
|
|
|
|
audio, sr = torchaudio.load(audio_path) |
|
|
audio_np = audio[0].numpy() if audio.shape[0] == 1 else audio.mean(dim=0).numpy() |
|
|
merged_segments, speaker_segments = [], {} |
|
|
enrolled_speakers_np = {n: v/norm(v) for n,v in (enrolled_speakers or {}).items() if norm(v) > 0} |
|
|
|
|
|
target_sr = 16000 |
|
|
|
|
|
|
|
|
clusters = {} |
|
|
for seg in diar_json: |
|
|
clusters.setdefault(seg["speaker"], []).append(seg) |
|
|
|
|
|
|
|
|
cluster_embeddings = {} |
|
|
for cluster_label, segs in clusters.items(): |
|
|
seg_embs = [] |
|
|
for seg in segs: |
|
|
start, end = seg["start"], seg["end"] |
|
|
start_sample, end_sample = int(start*sr), int(end*sr) |
|
|
chunk = audio_np[start_sample:end_sample] |
|
|
if chunk.size < 8000: |
|
|
chunk = np.pad(chunk, (0, 8000 - chunk.size), mode='constant') |
|
|
if sr != target_sr: |
|
|
chunk = signal.resample(chunk, int(len(chunk)*target_sr/sr)).astype(np.float32) |
|
|
if self.embedding_model: |
|
|
tensor = torch.from_numpy(chunk).unsqueeze(0).to(self.device) |
|
|
with torch.no_grad(): |
|
|
emb = np.ravel(self.embedding_model.encode_batch(tensor).squeeze().cpu().numpy()) |
|
|
if norm(emb) > 0: |
|
|
seg_embs.append(emb / norm(emb)) |
|
|
if seg_embs: |
|
|
cluster_emb = np.mean(np.stack(seg_embs), axis=0) |
|
|
cluster_embeddings[cluster_label] = cluster_emb / norm(cluster_emb) |
|
|
|
|
|
speaker_map, speakers_updated = {}, {} |
|
|
threshold = 0.5 |
|
|
|
|
|
|
|
|
with self._unknown_lock: |
|
|
for cluster_label, c_emb in cluster_embeddings.items(): |
|
|
matched_name, best_score, is_enrolled = self.match_speaker_embedding( |
|
|
c_emb, enrolled_speakers_np, unknown_speakers, threshold |
|
|
) |
|
|
|
|
|
if best_score >= threshold: |
|
|
speaker_map[cluster_label] = matched_name |
|
|
|
|
|
if not is_enrolled: |
|
|
old_emb = np.array(unknown_speakers[matched_name]) |
|
|
new_emb = (old_emb + c_emb) / 2.0 |
|
|
unknown_speakers[matched_name] = (new_emb / norm(new_emb)).tolist() |
|
|
speakers_updated = True |
|
|
else: |
|
|
|
|
|
new_id = self.get_next_unknown_id(unknown_speakers) |
|
|
unknown_speakers[new_id] = c_emb.tolist() |
|
|
speaker_map[cluster_label] = new_id |
|
|
speakers_updated = True |
|
|
|
|
|
if speakers_updated: |
|
|
self.save_unknown_speakers(unknown_speakers, unknown_speakers_path) |
|
|
|
|
|
|
|
|
for seg in diar_json: |
|
|
start, end, spk = seg["start"], seg["end"], seg["speaker"] |
|
|
start_sample, end_sample = int(start*sr), int(end*sr) |
|
|
chunk = audio_np[start_sample:end_sample] |
|
|
if chunk.size == 0: continue |
|
|
if sr != target_sr: |
|
|
chunk = signal.resample(chunk, int(len(chunk)*target_sr/sr)).astype(np.float32) |
|
|
sr_chunk = target_sr |
|
|
else: |
|
|
sr_chunk = sr |
|
|
try: |
|
|
reduced = nr.reduce_noise(chunk, sr=sr_chunk) |
|
|
except Exception: |
|
|
reduced = chunk |
|
|
try: |
|
|
result = self.asr_pipeline({"array": reduced, "sampling_rate": sr_chunk}) |
|
|
except Exception: |
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpf: |
|
|
sf.write(tmpf.name, reduced, sr_chunk, subtype="PCM_16") |
|
|
result = self.asr_pipeline(tmpf.name) |
|
|
|
|
|
tokens, transcript_text = [], "" |
|
|
if isinstance(result, dict) and "chunks" in result: |
|
|
for w in result["chunks"]: |
|
|
start_ts = w.get("start") or (w.get("timestamp") and w["timestamp"][0]) |
|
|
end_ts = w.get("end") or (w.get("timestamp") and w["timestamp"][1]) |
|
|
word_text = w.get("text","").strip() |
|
|
tokens.append({"start":start_ts,"end":end_ts,"text":word_text,"tag":"w"}) |
|
|
transcript_text += word_text + " " |
|
|
else: |
|
|
text = result.get("text") if isinstance(result, dict) else str(result) |
|
|
transcript_text = text or "" |
|
|
tokens.append({"start":None,"end":None,"text":transcript_text,"tag":"w"}) |
|
|
|
|
|
final_speaker = speaker_map.get(spk,"unknown") |
|
|
seg_dict = {"speaker":final_speaker,"start":start,"end":end,"text":transcript_text.strip(),"tokens":tokens} |
|
|
merged_segments.append(seg_dict) |
|
|
speaker_segments.setdefault(final_speaker,[]).append(seg_dict) |
|
|
|
|
|
return merged_segments, list(speaker_segments.keys()) |
|
|
|
|
|
def run_pipeline(self, audio_path, output_dir=None, base_name=None, |
|
|
ref_rttm=None, ref_json=None, enrolled_speakers=None, unknown_speakers_path=None): |
|
|
diar_json = self.run_diarization(audio_path) |
|
|
merged_segments, speakers = self.run_transcription( |
|
|
audio_path, diar_json, enrolled_speakers=enrolled_speakers, |
|
|
unknown_speakers_path=unknown_speakers_path |
|
|
) |
|
|
|
|
|
if output_dir and base_name: |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
rttm_path = os.path.join(output_dir, f"{base_name}.rttm") |
|
|
with open(rttm_path, "w") as f: |
|
|
for seg in diar_json: |
|
|
f.write( |
|
|
f"SPEAKER {base_name} 1 {seg['start']:.6f} " |
|
|
f"{seg['end']-seg['start']:.6f} <NA> <NA> " |
|
|
f"{seg['speaker']} <NA>\n" |
|
|
) |
|
|
|
|
|
|
|
|
merged_path = os.path.join(output_dir, f"{base_name}_merged_transcription.json") |
|
|
with open(merged_path, "w") as f: |
|
|
json.dump(merged_segments, f, indent=2) |
|
|
|
|
|
|
|
|
eval_results = None |
|
|
if ref_rttm or ref_json: |
|
|
eval_results = self.evaluate(output_dir, base_name, |
|
|
ref_rttm=ref_rttm, ref_json=ref_json) |
|
|
|
|
|
return { |
|
|
"speakers": speakers, |
|
|
"segments": merged_segments, |
|
|
"evaluation": eval_results |
|
|
} |
|
|
|
|
|
def evaluate(self, output_dir, base_name, ref_rttm=None, ref_json=None): |
|
|
results = {} |
|
|
hyp_rttm = os.path.join(output_dir, f"{base_name}.rttm") |
|
|
hyp_json = os.path.join(output_dir, f"{base_name}_merged_transcription.json") |
|
|
|
|
|
if ref_json: |
|
|
def load_words(path): |
|
|
data = json.load(open(path)) |
|
|
return " ".join([tok["text"] for seg in data for tok in seg["tokens"]]) |
|
|
|
|
|
ref_text, hyp_text = load_words(ref_json), load_words(hyp_json) |
|
|
transform = Compose([ToLowerCase(), RemovePunctuation(), |
|
|
RemoveMultipleSpaces(), Strip()]) |
|
|
results["WER_raw"] = round(wer(ref_text, hyp_text), 4) |
|
|
results["WER_normalized"] = round(wer(transform(ref_text), transform(hyp_text)), 4) |
|
|
|
|
|
return results if results else None |
|
|
|
|
|
def __call__(self, inputs): |
|
|
if isinstance(inputs, dict): |
|
|
if "audio_bytes" in inputs: |
|
|
audio_bytes = inputs["audio_bytes"] |
|
|
elif "audio" in inputs: |
|
|
audio_bytes = inputs["audio"] |
|
|
else: |
|
|
raise ValueError("No audio found in inputs") |
|
|
else: |
|
|
audio_bytes = inputs |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
|
|
tmp.write(audio_bytes) |
|
|
tmp_path = tmp.name |
|
|
|
|
|
return self.run_pipeline(tmp_path) |
|
|
|