|
|
| import torch |
| from datasets import load_dataset, DatasetDict |
| from datasets import Audio |
|
|
| from transformers import WhisperFeatureExtractor |
| from transformers import WhisperTokenizer |
| from transformers import WhisperProcessor |
| from transformers import WhisperForConditionalGeneration |
|
|
| from transformers import Seq2SeqTrainingArguments |
| from transformers import Seq2SeqTrainer |
|
|
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Union |
| import evaluate |
|
|
|
|
| |
| |
| @dataclass |
| class DataCollatorSpeechSeq2SeqWithPadding: |
| processor: Any |
|
|
| def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
| |
| |
| input_features = [{"input_features": feature["input_features"]} |
| for feature in features] |
| batch = self.processor.feature_extractor.pad( |
| input_features, return_tensors="pt") |
|
|
| |
| label_features = [{"input_ids": feature["labels"]} |
| for feature in features] |
| |
| labels_batch = self.processor.tokenizer.pad( |
| label_features, return_tensors="pt") |
|
|
| |
| labels = labels_batch["input_ids"].masked_fill( |
| labels_batch.attention_mask.ne(1), -100) |
|
|
| |
| |
| if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): |
| labels = labels[:, 1:] |
|
|
| batch["labels"] = labels |
|
|
| return batch |
|
|
| def main(): |
| |
| def compute_metrics(pred): |
| pred_ids = pred.predictions |
| label_ids = pred.label_ids |
|
|
| |
| label_ids[label_ids == -100] = tokenizer.pad_token_id |
|
|
| |
| pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
| label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) |
|
|
| wer = 100 * metric.compute(predictions=pred_str, references=label_str) |
|
|
| return {"wer": wer} |
|
|
| |
|
|
|
|
| def prepare_dataset(batch): |
| |
| audio = batch["audio"] |
|
|
| |
| batch["input_features"] = feature_extractor( |
| audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0] |
|
|
| |
| batch["labels"] = tokenizer(batch["sentence"]).input_ids |
| return batch |
|
|
|
|
| |
|
|
| |
| |
| source = "audio" |
| target = "sentence" |
|
|
|
|
| |
| speech_data = DatasetDict() |
|
|
| |
| |
| |
| |
| |
|
|
| |
| speech_data["train"] = load_dataset( |
| "mozilla-foundation/common_voice_11_0", "nn-NO", split="train", use_auth_token=True) |
| speech_data["test"] = load_dataset( |
| "mozilla-foundation/common_voice_11_0", "nn-NO", split="test", use_auth_token=True) |
|
|
|
|
| |
| if "audio" not in speech_data.column_names["train"]: |
| speech_data = speech_data.rename_column(source, "audio") |
|
|
| if "sentence" not in speech_data.column_names["train"]: |
| speech_data = speech_data.rename_column(target, "sentence") |
|
|
| |
| remove_list = [i for i in speech_data.column_names["train"] |
| if i not in ["audio", "sentence"]] |
|
|
| speech_data = speech_data.remove_columns(remove_list) |
|
|
| |
| feature_extractor = WhisperFeatureExtractor.from_pretrained( |
| "openai/whisper-small") |
| tokenizer = WhisperTokenizer.from_pretrained( |
| "openai/whisper-small", language="Norwegian", task="transcribe") |
| processor = WhisperProcessor.from_pretrained( |
| "openai/whisper-small", language="Norwegian", task="transcribe") |
| data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) |
|
|
| |
| speech_data = speech_data.cast_column("audio", Audio(sampling_rate=16000)) |
| speech_data = speech_data.map( |
| prepare_dataset, remove_columns=speech_data.column_names["train"], num_proc=1) |
|
|
| |
| metric = evaluate.load("wer") |
|
|
| |
| |
| model = WhisperForConditionalGeneration.from_pretrained( |
| "openai/whisper-small", use_cache=False) |
|
|
| |
| model.config.forced_decoder_ids = None |
| model.config.suppress_tokens = [] |
|
|
| |
| training_args = Seq2SeqTrainingArguments( |
| output_dir="../whisper-testrun1", |
| per_device_train_batch_size=16, |
| gradient_accumulation_steps=1, |
| learning_rate=2e-5, |
| warmup_steps=500, |
| max_steps=5000, |
| gradient_checkpointing=True, |
| group_by_length=True, |
| evaluation_strategy="steps", |
| per_device_eval_batch_size=8, |
| predict_with_generate=True, |
| generation_max_length=225, |
| save_steps=500, |
| eval_steps=500, |
| logging_steps=25, |
| report_to=["tensorboard"], |
| load_best_model_at_end=True, |
| metric_for_best_model="wer", |
| greater_is_better=False, |
| push_to_hub=True, |
| ) |
|
|
| trainer = Seq2SeqTrainer( |
| args=training_args, |
| model=model, |
| train_dataset=speech_data["train"], |
| eval_dataset=speech_data["test"], |
| data_collator=data_collator, |
| compute_metrics=compute_metrics, |
| tokenizer=processor.feature_extractor, |
| ) |
|
|
|
|
| |
| trainer.train() |
|
|
|
|
| def _mp_fn(index): |
| |
| print("The XLA is initiated") |
| main() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|