versae commited on
Commit
eac79a2
·
1 Parent(s): a325bea

Clean up files

Browse files
run_flax_speech_recognition_seq2seq_streaming.py CHANGED
@@ -18,6 +18,7 @@ Fine-tuning the Flax library models for sequence to sequence speech recognition.
18
  """
19
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
 
 
21
  import logging
22
  import math
23
  import os
@@ -26,7 +27,7 @@ import time
26
  from dataclasses import field
27
  from functools import partial
28
  from pathlib import Path
29
- from typing import Any, Callable, Dict, List, Optional, Union
30
 
31
  import datasets
32
  import flax
@@ -35,7 +36,7 @@ import jax.numpy as jnp
35
  import numpy as np
36
  import optax
37
  import torch
38
- from datasets import Dataset,DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
39
  from torch.utils.data import IterableDataset
40
  from flax import jax_utils, traverse_util
41
  from flax.jax_utils import pad_shard_unpad, unreplicate
@@ -66,7 +67,8 @@ from transformers.utils.versions import require_version
66
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
67
  check_min_version("4.27.0.dev0")
68
 
69
- require_version("datasets>=1.18.2", "To fix: pip install -r examples/flax/speech-recogintion/requirements.txt")
 
70
 
71
  logger = logging.getLogger(__name__)
72
 
@@ -78,7 +80,8 @@ class ModelArguments:
78
  """
79
 
80
  model_name_or_path: str = field(
81
- metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
 
82
  )
83
  config_name: Optional[str] = field(
84
  default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
@@ -91,15 +94,18 @@ class ModelArguments:
91
  )
92
  cache_dir: Optional[str] = field(
93
  default=None,
94
- metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
 
95
  )
96
  use_fast_tokenizer: bool = field(
97
  default=True,
98
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
 
99
  )
100
  model_revision: str = field(
101
  default="main",
102
- metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
 
103
  )
104
  use_auth_token: bool = field(
105
  default=False,
@@ -142,7 +148,8 @@ class DataTrainingArguments:
142
  )
143
  text_column: Optional[str] = field(
144
  default=None,
145
- metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
 
146
  )
147
  dataset_cache_dir: Optional[str] = field(
148
  default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
@@ -170,23 +177,28 @@ class DataTrainingArguments:
170
  )
171
  audio_column_name: str = field(
172
  default="audio",
173
- metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
 
174
  )
175
  text_column_name: str = field(
176
  default="text",
177
- metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
 
178
  )
179
  max_duration_in_seconds: float = field(
180
  default=30.0,
181
- metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
 
182
  )
183
  min_duration_in_seconds: float = field(
184
  default=0.0,
185
- metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
 
186
  )
187
  max_label_length: float = field(
188
  default=128,
189
- metadata={"help": "Truncate transcriptions that are longer `max_eval_length` tokens."},
 
190
  )
191
  pad_input_to_multiple_of: Optional[int] = field(
192
  default=None,
@@ -229,11 +241,13 @@ class DataTrainingArguments:
229
  )
230
  do_remove_punctuation: bool = field(
231
  default=False,
232
- metadata={"help": "Whether the target text should be striped of punctuation."},
 
233
  )
234
  do_normalize_eval: bool = field(
235
  default=True,
236
- metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
 
237
  )
238
  language: str = field(
239
  default=None,
@@ -246,10 +260,11 @@ class DataTrainingArguments:
246
  )
247
  task: str = field(
248
  default="transcribe",
249
- metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
 
250
  )
251
- num_train_steps: int = field(default=50000, metadata={"help": "The number of training steps."})
252
- # num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
253
  shuffle_buffer_size: Optional[int] = field(
254
  default=500,
255
  metadata={
@@ -261,9 +276,11 @@ class DataTrainingArguments:
261
  )
262
  streaming: bool = field(
263
  default=True,
264
- metadata={"help": "Whether to use streaming mode to load and pre-process the data."},
 
265
  )
266
 
 
267
  def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
268
  """
269
  Shift label ids one token to the right.
@@ -348,17 +365,19 @@ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
348
  labels = labels[:, 1:]
349
  labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
350
 
351
- decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
 
352
 
353
  # replace padding with -100 to ignore correctly when computing the loss
354
- labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
 
355
  labels = labels.filled(fill_value=-100)
356
 
357
  batch["labels"] = labels
358
  batch["decoder_input_ids"] = decoder_input_ids
359
 
360
  return batch
361
-
362
 
363
  def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
364
  """
@@ -369,7 +388,8 @@ def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train
369
  if "+" in split:
370
  # load multiple splits separated by the `+` symbol with streaming mode
371
  dataset_splits = [
372
- load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
 
373
  for split_name in split.split("+")
374
  ]
375
  # interleave multiple splits to form one dataset
@@ -377,32 +397,33 @@ def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train
377
  return interleaved_dataset
378
  else:
379
  # load a single split *with* streaming mode
380
- dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
 
381
  return dataset
382
 
383
 
384
- def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
 
 
 
 
 
 
 
 
 
385
  """
386
  Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
387
  and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
388
  """
389
- if shuffle:
390
- batch_idx = jax.random.permutation(rng, len(dataset))
391
- batch_idx = np.asarray(batch_idx)
392
- else:
393
- batch_idx = np.arange(len(dataset))
394
-
395
- if drop_last:
396
- steps_per_epoch = len(dataset) // batch_size
397
- batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
398
- batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
399
- else:
400
- steps_per_epoch = math.ceil(len(dataset) / batch_size)
401
- batch_idx = np.array_split(batch_idx, steps_per_epoch)
402
-
403
- for idx in batch_idx:
404
- batch = dataset[idx]
405
- yield batch
406
 
407
 
408
  class TrainState(train_state.TrainState):
@@ -426,14 +447,16 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
426
 
427
 
428
  def create_learning_rate_fn(
429
- num_train_steps: int, num_warmup_steps: int, learning_rate: float
430
  ) -> Callable[[int], jnp.array]:
431
  """Returns a linear warmup, linear_decay learning rate function."""
432
- warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
 
433
  decay_fn = optax.linear_schedule(
434
- init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
435
  )
436
- schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
 
437
  return schedule_fn
438
 
439
 
@@ -442,18 +465,21 @@ def main():
442
  # See all possible arguments in src/transformers/training_args.py
443
  # or by passing the --help flag to this script.
444
  # We now keep distinct sets of args, for a cleaner separation of concerns.
445
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
 
446
 
447
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
448
  # If we pass only one argument to the script and it's the path to a json file,
449
  # let's parse it to get our arguments.
450
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
 
451
  else:
452
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
453
 
454
  # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
455
  # information sent is the one passed as arguments along with your JAX/Flax versions.
456
- send_example_telemetry("run_speech_recognition_seq2seq", model_args, data_args, framework="flax")
 
457
 
458
  # 2. Setup logging
459
  # Make one log on every process with the configuration for debugging.
@@ -464,7 +490,8 @@ def main():
464
  )
465
  # Set the verbosity to info of the Transformers logger.
466
  # We only want one process per machine to log things on the screen.
467
- logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
 
468
  if jax.process_index() == 0:
469
  datasets.utils.logging.set_verbosity_warning()
470
  transformers.utils.logging.set_verbosity_info()
@@ -490,18 +517,20 @@ def main():
490
  if training_args.push_to_hub:
491
  if training_args.hub_model_id is None:
492
  repo_name = get_full_repo_name(
493
- Path(training_args.output_dir).absolute().name, token=training_args.hub_token
 
494
  )
495
  else:
496
  repo_name = training_args.hub_model_id
497
  create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
498
- repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
 
499
 
500
  # 3. Load dataset
501
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
502
-
503
  if training_args.do_train:
504
- raw_datasets[data_args.train_split_name] = load_maybe_streaming_dataset(
505
  data_args.dataset_name,
506
  data_args.dataset_config_name,
507
  split=data_args.train_split_name,
@@ -511,7 +540,7 @@ def main():
511
  )
512
 
513
  if training_args.do_eval:
514
- raw_datasets[data_args.eval_split_name] = load_maybe_streaming_dataset(
515
  data_args.dataset_name,
516
  data_args.dataset_config_name,
517
  split=data_args.eval_split_name,
@@ -519,13 +548,14 @@ def main():
519
  streaming=data_args.streaming,
520
  use_auth_token=True if model_args.use_auth_token else None,
521
  )
522
-
523
  if not training_args.do_train and not training_args.do_eval:
524
  raise ValueError(
525
  "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
526
  )
527
 
528
- raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
 
529
 
530
  if data_args.audio_column_name not in raw_datasets_features:
531
  raise ValueError(
@@ -572,21 +602,26 @@ def main():
572
  )
573
 
574
  if model.config.decoder_start_token_id is None:
575
- raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
 
576
 
577
  # 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
578
  # so we just need to set the correct target sampling rate.
579
- dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
580
-
 
581
  if dataset_sampling_rate != feature_extractor.sampling_rate:
582
  raw_datasets = raw_datasets.cast_column(
583
- data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
 
584
  )
585
 
586
  # 7. Preprocessing the datasets.
587
  # We need to read the audio files as arrays and tokenize the targets.
588
- max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
589
- min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
 
 
590
  max_label_length = (
591
  data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
592
  )
@@ -602,18 +637,21 @@ def main():
602
 
603
  if data_args.language is not None:
604
  # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
605
- tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
 
606
 
607
  def prepare_dataset(batch):
608
  # process audio
609
  sample = batch[audio_column_name]
610
- inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
 
611
  # process audio length
612
  batch[model_input_name] = inputs.get(model_input_name)[0]
613
  batch["input_length"] = len(sample["array"])
614
 
615
  # process targets
616
- input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
 
617
  if do_remove_punctuation:
618
  input_str = normalizer(input_str).strip()
619
  batch["labels"] = tokenizer(input_str).input_ids
@@ -623,52 +661,52 @@ def main():
623
  vectorized_datasets = raw_datasets.map(
624
  prepare_dataset,
625
  remove_columns=raw_datasets_features,
626
- ).with_format("torch")
627
 
628
-
629
  # filter training data with inputs longer than max_input_length
630
  def is_audio_in_length_range(length):
631
  return min_input_length < length < max_input_length
632
 
633
-
634
  if training_args.do_train:
635
- vectorized_datasets[data_args.train_split_name] = vectorized_datasets[data_args.train_split_name].filter(
636
  is_audio_in_length_range,
637
  input_columns=["input_length"],
638
  )
 
639
  if training_args.do_eval:
640
- vectorized_datasets[data_args.eval_split_name] = vectorized_datasets[data_args.eval_split_name].filter(
641
  is_audio_in_length_range,
642
  input_columns=["input_length"],
643
  )
644
-
645
  # 8. Load Metric
646
- metric = evaluate.load("wer")
 
647
  do_normalize_eval = data_args.do_normalize_eval
648
 
649
- def compute_metrics(preds, labels):
650
  # replace padded labels by the padding token
651
- for idx in range(len(labels)):
652
- labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
653
 
654
- pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True)
655
  # we do not want to group tokens when computing the metrics
656
- label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
657
-
658
- # TODO
659
- # We should implement and test this as well
660
- #if do_normalize_eval:
661
- # pred_str = [normalizer(pred) for pred in pred_str]
662
- # label_str = [normalizer(label) for label in label_str]
663
- # # filtering step to only evaluate the samples that correspond to non-zero references:
664
- # pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
665
- # label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
666
-
667
-
668
- wer = metric.compute(predictions=pred_str, references=label_str)
669
- return {"wer": wer}
670
-
671
-
672
  # 9. Save feature extractor, tokenizer and config
673
  feature_extractor.save_pretrained(training_args.output_dir)
674
  tokenizer.save_pretrained(training_args.output_dir)
@@ -692,7 +730,8 @@ def main():
692
  try:
693
  from flax.metrics.tensorboard import SummaryWriter
694
 
695
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
 
696
  except ImportError as ie:
697
  has_tensorboard = False
698
  logger.warning(
@@ -707,18 +746,36 @@ def main():
707
  # Initialize our training
708
  rng = jax.random.PRNGKey(training_args.seed)
709
  rng, dropout_rng = jax.random.split(rng)
 
710
 
711
  # Store some constant
712
  #num_epochs = int(training_args.num_train_epochs)
713
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
714
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
715
-
 
716
 
717
  # Create learning rate schedule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
719
- data_args.num_train_steps*train_batch_size,
720
  training_args.warmup_steps,
721
  training_args.learning_rate,
 
 
722
  )
723
 
724
  # We use Optax's "masking" functionality to not apply weight decay
@@ -737,7 +794,8 @@ def main():
737
  if layer_norm_name in "".join(layer).lower()
738
  ]
739
  )
740
- flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
 
741
  return traverse_util.unflatten_dict(flat_mask)
742
 
743
  # create adam optimizer
@@ -751,7 +809,8 @@ def main():
751
  )
752
 
753
  # Setup train state
754
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
 
755
 
756
  # label smoothed cross entropy
757
  def loss_fn(logits, labels, label_smoothing_factor=0.0):
@@ -763,9 +822,11 @@ def main():
763
  confidence = 1.0 - label_smoothing_factor
764
  low_confidence = (1.0 - confidence) / (vocab_size - 1)
765
  normalizing_constant = -(
766
- confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
 
767
  )
768
- soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
 
769
 
770
  loss = optax.softmax_cross_entropy(logits, soft_labels)
771
  loss = loss - normalizing_constant
@@ -783,7 +844,8 @@ def main():
783
 
784
  def compute_loss(params):
785
  labels = batch.pop("labels")
786
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
 
787
  loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
788
  return loss, num_labels
789
 
@@ -798,9 +860,11 @@ def main():
798
  # true grad = total grad / total samples
799
  grad = jax.lax.psum(grad, "batch")
800
  grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
801
- new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
 
802
 
803
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
 
