Pranav Mishra commited on
Commit
09061df
·
1 Parent(s): b5ec43b

Fix missing dependencies by simplifying dataset loader for inference

Browse files

- Replaced full dataset_loader.py with inference-only version
- Removed pandas and datasets dependencies (not needed for inference)
- Kept only audio preprocessing functionality (preprocess_audio method)
- Streamlined requirements.txt to remove heavy dependencies
- Should resolve 'ModuleNotFoundError: No module named pandas' error
- Smaller build size and faster deployment

ml_training/data/dataset_loader.py CHANGED
@@ -1,681 +1,167 @@
1
  """
2
- Dataset Loading and Preprocessing for Digit Classification
3
- Handles Free Spoken Digit Dataset (FSDD) loading and preprocessing
4
  """
5
 
6
  import numpy as np
7
- import pandas as pd
8
  import librosa
9
  import soundfile as sf
10
  from pathlib import Path
11
  import warnings
12
  import logging
13
- import subprocess
14
- import tempfile
15
- import os
16
  from typing import Dict, Tuple, Optional, List, Any
17
- from sklearn.model_selection import train_test_split
18
- from sklearn.preprocessing import LabelEncoder
19
- from datasets import load_dataset
20
- import torch
21
 
22
- # Setup logging
23
  logger = logging.getLogger(__name__)
24
 
25
  class DigitDatasetLoader:
26
  """
27
- Comprehensive dataset loader for spoken digit classification.
28
-
29
- Data Flow:
30
- Raw Audio (WAV) -> Preprocessing -> Normalized Arrays
31
- Input: Variable length audio files (8kHz)
32
- Output: Fixed length arrays (8000 samples = 1 second @ 8kHz)
33
  """
34
 
35
- def __init__(self, sample_rate: int = 8000, max_length: int = 8000,
36
- min_length: int = 1000):
37
  """
38
- Initialize dataset loader.
39
 
40
  Args:
41
- sample_rate: Target sampling rate (Hz)
42
- max_length: Maximum audio length in samples
43
- min_length: Minimum audio length in samples (for validation)
44
-
45
- Data Dimensions:
46
- - Raw audio: (batch_size, max_length) = (N, 8000)
47
- - Labels: (batch_size,) = (N,)
48
  """
49
  self.sample_rate = sample_rate
50
  self.max_length = max_length
