Dalzymodderever
Intial Commit
2cba492
import csv
from dataclasses import dataclass
from pathlib import Path
import torch
import torchaudio
from torch.utils.data import Dataset
from ..util import _load_audio_internal, get_logger
logger = get_logger()
@dataclass
class AudioItem:
waveform: torch.Tensor
audio_id: str
path: Path
sample_rate: int
frame_offset: int | None = None # For chunked audio
def convert_to_mono(waveform: torch.Tensor) -> torch.Tensor:
# (1, samples)
if waveform.shape[0] > 1:
return torch.mean(waveform, dim=0, keepdim=True)
return waveform
def resample_audio(waveform: torch.Tensor, orig_freq: int, new_freq: int) -> torch.Tensor:
if orig_freq != new_freq:
resampler = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq)
return resampler(waveform)
return waveform
def normalize_audio(waveform: torch.Tensor) -> torch.Tensor:
max_val = torch.max(torch.abs(waveform)) + 1e-8
return waveform / max_val
def preprocess_audio(
waveform: torch.Tensor, sample_rate: int, mono: bool, normalize: bool, target_sample_rate: int | None = None
) -> tuple[torch.Tensor, int]:
# Convert to mono if needed
if mono:
waveform = convert_to_mono(waveform)
# Resample if needed
if target_sample_rate is not None and sample_rate != target_sample_rate:
waveform = resample_audio(waveform, sample_rate, target_sample_rate)
sample_rate = target_sample_rate
# Normalize if needed
if normalize:
waveform = normalize_audio(waveform)
return waveform, sample_rate
def pad_audio(waveform: torch.Tensor, target_length: int) -> torch.Tensor:
current_length = waveform.shape[1]
if current_length >= target_length:
return waveform
# Calculate padding needed
pad_length = target_length - current_length
# Pad with zeros at the end
padding = torch.zeros((waveform.shape[0], pad_length), dtype=waveform.dtype, device=waveform.device)
padded_waveform = torch.cat([waveform, padding], dim=1)
return padded_waveform
@dataclass
class ChunkInfo:
audio_id: str
frame_offset: int # In target sample rate
num_frames: int # In target sample rate
class ChunkedAudioDataset(Dataset):
"""
Dataset that loads audio from CSV with optional chunking.
Args:
csv_path: Path to the CSV file with columns: audio_id, path, length, sample_rate
audio_root: Root directory for audio files (prepended to paths in CSV)
chunk_size: Size of each chunk in frames (None = no chunking)
hop_size: Hop size between chunks in frames (None = use chunk_size)
mono: Convert to mono if True
normalize: Normalize audio if True
target_sample_rate: Resample to this sample rate if provided
"""
def __init__(
self,
csv_path: str,
audio_root: str,
chunk_size: int | None = None,
hop_size: int | None = None,
mono: bool = True,
normalize: bool = True,
target_sample_rate: int | None = None,
):
self.csv_path = csv_path
self.audio_root = audio_root
self.chunk_size = chunk_size
self.hop_size = hop_size if hop_size is not None else chunk_size
self.mono = mono
self.normalize = normalize
self.target_sample_rate = target_sample_rate
# Load CSV and compute chunks
self.file_entries = self._load_csv()
self.chunks = self._compute_chunks()
logger.info(f"Loaded dataset from {csv_path}: {len(self.file_entries)} files, {len(self.chunks)} chunks")
def _load_csv(self) -> dict[str, dict]:
"""Load audio metadata from CSV."""
entries = {}
with open(self.csv_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
entries[row["audio_id"]] = {
"path": row["path"],
"length": int(row["length"]),
"sample_rate": int(row["sample_rate"]),
}
return entries
def _compute_chunks(self) -> list[ChunkInfo]:
"""Compute all chunks from the file entries."""
chunks = []
for audio_id, entry in self.file_entries.items():
length = entry["length"]
sample_rate = entry["sample_rate"]
# Adjust length if resampling to target sample rate
if self.target_sample_rate is not None and sample_rate != self.target_sample_rate:
length = int(length * self.target_sample_rate / sample_rate)
sample_rate = self.target_sample_rate
if self.chunk_size is None or length <= self.chunk_size:
# No chunking, or file is shorter than chunk size: use entire file
chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=0, num_frames=length))
else:
# Chunking: compute all chunks with last chunk aligned to end
frame_offset = 0
while frame_offset + self.chunk_size <= length:
chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=frame_offset, num_frames=self.chunk_size))
frame_offset += self.hop_size
# Add the last chunk aligned to the end
last_start = length - self.chunk_size
if last_start > frame_offset - self.hop_size:
chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=last_start, num_frames=self.chunk_size))
return chunks
def __len__(self) -> int:
return len(self.chunks)
def __getitem__(self, idx: int) -> AudioItem:
"""Load and return a single audio chunk."""
chunk = self.chunks[idx]
entry = self.file_entries[chunk.audio_id]
orig_sample_rate = entry["sample_rate"]
full_path = Path(self.audio_root) / entry["path"]
# Calculate start frame and num frames in original sample rate
if self.target_sample_rate is not None and orig_sample_rate != self.target_sample_rate:
orig_frame_offset = int(chunk.frame_offset * orig_sample_rate / self.target_sample_rate)
orig_num_frames = int(chunk.num_frames * orig_sample_rate / self.target_sample_rate)
else:
orig_frame_offset = chunk.frame_offset
orig_num_frames = chunk.num_frames
waveform, sample_rate = _load_audio_internal(
full_path, frame_offset=orig_frame_offset, num_frames=orig_num_frames
)
waveform, sample_rate = preprocess_audio(
waveform=waveform,
sample_rate=sample_rate,
mono=self.mono,
normalize=self.normalize,
target_sample_rate=self.target_sample_rate,
)
# Pad if necessary (in case file is shorter than expected)
if self.chunk_size is not None and waveform.shape[1] < self.chunk_size:
waveform = pad_audio(waveform, self.chunk_size)
return AudioItem(
waveform=waveform,
audio_id=chunk.audio_id,
path=full_path,
sample_rate=sample_rate,
frame_offset=chunk.frame_offset,
)