Spaces:
Build error
Build error
| import csv | |
| from pathlib import Path | |
| from typing import Callable | |
| from typing import Dict | |
| from typing import List | |
| from typing import Optional | |
| from typing import Union | |
| import numpy as np | |
| import soundfile as sf | |
| from audiotools import AudioSignal | |
| from audiotools.core.util import random_state | |
| from torch.utils.data import Dataset | |
| from ..constants import DURATION | |
| from ..constants import SAMPLE_RATE | |
| from ..constants import STEMS | |
| from ..util import collate | |
| from ..util import get_info | |
| from ..util import load_audio | |
| from ..util import rms_salience | |
| ################################################################################ | |
| # Dataset for loading aligned excerpts across stem classes | |
| ################################################################################ | |
| class StemDataset(Dataset): | |
| """ | |
| Load aligned excerpts from specified stem classes given paths in one or more | |
| CSV manifests. Based on `audiotools.data.datasets.AudioDataset`. | |
| Parameters | |
| ---------- | |
| sources : Union[str, Path, List[Union[str, Path]]] | |
| CSV manifest(s) with columns for each requested stem. | |
| stems : List[str] | |
| Column names to load, e.g. ["mixture","drums","bass","vocals"]. | |
| The **first** stem is used for salience unless `salience_on` is set. | |
| sample_rate : int | |
| duration : float | |
| n_examples : int | |
| num_channels : int | |
| relative_path : str | |
| Prepended to relative CSV paths. | |
| strict : bool | |
| Drop rows with missing stems (True) vs. fill with silence (False). | |
| with_replacement : bool | |
| Sampling strategy for rows. | |
| shuffle_state : int | |
| Seed for deterministic per-index RNG. | |
| loudness_cutoff : Optional[float] | |
| dB LUFS cutoff; if None, take random excerpt (still shared across stems). | |
| salience_num_tries : int | |
| Max tries for salient excerpt search (see `AudioSignal.salient_excerpt`). | |
| salience_on : Optional[str] | |
| Which stem to use for salience. Defaults to first of `stems`. | |
| """ | |
| def __init__( | |
| self, | |
| stems: List[str] = STEMS, | |
| sample_rate: int = SAMPLE_RATE, | |
| duration: float = DURATION, | |
| sources: Union[str, Path, List[Union[str, Path]]] = None, | |
| source_weights: Optional[List[float]] = None, | |
| n_examples: int = 1000, | |
| num_channels: int = 1, | |
| relative_path: str = "", | |
| strict: bool = True, | |
| with_replacement: bool = True, | |
| shuffle_state: int = 0, | |
| loudness_cutoff: Optional[float] = -40.0, | |
| salience_num_tries: int = 8, | |
| salience_on: Optional[str] = None, | |
| ): | |
| super().__init__() | |
| assert sources is not None | |
| assert len(stems) >= 1 | |
| self.stems = list(stems) | |
| self.sample_rate = int(sample_rate) | |
| self.duration = float(duration) | |
| self.num_channels = int(num_channels) | |
| self.relative_path = Path(relative_path) | |
| self.strict = strict | |
| self.with_replacement = with_replacement | |
| self.length = int(n_examples) | |
| self.shuffle_state = int(shuffle_state) | |
| self.loudness_cutoff = loudness_cutoff | |
| self.salience_num_tries = int(salience_num_tries) | |
| self.salience_on = salience_on or self.stems[0] | |
| if self.salience_on not in self.stems: | |
| raise ValueError( | |
| f"`salience_on` ('{self.salience_on}') must be one of {self.stems}" | |
| ) | |
| # Read manifests | |
| csv_paths = [sources] if isinstance(sources, (str, Path)) else list(sources) | |
| self.source_rows: List[List[Dict]] = [] | |
| kept_mask: List[bool] = [] | |
| kept_csvs: List[Path] = [] | |
| for cpath in csv_paths: | |
| # Read rows for source | |
| cpath = Path(cpath) | |
| raw_rows = [] | |
| with open(cpath, "r") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| entry = {"__manifest__": str(cpath)} | |
| stem_paths = {} | |
| for s in self.stems: | |
| raw = (row.get(s) or "").strip() | |
| stem_paths[s] = str(self._resolve_path(raw)) if raw else "" | |
| entry["paths"] = stem_paths | |
| extra = {k: v for k, v in row.items() if k not in self.stems} | |
| if extra: | |
| entry["meta"] = extra | |
| raw_rows.append(entry) | |
| # Filter rows for source | |
| filtered = [] | |
| for r in raw_rows: | |
| missing = [ | |
| s for s, p in r["paths"].items() if not p or not Path(p).is_file() | |
| ] | |
| if self.strict and missing: | |
| continue | |
| min_dur = np.inf | |
| any_valid = False | |
| for s, p in r["paths"].items(): | |
| if p and Path(p).is_file(): | |
| any_valid = True | |
| try: | |
| total_sec = float(sf.info(p).duration) | |
| min_dur = min(min_dur, float(total_sec)) | |
| except Exception: | |
| if self.strict: | |
| min_dur = -np.inf | |
| break | |
| if not any_valid or not np.isfinite(min_dur): | |
| continue | |
| if min_dur < self.duration and self.strict: | |
| continue | |
| r["min_duration"] = min_dur if np.isfinite(min_dur) else 0.0 | |
| filtered.append(r) | |
| if len(filtered) > 0: | |
| self.source_rows.append(filtered) | |
| kept_mask.append(True) | |
| kept_csvs.append(cpath) | |
| else: | |
| kept_mask.append(False) | |
| if len(self.source_rows) == 0: | |
| raise RuntimeError( | |
| "StemDataset: no valid rows after filtering in any source." | |
| ) | |
| self.csv_paths = kept_csvs | |
| lengths = [len(lst) for lst in self.source_rows] | |
| self._source_offsets = np.cumsum([0] + lengths[:-1]) # for global idx | |
| self._n_rows = int(sum(lengths)) | |
| # Weights over non-empty sources | |
| if source_weights is None: | |
| self._weights = None | |
| else: | |
| if len(source_weights) != len(csv_paths): | |
| raise ValueError( | |
| f"source_weights must match number of sources ({len(csv_paths)}), " | |
| f"got {len(source_weights)}" | |
| ) | |
| w = np.asarray(source_weights, dtype=float) | |
| # Keep only weights for sources that survived filtering | |
| w = w[np.array(kept_mask, dtype=bool)] | |
| w = np.clip(w, 0, None) | |
| if not np.any(w > 0): | |
| w = np.ones_like(w) | |
| self._weights = (w / w.sum()).tolist() | |
| def _resolve_path(self, p: Union[str, Path]) -> Path: | |
| p = Path(p).expanduser() | |
| if not p.is_absolute(): | |
| p = (self.relative_path / p).expanduser() | |
| return p | |
| def _pick_row(self, state: np.random.RandomState): | |
| # Sample a non-empty source | |
| sidx = int(state.choice(len(self.source_rows), p=self._weights)) | |
| n_in_source = len(self.source_rows[sidx]) | |
| item_idx = int(state.randint(n_in_source)) | |
| row = self.source_rows[sidx][item_idx] | |
| # Map to a global idx for metadata | |
| ridx_global = int(self._source_offsets[sidx] + item_idx) | |
| return ridx_global, row | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx: int): | |
| state = random_state((self.shuffle_state + int(idx)) & 0x7FFFFFFF) | |
| ridx, row = self._pick_row(state) | |
| primary = self.salience_on | |
| p0 = row["paths"].get(primary, "") | |
| offset = 0.0 | |
| primary_sig = None | |
| if p0 and Path(p0).is_file(): | |
| if self.loudness_cutoff is None or not self.salience_num_tries: | |
| try: | |
| total_sec, _sr = get_info(p0) | |
| except Exception: | |
| total_sec = 0.0 | |
| max_off = max(0.0, total_sec - self.duration) | |
| offset = float(state.rand() * max_off) if max_off > 0 else 0.0 | |
| else: | |
| offset = rms_salience( | |
| p0, | |
| duration=self.duration, | |
| cutoff_db=float(self.loudness_cutoff), | |
| num_tries=int(self.salience_num_tries), | |
| state=state, | |
| ) | |
| primary_sig = load_audio(p0, offset=offset, duration=self.duration) | |
| else: | |
| offset = 0.0 | |
| item: Dict[str, Dict] = {} | |
| for s in self.stems: | |
| p = row["paths"][s] | |
| exists = bool(p) and Path(p).is_file() | |
| if s == primary and primary_sig is not None: | |
| sig = primary_sig.clone() # reuse window we already loaded | |
| elif exists: | |
| sig = load_audio( | |
| p, offset=offset, duration=self.duration | |
| ) # windowed load | |
| else: | |
| sig = AudioSignal.zeros( | |
| self.duration, self.sample_rate, self.num_channels | |
| ) | |
| # Channel formatting | |
| if self.num_channels == 1: | |
| sig = sig.to_mono() | |
| elif self.num_channels != sig.num_channels: | |
| assert sig.num_channels == 1 | |
| sig.audio_data = sig.audio_data.repeat(1, self.num_channels, 1) | |
| # Resample/pad to target SR and exact duration | |
| sig = sig.resample(self.sample_rate) | |
| if sig.duration < self.duration: | |
| sig = sig.zero_pad_to(int(self.duration * self.sample_rate)) | |
| # Metadata | |
| sig.metadata["path"] = p | |
| sig.metadata["offset"] = offset | |
| sig.metadata["source_row"] = ridx | |
| if "meta" in row: | |
| for k, v in row["meta"].items(): | |
| sig.metadata[k] = v | |
| item[s] = {"signal": sig, "path": p} | |
| item["idx"] = idx | |
| return item | |
| def collate(list_of_dicts: Union[list, dict], n_splits: int = None): | |
| return collate(list_of_dicts, n_splits=n_splits) | |