Spaces:
Running
Running
File size: 7,227 Bytes
2cba492 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import csv
from dataclasses import dataclass
from pathlib import Path
import torch
import torchaudio
from torch.utils.data import Dataset
from ..util import _load_audio_internal, get_logger
logger = get_logger()
@dataclass
class AudioItem:
waveform: torch.Tensor
audio_id: str
path: Path
sample_rate: int
frame_offset: int | None = None # For chunked audio
def convert_to_mono(waveform: torch.Tensor) -> torch.Tensor:
# (1, samples)
if waveform.shape[0] > 1:
return torch.mean(waveform, dim=0, keepdim=True)
return waveform
def resample_audio(waveform: torch.Tensor, orig_freq: int, new_freq: int) -> torch.Tensor:
if orig_freq != new_freq:
resampler = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq)
return resampler(waveform)
return waveform
def normalize_audio(waveform: torch.Tensor) -> torch.Tensor:
max_val = torch.max(torch.abs(waveform)) + 1e-8
return waveform / max_val
def preprocess_audio(
waveform: torch.Tensor, sample_rate: int, mono: bool, normalize: bool, target_sample_rate: int | None = None
) -> tuple[torch.Tensor, int]:
# Convert to mono if needed
if mono:
waveform = convert_to_mono(waveform)
# Resample if needed
if target_sample_rate is not None and sample_rate != target_sample_rate:
waveform = resample_audio(waveform, sample_rate, target_sample_rate)
sample_rate = target_sample_rate
# Normalize if needed
if normalize:
waveform = normalize_audio(waveform)
return waveform, sample_rate
def pad_audio(waveform: torch.Tensor, target_length: int) -> torch.Tensor:
current_length = waveform.shape[1]
if current_length >= target_length:
return waveform
# Calculate padding needed
pad_length = target_length - current_length
# Pad with zeros at the end
padding = torch.zeros((waveform.shape[0], pad_length), dtype=waveform.dtype, device=waveform.device)
padded_waveform = torch.cat([waveform, padding], dim=1)
return padded_waveform
@dataclass
class ChunkInfo:
audio_id: str
frame_offset: int # In target sample rate
num_frames: int # In target sample rate
class ChunkedAudioDataset(Dataset):
"""
Dataset that loads audio from CSV with optional chunking.
Args:
csv_path: Path to the CSV file with columns: audio_id, path, length, sample_rate
audio_root: Root directory for audio files (prepended to paths in CSV)
chunk_size: Size of each chunk in frames (None = no chunking)
hop_size: Hop size between chunks in frames (None = use chunk_size)
mono: Convert to mono if True
normalize: Normalize audio if True
target_sample_rate: Resample to this sample rate if provided
"""
def __init__(
self,
csv_path: str,
audio_root: str,
chunk_size: int | None = None,
hop_size: int | None = None,
mono: bool = True,
normalize: bool = True,
target_sample_rate: int | None = None,
):
self.csv_path = csv_path
self.audio_root = audio_root
self.chunk_size = chunk_size
self.hop_size = hop_size if hop_size is not None else chunk_size
self.mono = mono
self.normalize = normalize
self.target_sample_rate = target_sample_rate
# Load CSV and compute chunks
self.file_entries = self._load_csv()
self.chunks = self._compute_chunks()
logger.info(f"Loaded dataset from {csv_path}: {len(self.file_entries)} files, {len(self.chunks)} chunks")
def _load_csv(self) -> dict[str, dict]:
"""Load audio metadata from CSV."""
entries = {}
with open(self.csv_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
entries[row["audio_id"]] = {
"path": row["path"],
"length": int(row["length"]),
"sample_rate": int(row["sample_rate"]),
}
return entries
def _compute_chunks(self) -> list[ChunkInfo]:
"""Compute all chunks from the file entries."""
chunks = []
for audio_id, entry in self.file_entries.items():
length = entry["length"]
sample_rate = entry["sample_rate"]
# Adjust length if resampling to target sample rate
if self.target_sample_rate is not None and sample_rate != self.target_sample_rate:
length = int(length * self.target_sample_rate / sample_rate)
sample_rate = self.target_sample_rate
if self.chunk_size is None or length <= self.chunk_size:
# No chunking, or file is shorter than chunk size: use entire file
chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=0, num_frames=length))
else:
# Chunking: compute all chunks with last chunk aligned to end
frame_offset = 0
while frame_offset + self.chunk_size <= length:
chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=frame_offset, num_frames=self.chunk_size))
frame_offset += self.hop_size
# Add the last chunk aligned to the end
last_start = length - self.chunk_size
if last_start > frame_offset - self.hop_size:
chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=last_start, num_frames=self.chunk_size))
return chunks
def __len__(self) -> int:
return len(self.chunks)
def __getitem__(self, idx: int) -> AudioItem:
"""Load and return a single audio chunk."""
chunk = self.chunks[idx]
entry = self.file_entries[chunk.audio_id]
orig_sample_rate = entry["sample_rate"]
full_path = Path(self.audio_root) / entry["path"]
# Calculate start frame and num frames in original sample rate
if self.target_sample_rate is not None and orig_sample_rate != self.target_sample_rate:
orig_frame_offset = int(chunk.frame_offset * orig_sample_rate / self.target_sample_rate)
orig_num_frames = int(chunk.num_frames * orig_sample_rate / self.target_sample_rate)
else:
orig_frame_offset = chunk.frame_offset
orig_num_frames = chunk.num_frames
waveform, sample_rate = _load_audio_internal(
full_path, frame_offset=orig_frame_offset, num_frames=orig_num_frames
)
waveform, sample_rate = preprocess_audio(
waveform=waveform,
sample_rate=sample_rate,
mono=self.mono,
normalize=self.normalize,
target_sample_rate=self.target_sample_rate,
)
# Pad if necessary (in case file is shorter than expected)
if self.chunk_size is not None and waveform.shape[1] < self.chunk_size:
waveform = pad_audio(waveform, self.chunk_size)
return AudioItem(
waveform=waveform,
audio_id=chunk.audio_id,
path=full_path,
sample_rate=sample_rate,
frame_offset=chunk.frame_offset,
)
|