import os import json import re import torch import torchaudio import noisereduce as nr import numpy as np from pyannote.audio import Pipeline from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline as hf_pipeline import tempfile from pyannote.core import Annotation, Segment from pyannote.metrics.diarization import DiarizationErrorRate from jiwer import wer, Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSpaces, Strip class ASR_Diarization: def __init__(self, HF_TOKEN, diar_model="pyannote/speaker-diarization-3.1", asr_model="Capstone04/TrainedWhisper_Medium", model_path=None, use_vad=True, vad_threshold=0.3, min_segment_duration=0.5, snr_threshold=15.0, min_whisper_duration=0.3): self.HF_TOKEN = HF_TOKEN self.device = "cuda" if torch.cuda.is_available() else "cpu" self.use_vad = use_vad self.vad_threshold = vad_threshold self.min_segment_duration = min_segment_duration self.snr_threshold = snr_threshold self.min_whisper_duration = min_whisper_duration # Load diarization model self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN) self.diar_pipeline = self.diar_pipeline.to(torch.device(self.device)) # Load WebRTC VAD for post-diarization filtering if self.use_vad: try: import webrtcvad self.vad = webrtcvad.Vad(2) print("WebRTC VAD loaded for post-diarization filtering") except ImportError: print("WebRTC VAD not available") self.use_vad = False # Load ASR model if model_path and os.path.exists(model_path): print(f"Loading custom ASR model from: {model_path}") actual_asr_model = model_path else: print(f"Loading default ASR model: {asr_model}") actual_asr_model = asr_model processor = WhisperProcessor.from_pretrained(actual_asr_model, token=HF_TOKEN) model = WhisperForConditionalGeneration.from_pretrained(actual_asr_model, token=HF_TOKEN).to(self.device) self.asr_pipeline = hf_pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, device=0 if self.device == "cuda" else -1, return_timestamps=True ) def clean_transcription_text(self, text): """Clean ASR text for better TTS performance""" if not text: return "" # Basic cleaning text = text.strip() # Fix punctuation spacing for TTS text = re.sub(r'\s+([.,!?;:])', r'\1', text) # Remove space before punctuation text = re.sub(r'([.,!?;:])(?=\w)', r'\1 ', text) # Add space after punctuation # Normalize whitespace text = re.sub(r'\s+', ' ', text) return text.strip() def should_keep_segment(self, text, duration, rms_energy): """Generalized segment quality assessment""" # Duration too short if duration < self.min_whisper_duration: return False # Energy too low (likely noise) if rms_energy < 0.001: return False # Text too short or just punctuation clean_text = text.strip() if len(clean_text) <= 1: return False return True def calculate_snr(self, audio_path): """NEW: Calculate SNR using RMS energy""" try: import librosa y, sr = librosa.load(audio_path, sr=16000, mono=True) # RMS-based SNR rms = librosa.feature.rms(y=y)[0] if len(rms) == 0: return float('inf') # Signal = high RMS regions, Noise = low RMS regions high_rms = rms[rms > np.percentile(rms, 70)] low_rms = rms[rms <= np.percentile(rms, 30)] if len(high_rms) == 0 or len(low_rms) == 0: return float('inf') signal_power = np.mean(high_rms) noise_power = np.mean(low_rms) if noise_power == 0: return float('inf') snr = 10 * np.log10(signal_power / noise_power) return snr except Exception as e: print(f"SNR calculation failed: {e}") return float('inf') def calculate_rms_energy(self, audio_chunk): """Calculate RMS energy for audio chunk""" return np.sqrt(np.mean(audio_chunk**2)) def run_webrtc_vad_on_segment(self, audio_path, segment_start, segment_end): """Run WebRTC VAD on segment to get speech ratio""" if not self.use_vad: return 1.0 try: import wave # Load audio with wave.open(audio_path, "rb") as wf: sample_rate = wf.getframerate() n_frames = wf.getnframes() audio_data = wf.readframes(n_frames) audio_array = np.frombuffer(audio_data, dtype=np.int16) start_sample = int(segment_start * sample_rate) end_sample = int(segment_end * sample_rate) segment_audio = audio_array[start_sample:end_sample] segment_bytes = segment_audio.tobytes() # WebRTC VAD processing (30ms frames) frame_duration = 30 bytes_per_sample = 2 frame_size = int(sample_rate * frame_duration / 1000) * bytes_per_sample speech_frames = 0 total_frames = 0 for i in range(0, len(segment_bytes) - frame_size + 1, frame_size): frame = segment_bytes[i:i + frame_size] if len(frame) == frame_size: is_speech = self.vad.is_speech(frame, sample_rate) if is_speech: speech_frames += 1 total_frames += 1 return speech_frames / total_frames if total_frames > 0 else 0.0 except Exception as e: print(f"WebRTC VAD failed: {e}") return 0.0 def run_diarization(self, audio_path): """Run diarization with VAD AFTER approach""" # Step 1: Diarization sees FULL audio first diarization = self.diar_pipeline(audio_path) diar_segments = [ {"start": t.start, "end": t.end, "speaker": spk} for t, _, spk in diarization.itertracks(yield_label=True) ] raw_speakers = list(set([seg['speaker'] for seg in diar_segments])) print(f"Diarization detected {len(raw_speakers)} speakers: {sorted(raw_speakers)}") # Step 2: Calculate SNR for adaptive processing snr = self.calculate_snr(audio_path) # Step 3: Apply VAD filtering ONLY if low SNR if snr < self.snr_threshold and self.use_vad: print(f"Low SNR ({snr:.1f} dB), applying VAD filtering") filtered_segments = [] for seg in diar_segments: # Skip VAD for very short segments if (seg["end"] - seg["start"]) < 0.2: continue speech_ratio = self.run_webrtc_vad_on_segment( audio_path, seg["start"], seg["end"] ) if speech_ratio >= self.vad_threshold: filtered_segments.append(seg) else: print(f"Filtered low-speech segment: {seg['start']:.2f}-{seg['end']:.2f} (speech: {speech_ratio:.1%})") diar_segments = filtered_segments else: print(f"Good SNR ({snr:.1f} dB), using all diarization segments") # Step 4: Duration filtering for Whisper filtered_segments = [ seg for seg in diar_segments if (seg["end"] - seg["start"]) >= self.min_whisper_duration ] print(f"Final: {len(filtered_segments)} segments for Whisper") return filtered_segments def merge_consecutive_speaker_segments(self, segments): """Merge only consecutive segments from the same speaker while preserving order""" if not segments: return [] # Sort by start time to ensure correct order segments.sort(key=lambda x: x["start"]) merged_segments = [] for seg in segments: if not merged_segments: # First segment merged_segments.append(seg) else: last_seg = merged_segments[-1] # Check if same speaker AND consecutive (small gap < 2 seconds) if (seg["speaker"] == last_seg["speaker"] and (seg["start"] - last_seg["end"]) < 2.0): # Merge with previous segment last_seg["text"] += " " + seg["text"] last_seg["end"] = seg["end"] else: # Different speaker or large gap - keep as separate segment merged_segments.append(seg) print(f"Reduced {len(segments)} segments to {len(merged_segments)} while preserving order") return merged_segments def run_transcription(self, audio_path, diar_json): """Segment-level transcription without word timestamps""" # Load and standardize audio audio, sr = torchaudio.load(audio_path) # Resample to 16kHz for consistency if sr != 16000: resampler = torchaudio.transforms.Resample(sr, 16000) audio = resampler(audio) sr = 16000 merged_segments = [] speaker_segments = {} # Calculate SNR for adaptive noise reduction snr = self.calculate_snr(audio_path) for seg in diar_json: start, end, spk = seg["start"], seg["end"], seg["speaker"] # Skip segments that are too short for Whisper segment_duration = end - start if segment_duration < self.min_whisper_duration: print(f"Skipping short segment for Whisper: {start:.2f}-{end:.2f} ({segment_duration:.2f}s)") continue start_sample, end_sample = int(start * sr), int(end * sr) # Handle both mono and stereo audio if audio.shape[0] > 1: # Stereo chunk = torch.mean(audio[:, start_sample:end_sample], dim=0).numpy() else: # Mono chunk = audio[0, start_sample:end_sample].numpy() # Calculate RMS energy for this segment rms_energy = self.calculate_rms_energy(chunk) # Adaptive noise reduction based on SNR + RMS if len(chunk) > int(0.1 * sr): if snr < 10 or rms_energy < 0.01: # Very noisy or low energy reduced = nr.reduce_noise(y=chunk, sr=sr, stationary=True, prop_decrease=0.8) elif snr < 20: # Moderately noisy reduced = nr.reduce_noise(y=chunk, sr=sr, stationary=True, prop_decrease=0.5) else: # Clean audio reduced = chunk else: reduced = chunk try: # Get text without timestamps result = self.asr_pipeline( reduced, generate_kwargs={ "task": "transcribe", "language": "en", "temperature": 0.0 # More accurate transcription } ) except Exception as e: print(f"Whisper failed on segment {start:.2f}-{end:.2f}: {e}") continue # Extract just the text (no timestamp processing) text = result.get("text", "").strip() # Clean the text for TTS and apply quality filtering clean_text = self.clean_transcription_text(text) if clean_text and self.should_keep_segment(clean_text, segment_duration, rms_energy): seg_dict = { "speaker": spk, "start": start, # Keep segment boundaries "end": end, # Keep segment boundaries "text": clean_text, # Use cleaned text "rms_energy": float(rms_energy) } merged_segments.append(seg_dict) if spk not in speaker_segments: speaker_segments[spk] = [] speaker_segments[spk].append(seg_dict) return merged_segments, list(speaker_segments.keys()) def run_pipeline(self, audio_path, output_dir=None, base_name=None, ref_rttm=None, ref_json=None, nse_events=None): """Add input validation and proper RTTM format""" # Validate input audio file if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") try: # Quick validation that it's loadable audio audio, sr = torchaudio.load(audio_path) if audio.numel() == 0: raise ValueError("Audio file is empty") except Exception as e: raise ValueError(f"Invalid audio file: {e}") print(f"Processing with VAD: {'ON' if self.use_vad else 'OFF'}") # Run diarization and transcription diar_json = self.run_diarization(audio_path) merged_segments, speakers = self.run_transcription(audio_path, diar_json) # Merge consecutive segments by same speaker merged_segments = self.merge_consecutive_speaker_segments(merged_segments) # Combine ASR segments with NSE events if provided if nse_events: print(f"Combining {len(merged_segments)} ASR segments with {len(nse_events)} NSE events") all_segments = merged_segments + nse_events # Sort by start time for proper timeline all_segments.sort(key=lambda x: x["start"]) else: all_segments = merged_segments if output_dir and base_name: os.makedirs(output_dir, exist_ok=True) # Save RTTM with standard format and precision rttm_path = os.path.join(output_dir, f"{base_name}.rttm") with open(rttm_path, "w") as f: for seg in diar_json: f.write( f"SPEAKER {base_name} 1 {seg['start']:.3f} " f"{seg['end']-seg['start']:.3f} " f"{seg['speaker']} \n" ) # Save transcription (with NSE events if available) merged_path = os.path.join(output_dir, f"{base_name}_merged_transcription.json") with open(merged_path, "w") as f: json.dump(all_segments, f, indent=2) # Evaluation if refs are provided eval_results = None if ref_rttm or ref_json: eval_results = self.evaluate(output_dir, base_name, ref_rttm=ref_rttm, ref_json=ref_json) return { "speakers": speakers, "segments": all_segments, # Return combined segments "evaluation": eval_results } def evaluate(self, output_dir, base_name, ref_rttm=None, ref_json=None): # Add output_dir validation if not output_dir or not base_name: return None results = {} hyp_rttm = os.path.join(output_dir, f"{base_name}.rttm") hyp_json = os.path.join(output_dir, f"{base_name}_merged_transcription.json") if ref_rttm and os.path.exists(hyp_rttm): def load_rttm(path): ann = Annotation() for line in open(path): if line.startswith("SPEAKER"): p = line.split() start, dur, spk = float(p[3]), float(p[4]), p[7] ann[Segment(start, start+dur)] = spk return ann der_score = DiarizationErrorRate()(load_rttm(ref_rttm), load_rttm(hyp_rttm)) results["DER"] = round(der_score * 100, 2) if ref_json and os.path.exists(hyp_json): def load_words_from_hypothesis(path): """Load text from YOUR pipeline output (has 'text' field)""" data = json.load(open(path)) # Filter out NSE events for WER calculation (only use speech) speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"] # Directly use segment text instead of tokens return " ".join([seg["text"] for seg in speech_segments]) def load_words_from_reference(path): """Load text from REFERENCE file (has 'tokens' field)""" data = json.load(open(path)) # Filter out NSE events for WER calculation (only use speech) speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"] # Reference format has tokens, not direct text return " ".join([tok["text"] for seg in speech_segments for tok in seg["tokens"]]) # Use appropriate loader for each file ref_text = load_words_from_reference(ref_json) hyp_text = load_words_from_hypothesis(hyp_json) transform = Compose([ToLowerCase(), RemovePunctuation(), RemoveMultipleSpaces(), Strip()]) results["WER_raw"] = round(wer(ref_text, hyp_text), 4) results["WER_normalized"] = round(wer(transform(ref_text), transform(hyp_text)), 4) return results if results else None def __call__(self, inputs, nse_events=None): """FIXED: Add proper temporary file cleanup""" if isinstance(inputs, dict): if "audio_bytes" in inputs: audio_bytes = inputs["audio_bytes"] elif "audio" in inputs: audio_bytes = inputs["audio"] else: raise ValueError("No audio found in inputs") else: audio_bytes = inputs tmp_path = None try: # Create temporary file for processing with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name # Run pipeline with NSE events result = self.run_pipeline(tmp_path, nse_events=nse_events) return result finally: # Always clean up temporary file if tmp_path and os.path.exists(tmp_path): os.unlink(tmp_path)