File size: 52,297 Bytes
ca1888b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
#!/usr/bin/env python
"""
data_io

Interface to load data

"""
from __future__ import absolute_import

import os
import sys
import numpy as np
import torch
import torch.utils.data

import core_scripts.other_tools.list_tools as nii_list_tools
import core_scripts.other_tools.display as nii_warn
import core_scripts.other_tools.str_tools as nii_str_tk
import core_scripts.data_io.io_tools as nii_io_tk
import core_scripts.data_io.wav_tools as nii_wav_tk
import core_scripts.data_io.text_process.text_io as nii_text_tk
import core_scripts.data_io.conf as nii_dconf

import core_scripts.data_io.seq_info as nii_seqinfo
import core_scripts.math_tools.stats as nii_stats
import core_scripts.data_io.customize_collate_fn as nii_collate_fn
import core_scripts.data_io.customize_sampler as nii_sampler_fn

__author__ = "Xin Wang"
__email__ = "wangxin@nii.ac.jp"
__copyright__ = "Copyright 2020, Xin Wang"

###
## functions wrappers to read/write data for this data_io
###
def _data_reader(file_path, dim, flag_lang):
    """ A wrapper to read raw binary data, waveform, or text
    """
    file_name, file_ext = os.path.splitext(file_path)
    if file_ext == '.wav':
        sr, data = nii_wav_tk.waveReadAsFloat(file_path)
    elif file_ext == '.flac':
        sr, data = nii_wav_tk.flacReadAsFloat(file_path)
    elif file_ext == '.txt':
        data = nii_text_tk.textloader(file_path, flag_lang)
    else:
        data = nii_io_tk.f_read_raw_mat(file_path, dim)
    return data

def _data_writer(data, file_path, sr = 16000):
    """ A wrapper to write raw binary data or waveform
    """
    file_name, file_ext = os.path.splitext(file_path)
    if file_ext == '.wav':
        nii_wav_tk.waveFloatToPCMFile(data, file_path, sr = sr)
    elif file_ext == '.txt':
        nii_warn.f_die("Cannot write to %s" % (file_path))
    else:
        nii_io_tk.f_write_raw_mat(data, file_path)
    return

def _data_len_reader(file_path):
    """ A wrapper to read length of data
    """
    file_name, file_ext = os.path.splitext(file_path)
    if file_ext == '.wav':
        sr, data = nii_wav_tk.waveReadAsFloat(file_path)
        length = data.shape[0]
    elif file_ext == '.flac':
        sr, data = nii_wav_tk.flacReadAsFloat(file_path)
        length = data.shape[0]
    elif file_ext == '.txt':
        # txt, no need to account length
        # note that this is for tts task
        length = 0
    else:
        length = nii_io_tk.f_read_raw_mat_length(file_path)
    return length

