Capstone04 commited on
Commit
e89eb31
·
verified ·
1 Parent(s): c300c20

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. asr_diarization/pipeline.py +40 -2
asr_diarization/pipeline.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import json
 
3
  import torch
4
  import torchaudio
5
  import noisereduce as nr
@@ -65,6 +66,40 @@ class ASR_Diarization:
65
  return_timestamps=True
66
  )
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def calculate_snr(self, audio_path):
69
  """NEW: Calculate SNR using RMS energy"""
70
  try:
@@ -301,12 +336,15 @@ class ASR_Diarization:
301
  # Extract just the text (no timestamp processing)
302
  text = result.get("text", "").strip()
303
 
304
- if text:
 
 
 
305
  seg_dict = {
306
  "speaker": spk,
307
  "start": start, # Keep segment boundaries
308
  "end": end, # Keep segment boundaries
309
- "text": text, # Just the full segment text
310
  "rms_energy": float(rms_energy)
311
  }
312
  merged_segments.append(seg_dict)
 
1
  import os
2
  import json
3
+ import re
4
  import torch
5
  import torchaudio
6
  import noisereduce as nr
 
66
  return_timestamps=True
67
  )
68
 
69
+ def clean_transcription_text(self, text):
70
+ """Clean ASR text for better TTS performance"""
71
+ if not text:
72
+ return ""
73
+
74
+ # Basic cleaning
75
+ text = text.strip()
76
+
77
+ # Fix punctuation spacing for TTS
78
+ text = re.sub(r'\s+([.,!?;:])', r'\1', text) # Remove space before punctuation
79
+ text = re.sub(r'([.,!?;:])(?=\w)', r'\1 ', text) # Add space after punctuation
80
+
81
+ # Normalize whitespace
82
+ text = re.sub(r'\s+', ' ', text)
83
+
84
+ return text.strip()
85
+
86
+ def should_keep_segment(self, text, duration, rms_energy):
87
+ """Generalized segment quality assessment"""
88
+ # Duration too short
89
+ if duration < self.min_whisper_duration:
90
+ return False
91
+
92
+ # Energy too low (likely noise)
93
+ if rms_energy < 0.001:
94
+ return False
95
+
96
+ # Text too short or just punctuation
97
+ clean_text = text.strip()
98
+ if len(clean_text) <= 1:
99
+ return False
100
+
101
+ return True
102
+
103
  def calculate_snr(self, audio_path):
104
  """NEW: Calculate SNR using RMS energy"""
105
  try:
 
336
  # Extract just the text (no timestamp processing)
337
  text = result.get("text", "").strip()
338
 
339
+ # Clean the text for TTS and apply quality filtering
340
+ clean_text = self.clean_transcription_text(text)
341
+
342
+ if clean_text and self.should_keep_segment(clean_text, segment_duration, rms_energy):
343
  seg_dict = {
344
  "speaker": spk,
345
  "start": start, # Keep segment boundaries
346
  "end": end, # Keep segment boundaries
347
+ "text": clean_text, # Use cleaned text
348
  "rms_energy": float(rms_energy)
349
  }
350
  merged_segments.append(seg_dict)