Spaces:
Sleeping
Sleeping
| import csv | |
| import os | |
| from pathlib import Path | |
| from typing import Callable, Dict, Tuple, Union, Optional, Any | |
| from rich.progress import track | |
| import numpy as np | |
| from audiotools.core.util import random_state | |
| from ..util import ensure_dir | |
| SplitType = Union[Tuple[float, float, float], Callable[[Path], str]] | |
| def create_manifests( | |
| data_dir: Union[str, Path], | |
| ext: str, | |
| output_dir: Union[str, Path], | |
| split: SplitType, | |
| attributes: Dict[str, Callable[[Path], Any]], | |
| seed: Optional[int] = 0, | |
| ) -> Dict[str, Path]: | |
| """ | |
| Create CSV manifests for audio dataset. | |
| Parameters | |
| ---------- | |
| data_dir : str | |
| Dataset root directory to search recursively for files | |
| ext : str | |
| Audio file extension | |
| output_dir : str | |
| Directory to which to write manifests | |
| split : SplitType | |
| Either a 3-tuple containing (train, val, test) proportions summing to 1 | |
| or a Callable that returns "train", "val", or "test" given a filepath | |
| attributes : dict | |
| Dictionary mapping column names to Callables for extracting values | |
| given filepaths; for example {'path': lambda p: str(p)} | |
| seed : int | |
| Random seed | |
| """ | |
| data_dir = Path(data_dir) | |
| output_dir = Path(output_dir) | |
| ensure_dir(output_dir) | |
| all_files = sorted( | |
| [p for p in data_dir.rglob(f"*{ext}") if p.is_file()], | |
| key=lambda p: str(p).lower(), | |
| ) | |
| splits = {"train": [], "val": [], "test": []} | |
| # Callable split: apply given function to file paths to obtain train/val/test | |
| # assignments | |
| if callable(split): | |
| for p in all_files: | |
| s = split(p) | |
| if s not in splits: | |
| raise ValueError( | |
| f"Split function must return one of " | |
| f"{list(splits.keys())}, got {s!r} for {p}" | |
| ) | |
| splits[s].append(p) | |
| # Proportional split: randomly shuffle files and split according to given | |
| # values | |
| else: | |
| if not (isinstance(split, tuple) and len(split) == 3): | |
| raise ValueError(f"Split proportions tuple must have length 3") | |
| p_train, p_val, p_test = split | |
| total = float(p_train + p_val + p_test) | |
| if not np.isclose(total, 1.0, atol=1e-6): | |
| raise ValueError(f"Split proportions must sum to 1.0 (got {total}).") | |
| rs = random_state(seed) | |
| idx = np.array(rs.permutation(len(all_files))) | |
| n = len(idx) | |
| n_train = int(np.floor(p_train * n)) | |
| n_val = int(np.floor(p_val * n)) | |
| n_test = n - n_train - n_val | |
| train_idx = idx[:n_train] | |
| val_idx = idx[n_train:n_train + n_val] | |
| test_idx = idx[n_train + n_val:] | |
| for i in train_idx: | |
| splits["train"].append(all_files[int(i)]) | |
| for i in val_idx: | |
| splits["val"].append(all_files[int(i)]) | |
| for i in test_idx: | |
| splits["test"].append(all_files[int(i)]) | |
| columns = list(attributes.keys()) | |
| # Write CSVs | |
| out_paths: Dict[str, Path] = {} | |
| for s in ("train", "val", "test"): | |
| out_csv = output_dir / f"{s}.csv" | |
| out_paths[s] = out_csv | |
| with out_csv.open("w", newline="") as f: | |
| writer = csv.DictWriter(f, fieldnames=columns) | |
| writer.writeheader() | |
| for p in track( | |
| splits[s], | |
| description=f"Writing {s}.csv", | |
| total=len(splits[s]) | |
| ): | |
| try: | |
| row = {} | |
| for col, fn in attributes.items(): | |
| row[col] = fn(p) | |
| writer.writerow(row) | |
| except Exception as e: | |
| print( | |
| f"Error at path {p}:\n" | |
| f"{e}\n" | |
| f"Skipping..." | |
| ) | |
| return out_paths | |