51
- self.min_length = min_length
52
- self.label_encoder = LabelEncoder()
53
-
54
- # Check ffmpeg availability for better audio processing
55
- self._ffmpeg_available = self._check_ffmpeg_available()
56
- if self._ffmpeg_available:
57
- logger.info("ffmpeg detected - will use for high-quality audio resampling")
58
- else:
59
- logger.info("ffmpeg not available - using librosa for resampling")
60
-
61
- logger.info(f"Initialized DataLoader - SR: {sample_rate}Hz, Max Length: {max_length} samples")
62
-
63
- def load_fsdd_dataset(self) -> Optional[Any]:
64
- """
65
- Load Free Spoken Digit Dataset from HuggingFace.
66
-
67
- Returns:
68
- dataset: HuggingFace dataset object or None if failed
69
-
70
- Data Structure:
71
- - 'audio': {'array': np.ndarray, 'sampling_rate': int}
72
- - 'label': int (0-9)
73
- - Total samples: ~3000
74
- - Speakers: 6 different
75
- """
76
- try:
77
- logger.info("Loading Free Spoken Digit Dataset from HuggingFace...")
78
-
79
- # Load the correct HuggingFace dataset
80
- dataset = load_dataset("mteb/free-spoken-digit-dataset", trust_remote_code=True)
81
-
82
- logger.info(f"Dataset loaded successfully")
83
- logger.info(f"Available splits: {list(dataset.keys())}")
84
-
85
- # Check which split to use
86
- if 'train' in dataset:
87
- split_size = len(dataset['train'])
88
- logger.info(f"Train split size: {split_size}")
89
- elif 'test' in dataset:
90
- split_size = len(dataset['test'])
91
- logger.info(f"Test split size: {split_size}")
92
- else:
93
- # Use first available split
94
- first_split = list(dataset.keys())[0]
95
- split_size = len(dataset[first_split])
96
- logger.info(f"Using '{first_split}' split with {split_size} samples")
97
-
98
- # Validate dataset structure
99
- sample = dataset[list(dataset.keys())[0]][0]
100
- logger.info(f"Dataset sample structure: {sample.keys()}")
101
-
102
- if 'audio' in sample:
103
- audio_info = sample['audio']
104
- logger.info(f"Audio info: sampling_rate={audio_info.get('sampling_rate', 'N/A')}, "
105
- f"array_shape={audio_info['array'].shape if 'array' in audio_info else 'N/A'}")
106
-
107
- if 'label' in sample:
108
- logger.info(f"Label type: {type(sample['label'])}, value: {sample['label']}")
109
-
110
- return dataset
111
-
112
- except Exception as e:
113
- logger.error(f"Error loading dataset: {str(e)}")
114
- logger.info("Attempting fallback dataset loading...")
115
-
116
- try:
117
- # Fallback to manual loading
118
- return self._load_fsdd_fallback()
119
- except Exception as fallback_error:
120
- logger.error(f"Fallback loading failed: {str(fallback_error)}")
121
- return None
122
-
123
- def _load_fsdd_fallback(self):
124
- """
125
- Fallback dataset loading method with multiple strategies.
126
-
127
- Returns:
128
- Synthetic dataset or None if all methods fail
129
- """
130
- logger.warning("Using fallback dataset loading - attempting alternative methods")
131
-
132
- # Strategy 1: Try alternative HuggingFace dataset names
133
- alternative_datasets = [
134
- "free-spoken-digit-dataset",
135
- "speech_commands", # Has similar digit data
136
- ]
137
-
138
- for alt_dataset in alternative_datasets:
139
- try:
140
- logger.info(f"Trying alternative dataset: {alt_dataset}")
141
- dataset = load_dataset(alt_dataset, trust_remote_code=True)
142
- logger.info(f"Successfully loaded alternative dataset: {alt_dataset}")
143
- return dataset
144
- except Exception as e:
145
- logger.debug(f"Alternative dataset {alt_dataset} failed: {e}")
146
- continue
147
-
148
- # Strategy 2: Create synthetic dataset for development/testing
149
- logger.info("Creating synthetic dataset for development/testing")
150
- return self._create_synthetic_dataset()
151
-
152
- def _create_synthetic_dataset(self):
153
- """
154
- Create a synthetic dataset with digit-like audio patterns.
155
-
156
- Returns:
157
- Synthetic dataset in HuggingFace format
158
- """
159
- logger.info("Generating synthetic digit dataset...")
160
-
161
- num_samples_per_digit = 50 # 50 samples per digit
162
- num_digits = 10
163
-
164
- synthetic_data = []
165
-
166
- for digit in range(num_digits):
167
- for sample_idx in range(num_samples_per_digit):
168
- # Generate synthetic audio with digit-specific characteristics
169
- audio_array = self._generate_synthetic_audio(digit)
170
-
171
- # Create sample in HuggingFace format
172
- sample = {
173
- 'audio': {
174
- 'array': audio_array.astype(np.float32),
175
- 'sampling_rate': self.sample_rate
176
- },
177
- 'label': digit
178
- }
179
-
180
- synthetic_data.append(sample)
181
-
182
- logger.info(f"Created synthetic dataset with {len(synthetic_data)} samples")
183
-
184
- # Return in HuggingFace dataset-like format
185
- return {'train': synthetic_data}
186
-
187
- def _generate_synthetic_audio(self, digit: int) -> np.ndarray:
188
- """
189
- Generate synthetic audio for a specific digit.
190
-
191
- Args:
192
- digit: Digit (0-9) to generate audio for
193
-
194
- Returns:
195
- Synthetic audio array
196
- """
197
- duration = 1.0 # 1 second
198
- t = np.linspace(0, duration, self.max_length)
199
-
200
- # Create digit-specific frequency patterns
201
- base_freq = 200 + digit * 50 # Different base frequency for each digit
202
-
203
- # Generate harmonic series
204
- audio = np.zeros_like(t)
205
-
206
- # Add fundamental frequency
207
- audio += 0.3 * np.sin(2 * np.pi * base_freq * t)
208
-
209
- # Add harmonics
210
- audio += 0.2 * np.sin(2 * np.pi * base_freq * 2 * t)
211
- audio += 0.1 * np.sin(2 * np.pi * base_freq * 3 * t)
212
-
213
- # Add some noise for realism
214
- noise = np.random.normal(0, 0.05, len(t))
215
- audio += noise
216
-
217
- # Apply envelope (attack-decay-sustain-release)
218
- envelope = np.ones_like(t)
219
-
220
- # Attack (first 10%)
221
- attack_samples = int(0.1 * len(t))
222
- envelope[:attack_samples] = np.linspace(0, 1, attack_samples)
223
-
224
- # Decay (last 20%)
225
- decay_samples = int(0.2 * len(t))
226
- envelope[-decay_samples:] = np.linspace(1, 0, decay_samples)
227
-
228
- audio = audio * envelope
229
-
230
- # Normalize
231
- if np.max(np.abs(audio)) > 0:
232
- audio = audio / np.max(np.abs(audio)) * 0.8
233
-
234
- return audio
235
-
236
- def _check_ffmpeg_available(self) -> bool:
237
- """Check if ffmpeg is available on the system."""
238
- try:
239
- result = subprocess.run(['ffmpeg', '-version'],
240
- capture_output=True,
241
- text=True,
242
- timeout=5)
243
- return result.returncode == 0
244
- except (subprocess.SubprocessError, FileNotFoundError, subprocess.TimeoutExpired):
245
- return False
246
-
247
- def _convert_audio_with_ffmpeg(self, audio_array: np.ndarray, original_sr: int, target_sr: int) -> np.ndarray:
248
- """
249
- Convert audio using ffmpeg for better quality resampling.
250
 
251
- Args:
252
- audio_array: Input audio array
253
- original_sr: Original sampling rate
254
- target_sr: Target sampling rate
255
-
256
- Returns:
257
- Resampled audio array
258
- """
259
- try:
260
- # Create temporary files
261
- with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_input:
262
- with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_output:
263
-
264
- # Write input audio
265
- sf.write(temp_input.name, audio_array, original_sr)
266
-
267
- # Run ffmpeg conversion
268
- ffmpeg_cmd = [
269
- 'ffmpeg',
270
- '-i', temp_input.name,
271
- '-ar', str(target_sr),
272
- '-ac', '1', # Mono
273
- '-acodec', 'pcm_f32le', # 32-bit float
274
- '-y', # Overwrite output
275
- temp_output.name
276
- ]
277
-
278
- result = subprocess.run(ffmpeg_cmd,
279
- capture_output=True,
280
- text=True,
281
- timeout=30)
282
-
283
- if result.returncode == 0:
284
- # Read converted audio
285
- converted_audio, _ = sf.read(temp_output.name, dtype='float32')
286
- logger.debug(f"ffmpeg conversion successful: {original_sr}Hz -> {target_sr}Hz")
287
- return converted_audio
288
- else:
289
- logger.warning(f"ffmpeg conversion failed: {result.stderr}")
290
- return None
291
-
292
- except Exception as e:
293
- logger.warning(f"ffmpeg conversion error: {e}")
294
- return None
295
- finally:
296
- # Clean up temporary files
297
- try:
298
- if 'temp_input' in locals():
299
- os.unlink(temp_input.name)
300
- if 'temp_output' in locals():
301
- os.unlink(temp_output.name)
302
- except:
303
- pass
304
-
305
- return None
306
 
