|
|
|
|
|
|
|
|
import gzip |
|
|
import json |
|
|
import math |
|
|
import random |
|
|
import shelve |
|
|
import torch |
|
|
|
|
|
import subprocess as sp |
|
|
|
|
|
from math import ceil |
|
|
from torch.utils.data import DataLoader, Sampler, Dataset |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
from env import END_OF_TEXT_TOKEN |
|
|
from gpt2_training.train_utils import (InputFeatures, InputFeatures_train, |
|
|
RedditExample) |
|
|
|
|
|
|
|
|
class BucketSampler(Sampler): |
|
|
""" |
|
|
this sampler will sort data by sequence length |
|
|
""" |
|
|
def __init__(self, lens, bucket_size, batch_size, |
|
|
droplast=False, shuffle=True): |
|
|
self._lens = lens |
|
|
self._batch_size = batch_size |
|
|
self._bucket_size = bucket_size |
|
|
self._droplast = droplast |
|
|
self._shuf = shuffle |
|
|
|
|
|
def __iter__(self): |
|
|
ids = list(range(len(self._lens))) |
|
|
if self._shuf: |
|
|
random.shuffle(ids) |
|
|
buckets = [sorted(ids[i:i+self._bucket_size], |
|
|
key=lambda i: self._lens[i], reverse=True) |
|
|
for i in range(0, len(ids), self._bucket_size)] |
|
|
batches = [bucket[i:i+self._batch_size] |
|
|
for bucket in buckets |
|
|
for i in range(0, len(bucket), self._batch_size)] |
|
|
if self._droplast: |
|
|
batches = [batch for batch in batches |
|
|
if len(batch) == self._batch_size] |
|
|
if self._shuf: |
|
|
random.shuffle(batches) |
|
|
return iter(batches) |
|
|
|
|
|
def __len__(self): |
|
|
bucket_sizes = ([self._bucket_size] |
|
|
* (len(self._lens) // self._bucket_size) |
|
|
+ [len(self._lens) % self._bucket_size]) |
|
|
if self._droplast: |
|
|
return sum(s//self._batch_size for s in bucket_sizes) |
|
|
else: |
|
|
return sum(math.ceil(s/self._batch_size) for s in bucket_sizes) |
|
|
|
|
|
|
|
|
class GPT2FeatureDataset(Dataset): |
|
|
""" pytorch dataset for GPT2 training """ |
|
|
def __init__(self, features, max_len=None): |
|
|
self.features = features |
|
|
self.max_len = max_len |
|
|
|
|
|
def __getitem__(self, i): |
|
|
feat_dict = self.features[i] |
|
|
if self.max_len is not None and feat_dict['input_len'] > self.max_len: |
|
|
|
|
|
feat_dict['input_ids'] = feat_dict['input_ids'][-self.max_len:] |
|
|
feat_dict['position_ids'] = feat_dict['position_ids'][ |
|
|
-self.max_len:] |
|
|
feat_dict['token_type_ids'] = feat_dict['token_type_ids'][ |
|
|
-self.max_len:] |
|
|
feat_dict['lm_labels'] = feat_dict['lm_labels'][-self.max_len:] |
|
|
try: |
|
|
for s in ['context_len', 'response_len']: |
|
|
if s in feat_dict.keys(): |
|
|
print("db file missing "+s) |
|
|
del feat_dict[s] |
|
|
except Exception: |
|
|
import pdb |
|
|
pdb.set_trace() |
|
|
|
|
|
feat = InputFeatures_train(**feat_dict) |
|
|
return feat |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.features) |
|
|
|
|
|
@staticmethod |
|
|
def collate(features): |
|
|
input_ids = pad_sequence([torch.tensor(f.input_ids, dtype=torch.long) |
|
|
for f in features], |
|
|
batch_first=True, padding_value=0) |
|
|
position_ids = pad_sequence([torch.tensor(f.position_ids, |
|
|
dtype=torch.long) |
|
|
for f in features], |
|
|
batch_first=True, padding_value=0) |
|
|
token_type_ids = pad_sequence([torch.tensor(f.token_type_ids, |
|
|
dtype=torch.long) |
|
|
for f in features], |
|
|
batch_first=True, padding_value=0) |
|
|
labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long) |
|
|
for f in features], |
|
|
batch_first=True, padding_value=-1) |
|
|
return (input_ids, position_ids, token_type_ids, labels) |
|
|
|
|
|
|
|
|
class BucketingDataLoader(object): |
|
|
""" this loads shelve db chunks and then convert to mini-batch loader""" |
|
|
def __init__(self, db_name, batch_size, max_seq_len, |
|
|
bucket=100, shuffle=True): |
|
|
self.db = shelve.open(f'{db_name}/db', 'r') |
|
|
self.batch_size = batch_size |
|
|
self.max_len = max_seq_len |
|
|
self.bucket_size = bucket * batch_size |
|
|
self.shuffle = shuffle |
|
|
|
|
|
def _get_keys(self): |
|
|
keys = list(self.db.keys()) |
|
|
return keys |
|
|
|
|
|
def __iter__(self): |
|
|
keys = self._get_keys() |
|
|
if self.shuffle: |
|
|
random.shuffle(keys) |
|
|
for key in keys: |
|
|
chunk = json.loads(gzip.decompress(self.db[key]).decode('utf-8')) |
|
|
|
|
|
trunc_chunk = [] |
|
|
lens = [] |
|
|
for feat in chunk: |
|
|
if feat['input_len'] > self.max_len: |
|
|
continue |
|
|
trunc_chunk.append(feat) |
|
|
lens.append(feat['input_len']) |
|
|
|
|
|
dataset = GPT2FeatureDataset(trunc_chunk, self.max_len) |
|
|
sampler = BucketSampler(lens, self.bucket_size, self.batch_size, |
|
|
droplast=True, shuffle=self.shuffle) |
|
|
loader = DataLoader(dataset, batch_sampler=sampler, |
|
|
num_workers=0, |
|
|
collate_fn=GPT2FeatureDataset.collate) |
|
|
yield from loader |
|
|
|
|
|
def __len__(self): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def __del__(self): |
|
|
self.db.close() |
|
|
|
|
|
|
|
|
class DistributedBucketingDataLoader(BucketingDataLoader): |
|
|
""" distributed version """ |
|
|
def __init__(self, rank, num_replica, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.rank = rank |
|
|
self.num_replica = num_replica |
|
|
|
|
|
def _get_keys(self): |
|
|
keys = list(self.db.keys())[self.rank::self.num_replica] |
|
|
return keys |
|
|
|
|
|
|
|
|
def convert_examples_to_features_dynamic(examples, tokenizer, |
|
|
max_seq_length=512): |
|
|
""" |
|
|
do not pad |
|
|
""" |
|
|
def featurize(example): |
|
|
conv_id = example.conv_id |
|
|
context_id = tokenizer.encode(example.context) |
|
|
end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN] |
|
|
|
|
|
|
|
|
response_id = tokenizer.encode(example.response) |
|
|
|
|
|
input_ids_len = len(context_id) + len(response_id) + 2 |
|
|
if input_ids_len > max_seq_length: |
|
|
if len(context_id) > input_ids_len - max_seq_length: |
|
|
|
|
|
|
|
|
context_id = context_id[input_ids_len - max_seq_length:] |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
if max_seq_length-len(context_id)-2 < 0: |
|
|
return None |
|
|
response_id = response_id[:max_seq_length-len(context_id)-2] |
|
|
|
|
|
input_ids = context_id + [end_of_text_id] + response_id + [end_of_text_id] |
|
|
|
|
|
|
|
|
lm_labels = [-1] * len(context_id) + response_id + [end_of_text_id] + [-1] |
|
|
|
|
|
position_ids = list(range(len(input_ids))) |
|
|
|
|
|
token_type_id = [0] * len(input_ids) |
|
|
|
|
|
return InputFeatures(conv_id, input_ids, position_ids, token_type_id, |
|
|
lm_labels, len(context_id), len(response_id)) |
|
|
|
|
|
|
|
|
features = [f for f in [featurize(ex) for ex in examples] if f is not None] |
|
|
return features |
|
|
|
|
|
|
|
|
class DynamicBatchingLoader(object): |
|
|
""" this loader takes raw text file, used for validate perplexity """ |
|
|
def __init__(self, corpus_file, tokenizer, normalize_data, |
|
|
batch_size, max_seq_length): |
|
|
self.corpus = corpus_file |
|
|
self.toker = tokenizer |
|
|
self.norm = normalize_data |
|
|
self.bs = batch_size |
|
|
self.max_seq_length = max_seq_length |
|
|
self.num_examples = self.get_len(corpus_file) |
|
|
|
|
|
def __iter__(self, epoch=1): |
|
|
if epoch > 0: |
|
|
for epoch in range(epoch): |
|
|
yield from self._iter_epoch() |
|
|
else: |
|
|
while True: |
|
|
yield from self._iter_epoch() |
|
|
|
|
|
def __len__(self): |
|
|
return ceil(self.num_examples/self.bs) |
|
|
|
|
|
def _iter_epoch(self): |
|
|
try: |
|
|
with open(self.corpus, 'r', encoding="utf-8") as corpus: |
|
|
i = 0 |
|
|
while True: |
|
|
examples = [] |
|
|
cur_bs = 0 |
|
|
while True: |
|
|
line = next(corpus).encode('utf-8').decode('utf-8') |
|
|
contents = line.split('\t') |
|
|
src, tgt_all = contents[0], contents[1:] |
|
|
for tgt in tgt_all: |
|
|
if self.norm: |
|
|
src_line = ' '.join(src.strip().split()) |
|
|
tgt_line = ' '.join(tgt.strip().split()) |
|
|
else: |
|
|
src_line = src.strip() |
|
|
tgt_line = tgt.strip() |
|
|
examples.append( |
|
|
RedditExample(i, src_line, tgt_line), |
|
|
) |
|
|
i += 1 |
|
|
cur_bs += 1 |
|
|
if cur_bs >= self.bs: |
|
|
break |
|
|
features = convert_examples_to_features_dynamic( |
|
|
examples, self.toker, self.max_seq_length) |
|
|
batch = self._batch_feature(features) |
|
|
yield batch |
|
|
except StopIteration: |
|
|
pass |
|
|
|
|
|
def _batch_feature(self, features): |
|
|
input_ids = pad_sequence([torch.tensor(f.choices_features['input_ids'], |
|
|
dtype=torch.long) |
|
|
for f in features], |
|
|
batch_first=True, padding_value=0) |
|
|
position_ids = pad_sequence( |
|
|
[torch.tensor(f.choices_features['position_ids'], dtype=torch.long) |
|
|
for f in features], |
|
|
batch_first=True, padding_value=0) |
|
|
token_type_ids = pad_sequence( |
|
|
[torch.tensor(f.choices_features['token_type_ids'], |
|
|
dtype=torch.long) |
|
|
for f in features], |
|
|
batch_first=True, padding_value=0) |
|
|
labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long) |
|
|
for f in features], |
|
|
batch_first=True, padding_value=-1) |
|
|
context_len = torch.tensor([f.context_len for f in features], |
|
|
dtype=torch.long) |
|
|
response_len = torch.tensor([f.response_len for f in features], |
|
|
dtype=torch.long) |
|
|
return (input_ids, position_ids, token_type_ids, labels, |
|
|
context_len, response_len) |
|
|
|
|
|
def get_len(self, corpus): |
|
|
n_line = int(sp.check_output(f"wc -l {corpus}".split(), |
|
|
universal_newlines=True).split()[0]) |
|
|
return n_line |
|
|
|