TRIA / tria /data /dataset.py
saumyap29's picture
initial commit
c9f87fa
import csv
from pathlib import Path
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import soundfile as sf
from audiotools import AudioSignal
from audiotools.core.util import random_state
from torch.utils.data import Dataset
from ..constants import DURATION
from ..constants import SAMPLE_RATE
from ..constants import STEMS
from ..util import collate
from ..util import get_info
from ..util import load_audio
from ..util import rms_salience
################################################################################
# Dataset for loading aligned excerpts across stem classes
################################################################################
class StemDataset(Dataset):
"""
Load aligned excerpts from specified stem classes given paths in one or more
CSV manifests. Based on `audiotools.data.datasets.AudioDataset`.
Parameters
----------
sources : Union[str, Path, List[Union[str, Path]]]
CSV manifest(s) with columns for each requested stem.
stems : List[str]
Column names to load, e.g. ["mixture","drums","bass","vocals"].
The **first** stem is used for salience unless `salience_on` is set.
sample_rate : int
duration : float
n_examples : int
num_channels : int
relative_path : str
Prepended to relative CSV paths.
strict : bool
Drop rows with missing stems (True) vs. fill with silence (False).
with_replacement : bool
Sampling strategy for rows.
shuffle_state : int
Seed for deterministic per-index RNG.
loudness_cutoff : Optional[float]
dB LUFS cutoff; if None, take random excerpt (still shared across stems).
salience_num_tries : int
Max tries for salient excerpt search (see `AudioSignal.salient_excerpt`).
salience_on : Optional[str]
Which stem to use for salience. Defaults to first of `stems`.
"""
def __init__(
self,
stems: List[str] = STEMS,
sample_rate: int = SAMPLE_RATE,
duration: float = DURATION,
sources: Union[str, Path, List[Union[str, Path]]] = None,
source_weights: Optional[List[float]] = None,
n_examples: int = 1000,
num_channels: int = 1,
relative_path: str = "",
strict: bool = True,
with_replacement: bool = True,
shuffle_state: int = 0,
loudness_cutoff: Optional[float] = -40.0,
salience_num_tries: int = 8,
salience_on: Optional[str] = None,
):
super().__init__()
assert sources is not None
assert len(stems) >= 1
self.stems = list(stems)
self.sample_rate = int(sample_rate)
self.duration = float(duration)
self.num_channels = int(num_channels)
self.relative_path = Path(relative_path)
self.strict = strict
self.with_replacement = with_replacement
self.length = int(n_examples)
self.shuffle_state = int(shuffle_state)
self.loudness_cutoff = loudness_cutoff
self.salience_num_tries = int(salience_num_tries)
self.salience_on = salience_on or self.stems[0]
if self.salience_on not in self.stems:
raise ValueError(
f"`salience_on` ('{self.salience_on}') must be one of {self.stems}"
)
# Read manifests
csv_paths = [sources] if isinstance(sources, (str, Path)) else list(sources)
self.source_rows: List[List[Dict]] = []
kept_mask: List[bool] = []
kept_csvs: List[Path] = []
for cpath in csv_paths:
# Read rows for source
cpath = Path(cpath)
raw_rows = []
with open(cpath, "r") as f:
reader = csv.DictReader(f)
for row in reader:
entry = {"__manifest__": str(cpath)}
stem_paths = {}
for s in self.stems:
raw = (row.get(s) or "").strip()
stem_paths[s] = str(self._resolve_path(raw)) if raw else ""
entry["paths"] = stem_paths
extra = {k: v for k, v in row.items() if k not in self.stems}
if extra:
entry["meta"] = extra
raw_rows.append(entry)
# Filter rows for source
filtered = []
for r in raw_rows:
missing = [
s for s, p in r["paths"].items() if not p or not Path(p).is_file()
]
if self.strict and missing:
continue
min_dur = np.inf
any_valid = False
for s, p in r["paths"].items():
if p and Path(p).is_file():
any_valid = True
try:
total_sec = float(sf.info(p).duration)
min_dur = min(min_dur, float(total_sec))
except Exception:
if self.strict:
min_dur = -np.inf
break
if not any_valid or not np.isfinite(min_dur):
continue
if min_dur < self.duration and self.strict:
continue
r["min_duration"] = min_dur if np.isfinite(min_dur) else 0.0
filtered.append(r)
if len(filtered) > 0:
self.source_rows.append(filtered)
kept_mask.append(True)
kept_csvs.append(cpath)
else:
kept_mask.append(False)
if len(self.source_rows) == 0:
raise RuntimeError(
"StemDataset: no valid rows after filtering in any source."
)
self.csv_paths = kept_csvs
lengths = [len(lst) for lst in self.source_rows]
self._source_offsets = np.cumsum([0] + lengths[:-1]) # for global idx
self._n_rows = int(sum(lengths))
# Weights over non-empty sources
if source_weights is None:
self._weights = None
else:
if len(source_weights) != len(csv_paths):
raise ValueError(
f"source_weights must match number of sources ({len(csv_paths)}), "
f"got {len(source_weights)}"
)
w = np.asarray(source_weights, dtype=float)
# Keep only weights for sources that survived filtering
w = w[np.array(kept_mask, dtype=bool)]
w = np.clip(w, 0, None)
if not np.any(w > 0):
w = np.ones_like(w)
self._weights = (w / w.sum()).tolist()
def _resolve_path(self, p: Union[str, Path]) -> Path:
p = Path(p).expanduser()
if not p.is_absolute():
p = (self.relative_path / p).expanduser()
return p
def _pick_row(self, state: np.random.RandomState):
# Sample a non-empty source
sidx = int(state.choice(len(self.source_rows), p=self._weights))
n_in_source = len(self.source_rows[sidx])
item_idx = int(state.randint(n_in_source))
row = self.source_rows[sidx][item_idx]
# Map to a global idx for metadata
ridx_global = int(self._source_offsets[sidx] + item_idx)
return ridx_global, row
def __len__(self):
return self.length
def __getitem__(self, idx: int):
state = random_state((self.shuffle_state + int(idx)) & 0x7FFFFFFF)
ridx, row = self._pick_row(state)
primary = self.salience_on
p0 = row["paths"].get(primary, "")
offset = 0.0
primary_sig = None
if p0 and Path(p0).is_file():
if self.loudness_cutoff is None or not self.salience_num_tries:
try:
total_sec, _sr = get_info(p0)
except Exception:
total_sec = 0.0
max_off = max(0.0, total_sec - self.duration)
offset = float(state.rand() * max_off) if max_off > 0 else 0.0
else:
offset = rms_salience(
p0,
duration=self.duration,
cutoff_db=float(self.loudness_cutoff),
num_tries=int(self.salience_num_tries),
state=state,
)
primary_sig = load_audio(p0, offset=offset, duration=self.duration)
else:
offset = 0.0
item: Dict[str, Dict] = {}
for s in self.stems:
p = row["paths"][s]
exists = bool(p) and Path(p).is_file()
if s == primary and primary_sig is not None:
sig = primary_sig.clone() # reuse window we already loaded
elif exists:
sig = load_audio(
p, offset=offset, duration=self.duration
) # windowed load
else:
sig = AudioSignal.zeros(
self.duration, self.sample_rate, self.num_channels
)
# Channel formatting
if self.num_channels == 1:
sig = sig.to_mono()
elif self.num_channels != sig.num_channels:
assert sig.num_channels == 1
sig.audio_data = sig.audio_data.repeat(1, self.num_channels, 1)
# Resample/pad to target SR and exact duration
sig = sig.resample(self.sample_rate)
if sig.duration < self.duration:
sig = sig.zero_pad_to(int(self.duration * self.sample_rate))
# Metadata
sig.metadata["path"] = p
sig.metadata["offset"] = offset
sig.metadata["source_row"] = ridx
if "meta" in row:
for k, v in row["meta"].items():
sig.metadata[k] = v
item[s] = {"signal": sig, "path": p}
item["idx"] = idx
return item
@staticmethod
def collate(list_of_dicts: Union[list, dict], n_splits: int = None):
return collate(list_of_dicts, n_splits=n_splits)