307
- def preprocess_audio(self, audio_array: np.ndarray, sr: int) -> np.ndarray:
308
  """
309
- Standardize audio preprocessing across all pipelines.
310
 
311
  Args:
312
- audio_array: Input audio signal
313
- sr: Original sampling rate
314
 
315
  Returns:
316
  processed_audio: Preprocessed audio array
317
-
318
- Data Flow:
319
- Input: (variable_length,) -> Output: (max_length,) = (8000,)
320
-
321
- Processing Steps:
322
- 1. Resample to target SR if needed
323
- 2. Amplitude normalization
324
- 3. Pad or truncate to fixed length
325
- 4. Validate output dimensions
326
  """
327
  try:
328
- logger.debug(f"Preprocessing audio - Original shape: {audio_array.shape}, SR: {sr}")
329
-
330
- # Input validation
331
- if len(audio_array) < self.min_length:
332
- logger.warning(f"Audio too short ({len(audio_array)} < {self.min_length}), padding")
333
 
334
- # Step 1: Resample if necessary
335
  if sr != self.sample_rate:
336
- logger.debug(f"Resampling from {sr}Hz to {self.sample_rate}Hz")
 
 
 
 
 
 
 
 
 
 
 
337
 
338
- # Try ffmpeg first for better quality, then fall back to librosa
339
- if hasattr(self, '_ffmpeg_available') and self._ffmpeg_available:
340
- ffmpeg_result = self._convert_audio_with_ffmpeg(audio_array, sr, self.sample_rate)
341
- if ffmpeg_result is not None:
342
- audio_array = ffmpeg_result
343
- logger.debug("Used ffmpeg for resampling")
344
- else:
345
- # Fall back to librosa
346
- audio_array = librosa.resample(
347
- audio_array,
348
- orig_sr=sr,
349
- target_sr=self.sample_rate
350
- )
351
- logger.debug("Fell back to librosa for resampling")
352
- else:
353
- # Use librosa
354
- audio_array = librosa.resample(
355
- audio_array,
356
- orig_sr=sr,
357
- target_sr=self.sample_rate
358
- )
359
-
360
- # Step 2: Amplitude normalization
361
- max_amplitude = np.max(np.abs(audio_array))
362
- if max_amplitude > 1e-8:
363
- audio_array = audio_array / max_amplitude
364
- else:
365
- logger.warning("Audio signal has very low amplitude")
366
-
367
- # Step 3: Pad or truncate to fixed length
368
- if len(audio_array) > self.max_length:
369
- # Truncate from center to preserve important content
370
- start = (len(audio_array) - self.max_length) // 2
371
- audio_array = audio_array[start:start + self.max_length]
372
- logger.debug(f"Truncated audio from center")
373
- else:
374
  # Pad with zeros
