|
|
import os |
|
|
import json |
|
|
import re |
|
|
import torch |
|
|
import torchaudio |
|
|
import noisereduce as nr |
|
|
import numpy as np |
|
|
from pyannote.audio import Pipeline |
|
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline as hf_pipeline |
|
|
import tempfile |
|
|
from pyannote.core import Annotation, Segment |
|
|
from pyannote.metrics.diarization import DiarizationErrorRate |
|
|
from jiwer import wer, Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSpaces, Strip |
|
|
|
|
|
|
|
|
class ASR_Diarization: |
|
|
def __init__(self, HF_TOKEN, |
|
|
diar_model="pyannote/speaker-diarization-3.1", |
|
|
asr_model="Capstone04/TrainedWhisper_Medium", |
|
|
model_path=None, |
|
|
use_vad=True, |
|
|
vad_threshold=0.3, |
|
|
min_segment_duration=0.5, |
|
|
snr_threshold=15.0, |
|
|
min_whisper_duration=0.3): |
|
|
|
|
|
self.HF_TOKEN = HF_TOKEN |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.use_vad = use_vad |
|
|
self.vad_threshold = vad_threshold |
|
|
self.min_segment_duration = min_segment_duration |
|
|
self.snr_threshold = snr_threshold |
|
|
self.min_whisper_duration = min_whisper_duration |
|
|
|
|
|
|
|
|
self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN) |
|
|
self.diar_pipeline = self.diar_pipeline.to(torch.device(self.device)) |
|
|
|
|
|
|
|
|
if self.use_vad: |
|
|
try: |
|
|
import webrtcvad |
|
|
self.vad = webrtcvad.Vad(2) |
|
|
print("WebRTC VAD loaded for post-diarization filtering") |
|
|
except ImportError: |
|
|
print("WebRTC VAD not available") |
|
|
self.use_vad = False |
|
|
|
|
|
|
|
|
if model_path and os.path.exists(model_path): |
|
|
print(f"Loading custom ASR model from: {model_path}") |
|
|
actual_asr_model = model_path |
|
|
else: |
|
|
print(f"Loading default ASR model: {asr_model}") |
|
|
actual_asr_model = asr_model |
|
|
|
|
|
processor = WhisperProcessor.from_pretrained(actual_asr_model, token=HF_TOKEN) |
|
|
model = WhisperForConditionalGeneration.from_pretrained(actual_asr_model, token=HF_TOKEN).to(self.device) |
|
|
|
|
|
self.asr_pipeline = hf_pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model=model, |
|
|
tokenizer=processor.tokenizer, |
|
|
feature_extractor=processor.feature_extractor, |
|
|
device=0 if self.device == "cuda" else -1, |
|
|
return_timestamps=True |
|
|
) |
|
|
|
|
|
def clean_transcription_text(self, text): |
|
|
"""Clean ASR text for better TTS performance""" |
|
|
if not text: |
|
|
return "" |
|
|
|
|
|
|
|
|
text = text.strip() |
|
|
|
|
|
|
|
|
text = re.sub(r'\s+([.,!?;:])', r'\1', text) |
|
|
text = re.sub(r'([.,!?;:])(?=\w)', r'\1 ', text) |
|
|
|
|
|
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
|
|
|
return text.strip() |
|
|
|
|
|
def should_keep_segment(self, text, duration, rms_energy): |
|
|
"""Generalized segment quality assessment""" |
|
|
|
|
|
if duration < self.min_whisper_duration: |
|
|
return False |
|
|
|
|
|
|
|
|
if rms_energy < 0.001: |
|
|
return False |
|
|
|
|
|
|
|
|
clean_text = text.strip() |
|
|
if len(clean_text) <= 1: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def calculate_snr(self, audio_path): |
|
|
"""NEW: Calculate SNR using RMS energy""" |
|
|
try: |
|
|
import librosa |
|
|
y, sr = librosa.load(audio_path, sr=16000, mono=True) |
|
|
|
|
|
|
|
|
rms = librosa.feature.rms(y=y)[0] |
|
|
if len(rms) == 0: |
|
|
return float('inf') |
|
|
|
|
|
|
|
|
high_rms = rms[rms > np.percentile(rms, 70)] |
|
|
low_rms = rms[rms <= np.percentile(rms, 30)] |
|
|
|
|
|
if len(high_rms) == 0 or len(low_rms) == 0: |
|
|
return float('inf') |
|
|
|
|
|
signal_power = np.mean(high_rms) |
|
|
noise_power = np.mean(low_rms) |
|
|
|
|
|
if noise_power == 0: |
|
|
return float('inf') |
|
|
|
|
|
snr = 10 * np.log10(signal_power / noise_power) |
|
|
return snr |
|
|
|
|
|
except Exception as e: |
|
|
print(f"SNR calculation failed: {e}") |
|
|
return float('inf') |
|
|
|
|
|
def calculate_rms_energy(self, audio_chunk): |
|
|
"""Calculate RMS energy for audio chunk""" |
|
|
return np.sqrt(np.mean(audio_chunk**2)) |
|
|
|
|
|
def run_webrtc_vad_on_segment(self, audio_path, segment_start, segment_end): |
|
|
"""Run WebRTC VAD on segment to get speech ratio""" |
|
|
if not self.use_vad: |
|
|
return 1.0 |
|
|
|
|
|
try: |
|
|
import wave |
|
|
|
|
|
with wave.open(audio_path, "rb") as wf: |
|
|
sample_rate = wf.getframerate() |
|
|
n_frames = wf.getnframes() |
|
|
audio_data = wf.readframes(n_frames) |
|
|
|
|
|
audio_array = np.frombuffer(audio_data, dtype=np.int16) |
|
|
start_sample = int(segment_start * sample_rate) |
|
|
end_sample = int(segment_end * sample_rate) |
|
|
segment_audio = audio_array[start_sample:end_sample] |
|
|
segment_bytes = segment_audio.tobytes() |
|
|
|
|
|
|
|
|
frame_duration = 30 |
|
|
bytes_per_sample = 2 |
|
|
frame_size = int(sample_rate * frame_duration / 1000) * bytes_per_sample |
|
|
|
|
|
speech_frames = 0 |
|
|
total_frames = 0 |
|
|
|
|
|
for i in range(0, len(segment_bytes) - frame_size + 1, frame_size): |
|
|
frame = segment_bytes[i:i + frame_size] |
|
|
if len(frame) == frame_size: |
|
|
is_speech = self.vad.is_speech(frame, sample_rate) |
|
|
if is_speech: |
|
|
speech_frames += 1 |
|
|
total_frames += 1 |
|
|
|
|
|
return speech_frames / total_frames if total_frames > 0 else 0.0 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"WebRTC VAD failed: {e}") |
|
|
return 0.0 |
|
|
|
|
|
def run_diarization(self, audio_path): |
|
|
"""Run diarization with VAD AFTER approach""" |
|
|
|
|
|
diarization = self.diar_pipeline(audio_path) |
|
|
diar_segments = [ |
|
|
{"start": t.start, "end": t.end, "speaker": spk} |
|
|
for t, _, spk in diarization.itertracks(yield_label=True) |
|
|
] |
|
|
|
|
|
raw_speakers = list(set([seg['speaker'] for seg in diar_segments])) |
|
|
print(f"Diarization detected {len(raw_speakers)} speakers: {sorted(raw_speakers)}") |
|
|
|
|
|
|
|
|
snr = self.calculate_snr(audio_path) |
|
|
|
|
|
|
|
|
if snr < self.snr_threshold and self.use_vad: |
|
|
print(f"Low SNR ({snr:.1f} dB), applying VAD filtering") |
|
|
filtered_segments = [] |
|
|
|
|
|
for seg in diar_segments: |
|
|
|
|
|
if (seg["end"] - seg["start"]) < 0.2: |
|
|
continue |
|
|
|
|
|
speech_ratio = self.run_webrtc_vad_on_segment( |
|
|
audio_path, seg["start"], seg["end"] |
|
|
) |
|
|
|
|
|
if speech_ratio >= self.vad_threshold: |
|
|
filtered_segments.append(seg) |
|
|
else: |
|
|
print(f"Filtered low-speech segment: {seg['start']:.2f}-{seg['end']:.2f} (speech: {speech_ratio:.1%})") |
|
|
|
|
|
diar_segments = filtered_segments |
|
|
else: |
|
|
print(f"Good SNR ({snr:.1f} dB), using all diarization segments") |
|
|
|
|
|
|
|
|
filtered_segments = [ |
|
|
seg for seg in diar_segments |
|
|
if (seg["end"] - seg["start"]) >= self.min_whisper_duration |
|
|
] |
|
|
|
|
|
print(f"Final: {len(filtered_segments)} segments for Whisper") |
|
|
return filtered_segments |
|
|
|
|
|
def merge_consecutive_speaker_segments(self, segments): |
|
|
"""Merge only consecutive segments from the same speaker while preserving order""" |
|
|
if not segments: |
|
|
return [] |
|
|
|
|
|
|
|
|
segments.sort(key=lambda x: x["start"]) |
|
|
|
|
|
merged_segments = [] |
|
|
|
|
|
for seg in segments: |
|
|
if not merged_segments: |
|
|
|
|
|
merged_segments.append(seg) |
|
|
else: |
|
|
last_seg = merged_segments[-1] |
|
|
|
|
|
|
|
|
if (seg["speaker"] == last_seg["speaker"] and |
|
|
(seg["start"] - last_seg["end"]) < 2.0): |
|
|
|
|
|
|
|
|
last_seg["text"] += " " + seg["text"] |
|
|
last_seg["end"] = seg["end"] |
|
|
else: |
|
|
|
|
|
merged_segments.append(seg) |
|
|
|
|
|
print(f"Reduced {len(segments)} segments to {len(merged_segments)} while preserving order") |
|
|
return merged_segments |
|
|
|
|
|
def run_transcription(self, audio_path, diar_json): |
|
|
"""Segment-level transcription without word timestamps""" |
|
|
|
|
|
audio, sr = torchaudio.load(audio_path) |
|
|
|
|
|
|
|
|
if sr != 16000: |
|
|
resampler = torchaudio.transforms.Resample(sr, 16000) |
|
|
audio = resampler(audio) |
|
|
sr = 16000 |
|
|
|
|
|
merged_segments = [] |
|
|
speaker_segments = {} |
|
|
|
|
|
|
|
|
snr = self.calculate_snr(audio_path) |
|
|
|
|
|
for seg in diar_json: |
|
|
start, end, spk = seg["start"], seg["end"], seg["speaker"] |
|
|
|
|
|
|
|
|
segment_duration = end - start |
|
|
if segment_duration < self.min_whisper_duration: |
|
|
print(f"Skipping short segment for Whisper: {start:.2f}-{end:.2f} ({segment_duration:.2f}s)") |
|
|
continue |
|
|
|
|
|
start_sample, end_sample = int(start * sr), int(end * sr) |
|
|
|
|
|
|
|
|
if audio.shape[0] > 1: |
|
|
chunk = torch.mean(audio[:, start_sample:end_sample], dim=0).numpy() |
|
|
else: |
|
|
chunk = audio[0, start_sample:end_sample].numpy() |
|
|
|
|
|
|
|
|
rms_energy = self.calculate_rms_energy(chunk) |
|
|
|
|
|
|
|
|
if len(chunk) > int(0.1 * sr): |
|
|
if snr < 10 or rms_energy < 0.01: |
|
|
reduced = nr.reduce_noise(y=chunk, sr=sr, stationary=True, prop_decrease=0.8) |
|
|
elif snr < 20: |
|
|
reduced = nr.reduce_noise(y=chunk, sr=sr, stationary=True, prop_decrease=0.5) |
|
|
else: |
|
|
reduced = chunk |
|
|
else: |
|
|
reduced = chunk |
|
|
|
|
|
try: |
|
|
|
|
|
result = self.asr_pipeline( |
|
|
reduced, |
|
|
generate_kwargs={ |
|
|
"task": "transcribe", |
|
|
"language": "en", |
|
|
"temperature": 0.0 |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Whisper failed on segment {start:.2f}-{end:.2f}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
text = result.get("text", "").strip() |
|
|
|
|
|
|
|
|
clean_text = self.clean_transcription_text(text) |
|
|
|
|
|
if clean_text and self.should_keep_segment(clean_text, segment_duration, rms_energy): |
|
|
seg_dict = { |
|
|
"speaker": spk, |
|
|
"start": start, |
|
|
"end": end, |
|
|
"text": clean_text, |
|
|
"rms_energy": float(rms_energy) |
|
|
} |
|
|
merged_segments.append(seg_dict) |
|
|
|
|
|
if spk not in speaker_segments: |
|
|
speaker_segments[spk] = [] |
|
|
speaker_segments[spk].append(seg_dict) |
|
|
|
|
|
return merged_segments, list(speaker_segments.keys()) |
|
|
|
|
|
def run_pipeline(self, audio_path, output_dir=None, base_name=None, |
|
|
ref_rttm=None, ref_json=None, nse_events=None): |
|
|
"""Add input validation and proper RTTM format""" |
|
|
|
|
|
if not os.path.exists(audio_path): |
|
|
raise FileNotFoundError(f"Audio file not found: {audio_path}") |
|
|
|
|
|
try: |
|
|
|
|
|
audio, sr = torchaudio.load(audio_path) |
|
|
if audio.numel() == 0: |
|
|
raise ValueError("Audio file is empty") |
|
|
except Exception as e: |
|
|
raise ValueError(f"Invalid audio file: {e}") |
|
|
|
|
|
print(f"Processing with VAD: {'ON' if self.use_vad else 'OFF'}") |
|
|
|
|
|
|
|
|
diar_json = self.run_diarization(audio_path) |
|
|
merged_segments, speakers = self.run_transcription(audio_path, diar_json) |
|
|
|
|
|
|
|
|
merged_segments = self.merge_consecutive_speaker_segments(merged_segments) |
|
|
|
|
|
|
|
|
if nse_events: |
|
|
print(f"Combining {len(merged_segments)} ASR segments with {len(nse_events)} NSE events") |
|
|
all_segments = merged_segments + nse_events |
|
|
|
|
|
all_segments.sort(key=lambda x: x["start"]) |
|
|
else: |
|
|
all_segments = merged_segments |
|
|
|
|
|
if output_dir and base_name: |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
rttm_path = os.path.join(output_dir, f"{base_name}.rttm") |
|
|
with open(rttm_path, "w") as f: |
|
|
for seg in diar_json: |
|
|
f.write( |
|
|
f"SPEAKER {base_name} 1 {seg['start']:.3f} " |
|
|
f"{seg['end']-seg['start']:.3f} <NA> <NA> " |
|
|
f"{seg['speaker']} <NA> <NA>\n" |
|
|
) |
|
|
|
|
|
|
|
|
merged_path = os.path.join(output_dir, f"{base_name}_merged_transcription.json") |
|
|
with open(merged_path, "w") as f: |
|
|
json.dump(all_segments, f, indent=2) |
|
|
|
|
|
|
|
|
eval_results = None |
|
|
if ref_rttm or ref_json: |
|
|
eval_results = self.evaluate(output_dir, base_name, |
|
|
ref_rttm=ref_rttm, ref_json=ref_json) |
|
|
|
|
|
return { |
|
|
"speakers": speakers, |
|
|
"segments": all_segments, |
|
|
"evaluation": eval_results |
|
|
} |
|
|
|
|
|
def evaluate(self, output_dir, base_name, ref_rttm=None, ref_json=None): |
|
|
|
|
|
if not output_dir or not base_name: |
|
|
return None |
|
|
|
|
|
results = {} |
|
|
hyp_rttm = os.path.join(output_dir, f"{base_name}.rttm") |
|
|
hyp_json = os.path.join(output_dir, f"{base_name}_merged_transcription.json") |
|
|
|
|
|
if ref_rttm and os.path.exists(hyp_rttm): |
|
|
def load_rttm(path): |
|
|
ann = Annotation() |
|
|
for line in open(path): |
|
|
if line.startswith("SPEAKER"): |
|
|
p = line.split() |
|
|
start, dur, spk = float(p[3]), float(p[4]), p[7] |
|
|
ann[Segment(start, start+dur)] = spk |
|
|
return ann |
|
|
|
|
|
der_score = DiarizationErrorRate()(load_rttm(ref_rttm), load_rttm(hyp_rttm)) |
|
|
results["DER"] = round(der_score * 100, 2) |
|
|
|
|
|
if ref_json and os.path.exists(hyp_json): |
|
|
def load_words_from_hypothesis(path): |
|
|
"""Load text from YOUR pipeline output (has 'text' field)""" |
|
|
data = json.load(open(path)) |
|
|
|
|
|
speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"] |
|
|
|
|
|
return " ".join([seg["text"] for seg in speech_segments]) |
|
|
|
|
|
def load_words_from_reference(path): |
|
|
"""Load text from REFERENCE file (has 'tokens' field)""" |
|
|
data = json.load(open(path)) |
|
|
|
|
|
speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"] |
|
|
|
|
|
return " ".join([tok["text"] for seg in speech_segments for tok in seg["tokens"]]) |
|
|
|
|
|
|
|
|
ref_text = load_words_from_reference(ref_json) |
|
|
hyp_text = load_words_from_hypothesis(hyp_json) |
|
|
|
|
|
transform = Compose([ToLowerCase(), RemovePunctuation(), |
|
|
RemoveMultipleSpaces(), Strip()]) |
|
|
results["WER_raw"] = round(wer(ref_text, hyp_text), 4) |
|
|
results["WER_normalized"] = round(wer(transform(ref_text), transform(hyp_text)), 4) |
|
|
|
|
|
return results if results else None |
|
|
|
|
|
def __call__(self, inputs, nse_events=None): |
|
|
"""FIXED: Add proper temporary file cleanup""" |
|
|
if isinstance(inputs, dict): |
|
|
if "audio_bytes" in inputs: |
|
|
audio_bytes = inputs["audio_bytes"] |
|
|
elif "audio" in inputs: |
|
|
audio_bytes = inputs["audio"] |
|
|
else: |
|
|
raise ValueError("No audio found in inputs") |
|
|
else: |
|
|
audio_bytes = inputs |
|
|
|
|
|
tmp_path = None |
|
|
try: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
|
|
tmp.write(audio_bytes) |
|
|
tmp_path = tmp.name |
|
|
|
|
|
|
|
|
result = self.run_pipeline(tmp_path, nse_events=nse_events) |
|
|
return result |
|
|
finally: |
|
|
|
|
|
if tmp_path and os.path.exists(tmp_path): |
|
|
os.unlink(tmp_path) |
|
|
|