Ranam Hamoud commited on
Commit
67597e5
·
1 Parent(s): 8b3fa78

Add audio validation and fix tensor reshape error for short/invalid audio

Browse files
Files changed (1) hide show
  1. speech_recognizer.py +87 -41
speech_recognizer.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import re
5
  from typing import Dict, Optional, List
6
  import warnings
 
7
  warnings.filterwarnings("ignore")
8
 
9
 
@@ -19,6 +20,26 @@ class SpeechRecognizer:
19
  print(f"Whisper model loaded successfully.")
20
 
21
  self.model_size = model_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def transcribe(
24
  self,
@@ -26,8 +47,13 @@ class SpeechRecognizer:
26
  language: Optional[str] = None,
27
  task: str = "transcribe"
28
  ) -> Dict[str, any]:
29
- # Transcribe with Whisper (with word-level timestamps for better pause detection)
30
- # Using fp16=False to avoid KV cache issues in production environments
 
 
 
 
 
31
  try:
32
  result = self.model.transcribe(
33
  audio_path,
@@ -38,21 +64,36 @@ class SpeechRecognizer:
38
  fp16=False # Disable fp16 to avoid KV cache KeyError
39
  )
40
  except (KeyError, RuntimeError) as e:
41
- # Fallback: transcribe without word timestamps if KV cache fails
42
- print(f"Warning: Word-level timestamps failed ({e}), retrying without them...")
43
- result = self.model.transcribe(
44
- audio_path,
45
- language=language,
46
- task=task,
47
- verbose=False,
48
- word_timestamps=False,
49
- fp16=False
50
- )
 
 
 
 
 
 
 
 
 
 
51
 
52
  transcription = result['text'].strip()
53
  detected_language = result.get('language', 'unknown')
54
  segments = result.get('segments', [])
55
 
 
 
 
 
 
56
  analysis = self._analyze_transcription(transcription, segments)
57
 
58
  duration = analysis['duration'] if analysis['duration'] > 0 else 1.0
@@ -76,6 +117,39 @@ class SpeechRecognizer:
76
  'interpretation': self._interpret_speech_patterns(analysis, kopparapu_features, kopparapu_score)
77
  }
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def _analyze_transcription(self, text: str, segments: List[Dict]) -> Dict:
80
  words = text.split()
81
  word_count = len(words)
@@ -160,10 +234,6 @@ class SpeechRecognizer:
160
  self, text: str, duration_sec: float,
161
  segments: List[Dict] = None, pause_patterns: Dict = None
162
  ) -> Dict:
163
- """
164
- Extract enhanced Kopparapu-like linguistic features for read speech detection.
165
- Based on: https://arxiv.org/pdf/2306.08012 with extensions.
166
- """
167
  text = text.strip()
168
  if len(text) == 0:
169
  return {
@@ -259,11 +329,6 @@ class SpeechRecognizer:
259
  }
260
 
261
  def _compute_rate_variability(self, segments: List[Dict]) -> float:
262
- """
263
- Compute speech rate variability across segments.
264
- Read speech has consistent rate; spontaneous varies with thinking.
265
- Returns 0-1 where higher = more variable = more spontaneous.
266
- """
267
  if not segments or len(segments) < 3:
268
  return 0.0
269
 
@@ -287,11 +352,6 @@ class SpeechRecognizer:
287
  return float(min(1.0, cv / 0.5)) # CV of 0.5+ maps to 1.0
288
 
289
  def _compute_sentence_variance(self, text: str) -> float:
290
- """
291
- Compute variance in sentence lengths.
292
- Read/scripted text tends to have more uniform sentence structure.
293
- Returns 0-1 where higher = more variance = more spontaneous.
294
- """
295
  # Split into sentences
296
  sentences = re.split(r'[.!?]+', text)
297
  sentences = [s.strip() for s in sentences if s.strip()]
