| import torch | |
| import os | |
| from collections import OrderedDict | |
| import numpy as np | |
| import tempfile | |
| import numpy as np | |
| import mmap | |
| import pickle | |
| import signal | |
| import sys | |
| import pdb | |
| from ..utils import xtqdm as tqdm | |
| __all__=['ExampleInstance', 'example_to_feature', 'ExampleSet'] | |
| class ExampleInstance: | |
| def __init__(self, segments, label=None, **kwv): | |
| self.segments = segments | |
| self.label = label | |
| self.__dict__.update(kwv) | |
| def __repr__(self): | |
| return f'segments: {self.segments}\nlabel: {self.label}' | |
| def __getitem__(self, i): | |
| return self.segments[i] | |
| def __len__(self): | |
| return len(self.segments) | |
| class ExampleSet: | |
| def __init__(self, pairs): | |
| self._data = np.array([pickle.dumps(p) for p in pairs]) | |
| self.total = len(self._data) | |
| def __getitem__(self, idx): | |
| """ | |
| return pair | |
| """ | |
| if isinstance(idx, tuple): | |
| idx,rng, ext_params = idx | |
| else: | |
| rng,ext_params=None, None | |
| content = self._data[idx] | |
| example = pickle.loads(content) | |
| return example | |
| def __len__(self): | |
| return self.total | |
| def __iter__(self): | |
| for i in range(self.total): | |
| yield self[i] | |
| def _truncate_segments(segments, max_num_tokens, rng): | |
| """ | |
| Truncate sequence pair according to original BERT implementation: | |
| https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L391 | |
| """ | |
| while True: | |
| if sum(len(s) for s in segments)<=max_num_tokens: | |
| break | |
| segments = sorted(segments, key=lambda s:len(s), reverse=True) | |
| trunc_tokens = segments[0] | |
| assert len(trunc_tokens) >= 1 | |
| if rng.random() < 0.5: | |
| trunc_tokens.pop(0) | |
| else: | |
| trunc_tokens.pop() | |
| return segments | |
| def example_to_feature(tokenizer, example, max_seq_len=512, rng=None, mask_generator = None, ext_params=None, label_type='int', **kwargs): | |
| if not rng: | |
| rng = random | |
| max_num_tokens = max_seq_len - len(example.segments) - 1 | |
| segments = _truncate_segments([tokenizer.tokenize(s) for s in example.segments], max_num_tokens, rng) | |
| tokens = ['[CLS]'] | |
| type_ids = [0] | |
| for i,s in enumerate(segments): | |
| tokens.extend(s) | |
| tokens.append('[SEP]') | |
| type_ids.extend([i]*(len(s)+1)) | |
| if mask_generator: | |
| tokens, lm_labels = mask_generator.mask_tokens(tokens, rng) | |
| token_ids = tokenizer.convert_tokens_to_ids(tokens) | |
| pos_ids = list(range(len(token_ids))) | |
| input_mask = [1]*len(token_ids) | |
| features = OrderedDict(input_ids = token_ids, | |
| type_ids = type_ids, | |
| position_ids = pos_ids, | |
| input_mask = input_mask) | |
| if mask_generator: | |
| features['lm_labels'] = lm_labels | |
| padding_size = max(0, max_seq_len - len(token_ids)) | |
| for f in features: | |
| features[f].extend([0]*padding_size) | |
| features[f] = torch.tensor(features[f], dtype=torch.int) | |
| label_type = torch.int if label_type=='int' else torch.float | |
| if example.label is not None: | |
| features['labels'] = torch.tensor(example.label, dtype=label_type) | |
| return features | |