LightNovelModel-Alpha / dataset_loader.py
hugfaceguy0001's picture
upload model and train/infer codes
e10f35b verified
Raw
History Blame Contribute Delete
3.41 kB
import torch
from torch.utils.data import IterableDataset
import random
import jsonlines
from tqdm import tqdm
from tokenizer import load_tokenizer
import concurrent.futures
class MultiSourceDataset(IterableDataset):
def __init__(self, source_files, probs, mini_batch=4):
"""
source_files example: [[file1.jsonl, file2.jsonl],[file3.jsonl]]
probs example: [0.3, 0.7]
"""
super().__init__()
self.sources = source_files
self.probs = torch.tensor(probs)
self.mini_batch = mini_batch
self.curr_count = 0
def __iter__(self):
while True:
if self.curr_count == 0:
idx = torch.multinomial(self.probs, 1).item()
filename = random.choice(self.sources[idx])
self.curr_data = []
with jsonlines.open(filename,'r') as f:
for obj in f:
self.curr_data.append(obj)
self.curr_count += 1
if self.curr_count == self.mini_batch:
self.curr_count = 0
yield random.choice(self.curr_data)
class MultiSourceDatasetV2(IterableDataset):
def __init__(self, source_files, probs, mini_batch=4, buffer_size_per_worker=12000, num_workers=10, tokenizer_path="tokenizer.model"):
"""
source_files example: [[file1.jsonl, file2.jsonl],[file3.jsonl]]
probs example: [0.3, 0.7]
"""
super().__init__()
self.sources = source_files
self.probs = torch.tensor(probs)
self.mini_batch = mini_batch
self.buffer_size_per_worker = buffer_size_per_worker
self.num_workers = num_workers
if tokenizer_path:
self.tokenizer = load_tokenizer(tokenizer_path)
else:
self.tokenizer = None
def _transform_ids(self, ids, block_size=1024, eos_id=50303):
ids = ids + [eos_id]
if len(ids) > block_size + 1:
start = random.randint(0,len(ids)-block_size-1)
ids = ids[start:start+block_size+1]
elif len(ids) < block_size + 1:
ids += [eos_id] * (block_size + 1 - len(ids))
ids = torch.tensor(ids,dtype=torch.int64)
return ids[:-1], ids[1:]
def _get_buffer(self):
print('\nGetting data buffer...')
buffer = []
for _ in tqdm(range(self.buffer_size_per_worker)):
idx = torch.multinomial(self.probs, 1).item()
filename = random.choice(self.sources[idx])
curr_data = []
with jsonlines.open(filename,'r') as f:
for obj in f:
curr_data.append(obj)
new_buffer = random.sample(curr_data, self.mini_batch)
if self.tokenizer is None:
buffer += new_buffer
else:
buffer += [self._transform_ids(self.tokenizer.encode(item['text'])) for item in new_buffer]
self.buffer += buffer
def __iter__(self):
while True:
self.buffer = []
with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_workers) as executor:
for _ in range(self.num_workers):
executor.submit(self._get_buffer)
for item in self.buffer:
yield item