| import os |
| import numpy as np |
| import torch |
|
|
| script_dir = os.path.dirname(__file__) |
|
|
|
|
| class DataLoaderLite: |
| """ A simple dataloader for FineWebEdu-10B dataset """ |
|
|
| def __init__(self, B, T, process_rank, num_processes, split='train'): |
| super().__init__() |
| self.B, self.T = B, T |
| self.process_rank = process_rank |
| self.num_processes = num_processes |
| assert split in {'train', 'val'} |
| |
| |
| data_root = os.path.join(script_dir, "../data/edu_fineweb10B") |
| shard_filenames = os.listdir(data_root) |
| shard_filenames = sorted([filename for filename in shard_filenames if split in filename]) |
| self.shard_filepaths = [os.path.join(data_root, filename) for filename in shard_filenames] |
| assert len(self.shard_filepaths) > 0, f'no shards found for split {split}' |
| master_process = process_rank == 0 |
| if master_process: |
| print(f'found {len(self.shard_filepaths)} shards for split {split}') |
| self.reset() |
|
|
| def load_tokens(self, filepath): |
| tokens = torch.tensor(np.load(filepath).astype(np.int32), dtype=torch.long) |
| return tokens |
|
|
| def reset(self): |
| |
| self.curr_shard = 0 |
| self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard]) |
| self.curr_pos = self.B * self.T * self.process_rank |
|
|
| def next_batch(self): |
| B, T = self.B, self.T |
| batch = self.tokens[self.curr_pos : self.curr_pos + B*T + 1] |
| x_batch = batch[:-1].view(B, T) |
| y_batch = batch[1:].view(B, T) |
| self.curr_pos += B * T * self.num_processes |
| if self.curr_pos + (B * T + 1) > len(self.tokens): |
| self.curr_shard = (self.curr_shard + 1) % len(self.shard_filepaths) |
| self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard]) |
| self.curr_pos = self.B * self.T * self.process_rank |
| return x_batch, y_batch |