Capstone04 commited on
Commit
83ddb5c
·
verified ·
1 Parent(s): 9ebf74d

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. asr_diarization/pipeline.py +89 -101
asr_diarization/pipeline.py CHANGED
@@ -17,11 +17,11 @@ class ASR_Diarization:
17
  diar_model="pyannote/speaker-diarization-3.1",
18
  asr_model="Capstone04/TrainedWhisper_Medium",
19
  model_path=None,
20
- use_vad=True, # NEW: VAD after diarization
21
- vad_threshold=0.3, # NEW: VAD speech ratio threshold
22
- min_segment_duration=0.5, # NEW: Minimum segment duration
23
- snr_threshold=15.0, # NEW: SNR threshold for adaptive processing
24
- min_whisper_duration=0.3): # NEW: Minimum duration for Whisper
25
 
26
  self.HF_TOKEN = HF_TOKEN
27
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -31,26 +31,26 @@ class ASR_Diarization:
31
  self.snr_threshold = snr_threshold
32
  self.min_whisper_duration = min_whisper_duration
33
 
34
- # Load diarization model - FIX: Add device
35
  self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
36
  self.diar_pipeline = self.diar_pipeline.to(torch.device(self.device))
37
 
38
- # Load WebRTC VAD for post-diarization filtering - NEW
39
  if self.use_vad:
40
  try:
41
  import webrtcvad
42
- self.vad = webrtcvad.Vad(2) # Medium aggressiveness
43
- print("WebRTC VAD loaded for post-diarization filtering")
44
  except ImportError:
45
- print("⚠️ WebRTC VAD not available")
46
  self.use_vad = False
47
 
48
  # Load ASR model
49
  if model_path and os.path.exists(model_path):
50
- print(f"🔄 Loading custom ASR model from: {model_path}")
51
  actual_asr_model = model_path
52
  else:
53
- print(f"🔄 Loading default ASR model: {asr_model}")
54
  actual_asr_model = asr_model
55
 
56
  processor = WhisperProcessor.from_pretrained(actual_asr_model, token=HF_TOKEN)
@@ -93,15 +93,15 @@ class ASR_Diarization:
93
  return snr
94
 
95
  except Exception as e:
96
- print(f"⚠️ SNR calculation failed: {e}")
97
  return float('inf')
98
 
99
  def calculate_rms_energy(self, audio_chunk):
100
- """NEW: Calculate RMS energy for audio chunk"""
101
  return np.sqrt(np.mean(audio_chunk**2))
102
 
103
  def run_webrtc_vad_on_segment(self, audio_path, segment_start, segment_end):
104
- """NEW: Run WebRTC VAD on segment to get speech ratio"""
105
  if not self.use_vad:
106
  return 1.0
107
 
@@ -138,11 +138,11 @@ class ASR_Diarization:
138
  return speech_frames / total_frames if total_frames > 0 else 0.0
139
 
140
  except Exception as e:
141
- print(f"⚠️ WebRTC VAD failed: {e}")
142
  return 0.0
143
 
144
  def run_diarization(self, audio_path):
145
- """FIXED: Run diarization with VAD AFTER approach"""
146
  # Step 1: Diarization sees FULL audio first
147
  diarization = self.diar_pipeline(audio_path)
