pere commited on
Commit
ce5e747
·
1 Parent(s): 1426b10
run_flax_speech_recognition_seq2seq_streaming_v3.py CHANGED
@@ -35,7 +35,7 @@ import jax.numpy as jnp
35
  import numpy as np
36
  import optax
37
  import torch
38
- from datasets import Dataset,DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
39
  from torch.utils.data import IterableDataset
40
  from flax import jax_utils, traverse_util
41
  from flax.jax_utils import pad_shard_unpad, unreplicate
@@ -66,7 +66,8 @@ from transformers.utils.versions import require_version
66
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
67
  check_min_version("4.27.0.dev0")
68
 
69
- require_version("datasets>=1.18.2", "To fix: pip install -r examples/flax/speech-recogintion/requirements.txt")
 
70
 
71
  logger = logging.getLogger(__name__)
72
 
@@ -78,7 +79,8 @@ class ModelArguments:
78
  """
79
 
80
  model_name_or_path: str = field(
81
- metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
 
82
  )
83
  config_name: Optional[str] = field(
84
  default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
@@ -91,15 +93,18 @@ class ModelArguments:
91
  )
92
  cache_dir: Optional[str] = field(
93
  default=None,
94
- metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
 
95
  )
96
  use_fast_tokenizer: bool = field(
97
  default=True,
98
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
 
99
  )
100
  model_revision: str = field(
101
  default="main",
102
- metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
 
103
  )
104
  use_auth_token: bool = field(
105
  default=False,
@@ -142,7 +147,8 @@ class DataTrainingArguments:
142
  )
143
  text_column: Optional[str] = field(
144
  default=None,
145
- metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
 
146
  )
147
  dataset_cache_dir: Optional[str] = field(
148
  default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
@@ -170,23 +176,28 @@ class DataTrainingArguments:
170
  )
171
  audio_column_name: str = field(
172
  default="audio",
173
- metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
 
174
  )
175
  text_column_name: str = field(
176
  default="text",
177
- metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
 
178
  )
179
  max_duration_in_seconds: float = field(
180
  default=30.0,
181
- metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
 
182
  )
183
  min_duration_in_seconds: float = field(
184
  default=0.0,
185
- metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
 
186
  )
187
  max_label_length: float = field(
188
  default=128,
189
- metadata={"help": "Truncate transcriptions that are longer `max_eval_length` tokens."},
 
190
  )
191
  pad_input_to_multiple_of: Optional[int] = field(
192
  default=None,
@@ -229,11 +240,13 @@ class DataTrainingArguments:
229
  )
230
  do_remove_punctuation: bool = field(
231
  default=False,
232
- metadata={"help": "Whether the target text should be striped of punctuation."},
 
233
  )
234
  do_normalize_eval: bool = field(
235
  default=True,
236
- metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
 
237
  )
238
  language: str = field(
239
  default=None,
@@ -246,9 +259,11 @@ class DataTrainingArguments:
246
  )
247
  task: str = field(
248
  default="transcribe",
249
- metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
 
250
  )
251
- num_train_steps: int = field(default=50000, metadata={"help": "The number of training steps."})
 
252
  # num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
253
  shuffle_buffer_size: Optional[int] = field(
254
  default=500,
@@ -261,9 +276,11 @@ class DataTrainingArguments:
261
  )
262
  streaming: bool = field(
263
  default=True,
264
- metadata={"help": "Whether to use streaming mode to load and pre-process the data."},
 
265
  )
266
 
 
267
  def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
268
  """
269
  Shift label ids one token to the right.
@@ -348,17 +365,19 @@ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
348
  labels = labels[:, 1:]
349
  labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
350
 
351
- decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
 
352
 
353
  # replace padding with -100 to ignore correctly when computing the loss
354
- labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
 
355
  labels = labels.filled(fill_value=-100)
356
 
357
  batch["labels"] = labels
358
  batch["decoder_input_ids"] = decoder_input_ids
359
 
360
  return batch
361
-
362
 
363
  def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
