Capstone04 commited on
Commit
c300c20
·
verified ·
1 Parent(s): 24ed202

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. asr_diarization/pipeline.py +10 -10
asr_diarization/pipeline.py CHANGED
@@ -150,14 +150,14 @@ class ASR_Diarization:
150
  for t, _, spk in diarization.itertracks(yield_label=True)
151
  ]
152
 
153
- print(f"🎯 Diarization found {len(diar_segments)} segments")
154
 
155
  # Step 2: Calculate SNR for adaptive processing
156
  snr = self.calculate_snr(audio_path)
157
 
158
  # Step 3: Apply VAD filtering ONLY if low SNR
159
  if snr < self.snr_threshold and self.use_vad:
160
- print(f"🔇 Low SNR ({snr:.1f} dB), applying VAD filtering")
161
  filtered_segments = []
162
 
163
  for seg in diar_segments:
@@ -172,7 +172,7 @@ class ASR_Diarization:
172
  if speech_ratio >= self.vad_threshold:
173
  filtered_segments.append(seg)
174
  else:
175
- print(f"🔇 Filtered low-speech segment: {seg['start']:.2f}-{seg['end']:.2f} (speech: {speech_ratio:.1%})")
176
 
177
  diar_segments = filtered_segments
178
  else:
@@ -233,11 +233,11 @@ class ASR_Diarization:
233
  # Different speaker or large gap - keep as separate segment
234
  merged_segments.append(seg)
235
 
236
- print(f"🔀 Reduced {len(segments)} segments to {len(merged_segments)} while preserving order")
237
  return merged_segments
238
 
239
  def run_transcription(self, audio_path, diar_json):
240
- """SIMPLIFIED: Segment-level transcription without word timestamps"""
241
  # Load and standardize audio
242
  audio, sr = torchaudio.load(audio_path)
243
 
@@ -259,7 +259,7 @@ class ASR_Diarization:
259
  # Skip segments that are too short for Whisper
260
  segment_duration = end - start
261
  if segment_duration < self.min_whisper_duration:
262
- print(f"Skipping short segment for Whisper: {start:.2f}-{end:.2f} ({segment_duration:.2f}s)")
263
  continue
264
 
265
  start_sample, end_sample = int(start * sr), int(end * sr)
@@ -285,7 +285,7 @@ class ASR_Diarization:
285
  reduced = chunk
286
 
287
  try:
288
- # SIMPLIFIED: Get text without timestamps
289
  result = self.asr_pipeline(
290
  reduced,
291
  generate_kwargs={
@@ -295,7 +295,7 @@ class ASR_Diarization:
295
  }
296
  )
297
  except Exception as e:
298
- print(f"⚠️ Whisper failed on segment {start:.2f}-{end:.2f}: {e}")
299
  continue
300
 
301
  # Extract just the text (no timestamp processing)
@@ -338,7 +338,7 @@ class ASR_Diarization:
338
  diar_json = self.run_diarization(audio_path)
339
  merged_segments, speakers = self.run_transcription(audio_path, diar_json)
340
 
341
- # NEW: Merge consecutive segments by same speaker
342
  merged_segments = self.merge_consecutive_speaker_segments(merged_segments)
343
 
344
  # Map speaker labels to match original format (A, B, C, D)
@@ -411,7 +411,7 @@ class ASR_Diarization:
411
  data = json.load(open(path))
412
  # Filter out NSE events for WER calculation (only use speech)
413
  speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"]
414
- # NEW: Directly use segment text instead of tokens
415
  return " ".join([seg["text"] for seg in speech_segments])
416
 
417
  def load_words_from_reference(path):
 
150
  for t, _, spk in diarization.itertracks(yield_label=True)
151
  ]
152
 
153
+ print(f"Diarization found {len(diar_segments)} segments")
154
 
155
  # Step 2: Calculate SNR for adaptive processing
156
  snr = self.calculate_snr(audio_path)
157
 
158
  # Step 3: Apply VAD filtering ONLY if low SNR
159
  if snr < self.snr_threshold and self.use_vad:
160
+ print(f"Low SNR ({snr:.1f} dB), applying VAD filtering")
161
  filtered_segments = []
162
 
163
  for seg in diar_segments:
 
172
  if speech_ratio >= self.vad_threshold:
173
  filtered_segments.append(seg)
174
  else:
175
+ print(f"Filtered low-speech segment: {seg['start']:.2f}-{seg['end']:.2f} (speech: {speech_ratio:.1%})")
176
 
177
  diar_segments = filtered_segments
178
  else:
 
233
  # Different speaker or large gap - keep as separate segment
234
  merged_segments.append(seg)
235
 
236
+ print(f"Reduced {len(segments)} segments to {len(merged_segments)} while preserving order")
237
  return merged_segments
238
 
239
  def run_transcription(self, audio_path, diar_json):
240
+ """Segment-level transcription without word timestamps"""
241
  # Load and standardize audio
242
  audio, sr = torchaudio.load(audio_path)
243
 
 
259
  # Skip segments that are too short for Whisper
260
  segment_duration = end - start
261
  if segment_duration < self.min_whisper_duration:
262
+ print(f"Skipping short segment for Whisper: {start:.2f}-{end:.2f} ({segment_duration:.2f}s)")
263
  continue
264
 
265
  start_sample, end_sample = int(start * sr), int(end * sr)
 
285
  reduced = chunk
286
 
287
  try:
288
+ # Get text without timestamps
289
  result = self.asr_pipeline(
290
  reduced,
291
  generate_kwargs={
 
295
  }
296
  )
297
  except Exception as e:
298
+ print(f"Whisper failed on segment {start:.2f}-{end:.2f}: {e}")
299
  continue
300
 
301
  # Extract just the text (no timestamp processing)
 
338
  diar_json = self.run_diarization(audio_path)
339
  merged_segments, speakers = self.run_transcription(audio_path, diar_json)
340
 
341
+ # Merge consecutive segments by same speaker
342
  merged_segments = self.merge_consecutive_speaker_segments(merged_segments)
343
 
344
  # Map speaker labels to match original format (A, B, C, D)
 
411
  data = json.load(open(path))
412
  # Filter out NSE events for WER calculation (only use speech)
413
  speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"]
414
+ # Directly use segment text instead of tokens
415
  return " ".join([seg["text"] for seg in speech_segments])
416
 
417
  def load_words_from_reference(path):