375
- pad_length = self.max_length - len(audio_array)
376
- audio_array = np.pad(audio_array, (0, pad_length), mode='constant')
377
- logger.debug(f"Padded audio with {pad_length} zeros")
 
 
378
 
379
- # Step 4: Validate output
380
- assert len(audio_array) == self.max_length, f"Expected length {self.max_length}, got {len(audio_array)}"
381
- assert not np.isnan(audio_array).any(), "NaN values found in processed audio"
382
- assert not np.isinf(audio_array).any(), "Infinite values found in processed audio"
383
 
384
- logger.debug(f"Preprocessing complete - Output shape: {audio_array.shape}")
385
- return audio_array.astype(np.float32)
 
386
 
387
  except Exception as e:
388
  logger.error(f"Audio preprocessing failed: {str(e)}")
389
- # Return zeros as fallback
390
- return np.zeros(self.max_length, dtype=np.float32)
391
 
392
- def create_train_test_split(self, dataset, test_size: float = 0.2,
393
- val_size: float = 0.1, random_state: int = 42) -> Dict[str, Any]:
394
  """
395
- Create stratified train/test/validation splits.
396
 
397
  Args:
398
- dataset: HuggingFace dataset
399
- test_size: Fraction for test set
400
- val_size: Fraction for validation set
401
- random_state: Random seed
402
 
403
  Returns:
404
- Dictionary containing splits and metadata
405
-
406
- Data Structure:
407
- Input: HuggingFace dataset
408
- Output: {
409
- 'X_train': (n_train, max_length) = (N*0.7, 8000)
410
- 'X_val': (n_val, max_length) = (N*0.1, 8000)
411
- 'X_test': (n_test, max_length) = (N*0.2, 8000)
412
- 'y_train': (n_train,)
413
- 'y_val': (n_val,)
414
- 'y_test': (n_test,)
415
- 'label_encoder': LabelEncoder object
416
- 'dataset_info': metadata dictionary
417
- }
418
  """
419
  try:
