Spaces:
Sleeping
Sleeping
File size: 6,879 Bytes
2279ae0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import os
import random
from PIL import Image, ImageFile
from datasets import register
from torch.utils.data import Dataset
from torchvision import transforms
import os
import random
from pathlib import Path
from typing import Optional, Callable
from models.ldm.dac.audiotools import AudioSignal
from models.ldm.dac.audiotools.core import util
# Audio file extensions (from audiotools)
AUDIO_EXTS = ('.wav', '.WAV', '.flac', '.FLAC', '.mp3', '.MP3', '.mp4', '.MP4', '.m4a', '.M4A')
@register('class_folder_audio')
class AudioFolder(Dataset):
"""
Audio dataset that loads audio files from a folder structure.
Similar to ClassFolder but for audio files.
Expected folder structure:
root_path/
├── class1/
│ ├── audio1.wav
│ ├── audio2.wav
│ └── ...
├── class2/
│ ├── audio1.wav
│ └── ...
└── ...
Or for single class (no subfolders):
root_path/
├── audio1.wav
├── audio2.wav
└── ...
"""
def __init__(
self,
root_path: str,
sample_rate: int = 24000,
duration: float = 2.0,
num_channels: int = 1,
random_crop: bool = True,
loudness_cutoff: float = -40,
audio_only: bool = False,
drop_label_p: float = 0.0,
shuffle: bool = True,
shuffle_state: int = 0,
transform: Optional[Callable] = None,
normalize: bool = True,
trim_silence: bool = False,
):
"""
Args:
root_path: Path to audio files
sample_rate: Target sample rate for audio
duration: Duration in seconds for audio clips
num_channels: Number of channels (1 for mono, 2 for stereo)
random_crop: Whether to randomly crop audio (vs deterministic)
loudness_cutoff: Minimum loudness threshold for audio selection
audio_only: If True, return only audio signal. If False, return dict with labels
drop_label_p: Probability of dropping labels (for unconditional training)
shuffle: Whether to shuffle files
shuffle_state: Random state for shuffling
transform: Additional audio transforms
normalize: Whether to normalize audio amplitude
trim_silence: Whether to trim silence from audio
"""
self.root_path = root_path
self.sample_rate = sample_rate
self.duration = duration
self.num_channels = num_channels
self.random_crop = random_crop
self.loudness_cutoff = loudness_cutoff
self.audio_only = audio_only
self.drop_label_p = drop_label_p
self.transform = transform
self.normalize = normalize
self.trim_silence = trim_silence
print(f'Audio root_path: {root_path}')
# Find audio files and labels
self.files = []
# Fin all audio in recursive in root_path
for root, dirs, files in os.walk(self.root_path):
for file in files:
if file.lower().endswith(AUDIO_EXTS):
self.files.append(os.path.join(root, file))
print(f'Found {len(self.files)} audio files')
# Shuffle files if requested
if shuffle:
state = util.random_state(shuffle_state)
combined = self.files
state.shuffle(combined)
self.files = combined
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
try:
file_path = self.files[idx]
# Load audio using AudioSignal
if self.random_crop:
# Use salient excerpt for random cropping with loudness filtering
signal = AudioSignal.salient_excerpt(
str(file_path),
duration=self.duration,
loudness_cutoff=self.loudness_cutoff,
)
else:
# Load from beginning or deterministic offset
signal = AudioSignal(
str(file_path),
duration=self.duration,
offset=0.0,
)
# Convert to mono/stereo as needed
if self.num_channels == 1:
signal = signal.to_mono()
# Resample to target sample rate
signal = signal.resample(self.sample_rate)
# Ensure duration by padding or trimming
target_samples = int(self.duration * self.sample_rate)
if signal.length < target_samples:
signal = signal.zero_pad_to(target_samples)
elif signal.length > target_samples:
signal = signal.truncate_samples(target_samples)
# Optional audio processing
if self.trim_silence:
signal = signal.trim_silence()
# Re-pad if trimming made it too short
if signal.length < target_samples:
signal = signal.zero_pad_to(target_samples)
if self.normalize:
signal = signal.normalize()
# Clamp audio to [-1, 1] range
signal.audio_data = signal.audio_data.clamp(-1.0, 1.0)
# Apply additional transforms if provided
if self.transform is not None:
# Create a random state for transforms
state = util.random_state(idx)
transform_args = self.transform.instantiate(state, signal=signal)
signal = self.transform(signal, **transform_args)
# print('before process: ', signal.audio_data.shape)
# Store metadata
signal.metadata.update(
{
'file_path': str(file_path),
'original_sr': signal.sample_rate,
'duration': self.duration,
}
)
if self.audio_only:
return signal
else:
return {
'signal': signal,
'file_path': str(file_path),
'idx': idx,
}
except Exception as e:
print(f'Error loading audio file {self.files[idx]}: {e}')
# Return next file on error to avoid crashing training
return self.__getitem__((idx + 1) % len(self))
def collate(self, batch):
"""Collate function for DataLoader"""
if self.audio_only:
# Batch AudioSignals
return AudioSignal.batch(batch)
else:
# Collate dictionary batch
return util.collate(batch) |