Commit ·
1e275bf
1
Parent(s): b839dd6
add config for training multi epochs
Browse files- callbacks.py +12 -0
- main.py +29 -22
callbacks.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl, logging
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BreakEachEpoch(TrainerCallback):
|
| 5 |
+
"""
|
| 6 |
+
A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation
|
| 7 |
+
and checkpoints.
|
| 8 |
+
"""
|
| 9 |
+
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
| 10 |
+
control.should_training_stop = True
|
| 11 |
+
logging.get_logger().info("Break each epoch for reload new shard dataset")
|
| 12 |
+
return control
|
main.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
| 1 |
-
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
|
|
|
| 2 |
from datasets import load_from_disk
|
| 3 |
from data_handler import DataCollatorCTCWithPadding
|
| 4 |
from transformers import TrainingArguments
|
| 5 |
from transformers import Trainer, logging
|
| 6 |
from metric_utils import compute_metrics_fn
|
| 7 |
from transformers.trainer_utils import get_last_checkpoint
|
| 8 |
-
import json
|
| 9 |
import os, glob
|
|
|
|
| 10 |
|
| 11 |
logging.set_verbosity_info()
|
| 12 |
|
|
@@ -68,8 +70,8 @@ def load_prepared_dataset(path, processor, cache_file_name):
|
|
| 68 |
dataset = load_from_disk(path)
|
| 69 |
processed_dataset = dataset.map(prepare_dataset,
|
| 70 |
remove_columns=dataset.column_names,
|
| 71 |
-
batch_size=
|
| 72 |
-
num_proc=
|
| 73 |
batched=True,
|
| 74 |
fn_kwargs={"processor": processor},
|
| 75 |
cache_file_name=cache_file_name)
|
|
@@ -90,8 +92,9 @@ if __name__ == "__main__":
|
|
| 90 |
test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
|
| 91 |
|
| 92 |
cache_processing_dataset_folder = './data-bin/cache/'
|
| 93 |
-
if not os.path.exists(cache_processing_dataset_folder):
|
| 94 |
-
os.makedirs(cache_processing_dataset_folder)
|
|
|
|
| 95 |
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
|
| 96 |
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
|
| 97 |
num_epochs = 5000
|
|
@@ -100,20 +103,21 @@ if __name__ == "__main__":
|
|
| 100 |
output_dir=checkpoint_path,
|
| 101 |
# fp16=True,
|
| 102 |
group_by_length=True,
|
| 103 |
-
per_device_train_batch_size=
|
| 104 |
-
per_device_eval_batch_size=
|
| 105 |
gradient_accumulation_steps=8,
|
| 106 |
-
num_train_epochs=
|
| 107 |
logging_steps=1,
|
| 108 |
learning_rate=1e-4,
|
| 109 |
weight_decay=0.005,
|
| 110 |
-
warmup_steps=
|
| 111 |
save_total_limit=2,
|
| 112 |
ignore_data_skip=True,
|
| 113 |
logging_dir=os.path.join(checkpoint_path, 'log'),
|
| 114 |
metric_for_best_model='wer',
|
| 115 |
save_strategy="epoch",
|
| 116 |
evaluation_strategy="epoch",
|
|
|
|
| 117 |
# save_steps=5,
|
| 118 |
# eval_steps=5,
|
| 119 |
)
|
|
@@ -143,19 +147,19 @@ if __name__ == "__main__":
|
|
| 143 |
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
| 144 |
'shard_{}'.format(train_dataset_shard_idx)),
|
| 145 |
w2v_ctc_processor,
|
| 146 |
-
cache_file_name=os.path.join(cache_processing_dataset_folder,
|
| 147 |
'cache-train-shard-{}.arrow'.format(
|
| 148 |
train_dataset_shard_idx))
|
| 149 |
-
)
|
| 150 |
# load test shard subset
|
| 151 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
| 152 |
'shard_{}'.format(test_dataset_shard_idx)),
|
| 153 |
w2v_ctc_processor,
|
| 154 |
-
cache_file_name=os.path.join(cache_processing_dataset_folder,
|
| 155 |
'cache-test-shard-{}.arrow'.format(
|
| 156 |
test_dataset_shard_idx))
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
# Init trainer
|
| 160 |
trainer = Trainer(
|
| 161 |
model=w2v_ctc_model,
|
|
@@ -164,13 +168,16 @@ if __name__ == "__main__":
|
|
| 164 |
compute_metrics=compute_metrics_fn(w2v_ctc_processor),
|
| 165 |
train_dataset=train_dataset,
|
| 166 |
eval_dataset=test_dataset,
|
| 167 |
-
tokenizer=w2v_ctc_processor.feature_extractor
|
|
|
|
| 168 |
)
|
| 169 |
-
# Manual add num_train_epochs because each epoch loop over a shard
|
| 170 |
-
training_args.num_train_epochs = epoch_idx + 1
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
if last_checkpoint_path is not None:
|
| 176 |
# start train from a checkpoint if exist
|
|
@@ -181,5 +188,5 @@ if __name__ == "__main__":
|
|
| 181 |
last_checkpoint_path = get_last_checkpoint(checkpoint_path)
|
| 182 |
|
| 183 |
# Clear cache file to free disk
|
| 184 |
-
|
| 185 |
-
|
|
|
|
| 1 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, \
|
| 2 |
+
TrainerCallback
|
| 3 |
from datasets import load_from_disk
|
| 4 |
from data_handler import DataCollatorCTCWithPadding
|
| 5 |
from transformers import TrainingArguments
|
| 6 |
from transformers import Trainer, logging
|
| 7 |
from metric_utils import compute_metrics_fn
|
| 8 |
from transformers.trainer_utils import get_last_checkpoint
|
| 9 |
+
import json
|
| 10 |
import os, glob
|
| 11 |
+
from callbacks import BreakEachEpoch
|
| 12 |
|
| 13 |
logging.set_verbosity_info()
|
| 14 |
|
|
|
|
| 70 |
dataset = load_from_disk(path)
|
| 71 |
processed_dataset = dataset.map(prepare_dataset,
|
| 72 |
remove_columns=dataset.column_names,
|
| 73 |
+
batch_size=32,
|
| 74 |
+
num_proc=4,
|
| 75 |
batched=True,
|
| 76 |
fn_kwargs={"processor": processor},
|
| 77 |
cache_file_name=cache_file_name)
|
|
|
|
| 92 |
test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
|
| 93 |
|
| 94 |
cache_processing_dataset_folder = './data-bin/cache/'
|
| 95 |
+
if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
|
| 96 |
+
os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
|
| 97 |
+
os.makedirs(os.path.join(cache_processing_dataset_folder, 'test'))
|
| 98 |
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
|
| 99 |
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
|
| 100 |
num_epochs = 5000
|
|
|
|
| 103 |
output_dir=checkpoint_path,
|
| 104 |
# fp16=True,
|
| 105 |
group_by_length=True,
|
| 106 |
+
per_device_train_batch_size=4,
|
| 107 |
+
per_device_eval_batch_size=4,
|
| 108 |
gradient_accumulation_steps=8,
|
| 109 |
+
num_train_epochs=num_epochs, # each epoch per shard data
|
| 110 |
logging_steps=1,
|
| 111 |
learning_rate=1e-4,
|
| 112 |
weight_decay=0.005,
|
| 113 |
+
warmup_steps=1000,
|
| 114 |
save_total_limit=2,
|
| 115 |
ignore_data_skip=True,
|
| 116 |
logging_dir=os.path.join(checkpoint_path, 'log'),
|
| 117 |
metric_for_best_model='wer',
|
| 118 |
save_strategy="epoch",
|
| 119 |
evaluation_strategy="epoch",
|
| 120 |
+
greater_is_better=False,
|
| 121 |
# save_steps=5,
|
| 122 |
# eval_steps=5,
|
| 123 |
)
|
|
|
|
| 147 |
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
| 148 |
'shard_{}'.format(train_dataset_shard_idx)),
|
| 149 |
w2v_ctc_processor,
|
| 150 |
+
cache_file_name=os.path.join(cache_processing_dataset_folder, 'train',
|
| 151 |
'cache-train-shard-{}.arrow'.format(
|
| 152 |
train_dataset_shard_idx))
|
| 153 |
+
).shard(1000, 0) # Remove shard split when train
|
| 154 |
# load test shard subset
|
| 155 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
| 156 |
'shard_{}'.format(test_dataset_shard_idx)),
|
| 157 |
w2v_ctc_processor,
|
| 158 |
+
cache_file_name=os.path.join(cache_processing_dataset_folder, 'test',
|
| 159 |
'cache-test-shard-{}.arrow'.format(
|
| 160 |
test_dataset_shard_idx))
|
| 161 |
+
)
|
| 162 |
+
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
| 163 |
# Init trainer
|
| 164 |
trainer = Trainer(
|
| 165 |
model=w2v_ctc_model,
|
|
|
|
| 168 |
compute_metrics=compute_metrics_fn(w2v_ctc_processor),
|
| 169 |
train_dataset=train_dataset,
|
| 170 |
eval_dataset=test_dataset,
|
| 171 |
+
tokenizer=w2v_ctc_processor.feature_extractor,
|
| 172 |
+
callbacks=[BreakEachEpoch()] # Manual break end of epoch because each epoch loop over a shard
|
| 173 |
)
|
|
|
|
|
|
|
| 174 |
|
| 175 |
+
# training_args.num_train_epochs = epoch_idx + 1
|
| 176 |
+
|
| 177 |
+
logging.get_logger().info('Train epoch {}'.format(training_args.num_train_epochs))
|
| 178 |
+
logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards))
|
| 179 |
+
logging.get_logger().info(
|
| 180 |
+
'Valid shard idx: {} / {} sub_shard: {}'.format(test_dataset_shard_idx + 1, num_test_shards, idx_sub_shard))
|
| 181 |
|
| 182 |
if last_checkpoint_path is not None:
|
| 183 |
# start train from a checkpoint if exist
|
|
|
|
| 188 |
last_checkpoint_path = get_last_checkpoint(checkpoint_path)
|
| 189 |
|
| 190 |
# Clear cache file to free disk
|
| 191 |
+
test_dataset.cleanup_cache_files()
|
| 192 |
+
train_dataset.cleanup_cache_files()
|