420
- logger.info("Creating train/test/validation splits...")
421
-
422
- if dataset is None:
423
- raise ValueError("Dataset is None - cannot create splits")
424
-
425
- # Extract features and labels
426
- audio_data = []
427
- labels = []
428
- sample_rates = []
429
-
430
- # Use train split or first available split
431
- split_name = 'train' if 'train' in dataset else list(dataset.keys())[0]
432
- data_split = dataset[split_name]
433
-
434
- logger.info(f"Processing {len(data_split)} samples from '{split_name}' split")
435
-
436
- for idx, item in enumerate(data_split):
437
- try:
438
- # Handle different data formats
439
- if isinstance(item, dict):
440
- # HuggingFace dataset format
441
- if 'audio' in item and isinstance(item['audio'], dict):
442
- audio = item['audio']['array']
443
- sr = item['audio']['sampling_rate']
444
- elif 'audio' in item and isinstance(item['audio'], np.ndarray):
445
- # Alternative format where audio is directly an array
446
- audio = item['audio']
447
- sr = item.get('sampling_rate', self.sample_rate)
448
- else:
449
- logger.warning(f"Unexpected audio format in sample {idx}: {item.keys()}")
450
- continue
451
-
452
- # Extract label
453
- label = item.get('label', self._extract_label_from_filename(item))
454
-
455
- # Convert label to int if it's a string
456
- if isinstance(label, str) and label.isdigit():
457
- label = int(label)
458
- elif not isinstance(label, (int, np.integer)):
459
- logger.warning(f"Invalid label format in sample {idx}: {label}")
460
- continue
461
-
462
- else:
463
- logger.warning(f"Unexpected item format in sample {idx}: {type(item)}")
464
- continue
465
-
466
- # Validate audio data
467
- if not isinstance(audio, (np.ndarray, list)):
468
- logger.warning(f"Audio data is not array-like in sample {idx}")
469
- continue
470
-
471
- # Convert to numpy array if needed
472
- if not isinstance(audio, np.ndarray):
473
- audio = np.array(audio, dtype=np.float32)
474
-
475
- # Validate audio shape and content
476
- if audio.size == 0:
477
- logger.warning(f"Empty audio array in sample {idx}")
478
- continue
479
-
480
- if np.all(audio == 0):
481
- logger.warning(f"Silent audio in sample {idx}")
482
- continue
483
-
484
- # Preprocess audio
485
- processed_audio = self.preprocess_audio(audio, sr)
486
-
487
- # Validate preprocessing result
488
- if processed_audio is None or len(processed_audio) != self.max_length:
489
- logger.warning(f"Preprocessing failed for sample {idx}")
490
- continue
491
-
492
- audio_data.append(processed_audio)
493
- labels.append(label)
494
- sample_rates.append(sr)
495
-
496
- if (idx + 1) % 100 == 0:
497
- logger.info(f"Processed {idx + 1}/{len(data_split)} samples")
498
-
499
- except Exception as e:
500
- logger.warning(f"Failed to process sample {idx}: {str(e)}")
501
- logger.debug(f"Sample {idx} content: {str(item)[:200]}...")
502
- continue
503
-
504
- if len(audio_data) == 0:
505
- raise ValueError("No valid audio samples found")
506
-
507
- # Convert to numpy arrays
508
- audio_data = np.array(audio_data, dtype=np.float32)
509
- labels = np.array(labels)
510
-
511
- logger.info(f"Audio data shape: {audio_data.shape}")
512
- logger.info(f"Labels shape: {labels.shape}")
513
- logger.info(f"Unique labels: {np.unique(labels)}")
514
-
515
- # Encode labels
516
- labels_encoded = self.label_encoder.fit_transform(labels)
517
- logger.info(f"Label encoding: {dict(zip(self.label_encoder.classes_, range(len(self.label_encoder.classes_))))}")
518
-
519
- # Create stratified splits
520
- logger.info(f"Creating splits - Test: {test_size:.1%}, Val: {val_size:.1%}")
521
-
522
- # First split: train+val and test
523
- X_temp, X_test, y_temp, y_test = train_test_split(
524
- audio_data, labels_encoded,
525
- test_size=test_size,
526
- stratify=labels_encoded,
527
- random_state=random_state
528
- )
529
 
530
- # Second split: train and validation
531
- if val_size > 0:
532
- val_size_adjusted = val_size / (1 - test_size)
533
- X_train, X_val, y_train, y_val = train_test_split(
534
- X_temp, y_temp,
535
- test_size=val_size_adjusted,
536
- stratify=y_temp,
537
- random_state=random_state
538
- )
539
- else:
540
- X_train, X_val, y_train, y_val = X_temp, None, y_temp, None
541
 
