versae commited on
Commit
fbb4053
·
1 Parent(s): f12972c

Adding proper data_loader, shuffling, evaluation, lr_schedulers. Refactor.

Browse files
run_flax_speech_recognition_seq2seq_streaming_v3.py CHANGED
@@ -18,6 +18,7 @@ Fine-tuning the Flax library models for sequence to sequence speech recognition.
18
  """
19
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
 
 
21
  import logging
22
  import math
23
  import os
@@ -26,7 +27,7 @@ import time
26
  from dataclasses import field
27
  from functools import partial
28
  from pathlib import Path
29
- from typing import Any, Callable, Dict, List, Optional, Union
30
 
31
  import datasets
32
  import flax
@@ -264,7 +265,6 @@ class DataTrainingArguments:
264
  )
265
  num_train_steps: int = field(default=50000, metadata={
266
  "help": "The number of training steps."})
267
- # num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
268
  shuffle_buffer_size: Optional[int] = field(
269
  default=500,
270
  metadata={
@@ -402,29 +402,28 @@ def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train
402
  return dataset
403
 
404
 
405
- def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
 
 
 
 
 
 
 
 
 
406
  """
407
  Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
408
  and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
409
  """
410
- if shuffle:
411
- batch_idx = jax.random.permutation(rng, len(dataset))
412
- batch_idx = np.asarray(batch_idx)
413
- else:
414
- batch_idx = np.arange(len(dataset))
415
-
416
- if drop_last:
417
- steps_per_epoch = len(dataset) // batch_size
418
- # Skip incomplete batch.
419
- batch_idx = batch_idx[: steps_per_epoch * batch_size]
420
- batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
421
- else:
422
- steps_per_epoch = math.ceil(len(dataset) / batch_size)
423
- batch_idx = np.array_split(batch_idx, steps_per_epoch)
424
-
425
- for idx in batch_idx:
426
- batch = dataset[idx]
427
- yield batch
428
 
429
 
430
  class TrainState(train_state.TrainState):
@@ -448,13 +447,13 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
448
 
449
 
450
  def create_learning_rate_fn(
451
- num_train_steps: int, num_warmup_steps: int, learning_rate: float
452
  ) -> Callable[[int], jnp.array]:
453
  """Returns a linear warmup, linear_decay learning rate function."""
454
  warmup_fn = optax.linear_schedule(
455
- init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
456
  decay_fn = optax.linear_schedule(
457
- init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
458
  )
459
  schedule_fn = optax.join_schedules(
460
  schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
@@ -662,7 +661,7 @@ def main():
662
  vectorized_datasets = raw_datasets.map(
663
  prepare_dataset,
664
  remove_columns=raw_datasets_features,
665
- ).with_format("torch")
666
 
667
  # filter training data with inputs longer than max_input_length
668
  def is_audio_in_length_range(length):
@@ -681,14 +680,14 @@ def main():
681
  )
682
 
683
  # 8. Load Metric
684
- metric = evaluate.load("wer")
 
685
  do_normalize_eval = data_args.do_normalize_eval
686
 
687
  def compute_metrics(pred_ids, label_ids):
688
  # replace padded labels by the padding token
689
  for idx in range(len(label_ids)):
690
  label_ids[idx][label_ids[idx] == -100] = tokenizer.pad_token_id
691
- #label_ids[label_ids == -100] = tokenizer.pad_token_id
692
 
693
  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
694
  # we do not want to group tokens when computing the metrics
@@ -703,9 +702,10 @@ def main():
703
  label_str = [label_str[i]
704
  for i in range(len(label_str)) if len(label_str[i]) > 0]
705
 
706
- wer = 100 * metric.compute(predictions=pred_str, references=label_str)
 
707
 
708
- return {"wer": wer}
709
 
710
  # 9. Save feature extractor, tokenizer and config
711
  feature_extractor.save_pretrained(training_args.output_dir)
@@ -746,6 +746,7 @@ def main():
746
  # Initialize our training
747
  rng = jax.random.PRNGKey(training_args.seed)
748
  rng, dropout_rng = jax.random.split(rng)
 
749
 
750
  # Store some constant
751
  #num_epochs = int(training_args.num_train_epochs)
@@ -755,10 +756,26 @@ def main():
755
  training_args.per_device_eval_batch_size) * jax.device_count()
756
 
757
  # Create learning rate schedule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
759
- data_args.num_train_steps*train_batch_size,
760
  training_args.warmup_steps,
761
  training_args.learning_rate,
 
 
762
  )
763
 
764
  # We use Optax's "masking" functionality to not apply weight decay
@@ -877,7 +894,7 @@ def main():
877
 
878
  # Create parallel version of the train and eval step
879
  p_train_step = jax.pmap(
880
- partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
881
  )
882
  p_eval_step = jax.pmap(partial(
883
  eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
@@ -888,7 +905,7 @@ def main():
888
 
889
  logger.info("***** Running training *****")
890
  logger.info(
891
- f" Num examples = {data_args.num_train_steps*train_batch_size}")
892
  logger.info(
893
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
894
  logger.info(
@@ -900,44 +917,25 @@ def main():
900
  # ======================== Training ================================
901
  train_start = time.time()
902
 
903
- # Create sampling rng
904
- #rng, input_rng = jax.random.split(rng)
905
  train_metrics = []
906
  epoch = 0
907
-
908
- def collate_batch(samples):
909
- return {key: [feature[key] for feature in samples] for key in samples[0].keys()}
910
-
911
- # Create a batched data iterator
912
- num_workers = 0
913
- # This is not working
914
- # vectorized_datasets["train"] = vectorized_datasets["train"].shuffle()
915
- train_data_loader = torch.utils.data.DataLoader(
916
- batch_size=train_batch_size, dataset=vectorized_datasets["train"], num_workers=num_workers, collate_fn=collate_batch, drop_last=True)
917
- train_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(
918
- train_data_loader)
919
-
920
  # train
921
  for step in tqdm(range(data_args.num_train_steps), desc="Training...", position=1, leave=False):
922
-
923
  try:
924
- samples = next(train_data_iterator)
925
-
926
  except StopIteration:
927
  epoch += 1
928
- train_data_loader = torch.utils.data.DataLoader(
929
- batch_size=train_batch_size, dataset=vectorized_datasets["train"], num_workers=num_workers, collate_fn=collate_batch, drop_last=True)
930
- train_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(
931
- train_data_loader)
932
- samples = next(train_data_iterator)
933
-
934
  logger.info(
935
  f"Completed epoch ({epoch} | Loss: {train_metric['loss']}, Learning Rate:"
936
  f" {train_metric['learning_rate']})"
937
  )
938
 
939
- # reshaped_samples = {key: [feature[key] for feature in samples] for key in samples[0].keys()}
940
- # breakpoint()
941
  batch = data_collator(samples)
942
  batch = shard(batch.data)
943
  state, train_metric = p_train_step(state, batch)
@@ -950,16 +948,17 @@ def main():
950
  eval_metrics = []
951
  eval_preds = []
952
  eval_labels = []
953
-
954
- #eval_loader = data_loader(input_rng, vectorized_datasets["eval"], eval_batch_size, drop_last=False)
955
- eval_data_loader = torch.utils.data.DataLoader(
956
- batch_size=eval_batch_size, dataset=vectorized_datasets["eval"], num_workers=num_workers, collate_fn=collate_batch, drop_last=False)
957
- eval_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(
958
- eval_data_loader)
959
-
960
- for _ in tqdm(range(training_args.eval_steps), desc="Evaluating...", position=2, leave=False):
961
  # Model forward
962
- samples = next(eval_data_iterator)
 
 
 
963
  batch = data_collator(samples)
964
  labels = batch["labels"]
965
 
@@ -968,29 +967,28 @@ def main():
968
  )
969
  eval_metrics.append(metrics)
970
 
971
- # generation
972
- if training_args.predict_with_generate:
973
- generated_ids = pad_shard_unpad(
974
- p_generate_step)(state.params, batch.data)
975
- eval_preds.extend(jax.device_get(
976
- generated_ids.reshape(-1, gen_kwargs["max_length"])))
977
- eval_labels.extend(labels)
978
-
979
-
980
  # normalize eval metrics
981
  eval_metrics = get_metrics(eval_metrics)
982
  eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
983
 
984
- # compute WER metric
985
- wer_desc = ""
986
  if training_args.predict_with_generate:
987
- wer_metric = compute_metrics(eval_preds, eval_labels)
988
- eval_metrics.update(wer_metric)
989
- wer_desc = " ".join(
990
- [f"Eval {key}: {value} |" for key, value in wer_metric.items()])
991
 
992
  # Print metrics
993
- desc = f"Epoch... ({epoch} | Eval Loss: {eval_metrics['loss']} | {wer_desc})"
994
  logger.info(desc)
995
 
996
  # Save metrics
 
18
  """
