respiratory-symptom-api / audio_preprocessing.py
Kalpokoch's picture
updated backend
aaa4ee3
"""
Audio Preprocessing Module for Respiratory Symptom Analysis
Updated for 39% F1-Macro Model (128x431 mel-spectrograms)
Version: 3.0.0
"""
import librosa
import numpy as np
import torch
import warnings
from typing import Union, Tuple, Dict
import soundfile as sf
import os
from scipy import signal
# Fix for Numba caching issues in Docker containers
os.environ['NUMBA_CACHE_DIR'] = '/tmp'
os.environ['NUMBA_DISABLE_JIT'] = '0'
warnings.filterwarnings('ignore')
class RespiratoryAudioPreprocessor:
"""
Audio preprocessor matching your 39% F1-Macro training pipeline
Mel-spectrogram shape: (128, 431) to match training data
"""
def __init__(self,
target_sr: int = 22050,
n_mels: int = 128,
n_fft: int = 2048,
hop_length: int = 512,
win_length: int = None,
window: str = 'hann',
fmin: float = 0.0,
fmax: float = None,
power: float = 2.0,
duration: float = 10.0): # Changed from 3.0 to 10.0 to match training
"""Initialize preprocessing parameters to match training"""
self.target_sr = target_sr
self.n_mels = n_mels
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.fmin = fmin
self.fmax = fmax or target_sr // 2
self.power = power
self.duration = duration
self.target_length = int(target_sr * duration)
# Expected output shape - UPDATED to match training (128, 431)
self.expected_shape = (1, 1, 128, 431)
# Pre-warm librosa
self._warmup_librosa()
def _warmup_librosa(self):
"""Pre-compile librosa functions"""
try:
dummy_audio = np.random.randn(1024).astype(np.float32)
_ = librosa.feature.melspectrogram(
y=dummy_audio,
sr=self.target_sr,
n_mels=32,
n_fft=512,
hop_length=256
)
print("✅ Librosa functions warmed up successfully")
except Exception as e:
print(f"⚠️ Librosa warmup warning: {str(e)}")
def scipy_resample(self, audio_data: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
"""
Custom resampling using scipy.signal instead of resampy
"""
if orig_sr == target_sr:
return audio_data
try:
# Calculate resampling ratio
resample_ratio = target_sr / orig_sr
# Use scipy.signal.resample for resampling
target_length = int(len(audio_data) * resample_ratio)
resampled_audio = signal.resample(audio_data, target_length)
return resampled_audio.astype(np.float32)
except Exception as e:
print(f"⚠️ Scipy resampling failed: {e}, using original audio")
return audio_data
def load_and_normalize_audio(self, audio_input: Union[str, np.ndarray, tuple]) -> np.ndarray:
"""Load audio file without resampy dependency"""
try:
if isinstance(audio_input, str):
# Load with soundfile first
try:
audio_data, sr = sf.read(audio_input)
# Convert to mono if stereo
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# Resample using scipy if needed
if sr != self.target_sr:
audio_data = self.scipy_resample(audio_data, sr, self.target_sr)
except Exception as sf_error:
# Fallback: try loading without librosa resampling
try:
# Load with original sample rate first
audio_data, sr = librosa.load(audio_input, sr=None)
# Convert to mono if stereo
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# Manual resampling with scipy
if sr != self.target_sr:
audio_data = self.scipy_resample(audio_data, sr, self.target_sr)
# Limit duration manually
if len(audio_data) > self.target_length:
audio_data = audio_data[:self.target_length]
except Exception as librosa_error:
raise RuntimeError(f"Failed to load audio. SoundFile: {sf_error}. Librosa: {librosa_error}")
elif isinstance(audio_input, tuple):
# (sample_rate, audio_array) from uploads
sr, audio_data = audio_input
# Convert to float32
if audio_data.dtype != np.float32:
if audio_data.dtype == np.int16:
audio_data = audio_data.astype(np.float32) / 32767.0
elif audio_data.dtype == np.int32:
audio_data = audio_data.astype(np.float32) / 2147483647.0
else:
audio_data = audio_data.astype(np.float32)
# Convert to mono if stereo
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
# Resample using scipy
if sr != self.target_sr:
audio_data = self.scipy_resample(audio_data, sr, self.target_sr)
# Trim duration
if len(audio_data) > self.target_length:
audio_data = audio_data[:self.target_length]
elif isinstance(audio_input, np.ndarray):
# Raw audio array (assume target_sr)
audio_data = audio_input.astype(np.float32)
# Convert to mono if stereo
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
if len(audio_data) > self.target_length:
audio_data = audio_data[:self.target_length]
else:
raise ValueError(f"Unsupported audio input type: {type(audio_input)}")
# Ensure 1D
if len(audio_data.shape) > 1:
audio_data = audio_data.flatten()
# Pad if too short
if len(audio_data) < self.target_length:
audio_data = np.pad(
audio_data,
(0, self.target_length - len(audio_data)),
mode='constant',
constant_values=0
)
# Normalize amplitude
max_val = np.max(np.abs(audio_data))
if max_val > 0:
audio_data = audio_data / max_val
return audio_data
except Exception as e:
raise RuntimeError(f"Failed to load audio: {str(e)}")
def extract_mel_spectrogram(self, audio_data: np.ndarray) -> np.ndarray:
"""Extract mel spectrogram matching training configuration"""
try:
# Ensure proper format
audio_data = np.asarray(audio_data, dtype=np.float32)
if len(audio_data.shape) > 1:
audio_data = audio_data.flatten()
# Extract mel spectrogram with exact training parameters
try:
mel_spec = librosa.feature.melspectrogram(
y=audio_data,
sr=self.target_sr,
n_mels=self.n_mels,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
fmin=self.fmin,
fmax=self.fmax,
power=self.power,
center=True,
pad_mode='constant'
)
except Exception as mel_error:
# Simplified fallback
print(f"⚠️ Using simplified mel spectrogram extraction: {mel_error}")
mel_spec = librosa.feature.melspectrogram(
y=audio_data,
sr=self.target_sr,
n_mels=self.n_mels
)
# Convert to dB
mel_spec = np.maximum(mel_spec, 1e-10)
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
return mel_spec_db
except Exception as e:
raise RuntimeError(f"Failed to extract mel spectrogram: {str(e)}")
def normalize_spectrogram(self, mel_spec: np.ndarray) -> np.ndarray:
"""Normalize spectrogram to match training"""
try:
mean = np.mean(mel_spec)
std = np.std(mel_spec)
if std == 0:
normalized = mel_spec - mean
else:
normalized = (mel_spec - mean) / (std + 1e-8)
# Clip to prevent extreme values
normalized = np.clip(normalized, -5.0, 5.0)
return normalized
except Exception as e:
raise RuntimeError(f"Failed to normalize spectrogram: {str(e)}")
def resize_spectrogram(self, mel_spec: np.ndarray, target_width: int = 431) -> np.ndarray:
"""
Resize spectrogram to target dimensions (128, 431) to match training
"""
try:
current_height, current_width = mel_spec.shape
# Handle height (should be 128 already)
if current_height != 128:
print(f"⚠️ Unexpected height: {current_height}, expected 128")
# Handle width
if current_width == target_width:
return mel_spec
if current_width < target_width:
# Pad to target width
pad_width = target_width - current_width
mel_spec = np.pad(
mel_spec,
((0, 0), (0, pad_width)),
mode='constant',
constant_values=0
)
else:
# Crop to target width
mel_spec = mel_spec[:, :target_width]
return mel_spec
except Exception as e:
raise RuntimeError(f"Failed to resize spectrogram: {str(e)}")
def preprocess_audio(self, audio_input: Union[str, np.ndarray, tuple]) -> torch.Tensor:
"""
Complete preprocessing pipeline matching your training
Output: (1, 1, 128, 431) tensor
"""
try:
# Load audio
audio_data = self.load_and_normalize_audio(audio_input)
# Extract mel-spectrogram
mel_spec = self.extract_mel_spectrogram(audio_data)
# Normalize
mel_spec_norm = self.normalize_spectrogram(mel_spec)
# Resize to (128, 431)
mel_spec_resized = self.resize_spectrogram(mel_spec_norm, target_width=431)
# Convert to tensor (1, 1, 128, 431)
tensor_input = torch.FloatTensor(mel_spec_resized)
tensor_input = tensor_input.unsqueeze(0).unsqueeze(0)
# Verify shape
if tensor_input.shape != self.expected_shape:
print(f"⚠️ Shape mismatch: got {tensor_input.shape}, expected {self.expected_shape}")
# Force resize using interpolation as last resort
tensor_input = torch.nn.functional.interpolate(
tensor_input,
size=self.expected_shape[2:],
mode='bilinear',
align_corners=False
)
return tensor_input
except Exception as e:
raise RuntimeError(f"Preprocessing failed: {str(e)}")
def get_preprocessing_info(self) -> Dict:
"""Get preprocessing configuration info"""
return {
'target_sr': self.target_sr,
'n_mels': self.n_mels,
'n_fft': self.n_fft,
'hop_length': self.hop_length,
'duration': self.duration,
'output_shape': self.expected_shape,
'resampling_method': 'scipy.signal',
'normalization': 'z-score (mean=0, std=1)',
'db_scale': True
}
def validate_audio_file(self, audio_path: str) -> Tuple[bool, str]:
"""Validate audio file before processing"""
try:
if not audio_path:
return False, "No audio file provided"
try:
info = sf.info(audio_path)
duration = info.duration
if duration < 0.5:
return False, f"Audio too short ({duration:.1f}s). Minimum: 0.5s"
if duration > 30.0:
return False, f"Audio too long ({duration:.1f}s). Maximum: 30s"
return True, "Audio file is valid"
except Exception as e:
return False, f"Error validating audio: {str(e)}"
except Exception as e:
return False, f"Validation error: {str(e)}"