|
|
|
|
|
""" |
|
|
customized sampler |
|
|
|
|
|
1. Block shuffler based on sequence length |
|
|
Like BinnedLengthSampler in https://github.com/fatchord/WaveRNN |
|
|
e.g., data length [1, 2, 3, 4, 5, 6] -> [3,1,2, 6,5,4] if block size =3 |
|
|
""" |
|
|
|
|
|
from __future__ import absolute_import |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
import torch.utils.data |
|
|
import torch.utils.data.sampler as torch_sampler |
|
|
|
|
|
import core_scripts.math_tools.random_tools as nii_rand_tk |
|
|
import core_scripts.other_tools.display as nii_warn |
|
|
|
|
|
__author__ = "Xin Wang" |
|
|
__email__ = "wangxin@nii.ac.jp" |
|
|
__copyright__ = "Copyright 2021, Xin Wang" |
|
|
|
|
|
|
|
|
g_str_sampler_bsbl = 'block_shuffle_by_length' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SamplerBlockShuffleByLen(torch_sampler.Sampler): |
|
|
""" Sampler with block shuffle based on sequence length |
|
|
e.g., data length [1, 2, 3, 4, 5, 6] -> [3,1,2, 6,5,4] if block size =3 |
|
|
""" |
|
|
def __init__(self, buf_dataseq_length, batch_size): |
|
|
""" SamplerBlockShuffleByLength(buf_dataseq_length, batch_size) |
|
|
args |
|
|
---- |
|
|
buf_dataseq_length: list or np.array of int, |
|
|
length of each data in a dataset |
|
|
batch_size: int, batch_size |
|
|
""" |
|
|
if batch_size == 1: |
|
|
mes = "Sampler block shuffle by length requires batch-size>1" |
|
|
nii_warn.f_die(mes) |
|
|
|
|
|
|
|
|
self.m_block_size = batch_size * 4 |
|
|
|
|
|
self.m_idx = np.argsort(buf_dataseq_length) |
|
|
return |
|
|
|
|
|
def __iter__(self): |
|
|
""" Return a iterator to be iterated. |
|
|
""" |
|
|
tmp_list = list(self.m_idx.copy()) |
|
|
|
|
|
|
|
|
|
|
|
nii_rand_tk.f_shuffle_in_block_inplace(tmp_list, self.m_block_size) |
|
|
|
|
|
|
|
|
|
|
|
nii_rand_tk.f_shuffle_blocks_inplace(tmp_list, self.m_block_size) |
|
|
|
|
|
|
|
|
|
|
|
return iter(tmp_list) |
|
|
|
|
|
|
|
|
def __len__(self): |
|
|
""" Sampler requires __len__ |
|
|
https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler |
|
|
""" |
|
|
return len(self.m_idx) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Definition of customized_sampler") |
|
|
|