19
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
 
21
+ import itertools
22
  import logging
23
  import math
24
  import os
 
27
  from dataclasses import field
28
  from functools import partial
29
  from pathlib import Path
30
+ from typing import Any, Callable, Dict, Generator, List, Optional, Union
31
 
32
  import datasets
33
  import flax
 
265
  )
266
  num_train_steps: int = field(default=50000, metadata={
267
  "help": "The number of training steps."})
 
268
  shuffle_buffer_size: Optional[int] = field(
269
  default=500,
270
  metadata={
 
402
  return dataset
403
 
404
 
405
+ def collate_batch(samples):
406
+ return {key: [feature[key] for feature in samples] for key in samples[0]}
407
+
408
+
409
+ def data_loader(
410
+ dataset: Dataset,
411
+ batch_size: int,
412
+ drop_last: bool=True,
413
+ num_workers: int=0,
414
+ ) -> Generator:
415
  """
416
  Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
417
  and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
418
  """
419
+ data_loader_iterator = iter(torch.utils.data.DataLoader(
420
+ batch_size=batch_size,
421
+ dataset=dataset.with_format("torch"),
422
+ num_workers=num_workers,
423
+ collate_fn=collate_batch,
424
+ drop_last=drop_last,
425
+ ))
426
+ return data_loader_iterator
 
 
 
 
 
 
 
 
 
 
427
 
428
 
429
  class TrainState(train_state.TrainState):
 
447
 
448
 
449
  def create_learning_rate_fn(
450
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float, warmup_init_value: float=0.0, decay_end_value: float=0.0,
451
  ) -> Callable[[int], jnp.array]:
452
  """Returns a linear warmup, linear_decay learning rate function."""
453
  warmup_fn = optax.linear_schedule(
454
+ init_value=warmup_init_value, end_value=learning_rate, transition_steps=num_warmup_steps)
455
  decay_fn = optax.linear_schedule(
456
+ init_value=learning_rate, end_value=decay_end_value, transition_steps=num_train_steps - num_warmup_steps
457
  )
458
  schedule_fn = optax.join_schedules(
459
  schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
 
661
  vectorized_datasets = raw_datasets.map(
662
  prepare_dataset,
663
  remove_columns=raw_datasets_features,
664
+ )
665
 
666
  # filter training data with inputs longer than max_input_length
667
  def is_audio_in_length_range(length):
 
680
  )
681
 
682
  # 8. Load Metric
683
+ metric_wer = evaluate.load("wer")
684
+ metric_cer = evaluate.load("cer")
685
  do_normalize_eval = data_args.do_normalize_eval
686
 
687
  def compute_metrics(pred_ids, label_ids):
688
  # replace padded labels by the padding token
689
  for idx in range(len(label_ids)):
690
  label_ids[idx][label_ids[idx] == -100] = tokenizer.pad_token_id
 
691
 
692
  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
693
  # we do not want to group tokens when computing the metrics
 
702
  label_str = [label_str[i]
703
  for i in range(len(label_str)) if len(label_str[i]) > 0]
704
 
705
+ wer = 100 * metric_wer.compute(predictions=pred_str, references=label_str)
706
+ cer = 100 * metric_cer.compute(predictions=pred_str, references=label_str)
707
 
708
+ return {"wer": wer, "cer": cer}
709
 
710
  # 9. Save feature extractor, tokenizer and config
711
  feature_extractor.save_pretrained(training_args.output_dir)
 
746
  # Initialize our training
747
  rng = jax.random.PRNGKey(training_args.seed)
748
  rng, dropout_rng = jax.random.split(rng)
749
+ # rng, input_rng = jax.random.split(rng)
750
 
751
  # Store some constant
752
  #num_epochs = int(training_args.num_train_epochs)
 
756
  training_args.per_device_eval_batch_size) * jax.device_count()
757
 
758
  # Create learning rate schedule
759
+ lr_scheduler_types = {"linear", "constant", "constant_with_warmup"}
760
+ if training_args.lr_scheduler_type not in lr_scheduler_types:
761
+ raise ValueError(
762
+ f"lr_scheduler_type of type {training_args.lr_scheduler_type} not supported, choose from {lr_scheduler_types}."
763
+ )
764
+ elif training_args.lr_scheduler_type == "constant":
765
+ warmup_init_value = training_args.learning_rate
766
+ decay_end_value = training_args.learning_rate
767
+ elif training_args.lr_scheduler_type == "constant_with_warmup":
768
+ warmup_init_value = 0.0
769
+ decay_end_value = training_args.learning_rate
770
+ else:
771
+ warmup_init_value = 0.0
772
+ decay_end_value = 0.0
773
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
774
+ data_args.num_train_steps * train_batch_size,
775
  training_args.warmup_steps,
776
  training_args.learning_rate,
777
+ warmup_init_value=warmup_init_value,
778
+ decay_end_value=decay_end_value,
779
  )
