Ranam Hamoud commited on
Commit
8e8675d
·
1 Parent(s): 4ec806c

Fix audio classifier model loading and label mapping, update Gradio compatibility

Browse files
Files changed (3) hide show
  1. app.py +5 -2
  2. audio_classifier.py +61 -14
  3. requirements.txt +1 -1
app.py CHANGED
@@ -271,9 +271,12 @@ def create_interface():
271
  }
272
  """
273
 
274
- with gr.Blocks(css=custom_css, title="Authenticity Detection System") as demo:
275
 
276
- gr.HTML("""
 
 
 
277
  <header style='background: white; border-bottom: 1px solid #e5e7eb; margin-bottom: 32px;'>
278
  <div style='padding: 16px 0;'>
279
  <div style='display: flex; align-items: center; gap: 12px;'>
 
271
  }
272
  """
273
 
274
+ with gr.Blocks(title="Authenticity Detection System") as demo:
275
 
276
+ gr.HTML(f"""
277
+ <style>
278
+ {custom_css}
279
+ </style>
280
  <header style='background: white; border-bottom: 1px solid #e5e7eb; margin-bottom: 32px;'>
281
  <div style='padding: 16px 0;'>
282
  <div style='display: flex; align-items: center; gap: 12px;'>
audio_classifier.py CHANGED
@@ -84,7 +84,7 @@ class AudioClassifier:
84
  }
85
 
86
  @classmethod
87
- def get_model_path(cls, model_name: str = '4s_window') -> str:
88
  import os
89
  if model_name not in cls.AVAILABLE_MODELS:
90
  raise ValueError(f"Unknown model: {model_name}. Available: {list(cls.AVAILABLE_MODELS.keys())}")
@@ -100,16 +100,17 @@ class AudioClassifier:
100
 
101
  if model_path is None:
102
  import os
103
- model_path = os.path.join(os.path.dirname(__file__), 'spectrogram_cnn_4s_window.pth')
104
 
105
  try:
 
106
  state_dict = torch.load(model_path, map_location=self.device)
107
  self.model.load_state_dict(state_dict)
108
- print(f"Successfully loaded model from: {model_path}")
109
  except FileNotFoundError:
110
- print(f"Warning: Model file not found at {model_path}. Using untrained model.")
111
  except Exception as e:
112
- print(f"Warning: Error loading model from {model_path}: {e}. Using untrained model.")
113
 
114
  self.model.eval()
115
 
@@ -118,16 +119,53 @@ class AudioClassifier:
118
  self.n_fft = 2048
119
  self.hop_length = 512
120
 
121
- def extract_mel_spectrogram(self, audio_path: str) -> np.ndarray:
 
122
  audio, sr = librosa.load(audio_path, sr=self.sample_rate)
123
 
124
- mel_spec = librosa.feature.melspectrogram(
125
- y=audio,
126
- sr=sr,
127
- n_mels=self.n_mels,
128
- n_fft=self.n_fft,
129
- hop_length=self.hop_length
130
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
133
 
@@ -287,6 +325,11 @@ class AudioClassifier:
287
  probabilities = F.softmax(logits, dim=1)
288
  predicted_class = torch.argmax(probabilities, dim=1).item()
289
  cnn_confidence = probabilities[0, predicted_class].item()
 
 
 
 
 
290
 
291
  acoustic_features = self.extract_acoustic_features(audio_path)
292
 
@@ -294,7 +337,11 @@ class AudioClassifier:
294
  prosody_classification = prosody_scores['classification']
295
  prosody_confidence = prosody_scores['confidence']
296
 
297
- cnn_class_name = 'read' if predicted_class == 0 else 'spontaneous'
 
 
 
 
298
 
299
  if cnn_class_name == prosody_classification:
300
  final_confidence = min(0.95, (cnn_confidence * 0.7 + prosody_confidence * 0.3))
 
84
  }
85
 
86
  @classmethod
87
+ def get_model_path(cls, model_name: str = '3s_window') -> str:
88
  import os
89
  if model_name not in cls.AVAILABLE_MODELS:
90
  raise ValueError(f"Unknown model: {model_name}. Available: {list(cls.AVAILABLE_MODELS.keys())}")
 
100
 
101
  if model_path is None:
102
  import os
103
+ model_path = os.path.join(os.path.dirname(__file__), 'spectrogram_cnn_3s_window (1).pth')
104
 
105
  try:
106
+ print(f"Attempting to load model from: {model_path}")
107
  state_dict = torch.load(model_path, map_location=self.device)
108
  self.model.load_state_dict(state_dict)
