""" Simplified Dataset Loading for Inference Only Contains only the audio preprocessing functionality needed for inference """ import numpy as np import librosa import soundfile as sf from pathlib import Path import warnings import logging from typing import Dict, Tuple, Optional, List, Any logger = logging.getLogger(__name__) class DigitDatasetLoader: """ Simplified dataset loader for inference only. Contains only the audio preprocessing functionality. """ def __init__(self, sample_rate: int = 8000, max_length: float = 1.0, normalize_audio: bool = True): """ Initialize the dataset loader. Args: sample_rate: Target sample rate for audio max_length: Maximum length in seconds normalize_audio: Whether to normalize audio amplitude """ self.sample_rate = sample_rate self.max_length = max_length self.max_samples = int(sample_rate * max_length) self.normalize_audio = normalize_audio logger.debug(f"DatasetLoader initialized: sr={sample_rate}, max_len={max_length}s") def preprocess_audio(self, audio: np.ndarray, sr: int) -> np.ndarray: """ Preprocess audio for model inference. Args: audio: Audio data array sr: Original sample rate Returns: processed_audio: Preprocessed audio array """ try: # Convert to float32 if needed if audio.dtype != np.float32: audio = audio.astype(np.float32) # Resample if needed if sr != self.sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate) logger.debug(f"Resampled from {sr} to {self.sample_rate} Hz") # Ensure mono if len(audio.shape) > 1: audio = librosa.to_mono(audio) logger.debug("Converted to mono") # Normalize amplitude if self.normalize_audio: # Remove DC offset audio = audio - np.mean(audio) # Normalize to [-1, 1] range max_val = np.max(np.abs(audio)) if max_val > 0: audio = audio / max_val logger.debug(f"Normalized audio: range=[{np.min(audio):.3f}, {np.max(audio):.3f}]") # Trim silence from beginning and end audio, _ = librosa.effects.trim(audio, top_db=20) # Pad or truncate to fixed length if len(audio) > self.max_samples: # Truncate from center to preserve important parts excess = len(audio) - self.max_samples start = excess // 2 audio = audio[start:start + self.max_samples] logger.debug(f"Truncated audio to {self.max_samples} samples") elif len(audio) < self.max_samples: # Pad with zeros padding = self.max_samples - len(audio) pad_before = padding // 2 pad_after = padding - pad_before audio = np.pad(audio, (pad_before, pad_after), mode='constant') logger.debug(f"Padded audio to {self.max_samples} samples") # Final validation assert len(audio) == self.max_samples, f"Audio length mismatch: {len(audio)} != {self.max_samples}" assert audio.dtype == np.float32, f"Audio dtype mismatch: {audio.dtype} != float32" logger.debug(f"Preprocessing complete: shape={audio.shape}, dtype={audio.dtype}") return audio except Exception as e: logger.error(f"Audio preprocessing failed: {str(e)}") # Return silence as fallback return np.zeros(self.max_samples, dtype=np.float32) def validate_audio(self, audio: np.ndarray, sr: int) -> bool: """ Validate audio input. Args: audio: Audio array sr: Sample rate Returns: is_valid: Whether audio is valid """ try: if len(audio) == 0: logger.warning("Empty audio array") return False if sr <= 0: logger.warning(f"Invalid sample rate: {sr}") return False if np.any(np.isnan(audio)) or np.any(np.isinf(audio)): logger.warning("Audio contains NaN or Inf values") return False # Check if audio is not just silence if np.max(np.abs(audio)) < 1e-6: logger.warning("Audio appears to be silence") return False return True except Exception as e: logger.error(f"Audio validation failed: {str(e)}") return False def get_audio_info(self, audio: np.ndarray, sr: int) -> Dict[str, Any]: """ Get information about audio file. Args: audio: Audio array sr: Sample rate Returns: info: Audio information dictionary """ duration = len(audio) / sr info = { 'duration': duration, 'samples': len(audio), 'sample_rate': sr, 'channels': 1 if len(audio.shape) == 1 else audio.shape[0], 'dtype': str(audio.dtype), 'amplitude_range': [float(np.min(audio)), float(np.max(audio))], 'rms_energy': float(np.sqrt(np.mean(audio**2))), 'zero_crossing_rate': float(np.mean(librosa.feature.zero_crossing_rate(audio)[0])) } return info