Delete dataloader.py
Browse files- dataloader.py +0 -49
dataloader.py
DELETED
|
@@ -1,49 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import numpy as np
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
script_dir = os.path.dirname(__file__)
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class DataLoaderLite:
|
| 9 |
-
""" A simple dataloader for FineWebEdu-10B dataset """
|
| 10 |
-
|
| 11 |
-
def __init__(self, B, T, process_rank, num_processes, split='train'):
|
| 12 |
-
super().__init__()
|
| 13 |
-
self.B, self.T = B, T
|
| 14 |
-
self.process_rank = process_rank
|
| 15 |
-
self.num_processes = num_processes
|
| 16 |
-
assert split in {'train', 'val'}
|
| 17 |
-
|
| 18 |
-
# get the shard filenames
|
| 19 |
-
data_root = os.path.join(script_dir, "../data/edu_fineweb10B")
|
| 20 |
-
shard_filenames = os.listdir(data_root)
|
| 21 |
-
shard_filenames = sorted([filename for filename in shard_filenames if split in filename])
|
| 22 |
-
self.shard_filepaths = [os.path.join(data_root, filename) for filename in shard_filenames]
|
| 23 |
-
assert len(self.shard_filepaths) > 0, f'no shards found for split {split}'
|
| 24 |
-
master_process = process_rank == 0
|
| 25 |
-
if master_process:
|
| 26 |
-
print(f'found {len(self.shard_filepaths)} shards for split {split}')
|
| 27 |
-
self.reset()
|
| 28 |
-
|
| 29 |
-
def load_tokens(self, filepath):
|
| 30 |
-
tokens = torch.tensor(np.load(filepath).astype(np.int32), dtype=torch.long)
|
| 31 |
-
return tokens
|
| 32 |
-
|
| 33 |
-
def reset(self):
|
| 34 |
-
# state, init at shard 0
|
| 35 |
-
self.curr_shard = 0
|
| 36 |
-
self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard])
|
| 37 |
-
self.curr_pos = self.B * self.T * self.process_rank
|
| 38 |
-
|
| 39 |
-
def next_batch(self):
|
| 40 |
-
B, T = self.B, self.T
|
| 41 |
-
batch = self.tokens[self.curr_pos : self.curr_pos + B*T + 1]
|
| 42 |
-
x_batch = batch[:-1].view(B, T)
|
| 43 |
-
y_batch = batch[1:].view(B, T)
|
| 44 |
-
self.curr_pos += B * T * self.num_processes
|
| 45 |
-
if self.curr_pos + (B * T + 1) > len(self.tokens):
|
| 46 |
-
self.curr_shard = (self.curr_shard + 1) % len(self.shard_filepaths)
|
| 47 |
-
self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard])
|
| 48 |
-
self.curr_pos = self.B * self.T * self.process_rank
|
| 49 |
-
return x_batch, y_batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|