versae commited on
Commit
a325bea
·
2 Parent(s): fbb405375676fa
run_flax_speech_recognition_seq2seq_streaming_v3_pere.py ADDED
@@ -0,0 +1,1030 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for sequence to sequence speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import sys
25
+ import time
26
+ from dataclasses import field
27
+ from functools import partial
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import flax
33
+ import jax
34
+ import jax.numpy as jnp
35
+ import numpy as np
36
+ import optax
37
+ import torch
38
+ from datasets import Dataset, DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
39
+ from torch.utils.data import IterableDataset
40
+ from flax import jax_utils, traverse_util
41
+ from flax.jax_utils import pad_shard_unpad, unreplicate
42
+ from flax.training import train_state
43
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
44
+ from huggingface_hub import Repository, create_repo
45
+ from tqdm import tqdm
46
+
47
+ import evaluate
48
+ import transformers
49
+ from transformers import (
50
+ AutoConfig,
51
+ AutoFeatureExtractor,
52
+ AutoProcessor,
53
+ AutoTokenizer,
54
+ FlaxAutoModelForSpeechSeq2Seq,
55
+ HfArgumentParser,
56
+ Seq2SeqTrainingArguments,
57
+ is_tensorboard_available,
58
+ )
59
+
60
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
61
+ from transformers.file_utils import get_full_repo_name
62
+ from transformers.utils import check_min_version, send_example_telemetry
63
+ from transformers.utils.versions import require_version
64
+
65
+
66
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
67
+ check_min_version("4.27.0.dev0")
68
+
69
+ require_version("datasets>=1.18.2",
70
+ "To fix: pip install -r examples/flax/speech-recogintion/requirements.txt")
71
+
72
+ logger = logging.getLogger(__name__)
73
+
74
+
75
+ @flax.struct.dataclass
76
+ class ModelArguments:
77
+ """
78
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
79
+ """
80
+
81
+ model_name_or_path: str = field(
82
+ metadata={
83
+ "help": "Path to pretrained model or model identifier from huggingface.co/models"}
84
+ )
85
+ config_name: Optional[str] = field(
86
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
87
+ )
88
+ tokenizer_name: Optional[str] = field(
89
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
90
+ )
91
+ feature_extractor_name: Optional[str] = field(
92
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
93
+ )
94
+ cache_dir: Optional[str] = field(
95
+ default=None,
96
+ metadata={
97
+ "help": "Where to store the pretrained models downloaded from huggingface.co"},
98
+ )
99
+ use_fast_tokenizer: bool = field(
100
+ default=True,
101
+ metadata={
102
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
103
+ )
104
+ model_revision: str = field(
105
+ default="main",
106
+ metadata={
107
+ "help": "The specific model version to use (can be a branch name, tag name or commit id)."},
108
+ )
109
+ use_auth_token: bool = field(
110
+ default=False,
111
+ metadata={
112
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
113
+ "with private models)."
114
+ },
115
+ )
116
+ dtype: Optional[str] = field(
117
+ default="float32",
118
+ metadata={
119
+ "help": (
120
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
121
+ " `[float32, float16, bfloat16]`."
122
+ )
123
+ },
124
+ )
125
+ num_beams: Optional[int] = field(
126
+ default=None,
127
+ metadata={
128
+ "help": (
129
+ "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
130
+ "which is used during evaluation."
131
+ )
132
+ },
133
+ )
134
+
135
+
136
+ @flax.struct.dataclass
137
+ class DataTrainingArguments:
138
+ """
139
+ Arguments pertaining to what data we are going to input our model for training and eval.
140
+ """
141
+
142
+ dataset_name: str = field(
143
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
144
+ )
145
+ dataset_config_name: Optional[str] = field(
146
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
147
+ )
148
+ text_column: Optional[str] = field(
149
+ default=None,
150
+ metadata={
151
+ "help": "The name of the column in the datasets containing the full texts (for summarization)."},
152
+ )
153
+ dataset_cache_dir: Optional[str] = field(
154
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
155
+ )
156
+ overwrite_cache: bool = field(
157
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
158
+ )
159
+ preprocessing_num_workers: Optional[int] = field(
160
+ default=None,
161
+ metadata={"help": "The number of processes to use for the preprocessing."},
162
+ )
163
+ max_train_samples: Optional[int] = field(
164
+ default=None,
165
+ metadata={
166
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
167
+ "value if set."
168
+ },
169
+ )
170
+ max_eval_samples: Optional[int] = field(
171
+ default=None,
172
+ metadata={
173
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
174
+ "value if set."
175
+ },
176
+ )
177
+ audio_column_name: str = field(
178
+ default="audio",
179
+ metadata={
180
+ "help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
181
+ )
182
+ text_column_name: str = field(
183
+ default="text",
184
+ metadata={
185
+ "help": "The name of the dataset column containing the text data. Defaults to 'text'"},
186
+ )
187
+ max_duration_in_seconds: float = field(
188
+ default=30.0,
189
+ metadata={
190
+ "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
191
+ )
192
+ min_duration_in_seconds: float = field(
193
+ default=0.0,
194
+ metadata={
195
+ "help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
196
+ )
197
+ max_label_length: float = field(
198
+ default=128,
199
+ metadata={
200
+ "help": "Truncate transcriptions that are longer `max_eval_length` tokens."},
201
+ )
202
+ pad_input_to_multiple_of: Optional[int] = field(
203
+ default=None,
204
+ metadata={
205
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
206
+ "This is important to avoid triggering recompilations on TPU. If unspecified, will default to padding the inputs to max length."
207
+ },
208
+ )
209
+ pad_target_to_multiple_of: Optional[int] = field(
210
+ default=None,
211
+ metadata={
212
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
213
+ "This is important to avoid triggering recompilations on TPU. If unspecified, will default to padding the targets to max length."
214
+ },
215
+ )
216
+ preprocessing_only: bool = field(
217
+ default=False,
218
+ metadata={
219
+ "help": "Whether to only do data preprocessing and skip training. "
220
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
221
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
222
+ "so that the cached datasets can consequently be loaded in distributed training"
223
+ },
224
+ )
225
+ train_split_name: str = field(
226
+ default="train",
227
+ metadata={
228
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
229
+ },
230
+ )
231
+ eval_split_name: str = field(
232
+ default="validation",
233
+ metadata={
234
+ "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
235
+ },
236
+ )
237
+ do_lower_case: bool = field(
238
+ default=True,
239
+ metadata={"help": "Whether the target text should be lower cased."},
240
+ )
241
+ do_remove_punctuation: bool = field(
242
+ default=False,
243
+ metadata={
244
+ "help": "Whether the target text should be striped of punctuation."},
245
+ )
246
+ do_normalize_eval: bool = field(
247
+ default=True,
248
+ metadata={
249
+ "help": "Whether to normalise the references and predictions in the eval WER calculation."},
250
+ )
251
+ language: str = field(
252
+ default=None,
253
+ metadata={
254
+ "help": (
255
+ "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
256
+ "only. For English speech recognition, it should be set to `None`."
257
+ )
258
+ },
259
+ )
260
+ task: str = field(
261
+ default="transcribe",
262
+ metadata={
263
+ "help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
264
+ )
265
+ num_train_steps: int = field(default=50000, metadata={
266
+ "help": "The number of training steps."})
267
+ # num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
268
+ shuffle_buffer_size: Optional[int] = field(
269
+ default=500,
270
+ metadata={
271
+ "help": (
272
+ "The number of streamed examples to download before shuffling them. The large the buffer, "
273
+ "the closer it is to real offline shuffling."
274
+ )
275
+ },
276
+ )
277
+ streaming: bool = field(
278
+ default=True,
279
+ metadata={
280
+ "help": "Whether to use streaming mode to load and pre-process the data."},
281
+ )
282
+
283
+
284
+ def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
285
+ """
286
+ Shift label ids one token to the right.
287
+ """
288
+ shifted_label_ids = np.zeros_like(label_ids)
289
+ shifted_label_ids[:, 1:] = label_ids[:, :-1]
290
+ shifted_label_ids[:, 0] = decoder_start_token_id
291
+
292
+ return shifted_label_ids
293
+
294
+
295
+ @flax.struct.dataclass
296
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
297
+ """
298
+ Data collator that will dynamically pad the inputs received.
299
+ Args:
300
+ processor ([`Wav2Vec2Processor`])
301
+ The processor used for proccessing the data.
302
+ decoder_start_token_id (:obj: `int`)
303
+ The begin-of-sentence of the decoder.
304
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
305
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
306
+ among:
307
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
308
+ sequence if provided).
309
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
310
+ maximum acceptable input length for the model if that argument is not provided.
311
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
312
+ different lengths).
313
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
314
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
315
+ See above for details.
316
+ max_input_length (:obj:`float`, `optional`):
317
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
318
+ max_target_length (:obj:`int`, `optional`):
319
+ Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
320
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
321
+ If set will pad the input sequence to a multiple of the provided value.
322
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
323
+ 7.5 (Volta).
324
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
325
+ If set will pad the target sequence to a multiple of the provided value.
326
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
327
+ 7.5 (Volta).
328
+ """
329
+
330
+ processor: Any
331
+ decoder_start_token_id: int
332
+ input_padding: Union[bool, str] = "longest"
333
+ target_padding: Union[bool, str] = "max_length"
334
+ max_input_length: Optional[float] = None
335
+ max_target_length: Optional[int] = None
336
+ pad_input_to_multiple_of: Optional[int] = None
337
+ pad_target_to_multiple_of: Optional[int] = None
338
+
339
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
340
+ model_input_name = self.processor.model_input_names[0]
341
+ input_features = {model_input_name: features[model_input_name]}
342
+ label_features = {"input_ids": features["labels"]}
343
+
344
+ # reformat list to dict and set to pytorch format
345
+ batch = self.processor.feature_extractor.pad(
346
+ input_features,
347
+ max_length=self.max_input_length,
348
+ padding=self.input_padding,
349
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
350
+ return_tensors="np",
351
+ )
352
+
353
+ labels_batch = self.processor.tokenizer.pad(
354
+ label_features,
355
+ max_length=self.max_target_length,
356
+ padding=self.target_padding,
357
+ pad_to_multiple_of=self.pad_target_to_multiple_of,
358
+ return_tensors="np",
359
+ )
360
+
361
+ # if bos token is appended in previous tokenization step,
362
+ # cut bos token here as it's append later anyways
363
+ labels = labels_batch["input_ids"]
364
+ if (labels[:, 0] == self.decoder_start_token_id).all().item():
365
+ labels = labels[:, 1:]
366
+ labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
367
+
368
+ decoder_input_ids = shift_tokens_right(
369
+ labels, self.decoder_start_token_id)
370
+
371
+ # replace padding with -100 to ignore correctly when computing the loss
372
+ labels = np.ma.array(labels, mask=np.not_equal(
373
+ labels_batch.attention_mask, 1))
374
+ labels = labels.filled(fill_value=-100)
375
+
376
+ batch["labels"] = labels
377
+ batch["decoder_input_ids"] = decoder_input_ids
378
+
379
+ return batch
380
+
381
+
382
+ def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
383
+ """
384
+ Utility function to load a dataset in streaming mode. For datasets with multiple splits,
385
+ each split is loaded individually and then splits combined by taking alternating examples from
386
+ each (interleaving).
387
+ """
388
+ if "+" in split:
389
+ # load multiple splits separated by the `+` symbol with streaming mode
390
+ dataset_splits = [
391
+ load_dataset(dataset_name, dataset_config_name,
392
+ split=split_name, streaming=streaming, **kwargs)
393
+ for split_name in split.split("+")
394
+ ]
395
+ # interleave multiple splits to form one dataset
396
+ interleaved_dataset = interleave_datasets(dataset_splits)
397
+ return interleaved_dataset
398
+ else:
399
+ # load a single split *with* streaming mode
400
+ dataset = load_dataset(
401
+ dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
402
+ return dataset
403
+
404
+
405
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
406
+ """
407
+ Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
408
+ and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
409
+ """
410
+ if shuffle:
411
+ batch_idx = jax.random.permutation(rng, len(dataset))
412
+ batch_idx = np.asarray(batch_idx)
413
+ else:
414
+ batch_idx = np.arange(len(dataset))
415
+
416
+ if drop_last:
417
+ steps_per_epoch = len(dataset) // batch_size
418
+ # Skip incomplete batch.
419
+ batch_idx = batch_idx[: steps_per_epoch * batch_size]
420
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
421
+ else:
422
+ steps_per_epoch = math.ceil(len(dataset) / batch_size)
423
+ batch_idx = np.array_split(batch_idx, steps_per_epoch)
424
+
425
+ for idx in batch_idx:
426
+ batch = dataset[idx]
427
+ yield batch
428
+
429
+
430
+ class TrainState(train_state.TrainState):
431
+ dropout_rng: jnp.ndarray
432
+
433
+ def replicate(self):
434
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
435
+
436
+
437
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
438
+ summary_writer.scalar("train_time", train_time, step)
439
+
440
+ train_metrics = get_metrics(train_metrics)
441
+ for key, vals in train_metrics.items():
442
+ tag = f"train_{key}"
443
+ for i, val in enumerate(vals):
444
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
445
+
446
+ for metric_name, value in eval_metrics.items():
447
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
448
+
449
+
450
+ def create_learning_rate_fn(
451
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
452
+ ) -> Callable[[int], jnp.array]:
453
+ """Returns a linear warmup, linear_decay learning rate function."""
454
+ warmup_fn = optax.linear_schedule(
455
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
456
+ decay_fn = optax.linear_schedule(
457
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
458
+ )
459
+ schedule_fn = optax.join_schedules(
460
+ schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
461
+ return schedule_fn
462
+
463
+
464
+ def main():
465
+ # 1. Parse input arguments
466
+ # See all possible arguments in src/transformers/training_args.py
467
+ # or by passing the --help flag to this script.
468
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
469
+ parser = HfArgumentParser(
470
+ (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
471
+
472
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
473
+ # If we pass only one argument to the script and it's the path to a json file,
474
+ # let's parse it to get our arguments.
475
+ model_args, data_args, training_args = parser.parse_json_file(
476
+ json_file=os.path.abspath(sys.argv[1]))
477
+ else:
478
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
479
+
480
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
481
+ # information sent is the one passed as arguments along with your JAX/Flax versions.
482
+ send_example_telemetry("run_speech_recognition_seq2seq",
483
+ model_args, data_args, framework="flax")
484
+
485
+ # 2. Setup logging
486
+ # Make one log on every process with the configuration for debugging.
487
+ logging.basicConfig(
488
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
489
+ datefmt="%m/%d/%Y %H:%M:%S",
490
+ handlers=[logging.StreamHandler(sys.stdout)],
491
+ )
492
+ # Set the verbosity to info of the Transformers logger.
493
+ # We only want one process per machine to log things on the screen.
494
+ logger.setLevel(logging.INFO if jax.process_index()
495
+ == 0 else logging.ERROR)
496
+ if jax.process_index() == 0:
497
+ datasets.utils.logging.set_verbosity_warning()
498
+ transformers.utils.logging.set_verbosity_info()
499
+ else:
500
+ datasets.utils.logging.set_verbosity_error()
501
+ transformers.utils.logging.set_verbosity_error()
502
+
503
+ logger.info("Training/evaluation parameters %s", training_args)
504
+
505
+ # Check the output dir is valid
506
+ if (
507
+ os.path.exists(training_args.output_dir)
508
+ and os.listdir(training_args.output_dir)
509
+ and training_args.do_train
510
+ and not training_args.overwrite_output_dir
511
+ ):
512
+ raise ValueError(
513
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
514
+ "Use `--overwrite_output_dir` to overcome."
515
+ )
516
+
517
+ # Handle the repository creation
518
+ if training_args.push_to_hub:
519
+ if training_args.hub_model_id is None:
520
+ repo_name = get_full_repo_name(
521
+ Path(training_args.output_dir).absolute(
522
+ ).name, token=training_args.hub_token
523
+ )
524
+ else:
525
+ repo_name = training_args.hub_model_id
526
+ create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
527
+ repo = Repository(training_args.output_dir,
528
+ clone_from=repo_name, token=training_args.hub_token)
529
+
530
+ # 3. Load dataset
531
+ raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
532
+
533
+ if training_args.do_train:
534
+ raw_datasets["train"] = load_maybe_streaming_dataset(
535
+ data_args.dataset_name,
536
+ data_args.dataset_config_name,
537
+ split=data_args.train_split_name,
538
+ cache_dir=data_args.dataset_cache_dir,
539
+ streaming=data_args.streaming,
540
+ use_auth_token=True if model_args.use_auth_token else None,
541
+ )
542
+
543
+ if training_args.do_eval:
544
+ raw_datasets["eval"] = load_maybe_streaming_dataset(
545
+ data_args.dataset_name,
546
+ data_args.dataset_config_name,
547
+ split=data_args.eval_split_name,
548
+ cache_dir=data_args.dataset_cache_dir,
549
+ streaming=data_args.streaming,
550
+ use_auth_token=True if model_args.use_auth_token else None,
551
+ )
552
+
553
+ if not training_args.do_train and not training_args.do_eval:
554
+ raise ValueError(
555
+ "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
556
+ )
557
+
558
+ raw_datasets_features = list(
559
+ next(iter(raw_datasets.values())).features.keys())
560
+
561
+ if data_args.audio_column_name not in raw_datasets_features:
562
+ raise ValueError(
563
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
564
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
565
+ f"{', '.join(raw_datasets_features)}."
566
+ )
567
+
568
+ if data_args.text_column_name not in raw_datasets_features:
569
+ raise ValueError(
570
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
571
+ "Make sure to set `--text_column_name` to the correct text column - one of "
572
+ f"{', '.join(raw_datasets_features)}."
573
+ )
574
+
575
+ # 5. Load pretrained model, tokenizer, and feature extractor
576
+ config = AutoConfig.from_pretrained(
577
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
578
+ cache_dir=model_args.cache_dir,
579
+ revision=model_args.model_revision,
580
+ use_auth_token=True if model_args.use_auth_token else None,
581
+ )
582
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
583
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
584
+ cache_dir=model_args.cache_dir,
585
+ revision=model_args.model_revision,
586
+ use_auth_token=True if model_args.use_auth_token else None,
587
+ )
588
+ tokenizer = AutoTokenizer.from_pretrained(
589
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
590
+ cache_dir=model_args.cache_dir,
591
+ use_fast=model_args.use_fast_tokenizer,
592
+ revision=model_args.model_revision,
593
+ use_auth_token=True if model_args.use_auth_token else None,
594
+ )
595
+
596
+ model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained(
597
+ model_args.model_name_or_path,
598
+ config=config,
599
+ dtype=getattr(jnp, model_args.dtype),
600
+ cache_dir=model_args.cache_dir,
601
+ revision=model_args.model_revision,
602
+ use_auth_token=True if model_args.use_auth_token else None,
603
+ )
604
+
605
+ if model.config.decoder_start_token_id is None:
606
+ raise ValueError(
607
+ "Make sure that `config.decoder_start_token_id` is correctly defined")
608
+
609
+ # 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
610
+ # so we just need to set the correct target sampling rate.
611
+ dataset_sampling_rate = next(
612
+ iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
613
+
614
+ if dataset_sampling_rate != feature_extractor.sampling_rate:
615
+ raw_datasets = raw_datasets.cast_column(
616
+ data_args.audio_column_name, datasets.features.Audio(
617
+ sampling_rate=feature_extractor.sampling_rate)
618
+ )
619
+
620
+ # 7. Preprocessing the datasets.
621
+ # We need to read the audio files as arrays and tokenize the targets.
622
+ max_input_length = int(
623
+ data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
624
+ min_input_length = int(
625
+ data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
626
+ max_label_length = (
627
+ data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
628
+ )
629
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
630
+ pad_target_to_multiple_of = data_args.pad_target_to_multiple_of
631
+ audio_column_name = data_args.audio_column_name
632
+ num_workers = data_args.preprocessing_num_workers
633
+ text_column_name = data_args.text_column_name
634
+ model_input_name = feature_extractor.model_input_names[0]
635
+ do_lower_case = data_args.do_lower_case
636
+ do_remove_punctuation = data_args.do_remove_punctuation
637
+ normalizer = BasicTextNormalizer() # 'official' text normalizer from OpenAI
638
+
639
+ if data_args.language is not None:
640
+ # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
641
+ tokenizer.set_prefix_tokens(
642
+ language=data_args.language, task=data_args.task)
643
+
644
+ def prepare_dataset(batch):
645
+ # process audio
646
+ sample = batch[audio_column_name]
647
+ inputs = feature_extractor(
648
+ sample["array"], sampling_rate=sample["sampling_rate"])
649
+ # process audio length
650
+ batch[model_input_name] = inputs.get(model_input_name)[0]
651
+ batch["input_length"] = len(sample["array"])
652
+
653
+ # process targets
654
+ input_str = batch[text_column_name].lower(
655
+ ) if do_lower_case else batch[text_column_name]
656
+ if do_remove_punctuation:
657
+ input_str = normalizer(input_str).strip()
658
+ batch["labels"] = tokenizer(input_str).input_ids
659
+ return batch
660
+
661
+ with training_args.main_process_first(desc="dataset map pre-processing"):
662
+ vectorized_datasets = raw_datasets.map(
663
+ prepare_dataset,
664
+ remove_columns=raw_datasets_features,
665
+ ).with_format("torch")
666
+
667
+ # filter training data with inputs longer than max_input_length
668
+ def is_audio_in_length_range(length):
669
+ return min_input_length < length < max_input_length
670
+
671
+ if training_args.do_train:
672
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
673
+ is_audio_in_length_range,
674
+ input_columns=["input_length"],
675
+ )
676
+
677
+ if training_args.do_eval:
678
+ vectorized_datasets["eval"] = vectorized_datasets["eval"].filter(
679
+ is_audio_in_length_range,
680
+ input_columns=["input_length"],
681
+ )
682
+
683
+ # 8. Load Metric
684
+ metric = evaluate.load("wer")
685
+ do_normalize_eval = data_args.do_normalize_eval
686
+
687
+ def compute_metrics(pred_ids, label_ids):
688
+ # replace padded labels by the padding token
689
+ for idx in range(len(label_ids)):
690
+ label_ids[idx][label_ids[idx] == -100] = tokenizer.pad_token_id
691
+ #label_ids[label_ids == -100] = tokenizer.pad_token_id
692
+
693
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
694
+ # we do not want to group tokens when computing the metrics
695
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
696
+
697
+ if do_normalize_eval:
698
+ pred_str = [normalizer(pred) for pred in pred_str]
699
+ label_str = [normalizer(label) for label in label_str]
700
+ # filtering step to only evaluate the samples that correspond to non-zero references:
701
+ pred_str = [pred_str[i]
702
+ for i in range(len(pred_str)) if len(label_str[i]) > 0]
703
+ label_str = [label_str[i]
704
+ for i in range(len(label_str)) if len(label_str[i]) > 0]
705
+
706
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
707
+
708
+ return {"wer": wer}
709
+
710
+ def write_stats(eval_metrics, pred_ids, label_ids):
711
+ import pandas as pd
712
+ df = pd.DataFrame(columns=['source', 'prediction'])
713
+
714
+
715
+ for pred,label in zip(pred_ids,label_ids):
716
+ pred_text = tokenizer.decode(pred,skip_special_tokens=True)
717
+ label_text = tokenizer.decode(label,skip_special_tokens=True)
718
+ df = df.append({'source': label_text, 'column2': pred_text}, ignore_index=True)
719
+
720
+ breakpoint()
721
+
722
+
723
+ print("Writing stats")
724
+
725
+
726
+ # 9. Save feature extractor, tokenizer and config
727
+ feature_extractor.save_pretrained(training_args.output_dir)
728
+ tokenizer.save_pretrained(training_args.output_dir)
729
+ config.save_pretrained(training_args.output_dir)
730
+
731
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
732
+
733
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
734
+ processor=processor,
735
+ decoder_start_token_id=model.config.decoder_start_token_id,
736
+ input_padding="longest",
737
+ target_padding="longest",
738
+ max_target_length=max_label_length,
739
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
740
+ pad_target_to_multiple_of=pad_target_to_multiple_of if pad_target_to_multiple_of else max_label_length,
741
+ )
742
+
743
+ # Enable tensorboard only on the master node
744
+ has_tensorboard = is_tensorboard_available()
745
+ if has_tensorboard and jax.process_index() == 0:
746
+ try:
747
+ from flax.metrics.tensorboard import SummaryWriter
748
+
749
+ summary_writer = SummaryWriter(
750
+ log_dir=Path(training_args.output_dir))
751
+ except ImportError as ie:
752
+ has_tensorboard = False
753
+ logger.warning(
754
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
755
+ )
756
+ else:
757
+ logger.warning(
758
+ "Unable to display metrics through TensorBoard because the package is not installed: "
759
+ "Please run pip install tensorboard to enable."
760
+ )
761
+
762
+ # Initialize our training
763
+ rng = jax.random.PRNGKey(training_args.seed)
764
+ rng, dropout_rng = jax.random.split(rng)
765
+
766
+ # Store some constant
767
+ #num_epochs = int(training_args.num_train_epochs)
768
+ train_batch_size = int(
769
+ training_args.per_device_train_batch_size) * jax.device_count()
770
+ eval_batch_size = int(
771
+ training_args.per_device_eval_batch_size) * jax.device_count()
772
+
773
+ # Create learning rate schedule
774
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
775
+ data_args.num_train_steps*train_batch_size,
776
+ training_args.warmup_steps,
777
+ training_args.learning_rate,
778
+ )
779
+
780
+ # We use Optax's "masking" functionality to not apply weight decay
781
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
782
+ # mask boolean with the same structure as the parameters.
783
+ # The mask is True for parameters that should be decayed.
784
+ def decay_mask_fn(params):
785
+ flat_params = traverse_util.flatten_dict(params)
786
+ # find out all LayerNorm parameters
787
+ layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
788
+ layer_norm_named_params = set(
789
+ [
790
+ layer[-2:]
791
+ for layer_norm_name in layer_norm_candidates
792
+ for layer in flat_params.keys()
793
+ if layer_norm_name in "".join(layer).lower()
794
+ ]
795
+ )
796
+ flat_mask = {path: (path[-1] != "bias" and path[-2:]
797
+ not in layer_norm_named_params) for path in flat_params}
798
+ return traverse_util.unflatten_dict(flat_mask)
799
+
800
+ # create adam optimizer
801
+ adamw = optax.adamw(
802
+ learning_rate=linear_decay_lr_schedule_fn,
803
+ b1=training_args.adam_beta1,
804
+ b2=training_args.adam_beta2,
805
+ eps=training_args.adam_epsilon,
806
+ weight_decay=training_args.weight_decay,
807
+ mask=decay_mask_fn,
808
+ )
809
+
810
+ # Setup train state
811
+ state = TrainState.create(
812
+ apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
813
+
814
+ # label smoothed cross entropy
815
+ def loss_fn(logits, labels, label_smoothing_factor=0.0):
816
+ """
817
+ The label smoothing implementation is adapted from Flax's official example:
818
+ https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
819
+ """
820
+ vocab_size = logits.shape[-1]
821
+ confidence = 1.0 - label_smoothing_factor
822
+ low_confidence = (1.0 - confidence) / (vocab_size - 1)
823
+ normalizing_constant = -(
824
+ confidence * jnp.log(confidence) + (vocab_size - 1) *
825
+ low_confidence * jnp.log(low_confidence + 1e-20)
826
+ )
827
+ soft_labels = onehot(labels, vocab_size,
828
+ on_value=confidence, off_value=low_confidence)
829
+
830
+ loss = optax.softmax_cross_entropy(logits, soft_labels)
831
+ loss = loss - normalizing_constant
832
+
833
+ # ignore padded tokens from loss, i.e. where labels are not set to -100
834
+ padding_mask = labels >= 0
835
+ loss = loss * padding_mask
836
+ loss = loss.sum()
837
+ num_labels = padding_mask.sum()
838
+ return loss, num_labels
839
+
840
+ # Define gradient update step fn
841
+ def train_step(state, batch, label_smoothing_factor=0.0):
842
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
843
+
844
+ def compute_loss(params):
845
+ labels = batch.pop("labels")
846
+ logits = state.apply_fn(
847
+ **batch, params=params, dropout_rng=dropout_rng, train=True)[0]
848
+ loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
849
+ return loss, num_labels
850
+
851
+ grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
852
+ (loss, num_labels), grad = grad_fn(state.params)
853
+ num_labels = jax.lax.psum(num_labels, "batch")
854
+
855
+ # true loss = total loss / total samples
856
+ loss = jax.lax.psum(loss, "batch")
857
+ loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
858
+
859
+ # true grad = total grad / total samples
860
+ grad = jax.lax.psum(grad, "batch")
861
+ grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
862
+ new_state = state.apply_gradients(
863
+ grads=grad, dropout_rng=new_dropout_rng)
864
+
865
+ metrics = {"loss": loss,
866
+ "learning_rate": linear_decay_lr_schedule_fn(state.step)}
867
+ return new_state, metrics
868
+
869
+ # Define eval fn
870
+ def eval_step(params, batch, label_smoothing_factor=0.0):
871
+ labels = batch.pop("labels")
872
+ logits = model(**batch, params=params, train=False)[0]
873
+
874
+ loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
875
+ num_labels = jax.lax.psum(num_labels, "batch")
876
+
877
+ # true loss = total loss / total samples
878
+ loss = jax.lax.psum(loss, "batch")
879
+ loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
880
+
881
+ metrics = {"loss": loss}
882
+ return metrics
883
+
884
+ # Define generation function
885
+ num_beams = model_args.num_beams if model_args.num_beams is not None else model.config.num_beams
886
+ gen_kwargs = {"max_length": max_label_length, "num_beams": num_beams}
887
+
888
+ def generate_step(params, batch):
889
+ model.params = params
890
+ output_ids = model.generate(batch[model_input_name], attention_mask=batch.get(
891
+ "attention_mask"), **gen_kwargs)
892
+ return output_ids.sequences
893
+
894
+ # Create parallel version of the train and eval step
895
+ p_train_step = jax.pmap(
896
+ partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
897
+ )
898
+ p_eval_step = jax.pmap(partial(
899
+ eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
900
+ p_generate_step = jax.pmap(generate_step, "batch")
901
+
902
+ # Replicate the train state on each device
903
+ state = state.replicate()
904
+
905
+ logger.info("***** Running training *****")
906
+ logger.info(
907
+ f" Num examples = {data_args.num_train_steps*train_batch_size}")
908
+ logger.info(
909
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
910
+ logger.info(
911
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
912
+ logger.info(f" Total optimization steps = {data_args.num_train_steps}")
913
+
914
+ train_time = 0
915
+
916
+ # ======================== Training ================================
917
+ train_start = time.time()
918
+
919
+ # Create sampling rng
920
+ #rng, input_rng = jax.random.split(rng)
921
+ train_metrics = []
922
+ epoch = 0
923
+
924
+ def collate_batch(samples):
925
+ return {key: [feature[key] for feature in samples] for key in samples[0].keys()}
926
+
927
+ # Create a batched data iterator
928
+ num_workers = 0
929
+ # This is not working
930
+ # vectorized_datasets["train"] = vectorized_datasets["train"].shuffle()
931
+ train_data_loader = torch.utils.data.DataLoader(
932
+ batch_size=train_batch_size, dataset=vectorized_datasets["train"], num_workers=num_workers, collate_fn=collate_batch, drop_last=True)
933
+ train_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(
934
+ train_data_loader)
935
+
936
+ # train
937
+ for step in tqdm(range(data_args.num_train_steps), desc="Training...", position=1, leave=False):
938
+
939
+ try:
940
+ samples = next(train_data_iterator)
941
+
942
+ except StopIteration:
943
+ epoch += 1
944
+ train_data_loader = torch.utils.data.DataLoader(
945
+ batch_size=train_batch_size, dataset=vectorized_datasets["train"], num_workers=num_workers, collate_fn=collate_batch, drop_last=True)
946
+ train_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(
947
+ train_data_loader)
948
+ samples = next(train_data_iterator)
949
+
950
+ logger.info(
951
+ f"Completed epoch ({epoch} | Loss: {train_metric['loss']}, Learning Rate:"
952
+ f" {train_metric['learning_rate']})"
953
+ )
954
+
955
+ # reshaped_samples = {key: [feature[key] for feature in samples] for key in samples[0].keys()}
956
+ # breakpoint()
957
+ batch = data_collator(samples)
958
+ batch = shard(batch.data)
959
+ state, train_metric = p_train_step(state, batch)
960
+ train_metrics.append(train_metric)
961
+
962
+ train_time += time.time() - train_start
963
+ train_metric = unreplicate(train_metric)
964
+ # ======================== Evaluating ==============================
965
+ if step % training_args.eval_steps == 0 and step > 0:
966
+ eval_metrics = []
967
+ eval_preds = []
968
+ eval_labels = []
969
+
970
+ #eval_loader = data_loader(input_rng, vectorized_datasets["eval"], eval_batch_size, drop_last=False)
971
+ eval_data_loader = torch.utils.data.DataLoader(
972
+ batch_size=eval_batch_size, dataset=vectorized_datasets["eval"], num_workers=num_workers, collate_fn=collate_batch, drop_last=False)
973
+ eval_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(
974
+ eval_data_loader)
975
+
976
+ for _ in tqdm(range(training_args.eval_steps), desc="Evaluating...", position=2, leave=False):
977
+ # Model forward
978
+ samples = next(eval_data_iterator)
979
+ batch = data_collator(samples)
980
+ labels = batch["labels"]
981
+
982
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
983
+ state.params, batch.data, min_device_batch=training_args.per_device_eval_batch_size
984
+ )
985
+ eval_metrics.append(metrics)
986
+
987
+ # generation
988
+ if training_args.predict_with_generate:
989
+ generated_ids = pad_shard_unpad(
990
+ p_generate_step)(state.params, batch.data)
991
+ eval_preds.extend(jax.device_get(
992
+ generated_ids.reshape(-1, gen_kwargs["max_length"])))
993
+ eval_labels.extend(labels)
994
+
995
+
996
+ # normalize eval metrics
997
+ eval_metrics = get_metrics(eval_metrics)
998
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
999
+
1000
+ # compute WER metric
1001
+ wer_desc = ""
1002
+ if training_args.predict_with_generate:
1003
+ wer_metric = compute_metrics(eval_preds, eval_labels)
1004
+ eval_metrics.update(wer_metric)
1005
+ wer_desc = " ".join(
1006
+ [f"Eval {key}: {value} |" for key, value in wer_metric.items()])
1007
+ write_stats(eval_metrics, eval_preds, eval_labels)
1008
+
1009
+ # Print metrics
1010
+ desc = f"Epoch... ({epoch} | Eval Loss: {eval_metrics['loss']} | {wer_desc})"
1011
+ logger.info(desc)
1012
+
1013
+ # Save metrics
1014
+ if has_tensorboard and jax.process_index() == 0:
1015
+ write_metric(summary_writer, train_metrics,
1016
+ eval_metrics, train_time, step)
1017
+
1018
+ # save checkpoint after each epoch and push checkpoint to the hub
1019
+ if jax.process_index() == 0:
1020
+ params = jax.device_get(
1021
+ jax.tree_util.tree_map(lambda x: x[0], state.params))
1022
+ model.save_pretrained(training_args.output_dir, params=params)
1023
+ tokenizer.save_pretrained(training_args.output_dir)
1024
+ if training_args.push_to_hub:
1025
+ repo.push_to_hub(
1026
+ commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
1027
+
1028
+
1029
+ if __name__ == "__main__":
1030
+ main()