import os import json import torch import tempfile import torchaudio import threading import numpy as np import soundfile as sf import noisereduce as nr from scipy import signal from numpy.linalg import norm from speechbrain.pretrained import SpeakerRecognition, EncoderClassifier from speechbrain.pretrained import SpectralMaskEnhancement from transformers import pipeline as hf_pipeline from jiwer import wer, Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSpaces, Strip class ASR_Diarization: def __init__(self, HF_TOKEN, asr_model="openai/whisper-medium"): self.HF_TOKEN = HF_TOKEN self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._unknown_lock = threading.Lock() # Load SpeechBrain models try: self.embedding_model = EncoderClassifier.from_hparams( source="speechbrain/spkrec-ecapa-voxceleb", run_opts={"device": str(self.device)} ) print("[ECAPA] Model loaded successfully.") except Exception as e: self.embedding_model = None print(f"[ERROR] Failed to load ECAPA: {e}") try: self.speaker_diarization = SpeakerRecognition.from_hparams( source="speechbrain/spkrec-ecapa-voxceleb", savedir="pretrained_models/spkrec-ecapa-voxceleb" ) print("[Speaker Recognition] Model loaded successfully.") except Exception as e: self.speaker_diarization = None print(f"[ERROR] Failed to load Speaker Recognition: {e}") # Load ASR pipeline device_index = 0 if torch.cuda.is_available() else -1 self.asr_pipeline = hf_pipeline( "automatic-speech-recognition", model=asr_model, device=device_index, return_timestamps=True ) def run_diarization(self, audio_path): """Simple diarization using SpeechBrain embedding clustering""" audio, sr = torchaudio.load(audio_path) audio_np = audio[0].numpy() if audio.shape[0] == 1 else audio.mean(dim=0).numpy() # Segment audio into chunks for diarization chunk_duration = 2.0 # 2-second chunks chunk_size = int(chunk_duration * sr) segments = [] for i in range(0, len(audio_np), chunk_size): start_time = i / sr end_time = min((i + chunk_size) / sr, len(audio_np) / sr) chunk = audio_np[i:i+chunk_size] if len(chunk) < 8000: # Skip very short chunks continue # Get speaker embedding for this chunk if self.embedding_model: try: chunk_tensor = torch.from_numpy(chunk).unsqueeze(0).to(self.device) with torch.no_grad(): embedding = self.embedding_model.encode_batch(chunk_tensor).squeeze().cpu().numpy() # Simple speaker assignment based on embedding similarity speaker_id = self._assign_speaker(embedding, segments) segments.append({ "start": start_time, "end": end_time, "speaker": speaker_id, "embedding": embedding }) except Exception as e: print(f"Error processing chunk: {e}") continue return segments def _assign_speaker(self, embedding, existing_segments, threshold=0.7): """Assign speaker based on embedding similarity""" if not existing_segments: return "speaker_1" # Calculate similarity with existing speakers similarities = [] for seg in existing_segments[-10:]: # Check last 10 segments if "embedding" in seg: sim = np.dot(embedding.flatten(), seg["embedding"].flatten()) / ( norm(embedding.flatten()) * norm(seg["embedding"].flatten()) ) similarities.append((seg["speaker"], sim)) if similarities: best_speaker, best_sim = max(similarities, key=lambda x: x[1]) if best_sim > threshold: return best_speaker # Create new speaker existing_speakers = set(seg["speaker"] for seg in existing_segments) speaker_num = 1 while f"speaker_{speaker_num}" in existing_speakers: speaker_num += 1 return f"speaker_{speaker_num}" def load_unknown_speakers(self, unknown_speakers_path): if os.path.exists(unknown_speakers_path): try: with open(unknown_speakers_path, "r") as f: content = f.read().strip() if content: return json.loads(content) except Exception as e: print(f"[WARN] Failed to load unknown speakers ({e}), starting fresh") return {} def save_unknown_speakers(self, unknown_speakers, unknown_speakers_path): try: os.makedirs(os.path.dirname(unknown_speakers_path), exist_ok=True) tmp = unknown_speakers_path + ".tmp" with open(tmp, "w", encoding="utf-8") as f: json.dump(unknown_speakers, f, indent=2) f.flush() os.fsync(f.fileno()) os.replace(tmp, unknown_speakers_path) return True except Exception as e: print(f"[ERROR] Failed to save unknown speakers: {e}") return False def get_next_unknown_id(self, unknown_speakers): if not unknown_speakers: return "unknown_1" max_id = 0 for speaker_id in unknown_speakers.keys(): if speaker_id.startswith("unknown_"): try: num = int(speaker_id.split("_")[1]) max_id = max(max_id, num) except (IndexError, ValueError): continue return f"unknown_{max_id + 1}" def match_speaker_embedding(self, cluster_embedding, enrolled_speakers_np, unknown_speakers, threshold=0.5): cluster_embedding = cluster_embedding / norm(cluster_embedding) best_name, best_score, is_enrolled = None, -1.0, False # Log all similarities sim_log = [] # Check enrolled speakers for name, e_emb in enrolled_speakers_np.items(): sim = float(np.dot(cluster_embedding, e_emb / norm(e_emb))) sim_log.append((name, sim, True)) if sim > best_score: best_name, best_score, is_enrolled = name, sim, True # Check unknown speakers for u_id, u_emb in unknown_speakers.items(): sim = float(np.dot(cluster_embedding, np.array(u_emb) / norm(u_emb))) sim_log.append((u_id, sim, False)) if sim > best_score: best_name, best_score, is_enrolled = u_id, sim, False # Log before creating new unknown print("[MATCH LOG] Cluster embedding compared:", sim_log) print(f"[MATCH LOG] Best match: {best_name}, score: {best_score}, enrolled: {is_enrolled}") return best_name, best_score, is_enrolled def run_transcription(self, audio_path, diar_json, enrolled_speakers=None, unknown_speakers_path=None): unknown_speakers_path = unknown_speakers_path or os.path.join(os.path.dirname(audio_path), "unknown_speakers.json") # Load unknown speakers safely with self._unknown_lock: unknown_speakers = self.load_unknown_speakers(unknown_speakers_path) audio, sr = torchaudio.load(audio_path) audio_np = audio[0].numpy() if audio.shape[0] == 1 else audio.mean(dim=0).numpy() merged_segments, speaker_segments = [], {} enrolled_speakers_np = {n: v/norm(v) for n,v in (enrolled_speakers or {}).items() if norm(v) > 0} target_sr = 16000 # Group segments by speaker for clustering clusters = {} for seg in diar_json: clusters.setdefault(seg["speaker"], []).append(seg) # Compute cluster embeddings cluster_embeddings = {} for cluster_label, segs in clusters.items(): seg_embs = [] for seg in segs: start, end = seg["start"], seg["end"] start_sample, end_sample = int(start*sr), int(end*sr) chunk = audio_np[start_sample:end_sample] if chunk.size < 8000: chunk = np.pad(chunk, (0, 8000 - chunk.size), mode='constant') if sr != target_sr: chunk = signal.resample(chunk, int(len(chunk)*target_sr/sr)).astype(np.float32) if self.embedding_model: tensor = torch.from_numpy(chunk).unsqueeze(0).to(self.device) with torch.no_grad(): emb = np.ravel(self.embedding_model.encode_batch(tensor).squeeze().cpu().numpy()) if norm(emb) > 0: seg_embs.append(emb / norm(emb)) if seg_embs: cluster_emb = np.mean(np.stack(seg_embs), axis=0) cluster_embeddings[cluster_label] = cluster_emb / norm(cluster_emb) speaker_map, speakers_updated = {}, {} threshold = 0.5 # Thread-safe unknown speaker update with self._unknown_lock: for cluster_label, c_emb in cluster_embeddings.items(): matched_name, best_score, is_enrolled = self.match_speaker_embedding( c_emb, enrolled_speakers_np, unknown_speakers, threshold ) if best_score >= threshold: speaker_map[cluster_label] = matched_name # Update unknown embedding if matched_name is an unknown if not is_enrolled: old_emb = np.array(unknown_speakers[matched_name]) new_emb = (old_emb + c_emb) / 2.0 unknown_speakers[matched_name] = (new_emb / norm(new_emb)).tolist() speakers_updated = True else: # No sufficient match found, create new unknown new_id = self.get_next_unknown_id(unknown_speakers) unknown_speakers[new_id] = c_emb.tolist() speaker_map[cluster_label] = new_id speakers_updated = True if speakers_updated: self.save_unknown_speakers(unknown_speakers, unknown_speakers_path) # ASR transcription for seg in diar_json: start, end, spk = seg["start"], seg["end"], seg["speaker"] start_sample, end_sample = int(start*sr), int(end*sr) chunk = audio_np[start_sample:end_sample] if chunk.size == 0: continue if sr != target_sr: chunk = signal.resample(chunk, int(len(chunk)*target_sr/sr)).astype(np.float32) sr_chunk = target_sr else: sr_chunk = sr try: reduced = nr.reduce_noise(chunk, sr=sr_chunk) except Exception: reduced = chunk try: result = self.asr_pipeline({"array": reduced, "sampling_rate": sr_chunk}) except Exception: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpf: sf.write(tmpf.name, reduced, sr_chunk, subtype="PCM_16") result = self.asr_pipeline(tmpf.name) tokens, transcript_text = [], "" if isinstance(result, dict) and "chunks" in result: for w in result["chunks"]: start_ts = w.get("start") or (w.get("timestamp") and w["timestamp"][0]) end_ts = w.get("end") or (w.get("timestamp") and w["timestamp"][1]) word_text = w.get("text","").strip() tokens.append({"start":start_ts,"end":end_ts,"text":word_text,"tag":"w"}) transcript_text += word_text + " " else: text = result.get("text") if isinstance(result, dict) else str(result) transcript_text = text or "" tokens.append({"start":None,"end":None,"text":transcript_text,"tag":"w"}) final_speaker = speaker_map.get(spk,"unknown") seg_dict = {"speaker":final_speaker,"start":start,"end":end,"text":transcript_text.strip(),"tokens":tokens} merged_segments.append(seg_dict) speaker_segments.setdefault(final_speaker,[]).append(seg_dict) return merged_segments, list(speaker_segments.keys()) def run_pipeline(self, audio_path, output_dir=None, base_name=None, ref_rttm=None, ref_json=None, enrolled_speakers=None, unknown_speakers_path=None): diar_json = self.run_diarization(audio_path) merged_segments, speakers = self.run_transcription( audio_path, diar_json, enrolled_speakers=enrolled_speakers, unknown_speakers_path=unknown_speakers_path ) if output_dir and base_name: os.makedirs(output_dir, exist_ok=True) # Save RTTM rttm_path = os.path.join(output_dir, f"{base_name}.rttm") with open(rttm_path, "w") as f: for seg in diar_json: f.write( f"SPEAKER {base_name} 1 {seg['start']:.6f} " f"{seg['end']-seg['start']:.6f} " f"{seg['speaker']} \n" ) # Save transcription merged_path = os.path.join(output_dir, f"{base_name}_merged_transcription.json") with open(merged_path, "w") as f: json.dump(merged_segments, f, indent=2) # Evaluation eval_results = None if ref_rttm or ref_json: eval_results = self.evaluate(output_dir, base_name, ref_rttm=ref_rttm, ref_json=ref_json) return { "speakers": speakers, "segments": merged_segments, "evaluation": eval_results } def evaluate(self, output_dir, base_name, ref_rttm=None, ref_json=None): results = {} hyp_rttm = os.path.join(output_dir, f"{base_name}.rttm") hyp_json = os.path.join(output_dir, f"{base_name}_merged_transcription.json") if ref_json: def load_words(path): data = json.load(open(path)) return " ".join([tok["text"] for seg in data for tok in seg["tokens"]]) ref_text, hyp_text = load_words(ref_json), load_words(hyp_json) transform = Compose([ToLowerCase(), RemovePunctuation(), RemoveMultipleSpaces(), Strip()]) results["WER_raw"] = round(wer(ref_text, hyp_text), 4) results["WER_normalized"] = round(wer(transform(ref_text), transform(hyp_text)), 4) return results if results else None def __call__(self, inputs): if isinstance(inputs, dict): if "audio_bytes" in inputs: audio_bytes = inputs["audio_bytes"] elif "audio" in inputs: audio_bytes = inputs["audio"] else: raise ValueError("No audio found in inputs") else: audio_bytes = inputs with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name return self.run_pipeline(tmp_path)