patrickvonplaten commited on
Commit
a1b0827
·
1 Parent(s): 0369ec9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -3
README.md CHANGED
@@ -21,10 +21,12 @@ ds = ds.map(load_audio)
21
  input_values = torch.nn.utils.rnn.pad_sequence([torch.tensor(x[0]) for x in ds["samples"][:10]], batch_first=True)
22
 
23
  # forward
24
- logits = model(input_ids).logits
25
  pred_ids = torch.argmax(logits, dim=-1)
26
 
27
  # dummy loss
28
- dummy_labels = torch.zeros((logits.shape[0], logits.shape[0] - 10))
29
- loss = model(input_ids, labels=dummy_labels).loss
 
 
30
  ```
 
21
  input_values = torch.nn.utils.rnn.pad_sequence([torch.tensor(x[0]) for x in ds["samples"][:10]], batch_first=True)
22
 
23
  # forward
24
+ logits = model(input_values).logits
25
  pred_ids = torch.argmax(logits, dim=-1)
26
 
27
  # dummy loss
28
+ dummy_labels = pred_ids.clone()
29
+ dummy_labels[dummy_labels == model.config.pad_token_id] = 1 # can't have CTC blank token in label
30
+ dummy_labels = dummy_labels[:, -(dummy_labels.shape[1] // 4):] # make sure labels are shorter to avoid "inf" loss (can still happen though...)
31
+ loss = model(input_values, labels=dummy_labels).loss
32
  ```