Commit
·
9ab2a02
1
Parent(s):
603c683
add updated files
Browse files- config.json +1 -1
- default_config.yaml +2 -2
- run_main.sh +7 -5
- run_pretrain_no_trainer.py +8 -11
config.json
CHANGED
|
@@ -65,7 +65,7 @@
|
|
| 65 |
"mask_time_length": 10,
|
| 66 |
"mask_time_min_space": 1,
|
| 67 |
"mask_time_other": 0.0,
|
| 68 |
-
"mask_time_prob": 0.
|
| 69 |
"mask_time_selection": "static",
|
| 70 |
"model_type": "wav2vec2",
|
| 71 |
"num_attention_heads": 16,
|
|
|
|
| 65 |
"mask_time_length": 10,
|
| 66 |
"mask_time_min_space": 1,
|
| 67 |
"mask_time_other": 0.0,
|
| 68 |
+
"mask_time_prob": 0.65,
|
| 69 |
"mask_time_selection": "static",
|
| 70 |
"model_type": "wav2vec2",
|
| 71 |
"num_attention_heads": 16,
|
default_config.yaml
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
compute_environment: LOCAL_MACHINE
|
| 2 |
deepspeed_config:
|
| 3 |
-
gradient_accumulation_steps: 4
|
| 4 |
offload_optimizer_device: cpu
|
| 5 |
zero_stage: 2
|
| 6 |
distributed_type: DEEPSPEED
|
|
@@ -10,4 +10,4 @@ main_process_ip: null
|
|
| 10 |
main_process_port: null
|
| 11 |
main_training_function: main
|
| 12 |
num_machines: 0
|
| 13 |
-
num_processes:
|
|
|
|
| 1 |
compute_environment: LOCAL_MACHINE
|
| 2 |
deepspeed_config:
|
| 3 |
+
gradient_accumulation_steps: 4
|
| 4 |
offload_optimizer_device: cpu
|
| 5 |
zero_stage: 2
|
| 6 |
distributed_type: DEEPSPEED
|
|
|
|
| 10 |
main_process_port: null
|
| 11 |
main_training_function: main
|
| 12 |
num_machines: 0
|
| 13 |
+
num_processes: 8
|
run_main.sh
CHANGED
|
@@ -5,16 +5,18 @@ accelerate launch --config_file ./default_config.yaml ./run_pretrain_no_trainer
|
|
| 5 |
--max_train_steps="200000" \
|
| 6 |
--num_warmup_steps="100000" \
|
| 7 |
--gradient_accumulation_steps="4" \
|
| 8 |
-
--learning_rate="0.
|
| 9 |
--weight_decay="0.01" \
|
| 10 |
-
--max_duration_in_seconds="
|
| 11 |
--model_name_or_path="./" \
|
| 12 |
--dataset_name="patrickvonplaten/librispeech_local" \
|
| 13 |
-
--manual_data_dir="/home/
|
| 14 |
--dataset_config_name="clean" \
|
| 15 |
--logging_steps="5" \
|
| 16 |
-
--per_device_train_batch_size="
|
| 17 |
-
--per_device_eval_batch_size="
|
|
|
|
|
|
|
| 18 |
#--preprocessing_num_workers="4" \
|
| 19 |
#--adam_beta1="0.9" \
|
| 20 |
#--adam_beta2="0.98" \
|
|
|
|
| 5 |
--max_train_steps="200000" \
|
| 6 |
--num_warmup_steps="100000" \
|
| 7 |
--gradient_accumulation_steps="4" \
|
| 8 |
+
--learning_rate="0.005" \
|
| 9 |
--weight_decay="0.01" \
|
| 10 |
+
--max_duration_in_seconds="10.0" \
|
| 11 |
--model_name_or_path="./" \
|
| 12 |
--dataset_name="patrickvonplaten/librispeech_local" \
|
| 13 |
+
--manual_data_dir="/home/ubuntu/wav2vec2_reproduce" \
|
| 14 |
--dataset_config_name="clean" \
|
| 15 |
--logging_steps="5" \
|
| 16 |
+
--per_device_train_batch_size="8" \
|
| 17 |
+
--per_device_eval_batch_size="8" \
|
| 18 |
+
#--per_device_train_batch_size="16" \
|
| 19 |
+
#--per_device_eval_batch_size="16" \
|
| 20 |
#--preprocessing_num_workers="4" \
|
| 21 |
#--adam_beta1="0.9" \
|
| 22 |
#--adam_beta2="0.98" \
|
run_pretrain_no_trainer.py
CHANGED
|
@@ -34,9 +34,6 @@ MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
|
| 34 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 35 |
|
| 36 |
|
| 37 |
-
wandb.init(project="pretraining-wav2vec2")
|
| 38 |
-
|
| 39 |
-
|
| 40 |
def parse_args():
|
| 41 |
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
| 42 |
parser.add_argument(
|
|
@@ -330,6 +327,8 @@ def main():
|
|
| 330 |
if accelerator.is_local_main_process:
|
| 331 |
datasets.utils.logging.set_verbosity_warning()
|
| 332 |
transformers.utils.logging.set_verbosity_info()
|
|
|
|
|
|
|
| 333 |
else:
|
| 334 |
datasets.utils.logging.set_verbosity_error()
|
| 335 |
transformers.utils.logging.set_verbosity_error()
|
|
@@ -381,9 +380,6 @@ def main():
|
|
| 381 |
split="train",
|
| 382 |
)
|
| 383 |
|
| 384 |
-
# raw_datasets["train"] = raw_datasets["train"].select(range(128))
|
| 385 |
-
# raw_datasets["validation"] = raw_datasets["validation"].select(range(16))
|
| 386 |
-
|
| 387 |
# only normalized-inputs-training is supported
|
| 388 |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
| 389 |
args.model_name_or_path, do_normalize=True
|
|
@@ -489,9 +485,9 @@ def main():
|
|
| 489 |
gumbel_temperature = max(args.max_gumbel_temperature * args.gumbel_temperature_decay ** completed_steps, args.min_gumbel_temperature)
|
| 490 |
|
| 491 |
if hasattr(model, "module"):
|
| 492 |
-
model
|
| 493 |
-
|
| 494 |
-
|
| 495 |
|
| 496 |
if step % args.logging_steps == 0:
|
| 497 |
logs = {
|
|
@@ -508,8 +504,9 @@ def main():
|
|
| 508 |
for k, v in logs.items():
|
| 509 |
log_str += f"| {k}: {round(v.item(), 5)}"
|
| 510 |
|
| 511 |
-
|
| 512 |
-
|
|
|
|
| 513 |
|
| 514 |
if completed_steps >= args.max_train_steps:
|
| 515 |
break
|
|
|
|
| 34 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
| 37 |
def parse_args():
|
| 38 |
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
| 39 |
parser.add_argument(
|
|
|
|
| 327 |
if accelerator.is_local_main_process:
|
| 328 |
datasets.utils.logging.set_verbosity_warning()
|
| 329 |
transformers.utils.logging.set_verbosity_info()
|
| 330 |
+
|
| 331 |
+
wandb.init(project="pretraining-wav2vec2")
|
| 332 |
else:
|
| 333 |
datasets.utils.logging.set_verbosity_error()
|
| 334 |
transformers.utils.logging.set_verbosity_error()
|
|
|
|
| 380 |
split="train",
|
| 381 |
)
|
| 382 |
|
|
|
|
|
|
|
|
|
|
| 383 |
# only normalized-inputs-training is supported
|
| 384 |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
| 385 |
args.model_name_or_path, do_normalize=True
|
|
|
|
| 485 |
gumbel_temperature = max(args.max_gumbel_temperature * args.gumbel_temperature_decay ** completed_steps, args.min_gumbel_temperature)
|
| 486 |
|
| 487 |
if hasattr(model, "module"):
|
| 488 |
+
model.module.set_gumbel_temperature(gumbel_temperature)
|
| 489 |
+
else:
|
| 490 |
+
model.set_gumbel_temperature(gumbel_temperature)
|
| 491 |
|
| 492 |
if step % args.logging_steps == 0:
|
| 493 |
logs = {
|
|
|
|
| 504 |
for k, v in logs.items():
|
| 505 |
log_str += f"| {k}: {round(v.item(), 5)}"
|
| 506 |
|
| 507 |
+
if accelerator.is_local_main_process:
|
| 508 |
+
wandb.log(logs)
|
| 509 |
+
progress_bar.write(log_str)
|
| 510 |
|
| 511 |
if completed_steps >= args.max_train_steps:
|
| 512 |
break
|