| | --- |
| | language: en |
| | datasets: |
| | - librispeech_asr |
| | tags: |
| | - automatic-speech-recognition |
| | license: apache-2.0 |
| | --- |
| | |
| | ## Test model |
| |
|
| | To test this model run the following code: |
| |
|
| | ```python |
| | from datasets import load_dataset |
| | from transformers import Wav2Vec2ForCTC |
| | import torchaudio |
| | import torch |
| | |
| | ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") |
| | |
| | model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2_tiny_random_robust") |
| | |
| | def load_audio(batch): |
| | batch["samples"], _ = torchaudio.load(batch["file"]) |
| | return batch |
| | |
| | ds = ds.map(load_audio) |
| | |
| | input_values = torch.nn.utils.rnn.pad_sequence([torch.tensor(x[0]) for x in ds["samples"][:10]], batch_first=True) |
| | |
| | # forward |
| | logits = model(input_values).logits |
| | pred_ids = torch.argmax(logits, dim=-1) |
| | |
| | # dummy loss |
| | dummy_labels = pred_ids.clone() |
| | dummy_labels[dummy_labels == model.config.pad_token_id] = 1 # can't have CTC blank token in label |
| | dummy_labels = dummy_labels[:, -(dummy_labels.shape[1] // 4):] # make sure labels are shorter to avoid "inf" loss (can still happen though...) |
| | loss = model(input_values, labels=dummy_labels).loss |
| | ``` |
| |
|