Commit ·
6ea4d4a
1
Parent(s): ee5b1b2
allow multiple languages and datasets
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
|
@@ -49,6 +49,7 @@ from transformers import (
|
|
| 49 |
set_seed,
|
| 50 |
)
|
| 51 |
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
|
|
|
| 52 |
from transformers.trainer_pt_utils import IterableDatasetShard
|
| 53 |
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
| 54 |
from transformers.utils import check_min_version, send_example_telemetry
|
|
@@ -61,6 +62,9 @@ require_version("datasets>=1.18.2", "To fix: pip install -r examples/pytorch/spe
|
|
| 61 |
|
| 62 |
logger = logging.getLogger(__name__)
|
| 63 |
|
|
|
|
|
|
|
|
|
|
| 64 |
wandb_token = os.environ.get("WANDB_TOKEN", "None")
|
| 65 |
hf_token = os.environ.get("HF_TOKEN", None)
|
| 66 |
if (hf_token is None or wandb_token == "None") and os.path.exists("./creds.txt"):
|
|
@@ -160,10 +164,16 @@ class DataTrainingArguments:
|
|
| 160 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 161 |
"""
|
| 162 |
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
| 165 |
)
|
| 166 |
-
|
| 167 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 168 |
)
|
| 169 |
text_column: Optional[str] = field(
|
|
@@ -232,7 +242,16 @@ class DataTrainingArguments:
|
|
| 232 |
default=True,
|
| 233 |
metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
|
| 234 |
)
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
default=None,
|
| 237 |
metadata={
|
| 238 |
"help": (
|
|
@@ -273,6 +292,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
| 273 |
|
| 274 |
processor: Any
|
| 275 |
decoder_start_token_id: int
|
|
|
|
| 276 |
|
| 277 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
| 278 |
# split inputs and labels since they have to be of different lengths and need
|
|
@@ -280,6 +300,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
| 280 |
model_input_name = self.processor.model_input_names[0]
|
| 281 |
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
|
| 282 |
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
|
|
|
| 283 |
|
| 284 |
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
| 285 |
|
|
@@ -292,6 +313,15 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
| 292 |
# cut bos token here as it's append later anyways
|
| 293 |
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
| 294 |
labels = labels[:, 1:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
batch["labels"] = labels
|
| 297 |
|
|
@@ -316,7 +346,7 @@ def notify_me(recipient, message=None):
|
|
| 316 |
from email.mime.text import MIMEText
|
| 317 |
|
| 318 |
msg = MIMEText(message)
|
| 319 |
-
msg["Subject"] = "Training
|
| 320 |
msg["From"] = "marinone.auto@gmail.com"
|
| 321 |
msg["To"] = recipient
|
| 322 |
|
|
@@ -334,16 +364,26 @@ def load_maybe_streaming_dataset(dataset_names, dataset_config_names, split="tra
|
|
| 334 |
each split is loaded individually and then splits combined by taking alternating examples from
|
| 335 |
each (interleaving).
|
| 336 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
if "," in dataset_names or "+" in split:
|
| 338 |
# load multiple splits separated by the `+` symbol with streaming mode
|
| 339 |
dataset_splits = []
|
| 340 |
-
for dataset_name, dataset_config_name, split_names in zip(
|
| 341 |
-
dataset_names.split(","), dataset_config_names.split(","), split.split(",")
|
| 342 |
):
|
| 343 |
for split_name in split_names.split("+"):
|
| 344 |
-
dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
dataset_splits.append(dataset)
|
| 346 |
-
|
| 347 |
# interleave multiple splits to form one dataset
|
| 348 |
interleaved_dataset = interleave_datasets(dataset_splits)
|
| 349 |
return interleaved_dataset
|
|
@@ -426,20 +466,23 @@ def main():
|
|
| 426 |
|
| 427 |
if training_args.do_train:
|
| 428 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
| 429 |
-
data_args.
|
| 430 |
-
data_args.
|
| 431 |
split=data_args.train_split_name,
|
| 432 |
use_auth_token=hf_token if model_args.use_auth_token else None,
|
| 433 |
streaming=data_args.streaming,
|
|
|
|
|
|
|
| 434 |
)
|
| 435 |
|
| 436 |
if training_args.do_eval:
|
| 437 |
raw_datasets["eval"] = load_maybe_streaming_dataset(
|
| 438 |
-
data_args.
|
| 439 |
-
data_args.
|
| 440 |
split=data_args.eval_split_name,
|
| 441 |
use_auth_token=hf_token if model_args.use_auth_token else None,
|
| 442 |
streaming=data_args.streaming,
|
|
|
|
| 443 |
)
|
| 444 |
|
| 445 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
|
@@ -451,6 +494,7 @@ def main():
|
|
| 451 |
f"{', '.join(raw_datasets_features)}."
|
| 452 |
)
|
| 453 |
|
|
|
|
| 454 |
if data_args.text_column_name not in raw_datasets_features:
|
| 455 |
raise ValueError(
|
| 456 |
f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
|
|
@@ -504,9 +548,13 @@ def main():
|
|
| 504 |
if model_args.freeze_encoder:
|
| 505 |
model.freeze_encoder()
|
| 506 |
|
| 507 |
-
if data_args.language is not None:
|
| 508 |
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
|
|
|
| 509 |
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
|
|
|
|
|
|
|
|
|
| 510 |
|
| 511 |
# 6. Resample speech dataset if necessary
|
| 512 |
logger.info("*** Resample dataset ***")
|
|
@@ -558,6 +606,7 @@ def main():
|
|
| 558 |
return batch
|
| 559 |
|
| 560 |
with training_args.main_process_first(desc="dataset map pre-processing"):
|
|
|
|
| 561 |
vectorized_datasets = raw_datasets.map(
|
| 562 |
prepare_dataset,
|
| 563 |
remove_columns=raw_datasets_features,
|
|
@@ -617,9 +666,14 @@ def main():
|
|
| 617 |
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
| 618 |
|
| 619 |
# 10. Define data collator
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
| 621 |
processor=processor,
|
| 622 |
decoder_start_token_id=model.config.decoder_start_token_id,
|
|
|
|
| 623 |
)
|
| 624 |
|
| 625 |
# 11. Configure Trainer
|
|
@@ -716,20 +770,24 @@ def main():
|
|
| 716 |
if model_args.model_index_name is not None:
|
| 717 |
kwargs["model_name"] = model_args.model_index_name
|
| 718 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
if training_args.push_to_hub:
|
| 720 |
logger.info("*** Pushing to hub ***")
|
| 721 |
trainer.push_to_hub(**kwargs)
|
| 722 |
logger.info("*** Pushed to hub ***")
|
|
|
|
|
|
|
| 723 |
else:
|
| 724 |
logger.info("*** Creating model card ***")
|
| 725 |
trainer.create_model_card(**kwargs)
|
| 726 |
logger.info("*** Model card created ***")
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
logger.info("*** Sending notification ***")
|
| 730 |
-
notify_me(recipient="marinone94@gmail.com", message=json.dumps(kwargs, indent=4))
|
| 731 |
-
|
| 732 |
-
logger.info("*** Training complete!!! ***")
|
| 733 |
|
| 734 |
return results
|
| 735 |
|
|
|
|
| 49 |
set_seed,
|
| 50 |
)
|
| 51 |
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
| 52 |
+
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
|
| 53 |
from transformers.trainer_pt_utils import IterableDatasetShard
|
| 54 |
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
| 55 |
from transformers.utils import check_min_version, send_example_telemetry
|
|
|
|
| 62 |
|
| 63 |
logger = logging.getLogger(__name__)
|
| 64 |
|
| 65 |
+
SENDING_NOTIFICATION = "*** Sending notification to email ***"
|
| 66 |
+
RECIPIENT_ADDRESS = "marinone94@gmail.com"
|
| 67 |
+
|
| 68 |
wandb_token = os.environ.get("WANDB_TOKEN", "None")
|
| 69 |
hf_token = os.environ.get("HF_TOKEN", None)
|
| 70 |
if (hf_token is None or wandb_token == "None") and os.path.exists("./creds.txt"):
|
|
|
|
| 164 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 165 |
"""
|
| 166 |
|
| 167 |
+
dataset_train_name: str = field(
|
| 168 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
| 169 |
+
)
|
| 170 |
+
dataset_train_config_name: Optional[str] = field(
|
| 171 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 172 |
+
)
|
| 173 |
+
dataset_eval_name: str = field(
|
| 174 |
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
| 175 |
)
|
| 176 |
+
dataset_eval_config_name: Optional[str] = field(
|
| 177 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 178 |
)
|
| 179 |
text_column: Optional[str] = field(
|
|
|
|
| 242 |
default=True,
|
| 243 |
metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
|
| 244 |
)
|
| 245 |
+
language_train: str = field(
|
| 246 |
+
default=None,
|
| 247 |
+
metadata={
|
| 248 |
+
"help": (
|
| 249 |
+
"Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
|
| 250 |
+
"only. For English speech recognition, it should be set to `None`."
|
| 251 |
+
)
|
| 252 |
+
},
|
| 253 |
+
)
|
| 254 |
+
language_eval: str = field(
|
| 255 |
default=None,
|
| 256 |
metadata={
|
| 257 |
"help": (
|
|
|
|
| 292 |
|
| 293 |
processor: Any
|
| 294 |
decoder_start_token_id: int
|
| 295 |
+
task_id: int
|
| 296 |
|
| 297 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
| 298 |
# split inputs and labels since they have to be of different lengths and need
|
|
|
|
| 300 |
model_input_name = self.processor.model_input_names[0]
|
| 301 |
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
|
| 302 |
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
| 303 |
+
lang_features = [f"<|{TO_LANGUAGE_CODE[feature['language']]}|>" for feature in features]
|
| 304 |
|
| 305 |
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
| 306 |
|
|
|
|
| 313 |
# cut bos token here as it's append later anyways
|
| 314 |
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
| 315 |
labels = labels[:, 1:]
|
| 316 |
+
lang_token_ids = self.processor.tokenizer(lang_features).input_ids
|
| 317 |
+
# Replace language and task if they are in the beginning, otherwise add them
|
| 318 |
+
if (labels[:, 1] == self.task_id).all().cpu().item():
|
| 319 |
+
labels[:, 0] = lang_token_ids
|
| 320 |
+
labels[:, 1] = torch.full_like(labels[:, 1], self.task_id)
|
| 321 |
+
else:
|
| 322 |
+
# convert task id to tensor of labels dim to concatenate
|
| 323 |
+
task_id = torch.full_like(labels[:, 0], self.task_id)
|
| 324 |
+
labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
|
| 325 |
|
| 326 |
batch["labels"] = labels
|
| 327 |
|
|
|
|
| 346 |
from email.mime.text import MIMEText
|
| 347 |
|
| 348 |
msg = MIMEText(message)
|
| 349 |
+
msg["Subject"] = "Training updates..."
|
| 350 |
msg["From"] = "marinone.auto@gmail.com"
|
| 351 |
msg["To"] = recipient
|
| 352 |
|
|
|
|
| 364 |
each split is loaded individually and then splits combined by taking alternating examples from
|
| 365 |
each (interleaving).
|
| 366 |
"""
|
| 367 |
+
column_names = None
|
| 368 |
+
if "column_names" in kwargs:
|
| 369 |
+
column_names = kwargs.pop("column_names").split(",")
|
| 370 |
+
|
| 371 |
if "," in dataset_names or "+" in split:
|
| 372 |
# load multiple splits separated by the `+` symbol with streaming mode
|
| 373 |
dataset_splits = []
|
| 374 |
+
for dataset_name, dataset_config_name, split_names, lang in zip(
|
| 375 |
+
dataset_names.split(","), dataset_config_names.split(","), split.split(","), kwargs.pop("language").split(",")
|
| 376 |
):
|
| 377 |
for split_name in split_names.split("+"):
|
| 378 |
+
dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
|
| 379 |
+
raw_datasets_features = list(next(iter(dataset.values())).features.keys())
|
| 380 |
+
if column_names[0] not in raw_datasets_features:
|
| 381 |
+
if len(column_names) == 1 or column_names[1] not in raw_datasets_features:
|
| 382 |
+
raise ValueError("Column name not found in dataset.")
|
| 383 |
+
dataset = dataset.rename_columns(column_names[1], column_names[0])
|
| 384 |
+
dataset["language"] = lang
|
| 385 |
dataset_splits.append(dataset)
|
| 386 |
+
|
| 387 |
# interleave multiple splits to form one dataset
|
| 388 |
interleaved_dataset = interleave_datasets(dataset_splits)
|
| 389 |
return interleaved_dataset
|
|
|
|
| 466 |
|
| 467 |
if training_args.do_train:
|
| 468 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
| 469 |
+
data_args.dataset_train_name,
|
| 470 |
+
data_args.dataset_train_config_name,
|
| 471 |
split=data_args.train_split_name,
|
| 472 |
use_auth_token=hf_token if model_args.use_auth_token else None,
|
| 473 |
streaming=data_args.streaming,
|
| 474 |
+
column_names=data_args.text_column_name,
|
| 475 |
+
language=data_args.language_train
|
| 476 |
)
|
| 477 |
|
| 478 |
if training_args.do_eval:
|
| 479 |
raw_datasets["eval"] = load_maybe_streaming_dataset(
|
| 480 |
+
data_args.dataset_eval_name,
|
| 481 |
+
data_args.dataset_eval_config_name,
|
| 482 |
split=data_args.eval_split_name,
|
| 483 |
use_auth_token=hf_token if model_args.use_auth_token else None,
|
| 484 |
streaming=data_args.streaming,
|
| 485 |
+
language=data_args.language_eval
|
| 486 |
)
|
| 487 |
|
| 488 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
|
|
|
| 494 |
f"{', '.join(raw_datasets_features)}."
|
| 495 |
)
|
| 496 |
|
| 497 |
+
data_args.text_column_name = data_args.text_column_name.split(",")[0]
|
| 498 |
if data_args.text_column_name not in raw_datasets_features:
|
| 499 |
raise ValueError(
|
| 500 |
f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
|
|
|
|
| 548 |
if model_args.freeze_encoder:
|
| 549 |
model.freeze_encoder()
|
| 550 |
|
| 551 |
+
if data_args.language is not None and len(data_args.language.split(",")) == 1:
|
| 552 |
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
| 553 |
+
# If more than a langugae is specified, it will be specified in the data collator
|
| 554 |
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
| 555 |
+
elif data_args.language is not None and len(data_args.language.split(",")) > 1:
|
| 556 |
+
# make sure language and task are not stored in the model config
|
| 557 |
+
model.config.forced_decoder_ids = None
|
| 558 |
|
| 559 |
# 6. Resample speech dataset if necessary
|
| 560 |
logger.info("*** Resample dataset ***")
|
|
|
|
| 606 |
return batch
|
| 607 |
|
| 608 |
with training_args.main_process_first(desc="dataset map pre-processing"):
|
| 609 |
+
raw_datasets_features.remove("language")
|
| 610 |
vectorized_datasets = raw_datasets.map(
|
| 611 |
prepare_dataset,
|
| 612 |
remove_columns=raw_datasets_features,
|
|
|
|
| 666 |
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
| 667 |
|
| 668 |
# 10. Define data collator
|
| 669 |
+
task_token = data_args.task
|
| 670 |
+
if not task_token.startswith('<|'):
|
| 671 |
+
task_token = f'<{task_token}>'
|
| 672 |
+
task_id = tokenizer(task_token).input_ids[0]
|
| 673 |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
| 674 |
processor=processor,
|
| 675 |
decoder_start_token_id=model.config.decoder_start_token_id,
|
| 676 |
+
task_id=task_id
|
| 677 |
)
|
| 678 |
|
| 679 |
# 11. Configure Trainer
|
|
|
|
| 770 |
if model_args.model_index_name is not None:
|
| 771 |
kwargs["model_name"] = model_args.model_index_name
|
| 772 |
|
| 773 |
+
# Training complete notification
|
| 774 |
+
logger.info(SENDING_NOTIFICATION)
|
| 775 |
+
notify_me(recipient=RECIPIENT_ADDRESS, message=json.dumps(kwargs, indent=4))
|
| 776 |
+
logger.info("*** Training complete!!! ***")
|
| 777 |
+
|
| 778 |
+
|
| 779 |
if training_args.push_to_hub:
|
| 780 |
logger.info("*** Pushing to hub ***")
|
| 781 |
trainer.push_to_hub(**kwargs)
|
| 782 |
logger.info("*** Pushed to hub ***")
|
| 783 |
+
logger.info(SENDING_NOTIFICATION)
|
| 784 |
+
notify_me(recipient=RECIPIENT_ADDRESS, message="Model pushed to hub")
|
| 785 |
else:
|
| 786 |
logger.info("*** Creating model card ***")
|
| 787 |
trainer.create_model_card(**kwargs)
|
| 788 |
logger.info("*** Model card created ***")
|
| 789 |
+
logger.info(SENDING_NOTIFICATION)
|
| 790 |
+
notify_me(recipient=RECIPIENT_ADDRESS, message="Model card created")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 791 |
|
| 792 |
return results
|
| 793 |
|
test_run_nordic.sh
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
python $1run_speech_recognition_seq2seq_streaming.py \
|
| 2 |
--model_name_or_path="openai/whisper-tiny" \
|
| 3 |
-
--
|
| 4 |
-
--
|
| 5 |
-
--
|
| 6 |
-
--train_split_name="train+validation,train+validation,train+validation,train,train+test,
|
|
|
|
|
|
|
|
|
|
| 7 |
--eval_split_name="test" \
|
| 8 |
--model_index_name="Whisper Tiny Swedish" \
|
| 9 |
--max_train_samples="64" \
|
|
@@ -22,7 +25,7 @@ python $1run_speech_recognition_seq2seq_streaming.py \
|
|
| 22 |
--generation_max_length="225" \
|
| 23 |
--length_column_name="input_length" \
|
| 24 |
--max_duration_in_seconds="30" \
|
| 25 |
-
--text_column_name="sentence" \
|
| 26 |
--freeze_feature_encoder="False" \
|
| 27 |
--report_to="wandb" \
|
| 28 |
--metric_for_best_model="wer" \
|
|
|
|
| 1 |
python $1run_speech_recognition_seq2seq_streaming.py \
|
| 2 |
--model_name_or_path="openai/whisper-tiny" \
|
| 3 |
+
--dataset_train_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,arpelarpe/nota,NbAiLab/NPSC,google/fleurs,google/fleurs,google/fleurs" \
|
| 4 |
+
--dataset_train_config_name="sv-SE,da,nn-NO,,no-distant,,16k_mp3_nynorsk,sv_se,da_dk,nb_no" \
|
| 5 |
+
--language_train="swedish,danish,norwegian,swedish,norwegian,danish,norwegian,swedish,danish,norwegian" \
|
| 6 |
+
--train_split_name="train+validation,train+validation,train+validation,train,train+test,train,train+validation,train+validation,train+validation,train+validation" \
|
| 7 |
+
--dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
|
| 8 |
+
--dataset_eval_config_name="sv-SE,da,nn-NO" \
|
| 9 |
+
--language_eval="swedish,danish,norwegian" \
|
| 10 |
--eval_split_name="test" \
|
| 11 |
--model_index_name="Whisper Tiny Swedish" \
|
| 12 |
--max_train_samples="64" \
|
|
|
|
| 25 |
--generation_max_length="225" \
|
| 26 |
--length_column_name="input_length" \
|
| 27 |
--max_duration_in_seconds="30" \
|
| 28 |
+
--text_column_name="sentence,text" \
|
| 29 |
--freeze_feature_encoder="False" \
|
| 30 |
--report_to="wandb" \
|
| 31 |
--metric_for_best_model="wer" \
|