Commit ·
cbf9056
1
Parent(s): f1bbf33
fix trainer
Browse files
main.py
CHANGED
|
@@ -123,6 +123,7 @@ if __name__ == "__main__":
|
|
| 123 |
# save_steps=5,
|
| 124 |
# eval_steps=5,
|
| 125 |
)
|
|
|
|
| 126 |
|
| 127 |
# PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
|
| 128 |
last_checkpoint_path = None
|
|
@@ -163,16 +164,20 @@ if __name__ == "__main__":
|
|
| 163 |
)
|
| 164 |
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
| 165 |
# Init trainer
|
| 166 |
-
trainer
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards))
|
| 178 |
logging.get_logger().info(
|
|
|
|
| 123 |
# save_steps=5,
|
| 124 |
# eval_steps=5,
|
| 125 |
)
|
| 126 |
+
trainer = None
|
| 127 |
|
| 128 |
# PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
|
| 129 |
last_checkpoint_path = None
|
|
|
|
| 164 |
)
|
| 165 |
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
| 166 |
# Init trainer
|
| 167 |
+
if trainer is None:
|
| 168 |
+
trainer = Trainer(
|
| 169 |
+
model=w2v_ctc_model,
|
| 170 |
+
data_collator=data_collator,
|
| 171 |
+
args=training_args,
|
| 172 |
+
compute_metrics=compute_metrics_fn(w2v_ctc_processor),
|
| 173 |
+
train_dataset=train_dataset,
|
| 174 |
+
eval_dataset=test_dataset,
|
| 175 |
+
tokenizer=w2v_ctc_processor.feature_extractor,
|
| 176 |
+
callbacks=[BreakEachEpoch()] # Manual break end of epoch because each epoch loop over a shard
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
trainer.train_dataset = train_dataset
|
| 180 |
+
trainer.eval_dataset = test_dataset
|
| 181 |
|
| 182 |
logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards))
|
| 183 |
logging.get_logger().info(
|