Spaces:
Runtime error
Runtime error
| """ | |
| Simplified Dataset Loading for Inference Only | |
| Contains only the audio preprocessing functionality needed for inference | |
| """ | |
| import numpy as np | |
| import librosa | |
| import soundfile as sf | |
| from pathlib import Path | |
| import warnings | |
| import logging | |
| from typing import Dict, Tuple, Optional, List, Any | |
| logger = logging.getLogger(__name__) | |
| class DigitDatasetLoader: | |
| """ | |
| Simplified dataset loader for inference only. | |
| Contains only the audio preprocessing functionality. | |
| """ | |
| def __init__(self, sample_rate: int = 8000, max_length: float = 1.0, | |
| normalize_audio: bool = True): | |
| """ | |
| Initialize the dataset loader. | |
| Args: | |
| sample_rate: Target sample rate for audio | |
| max_length: Maximum length in seconds | |
| normalize_audio: Whether to normalize audio amplitude | |
| """ | |
| self.sample_rate = sample_rate | |
| self.max_length = max_length | |
| self.max_samples = int(sample_rate * max_length) | |
| self.normalize_audio = normalize_audio | |
| logger.debug(f"DatasetLoader initialized: sr={sample_rate}, max_len={max_length}s") | |
| def preprocess_audio(self, audio: np.ndarray, sr: int) -> np.ndarray: | |
| """ | |
| Preprocess audio for model inference. | |
| Args: | |
| audio: Audio data array | |
| sr: Original sample rate | |
| Returns: | |
| processed_audio: Preprocessed audio array | |
| """ | |
| try: | |
| # Convert to float32 if needed | |
| if audio.dtype != np.float32: | |
| audio = audio.astype(np.float32) | |
| # Resample if needed | |
| if sr != self.sample_rate: | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate) | |
| logger.debug(f"Resampled from {sr} to {self.sample_rate} Hz") | |
| # Ensure mono | |
| if len(audio.shape) > 1: | |
| audio = librosa.to_mono(audio) | |
| logger.debug("Converted to mono") | |
| # Normalize amplitude | |
| if self.normalize_audio: | |
| # Remove DC offset | |
| audio = audio - np.mean(audio) | |
| # Normalize to [-1, 1] range | |
| max_val = np.max(np.abs(audio)) | |
| if max_val > 0: | |
| audio = audio / max_val | |
| logger.debug(f"Normalized audio: range=[{np.min(audio):.3f}, {np.max(audio):.3f}]") | |
| # Trim silence from beginning and end | |
| audio, _ = librosa.effects.trim(audio, top_db=20) | |
| # Pad or truncate to fixed length | |
| if len(audio) > self.max_samples: | |
| # Truncate from center to preserve important parts | |
| excess = len(audio) - self.max_samples | |
| start = excess // 2 | |
| audio = audio[start:start + self.max_samples] | |
| logger.debug(f"Truncated audio to {self.max_samples} samples") | |
| elif len(audio) < self.max_samples: | |
| # Pad with zeros | |
| padding = self.max_samples - len(audio) | |
| pad_before = padding // 2 | |
| pad_after = padding - pad_before | |
| audio = np.pad(audio, (pad_before, pad_after), mode='constant') | |
| logger.debug(f"Padded audio to {self.max_samples} samples") | |
| # Final validation | |
| assert len(audio) == self.max_samples, f"Audio length mismatch: {len(audio)} != {self.max_samples}" | |
| assert audio.dtype == np.float32, f"Audio dtype mismatch: {audio.dtype} != float32" | |
| logger.debug(f"Preprocessing complete: shape={audio.shape}, dtype={audio.dtype}") | |
| return audio | |
| except Exception as e: | |
| logger.error(f"Audio preprocessing failed: {str(e)}") | |
| # Return silence as fallback | |
| return np.zeros(self.max_samples, dtype=np.float32) | |
| def validate_audio(self, audio: np.ndarray, sr: int) -> bool: | |
| """ | |
| Validate audio input. | |
| Args: | |
| audio: Audio array | |
| sr: Sample rate | |
| Returns: | |
| is_valid: Whether audio is valid | |
| """ | |
| try: | |
| if len(audio) == 0: | |
| logger.warning("Empty audio array") | |
| return False | |
| if sr <= 0: | |
| logger.warning(f"Invalid sample rate: {sr}") | |
| return False | |
| if np.any(np.isnan(audio)) or np.any(np.isinf(audio)): | |
| logger.warning("Audio contains NaN or Inf values") | |
| return False | |
| # Check if audio is not just silence | |
| if np.max(np.abs(audio)) < 1e-6: | |
| logger.warning("Audio appears to be silence") | |
| return False | |
| return True | |
| except Exception as e: | |
| logger.error(f"Audio validation failed: {str(e)}") | |
| return False | |
| def get_audio_info(self, audio: np.ndarray, sr: int) -> Dict[str, Any]: | |
| """ | |
| Get information about audio file. | |
| Args: | |
| audio: Audio array | |
| sr: Sample rate | |
| Returns: | |
| info: Audio information dictionary | |
| """ | |
| duration = len(audio) / sr | |
| info = { | |
| 'duration': duration, | |
| 'samples': len(audio), | |
| 'sample_rate': sr, | |
| 'channels': 1 if len(audio.shape) == 1 else audio.shape[0], | |
| 'dtype': str(audio.dtype), | |
| 'amplitude_range': [float(np.min(audio)), float(np.max(audio))], | |
| 'rms_energy': float(np.sqrt(np.mean(audio**2))), | |
| 'zero_crossing_rate': float(np.mean(librosa.feature.zero_crossing_rate(audio)[0])) | |
| } | |
| return info |