542
- # Create dataset info
543
- dataset_info = {
544
- 'total_samples': len(audio_data),
545
- 'train_samples': len(X_train),
546
- 'val_samples': len(X_val) if X_val is not None else 0,
547
- 'test_samples': len(X_test),
548
- 'sample_rate': self.sample_rate,
549
- 'max_length': self.max_length,
550
- 'num_classes': len(self.label_encoder.classes_),
551
- 'class_names': self.label_encoder.classes_.tolist(),
552
- 'audio_shape': audio_data.shape,
553
- 'mean_sample_rate': np.mean(sample_rates),
554
- 'std_sample_rate': np.std(sample_rates)
555
- }
556
 
557
- logger.info(f"Dataset splits created successfully:")
558
- logger.info(f" Train: {dataset_info['train_samples']} samples")
559
- logger.info(f" Val: {dataset_info['val_samples']} samples")
560
- logger.info(f" Test: {dataset_info['test_samples']} samples")
561
- logger.info(f" Classes: {dataset_info['num_classes']} ({dataset_info['class_names']})")
562
 
563
- return {
564
- 'X_train': X_train, 'y_train': y_train,
565
- 'X_val': X_val, 'y_val': y_val,
566
- 'X_test': X_test, 'y_test': y_test,
567
- 'label_encoder': self.label_encoder,
568
- 'dataset_info': dataset_info
569
- }
570
 
571
  except Exception as e:
572
- logger.error(f"Failed to create dataset splits: {str(e)}")
573
- raise
574
-
575
- def _extract_label_from_filename(self, item: Dict) -> int:
576
- """Extract digit label from filename if not in metadata."""
577
- try:
578
- # Attempt to extract from filename or path
579
- filename = item.get('path', item.get('filename', ''))
580
- # Look for digit in filename (0-9)
581
- for digit in range(10):
582
- if str(digit) in filename:
583
- return digit
584
- return 0 # Default fallback
585
- except:
586
- return 0
587
 
588
- def validate_splits(self, data_splits: Dict[str, Any]) -> bool:
589
  """
590
- Validate dataset splits for correctness.
591
 
592
  Args:
593
- data_splits: Dictionary containing dataset splits
 
594
 
595
  Returns:
596
- bool: True if validation passes
597
- """
598
- try:
599
- logger.info("Validating dataset splits...")
600
-
601
- required_keys = ['X_train', 'y_train', 'X_test', 'y_test', 'label_encoder', 'dataset_info']
602
- for key in required_keys:
603
- if key not in data_splits:
604
- logger.error(f"Missing required key: {key}")
605
- return False
606
-
607
- # Check shapes
608
- X_train, y_train = data_splits['X_train'], data_splits['y_train']
609
- X_test, y_test = data_splits['X_test'], data_splits['y_test']
610
-
611
- assert X_train.shape[0] == y_train.shape[0], "Train X/y shape mismatch"
612
- assert X_test.shape[0] == y_test.shape[0], "Test X/y shape mismatch"
613
- assert X_train.shape[1] == self.max_length, f"Audio length mismatch: {X_train.shape[1]} != {self.max_length}"
614
-
615
- if data_splits['X_val'] is not None:
616
- X_val, y_val = data_splits['X_val'], data_splits['y_val']
617
- assert X_val.shape[0] == y_val.shape[0], "Val X/y shape mismatch"
618
- assert X_val.shape[1] == self.max_length, f"Val audio length mismatch"
619
-
620
- # Check label distribution
621
- unique_labels = np.unique(np.concatenate([y_train, y_test]))
622
- expected_labels = np.arange(len(data_splits['label_encoder'].classes_))
623
- assert np.array_equal(unique_labels, expected_labels), "Label distribution mismatch"
624
-
625
- logger.info("Dataset validation passed")
626
- return True
627
-
628
- except Exception as e:
629
- logger.error(f"Dataset validation failed: {str(e)}")
630
- return False
631
-
632
- def load_and_prepare_data(sample_rate: int = 8000, max_length: int = 8000,
633
- test_size: float = 0.2, val_size: float = 0.1) -> Dict[str, Any]:
634
- """
635
- Convenience function to load and prepare dataset.
636
-
637
- Returns:
638
- Dictionary containing prepared dataset splits
639
- """
640
- logger.info("Loading and preparing digit dataset...")
641
-
642
- try:
643
- # Initialize loader
644
- data_loader = DigitDatasetLoader(sample_rate=sample_rate, max_length=max_length)
645
-
646
- # Load dataset
647
- dataset = data_loader.load_fsdd_dataset()
648
-
649
- if dataset is None:
650
- logger.error("Failed to load dataset")
651
- return None
652
-
653
- # Create splits
654
- data_splits = data_loader.create_train_test_split(
655
- dataset, test_size=test_size, val_size=val_size
656
- )
657
-
658
- # Validate
659
- if not data_loader.validate_splits(data_splits):
660
- logger.error("Dataset validation failed")
661
- return None
662
-
663
- logger.info("Dataset preparation completed successfully")
664
- return data_splits
665
-
666
- except Exception as e:
667
- logger.error(f"Dataset preparation failed: {str(e)}")
668
- return None
669
-
670
- if __name__ == "__main__":
671
- # Test dataset loading
672
- logging.basicConfig(level=logging.INFO)
673
-
674
- logger.info("Testing dataset loader...")
675
- data_splits = load_and_prepare_data()
676
-
677
- if data_splits:
678
- info = data_splits['dataset_info']
679
- logger.info(f"Dataset loaded successfully: {info}")
680
- else:
681
- logger.error("Dataset loading failed")
 
