| import torch | |
| from torch.utils.data import IterableDataset | |
| import random | |
| import jsonlines | |
| from tqdm import tqdm | |
| from tokenizer import load_tokenizer | |
| import concurrent.futures | |
| class MultiSourceDataset(IterableDataset): | |
| def __init__(self, source_files, probs, mini_batch=4): | |
| """ | |
| source_files example: [[file1.jsonl, file2.jsonl],[file3.jsonl]] | |
| probs example: [0.3, 0.7] | |
| """ | |
| super().__init__() | |
| self.sources = source_files | |
| self.probs = torch.tensor(probs) | |
| self.mini_batch = mini_batch | |
| self.curr_count = 0 | |
| def __iter__(self): | |
| while True: | |
| if self.curr_count == 0: | |
| idx = torch.multinomial(self.probs, 1).item() | |
| filename = random.choice(self.sources[idx]) | |
| self.curr_data = [] | |
| with jsonlines.open(filename,'r') as f: | |
| for obj in f: | |
| self.curr_data.append(obj) | |
| self.curr_count += 1 | |
| if self.curr_count == self.mini_batch: | |
| self.curr_count = 0 | |
| yield random.choice(self.curr_data) | |
| class MultiSourceDatasetV2(IterableDataset): | |
| def __init__(self, source_files, probs, mini_batch=4, buffer_size_per_worker=12000, num_workers=10, tokenizer_path="tokenizer.model"): | |
| """ | |
| source_files example: [[file1.jsonl, file2.jsonl],[file3.jsonl]] | |
| probs example: [0.3, 0.7] | |
| """ | |
| super().__init__() | |
| self.sources = source_files | |
| self.probs = torch.tensor(probs) | |
| self.mini_batch = mini_batch | |
| self.buffer_size_per_worker = buffer_size_per_worker | |
| self.num_workers = num_workers | |
| if tokenizer_path: | |
| self.tokenizer = load_tokenizer(tokenizer_path) | |
| else: | |
| self.tokenizer = None | |
| def _transform_ids(self, ids, block_size=1024, eos_id=50303): | |
| ids = ids + [eos_id] | |
| if len(ids) > block_size + 1: | |
| start = random.randint(0,len(ids)-block_size-1) | |
| ids = ids[start:start+block_size+1] | |
| elif len(ids) < block_size + 1: | |
| ids += [eos_id] * (block_size + 1 - len(ids)) | |
| ids = torch.tensor(ids,dtype=torch.int64) | |
| return ids[:-1], ids[1:] | |
| def _get_buffer(self): | |
| print('\nGetting data buffer...') | |
| buffer = [] | |
| for _ in tqdm(range(self.buffer_size_per_worker)): | |
| idx = torch.multinomial(self.probs, 1).item() | |
| filename = random.choice(self.sources[idx]) | |
| curr_data = [] | |
| with jsonlines.open(filename,'r') as f: | |
| for obj in f: | |
| curr_data.append(obj) | |
| new_buffer = random.sample(curr_data, self.mini_batch) | |
| if self.tokenizer is None: | |
| buffer += new_buffer | |
| else: | |
| buffer += [self._transform_ids(self.tokenizer.encode(item['text'])) for item in new_buffer] | |
| self.buffer += buffer | |
| def __iter__(self): | |
| while True: | |
| self.buffer = [] | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_workers) as executor: | |
| for _ in range(self.num_workers): | |
| executor.submit(self._get_buffer) | |
| for item in self.buffer: | |
| yield item |