File size: 3,879 Bytes
c9f87fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import csv
import os
from pathlib import Path
from typing import Callable, Dict, Tuple, Union, Optional, Any
from rich.progress import track

import numpy as np

from audiotools.core.util import random_state
from ..util import ensure_dir

SplitType = Union[Tuple[float, float, float], Callable[[Path], str]]


def create_manifests(
    data_dir: Union[str, Path],
    ext: str,
    output_dir: Union[str, Path],
    split: SplitType,
    attributes: Dict[str, Callable[[Path], Any]],
    seed: Optional[int] = 0,
) -> Dict[str, Path]:
    """
    Create CSV manifests for audio dataset.

    Parameters
    ----------
    data_dir : str
        Dataset root directory to search recursively for files
    ext : str
        Audio file extension
    output_dir : str
        Directory to which to write manifests
    split : SplitType
        Either a 3-tuple containing (train, val, test) proportions summing to 1
        or a Callable that returns "train", "val", or "test" given a filepath
    attributes : dict
        Dictionary mapping column names to Callables for extracting values
        given filepaths; for example {'path': lambda p: str(p)}
    seed : int
        Random seed
    """
    data_dir = Path(data_dir)
    output_dir = Path(output_dir)
    ensure_dir(output_dir)

    all_files = sorted(
        [p for p in data_dir.rglob(f"*{ext}") if p.is_file()],
        key=lambda p: str(p).lower(),
    )

    splits = {"train": [], "val": [], "test": []}

    # Callable split: apply given function to file paths to obtain train/val/test
    # assignments
    if callable(split):
        for p in all_files:
            s = split(p)
            if s not in splits:
                raise ValueError(
                    f"Split function must return one of "
                    f"{list(splits.keys())}, got {s!r} for {p}"
                )
            splits[s].append(p)
            
    # Proportional split: randomly shuffle files and split according to given
    # values
    else:
        if not (isinstance(split, tuple) and len(split) == 3):
            raise ValueError(f"Split proportions tuple must have length 3")
        p_train, p_val, p_test = split
        total = float(p_train + p_val + p_test)
        if not np.isclose(total, 1.0, atol=1e-6):
            raise ValueError(f"Split proportions must sum to 1.0 (got {total}).")

        rs = random_state(seed)
        idx = np.array(rs.permutation(len(all_files)))
        n = len(idx)
        n_train = int(np.floor(p_train * n))
        n_val = int(np.floor(p_val * n))
        n_test = n - n_train - n_val

        train_idx = idx[:n_train]
        val_idx = idx[n_train:n_train + n_val]
        test_idx = idx[n_train + n_val:]

        for i in train_idx:
            splits["train"].append(all_files[int(i)])
        for i in val_idx:
            splits["val"].append(all_files[int(i)])
        for i in test_idx:
            splits["test"].append(all_files[int(i)])
    
    columns = list(attributes.keys())

    # Write CSVs
    out_paths: Dict[str, Path] = {}
    for s in ("train", "val", "test"):
        out_csv = output_dir / f"{s}.csv"
        out_paths[s] = out_csv

        with out_csv.open("w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=columns)
            writer.writeheader()

            for p in track(
                splits[s], 
                description=f"Writing {s}.csv", 
                total=len(splits[s])
            ):

                try:
                    row = {}
                    for col, fn in attributes.items():
                        row[col] = fn(p)
                    writer.writerow(row)
                except Exception as e:
                    print(
                        f"Error at path {p}:\n"
                        f"{e}\n"
                        f"Skipping..."
                    )

    return out_paths