1
  """
2
+ Simplified Dataset Loading for Inference Only
3
+ Contains only the audio preprocessing functionality needed for inference
4
  """
5
 
6
  import numpy as np
 
7
  import librosa
8
  import soundfile as sf
9
  from pathlib import Path
10
  import warnings
11
  import logging
 
 
 
12
  from typing import Dict, Tuple, Optional, List, Any
 
 
 
 
13
 
 
14
  logger = logging.getLogger(__name__)
15
 
16
  class DigitDatasetLoader:
17
  """
18
+ Simplified dataset loader for inference only.
19
+ Contains only the audio preprocessing functionality.
 
 
 
 
20
  """
21
 
22
+ def __init__(self, sample_rate: int = 8000, max_length: float = 1.0,
23
+ normalize_audio: bool = True):
24
  """
25
+ Initialize the dataset loader.
26
 
27
  Args:
28
+ sample_rate: Target sample rate for audio
29
+ max_length: Maximum length in seconds
30
+ normalize_audio: Whether to normalize audio amplitude
 
 
 
 
31
  """
32
  self.sample_rate = sample_rate
33
  self.max_length = max_length
34
+ self.max_samples = int(sample_rate * max_length)
35
+ self.normalize_audio = normalize_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ logger.debug(f"DatasetLoader initialized: sr={sample_rate}, max_len={max_length}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ def preprocess_audio(self, audio: np.ndarray, sr: int) -> np.ndarray:
40
  """
41
+ Preprocess audio for model inference.
42
 
43
  Args:
44
+ audio: Audio data array
45
+ sr: Original sample rate
46
 
47
  Returns:
48
  processed_audio: Preprocessed audio array
 
 
 
 
 
 
 
 
 
49
  """
50
  try:
51
+ # Convert to float32 if needed
52
+ if audio.dtype != np.float32:
53
+ audio = audio.astype(np.float32)
 
 
54
 
55
+ # Resample if needed
56
  if sr != self.sample_rate:
57
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate)
58
+ logger.debug(f"Resampled from {sr} to {self.sample_rate} Hz")
59
+
60
+ # Ensure mono
61
+ if len(audio.shape) > 1:
62
+ audio = librosa.to_mono(audio)
63
+ logger.debug("Converted to mono")
64
+
65
+ # Normalize amplitude
66
+ if self.normalize_audio:
67
+ # Remove DC offset
68
+ audio = audio - np.mean(audio)
69
 
