|
|
|
|
|
""" |
|
|
data_io |
|
|
|
|
|
Interface to load data |
|
|
|
|
|
""" |
|
|
from __future__ import absolute_import |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.utils.data |
|
|
|
|
|
import core_scripts.other_tools.list_tools as nii_list_tools |
|
|
import core_scripts.other_tools.display as nii_warn |
|
|
import core_scripts.other_tools.str_tools as nii_str_tk |
|
|
import core_scripts.data_io.io_tools as nii_io_tk |
|
|
import core_scripts.data_io.wav_tools as nii_wav_tk |
|
|
import core_scripts.data_io.text_process.text_io as nii_text_tk |
|
|
import core_scripts.data_io.conf as nii_dconf |
|
|
|
|
|
import core_scripts.data_io.seq_info as nii_seqinfo |
|
|
import core_scripts.math_tools.stats as nii_stats |
|
|
import core_scripts.data_io.customize_collate_fn as nii_collate_fn |
|
|
import core_scripts.data_io.customize_sampler as nii_sampler_fn |
|
|
|
|
|
__author__ = "Xin Wang" |
|
|
__email__ = "wangxin@nii.ac.jp" |
|
|
__copyright__ = "Copyright 2020, Xin Wang" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _data_reader(file_path, dim, flag_lang): |
|
|
""" A wrapper to read raw binary data, waveform, or text |
|
|
""" |
|
|
file_name, file_ext = os.path.splitext(file_path) |
|
|
if file_ext == '.wav': |
|
|
sr, data = nii_wav_tk.waveReadAsFloat(file_path) |
|
|
elif file_ext == '.flac': |
|
|
sr, data = nii_wav_tk.flacReadAsFloat(file_path) |
|
|
elif file_ext == '.txt': |
|
|
data = nii_text_tk.textloader(file_path, flag_lang) |
|
|
else: |
|
|
data = nii_io_tk.f_read_raw_mat(file_path, dim) |
|
|
return data |
|
|
|
|
|
def _data_writer(data, file_path, sr = 16000): |
|
|
""" A wrapper to write raw binary data or waveform |
|
|
""" |
|
|
file_name, file_ext = os.path.splitext(file_path) |
|
|
if file_ext == '.wav': |
|
|
nii_wav_tk.waveFloatToPCMFile(data, file_path, sr = sr) |
|
|
elif file_ext == '.txt': |
|
|
nii_warn.f_die("Cannot write to %s" % (file_path)) |
|
|
else: |
|
|
nii_io_tk.f_write_raw_mat(data, file_path) |
|
|
return |
|
|
|
|
|
def _data_len_reader(file_path): |
|
|
""" A wrapper to read length of data |
|
|
""" |
|
|
file_name, file_ext = os.path.splitext(file_path) |
|
|
if file_ext == '.wav': |
|
|
sr, data = nii_wav_tk.waveReadAsFloat(file_path) |
|
|
length = data.shape[0] |
|
|
elif file_ext == '.flac': |
|
|
sr, data = nii_wav_tk.flacReadAsFloat(file_path) |
|
|
length = data.shape[0] |
|
|
elif file_ext == '.txt': |
|
|
|
|
|
|
|
|
length = 0 |
|
|
else: |
|
|
length = nii_io_tk.f_read_raw_mat_length(file_path) |
|
|
return length |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NIIDataSet(torch.utils.data.Dataset): |
|
|
""" General class for NII speech dataset |
|
|
For definition of customized Dataset, please refer to |
|
|
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html |
|
|
""" |
|
|
def __init__(self, |
|
|
dataset_name, \ |
|
|
file_list, \ |
|
|
input_dirs, input_exts, input_dims, input_reso, \ |
|
|
input_norm, \ |
|
|
output_dirs, output_exts, output_dims, output_reso, \ |
|
|
output_norm, \ |
|
|
stats_path, \ |
|
|
data_format = nii_dconf.h_dtype_str, \ |
|
|
truncate_seq = None, \ |
|
|
min_seq_len = None, \ |
|
|
save_mean_std = True, \ |
|
|
wav_samp_rate = None, \ |
|
|
flag_lang = 'EN', \ |
|
|
global_arg = None): |
|
|
""" |
|
|
args |
|
|
---- |
|
|
dataset_name: name of this data set |
|
|
file_list: a list of file name strings (without extension) |
|
|
or, path to the file that contains the file names |
|
|
input_dirs: a list of dirs from which input feature is loaded |
|
|
input_exts: a list of input feature name extentions |
|
|
input_dims: a list of input feature dimensions |
|
|
input_reso: a list of input feature temporal resolutions |
|
|
input_norm: a list of bool, whether normalize input feature or not |
|
|
output_dirs: a list of dirs from which output feature is loaded |
|
|
output_exts: a list of output feature name extentions |
|
|
output_dims: a list of output feature dimensions |
|
|
output_reso: a list of output feature temporal resolutions |
|
|
output_norm: a list of bool, whether normalize target feature or not |
|
|
stat_path: path to the directory that saves mean/std, |
|
|
utterance length |
|
|
data_format: method to load the data |
|
|
'<f4' (default): load data as float32m little-endian |
|
|
'htk': load data as htk format |
|
|
truncate_seq: None (default) or int, truncate sequence into truncks. |
|
|
truncate_seq > 0 specifies the trunck length |
|
|
min_seq_len: None (default) or int, minimum length of an utterance |
|
|
utterance shorter than min_seq_len will be ignored |
|
|
save_mean_std: bool, True (default): save mean and std |
|
|
wav_samp_rate: None (default) or int, if input data has waveform, |
|
|
please set sampling rate. It is used by _data_writer |
|
|
flag_lang: str, 'EN' (default), if input data has text, the text will |
|
|
be converted into code indices. flag_lang indicates the |
|
|
language for the text processer. It is used by _data_reader |
|
|
global_arg: argument parser returned by arg_parse.f_args_parsed() |
|
|
default None |
|
|
""" |
|
|
|
|
|
self.m_set_name = dataset_name |
|
|
self.m_file_list = file_list |
|
|
self.m_input_dirs = input_dirs |
|
|
self.m_input_exts = input_exts |
|
|
self.m_input_dims = input_dims |
|
|
|
|
|
self.m_output_dirs = output_dirs |
|
|
self.m_output_exts = output_exts |
|
|
self.m_output_dims = output_dims |
|
|
|
|
|
if len(self.m_input_dirs) != len(self.m_input_exts) or \ |
|
|
len(self.m_input_dirs) != len(self.m_input_dims): |
|
|
nii_warn.f_print("Input dirs, exts, dims, unequal length", |
|
|
'error') |
|
|
nii_warn.f_print(str(self.m_input_dirs), 'error') |
|
|
nii_warn.f_print(str(self.m_input_exts), 'error') |
|
|
nii_warn.f_print(str(self.m_input_dims), 'error') |
|
|
nii_warn.f_die("Please check input dirs, exts, dims") |
|
|
|
|
|
if len(self.m_output_dims) != len(self.m_output_exts) or \ |
|
|
(self.m_output_dirs and \ |
|
|
len(self.m_output_dirs) != len(self.m_output_exts)): |
|
|
nii_warn.f_print("Output dirs, exts, dims, unequal length", \ |
|
|
'error') |
|
|
nii_warn.f_die("Please check output dirs, exts, dims") |
|
|
|
|
|
|
|
|
def _tmp_f(list2, default_value, length): |
|
|
if list2 is None: |
|
|
return [default_value for x in range(length)] |
|
|
else: |
|
|
return list2 |
|
|
|
|
|
self.m_input_reso = _tmp_f(input_reso, 1, len(input_dims)) |
|
|
self.m_input_norm = _tmp_f(input_norm, True, len(input_dims)) |
|
|
self.m_output_reso = _tmp_f(output_reso, 1, len(output_dims)) |
|
|
self.m_output_norm = _tmp_f(output_norm, True, len(output_dims)) |
|
|
if len(self.m_input_reso) != len(self.m_input_dims): |
|
|
nii_warn.f_die("len(input_reso) != len(input_dims) in config") |
|
|
if len(self.m_output_reso) != len(self.m_output_dims): |
|
|
nii_warn.f_die("len(output_reso) != len(input_dims) in config") |
|
|
if len(self.m_input_norm) != len(self.m_input_dims): |
|
|
nii_warn.f_die("len(input_norm) != len(input_dims) in config") |
|
|
if len(self.m_output_norm) != len(self.m_output_dims): |
|
|
nii_warn.f_die("len(output_norm) != len(output_dims) in config") |
|
|
|
|
|
|
|
|
self.m_input_all_dim = sum(self.m_input_dims) |
|
|
self.m_output_all_dim = sum(self.m_output_dims) |
|
|
self.m_io_dim = self.m_input_all_dim + self.m_output_all_dim |
|
|
|
|
|
self.m_truncate_seq = truncate_seq |
|
|
self.m_min_seq_len = min_seq_len |
|
|
self.m_save_ms = save_mean_std |
|
|
|
|
|
|
|
|
self.m_wav_sr = wav_samp_rate |
|
|
|
|
|
if global_arg is not None: |
|
|
self.m_opt_wav_handler = global_arg.opt_wav_silence_handler |
|
|
else: |
|
|
self.m_opt_wav_handler = 0 |
|
|
|
|
|
|
|
|
self.m_flag_lang = flag_lang |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if any([x != self.m_input_reso[0] for x in self.m_input_reso]): |
|
|
nii_warn.f_print("input_reso: %s" % (str(self.m_input_reso)),\ |
|
|
'error') |
|
|
nii_warn.f_print("NIIDataSet not support", 'error', end='') |
|
|
nii_warn.f_die(" different input_reso") |
|
|
|
|
|
if any([x != self.m_output_reso[0] for x in self.m_output_reso]): |
|
|
nii_warn.f_print("output_reso: %s" % (str(self.m_output_reso)),\ |
|
|
'error') |
|
|
nii_warn.f_print("NIIDataSet not support", 'error', end='') |
|
|
nii_warn.f_die(" different output_reso") |
|
|
if np.any(np.array(self.m_output_reso) < 0): |
|
|
nii_warn.f_print("NIIDataSet not support negative reso", |
|
|
'error', end='') |
|
|
nii_warn.f_die(" Output reso: %s" % (str(self.m_output_reso))) |
|
|
if np.any(np.array(self.m_input_reso) < 0): |
|
|
nii_warn.f_print("Input resolution: %s" % (str(self.m_input_reso))) |
|
|
nii_warn.f_print("Data IO for unaligned input and output pairs") |
|
|
if truncate_seq is not None: |
|
|
nii_warn.f_print("truncate is set to None", 'warning') |
|
|
self.m_truncate_seq = None |
|
|
self.m_min_seq_len = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.m_single_reso = np.max(self.m_input_reso + self.m_output_reso) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.m_truncate_seq is not None: |
|
|
|
|
|
self.m_truncate_seq = self.f_adjust_len(self.m_truncate_seq) |
|
|
|
|
|
|
|
|
if self.m_min_seq_len is not None: |
|
|
|
|
|
self.m_min_seq_len = self.f_adjust_len(self.m_min_seq_len) |
|
|
|
|
|
|
|
|
if data_format == nii_dconf.h_dtype_str: |
|
|
self.f_load_data = lambda x, y: _data_reader(x, y, self.m_flag_lang) |
|
|
self.f_length_data = _data_len_reader |
|
|
self.f_write_data = lambda x, y: _data_writer(x, y, self.m_wav_sr) |
|
|
else: |
|
|
nii_warn.f_print("Unsupported dtype %s" % (data_format)) |
|
|
nii_warn.f_die("Only supports %s " % (nii_dconf.h_dtype_str)) |
|
|
|
|
|
|
|
|
self.f_check_file_list() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_name(stats_path, set_name, file_name): |
|
|
tmp = set_name + '_' + file_name |
|
|
return os.path.join(stats_path, tmp) |
|
|
|
|
|
self.m_ms_input_path = get_name(stats_path, self.m_set_name, \ |
|
|
nii_dconf.mean_std_i_file) |
|
|
self.m_ms_output_path = get_name(stats_path, self.m_set_name, \ |
|
|
nii_dconf.mean_std_o_file) |
|
|
self.m_data_len_path = get_name(stats_path, self.m_set_name, \ |
|
|
nii_dconf.data_len_file) |
|
|
|
|
|
|
|
|
flag_cal_len = self.f_init_data_len_stats(self.m_data_len_path) |
|
|
flag_cal_mean_std = self.f_init_mean_std(self.m_ms_input_path, |
|
|
self.m_ms_output_path) |
|
|
|
|
|
|
|
|
if flag_cal_len or flag_cal_mean_std: |
|
|
self.f_calculate_stats(flag_cal_len, flag_cal_mean_std) |
|
|
|
|
|
|
|
|
if self.__len__() < 1: |
|
|
nii_warn.f_print("Fail to load any data", "error") |
|
|
nii_warn.f_print("Possible reasons: ", "error") |
|
|
mes = "1. Old cache %s. Please delete it." % (self.m_data_len_path) |
|
|
mes += "\n2. input_dirs, input_exts, " |
|
|
mes += "output_dirs, or output_exts incorrect." |
|
|
mes += "\n3. all data are less than minimum_len in length. " |
|
|
mes += "\nThe last case may happen if truncate_seq == mininum_len " |
|
|
mes += "and truncate_seq % input_reso != 0. Then, the actual " |
|
|
mes += "truncate_seq becomes truncate_seq//input_reso*input_reso " |
|
|
mes += "and it will be shorter than minimum_len. Please change " |
|
|
mes += "truncate_seq and minimum_len so that " |
|
|
mes += "truncate_seq % input_reso == 0." |
|
|
nii_warn.f_print(mes, "error") |
|
|
nii_warn.f_die("Please check configuration file") |
|
|
|
|
|
return |
|
|
|
|
|
def __len__(self): |
|
|
""" __len__(): |
|
|
Return the number of samples in the list |
|
|
""" |
|
|
return len(self.m_seq_info) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" __getitem__(self, idx): |
|
|
Return input, output |
|
|
|
|
|
For test set data, output can be None |
|
|
""" |
|
|
try: |
|
|
tmp_seq_info = self.m_seq_info[idx] |
|
|
except IndexError: |
|
|
nii_warn.f_die("Sample %d is not in seq_info" % (idx)) |
|
|
|
|
|
|
|
|
file_name = tmp_seq_info.seq_tag() |
|
|
|
|
|
|
|
|
input_reso = self.m_input_reso[0] |
|
|
seq_len = int(tmp_seq_info.seq_length() // input_reso) |
|
|
s_idx = int(tmp_seq_info.seq_start_pos() // input_reso) |
|
|
e_idx = s_idx + seq_len |
|
|
|
|
|
|
|
|
if seq_len < 0: |
|
|
seq_len = 0 |
|
|
s_idx = 0 |
|
|
e_idx = 0 |
|
|
|
|
|
input_dim = self.m_input_all_dim |
|
|
in_data = np.zeros([seq_len, input_dim], dtype=nii_dconf.h_dtype) |
|
|
s_dim = 0 |
|
|
e_dim = 0 |
|
|
|
|
|
|
|
|
for t_dir, t_ext, t_dim, t_res in \ |
|
|
zip(self.m_input_dirs, self.m_input_exts, \ |
|
|
self.m_input_dims, self.m_input_reso): |
|
|
e_dim = s_dim + t_dim |
|
|
|
|
|
|
|
|
file_path = nii_str_tk.f_realpath(t_dir, file_name, t_ext) |
|
|
try: |
|
|
tmp_d = self.f_load_data(file_path, t_dim) |
|
|
except IOError: |
|
|
nii_warn.f_die("Cannot find %s" % (file_path)) |
|
|
|
|
|
|
|
|
if t_res < 0: |
|
|
|
|
|
|
|
|
|
|
|
if tmp_d.ndim == 1: |
|
|
in_data = np.expand_dims(tmp_d, axis=1) |
|
|
elif tmp_d.ndim == 2: |
|
|
in_data = tmp_d |
|
|
else: |
|
|
nii_warn.f_die("Default IO cannot handle %s" % (file_path)) |
|
|
elif tmp_d.shape[0] == 1: |
|
|
|
|
|
if tmp_d.ndim > 1: |
|
|
in_data[:,s_dim:e_dim] = tmp_d[0,:] |
|
|
elif t_dim == 1: |
|
|
in_data[:,s_dim] = tmp_d |
|
|
else: |
|
|
nii_warn.f_die("Dimension wrong %s" % (file_path)) |
|
|
else: |
|
|
|
|
|
if tmp_d.ndim > 1: |
|
|
|
|
|
in_data[:,s_dim:e_dim] = tmp_d[s_idx:e_idx,:] |
|
|
elif t_dim == 1: |
|
|
|
|
|
in_data[:,s_dim] = tmp_d[s_idx:e_idx] |
|
|
else: |
|
|
nii_warn.f_die("Dimension wrong %s" % (file_path)) |
|
|
s_dim = e_dim |
|
|
|
|
|
|
|
|
if self.m_output_dirs: |
|
|
output_reso = self.m_output_reso[0] |
|
|
seq_len = int(tmp_seq_info.seq_length() // output_reso) |
|
|
s_idx = int(tmp_seq_info.seq_start_pos() // output_reso) |
|
|
e_idx = s_idx + seq_len |
|
|
|
|
|
out_dim = self.m_output_all_dim |
|
|
out_data = np.zeros([seq_len, out_dim], \ |
|
|
dtype = nii_dconf.h_dtype) |
|
|
s_dim = 0 |
|
|
e_dim = 0 |
|
|
for t_dir, t_ext, t_dim in zip(self.m_output_dirs, \ |
|
|
self.m_output_exts, \ |
|
|
self.m_output_dims): |
|
|
e_dim = s_dim + t_dim |
|
|
|
|
|
file_path = nii_str_tk.f_realpath(t_dir, file_name, t_ext) |
|
|
try: |
|
|
tmp_d = self.f_load_data(file_path, t_dim) |
|
|
except IOError: |
|
|
nii_warn.f_die("Cannot find %s" % (file_path)) |
|
|
|
|
|
if tmp_d.shape[0] == 1: |
|
|
if tmp_d.ndim > 1: |
|
|
out_data[:,s_dim:e_dim] = tmp_d[0,:] |
|
|
elif t_dim == 1: |
|
|
out_data[:,s_dim]=tmp_d |
|
|
else: |
|
|
nii_warn.f_die("Dimension wrong %s" % (file_path)) |
|
|
else: |
|
|
if tmp_d.ndim > 1: |
|
|
out_data[:,s_dim:e_dim] = tmp_d[s_idx:e_idx,:] |
|
|
elif t_dim == 1: |
|
|
out_data[:,s_dim]=tmp_d[s_idx:e_idx] |
|
|
else: |
|
|
nii_warn.f_die("Dimension wrong %s" % (file_path)) |
|
|
s_dim = s_dim + t_dim |
|
|
else: |
|
|
out_data = [] |
|
|
|
|
|
|
|
|
in_data, out_data, tmp_seq_info, idx = self.f_post_data_process( |
|
|
in_data, out_data, tmp_seq_info, idx) |
|
|
|
|
|
|
|
|
return in_data, out_data, tmp_seq_info.print_to_str(), idx |
|
|
|
|
|
|
|
|
def f_post_data_process(self, in_data, out_data, seq_info, idx): |
|
|
"""A wrapper to process the data after loading from files |
|
|
""" |
|
|
|
|
|
if self.m_opt_wav_handler > 0: |
|
|
|
|
|
|
|
|
tmp_seq_info = nii_seqinfo.SeqInfo( |
|
|
seq_info.length, seq_info.seq_name, seq_info.seg_idx, |
|
|
seq_info.start_pos, seq_info.info_id) |
|
|
|
|
|
|
|
|
if len(self.m_input_exts) == 1 \ |
|
|
and self.m_input_exts[0][-3:] == 'wav': |
|
|
in_data_n = nii_wav_tk.silence_handler( |
|
|
in_data[:, 0], self.m_wav_sr, |
|
|
flag_output = self.m_opt_wav_handler) |
|
|
in_data_n = np.expand_dims(in_data_n, axis=1) |
|
|
|
|
|
|
|
|
if tmp_seq_info.length == in_data.shape[0]: |
|
|
tmp_seq_info.length = in_data_n.shape[0] |
|
|
else: |
|
|
in_data_n = in_data |
|
|
|
|
|
if len(self.m_output_exts) == 1 \ |
|
|
and self.m_output_exts[0][-3:] == 'wav': |
|
|
out_data_n = nii_wav_tk.silence_handler( |
|
|
out_data[:,0], self.m_wav_sr, |
|
|
flag_output = self.m_opt_wav_handler) |
|
|
out_data_n = np.expand_dims(out_data_n, axis=1) |
|
|
|
|
|
|
|
|
if tmp_seq_info.length == out_data.shape[0]: |
|
|
tmp_seq_info.length = out_data_n.shape[0] |
|
|
else: |
|
|
out_data_n = out_data |
|
|
|
|
|
return in_data_n, out_data_n, tmp_seq_info, idx |
|
|
|
|
|
else: |
|
|
return in_data, out_data, seq_info, idx |
|
|
|
|
|
|
|
|
def f_get_num_seq(self): |
|
|
""" __len__(): |
|
|
Return the number of samples in the list |
|
|
""" |
|
|
return len(self.m_seq_info) |
|
|
|
|
|
def f_get_seq_len_list(self): |
|
|
""" Return length of each sequence as list |
|
|
""" |
|
|
return [x.seq_length() for x in self.m_seq_info] |
|
|
|
|
|
def f_get_mean_std_tuple(self): |
|
|
return (self.m_input_mean, self.m_input_std, |
|
|
self.m_output_mean, self.m_output_std) |
|
|
|
|
|
|
|
|
def f_check_file_list(self): |
|
|
""" f_check_file_list(): |
|
|
Check the file list after initialization |
|
|
Make sure that the file in file_list appears in every |
|
|
input/output feature directory. |
|
|
If not, get a file_list in which every file is avaiable |
|
|
in every input/output directory |
|
|
""" |
|
|
if not isinstance(self.m_file_list, list): |
|
|
if isinstance(self.m_file_list, str) and \ |
|
|
os.path.isfile(self.m_file_list): |
|
|
|
|
|
self.m_file_list = nii_list_tools.read_list_from_text( |
|
|
self.m_file_list) |
|
|
else: |
|
|
nii_warn.f_print("Cannot read {:s}".format(self.m_file_list)) |
|
|
nii_warn.f_print("Read file list from directories") |
|
|
self.m_file_list = None |
|
|
|
|
|
|
|
|
if self.m_file_list is None: |
|
|
self.m_file_list = nii_list_tools.listdir_with_ext( |
|
|
self.m_input_dirs[0], self.m_input_exts[0]) |
|
|
|
|
|
|
|
|
for tmp_d, tmp_e in zip(self.m_input_dirs, \ |
|
|
self.m_input_exts): |
|
|
tmp_list = nii_list_tools.listdir_with_ext(tmp_d, tmp_e) |
|
|
self.m_file_list = nii_list_tools.common_members( |
|
|
tmp_list, self.m_file_list) |
|
|
|
|
|
if len(self.m_file_list) < 1: |
|
|
nii_warn.f_print("No input features found after scannning", 'error') |
|
|
nii_warn.f_print("Please check %s" \ |
|
|
% (str(self.m_input_dirs)), 'error') |
|
|
nii_warn.f_print("They should contain all files in file list", |
|
|
'error') |
|
|
nii_warn.f_print("Please also check filename extentions %s" \ |
|
|
% (str(self.m_input_exts)), 'error') |
|
|
nii_warn.f_print("They should be correctly specified", 'error') |
|
|
nii_warn.f_die("Failed to read input features") |
|
|
|
|
|
|
|
|
if self.m_output_dirs: |
|
|
for tmp_d, tmp_e in zip(self.m_output_dirs, \ |
|
|
self.m_output_exts): |
|
|
tmp_list = nii_list_tools.listdir_with_ext(tmp_d, tmp_e) |
|
|
self.m_file_list = nii_list_tools.common_members( |
|
|
tmp_list, self.m_file_list) |
|
|
|
|
|
if len(self.m_file_list) < 1: |
|
|
nii_warn.f_print("No output data found", 'error') |
|
|
nii_warn.f_print("Please check %s" \ |
|
|
% (str(self.m_output_dirs)), 'error') |
|
|
nii_warn.f_print("They should contain all files in file list", |
|
|
'error') |
|
|
nii_warn.f_print("Please also check filename extentions %s" \ |
|
|
% (str(self.m_output_exts)), 'error') |
|
|
nii_warn.f_print("They should be correctly specified", 'error') |
|
|
nii_warn.f_die("Failed to read output features") |
|
|
else: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
def f_valid_len(self, t_1, t_2, min_length): |
|
|
""" f_valid_time_steps(time_step1, time_step2, min_length) |
|
|
When either t_1 > min_length or t_2 > min_length, check whether |
|
|
time_step1 and time_step2 are too different |
|
|
""" |
|
|
if max(t_1, t_2) > min_length: |
|
|
if (np.abs(t_1 - t_2) * 1.0 / t_1) > 0.1: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def f_check_specific_data(self, file_name): |
|
|
""" check the data length of a specific file |
|
|
""" |
|
|
tmp_dirs = self.m_input_dirs.copy() |
|
|
tmp_exts = self.m_input_exts.copy() |
|
|
tmp_dims = self.m_input_dims.copy() |
|
|
tmp_reso = self.m_input_reso.copy() |
|
|
tmp_dirs.extend(self.m_output_dirs) |
|
|
tmp_exts.extend(self.m_output_exts) |
|
|
tmp_dims.extend(self.m_output_dims) |
|
|
tmp_reso.extend(self.m_output_reso) |
|
|
|
|
|
|
|
|
for t_dir, t_ext, t_dim, t_res in \ |
|
|
zip(tmp_dirs, tmp_exts, tmp_dims, tmp_reso): |
|
|
|
|
|
file_path = nii_str_tk.f_realpath(t_dir, file_name, t_ext) |
|
|
if not nii_io_tk.file_exist(file_path): |
|
|
nii_warn.f_die("%s not found" % (file_path)) |
|
|
else: |
|
|
t_len = self.f_length_data(file_path) // t_dim |
|
|
print("%s, length %d, dim %d, reso: %d" % \ |
|
|
(file_path, t_len, t_dim, t_res)) |
|
|
return |
|
|
|
|
|
|
|
|
def f_log_data_len(self, file_name, t_len, t_reso): |
|
|
""" f_log_data_len(file_name, t_len, t_reso): |
|
|
Log down the length of the data file. |
|
|
|
|
|
When comparing the different input/output features for the same |
|
|
file_name, only keep the shortest length |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
valid_flag = t_len > 1 and t_reso > 0 |
|
|
|
|
|
if valid_flag: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tmp_len = t_len * t_reso |
|
|
|
|
|
|
|
|
if file_name not in self.m_data_length: |
|
|
self.m_data_length[file_name] = tmp_len |
|
|
|
|
|
|
|
|
if t_len == 1: |
|
|
|
|
|
|
|
|
pass |
|
|
elif self.f_valid_len(self.m_data_length[file_name], tmp_len, \ |
|
|
nii_dconf.data_seq_min_length): |
|
|
|
|
|
if self.m_data_length[file_name] > tmp_len: |
|
|
self.m_data_length[file_name] = tmp_len |
|
|
else: |
|
|
nii_warn.f_print("Sequence length mismatch:", 'error') |
|
|
self.f_check_specific_data(file_name) |
|
|
nii_warn.f_print("Please the above features", 'error') |
|
|
nii_warn.f_die("Possible invalid data %s" % (file_name)) |
|
|
|
|
|
|
|
|
|
|
|
tmp = self.m_data_length[file_name] |
|
|
self.m_data_length[file_name] = self.f_adjust_len(tmp) |
|
|
else: |
|
|
|
|
|
pass |
|
|
|
|
|
return |
|
|
|
|
|
def f_adjust_len(self, length): |
|
|
""" When input data will be up-sampled by self.m_single_reso, |
|
|
Make sure that the sequence length at the up-sampled level is |
|
|
= N * self.m_single_reso |
|
|
For data without up-sampling m_single_reso = 1 |
|
|
""" |
|
|
return length // self.m_single_reso * self.m_single_reso |
|
|
|
|
|
def f_precheck_data_length(self): |
|
|
""" For unaligned input and output, there is no way to know the |
|
|
target sequence length before hand during inference stage |
|
|
|
|
|
self.m_data_length will be empty |
|
|
""" |
|
|
|
|
|
if not self.m_data_length and not self.m_output_dirs and \ |
|
|
all([x < 0 for x in self.m_input_reso]): |
|
|
|
|
|
|
|
|
for file_name in self.m_file_list: |
|
|
self.m_data_length[file_name] = 0 |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
def f_log_seq_info(self): |
|
|
""" After m_data_length has been created, create seq_info |
|
|
|
|
|
""" |
|
|
for file_name in self.m_file_list: |
|
|
|
|
|
|
|
|
if file_name not in self.m_data_length: |
|
|
nii_warn.f_eprint("Exclude %s from dataset" % (file_name)) |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
length_remain = self.m_data_length[file_name] |
|
|
start_pos = 0 |
|
|
seg_idx = 0 |
|
|
if self.m_truncate_seq is not None: |
|
|
while(length_remain > 0): |
|
|
info_idx = len(self.m_seq_info) |
|
|
seg_length = min(self.m_truncate_seq, length_remain) |
|
|
seq_info = nii_seqinfo.SeqInfo(seg_length, |
|
|
file_name, seg_idx, |
|
|
start_pos, info_idx) |
|
|
if self.m_min_seq_len is None or \ |
|
|
seg_length >= self.m_min_seq_len: |
|
|
self.m_seq_info.append(seq_info) |
|
|
seg_idx += 1 |
|
|
start_pos += seg_length |
|
|
length_remain -= seg_length |
|
|
else: |
|
|
info_idx = len(self.m_seq_info) |
|
|
seq_info = nii_seqinfo.SeqInfo(length_remain, |
|
|
file_name, seg_idx, |
|
|
start_pos, info_idx) |
|
|
if self.m_min_seq_len is None or \ |
|
|
length_remain >= self.m_min_seq_len: |
|
|
self.m_seq_info.append(seq_info) |
|
|
|
|
|
|
|
|
self.m_data_total_length = self.f_sum_data_length() |
|
|
return |
|
|
|
|
|
def f_init_mean_std(self, ms_input_path, ms_output_path): |
|
|
""" f_init_mean_std |
|
|
Initialzie mean and std vectors for input and output |
|
|
""" |
|
|
self.m_input_mean = np.zeros([self.m_input_all_dim]) |
|
|
self.m_input_std = np.ones([self.m_input_all_dim]) |
|
|
self.m_output_mean = np.zeros([self.m_output_all_dim]) |
|
|
self.m_output_std = np.ones([self.m_output_all_dim]) |
|
|
|
|
|
flag = True |
|
|
if not self.m_save_ms: |
|
|
|
|
|
|
|
|
flag = False |
|
|
|
|
|
if not any(self.m_input_norm + self.m_output_norm): |
|
|
|
|
|
flag = False |
|
|
|
|
|
if os.path.isfile(ms_input_path) and \ |
|
|
os.path.isfile(ms_output_path): |
|
|
|
|
|
ms_input = self.f_load_data(ms_input_path, 1) |
|
|
ms_output = self.f_load_data(ms_output_path, 1) |
|
|
|
|
|
if ms_input.shape[0] != (self.m_input_all_dim * 2) or \ |
|
|
ms_output.shape[0] != (self.m_output_all_dim * 2): |
|
|
if ms_input.shape[0] != (self.m_input_all_dim * 2): |
|
|
nii_warn.f_print("%s incompatible" % (ms_input_path), |
|
|
'warning') |
|
|
if ms_output.shape[0] != (self.m_output_all_dim * 2): |
|
|
nii_warn.f_print("%s incompatible" % (ms_output_path), |
|
|
'warning') |
|
|
nii_warn.f_print("mean/std will be recomputed", 'warning') |
|
|
else: |
|
|
self.m_input_mean = ms_input[0:self.m_input_all_dim] |
|
|
self.m_input_std = ms_input[self.m_input_all_dim:] |
|
|
|
|
|
self.m_output_mean = ms_output[0:self.m_output_all_dim] |
|
|
self.m_output_std = ms_output[self.m_output_all_dim:] |
|
|
nii_warn.f_print("Load mean/std from %s and %s" % \ |
|
|
(ms_input_path, ms_output_path)) |
|
|
flag = False |
|
|
return flag |
|
|
|
|
|
|
|
|
def f_sum_data_length(self): |
|
|
""" |
|
|
""" |
|
|
|
|
|
return sum([x.seq_length() for x in self.m_seq_info]) |
|
|
|
|
|
def f_init_data_len_stats(self, data_path): |
|
|
""" |
|
|
flag = f_init_data_len_stats(self, data_path) |
|
|
Check whether data length has been stored in data_pat. |
|
|
If yes, load data_path and return False |
|
|
Else, return True |
|
|
""" |
|
|
self.m_seq_info = [] |
|
|
self.m_data_length = {} |
|
|
self.m_data_total_length = 0 |
|
|
|
|
|
flag = True |
|
|
if os.path.isfile(data_path): |
|
|
|
|
|
dic_seq_infos = nii_io_tk.read_dic(self.m_data_len_path) |
|
|
for dic_seq_info in dic_seq_infos: |
|
|
seq_info = nii_seqinfo.SeqInfo() |
|
|
seq_info.load_from_dic(dic_seq_info) |
|
|
self.m_seq_info.append(seq_info) |
|
|
seq_tag = seq_info.seq_tag() |
|
|
if seq_tag not in self.m_data_length: |
|
|
self.m_data_length[seq_tag] = seq_info.seq_length() |
|
|
else: |
|
|
self.m_data_length[seq_tag] += seq_info.seq_length() |
|
|
self.m_data_total_length = self.f_sum_data_length() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if nii_list_tools.list_identical(self.m_file_list,\ |
|
|
self.m_data_length.keys()): |
|
|
nii_warn.f_print("Read sequence info: %s" % (data_path)) |
|
|
flag = False |
|
|
elif nii_list_tools.list_b_in_list_a(self.m_file_list, |
|
|
self.m_data_length.keys()): |
|
|
nii_warn.f_print("Read sequence info: %s" % (data_path)) |
|
|
nii_warn.f_print( |
|
|
"However %d samples are ignoed" % \ |
|
|
(len(self.m_file_list)-len(self.m_data_length))) |
|
|
tmp = nii_list_tools.members_in_a_not_in_b( |
|
|
self.m_file_list, self.m_data_length.keys()) |
|
|
for tmp_name in tmp: |
|
|
nii_warn.f_eprint("Exclude %s from dataset" % (tmp_name)) |
|
|
|
|
|
flag = False |
|
|
else: |
|
|
self.m_seq_info = [] |
|
|
self.m_data_length = {} |
|
|
self.m_data_total_length = 0 |
|
|
|
|
|
return flag |
|
|
|
|
|
def f_save_data_len(self, data_len_path): |
|
|
""" |
|
|
""" |
|
|
nii_io_tk.write_dic([x.print_to_dic() for x in self.m_seq_info], \ |
|
|
data_len_path) |
|
|
|
|
|
def f_save_mean_std(self, ms_input_path, ms_output_path): |
|
|
""" |
|
|
""" |
|
|
|
|
|
ms_input = np.zeros([self.m_input_all_dim * 2]) |
|
|
ms_input[0:self.m_input_all_dim] = self.m_input_mean |
|
|
ms_input[self.m_input_all_dim :] = self.m_input_std |
|
|
self.f_write_data(ms_input, ms_input_path) |
|
|
|
|
|
ms_output = np.zeros([self.m_output_all_dim * 2]) |
|
|
ms_output[0:self.m_output_all_dim] = self.m_output_mean |
|
|
ms_output[self.m_output_all_dim :] = self.m_output_std |
|
|
self.f_write_data(ms_output, ms_output_path) |
|
|
|
|
|
return |
|
|
|
|
|
def f_print_info(self): |
|
|
""" |
|
|
""" |
|
|
mes = "Dataset {}:".format(self.m_set_name) |
|
|
mes += "\n Time steps: {:d} ".format(self.m_data_total_length) |
|
|
if self.m_truncate_seq is not None: |
|
|
mes += "\n Truncate length: {:d}".format(self.m_truncate_seq) |
|
|
mes += "\n Data sequence num: {:d}".format(len(self.m_seq_info)) |
|
|
tmp_min_len = min([x.seq_length() for x in self.m_seq_info]) |
|
|
tmp_max_len = max([x.seq_length() for x in self.m_seq_info]) |
|
|
mes += "\n Maximum sequence length: {:d}".format(tmp_max_len) |
|
|
mes += "\n Minimum sequence length: {:d}".format(tmp_min_len) |
|
|
if self.m_min_seq_len is not None: |
|
|
mes += "\n Shorter sequences are ignored" |
|
|
mes += "\n Inputs\n Dirs:" |
|
|
for subdir in self.m_input_dirs: |
|
|
mes += "\n {:s}".format(subdir) |
|
|
mes += "\n Exts:{:s}".format(str(self.m_input_exts)) |
|
|
mes += "\n Dims:{:s}".format(str(self.m_input_dims)) |
|
|
mes += "\n Reso:{:s}".format(str(self.m_input_reso)) |
|
|
mes += "\n Norm:{:s}".format(str(self.m_input_norm)) |
|
|
mes += "\n Outputs\n Dirs:" |
|
|
for subdir in self.m_output_dirs: |
|
|
mes += "\n {:s}".format(subdir) |
|
|
mes += "\n Exts:{:s}".format(str(self.m_output_exts)) |
|
|
mes += "\n Dims:{:s}".format(str(self.m_output_dims)) |
|
|
mes += "\n Reso:{:s}".format(str(self.m_output_reso)) |
|
|
mes += "\n Norm:{:s}".format(str(self.m_output_norm)) |
|
|
if self.m_opt_wav_handler > 0: |
|
|
mes += "\n Waveform silence handler will be used" |
|
|
nii_warn.f_print_message(mes) |
|
|
return |
|
|
|
|
|
def f_calculate_stats(self, flag_cal_data_len, flag_cal_mean_std): |
|
|
""" f_calculate_stats |
|
|
Log down the number of time steps for each file |
|
|
Calculate the mean/std |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tmp_dirs = self.m_input_dirs.copy() |
|
|
tmp_exts = self.m_input_exts.copy() |
|
|
tmp_dims = self.m_input_dims.copy() |
|
|
tmp_reso = self.m_input_reso.copy() |
|
|
tmp_norm = self.m_input_norm.copy() |
|
|
tmp_dirs.extend(self.m_output_dirs) |
|
|
tmp_exts.extend(self.m_output_exts) |
|
|
tmp_dims.extend(self.m_output_dims) |
|
|
tmp_reso.extend(self.m_output_reso) |
|
|
tmp_norm.extend(self.m_output_norm) |
|
|
|
|
|
|
|
|
s_dim = 0 |
|
|
|
|
|
e_dim = 0 |
|
|
|
|
|
|
|
|
for t_dir, t_ext, t_dim, t_reso, t_norm in \ |
|
|
zip(tmp_dirs, tmp_exts, tmp_dims, tmp_reso, tmp_norm): |
|
|
|
|
|
s_dim = e_dim |
|
|
e_dim = s_dim + t_dim |
|
|
t_cnt = 0 |
|
|
mean_i, var_i = np.zeros([t_dim]), np.zeros([t_dim]) |
|
|
|
|
|
|
|
|
for file_name in self.m_file_list: |
|
|
|
|
|
file_path = nii_str_tk.f_realpath(t_dir, file_name, t_ext) |
|
|
if not nii_io_tk.file_exist(file_path): |
|
|
nii_warn.f_die("%s not found" % (file_path)) |
|
|
|
|
|
|
|
|
if flag_cal_data_len: |
|
|
t_len = self.f_length_data(file_path) // t_dim |
|
|
self.f_log_data_len(file_name, t_len, t_reso) |
|
|
|
|
|
|
|
|
|
|
|
if flag_cal_mean_std: |
|
|
t_data = self.f_load_data(file_path, t_dim) |
|
|
|
|
|
|
|
|
if t_ext in nii_dconf.f0_unvoiced_dic: |
|
|
unvoiced_value = nii_dconf.f0_unvoiced_dic[t_ext] |
|
|
t_data = t_data[t_data > unvoiced_value] |
|
|
|
|
|
|
|
|
mean_i, var_i, t_cnt = nii_stats.f_online_mean_std( |
|
|
t_data, mean_i, var_i, t_cnt) |
|
|
|
|
|
|
|
|
if flag_cal_mean_std: |
|
|
|
|
|
if not t_norm: |
|
|
mean_i[:] = 0 |
|
|
var_i[:] = 1 |
|
|
|
|
|
if s_dim < self.m_input_all_dim: |
|
|
self.m_input_mean[s_dim:e_dim] = mean_i |
|
|
|
|
|
std_i = nii_stats.f_var2std(var_i) |
|
|
self.m_input_std[s_dim:e_dim] = std_i |
|
|
else: |
|
|
tmp_s = s_dim - self.m_input_all_dim |
|
|
tmp_e = e_dim - self.m_input_all_dim |
|
|
self.m_output_mean[tmp_s:tmp_e] = mean_i |
|
|
std_i = nii_stats.f_var2std(var_i) |
|
|
self.m_output_std[tmp_s:tmp_e] = std_i |
|
|
|
|
|
if flag_cal_data_len: |
|
|
|
|
|
self.f_precheck_data_length() |
|
|
|
|
|
self.f_log_seq_info() |
|
|
|
|
|
self.f_save_data_len(self.m_data_len_path) |
|
|
|
|
|
if flag_cal_mean_std: |
|
|
self.f_save_mean_std(self.m_ms_input_path, |
|
|
self.m_ms_output_path) |
|
|
|
|
|
return |
|
|
|
|
|
def f_putitem(self, output_data, save_dir, data_infor_str): |
|
|
""" |
|
|
""" |
|
|
|
|
|
if output_data.ndim == 3 and output_data.shape[0] == 1: |
|
|
|
|
|
output_data = output_data[0] |
|
|
elif output_data.ndim == 2 and output_data.shape[0] == 1: |
|
|
|
|
|
output_data = np.expand_dims(output_data[0], -1) |
|
|
else: |
|
|
nii_warn.f_print("Output data format not supported.", "error") |
|
|
nii_warn.f_print("Format is not (batch, len, dim)", "error") |
|
|
nii_warn.f_die("Please use batch_size = 1 in generation") |
|
|
|
|
|
|
|
|
if output_data.shape[1] != self.m_output_all_dim: |
|
|
nii_warn.f_print("Output data dim != expected dim", "error") |
|
|
nii_warn.f_print("Output:%d" % (output_data.shape[1]), \ |
|
|
"error") |
|
|
nii_warn.f_print("Expected:%d" % (self.m_output_all_dim), \ |
|
|
"error") |
|
|
nii_warn.f_die("Please check configuration") |
|
|
|
|
|
if not os.path.isdir(save_dir): |
|
|
try: |
|
|
os.mkdir(save_dir) |
|
|
except OSError: |
|
|
nii_warn.f_die("Cannot carete {}".format(save_dir)) |
|
|
|
|
|
|
|
|
tmp_seq_info = nii_seqinfo.SeqInfo() |
|
|
tmp_seq_info.parse_from_str(data_infor_str) |
|
|
|
|
|
|
|
|
file_name = tmp_seq_info.seq_tag() |
|
|
s_dim = 0 |
|
|
e_dim = 0 |
|
|
for t_ext, t_dim in zip(self.m_output_exts, self.m_output_dims): |
|
|
e_dim = s_dim + t_dim |
|
|
file_path = nii_str_tk.f_realpath(save_dir, file_name, t_ext) |
|
|
self.f_write_data(output_data[:, s_dim:e_dim], file_path) |
|
|
|
|
|
return |
|
|
|
|
|
def f_input_dim(self): |
|
|
""" |
|
|
f_input_dim() |
|
|
return the total dimension of input features |
|
|
""" |
|
|
return self.m_input_all_dim |
|
|
|
|
|
def f_output_dim(self): |
|
|
""" |
|
|
f_output_dim |
|
|
return the total dimension of output features |
|
|
""" |
|
|
return self.m_output_all_dim |
|
|
|
|
|
def f_adjust_idx(self, data_tuple, idx_shift): |
|
|
""" |
|
|
f_adjust_idx |
|
|
|
|
|
This is to be used by customize_dataset for idx adjustment. |
|
|
When multiple data sets are merged, the idx from __getitem__ |
|
|
should be adjusted. |
|
|
|
|
|
Only data_io itselts knows how to identify idx from the output of |
|
|
__getitem__, we need to define the function here |
|
|
""" |
|
|
for idx in np.arange(len(data_tuple[-1])): |
|
|
data_tuple[-1][idx] += idx_shift |
|
|
return data_tuple |
|
|
|
|
|
class NIIDataSetLoader: |
|
|
""" NIIDataSetLoader: |
|
|
A wrapper over torch.utils.data.DataLoader |
|
|
|
|
|
self.m_dataset will be the dataset |
|
|
self.m_loader will be the dataloader |
|
|
""" |
|
|
def __init__(self, |
|
|
dataset_name, \ |
|
|
file_list, \ |
|
|
input_dirs, input_exts, input_dims, input_reso, \ |
|
|
input_norm, \ |
|
|
output_dirs, output_exts, output_dims, output_reso, \ |
|
|
output_norm, \ |
|
|
stats_path, \ |
|
|
data_format = nii_dconf.h_dtype_str, \ |
|
|
params = None, \ |
|
|
truncate_seq = None, \ |
|
|
min_seq_len = None, |
|
|
save_mean_std = True, \ |
|
|
wav_samp_rate = None, \ |
|
|
flag_lang = 'EN', |
|
|
global_arg = None): |
|
|
""" |
|
|
NIIDataSetLoader( |
|
|
data_set_name, |
|
|
file_list, |
|
|
input_dirs, input_exts, input_dims, input_reso, input_norm, |
|
|
output_dirs, output_exts, output_dims, output_reso, output_norm, |
|
|
stats_path, |
|
|
data_format = '<f4', |
|
|
params = None, |
|
|
truncate_seq = None, |
|
|
min_seq_len = None, |
|
|
save_mean_std = True, \ |
|
|
wav_samp_rate = None, \ |
|
|
flag_lang = 'EN', |
|
|
global_arg = None): |
|
|
Args |
|
|
---- |
|
|
data_set_name: a string to name this dataset |
|
|
this will be used to name the statistics files |
|
|
such as the mean/std for this dataset |
|
|
file_list: a list of file name strings (without extension) |
|
|
or, path to the file that contains the file names |
|
|
input_dirs: a list of dirs from which input feature is loaded |
|
|
input_exts: a list of input feature name extentions |
|
|
input_dims: a list of input feature dimensions |
|
|
input_reso: a list of input feature temporal resolution, |
|
|
or None |
|
|
input_norm: a list of bool, whether normalize input feature or not |
|
|
|
|
|
output_dirs: a list of dirs from which output feature is loaded |
|
|
output_exts: a list of output feature name extentions |
|
|
output_dims: a list of output feature dimensions |
|
|
output_reso: a list of output feature temporal resolution, |
|
|
or None |
|
|
output_norm: a list of bool, whether normalize target feature or not |
|
|
|
|
|
stats_path: path to the directory of statistics(mean/std) |
|
|
data_format: method to load the data |
|
|
'<f4' (default): load data as float32m little-endian |
|
|
'htk': load data as htk format |
|
|
params: parameter for torch.utils.data.DataLoader |
|
|
|
|
|
truncate_seq: None or int, |
|
|
truncate data sequence into smaller truncks |
|
|
truncate_seq > 0 specifies the trunck length |
|
|
min_seq_len: None (default) or int, minimum length of an utterance |
|
|
utterance shorter than min_seq_len will be ignored |
|
|
save_mean_std: bool, True (default): save mean and std |
|
|
wav_samp_rate: None (default) or int, if input data has waveform, |
|
|
please set sampling rate. It is used by _data_writer |
|
|
flag_lang: str, 'EN' (default), if input data has text, text will |
|
|
be converted into code indices. flag_lang indicates the |
|
|
language for the text processer, used by _data_reader |
|
|
global_arg: argument parser returned by arg_parse.f_args_parsed() |
|
|
default None |
|
|
Methods |
|
|
------- |
|
|
get_loader(): return a torch.util.data.DataLoader |
|
|
get_dataset(): return a torch.util.data.DataSet |
|
|
""" |
|
|
nii_warn.f_print_w_date("Loading dataset %s" % (dataset_name), |
|
|
level="h") |
|
|
|
|
|
|
|
|
self.m_dataset = NIIDataSet(dataset_name, \ |
|
|
file_list, \ |
|
|
input_dirs, input_exts, \ |
|
|
input_dims, input_reso, \ |
|
|
input_norm, \ |
|
|
output_dirs, output_exts, \ |
|
|
output_dims, output_reso, \ |
|
|
output_norm, \ |
|
|
stats_path, data_format, \ |
|
|
truncate_seq, min_seq_len,\ |
|
|
save_mean_std, \ |
|
|
wav_samp_rate, \ |
|
|
flag_lang, \ |
|
|
global_arg) |
|
|
|
|
|
|
|
|
if params is None: |
|
|
tmp_params = nii_dconf.default_loader_conf |
|
|
else: |
|
|
tmp_params = params.copy() |
|
|
|
|
|
|
|
|
self.m_params = tmp_params.copy() |
|
|
|
|
|
|
|
|
if 'sampler' in tmp_params: |
|
|
tmp_sampler = None |
|
|
if tmp_params['sampler'] == nii_sampler_fn.g_str_sampler_bsbl: |
|
|
if 'batch_size' in tmp_params: |
|
|
|
|
|
tmp_sampler = nii_sampler_fn.SamplerBlockShuffleByLen( |
|
|
self.m_dataset.f_get_seq_len_list(), |
|
|
tmp_params['batch_size']) |
|
|
|
|
|
tmp_params['shuffle'] = False |
|
|
else: |
|
|
nii_warn.f_die("Sampler requires batch size > 1") |
|
|
tmp_params['sampler'] = tmp_sampler |
|
|
|
|
|
|
|
|
|
|
|
if 'batch_size' in tmp_params and tmp_params['batch_size'] > 1: |
|
|
|
|
|
|
|
|
collate_fn = nii_collate_fn.customize_collate |
|
|
else: |
|
|
collate_fn = None |
|
|
|
|
|
self.m_loader = torch.utils.data.DataLoader( |
|
|
self.m_dataset, collate_fn=collate_fn, **tmp_params) |
|
|
|
|
|
|
|
|
return |
|
|
|
|
|
def get_loader_params(self): |
|
|
return self.m_params |
|
|
|
|
|
def get_loader(self): |
|
|
""" get_loader(): |
|
|
Return the dataLoader (torch.util.data.DataLoader) |
|
|
""" |
|
|
return self.m_loader |
|
|
|
|
|
def get_dataset(self): |
|
|
""" get_dataset(): |
|
|
Return the dataset (torch.util.data.Dataset) |
|
|
""" |
|
|
return self.m_dataset |
|
|
|
|
|
def get_data_mean_std(self): |
|
|
""" |
|
|
""" |
|
|
return self.m_dataset.f_get_mean_std_tuple() |
|
|
|
|
|
def print_info(self): |
|
|
""" |
|
|
""" |
|
|
self.m_dataset.f_print_info() |
|
|
print(str(self.m_params)) |
|
|
return |
|
|
|
|
|
def putitem(self, output_data, save_dir, data_infor_str): |
|
|
""" Decompose the output_data from network into |
|
|
separate files |
|
|
""" |
|
|
self.m_dataset.f_putitem(output_data, save_dir, data_infor_str) |
|
|
|
|
|
def get_in_dim(self): |
|
|
""" Return the dimension of input features |
|
|
""" |
|
|
return self.m_dataset.f_input_dim() |
|
|
|
|
|
def get_out_dim(self): |
|
|
""" Return the dimension of output features |
|
|
""" |
|
|
return self.m_dataset.f_output_dim() |
|
|
|
|
|
def get_seq_num(self): |
|
|
""" Return the number of sequences (after truncation) |
|
|
""" |
|
|
return self.m_dataset.f_get_num_seq() |
|
|
|
|
|
def adjust_utt_idx(self, data_tuple, utt_idx_shift): |
|
|
""" Return data tuple with adjusted utterance index in merged dataset |
|
|
|
|
|
This is used by customize_dataset. |
|
|
""" |
|
|
return self.m_dataset.f_adjust_idx(data_tuple, utt_idx_shift) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
pass |
|
|
|