Capstone04 commited on
Commit
af763b6
·
verified ·
1 Parent(s): 67dfefe

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. asr_diarization/pipeline.py +292 -44
asr_diarization/pipeline.py CHANGED
@@ -3,6 +3,7 @@ import json
3
  import torch
4
  import torchaudio
5
  import noisereduce as nr
 
6
  from pyannote.audio import Pipeline
7
  from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline as hf_pipeline
8
  import tempfile
@@ -15,14 +16,36 @@ class ASR_Diarization:
15
  def __init__(self, HF_TOKEN,
16
  diar_model="pyannote/speaker-diarization-3.1",
17
  asr_model="Capstone04/TrainedWhisper_Medium",
18
- model_path=None): # NEW: model_path parameter
 
 
 
 
 
 
19
  self.HF_TOKEN = HF_TOKEN
20
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
21
 
22
- # Load diarization model
23
  self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
 
24
 
25
- # MODIFIED: Use custom model_path if provided, otherwise use asr_model
 
 
 
 
 
 
 
 
 
 
26
  if model_path and os.path.exists(model_path):
27
  print(f"🔄 Loading custom ASR model from: {model_path}")
28
  actual_asr_model = model_path
@@ -42,75 +65,285 @@ class ASR_Diarization:
42
  return_timestamps=True
43
  )
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def run_diarization(self, audio_path):
 
 
46
  diarization = self.diar_pipeline(audio_path)
