File size: 1,781 Bytes
ccfb646 7cec32f ccfb646 498886e ccfb646 498886e ccfb646 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import os
import numpy as np
import torch
#Data loader
class DataLoader_1:
def __init__(self, B, T, process_rank, num_processes, split, master_process):
self.B = B
self.T = T
self.process_rank = process_rank
self.num_processes = num_processes
assert split in {'train', 'val'}
data_root = "data/edu_fineweb10B"
shards = os.listdir(data_root)
shards = [s for s in shards if split in s]
shards = sorted(shards)
shards = [os.path.join(data_root, s) for s in shards]
self.shards = shards
assert len(shards)> 0, f"no shards found for split {split}"
if master_process:
print(f"found {len(shards)} shards for split {split}")
self.reset()
def load_tokens(self, filename):
npt = np.load(filename)
npt = npt.astype(np.int32)
ptt = torch.tensor(npt, dtype=torch.long)
return ptt
def reset(self):
#state, init at shard 0
self.current_shard = 0
self.tokens = self.load_tokens(self.shards[self.current_shard])
self.current_position = self.B * self.T * self.process_rank
def next_batch(self):
B, T = self.B, self.T
buf = self.tokens[self.current_position:self.current_position + B*T+1]
x = (buf[:-1]).view(B,T)
y = (buf[1:]).view(B,T)
self.current_position += B * T * self.num_processes
if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
self.current_shard = (self.current_shard + 1) % len(self.shards)
self.tokens = self.load_tokens(self.shards[self.current_shard])
self.current_position = B * T * self.process_rank
return x, y
|