Capstone04 commited on
Commit
9ebf74d
·
verified ·
1 Parent(s): 1a15c26

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. asr_diarization/pipeline.py +66 -25
  2. requirements.txt +2 -0
asr_diarization/pipeline.py CHANGED
@@ -187,8 +187,26 @@ class ASR_Diarization:
187
  print(f"🎯 Final: {len(filtered_segments)} segments for Whisper")
188
  return filtered_segments
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  def run_transcription(self, audio_path, diar_json):
191
- """FIXED: Transcription with proper timestamp conversion and error handling"""
192
  # FIX: Load and standardize audio
193
  audio, sr = torchaudio.load(audio_path)
194
 
@@ -236,42 +254,62 @@ class ASR_Diarization:
236
  reduced = chunk
237
 
238
  try:
239
- result = self.asr_pipeline(reduced)
 
 
 
 
 
 
 
 
240
  except Exception as e:
241
  print(f"⚠️ Whisper failed on segment {start:.2f}-{end:.2f}: {e}")
242
  continue
243
 
244
  tokens = []
245
  segment_text = ""
246
-
 
247
  if "chunks" in result:
248
- for word_info in result["chunks"]:
249
- # FIX: Convert relative timestamps to absolute
250
- timestamp = word_info.get("timestamp")
251
- text = word_info.get("text", "").strip()
252
 
253
- if text:
254
- if timestamp and isinstance(timestamp, (list, tuple)) and len(timestamp) == 2:
255
- rel_start, rel_end = timestamp
256
- # Validate timestamps are reasonable
257
- if 0 <= rel_start < rel_end <= (end - start):
258
- abs_start = start + rel_start # Convert to absolute time
259
- abs_end = start + rel_end # Convert to absolute time
260
- else:
261
- # Invalid timestamps, use segment boundaries
262
- abs_start = start
263
- abs_end = end
264
  else:
265
- # No timestamps from Whisper, use segment boundaries
266
  abs_start = start
267
  abs_end = end
268
 
269
- tokens.append({
270
- "start": abs_start, # Store absolute time
271
- "end": abs_end, # Store absolute time
272
- "text": text,
273
- "tag": "w"
274
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  segment_text += text + " "
277
 
@@ -316,6 +354,9 @@ class ASR_Diarization:
316
  diar_json = self.run_diarization(audio_path)
317
  merged_segments, speakers = self.run_transcription(audio_path, diar_json)
318
 
 
 
 
319
  # NEW: Combine ASR segments with NSE events if provided
320
  if nse_events:
321
  print(f"🔊 Combining {len(merged_segments)} ASR segments with {len(nse_events)} NSE events")
 
187
  print(f"🎯 Final: {len(filtered_segments)} segments for Whisper")
188
  return filtered_segments
189
 
190
+ def map_speaker_labels(self, segments, original_speakers=['A', 'B', 'C', 'D']):
191
+ """NEW: Map SPEAKER_XX labels to A, B, C, D format to match original"""
192
+ unique_speakers = list(set([seg['speaker'] for seg in segments]))
193
+ speaker_map = {}
194
+
195
+ # Create mapping from SPEAKER_00 -> A, SPEAKER_01 -> B, etc.
196
+ for i, spk in enumerate(sorted(unique_speakers)):
197
+ if i < len(original_speakers):
198
+ speaker_map[spk] = original_speakers[i]
199
+ else:
200
+ speaker_map[spk] = f"SPK_{i}"
201
+
202
+ # Apply mapping to all segments
203
+ for seg in segments:
204
+ seg['speaker'] = speaker_map[seg['speaker']]
205
+
206
+ return segments, list(speaker_map.values())
207
+
208
  def run_transcription(self, audio_path, diar_json):
209
+ """FIXED: Transcription with proper word-level timestamp extraction"""
210
  # FIX: Load and standardize audio
211
  audio, sr = torchaudio.load(audio_path)
212
 
 
254
  reduced = chunk
255
 
256
  try:
257
+ # FIX: Force word-level timestamps and better configuration
258
+ result = self.asr_pipeline(
259
+ reduced,
260
+ return_timestamps="word", # FORCE word-level timestamps
261
+ generate_kwargs={
262
+ "task": "transcribe",
263
+ "language": "en"
264
+ }
265
+ )
266
  except Exception as e:
267
  print(f"⚠️ Whisper failed on segment {start:.2f}-{end:.2f}: {e}")
268
  continue
269
 
270
  tokens = []
271
  segment_text = ""
272
+
273
+ # FIXED: Proper word-level timestamp extraction
274
  if "chunks" in result:
275
+ for chunk_info in result["chunks"]:
276
+ timestamp = chunk_info.get("timestamp")
277
+ text = chunk_info.get("text", "").strip()
 
278
 
279
+ if text and timestamp:
280
+ chunk_start, chunk_end = timestamp
281
+
282
+ # Validate and convert to absolute time
283
+ if 0 <= chunk_start <= chunk_end <= (end - start):
284
+ abs_start = start + chunk_start
285
+ abs_end = start + chunk_end
 
 
 
 
286
  else:
287
+ # Fallback: use segment boundaries
288
  abs_start = start
289
  abs_end = end
290
 
291
+ # NEW: Split into individual words with distributed timestamps
292
+ words = text.split()
293
+ if len(words) == 1:
294
+ # Single word - use original timestamp
295
+ tokens.append({
296
+ "start": abs_start,
297
+ "end": abs_end,
298
+ "text": text,
299
+ "tag": "w"
300
+ })
301
+ else:
302
+ # Multiple words - distribute time evenly
303
+ word_duration = (abs_end - abs_start) / len(words)
304
+ for i, word in enumerate(words):
305
+ word_start = abs_start + (i * word_duration)
306
+ word_end = word_start + word_duration
307
+ tokens.append({
308
+ "start": word_start,
309
+ "end": word_end,
310
+ "text": word,
311
+ "tag": "w"
312
+ })
313
 
314
  segment_text += text + " "
315
 
 
354
  diar_json = self.run_diarization(audio_path)
355
  merged_segments, speakers = self.run_transcription(audio_path, diar_json)
356
 
357
+ # NEW: Map speaker labels to match original format (A, B, C, D)
358
+ merged_segments, speakers = self.map_speaker_labels(merged_segments)
359
+
360
  # NEW: Combine ASR segments with NSE events if provided
361
  if nse_events:
362
  print(f"🔊 Combining {len(merged_segments)} ASR segments with {len(nse_events)} NSE events")
requirements.txt CHANGED
@@ -5,3 +5,5 @@ transformers
5
  noisereduce
6
  jiwer
7
  librosa
 
 
 
5
  noisereduce
6
  jiwer
7
  librosa
8
+ webrtcvad
9
+ resampy