109
+ print(f"Successfully loaded trained model from: {model_path}")
110
  except FileNotFoundError:
111
+ raise FileNotFoundError(f"Model file not found at {model_path}. Please ensure the model file exists.")
112
  except Exception as e:
113
+ raise RuntimeError(f"Error loading model from {model_path}: {e}")
114
 
115
  self.model.eval()
116
 
 
119
  self.n_fft = 2048
120
  self.hop_length = 512
121
 
122
+ def extract_mel_spectrogram(self, audio_path: str, window_size: float = 3.0) -> np.ndarray:
123
+ """Extract mel spectrogram from audio, using windowing if audio is longer than window_size."""
124
  audio, sr = librosa.load(audio_path, sr=self.sample_rate)
125
 
126
+ # If audio is longer than window_size, take multiple windows and average
127
+ window_samples = int(window_size * sr)
128
+
129
+ if len(audio) > window_samples * 1.5: # If significantly longer
130
+ # Split into overlapping windows
131
+ hop_samples = window_samples // 2
132
+ windows = []
133
+ for start in range(0, len(audio) - window_samples, hop_samples):
134
+ window = audio[start:start + window_samples]
135
+ windows.append(window)
136
+
137
+ # Also add the last window
138
+ if len(audio) > window_samples:
139
+ windows.append(audio[-window_samples:])
140
+
141
+ # Compute mel spectrogram for each window and average
142
+ mel_specs = []
143
+ for window in windows[:5]: # Limit to 5 windows to avoid too much computation
144
+ mel_spec = librosa.feature.melspectrogram(
145
+ y=window,
146
+ sr=sr,
147
+ n_mels=self.n_mels,
148
+ n_fft=self.n_fft,
149
+ hop_length=self.hop_length
150
+ )
151
+ mel_specs.append(mel_spec)
152
+
153
+ # Average the spectrograms
154
+ mel_spec = np.mean(mel_specs, axis=0)
155
+ else:
156
+ # Pad or use as-is for short audio
157
+ if len(audio) < window_samples:
158
+ audio = np.pad(audio, (0, window_samples - len(audio)), mode='constant')
159
+ else:
160
+ audio = audio[:window_samples]
161
+
162
+ mel_spec = librosa.feature.melspectrogram(
163
+ y=audio,
164
+ sr=sr,
165
+ n_mels=self.n_mels,
166
+ n_fft=self.n_fft,
167
+ hop_length=self.hop_length
168
+ )
169
 
170
  mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
171
 
 
325
  probabilities = F.softmax(logits, dim=1)
326
  predicted_class = torch.argmax(probabilities, dim=1).item()
327
  cnn_confidence = probabilities[0, predicted_class].item()
328
+
329
+ # Debug output
330
+ print(f"CNN Logits: {logits[0].cpu().numpy()}")
331
+ print(f"CNN Probabilities: Class 0 (read)={probabilities[0, 0].item():.3f}, Class 1 (spontaneous)={probabilities[0, 1].item():.3f}")
332
+ print(f"CNN Prediction: Class {predicted_class} ({['read', 'spontaneous'][predicted_class]}) with confidence {cnn_confidence:.3f}")
333
 
334
  acoustic_features = self.extract_acoustic_features(audio_path)
335
 
 
337
  prosody_classification = prosody_scores['classification']
338
  prosody_confidence = prosody_scores['confidence']
339
 
340
+ # Try reversing labels if model was trained with opposite mapping
341
+ # Original: 0=read, 1=spontaneous
342
+ # Reversed: 0=spontaneous, 1=read
343
+ cnn_class_name = 'spontaneous' if predicted_class == 0 else 'read' # REVERSED LABELS
344
+ print(f"Final CNN classification: {cnn_class_name}")
345
 
346
  if cnn_class_name == prosody_classification:
347
  final_confidence = min(0.95, (cnn_confidence * 0.7 + prosody_confidence * 0.3))
requirements.txt CHANGED
@@ -2,7 +2,7 @@ torch>=2.0.0
2
  torchaudio>=2.0.0
3
  openai-whisper>=20230314
4
  transformers>=4.30.0
5
- gradio>=4.0.0
6
  numpy>=1.24.0
7
  scikit-learn>=1.3.0
8
  librosa>=0.10.0
 
2
  torchaudio>=2.0.0
3
  openai-whisper>=20230314
4
  transformers>=4.30.0
5
+ gradio==4.44.0
6
  numpy>=1.24.0
7
  scikit-learn>=1.3.0
8
  librosa>=0.10.0