import torch from torch.utils.data import IterableDataset import random import jsonlines from tqdm import tqdm from tokenizer import load_tokenizer import concurrent.futures class MultiSourceDataset(IterableDataset): def __init__(self, source_files, probs, mini_batch=4): """ source_files example: [[file1.jsonl, file2.jsonl],[file3.jsonl]] probs example: [0.3, 0.7] """ super().__init__() self.sources = source_files self.probs = torch.tensor(probs) self.mini_batch = mini_batch self.curr_count = 0 def __iter__(self): while True: if self.curr_count == 0: idx = torch.multinomial(self.probs, 1).item() filename = random.choice(self.sources[idx]) self.curr_data = [] with jsonlines.open(filename,'r') as f: for obj in f: self.curr_data.append(obj) self.curr_count += 1 if self.curr_count == self.mini_batch: self.curr_count = 0 yield random.choice(self.curr_data) class MultiSourceDatasetV2(IterableDataset): def __init__(self, source_files, probs, mini_batch=4, buffer_size_per_worker=12000, num_workers=10, tokenizer_path="tokenizer.model"): """ source_files example: [[file1.jsonl, file2.jsonl],[file3.jsonl]] probs example: [0.3, 0.7] """ super().__init__() self.sources = source_files self.probs = torch.tensor(probs) self.mini_batch = mini_batch self.buffer_size_per_worker = buffer_size_per_worker self.num_workers = num_workers if tokenizer_path: self.tokenizer = load_tokenizer(tokenizer_path) else: self.tokenizer = None def _transform_ids(self, ids, block_size=1024, eos_id=50303): ids = ids + [eos_id] if len(ids) > block_size + 1: start = random.randint(0,len(ids)-block_size-1) ids = ids[start:start+block_size+1] elif len(ids) < block_size + 1: ids += [eos_id] * (block_size + 1 - len(ids)) ids = torch.tensor(ids,dtype=torch.int64) return ids[:-1], ids[1:] def _get_buffer(self): print('\nGetting data buffer...') buffer = [] for _ in tqdm(range(self.buffer_size_per_worker)): idx = torch.multinomial(self.probs, 1).item() filename = random.choice(self.sources[idx]) curr_data = [] with jsonlines.open(filename,'r') as f: for obj in f: curr_data.append(obj) new_buffer = random.sample(curr_data, self.mini_batch) if self.tokenizer is None: buffer += new_buffer else: buffer += [self._transform_ids(self.tokenizer.encode(item['text'])) for item in new_buffer] self.buffer += buffer def __iter__(self): while True: self.buffer = [] with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_workers) as executor: for _ in range(self.num_workers): executor.submit(self._get_buffer) for item in self.buffer: yield item