Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| import math | |
| import pickle | |
| import random | |
| from typing import List, Iterator, Callable | |
| from torch import Tensor as T | |
| logger = logging.getLogger() | |
| def read_serialized_data_from_files(paths: List[str]) -> List: | |
| results = [] | |
| for i, path in enumerate(paths): | |
| with open(path, "rb") as reader: | |
| logger.info("Reading file %s", path) | |
| data = pickle.load(reader) | |
| results.extend(data) | |
| logger.info("Aggregated data size: {}".format(len(results))) | |
| logger.info("Total data size: {}".format(len(results))) | |
| return results | |
| def read_data_from_json_files(paths: List[str], upsample_rates: List = None) -> List: | |
| results = [] | |
| if upsample_rates is None: | |
| upsample_rates = [1] * len(paths) | |
| assert len(upsample_rates) == len( | |
| paths | |
| ), "up-sample rates parameter doesn't match input files amount" | |
| for i, path in enumerate(paths): | |
| with open(path, "r", encoding="utf-8") as f: | |
| logger.info("Reading file %s" % path) | |
| data = json.load(f) | |
| upsample_factor = int(upsample_rates[i]) | |
| data = data * upsample_factor | |
| results.extend(data) | |
| logger.info("Aggregated data size: {}".format(len(results))) | |
| return results | |
| class ShardedDataIterator(object): | |
| """ | |
| General purpose data iterator to be used for Pytorch's DDP mode where every node should handle its own part of | |
| the data. | |
| Instead of cutting data shards by their min size, it sets the amount of iterations by the maximum shard size. | |
| It fills the extra sample by just taking first samples in a shard. | |
| It can also optionally enforce identical batch size for all iterations (might be useful for DP mode). | |
| """ | |
| def __init__( | |
| self, | |
| data: list, | |
| shard_id: int = 0, | |
| num_shards: int = 1, | |
| batch_size: int = 1, | |
| shuffle=True, | |
| shuffle_seed: int = 0, | |
| offset: int = 0, | |
| strict_batch_size: bool = False, | |
| ): | |
| self.data = data | |
| total_size = len(data) | |
| self.shards_num = max(num_shards, 1) | |
| self.shard_id = max(shard_id, 0) | |
| samples_per_shard = math.ceil(total_size / self.shards_num) | |
| self.shard_start_idx = self.shard_id * samples_per_shard | |
| self.shard_end_idx = min(self.shard_start_idx + samples_per_shard, total_size) | |
| if strict_batch_size: | |
| self.max_iterations = math.ceil(samples_per_shard / batch_size) | |
| else: | |
| self.max_iterations = int(samples_per_shard / batch_size) | |
| logger.debug( | |
| "samples_per_shard=%d, shard_start_idx=%d, shard_end_idx=%d, max_iterations=%d", | |
| samples_per_shard, | |
| self.shard_start_idx, | |
| self.shard_end_idx, | |
| self.max_iterations, | |
| ) | |
| self.iteration = offset # to track in-shard iteration status | |
| self.shuffle = shuffle | |
| self.batch_size = batch_size | |
| self.shuffle_seed = shuffle_seed | |
| self.strict_batch_size = strict_batch_size | |
| def total_data_len(self) -> int: | |
| return len(self.data) | |
| def iterate_data(self, epoch: int = 0) -> Iterator[List]: | |
| if self.shuffle: | |
| # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration | |
| epoch_rnd = random.Random(self.shuffle_seed + epoch) | |
| epoch_rnd.shuffle(self.data) | |
| # if resuming iteration somewhere in the middle of epoch, one needs to adjust max_iterations | |
| max_iterations = self.max_iterations - self.iteration | |
| shard_samples = self.data[self.shard_start_idx : self.shard_end_idx] | |
| for i in range( | |
| self.iteration * self.batch_size, len(shard_samples), self.batch_size | |
| ): | |
| items = shard_samples[i : i + self.batch_size] | |
| if self.strict_batch_size and len(items) < self.batch_size: | |
| logger.debug("Extending batch to max size") | |
| items.extend(shard_samples[0 : self.batch_size - len(items)]) | |
| self.iteration += 1 | |
| yield items | |
| # some shards may done iterating while the others are at the last batch. Just return the first batch | |
| while self.iteration < max_iterations: | |
| logger.debug("Fulfilling non complete shard=".format(self.shard_id)) | |
| self.iteration += 1 | |
| batch = shard_samples[0 : self.batch_size] | |
| yield batch | |
| logger.debug( | |
| "Finished iterating, iteration={}, shard={}".format( | |
| self.iteration, self.shard_id | |
| ) | |
| ) | |
| # reset the iteration status | |
| self.iteration = 0 | |
| def get_iteration(self) -> int: | |
| return self.iteration | |
| def apply(self, visitor_func: Callable): | |
| for sample in self.data: | |
| visitor_func(sample) | |
| def normalize_question(question: str) -> str: | |
| if question[-1] == "?": | |
| question = question[:-1] | |
| return question | |
| class Tensorizer(object): | |
| """ | |
| Component for all text to model input data conversions and related utility methods | |
| """ | |
| # Note: title, if present, is supposed to be put before text (i.e. optional title + document body) | |
| def text_to_tensor( | |
| self, text: str, title: str = None, add_special_tokens: bool = True | |
| ): | |
| raise NotImplementedError | |
| def get_pair_separator_ids(self) -> T: | |
| raise NotImplementedError | |
| def get_pad_id(self) -> int: | |
| raise NotImplementedError | |
| def get_attn_mask(self, tokens_tensor: T): | |
| raise NotImplementedError | |
| def is_sub_word_id(self, token_id: int): | |
| raise NotImplementedError | |
| def to_string(self, token_ids, skip_special_tokens=True): | |
| raise NotImplementedError | |
| def set_pad_to_max(self, pad: bool): | |
| raise NotImplementedError |