csun22's picture
Upload 59 files
ca1888b verified
#!/usr/bin/env python
"""
customized sampler
1. Block shuffler based on sequence length
Like BinnedLengthSampler in https://github.com/fatchord/WaveRNN
e.g., data length [1, 2, 3, 4, 5, 6] -> [3,1,2, 6,5,4] if block size =3
"""
from __future__ import absolute_import
import os
import sys
import numpy as np
import torch
import torch.utils.data
import torch.utils.data.sampler as torch_sampler
import core_scripts.math_tools.random_tools as nii_rand_tk
import core_scripts.other_tools.display as nii_warn
__author__ = "Xin Wang"
__email__ = "wangxin@nii.ac.jp"
__copyright__ = "Copyright 2021, Xin Wang"
# name of the sampler
g_str_sampler_bsbl = 'block_shuffle_by_length'
###############################################
# Sampler definition
###############################################
class SamplerBlockShuffleByLen(torch_sampler.Sampler):
""" Sampler with block shuffle based on sequence length
e.g., data length [1, 2, 3, 4, 5, 6] -> [3,1,2, 6,5,4] if block size =3
"""
def __init__(self, buf_dataseq_length, batch_size):
""" SamplerBlockShuffleByLength(buf_dataseq_length, batch_size)
args
----
buf_dataseq_length: list or np.array of int,
length of each data in a dataset
batch_size: int, batch_size
"""
if batch_size == 1:
mes = "Sampler block shuffle by length requires batch-size>1"
nii_warn.f_die(mes)
# hyper-parameter, just let block_size = batch_size * 3
self.m_block_size = batch_size * 4
# idx sorted based on sequence length
self.m_idx = np.argsort(buf_dataseq_length)
return
def __iter__(self):
""" Return a iterator to be iterated.
"""
tmp_list = list(self.m_idx.copy())
# shuffle within each block
# e.g., [1,2,3,4,5,6], block_size=3 -> [3,1,2,5,4,6]
nii_rand_tk.f_shuffle_in_block_inplace(tmp_list, self.m_block_size)
# shuffle blocks
# e.g., [3,1,2,5,4,6], block_size=3 -> [5,4,6,3,1,2]
nii_rand_tk.f_shuffle_blocks_inplace(tmp_list, self.m_block_size)
# return a iterator, list is iterable but not a iterator
# https://www.programiz.com/python-programming/iterator
return iter(tmp_list)
def __len__(self):
""" Sampler requires __len__
https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler
"""
return len(self.m_idx)
if __name__ == "__main__":
print("Definition of customized_sampler")