Upload folder using huggingface_hub
Browse files- asr_diarization/pipeline.py +10 -10
asr_diarization/pipeline.py
CHANGED
|
@@ -150,14 +150,14 @@ class ASR_Diarization:
|
|
| 150 |
for t, _, spk in diarization.itertracks(yield_label=True)
|
| 151 |
]
|
| 152 |
|
| 153 |
-
print(f"
|
| 154 |
|
| 155 |
# Step 2: Calculate SNR for adaptive processing
|
| 156 |
snr = self.calculate_snr(audio_path)
|
| 157 |
|
| 158 |
# Step 3: Apply VAD filtering ONLY if low SNR
|
| 159 |
if snr < self.snr_threshold and self.use_vad:
|
| 160 |
-
print(f"
|
| 161 |
filtered_segments = []
|
| 162 |
|
| 163 |
for seg in diar_segments:
|
|
@@ -172,7 +172,7 @@ class ASR_Diarization:
|
|
| 172 |
if speech_ratio >= self.vad_threshold:
|
| 173 |
filtered_segments.append(seg)
|
| 174 |
else:
|
| 175 |
-
print(f"
|
| 176 |
|
| 177 |
diar_segments = filtered_segments
|
| 178 |
else:
|
|
@@ -233,11 +233,11 @@ class ASR_Diarization:
|
|
| 233 |
# Different speaker or large gap - keep as separate segment
|
| 234 |
merged_segments.append(seg)
|
| 235 |
|
| 236 |
-
print(f"
|
| 237 |
return merged_segments
|
| 238 |
|
| 239 |
def run_transcription(self, audio_path, diar_json):
|
| 240 |
-
"""
|
| 241 |
# Load and standardize audio
|
| 242 |
audio, sr = torchaudio.load(audio_path)
|
| 243 |
|
|
@@ -259,7 +259,7 @@ class ASR_Diarization:
|
|
| 259 |
# Skip segments that are too short for Whisper
|
| 260 |
segment_duration = end - start
|
| 261 |
if segment_duration < self.min_whisper_duration:
|
| 262 |
-
print(f"
|
| 263 |
continue
|
| 264 |
|
| 265 |
start_sample, end_sample = int(start * sr), int(end * sr)
|
|
@@ -285,7 +285,7 @@ class ASR_Diarization:
|
|
| 285 |
reduced = chunk
|
| 286 |
|
| 287 |
try:
|
| 288 |
-
#
|
| 289 |
result = self.asr_pipeline(
|
| 290 |
reduced,
|
| 291 |
generate_kwargs={
|
|
@@ -295,7 +295,7 @@ class ASR_Diarization:
|
|
| 295 |
}
|
| 296 |
)
|
| 297 |
except Exception as e:
|
| 298 |
-
print(f"
|
| 299 |
continue
|
| 300 |
|
| 301 |
# Extract just the text (no timestamp processing)
|
|
@@ -338,7 +338,7 @@ class ASR_Diarization:
|
|
| 338 |
diar_json = self.run_diarization(audio_path)
|
| 339 |
merged_segments, speakers = self.run_transcription(audio_path, diar_json)
|
| 340 |
|
| 341 |
-
#
|
| 342 |
merged_segments = self.merge_consecutive_speaker_segments(merged_segments)
|
| 343 |
|
| 344 |
# Map speaker labels to match original format (A, B, C, D)
|
|
@@ -411,7 +411,7 @@ class ASR_Diarization:
|
|
| 411 |
data = json.load(open(path))
|
| 412 |
# Filter out NSE events for WER calculation (only use speech)
|
| 413 |
speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"]
|
| 414 |
-
#
|
| 415 |
return " ".join([seg["text"] for seg in speech_segments])
|
| 416 |
|
| 417 |
def load_words_from_reference(path):
|
|
|
|
| 150 |
for t, _, spk in diarization.itertracks(yield_label=True)
|
| 151 |
]
|
| 152 |
|
| 153 |
+
print(f"Diarization found {len(diar_segments)} segments")
|
| 154 |
|
| 155 |
# Step 2: Calculate SNR for adaptive processing
|
| 156 |
snr = self.calculate_snr(audio_path)
|
| 157 |
|
| 158 |
# Step 3: Apply VAD filtering ONLY if low SNR
|
| 159 |
if snr < self.snr_threshold and self.use_vad:
|
| 160 |
+
print(f"Low SNR ({snr:.1f} dB), applying VAD filtering")
|
| 161 |
filtered_segments = []
|
| 162 |
|
| 163 |
for seg in diar_segments:
|
|
|
|
| 172 |
if speech_ratio >= self.vad_threshold:
|
| 173 |
filtered_segments.append(seg)
|
| 174 |
else:
|
| 175 |
+
print(f"Filtered low-speech segment: {seg['start']:.2f}-{seg['end']:.2f} (speech: {speech_ratio:.1%})")
|
| 176 |
|
| 177 |
diar_segments = filtered_segments
|
| 178 |
else:
|
|
|
|
| 233 |
# Different speaker or large gap - keep as separate segment
|
| 234 |
merged_segments.append(seg)
|
| 235 |
|
| 236 |
+
print(f"Reduced {len(segments)} segments to {len(merged_segments)} while preserving order")
|
| 237 |
return merged_segments
|
| 238 |
|
| 239 |
def run_transcription(self, audio_path, diar_json):
|
| 240 |
+
"""Segment-level transcription without word timestamps"""
|
| 241 |
# Load and standardize audio
|
| 242 |
audio, sr = torchaudio.load(audio_path)
|
| 243 |
|
|
|
|
| 259 |
# Skip segments that are too short for Whisper
|
| 260 |
segment_duration = end - start
|
| 261 |
if segment_duration < self.min_whisper_duration:
|
| 262 |
+
print(f"Skipping short segment for Whisper: {start:.2f}-{end:.2f} ({segment_duration:.2f}s)")
|
| 263 |
continue
|
| 264 |
|
| 265 |
start_sample, end_sample = int(start * sr), int(end * sr)
|
|
|
|
| 285 |
reduced = chunk
|
| 286 |
|
| 287 |
try:
|
| 288 |
+
# Get text without timestamps
|
| 289 |
result = self.asr_pipeline(
|
| 290 |
reduced,
|
| 291 |
generate_kwargs={
|
|
|
|
| 295 |
}
|
| 296 |
)
|
| 297 |
except Exception as e:
|
| 298 |
+
print(f"Whisper failed on segment {start:.2f}-{end:.2f}: {e}")
|
| 299 |
continue
|
| 300 |
|
| 301 |
# Extract just the text (no timestamp processing)
|
|
|
|
| 338 |
diar_json = self.run_diarization(audio_path)
|
| 339 |
merged_segments, speakers = self.run_transcription(audio_path, diar_json)
|
| 340 |
|
| 341 |
+
# Merge consecutive segments by same speaker
|
| 342 |
merged_segments = self.merge_consecutive_speaker_segments(merged_segments)
|
| 343 |
|
| 344 |
# Map speaker labels to match original format (A, B, C, D)
|
|
|
|
| 411 |
data = json.load(open(path))
|
| 412 |
# Filter out NSE events for WER calculation (only use speech)
|
| 413 |
speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"]
|
| 414 |
+
# Directly use segment text instead of tokens
|
| 415 |
return " ".join([seg["text"] for seg in speech_segments])
|
| 416 |
|
| 417 |
def load_words_from_reference(path):
|