Spaces:
Running
Running
| import csv | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import torch | |
| import torchaudio | |
| from torch.utils.data import Dataset | |
| from ..util import _load_audio_internal, get_logger | |
| logger = get_logger() | |
| class AudioItem: | |
| waveform: torch.Tensor | |
| audio_id: str | |
| path: Path | |
| sample_rate: int | |
| frame_offset: int | None = None # For chunked audio | |
| def convert_to_mono(waveform: torch.Tensor) -> torch.Tensor: | |
| # (1, samples) | |
| if waveform.shape[0] > 1: | |
| return torch.mean(waveform, dim=0, keepdim=True) | |
| return waveform | |
| def resample_audio(waveform: torch.Tensor, orig_freq: int, new_freq: int) -> torch.Tensor: | |
| if orig_freq != new_freq: | |
| resampler = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq) | |
| return resampler(waveform) | |
| return waveform | |
| def normalize_audio(waveform: torch.Tensor) -> torch.Tensor: | |
| max_val = torch.max(torch.abs(waveform)) + 1e-8 | |
| return waveform / max_val | |
| def preprocess_audio( | |
| waveform: torch.Tensor, sample_rate: int, mono: bool, normalize: bool, target_sample_rate: int | None = None | |
| ) -> tuple[torch.Tensor, int]: | |
| # Convert to mono if needed | |
| if mono: | |
| waveform = convert_to_mono(waveform) | |
| # Resample if needed | |
| if target_sample_rate is not None and sample_rate != target_sample_rate: | |
| waveform = resample_audio(waveform, sample_rate, target_sample_rate) | |
| sample_rate = target_sample_rate | |
| # Normalize if needed | |
| if normalize: | |
| waveform = normalize_audio(waveform) | |
| return waveform, sample_rate | |
| def pad_audio(waveform: torch.Tensor, target_length: int) -> torch.Tensor: | |
| current_length = waveform.shape[1] | |
| if current_length >= target_length: | |
| return waveform | |
| # Calculate padding needed | |
| pad_length = target_length - current_length | |
| # Pad with zeros at the end | |
| padding = torch.zeros((waveform.shape[0], pad_length), dtype=waveform.dtype, device=waveform.device) | |
| padded_waveform = torch.cat([waveform, padding], dim=1) | |
| return padded_waveform | |
| class ChunkInfo: | |
| audio_id: str | |
| frame_offset: int # In target sample rate | |
| num_frames: int # In target sample rate | |
| class ChunkedAudioDataset(Dataset): | |
| """ | |
| Dataset that loads audio from CSV with optional chunking. | |
| Args: | |
| csv_path: Path to the CSV file with columns: audio_id, path, length, sample_rate | |
| audio_root: Root directory for audio files (prepended to paths in CSV) | |
| chunk_size: Size of each chunk in frames (None = no chunking) | |
| hop_size: Hop size between chunks in frames (None = use chunk_size) | |
| mono: Convert to mono if True | |
| normalize: Normalize audio if True | |
| target_sample_rate: Resample to this sample rate if provided | |
| """ | |
| def __init__( | |
| self, | |
| csv_path: str, | |
| audio_root: str, | |
| chunk_size: int | None = None, | |
| hop_size: int | None = None, | |
| mono: bool = True, | |
| normalize: bool = True, | |
| target_sample_rate: int | None = None, | |
| ): | |
| self.csv_path = csv_path | |
| self.audio_root = audio_root | |
| self.chunk_size = chunk_size | |
| self.hop_size = hop_size if hop_size is not None else chunk_size | |
| self.mono = mono | |
| self.normalize = normalize | |
| self.target_sample_rate = target_sample_rate | |
| # Load CSV and compute chunks | |
| self.file_entries = self._load_csv() | |
| self.chunks = self._compute_chunks() | |
| logger.info(f"Loaded dataset from {csv_path}: {len(self.file_entries)} files, {len(self.chunks)} chunks") | |
| def _load_csv(self) -> dict[str, dict]: | |
| """Load audio metadata from CSV.""" | |
| entries = {} | |
| with open(self.csv_path, "r", encoding="utf-8") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| entries[row["audio_id"]] = { | |
| "path": row["path"], | |
| "length": int(row["length"]), | |
| "sample_rate": int(row["sample_rate"]), | |
| } | |
| return entries | |
| def _compute_chunks(self) -> list[ChunkInfo]: | |
| """Compute all chunks from the file entries.""" | |
| chunks = [] | |
| for audio_id, entry in self.file_entries.items(): | |
| length = entry["length"] | |
| sample_rate = entry["sample_rate"] | |
| # Adjust length if resampling to target sample rate | |
| if self.target_sample_rate is not None and sample_rate != self.target_sample_rate: | |
| length = int(length * self.target_sample_rate / sample_rate) | |
| sample_rate = self.target_sample_rate | |
| if self.chunk_size is None or length <= self.chunk_size: | |
| # No chunking, or file is shorter than chunk size: use entire file | |
| chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=0, num_frames=length)) | |
| else: | |
| # Chunking: compute all chunks with last chunk aligned to end | |
| frame_offset = 0 | |
| while frame_offset + self.chunk_size <= length: | |
| chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=frame_offset, num_frames=self.chunk_size)) | |
| frame_offset += self.hop_size | |
| # Add the last chunk aligned to the end | |
| last_start = length - self.chunk_size | |
| if last_start > frame_offset - self.hop_size: | |
| chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=last_start, num_frames=self.chunk_size)) | |
| return chunks | |
| def __len__(self) -> int: | |
| return len(self.chunks) | |
| def __getitem__(self, idx: int) -> AudioItem: | |
| """Load and return a single audio chunk.""" | |
| chunk = self.chunks[idx] | |
| entry = self.file_entries[chunk.audio_id] | |
| orig_sample_rate = entry["sample_rate"] | |
| full_path = Path(self.audio_root) / entry["path"] | |
| # Calculate start frame and num frames in original sample rate | |
| if self.target_sample_rate is not None and orig_sample_rate != self.target_sample_rate: | |
| orig_frame_offset = int(chunk.frame_offset * orig_sample_rate / self.target_sample_rate) | |
| orig_num_frames = int(chunk.num_frames * orig_sample_rate / self.target_sample_rate) | |
| else: | |
| orig_frame_offset = chunk.frame_offset | |
| orig_num_frames = chunk.num_frames | |
| waveform, sample_rate = _load_audio_internal( | |
| full_path, frame_offset=orig_frame_offset, num_frames=orig_num_frames | |
| ) | |
| waveform, sample_rate = preprocess_audio( | |
| waveform=waveform, | |
| sample_rate=sample_rate, | |
| mono=self.mono, | |
| normalize=self.normalize, | |
| target_sample_rate=self.target_sample_rate, | |
| ) | |
| # Pad if necessary (in case file is shorter than expected) | |
| if self.chunk_size is not None and waveform.shape[1] < self.chunk_size: | |
| waveform = pad_audio(waveform, self.chunk_size) | |
| return AudioItem( | |
| waveform=waveform, | |
| audio_id=chunk.audio_id, | |
| path=full_path, | |
| sample_rate=sample_rate, | |
| frame_offset=chunk.frame_offset, | |
| ) | |