364
  """
@@ -369,7 +388,8 @@ def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train
369
  if "+" in split:
370
  # load multiple splits separated by the `+` symbol with streaming mode
371
  dataset_splits = [
372
- load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
 
373
  for split_name in split.split("+")
374
  ]
375
  # interleave multiple splits to form one dataset
@@ -377,7 +397,8 @@ def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train
377
  return interleaved_dataset
378
  else:
379
  # load a single split *with* streaming mode
380
- dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
 
381
  return dataset
382
 
383
 
@@ -394,7 +415,8 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
394
 
395
  if drop_last:
396
  steps_per_epoch = len(dataset) // batch_size
397
- batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
 
398
  batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
399
  else:
400
  steps_per_epoch = math.ceil(len(dataset) / batch_size)
@@ -429,11 +451,13 @@ def create_learning_rate_fn(
429
  num_train_steps: int, num_warmup_steps: int, learning_rate: float
430
  ) -> Callable[[int], jnp.array]:
431
  """Returns a linear warmup, linear_decay learning rate function."""
432
- warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
 
433
  decay_fn = optax.linear_schedule(
434
  init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
435
  )
436
- schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
 
437
  return schedule_fn
438
 
439
 
@@ -442,18 +466,21 @@ def main():
442
  # See all possible arguments in src/transformers/training_args.py
443
  # or by passing the --help flag to this script.
444
  # We now keep distinct sets of args, for a cleaner separation of concerns.
445
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
 
446
 
447
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
448
  # If we pass only one argument to the script and it's the path to a json file,
449
  # let's parse it to get our arguments.
450
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
 
451
  else:
452
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
453
 
454
  # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
455
  # information sent is the one passed as arguments along with your JAX/Flax versions.
456
- send_example_telemetry("run_speech_recognition_seq2seq", model_args, data_args, framework="flax")
 
457
 
458
  # 2. Setup logging
459
  # Make one log on every process with the configuration for debugging.
@@ -464,7 +491,8 @@ def main():
464
  )
465
  # Set the verbosity to info of the Transformers logger.
466
  # We only want one process per machine to log things on the screen.
467
- logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
 
468
  if jax.process_index() == 0:
469
  datasets.utils.logging.set_verbosity_warning()
470
  transformers.utils.logging.set_verbosity_info()
@@ -490,16 +518,18 @@ def main():
490
  if training_args.push_to_hub:
491
  if training_args.hub_model_id is None:
492
  repo_name = get_full_repo_name(
493
- Path(training_args.output_dir).absolute().name, token=training_args.hub_token
 
494
  )
495
  else:
496
  repo_name = training_args.hub_model_id
497
  create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
498
- repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
 
499
 
500
  # 3. Load dataset
501
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
502
-
503
  if training_args.do_train:
504
  raw_datasets["train"] = load_maybe_streaming_dataset(
505
  data_args.dataset_name,
@@ -519,13 +549,14 @@ def main():
519
  streaming=data_args.streaming,
520
  use_auth_token=True if model_args.use_auth_token else None,
521
  )
522
-
523
  if not training_args.do_train and not training_args.do_eval:
524
  raise ValueError(
525
  "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
526
  )
527
 
528
- raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
 
529
 
530
  if data_args.audio_column_name not in raw_datasets_features:
531
  raise ValueError(
@@ -572,21 +603,26 @@ def main():
572
  )
573
 
574
  if model.config.decoder_start_token_id is None:
575
- raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
 
576
 
577
  # 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
578
  # so we just need to set the correct target sampling rate.
579
- dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
580
-
 
581
  if dataset_sampling_rate != feature_extractor.sampling_rate:
582
  raw_datasets = raw_datasets.cast_column(
583
- data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
 
584
  )
585
 
586
  # 7. Preprocessing the datasets.
587
  # We need to read the audio files as arrays and tokenize the targets.
588
- max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
589
- min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
 
 
590
  max_label_length = (
591
  data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
592
  )
@@ -602,18 +638,21 @@ def main():
602
 
603
  if data_args.language is not None:
604
  # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
605
- tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
 
606
 
607
  def prepare_dataset(batch):
608
  # process audio
609
  sample = batch[audio_column_name]
610
- inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
 
611
  # process audio length
612
  batch[model_input_name] = inputs.get(model_input_name)[0]
613
  batch["input_length"] = len(sample["array"])
614
 
615
  # process targets
616
- input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
 
617
  if do_remove_punctuation:
618
  input_str = normalizer(input_str).strip()
619
  batch["labels"] = tokenizer(input_str).input_ids
@@ -624,7 +663,7 @@ def main():
624
  prepare_dataset,
625
  remove_columns=raw_datasets_features,
626
  ).with_format("torch")
627
-
628
  # filter training data with inputs longer than max_input_length
629
  def is_audio_in_length_range(length):
630
  return min_input_length < length < max_input_length
@@ -634,14 +673,13 @@ def main():
634
  is_audio_in_length_range,
635
  input_columns=["input_length"],
636
  )
637
-
638
  if training_args.do_eval:
639
  vectorized_datasets["eval"] = vectorized_datasets["eval"].filter(
640
  is_audio_in_length_range,
641
  input_columns=["input_length"],
642
  )
643
 
644
-
645
  # 8. Load Metric
646
  metric = evaluate.load("wer")
647
  do_normalize_eval = data_args.do_normalize_eval
@@ -660,8 +698,10 @@ def main():
660
  pred_str = [normalizer(pred) for pred in pred_str]
661
  label_str = [normalizer(label) for label in label_str]
662
  # filtering step to only evaluate the samples that correspond to non-zero references:
663
- pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
664
- label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
 
 
665
 
666
  wer = 100 * metric.compute(predictions=pred_str, references=label_str)
667
 
@@ -690,7 +730,8 @@ def main():
690
  try:
691
  from flax.metrics.tensorboard import SummaryWriter
692
 
693
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
 
694
  except ImportError as ie:
695
  has_tensorboard = False
696
  logger.warning(
@@ -708,10 +749,10 @@ def main():
708
 
709
  # Store some constant
710
  #num_epochs = int(training_args.num_train_epochs)
711
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
712
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
713
-
714
-
715
 
716
  # Create learning rate schedule
717
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
@@ -736,7 +777,8 @@ def main():
736
  if layer_norm_name in "".join(layer).lower()
737
  ]
738
  )
739
- flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
 
740
  return traverse_util.unflatten_dict(flat_mask)
741
 
742
  # create adam optimizer
@@ -750,7 +792,8 @@ def main():
750
  )
751
 
752
  # Setup train state
753
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
 
754
 
755
  # label smoothed cross entropy
756
  def loss_fn(logits, labels, label_smoothing_factor=0.0):
@@ -762,9 +805,11 @@ def main():
762
  confidence = 1.0 - label_smoothing_factor
763
  low_confidence = (1.0 - confidence) / (vocab_size - 1)
764
  normalizing_constant = -(
765
- confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
 
766
  )
767
- soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
 
768
 
769
  loss = optax.softmax_cross_entropy(logits, soft_labels)
770
  loss = loss - normalizing_constant
@@ -782,7 +827,8 @@ def main():
782
 
783
  def compute_loss(params):
784
  labels = batch.pop("labels")
785
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
 
786
  loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
787
  return loss, num_labels
788
 
@@ -797,9 +843,11 @@ def main():
797
  # true grad = total grad / total samples
798
  grad = jax.lax.psum(grad, "batch")
799
  grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
800
- new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
 
801
 
802
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
 
803
  return new_state, metrics
804
 
805
  # Define eval fn
@@ -823,27 +871,32 @@ def main():
823
 
824
  def generate_step(params, batch):
825
  model.params = params
826
- output_ids = model.generate(batch[model_input_name], attention_mask=batch.get("attention_mask"), **gen_kwargs)
 
827
  return output_ids.sequences
828
 
829
  # Create parallel version of the train and eval step
830
  p_train_step = jax.pmap(
831
  partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
832
  )
833
- p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
 
834
  p_generate_step = jax.pmap(generate_step, "batch")
835
 
836
  # Replicate the train state on each device
837
  state = state.replicate()
838
 
839
  logger.info("***** Running training *****")
840
- logger.info(f" Num examples = {data_args.num_train_steps*train_batch_size}")
841
- logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
842
- logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
 
 
 
843
  logger.info(f" Total optimization steps = {data_args.num_train_steps}")
844
 
845
  train_time = 0
846
-
847
  # ======================== Training ================================
848
  train_start = time.time()
849
 
@@ -859,29 +912,32 @@ def main():
859
  num_workers = 0
860
  # This is not working
861
  # vectorized_datasets["train"] = vectorized_datasets["train"].shuffle()
862
- train_data_loader = torch.utils.data.DataLoader( batch_size=train_batch_size, dataset=vectorized_datasets["train"], num_workers=num_workers, collate_fn=collate_batch, drop_last=True)
863
- train_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(train_data_loader)
864
-
865
-
 
866
  # train
867
- for step in tqdm(range(data_args.num_train_steps), desc="Training...", position=1, leave=False):
868
-
869
  try:
870
  samples = next(train_data_iterator)
871
-
872
  except StopIteration:
873
  epoch += 1
874
- train_data_loader = torch.utils.data.DataLoader( batch_size=train_batch_size, dataset=vectorized_datasets["train"], num_workers=num_workers, collate_fn=collate_batch, drop_last=True)
875
- train_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(train_data_loader)
 
 
876
  samples = next(train_data_iterator)
877
-
878
  logger.info(
879
  f"Completed epoch ({epoch} | Loss: {train_metric['loss']}, Learning Rate:"
880
  f" {train_metric['learning_rate']})"
881
  )
882
-
883
  # reshaped_samples = {key: [feature[key] for feature in samples] for key in samples[0].keys()}
884
- #breakpoint()
885
  batch = data_collator(samples)
886
  batch = shard(batch.data)
887
  state, train_metric = p_train_step(state, batch)
@@ -896,8 +952,10 @@ def main():
896
  eval_labels = []
897
 
898
  #eval_loader = data_loader(input_rng, vectorized_datasets["eval"], eval_batch_size, drop_last=False)
899
- eval_data_loader = torch.utils.data.DataLoader( batch_size=eval_batch_size, dataset=vectorized_datasets["eval"], num_workers=num_workers, collate_fn=collate_batch, drop_last=False)
900
- eval_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(eval_data_loader)
 
 
901
 
902
  for _ in tqdm(range(training_args.eval_steps), desc="Evaluating...", position=2, leave=False):
903
  # Model forward
@@ -912,10 +970,12 @@ def main():
912
 
913
  # generation
914
  if training_args.predict_with_generate:
915
- generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch.data)
916
- eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
 
 
917
  eval_labels.extend(labels)
918
-
919
  # normalize eval metrics
920
  eval_metrics = get_metrics(eval_metrics)
921
  eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
@@ -925,7 +985,8 @@ def main():
925
  if training_args.predict_with_generate:
926
  wer_metric = compute_metrics(eval_preds, eval_labels)
927
  eval_metrics.update(wer_metric)
928
- wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
 
929
 
930
  # Print metrics
931
  desc = f"Epoch... ({epoch} | Eval Loss: {eval_metrics['loss']} | {wer_desc})"
@@ -933,15 +994,18 @@ def main():
933
 
934
  # Save metrics
935
  if has_tensorboard and jax.process_index() == 0:
936
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, step)
 
937
 
938
  # save checkpoint after each epoch and push checkpoint to the hub
939
  if jax.process_index() == 0:
940
- params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
 
941
  model.save_pretrained(training_args.output_dir, params=params)
942
  tokenizer.save_pretrained(training_args.output_dir)
943
  if training_args.push_to_hub:
944
- repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
 
945
 
946
 
947
  if __name__ == "__main__":
 
35
  import numpy as np
36
  import optax
37
  import torch
38
+ from datasets import Dataset, DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
39
  from torch.utils.data import IterableDataset
40
  from flax import jax_utils, traverse_util
41
  from flax.jax_utils import pad_shard_unpad, unreplicate
 
66
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
67
  check_min_version("4.27.0.dev0")
68
 
69
+ require_version("datasets>=1.18.2",
70
+ "To fix: pip install -r examples/flax/speech-recogintion/requirements.txt")
71
 
72
  logger = logging.getLogger(__name__)
73
 
 
79
  """
