| | import dataclasses |
| | import pprint |
| | import time |
| | from functools import partial |
| | import json |
| | import base64 |
| | from multiprocessing import Pool |
| |
|
| | import h5py |
| | import mlxu |
| | from ml_collections.config_dict import config_dict |
| | from ml_collections import ConfigDict |
| | from tqdm import tqdm, trange |
| | import numpy as np |
| |
|
| | from datasets import load_dataset, load_from_disk |
| |
|
| |
|
| | class DatasetFactory(object): |
| | """ Datset builder class. """ |
| |
|
| | @staticmethod |
| | def get_default_config(updates=None): |
| | config = ConfigDict() |
| | config.type = 'huggingface' |
| | config.text_processor = TextProcessor.get_default_config() |
| | config.huggingface_dataset = HuggingfaceDataset.get_default_config() |
| | config.json_dataset = JsonDataset.get_default_config() |
| |
|
| | if updates is not None: |
| | config.update(ConfigDict(updates).copy_and_resolve_references()) |
| | return config |
| |
|
| | @classmethod |
| | def load_dataset(cls, config, tokenizer, **kwargs): |
| | config = cls.get_default_config(config) |
| | text_processor = TextProcessor(config.text_processor, tokenizer) |
| | if config.type == 'huggingface': |
| | return HuggingfaceDataset( |
| | config.huggingface_dataset, tokenizer, text_processor, **kwargs |
| | ) |
| | elif config.type == 'json': |
| | return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs) |
| | else: |
| | raise ValueError(f'Unknown dataset type: {config.type}') |
| |
|
| | def __init__(self): |
| | raise ValueError('DatasetFactory is a static class and should not be instantiated.') |
| |
|
| |
|
| | class TextProcessor(object): |
| | """ Example processor that converts a dictionary of texts into tokens. """ |
| |
|
| | @staticmethod |
| | def get_default_config(updates=None): |
| | config = ConfigDict() |
| | config.fields_from_example = '' |
| | config.fields = '' |
| | config.subfield_separator = ' ' |
| | config.add_bos_token = True |
| | config.add_eos_token = True |
| | config.prepend_text = '' |
| | config.base64_token_dtype = 'i4' |
| | if updates is not None: |
| | config.update(ConfigDict(updates).copy_and_resolve_references()) |
| | return config |
| |
|
| | def __init__(self, config, tokenizer): |
| | self.config = self.get_default_config(config) |
| | assert self.config.fields != '' or self.config.fields_from_example != '', ( |
| | 'Either fields or fields_from_example must be specified.' |
| | ) |
| | self.tokenizer = tokenizer |
| |
|
| | def __call__(self, example, has_aux=False): |
| | if has_aux: |
| | example, *aux = example |
| | else: |
| | aux = tuple() |
| | token_buffer = [] |
| | loss_mask_buffer = [] |
| |
|
| | if self.config.add_bos_token: |
| | token_buffer.append(self.tokenizer.bos_token_id) |
| | loss_mask_buffer.append(0.0) |
| |
|
| | if self.config.fields_from_example != '': |
| | fields = example[self.config.fields_from_example].split(',') |
| | else: |
| | fields = self.config.fields.split(',') |
| |
|
| | for i, field in enumerate(fields): |
| | if field.startswith('[') and field.endswith(']'): |
| | |
| | field = field[1:-1] |
| | mask = 0.0 |
| | else: |
| | mask = 1.0 |
| |
|
| | if field.startswith('<|') and field.endswith('|>'): |
| | |
| | field = field[2:-2] |
| | if field == 'bos': |
| | token_buffer.append(self.tokenizer.bos_token_id) |
| | elif field == 'eos': |
| | token_buffer.append(self.tokenizer.eos_token_id) |
| | else: |
| | |
| | token_buffer.append(int(field)) |
| | loss_mask_buffer.append(mask) |
| | elif field.startswith('{') and field.endswith('}'): |
| | field = field[1:-1] |
| | |
| | tokens = np.frombuffer( |
| | base64.b64decode(example[field]), |
| | dtype=self.config.base64_token_dtype |
| | ).tolist() |
| | token_buffer.extend(tokens) |
| | loss_mask_buffer.extend([mask for _ in range(len(tokens))]) |
| | else: |
| | subfields = field.split('+') |
| | text = self.config.subfield_separator.join( |
| | [example[subfield] for subfield in subfields] |
| | ) |
| | if i == 0: |
| | text = self.config.prepend_text + text |
| | tokens = self.tokenizer.encode(text) |
| | token_buffer.extend(tokens) |
| | loss_mask_buffer.extend([mask for _ in range(len(tokens))]) |
| |
|
| | if self.config.add_eos_token: |
| | token_buffer.append(self.tokenizer.eos_token_id) |
| | loss_mask_buffer.append(1.0) |
| |
|
| | return token_buffer, loss_mask_buffer, *aux |
| |
|
| |
|
| | class HuggingfaceDataset(object): |
| | """ Huggingface dataset, where the dataset is loaded using the huggingface |
| | datasets.load_dataset() function. |
| | """ |
| |
|
| | @staticmethod |
| | def get_default_config(updates=None): |
| | config = ConfigDict() |
| | config.path = 'c4' |
| | config.name = 'en' |
| | config.split = 'train' |
| | config.streaming = False |
| | config.seq_length = 1024 |
| | config.batch_size = 8 |
| | config.always_start_with_bos = False |
| | config.start_seek_loc = 0 |
| | config.tokens_count_at_start = 0 |
| | config.batch_token_dtype = 'i4' |
| | config.reset_dataset_loc = False |
| |
|
| | if updates is not None: |
| | config.update(ConfigDict(updates).copy_and_resolve_references()) |
| | return config |
| |
|
| | def __init__(self, config, tokenizer, text_processor, eval_dataset=False): |
| | self.config = self.get_default_config(config) |
| | name = self.config.name if self.config.name != '' else None |
| | split = self.config.split if self.config.split != '' else None |
| | self._tokenizer = tokenizer |
| | self._text_processor = text_processor |
| | self._dataset = load_from_disk( |
| | self.config.path |
| | )[split] |
| | self._dataset = self._dataset.to_iterable_dataset(num_shards=128 if len(self._dataset) > 128 else len(self._dataset)) |
| | self._eval_dataset = eval_dataset |
| | self._train_epochs = 0 |
| | self._dataset_loc = self.config.start_seek_loc |
| | self._total_tokens = self.config.tokens_count_at_start |
| | self._index = 0 |
| | self.reset_dataset_loc = self.config.reset_dataset_loc |
| |
|
| |
|
| | def __iter__(self): |
| | if not self._eval_dataset and self._train_epochs > 0: |
| | self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000) |
| | chunk_size = self.config.batch_size * self.config.seq_length |
| | while True: |
| | token_buffer = [] |
| | loss_mask_buffer = [] |
| | if not self._eval_dataset and self._train_epochs > 0: |
| | self._dataset.set_epoch(self._train_epochs) |
| | for index, example in enumerate(self._dataset): |
| | self._index = index |
| | if not self._eval_dataset and self._dataset_loc > index: |
| | continue |
| | tokens, loss_masks = self.text_processor(example) |
| | token_buffer.extend(tokens) |
| | loss_mask_buffer.extend(loss_masks) |
| | while len(token_buffer) > chunk_size + 1: |
| | self._total_tokens += chunk_size |
| | metrics = { |
| | 'dataset_example_index': index, |
| | 'dataset_total_tokens': self._total_tokens, |
| | 'epoch': self._train_epochs, |
| | } |
| | batch = { |
| | 'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape( |
| | self.config.batch_size, -1 |
| | ), |
| | 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape( |
| | self.config.batch_size, -1 |
| | ), |
| | 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape( |
| | self.config.batch_size, -1 |
| | ), |
| | } |
| | if self.config.always_start_with_bos: |
| | batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id |
| | yield batch, metrics |
| | token_buffer = token_buffer[chunk_size:] |
| | loss_mask_buffer = loss_mask_buffer[chunk_size:] |
| |
|
| | if self._eval_dataset: |
| | break |
| | else: |
| | if self._train_epochs == 0: |
| | self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000) |
| | self._dataset_loc = 0 |
| | self._train_epochs += 1 |
| |
|
| | def get_state_dict(self): |
| | return dict( |
| | config=self.config, |
| | dataset_loc=self._index, |
| | total_tokens=self._total_tokens, |
| | epochs=self._train_epochs, |
| | ) |
| |
|
| | def load_state_dict(self, state_dict): |
| | if 'config' in state_dict: |
| | self.config.update(ConfigDict(state_dict['config'])) |
| | self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc) |
| | self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start) |
| | self._train_epochs = state_dict.get('epochs', 0) |
| | if self.reset_dataset_loc: |
| | self._dataset_loc = 0 |
| | self._train_epochs = 0 |
| |
|
| |
|
| | @property |
| | def seq_length(self): |
| | return self.config.seq_length |
| |
|
| | @property |
| | def tokenizer(self): |
| | return self._tokenizer |
| |
|
| | @property |
| | def text_processor(self): |
| | return self._text_processor |
| |
|
| | @property |
| | def dataset(self): |
| | return self._dataset |
| |
|
| | @property |
| | def vocab_size(self): |
| | return len(self._tokenizer) |
| |
|
| |
|
| | class JsonDataset(object): |
| | """ JSON dataset, where each line of the data file contains a JSON |
| | dictionary with text fields. |
| | """ |
| |
|
| | @staticmethod |
| | def get_default_config(updates=None): |
| | config = ConfigDict() |
| | config.path = '' |
| | config.seq_length = 1024 |
| | config.batch_size = 8 |
| | config.always_start_with_bos = False |
| | config.start_seek_loc = 0 |
| | config.example_index_at_start = 0 |
| | config.tokens_count_at_start = 0 |
| | config.tokenizer_processes = 1 |
| | config.tokenizer_parallel_chunk_size = 32 |
| | config.tokenizer_parallel_batch_size = 1024 |
| | config.throughput_average_window_size = 200 |
| |
|
| | if updates is not None: |
| | config.update(ConfigDict(updates).copy_and_resolve_references()) |
| | return config |
| |
|
| | def __init__(self, config, tokenizer, text_processor): |
| | self.config = self.get_default_config(config) |
| | assert self.config.path != '' |
| | self._tokenizer = tokenizer |
| | self._text_processor = text_processor |
| | self._index = self.config.example_index_at_start |
| | self._file_loc = self.config.start_seek_loc |
| | self._total_tokens = self.config.tokens_count_at_start |
| |
|
| | def parse_json(self, line): |
| | if not line or line == '\n': |
| | return None |
| | try: |
| | data = json.loads(line) |
| | except json.decoder.JSONDecodeError: |
| | print(f'Error parsing json line:\n{line}') |
| | return None |
| | return data |
| |
|
| | def json_iterator(self): |
| | with mlxu.open_file(self.config.path, 'r') as fin: |
| | fin.seek(self._file_loc) |
| | while True: |
| | line = fin.readline() |
| | self._file_loc = fin.tell() |
| | if not line: |
| | self._index = 0 |
| | fin.seek(0) |
| | continue |
| |
|
| | data = self.parse_json(line) |
| | if data is not None: |
| | |
| | yield data, self._file_loc, self._index |
| | self._index += 1 |
| |
|
| | def batched(self, iterator, batch_size): |
| | batch = [] |
| | for example in iterator: |
| | batch.append(example) |
| | if len(batch) == batch_size: |
| | yield batch |
| | batch = [] |
| | if len(batch) > 0: |
| | yield batch |
| |
|
| | def parallel_example_iterator(self): |
| | if self.config.tokenizer_processes == 1: |
| | for example, loc, index in self.json_iterator(): |
| | yield self.text_processor((example, loc, index), has_aux=True) |
| | else: |
| | process_pool = Pool(self.config.tokenizer_processes) |
| | batched_iterator = self.batched( |
| | self.json_iterator(), self.config.tokenizer_parallel_batch_size |
| | ) |
| | with process_pool as pool: |
| | map_fn = partial(self.text_processor, has_aux=True) |
| | next_batch = pool.map_async( |
| | map_fn, next(batched_iterator), |
| | chunksize=self.config.tokenizer_parallel_chunk_size |
| | ) |
| | while True: |
| | current_batch = next_batch |
| | next_batch = pool.map_async( |
| | map_fn, next(batched_iterator), |
| | chunksize=self.config.tokenizer_parallel_chunk_size |
| | ) |
| | for example in current_batch.get(): |
| | yield example |
| |
|
| | def __iter__(self): |
| | chunk_size = self.config.batch_size * self.config.seq_length |
| | token_buffer = [] |
| | loss_mask_buffer = [] |
| | last_time = 0.0 |
| | step_times = [] |
| | start_time = time.time() |
| | start_tokens = self._total_tokens |
| | for tokens, loss_masks, loc, index in self.parallel_example_iterator(): |
| | token_buffer.extend(tokens) |
| | loss_mask_buffer.extend(loss_masks) |
| | while len(token_buffer) > chunk_size + 1: |
| | self._total_tokens += chunk_size |
| | step_times.append(time.time() - last_time) |
| | last_time = time.time() |
| | if len(step_times) > self.config.throughput_average_window_size: |
| | step_times = step_times[-self.config.throughput_average_window_size:] |
| | average_throughput = chunk_size / np.mean(step_times) |
| | accumulated_throughput = ( |
| | (self._total_tokens - start_tokens) / (time.time() - start_time) |
| | ) |
| | metrics = { |
| | 'dataset_file_loc': loc, |
| | 'dataset_example_index': index, |
| | 'dataset_total_tokens': self._total_tokens, |
| | 'dataset_accumulated_tps': accumulated_throughput, |
| | 'dataset_average_tps': average_throughput, |
| | } |
| | batch = { |
| | 'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape( |
| | self.config.batch_size, -1 |
| | ), |
| | 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape( |
| | self.config.batch_size, -1 |
| | ), |
| | 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape( |
| | self.config.batch_size, -1 |
| | ), |
| | } |
| | if self.config.always_start_with_bos: |
| | batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id |
| | yield batch, metrics |
| | token_buffer = token_buffer[chunk_size:] |
| | loss_mask_buffer = loss_mask_buffer[chunk_size:] |
| |
|
| | def get_state_dict(self): |
| | return dict( |
| | config=self.config, |
| | index=self._index, |
| | file_loc=self._file_loc, |
| | total_tokens=self._total_tokens, |
| | ) |
| |
|
| | def load_state_dict(self, state_dict): |
| | if 'config' in state_dict: |
| | self.config.update(ConfigDict(state_dict['config'])) |
| | self._index = state_dict.get('index', self.config.example_index_at_start) |
| | self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc) |
| | self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start) |
| |
|
| | @property |
| | def seq_length(self): |
| | return self.config.seq_length |
| |
|
| | @property |
| | def tokenizer(self): |
| | return self._tokenizer |
| |
|
| | @property |
| | def text_processor(self): |
| | return self._text_processor |
| |
|
| | @property |
| | def vocab_size(self): |
| | return len(self.tokenizer) |
| |
|