148
  diar_segments = [
@@ -176,7 +176,7 @@ class ASR_Diarization:
176
 
177
  diar_segments = filtered_segments
178
  else:
179
- print(f"Good SNR ({snr:.1f} dB), using all diarization segments")
180
 
181
  # Step 4: Duration filtering for Whisper
182
  filtered_segments = [
@@ -184,11 +184,11 @@ class ASR_Diarization:
184
  if (seg["end"] - seg["start"]) >= self.min_whisper_duration
185
  ]
186
 
187
- print(f"🎯 Final: {len(filtered_segments)} segments for Whisper")
188
  return filtered_segments
189
 
190
  def map_speaker_labels(self, segments, original_speakers=['A', 'B', 'C', 'D']):
191
- """NEW: Map SPEAKER_XX labels to A, B, C, D format to match original"""
192
  unique_speakers = list(set([seg['speaker'] for seg in segments]))
193
  speaker_map = {}
194
 
@@ -205,12 +205,43 @@ class ASR_Diarization:
205
 
206
  return segments, list(speaker_map.values())
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def run_transcription(self, audio_path, diar_json):
209
- """FIXED: Transcription with proper word-level timestamp extraction"""
210
- # FIX: Load and standardize audio
211
  audio, sr = torchaudio.load(audio_path)
212
 
213
- # FIX: Resample to 16kHz for consistency
214
  if sr != 16000:
215
  resampler = torchaudio.transforms.Resample(sr, 16000)
216
  audio = resampler(audio)
@@ -219,13 +250,13 @@ class ASR_Diarization:
219
  merged_segments = []
220
  speaker_segments = {}
221
 
222
- # NEW: Calculate SNR for adaptive noise reduction
223
  snr = self.calculate_snr(audio_path)
224
 
225
  for seg in diar_json:
226
  start, end, spk = seg["start"], seg["end"], seg["speaker"]
227
 
228
- # NEW: Skip segments that are too short for Whisper
229
  segment_duration = end - start
230
  if segment_duration < self.min_whisper_duration:
231
  print(f"⏩ Skipping short segment for Whisper: {start:.2f}-{end:.2f} ({segment_duration:.2f}s)")
@@ -233,16 +264,16 @@ class ASR_Diarization:
233
 
234
  start_sample, end_sample = int(start * sr), int(end * sr)
235
 
236
- # FIX: Handle both mono and stereo audio
237
  if audio.shape[0] > 1: # Stereo
238
  chunk = torch.mean(audio[:, start_sample:end_sample], dim=0).numpy()
239
  else: # Mono
240
  chunk = audio[0, start_sample:end_sample].numpy()
241
 
242
- # NEW: Calculate RMS energy for this segment
243
  rms_energy = self.calculate_rms_energy(chunk)
244
 
245
- # NEW: Adaptive noise reduction based on SNR + RMS
246
  if len(chunk) > int(0.1 * sr):
247
  if snr < 10 or rms_energy < 0.01: # Very noisy or low energy
248
  reduced = nr.reduce_noise(y=chunk, sr=sr, stationary=True, prop_decrease=0.8)
@@ -254,112 +285,68 @@ class ASR_Diarization:
254
  reduced = chunk
255
 
256
  try:
257
- # FIX: Force word-level timestamps and better configuration
258
  result = self.asr_pipeline(
259
  reduced,
260
- return_timestamps="word", # FORCE word-level timestamps
261
  generate_kwargs={
262
  "task": "transcribe",
263
- "language": "en"
 
264
  }
265
  )
266
  except Exception as e:
267
  print(f"⚠️ Whisper failed on segment {start:.2f}-{end:.2f}: {e}")
268
  continue
269
 
270
- tokens = []
271
- segment_text = ""
272
-
273
- # FIXED: Proper word-level timestamp extraction
274
- if "chunks" in result:
275
- for chunk_info in result["chunks"]:
276
- timestamp = chunk_info.get("timestamp")
277
- text = chunk_info.get("text", "").strip()
278
-
279
- if text and timestamp:
280
- chunk_start, chunk_end = timestamp
281
-
282
- # Validate and convert to absolute time
283
- if 0 <= chunk_start <= chunk_end <= (end - start):
284
- abs_start = start + chunk_start
285
- abs_end = start + chunk_end
286
- else:
287
- # Fallback: use segment boundaries
288
- abs_start = start
289
- abs_end = end
290
-
291
- # NEW: Split into individual words with distributed timestamps
292
- words = text.split()
293
- if len(words) == 1:
294
- # Single word - use original timestamp
295
- tokens.append({
296
- "start": abs_start,
297
- "end": abs_end,
298
- "text": text,
299
- "tag": "w"
300
- })
301
- else:
302
- # Multiple words - distribute time evenly
303
- word_duration = (abs_end - abs_start) / len(words)
304
- for i, word in enumerate(words):
305
- word_start = abs_start + (i * word_duration)
306
- word_end = word_start + word_duration
307
- tokens.append({
308
- "start": word_start,
309
- "end": word_end,
310
- "text": word,
311
- "tag": "w"
312
- })
313
-
314
- segment_text += text + " "
315
-
316
- # NEW: Only add segment if we got content
317
- if tokens or segment_text.strip():
318
  seg_dict = {
319
  "speaker": spk,
320
- "start": start,
321
- "end": end,
322
- "tokens": tokens,
323
- "text": segment_text.strip(), # NEW: Add full segment text
324
- "rms_energy": float(rms_energy) # NEW: Store RMS energy
325
  }
326
  merged_segments.append(seg_dict)
327
 
328
  if spk not in speaker_segments:
329
  speaker_segments[spk] = []
330
  speaker_segments[spk].append(seg_dict)
331
- else:
332
- print(f"🔇 Empty transcription for segment {start:.2f}-{end:.2f}")
333
 
334
  return merged_segments, list(speaker_segments.keys())
335
 
336
  def run_pipeline(self, audio_path, output_dir=None, base_name=None,
337
- ref_rttm=None, ref_json=None, nse_events=None): # NEW: nse_events parameter
338
- """FIXED: Add input validation and proper RTTM format"""
339
- # NEW: Validate input audio file
340
  if not os.path.exists(audio_path):
341
  raise FileNotFoundError(f"Audio file not found: {audio_path}")
342
 
343
  try:
344
- # NEW: Quick validation that it's loadable audio
345
  audio, sr = torchaudio.load(audio_path)
346
  if audio.numel() == 0:
347
  raise ValueError("Audio file is empty")
348
  except Exception as e:
349
  raise ValueError(f"Invalid audio file: {e}")
350
 
351
- print(f"🔊 Processing with VAD: {'ON' if self.use_vad else 'OFF'}")
352
 
353
  # Run diarization and transcription
354
  diar_json = self.run_diarization(audio_path)
355
  merged_segments, speakers = self.run_transcription(audio_path, diar_json)
356
 
357
- # NEW: Map speaker labels to match original format (A, B, C, D)
 
 
 
358
  merged_segments, speakers = self.map_speaker_labels(merged_segments)
359
 
360
- # NEW: Combine ASR segments with NSE events if provided
361
  if nse_events:
362
- print(f"🔊 Combining {len(merged_segments)} ASR segments with {len(nse_events)} NSE events")
363
  all_segments = merged_segments + nse_events
364
  # Sort by start time for proper timeline
365
  all_segments.sort(key=lambda x: x["start"])
@@ -369,14 +356,14 @@ class ASR_Diarization:
369
  if output_dir and base_name:
370
  os.makedirs(output_dir, exist_ok=True)
371
 
372
- # FIX: Save RTTM with standard format and precision
373
  rttm_path = os.path.join(output_dir, f"{base_name}.rttm")
374
  with open(rttm_path, "w") as f:
375
  for seg in diar_json:
376
  f.write(
377
  f"SPEAKER {base_name} 1 {seg['start']:.3f} "
378
  f"{seg['end']-seg['start']:.3f} <NA> <NA> "
379
- f"{seg['speaker']} <NA> <NA>\n" # FIX: Standard 9 fields
380
  )
381
 
382
  # Save transcription (with NSE events if available)
@@ -397,7 +384,7 @@ class ASR_Diarization:
397
  }