804
  return new_state, metrics
805
 
806
  # Define eval fn
@@ -824,84 +888,78 @@ def main():
824
 
825
  def generate_step(params, batch):
826
  model.params = params
827
- output_ids = model.generate(batch[model_input_name], attention_mask=batch.get("attention_mask"), **gen_kwargs)
 
828
  return output_ids.sequences
829
 
830
  # Create parallel version of the train and eval step
831
  p_train_step = jax.pmap(
832
- partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
833
  )
834
- p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
 
835
  p_generate_step = jax.pmap(generate_step, "batch")
836
 
837
  # Replicate the train state on each device
838
  state = state.replicate()
839
 
840
  logger.info("***** Running training *****")
841
- logger.info(f" Num examples = {data_args.num_train_steps*train_batch_size}")
842
- logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
843
- logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
 
 
 
844
  logger.info(f" Total optimization steps = {data_args.num_train_steps}")
845
 
846
  train_time = 0
847
- eval_metrics = []
848
- eval_preds = []
849
- eval_labels = []
850
-
851
-
852
  # ======================== Training ================================
853
  train_start = time.time()
854
 
855
- # Create sampling rng
856
- #rng, input_rng = jax.random.split(rng)
857
  train_metrics = []
858
-
859
- # TODO
860
- # Do the reset epoch stuff to shuffle
861
-
862
- epoch = 1
863
-
864
- # Create a batched data iterator
865
- num_workers = 0
866
- batched_data_loader = torch.utils.data.DataLoader( batch_size=train_batch_size, dataset=vectorized_datasets[data_args.train_split_name], num_workers=num_workers, collate_fn=lambda x: x )
867
- batched_data_iterator = iter(batched_data_loader)
868
-
869
  # train
870
- for step in tqdm(range(data_args.num_train_steps), desc="Training...", position=1, leave=False):
871
-
872
  try:
873
- samples = next(batched_data_iterator)
874
-
875
  except StopIteration:
876
  epoch += 1
877
- # TODO - Not currently shuffled
878
- batched_data_loader = torch.utils.data.DataLoader( batch_size=train_batch_size, dataset=vectorized_datasets[data_args.train_split_name], num_workers=num_workers, collate_fn=lambda x: x )
879
- batched_data_iterator = iter(batched_data_loader)
880
- samples = next(batched_data_iterator)
881
-
882
  logger.info(
883
  f"Completed epoch ({epoch} | Loss: {train_metric['loss']}, Learning Rate:"
884
  f" {train_metric['learning_rate']})"
885
  )
886
-
887
- reshaped_samples = {key: [feature[key] for feature in samples] for key in samples[0].keys()}
888
- batch = data_collator(reshaped_samples)
889
  batch = shard(batch.data)
890
  state, train_metric = p_train_step(state, batch)
891
-
 
892
  train_time += time.time() - train_start
893
  train_metric = unreplicate(train_metric)
894
-
895
  # ======================== Evaluating ==============================
896
  if step % training_args.eval_steps == 0 and step > 0:
897
- batched_data_eval_loader = torch.utils.data.DataLoader( batch_size=eval_batch_size, dataset=vectorized_datasets[data_args.eval_split_name], num_workers=num_workers, collate_fn=lambda x: x )
898
- batched_data_eval_iterator = iter(batched_data_eval_loader)
899
-
900
- for _ in tqdm(range(data_args.max_eval_samples//eval_batch_size), desc="Evaluating...", position=2, leave=False):
 
 
 
 
 
901
  # Model forward
902
- samples = next(batched_data_eval_iterator)
903
- reshaped_samples = {key: [feature[key] for feature in samples] for key in samples[0].keys()}
904
- batch = data_collator(reshaped_samples)
 
 
905
  labels = batch["labels"]
906
 
907
  metrics = pad_shard_unpad(p_eval_step, static_return=True)(
@@ -909,43 +967,44 @@ def main():
909
  )
910
  eval_metrics.append(metrics)
911
 
912
- # generation
913
- if training_args.predict_with_generate:
914
- generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch.data)
915
- eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
916
- eval_labels.extend(labels)
917
-
 
 
918
  # normalize eval metrics
919
  eval_metrics = get_metrics(eval_metrics)
920
  eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
921
-
922
- # compute WER metric
923
- wer_desc = ""
924
  if training_args.predict_with_generate:
925
- wer_metric = compute_metrics(eval_preds, eval_labels)
926
- eval_metrics.update(wer_metric)
927
- wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
 
928
 
929
  # Print metrics
930
- desc = f"Epoch... ({epoch} | Eval Loss: {eval_metrics['loss']} | {wer_desc})"
931
  logger.info(desc)
932
 
933
  # Save metrics
934
  if has_tensorboard and jax.process_index() == 0:
935
- #TODO
936
- breakpoint()
937
- # cur_step = epoch * (len(vectorized_datasets[data_args.train_split_name]) // train_batch_size)
938
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, data_args.num_train_steps)
939
 
940
- # TODO THis is not happening at every epoch!!!
941
- breakpoint()
942
  # save checkpoint after each epoch and push checkpoint to the hub
943
  if jax.process_index() == 0:
944
- params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
 
945
  model.save_pretrained(training_args.output_dir, params=params)
946
  tokenizer.save_pretrained(training_args.output_dir)
947
  if training_args.push_to_hub:
948
- repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
 
949
 
950
 
951
  if __name__ == "__main__":
 
18
  """
19
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
 
21
+ import itertools
22
  import logging
23
  import math
24
  import os
 
27
  from dataclasses import field
28
  from functools import partial
29
  from pathlib import Path
30
+ from typing import Any, Callable, Dict, Generator, List, Optional, Union
31
 
32
  import datasets
33
  import flax
 
36
  import numpy as np
37
  import optax
38
  import torch
39
+ from datasets import Dataset, DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
40
  from torch.utils.data import IterableDataset
41
  from flax import jax_utils, traverse_util
42
  from flax.jax_utils import pad_shard_unpad, unreplicate
 
67
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
68
  check_min_version("4.27.0.dev0")
69
 
70
+ require_version("datasets>=1.18.2",
71
+ "To fix: pip install -r examples/flax/speech-recogintion/requirements.txt")
72
 
73
  logger = logging.getLogger(__name__)
74
 
 
80
  """
81
 
82
  model_name_or_path: str = field(
83
+ metadata={
84
+ "help": "Path to pretrained model or model identifier from huggingface.co/models"}
85
  )
86
  config_name: Optional[str] = field(
87
  default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
 
94
  )
95
  cache_dir: Optional[str] = field(
96
  default=None,
97
+ metadata={
98
+ "help": "Where to store the pretrained models downloaded from huggingface.co"},
99
  )