80
 
81
  model_name_or_path: str = field(
82
+ metadata={
83
+ "help": "Path to pretrained model or model identifier from huggingface.co/models"}
84
  )
85
  config_name: Optional[str] = field(
86
  default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
 
93
  )
94
  cache_dir: Optional[str] = field(
95
  default=None,
96
+ metadata={
97
+ "help": "Where to store the pretrained models downloaded from huggingface.co"},
98
  )
99
  use_fast_tokenizer: bool = field(
100
  default=True,
101
+ metadata={
102
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
103
  )
104
  model_revision: str = field(
105
  default="main",
106
+ metadata={
107
+ "help": "The specific model version to use (can be a branch name, tag name or commit id)."},
108
  )
109
  use_auth_token: bool = field(
110
  default=False,
 
147
  )
148
  text_column: Optional[str] = field(
149
  default=None,
150
+ metadata={
151
+ "help": "The name of the column in the datasets containing the full texts (for summarization)."},
152
  )
153
  dataset_cache_dir: Optional[str] = field(
154
  default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
 
176
  )
177
  audio_column_name: str = field(
178
  default="audio",
179
+ metadata={
180
+ "help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
181
  )
182
  text_column_name: str = field(
183
  default="text",
184
+ metadata={
185
+ "help": "The name of the dataset column containing the text data. Defaults to 'text'"},
186
  )
187
  max_duration_in_seconds: float = field(
188
  default=30.0,
189
+ metadata={
190
+ "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
191
  )
192
  min_duration_in_seconds: float = field(
193
  default=0.0,
194
+ metadata={
195
+ "help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
196
  )
197
  max_label_length: float = field(
198
  default=128,
199
+ metadata={
200
+ "help": "Truncate transcriptions that are longer `max_eval_length` tokens."},
201
  )
202
  pad_input_to_multiple_of: Optional[int] = field(
203
  default=None,
 
240
  )
241
  do_remove_punctuation: bool = field(
242
  default=False,
243
+ metadata={
244
+ "help": "Whether the target text should be striped of punctuation."},
245
  )
246
  do_normalize_eval: bool = field(
247
  default=True,
248
+ metadata={
249
+ "help": "Whether to normalise the references and predictions in the eval WER calculation."},
250
  )
251
  language: str = field(
252
  default=None,
 
259
  )
260
  task: str = field(
261
  default="transcribe",
262
+ metadata={
263
+ "help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
264
  )
265
+ num_train_steps: int = field(default=50000, metadata={
266
+ "help": "The number of training steps."})
267
  # num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
268
  shuffle_buffer_size: Optional[int] = field(
269
  default=500,
 
276
  )
277
  streaming: bool = field(
278
  default=True,
279
+ metadata={
280
+ "help": "Whether to use streaming mode to load and pre-process the data."},
281
  )
282
 
283
+
284
  def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
285
  """
286
  Shift label ids one token to the right.
 
365
  labels = labels[:, 1:]
366
  labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
367
 
368
+ decoder_input_ids = shift_tokens_right(
369
+ labels, self.decoder_start_token_id)
370
 
371
  # replace padding with -100 to ignore correctly when computing the loss
372
+ labels = np.ma.array(labels, mask=np.not_equal(
373
+ labels_batch.attention_mask, 1))
374
  labels = labels.filled(fill_value=-100)
375
 
376
  batch["labels"] = labels
377
  batch["decoder_input_ids"] = decoder_input_ids
378
 
379
  return batch
380
+
381
 
382
  def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
383
  """
 
388
  if "+" in split:
389
  # load multiple splits separated by the `+` symbol with streaming mode
390
  dataset_splits = [
391
+ load_dataset(dataset_name, dataset_config_name,
392
+ split=split_name, streaming=streaming, **kwargs)
393
  for split_name in split.split("+")
394
  ]
395
  # interleave multiple splits to form one dataset
 
397
  return interleaved_dataset
398
  else:
399
  # load a single split *with* streaming mode
400
+ dataset = load_dataset(
401
+ dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
402
  return dataset
403
 
404
 
 
415
 
416
  if drop_last:
417
  steps_per_epoch = len(dataset) // batch_size
418
+ # Skip incomplete batch.
419
+ batch_idx = batch_idx[: steps_per_epoch * batch_size]
420
  batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
421
  else:
422
  steps_per_epoch = math.ceil(len(dataset) / batch_size)
 
451
  num_train_steps: int, num_warmup_steps: int, learning_rate: float
452
  ) -> Callable[[int], jnp.array]:
453
  """Returns a linear warmup, linear_decay learning rate function."""
454
+ warmup_fn = optax.linear_schedule(
455
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
456
  decay_fn = optax.linear_schedule(
457
  init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
458
  )
459
+ schedule_fn = optax.join_schedules(
460
+ schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
461
  return schedule_fn
462
 
463
 
 
466
  # See all possible arguments in src/transformers/training_args.py
467
  # or by passing the --help flag to this script.
468
  # We now keep distinct sets of args, for a cleaner separation of concerns.
469
+ parser = HfArgumentParser(
470
+ (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
471
 
472
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
473
  # If we pass only one argument to the script and it's the path to a json file,
474
  # let's parse it to get our arguments.
475
+ model_args, data_args, training_args = parser.parse_json_file(
476
+ json_file=os.path.abspath(sys.argv[1]))
477
  else:
478
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
479
 
480
  # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
481
  # information sent is the one passed as arguments along with your JAX/Flax versions.
482
+ send_example_telemetry("run_speech_recognition_seq2seq",
483
+ model_args, data_args, framework="flax")
484
 
485
  # 2. Setup logging
486
  # Make one log on every process with the configuration for debugging.
 
491
  )
492
  # Set the verbosity to info of the Transformers logger.
493
  # We only want one process per machine to log things on the screen.
494
+ logger.setLevel(logging.INFO if jax.process_index()
495
+ == 0 else logging.ERROR)
496
  if jax.process_index() == 0:
497
  datasets.utils.logging.set_verbosity_warning()
498
  transformers.utils.logging.set_verbosity_info()
 
518
  if training_args.push_to_hub:
519
  if training_args.hub_model_id is None:
520
  repo_name = get_full_repo_name(
521
+ Path(training_args.output_dir).absolute(
522
+ ).name, token=training_args.hub_token
523
  )
524
  else:
525
  repo_name = training_args.hub_model_id
526
  create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
527
+ repo = Repository(training_args.output_dir,
528
+ clone_from=repo_name, token=training_args.hub_token)
529
 
530
  # 3. Load dataset
531
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
532
+
533
  if training_args.do_train:
534
  raw_datasets["train"] = load_maybe_streaming_dataset(
535
  data_args.dataset_name,
 
549
  streaming=data_args.streaming,
550
  use_auth_token=True if model_args.use_auth_token else None,
551
  )
552
+
553
  if not training_args.do_train and not training_args.do_eval:
554
  raise ValueError(
555
  "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
556
  )
557
 
558
+ raw_datasets_features = list(
559
+ next(iter(raw_datasets.values())).features.keys())
560
 
561
  if data_args.audio_column_name not in raw_datasets_features:
562
  raise ValueError(
 
603
  )
604
 
605
  if model.config.decoder_start_token_id is None:
606
+ raise ValueError(
607
+ "Make sure that `config.decoder_start_token_id` is correctly defined")
608
 
609
  # 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
610
  # so we just need to set the correct target sampling rate.
611
+ dataset_sampling_rate = next(
612
+ iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
613
+
614
  if dataset_sampling_rate != feature_extractor.sampling_rate:
615
  raw_datasets = raw_datasets.cast_column(
616
+ data_args.audio_column_name, datasets.features.Audio(
617
+ sampling_rate=feature_extractor.sampling_rate)
618
  )
619
 
620
  # 7. Preprocessing the datasets.
621
  # We need to read the audio files as arrays and tokenize the targets.
622
+ max_input_length = int(
623
+ data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
624
+ min_input_length = int(
625
+ data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
626
  max_label_length = (
627
  data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
628
  )
 
638
 
639
  if data_args.language is not None:
640
  # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
641
+ tokenizer.set_prefix_tokens(
642
+ language=data_args.language, task=data_args.task)
643
 
644
  def prepare_dataset(batch):
645
  # process audio
646
  sample = batch[audio_column_name]
647
+ inputs = feature_extractor(
648
+ sample["array"], sampling_rate=sample["sampling_rate"])
649
  # process audio length
650
  batch[model_input_name] = inputs.get(model_input_name)[0]
651
  batch["input_length"] = len(sample["array"])
652
 
653
  # process targets
654
+ input_str = batch[text_column_name].lower(
655
+ ) if do_lower_case else batch[text_column_name]
656
  if do_remove_punctuation:
657
  input_str = normalizer(input_str).strip()
658
  batch["labels"] = tokenizer(input_str).input_ids
 
663
  prepare_dataset,
664
  remove_columns=raw_datasets_features,
665
  ).with_format("torch")
666
+
667
  # filter training data with inputs longer than max_input_length
668
  def is_audio_in_length_range(length):
669
  return min_input_length < length < max_input_length
 
673
  is_audio_in_length_range,
674
  input_columns=["input_length"],
675
  )
676
+
677
  if training_args.do_eval:
678
  vectorized_datasets["eval"] = vectorized_datasets["eval"].filter(
679
  is_audio_in_length_range,
680
  input_columns=["input_length"],
681
  )
682
 
 
683
  # 8. Load Metric
684
  metric = evaluate.load("wer")
685
  do_normalize_eval = data_args.do_normalize_eval
 
698
  pred_str = [normalizer(pred) for pred in pred_str]
699
  label_str = [normalizer(label) for label in label_str]
700
  # filtering step to only evaluate the samples that correspond to non-zero references:
701
+ pred_str = [pred_str[i]
702
+ for i in range(len(pred_str)) if len(label_str[i]) > 0]
703
+ label_str = [label_str[i]
704
+ for i in range(len(label_str)) if len(label_str[i]) > 0]
705
 
706
  wer = 100 * metric.compute(predictions=pred_str, references=label_str)
707
 
 
730
  try:
731
  from flax.metrics.tensorboard import SummaryWriter
732
 
733
+ summary_writer = SummaryWriter(
734
+ log_dir=Path(training_args.output_dir))
735
  except ImportError as ie:
736
  has_tensorboard = False
737
  logger.warning(
 
749
 
750
  # Store some constant
751
  #num_epochs = int(training_args.num_train_epochs)
752
+ train_batch_size = int(
753
+ training_args.per_device_train_batch_size) * jax.device_count()
754
+ eval_batch_size = int(
755
+ training_args.per_device_eval_batch_size) * jax.device_count()
756
 
757
  # Create learning rate schedule
758
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
 
777
  if layer_norm_name in "".join(layer).lower()
778
  ]
779
  )
780
+ flat_mask = {path: (path[-1] != "bias" and path[-2:]
781
+ not in layer_norm_named_params) for path in flat_params}
782
  return traverse_util.unflatten_dict(flat_mask)
783
 
784
  # create adam optimizer
 
792
  )
793
 
794
  # Setup train state
795
+ state = TrainState.create(
796
+ apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
797
 
798
  # label smoothed cross entropy
799
  def loss_fn(logits, labels, label_smoothing_factor=0.0):
 
805
  confidence = 1.0 - label_smoothing_factor
806
  low_confidence = (1.0 - confidence) / (vocab_size - 1)
807
  normalizing_constant = -(
808
+ confidence * jnp.log(confidence) + (vocab_size - 1) *
809
+ low_confidence * jnp.log(low_confidence + 1e-20)
810
  )
811
+ soft_labels = onehot(labels, vocab_size,
812
+ on_value=confidence, off_value=low_confidence)
813
 
814
  loss = optax.softmax_cross_entropy(logits, soft_labels)
815
  loss = loss - normalizing_constant
 
827
 
828
  def compute_loss(params):
829
  labels = batch.pop("labels")
830
+ logits = state.apply_fn(
831
+ **batch, params=params, dropout_rng=dropout_rng, train=True)[0]
832
  loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
833
  return loss, num_labels
834
 
 
843
  # true grad = total grad / total samples
844
  grad = jax.lax.psum(grad, "batch")
845
  grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
846
+ new_state = state.apply_gradients(
847
+ grads=grad, dropout_rng=new_dropout_rng)
848
 
849
+ metrics = {"loss": loss,
850
+ "learning_rate": linear_decay_lr_schedule_fn(state.step)}
851
  return new_state, metrics
852
 
853
  # Define eval fn
 
871
 
872
  def generate_step(params, batch):
873
  model.params = params
874
+ output_ids = model.generate(batch[model_input_name], attention_mask=batch.get(
875
+ "attention_mask"), **gen_kwargs)
876
  return output_ids.sequences
877
 
878
  # Create parallel version of the train and eval step
879
  p_train_step = jax.pmap(
880
  partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
881
  )
882
+ p_eval_step = jax.pmap(partial(
883
+ eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
884
  p_generate_step = jax.pmap(generate_step, "batch")
885
 
886
  # Replicate the train state on each device
887
  state = state.replicate()
888
 
889
  logger.info("***** Running training *****")
890
+ logger.info(
891
+ f" Num examples = {data_args.num_train_steps*train_batch_size}")
892
+ logger.info(
893
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
894
+ logger.info(
895
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
896
  logger.info(f" Total optimization steps = {data_args.num_train_steps}")
897
 
898
  train_time = 0
899
+
900
  # ======================== Training ================================
901
  train_start = time.time()
902
 
 
912
  num_workers = 0
913
  # This is not working
914
  # vectorized_datasets["train"] = vectorized_datasets["train"].shuffle()
915
+ train_data_loader = torch.utils.data.DataLoader(
916
+ batch_size=train_batch_size, dataset=vectorized_datasets["train"], num_workers=num_workers, collate_fn=collate_batch, drop_last=True)
917
+ train_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(
918
+ train_data_loader)
919
+
920
  # train
921
+ for step in tqdm(range(data_args.num_train_steps), desc="Training...", position=1, leave=False):
922
+
923
  try:
924
  samples = next(train_data_iterator)
925
+
926
  except StopIteration:
927
  epoch += 1
928
+ train_data_loader = torch.utils.data.DataLoader(
929
+ batch_size=train_batch_size, dataset=vectorized_datasets["train"], num_workers=num_workers, collate_fn=collate_batch, drop_last=True)
930
+ train_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(
931
+ train_data_loader)
932
  samples = next(train_data_iterator)
933
+
934
  logger.info(
935
  f"Completed epoch ({epoch} | Loss: {train_metric['loss']}, Learning Rate:"
936
  f" {train_metric['learning_rate']})"
937
  )
938
+
939
  # reshaped_samples = {key: [feature[key] for feature in samples] for key in samples[0].keys()}
940
+ # breakpoint()
941
  batch = data_collator(samples)
942
  batch = shard(batch.data)
943
  state, train_metric = p_train_step(state, batch)
 
952
  eval_labels = []
953
 
954
  #eval_loader = data_loader(input_rng, vectorized_datasets["eval"], eval_batch_size, drop_last=False)
955
+ eval_data_loader = torch.utils.data.DataLoader(
956
+ batch_size=eval_batch_size, dataset=vectorized_datasets["eval"], num_workers=num_workers, collate_fn=collate_batch, drop_last=False)
957
+ eval_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(
958
+ eval_data_loader)
959
 
960
  for _ in tqdm(range(training_args.eval_steps), desc="Evaluating...", position=2, leave=False):
961
  # Model forward
 
970
 
971
  # generation
972
  if training_args.predict_with_generate:
973
+ generated_ids = pad_shard_unpad(
974
+ p_generate_step)(state.params, batch.data)
975
+ eval_preds.extend(jax.device_get(
976
+ generated_ids.reshape(-1, gen_kwargs["max_length"])))
977
  eval_labels.extend(labels)
978
+ breakpoint()
979
  # normalize eval metrics
980
  eval_metrics = get_metrics(eval_metrics)
981
  eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
 
985
  if training_args.predict_with_generate:
986
  wer_metric = compute_metrics(eval_preds, eval_labels)
987
  eval_metrics.update(wer_metric)
988
+ wer_desc = " ".join(
989
+ [f"Eval {key}: {value} |" for key, value in wer_metric.items()])
990
 
991
  # Print metrics
992
  desc = f"Epoch... ({epoch} | Eval Loss: {eval_metrics['loss']} | {wer_desc})"
 
994
 
995
  # Save metrics
996
  if has_tensorboard and jax.process_index() == 0:
997
+ write_metric(summary_writer, train_metrics,
998
+ eval_metrics, train_time, step)
999
 
1000
  # save checkpoint after each epoch and push checkpoint to the hub
1001
  if jax.process_index() == 0:
1002
+ params = jax.device_get(
1003
+ jax.tree_util.tree_map(lambda x: x[0], state.params))
1004
  model.save_pretrained(training_args.output_dir, params=params)
1005
  tokenizer.save_pretrained(training_args.output_dir)
1006
  if training_args.push_to_hub:
1007
+ repo.push_to_hub(
1008
+ commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
1009
 
1010
 
1011
  if __name__ == "__main__":
run_streaming.sh CHANGED
@@ -2,6 +2,7 @@ python run_flax_speech_recognition_seq2seq_streaming_v3.py \
2
  --model_name_or_path openai/whisper-tiny.en \
3
  --dataset_name mozilla-foundation/common_voice_11_0 \
4
  --dataset_config es \
 
5
  --text_column_name sentence \
6
  --train_split_name test\
7
  --eval_split_name test\
 
2
  --model_name_or_path openai/whisper-tiny.en \
3
  --dataset_name mozilla-foundation/common_voice_11_0 \
4
  --dataset_config es \
5
+ --language es \
6
  --text_column_name sentence \
7
  --train_split_name test\
8
  --eval_split_name test\