###
# Definition of DataSet
###
class NIIDataSet(torch.utils.data.Dataset):
    """ General class for NII speech dataset
    For definition of customized Dataset, please refer to 
    https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
    """
    def __init__(self,
                 dataset_name, \
                 file_list, \
                 input_dirs, input_exts, input_dims, input_reso, \
                 input_norm, \
                 output_dirs, output_exts, output_dims, output_reso, \
                 output_norm, \
                 stats_path, \
                 data_format = nii_dconf.h_dtype_str, \
                 truncate_seq = None, \
                 min_seq_len = None, \
                 save_mean_std = True, \
                 wav_samp_rate = None, \
                 flag_lang = 'EN', \
                 global_arg = None):
        """
        args
        ----
          dataset_name: name of this data set
          file_list: a list of file name strings (without extension)
                     or, path to the file that contains the file names
          input_dirs: a list of dirs from which input feature is loaded
          input_exts: a list of input feature name extentions
          input_dims: a list of input feature dimensions
          input_reso: a list of input feature temporal resolutions
          input_norm: a list of bool, whether normalize input feature or not
          output_dirs: a list of dirs from which output feature is loaded
          output_exts: a list of output feature name extentions
          output_dims: a list of output feature dimensions
          output_reso: a list of output feature temporal resolutions
          output_norm: a list of bool, whether normalize target feature or not
          stat_path: path to the directory that saves mean/std, 
                     utterance length
          data_format: method to load the data
                    '<f4' (default): load data as float32m little-endian
                    'htk': load data as htk format
          truncate_seq: None (default) or int, truncate sequence into truncks.
                        truncate_seq > 0 specifies the trunck length 
          min_seq_len: None (default) or int, minimum length of an utterance
                        utterance shorter than min_seq_len will be ignored
          save_mean_std: bool, True (default): save mean and std 
          wav_samp_rate: None (default) or int, if input data has  waveform, 
                         please set sampling rate. It is used by _data_writer
          flag_lang: str, 'EN' (default), if input data has text, the text will
                     be converted into code indices. flag_lang indicates the 
                     language for the text processer. It is used by _data_reader
          global_arg: argument parser returned by arg_parse.f_args_parsed()
                      default None
        """
        # initialization
        self.m_set_name = dataset_name
        self.m_file_list = file_list
        self.m_input_dirs = input_dirs
        self.m_input_exts = input_exts
        self.m_input_dims = input_dims
        
        self.m_output_dirs = output_dirs
        self.m_output_exts = output_exts
        self.m_output_dims = output_dims

        if len(self.m_input_dirs) != len(self.m_input_exts) or \
           len(self.m_input_dirs) != len(self.m_input_dims):
            nii_warn.f_print("Input dirs, exts, dims, unequal length",
                             'error')
            nii_warn.f_print(str(self.m_input_dirs), 'error')
            nii_warn.f_print(str(self.m_input_exts), 'error')
            nii_warn.f_print(str(self.m_input_dims), 'error')
            nii_warn.f_die("Please check input dirs, exts, dims")

        if len(self.m_output_dims) != len(self.m_output_exts) or \
           (self.m_output_dirs and \
            len(self.m_output_dirs) != len(self.m_output_exts)):
            nii_warn.f_print("Output dirs, exts, dims, unequal length", \
                             'error')
            nii_warn.f_die("Please check output dirs, exts, dims")

        # fill in m_*_reso and m_*_norm
        def _tmp_f(list2, default_value, length):
            if list2 is None:
                return [default_value for x in range(length)]
            else:
                return list2
            
        self.m_input_reso = _tmp_f(input_reso, 1, len(input_dims))
        self.m_input_norm = _tmp_f(input_norm, True, len(input_dims))
        self.m_output_reso = _tmp_f(output_reso, 1, len(output_dims))
        self.m_output_norm = _tmp_f(output_norm, True, len(output_dims))
        if len(self.m_input_reso) != len(self.m_input_dims):
            nii_warn.f_die("len(input_reso) != len(input_dims) in config")
        if len(self.m_output_reso) != len(self.m_output_dims):
            nii_warn.f_die("len(output_reso) != len(input_dims) in config")
        if len(self.m_input_norm) != len(self.m_input_dims):
            nii_warn.f_die("len(input_norm) != len(input_dims) in config")
        if len(self.m_output_norm) != len(self.m_output_dims):
            nii_warn.f_die("len(output_norm) != len(output_dims) in config")
        
        # dimensions
        self.m_input_all_dim = sum(self.m_input_dims)
        self.m_output_all_dim = sum(self.m_output_dims)
        self.m_io_dim = self.m_input_all_dim + self.m_output_all_dim

        self.m_truncate_seq = truncate_seq
        self.m_min_seq_len = min_seq_len
        self.m_save_ms = save_mean_std

        # in case there is waveform data in input or output features 
        self.m_wav_sr = wav_samp_rate
        # option to process waveform with simple VAD
        if global_arg is not None:
            self.m_opt_wav_handler = global_arg.opt_wav_silence_handler
        else:
            self.m_opt_wav_handler = 0

        # in case there is text data in input or output features
        self.m_flag_lang = flag_lang

        # sanity check on resolution configuration
        # currently, only input features can have different reso,
        # and the m_input_reso must be the same for all input features
        if any([x != self.m_input_reso[0] for x in self.m_input_reso]):
            nii_warn.f_print("input_reso: %s" % (str(self.m_input_reso)),\
                             'error')
            nii_warn.f_print("NIIDataSet not support", 'error', end='')
            nii_warn.f_die(" different input_reso")

        if any([x != self.m_output_reso[0] for x in self.m_output_reso]):
            nii_warn.f_print("output_reso: %s" % (str(self.m_output_reso)),\
                             'error')
            nii_warn.f_print("NIIDataSet not support", 'error', end='')
            nii_warn.f_die(" different output_reso")
        if np.any(np.array(self.m_output_reso) < 0):
            nii_warn.f_print("NIIDataSet not support negative reso", 
                             'error', end='')
            nii_warn.f_die(" Output reso: %s" % (str(self.m_output_reso)))
        if np.any(np.array(self.m_input_reso) < 0):
            nii_warn.f_print("Input resolution: %s" % (str(self.m_input_reso)))
            nii_warn.f_print("Data IO for unaligned input and output pairs")
            if truncate_seq is not None:
                nii_warn.f_print("truncate is set to None", 'warning')
                self.m_truncate_seq = None
                self.m_min_seq_len = None


        # no need to contrain output_reso = 1
        #if any([x != 1 for x in self.m_output_reso]):
        #    nii_warn.f_print("NIIDataSet only supports", 'error', end='')
        #    nii_warn.f_die(" output_reso = [1, 1, ... 1]")
        #self.m_single_reso = self.m_input_reso[0]
        self.m_single_reso = np.max(self.m_input_reso + self.m_output_reso)
            
        # To make sure that target waveform length is exactly equal
        #  to the up-sampled sequence length
        # self.m_truncate_seq must be changed to be N * up_sample
        if self.m_truncate_seq is not None:
            # assume input resolution is the same
            self.m_truncate_seq = self.f_adjust_len(self.m_truncate_seq)

        # similarly on self.m_min_seq_len
        if self.m_min_seq_len is not None:
            # assume input resolution is the same
            self.m_min_seq_len = self.f_adjust_len(self.m_min_seq_len)

        # method to load/write raw data
        if data_format == nii_dconf.h_dtype_str:
            self.f_load_data = lambda x, y: _data_reader(x, y, self.m_flag_lang)
            self.f_length_data = _data_len_reader
            self.f_write_data = lambda x, y: _data_writer(x, y, self.m_wav_sr)
        else:
            nii_warn.f_print("Unsupported dtype %s" % (data_format))
            nii_warn.f_die("Only supports %s " % (nii_dconf.h_dtype_str))
            
        # check the validity of data
        self.f_check_file_list()
        
        # log down statiscs 
        #  1. length of each data utterance
        #  2. mean / std of feature feature file
        def get_name(stats_path, set_name, file_name):
            tmp = set_name + '_' + file_name
            return os.path.join(stats_path, tmp)
        
        self.m_ms_input_path = get_name(stats_path, self.m_set_name, \
                                        nii_dconf.mean_std_i_file)
        self.m_ms_output_path = get_name(stats_path, self.m_set_name, \
                                         nii_dconf.mean_std_o_file)
        self.m_data_len_path = get_name(stats_path, self.m_set_name, \
                                        nii_dconf.data_len_file)
        
        # initialize data length and mean /std, read prepared data stats
        flag_cal_len = self.f_init_data_len_stats(self.m_data_len_path)
        flag_cal_mean_std = self.f_init_mean_std(self.m_ms_input_path,
                                                 self.m_ms_output_path)
            
        # if data information is not available, read it again from data
        if flag_cal_len or flag_cal_mean_std:
            self.f_calculate_stats(flag_cal_len, flag_cal_mean_std) 
            
        # check
        if self.__len__() < 1:
            nii_warn.f_print("Fail to load any data", "error")
            nii_warn.f_print("Possible reasons: ", "error")
            mes = "1. Old cache %s. Please delete it." % (self.m_data_len_path)
            mes += "\n2. input_dirs, input_exts, "
            mes += "output_dirs, or output_exts incorrect."
            mes += "\n3. all data are less than minimum_len in length. "
            mes += "\nThe last case may happen if truncate_seq == mininum_len "
            mes += "and truncate_seq % input_reso != 0. Then, the actual "
            mes += "truncate_seq becomes truncate_seq//input_reso*input_reso "
            mes += "and it will be shorter than minimum_len. Please change "
            mes += "truncate_seq and minimum_len so that "
            mes += "truncate_seq % input_reso == 0."
            nii_warn.f_print(mes, "error")
            nii_warn.f_die("Please check configuration file")
        # done
        return                
        
    def __len__(self):
        """ __len__():
        Return the number of samples in the list
        """
        return len(self.m_seq_info)

    def __getitem__(self, idx):
        """ __getitem__(self, idx):
        Return input, output
        
        For test set data, output can be None
        """
        try:
            tmp_seq_info = self.m_seq_info[idx]
        except IndexError:
            nii_warn.f_die("Sample %d is not in seq_info" % (idx))

        # file_name
        file_name = tmp_seq_info.seq_tag()
        
        # For input data
        input_reso = self.m_input_reso[0]
        seq_len = int(tmp_seq_info.seq_length() // input_reso)
        s_idx = int(tmp_seq_info.seq_start_pos() // input_reso)
        e_idx = s_idx + seq_len
        
        # in case the input length not account using tmp_seq_info.seq_length
        if seq_len < 0:
            seq_len = 0
            s_idx = 0
            e_idx = 0

        input_dim = self.m_input_all_dim
        in_data = np.zeros([seq_len, input_dim], dtype=nii_dconf.h_dtype)
        s_dim = 0
        e_dim = 0

        # loop over each feature type
        for t_dir, t_ext, t_dim, t_res in \
            zip(self.m_input_dirs, self.m_input_exts, \
                self.m_input_dims, self.m_input_reso):
            e_dim = s_dim + t_dim
            
            # get file path and load data
            file_path = nii_str_tk.f_realpath(t_dir, file_name, t_ext)
            try:
                tmp_d = self.f_load_data(file_path, t_dim) 
            except IOError:
                nii_warn.f_die("Cannot find %s" % (file_path))

            # write data
            if t_res < 0:
                # if this is for input data not aligned with output
                # make sure that the input is in shape (seq_len, dim)
                #  f_load_data should return data in shape (seq_len, dim)
                if tmp_d.ndim == 1:
                    in_data = np.expand_dims(tmp_d, axis=1)
                elif tmp_d.ndim == 2:
                    in_data = tmp_d
                else:
                    nii_warn.f_die("Default IO cannot handle %s" % (file_path))
            elif tmp_d.shape[0] == 1:
                # input data has only one frame, duplicate
                if tmp_d.ndim > 1:
                    in_data[:,s_dim:e_dim] = tmp_d[0,:]
                elif t_dim == 1:
                    in_data[:,s_dim] = tmp_d
                else:
                    nii_warn.f_die("Dimension wrong %s" % (file_path))
            else:
                # normal case
                if tmp_d.ndim > 1:
                    # write multi-dimension data
                    in_data[:,s_dim:e_dim] = tmp_d[s_idx:e_idx,:]
                elif t_dim == 1:
                    # write one-dimension data
                    in_data[:,s_dim] = tmp_d[s_idx:e_idx]
                else:
                    nii_warn.f_die("Dimension wrong %s" % (file_path))
            s_dim = e_dim

        # load output data
        if self.m_output_dirs:
            output_reso = self.m_output_reso[0]
            seq_len = int(tmp_seq_info.seq_length() // output_reso)
            s_idx = int(tmp_seq_info.seq_start_pos() // output_reso)
            e_idx = s_idx + seq_len
        
            out_dim = self.m_output_all_dim
            out_data = np.zeros([seq_len, out_dim], \
                                dtype = nii_dconf.h_dtype)
            s_dim = 0
            e_dim = 0
            for t_dir, t_ext, t_dim in zip(self.m_output_dirs, \
                                           self.m_output_exts, \
                                           self.m_output_dims):
                e_dim = s_dim + t_dim
                # get file path and load data
                file_path = nii_str_tk.f_realpath(t_dir, file_name, t_ext)
                try:
                    tmp_d = self.f_load_data(file_path, t_dim) 
                except IOError:
                    nii_warn.f_die("Cannot find %s" % (file_path))

                if tmp_d.shape[0] == 1:
                    if tmp_d.ndim > 1:
                        out_data[:,s_dim:e_dim] = tmp_d[0,:]
                    elif t_dim == 1:
                        out_data[:,s_dim]=tmp_d
                    else:
                        nii_warn.f_die("Dimension wrong %s" % (file_path))
                else:
                    if tmp_d.ndim > 1:
                        out_data[:,s_dim:e_dim] = tmp_d[s_idx:e_idx,:]
                    elif t_dim == 1:
                        out_data[:,s_dim]=tmp_d[s_idx:e_idx]
                    else:
                        nii_warn.f_die("Dimension wrong %s" % (file_path))
                s_dim = s_dim + t_dim
        else:
            out_data = []
        
        # post processing if necessary
        in_data, out_data, tmp_seq_info, idx = self.f_post_data_process(
            in_data, out_data, tmp_seq_info, idx)

        # return data
        return in_data, out_data, tmp_seq_info.print_to_str(), idx


    def f_post_data_process(self, in_data, out_data, seq_info, idx):
        """A wrapper to process the data after loading from files
        """

        if self.m_opt_wav_handler > 0:
        
            # Do post processing one by one
            tmp_seq_info = nii_seqinfo.SeqInfo(
                seq_info.length, seq_info.seq_name, seq_info.seg_idx,
                seq_info.start_pos, seq_info.info_id)
        
            # waveform silence handler
            if len(self.m_input_exts) == 1 \
               and self.m_input_exts[0][-3:] == 'wav':
                in_data_n = nii_wav_tk.silence_handler(
                    in_data[:, 0], self.m_wav_sr, 
                    flag_output = self.m_opt_wav_handler)
                in_data_n = np.expand_dims(in_data_n, axis=1)
            
                # this is temporary setting, use length if it is compatible
                if tmp_seq_info.length == in_data.shape[0]:
                    tmp_seq_info.length = in_data_n.shape[0]
            else:
                in_data_n = in_data

            if len(self.m_output_exts) == 1 \
               and self.m_output_exts[0][-3:] == 'wav':
                out_data_n = nii_wav_tk.silence_handler(
                    out_data[:,0], self.m_wav_sr, 
                    flag_output = self.m_opt_wav_handler)
                out_data_n = np.expand_dims(out_data_n, axis=1)
            
                # this is temporary setting, use length if it is compatible
                if tmp_seq_info.length == out_data.shape[0]:
                    tmp_seq_info.length = out_data_n.shape[0]
            else:
                out_data_n = out_data

            return in_data_n, out_data_n, tmp_seq_info, idx

        else:
            return in_data, out_data, seq_info, idx
        
    
    def f_get_num_seq(self):
        """ __len__():
        Return the number of samples in the list
        """
        return len(self.m_seq_info)

    def f_get_seq_len_list(self):
        """ Return length of each sequence as list
        """
        return [x.seq_length() for x in self.m_seq_info]
    
    def f_get_mean_std_tuple(self):
        return (self.m_input_mean, self.m_input_std,
                self.m_output_mean, self.m_output_std)

    
    def f_check_file_list(self):
        """ f_check_file_list():
            Check the file list after initialization
            Make sure that the file in file_list appears in every 
            input/output feature directory. 
            If not, get a file_list in which every file is avaiable
            in every input/output directory
        """
        if not isinstance(self.m_file_list, list):
            if isinstance(self.m_file_list, str) and \
               os.path.isfile(self.m_file_list):
                # read the list if m_file_list is a str
                self.m_file_list = nii_list_tools.read_list_from_text(
                    self.m_file_list)
            else:
                nii_warn.f_print("Cannot read {:s}".format(self.m_file_list))
                nii_warn.f_print("Read file list from directories")
                self.m_file_list = None
        
        #  get a initial file list
        if self.m_file_list is None:
            self.m_file_list = nii_list_tools.listdir_with_ext(
                self.m_input_dirs[0], self.m_input_exts[0])

        # check the list of files exist in all input/output directories
        for tmp_d, tmp_e in zip(self.m_input_dirs, \
                                self.m_input_exts):
            tmp_list = nii_list_tools.listdir_with_ext(tmp_d, tmp_e)
            self.m_file_list = nii_list_tools.common_members(
                tmp_list, self.m_file_list)

        if len(self.m_file_list) < 1:
            nii_warn.f_print("No input features found after scannning", 'error')
            nii_warn.f_print("Please check %s" \
                             % (str(self.m_input_dirs)), 'error')
            nii_warn.f_print("They should contain all files in file list", 
                             'error')
            nii_warn.f_print("Please also check filename extentions %s" \
                             % (str(self.m_input_exts)), 'error')
            nii_warn.f_print("They should be correctly specified", 'error')
            nii_warn.f_die("Failed to read input features")
            
        # check output files if necessary
        if self.m_output_dirs:
            for tmp_d, tmp_e in zip(self.m_output_dirs, \
                                    self.m_output_exts):
                tmp_list = nii_list_tools.listdir_with_ext(tmp_d, tmp_e)
                self.m_file_list = nii_list_tools.common_members(
                    tmp_list, self.m_file_list)

            if len(self.m_file_list) < 1:
                nii_warn.f_print("No output data found", 'error')
                nii_warn.f_print("Please check %s" \
                                 % (str(self.m_output_dirs)), 'error')
                nii_warn.f_print("They should contain all files in file list", 
                                 'error')
                nii_warn.f_print("Please also check filename extentions %s" \
                                 % (str(self.m_output_exts)), 'error')
                nii_warn.f_print("They should be correctly specified", 'error')
                nii_warn.f_die("Failed to read output features")
        else:
            #nii_warn.f_print("Not loading output features")
            pass
        
        # done
        return
        

    def f_valid_len(self, t_1, t_2, min_length):
        """ f_valid_time_steps(time_step1, time_step2, min_length)
        When either t_1 > min_length or t_2 > min_length, check whether 
        time_step1 and time_step2 are too different       
        """
        if max(t_1, t_2) > min_length:
            if (np.abs(t_1 - t_2) * 1.0 / t_1) > 0.1:
                return False
        return True

    def f_check_specific_data(self, file_name):
        """ check the data length of a specific file
        """
        tmp_dirs = self.m_input_dirs.copy()
        tmp_exts = self.m_input_exts.copy()
        tmp_dims = self.m_input_dims.copy()
        tmp_reso = self.m_input_reso.copy()
        tmp_dirs.extend(self.m_output_dirs)
        tmp_exts.extend(self.m_output_exts)
        tmp_dims.extend(self.m_output_dims)
        tmp_reso.extend(self.m_output_reso)        
        
        # loop over each input/output feature type
        for t_dir, t_ext, t_dim, t_res in \
            zip(tmp_dirs, tmp_exts, tmp_dims, tmp_reso):

            file_path = nii_str_tk.f_realpath(t_dir, file_name, t_ext)
            if not nii_io_tk.file_exist(file_path):
                nii_warn.f_die("%s not found" % (file_path))
            else:        
                t_len  = self.f_length_data(file_path) // t_dim
                print("%s, length %d, dim %d, reso: %d" % \
                      (file_path, t_len, t_dim, t_res))
        return

        
    def f_log_data_len(self, file_name, t_len, t_reso):
        """ f_log_data_len(file_name, t_len, t_reso):
        Log down the length of the data file.

        When comparing the different input/output features for the same
        file_name, only keep the shortest length
        """
        
        # We need to exclude features that should not be considered when
        #  calculating the sequence length
        #  1. sentence-level vector (t_len = 1)
        #  2. unaligned feature (text in text-to-speech) (t_reso < 0)
        valid_flag = t_len > 1 and t_reso > 0
        
        if valid_flag:
            # the length for the sequence with the fast tempoeral rate
            # For example, acoustic-feature -> waveform 16kHz,
            # if acoustic-feature is one frame per 5ms,
            #  tmp_len = acoustic feature frame length * (5 * 16)
            # where t_reso = 5*16 is the up-sampling rate of acoustic feature
            tmp_len = t_len * t_reso
        
            # save length when have not read the file
            if file_name not in self.m_data_length:
                self.m_data_length[file_name] = tmp_len

            # check length
            if t_len == 1:
                # cannot come here, keep this line as history
                # if this is an utterance-level feature, it has only 1 frame
                pass
            elif self.f_valid_len(self.m_data_length[file_name], tmp_len, \
                                  nii_dconf.data_seq_min_length):
                # if the difference in length is small
                if self.m_data_length[file_name] > tmp_len:
                    self.m_data_length[file_name] = tmp_len
            else:
                nii_warn.f_print("Sequence length mismatch:", 'error')
                self.f_check_specific_data(file_name)
                nii_warn.f_print("Please the above features", 'error')
                nii_warn.f_die("Possible invalid data %s" % (file_name))

            # adjust the length so that, when reso is used,
            # the sequence length will be N * reso
            tmp = self.m_data_length[file_name]
            self.m_data_length[file_name] = self.f_adjust_len(tmp)
        else:
            # do nothing for unaligned input or sentence-level input
            pass
        
        return

    def f_adjust_len(self, length):
        """ When input data will be up-sampled by self.m_single_reso,
        Make sure that the sequence length at the up-sampled level is
         = N * self.m_single_reso
        For data without up-sampling m_single_reso = 1
        """
        return length // self.m_single_reso * self.m_single_reso

    def f_precheck_data_length(self):
        """ For unaligned input and output, there is no way to know the 
        target sequence length before hand during inference stage
        
        self.m_data_length will be empty
        """
        
        if not self.m_data_length and not self.m_output_dirs and \
           all([x < 0 for x in self.m_input_reso]):
            # inference stage, when only input is given
            # manually create a fake data length for each utterance
            for file_name in self.m_file_list:
                self.m_data_length[file_name] = 0
        return

        
    
    def f_log_seq_info(self):
        """ After m_data_length has been created, create seq_info
        
        """
        for file_name in self.m_file_list:

            # if file_name is not logged, ignore this file
            if file_name not in self.m_data_length:
                nii_warn.f_eprint("Exclude %s from dataset" % (file_name))
                continue
            
            # if not truncate, save the seq_info directly
            # otherwise, save truncate_seq info
            length_remain = self.m_data_length[file_name]
            start_pos = 0
            seg_idx = 0
            if self.m_truncate_seq is not None:
                while(length_remain > 0):
                    info_idx = len(self.m_seq_info)
                    seg_length = min(self.m_truncate_seq, length_remain)
                    seq_info = nii_seqinfo.SeqInfo(seg_length, 
                                                   file_name, seg_idx,
                                                   start_pos, info_idx)
                    if self.m_min_seq_len is None or \
                       seg_length >= self.m_min_seq_len:
                        self.m_seq_info.append(seq_info)
                        seg_idx += 1
                    start_pos += seg_length
                    length_remain -= seg_length
            else:
                info_idx = len(self.m_seq_info)
                seq_info = nii_seqinfo.SeqInfo(length_remain,
                                               file_name, seg_idx,
                                               start_pos, info_idx)
                if self.m_min_seq_len is None or \
                   length_remain >= self.m_min_seq_len:
                    self.m_seq_info.append(seq_info)
        
        # get the total length
        self.m_data_total_length = self.f_sum_data_length()
        return
        
    def f_init_mean_std(self, ms_input_path, ms_output_path):
        """ f_init_mean_std
        Initialzie mean and std vectors for input and output
        """
        self.m_input_mean = np.zeros([self.m_input_all_dim])
        self.m_input_std = np.ones([self.m_input_all_dim])
        self.m_output_mean = np.zeros([self.m_output_all_dim])
        self.m_output_std = np.ones([self.m_output_all_dim])
        
        flag = True
        if not self.m_save_ms:
            # assume mean/std will be loaded from the network
            # for example, for validation and test sets
            flag = False

        if not any(self.m_input_norm + self.m_output_norm):
            # none of the input / output features needs norm
            flag = False

        if os.path.isfile(ms_input_path) and \
           os.path.isfile(ms_output_path):
            # load mean and std if exists
            ms_input = self.f_load_data(ms_input_path, 1)
            ms_output = self.f_load_data(ms_output_path, 1)
            
            if ms_input.shape[0] != (self.m_input_all_dim * 2) or \
               ms_output.shape[0] != (self.m_output_all_dim * 2):
                if ms_input.shape[0] != (self.m_input_all_dim * 2):
                    nii_warn.f_print("%s incompatible" % (ms_input_path),
                                     'warning')
                if ms_output.shape[0] != (self.m_output_all_dim * 2):
                    nii_warn.f_print("%s incompatible" % (ms_output_path),
                                     'warning')
                nii_warn.f_print("mean/std will be recomputed", 'warning')
            else:
                self.m_input_mean = ms_input[0:self.m_input_all_dim]
                self.m_input_std = ms_input[self.m_input_all_dim:]
                
                self.m_output_mean = ms_output[0:self.m_output_all_dim]
                self.m_output_std = ms_output[self.m_output_all_dim:]
                nii_warn.f_print("Load mean/std from %s and %s" % \
                                 (ms_input_path, ms_output_path))
                flag = False
        return flag


    def f_sum_data_length(self):
        """
        """
        
        return sum([x.seq_length() for x in self.m_seq_info])
        
    def f_init_data_len_stats(self, data_path):
        """
        flag = f_init_data_len_stats(self, data_path)
        Check whether data length has been stored in data_pat.
        If yes, load data_path and return False
        Else, return True
        """
        self.m_seq_info = []
        self.m_data_length = {}
        self.m_data_total_length = 0
        
        flag = True
        if os.path.isfile(data_path):
            # load data length from pre-stored *.dic
            dic_seq_infos = nii_io_tk.read_dic(self.m_data_len_path)
            for dic_seq_info in dic_seq_infos:
                seq_info = nii_seqinfo.SeqInfo()
                seq_info.load_from_dic(dic_seq_info)
                self.m_seq_info.append(seq_info)
                seq_tag = seq_info.seq_tag()
                if seq_tag not in self.m_data_length:
                    self.m_data_length[seq_tag] = seq_info.seq_length()
                else:
                    self.m_data_length[seq_tag] += seq_info.seq_length()
            self.m_data_total_length = self.f_sum_data_length()
            
            # check whether *.dic contains files in filelist
            # note: one file is not found in self.m_data_length if it
            #  is shorter than the truncate_seq
            if nii_list_tools.list_identical(self.m_file_list,\
                                             self.m_data_length.keys()):
                nii_warn.f_print("Read sequence info: %s" % (data_path))
                flag = False
            elif nii_list_tools.list_b_in_list_a(self.m_file_list, 
                                                 self.m_data_length.keys()):
                nii_warn.f_print("Read sequence info: %s" % (data_path))
                nii_warn.f_print(
                    "However %d samples are ignoed" % \
                    (len(self.m_file_list)-len(self.m_data_length)))
                tmp = nii_list_tools.members_in_a_not_in_b(
                    self.m_file_list, self.m_data_length.keys())
                for tmp_name in tmp:
                    nii_warn.f_eprint("Exclude %s from dataset" % (tmp_name))
                                    
                flag = False
            else:
                self.m_seq_info = []
                self.m_data_length = {}
                self.m_data_total_length = 0

        return flag

    def f_save_data_len(self, data_len_path):
        """
        """
        nii_io_tk.write_dic([x.print_to_dic() for x in self.m_seq_info], \
                            data_len_path)
        
    def f_save_mean_std(self, ms_input_path, ms_output_path):
        """
        """
        # save mean and std
        ms_input = np.zeros([self.m_input_all_dim * 2])
        ms_input[0:self.m_input_all_dim] = self.m_input_mean
        ms_input[self.m_input_all_dim :] = self.m_input_std
        self.f_write_data(ms_input, ms_input_path)

        ms_output = np.zeros([self.m_output_all_dim * 2])
        ms_output[0:self.m_output_all_dim] = self.m_output_mean
        ms_output[self.m_output_all_dim :] = self.m_output_std
        self.f_write_data(ms_output, ms_output_path)

        return

    def f_print_info(self):
        """
        """
        mes = "Dataset {}:".format(self.m_set_name)
        mes += "\n  Time steps: {:d} ".format(self.m_data_total_length)
        if self.m_truncate_seq is not None:
            mes += "\n  Truncate length: {:d}".format(self.m_truncate_seq)
        mes += "\n  Data sequence num: {:d}".format(len(self.m_seq_info))
        tmp_min_len = min([x.seq_length() for x in self.m_seq_info])
        tmp_max_len = max([x.seq_length() for x in self.m_seq_info])
        mes += "\n  Maximum sequence length: {:d}".format(tmp_max_len)
        mes += "\n  Minimum sequence length: {:d}".format(tmp_min_len)
        if self.m_min_seq_len is not None:
            mes += "\n  Shorter sequences are ignored"
        mes += "\n  Inputs\n    Dirs:"
        for subdir in self.m_input_dirs:
            mes += "\n        {:s}".format(subdir)
        mes += "\n    Exts:{:s}".format(str(self.m_input_exts))
        mes += "\n    Dims:{:s}".format(str(self.m_input_dims))
        mes += "\n    Reso:{:s}".format(str(self.m_input_reso))
        mes += "\n    Norm:{:s}".format(str(self.m_input_norm))
        mes += "\n  Outputs\n    Dirs:"
        for subdir in  self.m_output_dirs:
            mes += "\n        {:s}".format(subdir)
        mes += "\n    Exts:{:s}".format(str(self.m_output_exts))
        mes += "\n    Dims:{:s}".format(str(self.m_output_dims))
        mes += "\n    Reso:{:s}".format(str(self.m_output_reso))
        mes += "\n    Norm:{:s}".format(str(self.m_output_norm))
        if self.m_opt_wav_handler > 0:
            mes += "\n  Waveform silence handler will be used"
        nii_warn.f_print_message(mes)
        return
    
    def f_calculate_stats(self, flag_cal_data_len, flag_cal_mean_std):
        """ f_calculate_stats
        Log down the number of time steps for each file
        Calculate the mean/std
        """
        # check
        #if not self.m_output_dirs:
        #    nii_warn.f_print("Calculating mean/std", 'error')
        #    nii_warn.f_die("But output_dirs is not provided")

        # prepare the directory, extension, and dimensions
        tmp_dirs = self.m_input_dirs.copy()
        tmp_exts = self.m_input_exts.copy()
        tmp_dims = self.m_input_dims.copy()
        tmp_reso = self.m_input_reso.copy()
        tmp_norm = self.m_input_norm.copy()        
        tmp_dirs.extend(self.m_output_dirs)
        tmp_exts.extend(self.m_output_exts)
        tmp_dims.extend(self.m_output_dims)
        tmp_reso.extend(self.m_output_reso)
        tmp_norm.extend(self.m_output_norm)
        
        # starting dimension of one type of feature
        s_dim = 0
        # ending dimension of one type of feature        
        e_dim = 0
        
        # loop over each input/output feature type
        for t_dir, t_ext, t_dim, t_reso, t_norm in \
            zip(tmp_dirs, tmp_exts, tmp_dims, tmp_reso, tmp_norm):
            
            s_dim = e_dim
            e_dim = s_dim + t_dim
            t_cnt = 0
            mean_i, var_i = np.zeros([t_dim]), np.zeros([t_dim])
            
            # loop over all the data
            for file_name in self.m_file_list:
                # get file path
                file_path = nii_str_tk.f_realpath(t_dir, file_name, t_ext)
                if not nii_io_tk.file_exist(file_path):
                    nii_warn.f_die("%s not found" % (file_path))
                    
                # read the length of the data
                if flag_cal_data_len:
                    t_len  = self.f_length_data(file_path) // t_dim
                    self.f_log_data_len(file_name, t_len, t_reso)
                    
                    
                # accumulate the mean/std recursively
                if flag_cal_mean_std:
                    t_data  = self.f_load_data(file_path, t_dim)

                    # if the is F0 data, only consider voiced data
                    if t_ext in nii_dconf.f0_unvoiced_dic:
                        unvoiced_value = nii_dconf.f0_unvoiced_dic[t_ext]
                        t_data = t_data[t_data > unvoiced_value]
                    # mean_i, var_i, t_cnt will be updated using online
                    # accumulation method
                    mean_i, var_i, t_cnt = nii_stats.f_online_mean_std(
                        t_data, mean_i, var_i, t_cnt)

            # save mean and std for one feature type
            if flag_cal_mean_std:
                # if not normalize this dimension, set mean=0, std=1
                if not t_norm:
                    mean_i[:] = 0
                    var_i[:] = 1
                    
                if s_dim < self.m_input_all_dim:
                    self.m_input_mean[s_dim:e_dim] = mean_i

                    std_i = nii_stats.f_var2std(var_i)
                    self.m_input_std[s_dim:e_dim] = std_i
                else:
                    tmp_s = s_dim - self.m_input_all_dim
                    tmp_e = e_dim - self.m_input_all_dim
                    self.m_output_mean[tmp_s:tmp_e] = mean_i
                    std_i = nii_stats.f_var2std(var_i)
                    self.m_output_std[tmp_s:tmp_e] = std_i

        if flag_cal_data_len:
            # 
            self.f_precheck_data_length()
            # create seq_info
            self.f_log_seq_info()
            # save len information
            self.f_save_data_len(self.m_data_len_path)
            
        if flag_cal_mean_std:
            self.f_save_mean_std(self.m_ms_input_path,
                                 self.m_ms_output_path)
        # done
        return
        
    def f_putitem(self, output_data, save_dir, data_infor_str):
        """ 
        """
        # Change the dimension to (length, dim)
        if output_data.ndim == 3 and output_data.shape[0] == 1:
            # When input data is (batchsize=1, length, dim)
            output_data = output_data[0]
        elif output_data.ndim == 2 and output_data.shape[0] == 1:
            # When input data is (batchsize=1, length)
            output_data = np.expand_dims(output_data[0], -1)
        else:
            nii_warn.f_print("Output data format not supported.", "error")
            nii_warn.f_print("Format is not (batch, len, dim)", "error")
            nii_warn.f_die("Please use batch_size = 1 in generation")

        # Save output
        if output_data.shape[1] != self.m_output_all_dim:
            nii_warn.f_print("Output data dim != expected dim", "error")
            nii_warn.f_print("Output:%d" % (output_data.shape[1]), \
                             "error")
            nii_warn.f_print("Expected:%d" % (self.m_output_all_dim), \
                             "error")
            nii_warn.f_die("Please check configuration")
        
        if not os.path.isdir(save_dir):
            try:
                os.mkdir(save_dir)
            except OSError:
                nii_warn.f_die("Cannot carete {}".format(save_dir))

        # read the sentence information
        tmp_seq_info = nii_seqinfo.SeqInfo()
        tmp_seq_info.parse_from_str(data_infor_str)

        # write the data
        file_name = tmp_seq_info.seq_tag()
        s_dim = 0
        e_dim = 0
        for t_ext, t_dim in zip(self.m_output_exts, self.m_output_dims):
            e_dim = s_dim + t_dim
            file_path = nii_str_tk.f_realpath(save_dir, file_name, t_ext)
            self.f_write_data(output_data[:, s_dim:e_dim], file_path)
        
        return

    def f_input_dim(self):
        """
        f_input_dim()
        return the total dimension of input features
        """ 
        return self.m_input_all_dim
    
    def f_output_dim(self):
        """
        f_output_dim
        return the total dimension of output features
        """
        return self.m_output_all_dim

    def f_adjust_idx(self, data_tuple, idx_shift):
        """
        f_adjust_idx

        This is to be used by customize_dataset for idx adjustment.
        When multiple data sets are merged, the idx from __getitem__
        should be adjusted.

        Only data_io itselts knows how to identify idx from the output of
        __getitem__, we need to define the function here
        """
        for idx in np.arange(len(data_tuple[-1])):
            data_tuple[-1][idx] += idx_shift
        return data_tuple
    
class NIIDataSetLoader:
    """ NIIDataSetLoader:
    A wrapper over torch.utils.data.DataLoader 
    
    self.m_dataset will be the dataset
    self.m_loader  will be the dataloader
    """
    def __init__(self,
                 dataset_name, \
                 file_list, \
                 input_dirs, input_exts, input_dims, input_reso, \
                 input_norm, \
                 output_dirs, output_exts, output_dims, output_reso, \
                 output_norm, \
                 stats_path, \
                 data_format = nii_dconf.h_dtype_str, \
                 params = None, \
                 truncate_seq = None, \
                 min_seq_len = None,
                 save_mean_std = True, \
                 wav_samp_rate = None, \
                 flag_lang = 'EN',
                 global_arg = None):
        """
        NIIDataSetLoader(
               data_set_name,
               file_list,
               input_dirs, input_exts, input_dims, input_reso, input_norm,
               output_dirs, output_exts, output_dims, output_reso, output_norm,
               stats_path,
               data_format = '<f4',
               params = None,
               truncate_seq = None,
               min_seq_len = None,
               save_mean_std = True, \
               wav_samp_rate = None, \
               flag_lang = 'EN',
               global_arg = None):
        Args
        ----
            data_set_name: a string to name this dataset
                           this will be used to name the statistics files
                           such as the mean/std for this dataset
            file_list: a list of file name strings (without extension)
                     or, path to the file that contains the file names
            input_dirs: a list of dirs from which input feature is loaded
            input_exts: a list of input feature name extentions
            input_dims: a list of input feature dimensions
            input_reso: a list of input feature temporal resolution,
                        or None
            input_norm: a list of bool, whether normalize input feature or not

            output_dirs: a list of dirs from which output feature is loaded
            output_exts: a list of output feature name extentions
            output_dims: a list of output feature dimensions
            output_reso: a list of output feature temporal resolution, 
                         or None
            output_norm: a list of bool, whether normalize target feature or not

            stats_path: path to the directory of statistics(mean/std)
            data_format: method to load the data
                    '<f4' (default): load data as float32m little-endian
                    'htk': load data as htk format
            params: parameter for torch.utils.data.DataLoader

            truncate_seq: None or int, 
                          truncate data sequence into smaller truncks
                          truncate_seq > 0 specifies the trunck length
            min_seq_len: None (default) or int, minimum length of an utterance
                         utterance shorter than min_seq_len will be ignored
            save_mean_std: bool, True (default): save mean and std 
            wav_samp_rate: None (default) or int, if input data has  waveform, 
                         please set sampling rate. It is used by _data_writer
            flag_lang: str, 'EN' (default), if input data has text, text will
                       be converted into code indices. flag_lang indicates the 
                       language for the text processer, used by _data_reader
            global_arg: argument parser returned by arg_parse.f_args_parsed()
                      default None
        Methods
        -------
            get_loader(): return a torch.util.data.DataLoader
            get_dataset(): return a torch.util.data.DataSet
        """
        nii_warn.f_print_w_date("Loading dataset %s" % (dataset_name),
                                level="h")
        
        # create torch.util.data.DataSet
        self.m_dataset = NIIDataSet(dataset_name, \
                                    file_list, \
                                    input_dirs, input_exts, \
                                    input_dims, input_reso, \
                                    input_norm, \
                                    output_dirs, output_exts, \
                                    output_dims, output_reso, \
                                    output_norm, \
                                    stats_path, data_format, \
                                    truncate_seq, min_seq_len,\
                                    save_mean_std, \
                                    wav_samp_rate, \
                                    flag_lang, \
                                    global_arg)
        
        # create torch.util.data.DataLoader
        if params is None:
            tmp_params = nii_dconf.default_loader_conf
        else:
            tmp_params = params.copy()
            
        # save parameters
        self.m_params = tmp_params.copy()
        
        # initialize sampler if necessary
        if 'sampler' in tmp_params:
            tmp_sampler = None
            if tmp_params['sampler'] == nii_sampler_fn.g_str_sampler_bsbl:
                if 'batch_size' in tmp_params:
                    # initialize the sampler
                    tmp_sampler = nii_sampler_fn.SamplerBlockShuffleByLen(
                        self.m_dataset.f_get_seq_len_list(), 
                        tmp_params['batch_size'])
                    # turn off automatic shuffle
                    tmp_params['shuffle'] = False                    
                else:
                    nii_warn.f_die("Sampler requires batch size > 1")
            tmp_params['sampler'] = tmp_sampler
            

        # collate function
        if 'batch_size' in tmp_params and tmp_params['batch_size'] > 1:
            # for batch-size > 1, use customize_collate to handle
            # data with different length
            collate_fn = nii_collate_fn.customize_collate
        else:
            collate_fn = None
            
        self.m_loader = torch.utils.data.DataLoader(
            self.m_dataset, collate_fn=collate_fn, **tmp_params)

        # done
        return
        
    def get_loader_params(self):
        return self.m_params
    
    def get_loader(self):
        """ get_loader():
        Return the dataLoader (torch.util.data.DataLoader)
        """
        return self.m_loader
    
    def get_dataset(self):
        """ get_dataset():
        Return the dataset (torch.util.data.Dataset)
        """
        return self.m_dataset

    def get_data_mean_std(self):
        """
        """
        return self.m_dataset.f_get_mean_std_tuple()

    def print_info(self):
        """
        """
        self.m_dataset.f_print_info()
        print(str(self.m_params))
        return

    def putitem(self, output_data, save_dir, data_infor_str):
        """ Decompose the output_data from network into
        separate files
        """
        self.m_dataset.f_putitem(output_data, save_dir, data_infor_str)

    def get_in_dim(self):
        """ Return the dimension of input features
        """ 
        return self.m_dataset.f_input_dim()

    def get_out_dim(self):
        """ Return the dimension of output features
        """
        return self.m_dataset.f_output_dim()

    def get_seq_num(self):
        """ Return the number of sequences (after truncation)
        """ 
        return self.m_dataset.f_get_num_seq()

    def adjust_utt_idx(self, data_tuple, utt_idx_shift):
        """ Return data tuple with adjusted utterance index in merged dataset
        
        This is used by customize_dataset.
        """
        return self.m_dataset.f_adjust_idx(data_tuple, utt_idx_shift)
    
if __name__ == "__main__":
    pass