File size: 4,073 Bytes
79cf5f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import csv
import random
from collections import defaultdict
from pathlib import Path

import click
import yaml


# noinspection PyShadowingBuiltins
@click.command(help='Randomly select test samples')
@click.argument(
    'config',
    type=click.Path(file_okay=True, dir_okay=False, resolve_path=True, writable=True, path_type=Path),
    metavar="CONFIG"
)
@click.option(
    '--rel_path',
    type=click.Path(file_okay=False, dir_okay=True, resolve_path=True, path_type=Path),
    default=None,
    help='Path that is relative to the paths mentioned in the config file.'
)
@click.option(
    '--min', '_min',
    show_default=True,
    type=click.IntRange(min=1),
    default=10,
    help='Minimum number of test samples.'
)
@click.option(
    '--max', '_max',
    show_default=True,
    type=click.IntRange(min=1),
    default=20,
    help='Maximum number of test samples (note that each speaker will have at least one test sample).'
)
@click.option(
    '--per_speaker',
    show_default=True,
    type=click.IntRange(min=1),
    default=4,
    help='Expected number of test samples per speaker.'
)
def select_test_set(config, rel_path, _min, _max, per_speaker):
    assert _min <= _max, 'min must be smaller or equal to max'
    with open(config, 'r', encoding='utf8') as f:
        hparams = yaml.safe_load(f)

    spk_map = None
    spk_ids = hparams['spk_ids']
    speakers = hparams['speakers']
    raw_data_dirs = list(map(Path, hparams['raw_data_dir']))
    assert isinstance(speakers, list), 'Speakers must be a list'
    assert len(speakers) == len(raw_data_dirs), \
        'Number of raw data dirs must equal number of speaker names!'
    if not spk_ids:
        spk_ids = list(range(len(raw_data_dirs)))
    else:
        assert len(spk_ids) == len(raw_data_dirs), \
            'Length of explicitly given spk_ids must equal the number of raw datasets.'
    assert max(spk_ids) < hparams['num_spk'], \
        f'Index in spk_id sequence {spk_ids} is out of range. All values should be smaller than num_spk.'

    spk_map = {}
    path_spk_map = defaultdict(list)
    for ds_id, (spk_name, raw_path, spk_id) in enumerate(zip(speakers, raw_data_dirs, spk_ids)):
        if spk_name in spk_map and spk_map[spk_name] != spk_id:
            raise ValueError(f'Invalid speaker ID assignment. Name \'{spk_name}\' is assigned '
                                f'with different speaker IDs: {spk_map[spk_name]} and {spk_id}.')
        spk_map[spk_name] = spk_id
        path_spk_map[spk_id].append((ds_id, rel_path / raw_path if rel_path else raw_path))

    training_cases = []
    for spk_raw_dirs in path_spk_map.values():
        training_case = []
        # training cases from the same speaker are grouped together
        for ds_id, raw_data_dir in spk_raw_dirs:
            with open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8') as f:
                reader = csv.DictReader(f)
                for row in reader:
                    if (raw_data_dir / 'wavs' / f'{row["name"]}.wav').exists():
                        training_case.append(f'{ds_id}:{row["name"]}')
        training_cases.append(training_case)

    test_prefixes = []
    total = min(_max, max(_min, per_speaker * len(training_cases)))
    quotient, remainder = total // len(training_cases), total % len(training_cases)
    if quotient == 0:
        test_counts = [1] * len(training_cases)
    else:
        test_counts = [quotient + 1] * remainder + [quotient] * (len(training_cases) - remainder)
    for i, count in enumerate(test_counts):
        test_prefixes += sorted(random.sample(training_cases[i], count))
    if not hparams['test_prefixes'] or click.confirm('Overwrite existing test prefixes?', abort=False):
        hparams['test_prefixes'] = test_prefixes
        hparams['num_valid_plots'] = len(test_prefixes)
        with open(config, 'w', encoding='utf8') as f:
            yaml.dump(hparams, f, sort_keys=False)
        print('Test prefixes saved.')
    else:
        print('Test prefixes not saved, aborted.')

if __name__ == '__main__':
    select_test_set()