GPT-124m / model_core /dataloader.py
abhinavv3's picture
minor changes
498886e
import os
import numpy as np
import torch
#Data loader
class DataLoader_1:
def __init__(self, B, T, process_rank, num_processes, split, master_process):
self.B = B
self.T = T
self.process_rank = process_rank
self.num_processes = num_processes
assert split in {'train', 'val'}
data_root = "data/edu_fineweb10B"
shards = os.listdir(data_root)
shards = [s for s in shards if split in s]
shards = sorted(shards)
shards = [os.path.join(data_root, s) for s in shards]
self.shards = shards
assert len(shards)> 0, f"no shards found for split {split}"
if master_process:
print(f"found {len(shards)} shards for split {split}")
self.reset()
def load_tokens(self, filename):
npt = np.load(filename)
npt = npt.astype(np.int32)
ptt = torch.tensor(npt, dtype=torch.long)
return ptt
def reset(self):
#state, init at shard 0
self.current_shard = 0
self.tokens = self.load_tokens(self.shards[self.current_shard])
self.current_position = self.B * self.T * self.process_rank
def next_batch(self):
B, T = self.B, self.T
buf = self.tokens[self.current_position:self.current_position + B*T+1]
x = (buf[:-1]).view(B,T)
y = (buf[1:]).view(B,T)
self.current_position += B * T * self.num_processes
if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
self.current_shard = (self.current_shard + 1) % len(self.shards)
self.tokens = self.load_tokens(self.shards[self.current_shard])
self.current_position = B * T * self.process_rank
return x, y