Spaces:
Sleeping
Sleeping
File size: 3,879 Bytes
c9f87fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import csv
import os
from pathlib import Path
from typing import Callable, Dict, Tuple, Union, Optional, Any
from rich.progress import track
import numpy as np
from audiotools.core.util import random_state
from ..util import ensure_dir
SplitType = Union[Tuple[float, float, float], Callable[[Path], str]]
def create_manifests(
data_dir: Union[str, Path],
ext: str,
output_dir: Union[str, Path],
split: SplitType,
attributes: Dict[str, Callable[[Path], Any]],
seed: Optional[int] = 0,
) -> Dict[str, Path]:
"""
Create CSV manifests for audio dataset.
Parameters
----------
data_dir : str
Dataset root directory to search recursively for files
ext : str
Audio file extension
output_dir : str
Directory to which to write manifests
split : SplitType
Either a 3-tuple containing (train, val, test) proportions summing to 1
or a Callable that returns "train", "val", or "test" given a filepath
attributes : dict
Dictionary mapping column names to Callables for extracting values
given filepaths; for example {'path': lambda p: str(p)}
seed : int
Random seed
"""
data_dir = Path(data_dir)
output_dir = Path(output_dir)
ensure_dir(output_dir)
all_files = sorted(
[p for p in data_dir.rglob(f"*{ext}") if p.is_file()],
key=lambda p: str(p).lower(),
)
splits = {"train": [], "val": [], "test": []}
# Callable split: apply given function to file paths to obtain train/val/test
# assignments
if callable(split):
for p in all_files:
s = split(p)
if s not in splits:
raise ValueError(
f"Split function must return one of "
f"{list(splits.keys())}, got {s!r} for {p}"
)
splits[s].append(p)
# Proportional split: randomly shuffle files and split according to given
# values
else:
if not (isinstance(split, tuple) and len(split) == 3):
raise ValueError(f"Split proportions tuple must have length 3")
p_train, p_val, p_test = split
total = float(p_train + p_val + p_test)
if not np.isclose(total, 1.0, atol=1e-6):
raise ValueError(f"Split proportions must sum to 1.0 (got {total}).")
rs = random_state(seed)
idx = np.array(rs.permutation(len(all_files)))
n = len(idx)
n_train = int(np.floor(p_train * n))
n_val = int(np.floor(p_val * n))
n_test = n - n_train - n_val
train_idx = idx[:n_train]
val_idx = idx[n_train:n_train + n_val]
test_idx = idx[n_train + n_val:]
for i in train_idx:
splits["train"].append(all_files[int(i)])
for i in val_idx:
splits["val"].append(all_files[int(i)])
for i in test_idx:
splits["test"].append(all_files[int(i)])
columns = list(attributes.keys())
# Write CSVs
out_paths: Dict[str, Path] = {}
for s in ("train", "val", "test"):
out_csv = output_dir / f"{s}.csv"
out_paths[s] = out_csv
with out_csv.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=columns)
writer.writeheader()
for p in track(
splits[s],
description=f"Writing {s}.csv",
total=len(splits[s])
):
try:
row = {}
for col, fn in attributes.items():
row[col] = fn(p)
writer.writerow(row)
except Exception as e:
print(
f"Error at path {p}:\n"
f"{e}\n"
f"Skipping..."
)
return out_paths
|