@@ -307,23 +367,9 @@ class SpeechRecognizer:
307
  cv = std_len / mean_len if mean_len > 0 else 0
308
  return float(min(1.0, cv / 0.6)) # CV of 0.6+ maps to 1.0
309
 
310
- def _logistic(self, x: float, a: float, b: float) -> float:
311
- """Sigmoid function centered at 'a' with steepness 'b'."""
312
- return 1.0 / (1.0 + np.exp(-(x - a) / b))
313
 
314
  def _calculate_kopparapu_score(self, features: Dict) -> float:
315
- """
316
- Calculate enhanced Kopparapu score for read vs spontaneous classification.
317
- Score closer to 1 = more likely READ, closer to 0 = more likely SPONTANEOUS.
318
-
319
- Key signals for READ speech:
320
- - Higher chars_per_word (formal vocabulary)
321
- - Faster, steadier words_per_sec
322
- - Lower filler rate and disfluencies
323
- - Regular pause patterns (pause_regularity high)
324
- - Low speech rate variability
325
- - Uniform sentence lengths
326
- """
327
  # L1: Vocabulary complexity - higher chars/word = more formal = read
328
  f1 = features['chars_per_word']
329
  L1 = self._logistic(f1, a=4.8, b=1.2)
 
4
  import re
5
  from typing import Dict, Optional, List
6
  import warnings
7
+ import librosa
8
  warnings.filterwarnings("ignore")
9
 
10
 
 
20
  print(f"Whisper model loaded successfully.")
21
 
22
  self.model_size = model_size
23
+
24
+ def _validate_audio(self, audio_path: str) -> tuple[bool, str, float]:
25
+ """Validate audio file before transcription."""
26
+ try:
27
+ # Load audio to check if it's valid
28
+ audio, sr = librosa.load(audio_path, sr=16000)
29
+ duration = len(audio) / sr
30
+
31
+ # Check if audio is too short
32
+ if duration < 0.1:
33
+ return False, "Audio is too short (< 0.1 seconds)", duration
34
+
35
+ # Check if audio is empty or silent
36
+ if np.max(np.abs(audio)) < 0.001:
37
+ return False, "Audio appears to be silent or empty", duration
38
+
39
+ return True, "Valid", duration
40
+
41
+ except Exception as e:
42
+ return False, f"Failed to load audio: {str(e)}", 0.0
43
 
44
  def transcribe(
45
  self,
 
47
  language: Optional[str] = None,
48
  task: str = "transcribe"
49
  ) -> Dict[str, any]:
50
+ # Validate audio first
51
+ is_valid, message, audio_duration = self._validate_audio(audio_path)
52
+ if not is_valid:
53
+ print(f"Audio validation failed: {message}")
54
+ # Return minimal valid response for invalid audio
55
+ return self._get_empty_response(message, audio_duration)
56
+
57
  try:
58
  result = self.model.transcribe(
59
  audio_path,
 
64
  fp16=False # Disable fp16 to avoid KV cache KeyError
65
  )
66
  except (KeyError, RuntimeError) as e:
67
+ error_msg = str(e)
68
+ # Check if it's a tensor shape error (empty audio issue)
69
+ if "reshape tensor of 0 elements" in error_msg or "ambiguous" in error_msg:
70
+ print(f"Audio processing failed: Audio may be too short or corrupted")
71
+ return self._get_empty_response("Audio too short or corrupted", audio_duration)
72
+
73
+ # Fallback: transcribe without word timestamps for other errors
74
+ print(f"Warning: Transcription failed ({error_msg[:100]}), retrying without word timestamps...")
75
+ try:
76
+ result = self.model.transcribe(
77
+ audio_path,
78
+ language=language,
79
+ task=task,
80
+ verbose=False,
81
+ word_timestamps=False,
82
+ fp16=False
83
+ )
84
+ except Exception as e2:
85
+ print(f"Transcription completely failed: {e2}")
86
+ return self._get_empty_response(f"Transcription failed: {str(e2)[:100]}", audio_duration)
87
 