398
 
399
  def evaluate(self, output_dir, base_name, ref_rttm=None, ref_json=None):
400
- # FIX: Add output_dir validation
401
  if not output_dir or not base_name:
402
  return None
403
 
@@ -421,9 +408,10 @@ class ASR_Diarization:
421
  if ref_json and os.path.exists(hyp_json):
422
  def load_words(path):
423
  data = json.load(open(path))
424
- # NEW: Filter out NSE events for WER calculation (only use speech)
425
  speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"]
426
- return " ".join([tok["text"] for seg in speech_segments for tok in seg["tokens"]])
 
427
 
428
  ref_text, hyp_text = load_words(ref_json), load_words(hyp_json)
429
  transform = Compose([ToLowerCase(), RemovePunctuation(),
@@ -433,7 +421,7 @@ class ASR_Diarization:
433
 
434
  return results if results else None
435
 
436
- def __call__(self, inputs, nse_events=None): # NEW: nse_events parameter
437
  """FIXED: Add proper temporary file cleanup"""
438
  if isinstance(inputs, dict):
439
  if "audio_bytes" in inputs:
@@ -456,6 +444,6 @@ class ASR_Diarization:
456
  result = self.run_pipeline(tmp_path, nse_events=nse_events)
457
  return result
458
  finally:
459
- # FIX: Always clean up temporary file
460
  if tmp_path and os.path.exists(tmp_path):
461
  os.unlink(tmp_path)
 
17
  diar_model="pyannote/speaker-diarization-3.1",
18
  asr_model="Capstone04/TrainedWhisper_Medium",
19
  model_path=None,
20
+ use_vad=True,
21
+ vad_threshold=0.3,
22
+ min_segment_duration=0.5,
23
+ snr_threshold=15.0,
24
+ min_whisper_duration=0.3):
25
 
26
  self.HF_TOKEN = HF_TOKEN
27
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
31
  self.snr_threshold = snr_threshold
32
  self.min_whisper_duration = min_whisper_duration
33
 
34
+ # Load diarization model
35
  self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
36
  self.diar_pipeline = self.diar_pipeline.to(torch.device(self.device))
37
 
38
+ # Load WebRTC VAD for post-diarization filtering
39
  if self.use_vad:
40
  try:
41
  import webrtcvad
42
+ self.vad = webrtcvad.Vad(2)
43
+ print("WebRTC VAD loaded for post-diarization filtering")
44
  except ImportError:
45
+ print("WebRTC VAD not available")
46
  self.use_vad = False
47
 
48
  # Load ASR model
49
  if model_path and os.path.exists(model_path):
50
+ print(f"Loading custom ASR model from: {model_path}")
51
  actual_asr_model = model_path
52
  else:
53
+ print(f"Loading default ASR model: {asr_model}")
54
  actual_asr_model = asr_model
55
 
56
  processor = WhisperProcessor.from_pretrained(actual_asr_model, token=HF_TOKEN)
 
93
  return snr
94
 
95
  except Exception as e:
96
+ print(f"SNR calculation failed: {e}")
97
  return float('inf')
98
 
99
  def calculate_rms_energy(self, audio_chunk):
100
+ """Calculate RMS energy for audio chunk"""
101
  return np.sqrt(np.mean(audio_chunk**2))
102
 
103
  def run_webrtc_vad_on_segment(self, audio_path, segment_start, segment_end):
104
+ """Run WebRTC VAD on segment to get speech ratio"""
105
  if not self.use_vad:
106
  return 1.0
107
 
 
138
  return speech_frames / total_frames if total_frames > 0 else 0.0
139
 
140
  except Exception as e:
141
+ print(f"WebRTC VAD failed: {e}")
142
  return 0.0
143
 
144
  def run_diarization(self, audio_path):
145
+ """Run diarization with VAD AFTER approach"""
146
  # Step 1: Diarization sees FULL audio first
147
  diarization = self.diar_pipeline(audio_path)
148
  diar_segments = [
 
176
 
177
  diar_segments = filtered_segments
178
  else:
179
+ print(f"Good SNR ({snr:.1f} dB), using all diarization segments")
180
 
181
  # Step 4: Duration filtering for Whisper
182
  filtered_segments = [
 
184
  if (seg["end"] - seg["start"]) >= self.min_whisper_duration
185
  ]
186
 
187
+ print(f"Final: {len(filtered_segments)} segments for Whisper")
188
  return filtered_segments
189
 
190
  def map_speaker_labels(self, segments, original_speakers=['A', 'B', 'C', 'D']):
191
+ """Map SPEAKER_XX labels to A, B, C, D format to match original"""
192
  unique_speakers = list(set([seg['speaker'] for seg in segments]))
193
  speaker_map = {}
194
 
 
205
 
206
  return segments, list(speaker_map.values())
207
 
208
+ def merge_consecutive_speaker_segments(self, segments):
209
+ """Merge only consecutive segments from the same speaker while preserving order"""
210
+ if not segments:
211
+ return []
212
+
213
+ # Sort by start time to ensure correct order
214
+ segments.sort(key=lambda x: x["start"])
215
+
216
+ merged_segments = []
217
+
218
+ for seg in segments:
219
+ if not merged_segments:
220
+ # First segment
221
+ merged_segments.append(seg)
222
+ else:
223
+ last_seg = merged_segments[-1]
224
+
225
+ # Check if same speaker AND consecutive (small gap < 2 seconds)
226
+ if (seg["speaker"] == last_seg["speaker"] and
227
+ (seg["start"] - last_seg["end"]) < 2.0):
228
+
229
+ # Merge with previous segment
230
+ last_seg["text"] += " " + seg["text"]
231
+ last_seg["end"] = seg["end"]
232
+ else:
233
+ # Different speaker or large gap - keep as separate segment
234
+ merged_segments.append(seg)
235
+
236
+ print(f"🔀 Reduced {len(segments)} segments to {len(merged_segments)} while preserving order")
237
+ return merged_segments
238
+
239
  def run_transcription(self, audio_path, diar_json):
240
+ """SIMPLIFIED: Segment-level transcription without word timestamps"""
241
+ # Load and standardize audio
242
  audio, sr = torchaudio.load(audio_path)
243
 
244
+ # Resample to 16kHz for consistency
245
  if sr != 16000:
246
  resampler = torchaudio.transforms.Resample(sr, 16000)
247
  audio = resampler(audio)
 
250
  merged_segments = []
251
  speaker_segments = {}
252
 
253
+ # Calculate SNR for adaptive noise reduction
254
  snr = self.calculate_snr(audio_path)
255
 
256
  for seg in diar_json:
257
  start, end, spk = seg["start"], seg["end"], seg["speaker"]
258
 
259
+ # Skip segments that are too short for Whisper
260
  segment_duration = end - start
261
  if segment_duration < self.min_whisper_duration:
262
  print(f"⏩ Skipping short segment for Whisper: {start:.2f}-{end:.2f} ({segment_duration:.2f}s)")
 
264
 
265
  start_sample, end_sample = int(start * sr), int(end * sr)
266
 
267
+ # Handle both mono and stereo audio
268
  if audio.shape[0] > 1: # Stereo
269
  chunk = torch.mean(audio[:, start_sample:end_sample], dim=0).numpy()
270
  else: # Mono
271
  chunk = audio[0, start_sample:end_sample].numpy()
272
 
273
+ # Calculate RMS energy for this segment
274
  rms_energy = self.calculate_rms_energy(chunk)
275
 
276
+ # Adaptive noise reduction based on SNR + RMS
277
  if len(chunk) > int(0.1 * sr):
278
  if snr < 10 or rms_energy < 0.01: # Very noisy or low energy
279
  reduced = nr.reduce_noise(y=chunk, sr=sr, stationary=True, prop_decrease=0.8)
 
285
  reduced = chunk
286
 
287
  try:
288
+ # SIMPLIFIED: Get text without timestamps
289
  result = self.asr_pipeline(
290
  reduced,
 
291
  generate_kwargs={
292
  "task": "transcribe",
293
+ "language": "en",
294
+ "temperature": 0.0 # More accurate transcription
295
  }
296
  )
297
  except Exception as e:
298
  print(f"⚠️ Whisper failed on segment {start:.2f}-{end:.2f}: {e}")
299
  continue
300
 
301
+ # Extract just the text (no timestamp processing)
302
+ text = result.get("text", "").strip()
303
+
304
+ if text:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  seg_dict = {
306
  "speaker": spk,
307
+ "start": start, # Keep segment boundaries
308
+ "end": end, # Keep segment boundaries
309
+ "text": text, # Just the full segment text
310
+ "rms_energy": float(rms_energy)
 
311
  }
312
  merged_segments.append(seg_dict)
313
 
314
  if spk not in speaker_segments:
315
  speaker_segments[spk] = []
316
  speaker_segments[spk].append(seg_dict)
 
 
317
 
318
  return merged_segments, list(speaker_segments.keys())
319
 
320
  def run_pipeline(self, audio_path, output_dir=None, base_name=None,
321
+ ref_rttm=None, ref_json=None, nse_events=None):
322
+ """Add input validation and proper RTTM format"""
323
+ # Validate input audio file
324
  if not os.path.exists(audio_path):
325
  raise FileNotFoundError(f"Audio file not found: {audio_path}")
326
 
327
  try:
328
+ # Quick validation that it's loadable audio
329
  audio, sr = torchaudio.load(audio_path)
330
  if audio.numel() == 0:
331
  raise ValueError("Audio file is empty")
332
  except Exception as e:
333
  raise ValueError(f"Invalid audio file: {e}")
334
 
335
+ print(f"Processing with VAD: {'ON' if self.use_vad else 'OFF'}")
336
 
337
  # Run diarization and transcription
338
  diar_json = self.run_diarization(audio_path)
339
  merged_segments, speakers = self.run_transcription(audio_path, diar_json)
340
 
341
+ # NEW: Merge consecutive segments by same speaker
342
+ merged_segments = self.merge_consecutive_speaker_segments(merged_segments)
343
+
344
+ # Map speaker labels to match original format (A, B, C, D)
345
  merged_segments, speakers = self.map_speaker_labels(merged_segments)
346
 
347
+ # Combine ASR segments with NSE events if provided
348
  if nse_events:
349
+ print(f"Combining {len(merged_segments)} ASR segments with {len(nse_events)} NSE events")
350
  all_segments = merged_segments + nse_events
351
  # Sort by start time for proper timeline
352
  all_segments.sort(key=lambda x: x["start"])
 
356
  if output_dir and base_name:
357
  os.makedirs(output_dir, exist_ok=True)
358
 
359
+ # Save RTTM with standard format and precision
360
  rttm_path = os.path.join(output_dir, f"{base_name}.rttm")
361
  with open(rttm_path, "w") as f:
362
  for seg in diar_json:
363
  f.write(
364
  f"SPEAKER {base_name} 1 {seg['start']:.3f} "
365
  f"{seg['end']-seg['start']:.3f} <NA> <NA> "
366
+ f"{seg['speaker']} <NA> <NA>\n"
367
  )
368
 
369
  # Save transcription (with NSE events if available)
 
384
  }
385
 
386
  def evaluate(self, output_dir, base_name, ref_rttm=None, ref_json=None):
387
+ # Add output_dir validation
388
  if not output_dir or not base_name:
389
  return None
390
 
 
408
  if ref_json and os.path.exists(hyp_json):
409
  def load_words(path):
410
  data = json.load(open(path))
411
+ # Filter out NSE events for WER calculation (only use speech)
412
  speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"]
413
+ # NEW: Directly use segment text instead of tokens
414
+ return " ".join([seg["text"] for seg in speech_segments])
415
 
416
  ref_text, hyp_text = load_words(ref_json), load_words(hyp_json)
417
  transform = Compose([ToLowerCase(), RemovePunctuation(),
 
421
 
422
  return results if results else None
423
 
424
+ def __call__(self, inputs, nse_events=None):
425
  """FIXED: Add proper temporary file cleanup"""
426
  if isinstance(inputs, dict):
427
  if "audio_bytes" in inputs:
 
444
  result = self.run_pipeline(tmp_path, nse_events=nse_events)
445
  return result
446
  finally:
447
+ # Always clean up temporary file
448
  if tmp_path and os.path.exists(tmp_path):
449
  os.unlink(tmp_path)