100
  use_fast_tokenizer: bool = field(
101
  default=True,
102
+ metadata={
103
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
104
  )
105
  model_revision: str = field(
106
  default="main",
107
+ metadata={
108
+ "help": "The specific model version to use (can be a branch name, tag name or commit id)."},
109
  )
110
  use_auth_token: bool = field(
111
  default=False,
 
148
  )
149
  text_column: Optional[str] = field(
150
  default=None,
151
+ metadata={
152
+ "help": "The name of the column in the datasets containing the full texts (for summarization)."},
153
  )
154
  dataset_cache_dir: Optional[str] = field(
155
  default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
 
177
  )
178
  audio_column_name: str = field(
179
  default="audio",
180
+ metadata={
181
+ "help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
182
  )
183
  text_column_name: str = field(
184
  default="text",
185
+ metadata={
186
+ "help": "The name of the dataset column containing the text data. Defaults to 'text'"},
187
  )
188
  max_duration_in_seconds: float = field(
189
  default=30.0,
190
+ metadata={
191
+ "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
192
  )
193
  min_duration_in_seconds: float = field(
194
  default=0.0,
195
+ metadata={
196
+ "help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
197
  )
198
  max_label_length: float = field(
199
  default=128,
200
+ metadata={
201
+ "help": "Truncate transcriptions that are longer `max_eval_length` tokens."},
202
  )
203
  pad_input_to_multiple_of: Optional[int] = field(
204
  default=None,
 
241
  )
242
  do_remove_punctuation: bool = field(
243
  default=False,
244
+ metadata={
245
+ "help": "Whether the target text should be striped of punctuation."},
246
  )
247
  do_normalize_eval: bool = field(
248
  default=True,
249
+ metadata={
250
+ "help": "Whether to normalise the references and predictions in the eval WER calculation."},
251
  )
252
  language: str = field(
253
  default=None,
 
260
  )
261
  task: str = field(
262
  default="transcribe",
263
+ metadata={
264
+ "help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
265
  )
266
+ num_train_steps: int = field(default=50000, metadata={
267
+ "help": "The number of training steps."})
268
  shuffle_buffer_size: Optional[int] = field(
269
  default=500,
270
  metadata={
 
276
  )
277
  streaming: bool = field(
278
  default=True,
279
+ metadata={
280
+ "help": "Whether to use streaming mode to load and pre-process the data."},
281
  )
282
 
283
+
284
  def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
285
  """
286
  Shift label ids one token to the right.
 
365
  labels = labels[:, 1:]
366
  labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
367
 
368
+ decoder_input_ids = shift_tokens_right(
369
+ labels, self.decoder_start_token_id)
370
 
371
  # replace padding with -100 to ignore correctly when computing the loss
372
+ labels = np.ma.array(labels, mask=np.not_equal(
373
+ labels_batch.attention_mask, 1))
374
  labels = labels.filled(fill_value=-100)
375
 
376
  batch["labels"] = labels
377
  batch["decoder_input_ids"] = decoder_input_ids
378
 
379
  return batch
380
+
381
 
382
  def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
383
  """
 
388
  if "+" in split:
389
  # load multiple splits separated by the `+` symbol with streaming mode
390
  dataset_splits = [
391
+ load_dataset(dataset_name, dataset_config_name,
392
+ split=split_name, streaming=streaming, **kwargs)
393
  for split_name in split.split("+")
394
  ]
395
  # interleave multiple splits to form one dataset
 
397
  return interleaved_dataset
398
  else:
399
  # load a single split *with* streaming mode
400
+ dataset = load_dataset(
401
+ dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
402
  return dataset
403
 
404
 
405
+ def collate_batch(samples):
406
+ return {key: [feature[key] for feature in samples] for key in samples[0]}
407
+
408
+
409
+ def data_loader(
410
+ dataset: Dataset,
411
+ batch_size: int,
412
+ drop_last: bool=True,
413
+ num_workers: int=0,
414
+ ) -> Generator:
415
  """
416
  Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
417
  and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
418
  """
419
+ data_loader_iterator = iter(torch.utils.data.DataLoader(
420
+ batch_size=batch_size,
421
+ dataset=dataset.with_format("torch"),
422
+ num_workers=num_workers,
423
+ collate_fn=collate_batch,
424
+ drop_last=drop_last,
425
+ ))
426
+ return data_loader_iterator
 
 
 
 
 
 
 
 
 
427
 
428
 
429
  class TrainState(train_state.TrainState):
 
447
 
448
 
449
  def create_learning_rate_fn(
450
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float, warmup_init_value: float=0.0, decay_end_value: float=0.0,
451
  ) -> Callable[[int], jnp.array]:
452
  """Returns a linear warmup, linear_decay learning rate function."""
453
+ warmup_fn = optax.linear_schedule(
454
+ init_value=warmup_init_value, end_value=learning_rate, transition_steps=num_warmup_steps)
455
  decay_fn = optax.linear_schedule(
456
+ init_value=learning_rate, end_value=decay_end_value, transition_steps=num_train_steps - num_warmup_steps
457
  )
458
+ schedule_fn = optax.join_schedules(
459
+ schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
460
  return schedule_fn
461
 
462
 
 
465
  # See all possible arguments in src/transformers/training_args.py
466
  # or by passing the --help flag to this script.
467
  # We now keep distinct sets of args, for a cleaner separation of concerns.
468
+ parser = HfArgumentParser(
469
+ (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
470
 
471
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
472
  # If we pass only one argument to the script and it's the path to a json file,
473
  # let's parse it to get our arguments.
474
+ model_args, data_args, training_args = parser.parse_json_file(
475
+ json_file=os.path.abspath(sys.argv[1]))
476
  else:
477
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
478
 
479
  # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
480
  # information sent is the one passed as arguments along with your JAX/Flax versions.
481
+ send_example_telemetry("run_speech_recognition_seq2seq",
482
+ model_args, data_args, framework="flax")
483
 
484
  # 2. Setup logging
485
  # Make one log on every process with the configuration for debugging.
 
490
  )
491
  # Set the verbosity to info of the Transformers logger.
492
  # We only want one process per machine to log things on the screen.
493
+ logger.setLevel(logging.INFO if jax.process_index()
494
+ == 0 else logging.ERROR)
495
  if jax.process_index() == 0:
496
  datasets.utils.logging.set_verbosity_warning()
497
  transformers.utils.logging.set_verbosity_info()
 
517
  if training_args.push_to_hub:
518
  if training_args.hub_model_id is None:
519
  repo_name = get_full_repo_name(
520
+ Path(training_args.output_dir).absolute(
521
+ ).name, token=training_args.hub_token
522
  )
523
  else:
524
  repo_name = training_args.hub_model_id
525
  create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
526
+ repo = Repository(training_args.output_dir,
527
+ clone_from=repo_name, token=training_args.hub_token)
528
 
529
  # 3. Load dataset
530
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
531
+
532
  if training_args.do_train:
533
+ raw_datasets["train"] = load_maybe_streaming_dataset(
534
  data_args.dataset_name,
535
  data_args.dataset_config_name,
536
  split=data_args.train_split_name,
 
540
  )
541
 
542
  if training_args.do_eval:
543
+ raw_datasets["eval"] = load_maybe_streaming_dataset(
544
  data_args.dataset_name,
545
  data_args.dataset_config_name,
546
  split=data_args.eval_split_name,
 
548
  streaming=data_args.streaming,
549
  use_auth_token=True if model_args.use_auth_token else None,
550
  )
551
+
552
  if not training_args.do_train and not training_args.do_eval:
553
  raise ValueError(
554
  "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
555
  )
556
 
557
+ raw_datasets_features = list(
558
+ next(iter(raw_datasets.values())).features.keys())
559
 
560
  if data_args.audio_column_name not in raw_datasets_features:
561
  raise ValueError(
 
602
  )
603
 
604
  if model.config.decoder_start_token_id is None:
605
+ raise ValueError(
606
+ "Make sure that `config.decoder_start_token_id` is correctly defined")
607
 
608
  # 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
609
  # so we just need to set the correct target sampling rate.
610
+ dataset_sampling_rate = next(
611
+ iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
612
+
613
  if dataset_sampling_rate != feature_extractor.sampling_rate:
614
  raw_datasets = raw_datasets.cast_column(
615
+ data_args.audio_column_name, datasets.features.Audio(
616
+ sampling_rate=feature_extractor.sampling_rate)
617
  )
618
 
619
  # 7. Preprocessing the datasets.
620
  # We need to read the audio files as arrays and tokenize the targets.
621
+ max_input_length = int(
622
+ data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
623
+ min_input_length = int(
624
+ data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
625
  max_label_length = (
626
  data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
627
  )
 
637
 
638
  if data_args.language is not None:
639
  # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
640
+ tokenizer.set_prefix_tokens(
641
+ language=data_args.language, task=data_args.task)
642
 
643
  def prepare_dataset(batch):
644
  # process audio
645
  sample = batch[audio_column_name]
646
+ inputs = feature_extractor(
647
+ sample["array"], sampling_rate=sample["sampling_rate"])
648
  # process audio length
649
  batch[model_input_name] = inputs.get(model_input_name)[0]
650
  batch["input_length"] = len(sample["array"])
651
 
652
  # process targets
653
+ input_str = batch[text_column_name].lower(
654
+ ) if do_lower_case else batch[text_column_name]
655
  if do_remove_punctuation:
656
  input_str = normalizer(input_str).strip()
657
  batch["labels"] = tokenizer(input_str).input_ids
 
661
  vectorized_datasets = raw_datasets.map(
662
  prepare_dataset,
663
  remove_columns=raw_datasets_features,
664
+ )
665
 
 
666
  # filter training data with inputs longer than max_input_length
667
  def is_audio_in_length_range(length):
668
  return min_input_length < length < max_input_length
669
 
 
670
  if training_args.do_train:
671
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
672
  is_audio_in_length_range,
673
  input_columns=["input_length"],
674
  )
675
+
676
  if training_args.do_eval:
677
+ vectorized_datasets["eval"] = vectorized_datasets["eval"].filter(
678
  is_audio_in_length_range,
679
  input_columns=["input_length"],
680
  )
681
+
682
  # 8. Load Metric
683
+ metric_wer = evaluate.load("wer")
684
+ metric_cer = evaluate.load("cer")
685
  do_normalize_eval = data_args.do_normalize_eval
686
 
687
+ def compute_metrics(pred_ids, label_ids):
688
  # replace padded labels by the padding token
689
+ for idx in range(len(label_ids)):
690
+ label_ids[idx][label_ids[idx] == -100] = tokenizer.pad_token_id
691
 
692
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
693
  # we do not want to group tokens when computing the metrics
694
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
695
+
696
+ if do_normalize_eval:
697
+ pred_str = [normalizer(pred) for pred in pred_str]
698
+ label_str = [normalizer(label) for label in label_str]
699
+ # filtering step to only evaluate the samples that correspond to non-zero references:
700
+ pred_str = [pred_str[i]
701
+ for i in range(len(pred_str)) if len(label_str[i]) > 0]
702
+ label_str = [label_str[i]
703
+ for i in range(len(label_str)) if len(label_str[i]) > 0]
704
+
705
+ wer = 100 * metric_wer.compute(predictions=pred_str, references=label_str)
706
+ cer = 100 * metric_cer.compute(predictions=pred_str, references=label_str)
707
+
708
+ return {"wer": wer, "cer": cer}
709
+
710
  # 9. Save feature extractor, tokenizer and config
711
  feature_extractor.save_pretrained(training_args.output_dir)
712
  tokenizer.save_pretrained(training_args.output_dir)
 
730
  try:
731
  from flax.metrics.tensorboard import SummaryWriter
732
 
733
+ summary_writer = SummaryWriter(
734
+ log_dir=Path(training_args.output_dir))
735
  except ImportError as ie:
736
  has_tensorboard = False
737
  logger.warning(
 
746
  # Initialize our training
747
  rng = jax.random.PRNGKey(training_args.seed)
748
  rng, dropout_rng = jax.random.split(rng)
749
+ # rng, input_rng = jax.random.split(rng)
750
 
751
  # Store some constant
752
  #num_epochs = int(training_args.num_train_epochs)
753
+ train_batch_size = int(
754
+ training_args.per_device_train_batch_size) * jax.device_count()
755
+ eval_batch_size = int(
756
+ training_args.per_device_eval_batch_size) * jax.device_count()
757
 
758
  # Create learning rate schedule
759
+ lr_scheduler_types = {"linear", "constant", "constant_with_warmup"}
760
+ if training_args.lr_scheduler_type not in lr_scheduler_types:
761
+ raise ValueError(
762
+ f"lr_scheduler_type of type {training_args.lr_scheduler_type} not supported, choose from {lr_scheduler_types}."
763
+ )
764
+ elif training_args.lr_scheduler_type == "constant":
765
+ warmup_init_value = training_args.learning_rate
766
+ decay_end_value = training_args.learning_rate
767
+ elif training_args.lr_scheduler_type == "constant_with_warmup":
768
+ warmup_init_value = 0.0
769
+ decay_end_value = training_args.learning_rate
770
+ else:
771
+ warmup_init_value = 0.0
772
+ decay_end_value = 0.0
773
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
774
+ data_args.num_train_steps * train_batch_size,
775
  training_args.warmup_steps,
776
  training_args.learning_rate,
777
+ warmup_init_value=warmup_init_value,
778
+ decay_end_value=decay_end_value,
779
  )
780
 
781
  # We use Optax's "masking" functionality to not apply weight decay
 
794
  if layer_norm_name in "".join(layer).lower()
795
  ]
796
  )
797
+ flat_mask = {path: (path[-1] != "bias" and path[-2:]
798
+ not in layer_norm_named_params) for path in flat_params}
799
  return traverse_util.unflatten_dict(flat_mask)
800
 
801
  # create adam optimizer
 
809
  )
810
 
811
  # Setup train state
812
+ state = TrainState.create(
813
+ apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
814
 
815
  # label smoothed cross entropy
816
  def loss_fn(logits, labels, label_smoothing_factor=0.0):
 
822
  confidence = 1.0 - label_smoothing_factor
823
  low_confidence = (1.0 - confidence) / (vocab_size - 1)
824
  normalizing_constant = -(
825
+ confidence * jnp.log(confidence) + (vocab_size - 1) *
826
+ low_confidence * jnp.log(low_confidence + 1e-20)
827
  )
828
+ soft_labels = onehot(labels, vocab_size,
829
+ on_value=confidence, off_value=low_confidence)
830
 
831
  loss = optax.softmax_cross_entropy(logits, soft_labels)
832
  loss = loss - normalizing_constant
 
844
 
845
  def compute_loss(params):
846
  labels = batch.pop("labels")
847
+ logits = state.apply_fn(
848
+ **batch, params=params, dropout_rng=dropout_rng, train=True)[0]
849
  loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
850
  return loss, num_labels
851
 
 
860
  # true grad = total grad / total samples
861
  grad = jax.lax.psum(grad, "batch")
862
  grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
863
+ new_state = state.apply_gradients(
864
+ grads=grad, dropout_rng=new_dropout_rng)
865
 
866
+ metrics = {"loss": loss,
867
+ "learning_rate": linear_decay_lr_schedule_fn(state.step)}
868
  return new_state, metrics
869
 
870
  # Define eval fn
 
888
 
889
  def generate_step(params, batch):
890
  model.params = params
891
+ output_ids = model.generate(batch[model_input_name], attention_mask=batch.get(
892
+ "attention_mask"), **gen_kwargs)
893
  return output_ids.sequences
894
 
895
  # Create parallel version of the train and eval step
896
  p_train_step = jax.pmap(
897
+ partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0, )
898
  )
899
+ p_eval_step = jax.pmap(partial(
900
+ eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
901
  p_generate_step = jax.pmap(generate_step, "batch")
902
 
903
  # Replicate the train state on each device
904
  state = state.replicate()
905
 
906
  logger.info("***** Running training *****")
907
+ logger.info(
908
+ f" Num examples = {data_args.num_train_steps * train_batch_size}")
909
+ logger.info(
910
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
911
+ logger.info(
912
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
913
  logger.info(f" Total optimization steps = {data_args.num_train_steps}")
914
 
915
  train_time = 0
916
+
 
 
 
 
917
  # ======================== Training ================================
918
  train_start = time.time()
919
 
 
 
920
  train_metrics = []
921
+ epoch = 0
922
+ train_dataset = vectorized_datasets["train"].shuffle(seed=training_args.seed)
923
+ eval_dataset = vectorized_datasets["eval"]
924
+ train_loader = data_loader(train_dataset, train_batch_size)
 
 
 
 
 
 
 
925
  # train
926
+ for step in tqdm(range(data_args.num_train_steps), desc="Training...", position=1, leave=False):
 
927
  try:
928
+ samples = next(train_loader)
 
929
  except StopIteration:
930
  epoch += 1
931
+ train_dataset.set_epoch(epoch)
932
+ train_loader = data_loader(train_dataset, train_batch_size)
933
+ samples = next(train_loader)
 
 
934
  logger.info(
935
  f"Completed epoch ({epoch} | Loss: {train_metric['loss']}, Learning Rate:"
936
  f" {train_metric['learning_rate']})"
937
  )
938
+
939
+ batch = data_collator(samples)
 
940
  batch = shard(batch.data)
941
  state, train_metric = p_train_step(state, batch)
942
+ train_metrics.append(train_metric)
943
+
944
  train_time += time.time() - train_start
945
  train_metric = unreplicate(train_metric)
 
946
  # ======================== Evaluating ==============================
947
  if step % training_args.eval_steps == 0 and step > 0:
948
+ eval_metrics = []
949
+ eval_preds = []
950
+ eval_labels = []
951
+ eval_loader = data_loader(eval_dataset, eval_batch_size, drop_last=False)
952
+ if data_args.max_eval_samples:
953
+ max_eval_steps_iter = range(1 + data_args.max_eval_samples // eval_batch_size)
954
+ else:
955
+ max_eval_steps_iter = itertools.repeat(None)
956
+ for _ in tqdm(max_eval_steps_iter, desc="Evaluating...", position=2, leave=False):
957
  # Model forward
958
+ try:
959
+ samples = next(eval_loader)
960
+ except StopIteration:
961
+ break
962
+ batch = data_collator(samples)
963
  labels = batch["labels"]
964
 
965
  metrics = pad_shard_unpad(p_eval_step, static_return=True)(
 
967
  )
968
  eval_metrics.append(metrics)
969
 
970
+ # generation
971
+ if training_args.predict_with_generate:
972
+ generated_ids = pad_shard_unpad(
973
+ p_generate_step)(state.params, batch.data)
974
+ eval_preds.extend(jax.device_get(
975
+ generated_ids.reshape(-1, gen_kwargs["max_length"])))
976
+ eval_labels.extend(labels)
977
+
978
  # normalize eval metrics
979
  eval_metrics = get_metrics(eval_metrics)
980
  eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
981
+
982
+ # compute metrics
983
+ metric_desc = ""
984
  if training_args.predict_with_generate:
985
+ metric_values = compute_metrics(eval_preds, eval_labels)
986
+ eval_metrics.update(metric_values)
987
+ metric_desc = " ".join(
988
+ [f"Eval {key}: {value} |" for key, value in metric_values.items()])
989
 
990
  # Print metrics
991
+ desc = f"Epoch... ({epoch} | Eval Loss: {eval_metrics['loss']} | {metric_desc})"
992
  logger.info(desc)
993
 
994
  # Save metrics
995
  if has_tensorboard and jax.process_index() == 0:
996
+ write_metric(summary_writer, train_metrics,
997
+ eval_metrics, train_time, step)
 
 
998
 
 
 
999
  # save checkpoint after each epoch and push checkpoint to the hub
1000
  if jax.process_index() == 0:
1001
+ params = jax.device_get(
1002
+ jax.tree_util.tree_map(lambda x: x[0], state.params))
1003
  model.save_pretrained(training_args.output_dir, params=params)
1004
  tokenizer.save_pretrained(training_args.output_dir)
1005
  if training_args.push_to_hub:
1006
+ repo.push_to_hub(
1007
+ commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
1008
 
1009
 
1010
  if __name__ == "__main__":
run_flax_speech_recognition_seq2seq_streaming_v2.py DELETED
@@ -1,972 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- Fine-tuning the Flax library models for sequence to sequence speech recognition.
18
- """
19
- # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
-
21
- import logging
22
- import math
23
- import os
24
- import sys
25
- import time
26
- from dataclasses import field
27
- from functools import partial
28
- from pathlib import Path
29
- from typing import Any, Callable, Dict, List, Optional, Union
30
-
31
- import datasets
32
- import flax
33
- import jax
34
- import jax.numpy as jnp
35
- import numpy as np
36
- import optax
37
- import torch
38
- from datasets import Dataset,DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
39
- from torch.utils.data import IterableDataset
40
- from flax import jax_utils, traverse_util
41
- from flax.jax_utils import pad_shard_unpad, unreplicate
42
- from flax.training import train_state
43
- from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
44
- from huggingface_hub import Repository, create_repo
45
- from tqdm import tqdm
46
-
47
- import evaluate
48
- import transformers
49
- from transformers import (
50
- AutoConfig,
51
- AutoFeatureExtractor,
52
- AutoProcessor,
53
- AutoTokenizer,
54
- FlaxAutoModelForSpeechSeq2Seq,
55
- HfArgumentParser,
56
- Seq2SeqTrainingArguments,
57
- is_tensorboard_available,
58
- )
59
-
60
- from transformers.models.whisper.english_normalizer import BasicTextNormalizer
61
- from transformers.file_utils import get_full_repo_name
62
- from transformers.utils import check_min_version, send_example_telemetry
63
- from transformers.utils.versions import require_version
64
-
65
-
66
- # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
67
- check_min_version("4.27.0.dev0")
68
-
69
- require_version("datasets>=1.18.2", "To fix: pip install -r examples/flax/speech-recogintion/requirements.txt")
70
-
71
- logger = logging.getLogger(__name__)
72
-
73
-
74
- @flax.struct.dataclass
75
- class ModelArguments:
76
- """
77
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
78
- """
79
-
80
- model_name_or_path: str = field(
81
- metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
82
- )
83
- config_name: Optional[str] = field(
84
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
85
- )
86
- tokenizer_name: Optional[str] = field(
87
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
88
- )
89
- feature_extractor_name: Optional[str] = field(
90
- default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
91
- )
92
- cache_dir: Optional[str] = field(
93
- default=None,
94
- metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
95
- )
96
- use_fast_tokenizer: bool = field(
97
- default=True,
98
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
99
- )
100
- model_revision: str = field(
101
- default="main",
102
- metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
103
- )
104
- use_auth_token: bool = field(
105
- default=False,
106
- metadata={
107
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
108
- "with private models)."
109
- },
110
- )
111
- dtype: Optional[str] = field(
112
- default="float32",
113
- metadata={
114
- "help": (
115
- "Floating-point format in which the model weights should be initialized and trained. Choose one of"
116
- " `[float32, float16, bfloat16]`."
117
- )
118
- },
119
- )
120
- num_beams: Optional[int] = field(
121
- default=None,
122
- metadata={
123
- "help": (
124
- "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
125
- "which is used during evaluation."
126
- )
127
- },
128
- )
129
-
130
-
131
- @flax.struct.dataclass
132
- class DataTrainingArguments:
133
- """
134
- Arguments pertaining to what data we are going to input our model for training and eval.
135
- """
136
-
137
- dataset_name: str = field(
138
- default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
139
- )
140
- dataset_config_name: Optional[str] = field(
141
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
142
- )
143
- text_column: Optional[str] = field(
144
- default=None,
145
- metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
146
- )
147
- dataset_cache_dir: Optional[str] = field(
148
- default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
149
- )
150
- overwrite_cache: bool = field(
151
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
152
- )
153
- preprocessing_num_workers: Optional[int] = field(
154
- default=None,
155
- metadata={"help": "The number of processes to use for the preprocessing."},
156
- )
157
- max_train_samples: Optional[int] = field(
158
- default=None,
159
- metadata={
160
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
161
- "value if set."
162
- },
163
- )
164
- max_eval_samples: Optional[int] = field(
165
- default=None,
166
- metadata={
167
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
168
- "value if set."
169
- },
170
- )
171
- audio_column_name: str = field(
172
- default="audio",
173
- metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
174
- )
175
- text_column_name: str = field(
176
- default="text",
177
- metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
178
- )
179
- max_duration_in_seconds: float = field(
180
- default=20.0,
181
- metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
182
- )
183
- min_duration_in_seconds: float = field(
184
- default=0.0,
185
- metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
186
- )
187
- max_label_length: float = field(
188
- default=128,
189
- metadata={"help": "Truncate transcriptions that are longer `max_eval_length` tokens."},
190
- )
191
- pad_input_to_multiple_of: Optional[int] = field(
192
- default=None,
193
- metadata={
194
- "help": "If set will pad the input sequence to a multiple of the provided value. "
195
- "This is important to avoid triggering recompilations on TPU. If unspecified, will default to padding the inputs to max length."
196
- },
197
- )
198
- pad_target_to_multiple_of: Optional[int] = field(
199
- default=None,
200
- metadata={
201
- "help": "If set will pad the target sequence to a multiple of the provided value. "
202
- "This is important to avoid triggering recompilations on TPU. If unspecified, will default to padding the targets to max length."
203
- },
204
- )
205
- preprocessing_only: bool = field(
206
- default=False,
207
- metadata={
208
- "help": "Whether to only do data preprocessing and skip training. "
209
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
210
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
211
- "so that the cached datasets can consequently be loaded in distributed training"
212
- },
213
- )
214
- train_split_name: str = field(
215
- default="train",
216
- metadata={
217
- "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
218
- },
219
- )
220
- eval_split_name: str = field(
221
- default="validation",
222
- metadata={
223
- "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
224
- },
225
- )
226
- do_lower_case: bool = field(
227
- default=True,
228
- metadata={"help": "Whether the target text should be lower cased."},
229
- )
230
- do_remove_punctuation: bool = field(
231
- default=False,
232
- metadata={"help": "Whether the target text should be striped of punctuation."},
233
- )
234
- do_normalize_eval: bool = field(
235
- default=True,
236
- metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
237
- )
238
- language: str = field(
239
- default=None,
240
- metadata={
241
- "help": (
242
- "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
243
- "only. For English speech recognition, it should be set to `None`."
244
- )
245
- },
246
- )
247
- task: str = field(
248
- default="transcribe",
249
- metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
250
- )
251
- num_train_steps: int = field(default=50000, metadata={"help": "The number of training steps."})
252
- # num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
253
- shuffle_buffer_size: Optional[int] = field(
254
- default=500,
255
- metadata={
256
- "help": (
257
- "The number of streamed examples to download before shuffling them. The large the buffer, "
258
- "the closer it is to real offline shuffling."
259
- )
260
- },
261
- )
262
- streaming: bool = field(
263
- default=True,
264
- metadata={"help": "Whether to use streaming mode to load and pre-process the data."},
265
- )
266
-
267
- def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
268
- """
269
- Shift label ids one token to the right.
270
- """
271
- shifted_label_ids = np.zeros_like(label_ids)
272
- shifted_label_ids[:, 1:] = label_ids[:, :-1]
273
- shifted_label_ids[:, 0] = decoder_start_token_id
274
-
275
- return shifted_label_ids
276
-
277
-
278
- @flax.struct.dataclass
279
- class FlaxDataCollatorSpeechSeq2SeqWithPadding:
280
- """
281
- Data collator that will dynamically pad the inputs received.
282
- Args:
283
- processor ([`Wav2Vec2Processor`])
284
- The processor used for proccessing the data.
285
- decoder_start_token_id (:obj: `int`)
286
- The begin-of-sentence of the decoder.
287
- input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
288
- Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
289
- among:
290
- * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
291
- sequence if provided).
292
- * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
293
- maximum acceptable input length for the model if that argument is not provided.
294
- * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
295
- different lengths).
296
- target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
297
- Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
298
- See above for details.
299
- max_input_length (:obj:`float`, `optional`):
300
- Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
301
- max_target_length (:obj:`int`, `optional`):
302
- Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
303
- pad_input_to_multiple_of (:obj:`int`, `optional`):
304
- If set will pad the input sequence to a multiple of the provided value.
305
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
306
- 7.5 (Volta).
307
- pad_target_to_multiple_of (:obj:`int`, `optional`):
308
- If set will pad the target sequence to a multiple of the provided value.
309
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
310
- 7.5 (Volta).
311
- """
312
-
313
- processor: Any
314
- decoder_start_token_id: int
315
- input_padding: Union[bool, str] = "longest"
316
- target_padding: Union[bool, str] = "max_length"
317
- max_input_length: Optional[float] = None
318
- max_target_length: Optional[int] = None
319
- pad_input_to_multiple_of: Optional[int] = None
320
- pad_target_to_multiple_of: Optional[int] = None
321
-
322
- def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
323
- model_input_name = self.processor.model_input_names[0]
324
- input_features = {model_input_name: features[model_input_name]}
325
- label_features = {"input_ids": features["labels"]}
326
-
327
- # reformat list to dict and set to pytorch format
328
- batch = self.processor.feature_extractor.pad(
329
- input_features,
330
- max_length=self.max_input_length,
331
- padding=self.input_padding,
332
- pad_to_multiple_of=self.pad_input_to_multiple_of,
333
- return_tensors="np",
334
- )
335
-
336
- labels_batch = self.processor.tokenizer.pad(
337
- label_features,
338
- max_length=self.max_target_length,
339
- padding=self.target_padding,
340
- pad_to_multiple_of=self.pad_target_to_multiple_of,
341
- return_tensors="np",
342
- )
343
-
344
- # if bos token is appended in previous tokenization step,
345
- # cut bos token here as it's append later anyways
346
- labels = labels_batch["input_ids"]
347
- if (labels[:, 0] == self.decoder_start_token_id).all().item():
348
- labels = labels[:, 1:]
349
- labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
350
-
351
- decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
352
-
353
- # replace padding with -100 to ignore correctly when computing the loss
354
- labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
355
- labels = labels.filled(fill_value=-100)
356
-
357
- batch["labels"] = labels
358
- batch["decoder_input_ids"] = decoder_input_ids
359
-
360
- return batch
361
-
362
-
363
- def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
364
- """
365
- Utility function to load a dataset in streaming mode. For datasets with multiple splits,
366
- each split is loaded individually and then splits combined by taking alternating examples from
367
- each (interleaving).
368
- """
369
- if "+" in split:
370
- # load multiple splits separated by the `+` symbol with streaming mode
371
- dataset_splits = [
372
- load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
373
- for split_name in split.split("+")
374
- ]
375
- # interleave multiple splits to form one dataset
376
- interleaved_dataset = interleave_datasets(dataset_splits)
377
- return interleaved_dataset
378
- else:
379
- # load a single split *with* streaming mode
380
- dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
381
- return dataset
382
-
383
-
384
- def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
385
- """
386
- Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
387
- and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
388
- """
389
- if shuffle:
390
- batch_idx = jax.random.permutation(rng, len(dataset))
391
- batch_idx = np.asarray(batch_idx)
392
- else:
393
- batch_idx = np.arange(len(dataset))
394
-
395
- if drop_last:
396
- steps_per_epoch = len(dataset) // batch_size
397
- batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
398
- batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
399
- else:
400
- steps_per_epoch = math.ceil(len(dataset) / batch_size)
401
- batch_idx = np.array_split(batch_idx, steps_per_epoch)
402
-
403
- for idx in batch_idx:
404
- batch = dataset[idx]
405
- yield batch
406
-
407
-
408
- class TrainState(train_state.TrainState):
409
- dropout_rng: jnp.ndarray
410
-
411
- def replicate(self):
412
- return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
413
-
414
-
415
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
416
- summary_writer.scalar("train_time", train_time, step)
417
-
418
- train_metrics = get_metrics(train_metrics)
419
- for key, vals in train_metrics.items():
420
- tag = f"train_{key}"
421
- for i, val in enumerate(vals):
422
- summary_writer.scalar(tag, val, step - len(vals) + i + 1)
423
-
424
- for metric_name, value in eval_metrics.items():
425
- summary_writer.scalar(f"eval_{metric_name}", value, step)
426
-
427
-
428
- def create_learning_rate_fn(
429
- num_train_steps: int, num_warmup_steps: int, learning_rate: float
430
- ) -> Callable[[int], jnp.array]:
431
- """Returns a linear warmup, linear_decay learning rate function."""
432
- warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
433
- decay_fn = optax.linear_schedule(
434
- init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
435
- )
436
- schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
437
- return schedule_fn
438
-
439
-
440
- def main():
441
- # 1. Parse input arguments
442
- # See all possible arguments in src/transformers/training_args.py
443
- # or by passing the --help flag to this script.
444
- # We now keep distinct sets of args, for a cleaner separation of concerns.
445
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
446
-
447
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
448
- # If we pass only one argument to the script and it's the path to a json file,
449
- # let's parse it to get our arguments.
450
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
451
- else:
452
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
453
-
454
- # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
455
- # information sent is the one passed as arguments along with your JAX/Flax versions.
456
- send_example_telemetry("run_speech_recognition_seq2seq", model_args, data_args, framework="flax")
457
-
458
- # 2. Setup logging
459
- # Make one log on every process with the configuration for debugging.
460
- logging.basicConfig(
461
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
462
- datefmt="%m/%d/%Y %H:%M:%S",
463
- handlers=[logging.StreamHandler(sys.stdout)],
464
- )
465
- # Set the verbosity to info of the Transformers logger.
466
- # We only want one process per machine to log things on the screen.
467
- logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
468
- if jax.process_index() == 0:
469
- datasets.utils.logging.set_verbosity_warning()
470
- transformers.utils.logging.set_verbosity_info()
471
- else:
472
- datasets.utils.logging.set_verbosity_error()
473
- transformers.utils.logging.set_verbosity_error()
474
-
475
- logger.info("Training/evaluation parameters %s", training_args)
476
-
477
- # Check the output dir is valid
478
- if (
479
- os.path.exists(training_args.output_dir)
480
- and os.listdir(training_args.output_dir)
481
- and training_args.do_train
482
- and not training_args.overwrite_output_dir
483
- ):
484
- raise ValueError(
485
- f"Output directory ({training_args.output_dir}) already exists and is not empty."
486
- "Use `--overwrite_output_dir` to overcome."
487
- )
488
-
489
- # Handle the repository creation
490
- if training_args.push_to_hub:
491
- if training_args.hub_model_id is None:
492
- repo_name = get_full_repo_name(
493
- Path(training_args.output_dir).absolute().name, token=training_args.hub_token
494
- )
495
- else:
496
- repo_name = training_args.hub_model_id
497
- create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
498
- repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
499
-
500
- # 3. Load dataset
501
- raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
502
-
503
- if training_args.do_train:
504
- raw_datasets["train"] = load_maybe_streaming_dataset(
505
- data_args.dataset_name,
506
- data_args.dataset_config_name,
507
- split=data_args.train_split_name,
508
- cache_dir=data_args.dataset_cache_dir,
509
- streaming=data_args.streaming,
510
- use_auth_token=True if model_args.use_auth_token else None,
511
- )
512
-
513
- if training_args.do_eval:
514
- raw_datasets["eval"] = load_maybe_streaming_dataset(
515
- data_args.dataset_name,
516
- data_args.dataset_config_name,
517
- split=data_args.eval_split_name,
518
- cache_dir=data_args.dataset_cache_dir,
519
- streaming=data_args.streaming,
520
- use_auth_token=True if model_args.use_auth_token else None,
521
- )
522
-
523
- if not training_args.do_train and not training_args.do_eval:
524
- raise ValueError(
525
- "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
526
- )
527
-
528
- raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
529
-
530
- if data_args.audio_column_name not in raw_datasets_features:
531
- raise ValueError(
532
- f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
533
- "Make sure to set `--audio_column_name` to the correct audio column - one of "
534
- f"{', '.join(raw_datasets_features)}."
535
- )
536
-
537
- if data_args.text_column_name not in raw_datasets_features:
538
- raise ValueError(
539
- f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
540
- "Make sure to set `--text_column_name` to the correct text column - one of "
541
- f"{', '.join(raw_datasets_features)}."
542
- )
543
-
544
- # 5. Load pretrained model, tokenizer, and feature extractor
545
- config = AutoConfig.from_pretrained(
546
- model_args.config_name if model_args.config_name else model_args.model_name_or_path,
547
- cache_dir=model_args.cache_dir,
548
- revision=model_args.model_revision,
549
- use_auth_token=True if model_args.use_auth_token else None,
550
- )
551
- feature_extractor = AutoFeatureExtractor.from_pretrained(
552
- model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
553
- cache_dir=model_args.cache_dir,
554
- revision=model_args.model_revision,
555
- use_auth_token=True if model_args.use_auth_token else None,
556
- )
557
- tokenizer = AutoTokenizer.from_pretrained(
558
- model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
559
- cache_dir=model_args.cache_dir,
560
- use_fast=model_args.use_fast_tokenizer,
561
- revision=model_args.model_revision,
562
- use_auth_token=True if model_args.use_auth_token else None,
563
- )
564
-
565
- model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained(
566
- model_args.model_name_or_path,
567
- config=config,
568
- dtype=getattr(jnp, model_args.dtype),
569
- cache_dir=model_args.cache_dir,
570
- revision=model_args.model_revision,
571
- use_auth_token=True if model_args.use_auth_token else None,
572
- )
573
-
574
- if model.config.decoder_start_token_id is None:
575
- raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
576
-
577
- # 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
578
- # so we just need to set the correct target sampling rate.
579
- dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
580
-
581
- if dataset_sampling_rate != feature_extractor.sampling_rate:
582
- raw_datasets = raw_datasets.cast_column(
583
- data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
584
- )
585
-
586
- # 7. Preprocessing the datasets.
587
- # We need to read the audio files as arrays and tokenize the targets.
588
- max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
589
- min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
590
- max_label_length = (
591
- data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
592
- )
593
- pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
594
- pad_target_to_multiple_of = data_args.pad_target_to_multiple_of
595
- audio_column_name = data_args.audio_column_name
596
- num_workers = data_args.preprocessing_num_workers
597
- text_column_name = data_args.text_column_name
598
- model_input_name = feature_extractor.model_input_names[0]
599
- do_lower_case = data_args.do_lower_case
600
- do_remove_punctuation = data_args.do_remove_punctuation
601
- normalizer = BasicTextNormalizer() # 'official' text normalizer from OpenAI
602
-
603
- if data_args.language is not None:
604
- # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
605
- tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
606
-
607
- def prepare_dataset(batch):
608
- # process audio
609
- sample = batch[audio_column_name]
610
- inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
611
- # process audio length
612
- batch[model_input_name] = inputs.get(model_input_name)[0]
613
- batch["input_length"] = len(sample["array"])
614
-
615
- # process targets
616
- input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
617
- if do_remove_punctuation:
618
- input_str = normalizer(input_str).strip()
619
- batch["labels"] = tokenizer(input_str).input_ids
620
- return batch
621
-
622
- with training_args.main_process_first(desc="dataset map pre-processing"):
623
- vectorized_datasets = raw_datasets.map(
624
- prepare_dataset,
625
- remove_columns=raw_datasets_features,
626
- ).with_format("torch")
627
-
628
- # Moving this to later. Better ways of doing this
629
- #if training_args.do_train and data_args.streaming:
630
- # manually shuffle if streaming (done by the trainer for non-streaming)
631
- #vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(
632
- # buffer_size=data_args.shuffle_buffer_size,
633
- # seed=training_args.seed,
634
- #)
635
-
636
-
637
- # filter training data with inputs longer than max_input_length
638
- def is_audio_in_length_range(length):
639
- return min_input_length < length < max_input_length
640
-
641
- # For debugging
642
- #def is_audio_in_length_range(length):
643
- # if min_input_length < length < max_input_length:
644
- # return True
645
- # else:
646
- # print(f"Warning: Input length {length} is not within the expected range [{min_input_length}, {max_input_length}].")
647
- # return False
648
-
649
-
650
- if training_args.do_train:
651
- vectorized_datasets["train"] = vectorized_datasets["train"].filter(
652
- is_audio_in_length_range,
653
- input_columns=["input_length"],
654
- )
655
-
656
- # for large datasets it is advised to run the preprocessing on a
657
- # single machine first with `args.preprocessing_only` since there will mostly likely
658
- # be a timeout when running the script in distributed mode.
659
- # In a second step `args.preprocessing_only` can then be set to `False` to load the
660
- # cached dataset
661
- # Not really needed for streaming
662
- # if data_args.preprocessing_only:
663
- # cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
664
- # logger.info(f"Data preprocessing finished. Files cached at {cache}.")
665
- # return
666
-
667
- # 8. Load Metric
668
- metric = evaluate.load("wer")
669
- do_normalize_eval = data_args.do_normalize_eval
670
-
671
- def compute_metrics(pred_ids, label_ids):
672
- # replace padded labels by the padding token
673
- for idx in range(len(label_ids)):
674
- label_ids[idx][label_ids[idx] == -100] = tokenizer.pad_token_id
675
- #label_ids[label_ids == -100] = tokenizer.pad_token_id
676
-
677
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
678
- # we do not want to group tokens when computing the metrics
679
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
680
-
681
- if do_normalize_eval:
682
- pred_str = [normalizer(pred) for pred in pred_str]
683
- label_str = [normalizer(label) for label in label_str]
684
- # filtering step to only evaluate the samples that correspond to non-zero references:
685
- pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
686
- label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
687
-
688
- wer = 100 * metric.compute(predictions=pred_str, references=label_str)
689
-
690
- return {"wer": wer}
691
-
692
- # 9. Save feature extractor, tokenizer and config
693
- feature_extractor.save_pretrained(training_args.output_dir)
694
- tokenizer.save_pretrained(training_args.output_dir)
695
- config.save_pretrained(training_args.output_dir)
696
-
697
- processor = AutoProcessor.from_pretrained(training_args.output_dir)
698
-
699
- data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
700
- processor=processor,
701
- decoder_start_token_id=model.config.decoder_start_token_id,
702
- input_padding="longest",
703
- target_padding="longest",
704
- max_target_length=max_label_length,
705
- pad_input_to_multiple_of=pad_input_to_multiple_of,
706
- pad_target_to_multiple_of=pad_target_to_multiple_of if pad_target_to_multiple_of else max_label_length,
707
- )
708
-
709
- # Enable tensorboard only on the master node
710
- has_tensorboard = is_tensorboard_available()
711
- if has_tensorboard and jax.process_index() == 0:
712
- try:
713
- from flax.metrics.tensorboard import SummaryWriter
714
-
715
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
716
- except ImportError as ie:
717
- has_tensorboard = False
718
- logger.warning(
719
- f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
720
- )
721
- else:
722
- logger.warning(
723
- "Unable to display metrics through TensorBoard because the package is not installed: "
724
- "Please run pip install tensorboard to enable."
725
- )
726
-
727
- # Initialize our training
728
- rng = jax.random.PRNGKey(training_args.seed)
729
- rng, dropout_rng = jax.random.split(rng)
730
-
731
- # Store some constant
732
- #num_epochs = int(training_args.num_train_epochs)
733
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
734
- per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
735
- eval_batch_size = per_device_eval_batch_size * jax.device_count()
736
-
737
-
738
- # Create learning rate schedule
739
- linear_decay_lr_schedule_fn = create_learning_rate_fn(
740
- data_args.num_train_steps*train_batch_size,
741
- training_args.warmup_steps,
742
- training_args.learning_rate,
743
- )
744
-
745
- # We use Optax's "masking" functionality to not apply weight decay
746
- # to bias and LayerNorm scale parameters. decay_mask_fn returns a
747
- # mask boolean with the same structure as the parameters.
748
- # The mask is True for parameters that should be decayed.
749
- def decay_mask_fn(params):
750
- flat_params = traverse_util.flatten_dict(params)
751
- # find out all LayerNorm parameters
752
- layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
753
- layer_norm_named_params = set(
754
- [
755
- layer[-2:]
756
- for layer_norm_name in layer_norm_candidates
757
- for layer in flat_params.keys()
758
- if layer_norm_name in "".join(layer).lower()
759
- ]
760
- )
761
- flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
762
- return traverse_util.unflatten_dict(flat_mask)
763
-
764
- # create adam optimizer
765
- adamw = optax.adamw(
766
- learning_rate=linear_decay_lr_schedule_fn,
767
- b1=training_args.adam_beta1,
768
- b2=training_args.adam_beta2,
769
- eps=training_args.adam_epsilon,
770
- weight_decay=training_args.weight_decay,
771
- mask=decay_mask_fn,
772
- )
773
-
774
- # Setup train state
775
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
776
-
777
- # label smoothed cross entropy
778
- def loss_fn(logits, labels, label_smoothing_factor=0.0):
779
- """
780
- The label smoothing implementation is adapted from Flax's official example:
781
- https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
782
- """
783
- vocab_size = logits.shape[-1]
784
- confidence = 1.0 - label_smoothing_factor
785
- low_confidence = (1.0 - confidence) / (vocab_size - 1)
786
- normalizing_constant = -(
787
- confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
788
- )
789
- soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
790
-
791
- loss = optax.softmax_cross_entropy(logits, soft_labels)
792
- loss = loss - normalizing_constant
793
-
794
- # ignore padded tokens from loss, i.e. where labels are not set to -100
795
- padding_mask = labels >= 0
796
- loss = loss * padding_mask
797
- loss = loss.sum()
798
- num_labels = padding_mask.sum()
799
- return loss, num_labels
800
-
801
- # Define gradient update step fn
802
- def train_step(state, batch, label_smoothing_factor=0.0):
803
- dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
804
-
805
- def compute_loss(params):
806
- labels = batch.pop("labels")
807
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
808
- loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
809
- return loss, num_labels
810
-
811
- grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
812
- (loss, num_labels), grad = grad_fn(state.params)
813
- num_labels = jax.lax.psum(num_labels, "batch")
814
-
815
- # true loss = total loss / total samples
816
- loss = jax.lax.psum(loss, "batch")
817
- loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
818
-
819
- # true grad = total grad / total samples
820
- grad = jax.lax.psum(grad, "batch")
821
- grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
822
- new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
823
-
824
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
825
- return new_state, metrics
826
-
827
- # Define eval fn
828
- def eval_step(params, batch, label_smoothing_factor=0.0):
829
- labels = batch.pop("labels")
830
- logits = model(**batch, params=params, train=False)[0]
831
-
832
- loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
833
- num_labels = jax.lax.psum(num_labels, "batch")
834
-
835
- # true loss = total loss / total samples
836
- loss = jax.lax.psum(loss, "batch")
837
- loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
838
-
839
- metrics = {"loss": loss}
840
- return metrics
841
-
842
- # Define generation function
843
- num_beams = model_args.num_beams if model_args.num_beams is not None else model.config.num_beams
844
- gen_kwargs = {"max_length": max_label_length, "num_beams": num_beams}
845
-
846
- def generate_step(params, batch):
847
- model.params = params
848
- output_ids = model.generate(batch[model_input_name], attention_mask=batch.get("attention_mask"), **gen_kwargs)
849
- return output_ids.sequences
850
-
851
- # Create parallel version of the train and eval step
852
- p_train_step = jax.pmap(
853
- partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
854
- )
855
- p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
856
- p_generate_step = jax.pmap(generate_step, "batch")
857
-
858
- # Replicate the train state on each device
859
- state = state.replicate()
860
-
861
- logger.info("***** Running training *****")
862
- logger.info(f" Num examples = {data_args.num_train_steps*train_batch_size}")
863
- logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
864
- logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
865
- logger.info(f" Total optimization steps = {data_args.num_train_steps}")
866
-
867
- train_time = 0
868
-
869
- # ======================== Training ================================
870
- train_start = time.time()
871
-
872
- # Create sampling rng
873
- #rng, input_rng = jax.random.split(rng)
874
- train_metrics = []
875
- epoch = 0
876
-
877
- def collate_batch(samples):
878
- return {key: [feature[key] for feature in samples] for key in samples[0].keys()}
879
-
880
- # Create a batched data iterator
881
- num_workers = 0
882
- # This is not working
883
- # vectorized_datasets["train"] = vectorized_datasets["train"].shuffle()
884
- train_data_loader = torch.utils.data.DataLoader( batch_size=train_batch_size, dataset=vectorized_datasets["train"], num_workers=num_workers, collate_fn=collate_batch, drop_last=True)
885
- train_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(train_data_loader)
886
-
887
-
888
- # train
889
- for step in tqdm(range(data_args.num_train_steps), desc="Training...", position=1, leave=False):
890
-
891
- try:
892
- samples = next(train_data_iterator)
893
-
894
- except StopIteration:
895
- epoch += 1
896
- train_data_loader = torch.utils.data.DataLoader( batch_size=train_batch_size, dataset=vectorized_datasets["train"], num_workers=num_workers, collate_fn=collate_batch, drop_last=True)
897
- train_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(train_data_loader)
898
- samples = next(train_data_iterator)
899
-
900
- logger.info(
901
- f"Completed epoch ({epoch} | Loss: {train_metric['loss']}, Learning Rate:"
902
- f" {train_metric['learning_rate']})"
903
- )
904
-
905
- # reshaped_samples = {key: [feature[key] for feature in samples] for key in samples[0].keys()}
906
- #breakpoint()
907
- batch = data_collator(samples)
908
- batch = shard(batch.data)
909
- state, train_metric = p_train_step(state, batch)
910
-
911
- # ======================== Evaluating ==============================
912
- if step % training_args.eval_steps == 0 and step > 0:
913
- eval_metrics = []
914
- eval_preds = []
915
- eval_labels = []
916
-
917
- #eval_loader = data_loader(input_rng, vectorized_datasets["eval"], eval_batch_size, drop_last=False)
918
- eval_data_loader = torch.utils.data.DataLoader( batch_size=eval_batch_size, dataset=vectorized_datasets["eval"], num_workers=num_workers, collate_fn=collate_batch, drop_last=False)
919
- eval_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(eval_data_loader)
920
-
921
- for _ in tqdm(range(training_args.eval_steps), desc="Evaluating...", position=2, leave=False):
922
- # Model forward
923
- samples = next(eval_data_iterator)
924
- batch = data_collator(samples)
925
- labels = batch["labels"]
926
-
927
- metrics = pad_shard_unpad(p_eval_step, static_return=True)(
928
- state.params, batch.data, min_device_batch=per_device_eval_batch_size
929
- )
930
- eval_metrics.append(metrics)
931
-
932
- # generation
933
- if training_args.predict_with_generate:
934
- generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch.data)
935
- eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
936
- eval_labels.extend(labels)
937
-
938
- # normalize eval metrics
939
- eval_metrics = get_metrics(eval_metrics)
940
- eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
941
-
942
- # compute WER metric
943
- wer_desc = ""
944
- if training_args.predict_with_generate:
945
- wer_metric = compute_metrics(eval_preds, eval_labels)
946
- eval_metrics.update(wer_metric)
947
- wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
948
-
949
- # Print metrics
950
- desc = f"Epoch... ({epoch} | Eval Loss: {eval_metrics['loss']} | {wer_desc})"
951
- logger.info(desc)
952
-
953
- train_time += time.time() - train_start
954
- train_metric = unreplicate(train_metric)
955
-
956
-
957
- # Save metrics
958
- if has_tensorboard and jax.process_index() == 0:
959
- cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
960
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
961
-
962
- # save checkpoint after each epoch and push checkpoint to the hub
963
- if jax.process_index() == 0:
964
- params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
965
- model.save_pretrained(training_args.output_dir, params=params)
966
- tokenizer.save_pretrained(training_args.output_dir)
967
- if training_args.push_to_hub:
968
- repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
969
-
970
-
971
- if __name__ == "__main__":
972
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_flax_speech_recognition_seq2seq_streaming_v3.py DELETED
@@ -1,1011 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- Fine-tuning the Flax library models for sequence to sequence speech recognition.
18
- """
19
- # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
-
21
- import itertools
22
- import logging
23
- import math
24
- import os
25
- import sys
26
- import time
27
- from dataclasses import field
28
- from functools import partial
29
- from pathlib import Path
30
- from typing import Any, Callable, Dict, Generator, List, Optional, Union
31
-
32
- import datasets
33
- import flax
34
- import jax
35
- import jax.numpy as jnp
36
- import numpy as np
37
- import optax
38
- import torch
39
- from datasets import Dataset, DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
40
- from torch.utils.data import IterableDataset
41
- from flax import jax_utils, traverse_util
42
- from flax.jax_utils import pad_shard_unpad, unreplicate
43
- from flax.training import train_state
44
- from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
45
- from huggingface_hub import Repository, create_repo
46
- from tqdm import tqdm
47
-
48
- import evaluate
49
- import transformers
50
- from transformers import (
51
- AutoConfig,
52
- AutoFeatureExtractor,
53
- AutoProcessor,
54
- AutoTokenizer,
55
- FlaxAutoModelForSpeechSeq2Seq,
56
- HfArgumentParser,
57
- Seq2SeqTrainingArguments,
58
- is_tensorboard_available,
59
- )
60
-
61
- from transformers.models.whisper.english_normalizer import BasicTextNormalizer
62
- from transformers.file_utils import get_full_repo_name
63
- from transformers.utils import check_min_version, send_example_telemetry
64
- from transformers.utils.versions import require_version
65
-
66
-
67
- # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
68
- check_min_version("4.27.0.dev0")
69
-
70
- require_version("datasets>=1.18.2",
71
- "To fix: pip install -r examples/flax/speech-recogintion/requirements.txt")
72
-
73
- logger = logging.getLogger(__name__)
74
-
75
-
76
- @flax.struct.dataclass
77
- class ModelArguments:
78
- """
79
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
80
- """
81
-
82
- model_name_or_path: str = field(
83
- metadata={
84
- "help": "Path to pretrained model or model identifier from huggingface.co/models"}
85
- )
86
- config_name: Optional[str] = field(
87
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
88
- )
89
- tokenizer_name: Optional[str] = field(
90
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
91
- )
92
- feature_extractor_name: Optional[str] = field(
93
- default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
94
- )
95
- cache_dir: Optional[str] = field(
96
- default=None,
97
- metadata={
98
- "help": "Where to store the pretrained models downloaded from huggingface.co"},
99
- )
100
- use_fast_tokenizer: bool = field(
101
- default=True,
102
- metadata={
103
- "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
104
- )
105
- model_revision: str = field(
106
- default="main",
107
- metadata={
108
- "help": "The specific model version to use (can be a branch name, tag name or commit id)."},
109
- )
110
- use_auth_token: bool = field(
111
- default=False,
112
- metadata={
113
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
114
- "with private models)."
115
- },
116
- )
117
- dtype: Optional[str] = field(
118
- default="float32",
119
- metadata={
120
- "help": (
121
- "Floating-point format in which the model weights should be initialized and trained. Choose one of"
122
- " `[float32, float16, bfloat16]`."
123
- )
124
- },
125
- )
126
- num_beams: Optional[int] = field(
127
- default=None,
128
- metadata={
129
- "help": (
130
- "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
131
- "which is used during evaluation."
132
- )
133
- },
134
- )
135
-
136
-
137
- @flax.struct.dataclass
138
- class DataTrainingArguments:
139
- """
140
- Arguments pertaining to what data we are going to input our model for training and eval.
141
- """
142
-
143
- dataset_name: str = field(
144
- default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
145
- )
146
- dataset_config_name: Optional[str] = field(
147
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
148
- )
149
- text_column: Optional[str] = field(
150
- default=None,
151
- metadata={
152
- "help": "The name of the column in the datasets containing the full texts (for summarization)."},
153
- )
154
- dataset_cache_dir: Optional[str] = field(
155
- default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
156
- )
157
- overwrite_cache: bool = field(
158
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
159
- )
160
- preprocessing_num_workers: Optional[int] = field(
161
- default=None,
162
- metadata={"help": "The number of processes to use for the preprocessing."},
163
- )
164
- max_train_samples: Optional[int] = field(
165
- default=None,
166
- metadata={
167
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
168
- "value if set."
169
- },
170
- )
171
- max_eval_samples: Optional[int] = field(
172
- default=None,
173
- metadata={
174
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
175
- "value if set."
176
- },
177
- )
178
- audio_column_name: str = field(
179
- default="audio",
180
- metadata={
181
- "help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
182
- )
183
- text_column_name: str = field(
184
- default="text",
185
- metadata={
186
- "help": "The name of the dataset column containing the text data. Defaults to 'text'"},
187
- )
188
- max_duration_in_seconds: float = field(
189
- default=30.0,
190
- metadata={
191
- "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
192
- )
193
- min_duration_in_seconds: float = field(
194
- default=0.0,
195
- metadata={
196
- "help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
197
- )
198
- max_label_length: float = field(
199
- default=128,
200
- metadata={
201
- "help": "Truncate transcriptions that are longer `max_eval_length` tokens."},
202
- )
203
- pad_input_to_multiple_of: Optional[int] = field(
204
- default=None,
205
- metadata={
206
- "help": "If set will pad the input sequence to a multiple of the provided value. "
207
- "This is important to avoid triggering recompilations on TPU. If unspecified, will default to padding the inputs to max length."
208
- },
209
- )
210
- pad_target_to_multiple_of: Optional[int] = field(
211
- default=None,
212
- metadata={
213
- "help": "If set will pad the target sequence to a multiple of the provided value. "
214
- "This is important to avoid triggering recompilations on TPU. If unspecified, will default to padding the targets to max length."
215
- },
216
- )
217
- preprocessing_only: bool = field(
218
- default=False,
219
- metadata={
220
- "help": "Whether to only do data preprocessing and skip training. "
221
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
222
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
223
- "so that the cached datasets can consequently be loaded in distributed training"
224
- },
225
- )
226
- train_split_name: str = field(
227
- default="train",
228
- metadata={
229
- "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
230
- },
231
- )
232
- eval_split_name: str = field(
233
- default="validation",
234
- metadata={
235
- "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
236
- },
237
- )
238
- do_lower_case: bool = field(
239
- default=True,
240
- metadata={"help": "Whether the target text should be lower cased."},
241
- )
242
- do_remove_punctuation: bool = field(
243
- default=False,
244
- metadata={
245
- "help": "Whether the target text should be striped of punctuation."},
246
- )
247
- do_normalize_eval: bool = field(
248
- default=True,
249
- metadata={
250
- "help": "Whether to normalise the references and predictions in the eval WER calculation."},
251
- )
252
- language: str = field(
253
- default=None,
254
- metadata={
255
- "help": (
256
- "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
257
- "only. For English speech recognition, it should be set to `None`."
258
- )
259
- },
260
- )
261
- task: str = field(
262
- default="transcribe",
263
- metadata={
264
- "help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
265
- )
266
- num_train_steps: int = field(default=50000, metadata={
267
- "help": "The number of training steps."})
268
- shuffle_buffer_size: Optional[int] = field(
269
- default=500,
270
- metadata={
271
- "help": (
272
- "The number of streamed examples to download before shuffling them. The large the buffer, "
273
- "the closer it is to real offline shuffling."
274
- )
275
- },
276
- )
277
- streaming: bool = field(
278
- default=True,
279
- metadata={
280
- "help": "Whether to use streaming mode to load and pre-process the data."},
281
- )
282
-
283
-
284
- def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
285
- """
286
- Shift label ids one token to the right.
287
- """
288
- shifted_label_ids = np.zeros_like(label_ids)
289
- shifted_label_ids[:, 1:] = label_ids[:, :-1]
290
- shifted_label_ids[:, 0] = decoder_start_token_id
291
-
292
- return shifted_label_ids
293
-
294
-
295
- @flax.struct.dataclass
296
- class FlaxDataCollatorSpeechSeq2SeqWithPadding:
297
- """
298
- Data collator that will dynamically pad the inputs received.
299
- Args:
300
- processor ([`Wav2Vec2Processor`])
301
- The processor used for proccessing the data.
302
- decoder_start_token_id (:obj: `int`)
303
- The begin-of-sentence of the decoder.
304
- input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
305
- Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
306
- among:
307
- * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
308
- sequence if provided).
309
- * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
310
- maximum acceptable input length for the model if that argument is not provided.
311
- * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
312
- different lengths).
313
- target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
314
- Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
315
- See above for details.
316
- max_input_length (:obj:`float`, `optional`):
317
- Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
318
- max_target_length (:obj:`int`, `optional`):
319
- Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
320
- pad_input_to_multiple_of (:obj:`int`, `optional`):
321
- If set will pad the input sequence to a multiple of the provided value.
322
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
323
- 7.5 (Volta).
324
- pad_target_to_multiple_of (:obj:`int`, `optional`):
325
- If set will pad the target sequence to a multiple of the provided value.
326
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
327
- 7.5 (Volta).
328
- """
329
-
330
- processor: Any
331
- decoder_start_token_id: int
332
- input_padding: Union[bool, str] = "longest"
333
- target_padding: Union[bool, str] = "max_length"
334
- max_input_length: Optional[float] = None
335
- max_target_length: Optional[int] = None
336
- pad_input_to_multiple_of: Optional[int] = None
337
- pad_target_to_multiple_of: Optional[int] = None
338
-
339
- def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
340
- model_input_name = self.processor.model_input_names[0]
341
- input_features = {model_input_name: features[model_input_name]}
342
- label_features = {"input_ids": features["labels"]}
343
-
344
- # reformat list to dict and set to pytorch format
345
- batch = self.processor.feature_extractor.pad(
346
- input_features,
347
- max_length=self.max_input_length,
348
- padding=self.input_padding,
349
- pad_to_multiple_of=self.pad_input_to_multiple_of,
350
- return_tensors="np",
351
- )
352
-
353
- labels_batch = self.processor.tokenizer.pad(
354
- label_features,
355
- max_length=self.max_target_length,
356
- padding=self.target_padding,
357
- pad_to_multiple_of=self.pad_target_to_multiple_of,
358
- return_tensors="np",
359
- )
360
-
361
- # if bos token is appended in previous tokenization step,
362
- # cut bos token here as it's append later anyways
363
- labels = labels_batch["input_ids"]
364
- if (labels[:, 0] == self.decoder_start_token_id).all().item():
365
- labels = labels[:, 1:]
366
- labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
367
-
368
- decoder_input_ids = shift_tokens_right(
369
- labels, self.decoder_start_token_id)
370
-
371
- # replace padding with -100 to ignore correctly when computing the loss
372
- labels = np.ma.array(labels, mask=np.not_equal(
373
- labels_batch.attention_mask, 1))
374
- labels = labels.filled(fill_value=-100)
375
-
376
- batch["labels"] = labels
377
- batch["decoder_input_ids"] = decoder_input_ids
378
-
379
- return batch
380
-
381
-
382
- def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
383
- """
384
- Utility function to load a dataset in streaming mode. For datasets with multiple splits,
385
- each split is loaded individually and then splits combined by taking alternating examples from
386
- each (interleaving).
387
- """
388
- if "+" in split:
389
- # load multiple splits separated by the `+` symbol with streaming mode
390
- dataset_splits = [
391
- load_dataset(dataset_name, dataset_config_name,
392
- split=split_name, streaming=streaming, **kwargs)
393
- for split_name in split.split("+")
394
- ]
395
- # interleave multiple splits to form one dataset
396
- interleaved_dataset = interleave_datasets(dataset_splits)
397
- return interleaved_dataset
398
- else:
399
- # load a single split *with* streaming mode
400
- dataset = load_dataset(
401
- dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
402
- return dataset
403
-
404
-
405
- def collate_batch(samples):
406
- return {key: [feature[key] for feature in samples] for key in samples[0]}
407
-
408
-
409
- def data_loader(
410
- dataset: Dataset,
411
- batch_size: int,
412
- drop_last: bool=True,
413
- num_workers: int=0,
414
- ) -> Generator:
415
- """
416
- Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
417
- and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
418
- """
419
- data_loader_iterator = iter(torch.utils.data.DataLoader(
420
- batch_size=batch_size,
421
- dataset=dataset.with_format("torch"),
422
- num_workers=num_workers,
423
- collate_fn=collate_batch,
424
- drop_last=drop_last,
425
- ))
426
- return data_loader_iterator
427
-
428
-
429
- class TrainState(train_state.TrainState):
430
- dropout_rng: jnp.ndarray
431
-
432
- def replicate(self):
433
- return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
434
-
435
-
436
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
437
- summary_writer.scalar("train_time", train_time, step)
438
-
439
- train_metrics = get_metrics(train_metrics)
440
- for key, vals in train_metrics.items():
441
- tag = f"train_{key}"
442
- for i, val in enumerate(vals):
443
- summary_writer.scalar(tag, val, step - len(vals) + i + 1)
444
-
445
- for metric_name, value in eval_metrics.items():
446
- summary_writer.scalar(f"eval_{metric_name}", value, step)
447
-
448
-
449
- def create_learning_rate_fn(
450
- num_train_steps: int, num_warmup_steps: int, learning_rate: float, warmup_init_value: float=0.0, decay_end_value: float=0.0,
451
- ) -> Callable[[int], jnp.array]:
452
- """Returns a linear warmup, linear_decay learning rate function."""
453
- warmup_fn = optax.linear_schedule(
454
- init_value=warmup_init_value, end_value=learning_rate, transition_steps=num_warmup_steps)
455
- decay_fn = optax.linear_schedule(
456
- init_value=learning_rate, end_value=decay_end_value, transition_steps=num_train_steps - num_warmup_steps
457
- )
458
- schedule_fn = optax.join_schedules(
459
- schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
460
- return schedule_fn
461
-
462
-
463
- def main():
464
- # 1. Parse input arguments
465
- # See all possible arguments in src/transformers/training_args.py
466
- # or by passing the --help flag to this script.
467
- # We now keep distinct sets of args, for a cleaner separation of concerns.
468
- parser = HfArgumentParser(
469
- (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
470
-
471
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
472
- # If we pass only one argument to the script and it's the path to a json file,
473
- # let's parse it to get our arguments.
474
- model_args, data_args, training_args = parser.parse_json_file(
475
- json_file=os.path.abspath(sys.argv[1]))
476
- else:
477
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
478
-
479
- # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
480
- # information sent is the one passed as arguments along with your JAX/Flax versions.
481
- send_example_telemetry("run_speech_recognition_seq2seq",
482
- model_args, data_args, framework="flax")
483
-
484
- # 2. Setup logging
485
- # Make one log on every process with the configuration for debugging.
486
- logging.basicConfig(
487
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
488
- datefmt="%m/%d/%Y %H:%M:%S",
489
- handlers=[logging.StreamHandler(sys.stdout)],
490
- )
491
- # Set the verbosity to info of the Transformers logger.
492
- # We only want one process per machine to log things on the screen.
493
- logger.setLevel(logging.INFO if jax.process_index()
494
- == 0 else logging.ERROR)
495
- if jax.process_index() == 0:
496
- datasets.utils.logging.set_verbosity_warning()
497
- transformers.utils.logging.set_verbosity_info()
498
- else:
499
- datasets.utils.logging.set_verbosity_error()
500
- transformers.utils.logging.set_verbosity_error()
501
-
502
- logger.info("Training/evaluation parameters %s", training_args)
503
-
504
- # Check the output dir is valid
505
- if (
506
- os.path.exists(training_args.output_dir)
507
- and os.listdir(training_args.output_dir)
508
- and training_args.do_train
509
- and not training_args.overwrite_output_dir
510
- ):
511
- raise ValueError(
512
- f"Output directory ({training_args.output_dir}) already exists and is not empty."
513
- "Use `--overwrite_output_dir` to overcome."
514
- )
515
-
516
- # Handle the repository creation
517
- if training_args.push_to_hub:
518
- if training_args.hub_model_id is None:
519
- repo_name = get_full_repo_name(
520
- Path(training_args.output_dir).absolute(
521
- ).name, token=training_args.hub_token
522
- )
523
- else:
524
- repo_name = training_args.hub_model_id
525
- create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
526
- repo = Repository(training_args.output_dir,
527
- clone_from=repo_name, token=training_args.hub_token)
528
-
529
- # 3. Load dataset
530
- raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
531
-
532
- if training_args.do_train:
533
- raw_datasets["train"] = load_maybe_streaming_dataset(
534
- data_args.dataset_name,
535
- data_args.dataset_config_name,
536
- split=data_args.train_split_name,
537
- cache_dir=data_args.dataset_cache_dir,
538
- streaming=data_args.streaming,
539
- use_auth_token=True if model_args.use_auth_token else None,
540
- )
541
-
542
- if training_args.do_eval:
543
- raw_datasets["eval"] = load_maybe_streaming_dataset(
544
- data_args.dataset_name,
545
- data_args.dataset_config_name,
546
- split=data_args.eval_split_name,
547
- cache_dir=data_args.dataset_cache_dir,
548
- streaming=data_args.streaming,
549
- use_auth_token=True if model_args.use_auth_token else None,
550
- )
551
-
552
- if not training_args.do_train and not training_args.do_eval:
553
- raise ValueError(
554
- "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
555
- )
556
-
557
- raw_datasets_features = list(
558
- next(iter(raw_datasets.values())).features.keys())
559
-
560
- if data_args.audio_column_name not in raw_datasets_features:
561
- raise ValueError(
562
- f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
563
- "Make sure to set `--audio_column_name` to the correct audio column - one of "
564
- f"{', '.join(raw_datasets_features)}."
565
- )
566
-
567
- if data_args.text_column_name not in raw_datasets_features:
568
- raise ValueError(
569
- f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
570
- "Make sure to set `--text_column_name` to the correct text column - one of "
571
- f"{', '.join(raw_datasets_features)}."
572
- )
573
-
574
- # 5. Load pretrained model, tokenizer, and feature extractor
575
- config = AutoConfig.from_pretrained(
576
- model_args.config_name if model_args.config_name else model_args.model_name_or_path,
577
- cache_dir=model_args.cache_dir,
578
- revision=model_args.model_revision,
579
- use_auth_token=True if model_args.use_auth_token else None,
580
- )
581
- feature_extractor = AutoFeatureExtractor.from_pretrained(
582
- model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
583
- cache_dir=model_args.cache_dir,
584
- revision=model_args.model_revision,
585
- use_auth_token=True if model_args.use_auth_token else None,
586
- )
587
- tokenizer = AutoTokenizer.from_pretrained(
588
- model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
589
- cache_dir=model_args.cache_dir,
590
- use_fast=model_args.use_fast_tokenizer,
591
- revision=model_args.model_revision,
592
- use_auth_token=True if model_args.use_auth_token else None,
593
- )
594
-
595
- model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained(
596
- model_args.model_name_or_path,
597
- config=config,
598
- dtype=getattr(jnp, model_args.dtype),
599
- cache_dir=model_args.cache_dir,
600
- revision=model_args.model_revision,
601
- use_auth_token=True if model_args.use_auth_token else None,
602
- )
603
-
604
- if model.config.decoder_start_token_id is None:
605
- raise ValueError(
606
- "Make sure that `config.decoder_start_token_id` is correctly defined")
607
-
608
- # 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
609
- # so we just need to set the correct target sampling rate.
610
- dataset_sampling_rate = next(
611
- iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
612
-
613
- if dataset_sampling_rate != feature_extractor.sampling_rate:
614
- raw_datasets = raw_datasets.cast_column(
615
- data_args.audio_column_name, datasets.features.Audio(
616
- sampling_rate=feature_extractor.sampling_rate)
617
- )
618
-
619
- # 7. Preprocessing the datasets.
620
- # We need to read the audio files as arrays and tokenize the targets.
621
- max_input_length = int(
622
- data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
623
- min_input_length = int(
624
- data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
625
- max_label_length = (
626
- data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
627
- )
628
- pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
629
- pad_target_to_multiple_of = data_args.pad_target_to_multiple_of
630
- audio_column_name = data_args.audio_column_name
631
- num_workers = data_args.preprocessing_num_workers
632
- text_column_name = data_args.text_column_name
633
- model_input_name = feature_extractor.model_input_names[0]
634
- do_lower_case = data_args.do_lower_case
635
- do_remove_punctuation = data_args.do_remove_punctuation
636
- normalizer = BasicTextNormalizer() # 'official' text normalizer from OpenAI
637
-
638
- if data_args.language is not None:
639
- # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
640
- tokenizer.set_prefix_tokens(
641
- language=data_args.language, task=data_args.task)
642
-
643
- def prepare_dataset(batch):
644
- # process audio
645
- sample = batch[audio_column_name]
646
- inputs = feature_extractor(
647
- sample["array"], sampling_rate=sample["sampling_rate"])
648
- # process audio length
649
- batch[model_input_name] = inputs.get(model_input_name)[0]
650
- batch["input_length"] = len(sample["array"])
651
-
652
- # process targets
653
- input_str = batch[text_column_name].lower(
654
- ) if do_lower_case else batch[text_column_name]
655
- if do_remove_punctuation:
656
- input_str = normalizer(input_str).strip()
657
- batch["labels"] = tokenizer(input_str).input_ids
658
- return batch
659
-
660
- with training_args.main_process_first(desc="dataset map pre-processing"):
661
- vectorized_datasets = raw_datasets.map(
662
- prepare_dataset,
663
- remove_columns=raw_datasets_features,
664
- )
665
-
666
- # filter training data with inputs longer than max_input_length
667
- def is_audio_in_length_range(length):
668
- return min_input_length < length < max_input_length
669
-
670
- if training_args.do_train:
671
- vectorized_datasets["train"] = vectorized_datasets["train"].filter(
672
- is_audio_in_length_range,
673
- input_columns=["input_length"],
674
- )
675
-
676
- if training_args.do_eval:
677
- vectorized_datasets["eval"] = vectorized_datasets["eval"].filter(
678
- is_audio_in_length_range,
679
- input_columns=["input_length"],
680
- )
681
-
682
- # 8. Load Metric
683
- metric_wer = evaluate.load("wer")
684
- metric_cer = evaluate.load("cer")
685
- do_normalize_eval = data_args.do_normalize_eval
686
-
687
- def compute_metrics(pred_ids, label_ids):
688
- # replace padded labels by the padding token
689
- for idx in range(len(label_ids)):
690
- label_ids[idx][label_ids[idx] == -100] = tokenizer.pad_token_id
691
-
692
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
693
- # we do not want to group tokens when computing the metrics
694
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
695
-
696
- if do_normalize_eval:
697
- pred_str = [normalizer(pred) for pred in pred_str]
698
- label_str = [normalizer(label) for label in label_str]
699
- # filtering step to only evaluate the samples that correspond to non-zero references:
700
- pred_str = [pred_str[i]
701
- for i in range(len(pred_str)) if len(label_str[i]) > 0]
702
- label_str = [label_str[i]
703
- for i in range(len(label_str)) if len(label_str[i]) > 0]
704
-
705
- wer = 100 * metric_wer.compute(predictions=pred_str, references=label_str)
706
- cer = 100 * metric_cer.compute(predictions=pred_str, references=label_str)
707
-
708
- return {"wer": wer, "cer": cer}
709
-
710
- # 9. Save feature extractor, tokenizer and config
711
- feature_extractor.save_pretrained(training_args.output_dir)
712
- tokenizer.save_pretrained(training_args.output_dir)
713
- config.save_pretrained(training_args.output_dir)
714
-
715
- processor = AutoProcessor.from_pretrained(training_args.output_dir)
716
-
717
- data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
718
- processor=processor,
719
- decoder_start_token_id=model.config.decoder_start_token_id,
720
- input_padding="longest",
721
- target_padding="longest",
722
- max_target_length=max_label_length,
723
- pad_input_to_multiple_of=pad_input_to_multiple_of,
724
- pad_target_to_multiple_of=pad_target_to_multiple_of if pad_target_to_multiple_of else max_label_length,
725
- )
726
-
727
- # Enable tensorboard only on the master node
728
- has_tensorboard = is_tensorboard_available()
729
- if has_tensorboard and jax.process_index() == 0:
730
- try:
731
- from flax.metrics.tensorboard import SummaryWriter
732
-
733
- summary_writer = SummaryWriter(
734
- log_dir=Path(training_args.output_dir))
735
- except ImportError as ie:
736
- has_tensorboard = False
737
- logger.warning(
738
- f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
739
- )
740
- else:
741
- logger.warning(
742
- "Unable to display metrics through TensorBoard because the package is not installed: "
743
- "Please run pip install tensorboard to enable."
744
- )
745
-
746
- # Initialize our training
747
- rng = jax.random.PRNGKey(training_args.seed)
748
- rng, dropout_rng = jax.random.split(rng)
749
- # rng, input_rng = jax.random.split(rng)
750
-
751
- # Store some constant
752
- #num_epochs = int(training_args.num_train_epochs)
753
- train_batch_size = int(
754
- training_args.per_device_train_batch_size) * jax.device_count()
755
- eval_batch_size = int(
756
- training_args.per_device_eval_batch_size) * jax.device_count()
757
-
758
- # Create learning rate schedule
759
- lr_scheduler_types = {"linear", "constant", "constant_with_warmup"}
760
- if training_args.lr_scheduler_type not in lr_scheduler_types:
761
- raise ValueError(
762
- f"lr_scheduler_type of type {training_args.lr_scheduler_type} not supported, choose from {lr_scheduler_types}."
763
- )
764
- elif training_args.lr_scheduler_type == "constant":
765
- warmup_init_value = training_args.learning_rate
766
- decay_end_value = training_args.learning_rate
767
- elif training_args.lr_scheduler_type == "constant_with_warmup":
768
- warmup_init_value = 0.0
769
- decay_end_value = training_args.learning_rate
770
- else:
771
- warmup_init_value = 0.0
772
- decay_end_value = 0.0
773
- linear_decay_lr_schedule_fn = create_learning_rate_fn(
774
- data_args.num_train_steps * train_batch_size,
775
- training_args.warmup_steps,
776
- training_args.learning_rate,
777
- warmup_init_value=warmup_init_value,
778
- decay_end_value=decay_end_value,
779
- )
780
-
781
- # We use Optax's "masking" functionality to not apply weight decay
782
- # to bias and LayerNorm scale parameters. decay_mask_fn returns a
783
- # mask boolean with the same structure as the parameters.
784
- # The mask is True for parameters that should be decayed.
785
- def decay_mask_fn(params):
786
- flat_params = traverse_util.flatten_dict(params)
787
- # find out all LayerNorm parameters
788
- layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
789
- layer_norm_named_params = set(
790
- [
791
- layer[-2:]
792
- for layer_norm_name in layer_norm_candidates
793
- for layer in flat_params.keys()
794
- if layer_norm_name in "".join(layer).lower()
795
- ]
796
- )
797
- flat_mask = {path: (path[-1] != "bias" and path[-2:]
798
- not in layer_norm_named_params) for path in flat_params}
799
- return traverse_util.unflatten_dict(flat_mask)
800
-
801
- # create adam optimizer
802
- adamw = optax.adamw(
803
- learning_rate=linear_decay_lr_schedule_fn,
804
- b1=training_args.adam_beta1,
805
- b2=training_args.adam_beta2,
806
- eps=training_args.adam_epsilon,
807
- weight_decay=training_args.weight_decay,
808
- mask=decay_mask_fn,
809
- )
810
-
811
- # Setup train state
812
- state = TrainState.create(
813
- apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
814
-
815
- # label smoothed cross entropy
816
- def loss_fn(logits, labels, label_smoothing_factor=0.0):
817
- """
818
- The label smoothing implementation is adapted from Flax's official example:
819
- https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
820
- """
821
- vocab_size = logits.shape[-1]
822
- confidence = 1.0 - label_smoothing_factor
823
- low_confidence = (1.0 - confidence) / (vocab_size - 1)
824
- normalizing_constant = -(
825
- confidence * jnp.log(confidence) + (vocab_size - 1) *
826
- low_confidence * jnp.log(low_confidence + 1e-20)
827
- )
828
- soft_labels = onehot(labels, vocab_size,
829
- on_value=confidence, off_value=low_confidence)
830
-
831
- loss = optax.softmax_cross_entropy(logits, soft_labels)
832
- loss = loss - normalizing_constant
833
-
834
- # ignore padded tokens from loss, i.e. where labels are not set to -100
835
- padding_mask = labels >= 0
836
- loss = loss * padding_mask
837
- loss = loss.sum()
838
- num_labels = padding_mask.sum()
839
- return loss, num_labels
840
-
841
- # Define gradient update step fn
842
- def train_step(state, batch, label_smoothing_factor=0.0):
843
- dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
844
-
845
- def compute_loss(params):
846
- labels = batch.pop("labels")
847
- logits = state.apply_fn(
848
- **batch, params=params, dropout_rng=dropout_rng, train=True)[0]
849
- loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
850
- return loss, num_labels
851
-
852
- grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
853
- (loss, num_labels), grad = grad_fn(state.params)
854
- num_labels = jax.lax.psum(num_labels, "batch")
855
-
856
- # true loss = total loss / total samples
857
- loss = jax.lax.psum(loss, "batch")
858
- loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
859
-
860
- # true grad = total grad / total samples
861
- grad = jax.lax.psum(grad, "batch")
862
- grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
863
- new_state = state.apply_gradients(
864
- grads=grad, dropout_rng=new_dropout_rng)
865
-
866
- metrics = {"loss": loss,
867
- "learning_rate": linear_decay_lr_schedule_fn(state.step)}
868
- return new_state, metrics
869
-
870
- # Define eval fn
871
- def eval_step(params, batch, label_smoothing_factor=0.0):
872
- labels = batch.pop("labels")
873
- logits = model(**batch, params=params, train=False)[0]
874
-
875
- loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
876
- num_labels = jax.lax.psum(num_labels, "batch")
877
-
878
- # true loss = total loss / total samples
879
- loss = jax.lax.psum(loss, "batch")
880
- loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
881
-
882
- metrics = {"loss": loss}
883
- return metrics
884
-
885
- # Define generation function
886
- num_beams = model_args.num_beams if model_args.num_beams is not None else model.config.num_beams
887
- gen_kwargs = {"max_length": max_label_length, "num_beams": num_beams}
888
-
889
- def generate_step(params, batch):
890
- model.params = params
891
- output_ids = model.generate(batch[model_input_name], attention_mask=batch.get(
892
- "attention_mask"), **gen_kwargs)
893
- return output_ids.sequences
894
-
895
- # Create parallel version of the train and eval step
896
- p_train_step = jax.pmap(
897
- partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0, )
898
- )
899
- p_eval_step = jax.pmap(partial(
900
- eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
901
- p_generate_step = jax.pmap(generate_step, "batch")
902
-
903
- # Replicate the train state on each device
904
- state = state.replicate()
905
-
906
- logger.info("***** Running training *****")
907
- logger.info(
908
- f" Num examples = {data_args.num_train_steps * train_batch_size}")
909
- logger.info(
910
- f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
911
- logger.info(
912
- f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
913
- logger.info(f" Total optimization steps = {data_args.num_train_steps}")
914
-
915
- train_time = 0
916
-
917
- # ======================== Training ================================
918
- train_start = time.time()
919
-
920
- train_metrics = []
921
- epoch = 0
922
- train_dataset = vectorized_datasets["train"].shuffle(seed=training_args.seed)
923
- eval_dataset = vectorized_datasets["eval"]
924
- train_loader = data_loader(train_dataset, train_batch_size)
925
- # train
926
- for step in tqdm(range(data_args.num_train_steps), desc="Training...", position=1, leave=False):
927
- try:
928
- samples = next(train_loader)
929
- except StopIteration:
930
- epoch += 1
931
- train_dataset.set_epoch(epoch)
932
- train_loader = data_loader(train_dataset, train_batch_size)
933
- samples = next(train_loader)
934
- logger.info(
935
- f"Completed epoch ({epoch} | Loss: {train_metric['loss']}, Learning Rate:"
936
- f" {train_metric['learning_rate']})"
937
- )
938
-
939
- batch = data_collator(samples)
940
- batch = shard(batch.data)
941
- state, train_metric = p_train_step(state, batch)
942
- train_metrics.append(train_metric)
943
-
944
- train_time += time.time() - train_start
945
- train_metric = unreplicate(train_metric)
946
- # ======================== Evaluating ==============================
947
- if step % training_args.eval_steps == 0 and step > 0:
948
- eval_metrics = []
949
- eval_preds = []
950
- eval_labels = []
951
- eval_loader = data_loader(eval_dataset, eval_batch_size, drop_last=False)
952
- if data_args.max_eval_samples:
953
- max_eval_steps_iter = range(1 + data_args.max_eval_samples // eval_batch_size)
954
- else:
955
- max_eval_steps_iter = itertools.repeat(None)
956
- for _ in tqdm(max_eval_steps_iter, desc="Evaluating...", position=2, leave=False):
957
- # Model forward
958
- try:
959
- samples = next(eval_loader)
960
- except StopIteration:
961
- break
962
- batch = data_collator(samples)
963
- labels = batch["labels"]
964
-
965
- metrics = pad_shard_unpad(p_eval_step, static_return=True)(
966
- state.params, batch.data, min_device_batch=training_args.per_device_eval_batch_size
967
- )
968
- eval_metrics.append(metrics)
969
-
970
- # generation
971
- if training_args.predict_with_generate:
972
- generated_ids = pad_shard_unpad(
973
- p_generate_step)(state.params, batch.data)
974
- eval_preds.extend(jax.device_get(
975
- generated_ids.reshape(-1, gen_kwargs["max_length"])))
976
- eval_labels.extend(labels)
977
-
978
- # normalize eval metrics
979
- eval_metrics = get_metrics(eval_metrics)
980
- eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
981
-
982
- # compute metrics
983
- metric_desc = ""
984
- if training_args.predict_with_generate:
985
- metric_values = compute_metrics(eval_preds, eval_labels)
986
- eval_metrics.update(metric_values)
987
- metric_desc = " ".join(
988
- [f"Eval {key}: {value} |" for key, value in metric_values.items()])
989
-
990
- # Print metrics
991
- desc = f"Epoch... ({epoch} | Eval Loss: {eval_metrics['loss']} | {metric_desc})"
992
- logger.info(desc)
993
-
994
- # Save metrics
995
- if has_tensorboard and jax.process_index() == 0:
996
- write_metric(summary_writer, train_metrics,
997
- eval_metrics, train_time, step)
998
-
999
- # save checkpoint after each epoch and push checkpoint to the hub
1000
- if jax.process_index() == 0:
1001
- params = jax.device_get(
1002
- jax.tree_util.tree_map(lambda x: x[0], state.params))
1003
- model.save_pretrained(training_args.output_dir, params=params)
1004
- tokenizer.save_pretrained(training_args.output_dir)
1005
- if training_args.push_to_hub:
1006
- repo.push_to_hub(
1007
- commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
1008
-
1009
-
1010
- if __name__ == "__main__":
1011
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_streaming.sh CHANGED
@@ -1,23 +1,20 @@
1
- python run_flax_speech_recognition_seq2seq_streaming_v3.py \
2
  --model_name_or_path openai/whisper-tiny.en \
3
  --dataset_name mozilla-foundation/common_voice_11_0 \
4
- --dataset_config es \
5
- --language es \
6
- --text_column_name sentence \
7
  --train_split_name test\
8
  --eval_split_name test\
9
- --output_dir whisper-small-flaxtest\
10
  --overwrite_output_dir\
11
- --num_train_epochs=1\
12
  --warmup_steps=8 \
13
  --do_train \
14
  --do_eval \
15
- --num_train_steps 200 \
16
- --max_eval_samples 100 \
17
- --eval_steps 50 \
18
  --learning_rate=2e-4 \
19
- --per_device_train_batch_size=4 \
20
  --per_device_eval_batch_size=2 \
21
  --predict_with_generate \
22
- --streaming=True \
23
-
 
1
+ python run_flax_speech_recognition_seq2seq_streaming.py \
2
  --model_name_or_path openai/whisper-tiny.en \
3
  --dataset_name mozilla-foundation/common_voice_11_0 \
4
+ --dataset_config nn-NO \
5
+ --text_column_name sentence \
 
6
  --train_split_name test\
7
  --eval_split_name test\
8
+ --output_dir whisper-tiny-ft-dummy\
9
  --overwrite_output_dir\
10
+ --max_eval_samples=16\
11
  --warmup_steps=8 \
12
  --do_train \
13
  --do_eval \
14
+ --num_train_steps 30 \
15
+ --eval_steps 10 \
 
16
  --learning_rate=2e-4 \
17
+ --per_device_train_batch_size=2 \
18
  --per_device_eval_batch_size=2 \
19
  --predict_with_generate \
20
+ --streaming=True
 
run_streaming_v3.sh CHANGED
@@ -1,4 +1,4 @@
1
- python run_flax_speech_recognition_seq2seq_streaming_v3.py \
2
  --model_name_or_path openai/whisper-tiny.en \
3
  --dataset_name mozilla-foundation/common_voice_11_0 \
4
  --dataset_config nn-NO \
 
1
+ python run_flax_speech_recognition_seq2seq_streaming_v3_pere.py \
2
  --model_name_or_path openai/whisper-tiny.en \
3
  --dataset_name mozilla-foundation/common_voice_11_0 \
4
  --dataset_config nn-NO \