Spaces:
Sleeping
Sleeping
Ranam Hamoud
commited on
Commit
·
8e8675d
1
Parent(s):
4ec806c
Fix audio classifier model loading and label mapping, update Gradio compatibility
Browse files- app.py +5 -2
- audio_classifier.py +61 -14
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -271,9 +271,12 @@ def create_interface():
|
|
| 271 |
}
|
| 272 |
"""
|
| 273 |
|
| 274 |
-
with gr.Blocks(
|
| 275 |
|
| 276 |
-
gr.HTML("""
|
|
|
|
|
|
|
|
|
|
| 277 |
<header style='background: white; border-bottom: 1px solid #e5e7eb; margin-bottom: 32px;'>
|
| 278 |
<div style='padding: 16px 0;'>
|
| 279 |
<div style='display: flex; align-items: center; gap: 12px;'>
|
|
|
|
| 271 |
}
|
| 272 |
"""
|
| 273 |
|
| 274 |
+
with gr.Blocks(title="Authenticity Detection System") as demo:
|
| 275 |
|
| 276 |
+
gr.HTML(f"""
|
| 277 |
+
<style>
|
| 278 |
+
{custom_css}
|
| 279 |
+
</style>
|
| 280 |
<header style='background: white; border-bottom: 1px solid #e5e7eb; margin-bottom: 32px;'>
|
| 281 |
<div style='padding: 16px 0;'>
|
| 282 |
<div style='display: flex; align-items: center; gap: 12px;'>
|
audio_classifier.py
CHANGED
|
@@ -84,7 +84,7 @@ class AudioClassifier:
|
|
| 84 |
}
|
| 85 |
|
| 86 |
@classmethod
|
| 87 |
-
def get_model_path(cls, model_name: str = '
|
| 88 |
import os
|
| 89 |
if model_name not in cls.AVAILABLE_MODELS:
|
| 90 |
raise ValueError(f"Unknown model: {model_name}. Available: {list(cls.AVAILABLE_MODELS.keys())}")
|
|
@@ -100,16 +100,17 @@ class AudioClassifier:
|
|
| 100 |
|
| 101 |
if model_path is None:
|
| 102 |
import os
|
| 103 |
-
model_path = os.path.join(os.path.dirname(__file__), '
|
| 104 |
|
| 105 |
try:
|
|
|
|
| 106 |
state_dict = torch.load(model_path, map_location=self.device)
|
| 107 |
self.model.load_state_dict(state_dict)
|
| 108 |
-
print(f"Successfully loaded model from: {model_path}")
|
| 109 |
except FileNotFoundError:
|
| 110 |
-
|
| 111 |
except Exception as e:
|
| 112 |
-
|
| 113 |
|
| 114 |
self.model.eval()
|
| 115 |
|
|
@@ -118,16 +119,53 @@ class AudioClassifier:
|
|
| 118 |
self.n_fft = 2048
|
| 119 |
self.hop_length = 512
|
| 120 |
|
| 121 |
-
def extract_mel_spectrogram(self, audio_path: str) -> np.ndarray:
|
|
|
|
| 122 |
audio, sr = librosa.load(audio_path, sr=self.sample_rate)
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
|
| 133 |
|
|
@@ -287,6 +325,11 @@ class AudioClassifier:
|
|
| 287 |
probabilities = F.softmax(logits, dim=1)
|
| 288 |
predicted_class = torch.argmax(probabilities, dim=1).item()
|
| 289 |
cnn_confidence = probabilities[0, predicted_class].item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
acoustic_features = self.extract_acoustic_features(audio_path)
|
| 292 |
|
|
@@ -294,7 +337,11 @@ class AudioClassifier:
|
|
| 294 |
prosody_classification = prosody_scores['classification']
|
| 295 |
prosody_confidence = prosody_scores['confidence']
|
| 296 |
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
if cnn_class_name == prosody_classification:
|
| 300 |
final_confidence = min(0.95, (cnn_confidence * 0.7 + prosody_confidence * 0.3))
|
|
|
|
| 84 |
}
|
| 85 |
|
| 86 |
@classmethod
|
| 87 |
+
def get_model_path(cls, model_name: str = '3s_window') -> str:
|
| 88 |
import os
|
| 89 |
if model_name not in cls.AVAILABLE_MODELS:
|
| 90 |
raise ValueError(f"Unknown model: {model_name}. Available: {list(cls.AVAILABLE_MODELS.keys())}")
|
|
|
|
| 100 |
|
| 101 |
if model_path is None:
|
| 102 |
import os
|
| 103 |
+
model_path = os.path.join(os.path.dirname(__file__), 'spectrogram_cnn_3s_window (1).pth')
|
| 104 |
|
| 105 |
try:
|
| 106 |
+
print(f"Attempting to load model from: {model_path}")
|
| 107 |
state_dict = torch.load(model_path, map_location=self.device)
|
| 108 |
self.model.load_state_dict(state_dict)
|
| 109 |
+
print(f"✓ Successfully loaded trained model from: {model_path}")
|
| 110 |
except FileNotFoundError:
|
| 111 |
+
raise FileNotFoundError(f"Model file not found at {model_path}. Please ensure the model file exists.")
|
| 112 |
except Exception as e:
|
| 113 |
+
raise RuntimeError(f"Error loading model from {model_path}: {e}")
|
| 114 |
|
| 115 |
self.model.eval()
|
| 116 |
|
|
|
|
| 119 |
self.n_fft = 2048
|
| 120 |
self.hop_length = 512
|
| 121 |
|
| 122 |
+
def extract_mel_spectrogram(self, audio_path: str, window_size: float = 3.0) -> np.ndarray:
|
| 123 |
+
"""Extract mel spectrogram from audio, using windowing if audio is longer than window_size."""
|
| 124 |
audio, sr = librosa.load(audio_path, sr=self.sample_rate)
|
| 125 |
|
| 126 |
+
# If audio is longer than window_size, take multiple windows and average
|
| 127 |
+
window_samples = int(window_size * sr)
|
| 128 |
+
|
| 129 |
+
if len(audio) > window_samples * 1.5: # If significantly longer
|
| 130 |
+
# Split into overlapping windows
|
| 131 |
+
hop_samples = window_samples // 2
|
| 132 |
+
windows = []
|
| 133 |
+
for start in range(0, len(audio) - window_samples, hop_samples):
|
| 134 |
+
window = audio[start:start + window_samples]
|
| 135 |
+
windows.append(window)
|
| 136 |
+
|
| 137 |
+
# Also add the last window
|
| 138 |
+
if len(audio) > window_samples:
|
| 139 |
+
windows.append(audio[-window_samples:])
|
| 140 |
+
|
| 141 |
+
# Compute mel spectrogram for each window and average
|
| 142 |
+
mel_specs = []
|
| 143 |
+
for window in windows[:5]: # Limit to 5 windows to avoid too much computation
|
| 144 |
+
mel_spec = librosa.feature.melspectrogram(
|
| 145 |
+
y=window,
|
| 146 |
+
sr=sr,
|
| 147 |
+
n_mels=self.n_mels,
|
| 148 |
+
n_fft=self.n_fft,
|
| 149 |
+
hop_length=self.hop_length
|
| 150 |
+
)
|
| 151 |
+
mel_specs.append(mel_spec)
|
| 152 |
+
|
| 153 |
+
# Average the spectrograms
|
| 154 |
+
mel_spec = np.mean(mel_specs, axis=0)
|
| 155 |
+
else:
|
| 156 |
+
# Pad or use as-is for short audio
|
| 157 |
+
if len(audio) < window_samples:
|
| 158 |
+
audio = np.pad(audio, (0, window_samples - len(audio)), mode='constant')
|
| 159 |
+
else:
|
| 160 |
+
audio = audio[:window_samples]
|
| 161 |
+
|
| 162 |
+
mel_spec = librosa.feature.melspectrogram(
|
| 163 |
+
y=audio,
|
| 164 |
+
sr=sr,
|
| 165 |
+
n_mels=self.n_mels,
|
| 166 |
+
n_fft=self.n_fft,
|
| 167 |
+
hop_length=self.hop_length
|
| 168 |
+
)
|
| 169 |
|
| 170 |
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
|
| 171 |
|
|
|
|
| 325 |
probabilities = F.softmax(logits, dim=1)
|
| 326 |
predicted_class = torch.argmax(probabilities, dim=1).item()
|
| 327 |
cnn_confidence = probabilities[0, predicted_class].item()
|
| 328 |
+
|
| 329 |
+
# Debug output
|
| 330 |
+
print(f"CNN Logits: {logits[0].cpu().numpy()}")
|
| 331 |
+
print(f"CNN Probabilities: Class 0 (read)={probabilities[0, 0].item():.3f}, Class 1 (spontaneous)={probabilities[0, 1].item():.3f}")
|
| 332 |
+
print(f"CNN Prediction: Class {predicted_class} ({['read', 'spontaneous'][predicted_class]}) with confidence {cnn_confidence:.3f}")
|
| 333 |
|
| 334 |
acoustic_features = self.extract_acoustic_features(audio_path)
|
| 335 |
|
|
|
|
| 337 |
prosody_classification = prosody_scores['classification']
|
| 338 |
prosody_confidence = prosody_scores['confidence']
|
| 339 |
|
| 340 |
+
# Try reversing labels if model was trained with opposite mapping
|
| 341 |
+
# Original: 0=read, 1=spontaneous
|
| 342 |
+
# Reversed: 0=spontaneous, 1=read
|
| 343 |
+
cnn_class_name = 'spontaneous' if predicted_class == 0 else 'read' # REVERSED LABELS
|
| 344 |
+
print(f"Final CNN classification: {cnn_class_name}")
|
| 345 |
|
| 346 |
if cnn_class_name == prosody_classification:
|
| 347 |
final_confidence = min(0.95, (cnn_confidence * 0.7 + prosody_confidence * 0.3))
|
requirements.txt
CHANGED
|
@@ -2,7 +2,7 @@ torch>=2.0.0
|
|
| 2 |
torchaudio>=2.0.0
|
| 3 |
openai-whisper>=20230314
|
| 4 |
transformers>=4.30.0
|
| 5 |
-
gradio
|
| 6 |
numpy>=1.24.0
|
| 7 |
scikit-learn>=1.3.0
|
| 8 |
librosa>=0.10.0
|
|
|
|
| 2 |
torchaudio>=2.0.0
|
| 3 |
openai-whisper>=20230314
|
| 4 |
transformers>=4.30.0
|
| 5 |
+
gradio==4.44.0
|
| 6 |
numpy>=1.24.0
|
| 7 |
scikit-learn>=1.3.0
|
| 8 |
librosa>=0.10.0
|