File size: 2,941 Bytes
ab0f6ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | import torch
import os
from collections import OrderedDict
import numpy as np
import tempfile
import numpy as np
import mmap
import pickle
import signal
import sys
import pdb
from ..utils import xtqdm as tqdm
__all__=['ExampleInstance', 'example_to_feature', 'ExampleSet']
class ExampleInstance:
def __init__(self, segments, label=None, **kwv):
self.segments = segments
self.label = label
self.__dict__.update(kwv)
def __repr__(self):
return f'segments: {self.segments}\nlabel: {self.label}'
def __getitem__(self, i):
return self.segments[i]
def __len__(self):
return len(self.segments)
class ExampleSet:
def __init__(self, pairs):
self._data = np.array([pickle.dumps(p) for p in pairs])
self.total = len(self._data)
def __getitem__(self, idx):
"""
return pair
"""
if isinstance(idx, tuple):
idx,rng, ext_params = idx
else:
rng,ext_params=None, None
content = self._data[idx]
example = pickle.loads(content)
return example
def __len__(self):
return self.total
def __iter__(self):
for i in range(self.total):
yield self[i]
def _truncate_segments(segments, max_num_tokens, rng):
"""
Truncate sequence pair according to original BERT implementation:
https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L391
"""
while True:
if sum(len(s) for s in segments)<=max_num_tokens:
break
segments = sorted(segments, key=lambda s:len(s), reverse=True)
trunc_tokens = segments[0]
assert len(trunc_tokens) >= 1
if rng.random() < 0.5:
trunc_tokens.pop(0)
else:
trunc_tokens.pop()
return segments
def example_to_feature(tokenizer, example, max_seq_len=512, rng=None, mask_generator = None, ext_params=None, label_type='int', **kwargs):
if not rng:
rng = random
max_num_tokens = max_seq_len - len(example.segments) - 1
segments = _truncate_segments([tokenizer.tokenize(s) for s in example.segments], max_num_tokens, rng)
tokens = ['[CLS]']
type_ids = [0]
for i,s in enumerate(segments):
tokens.extend(s)
tokens.append('[SEP]')
type_ids.extend([i]*(len(s)+1))
if mask_generator:
tokens, lm_labels = mask_generator.mask_tokens(tokens, rng)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
pos_ids = list(range(len(token_ids)))
input_mask = [1]*len(token_ids)
features = OrderedDict(input_ids = token_ids,
type_ids = type_ids,
position_ids = pos_ids,
input_mask = input_mask)
if mask_generator:
features['lm_labels'] = lm_labels
padding_size = max(0, max_seq_len - len(token_ids))
for f in features:
features[f].extend([0]*padding_size)
features[f] = torch.tensor(features[f], dtype=torch.int)
label_type = torch.int if label_type=='int' else torch.float
if example.label is not None:
features['labels'] = torch.tensor(example.label, dtype=label_type)
return features
|