import json import logging import math import pickle import random from typing import List, Iterator, Callable from torch import Tensor as T logger = logging.getLogger() def read_serialized_data_from_files(paths: List[str]) -> List: results = [] for i, path in enumerate(paths): with open(path, "rb") as reader: logger.info("Reading file %s", path) data = pickle.load(reader) results.extend(data) logger.info("Aggregated data size: {}".format(len(results))) logger.info("Total data size: {}".format(len(results))) return results def read_data_from_json_files(paths: List[str], upsample_rates: List = None) -> List: results = [] if upsample_rates is None: upsample_rates = [1] * len(paths) assert len(upsample_rates) == len( paths ), "up-sample rates parameter doesn't match input files amount" for i, path in enumerate(paths): with open(path, "r", encoding="utf-8") as f: logger.info("Reading file %s" % path) data = json.load(f) upsample_factor = int(upsample_rates[i]) data = data * upsample_factor results.extend(data) logger.info("Aggregated data size: {}".format(len(results))) return results class ShardedDataIterator(object): """ General purpose data iterator to be used for Pytorch's DDP mode where every node should handle its own part of the data. Instead of cutting data shards by their min size, it sets the amount of iterations by the maximum shard size. It fills the extra sample by just taking first samples in a shard. It can also optionally enforce identical batch size for all iterations (might be useful for DP mode). """ def __init__( self, data: list, shard_id: int = 0, num_shards: int = 1, batch_size: int = 1, shuffle=True, shuffle_seed: int = 0, offset: int = 0, strict_batch_size: bool = False, ): self.data = data total_size = len(data) self.shards_num = max(num_shards, 1) self.shard_id = max(shard_id, 0) samples_per_shard = math.ceil(total_size / self.shards_num) self.shard_start_idx = self.shard_id * samples_per_shard self.shard_end_idx = min(self.shard_start_idx + samples_per_shard, total_size) if strict_batch_size: self.max_iterations = math.ceil(samples_per_shard / batch_size) else: self.max_iterations = int(samples_per_shard / batch_size) logger.debug( "samples_per_shard=%d, shard_start_idx=%d, shard_end_idx=%d, max_iterations=%d", samples_per_shard, self.shard_start_idx, self.shard_end_idx, self.max_iterations, ) self.iteration = offset # to track in-shard iteration status self.shuffle = shuffle self.batch_size = batch_size self.shuffle_seed = shuffle_seed self.strict_batch_size = strict_batch_size def total_data_len(self) -> int: return len(self.data) def iterate_data(self, epoch: int = 0) -> Iterator[List]: if self.shuffle: # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration epoch_rnd = random.Random(self.shuffle_seed + epoch) epoch_rnd.shuffle(self.data) # if resuming iteration somewhere in the middle of epoch, one needs to adjust max_iterations max_iterations = self.max_iterations - self.iteration shard_samples = self.data[self.shard_start_idx : self.shard_end_idx] for i in range( self.iteration * self.batch_size, len(shard_samples), self.batch_size ): items = shard_samples[i : i + self.batch_size] if self.strict_batch_size and len(items) < self.batch_size: logger.debug("Extending batch to max size") items.extend(shard_samples[0 : self.batch_size - len(items)]) self.iteration += 1 yield items # some shards may done iterating while the others are at the last batch. Just return the first batch while self.iteration < max_iterations: logger.debug("Fulfilling non complete shard=".format(self.shard_id)) self.iteration += 1 batch = shard_samples[0 : self.batch_size] yield batch logger.debug( "Finished iterating, iteration={}, shard={}".format( self.iteration, self.shard_id ) ) # reset the iteration status self.iteration = 0 def get_iteration(self) -> int: return self.iteration def apply(self, visitor_func: Callable): for sample in self.data: visitor_func(sample) def normalize_question(question: str) -> str: if question[-1] == "?": question = question[:-1] return question class Tensorizer(object): """ Component for all text to model input data conversions and related utility methods """ # Note: title, if present, is supposed to be put before text (i.e. optional title + document body) def text_to_tensor( self, text: str, title: str = None, add_special_tokens: bool = True ): raise NotImplementedError def get_pair_separator_ids(self) -> T: raise NotImplementedError def get_pad_id(self) -> int: raise NotImplementedError def get_attn_mask(self, tokens_tensor: T): raise NotImplementedError def is_sub_word_id(self, token_id: int): raise NotImplementedError def to_string(self, token_ids, skip_special_tokens=True): raise NotImplementedError def set_pad_to_max(self, pad: bool): raise NotImplementedError