780
 
781
  # We use Optax's "masking" functionality to not apply weight decay
 
894
 
895
  # Create parallel version of the train and eval step
896
  p_train_step = jax.pmap(
897
+ partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0, )
898
  )
899
  p_eval_step = jax.pmap(partial(
900
  eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
 
905
 
906
  logger.info("***** Running training *****")
907
  logger.info(
908
+ f" Num examples = {data_args.num_train_steps * train_batch_size}")
909
  logger.info(
910
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
911
  logger.info(
 
917
  # ======================== Training ================================
918
  train_start = time.time()
919
 
 
 
920
  train_metrics = []
921
  epoch = 0
922
+ train_dataset = vectorized_datasets["train"].shuffle(seed=training_args.seed)
923
+ eval_dataset = vectorized_datasets["eval"]
924
+ train_loader = data_loader(train_dataset, train_batch_size)
 
 
 
 
 
 
 
 
 
 
925
  # train
926
  for step in tqdm(range(data_args.num_train_steps), desc="Training...", position=1, leave=False):
 
927
  try:
928
+ samples = next(train_loader)
 
929
  except StopIteration:
930
  epoch += 1
931
+ train_dataset.set_epoch(epoch)
932
+ train_loader = data_loader(train_dataset, train_batch_size)
933
+ samples = next(train_loader)
 
 
 
934
  logger.info(
935
  f"Completed epoch ({epoch} | Loss: {train_metric['loss']}, Learning Rate:"
936
  f" {train_metric['learning_rate']})"
937
  )
938
 
 
 
939
  batch = data_collator(samples)
940
  batch = shard(batch.data)
941
  state, train_metric = p_train_step(state, batch)
 
948
  eval_metrics = []
949
  eval_preds = []
950
  eval_labels = []
951
+ eval_loader = data_loader(eval_dataset, eval_batch_size, drop_last=False)
952
+ if data_args.max_eval_samples:
953
+ max_eval_steps_iter = range(1 + data_args.max_eval_samples // eval_batch_size)
954
+ else:
955
+ max_eval_steps_iter = itertools.repeat(None)
956
+ for _ in tqdm(max_eval_steps_iter, desc="Evaluating...", position=2, leave=False):
 
 
957
  # Model forward
958
+ try:
959
+ samples = next(eval_loader)
960
+ except StopIteration:
961
+ break
962
  batch = data_collator(samples)
963
  labels = batch["labels"]
964
 
 
967
  )
968
  eval_metrics.append(metrics)
969
 
970
+ # generation
971
+ if training_args.predict_with_generate:
972
+ generated_ids = pad_shard_unpad(
973
+ p_generate_step)(state.params, batch.data)
974
+ eval_preds.extend(jax.device_get(
975
+ generated_ids.reshape(-1, gen_kwargs["max_length"])))
976
+ eval_labels.extend(labels)
977
+
 
978
  # normalize eval metrics
979
  eval_metrics = get_metrics(eval_metrics)
980
  eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
981
 
982
+ # compute metrics
983
+ metric_desc = ""
984
  if training_args.predict_with_generate:
985
+ metric_values = compute_metrics(eval_preds, eval_labels)
986
+ eval_metrics.update(metric_values)
987
+ metric_desc = " ".join(
988
+ [f"Eval {key}: {value} |" for key, value in metric_values.items()])
989
 
990
  # Print metrics
991
+ desc = f"Epoch... ({epoch} | Eval Loss: {eval_metrics['loss']} | {metric_desc})"
992
  logger.info(desc)
993
 
994
  # Save metrics
run_streaming_v2.sh → run_streaming_v3.sh RENAMED
@@ -1,4 +1,4 @@
1
- python run_flax_speech_recognition_seq2seq_streaming_v2.py \
2
  --model_name_or_path openai/whisper-tiny.en \
3
  --dataset_name mozilla-foundation/common_voice_11_0 \
4
  --dataset_config nn-NO \
@@ -7,11 +7,11 @@ python run_flax_speech_recognition_seq2seq_streaming_v2.py \
7
  --eval_split_name test\
8
  --output_dir whisper-tiny-ft-dummy\
9
  --overwrite_output_dir\
10
- --num_train_epochs=1\
11
  --warmup_steps=8 \
12
  --do_train \
13
  --do_eval \
14
- --num_train_steps 1000 \
15
  --eval_steps 10 \
16
  --learning_rate=2e-4 \
17
  --per_device_train_batch_size=2 \
 
1
+ python run_flax_speech_recognition_seq2seq_streaming_v3.py \
2
  --model_name_or_path openai/whisper-tiny.en \
3
  --dataset_name mozilla-foundation/common_voice_11_0 \
4
  --dataset_config nn-NO \
 
7
  --eval_split_name test\
8
  --output_dir whisper-tiny-ft-dummy\
9
  --overwrite_output_dir\
10
+ --max_eval_samples=16\
11
  --warmup_steps=8 \
12
  --do_train \
13
  --do_eval \
14
+ --num_train_steps 30 \
15
  --eval_steps 10 \
16
  --learning_rate=2e-4 \
17
  --per_device_train_batch_size=2 \