47
- return [
48
  {"start": t.start, "end": t.end, "speaker": spk}
49
  for t, _, spk in diarization.itertracks(yield_label=True)
50
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def run_transcription(self, audio_path, diar_json):
 
 
53
  audio, sr = torchaudio.load(audio_path)
 
 
 
 
 
 
 
54
  merged_segments = []
55
  speaker_segments = {}
 
 
 
56
 
57
  for seg in diar_json:
58
  start, end, spk = seg["start"], seg["end"], seg["speaker"]
 
 
 
 
 
 
 
59
  start_sample, end_sample = int(start * sr), int(end * sr)
60
- chunk = audio[0, start_sample:end_sample].numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- reduced = nr.reduce_noise(y=chunk, sr=sr)
63
- result = self.asr_pipeline(reduced)
 
 
 
64
 
65
  tokens = []
 
 
66
  if "chunks" in result:
67
  for word_info in result["chunks"]:
68
- start_ts, end_ts = word_info.get("timestamp", (None, None)) or (None, None)
69
- tokens.append({
70
- "start": start_ts,
71
- "end": end_ts,
72
- "text": word_info["text"],
73
- "tag": "w"
74
- })
75
-
76
- seg_dict = {
77
- "speaker": spk,
78
- "start": start,
79
- "end": end,
80
- "tokens": tokens
81
- }
82
- merged_segments.append(seg_dict)
83
-
84
- if spk not in speaker_segments:
85
- speaker_segments[spk] = []
86
- speaker_segments[spk].append(seg_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  return merged_segments, list(speaker_segments.keys())
89
 
90
  def run_pipeline(self, audio_path, output_dir=None, base_name=None,
91
- ref_rttm=None, ref_json=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  diar_json = self.run_diarization(audio_path)
93
  merged_segments, speakers = self.run_transcription(audio_path, diar_json)
94
 
 
 
 
 
 
 
 
 
 
95
  if output_dir and base_name:
96
  os.makedirs(output_dir, exist_ok=True)
97
 
98
- # Save RTTM
99
  rttm_path = os.path.join(output_dir, f"{base_name}.rttm")
100
  with open(rttm_path, "w") as f:
101
  for seg in diar_json:
102
  f.write(
103
- f"SPEAKER {base_name} 1 {seg['start']:.6f} "
104
- f"{seg['end']-seg['start']:.6f} <NA> <NA> "
105
- f"{seg['speaker']} <NA>\n"
106
  )
107
 
108
- # Save transcription
109
  merged_path = os.path.join(output_dir, f"{base_name}_merged_transcription.json")
110
  with open(merged_path, "w") as f:
111
- json.dump(merged_segments, f, indent=2)
112
 
113
- # --- evaluation if refs are provided ---
114
  eval_results = None
115
  if ref_rttm or ref_json:
116
  eval_results = self.evaluate(output_dir, base_name,
@@ -118,17 +351,20 @@ class ASR_Diarization:
118
 
119
  return {
120
  "speakers": speakers,
121
- "segments": merged_segments,
122
  "evaluation": eval_results
123
  }
124
 
125
  def evaluate(self, output_dir, base_name, ref_rttm=None, ref_json=None):
126
- results = {}
 
 
127
 
 
128
  hyp_rttm = os.path.join(output_dir, f"{base_name}.rttm")
129
  hyp_json = os.path.join(output_dir, f"{base_name}_merged_transcription.json")
130
 
131
- if ref_rttm:
132
  def load_rttm(path):
133
  ann = Annotation()
134
  for line in open(path):
@@ -141,10 +377,12 @@ class ASR_Diarization:
141
  der_score = DiarizationErrorRate()(load_rttm(ref_rttm), load_rttm(hyp_rttm))
142
  results["DER"] = round(der_score * 100, 2)
143
 
144
- if ref_json:
145
  def load_words(path):
146
  data = json.load(open(path))
147
- return " ".join([tok["text"] for seg in data for tok in seg["tokens"]])
 
 
148
 
149
  ref_text, hyp_text = load_words(ref_json), load_words(hyp_json)
150
  transform = Compose([ToLowerCase(), RemovePunctuation(),
@@ -154,7 +392,8 @@ class ASR_Diarization:
154
 
155
  return results if results else None
156
 
157
- def __call__(self, inputs):
 
158
  if isinstance(inputs, dict):
159
  if "audio_bytes" in inputs:
160
  audio_bytes = inputs["audio_bytes"]
@@ -165,8 +404,17 @@ class ASR_Diarization:
165
  else:
166
  audio_bytes = inputs
167
 
168
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
169
- tmp.write(audio_bytes)
170
- tmp_path = tmp.name
 
 
 
171
 
172
- return self.run_pipeline(tmp_path)
 
 
 
 
 
 
 
3
  import torch
4
  import torchaudio
5
  import noisereduce as nr
6
+ import numpy as np
7
  from pyannote.audio import Pipeline
8
  from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline as hf_pipeline
9
  import tempfile
 
16
  def __init__(self, HF_TOKEN,
17
  diar_model="pyannote/speaker-diarization-3.1",
18
  asr_model="Capstone04/TrainedWhisper_Medium",
19
+ model_path=None,
20
+ use_vad=True, # NEW: VAD after diarization
21
+ vad_threshold=0.3, # NEW: VAD speech ratio threshold
22
+ min_segment_duration=0.5, # NEW: Minimum segment duration
23
+ snr_threshold=15.0, # NEW: SNR threshold for adaptive processing
24
+ min_whisper_duration=0.3): # NEW: Minimum duration for Whisper
25
+
26
  self.HF_TOKEN = HF_TOKEN
27
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ self.use_vad = use_vad
29
+ self.vad_threshold = vad_threshold
30
+ self.min_segment_duration = min_segment_duration
31
+ self.snr_threshold = snr_threshold
32
+ self.min_whisper_duration = min_whisper_duration
33
 
34
+ # Load diarization model - FIX: Add device
35
  self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
36
+ self.diar_pipeline = self.diar_pipeline.to(torch.device(self.device))
37
 
38
+ # Load WebRTC VAD for post-diarization filtering - NEW
39
+ if self.use_vad:
40
+ try:
41
+ import webrtcvad
42
+ self.vad = webrtcvad.Vad(2) # Medium aggressiveness
43
+ print("✅ WebRTC VAD loaded for post-diarization filtering")
44
+ except ImportError:
45
+ print("⚠️ WebRTC VAD not available")
46
+ self.use_vad = False
47
+
48
+ # Load ASR model
49
  if model_path and os.path.exists(model_path):
50
  print(f"🔄 Loading custom ASR model from: {model_path}")
51
  actual_asr_model = model_path
 
65
  return_timestamps=True
66
  )
67
 
68
+ def calculate_snr(self, audio_path):
69
+ """NEW: Calculate SNR using RMS energy"""
70
+ try:
71
+ import librosa
72
+ y, sr = librosa.load(audio_path, sr=16000, mono=True)
73
+
74
+ # RMS-based SNR
75
+ rms = librosa.feature.rms(y=y)[0]
76
+ if len(rms) == 0:
77
+ return float('inf')
78
+
79
+ # Signal = high RMS regions, Noise = low RMS regions
80
+ high_rms = rms[rms > np.percentile(rms, 70)]
81
+ low_rms = rms[rms <= np.percentile(rms, 30)]
82
+
83
+ if len(high_rms) == 0 or len(low_rms) == 0:
84
+ return float('inf')
85
+
86
+ signal_power = np.mean(high_rms)
87
+ noise_power = np.mean(low_rms)
88
+
89
+ if noise_power == 0:
90
+ return float('inf')
91
+
92
+ snr = 10 * np.log10(signal_power / noise_power)
93
+ return snr
94
+
95
+ except Exception as e:
96
+ print(f"⚠️ SNR calculation failed: {e}")
97
+ return float('inf')
98
+
99
+ def calculate_rms_energy(self, audio_chunk):
100
+ """NEW: Calculate RMS energy for audio chunk"""
101
+ return np.sqrt(np.mean(audio_chunk**2))
102
+
103
+ def run_webrtc_vad_on_segment(self, audio_path, segment_start, segment_end):
104
+ """NEW: Run WebRTC VAD on segment to get speech ratio"""
105
+ if not self.use_vad:
106
+ return 1.0
107
+
108
+ try:
109
+ import wave
110
+ # Load audio
111
+ with wave.open(audio_path, "rb") as wf:
112
+ sample_rate = wf.getframerate()
113
+ n_frames = wf.getnframes()
114
+ audio_data = wf.readframes(n_frames)
115
+
116
+ audio_array = np.frombuffer(audio_data, dtype=np.int16)
117
+ start_sample = int(segment_start * sample_rate)
118
+ end_sample = int(segment_end * sample_rate)
119
+ segment_audio = audio_array[start_sample:end_sample]
120
+ segment_bytes = segment_audio.tobytes()
121
+
122
+ # WebRTC VAD processing (30ms frames)
123
+ frame_duration = 30
124
+ bytes_per_sample = 2
125
+ frame_size = int(sample_rate * frame_duration / 1000) * bytes_per_sample
126
+
127
+ speech_frames = 0
128
+ total_frames = 0
129
+
130
+ for i in range(0, len(segment_bytes) - frame_size + 1, frame_size):
131
+ frame = segment_bytes[i:i + frame_size]
132
+ if len(frame) == frame_size:
133
+ is_speech = self.vad.is_speech(frame, sample_rate)
134
+ if is_speech:
135
+ speech_frames += 1
136
+ total_frames += 1
137
+
138
+ return speech_frames / total_frames if total_frames > 0 else 0.0
139
+
140
+ except Exception as e:
141
+ print(f"⚠️ WebRTC VAD failed: {e}")
142
+ return 0.0
143
+
144
  def run_diarization(self, audio_path):
145
+ """FIXED: Run diarization with VAD AFTER approach"""
146
+ # Step 1: Diarization sees FULL audio first
147
  diarization = self.diar_pipeline(audio_path)
148
+ diar_segments = [
149
  {"start": t.start, "end": t.end, "speaker": spk}
150
  for t, _, spk in diarization.itertracks(yield_label=True)
151
  ]
152
+
153
+ print(f"🎯 Diarization found {len(diar_segments)} segments")
154
+
155
+ # Step 2: Calculate SNR for adaptive processing
156
+ snr = self.calculate_snr(audio_path)
157
+
158
+ # Step 3: Apply VAD filtering ONLY if low SNR
159
+ if snr < self.snr_threshold and self.use_vad:
160
+ print(f"🔇 Low SNR ({snr:.1f} dB), applying VAD filtering")
161
+ filtered_segments = []
162
+
163
+ for seg in diar_segments:
164
+ # Skip VAD for very short segments
165
+ if (seg["end"] - seg["start"]) < 0.2:
166
+ continue
167
+
168
+ speech_ratio = self.run_webrtc_vad_on_segment(
169
+ audio_path, seg["start"], seg["end"]
170
+ )
171
+
172
+ if speech_ratio >= self.vad_threshold:
173
+ filtered_segments.append(seg)
174
+ else:
175
+ print(f"🔇 Filtered low-speech segment: {seg['start']:.2f}-{seg['end']:.2f} (speech: {speech_ratio:.1%})")
176
+
177
+ diar_segments = filtered_segments
178
+ else:
179
+ print(f"✅ Good SNR ({snr:.1f} dB), using all diarization segments")
180
+
181
+ # Step 4: Duration filtering for Whisper
182
+ filtered_segments = [
183
+ seg for seg in diar_segments
184
+ if (seg["end"] - seg["start"]) >= self.min_whisper_duration
185
+ ]
186
+
187
+ print(f"🎯 Final: {len(filtered_segments)} segments for Whisper")
188
+ return filtered_segments
189
 
190
  def run_transcription(self, audio_path, diar_json):
191
+ """FIXED: Transcription with proper timestamp conversion and error handling"""
192
+ # FIX: Load and standardize audio
193
  audio, sr = torchaudio.load(audio_path)
194
+
195
+ # FIX: Resample to 16kHz for consistency
196
+ if sr != 16000:
197
+ resampler = torchaudio.transforms.Resample(sr, 16000)
198
+ audio = resampler(audio)
199
+ sr = 16000
200
+
201
  merged_segments = []
202
  speaker_segments = {}
203
+
204
+ # NEW: Calculate SNR for adaptive noise reduction
205
+ snr = self.calculate_snr(audio_path)
206
 
207
  for seg in diar_json:
208
  start, end, spk = seg["start"], seg["end"], seg["speaker"]
209
+
210
+ # NEW: Skip segments that are too short for Whisper
211
+ segment_duration = end - start
212
+ if segment_duration < self.min_whisper_duration:
213
+ print(f"⏩ Skipping short segment for Whisper: {start:.2f}-{end:.2f} ({segment_duration:.2f}s)")
214
+ continue
215
+
216
  start_sample, end_sample = int(start * sr), int(end * sr)
217
+
218
+ # FIX: Handle both mono and stereo audio
219
+ if audio.shape[0] > 1: # Stereo
220
+ chunk = torch.mean(audio[:, start_sample:end_sample], dim=0).numpy()
221
+ else: # Mono
222
+ chunk = audio[0, start_sample:end_sample].numpy()
223
+
224
+ # NEW: Calculate RMS energy for this segment
225
+ rms_energy = self.calculate_rms_energy(chunk)
226
+
227
+ # NEW: Adaptive noise reduction based on SNR + RMS
228
+ if len(chunk) > int(0.1 * sr):
229
+ if snr < 10 or rms_energy < 0.01: # Very noisy or low energy
230
+ reduced = nr.reduce_noise(y=chunk, sr=sr, stationary=True, prop_decrease=0.8)
231
+ elif snr < 20: # Moderately noisy
232
+ reduced = nr.reduce_noise(y=chunk, sr=sr, stationary=True, prop_decrease=0.5)
233
+ else: # Clean audio
234
+ reduced = chunk
235
+ else:
236
+ reduced = chunk
237
 
238
+ try:
239
+ result = self.asr_pipeline(reduced)
240
+ except Exception as e:
241
+ print(f"⚠️ Whisper failed on segment {start:.2f}-{end:.2f}: {e}")
242
+ continue
243
 
244
  tokens = []
245
+ segment_text = ""
246
+
247
  if "chunks" in result:
248
  for word_info in result["chunks"]:
249
+ # FIX: Convert relative timestamps to absolute
250
+ timestamp = word_info.get("timestamp")
251
+ text = word_info.get("text", "").strip()
252
+
253
+ if text:
254
+ if timestamp and isinstance(timestamp, (list, tuple)) and len(timestamp) == 2:
255
+ rel_start, rel_end = timestamp
256
+ # Validate timestamps are reasonable
257
+ if 0 <= rel_start < rel_end <= (end - start):
258
+ abs_start = start + rel_start # Convert to absolute time
259
+ abs_end = start + rel_end # Convert to absolute time
260
+ else:
261
+ # Invalid timestamps, use segment boundaries
262
+ abs_start = start
263
+ abs_end = end
264
+ else:
265
+ # No timestamps from Whisper, use segment boundaries
266
+ abs_start = start
267
+ abs_end = end
268
+
269
+ tokens.append({
270
+ "start": abs_start, # Store absolute time
271
+ "end": abs_end, # Store absolute time
272
+ "text": text,
273
+ "tag": "w"
274
+ })
275
+
276
+ segment_text += text + " "
277
+
278
+ # NEW: Only add segment if we got content
279
+ if tokens or segment_text.strip():
280
+ seg_dict = {
281
+ "speaker": spk,
282
+ "start": start,
283
+ "end": end,
284
+ "tokens": tokens,
285
+ "text": segment_text.strip(), # NEW: Add full segment text
286
+ "rms_energy": float(rms_energy) # NEW: Store RMS energy
287
+ }
288
+ merged_segments.append(seg_dict)
289
+
290
+ if spk not in speaker_segments:
291
+ speaker_segments[spk] = []
292
+ speaker_segments[spk].append(seg_dict)
293
+ else:
294
+ print(f"🔇 Empty transcription for segment {start:.2f}-{end:.2f}")
295
 
296
  return merged_segments, list(speaker_segments.keys())
297
 
298
  def run_pipeline(self, audio_path, output_dir=None, base_name=None,
299
+ ref_rttm=None, ref_json=None, nse_events=None): # NEW: nse_events parameter
300
+ """FIXED: Add input validation and proper RTTM format"""
301
+ # NEW: Validate input audio file
302
+ if not os.path.exists(audio_path):
303
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
304
+
305
+ try:
306
+ # NEW: Quick validation that it's loadable audio
307
+ audio, sr = torchaudio.load(audio_path)
308
+ if audio.numel() == 0:
309
+ raise ValueError("Audio file is empty")
310
+ except Exception as e:
311
+ raise ValueError(f"Invalid audio file: {e}")
312
+
313
+ print(f"🔊 Processing with VAD: {'ON' if self.use_vad else 'OFF'}")
314
+
315
+ # Run diarization and transcription
316
  diar_json = self.run_diarization(audio_path)
317
  merged_segments, speakers = self.run_transcription(audio_path, diar_json)
318
 
319
+ # NEW: Combine ASR segments with NSE events if provided
320
+ if nse_events:
321
+ print(f"🔊 Combining {len(merged_segments)} ASR segments with {len(nse_events)} NSE events")
322
+ all_segments = merged_segments + nse_events
323
+ # Sort by start time for proper timeline
324
+ all_segments.sort(key=lambda x: x["start"])
325
+ else:
326
+ all_segments = merged_segments
327
+
328
  if output_dir and base_name:
329
  os.makedirs(output_dir, exist_ok=True)
330
 
331
+ # FIX: Save RTTM with standard format and precision
332
  rttm_path = os.path.join(output_dir, f"{base_name}.rttm")
333
  with open(rttm_path, "w") as f:
334
  for seg in diar_json:
335
  f.write(
336
+ f"SPEAKER {base_name} 1 {seg['start']:.3f} "
337
+ f"{seg['end']-seg['start']:.3f} <NA> <NA> "
338
+ f"{seg['speaker']} <NA> <NA>\n" # FIX: Standard 9 fields
339
  )
340
 
341
+ # Save transcription (with NSE events if available)
342
  merged_path = os.path.join(output_dir, f"{base_name}_merged_transcription.json")
343
  with open(merged_path, "w") as f:
344
+ json.dump(all_segments, f, indent=2)
345
 
346
+ # Evaluation if refs are provided
347
  eval_results = None
348
  if ref_rttm or ref_json:
349
  eval_results = self.evaluate(output_dir, base_name,
 
351
 
352
  return {
353
  "speakers": speakers,
354
+ "segments": all_segments, # Return combined segments
355
  "evaluation": eval_results
356
  }
357
 
358
  def evaluate(self, output_dir, base_name, ref_rttm=None, ref_json=None):
359
+ # FIX: Add output_dir validation
360
+ if not output_dir or not base_name:
361
+ return None
362
 
363
+ results = {}
364
  hyp_rttm = os.path.join(output_dir, f"{base_name}.rttm")
365
  hyp_json = os.path.join(output_dir, f"{base_name}_merged_transcription.json")
366
 
367
+ if ref_rttm and os.path.exists(hyp_rttm):
368
  def load_rttm(path):
369
  ann = Annotation()
370
  for line in open(path):
 
377
  der_score = DiarizationErrorRate()(load_rttm(ref_rttm), load_rttm(hyp_rttm))
378
  results["DER"] = round(der_score * 100, 2)
379
 
380
+ if ref_json and os.path.exists(hyp_json):
381
  def load_words(path):
382
  data = json.load(open(path))
383
+ # NEW: Filter out NSE events for WER calculation (only use speech)
384
+ speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"]
385
+ return " ".join([tok["text"] for seg in speech_segments for tok in seg["tokens"]])
386
 
387
  ref_text, hyp_text = load_words(ref_json), load_words(hyp_json)
388
  transform = Compose([ToLowerCase(), RemovePunctuation(),
 
392
 
393
  return results if results else None
394
 
395
+ def __call__(self, inputs, nse_events=None): # NEW: nse_events parameter
396
+ """FIXED: Add proper temporary file cleanup"""
397
  if isinstance(inputs, dict):
398
  if "audio_bytes" in inputs:
399
  audio_bytes = inputs["audio_bytes"]
 
404
  else:
405
  audio_bytes = inputs
406
 
407
+ tmp_path = None
408
+ try:
409
+ # Create temporary file for processing
410
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
411
+ tmp.write(audio_bytes)
412
+ tmp_path = tmp.name
413
 
414
+ # Run pipeline with NSE events
415
+ result = self.run_pipeline(tmp_path, nse_events=nse_events)
416
+ return result
417
+ finally:
418
+ # FIX: Always clean up temporary file
419
+ if tmp_path and os.path.exists(tmp_path):
420
+ os.unlink(tmp_path)