88
  transcription = result['text'].strip()
89
  detected_language = result.get('language', 'unknown')
90
  segments = result.get('segments', [])
91
 
92
+ # Handle empty transcription
93
+ if not transcription or len(transcription.strip()) == 0:
94
+ print("Warning: Transcription is empty")
95
+ return self._get_empty_response("No speech detected in audio", audio_duration)
96
+
97
  analysis = self._analyze_transcription(transcription, segments)
98
 
99
  duration = analysis['duration'] if analysis['duration'] > 0 else 1.0
 
117
  'interpretation': self._interpret_speech_patterns(analysis, kopparapu_features, kopparapu_score)
118
  }
119
 
120
+ def _get_empty_response(self, reason: str, duration: float = 0.0) -> Dict[str, any]:
121
+ """Return a valid empty response when transcription fails."""
122
+ return {
123
+ 'transcription': f"[Error: {reason}]",
124
+ 'language': 'unknown',
125
+ 'segments': [],
126
+ 'word_count': 0,
127
+ 'duration': duration,
128
+ 'speech_rate': 0.0,
129
+ 'pause_patterns': {
130
+ 'avg_pause': 0.0,
131
+ 'max_pause': 0.0,
132
+ 'num_pauses': 0,
133
+ 'pause_variability': 0.0
134
+ },
135
+ 'filler_words': {
136
+ 'count': 0,
137
+ 'ratio': 0.0,
138
+ 'details': {}
139
+ },
140
+ 'kopparapu_features': {
141
+ 'chars_per_word': 0.0,
142
+ 'words_per_sec': 0.0,
143
+ 'nonalpha_per_sec': 0.0,
144
+ 'filler_rate': 0.0,
145
+ 'repetition_count': 0,
146
+ 'alpha_ratio': 0.0
147
+ },
148
+ 'kopparapu_score': 0.5,
149
+ 'kopparapu_classification': 'unknown',
150
+ 'interpretation': f"⚠️ Audio processing failed: {reason}\n\nPlease ensure:\n- Audio is at least 1 second long\n- Audio contains actual speech\n- Audio file is not corrupted"
151
+ }
152
+
153
  def _analyze_transcription(self, text: str, segments: List[Dict]) -> Dict:
154
  words = text.split()
155
  word_count = len(words)
 
234
  self, text: str, duration_sec: float,
235
  segments: List[Dict] = None, pause_patterns: Dict = None
236
  ) -> Dict:
 
 
 
 
237
  text = text.strip()
238
  if len(text) == 0:
239
  return {
 
329
  }
330
 
331
  def _compute_rate_variability(self, segments: List[Dict]) -> float:
 
 
 
 
 
332
  if not segments or len(segments) < 3:
333
  return 0.0
334
 
 
352
  return float(min(1.0, cv / 0.5)) # CV of 0.5+ maps to 1.0
353
 
354
  def _compute_sentence_variance(self, text: str) -> float:
 
 
 
 
 
355
  # Split into sentences
356
  sentences = re.split(r'[.!?]+', text)
357
  sentences = [s.strip() for s in sentences if s.strip()]
 
367
  cv = std_len / mean_len if mean_len > 0 else 0
368
  return float(min(1.0, cv / 0.6)) # CV of 0.6+ maps to 1.0
369
 
370
+ def _logistic(self, x: float, a: float, b: float) -> float: return 1.0 / (1.0 + np.exp(-(x - a) / b))
 
 
371
 
372
  def _calculate_kopparapu_score(self, features: Dict) -> float:
 
 
 
 
 
 
 
 
 
 
 
 
373
  # L1: Vocabulary complexity - higher chars/word = more formal = read
374
  f1 = features['chars_per_word']
375
  L1 = self._logistic(f1, a=4.8, b=1.2)