Commit ·
062f3fc
1
Parent(s): 8bb6457
Update training script
Browse files- run_wav2vec2_pretrain_flax.py +3 -0
- train.sh +6 -7
run_wav2vec2_pretrain_flax.py
CHANGED
|
@@ -174,6 +174,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
|
| 174 |
|
| 175 |
batch_size = batch["input_values"].shape[0]
|
| 176 |
|
|
|
|
| 177 |
if batch["attention_mask"] is not None:
|
| 178 |
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
|
| 179 |
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
|
|
@@ -196,6 +197,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
|
| 196 |
batch["sampled_negative_indices"] = _sample_negative_indices(
|
| 197 |
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
|
| 198 |
self.model.config.num_negatives,
|
|
|
|
| 199 |
)
|
| 200 |
|
| 201 |
return batch
|
|
@@ -342,6 +344,7 @@ def main():
|
|
| 342 |
def normalize(batch):
|
| 343 |
return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
|
| 344 |
|
|
|
|
| 345 |
# normalize and transform to `BatchFeatures`
|
| 346 |
vectorized_datasets = vectorized_datasets.map(
|
| 347 |
normalize,
|
|
|
|
| 174 |
|
| 175 |
batch_size = batch["input_values"].shape[0]
|
| 176 |
|
| 177 |
+
attention_mask = None
|
| 178 |
if batch["attention_mask"] is not None:
|
| 179 |
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
|
| 180 |
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
|
|
|
|
| 197 |
batch["sampled_negative_indices"] = _sample_negative_indices(
|
| 198 |
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
|
| 199 |
self.model.config.num_negatives,
|
| 200 |
+
attention_mask=attention_mask,
|
| 201 |
)
|
| 202 |
|
| 203 |
return batch
|
|
|
|
| 344 |
def normalize(batch):
|
| 345 |
return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
|
| 346 |
|
| 347 |
+
batch_size = 64
|
| 348 |
# normalize and transform to `BatchFeatures`
|
| 349 |
vectorized_datasets = vectorized_datasets.map(
|
| 350 |
normalize,
|
train.sh
CHANGED
|
@@ -1,22 +1,21 @@
|
|
| 1 |
#!/usr/bin/env bash
|
| 2 |
-
./
|
| 3 |
-
--output_dir="./
|
| 4 |
--num_train_epochs="5" \
|
| 5 |
-
--per_device_train_batch_size="
|
| 6 |
-
--per_device_eval_batch_size="
|
| 7 |
--learning_rate="5e-4" \
|
| 8 |
--weight_decay="0.01" \
|
| 9 |
-
--warmup_steps="
|
| 10 |
--model_name_or_path="./" \
|
| 11 |
--dataset_name="common_voice" \
|
| 12 |
--dataset_config_name="es" \
|
| 13 |
-
--preprocessing_num_workers="
|
| 14 |
--max_duration_in_seconds="10.0" \
|
| 15 |
--adam_beta1="0.9" \
|
| 16 |
--adam_beta2="0.98" \
|
| 17 |
--pad_to_multiple_of="16384" \
|
| 18 |
--validation_split_percentage="5" \
|
| 19 |
--speech_file_column="path" \
|
| 20 |
-
--dtype="bfloat16" \
|
| 21 |
--cache_dir="./data_cache" \
|
| 22 |
--push_to_hub
|
|
|
|
| 1 |
#!/usr/bin/env bash
|
| 2 |
+
./run_wav2vec2_pretrain_flax.py \
|
| 3 |
+
--output_dir="./wav2vec2-spanish" \
|
| 4 |
--num_train_epochs="5" \
|
| 5 |
+
--per_device_train_batch_size="16" \
|
| 6 |
+
--per_device_eval_batch_size="16" \
|
| 7 |
--learning_rate="5e-4" \
|
| 8 |
--weight_decay="0.01" \
|
| 9 |
+
--warmup_steps="1000" \
|
| 10 |
--model_name_or_path="./" \
|
| 11 |
--dataset_name="common_voice" \
|
| 12 |
--dataset_config_name="es" \
|
| 13 |
+
--preprocessing_num_workers="32" \
|
| 14 |
--max_duration_in_seconds="10.0" \
|
| 15 |
--adam_beta1="0.9" \
|
| 16 |
--adam_beta2="0.98" \
|
| 17 |
--pad_to_multiple_of="16384" \
|
| 18 |
--validation_split_percentage="5" \
|
| 19 |
--speech_file_column="path" \
|
|
|
|
| 20 |
--cache_dir="./data_cache" \
|
| 21 |
--push_to_hub
|