Pranav Mishra
Fix missing dependencies by simplifying dataset loader for inference
09061df
"""
Simplified Dataset Loading for Inference Only
Contains only the audio preprocessing functionality needed for inference
"""
import numpy as np
import librosa
import soundfile as sf
from pathlib import Path
import warnings
import logging
from typing import Dict, Tuple, Optional, List, Any
logger = logging.getLogger(__name__)
class DigitDatasetLoader:
"""
Simplified dataset loader for inference only.
Contains only the audio preprocessing functionality.
"""
def __init__(self, sample_rate: int = 8000, max_length: float = 1.0,
normalize_audio: bool = True):
"""
Initialize the dataset loader.
Args:
sample_rate: Target sample rate for audio
max_length: Maximum length in seconds
normalize_audio: Whether to normalize audio amplitude
"""
self.sample_rate = sample_rate
self.max_length = max_length
self.max_samples = int(sample_rate * max_length)
self.normalize_audio = normalize_audio
logger.debug(f"DatasetLoader initialized: sr={sample_rate}, max_len={max_length}s")
def preprocess_audio(self, audio: np.ndarray, sr: int) -> np.ndarray:
"""
Preprocess audio for model inference.
Args:
audio: Audio data array
sr: Original sample rate
Returns:
processed_audio: Preprocessed audio array
"""
try:
# Convert to float32 if needed
if audio.dtype != np.float32:
audio = audio.astype(np.float32)
# Resample if needed
if sr != self.sample_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate)
logger.debug(f"Resampled from {sr} to {self.sample_rate} Hz")
# Ensure mono
if len(audio.shape) > 1:
audio = librosa.to_mono(audio)
logger.debug("Converted to mono")
# Normalize amplitude
if self.normalize_audio:
# Remove DC offset
audio = audio - np.mean(audio)
# Normalize to [-1, 1] range
max_val = np.max(np.abs(audio))
if max_val > 0:
audio = audio / max_val
logger.debug(f"Normalized audio: range=[{np.min(audio):.3f}, {np.max(audio):.3f}]")
# Trim silence from beginning and end
audio, _ = librosa.effects.trim(audio, top_db=20)
# Pad or truncate to fixed length
if len(audio) > self.max_samples:
# Truncate from center to preserve important parts
excess = len(audio) - self.max_samples
start = excess // 2
audio = audio[start:start + self.max_samples]
logger.debug(f"Truncated audio to {self.max_samples} samples")
elif len(audio) < self.max_samples:
# Pad with zeros
padding = self.max_samples - len(audio)
pad_before = padding // 2
pad_after = padding - pad_before
audio = np.pad(audio, (pad_before, pad_after), mode='constant')
logger.debug(f"Padded audio to {self.max_samples} samples")
# Final validation
assert len(audio) == self.max_samples, f"Audio length mismatch: {len(audio)} != {self.max_samples}"
assert audio.dtype == np.float32, f"Audio dtype mismatch: {audio.dtype} != float32"
logger.debug(f"Preprocessing complete: shape={audio.shape}, dtype={audio.dtype}")
return audio
except Exception as e:
logger.error(f"Audio preprocessing failed: {str(e)}")
# Return silence as fallback
return np.zeros(self.max_samples, dtype=np.float32)
def validate_audio(self, audio: np.ndarray, sr: int) -> bool:
"""
Validate audio input.
Args:
audio: Audio array
sr: Sample rate
Returns:
is_valid: Whether audio is valid
"""
try:
if len(audio) == 0:
logger.warning("Empty audio array")
return False
if sr <= 0:
logger.warning(f"Invalid sample rate: {sr}")
return False
if np.any(np.isnan(audio)) or np.any(np.isinf(audio)):
logger.warning("Audio contains NaN or Inf values")
return False
# Check if audio is not just silence
if np.max(np.abs(audio)) < 1e-6:
logger.warning("Audio appears to be silence")
return False
return True
except Exception as e:
logger.error(f"Audio validation failed: {str(e)}")
return False
def get_audio_info(self, audio: np.ndarray, sr: int) -> Dict[str, Any]:
"""
Get information about audio file.
Args:
audio: Audio array
sr: Sample rate
Returns:
info: Audio information dictionary
"""
duration = len(audio) / sr
info = {
'duration': duration,
'samples': len(audio),
'sample_rate': sr,
'channels': 1 if len(audio.shape) == 1 else audio.shape[0],
'dtype': str(audio.dtype),
'amplitude_range': [float(np.min(audio)), float(np.max(audio))],
'rms_energy': float(np.sqrt(np.mean(audio**2))),
'zero_crossing_rate': float(np.mean(librosa.feature.zero_crossing_rate(audio)[0]))
}
return info