Commit ·
cb2b82e
1
Parent(s): cbf9056
filter wav 10s and new pretrained model
Browse files- main.py +27 -13
- model-bin/pretrained/base/pytorch_model.bin +1 -1
main.py
CHANGED
|
@@ -45,6 +45,7 @@ def load_pretrained_model(checkpoint_path=None):
|
|
| 45 |
)
|
| 46 |
# model.freeze_feature_extractor()
|
| 47 |
|
|
|
|
| 48 |
model_total_params = sum(p.numel() for p in model.parameters())
|
| 49 |
model_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 50 |
print(model)
|
|
@@ -68,15 +69,19 @@ def prepare_dataset(batch, processor):
|
|
| 68 |
return batch
|
| 69 |
|
| 70 |
|
| 71 |
-
def load_prepared_dataset(path, processor,
|
| 72 |
dataset = load_from_disk(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
processed_dataset = dataset.map(prepare_dataset,
|
| 74 |
remove_columns=dataset.column_names,
|
| 75 |
batch_size=32,
|
| 76 |
-
num_proc=
|
| 77 |
batched=True,
|
| 78 |
fn_kwargs={"processor": processor},
|
| 79 |
-
cache_file_name=
|
| 80 |
return processed_dataset
|
| 81 |
|
| 82 |
|
|
@@ -105,9 +110,9 @@ if __name__ == "__main__":
|
|
| 105 |
output_dir=checkpoint_path,
|
| 106 |
fp16=True,
|
| 107 |
group_by_length=True,
|
| 108 |
-
per_device_train_batch_size=
|
| 109 |
-
per_device_eval_batch_size=
|
| 110 |
-
gradient_accumulation_steps=
|
| 111 |
num_train_epochs=num_epochs, # each epoch per shard data
|
| 112 |
logging_steps=1,
|
| 113 |
learning_rate=1e-4,
|
|
@@ -150,17 +155,26 @@ if __name__ == "__main__":
|
|
| 150 |
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
| 151 |
'shard_{}'.format(train_dataset_shard_idx)),
|
| 152 |
w2v_ctc_processor,
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
# load test shard subset
|
| 158 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
| 159 |
'shard_{}'.format(test_dataset_shard_idx)),
|
| 160 |
w2v_ctc_processor,
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
)
|
| 165 |
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
| 166 |
# Init trainer
|
|
|
|
| 45 |
)
|
| 46 |
# model.freeze_feature_extractor()
|
| 47 |
|
| 48 |
+
# model = Wav2Vec2ForCTC(model.config)
|
| 49 |
model_total_params = sum(p.numel() for p in model.parameters())
|
| 50 |
model_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 51 |
print(model)
|
|
|
|
| 69 |
return batch
|
| 70 |
|
| 71 |
|
| 72 |
+
def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=8):
|
| 73 |
dataset = load_from_disk(path)
|
| 74 |
+
dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
|
| 75 |
+
batch_size=32,
|
| 76 |
+
num_proc=num_proc,
|
| 77 |
+
cache_file_name=cache_file_filter_name)
|
| 78 |
processed_dataset = dataset.map(prepare_dataset,
|
| 79 |
remove_columns=dataset.column_names,
|
| 80 |
batch_size=32,
|
| 81 |
+
num_proc=num_proc,
|
| 82 |
batched=True,
|
| 83 |
fn_kwargs={"processor": processor},
|
| 84 |
+
cache_file_name=cache_file_map_name)
|
| 85 |
return processed_dataset
|
| 86 |
|
| 87 |
|
|
|
|
| 110 |
output_dir=checkpoint_path,
|
| 111 |
fp16=True,
|
| 112 |
group_by_length=True,
|
| 113 |
+
per_device_train_batch_size=32,
|
| 114 |
+
per_device_eval_batch_size=32,
|
| 115 |
+
gradient_accumulation_steps=2,
|
| 116 |
num_train_epochs=num_epochs, # each epoch per shard data
|
| 117 |
logging_steps=1,
|
| 118 |
learning_rate=1e-4,
|
|
|
|
| 155 |
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
| 156 |
'shard_{}'.format(train_dataset_shard_idx)),
|
| 157 |
w2v_ctc_processor,
|
| 158 |
+
cache_file_filter_name=os.path.join(cache_processing_dataset_folder,
|
| 159 |
+
'train',
|
| 160 |
+
'cache-train-filter-shard-{}.arrow'.format(
|
| 161 |
+
train_dataset_shard_idx)),
|
| 162 |
+
cache_file_map_name=os.path.join(cache_processing_dataset_folder,
|
| 163 |
+
'train',
|
| 164 |
+
'cache-train-map-shard-{}.arrow'.format(
|
| 165 |
+
train_dataset_shard_idx)),
|
| 166 |
+
) #.shard(1000, 0) # Remove shard split when train
|
| 167 |
# load test shard subset
|
| 168 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
| 169 |
'shard_{}'.format(test_dataset_shard_idx)),
|
| 170 |
w2v_ctc_processor,
|
| 171 |
+
cache_file_filter_name=os.path.join(cache_processing_dataset_folder,
|
| 172 |
+
'test',
|
| 173 |
+
'cache-test-filter-shard-{}.arrow'.format(
|
| 174 |
+
test_dataset_shard_idx)),
|
| 175 |
+
cache_file_map_name=os.path.join(cache_processing_dataset_folder, 'test',
|
| 176 |
+
'cache-test-map-shard-{}.arrow'.format(
|
| 177 |
+
test_dataset_shard_idx))
|
| 178 |
)
|
| 179 |
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
| 180 |
# Init trainer
|
model-bin/pretrained/base/pytorch_model.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 380261837
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b8fc5e67c00d407cd160a238034677db5670cbc77fe766c53d1042478509574d
|
| 3 |
size 380261837
|