File size: 4,073 Bytes
79cf5f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import csv
import random
from collections import defaultdict
from pathlib import Path
import click
import yaml
# noinspection PyShadowingBuiltins
@click.command(help='Randomly select test samples')
@click.argument(
'config',
type=click.Path(file_okay=True, dir_okay=False, resolve_path=True, writable=True, path_type=Path),
metavar="CONFIG"
)
@click.option(
'--rel_path',
type=click.Path(file_okay=False, dir_okay=True, resolve_path=True, path_type=Path),
default=None,
help='Path that is relative to the paths mentioned in the config file.'
)
@click.option(
'--min', '_min',
show_default=True,
type=click.IntRange(min=1),
default=10,
help='Minimum number of test samples.'
)
@click.option(
'--max', '_max',
show_default=True,
type=click.IntRange(min=1),
default=20,
help='Maximum number of test samples (note that each speaker will have at least one test sample).'
)
@click.option(
'--per_speaker',
show_default=True,
type=click.IntRange(min=1),
default=4,
help='Expected number of test samples per speaker.'
)
def select_test_set(config, rel_path, _min, _max, per_speaker):
assert _min <= _max, 'min must be smaller or equal to max'
with open(config, 'r', encoding='utf8') as f:
hparams = yaml.safe_load(f)
spk_map = None
spk_ids = hparams['spk_ids']
speakers = hparams['speakers']
raw_data_dirs = list(map(Path, hparams['raw_data_dir']))
assert isinstance(speakers, list), 'Speakers must be a list'
assert len(speakers) == len(raw_data_dirs), \
'Number of raw data dirs must equal number of speaker names!'
if not spk_ids:
spk_ids = list(range(len(raw_data_dirs)))
else:
assert len(spk_ids) == len(raw_data_dirs), \
'Length of explicitly given spk_ids must equal the number of raw datasets.'
assert max(spk_ids) < hparams['num_spk'], \
f'Index in spk_id sequence {spk_ids} is out of range. All values should be smaller than num_spk.'
spk_map = {}
path_spk_map = defaultdict(list)
for ds_id, (spk_name, raw_path, spk_id) in enumerate(zip(speakers, raw_data_dirs, spk_ids)):
if spk_name in spk_map and spk_map[spk_name] != spk_id:
raise ValueError(f'Invalid speaker ID assignment. Name \'{spk_name}\' is assigned '
f'with different speaker IDs: {spk_map[spk_name]} and {spk_id}.')
spk_map[spk_name] = spk_id
path_spk_map[spk_id].append((ds_id, rel_path / raw_path if rel_path else raw_path))
training_cases = []
for spk_raw_dirs in path_spk_map.values():
training_case = []
# training cases from the same speaker are grouped together
for ds_id, raw_data_dir in spk_raw_dirs:
with open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8') as f:
reader = csv.DictReader(f)
for row in reader:
if (raw_data_dir / 'wavs' / f'{row["name"]}.wav').exists():
training_case.append(f'{ds_id}:{row["name"]}')
training_cases.append(training_case)
test_prefixes = []
total = min(_max, max(_min, per_speaker * len(training_cases)))
quotient, remainder = total // len(training_cases), total % len(training_cases)
if quotient == 0:
test_counts = [1] * len(training_cases)
else:
test_counts = [quotient + 1] * remainder + [quotient] * (len(training_cases) - remainder)
for i, count in enumerate(test_counts):
test_prefixes += sorted(random.sample(training_cases[i], count))
if not hparams['test_prefixes'] or click.confirm('Overwrite existing test prefixes?', abort=False):
hparams['test_prefixes'] = test_prefixes
hparams['num_valid_plots'] = len(test_prefixes)
with open(config, 'w', encoding='utf8') as f:
yaml.dump(hparams, f, sort_keys=False)
print('Test prefixes saved.')
else:
print('Test prefixes not saved, aborted.')
if __name__ == '__main__':
select_test_set()
|