| ## Test model | |
| To test this model run the following code: | |
| ```python | |
| from datasets import load_dataset | |
| from transformers import Wav2Vec2ForCTC | |
| import torchaudio | |
| import torch | |
| ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") | |
| model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2_tiny_random") | |
| def load_audio(batch): | |
| batch["samples"], _ = torchaudio.load(batch["file"]) | |
| return batch | |
| ds = ds.map(load_audio) | |
| input_values = torch.nn.utils.rnn.pad_sequence([torch.tensor(x[0]) for x in ds["samples"][:10]], batch_first=True) | |
| # forward | |
| logits = model(input_values).logits | |
| pred_ids = torch.argmax(logits, dim=-1) | |
| # dummy loss | |
| dummy_labels = pred_ids.clone() | |
| dummy_labels[dummy_labels == model.config.pad_token_id] = 1 # can't have CTC blank token in label | |
| dummy_labels = dummy_labels[:, -(dummy_labels.shape[1] // 4):] # make sure labels are shorter to avoid "inf" loss (can still happen though...) | |
| loss = model(input_values, labels=dummy_labels).loss | |
| ``` | |