| | import functools |
| | import random |
| | import unittest |
| |
|
| | import torch |
| |
|
| | from TTS.config.shared_configs import BaseDatasetConfig |
| | from TTS.tts.datasets import load_tts_samples |
| | from TTS.tts.utils.data import get_length_balancer_weights |
| | from TTS.tts.utils.languages import get_language_balancer_weights |
| | from TTS.tts.utils.speakers import get_speaker_balancer_weights |
| | from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler |
| |
|
| | |
| | torch.manual_seed(0) |
| |
|
| | dataset_config_en = BaseDatasetConfig( |
| | formatter="ljspeech", |
| | meta_file_train="metadata.csv", |
| | meta_file_val="metadata.csv", |
| | path="tests/data/ljspeech", |
| | language="en", |
| | ) |
| |
|
| | dataset_config_pt = BaseDatasetConfig( |
| | formatter="ljspeech", |
| | meta_file_train="metadata.csv", |
| | meta_file_val="metadata.csv", |
| | path="tests/data/ljspeech", |
| | language="pt-br", |
| | ) |
| |
|
| | |
| | train_samples, eval_samples = load_tts_samples( |
| | [dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True |
| | ) |
| |
|
| | |
| | for i, sample in enumerate(train_samples): |
| | if i < 5: |
| | sample["speaker_name"] = "ljspeech-0" |
| | else: |
| | sample["speaker_name"] = "ljspeech-1" |
| |
|
| |
|
| | def is_balanced(lang_1, lang_2): |
| | return 0.85 < lang_1 / lang_2 < 1.2 |
| |
|
| |
|
| | class TestSamplers(unittest.TestCase): |
| | def test_language_random_sampler(self): |
| | random_sampler = torch.utils.data.RandomSampler(train_samples) |
| | ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) |
| | en, pt = 0, 0 |
| | for index in ids: |
| | if train_samples[index]["language"] == "en": |
| | en += 1 |
| | else: |
| | pt += 1 |
| |
|
| | assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced" |
| |
|
| | def test_language_weighted_random_sampler(self): |
| | weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler( |
| | get_language_balancer_weights(train_samples), len(train_samples) |
| | ) |
| | ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) |
| | en, pt = 0, 0 |
| | for index in ids: |
| | if train_samples[index]["language"] == "en": |
| | en += 1 |
| | else: |
| | pt += 1 |
| |
|
| | assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced" |
| |
|
| | def test_speaker_weighted_random_sampler(self): |
| | weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler( |
| | get_speaker_balancer_weights(train_samples), len(train_samples) |
| | ) |
| | ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) |
| | spk1, spk2 = 0, 0 |
| | for index in ids: |
| | if train_samples[index]["speaker_name"] == "ljspeech-0": |
| | spk1 += 1 |
| | else: |
| | spk2 += 1 |
| |
|
| | assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced" |
| |
|
| | def test_perfect_sampler(self): |
| | classes = set() |
| | for item in train_samples: |
| | classes.add(item["speaker_name"]) |
| |
|
| | sampler = PerfectBatchSampler( |
| | train_samples, |
| | classes, |
| | batch_size=2 * 3, |
| | num_classes_in_batch=2, |
| | label_key="speaker_name", |
| | shuffle=False, |
| | drop_last=True, |
| | ) |
| | batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)]) |
| | for batch in batchs: |
| | spk1, spk2 = 0, 0 |
| | |
| | for index in batch: |
| | if train_samples[index]["speaker_name"] == "ljspeech-0": |
| | spk1 += 1 |
| | else: |
| | spk2 += 1 |
| | assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced" |
| |
|
| | def test_perfect_sampler_shuffle(self): |
| | classes = set() |
| | for item in train_samples: |
| | classes.add(item["speaker_name"]) |
| |
|
| | sampler = PerfectBatchSampler( |
| | train_samples, |
| | classes, |
| | batch_size=2 * 3, |
| | num_classes_in_batch=2, |
| | label_key="speaker_name", |
| | shuffle=True, |
| | drop_last=False, |
| | ) |
| | batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)]) |
| | for batch in batchs: |
| | spk1, spk2 = 0, 0 |
| | |
| | for index in batch: |
| | if train_samples[index]["speaker_name"] == "ljspeech-0": |
| | spk1 += 1 |
| | else: |
| | spk2 += 1 |
| | assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced" |
| |
|
| | def test_length_weighted_random_sampler(self): |
| | for _ in range(1000): |
| | |
| | min_audio = random.randrange(1, 22050) |
| | max_audio = random.randrange(44100, 220500) |
| | for idx, item in enumerate(train_samples): |
| | |
| | random_increase = random.randrange(100, 1000) |
| | if idx < 5: |
| | item["audio_length"] = min_audio + random_increase |
| | else: |
| | item["audio_length"] = max_audio + random_increase |
| |
|
| | weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler( |
| | get_length_balancer_weights(train_samples, num_buckets=2), len(train_samples) |
| | ) |
| | ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) |
| | len1, len2 = 0, 0 |
| | for index in ids: |
| | if train_samples[index]["audio_length"] < max_audio: |
| | len1 += 1 |
| | else: |
| | len2 += 1 |
| | assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced" |
| |
|
| | def test_bucket_batch_sampler(self): |
| | bucket_size_multiplier = 2 |
| | sampler = range(len(train_samples)) |
| | sampler = BucketBatchSampler( |
| | sampler, |
| | data=train_samples, |
| | batch_size=7, |
| | drop_last=True, |
| | sort_key=lambda x: len(x["text"]), |
| | bucket_size_multiplier=bucket_size_multiplier, |
| | ) |
| |
|
| | |
| | min_text_len_in_bucket = 0 |
| | bucket_items = [] |
| | for batch_idx, batch in enumerate(list(sampler)): |
| | if (batch_idx + 1) % bucket_size_multiplier == 0: |
| | for bucket_item in bucket_items: |
| | self.assertLessEqual(min_text_len_in_bucket, len(train_samples[bucket_item]["text"])) |
| | min_text_len_in_bucket = len(train_samples[bucket_item]["text"]) |
| | min_text_len_in_bucket = 0 |
| | bucket_items = [] |
| | else: |
| | bucket_items += batch |
| |
|
| | |
| | self.assertEqual(len(sampler), len(train_samples) // 7) |
| |
|