|
|
from dataclasses import dataclass |
|
|
from typing import Union, Optional |
|
|
|
|
|
from easy_tpp.preprocess.event_tokenizer import EventTokenizer |
|
|
from easy_tpp.utils import PaddingStrategy, TruncationStrategy |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TPPDataCollator: |
|
|
""" |
|
|
Data collator that will dynamically pad the inputs of event sequences. |
|
|
|
|
|
Args: |
|
|
tokenizer ([`EventTokenizer`]): |
|
|
The tokenizer used for encoding the data. |
|
|
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): |
|
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index) |
|
|
among: |
|
|
|
|
|
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single |
|
|
sequence is provided). |
|
|
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum |
|
|
acceptable input length for the model if that argument is not provided. |
|
|
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). |
|
|
max_length (`int`, *optional*): |
|
|
Maximum length of the returned list and optionally padding length (see above). |
|
|
return_tensors (`str`): |
|
|
The type of Tensor to return. Allowable values are "np", "pt" and "tf". |
|
|
""" |
|
|
|
|
|
tokenizer: EventTokenizer |
|
|
padding: Union[bool, str, PaddingStrategy] = True |
|
|
max_length: Optional[int] = None |
|
|
truncation: Union[bool, str, TruncationStrategy] = False |
|
|
return_tensors: str = "pt" |
|
|
|
|
|
def __call__(self, features, return_tensors=None): |
|
|
if return_tensors is None: |
|
|
return_tensors = self.return_tensors |
|
|
|
|
|
batch = self.tokenizer.pad( |
|
|
features, |
|
|
padding=self.padding, |
|
|
max_length=self.max_length, |
|
|
truncation=self.truncation, |
|
|
return_tensors=return_tensors, |
|
|
) |
|
|
|
|
|
return batch |
|
|
|