| import csv |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import torch |
| import torchaudio |
| from torch.utils.data import Dataset |
|
|
| from ..util import _load_audio_internal, get_logger |
|
|
| logger = get_logger() |
|
|
|
|
| @dataclass |
| class AudioItem: |
| waveform: torch.Tensor |
| audio_id: str |
| path: Path |
| sample_rate: int |
| frame_offset: int | None = None |
|
|
|
|
| def convert_to_mono(waveform: torch.Tensor) -> torch.Tensor: |
| |
| if waveform.shape[0] > 1: |
| return torch.mean(waveform, dim=0, keepdim=True) |
| return waveform |
|
|
|
|
| def resample_audio(waveform: torch.Tensor, orig_freq: int, new_freq: int) -> torch.Tensor: |
| if orig_freq != new_freq: |
| resampler = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq) |
| return resampler(waveform) |
| return waveform |
|
|
|
|
| def normalize_audio(waveform: torch.Tensor) -> torch.Tensor: |
| max_val = torch.max(torch.abs(waveform)) + 1e-8 |
| return waveform / max_val |
|
|
|
|
| def preprocess_audio( |
| waveform: torch.Tensor, sample_rate: int, mono: bool, normalize: bool, target_sample_rate: int | None = None |
| ) -> tuple[torch.Tensor, int]: |
| |
| if mono: |
| waveform = convert_to_mono(waveform) |
|
|
| |
| if target_sample_rate is not None and sample_rate != target_sample_rate: |
| waveform = resample_audio(waveform, sample_rate, target_sample_rate) |
| sample_rate = target_sample_rate |
|
|
| |
| if normalize: |
| waveform = normalize_audio(waveform) |
|
|
| return waveform, sample_rate |
|
|
|
|
| def pad_audio(waveform: torch.Tensor, target_length: int) -> torch.Tensor: |
| current_length = waveform.shape[1] |
| if current_length >= target_length: |
| return waveform |
|
|
| |
| pad_length = target_length - current_length |
| |
| padding = torch.zeros((waveform.shape[0], pad_length), dtype=waveform.dtype, device=waveform.device) |
| padded_waveform = torch.cat([waveform, padding], dim=1) |
| return padded_waveform |
|
|
|
|
| @dataclass |
| class ChunkInfo: |
| audio_id: str |
| frame_offset: int |
| num_frames: int |
|
|
|
|
| class ChunkedAudioDataset(Dataset): |
| """ |
| Dataset that loads audio from CSV with optional chunking. |
| |
| Args: |
| csv_path: Path to the CSV file with columns: audio_id, path, length, sample_rate |
| audio_root: Root directory for audio files (prepended to paths in CSV) |
| chunk_size: Size of each chunk in frames (None = no chunking) |
| hop_size: Hop size between chunks in frames (None = use chunk_size) |
| mono: Convert to mono if True |
| normalize: Normalize audio if True |
| target_sample_rate: Resample to this sample rate if provided |
| """ |
|
|
| def __init__( |
| self, |
| csv_path: str, |
| audio_root: str, |
| chunk_size: int | None = None, |
| hop_size: int | None = None, |
| mono: bool = True, |
| normalize: bool = True, |
| target_sample_rate: int | None = None, |
| ): |
| self.csv_path = csv_path |
| self.audio_root = audio_root |
| self.chunk_size = chunk_size |
| self.hop_size = hop_size if hop_size is not None else chunk_size |
| self.mono = mono |
| self.normalize = normalize |
| self.target_sample_rate = target_sample_rate |
|
|
| |
| self.file_entries = self._load_csv() |
| self.chunks = self._compute_chunks() |
|
|
| logger.info(f"Loaded dataset from {csv_path}: {len(self.file_entries)} files, {len(self.chunks)} chunks") |
|
|
| def _load_csv(self) -> dict[str, dict]: |
| """Load audio metadata from CSV.""" |
| entries = {} |
| with open(self.csv_path, "r", encoding="utf-8") as f: |
| reader = csv.DictReader(f) |
| for row in reader: |
| entries[row["audio_id"]] = { |
| "path": row["path"], |
| "length": int(row["length"]), |
| "sample_rate": int(row["sample_rate"]), |
| } |
| return entries |
|
|
| def _compute_chunks(self) -> list[ChunkInfo]: |
| """Compute all chunks from the file entries.""" |
| chunks = [] |
| for audio_id, entry in self.file_entries.items(): |
| length = entry["length"] |
| sample_rate = entry["sample_rate"] |
|
|
| |
| if self.target_sample_rate is not None and sample_rate != self.target_sample_rate: |
| length = int(length * self.target_sample_rate / sample_rate) |
| sample_rate = self.target_sample_rate |
|
|
| if self.chunk_size is None or length <= self.chunk_size: |
| |
| chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=0, num_frames=length)) |
| else: |
| |
| frame_offset = 0 |
| while frame_offset + self.chunk_size <= length: |
| chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=frame_offset, num_frames=self.chunk_size)) |
| frame_offset += self.hop_size |
|
|
| |
| last_start = length - self.chunk_size |
| if last_start > frame_offset - self.hop_size: |
| chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=last_start, num_frames=self.chunk_size)) |
|
|
| return chunks |
|
|
| def __len__(self) -> int: |
| return len(self.chunks) |
|
|
| def __getitem__(self, idx: int) -> AudioItem: |
| """Load and return a single audio chunk.""" |
| chunk = self.chunks[idx] |
| entry = self.file_entries[chunk.audio_id] |
| orig_sample_rate = entry["sample_rate"] |
| full_path = Path(self.audio_root) / entry["path"] |
|
|
| |
| if self.target_sample_rate is not None and orig_sample_rate != self.target_sample_rate: |
| orig_frame_offset = int(chunk.frame_offset * orig_sample_rate / self.target_sample_rate) |
| orig_num_frames = int(chunk.num_frames * orig_sample_rate / self.target_sample_rate) |
| else: |
| orig_frame_offset = chunk.frame_offset |
| orig_num_frames = chunk.num_frames |
|
|
| waveform, sample_rate = _load_audio_internal( |
| full_path, frame_offset=orig_frame_offset, num_frames=orig_num_frames |
| ) |
|
|
| waveform, sample_rate = preprocess_audio( |
| waveform=waveform, |
| sample_rate=sample_rate, |
| mono=self.mono, |
| normalize=self.normalize, |
| target_sample_rate=self.target_sample_rate, |
| ) |
|
|
| |
| if self.chunk_size is not None and waveform.shape[1] < self.chunk_size: |
| waveform = pad_audio(waveform, self.chunk_size) |
|
|
| return AudioItem( |
| waveform=waveform, |
| audio_id=chunk.audio_id, |
| path=full_path, |
| sample_rate=sample_rate, |
| frame_offset=chunk.frame_offset, |
| ) |
|
|