70
+ # Normalize to [-1, 1] range
71
+ max_val = np.max(np.abs(audio))
72
+ if max_val > 0:
73
+ audio = audio / max_val
74
+
75
+ logger.debug(f"Normalized audio: range=[{np.min(audio):.3f}, {np.max(audio):.3f}]")
76
+
77
+ # Trim silence from beginning and end
78
+ audio, _ = librosa.effects.trim(audio, top_db=20)
79
+
80
+ # Pad or truncate to fixed length
81
+ if len(audio) > self.max_samples:
82
+ # Truncate from center to preserve important parts
83
+ excess = len(audio) - self.max_samples
84
+ start = excess // 2
85
+ audio = audio[start:start + self.max_samples]
86
+ logger.debug(f"Truncated audio to {self.max_samples} samples")
87
+ elif len(audio) < self.max_samples:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  # Pad with zeros
89
+ padding = self.max_samples - len(audio)
90
+ pad_before = padding // 2
91
+ pad_after = padding - pad_before
92
+ audio = np.pad(audio, (pad_before, pad_after), mode='constant')
93
+ logger.debug(f"Padded audio to {self.max_samples} samples")
94
 
95
+ # Final validation
96
+ assert len(audio) == self.max_samples, f"Audio length mismatch: {len(audio)} != {self.max_samples}"
97
+ assert audio.dtype == np.float32, f"Audio dtype mismatch: {audio.dtype} != float32"
 
98
 
99
+ logger.debug(f"Preprocessing complete: shape={audio.shape}, dtype={audio.dtype}")
100
+
101
+ return audio
102
 
103
  except Exception as e:
104
  logger.error(f"Audio preprocessing failed: {str(e)}")
105
+ # Return silence as fallback
106
+ return np.zeros(self.max_samples, dtype=np.float32)
107
 
108
+ def validate_audio(self, audio: np.ndarray, sr: int) -> bool:
 
109
  """
110
+ Validate audio input.
111
 
112
  Args:
113
+ audio: Audio array
114
+ sr: Sample rate
 
 
115
 
116
  Returns:
117
+ is_valid: Whether audio is valid
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  """
119
  try:
120
+ if len(audio) == 0:
121
+ logger.warning("Empty audio array")
122
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ if sr <= 0:
125
+ logger.warning(f"Invalid sample rate: {sr}")
126
+ return False
 
 
 
 
 
 
 
 
127
 
128
+ if np.any(np.isnan(audio)) or np.any(np.isinf(audio)):
129
+ logger.warning("Audio contains NaN or Inf values")
130
+ return False
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ # Check if audio is not just silence
133
+ if np.max(np.abs(audio)) < 1e-6:
134
+ logger.warning("Audio appears to be silence")
135
+ return False
 
136
 
137
+ return True
 
 
 
 
 
 
138
 
139
  except Exception as e:
140
+ logger.error(f"Audio validation failed: {str(e)}")
141
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ def get_audio_info(self, audio: np.ndarray, sr: int) -> Dict[str, Any]:
144
  """
145
+ Get information about audio file.
146
 
147
  Args:
148
+ audio: Audio array
149
+ sr: Sample rate
150
 
151
  Returns:
152
+ info: Audio information dictionary
153
+ """
154
+ duration = len(audio) / sr
155
+
156
+ info = {
157
+ 'duration': duration,
158
+ 'samples': len(audio),
159
+ 'sample_rate': sr,
160
+ 'channels': 1 if len(audio.shape) == 1 else audio.shape[0],
161
+ 'dtype': str(audio.dtype),
162
+ 'amplitude_range': [float(np.min(audio)), float(np.max(audio))],
163
+ 'rms_energy': float(np.sqrt(np.mean(audio**2))),
164
+ 'zero_crossing_rate': float(np.mean(librosa.feature.zero_crossing_rate(audio)[0]))
165
+ }
166
+
167
+ return info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements_hf.txt CHANGED
@@ -1,4 +1,4 @@
1
- # HF Spaces Requirements - Python 3.10 compatible versions
2
  # Core Flask API
3
  Flask==2.3.3
4
  Flask-CORS==4.0.0
@@ -15,7 +15,7 @@ soundfile==0.12.1
15
  torch==2.0.1
16
  torchaudio==2.0.2
17
 
18
- # Essential ML utilities
19
  scikit-learn==1.3.0
20
 
21
  # Logging and utilities
 
1
+ # HF Spaces Requirements - Python 3.10 compatible versions (streamlined)
2
  # Core Flask API
3
  Flask==2.3.3
4
  Flask-CORS==4.0.0
 
15
  torch==2.0.1
16
  torchaudio==2.0.2
17
 
18
+ # Essential ML utilities (no pandas/datasets needed for inference)
19
  scikit-learn==1.3.0
20
 
21
  # Logging and utilities