Spaces:
Sleeping
Sleeping
| """ | |
| Audio Preprocessing Module for Respiratory Symptom Analysis | |
| Updated for 39% F1-Macro Model (128x431 mel-spectrograms) | |
| Version: 3.0.0 | |
| """ | |
| import librosa | |
| import numpy as np | |
| import torch | |
| import warnings | |
| from typing import Union, Tuple, Dict | |
| import soundfile as sf | |
| import os | |
| from scipy import signal | |
| # Fix for Numba caching issues in Docker containers | |
| os.environ['NUMBA_CACHE_DIR'] = '/tmp' | |
| os.environ['NUMBA_DISABLE_JIT'] = '0' | |
| warnings.filterwarnings('ignore') | |
| class RespiratoryAudioPreprocessor: | |
| """ | |
| Audio preprocessor matching your 39% F1-Macro training pipeline | |
| Mel-spectrogram shape: (128, 431) to match training data | |
| """ | |
| def __init__(self, | |
| target_sr: int = 22050, | |
| n_mels: int = 128, | |
| n_fft: int = 2048, | |
| hop_length: int = 512, | |
| win_length: int = None, | |
| window: str = 'hann', | |
| fmin: float = 0.0, | |
| fmax: float = None, | |
| power: float = 2.0, | |
| duration: float = 10.0): # Changed from 3.0 to 10.0 to match training | |
| """Initialize preprocessing parameters to match training""" | |
| self.target_sr = target_sr | |
| self.n_mels = n_mels | |
| self.n_fft = n_fft | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| self.window = window | |
| self.fmin = fmin | |
| self.fmax = fmax or target_sr // 2 | |
| self.power = power | |
| self.duration = duration | |
| self.target_length = int(target_sr * duration) | |
| # Expected output shape - UPDATED to match training (128, 431) | |
| self.expected_shape = (1, 1, 128, 431) | |
| # Pre-warm librosa | |
| self._warmup_librosa() | |
| def _warmup_librosa(self): | |
| """Pre-compile librosa functions""" | |
| try: | |
| dummy_audio = np.random.randn(1024).astype(np.float32) | |
| _ = librosa.feature.melspectrogram( | |
| y=dummy_audio, | |
| sr=self.target_sr, | |
| n_mels=32, | |
| n_fft=512, | |
| hop_length=256 | |
| ) | |
| print("✅ Librosa functions warmed up successfully") | |
| except Exception as e: | |
| print(f"⚠️ Librosa warmup warning: {str(e)}") | |
| def scipy_resample(self, audio_data: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: | |
| """ | |
| Custom resampling using scipy.signal instead of resampy | |
| """ | |
| if orig_sr == target_sr: | |
| return audio_data | |
| try: | |
| # Calculate resampling ratio | |
| resample_ratio = target_sr / orig_sr | |
| # Use scipy.signal.resample for resampling | |
| target_length = int(len(audio_data) * resample_ratio) | |
| resampled_audio = signal.resample(audio_data, target_length) | |
| return resampled_audio.astype(np.float32) | |
| except Exception as e: | |
| print(f"⚠️ Scipy resampling failed: {e}, using original audio") | |
| return audio_data | |
| def load_and_normalize_audio(self, audio_input: Union[str, np.ndarray, tuple]) -> np.ndarray: | |
| """Load audio file without resampy dependency""" | |
| try: | |
| if isinstance(audio_input, str): | |
| # Load with soundfile first | |
| try: | |
| audio_data, sr = sf.read(audio_input) | |
| # Convert to mono if stereo | |
| if len(audio_data.shape) > 1: | |
| audio_data = np.mean(audio_data, axis=1) | |
| # Resample using scipy if needed | |
| if sr != self.target_sr: | |
| audio_data = self.scipy_resample(audio_data, sr, self.target_sr) | |
| except Exception as sf_error: | |
| # Fallback: try loading without librosa resampling | |
| try: | |
| # Load with original sample rate first | |
| audio_data, sr = librosa.load(audio_input, sr=None) | |
| # Convert to mono if stereo | |
| if len(audio_data.shape) > 1: | |
| audio_data = np.mean(audio_data, axis=1) | |
| # Manual resampling with scipy | |
| if sr != self.target_sr: | |
| audio_data = self.scipy_resample(audio_data, sr, self.target_sr) | |
| # Limit duration manually | |
| if len(audio_data) > self.target_length: | |
| audio_data = audio_data[:self.target_length] | |
| except Exception as librosa_error: | |
| raise RuntimeError(f"Failed to load audio. SoundFile: {sf_error}. Librosa: {librosa_error}") | |
| elif isinstance(audio_input, tuple): | |
| # (sample_rate, audio_array) from uploads | |
| sr, audio_data = audio_input | |
| # Convert to float32 | |
| if audio_data.dtype != np.float32: | |
| if audio_data.dtype == np.int16: | |
| audio_data = audio_data.astype(np.float32) / 32767.0 | |
| elif audio_data.dtype == np.int32: | |
| audio_data = audio_data.astype(np.float32) / 2147483647.0 | |
| else: | |
| audio_data = audio_data.astype(np.float32) | |
| # Convert to mono if stereo | |
| if len(audio_data.shape) > 1: | |
| audio_data = np.mean(audio_data, axis=1) | |
| # Resample using scipy | |
| if sr != self.target_sr: | |
| audio_data = self.scipy_resample(audio_data, sr, self.target_sr) | |
| # Trim duration | |
| if len(audio_data) > self.target_length: | |
| audio_data = audio_data[:self.target_length] | |
| elif isinstance(audio_input, np.ndarray): | |
| # Raw audio array (assume target_sr) | |
| audio_data = audio_input.astype(np.float32) | |
| # Convert to mono if stereo | |
| if len(audio_data.shape) > 1: | |
| audio_data = np.mean(audio_data, axis=1) | |
| if len(audio_data) > self.target_length: | |
| audio_data = audio_data[:self.target_length] | |
| else: | |
| raise ValueError(f"Unsupported audio input type: {type(audio_input)}") | |
| # Ensure 1D | |
| if len(audio_data.shape) > 1: | |
| audio_data = audio_data.flatten() | |
| # Pad if too short | |
| if len(audio_data) < self.target_length: | |
| audio_data = np.pad( | |
| audio_data, | |
| (0, self.target_length - len(audio_data)), | |
| mode='constant', | |
| constant_values=0 | |
| ) | |
| # Normalize amplitude | |
| max_val = np.max(np.abs(audio_data)) | |
| if max_val > 0: | |
| audio_data = audio_data / max_val | |
| return audio_data | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load audio: {str(e)}") | |
| def extract_mel_spectrogram(self, audio_data: np.ndarray) -> np.ndarray: | |
| """Extract mel spectrogram matching training configuration""" | |
| try: | |
| # Ensure proper format | |
| audio_data = np.asarray(audio_data, dtype=np.float32) | |
| if len(audio_data.shape) > 1: | |
| audio_data = audio_data.flatten() | |
| # Extract mel spectrogram with exact training parameters | |
| try: | |
| mel_spec = librosa.feature.melspectrogram( | |
| y=audio_data, | |
| sr=self.target_sr, | |
| n_mels=self.n_mels, | |
| n_fft=self.n_fft, | |
| hop_length=self.hop_length, | |
| win_length=self.win_length, | |
| window=self.window, | |
| fmin=self.fmin, | |
| fmax=self.fmax, | |
| power=self.power, | |
| center=True, | |
| pad_mode='constant' | |
| ) | |
| except Exception as mel_error: | |
| # Simplified fallback | |
| print(f"⚠️ Using simplified mel spectrogram extraction: {mel_error}") | |
| mel_spec = librosa.feature.melspectrogram( | |
| y=audio_data, | |
| sr=self.target_sr, | |
| n_mels=self.n_mels | |
| ) | |
| # Convert to dB | |
| mel_spec = np.maximum(mel_spec, 1e-10) | |
| mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max) | |
| return mel_spec_db | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to extract mel spectrogram: {str(e)}") | |
| def normalize_spectrogram(self, mel_spec: np.ndarray) -> np.ndarray: | |
| """Normalize spectrogram to match training""" | |
| try: | |
| mean = np.mean(mel_spec) | |
| std = np.std(mel_spec) | |
| if std == 0: | |
| normalized = mel_spec - mean | |
| else: | |
| normalized = (mel_spec - mean) / (std + 1e-8) | |
| # Clip to prevent extreme values | |
| normalized = np.clip(normalized, -5.0, 5.0) | |
| return normalized | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to normalize spectrogram: {str(e)}") | |
| def resize_spectrogram(self, mel_spec: np.ndarray, target_width: int = 431) -> np.ndarray: | |
| """ | |
| Resize spectrogram to target dimensions (128, 431) to match training | |
| """ | |
| try: | |
| current_height, current_width = mel_spec.shape | |
| # Handle height (should be 128 already) | |
| if current_height != 128: | |
| print(f"⚠️ Unexpected height: {current_height}, expected 128") | |
| # Handle width | |
| if current_width == target_width: | |
| return mel_spec | |
| if current_width < target_width: | |
| # Pad to target width | |
| pad_width = target_width - current_width | |
| mel_spec = np.pad( | |
| mel_spec, | |
| ((0, 0), (0, pad_width)), | |
| mode='constant', | |
| constant_values=0 | |
| ) | |
| else: | |
| # Crop to target width | |
| mel_spec = mel_spec[:, :target_width] | |
| return mel_spec | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to resize spectrogram: {str(e)}") | |
| def preprocess_audio(self, audio_input: Union[str, np.ndarray, tuple]) -> torch.Tensor: | |
| """ | |
| Complete preprocessing pipeline matching your training | |
| Output: (1, 1, 128, 431) tensor | |
| """ | |
| try: | |
| # Load audio | |
| audio_data = self.load_and_normalize_audio(audio_input) | |
| # Extract mel-spectrogram | |
| mel_spec = self.extract_mel_spectrogram(audio_data) | |
| # Normalize | |
| mel_spec_norm = self.normalize_spectrogram(mel_spec) | |
| # Resize to (128, 431) | |
| mel_spec_resized = self.resize_spectrogram(mel_spec_norm, target_width=431) | |
| # Convert to tensor (1, 1, 128, 431) | |
| tensor_input = torch.FloatTensor(mel_spec_resized) | |
| tensor_input = tensor_input.unsqueeze(0).unsqueeze(0) | |
| # Verify shape | |
| if tensor_input.shape != self.expected_shape: | |
| print(f"⚠️ Shape mismatch: got {tensor_input.shape}, expected {self.expected_shape}") | |
| # Force resize using interpolation as last resort | |
| tensor_input = torch.nn.functional.interpolate( | |
| tensor_input, | |
| size=self.expected_shape[2:], | |
| mode='bilinear', | |
| align_corners=False | |
| ) | |
| return tensor_input | |
| except Exception as e: | |
| raise RuntimeError(f"Preprocessing failed: {str(e)}") | |
| def get_preprocessing_info(self) -> Dict: | |
| """Get preprocessing configuration info""" | |
| return { | |
| 'target_sr': self.target_sr, | |
| 'n_mels': self.n_mels, | |
| 'n_fft': self.n_fft, | |
| 'hop_length': self.hop_length, | |
| 'duration': self.duration, | |
| 'output_shape': self.expected_shape, | |
| 'resampling_method': 'scipy.signal', | |
| 'normalization': 'z-score (mean=0, std=1)', | |
| 'db_scale': True | |
| } | |
| def validate_audio_file(self, audio_path: str) -> Tuple[bool, str]: | |
| """Validate audio file before processing""" | |
| try: | |
| if not audio_path: | |
| return False, "No audio file provided" | |
| try: | |
| info = sf.info(audio_path) | |
| duration = info.duration | |
| if duration < 0.5: | |
| return False, f"Audio too short ({duration:.1f}s). Minimum: 0.5s" | |
| if duration > 30.0: | |
| return False, f"Audio too long ({duration:.1f}s). Maximum: 30s" | |
| return True, "Audio file is valid" | |
| except Exception as e: | |
| return False, f"Error validating audio: {str(e)}" | |
| except Exception as e: | |
| return False, f"Validation error: {str(e)}" | |