import os import numpy as np import torch #Data loader class DataLoader_1: def __init__(self, B, T, process_rank, num_processes, split, master_process): self.B = B self.T = T self.process_rank = process_rank self.num_processes = num_processes assert split in {'train', 'val'} data_root = "data/edu_fineweb10B" shards = os.listdir(data_root) shards = [s for s in shards if split in s] shards = sorted(shards) shards = [os.path.join(data_root, s) for s in shards] self.shards = shards assert len(shards)> 0, f"no shards found for split {split}" if master_process: print(f"found {len(shards)} shards for split {split}") self.reset() def load_tokens(self, filename): npt = np.load(filename) npt = npt.astype(np.int32) ptt = torch.tensor(npt, dtype=torch.long) return ptt def reset(self): #state, init at shard 0 self.current_shard = 0 self.tokens = self.load_tokens(self.shards[self.current_shard]) self.current_position = self.B * self.T * self.process_rank def next_batch(self): B, T = self.B, self.T buf = self.tokens[self.current_position:self.current_position + B*T+1] x = (buf[:-1]).view(B,T) y = (buf[1:]).view(B,T) self.current_position += B * T * self.num_processes if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens): self.current_shard = (self.current_shard + 1) % len(self.shards) self.tokens = self.load_tokens(self.shards[self.current_shard]) self.current_position = B * T * self.process_rank return x, y