Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| from PIL import Image, ImageFile | |
| from datasets import register | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| import os | |
| import random | |
| from pathlib import Path | |
| from typing import Optional, Callable | |
| from models.ldm.dac.audiotools import AudioSignal | |
| from models.ldm.dac.audiotools.core import util | |
| # Audio file extensions (from audiotools) | |
| AUDIO_EXTS = ('.wav', '.WAV', '.flac', '.FLAC', '.mp3', '.MP3', '.mp4', '.MP4', '.m4a', '.M4A') | |
| class AudioFolder(Dataset): | |
| """ | |
| Audio dataset that loads audio files from a folder structure. | |
| Similar to ClassFolder but for audio files. | |
| Expected folder structure: | |
| root_path/ | |
| ├── class1/ | |
| │ ├── audio1.wav | |
| │ ├── audio2.wav | |
| │ └── ... | |
| ├── class2/ | |
| │ ├── audio1.wav | |
| │ └── ... | |
| └── ... | |
| Or for single class (no subfolders): | |
| root_path/ | |
| ├── audio1.wav | |
| ├── audio2.wav | |
| └── ... | |
| """ | |
| def __init__( | |
| self, | |
| root_path: str, | |
| sample_rate: int = 24000, | |
| duration: float = 2.0, | |
| num_channels: int = 1, | |
| random_crop: bool = True, | |
| loudness_cutoff: float = -40, | |
| audio_only: bool = False, | |
| drop_label_p: float = 0.0, | |
| shuffle: bool = True, | |
| shuffle_state: int = 0, | |
| transform: Optional[Callable] = None, | |
| normalize: bool = True, | |
| trim_silence: bool = False, | |
| ): | |
| """ | |
| Args: | |
| root_path: Path to audio files | |
| sample_rate: Target sample rate for audio | |
| duration: Duration in seconds for audio clips | |
| num_channels: Number of channels (1 for mono, 2 for stereo) | |
| random_crop: Whether to randomly crop audio (vs deterministic) | |
| loudness_cutoff: Minimum loudness threshold for audio selection | |
| audio_only: If True, return only audio signal. If False, return dict with labels | |
| drop_label_p: Probability of dropping labels (for unconditional training) | |
| shuffle: Whether to shuffle files | |
| shuffle_state: Random state for shuffling | |
| transform: Additional audio transforms | |
| normalize: Whether to normalize audio amplitude | |
| trim_silence: Whether to trim silence from audio | |
| """ | |
| self.root_path = root_path | |
| self.sample_rate = sample_rate | |
| self.duration = duration | |
| self.num_channels = num_channels | |
| self.random_crop = random_crop | |
| self.loudness_cutoff = loudness_cutoff | |
| self.audio_only = audio_only | |
| self.drop_label_p = drop_label_p | |
| self.transform = transform | |
| self.normalize = normalize | |
| self.trim_silence = trim_silence | |
| print(f'Audio root_path: {root_path}') | |
| # Find audio files and labels | |
| self.files = [] | |
| # Fin all audio in recursive in root_path | |
| for root, dirs, files in os.walk(self.root_path): | |
| for file in files: | |
| if file.lower().endswith(AUDIO_EXTS): | |
| self.files.append(os.path.join(root, file)) | |
| print(f'Found {len(self.files)} audio files') | |
| # Shuffle files if requested | |
| if shuffle: | |
| state = util.random_state(shuffle_state) | |
| combined = self.files | |
| state.shuffle(combined) | |
| self.files = combined | |
| def __len__(self): | |
| return len(self.files) | |
| def __getitem__(self, idx): | |
| try: | |
| file_path = self.files[idx] | |
| # Load audio using AudioSignal | |
| if self.random_crop: | |
| # Use salient excerpt for random cropping with loudness filtering | |
| signal = AudioSignal.salient_excerpt( | |
| str(file_path), | |
| duration=self.duration, | |
| loudness_cutoff=self.loudness_cutoff, | |
| ) | |
| else: | |
| # Load from beginning or deterministic offset | |
| signal = AudioSignal( | |
| str(file_path), | |
| duration=self.duration, | |
| offset=0.0, | |
| ) | |
| # Convert to mono/stereo as needed | |
| if self.num_channels == 1: | |
| signal = signal.to_mono() | |
| # Resample to target sample rate | |
| signal = signal.resample(self.sample_rate) | |
| # Ensure duration by padding or trimming | |
| target_samples = int(self.duration * self.sample_rate) | |
| if signal.length < target_samples: | |
| signal = signal.zero_pad_to(target_samples) | |
| elif signal.length > target_samples: | |
| signal = signal.truncate_samples(target_samples) | |
| # Optional audio processing | |
| if self.trim_silence: | |
| signal = signal.trim_silence() | |
| # Re-pad if trimming made it too short | |
| if signal.length < target_samples: | |
| signal = signal.zero_pad_to(target_samples) | |
| if self.normalize: | |
| signal = signal.normalize() | |
| # Clamp audio to [-1, 1] range | |
| signal.audio_data = signal.audio_data.clamp(-1.0, 1.0) | |
| # Apply additional transforms if provided | |
| if self.transform is not None: | |
| # Create a random state for transforms | |
| state = util.random_state(idx) | |
| transform_args = self.transform.instantiate(state, signal=signal) | |
| signal = self.transform(signal, **transform_args) | |
| # print('before process: ', signal.audio_data.shape) | |
| # Store metadata | |
| signal.metadata.update( | |
| { | |
| 'file_path': str(file_path), | |
| 'original_sr': signal.sample_rate, | |
| 'duration': self.duration, | |
| } | |
| ) | |
| if self.audio_only: | |
| return signal | |
| else: | |
| return { | |
| 'signal': signal, | |
| 'file_path': str(file_path), | |
| 'idx': idx, | |
| } | |
| except Exception as e: | |
| print(f'Error loading audio file {self.files[idx]}: {e}') | |
| # Return next file on error to avoid crashing training | |
| return self.__getitem__((idx + 1) % len(self)) | |
| def collate(self, batch): | |
| """Collate function for DataLoader""" | |
| if self.audio_only: | |
| # Batch AudioSignals | |
| return AudioSignal.batch(batch) | |
| else: | |
| # Collate dictionary batch | |
| return util.collate(batch) |