add prefetch
Browse files
main.py
CHANGED
|
@@ -10,6 +10,7 @@ import json
|
|
| 10 |
import os, glob
|
| 11 |
from callbacks import BreakEachEpoch
|
| 12 |
import subprocess
|
|
|
|
| 13 |
|
| 14 |
logging.set_verbosity_info()
|
| 15 |
|
|
@@ -70,8 +71,34 @@ def prepare_dataset(batch, processor):
|
|
| 70 |
return batch
|
| 71 |
|
| 72 |
|
| 73 |
-
def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=
|
| 74 |
dataset = load_from_disk(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
|
| 76 |
batch_size=32,
|
| 77 |
num_proc=num_proc,
|
|
@@ -83,6 +110,7 @@ def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_ma
|
|
| 83 |
batched=True,
|
| 84 |
fn_kwargs={"processor": processor},
|
| 85 |
cache_file_name=cache_file_map_name)
|
|
|
|
| 86 |
return processed_dataset
|
| 87 |
|
| 88 |
|
|
@@ -95,6 +123,44 @@ def commit_checkpoint():
|
|
| 95 |
for command in submit_commands:
|
| 96 |
print(subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf-8'))
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
if __name__ == "__main__":
|
| 99 |
|
| 100 |
checkpoint_path = "./model-bin/finetune/base/"
|
|
@@ -106,9 +172,13 @@ if __name__ == "__main__":
|
|
| 106 |
test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
|
| 107 |
|
| 108 |
cache_processing_dataset_folder = './data-bin/cache/'
|
|
|
|
| 109 |
if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
|
| 110 |
os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
|
| 111 |
os.makedirs(os.path.join(cache_processing_dataset_folder, 'test'))
|
|
|
|
|
|
|
|
|
|
| 112 |
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
|
| 113 |
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
|
| 114 |
num_epochs = 5000
|
|
@@ -121,7 +191,7 @@ if __name__ == "__main__":
|
|
| 121 |
per_device_eval_batch_size=32,
|
| 122 |
gradient_accumulation_steps=2,
|
| 123 |
num_train_epochs=num_epochs, # each epoch per shard data
|
| 124 |
-
logging_steps=
|
| 125 |
learning_rate=1e-5,
|
| 126 |
weight_decay=0.005,
|
| 127 |
warmup_steps=1000,
|
|
@@ -150,13 +220,23 @@ if __name__ == "__main__":
|
|
| 150 |
w2v_ctc_model, w2v_ctc_processor = load_pretrained_model()
|
| 151 |
data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True)
|
| 152 |
|
|
|
|
|
|
|
| 153 |
for epoch_idx in range(last_epoch_idx, num_epochs):
|
| 154 |
-
# loop over training shards
|
| 155 |
-
train_dataset_shard_idx = epoch_idx % num_train_shards
|
| 156 |
-
# Get test shard depend on train shard id
|
| 157 |
-
test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards))
|
| 158 |
-
num_test_sub_shard = 8 # Split test shard into subset. Default is 8
|
| 159 |
-
idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard # loop over test shard subset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
# load train shard
|
| 162 |
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
|
@@ -170,7 +250,7 @@ if __name__ == "__main__":
|
|
| 170 |
'train',
|
| 171 |
'cache-train-map-shard-{}.arrow'.format(
|
| 172 |
train_dataset_shard_idx)),
|
| 173 |
-
)
|
| 174 |
# load test shard subset
|
| 175 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
| 176 |
'shard_{}'.format(test_dataset_shard_idx)),
|
|
@@ -184,6 +264,12 @@ if __name__ == "__main__":
|
|
| 184 |
test_dataset_shard_idx))
|
| 185 |
)
|
| 186 |
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
# Init trainer
|
| 188 |
if trainer is None:
|
| 189 |
trainer = Trainer(
|
|
@@ -216,5 +302,5 @@ if __name__ == "__main__":
|
|
| 216 |
test_dataset.cleanup_cache_files()
|
| 217 |
train_dataset.cleanup_cache_files()
|
| 218 |
|
| 219 |
-
if epoch_idx %
|
| 220 |
-
|
|
|
|
| 10 |
import os, glob
|
| 11 |
from callbacks import BreakEachEpoch
|
| 12 |
import subprocess
|
| 13 |
+
from multiprocessing import Process
|
| 14 |
|
| 15 |
logging.set_verbosity_info()
|
| 16 |
|
|
|
|
| 71 |
return batch
|
| 72 |
|
| 73 |
|
| 74 |
+
def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=6):
|
| 75 |
dataset = load_from_disk(path)
|
| 76 |
+
list_cache_prefetch_files = glob.glob(
|
| 77 |
+
cache_file_map_name.replace(cache_processing_dataset_folder, cache_processing_dataset_folder_prefetch).replace(
|
| 78 |
+
'.arrow', '*'))
|
| 79 |
+
|
| 80 |
+
# Do not re-compute what already in cache folder
|
| 81 |
+
if cache_file_map_name.startswith(cache_processing_dataset_folder_prefetch):
|
| 82 |
+
if len(glob.glob(cache_file_map_name.replace(cache_processing_dataset_folder_prefetch,
|
| 83 |
+
cache_processing_dataset_folder).replace('.arrow', '*'))) > 0:
|
| 84 |
+
return
|
| 85 |
+
if len(list_cache_prefetch_files) > 0:
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
# check cache file
|
| 89 |
+
if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) == 0 and len(list_cache_prefetch_files) > 0:
|
| 90 |
+
for item_file in list_cache_prefetch_files:
|
| 91 |
+
os.rename(item_file, item_file.replace(cache_processing_dataset_folder_prefetch,
|
| 92 |
+
cache_processing_dataset_folder))
|
| 93 |
+
if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) > 0:
|
| 94 |
+
return dataset.map(prepare_dataset,
|
| 95 |
+
remove_columns=dataset.column_names,
|
| 96 |
+
batch_size=32,
|
| 97 |
+
num_proc=num_proc,
|
| 98 |
+
batched=True,
|
| 99 |
+
fn_kwargs={"processor": processor},
|
| 100 |
+
cache_file_name=cache_file_map_name)
|
| 101 |
+
|
| 102 |
dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
|
| 103 |
batch_size=32,
|
| 104 |
num_proc=num_proc,
|
|
|
|
| 110 |
batched=True,
|
| 111 |
fn_kwargs={"processor": processor},
|
| 112 |
cache_file_name=cache_file_map_name)
|
| 113 |
+
processed_dataset.cleanup_cache_files()
|
| 114 |
return processed_dataset
|
| 115 |
|
| 116 |
|
|
|
|
| 123 |
for command in submit_commands:
|
| 124 |
print(subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf-8'))
|
| 125 |
|
| 126 |
+
|
| 127 |
+
def get_train_test_shard_id(epoch_count):
|
| 128 |
+
# loop over training shards
|
| 129 |
+
_train_dataset_shard_idx = epoch_count % num_train_shards
|
| 130 |
+
# Get test shard depend on train shard id
|
| 131 |
+
_test_dataset_shard_idx = round(_train_dataset_shard_idx / (num_train_shards / num_test_shards))
|
| 132 |
+
_num_test_sub_shard = 8 # Split test shard into subset. Default is 8
|
| 133 |
+
_idx_sub_shard = _train_dataset_shard_idx % _num_test_sub_shard # loop over test shard subset
|
| 134 |
+
return _train_dataset_shard_idx, _test_dataset_shard_idx, _num_test_sub_shard, _idx_sub_shard
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def process_prefetch_epoch(epoch_count):
|
| 138 |
+
train_shard_idx, test_shard_idx, _, _ = get_train_test_shard_id(epoch_count)
|
| 139 |
+
load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
| 140 |
+
'shard_{}'.format(train_shard_idx)),
|
| 141 |
+
w2v_ctc_processor,
|
| 142 |
+
cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch,
|
| 143 |
+
'train',
|
| 144 |
+
'cache-train-filter-shard-{}.arrow'.format(
|
| 145 |
+
train_shard_idx)),
|
| 146 |
+
cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch,
|
| 147 |
+
'train',
|
| 148 |
+
'cache-train-map-shard-{}.arrow'.format(
|
| 149 |
+
train_shard_idx)),
|
| 150 |
+
)
|
| 151 |
+
load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
| 152 |
+
'shard_{}'.format(test_shard_idx)),
|
| 153 |
+
w2v_ctc_processor,
|
| 154 |
+
cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch,
|
| 155 |
+
'test',
|
| 156 |
+
'cache-test-filter-shard-{}.arrow'.format(
|
| 157 |
+
test_shard_idx)),
|
| 158 |
+
cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch, 'test',
|
| 159 |
+
'cache-test-map-shard-{}.arrow'.format(
|
| 160 |
+
test_shard_idx))
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
if __name__ == "__main__":
|
| 165 |
|
| 166 |
checkpoint_path = "./model-bin/finetune/base/"
|
|
|
|
| 172 |
test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
|
| 173 |
|
| 174 |
cache_processing_dataset_folder = './data-bin/cache/'
|
| 175 |
+
cache_processing_dataset_folder_prefetch = './data-bin/cache_prefetch/'
|
| 176 |
if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
|
| 177 |
os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
|
| 178 |
os.makedirs(os.path.join(cache_processing_dataset_folder, 'test'))
|
| 179 |
+
if not os.path.exists(os.path.join(cache_processing_dataset_folder_prefetch, 'train')):
|
| 180 |
+
os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'train'))
|
| 181 |
+
os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'test'))
|
| 182 |
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
|
| 183 |
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
|
| 184 |
num_epochs = 5000
|
|
|
|
| 191 |
per_device_eval_batch_size=32,
|
| 192 |
gradient_accumulation_steps=2,
|
| 193 |
num_train_epochs=num_epochs, # each epoch per shard data
|
| 194 |
+
logging_steps=5,
|
| 195 |
learning_rate=1e-5,
|
| 196 |
weight_decay=0.005,
|
| 197 |
warmup_steps=1000,
|
|
|
|
| 220 |
w2v_ctc_model, w2v_ctc_processor = load_pretrained_model()
|
| 221 |
data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True)
|
| 222 |
|
| 223 |
+
prefetch_process = []
|
| 224 |
+
|
| 225 |
for epoch_idx in range(last_epoch_idx, num_epochs):
|
| 226 |
+
# # loop over training shards
|
| 227 |
+
# train_dataset_shard_idx = epoch_idx % num_train_shards
|
| 228 |
+
# # Get test shard depend on train shard id
|
| 229 |
+
# test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards))
|
| 230 |
+
# num_test_sub_shard = 8 # Split test shard into subset. Default is 8
|
| 231 |
+
# idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard # loop over test shard subset
|
| 232 |
+
|
| 233 |
+
train_dataset_shard_idx, test_dataset_shard_idx, num_test_sub_shard, idx_sub_shard = get_train_test_shard_id(
|
| 234 |
+
epoch_idx)
|
| 235 |
+
|
| 236 |
+
# waiting for all prefetch process done
|
| 237 |
+
for process_instance in prefetch_process:
|
| 238 |
+
process_instance.join()
|
| 239 |
+
prefetch_process.clear()
|
| 240 |
|
| 241 |
# load train shard
|
| 242 |
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
|
|
|
| 250 |
'train',
|
| 251 |
'cache-train-map-shard-{}.arrow'.format(
|
| 252 |
train_dataset_shard_idx)),
|
| 253 |
+
) # .shard(1000, 0) # Remove shard split when train
|
| 254 |
# load test shard subset
|
| 255 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
| 256 |
'shard_{}'.format(test_dataset_shard_idx)),
|
|
|
|
| 264 |
test_dataset_shard_idx))
|
| 265 |
)
|
| 266 |
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
| 267 |
+
|
| 268 |
+
# Prefetch_dataset
|
| 269 |
+
prefetch_process.append(Process(target=process_prefetch_epoch, args=(epoch_idx + 1,)))
|
| 270 |
+
for process_instance in prefetch_process:
|
| 271 |
+
process_instance.start()
|
| 272 |
+
|
| 273 |
# Init trainer
|
| 274 |
if trainer is None:
|
| 275 |
trainer = Trainer(
|
|
|
|
| 302 |
test_dataset.cleanup_cache_files()
|
| 303 |
train_dataset.cleanup_cache_files()
|
| 304 |
|
| 305 |
+
if epoch_idx % 5 == 0:
|
| 306 |
+
commit_checkpoint()
|