TRIA / tria /data /preprocess.py
saumyap29's picture
initial commit
c9f87fa
import csv
import os
from pathlib import Path
from typing import Callable, Dict, Tuple, Union, Optional, Any
from rich.progress import track
import numpy as np
from audiotools.core.util import random_state
from ..util import ensure_dir
SplitType = Union[Tuple[float, float, float], Callable[[Path], str]]
def create_manifests(
data_dir: Union[str, Path],
ext: str,
output_dir: Union[str, Path],
split: SplitType,
attributes: Dict[str, Callable[[Path], Any]],
seed: Optional[int] = 0,
) -> Dict[str, Path]:
"""
Create CSV manifests for audio dataset.
Parameters
----------
data_dir : str
Dataset root directory to search recursively for files
ext : str
Audio file extension
output_dir : str
Directory to which to write manifests
split : SplitType
Either a 3-tuple containing (train, val, test) proportions summing to 1
or a Callable that returns "train", "val", or "test" given a filepath
attributes : dict
Dictionary mapping column names to Callables for extracting values
given filepaths; for example {'path': lambda p: str(p)}
seed : int
Random seed
"""
data_dir = Path(data_dir)
output_dir = Path(output_dir)
ensure_dir(output_dir)
all_files = sorted(
[p for p in data_dir.rglob(f"*{ext}") if p.is_file()],
key=lambda p: str(p).lower(),
)
splits = {"train": [], "val": [], "test": []}
# Callable split: apply given function to file paths to obtain train/val/test
# assignments
if callable(split):
for p in all_files:
s = split(p)
if s not in splits:
raise ValueError(
f"Split function must return one of "
f"{list(splits.keys())}, got {s!r} for {p}"
)
splits[s].append(p)
# Proportional split: randomly shuffle files and split according to given
# values
else:
if not (isinstance(split, tuple) and len(split) == 3):
raise ValueError(f"Split proportions tuple must have length 3")
p_train, p_val, p_test = split
total = float(p_train + p_val + p_test)
if not np.isclose(total, 1.0, atol=1e-6):
raise ValueError(f"Split proportions must sum to 1.0 (got {total}).")
rs = random_state(seed)
idx = np.array(rs.permutation(len(all_files)))
n = len(idx)
n_train = int(np.floor(p_train * n))
n_val = int(np.floor(p_val * n))
n_test = n - n_train - n_val
train_idx = idx[:n_train]
val_idx = idx[n_train:n_train + n_val]
test_idx = idx[n_train + n_val:]
for i in train_idx:
splits["train"].append(all_files[int(i)])
for i in val_idx:
splits["val"].append(all_files[int(i)])
for i in test_idx:
splits["test"].append(all_files[int(i)])
columns = list(attributes.keys())
# Write CSVs
out_paths: Dict[str, Path] = {}
for s in ("train", "val", "test"):
out_csv = output_dir / f"{s}.csv"
out_paths[s] = out_csv
with out_csv.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=columns)
writer.writeheader()
for p in track(
splits[s],
description=f"Writing {s}.csv",
total=len(splits[s])
):
try:
row = {}
for col, fn in attributes.items():
row[col] = fn(p)
writer.writerow(row)
except Exception as e:
print(
f"Error at path {p}:\n"
f"{e}\n"
f"Skipping..."
)
return out_paths