test
Browse files
run_flax_speech_recognition_seq2seq_streaming_v3.py
CHANGED
|
@@ -35,7 +35,7 @@ import jax.numpy as jnp
|
|
| 35 |
import numpy as np
|
| 36 |
import optax
|
| 37 |
import torch
|
| 38 |
-
from datasets import Dataset,DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
|
| 39 |
from torch.utils.data import IterableDataset
|
| 40 |
from flax import jax_utils, traverse_util
|
| 41 |
from flax.jax_utils import pad_shard_unpad, unreplicate
|
|
@@ -66,7 +66,8 @@ from transformers.utils.versions import require_version
|
|
| 66 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 67 |
check_min_version("4.27.0.dev0")
|
| 68 |
|
| 69 |
-
require_version("datasets>=1.18.2",
|
|
|
|
| 70 |
|
| 71 |
logger = logging.getLogger(__name__)
|
| 72 |
|
|
@@ -78,7 +79,8 @@ class ModelArguments:
|
|
| 78 |
"""
|
| 79 |
|
| 80 |
model_name_or_path: str = field(
|
| 81 |
-
metadata={
|
|
|
|
| 82 |
)
|
| 83 |
config_name: Optional[str] = field(
|
| 84 |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
|
@@ -91,15 +93,18 @@ class ModelArguments:
|
|
| 91 |
)
|
| 92 |
cache_dir: Optional[str] = field(
|
| 93 |
default=None,
|
| 94 |
-
metadata={
|
|
|
|
| 95 |
)
|
| 96 |
use_fast_tokenizer: bool = field(
|
| 97 |
default=True,
|
| 98 |
-
metadata={
|
|
|
|
| 99 |
)
|
| 100 |
model_revision: str = field(
|
| 101 |
default="main",
|
| 102 |
-
metadata={
|
|
|
|
| 103 |
)
|
| 104 |
use_auth_token: bool = field(
|
| 105 |
default=False,
|
|
@@ -142,7 +147,8 @@ class DataTrainingArguments:
|
|
| 142 |
)
|
| 143 |
text_column: Optional[str] = field(
|
| 144 |
default=None,
|
| 145 |
-
metadata={
|
|
|
|
| 146 |
)
|
| 147 |
dataset_cache_dir: Optional[str] = field(
|
| 148 |
default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
|
|
@@ -170,23 +176,28 @@ class DataTrainingArguments:
|
|
| 170 |
)
|
| 171 |
audio_column_name: str = field(
|
| 172 |
default="audio",
|
| 173 |
-
metadata={
|
|
|
|
| 174 |
)
|
| 175 |
text_column_name: str = field(
|
| 176 |
default="text",
|
| 177 |
-
metadata={
|
|
|
|
| 178 |
)
|
| 179 |
max_duration_in_seconds: float = field(
|
| 180 |
default=30.0,
|
| 181 |
-
metadata={
|
|
|
|
| 182 |
)
|
| 183 |
min_duration_in_seconds: float = field(
|
| 184 |
default=0.0,
|
| 185 |
-
metadata={
|
|
|
|
| 186 |
)
|
| 187 |
max_label_length: float = field(
|
| 188 |
default=128,
|
| 189 |
-
metadata={
|
|
|
|
| 190 |
)
|
| 191 |
pad_input_to_multiple_of: Optional[int] = field(
|
| 192 |
default=None,
|
|
@@ -229,11 +240,13 @@ class DataTrainingArguments:
|
|
| 229 |
)
|
| 230 |
do_remove_punctuation: bool = field(
|
| 231 |
default=False,
|
| 232 |
-
metadata={
|
|
|
|
| 233 |
)
|
| 234 |
do_normalize_eval: bool = field(
|
| 235 |
default=True,
|
| 236 |
-
metadata={
|
|
|
|
| 237 |
)
|
| 238 |
language: str = field(
|
| 239 |
default=None,
|
|
@@ -246,9 +259,11 @@ class DataTrainingArguments:
|
|
| 246 |
)
|
| 247 |
task: str = field(
|
| 248 |
default="transcribe",
|
| 249 |
-
metadata={
|
|
|
|
| 250 |
)
|
| 251 |
-
num_train_steps: int = field(default=50000, metadata={
|
|
|
|
| 252 |
# num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
|
| 253 |
shuffle_buffer_size: Optional[int] = field(
|
| 254 |
default=500,
|
|
@@ -261,9 +276,11 @@ class DataTrainingArguments:
|
|
| 261 |
)
|
| 262 |
streaming: bool = field(
|
| 263 |
default=True,
|
| 264 |
-
metadata={
|
|
|
|
| 265 |
)
|
| 266 |
|
|
|
|
| 267 |
def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
|
| 268 |
"""
|
| 269 |
Shift label ids one token to the right.
|
|
@@ -348,17 +365,19 @@ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
|
|
| 348 |
labels = labels[:, 1:]
|
| 349 |
labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
|
| 350 |
|
| 351 |
-
decoder_input_ids = shift_tokens_right(
|
|
|
|
| 352 |
|
| 353 |
# replace padding with -100 to ignore correctly when computing the loss
|
| 354 |
-
labels = np.ma.array(labels, mask=np.not_equal(
|
|
|
|
| 355 |
labels = labels.filled(fill_value=-100)
|
| 356 |
|
| 357 |
batch["labels"] = labels
|
| 358 |
batch["decoder_input_ids"] = decoder_input_ids
|
| 359 |
|
| 360 |
return batch
|
| 361 |
-
|
| 362 |
|
| 363 |
def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
|
| 364 |
"""
|
|
@@ -369,7 +388,8 @@ def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train
|
|
| 369 |
if "+" in split:
|
| 370 |
# load multiple splits separated by the `+` symbol with streaming mode
|
| 371 |
dataset_splits = [
|
| 372 |
-
load_dataset(dataset_name, dataset_config_name,
|
|
|
|
| 373 |
for split_name in split.split("+")
|
| 374 |
]
|
| 375 |
# interleave multiple splits to form one dataset
|
|
@@ -377,7 +397,8 @@ def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train
|
|
| 377 |
return interleaved_dataset
|
| 378 |
else:
|
| 379 |
# load a single split *with* streaming mode
|
| 380 |
-
dataset = load_dataset(
|
|
|
|
| 381 |
return dataset
|
| 382 |
|
| 383 |
|
|
@@ -394,7 +415,8 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
| 394 |
|
| 395 |
if drop_last:
|
| 396 |
steps_per_epoch = len(dataset) // batch_size
|
| 397 |
-
|
|
|
|
| 398 |
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
| 399 |
else:
|
| 400 |
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
|
@@ -429,11 +451,13 @@ def create_learning_rate_fn(
|
|
| 429 |
num_train_steps: int, num_warmup_steps: int, learning_rate: float
|
| 430 |
) -> Callable[[int], jnp.array]:
|
| 431 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 432 |
-
warmup_fn = optax.linear_schedule(
|
|
|
|
| 433 |
decay_fn = optax.linear_schedule(
|
| 434 |
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
| 435 |
)
|
| 436 |
-
schedule_fn = optax.join_schedules(
|
|
|
|
| 437 |
return schedule_fn
|
| 438 |
|
| 439 |
|
|
@@ -442,18 +466,21 @@ def main():
|
|
| 442 |
# See all possible arguments in src/transformers/training_args.py
|
| 443 |
# or by passing the --help flag to this script.
|
| 444 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
| 445 |
-
parser = HfArgumentParser(
|
|
|
|
| 446 |
|
| 447 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 448 |
# If we pass only one argument to the script and it's the path to a json file,
|
| 449 |
# let's parse it to get our arguments.
|
| 450 |
-
model_args, data_args, training_args = parser.parse_json_file(
|
|
|
|
| 451 |
else:
|
| 452 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 453 |
|
| 454 |
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
| 455 |
# information sent is the one passed as arguments along with your JAX/Flax versions.
|
| 456 |
-
send_example_telemetry("run_speech_recognition_seq2seq",
|
|
|
|
| 457 |
|
| 458 |
# 2. Setup logging
|
| 459 |
# Make one log on every process with the configuration for debugging.
|
|
@@ -464,7 +491,8 @@ def main():
|
|
| 464 |
)
|
| 465 |
# Set the verbosity to info of the Transformers logger.
|
| 466 |
# We only want one process per machine to log things on the screen.
|
| 467 |
-
logger.setLevel(logging.INFO if jax.process_index()
|
|
|
|
| 468 |
if jax.process_index() == 0:
|
| 469 |
datasets.utils.logging.set_verbosity_warning()
|
| 470 |
transformers.utils.logging.set_verbosity_info()
|
|
@@ -490,16 +518,18 @@ def main():
|
|
| 490 |
if training_args.push_to_hub:
|
| 491 |
if training_args.hub_model_id is None:
|
| 492 |
repo_name = get_full_repo_name(
|
| 493 |
-
Path(training_args.output_dir).absolute(
|
|
|
|
| 494 |
)
|
| 495 |
else:
|
| 496 |
repo_name = training_args.hub_model_id
|
| 497 |
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
|
| 498 |
-
repo = Repository(training_args.output_dir,
|
|
|
|
| 499 |
|
| 500 |
# 3. Load dataset
|
| 501 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
| 502 |
-
|
| 503 |
if training_args.do_train:
|
| 504 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
| 505 |
data_args.dataset_name,
|
|
@@ -519,13 +549,14 @@ def main():
|
|
| 519 |
streaming=data_args.streaming,
|
| 520 |
use_auth_token=True if model_args.use_auth_token else None,
|
| 521 |
)
|
| 522 |
-
|
| 523 |
if not training_args.do_train and not training_args.do_eval:
|
| 524 |
raise ValueError(
|
| 525 |
"Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
|
| 526 |
)
|
| 527 |
|
| 528 |
-
raw_datasets_features = list(
|
|
|
|
| 529 |
|
| 530 |
if data_args.audio_column_name not in raw_datasets_features:
|
| 531 |
raise ValueError(
|
|
@@ -572,21 +603,26 @@ def main():
|
|
| 572 |
)
|
| 573 |
|
| 574 |
if model.config.decoder_start_token_id is None:
|
| 575 |
-
raise ValueError(
|
|
|
|
| 576 |
|
| 577 |
# 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
|
| 578 |
# so we just need to set the correct target sampling rate.
|
| 579 |
-
dataset_sampling_rate = next(
|
| 580 |
-
|
|
|
|
| 581 |
if dataset_sampling_rate != feature_extractor.sampling_rate:
|
| 582 |
raw_datasets = raw_datasets.cast_column(
|
| 583 |
-
data_args.audio_column_name, datasets.features.Audio(
|
|
|
|
| 584 |
)
|
| 585 |
|
| 586 |
# 7. Preprocessing the datasets.
|
| 587 |
# We need to read the audio files as arrays and tokenize the targets.
|
| 588 |
-
max_input_length = int(
|
| 589 |
-
|
|
|
|
|
|
|
| 590 |
max_label_length = (
|
| 591 |
data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
|
| 592 |
)
|
|
@@ -602,18 +638,21 @@ def main():
|
|
| 602 |
|
| 603 |
if data_args.language is not None:
|
| 604 |
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
| 605 |
-
tokenizer.set_prefix_tokens(
|
|
|
|
| 606 |
|
| 607 |
def prepare_dataset(batch):
|
| 608 |
# process audio
|
| 609 |
sample = batch[audio_column_name]
|
| 610 |
-
inputs = feature_extractor(
|
|
|
|
| 611 |
# process audio length
|
| 612 |
batch[model_input_name] = inputs.get(model_input_name)[0]
|
| 613 |
batch["input_length"] = len(sample["array"])
|
| 614 |
|
| 615 |
# process targets
|
| 616 |
-
input_str = batch[text_column_name].lower(
|
|
|
|
| 617 |
if do_remove_punctuation:
|
| 618 |
input_str = normalizer(input_str).strip()
|
| 619 |
batch["labels"] = tokenizer(input_str).input_ids
|
|
@@ -624,7 +663,7 @@ def main():
|
|
| 624 |
prepare_dataset,
|
| 625 |
remove_columns=raw_datasets_features,
|
| 626 |
).with_format("torch")
|
| 627 |
-
|
| 628 |
# filter training data with inputs longer than max_input_length
|
| 629 |
def is_audio_in_length_range(length):
|
| 630 |
return min_input_length < length < max_input_length
|
|
@@ -634,14 +673,13 @@ def main():
|
|
| 634 |
is_audio_in_length_range,
|
| 635 |
input_columns=["input_length"],
|
| 636 |
)
|
| 637 |
-
|
| 638 |
if training_args.do_eval:
|
| 639 |
vectorized_datasets["eval"] = vectorized_datasets["eval"].filter(
|
| 640 |
is_audio_in_length_range,
|
| 641 |
input_columns=["input_length"],
|
| 642 |
)
|
| 643 |
|
| 644 |
-
|
| 645 |
# 8. Load Metric
|
| 646 |
metric = evaluate.load("wer")
|
| 647 |
do_normalize_eval = data_args.do_normalize_eval
|
|
@@ -660,8 +698,10 @@ def main():
|
|
| 660 |
pred_str = [normalizer(pred) for pred in pred_str]
|
| 661 |
label_str = [normalizer(label) for label in label_str]
|
| 662 |
# filtering step to only evaluate the samples that correspond to non-zero references:
|
| 663 |
-
pred_str = [pred_str[i]
|
| 664 |
-
|
|
|
|
|
|
|
| 665 |
|
| 666 |
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
|
| 667 |
|
|
@@ -690,7 +730,8 @@ def main():
|
|
| 690 |
try:
|
| 691 |
from flax.metrics.tensorboard import SummaryWriter
|
| 692 |
|
| 693 |
-
summary_writer = SummaryWriter(
|
|
|
|
| 694 |
except ImportError as ie:
|
| 695 |
has_tensorboard = False
|
| 696 |
logger.warning(
|
|
@@ -708,10 +749,10 @@ def main():
|
|
| 708 |
|
| 709 |
# Store some constant
|
| 710 |
#num_epochs = int(training_args.num_train_epochs)
|
| 711 |
-
train_batch_size = int(
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
|
| 716 |
# Create learning rate schedule
|
| 717 |
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
|
@@ -736,7 +777,8 @@ def main():
|
|
| 736 |
if layer_norm_name in "".join(layer).lower()
|
| 737 |
]
|
| 738 |
)
|
| 739 |
-
flat_mask = {path: (path[-1] != "bias" and path[-2:]
|
|
|
|
| 740 |
return traverse_util.unflatten_dict(flat_mask)
|
| 741 |
|
| 742 |
# create adam optimizer
|
|
@@ -750,7 +792,8 @@ def main():
|
|
| 750 |
)
|
| 751 |
|
| 752 |
# Setup train state
|
| 753 |
-
state = TrainState.create(
|
|
|
|
| 754 |
|
| 755 |
# label smoothed cross entropy
|
| 756 |
def loss_fn(logits, labels, label_smoothing_factor=0.0):
|
|
@@ -762,9 +805,11 @@ def main():
|
|
| 762 |
confidence = 1.0 - label_smoothing_factor
|
| 763 |
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
| 764 |
normalizing_constant = -(
|
| 765 |
-
confidence * jnp.log(confidence) + (vocab_size - 1) *
|
|
|
|
| 766 |
)
|
| 767 |
-
soft_labels = onehot(labels, vocab_size,
|
|
|
|
| 768 |
|
| 769 |
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
| 770 |
loss = loss - normalizing_constant
|
|
@@ -782,7 +827,8 @@ def main():
|
|
| 782 |
|
| 783 |
def compute_loss(params):
|
| 784 |
labels = batch.pop("labels")
|
| 785 |
-
logits = state.apply_fn(
|
|
|
|
| 786 |
loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
|
| 787 |
return loss, num_labels
|
| 788 |
|
|
@@ -797,9 +843,11 @@ def main():
|
|
| 797 |
# true grad = total grad / total samples
|
| 798 |
grad = jax.lax.psum(grad, "batch")
|
| 799 |
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
| 800 |
-
new_state = state.apply_gradients(
|
|
|
|
| 801 |
|
| 802 |
-
metrics = {"loss": loss,
|
|
|
|
| 803 |
return new_state, metrics
|
| 804 |
|
| 805 |
# Define eval fn
|
|
@@ -823,27 +871,32 @@ def main():
|
|
| 823 |
|
| 824 |
def generate_step(params, batch):
|
| 825 |
model.params = params
|
| 826 |
-
output_ids = model.generate(batch[model_input_name], attention_mask=batch.get(
|
|
|
|
| 827 |
return output_ids.sequences
|
| 828 |
|
| 829 |
# Create parallel version of the train and eval step
|
| 830 |
p_train_step = jax.pmap(
|
| 831 |
partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
|
| 832 |
)
|
| 833 |
-
p_eval_step = jax.pmap(partial(
|
|
|
|
| 834 |
p_generate_step = jax.pmap(generate_step, "batch")
|
| 835 |
|
| 836 |
# Replicate the train state on each device
|
| 837 |
state = state.replicate()
|
| 838 |
|
| 839 |
logger.info("***** Running training *****")
|
| 840 |
-
logger.info(
|
| 841 |
-
|
| 842 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
| 843 |
logger.info(f" Total optimization steps = {data_args.num_train_steps}")
|
| 844 |
|
| 845 |
train_time = 0
|
| 846 |
-
|
| 847 |
# ======================== Training ================================
|
| 848 |
train_start = time.time()
|
| 849 |
|
|
@@ -859,29 +912,32 @@ def main():
|
|
| 859 |
num_workers = 0
|
| 860 |
# This is not working
|
| 861 |
# vectorized_datasets["train"] = vectorized_datasets["train"].shuffle()
|
| 862 |
-
train_data_loader = torch.utils.data.DataLoader(
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
|
|
|
| 866 |
# train
|
| 867 |
-
for step in tqdm(range(data_args.num_train_steps), desc="Training...", position=1, leave=False):
|
| 868 |
-
|
| 869 |
try:
|
| 870 |
samples = next(train_data_iterator)
|
| 871 |
-
|
| 872 |
except StopIteration:
|
| 873 |
epoch += 1
|
| 874 |
-
train_data_loader = torch.utils.data.DataLoader(
|
| 875 |
-
|
|
|
|
|
|
|
| 876 |
samples = next(train_data_iterator)
|
| 877 |
-
|
| 878 |
logger.info(
|
| 879 |
f"Completed epoch ({epoch} | Loss: {train_metric['loss']}, Learning Rate:"
|
| 880 |
f" {train_metric['learning_rate']})"
|
| 881 |
)
|
| 882 |
-
|
| 883 |
# reshaped_samples = {key: [feature[key] for feature in samples] for key in samples[0].keys()}
|
| 884 |
-
#breakpoint()
|
| 885 |
batch = data_collator(samples)
|
| 886 |
batch = shard(batch.data)
|
| 887 |
state, train_metric = p_train_step(state, batch)
|
|
@@ -896,8 +952,10 @@ def main():
|
|
| 896 |
eval_labels = []
|
| 897 |
|
| 898 |
#eval_loader = data_loader(input_rng, vectorized_datasets["eval"], eval_batch_size, drop_last=False)
|
| 899 |
-
eval_data_loader = torch.utils.data.DataLoader(
|
| 900 |
-
|
|
|
|
|
|
|
| 901 |
|
| 902 |
for _ in tqdm(range(training_args.eval_steps), desc="Evaluating...", position=2, leave=False):
|
| 903 |
# Model forward
|
|
@@ -912,10 +970,12 @@ def main():
|
|
| 912 |
|
| 913 |
# generation
|
| 914 |
if training_args.predict_with_generate:
|
| 915 |
-
generated_ids = pad_shard_unpad(
|
| 916 |
-
|
|
|
|
|
|
|
| 917 |
eval_labels.extend(labels)
|
| 918 |
-
|
| 919 |
# normalize eval metrics
|
| 920 |
eval_metrics = get_metrics(eval_metrics)
|
| 921 |
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
|
@@ -925,7 +985,8 @@ def main():
|
|
| 925 |
if training_args.predict_with_generate:
|
| 926 |
wer_metric = compute_metrics(eval_preds, eval_labels)
|
| 927 |
eval_metrics.update(wer_metric)
|
| 928 |
-
wer_desc = " ".join(
|
|
|
|
| 929 |
|
| 930 |
# Print metrics
|
| 931 |
desc = f"Epoch... ({epoch} | Eval Loss: {eval_metrics['loss']} | {wer_desc})"
|
|
@@ -933,15 +994,18 @@ def main():
|
|
| 933 |
|
| 934 |
# Save metrics
|
| 935 |
if has_tensorboard and jax.process_index() == 0:
|
| 936 |
-
write_metric(summary_writer, train_metrics,
|
|
|
|
| 937 |
|
| 938 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 939 |
if jax.process_index() == 0:
|
| 940 |
-
params = jax.device_get(
|
|
|
|
| 941 |
model.save_pretrained(training_args.output_dir, params=params)
|
| 942 |
tokenizer.save_pretrained(training_args.output_dir)
|
| 943 |
if training_args.push_to_hub:
|
| 944 |
-
repo.push_to_hub(
|
|
|
|
| 945 |
|
| 946 |
|
| 947 |
if __name__ == "__main__":
|
|
|
|
| 35 |
import numpy as np
|
| 36 |
import optax
|
| 37 |
import torch
|
| 38 |
+
from datasets import Dataset, DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
|
| 39 |
from torch.utils.data import IterableDataset
|
| 40 |
from flax import jax_utils, traverse_util
|
| 41 |
from flax.jax_utils import pad_shard_unpad, unreplicate
|
|
|
|
| 66 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 67 |
check_min_version("4.27.0.dev0")
|
| 68 |
|
| 69 |
+
require_version("datasets>=1.18.2",
|
| 70 |
+
"To fix: pip install -r examples/flax/speech-recogintion/requirements.txt")
|
| 71 |
|
| 72 |
logger = logging.getLogger(__name__)
|
| 73 |
|
|
|
|
| 79 |
"""
|
| 80 |
|
| 81 |
model_name_or_path: str = field(
|
| 82 |
+
metadata={
|
| 83 |
+
"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
| 84 |
)
|
| 85 |
config_name: Optional[str] = field(
|
| 86 |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
|
|
|
| 93 |
)
|
| 94 |
cache_dir: Optional[str] = field(
|
| 95 |
default=None,
|
| 96 |
+
metadata={
|
| 97 |
+
"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
| 98 |
)
|
| 99 |
use_fast_tokenizer: bool = field(
|
| 100 |
default=True,
|
| 101 |
+
metadata={
|
| 102 |
+
"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
| 103 |
)
|
| 104 |
model_revision: str = field(
|
| 105 |
default="main",
|
| 106 |
+
metadata={
|
| 107 |
+
"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
| 108 |
)
|
| 109 |
use_auth_token: bool = field(
|
| 110 |
default=False,
|
|
|
|
| 147 |
)
|
| 148 |
text_column: Optional[str] = field(
|
| 149 |
default=None,
|
| 150 |
+
metadata={
|
| 151 |
+
"help": "The name of the column in the datasets containing the full texts (for summarization)."},
|
| 152 |
)
|
| 153 |
dataset_cache_dir: Optional[str] = field(
|
| 154 |
default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
|
|
|
|
| 176 |
)
|
| 177 |
audio_column_name: str = field(
|
| 178 |
default="audio",
|
| 179 |
+
metadata={
|
| 180 |
+
"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
| 181 |
)
|
| 182 |
text_column_name: str = field(
|
| 183 |
default="text",
|
| 184 |
+
metadata={
|
| 185 |
+
"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
|
| 186 |
)
|
| 187 |
max_duration_in_seconds: float = field(
|
| 188 |
default=30.0,
|
| 189 |
+
metadata={
|
| 190 |
+
"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
|
| 191 |
)
|
| 192 |
min_duration_in_seconds: float = field(
|
| 193 |
default=0.0,
|
| 194 |
+
metadata={
|
| 195 |
+
"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
|
| 196 |
)
|
| 197 |
max_label_length: float = field(
|
| 198 |
default=128,
|
| 199 |
+
metadata={
|
| 200 |
+
"help": "Truncate transcriptions that are longer `max_eval_length` tokens."},
|
| 201 |
)
|
| 202 |
pad_input_to_multiple_of: Optional[int] = field(
|
| 203 |
default=None,
|
|
|
|
| 240 |
)
|
| 241 |
do_remove_punctuation: bool = field(
|
| 242 |
default=False,
|
| 243 |
+
metadata={
|
| 244 |
+
"help": "Whether the target text should be striped of punctuation."},
|
| 245 |
)
|
| 246 |
do_normalize_eval: bool = field(
|
| 247 |
default=True,
|
| 248 |
+
metadata={
|
| 249 |
+
"help": "Whether to normalise the references and predictions in the eval WER calculation."},
|
| 250 |
)
|
| 251 |
language: str = field(
|
| 252 |
default=None,
|
|
|
|
| 259 |
)
|
| 260 |
task: str = field(
|
| 261 |
default="transcribe",
|
| 262 |
+
metadata={
|
| 263 |
+
"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
|
| 264 |
)
|
| 265 |
+
num_train_steps: int = field(default=50000, metadata={
|
| 266 |
+
"help": "The number of training steps."})
|
| 267 |
# num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
|
| 268 |
shuffle_buffer_size: Optional[int] = field(
|
| 269 |
default=500,
|
|
|
|
| 276 |
)
|
| 277 |
streaming: bool = field(
|
| 278 |
default=True,
|
| 279 |
+
metadata={
|
| 280 |
+
"help": "Whether to use streaming mode to load and pre-process the data."},
|
| 281 |
)
|
| 282 |
|
| 283 |
+
|
| 284 |
def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
|
| 285 |
"""
|
| 286 |
Shift label ids one token to the right.
|
|
|
|
| 365 |
labels = labels[:, 1:]
|
| 366 |
labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
|
| 367 |
|
| 368 |
+
decoder_input_ids = shift_tokens_right(
|
| 369 |
+
labels, self.decoder_start_token_id)
|
| 370 |
|
| 371 |
# replace padding with -100 to ignore correctly when computing the loss
|
| 372 |
+
labels = np.ma.array(labels, mask=np.not_equal(
|
| 373 |
+
labels_batch.attention_mask, 1))
|
| 374 |
labels = labels.filled(fill_value=-100)
|
| 375 |
|
| 376 |
batch["labels"] = labels
|
| 377 |
batch["decoder_input_ids"] = decoder_input_ids
|
| 378 |
|
| 379 |
return batch
|
| 380 |
+
|
| 381 |
|
| 382 |
def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
|
| 383 |
"""
|
|
|
|
| 388 |
if "+" in split:
|
| 389 |
# load multiple splits separated by the `+` symbol with streaming mode
|
| 390 |
dataset_splits = [
|
| 391 |
+
load_dataset(dataset_name, dataset_config_name,
|
| 392 |
+
split=split_name, streaming=streaming, **kwargs)
|
| 393 |
for split_name in split.split("+")
|
| 394 |
]
|
| 395 |
# interleave multiple splits to form one dataset
|
|
|
|
| 397 |
return interleaved_dataset
|
| 398 |
else:
|
| 399 |
# load a single split *with* streaming mode
|
| 400 |
+
dataset = load_dataset(
|
| 401 |
+
dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
|
| 402 |
return dataset
|
| 403 |
|
| 404 |
|
|
|
|
| 415 |
|
| 416 |
if drop_last:
|
| 417 |
steps_per_epoch = len(dataset) // batch_size
|
| 418 |
+
# Skip incomplete batch.
|
| 419 |
+
batch_idx = batch_idx[: steps_per_epoch * batch_size]
|
| 420 |
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
| 421 |
else:
|
| 422 |
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
|
|
|
| 451 |
num_train_steps: int, num_warmup_steps: int, learning_rate: float
|
| 452 |
) -> Callable[[int], jnp.array]:
|
| 453 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 454 |
+
warmup_fn = optax.linear_schedule(
|
| 455 |
+
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
| 456 |
decay_fn = optax.linear_schedule(
|
| 457 |
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
| 458 |
)
|
| 459 |
+
schedule_fn = optax.join_schedules(
|
| 460 |
+
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
| 461 |
return schedule_fn
|
| 462 |
|
| 463 |
|
|
|
|
| 466 |
# See all possible arguments in src/transformers/training_args.py
|
| 467 |
# or by passing the --help flag to this script.
|
| 468 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
| 469 |
+
parser = HfArgumentParser(
|
| 470 |
+
(ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
| 471 |
|
| 472 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 473 |
# If we pass only one argument to the script and it's the path to a json file,
|
| 474 |
# let's parse it to get our arguments.
|
| 475 |
+
model_args, data_args, training_args = parser.parse_json_file(
|
| 476 |
+
json_file=os.path.abspath(sys.argv[1]))
|
| 477 |
else:
|
| 478 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 479 |
|
| 480 |
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
| 481 |
# information sent is the one passed as arguments along with your JAX/Flax versions.
|
| 482 |
+
send_example_telemetry("run_speech_recognition_seq2seq",
|
| 483 |
+
model_args, data_args, framework="flax")
|
| 484 |
|
| 485 |
# 2. Setup logging
|
| 486 |
# Make one log on every process with the configuration for debugging.
|
|
|
|
| 491 |
)
|
| 492 |
# Set the verbosity to info of the Transformers logger.
|
| 493 |
# We only want one process per machine to log things on the screen.
|
| 494 |
+
logger.setLevel(logging.INFO if jax.process_index()
|
| 495 |
+
== 0 else logging.ERROR)
|
| 496 |
if jax.process_index() == 0:
|
| 497 |
datasets.utils.logging.set_verbosity_warning()
|
| 498 |
transformers.utils.logging.set_verbosity_info()
|
|
|
|
| 518 |
if training_args.push_to_hub:
|
| 519 |
if training_args.hub_model_id is None:
|
| 520 |
repo_name = get_full_repo_name(
|
| 521 |
+
Path(training_args.output_dir).absolute(
|
| 522 |
+
).name, token=training_args.hub_token
|
| 523 |
)
|
| 524 |
else:
|
| 525 |
repo_name = training_args.hub_model_id
|
| 526 |
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
|
| 527 |
+
repo = Repository(training_args.output_dir,
|
| 528 |
+
clone_from=repo_name, token=training_args.hub_token)
|
| 529 |
|
| 530 |
# 3. Load dataset
|
| 531 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
| 532 |
+
|
| 533 |
if training_args.do_train:
|
| 534 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
| 535 |
data_args.dataset_name,
|
|
|
|
| 549 |
streaming=data_args.streaming,
|
| 550 |
use_auth_token=True if model_args.use_auth_token else None,
|
| 551 |
)
|
| 552 |
+
|
| 553 |
if not training_args.do_train and not training_args.do_eval:
|
| 554 |
raise ValueError(
|
| 555 |
"Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
|
| 556 |
)
|
| 557 |
|
| 558 |
+
raw_datasets_features = list(
|
| 559 |
+
next(iter(raw_datasets.values())).features.keys())
|
| 560 |
|
| 561 |
if data_args.audio_column_name not in raw_datasets_features:
|
| 562 |
raise ValueError(
|
|
|
|
| 603 |
)
|
| 604 |
|
| 605 |
if model.config.decoder_start_token_id is None:
|
| 606 |
+
raise ValueError(
|
| 607 |
+
"Make sure that `config.decoder_start_token_id` is correctly defined")
|
| 608 |
|
| 609 |
# 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
|
| 610 |
# so we just need to set the correct target sampling rate.
|
| 611 |
+
dataset_sampling_rate = next(
|
| 612 |
+
iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
| 613 |
+
|
| 614 |
if dataset_sampling_rate != feature_extractor.sampling_rate:
|
| 615 |
raw_datasets = raw_datasets.cast_column(
|
| 616 |
+
data_args.audio_column_name, datasets.features.Audio(
|
| 617 |
+
sampling_rate=feature_extractor.sampling_rate)
|
| 618 |
)
|
| 619 |
|
| 620 |
# 7. Preprocessing the datasets.
|
| 621 |
# We need to read the audio files as arrays and tokenize the targets.
|
| 622 |
+
max_input_length = int(
|
| 623 |
+
data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
|
| 624 |
+
min_input_length = int(
|
| 625 |
+
data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
|
| 626 |
max_label_length = (
|
| 627 |
data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
|
| 628 |
)
|
|
|
|
| 638 |
|
| 639 |
if data_args.language is not None:
|
| 640 |
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
| 641 |
+
tokenizer.set_prefix_tokens(
|
| 642 |
+
language=data_args.language, task=data_args.task)
|
| 643 |
|
| 644 |
def prepare_dataset(batch):
|
| 645 |
# process audio
|
| 646 |
sample = batch[audio_column_name]
|
| 647 |
+
inputs = feature_extractor(
|
| 648 |
+
sample["array"], sampling_rate=sample["sampling_rate"])
|
| 649 |
# process audio length
|
| 650 |
batch[model_input_name] = inputs.get(model_input_name)[0]
|
| 651 |
batch["input_length"] = len(sample["array"])
|
| 652 |
|
| 653 |
# process targets
|
| 654 |
+
input_str = batch[text_column_name].lower(
|
| 655 |
+
) if do_lower_case else batch[text_column_name]
|
| 656 |
if do_remove_punctuation:
|
| 657 |
input_str = normalizer(input_str).strip()
|
| 658 |
batch["labels"] = tokenizer(input_str).input_ids
|
|
|
|
| 663 |
prepare_dataset,
|
| 664 |
remove_columns=raw_datasets_features,
|
| 665 |
).with_format("torch")
|
| 666 |
+
|
| 667 |
# filter training data with inputs longer than max_input_length
|
| 668 |
def is_audio_in_length_range(length):
|
| 669 |
return min_input_length < length < max_input_length
|
|
|
|
| 673 |
is_audio_in_length_range,
|
| 674 |
input_columns=["input_length"],
|
| 675 |
)
|
| 676 |
+
|
| 677 |
if training_args.do_eval:
|
| 678 |
vectorized_datasets["eval"] = vectorized_datasets["eval"].filter(
|
| 679 |
is_audio_in_length_range,
|
| 680 |
input_columns=["input_length"],
|
| 681 |
)
|
| 682 |
|
|
|
|
| 683 |
# 8. Load Metric
|
| 684 |
metric = evaluate.load("wer")
|
| 685 |
do_normalize_eval = data_args.do_normalize_eval
|
|
|
|
| 698 |
pred_str = [normalizer(pred) for pred in pred_str]
|
| 699 |
label_str = [normalizer(label) for label in label_str]
|
| 700 |
# filtering step to only evaluate the samples that correspond to non-zero references:
|
| 701 |
+
pred_str = [pred_str[i]
|
| 702 |
+
for i in range(len(pred_str)) if len(label_str[i]) > 0]
|
| 703 |
+
label_str = [label_str[i]
|
| 704 |
+
for i in range(len(label_str)) if len(label_str[i]) > 0]
|
| 705 |
|
| 706 |
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
|
| 707 |
|
|
|
|
| 730 |
try:
|
| 731 |
from flax.metrics.tensorboard import SummaryWriter
|
| 732 |
|
| 733 |
+
summary_writer = SummaryWriter(
|
| 734 |
+
log_dir=Path(training_args.output_dir))
|
| 735 |
except ImportError as ie:
|
| 736 |
has_tensorboard = False
|
| 737 |
logger.warning(
|
|
|
|
| 749 |
|
| 750 |
# Store some constant
|
| 751 |
#num_epochs = int(training_args.num_train_epochs)
|
| 752 |
+
train_batch_size = int(
|
| 753 |
+
training_args.per_device_train_batch_size) * jax.device_count()
|
| 754 |
+
eval_batch_size = int(
|
| 755 |
+
training_args.per_device_eval_batch_size) * jax.device_count()
|
| 756 |
|
| 757 |
# Create learning rate schedule
|
| 758 |
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
|
|
|
| 777 |
if layer_norm_name in "".join(layer).lower()
|
| 778 |
]
|
| 779 |
)
|
| 780 |
+
flat_mask = {path: (path[-1] != "bias" and path[-2:]
|
| 781 |
+
not in layer_norm_named_params) for path in flat_params}
|
| 782 |
return traverse_util.unflatten_dict(flat_mask)
|
| 783 |
|
| 784 |
# create adam optimizer
|
|
|
|
| 792 |
)
|
| 793 |
|
| 794 |
# Setup train state
|
| 795 |
+
state = TrainState.create(
|
| 796 |
+
apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
| 797 |
|
| 798 |
# label smoothed cross entropy
|
| 799 |
def loss_fn(logits, labels, label_smoothing_factor=0.0):
|
|
|
|
| 805 |
confidence = 1.0 - label_smoothing_factor
|
| 806 |
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
| 807 |
normalizing_constant = -(
|
| 808 |
+
confidence * jnp.log(confidence) + (vocab_size - 1) *
|
| 809 |
+
low_confidence * jnp.log(low_confidence + 1e-20)
|
| 810 |
)
|
| 811 |
+
soft_labels = onehot(labels, vocab_size,
|
| 812 |
+
on_value=confidence, off_value=low_confidence)
|
| 813 |
|
| 814 |
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
| 815 |
loss = loss - normalizing_constant
|
|
|
|
| 827 |
|
| 828 |
def compute_loss(params):
|
| 829 |
labels = batch.pop("labels")
|
| 830 |
+
logits = state.apply_fn(
|
| 831 |
+
**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
| 832 |
loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
|
| 833 |
return loss, num_labels
|
| 834 |
|
|
|
|
| 843 |
# true grad = total grad / total samples
|
| 844 |
grad = jax.lax.psum(grad, "batch")
|
| 845 |
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
| 846 |
+
new_state = state.apply_gradients(
|
| 847 |
+
grads=grad, dropout_rng=new_dropout_rng)
|
| 848 |
|
| 849 |
+
metrics = {"loss": loss,
|
| 850 |
+
"learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
| 851 |
return new_state, metrics
|
| 852 |
|
| 853 |
# Define eval fn
|
|
|
|
| 871 |
|
| 872 |
def generate_step(params, batch):
|
| 873 |
model.params = params
|
| 874 |
+
output_ids = model.generate(batch[model_input_name], attention_mask=batch.get(
|
| 875 |
+
"attention_mask"), **gen_kwargs)
|
| 876 |
return output_ids.sequences
|
| 877 |
|
| 878 |
# Create parallel version of the train and eval step
|
| 879 |
p_train_step = jax.pmap(
|
| 880 |
partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
|
| 881 |
)
|
| 882 |
+
p_eval_step = jax.pmap(partial(
|
| 883 |
+
eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
|
| 884 |
p_generate_step = jax.pmap(generate_step, "batch")
|
| 885 |
|
| 886 |
# Replicate the train state on each device
|
| 887 |
state = state.replicate()
|
| 888 |
|
| 889 |
logger.info("***** Running training *****")
|
| 890 |
+
logger.info(
|
| 891 |
+
f" Num examples = {data_args.num_train_steps*train_batch_size}")
|
| 892 |
+
logger.info(
|
| 893 |
+
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
| 894 |
+
logger.info(
|
| 895 |
+
f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
| 896 |
logger.info(f" Total optimization steps = {data_args.num_train_steps}")
|
| 897 |
|
| 898 |
train_time = 0
|
| 899 |
+
|
| 900 |
# ======================== Training ================================
|
| 901 |
train_start = time.time()
|
| 902 |
|
|
|
|
| 912 |
num_workers = 0
|
| 913 |
# This is not working
|
| 914 |
# vectorized_datasets["train"] = vectorized_datasets["train"].shuffle()
|
| 915 |
+
train_data_loader = torch.utils.data.DataLoader(
|
| 916 |
+
batch_size=train_batch_size, dataset=vectorized_datasets["train"], num_workers=num_workers, collate_fn=collate_batch, drop_last=True)
|
| 917 |
+
train_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(
|
| 918 |
+
train_data_loader)
|
| 919 |
+
|
| 920 |
# train
|
| 921 |
+
for step in tqdm(range(data_args.num_train_steps), desc="Training...", position=1, leave=False):
|
| 922 |
+
|
| 923 |
try:
|
| 924 |
samples = next(train_data_iterator)
|
| 925 |
+
|
| 926 |
except StopIteration:
|
| 927 |
epoch += 1
|
| 928 |
+
train_data_loader = torch.utils.data.DataLoader(
|
| 929 |
+
batch_size=train_batch_size, dataset=vectorized_datasets["train"], num_workers=num_workers, collate_fn=collate_batch, drop_last=True)
|
| 930 |
+
train_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(
|
| 931 |
+
train_data_loader)
|
| 932 |
samples = next(train_data_iterator)
|
| 933 |
+
|
| 934 |
logger.info(
|
| 935 |
f"Completed epoch ({epoch} | Loss: {train_metric['loss']}, Learning Rate:"
|
| 936 |
f" {train_metric['learning_rate']})"
|
| 937 |
)
|
| 938 |
+
|
| 939 |
# reshaped_samples = {key: [feature[key] for feature in samples] for key in samples[0].keys()}
|
| 940 |
+
# breakpoint()
|
| 941 |
batch = data_collator(samples)
|
| 942 |
batch = shard(batch.data)
|
| 943 |
state, train_metric = p_train_step(state, batch)
|
|
|
|
| 952 |
eval_labels = []
|
| 953 |
|
| 954 |
#eval_loader = data_loader(input_rng, vectorized_datasets["eval"], eval_batch_size, drop_last=False)
|
| 955 |
+
eval_data_loader = torch.utils.data.DataLoader(
|
| 956 |
+
batch_size=eval_batch_size, dataset=vectorized_datasets["eval"], num_workers=num_workers, collate_fn=collate_batch, drop_last=False)
|
| 957 |
+
eval_data_iterator = torch.utils.data.dataloader._SingleProcessDataLoaderIter(
|
| 958 |
+
eval_data_loader)
|
| 959 |
|
| 960 |
for _ in tqdm(range(training_args.eval_steps), desc="Evaluating...", position=2, leave=False):
|
| 961 |
# Model forward
|
|
|
|
| 970 |
|
| 971 |
# generation
|
| 972 |
if training_args.predict_with_generate:
|
| 973 |
+
generated_ids = pad_shard_unpad(
|
| 974 |
+
p_generate_step)(state.params, batch.data)
|
| 975 |
+
eval_preds.extend(jax.device_get(
|
| 976 |
+
generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
| 977 |
eval_labels.extend(labels)
|
| 978 |
+
breakpoint()
|
| 979 |
# normalize eval metrics
|
| 980 |
eval_metrics = get_metrics(eval_metrics)
|
| 981 |
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
|
|
|
| 985 |
if training_args.predict_with_generate:
|
| 986 |
wer_metric = compute_metrics(eval_preds, eval_labels)
|
| 987 |
eval_metrics.update(wer_metric)
|
| 988 |
+
wer_desc = " ".join(
|
| 989 |
+
[f"Eval {key}: {value} |" for key, value in wer_metric.items()])
|
| 990 |
|
| 991 |
# Print metrics
|
| 992 |
desc = f"Epoch... ({epoch} | Eval Loss: {eval_metrics['loss']} | {wer_desc})"
|
|
|
|
| 994 |
|
| 995 |
# Save metrics
|
| 996 |
if has_tensorboard and jax.process_index() == 0:
|
| 997 |
+
write_metric(summary_writer, train_metrics,
|
| 998 |
+
eval_metrics, train_time, step)
|
| 999 |
|
| 1000 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 1001 |
if jax.process_index() == 0:
|
| 1002 |
+
params = jax.device_get(
|
| 1003 |
+
jax.tree_util.tree_map(lambda x: x[0], state.params))
|
| 1004 |
model.save_pretrained(training_args.output_dir, params=params)
|
| 1005 |
tokenizer.save_pretrained(training_args.output_dir)
|
| 1006 |
if training_args.push_to_hub:
|
| 1007 |
+
repo.push_to_hub(
|
| 1008 |
+
commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
| 1009 |
|
| 1010 |
|
| 1011 |
if __name__ == "__main__":
|
run_streaming.sh
CHANGED
|
@@ -2,6 +2,7 @@ python run_flax_speech_recognition_seq2seq_streaming_v3.py \
|
|
| 2 |
--model_name_or_path openai/whisper-tiny.en \
|
| 3 |
--dataset_name mozilla-foundation/common_voice_11_0 \
|
| 4 |
--dataset_config es \
|
|
|
|
| 5 |
--text_column_name sentence \
|
| 6 |
--train_split_name test\
|
| 7 |
--eval_split_name test\
|
|
|
|
| 2 |
--model_name_or_path openai/whisper-tiny.en \
|
| 3 |
--dataset_name mozilla-foundation/common_voice_11_0 \
|
| 4 |
--dataset_config es \
|
| 5 |
+
--language es \
|
| 6 |
--text_column_name sentence \
|
| 7 |
--train_split_name test\
|
| 8 |
--eval_split_name test\
|