scasutt commited on
Commit
a848498
·
1 Parent(s): bed1127

Create DataCollatorCTCWithPadding.py

Browse files
Files changed (1) hide show
  1. DataCollatorCTCWithPadding.py +67 -0
DataCollatorCTCWithPadding.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #create data collator CTC
2
+ import torch
3
+
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ @dataclass
8
+ class DataCollatorCTCWithPadding:
9
+ """
10
+ Data collator that will dynamically pad the inputs received.
11
+ Args:
12
+ processor (:class:`~transformers.Wav2Vec2Processor`)
13
+ The processor used for proccessing the data.
14
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
15
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
16
+ among:
17
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
18
+ sequence if provided).
19
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
20
+ maximum acceptable input length for the model if that argument is not provided.
21
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
22
+ different lengths).
23
+ max_length (:obj:`int`, `optional`):
24
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
25
+ max_length_labels (:obj:`int`, `optional`):
26
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
27
+ pad_to_multiple_of (:obj:`int`, `optional`):
28
+ If set will pad the sequence to a multiple of the provided value.
29
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
30
+ 7.5 (Volta).
31
+ """
32
+
33
+ processor: Wav2Vec2Processor
34
+ padding: Union[bool, str] = True
35
+ max_length: Optional[int] = None
36
+ max_length_labels: Optional[int] = None
37
+ pad_to_multiple_of: Optional[int] = None
38
+ pad_to_multiple_of_labels: Optional[int] = None
39
+
40
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
41
+ # split inputs and labels since they have to be of different lenghts and need
42
+ # different padding methods
43
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
44
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
45
+
46
+ batch = self.processor.pad(
47
+ input_features,
48
+ padding=self.padding,
49
+ max_length=self.max_length,
50
+ pad_to_multiple_of=self.pad_to_multiple_of,
51
+ return_tensors="pt",
52
+ )
53
+ with self.processor.as_target_processor():
54
+ labels_batch = self.processor.pad(
55
+ label_features,
56
+ padding=self.padding,
57
+ max_length=self.max_length_labels,
58
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
59
+ return_tensors="pt",
60
+ )
61
+
62
+ # replace padding with -100 to ignore loss correctly
63
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
64
+
65
+ batch["labels"] = labels
66
+
67
+ return batch